// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tar implements access to tar archives.
//
// Tape archives (tar) are a file format for storing a sequence of files that
// can be read and written in a streaming manner.
// This package aims to cover most variations of the format,
// including those produced by GNU and BSD tar tools.
package tar
import (
"errors"
"fmt"
"internal/godebug"
"io/fs"
"math"
"path"
"reflect"
"strconv"
"strings"
"time"
)
// BUG: Use of the Uid and Gid fields in Header could overflow on 32-bit
// architectures. If a large value is encountered when decoding, the result
// stored in Header will be the truncated version.
var tarinsecurepath = godebug.New("tarinsecurepath")
var (
ErrHeader = errors.New("archive/tar: invalid tar header")
ErrWriteTooLong = errors.New("archive/tar: write too long")
ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("archive/tar: write after close")
ErrInsecurePath = errors.New("archive/tar: insecure file path")
errMissData = errors.New("archive/tar: sparse file references non-existent data")
errUnrefData = errors.New("archive/tar: sparse file contains unreferenced data")
errWriteHole = errors.New("archive/tar: write non-NUL byte in sparse hole")
)
type headerError []string
func (he headerError) Error() string {
const prefix = "archive/tar: cannot encode header"
var ss []string
for _, s := range he {
if s != "" {
ss = append(ss, s)
}
}
if len(ss) == 0 {
return prefix
}
return fmt.Sprintf("%s: %v", prefix, strings.Join(ss, "; and "))
}
// Type flags for Header.Typeflag.
const (
// Type '0' indicates a regular file.
TypeReg = '0'
// Deprecated: Use TypeReg instead.
TypeRegA = '\x00'
// Type '1' to '6' are header-only flags and may not have a data body.
TypeLink = '1' // Hard link
TypeSymlink = '2' // Symbolic link
TypeChar = '3' // Character device node
TypeBlock = '4' // Block device node
TypeDir = '5' // Directory
TypeFifo = '6' // FIFO node
// Type '7' is reserved.
TypeCont = '7'
// Type 'x' is used by the PAX format to store key-value records that
// are only relevant to the next file.
// This package transparently handles these types.
TypeXHeader = 'x'
// Type 'g' is used by the PAX format to store key-value records that
// are relevant to all subsequent files.
// This package only supports parsing and composing such headers,
// but does not currently support persisting the global state across files.
TypeXGlobalHeader = 'g'
// Type 'S' indicates a sparse file in the GNU format.
TypeGNUSparse = 'S'
// Types 'L' and 'K' are used by the GNU format for a meta file
// used to store the path or link name for the next file.
// This package transparently handles these types.
TypeGNULongName = 'L'
TypeGNULongLink = 'K'
)
// Keywords for PAX extended header records.
const (
paxNone = "" // Indicates that no PAX key is suitable
paxPath = "path"
paxLinkpath = "linkpath"
paxSize = "size"
paxUid = "uid"
paxGid = "gid"
paxUname = "uname"
paxGname = "gname"
paxMtime = "mtime"
paxAtime = "atime"
paxCtime = "ctime" // Removed from later revision of PAX spec, but was valid
paxCharset = "charset" // Currently unused
paxComment = "comment" // Currently unused
paxSchilyXattr = "SCHILY.xattr."
// Keywords for GNU sparse files in a PAX extended header.
paxGNUSparse = "GNU.sparse."
paxGNUSparseNumBlocks = "GNU.sparse.numblocks"
paxGNUSparseOffset = "GNU.sparse.offset"
paxGNUSparseNumBytes = "GNU.sparse.numbytes"
paxGNUSparseMap = "GNU.sparse.map"
paxGNUSparseName = "GNU.sparse.name"
paxGNUSparseMajor = "GNU.sparse.major"
paxGNUSparseMinor = "GNU.sparse.minor"
paxGNUSparseSize = "GNU.sparse.size"
paxGNUSparseRealSize = "GNU.sparse.realsize"
)
// basicKeys is a set of the PAX keys for which we have built-in support.
// This does not contain "charset" or "comment", which are both PAX-specific,
// so adding them as first-class features of Header is unlikely.
// Users can use the PAXRecords field to set it themselves.
var basicKeys = map[string]bool{
paxPath: true, paxLinkpath: true, paxSize: true, paxUid: true, paxGid: true,
paxUname: true, paxGname: true, paxMtime: true, paxAtime: true, paxCtime: true,
}
// A Header represents a single header in a tar archive.
// Some fields may not be populated.
//
// For forward compatibility, users that retrieve a Header from Reader.Next,
// mutate it in some ways, and then pass it back to Writer.WriteHeader
// should do so by creating a new Header and copying the fields
// that they are interested in preserving.
type Header struct {
// Typeflag is the type of header entry.
// The zero value is automatically promoted to either TypeReg or TypeDir
// depending on the presence of a trailing slash in Name.
Typeflag byte
Name string // Name of file entry
Linkname string // Target name of link (valid for TypeLink or TypeSymlink)
Size int64 // Logical file size in bytes
Mode int64 // Permission and mode bits
Uid int // User ID of owner
Gid int // Group ID of owner
Uname string // User name of owner
Gname string // Group name of owner
// If the Format is unspecified, then Writer.WriteHeader rounds ModTime
// to the nearest second and ignores the AccessTime and ChangeTime fields.
//
// To use AccessTime or ChangeTime, specify the Format as PAX or GNU.
// To use sub-second resolution, specify the Format as PAX.
ModTime time.Time // Modification time
AccessTime time.Time // Access time (requires either PAX or GNU support)
ChangeTime time.Time // Change time (requires either PAX or GNU support)
Devmajor int64 // Major device number (valid for TypeChar or TypeBlock)
Devminor int64 // Minor device number (valid for TypeChar or TypeBlock)
// Xattrs stores extended attributes as PAX records under the
// "SCHILY.xattr." namespace.
//
// The following are semantically equivalent:
// h.Xattrs[key] = value
// h.PAXRecords["SCHILY.xattr."+key] = value
//
// When Writer.WriteHeader is called, the contents of Xattrs will take
// precedence over those in PAXRecords.
//
// Deprecated: Use PAXRecords instead.
Xattrs map[string]string
// PAXRecords is a map of PAX extended header records.
//
// User-defined records should have keys of the following form:
// VENDOR.keyword
// Where VENDOR is some namespace in all uppercase, and keyword may
// not contain the '=' character (e.g., "GOLANG.pkg.version").
// The key and value should be non-empty UTF-8 strings.
//
// When Writer.WriteHeader is called, PAX records derived from the
// other fields in Header take precedence over PAXRecords.
PAXRecords map[string]string
// Format specifies the format of the tar header.
//
// This is set by Reader.Next as a best-effort guess at the format.
// Since the Reader liberally reads some non-compliant files,
// it is possible for this to be FormatUnknown.
//
// If the format is unspecified when Writer.WriteHeader is called,
// then it uses the first format (in the order of USTAR, PAX, GNU)
// capable of encoding this Header (see Format).
Format Format
}
// sparseEntry represents a Length-sized fragment at Offset in the file.
type sparseEntry struct{ Offset, Length int64 }
func (s sparseEntry) endOffset() int64 { return s.Offset + s.Length }
// A sparse file can be represented as either a sparseDatas or a sparseHoles.
// As long as the total size is known, they are equivalent and one can be
// converted to the other form and back. The various tar formats with sparse
// file support represent sparse files in the sparseDatas form. That is, they
// specify the fragments in the file that has data, and treat everything else as
// having zero bytes. As such, the encoding and decoding logic in this package
// deals with sparseDatas.
//
// However, the external API uses sparseHoles instead of sparseDatas because the
// zero value of sparseHoles logically represents a normal file (i.e., there are
// no holes in it). On the other hand, the zero value of sparseDatas implies
// that the file has no data in it, which is rather odd.
//
// As an example, if the underlying raw file contains the 10-byte data:
//
// var compactFile = "abcdefgh"
//
// And the sparse map has the following entries:
//
// var spd sparseDatas = []sparseEntry{
// {Offset: 2, Length: 5}, // Data fragment for 2..6
// {Offset: 18, Length: 3}, // Data fragment for 18..20
// }
// var sph sparseHoles = []sparseEntry{
// {Offset: 0, Length: 2}, // Hole fragment for 0..1
// {Offset: 7, Length: 11}, // Hole fragment for 7..17
// {Offset: 21, Length: 4}, // Hole fragment for 21..24
// }
//
// Then the content of the resulting sparse file with a Header.Size of 25 is:
//
// var sparseFile = "\x00"*2 + "abcde" + "\x00"*11 + "fgh" + "\x00"*4
type (
sparseDatas []sparseEntry
sparseHoles []sparseEntry
)
// validateSparseEntries reports whether sp is a valid sparse map.
// It does not matter whether sp represents data fragments or hole fragments.
func validateSparseEntries(sp []sparseEntry, size int64) bool {
// Validate all sparse entries. These are the same checks as performed by
// the BSD tar utility.
if size < 0 {
return false
}
var pre sparseEntry
for _, cur := range sp {
switch {
case cur.Offset < 0 || cur.Length < 0:
return false // Negative values are never okay
case cur.Offset > math.MaxInt64-cur.Length:
return false // Integer overflow with large length
case cur.endOffset() > size:
return false // Region extends beyond the actual size
case pre.endOffset() > cur.Offset:
return false // Regions cannot overlap and must be in order
}
pre = cur
}
return true
}
// alignSparseEntries mutates src and returns dst where each fragment's
// starting offset is aligned up to the nearest block edge, and each
// ending offset is aligned down to the nearest block edge.
//
// Even though the Go tar Reader and the BSD tar utility can handle entries
// with arbitrary offsets and lengths, the GNU tar utility can only handle
// offsets and lengths that are multiples of blockSize.
func alignSparseEntries(src []sparseEntry, size int64) []sparseEntry {
dst := src[:0]
for _, s := range src {
pos, end := s.Offset, s.endOffset()
pos += blockPadding(+pos) // Round-up to nearest blockSize
if end != size {
end -= blockPadding(-end) // Round-down to nearest blockSize
}
if pos < end {
dst = append(dst, sparseEntry{Offset: pos, Length: end - pos})
}
}
return dst
}
// invertSparseEntries converts a sparse map from one form to the other.
// If the input is sparseHoles, then it will output sparseDatas and vice-versa.
// The input must have been already validated.
//
// This function mutates src and returns a normalized map where:
// - adjacent fragments are coalesced together
// - only the last fragment may be empty
// - the endOffset of the last fragment is the total size
func invertSparseEntries(src []sparseEntry, size int64) []sparseEntry {
dst := src[:0]
var pre sparseEntry
for _, cur := range src {
if cur.Length == 0 {
continue // Skip empty fragments
}
pre.Length = cur.Offset - pre.Offset
if pre.Length > 0 {
dst = append(dst, pre) // Only add non-empty fragments
}
pre.Offset = cur.endOffset()
}
pre.Length = size - pre.Offset // Possibly the only empty fragment
return append(dst, pre)
}
// fileState tracks the number of logical (includes sparse holes) and physical
// (actual in tar archive) bytes remaining for the current file.
//
// Invariant: logicalRemaining >= physicalRemaining
type fileState interface {
logicalRemaining() int64
physicalRemaining() int64
}
// allowedFormats determines which formats can be used.
// The value returned is the logical OR of multiple possible formats.
// If the value is FormatUnknown, then the input Header cannot be encoded
// and an error is returned explaining why.
//
// As a by-product of checking the fields, this function returns paxHdrs, which
// contain all fields that could not be directly encoded.
// A value receiver ensures that this method does not mutate the source Header.
func (h Header) allowedFormats() (format Format, paxHdrs map[string]string, err error) {
format = FormatUSTAR | FormatPAX | FormatGNU
paxHdrs = make(map[string]string)
var whyNoUSTAR, whyNoPAX, whyNoGNU string
var preferPAX bool // Prefer PAX over USTAR
verifyString := func(s string, size int, name, paxKey string) {
// NUL-terminator is optional for path and linkpath.
// Technically, it is required for uname and gname,
// but neither GNU nor BSD tar checks for it.
tooLong := len(s) > size
allowLongGNU := paxKey == paxPath || paxKey == paxLinkpath
if hasNUL(s) || (tooLong && !allowLongGNU) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%q", name, s)
format.mustNotBe(FormatGNU)
}
if !isASCII(s) || tooLong {
canSplitUSTAR := paxKey == paxPath
if _, _, ok := splitUSTARPath(s); !canSplitUSTAR || !ok {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%q", name, s)
format.mustNotBe(FormatUSTAR)
}
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%q", name, s)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = s
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == s {
paxHdrs[paxKey] = v
}
}
verifyNumeric := func(n int64, size int, name, paxKey string) {
if !fitsInBase256(size, n) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%d", name, n)
format.mustNotBe(FormatGNU)
}
if !fitsInOctal(size, n) {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%d", name, n)
format.mustNotBe(FormatUSTAR)
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%d", name, n)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = strconv.FormatInt(n, 10)
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == strconv.FormatInt(n, 10) {
paxHdrs[paxKey] = v
}
}
verifyTime := func(ts time.Time, size int, name, paxKey string) {
if ts.IsZero() {
return // Always okay
}
if !fitsInBase256(size, ts.Unix()) {
whyNoGNU = fmt.Sprintf("GNU cannot encode %s=%v", name, ts)
format.mustNotBe(FormatGNU)
}
isMtime := paxKey == paxMtime
fitsOctal := fitsInOctal(size, ts.Unix())
if (isMtime && !fitsOctal) || !isMtime {
whyNoUSTAR = fmt.Sprintf("USTAR cannot encode %s=%v", name, ts)
format.mustNotBe(FormatUSTAR)
}
needsNano := ts.Nanosecond() != 0
if !isMtime || !fitsOctal || needsNano {
preferPAX = true // USTAR may truncate sub-second measurements
if paxKey == paxNone {
whyNoPAX = fmt.Sprintf("PAX cannot encode %s=%v", name, ts)
format.mustNotBe(FormatPAX)
} else {
paxHdrs[paxKey] = formatPAXTime(ts)
}
}
if v, ok := h.PAXRecords[paxKey]; ok && v == formatPAXTime(ts) {
paxHdrs[paxKey] = v
}
}
// Check basic fields.
var blk block
v7 := blk.toV7()
ustar := blk.toUSTAR()
gnu := blk.toGNU()
verifyString(h.Name, len(v7.name()), "Name", paxPath)
verifyString(h.Linkname, len(v7.linkName()), "Linkname", paxLinkpath)
verifyString(h.Uname, len(ustar.userName()), "Uname", paxUname)
verifyString(h.Gname, len(ustar.groupName()), "Gname", paxGname)
verifyNumeric(h.Mode, len(v7.mode()), "Mode", paxNone)
verifyNumeric(int64(h.Uid), len(v7.uid()), "Uid", paxUid)
verifyNumeric(int64(h.Gid), len(v7.gid()), "Gid", paxGid)
verifyNumeric(h.Size, len(v7.size()), "Size", paxSize)
verifyNumeric(h.Devmajor, len(ustar.devMajor()), "Devmajor", paxNone)
verifyNumeric(h.Devminor, len(ustar.devMinor()), "Devminor", paxNone)
verifyTime(h.ModTime, len(v7.modTime()), "ModTime", paxMtime)
verifyTime(h.AccessTime, len(gnu.accessTime()), "AccessTime", paxAtime)
verifyTime(h.ChangeTime, len(gnu.changeTime()), "ChangeTime", paxCtime)
// Check for header-only types.
var whyOnlyPAX, whyOnlyGNU string
switch h.Typeflag {
case TypeReg, TypeChar, TypeBlock, TypeFifo, TypeGNUSparse:
// Exclude TypeLink and TypeSymlink, since they may reference directories.
if strings.HasSuffix(h.Name, "/") {
return FormatUnknown, nil, headerError{"filename may not have trailing slash"}
}
case TypeXHeader, TypeGNULongName, TypeGNULongLink:
return FormatUnknown, nil, headerError{"cannot manually encode TypeXHeader, TypeGNULongName, or TypeGNULongLink headers"}
case TypeXGlobalHeader:
h2 := Header{Name: h.Name, Typeflag: h.Typeflag, Xattrs: h.Xattrs, PAXRecords: h.PAXRecords, Format: h.Format}
if !reflect.DeepEqual(h, h2) {
return FormatUnknown, nil, headerError{"only PAXRecords should be set for TypeXGlobalHeader"}
}
whyOnlyPAX = "only PAX supports TypeXGlobalHeader"
format.mayOnlyBe(FormatPAX)
}
if !isHeaderOnlyType(h.Typeflag) && h.Size < 0 {
return FormatUnknown, nil, headerError{"negative size on header-only type"}
}
// Check PAX records.
if len(h.Xattrs) > 0 {
for k, v := range h.Xattrs {
paxHdrs[paxSchilyXattr+k] = v
}
whyOnlyPAX = "only PAX supports Xattrs"
format.mayOnlyBe(FormatPAX)
}
if len(h.PAXRecords) > 0 {
for k, v := range h.PAXRecords {
switch _, exists := paxHdrs[k]; {
case exists:
continue // Do not overwrite existing records
case h.Typeflag == TypeXGlobalHeader:
paxHdrs[k] = v // Copy all records
case !basicKeys[k] && !strings.HasPrefix(k, paxGNUSparse):
paxHdrs[k] = v // Ignore local records that may conflict
}
}
whyOnlyPAX = "only PAX supports PAXRecords"
format.mayOnlyBe(FormatPAX)
}
for k, v := range paxHdrs {
if !validPAXRecord(k, v) {
return FormatUnknown, nil, headerError{fmt.Sprintf("invalid PAX record: %q", k+" = "+v)}
}
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Check sparse files.
if len(h.SparseHoles) > 0 || h.Typeflag == TypeGNUSparse {
if isHeaderOnlyType(h.Typeflag) {
return FormatUnknown, nil, headerError{"header-only type cannot be sparse"}
}
if !validateSparseEntries(h.SparseHoles, h.Size) {
return FormatUnknown, nil, headerError{"invalid sparse holes"}
}
if h.Typeflag == TypeGNUSparse {
whyOnlyGNU = "only GNU supports TypeGNUSparse"
format.mayOnlyBe(FormatGNU)
} else {
whyNoGNU = "GNU supports sparse files only with TypeGNUSparse"
format.mustNotBe(FormatGNU)
}
whyNoUSTAR = "USTAR does not support sparse files"
format.mustNotBe(FormatUSTAR)
}
*/
// Check desired format.
if wantFormat := h.Format; wantFormat != FormatUnknown {
if wantFormat.has(FormatPAX) && !preferPAX {
wantFormat.mayBe(FormatUSTAR) // PAX implies USTAR allowed too
}
format.mayOnlyBe(wantFormat) // Set union of formats allowed and format wanted
}
if format == FormatUnknown {
switch h.Format {
case FormatUSTAR:
err = headerError{"Format specifies USTAR", whyNoUSTAR, whyOnlyPAX, whyOnlyGNU}
case FormatPAX:
err = headerError{"Format specifies PAX", whyNoPAX, whyOnlyGNU}
case FormatGNU:
err = headerError{"Format specifies GNU", whyNoGNU, whyOnlyPAX}
default:
err = headerError{whyNoUSTAR, whyNoPAX, whyNoGNU, whyOnlyPAX, whyOnlyGNU}
}
}
return format, paxHdrs, err
}
// FileInfo returns an fs.FileInfo for the Header.
func (h *Header) FileInfo() fs.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements fs.FileInfo.
type headerFileInfo struct {
h *Header
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() any { return fi.h }
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Base(path.Clean(fi.h.Name))
}
return path.Base(fi.h.Name)
}
// Mode returns the permission and mode bits for the headerFileInfo.
func (fi headerFileInfo) Mode() (mode fs.FileMode) {
// Set file permission bits.
mode = fs.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&c_ISUID != 0 {
mode |= fs.ModeSetuid
}
if fi.h.Mode&c_ISGID != 0 {
mode |= fs.ModeSetgid
}
if fi.h.Mode&c_ISVTX != 0 {
mode |= fs.ModeSticky
}
// Set file mode bits; clear perm, setuid, setgid, and sticky bits.
switch m := fs.FileMode(fi.h.Mode) &^ 07777; m {
case c_ISDIR:
mode |= fs.ModeDir
case c_ISFIFO:
mode |= fs.ModeNamedPipe
case c_ISLNK:
mode |= fs.ModeSymlink
case c_ISBLK:
mode |= fs.ModeDevice
case c_ISCHR:
mode |= fs.ModeDevice
mode |= fs.ModeCharDevice
case c_ISSOCK:
mode |= fs.ModeSocket
}
switch fi.h.Typeflag {
case TypeSymlink:
mode |= fs.ModeSymlink
case TypeChar:
mode |= fs.ModeDevice
mode |= fs.ModeCharDevice
case TypeBlock:
mode |= fs.ModeDevice
case TypeDir:
mode |= fs.ModeDir
case TypeFifo:
mode |= fs.ModeNamedPipe
}
return mode
}
// sysStat, if non-nil, populates h from system-dependent fields of fi.
var sysStat func(fi fs.FileInfo, h *Header) error
const (
// Mode constants from the USTAR spec:
// See http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_06
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
// Common Unix mode constants; these are not defined in any common tar standard.
// Header.FileInfo understands these, but FileInfoHeader will never produce these.
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
// FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
//
// Since fs.FileInfo's Name method only returns the base name of
// the file it describes, it may be necessary to modify Header.Name
// to provide the full path name of the file.
func FileInfoHeader(fi fs.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("archive/tar: FileInfo is nil")
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
ModTime: fi.ModTime(),
Mode: int64(fm.Perm()), // or'd with c_IS* constants later
}
switch {
case fm.IsRegular():
h.Typeflag = TypeReg
h.Size = fi.Size()
case fi.IsDir():
h.Typeflag = TypeDir
h.Name += "/"
case fm&fs.ModeSymlink != 0:
h.Typeflag = TypeSymlink
h.Linkname = link
case fm&fs.ModeDevice != 0:
if fm&fs.ModeCharDevice != 0 {
h.Typeflag = TypeChar
} else {
h.Typeflag = TypeBlock
}
case fm&fs.ModeNamedPipe != 0:
h.Typeflag = TypeFifo
case fm&fs.ModeSocket != 0:
return nil, fmt.Errorf("archive/tar: sockets not supported")
default:
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fm)
}
if fm&fs.ModeSetuid != 0 {
h.Mode |= c_ISUID
}
if fm&fs.ModeSetgid != 0 {
h.Mode |= c_ISGID
}
if fm&fs.ModeSticky != 0 {
h.Mode |= c_ISVTX
}
// If possible, populate additional fields from OS-specific
// FileInfo fields.
if sys, ok := fi.Sys().(*Header); ok {
// This FileInfo came from a Header (not the OS). Use the
// original Header to populate all remaining fields.
h.Uid = sys.Uid
h.Gid = sys.Gid
h.Uname = sys.Uname
h.Gname = sys.Gname
h.AccessTime = sys.AccessTime
h.ChangeTime = sys.ChangeTime
if sys.Xattrs != nil {
h.Xattrs = make(map[string]string)
for k, v := range sys.Xattrs {
h.Xattrs[k] = v
}
}
if sys.Typeflag == TypeLink {
// hard link
h.Typeflag = TypeLink
h.Size = 0
h.Linkname = sys.Linkname
}
if sys.PAXRecords != nil {
h.PAXRecords = make(map[string]string)
for k, v := range sys.PAXRecords {
h.PAXRecords[k] = v
}
}
}
if sysStat != nil {
return h, sysStat(fi, h)
}
return h, nil
}
// isHeaderOnlyType checks if the given type flag is of the type that has no
// data section even if a size is specified.
func isHeaderOnlyType(flag byte) bool {
switch flag {
case TypeLink, TypeSymlink, TypeChar, TypeBlock, TypeDir, TypeFifo:
return true
default:
return false
}
}
func min(a, b int64) int64 {
if a < b {
return a
}
return b
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import "strings"
// Format represents the tar archive format.
//
// The original tar format was introduced in Unix V7.
// Since then, there have been multiple competing formats attempting to
// standardize or extend the V7 format to overcome its limitations.
// The most common formats are the USTAR, PAX, and GNU formats,
// each with their own advantages and limitations.
//
// The following table captures the capabilities of each format:
//
// | USTAR | PAX | GNU
// ------------------+--------+-----------+----------
// Name | 256B | unlimited | unlimited
// Linkname | 100B | unlimited | unlimited
// Size | uint33 | unlimited | uint89
// Mode | uint21 | uint21 | uint57
// Uid/Gid | uint21 | unlimited | uint57
// Uname/Gname | 32B | unlimited | 32B
// ModTime | uint33 | unlimited | int89
// AccessTime | n/a | unlimited | int89
// ChangeTime | n/a | unlimited | int89
// Devmajor/Devminor | uint21 | uint21 | uint57
// ------------------+--------+-----------+----------
// string encoding | ASCII | UTF-8 | binary
// sub-second times | no | yes | no
// sparse files | no | yes | yes
//
// The table's upper portion shows the Header fields, where each format reports
// the maximum number of bytes allowed for each string field and
// the integer type used to store each numeric field
// (where timestamps are stored as the number of seconds since the Unix epoch).
//
// The table's lower portion shows specialized features of each format,
// such as supported string encodings, support for sub-second timestamps,
// or support for sparse files.
//
// The Writer currently provides no support for sparse files.
type Format int
// Constants to identify various tar formats.
const (
// Deliberately hide the meaning of constants from public API.
_ Format = (1 << iota) / 4 // Sequence of 0, 0, 1, 2, 4, 8, etc...
// FormatUnknown indicates that the format is unknown.
FormatUnknown
// The format of the original Unix V7 tar tool prior to standardization.
formatV7
// FormatUSTAR represents the USTAR header format defined in POSIX.1-1988.
//
// While this format is compatible with most tar readers,
// the format has several limitations making it unsuitable for some usages.
// Most notably, it cannot support sparse files, files larger than 8GiB,
// filenames larger than 256 characters, and non-ASCII filenames.
//
// Reference:
// http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html#tag_20_92_13_06
FormatUSTAR
// FormatPAX represents the PAX header format defined in POSIX.1-2001.
//
// PAX extends USTAR by writing a special file with Typeflag TypeXHeader
// preceding the original header. This file contains a set of key-value
// records, which are used to overcome USTAR's shortcomings, in addition to
// providing the ability to have sub-second resolution for timestamps.
//
// Some newer formats add their own extensions to PAX by defining their
// own keys and assigning certain semantic meaning to the associated values.
// For example, sparse file support in PAX is implemented using keys
// defined by the GNU manual (e.g., "GNU.sparse.map").
//
// Reference:
// http://pubs.opengroup.org/onlinepubs/009695399/utilities/pax.html
FormatPAX
// FormatGNU represents the GNU header format.
//
// The GNU header format is older than the USTAR and PAX standards and
// is not compatible with them. The GNU format supports
// arbitrary file sizes, filenames of arbitrary encoding and length,
// sparse files, and other features.
//
// It is recommended that PAX be chosen over GNU unless the target
// application can only parse GNU formatted archives.
//
// Reference:
// https://www.gnu.org/software/tar/manual/html_node/Standard.html
FormatGNU
// Schily's tar format, which is incompatible with USTAR.
// This does not cover STAR extensions to the PAX format; these fall under
// the PAX format.
formatSTAR
formatMax
)
func (f Format) has(f2 Format) bool { return f&f2 != 0 }
func (f *Format) mayBe(f2 Format) { *f |= f2 }
func (f *Format) mayOnlyBe(f2 Format) { *f &= f2 }
func (f *Format) mustNotBe(f2 Format) { *f &^= f2 }
var formatNames = map[Format]string{
formatV7: "V7", FormatUSTAR: "USTAR", FormatPAX: "PAX", FormatGNU: "GNU", formatSTAR: "STAR",
}
func (f Format) String() string {
var ss []string
for f2 := Format(1); f2 < formatMax; f2 <<= 1 {
if f.has(f2) {
ss = append(ss, formatNames[f2])
}
}
switch len(ss) {
case 0:
return "<unknown>"
case 1:
return ss[0]
default:
return "(" + strings.Join(ss, " | ") + ")"
}
}
// Magics used to identify various formats.
const (
magicGNU, versionGNU = "ustar ", " \x00"
magicUSTAR, versionUSTAR = "ustar\x00", "00"
trailerSTAR = "tar\x00"
)
// Size constants from various tar specifications.
const (
blockSize = 512 // Size of each block in a tar stream
nameSize = 100 // Max length of the name field in USTAR format
prefixSize = 155 // Max length of the prefix field in USTAR format
// Max length of a special file (PAX header, GNU long name or link).
// This matches the limit used by libarchive.
maxSpecialFileSize = 1 << 20
)
// blockPadding computes the number of bytes needed to pad offset up to the
// nearest block edge where 0 <= n < blockSize.
func blockPadding(offset int64) (n int64) {
return -offset & (blockSize - 1)
}
var zeroBlock block
type block [blockSize]byte
// Convert block to any number of formats.
func (b *block) toV7() *headerV7 { return (*headerV7)(b) }
func (b *block) toGNU() *headerGNU { return (*headerGNU)(b) }
func (b *block) toSTAR() *headerSTAR { return (*headerSTAR)(b) }
func (b *block) toUSTAR() *headerUSTAR { return (*headerUSTAR)(b) }
func (b *block) toSparse() sparseArray { return sparseArray(b[:]) }
// getFormat checks that the block is a valid tar header based on the checksum.
// It then attempts to guess the specific format based on magic values.
// If the checksum fails, then FormatUnknown is returned.
func (b *block) getFormat() Format {
// Verify checksum.
var p parser
value := p.parseOctal(b.toV7().chksum())
chksum1, chksum2 := b.computeChecksum()
if p.err != nil || (value != chksum1 && value != chksum2) {
return FormatUnknown
}
// Guess the magic values.
magic := string(b.toUSTAR().magic())
version := string(b.toUSTAR().version())
trailer := string(b.toSTAR().trailer())
switch {
case magic == magicUSTAR && trailer == trailerSTAR:
return formatSTAR
case magic == magicUSTAR:
return FormatUSTAR | FormatPAX
case magic == magicGNU && version == versionGNU:
return FormatGNU
default:
return formatV7
}
}
// setFormat writes the magic values necessary for specified format
// and then updates the checksum accordingly.
func (b *block) setFormat(format Format) {
// Set the magic values.
switch {
case format.has(formatV7):
// Do nothing.
case format.has(FormatGNU):
copy(b.toGNU().magic(), magicGNU)
copy(b.toGNU().version(), versionGNU)
case format.has(formatSTAR):
copy(b.toSTAR().magic(), magicUSTAR)
copy(b.toSTAR().version(), versionUSTAR)
copy(b.toSTAR().trailer(), trailerSTAR)
case format.has(FormatUSTAR | FormatPAX):
copy(b.toUSTAR().magic(), magicUSTAR)
copy(b.toUSTAR().version(), versionUSTAR)
default:
panic("invalid format")
}
// Update checksum.
// This field is special in that it is terminated by a NULL then space.
var f formatter
field := b.toV7().chksum()
chksum, _ := b.computeChecksum() // Possible values are 256..128776
f.formatOctal(field[:7], chksum) // Never fails since 128776 < 262143
field[7] = ' '
}
// computeChecksum computes the checksum for the header block.
// POSIX specifies a sum of the unsigned byte values, but the Sun tar used
// signed byte values.
// We compute and return both.
func (b *block) computeChecksum() (unsigned, signed int64) {
for i, c := range b {
if 148 <= i && i < 156 {
c = ' ' // Treat the checksum field itself as all spaces.
}
unsigned += int64(c)
signed += int64(int8(c))
}
return unsigned, signed
}
// reset clears the block with all zeros.
func (b *block) reset() {
*b = block{}
}
type headerV7 [blockSize]byte
func (h *headerV7) name() []byte { return h[000:][:100] }
func (h *headerV7) mode() []byte { return h[100:][:8] }
func (h *headerV7) uid() []byte { return h[108:][:8] }
func (h *headerV7) gid() []byte { return h[116:][:8] }
func (h *headerV7) size() []byte { return h[124:][:12] }
func (h *headerV7) modTime() []byte { return h[136:][:12] }
func (h *headerV7) chksum() []byte { return h[148:][:8] }
func (h *headerV7) typeFlag() []byte { return h[156:][:1] }
func (h *headerV7) linkName() []byte { return h[157:][:100] }
type headerGNU [blockSize]byte
func (h *headerGNU) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerGNU) magic() []byte { return h[257:][:6] }
func (h *headerGNU) version() []byte { return h[263:][:2] }
func (h *headerGNU) userName() []byte { return h[265:][:32] }
func (h *headerGNU) groupName() []byte { return h[297:][:32] }
func (h *headerGNU) devMajor() []byte { return h[329:][:8] }
func (h *headerGNU) devMinor() []byte { return h[337:][:8] }
func (h *headerGNU) accessTime() []byte { return h[345:][:12] }
func (h *headerGNU) changeTime() []byte { return h[357:][:12] }
func (h *headerGNU) sparse() sparseArray { return sparseArray(h[386:][:24*4+1]) }
func (h *headerGNU) realSize() []byte { return h[483:][:12] }
type headerSTAR [blockSize]byte
func (h *headerSTAR) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerSTAR) magic() []byte { return h[257:][:6] }
func (h *headerSTAR) version() []byte { return h[263:][:2] }
func (h *headerSTAR) userName() []byte { return h[265:][:32] }
func (h *headerSTAR) groupName() []byte { return h[297:][:32] }
func (h *headerSTAR) devMajor() []byte { return h[329:][:8] }
func (h *headerSTAR) devMinor() []byte { return h[337:][:8] }
func (h *headerSTAR) prefix() []byte { return h[345:][:131] }
func (h *headerSTAR) accessTime() []byte { return h[476:][:12] }
func (h *headerSTAR) changeTime() []byte { return h[488:][:12] }
func (h *headerSTAR) trailer() []byte { return h[508:][:4] }
type headerUSTAR [blockSize]byte
func (h *headerUSTAR) v7() *headerV7 { return (*headerV7)(h) }
func (h *headerUSTAR) magic() []byte { return h[257:][:6] }
func (h *headerUSTAR) version() []byte { return h[263:][:2] }
func (h *headerUSTAR) userName() []byte { return h[265:][:32] }
func (h *headerUSTAR) groupName() []byte { return h[297:][:32] }
func (h *headerUSTAR) devMajor() []byte { return h[329:][:8] }
func (h *headerUSTAR) devMinor() []byte { return h[337:][:8] }
func (h *headerUSTAR) prefix() []byte { return h[345:][:155] }
type sparseArray []byte
func (s sparseArray) entry(i int) sparseElem { return sparseElem(s[i*24:]) }
func (s sparseArray) isExtended() []byte { return s[24*s.maxEntries():][:1] }
func (s sparseArray) maxEntries() int { return len(s) / 24 }
type sparseElem []byte
func (s sparseElem) offset() []byte { return s[00:][:12] }
func (s sparseElem) length() []byte { return s[12:][:12] }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"io"
"path/filepath"
"strconv"
"strings"
"time"
)
// Reader provides sequential access to the contents of a tar archive.
// Reader.Next advances to the next file in the archive (including the first),
// and then Reader can be treated as an io.Reader to access the file's data.
type Reader struct {
r io.Reader
pad int64 // Amount of padding (ignored) after current file entry
curr fileReader // Reader for current file entry
blk block // Buffer to use as temporary local storage
// err is a persistent error.
// It is only the responsibility of every exported method of Reader to
// ensure that this error is sticky.
err error
}
type fileReader interface {
io.Reader
fileState
WriteTo(io.Writer) (int64, error)
}
// NewReader creates a new Reader reading from r.
func NewReader(r io.Reader) *Reader {
return &Reader{r: r, curr: ®FileReader{r, 0}}
}
// Next advances to the next entry in the tar archive.
// The Header.Size determines how many bytes can be read for the next file.
// Any remaining data in the current file is automatically discarded.
// At the end of the archive, Next returns the error io.EOF.
//
// If Next encounters a non-local name (as defined by [filepath.IsLocal])
// and the GODEBUG environment variable contains `tarinsecurepath=0`,
// Next returns the header with an ErrInsecurePath error.
// A future version of Go may introduce this behavior by default.
// Programs that want to accept non-local names can ignore
// the ErrInsecurePath error and use the returned header.
func (tr *Reader) Next() (*Header, error) {
if tr.err != nil {
return nil, tr.err
}
hdr, err := tr.next()
tr.err = err
if err == nil && !filepath.IsLocal(hdr.Name) {
if tarinsecurepath.Value() == "0" {
tarinsecurepath.IncNonDefault()
err = ErrInsecurePath
}
}
return hdr, err
}
func (tr *Reader) next() (*Header, error) {
var paxHdrs map[string]string
var gnuLongName, gnuLongLink string
// Externally, Next iterates through the tar archive as if it is a series of
// files. Internally, the tar format often uses fake "files" to add meta
// data that describes the next file. These meta data "files" should not
// normally be visible to the outside. As such, this loop iterates through
// one or more "header files" until it finds a "normal file".
format := FormatUSTAR | FormatPAX | FormatGNU
for {
// Discard the remainder of the file and any padding.
if err := discard(tr.r, tr.curr.physicalRemaining()); err != nil {
return nil, err
}
if _, err := tryReadFull(tr.r, tr.blk[:tr.pad]); err != nil {
return nil, err
}
tr.pad = 0
hdr, rawHdr, err := tr.readHeader()
if err != nil {
return nil, err
}
if err := tr.handleRegularFile(hdr); err != nil {
return nil, err
}
format.mayOnlyBe(hdr.Format)
// Check for PAX/GNU special headers and files.
switch hdr.Typeflag {
case TypeXHeader, TypeXGlobalHeader:
format.mayOnlyBe(FormatPAX)
paxHdrs, err = parsePAX(tr)
if err != nil {
return nil, err
}
if hdr.Typeflag == TypeXGlobalHeader {
mergePAX(hdr, paxHdrs)
return &Header{
Name: hdr.Name,
Typeflag: hdr.Typeflag,
Xattrs: hdr.Xattrs,
PAXRecords: hdr.PAXRecords,
Format: format,
}, nil
}
continue // This is a meta header affecting the next header
case TypeGNULongName, TypeGNULongLink:
format.mayOnlyBe(FormatGNU)
realname, err := readSpecialFile(tr)
if err != nil {
return nil, err
}
var p parser
switch hdr.Typeflag {
case TypeGNULongName:
gnuLongName = p.parseString(realname)
case TypeGNULongLink:
gnuLongLink = p.parseString(realname)
}
continue // This is a meta header affecting the next header
default:
// The old GNU sparse format is handled here since it is technically
// just a regular file with additional attributes.
if err := mergePAX(hdr, paxHdrs); err != nil {
return nil, err
}
if gnuLongName != "" {
hdr.Name = gnuLongName
}
if gnuLongLink != "" {
hdr.Linkname = gnuLongLink
}
if hdr.Typeflag == TypeRegA {
if strings.HasSuffix(hdr.Name, "/") {
hdr.Typeflag = TypeDir // Legacy archives use trailing slash for directories
} else {
hdr.Typeflag = TypeReg
}
}
// The extended headers may have updated the size.
// Thus, setup the regFileReader again after merging PAX headers.
if err := tr.handleRegularFile(hdr); err != nil {
return nil, err
}
// Sparse formats rely on being able to read from the logical data
// section; there must be a preceding call to handleRegularFile.
if err := tr.handleSparseFile(hdr, rawHdr); err != nil {
return nil, err
}
// Set the final guess at the format.
if format.has(FormatUSTAR) && format.has(FormatPAX) {
format.mayOnlyBe(FormatUSTAR)
}
hdr.Format = format
return hdr, nil // This is a file, so stop
}
}
}
// handleRegularFile sets up the current file reader and padding such that it
// can only read the following logical data section. It will properly handle
// special headers that contain no data section.
func (tr *Reader) handleRegularFile(hdr *Header) error {
nb := hdr.Size
if isHeaderOnlyType(hdr.Typeflag) {
nb = 0
}
if nb < 0 {
return ErrHeader
}
tr.pad = blockPadding(nb)
tr.curr = ®FileReader{r: tr.r, nb: nb}
return nil
}
// handleSparseFile checks if the current file is a sparse format of any type
// and sets the curr reader appropriately.
func (tr *Reader) handleSparseFile(hdr *Header, rawHdr *block) error {
var spd sparseDatas
var err error
if hdr.Typeflag == TypeGNUSparse {
spd, err = tr.readOldGNUSparseMap(hdr, rawHdr)
} else {
spd, err = tr.readGNUSparsePAXHeaders(hdr)
}
// If sp is non-nil, then this is a sparse file.
// Note that it is possible for len(sp) == 0.
if err == nil && spd != nil {
if isHeaderOnlyType(hdr.Typeflag) || !validateSparseEntries(spd, hdr.Size) {
return ErrHeader
}
sph := invertSparseEntries(spd, hdr.Size)
tr.curr = &sparseFileReader{tr.curr, sph, 0}
}
return err
}
// readGNUSparsePAXHeaders checks the PAX headers for GNU sparse headers.
// If they are found, then this function reads the sparse map and returns it.
// This assumes that 0.0 headers have already been converted to 0.1 headers
// by the PAX header parsing logic.
func (tr *Reader) readGNUSparsePAXHeaders(hdr *Header) (sparseDatas, error) {
// Identify the version of GNU headers.
var is1x0 bool
major, minor := hdr.PAXRecords[paxGNUSparseMajor], hdr.PAXRecords[paxGNUSparseMinor]
switch {
case major == "0" && (minor == "0" || minor == "1"):
is1x0 = false
case major == "1" && minor == "0":
is1x0 = true
case major != "" || minor != "":
return nil, nil // Unknown GNU sparse PAX version
case hdr.PAXRecords[paxGNUSparseMap] != "":
is1x0 = false // 0.0 and 0.1 did not have explicit version records, so guess
default:
return nil, nil // Not a PAX format GNU sparse file.
}
hdr.Format.mayOnlyBe(FormatPAX)
// Update hdr from GNU sparse PAX headers.
if name := hdr.PAXRecords[paxGNUSparseName]; name != "" {
hdr.Name = name
}
size := hdr.PAXRecords[paxGNUSparseSize]
if size == "" {
size = hdr.PAXRecords[paxGNUSparseRealSize]
}
if size != "" {
n, err := strconv.ParseInt(size, 10, 64)
if err != nil {
return nil, ErrHeader
}
hdr.Size = n
}
// Read the sparse map according to the appropriate format.
if is1x0 {
return readGNUSparseMap1x0(tr.curr)
}
return readGNUSparseMap0x1(hdr.PAXRecords)
}
// mergePAX merges paxHdrs into hdr for all relevant fields of Header.
func mergePAX(hdr *Header, paxHdrs map[string]string) (err error) {
for k, v := range paxHdrs {
if v == "" {
continue // Keep the original USTAR value
}
var id64 int64
switch k {
case paxPath:
hdr.Name = v
case paxLinkpath:
hdr.Linkname = v
case paxUname:
hdr.Uname = v
case paxGname:
hdr.Gname = v
case paxUid:
id64, err = strconv.ParseInt(v, 10, 64)
hdr.Uid = int(id64) // Integer overflow possible
case paxGid:
id64, err = strconv.ParseInt(v, 10, 64)
hdr.Gid = int(id64) // Integer overflow possible
case paxAtime:
hdr.AccessTime, err = parsePAXTime(v)
case paxMtime:
hdr.ModTime, err = parsePAXTime(v)
case paxCtime:
hdr.ChangeTime, err = parsePAXTime(v)
case paxSize:
hdr.Size, err = strconv.ParseInt(v, 10, 64)
default:
if strings.HasPrefix(k, paxSchilyXattr) {
if hdr.Xattrs == nil {
hdr.Xattrs = make(map[string]string)
}
hdr.Xattrs[k[len(paxSchilyXattr):]] = v
}
}
if err != nil {
return ErrHeader
}
}
hdr.PAXRecords = paxHdrs
return nil
}
// parsePAX parses PAX headers.
// If an extended header (type 'x') is invalid, ErrHeader is returned.
func parsePAX(r io.Reader) (map[string]string, error) {
buf, err := readSpecialFile(r)
if err != nil {
return nil, err
}
sbuf := string(buf)
// For GNU PAX sparse format 0.0 support.
// This function transforms the sparse format 0.0 headers into format 0.1
// headers since 0.0 headers were not PAX compliant.
var sparseMap []string
paxHdrs := make(map[string]string)
for len(sbuf) > 0 {
key, value, residual, err := parsePAXRecord(sbuf)
if err != nil {
return nil, ErrHeader
}
sbuf = residual
switch key {
case paxGNUSparseOffset, paxGNUSparseNumBytes:
// Validate sparse header order and value.
if (len(sparseMap)%2 == 0 && key != paxGNUSparseOffset) ||
(len(sparseMap)%2 == 1 && key != paxGNUSparseNumBytes) ||
strings.Contains(value, ",") {
return nil, ErrHeader
}
sparseMap = append(sparseMap, value)
default:
paxHdrs[key] = value
}
}
if len(sparseMap) > 0 {
paxHdrs[paxGNUSparseMap] = strings.Join(sparseMap, ",")
}
return paxHdrs, nil
}
// readHeader reads the next block header and assumes that the underlying reader
// is already aligned to a block boundary. It returns the raw block of the
// header in case further processing is required.
//
// The err will be set to io.EOF only when one of the following occurs:
// - Exactly 0 bytes are read and EOF is hit.
// - Exactly 1 block of zeros is read and EOF is hit.
// - At least 2 blocks of zeros are read.
func (tr *Reader) readHeader() (*Header, *block, error) {
// Two blocks of zero bytes marks the end of the archive.
if _, err := io.ReadFull(tr.r, tr.blk[:]); err != nil {
return nil, nil, err // EOF is okay here; exactly 0 bytes read
}
if bytes.Equal(tr.blk[:], zeroBlock[:]) {
if _, err := io.ReadFull(tr.r, tr.blk[:]); err != nil {
return nil, nil, err // EOF is okay here; exactly 1 block of zeros read
}
if bytes.Equal(tr.blk[:], zeroBlock[:]) {
return nil, nil, io.EOF // normal EOF; exactly 2 block of zeros read
}
return nil, nil, ErrHeader // Zero block and then non-zero block
}
// Verify the header matches a known format.
format := tr.blk.getFormat()
if format == FormatUnknown {
return nil, nil, ErrHeader
}
var p parser
hdr := new(Header)
// Unpack the V7 header.
v7 := tr.blk.toV7()
hdr.Typeflag = v7.typeFlag()[0]
hdr.Name = p.parseString(v7.name())
hdr.Linkname = p.parseString(v7.linkName())
hdr.Size = p.parseNumeric(v7.size())
hdr.Mode = p.parseNumeric(v7.mode())
hdr.Uid = int(p.parseNumeric(v7.uid()))
hdr.Gid = int(p.parseNumeric(v7.gid()))
hdr.ModTime = time.Unix(p.parseNumeric(v7.modTime()), 0)
// Unpack format specific fields.
if format > formatV7 {
ustar := tr.blk.toUSTAR()
hdr.Uname = p.parseString(ustar.userName())
hdr.Gname = p.parseString(ustar.groupName())
hdr.Devmajor = p.parseNumeric(ustar.devMajor())
hdr.Devminor = p.parseNumeric(ustar.devMinor())
var prefix string
switch {
case format.has(FormatUSTAR | FormatPAX):
hdr.Format = format
ustar := tr.blk.toUSTAR()
prefix = p.parseString(ustar.prefix())
// For Format detection, check if block is properly formatted since
// the parser is more liberal than what USTAR actually permits.
notASCII := func(r rune) bool { return r >= 0x80 }
if bytes.IndexFunc(tr.blk[:], notASCII) >= 0 {
hdr.Format = FormatUnknown // Non-ASCII characters in block.
}
nul := func(b []byte) bool { return int(b[len(b)-1]) == 0 }
if !(nul(v7.size()) && nul(v7.mode()) && nul(v7.uid()) && nul(v7.gid()) &&
nul(v7.modTime()) && nul(ustar.devMajor()) && nul(ustar.devMinor())) {
hdr.Format = FormatUnknown // Numeric fields must end in NUL
}
case format.has(formatSTAR):
star := tr.blk.toSTAR()
prefix = p.parseString(star.prefix())
hdr.AccessTime = time.Unix(p.parseNumeric(star.accessTime()), 0)
hdr.ChangeTime = time.Unix(p.parseNumeric(star.changeTime()), 0)
case format.has(FormatGNU):
hdr.Format = format
var p2 parser
gnu := tr.blk.toGNU()
if b := gnu.accessTime(); b[0] != 0 {
hdr.AccessTime = time.Unix(p2.parseNumeric(b), 0)
}
if b := gnu.changeTime(); b[0] != 0 {
hdr.ChangeTime = time.Unix(p2.parseNumeric(b), 0)
}
// Prior to Go1.8, the Writer had a bug where it would output
// an invalid tar file in certain rare situations because the logic
// incorrectly believed that the old GNU format had a prefix field.
// This is wrong and leads to an output file that mangles the
// atime and ctime fields, which are often left unused.
//
// In order to continue reading tar files created by former, buggy
// versions of Go, we skeptically parse the atime and ctime fields.
// If we are unable to parse them and the prefix field looks like
// an ASCII string, then we fallback on the pre-Go1.8 behavior
// of treating these fields as the USTAR prefix field.
//
// Note that this will not use the fallback logic for all possible
// files generated by a pre-Go1.8 toolchain. If the generated file
// happened to have a prefix field that parses as valid
// atime and ctime fields (e.g., when they are valid octal strings),
// then it is impossible to distinguish between a valid GNU file
// and an invalid pre-Go1.8 file.
//
// See https://golang.org/issues/12594
// See https://golang.org/issues/21005
if p2.err != nil {
hdr.AccessTime, hdr.ChangeTime = time.Time{}, time.Time{}
ustar := tr.blk.toUSTAR()
if s := p.parseString(ustar.prefix()); isASCII(s) {
prefix = s
}
hdr.Format = FormatUnknown // Buggy file is not GNU
}
}
if len(prefix) > 0 {
hdr.Name = prefix + "/" + hdr.Name
}
}
return hdr, &tr.blk, p.err
}
// readOldGNUSparseMap reads the sparse map from the old GNU sparse format.
// The sparse map is stored in the tar header if it's small enough.
// If it's larger than four entries, then one or more extension headers are used
// to store the rest of the sparse map.
//
// The Header.Size does not reflect the size of any extended headers used.
// Thus, this function will read from the raw io.Reader to fetch extra headers.
// This method mutates blk in the process.
func (tr *Reader) readOldGNUSparseMap(hdr *Header, blk *block) (sparseDatas, error) {
// Make sure that the input format is GNU.
// Unfortunately, the STAR format also has a sparse header format that uses
// the same type flag but has a completely different layout.
if blk.getFormat() != FormatGNU {
return nil, ErrHeader
}
hdr.Format.mayOnlyBe(FormatGNU)
var p parser
hdr.Size = p.parseNumeric(blk.toGNU().realSize())
if p.err != nil {
return nil, p.err
}
s := blk.toGNU().sparse()
spd := make(sparseDatas, 0, s.maxEntries())
for {
for i := 0; i < s.maxEntries(); i++ {
// This termination condition is identical to GNU and BSD tar.
if s.entry(i).offset()[0] == 0x00 {
break // Don't return, need to process extended headers (even if empty)
}
offset := p.parseNumeric(s.entry(i).offset())
length := p.parseNumeric(s.entry(i).length())
if p.err != nil {
return nil, p.err
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
}
if s.isExtended()[0] > 0 {
// There are more entries. Read an extension header and parse its entries.
if _, err := mustReadFull(tr.r, blk[:]); err != nil {
return nil, err
}
s = blk.toSparse()
continue
}
return spd, nil // Done
}
}
// readGNUSparseMap1x0 reads the sparse map as stored in GNU's PAX sparse format
// version 1.0. The format of the sparse map consists of a series of
// newline-terminated numeric fields. The first field is the number of entries
// and is always present. Following this are the entries, consisting of two
// fields (offset, length). This function must stop reading at the end
// boundary of the block containing the last newline.
//
// Note that the GNU manual says that numeric values should be encoded in octal
// format. However, the GNU tar utility itself outputs these values in decimal.
// As such, this library treats values as being encoded in decimal.
func readGNUSparseMap1x0(r io.Reader) (sparseDatas, error) {
var (
cntNewline int64
buf bytes.Buffer
blk block
)
// feedTokens copies data in blocks from r into buf until there are
// at least cnt newlines in buf. It will not read more blocks than needed.
feedTokens := func(n int64) error {
for cntNewline < n {
if _, err := mustReadFull(r, blk[:]); err != nil {
return err
}
buf.Write(blk[:])
for _, c := range blk {
if c == '\n' {
cntNewline++
}
}
}
return nil
}
// nextToken gets the next token delimited by a newline. This assumes that
// at least one newline exists in the buffer.
nextToken := func() string {
cntNewline--
tok, _ := buf.ReadString('\n')
return strings.TrimRight(tok, "\n")
}
// Parse for the number of entries.
// Use integer overflow resistant math to check this.
if err := feedTokens(1); err != nil {
return nil, err
}
numEntries, err := strconv.ParseInt(nextToken(), 10, 0) // Intentionally parse as native int
if err != nil || numEntries < 0 || int(2*numEntries) < int(numEntries) {
return nil, ErrHeader
}
// Parse for all member entries.
// numEntries is trusted after this since a potential attacker must have
// committed resources proportional to what this library used.
if err := feedTokens(2 * numEntries); err != nil {
return nil, err
}
spd := make(sparseDatas, 0, numEntries)
for i := int64(0); i < numEntries; i++ {
offset, err1 := strconv.ParseInt(nextToken(), 10, 64)
length, err2 := strconv.ParseInt(nextToken(), 10, 64)
if err1 != nil || err2 != nil {
return nil, ErrHeader
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
}
return spd, nil
}
// readGNUSparseMap0x1 reads the sparse map as stored in GNU's PAX sparse format
// version 0.1. The sparse map is stored in the PAX headers.
func readGNUSparseMap0x1(paxHdrs map[string]string) (sparseDatas, error) {
// Get number of entries.
// Use integer overflow resistant math to check this.
numEntriesStr := paxHdrs[paxGNUSparseNumBlocks]
numEntries, err := strconv.ParseInt(numEntriesStr, 10, 0) // Intentionally parse as native int
if err != nil || numEntries < 0 || int(2*numEntries) < int(numEntries) {
return nil, ErrHeader
}
// There should be two numbers in sparseMap for each entry.
sparseMap := strings.Split(paxHdrs[paxGNUSparseMap], ",")
if len(sparseMap) == 1 && sparseMap[0] == "" {
sparseMap = sparseMap[:0]
}
if int64(len(sparseMap)) != 2*numEntries {
return nil, ErrHeader
}
// Loop through the entries in the sparse map.
// numEntries is trusted now.
spd := make(sparseDatas, 0, numEntries)
for len(sparseMap) >= 2 {
offset, err1 := strconv.ParseInt(sparseMap[0], 10, 64)
length, err2 := strconv.ParseInt(sparseMap[1], 10, 64)
if err1 != nil || err2 != nil {
return nil, ErrHeader
}
spd = append(spd, sparseEntry{Offset: offset, Length: length})
sparseMap = sparseMap[2:]
}
return spd, nil
}
// Read reads from the current file in the tar archive.
// It returns (0, io.EOF) when it reaches the end of that file,
// until Next is called to advance to the next file.
//
// If the current file is sparse, then the regions marked as a hole
// are read back as NUL-bytes.
//
// Calling Read on special types like TypeLink, TypeSymlink, TypeChar,
// TypeBlock, TypeDir, and TypeFifo returns (0, io.EOF) regardless of what
// the Header.Size claims.
func (tr *Reader) Read(b []byte) (int, error) {
if tr.err != nil {
return 0, tr.err
}
n, err := tr.curr.Read(b)
if err != nil && err != io.EOF {
tr.err = err
}
return n, err
}
// writeTo writes the content of the current file to w.
// The bytes written matches the number of remaining bytes in the current file.
//
// If the current file is sparse and w is an io.WriteSeeker,
// then writeTo uses Seek to skip past holes defined in Header.SparseHoles,
// assuming that skipped regions are filled with NULs.
// This always writes the last byte to ensure w is the right size.
//
// TODO(dsnet): Re-export this when adding sparse file support.
// See https://golang.org/issue/22735
func (tr *Reader) writeTo(w io.Writer) (int64, error) {
if tr.err != nil {
return 0, tr.err
}
n, err := tr.curr.WriteTo(w)
if err != nil {
tr.err = err
}
return n, err
}
// regFileReader is a fileReader for reading data from a regular file entry.
type regFileReader struct {
r io.Reader // Underlying Reader
nb int64 // Number of remaining bytes to read
}
func (fr *regFileReader) Read(b []byte) (n int, err error) {
if int64(len(b)) > fr.nb {
b = b[:fr.nb]
}
if len(b) > 0 {
n, err = fr.r.Read(b)
fr.nb -= int64(n)
}
switch {
case err == io.EOF && fr.nb > 0:
return n, io.ErrUnexpectedEOF
case err == nil && fr.nb == 0:
return n, io.EOF
default:
return n, err
}
}
func (fr *regFileReader) WriteTo(w io.Writer) (int64, error) {
return io.Copy(w, struct{ io.Reader }{fr})
}
// logicalRemaining implements fileState.logicalRemaining.
func (fr regFileReader) logicalRemaining() int64 {
return fr.nb
}
// physicalRemaining implements fileState.physicalRemaining.
func (fr regFileReader) physicalRemaining() int64 {
return fr.nb
}
// sparseFileReader is a fileReader for reading data from a sparse file entry.
type sparseFileReader struct {
fr fileReader // Underlying fileReader
sp sparseHoles // Normalized list of sparse holes
pos int64 // Current position in sparse file
}
func (sr *sparseFileReader) Read(b []byte) (n int, err error) {
finished := int64(len(b)) >= sr.logicalRemaining()
if finished {
b = b[:sr.logicalRemaining()]
}
b0 := b
endPos := sr.pos + int64(len(b))
for endPos > sr.pos && err == nil {
var nf int // Bytes read in fragment
holeStart, holeEnd := sr.sp[0].Offset, sr.sp[0].endOffset()
if sr.pos < holeStart { // In a data fragment
bf := b[:min(int64(len(b)), holeStart-sr.pos)]
nf, err = tryReadFull(sr.fr, bf)
} else { // In a hole fragment
bf := b[:min(int64(len(b)), holeEnd-sr.pos)]
nf, err = tryReadFull(zeroReader{}, bf)
}
b = b[nf:]
sr.pos += int64(nf)
if sr.pos >= holeEnd && len(sr.sp) > 1 {
sr.sp = sr.sp[1:] // Ensure last fragment always remains
}
}
n = len(b0) - len(b)
switch {
case err == io.EOF:
return n, errMissData // Less data in dense file than sparse file
case err != nil:
return n, err
case sr.logicalRemaining() == 0 && sr.physicalRemaining() > 0:
return n, errUnrefData // More data in dense file than sparse file
case finished:
return n, io.EOF
default:
return n, nil
}
}
func (sr *sparseFileReader) WriteTo(w io.Writer) (n int64, err error) {
ws, ok := w.(io.WriteSeeker)
if ok {
if _, err := ws.Seek(0, io.SeekCurrent); err != nil {
ok = false // Not all io.Seeker can really seek
}
}
if !ok {
return io.Copy(w, struct{ io.Reader }{sr})
}
var writeLastByte bool
pos0 := sr.pos
for sr.logicalRemaining() > 0 && !writeLastByte && err == nil {
var nf int64 // Size of fragment
holeStart, holeEnd := sr.sp[0].Offset, sr.sp[0].endOffset()
if sr.pos < holeStart { // In a data fragment
nf = holeStart - sr.pos
nf, err = io.CopyN(ws, sr.fr, nf)
} else { // In a hole fragment
nf = holeEnd - sr.pos
if sr.physicalRemaining() == 0 {
writeLastByte = true
nf--
}
_, err = ws.Seek(nf, io.SeekCurrent)
}
sr.pos += nf
if sr.pos >= holeEnd && len(sr.sp) > 1 {
sr.sp = sr.sp[1:] // Ensure last fragment always remains
}
}
// If the last fragment is a hole, then seek to 1-byte before EOF, and
// write a single byte to ensure the file is the right size.
if writeLastByte && err == nil {
_, err = ws.Write([]byte{0})
sr.pos++
}
n = sr.pos - pos0
switch {
case err == io.EOF:
return n, errMissData // Less data in dense file than sparse file
case err != nil:
return n, err
case sr.logicalRemaining() == 0 && sr.physicalRemaining() > 0:
return n, errUnrefData // More data in dense file than sparse file
default:
return n, nil
}
}
func (sr sparseFileReader) logicalRemaining() int64 {
return sr.sp[len(sr.sp)-1].endOffset() - sr.pos
}
func (sr sparseFileReader) physicalRemaining() int64 {
return sr.fr.physicalRemaining()
}
type zeroReader struct{}
func (zeroReader) Read(b []byte) (int, error) {
for i := range b {
b[i] = 0
}
return len(b), nil
}
// mustReadFull is like io.ReadFull except it returns
// io.ErrUnexpectedEOF when io.EOF is hit before len(b) bytes are read.
func mustReadFull(r io.Reader, b []byte) (int, error) {
n, err := tryReadFull(r, b)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return n, err
}
// tryReadFull is like io.ReadFull except it returns
// io.EOF when it is hit before len(b) bytes are read.
func tryReadFull(r io.Reader, b []byte) (n int, err error) {
for len(b) > n && err == nil {
var nn int
nn, err = r.Read(b[n:])
n += nn
}
if len(b) == n && err == io.EOF {
err = nil
}
return n, err
}
// readSpecialFile is like io.ReadAll except it returns
// ErrFieldTooLong if more than maxSpecialFileSize is read.
func readSpecialFile(r io.Reader) ([]byte, error) {
buf, err := io.ReadAll(io.LimitReader(r, maxSpecialFileSize+1))
if len(buf) > maxSpecialFileSize {
return nil, ErrFieldTooLong
}
return buf, err
}
// discard skips n bytes in r, reporting an error if unable to do so.
func discard(r io.Reader, n int64) error {
// If possible, Seek to the last byte before the end of the data section.
// Do this because Seek is often lazy about reporting errors; this will mask
// the fact that the stream may be truncated. We can rely on the
// io.CopyN done shortly afterwards to trigger any IO errors.
var seekSkipped int64 // Number of bytes skipped via Seek
if sr, ok := r.(io.Seeker); ok && n > 1 {
// Not all io.Seeker can actually Seek. For example, os.Stdin implements
// io.Seeker, but calling Seek always returns an error and performs
// no action. Thus, we try an innocent seek to the current position
// to see if Seek is really supported.
pos1, err := sr.Seek(0, io.SeekCurrent)
if pos1 >= 0 && err == nil {
// Seek seems supported, so perform the real Seek.
pos2, err := sr.Seek(n-1, io.SeekCurrent)
if pos2 < 0 || err != nil {
return err
}
seekSkipped = pos2 - pos1
}
}
copySkipped, err := io.CopyN(io.Discard, r, n-seekSkipped)
if err == io.EOF && seekSkipped+copySkipped < n {
err = io.ErrUnexpectedEOF
}
return err
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || linux || dragonfly || openbsd || solaris
package tar
import (
"syscall"
"time"
)
func statAtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Atim.Unix())
}
func statCtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Ctim.Unix())
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package tar
import (
"io/fs"
"os/user"
"runtime"
"strconv"
"sync"
"syscall"
)
func init() {
sysStat = statUnix
}
// userMap and groupMap caches UID and GID lookups for performance reasons.
// The downside is that renaming uname or gname by the OS never takes effect.
var userMap, groupMap sync.Map // map[int]string
func statUnix(fi fs.FileInfo, h *Header) error {
sys, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
return nil
}
h.Uid = int(sys.Uid)
h.Gid = int(sys.Gid)
// Best effort at populating Uname and Gname.
// The os/user functions may fail for any number of reasons
// (not implemented on that platform, cgo not enabled, etc).
if u, ok := userMap.Load(h.Uid); ok {
h.Uname = u.(string)
} else if u, err := user.LookupId(strconv.Itoa(h.Uid)); err == nil {
h.Uname = u.Username
userMap.Store(h.Uid, h.Uname)
}
if g, ok := groupMap.Load(h.Gid); ok {
h.Gname = g.(string)
} else if g, err := user.LookupGroupId(strconv.Itoa(h.Gid)); err == nil {
h.Gname = g.Name
groupMap.Store(h.Gid, h.Gname)
}
h.AccessTime = statAtime(sys)
h.ChangeTime = statCtime(sys)
// Best effort at populating Devmajor and Devminor.
if h.Typeflag == TypeChar || h.Typeflag == TypeBlock {
dev := uint64(sys.Rdev) // May be int32 or uint32
switch runtime.GOOS {
case "aix":
var major, minor uint32
major = uint32((dev & 0x3fffffff00000000) >> 32)
minor = uint32((dev & 0x00000000ffffffff) >> 0)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "linux":
// Copied from golang.org/x/sys/unix/dev_linux.go.
major := uint32((dev & 0x00000000000fff00) >> 8)
major |= uint32((dev & 0xfffff00000000000) >> 32)
minor := uint32((dev & 0x00000000000000ff) >> 0)
minor |= uint32((dev & 0x00000ffffff00000) >> 12)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "darwin", "ios":
// Copied from golang.org/x/sys/unix/dev_darwin.go.
major := uint32((dev >> 24) & 0xff)
minor := uint32(dev & 0xffffff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "dragonfly":
// Copied from golang.org/x/sys/unix/dev_dragonfly.go.
major := uint32((dev >> 8) & 0xff)
minor := uint32(dev & 0xffff00ff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "freebsd":
// Copied from golang.org/x/sys/unix/dev_freebsd.go.
major := uint32((dev >> 8) & 0xff)
minor := uint32(dev & 0xffff00ff)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "netbsd":
// Copied from golang.org/x/sys/unix/dev_netbsd.go.
major := uint32((dev & 0x000fff00) >> 8)
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xfff00000) >> 12)
h.Devmajor, h.Devminor = int64(major), int64(minor)
case "openbsd":
// Copied from golang.org/x/sys/unix/dev_openbsd.go.
major := uint32((dev & 0x0000ff00) >> 8)
minor := uint32((dev & 0x000000ff) >> 0)
minor |= uint32((dev & 0xffff0000) >> 8)
h.Devmajor, h.Devminor = int64(major), int64(minor)
default:
// TODO: Implement solaris (see https://golang.org/issue/8106)
}
}
return nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"fmt"
"strconv"
"strings"
"time"
)
// hasNUL reports whether the NUL character exists within s.
func hasNUL(s string) bool {
return strings.Contains(s, "\x00")
}
// isASCII reports whether the input is an ASCII C-style string.
func isASCII(s string) bool {
for _, c := range s {
if c >= 0x80 || c == 0x00 {
return false
}
}
return true
}
// toASCII converts the input to an ASCII C-style string.
// This is a best effort conversion, so invalid characters are dropped.
func toASCII(s string) string {
if isASCII(s) {
return s
}
b := make([]byte, 0, len(s))
for _, c := range s {
if c < 0x80 && c != 0x00 {
b = append(b, byte(c))
}
}
return string(b)
}
type parser struct {
err error // Last error seen
}
type formatter struct {
err error // Last error seen
}
// parseString parses bytes as a NUL-terminated C-style string.
// If a NUL byte is not found then the whole slice is returned as a string.
func (*parser) parseString(b []byte) string {
if i := bytes.IndexByte(b, 0); i >= 0 {
return string(b[:i])
}
return string(b)
}
// formatString copies s into b, NUL-terminating if possible.
func (f *formatter) formatString(b []byte, s string) {
if len(s) > len(b) {
f.err = ErrFieldTooLong
}
copy(b, s)
if len(s) < len(b) {
b[len(s)] = 0
}
// Some buggy readers treat regular files with a trailing slash
// in the V7 path field as a directory even though the full path
// recorded elsewhere (e.g., via PAX record) contains no trailing slash.
if len(s) > len(b) && b[len(b)-1] == '/' {
n := len(strings.TrimRight(s[:len(b)], "/"))
b[n] = 0 // Replace trailing slash with NUL terminator
}
}
// fitsInBase256 reports whether x can be encoded into n bytes using base-256
// encoding. Unlike octal encoding, base-256 encoding does not require that the
// string ends with a NUL character. Thus, all n bytes are available for output.
//
// If operating in binary mode, this assumes strict GNU binary mode; which means
// that the first byte can only be either 0x80 or 0xff. Thus, the first byte is
// equivalent to the sign bit in two's complement form.
func fitsInBase256(n int, x int64) bool {
binBits := uint(n-1) * 8
return n >= 9 || (x >= -1<<binBits && x < 1<<binBits)
}
// parseNumeric parses the input as being encoded in either base-256 or octal.
// This function may return negative numbers.
// If parsing fails or an integer overflow occurs, err will be set.
func (p *parser) parseNumeric(b []byte) int64 {
// Check for base-256 (binary) format first.
// If the first bit is set, then all following bits constitute a two's
// complement encoded number in big-endian byte order.
if len(b) > 0 && b[0]&0x80 != 0 {
// Handling negative numbers relies on the following identity:
// -a-1 == ^a
//
// If the number is negative, we use an inversion mask to invert the
// data bytes and treat the value as an unsigned number.
var inv byte // 0x00 if positive or zero, 0xff if negative
if b[0]&0x40 != 0 {
inv = 0xff
}
var x uint64
for i, c := range b {
c ^= inv // Inverts c only if inv is 0xff, otherwise does nothing
if i == 0 {
c &= 0x7f // Ignore signal bit in first byte
}
if (x >> 56) > 0 {
p.err = ErrHeader // Integer overflow
return 0
}
x = x<<8 | uint64(c)
}
if (x >> 63) > 0 {
p.err = ErrHeader // Integer overflow
return 0
}
if inv == 0xff {
return ^int64(x)
}
return int64(x)
}
// Normal case is base-8 (octal) format.
return p.parseOctal(b)
}
// formatNumeric encodes x into b using base-8 (octal) encoding if possible.
// Otherwise it will attempt to use base-256 (binary) encoding.
func (f *formatter) formatNumeric(b []byte, x int64) {
if fitsInOctal(len(b), x) {
f.formatOctal(b, x)
return
}
if fitsInBase256(len(b), x) {
for i := len(b) - 1; i >= 0; i-- {
b[i] = byte(x)
x >>= 8
}
b[0] |= 0x80 // Highest bit indicates binary format
return
}
f.formatOctal(b, 0) // Last resort, just write zero
f.err = ErrFieldTooLong
}
func (p *parser) parseOctal(b []byte) int64 {
// Because unused fields are filled with NULs, we need
// to skip leading NULs. Fields may also be padded with
// spaces or NULs.
// So we remove leading and trailing NULs and spaces to
// be sure.
b = bytes.Trim(b, " \x00")
if len(b) == 0 {
return 0
}
x, perr := strconv.ParseUint(p.parseString(b), 8, 64)
if perr != nil {
p.err = ErrHeader
}
return int64(x)
}
func (f *formatter) formatOctal(b []byte, x int64) {
if !fitsInOctal(len(b), x) {
x = 0 // Last resort, just write zero
f.err = ErrFieldTooLong
}
s := strconv.FormatInt(x, 8)
// Add leading zeros, but leave room for a NUL.
if n := len(b) - len(s) - 1; n > 0 {
s = strings.Repeat("0", n) + s
}
f.formatString(b, s)
}
// fitsInOctal reports whether the integer x fits in a field n-bytes long
// using octal encoding with the appropriate NUL terminator.
func fitsInOctal(n int, x int64) bool {
octBits := uint(n-1) * 3
return x >= 0 && (n >= 22 || x < 1<<octBits)
}
// parsePAXTime takes a string of the form %d.%d as described in the PAX
// specification. Note that this implementation allows for negative timestamps,
// which is allowed for by the PAX specification, but not always portable.
func parsePAXTime(s string) (time.Time, error) {
const maxNanoSecondDigits = 9
// Split string into seconds and sub-seconds parts.
ss, sn, _ := strings.Cut(s, ".")
// Parse the seconds.
secs, err := strconv.ParseInt(ss, 10, 64)
if err != nil {
return time.Time{}, ErrHeader
}
if len(sn) == 0 {
return time.Unix(secs, 0), nil // No sub-second values
}
// Parse the nanoseconds.
if strings.Trim(sn, "0123456789") != "" {
return time.Time{}, ErrHeader
}
if len(sn) < maxNanoSecondDigits {
sn += strings.Repeat("0", maxNanoSecondDigits-len(sn)) // Right pad
} else {
sn = sn[:maxNanoSecondDigits] // Right truncate
}
nsecs, _ := strconv.ParseInt(sn, 10, 64) // Must succeed
if len(ss) > 0 && ss[0] == '-' {
return time.Unix(secs, -1*nsecs), nil // Negative correction
}
return time.Unix(secs, nsecs), nil
}
// formatPAXTime converts ts into a time of the form %d.%d as described in the
// PAX specification. This function is capable of negative timestamps.
func formatPAXTime(ts time.Time) (s string) {
secs, nsecs := ts.Unix(), ts.Nanosecond()
if nsecs == 0 {
return strconv.FormatInt(secs, 10)
}
// If seconds is negative, then perform correction.
sign := ""
if secs < 0 {
sign = "-" // Remember sign
secs = -(secs + 1) // Add a second to secs
nsecs = -(nsecs - 1e9) // Take that second away from nsecs
}
return strings.TrimRight(fmt.Sprintf("%s%d.%09d", sign, secs, nsecs), "0")
}
// parsePAXRecord parses the input PAX record string into a key-value pair.
// If parsing is successful, it will slice off the currently read record and
// return the remainder as r.
func parsePAXRecord(s string) (k, v, r string, err error) {
// The size field ends at the first space.
nStr, rest, ok := strings.Cut(s, " ")
if !ok {
return "", "", s, ErrHeader
}
// Parse the first token as a decimal integer.
n, perr := strconv.ParseInt(nStr, 10, 0) // Intentionally parse as native int
if perr != nil || n < 5 || n > int64(len(s)) {
return "", "", s, ErrHeader
}
n -= int64(len(nStr) + 1) // convert from index in s to index in rest
if n <= 0 {
return "", "", s, ErrHeader
}
// Extract everything between the space and the final newline.
rec, nl, rem := rest[:n-1], rest[n-1:n], rest[n:]
if nl != "\n" {
return "", "", s, ErrHeader
}
// The first equals separates the key from the value.
k, v, ok = strings.Cut(rec, "=")
if !ok {
return "", "", s, ErrHeader
}
if !validPAXRecord(k, v) {
return "", "", s, ErrHeader
}
return k, v, rem, nil
}
// formatPAXRecord formats a single PAX record, prefixing it with the
// appropriate length.
func formatPAXRecord(k, v string) (string, error) {
if !validPAXRecord(k, v) {
return "", ErrHeader
}
const padding = 3 // Extra padding for ' ', '=', and '\n'
size := len(k) + len(v) + padding
size += len(strconv.Itoa(size))
record := strconv.Itoa(size) + " " + k + "=" + v + "\n"
// Final adjustment if adding size field increased the record size.
if len(record) != size {
size = len(record)
record = strconv.Itoa(size) + " " + k + "=" + v + "\n"
}
return record, nil
}
// validPAXRecord reports whether the key-value pair is valid where each
// record is formatted as:
//
// "%d %s=%s\n" % (size, key, value)
//
// Keys and values should be UTF-8, but the number of bad writers out there
// forces us to be a more liberal.
// Thus, we only reject all keys with NUL, and only reject NULs in values
// for the PAX version of the USTAR string fields.
// The key must not contain an '=' character.
func validPAXRecord(k, v string) bool {
if k == "" || strings.Contains(k, "=") {
return false
}
switch k {
case paxPath, paxLinkpath, paxUname, paxGname:
return !hasNUL(v)
default:
return !hasNUL(k)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"fmt"
"io"
"path"
"sort"
"strings"
"time"
)
// Writer provides sequential writing of a tar archive.
// Write.WriteHeader begins a new file with the provided Header,
// and then Writer can be treated as an io.Writer to supply that file's data.
type Writer struct {
w io.Writer
pad int64 // Amount of padding to write after current file entry
curr fileWriter // Writer for current file entry
hdr Header // Shallow copy of Header that is safe for mutations
blk block // Buffer to use as temporary local storage
// err is a persistent error.
// It is only the responsibility of every exported method of Writer to
// ensure that this error is sticky.
err error
}
// NewWriter creates a new Writer writing to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{w: w, curr: ®FileWriter{w, 0}}
}
type fileWriter interface {
io.Writer
fileState
ReadFrom(io.Reader) (int64, error)
}
// Flush finishes writing the current file's block padding.
// The current file must be fully written before Flush can be called.
//
// This is unnecessary as the next call to WriteHeader or Close
// will implicitly flush out the file's padding.
func (tw *Writer) Flush() error {
if tw.err != nil {
return tw.err
}
if nb := tw.curr.logicalRemaining(); nb > 0 {
return fmt.Errorf("archive/tar: missed writing %d bytes", nb)
}
if _, tw.err = tw.w.Write(zeroBlock[:tw.pad]); tw.err != nil {
return tw.err
}
tw.pad = 0
return nil
}
// WriteHeader writes hdr and prepares to accept the file's contents.
// The Header.Size determines how many bytes can be written for the next file.
// If the current file is not fully written, then this returns an error.
// This implicitly flushes any padding necessary before writing the header.
func (tw *Writer) WriteHeader(hdr *Header) error {
if err := tw.Flush(); err != nil {
return err
}
tw.hdr = *hdr // Shallow copy of Header
// Avoid usage of the legacy TypeRegA flag, and automatically promote
// it to use TypeReg or TypeDir.
if tw.hdr.Typeflag == TypeRegA {
if strings.HasSuffix(tw.hdr.Name, "/") {
tw.hdr.Typeflag = TypeDir
} else {
tw.hdr.Typeflag = TypeReg
}
}
// Round ModTime and ignore AccessTime and ChangeTime unless
// the format is explicitly chosen.
// This ensures nominal usage of WriteHeader (without specifying the format)
// does not always result in the PAX format being chosen, which
// causes a 1KiB increase to every header.
if tw.hdr.Format == FormatUnknown {
tw.hdr.ModTime = tw.hdr.ModTime.Round(time.Second)
tw.hdr.AccessTime = time.Time{}
tw.hdr.ChangeTime = time.Time{}
}
allowedFormats, paxHdrs, err := tw.hdr.allowedFormats()
switch {
case allowedFormats.has(FormatUSTAR):
tw.err = tw.writeUSTARHeader(&tw.hdr)
return tw.err
case allowedFormats.has(FormatPAX):
tw.err = tw.writePAXHeader(&tw.hdr, paxHdrs)
return tw.err
case allowedFormats.has(FormatGNU):
tw.err = tw.writeGNUHeader(&tw.hdr)
return tw.err
default:
return err // Non-fatal error
}
}
func (tw *Writer) writeUSTARHeader(hdr *Header) error {
// Check if we can use USTAR prefix/suffix splitting.
var namePrefix string
if prefix, suffix, ok := splitUSTARPath(hdr.Name); ok {
namePrefix, hdr.Name = prefix, suffix
}
// Pack the main header.
var f formatter
blk := tw.templateV7Plus(hdr, f.formatString, f.formatOctal)
f.formatString(blk.toUSTAR().prefix(), namePrefix)
blk.setFormat(FormatUSTAR)
if f.err != nil {
return f.err // Should never happen since header is validated
}
return tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag)
}
func (tw *Writer) writePAXHeader(hdr *Header, paxHdrs map[string]string) error {
realName, realSize := hdr.Name, hdr.Size
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Handle sparse files.
var spd sparseDatas
var spb []byte
if len(hdr.SparseHoles) > 0 {
sph := append([]sparseEntry{}, hdr.SparseHoles...) // Copy sparse map
sph = alignSparseEntries(sph, hdr.Size)
spd = invertSparseEntries(sph, hdr.Size)
// Format the sparse map.
hdr.Size = 0 // Replace with encoded size
spb = append(strconv.AppendInt(spb, int64(len(spd)), 10), '\n')
for _, s := range spd {
hdr.Size += s.Length
spb = append(strconv.AppendInt(spb, s.Offset, 10), '\n')
spb = append(strconv.AppendInt(spb, s.Length, 10), '\n')
}
pad := blockPadding(int64(len(spb)))
spb = append(spb, zeroBlock[:pad]...)
hdr.Size += int64(len(spb)) // Accounts for encoded sparse map
// Add and modify appropriate PAX records.
dir, file := path.Split(realName)
hdr.Name = path.Join(dir, "GNUSparseFile.0", file)
paxHdrs[paxGNUSparseMajor] = "1"
paxHdrs[paxGNUSparseMinor] = "0"
paxHdrs[paxGNUSparseName] = realName
paxHdrs[paxGNUSparseRealSize] = strconv.FormatInt(realSize, 10)
paxHdrs[paxSize] = strconv.FormatInt(hdr.Size, 10)
delete(paxHdrs, paxPath) // Recorded by paxGNUSparseName
}
*/
_ = realSize
// Write PAX records to the output.
isGlobal := hdr.Typeflag == TypeXGlobalHeader
if len(paxHdrs) > 0 || isGlobal {
// Sort keys for deterministic ordering.
var keys []string
for k := range paxHdrs {
keys = append(keys, k)
}
sort.Strings(keys)
// Write each record to a buffer.
var buf strings.Builder
for _, k := range keys {
rec, err := formatPAXRecord(k, paxHdrs[k])
if err != nil {
return err
}
buf.WriteString(rec)
}
// Write the extended header file.
var name string
var flag byte
if isGlobal {
name = realName
if name == "" {
name = "GlobalHead.0.0"
}
flag = TypeXGlobalHeader
} else {
dir, file := path.Split(realName)
name = path.Join(dir, "PaxHeaders.0", file)
flag = TypeXHeader
}
data := buf.String()
if len(data) > maxSpecialFileSize {
return ErrFieldTooLong
}
if err := tw.writeRawFile(name, data, flag, FormatPAX); err != nil || isGlobal {
return err // Global headers return here
}
}
// Pack the main header.
var f formatter // Ignore errors since they are expected
fmtStr := func(b []byte, s string) { f.formatString(b, toASCII(s)) }
blk := tw.templateV7Plus(hdr, fmtStr, f.formatOctal)
blk.setFormat(FormatPAX)
if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
return err
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
// Write the sparse map and setup the sparse writer if necessary.
if len(spd) > 0 {
// Use tw.curr since the sparse map is accounted for in hdr.Size.
if _, err := tw.curr.Write(spb); err != nil {
return err
}
tw.curr = &sparseFileWriter{tw.curr, spd, 0}
}
*/
return nil
}
func (tw *Writer) writeGNUHeader(hdr *Header) error {
// Use long-link files if Name or Linkname exceeds the field size.
const longName = "././@LongLink"
if len(hdr.Name) > nameSize {
data := hdr.Name + "\x00"
if err := tw.writeRawFile(longName, data, TypeGNULongName, FormatGNU); err != nil {
return err
}
}
if len(hdr.Linkname) > nameSize {
data := hdr.Linkname + "\x00"
if err := tw.writeRawFile(longName, data, TypeGNULongLink, FormatGNU); err != nil {
return err
}
}
// Pack the main header.
var f formatter // Ignore errors since they are expected
var spd sparseDatas
var spb []byte
blk := tw.templateV7Plus(hdr, f.formatString, f.formatNumeric)
if !hdr.AccessTime.IsZero() {
f.formatNumeric(blk.toGNU().accessTime(), hdr.AccessTime.Unix())
}
if !hdr.ChangeTime.IsZero() {
f.formatNumeric(blk.toGNU().changeTime(), hdr.ChangeTime.Unix())
}
// TODO(dsnet): Re-enable this when adding sparse support.
// See https://golang.org/issue/22735
/*
if hdr.Typeflag == TypeGNUSparse {
sph := append([]sparseEntry{}, hdr.SparseHoles...) // Copy sparse map
sph = alignSparseEntries(sph, hdr.Size)
spd = invertSparseEntries(sph, hdr.Size)
// Format the sparse map.
formatSPD := func(sp sparseDatas, sa sparseArray) sparseDatas {
for i := 0; len(sp) > 0 && i < sa.MaxEntries(); i++ {
f.formatNumeric(sa.Entry(i).Offset(), sp[0].Offset)
f.formatNumeric(sa.Entry(i).Length(), sp[0].Length)
sp = sp[1:]
}
if len(sp) > 0 {
sa.IsExtended()[0] = 1
}
return sp
}
sp2 := formatSPD(spd, blk.GNU().Sparse())
for len(sp2) > 0 {
var spHdr block
sp2 = formatSPD(sp2, spHdr.Sparse())
spb = append(spb, spHdr[:]...)
}
// Update size fields in the header block.
realSize := hdr.Size
hdr.Size = 0 // Encoded size; does not account for encoded sparse map
for _, s := range spd {
hdr.Size += s.Length
}
copy(blk.V7().Size(), zeroBlock[:]) // Reset field
f.formatNumeric(blk.V7().Size(), hdr.Size)
f.formatNumeric(blk.GNU().RealSize(), realSize)
}
*/
blk.setFormat(FormatGNU)
if err := tw.writeRawHeader(blk, hdr.Size, hdr.Typeflag); err != nil {
return err
}
// Write the extended sparse map and setup the sparse writer if necessary.
if len(spd) > 0 {
// Use tw.w since the sparse map is not accounted for in hdr.Size.
if _, err := tw.w.Write(spb); err != nil {
return err
}
tw.curr = &sparseFileWriter{tw.curr, spd, 0}
}
return nil
}
type (
stringFormatter func([]byte, string)
numberFormatter func([]byte, int64)
)
// templateV7Plus fills out the V7 fields of a block using values from hdr.
// It also fills out fields (uname, gname, devmajor, devminor) that are
// shared in the USTAR, PAX, and GNU formats using the provided formatters.
//
// The block returned is only valid until the next call to
// templateV7Plus or writeRawFile.
func (tw *Writer) templateV7Plus(hdr *Header, fmtStr stringFormatter, fmtNum numberFormatter) *block {
tw.blk.reset()
modTime := hdr.ModTime
if modTime.IsZero() {
modTime = time.Unix(0, 0)
}
v7 := tw.blk.toV7()
v7.typeFlag()[0] = hdr.Typeflag
fmtStr(v7.name(), hdr.Name)
fmtStr(v7.linkName(), hdr.Linkname)
fmtNum(v7.mode(), hdr.Mode)
fmtNum(v7.uid(), int64(hdr.Uid))
fmtNum(v7.gid(), int64(hdr.Gid))
fmtNum(v7.size(), hdr.Size)
fmtNum(v7.modTime(), modTime.Unix())
ustar := tw.blk.toUSTAR()
fmtStr(ustar.userName(), hdr.Uname)
fmtStr(ustar.groupName(), hdr.Gname)
fmtNum(ustar.devMajor(), hdr.Devmajor)
fmtNum(ustar.devMinor(), hdr.Devminor)
return &tw.blk
}
// writeRawFile writes a minimal file with the given name and flag type.
// It uses format to encode the header format and will write data as the body.
// It uses default values for all of the other fields (as BSD and GNU tar does).
func (tw *Writer) writeRawFile(name, data string, flag byte, format Format) error {
tw.blk.reset()
// Best effort for the filename.
name = toASCII(name)
if len(name) > nameSize {
name = name[:nameSize]
}
name = strings.TrimRight(name, "/")
var f formatter
v7 := tw.blk.toV7()
v7.typeFlag()[0] = flag
f.formatString(v7.name(), name)
f.formatOctal(v7.mode(), 0)
f.formatOctal(v7.uid(), 0)
f.formatOctal(v7.gid(), 0)
f.formatOctal(v7.size(), int64(len(data))) // Must be < 8GiB
f.formatOctal(v7.modTime(), 0)
tw.blk.setFormat(format)
if f.err != nil {
return f.err // Only occurs if size condition is violated
}
// Write the header and data.
if err := tw.writeRawHeader(&tw.blk, int64(len(data)), flag); err != nil {
return err
}
_, err := io.WriteString(tw, data)
return err
}
// writeRawHeader writes the value of blk, regardless of its value.
// It sets up the Writer such that it can accept a file of the given size.
// If the flag is a special header-only flag, then the size is treated as zero.
func (tw *Writer) writeRawHeader(blk *block, size int64, flag byte) error {
if err := tw.Flush(); err != nil {
return err
}
if _, err := tw.w.Write(blk[:]); err != nil {
return err
}
if isHeaderOnlyType(flag) {
size = 0
}
tw.curr = ®FileWriter{tw.w, size}
tw.pad = blockPadding(size)
return nil
}
// splitUSTARPath splits a path according to USTAR prefix and suffix rules.
// If the path is not splittable, then it will return ("", "", false).
func splitUSTARPath(name string) (prefix, suffix string, ok bool) {
length := len(name)
if length <= nameSize || !isASCII(name) {
return "", "", false
} else if length > prefixSize+1 {
length = prefixSize + 1
} else if name[length-1] == '/' {
length--
}
i := strings.LastIndex(name[:length], "/")
nlen := len(name) - i - 1 // nlen is length of suffix
plen := i // plen is length of prefix
if i <= 0 || nlen > nameSize || nlen == 0 || plen > prefixSize {
return "", "", false
}
return name[:i], name[i+1:], true
}
// Write writes to the current file in the tar archive.
// Write returns the error ErrWriteTooLong if more than
// Header.Size bytes are written after WriteHeader.
//
// Calling Write on special types like TypeLink, TypeSymlink, TypeChar,
// TypeBlock, TypeDir, and TypeFifo returns (0, ErrWriteTooLong) regardless
// of what the Header.Size claims.
func (tw *Writer) Write(b []byte) (int, error) {
if tw.err != nil {
return 0, tw.err
}
n, err := tw.curr.Write(b)
if err != nil && err != ErrWriteTooLong {
tw.err = err
}
return n, err
}
// readFrom populates the content of the current file by reading from r.
// The bytes read must match the number of remaining bytes in the current file.
//
// If the current file is sparse and r is an io.ReadSeeker,
// then readFrom uses Seek to skip past holes defined in Header.SparseHoles,
// assuming that skipped regions are all NULs.
// This always reads the last byte to ensure r is the right size.
//
// TODO(dsnet): Re-export this when adding sparse file support.
// See https://golang.org/issue/22735
func (tw *Writer) readFrom(r io.Reader) (int64, error) {
if tw.err != nil {
return 0, tw.err
}
n, err := tw.curr.ReadFrom(r)
if err != nil && err != ErrWriteTooLong {
tw.err = err
}
return n, err
}
// Close closes the tar archive by flushing the padding, and writing the footer.
// If the current file (from a prior call to WriteHeader) is not fully written,
// then this returns an error.
func (tw *Writer) Close() error {
if tw.err == ErrWriteAfterClose {
return nil
}
if tw.err != nil {
return tw.err
}
// Trailer: two zero blocks.
err := tw.Flush()
for i := 0; i < 2 && err == nil; i++ {
_, err = tw.w.Write(zeroBlock[:])
}
// Ensure all future actions are invalid.
tw.err = ErrWriteAfterClose
return err // Report IO errors
}
// regFileWriter is a fileWriter for writing data to a regular file entry.
type regFileWriter struct {
w io.Writer // Underlying Writer
nb int64 // Number of remaining bytes to write
}
func (fw *regFileWriter) Write(b []byte) (n int, err error) {
overwrite := int64(len(b)) > fw.nb
if overwrite {
b = b[:fw.nb]
}
if len(b) > 0 {
n, err = fw.w.Write(b)
fw.nb -= int64(n)
}
switch {
case err != nil:
return n, err
case overwrite:
return n, ErrWriteTooLong
default:
return n, nil
}
}
func (fw *regFileWriter) ReadFrom(r io.Reader) (int64, error) {
return io.Copy(struct{ io.Writer }{fw}, r)
}
// logicalRemaining implements fileState.logicalRemaining.
func (fw regFileWriter) logicalRemaining() int64 {
return fw.nb
}
// physicalRemaining implements fileState.physicalRemaining.
func (fw regFileWriter) physicalRemaining() int64 {
return fw.nb
}
// sparseFileWriter is a fileWriter for writing data to a sparse file entry.
type sparseFileWriter struct {
fw fileWriter // Underlying fileWriter
sp sparseDatas // Normalized list of data fragments
pos int64 // Current position in sparse file
}
func (sw *sparseFileWriter) Write(b []byte) (n int, err error) {
overwrite := int64(len(b)) > sw.logicalRemaining()
if overwrite {
b = b[:sw.logicalRemaining()]
}
b0 := b
endPos := sw.pos + int64(len(b))
for endPos > sw.pos && err == nil {
var nf int // Bytes written in fragment
dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
if sw.pos < dataStart { // In a hole fragment
bf := b[:min(int64(len(b)), dataStart-sw.pos)]
nf, err = zeroWriter{}.Write(bf)
} else { // In a data fragment
bf := b[:min(int64(len(b)), dataEnd-sw.pos)]
nf, err = sw.fw.Write(bf)
}
b = b[nf:]
sw.pos += int64(nf)
if sw.pos >= dataEnd && len(sw.sp) > 1 {
sw.sp = sw.sp[1:] // Ensure last fragment always remains
}
}
n = len(b0) - len(b)
switch {
case err == ErrWriteTooLong:
return n, errMissData // Not possible; implies bug in validation logic
case err != nil:
return n, err
case sw.logicalRemaining() == 0 && sw.physicalRemaining() > 0:
return n, errUnrefData // Not possible; implies bug in validation logic
case overwrite:
return n, ErrWriteTooLong
default:
return n, nil
}
}
func (sw *sparseFileWriter) ReadFrom(r io.Reader) (n int64, err error) {
rs, ok := r.(io.ReadSeeker)
if ok {
if _, err := rs.Seek(0, io.SeekCurrent); err != nil {
ok = false // Not all io.Seeker can really seek
}
}
if !ok {
return io.Copy(struct{ io.Writer }{sw}, r)
}
var readLastByte bool
pos0 := sw.pos
for sw.logicalRemaining() > 0 && !readLastByte && err == nil {
var nf int64 // Size of fragment
dataStart, dataEnd := sw.sp[0].Offset, sw.sp[0].endOffset()
if sw.pos < dataStart { // In a hole fragment
nf = dataStart - sw.pos
if sw.physicalRemaining() == 0 {
readLastByte = true
nf--
}
_, err = rs.Seek(nf, io.SeekCurrent)
} else { // In a data fragment
nf = dataEnd - sw.pos
nf, err = io.CopyN(sw.fw, rs, nf)
}
sw.pos += nf
if sw.pos >= dataEnd && len(sw.sp) > 1 {
sw.sp = sw.sp[1:] // Ensure last fragment always remains
}
}
// If the last fragment is a hole, then seek to 1-byte before EOF, and
// read a single byte to ensure the file is the right size.
if readLastByte && err == nil {
_, err = mustReadFull(rs, []byte{0})
sw.pos++
}
n = sw.pos - pos0
switch {
case err == io.EOF:
return n, io.ErrUnexpectedEOF
case err == ErrWriteTooLong:
return n, errMissData // Not possible; implies bug in validation logic
case err != nil:
return n, err
case sw.logicalRemaining() == 0 && sw.physicalRemaining() > 0:
return n, errUnrefData // Not possible; implies bug in validation logic
default:
return n, ensureEOF(rs)
}
}
func (sw sparseFileWriter) logicalRemaining() int64 {
return sw.sp[len(sw.sp)-1].endOffset() - sw.pos
}
func (sw sparseFileWriter) physicalRemaining() int64 {
return sw.fw.physicalRemaining()
}
// zeroWriter may only be written with NULs, otherwise it returns errWriteHole.
type zeroWriter struct{}
func (zeroWriter) Write(b []byte) (int, error) {
for i, c := range b {
if c != 0 {
return i, errWriteHole
}
}
return len(b), nil
}
// ensureEOF checks whether r is at EOF, reporting ErrWriteTooLong if not so.
func ensureEOF(r io.Reader) error {
n, err := tryReadFull(r, []byte{0})
switch {
case n > 0:
return ErrWriteTooLong
case err == io.EOF:
return nil
default:
return err
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"encoding/binary"
"errors"
"hash"
"hash/crc32"
"internal/godebug"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
var zipinsecurepath = godebug.New("zipinsecurepath")
var (
ErrFormat = errors.New("zip: not a valid zip file")
ErrAlgorithm = errors.New("zip: unsupported compression algorithm")
ErrChecksum = errors.New("zip: checksum error")
ErrInsecurePath = errors.New("zip: insecure file path")
)
// A Reader serves content from a ZIP archive.
type Reader struct {
r io.ReaderAt
File []*File
Comment string
decompressors map[uint16]Decompressor
// Some JAR files are zip files with a prefix that is a bash script.
// The baseOffset field is the start of the zip file proper.
baseOffset int64
// fileList is a list of files sorted by ename,
// for use by the Open method.
fileListOnce sync.Once
fileList []fileListEntry
}
// A ReadCloser is a Reader that must be closed when no longer needed.
type ReadCloser struct {
f *os.File
Reader
}
// A File is a single file in a ZIP archive.
// The file information is in the embedded FileHeader.
// The file content can be accessed by calling Open.
type File struct {
FileHeader
zip *Reader
zipr io.ReaderAt
headerOffset int64 // includes overall ZIP archive baseOffset
zip64 bool // zip64 extended information extra field presence
}
// OpenReader will open the Zip file specified by name and return a ReadCloser.
func OpenReader(name string) (*ReadCloser, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
f.Close()
return nil, err
}
r := new(ReadCloser)
if err := r.init(f, fi.Size()); err != nil {
f.Close()
return nil, err
}
r.f = f
return r, nil
}
// NewReader returns a new Reader reading from r, which is assumed to
// have the given size in bytes.
//
// If any file inside the archive uses a non-local name
// (as defined by [filepath.IsLocal]) or a name containing backslashes
// and the GODEBUG environment variable contains `zipinsecurepath=0`,
// NewReader returns the reader with an ErrInsecurePath error.
// A future version of Go may introduce this behavior by default.
// Programs that want to accept non-local names can ignore
// the ErrInsecurePath error and use the returned reader.
func NewReader(r io.ReaderAt, size int64) (*Reader, error) {
if size < 0 {
return nil, errors.New("zip: size cannot be negative")
}
zr := new(Reader)
if err := zr.init(r, size); err != nil {
return nil, err
}
for _, f := range zr.File {
if f.Name == "" {
// Zip permits an empty file name field.
continue
}
// The zip specification states that names must use forward slashes,
// so consider any backslashes in the name insecure.
if !filepath.IsLocal(f.Name) || strings.Contains(f.Name, `\`) {
if zipinsecurepath.Value() != "0" {
continue
}
zipinsecurepath.IncNonDefault()
return zr, ErrInsecurePath
}
}
return zr, nil
}
func (r *Reader) init(rdr io.ReaderAt, size int64) error {
end, baseOffset, err := readDirectoryEnd(rdr, size)
if err != nil {
return err
}
r.r = rdr
r.baseOffset = baseOffset
// Since the number of directory records is not validated, it is not
// safe to preallocate r.File without first checking that the specified
// number of files is reasonable, since a malformed archive may
// indicate it contains up to 1 << 128 - 1 files. Since each file has a
// header which will be _at least_ 30 bytes we can safely preallocate
// if (data size / 30) >= end.directoryRecords.
if end.directorySize < uint64(size) && (uint64(size)-end.directorySize)/30 >= end.directoryRecords {
r.File = make([]*File, 0, end.directoryRecords)
}
r.Comment = end.comment
rs := io.NewSectionReader(rdr, 0, size)
if _, err = rs.Seek(r.baseOffset+int64(end.directoryOffset), io.SeekStart); err != nil {
return err
}
buf := bufio.NewReader(rs)
// The count of files inside a zip is truncated to fit in a uint16.
// Gloss over this by reading headers until we encounter
// a bad one, and then only report an ErrFormat or UnexpectedEOF if
// the file count modulo 65536 is incorrect.
for {
f := &File{zip: r, zipr: rdr}
err = readDirectoryHeader(f, buf)
if err == ErrFormat || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return err
}
f.headerOffset += r.baseOffset
r.File = append(r.File, f)
}
if uint16(len(r.File)) != uint16(end.directoryRecords) { // only compare 16 bits here
// Return the readDirectoryHeader error if we read
// the wrong number of directory entries.
return err
}
return nil
}
// RegisterDecompressor registers or overrides a custom decompressor for a
// specific method ID. If a decompressor for a given method is not found,
// Reader will default to looking up the decompressor at the package level.
func (r *Reader) RegisterDecompressor(method uint16, dcomp Decompressor) {
if r.decompressors == nil {
r.decompressors = make(map[uint16]Decompressor)
}
r.decompressors[method] = dcomp
}
func (r *Reader) decompressor(method uint16) Decompressor {
dcomp := r.decompressors[method]
if dcomp == nil {
dcomp = decompressor(method)
}
return dcomp
}
// Close closes the Zip file, rendering it unusable for I/O.
func (rc *ReadCloser) Close() error {
return rc.f.Close()
}
// DataOffset returns the offset of the file's possibly-compressed
// data, relative to the beginning of the zip file.
//
// Most callers should instead use Open, which transparently
// decompresses data and verifies checksums.
func (f *File) DataOffset() (offset int64, err error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return
}
return f.headerOffset + bodyOffset, nil
}
// Open returns a ReadCloser that provides access to the File's contents.
// Multiple files may be read concurrently.
func (f *File) Open() (io.ReadCloser, error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return nil, err
}
if strings.HasSuffix(f.Name, "/") {
// The ZIP specification (APPNOTE.TXT) specifies that directories, which
// are technically zero-byte files, must not have any associated file
// data. We previously tried failing here if f.CompressedSize64 != 0,
// but it turns out that a number of implementations (namely, the Java
// jar tool) don't properly set the storage method on directories
// resulting in a file with compressed size > 0 but uncompressed size ==
// 0. We still want to fail when a directory has associated uncompressed
// data, but we are tolerant of cases where the uncompressed size is
// zero but compressed size is not.
if f.UncompressedSize64 != 0 {
return &dirReader{ErrFormat}, nil
} else {
return &dirReader{io.EOF}, nil
}
}
size := int64(f.CompressedSize64)
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
dcomp := f.zip.decompressor(f.Method)
if dcomp == nil {
return nil, ErrAlgorithm
}
var rc io.ReadCloser = dcomp(r)
var desr io.Reader
if f.hasDataDescriptor() {
desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
}
rc = &checksumReader{
rc: rc,
hash: crc32.NewIEEE(),
f: f,
desr: desr,
}
return rc, nil
}
// OpenRaw returns a Reader that provides access to the File's contents without
// decompression.
func (f *File) OpenRaw() (io.Reader, error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return nil, err
}
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, int64(f.CompressedSize64))
return r, nil
}
type dirReader struct {
err error
}
func (r *dirReader) Read([]byte) (int, error) {
return 0, r.err
}
func (r *dirReader) Close() error {
return nil
}
type checksumReader struct {
rc io.ReadCloser
hash hash.Hash32
nread uint64 // number of bytes read so far
f *File
desr io.Reader // if non-nil, where to read the data descriptor
err error // sticky error
}
func (r *checksumReader) Stat() (fs.FileInfo, error) {
return headerFileInfo{&r.f.FileHeader}, nil
}
func (r *checksumReader) Read(b []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, err = r.rc.Read(b)
r.hash.Write(b[:n])
r.nread += uint64(n)
if r.nread > r.f.UncompressedSize64 {
return 0, ErrFormat
}
if err == nil {
return
}
if err == io.EOF {
if r.nread != r.f.UncompressedSize64 {
return 0, io.ErrUnexpectedEOF
}
if r.desr != nil {
if err1 := readDataDescriptor(r.desr, r.f); err1 != nil {
if err1 == io.EOF {
err = io.ErrUnexpectedEOF
} else {
err = err1
}
} else if r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
} else {
// If there's not a data descriptor, we still compare
// the CRC32 of what we've read against the file header
// or TOC's CRC32, if it seems like it was set.
if r.f.CRC32 != 0 && r.hash.Sum32() != r.f.CRC32 {
err = ErrChecksum
}
}
}
r.err = err
return
}
func (r *checksumReader) Close() error { return r.rc.Close() }
// findBodyOffset does the minimum work to verify the file has a header
// and returns the file body offset.
func (f *File) findBodyOffset() (int64, error) {
var buf [fileHeaderLen]byte
if _, err := f.zipr.ReadAt(buf[:], f.headerOffset); err != nil {
return 0, err
}
b := readBuf(buf[:])
if sig := b.uint32(); sig != fileHeaderSignature {
return 0, ErrFormat
}
b = b[22:] // skip over most of the header
filenameLen := int(b.uint16())
extraLen := int(b.uint16())
return int64(fileHeaderLen + filenameLen + extraLen), nil
}
// readDirectoryHeader attempts to read a directory header from r.
// It returns io.ErrUnexpectedEOF if it cannot read a complete header,
// and ErrFormat if it doesn't find a valid header signature.
func readDirectoryHeader(f *File, r io.Reader) error {
var buf [directoryHeaderLen]byte
if _, err := io.ReadFull(r, buf[:]); err != nil {
return err
}
b := readBuf(buf[:])
if sig := b.uint32(); sig != directoryHeaderSignature {
return ErrFormat
}
f.CreatorVersion = b.uint16()
f.ReaderVersion = b.uint16()
f.Flags = b.uint16()
f.Method = b.uint16()
f.ModifiedTime = b.uint16()
f.ModifiedDate = b.uint16()
f.CRC32 = b.uint32()
f.CompressedSize = b.uint32()
f.UncompressedSize = b.uint32()
f.CompressedSize64 = uint64(f.CompressedSize)
f.UncompressedSize64 = uint64(f.UncompressedSize)
filenameLen := int(b.uint16())
extraLen := int(b.uint16())
commentLen := int(b.uint16())
b = b[4:] // skipped start disk number and internal attributes (2x uint16)
f.ExternalAttrs = b.uint32()
f.headerOffset = int64(b.uint32())
d := make([]byte, filenameLen+extraLen+commentLen)
if _, err := io.ReadFull(r, d); err != nil {
return err
}
f.Name = string(d[:filenameLen])
f.Extra = d[filenameLen : filenameLen+extraLen]
f.Comment = string(d[filenameLen+extraLen:])
// Determine the character encoding.
utf8Valid1, utf8Require1 := detectUTF8(f.Name)
utf8Valid2, utf8Require2 := detectUTF8(f.Comment)
switch {
case !utf8Valid1 || !utf8Valid2:
// Name and Comment definitely not UTF-8.
f.NonUTF8 = true
case !utf8Require1 && !utf8Require2:
// Name and Comment use only single-byte runes that overlap with UTF-8.
f.NonUTF8 = false
default:
// Might be UTF-8, might be some other encoding; preserve existing flag.
// Some ZIP writers use UTF-8 encoding without setting the UTF-8 flag.
// Since it is impossible to always distinguish valid UTF-8 from some
// other encoding (e.g., GBK or Shift-JIS), we trust the flag.
f.NonUTF8 = f.Flags&0x800 == 0
}
needUSize := f.UncompressedSize == ^uint32(0)
needCSize := f.CompressedSize == ^uint32(0)
needHeaderOffset := f.headerOffset == int64(^uint32(0))
// Best effort to find what we need.
// Other zip authors might not even follow the basic format,
// and we'll just ignore the Extra content in that case.
var modified time.Time
parseExtras:
for extra := readBuf(f.Extra); len(extra) >= 4; { // need at least tag and size
fieldTag := extra.uint16()
fieldSize := int(extra.uint16())
if len(extra) < fieldSize {
break
}
fieldBuf := extra.sub(fieldSize)
switch fieldTag {
case zip64ExtraID:
f.zip64 = true
// update directory values from the zip64 extra block.
// They should only be consulted if the sizes read earlier
// are maxed out.
// See golang.org/issue/13367.
if needUSize {
needUSize = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.UncompressedSize64 = fieldBuf.uint64()
}
if needCSize {
needCSize = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.CompressedSize64 = fieldBuf.uint64()
}
if needHeaderOffset {
needHeaderOffset = false
if len(fieldBuf) < 8 {
return ErrFormat
}
f.headerOffset = int64(fieldBuf.uint64())
}
case ntfsExtraID:
if len(fieldBuf) < 4 {
continue parseExtras
}
fieldBuf.uint32() // reserved (ignored)
for len(fieldBuf) >= 4 { // need at least tag and size
attrTag := fieldBuf.uint16()
attrSize := int(fieldBuf.uint16())
if len(fieldBuf) < attrSize {
continue parseExtras
}
attrBuf := fieldBuf.sub(attrSize)
if attrTag != 1 || attrSize != 24 {
continue // Ignore irrelevant attributes
}
const ticksPerSecond = 1e7 // Windows timestamp resolution
ts := int64(attrBuf.uint64()) // ModTime since Windows epoch
secs := int64(ts / ticksPerSecond)
nsecs := (1e9 / ticksPerSecond) * int64(ts%ticksPerSecond)
epoch := time.Date(1601, time.January, 1, 0, 0, 0, 0, time.UTC)
modified = time.Unix(epoch.Unix()+secs, nsecs)
}
case unixExtraID, infoZipUnixExtraID:
if len(fieldBuf) < 8 {
continue parseExtras
}
fieldBuf.uint32() // AcTime (ignored)
ts := int64(fieldBuf.uint32()) // ModTime since Unix epoch
modified = time.Unix(ts, 0)
case extTimeExtraID:
if len(fieldBuf) < 5 || fieldBuf.uint8()&1 == 0 {
continue parseExtras
}
ts := int64(fieldBuf.uint32()) // ModTime since Unix epoch
modified = time.Unix(ts, 0)
}
}
msdosModified := msDosTimeToTime(f.ModifiedDate, f.ModifiedTime)
f.Modified = msdosModified
if !modified.IsZero() {
f.Modified = modified.UTC()
// If legacy MS-DOS timestamps are set, we can use the delta between
// the legacy and extended versions to estimate timezone offset.
//
// A non-UTC timezone is always used (even if offset is zero).
// Thus, FileHeader.Modified.Location() == time.UTC is useful for
// determining whether extended timestamps are present.
// This is necessary for users that need to do additional time
// calculations when dealing with legacy ZIP formats.
if f.ModifiedTime != 0 || f.ModifiedDate != 0 {
f.Modified = modified.In(timeZone(msdosModified.Sub(modified)))
}
}
// Assume that uncompressed size 2³²-1 could plausibly happen in
// an old zip32 file that was sharding inputs into the largest chunks
// possible (or is just malicious; search the web for 42.zip).
// If needUSize is true still, it means we didn't see a zip64 extension.
// As long as the compressed size is not also 2³²-1 (implausible)
// and the header is not also 2³²-1 (equally implausible),
// accept the uncompressed size 2³²-1 as valid.
// If nothing else, this keeps archive/zip working with 42.zip.
_ = needUSize
if needCSize || needHeaderOffset {
return ErrFormat
}
return nil
}
func readDataDescriptor(r io.Reader, f *File) error {
var buf [dataDescriptorLen]byte
// The spec says: "Although not originally assigned a
// signature, the value 0x08074b50 has commonly been adopted
// as a signature value for the data descriptor record.
// Implementers should be aware that ZIP files may be
// encountered with or without this signature marking data
// descriptors and should account for either case when reading
// ZIP files to ensure compatibility."
//
// dataDescriptorLen includes the size of the signature but
// first read just those 4 bytes to see if it exists.
if _, err := io.ReadFull(r, buf[:4]); err != nil {
return err
}
off := 0
maybeSig := readBuf(buf[:4])
if maybeSig.uint32() != dataDescriptorSignature {
// No data descriptor signature. Keep these four
// bytes.
off += 4
}
if _, err := io.ReadFull(r, buf[off:12]); err != nil {
return err
}
b := readBuf(buf[:12])
if b.uint32() != f.CRC32 {
return ErrChecksum
}
// The two sizes that follow here can be either 32 bits or 64 bits
// but the spec is not very clear on this and different
// interpretations has been made causing incompatibilities. We
// already have the sizes from the central directory so we can
// just ignore these.
return nil
}
func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, baseOffset int64, err error) {
// look for directoryEndSignature in the last 1k, then in the last 65k
var buf []byte
var directoryEndOffset int64
for i, bLen := range []int64{1024, 65 * 1024} {
if bLen > size {
bLen = size
}
buf = make([]byte, int(bLen))
if _, err := r.ReadAt(buf, size-bLen); err != nil && err != io.EOF {
return nil, 0, err
}
if p := findSignatureInBlock(buf); p >= 0 {
buf = buf[p:]
directoryEndOffset = size - bLen + int64(p)
break
}
if i == 1 || bLen == size {
return nil, 0, ErrFormat
}
}
// read header into struct
b := readBuf(buf[4:]) // skip signature
d := &directoryEnd{
diskNbr: uint32(b.uint16()),
dirDiskNbr: uint32(b.uint16()),
dirRecordsThisDisk: uint64(b.uint16()),
directoryRecords: uint64(b.uint16()),
directorySize: uint64(b.uint32()),
directoryOffset: uint64(b.uint32()),
commentLen: b.uint16(),
}
l := int(d.commentLen)
if l > len(b) {
return nil, 0, errors.New("zip: invalid comment length")
}
d.comment = string(b[:l])
// These values mean that the file can be a zip64 file
if d.directoryRecords == 0xffff || d.directorySize == 0xffff || d.directoryOffset == 0xffffffff {
p, err := findDirectory64End(r, directoryEndOffset)
if err == nil && p >= 0 {
directoryEndOffset = p
err = readDirectory64End(r, p, d)
}
if err != nil {
return nil, 0, err
}
}
baseOffset = directoryEndOffset - int64(d.directorySize) - int64(d.directoryOffset)
// Make sure directoryOffset points to somewhere in our file.
if o := baseOffset + int64(d.directoryOffset); o < 0 || o >= size {
return nil, 0, ErrFormat
}
// If the directory end data tells us to use a non-zero baseOffset,
// but we would find a valid directory entry if we assume that the
// baseOffset is 0, then just use a baseOffset of 0.
// We've seen files in which the directory end data gives us
// an incorrect baseOffset.
if baseOffset > 0 {
off := int64(d.directoryOffset)
rs := io.NewSectionReader(r, off, size-off)
if readDirectoryHeader(&File{}, rs) == nil {
baseOffset = 0
}
}
return d, baseOffset, nil
}
// findDirectory64End tries to read the zip64 locator just before the
// directory end and returns the offset of the zip64 directory end if
// found.
func findDirectory64End(r io.ReaderAt, directoryEndOffset int64) (int64, error) {
locOffset := directoryEndOffset - directory64LocLen
if locOffset < 0 {
return -1, nil // no need to look for a header outside the file
}
buf := make([]byte, directory64LocLen)
if _, err := r.ReadAt(buf, locOffset); err != nil {
return -1, err
}
b := readBuf(buf)
if sig := b.uint32(); sig != directory64LocSignature {
return -1, nil
}
if b.uint32() != 0 { // number of the disk with the start of the zip64 end of central directory
return -1, nil // the file is not a valid zip64-file
}
p := b.uint64() // relative offset of the zip64 end of central directory record
if b.uint32() != 1 { // total number of disks
return -1, nil // the file is not a valid zip64-file
}
return int64(p), nil
}
// readDirectory64End reads the zip64 directory end and updates the
// directory end with the zip64 directory end values.
func readDirectory64End(r io.ReaderAt, offset int64, d *directoryEnd) (err error) {
buf := make([]byte, directory64EndLen)
if _, err := r.ReadAt(buf, offset); err != nil {
return err
}
b := readBuf(buf)
if sig := b.uint32(); sig != directory64EndSignature {
return ErrFormat
}
b = b[12:] // skip dir size, version and version needed (uint64 + 2x uint16)
d.diskNbr = b.uint32() // number of this disk
d.dirDiskNbr = b.uint32() // number of the disk with the start of the central directory
d.dirRecordsThisDisk = b.uint64() // total number of entries in the central directory on this disk
d.directoryRecords = b.uint64() // total number of entries in the central directory
d.directorySize = b.uint64() // size of the central directory
d.directoryOffset = b.uint64() // offset of start of central directory with respect to the starting disk number
return nil
}
func findSignatureInBlock(b []byte) int {
for i := len(b) - directoryEndLen; i >= 0; i-- {
// defined from directoryEndSignature in struct.go
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
// n is length of comment
n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
if n+directoryEndLen+i <= len(b) {
return i
}
}
}
return -1
}
type readBuf []byte
func (b *readBuf) uint8() uint8 {
v := (*b)[0]
*b = (*b)[1:]
return v
}
func (b *readBuf) uint16() uint16 {
v := binary.LittleEndian.Uint16(*b)
*b = (*b)[2:]
return v
}
func (b *readBuf) uint32() uint32 {
v := binary.LittleEndian.Uint32(*b)
*b = (*b)[4:]
return v
}
func (b *readBuf) uint64() uint64 {
v := binary.LittleEndian.Uint64(*b)
*b = (*b)[8:]
return v
}
func (b *readBuf) sub(n int) readBuf {
b2 := (*b)[:n]
*b = (*b)[n:]
return b2
}
// A fileListEntry is a File and its ename.
// If file == nil, the fileListEntry describes a directory without metadata.
type fileListEntry struct {
name string
file *File
isDir bool
isDup bool
}
type fileInfoDirEntry interface {
fs.FileInfo
fs.DirEntry
}
func (f *fileListEntry) stat() (fileInfoDirEntry, error) {
if f.isDup {
return nil, errors.New(f.name + ": duplicate entries in zip file")
}
if !f.isDir {
return headerFileInfo{&f.file.FileHeader}, nil
}
return f, nil
}
// Only used for directories.
func (f *fileListEntry) Name() string { _, elem, _ := split(f.name); return elem }
func (f *fileListEntry) Size() int64 { return 0 }
func (f *fileListEntry) Mode() fs.FileMode { return fs.ModeDir | 0555 }
func (f *fileListEntry) Type() fs.FileMode { return fs.ModeDir }
func (f *fileListEntry) IsDir() bool { return true }
func (f *fileListEntry) Sys() any { return nil }
func (f *fileListEntry) ModTime() time.Time {
if f.file == nil {
return time.Time{}
}
return f.file.FileHeader.Modified.UTC()
}
func (f *fileListEntry) Info() (fs.FileInfo, error) { return f, nil }
// toValidName coerces name to be a valid name for fs.FS.Open.
func toValidName(name string) string {
name = strings.ReplaceAll(name, `\`, `/`)
p := path.Clean(name)
p = strings.TrimPrefix(p, "/")
for strings.HasPrefix(p, "../") {
p = p[len("../"):]
}
return p
}
func (r *Reader) initFileList() {
r.fileListOnce.Do(func() {
// files and knownDirs map from a file/directory name
// to an index into the r.fileList entry that we are
// building. They are used to mark duplicate entries.
files := make(map[string]int)
knownDirs := make(map[string]int)
// dirs[name] is true if name is known to be a directory,
// because it appears as a prefix in a path.
dirs := make(map[string]bool)
for _, file := range r.File {
isDir := len(file.Name) > 0 && file.Name[len(file.Name)-1] == '/'
name := toValidName(file.Name)
if name == "" {
continue
}
if idx, ok := files[name]; ok {
r.fileList[idx].isDup = true
continue
}
if idx, ok := knownDirs[name]; ok {
r.fileList[idx].isDup = true
continue
}
for dir := path.Dir(name); dir != "."; dir = path.Dir(dir) {
dirs[dir] = true
}
idx := len(r.fileList)
entry := fileListEntry{
name: name,
file: file,
isDir: isDir,
}
r.fileList = append(r.fileList, entry)
if isDir {
knownDirs[name] = idx
} else {
files[name] = idx
}
}
for dir := range dirs {
if _, ok := knownDirs[dir]; !ok {
if idx, ok := files[dir]; ok {
r.fileList[idx].isDup = true
} else {
entry := fileListEntry{
name: dir,
file: nil,
isDir: true,
}
r.fileList = append(r.fileList, entry)
}
}
}
sort.Slice(r.fileList, func(i, j int) bool { return fileEntryLess(r.fileList[i].name, r.fileList[j].name) })
})
}
func fileEntryLess(x, y string) bool {
xdir, xelem, _ := split(x)
ydir, yelem, _ := split(y)
return xdir < ydir || xdir == ydir && xelem < yelem
}
// Open opens the named file in the ZIP archive,
// using the semantics of fs.FS.Open:
// paths are always slash separated, with no
// leading / or ../ elements.
func (r *Reader) Open(name string) (fs.File, error) {
r.initFileList()
if !fs.ValidPath(name) {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrInvalid}
}
e := r.openLookup(name)
if e == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if e.isDir {
return &openDir{e, r.openReadDir(name), 0}, nil
}
rc, err := e.file.Open()
if err != nil {
return nil, err
}
return rc.(fs.File), nil
}
func split(name string) (dir, elem string, isDir bool) {
if len(name) > 0 && name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
var dotFile = &fileListEntry{name: "./", isDir: true}
func (r *Reader) openLookup(name string) *fileListEntry {
if name == "." {
return dotFile
}
dir, elem, _ := split(name)
files := r.fileList
i := sort.Search(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) {
fname := files[i].name
if fname == name || len(fname) == len(name)+1 && fname[len(name)] == '/' && fname[:len(name)] == name {
return &files[i]
}
}
return nil
}
func (r *Reader) openReadDir(dir string) []fileListEntry {
files := r.fileList
i := sort.Search(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sort.Search(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
type openDir struct {
e *fileListEntry
files []fileListEntry
offset int
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.e.stat() }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.e.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if count > 0 && n > count {
n = count
}
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
list := make([]fs.DirEntry, n)
for i := range list {
s, err := d.files[d.offset+i].stat()
if err != nil {
return nil, err
}
list[i] = s
}
d.offset += n
return list, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"compress/flate"
"errors"
"io"
"sync"
)
// A Compressor returns a new compressing writer, writing to w.
// The WriteCloser's Close method must be used to flush pending data to w.
// The Compressor itself must be safe to invoke from multiple goroutines
// simultaneously, but each returned writer will be used only by
// one goroutine at a time.
type Compressor func(w io.Writer) (io.WriteCloser, error)
// A Decompressor returns a new decompressing reader, reading from r.
// The ReadCloser's Close method must be used to release associated resources.
// The Decompressor itself must be safe to invoke from multiple goroutines
// simultaneously, but each returned reader will be used only by
// one goroutine at a time.
type Decompressor func(r io.Reader) io.ReadCloser
var flateWriterPool sync.Pool
func newFlateWriter(w io.Writer) io.WriteCloser {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if ok {
fw.Reset(w)
} else {
fw, _ = flate.NewWriter(w, 5)
}
return &pooledFlateWriter{fw: fw}
}
type pooledFlateWriter struct {
mu sync.Mutex // guards Close and Write
fw *flate.Writer
}
func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.fw == nil {
return 0, errors.New("Write after Close")
}
return w.fw.Write(p)
}
func (w *pooledFlateWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var err error
if w.fw != nil {
err = w.fw.Close()
flateWriterPool.Put(w.fw)
w.fw = nil
}
return err
}
var flateReaderPool sync.Pool
func newFlateReader(r io.Reader) io.ReadCloser {
fr, ok := flateReaderPool.Get().(io.ReadCloser)
if ok {
fr.(flate.Resetter).Reset(r, nil)
} else {
fr = flate.NewReader(r)
}
return &pooledFlateReader{fr: fr}
}
type pooledFlateReader struct {
mu sync.Mutex // guards Close and Read
fr io.ReadCloser
}
func (r *pooledFlateReader) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.fr == nil {
return 0, errors.New("Read after Close")
}
return r.fr.Read(p)
}
func (r *pooledFlateReader) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
var err error
if r.fr != nil {
err = r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
}
return err
}
var (
compressors sync.Map // map[uint16]Compressor
decompressors sync.Map // map[uint16]Decompressor
)
func init() {
compressors.Store(Store, Compressor(func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil }))
compressors.Store(Deflate, Compressor(func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil }))
decompressors.Store(Store, Decompressor(io.NopCloser))
decompressors.Store(Deflate, Decompressor(newFlateReader))
}
// RegisterDecompressor allows custom decompressors for a specified method ID.
// The common methods Store and Deflate are built in.
func RegisterDecompressor(method uint16, dcomp Decompressor) {
if _, dup := decompressors.LoadOrStore(method, dcomp); dup {
panic("decompressor already registered")
}
}
// RegisterCompressor registers custom compressors for a specified method ID.
// The common methods Store and Deflate are built in.
func RegisterCompressor(method uint16, comp Compressor) {
if _, dup := compressors.LoadOrStore(method, comp); dup {
panic("compressor already registered")
}
}
func compressor(method uint16) Compressor {
ci, ok := compressors.Load(method)
if !ok {
return nil
}
return ci.(Compressor)
}
func decompressor(method uint16) Decompressor {
di, ok := decompressors.Load(method)
if !ok {
return nil
}
return di.(Decompressor)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading and writing ZIP archives.
See the [ZIP specification] for details.
This package does not support disk spanning.
A note about ZIP64:
To be backwards compatible the FileHeader has both 32 and 64 bit Size
fields. The 64 bit fields will always contain the correct value and
for normal archives both fields will be the same. For files requiring
the ZIP64 format the 32 bit fields will be 0xffffffff and the 64 bit
fields must be used instead.
[ZIP specification]: https://www.pkware.com/appnote
*/
package zip
import (
"io/fs"
"path"
"time"
)
// Compression methods.
const (
Store uint16 = 0 // no compression
Deflate uint16 = 8 // DEFLATE compressed
)
const (
fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50
directory64LocSignature = 0x07064b50
directory64EndSignature = 0x06064b50
dataDescriptorSignature = 0x08074b50 // de-facto standard; required by OS X Finder
fileHeaderLen = 30 // + filename + extra
directoryHeaderLen = 46 // + filename + extra + comment
directoryEndLen = 22 // + comment
dataDescriptorLen = 16 // four uint32: descriptor signature, crc32, compressed size, size
dataDescriptor64Len = 24 // two uint32: signature, crc32 | two uint64: compressed size, size
directory64LocLen = 20 //
directory64EndLen = 56 // + extra
// Constants for the first byte in CreatorVersion.
creatorFAT = 0
creatorUnix = 3
creatorNTFS = 11
creatorVFAT = 14
creatorMacOSX = 19
// Version numbers.
zipVersion20 = 20 // 2.0
zipVersion45 = 45 // 4.5 (reads and writes zip64 archives)
// Limits for non zip64 files.
uint16max = (1 << 16) - 1
uint32max = (1 << 32) - 1
// Extra header IDs.
//
// IDs 0..31 are reserved for official use by PKWARE.
// IDs above that range are defined by third-party vendors.
// Since ZIP lacked high precision timestamps (nor a official specification
// of the timezone used for the date fields), many competing extra fields
// have been invented. Pervasive use effectively makes them "official".
//
// See http://mdfs.net/Docs/Comp/Archiving/Zip/ExtraField
zip64ExtraID = 0x0001 // Zip64 extended information
ntfsExtraID = 0x000a // NTFS
unixExtraID = 0x000d // UNIX
extTimeExtraID = 0x5455 // Extended timestamp
infoZipUnixExtraID = 0x5855 // Info-ZIP Unix extension
)
// FileHeader describes a file within a ZIP file.
// See the [ZIP specification] for details.
//
// [ZIP specification]: https://www.pkware.com/appnote
type FileHeader struct {
// Name is the name of the file.
//
// It must be a relative path, not start with a drive letter (such as "C:"),
// and must use forward slashes instead of back slashes. A trailing slash
// indicates that this file is a directory and should have no data.
Name string
// Comment is any arbitrary user-defined string shorter than 64KiB.
Comment string
// NonUTF8 indicates that Name and Comment are not encoded in UTF-8.
//
// By specification, the only other encoding permitted should be CP-437,
// but historically many ZIP readers interpret Name and Comment as whatever
// the system's local character encoding happens to be.
//
// This flag should only be set if the user intends to encode a non-portable
// ZIP file for a specific localized region. Otherwise, the Writer
// automatically sets the ZIP format's UTF-8 flag for valid UTF-8 strings.
NonUTF8 bool
CreatorVersion uint16
ReaderVersion uint16
Flags uint16
// Method is the compression method. If zero, Store is used.
Method uint16
// Modified is the modified time of the file.
//
// When reading, an extended timestamp is preferred over the legacy MS-DOS
// date field, and the offset between the times is used as the timezone.
// If only the MS-DOS date is present, the timezone is assumed to be UTC.
//
// When writing, an extended timestamp (which is timezone-agnostic) is
// always emitted. The legacy MS-DOS date field is encoded according to the
// location of the Modified time.
Modified time.Time
// ModifiedTime is an MS-DOS-encoded time.
//
// Deprecated: Use Modified instead.
ModifiedTime uint16
// ModifiedDate is an MS-DOS-encoded date.
//
// Deprecated: Use Modified instead.
ModifiedDate uint16
// CRC32 is the CRC32 checksum of the file content.
CRC32 uint32
// CompressedSize is the compressed size of the file in bytes.
// If either the uncompressed or compressed size of the file
// does not fit in 32 bits, CompressedSize is set to ^uint32(0).
//
// Deprecated: Use CompressedSize64 instead.
CompressedSize uint32
// UncompressedSize is the compressed size of the file in bytes.
// If either the uncompressed or compressed size of the file
// does not fit in 32 bits, CompressedSize is set to ^uint32(0).
//
// Deprecated: Use UncompressedSize64 instead.
UncompressedSize uint32
// CompressedSize64 is the compressed size of the file in bytes.
CompressedSize64 uint64
// UncompressedSize64 is the uncompressed size of the file in bytes.
UncompressedSize64 uint64
Extra []byte
ExternalAttrs uint32 // Meaning depends on CreatorVersion
}
// FileInfo returns an fs.FileInfo for the FileHeader.
func (h *FileHeader) FileInfo() fs.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements fs.FileInfo.
type headerFileInfo struct {
fh *FileHeader
}
func (fi headerFileInfo) Name() string { return path.Base(fi.fh.Name) }
func (fi headerFileInfo) Size() int64 {
if fi.fh.UncompressedSize64 > 0 {
return int64(fi.fh.UncompressedSize64)
}
return int64(fi.fh.UncompressedSize)
}
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time {
if fi.fh.Modified.IsZero() {
return fi.fh.ModTime()
}
return fi.fh.Modified.UTC()
}
func (fi headerFileInfo) Mode() fs.FileMode { return fi.fh.Mode() }
func (fi headerFileInfo) Type() fs.FileMode { return fi.fh.Mode().Type() }
func (fi headerFileInfo) Sys() any { return fi.fh }
func (fi headerFileInfo) Info() (fs.FileInfo, error) { return fi, nil }
// FileInfoHeader creates a partially-populated FileHeader from an
// fs.FileInfo.
// Because fs.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
// If compression is desired, callers should set the FileHeader.Method
// field; it is unset by default.
func FileInfoHeader(fi fs.FileInfo) (*FileHeader, error) {
size := fi.Size()
fh := &FileHeader{
Name: fi.Name(),
UncompressedSize64: uint64(size),
}
fh.SetModTime(fi.ModTime())
fh.SetMode(fi.Mode())
if fh.UncompressedSize64 > uint32max {
fh.UncompressedSize = uint32max
} else {
fh.UncompressedSize = uint32(fh.UncompressedSize64)
}
return fh, nil
}
type directoryEnd struct {
diskNbr uint32 // unused
dirDiskNbr uint32 // unused
dirRecordsThisDisk uint64 // unused
directoryRecords uint64
directorySize uint64
directoryOffset uint64 // relative to file
commentLen uint16
comment string
}
// timeZone returns a *time.Location based on the provided offset.
// If the offset is non-sensible, then this uses an offset of zero.
func timeZone(offset time.Duration) *time.Location {
const (
minOffset = -12 * time.Hour // E.g., Baker island at -12:00
maxOffset = +14 * time.Hour // E.g., Line island at +14:00
offsetAlias = 15 * time.Minute // E.g., Nepal at +5:45
)
offset = offset.Round(offsetAlias)
if offset < minOffset || maxOffset < offset {
offset = 0
}
return time.FixedZone("", int(offset/time.Second))
}
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
// The resolution is 2s.
// See: https://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Date(
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
int(dosDate>>9+1980),
time.Month(dosDate>>5&0xf),
int(dosDate&0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
int(dosTime>>11),
int(dosTime>>5&0x3f),
int(dosTime&0x1f*2),
0, // nanoseconds
time.UTC,
)
}
// timeToMsDosTime converts a time.Time to an MS-DOS date and time.
// The resolution is 2s.
// See: https://msdn.microsoft.com/en-us/library/ms724274(v=VS.85).aspx
func timeToMsDosTime(t time.Time) (fDate uint16, fTime uint16) {
fDate = uint16(t.Day() + int(t.Month())<<5 + (t.Year()-1980)<<9)
fTime = uint16(t.Second()/2 + t.Minute()<<5 + t.Hour()<<11)
return
}
// ModTime returns the modification time in UTC using the legacy
// ModifiedDate and ModifiedTime fields.
//
// Deprecated: Use Modified instead.
func (h *FileHeader) ModTime() time.Time {
return msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
}
// SetModTime sets the Modified, ModifiedTime, and ModifiedDate fields
// to the given time in UTC.
//
// Deprecated: Use Modified instead.
func (h *FileHeader) SetModTime(t time.Time) {
t = t.UTC() // Convert to UTC for compatibility
h.Modified = t
h.ModifiedDate, h.ModifiedTime = timeToMsDosTime(t)
}
const (
// Unix constants. The specification doesn't mention them,
// but these seem to be the values agreed on by tools.
s_IFMT = 0xf000
s_IFSOCK = 0xc000
s_IFLNK = 0xa000
s_IFREG = 0x8000
s_IFBLK = 0x6000
s_IFDIR = 0x4000
s_IFCHR = 0x2000
s_IFIFO = 0x1000
s_ISUID = 0x800
s_ISGID = 0x400
s_ISVTX = 0x200
msdosDir = 0x10
msdosReadOnly = 0x01
)
// Mode returns the permission and mode bits for the FileHeader.
func (h *FileHeader) Mode() (mode fs.FileMode) {
switch h.CreatorVersion >> 8 {
case creatorUnix, creatorMacOSX:
mode = unixModeToFileMode(h.ExternalAttrs >> 16)
case creatorNTFS, creatorVFAT, creatorFAT:
mode = msdosModeToFileMode(h.ExternalAttrs)
}
if len(h.Name) > 0 && h.Name[len(h.Name)-1] == '/' {
mode |= fs.ModeDir
}
return mode
}
// SetMode changes the permission and mode bits for the FileHeader.
func (h *FileHeader) SetMode(mode fs.FileMode) {
h.CreatorVersion = h.CreatorVersion&0xff | creatorUnix<<8
h.ExternalAttrs = fileModeToUnixMode(mode) << 16
// set MSDOS attributes too, as the original zip does.
if mode&fs.ModeDir != 0 {
h.ExternalAttrs |= msdosDir
}
if mode&0200 == 0 {
h.ExternalAttrs |= msdosReadOnly
}
}
// isZip64 reports whether the file size exceeds the 32 bit limit
func (h *FileHeader) isZip64() bool {
return h.CompressedSize64 >= uint32max || h.UncompressedSize64 >= uint32max
}
func (h *FileHeader) hasDataDescriptor() bool {
return h.Flags&0x8 != 0
}
func msdosModeToFileMode(m uint32) (mode fs.FileMode) {
if m&msdosDir != 0 {
mode = fs.ModeDir | 0777
} else {
mode = 0666
}
if m&msdosReadOnly != 0 {
mode &^= 0222
}
return mode
}
func fileModeToUnixMode(mode fs.FileMode) uint32 {
var m uint32
switch mode & fs.ModeType {
default:
m = s_IFREG
case fs.ModeDir:
m = s_IFDIR
case fs.ModeSymlink:
m = s_IFLNK
case fs.ModeNamedPipe:
m = s_IFIFO
case fs.ModeSocket:
m = s_IFSOCK
case fs.ModeDevice:
m = s_IFBLK
case fs.ModeDevice | fs.ModeCharDevice:
m = s_IFCHR
}
if mode&fs.ModeSetuid != 0 {
m |= s_ISUID
}
if mode&fs.ModeSetgid != 0 {
m |= s_ISGID
}
if mode&fs.ModeSticky != 0 {
m |= s_ISVTX
}
return m | uint32(mode&0777)
}
func unixModeToFileMode(m uint32) fs.FileMode {
mode := fs.FileMode(m & 0777)
switch m & s_IFMT {
case s_IFBLK:
mode |= fs.ModeDevice
case s_IFCHR:
mode |= fs.ModeDevice | fs.ModeCharDevice
case s_IFDIR:
mode |= fs.ModeDir
case s_IFIFO:
mode |= fs.ModeNamedPipe
case s_IFLNK:
mode |= fs.ModeSymlink
case s_IFREG:
// nothing to do
case s_IFSOCK:
mode |= fs.ModeSocket
}
if m&s_ISGID != 0 {
mode |= fs.ModeSetgid
}
if m&s_ISUID != 0 {
mode |= fs.ModeSetuid
}
if m&s_ISVTX != 0 {
mode |= fs.ModeSticky
}
return mode
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"encoding/binary"
"errors"
"hash"
"hash/crc32"
"io"
"strings"
"unicode/utf8"
)
var (
errLongName = errors.New("zip: FileHeader.Name too long")
errLongExtra = errors.New("zip: FileHeader.Extra too long")
)
// Writer implements a zip file writer.
type Writer struct {
cw *countWriter
dir []*header
last *fileWriter
closed bool
compressors map[uint16]Compressor
comment string
// testHookCloseSizeOffset if non-nil is called with the size
// of offset of the central directory at Close.
testHookCloseSizeOffset func(size, offset uint64)
}
type header struct {
*FileHeader
offset uint64
raw bool
}
// NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{cw: &countWriter{w: bufio.NewWriter(w)}}
}
// SetOffset sets the offset of the beginning of the zip data within the
// underlying writer. It should be used when the zip data is appended to an
// existing file, such as a binary executable.
// It must be called before any data is written.
func (w *Writer) SetOffset(n int64) {
if w.cw.count != 0 {
panic("zip: SetOffset called after data was written")
}
w.cw.count = n
}
// Flush flushes any buffered data to the underlying writer.
// Calling Flush is not normally necessary; calling Close is sufficient.
func (w *Writer) Flush() error {
return w.cw.w.(*bufio.Writer).Flush()
}
// SetComment sets the end-of-central-directory comment field.
// It can only be called before Close.
func (w *Writer) SetComment(comment string) error {
if len(comment) > uint16max {
return errors.New("zip: Writer.Comment too long")
}
w.comment = comment
return nil
}
// Close finishes writing the zip file by writing the central directory.
// It does not close the underlying writer.
func (w *Writer) Close() error {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return err
}
w.last = nil
}
if w.closed {
return errors.New("zip: writer closed twice")
}
w.closed = true
// write central directory
start := w.cw.count
for _, h := range w.dir {
var buf [directoryHeaderLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(directoryHeaderSignature))
b.uint16(h.CreatorVersion)
b.uint16(h.ReaderVersion)
b.uint16(h.Flags)
b.uint16(h.Method)
b.uint16(h.ModifiedTime)
b.uint16(h.ModifiedDate)
b.uint32(h.CRC32)
if h.isZip64() || h.offset >= uint32max {
// the file needs a zip64 header. store maxint in both
// 32 bit size fields (and offset later) to signal that the
// zip64 extra header should be used.
b.uint32(uint32max) // compressed size
b.uint32(uint32max) // uncompressed size
// append a zip64 extra block to Extra
var buf [28]byte // 2x uint16 + 3x uint64
eb := writeBuf(buf[:])
eb.uint16(zip64ExtraID)
eb.uint16(24) // size = 3x uint64
eb.uint64(h.UncompressedSize64)
eb.uint64(h.CompressedSize64)
eb.uint64(h.offset)
h.Extra = append(h.Extra, buf[:]...)
} else {
b.uint32(h.CompressedSize)
b.uint32(h.UncompressedSize)
}
b.uint16(uint16(len(h.Name)))
b.uint16(uint16(len(h.Extra)))
b.uint16(uint16(len(h.Comment)))
b = b[4:] // skip disk number start and internal file attr (2x uint16)
b.uint32(h.ExternalAttrs)
if h.offset > uint32max {
b.uint32(uint32max)
} else {
b.uint32(uint32(h.offset))
}
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w.cw, h.Name); err != nil {
return err
}
if _, err := w.cw.Write(h.Extra); err != nil {
return err
}
if _, err := io.WriteString(w.cw, h.Comment); err != nil {
return err
}
}
end := w.cw.count
records := uint64(len(w.dir))
size := uint64(end - start)
offset := uint64(start)
if f := w.testHookCloseSizeOffset; f != nil {
f(size, offset)
}
if records >= uint16max || size >= uint32max || offset >= uint32max {
var buf [directory64EndLen + directory64LocLen]byte
b := writeBuf(buf[:])
// zip64 end of central directory record
b.uint32(directory64EndSignature)
b.uint64(directory64EndLen - 12) // length minus signature (uint32) and length fields (uint64)
b.uint16(zipVersion45) // version made by
b.uint16(zipVersion45) // version needed to extract
b.uint32(0) // number of this disk
b.uint32(0) // number of the disk with the start of the central directory
b.uint64(records) // total number of entries in the central directory on this disk
b.uint64(records) // total number of entries in the central directory
b.uint64(size) // size of the central directory
b.uint64(offset) // offset of start of central directory with respect to the starting disk number
// zip64 end of central directory locator
b.uint32(directory64LocSignature)
b.uint32(0) // number of the disk with the start of the zip64 end of central directory
b.uint64(uint64(end)) // relative offset of the zip64 end of central directory record
b.uint32(1) // total number of disks
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
// store max values in the regular end record to signal
// that the zip64 values should be used instead
records = uint16max
size = uint32max
offset = uint32max
}
// write end record
var buf [directoryEndLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(directoryEndSignature))
b = b[4:] // skip over disk number and first disk number (2x uint16)
b.uint16(uint16(records)) // number of entries this disk
b.uint16(uint16(records)) // number of entries total
b.uint32(uint32(size)) // size of directory
b.uint32(uint32(offset)) // start of directory
b.uint16(uint16(len(w.comment))) // byte size of EOCD comment
if _, err := w.cw.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w.cw, w.comment); err != nil {
return err
}
return w.cw.w.(*bufio.Writer).Flush()
}
// Create adds a file to the zip file using the provided name.
// It returns a Writer to which the file contents should be written.
// The file contents will be compressed using the Deflate method.
// The name must be a relative path: it must not start with a drive
// letter (e.g. C:) or leading slash, and only forward slashes are
// allowed. To create a directory instead of a file, add a trailing
// slash to the name.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) Create(name string) (io.Writer, error) {
header := &FileHeader{
Name: name,
Method: Deflate,
}
return w.CreateHeader(header)
}
// detectUTF8 reports whether s is a valid UTF-8 string, and whether the string
// must be considered UTF-8 encoding (i.e., not compatible with CP-437, ASCII,
// or any other common encoding).
func detectUTF8(s string) (valid, require bool) {
for i := 0; i < len(s); {
r, size := utf8.DecodeRuneInString(s[i:])
i += size
// Officially, ZIP uses CP-437, but many readers use the system's
// local character encoding. Most encoding are compatible with a large
// subset of CP-437, which itself is ASCII-like.
//
// Forbid 0x7e and 0x5c since EUC-KR and Shift-JIS replace those
// characters with localized currency and overline characters.
if r < 0x20 || r > 0x7d || r == 0x5c {
if !utf8.ValidRune(r) || (r == utf8.RuneError && size == 1) {
return false, false
}
require = true
}
}
return true, require
}
// prepare performs the bookkeeping operations required at the start of
// CreateHeader and CreateRaw.
func (w *Writer) prepare(fh *FileHeader) error {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return err
}
}
if len(w.dir) > 0 && w.dir[len(w.dir)-1].FileHeader == fh {
// See https://golang.org/issue/11144 confusion.
return errors.New("archive/zip: invalid duplicate FileHeader")
}
return nil
}
// CreateHeader adds a file to the zip archive using the provided FileHeader
// for the file metadata. Writer takes ownership of fh and may mutate
// its fields. The caller must not modify fh after calling CreateHeader.
//
// This returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, CreateRaw, or Close.
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
if err := w.prepare(fh); err != nil {
return nil, err
}
// The ZIP format has a sad state of affairs regarding character encoding.
// Officially, the name and comment fields are supposed to be encoded
// in CP-437 (which is mostly compatible with ASCII), unless the UTF-8
// flag bit is set. However, there are several problems:
//
// * Many ZIP readers still do not support UTF-8.
// * If the UTF-8 flag is cleared, several readers simply interpret the
// name and comment fields as whatever the local system encoding is.
//
// In order to avoid breaking readers without UTF-8 support,
// we avoid setting the UTF-8 flag if the strings are CP-437 compatible.
// However, if the strings require multibyte UTF-8 encoding and is a
// valid UTF-8 string, then we set the UTF-8 bit.
//
// For the case, where the user explicitly wants to specify the encoding
// as UTF-8, they will need to set the flag bit themselves.
utf8Valid1, utf8Require1 := detectUTF8(fh.Name)
utf8Valid2, utf8Require2 := detectUTF8(fh.Comment)
switch {
case fh.NonUTF8:
fh.Flags &^= 0x800
case (utf8Require1 || utf8Require2) && (utf8Valid1 && utf8Valid2):
fh.Flags |= 0x800
}
fh.CreatorVersion = fh.CreatorVersion&0xff00 | zipVersion20 // preserve compatibility byte
fh.ReaderVersion = zipVersion20
// If Modified is set, this takes precedence over MS-DOS timestamp fields.
if !fh.Modified.IsZero() {
// Contrary to the FileHeader.SetModTime method, we intentionally
// do not convert to UTC, because we assume the user intends to encode
// the date using the specified timezone. A user may want this control
// because many legacy ZIP readers interpret the timestamp according
// to the local timezone.
//
// The timezone is only non-UTC if a user directly sets the Modified
// field directly themselves. All other approaches sets UTC.
fh.ModifiedDate, fh.ModifiedTime = timeToMsDosTime(fh.Modified)
// Use "extended timestamp" format since this is what Info-ZIP uses.
// Nearly every major ZIP implementation uses a different format,
// but at least most seem to be able to understand the other formats.
//
// This format happens to be identical for both local and central header
// if modification time is the only timestamp being encoded.
var mbuf [9]byte // 2*SizeOf(uint16) + SizeOf(uint8) + SizeOf(uint32)
mt := uint32(fh.Modified.Unix())
eb := writeBuf(mbuf[:])
eb.uint16(extTimeExtraID)
eb.uint16(5) // Size: SizeOf(uint8) + SizeOf(uint32)
eb.uint8(1) // Flags: ModTime
eb.uint32(mt) // ModTime
fh.Extra = append(fh.Extra, mbuf[:]...)
}
var (
ow io.Writer
fw *fileWriter
)
h := &header{
FileHeader: fh,
offset: uint64(w.cw.count),
}
if strings.HasSuffix(fh.Name, "/") {
// Set the compression method to Store to ensure data length is truly zero,
// which the writeHeader method always encodes for the size fields.
// This is necessary as most compression formats have non-zero lengths
// even when compressing an empty string.
fh.Method = Store
fh.Flags &^= 0x8 // we will not write a data descriptor
// Explicitly clear sizes as they have no meaning for directories.
fh.CompressedSize = 0
fh.CompressedSize64 = 0
fh.UncompressedSize = 0
fh.UncompressedSize64 = 0
ow = dirWriter{}
} else {
fh.Flags |= 0x8 // we will write a data descriptor
fw = &fileWriter{
zipw: w.cw,
compCount: &countWriter{w: w.cw},
crc32: crc32.NewIEEE(),
}
comp := w.compressor(fh.Method)
if comp == nil {
return nil, ErrAlgorithm
}
var err error
fw.comp, err = comp(fw.compCount)
if err != nil {
return nil, err
}
fw.rawCount = &countWriter{w: fw.comp}
fw.header = h
ow = fw
}
w.dir = append(w.dir, h)
if err := writeHeader(w.cw, h); err != nil {
return nil, err
}
// If we're creating a directory, fw is nil.
w.last = fw
return ow, nil
}
func writeHeader(w io.Writer, h *header) error {
const maxUint16 = 1<<16 - 1
if len(h.Name) > maxUint16 {
return errLongName
}
if len(h.Extra) > maxUint16 {
return errLongExtra
}
var buf [fileHeaderLen]byte
b := writeBuf(buf[:])
b.uint32(uint32(fileHeaderSignature))
b.uint16(h.ReaderVersion)
b.uint16(h.Flags)
b.uint16(h.Method)
b.uint16(h.ModifiedTime)
b.uint16(h.ModifiedDate)
// In raw mode (caller does the compression), the values are either
// written here or in the trailing data descriptor based on the header
// flags.
if h.raw && !h.hasDataDescriptor() {
b.uint32(h.CRC32)
b.uint32(uint32(min64(h.CompressedSize64, uint32max)))
b.uint32(uint32(min64(h.UncompressedSize64, uint32max)))
} else {
// When this package handle the compression, these values are
// always written to the trailing data descriptor.
b.uint32(0) // crc32
b.uint32(0) // compressed size
b.uint32(0) // uncompressed size
}
b.uint16(uint16(len(h.Name)))
b.uint16(uint16(len(h.Extra)))
if _, err := w.Write(buf[:]); err != nil {
return err
}
if _, err := io.WriteString(w, h.Name); err != nil {
return err
}
_, err := w.Write(h.Extra)
return err
}
func min64(x, y uint64) uint64 {
if x < y {
return x
}
return y
}
// CreateRaw adds a file to the zip archive using the provided FileHeader and
// returns a Writer to which the file contents should be written. The file's
// contents must be written to the io.Writer before the next call to Create,
// CreateHeader, CreateRaw, or Close.
//
// In contrast to CreateHeader, the bytes passed to Writer are not compressed.
func (w *Writer) CreateRaw(fh *FileHeader) (io.Writer, error) {
if err := w.prepare(fh); err != nil {
return nil, err
}
fh.CompressedSize = uint32(min64(fh.CompressedSize64, uint32max))
fh.UncompressedSize = uint32(min64(fh.UncompressedSize64, uint32max))
h := &header{
FileHeader: fh,
offset: uint64(w.cw.count),
raw: true,
}
w.dir = append(w.dir, h)
if err := writeHeader(w.cw, h); err != nil {
return nil, err
}
if strings.HasSuffix(fh.Name, "/") {
w.last = nil
return dirWriter{}, nil
}
fw := &fileWriter{
header: h,
zipw: w.cw,
}
w.last = fw
return fw, nil
}
// Copy copies the file f (obtained from a Reader) into w. It copies the raw
// form directly bypassing decompression, compression, and validation.
func (w *Writer) Copy(f *File) error {
r, err := f.OpenRaw()
if err != nil {
return err
}
fw, err := w.CreateRaw(&f.FileHeader)
if err != nil {
return err
}
_, err = io.Copy(fw, r)
return err
}
// RegisterCompressor registers or overrides a custom compressor for a specific
// method ID. If a compressor for a given method is not found, Writer will
// default to looking up the compressor at the package level.
func (w *Writer) RegisterCompressor(method uint16, comp Compressor) {
if w.compressors == nil {
w.compressors = make(map[uint16]Compressor)
}
w.compressors[method] = comp
}
func (w *Writer) compressor(method uint16) Compressor {
comp := w.compressors[method]
if comp == nil {
comp = compressor(method)
}
return comp
}
type dirWriter struct{}
func (dirWriter) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
return 0, errors.New("zip: write to directory")
}
type fileWriter struct {
*header
zipw io.Writer
rawCount *countWriter
comp io.WriteCloser
compCount *countWriter
crc32 hash.Hash32
closed bool
}
func (w *fileWriter) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("zip: write to closed file")
}
if w.raw {
return w.zipw.Write(p)
}
w.crc32.Write(p)
return w.rawCount.Write(p)
}
func (w *fileWriter) close() error {
if w.closed {
return errors.New("zip: file closed twice")
}
w.closed = true
if w.raw {
return w.writeDataDescriptor()
}
if err := w.comp.Close(); err != nil {
return err
}
// update FileHeader
fh := w.header.FileHeader
fh.CRC32 = w.crc32.Sum32()
fh.CompressedSize64 = uint64(w.compCount.count)
fh.UncompressedSize64 = uint64(w.rawCount.count)
if fh.isZip64() {
fh.CompressedSize = uint32max
fh.UncompressedSize = uint32max
fh.ReaderVersion = zipVersion45 // requires 4.5 - File uses ZIP64 format extensions
} else {
fh.CompressedSize = uint32(fh.CompressedSize64)
fh.UncompressedSize = uint32(fh.UncompressedSize64)
}
return w.writeDataDescriptor()
}
func (w *fileWriter) writeDataDescriptor() error {
if !w.hasDataDescriptor() {
return nil
}
// Write data descriptor. This is more complicated than one would
// think, see e.g. comments in zipfile.c:putextended() and
// http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=7073588.
// The approach here is to write 8 byte sizes if needed without
// adding a zip64 extra in the local header (too late anyway).
var buf []byte
if w.isZip64() {
buf = make([]byte, dataDescriptor64Len)
} else {
buf = make([]byte, dataDescriptorLen)
}
b := writeBuf(buf)
b.uint32(dataDescriptorSignature) // de-facto standard, required by OS X
b.uint32(w.CRC32)
if w.isZip64() {
b.uint64(w.CompressedSize64)
b.uint64(w.UncompressedSize64)
} else {
b.uint32(w.CompressedSize)
b.uint32(w.UncompressedSize)
}
_, err := w.zipw.Write(buf)
return err
}
type countWriter struct {
w io.Writer
count int64
}
func (w *countWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.count += int64(n)
return n, err
}
type nopCloser struct {
io.Writer
}
func (w nopCloser) Close() error {
return nil
}
type writeBuf []byte
func (b *writeBuf) uint8(v uint8) {
(*b)[0] = v
*b = (*b)[1:]
}
func (b *writeBuf) uint16(v uint16) {
binary.LittleEndian.PutUint16(*b, v)
*b = (*b)[2:]
}
func (b *writeBuf) uint32(v uint32) {
binary.LittleEndian.PutUint32(*b, v)
*b = (*b)[4:]
}
func (b *writeBuf) uint64(v uint64) {
binary.LittleEndian.PutUint64(*b, v)
*b = (*b)[8:]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bufio implements buffered I/O. It wraps an io.Reader or io.Writer
// object, creating another object (Reader or Writer) that also implements
// the interface but provides buffering and some help for textual I/O.
package bufio
import (
"bytes"
"errors"
"io"
"strings"
"unicode/utf8"
)
const (
defaultBufSize = 4096
)
var (
ErrInvalidUnreadByte = errors.New("bufio: invalid use of UnreadByte")
ErrInvalidUnreadRune = errors.New("bufio: invalid use of UnreadRune")
ErrBufferFull = errors.New("bufio: buffer full")
ErrNegativeCount = errors.New("bufio: negative count")
)
// Buffered input.
// Reader implements buffering for an io.Reader object.
type Reader struct {
buf []byte
rd io.Reader // reader provided by the client
r, w int // buf read and write positions
err error
lastByte int // last byte read for UnreadByte; -1 means invalid
lastRuneSize int // size of last rune read for UnreadRune; -1 means invalid
}
const minReadBufferSize = 16
const maxConsecutiveEmptyReads = 100
// NewReaderSize returns a new Reader whose buffer has at least the specified
// size. If the argument io.Reader is already a Reader with large enough
// size, it returns the underlying Reader.
func NewReaderSize(rd io.Reader, size int) *Reader {
// Is it already a Reader?
b, ok := rd.(*Reader)
if ok && len(b.buf) >= size {
return b
}
if size < minReadBufferSize {
size = minReadBufferSize
}
r := new(Reader)
r.reset(make([]byte, size), rd)
return r
}
// NewReader returns a new Reader whose buffer has the default size.
func NewReader(rd io.Reader) *Reader {
return NewReaderSize(rd, defaultBufSize)
}
// Size returns the size of the underlying buffer in bytes.
func (b *Reader) Size() int { return len(b.buf) }
// Reset discards any buffered data, resets all state, and switches
// the buffered reader to read from r.
// Calling Reset on the zero value of Reader initializes the internal buffer
// to the default size.
// Calling b.Reset(b) (that is, resetting a Reader to itself) does nothing.
func (b *Reader) Reset(r io.Reader) {
// If a Reader r is passed to NewReader, NewReader will return r.
// Different layers of code may do that, and then later pass r
// to Reset. Avoid infinite recursion in that case.
if b == r {
return
}
if b.buf == nil {
b.buf = make([]byte, defaultBufSize)
}
b.reset(b.buf, r)
}
func (b *Reader) reset(buf []byte, r io.Reader) {
*b = Reader{
buf: buf,
rd: r,
lastByte: -1,
lastRuneSize: -1,
}
}
var errNegativeRead = errors.New("bufio: reader returned negative count from Read")
// fill reads a new chunk into the buffer.
func (b *Reader) fill() {
// Slide existing data to beginning.
if b.r > 0 {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}
if b.w >= len(b.buf) {
panic("bufio: tried to fill full buffer")
}
// Read new data: try a limited number of times.
for i := maxConsecutiveEmptyReads; i > 0; i-- {
n, err := b.rd.Read(b.buf[b.w:])
if n < 0 {
panic(errNegativeRead)
}
b.w += n
if err != nil {
b.err = err
return
}
if n > 0 {
return
}
}
b.err = io.ErrNoProgress
}
func (b *Reader) readErr() error {
err := b.err
b.err = nil
return err
}
// Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If Peek returns fewer than n bytes, it
// also returns an error explaining why the read is short. The error is
// ErrBufferFull if n is larger than b's buffer size.
//
// Calling Peek prevents a UnreadByte or UnreadRune call from succeeding
// until the next read operation.
func (b *Reader) Peek(n int) ([]byte, error) {
if n < 0 {
return nil, ErrNegativeCount
}
b.lastByte = -1
b.lastRuneSize = -1
for b.w-b.r < n && b.w-b.r < len(b.buf) && b.err == nil {
b.fill() // b.w-b.r < len(b.buf) => buffer is not full
}
if n > len(b.buf) {
return b.buf[b.r:b.w], ErrBufferFull
}
// 0 <= n <= len(b.buf)
var err error
if avail := b.w - b.r; avail < n {
// not enough data in buffer
n = avail
err = b.readErr()
if err == nil {
err = ErrBufferFull
}
}
return b.buf[b.r : b.r+n], err
}
// Discard skips the next n bytes, returning the number of bytes discarded.
//
// If Discard skips fewer than n bytes, it also returns an error.
// If 0 <= n <= b.Buffered(), Discard is guaranteed to succeed without
// reading from the underlying io.Reader.
func (b *Reader) Discard(n int) (discarded int, err error) {
if n < 0 {
return 0, ErrNegativeCount
}
if n == 0 {
return
}
b.lastByte = -1
b.lastRuneSize = -1
remain := n
for {
skip := b.Buffered()
if skip == 0 {
b.fill()
skip = b.Buffered()
}
if skip > remain {
skip = remain
}
b.r += skip
remain -= skip
if remain == 0 {
return n, nil
}
if b.err != nil {
return n - remain, b.readErr()
}
}
}
// Read reads data into p.
// It returns the number of bytes read into p.
// The bytes are taken from at most one Read on the underlying Reader,
// hence n may be less than len(p).
// To read exactly len(p) bytes, use io.ReadFull(b, p).
// If the underlying Reader can return a non-zero count with io.EOF,
// then this Read method can do so as well; see the [io.Reader] docs.
func (b *Reader) Read(p []byte) (n int, err error) {
n = len(p)
if n == 0 {
if b.Buffered() > 0 {
return 0, nil
}
return 0, b.readErr()
}
if b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
if len(p) >= len(b.buf) {
// Large read, empty buffer.
// Read directly into p to avoid copy.
n, b.err = b.rd.Read(p)
if n < 0 {
panic(errNegativeRead)
}
if n > 0 {
b.lastByte = int(p[n-1])
b.lastRuneSize = -1
}
return n, b.readErr()
}
// One read.
// Do not use b.fill, which will loop.
b.r = 0
b.w = 0
n, b.err = b.rd.Read(b.buf)
if n < 0 {
panic(errNegativeRead)
}
if n == 0 {
return 0, b.readErr()
}
b.w += n
}
// copy as much as we can
// Note: if the slice panics here, it is probably because
// the underlying reader returned a bad count. See issue 49795.
n = copy(p, b.buf[b.r:b.w])
b.r += n
b.lastByte = int(b.buf[b.r-1])
b.lastRuneSize = -1
return n, nil
}
// ReadByte reads and returns a single byte.
// If no byte is available, returns an error.
func (b *Reader) ReadByte() (byte, error) {
b.lastRuneSize = -1
for b.r == b.w {
if b.err != nil {
return 0, b.readErr()
}
b.fill() // buffer is empty
}
c := b.buf[b.r]
b.r++
b.lastByte = int(c)
return c, nil
}
// UnreadByte unreads the last byte. Only the most recently read byte can be unread.
//
// UnreadByte returns an error if the most recent method called on the
// Reader was not a read operation. Notably, Peek, Discard, and WriteTo are not
// considered read operations.
func (b *Reader) UnreadByte() error {
if b.lastByte < 0 || b.r == 0 && b.w > 0 {
return ErrInvalidUnreadByte
}
// b.r > 0 || b.w == 0
if b.r > 0 {
b.r--
} else {
// b.r == 0 && b.w == 0
b.w = 1
}
b.buf[b.r] = byte(b.lastByte)
b.lastByte = -1
b.lastRuneSize = -1
return nil
}
// ReadRune reads a single UTF-8 encoded Unicode character and returns the
// rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
func (b *Reader) ReadRune() (r rune, size int, err error) {
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil && b.w-b.r < len(b.buf) {
b.fill() // b.w-b.r < len(buf) => buffer is not full
}
b.lastRuneSize = -1
if b.r == b.w {
return 0, 0, b.readErr()
}
r, size = rune(b.buf[b.r]), 1
if r >= utf8.RuneSelf {
r, size = utf8.DecodeRune(b.buf[b.r:b.w])
}
b.r += size
b.lastByte = int(b.buf[b.r-1])
b.lastRuneSize = size
return r, size, nil
}
// UnreadRune unreads the last rune. If the most recent method called on
// the Reader was not a ReadRune, UnreadRune returns an error. (In this
// regard it is stricter than UnreadByte, which will unread the last byte
// from any read operation.)
func (b *Reader) UnreadRune() error {
if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
return ErrInvalidUnreadRune
}
b.r -= b.lastRuneSize
b.lastByte = -1
b.lastRuneSize = -1
return nil
}
// Buffered returns the number of bytes that can be read from the current buffer.
func (b *Reader) Buffered() int { return b.w - b.r }
// ReadSlice reads until the first occurrence of delim in the input,
// returning a slice pointing at the bytes in the buffer.
// The bytes stop being valid at the next read.
// If ReadSlice encounters an error before finding a delimiter,
// it returns all the data in the buffer and the error itself (often io.EOF).
// ReadSlice fails with error ErrBufferFull if the buffer fills without a delim.
// Because the data returned from ReadSlice will be overwritten
// by the next I/O operation, most clients should use
// ReadBytes or ReadString instead.
// ReadSlice returns err != nil if and only if line does not end in delim.
func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
s := 0 // search start index
for {
// Search buffer.
if i := bytes.IndexByte(b.buf[b.r+s:b.w], delim); i >= 0 {
i += s
line = b.buf[b.r : b.r+i+1]
b.r += i + 1
break
}
// Pending error?
if b.err != nil {
line = b.buf[b.r:b.w]
b.r = b.w
err = b.readErr()
break
}
// Buffer full?
if b.Buffered() >= len(b.buf) {
b.r = b.w
line = b.buf
err = ErrBufferFull
break
}
s = b.w - b.r // do not rescan area we scanned before
b.fill() // buffer is not full
}
// Handle last byte, if any.
if i := len(line) - 1; i >= 0 {
b.lastByte = int(line[i])
b.lastRuneSize = -1
}
return
}
// ReadLine is a low-level line-reading primitive. Most callers should use
// ReadBytes('\n') or ReadString('\n') instead or use a Scanner.
//
// ReadLine tries to return a single line, not including the end-of-line bytes.
// If the line was too long for the buffer then isPrefix is set and the
// beginning of the line is returned. The rest of the line will be returned
// from future calls. isPrefix will be false when returning the last fragment
// of the line. The returned buffer is only valid until the next call to
// ReadLine. ReadLine either returns a non-nil line or it returns an error,
// never both.
//
// The text returned from ReadLine does not include the line end ("\r\n" or "\n").
// No indication or error is given if the input ends without a final line end.
// Calling UnreadByte after ReadLine will always unread the last byte read
// (possibly a character belonging to the line end) even if that byte is not
// part of the line returned by ReadLine.
func (b *Reader) ReadLine() (line []byte, isPrefix bool, err error) {
line, err = b.ReadSlice('\n')
if err == ErrBufferFull {
// Handle the case where "\r\n" straddles the buffer.
if len(line) > 0 && line[len(line)-1] == '\r' {
// Put the '\r' back on buf and drop it from line.
// Let the next call to ReadLine check for "\r\n".
if b.r == 0 {
// should be unreachable
panic("bufio: tried to rewind past start of buffer")
}
b.r--
line = line[:len(line)-1]
}
return line, true, nil
}
if len(line) == 0 {
if err != nil {
line = nil
}
return
}
err = nil
if line[len(line)-1] == '\n' {
drop := 1
if len(line) > 1 && line[len(line)-2] == '\r' {
drop = 2
}
line = line[:len(line)-drop]
}
return
}
// collectFragments reads until the first occurrence of delim in the input. It
// returns (slice of full buffers, remaining bytes before delim, total number
// of bytes in the combined first two elements, error).
// The complete result is equal to
// `bytes.Join(append(fullBuffers, finalFragment), nil)`, which has a
// length of `totalLen`. The result is structured in this way to allow callers
// to minimize allocations and copies.
func (b *Reader) collectFragments(delim byte) (fullBuffers [][]byte, finalFragment []byte, totalLen int, err error) {
var frag []byte
// Use ReadSlice to look for delim, accumulating full buffers.
for {
var e error
frag, e = b.ReadSlice(delim)
if e == nil { // got final fragment
break
}
if e != ErrBufferFull { // unexpected error
err = e
break
}
// Make a copy of the buffer.
buf := bytes.Clone(frag)
fullBuffers = append(fullBuffers, buf)
totalLen += len(buf)
}
totalLen += len(frag)
return fullBuffers, frag, totalLen, err
}
// ReadBytes reads until the first occurrence of delim in the input,
// returning a slice containing the data up to and including the delimiter.
// If ReadBytes encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadBytes returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadBytes(delim byte) ([]byte, error) {
full, frag, n, err := b.collectFragments(delim)
// Allocate new buffer to hold the full pieces and the fragment.
buf := make([]byte, n)
n = 0
// Copy full pieces and fragment in.
for i := range full {
n += copy(buf[n:], full[i])
}
copy(buf[n:], frag)
return buf, err
}
// ReadString reads until the first occurrence of delim in the input,
// returning a string containing the data up to and including the delimiter.
// If ReadString encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadString returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadString(delim byte) (string, error) {
full, frag, n, err := b.collectFragments(delim)
// Allocate new buffer to hold the full pieces and the fragment.
var buf strings.Builder
buf.Grow(n)
// Copy full pieces and fragment in.
for _, fb := range full {
buf.Write(fb)
}
buf.Write(frag)
return buf.String(), err
}
// WriteTo implements io.WriterTo.
// This may make multiple calls to the Read method of the underlying Reader.
// If the underlying reader supports the WriteTo method,
// this calls the underlying WriteTo without buffering.
func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
b.lastByte = -1
b.lastRuneSize = -1
n, err = b.writeBuf(w)
if err != nil {
return
}
if r, ok := b.rd.(io.WriterTo); ok {
m, err := r.WriteTo(w)
n += m
return n, err
}
if w, ok := w.(io.ReaderFrom); ok {
m, err := w.ReadFrom(b.rd)
n += m
return n, err
}
if b.w-b.r < len(b.buf) {
b.fill() // buffer not full
}
for b.r < b.w {
// b.r < b.w => buffer is not empty
m, err := b.writeBuf(w)
n += m
if err != nil {
return n, err
}
b.fill() // buffer is empty
}
if b.err == io.EOF {
b.err = nil
}
return n, b.readErr()
}
var errNegativeWrite = errors.New("bufio: writer returned negative count from Write")
// writeBuf writes the Reader's buffer to the writer.
func (b *Reader) writeBuf(w io.Writer) (int64, error) {
n, err := w.Write(b.buf[b.r:b.w])
if n < 0 {
panic(errNegativeWrite)
}
b.r += n
return int64(n), err
}
// buffered output
// Writer implements buffering for an io.Writer object.
// If an error occurs writing to a Writer, no more data will be
// accepted and all subsequent writes, and Flush, will return the error.
// After all data has been written, the client should call the
// Flush method to guarantee all data has been forwarded to
// the underlying io.Writer.
type Writer struct {
err error
buf []byte
n int
wr io.Writer
}
// NewWriterSize returns a new Writer whose buffer has at least the specified
// size. If the argument io.Writer is already a Writer with large enough
// size, it returns the underlying Writer.
func NewWriterSize(w io.Writer, size int) *Writer {
// Is it already a Writer?
b, ok := w.(*Writer)
if ok && len(b.buf) >= size {
return b
}
if size <= 0 {
size = defaultBufSize
}
return &Writer{
buf: make([]byte, size),
wr: w,
}
}
// NewWriter returns a new Writer whose buffer has the default size.
// If the argument io.Writer is already a Writer with large enough buffer size,
// it returns the underlying Writer.
func NewWriter(w io.Writer) *Writer {
return NewWriterSize(w, defaultBufSize)
}
// Size returns the size of the underlying buffer in bytes.
func (b *Writer) Size() int { return len(b.buf) }
// Reset discards any unflushed buffered data, clears any error, and
// resets b to write its output to w.
// Calling Reset on the zero value of Writer initializes the internal buffer
// to the default size.
// Calling w.Reset(w) (that is, resetting a Writer to itself) does nothing.
func (b *Writer) Reset(w io.Writer) {
// If a Writer w is passed to NewWriter, NewWriter will return w.
// Different layers of code may do that, and then later pass w
// to Reset. Avoid infinite recursion in that case.
if b == w {
return
}
if b.buf == nil {
b.buf = make([]byte, defaultBufSize)
}
b.err = nil
b.n = 0
b.wr = w
}
// Flush writes any buffered data to the underlying io.Writer.
func (b *Writer) Flush() error {
if b.err != nil {
return b.err
}
if b.n == 0 {
return nil
}
n, err := b.wr.Write(b.buf[0:b.n])
if n < b.n && err == nil {
err = io.ErrShortWrite
}
if err != nil {
if n > 0 && n < b.n {
copy(b.buf[0:b.n-n], b.buf[n:b.n])
}
b.n -= n
b.err = err
return err
}
b.n = 0
return nil
}
// Available returns how many bytes are unused in the buffer.
func (b *Writer) Available() int { return len(b.buf) - b.n }
// AvailableBuffer returns an empty buffer with b.Available() capacity.
// This buffer is intended to be appended to and
// passed to an immediately succeeding Write call.
// The buffer is only valid until the next write operation on b.
func (b *Writer) AvailableBuffer() []byte {
return b.buf[b.n:][:0]
}
// Buffered returns the number of bytes that have been written into the current buffer.
func (b *Writer) Buffered() int { return b.n }
// Write writes the contents of p into the buffer.
// It returns the number of bytes written.
// If nn < len(p), it also returns an error explaining
// why the write is short.
func (b *Writer) Write(p []byte) (nn int, err error) {
for len(p) > b.Available() && b.err == nil {
var n int
if b.Buffered() == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, b.err = b.wr.Write(p)
} else {
n = copy(b.buf[b.n:], p)
b.n += n
b.Flush()
}
nn += n
p = p[n:]
}
if b.err != nil {
return nn, b.err
}
n := copy(b.buf[b.n:], p)
b.n += n
nn += n
return nn, nil
}
// WriteByte writes a single byte.
func (b *Writer) WriteByte(c byte) error {
if b.err != nil {
return b.err
}
if b.Available() <= 0 && b.Flush() != nil {
return b.err
}
b.buf[b.n] = c
b.n++
return nil
}
// WriteRune writes a single Unicode code point, returning
// the number of bytes written and any error.
func (b *Writer) WriteRune(r rune) (size int, err error) {
// Compare as uint32 to correctly handle negative runes.
if uint32(r) < utf8.RuneSelf {
err = b.WriteByte(byte(r))
if err != nil {
return 0, err
}
return 1, nil
}
if b.err != nil {
return 0, b.err
}
n := b.Available()
if n < utf8.UTFMax {
if b.Flush(); b.err != nil {
return 0, b.err
}
n = b.Available()
if n < utf8.UTFMax {
// Can only happen if buffer is silly small.
return b.WriteString(string(r))
}
}
size = utf8.EncodeRune(b.buf[b.n:], r)
b.n += size
return size, nil
}
// WriteString writes a string.
// It returns the number of bytes written.
// If the count is less than len(s), it also returns an error explaining
// why the write is short.
func (b *Writer) WriteString(s string) (int, error) {
var sw io.StringWriter
tryStringWriter := true
nn := 0
for len(s) > b.Available() && b.err == nil {
var n int
if b.Buffered() == 0 && sw == nil && tryStringWriter {
// Check at most once whether b.wr is a StringWriter.
sw, tryStringWriter = b.wr.(io.StringWriter)
}
if b.Buffered() == 0 && tryStringWriter {
// Large write, empty buffer, and the underlying writer supports
// WriteString: forward the write to the underlying StringWriter.
// This avoids an extra copy.
n, b.err = sw.WriteString(s)
} else {
n = copy(b.buf[b.n:], s)
b.n += n
b.Flush()
}
nn += n
s = s[n:]
}
if b.err != nil {
return nn, b.err
}
n := copy(b.buf[b.n:], s)
b.n += n
nn += n
return nn, nil
}
// ReadFrom implements io.ReaderFrom. If the underlying writer
// supports the ReadFrom method, this calls the underlying ReadFrom.
// If there is buffered data and an underlying ReadFrom, this fills
// the buffer and writes it before calling ReadFrom.
func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
if b.err != nil {
return 0, b.err
}
readerFrom, readerFromOK := b.wr.(io.ReaderFrom)
var m int
for {
if b.Available() == 0 {
if err1 := b.Flush(); err1 != nil {
return n, err1
}
}
if readerFromOK && b.Buffered() == 0 {
nn, err := readerFrom.ReadFrom(r)
b.err = err
n += nn
return n, err
}
nr := 0
for nr < maxConsecutiveEmptyReads {
m, err = r.Read(b.buf[b.n:])
if m != 0 || err != nil {
break
}
nr++
}
if nr == maxConsecutiveEmptyReads {
return n, io.ErrNoProgress
}
b.n += m
n += int64(m)
if err != nil {
break
}
}
if err == io.EOF {
// If we filled the buffer exactly, flush preemptively.
if b.Available() == 0 {
err = b.Flush()
} else {
err = nil
}
}
return n, err
}
// buffered input and output
// ReadWriter stores pointers to a Reader and a Writer.
// It implements io.ReadWriter.
type ReadWriter struct {
*Reader
*Writer
}
// NewReadWriter allocates a new ReadWriter that dispatches to r and w.
func NewReadWriter(r *Reader, w *Writer) *ReadWriter {
return &ReadWriter{r, w}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bufio
import (
"bytes"
"errors"
"io"
"unicode/utf8"
)
// Scanner provides a convenient interface for reading data such as
// a file of newline-delimited lines of text. Successive calls to
// the Scan method will step through the 'tokens' of a file, skipping
// the bytes between the tokens. The specification of a token is
// defined by a split function of type SplitFunc; the default split
// function breaks the input into lines with line termination stripped. Split
// functions are defined in this package for scanning a file into
// lines, bytes, UTF-8-encoded runes, and space-delimited words. The
// client may instead provide a custom split function.
//
// Scanning stops unrecoverably at EOF, the first I/O error, or a token too
// large to fit in the buffer. When a scan stops, the reader may have
// advanced arbitrarily far past the last token. Programs that need more
// control over error handling or large tokens, or must run sequential scans
// on a reader, should use bufio.Reader instead.
type Scanner struct {
r io.Reader // The reader provided by the client.
split SplitFunc // The function to split the tokens.
maxTokenSize int // Maximum size of a token; modified by tests.
token []byte // Last token returned by split.
buf []byte // Buffer used as argument to split.
start int // First non-processed byte in buf.
end int // End of data in buf.
err error // Sticky error.
empties int // Count of successive empty tokens.
scanCalled bool // Scan has been called; buffer is in use.
done bool // Scan has finished.
}
// SplitFunc is the signature of the split function used to tokenize the
// input. The arguments are an initial substring of the remaining unprocessed
// data and a flag, atEOF, that reports whether the Reader has no more data
// to give. The return values are the number of bytes to advance the input
// and the next token to return to the user, if any, plus an error, if any.
//
// Scanning stops if the function returns an error, in which case some of
// the input may be discarded. If that error is ErrFinalToken, scanning
// stops with no error.
//
// Otherwise, the Scanner advances the input. If the token is not nil,
// the Scanner returns it to the user. If the token is nil, the
// Scanner reads more data and continues scanning; if there is no more
// data--if atEOF was true--the Scanner returns. If the data does not
// yet hold a complete token, for instance if it has no newline while
// scanning lines, a SplitFunc can return (0, nil, nil) to signal the
// Scanner to read more data into the slice and try again with a
// longer slice starting at the same point in the input.
//
// The function is never called with an empty data slice unless atEOF
// is true. If atEOF is true, however, data may be non-empty and,
// as always, holds unprocessed text.
type SplitFunc func(data []byte, atEOF bool) (advance int, token []byte, err error)
// Errors returned by Scanner.
var (
ErrTooLong = errors.New("bufio.Scanner: token too long")
ErrNegativeAdvance = errors.New("bufio.Scanner: SplitFunc returns negative advance count")
ErrAdvanceTooFar = errors.New("bufio.Scanner: SplitFunc returns advance count beyond input")
ErrBadReadCount = errors.New("bufio.Scanner: Read returned impossible count")
)
const (
// MaxScanTokenSize is the maximum size used to buffer a token
// unless the user provides an explicit buffer with Scanner.Buffer.
// The actual maximum token size may be smaller as the buffer
// may need to include, for instance, a newline.
MaxScanTokenSize = 64 * 1024
startBufSize = 4096 // Size of initial allocation for buffer.
)
// NewScanner returns a new Scanner to read from r.
// The split function defaults to ScanLines.
func NewScanner(r io.Reader) *Scanner {
return &Scanner{
r: r,
split: ScanLines,
maxTokenSize: MaxScanTokenSize,
}
}
// Err returns the first non-EOF error that was encountered by the Scanner.
func (s *Scanner) Err() error {
if s.err == io.EOF {
return nil
}
return s.err
}
// Bytes returns the most recent token generated by a call to Scan.
// The underlying array may point to data that will be overwritten
// by a subsequent call to Scan. It does no allocation.
func (s *Scanner) Bytes() []byte {
return s.token
}
// Text returns the most recent token generated by a call to Scan
// as a newly allocated string holding its bytes.
func (s *Scanner) Text() string {
return string(s.token)
}
// ErrFinalToken is a special sentinel error value. It is intended to be
// returned by a Split function to indicate that the token being delivered
// with the error is the last token and scanning should stop after this one.
// After ErrFinalToken is received by Scan, scanning stops with no error.
// The value is useful to stop processing early or when it is necessary to
// deliver a final empty token. One could achieve the same behavior
// with a custom error value but providing one here is tidier.
// See the emptyFinalToken example for a use of this value.
var ErrFinalToken = errors.New("final token")
// Scan advances the Scanner to the next token, which will then be
// available through the Bytes or Text method. It returns false when the
// scan stops, either by reaching the end of the input or an error.
// After Scan returns false, the Err method will return any error that
// occurred during scanning, except that if it was io.EOF, Err
// will return nil.
// Scan panics if the split function returns too many empty
// tokens without advancing the input. This is a common error mode for
// scanners.
func (s *Scanner) Scan() bool {
if s.done {
return false
}
s.scanCalled = true
// Loop until we have a token.
for {
// See if we can get a token with what we already have.
// If we've run out of data but have an error, give the split function
// a chance to recover any remaining, possibly empty token.
if s.end > s.start || s.err != nil {
advance, token, err := s.split(s.buf[s.start:s.end], s.err != nil)
if err != nil {
if err == ErrFinalToken {
s.token = token
s.done = true
return true
}
s.setErr(err)
return false
}
if !s.advance(advance) {
return false
}
s.token = token
if token != nil {
if s.err == nil || advance > 0 {
s.empties = 0
} else {
// Returning tokens not advancing input at EOF.
s.empties++
if s.empties > maxConsecutiveEmptyReads {
panic("bufio.Scan: too many empty tokens without progressing")
}
}
return true
}
}
// We cannot generate a token with what we are holding.
// If we've already hit EOF or an I/O error, we are done.
if s.err != nil {
// Shut it down.
s.start = 0
s.end = 0
return false
}
// Must read more data.
// First, shift data to beginning of buffer if there's lots of empty space
// or space is needed.
if s.start > 0 && (s.end == len(s.buf) || s.start > len(s.buf)/2) {
copy(s.buf, s.buf[s.start:s.end])
s.end -= s.start
s.start = 0
}
// Is the buffer full? If so, resize.
if s.end == len(s.buf) {
// Guarantee no overflow in the multiplication below.
const maxInt = int(^uint(0) >> 1)
if len(s.buf) >= s.maxTokenSize || len(s.buf) > maxInt/2 {
s.setErr(ErrTooLong)
return false
}
newSize := len(s.buf) * 2
if newSize == 0 {
newSize = startBufSize
}
if newSize > s.maxTokenSize {
newSize = s.maxTokenSize
}
newBuf := make([]byte, newSize)
copy(newBuf, s.buf[s.start:s.end])
s.buf = newBuf
s.end -= s.start
s.start = 0
}
// Finally we can read some input. Make sure we don't get stuck with
// a misbehaving Reader. Officially we don't need to do this, but let's
// be extra careful: Scanner is for safe, simple jobs.
for loop := 0; ; {
n, err := s.r.Read(s.buf[s.end:len(s.buf)])
if n < 0 || len(s.buf)-s.end < n {
s.setErr(ErrBadReadCount)
break
}
s.end += n
if err != nil {
s.setErr(err)
break
}
if n > 0 {
s.empties = 0
break
}
loop++
if loop > maxConsecutiveEmptyReads {
s.setErr(io.ErrNoProgress)
break
}
}
}
}
// advance consumes n bytes of the buffer. It reports whether the advance was legal.
func (s *Scanner) advance(n int) bool {
if n < 0 {
s.setErr(ErrNegativeAdvance)
return false
}
if n > s.end-s.start {
s.setErr(ErrAdvanceTooFar)
return false
}
s.start += n
return true
}
// setErr records the first error encountered.
func (s *Scanner) setErr(err error) {
if s.err == nil || s.err == io.EOF {
s.err = err
}
}
// Buffer sets the initial buffer to use when scanning and the maximum
// size of buffer that may be allocated during scanning. The maximum
// token size is the larger of max and cap(buf). If max <= cap(buf),
// Scan will use this buffer only and do no allocation.
//
// By default, Scan uses an internal buffer and sets the
// maximum token size to MaxScanTokenSize.
//
// Buffer panics if it is called after scanning has started.
func (s *Scanner) Buffer(buf []byte, max int) {
if s.scanCalled {
panic("Buffer called after Scan")
}
s.buf = buf[0:cap(buf)]
s.maxTokenSize = max
}
// Split sets the split function for the Scanner.
// The default split function is ScanLines.
//
// Split panics if it is called after scanning has started.
func (s *Scanner) Split(split SplitFunc) {
if s.scanCalled {
panic("Split called after Scan")
}
s.split = split
}
// Split functions
// ScanBytes is a split function for a Scanner that returns each byte as a token.
func ScanBytes(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
return 1, data[0:1], nil
}
var errorRune = []byte(string(utf8.RuneError))
// ScanRunes is a split function for a Scanner that returns each
// UTF-8-encoded rune as a token. The sequence of runes returned is
// equivalent to that from a range loop over the input as a string, which
// means that erroneous UTF-8 encodings translate to U+FFFD = "\xef\xbf\xbd".
// Because of the Scan interface, this makes it impossible for the client to
// distinguish correctly encoded replacement runes from encoding errors.
func ScanRunes(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
// Fast path 1: ASCII.
if data[0] < utf8.RuneSelf {
return 1, data[0:1], nil
}
// Fast path 2: Correct UTF-8 decode without error.
_, width := utf8.DecodeRune(data)
if width > 1 {
// It's a valid encoding. Width cannot be one for a correctly encoded
// non-ASCII rune.
return width, data[0:width], nil
}
// We know it's an error: we have width==1 and implicitly r==utf8.RuneError.
// Is the error because there wasn't a full rune to be decoded?
// FullRune distinguishes correctly between erroneous and incomplete encodings.
if !atEOF && !utf8.FullRune(data) {
// Incomplete; get more bytes.
return 0, nil, nil
}
// We have a real UTF-8 encoding error. Return a properly encoded error rune
// but advance only one byte. This matches the behavior of a range loop over
// an incorrectly encoded string.
return 1, errorRune, nil
}
// dropCR drops a terminal \r from the data.
func dropCR(data []byte) []byte {
if len(data) > 0 && data[len(data)-1] == '\r' {
return data[0 : len(data)-1]
}
return data
}
// ScanLines is a split function for a Scanner that returns each line of
// text, stripped of any trailing end-of-line marker. The returned line may
// be empty. The end-of-line marker is one optional carriage return followed
// by one mandatory newline. In regular expression notation, it is `\r?\n`.
// The last non-empty line of input will be returned even if it has no
// newline.
func ScanLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i >= 0 {
// We have a full newline-terminated line.
return i + 1, dropCR(data[0:i]), nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), dropCR(data), nil
}
// Request more data.
return 0, nil, nil
}
// isSpace reports whether the character is a Unicode white space character.
// We avoid dependency on the unicode package, but check validity of the implementation
// in the tests.
func isSpace(r rune) bool {
if r <= '\u00FF' {
// Obvious ASCII ones: \t through \r plus space. Plus two Latin-1 oddballs.
switch r {
case ' ', '\t', '\n', '\v', '\f', '\r':
return true
case '\u0085', '\u00A0':
return true
}
return false
}
// High-valued ones.
if '\u2000' <= r && r <= '\u200a' {
return true
}
switch r {
case '\u1680', '\u2028', '\u2029', '\u202f', '\u205f', '\u3000':
return true
}
return false
}
// ScanWords is a split function for a Scanner that returns each
// space-separated word of text, with surrounding spaces deleted. It will
// never return an empty string. The definition of space is set by
// unicode.IsSpace.
func ScanWords(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Skip leading spaces.
start := 0
for width := 0; start < len(data); start += width {
var r rune
r, width = utf8.DecodeRune(data[start:])
if !isSpace(r) {
break
}
}
// Scan until space, marking end of word.
for width, i := 0, start; i < len(data); i += width {
var r rune
r, width = utf8.DecodeRune(data[i:])
if isSpace(r) {
return i + width, data[start:i], nil
}
}
// If we're at EOF, we have a final, non-empty, non-terminated word. Return it.
if atEOF && len(data) > start {
return len(data), data[start:], nil
}
// Request more data.
return start, nil, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bytes
// Simple byte buffer for marshaling data.
import (
"errors"
"io"
"unicode/utf8"
)
// smallBufferSize is an initial allocation minimal capacity.
const smallBufferSize = 64
// A Buffer is a variable-sized buffer of bytes with Read and Write methods.
// The zero value for Buffer is an empty buffer ready to use.
type Buffer struct {
buf []byte // contents are the bytes buf[off : len(buf)]
off int // read at &buf[off], write at &buf[len(buf)]
lastRead readOp // last read operation, so that Unread* can work correctly.
}
// The readOp constants describe the last action performed on
// the buffer, so that UnreadRune and UnreadByte can check for
// invalid usage. opReadRuneX constants are chosen such that
// converted to int they correspond to the rune size that was read.
type readOp int8
// Don't use iota for these, as the values need to correspond with the
// names and comments, which is easier to see when being explicit.
const (
opRead readOp = -1 // Any other read operation.
opInvalid readOp = 0 // Non-read operation.
opReadRune1 readOp = 1 // Read rune of size 1.
opReadRune2 readOp = 2 // Read rune of size 2.
opReadRune3 readOp = 3 // Read rune of size 3.
opReadRune4 readOp = 4 // Read rune of size 4.
)
// ErrTooLarge is passed to panic if memory cannot be allocated to store data in a buffer.
var ErrTooLarge = errors.New("bytes.Buffer: too large")
var errNegativeRead = errors.New("bytes.Buffer: reader returned negative count from Read")
const maxInt = int(^uint(0) >> 1)
// Bytes returns a slice of length b.Len() holding the unread portion of the buffer.
// The slice is valid for use only until the next buffer modification (that is,
// only until the next call to a method like Read, Write, Reset, or Truncate).
// The slice aliases the buffer content at least until the next buffer modification,
// so immediate changes to the slice will affect the result of future reads.
func (b *Buffer) Bytes() []byte { return b.buf[b.off:] }
// String returns the contents of the unread portion of the buffer
// as a string. If the Buffer is a nil pointer, it returns "<nil>".
//
// To build strings more efficiently, see the strings.Builder type.
func (b *Buffer) String() string {
if b == nil {
// Special case, useful in debugging.
return "<nil>"
}
return string(b.buf[b.off:])
}
// empty reports whether the unread portion of the buffer is empty.
func (b *Buffer) empty() bool { return len(b.buf) <= b.off }
// Len returns the number of bytes of the unread portion of the buffer;
// b.Len() == len(b.Bytes()).
func (b *Buffer) Len() int { return len(b.buf) - b.off }
// Cap returns the capacity of the buffer's underlying byte slice, that is, the
// total space allocated for the buffer's data.
func (b *Buffer) Cap() int { return cap(b.buf) }
// Truncate discards all but the first n unread bytes from the buffer
// but continues to use the same allocated storage.
// It panics if n is negative or greater than the length of the buffer.
func (b *Buffer) Truncate(n int) {
if n == 0 {
b.Reset()
return
}
b.lastRead = opInvalid
if n < 0 || n > b.Len() {
panic("bytes.Buffer: truncation out of range")
}
b.buf = b.buf[:b.off+n]
}
// Reset resets the buffer to be empty,
// but it retains the underlying storage for use by future writes.
// Reset is the same as Truncate(0).
func (b *Buffer) Reset() {
b.buf = b.buf[:0]
b.off = 0
b.lastRead = opInvalid
}
// tryGrowByReslice is a inlineable version of grow for the fast-case where the
// internal buffer only needs to be resliced.
// It returns the index where bytes should be written and whether it succeeded.
func (b *Buffer) tryGrowByReslice(n int) (int, bool) {
if l := len(b.buf); n <= cap(b.buf)-l {
b.buf = b.buf[:l+n]
return l, true
}
return 0, false
}
// grow grows the buffer to guarantee space for n more bytes.
// It returns the index where bytes should be written.
// If the buffer can't grow it will panic with ErrTooLarge.
func (b *Buffer) grow(n int) int {
m := b.Len()
// If buffer is empty, reset to recover space.
if m == 0 && b.off != 0 {
b.Reset()
}
// Try to grow by means of a reslice.
if i, ok := b.tryGrowByReslice(n); ok {
return i
}
if b.buf == nil && n <= smallBufferSize {
b.buf = make([]byte, n, smallBufferSize)
return 0
}
c := cap(b.buf)
if n <= c/2-m {
// We can slide things down instead of allocating a new
// slice. We only need m+n <= c to slide, but
// we instead let capacity get twice as large so we
// don't spend all our time copying.
copy(b.buf, b.buf[b.off:])
} else if c > maxInt-c-n {
panic(ErrTooLarge)
} else {
// Add b.off to account for b.buf[:b.off] being sliced off the front.
b.buf = growSlice(b.buf[b.off:], b.off+n)
}
// Restore b.off and len(b.buf).
b.off = 0
b.buf = b.buf[:m+n]
return m
}
// Grow grows the buffer's capacity, if necessary, to guarantee space for
// another n bytes. After Grow(n), at least n bytes can be written to the
// buffer without another allocation.
// If n is negative, Grow will panic.
// If the buffer can't grow it will panic with ErrTooLarge.
func (b *Buffer) Grow(n int) {
if n < 0 {
panic("bytes.Buffer.Grow: negative count")
}
m := b.grow(n)
b.buf = b.buf[:m]
}
// Write appends the contents of p to the buffer, growing the buffer as
// needed. The return value n is the length of p; err is always nil. If the
// buffer becomes too large, Write will panic with ErrTooLarge.
func (b *Buffer) Write(p []byte) (n int, err error) {
b.lastRead = opInvalid
m, ok := b.tryGrowByReslice(len(p))
if !ok {
m = b.grow(len(p))
}
return copy(b.buf[m:], p), nil
}
// WriteString appends the contents of s to the buffer, growing the buffer as
// needed. The return value n is the length of s; err is always nil. If the
// buffer becomes too large, WriteString will panic with ErrTooLarge.
func (b *Buffer) WriteString(s string) (n int, err error) {
b.lastRead = opInvalid
m, ok := b.tryGrowByReslice(len(s))
if !ok {
m = b.grow(len(s))
}
return copy(b.buf[m:], s), nil
}
// MinRead is the minimum slice size passed to a Read call by
// Buffer.ReadFrom. As long as the Buffer has at least MinRead bytes beyond
// what is required to hold the contents of r, ReadFrom will not grow the
// underlying buffer.
const MinRead = 512
// ReadFrom reads data from r until EOF and appends it to the buffer, growing
// the buffer as needed. The return value n is the number of bytes read. Any
// error except io.EOF encountered during the read is also returned. If the
// buffer becomes too large, ReadFrom will panic with ErrTooLarge.
func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) {
b.lastRead = opInvalid
for {
i := b.grow(MinRead)
b.buf = b.buf[:i]
m, e := r.Read(b.buf[i:cap(b.buf)])
if m < 0 {
panic(errNegativeRead)
}
b.buf = b.buf[:i+m]
n += int64(m)
if e == io.EOF {
return n, nil // e is EOF, so return nil explicitly
}
if e != nil {
return n, e
}
}
}
// growSlice grows b by n, preserving the original content of b.
// If the allocation fails, it panics with ErrTooLarge.
func growSlice(b []byte, n int) []byte {
defer func() {
if recover() != nil {
panic(ErrTooLarge)
}
}()
// TODO(http://golang.org/issue/51462): We should rely on the append-make
// pattern so that the compiler can call runtime.growslice. For example:
// return append(b, make([]byte, n)...)
// This avoids unnecessary zero-ing of the first len(b) bytes of the
// allocated slice, but this pattern causes b to escape onto the heap.
//
// Instead use the append-make pattern with a nil slice to ensure that
// we allocate buffers rounded up to the closest size class.
c := len(b) + n // ensure enough space for n elements
if c < 2*cap(b) {
// The growth rate has historically always been 2x. In the future,
// we could rely purely on append to determine the growth rate.
c = 2 * cap(b)
}
b2 := append([]byte(nil), make([]byte, c)...)
copy(b2, b)
return b2[:len(b)]
}
// WriteTo writes data to w until the buffer is drained or an error occurs.
// The return value n is the number of bytes written; it always fits into an
// int, but it is int64 to match the io.WriterTo interface. Any error
// encountered during the write is also returned.
func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) {
b.lastRead = opInvalid
if nBytes := b.Len(); nBytes > 0 {
m, e := w.Write(b.buf[b.off:])
if m > nBytes {
panic("bytes.Buffer.WriteTo: invalid Write count")
}
b.off += m
n = int64(m)
if e != nil {
return n, e
}
// all bytes should have been written, by definition of
// Write method in io.Writer
if m != nBytes {
return n, io.ErrShortWrite
}
}
// Buffer is now empty; reset.
b.Reset()
return n, nil
}
// WriteByte appends the byte c to the buffer, growing the buffer as needed.
// The returned error is always nil, but is included to match bufio.Writer's
// WriteByte. If the buffer becomes too large, WriteByte will panic with
// ErrTooLarge.
func (b *Buffer) WriteByte(c byte) error {
b.lastRead = opInvalid
m, ok := b.tryGrowByReslice(1)
if !ok {
m = b.grow(1)
}
b.buf[m] = c
return nil
}
// WriteRune appends the UTF-8 encoding of Unicode code point r to the
// buffer, returning its length and an error, which is always nil but is
// included to match bufio.Writer's WriteRune. The buffer is grown as needed;
// if it becomes too large, WriteRune will panic with ErrTooLarge.
func (b *Buffer) WriteRune(r rune) (n int, err error) {
// Compare as uint32 to correctly handle negative runes.
if uint32(r) < utf8.RuneSelf {
b.WriteByte(byte(r))
return 1, nil
}
b.lastRead = opInvalid
m, ok := b.tryGrowByReslice(utf8.UTFMax)
if !ok {
m = b.grow(utf8.UTFMax)
}
b.buf = utf8.AppendRune(b.buf[:m], r)
return len(b.buf) - m, nil
}
// Read reads the next len(p) bytes from the buffer or until the buffer
// is drained. The return value n is the number of bytes read. If the
// buffer has no data to return, err is io.EOF (unless len(p) is zero);
// otherwise it is nil.
func (b *Buffer) Read(p []byte) (n int, err error) {
b.lastRead = opInvalid
if b.empty() {
// Buffer is empty, reset to recover space.
b.Reset()
if len(p) == 0 {
return 0, nil
}
return 0, io.EOF
}
n = copy(p, b.buf[b.off:])
b.off += n
if n > 0 {
b.lastRead = opRead
}
return n, nil
}
// Next returns a slice containing the next n bytes from the buffer,
// advancing the buffer as if the bytes had been returned by Read.
// If there are fewer than n bytes in the buffer, Next returns the entire buffer.
// The slice is only valid until the next call to a read or write method.
func (b *Buffer) Next(n int) []byte {
b.lastRead = opInvalid
m := b.Len()
if n > m {
n = m
}
data := b.buf[b.off : b.off+n]
b.off += n
if n > 0 {
b.lastRead = opRead
}
return data
}
// ReadByte reads and returns the next byte from the buffer.
// If no byte is available, it returns error io.EOF.
func (b *Buffer) ReadByte() (byte, error) {
if b.empty() {
// Buffer is empty, reset to recover space.
b.Reset()
return 0, io.EOF
}
c := b.buf[b.off]
b.off++
b.lastRead = opRead
return c, nil
}
// ReadRune reads and returns the next UTF-8-encoded
// Unicode code point from the buffer.
// If no bytes are available, the error returned is io.EOF.
// If the bytes are an erroneous UTF-8 encoding, it
// consumes one byte and returns U+FFFD, 1.
func (b *Buffer) ReadRune() (r rune, size int, err error) {
if b.empty() {
// Buffer is empty, reset to recover space.
b.Reset()
return 0, 0, io.EOF
}
c := b.buf[b.off]
if c < utf8.RuneSelf {
b.off++
b.lastRead = opReadRune1
return rune(c), 1, nil
}
r, n := utf8.DecodeRune(b.buf[b.off:])
b.off += n
b.lastRead = readOp(n)
return r, n, nil
}
// UnreadRune unreads the last rune returned by ReadRune.
// If the most recent read or write operation on the buffer was
// not a successful ReadRune, UnreadRune returns an error. (In this regard
// it is stricter than UnreadByte, which will unread the last byte
// from any read operation.)
func (b *Buffer) UnreadRune() error {
if b.lastRead <= opInvalid {
return errors.New("bytes.Buffer: UnreadRune: previous operation was not a successful ReadRune")
}
if b.off >= int(b.lastRead) {
b.off -= int(b.lastRead)
}
b.lastRead = opInvalid
return nil
}
var errUnreadByte = errors.New("bytes.Buffer: UnreadByte: previous operation was not a successful read")
// UnreadByte unreads the last byte returned by the most recent successful
// read operation that read at least one byte. If a write has happened since
// the last read, if the last read returned an error, or if the read read zero
// bytes, UnreadByte returns an error.
func (b *Buffer) UnreadByte() error {
if b.lastRead == opInvalid {
return errUnreadByte
}
b.lastRead = opInvalid
if b.off > 0 {
b.off--
}
return nil
}
// ReadBytes reads until the first occurrence of delim in the input,
// returning a slice containing the data up to and including the delimiter.
// If ReadBytes encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadBytes returns err != nil if and only if the returned data does not end in
// delim.
func (b *Buffer) ReadBytes(delim byte) (line []byte, err error) {
slice, err := b.readSlice(delim)
// return a copy of slice. The buffer's backing array may
// be overwritten by later calls.
line = append(line, slice...)
return line, err
}
// readSlice is like ReadBytes but returns a reference to internal buffer data.
func (b *Buffer) readSlice(delim byte) (line []byte, err error) {
i := IndexByte(b.buf[b.off:], delim)
end := b.off + i + 1
if i < 0 {
end = len(b.buf)
err = io.EOF
}
line = b.buf[b.off:end]
b.off = end
b.lastRead = opRead
return line, err
}
// ReadString reads until the first occurrence of delim in the input,
// returning a string containing the data up to and including the delimiter.
// If ReadString encounters an error before finding a delimiter,
// it returns the data read before the error and the error itself (often io.EOF).
// ReadString returns err != nil if and only if the returned data does not end
// in delim.
func (b *Buffer) ReadString(delim byte) (line string, err error) {
slice, err := b.readSlice(delim)
return string(slice), err
}
// NewBuffer creates and initializes a new Buffer using buf as its
// initial contents. The new Buffer takes ownership of buf, and the
// caller should not use buf after this call. NewBuffer is intended to
// prepare a Buffer to read existing data. It can also be used to set
// the initial size of the internal buffer for writing. To do that,
// buf should have the desired capacity but a length of zero.
//
// In most cases, new(Buffer) (or just declaring a Buffer variable) is
// sufficient to initialize a Buffer.
func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} }
// NewBufferString creates and initializes a new Buffer using string s as its
// initial contents. It is intended to prepare a buffer to read an existing
// string.
//
// In most cases, new(Buffer) (or just declaring a Buffer variable) is
// sufficient to initialize a Buffer.
func NewBufferString(s string) *Buffer {
return &Buffer{buf: []byte(s)}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bytes implements functions for the manipulation of byte slices.
// It is analogous to the facilities of the strings package.
package bytes
import (
"internal/bytealg"
"unicode"
"unicode/utf8"
)
// Equal reports whether a and b
// are the same length and contain the same bytes.
// A nil argument is equivalent to an empty slice.
func Equal(a, b []byte) bool {
// Neither cmd/compile nor gccgo allocates for these string conversions.
return string(a) == string(b)
}
// Compare returns an integer comparing two byte slices lexicographically.
// The result will be 0 if a == b, -1 if a < b, and +1 if a > b.
// A nil argument is equivalent to an empty slice.
func Compare(a, b []byte) int {
return bytealg.Compare(a, b)
}
// explode splits s into a slice of UTF-8 sequences, one per Unicode code point (still slices of bytes),
// up to a maximum of n byte slices. Invalid UTF-8 sequences are chopped into individual bytes.
func explode(s []byte, n int) [][]byte {
if n <= 0 || n > len(s) {
n = len(s)
}
a := make([][]byte, n)
var size int
na := 0
for len(s) > 0 {
if na+1 >= n {
a[na] = s
na++
break
}
_, size = utf8.DecodeRune(s)
a[na] = s[0:size:size]
s = s[size:]
na++
}
return a[0:na]
}
// Count counts the number of non-overlapping instances of sep in s.
// If sep is an empty slice, Count returns 1 + the number of UTF-8-encoded code points in s.
func Count(s, sep []byte) int {
// special case
if len(sep) == 0 {
return utf8.RuneCount(s) + 1
}
if len(sep) == 1 {
return bytealg.Count(s, sep[0])
}
n := 0
for {
i := Index(s, sep)
if i == -1 {
return n
}
n++
s = s[i+len(sep):]
}
}
// Contains reports whether subslice is within b.
func Contains(b, subslice []byte) bool {
return Index(b, subslice) != -1
}
// ContainsAny reports whether any of the UTF-8-encoded code points in chars are within b.
func ContainsAny(b []byte, chars string) bool {
return IndexAny(b, chars) >= 0
}
// ContainsRune reports whether the rune is contained in the UTF-8-encoded byte slice b.
func ContainsRune(b []byte, r rune) bool {
return IndexRune(b, r) >= 0
}
// ContainsFunc reports whether any of the UTF-8-encoded code points r within b satisfy f(r).
func ContainsFunc(b []byte, f func(rune) bool) bool {
return IndexFunc(b, f) >= 0
}
// IndexByte returns the index of the first instance of c in b, or -1 if c is not present in b.
func IndexByte(b []byte, c byte) int {
return bytealg.IndexByte(b, c)
}
func indexBytePortable(s []byte, c byte) int {
for i, b := range s {
if b == c {
return i
}
}
return -1
}
// LastIndex returns the index of the last instance of sep in s, or -1 if sep is not present in s.
func LastIndex(s, sep []byte) int {
n := len(sep)
switch {
case n == 0:
return len(s)
case n == 1:
return LastIndexByte(s, sep[0])
case n == len(s):
if Equal(s, sep) {
return 0
}
return -1
case n > len(s):
return -1
}
// Rabin-Karp search from the end of the string
hashss, pow := bytealg.HashStrRevBytes(sep)
last := len(s) - n
var h uint32
for i := len(s) - 1; i >= last; i-- {
h = h*bytealg.PrimeRK + uint32(s[i])
}
if h == hashss && Equal(s[last:], sep) {
return last
}
for i := last - 1; i >= 0; i-- {
h *= bytealg.PrimeRK
h += uint32(s[i])
h -= pow * uint32(s[i+n])
if h == hashss && Equal(s[i:i+n], sep) {
return i
}
}
return -1
}
// LastIndexByte returns the index of the last instance of c in s, or -1 if c is not present in s.
func LastIndexByte(s []byte, c byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == c {
return i
}
}
return -1
}
// IndexRune interprets s as a sequence of UTF-8-encoded code points.
// It returns the byte index of the first occurrence in s of the given rune.
// It returns -1 if rune is not present in s.
// If r is utf8.RuneError, it returns the first instance of any
// invalid UTF-8 byte sequence.
func IndexRune(s []byte, r rune) int {
switch {
case 0 <= r && r < utf8.RuneSelf:
return IndexByte(s, byte(r))
case r == utf8.RuneError:
for i := 0; i < len(s); {
r1, n := utf8.DecodeRune(s[i:])
if r1 == utf8.RuneError {
return i
}
i += n
}
return -1
case !utf8.ValidRune(r):
return -1
default:
var b [utf8.UTFMax]byte
n := utf8.EncodeRune(b[:], r)
return Index(s, b[:n])
}
}
// IndexAny interprets s as a sequence of UTF-8-encoded Unicode code points.
// It returns the byte index of the first occurrence in s of any of the Unicode
// code points in chars. It returns -1 if chars is empty or if there is no code
// point in common.
func IndexAny(s []byte, chars string) int {
if chars == "" {
// Avoid scanning all of s.
return -1
}
if len(s) == 1 {
r := rune(s[0])
if r >= utf8.RuneSelf {
// search utf8.RuneError.
for _, r = range chars {
if r == utf8.RuneError {
return 0
}
}
return -1
}
if bytealg.IndexByteString(chars, s[0]) >= 0 {
return 0
}
return -1
}
if len(chars) == 1 {
r := rune(chars[0])
if r >= utf8.RuneSelf {
r = utf8.RuneError
}
return IndexRune(s, r)
}
if len(s) > 8 {
if as, isASCII := makeASCIISet(chars); isASCII {
for i, c := range s {
if as.contains(c) {
return i
}
}
return -1
}
}
var width int
for i := 0; i < len(s); i += width {
r := rune(s[i])
if r < utf8.RuneSelf {
if bytealg.IndexByteString(chars, s[i]) >= 0 {
return i
}
width = 1
continue
}
r, width = utf8.DecodeRune(s[i:])
if r != utf8.RuneError {
// r is 2 to 4 bytes
if len(chars) == width {
if chars == string(r) {
return i
}
continue
}
// Use bytealg.IndexString for performance if available.
if bytealg.MaxLen >= width {
if bytealg.IndexString(chars, string(r)) >= 0 {
return i
}
continue
}
}
for _, ch := range chars {
if r == ch {
return i
}
}
}
return -1
}
// LastIndexAny interprets s as a sequence of UTF-8-encoded Unicode code
// points. It returns the byte index of the last occurrence in s of any of
// the Unicode code points in chars. It returns -1 if chars is empty or if
// there is no code point in common.
func LastIndexAny(s []byte, chars string) int {
if chars == "" {
// Avoid scanning all of s.
return -1
}
if len(s) > 8 {
if as, isASCII := makeASCIISet(chars); isASCII {
for i := len(s) - 1; i >= 0; i-- {
if as.contains(s[i]) {
return i
}
}
return -1
}
}
if len(s) == 1 {
r := rune(s[0])
if r >= utf8.RuneSelf {
for _, r = range chars {
if r == utf8.RuneError {
return 0
}
}
return -1
}
if bytealg.IndexByteString(chars, s[0]) >= 0 {
return 0
}
return -1
}
if len(chars) == 1 {
cr := rune(chars[0])
if cr >= utf8.RuneSelf {
cr = utf8.RuneError
}
for i := len(s); i > 0; {
r, size := utf8.DecodeLastRune(s[:i])
i -= size
if r == cr {
return i
}
}
return -1
}
for i := len(s); i > 0; {
r := rune(s[i-1])
if r < utf8.RuneSelf {
if bytealg.IndexByteString(chars, s[i-1]) >= 0 {
return i - 1
}
i--
continue
}
r, size := utf8.DecodeLastRune(s[:i])
i -= size
if r != utf8.RuneError {
// r is 2 to 4 bytes
if len(chars) == size {
if chars == string(r) {
return i
}
continue
}
// Use bytealg.IndexString for performance if available.
if bytealg.MaxLen >= size {
if bytealg.IndexString(chars, string(r)) >= 0 {
return i
}
continue
}
}
for _, ch := range chars {
if r == ch {
return i
}
}
}
return -1
}
// Generic split: splits after each instance of sep,
// including sepSave bytes of sep in the subslices.
func genSplit(s, sep []byte, sepSave, n int) [][]byte {
if n == 0 {
return nil
}
if len(sep) == 0 {
return explode(s, n)
}
if n < 0 {
n = Count(s, sep) + 1
}
if n > len(s)+1 {
n = len(s) + 1
}
a := make([][]byte, n)
n--
i := 0
for i < n {
m := Index(s, sep)
if m < 0 {
break
}
a[i] = s[: m+sepSave : m+sepSave]
s = s[m+len(sep):]
i++
}
a[i] = s
return a[:i+1]
}
// SplitN slices s into subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, SplitN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
//
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
//
// To split around the first instance of a separator, see Cut.
func SplitN(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// SplitAfterN slices s into subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, SplitAfterN splits after each UTF-8 sequence.
// The count determines the number of subslices to return:
//
// n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices)
// n < 0: all subslices
func SplitAfterN(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n)
}
// Split slices s into all subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, Split splits after each UTF-8 sequence.
// It is equivalent to SplitN with a count of -1.
//
// To split around the first instance of a separator, see Cut.
func Split(s, sep []byte) [][]byte { return genSplit(s, sep, 0, -1) }
// SplitAfter slices s into all subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, SplitAfter splits after each UTF-8 sequence.
// It is equivalent to SplitAfterN with a count of -1.
func SplitAfter(s, sep []byte) [][]byte {
return genSplit(s, sep, len(sep), -1)
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
// Fields interprets s as a sequence of UTF-8-encoded code points.
// It splits the slice s around each instance of one or more consecutive white space
// characters, as defined by unicode.IsSpace, returning a slice of subslices of s or an
// empty slice if s contains only white space.
func Fields(s []byte) [][]byte {
// First count the fields.
// This is an exact count if s is ASCII, otherwise it is an approximation.
n := 0
wasSpace := 1
// setBits is used to track which bits are set in the bytes of s.
setBits := uint8(0)
for i := 0; i < len(s); i++ {
r := s[i]
setBits |= r
isSpace := int(asciiSpace[r])
n += wasSpace & ^isSpace
wasSpace = isSpace
}
if setBits >= utf8.RuneSelf {
// Some runes in the input slice are not ASCII.
return FieldsFunc(s, unicode.IsSpace)
}
// ASCII fast path
a := make([][]byte, n)
na := 0
fieldStart := 0
i := 0
// Skip spaces in the front of the input.
for i < len(s) && asciiSpace[s[i]] != 0 {
i++
}
fieldStart = i
for i < len(s) {
if asciiSpace[s[i]] == 0 {
i++
continue
}
a[na] = s[fieldStart:i:i]
na++
i++
// Skip spaces in between fields.
for i < len(s) && asciiSpace[s[i]] != 0 {
i++
}
fieldStart = i
}
if fieldStart < len(s) { // Last field might end at EOF.
a[na] = s[fieldStart:len(s):len(s)]
}
return a
}
// FieldsFunc interprets s as a sequence of UTF-8-encoded code points.
// It splits the slice s at each run of code points c satisfying f(c) and
// returns a slice of subslices of s. If all code points in s satisfy f(c), or
// len(s) == 0, an empty slice is returned.
//
// FieldsFunc makes no guarantees about the order in which it calls f(c)
// and assumes that f always returns the same value for a given c.
func FieldsFunc(s []byte, f func(rune) bool) [][]byte {
// A span is used to record a slice of s of the form s[start:end].
// The start index is inclusive and the end index is exclusive.
type span struct {
start int
end int
}
spans := make([]span, 0, 32)
// Find the field start and end indices.
// Doing this in a separate pass (rather than slicing the string s
// and collecting the result substrings right away) is significantly
// more efficient, possibly due to cache effects.
start := -1 // valid span start if >= 0
for i := 0; i < len(s); {
size := 1
r := rune(s[i])
if r >= utf8.RuneSelf {
r, size = utf8.DecodeRune(s[i:])
}
if f(r) {
if start >= 0 {
spans = append(spans, span{start, i})
start = -1
}
} else {
if start < 0 {
start = i
}
}
i += size
}
// Last field might end at EOF.
if start >= 0 {
spans = append(spans, span{start, len(s)})
}
// Create subslices from recorded field indices.
a := make([][]byte, len(spans))
for i, span := range spans {
a[i] = s[span.start:span.end:span.end]
}
return a
}
// Join concatenates the elements of s to create a new byte slice. The separator
// sep is placed between elements in the resulting slice.
func Join(s [][]byte, sep []byte) []byte {
if len(s) == 0 {
return []byte{}
}
if len(s) == 1 {
// Just return a copy.
return append([]byte(nil), s[0]...)
}
var n int
if len(sep) > 0 {
if len(sep) >= maxInt/(len(s)-1) {
panic("bytes: Join output length overflow")
}
n += len(sep) * (len(s) - 1)
}
for _, v := range s {
if len(v) > maxInt-n {
panic("bytes: Join output length overflow")
}
n += len(v)
}
b := bytealg.MakeNoZero(n)
bp := copy(b, s[0])
for _, v := range s[1:] {
bp += copy(b[bp:], sep)
bp += copy(b[bp:], v)
}
return b
}
// HasPrefix tests whether the byte slice s begins with prefix.
func HasPrefix(s, prefix []byte) bool {
return len(s) >= len(prefix) && Equal(s[0:len(prefix)], prefix)
}
// HasSuffix tests whether the byte slice s ends with suffix.
func HasSuffix(s, suffix []byte) bool {
return len(s) >= len(suffix) && Equal(s[len(s)-len(suffix):], suffix)
}
// Map returns a copy of the byte slice s with all its characters modified
// according to the mapping function. If mapping returns a negative value, the character is
// dropped from the byte slice with no replacement. The characters in s and the
// output are interpreted as UTF-8-encoded code points.
func Map(mapping func(r rune) rune, s []byte) []byte {
// In the worst case, the slice can grow when mapped, making
// things unpleasant. But it's so rare we barge in assuming it's
// fine. It could also shrink but that falls out naturally.
b := make([]byte, 0, len(s))
for i := 0; i < len(s); {
wid := 1
r := rune(s[i])
if r >= utf8.RuneSelf {
r, wid = utf8.DecodeRune(s[i:])
}
r = mapping(r)
if r >= 0 {
b = utf8.AppendRune(b, r)
}
i += wid
}
return b
}
// Repeat returns a new byte slice consisting of count copies of b.
//
// It panics if count is negative or if the result of (len(b) * count)
// overflows.
func Repeat(b []byte, count int) []byte {
if count == 0 {
return []byte{}
}
// Since we cannot return an error on overflow,
// we should panic if the repeat will generate an overflow.
// See golang.org/issue/16237.
if count < 0 {
panic("bytes: negative Repeat count")
}
if len(b) >= maxInt/count {
panic("bytes: Repeat output length overflow")
}
n := len(b) * count
if len(b) == 0 {
return []byte{}
}
// Past a certain chunk size it is counterproductive to use
// larger chunks as the source of the write, as when the source
// is too large we are basically just thrashing the CPU D-cache.
// So if the result length is larger than an empirically-found
// limit (8KB), we stop growing the source string once the limit
// is reached and keep reusing the same source string - that
// should therefore be always resident in the L1 cache - until we
// have completed the construction of the result.
// This yields significant speedups (up to +100%) in cases where
// the result length is large (roughly, over L2 cache size).
const chunkLimit = 8 * 1024
chunkMax := n
if chunkMax > chunkLimit {
chunkMax = chunkLimit / len(b) * len(b)
if chunkMax == 0 {
chunkMax = len(b)
}
}
nb := bytealg.MakeNoZero(n)
bp := copy(nb, b)
for bp < n {
chunk := bp
if chunk > chunkMax {
chunk = chunkMax
}
bp += copy(nb[bp:], nb[:chunk])
}
return nb
}
// ToUpper returns a copy of the byte slice s with all Unicode letters mapped to
// their upper case.
func ToUpper(s []byte) []byte {
isASCII, hasLower := true, false
for i := 0; i < len(s); i++ {
c := s[i]
if c >= utf8.RuneSelf {
isASCII = false
break
}
hasLower = hasLower || ('a' <= c && c <= 'z')
}
if isASCII { // optimize for ASCII-only byte slices.
if !hasLower {
// Just return a copy.
return append([]byte(""), s...)
}
b := bytealg.MakeNoZero(len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'a' <= c && c <= 'z' {
c -= 'a' - 'A'
}
b[i] = c
}
return b
}
return Map(unicode.ToUpper, s)
}
// ToLower returns a copy of the byte slice s with all Unicode letters mapped to
// their lower case.
func ToLower(s []byte) []byte {
isASCII, hasUpper := true, false
for i := 0; i < len(s); i++ {
c := s[i]
if c >= utf8.RuneSelf {
isASCII = false
break
}
hasUpper = hasUpper || ('A' <= c && c <= 'Z')
}
if isASCII { // optimize for ASCII-only byte slices.
if !hasUpper {
return append([]byte(""), s...)
}
b := bytealg.MakeNoZero(len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
b[i] = c
}
return b
}
return Map(unicode.ToLower, s)
}
// ToTitle treats s as UTF-8-encoded bytes and returns a copy with all the Unicode letters mapped to their title case.
func ToTitle(s []byte) []byte { return Map(unicode.ToTitle, s) }
// ToUpperSpecial treats s as UTF-8-encoded bytes and returns a copy with all the Unicode letters mapped to their
// upper case, giving priority to the special casing rules.
func ToUpperSpecial(c unicode.SpecialCase, s []byte) []byte {
return Map(c.ToUpper, s)
}
// ToLowerSpecial treats s as UTF-8-encoded bytes and returns a copy with all the Unicode letters mapped to their
// lower case, giving priority to the special casing rules.
func ToLowerSpecial(c unicode.SpecialCase, s []byte) []byte {
return Map(c.ToLower, s)
}
// ToTitleSpecial treats s as UTF-8-encoded bytes and returns a copy with all the Unicode letters mapped to their
// title case, giving priority to the special casing rules.
func ToTitleSpecial(c unicode.SpecialCase, s []byte) []byte {
return Map(c.ToTitle, s)
}
// ToValidUTF8 treats s as UTF-8-encoded bytes and returns a copy with each run of bytes
// representing invalid UTF-8 replaced with the bytes in replacement, which may be empty.
func ToValidUTF8(s, replacement []byte) []byte {
b := make([]byte, 0, len(s)+len(replacement))
invalid := false // previous byte was from an invalid UTF-8 sequence
for i := 0; i < len(s); {
c := s[i]
if c < utf8.RuneSelf {
i++
invalid = false
b = append(b, c)
continue
}
_, wid := utf8.DecodeRune(s[i:])
if wid == 1 {
i++
if !invalid {
invalid = true
b = append(b, replacement...)
}
continue
}
invalid = false
b = append(b, s[i:i+wid]...)
i += wid
}
return b
}
// isSeparator reports whether the rune could mark a word boundary.
// TODO: update when package unicode captures more of the properties.
func isSeparator(r rune) bool {
// ASCII alphanumerics and underscore are not separators
if r <= 0x7F {
switch {
case '0' <= r && r <= '9':
return false
case 'a' <= r && r <= 'z':
return false
case 'A' <= r && r <= 'Z':
return false
case r == '_':
return false
}
return true
}
// Letters and digits are not separators
if unicode.IsLetter(r) || unicode.IsDigit(r) {
return false
}
// Otherwise, all we can do for now is treat spaces as separators.
return unicode.IsSpace(r)
}
// Title treats s as UTF-8-encoded bytes and returns a copy with all Unicode letters that begin
// words mapped to their title case.
//
// Deprecated: The rule Title uses for word boundaries does not handle Unicode
// punctuation properly. Use golang.org/x/text/cases instead.
func Title(s []byte) []byte {
// Use a closure here to remember state.
// Hackish but effective. Depends on Map scanning in order and calling
// the closure once per rune.
prev := ' '
return Map(
func(r rune) rune {
if isSeparator(prev) {
prev = r
return unicode.ToTitle(r)
}
prev = r
return r
},
s)
}
// TrimLeftFunc treats s as UTF-8-encoded bytes and returns a subslice of s by slicing off
// all leading UTF-8-encoded code points c that satisfy f(c).
func TrimLeftFunc(s []byte, f func(r rune) bool) []byte {
i := indexFunc(s, f, false)
if i == -1 {
return nil
}
return s[i:]
}
// TrimRightFunc returns a subslice of s by slicing off all trailing
// UTF-8-encoded code points c that satisfy f(c).
func TrimRightFunc(s []byte, f func(r rune) bool) []byte {
i := lastIndexFunc(s, f, false)
if i >= 0 && s[i] >= utf8.RuneSelf {
_, wid := utf8.DecodeRune(s[i:])
i += wid
} else {
i++
}
return s[0:i]
}
// TrimFunc returns a subslice of s by slicing off all leading and trailing
// UTF-8-encoded code points c that satisfy f(c).
func TrimFunc(s []byte, f func(r rune) bool) []byte {
return TrimRightFunc(TrimLeftFunc(s, f), f)
}
// TrimPrefix returns s without the provided leading prefix string.
// If s doesn't start with prefix, s is returned unchanged.
func TrimPrefix(s, prefix []byte) []byte {
if HasPrefix(s, prefix) {
return s[len(prefix):]
}
return s
}
// TrimSuffix returns s without the provided trailing suffix string.
// If s doesn't end with suffix, s is returned unchanged.
func TrimSuffix(s, suffix []byte) []byte {
if HasSuffix(s, suffix) {
return s[:len(s)-len(suffix)]
}
return s
}
// IndexFunc interprets s as a sequence of UTF-8-encoded code points.
// It returns the byte index in s of the first Unicode
// code point satisfying f(c), or -1 if none do.
func IndexFunc(s []byte, f func(r rune) bool) int {
return indexFunc(s, f, true)
}
// LastIndexFunc interprets s as a sequence of UTF-8-encoded code points.
// It returns the byte index in s of the last Unicode
// code point satisfying f(c), or -1 if none do.
func LastIndexFunc(s []byte, f func(r rune) bool) int {
return lastIndexFunc(s, f, true)
}
// indexFunc is the same as IndexFunc except that if
// truth==false, the sense of the predicate function is
// inverted.
func indexFunc(s []byte, f func(r rune) bool, truth bool) int {
start := 0
for start < len(s) {
wid := 1
r := rune(s[start])
if r >= utf8.RuneSelf {
r, wid = utf8.DecodeRune(s[start:])
}
if f(r) == truth {
return start
}
start += wid
}
return -1
}
// lastIndexFunc is the same as LastIndexFunc except that if
// truth==false, the sense of the predicate function is
// inverted.
func lastIndexFunc(s []byte, f func(r rune) bool, truth bool) int {
for i := len(s); i > 0; {
r, size := rune(s[i-1]), 1
if r >= utf8.RuneSelf {
r, size = utf8.DecodeLastRune(s[0:i])
}
i -= size
if f(r) == truth {
return i
}
}
return -1
}
// asciiSet is a 32-byte value, where each bit represents the presence of a
// given ASCII character in the set. The 128-bits of the lower 16 bytes,
// starting with the least-significant bit of the lowest word to the
// most-significant bit of the highest word, map to the full range of all
// 128 ASCII characters. The 128-bits of the upper 16 bytes will be zeroed,
// ensuring that any non-ASCII character will be reported as not in the set.
// This allocates a total of 32 bytes even though the upper half
// is unused to avoid bounds checks in asciiSet.contains.
type asciiSet [8]uint32
// makeASCIISet creates a set of ASCII characters and reports whether all
// characters in chars are ASCII.
func makeASCIISet(chars string) (as asciiSet, ok bool) {
for i := 0; i < len(chars); i++ {
c := chars[i]
if c >= utf8.RuneSelf {
return as, false
}
as[c/32] |= 1 << (c % 32)
}
return as, true
}
// contains reports whether c is inside the set.
func (as *asciiSet) contains(c byte) bool {
return (as[c/32] & (1 << (c % 32))) != 0
}
// containsRune is a simplified version of strings.ContainsRune
// to avoid importing the strings package.
// We avoid bytes.ContainsRune to avoid allocating a temporary copy of s.
func containsRune(s string, r rune) bool {
for _, c := range s {
if c == r {
return true
}
}
return false
}
// Trim returns a subslice of s by slicing off all leading and
// trailing UTF-8-encoded code points contained in cutset.
func Trim(s []byte, cutset string) []byte {
if len(s) == 0 {
// This is what we've historically done.
return nil
}
if cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimLeftByte(trimRightByte(s, cutset[0]), cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimLeftASCII(trimRightASCII(s, &as), &as)
}
return trimLeftUnicode(trimRightUnicode(s, cutset), cutset)
}
// TrimLeft returns a subslice of s by slicing off all leading
// UTF-8-encoded code points contained in cutset.
func TrimLeft(s []byte, cutset string) []byte {
if len(s) == 0 {
// This is what we've historically done.
return nil
}
if cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimLeftByte(s, cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimLeftASCII(s, &as)
}
return trimLeftUnicode(s, cutset)
}
func trimLeftByte(s []byte, c byte) []byte {
for len(s) > 0 && s[0] == c {
s = s[1:]
}
if len(s) == 0 {
// This is what we've historically done.
return nil
}
return s
}
func trimLeftASCII(s []byte, as *asciiSet) []byte {
for len(s) > 0 {
if !as.contains(s[0]) {
break
}
s = s[1:]
}
if len(s) == 0 {
// This is what we've historically done.
return nil
}
return s
}
func trimLeftUnicode(s []byte, cutset string) []byte {
for len(s) > 0 {
r, n := rune(s[0]), 1
if r >= utf8.RuneSelf {
r, n = utf8.DecodeRune(s)
}
if !containsRune(cutset, r) {
break
}
s = s[n:]
}
if len(s) == 0 {
// This is what we've historically done.
return nil
}
return s
}
// TrimRight returns a subslice of s by slicing off all trailing
// UTF-8-encoded code points that are contained in cutset.
func TrimRight(s []byte, cutset string) []byte {
if len(s) == 0 || cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimRightByte(s, cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimRightASCII(s, &as)
}
return trimRightUnicode(s, cutset)
}
func trimRightByte(s []byte, c byte) []byte {
for len(s) > 0 && s[len(s)-1] == c {
s = s[:len(s)-1]
}
return s
}
func trimRightASCII(s []byte, as *asciiSet) []byte {
for len(s) > 0 {
if !as.contains(s[len(s)-1]) {
break
}
s = s[:len(s)-1]
}
return s
}
func trimRightUnicode(s []byte, cutset string) []byte {
for len(s) > 0 {
r, n := rune(s[len(s)-1]), 1
if r >= utf8.RuneSelf {
r, n = utf8.DecodeLastRune(s)
}
if !containsRune(cutset, r) {
break
}
s = s[:len(s)-n]
}
return s
}
// TrimSpace returns a subslice of s by slicing off all leading and
// trailing white space, as defined by Unicode.
func TrimSpace(s []byte) []byte {
// Fast path for ASCII: look for the first ASCII non-space byte
start := 0
for ; start < len(s); start++ {
c := s[start]
if c >= utf8.RuneSelf {
// If we run into a non-ASCII byte, fall back to the
// slower unicode-aware method on the remaining bytes
return TrimFunc(s[start:], unicode.IsSpace)
}
if asciiSpace[c] == 0 {
break
}
}
// Now look for the first ASCII non-space byte from the end
stop := len(s)
for ; stop > start; stop-- {
c := s[stop-1]
if c >= utf8.RuneSelf {
return TrimFunc(s[start:stop], unicode.IsSpace)
}
if asciiSpace[c] == 0 {
break
}
}
// At this point s[start:stop] starts and ends with an ASCII
// non-space bytes, so we're done. Non-ASCII cases have already
// been handled above.
if start == stop {
// Special case to preserve previous TrimLeftFunc behavior,
// returning nil instead of empty slice if all spaces.
return nil
}
return s[start:stop]
}
// Runes interprets s as a sequence of UTF-8-encoded code points.
// It returns a slice of runes (Unicode code points) equivalent to s.
func Runes(s []byte) []rune {
t := make([]rune, utf8.RuneCount(s))
i := 0
for len(s) > 0 {
r, l := utf8.DecodeRune(s)
t[i] = r
i++
s = s[l:]
}
return t
}
// Replace returns a copy of the slice s with the first n
// non-overlapping instances of old replaced by new.
// If old is empty, it matches at the beginning of the slice
// and after each UTF-8 sequence, yielding up to k+1 replacements
// for a k-rune slice.
// If n < 0, there is no limit on the number of replacements.
func Replace(s, old, new []byte, n int) []byte {
m := 0
if n != 0 {
// Compute number of replacements.
m = Count(s, old)
}
if m == 0 {
// Just return a copy.
return append([]byte(nil), s...)
}
if n < 0 || m < n {
n = m
}
// Apply replacements to buffer.
t := make([]byte, len(s)+n*(len(new)-len(old)))
w := 0
start := 0
for i := 0; i < n; i++ {
j := start
if len(old) == 0 {
if i > 0 {
_, wid := utf8.DecodeRune(s[start:])
j += wid
}
} else {
j += Index(s[start:], old)
}
w += copy(t[w:], s[start:j])
w += copy(t[w:], new)
start = j + len(old)
}
w += copy(t[w:], s[start:])
return t[0:w]
}
// ReplaceAll returns a copy of the slice s with all
// non-overlapping instances of old replaced by new.
// If old is empty, it matches at the beginning of the slice
// and after each UTF-8 sequence, yielding up to k+1 replacements
// for a k-rune slice.
func ReplaceAll(s, old, new []byte) []byte {
return Replace(s, old, new, -1)
}
// EqualFold reports whether s and t, interpreted as UTF-8 strings,
// are equal under simple Unicode case-folding, which is a more general
// form of case-insensitivity.
func EqualFold(s, t []byte) bool {
// ASCII fast path
i := 0
for ; i < len(s) && i < len(t); i++ {
sr := s[i]
tr := t[i]
if sr|tr >= utf8.RuneSelf {
goto hasUnicode
}
// Easy case.
if tr == sr {
continue
}
// Make sr < tr to simplify what follows.
if tr < sr {
tr, sr = sr, tr
}
// ASCII only, sr/tr must be upper/lower case
if 'A' <= sr && sr <= 'Z' && tr == sr+'a'-'A' {
continue
}
return false
}
// Check if we've exhausted both strings.
return len(s) == len(t)
hasUnicode:
s = s[i:]
t = t[i:]
for len(s) != 0 && len(t) != 0 {
// Extract first rune from each.
var sr, tr rune
if s[0] < utf8.RuneSelf {
sr, s = rune(s[0]), s[1:]
} else {
r, size := utf8.DecodeRune(s)
sr, s = r, s[size:]
}
if t[0] < utf8.RuneSelf {
tr, t = rune(t[0]), t[1:]
} else {
r, size := utf8.DecodeRune(t)
tr, t = r, t[size:]
}
// If they match, keep going; if not, return false.
// Easy case.
if tr == sr {
continue
}
// Make sr < tr to simplify what follows.
if tr < sr {
tr, sr = sr, tr
}
// Fast check for ASCII.
if tr < utf8.RuneSelf {
// ASCII only, sr/tr must be upper/lower case
if 'A' <= sr && sr <= 'Z' && tr == sr+'a'-'A' {
continue
}
return false
}
// General case. SimpleFold(x) returns the next equivalent rune > x
// or wraps around to smaller values.
r := unicode.SimpleFold(sr)
for r != sr && r < tr {
r = unicode.SimpleFold(r)
}
if r == tr {
continue
}
return false
}
// One string is empty. Are both?
return len(s) == len(t)
}
// Index returns the index of the first instance of sep in s, or -1 if sep is not present in s.
func Index(s, sep []byte) int {
n := len(sep)
switch {
case n == 0:
return 0
case n == 1:
return IndexByte(s, sep[0])
case n == len(s):
if Equal(sep, s) {
return 0
}
return -1
case n > len(s):
return -1
case n <= bytealg.MaxLen:
// Use brute force when s and sep both are small
if len(s) <= bytealg.MaxBruteForce {
return bytealg.Index(s, sep)
}
c0 := sep[0]
c1 := sep[1]
i := 0
t := len(s) - n + 1
fails := 0
for i < t {
if s[i] != c0 {
// IndexByte is faster than bytealg.Index, so use it as long as
// we're not getting lots of false positives.
o := IndexByte(s[i+1:t], c0)
if o < 0 {
return -1
}
i += o + 1
}
if s[i+1] == c1 && Equal(s[i:i+n], sep) {
return i
}
fails++
i++
// Switch to bytealg.Index when IndexByte produces too many false positives.
if fails > bytealg.Cutover(i) {
r := bytealg.Index(s[i:], sep)
if r >= 0 {
return r + i
}
return -1
}
}
return -1
}
c0 := sep[0]
c1 := sep[1]
i := 0
fails := 0
t := len(s) - n + 1
for i < t {
if s[i] != c0 {
o := IndexByte(s[i+1:t], c0)
if o < 0 {
break
}
i += o + 1
}
if s[i+1] == c1 && Equal(s[i:i+n], sep) {
return i
}
i++
fails++
if fails >= 4+i>>4 && i < t {
// Give up on IndexByte, it isn't skipping ahead
// far enough to be better than Rabin-Karp.
// Experiments (using IndexPeriodic) suggest
// the cutover is about 16 byte skips.
// TODO: if large prefixes of sep are matching
// we should cutover at even larger average skips,
// because Equal becomes that much more expensive.
// This code does not take that effect into account.
j := bytealg.IndexRabinKarpBytes(s[i:], sep)
if j < 0 {
return -1
}
return i + j
}
}
return -1
}
// Cut slices s around the first instance of sep,
// returning the text before and after sep.
// The found result reports whether sep appears in s.
// If sep does not appear in s, cut returns s, nil, false.
//
// Cut returns slices of the original slice s, not copies.
func Cut(s, sep []byte) (before, after []byte, found bool) {
if i := Index(s, sep); i >= 0 {
return s[:i], s[i+len(sep):], true
}
return s, nil, false
}
// Clone returns a copy of b[:len(b)].
// The result may have additional unused capacity.
// Clone(nil) returns nil.
func Clone(b []byte) []byte {
if b == nil {
return nil
}
return append([]byte{}, b...)
}
// CutPrefix returns s without the provided leading prefix byte slice
// and reports whether it found the prefix.
// If s doesn't start with prefix, CutPrefix returns s, false.
// If prefix is the empty byte slice, CutPrefix returns s, true.
//
// CutPrefix returns slices of the original slice s, not copies.
func CutPrefix(s, prefix []byte) (after []byte, found bool) {
if !HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}
// CutSuffix returns s without the provided ending suffix byte slice
// and reports whether it found the suffix.
// If s doesn't end with suffix, CutSuffix returns s, false.
// If suffix is the empty byte slice, CutSuffix returns s, true.
//
// CutSuffix returns slices of the original slice s, not copies.
func CutSuffix(s, suffix []byte) (before []byte, found bool) {
if !HasSuffix(s, suffix) {
return s, false
}
return s[:len(s)-len(suffix)], true
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bytes
import (
"errors"
"io"
"unicode/utf8"
)
// A Reader implements the io.Reader, io.ReaderAt, io.WriterTo, io.Seeker,
// io.ByteScanner, and io.RuneScanner interfaces by reading from
// a byte slice.
// Unlike a Buffer, a Reader is read-only and supports seeking.
// The zero value for Reader operates like a Reader of an empty slice.
type Reader struct {
s []byte
i int64 // current reading index
prevRune int // index of previous rune; or < 0
}
// Len returns the number of bytes of the unread portion of the
// slice.
func (r *Reader) Len() int {
if r.i >= int64(len(r.s)) {
return 0
}
return int(int64(len(r.s)) - r.i)
}
// Size returns the original length of the underlying byte slice.
// Size is the number of bytes available for reading via ReadAt.
// The result is unaffected by any method calls except Reset.
func (r *Reader) Size() int64 { return int64(len(r.s)) }
// Read implements the io.Reader interface.
func (r *Reader) Read(b []byte) (n int, err error) {
if r.i >= int64(len(r.s)) {
return 0, io.EOF
}
r.prevRune = -1
n = copy(b, r.s[r.i:])
r.i += int64(n)
return
}
// ReadAt implements the io.ReaderAt interface.
func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
// cannot modify state - see io.ReaderAt
if off < 0 {
return 0, errors.New("bytes.Reader.ReadAt: negative offset")
}
if off >= int64(len(r.s)) {
return 0, io.EOF
}
n = copy(b, r.s[off:])
if n < len(b) {
err = io.EOF
}
return
}
// ReadByte implements the io.ByteReader interface.
func (r *Reader) ReadByte() (byte, error) {
r.prevRune = -1
if r.i >= int64(len(r.s)) {
return 0, io.EOF
}
b := r.s[r.i]
r.i++
return b, nil
}
// UnreadByte complements ReadByte in implementing the io.ByteScanner interface.
func (r *Reader) UnreadByte() error {
if r.i <= 0 {
return errors.New("bytes.Reader.UnreadByte: at beginning of slice")
}
r.prevRune = -1
r.i--
return nil
}
// ReadRune implements the io.RuneReader interface.
func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= int64(len(r.s)) {
r.prevRune = -1
return 0, 0, io.EOF
}
r.prevRune = int(r.i)
if c := r.s[r.i]; c < utf8.RuneSelf {
r.i++
return rune(c), 1, nil
}
ch, size = utf8.DecodeRune(r.s[r.i:])
r.i += int64(size)
return
}
// UnreadRune complements ReadRune in implementing the io.RuneScanner interface.
func (r *Reader) UnreadRune() error {
if r.i <= 0 {
return errors.New("bytes.Reader.UnreadRune: at beginning of slice")
}
if r.prevRune < 0 {
return errors.New("bytes.Reader.UnreadRune: previous operation was not ReadRune")
}
r.i = int64(r.prevRune)
r.prevRune = -1
return nil
}
// Seek implements the io.Seeker interface.
func (r *Reader) Seek(offset int64, whence int) (int64, error) {
r.prevRune = -1
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = r.i + offset
case io.SeekEnd:
abs = int64(len(r.s)) + offset
default:
return 0, errors.New("bytes.Reader.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("bytes.Reader.Seek: negative position")
}
r.i = abs
return abs, nil
}
// WriteTo implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
r.prevRune = -1
if r.i >= int64(len(r.s)) {
return 0, nil
}
b := r.s[r.i:]
m, err := w.Write(b)
if m > len(b) {
panic("bytes.Reader.WriteTo: invalid Write count")
}
r.i += int64(m)
n = int64(m)
if m != len(b) && err == nil {
err = io.ErrShortWrite
}
return
}
// Reset resets the Reader to be reading from b.
func (r *Reader) Reset(b []byte) { *r = Reader{b, 0, -1} }
// NewReader returns a new Reader reading from b.
func NewReader(b []byte) *Reader { return &Reader{b, 0, -1} }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
import (
"bufio"
"io"
)
// bitReader wraps an io.Reader and provides the ability to read values,
// bit-by-bit, from it. Its Read* methods don't return the usual error
// because the error handling was verbose. Instead, any error is kept and can
// be checked afterwards.
type bitReader struct {
r io.ByteReader
n uint64
bits uint
err error
}
// newBitReader returns a new bitReader reading from r. If r is not
// already an io.ByteReader, it will be converted via a bufio.Reader.
func newBitReader(r io.Reader) bitReader {
byter, ok := r.(io.ByteReader)
if !ok {
byter = bufio.NewReader(r)
}
return bitReader{r: byter}
}
// ReadBits64 reads the given number of bits and returns them in the
// least-significant part of a uint64. In the event of an error, it returns 0
// and the error can be obtained by calling Err().
func (br *bitReader) ReadBits64(bits uint) (n uint64) {
for bits > br.bits {
b, err := br.r.ReadByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != nil {
br.err = err
return 0
}
br.n <<= 8
br.n |= uint64(b)
br.bits += 8
}
// br.n looks like this (assuming that br.bits = 14 and bits = 6):
// Bit: 111111
// 5432109876543210
//
// (6 bits, the desired output)
// |-----|
// V V
// 0101101101001110
// ^ ^
// |------------|
// br.bits (num valid bits)
//
// This the next line right shifts the desired bits into the
// least-significant places and masks off anything above.
n = (br.n >> (br.bits - bits)) & ((1 << bits) - 1)
br.bits -= bits
return
}
func (br *bitReader) ReadBits(bits uint) (n int) {
n64 := br.ReadBits64(bits)
return int(n64)
}
func (br *bitReader) ReadBit() bool {
n := br.ReadBits(1)
return n != 0
}
func (br *bitReader) Err() error {
return br.err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bzip2 implements bzip2 decompression.
package bzip2
import "io"
// There's no RFC for bzip2. I used the Wikipedia page for reference and a lot
// of guessing: https://en.wikipedia.org/wiki/Bzip2
// The source code to pyflate was useful for debugging:
// http://www.paul.sladen.org/projects/pyflate
// A StructuralError is returned when the bzip2 data is found to be
// syntactically invalid.
type StructuralError string
func (s StructuralError) Error() string {
return "bzip2 data invalid: " + string(s)
}
// A reader decompresses bzip2 compressed data.
type reader struct {
br bitReader
fileCRC uint32
blockCRC uint32
wantBlockCRC uint32
setupDone bool // true if we have parsed the bzip2 header.
blockSize int // blockSize in bytes, i.e. 900 * 1000.
eof bool
c [256]uint // the ``C'' array for the inverse BWT.
tt []uint32 // mirrors the ``tt'' array in the bzip2 source and contains the P array in the upper 24 bits.
tPos uint32 // Index of the next output byte in tt.
preRLE []uint32 // contains the RLE data still to be processed.
preRLEUsed int // number of entries of preRLE used.
lastByte int // the last byte value seen.
byteRepeats uint // the number of repeats of lastByte seen.
repeats uint // the number of copies of lastByte to output.
}
// NewReader returns an io.Reader which decompresses bzip2 data from r.
// If r does not also implement io.ByteReader,
// the decompressor may read more data than necessary from r.
func NewReader(r io.Reader) io.Reader {
bz2 := new(reader)
bz2.br = newBitReader(r)
return bz2
}
const bzip2FileMagic = 0x425a // "BZ"
const bzip2BlockMagic = 0x314159265359
const bzip2FinalMagic = 0x177245385090
// setup parses the bzip2 header.
func (bz2 *reader) setup(needMagic bool) error {
br := &bz2.br
if needMagic {
magic := br.ReadBits(16)
if magic != bzip2FileMagic {
return StructuralError("bad magic value")
}
}
t := br.ReadBits(8)
if t != 'h' {
return StructuralError("non-Huffman entropy encoding")
}
level := br.ReadBits(8)
if level < '1' || level > '9' {
return StructuralError("invalid compression level")
}
bz2.fileCRC = 0
bz2.blockSize = 100 * 1000 * (level - '0')
if bz2.blockSize > len(bz2.tt) {
bz2.tt = make([]uint32, bz2.blockSize)
}
return nil
}
func (bz2 *reader) Read(buf []byte) (n int, err error) {
if bz2.eof {
return 0, io.EOF
}
if !bz2.setupDone {
err = bz2.setup(true)
brErr := bz2.br.Err()
if brErr != nil {
err = brErr
}
if err != nil {
return 0, err
}
bz2.setupDone = true
}
n, err = bz2.read(buf)
brErr := bz2.br.Err()
if brErr != nil {
err = brErr
}
return
}
func (bz2 *reader) readFromBlock(buf []byte) int {
// bzip2 is a block based compressor, except that it has a run-length
// preprocessing step. The block based nature means that we can
// preallocate fixed-size buffers and reuse them. However, the RLE
// preprocessing would require allocating huge buffers to store the
// maximum expansion. Thus we process blocks all at once, except for
// the RLE which we decompress as required.
n := 0
for (bz2.repeats > 0 || bz2.preRLEUsed < len(bz2.preRLE)) && n < len(buf) {
// We have RLE data pending.
// The run-length encoding works like this:
// Any sequence of four equal bytes is followed by a length
// byte which contains the number of repeats of that byte to
// include. (The number of repeats can be zero.) Because we are
// decompressing on-demand our state is kept in the reader
// object.
if bz2.repeats > 0 {
buf[n] = byte(bz2.lastByte)
n++
bz2.repeats--
if bz2.repeats == 0 {
bz2.lastByte = -1
}
continue
}
bz2.tPos = bz2.preRLE[bz2.tPos]
b := byte(bz2.tPos)
bz2.tPos >>= 8
bz2.preRLEUsed++
if bz2.byteRepeats == 3 {
bz2.repeats = uint(b)
bz2.byteRepeats = 0
continue
}
if bz2.lastByte == int(b) {
bz2.byteRepeats++
} else {
bz2.byteRepeats = 0
}
bz2.lastByte = int(b)
buf[n] = b
n++
}
return n
}
func (bz2 *reader) read(buf []byte) (int, error) {
for {
n := bz2.readFromBlock(buf)
if n > 0 || len(buf) == 0 {
bz2.blockCRC = updateCRC(bz2.blockCRC, buf[:n])
return n, nil
}
// End of block. Check CRC.
if bz2.blockCRC != bz2.wantBlockCRC {
bz2.br.err = StructuralError("block checksum mismatch")
return 0, bz2.br.err
}
// Find next block.
br := &bz2.br
switch br.ReadBits64(48) {
default:
return 0, StructuralError("bad magic value found")
case bzip2BlockMagic:
// Start of block.
err := bz2.readBlock()
if err != nil {
return 0, err
}
case bzip2FinalMagic:
// Check end-of-file CRC.
wantFileCRC := uint32(br.ReadBits64(32))
if br.err != nil {
return 0, br.err
}
if bz2.fileCRC != wantFileCRC {
br.err = StructuralError("file checksum mismatch")
return 0, br.err
}
// Skip ahead to byte boundary.
// Is there a file concatenated to this one?
// It would start with BZ.
if br.bits%8 != 0 {
br.ReadBits(br.bits % 8)
}
b, err := br.r.ReadByte()
if err == io.EOF {
br.err = io.EOF
bz2.eof = true
return 0, io.EOF
}
if err != nil {
br.err = err
return 0, err
}
z, err := br.r.ReadByte()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
br.err = err
return 0, err
}
if b != 'B' || z != 'Z' {
return 0, StructuralError("bad magic value in continuation file")
}
if err := bz2.setup(false); err != nil {
return 0, err
}
}
}
}
// readBlock reads a bzip2 block. The magic number should already have been consumed.
func (bz2 *reader) readBlock() (err error) {
br := &bz2.br
bz2.wantBlockCRC = uint32(br.ReadBits64(32)) // skip checksum. TODO: check it if we can figure out what it is.
bz2.blockCRC = 0
bz2.fileCRC = (bz2.fileCRC<<1 | bz2.fileCRC>>31) ^ bz2.wantBlockCRC
randomized := br.ReadBits(1)
if randomized != 0 {
return StructuralError("deprecated randomized files")
}
origPtr := uint(br.ReadBits(24))
// If not every byte value is used in the block (i.e., it's text) then
// the symbol set is reduced. The symbols used are stored as a
// two-level, 16x16 bitmap.
symbolRangeUsedBitmap := br.ReadBits(16)
symbolPresent := make([]bool, 256)
numSymbols := 0
for symRange := uint(0); symRange < 16; symRange++ {
if symbolRangeUsedBitmap&(1<<(15-symRange)) != 0 {
bits := br.ReadBits(16)
for symbol := uint(0); symbol < 16; symbol++ {
if bits&(1<<(15-symbol)) != 0 {
symbolPresent[16*symRange+symbol] = true
numSymbols++
}
}
}
}
if numSymbols == 0 {
// There must be an EOF symbol.
return StructuralError("no symbols in input")
}
// A block uses between two and six different Huffman trees.
numHuffmanTrees := br.ReadBits(3)
if numHuffmanTrees < 2 || numHuffmanTrees > 6 {
return StructuralError("invalid number of Huffman trees")
}
// The Huffman tree can switch every 50 symbols so there's a list of
// tree indexes telling us which tree to use for each 50 symbol block.
numSelectors := br.ReadBits(15)
treeIndexes := make([]uint8, numSelectors)
// The tree indexes are move-to-front transformed and stored as unary
// numbers.
mtfTreeDecoder := newMTFDecoderWithRange(numHuffmanTrees)
for i := range treeIndexes {
c := 0
for {
inc := br.ReadBits(1)
if inc == 0 {
break
}
c++
}
if c >= numHuffmanTrees {
return StructuralError("tree index too large")
}
treeIndexes[i] = mtfTreeDecoder.Decode(c)
}
// The list of symbols for the move-to-front transform is taken from
// the previously decoded symbol bitmap.
symbols := make([]byte, numSymbols)
nextSymbol := 0
for i := 0; i < 256; i++ {
if symbolPresent[i] {
symbols[nextSymbol] = byte(i)
nextSymbol++
}
}
mtf := newMTFDecoder(symbols)
numSymbols += 2 // to account for RUNA and RUNB symbols
huffmanTrees := make([]huffmanTree, numHuffmanTrees)
// Now we decode the arrays of code-lengths for each tree.
lengths := make([]uint8, numSymbols)
for i := range huffmanTrees {
// The code lengths are delta encoded from a 5-bit base value.
length := br.ReadBits(5)
for j := range lengths {
for {
if length < 1 || length > 20 {
return StructuralError("Huffman length out of range")
}
if !br.ReadBit() {
break
}
if br.ReadBit() {
length--
} else {
length++
}
}
lengths[j] = uint8(length)
}
huffmanTrees[i], err = newHuffmanTree(lengths)
if err != nil {
return err
}
}
selectorIndex := 1 // the next tree index to use
if len(treeIndexes) == 0 {
return StructuralError("no tree selectors given")
}
if int(treeIndexes[0]) >= len(huffmanTrees) {
return StructuralError("tree selector out of range")
}
currentHuffmanTree := huffmanTrees[treeIndexes[0]]
bufIndex := 0 // indexes bz2.buf, the output buffer.
// The output of the move-to-front transform is run-length encoded and
// we merge the decoding into the Huffman parsing loop. These two
// variables accumulate the repeat count. See the Wikipedia page for
// details.
repeat := 0
repeatPower := 0
// The `C' array (used by the inverse BWT) needs to be zero initialized.
for i := range bz2.c {
bz2.c[i] = 0
}
decoded := 0 // counts the number of symbols decoded by the current tree.
for {
if decoded == 50 {
if selectorIndex >= numSelectors {
return StructuralError("insufficient selector indices for number of symbols")
}
if int(treeIndexes[selectorIndex]) >= len(huffmanTrees) {
return StructuralError("tree selector out of range")
}
currentHuffmanTree = huffmanTrees[treeIndexes[selectorIndex]]
selectorIndex++
decoded = 0
}
v := currentHuffmanTree.Decode(br)
decoded++
if v < 2 {
// This is either the RUNA or RUNB symbol.
if repeat == 0 {
repeatPower = 1
}
repeat += repeatPower << v
repeatPower <<= 1
// This limit of 2 million comes from the bzip2 source
// code. It prevents repeat from overflowing.
if repeat > 2*1024*1024 {
return StructuralError("repeat count too large")
}
continue
}
if repeat > 0 {
// We have decoded a complete run-length so we need to
// replicate the last output symbol.
if repeat > bz2.blockSize-bufIndex {
return StructuralError("repeats past end of block")
}
for i := 0; i < repeat; i++ {
b := mtf.First()
bz2.tt[bufIndex] = uint32(b)
bz2.c[b]++
bufIndex++
}
repeat = 0
}
if int(v) == numSymbols-1 {
// This is the EOF symbol. Because it's always at the
// end of the move-to-front list, and never gets moved
// to the front, it has this unique value.
break
}
// Since two metasymbols (RUNA and RUNB) have values 0 and 1,
// one would expect |v-2| to be passed to the MTF decoder.
// However, the front of the MTF list is never referenced as 0,
// it's always referenced with a run-length of 1. Thus 0
// doesn't need to be encoded and we have |v-1| in the next
// line.
b := mtf.Decode(int(v - 1))
if bufIndex >= bz2.blockSize {
return StructuralError("data exceeds block size")
}
bz2.tt[bufIndex] = uint32(b)
bz2.c[b]++
bufIndex++
}
if origPtr >= uint(bufIndex) {
return StructuralError("origPtr out of bounds")
}
// We have completed the entropy decoding. Now we can perform the
// inverse BWT and setup the RLE buffer.
bz2.preRLE = bz2.tt[:bufIndex]
bz2.preRLEUsed = 0
bz2.tPos = inverseBWT(bz2.preRLE, origPtr, bz2.c[:])
bz2.lastByte = -1
bz2.byteRepeats = 0
bz2.repeats = 0
return nil
}
// inverseBWT implements the inverse Burrows-Wheeler transform as described in
// http://www.hpl.hp.com/techreports/Compaq-DEC/SRC-RR-124.pdf, section 4.2.
// In that document, origPtr is called “I” and c is the “C” array after the
// first pass over the data. It's an argument here because we merge the first
// pass with the Huffman decoding.
//
// This also implements the “single array” method from the bzip2 source code
// which leaves the output, still shuffled, in the bottom 8 bits of tt with the
// index of the next byte in the top 24-bits. The index of the first byte is
// returned.
func inverseBWT(tt []uint32, origPtr uint, c []uint) uint32 {
sum := uint(0)
for i := 0; i < 256; i++ {
sum += c[i]
c[i] = sum - c[i]
}
for i := range tt {
b := tt[i] & 0xff
tt[c[b]] |= uint32(i) << 8
c[b]++
}
return tt[origPtr] >> 8
}
// This is a standard CRC32 like in hash/crc32 except that all the shifts are reversed,
// causing the bits in the input to be processed in the reverse of the usual order.
var crctab [256]uint32
func init() {
const poly = 0x04C11DB7
for i := range crctab {
crc := uint32(i) << 24
for j := 0; j < 8; j++ {
if crc&0x80000000 != 0 {
crc = (crc << 1) ^ poly
} else {
crc <<= 1
}
}
crctab[i] = crc
}
}
// updateCRC updates the crc value to incorporate the data in b.
// The initial value is 0.
func updateCRC(val uint32, b []byte) uint32 {
crc := ^val
for _, v := range b {
crc = crctab[byte(crc>>24)^v] ^ (crc << 8)
}
return ^crc
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
import "sort"
// A huffmanTree is a binary tree which is navigated, bit-by-bit to reach a
// symbol.
type huffmanTree struct {
// nodes contains all the non-leaf nodes in the tree. nodes[0] is the
// root of the tree and nextNode contains the index of the next element
// of nodes to use when the tree is being constructed.
nodes []huffmanNode
nextNode int
}
// A huffmanNode is a node in the tree. left and right contain indexes into the
// nodes slice of the tree. If left or right is invalidNodeValue then the child
// is a left node and its value is in leftValue/rightValue.
//
// The symbols are uint16s because bzip2 encodes not only MTF indexes in the
// tree, but also two magic values for run-length encoding and an EOF symbol.
// Thus there are more than 256 possible symbols.
type huffmanNode struct {
left, right uint16
leftValue, rightValue uint16
}
// invalidNodeValue is an invalid index which marks a leaf node in the tree.
const invalidNodeValue = 0xffff
// Decode reads bits from the given bitReader and navigates the tree until a
// symbol is found.
func (t *huffmanTree) Decode(br *bitReader) (v uint16) {
nodeIndex := uint16(0) // node 0 is the root of the tree.
for {
node := &t.nodes[nodeIndex]
var bit uint16
if br.bits > 0 {
// Get next bit - fast path.
br.bits--
bit = uint16(br.n>>(br.bits&63)) & 1
} else {
// Get next bit - slow path.
// Use ReadBits to retrieve a single bit
// from the underling io.ByteReader.
bit = uint16(br.ReadBits(1))
}
// Trick a compiler into generating conditional move instead of branch,
// by making both loads unconditional.
l, r := node.left, node.right
if bit == 1 {
nodeIndex = l
} else {
nodeIndex = r
}
if nodeIndex == invalidNodeValue {
// We found a leaf. Use the value of bit to decide
// whether is a left or a right value.
l, r := node.leftValue, node.rightValue
if bit == 1 {
v = l
} else {
v = r
}
return
}
}
}
// newHuffmanTree builds a Huffman tree from a slice containing the code
// lengths of each symbol. The maximum code length is 32 bits.
func newHuffmanTree(lengths []uint8) (huffmanTree, error) {
// There are many possible trees that assign the same code length to
// each symbol (consider reflecting a tree down the middle, for
// example). Since the code length assignments determine the
// efficiency of the tree, each of these trees is equally good. In
// order to minimize the amount of information needed to build a tree
// bzip2 uses a canonical tree so that it can be reconstructed given
// only the code length assignments.
if len(lengths) < 2 {
panic("newHuffmanTree: too few symbols")
}
var t huffmanTree
// First we sort the code length assignments by ascending code length,
// using the symbol value to break ties.
pairs := make([]huffmanSymbolLengthPair, len(lengths))
for i, length := range lengths {
pairs[i].value = uint16(i)
pairs[i].length = length
}
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].length < pairs[j].length {
return true
}
if pairs[i].length > pairs[j].length {
return false
}
if pairs[i].value < pairs[j].value {
return true
}
return false
})
// Now we assign codes to the symbols, starting with the longest code.
// We keep the codes packed into a uint32, at the most-significant end.
// So branches are taken from the MSB downwards. This makes it easy to
// sort them later.
code := uint32(0)
length := uint8(32)
codes := make([]huffmanCode, len(lengths))
for i := len(pairs) - 1; i >= 0; i-- {
if length > pairs[i].length {
length = pairs[i].length
}
codes[i].code = code
codes[i].codeLen = length
codes[i].value = pairs[i].value
// We need to 'increment' the code, which means treating |code|
// like a |length| bit number.
code += 1 << (32 - length)
}
// Now we can sort by the code so that the left half of each branch are
// grouped together, recursively.
sort.Slice(codes, func(i, j int) bool {
return codes[i].code < codes[j].code
})
t.nodes = make([]huffmanNode, len(codes))
_, err := buildHuffmanNode(&t, codes, 0)
return t, err
}
// huffmanSymbolLengthPair contains a symbol and its code length.
type huffmanSymbolLengthPair struct {
value uint16
length uint8
}
// huffmanCode contains a symbol, its code and code length.
type huffmanCode struct {
code uint32
codeLen uint8
value uint16
}
// buildHuffmanNode takes a slice of sorted huffmanCodes and builds a node in
// the Huffman tree at the given level. It returns the index of the newly
// constructed node.
func buildHuffmanNode(t *huffmanTree, codes []huffmanCode, level uint32) (nodeIndex uint16, err error) {
test := uint32(1) << (31 - level)
// We have to search the list of codes to find the divide between the left and right sides.
firstRightIndex := len(codes)
for i, code := range codes {
if code.code&test != 0 {
firstRightIndex = i
break
}
}
left := codes[:firstRightIndex]
right := codes[firstRightIndex:]
if len(left) == 0 || len(right) == 0 {
// There is a superfluous level in the Huffman tree indicating
// a bug in the encoder. However, this bug has been observed in
// the wild so we handle it.
// If this function was called recursively then we know that
// len(codes) >= 2 because, otherwise, we would have hit the
// "leaf node" case, below, and not recurred.
//
// However, for the initial call it's possible that len(codes)
// is zero or one. Both cases are invalid because a zero length
// tree cannot encode anything and a length-1 tree can only
// encode EOF and so is superfluous. We reject both.
if len(codes) < 2 {
return 0, StructuralError("empty Huffman tree")
}
// In this case the recursion doesn't always reduce the length
// of codes so we need to ensure termination via another
// mechanism.
if level == 31 {
// Since len(codes) >= 2 the only way that the values
// can match at all 32 bits is if they are equal, which
// is invalid. This ensures that we never enter
// infinite recursion.
return 0, StructuralError("equal symbols in Huffman tree")
}
if len(left) == 0 {
return buildHuffmanNode(t, right, level+1)
}
return buildHuffmanNode(t, left, level+1)
}
nodeIndex = uint16(t.nextNode)
node := &t.nodes[t.nextNode]
t.nextNode++
if len(left) == 1 {
// leaf node
node.left = invalidNodeValue
node.leftValue = left[0].value
} else {
node.left, err = buildHuffmanNode(t, left, level+1)
}
if err != nil {
return
}
if len(right) == 1 {
// leaf node
node.right = invalidNodeValue
node.rightValue = right[0].value
} else {
node.right, err = buildHuffmanNode(t, right, level+1)
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bzip2
// moveToFrontDecoder implements a move-to-front list. Such a list is an
// efficient way to transform a string with repeating elements into one with
// many small valued numbers, which is suitable for entropy encoding. It works
// by starting with an initial list of symbols and references symbols by their
// index into that list. When a symbol is referenced, it's moved to the front
// of the list. Thus, a repeated symbol ends up being encoded with many zeros,
// as the symbol will be at the front of the list after the first access.
type moveToFrontDecoder []byte
// newMTFDecoder creates a move-to-front decoder with an explicit initial list
// of symbols.
func newMTFDecoder(symbols []byte) moveToFrontDecoder {
if len(symbols) > 256 {
panic("too many symbols")
}
return moveToFrontDecoder(symbols)
}
// newMTFDecoderWithRange creates a move-to-front decoder with an initial
// symbol list of 0...n-1.
func newMTFDecoderWithRange(n int) moveToFrontDecoder {
if n > 256 {
panic("newMTFDecoderWithRange: cannot have > 256 symbols")
}
m := make([]byte, n)
for i := 0; i < n; i++ {
m[i] = byte(i)
}
return moveToFrontDecoder(m)
}
func (m moveToFrontDecoder) Decode(n int) (b byte) {
// Implement move-to-front with a simple copy. This approach
// beats more sophisticated approaches in benchmarking, probably
// because it has high locality of reference inside of a
// single cache line (most move-to-front operations have n < 64).
b = m[n]
copy(m[1:], m[:n])
m[0] = b
return
}
// First returns the symbol at the front of the list.
func (m moveToFrontDecoder) First() byte {
return m[0]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"errors"
"fmt"
"io"
"math"
)
const (
NoCompression = 0
BestSpeed = 1
BestCompression = 9
DefaultCompression = -1
// HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman
// entropy encoding. This mode is useful in compressing data that has
// already been compressed with an LZ style algorithm (e.g. Snappy or LZ4)
// that lacks an entropy encoder. Compression gains are achieved when
// certain bytes in the input stream occur more frequently than others.
//
// Note that HuffmanOnly produces a compressed output that is
// RFC 1951 compliant. That is, any valid DEFLATE decompressor will
// continue to be able to decompress this output.
HuffmanOnly = -2
)
const (
logWindowSize = 15
windowSize = 1 << logWindowSize
windowMask = windowSize - 1
// The LZ77 step produces a sequence of literal tokens and <length, offset>
// pair tokens. The offset is also known as distance. The underlying wire
// format limits the range of lengths and offsets. For example, there are
// 256 legitimate lengths: those in the range [3, 258]. This package's
// compressor uses a higher minimum match length, enabling optimizations
// such as finding matches via 32-bit loads and compares.
baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5
minMatchLength = 4 // The smallest match length that the compressor actually emits
maxMatchLength = 258 // The largest match length
baseMatchOffset = 1 // The smallest match offset
maxMatchOffset = 1 << 15 // The largest match offset
// The maximum number of tokens we put into a single flate block, just to
// stop things from getting too large.
maxFlateBlockTokens = 1 << 14
maxStoreBlockSize = 65535
hashBits = 17 // After 17 performance degrades
hashSize = 1 << hashBits
hashMask = (1 << hashBits) - 1
maxHashOffset = 1 << 24
skipNever = math.MaxInt32
)
type compressionLevel struct {
level, good, lazy, nice, chain, fastSkipHashing int
}
var levels = []compressionLevel{
{0, 0, 0, 0, 0, 0}, // NoCompression.
{1, 0, 0, 0, 0, 0}, // BestSpeed uses a custom algorithm; see deflatefast.go.
// For levels 2-3 we don't bother trying with lazy matches.
{2, 4, 0, 16, 8, 5},
{3, 4, 0, 32, 32, 6},
// Levels 4-9 use increasingly more lazy matching
// and increasingly stringent conditions for "good enough".
{4, 4, 4, 16, 16, skipNever},
{5, 8, 16, 32, 32, skipNever},
{6, 8, 16, 128, 128, skipNever},
{7, 8, 32, 128, 256, skipNever},
{8, 32, 128, 258, 1024, skipNever},
{9, 32, 258, 258, 4096, skipNever},
}
type compressor struct {
compressionLevel
w *huffmanBitWriter
bulkHasher func([]byte, []uint32)
// compression algorithm
fill func(*compressor, []byte) int // copy data to window
step func(*compressor) // process window
sync bool // requesting flush
bestSpeed *deflateFast // Encoder for BestSpeed
// Input hash chains
// hashHead[hashValue] contains the largest inputIndex with the specified hash value
// If hashHead[hashValue] is within the current window, then
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index
// with the same hash value.
chainHead int
hashHead [hashSize]uint32
hashPrev [windowSize]uint32
hashOffset int
// input window: unprocessed data is window[index:windowEnd]
index int
window []byte
windowEnd int
blockStart int // window index where current tokens start
byteAvailable bool // if true, still need to process window[index-1].
// queued output tokens
tokens []token
// deflate state
length int
offset int
maxInsertIndex int
err error
// hashMatch must be able to contain hashes for the maximum match length.
hashMatch [maxMatchLength - 1]uint32
}
func (d *compressor) fillDeflate(b []byte) int {
if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
// shift the window by windowSize
copy(d.window, d.window[windowSize:2*windowSize])
d.index -= windowSize
d.windowEnd -= windowSize
if d.blockStart >= windowSize {
d.blockStart -= windowSize
} else {
d.blockStart = math.MaxInt32
}
d.hashOffset += windowSize
if d.hashOffset > maxHashOffset {
delta := d.hashOffset - 1
d.hashOffset -= delta
d.chainHead -= delta
// Iterate over slices instead of arrays to avoid copying
// the entire table onto the stack (Issue #18625).
for i, v := range d.hashPrev[:] {
if int(v) > delta {
d.hashPrev[i] = uint32(int(v) - delta)
} else {
d.hashPrev[i] = 0
}
}
for i, v := range d.hashHead[:] {
if int(v) > delta {
d.hashHead[i] = uint32(int(v) - delta)
} else {
d.hashHead[i] = 0
}
}
}
}
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) writeBlock(tokens []token, index int) error {
if index > 0 {
var window []byte
if d.blockStart <= index {
window = d.window[d.blockStart:index]
}
d.blockStart = index
d.w.writeBlock(tokens, false, window)
return d.w.err
}
return nil
}
// fillWindow will fill the current window with the supplied
// dictionary and calculate all hashes.
// This is much faster than doing a full encode.
// Should only be used after a reset.
func (d *compressor) fillWindow(b []byte) {
// Do not fill window if we are in store-only mode.
if d.compressionLevel.level < 2 {
return
}
if d.index != 0 || d.windowEnd != 0 {
panic("internal error: fillWindow called with stale data")
}
// If we are given too much, cut it.
if len(b) > windowSize {
b = b[len(b)-windowSize:]
}
// Add all to window.
n := copy(d.window, b)
// Calculate 256 hashes at the time (more L1 cache hits)
loops := (n + 256 - minMatchLength) / 256
for j := 0; j < loops; j++ {
index := j * 256
end := index + 256 + minMatchLength - 1
if end > n {
end = n
}
toCheck := d.window[index:end]
dstSize := len(toCheck) - minMatchLength + 1
if dstSize <= 0 {
continue
}
dst := d.hashMatch[:dstSize]
d.bulkHasher(toCheck, dst)
for i, val := range dst {
di := i + index
hh := &d.hashHead[val&hashMask]
// Get previous value with the same hash.
// Our chain should point to the previous value.
d.hashPrev[di&windowMask] = *hh
// Set the head of the hash chain to us.
*hh = uint32(di + d.hashOffset)
}
}
// Update window information.
d.windowEnd = n
d.index = n
}
// Try to find a match starting at index whose length is greater than prevSize.
// We only look at chainCount possibilities before giving up.
func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) {
minMatchLook := maxMatchLength
if lookahead < minMatchLook {
minMatchLook = lookahead
}
win := d.window[0 : pos+minMatchLook]
// We quit when we get a match that's at least nice long
nice := len(win) - pos
if d.nice < nice {
nice = d.nice
}
// If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.chain
length = prevLength
if length >= d.good {
tries >>= 2
}
wEnd := win[pos+length]
wPos := win[pos:]
minIndex := pos - windowSize
for i := prevHead; tries > 0; tries-- {
if wEnd == win[i+length] {
n := matchLen(win[i:], wPos, minMatchLook)
if n > length && (n > minMatchLength || pos-i <= 4096) {
length = n
offset = pos - i
ok = true
if n >= nice {
// The match is good enough that we don't try to find a better one.
break
}
wEnd = win[pos+n]
}
}
if i == minIndex {
// hashPrev[i & windowMask] has already been overwritten, so stop now.
break
}
i = int(d.hashPrev[i&windowMask]) - d.hashOffset
if i < minIndex || i < 0 {
break
}
}
return
}
func (d *compressor) writeStoredBlock(buf []byte) error {
if d.w.writeStoredHeader(len(buf), false); d.w.err != nil {
return d.w.err
}
d.w.writeBytes(buf)
return d.w.err
}
const hashmul = 0x1e35a7bd
// hash4 returns a hash representation of the first 4 bytes
// of the supplied slice.
// The caller must ensure that len(b) >= 4.
func hash4(b []byte) uint32 {
return ((uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24) * hashmul) >> (32 - hashBits)
}
// bulkHash4 will compute hashes using the same
// algorithm as hash4.
func bulkHash4(b []byte, dst []uint32) {
if len(b) < minMatchLength {
return
}
hb := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
dst[0] = (hb * hashmul) >> (32 - hashBits)
end := len(b) - minMatchLength + 1
for i := 1; i < end; i++ {
hb = (hb << 8) | uint32(b[i+3])
dst[i] = (hb * hashmul) >> (32 - hashBits)
}
}
// matchLen returns the number of matching bytes in a and b
// up to length 'max'. Both slices must be at least 'max'
// bytes in size.
func matchLen(a, b []byte, max int) int {
a = a[:max]
b = b[:len(a)]
for i, av := range a {
if b[i] != av {
return i
}
}
return max
}
// encSpeed will compress and store the currently added data,
// if enough has been accumulated or we at the end of the stream.
// Any error that occurred will be in d.err
func (d *compressor) encSpeed() {
// We only compress if we have maxStoreBlockSize.
if d.windowEnd < maxStoreBlockSize {
if !d.sync {
return
}
// Handle small sizes.
if d.windowEnd < 128 {
switch {
case d.windowEnd == 0:
return
case d.windowEnd <= 16:
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
default:
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
d.err = d.w.err
}
d.windowEnd = 0
d.bestSpeed.reset()
return
}
}
// Encode the block.
d.tokens = d.bestSpeed.encode(d.tokens[:0], d.window[:d.windowEnd])
// If we removed less than 1/16th, Huffman compress the block.
if len(d.tokens) > d.windowEnd-(d.windowEnd>>4) {
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
} else {
d.w.writeBlockDynamic(d.tokens, false, d.window[:d.windowEnd])
}
d.err = d.w.err
d.windowEnd = 0
}
func (d *compressor) initDeflate() {
d.window = make([]byte, 2*windowSize)
d.hashOffset = 1
d.tokens = make([]token, 0, maxFlateBlockTokens+1)
d.length = minMatchLength - 1
d.offset = 0
d.byteAvailable = false
d.index = 0
d.chainHead = -1
d.bulkHasher = bulkHash4
}
func (d *compressor) deflate() {
if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
return
}
d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
Loop:
for {
if d.index > d.windowEnd {
panic("index > windowEnd")
}
lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength {
if !d.sync {
break Loop
}
if d.index > d.windowEnd {
panic("index > windowEnd")
}
if lookahead == 0 {
// Flush current output block if any.
if d.byteAvailable {
// There is still one pending token that needs to be flushed
d.tokens = append(d.tokens, literalToken(uint32(d.window[d.index-1])))
d.byteAvailable = false
}
if len(d.tokens) > 0 {
if d.err = d.writeBlock(d.tokens, d.index); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
break Loop
}
}
if d.index < d.maxInsertIndex {
// Update the hash
hash := hash4(d.window[d.index : d.index+minMatchLength])
hh := &d.hashHead[hash&hashMask]
d.chainHead = int(*hh)
d.hashPrev[d.index&windowMask] = uint32(d.chainHead)
*hh = uint32(d.index + d.hashOffset)
}
prevLength := d.length
prevOffset := d.offset
d.length = minMatchLength - 1
d.offset = 0
minIndex := d.index - windowSize
if minIndex < 0 {
minIndex = 0
}
if d.chainHead-d.hashOffset >= minIndex &&
(d.fastSkipHashing != skipNever && lookahead > minMatchLength-1 ||
d.fastSkipHashing == skipNever && lookahead > prevLength && prevLength < d.lazy) {
if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead-d.hashOffset, minMatchLength-1, lookahead); ok {
d.length = newLength
d.offset = newOffset
}
}
if d.fastSkipHashing != skipNever && d.length >= minMatchLength ||
d.fastSkipHashing == skipNever && prevLength >= minMatchLength && d.length <= prevLength {
// There was a match at the previous step, and the current match is
// not better. Output the previous match.
if d.fastSkipHashing != skipNever {
d.tokens = append(d.tokens, matchToken(uint32(d.length-baseMatchLength), uint32(d.offset-baseMatchOffset)))
} else {
d.tokens = append(d.tokens, matchToken(uint32(prevLength-baseMatchLength), uint32(prevOffset-baseMatchOffset)))
}
// Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash
// table.
if d.length <= d.fastSkipHashing {
var newIndex int
if d.fastSkipHashing != skipNever {
newIndex = d.index + d.length
} else {
newIndex = d.index + prevLength - 1
}
index := d.index
for index++; index < newIndex; index++ {
if index < d.maxInsertIndex {
hash := hash4(d.window[index : index+minMatchLength])
// Get previous value with the same hash.
// Our chain should point to the previous value.
hh := &d.hashHead[hash&hashMask]
d.hashPrev[index&windowMask] = *hh
// Set the head of the hash chain to us.
*hh = uint32(index + d.hashOffset)
}
}
d.index = index
if d.fastSkipHashing == skipNever {
d.byteAvailable = false
d.length = minMatchLength - 1
}
} else {
// For matches this long, we don't bother inserting each individual
// item into the table.
d.index += d.length
}
if len(d.tokens) == maxFlateBlockTokens {
// The block includes the current character
if d.err = d.writeBlock(d.tokens, d.index); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
} else {
if d.fastSkipHashing != skipNever || d.byteAvailable {
i := d.index - 1
if d.fastSkipHashing != skipNever {
i = d.index
}
d.tokens = append(d.tokens, literalToken(uint32(d.window[i])))
if len(d.tokens) == maxFlateBlockTokens {
if d.err = d.writeBlock(d.tokens, i+1); d.err != nil {
return
}
d.tokens = d.tokens[:0]
}
}
d.index++
if d.fastSkipHashing == skipNever {
d.byteAvailable = true
}
}
}
}
func (d *compressor) fillStore(b []byte) int {
n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) store() {
if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) {
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
d.windowEnd = 0
}
}
// storeHuff compresses and stores the currently added data
// when the d.window is full or we are at the end of the stream.
// Any error that occurred will be in d.err
func (d *compressor) storeHuff() {
if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 {
return
}
d.w.writeBlockHuff(false, d.window[:d.windowEnd])
d.err = d.w.err
d.windowEnd = 0
}
func (d *compressor) write(b []byte) (n int, err error) {
if d.err != nil {
return 0, d.err
}
n = len(b)
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
if d.err != nil {
return 0, d.err
}
}
return n, nil
}
func (d *compressor) syncFlush() error {
if d.err != nil {
return d.err
}
d.sync = true
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.err = d.w.err
}
d.sync = false
return d.err
}
func (d *compressor) init(w io.Writer, level int) (err error) {
d.w = newHuffmanBitWriter(w)
switch {
case level == NoCompression:
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).store
case level == HuffmanOnly:
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).storeHuff
case level == BestSpeed:
d.compressionLevel = levels[level]
d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).encSpeed
d.bestSpeed = newDeflateFast()
d.tokens = make([]token, maxStoreBlockSize)
case level == DefaultCompression:
level = 6
fallthrough
case 2 <= level && level <= 9:
d.compressionLevel = levels[level]
d.initDeflate()
d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate
default:
return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level)
}
return nil
}
func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
d.sync = false
d.err = nil
switch d.compressionLevel.level {
case NoCompression:
d.windowEnd = 0
case BestSpeed:
d.windowEnd = 0
d.tokens = d.tokens[:0]
d.bestSpeed.reset()
default:
d.chainHead = -1
for i := range d.hashHead {
d.hashHead[i] = 0
}
for i := range d.hashPrev {
d.hashPrev[i] = 0
}
d.hashOffset = 1
d.index, d.windowEnd = 0, 0
d.blockStart, d.byteAvailable = 0, false
d.tokens = d.tokens[:0]
d.length = minMatchLength - 1
d.offset = 0
d.maxInsertIndex = 0
}
}
func (d *compressor) close() error {
if d.err == errWriterClosed {
return nil
}
if d.err != nil {
return d.err
}
d.sync = true
d.step(d)
if d.err != nil {
return d.err
}
if d.w.writeStoredHeader(0, true); d.w.err != nil {
return d.w.err
}
d.w.flush()
if d.w.err != nil {
return d.w.err
}
d.err = errWriterClosed
return nil
}
// NewWriter returns a new Writer compressing data at the given level.
// Following zlib, levels range from 1 (BestSpeed) to 9 (BestCompression);
// higher levels typically run slower but compress more. Level 0
// (NoCompression) does not attempt any compression; it only adds the
// necessary DEFLATE framing.
// Level -1 (DefaultCompression) uses the default compression level.
// Level -2 (HuffmanOnly) will use Huffman compression only, giving
// a very fast compression for all types of input, but sacrificing considerable
// compression efficiency.
//
// If level is in the range [-2, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) {
var dw Writer
if err := dw.d.init(w, level); err != nil {
return nil, err
}
return &dw, nil
}
// NewWriterDict is like NewWriter but initializes the new
// Writer with a preset dictionary. The returned Writer behaves
// as if the dictionary had been written to it without producing
// any compressed output. The compressed data written to w
// can only be decompressed by a Reader initialized with the
// same dictionary.
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
dw := &dictWriter{w}
zw, err := NewWriter(dw, level)
if err != nil {
return nil, err
}
zw.d.fillWindow(dict)
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
return zw, err
}
type dictWriter struct {
w io.Writer
}
func (w *dictWriter) Write(b []byte) (n int, err error) {
return w.w.Write(b)
}
var errWriterClosed = errors.New("flate: closed writer")
// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
d compressor
dict []byte
}
// Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err error) {
return w.d.write(data)
}
// Flush flushes any pending data to the underlying writer.
// It is useful mainly in compressed network protocols, to ensure that
// a remote reader has enough data to reconstruct a packet.
// Flush does not return until the data has been written.
// Calling Flush when there is no pending data still causes the Writer
// to emit a sync marker of at least 4 bytes.
// If the underlying writer returns an error, Flush returns that error.
//
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
func (w *Writer) Flush() error {
// For more about flushing:
// https://www.bolet.org/~pornin/deflate-flush.html
return w.d.syncFlush()
}
// Close flushes and closes the writer.
func (w *Writer) Close() error {
return w.d.close()
}
// Reset discards the writer's state and makes it equivalent to
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
w.d.fillWindow(w.dict)
} else {
// w was created with NewWriter
w.d.reset(dst)
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import "math"
// This encoding algorithm, which prioritizes speed over output size, is
// based on Snappy's LZ77-style encoder: github.com/golang/snappy
const (
tableBits = 14 // Bits used in the table.
tableSize = 1 << tableBits // Size of the table.
tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks.
tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
// Reset the buffer offset when reaching this.
// Offsets are stored between blocks as int32 values.
// Since the offset we are checking against is at the beginning
// of the buffer, we need to subtract the current and input
// buffer to not risk overflowing the int32.
bufferReset = math.MaxInt32 - maxStoreBlockSize*2
)
func load32(b []byte, i int32) uint32 {
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int32) uint64 {
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func hash(u uint32) uint32 {
return (u * 0x1e35a7bd) >> tableShift
}
// These constants are defined by the Snappy implementation so that its
// assembly implementation can fast-path some 16-bytes-at-a-time copies. They
// aren't necessary in the pure Go implementation, as we don't use those same
// optimizations, but using the same thresholds doesn't really hurt.
const (
inputMargin = 16 - 1
minNonLiteralBlockSize = 1 + 1 + inputMargin
)
type tableEntry struct {
val uint32 // Value at destination
offset int32
}
// deflateFast maintains the table for matches,
// and the previous byte block for cross block matching.
type deflateFast struct {
table [tableSize]tableEntry
prev []byte // Previous block, zero length if unknown.
cur int32 // Current match offset.
}
func newDeflateFast() *deflateFast {
return &deflateFast{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)}
}
// encode encodes a block given in src and appends tokens
// to dst and returns the result.
func (e *deflateFast) encode(dst []token, src []byte) []token {
// Ensure that e.cur doesn't wrap.
if e.cur >= bufferReset {
e.shiftOffsets()
}
// This check isn't in the Snappy implementation, but there, the caller
// instead of the callee handles this case.
if len(src) < minNonLiteralBlockSize {
e.cur += maxStoreBlockSize
e.prev = e.prev[:0]
return emitLiteral(dst, src)
}
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := int32(len(src) - inputMargin)
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := int32(0)
s := int32(0)
cv := load32(src, s)
nextHash := hash(cv)
for {
// Copied from the C++ snappy implementation:
//
// Heuristic match skipping: If 32 bytes are scanned with no matches
// found, start looking only at every other byte. If 32 more bytes are
// scanned (or skipped), look at every third byte, etc.. When a match
// is found, immediately go back to looking at every byte. This is a
// small loss (~5% performance, ~0.1% density) for compressible data
// due to more bookkeeping, but for non-compressible data (such as
// JPEG) it's a huge win since the compressor quickly "realizes" the
// data is incompressible and doesn't bother looking for matches
// everywhere.
//
// The "skip" variable keeps track of how many bytes there are since
// the last match; dividing it by 32 (ie. right-shifting by five) gives
// the number of bytes to move ahead for each iteration.
skip := int32(32)
nextS := s
var candidate tableEntry
for {
s = nextS
bytesBetweenHashLookups := skip >> 5
nextS = s + bytesBetweenHashLookups
skip += bytesBetweenHashLookups
if nextS > sLimit {
goto emitRemainder
}
candidate = e.table[nextHash&tableMask]
now := load32(src, nextS)
e.table[nextHash&tableMask] = tableEntry{offset: s + e.cur, val: cv}
nextHash = hash(now)
offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || cv != candidate.val {
// Out of range or not matched.
cv = now
continue
}
break
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
dst = emitLiteral(dst, src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
// Extend the 4-byte match as long as possible.
//
s += 4
t := candidate.offset - e.cur + 4
l := e.matchLen(s, t, src)
// matchToken is flate's equivalent of Snappy's emitCopy. (length,offset)
dst = append(dst, matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset)))
s += l
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// We could immediately start working at s now, but to improve
// compression we first update the hash table at s-1 and at s. If
// another emitCopy is not our next move, also calculate nextHash
// at s+1. At least on GOARCH=amd64, these three hash calculations
// are faster as one load64 call (with some shifts) instead of
// three load32 calls.
x := load64(src, s-1)
prevHash := hash(uint32(x))
e.table[prevHash&tableMask] = tableEntry{offset: e.cur + s - 1, val: uint32(x)}
x >>= 8
currHash := hash(uint32(x))
candidate = e.table[currHash&tableMask]
e.table[currHash&tableMask] = tableEntry{offset: e.cur + s, val: uint32(x)}
offset := s - (candidate.offset - e.cur)
if offset > maxMatchOffset || uint32(x) != candidate.val {
cv = uint32(x >> 8)
nextHash = hash(cv)
s++
break
}
}
}
emitRemainder:
if int(nextEmit) < len(src) {
dst = emitLiteral(dst, src[nextEmit:])
}
e.cur += int32(len(src))
e.prev = e.prev[:len(src)]
copy(e.prev, src)
return dst
}
func emitLiteral(dst []token, lit []byte) []token {
for _, v := range lit {
dst = append(dst, literalToken(uint32(v)))
}
return dst
}
// matchLen returns the match length between src[s:] and src[t:].
// t can be negative to indicate the match is starting in e.prev.
// We assume that src[s-4:s] and src[t-4:t] already match.
func (e *deflateFast) matchLen(s, t int32, src []byte) int32 {
s1 := int(s) + maxMatchLength - 4
if s1 > len(src) {
s1 = len(src)
}
// If we are inside the current block
if t >= 0 {
b := src[t:]
a := src[s:s1]
b = b[:len(a)]
// Extend the match to be as long as possible.
for i := range a {
if a[i] != b[i] {
return int32(i)
}
}
return int32(len(a))
}
// We found a match in the previous block.
tp := int32(len(e.prev)) + t
if tp < 0 {
return 0
}
// Extend the match to be as long as possible.
a := src[s:s1]
b := e.prev[tp:]
if len(b) > len(a) {
b = b[:len(a)]
}
a = a[:len(b)]
for i := range b {
if a[i] != b[i] {
return int32(i)
}
}
// If we reached our limit, we matched everything we are
// allowed to in the previous block and we return.
n := int32(len(b))
if int(s+n) == s1 {
return n
}
// Continue looking for more matches in the current block.
a = src[s+n : s1]
b = src[:len(a)]
for i := range a {
if a[i] != b[i] {
return int32(i) + n
}
}
return int32(len(a)) + n
}
// Reset resets the encoding history.
// This ensures that no matches are made to the previous block.
func (e *deflateFast) reset() {
e.prev = e.prev[:0]
// Bump the offset, so all matches will fail distance check.
// Nothing should be >= e.cur in the table.
e.cur += maxMatchOffset
// Protect against e.cur wraparound.
if e.cur >= bufferReset {
e.shiftOffsets()
}
}
// shiftOffsets will shift down all match offset.
// This is only called in rare situations to prevent integer overflow.
//
// See https://golang.org/issue/18636 and https://github.com/golang/go/issues/34121.
func (e *deflateFast) shiftOffsets() {
if len(e.prev) == 0 {
// We have no history; just clear the table.
for i := range e.table[:] {
e.table[i] = tableEntry{}
}
e.cur = maxMatchOffset + 1
return
}
// Shift down everything in the table that isn't already too far away.
for i := range e.table[:] {
v := e.table[i].offset - e.cur + maxMatchOffset + 1
if v < 0 {
// We want to reset e.cur to maxMatchOffset + 1, so we need to shift
// all table entries down by (e.cur - (maxMatchOffset + 1)).
// Because we ignore matches > maxMatchOffset, we can cap
// any negative offsets at 0.
v = 0
}
e.table[i].offset = v
}
e.cur = maxMatchOffset + 1
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
// dictDecoder implements the LZ77 sliding dictionary as used in decompression.
// LZ77 decompresses data through sequences of two forms of commands:
//
// - Literal insertions: Runs of one or more symbols are inserted into the data
// stream as is. This is accomplished through the writeByte method for a
// single symbol, or combinations of writeSlice/writeMark for multiple symbols.
// Any valid stream must start with a literal insertion if no preset dictionary
// is used.
//
// - Backward copies: Runs of one or more symbols are copied from previously
// emitted data. Backward copies come as the tuple (dist, length) where dist
// determines how far back in the stream to copy from and length determines how
// many bytes to copy. Note that it is valid for the length to be greater than
// the distance. Since LZ77 uses forward copies, that situation is used to
// perform a form of run-length encoding on repeated runs of symbols.
// The writeCopy and tryWriteCopy are used to implement this command.
//
// For performance reasons, this implementation performs little to no sanity
// checks about the arguments. As such, the invariants documented for each
// method call must be respected.
type dictDecoder struct {
hist []byte // Sliding window history
// Invariant: 0 <= rdPos <= wrPos <= len(hist)
wrPos int // Current output position in buffer
rdPos int // Have emitted hist[:rdPos] already
full bool // Has a full window length been written yet?
}
// init initializes dictDecoder to have a sliding window dictionary of the given
// size. If a preset dict is provided, it will initialize the dictionary with
// the contents of dict.
func (dd *dictDecoder) init(size int, dict []byte) {
*dd = dictDecoder{hist: dd.hist}
if cap(dd.hist) < size {
dd.hist = make([]byte, size)
}
dd.hist = dd.hist[:size]
if len(dict) > len(dd.hist) {
dict = dict[len(dict)-len(dd.hist):]
}
dd.wrPos = copy(dd.hist, dict)
if dd.wrPos == len(dd.hist) {
dd.wrPos = 0
dd.full = true
}
dd.rdPos = dd.wrPos
}
// histSize reports the total amount of historical data in the dictionary.
func (dd *dictDecoder) histSize() int {
if dd.full {
return len(dd.hist)
}
return dd.wrPos
}
// availRead reports the number of bytes that can be flushed by readFlush.
func (dd *dictDecoder) availRead() int {
return dd.wrPos - dd.rdPos
}
// availWrite reports the available amount of output buffer space.
func (dd *dictDecoder) availWrite() int {
return len(dd.hist) - dd.wrPos
}
// writeSlice returns a slice of the available buffer to write data to.
//
// This invariant will be kept: len(s) <= availWrite()
func (dd *dictDecoder) writeSlice() []byte {
return dd.hist[dd.wrPos:]
}
// writeMark advances the writer pointer by cnt.
//
// This invariant must be kept: 0 <= cnt <= availWrite()
func (dd *dictDecoder) writeMark(cnt int) {
dd.wrPos += cnt
}
// writeByte writes a single byte to the dictionary.
//
// This invariant must be kept: 0 < availWrite()
func (dd *dictDecoder) writeByte(c byte) {
dd.hist[dd.wrPos] = c
dd.wrPos++
}
// writeCopy copies a string at a given (dist, length) to the output.
// This returns the number of bytes copied and may be less than the requested
// length if the available space in the output buffer is too small.
//
// This invariant must be kept: 0 < dist <= histSize()
func (dd *dictDecoder) writeCopy(dist, length int) int {
dstBase := dd.wrPos
dstPos := dstBase
srcPos := dstPos - dist
endPos := dstPos + length
if endPos > len(dd.hist) {
endPos = len(dd.hist)
}
// Copy non-overlapping section after destination position.
//
// This section is non-overlapping in that the copy length for this section
// is always less than or equal to the backwards distance. This can occur
// if a distance refers to data that wraps-around in the buffer.
// Thus, a backwards copy is performed here; that is, the exact bytes in
// the source prior to the copy is placed in the destination.
if srcPos < 0 {
srcPos += len(dd.hist)
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:])
srcPos = 0
}
// Copy possibly overlapping section before destination position.
//
// This section can overlap if the copy length for this section is larger
// than the backwards distance. This is allowed by LZ77 so that repeated
// strings can be succinctly represented using (dist, length) pairs.
// Thus, a forwards copy is performed here; that is, the bytes copied is
// possibly dependent on the resulting bytes in the destination as the copy
// progresses along. This is functionally equivalent to the following:
//
// for i := 0; i < endPos-dstPos; i++ {
// dd.hist[dstPos+i] = dd.hist[srcPos+i]
// }
// dstPos = endPos
//
for dstPos < endPos {
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
}
dd.wrPos = dstPos
return dstPos - dstBase
}
// tryWriteCopy tries to copy a string at a given (distance, length) to the
// output. This specialized version is optimized for short distances.
//
// This method is designed to be inlined for performance reasons.
//
// This invariant must be kept: 0 < dist <= histSize()
func (dd *dictDecoder) tryWriteCopy(dist, length int) int {
dstPos := dd.wrPos
endPos := dstPos + length
if dstPos < dist || endPos > len(dd.hist) {
return 0
}
dstBase := dstPos
srcPos := dstPos - dist
// Copy possibly overlapping section before destination position.
for dstPos < endPos {
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
}
dd.wrPos = dstPos
return dstPos - dstBase
}
// readFlush returns a slice of the historical buffer that is ready to be
// emitted to the user. The data returned by readFlush must be fully consumed
// before calling any other dictDecoder methods.
func (dd *dictDecoder) readFlush() []byte {
toRead := dd.hist[dd.rdPos:dd.wrPos]
dd.rdPos = dd.wrPos
if dd.wrPos == len(dd.hist) {
dd.wrPos, dd.rdPos = 0, 0
dd.full = true
}
return toRead
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"io"
)
const (
// The largest offset code.
offsetCodeCount = 30
// The special code used to mark the end of a block.
endBlockMarker = 256
// The first length code.
lengthCodesStart = 257
// The number of codegen codes.
codegenCodeCount = 19
badCode = 255
// bufferFlushSize indicates the buffer size
// after which bytes are flushed to the writer.
// Should preferably be a multiple of 6, since
// we accumulate 6 bytes between writes to the buffer.
bufferFlushSize = 240
// bufferSize is the actual output byte buffer size.
// It must have additional headroom for a flush
// which can contain up to 8 bytes.
bufferSize = bufferFlushSize + 8
)
// The number of extra bits needed by length code X - LENGTH_CODES_START.
var lengthExtraBits = []int8{
/* 257 */ 0, 0, 0,
/* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
/* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
/* 280 */ 4, 5, 5, 5, 5, 0,
}
// The length indicated by length code X - LENGTH_CODES_START.
var lengthBase = []uint32{
0, 1, 2, 3, 4, 5, 6, 7, 8, 10,
12, 14, 16, 20, 24, 28, 32, 40, 48, 56,
64, 80, 96, 112, 128, 160, 192, 224, 255,
}
// offset code word extra bits.
var offsetExtraBits = []int8{
0, 0, 0, 0, 1, 1, 2, 2, 3, 3,
4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
}
var offsetBase = []uint32{
0x000000, 0x000001, 0x000002, 0x000003, 0x000004,
0x000006, 0x000008, 0x00000c, 0x000010, 0x000018,
0x000020, 0x000030, 0x000040, 0x000060, 0x000080,
0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300,
0x000400, 0x000600, 0x000800, 0x000c00, 0x001000,
0x001800, 0x002000, 0x003000, 0x004000, 0x006000,
}
// The odd order in which the codegen code sizes are written.
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
type huffmanBitWriter struct {
// writer is the underlying writer.
// Do not use it directly; use the write method, which ensures
// that Write errors are sticky.
writer io.Writer
// Data waiting to be written is bytes[0:nbytes]
// and then the low nbits of bits. Data is always written
// sequentially into the bytes array.
bits uint64
nbits uint
bytes [bufferSize]byte
codegenFreq [codegenCodeCount]int32
nbytes int
literalFreq []int32
offsetFreq []int32
codegen []uint8
literalEncoding *huffmanEncoder
offsetEncoding *huffmanEncoder
codegenEncoding *huffmanEncoder
err error
}
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{
writer: w,
literalFreq: make([]int32, maxNumLit),
offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
literalEncoding: newHuffmanEncoder(maxNumLit),
codegenEncoding: newHuffmanEncoder(codegenCodeCount),
offsetEncoding: newHuffmanEncoder(offsetCodeCount),
}
}
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.writer = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
}
func (w *huffmanBitWriter) flush() {
if w.err != nil {
w.nbits = 0
return
}
n := w.nbytes
for w.nbits != 0 {
w.bytes[n] = byte(w.bits)
w.bits >>= 8
if w.nbits > 8 { // Avoid underflow
w.nbits -= 8
} else {
w.nbits = 0
}
n++
}
w.bits = 0
w.write(w.bytes[:n])
w.nbytes = 0
}
func (w *huffmanBitWriter) write(b []byte) {
if w.err != nil {
return
}
_, w.err = w.writer.Write(b)
}
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
if w.err != nil {
return
}
w.bits |= uint64(b) << w.nbits
w.nbits += nb
if w.nbits >= 48 {
bits := w.bits
w.bits >>= 48
w.nbits -= 48
n := w.nbytes
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
}
}
func (w *huffmanBitWriter) writeBytes(bytes []byte) {
if w.err != nil {
return
}
n := w.nbytes
if w.nbits&7 != 0 {
w.err = InternalError("writeBytes with unfinished bits")
return
}
for w.nbits != 0 {
w.bytes[n] = byte(w.bits)
w.bits >>= 8
w.nbits -= 8
n++
}
if n != 0 {
w.write(w.bytes[:n])
}
w.nbytes = 0
w.write(bytes)
}
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
// the literal and offset lengths arrays (which are concatenated into a single
// array). This method generates that run-length encoding.
//
// The result is written into the codegen array, and the frequencies
// of each code is written into the codegenFreq array.
// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional
// information. Code badCode is an end marker
//
// numLiterals The number of literals in literalEncoding
// numOffsets The number of offsets in offsetEncoding
// litenc, offenc The literal and offset encoder to use
func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) {
for i := range w.codegenFreq {
w.codegenFreq[i] = 0
}
// Note that we are using codegen both as a temporary variable for holding
// a copy of the frequencies, and as the place where we put the result.
// This is fine because the output is always shorter than the input used
// so far.
codegen := w.codegen // cache
// Copy the concatenated code sizes to codegen. Put a marker at the end.
cgnl := codegen[:numLiterals]
for i := range cgnl {
cgnl[i] = uint8(litEnc.codes[i].len)
}
cgnl = codegen[numLiterals : numLiterals+numOffsets]
for i := range cgnl {
cgnl[i] = uint8(offEnc.codes[i].len)
}
codegen[numLiterals+numOffsets] = badCode
size := codegen[0]
count := 1
outIndex := 0
for inIndex := 1; size != badCode; inIndex++ {
// INVARIANT: We have seen "count" copies of size that have not yet
// had output generated for them.
nextSize := codegen[inIndex]
if nextSize == size {
count++
continue
}
// We need to generate codegen indicating "count" of size.
if size != 0 {
codegen[outIndex] = size
outIndex++
w.codegenFreq[size]++
count--
for count >= 3 {
n := 6
if n > count {
n = count
}
codegen[outIndex] = 16
outIndex++
codegen[outIndex] = uint8(n - 3)
outIndex++
w.codegenFreq[16]++
count -= n
}
} else {
for count >= 11 {
n := 138
if n > count {
n = count
}
codegen[outIndex] = 18
outIndex++
codegen[outIndex] = uint8(n - 11)
outIndex++
w.codegenFreq[18]++
count -= n
}
if count >= 3 {
// count >= 3 && count <= 10
codegen[outIndex] = 17
outIndex++
codegen[outIndex] = uint8(count - 3)
outIndex++
w.codegenFreq[17]++
count = 0
}
}
count--
for ; count >= 0; count-- {
codegen[outIndex] = size
outIndex++
w.codegenFreq[size]++
}
// Set up invariant for next time through the loop.
size = nextSize
count = 1
}
// Marker indicating the end of the codegen.
codegen[outIndex] = badCode
}
// dynamicSize returns the size of dynamically encoded data in bits.
func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) {
numCodegens = len(w.codegenFreq)
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
numCodegens--
}
header := 3 + 5 + 5 + 4 + (3 * numCodegens) +
w.codegenEncoding.bitLength(w.codegenFreq[:]) +
int(w.codegenFreq[16])*2 +
int(w.codegenFreq[17])*3 +
int(w.codegenFreq[18])*7
size = header +
litEnc.bitLength(w.literalFreq) +
offEnc.bitLength(w.offsetFreq) +
extraBits
return size, numCodegens
}
// fixedSize returns the size of dynamically encoded data in bits.
func (w *huffmanBitWriter) fixedSize(extraBits int) int {
return 3 +
fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedOffsetEncoding.bitLength(w.offsetFreq) +
extraBits
}
// storedSize calculates the stored size, including header.
// The function returns the size in bits and whether the block
// fits inside a single block.
func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) {
if in == nil {
return 0, false
}
if len(in) <= maxStoreBlockSize {
return (len(in) + 5) * 8, true
}
return 0, false
}
func (w *huffmanBitWriter) writeCode(c hcode) {
if w.err != nil {
return
}
w.bits |= uint64(c.code) << w.nbits
w.nbits += uint(c.len)
if w.nbits >= 48 {
bits := w.bits
w.bits >>= 48
w.nbits -= 48
n := w.nbytes
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
}
}
// Write the header of a dynamic Huffman block to the output stream.
//
// numLiterals The number of literals specified in codegen
// numOffsets The number of offsets specified in codegen
// numCodegens The number of codegens used in codegen
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
if w.err != nil {
return
}
var firstBits int32 = 4
if isEof {
firstBits = 5
}
w.writeBits(firstBits, 3)
w.writeBits(int32(numLiterals-257), 5)
w.writeBits(int32(numOffsets-1), 5)
w.writeBits(int32(numCodegens-4), 4)
for i := 0; i < numCodegens; i++ {
value := uint(w.codegenEncoding.codes[codegenOrder[i]].len)
w.writeBits(int32(value), 3)
}
i := 0
for {
var codeWord int = int(w.codegen[i])
i++
if codeWord == badCode {
break
}
w.writeCode(w.codegenEncoding.codes[uint32(codeWord)])
switch codeWord {
case 16:
w.writeBits(int32(w.codegen[i]), 2)
i++
case 17:
w.writeBits(int32(w.codegen[i]), 3)
i++
case 18:
w.writeBits(int32(w.codegen[i]), 7)
i++
}
}
}
func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) {
if w.err != nil {
return
}
var flag int32
if isEof {
flag = 1
}
w.writeBits(flag, 3)
w.flush()
w.writeBits(int32(length), 16)
w.writeBits(int32(^uint16(length)), 16)
}
func (w *huffmanBitWriter) writeFixedHeader(isEof bool) {
if w.err != nil {
return
}
// Indicate that we are a fixed Huffman block
var value int32 = 2
if isEof {
value = 3
}
w.writeBits(value, 3)
}
// writeBlock will write a block of tokens with the smallest encoding.
// The original input can be supplied, and if the huffman encoded data
// is larger than the original bytes, the data will be written as a
// stored block.
// If the input is nil, the tokens will always be Huffman encoded.
func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
if w.err != nil {
return
}
tokens = append(tokens, endBlockMarker)
numLiterals, numOffsets := w.indexTokens(tokens)
var extraBits int
storedSize, storable := w.storedSize(input)
if storable {
// We only bother calculating the costs of the extra bits required by
// the length of offset fields (which will be the same for both fixed
// and dynamic encoding), if we need to compare those two encodings
// against stored encoding.
for lengthCode := lengthCodesStart + 8; lengthCode < numLiterals; lengthCode++ {
// First eight length codes have extra size = 0.
extraBits += int(w.literalFreq[lengthCode]) * int(lengthExtraBits[lengthCode-lengthCodesStart])
}
for offsetCode := 4; offsetCode < numOffsets; offsetCode++ {
// First four offset codes have extra size = 0.
extraBits += int(w.offsetFreq[offsetCode]) * int(offsetExtraBits[offsetCode])
}
}
// Figure out smallest code.
// Fixed Huffman baseline.
var literalEncoding = fixedLiteralEncoding
var offsetEncoding = fixedOffsetEncoding
var size = w.fixedSize(extraBits)
// Dynamic Huffman?
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits)
if dynamicSize < size {
size = dynamicSize
literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding
}
// Stored bytes?
if storable && storedSize < size {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Huffman.
if literalEncoding == fixedLiteralEncoding {
w.writeFixedHeader(eof)
} else {
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
}
// Write the tokens.
w.writeTokens(tokens, literalEncoding.codes, offsetEncoding.codes)
}
// writeBlockDynamic encodes a block using a dynamic Huffman table.
// This should be used if the symbols used have a disproportionate
// histogram distribution.
// If input is supplied and the compression savings are below 1/16th of the
// input size the block is stored.
func (w *huffmanBitWriter) writeBlockDynamic(tokens []token, eof bool, input []byte) {
if w.err != nil {
return
}
tokens = append(tokens, endBlockMarker)
numLiterals, numOffsets := w.indexTokens(tokens)
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
size, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, 0)
// Store bytes, if we don't get a reasonable improvement.
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Write Huffman table.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
// Write the tokens.
w.writeTokens(tokens, w.literalEncoding.codes, w.offsetEncoding.codes)
}
// indexTokens indexes a slice of tokens, and updates
// literalFreq and offsetFreq, and generates literalEncoding
// and offsetEncoding.
// The number of literal and offset tokens is returned.
func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets int) {
for i := range w.literalFreq {
w.literalFreq[i] = 0
}
for i := range w.offsetFreq {
w.offsetFreq[i] = 0
}
for _, t := range tokens {
if t < matchType {
w.literalFreq[t.literal()]++
continue
}
length := t.length()
offset := t.offset()
w.literalFreq[lengthCodesStart+lengthCode(length)]++
w.offsetFreq[offsetCode(offset)]++
}
// get the number of literals
numLiterals = len(w.literalFreq)
for w.literalFreq[numLiterals-1] == 0 {
numLiterals--
}
// get the number of offsets
numOffsets = len(w.offsetFreq)
for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
numOffsets--
}
if numOffsets == 0 {
// We haven't found a single match. If we want to go with the dynamic encoding,
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
w.offsetFreq[0] = 1
numOffsets = 1
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
return
}
// writeTokens writes a slice of tokens to the output.
// codes for literal and offset encoding must be supplied.
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
if w.err != nil {
return
}
for _, t := range tokens {
if t < matchType {
w.writeCode(leCodes[t.literal()])
continue
}
// Write the length
length := t.length()
lengthCode := lengthCode(length)
w.writeCode(leCodes[lengthCode+lengthCodesStart])
extraLengthBits := uint(lengthExtraBits[lengthCode])
if extraLengthBits > 0 {
extraLength := int32(length - lengthBase[lengthCode])
w.writeBits(extraLength, extraLengthBits)
}
// Write the offset
offset := t.offset()
offsetCode := offsetCode(offset)
w.writeCode(oeCodes[offsetCode])
extraOffsetBits := uint(offsetExtraBits[offsetCode])
if extraOffsetBits > 0 {
extraOffset := int32(offset - offsetBase[offsetCode])
w.writeBits(extraOffset, extraOffsetBits)
}
}
}
// huffOffset is a static offset encoder used for huffman only encoding.
// It can be reused since we will not be encoding offset values.
var huffOffset *huffmanEncoder
func init() {
offsetFreq := make([]int32, offsetCodeCount)
offsetFreq[0] = 1
huffOffset = newHuffmanEncoder(offsetCodeCount)
huffOffset.generate(offsetFreq, 15)
}
// writeBlockHuff encodes a block of bytes as either
// Huffman encoded literals or uncompressed bytes if the
// results only gains very little from compression.
func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
if w.err != nil {
return
}
// Clear histogram
for i := range w.literalFreq {
w.literalFreq[i] = 0
}
// Add everything as literals
histogram(input, w.literalFreq)
w.literalFreq[endBlockMarker] = 1
const numLiterals = endBlockMarker + 1
w.offsetFreq[0] = 1
const numOffsets = 1
w.literalEncoding.generate(w.literalFreq, 15)
// Figure out smallest code.
// Always use dynamic Huffman or Store
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset)
w.codegenEncoding.generate(w.codegenFreq[:], 7)
size, numCodegens := w.dynamicSize(w.literalEncoding, huffOffset, 0)
// Store bytes, if we don't get a reasonable improvement.
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
w.writeStoredHeader(len(input), eof)
w.writeBytes(input)
return
}
// Huffman.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
encoding := w.literalEncoding.codes[:257]
n := w.nbytes
for _, t := range input {
// Bitwriting inlined, ~30% speedup
c := encoding[t]
w.bits |= uint64(c.code) << w.nbits
w.nbits += uint(c.len)
if w.nbits < 48 {
continue
}
// Store 6 bytes
bits := w.bits
w.bits >>= 48
w.nbits -= 48
bytes := w.bytes[n : n+6]
bytes[0] = byte(bits)
bytes[1] = byte(bits >> 8)
bytes[2] = byte(bits >> 16)
bytes[3] = byte(bits >> 24)
bytes[4] = byte(bits >> 32)
bytes[5] = byte(bits >> 40)
n += 6
if n < bufferFlushSize {
continue
}
w.write(w.bytes[:n])
if w.err != nil {
return // Return early in the event of write failures
}
n = 0
}
w.nbytes = n
w.writeCode(encoding[endBlockMarker])
}
// histogram accumulates a histogram of b in h.
//
// len(h) must be >= 256, and h's elements must be all zeroes.
func histogram(b []byte, h []int32) {
h = h[:256]
for _, t := range b {
h[t]++
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
import (
"math"
"math/bits"
"sort"
)
// hcode is a huffman code with a bit code and bit length.
type hcode struct {
code, len uint16
}
type huffmanEncoder struct {
codes []hcode
freqcache []literalNode
bitCount [17]int32
lns byLiteral // stored to avoid repeated allocation in generate
lfs byFreq // stored to avoid repeated allocation in generate
}
type literalNode struct {
literal uint16
freq int32
}
// A levelInfo describes the state of the constructed tree for a given depth.
type levelInfo struct {
// Our level. for better printing
level int32
// The frequency of the last node at this level
lastFreq int32
// The frequency of the next character to add to this level
nextCharFreq int32
// The frequency of the next pair (from level below) to add to this level.
// Only valid if the "needed" value of the next lower level is 0.
nextPairFreq int32
// The number of chains remaining to generate for this level before moving
// up to the next level
needed int32
}
// set sets the code and length of an hcode.
func (h *hcode) set(code uint16, length uint16) {
h.len = length
h.code = code
}
func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} }
func newHuffmanEncoder(size int) *huffmanEncoder {
return &huffmanEncoder{codes: make([]hcode, size)}
}
// Generates a HuffmanCode corresponding to the fixed literal table.
func generateFixedLiteralEncoding() *huffmanEncoder {
h := newHuffmanEncoder(maxNumLit)
codes := h.codes
var ch uint16
for ch = 0; ch < maxNumLit; ch++ {
var bits uint16
var size uint16
switch {
case ch < 144:
// size 8, 000110000 .. 10111111
bits = ch + 48
size = 8
case ch < 256:
// size 9, 110010000 .. 111111111
bits = ch + 400 - 144
size = 9
case ch < 280:
// size 7, 0000000 .. 0010111
bits = ch - 256
size = 7
default:
// size 8, 11000000 .. 11000111
bits = ch + 192 - 280
size = 8
}
codes[ch] = hcode{code: reverseBits(bits, byte(size)), len: size}
}
return h
}
func generateFixedOffsetEncoding() *huffmanEncoder {
h := newHuffmanEncoder(30)
codes := h.codes
for ch := range codes {
codes[ch] = hcode{code: reverseBits(uint16(ch), 5), len: 5}
}
return h
}
var fixedLiteralEncoding *huffmanEncoder = generateFixedLiteralEncoding()
var fixedOffsetEncoding *huffmanEncoder = generateFixedOffsetEncoding()
func (h *huffmanEncoder) bitLength(freq []int32) int {
var total int
for i, f := range freq {
if f != 0 {
total += int(f) * int(h.codes[i].len)
}
}
return total
}
const maxBitsLimit = 16
// bitCounts computes the number of literals assigned to each bit size in the Huffman encoding.
// It is only called when list.length >= 3.
// The cases of 0, 1, and 2 literals are handled by special case code.
//
// list is an array of the literals with non-zero frequencies
// and their associated frequencies. The array is in order of increasing
// frequency and has as its last element a special element with frequency
// MaxInt32.
//
// maxBits is the maximum number of bits that should be used to encode any literal.
// It must be less than 16.
//
// bitCounts returns an integer slice in which slice[i] indicates the number of literals
// that should be encoded in i bits.
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
if maxBits >= maxBitsLimit {
panic("flate: maxBits too large")
}
n := int32(len(list))
list = list[0 : n+1]
list[n] = maxNode()
// The tree can't have greater depth than n - 1, no matter what. This
// saves a little bit of work in some small cases
if maxBits > n-1 {
maxBits = n - 1
}
// Create information about each of the levels.
// A bogus "Level 0" whose sole purpose is so that
// level1.prev.needed==0. This makes level1.nextPairFreq
// be a legitimate value that never gets chosen.
var levels [maxBitsLimit]levelInfo
// leafCounts[i] counts the number of literals at the left
// of ancestors of the rightmost node at level i.
// leafCounts[i][j] is the number of literals at the left
// of the level j ancestor.
var leafCounts [maxBitsLimit][maxBitsLimit]int32
for level := int32(1); level <= maxBits; level++ {
// For every level, the first two items are the first two characters.
// We initialize the levels as if we had already figured this out.
levels[level] = levelInfo{
level: level,
lastFreq: list[1].freq,
nextCharFreq: list[2].freq,
nextPairFreq: list[0].freq + list[1].freq,
}
leafCounts[level][level] = 2
if level == 1 {
levels[level].nextPairFreq = math.MaxInt32
}
}
// We need a total of 2*n - 2 items at top level and have already generated 2.
levels[maxBits].needed = 2*n - 4
level := maxBits
for {
l := &levels[level]
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
// We've run out of both leafs and pairs.
// End all calculations for this level.
// To make sure we never come back to this level or any lower level,
// set nextPairFreq impossibly large.
l.needed = 0
levels[level+1].nextPairFreq = math.MaxInt32
level++
continue
}
prevFreq := l.lastFreq
if l.nextCharFreq < l.nextPairFreq {
// The next item on this row is a leaf node.
n := leafCounts[level][level] + 1
l.lastFreq = l.nextCharFreq
// Lower leafCounts are the same of the previous node.
leafCounts[level][level] = n
l.nextCharFreq = list[n].freq
} else {
// The next item on this row is a pair from the previous row.
// nextPairFreq isn't valid until we generate two
// more values in the level below
l.lastFreq = l.nextPairFreq
// Take leaf counts from the lower level, except counts[level] remains the same.
copy(leafCounts[level][:level], leafCounts[level-1][:level])
levels[l.level-1].needed = 2
}
if l.needed--; l.needed == 0 {
// We've done everything we need to do for this level.
// Continue calculating one level up. Fill in nextPairFreq
// of that level with the sum of the two nodes we've just calculated on
// this level.
if l.level == maxBits {
// All done!
break
}
levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq
level++
} else {
// If we stole from below, move down temporarily to replenish it.
for levels[level-1].needed > 0 {
level--
}
}
}
// Somethings is wrong if at the end, the top level is null or hasn't used
// all of the leaves.
if leafCounts[maxBits][maxBits] != n {
panic("leafCounts[maxBits][maxBits] != n")
}
bitCount := h.bitCount[:maxBits+1]
bits := 1
counts := &leafCounts[maxBits]
for level := maxBits; level > 0; level-- {
// chain.leafCount gives the number of literals requiring at least "bits"
// bits to encode.
bitCount[bits] = counts[level] - counts[level-1]
bits++
}
return bitCount
}
// Look at the leaves and assign them a bit count and an encoding as specified
// in RFC 1951 3.2.2
func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) {
code := uint16(0)
for n, bits := range bitCount {
code <<= 1
if n == 0 || bits == 0 {
continue
}
// The literals list[len(list)-bits] .. list[len(list)-bits]
// are encoded using "bits" bits, and get the values
// code, code + 1, .... The code values are
// assigned in literal order (not frequency order).
chunk := list[len(list)-int(bits):]
h.lns.sort(chunk)
for _, node := range chunk {
h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)}
code++
}
list = list[0 : len(list)-int(bits)]
}
}
// Update this Huffman Code object to be the minimum code for the specified frequency count.
//
// freq is an array of frequencies, in which freq[i] gives the frequency of literal i.
// maxBits The maximum number of bits to use for any literal.
func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
if h.freqcache == nil {
// Allocate a reusable buffer with the longest possible frequency table.
// Possible lengths are codegenCodeCount, offsetCodeCount and maxNumLit.
// The largest of these is maxNumLit, so we allocate for that case.
h.freqcache = make([]literalNode, maxNumLit+1)
}
list := h.freqcache[:len(freq)+1]
// Number of non-zero literals
count := 0
// Set list to be the set of all non-zero literals and their frequencies
for i, f := range freq {
if f != 0 {
list[count] = literalNode{uint16(i), f}
count++
} else {
h.codes[i].len = 0
}
}
list = list[:count]
if count <= 2 {
// Handle the small cases here, because they are awkward for the general case code. With
// two or fewer literals, everything has bit length 1.
for i, node := range list {
// "list" is in order of increasing literal value.
h.codes[node.literal].set(uint16(i), 1)
}
return
}
h.lfs.sort(list)
// Get the number of literals for each bit count
bitCount := h.bitCounts(list, maxBits)
// And do the assignment
h.assignEncodingAndSize(bitCount, list)
}
type byLiteral []literalNode
func (s *byLiteral) sort(a []literalNode) {
*s = byLiteral(a)
sort.Sort(s)
}
func (s byLiteral) Len() int { return len(s) }
func (s byLiteral) Less(i, j int) bool {
return s[i].literal < s[j].literal
}
func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type byFreq []literalNode
func (s *byFreq) sort(a []literalNode) {
*s = byFreq(a)
sort.Sort(s)
}
func (s byFreq) Len() int { return len(s) }
func (s byFreq) Less(i, j int) bool {
if s[i].freq == s[j].freq {
return s[i].literal < s[j].literal
}
return s[i].freq < s[j].freq
}
func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func reverseBits(number uint16, bitLength byte) uint16 {
return bits.Reverse16(number << (16 - bitLength))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package flate implements the DEFLATE compressed data format, described in
// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file
// formats.
package flate
import (
"bufio"
"io"
"math/bits"
"strconv"
"sync"
)
const (
maxCodeLen = 16 // max length of Huffman code
// The next three numbers come from the RFC section 3.2.7, with the
// additional proviso in section 3.2.5 which implies that distance codes
// 30 and 31 should never occur in compressed data.
maxNumLit = 286
maxNumDist = 30
numCodes = 19 // number of codes in Huffman meta-code
)
// Initialize the fixedHuffmanDecoder only once upon first use.
var fixedOnce sync.Once
var fixedHuffmanDecoder huffmanDecoder
// A CorruptInputError reports the presence of corrupt input at a given offset.
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10)
}
// An InternalError reports an error in the flate code itself.
type InternalError string
func (e InternalError) Error() string { return "flate: internal error: " + string(e) }
// A ReadError reports an error encountered while reading input.
//
// Deprecated: No longer returned.
type ReadError struct {
Offset int64 // byte offset where error occurred
Err error // error returned by underlying Read
}
func (e *ReadError) Error() string {
return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
}
// A WriteError reports an error encountered while writing output.
//
// Deprecated: No longer returned.
type WriteError struct {
Offset int64 // byte offset where error occurred
Err error // error returned by underlying Write
}
func (e *WriteError) Error() string {
return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
}
// Resetter resets a ReadCloser returned by NewReader or NewReaderDict
// to switch to a new underlying Reader. This permits reusing a ReadCloser
// instead of allocating a new one.
type Resetter interface {
// Reset discards any buffered data and resets the Resetter as if it was
// newly initialized with the given reader.
Reset(r io.Reader, dict []byte) error
}
// The data structure for decoding Huffman tables is based on that of
// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits),
// For codes smaller than the table width, there are multiple entries
// (each combination of trailing bits has the same value). For codes
// larger than the table width, the table contains a link to an overflow
// table. The width of each entry in the link table is the maximum code
// size minus the chunk width.
//
// Note that you can do a lookup in the table even without all bits
// filled. Since the extra bits are zero, and the DEFLATE Huffman codes
// have the property that shorter codes come before longer ones, the
// bit length estimate in the result is a lower bound on the actual
// number of bits.
//
// See the following:
// https://github.com/madler/zlib/raw/master/doc/algorithm.txt
// chunk & 15 is number of bits
// chunk >> 4 is value, including table link
const (
huffmanChunkBits = 9
huffmanNumChunks = 1 << huffmanChunkBits
huffmanCountMask = 15
huffmanValueShift = 4
)
type huffmanDecoder struct {
min int // the minimum code length
chunks [huffmanNumChunks]uint32 // chunks as described above
links [][]uint32 // overflow links
linkMask uint32 // mask the width of the link table
}
// Initialize Huffman decoding tables from array of code lengths.
// Following this function, h is guaranteed to be initialized into a complete
// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a
// degenerate case where the tree has only a single symbol with length 1. Empty
// trees are permitted.
func (h *huffmanDecoder) init(lengths []int) bool {
// Sanity enables additional runtime tests during Huffman
// table construction. It's intended to be used during
// development to supplement the currently ad-hoc unit tests.
const sanity = false
if h.min != 0 {
*h = huffmanDecoder{}
}
// Count number of codes of each length,
// compute min and max length.
var count [maxCodeLen]int
var min, max int
for _, n := range lengths {
if n == 0 {
continue
}
if min == 0 || n < min {
min = n
}
if n > max {
max = n
}
count[n]++
}
// Empty tree. The decompressor.huffSym function will fail later if the tree
// is used. Technically, an empty tree is only valid for the HDIST tree and
// not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree
// is guaranteed to fail since it will attempt to use the tree to decode the
// codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is
// guaranteed to fail later since the compressed data section must be
// composed of at least one symbol (the end-of-block marker).
if max == 0 {
return true
}
code := 0
var nextcode [maxCodeLen]int
for i := min; i <= max; i++ {
code <<= 1
nextcode[i] = code
code += count[i]
}
// Check that the coding is complete (i.e., that we've
// assigned all 2-to-the-max possible bit sequences).
// Exception: To be compatible with zlib, we also need to
// accept degenerate single-code codings. See also
// TestDegenerateHuffmanCoding.
if code != 1<<uint(max) && !(code == 1 && max == 1) {
return false
}
h.min = min
if max > huffmanChunkBits {
numLinks := 1 << (uint(max) - huffmanChunkBits)
h.linkMask = uint32(numLinks - 1)
// create link tables
link := nextcode[huffmanChunkBits+1] >> 1
h.links = make([][]uint32, huffmanNumChunks-link)
for j := uint(link); j < huffmanNumChunks; j++ {
reverse := int(bits.Reverse16(uint16(j)))
reverse >>= uint(16 - huffmanChunkBits)
off := j - uint(link)
if sanity && h.chunks[reverse] != 0 {
panic("impossible: overwriting existing chunk")
}
h.chunks[reverse] = uint32(off<<huffmanValueShift | (huffmanChunkBits + 1))
h.links[off] = make([]uint32, numLinks)
}
}
for i, n := range lengths {
if n == 0 {
continue
}
code := nextcode[n]
nextcode[n]++
chunk := uint32(i<<huffmanValueShift | n)
reverse := int(bits.Reverse16(uint16(code)))
reverse >>= uint(16 - n)
if n <= huffmanChunkBits {
for off := reverse; off < len(h.chunks); off += 1 << uint(n) {
// We should never need to overwrite
// an existing chunk. Also, 0 is
// never a valid chunk, because the
// lower 4 "count" bits should be
// between 1 and 15.
if sanity && h.chunks[off] != 0 {
panic("impossible: overwriting existing chunk")
}
h.chunks[off] = chunk
}
} else {
j := reverse & (huffmanNumChunks - 1)
if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 {
// Longer codes should have been
// associated with a link table above.
panic("impossible: not an indirect chunk")
}
value := h.chunks[j] >> huffmanValueShift
linktab := h.links[value]
reverse >>= huffmanChunkBits
for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) {
if sanity && linktab[off] != 0 {
panic("impossible: overwriting existing chunk")
}
linktab[off] = chunk
}
}
}
if sanity {
// Above we've sanity checked that we never overwrote
// an existing entry. Here we additionally check that
// we filled the tables completely.
for i, chunk := range h.chunks {
if chunk == 0 {
// As an exception, in the degenerate
// single-code case, we allow odd
// chunks to be missing.
if code == 1 && i%2 == 1 {
continue
}
panic("impossible: missing chunk")
}
}
for _, linktab := range h.links {
for _, chunk := range linktab {
if chunk == 0 {
panic("impossible: missing chunk")
}
}
}
}
return true
}
// The actual read interface needed by NewReader.
// If the passed in io.Reader does not also have ReadByte,
// the NewReader will introduce its own buffering.
type Reader interface {
io.Reader
io.ByteReader
}
// Decompress state.
type decompressor struct {
// Input source.
r Reader
roffset int64
// Input bits, in top of b.
b uint32
nb uint
// Huffman decoders for literal/length, distance.
h1, h2 huffmanDecoder
// Length arrays used to define Huffman codes.
bits *[maxNumLit + maxNumDist]int
codebits *[numCodes]int
// Output history, buffer.
dict dictDecoder
// Temporary buffer (avoids repeated allocation).
buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
stepState int
final bool
err error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
}
func (f *decompressor) nextBlock() {
for f.nb < 1+2 {
if f.err = f.moreBits(); f.err != nil {
return
}
}
f.final = f.b&1 == 1
f.b >>= 1
typ := f.b & 3
f.b >>= 2
f.nb -= 1 + 2
switch typ {
case 0:
f.dataBlock()
case 1:
// compressed, fixed Huffman tables
f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
// compressed, dynamic Huffman tables
if f.err = f.readHuffman(); f.err != nil {
break
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
// 3 is reserved.
f.err = CorruptInputError(f.roffset)
}
}
func (f *decompressor) Read(b []byte) (int, error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
if len(f.toRead) == 0 {
return n, f.err
}
return n, nil
}
if f.err != nil {
return 0, f.err
}
f.step(f)
if f.err != nil && len(f.toRead) == 0 {
f.toRead = f.dict.readFlush() // Flush what's left in case of error
}
}
}
func (f *decompressor) Close() error {
if f.err == io.EOF {
return nil
}
return f.err
}
// RFC 1951 section 3.2.7.
// Compression with dynamic Huffman codes
var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
func (f *decompressor) readHuffman() error {
// HLIT[5], HDIST[5], HCLEN[4].
for f.nb < 5+5+4 {
if err := f.moreBits(); err != nil {
return err
}
}
nlit := int(f.b&0x1F) + 257
if nlit > maxNumLit {
return CorruptInputError(f.roffset)
}
f.b >>= 5
ndist := int(f.b&0x1F) + 1
if ndist > maxNumDist {
return CorruptInputError(f.roffset)
}
f.b >>= 5
nclen := int(f.b&0xF) + 4
// numCodes is 19, so nclen is always valid.
f.b >>= 4
f.nb -= 5 + 5 + 4
// (HCLEN+4)*3 bits: code lengths in the magic codeOrder order.
for i := 0; i < nclen; i++ {
for f.nb < 3 {
if err := f.moreBits(); err != nil {
return err
}
}
f.codebits[codeOrder[i]] = int(f.b & 0x7)
f.b >>= 3
f.nb -= 3
}
for i := nclen; i < len(codeOrder); i++ {
f.codebits[codeOrder[i]] = 0
}
if !f.h1.init(f.codebits[0:]) {
return CorruptInputError(f.roffset)
}
// HLIT + 257 code lengths, HDIST + 1 code lengths,
// using the code length Huffman code.
for i, n := 0, nlit+ndist; i < n; {
x, err := f.huffSym(&f.h1)
if err != nil {
return err
}
if x < 16 {
// Actual length.
f.bits[i] = x
i++
continue
}
// Repeat previous length or zero.
var rep int
var nb uint
var b int
switch x {
default:
return InternalError("unexpected length code")
case 16:
rep = 3
nb = 2
if i == 0 {
return CorruptInputError(f.roffset)
}
b = f.bits[i-1]
case 17:
rep = 3
nb = 3
b = 0
case 18:
rep = 11
nb = 7
b = 0
}
for f.nb < nb {
if err := f.moreBits(); err != nil {
return err
}
}
rep += int(f.b & uint32(1<<nb-1))
f.b >>= nb
f.nb -= nb
if i+rep > n {
return CorruptInputError(f.roffset)
}
for j := 0; j < rep; j++ {
f.bits[i] = b
i++
}
}
if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) {
return CorruptInputError(f.roffset)
}
// As an optimization, we can initialize the min bits to read at a time
// for the HLIT tree to the length of the EOB marker since we know that
// every block must terminate with one. This preserves the property that
// we never read any extra bytes after the end of the DEFLATE stream.
if f.h1.min < f.bits[endBlockMarker] {
f.h1.min = f.bits[endBlockMarker]
}
return nil
}
// Decode a single Huffman block from f.
// hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) huffmanBlock() {
const (
stateInit = iota // Zero value must be stateInit
stateDict
)
switch f.stepState {
case stateInit:
goto readLiteral
case stateDict:
goto copyHistory
}
readLiteral:
// Read literal and/or (length, distance) according to RFC section 3.2.3.
{
v, err := f.huffSym(f.hl)
if err != nil {
f.err = err
return
}
var n uint // number of bits extra
var length int
switch {
case v < 256:
f.dict.writeByte(byte(v))
if f.dict.availWrite() == 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).huffmanBlock
f.stepState = stateInit
return
}
goto readLiteral
case v == 256:
f.finishBlock()
return
// otherwise, reference to older data
case v < 265:
length = v - (257 - 3)
n = 0
case v < 269:
length = v*2 - (265*2 - 11)
n = 1
case v < 273:
length = v*4 - (269*4 - 19)
n = 2
case v < 277:
length = v*8 - (273*8 - 35)
n = 3
case v < 281:
length = v*16 - (277*16 - 67)
n = 4
case v < 285:
length = v*32 - (281*32 - 131)
n = 5
case v < maxNumLit:
length = 258
n = 0
default:
f.err = CorruptInputError(f.roffset)
return
}
if n > 0 {
for f.nb < n {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
length += int(f.b & uint32(1<<n-1))
f.b >>= n
f.nb -= n
}
var dist int
if f.hd == nil {
for f.nb < 5 {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3)))
f.b >>= 5
f.nb -= 5
} else {
if dist, err = f.huffSym(f.hd); err != nil {
f.err = err
return
}
}
switch {
case dist < 4:
dist++
case dist < maxNumDist:
nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb
for f.nb < nb {
if err = f.moreBits(); err != nil {
f.err = err
return
}
}
extra |= int(f.b & uint32(1<<nb-1))
f.b >>= nb
f.nb -= nb
dist = 1<<(nb+1) + 1 + extra
default:
f.err = CorruptInputError(f.roffset)
return
}
// No check on length; encoding can be prescient.
if dist > f.dict.histSize() {
f.err = CorruptInputError(f.roffset)
return
}
f.copyLen, f.copyDist = length, dist
goto copyHistory
}
copyHistory:
// Perform a backwards copy according to RFC section 3.2.3.
{
cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen)
if cnt == 0 {
cnt = f.dict.writeCopy(f.copyDist, f.copyLen)
}
f.copyLen -= cnt
if f.dict.availWrite() == 0 || f.copyLen > 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).huffmanBlock // We need to continue this work
f.stepState = stateDict
return
}
goto readLiteral
}
}
// Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() {
// Uncompressed.
// Discard current half-byte.
f.nb = 0
f.b = 0
// Length then ones-complement of length.
nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr)
if err != nil {
f.err = noEOF(err)
return
}
n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) {
f.err = CorruptInputError(f.roffset)
return
}
if n == 0 {
f.toRead = f.dict.readFlush()
f.finishBlock()
return
}
f.copyLen = n
f.copyData()
}
// copyData copies f.copyLen bytes from the underlying reader into f.hist.
// It pauses for reads when f.hist is full.
func (f *decompressor) copyData() {
buf := f.dict.writeSlice()
if len(buf) > f.copyLen {
buf = buf[:f.copyLen]
}
cnt, err := io.ReadFull(f.r, buf)
f.roffset += int64(cnt)
f.copyLen -= cnt
f.dict.writeMark(cnt)
if err != nil {
f.err = noEOF(err)
return
}
if f.dict.availWrite() == 0 || f.copyLen > 0 {
f.toRead = f.dict.readFlush()
f.step = (*decompressor).copyData
return
}
f.finishBlock()
}
func (f *decompressor) finishBlock() {
if f.final {
if f.dict.availRead() > 0 {
f.toRead = f.dict.readFlush()
}
f.err = io.EOF
}
f.step = (*decompressor).nextBlock
}
// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF.
func noEOF(e error) error {
if e == io.EOF {
return io.ErrUnexpectedEOF
}
return e
}
func (f *decompressor) moreBits() error {
c, err := f.r.ReadByte()
if err != nil {
return noEOF(err)
}
f.roffset++
f.b |= uint32(c) << f.nb
f.nb += 8
return nil
}
// Read the next Huffman-encoded symbol from f according to h.
func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
// Since a huffmanDecoder can be empty or be composed of a degenerate tree
// with single element, huffSym must error on these two edge cases. In both
// cases, the chunks slice will be 0 for the invalid sequence, leading it
// satisfy the n == 0 check below.
n := uint(h.min)
// Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers,
// but is smart enough to keep local variables in registers, so use nb and b,
// inline call to moreBits and reassign b,nb back to f on return.
nb, b := f.nb, f.b
for {
for nb < n {
c, err := f.r.ReadByte()
if err != nil {
f.b = b
f.nb = nb
return 0, noEOF(err)
}
f.roffset++
b |= uint32(c) << (nb & 31)
nb += 8
}
chunk := h.chunks[b&(huffmanNumChunks-1)]
n = uint(chunk & huffmanCountMask)
if n > huffmanChunkBits {
chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask]
n = uint(chunk & huffmanCountMask)
}
if n <= nb {
if n == 0 {
f.b = b
f.nb = nb
f.err = CorruptInputError(f.roffset)
return 0, f.err
}
f.b = b >> (n & 31)
f.nb = nb - n
return int(chunk >> huffmanValueShift), nil
}
}
}
func makeReader(r io.Reader) Reader {
if rr, ok := r.(Reader); ok {
return rr
}
return bufio.NewReader(r)
}
func fixedHuffmanDecoderInit() {
fixedOnce.Do(func() {
// These come from the RFC section 3.2.6.
var bits [288]int
for i := 0; i < 144; i++ {
bits[i] = 8
}
for i := 144; i < 256; i++ {
bits[i] = 9
}
for i := 256; i < 280; i++ {
bits[i] = 7
}
for i := 280; i < 288; i++ {
bits[i] = 8
}
fixedHuffmanDecoder.init(bits[:])
})
}
func (f *decompressor) Reset(r io.Reader, dict []byte) error {
*f = decompressor{
r: makeReader(r),
bits: f.bits,
codebits: f.codebits,
dict: f.dict,
step: (*decompressor).nextBlock,
}
f.dict.init(maxMatchOffset, dict)
return nil
}
// NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r.
// If r does not also implement io.ByteReader,
// the decompressor may read more data than necessary from r.
// The reader returns io.EOF after the final block in the DEFLATE stream has
// been encountered. Any trailing data after the final block is ignored.
//
// The ReadCloser returned by NewReader also implements Resetter.
func NewReader(r io.Reader) io.ReadCloser {
fixedHuffmanDecoderInit()
var f decompressor
f.r = makeReader(r)
f.bits = new([maxNumLit + maxNumDist]int)
f.codebits = new([numCodes]int)
f.step = (*decompressor).nextBlock
f.dict.init(maxMatchOffset, nil)
return &f
}
// NewReaderDict is like NewReader but initializes the reader
// with a preset dictionary. The returned Reader behaves as if
// the uncompressed data stream started with the given dictionary,
// which has already been read. NewReaderDict is typically used
// to read data compressed by NewWriterDict.
//
// The ReadCloser returned by NewReader also implements Resetter.
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
fixedHuffmanDecoderInit()
var f decompressor
f.r = makeReader(r)
f.bits = new([maxNumLit + maxNumDist]int)
f.codebits = new([numCodes]int)
f.step = (*decompressor).nextBlock
f.dict.init(maxMatchOffset, dict)
return &f
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flate
const (
// 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused
// 8 bits: xlength = length - MIN_MATCH_LENGTH
// 22 bits xoffset = offset - MIN_OFFSET_SIZE, or literal
lengthShift = 22
offsetMask = 1<<lengthShift - 1
typeMask = 3 << 30
literalType = 0 << 30
matchType = 1 << 30
)
// The length code for length X (MIN_MATCH_LENGTH <= X <= MAX_MATCH_LENGTH)
// is lengthCodes[length - MIN_MATCH_LENGTH]
var lengthCodes = [...]uint32{
0, 1, 2, 3, 4, 5, 6, 7, 8, 8,
9, 9, 10, 10, 11, 11, 12, 12, 12, 12,
13, 13, 13, 13, 14, 14, 14, 14, 15, 15,
15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
17, 17, 17, 17, 17, 17, 17, 17, 18, 18,
18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
19, 19, 19, 19, 20, 20, 20, 20, 20, 20,
20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 21, 21, 21, 21, 21, 22, 22, 22, 22,
22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
23, 23, 23, 23, 23, 23, 23, 23, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
25, 25, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
26, 26, 26, 26, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
27, 27, 27, 27, 27, 28,
}
var offsetCodes = [...]uint32{
0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
}
type token uint32
// Convert a literal into a literal token.
func literalToken(literal uint32) token { return token(literalType + literal) }
// Convert a < xlength, xoffset > pair into a match token.
func matchToken(xlength uint32, xoffset uint32) token {
return token(matchType + xlength<<lengthShift + xoffset)
}
// Returns the literal of a literal token.
func (t token) literal() uint32 { return uint32(t - literalType) }
// Returns the extra offset of a match token.
func (t token) offset() uint32 { return uint32(t) & offsetMask }
func (t token) length() uint32 { return uint32((t - matchType) >> lengthShift) }
func lengthCode(len uint32) uint32 { return lengthCodes[len] }
// Returns the offset code corresponding to a specific offset.
func offsetCode(off uint32) uint32 {
if off < uint32(len(offsetCodes)) {
return offsetCodes[off]
}
if off>>7 < uint32(len(offsetCodes)) {
return offsetCodes[off>>7] + 14
}
return offsetCodes[off>>14] + 28
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gzip implements reading and writing of gzip format compressed files,
// as specified in RFC 1952.
package gzip
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"time"
)
const (
gzipID1 = 0x1f
gzipID2 = 0x8b
gzipDeflate = 8
flagText = 1 << 0
flagHdrCrc = 1 << 1
flagExtra = 1 << 2
flagName = 1 << 3
flagComment = 1 << 4
)
var (
// ErrChecksum is returned when reading GZIP data that has an invalid checksum.
ErrChecksum = errors.New("gzip: invalid checksum")
// ErrHeader is returned when reading GZIP data that has an invalid header.
ErrHeader = errors.New("gzip: invalid header")
)
var le = binary.LittleEndian
// noEOF converts io.EOF to io.ErrUnexpectedEOF.
func noEOF(err error) error {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
// The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Writer and Reader structs.
//
// Strings must be UTF-8 encoded and may only contain Unicode code points
// U+0001 through U+00FF, due to limitations of the GZIP file format.
type Header struct {
Comment string // comment
Extra []byte // "extra data"
ModTime time.Time // modification time
Name string // file name
OS byte // operating system type
}
// A Reader is an io.Reader that can be read to retrieve
// uncompressed data from a gzip-format compressed file.
//
// In general, a gzip file can be a concatenation of gzip files,
// each with its own header. Reads from the Reader
// return the concatenation of the uncompressed data of each.
// Only the first header is recorded in the Reader fields.
//
// Gzip files store a length and checksum of the uncompressed data.
// The Reader will return an ErrChecksum when Read
// reaches the end of the uncompressed data if it does not
// have the expected length or checksum. Clients should treat data
// returned by Read as tentative until they receive the io.EOF
// marking the end of the data.
type Reader struct {
Header // valid after NewReader or Reader.Reset
r flate.Reader
decompressor io.ReadCloser
digest uint32 // CRC-32, IEEE polynomial (section 8)
size uint32 // Uncompressed size (section 2.3.1)
buf [512]byte
err error
multistream bool
}
// NewReader creates a new Reader reading the given reader.
// If r does not also implement io.ByteReader,
// the decompressor may read more data than necessary from r.
//
// It is the caller's responsibility to call Close on the Reader when done.
//
// The Reader.Header fields will be valid in the Reader returned.
func NewReader(r io.Reader) (*Reader, error) {
z := new(Reader)
if err := z.Reset(r); err != nil {
return nil, err
}
return z, nil
}
// Reset discards the Reader z's state and makes it equivalent to the
// result of its original state from NewReader, but reading from r instead.
// This permits reusing a Reader rather than allocating a new one.
func (z *Reader) Reset(r io.Reader) error {
*z = Reader{
decompressor: z.decompressor,
multistream: true,
}
if rr, ok := r.(flate.Reader); ok {
z.r = rr
} else {
z.r = bufio.NewReader(r)
}
z.Header, z.err = z.readHeader()
return z.err
}
// Multistream controls whether the reader supports multistream files.
//
// If enabled (the default), the Reader expects the input to be a sequence
// of individually gzipped data streams, each with its own header and
// trailer, ending at EOF. The effect is that the concatenation of a sequence
// of gzipped files is treated as equivalent to the gzip of the concatenation
// of the sequence. This is standard behavior for gzip readers.
//
// Calling Multistream(false) disables this behavior; disabling the behavior
// can be useful when reading file formats that distinguish individual gzip
// data streams or mix gzip data streams with other data streams.
// In this mode, when the Reader reaches the end of the data stream,
// Read returns io.EOF. The underlying reader must implement io.ByteReader
// in order to be left positioned just after the gzip stream.
// To start the next stream, call z.Reset(r) followed by z.Multistream(false).
// If there is no next stream, z.Reset(r) will return io.EOF.
func (z *Reader) Multistream(ok bool) {
z.multistream = ok
}
// readString reads a NUL-terminated string from z.r.
// It treats the bytes read as being encoded as ISO 8859-1 (Latin-1) and
// will output a string encoded using UTF-8.
// This method always updates z.digest with the data read.
func (z *Reader) readString() (string, error) {
var err error
needConv := false
for i := 0; ; i++ {
if i >= len(z.buf) {
return "", ErrHeader
}
z.buf[i], err = z.r.ReadByte()
if err != nil {
return "", err
}
if z.buf[i] > 0x7f {
needConv = true
}
if z.buf[i] == 0 {
// Digest covers the NUL terminator.
z.digest = crc32.Update(z.digest, crc32.IEEETable, z.buf[:i+1])
// Strings are ISO 8859-1, Latin-1 (RFC 1952, section 2.3.1).
if needConv {
s := make([]rune, 0, i)
for _, v := range z.buf[:i] {
s = append(s, rune(v))
}
return string(s), nil
}
return string(z.buf[:i]), nil
}
}
}
// readHeader reads the GZIP header according to section 2.3.1.
// This method does not set z.err.
func (z *Reader) readHeader() (hdr Header, err error) {
if _, err = io.ReadFull(z.r, z.buf[:10]); err != nil {
// RFC 1952, section 2.2, says the following:
// A gzip file consists of a series of "members" (compressed data sets).
//
// Other than this, the specification does not clarify whether a
// "series" is defined as "one or more" or "zero or more". To err on the
// side of caution, Go interprets this to mean "zero or more".
// Thus, it is okay to return io.EOF here.
return hdr, err
}
if z.buf[0] != gzipID1 || z.buf[1] != gzipID2 || z.buf[2] != gzipDeflate {
return hdr, ErrHeader
}
flg := z.buf[3]
if t := int64(le.Uint32(z.buf[4:8])); t > 0 {
// Section 2.3.1, the zero value for MTIME means that the
// modified time is not set.
hdr.ModTime = time.Unix(t, 0)
}
// z.buf[8] is XFL and is currently ignored.
hdr.OS = z.buf[9]
z.digest = crc32.ChecksumIEEE(z.buf[:10])
if flg&flagExtra != 0 {
if _, err = io.ReadFull(z.r, z.buf[:2]); err != nil {
return hdr, noEOF(err)
}
z.digest = crc32.Update(z.digest, crc32.IEEETable, z.buf[:2])
data := make([]byte, le.Uint16(z.buf[:2]))
if _, err = io.ReadFull(z.r, data); err != nil {
return hdr, noEOF(err)
}
z.digest = crc32.Update(z.digest, crc32.IEEETable, data)
hdr.Extra = data
}
var s string
if flg&flagName != 0 {
if s, err = z.readString(); err != nil {
return hdr, noEOF(err)
}
hdr.Name = s
}
if flg&flagComment != 0 {
if s, err = z.readString(); err != nil {
return hdr, noEOF(err)
}
hdr.Comment = s
}
if flg&flagHdrCrc != 0 {
if _, err = io.ReadFull(z.r, z.buf[:2]); err != nil {
return hdr, noEOF(err)
}
digest := le.Uint16(z.buf[:2])
if digest != uint16(z.digest) {
return hdr, ErrHeader
}
}
z.digest = 0
if z.decompressor == nil {
z.decompressor = flate.NewReader(z.r)
} else {
z.decompressor.(flate.Resetter).Reset(z.r, nil)
}
return hdr, nil
}
// Read implements io.Reader, reading uncompressed bytes from its underlying Reader.
func (z *Reader) Read(p []byte) (n int, err error) {
if z.err != nil {
return 0, z.err
}
for n == 0 {
n, z.err = z.decompressor.Read(p)
z.digest = crc32.Update(z.digest, crc32.IEEETable, p[:n])
z.size += uint32(n)
if z.err != io.EOF {
// In the normal case we return here.
return n, z.err
}
// Finished file; check checksum and size.
if _, err := io.ReadFull(z.r, z.buf[:8]); err != nil {
z.err = noEOF(err)
return n, z.err
}
digest := le.Uint32(z.buf[:4])
size := le.Uint32(z.buf[4:8])
if digest != z.digest || size != z.size {
z.err = ErrChecksum
return n, z.err
}
z.digest, z.size = 0, 0
// File is ok; check if there is another.
if !z.multistream {
return n, io.EOF
}
z.err = nil // Remove io.EOF
if _, z.err = z.readHeader(); z.err != nil {
return n, z.err
}
}
return n, nil
}
// Close closes the Reader. It does not close the underlying io.Reader.
// In order for the GZIP checksum to be verified, the reader must be
// fully consumed until the io.EOF.
func (z *Reader) Close() error { return z.decompressor.Close() }
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gzip
import (
"compress/flate"
"errors"
"fmt"
"hash/crc32"
"io"
"time"
)
// These constants are copied from the flate package, so that code that imports
// "compress/gzip" does not also have to import "compress/flate".
const (
NoCompression = flate.NoCompression
BestSpeed = flate.BestSpeed
BestCompression = flate.BestCompression
DefaultCompression = flate.DefaultCompression
HuffmanOnly = flate.HuffmanOnly
)
// A Writer is an io.WriteCloser.
// Writes to a Writer are compressed and written to w.
type Writer struct {
Header // written at first call to Write, Flush, or Close
w io.Writer
level int
wroteHeader bool
compressor *flate.Writer
digest uint32 // CRC-32, IEEE polynomial (section 8)
size uint32 // Uncompressed size (section 2.3.1)
closed bool
buf [10]byte
err error
}
// NewWriter returns a new Writer.
// Writes to the returned writer are compressed and written to w.
//
// It is the caller's responsibility to call Close on the Writer when done.
// Writes may be buffered and not flushed until Close.
//
// Callers that wish to set the fields in Writer.Header must do so before
// the first call to Write, Flush, or Close.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevel(w, DefaultCompression)
return z
}
// NewWriterLevel is like NewWriter but specifies the compression level instead
// of assuming DefaultCompression.
//
// The compression level can be DefaultCompression, NoCompression, HuffmanOnly
// or any integer value between BestSpeed and BestCompression inclusive.
// The error returned will be nil if the level is valid.
func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
if level < HuffmanOnly || level > BestCompression {
return nil, fmt.Errorf("gzip: invalid compression level: %d", level)
}
z := new(Writer)
z.init(w, level)
return z, nil
}
func (z *Writer) init(w io.Writer, level int) {
compressor := z.compressor
if compressor != nil {
compressor.Reset(w)
}
*z = Writer{
Header: Header{
OS: 255, // unknown
},
w: w,
level: level,
compressor: compressor,
}
}
// Reset discards the Writer z's state and makes it equivalent to the
// result of its original state from NewWriter or NewWriterLevel, but
// writing to w instead. This permits reusing a Writer rather than
// allocating a new one.
func (z *Writer) Reset(w io.Writer) {
z.init(w, z.level)
}
// writeBytes writes a length-prefixed byte slice to z.w.
func (z *Writer) writeBytes(b []byte) error {
if len(b) > 0xffff {
return errors.New("gzip.Write: Extra data is too large")
}
le.PutUint16(z.buf[:2], uint16(len(b)))
_, err := z.w.Write(z.buf[:2])
if err != nil {
return err
}
_, err = z.w.Write(b)
return err
}
// writeString writes a UTF-8 string s in GZIP's format to z.w.
// GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1).
func (z *Writer) writeString(s string) (err error) {
// GZIP stores Latin-1 strings; error if non-Latin-1; convert if non-ASCII.
needconv := false
for _, v := range s {
if v == 0 || v > 0xff {
return errors.New("gzip.Write: non-Latin-1 header string")
}
if v > 0x7f {
needconv = true
}
}
if needconv {
b := make([]byte, 0, len(s))
for _, v := range s {
b = append(b, byte(v))
}
_, err = z.w.Write(b)
} else {
_, err = io.WriteString(z.w, s)
}
if err != nil {
return err
}
// GZIP strings are NUL-terminated.
z.buf[0] = 0
_, err = z.w.Write(z.buf[:1])
return err
}
// Write writes a compressed form of p to the underlying io.Writer. The
// compressed bytes are not necessarily flushed until the Writer is closed.
func (z *Writer) Write(p []byte) (int, error) {
if z.err != nil {
return 0, z.err
}
var n int
// Write the GZIP header lazily.
if !z.wroteHeader {
z.wroteHeader = true
z.buf = [10]byte{0: gzipID1, 1: gzipID2, 2: gzipDeflate}
if z.Extra != nil {
z.buf[3] |= 0x04
}
if z.Name != "" {
z.buf[3] |= 0x08
}
if z.Comment != "" {
z.buf[3] |= 0x10
}
if z.ModTime.After(time.Unix(0, 0)) {
// Section 2.3.1, the zero value for MTIME means that the
// modified time is not set.
le.PutUint32(z.buf[4:8], uint32(z.ModTime.Unix()))
}
if z.level == BestCompression {
z.buf[8] = 2
} else if z.level == BestSpeed {
z.buf[8] = 4
}
z.buf[9] = z.OS
_, z.err = z.w.Write(z.buf[:10])
if z.err != nil {
return 0, z.err
}
if z.Extra != nil {
z.err = z.writeBytes(z.Extra)
if z.err != nil {
return 0, z.err
}
}
if z.Name != "" {
z.err = z.writeString(z.Name)
if z.err != nil {
return 0, z.err
}
}
if z.Comment != "" {
z.err = z.writeString(z.Comment)
if z.err != nil {
return 0, z.err
}
}
if z.compressor == nil {
z.compressor, _ = flate.NewWriter(z.w, z.level)
}
}
z.size += uint32(len(p))
z.digest = crc32.Update(z.digest, crc32.IEEETable, p)
n, z.err = z.compressor.Write(p)
return n, z.err
}
// Flush flushes any pending compressed data to the underlying writer.
//
// It is useful mainly in compressed network protocols, to ensure that
// a remote reader has enough data to reconstruct a packet. Flush does
// not return until the data has been written. If the underlying
// writer returns an error, Flush returns that error.
//
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
func (z *Writer) Flush() error {
if z.err != nil {
return z.err
}
if z.closed {
return nil
}
if !z.wroteHeader {
z.Write(nil)
if z.err != nil {
return z.err
}
}
z.err = z.compressor.Flush()
return z.err
}
// Close closes the Writer by flushing any unwritten data to the underlying
// io.Writer and writing the GZIP footer.
// It does not close the underlying io.Writer.
func (z *Writer) Close() error {
if z.err != nil {
return z.err
}
if z.closed {
return nil
}
z.closed = true
if !z.wroteHeader {
z.Write(nil)
if z.err != nil {
return z.err
}
}
z.err = z.compressor.Close()
if z.err != nil {
return z.err
}
le.PutUint32(z.buf[:4], z.digest)
le.PutUint32(z.buf[4:8], z.size)
_, z.err = z.w.Write(z.buf[:8])
return z.err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package lzw implements the Lempel-Ziv-Welch compressed data format,
// described in T. A. Welch, “A Technique for High-Performance Data
// Compression”, Computer, 17(6) (June 1984), pp 8-19.
//
// In particular, it implements LZW as used by the GIF and PDF file
// formats, which means variable-width codes up to 12 bits and the first
// two non-literal codes are a clear code and an EOF code.
//
// The TIFF file format uses a similar but incompatible version of the LZW
// algorithm. See the golang.org/x/image/tiff/lzw package for an
// implementation.
package lzw
// TODO(nigeltao): check that PDF uses LZW in the same way as GIF,
// modulo LSB/MSB packing order.
import (
"bufio"
"errors"
"fmt"
"io"
)
// Order specifies the bit ordering in an LZW data stream.
type Order int
const (
// LSB means Least Significant Bits first, as used in the GIF file format.
LSB Order = iota
// MSB means Most Significant Bits first, as used in the TIFF and PDF
// file formats.
MSB
)
const (
maxWidth = 12
decoderInvalidCode = 0xffff
flushBuffer = 1 << maxWidth
)
// Reader is an io.Reader which can be used to read compressed data in the
// LZW format.
type Reader struct {
r io.ByteReader
bits uint32
nBits uint
width uint
read func(*Reader) (uint16, error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err error
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
//
// overflow is the code at which hi overflows the code width. It always
// equals 1 << width.
//
// last is the most recently seen code, or decoderInvalidCode.
//
// An invariant is that hi < overflow.
clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// output is the temporary output buffer.
// Literal codes are accumulated from the start of the buffer.
// Non-literal codes decode to a sequence of suffixes that are first
// written right-to-left from the end of the buffer before being copied
// to the start of the buffer.
// It is flushed when it contains >= 1<<maxWidth bytes,
// so that there is always room to decode an entire code.
output [2 * 1 << maxWidth]byte
o int // write index into output
toRead []byte // bytes to return from Read
}
// readLSB returns the next code for "Least Significant Bits first" data.
func (r *Reader) readLSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
r.bits |= uint32(x) << r.nBits
r.nBits += 8
}
code := uint16(r.bits & (1<<r.width - 1))
r.bits >>= r.width
r.nBits -= r.width
return code, nil
}
// readMSB returns the next code for "Most Significant Bits first" data.
func (r *Reader) readMSB() (uint16, error) {
for r.nBits < r.width {
x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
r.bits |= uint32(x) << (24 - r.nBits)
r.nBits += 8
}
code := uint16(r.bits >> (32 - r.width))
r.bits <<= r.width
r.nBits -= r.width
return code, nil
}
// Read implements io.Reader, reading uncompressed bytes from its underlying Reader.
func (r *Reader) Read(b []byte) (int, error) {
for {
if len(r.toRead) > 0 {
n := copy(b, r.toRead)
r.toRead = r.toRead[n:]
return n, nil
}
if r.err != nil {
return 0, r.err
}
r.decode()
}
}
// decode decompresses bytes from r and leaves them in d.toRead.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
func (r *Reader) decode() {
// Loop over the code stream, converting codes into decompressed bytes.
loop:
for {
code, err := r.read(r)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
r.err = err
break
}
switch {
case code < r.clear:
// We have a literal code.
r.output[r.o] = uint8(code)
r.o++
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
r.suffix[r.hi] = uint8(code)
r.prefix[r.hi] = r.last
}
case code == r.clear:
r.width = 1 + uint(r.litWidth)
r.hi = r.eof
r.overflow = 1 << r.width
r.last = decoderInvalidCode
continue
case code == r.eof:
r.err = io.EOF
break loop
case code <= r.hi:
c, i := code, len(r.output)-1
if code == r.hi && r.last != decoderInvalidCode {
// code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code.
c = r.last
for c >= r.clear {
c = r.prefix[c]
}
r.output[i] = uint8(c)
i--
c = r.last
}
// Copy the suffix chain into output and then write that to w.
for c >= r.clear {
r.output[i] = r.suffix[c]
i--
c = r.prefix[c]
}
r.output[i] = uint8(c)
r.o += copy(r.output[r.o:], r.output[i:])
if r.last != decoderInvalidCode {
// Save what the hi code expands to.
r.suffix[r.hi] = uint8(c)
r.prefix[r.hi] = r.last
}
default:
r.err = errors.New("lzw: invalid code")
break loop
}
r.last, r.hi = code, r.hi+1
if r.hi >= r.overflow {
if r.hi > r.overflow {
panic("unreachable")
}
if r.width == maxWidth {
r.last = decoderInvalidCode
// Undo the d.hi++ a few lines above, so that (1) we maintain
// the invariant that d.hi < d.overflow, and (2) d.hi does not
// eventually overflow a uint16.
r.hi--
} else {
r.width++
r.overflow = 1 << r.width
}
}
if r.o >= flushBuffer {
break
}
}
// Flush pending output.
r.toRead = r.output[:r.o]
r.o = 0
}
var errClosed = errors.New("lzw: reader/writer is closed")
// Close closes the Reader and returns an error for any future read operation.
// It does not close the underlying io.Reader.
func (r *Reader) Close() error {
r.err = errClosed // in case any Reads come along
return nil
}
// Reset clears the Reader's state and allows it to be reused again
// as a new Reader.
func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
*r = Reader{}
r.init(src, order, litWidth)
}
// NewReader creates a new io.ReadCloser.
// Reads from the returned io.ReadCloser read and decompress data from r.
// If r does not also implement io.ByteReader,
// the decompressor may read more data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when
// finished reading.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. It must equal the litWidth
// used during compression.
//
// It is guaranteed that the underlying type of the returned io.ReadCloser
// is a *Reader.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
return newReader(r, order, litWidth)
}
func newReader(src io.Reader, order Order, litWidth int) *Reader {
r := new(Reader)
r.init(src, order, litWidth)
return r
}
func (r *Reader) init(src io.Reader, order Order, litWidth int) {
switch order {
case LSB:
r.read = (*Reader).readLSB
case MSB:
r.read = (*Reader).readMSB
default:
r.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
br, ok := src.(io.ByteReader)
if !ok && src != nil {
br = bufio.NewReader(src)
}
r.r = br
r.litWidth = litWidth
r.width = 1 + uint(litWidth)
r.clear = uint16(1) << uint(litWidth)
r.eof, r.hi = r.clear+1, r.clear+1
r.overflow = uint16(1) << r.width
r.last = decoderInvalidCode
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lzw
import (
"bufio"
"errors"
"fmt"
"io"
)
// A writer is a buffered, flushable writer.
type writer interface {
io.ByteWriter
Flush() error
}
const (
// A code is a 12 bit value, stored as a uint32 when encoding to avoid
// type conversions when shifting bits.
maxCode = 1<<12 - 1
invalidCode = 1<<32 - 1
// There are 1<<12 possible codes, which is an upper bound on the number of
// valid hash table entries at any given point in time. tableSize is 4x that.
tableSize = 4 * 1 << 12
tableMask = tableSize - 1
// A hash table entry is a uint32. Zero is an invalid entry since the
// lower 12 bits of a valid entry must be a non-literal code.
invalidEntry = 0
)
// Writer is an LZW compressor. It writes the compressed form of the data
// to an underlying writer (see NewWriter).
type Writer struct {
// w is the writer that compressed bytes are written to.
w writer
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
write func(*Writer, uint32) error
bits uint32
nBits uint
width uint
// litWidth is the width in bits of literal codes.
litWidth uint
// hi is the code implied by the next code emission.
// overflow is the code at which hi overflows the code width.
hi, overflow uint32
// savedCode is the accumulated code at the end of the most recent Write
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
// err is the first error encountered during writing. Closing the writer
// will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing.
// The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
// The values are a 12-bit code.
table [tableSize]uint32
}
// writeLSB writes the code c for "Least Significant Bits first" data.
func (w *Writer) writeLSB(c uint32) error {
w.bits |= c << w.nBits
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
w.bits >>= 8
w.nBits -= 8
}
return nil
}
// writeMSB writes the code c for "Most Significant Bits first" data.
func (w *Writer) writeMSB(c uint32) error {
w.bits |= c << (32 - w.width - w.nBits)
w.nBits += w.width
for w.nBits >= 8 {
if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
return err
}
w.bits <<= 8
w.nBits -= 8
}
return nil
}
// errOutOfCodes is an internal error that means that the writer has run out
// of unused codes and a clear code needs to be sent next.
var errOutOfCodes = errors.New("lzw: out of codes")
// incHi increments e.hi and checks for both overflow and running out of
// unused codes. In the latter case, incHi sends a clear code, resets the
// writer state and returns errOutOfCodes.
func (w *Writer) incHi() error {
w.hi++
if w.hi == w.overflow {
w.width++
w.overflow <<= 1
}
if w.hi == maxCode {
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return err
}
w.width = w.litWidth + 1
w.hi = clear + 1
w.overflow = clear << 1
for i := range w.table {
w.table[i] = invalidEntry
}
return errOutOfCodes
}
return nil
}
// Write writes a compressed representation of p to w's underlying writer.
func (w *Writer) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
for _, x := range p {
if x > maxLit {
w.err = errors.New("lzw: input byte too large for the litWidth")
return 0, w.err
}
}
}
n = len(p)
code := w.savedCode
if code == invalidCode {
// This is the first write; send a clear code.
// https://www.w3.org/Graphics/GIF/spec-gif89a.txt Appendix F
// "Variable-Length-Code LZW Compression" says that "Encoders should
// output a Clear code as the first code of each image data stream".
//
// LZW compression isn't only used by GIF, but it's cheap to follow
// that directive unconditionally.
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return 0, err
}
// After the starting clear code, the next code sent (for non-empty
// input) is always a literal code.
code, p = uint32(p[0]), p[1:]
}
loop:
for _, x := range p {
literal := uint32(x)
key := code<<8 | literal
// If there is a hash table hit for this key then we continue the loop
// and do not emit a code yet.
hash := (key>>12 ^ key) & tableMask
for h, t := hash, w.table[hash]; t != invalidEntry; {
if key == t>>12 {
code = t & maxCode
continue loop
}
h = (h + 1) & tableMask
t = w.table[h]
}
// Otherwise, write the current code, and literal becomes the start of
// the next emitted code.
if w.err = w.write(w, code); w.err != nil {
return 0, w.err
}
code = literal
// Increment e.hi, the next implied code. If we run out of codes, reset
// the writer state (including clearing the hash table) and continue.
if err1 := w.incHi(); err1 != nil {
if err1 == errOutOfCodes {
continue
}
w.err = err1
return 0, w.err
}
// Otherwise, insert key -> e.hi into the map that e.table represents.
for {
if w.table[hash] == invalidEntry {
w.table[hash] = (key << 12) | w.hi
break
}
hash = (hash + 1) & tableMask
}
}
w.savedCode = code
return n, nil
}
// Close closes the Writer, flushing any pending output. It does not close
// w's underlying writer.
func (w *Writer) Close() error {
if w.err != nil {
if w.err == errClosed {
return nil
}
return w.err
}
// Make any future calls to Write return errClosed.
w.err = errClosed
// Write the savedCode if valid.
if w.savedCode != invalidCode {
if err := w.write(w, w.savedCode); err != nil {
return err
}
if err := w.incHi(); err != nil && err != errOutOfCodes {
return err
}
} else {
// Write the starting clear code, as w.Write did not.
clear := uint32(1) << w.litWidth
if err := w.write(w, clear); err != nil {
return err
}
}
// Write the eof code.
eof := uint32(1)<<w.litWidth + 1
if err := w.write(w, eof); err != nil {
return err
}
// Write the final bits.
if w.nBits > 0 {
if w.order == MSB {
w.bits >>= 24
}
if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
}
return w.w.Flush()
}
// Reset clears the Writer's state and allows it to be reused again
// as a new Writer.
func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
*w = Writer{}
w.init(dst, order, litWidth)
}
// NewWriter creates a new io.WriteCloser.
// Writes to the returned io.WriteCloser are compressed and written to w.
// It is the caller's responsibility to call Close on the WriteCloser when
// finished writing.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
//
// It is guaranteed that the underlying type of the returned io.WriteCloser
// is a *Writer.
func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
return newWriter(w, order, litWidth)
}
func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
w := new(Writer)
w.init(dst, order, litWidth)
return w
}
func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
switch order {
case LSB:
w.write = (*Writer).writeLSB
case MSB:
w.write = (*Writer).writeMSB
default:
w.err = errors.New("lzw: unknown order")
return
}
if litWidth < 2 || 8 < litWidth {
w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return
}
bw, ok := dst.(writer)
if !ok && dst != nil {
bw = bufio.NewWriter(dst)
}
w.w = bw
lw := uint(litWidth)
w.order = order
w.width = 1 + lw
w.litWidth = lw
w.hi = 1<<lw + 1
w.overflow = 1 << (lw + 1)
w.savedCode = invalidCode
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zlib implements reading and writing of zlib format compressed data,
as specified in RFC 1950.
The implementation provides filters that uncompress during reading
and compress during writing. For example, to write compressed data
to a buffer:
var b bytes.Buffer
w := zlib.NewWriter(&b)
w.Write([]byte("hello, world\n"))
w.Close()
and to read that data back:
r, err := zlib.NewReader(&b)
io.Copy(os.Stdout, r)
r.Close()
*/
package zlib
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"hash"
"hash/adler32"
"io"
)
const (
zlibDeflate = 8
zlibMaxWindow = 7
)
var (
// ErrChecksum is returned when reading ZLIB data that has an invalid checksum.
ErrChecksum = errors.New("zlib: invalid checksum")
// ErrDictionary is returned when reading ZLIB data that has an invalid dictionary.
ErrDictionary = errors.New("zlib: invalid dictionary")
// ErrHeader is returned when reading ZLIB data that has an invalid header.
ErrHeader = errors.New("zlib: invalid header")
)
type reader struct {
r flate.Reader
decompressor io.ReadCloser
digest hash.Hash32
err error
scratch [4]byte
}
// Resetter resets a ReadCloser returned by NewReader or NewReaderDict
// to switch to a new underlying Reader. This permits reusing a ReadCloser
// instead of allocating a new one.
type Resetter interface {
// Reset discards any buffered data and resets the Resetter as if it was
// newly initialized with the given reader.
Reset(r io.Reader, dict []byte) error
}
// NewReader creates a new ReadCloser.
// Reads from the returned ReadCloser read and decompress data from r.
// If r does not implement io.ByteReader, the decompressor may read more
// data than necessary from r.
// It is the caller's responsibility to call Close on the ReadCloser when done.
//
// The ReadCloser returned by NewReader also implements Resetter.
func NewReader(r io.Reader) (io.ReadCloser, error) {
return NewReaderDict(r, nil)
}
// NewReaderDict is like NewReader but uses a preset dictionary.
// NewReaderDict ignores the dictionary if the compressed data does not refer to it.
// If the compressed data refers to a different dictionary, NewReaderDict returns ErrDictionary.
//
// The ReadCloser returned by NewReaderDict also implements Resetter.
func NewReaderDict(r io.Reader, dict []byte) (io.ReadCloser, error) {
z := new(reader)
err := z.Reset(r, dict)
if err != nil {
return nil, err
}
return z, nil
}
func (z *reader) Read(p []byte) (int, error) {
if z.err != nil {
return 0, z.err
}
var n int
n, z.err = z.decompressor.Read(p)
z.digest.Write(p[0:n])
if z.err != io.EOF {
// In the normal case we return here.
return n, z.err
}
// Finished file; check checksum.
if _, err := io.ReadFull(z.r, z.scratch[0:4]); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
z.err = err
return n, z.err
}
// ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
checksum := binary.BigEndian.Uint32(z.scratch[:4])
if checksum != z.digest.Sum32() {
z.err = ErrChecksum
return n, z.err
}
return n, io.EOF
}
// Calling Close does not close the wrapped io.Reader originally passed to NewReader.
// In order for the ZLIB checksum to be verified, the reader must be
// fully consumed until the io.EOF.
func (z *reader) Close() error {
if z.err != nil && z.err != io.EOF {
return z.err
}
z.err = z.decompressor.Close()
return z.err
}
func (z *reader) Reset(r io.Reader, dict []byte) error {
*z = reader{decompressor: z.decompressor}
if fr, ok := r.(flate.Reader); ok {
z.r = fr
} else {
z.r = bufio.NewReader(r)
}
// Read the header (RFC 1950 section 2.2.).
_, z.err = io.ReadFull(z.r, z.scratch[0:2])
if z.err != nil {
if z.err == io.EOF {
z.err = io.ErrUnexpectedEOF
}
return z.err
}
h := binary.BigEndian.Uint16(z.scratch[:2])
if (z.scratch[0]&0x0f != zlibDeflate) || (z.scratch[0]>>4 > zlibMaxWindow) || (h%31 != 0) {
z.err = ErrHeader
return z.err
}
haveDict := z.scratch[1]&0x20 != 0
if haveDict {
_, z.err = io.ReadFull(z.r, z.scratch[0:4])
if z.err != nil {
if z.err == io.EOF {
z.err = io.ErrUnexpectedEOF
}
return z.err
}
checksum := binary.BigEndian.Uint32(z.scratch[:4])
if checksum != adler32.Checksum(dict) {
z.err = ErrDictionary
return z.err
}
}
if z.decompressor == nil {
if haveDict {
z.decompressor = flate.NewReaderDict(z.r, dict)
} else {
z.decompressor = flate.NewReader(z.r)
}
} else {
z.decompressor.(flate.Resetter).Reset(z.r, dict)
}
z.digest = adler32.New()
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zlib
import (
"compress/flate"
"encoding/binary"
"fmt"
"hash"
"hash/adler32"
"io"
)
// These constants are copied from the flate package, so that code that imports
// "compress/zlib" does not also have to import "compress/flate".
const (
NoCompression = flate.NoCompression
BestSpeed = flate.BestSpeed
BestCompression = flate.BestCompression
DefaultCompression = flate.DefaultCompression
HuffmanOnly = flate.HuffmanOnly
)
// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
w io.Writer
level int
dict []byte
compressor *flate.Writer
digest hash.Hash32
err error
scratch [4]byte
wroteHeader bool
}
// NewWriter creates a new Writer.
// Writes to the returned Writer are compressed and written to w.
//
// It is the caller's responsibility to call Close on the Writer when done.
// Writes may be buffered and not flushed until Close.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevelDict(w, DefaultCompression, nil)
return z
}
// NewWriterLevel is like NewWriter but specifies the compression level instead
// of assuming DefaultCompression.
//
// The compression level can be DefaultCompression, NoCompression, HuffmanOnly
// or any integer value between BestSpeed and BestCompression inclusive.
// The error returned will be nil if the level is valid.
func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
return NewWriterLevelDict(w, level, nil)
}
// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
// compress with.
//
// The dictionary may be nil. If not, its contents should not be modified until
// the Writer is closed.
func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) {
if level < HuffmanOnly || level > BestCompression {
return nil, fmt.Errorf("zlib: invalid compression level: %d", level)
}
return &Writer{
w: w,
level: level,
dict: dict,
}, nil
}
// Reset clears the state of the Writer z such that it is equivalent to its
// initial state from NewWriterLevel or NewWriterLevelDict, but instead writing
// to w.
func (z *Writer) Reset(w io.Writer) {
z.w = w
// z.level and z.dict left unchanged.
if z.compressor != nil {
z.compressor.Reset(w)
}
if z.digest != nil {
z.digest.Reset()
}
z.err = nil
z.scratch = [4]byte{}
z.wroteHeader = false
}
// writeHeader writes the ZLIB header.
func (z *Writer) writeHeader() (err error) {
z.wroteHeader = true
// ZLIB has a two-byte header (as documented in RFC 1950).
// The first four bits is the CINFO (compression info), which is 7 for the default deflate window size.
// The next four bits is the CM (compression method), which is 8 for deflate.
z.scratch[0] = 0x78
// The next two bits is the FLEVEL (compression level). The four values are:
// 0=fastest, 1=fast, 2=default, 3=best.
// The next bit, FDICT, is set if a dictionary is given.
// The final five FCHECK bits form a mod-31 checksum.
switch z.level {
case -2, 0, 1:
z.scratch[1] = 0 << 6
case 2, 3, 4, 5:
z.scratch[1] = 1 << 6
case 6, -1:
z.scratch[1] = 2 << 6
case 7, 8, 9:
z.scratch[1] = 3 << 6
default:
panic("unreachable")
}
if z.dict != nil {
z.scratch[1] |= 1 << 5
}
z.scratch[1] += uint8(31 - binary.BigEndian.Uint16(z.scratch[:2])%31)
if _, err = z.w.Write(z.scratch[0:2]); err != nil {
return err
}
if z.dict != nil {
// The next four bytes are the Adler-32 checksum of the dictionary.
binary.BigEndian.PutUint32(z.scratch[:], adler32.Checksum(z.dict))
if _, err = z.w.Write(z.scratch[0:4]); err != nil {
return err
}
}
if z.compressor == nil {
// Initialize deflater unless the Writer is being reused
// after a Reset call.
z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict)
if err != nil {
return err
}
z.digest = adler32.New()
}
return nil
}
// Write writes a compressed form of p to the underlying io.Writer. The
// compressed bytes are not necessarily flushed until the Writer is closed or
// explicitly flushed.
func (z *Writer) Write(p []byte) (n int, err error) {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return 0, z.err
}
if len(p) == 0 {
return 0, nil
}
n, err = z.compressor.Write(p)
if err != nil {
z.err = err
return
}
z.digest.Write(p)
return
}
// Flush flushes the Writer to its underlying io.Writer.
func (z *Writer) Flush() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return z.err
}
z.err = z.compressor.Flush()
return z.err
}
// Close closes the Writer, flushing any unwritten data to the underlying
// io.Writer, but does not close the underlying io.Writer.
func (z *Writer) Close() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil {
return z.err
}
z.err = z.compressor.Close()
if z.err != nil {
return z.err
}
checksum := z.digest.Sum32()
// ZLIB (RFC 1950) is big-endian, unlike GZIP (RFC 1952).
binary.BigEndian.PutUint32(z.scratch[:], checksum)
_, z.err = z.w.Write(z.scratch[0:4])
return z.err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package heap provides heap operations for any type that implements
// heap.Interface. A heap is a tree with the property that each node is the
// minimum-valued node in its subtree.
//
// The minimum element in the tree is the root, at index 0.
//
// A heap is a common way to implement a priority queue. To build a priority
// queue, implement the Heap interface with the (negative) priority as the
// ordering for the Less method, so Push adds items while Pop removes the
// highest-priority item from the queue. The Examples include such an
// implementation; the file example_pq_test.go has the complete source.
package heap
import "sort"
// The Interface type describes the requirements
// for a type using the routines in this package.
// Any type that implements it may be used as a
// min-heap with the following invariants (established after
// Init has been called or if the data is empty or sorted):
//
// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len()
//
// Note that Push and Pop in this interface are for package heap's
// implementation to call. To add and remove things from the heap,
// use heap.Push and heap.Pop.
type Interface interface {
sort.Interface
Push(x any) // add x as element Len()
Pop() any // remove and return element Len() - 1.
}
// Init establishes the heap invariants required by the other routines in this package.
// Init is idempotent with respect to the heap invariants
// and may be called whenever the heap invariants may have been invalidated.
// The complexity is O(n) where n = h.Len().
func Init(h Interface) {
// heapify
n := h.Len()
for i := n/2 - 1; i >= 0; i-- {
down(h, i, n)
}
}
// Push pushes the element x onto the heap.
// The complexity is O(log n) where n = h.Len().
func Push(h Interface, x any) {
h.Push(x)
up(h, h.Len()-1)
}
// Pop removes and returns the minimum element (according to Less) from the heap.
// The complexity is O(log n) where n = h.Len().
// Pop is equivalent to Remove(h, 0).
func Pop(h Interface) any {
n := h.Len() - 1
h.Swap(0, n)
down(h, 0, n)
return h.Pop()
}
// Remove removes and returns the element at index i from the heap.
// The complexity is O(log n) where n = h.Len().
func Remove(h Interface, i int) any {
n := h.Len() - 1
if n != i {
h.Swap(i, n)
if !down(h, i, n) {
up(h, i)
}
}
return h.Pop()
}
// Fix re-establishes the heap ordering after the element at index i has changed its value.
// Changing the value of the element at index i and then calling Fix is equivalent to,
// but less expensive than, calling Remove(h, i) followed by a Push of the new value.
// The complexity is O(log n) where n = h.Len().
func Fix(h Interface, i int) {
if !down(h, i, h.Len()) {
up(h, i)
}
}
func up(h Interface, j int) {
for {
i := (j - 1) / 2 // parent
if i == j || !h.Less(j, i) {
break
}
h.Swap(i, j)
j = i
}
}
func down(h Interface, i0, n int) bool {
i := i0
for {
j1 := 2*i + 1
if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
break
}
j := j1 // left child
if j2 := j1 + 1; j2 < n && h.Less(j2, j1) {
j = j2 // = 2*i + 2 // right child
}
if !h.Less(j, i) {
break
}
h.Swap(i, j)
i = j
}
return i > i0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package list implements a doubly linked list.
//
// To iterate over a list (where l is a *List):
//
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.Value
// }
package list
// Element is an element of a linked list.
type Element struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *Element
// The list to which this element belongs.
list *List
// The value stored with this element.
Value any
}
// Next returns the next list element or nil.
func (e *Element) Next() *Element {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *Element) Prev() *Element {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// List represents a doubly linked list.
// The zero value for List is an empty list ready to use.
type List struct {
root Element // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *List) Init() *List {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// New returns an initialized list.
func New() *List { return new(List).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *List) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *List) Front() *Element {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *List) Back() *Element {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *List) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *List) insert(e, at *Element) *Element {
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *List) insertValue(v any, at *Element) *Element {
return l.insert(&Element{Value: v}, at)
}
// remove removes e from its list, decrements l.len
func (l *List) remove(e *Element) {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
}
// move moves e to next to at.
func (l *List) move(e, at *Element) {
if e == at {
return
}
e.prev.next = e.next
e.next.prev = e.prev
e.prev = at
e.next = at.next
e.prev.next = e
e.next.prev = e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *List) Remove(e *Element) any {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *List) PushFront(v any) *Element {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *List) PushBack(v any) *Element {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List) InsertBefore(v any, mark *Element) *Element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *List) InsertAfter(v any, mark *Element) *Element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List) MoveToFront(e *Element) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *List) MoveToBack(e *Element) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.move(e, l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List) MoveBefore(e, mark *Element) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *List) MoveAfter(e, mark *Element) {
if e.list != l || e == mark || mark.list != l {
return
}
l.move(e, mark)
}
// PushBackList inserts a copy of another list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List) PushBackList(other *List) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of another list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *List) PushFrontList(other *List) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ring implements operations on circular lists.
package ring
// A Ring is an element of a circular list, or ring.
// Rings do not have a beginning or end; a pointer to any ring element
// serves as reference to the entire ring. Empty rings are represented
// as nil Ring pointers. The zero value for a Ring is a one-element
// ring with a nil Value.
type Ring struct {
next, prev *Ring
Value any // for use by client; untouched by this library
}
func (r *Ring) init() *Ring {
r.next = r
r.prev = r
return r
}
// Next returns the next ring element. r must not be empty.
func (r *Ring) Next() *Ring {
if r.next == nil {
return r.init()
}
return r.next
}
// Prev returns the previous ring element. r must not be empty.
func (r *Ring) Prev() *Ring {
if r.next == nil {
return r.init()
}
return r.prev
}
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty.
func (r *Ring) Move(n int) *Ring {
if r.next == nil {
return r.init()
}
switch {
case n < 0:
for ; n < 0; n++ {
r = r.prev
}
case n > 0:
for ; n > 0; n-- {
r = r.next
}
}
return r
}
// New creates a ring of n elements.
func New(n int) *Ring {
if n <= 0 {
return nil
}
r := new(Ring)
p := r
for i := 1; i < n; i++ {
p.next = &Ring{prev: p}
p = p.next
}
p.next = r
r.prev = p
return r
}
// Link connects ring r with ring s such that r.Next()
// becomes s and returns the original value for r.Next().
// r must not be empty.
//
// If r and s point to the same ring, linking
// them removes the elements between r and s from the ring.
// The removed elements form a subring and the result is a
// reference to that subring (if no elements were removed,
// the result is still the original value for r.Next(),
// and not nil).
//
// If r and s point to different rings, linking
// them creates a single ring with the elements of s inserted
// after r. The result points to the element following the
// last element of s after insertion.
func (r *Ring) Link(s *Ring) *Ring {
n := r.Next()
if s != nil {
p := s.Prev()
// Note: Cannot use multiple assignment because
// evaluation order of LHS is not specified.
r.next = s
s.prev = r
n.prev = p
p.next = n
}
return n
}
// Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty.
func (r *Ring) Unlink(n int) *Ring {
if n <= 0 {
return nil
}
return r.Link(r.Move(n + 1))
}
// Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements.
func (r *Ring) Len() int {
n := 0
if r != nil {
n = 1
for p := r.Next(); p != r; p = p.next {
n++
}
}
return n
}
// Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r.
func (r *Ring) Do(f func(any)) {
if r != nil {
f(r.Value)
for p := r.Next(); p != r; p = p.next {
f(p.Value)
}
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
// cancellation signals, and other request-scoped values across API boundaries
// and between processes.
//
// Incoming requests to a server should create a Context, and outgoing
// calls to servers should accept a Context. The chain of function
// calls between them must propagate the Context, optionally replacing
// it with a derived Context created using WithCancel, WithDeadline,
// WithTimeout, or WithValue. When a Context is canceled, all
// Contexts derived from it are also canceled.
//
// The WithCancel, WithDeadline, and WithTimeout functions take a
// Context (the parent) and return a derived Context (the child) and a
// CancelFunc. Calling the CancelFunc cancels the child and its
// children, removes the parent's reference to the child, and stops
// any associated timers. Failing to call the CancelFunc leaks the
// child and its children until the parent is canceled or the timer
// fires. The go vet tool checks that CancelFuncs are used on all
// control-flow paths.
//
// The WithCancelCause function returns a CancelCauseFunc, which
// takes an error and records it as the cancellation cause. Calling
// Cause on the canceled context or any of its children retrieves
// the cause. If no cause is specified, Cause(ctx) returns the same
// value as ctx.Err().
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
// explicitly to each function that needs it. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
// Do not pass a nil Context, even if a function permits it. Pass context.TODO
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
// See https://blog.golang.org/context for example code for a server that uses
// Contexts.
package context
import (
"errors"
"internal/reflectlite"
"sync"
"sync/atomic"
"time"
)
// A Context carries a deadline, a cancellation signal, and other values across
// API boundaries.
//
// Context's methods may be called by multiple goroutines simultaneously.
type Context interface {
// Deadline returns the time when work done on behalf of this context
// should be canceled. Deadline returns ok==false when no deadline is
// set. Successive calls to Deadline return the same results.
Deadline() (deadline time.Time, ok bool)
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
// The close of the Done channel may happen asynchronously,
// after the cancel function returns.
//
// WithCancel arranges for Done to be closed when cancel is called;
// WithDeadline arranges for Done to be closed when the deadline
// expires; WithTimeout arranges for Done to be closed when the timeout
// elapses.
//
// Done is provided for use in select statements:
//
// // Stream generates values with DoSomething and sends them to out
// // until DoSomething returns an error or ctx.Done is closed.
// func Stream(ctx context.Context, out chan<- Value) error {
// for {
// v, err := DoSomething(ctx)
// if err != nil {
// return err
// }
// select {
// case <-ctx.Done():
// return ctx.Err()
// case out <- v:
// }
// }
// }
//
// See https://blog.golang.org/pipelines for more examples of how to use
// a Done channel for cancellation.
Done() <-chan struct{}
// If Done is not yet closed, Err returns nil.
// If Done is closed, Err returns a non-nil error explaining why:
// Canceled if the context was canceled
// or DeadlineExceeded if the context's deadline passed.
// After Err returns a non-nil error, successive calls to Err return the same error.
Err() error
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
//
// Use context values only for request-scoped data that transits
// processes and API boundaries, not for passing optional parameters to
// functions.
//
// A key identifies a specific value in a Context. Functions that wish
// to store values in Context typically allocate a key in a global
// variable then use that key as the argument to context.WithValue and
// Context.Value. A key can be any type that supports equality;
// packages should define keys as an unexported type to avoid
// collisions.
//
// Packages that define a Context key should provide type-safe accessors
// for the values stored using that key:
//
// // Package user defines a User type that's stored in Contexts.
// package user
//
// import "context"
//
// // User is the type of value stored in the Contexts.
// type User struct {...}
//
// // key is an unexported type for keys defined in this package.
// // This prevents collisions with keys defined in other packages.
// type key int
//
// // userKey is the key for user.User values in Contexts. It is
// // unexported; clients use user.NewContext and user.FromContext
// // instead of using this key directly.
// var userKey key
//
// // NewContext returns a new Context that carries value u.
// func NewContext(ctx context.Context, u *User) context.Context {
// return context.WithValue(ctx, userKey, u)
// }
//
// // FromContext returns the User value stored in ctx, if any.
// func FromContext(ctx context.Context) (*User, bool) {
// u, ok := ctx.Value(userKey).(*User)
// return u, ok
// }
Value(key any) any
}
// Canceled is the error returned by Context.Err when the context is canceled.
var Canceled = errors.New("context canceled")
// DeadlineExceeded is the error returned by Context.Err when the context's
// deadline passes.
var DeadlineExceeded error = deadlineExceededError{}
type deadlineExceededError struct{}
func (deadlineExceededError) Error() string { return "context deadline exceeded" }
func (deadlineExceededError) Timeout() bool { return true }
func (deadlineExceededError) Temporary() bool { return true }
// An emptyCtx is never canceled, has no values, and has no deadline.
// It is the common base of backgroundCtx and todoCtx.
type emptyCtx struct{}
func (emptyCtx) Deadline() (deadline time.Time, ok bool) {
return
}
func (emptyCtx) Done() <-chan struct{} {
return nil
}
func (emptyCtx) Err() error {
return nil
}
func (emptyCtx) Value(key any) any {
return nil
}
type backgroundCtx struct{ emptyCtx }
func (backgroundCtx) String() string {
return "context.Background"
}
type todoCtx struct{ emptyCtx }
func (todoCtx) String() string {
return "context.TODO"
}
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
// initialization, and tests, and as the top-level Context for incoming
// requests.
func Background() Context {
return backgroundCtx{}
}
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
// parameter).
func TODO() Context {
return todoCtx{}
}
// A CancelFunc tells an operation to abandon its work.
// A CancelFunc does not wait for the work to stop.
// A CancelFunc may be called by multiple goroutines simultaneously.
// After the first call, subsequent calls to a CancelFunc do nothing.
type CancelFunc func()
// WithCancel returns a copy of parent with a new Done channel. The returned
// context's Done channel is closed when the returned cancel function is called
// or when the parent context's Done channel is closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
c := withCancel(parent)
return c, func() { c.cancel(true, Canceled, nil) }
}
// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancellation cause.
// This cause can be retrieved by calling Cause on the canceled Context or on
// any of its derived Contexts.
//
// If the context has already been canceled, CancelCauseFunc does not set the cause.
// For example, if childContext is derived from parentContext:
// - if parentContext is canceled with cause1 before childContext is canceled with cause2,
// then Cause(parentContext) == Cause(childContext) == cause1
// - if childContext is canceled with cause2 before parentContext is canceled with cause1,
// then Cause(parentContext) == cause1 and Cause(childContext) == cause2
type CancelCauseFunc func(cause error)
// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc.
// Calling cancel with a non-nil error (the "cause") records that error in ctx;
// it can then be retrieved using Cause(ctx).
// Calling cancel with nil sets the cause to Canceled.
//
// Example use:
//
// ctx, cancel := context.WithCancelCause(parent)
// cancel(myError)
// ctx.Err() // returns context.Canceled
// context.Cause(ctx) // returns myError
func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
c := withCancel(parent)
return c, func(cause error) { c.cancel(true, Canceled, cause) }
}
func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := &cancelCtx{Context: parent}
propagateCancel(parent, c)
return c
}
// Cause returns a non-nil error explaining why c was canceled.
// The first cancellation of c or one of its parents sets the cause.
// If that cancellation happened via a call to CancelCauseFunc(err),
// then Cause returns err.
// Otherwise Cause(c) returns the same value as c.Err().
// Cause returns nil if c has not been canceled yet.
func Cause(c Context) error {
if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.cause
}
return nil
}
// goroutines counts the number of goroutines ever created; for testing.
var goroutines atomic.Int32
// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
done := parent.Done()
if done == nil {
return // parent is never canceled
}
select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
} else {
goroutines.Add(1)
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
}
}
// &cancelCtxKey is the key that a cancelCtx returns itself for.
var cancelCtxKey int
// parentCancelCtx returns the underlying *cancelCtx for parent.
// It does this by looking up parent.Value(&cancelCtxKey) to find
// the innermost enclosing *cancelCtx and then checking whether
// parent.Done() matches that *cancelCtx. (If not, the *cancelCtx
// has been wrapped in a custom implementation providing a
// different done channel, in which case we should not bypass it.)
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
done := parent.Done()
if done == closedchan || done == nil {
return nil, false
}
p, ok := parent.Value(&cancelCtxKey).(*cancelCtx)
if !ok {
return nil, false
}
pdone, _ := p.done.Load().(chan struct{})
if pdone != done {
return nil, false
}
return p, true
}
// removeChild removes a context from its parent.
func removeChild(parent Context, child canceler) {
p, ok := parentCancelCtx(parent)
if !ok {
return
}
p.mu.Lock()
if p.children != nil {
delete(p.children, child)
}
p.mu.Unlock()
}
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
cancel(removeFromParent bool, err, cause error)
Done() <-chan struct{}
}
// closedchan is a reusable closed channel.
var closedchan = make(chan struct{})
func init() {
close(closedchan)
}
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
mu sync.Mutex // protects following fields
done atomic.Value // of chan struct{}, created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
cause error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Value(key any) any {
if key == &cancelCtxKey {
return c
}
return value(c.Context, key)
}
func (c *cancelCtx) Done() <-chan struct{} {
d := c.done.Load()
if d != nil {
return d.(chan struct{})
}
c.mu.Lock()
defer c.mu.Unlock()
d = c.done.Load()
if d == nil {
d = make(chan struct{})
c.done.Store(d)
}
return d.(chan struct{})
}
func (c *cancelCtx) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}
type stringer interface {
String() string
}
func contextName(c Context) string {
if s, ok := c.(stringer); ok {
return s.String()
}
return reflectlite.TypeOf(c).String()
}
func (c *cancelCtx) String() string {
return contextName(c.Context) + ".WithCancel"
}
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
// cancel sets c.cause to cause if this is the first time c is canceled.
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
if cause == nil {
cause = err
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
c.cause = cause
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
} else {
close(d)
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
removeChild(c.Context, c)
}
}
// WithDeadline returns a copy of the parent context with the deadline adjusted
// to be no later than d. If the parent's deadline is already earlier than d,
// WithDeadline(parent, d) is semantically equivalent to parent. The returned
// context's Done channel is closed when the deadline expires, when the returned
// cancel function is called, or when the parent context's Done channel is
// closed, whichever happens first.
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
return WithDeadlineCause(parent, d, nil)
}
// WithDeadlineCause behaves like WithDeadline but also sets the cause of the
// returned Context when the deadline is exceeded. The returned CancelFunc does
// not set the cause.
func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, CancelFunc) {
if parent == nil {
panic("cannot create context from nil parent")
}
if cur, ok := parent.Deadline(); ok && cur.Before(d) {
// The current deadline is already sooner than the new one.
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: cancelCtx{Context: parent},
deadline: d,
}
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded, cause) // deadline has already passed
return c, func() { c.cancel(false, Canceled, nil) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded, cause)
})
}
return c, func() { c.cancel(true, Canceled, nil) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
return c.deadline, true
}
func (c *timerCtx) String() string {
return contextName(c.cancelCtx.Context) + ".WithDeadline(" +
c.deadline.String() + " [" +
time.Until(c.deadline).String() + "])"
}
func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
c.cancelCtx.cancel(false, err, cause)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
}
c.mu.Lock()
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
c.mu.Unlock()
}
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
//
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete:
//
// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
// defer cancel() // releases resources if slowOperation completes before timeout elapses
// return slowOperation(ctx)
// }
func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}
// WithTimeoutCause behaves like WithTimeout but also sets the cause of the
// returned Context when the timout expires. The returned CancelFunc does
// not set the cause.
func WithTimeoutCause(parent Context, timeout time.Duration, cause error) (Context, CancelFunc) {
return WithDeadlineCause(parent, time.Now().Add(timeout), cause)
}
// WithValue returns a copy of parent in which the value associated with key is
// val.
//
// Use context Values only for request-scoped data that transits processes and
// APIs, not for passing optional parameters to functions.
//
// The provided key must be comparable and should not be of type
// string or any other built-in type to avoid collisions between
// packages using context. Users of WithValue should define their own
// types for keys. To avoid allocating when assigning to an
// interface{}, context keys often have concrete type
// struct{}. Alternatively, exported context key variables' static
// type should be a pointer or interface.
func WithValue(parent Context, key, val any) Context {
if parent == nil {
panic("cannot create context from nil parent")
}
if key == nil {
panic("nil key")
}
if !reflectlite.TypeOf(key).Comparable() {
panic("key is not comparable")
}
return &valueCtx{parent, key, val}
}
// A valueCtx carries a key-value pair. It implements Value for that key and
// delegates all other calls to the embedded Context.
type valueCtx struct {
Context
key, val any
}
// stringify tries a bit to stringify v, without using fmt, since we don't
// want context depending on the unicode tables. This is only used by
// *valueCtx.String().
func stringify(v any) string {
switch s := v.(type) {
case stringer:
return s.String()
case string:
return s
}
return "<not Stringer>"
}
func (c *valueCtx) String() string {
return contextName(c.Context) + ".WithValue(type " +
reflectlite.TypeOf(c.key).String() +
", val " + stringify(c.val) + ")"
}
func (c *valueCtx) Value(key any) any {
if c.key == key {
return c.val
}
return value(c.Context, key)
}
func value(c Context, key any) any {
for {
switch ctx := c.(type) {
case *valueCtx:
if key == ctx.key {
return ctx.val
}
c = ctx.Context
case *cancelCtx:
if key == &cancelCtxKey {
return c
}
c = ctx.Context
case *timerCtx:
if key == &cancelCtxKey {
return &ctx.cancelCtx
}
c = ctx.Context
case backgroundCtx, todoCtx:
return nil
default:
return c.Value(key)
}
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64
package aes
import (
"crypto/cipher"
"crypto/internal/alias"
"crypto/subtle"
"errors"
)
// The following functions are defined in gcm_*.s.
//go:noescape
func gcmAesInit(productTable *[256]byte, ks []uint32)
//go:noescape
func gcmAesData(productTable *[256]byte, data []byte, T *[16]byte)
//go:noescape
func gcmAesEnc(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, ks []uint32)
//go:noescape
func gcmAesDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, ks []uint32)
//go:noescape
func gcmAesFinish(productTable *[256]byte, tagMask, T *[16]byte, pLen, dLen uint64)
const (
gcmBlockSize = 16
gcmTagSize = 16
gcmMinimumTagSize = 12 // NIST SP 800-38D recommends tags with 12 or more bytes.
gcmStandardNonceSize = 12
)
var errOpen = errors.New("cipher: message authentication failed")
// Assert that aesCipherGCM implements the gcmAble interface.
var _ gcmAble = (*aesCipherGCM)(nil)
// NewGCM returns the AES cipher wrapped in Galois Counter Mode. This is only
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *aesCipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
g := &gcmAsm{ks: c.enc, nonceSize: nonceSize, tagSize: tagSize}
gcmAesInit(&g.productTable, g.ks)
return g, nil
}
type gcmAsm struct {
// ks is the key schedule, the length of which depends on the size of
// the AES key.
ks []uint32
// productTable contains pre-computed multiples of the binary-field
// element used in GHASH.
productTable [256]byte
// nonceSize contains the expected size of the nonce, in bytes.
nonceSize int
// tagSize contains the size of the tag, in bytes.
tagSize int
}
func (g *gcmAsm) NonceSize() int {
return g.nonceSize
}
func (g *gcmAsm) Overhead() int {
return g.tagSize
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// Seal encrypts and authenticates plaintext. See the cipher.AEAD interface for
// details.
func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte {
if len(nonce) != g.nonceSize {
panic("crypto/cipher: incorrect nonce length given to GCM")
}
if uint64(len(plaintext)) > ((1<<32)-2)*BlockSize {
panic("crypto/cipher: message too large for GCM")
}
var counter, tagMask [gcmBlockSize]byte
if len(nonce) == gcmStandardNonceSize {
// Init counter to nonce||1
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
// Otherwise counter = GHASH(nonce)
gcmAesData(&g.productTable, nonce, &counter)
gcmAesFinish(&g.productTable, &tagMask, &counter, uint64(len(nonce)), uint64(0))
}
encryptBlockAsm(len(g.ks)/4-1, &g.ks[0], &tagMask[0], &counter[0])
var tagOut [gcmTagSize]byte
gcmAesData(&g.productTable, data, &tagOut)
ret, out := sliceForAppend(dst, len(plaintext)+g.tagSize)
if alias.InexactOverlap(out[:len(plaintext)], plaintext) {
panic("crypto/cipher: invalid buffer overlap")
}
if len(plaintext) > 0 {
gcmAesEnc(&g.productTable, out, plaintext, &counter, &tagOut, g.ks)
}
gcmAesFinish(&g.productTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:])
return ret
}
// Open authenticates and decrypts ciphertext. See the cipher.AEAD interface
// for details.
func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(nonce) != g.nonceSize {
panic("crypto/cipher: incorrect nonce length given to GCM")
}
// Sanity check to prevent the authentication from always succeeding if an implementation
// leaves tagSize uninitialized, for example.
if g.tagSize < gcmMinimumTagSize {
panic("crypto/cipher: incorrect GCM tag size")
}
if len(ciphertext) < g.tagSize {
return nil, errOpen
}
if uint64(len(ciphertext)) > ((1<<32)-2)*uint64(BlockSize)+uint64(g.tagSize) {
return nil, errOpen
}
tag := ciphertext[len(ciphertext)-g.tagSize:]
ciphertext = ciphertext[:len(ciphertext)-g.tagSize]
// See GCM spec, section 7.1.
var counter, tagMask [gcmBlockSize]byte
if len(nonce) == gcmStandardNonceSize {
// Init counter to nonce||1
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
// Otherwise counter = GHASH(nonce)
gcmAesData(&g.productTable, nonce, &counter)
gcmAesFinish(&g.productTable, &tagMask, &counter, uint64(len(nonce)), uint64(0))
}
encryptBlockAsm(len(g.ks)/4-1, &g.ks[0], &tagMask[0], &counter[0])
var expectedTag [gcmTagSize]byte
gcmAesData(&g.productTable, data, &expectedTag)
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("crypto/cipher: invalid buffer overlap")
}
if len(ciphertext) > 0 {
gcmAesDec(&g.productTable, out, ciphertext, &counter, &expectedTag, g.ks)
}
gcmAesFinish(&g.productTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))
if subtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 {
for i := range out {
out[i] = 0
}
return nil, errOpen
}
return ret, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This Go implementation is derived in part from the reference
// ANSI C implementation, which carries the following notice:
//
// rijndael-alg-fst.c
//
// @version 3.0 (December 2000)
//
// Optimised ANSI C code for the Rijndael cipher (now AES)
//
// @author Vincent Rijmen <vincent.rijmen@esat.kuleuven.ac.be>
// @author Antoon Bosselaers <antoon.bosselaers@esat.kuleuven.ac.be>
// @author Paulo Barreto <paulo.barreto@terra.com.br>
//
// This code is hereby placed in the public domain.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHORS ''AS IS'' AND ANY EXPRESS
// OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
// OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// See FIPS 197 for specification, and see Daemen and Rijmen's Rijndael submission
// for implementation details.
// https://csrc.nist.gov/csrc/media/publications/fips/197/final/documents/fips-197.pdf
// https://csrc.nist.gov/archive/aes/rijndael/Rijndael-ammended.pdf
package aes
import (
"encoding/binary"
)
// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
_ = src[15] // early bounds check
s0 := binary.BigEndian.Uint32(src[0:4])
s1 := binary.BigEndian.Uint32(src[4:8])
s2 := binary.BigEndian.Uint32(src[8:12])
s3 := binary.BigEndian.Uint32(src[12:16])
// First round just XORs input with key.
s0 ^= xk[0]
s1 ^= xk[1]
s2 ^= xk[2]
s3 ^= xk[3]
// Middle rounds shuffle using tables.
// Number of rounds is set by length of expanded key.
nr := len(xk)/4 - 2 // - 2: one above, one more below
k := 4
var t0, t1, t2, t3 uint32
for r := 0; r < nr; r++ {
t0 = xk[k+0] ^ te0[uint8(s0>>24)] ^ te1[uint8(s1>>16)] ^ te2[uint8(s2>>8)] ^ te3[uint8(s3)]
t1 = xk[k+1] ^ te0[uint8(s1>>24)] ^ te1[uint8(s2>>16)] ^ te2[uint8(s3>>8)] ^ te3[uint8(s0)]
t2 = xk[k+2] ^ te0[uint8(s2>>24)] ^ te1[uint8(s3>>16)] ^ te2[uint8(s0>>8)] ^ te3[uint8(s1)]
t3 = xk[k+3] ^ te0[uint8(s3>>24)] ^ te1[uint8(s0>>16)] ^ te2[uint8(s1>>8)] ^ te3[uint8(s2)]
k += 4
s0, s1, s2, s3 = t0, t1, t2, t3
}
// Last round uses s-box directly and XORs to produce output.
s0 = uint32(sbox0[t0>>24])<<24 | uint32(sbox0[t1>>16&0xff])<<16 | uint32(sbox0[t2>>8&0xff])<<8 | uint32(sbox0[t3&0xff])
s1 = uint32(sbox0[t1>>24])<<24 | uint32(sbox0[t2>>16&0xff])<<16 | uint32(sbox0[t3>>8&0xff])<<8 | uint32(sbox0[t0&0xff])
s2 = uint32(sbox0[t2>>24])<<24 | uint32(sbox0[t3>>16&0xff])<<16 | uint32(sbox0[t0>>8&0xff])<<8 | uint32(sbox0[t1&0xff])
s3 = uint32(sbox0[t3>>24])<<24 | uint32(sbox0[t0>>16&0xff])<<16 | uint32(sbox0[t1>>8&0xff])<<8 | uint32(sbox0[t2&0xff])
s0 ^= xk[k+0]
s1 ^= xk[k+1]
s2 ^= xk[k+2]
s3 ^= xk[k+3]
_ = dst[15] // early bounds check
binary.BigEndian.PutUint32(dst[0:4], s0)
binary.BigEndian.PutUint32(dst[4:8], s1)
binary.BigEndian.PutUint32(dst[8:12], s2)
binary.BigEndian.PutUint32(dst[12:16], s3)
}
// Decrypt one block from src into dst, using the expanded key xk.
func decryptBlockGo(xk []uint32, dst, src []byte) {
_ = src[15] // early bounds check
s0 := binary.BigEndian.Uint32(src[0:4])
s1 := binary.BigEndian.Uint32(src[4:8])
s2 := binary.BigEndian.Uint32(src[8:12])
s3 := binary.BigEndian.Uint32(src[12:16])
// First round just XORs input with key.
s0 ^= xk[0]
s1 ^= xk[1]
s2 ^= xk[2]
s3 ^= xk[3]
// Middle rounds shuffle using tables.
// Number of rounds is set by length of expanded key.
nr := len(xk)/4 - 2 // - 2: one above, one more below
k := 4
var t0, t1, t2, t3 uint32
for r := 0; r < nr; r++ {
t0 = xk[k+0] ^ td0[uint8(s0>>24)] ^ td1[uint8(s3>>16)] ^ td2[uint8(s2>>8)] ^ td3[uint8(s1)]
t1 = xk[k+1] ^ td0[uint8(s1>>24)] ^ td1[uint8(s0>>16)] ^ td2[uint8(s3>>8)] ^ td3[uint8(s2)]
t2 = xk[k+2] ^ td0[uint8(s2>>24)] ^ td1[uint8(s1>>16)] ^ td2[uint8(s0>>8)] ^ td3[uint8(s3)]
t3 = xk[k+3] ^ td0[uint8(s3>>24)] ^ td1[uint8(s2>>16)] ^ td2[uint8(s1>>8)] ^ td3[uint8(s0)]
k += 4
s0, s1, s2, s3 = t0, t1, t2, t3
}
// Last round uses s-box directly and XORs to produce output.
s0 = uint32(sbox1[t0>>24])<<24 | uint32(sbox1[t3>>16&0xff])<<16 | uint32(sbox1[t2>>8&0xff])<<8 | uint32(sbox1[t1&0xff])
s1 = uint32(sbox1[t1>>24])<<24 | uint32(sbox1[t0>>16&0xff])<<16 | uint32(sbox1[t3>>8&0xff])<<8 | uint32(sbox1[t2&0xff])
s2 = uint32(sbox1[t2>>24])<<24 | uint32(sbox1[t1>>16&0xff])<<16 | uint32(sbox1[t0>>8&0xff])<<8 | uint32(sbox1[t3&0xff])
s3 = uint32(sbox1[t3>>24])<<24 | uint32(sbox1[t2>>16&0xff])<<16 | uint32(sbox1[t1>>8&0xff])<<8 | uint32(sbox1[t0&0xff])
s0 ^= xk[k+0]
s1 ^= xk[k+1]
s2 ^= xk[k+2]
s3 ^= xk[k+3]
_ = dst[15] // early bounds check
binary.BigEndian.PutUint32(dst[0:4], s0)
binary.BigEndian.PutUint32(dst[4:8], s1)
binary.BigEndian.PutUint32(dst[8:12], s2)
binary.BigEndian.PutUint32(dst[12:16], s3)
}
// Apply sbox0 to each byte in w.
func subw(w uint32) uint32 {
return uint32(sbox0[w>>24])<<24 |
uint32(sbox0[w>>16&0xff])<<16 |
uint32(sbox0[w>>8&0xff])<<8 |
uint32(sbox0[w&0xff])
}
// Rotate
func rotw(w uint32) uint32 { return w<<8 | w>>24 }
// Key expansion algorithm. See FIPS-197, Figure 11.
// Their rcon[i] is our powx[i-1] << 24.
func expandKeyGo(key []byte, enc, dec []uint32) {
// Encryption key setup.
var i int
nk := len(key) / 4
for i = 0; i < nk; i++ {
enc[i] = binary.BigEndian.Uint32(key[4*i:])
}
for ; i < len(enc); i++ {
t := enc[i-1]
if i%nk == 0 {
t = subw(rotw(t)) ^ (uint32(powx[i/nk-1]) << 24)
} else if nk > 6 && i%nk == 4 {
t = subw(t)
}
enc[i] = enc[i-nk] ^ t
}
// Derive decryption key from encryption key.
// Reverse the 4-word round key sets from enc to produce dec.
// All sets but the first and last get the MixColumn transform applied.
if dec == nil {
return
}
n := len(enc)
for i := 0; i < n; i += 4 {
ei := n - i - 4
for j := 0; j < 4; j++ {
x := enc[ei+j]
if i > 0 && i+4 < n {
x = td0[sbox0[x>>24]] ^ td1[sbox0[x>>16&0xff]] ^ td2[sbox0[x>>8&0xff]] ^ td3[sbox0[x&0xff]]
}
dec[i+j] = x
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package aes
import (
"crypto/cipher"
"crypto/internal/alias"
"crypto/internal/boring"
"strconv"
)
// The AES block size in bytes.
const BlockSize = 16
// A cipher is an instance of AES encryption using a particular key.
type aesCipher struct {
enc []uint32
dec []uint32
}
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/aes: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a new cipher.Block.
// The key argument should be the AES key,
// either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256.
func NewCipher(key []byte) (cipher.Block, error) {
k := len(key)
switch k {
default:
return nil, KeySizeError(k)
case 16, 24, 32:
break
}
if boring.Enabled {
return boring.NewAESCipher(key)
}
return newCipher(key)
}
// newCipherGeneric creates and returns a new cipher.Block
// implemented in pure Go.
func newCipherGeneric(key []byte) (cipher.Block, error) {
n := len(key) + 28
c := aesCipher{make([]uint32, n), make([]uint32, n)}
expandKeyGo(key, c.enc, c.dec)
return &c, nil
}
func (c *aesCipher) BlockSize() int { return BlockSize }
func (c *aesCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/aes: invalid buffer overlap")
}
encryptBlockGo(c.enc, dst, src)
}
func (c *aesCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/aes: invalid buffer overlap")
}
decryptBlockGo(c.dec, dst, src)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64 || ppc64 || ppc64le
package aes
import (
"crypto/cipher"
"crypto/internal/alias"
"crypto/internal/boring"
"internal/cpu"
"internal/goarch"
)
// defined in asm_*.s
//go:noescape
func encryptBlockAsm(nr int, xk *uint32, dst, src *byte)
//go:noescape
func decryptBlockAsm(nr int, xk *uint32, dst, src *byte)
//go:noescape
func expandKeyAsm(nr int, key *byte, enc *uint32, dec *uint32)
type aesCipherAsm struct {
aesCipher
}
// aesCipherGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM
// will use the optimised implementation in aes_gcm.go when possible.
// Instances of this type only exist when hasGCMAsm returns true. Likewise,
// the gcmAble implementation is in aes_gcm.go.
type aesCipherGCM struct {
aesCipherAsm
}
var supportsAES = cpu.X86.HasAES || cpu.ARM64.HasAES || goarch.IsPpc64 == 1 || goarch.IsPpc64le == 1
var supportsGFMUL = cpu.X86.HasPCLMULQDQ || cpu.ARM64.HasPMULL
func newCipher(key []byte) (cipher.Block, error) {
if !supportsAES {
return newCipherGeneric(key)
}
n := len(key) + 28
c := aesCipherAsm{aesCipher{make([]uint32, n), make([]uint32, n)}}
var rounds int
switch len(key) {
case 128 / 8:
rounds = 10
case 192 / 8:
rounds = 12
case 256 / 8:
rounds = 14
default:
return nil, KeySizeError(len(key))
}
expandKeyAsm(rounds, &key[0], &c.enc[0], &c.dec[0])
if supportsAES && supportsGFMUL {
return &aesCipherGCM{c}, nil
}
return &c, nil
}
func (c *aesCipherAsm) BlockSize() int { return BlockSize }
func (c *aesCipherAsm) Encrypt(dst, src []byte) {
boring.Unreachable()
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/aes: invalid buffer overlap")
}
encryptBlockAsm(len(c.enc)/4-1, &c.enc[0], &dst[0], &src[0])
}
func (c *aesCipherAsm) Decrypt(dst, src []byte) {
boring.Unreachable()
if len(src) < BlockSize {
panic("crypto/aes: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/aes: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/aes: invalid buffer overlap")
}
decryptBlockAsm(len(c.dec)/4-1, &c.dec[0], &dst[0], &src[0])
}
// expandKey is used by BenchmarkExpand to ensure that the asm implementation
// of key expansion is used for the benchmark when it is available.
func expandKey(key []byte, enc, dec []uint32) {
if supportsAES {
rounds := 10 // rounds needed for AES128
switch len(key) {
case 192 / 8:
rounds = 12
case 256 / 8:
rounds = 14
}
expandKeyAsm(rounds, &key[0], &enc[0], &dec[0])
} else {
expandKeyGo(key, enc, dec)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Cipher block chaining (CBC) mode.
// CBC provides confidentiality by xoring (chaining) each plaintext block
// with the previous ciphertext block before applying the block cipher.
// See NIST SP 800-38A, pp 10-11
package cipher
import (
"bytes"
"crypto/internal/alias"
"crypto/subtle"
)
type cbc struct {
b Block
blockSize int
iv []byte
tmp []byte
}
func newCBC(b Block, iv []byte) *cbc {
return &cbc{
b: b,
blockSize: b.BlockSize(),
iv: bytes.Clone(iv),
tmp: make([]byte, b.BlockSize()),
}
}
type cbcEncrypter cbc
// cbcEncAble is an interface implemented by ciphers that have a specific
// optimized implementation of CBC encryption, like crypto/aes.
// NewCBCEncrypter will check for this interface and return the specific
// BlockMode if found.
type cbcEncAble interface {
NewCBCEncrypter(iv []byte) BlockMode
}
// NewCBCEncrypter returns a BlockMode which encrypts in cipher block chaining
// mode, using the given Block. The length of iv must be the same as the
// Block's block size.
func NewCBCEncrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCEncrypter: IV length must equal block size")
}
if cbc, ok := b.(cbcEncAble); ok {
return cbc.NewCBCEncrypter(iv)
}
return (*cbcEncrypter)(newCBC(b, iv))
}
// newCBCGenericEncrypter returns a BlockMode which encrypts in cipher block chaining
// mode, using the given Block. The length of iv must be the same as the
// Block's block size. This always returns the generic non-asm encrypter for use
// in fuzz testing.
func newCBCGenericEncrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCEncrypter: IV length must equal block size")
}
return (*cbcEncrypter)(newCBC(b, iv))
}
func (x *cbcEncrypter) BlockSize() int { return x.blockSize }
func (x *cbcEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/cipher: invalid buffer overlap")
}
iv := x.iv
for len(src) > 0 {
// Write the xor to dst, then encrypt in place.
subtle.XORBytes(dst[:x.blockSize], src[:x.blockSize], iv)
x.b.Encrypt(dst[:x.blockSize], dst[:x.blockSize])
// Move to the next block with this block as the next iv.
iv = dst[:x.blockSize]
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
// Save the iv for the next CryptBlocks call.
copy(x.iv, iv)
}
func (x *cbcEncrypter) SetIV(iv []byte) {
if len(iv) != len(x.iv) {
panic("cipher: incorrect length IV")
}
copy(x.iv, iv)
}
type cbcDecrypter cbc
// cbcDecAble is an interface implemented by ciphers that have a specific
// optimized implementation of CBC decryption, like crypto/aes.
// NewCBCDecrypter will check for this interface and return the specific
// BlockMode if found.
type cbcDecAble interface {
NewCBCDecrypter(iv []byte) BlockMode
}
// NewCBCDecrypter returns a BlockMode which decrypts in cipher block chaining
// mode, using the given Block. The length of iv must be the same as the
// Block's block size and must match the iv used to encrypt the data.
func NewCBCDecrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCDecrypter: IV length must equal block size")
}
if cbc, ok := b.(cbcDecAble); ok {
return cbc.NewCBCDecrypter(iv)
}
return (*cbcDecrypter)(newCBC(b, iv))
}
// newCBCGenericDecrypter returns a BlockMode which encrypts in cipher block chaining
// mode, using the given Block. The length of iv must be the same as the
// Block's block size. This always returns the generic non-asm decrypter for use in
// fuzz testing.
func newCBCGenericDecrypter(b Block, iv []byte) BlockMode {
if len(iv) != b.BlockSize() {
panic("cipher.NewCBCDecrypter: IV length must equal block size")
}
return (*cbcDecrypter)(newCBC(b, iv))
}
func (x *cbcDecrypter) BlockSize() int { return x.blockSize }
func (x *cbcDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/cipher: invalid buffer overlap")
}
if len(src) == 0 {
return
}
// For each block, we need to xor the decrypted data with the previous block's ciphertext (the iv).
// To avoid making a copy each time, we loop over the blocks BACKWARDS.
end := len(src)
start := end - x.blockSize
prev := start - x.blockSize
// Copy the last block of ciphertext in preparation as the new iv.
copy(x.tmp, src[start:end])
// Loop over all but the first block.
for start > 0 {
x.b.Decrypt(dst[start:end], src[start:end])
subtle.XORBytes(dst[start:end], dst[start:end], src[prev:start])
end = start
start = prev
prev -= x.blockSize
}
// The first block is special because it uses the saved iv.
x.b.Decrypt(dst[start:end], src[start:end])
subtle.XORBytes(dst[start:end], dst[start:end], x.iv)
// Set the new iv to the first block we copied earlier.
x.iv, x.tmp = x.tmp, x.iv
}
func (x *cbcDecrypter) SetIV(iv []byte) {
if len(iv) != len(x.iv) {
panic("cipher: incorrect length IV")
}
copy(x.iv, iv)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// CFB (Cipher Feedback) Mode.
package cipher
import (
"crypto/internal/alias"
"crypto/subtle"
)
type cfb struct {
b Block
next []byte
out []byte
outUsed int
decrypt bool
}
func (x *cfb) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/cipher: invalid buffer overlap")
}
for len(src) > 0 {
if x.outUsed == len(x.out) {
x.b.Encrypt(x.out, x.next)
x.outUsed = 0
}
if x.decrypt {
// We can precompute a larger segment of the
// keystream on decryption. This will allow
// larger batches for xor, and we should be
// able to match CTR/OFB performance.
copy(x.next[x.outUsed:], src)
}
n := subtle.XORBytes(dst, src, x.out[x.outUsed:])
if !x.decrypt {
copy(x.next[x.outUsed:], dst)
}
dst = dst[n:]
src = src[n:]
x.outUsed += n
}
}
// NewCFBEncrypter returns a Stream which encrypts with cipher feedback mode,
// using the given Block. The iv must be the same length as the Block's block
// size.
func NewCFBEncrypter(block Block, iv []byte) Stream {
return newCFB(block, iv, false)
}
// NewCFBDecrypter returns a Stream which decrypts with cipher feedback mode,
// using the given Block. The iv must be the same length as the Block's block
// size.
func NewCFBDecrypter(block Block, iv []byte) Stream {
return newCFB(block, iv, true)
}
func newCFB(block Block, iv []byte, decrypt bool) Stream {
blockSize := block.BlockSize()
if len(iv) != blockSize {
// stack trace will indicate whether it was de or encryption
panic("cipher.newCFB: IV length must equal block size")
}
x := &cfb{
b: block,
out: make([]byte, blockSize),
next: make([]byte, blockSize),
outUsed: blockSize,
decrypt: decrypt,
}
copy(x.next, iv)
return x
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Counter (CTR) mode.
// CTR converts a block cipher into a stream cipher by
// repeatedly encrypting an incrementing counter and
// xoring the resulting stream of data with the input.
// See NIST SP 800-38A, pp 13-15
package cipher
import (
"bytes"
"crypto/internal/alias"
"crypto/subtle"
)
type ctr struct {
b Block
ctr []byte
out []byte
outUsed int
}
const streamBufferSize = 512
// ctrAble is an interface implemented by ciphers that have a specific optimized
// implementation of CTR, like crypto/aes. NewCTR will check for this interface
// and return the specific Stream if found.
type ctrAble interface {
NewCTR(iv []byte) Stream
}
// NewCTR returns a Stream which encrypts/decrypts using the given Block in
// counter mode. The length of iv must be the same as the Block's block size.
func NewCTR(block Block, iv []byte) Stream {
if ctr, ok := block.(ctrAble); ok {
return ctr.NewCTR(iv)
}
if len(iv) != block.BlockSize() {
panic("cipher.NewCTR: IV length must equal block size")
}
bufSize := streamBufferSize
if bufSize < block.BlockSize() {
bufSize = block.BlockSize()
}
return &ctr{
b: block,
ctr: bytes.Clone(iv),
out: make([]byte, 0, bufSize),
outUsed: 0,
}
}
func (x *ctr) refill() {
remain := len(x.out) - x.outUsed
copy(x.out, x.out[x.outUsed:])
x.out = x.out[:cap(x.out)]
bs := x.b.BlockSize()
for remain <= len(x.out)-bs {
x.b.Encrypt(x.out[remain:], x.ctr)
remain += bs
// Increment counter
for i := len(x.ctr) - 1; i >= 0; i-- {
x.ctr[i]++
if x.ctr[i] != 0 {
break
}
}
}
x.out = x.out[:remain]
x.outUsed = 0
}
func (x *ctr) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/cipher: invalid buffer overlap")
}
for len(src) > 0 {
if x.outUsed >= len(x.out)-x.b.BlockSize() {
x.refill()
}
n := subtle.XORBytes(dst, src, x.out[x.outUsed:])
dst = dst[n:]
src = src[n:]
x.outUsed += n
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher
import (
"crypto/internal/alias"
"crypto/subtle"
"encoding/binary"
"errors"
)
// AEAD is a cipher mode providing authenticated encryption with associated
// data. For a description of the methodology, see
// https://en.wikipedia.org/wiki/Authenticated_encryption.
type AEAD interface {
// NonceSize returns the size of the nonce that must be passed to Seal
// and Open.
NonceSize() int
// Overhead returns the maximum difference between the lengths of a
// plaintext and its ciphertext.
Overhead() int
// Seal encrypts and authenticates plaintext, authenticates the
// additional data and appends the result to dst, returning the updated
// slice. The nonce must be NonceSize() bytes long and unique for all
// time, for a given key.
//
// To reuse plaintext's storage for the encrypted output, use plaintext[:0]
// as dst. Otherwise, the remaining capacity of dst must not overlap plaintext.
Seal(dst, nonce, plaintext, additionalData []byte) []byte
// Open decrypts and authenticates ciphertext, authenticates the
// additional data and, if successful, appends the resulting plaintext
// to dst, returning the updated slice. The nonce must be NonceSize()
// bytes long and both it and the additional data must match the
// value passed to Seal.
//
// To reuse ciphertext's storage for the decrypted output, use ciphertext[:0]
// as dst. Otherwise, the remaining capacity of dst must not overlap plaintext.
//
// Even if the function fails, the contents of dst, up to its capacity,
// may be overwritten.
Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error)
}
// gcmAble is an interface implemented by ciphers that have a specific optimized
// implementation of GCM, like crypto/aes. NewGCM will check for this interface
// and return the specific AEAD if found.
type gcmAble interface {
NewGCM(nonceSize, tagSize int) (AEAD, error)
}
// gcmFieldElement represents a value in GF(2¹²⁸). In order to reflect the GCM
// standard and make binary.BigEndian suitable for marshaling these values, the
// bits are stored in big endian order. For example:
//
// the coefficient of x⁰ can be obtained by v.low >> 63.
// the coefficient of x⁶³ can be obtained by v.low & 1.
// the coefficient of x⁶⁴ can be obtained by v.high >> 63.
// the coefficient of x¹²⁷ can be obtained by v.high & 1.
type gcmFieldElement struct {
low, high uint64
}
// gcm represents a Galois Counter Mode with a specific key. See
// https://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
type gcm struct {
cipher Block
nonceSize int
tagSize int
// productTable contains the first sixteen powers of the key, H.
// However, they are in bit reversed order. See NewGCMWithNonceSize.
productTable [16]gcmFieldElement
}
// NewGCM returns the given 128-bit, block cipher wrapped in Galois Counter Mode
// with the standard nonce length.
//
// In general, the GHASH operation performed by this implementation of GCM is not constant-time.
// An exception is when the underlying Block was created by aes.NewCipher
// on systems with hardware support for AES. See the crypto/aes package documentation for details.
func NewGCM(cipher Block) (AEAD, error) {
return newGCMWithNonceAndTagSize(cipher, gcmStandardNonceSize, gcmTagSize)
}
// NewGCMWithNonceSize returns the given 128-bit, block cipher wrapped in Galois
// Counter Mode, which accepts nonces of the given length. The length must not
// be zero.
//
// Only use this function if you require compatibility with an existing
// cryptosystem that uses non-standard nonce lengths. All other users should use
// NewGCM, which is faster and more resistant to misuse.
func NewGCMWithNonceSize(cipher Block, size int) (AEAD, error) {
return newGCMWithNonceAndTagSize(cipher, size, gcmTagSize)
}
// NewGCMWithTagSize returns the given 128-bit, block cipher wrapped in Galois
// Counter Mode, which generates tags with the given length.
//
// Tag sizes between 12 and 16 bytes are allowed.
//
// Only use this function if you require compatibility with an existing
// cryptosystem that uses non-standard tag lengths. All other users should use
// NewGCM, which is more resistant to misuse.
func NewGCMWithTagSize(cipher Block, tagSize int) (AEAD, error) {
return newGCMWithNonceAndTagSize(cipher, gcmStandardNonceSize, tagSize)
}
func newGCMWithNonceAndTagSize(cipher Block, nonceSize, tagSize int) (AEAD, error) {
if tagSize < gcmMinimumTagSize || tagSize > gcmBlockSize {
return nil, errors.New("cipher: incorrect tag size given to GCM")
}
if nonceSize <= 0 {
return nil, errors.New("cipher: the nonce can't have zero length, or the security of the key will be immediately compromised")
}
if cipher, ok := cipher.(gcmAble); ok {
return cipher.NewGCM(nonceSize, tagSize)
}
if cipher.BlockSize() != gcmBlockSize {
return nil, errors.New("cipher: NewGCM requires 128-bit block cipher")
}
var key [gcmBlockSize]byte
cipher.Encrypt(key[:], key[:])
g := &gcm{cipher: cipher, nonceSize: nonceSize, tagSize: tagSize}
// We precompute 16 multiples of |key|. However, when we do lookups
// into this table we'll be using bits from a field element and
// therefore the bits will be in the reverse order. So normally one
// would expect, say, 4*key to be in index 4 of the table but due to
// this bit ordering it will actually be in index 0010 (base 2) = 2.
x := gcmFieldElement{
binary.BigEndian.Uint64(key[:8]),
binary.BigEndian.Uint64(key[8:]),
}
g.productTable[reverseBits(1)] = x
for i := 2; i < 16; i += 2 {
g.productTable[reverseBits(i)] = gcmDouble(&g.productTable[reverseBits(i/2)])
g.productTable[reverseBits(i+1)] = gcmAdd(&g.productTable[reverseBits(i)], &x)
}
return g, nil
}
const (
gcmBlockSize = 16
gcmTagSize = 16
gcmMinimumTagSize = 12 // NIST SP 800-38D recommends tags with 12 or more bytes.
gcmStandardNonceSize = 12
)
func (g *gcm) NonceSize() int {
return g.nonceSize
}
func (g *gcm) Overhead() int {
return g.tagSize
}
func (g *gcm) Seal(dst, nonce, plaintext, data []byte) []byte {
if len(nonce) != g.nonceSize {
panic("crypto/cipher: incorrect nonce length given to GCM")
}
if uint64(len(plaintext)) > ((1<<32)-2)*uint64(g.cipher.BlockSize()) {
panic("crypto/cipher: message too large for GCM")
}
ret, out := sliceForAppend(dst, len(plaintext)+g.tagSize)
if alias.InexactOverlap(out, plaintext) {
panic("crypto/cipher: invalid buffer overlap")
}
var counter, tagMask [gcmBlockSize]byte
g.deriveCounter(&counter, nonce)
g.cipher.Encrypt(tagMask[:], counter[:])
gcmInc32(&counter)
g.counterCrypt(out, plaintext, &counter)
var tag [gcmTagSize]byte
g.auth(tag[:], out[:len(plaintext)], data, &tagMask)
copy(out[len(plaintext):], tag[:])
return ret
}
var errOpen = errors.New("cipher: message authentication failed")
func (g *gcm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
if len(nonce) != g.nonceSize {
panic("crypto/cipher: incorrect nonce length given to GCM")
}
// Sanity check to prevent the authentication from always succeeding if an implementation
// leaves tagSize uninitialized, for example.
if g.tagSize < gcmMinimumTagSize {
panic("crypto/cipher: incorrect GCM tag size")
}
if len(ciphertext) < g.tagSize {
return nil, errOpen
}
if uint64(len(ciphertext)) > ((1<<32)-2)*uint64(g.cipher.BlockSize())+uint64(g.tagSize) {
return nil, errOpen
}
tag := ciphertext[len(ciphertext)-g.tagSize:]
ciphertext = ciphertext[:len(ciphertext)-g.tagSize]
var counter, tagMask [gcmBlockSize]byte
g.deriveCounter(&counter, nonce)
g.cipher.Encrypt(tagMask[:], counter[:])
gcmInc32(&counter)
var expectedTag [gcmTagSize]byte
g.auth(expectedTag[:], ciphertext, data, &tagMask)
ret, out := sliceForAppend(dst, len(ciphertext))
if alias.InexactOverlap(out, ciphertext) {
panic("crypto/cipher: invalid buffer overlap")
}
if subtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 {
// The AESNI code decrypts and authenticates concurrently, and
// so overwrites dst in the event of a tag mismatch. That
// behavior is mimicked here in order to be consistent across
// platforms.
for i := range out {
out[i] = 0
}
return nil, errOpen
}
g.counterCrypt(out, ciphertext, &counter)
return ret, nil
}
// reverseBits reverses the order of the bits of 4-bit number in i.
func reverseBits(i int) int {
i = ((i << 2) & 0xc) | ((i >> 2) & 0x3)
i = ((i << 1) & 0xa) | ((i >> 1) & 0x5)
return i
}
// gcmAdd adds two elements of GF(2¹²⁸) and returns the sum.
func gcmAdd(x, y *gcmFieldElement) gcmFieldElement {
// Addition in a characteristic 2 field is just XOR.
return gcmFieldElement{x.low ^ y.low, x.high ^ y.high}
}
// gcmDouble returns the result of doubling an element of GF(2¹²⁸).
func gcmDouble(x *gcmFieldElement) (double gcmFieldElement) {
msbSet := x.high&1 == 1
// Because of the bit-ordering, doubling is actually a right shift.
double.high = x.high >> 1
double.high |= x.low << 63
double.low = x.low >> 1
// If the most-significant bit was set before shifting then it,
// conceptually, becomes a term of x^128. This is greater than the
// irreducible polynomial so the result has to be reduced. The
// irreducible polynomial is 1+x+x^2+x^7+x^128. We can subtract that to
// eliminate the term at x^128 which also means subtracting the other
// four terms. In characteristic 2 fields, subtraction == addition ==
// XOR.
if msbSet {
double.low ^= 0xe100000000000000
}
return
}
var gcmReductionTable = []uint16{
0x0000, 0x1c20, 0x3840, 0x2460, 0x7080, 0x6ca0, 0x48c0, 0x54e0,
0xe100, 0xfd20, 0xd940, 0xc560, 0x9180, 0x8da0, 0xa9c0, 0xb5e0,
}
// mul sets y to y*H, where H is the GCM key, fixed during NewGCMWithNonceSize.
func (g *gcm) mul(y *gcmFieldElement) {
var z gcmFieldElement
for i := 0; i < 2; i++ {
word := y.high
if i == 1 {
word = y.low
}
// Multiplication works by multiplying z by 16 and adding in
// one of the precomputed multiples of H.
for j := 0; j < 64; j += 4 {
msw := z.high & 0xf
z.high >>= 4
z.high |= z.low << 60
z.low >>= 4
z.low ^= uint64(gcmReductionTable[msw]) << 48
// the values in |table| are ordered for
// little-endian bit positions. See the comment
// in NewGCMWithNonceSize.
t := &g.productTable[word&0xf]
z.low ^= t.low
z.high ^= t.high
word >>= 4
}
}
*y = z
}
// updateBlocks extends y with more polynomial terms from blocks, based on
// Horner's rule. There must be a multiple of gcmBlockSize bytes in blocks.
func (g *gcm) updateBlocks(y *gcmFieldElement, blocks []byte) {
for len(blocks) > 0 {
y.low ^= binary.BigEndian.Uint64(blocks)
y.high ^= binary.BigEndian.Uint64(blocks[8:])
g.mul(y)
blocks = blocks[gcmBlockSize:]
}
}
// update extends y with more polynomial terms from data. If data is not a
// multiple of gcmBlockSize bytes long then the remainder is zero padded.
func (g *gcm) update(y *gcmFieldElement, data []byte) {
fullBlocks := (len(data) >> 4) << 4
g.updateBlocks(y, data[:fullBlocks])
if len(data) != fullBlocks {
var partialBlock [gcmBlockSize]byte
copy(partialBlock[:], data[fullBlocks:])
g.updateBlocks(y, partialBlock[:])
}
}
// gcmInc32 treats the final four bytes of counterBlock as a big-endian value
// and increments it.
func gcmInc32(counterBlock *[16]byte) {
ctr := counterBlock[len(counterBlock)-4:]
binary.BigEndian.PutUint32(ctr, binary.BigEndian.Uint32(ctr)+1)
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// counterCrypt crypts in to out using g.cipher in counter mode.
func (g *gcm) counterCrypt(out, in []byte, counter *[gcmBlockSize]byte) {
var mask [gcmBlockSize]byte
for len(in) >= gcmBlockSize {
g.cipher.Encrypt(mask[:], counter[:])
gcmInc32(counter)
subtle.XORBytes(out, in, mask[:])
out = out[gcmBlockSize:]
in = in[gcmBlockSize:]
}
if len(in) > 0 {
g.cipher.Encrypt(mask[:], counter[:])
gcmInc32(counter)
subtle.XORBytes(out, in, mask[:])
}
}
// deriveCounter computes the initial GCM counter state from the given nonce.
// See NIST SP 800-38D, section 7.1. This assumes that counter is filled with
// zeros on entry.
func (g *gcm) deriveCounter(counter *[gcmBlockSize]byte, nonce []byte) {
// GCM has two modes of operation with respect to the initial counter
// state: a "fast path" for 96-bit (12-byte) nonces, and a "slow path"
// for nonces of other lengths. For a 96-bit nonce, the nonce, along
// with a four-byte big-endian counter starting at one, is used
// directly as the starting counter. For other nonce sizes, the counter
// is computed by passing it through the GHASH function.
if len(nonce) == gcmStandardNonceSize {
copy(counter[:], nonce)
counter[gcmBlockSize-1] = 1
} else {
var y gcmFieldElement
g.update(&y, nonce)
y.high ^= uint64(len(nonce)) * 8
g.mul(&y)
binary.BigEndian.PutUint64(counter[:8], y.low)
binary.BigEndian.PutUint64(counter[8:], y.high)
}
}
// auth calculates GHASH(ciphertext, additionalData), masks the result with
// tagMask and writes the result to out.
func (g *gcm) auth(out, ciphertext, additionalData []byte, tagMask *[gcmTagSize]byte) {
var y gcmFieldElement
g.update(&y, additionalData)
g.update(&y, ciphertext)
y.low ^= uint64(len(additionalData)) * 8
y.high ^= uint64(len(ciphertext)) * 8
g.mul(&y)
binary.BigEndian.PutUint64(out, y.low)
binary.BigEndian.PutUint64(out[8:], y.high)
subtle.XORBytes(out, out, tagMask[:])
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher
import "io"
// The Stream* objects are so simple that all their members are public. Users
// can create them themselves.
// StreamReader wraps a Stream into an io.Reader. It calls XORKeyStream
// to process each slice of data which passes through.
type StreamReader struct {
S Stream
R io.Reader
}
func (r StreamReader) Read(dst []byte) (n int, err error) {
n, err = r.R.Read(dst)
r.S.XORKeyStream(dst[:n], dst[:n])
return
}
// StreamWriter wraps a Stream into an io.Writer. It calls XORKeyStream
// to process each slice of data which passes through. If any Write call
// returns short then the StreamWriter is out of sync and must be discarded.
// A StreamWriter has no internal buffering; Close does not need
// to be called to flush write data.
type StreamWriter struct {
S Stream
W io.Writer
Err error // unused
}
func (w StreamWriter) Write(src []byte) (n int, err error) {
c := make([]byte, len(src))
w.S.XORKeyStream(c, src)
n, err = w.W.Write(c)
if n != len(src) && err == nil { // should never happen
err = io.ErrShortWrite
}
return
}
// Close closes the underlying Writer and returns its Close return value, if the Writer
// is also an io.Closer. Otherwise it returns nil.
func (w StreamWriter) Close() error {
if c, ok := w.W.(io.Closer); ok {
return c.Close()
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// OFB (Output Feedback) Mode.
package cipher
import (
"crypto/internal/alias"
"crypto/subtle"
)
type ofb struct {
b Block
cipher []byte
out []byte
outUsed int
}
// NewOFB returns a Stream that encrypts or decrypts using the block cipher b
// in output feedback mode. The initialization vector iv's length must be equal
// to b's block size.
func NewOFB(b Block, iv []byte) Stream {
blockSize := b.BlockSize()
if len(iv) != blockSize {
panic("cipher.NewOFB: IV length must equal block size")
}
bufSize := streamBufferSize
if bufSize < blockSize {
bufSize = blockSize
}
x := &ofb{
b: b,
cipher: make([]byte, blockSize),
out: make([]byte, 0, bufSize),
outUsed: 0,
}
copy(x.cipher, iv)
return x
}
func (x *ofb) refill() {
bs := x.b.BlockSize()
remain := len(x.out) - x.outUsed
if remain > x.outUsed {
return
}
copy(x.out, x.out[x.outUsed:])
x.out = x.out[:cap(x.out)]
for remain < len(x.out)-bs {
x.b.Encrypt(x.cipher, x.cipher)
copy(x.out[remain:], x.cipher)
remain += bs
}
x.out = x.out[:remain]
x.outUsed = 0
}
func (x *ofb) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/cipher: invalid buffer overlap")
}
for len(src) > 0 {
if x.outUsed >= len(x.out)-x.b.BlockSize() {
x.refill()
}
n := subtle.XORBytes(dst, src, x.out[x.outUsed:])
dst = dst[n:]
src = src[n:]
x.outUsed += n
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package crypto collects common cryptographic constants.
package crypto
import (
"hash"
"io"
"strconv"
)
// Hash identifies a cryptographic hash function that is implemented in another
// package.
type Hash uint
// HashFunc simply returns the value of h so that Hash implements SignerOpts.
func (h Hash) HashFunc() Hash {
return h
}
func (h Hash) String() string {
switch h {
case MD4:
return "MD4"
case MD5:
return "MD5"
case SHA1:
return "SHA-1"
case SHA224:
return "SHA-224"
case SHA256:
return "SHA-256"
case SHA384:
return "SHA-384"
case SHA512:
return "SHA-512"
case MD5SHA1:
return "MD5+SHA1"
case RIPEMD160:
return "RIPEMD-160"
case SHA3_224:
return "SHA3-224"
case SHA3_256:
return "SHA3-256"
case SHA3_384:
return "SHA3-384"
case SHA3_512:
return "SHA3-512"
case SHA512_224:
return "SHA-512/224"
case SHA512_256:
return "SHA-512/256"
case BLAKE2s_256:
return "BLAKE2s-256"
case BLAKE2b_256:
return "BLAKE2b-256"
case BLAKE2b_384:
return "BLAKE2b-384"
case BLAKE2b_512:
return "BLAKE2b-512"
default:
return "unknown hash value " + strconv.Itoa(int(h))
}
}
const (
MD4 Hash = 1 + iota // import golang.org/x/crypto/md4
MD5 // import crypto/md5
SHA1 // import crypto/sha1
SHA224 // import crypto/sha256
SHA256 // import crypto/sha256
SHA384 // import crypto/sha512
SHA512 // import crypto/sha512
MD5SHA1 // no implementation; MD5+SHA1 used for TLS RSA
RIPEMD160 // import golang.org/x/crypto/ripemd160
SHA3_224 // import golang.org/x/crypto/sha3
SHA3_256 // import golang.org/x/crypto/sha3
SHA3_384 // import golang.org/x/crypto/sha3
SHA3_512 // import golang.org/x/crypto/sha3
SHA512_224 // import crypto/sha512
SHA512_256 // import crypto/sha512
BLAKE2s_256 // import golang.org/x/crypto/blake2s
BLAKE2b_256 // import golang.org/x/crypto/blake2b
BLAKE2b_384 // import golang.org/x/crypto/blake2b
BLAKE2b_512 // import golang.org/x/crypto/blake2b
maxHash
)
var digestSizes = []uint8{
MD4: 16,
MD5: 16,
SHA1: 20,
SHA224: 28,
SHA256: 32,
SHA384: 48,
SHA512: 64,
SHA512_224: 28,
SHA512_256: 32,
SHA3_224: 28,
SHA3_256: 32,
SHA3_384: 48,
SHA3_512: 64,
MD5SHA1: 36,
RIPEMD160: 20,
BLAKE2s_256: 32,
BLAKE2b_256: 32,
BLAKE2b_384: 48,
BLAKE2b_512: 64,
}
// Size returns the length, in bytes, of a digest resulting from the given hash
// function. It doesn't require that the hash function in question be linked
// into the program.
func (h Hash) Size() int {
if h > 0 && h < maxHash {
return int(digestSizes[h])
}
panic("crypto: Size of unknown hash function")
}
var hashes = make([]func() hash.Hash, maxHash)
// New returns a new hash.Hash calculating the given hash function. New panics
// if the hash function is not linked into the binary.
func (h Hash) New() hash.Hash {
if h > 0 && h < maxHash {
f := hashes[h]
if f != nil {
return f()
}
}
panic("crypto: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable")
}
// Available reports whether the given hash function is linked into the binary.
func (h Hash) Available() bool {
return h < maxHash && hashes[h] != nil
}
// RegisterHash registers a function that returns a new instance of the given
// hash function. This is intended to be called from the init function in
// packages that implement hash functions.
func RegisterHash(h Hash, f func() hash.Hash) {
if h >= maxHash {
panic("crypto: RegisterHash of unknown hash function")
}
hashes[h] = f
}
// PublicKey represents a public key using an unspecified algorithm.
//
// Although this type is an empty interface for backwards compatibility reasons,
// all public key types in the standard library implement the following interface
//
// interface{
// Equal(x crypto.PublicKey) bool
// }
//
// which can be used for increased type safety within applications.
type PublicKey any
// PrivateKey represents a private key using an unspecified algorithm.
//
// Although this type is an empty interface for backwards compatibility reasons,
// all private key types in the standard library implement the following interface
//
// interface{
// Public() crypto.PublicKey
// Equal(x crypto.PrivateKey) bool
// }
//
// as well as purpose-specific interfaces such as Signer and Decrypter, which
// can be used for increased type safety within applications.
type PrivateKey any
// Signer is an interface for an opaque private key that can be used for
// signing operations. For example, an RSA key kept in a hardware module.
type Signer interface {
// Public returns the public key corresponding to the opaque,
// private key.
Public() PublicKey
// Sign signs digest with the private key, possibly using entropy from
// rand. For an RSA key, the resulting signature should be either a
// PKCS #1 v1.5 or PSS signature (as indicated by opts). For an (EC)DSA
// key, it should be a DER-serialised, ASN.1 signature structure.
//
// Hash implements the SignerOpts interface and, in most cases, one can
// simply pass in the hash function used as opts. Sign may also attempt
// to type assert opts to other types in order to obtain algorithm
// specific values. See the documentation in each package for details.
//
// Note that when a signature of a hash of a larger message is needed,
// the caller is responsible for hashing the larger message and passing
// the hash (as digest) and the hash function (as opts) to Sign.
Sign(rand io.Reader, digest []byte, opts SignerOpts) (signature []byte, err error)
}
// SignerOpts contains options for signing with a Signer.
type SignerOpts interface {
// HashFunc returns an identifier for the hash function used to produce
// the message passed to Signer.Sign, or else zero to indicate that no
// hashing was done.
HashFunc() Hash
}
// Decrypter is an interface for an opaque private key that can be used for
// asymmetric decryption operations. An example would be an RSA key
// kept in a hardware module.
type Decrypter interface {
// Public returns the public key corresponding to the opaque,
// private key.
Public() PublicKey
// Decrypt decrypts msg. The opts argument should be appropriate for
// the primitive used. See the documentation in each implementation for
// details.
Decrypt(rand io.Reader, msg []byte, opts DecrypterOpts) (plaintext []byte, err error)
}
type DecrypterOpts any
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package des
import (
"encoding/binary"
"sync"
)
func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) {
b := binary.BigEndian.Uint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
if decrypt {
for i := 0; i < 8; i++ {
left, right = feistel(left, right, subkeys[15-2*i], subkeys[15-(2*i+1)])
}
} else {
for i := 0; i < 8; i++ {
left, right = feistel(left, right, subkeys[2*i], subkeys[2*i+1])
}
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
// switch left & right and perform final permutation
preOutput := (uint64(right) << 32) | uint64(left)
binary.BigEndian.PutUint64(dst, permuteFinalBlock(preOutput))
}
// Encrypt one block from src into dst, using the subkeys.
func encryptBlock(subkeys []uint64, dst, src []byte) {
cryptBlock(subkeys, dst, src, false)
}
// Decrypt one block from src into dst, using the subkeys.
func decryptBlock(subkeys []uint64, dst, src []byte) {
cryptBlock(subkeys, dst, src, true)
}
// DES Feistel function. feistelBox must be initialized via
// feistelBoxOnce.Do(initFeistelBox) first.
func feistel(l, r uint32, k0, k1 uint64) (lout, rout uint32) {
var t uint32
t = r ^ uint32(k0>>32)
l ^= feistelBox[7][t&0x3f] ^
feistelBox[5][(t>>8)&0x3f] ^
feistelBox[3][(t>>16)&0x3f] ^
feistelBox[1][(t>>24)&0x3f]
t = ((r << 28) | (r >> 4)) ^ uint32(k0)
l ^= feistelBox[6][(t)&0x3f] ^
feistelBox[4][(t>>8)&0x3f] ^
feistelBox[2][(t>>16)&0x3f] ^
feistelBox[0][(t>>24)&0x3f]
t = l ^ uint32(k1>>32)
r ^= feistelBox[7][t&0x3f] ^
feistelBox[5][(t>>8)&0x3f] ^
feistelBox[3][(t>>16)&0x3f] ^
feistelBox[1][(t>>24)&0x3f]
t = ((l << 28) | (l >> 4)) ^ uint32(k1)
r ^= feistelBox[6][(t)&0x3f] ^
feistelBox[4][(t>>8)&0x3f] ^
feistelBox[2][(t>>16)&0x3f] ^
feistelBox[0][(t>>24)&0x3f]
return l, r
}
// feistelBox[s][16*i+j] contains the output of permutationFunction
// for sBoxes[s][i][j] << 4*(7-s)
var feistelBox [8][64]uint32
var feistelBoxOnce sync.Once
// general purpose function to perform DES block permutations.
func permuteBlock(src uint64, permutation []uint8) (block uint64) {
for position, n := range permutation {
bit := (src >> n) & 1
block |= bit << uint((len(permutation)-1)-position)
}
return
}
func initFeistelBox() {
for s := range sBoxes {
for i := 0; i < 4; i++ {
for j := 0; j < 16; j++ {
f := uint64(sBoxes[s][i][j]) << (4 * (7 - uint(s)))
f = permuteBlock(f, permutationFunction[:])
// Row is determined by the 1st and 6th bit.
// Column is the middle four bits.
row := uint8(((i & 2) << 4) | i&1)
col := uint8(j << 1)
t := row | col
// The rotation was performed in the feistel rounds, being factored out and now mixed into the feistelBox.
f = (f << 1) | (f >> 31)
feistelBox[s][t] = uint32(f)
}
}
}
}
// permuteInitialBlock is equivalent to the permutation defined
// by initialPermutation.
func permuteInitialBlock(block uint64) uint64 {
// block = b7 b6 b5 b4 b3 b2 b1 b0 (8 bytes)
b1 := block >> 48
b2 := block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
// block = b1 b0 b5 b4 b3 b2 b7 b6
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24 // exchange b0 b4 with b3 b7
// block is now b1 b3 b5 b7 b0 b2 b4 b6, the permutation:
// ... 8
// ... 24
// ... 40
// ... 56
// 7 6 5 4 3 2 1 0
// 23 22 21 20 19 18 17 16
// ... 32
// ... 48
// exchange 4,5,6,7 with 32,33,34,35 etc.
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
// block is the permutation:
//
// [+8] [+40]
//
// 7 6 5 4
// 23 22 21 20
// 3 2 1 0
// 19 18 17 16 [+32]
// exchange 0,1,4,5 with 18,19,22,23
b1 = block & 0x3300330033003300
b2 = block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
// block is the permutation:
// 15 14
// 13 12
// 11 10
// 9 8
// 7 6
// 5 4
// 3 2
// 1 0 [+16] [+32] [+64]
// exchange 0,2,4,6 with 9,11,13,15:
b1 = block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
// block is the permutation:
// 6 14 22 30 38 46 54 62
// 4 12 20 28 36 44 52 60
// 2 10 18 26 34 42 50 58
// 0 8 16 24 32 40 48 56
// 7 15 23 31 39 47 55 63
// 5 13 21 29 37 45 53 61
// 3 11 19 27 35 43 51 59
// 1 9 17 25 33 41 49 57
return block
}
// permuteFinalBlock is equivalent to the permutation defined
// by finalPermutation.
func permuteFinalBlock(block uint64) uint64 {
// Perform the same bit exchanges as permuteInitialBlock
// but in reverse order.
b1 := block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
b1 = block & 0x3300330033003300
b2 := block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24
b1 = block >> 48
b2 = block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
return block
}
// creates 16 28-bit blocks rotated according
// to the rotation schedule.
func ksRotate(in uint32) (out []uint32) {
out = make([]uint32, 16)
last := in
for i := 0; i < 16; i++ {
// 28-bit circular left shift
left := (last << (4 + ksRotations[i])) >> 4
right := (last << 4) >> (32 - ksRotations[i])
out[i] = left | right
last = out[i]
}
return
}
// creates 16 56-bit subkeys from the original key.
func (c *desCipher) generateSubkeys(keyBytes []byte) {
feistelBoxOnce.Do(initFeistelBox)
// apply PC1 permutation to key
key := binary.BigEndian.Uint64(keyBytes)
permutedKey := permuteBlock(key, permutedChoice1[:])
// rotate halves of permuted key according to the rotation schedule
leftRotations := ksRotate(uint32(permutedKey >> 28))
rightRotations := ksRotate(uint32(permutedKey<<4) >> 4)
// generate subkeys
for i := 0; i < 16; i++ {
// combine halves to form 56-bit input to PC2
pc2Input := uint64(leftRotations[i])<<28 | uint64(rightRotations[i])
// apply PC2 permutation to 7 byte input
c.subkeys[i] = unpack(permuteBlock(pc2Input, permutedChoice2[:]))
}
}
// Expand 48-bit input to 64-bit, with each 6-bit block padded by extra two bits at the top.
// By doing so, we can have the input blocks (four bits each), and the key blocks (six bits each) well-aligned without
// extra shifts/rotations for alignments.
func unpack(x uint64) uint64 {
return ((x>>(6*1))&0xff)<<(8*0) |
((x>>(6*3))&0xff)<<(8*1) |
((x>>(6*5))&0xff)<<(8*2) |
((x>>(6*7))&0xff)<<(8*3) |
((x>>(6*0))&0xff)<<(8*4) |
((x>>(6*2))&0xff)<<(8*5) |
((x>>(6*4))&0xff)<<(8*6) |
((x>>(6*6))&0xff)<<(8*7)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package des
import (
"crypto/cipher"
"crypto/internal/alias"
"encoding/binary"
"strconv"
)
// The DES block size in bytes.
const BlockSize = 8
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/des: invalid key size " + strconv.Itoa(int(k))
}
// desCipher is an instance of DES encryption.
type desCipher struct {
subkeys [16]uint64
}
// NewCipher creates and returns a new cipher.Block.
func NewCipher(key []byte) (cipher.Block, error) {
if len(key) != 8 {
return nil, KeySizeError(len(key))
}
c := new(desCipher)
c.generateSubkeys(key)
return c, nil
}
func (c *desCipher) BlockSize() int { return BlockSize }
func (c *desCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
encryptBlock(c.subkeys[:], dst, src)
}
func (c *desCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
decryptBlock(c.subkeys[:], dst, src)
}
// A tripleDESCipher is an instance of TripleDES encryption.
type tripleDESCipher struct {
cipher1, cipher2, cipher3 desCipher
}
// NewTripleDESCipher creates and returns a new cipher.Block.
func NewTripleDESCipher(key []byte) (cipher.Block, error) {
if len(key) != 24 {
return nil, KeySizeError(len(key))
}
c := new(tripleDESCipher)
c.cipher1.generateSubkeys(key[:8])
c.cipher2.generateSubkeys(key[8:16])
c.cipher3.generateSubkeys(key[16:])
return c, nil
}
func (c *tripleDESCipher) BlockSize() int { return BlockSize }
func (c *tripleDESCipher) Encrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
b := binary.BigEndian.Uint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher1.subkeys[2*i], c.cipher1.subkeys[2*i+1])
}
for i := 0; i < 8; i++ {
right, left = feistel(right, left, c.cipher2.subkeys[15-2*i], c.cipher2.subkeys[15-(2*i+1)])
}
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher3.subkeys[2*i], c.cipher3.subkeys[2*i+1])
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
preOutput := (uint64(right) << 32) | uint64(left)
binary.BigEndian.PutUint64(dst, permuteFinalBlock(preOutput))
}
func (c *tripleDESCipher) Decrypt(dst, src []byte) {
if len(src) < BlockSize {
panic("crypto/des: input not full block")
}
if len(dst) < BlockSize {
panic("crypto/des: output not full block")
}
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("crypto/des: invalid buffer overlap")
}
b := binary.BigEndian.Uint64(src)
b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b)
left = (left << 1) | (left >> 31)
right = (right << 1) | (right >> 31)
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher3.subkeys[15-2*i], c.cipher3.subkeys[15-(2*i+1)])
}
for i := 0; i < 8; i++ {
right, left = feistel(right, left, c.cipher2.subkeys[2*i], c.cipher2.subkeys[2*i+1])
}
for i := 0; i < 8; i++ {
left, right = feistel(left, right, c.cipher1.subkeys[15-2*i], c.cipher1.subkeys[15-(2*i+1)])
}
left = (left << 31) | (left >> 1)
right = (right << 31) | (right >> 1)
preOutput := (uint64(right) << 32) | uint64(left)
binary.BigEndian.PutUint64(dst, permuteFinalBlock(preOutput))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package dsa implements the Digital Signature Algorithm, as defined in FIPS 186-3.
//
// The DSA operations in this package are not implemented using constant-time algorithms.
//
// Deprecated: DSA is a legacy algorithm, and modern alternatives such as
// Ed25519 (implemented by package crypto/ed25519) should be used instead. Keys
// with 1024-bit moduli (L1024N160 parameters) are cryptographically weak, while
// bigger keys are not widely supported. Note that FIPS 186-5 no longer approves
// DSA for signature generation.
package dsa
import (
"errors"
"io"
"math/big"
"crypto/internal/randutil"
)
// Parameters represents the domain parameters for a key. These parameters can
// be shared across many keys. The bit length of Q must be a multiple of 8.
type Parameters struct {
P, Q, G *big.Int
}
// PublicKey represents a DSA public key.
type PublicKey struct {
Parameters
Y *big.Int
}
// PrivateKey represents a DSA private key.
type PrivateKey struct {
PublicKey
X *big.Int
}
// ErrInvalidPublicKey results when a public key is not usable by this code.
// FIPS is quite strict about the format of DSA keys, but other code may be
// less so. Thus, when using keys which may have been generated by other code,
// this error must be handled.
var ErrInvalidPublicKey = errors.New("crypto/dsa: invalid public key")
// ParameterSizes is an enumeration of the acceptable bit lengths of the primes
// in a set of DSA parameters. See FIPS 186-3, section 4.2.
type ParameterSizes int
const (
L1024N160 ParameterSizes = iota
L2048N224
L2048N256
L3072N256
)
// numMRTests is the number of Miller-Rabin primality tests that we perform. We
// pick the largest recommended number from table C.1 of FIPS 186-3.
const numMRTests = 64
// GenerateParameters puts a random, valid set of DSA parameters into params.
// This function can take many seconds, even on fast machines.
func GenerateParameters(params *Parameters, rand io.Reader, sizes ParameterSizes) error {
// This function doesn't follow FIPS 186-3 exactly in that it doesn't
// use a verification seed to generate the primes. The verification
// seed doesn't appear to be exported or used by other code and
// omitting it makes the code cleaner.
var L, N int
switch sizes {
case L1024N160:
L = 1024
N = 160
case L2048N224:
L = 2048
N = 224
case L2048N256:
L = 2048
N = 256
case L3072N256:
L = 3072
N = 256
default:
return errors.New("crypto/dsa: invalid ParameterSizes")
}
qBytes := make([]byte, N/8)
pBytes := make([]byte, L/8)
q := new(big.Int)
p := new(big.Int)
rem := new(big.Int)
one := new(big.Int)
one.SetInt64(1)
GeneratePrimes:
for {
if _, err := io.ReadFull(rand, qBytes); err != nil {
return err
}
qBytes[len(qBytes)-1] |= 1
qBytes[0] |= 0x80
q.SetBytes(qBytes)
if !q.ProbablyPrime(numMRTests) {
continue
}
for i := 0; i < 4*L; i++ {
if _, err := io.ReadFull(rand, pBytes); err != nil {
return err
}
pBytes[len(pBytes)-1] |= 1
pBytes[0] |= 0x80
p.SetBytes(pBytes)
rem.Mod(p, q)
rem.Sub(rem, one)
p.Sub(p, rem)
if p.BitLen() < L {
continue
}
if !p.ProbablyPrime(numMRTests) {
continue
}
params.P = p
params.Q = q
break GeneratePrimes
}
}
h := new(big.Int)
h.SetInt64(2)
g := new(big.Int)
pm1 := new(big.Int).Sub(p, one)
e := new(big.Int).Div(pm1, q)
for {
g.Exp(h, e, p)
if g.Cmp(one) == 0 {
h.Add(h, one)
continue
}
params.G = g
return nil
}
}
// GenerateKey generates a public&private key pair. The Parameters of the
// PrivateKey must already be valid (see GenerateParameters).
func GenerateKey(priv *PrivateKey, rand io.Reader) error {
if priv.P == nil || priv.Q == nil || priv.G == nil {
return errors.New("crypto/dsa: parameters not set up before generating key")
}
x := new(big.Int)
xBytes := make([]byte, priv.Q.BitLen()/8)
for {
_, err := io.ReadFull(rand, xBytes)
if err != nil {
return err
}
x.SetBytes(xBytes)
if x.Sign() != 0 && x.Cmp(priv.Q) < 0 {
break
}
}
priv.X = x
priv.Y = new(big.Int)
priv.Y.Exp(priv.G, x, priv.P)
return nil
}
// fermatInverse calculates the inverse of k in GF(P) using Fermat's method.
// This has better constant-time properties than Euclid's method (implemented
// in math/big.Int.ModInverse) although math/big itself isn't strictly
// constant-time so it's not perfect.
func fermatInverse(k, P *big.Int) *big.Int {
two := big.NewInt(2)
pMinus2 := new(big.Int).Sub(P, two)
return new(big.Int).Exp(k, pMinus2, P)
}
// Sign signs an arbitrary length hash (which should be the result of hashing a
// larger message) using the private key, priv. It returns the signature as a
// pair of integers. The security of the private key depends on the entropy of
// rand.
//
// Note that FIPS 186-3 section 4.6 specifies that the hash should be truncated
// to the byte-length of the subgroup. This function does not perform that
// truncation itself.
//
// Be aware that calling Sign with an attacker-controlled PrivateKey may
// require an arbitrary amount of CPU.
func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
randutil.MaybeReadByte(rand)
// FIPS 186-3, section 4.6
n := priv.Q.BitLen()
if priv.Q.Sign() <= 0 || priv.P.Sign() <= 0 || priv.G.Sign() <= 0 || priv.X.Sign() <= 0 || n%8 != 0 {
err = ErrInvalidPublicKey
return
}
n >>= 3
var attempts int
for attempts = 10; attempts > 0; attempts-- {
k := new(big.Int)
buf := make([]byte, n)
for {
_, err = io.ReadFull(rand, buf)
if err != nil {
return
}
k.SetBytes(buf)
// priv.Q must be >= 128 because the test above
// requires it to be > 0 and that
// ceil(log_2(Q)) mod 8 = 0
// Thus this loop will quickly terminate.
if k.Sign() > 0 && k.Cmp(priv.Q) < 0 {
break
}
}
kInv := fermatInverse(k, priv.Q)
r = new(big.Int).Exp(priv.G, k, priv.P)
r.Mod(r, priv.Q)
if r.Sign() == 0 {
continue
}
z := k.SetBytes(hash)
s = new(big.Int).Mul(priv.X, r)
s.Add(s, z)
s.Mod(s, priv.Q)
s.Mul(s, kInv)
s.Mod(s, priv.Q)
if s.Sign() != 0 {
break
}
}
// Only degenerate private keys will require more than a handful of
// attempts.
if attempts == 0 {
return nil, nil, ErrInvalidPublicKey
}
return
}
// Verify verifies the signature in r, s of hash using the public key, pub. It
// reports whether the signature is valid.
//
// Note that FIPS 186-3 section 4.6 specifies that the hash should be truncated
// to the byte-length of the subgroup. This function does not perform that
// truncation itself.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
// FIPS 186-3, section 4.7
if pub.P.Sign() == 0 {
return false
}
if r.Sign() < 1 || r.Cmp(pub.Q) >= 0 {
return false
}
if s.Sign() < 1 || s.Cmp(pub.Q) >= 0 {
return false
}
w := new(big.Int).ModInverse(s, pub.Q)
if w == nil {
return false
}
n := pub.Q.BitLen()
if n%8 != 0 {
return false
}
z := new(big.Int).SetBytes(hash)
u1 := new(big.Int).Mul(z, w)
u1.Mod(u1, pub.Q)
u2 := w.Mul(r, w)
u2.Mod(u2, pub.Q)
v := u1.Exp(pub.G, u1, pub.P)
u2.Exp(pub.Y, u2, pub.P)
v.Mul(v, u2)
v.Mod(v, pub.P)
v.Mod(v, pub.Q)
return v.Cmp(r) == 0
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ecdh implements Elliptic Curve Diffie-Hellman over
// NIST curves and Curve25519.
package ecdh
import (
"crypto"
"crypto/internal/boring"
"crypto/subtle"
"errors"
"io"
"sync"
)
type Curve interface {
// GenerateKey generates a new PrivateKey from rand.
GenerateKey(rand io.Reader) (*PrivateKey, error)
// NewPrivateKey checks that key is valid and returns a PrivateKey.
//
// For NIST curves, this follows SEC 1, Version 2.0, Section 2.3.6, which
// amounts to decoding the bytes as a fixed length big endian integer and
// checking that the result is lower than the order of the curve. The zero
// private key is also rejected, as the encoding of the corresponding public
// key would be irregular.
//
// For X25519, this only checks the scalar length.
NewPrivateKey(key []byte) (*PrivateKey, error)
// NewPublicKey checks that key is valid and returns a PublicKey.
//
// For NIST curves, this decodes an uncompressed point according to SEC 1,
// Version 2.0, Section 2.3.4. Compressed encodings and the point at
// infinity are rejected.
//
// For X25519, this only checks the u-coordinate length. Adversarially
// selected public keys can cause ECDH to return an error.
NewPublicKey(key []byte) (*PublicKey, error)
// ecdh performs a ECDH exchange and returns the shared secret. It's exposed
// as the PrivateKey.ECDH method.
//
// The private method also allow us to expand the ECDH interface with more
// methods in the future without breaking backwards compatibility.
ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error)
// privateKeyToPublicKey converts a PrivateKey to a PublicKey. It's exposed
// as the PrivateKey.PublicKey method.
//
// This method always succeeds: for X25519, the zero key can't be
// constructed due to clamping; for NIST curves, it is rejected by
// NewPrivateKey.
privateKeyToPublicKey(*PrivateKey) *PublicKey
}
// PublicKey is an ECDH public key, usually a peer's ECDH share sent over the wire.
//
// These keys can be parsed with [crypto/x509.ParsePKIXPublicKey] and encoded
// with [crypto/x509.MarshalPKIXPublicKey]. For NIST curves, they then need to
// be converted with [crypto/ecdsa.PublicKey.ECDH] after parsing.
type PublicKey struct {
curve Curve
publicKey []byte
boring *boring.PublicKeyECDH
}
// Bytes returns a copy of the encoding of the public key.
func (k *PublicKey) Bytes() []byte {
// Copy the public key to a fixed size buffer that can get allocated on the
// caller's stack after inlining.
var buf [133]byte
return append(buf[:0], k.publicKey...)
}
// Equal returns whether x represents the same public key as k.
//
// Note that there can be equivalent public keys with different encodings which
// would return false from this check but behave the same way as inputs to ECDH.
//
// This check is performed in constant time as long as the key types and their
// curve match.
func (k *PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey)
if !ok {
return false
}
return k.curve == xx.curve &&
subtle.ConstantTimeCompare(k.publicKey, xx.publicKey) == 1
}
func (k *PublicKey) Curve() Curve {
return k.curve
}
// PrivateKey is an ECDH private key, usually kept secret.
//
// These keys can be parsed with [crypto/x509.ParsePKCS8PrivateKey] and encoded
// with [crypto/x509.MarshalPKCS8PrivateKey]. For NIST curves, they then need to
// be converted with [crypto/ecdsa.PrivateKey.ECDH] after parsing.
type PrivateKey struct {
curve Curve
privateKey []byte
boring *boring.PrivateKeyECDH
// publicKey is set under publicKeyOnce, to allow loading private keys with
// NewPrivateKey without having to perform a scalar multiplication.
publicKey *PublicKey
publicKeyOnce sync.Once
}
// ECDH performs a ECDH exchange and returns the shared secret. The PrivateKey
// and PublicKey must use the same curve.
//
// For NIST curves, this performs ECDH as specified in SEC 1, Version 2.0,
// Section 3.3.1, and returns the x-coordinate encoded according to SEC 1,
// Version 2.0, Section 2.3.5. The result is never the point at infinity.
//
// For X25519, this performs ECDH as specified in RFC 7748, Section 6.1. If
// the result is the all-zero value, ECDH returns an error.
func (k *PrivateKey) ECDH(remote *PublicKey) ([]byte, error) {
if k.curve != remote.curve {
return nil, errors.New("crypto/ecdh: private key and public key curves do not match")
}
return k.curve.ecdh(k, remote)
}
// Bytes returns a copy of the encoding of the private key.
func (k *PrivateKey) Bytes() []byte {
// Copy the private key to a fixed size buffer that can get allocated on the
// caller's stack after inlining.
var buf [66]byte
return append(buf[:0], k.privateKey...)
}
// Equal returns whether x represents the same private key as k.
//
// Note that there can be equivalent private keys with different encodings which
// would return false from this check but behave the same way as inputs to ECDH.
//
// This check is performed in constant time as long as the key types and their
// curve match.
func (k *PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(*PrivateKey)
if !ok {
return false
}
return k.curve == xx.curve &&
subtle.ConstantTimeCompare(k.privateKey, xx.privateKey) == 1
}
func (k *PrivateKey) Curve() Curve {
return k.curve
}
func (k *PrivateKey) PublicKey() *PublicKey {
k.publicKeyOnce.Do(func() {
if k.boring != nil {
// Because we already checked in NewPrivateKey that the key is valid,
// there should not be any possible errors from BoringCrypto,
// so we turn the error into a panic.
// (We can't return it anyhow.)
kpub, err := k.boring.PublicKey()
if err != nil {
panic("boringcrypto: " + err.Error())
}
k.publicKey = &PublicKey{
curve: k.curve,
publicKey: kpub.Bytes(),
boring: kpub,
}
} else {
k.publicKey = k.curve.privateKeyToPublicKey(k)
}
})
return k.publicKey
}
// Public implements the implicit interface of all standard library private
// keys. See the docs of crypto.PrivateKey.
func (k *PrivateKey) Public() crypto.PublicKey {
return k.PublicKey()
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ecdh
import (
"crypto/internal/boring"
"crypto/internal/nistec"
"crypto/internal/randutil"
"encoding/binary"
"errors"
"io"
"math/bits"
)
type nistCurve[Point nistPoint[Point]] struct {
name string
newPoint func() Point
scalarOrder []byte
}
// nistPoint is a generic constraint for the nistec Point types.
type nistPoint[T any] interface {
Bytes() []byte
BytesX() ([]byte, error)
SetBytes([]byte) (T, error)
ScalarMult(T, []byte) (T, error)
ScalarBaseMult([]byte) (T, error)
}
func (c *nistCurve[Point]) String() string {
return c.name
}
var errInvalidPrivateKey = errors.New("crypto/ecdh: invalid private key")
func (c *nistCurve[Point]) GenerateKey(rand io.Reader) (*PrivateKey, error) {
if boring.Enabled && rand == boring.RandReader {
key, bytes, err := boring.GenerateKeyECDH(c.name)
if err != nil {
return nil, err
}
return newBoringPrivateKey(c, key, bytes)
}
key := make([]byte, len(c.scalarOrder))
randutil.MaybeReadByte(rand)
for {
if _, err := io.ReadFull(rand, key); err != nil {
return nil, err
}
// Mask off any excess bits if the size of the underlying field is not a
// whole number of bytes, which is only the case for P-521. We use a
// pointer to the scalarOrder field because comparing generic and
// instantiated types is not supported.
if &c.scalarOrder[0] == &p521Order[0] {
key[0] &= 0b0000_0001
}
// In tests, rand will return all zeros and NewPrivateKey will reject
// the zero key as it generates the identity as a public key. This also
// makes this function consistent with crypto/elliptic.GenerateKey.
key[1] ^= 0x42
k, err := c.NewPrivateKey(key)
if err == errInvalidPrivateKey {
continue
}
return k, err
}
}
func (c *nistCurve[Point]) NewPrivateKey(key []byte) (*PrivateKey, error) {
if len(key) != len(c.scalarOrder) {
return nil, errors.New("crypto/ecdh: invalid private key size")
}
if isZero(key) || !isLess(key, c.scalarOrder) {
return nil, errInvalidPrivateKey
}
if boring.Enabled {
bk, err := boring.NewPrivateKeyECDH(c.name, key)
if err != nil {
return nil, err
}
return newBoringPrivateKey(c, bk, key)
}
k := &PrivateKey{
curve: c,
privateKey: append([]byte{}, key...),
}
return k, nil
}
func newBoringPrivateKey(c Curve, bk *boring.PrivateKeyECDH, privateKey []byte) (*PrivateKey, error) {
k := &PrivateKey{
curve: c,
boring: bk,
privateKey: append([]byte(nil), privateKey...),
}
return k, nil
}
func (c *nistCurve[Point]) privateKeyToPublicKey(key *PrivateKey) *PublicKey {
boring.Unreachable()
if key.curve != c {
panic("crypto/ecdh: internal error: converting the wrong key type")
}
p, err := c.newPoint().ScalarBaseMult(key.privateKey)
if err != nil {
// This is unreachable because the only error condition of
// ScalarBaseMult is if the input is not the right size.
panic("crypto/ecdh: internal error: nistec ScalarBaseMult failed for a fixed-size input")
}
publicKey := p.Bytes()
if len(publicKey) == 1 {
// The encoding of the identity is a single 0x00 byte. This is
// unreachable because the only scalar that generates the identity is
// zero, which is rejected by NewPrivateKey.
panic("crypto/ecdh: internal error: nistec ScalarBaseMult returned the identity")
}
return &PublicKey{
curve: key.curve,
publicKey: publicKey,
}
}
// isZero returns whether a is all zeroes in constant time.
func isZero(a []byte) bool {
var acc byte
for _, b := range a {
acc |= b
}
return acc == 0
}
// isLess returns whether a < b, where a and b are big-endian buffers of the
// same length and shorter than 72 bytes.
func isLess(a, b []byte) bool {
if len(a) != len(b) {
panic("crypto/ecdh: internal error: mismatched isLess inputs")
}
// Copy the values into a fixed-size preallocated little-endian buffer.
// 72 bytes is enough for every scalar in this package, and having a fixed
// size lets us avoid heap allocations.
if len(a) > 72 {
panic("crypto/ecdh: internal error: isLess input too large")
}
bufA, bufB := make([]byte, 72), make([]byte, 72)
for i := range a {
bufA[i], bufB[i] = a[len(a)-i-1], b[len(b)-i-1]
}
// Perform a subtraction with borrow.
var borrow uint64
for i := 0; i < len(bufA); i += 8 {
limbA, limbB := binary.LittleEndian.Uint64(bufA[i:]), binary.LittleEndian.Uint64(bufB[i:])
_, borrow = bits.Sub64(limbA, limbB, borrow)
}
// If there is a borrow at the end of the operation, then a < b.
return borrow == 1
}
func (c *nistCurve[Point]) NewPublicKey(key []byte) (*PublicKey, error) {
// Reject the point at infinity and compressed encodings.
if len(key) == 0 || key[0] != 4 {
return nil, errors.New("crypto/ecdh: invalid public key")
}
k := &PublicKey{
curve: c,
publicKey: append([]byte{}, key...),
}
if boring.Enabled {
bk, err := boring.NewPublicKeyECDH(c.name, k.publicKey)
if err != nil {
return nil, err
}
k.boring = bk
} else {
// SetBytes also checks that the point is on the curve.
if _, err := c.newPoint().SetBytes(key); err != nil {
return nil, err
}
}
return k, nil
}
func (c *nistCurve[Point]) ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) {
// Note that this function can't return an error, as NewPublicKey rejects
// invalid points and the point at infinity, and NewPrivateKey rejects
// invalid scalars and the zero value. BytesX returns an error for the point
// at infinity, but in a prime order group such as the NIST curves that can
// only be the result of a scalar multiplication if one of the inputs is the
// zero scalar or the point at infinity.
if boring.Enabled {
return boring.ECDH(local.boring, remote.boring)
}
boring.Unreachable()
p, err := c.newPoint().SetBytes(remote.publicKey)
if err != nil {
return nil, err
}
if _, err := p.ScalarMult(p, local.privateKey); err != nil {
return nil, err
}
return p.BytesX()
}
// P256 returns a Curve which implements NIST P-256 (FIPS 186-3, section D.2.3),
// also known as secp256r1 or prime256v1.
//
// Multiple invocations of this function will return the same value, which can
// be used for equality checks and switch statements.
func P256() Curve { return p256 }
var p256 = &nistCurve[*nistec.P256Point]{
name: "P-256",
newPoint: nistec.NewP256Point,
scalarOrder: p256Order,
}
var p256Order = []byte{
0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51}
// P384 returns a Curve which implements NIST P-384 (FIPS 186-3, section D.2.4),
// also known as secp384r1.
//
// Multiple invocations of this function will return the same value, which can
// be used for equality checks and switch statements.
func P384() Curve { return p384 }
var p384 = &nistCurve[*nistec.P384Point]{
name: "P-384",
newPoint: nistec.NewP384Point,
scalarOrder: p384Order,
}
var p384Order = []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73}
// P521 returns a Curve which implements NIST P-521 (FIPS 186-3, section D.2.5),
// also known as secp521r1.
//
// Multiple invocations of this function will return the same value, which can
// be used for equality checks and switch statements.
func P521() Curve { return p521 }
var p521 = &nistCurve[*nistec.P521Point]{
name: "P-521",
newPoint: nistec.NewP521Point,
scalarOrder: p521Order,
}
var p521Order = []byte{0x01, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfa,
0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f, 0x96, 0x6b,
0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09, 0xa5, 0xd0,
0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c, 0x47, 0xae,
0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38, 0x64, 0x09}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ecdh
import (
"crypto/internal/edwards25519/field"
"crypto/internal/randutil"
"errors"
"io"
)
var (
x25519PublicKeySize = 32
x25519PrivateKeySize = 32
x25519SharedSecretSize = 32
)
// X25519 returns a Curve which implements the X25519 function over Curve25519
// (RFC 7748, Section 5).
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
func X25519() Curve { return x25519 }
var x25519 = &x25519Curve{}
type x25519Curve struct{}
func (c *x25519Curve) String() string {
return "X25519"
}
func (c *x25519Curve) GenerateKey(rand io.Reader) (*PrivateKey, error) {
key := make([]byte, x25519PrivateKeySize)
randutil.MaybeReadByte(rand)
if _, err := io.ReadFull(rand, key); err != nil {
return nil, err
}
return c.NewPrivateKey(key)
}
func (c *x25519Curve) NewPrivateKey(key []byte) (*PrivateKey, error) {
if len(key) != x25519PrivateKeySize {
return nil, errors.New("crypto/ecdh: invalid private key size")
}
return &PrivateKey{
curve: c,
privateKey: append([]byte{}, key...),
}, nil
}
func (c *x25519Curve) privateKeyToPublicKey(key *PrivateKey) *PublicKey {
if key.curve != c {
panic("crypto/ecdh: internal error: converting the wrong key type")
}
k := &PublicKey{
curve: key.curve,
publicKey: make([]byte, x25519PublicKeySize),
}
x25519Basepoint := [32]byte{9}
x25519ScalarMult(k.publicKey, key.privateKey, x25519Basepoint[:])
return k
}
func (c *x25519Curve) NewPublicKey(key []byte) (*PublicKey, error) {
if len(key) != x25519PublicKeySize {
return nil, errors.New("crypto/ecdh: invalid public key")
}
return &PublicKey{
curve: c,
publicKey: append([]byte{}, key...),
}, nil
}
func (c *x25519Curve) ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) {
out := make([]byte, x25519SharedSecretSize)
x25519ScalarMult(out, local.privateKey, remote.publicKey)
if isZero(out) {
return nil, errors.New("crypto/ecdh: bad X25519 remote ECDH input: low order point")
}
return out, nil
}
func x25519ScalarMult(dst, scalar, point []byte) {
var e [32]byte
copy(e[:], scalar[:])
e[0] &= 248
e[31] &= 127
e[31] |= 64
var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element
x1.SetBytes(point[:])
x2.One()
x3.Set(&x1)
z3.One()
swap := 0
for pos := 254; pos >= 0; pos-- {
b := e[pos/8] >> uint(pos&7)
b &= 1
swap ^= int(b)
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
swap = int(b)
tmp0.Subtract(&x3, &z3)
tmp1.Subtract(&x2, &z2)
x2.Add(&x2, &z2)
z2.Add(&x3, &z3)
z3.Multiply(&tmp0, &x2)
z2.Multiply(&z2, &tmp1)
tmp0.Square(&tmp1)
tmp1.Square(&x2)
x3.Add(&z3, &z2)
z2.Subtract(&z3, &z2)
x2.Multiply(&tmp1, &tmp0)
tmp1.Subtract(&tmp1, &tmp0)
z2.Square(&z2)
z3.Mult32(&tmp1, 121666)
x3.Square(&x3)
tmp0.Add(&tmp0, &z3)
z3.Multiply(&x1, &z2)
z2.Multiply(&tmp1, &tmp0)
}
x2.Swap(&x3, swap)
z2.Swap(&z3, swap)
z2.Invert(&z2)
x2.Multiply(&x2, &z2)
copy(dst[:], x2.Bytes())
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ecdsa implements the Elliptic Curve Digital Signature Algorithm, as
// defined in FIPS 186-4 and SEC 1, Version 2.0.
//
// Signatures generated by this package are not deterministic, but entropy is
// mixed with the private key and the message, achieving the same level of
// security in case of randomness source failure.
package ecdsa
// [FIPS 186-4] references ANSI X9.62-2005 for the bulk of the ECDSA algorithm.
// That standard is not freely available, which is a problem in an open source
// implementation, because not only the implementer, but also any maintainer,
// contributor, reviewer, auditor, and learner needs access to it. Instead, this
// package references and follows the equivalent [SEC 1, Version 2.0].
//
// [FIPS 186-4]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
// [SEC 1, Version 2.0]: https://www.secg.org/sec1-v2.pdf
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/elliptic"
"crypto/internal/bigmod"
"crypto/internal/boring"
"crypto/internal/boring/bbig"
"crypto/internal/nistec"
"crypto/internal/randutil"
"crypto/sha512"
"crypto/subtle"
"errors"
"io"
"math/big"
"sync"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// PublicKey represents an ECDSA public key.
type PublicKey struct {
elliptic.Curve
X, Y *big.Int
}
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// ECDH returns k as a [ecdh.PublicKey]. It returns an error if the key is
// invalid according to the definition of [ecdh.Curve.NewPublicKey], or if the
// Curve is not supported by crypto/ecdh.
func (k *PublicKey) ECDH() (*ecdh.PublicKey, error) {
c := curveToECDH(k.Curve)
if c == nil {
return nil, errors.New("ecdsa: unsupported curve by crypto/ecdh")
}
if !k.Curve.IsOnCurve(k.X, k.Y) {
return nil, errors.New("ecdsa: invalid public key")
}
return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y))
}
// Equal reports whether pub and x have the same value.
//
// Two keys are only considered to have the same value if they have the same Curve value.
// Note that for example elliptic.P256() and elliptic.P256().Params() are different
// values, as the latter is a generic not constant time implementation.
func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey)
if !ok {
return false
}
return bigIntEqual(pub.X, xx.X) && bigIntEqual(pub.Y, xx.Y) &&
// Standard library Curve implementations are singletons, so this check
// will work for those. Other Curves might be equivalent even if not
// singletons, but there is no definitive way to check for that, and
// better to err on the side of safety.
pub.Curve == xx.Curve
}
// PrivateKey represents an ECDSA private key.
type PrivateKey struct {
PublicKey
D *big.Int
}
// ECDH returns k as a [ecdh.PrivateKey]. It returns an error if the key is
// invalid according to the definition of [ecdh.Curve.NewPrivateKey], or if the
// Curve is not supported by crypto/ecdh.
func (k *PrivateKey) ECDH() (*ecdh.PrivateKey, error) {
c := curveToECDH(k.Curve)
if c == nil {
return nil, errors.New("ecdsa: unsupported curve by crypto/ecdh")
}
size := (k.Curve.Params().N.BitLen() + 7) / 8
if k.D.BitLen() > size*8 {
return nil, errors.New("ecdsa: invalid private key")
}
return c.NewPrivateKey(k.D.FillBytes(make([]byte, size)))
}
func curveToECDH(c elliptic.Curve) ecdh.Curve {
switch c {
case elliptic.P256():
return ecdh.P256()
case elliptic.P384():
return ecdh.P384()
case elliptic.P521():
return ecdh.P521()
default:
return nil
}
}
// Public returns the public key corresponding to priv.
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
// Equal reports whether priv and x have the same value.
//
// See PublicKey.Equal for details on how Curve is compared.
func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(*PrivateKey)
if !ok {
return false
}
return priv.PublicKey.Equal(&xx.PublicKey) && bigIntEqual(priv.D, xx.D)
}
// bigIntEqual reports whether a and b are equal leaking only their bit length
// through timing side-channels.
func bigIntEqual(a, b *big.Int) bool {
return subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
}
// Sign signs digest with priv, reading randomness from rand. The opts argument
// is not currently used but, in keeping with the crypto.Signer interface,
// should be the hash function used to digest the message.
//
// This method implements crypto.Signer, which is an interface to support keys
// where the private part is kept in, for example, a hardware module. Common
// uses can use the SignASN1 function in this package directly.
func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return SignASN1(rand, priv, digest)
}
// GenerateKey generates a public and private key pair.
func GenerateKey(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) {
randutil.MaybeReadByte(rand)
if boring.Enabled && rand == boring.RandReader {
x, y, d, err := boring.GenerateKeyECDSA(c.Params().Name)
if err != nil {
return nil, err
}
return &PrivateKey{PublicKey: PublicKey{Curve: c, X: bbig.Dec(x), Y: bbig.Dec(y)}, D: bbig.Dec(d)}, nil
}
boring.UnreachableExceptTests()
switch c.Params() {
case elliptic.P224().Params():
return generateNISTEC(p224(), rand)
case elliptic.P256().Params():
return generateNISTEC(p256(), rand)
case elliptic.P384().Params():
return generateNISTEC(p384(), rand)
case elliptic.P521().Params():
return generateNISTEC(p521(), rand)
default:
return generateLegacy(c, rand)
}
}
func generateNISTEC[Point nistPoint[Point]](c *nistCurve[Point], rand io.Reader) (*PrivateKey, error) {
k, Q, err := randomPoint(c, rand)
if err != nil {
return nil, err
}
priv := new(PrivateKey)
priv.PublicKey.Curve = c.curve
priv.D = new(big.Int).SetBytes(k.Bytes(c.N))
priv.PublicKey.X, priv.PublicKey.Y, err = c.pointToAffine(Q)
if err != nil {
return nil, err
}
return priv, nil
}
// randomPoint returns a random scalar and the corresponding point using the
// procedure given in FIPS 186-4, Appendix B.5.2 (rejection sampling).
func randomPoint[Point nistPoint[Point]](c *nistCurve[Point], rand io.Reader) (k *bigmod.Nat, p Point, err error) {
k = bigmod.NewNat()
for {
b := make([]byte, c.N.Size())
if _, err = io.ReadFull(rand, b); err != nil {
return
}
// Mask off any excess bits to increase the chance of hitting a value in
// (0, N). These are the most dangerous lines in the package and maybe in
// the library: a single bit of bias in the selection of nonces would likely
// lead to key recovery, but no tests would fail. Look but DO NOT TOUCH.
if excess := len(b)*8 - c.N.BitLen(); excess > 0 {
// Just to be safe, assert that this only happens for the one curve that
// doesn't have a round number of bits.
if excess != 0 && c.curve.Params().Name != "P-521" {
panic("ecdsa: internal error: unexpectedly masking off bits")
}
b[0] >>= excess
}
// FIPS 186-4 makes us check k <= N - 2 and then add one.
// Checking 0 < k <= N - 1 is strictly equivalent.
// None of this matters anyway because the chance of selecting
// zero is cryptographically negligible.
if _, err = k.SetBytes(b, c.N); err == nil && k.IsZero() == 0 {
break
}
if testingOnlyRejectionSamplingLooped != nil {
testingOnlyRejectionSamplingLooped()
}
}
p, err = c.newPoint().ScalarBaseMult(k.Bytes(c.N))
return
}
// testingOnlyRejectionSamplingLooped is called when rejection sampling in
// randomPoint rejects a candidate for being higher than the modulus.
var testingOnlyRejectionSamplingLooped func()
// errNoAsm is returned by signAsm and verifyAsm when the assembly
// implementation is not available.
var errNoAsm = errors.New("no assembly implementation available")
// SignASN1 signs a hash (which should be the result of hashing a larger message)
// using the private key, priv. If the hash is longer than the bit-length of the
// private key's curve order, the hash will be truncated to that length. It
// returns the ASN.1 encoded signature.
func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte) ([]byte, error) {
randutil.MaybeReadByte(rand)
if boring.Enabled && rand == boring.RandReader {
b, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignMarshalECDSA(b, hash)
}
boring.UnreachableExceptTests()
csprng, err := mixedCSPRNG(rand, priv, hash)
if err != nil {
return nil, err
}
if sig, err := signAsm(priv, csprng, hash); err != errNoAsm {
return sig, err
}
switch priv.Curve.Params() {
case elliptic.P224().Params():
return signNISTEC(p224(), priv, csprng, hash)
case elliptic.P256().Params():
return signNISTEC(p256(), priv, csprng, hash)
case elliptic.P384().Params():
return signNISTEC(p384(), priv, csprng, hash)
case elliptic.P521().Params():
return signNISTEC(p521(), priv, csprng, hash)
default:
return signLegacy(priv, csprng, hash)
}
}
func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, csprng io.Reader, hash []byte) (sig []byte, err error) {
// SEC 1, Version 2.0, Section 4.1.3
k, R, err := randomPoint(c, csprng)
if err != nil {
return nil, err
}
// kInv = k⁻¹
kInv := bigmod.NewNat()
inverse(c, kInv, k)
Rx, err := R.BytesX()
if err != nil {
return nil, err
}
r, err := bigmod.NewNat().SetOverflowingBytes(Rx, c.N)
if err != nil {
return nil, err
}
// The spec wants us to retry here, but the chance of hitting this condition
// on a large prime-order group like the NIST curves we support is
// cryptographically negligible. If we hit it, something is awfully wrong.
if r.IsZero() == 1 {
return nil, errors.New("ecdsa: internal error: r is zero")
}
e := bigmod.NewNat()
hashToNat(c, e, hash)
s, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
if err != nil {
return nil, err
}
s.Mul(r, c.N)
s.Add(e, c.N)
s.Mul(kInv, c.N)
// Again, the chance of this happening is cryptographically negligible.
if s.IsZero() == 1 {
return nil, errors.New("ecdsa: internal error: s is zero")
}
return encodeSignature(r.Bytes(c.N), s.Bytes(c.N))
}
func encodeSignature(r, s []byte) ([]byte, error) {
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
addASN1IntBytes(b, r)
addASN1IntBytes(b, s)
})
return b.Bytes()
}
// addASN1IntBytes encodes in ASN.1 a positive integer represented as
// a big-endian byte slice with zero or more leading zeroes.
func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
for len(bytes) > 0 && bytes[0] == 0 {
bytes = bytes[1:]
}
if len(bytes) == 0 {
b.SetError(errors.New("invalid integer"))
return
}
b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) {
if bytes[0]&0x80 != 0 {
c.AddUint8(0)
}
c.AddBytes(bytes)
})
}
// inverse sets kInv to the inverse of k modulo the order of the curve.
func inverse[Point nistPoint[Point]](c *nistCurve[Point], kInv, k *bigmod.Nat) {
if c.curve.Params().Name == "P-256" {
kBytes, err := nistec.P256OrdInverse(k.Bytes(c.N))
// Some platforms don't implement P256OrdInverse, and always return an error.
if err == nil {
_, err := kInv.SetBytes(kBytes, c.N)
if err != nil {
panic("ecdsa: internal error: P256OrdInverse produced an invalid value")
}
return
}
}
// Calculate the inverse of s in GF(N) using Fermat's method
// (exponentiation modulo P - 2, per Euler's theorem)
kInv.Exp(k, c.nMinus2, c.N)
}
// hashToNat sets e to the left-most bits of hash, according to
// SEC 1, Section 4.1.3, point 5 and Section 4.1.4, point 3.
func hashToNat[Point nistPoint[Point]](c *nistCurve[Point], e *bigmod.Nat, hash []byte) {
// ECDSA asks us to take the left-most log2(N) bits of hash, and use them as
// an integer modulo N. This is the absolute worst of all worlds: we still
// have to reduce, because the result might still overflow N, but to take
// the left-most bits for P-521 we have to do a right shift.
if size := c.N.Size(); len(hash) > size {
hash = hash[:size]
if excess := len(hash)*8 - c.N.BitLen(); excess > 0 {
hash = bytes.Clone(hash)
for i := len(hash) - 1; i >= 0; i-- {
hash[i] >>= excess
if i > 0 {
hash[i] |= hash[i-1] << (8 - excess)
}
}
}
}
_, err := e.SetOverflowingBytes(hash, c.N)
if err != nil {
panic("ecdsa: internal error: truncated hash is too long")
}
}
// mixedCSPRNG returns a CSPRNG that mixes entropy from rand with the message
// and the private key, to protect the key in case rand fails. This is
// equivalent in security to RFC 6979 deterministic nonce generation, but still
// produces randomized signatures.
func mixedCSPRNG(rand io.Reader, priv *PrivateKey, hash []byte) (io.Reader, error) {
// This implementation derives the nonce from an AES-CTR CSPRNG keyed by:
//
// SHA2-512(priv.D || entropy || hash)[:32]
//
// The CSPRNG key is indifferentiable from a random oracle as shown in
// [Coron], the AES-CTR stream is indifferentiable from a random oracle
// under standard cryptographic assumptions (see [Larsson] for examples).
//
// [Coron]: https://cs.nyu.edu/~dodis/ps/merkle.pdf
// [Larsson]: https://web.archive.org/web/20040719170906/https://www.nada.kth.se/kurser/kth/2D1441/semteo03/lecturenotes/assump.pdf
// Get 256 bits of entropy from rand.
entropy := make([]byte, 32)
if _, err := io.ReadFull(rand, entropy); err != nil {
return nil, err
}
// Initialize an SHA-512 hash context; digest...
md := sha512.New()
md.Write(priv.D.Bytes()) // the private key,
md.Write(entropy) // the entropy,
md.Write(hash) // and the input hash;
key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512),
// which is an indifferentiable MAC.
// Create an AES-CTR instance to use as a CSPRNG.
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// Create a CSPRNG that xors a stream of zeros with
// the output of the AES-CTR instance.
const aesIV = "IV for ECDSA CTR"
return &cipher.StreamReader{
R: zeroReader,
S: cipher.NewCTR(block, []byte(aesIV)),
}, nil
}
type zr struct{}
var zeroReader = zr{}
// Read replaces the contents of dst with zeros. It is safe for concurrent use.
func (zr) Read(dst []byte) (n int, err error) {
for i := range dst {
dst[i] = 0
}
return len(dst), nil
}
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
// public key, pub. Its return value records whether the signature is valid.
func VerifyASN1(pub *PublicKey, hash, sig []byte) bool {
if boring.Enabled {
key, err := boringPublicKey(pub)
if err != nil {
return false
}
return boring.VerifyECDSA(key, hash, sig)
}
boring.UnreachableExceptTests()
if err := verifyAsm(pub, hash, sig); err != errNoAsm {
return err == nil
}
switch pub.Curve.Params() {
case elliptic.P224().Params():
return verifyNISTEC(p224(), pub, hash, sig)
case elliptic.P256().Params():
return verifyNISTEC(p256(), pub, hash, sig)
case elliptic.P384().Params():
return verifyNISTEC(p384(), pub, hash, sig)
case elliptic.P521().Params():
return verifyNISTEC(p521(), pub, hash, sig)
default:
return verifyLegacy(pub, hash, sig)
}
}
func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, hash, sig []byte) bool {
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return false
}
Q, err := c.pointFromAffine(pub.X, pub.Y)
if err != nil {
return false
}
// SEC 1, Version 2.0, Section 4.1.4
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
if err != nil || r.IsZero() == 1 {
return false
}
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
if err != nil || s.IsZero() == 1 {
return false
}
e := bigmod.NewNat()
hashToNat(c, e, hash)
// w = s⁻¹
w := bigmod.NewNat()
inverse(c, w, s)
// p₁ = [e * s⁻¹]G
p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N))
if err != nil {
return false
}
// p₂ = [r * s⁻¹]Q
p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N))
if err != nil {
return false
}
// BytesX returns an error for the point at infinity.
Rx, err := p1.Add(p1, p2).BytesX()
if err != nil {
return false
}
v, err := bigmod.NewNat().SetOverflowingBytes(Rx, c.N)
if err != nil {
return false
}
return v.Equal(r) == 1
}
func parseSignature(sig []byte) (r, s []byte, err error) {
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(&r) ||
!inner.ReadASN1Integer(&s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1")
}
return r, s, nil
}
type nistCurve[Point nistPoint[Point]] struct {
newPoint func() Point
curve elliptic.Curve
N *bigmod.Modulus
nMinus2 []byte
}
// nistPoint is a generic constraint for the nistec Point types.
type nistPoint[T any] interface {
Bytes() []byte
BytesX() ([]byte, error)
SetBytes([]byte) (T, error)
Add(T, T) T
ScalarMult(T, []byte) (T, error)
ScalarBaseMult([]byte) (T, error)
}
// pointFromAffine is used to convert the PublicKey to a nistec Point.
func (curve *nistCurve[Point]) pointFromAffine(x, y *big.Int) (p Point, err error) {
bitSize := curve.curve.Params().BitSize
// Reject values that would not get correctly encoded.
if x.Sign() < 0 || y.Sign() < 0 {
return p, errors.New("negative coordinate")
}
if x.BitLen() > bitSize || y.BitLen() > bitSize {
return p, errors.New("overflowing coordinate")
}
// Encode the coordinates and let SetBytes reject invalid points.
byteLen := (bitSize + 7) / 8
buf := make([]byte, 1+2*byteLen)
buf[0] = 4 // uncompressed point
x.FillBytes(buf[1 : 1+byteLen])
y.FillBytes(buf[1+byteLen : 1+2*byteLen])
return curve.newPoint().SetBytes(buf)
}
// pointToAffine is used to convert a nistec Point to a PublicKey.
func (curve *nistCurve[Point]) pointToAffine(p Point) (x, y *big.Int, err error) {
out := p.Bytes()
if len(out) == 1 && out[0] == 0 {
// This is the encoding of the point at infinity.
return nil, nil, errors.New("ecdsa: public key point is the infinity")
}
byteLen := (curve.curve.Params().BitSize + 7) / 8
x = new(big.Int).SetBytes(out[1 : 1+byteLen])
y = new(big.Int).SetBytes(out[1+byteLen:])
return x, y, nil
}
var p224Once sync.Once
var _p224 *nistCurve[*nistec.P224Point]
func p224() *nistCurve[*nistec.P224Point] {
p224Once.Do(func() {
_p224 = &nistCurve[*nistec.P224Point]{
newPoint: func() *nistec.P224Point { return nistec.NewP224Point() },
}
precomputeParams(_p224, elliptic.P224())
})
return _p224
}
var p256Once sync.Once
var _p256 *nistCurve[*nistec.P256Point]
func p256() *nistCurve[*nistec.P256Point] {
p256Once.Do(func() {
_p256 = &nistCurve[*nistec.P256Point]{
newPoint: func() *nistec.P256Point { return nistec.NewP256Point() },
}
precomputeParams(_p256, elliptic.P256())
})
return _p256
}
var p384Once sync.Once
var _p384 *nistCurve[*nistec.P384Point]
func p384() *nistCurve[*nistec.P384Point] {
p384Once.Do(func() {
_p384 = &nistCurve[*nistec.P384Point]{
newPoint: func() *nistec.P384Point { return nistec.NewP384Point() },
}
precomputeParams(_p384, elliptic.P384())
})
return _p384
}
var p521Once sync.Once
var _p521 *nistCurve[*nistec.P521Point]
func p521() *nistCurve[*nistec.P521Point] {
p521Once.Do(func() {
_p521 = &nistCurve[*nistec.P521Point]{
newPoint: func() *nistec.P521Point { return nistec.NewP521Point() },
}
precomputeParams(_p521, elliptic.P521())
})
return _p521
}
func precomputeParams[Point nistPoint[Point]](c *nistCurve[Point], curve elliptic.Curve) {
params := curve.Params()
c.curve = curve
c.N = bigmod.NewModulusFromBig(params.N)
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ecdsa
import (
"crypto/elliptic"
"errors"
"io"
"math/big"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
)
// This file contains a math/big implementation of ECDSA that is only used for
// deprecated custom curves.
func generateLegacy(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) {
k, err := randFieldElement(c, rand)
if err != nil {
return nil, err
}
priv := new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
// hashToInt converts a hash value to an integer. Per FIPS 186-4, Section 6.4,
// we use the left-most bits of the hash to match the bit-length of the order of
// the curve. This also performs Step 5 of SEC 1, Version 2.0, Section 4.1.3.
func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
orderBits := c.Params().N.BitLen()
orderBytes := (orderBits + 7) / 8
if len(hash) > orderBytes {
hash = hash[:orderBytes]
}
ret := new(big.Int).SetBytes(hash)
excess := len(hash)*8 - orderBits
if excess > 0 {
ret.Rsh(ret, uint(excess))
}
return ret
}
var errZeroParam = errors.New("zero parameter")
// Sign signs a hash (which should be the result of hashing a larger message)
// using the private key, priv. If the hash is longer than the bit-length of the
// private key's curve order, the hash will be truncated to that length. It
// returns the signature as a pair of integers. Most applications should use
// SignASN1 instead of dealing directly with r, s.
func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
sig, err := SignASN1(rand, priv, hash)
if err != nil {
return nil, nil, err
}
r, s = new(big.Int), new(big.Int)
var inner cryptobyte.String
input := cryptobyte.String(sig)
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
!input.Empty() ||
!inner.ReadASN1Integer(r) ||
!inner.ReadASN1Integer(s) ||
!inner.Empty() {
return nil, nil, errors.New("invalid ASN.1 from SignASN1")
}
return r, s, nil
}
func signLegacy(priv *PrivateKey, csprng io.Reader, hash []byte) (sig []byte, err error) {
c := priv.Curve
// SEC 1, Version 2.0, Section 4.1.3
N := c.Params().N
if N.Sign() == 0 {
return nil, errZeroParam
}
var k, kInv, r, s *big.Int
for {
for {
k, err = randFieldElement(c, csprng)
if err != nil {
return nil, err
}
kInv = new(big.Int).ModInverse(k, N)
r, _ = c.ScalarBaseMult(k.Bytes())
r.Mod(r, N)
if r.Sign() != 0 {
break
}
}
e := hashToInt(hash, c)
s = new(big.Int).Mul(priv.D, r)
s.Add(s, e)
s.Mul(s, kInv)
s.Mod(s, N) // N != 0
if s.Sign() != 0 {
break
}
}
return encodeSignature(r.Bytes(), s.Bytes())
}
// Verify verifies the signature in r, s of hash using the public key, pub. Its
// return value records whether the signature is valid. Most applications should
// use VerifyASN1 instead of dealing directly with r, s.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
sig, err := encodeSignature(r.Bytes(), s.Bytes())
if err != nil {
return false
}
return VerifyASN1(pub, hash, sig)
}
func verifyLegacy(pub *PublicKey, hash []byte, sig []byte) bool {
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return false
}
r, s := new(big.Int).SetBytes(rBytes), new(big.Int).SetBytes(sBytes)
c := pub.Curve
N := c.Params().N
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
return false
}
// SEC 1, Version 2.0, Section 4.1.4
e := hashToInt(hash, c)
w := new(big.Int).ModInverse(s, N)
u1 := e.Mul(e, w)
u1.Mod(u1, N)
u2 := w.Mul(r, w)
u2.Mod(u2, N)
x1, y1 := c.ScalarBaseMult(u1.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, u2.Bytes())
x, y := c.Add(x1, y1, x2, y2)
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
x.Mod(x, N)
return x.Cmp(r) == 0
}
var one = new(big.Int).SetInt64(1)
// randFieldElement returns a random element of the order of the given
// curve using the procedure given in FIPS 186-4, Appendix B.5.2.
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
// See randomPoint for notes on the algorithm. This has to match, or s390x
// signatures will come out different from other architectures, which will
// break TLS recorded tests.
for {
N := c.Params().N
b := make([]byte, (N.BitLen()+7)/8)
if _, err = io.ReadFull(rand, b); err != nil {
return
}
if excess := len(b)*8 - N.BitLen(); excess > 0 {
b[0] >>= excess
}
k = new(big.Int).SetBytes(b)
if k.Sign() != 0 && k.Cmp(N) < 0 {
return
}
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !s390x
package ecdsa
import "io"
func verifyAsm(pub *PublicKey, hash []byte, sig []byte) error {
return errNoAsm
}
func signAsm(priv *PrivateKey, csprng io.Reader, hash []byte) (sig []byte, err error) {
return nil, errNoAsm
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package ecdsa
import "crypto/internal/boring"
func boringPublicKey(*PublicKey) (*boring.PublicKeyECDSA, error) {
panic("boringcrypto: not available")
}
func boringPrivateKey(*PrivateKey) (*boring.PrivateKeyECDSA, error) {
panic("boringcrypto: not available")
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ed25519 implements the Ed25519 signature algorithm. See
// https://ed25519.cr.yp.to/.
//
// These functions are also compatible with the “Ed25519” function defined in
// RFC 8032. However, unlike RFC 8032's formulation, this package's private key
// representation includes a public key suffix to make multiple signing
// operations with the same key more efficient. This package refers to the RFC
// 8032 private key as the “seed”.
package ed25519
import (
"bytes"
"crypto"
"crypto/internal/edwards25519"
cryptorand "crypto/rand"
"crypto/sha512"
"errors"
"io"
"strconv"
)
const (
// PublicKeySize is the size, in bytes, of public keys as used in this package.
PublicKeySize = 32
// PrivateKeySize is the size, in bytes, of private keys as used in this package.
PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64
// SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032.
SeedSize = 32
)
// PublicKey is the type of Ed25519 public keys.
type PublicKey []byte
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// Equal reports whether pub and x have the same value.
func (pub PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(PublicKey)
if !ok {
return false
}
return bytes.Equal(pub, xx)
}
// PrivateKey is the type of Ed25519 private keys. It implements [crypto.Signer].
type PrivateKey []byte
// Public returns the [PublicKey] corresponding to priv.
func (priv PrivateKey) Public() crypto.PublicKey {
publicKey := make([]byte, PublicKeySize)
copy(publicKey, priv[32:])
return PublicKey(publicKey)
}
// Equal reports whether priv and x have the same value.
func (priv PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(PrivateKey)
if !ok {
return false
}
return bytes.Equal(priv, xx)
}
// Seed returns the private key seed corresponding to priv. It is provided for
// interoperability with RFC 8032. RFC 8032's private keys correspond to seeds
// in this package.
func (priv PrivateKey) Seed() []byte {
return bytes.Clone(priv[:SeedSize])
}
// Sign signs the given message with priv. rand is ignored.
//
// If opts.HashFunc() is [crypto.SHA512], the pre-hashed variant Ed25519ph is used
// and message is expected to be a SHA-512 hash, otherwise opts.HashFunc() must
// be [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two
// passes over messages to be signed.
//
// A value of type [Options] can be used as opts, or crypto.Hash(0) or
// crypto.SHA512 directly to select plain Ed25519 or Ed25519ph, respectively.
func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) {
hash := opts.HashFunc()
context := ""
if opts, ok := opts.(*Options); ok {
context = opts.Context
}
if l := len(context); l > 255 {
return nil, errors.New("ed25519: bad Ed25519ph context length: " + strconv.Itoa(l))
}
switch {
case hash == crypto.SHA512: // Ed25519ph
if l := len(message); l != sha512.Size {
return nil, errors.New("ed25519: bad Ed25519ph message hash length: " + strconv.Itoa(l))
}
signature := make([]byte, SignatureSize)
sign(signature, priv, message, domPrefixPh, context)
return signature, nil
case hash == crypto.Hash(0) && context != "": // Ed25519ctx
signature := make([]byte, SignatureSize)
sign(signature, priv, message, domPrefixCtx, context)
return signature, nil
case hash == crypto.Hash(0): // Ed25519
return Sign(priv, message), nil
default:
return nil, errors.New("ed25519: expected opts.HashFunc() zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)")
}
}
// Options can be used with [PrivateKey.Sign] or [VerifyWithOptions]
// to select Ed25519 variants.
type Options struct {
// Hash can be zero for regular Ed25519, or crypto.SHA512 for Ed25519ph.
Hash crypto.Hash
// Context, if not empty, selects Ed25519ctx or provides the context string
// for Ed25519ph. It can be at most 255 bytes in length.
Context string
}
// HashFunc returns o.Hash.
func (o *Options) HashFunc() crypto.Hash { return o.Hash }
// GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, [crypto/rand.Reader] will be used.
func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) {
if rand == nil {
rand = cryptorand.Reader
}
seed := make([]byte, SeedSize)
if _, err := io.ReadFull(rand, seed); err != nil {
return nil, nil, err
}
privateKey := NewKeyFromSeed(seed)
publicKey := make([]byte, PublicKeySize)
copy(publicKey, privateKey[32:])
return publicKey, privateKey, nil
}
// NewKeyFromSeed calculates a private key from a seed. It will panic if
// len(seed) is not [SeedSize]. This function is provided for interoperability
// with RFC 8032. RFC 8032's private keys correspond to seeds in this
// package.
func NewKeyFromSeed(seed []byte) PrivateKey {
// Outline the function body so that the returned key can be stack-allocated.
privateKey := make([]byte, PrivateKeySize)
newKeyFromSeed(privateKey, seed)
return privateKey
}
func newKeyFromSeed(privateKey, seed []byte) {
if l := len(seed); l != SeedSize {
panic("ed25519: bad seed length: " + strconv.Itoa(l))
}
h := sha512.Sum512(seed)
s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32])
if err != nil {
panic("ed25519: internal error: setting scalar failed")
}
A := (&edwards25519.Point{}).ScalarBaseMult(s)
publicKey := A.Bytes()
copy(privateKey, seed)
copy(privateKey[32:], publicKey)
}
// Sign signs the message with privateKey and returns a signature. It will
// panic if len(privateKey) is not [PrivateKeySize].
func Sign(privateKey PrivateKey, message []byte) []byte {
// Outline the function body so that the returned signature can be
// stack-allocated.
signature := make([]byte, SignatureSize)
sign(signature, privateKey, message, domPrefixPure, "")
return signature
}
// Domain separation prefixes used to disambiguate Ed25519/Ed25519ph/Ed25519ctx.
// See RFC 8032, Section 2 and Section 5.1.
const (
// domPrefixPure is empty for pure Ed25519.
domPrefixPure = ""
// domPrefixPh is dom2(phflag=1) for Ed25519ph. It must be followed by the
// uint8-length prefixed context.
domPrefixPh = "SigEd25519 no Ed25519 collisions\x01"
// domPrefixCtx is dom2(phflag=0) for Ed25519ctx. It must be followed by the
// uint8-length prefixed context.
domPrefixCtx = "SigEd25519 no Ed25519 collisions\x00"
)
func sign(signature, privateKey, message []byte, domPrefix, context string) {
if l := len(privateKey); l != PrivateKeySize {
panic("ed25519: bad private key length: " + strconv.Itoa(l))
}
seed, publicKey := privateKey[:SeedSize], privateKey[SeedSize:]
h := sha512.Sum512(seed)
s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32])
if err != nil {
panic("ed25519: internal error: setting scalar failed")
}
prefix := h[32:]
mh := sha512.New()
if domPrefix != domPrefixPure {
mh.Write([]byte(domPrefix))
mh.Write([]byte{byte(len(context))})
mh.Write([]byte(context))
}
mh.Write(prefix)
mh.Write(message)
messageDigest := make([]byte, 0, sha512.Size)
messageDigest = mh.Sum(messageDigest)
r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest)
if err != nil {
panic("ed25519: internal error: setting scalar failed")
}
R := (&edwards25519.Point{}).ScalarBaseMult(r)
kh := sha512.New()
if domPrefix != domPrefixPure {
kh.Write([]byte(domPrefix))
kh.Write([]byte{byte(len(context))})
kh.Write([]byte(context))
}
kh.Write(R.Bytes())
kh.Write(publicKey)
kh.Write(message)
hramDigest := make([]byte, 0, sha512.Size)
hramDigest = kh.Sum(hramDigest)
k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest)
if err != nil {
panic("ed25519: internal error: setting scalar failed")
}
S := edwards25519.NewScalar().MultiplyAdd(k, s, r)
copy(signature[:32], R.Bytes())
copy(signature[32:], S.Bytes())
}
// Verify reports whether sig is a valid signature of message by publicKey. It
// will panic if len(publicKey) is not [PublicKeySize].
func Verify(publicKey PublicKey, message, sig []byte) bool {
return verify(publicKey, message, sig, domPrefixPure, "")
}
// VerifyWithOptions reports whether sig is a valid signature of message by
// publicKey. A valid signature is indicated by returning a nil error. It will
// panic if len(publicKey) is not [PublicKeySize].
//
// If opts.Hash is [crypto.SHA512], the pre-hashed variant Ed25519ph is used and
// message is expected to be a SHA-512 hash, otherwise opts.Hash must be
// [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two
// passes over messages to be signed.
func VerifyWithOptions(publicKey PublicKey, message, sig []byte, opts *Options) error {
switch {
case opts.Hash == crypto.SHA512: // Ed25519ph
if l := len(message); l != sha512.Size {
return errors.New("ed25519: bad Ed25519ph message hash length: " + strconv.Itoa(l))
}
if l := len(opts.Context); l > 255 {
return errors.New("ed25519: bad Ed25519ph context length: " + strconv.Itoa(l))
}
if !verify(publicKey, message, sig, domPrefixPh, opts.Context) {
return errors.New("ed25519: invalid signature")
}
return nil
case opts.Hash == crypto.Hash(0) && opts.Context != "": // Ed25519ctx
if l := len(opts.Context); l > 255 {
return errors.New("ed25519: bad Ed25519ctx context length: " + strconv.Itoa(l))
}
if !verify(publicKey, message, sig, domPrefixCtx, opts.Context) {
return errors.New("ed25519: invalid signature")
}
return nil
case opts.Hash == crypto.Hash(0): // Ed25519
if !verify(publicKey, message, sig, domPrefixPure, "") {
return errors.New("ed25519: invalid signature")
}
return nil
default:
return errors.New("ed25519: expected opts.Hash zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)")
}
}
func verify(publicKey PublicKey, message, sig []byte, domPrefix, context string) bool {
if l := len(publicKey); l != PublicKeySize {
panic("ed25519: bad public key length: " + strconv.Itoa(l))
}
if len(sig) != SignatureSize || sig[63]&224 != 0 {
return false
}
A, err := (&edwards25519.Point{}).SetBytes(publicKey)
if err != nil {
return false
}
kh := sha512.New()
if domPrefix != domPrefixPure {
kh.Write([]byte(domPrefix))
kh.Write([]byte{byte(len(context))})
kh.Write([]byte(context))
}
kh.Write(sig[:32])
kh.Write(publicKey)
kh.Write(message)
hramDigest := make([]byte, 0, sha512.Size)
hramDigest = kh.Sum(hramDigest)
k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest)
if err != nil {
panic("ed25519: internal error: setting scalar failed")
}
S, err := edwards25519.NewScalar().SetCanonicalBytes(sig[32:])
if err != nil {
return false
}
// [S]B = R + [k]A --> [k](-A) + [S]B = R
minusA := (&edwards25519.Point{}).Negate(A)
R := (&edwards25519.Point{}).VarTimeDoubleScalarBaseMult(k, minusA, S)
return bytes.Equal(sig[:32], R.Bytes())
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package elliptic implements the standard NIST P-224, P-256, P-384, and P-521
// elliptic curves over prime fields.
//
// The P224(), P256(), P384() and P521() values are necessary to use the crypto/ecdsa package.
// Most other uses should migrate to the more efficient and safer crypto/ecdh package.
package elliptic
import (
"io"
"math/big"
"sync"
)
// A Curve represents a short-form Weierstrass curve with a=-3.
//
// The behavior of Add, Double, and ScalarMult when the input is not a point on
// the curve is undefined.
//
// Note that the conventional point at infinity (0, 0) is not considered on the
// curve, although it can be returned by Add, Double, ScalarMult, or
// ScalarBaseMult (but not the Unmarshal or UnmarshalCompressed functions).
type Curve interface {
// Params returns the parameters for the curve.
Params() *CurveParams
// IsOnCurve reports whether the given (x,y) lies on the curve.
//
// Note: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. The NewPublicKey methods of NIST curves in crypto/ecdh accept
// the same encoding as the Unmarshal function, and perform on-curve checks.
IsOnCurve(x, y *big.Int) bool
// Add returns the sum of (x1,y1) and (x2,y2).
//
// Note: this is a low-level unsafe API.
Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int)
// Double returns 2*(x,y).
//
// Note: this is a low-level unsafe API.
Double(x1, y1 *big.Int) (x, y *big.Int)
// ScalarMult returns k*(x,y) where k is an integer in big-endian form.
//
// Note: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. Most uses of ScalarMult can be replaced by a call to the ECDH
// methods of NIST curves in crypto/ecdh.
ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int)
// ScalarBaseMult returns k*G, where G is the base point of the group
// and k is an integer in big-endian form.
//
// Note: this is a low-level unsafe API. For ECDH, use the crypto/ecdh
// package. Most uses of ScalarBaseMult can be replaced by a call to the
// PrivateKey.PublicKey method in crypto/ecdh.
ScalarBaseMult(k []byte) (x, y *big.Int)
}
var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
// GenerateKey returns a public/private key pair. The private key is
// generated using the given reader, which must return random data.
//
// Note: for ECDH, use the GenerateKey methods of the crypto/ecdh package;
// for ECDSA, use the GenerateKey function of the crypto/ecdsa package.
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
_, err = io.ReadFull(rand, priv)
if err != nil {
return
}
// We have to mask off any excess bits in the case that the size of the
// underlying field is not a whole number of bytes.
priv[0] &= mask[bitSize%8]
// This is because, in tests, rand will return all zeros and we don't
// want to get the point at infinity and loop forever.
priv[1] ^= 0x42
// If the scalar is out of range, sample another random number.
if new(big.Int).SetBytes(priv).Cmp(N) >= 0 {
continue
}
x, y = curve.ScalarBaseMult(priv)
}
return
}
// Marshal converts a point on the curve into the uncompressed form specified in
// SEC 1, Version 2.0, Section 2.3.3. If the point is not on the curve (or is
// the conventional point at infinity), the behavior is undefined.
//
// Note: for ECDH, use the crypto/ecdh package. This function returns an
// encoding equivalent to that of PublicKey.Bytes in crypto/ecdh.
func Marshal(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
x.FillBytes(ret[1 : 1+byteLen])
y.FillBytes(ret[1+byteLen : 1+2*byteLen])
return ret
}
// MarshalCompressed converts a point on the curve into the compressed form
// specified in SEC 1, Version 2.0, Section 2.3.3. If the point is not on the
// curve (or is the conventional point at infinity), the behavior is undefined.
func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
panicIfNotOnCurve(curve, x, y)
byteLen := (curve.Params().BitSize + 7) / 8
compressed := make([]byte, 1+byteLen)
compressed[0] = byte(y.Bit(0)) | 2
x.FillBytes(compressed[1:])
return compressed
}
// unmarshaler is implemented by curves with their own constant-time Unmarshal.
//
// There isn't an equivalent interface for Marshal/MarshalCompressed because
// that doesn't involve any mathematical operations, only FillBytes and Bit.
type unmarshaler interface {
Unmarshal([]byte) (x, y *big.Int)
UnmarshalCompressed([]byte) (x, y *big.Int)
}
// Assert that the known curves implement unmarshaler.
var _ = []unmarshaler{p224, p256, p384, p521}
// Unmarshal converts a point, serialized by Marshal, into an x, y pair. It is
// an error if the point is not in uncompressed form, is not on the curve, or is
// the point at infinity. On error, x = nil.
//
// Note: for ECDH, use the crypto/ecdh package. This function accepts an
// encoding equivalent to that of the NewPublicKey methods in crypto/ecdh.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.Unmarshal(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
return nil, nil
}
if data[0] != 4 { // uncompressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
y = new(big.Int).SetBytes(data[1+byteLen:])
if x.Cmp(p) >= 0 || y.Cmp(p) >= 0 {
return nil, nil
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into
// an x, y pair. It is an error if the point is not in compressed form, is not
// on the curve, or is the point at infinity. On error, x = nil.
func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
if c, ok := curve.(unmarshaler); ok {
return c.UnmarshalCompressed(data)
}
byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+byteLen {
return nil, nil
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1:])
if x.Cmp(p) >= 0 {
return nil, nil
}
// y² = x³ - 3x + b
y = curve.Params().polynomial(x)
y = y.ModSqrt(y, p)
if y == nil {
return nil, nil
}
if byte(y.Bit(0)) != data[0]&1 {
y.Neg(y).Mod(y, p)
}
if !curve.IsOnCurve(x, y) {
return nil, nil
}
return
}
func panicIfNotOnCurve(curve Curve, x, y *big.Int) {
// (0, 0) is the point at infinity by convention. It's ok to operate on it,
// although IsOnCurve is documented to return false for it. See Issue 37294.
if x.Sign() == 0 && y.Sign() == 0 {
return
}
if !curve.IsOnCurve(x, y) {
panic("crypto/elliptic: attempted operation on invalid point")
}
}
var initonce sync.Once
func initAll() {
initP224()
initP256()
initP384()
initP521()
}
// P224 returns a Curve which implements NIST P-224 (FIPS 186-3, section D.2.2),
// also known as secp224r1. The CurveParams.Name of this Curve is "P-224".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P224() Curve {
initonce.Do(initAll)
return p224
}
// P256 returns a Curve which implements NIST P-256 (FIPS 186-3, section D.2.3),
// also known as secp256r1 or prime256v1. The CurveParams.Name of this Curve is
// "P-256".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P256() Curve {
initonce.Do(initAll)
return p256
}
// P384 returns a Curve which implements NIST P-384 (FIPS 186-3, section D.2.4),
// also known as secp384r1. The CurveParams.Name of this Curve is "P-384".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P384() Curve {
initonce.Do(initAll)
return p384
}
// P521 returns a Curve which implements NIST P-521 (FIPS 186-3, section D.2.5),
// also known as secp521r1. The CurveParams.Name of this Curve is "P-521".
//
// Multiple invocations of this function will return the same value, so it can
// be used for equality checks and switch statements.
//
// The cryptographic operations are implemented using constant-time algorithms.
func P521() Curve {
initonce.Do(initAll)
return p521
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elliptic
import (
"crypto/internal/nistec"
"errors"
"math/big"
)
var p224 = &nistCurve[*nistec.P224Point]{
newPoint: nistec.NewP224Point,
}
func initP224() {
p224.params = &CurveParams{
Name: "P-224",
BitSize: 224,
// FIPS 186-4, section D.1.2.2
P: bigFromDecimal("26959946667150639794667015087019630673557916260026308143510066298881"),
N: bigFromDecimal("26959946667150639794667015087019625940457807714424391721682722368061"),
B: bigFromHex("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4"),
Gx: bigFromHex("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21"),
Gy: bigFromHex("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34"),
}
}
type p256Curve struct {
nistCurve[*nistec.P256Point]
}
var p256 = &p256Curve{nistCurve[*nistec.P256Point]{
newPoint: nistec.NewP256Point,
}}
func initP256() {
p256.params = &CurveParams{
Name: "P-256",
BitSize: 256,
// FIPS 186-4, section D.1.2.3
P: bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951"),
N: bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369"),
B: bigFromHex("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b"),
Gx: bigFromHex("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296"),
Gy: bigFromHex("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5"),
}
}
var p384 = &nistCurve[*nistec.P384Point]{
newPoint: nistec.NewP384Point,
}
func initP384() {
p384.params = &CurveParams{
Name: "P-384",
BitSize: 384,
// FIPS 186-4, section D.1.2.4
P: bigFromDecimal("394020061963944792122790401001436138050797392704654" +
"46667948293404245721771496870329047266088258938001861606973112319"),
N: bigFromDecimal("394020061963944792122790401001436138050797392704654" +
"46667946905279627659399113263569398956308152294913554433653942643"),
B: bigFromHex("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088" +
"f5013875ac656398d8a2ed19d2a85c8edd3ec2aef"),
Gx: bigFromHex("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741" +
"e082542a385502f25dbf55296c3a545e3872760ab7"),
Gy: bigFromHex("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da31" +
"13b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f"),
}
}
var p521 = &nistCurve[*nistec.P521Point]{
newPoint: nistec.NewP521Point,
}
func initP521() {
p521.params = &CurveParams{
Name: "P-521",
BitSize: 521,
// FIPS 186-4, section D.1.2.5
P: bigFromDecimal("68647976601306097149819007990813932172694353001433" +
"0540939446345918554318339765605212255964066145455497729631139148" +
"0858037121987999716643812574028291115057151"),
N: bigFromDecimal("68647976601306097149819007990813932172694353001433" +
"0540939446345918554318339765539424505774633321719753296399637136" +
"3321113864768612440380340372808892707005449"),
B: bigFromHex("0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8" +
"b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef" +
"451fd46b503f00"),
Gx: bigFromHex("00c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f8" +
"28af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf9" +
"7e7e31c2e5bd66"),
Gy: bigFromHex("011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817" +
"afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088" +
"be94769fd16650"),
}
}
// nistCurve is a Curve implementation based on a nistec Point.
//
// It's a wrapper that exposes the big.Int-based Curve interface and encodes the
// legacy idiosyncrasies it requires, such as invalid and infinity point
// handling.
//
// To interact with the nistec package, points are encoded into and decoded from
// properly formatted byte slices. All big.Int use is limited to this package.
// Encoding and decoding is 1/1000th of the runtime of a scalar multiplication,
// so the overhead is acceptable.
type nistCurve[Point nistPoint[Point]] struct {
newPoint func() Point
params *CurveParams
}
// nistPoint is a generic constraint for the nistec Point types.
type nistPoint[T any] interface {
Bytes() []byte
SetBytes([]byte) (T, error)
Add(T, T) T
Double(T) T
ScalarMult(T, []byte) (T, error)
ScalarBaseMult([]byte) (T, error)
}
func (curve *nistCurve[Point]) Params() *CurveParams {
return curve.params
}
func (curve *nistCurve[Point]) IsOnCurve(x, y *big.Int) bool {
// IsOnCurve is documented to reject (0, 0), the conventional point at
// infinity, which however is accepted by pointFromAffine.
if x.Sign() == 0 && y.Sign() == 0 {
return false
}
_, err := curve.pointFromAffine(x, y)
return err == nil
}
func (curve *nistCurve[Point]) pointFromAffine(x, y *big.Int) (p Point, err error) {
// (0, 0) is by convention the point at infinity, which can't be represented
// in affine coordinates. See Issue 37294.
if x.Sign() == 0 && y.Sign() == 0 {
return curve.newPoint(), nil
}
// Reject values that would not get correctly encoded.
if x.Sign() < 0 || y.Sign() < 0 {
return p, errors.New("negative coordinate")
}
if x.BitLen() > curve.params.BitSize || y.BitLen() > curve.params.BitSize {
return p, errors.New("overflowing coordinate")
}
// Encode the coordinates and let SetBytes reject invalid points.
byteLen := (curve.params.BitSize + 7) / 8
buf := make([]byte, 1+2*byteLen)
buf[0] = 4 // uncompressed point
x.FillBytes(buf[1 : 1+byteLen])
y.FillBytes(buf[1+byteLen : 1+2*byteLen])
return curve.newPoint().SetBytes(buf)
}
func (curve *nistCurve[Point]) pointToAffine(p Point) (x, y *big.Int) {
out := p.Bytes()
if len(out) == 1 && out[0] == 0 {
// This is the encoding of the point at infinity, which the affine
// coordinates API represents as (0, 0) by convention.
return new(big.Int), new(big.Int)
}
byteLen := (curve.params.BitSize + 7) / 8
x = new(big.Int).SetBytes(out[1 : 1+byteLen])
y = new(big.Int).SetBytes(out[1+byteLen:])
return x, y
}
func (curve *nistCurve[Point]) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
p1, err := curve.pointFromAffine(x1, y1)
if err != nil {
panic("crypto/elliptic: Add was called on an invalid point")
}
p2, err := curve.pointFromAffine(x2, y2)
if err != nil {
panic("crypto/elliptic: Add was called on an invalid point")
}
return curve.pointToAffine(p1.Add(p1, p2))
}
func (curve *nistCurve[Point]) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
p, err := curve.pointFromAffine(x1, y1)
if err != nil {
panic("crypto/elliptic: Double was called on an invalid point")
}
return curve.pointToAffine(p.Double(p))
}
// normalizeScalar brings the scalar within the byte size of the order of the
// curve, as expected by the nistec scalar multiplication functions.
func (curve *nistCurve[Point]) normalizeScalar(scalar []byte) []byte {
byteSize := (curve.params.N.BitLen() + 7) / 8
if len(scalar) == byteSize {
return scalar
}
s := new(big.Int).SetBytes(scalar)
if len(scalar) > byteSize {
s.Mod(s, curve.params.N)
}
out := make([]byte, byteSize)
return s.FillBytes(out)
}
func (curve *nistCurve[Point]) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
p, err := curve.pointFromAffine(Bx, By)
if err != nil {
panic("crypto/elliptic: ScalarMult was called on an invalid point")
}
scalar = curve.normalizeScalar(scalar)
p, err = p.ScalarMult(p, scalar)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return curve.pointToAffine(p)
}
func (curve *nistCurve[Point]) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
scalar = curve.normalizeScalar(scalar)
p, err := curve.newPoint().ScalarBaseMult(scalar)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return curve.pointToAffine(p)
}
// CombinedMult returns [s1]G + [s2]P where G is the generator. It's used
// through an interface upgrade in crypto/ecdsa.
func (curve *nistCurve[Point]) CombinedMult(Px, Py *big.Int, s1, s2 []byte) (x, y *big.Int) {
s1 = curve.normalizeScalar(s1)
q, err := curve.newPoint().ScalarBaseMult(s1)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
p, err := curve.pointFromAffine(Px, Py)
if err != nil {
panic("crypto/elliptic: CombinedMult was called on an invalid point")
}
s2 = curve.normalizeScalar(s2)
p, err = p.ScalarMult(p, s2)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return curve.pointToAffine(p.Add(p, q))
}
func (curve *nistCurve[Point]) Unmarshal(data []byte) (x, y *big.Int) {
if len(data) == 0 || data[0] != 4 {
return nil, nil
}
// Use SetBytes to check that data encodes a valid point.
_, err := curve.newPoint().SetBytes(data)
if err != nil {
return nil, nil
}
// We don't use pointToAffine because it involves an expensive field
// inversion to convert from Jacobian to affine coordinates, which we
// already have.
byteLen := (curve.params.BitSize + 7) / 8
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
y = new(big.Int).SetBytes(data[1+byteLen:])
return x, y
}
func (curve *nistCurve[Point]) UnmarshalCompressed(data []byte) (x, y *big.Int) {
if len(data) == 0 || (data[0] != 2 && data[0] != 3) {
return nil, nil
}
p, err := curve.newPoint().SetBytes(data)
if err != nil {
return nil, nil
}
return curve.pointToAffine(p)
}
func bigFromDecimal(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 10)
if !ok {
panic("crypto/elliptic: internal error: invalid encoding")
}
return b
}
func bigFromHex(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 16)
if !ok {
panic("crypto/elliptic: internal error: invalid encoding")
}
return b
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64
package elliptic
import (
"crypto/internal/nistec"
"math/big"
)
func (c p256Curve) Inverse(k *big.Int) *big.Int {
if k.Sign() < 0 {
// This should never happen.
k = new(big.Int).Neg(k)
}
if k.Cmp(c.params.N) >= 0 {
// This should never happen.
k = new(big.Int).Mod(k, c.params.N)
}
scalar := k.FillBytes(make([]byte, 32))
inverse, err := nistec.P256OrdInverse(scalar)
if err != nil {
panic("crypto/elliptic: nistec rejected normalized scalar")
}
return new(big.Int).SetBytes(inverse)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elliptic
import "math/big"
// CurveParams contains the parameters of an elliptic curve and also provides
// a generic, non-constant time implementation of Curve.
//
// Note: Custom curves (those not returned by P224(), P256(), P384(), and P521())
// are not guaranteed to provide any security property.
type CurveParams struct {
P *big.Int // the order of the underlying field
N *big.Int // the order of the base point
B *big.Int // the constant of the curve equation
Gx, Gy *big.Int // (x,y) of the base point
BitSize int // the size of the underlying field
Name string // the canonical name of the curve
}
func (curve *CurveParams) Params() *CurveParams {
return curve
}
// CurveParams operates, internally, on Jacobian coordinates. For a given
// (x, y) position on the curve, the Jacobian coordinates are (x1, y1, z1)
// where x = x1/z1² and y = y1/z1³. The greatest speedups come when the whole
// calculation can be performed within the transform (as in ScalarMult and
// ScalarBaseMult). But even for Add and Double, it's faster to apply and
// reverse the transform than to operate in affine coordinates.
// polynomial returns x³ - 3x + b.
func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
threeX := new(big.Int).Lsh(x, 1)
threeX.Add(threeX, x)
x3.Sub(x3, threeX)
x3.Add(x3, curve.B)
x3.Mod(x3, curve.P)
return x3
}
// IsOnCurve implements Curve.IsOnCurve.
//
// Note: the CurveParams methods are not guaranteed to
// provide any security property. For ECDH, use the crypto/ecdh package.
// For ECDSA, use the crypto/ecdsa package with a Curve value returned directly
// from P224(), P256(), P384(), or P521().
func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.IsOnCurve(x, y)
}
if x.Sign() < 0 || x.Cmp(curve.P) >= 0 ||
y.Sign() < 0 || y.Cmp(curve.P) >= 0 {
return false
}
// y² = x³ - 3x + b
y2 := new(big.Int).Mul(y, y)
y2.Mod(y2, curve.P)
return curve.polynomial(x).Cmp(y2) == 0
}
// zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
// y are zero, it assumes that they represent the point at infinity because (0,
// 0) is not on the any of the curves handled here.
func zForAffine(x, y *big.Int) *big.Int {
z := new(big.Int)
if x.Sign() != 0 || y.Sign() != 0 {
z.SetInt64(1)
}
return z
}
// affineFromJacobian reverses the Jacobian transform. See the comment at the
// top of the file. If the point is ∞ it returns 0, 0.
func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) {
if z.Sign() == 0 {
return new(big.Int), new(big.Int)
}
zinv := new(big.Int).ModInverse(z, curve.P)
zinvsq := new(big.Int).Mul(zinv, zinv)
xOut = new(big.Int).Mul(x, zinvsq)
xOut.Mod(xOut, curve.P)
zinvsq.Mul(zinvsq, zinv)
yOut = new(big.Int).Mul(y, zinvsq)
yOut.Mod(yOut, curve.P)
return
}
// Add implements Curve.Add.
//
// Note: the CurveParams methods are not guaranteed to
// provide any security property. For ECDH, use the crypto/ecdh package.
// For ECDSA, use the crypto/ecdsa package with a Curve value returned directly
// from P224(), P256(), P384(), or P521().
func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.Add(x1, y1, x2, y2)
}
panicIfNotOnCurve(curve, x1, y1)
panicIfNotOnCurve(curve, x2, y2)
z1 := zForAffine(x1, y1)
z2 := zForAffine(x2, y2)
return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2))
}
// addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and
// (x2, y2, z2) and returns their sum, also in Jacobian form.
func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) {
// See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl
x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int)
if z1.Sign() == 0 {
x3.Set(x2)
y3.Set(y2)
z3.Set(z2)
return x3, y3, z3
}
if z2.Sign() == 0 {
x3.Set(x1)
y3.Set(y1)
z3.Set(z1)
return x3, y3, z3
}
z1z1 := new(big.Int).Mul(z1, z1)
z1z1.Mod(z1z1, curve.P)
z2z2 := new(big.Int).Mul(z2, z2)
z2z2.Mod(z2z2, curve.P)
u1 := new(big.Int).Mul(x1, z2z2)
u1.Mod(u1, curve.P)
u2 := new(big.Int).Mul(x2, z1z1)
u2.Mod(u2, curve.P)
h := new(big.Int).Sub(u2, u1)
xEqual := h.Sign() == 0
if h.Sign() == -1 {
h.Add(h, curve.P)
}
i := new(big.Int).Lsh(h, 1)
i.Mul(i, i)
j := new(big.Int).Mul(h, i)
s1 := new(big.Int).Mul(y1, z2)
s1.Mul(s1, z2z2)
s1.Mod(s1, curve.P)
s2 := new(big.Int).Mul(y2, z1)
s2.Mul(s2, z1z1)
s2.Mod(s2, curve.P)
r := new(big.Int).Sub(s2, s1)
if r.Sign() == -1 {
r.Add(r, curve.P)
}
yEqual := r.Sign() == 0
if xEqual && yEqual {
return curve.doubleJacobian(x1, y1, z1)
}
r.Lsh(r, 1)
v := new(big.Int).Mul(u1, i)
x3.Set(r)
x3.Mul(x3, x3)
x3.Sub(x3, j)
x3.Sub(x3, v)
x3.Sub(x3, v)
x3.Mod(x3, curve.P)
y3.Set(r)
v.Sub(v, x3)
y3.Mul(y3, v)
s1.Mul(s1, j)
s1.Lsh(s1, 1)
y3.Sub(y3, s1)
y3.Mod(y3, curve.P)
z3.Add(z1, z2)
z3.Mul(z3, z3)
z3.Sub(z3, z1z1)
z3.Sub(z3, z2z2)
z3.Mul(z3, h)
z3.Mod(z3, curve.P)
return x3, y3, z3
}
// Double implements Curve.Double.
//
// Note: the CurveParams methods are not guaranteed to
// provide any security property. For ECDH, use the crypto/ecdh package.
// For ECDSA, use the crypto/ecdsa package with a Curve value returned directly
// from P224(), P256(), P384(), or P521().
func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.Double(x1, y1)
}
panicIfNotOnCurve(curve, x1, y1)
z1 := zForAffine(x1, y1)
return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1))
}
// doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and
// returns its double, also in Jacobian form.
func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) {
// See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2001-b
delta := new(big.Int).Mul(z, z)
delta.Mod(delta, curve.P)
gamma := new(big.Int).Mul(y, y)
gamma.Mod(gamma, curve.P)
alpha := new(big.Int).Sub(x, delta)
if alpha.Sign() == -1 {
alpha.Add(alpha, curve.P)
}
alpha2 := new(big.Int).Add(x, delta)
alpha.Mul(alpha, alpha2)
alpha2.Set(alpha)
alpha.Lsh(alpha, 1)
alpha.Add(alpha, alpha2)
beta := alpha2.Mul(x, gamma)
x3 := new(big.Int).Mul(alpha, alpha)
beta8 := new(big.Int).Lsh(beta, 3)
beta8.Mod(beta8, curve.P)
x3.Sub(x3, beta8)
if x3.Sign() == -1 {
x3.Add(x3, curve.P)
}
x3.Mod(x3, curve.P)
z3 := new(big.Int).Add(y, z)
z3.Mul(z3, z3)
z3.Sub(z3, gamma)
if z3.Sign() == -1 {
z3.Add(z3, curve.P)
}
z3.Sub(z3, delta)
if z3.Sign() == -1 {
z3.Add(z3, curve.P)
}
z3.Mod(z3, curve.P)
beta.Lsh(beta, 2)
beta.Sub(beta, x3)
if beta.Sign() == -1 {
beta.Add(beta, curve.P)
}
y3 := alpha.Mul(alpha, beta)
gamma.Mul(gamma, gamma)
gamma.Lsh(gamma, 3)
gamma.Mod(gamma, curve.P)
y3.Sub(y3, gamma)
if y3.Sign() == -1 {
y3.Add(y3, curve.P)
}
y3.Mod(y3, curve.P)
return x3, y3, z3
}
// ScalarMult implements Curve.ScalarMult.
//
// Note: the CurveParams methods are not guaranteed to
// provide any security property. For ECDH, use the crypto/ecdh package.
// For ECDSA, use the crypto/ecdsa package with a Curve value returned directly
// from P224(), P256(), P384(), or P521().
func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.ScalarMult(Bx, By, k)
}
panicIfNotOnCurve(curve, Bx, By)
Bz := new(big.Int).SetInt64(1)
x, y, z := new(big.Int), new(big.Int), new(big.Int)
for _, byte := range k {
for bitNum := 0; bitNum < 8; bitNum++ {
x, y, z = curve.doubleJacobian(x, y, z)
if byte&0x80 == 0x80 {
x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z)
}
byte <<= 1
}
}
return curve.affineFromJacobian(x, y, z)
}
// ScalarBaseMult implements Curve.ScalarBaseMult.
//
// Note: the CurveParams methods are not guaranteed to
// provide any security property. For ECDH, use the crypto/ecdh package.
// For ECDSA, use the crypto/ecdsa package with a Curve value returned directly
// from P224(), P256(), P384(), or P521().
func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
// If there is a dedicated constant-time implementation for this curve operation,
// use that instead of the generic one.
if specific, ok := matchesSpecificCurve(curve); ok {
return specific.ScalarBaseMult(k)
}
return curve.ScalarMult(curve.Gx, curve.Gy, k)
}
func matchesSpecificCurve(params *CurveParams) (Curve, bool) {
for _, c := range []Curve{p224, p256, p384, p521} {
if params == c.Params() {
return c, true
}
}
return nil, false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package hmac implements the Keyed-Hash Message Authentication Code (HMAC) as
defined in U.S. Federal Information Processing Standards Publication 198.
An HMAC is a cryptographic hash that uses a key to sign a message.
The receiver verifies the hash by recomputing it using the same key.
Receivers should be careful to use Equal to compare MACs in order to avoid
timing side-channels:
// ValidMAC reports whether messageMAC is a valid HMAC tag for message.
func ValidMAC(message, messageMAC, key []byte) bool {
mac := hmac.New(sha256.New, key)
mac.Write(message)
expectedMAC := mac.Sum(nil)
return hmac.Equal(messageMAC, expectedMAC)
}
*/
package hmac
import (
"crypto/internal/boring"
"crypto/subtle"
"hash"
)
// FIPS 198-1:
// https://csrc.nist.gov/publications/fips/fips198-1/FIPS-198-1_final.pdf
// key is zero padded to the block size of the hash function
// ipad = 0x36 byte repeated for key length
// opad = 0x5c byte repeated for key length
// hmac = H([key ^ opad] H([key ^ ipad] text))
// Marshalable is the combination of encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler. Their method definitions are repeated here to
// avoid a dependency on the encoding package.
type marshalable interface {
MarshalBinary() ([]byte, error)
UnmarshalBinary([]byte) error
}
type hmac struct {
opad, ipad []byte
outer, inner hash.Hash
// If marshaled is true, then opad and ipad do not contain a padded
// copy of the key, but rather the marshaled state of outer/inner after
// opad/ipad has been fed into it.
marshaled bool
}
func (h *hmac) Sum(in []byte) []byte {
origLen := len(in)
in = h.inner.Sum(in)
if h.marshaled {
if err := h.outer.(marshalable).UnmarshalBinary(h.opad); err != nil {
panic(err)
}
} else {
h.outer.Reset()
h.outer.Write(h.opad)
}
h.outer.Write(in[origLen:])
return h.outer.Sum(in[:origLen])
}
func (h *hmac) Write(p []byte) (n int, err error) {
return h.inner.Write(p)
}
func (h *hmac) Size() int { return h.outer.Size() }
func (h *hmac) BlockSize() int { return h.inner.BlockSize() }
func (h *hmac) Reset() {
if h.marshaled {
if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil {
panic(err)
}
return
}
h.inner.Reset()
h.inner.Write(h.ipad)
// If the underlying hash is marshalable, we can save some time by
// saving a copy of the hash state now, and restoring it on future
// calls to Reset and Sum instead of writing ipad/opad every time.
//
// If either hash is unmarshalable for whatever reason,
// it's safe to bail out here.
marshalableInner, innerOK := h.inner.(marshalable)
if !innerOK {
return
}
marshalableOuter, outerOK := h.outer.(marshalable)
if !outerOK {
return
}
imarshal, err := marshalableInner.MarshalBinary()
if err != nil {
return
}
h.outer.Reset()
h.outer.Write(h.opad)
omarshal, err := marshalableOuter.MarshalBinary()
if err != nil {
return
}
// Marshaling succeeded; save the marshaled state for later
h.ipad = imarshal
h.opad = omarshal
h.marshaled = true
}
// New returns a new HMAC hash using the given hash.Hash type and key.
// New functions like sha256.New from crypto/sha256 can be used as h.
// h must return a new Hash every time it is called.
// Note that unlike other hash implementations in the standard library,
// the returned Hash does not implement encoding.BinaryMarshaler
// or encoding.BinaryUnmarshaler.
func New(h func() hash.Hash, key []byte) hash.Hash {
if boring.Enabled {
hm := boring.NewHMAC(h, key)
if hm != nil {
return hm
}
// BoringCrypto did not recognize h, so fall through to standard Go code.
}
hm := new(hmac)
hm.outer = h()
hm.inner = h()
unique := true
func() {
defer func() {
// The comparison might panic if the underlying types are not comparable.
_ = recover()
}()
if hm.outer == hm.inner {
unique = false
}
}()
if !unique {
panic("crypto/hmac: hash generation function does not produce unique values")
}
blocksize := hm.inner.BlockSize()
hm.ipad = make([]byte, blocksize)
hm.opad = make([]byte, blocksize)
if len(key) > blocksize {
// If key is too big, hash it.
hm.outer.Write(key)
key = hm.outer.Sum(nil)
}
copy(hm.ipad, key)
copy(hm.opad, key)
for i := range hm.ipad {
hm.ipad[i] ^= 0x36
}
for i := range hm.opad {
hm.opad[i] ^= 0x5c
}
hm.inner.Write(hm.ipad)
return hm
}
// Equal compares two MACs for equality without leaking timing information.
func Equal(mac1, mac2 []byte) bool {
// We don't have to be constant time if the lengths of the MACs are
// different as that suggests that a completely different hash function
// was used.
return subtle.ConstantTimeCompare(mac1, mac2) == 1
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package alias implements memory aliasing tests.
// This code also exists as golang.org/x/crypto/internal/alias.
package alias
import "unsafe"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) &&
uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1]))
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bigmod
import (
"errors"
"math/big"
"math/bits"
)
const (
// _W is the number of bits we use for our limbs.
_W = bits.UintSize - 1
// _MASK selects _W bits from a full machine word.
_MASK = (1 << _W) - 1
)
// choice represents a constant-time boolean. The value of choice is always
// either 1 or 0. We use an int instead of bool in order to make decisions in
// constant time by turning it into a mask.
type choice uint
func not(c choice) choice { return 1 ^ c }
const yes = choice(1)
const no = choice(0)
// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
// function does not depend on its inputs. If on is any value besides 1 or 0,
// the result is undefined.
func ctSelect(on choice, x, y uint) uint {
// When on == 1, mask is 0b111..., otherwise mask is 0b000...
mask := -uint(on)
// When mask is all zeros, we just have y, otherwise, y cancels with itself.
return y ^ (mask & (y ^ x))
}
// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func ctEq(x, y uint) choice {
// If x != y, then either x - y or y - x will generate a carry.
_, c1 := bits.Sub(x, y, 0)
_, c2 := bits.Sub(y, x, 0)
return not(choice(c1 | c2))
}
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
func ctGeq(x, y uint) choice {
// If x < y, then x - y generates a carry.
_, carry := bits.Sub(x, y, 0)
return not(choice(carry))
}
// Nat represents an arbitrary natural number
//
// Each Nat has an announced length, which is the number of limbs it has stored.
// Operations on this number are allowed to leak this length, but will not leak
// any information about the values contained in those limbs.
type Nat struct {
// limbs is a little-endian representation in base 2^W with
// W = bits.UintSize - 1. The top bit is always unset between operations.
//
// The top bit is left unset to optimize Montgomery multiplication, in the
// inner loop of exponentiation. Using fully saturated limbs would leave us
// working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
// and thus time.
limbs []uint
}
// preallocTarget is the size in bits of the numbers used to implement the most
// common and most performant RSA key size. It's also enough to cover some of
// the operations of key sizes up to 4096.
const preallocTarget = 2048
const preallocLimbs = (preallocTarget + _W - 1) / _W
// NewNat returns a new nat with a size of zero, just like new(Nat), but with
// the preallocated capacity to hold a number of up to preallocTarget bits.
// NewNat inlines, so the allocation can live on the stack.
func NewNat() *Nat {
limbs := make([]uint, 0, preallocLimbs)
return &Nat{limbs}
}
// expand expands x to n limbs, leaving its value unchanged.
func (x *Nat) expand(n int) *Nat {
if len(x.limbs) > n {
panic("bigmod: internal error: shrinking nat")
}
if cap(x.limbs) < n {
newLimbs := make([]uint, n)
copy(newLimbs, x.limbs)
x.limbs = newLimbs
return x
}
extraLimbs := x.limbs[len(x.limbs):n]
for i := range extraLimbs {
extraLimbs[i] = 0
}
x.limbs = x.limbs[:n]
return x
}
// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
func (x *Nat) reset(n int) *Nat {
if cap(x.limbs) < n {
x.limbs = make([]uint, n)
return x
}
for i := range x.limbs {
x.limbs[i] = 0
}
x.limbs = x.limbs[:n]
return x
}
// set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) set(y *Nat) *Nat {
x.reset(len(y.limbs))
copy(x.limbs, y.limbs)
return x
}
// setBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat {
requiredLimbs := (n.BitLen() + _W - 1) / _W
x.reset(requiredLimbs)
outI := 0
shift := 0
limbs := n.Bits()
for i := range limbs {
xi := uint(limbs[i])
x.limbs[outI] |= (xi << shift) & _MASK
outI++
if outI == requiredLimbs {
return x
}
x.limbs[outI] = xi >> (_W - shift)
shift++ // this assumes bits.UintSize - _W = 1
if shift == _W {
shift = 0
outI++
}
}
return x
}
// Bytes returns x as a zero-extended big-endian byte slice. The size of the
// slice will match the size of m.
//
// x must have the same size as m and it must be reduced modulo m.
func (x *Nat) Bytes(m *Modulus) []byte {
bytes := make([]byte, m.Size())
shift := 0
outI := len(bytes) - 1
for _, limb := range x.limbs {
remainingBits := _W
for remainingBits >= 8 {
bytes[outI] |= byte(limb) << shift
consumed := 8 - shift
limb >>= consumed
remainingBits -= consumed
shift = 0
outI--
if outI < 0 {
return bytes
}
}
bytes[outI] = byte(limb)
shift = remainingBits
}
return bytes
}
// SetBytes assigns x = b, where b is a slice of big-endian bytes.
// SetBytes returns an error if b >= m.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
if err := x.setBytes(b, m); err != nil {
return nil, err
}
if x.cmpGeq(m.nat) == yes {
return nil, errors.New("input overflows the modulus")
}
return x, nil
}
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
// returns an error if b has a longer bit length than m, but reduces overflowing
// values up to 2^⌈log2(m)⌉ - 1.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
if err := x.setBytes(b, m); err != nil {
return nil, err
}
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
if leading < m.leading {
return nil, errors.New("input overflows the modulus")
}
x.sub(x.cmpGeq(m.nat), m.nat)
return x, nil
}
func (x *Nat) setBytes(b []byte, m *Modulus) error {
outI := 0
shift := 0
x.resetFor(m)
for i := len(b) - 1; i >= 0; i-- {
bi := b[i]
x.limbs[outI] |= uint(bi) << shift
shift += 8
if shift >= _W {
shift -= _W
x.limbs[outI] &= _MASK
overflow := bi >> (8 - shift)
outI++
if outI >= len(x.limbs) {
if overflow > 0 || i > 0 {
return errors.New("input overflows the modulus")
}
break
}
x.limbs[outI] = uint(overflow)
}
}
return nil
}
// Equal returns 1 if x == y, and 0 otherwise.
//
// Both operands must have the same announced length.
func (x *Nat) Equal(y *Nat) choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
equal := yes
for i := 0; i < size; i++ {
equal &= ctEq(xLimbs[i], yLimbs[i])
}
return equal
}
// IsZero returns 1 if x == 0, and 0 otherwise.
func (x *Nat) IsZero() choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
zero := yes
for i := 0; i < size; i++ {
zero &= ctEq(xLimbs[i], 0)
}
return zero
}
// cmpGeq returns 1 if x >= y, and 0 otherwise.
//
// Both operands must have the same announced length.
func (x *Nat) cmpGeq(y *Nat) choice {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
var c uint
for i := 0; i < size; i++ {
c = (xLimbs[i] - yLimbs[i] - c) >> _W
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice(c))
}
// assign sets x <- y if on == 1, and does nothing otherwise.
//
// Both operands must have the same announced length.
func (x *Nat) assign(on choice, y *Nat) *Nat {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
}
return x
}
// add computes x += y if on == 1, and does nothing otherwise. It returns the
// carry of the addition regardless of on.
//
// Both operands must have the same announced length.
func (x *Nat) add(on choice, y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
res := xLimbs[i] + yLimbs[i] + c
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
c = res >> _W
}
return
}
// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
// borrow of the subtraction regardless of on.
//
// Both operands must have the same announced length.
func (x *Nat) sub(on choice, y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
res := xLimbs[i] - yLimbs[i] - c
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
c = res >> _W
}
return
}
// Modulus is used for modular arithmetic, precomputing relevant constants.
//
// Moduli are assumed to be odd numbers. Moduli can also leak the exact
// number of bits needed to store their value, and are stored without padding.
//
// Their actual value is still kept secret.
type Modulus struct {
// The underlying natural number for this modulus.
//
// This will be stored without any padding, and shouldn't alias with any
// other natural number being used.
nat *Nat
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
func rr(m *Modulus) *Nat {
rr := NewNat().ExpandFor(m)
// R*R is 2^(2 * _W * n). We can safely get 2^(_W * (n - 1)) by setting the
// most significant limb to 1. We then get to R*R by shifting left by _W
// n + 1 times.
n := len(rr.limbs)
rr.limbs[n-1] = 1
for i := n - 1; i < 2*n; i++ {
rr.shiftIn(0, m) // x = x * 2^_W mod m
}
return rr
}
// minusInverseModW computes -x⁻¹ mod _W with x odd.
//
// This operation is used to precompute a constant involved in Montgomery
// multiplication.
func minusInverseModW(x uint) uint {
// Every iteration of this loop doubles the least-significant bits of
// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
// for 61 bits (and wastes only one iteration for 31 bits).
//
// See https://crypto.stackexchange.com/a/47496.
y := x
for i := 0; i < 5; i++ {
y = y * (2 - x*y)
}
return (1 << _W) - (y & _MASK)
}
// NewModulusFromBig creates a new Modulus from a [big.Int].
//
// The Int must be odd. The number of significant bits must be leakable.
func NewModulusFromBig(n *big.Int) *Modulus {
m := &Modulus{}
m.nat = NewNat().setBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
return m
}
// bitLen is a version of bits.Len that only leaks the bit length of n, but not
// its value. bits.Len and bits.LeadingZeros use a lookup table for the
// low-order bits on some architectures.
func bitLen(n uint) int {
var len int
// We assume, here and elsewhere, that comparison to zero is constant time
// with respect to different non-zero values.
for n != 0 {
len++
n >>= 1
}
return len
}
// Size returns the size of m in bytes.
func (m *Modulus) Size() int {
return (m.BitLen() + 7) / 8
}
// BitLen returns the size of m in bits.
func (m *Modulus) BitLen() int {
return len(m.nat.limbs)*_W - int(m.leading)
}
// Nat returns m as a Nat. The return value must not be written to.
func (m *Modulus) Nat() *Nat {
return m.nat
}
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().resetFor(m)
// Eliminate bounds checks in the loop.
size := len(m.nat.limbs)
xLimbs := x.limbs[:size]
dLimbs := d.limbs[:size]
mLimbs := m.nat.limbs[:size]
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time,
// reducing it every time.
//
// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
// The next iteration (and finally the return line) will use either result
// based on whether the subtraction underflowed.
needSubtraction := no
for i := _W - 1; i >= 0; i-- {
carry := (y >> i) & 1
var borrow uint
for i := 0; i < size; i++ {
l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
res := l<<1 + carry
xLimbs[i] = res & _MASK
carry = res >> _W
res = xLimbs[i] - mLimbs[i] - borrow
dLimbs[i] = res & _MASK
borrow = res >> _W
}
// See Add for how carry (aka overflow), borrow (aka underflow), and
// needSubtraction relate.
needSubtraction = ctEq(carry, borrow)
}
return x.assign(needSubtraction, d)
}
// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
out.resetFor(m)
// Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the
// correct number of bits. We can insert at least N - 1 limbs without
// overflowing m. After that, we need to reduce every time we shift.
i := len(x.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i).
start := len(m.nat.limbs) - 2
if i < start {
start = i
}
for j := start; j >= 0; j-- {
out.limbs[j] = x.limbs[i]
i--
}
// We shift in the remaining limbs, reducing modulo m each time.
for i >= 0 {
out.shiftIn(x.limbs[i], m)
i--
}
return out
}
// ExpandFor ensures out has the right size to work with operations modulo m.
//
// The announced size of out must be smaller than or equal to that of m.
func (out *Nat) ExpandFor(m *Modulus) *Nat {
return out.expand(len(m.nat.limbs))
}
// resetFor ensures out has the right size to work with operations modulo m.
//
// out is zeroed and may start at any size.
func (out *Nat) resetFor(m *Modulus) *Nat {
return out.reset(len(m.nat.limbs))
}
// Sub computes x = x - y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(yes, y)
// If the subtraction underflowed, add m.
x.add(choice(underflow), m.nat)
return x
}
// Add computes x = x + y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
overflow := x.add(yes, y)
underflow := not(x.cmpGeq(m.nat)) // x < m
// Three cases are possible:
//
// - overflow = 0, underflow = 0
//
// In this case, addition fits in our limbs, but we can still subtract away
// m without an underflow, so we need to perform the subtraction to reduce
// our result.
//
// - overflow = 0, underflow = 1
//
// The addition fits in our limbs, but we can't subtract m without
// underflowing. The result is already reduced.
//
// - overflow = 1, underflow = 1
//
// The addition does not fit in our limbs, and the subtraction's borrow
// would cancel out with the addition's carry. We need to subtract m to
// reduce our result.
//
// The overflow = 1, underflow = 0 case is not possible, because y is at
// most m - 1, and if adding m - 1 overflows, then subtracting m must
// necessarily underflow.
needSubtraction := ctEq(overflow, uint(underflow))
x.sub(needSubtraction, m.nat)
return x
}
// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs).
//
// Faster Montgomery multiplication replaces standard modular multiplication for
// numbers in this representation.
//
// This assumes that x is already reduced mod m.
func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
// A Montgomery multiplication (which computes a * b / R) by R * R works out
// to a multiplication by R, which takes the value out of the Montgomery domain.
return x.montgomeryMul(NewNat().set(x), m.rr, m)
}
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs).
//
// This assumes that x is already reduced mod m.
func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// By Montgomery multiplying with 1 not in Montgomery representation, we
// convert out back from Montgomery representation, because it works out to
// dividing by R.
t0 := NewNat().set(x)
t1 := NewNat().ExpandFor(m)
t1.limbs[0] = 1
return x.montgomeryMul(t0, t1, m)
}
// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
//
// All inputs should be the same length, not aliasing d, and already
// reduced modulo m. d will be resized to the size of m and overwritten.
func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
d.resetFor(m)
if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
panic("bigmod: invalid montgomeryMul input")
}
// See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
// for a description of the algorithm implemented mostly in montgomeryLoop.
// See Add for how overflow, underflow, and needSubtraction relate.
overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
underflow := not(d.cmpGeq(m.nat)) // d < m
needSubtraction := ctEq(overflow, uint(underflow))
d.sub(needSubtraction, m.nat)
return d
}
func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
// Eliminate bounds checks in the loop.
size := len(d)
a = a[:size]
b = b[:size]
m = m[:size]
for _, ai := range a {
// This is an unrolled iteration of the loop below with j = 0.
hi, lo := bits.Mul(ai, b[0])
z_lo, c := bits.Add(d[0], lo, 0)
f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
z_hi, _ := bits.Add(0, hi, c)
hi, lo = bits.Mul(f, m[0])
z_lo, c = bits.Add(z_lo, lo, 0)
z_hi, _ = bits.Add(z_hi, hi, c)
carry := z_hi<<1 | z_lo>>_W
for j := 1; j < size; j++ {
// z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
hi, lo := bits.Mul(ai, b[j])
z_lo, c := bits.Add(d[j], lo, 0)
z_hi, _ := bits.Add(0, hi, c)
hi, lo = bits.Mul(f, m[j])
z_lo, c = bits.Add(z_lo, lo, 0)
z_hi, _ = bits.Add(z_hi, hi, c)
z_lo, c = bits.Add(z_lo, carry, 0)
z_hi, _ = bits.Add(z_hi, 0, c)
d[j-1] = z_lo & _MASK
carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
}
z := overflow + carry // z <= 2^(W+1) - 1
d[size-1] = z & _MASK
overflow = z >> _W // overflow <= 1
}
return
}
// Mul calculates x *= y mod m.
//
// x and y must already be reduced modulo m, they must share its announced
// length, and they may not alias.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
// to the size of m and overwritten. x must already be reduced modulo m.
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement.
table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
// newNat calls are unrolled so they are allocated on the stack.
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
table[0].set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
table[i].montgomeryMul(table[i-1], table[0], m)
}
out.resetFor(m)
out.limbs[0] = 1
out.montgomeryRepresentation(m)
t0 := NewNat().ExpandFor(m)
t1 := NewNat().ExpandFor(m)
for _, b := range e {
for _, j := range []int{4, 0} {
// Square four times.
t1.montgomeryMul(out, out, m)
out.montgomeryMul(t1, t1, m)
t1.montgomeryMul(out, out, m)
out.montgomeryMul(t1, t1, m)
// Select x^k in constant time from the table.
k := uint((b >> j) & 0b1111)
for i := range table {
t0.assign(ctEq(k, uint(i+1)), table[i])
}
// Multiply by x^k, discarding the result if k = 0.
t1.montgomeryMul(out, t0, m)
out.assign(not(ctEq(k, 0)), t1)
}
}
return out.montgomeryReduction(m)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bcache implements a GC-friendly cache (see [Cache]) for BoringCrypto.
package bcache
import (
"sync/atomic"
"unsafe"
)
// A Cache is a GC-friendly concurrent map from unsafe.Pointer to
// unsafe.Pointer. It is meant to be used for maintaining shadow
// BoringCrypto state associated with certain allocated structs, in
// particular public and private RSA and ECDSA keys.
//
// The cache is GC-friendly in the sense that the keys do not
// indefinitely prevent the garbage collector from collecting them.
// Instead, at the start of each GC, the cache is cleared entirely. That
// is, the cache is lossy, and the loss happens at the start of each GC.
// This means that clients need to be able to cope with cache entries
// disappearing, but it also means that clients don't need to worry about
// cache entries keeping the keys from being collected.
type Cache[K, V any] struct {
// The runtime atomically stores nil to ptable at the start of each GC.
ptable atomic.Pointer[cacheTable[K, V]]
}
type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]]
// A cacheEntry is a single entry in the linked list for a given hash table entry.
type cacheEntry[K, V any] struct {
k *K // immutable once created
v atomic.Pointer[V] // read and written atomically to allow updates
next *cacheEntry[K, V] // immutable once linked into table
}
func registerCache(unsafe.Pointer) // provided by runtime
// Register registers the cache with the runtime,
// so that c.ptable can be cleared at the start of each GC.
// Register must be called during package initialization.
func (c *Cache[K, V]) Register() {
registerCache(unsafe.Pointer(&c.ptable))
}
// cacheSize is the number of entries in the hash table.
// The hash is the pointer value mod cacheSize, a prime.
// Collisions are resolved by maintaining a linked list in each hash slot.
const cacheSize = 1021
// table returns a pointer to the current cache hash table,
// coping with the possibility of the GC clearing it out from under us.
func (c *Cache[K, V]) table() *cacheTable[K, V] {
for {
p := c.ptable.Load()
if p == nil {
p = new(cacheTable[K, V])
if !c.ptable.CompareAndSwap(nil, p) {
continue
}
}
return p
}
}
// Clear clears the cache.
// The runtime does this automatically at each garbage collection;
// this method is exposed only for testing.
func (c *Cache[K, V]) Clear() {
// The runtime does this at the start of every garbage collection
// (itself, not by calling this function).
c.ptable.Store(nil)
}
// Get returns the cached value associated with v,
// which is either the value v corresponding to the most recent call to Put(k, v)
// or nil if that cache entry has been dropped.
func (c *Cache[K, V]) Get(k *K) *V {
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
e := head.Load()
for ; e != nil; e = e.next {
if e.k == k {
return e.v.Load()
}
}
return nil
}
// Put sets the cached value associated with k to v.
func (c *Cache[K, V]) Put(k *K, v *V) {
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
// Strategy is to walk the linked list at head,
// same as in Get, to look for existing entry.
// If we find one, we update v atomically in place.
// If not, then we race to replace the start = *head
// we observed with a new k, v entry.
// If we win that race, we're done.
// Otherwise, we try the whole thing again,
// with two optimizations:
//
// 1. We track in noK the start of the section of
// the list that we've confirmed has no entry for k.
// The next time down the list, we can stop at noK,
// because new entries are inserted at the front of the list.
// This guarantees we never traverse an entry
// multiple times.
//
// 2. We only allocate the entry to be added once,
// saving it in add for the next attempt.
var add, noK *cacheEntry[K, V]
n := 0
for {
e := head.Load()
start := e
for ; e != nil && e != noK; e = e.next {
if e.k == k {
e.v.Store(v)
return
}
n++
}
if add == nil {
add = &cacheEntry[K, V]{k: k}
add.v.Store(v)
}
add.next = start
if n >= 1000 {
// If an individual list gets too long, which shouldn't happen,
// throw it away to avoid quadratic lookup behavior.
add.next = nil
}
if head.CompareAndSwap(start, add) {
return
}
noK = start
}
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !(boringcrypto && linux && (amd64 || arm64) && !android && !cmd_go_bootstrap && !msan && cgo)
package boring
import (
"crypto"
"crypto/cipher"
"crypto/internal/boring/sig"
"hash"
)
const available = false
// Unreachable marks code that should be unreachable
// when BoringCrypto is in use. It is a no-op without BoringCrypto.
func Unreachable() {
// Code that's unreachable when using BoringCrypto
// is exactly the code we want to detect for reporting
// standard Go crypto.
sig.StandardCrypto()
}
// UnreachableExceptTests marks code that should be unreachable
// when BoringCrypto is in use. It is a no-op without BoringCrypto.
func UnreachableExceptTests() {}
type randReader int
func (randReader) Read(b []byte) (int, error) { panic("boringcrypto: not available") }
const RandReader = randReader(0)
func NewSHA1() hash.Hash { panic("boringcrypto: not available") }
func NewSHA224() hash.Hash { panic("boringcrypto: not available") }
func NewSHA256() hash.Hash { panic("boringcrypto: not available") }
func NewSHA384() hash.Hash { panic("boringcrypto: not available") }
func NewSHA512() hash.Hash { panic("boringcrypto: not available") }
func SHA1([]byte) [20]byte { panic("boringcrypto: not available") }
func SHA224([]byte) [28]byte { panic("boringcrypto: not available") }
func SHA256([]byte) [32]byte { panic("boringcrypto: not available") }
func SHA384([]byte) [48]byte { panic("boringcrypto: not available") }
func SHA512([]byte) [64]byte { panic("boringcrypto: not available") }
func NewHMAC(h func() hash.Hash, key []byte) hash.Hash { panic("boringcrypto: not available") }
func NewAESCipher(key []byte) (cipher.Block, error) { panic("boringcrypto: not available") }
func NewGCMTLS(cipher.Block) (cipher.AEAD, error) { panic("boringcrypto: not available") }
type PublicKeyECDSA struct{ _ int }
type PrivateKeyECDSA struct{ _ int }
func GenerateKeyECDSA(curve string) (X, Y, D BigInt, err error) {
panic("boringcrypto: not available")
}
func NewPrivateKeyECDSA(curve string, X, Y, D BigInt) (*PrivateKeyECDSA, error) {
panic("boringcrypto: not available")
}
func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) {
panic("boringcrypto: not available")
}
func SignMarshalECDSA(priv *PrivateKeyECDSA, hash []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func VerifyECDSA(pub *PublicKeyECDSA, hash []byte, sig []byte) bool {
panic("boringcrypto: not available")
}
type PublicKeyRSA struct{ _ int }
type PrivateKeyRSA struct{ _ int }
func DecryptRSAOAEP(h, mgfHash hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func EncryptRSAOAEP(h, mgfHash hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
panic("boringcrypto: not available")
}
func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) {
panic("boringcrypto: not available")
}
func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { panic("boringcrypto: not available") }
func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) {
panic("boringcrypto: not available")
}
func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) {
panic("boringcrypto: not available")
}
func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error {
panic("boringcrypto: not available")
}
func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error {
panic("boringcrypto: not available")
}
type PublicKeyECDH struct{}
type PrivateKeyECDH struct{}
func ECDH(*PrivateKeyECDH, *PublicKeyECDH) ([]byte, error) { panic("boringcrypto: not available") }
func GenerateKeyECDH(string) (*PrivateKeyECDH, []byte, error) { panic("boringcrypto: not available") }
func NewPrivateKeyECDH(string, []byte) (*PrivateKeyECDH, error) { panic("boringcrypto: not available") }
func NewPublicKeyECDH(string, []byte) (*PublicKeyECDH, error) { panic("boringcrypto: not available") }
func (*PublicKeyECDH) Bytes() []byte { panic("boringcrypto: not available") }
func (*PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { panic("boringcrypto: not available") }
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/internal/edwards25519/field"
"errors"
)
// Point types.
type projP1xP1 struct {
X, Y, Z, T field.Element
}
type projP2 struct {
X, Y, Z field.Element
}
// Point represents a point on the edwards25519 curve.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is NOT valid, and it may be used only as a receiver.
type Point struct {
// Make the type not comparable (i.e. used with == or as a map key), as
// equivalent points can be represented by different Go values.
_ incomparable
// The point is internally represented in extended coordinates (X, Y, Z, T)
// where x = X/Z, y = Y/Z, and xy = T/Z per https://eprint.iacr.org/2008/522.
x, y, z, t field.Element
}
type incomparable [0]func()
func checkInitialized(points ...*Point) {
for _, p := range points {
if p.x == (field.Element{}) && p.y == (field.Element{}) {
panic("edwards25519: use of uninitialized Point")
}
}
}
type projCached struct {
YplusX, YminusX, Z, T2d field.Element
}
type affineCached struct {
YplusX, YminusX, T2d field.Element
}
// Constructors.
func (v *projP2) Zero() *projP2 {
v.X.Zero()
v.Y.One()
v.Z.One()
return v
}
// identity is the point at infinity.
var identity, _ = new(Point).SetBytes([]byte{
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
// NewIdentityPoint returns a new Point set to the identity.
func NewIdentityPoint() *Point {
return new(Point).Set(identity)
}
// generator is the canonical curve basepoint. See TestGenerator for the
// correspondence of this encoding with the values in RFC 8032.
var generator, _ = new(Point).SetBytes([]byte{
0x58, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66,
0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66})
// NewGeneratorPoint returns a new Point set to the canonical generator.
func NewGeneratorPoint() *Point {
return new(Point).Set(generator)
}
func (v *projCached) Zero() *projCached {
v.YplusX.One()
v.YminusX.One()
v.Z.One()
v.T2d.Zero()
return v
}
func (v *affineCached) Zero() *affineCached {
v.YplusX.One()
v.YminusX.One()
v.T2d.Zero()
return v
}
// Assignments.
// Set sets v = u, and returns v.
func (v *Point) Set(u *Point) *Point {
*v = *u
return v
}
// Encoding.
// Bytes returns the canonical 32-byte encoding of v, according to RFC 8032,
// Section 5.1.2.
func (v *Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var buf [32]byte
return v.bytes(&buf)
}
func (v *Point) bytes(buf *[32]byte) []byte {
checkInitialized(v)
var zInv, x, y field.Element
zInv.Invert(&v.z) // zInv = 1 / Z
x.Multiply(&v.x, &zInv) // x = X / Z
y.Multiply(&v.y, &zInv) // y = Y / Z
out := copyFieldElement(buf, &y)
out[31] |= byte(x.IsNegative() << 7)
return out
}
var feOne = new(field.Element).One()
// SetBytes sets v = x, where x is a 32-byte encoding of v. If x does not
// represent a valid point on the curve, SetBytes returns nil and an error and
// the receiver is unchanged. Otherwise, SetBytes returns v.
//
// Note that SetBytes accepts all non-canonical encodings of valid points.
// That is, it follows decoding rules that match most implementations in
// the ecosystem rather than RFC 8032.
func (v *Point) SetBytes(x []byte) (*Point, error) {
// Specifically, the non-canonical encodings that are accepted are
// 1) the ones where the field element is not reduced (see the
// (*field.Element).SetBytes docs) and
// 2) the ones where the x-coordinate is zero and the sign bit is set.
//
// Read more at https://hdevalence.ca/blog/2020-10-04-its-25519am,
// specifically the "Canonical A, R" section.
y, err := new(field.Element).SetBytes(x)
if err != nil {
return nil, errors.New("edwards25519: invalid point encoding length")
}
// -x² + y² = 1 + dx²y²
// x² + dx²y² = x²(dy² + 1) = y² - 1
// x² = (y² - 1) / (dy² + 1)
// u = y² - 1
y2 := new(field.Element).Square(y)
u := new(field.Element).Subtract(y2, feOne)
// v = dy² + 1
vv := new(field.Element).Multiply(y2, d)
vv = vv.Add(vv, feOne)
// x = +√(u/v)
xx, wasSquare := new(field.Element).SqrtRatio(u, vv)
if wasSquare == 0 {
return nil, errors.New("edwards25519: invalid point encoding")
}
// Select the negative square root if the sign bit is set.
xxNeg := new(field.Element).Negate(xx)
xx = xx.Select(xxNeg, xx, int(x[31]>>7))
v.x.Set(xx)
v.y.Set(y)
v.z.One()
v.t.Multiply(xx, y) // xy = T / Z
return v, nil
}
func copyFieldElement(buf *[32]byte, v *field.Element) []byte {
copy(buf[:], v.Bytes())
return buf[:]
}
// Conversions.
func (v *projP2) FromP1xP1(p *projP1xP1) *projP2 {
v.X.Multiply(&p.X, &p.T)
v.Y.Multiply(&p.Y, &p.Z)
v.Z.Multiply(&p.Z, &p.T)
return v
}
func (v *projP2) FromP3(p *Point) *projP2 {
v.X.Set(&p.x)
v.Y.Set(&p.y)
v.Z.Set(&p.z)
return v
}
func (v *Point) fromP1xP1(p *projP1xP1) *Point {
v.x.Multiply(&p.X, &p.T)
v.y.Multiply(&p.Y, &p.Z)
v.z.Multiply(&p.Z, &p.T)
v.t.Multiply(&p.X, &p.Y)
return v
}
func (v *Point) fromP2(p *projP2) *Point {
v.x.Multiply(&p.X, &p.Z)
v.y.Multiply(&p.Y, &p.Z)
v.z.Square(&p.Z)
v.t.Multiply(&p.X, &p.Y)
return v
}
// d is a constant in the curve equation.
var d, _ = new(field.Element).SetBytes([]byte{
0xa3, 0x78, 0x59, 0x13, 0xca, 0x4d, 0xeb, 0x75,
0xab, 0xd8, 0x41, 0x41, 0x4d, 0x0a, 0x70, 0x00,
0x98, 0xe8, 0x79, 0x77, 0x79, 0x40, 0xc7, 0x8c,
0x73, 0xfe, 0x6f, 0x2b, 0xee, 0x6c, 0x03, 0x52})
var d2 = new(field.Element).Add(d, d)
func (v *projCached) FromP3(p *Point) *projCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.Z.Set(&p.z)
v.T2d.Multiply(&p.t, d2)
return v
}
func (v *affineCached) FromP3(p *Point) *affineCached {
v.YplusX.Add(&p.y, &p.x)
v.YminusX.Subtract(&p.y, &p.x)
v.T2d.Multiply(&p.t, d2)
var invZ field.Element
invZ.Invert(&p.z)
v.YplusX.Multiply(&v.YplusX, &invZ)
v.YminusX.Multiply(&v.YminusX, &invZ)
v.T2d.Multiply(&v.T2d, &invZ)
return v
}
// (Re)addition and subtraction.
// Add sets v = p + q, and returns v.
func (v *Point) Add(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Add(p, qCached)
return v.fromP1xP1(result)
}
// Subtract sets v = p - q, and returns v.
func (v *Point) Subtract(p, q *Point) *Point {
checkInitialized(p, q)
qCached := new(projCached).FromP3(q)
result := new(projP1xP1).Sub(p, qCached)
return v.fromP1xP1(result)
}
func (v *projP1xP1) Add(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&ZZ2, &TT2d)
v.T.Subtract(&ZZ2, &TT2d)
return v
}
func (v *projP1xP1) Sub(p *Point, q *projCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, ZZ2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
ZZ2.Multiply(&p.z, &q.Z)
ZZ2.Add(&ZZ2, &ZZ2)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&ZZ2, &TT2d) // flipped sign
v.T.Add(&ZZ2, &TT2d) // flipped sign
return v
}
func (v *projP1xP1) AddAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YplusX)
MM.Multiply(&YminusX, &q.YminusX)
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Add(&Z2, &TT2d)
v.T.Subtract(&Z2, &TT2d)
return v
}
func (v *projP1xP1) SubAffine(p *Point, q *affineCached) *projP1xP1 {
var YplusX, YminusX, PP, MM, TT2d, Z2 field.Element
YplusX.Add(&p.y, &p.x)
YminusX.Subtract(&p.y, &p.x)
PP.Multiply(&YplusX, &q.YminusX) // flipped sign
MM.Multiply(&YminusX, &q.YplusX) // flipped sign
TT2d.Multiply(&p.t, &q.T2d)
Z2.Add(&p.z, &p.z)
v.X.Subtract(&PP, &MM)
v.Y.Add(&PP, &MM)
v.Z.Subtract(&Z2, &TT2d) // flipped sign
v.T.Add(&Z2, &TT2d) // flipped sign
return v
}
// Doubling.
func (v *projP1xP1) Double(p *projP2) *projP1xP1 {
var XX, YY, ZZ2, XplusYsq field.Element
XX.Square(&p.X)
YY.Square(&p.Y)
ZZ2.Square(&p.Z)
ZZ2.Add(&ZZ2, &ZZ2)
XplusYsq.Add(&p.X, &p.Y)
XplusYsq.Square(&XplusYsq)
v.Y.Add(&YY, &XX)
v.Z.Subtract(&YY, &XX)
v.X.Subtract(&XplusYsq, &v.Y)
v.T.Subtract(&ZZ2, &v.Z)
return v
}
// Negation.
// Negate sets v = -p, and returns v.
func (v *Point) Negate(p *Point) *Point {
checkInitialized(p)
v.x.Negate(&p.x)
v.y.Set(&p.y)
v.z.Set(&p.z)
v.t.Negate(&p.t)
return v
}
// Equal returns 1 if v is equivalent to u, and 0 otherwise.
func (v *Point) Equal(u *Point) int {
checkInitialized(v, u)
var t1, t2, t3, t4 field.Element
t1.Multiply(&v.x, &u.z)
t2.Multiply(&u.x, &v.z)
t3.Multiply(&v.y, &u.z)
t4.Multiply(&u.y, &v.z)
return t1.Equal(&t2) & t3.Equal(&t4)
}
// Constant-time operations
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *projCached) Select(a, b *projCached, cond int) *projCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.Z.Select(&a.Z, &b.Z, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// Select sets v to a if cond == 1 and to b if cond == 0.
func (v *affineCached) Select(a, b *affineCached, cond int) *affineCached {
v.YplusX.Select(&a.YplusX, &b.YplusX, cond)
v.YminusX.Select(&a.YminusX, &b.YminusX, cond)
v.T2d.Select(&a.T2d, &b.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *projCached) CondNeg(cond int) *projCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// CondNeg negates v if cond == 1 and leaves it unchanged if cond == 0.
func (v *affineCached) CondNeg(cond int) *affineCached {
v.YplusX.Swap(&v.YminusX, cond)
v.T2d.Select(new(field.Element).Negate(&v.T2d), &v.T2d, cond)
return v
}
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package field implements fast arithmetic modulo 2^255-19.
package field
import (
"crypto/subtle"
"encoding/binary"
"errors"
"math/bits"
)
// Element represents an element of the field GF(2^255-19). Note that this
// is not a cryptographically secure group, and should only be used to interact
// with edwards25519.Point coordinates.
//
// This type works similarly to math/big.Int, and all arguments and receivers
// are allowed to alias.
//
// The zero value is a valid zero element.
type Element struct {
// An element t represents the integer
// t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204
//
// Between operations, all limbs are expected to be lower than 2^52.
l0 uint64
l1 uint64
l2 uint64
l3 uint64
l4 uint64
}
const maskLow51Bits uint64 = (1 << 51) - 1
var feZero = &Element{0, 0, 0, 0, 0}
// Zero sets v = 0, and returns v.
func (v *Element) Zero() *Element {
*v = *feZero
return v
}
var feOne = &Element{1, 0, 0, 0, 0}
// One sets v = 1, and returns v.
func (v *Element) One() *Element {
*v = *feOne
return v
}
// reduce reduces v modulo 2^255 - 19 and returns it.
func (v *Element) reduce() *Element {
v.carryPropagate()
// After the light reduction we now have a field element representation
// v < 2^255 + 2^13 * 19, but need v < 2^255 - 19.
// If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1,
// generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise.
c := (v.l0 + 19) >> 51
c = (v.l1 + c) >> 51
c = (v.l2 + c) >> 51
c = (v.l3 + c) >> 51
c = (v.l4 + c) >> 51
// If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's
// effectively applying the reduction identity to the carry.
v.l0 += 19 * c
v.l1 += v.l0 >> 51
v.l0 = v.l0 & maskLow51Bits
v.l2 += v.l1 >> 51
v.l1 = v.l1 & maskLow51Bits
v.l3 += v.l2 >> 51
v.l2 = v.l2 & maskLow51Bits
v.l4 += v.l3 >> 51
v.l3 = v.l3 & maskLow51Bits
// no additional carry
v.l4 = v.l4 & maskLow51Bits
return v
}
// Add sets v = a + b, and returns v.
func (v *Element) Add(a, b *Element) *Element {
v.l0 = a.l0 + b.l0
v.l1 = a.l1 + b.l1
v.l2 = a.l2 + b.l2
v.l3 = a.l3 + b.l3
v.l4 = a.l4 + b.l4
// Using the generic implementation here is actually faster than the
// assembly. Probably because the body of this function is so simple that
// the compiler can figure out better optimizations by inlining the carry
// propagation.
return v.carryPropagateGeneric()
}
// Subtract sets v = a - b, and returns v.
func (v *Element) Subtract(a, b *Element) *Element {
// We first add 2 * p, to guarantee the subtraction won't underflow, and
// then subtract b (which can be up to 2^255 + 2^13 * 19).
v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0
v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1
v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2
v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3
v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4
return v.carryPropagate()
}
// Negate sets v = -a, and returns v.
func (v *Element) Negate(a *Element) *Element {
return v.Subtract(feZero, a)
}
// Invert sets v = 1/z mod p, and returns v.
//
// If z == 0, Invert returns v = 0.
func (v *Element) Invert(z *Element) *Element {
// Inversion is implemented as exponentiation with exponent p − 2. It uses the
// same sequence of 255 squarings and 11 multiplications as [Curve25519].
var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element
z2.Square(z) // 2
t.Square(&z2) // 4
t.Square(&t) // 8
z9.Multiply(&t, z) // 9
z11.Multiply(&z9, &z2) // 11
t.Square(&z11) // 22
z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0
t.Square(&z2_5_0) // 2^6 - 2^1
for i := 0; i < 4; i++ {
t.Square(&t) // 2^10 - 2^5
}
z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0
t.Square(&z2_10_0) // 2^11 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^20 - 2^10
}
z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0
t.Square(&z2_20_0) // 2^21 - 2^1
for i := 0; i < 19; i++ {
t.Square(&t) // 2^40 - 2^20
}
t.Multiply(&t, &z2_20_0) // 2^40 - 2^0
t.Square(&t) // 2^41 - 2^1
for i := 0; i < 9; i++ {
t.Square(&t) // 2^50 - 2^10
}
z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0
t.Square(&z2_50_0) // 2^51 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^100 - 2^50
}
z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0
t.Square(&z2_100_0) // 2^101 - 2^1
for i := 0; i < 99; i++ {
t.Square(&t) // 2^200 - 2^100
}
t.Multiply(&t, &z2_100_0) // 2^200 - 2^0
t.Square(&t) // 2^201 - 2^1
for i := 0; i < 49; i++ {
t.Square(&t) // 2^250 - 2^50
}
t.Multiply(&t, &z2_50_0) // 2^250 - 2^0
t.Square(&t) // 2^251 - 2^1
t.Square(&t) // 2^252 - 2^2
t.Square(&t) // 2^253 - 2^3
t.Square(&t) // 2^254 - 2^4
t.Square(&t) // 2^255 - 2^5
return v.Multiply(&t, &z11) // 2^255 - 21
}
// Set sets v = a, and returns v.
func (v *Element) Set(a *Element) *Element {
*v = *a
return v
}
// SetBytes sets v to x, where x is a 32-byte little-endian encoding. If x is
// not of the right length, SetBytes returns nil and an error, and the
// receiver is unchanged.
//
// Consistent with RFC 7748, the most significant bit (the high bit of the
// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1)
// are accepted. Note that this is laxer than specified by RFC 8032, but
// consistent with most Ed25519 implementations.
func (v *Element) SetBytes(x []byte) (*Element, error) {
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid field element input size")
}
// Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51).
v.l0 = binary.LittleEndian.Uint64(x[0:8])
v.l0 &= maskLow51Bits
// Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51).
v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3
v.l1 &= maskLow51Bits
// Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51).
v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6
v.l2 &= maskLow51Bits
// Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51).
v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1
v.l3 &= maskLow51Bits
// Bits 204:255 (bytes 24:32, bits 192:256, shift 12, mask 51).
// Note: not bytes 25:33, shift 4, to avoid overread.
v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12
v.l4 &= maskLow51Bits
return v, nil
}
// Bytes returns the canonical 32-byte little-endian encoding of v.
func (v *Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [32]byte
return v.bytes(&out)
}
func (v *Element) bytes(out *[32]byte) []byte {
t := *v
t.reduce()
var buf [8]byte
for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} {
bitsOffset := i * 51
binary.LittleEndian.PutUint64(buf[:], l<<uint(bitsOffset%8))
for i, bb := range buf {
off := bitsOffset/8 + i
if off >= len(out) {
break
}
out[off] |= bb
}
}
return out[:]
}
// Equal returns 1 if v and u are equal, and 0 otherwise.
func (v *Element) Equal(u *Element) int {
sa, sv := u.Bytes(), v.Bytes()
return subtle.ConstantTimeCompare(sa, sv)
}
// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise.
func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) }
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *Element) Select(a, b *Element, cond int) *Element {
m := mask64Bits(cond)
v.l0 = (m & a.l0) | (^m & b.l0)
v.l1 = (m & a.l1) | (^m & b.l1)
v.l2 = (m & a.l2) | (^m & b.l2)
v.l3 = (m & a.l3) | (^m & b.l3)
v.l4 = (m & a.l4) | (^m & b.l4)
return v
}
// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v.
func (v *Element) Swap(u *Element, cond int) {
m := mask64Bits(cond)
t := m & (v.l0 ^ u.l0)
v.l0 ^= t
u.l0 ^= t
t = m & (v.l1 ^ u.l1)
v.l1 ^= t
u.l1 ^= t
t = m & (v.l2 ^ u.l2)
v.l2 ^= t
u.l2 ^= t
t = m & (v.l3 ^ u.l3)
v.l3 ^= t
u.l3 ^= t
t = m & (v.l4 ^ u.l4)
v.l4 ^= t
u.l4 ^= t
}
// IsNegative returns 1 if v is negative, and 0 otherwise.
func (v *Element) IsNegative() int {
return int(v.Bytes()[0] & 1)
}
// Absolute sets v to |u|, and returns v.
func (v *Element) Absolute(u *Element) *Element {
return v.Select(new(Element).Negate(u), u, u.IsNegative())
}
// Multiply sets v = x * y, and returns v.
func (v *Element) Multiply(x, y *Element) *Element {
feMul(v, x, y)
return v
}
// Square sets v = x * x, and returns v.
func (v *Element) Square(x *Element) *Element {
feSquare(v, x)
return v
}
// Mult32 sets v = x * y, and returns v.
func (v *Element) Mult32(x *Element, y uint32) *Element {
x0lo, x0hi := mul51(x.l0, y)
x1lo, x1hi := mul51(x.l1, y)
x2lo, x2hi := mul51(x.l2, y)
x3lo, x3hi := mul51(x.l3, y)
x4lo, x4hi := mul51(x.l4, y)
v.l0 = x0lo + 19*x4hi // carried over per the reduction identity
v.l1 = x1lo + x0hi
v.l2 = x2lo + x1hi
v.l3 = x3lo + x2hi
v.l4 = x4lo + x3hi
// The hi portions are going to be only 32 bits, plus any previous excess,
// so we can skip the carry propagation.
return v
}
// mul51 returns lo + hi * 2⁵¹ = a * b.
func mul51(a uint64, b uint32) (lo uint64, hi uint64) {
mh, ml := bits.Mul64(a, uint64(b))
lo = ml & maskLow51Bits
hi = (mh << 13) | (ml >> 51)
return
}
// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3.
func (v *Element) Pow22523(x *Element) *Element {
var t0, t1, t2 Element
t0.Square(x) // x^2
t1.Square(&t0) // x^4
t1.Square(&t1) // x^8
t1.Multiply(x, &t1) // x^9
t0.Multiply(&t0, &t1) // x^11
t0.Square(&t0) // x^22
t0.Multiply(&t1, &t0) // x^31
t1.Square(&t0) // x^62
for i := 1; i < 5; i++ { // x^992
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1
t1.Square(&t0) // 2^11 - 2
for i := 1; i < 10; i++ { // 2^20 - 2^10
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^20 - 1
t2.Square(&t1) // 2^21 - 2
for i := 1; i < 20; i++ { // 2^40 - 2^20
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^40 - 1
t1.Square(&t1) // 2^41 - 2
for i := 1; i < 10; i++ { // 2^50 - 2^10
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^50 - 1
t1.Square(&t0) // 2^51 - 2
for i := 1; i < 50; i++ { // 2^100 - 2^50
t1.Square(&t1)
}
t1.Multiply(&t1, &t0) // 2^100 - 1
t2.Square(&t1) // 2^101 - 2
for i := 1; i < 100; i++ { // 2^200 - 2^100
t2.Square(&t2)
}
t1.Multiply(&t2, &t1) // 2^200 - 1
t1.Square(&t1) // 2^201 - 2
for i := 1; i < 50; i++ { // 2^250 - 2^50
t1.Square(&t1)
}
t0.Multiply(&t1, &t0) // 2^250 - 1
t0.Square(&t0) // 2^251 - 2
t0.Square(&t0) // 2^252 - 4
return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3)
}
// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion.
var sqrtM1 = &Element{1718705420411056, 234908883556509,
2233514472574048, 2117202627021982, 765476049583133}
// SqrtRatio sets r to the non-negative square root of the ratio of u and v.
//
// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio
// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00,
// and returns r and 0.
func (r *Element) SqrtRatio(u, v *Element) (R *Element, wasSquare int) {
t0 := new(Element)
// r = (u * v3) * (u * v7)^((p-5)/8)
v2 := new(Element).Square(v)
uv3 := new(Element).Multiply(u, t0.Multiply(v2, v))
uv7 := new(Element).Multiply(uv3, t0.Square(v2))
rr := new(Element).Multiply(uv3, t0.Pow22523(uv7))
check := new(Element).Multiply(v, t0.Square(rr)) // check = v * r^2
uNeg := new(Element).Negate(u)
correctSignSqrt := check.Equal(u)
flippedSignSqrt := check.Equal(uNeg)
flippedSignSqrtI := check.Equal(t0.Multiply(uNeg, sqrtM1))
rPrime := new(Element).Multiply(rr, sqrtM1) // r_prime = SQRT_M1 * r
// r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r)
rr.Select(rPrime, rr, flippedSignSqrt|flippedSignSqrtI)
r.Absolute(rr) // Choose the nonnegative square root.
return r, correctSignSqrt | flippedSignSqrt
}
// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 || !gc || purego
package field
func (v *Element) carryPropagate() *Element {
return v.carryPropagateGeneric()
}
// Copyright (c) 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package field
import "math/bits"
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
// bits.Mul64 and bits.Add64 intrinsics.
type uint128 struct {
lo, hi uint64
}
// mul64 returns a * b.
func mul64(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
// addMul64 returns v + a * b.
func addMul64(v uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
lo, c := bits.Add64(lo, v.lo, 0)
hi, _ = bits.Add64(hi, v.hi, c)
return uint128{lo, hi}
}
// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits.
func shiftRightBy51(a uint128) uint64 {
return (a.hi << (64 - 51)) | (a.lo >> 51)
}
func feMulGeneric(v, a, b *Element) {
a0 := a.l0
a1 := a.l1
a2 := a.l2
a3 := a.l3
a4 := a.l4
b0 := b.l0
b1 := b.l1
b2 := b.l2
b3 := b.l3
b4 := b.l4
// Limb multiplication works like pen-and-paper columnar multiplication, but
// with 51-bit limbs instead of digits.
//
// a4 a3 a2 a1 a0 x
// b4 b3 b2 b1 b0 =
// ------------------------
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a4b1 a3b1 a2b1 a1b1 a0b1 +
// a4b2 a3b2 a2b2 a1b2 a0b2 +
// a4b3 a3b3 a2b3 a1b3 a0b3 +
// a4b4 a3b4 a2b4 a1b4 a0b4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to
// reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5,
// r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc.
//
// Reduction can be carried out simultaneously to multiplication. For
// example, we do not compute r5: whenever the result of a multiplication
// belongs to r5, like a1b4, we multiply it by 19 and add the result to r0.
//
// a4b0 a3b0 a2b0 a1b0 a0b0 +
// a3b1 a2b1 a1b1 a0b1 19×a4b1 +
// a2b2 a1b2 a0b2 19×a4b2 19×a3b2 +
// a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 +
// a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// Finally we add up the columns into wide, overlapping limbs.
a1_19 := a1 * 19
a2_19 := a2 * 19
a3_19 := a3 * 19
a4_19 := a4 * 19
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
r0 := mul64(a0, b0)
r0 = addMul64(r0, a1_19, b4)
r0 = addMul64(r0, a2_19, b3)
r0 = addMul64(r0, a3_19, b2)
r0 = addMul64(r0, a4_19, b1)
// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
r1 := mul64(a0, b1)
r1 = addMul64(r1, a1, b0)
r1 = addMul64(r1, a2_19, b4)
r1 = addMul64(r1, a3_19, b3)
r1 = addMul64(r1, a4_19, b2)
// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
r2 := mul64(a0, b2)
r2 = addMul64(r2, a1, b1)
r2 = addMul64(r2, a2, b0)
r2 = addMul64(r2, a3_19, b4)
r2 = addMul64(r2, a4_19, b3)
// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
r3 := mul64(a0, b3)
r3 = addMul64(r3, a1, b2)
r3 = addMul64(r3, a2, b1)
r3 = addMul64(r3, a3, b0)
r3 = addMul64(r3, a4_19, b4)
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
r4 := mul64(a0, b4)
r4 = addMul64(r4, a1, b3)
r4 = addMul64(r4, a2, b2)
r4 = addMul64(r4, a3, b1)
r4 = addMul64(r4, a4, b0)
// After the multiplication, we need to reduce (carry) the five coefficients
// to obtain a result with limbs that are at most slightly larger than 2⁵¹,
// to respect the Element invariant.
//
// Overall, the reduction works the same as carryPropagate, except with
// wider inputs: we take the carry for each coefficient by shifting it right
// by 51, and add it to the limb above it. The top carry is multiplied by 19
// according to the reduction identity and added to the lowest limb.
//
// The largest coefficient (r0) will be at most 111 bits, which guarantees
// that all carries are at most 111 - 51 = 60 bits, which fits in a uint64.
//
// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
// r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²)
// r0 < (1 + 19 × 4) × 2⁵² × 2⁵²
// r0 < 2⁷ × 2⁵² × 2⁵²
// r0 < 2¹¹¹
//
// Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most
// 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and
// allows us to easily apply the reduction identity.
//
// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
// r4 < 5 × 2⁵² × 2⁵²
// r4 < 2¹⁰⁷
//
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
// Now all coefficients fit into 64-bit registers but are still too large to
// be passed around as a Element. We therefore do one last carry chain,
// where the carries will be small enough to fit in the wiggle room above 2⁵¹.
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
func feSquareGeneric(v, a *Element) {
l0 := a.l0
l1 := a.l1
l2 := a.l2
l3 := a.l3
l4 := a.l4
// Squaring works precisely like multiplication above, but thanks to its
// symmetry we get to group a few terms together.
//
// l4 l3 l2 l1 l0 x
// l4 l3 l2 l1 l0 =
// ------------------------
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l4l1 l3l1 l2l1 l1l1 l0l1 +
// l4l2 l3l2 l2l2 l1l2 l0l2 +
// l4l3 l3l3 l2l3 l1l3 l0l3 +
// l4l4 l3l4 l2l4 l1l4 l0l4 =
// ----------------------------------------------
// r8 r7 r6 r5 r4 r3 r2 r1 r0
//
// l4l0 l3l0 l2l0 l1l0 l0l0 +
// l3l1 l2l1 l1l1 l0l1 19×l4l1 +
// l2l2 l1l2 l0l2 19×l4l2 19×l3l2 +
// l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 +
// l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 =
// --------------------------------------
// r4 r3 r2 r1 r0
//
// With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with
// only three Mul64 and four Add64, instead of five and eight.
l0_2 := l0 * 2
l1_2 := l1 * 2
l1_38 := l1 * 38
l2_38 := l2 * 38
l3_38 := l3 * 38
l3_19 := l3 * 19
l4_19 := l4 * 19
// r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3)
r0 := mul64(l0, l0)
r0 = addMul64(r0, l1_38, l4)
r0 = addMul64(r0, l2_38, l3)
// r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
r1 := mul64(l0_2, l1)
r1 = addMul64(r1, l2_38, l4)
r1 = addMul64(r1, l3_19, l3)
// r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4
r2 := mul64(l0_2, l2)
r2 = addMul64(r2, l1, l1)
r2 = addMul64(r2, l3_38, l4)
// r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
r3 := mul64(l0_2, l3)
r3 = addMul64(r3, l1_2, l2)
r3 = addMul64(r3, l4_19, l4)
// r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2
r4 := mul64(l0_2, l4)
r4 = addMul64(r4, l1_2, l3)
r4 = addMul64(r4, l2, l2)
c0 := shiftRightBy51(r0)
c1 := shiftRightBy51(r1)
c2 := shiftRightBy51(r2)
c3 := shiftRightBy51(r3)
c4 := shiftRightBy51(r4)
rr0 := r0.lo&maskLow51Bits + c4*19
rr1 := r1.lo&maskLow51Bits + c0
rr2 := r2.lo&maskLow51Bits + c1
rr3 := r3.lo&maskLow51Bits + c2
rr4 := r4.lo&maskLow51Bits + c3
*v = Element{rr0, rr1, rr2, rr3, rr4}
v.carryPropagate()
}
// carryPropagateGeneric brings the limbs below 52 bits by applying the reduction
// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry.
func (v *Element) carryPropagateGeneric() *Element {
c0 := v.l0 >> 51
c1 := v.l1 >> 51
c2 := v.l2 >> 51
c3 := v.l3 >> 51
c4 := v.l4 >> 51
// c4 is at most 64 - 51 = 13 bits, so c4*19 is at most 18 bits, and
// the final l0 will be at most 52 bits. Similarly for the rest.
v.l0 = v.l0&maskLow51Bits + c4*19
v.l1 = v.l1&maskLow51Bits + c0
v.l2 = v.l2&maskLow51Bits + c1
v.l3 = v.l3&maskLow51Bits + c2
v.l4 = v.l4&maskLow51Bits + c3
return v
}
// Copyright (c) 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"encoding/binary"
"errors"
)
// A Scalar is an integer modulo
//
// l = 2^252 + 27742317777372353535851937790883648493
//
// which is the prime order of the edwards25519 group.
//
// This type works similarly to math/big.Int, and all arguments and
// receivers are allowed to alias.
//
// The zero value is a valid zero element.
type Scalar struct {
// s is the scalar in the Montgomery domain, in the format of the
// fiat-crypto implementation.
s fiatScalarMontgomeryDomainFieldElement
}
// The field implementation in scalar_fiat.go is generated by the fiat-crypto
// project (https://github.com/mit-plv/fiat-crypto) at version v0.0.9 (23d2dbc)
// from a formally verified model.
//
// fiat-crypto code comes under the following license.
//
// Copyright (c) 2015-2020 The fiat-crypto Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// THIS SOFTWARE IS PROVIDED BY the fiat-crypto authors "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Berkeley Software Design,
// Inc. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// NewScalar returns a new zero Scalar.
func NewScalar() *Scalar {
return &Scalar{}
}
// MultiplyAdd sets s = x * y + z mod l, and returns s. It is equivalent to
// using Multiply and then Add.
func (s *Scalar) MultiplyAdd(x, y, z *Scalar) *Scalar {
// Make a copy of z in case it aliases s.
zCopy := new(Scalar).Set(z)
return s.Multiply(x, y).Add(s, zCopy)
}
// Add sets s = x + y mod l, and returns s.
func (s *Scalar) Add(x, y *Scalar) *Scalar {
// s = 1 * x + y mod l
fiatScalarAdd(&s.s, &x.s, &y.s)
return s
}
// Subtract sets s = x - y mod l, and returns s.
func (s *Scalar) Subtract(x, y *Scalar) *Scalar {
// s = -1 * y + x mod l
fiatScalarSub(&s.s, &x.s, &y.s)
return s
}
// Negate sets s = -x mod l, and returns s.
func (s *Scalar) Negate(x *Scalar) *Scalar {
// s = -1 * x + 0 mod l
fiatScalarOpp(&s.s, &x.s)
return s
}
// Multiply sets s = x * y mod l, and returns s.
func (s *Scalar) Multiply(x, y *Scalar) *Scalar {
// s = x * y + 0 mod l
fiatScalarMul(&s.s, &x.s, &y.s)
return s
}
// Set sets s = x, and returns s.
func (s *Scalar) Set(x *Scalar) *Scalar {
*s = *x
return s
}
// SetUniformBytes sets s = x mod l, where x is a 64-byte little-endian integer.
// If x is not of the right length, SetUniformBytes returns nil and an error,
// and the receiver is unchanged.
//
// SetUniformBytes can be used to set s to an uniformly distributed value given
// 64 uniformly distributed random bytes.
func (s *Scalar) SetUniformBytes(x []byte) (*Scalar, error) {
if len(x) != 64 {
return nil, errors.New("edwards25519: invalid SetUniformBytes input length")
}
// We have a value x of 512 bits, but our fiatScalarFromBytes function
// expects an input lower than l, which is a little over 252 bits.
//
// Instead of writing a reduction function that operates on wider inputs, we
// can interpret x as the sum of three shorter values a, b, and c.
//
// x = a + b * 2^168 + c * 2^336 mod l
//
// We then precompute 2^168 and 2^336 modulo l, and perform the reduction
// with two multiplications and two additions.
s.setShortBytes(x[:21])
t := new(Scalar).setShortBytes(x[21:42])
s.Add(s, t.Multiply(t, scalarTwo168))
t.setShortBytes(x[42:])
s.Add(s, t.Multiply(t, scalarTwo336))
return s, nil
}
// scalarTwo168 and scalarTwo336 are 2^168 and 2^336 modulo l, encoded as a
// fiatScalarMontgomeryDomainFieldElement, which is a little-endian 4-limb value
// in the 2^256 Montgomery domain.
var scalarTwo168 = &Scalar{s: [4]uint64{0x5b8ab432eac74798, 0x38afddd6de59d5d7,
0xa2c131b399411b7c, 0x6329a7ed9ce5a30}}
var scalarTwo336 = &Scalar{s: [4]uint64{0xbd3d108e2b35ecc5, 0x5c3a3718bdf9c90b,
0x63aa97a331b4f2ee, 0x3d217f5be65cb5c}}
// setShortBytes sets s = x mod l, where x is a little-endian integer shorter
// than 32 bytes.
func (s *Scalar) setShortBytes(x []byte) *Scalar {
if len(x) >= 32 {
panic("edwards25519: internal error: setShortBytes called with a long string")
}
var buf [32]byte
copy(buf[:], x)
fiatScalarFromBytes((*[4]uint64)(&s.s), &buf)
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s
}
// SetCanonicalBytes sets s = x, where x is a 32-byte little-endian encoding of
// s, and returns s. If x is not a canonical encoding of s, SetCanonicalBytes
// returns nil and an error, and the receiver is unchanged.
func (s *Scalar) SetCanonicalBytes(x []byte) (*Scalar, error) {
if len(x) != 32 {
return nil, errors.New("invalid scalar length")
}
if !isReduced(x) {
return nil, errors.New("invalid scalar encoding")
}
fiatScalarFromBytes((*[4]uint64)(&s.s), (*[32]byte)(x))
fiatScalarToMontgomery(&s.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&s.s))
return s, nil
}
// scalarMinusOneBytes is l - 1 in little endian.
var scalarMinusOneBytes = [32]byte{236, 211, 245, 92, 26, 99, 18, 88, 214, 156, 247, 162, 222, 249, 222, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16}
// isReduced returns whether the given scalar in 32-byte little endian encoded
// form is reduced modulo l.
func isReduced(s []byte) bool {
if len(s) != 32 {
return false
}
for i := len(s) - 1; i >= 0; i-- {
switch {
case s[i] > scalarMinusOneBytes[i]:
return false
case s[i] < scalarMinusOneBytes[i]:
return true
}
}
return true
}
// SetBytesWithClamping applies the buffer pruning described in RFC 8032,
// Section 5.1.5 (also known as clamping) and sets s to the result. The input
// must be 32 bytes, and it is not modified. If x is not of the right length,
// SetBytesWithClamping returns nil and an error, and the receiver is unchanged.
//
// Note that since Scalar values are always reduced modulo the prime order of
// the curve, the resulting value will not preserve any of the cofactor-clearing
// properties that clamping is meant to provide. It will however work as
// expected as long as it is applied to points on the prime order subgroup, like
// in Ed25519. In fact, it is lost to history why RFC 8032 adopted the
// irrelevant RFC 7748 clamping, but it is now required for compatibility.
func (s *Scalar) SetBytesWithClamping(x []byte) (*Scalar, error) {
// The description above omits the purpose of the high bits of the clamping
// for brevity, but those are also lost to reductions, and are also
// irrelevant to edwards25519 as they protect against a specific
// implementation bug that was once observed in a generic Montgomery ladder.
if len(x) != 32 {
return nil, errors.New("edwards25519: invalid SetBytesWithClamping input length")
}
// We need to use the wide reduction from SetUniformBytes, since clamping
// sets the 2^254 bit, making the value higher than the order.
var wideBytes [64]byte
copy(wideBytes[:], x[:])
wideBytes[0] &= 248
wideBytes[31] &= 63
wideBytes[31] |= 64
return s.SetUniformBytes(wideBytes[:])
}
// Bytes returns the canonical 32-byte little-endian encoding of s.
func (s *Scalar) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var encoded [32]byte
return s.bytes(&encoded)
}
func (s *Scalar) bytes(out *[32]byte) []byte {
var ss fiatScalarNonMontgomeryDomainFieldElement
fiatScalarFromMontgomery(&ss, &s.s)
fiatScalarToBytes(out, (*[4]uint64)(&ss))
return out[:]
}
// Equal returns 1 if s and t are equal, and 0 otherwise.
func (s *Scalar) Equal(t *Scalar) int {
var diff fiatScalarMontgomeryDomainFieldElement
fiatScalarSub(&diff, &s.s, &t.s)
var nonzero uint64
fiatScalarNonzero(&nonzero, (*[4]uint64)(&diff))
nonzero |= nonzero >> 32
nonzero |= nonzero >> 16
nonzero |= nonzero >> 8
nonzero |= nonzero >> 4
nonzero |= nonzero >> 2
nonzero |= nonzero >> 1
return int(^nonzero) & 1
}
// nonAdjacentForm computes a width-w non-adjacent form for this scalar.
//
// w must be between 2 and 8, or nonAdjacentForm will panic.
func (s *Scalar) nonAdjacentForm(w uint) [256]int8 {
// This implementation is adapted from the one
// in curve25519-dalek and is documented there:
// https://github.com/dalek-cryptography/curve25519-dalek/blob/f630041af28e9a405255f98a8a93adca18e4315b/src/scalar.rs#L800-L871
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
if w < 2 {
panic("w must be at least 2 by the definition of NAF")
} else if w > 8 {
panic("NAF digits must fit in int8")
}
var naf [256]int8
var digits [5]uint64
for i := 0; i < 4; i++ {
digits[i] = binary.LittleEndian.Uint64(b[i*8:])
}
width := uint64(1 << w)
windowMask := uint64(width - 1)
pos := uint(0)
carry := uint64(0)
for pos < 256 {
indexU64 := pos / 64
indexBit := pos % 64
var bitBuf uint64
if indexBit < 64-w {
// This window's bits are contained in a single u64
bitBuf = digits[indexU64] >> indexBit
} else {
// Combine the current 64 bits with bits from the next 64
bitBuf = (digits[indexU64] >> indexBit) | (digits[1+indexU64] << (64 - indexBit))
}
// Add carry into the current window
window := carry + (bitBuf & windowMask)
if window&1 == 0 {
// If the window value is even, preserve the carry and continue.
// Why is the carry preserved?
// If carry == 0 and window & 1 == 0,
// then the next carry should be 0
// If carry == 1 and window & 1 == 0,
// then bit_buf & 1 == 1 so the next carry should be 1
pos += 1
continue
}
if window < width/2 {
carry = 0
naf[pos] = int8(window)
} else {
carry = 1
naf[pos] = int8(window) - int8(width)
}
pos += w
}
return naf
}
func (s *Scalar) signedRadix16() [64]int8 {
b := s.Bytes()
if b[31] > 127 {
panic("scalar has high bit set illegally")
}
var digits [64]int8
// Compute unsigned radix-16 digits:
for i := 0; i < 32; i++ {
digits[2*i] = int8(b[i] & 15)
digits[2*i+1] = int8((b[i] >> 4) & 15)
}
// Recenter coefficients:
for i := 0; i < 63; i++ {
carry := (digits[i] + 8) >> 4
digits[i] -= carry << 4
digits[i+1] += carry
}
return digits
}
// Code generated by Fiat Cryptography. DO NOT EDIT.
//
// Autogenerated: word_by_word_montgomery --lang Go --cmovznz-by-mul --relax-primitive-carry-to-bitwidth 32,64 --public-function-case camelCase --public-type-case camelCase --private-function-case camelCase --private-type-case camelCase --doc-text-before-function-name '' --doc-newline-before-package-declaration --doc-prepend-header 'Code generated by Fiat Cryptography. DO NOT EDIT.' --package-name edwards25519 Scalar 64 '2^252 + 27742317777372353535851937790883648493' mul add sub opp nonzero from_montgomery to_montgomery to_bytes from_bytes
//
// curve description: Scalar
//
// machine_wordsize = 64 (from "64")
//
// requested operations: mul, add, sub, opp, nonzero, from_montgomery, to_montgomery, to_bytes, from_bytes
//
// m = 0x1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed (from "2^252 + 27742317777372353535851937790883648493")
//
//
//
// NOTE: In addition to the bounds specified above each function, all
//
// functions synthesized for this Montgomery arithmetic require the
//
// input to be strictly less than the prime modulus (m), and also
//
// require the input to be in the unique saturated representation.
//
// All functions also ensure that these two properties are true of
//
// return values.
//
//
//
// Computed values:
//
// eval z = z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192)
//
// bytes_eval z = z[0] + (z[1] << 8) + (z[2] << 16) + (z[3] << 24) + (z[4] << 32) + (z[5] << 40) + (z[6] << 48) + (z[7] << 56) + (z[8] << 64) + (z[9] << 72) + (z[10] << 80) + (z[11] << 88) + (z[12] << 96) + (z[13] << 104) + (z[14] << 112) + (z[15] << 120) + (z[16] << 128) + (z[17] << 136) + (z[18] << 144) + (z[19] << 152) + (z[20] << 160) + (z[21] << 168) + (z[22] << 176) + (z[23] << 184) + (z[24] << 192) + (z[25] << 200) + (z[26] << 208) + (z[27] << 216) + (z[28] << 224) + (z[29] << 232) + (z[30] << 240) + (z[31] << 248)
//
// twos_complement_eval z = let x1 := z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) in
//
// if x1 & (2^256-1) < 2^255 then x1 & (2^256-1) else (x1 & (2^256-1)) - 2^256
package edwards25519
import "math/bits"
type fiatScalarUint1 uint64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
type fiatScalarInt1 int64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
// The type fiatScalarMontgomeryDomainFieldElement is a field element in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type fiatScalarMontgomeryDomainFieldElement [4]uint64
// The type fiatScalarNonMontgomeryDomainFieldElement is a field element NOT in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type fiatScalarNonMontgomeryDomainFieldElement [4]uint64
// fiatScalarCmovznzU64 is a single-word conditional move.
//
// Postconditions:
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func fiatScalarCmovznzU64(out1 *uint64, arg1 fiatScalarUint1, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// fiatScalarMul multiplies two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func fiatScalarMul(out1 *fiatScalarMontgomeryDomainFieldElement, arg1 *fiatScalarMontgomeryDomainFieldElement, arg2 *fiatScalarMontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, arg2[3])
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, arg2[2])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, arg2[1])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, arg2[0])
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(fiatScalarUint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(fiatScalarUint1(x16)))
x19 := (uint64(fiatScalarUint1(x18)) + x6)
var x20 uint64
_, x20 = bits.Mul64(x11, 0xd2b51da312547e1b)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x20, 0x1000000000000000)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x20, 0x14def9dea2f79cd6)
var x26 uint64
var x27 uint64
x27, x26 = bits.Mul64(x20, 0x5812631a5cf5d3ed)
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x27, x24, uint64(0x0))
x30 := (uint64(fiatScalarUint1(x29)) + x25)
var x32 uint64
_, x32 = bits.Add64(x11, x26, uint64(0x0))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x13, x28, uint64(fiatScalarUint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x15, x30, uint64(fiatScalarUint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x17, x22, uint64(fiatScalarUint1(x36)))
var x39 uint64
var x40 uint64
x39, x40 = bits.Add64(x19, x23, uint64(fiatScalarUint1(x38)))
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, arg2[3])
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, arg2[2])
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, arg2[1])
var x47 uint64
var x48 uint64
x48, x47 = bits.Mul64(x1, arg2[0])
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x48, x45, uint64(0x0))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x46, x43, uint64(fiatScalarUint1(x50)))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x44, x41, uint64(fiatScalarUint1(x52)))
x55 := (uint64(fiatScalarUint1(x54)) + x42)
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x33, x47, uint64(0x0))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x35, x49, uint64(fiatScalarUint1(x57)))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x37, x51, uint64(fiatScalarUint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x39, x53, uint64(fiatScalarUint1(x61)))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(uint64(fiatScalarUint1(x40)), x55, uint64(fiatScalarUint1(x63)))
var x66 uint64
_, x66 = bits.Mul64(x56, 0xd2b51da312547e1b)
var x68 uint64
var x69 uint64
x69, x68 = bits.Mul64(x66, 0x1000000000000000)
var x70 uint64
var x71 uint64
x71, x70 = bits.Mul64(x66, 0x14def9dea2f79cd6)
var x72 uint64
var x73 uint64
x73, x72 = bits.Mul64(x66, 0x5812631a5cf5d3ed)
var x74 uint64
var x75 uint64
x74, x75 = bits.Add64(x73, x70, uint64(0x0))
x76 := (uint64(fiatScalarUint1(x75)) + x71)
var x78 uint64
_, x78 = bits.Add64(x56, x72, uint64(0x0))
var x79 uint64
var x80 uint64
x79, x80 = bits.Add64(x58, x74, uint64(fiatScalarUint1(x78)))
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x60, x76, uint64(fiatScalarUint1(x80)))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x62, x68, uint64(fiatScalarUint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x64, x69, uint64(fiatScalarUint1(x84)))
x87 := (uint64(fiatScalarUint1(x86)) + uint64(fiatScalarUint1(x65)))
var x88 uint64
var x89 uint64
x89, x88 = bits.Mul64(x2, arg2[3])
var x90 uint64
var x91 uint64
x91, x90 = bits.Mul64(x2, arg2[2])
var x92 uint64
var x93 uint64
x93, x92 = bits.Mul64(x2, arg2[1])
var x94 uint64
var x95 uint64
x95, x94 = bits.Mul64(x2, arg2[0])
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x95, x92, uint64(0x0))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x93, x90, uint64(fiatScalarUint1(x97)))
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x91, x88, uint64(fiatScalarUint1(x99)))
x102 := (uint64(fiatScalarUint1(x101)) + x89)
var x103 uint64
var x104 uint64
x103, x104 = bits.Add64(x79, x94, uint64(0x0))
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x81, x96, uint64(fiatScalarUint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x83, x98, uint64(fiatScalarUint1(x106)))
var x109 uint64
var x110 uint64
x109, x110 = bits.Add64(x85, x100, uint64(fiatScalarUint1(x108)))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x87, x102, uint64(fiatScalarUint1(x110)))
var x113 uint64
_, x113 = bits.Mul64(x103, 0xd2b51da312547e1b)
var x115 uint64
var x116 uint64
x116, x115 = bits.Mul64(x113, 0x1000000000000000)
var x117 uint64
var x118 uint64
x118, x117 = bits.Mul64(x113, 0x14def9dea2f79cd6)
var x119 uint64
var x120 uint64
x120, x119 = bits.Mul64(x113, 0x5812631a5cf5d3ed)
var x121 uint64
var x122 uint64
x121, x122 = bits.Add64(x120, x117, uint64(0x0))
x123 := (uint64(fiatScalarUint1(x122)) + x118)
var x125 uint64
_, x125 = bits.Add64(x103, x119, uint64(0x0))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x105, x121, uint64(fiatScalarUint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x107, x123, uint64(fiatScalarUint1(x127)))
var x130 uint64
var x131 uint64
x130, x131 = bits.Add64(x109, x115, uint64(fiatScalarUint1(x129)))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x111, x116, uint64(fiatScalarUint1(x131)))
x134 := (uint64(fiatScalarUint1(x133)) + uint64(fiatScalarUint1(x112)))
var x135 uint64
var x136 uint64
x136, x135 = bits.Mul64(x3, arg2[3])
var x137 uint64
var x138 uint64
x138, x137 = bits.Mul64(x3, arg2[2])
var x139 uint64
var x140 uint64
x140, x139 = bits.Mul64(x3, arg2[1])
var x141 uint64
var x142 uint64
x142, x141 = bits.Mul64(x3, arg2[0])
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x142, x139, uint64(0x0))
var x145 uint64
var x146 uint64
x145, x146 = bits.Add64(x140, x137, uint64(fiatScalarUint1(x144)))
var x147 uint64
var x148 uint64
x147, x148 = bits.Add64(x138, x135, uint64(fiatScalarUint1(x146)))
x149 := (uint64(fiatScalarUint1(x148)) + x136)
var x150 uint64
var x151 uint64
x150, x151 = bits.Add64(x126, x141, uint64(0x0))
var x152 uint64
var x153 uint64
x152, x153 = bits.Add64(x128, x143, uint64(fiatScalarUint1(x151)))
var x154 uint64
var x155 uint64
x154, x155 = bits.Add64(x130, x145, uint64(fiatScalarUint1(x153)))
var x156 uint64
var x157 uint64
x156, x157 = bits.Add64(x132, x147, uint64(fiatScalarUint1(x155)))
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x134, x149, uint64(fiatScalarUint1(x157)))
var x160 uint64
_, x160 = bits.Mul64(x150, 0xd2b51da312547e1b)
var x162 uint64
var x163 uint64
x163, x162 = bits.Mul64(x160, 0x1000000000000000)
var x164 uint64
var x165 uint64
x165, x164 = bits.Mul64(x160, 0x14def9dea2f79cd6)
var x166 uint64
var x167 uint64
x167, x166 = bits.Mul64(x160, 0x5812631a5cf5d3ed)
var x168 uint64
var x169 uint64
x168, x169 = bits.Add64(x167, x164, uint64(0x0))
x170 := (uint64(fiatScalarUint1(x169)) + x165)
var x172 uint64
_, x172 = bits.Add64(x150, x166, uint64(0x0))
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x152, x168, uint64(fiatScalarUint1(x172)))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x154, x170, uint64(fiatScalarUint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x156, x162, uint64(fiatScalarUint1(x176)))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x158, x163, uint64(fiatScalarUint1(x178)))
x181 := (uint64(fiatScalarUint1(x180)) + uint64(fiatScalarUint1(x159)))
var x182 uint64
var x183 uint64
x182, x183 = bits.Sub64(x173, 0x5812631a5cf5d3ed, uint64(0x0))
var x184 uint64
var x185 uint64
x184, x185 = bits.Sub64(x175, 0x14def9dea2f79cd6, uint64(fiatScalarUint1(x183)))
var x186 uint64
var x187 uint64
x186, x187 = bits.Sub64(x177, uint64(0x0), uint64(fiatScalarUint1(x185)))
var x188 uint64
var x189 uint64
x188, x189 = bits.Sub64(x179, 0x1000000000000000, uint64(fiatScalarUint1(x187)))
var x191 uint64
_, x191 = bits.Sub64(x181, uint64(0x0), uint64(fiatScalarUint1(x189)))
var x192 uint64
fiatScalarCmovznzU64(&x192, fiatScalarUint1(x191), x182, x173)
var x193 uint64
fiatScalarCmovznzU64(&x193, fiatScalarUint1(x191), x184, x175)
var x194 uint64
fiatScalarCmovznzU64(&x194, fiatScalarUint1(x191), x186, x177)
var x195 uint64
fiatScalarCmovznzU64(&x195, fiatScalarUint1(x191), x188, x179)
out1[0] = x192
out1[1] = x193
out1[2] = x194
out1[3] = x195
}
// fiatScalarAdd adds two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) + eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func fiatScalarAdd(out1 *fiatScalarMontgomeryDomainFieldElement, arg1 *fiatScalarMontgomeryDomainFieldElement, arg2 *fiatScalarMontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Add64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Add64(arg1[1], arg2[1], uint64(fiatScalarUint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(arg1[2], arg2[2], uint64(fiatScalarUint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Add64(arg1[3], arg2[3], uint64(fiatScalarUint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Sub64(x1, 0x5812631a5cf5d3ed, uint64(0x0))
var x11 uint64
var x12 uint64
x11, x12 = bits.Sub64(x3, 0x14def9dea2f79cd6, uint64(fiatScalarUint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Sub64(x5, uint64(0x0), uint64(fiatScalarUint1(x12)))
var x15 uint64
var x16 uint64
x15, x16 = bits.Sub64(x7, 0x1000000000000000, uint64(fiatScalarUint1(x14)))
var x18 uint64
_, x18 = bits.Sub64(uint64(fiatScalarUint1(x8)), uint64(0x0), uint64(fiatScalarUint1(x16)))
var x19 uint64
fiatScalarCmovznzU64(&x19, fiatScalarUint1(x18), x9, x1)
var x20 uint64
fiatScalarCmovznzU64(&x20, fiatScalarUint1(x18), x11, x3)
var x21 uint64
fiatScalarCmovznzU64(&x21, fiatScalarUint1(x18), x13, x5)
var x22 uint64
fiatScalarCmovznzU64(&x22, fiatScalarUint1(x18), x15, x7)
out1[0] = x19
out1[1] = x20
out1[2] = x21
out1[3] = x22
}
// fiatScalarSub subtracts two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) - eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func fiatScalarSub(out1 *fiatScalarMontgomeryDomainFieldElement, arg1 *fiatScalarMontgomeryDomainFieldElement, arg2 *fiatScalarMontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(arg1[1], arg2[1], uint64(fiatScalarUint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(arg1[2], arg2[2], uint64(fiatScalarUint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(arg1[3], arg2[3], uint64(fiatScalarUint1(x6)))
var x9 uint64
fiatScalarCmovznzU64(&x9, fiatScalarUint1(x8), uint64(0x0), 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x1, (x9 & 0x5812631a5cf5d3ed), uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(x3, (x9 & 0x14def9dea2f79cd6), uint64(fiatScalarUint1(x11)))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x5, uint64(0x0), uint64(fiatScalarUint1(x13)))
var x16 uint64
x16, _ = bits.Add64(x7, (x9 & 0x1000000000000000), uint64(fiatScalarUint1(x15)))
out1[0] = x10
out1[1] = x12
out1[2] = x14
out1[3] = x16
}
// fiatScalarOpp negates a field element in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = -eval (from_montgomery arg1) mod m
// 0 ≤ eval out1 < m
func fiatScalarOpp(out1 *fiatScalarMontgomeryDomainFieldElement, arg1 *fiatScalarMontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(uint64(0x0), arg1[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(uint64(0x0), arg1[1], uint64(fiatScalarUint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(uint64(0x0), arg1[2], uint64(fiatScalarUint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(uint64(0x0), arg1[3], uint64(fiatScalarUint1(x6)))
var x9 uint64
fiatScalarCmovznzU64(&x9, fiatScalarUint1(x8), uint64(0x0), 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x1, (x9 & 0x5812631a5cf5d3ed), uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(x3, (x9 & 0x14def9dea2f79cd6), uint64(fiatScalarUint1(x11)))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x5, uint64(0x0), uint64(fiatScalarUint1(x13)))
var x16 uint64
x16, _ = bits.Add64(x7, (x9 & 0x1000000000000000), uint64(fiatScalarUint1(x15)))
out1[0] = x10
out1[1] = x12
out1[2] = x14
out1[3] = x16
}
// fiatScalarNonzero outputs a single non-zero word if the input is non-zero and zero otherwise.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = 0 ↔ eval (from_montgomery arg1) mod m = 0
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func fiatScalarNonzero(out1 *uint64, arg1 *[4]uint64) {
x1 := (arg1[0] | (arg1[1] | (arg1[2] | arg1[3])))
*out1 = x1
}
// fiatScalarFromMontgomery translates a field element out of the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = (eval arg1 * ((2^64)⁻¹ mod m)^4) mod m
// 0 ≤ eval out1 < m
func fiatScalarFromMontgomery(out1 *fiatScalarNonMontgomeryDomainFieldElement, arg1 *fiatScalarMontgomeryDomainFieldElement) {
x1 := arg1[0]
var x2 uint64
_, x2 = bits.Mul64(x1, 0xd2b51da312547e1b)
var x4 uint64
var x5 uint64
x5, x4 = bits.Mul64(x2, 0x1000000000000000)
var x6 uint64
var x7 uint64
x7, x6 = bits.Mul64(x2, 0x14def9dea2f79cd6)
var x8 uint64
var x9 uint64
x9, x8 = bits.Mul64(x2, 0x5812631a5cf5d3ed)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x9, x6, uint64(0x0))
var x13 uint64
_, x13 = bits.Add64(x1, x8, uint64(0x0))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(uint64(0x0), x10, uint64(fiatScalarUint1(x13)))
var x16 uint64
var x17 uint64
x16, x17 = bits.Add64(x14, arg1[1], uint64(0x0))
var x18 uint64
_, x18 = bits.Mul64(x16, 0xd2b51da312547e1b)
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x18, 0x1000000000000000)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x18, 0x14def9dea2f79cd6)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x18, 0x5812631a5cf5d3ed)
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x25, x22, uint64(0x0))
var x29 uint64
_, x29 = bits.Add64(x16, x24, uint64(0x0))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64((uint64(fiatScalarUint1(x17)) + (uint64(fiatScalarUint1(x15)) + (uint64(fiatScalarUint1(x11)) + x7))), x26, uint64(fiatScalarUint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x4, (uint64(fiatScalarUint1(x27)) + x23), uint64(fiatScalarUint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x5, x20, uint64(fiatScalarUint1(x33)))
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(x30, arg1[2], uint64(0x0))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(x32, uint64(0x0), uint64(fiatScalarUint1(x37)))
var x40 uint64
var x41 uint64
x40, x41 = bits.Add64(x34, uint64(0x0), uint64(fiatScalarUint1(x39)))
var x42 uint64
_, x42 = bits.Mul64(x36, 0xd2b51da312547e1b)
var x44 uint64
var x45 uint64
x45, x44 = bits.Mul64(x42, 0x1000000000000000)
var x46 uint64
var x47 uint64
x47, x46 = bits.Mul64(x42, 0x14def9dea2f79cd6)
var x48 uint64
var x49 uint64
x49, x48 = bits.Mul64(x42, 0x5812631a5cf5d3ed)
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x49, x46, uint64(0x0))
var x53 uint64
_, x53 = bits.Add64(x36, x48, uint64(0x0))
var x54 uint64
var x55 uint64
x54, x55 = bits.Add64(x38, x50, uint64(fiatScalarUint1(x53)))
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x40, (uint64(fiatScalarUint1(x51)) + x47), uint64(fiatScalarUint1(x55)))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64((uint64(fiatScalarUint1(x41)) + (uint64(fiatScalarUint1(x35)) + x21)), x44, uint64(fiatScalarUint1(x57)))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x54, arg1[3], uint64(0x0))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x56, uint64(0x0), uint64(fiatScalarUint1(x61)))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(x58, uint64(0x0), uint64(fiatScalarUint1(x63)))
var x66 uint64
_, x66 = bits.Mul64(x60, 0xd2b51da312547e1b)
var x68 uint64
var x69 uint64
x69, x68 = bits.Mul64(x66, 0x1000000000000000)
var x70 uint64
var x71 uint64
x71, x70 = bits.Mul64(x66, 0x14def9dea2f79cd6)
var x72 uint64
var x73 uint64
x73, x72 = bits.Mul64(x66, 0x5812631a5cf5d3ed)
var x74 uint64
var x75 uint64
x74, x75 = bits.Add64(x73, x70, uint64(0x0))
var x77 uint64
_, x77 = bits.Add64(x60, x72, uint64(0x0))
var x78 uint64
var x79 uint64
x78, x79 = bits.Add64(x62, x74, uint64(fiatScalarUint1(x77)))
var x80 uint64
var x81 uint64
x80, x81 = bits.Add64(x64, (uint64(fiatScalarUint1(x75)) + x71), uint64(fiatScalarUint1(x79)))
var x82 uint64
var x83 uint64
x82, x83 = bits.Add64((uint64(fiatScalarUint1(x65)) + (uint64(fiatScalarUint1(x59)) + x45)), x68, uint64(fiatScalarUint1(x81)))
x84 := (uint64(fiatScalarUint1(x83)) + x69)
var x85 uint64
var x86 uint64
x85, x86 = bits.Sub64(x78, 0x5812631a5cf5d3ed, uint64(0x0))
var x87 uint64
var x88 uint64
x87, x88 = bits.Sub64(x80, 0x14def9dea2f79cd6, uint64(fiatScalarUint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Sub64(x82, uint64(0x0), uint64(fiatScalarUint1(x88)))
var x91 uint64
var x92 uint64
x91, x92 = bits.Sub64(x84, 0x1000000000000000, uint64(fiatScalarUint1(x90)))
var x94 uint64
_, x94 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(fiatScalarUint1(x92)))
var x95 uint64
fiatScalarCmovznzU64(&x95, fiatScalarUint1(x94), x85, x78)
var x96 uint64
fiatScalarCmovznzU64(&x96, fiatScalarUint1(x94), x87, x80)
var x97 uint64
fiatScalarCmovznzU64(&x97, fiatScalarUint1(x94), x89, x82)
var x98 uint64
fiatScalarCmovznzU64(&x98, fiatScalarUint1(x94), x91, x84)
out1[0] = x95
out1[1] = x96
out1[2] = x97
out1[3] = x98
}
// fiatScalarToMontgomery translates a field element into the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = eval arg1 mod m
// 0 ≤ eval out1 < m
func fiatScalarToMontgomery(out1 *fiatScalarMontgomeryDomainFieldElement, arg1 *fiatScalarNonMontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, 0x399411b7c309a3d)
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, 0xceec73d217f5be65)
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, 0xd00e1ba768859347)
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, 0xa40611e3449c0f01)
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(fiatScalarUint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(fiatScalarUint1(x16)))
var x19 uint64
_, x19 = bits.Mul64(x11, 0xd2b51da312547e1b)
var x21 uint64
var x22 uint64
x22, x21 = bits.Mul64(x19, 0x1000000000000000)
var x23 uint64
var x24 uint64
x24, x23 = bits.Mul64(x19, 0x14def9dea2f79cd6)
var x25 uint64
var x26 uint64
x26, x25 = bits.Mul64(x19, 0x5812631a5cf5d3ed)
var x27 uint64
var x28 uint64
x27, x28 = bits.Add64(x26, x23, uint64(0x0))
var x30 uint64
_, x30 = bits.Add64(x11, x25, uint64(0x0))
var x31 uint64
var x32 uint64
x31, x32 = bits.Add64(x13, x27, uint64(fiatScalarUint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x15, (uint64(fiatScalarUint1(x28)) + x24), uint64(fiatScalarUint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x17, x21, uint64(fiatScalarUint1(x34)))
var x37 uint64
var x38 uint64
x38, x37 = bits.Mul64(x1, 0x399411b7c309a3d)
var x39 uint64
var x40 uint64
x40, x39 = bits.Mul64(x1, 0xceec73d217f5be65)
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, 0xd00e1ba768859347)
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, 0xa40611e3449c0f01)
var x45 uint64
var x46 uint64
x45, x46 = bits.Add64(x44, x41, uint64(0x0))
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(x42, x39, uint64(fiatScalarUint1(x46)))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x40, x37, uint64(fiatScalarUint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x31, x43, uint64(0x0))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x33, x45, uint64(fiatScalarUint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x35, x47, uint64(fiatScalarUint1(x54)))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(((uint64(fiatScalarUint1(x36)) + (uint64(fiatScalarUint1(x18)) + x6)) + x22), x49, uint64(fiatScalarUint1(x56)))
var x59 uint64
_, x59 = bits.Mul64(x51, 0xd2b51da312547e1b)
var x61 uint64
var x62 uint64
x62, x61 = bits.Mul64(x59, 0x1000000000000000)
var x63 uint64
var x64 uint64
x64, x63 = bits.Mul64(x59, 0x14def9dea2f79cd6)
var x65 uint64
var x66 uint64
x66, x65 = bits.Mul64(x59, 0x5812631a5cf5d3ed)
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x66, x63, uint64(0x0))
var x70 uint64
_, x70 = bits.Add64(x51, x65, uint64(0x0))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x53, x67, uint64(fiatScalarUint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x55, (uint64(fiatScalarUint1(x68)) + x64), uint64(fiatScalarUint1(x72)))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x57, x61, uint64(fiatScalarUint1(x74)))
var x77 uint64
var x78 uint64
x78, x77 = bits.Mul64(x2, 0x399411b7c309a3d)
var x79 uint64
var x80 uint64
x80, x79 = bits.Mul64(x2, 0xceec73d217f5be65)
var x81 uint64
var x82 uint64
x82, x81 = bits.Mul64(x2, 0xd00e1ba768859347)
var x83 uint64
var x84 uint64
x84, x83 = bits.Mul64(x2, 0xa40611e3449c0f01)
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x84, x81, uint64(0x0))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x82, x79, uint64(fiatScalarUint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x80, x77, uint64(fiatScalarUint1(x88)))
var x91 uint64
var x92 uint64
x91, x92 = bits.Add64(x71, x83, uint64(0x0))
var x93 uint64
var x94 uint64
x93, x94 = bits.Add64(x73, x85, uint64(fiatScalarUint1(x92)))
var x95 uint64
var x96 uint64
x95, x96 = bits.Add64(x75, x87, uint64(fiatScalarUint1(x94)))
var x97 uint64
var x98 uint64
x97, x98 = bits.Add64(((uint64(fiatScalarUint1(x76)) + (uint64(fiatScalarUint1(x58)) + (uint64(fiatScalarUint1(x50)) + x38))) + x62), x89, uint64(fiatScalarUint1(x96)))
var x99 uint64
_, x99 = bits.Mul64(x91, 0xd2b51da312547e1b)
var x101 uint64
var x102 uint64
x102, x101 = bits.Mul64(x99, 0x1000000000000000)
var x103 uint64
var x104 uint64
x104, x103 = bits.Mul64(x99, 0x14def9dea2f79cd6)
var x105 uint64
var x106 uint64
x106, x105 = bits.Mul64(x99, 0x5812631a5cf5d3ed)
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x106, x103, uint64(0x0))
var x110 uint64
_, x110 = bits.Add64(x91, x105, uint64(0x0))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x93, x107, uint64(fiatScalarUint1(x110)))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x95, (uint64(fiatScalarUint1(x108)) + x104), uint64(fiatScalarUint1(x112)))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x97, x101, uint64(fiatScalarUint1(x114)))
var x117 uint64
var x118 uint64
x118, x117 = bits.Mul64(x3, 0x399411b7c309a3d)
var x119 uint64
var x120 uint64
x120, x119 = bits.Mul64(x3, 0xceec73d217f5be65)
var x121 uint64
var x122 uint64
x122, x121 = bits.Mul64(x3, 0xd00e1ba768859347)
var x123 uint64
var x124 uint64
x124, x123 = bits.Mul64(x3, 0xa40611e3449c0f01)
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64(x124, x121, uint64(0x0))
var x127 uint64
var x128 uint64
x127, x128 = bits.Add64(x122, x119, uint64(fiatScalarUint1(x126)))
var x129 uint64
var x130 uint64
x129, x130 = bits.Add64(x120, x117, uint64(fiatScalarUint1(x128)))
var x131 uint64
var x132 uint64
x131, x132 = bits.Add64(x111, x123, uint64(0x0))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x113, x125, uint64(fiatScalarUint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x115, x127, uint64(fiatScalarUint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(((uint64(fiatScalarUint1(x116)) + (uint64(fiatScalarUint1(x98)) + (uint64(fiatScalarUint1(x90)) + x78))) + x102), x129, uint64(fiatScalarUint1(x136)))
var x139 uint64
_, x139 = bits.Mul64(x131, 0xd2b51da312547e1b)
var x141 uint64
var x142 uint64
x142, x141 = bits.Mul64(x139, 0x1000000000000000)
var x143 uint64
var x144 uint64
x144, x143 = bits.Mul64(x139, 0x14def9dea2f79cd6)
var x145 uint64
var x146 uint64
x146, x145 = bits.Mul64(x139, 0x5812631a5cf5d3ed)
var x147 uint64
var x148 uint64
x147, x148 = bits.Add64(x146, x143, uint64(0x0))
var x150 uint64
_, x150 = bits.Add64(x131, x145, uint64(0x0))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x133, x147, uint64(fiatScalarUint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x135, (uint64(fiatScalarUint1(x148)) + x144), uint64(fiatScalarUint1(x152)))
var x155 uint64
var x156 uint64
x155, x156 = bits.Add64(x137, x141, uint64(fiatScalarUint1(x154)))
x157 := ((uint64(fiatScalarUint1(x156)) + (uint64(fiatScalarUint1(x138)) + (uint64(fiatScalarUint1(x130)) + x118))) + x142)
var x158 uint64
var x159 uint64
x158, x159 = bits.Sub64(x151, 0x5812631a5cf5d3ed, uint64(0x0))
var x160 uint64
var x161 uint64
x160, x161 = bits.Sub64(x153, 0x14def9dea2f79cd6, uint64(fiatScalarUint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Sub64(x155, uint64(0x0), uint64(fiatScalarUint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Sub64(x157, 0x1000000000000000, uint64(fiatScalarUint1(x163)))
var x167 uint64
_, x167 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(fiatScalarUint1(x165)))
var x168 uint64
fiatScalarCmovznzU64(&x168, fiatScalarUint1(x167), x158, x151)
var x169 uint64
fiatScalarCmovznzU64(&x169, fiatScalarUint1(x167), x160, x153)
var x170 uint64
fiatScalarCmovznzU64(&x170, fiatScalarUint1(x167), x162, x155)
var x171 uint64
fiatScalarCmovznzU64(&x171, fiatScalarUint1(x167), x164, x157)
out1[0] = x168
out1[1] = x169
out1[2] = x170
out1[3] = x171
}
// fiatScalarToBytes serializes a field element NOT in the Montgomery domain to bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = map (λ x, ⌊((eval arg1 mod m) mod 2^(8 * (x + 1))) / 2^(8 * x)⌋) [0..31]
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0x1fffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0x1f]]
func fiatScalarToBytes(out1 *[32]uint8, arg1 *[4]uint64) {
x1 := arg1[3]
x2 := arg1[2]
x3 := arg1[1]
x4 := arg1[0]
x5 := (uint8(x4) & 0xff)
x6 := (x4 >> 8)
x7 := (uint8(x6) & 0xff)
x8 := (x6 >> 8)
x9 := (uint8(x8) & 0xff)
x10 := (x8 >> 8)
x11 := (uint8(x10) & 0xff)
x12 := (x10 >> 8)
x13 := (uint8(x12) & 0xff)
x14 := (x12 >> 8)
x15 := (uint8(x14) & 0xff)
x16 := (x14 >> 8)
x17 := (uint8(x16) & 0xff)
x18 := uint8((x16 >> 8))
x19 := (uint8(x3) & 0xff)
x20 := (x3 >> 8)
x21 := (uint8(x20) & 0xff)
x22 := (x20 >> 8)
x23 := (uint8(x22) & 0xff)
x24 := (x22 >> 8)
x25 := (uint8(x24) & 0xff)
x26 := (x24 >> 8)
x27 := (uint8(x26) & 0xff)
x28 := (x26 >> 8)
x29 := (uint8(x28) & 0xff)
x30 := (x28 >> 8)
x31 := (uint8(x30) & 0xff)
x32 := uint8((x30 >> 8))
x33 := (uint8(x2) & 0xff)
x34 := (x2 >> 8)
x35 := (uint8(x34) & 0xff)
x36 := (x34 >> 8)
x37 := (uint8(x36) & 0xff)
x38 := (x36 >> 8)
x39 := (uint8(x38) & 0xff)
x40 := (x38 >> 8)
x41 := (uint8(x40) & 0xff)
x42 := (x40 >> 8)
x43 := (uint8(x42) & 0xff)
x44 := (x42 >> 8)
x45 := (uint8(x44) & 0xff)
x46 := uint8((x44 >> 8))
x47 := (uint8(x1) & 0xff)
x48 := (x1 >> 8)
x49 := (uint8(x48) & 0xff)
x50 := (x48 >> 8)
x51 := (uint8(x50) & 0xff)
x52 := (x50 >> 8)
x53 := (uint8(x52) & 0xff)
x54 := (x52 >> 8)
x55 := (uint8(x54) & 0xff)
x56 := (x54 >> 8)
x57 := (uint8(x56) & 0xff)
x58 := (x56 >> 8)
x59 := (uint8(x58) & 0xff)
x60 := uint8((x58 >> 8))
out1[0] = x5
out1[1] = x7
out1[2] = x9
out1[3] = x11
out1[4] = x13
out1[5] = x15
out1[6] = x17
out1[7] = x18
out1[8] = x19
out1[9] = x21
out1[10] = x23
out1[11] = x25
out1[12] = x27
out1[13] = x29
out1[14] = x31
out1[15] = x32
out1[16] = x33
out1[17] = x35
out1[18] = x37
out1[19] = x39
out1[20] = x41
out1[21] = x43
out1[22] = x45
out1[23] = x46
out1[24] = x47
out1[25] = x49
out1[26] = x51
out1[27] = x53
out1[28] = x55
out1[29] = x57
out1[30] = x59
out1[31] = x60
}
// fiatScalarFromBytes deserializes a field element NOT in the Montgomery domain from bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ bytes_eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = bytes_eval arg1 mod m
// 0 ≤ eval out1 < m
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0x1f]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0x1fffffffffffffff]]
func fiatScalarFromBytes(out1 *[4]uint64, arg1 *[32]uint8) {
x1 := (uint64(arg1[31]) << 56)
x2 := (uint64(arg1[30]) << 48)
x3 := (uint64(arg1[29]) << 40)
x4 := (uint64(arg1[28]) << 32)
x5 := (uint64(arg1[27]) << 24)
x6 := (uint64(arg1[26]) << 16)
x7 := (uint64(arg1[25]) << 8)
x8 := arg1[24]
x9 := (uint64(arg1[23]) << 56)
x10 := (uint64(arg1[22]) << 48)
x11 := (uint64(arg1[21]) << 40)
x12 := (uint64(arg1[20]) << 32)
x13 := (uint64(arg1[19]) << 24)
x14 := (uint64(arg1[18]) << 16)
x15 := (uint64(arg1[17]) << 8)
x16 := arg1[16]
x17 := (uint64(arg1[15]) << 56)
x18 := (uint64(arg1[14]) << 48)
x19 := (uint64(arg1[13]) << 40)
x20 := (uint64(arg1[12]) << 32)
x21 := (uint64(arg1[11]) << 24)
x22 := (uint64(arg1[10]) << 16)
x23 := (uint64(arg1[9]) << 8)
x24 := arg1[8]
x25 := (uint64(arg1[7]) << 56)
x26 := (uint64(arg1[6]) << 48)
x27 := (uint64(arg1[5]) << 40)
x28 := (uint64(arg1[4]) << 32)
x29 := (uint64(arg1[3]) << 24)
x30 := (uint64(arg1[2]) << 16)
x31 := (uint64(arg1[1]) << 8)
x32 := arg1[0]
x33 := (x31 + uint64(x32))
x34 := (x30 + x33)
x35 := (x29 + x34)
x36 := (x28 + x35)
x37 := (x27 + x36)
x38 := (x26 + x37)
x39 := (x25 + x38)
x40 := (x23 + uint64(x24))
x41 := (x22 + x40)
x42 := (x21 + x41)
x43 := (x20 + x42)
x44 := (x19 + x43)
x45 := (x18 + x44)
x46 := (x17 + x45)
x47 := (x15 + uint64(x16))
x48 := (x14 + x47)
x49 := (x13 + x48)
x50 := (x12 + x49)
x51 := (x11 + x50)
x52 := (x10 + x51)
x53 := (x9 + x52)
x54 := (x7 + uint64(x8))
x55 := (x6 + x54)
x56 := (x5 + x55)
x57 := (x4 + x56)
x58 := (x3 + x57)
x59 := (x2 + x58)
x60 := (x1 + x59)
out1[0] = x39
out1[1] = x46
out1[2] = x53
out1[3] = x60
}
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import "sync"
// basepointTable is a set of 32 affineLookupTables, where table i is generated
// from 256i * basepoint. It is precomputed the first time it's used.
func basepointTable() *[32]affineLookupTable {
basepointTablePrecomp.initOnce.Do(func() {
p := NewGeneratorPoint()
for i := 0; i < 32; i++ {
basepointTablePrecomp.table[i].FromP3(p)
for j := 0; j < 8; j++ {
p.Add(p, p)
}
}
})
return &basepointTablePrecomp.table
}
var basepointTablePrecomp struct {
table [32]affineLookupTable
initOnce sync.Once
}
// ScalarBaseMult sets v = x * B, where B is the canonical generator, and
// returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarBaseMult(x *Scalar) *Point {
basepointTable := basepointTable()
// Write x = sum(x_i * 16^i) so x*B = sum( B*x_i*16^i )
// as described in the Ed25519 paper
//
// Group even and odd coefficients
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + x_1*16^1*B + x_3*16^3*B + ... + x_63*16^63*B
// x*B = x_0*16^0*B + x_2*16^2*B + ... + x_62*16^62*B
// + 16*( x_1*16^0*B + x_3*16^2*B + ... + x_63*16^62*B)
//
// We use a lookup table for each i to get x_i*16^(2*i)*B
// and do four doublings to multiply by 16.
digits := x.signedRadix16()
multiple := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
// Accumulate the odd components first
v.Set(NewIdentityPoint())
for i := 1; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
// Multiply by 16
tmp2.FromP3(v) // tmp2 = v in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*v in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*v in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*v in P1xP1 coords
v.fromP1xP1(tmp1) // now v = 16*(odd components)
// Accumulate the even components
for i := 0; i < 64; i += 2 {
basepointTable[i/2].SelectInto(multiple, digits[i])
tmp1.AddAffine(v, multiple)
v.fromP1xP1(tmp1)
}
return v
}
// ScalarMult sets v = x * q, and returns v.
//
// The scalar multiplication is done in constant time.
func (v *Point) ScalarMult(x *Scalar, q *Point) *Point {
checkInitialized(q)
var table projLookupTable
table.FromP3(q)
// Write x = sum(x_i * 16^i)
// so x*Q = sum( Q*x_i*16^i )
// = Q*x_0 + 16*(Q*x_1 + 16*( ... + Q*x_63) ... )
// <------compute inside out---------
//
// We use the lookup table to get the x_i*Q values
// and do four doublings to compute 16*Q
digits := x.signedRadix16()
// Unwrap first loop iteration to save computing 16*identity
multiple := &projCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
table.SelectInto(multiple, digits[63])
v.Set(NewIdentityPoint())
tmp1.Add(v, multiple) // tmp1 = x_63*Q in P1xP1 coords
for i := 62; i >= 0; i-- {
tmp2.FromP1xP1(tmp1) // tmp2 = (prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
table.SelectInto(multiple, digits[i])
tmp1.Add(v, multiple) // tmp1 = x_i*Q + 16*(prev) in P1xP1 coords
}
v.fromP1xP1(tmp1)
return v
}
// basepointNafTable is the nafLookupTable8 for the basepoint.
// It is precomputed the first time it's used.
func basepointNafTable() *nafLookupTable8 {
basepointNafTablePrecomp.initOnce.Do(func() {
basepointNafTablePrecomp.table.FromP3(NewGeneratorPoint())
})
return &basepointNafTablePrecomp.table
}
var basepointNafTablePrecomp struct {
table nafLookupTable8
initOnce sync.Once
}
// VarTimeDoubleScalarBaseMult sets v = a * A + b * B, where B is the canonical
// generator, and returns v.
//
// Execution time depends on the inputs.
func (v *Point) VarTimeDoubleScalarBaseMult(a *Scalar, A *Point, b *Scalar) *Point {
checkInitialized(A)
// Similarly to the single variable-base approach, we compute
// digits and use them with a lookup table. However, because
// we are allowed to do variable-time operations, we don't
// need constant-time lookups or constant-time digit
// computations.
//
// So we use a non-adjacent form of some width w instead of
// radix 16. This is like a binary representation (one digit
// for each binary place) but we allow the digits to grow in
// magnitude up to 2^{w-1} so that the nonzero digits are as
// sparse as possible. Intuitively, this "condenses" the
// "mass" of the scalar onto sparse coefficients (meaning
// fewer additions).
basepointNafTable := basepointNafTable()
var aTable nafLookupTable5
aTable.FromP3(A)
// Because the basepoint is fixed, we can use a wider NAF
// corresponding to a bigger table.
aNaf := a.nonAdjacentForm(5)
bNaf := b.nonAdjacentForm(8)
// Find the first nonzero coefficient.
i := 255
for j := i; j >= 0; j-- {
if aNaf[j] != 0 || bNaf[j] != 0 {
break
}
}
multA := &projCached{}
multB := &affineCached{}
tmp1 := &projP1xP1{}
tmp2 := &projP2{}
tmp2.Zero()
// Move from high to low bits, doubling the accumulator
// at each iteration and checking whether there is a nonzero
// coefficient to look up a multiple of.
for ; i >= 0; i-- {
tmp1.Double(tmp2)
// Only update v if we have a nonzero coeff to add in.
if aNaf[i] > 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, aNaf[i])
tmp1.Add(v, multA)
} else if aNaf[i] < 0 {
v.fromP1xP1(tmp1)
aTable.SelectInto(multA, -aNaf[i])
tmp1.Sub(v, multA)
}
if bNaf[i] > 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, bNaf[i])
tmp1.AddAffine(v, multB)
} else if bNaf[i] < 0 {
v.fromP1xP1(tmp1)
basepointNafTable.SelectInto(multB, -bNaf[i])
tmp1.SubAffine(v, multB)
}
tmp2.FromP1xP1(tmp1)
}
v.fromP2(tmp2)
return v
}
// Copyright (c) 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edwards25519
import (
"crypto/subtle"
)
// A dynamic lookup table for variable-base, constant-time scalar muls.
type projLookupTable struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, constant-time scalar muls.
type affineLookupTable struct {
points [8]affineCached
}
// A dynamic lookup table for variable-base, variable-time scalar muls.
type nafLookupTable5 struct {
points [8]projCached
}
// A precomputed lookup table for fixed-base, variable-time scalar muls.
type nafLookupTable8 struct {
points [64]affineCached
}
// Constructors.
// Builds a lookup table at runtime. Fast.
func (v *projLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to a projCached
// This is needlessly complicated because the API has explicit
// receivers instead of creating stack objects and relying on RVO
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(q, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *affineLookupTable) FromP3(q *Point) {
// Goal: v.points[i] = (i+1)*Q, i.e., Q, 2Q, ..., 8Q
// This allows lookup of -8Q, ..., -Q, 0, Q, ..., 8Q
v.points[0].FromP3(q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
// Compute (i+1)*Q as Q + i*Q and convert to affineCached
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(q, &v.points[i])))
}
}
// Builds a lookup table at runtime. Fast.
func (v *nafLookupTable5) FromP3(q *Point) {
// Goal: v.points[i] = (2*i+1)*Q, i.e., Q, 3Q, 5Q, ..., 15Q
// This allows lookup of -15Q, ..., -3Q, -Q, 0, Q, 3Q, ..., 15Q
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 7; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.Add(&q2, &v.points[i])))
}
}
// This is not optimised for speed; fixed-base tables should be precomputed.
func (v *nafLookupTable8) FromP3(q *Point) {
v.points[0].FromP3(q)
q2 := Point{}
q2.Add(q, q)
tmpP3 := Point{}
tmpP1xP1 := projP1xP1{}
for i := 0; i < 63; i++ {
v.points[i+1].FromP3(tmpP3.fromP1xP1(tmpP1xP1.AddAffine(&q2, &v.points[i])))
}
}
// Selectors.
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *projLookupTable) SelectInto(dest *projCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Set dest to x*Q, where -8 <= x <= 8, in constant time.
func (v *affineLookupTable) SelectInto(dest *affineCached, x int8) {
// Compute xabs = |x|
xmask := x >> 7
xabs := uint8((x + xmask) ^ xmask)
dest.Zero()
for j := 1; j <= 8; j++ {
// Set dest = j*Q if |x| = j
cond := subtle.ConstantTimeByteEq(xabs, uint8(j))
dest.Select(&v.points[j-1], dest, cond)
}
// Now dest = |x|*Q, conditionally negate to get x*Q
dest.CondNeg(int(xmask & 1))
}
// Given odd x with 0 < x < 2^4, return x*Q (in variable time).
func (v *nafLookupTable5) SelectInto(dest *projCached, x int8) {
*dest = v.points[x/2]
}
// Given odd x with 0 < x < 2^7, return x*Q (in variable time).
func (v *nafLookupTable8) SelectInto(dest *affineCached, x int8) {
*dest = v.points[x/2]
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package fiat
import (
"crypto/subtle"
"errors"
)
// P224Element is an integer modulo 2^224 - 2^96 + 1.
//
// The zero value is a valid zero element.
type P224Element struct {
// Values are represented internally always in the Montgomery domain, and
// converted in Bytes and SetBytes.
x p224MontgomeryDomainFieldElement
}
const p224ElementLen = 28
type p224UntypedFieldElement = [4]uint64
// One sets e = 1, and returns e.
func (e *P224Element) One() *P224Element {
p224SetOne(&e.x)
return e
}
// Equal returns 1 if e == t, and zero otherwise.
func (e *P224Element) Equal(t *P224Element) int {
eBytes := e.Bytes()
tBytes := t.Bytes()
return subtle.ConstantTimeCompare(eBytes, tBytes)
}
// IsZero returns 1 if e == 0, and zero otherwise.
func (e *P224Element) IsZero() int {
zero := make([]byte, p224ElementLen)
eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, zero)
}
// Set sets e = t, and returns e.
func (e *P224Element) Set(t *P224Element) *P224Element {
e.x = t.x
return e
}
// Bytes returns the 28-byte big-endian encoding of e.
func (e *P224Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p224ElementLen]byte
return e.bytes(&out)
}
func (e *P224Element) bytes(out *[p224ElementLen]byte) []byte {
var tmp p224NonMontgomeryDomainFieldElement
p224FromMontgomery(&tmp, &e.x)
p224ToBytes(out, (*p224UntypedFieldElement)(&tmp))
p224InvertEndianness(out[:])
return out[:]
}
// SetBytes sets e = v, where v is a big-endian 28-byte encoding, and returns e.
// If v is not 28 bytes or it encodes a value higher than 2^224 - 2^96 + 1,
// SetBytes returns nil and an error, and e is unchanged.
func (e *P224Element) SetBytes(v []byte) (*P224Element, error) {
if len(v) != p224ElementLen {
return nil, errors.New("invalid P224Element encoding")
}
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P224Element).Sub(
new(P224Element), new(P224Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P224Element encoding")
}
}
var in [p224ElementLen]byte
copy(in[:], v)
p224InvertEndianness(in[:])
var tmp p224NonMontgomeryDomainFieldElement
p224FromBytes((*p224UntypedFieldElement)(&tmp), &in)
p224ToMontgomery(&e.x, &tmp)
return e, nil
}
// Add sets e = t1 + t2, and returns e.
func (e *P224Element) Add(t1, t2 *P224Element) *P224Element {
p224Add(&e.x, &t1.x, &t2.x)
return e
}
// Sub sets e = t1 - t2, and returns e.
func (e *P224Element) Sub(t1, t2 *P224Element) *P224Element {
p224Sub(&e.x, &t1.x, &t2.x)
return e
}
// Mul sets e = t1 * t2, and returns e.
func (e *P224Element) Mul(t1, t2 *P224Element) *P224Element {
p224Mul(&e.x, &t1.x, &t2.x)
return e
}
// Square sets e = t * t, and returns e.
func (e *P224Element) Square(t *P224Element) *P224Element {
p224Square(&e.x, &t.x)
return e
}
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *P224Element) Select(a, b *P224Element, cond int) *P224Element {
p224Selectznz((*p224UntypedFieldElement)(&v.x), p224Uint1(cond),
(*p224UntypedFieldElement)(&b.x), (*p224UntypedFieldElement)(&a.x))
return v
}
func p224InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]
}
}
// Code generated by Fiat Cryptography. DO NOT EDIT.
//
// Autogenerated: word_by_word_montgomery --lang Go --no-wide-int --cmovznz-by-mul --relax-primitive-carry-to-bitwidth 32,64 --internal-static --public-function-case camelCase --public-type-case camelCase --private-function-case camelCase --private-type-case camelCase --doc-text-before-function-name '' --doc-newline-before-package-declaration --doc-prepend-header 'Code generated by Fiat Cryptography. DO NOT EDIT.' --package-name fiat --no-prefix-fiat p224 64 '2^224 - 2^96 + 1' mul square add sub one from_montgomery to_montgomery selectznz to_bytes from_bytes
//
// curve description: p224
//
// machine_wordsize = 64 (from "64")
//
// requested operations: mul, square, add, sub, one, from_montgomery, to_montgomery, selectznz, to_bytes, from_bytes
//
// m = 0xffffffffffffffffffffffffffffffff000000000000000000000001 (from "2^224 - 2^96 + 1")
//
//
//
// NOTE: In addition to the bounds specified above each function, all
//
// functions synthesized for this Montgomery arithmetic require the
//
// input to be strictly less than the prime modulus (m), and also
//
// require the input to be in the unique saturated representation.
//
// All functions also ensure that these two properties are true of
//
// return values.
//
//
//
// Computed values:
//
// eval z = z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192)
//
// bytes_eval z = z[0] + (z[1] << 8) + (z[2] << 16) + (z[3] << 24) + (z[4] << 32) + (z[5] << 40) + (z[6] << 48) + (z[7] << 56) + (z[8] << 64) + (z[9] << 72) + (z[10] << 80) + (z[11] << 88) + (z[12] << 96) + (z[13] << 104) + (z[14] << 112) + (z[15] << 120) + (z[16] << 128) + (z[17] << 136) + (z[18] << 144) + (z[19] << 152) + (z[20] << 160) + (z[21] << 168) + (z[22] << 176) + (z[23] << 184) + (z[24] << 192) + (z[25] << 200) + (z[26] << 208) + (z[27] << 216)
//
// twos_complement_eval z = let x1 := z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) in
//
// if x1 & (2^256-1) < 2^255 then x1 & (2^256-1) else (x1 & (2^256-1)) - 2^256
package fiat
import "math/bits"
type p224Uint1 uint64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
type p224Int1 int64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
// The type p224MontgomeryDomainFieldElement is a field element in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p224MontgomeryDomainFieldElement [4]uint64
// The type p224NonMontgomeryDomainFieldElement is a field element NOT in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p224NonMontgomeryDomainFieldElement [4]uint64
// p224CmovznzU64 is a single-word conditional move.
//
// Postconditions:
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func p224CmovznzU64(out1 *uint64, arg1 p224Uint1, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// p224Mul multiplies two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p224Mul(out1 *p224MontgomeryDomainFieldElement, arg1 *p224MontgomeryDomainFieldElement, arg2 *p224MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, arg2[3])
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, arg2[2])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, arg2[1])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, arg2[0])
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p224Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p224Uint1(x16)))
x19 := (uint64(p224Uint1(x18)) + x6)
var x20 uint64
_, x20 = bits.Mul64(x11, 0xffffffffffffffff)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x20, 0xffffffff)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x20, 0xffffffffffffffff)
var x26 uint64
var x27 uint64
x27, x26 = bits.Mul64(x20, 0xffffffff00000000)
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x27, x24, uint64(0x0))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x25, x22, uint64(p224Uint1(x29)))
x32 := (uint64(p224Uint1(x31)) + x23)
var x34 uint64
_, x34 = bits.Add64(x11, x20, uint64(0x0))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x13, x26, uint64(p224Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x15, x28, uint64(p224Uint1(x36)))
var x39 uint64
var x40 uint64
x39, x40 = bits.Add64(x17, x30, uint64(p224Uint1(x38)))
var x41 uint64
var x42 uint64
x41, x42 = bits.Add64(x19, x32, uint64(p224Uint1(x40)))
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, arg2[3])
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, arg2[2])
var x47 uint64
var x48 uint64
x48, x47 = bits.Mul64(x1, arg2[1])
var x49 uint64
var x50 uint64
x50, x49 = bits.Mul64(x1, arg2[0])
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x50, x47, uint64(0x0))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x48, x45, uint64(p224Uint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x46, x43, uint64(p224Uint1(x54)))
x57 := (uint64(p224Uint1(x56)) + x44)
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x35, x49, uint64(0x0))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x37, x51, uint64(p224Uint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x39, x53, uint64(p224Uint1(x61)))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(x41, x55, uint64(p224Uint1(x63)))
var x66 uint64
var x67 uint64
x66, x67 = bits.Add64(uint64(p224Uint1(x42)), x57, uint64(p224Uint1(x65)))
var x68 uint64
_, x68 = bits.Mul64(x58, 0xffffffffffffffff)
var x70 uint64
var x71 uint64
x71, x70 = bits.Mul64(x68, 0xffffffff)
var x72 uint64
var x73 uint64
x73, x72 = bits.Mul64(x68, 0xffffffffffffffff)
var x74 uint64
var x75 uint64
x75, x74 = bits.Mul64(x68, 0xffffffff00000000)
var x76 uint64
var x77 uint64
x76, x77 = bits.Add64(x75, x72, uint64(0x0))
var x78 uint64
var x79 uint64
x78, x79 = bits.Add64(x73, x70, uint64(p224Uint1(x77)))
x80 := (uint64(p224Uint1(x79)) + x71)
var x82 uint64
_, x82 = bits.Add64(x58, x68, uint64(0x0))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x60, x74, uint64(p224Uint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x62, x76, uint64(p224Uint1(x84)))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x64, x78, uint64(p224Uint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x66, x80, uint64(p224Uint1(x88)))
x91 := (uint64(p224Uint1(x90)) + uint64(p224Uint1(x67)))
var x92 uint64
var x93 uint64
x93, x92 = bits.Mul64(x2, arg2[3])
var x94 uint64
var x95 uint64
x95, x94 = bits.Mul64(x2, arg2[2])
var x96 uint64
var x97 uint64
x97, x96 = bits.Mul64(x2, arg2[1])
var x98 uint64
var x99 uint64
x99, x98 = bits.Mul64(x2, arg2[0])
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x99, x96, uint64(0x0))
var x102 uint64
var x103 uint64
x102, x103 = bits.Add64(x97, x94, uint64(p224Uint1(x101)))
var x104 uint64
var x105 uint64
x104, x105 = bits.Add64(x95, x92, uint64(p224Uint1(x103)))
x106 := (uint64(p224Uint1(x105)) + x93)
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x83, x98, uint64(0x0))
var x109 uint64
var x110 uint64
x109, x110 = bits.Add64(x85, x100, uint64(p224Uint1(x108)))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x87, x102, uint64(p224Uint1(x110)))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x89, x104, uint64(p224Uint1(x112)))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x91, x106, uint64(p224Uint1(x114)))
var x117 uint64
_, x117 = bits.Mul64(x107, 0xffffffffffffffff)
var x119 uint64
var x120 uint64
x120, x119 = bits.Mul64(x117, 0xffffffff)
var x121 uint64
var x122 uint64
x122, x121 = bits.Mul64(x117, 0xffffffffffffffff)
var x123 uint64
var x124 uint64
x124, x123 = bits.Mul64(x117, 0xffffffff00000000)
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64(x124, x121, uint64(0x0))
var x127 uint64
var x128 uint64
x127, x128 = bits.Add64(x122, x119, uint64(p224Uint1(x126)))
x129 := (uint64(p224Uint1(x128)) + x120)
var x131 uint64
_, x131 = bits.Add64(x107, x117, uint64(0x0))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x109, x123, uint64(p224Uint1(x131)))
var x134 uint64
var x135 uint64
x134, x135 = bits.Add64(x111, x125, uint64(p224Uint1(x133)))
var x136 uint64
var x137 uint64
x136, x137 = bits.Add64(x113, x127, uint64(p224Uint1(x135)))
var x138 uint64
var x139 uint64
x138, x139 = bits.Add64(x115, x129, uint64(p224Uint1(x137)))
x140 := (uint64(p224Uint1(x139)) + uint64(p224Uint1(x116)))
var x141 uint64
var x142 uint64
x142, x141 = bits.Mul64(x3, arg2[3])
var x143 uint64
var x144 uint64
x144, x143 = bits.Mul64(x3, arg2[2])
var x145 uint64
var x146 uint64
x146, x145 = bits.Mul64(x3, arg2[1])
var x147 uint64
var x148 uint64
x148, x147 = bits.Mul64(x3, arg2[0])
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x148, x145, uint64(0x0))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x146, x143, uint64(p224Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x144, x141, uint64(p224Uint1(x152)))
x155 := (uint64(p224Uint1(x154)) + x142)
var x156 uint64
var x157 uint64
x156, x157 = bits.Add64(x132, x147, uint64(0x0))
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x134, x149, uint64(p224Uint1(x157)))
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x136, x151, uint64(p224Uint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Add64(x138, x153, uint64(p224Uint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Add64(x140, x155, uint64(p224Uint1(x163)))
var x166 uint64
_, x166 = bits.Mul64(x156, 0xffffffffffffffff)
var x168 uint64
var x169 uint64
x169, x168 = bits.Mul64(x166, 0xffffffff)
var x170 uint64
var x171 uint64
x171, x170 = bits.Mul64(x166, 0xffffffffffffffff)
var x172 uint64
var x173 uint64
x173, x172 = bits.Mul64(x166, 0xffffffff00000000)
var x174 uint64
var x175 uint64
x174, x175 = bits.Add64(x173, x170, uint64(0x0))
var x176 uint64
var x177 uint64
x176, x177 = bits.Add64(x171, x168, uint64(p224Uint1(x175)))
x178 := (uint64(p224Uint1(x177)) + x169)
var x180 uint64
_, x180 = bits.Add64(x156, x166, uint64(0x0))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x158, x172, uint64(p224Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x160, x174, uint64(p224Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x162, x176, uint64(p224Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x164, x178, uint64(p224Uint1(x186)))
x189 := (uint64(p224Uint1(x188)) + uint64(p224Uint1(x165)))
var x190 uint64
var x191 uint64
x190, x191 = bits.Sub64(x181, uint64(0x1), uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Sub64(x183, 0xffffffff00000000, uint64(p224Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Sub64(x185, 0xffffffffffffffff, uint64(p224Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Sub64(x187, 0xffffffff, uint64(p224Uint1(x195)))
var x199 uint64
_, x199 = bits.Sub64(x189, uint64(0x0), uint64(p224Uint1(x197)))
var x200 uint64
p224CmovznzU64(&x200, p224Uint1(x199), x190, x181)
var x201 uint64
p224CmovznzU64(&x201, p224Uint1(x199), x192, x183)
var x202 uint64
p224CmovznzU64(&x202, p224Uint1(x199), x194, x185)
var x203 uint64
p224CmovznzU64(&x203, p224Uint1(x199), x196, x187)
out1[0] = x200
out1[1] = x201
out1[2] = x202
out1[3] = x203
}
// p224Square squares a field element in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg1)) mod m
// 0 ≤ eval out1 < m
func p224Square(out1 *p224MontgomeryDomainFieldElement, arg1 *p224MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, arg1[3])
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, arg1[2])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, arg1[1])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, arg1[0])
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p224Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p224Uint1(x16)))
x19 := (uint64(p224Uint1(x18)) + x6)
var x20 uint64
_, x20 = bits.Mul64(x11, 0xffffffffffffffff)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x20, 0xffffffff)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x20, 0xffffffffffffffff)
var x26 uint64
var x27 uint64
x27, x26 = bits.Mul64(x20, 0xffffffff00000000)
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x27, x24, uint64(0x0))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x25, x22, uint64(p224Uint1(x29)))
x32 := (uint64(p224Uint1(x31)) + x23)
var x34 uint64
_, x34 = bits.Add64(x11, x20, uint64(0x0))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x13, x26, uint64(p224Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x15, x28, uint64(p224Uint1(x36)))
var x39 uint64
var x40 uint64
x39, x40 = bits.Add64(x17, x30, uint64(p224Uint1(x38)))
var x41 uint64
var x42 uint64
x41, x42 = bits.Add64(x19, x32, uint64(p224Uint1(x40)))
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, arg1[3])
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, arg1[2])
var x47 uint64
var x48 uint64
x48, x47 = bits.Mul64(x1, arg1[1])
var x49 uint64
var x50 uint64
x50, x49 = bits.Mul64(x1, arg1[0])
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x50, x47, uint64(0x0))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x48, x45, uint64(p224Uint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x46, x43, uint64(p224Uint1(x54)))
x57 := (uint64(p224Uint1(x56)) + x44)
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x35, x49, uint64(0x0))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x37, x51, uint64(p224Uint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x39, x53, uint64(p224Uint1(x61)))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(x41, x55, uint64(p224Uint1(x63)))
var x66 uint64
var x67 uint64
x66, x67 = bits.Add64(uint64(p224Uint1(x42)), x57, uint64(p224Uint1(x65)))
var x68 uint64
_, x68 = bits.Mul64(x58, 0xffffffffffffffff)
var x70 uint64
var x71 uint64
x71, x70 = bits.Mul64(x68, 0xffffffff)
var x72 uint64
var x73 uint64
x73, x72 = bits.Mul64(x68, 0xffffffffffffffff)
var x74 uint64
var x75 uint64
x75, x74 = bits.Mul64(x68, 0xffffffff00000000)
var x76 uint64
var x77 uint64
x76, x77 = bits.Add64(x75, x72, uint64(0x0))
var x78 uint64
var x79 uint64
x78, x79 = bits.Add64(x73, x70, uint64(p224Uint1(x77)))
x80 := (uint64(p224Uint1(x79)) + x71)
var x82 uint64
_, x82 = bits.Add64(x58, x68, uint64(0x0))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x60, x74, uint64(p224Uint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x62, x76, uint64(p224Uint1(x84)))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x64, x78, uint64(p224Uint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x66, x80, uint64(p224Uint1(x88)))
x91 := (uint64(p224Uint1(x90)) + uint64(p224Uint1(x67)))
var x92 uint64
var x93 uint64
x93, x92 = bits.Mul64(x2, arg1[3])
var x94 uint64
var x95 uint64
x95, x94 = bits.Mul64(x2, arg1[2])
var x96 uint64
var x97 uint64
x97, x96 = bits.Mul64(x2, arg1[1])
var x98 uint64
var x99 uint64
x99, x98 = bits.Mul64(x2, arg1[0])
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x99, x96, uint64(0x0))
var x102 uint64
var x103 uint64
x102, x103 = bits.Add64(x97, x94, uint64(p224Uint1(x101)))
var x104 uint64
var x105 uint64
x104, x105 = bits.Add64(x95, x92, uint64(p224Uint1(x103)))
x106 := (uint64(p224Uint1(x105)) + x93)
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x83, x98, uint64(0x0))
var x109 uint64
var x110 uint64
x109, x110 = bits.Add64(x85, x100, uint64(p224Uint1(x108)))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x87, x102, uint64(p224Uint1(x110)))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x89, x104, uint64(p224Uint1(x112)))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x91, x106, uint64(p224Uint1(x114)))
var x117 uint64
_, x117 = bits.Mul64(x107, 0xffffffffffffffff)
var x119 uint64
var x120 uint64
x120, x119 = bits.Mul64(x117, 0xffffffff)
var x121 uint64
var x122 uint64
x122, x121 = bits.Mul64(x117, 0xffffffffffffffff)
var x123 uint64
var x124 uint64
x124, x123 = bits.Mul64(x117, 0xffffffff00000000)
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64(x124, x121, uint64(0x0))
var x127 uint64
var x128 uint64
x127, x128 = bits.Add64(x122, x119, uint64(p224Uint1(x126)))
x129 := (uint64(p224Uint1(x128)) + x120)
var x131 uint64
_, x131 = bits.Add64(x107, x117, uint64(0x0))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x109, x123, uint64(p224Uint1(x131)))
var x134 uint64
var x135 uint64
x134, x135 = bits.Add64(x111, x125, uint64(p224Uint1(x133)))
var x136 uint64
var x137 uint64
x136, x137 = bits.Add64(x113, x127, uint64(p224Uint1(x135)))
var x138 uint64
var x139 uint64
x138, x139 = bits.Add64(x115, x129, uint64(p224Uint1(x137)))
x140 := (uint64(p224Uint1(x139)) + uint64(p224Uint1(x116)))
var x141 uint64
var x142 uint64
x142, x141 = bits.Mul64(x3, arg1[3])
var x143 uint64
var x144 uint64
x144, x143 = bits.Mul64(x3, arg1[2])
var x145 uint64
var x146 uint64
x146, x145 = bits.Mul64(x3, arg1[1])
var x147 uint64
var x148 uint64
x148, x147 = bits.Mul64(x3, arg1[0])
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x148, x145, uint64(0x0))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x146, x143, uint64(p224Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x144, x141, uint64(p224Uint1(x152)))
x155 := (uint64(p224Uint1(x154)) + x142)
var x156 uint64
var x157 uint64
x156, x157 = bits.Add64(x132, x147, uint64(0x0))
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x134, x149, uint64(p224Uint1(x157)))
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x136, x151, uint64(p224Uint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Add64(x138, x153, uint64(p224Uint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Add64(x140, x155, uint64(p224Uint1(x163)))
var x166 uint64
_, x166 = bits.Mul64(x156, 0xffffffffffffffff)
var x168 uint64
var x169 uint64
x169, x168 = bits.Mul64(x166, 0xffffffff)
var x170 uint64
var x171 uint64
x171, x170 = bits.Mul64(x166, 0xffffffffffffffff)
var x172 uint64
var x173 uint64
x173, x172 = bits.Mul64(x166, 0xffffffff00000000)
var x174 uint64
var x175 uint64
x174, x175 = bits.Add64(x173, x170, uint64(0x0))
var x176 uint64
var x177 uint64
x176, x177 = bits.Add64(x171, x168, uint64(p224Uint1(x175)))
x178 := (uint64(p224Uint1(x177)) + x169)
var x180 uint64
_, x180 = bits.Add64(x156, x166, uint64(0x0))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x158, x172, uint64(p224Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x160, x174, uint64(p224Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x162, x176, uint64(p224Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x164, x178, uint64(p224Uint1(x186)))
x189 := (uint64(p224Uint1(x188)) + uint64(p224Uint1(x165)))
var x190 uint64
var x191 uint64
x190, x191 = bits.Sub64(x181, uint64(0x1), uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Sub64(x183, 0xffffffff00000000, uint64(p224Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Sub64(x185, 0xffffffffffffffff, uint64(p224Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Sub64(x187, 0xffffffff, uint64(p224Uint1(x195)))
var x199 uint64
_, x199 = bits.Sub64(x189, uint64(0x0), uint64(p224Uint1(x197)))
var x200 uint64
p224CmovznzU64(&x200, p224Uint1(x199), x190, x181)
var x201 uint64
p224CmovznzU64(&x201, p224Uint1(x199), x192, x183)
var x202 uint64
p224CmovznzU64(&x202, p224Uint1(x199), x194, x185)
var x203 uint64
p224CmovznzU64(&x203, p224Uint1(x199), x196, x187)
out1[0] = x200
out1[1] = x201
out1[2] = x202
out1[3] = x203
}
// p224Add adds two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) + eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p224Add(out1 *p224MontgomeryDomainFieldElement, arg1 *p224MontgomeryDomainFieldElement, arg2 *p224MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Add64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Add64(arg1[1], arg2[1], uint64(p224Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(arg1[2], arg2[2], uint64(p224Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Add64(arg1[3], arg2[3], uint64(p224Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Sub64(x1, uint64(0x1), uint64(0x0))
var x11 uint64
var x12 uint64
x11, x12 = bits.Sub64(x3, 0xffffffff00000000, uint64(p224Uint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Sub64(x5, 0xffffffffffffffff, uint64(p224Uint1(x12)))
var x15 uint64
var x16 uint64
x15, x16 = bits.Sub64(x7, 0xffffffff, uint64(p224Uint1(x14)))
var x18 uint64
_, x18 = bits.Sub64(uint64(p224Uint1(x8)), uint64(0x0), uint64(p224Uint1(x16)))
var x19 uint64
p224CmovznzU64(&x19, p224Uint1(x18), x9, x1)
var x20 uint64
p224CmovznzU64(&x20, p224Uint1(x18), x11, x3)
var x21 uint64
p224CmovznzU64(&x21, p224Uint1(x18), x13, x5)
var x22 uint64
p224CmovznzU64(&x22, p224Uint1(x18), x15, x7)
out1[0] = x19
out1[1] = x20
out1[2] = x21
out1[3] = x22
}
// p224Sub subtracts two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) - eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p224Sub(out1 *p224MontgomeryDomainFieldElement, arg1 *p224MontgomeryDomainFieldElement, arg2 *p224MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(arg1[1], arg2[1], uint64(p224Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(arg1[2], arg2[2], uint64(p224Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(arg1[3], arg2[3], uint64(p224Uint1(x6)))
var x9 uint64
p224CmovznzU64(&x9, p224Uint1(x8), uint64(0x0), 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x1, uint64((p224Uint1(x9) & 0x1)), uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(x3, (x9 & 0xffffffff00000000), uint64(p224Uint1(x11)))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x5, x9, uint64(p224Uint1(x13)))
var x16 uint64
x16, _ = bits.Add64(x7, (x9 & 0xffffffff), uint64(p224Uint1(x15)))
out1[0] = x10
out1[1] = x12
out1[2] = x14
out1[3] = x16
}
// p224SetOne returns the field element one in the Montgomery domain.
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = 1 mod m
// 0 ≤ eval out1 < m
func p224SetOne(out1 *p224MontgomeryDomainFieldElement) {
out1[0] = 0xffffffff00000000
out1[1] = 0xffffffffffffffff
out1[2] = uint64(0x0)
out1[3] = uint64(0x0)
}
// p224FromMontgomery translates a field element out of the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = (eval arg1 * ((2^64)⁻¹ mod m)^4) mod m
// 0 ≤ eval out1 < m
func p224FromMontgomery(out1 *p224NonMontgomeryDomainFieldElement, arg1 *p224MontgomeryDomainFieldElement) {
x1 := arg1[0]
var x2 uint64
_, x2 = bits.Mul64(x1, 0xffffffffffffffff)
var x4 uint64
var x5 uint64
x5, x4 = bits.Mul64(x2, 0xffffffff)
var x6 uint64
var x7 uint64
x7, x6 = bits.Mul64(x2, 0xffffffffffffffff)
var x8 uint64
var x9 uint64
x9, x8 = bits.Mul64(x2, 0xffffffff00000000)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x9, x6, uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(x7, x4, uint64(p224Uint1(x11)))
var x15 uint64
_, x15 = bits.Add64(x1, x2, uint64(0x0))
var x16 uint64
var x17 uint64
x16, x17 = bits.Add64(uint64(0x0), x8, uint64(p224Uint1(x15)))
var x18 uint64
var x19 uint64
x18, x19 = bits.Add64(uint64(0x0), x10, uint64(p224Uint1(x17)))
var x20 uint64
var x21 uint64
x20, x21 = bits.Add64(uint64(0x0), x12, uint64(p224Uint1(x19)))
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x16, arg1[1], uint64(0x0))
var x24 uint64
var x25 uint64
x24, x25 = bits.Add64(x18, uint64(0x0), uint64(p224Uint1(x23)))
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x20, uint64(0x0), uint64(p224Uint1(x25)))
var x28 uint64
_, x28 = bits.Mul64(x22, 0xffffffffffffffff)
var x30 uint64
var x31 uint64
x31, x30 = bits.Mul64(x28, 0xffffffff)
var x32 uint64
var x33 uint64
x33, x32 = bits.Mul64(x28, 0xffffffffffffffff)
var x34 uint64
var x35 uint64
x35, x34 = bits.Mul64(x28, 0xffffffff00000000)
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(x35, x32, uint64(0x0))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(x33, x30, uint64(p224Uint1(x37)))
var x41 uint64
_, x41 = bits.Add64(x22, x28, uint64(0x0))
var x42 uint64
var x43 uint64
x42, x43 = bits.Add64(x24, x34, uint64(p224Uint1(x41)))
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(x26, x36, uint64(p224Uint1(x43)))
var x46 uint64
var x47 uint64
x46, x47 = bits.Add64((uint64(p224Uint1(x27)) + (uint64(p224Uint1(x21)) + (uint64(p224Uint1(x13)) + x5))), x38, uint64(p224Uint1(x45)))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(x42, arg1[2], uint64(0x0))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x44, uint64(0x0), uint64(p224Uint1(x49)))
var x52 uint64
var x53 uint64
x52, x53 = bits.Add64(x46, uint64(0x0), uint64(p224Uint1(x51)))
var x54 uint64
_, x54 = bits.Mul64(x48, 0xffffffffffffffff)
var x56 uint64
var x57 uint64
x57, x56 = bits.Mul64(x54, 0xffffffff)
var x58 uint64
var x59 uint64
x59, x58 = bits.Mul64(x54, 0xffffffffffffffff)
var x60 uint64
var x61 uint64
x61, x60 = bits.Mul64(x54, 0xffffffff00000000)
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x61, x58, uint64(0x0))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(x59, x56, uint64(p224Uint1(x63)))
var x67 uint64
_, x67 = bits.Add64(x48, x54, uint64(0x0))
var x68 uint64
var x69 uint64
x68, x69 = bits.Add64(x50, x60, uint64(p224Uint1(x67)))
var x70 uint64
var x71 uint64
x70, x71 = bits.Add64(x52, x62, uint64(p224Uint1(x69)))
var x72 uint64
var x73 uint64
x72, x73 = bits.Add64((uint64(p224Uint1(x53)) + (uint64(p224Uint1(x47)) + (uint64(p224Uint1(x39)) + x31))), x64, uint64(p224Uint1(x71)))
var x74 uint64
var x75 uint64
x74, x75 = bits.Add64(x68, arg1[3], uint64(0x0))
var x76 uint64
var x77 uint64
x76, x77 = bits.Add64(x70, uint64(0x0), uint64(p224Uint1(x75)))
var x78 uint64
var x79 uint64
x78, x79 = bits.Add64(x72, uint64(0x0), uint64(p224Uint1(x77)))
var x80 uint64
_, x80 = bits.Mul64(x74, 0xffffffffffffffff)
var x82 uint64
var x83 uint64
x83, x82 = bits.Mul64(x80, 0xffffffff)
var x84 uint64
var x85 uint64
x85, x84 = bits.Mul64(x80, 0xffffffffffffffff)
var x86 uint64
var x87 uint64
x87, x86 = bits.Mul64(x80, 0xffffffff00000000)
var x88 uint64
var x89 uint64
x88, x89 = bits.Add64(x87, x84, uint64(0x0))
var x90 uint64
var x91 uint64
x90, x91 = bits.Add64(x85, x82, uint64(p224Uint1(x89)))
var x93 uint64
_, x93 = bits.Add64(x74, x80, uint64(0x0))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x76, x86, uint64(p224Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x78, x88, uint64(p224Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64((uint64(p224Uint1(x79)) + (uint64(p224Uint1(x73)) + (uint64(p224Uint1(x65)) + x57))), x90, uint64(p224Uint1(x97)))
x100 := (uint64(p224Uint1(x99)) + (uint64(p224Uint1(x91)) + x83))
var x101 uint64
var x102 uint64
x101, x102 = bits.Sub64(x94, uint64(0x1), uint64(0x0))
var x103 uint64
var x104 uint64
x103, x104 = bits.Sub64(x96, 0xffffffff00000000, uint64(p224Uint1(x102)))
var x105 uint64
var x106 uint64
x105, x106 = bits.Sub64(x98, 0xffffffffffffffff, uint64(p224Uint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Sub64(x100, 0xffffffff, uint64(p224Uint1(x106)))
var x110 uint64
_, x110 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(p224Uint1(x108)))
var x111 uint64
p224CmovznzU64(&x111, p224Uint1(x110), x101, x94)
var x112 uint64
p224CmovznzU64(&x112, p224Uint1(x110), x103, x96)
var x113 uint64
p224CmovznzU64(&x113, p224Uint1(x110), x105, x98)
var x114 uint64
p224CmovznzU64(&x114, p224Uint1(x110), x107, x100)
out1[0] = x111
out1[1] = x112
out1[2] = x113
out1[3] = x114
}
// p224ToMontgomery translates a field element into the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = eval arg1 mod m
// 0 ≤ eval out1 < m
func p224ToMontgomery(out1 *p224MontgomeryDomainFieldElement, arg1 *p224NonMontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, 0xffffffff)
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, 0xfffffffe00000000)
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, 0xffffffff00000000)
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, 0xffffffff00000001)
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p224Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p224Uint1(x16)))
var x19 uint64
_, x19 = bits.Mul64(x11, 0xffffffffffffffff)
var x21 uint64
var x22 uint64
x22, x21 = bits.Mul64(x19, 0xffffffff)
var x23 uint64
var x24 uint64
x24, x23 = bits.Mul64(x19, 0xffffffffffffffff)
var x25 uint64
var x26 uint64
x26, x25 = bits.Mul64(x19, 0xffffffff00000000)
var x27 uint64
var x28 uint64
x27, x28 = bits.Add64(x26, x23, uint64(0x0))
var x29 uint64
var x30 uint64
x29, x30 = bits.Add64(x24, x21, uint64(p224Uint1(x28)))
var x32 uint64
_, x32 = bits.Add64(x11, x19, uint64(0x0))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x13, x25, uint64(p224Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x15, x27, uint64(p224Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x17, x29, uint64(p224Uint1(x36)))
var x39 uint64
var x40 uint64
x40, x39 = bits.Mul64(x1, 0xffffffff)
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, 0xfffffffe00000000)
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, 0xffffffff00000000)
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, 0xffffffff00000001)
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(x46, x43, uint64(0x0))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x44, x41, uint64(p224Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x42, x39, uint64(p224Uint1(x50)))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x33, x45, uint64(0x0))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x35, x47, uint64(p224Uint1(x54)))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(x37, x49, uint64(p224Uint1(x56)))
var x59 uint64
var x60 uint64
x59, x60 = bits.Add64(((uint64(p224Uint1(x38)) + (uint64(p224Uint1(x18)) + x6)) + (uint64(p224Uint1(x30)) + x22)), x51, uint64(p224Uint1(x58)))
var x61 uint64
_, x61 = bits.Mul64(x53, 0xffffffffffffffff)
var x63 uint64
var x64 uint64
x64, x63 = bits.Mul64(x61, 0xffffffff)
var x65 uint64
var x66 uint64
x66, x65 = bits.Mul64(x61, 0xffffffffffffffff)
var x67 uint64
var x68 uint64
x68, x67 = bits.Mul64(x61, 0xffffffff00000000)
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x68, x65, uint64(0x0))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x66, x63, uint64(p224Uint1(x70)))
var x74 uint64
_, x74 = bits.Add64(x53, x61, uint64(0x0))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x55, x67, uint64(p224Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x57, x69, uint64(p224Uint1(x76)))
var x79 uint64
var x80 uint64
x79, x80 = bits.Add64(x59, x71, uint64(p224Uint1(x78)))
var x81 uint64
var x82 uint64
x82, x81 = bits.Mul64(x2, 0xffffffff)
var x83 uint64
var x84 uint64
x84, x83 = bits.Mul64(x2, 0xfffffffe00000000)
var x85 uint64
var x86 uint64
x86, x85 = bits.Mul64(x2, 0xffffffff00000000)
var x87 uint64
var x88 uint64
x88, x87 = bits.Mul64(x2, 0xffffffff00000001)
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x88, x85, uint64(0x0))
var x91 uint64
var x92 uint64
x91, x92 = bits.Add64(x86, x83, uint64(p224Uint1(x90)))
var x93 uint64
var x94 uint64
x93, x94 = bits.Add64(x84, x81, uint64(p224Uint1(x92)))
var x95 uint64
var x96 uint64
x95, x96 = bits.Add64(x75, x87, uint64(0x0))
var x97 uint64
var x98 uint64
x97, x98 = bits.Add64(x77, x89, uint64(p224Uint1(x96)))
var x99 uint64
var x100 uint64
x99, x100 = bits.Add64(x79, x91, uint64(p224Uint1(x98)))
var x101 uint64
var x102 uint64
x101, x102 = bits.Add64(((uint64(p224Uint1(x80)) + (uint64(p224Uint1(x60)) + (uint64(p224Uint1(x52)) + x40))) + (uint64(p224Uint1(x72)) + x64)), x93, uint64(p224Uint1(x100)))
var x103 uint64
_, x103 = bits.Mul64(x95, 0xffffffffffffffff)
var x105 uint64
var x106 uint64
x106, x105 = bits.Mul64(x103, 0xffffffff)
var x107 uint64
var x108 uint64
x108, x107 = bits.Mul64(x103, 0xffffffffffffffff)
var x109 uint64
var x110 uint64
x110, x109 = bits.Mul64(x103, 0xffffffff00000000)
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x110, x107, uint64(0x0))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x108, x105, uint64(p224Uint1(x112)))
var x116 uint64
_, x116 = bits.Add64(x95, x103, uint64(0x0))
var x117 uint64
var x118 uint64
x117, x118 = bits.Add64(x97, x109, uint64(p224Uint1(x116)))
var x119 uint64
var x120 uint64
x119, x120 = bits.Add64(x99, x111, uint64(p224Uint1(x118)))
var x121 uint64
var x122 uint64
x121, x122 = bits.Add64(x101, x113, uint64(p224Uint1(x120)))
var x123 uint64
var x124 uint64
x124, x123 = bits.Mul64(x3, 0xffffffff)
var x125 uint64
var x126 uint64
x126, x125 = bits.Mul64(x3, 0xfffffffe00000000)
var x127 uint64
var x128 uint64
x128, x127 = bits.Mul64(x3, 0xffffffff00000000)
var x129 uint64
var x130 uint64
x130, x129 = bits.Mul64(x3, 0xffffffff00000001)
var x131 uint64
var x132 uint64
x131, x132 = bits.Add64(x130, x127, uint64(0x0))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x128, x125, uint64(p224Uint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x126, x123, uint64(p224Uint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x117, x129, uint64(0x0))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x119, x131, uint64(p224Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x121, x133, uint64(p224Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(((uint64(p224Uint1(x122)) + (uint64(p224Uint1(x102)) + (uint64(p224Uint1(x94)) + x82))) + (uint64(p224Uint1(x114)) + x106)), x135, uint64(p224Uint1(x142)))
var x145 uint64
_, x145 = bits.Mul64(x137, 0xffffffffffffffff)
var x147 uint64
var x148 uint64
x148, x147 = bits.Mul64(x145, 0xffffffff)
var x149 uint64
var x150 uint64
x150, x149 = bits.Mul64(x145, 0xffffffffffffffff)
var x151 uint64
var x152 uint64
x152, x151 = bits.Mul64(x145, 0xffffffff00000000)
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x152, x149, uint64(0x0))
var x155 uint64
var x156 uint64
x155, x156 = bits.Add64(x150, x147, uint64(p224Uint1(x154)))
var x158 uint64
_, x158 = bits.Add64(x137, x145, uint64(0x0))
var x159 uint64
var x160 uint64
x159, x160 = bits.Add64(x139, x151, uint64(p224Uint1(x158)))
var x161 uint64
var x162 uint64
x161, x162 = bits.Add64(x141, x153, uint64(p224Uint1(x160)))
var x163 uint64
var x164 uint64
x163, x164 = bits.Add64(x143, x155, uint64(p224Uint1(x162)))
x165 := ((uint64(p224Uint1(x164)) + (uint64(p224Uint1(x144)) + (uint64(p224Uint1(x136)) + x124))) + (uint64(p224Uint1(x156)) + x148))
var x166 uint64
var x167 uint64
x166, x167 = bits.Sub64(x159, uint64(0x1), uint64(0x0))
var x168 uint64
var x169 uint64
x168, x169 = bits.Sub64(x161, 0xffffffff00000000, uint64(p224Uint1(x167)))
var x170 uint64
var x171 uint64
x170, x171 = bits.Sub64(x163, 0xffffffffffffffff, uint64(p224Uint1(x169)))
var x172 uint64
var x173 uint64
x172, x173 = bits.Sub64(x165, 0xffffffff, uint64(p224Uint1(x171)))
var x175 uint64
_, x175 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(p224Uint1(x173)))
var x176 uint64
p224CmovznzU64(&x176, p224Uint1(x175), x166, x159)
var x177 uint64
p224CmovznzU64(&x177, p224Uint1(x175), x168, x161)
var x178 uint64
p224CmovznzU64(&x178, p224Uint1(x175), x170, x163)
var x179 uint64
p224CmovznzU64(&x179, p224Uint1(x175), x172, x165)
out1[0] = x176
out1[1] = x177
out1[2] = x178
out1[3] = x179
}
// p224Selectznz is a multi-limb conditional select.
//
// Postconditions:
//
// eval out1 = (if arg1 = 0 then eval arg2 else eval arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
// arg3: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p224Selectznz(out1 *[4]uint64, arg1 p224Uint1, arg2 *[4]uint64, arg3 *[4]uint64) {
var x1 uint64
p224CmovznzU64(&x1, arg1, arg2[0], arg3[0])
var x2 uint64
p224CmovznzU64(&x2, arg1, arg2[1], arg3[1])
var x3 uint64
p224CmovznzU64(&x3, arg1, arg2[2], arg3[2])
var x4 uint64
p224CmovznzU64(&x4, arg1, arg2[3], arg3[3])
out1[0] = x1
out1[1] = x2
out1[2] = x3
out1[3] = x4
}
// p224ToBytes serializes a field element NOT in the Montgomery domain to bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = map (λ x, ⌊((eval arg1 mod m) mod 2^(8 * (x + 1))) / 2^(8 * x)⌋) [0..27]
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
func p224ToBytes(out1 *[28]uint8, arg1 *[4]uint64) {
x1 := arg1[3]
x2 := arg1[2]
x3 := arg1[1]
x4 := arg1[0]
x5 := (uint8(x4) & 0xff)
x6 := (x4 >> 8)
x7 := (uint8(x6) & 0xff)
x8 := (x6 >> 8)
x9 := (uint8(x8) & 0xff)
x10 := (x8 >> 8)
x11 := (uint8(x10) & 0xff)
x12 := (x10 >> 8)
x13 := (uint8(x12) & 0xff)
x14 := (x12 >> 8)
x15 := (uint8(x14) & 0xff)
x16 := (x14 >> 8)
x17 := (uint8(x16) & 0xff)
x18 := uint8((x16 >> 8))
x19 := (uint8(x3) & 0xff)
x20 := (x3 >> 8)
x21 := (uint8(x20) & 0xff)
x22 := (x20 >> 8)
x23 := (uint8(x22) & 0xff)
x24 := (x22 >> 8)
x25 := (uint8(x24) & 0xff)
x26 := (x24 >> 8)
x27 := (uint8(x26) & 0xff)
x28 := (x26 >> 8)
x29 := (uint8(x28) & 0xff)
x30 := (x28 >> 8)
x31 := (uint8(x30) & 0xff)
x32 := uint8((x30 >> 8))
x33 := (uint8(x2) & 0xff)
x34 := (x2 >> 8)
x35 := (uint8(x34) & 0xff)
x36 := (x34 >> 8)
x37 := (uint8(x36) & 0xff)
x38 := (x36 >> 8)
x39 := (uint8(x38) & 0xff)
x40 := (x38 >> 8)
x41 := (uint8(x40) & 0xff)
x42 := (x40 >> 8)
x43 := (uint8(x42) & 0xff)
x44 := (x42 >> 8)
x45 := (uint8(x44) & 0xff)
x46 := uint8((x44 >> 8))
x47 := (uint8(x1) & 0xff)
x48 := (x1 >> 8)
x49 := (uint8(x48) & 0xff)
x50 := (x48 >> 8)
x51 := (uint8(x50) & 0xff)
x52 := uint8((x50 >> 8))
out1[0] = x5
out1[1] = x7
out1[2] = x9
out1[3] = x11
out1[4] = x13
out1[5] = x15
out1[6] = x17
out1[7] = x18
out1[8] = x19
out1[9] = x21
out1[10] = x23
out1[11] = x25
out1[12] = x27
out1[13] = x29
out1[14] = x31
out1[15] = x32
out1[16] = x33
out1[17] = x35
out1[18] = x37
out1[19] = x39
out1[20] = x41
out1[21] = x43
out1[22] = x45
out1[23] = x46
out1[24] = x47
out1[25] = x49
out1[26] = x51
out1[27] = x52
}
// p224FromBytes deserializes a field element NOT in the Montgomery domain from bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ bytes_eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = bytes_eval arg1 mod m
// 0 ≤ eval out1 < m
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffff]]
func p224FromBytes(out1 *[4]uint64, arg1 *[28]uint8) {
x1 := (uint64(arg1[27]) << 24)
x2 := (uint64(arg1[26]) << 16)
x3 := (uint64(arg1[25]) << 8)
x4 := arg1[24]
x5 := (uint64(arg1[23]) << 56)
x6 := (uint64(arg1[22]) << 48)
x7 := (uint64(arg1[21]) << 40)
x8 := (uint64(arg1[20]) << 32)
x9 := (uint64(arg1[19]) << 24)
x10 := (uint64(arg1[18]) << 16)
x11 := (uint64(arg1[17]) << 8)
x12 := arg1[16]
x13 := (uint64(arg1[15]) << 56)
x14 := (uint64(arg1[14]) << 48)
x15 := (uint64(arg1[13]) << 40)
x16 := (uint64(arg1[12]) << 32)
x17 := (uint64(arg1[11]) << 24)
x18 := (uint64(arg1[10]) << 16)
x19 := (uint64(arg1[9]) << 8)
x20 := arg1[8]
x21 := (uint64(arg1[7]) << 56)
x22 := (uint64(arg1[6]) << 48)
x23 := (uint64(arg1[5]) << 40)
x24 := (uint64(arg1[4]) << 32)
x25 := (uint64(arg1[3]) << 24)
x26 := (uint64(arg1[2]) << 16)
x27 := (uint64(arg1[1]) << 8)
x28 := arg1[0]
x29 := (x27 + uint64(x28))
x30 := (x26 + x29)
x31 := (x25 + x30)
x32 := (x24 + x31)
x33 := (x23 + x32)
x34 := (x22 + x33)
x35 := (x21 + x34)
x36 := (x19 + uint64(x20))
x37 := (x18 + x36)
x38 := (x17 + x37)
x39 := (x16 + x38)
x40 := (x15 + x39)
x41 := (x14 + x40)
x42 := (x13 + x41)
x43 := (x11 + uint64(x12))
x44 := (x10 + x43)
x45 := (x9 + x44)
x46 := (x8 + x45)
x47 := (x7 + x46)
x48 := (x6 + x47)
x49 := (x5 + x48)
x50 := (x3 + uint64(x4))
x51 := (x2 + x50)
x52 := (x1 + x51)
out1[0] = x35
out1[1] = x42
out1[2] = x49
out1[3] = x52
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by addchain. DO NOT EDIT.
package fiat
// Invert sets e = 1/x, and returns e.
//
// If x == 0, Invert returns e = 0.
func (e *P224Element) Invert(x *P224Element) *P224Element {
// Inversion is implemented as exponentiation with exponent p − 2.
// The sequence of 11 multiplications and 223 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// x12 = _111111 << 6 + _111111
// x14 = x12 << 2 + _11
// x17 = x14 << 3 + _111
// x31 = x17 << 14 + x14
// x48 = x31 << 17 + x17
// x96 = x48 << 48 + x48
// x127 = x96 << 31 + x31
// return x127 << 97 + x96
//
var z = new(P224Element).Set(e)
var t0 = new(P224Element)
var t1 = new(P224Element)
var t2 = new(P224Element)
z.Square(x)
t0.Mul(x, z)
z.Square(t0)
z.Mul(x, z)
t1.Square(z)
for s := 1; s < 3; s++ {
t1.Square(t1)
}
t1.Mul(z, t1)
t2.Square(t1)
for s := 1; s < 6; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
for s := 0; s < 2; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
t1.Square(t0)
for s := 1; s < 3; s++ {
t1.Square(t1)
}
z.Mul(z, t1)
t1.Square(z)
for s := 1; s < 14; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
t1.Square(t0)
for s := 1; s < 17; s++ {
t1.Square(t1)
}
z.Mul(z, t1)
t1.Square(z)
for s := 1; s < 48; s++ {
t1.Square(t1)
}
z.Mul(z, t1)
t1.Square(z)
for s := 1; s < 31; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 97; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
return e.Set(z)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package fiat
import (
"crypto/subtle"
"errors"
)
// P256Element is an integer modulo 2^256 - 2^224 + 2^192 + 2^96 - 1.
//
// The zero value is a valid zero element.
type P256Element struct {
// Values are represented internally always in the Montgomery domain, and
// converted in Bytes and SetBytes.
x p256MontgomeryDomainFieldElement
}
const p256ElementLen = 32
type p256UntypedFieldElement = [4]uint64
// One sets e = 1, and returns e.
func (e *P256Element) One() *P256Element {
p256SetOne(&e.x)
return e
}
// Equal returns 1 if e == t, and zero otherwise.
func (e *P256Element) Equal(t *P256Element) int {
eBytes := e.Bytes()
tBytes := t.Bytes()
return subtle.ConstantTimeCompare(eBytes, tBytes)
}
// IsZero returns 1 if e == 0, and zero otherwise.
func (e *P256Element) IsZero() int {
zero := make([]byte, p256ElementLen)
eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, zero)
}
// Set sets e = t, and returns e.
func (e *P256Element) Set(t *P256Element) *P256Element {
e.x = t.x
return e
}
// Bytes returns the 32-byte big-endian encoding of e.
func (e *P256Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p256ElementLen]byte
return e.bytes(&out)
}
func (e *P256Element) bytes(out *[p256ElementLen]byte) []byte {
var tmp p256NonMontgomeryDomainFieldElement
p256FromMontgomery(&tmp, &e.x)
p256ToBytes(out, (*p256UntypedFieldElement)(&tmp))
p256InvertEndianness(out[:])
return out[:]
}
// SetBytes sets e = v, where v is a big-endian 32-byte encoding, and returns e.
// If v is not 32 bytes or it encodes a value higher than 2^256 - 2^224 + 2^192 + 2^96 - 1,
// SetBytes returns nil and an error, and e is unchanged.
func (e *P256Element) SetBytes(v []byte) (*P256Element, error) {
if len(v) != p256ElementLen {
return nil, errors.New("invalid P256Element encoding")
}
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P256Element).Sub(
new(P256Element), new(P256Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P256Element encoding")
}
}
var in [p256ElementLen]byte
copy(in[:], v)
p256InvertEndianness(in[:])
var tmp p256NonMontgomeryDomainFieldElement
p256FromBytes((*p256UntypedFieldElement)(&tmp), &in)
p256ToMontgomery(&e.x, &tmp)
return e, nil
}
// Add sets e = t1 + t2, and returns e.
func (e *P256Element) Add(t1, t2 *P256Element) *P256Element {
p256Add(&e.x, &t1.x, &t2.x)
return e
}
// Sub sets e = t1 - t2, and returns e.
func (e *P256Element) Sub(t1, t2 *P256Element) *P256Element {
p256Sub(&e.x, &t1.x, &t2.x)
return e
}
// Mul sets e = t1 * t2, and returns e.
func (e *P256Element) Mul(t1, t2 *P256Element) *P256Element {
p256Mul(&e.x, &t1.x, &t2.x)
return e
}
// Square sets e = t * t, and returns e.
func (e *P256Element) Square(t *P256Element) *P256Element {
p256Square(&e.x, &t.x)
return e
}
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *P256Element) Select(a, b *P256Element, cond int) *P256Element {
p256Selectznz((*p256UntypedFieldElement)(&v.x), p256Uint1(cond),
(*p256UntypedFieldElement)(&b.x), (*p256UntypedFieldElement)(&a.x))
return v
}
func p256InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]
}
}
// Code generated by Fiat Cryptography. DO NOT EDIT.
//
// Autogenerated: word_by_word_montgomery --lang Go --no-wide-int --cmovznz-by-mul --relax-primitive-carry-to-bitwidth 32,64 --internal-static --public-function-case camelCase --public-type-case camelCase --private-function-case camelCase --private-type-case camelCase --doc-text-before-function-name '' --doc-newline-before-package-declaration --doc-prepend-header 'Code generated by Fiat Cryptography. DO NOT EDIT.' --package-name fiat --no-prefix-fiat p256 64 '2^256 - 2^224 + 2^192 + 2^96 - 1' mul square add sub one from_montgomery to_montgomery selectznz to_bytes from_bytes
//
// curve description: p256
//
// machine_wordsize = 64 (from "64")
//
// requested operations: mul, square, add, sub, one, from_montgomery, to_montgomery, selectznz, to_bytes, from_bytes
//
// m = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff (from "2^256 - 2^224 + 2^192 + 2^96 - 1")
//
//
//
// NOTE: In addition to the bounds specified above each function, all
//
// functions synthesized for this Montgomery arithmetic require the
//
// input to be strictly less than the prime modulus (m), and also
//
// require the input to be in the unique saturated representation.
//
// All functions also ensure that these two properties are true of
//
// return values.
//
//
//
// Computed values:
//
// eval z = z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192)
//
// bytes_eval z = z[0] + (z[1] << 8) + (z[2] << 16) + (z[3] << 24) + (z[4] << 32) + (z[5] << 40) + (z[6] << 48) + (z[7] << 56) + (z[8] << 64) + (z[9] << 72) + (z[10] << 80) + (z[11] << 88) + (z[12] << 96) + (z[13] << 104) + (z[14] << 112) + (z[15] << 120) + (z[16] << 128) + (z[17] << 136) + (z[18] << 144) + (z[19] << 152) + (z[20] << 160) + (z[21] << 168) + (z[22] << 176) + (z[23] << 184) + (z[24] << 192) + (z[25] << 200) + (z[26] << 208) + (z[27] << 216) + (z[28] << 224) + (z[29] << 232) + (z[30] << 240) + (z[31] << 248)
//
// twos_complement_eval z = let x1 := z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) in
//
// if x1 & (2^256-1) < 2^255 then x1 & (2^256-1) else (x1 & (2^256-1)) - 2^256
package fiat
import "math/bits"
type p256Uint1 uint64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
type p256Int1 int64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
// The type p256MontgomeryDomainFieldElement is a field element in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p256MontgomeryDomainFieldElement [4]uint64
// The type p256NonMontgomeryDomainFieldElement is a field element NOT in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p256NonMontgomeryDomainFieldElement [4]uint64
// p256CmovznzU64 is a single-word conditional move.
//
// Postconditions:
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func p256CmovznzU64(out1 *uint64, arg1 p256Uint1, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// p256Mul multiplies two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p256Mul(out1 *p256MontgomeryDomainFieldElement, arg1 *p256MontgomeryDomainFieldElement, arg2 *p256MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, arg2[3])
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, arg2[2])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, arg2[1])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, arg2[0])
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p256Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p256Uint1(x16)))
x19 := (uint64(p256Uint1(x18)) + x6)
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x11, 0xffffffff00000001)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x11, 0xffffffff)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x11, 0xffffffffffffffff)
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x25, x22, uint64(0x0))
x28 := (uint64(p256Uint1(x27)) + x23)
var x30 uint64
_, x30 = bits.Add64(x11, x24, uint64(0x0))
var x31 uint64
var x32 uint64
x31, x32 = bits.Add64(x13, x26, uint64(p256Uint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x15, x28, uint64(p256Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x17, x20, uint64(p256Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x19, x21, uint64(p256Uint1(x36)))
var x39 uint64
var x40 uint64
x40, x39 = bits.Mul64(x1, arg2[3])
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, arg2[2])
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, arg2[1])
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, arg2[0])
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(x46, x43, uint64(0x0))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x44, x41, uint64(p256Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x42, x39, uint64(p256Uint1(x50)))
x53 := (uint64(p256Uint1(x52)) + x40)
var x54 uint64
var x55 uint64
x54, x55 = bits.Add64(x31, x45, uint64(0x0))
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x33, x47, uint64(p256Uint1(x55)))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x35, x49, uint64(p256Uint1(x57)))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x37, x51, uint64(p256Uint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(uint64(p256Uint1(x38)), x53, uint64(p256Uint1(x61)))
var x64 uint64
var x65 uint64
x65, x64 = bits.Mul64(x54, 0xffffffff00000001)
var x66 uint64
var x67 uint64
x67, x66 = bits.Mul64(x54, 0xffffffff)
var x68 uint64
var x69 uint64
x69, x68 = bits.Mul64(x54, 0xffffffffffffffff)
var x70 uint64
var x71 uint64
x70, x71 = bits.Add64(x69, x66, uint64(0x0))
x72 := (uint64(p256Uint1(x71)) + x67)
var x74 uint64
_, x74 = bits.Add64(x54, x68, uint64(0x0))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x56, x70, uint64(p256Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x58, x72, uint64(p256Uint1(x76)))
var x79 uint64
var x80 uint64
x79, x80 = bits.Add64(x60, x64, uint64(p256Uint1(x78)))
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x62, x65, uint64(p256Uint1(x80)))
x83 := (uint64(p256Uint1(x82)) + uint64(p256Uint1(x63)))
var x84 uint64
var x85 uint64
x85, x84 = bits.Mul64(x2, arg2[3])
var x86 uint64
var x87 uint64
x87, x86 = bits.Mul64(x2, arg2[2])
var x88 uint64
var x89 uint64
x89, x88 = bits.Mul64(x2, arg2[1])
var x90 uint64
var x91 uint64
x91, x90 = bits.Mul64(x2, arg2[0])
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x91, x88, uint64(0x0))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x89, x86, uint64(p256Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x87, x84, uint64(p256Uint1(x95)))
x98 := (uint64(p256Uint1(x97)) + x85)
var x99 uint64
var x100 uint64
x99, x100 = bits.Add64(x75, x90, uint64(0x0))
var x101 uint64
var x102 uint64
x101, x102 = bits.Add64(x77, x92, uint64(p256Uint1(x100)))
var x103 uint64
var x104 uint64
x103, x104 = bits.Add64(x79, x94, uint64(p256Uint1(x102)))
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x81, x96, uint64(p256Uint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x83, x98, uint64(p256Uint1(x106)))
var x109 uint64
var x110 uint64
x110, x109 = bits.Mul64(x99, 0xffffffff00000001)
var x111 uint64
var x112 uint64
x112, x111 = bits.Mul64(x99, 0xffffffff)
var x113 uint64
var x114 uint64
x114, x113 = bits.Mul64(x99, 0xffffffffffffffff)
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x114, x111, uint64(0x0))
x117 := (uint64(p256Uint1(x116)) + x112)
var x119 uint64
_, x119 = bits.Add64(x99, x113, uint64(0x0))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x101, x115, uint64(p256Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x103, x117, uint64(p256Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x105, x109, uint64(p256Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x107, x110, uint64(p256Uint1(x125)))
x128 := (uint64(p256Uint1(x127)) + uint64(p256Uint1(x108)))
var x129 uint64
var x130 uint64
x130, x129 = bits.Mul64(x3, arg2[3])
var x131 uint64
var x132 uint64
x132, x131 = bits.Mul64(x3, arg2[2])
var x133 uint64
var x134 uint64
x134, x133 = bits.Mul64(x3, arg2[1])
var x135 uint64
var x136 uint64
x136, x135 = bits.Mul64(x3, arg2[0])
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x136, x133, uint64(0x0))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x134, x131, uint64(p256Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x132, x129, uint64(p256Uint1(x140)))
x143 := (uint64(p256Uint1(x142)) + x130)
var x144 uint64
var x145 uint64
x144, x145 = bits.Add64(x120, x135, uint64(0x0))
var x146 uint64
var x147 uint64
x146, x147 = bits.Add64(x122, x137, uint64(p256Uint1(x145)))
var x148 uint64
var x149 uint64
x148, x149 = bits.Add64(x124, x139, uint64(p256Uint1(x147)))
var x150 uint64
var x151 uint64
x150, x151 = bits.Add64(x126, x141, uint64(p256Uint1(x149)))
var x152 uint64
var x153 uint64
x152, x153 = bits.Add64(x128, x143, uint64(p256Uint1(x151)))
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x144, 0xffffffff00000001)
var x156 uint64
var x157 uint64
x157, x156 = bits.Mul64(x144, 0xffffffff)
var x158 uint64
var x159 uint64
x159, x158 = bits.Mul64(x144, 0xffffffffffffffff)
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x159, x156, uint64(0x0))
x162 := (uint64(p256Uint1(x161)) + x157)
var x164 uint64
_, x164 = bits.Add64(x144, x158, uint64(0x0))
var x165 uint64
var x166 uint64
x165, x166 = bits.Add64(x146, x160, uint64(p256Uint1(x164)))
var x167 uint64
var x168 uint64
x167, x168 = bits.Add64(x148, x162, uint64(p256Uint1(x166)))
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x150, x154, uint64(p256Uint1(x168)))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x152, x155, uint64(p256Uint1(x170)))
x173 := (uint64(p256Uint1(x172)) + uint64(p256Uint1(x153)))
var x174 uint64
var x175 uint64
x174, x175 = bits.Sub64(x165, 0xffffffffffffffff, uint64(0x0))
var x176 uint64
var x177 uint64
x176, x177 = bits.Sub64(x167, 0xffffffff, uint64(p256Uint1(x175)))
var x178 uint64
var x179 uint64
x178, x179 = bits.Sub64(x169, uint64(0x0), uint64(p256Uint1(x177)))
var x180 uint64
var x181 uint64
x180, x181 = bits.Sub64(x171, 0xffffffff00000001, uint64(p256Uint1(x179)))
var x183 uint64
_, x183 = bits.Sub64(x173, uint64(0x0), uint64(p256Uint1(x181)))
var x184 uint64
p256CmovznzU64(&x184, p256Uint1(x183), x174, x165)
var x185 uint64
p256CmovznzU64(&x185, p256Uint1(x183), x176, x167)
var x186 uint64
p256CmovznzU64(&x186, p256Uint1(x183), x178, x169)
var x187 uint64
p256CmovznzU64(&x187, p256Uint1(x183), x180, x171)
out1[0] = x184
out1[1] = x185
out1[2] = x186
out1[3] = x187
}
// p256Square squares a field element in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg1)) mod m
// 0 ≤ eval out1 < m
func p256Square(out1 *p256MontgomeryDomainFieldElement, arg1 *p256MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, arg1[3])
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, arg1[2])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, arg1[1])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, arg1[0])
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p256Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p256Uint1(x16)))
x19 := (uint64(p256Uint1(x18)) + x6)
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x11, 0xffffffff00000001)
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x11, 0xffffffff)
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x11, 0xffffffffffffffff)
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x25, x22, uint64(0x0))
x28 := (uint64(p256Uint1(x27)) + x23)
var x30 uint64
_, x30 = bits.Add64(x11, x24, uint64(0x0))
var x31 uint64
var x32 uint64
x31, x32 = bits.Add64(x13, x26, uint64(p256Uint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x15, x28, uint64(p256Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x17, x20, uint64(p256Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x19, x21, uint64(p256Uint1(x36)))
var x39 uint64
var x40 uint64
x40, x39 = bits.Mul64(x1, arg1[3])
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, arg1[2])
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, arg1[1])
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x1, arg1[0])
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(x46, x43, uint64(0x0))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x44, x41, uint64(p256Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x42, x39, uint64(p256Uint1(x50)))
x53 := (uint64(p256Uint1(x52)) + x40)
var x54 uint64
var x55 uint64
x54, x55 = bits.Add64(x31, x45, uint64(0x0))
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x33, x47, uint64(p256Uint1(x55)))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x35, x49, uint64(p256Uint1(x57)))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x37, x51, uint64(p256Uint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(uint64(p256Uint1(x38)), x53, uint64(p256Uint1(x61)))
var x64 uint64
var x65 uint64
x65, x64 = bits.Mul64(x54, 0xffffffff00000001)
var x66 uint64
var x67 uint64
x67, x66 = bits.Mul64(x54, 0xffffffff)
var x68 uint64
var x69 uint64
x69, x68 = bits.Mul64(x54, 0xffffffffffffffff)
var x70 uint64
var x71 uint64
x70, x71 = bits.Add64(x69, x66, uint64(0x0))
x72 := (uint64(p256Uint1(x71)) + x67)
var x74 uint64
_, x74 = bits.Add64(x54, x68, uint64(0x0))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x56, x70, uint64(p256Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x58, x72, uint64(p256Uint1(x76)))
var x79 uint64
var x80 uint64
x79, x80 = bits.Add64(x60, x64, uint64(p256Uint1(x78)))
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x62, x65, uint64(p256Uint1(x80)))
x83 := (uint64(p256Uint1(x82)) + uint64(p256Uint1(x63)))
var x84 uint64
var x85 uint64
x85, x84 = bits.Mul64(x2, arg1[3])
var x86 uint64
var x87 uint64
x87, x86 = bits.Mul64(x2, arg1[2])
var x88 uint64
var x89 uint64
x89, x88 = bits.Mul64(x2, arg1[1])
var x90 uint64
var x91 uint64
x91, x90 = bits.Mul64(x2, arg1[0])
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x91, x88, uint64(0x0))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x89, x86, uint64(p256Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x87, x84, uint64(p256Uint1(x95)))
x98 := (uint64(p256Uint1(x97)) + x85)
var x99 uint64
var x100 uint64
x99, x100 = bits.Add64(x75, x90, uint64(0x0))
var x101 uint64
var x102 uint64
x101, x102 = bits.Add64(x77, x92, uint64(p256Uint1(x100)))
var x103 uint64
var x104 uint64
x103, x104 = bits.Add64(x79, x94, uint64(p256Uint1(x102)))
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x81, x96, uint64(p256Uint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x83, x98, uint64(p256Uint1(x106)))
var x109 uint64
var x110 uint64
x110, x109 = bits.Mul64(x99, 0xffffffff00000001)
var x111 uint64
var x112 uint64
x112, x111 = bits.Mul64(x99, 0xffffffff)
var x113 uint64
var x114 uint64
x114, x113 = bits.Mul64(x99, 0xffffffffffffffff)
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x114, x111, uint64(0x0))
x117 := (uint64(p256Uint1(x116)) + x112)
var x119 uint64
_, x119 = bits.Add64(x99, x113, uint64(0x0))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x101, x115, uint64(p256Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x103, x117, uint64(p256Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x105, x109, uint64(p256Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x107, x110, uint64(p256Uint1(x125)))
x128 := (uint64(p256Uint1(x127)) + uint64(p256Uint1(x108)))
var x129 uint64
var x130 uint64
x130, x129 = bits.Mul64(x3, arg1[3])
var x131 uint64
var x132 uint64
x132, x131 = bits.Mul64(x3, arg1[2])
var x133 uint64
var x134 uint64
x134, x133 = bits.Mul64(x3, arg1[1])
var x135 uint64
var x136 uint64
x136, x135 = bits.Mul64(x3, arg1[0])
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x136, x133, uint64(0x0))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x134, x131, uint64(p256Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x132, x129, uint64(p256Uint1(x140)))
x143 := (uint64(p256Uint1(x142)) + x130)
var x144 uint64
var x145 uint64
x144, x145 = bits.Add64(x120, x135, uint64(0x0))
var x146 uint64
var x147 uint64
x146, x147 = bits.Add64(x122, x137, uint64(p256Uint1(x145)))
var x148 uint64
var x149 uint64
x148, x149 = bits.Add64(x124, x139, uint64(p256Uint1(x147)))
var x150 uint64
var x151 uint64
x150, x151 = bits.Add64(x126, x141, uint64(p256Uint1(x149)))
var x152 uint64
var x153 uint64
x152, x153 = bits.Add64(x128, x143, uint64(p256Uint1(x151)))
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x144, 0xffffffff00000001)
var x156 uint64
var x157 uint64
x157, x156 = bits.Mul64(x144, 0xffffffff)
var x158 uint64
var x159 uint64
x159, x158 = bits.Mul64(x144, 0xffffffffffffffff)
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x159, x156, uint64(0x0))
x162 := (uint64(p256Uint1(x161)) + x157)
var x164 uint64
_, x164 = bits.Add64(x144, x158, uint64(0x0))
var x165 uint64
var x166 uint64
x165, x166 = bits.Add64(x146, x160, uint64(p256Uint1(x164)))
var x167 uint64
var x168 uint64
x167, x168 = bits.Add64(x148, x162, uint64(p256Uint1(x166)))
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x150, x154, uint64(p256Uint1(x168)))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x152, x155, uint64(p256Uint1(x170)))
x173 := (uint64(p256Uint1(x172)) + uint64(p256Uint1(x153)))
var x174 uint64
var x175 uint64
x174, x175 = bits.Sub64(x165, 0xffffffffffffffff, uint64(0x0))
var x176 uint64
var x177 uint64
x176, x177 = bits.Sub64(x167, 0xffffffff, uint64(p256Uint1(x175)))
var x178 uint64
var x179 uint64
x178, x179 = bits.Sub64(x169, uint64(0x0), uint64(p256Uint1(x177)))
var x180 uint64
var x181 uint64
x180, x181 = bits.Sub64(x171, 0xffffffff00000001, uint64(p256Uint1(x179)))
var x183 uint64
_, x183 = bits.Sub64(x173, uint64(0x0), uint64(p256Uint1(x181)))
var x184 uint64
p256CmovznzU64(&x184, p256Uint1(x183), x174, x165)
var x185 uint64
p256CmovznzU64(&x185, p256Uint1(x183), x176, x167)
var x186 uint64
p256CmovznzU64(&x186, p256Uint1(x183), x178, x169)
var x187 uint64
p256CmovznzU64(&x187, p256Uint1(x183), x180, x171)
out1[0] = x184
out1[1] = x185
out1[2] = x186
out1[3] = x187
}
// p256Add adds two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) + eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p256Add(out1 *p256MontgomeryDomainFieldElement, arg1 *p256MontgomeryDomainFieldElement, arg2 *p256MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Add64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Add64(arg1[1], arg2[1], uint64(p256Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(arg1[2], arg2[2], uint64(p256Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Add64(arg1[3], arg2[3], uint64(p256Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Sub64(x1, 0xffffffffffffffff, uint64(0x0))
var x11 uint64
var x12 uint64
x11, x12 = bits.Sub64(x3, 0xffffffff, uint64(p256Uint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Sub64(x5, uint64(0x0), uint64(p256Uint1(x12)))
var x15 uint64
var x16 uint64
x15, x16 = bits.Sub64(x7, 0xffffffff00000001, uint64(p256Uint1(x14)))
var x18 uint64
_, x18 = bits.Sub64(uint64(p256Uint1(x8)), uint64(0x0), uint64(p256Uint1(x16)))
var x19 uint64
p256CmovznzU64(&x19, p256Uint1(x18), x9, x1)
var x20 uint64
p256CmovznzU64(&x20, p256Uint1(x18), x11, x3)
var x21 uint64
p256CmovznzU64(&x21, p256Uint1(x18), x13, x5)
var x22 uint64
p256CmovznzU64(&x22, p256Uint1(x18), x15, x7)
out1[0] = x19
out1[1] = x20
out1[2] = x21
out1[3] = x22
}
// p256Sub subtracts two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) - eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p256Sub(out1 *p256MontgomeryDomainFieldElement, arg1 *p256MontgomeryDomainFieldElement, arg2 *p256MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(arg1[1], arg2[1], uint64(p256Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(arg1[2], arg2[2], uint64(p256Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(arg1[3], arg2[3], uint64(p256Uint1(x6)))
var x9 uint64
p256CmovznzU64(&x9, p256Uint1(x8), uint64(0x0), 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x10, x11 = bits.Add64(x1, x9, uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(x3, (x9 & 0xffffffff), uint64(p256Uint1(x11)))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x5, uint64(0x0), uint64(p256Uint1(x13)))
var x16 uint64
x16, _ = bits.Add64(x7, (x9 & 0xffffffff00000001), uint64(p256Uint1(x15)))
out1[0] = x10
out1[1] = x12
out1[2] = x14
out1[3] = x16
}
// p256SetOne returns the field element one in the Montgomery domain.
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = 1 mod m
// 0 ≤ eval out1 < m
func p256SetOne(out1 *p256MontgomeryDomainFieldElement) {
out1[0] = uint64(0x1)
out1[1] = 0xffffffff00000000
out1[2] = 0xffffffffffffffff
out1[3] = 0xfffffffe
}
// p256FromMontgomery translates a field element out of the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = (eval arg1 * ((2^64)⁻¹ mod m)^4) mod m
// 0 ≤ eval out1 < m
func p256FromMontgomery(out1 *p256NonMontgomeryDomainFieldElement, arg1 *p256MontgomeryDomainFieldElement) {
x1 := arg1[0]
var x2 uint64
var x3 uint64
x3, x2 = bits.Mul64(x1, 0xffffffff00000001)
var x4 uint64
var x5 uint64
x5, x4 = bits.Mul64(x1, 0xffffffff)
var x6 uint64
var x7 uint64
x7, x6 = bits.Mul64(x1, 0xffffffffffffffff)
var x8 uint64
var x9 uint64
x8, x9 = bits.Add64(x7, x4, uint64(0x0))
var x11 uint64
_, x11 = bits.Add64(x1, x6, uint64(0x0))
var x12 uint64
var x13 uint64
x12, x13 = bits.Add64(uint64(0x0), x8, uint64(p256Uint1(x11)))
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x12, arg1[1], uint64(0x0))
var x16 uint64
var x17 uint64
x17, x16 = bits.Mul64(x14, 0xffffffff00000001)
var x18 uint64
var x19 uint64
x19, x18 = bits.Mul64(x14, 0xffffffff)
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x14, 0xffffffffffffffff)
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x21, x18, uint64(0x0))
var x25 uint64
_, x25 = bits.Add64(x14, x20, uint64(0x0))
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64((uint64(p256Uint1(x15)) + (uint64(p256Uint1(x13)) + (uint64(p256Uint1(x9)) + x5))), x22, uint64(p256Uint1(x25)))
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x2, (uint64(p256Uint1(x23)) + x19), uint64(p256Uint1(x27)))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x3, x16, uint64(p256Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x26, arg1[2], uint64(0x0))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x28, uint64(0x0), uint64(p256Uint1(x33)))
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(x30, uint64(0x0), uint64(p256Uint1(x35)))
var x38 uint64
var x39 uint64
x39, x38 = bits.Mul64(x32, 0xffffffff00000001)
var x40 uint64
var x41 uint64
x41, x40 = bits.Mul64(x32, 0xffffffff)
var x42 uint64
var x43 uint64
x43, x42 = bits.Mul64(x32, 0xffffffffffffffff)
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(x43, x40, uint64(0x0))
var x47 uint64
_, x47 = bits.Add64(x32, x42, uint64(0x0))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(x34, x44, uint64(p256Uint1(x47)))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x36, (uint64(p256Uint1(x45)) + x41), uint64(p256Uint1(x49)))
var x52 uint64
var x53 uint64
x52, x53 = bits.Add64((uint64(p256Uint1(x37)) + (uint64(p256Uint1(x31)) + x17)), x38, uint64(p256Uint1(x51)))
var x54 uint64
var x55 uint64
x54, x55 = bits.Add64(x48, arg1[3], uint64(0x0))
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x50, uint64(0x0), uint64(p256Uint1(x55)))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x52, uint64(0x0), uint64(p256Uint1(x57)))
var x60 uint64
var x61 uint64
x61, x60 = bits.Mul64(x54, 0xffffffff00000001)
var x62 uint64
var x63 uint64
x63, x62 = bits.Mul64(x54, 0xffffffff)
var x64 uint64
var x65 uint64
x65, x64 = bits.Mul64(x54, 0xffffffffffffffff)
var x66 uint64
var x67 uint64
x66, x67 = bits.Add64(x65, x62, uint64(0x0))
var x69 uint64
_, x69 = bits.Add64(x54, x64, uint64(0x0))
var x70 uint64
var x71 uint64
x70, x71 = bits.Add64(x56, x66, uint64(p256Uint1(x69)))
var x72 uint64
var x73 uint64
x72, x73 = bits.Add64(x58, (uint64(p256Uint1(x67)) + x63), uint64(p256Uint1(x71)))
var x74 uint64
var x75 uint64
x74, x75 = bits.Add64((uint64(p256Uint1(x59)) + (uint64(p256Uint1(x53)) + x39)), x60, uint64(p256Uint1(x73)))
x76 := (uint64(p256Uint1(x75)) + x61)
var x77 uint64
var x78 uint64
x77, x78 = bits.Sub64(x70, 0xffffffffffffffff, uint64(0x0))
var x79 uint64
var x80 uint64
x79, x80 = bits.Sub64(x72, 0xffffffff, uint64(p256Uint1(x78)))
var x81 uint64
var x82 uint64
x81, x82 = bits.Sub64(x74, uint64(0x0), uint64(p256Uint1(x80)))
var x83 uint64
var x84 uint64
x83, x84 = bits.Sub64(x76, 0xffffffff00000001, uint64(p256Uint1(x82)))
var x86 uint64
_, x86 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(p256Uint1(x84)))
var x87 uint64
p256CmovznzU64(&x87, p256Uint1(x86), x77, x70)
var x88 uint64
p256CmovznzU64(&x88, p256Uint1(x86), x79, x72)
var x89 uint64
p256CmovznzU64(&x89, p256Uint1(x86), x81, x74)
var x90 uint64
p256CmovznzU64(&x90, p256Uint1(x86), x83, x76)
out1[0] = x87
out1[1] = x88
out1[2] = x89
out1[3] = x90
}
// p256ToMontgomery translates a field element into the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = eval arg1 mod m
// 0 ≤ eval out1 < m
func p256ToMontgomery(out1 *p256MontgomeryDomainFieldElement, arg1 *p256NonMontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[0]
var x5 uint64
var x6 uint64
x6, x5 = bits.Mul64(x4, 0x4fffffffd)
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x4, 0xfffffffffffffffe)
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x4, 0xfffffffbffffffff)
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x4, 0x3)
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(x12, x9, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x10, x7, uint64(p256Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x8, x5, uint64(p256Uint1(x16)))
var x19 uint64
var x20 uint64
x20, x19 = bits.Mul64(x11, 0xffffffff00000001)
var x21 uint64
var x22 uint64
x22, x21 = bits.Mul64(x11, 0xffffffff)
var x23 uint64
var x24 uint64
x24, x23 = bits.Mul64(x11, 0xffffffffffffffff)
var x25 uint64
var x26 uint64
x25, x26 = bits.Add64(x24, x21, uint64(0x0))
var x28 uint64
_, x28 = bits.Add64(x11, x23, uint64(0x0))
var x29 uint64
var x30 uint64
x29, x30 = bits.Add64(x13, x25, uint64(p256Uint1(x28)))
var x31 uint64
var x32 uint64
x31, x32 = bits.Add64(x15, (uint64(p256Uint1(x26)) + x22), uint64(p256Uint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x17, x19, uint64(p256Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64((uint64(p256Uint1(x18)) + x6), x20, uint64(p256Uint1(x34)))
var x37 uint64
var x38 uint64
x38, x37 = bits.Mul64(x1, 0x4fffffffd)
var x39 uint64
var x40 uint64
x40, x39 = bits.Mul64(x1, 0xfffffffffffffffe)
var x41 uint64
var x42 uint64
x42, x41 = bits.Mul64(x1, 0xfffffffbffffffff)
var x43 uint64
var x44 uint64
x44, x43 = bits.Mul64(x1, 0x3)
var x45 uint64
var x46 uint64
x45, x46 = bits.Add64(x44, x41, uint64(0x0))
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(x42, x39, uint64(p256Uint1(x46)))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x40, x37, uint64(p256Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x29, x43, uint64(0x0))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x31, x45, uint64(p256Uint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x33, x47, uint64(p256Uint1(x54)))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(x35, x49, uint64(p256Uint1(x56)))
var x59 uint64
var x60 uint64
x60, x59 = bits.Mul64(x51, 0xffffffff00000001)
var x61 uint64
var x62 uint64
x62, x61 = bits.Mul64(x51, 0xffffffff)
var x63 uint64
var x64 uint64
x64, x63 = bits.Mul64(x51, 0xffffffffffffffff)
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x64, x61, uint64(0x0))
var x68 uint64
_, x68 = bits.Add64(x51, x63, uint64(0x0))
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x53, x65, uint64(p256Uint1(x68)))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x55, (uint64(p256Uint1(x66)) + x62), uint64(p256Uint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x57, x59, uint64(p256Uint1(x72)))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(((uint64(p256Uint1(x58)) + uint64(p256Uint1(x36))) + (uint64(p256Uint1(x50)) + x38)), x60, uint64(p256Uint1(x74)))
var x77 uint64
var x78 uint64
x78, x77 = bits.Mul64(x2, 0x4fffffffd)
var x79 uint64
var x80 uint64
x80, x79 = bits.Mul64(x2, 0xfffffffffffffffe)
var x81 uint64
var x82 uint64
x82, x81 = bits.Mul64(x2, 0xfffffffbffffffff)
var x83 uint64
var x84 uint64
x84, x83 = bits.Mul64(x2, 0x3)
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x84, x81, uint64(0x0))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x82, x79, uint64(p256Uint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x80, x77, uint64(p256Uint1(x88)))
var x91 uint64
var x92 uint64
x91, x92 = bits.Add64(x69, x83, uint64(0x0))
var x93 uint64
var x94 uint64
x93, x94 = bits.Add64(x71, x85, uint64(p256Uint1(x92)))
var x95 uint64
var x96 uint64
x95, x96 = bits.Add64(x73, x87, uint64(p256Uint1(x94)))
var x97 uint64
var x98 uint64
x97, x98 = bits.Add64(x75, x89, uint64(p256Uint1(x96)))
var x99 uint64
var x100 uint64
x100, x99 = bits.Mul64(x91, 0xffffffff00000001)
var x101 uint64
var x102 uint64
x102, x101 = bits.Mul64(x91, 0xffffffff)
var x103 uint64
var x104 uint64
x104, x103 = bits.Mul64(x91, 0xffffffffffffffff)
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x104, x101, uint64(0x0))
var x108 uint64
_, x108 = bits.Add64(x91, x103, uint64(0x0))
var x109 uint64
var x110 uint64
x109, x110 = bits.Add64(x93, x105, uint64(p256Uint1(x108)))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x95, (uint64(p256Uint1(x106)) + x102), uint64(p256Uint1(x110)))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x97, x99, uint64(p256Uint1(x112)))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(((uint64(p256Uint1(x98)) + uint64(p256Uint1(x76))) + (uint64(p256Uint1(x90)) + x78)), x100, uint64(p256Uint1(x114)))
var x117 uint64
var x118 uint64
x118, x117 = bits.Mul64(x3, 0x4fffffffd)
var x119 uint64
var x120 uint64
x120, x119 = bits.Mul64(x3, 0xfffffffffffffffe)
var x121 uint64
var x122 uint64
x122, x121 = bits.Mul64(x3, 0xfffffffbffffffff)
var x123 uint64
var x124 uint64
x124, x123 = bits.Mul64(x3, 0x3)
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64(x124, x121, uint64(0x0))
var x127 uint64
var x128 uint64
x127, x128 = bits.Add64(x122, x119, uint64(p256Uint1(x126)))
var x129 uint64
var x130 uint64
x129, x130 = bits.Add64(x120, x117, uint64(p256Uint1(x128)))
var x131 uint64
var x132 uint64
x131, x132 = bits.Add64(x109, x123, uint64(0x0))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x111, x125, uint64(p256Uint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x113, x127, uint64(p256Uint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x115, x129, uint64(p256Uint1(x136)))
var x139 uint64
var x140 uint64
x140, x139 = bits.Mul64(x131, 0xffffffff00000001)
var x141 uint64
var x142 uint64
x142, x141 = bits.Mul64(x131, 0xffffffff)
var x143 uint64
var x144 uint64
x144, x143 = bits.Mul64(x131, 0xffffffffffffffff)
var x145 uint64
var x146 uint64
x145, x146 = bits.Add64(x144, x141, uint64(0x0))
var x148 uint64
_, x148 = bits.Add64(x131, x143, uint64(0x0))
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x133, x145, uint64(p256Uint1(x148)))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x135, (uint64(p256Uint1(x146)) + x142), uint64(p256Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x137, x139, uint64(p256Uint1(x152)))
var x155 uint64
var x156 uint64
x155, x156 = bits.Add64(((uint64(p256Uint1(x138)) + uint64(p256Uint1(x116))) + (uint64(p256Uint1(x130)) + x118)), x140, uint64(p256Uint1(x154)))
var x157 uint64
var x158 uint64
x157, x158 = bits.Sub64(x149, 0xffffffffffffffff, uint64(0x0))
var x159 uint64
var x160 uint64
x159, x160 = bits.Sub64(x151, 0xffffffff, uint64(p256Uint1(x158)))
var x161 uint64
var x162 uint64
x161, x162 = bits.Sub64(x153, uint64(0x0), uint64(p256Uint1(x160)))
var x163 uint64
var x164 uint64
x163, x164 = bits.Sub64(x155, 0xffffffff00000001, uint64(p256Uint1(x162)))
var x166 uint64
_, x166 = bits.Sub64(uint64(p256Uint1(x156)), uint64(0x0), uint64(p256Uint1(x164)))
var x167 uint64
p256CmovznzU64(&x167, p256Uint1(x166), x157, x149)
var x168 uint64
p256CmovznzU64(&x168, p256Uint1(x166), x159, x151)
var x169 uint64
p256CmovznzU64(&x169, p256Uint1(x166), x161, x153)
var x170 uint64
p256CmovznzU64(&x170, p256Uint1(x166), x163, x155)
out1[0] = x167
out1[1] = x168
out1[2] = x169
out1[3] = x170
}
// p256Selectznz is a multi-limb conditional select.
//
// Postconditions:
//
// eval out1 = (if arg1 = 0 then eval arg2 else eval arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
// arg3: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p256Selectznz(out1 *[4]uint64, arg1 p256Uint1, arg2 *[4]uint64, arg3 *[4]uint64) {
var x1 uint64
p256CmovznzU64(&x1, arg1, arg2[0], arg3[0])
var x2 uint64
p256CmovznzU64(&x2, arg1, arg2[1], arg3[1])
var x3 uint64
p256CmovznzU64(&x3, arg1, arg2[2], arg3[2])
var x4 uint64
p256CmovznzU64(&x4, arg1, arg2[3], arg3[3])
out1[0] = x1
out1[1] = x2
out1[2] = x3
out1[3] = x4
}
// p256ToBytes serializes a field element NOT in the Montgomery domain to bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = map (λ x, ⌊((eval arg1 mod m) mod 2^(8 * (x + 1))) / 2^(8 * x)⌋) [0..31]
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
func p256ToBytes(out1 *[32]uint8, arg1 *[4]uint64) {
x1 := arg1[3]
x2 := arg1[2]
x3 := arg1[1]
x4 := arg1[0]
x5 := (uint8(x4) & 0xff)
x6 := (x4 >> 8)
x7 := (uint8(x6) & 0xff)
x8 := (x6 >> 8)
x9 := (uint8(x8) & 0xff)
x10 := (x8 >> 8)
x11 := (uint8(x10) & 0xff)
x12 := (x10 >> 8)
x13 := (uint8(x12) & 0xff)
x14 := (x12 >> 8)
x15 := (uint8(x14) & 0xff)
x16 := (x14 >> 8)
x17 := (uint8(x16) & 0xff)
x18 := uint8((x16 >> 8))
x19 := (uint8(x3) & 0xff)
x20 := (x3 >> 8)
x21 := (uint8(x20) & 0xff)
x22 := (x20 >> 8)
x23 := (uint8(x22) & 0xff)
x24 := (x22 >> 8)
x25 := (uint8(x24) & 0xff)
x26 := (x24 >> 8)
x27 := (uint8(x26) & 0xff)
x28 := (x26 >> 8)
x29 := (uint8(x28) & 0xff)
x30 := (x28 >> 8)
x31 := (uint8(x30) & 0xff)
x32 := uint8((x30 >> 8))
x33 := (uint8(x2) & 0xff)
x34 := (x2 >> 8)
x35 := (uint8(x34) & 0xff)
x36 := (x34 >> 8)
x37 := (uint8(x36) & 0xff)
x38 := (x36 >> 8)
x39 := (uint8(x38) & 0xff)
x40 := (x38 >> 8)
x41 := (uint8(x40) & 0xff)
x42 := (x40 >> 8)
x43 := (uint8(x42) & 0xff)
x44 := (x42 >> 8)
x45 := (uint8(x44) & 0xff)
x46 := uint8((x44 >> 8))
x47 := (uint8(x1) & 0xff)
x48 := (x1 >> 8)
x49 := (uint8(x48) & 0xff)
x50 := (x48 >> 8)
x51 := (uint8(x50) & 0xff)
x52 := (x50 >> 8)
x53 := (uint8(x52) & 0xff)
x54 := (x52 >> 8)
x55 := (uint8(x54) & 0xff)
x56 := (x54 >> 8)
x57 := (uint8(x56) & 0xff)
x58 := (x56 >> 8)
x59 := (uint8(x58) & 0xff)
x60 := uint8((x58 >> 8))
out1[0] = x5
out1[1] = x7
out1[2] = x9
out1[3] = x11
out1[4] = x13
out1[5] = x15
out1[6] = x17
out1[7] = x18
out1[8] = x19
out1[9] = x21
out1[10] = x23
out1[11] = x25
out1[12] = x27
out1[13] = x29
out1[14] = x31
out1[15] = x32
out1[16] = x33
out1[17] = x35
out1[18] = x37
out1[19] = x39
out1[20] = x41
out1[21] = x43
out1[22] = x45
out1[23] = x46
out1[24] = x47
out1[25] = x49
out1[26] = x51
out1[27] = x53
out1[28] = x55
out1[29] = x57
out1[30] = x59
out1[31] = x60
}
// p256FromBytes deserializes a field element NOT in the Montgomery domain from bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ bytes_eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = bytes_eval arg1 mod m
// 0 ≤ eval out1 < m
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p256FromBytes(out1 *[4]uint64, arg1 *[32]uint8) {
x1 := (uint64(arg1[31]) << 56)
x2 := (uint64(arg1[30]) << 48)
x3 := (uint64(arg1[29]) << 40)
x4 := (uint64(arg1[28]) << 32)
x5 := (uint64(arg1[27]) << 24)
x6 := (uint64(arg1[26]) << 16)
x7 := (uint64(arg1[25]) << 8)
x8 := arg1[24]
x9 := (uint64(arg1[23]) << 56)
x10 := (uint64(arg1[22]) << 48)
x11 := (uint64(arg1[21]) << 40)
x12 := (uint64(arg1[20]) << 32)
x13 := (uint64(arg1[19]) << 24)
x14 := (uint64(arg1[18]) << 16)
x15 := (uint64(arg1[17]) << 8)
x16 := arg1[16]
x17 := (uint64(arg1[15]) << 56)
x18 := (uint64(arg1[14]) << 48)
x19 := (uint64(arg1[13]) << 40)
x20 := (uint64(arg1[12]) << 32)
x21 := (uint64(arg1[11]) << 24)
x22 := (uint64(arg1[10]) << 16)
x23 := (uint64(arg1[9]) << 8)
x24 := arg1[8]
x25 := (uint64(arg1[7]) << 56)
x26 := (uint64(arg1[6]) << 48)
x27 := (uint64(arg1[5]) << 40)
x28 := (uint64(arg1[4]) << 32)
x29 := (uint64(arg1[3]) << 24)
x30 := (uint64(arg1[2]) << 16)
x31 := (uint64(arg1[1]) << 8)
x32 := arg1[0]
x33 := (x31 + uint64(x32))
x34 := (x30 + x33)
x35 := (x29 + x34)
x36 := (x28 + x35)
x37 := (x27 + x36)
x38 := (x26 + x37)
x39 := (x25 + x38)
x40 := (x23 + uint64(x24))
x41 := (x22 + x40)
x42 := (x21 + x41)
x43 := (x20 + x42)
x44 := (x19 + x43)
x45 := (x18 + x44)
x46 := (x17 + x45)
x47 := (x15 + uint64(x16))
x48 := (x14 + x47)
x49 := (x13 + x48)
x50 := (x12 + x49)
x51 := (x11 + x50)
x52 := (x10 + x51)
x53 := (x9 + x52)
x54 := (x7 + uint64(x8))
x55 := (x6 + x54)
x56 := (x5 + x55)
x57 := (x4 + x56)
x58 := (x3 + x57)
x59 := (x2 + x58)
x60 := (x1 + x59)
out1[0] = x39
out1[1] = x46
out1[2] = x53
out1[3] = x60
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by addchain. DO NOT EDIT.
package fiat
// Invert sets e = 1/x, and returns e.
//
// If x == 0, Invert returns e = 0.
func (e *P256Element) Invert(x *P256Element) *P256Element {
// Inversion is implemented as exponentiation with exponent p − 2.
// The sequence of 12 multiplications and 255 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// x12 = _111111 << 6 + _111111
// x15 = x12 << 3 + _111
// x16 = 2*x15 + 1
// x32 = x16 << 16 + x16
// i53 = x32 << 15
// x47 = x15 + i53
// i263 = ((i53 << 17 + 1) << 143 + x47) << 47
// return (x47 + i263) << 2 + 1
//
var z = new(P256Element).Set(e)
var t0 = new(P256Element)
var t1 = new(P256Element)
z.Square(x)
z.Mul(x, z)
z.Square(z)
z.Mul(x, z)
t0.Square(z)
for s := 1; s < 3; s++ {
t0.Square(t0)
}
t0.Mul(z, t0)
t1.Square(t0)
for s := 1; s < 6; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 3; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
t0.Mul(x, t0)
t1.Square(t0)
for s := 1; s < 16; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 15; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
for s := 0; s < 17; s++ {
t0.Square(t0)
}
t0.Mul(x, t0)
for s := 0; s < 143; s++ {
t0.Square(t0)
}
t0.Mul(z, t0)
for s := 0; s < 47; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
for s := 0; s < 2; s++ {
z.Square(z)
}
z.Mul(x, z)
return e.Set(z)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package fiat
import (
"crypto/subtle"
"errors"
)
// P384Element is an integer modulo 2^384 - 2^128 - 2^96 + 2^32 - 1.
//
// The zero value is a valid zero element.
type P384Element struct {
// Values are represented internally always in the Montgomery domain, and
// converted in Bytes and SetBytes.
x p384MontgomeryDomainFieldElement
}
const p384ElementLen = 48
type p384UntypedFieldElement = [6]uint64
// One sets e = 1, and returns e.
func (e *P384Element) One() *P384Element {
p384SetOne(&e.x)
return e
}
// Equal returns 1 if e == t, and zero otherwise.
func (e *P384Element) Equal(t *P384Element) int {
eBytes := e.Bytes()
tBytes := t.Bytes()
return subtle.ConstantTimeCompare(eBytes, tBytes)
}
// IsZero returns 1 if e == 0, and zero otherwise.
func (e *P384Element) IsZero() int {
zero := make([]byte, p384ElementLen)
eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, zero)
}
// Set sets e = t, and returns e.
func (e *P384Element) Set(t *P384Element) *P384Element {
e.x = t.x
return e
}
// Bytes returns the 48-byte big-endian encoding of e.
func (e *P384Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p384ElementLen]byte
return e.bytes(&out)
}
func (e *P384Element) bytes(out *[p384ElementLen]byte) []byte {
var tmp p384NonMontgomeryDomainFieldElement
p384FromMontgomery(&tmp, &e.x)
p384ToBytes(out, (*p384UntypedFieldElement)(&tmp))
p384InvertEndianness(out[:])
return out[:]
}
// SetBytes sets e = v, where v is a big-endian 48-byte encoding, and returns e.
// If v is not 48 bytes or it encodes a value higher than 2^384 - 2^128 - 2^96 + 2^32 - 1,
// SetBytes returns nil and an error, and e is unchanged.
func (e *P384Element) SetBytes(v []byte) (*P384Element, error) {
if len(v) != p384ElementLen {
return nil, errors.New("invalid P384Element encoding")
}
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P384Element).Sub(
new(P384Element), new(P384Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P384Element encoding")
}
}
var in [p384ElementLen]byte
copy(in[:], v)
p384InvertEndianness(in[:])
var tmp p384NonMontgomeryDomainFieldElement
p384FromBytes((*p384UntypedFieldElement)(&tmp), &in)
p384ToMontgomery(&e.x, &tmp)
return e, nil
}
// Add sets e = t1 + t2, and returns e.
func (e *P384Element) Add(t1, t2 *P384Element) *P384Element {
p384Add(&e.x, &t1.x, &t2.x)
return e
}
// Sub sets e = t1 - t2, and returns e.
func (e *P384Element) Sub(t1, t2 *P384Element) *P384Element {
p384Sub(&e.x, &t1.x, &t2.x)
return e
}
// Mul sets e = t1 * t2, and returns e.
func (e *P384Element) Mul(t1, t2 *P384Element) *P384Element {
p384Mul(&e.x, &t1.x, &t2.x)
return e
}
// Square sets e = t * t, and returns e.
func (e *P384Element) Square(t *P384Element) *P384Element {
p384Square(&e.x, &t.x)
return e
}
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *P384Element) Select(a, b *P384Element, cond int) *P384Element {
p384Selectznz((*p384UntypedFieldElement)(&v.x), p384Uint1(cond),
(*p384UntypedFieldElement)(&b.x), (*p384UntypedFieldElement)(&a.x))
return v
}
func p384InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]
}
}
// Code generated by Fiat Cryptography. DO NOT EDIT.
//
// Autogenerated: word_by_word_montgomery --lang Go --no-wide-int --cmovznz-by-mul --relax-primitive-carry-to-bitwidth 32,64 --internal-static --public-function-case camelCase --public-type-case camelCase --private-function-case camelCase --private-type-case camelCase --doc-text-before-function-name '' --doc-newline-before-package-declaration --doc-prepend-header 'Code generated by Fiat Cryptography. DO NOT EDIT.' --package-name fiat --no-prefix-fiat p384 64 '2^384 - 2^128 - 2^96 + 2^32 - 1' mul square add sub one from_montgomery to_montgomery selectznz to_bytes from_bytes
//
// curve description: p384
//
// machine_wordsize = 64 (from "64")
//
// requested operations: mul, square, add, sub, one, from_montgomery, to_montgomery, selectznz, to_bytes, from_bytes
//
// m = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff (from "2^384 - 2^128 - 2^96 + 2^32 - 1")
//
//
//
// NOTE: In addition to the bounds specified above each function, all
//
// functions synthesized for this Montgomery arithmetic require the
//
// input to be strictly less than the prime modulus (m), and also
//
// require the input to be in the unique saturated representation.
//
// All functions also ensure that these two properties are true of
//
// return values.
//
//
//
// Computed values:
//
// eval z = z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) + (z[4] << 256) + (z[5] << 0x140)
//
// bytes_eval z = z[0] + (z[1] << 8) + (z[2] << 16) + (z[3] << 24) + (z[4] << 32) + (z[5] << 40) + (z[6] << 48) + (z[7] << 56) + (z[8] << 64) + (z[9] << 72) + (z[10] << 80) + (z[11] << 88) + (z[12] << 96) + (z[13] << 104) + (z[14] << 112) + (z[15] << 120) + (z[16] << 128) + (z[17] << 136) + (z[18] << 144) + (z[19] << 152) + (z[20] << 160) + (z[21] << 168) + (z[22] << 176) + (z[23] << 184) + (z[24] << 192) + (z[25] << 200) + (z[26] << 208) + (z[27] << 216) + (z[28] << 224) + (z[29] << 232) + (z[30] << 240) + (z[31] << 248) + (z[32] << 256) + (z[33] << 0x108) + (z[34] << 0x110) + (z[35] << 0x118) + (z[36] << 0x120) + (z[37] << 0x128) + (z[38] << 0x130) + (z[39] << 0x138) + (z[40] << 0x140) + (z[41] << 0x148) + (z[42] << 0x150) + (z[43] << 0x158) + (z[44] << 0x160) + (z[45] << 0x168) + (z[46] << 0x170) + (z[47] << 0x178)
//
// twos_complement_eval z = let x1 := z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) + (z[4] << 256) + (z[5] << 0x140) in
//
// if x1 & (2^384-1) < 2^383 then x1 & (2^384-1) else (x1 & (2^384-1)) - 2^384
package fiat
import "math/bits"
type p384Uint1 uint64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
type p384Int1 int64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
// The type p384MontgomeryDomainFieldElement is a field element in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p384MontgomeryDomainFieldElement [6]uint64
// The type p384NonMontgomeryDomainFieldElement is a field element NOT in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p384NonMontgomeryDomainFieldElement [6]uint64
// p384CmovznzU64 is a single-word conditional move.
//
// Postconditions:
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func p384CmovznzU64(out1 *uint64, arg1 p384Uint1, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// p384Mul multiplies two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p384Mul(out1 *p384MontgomeryDomainFieldElement, arg1 *p384MontgomeryDomainFieldElement, arg2 *p384MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[4]
x5 := arg1[5]
x6 := arg1[0]
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x6, arg2[5])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x6, arg2[4])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x6, arg2[3])
var x13 uint64
var x14 uint64
x14, x13 = bits.Mul64(x6, arg2[2])
var x15 uint64
var x16 uint64
x16, x15 = bits.Mul64(x6, arg2[1])
var x17 uint64
var x18 uint64
x18, x17 = bits.Mul64(x6, arg2[0])
var x19 uint64
var x20 uint64
x19, x20 = bits.Add64(x18, x15, uint64(0x0))
var x21 uint64
var x22 uint64
x21, x22 = bits.Add64(x16, x13, uint64(p384Uint1(x20)))
var x23 uint64
var x24 uint64
x23, x24 = bits.Add64(x14, x11, uint64(p384Uint1(x22)))
var x25 uint64
var x26 uint64
x25, x26 = bits.Add64(x12, x9, uint64(p384Uint1(x24)))
var x27 uint64
var x28 uint64
x27, x28 = bits.Add64(x10, x7, uint64(p384Uint1(x26)))
x29 := (uint64(p384Uint1(x28)) + x8)
var x30 uint64
_, x30 = bits.Mul64(x17, 0x100000001)
var x32 uint64
var x33 uint64
x33, x32 = bits.Mul64(x30, 0xffffffffffffffff)
var x34 uint64
var x35 uint64
x35, x34 = bits.Mul64(x30, 0xffffffffffffffff)
var x36 uint64
var x37 uint64
x37, x36 = bits.Mul64(x30, 0xffffffffffffffff)
var x38 uint64
var x39 uint64
x39, x38 = bits.Mul64(x30, 0xfffffffffffffffe)
var x40 uint64
var x41 uint64
x41, x40 = bits.Mul64(x30, 0xffffffff00000000)
var x42 uint64
var x43 uint64
x43, x42 = bits.Mul64(x30, 0xffffffff)
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(x43, x40, uint64(0x0))
var x46 uint64
var x47 uint64
x46, x47 = bits.Add64(x41, x38, uint64(p384Uint1(x45)))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(x39, x36, uint64(p384Uint1(x47)))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x37, x34, uint64(p384Uint1(x49)))
var x52 uint64
var x53 uint64
x52, x53 = bits.Add64(x35, x32, uint64(p384Uint1(x51)))
x54 := (uint64(p384Uint1(x53)) + x33)
var x56 uint64
_, x56 = bits.Add64(x17, x42, uint64(0x0))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(x19, x44, uint64(p384Uint1(x56)))
var x59 uint64
var x60 uint64
x59, x60 = bits.Add64(x21, x46, uint64(p384Uint1(x58)))
var x61 uint64
var x62 uint64
x61, x62 = bits.Add64(x23, x48, uint64(p384Uint1(x60)))
var x63 uint64
var x64 uint64
x63, x64 = bits.Add64(x25, x50, uint64(p384Uint1(x62)))
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x27, x52, uint64(p384Uint1(x64)))
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x29, x54, uint64(p384Uint1(x66)))
var x69 uint64
var x70 uint64
x70, x69 = bits.Mul64(x1, arg2[5])
var x71 uint64
var x72 uint64
x72, x71 = bits.Mul64(x1, arg2[4])
var x73 uint64
var x74 uint64
x74, x73 = bits.Mul64(x1, arg2[3])
var x75 uint64
var x76 uint64
x76, x75 = bits.Mul64(x1, arg2[2])
var x77 uint64
var x78 uint64
x78, x77 = bits.Mul64(x1, arg2[1])
var x79 uint64
var x80 uint64
x80, x79 = bits.Mul64(x1, arg2[0])
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x80, x77, uint64(0x0))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x78, x75, uint64(p384Uint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x76, x73, uint64(p384Uint1(x84)))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x74, x71, uint64(p384Uint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x72, x69, uint64(p384Uint1(x88)))
x91 := (uint64(p384Uint1(x90)) + x70)
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x57, x79, uint64(0x0))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x59, x81, uint64(p384Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x61, x83, uint64(p384Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x63, x85, uint64(p384Uint1(x97)))
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x65, x87, uint64(p384Uint1(x99)))
var x102 uint64
var x103 uint64
x102, x103 = bits.Add64(x67, x89, uint64(p384Uint1(x101)))
var x104 uint64
var x105 uint64
x104, x105 = bits.Add64(uint64(p384Uint1(x68)), x91, uint64(p384Uint1(x103)))
var x106 uint64
_, x106 = bits.Mul64(x92, 0x100000001)
var x108 uint64
var x109 uint64
x109, x108 = bits.Mul64(x106, 0xffffffffffffffff)
var x110 uint64
var x111 uint64
x111, x110 = bits.Mul64(x106, 0xffffffffffffffff)
var x112 uint64
var x113 uint64
x113, x112 = bits.Mul64(x106, 0xffffffffffffffff)
var x114 uint64
var x115 uint64
x115, x114 = bits.Mul64(x106, 0xfffffffffffffffe)
var x116 uint64
var x117 uint64
x117, x116 = bits.Mul64(x106, 0xffffffff00000000)
var x118 uint64
var x119 uint64
x119, x118 = bits.Mul64(x106, 0xffffffff)
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x119, x116, uint64(0x0))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x117, x114, uint64(p384Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x115, x112, uint64(p384Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x113, x110, uint64(p384Uint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x111, x108, uint64(p384Uint1(x127)))
x130 := (uint64(p384Uint1(x129)) + x109)
var x132 uint64
_, x132 = bits.Add64(x92, x118, uint64(0x0))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x94, x120, uint64(p384Uint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x96, x122, uint64(p384Uint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x98, x124, uint64(p384Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x100, x126, uint64(p384Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x102, x128, uint64(p384Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x104, x130, uint64(p384Uint1(x142)))
x145 := (uint64(p384Uint1(x144)) + uint64(p384Uint1(x105)))
var x146 uint64
var x147 uint64
x147, x146 = bits.Mul64(x2, arg2[5])
var x148 uint64
var x149 uint64
x149, x148 = bits.Mul64(x2, arg2[4])
var x150 uint64
var x151 uint64
x151, x150 = bits.Mul64(x2, arg2[3])
var x152 uint64
var x153 uint64
x153, x152 = bits.Mul64(x2, arg2[2])
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x2, arg2[1])
var x156 uint64
var x157 uint64
x157, x156 = bits.Mul64(x2, arg2[0])
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x157, x154, uint64(0x0))
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x155, x152, uint64(p384Uint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Add64(x153, x150, uint64(p384Uint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Add64(x151, x148, uint64(p384Uint1(x163)))
var x166 uint64
var x167 uint64
x166, x167 = bits.Add64(x149, x146, uint64(p384Uint1(x165)))
x168 := (uint64(p384Uint1(x167)) + x147)
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x133, x156, uint64(0x0))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x135, x158, uint64(p384Uint1(x170)))
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x137, x160, uint64(p384Uint1(x172)))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x139, x162, uint64(p384Uint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x141, x164, uint64(p384Uint1(x176)))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x143, x166, uint64(p384Uint1(x178)))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x145, x168, uint64(p384Uint1(x180)))
var x183 uint64
_, x183 = bits.Mul64(x169, 0x100000001)
var x185 uint64
var x186 uint64
x186, x185 = bits.Mul64(x183, 0xffffffffffffffff)
var x187 uint64
var x188 uint64
x188, x187 = bits.Mul64(x183, 0xffffffffffffffff)
var x189 uint64
var x190 uint64
x190, x189 = bits.Mul64(x183, 0xffffffffffffffff)
var x191 uint64
var x192 uint64
x192, x191 = bits.Mul64(x183, 0xfffffffffffffffe)
var x193 uint64
var x194 uint64
x194, x193 = bits.Mul64(x183, 0xffffffff00000000)
var x195 uint64
var x196 uint64
x196, x195 = bits.Mul64(x183, 0xffffffff)
var x197 uint64
var x198 uint64
x197, x198 = bits.Add64(x196, x193, uint64(0x0))
var x199 uint64
var x200 uint64
x199, x200 = bits.Add64(x194, x191, uint64(p384Uint1(x198)))
var x201 uint64
var x202 uint64
x201, x202 = bits.Add64(x192, x189, uint64(p384Uint1(x200)))
var x203 uint64
var x204 uint64
x203, x204 = bits.Add64(x190, x187, uint64(p384Uint1(x202)))
var x205 uint64
var x206 uint64
x205, x206 = bits.Add64(x188, x185, uint64(p384Uint1(x204)))
x207 := (uint64(p384Uint1(x206)) + x186)
var x209 uint64
_, x209 = bits.Add64(x169, x195, uint64(0x0))
var x210 uint64
var x211 uint64
x210, x211 = bits.Add64(x171, x197, uint64(p384Uint1(x209)))
var x212 uint64
var x213 uint64
x212, x213 = bits.Add64(x173, x199, uint64(p384Uint1(x211)))
var x214 uint64
var x215 uint64
x214, x215 = bits.Add64(x175, x201, uint64(p384Uint1(x213)))
var x216 uint64
var x217 uint64
x216, x217 = bits.Add64(x177, x203, uint64(p384Uint1(x215)))
var x218 uint64
var x219 uint64
x218, x219 = bits.Add64(x179, x205, uint64(p384Uint1(x217)))
var x220 uint64
var x221 uint64
x220, x221 = bits.Add64(x181, x207, uint64(p384Uint1(x219)))
x222 := (uint64(p384Uint1(x221)) + uint64(p384Uint1(x182)))
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x3, arg2[5])
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x3, arg2[4])
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x3, arg2[3])
var x229 uint64
var x230 uint64
x230, x229 = bits.Mul64(x3, arg2[2])
var x231 uint64
var x232 uint64
x232, x231 = bits.Mul64(x3, arg2[1])
var x233 uint64
var x234 uint64
x234, x233 = bits.Mul64(x3, arg2[0])
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x234, x231, uint64(0x0))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x232, x229, uint64(p384Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x230, x227, uint64(p384Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x228, x225, uint64(p384Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x226, x223, uint64(p384Uint1(x242)))
x245 := (uint64(p384Uint1(x244)) + x224)
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x210, x233, uint64(0x0))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x212, x235, uint64(p384Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x214, x237, uint64(p384Uint1(x249)))
var x252 uint64
var x253 uint64
x252, x253 = bits.Add64(x216, x239, uint64(p384Uint1(x251)))
var x254 uint64
var x255 uint64
x254, x255 = bits.Add64(x218, x241, uint64(p384Uint1(x253)))
var x256 uint64
var x257 uint64
x256, x257 = bits.Add64(x220, x243, uint64(p384Uint1(x255)))
var x258 uint64
var x259 uint64
x258, x259 = bits.Add64(x222, x245, uint64(p384Uint1(x257)))
var x260 uint64
_, x260 = bits.Mul64(x246, 0x100000001)
var x262 uint64
var x263 uint64
x263, x262 = bits.Mul64(x260, 0xffffffffffffffff)
var x264 uint64
var x265 uint64
x265, x264 = bits.Mul64(x260, 0xffffffffffffffff)
var x266 uint64
var x267 uint64
x267, x266 = bits.Mul64(x260, 0xffffffffffffffff)
var x268 uint64
var x269 uint64
x269, x268 = bits.Mul64(x260, 0xfffffffffffffffe)
var x270 uint64
var x271 uint64
x271, x270 = bits.Mul64(x260, 0xffffffff00000000)
var x272 uint64
var x273 uint64
x273, x272 = bits.Mul64(x260, 0xffffffff)
var x274 uint64
var x275 uint64
x274, x275 = bits.Add64(x273, x270, uint64(0x0))
var x276 uint64
var x277 uint64
x276, x277 = bits.Add64(x271, x268, uint64(p384Uint1(x275)))
var x278 uint64
var x279 uint64
x278, x279 = bits.Add64(x269, x266, uint64(p384Uint1(x277)))
var x280 uint64
var x281 uint64
x280, x281 = bits.Add64(x267, x264, uint64(p384Uint1(x279)))
var x282 uint64
var x283 uint64
x282, x283 = bits.Add64(x265, x262, uint64(p384Uint1(x281)))
x284 := (uint64(p384Uint1(x283)) + x263)
var x286 uint64
_, x286 = bits.Add64(x246, x272, uint64(0x0))
var x287 uint64
var x288 uint64
x287, x288 = bits.Add64(x248, x274, uint64(p384Uint1(x286)))
var x289 uint64
var x290 uint64
x289, x290 = bits.Add64(x250, x276, uint64(p384Uint1(x288)))
var x291 uint64
var x292 uint64
x291, x292 = bits.Add64(x252, x278, uint64(p384Uint1(x290)))
var x293 uint64
var x294 uint64
x293, x294 = bits.Add64(x254, x280, uint64(p384Uint1(x292)))
var x295 uint64
var x296 uint64
x295, x296 = bits.Add64(x256, x282, uint64(p384Uint1(x294)))
var x297 uint64
var x298 uint64
x297, x298 = bits.Add64(x258, x284, uint64(p384Uint1(x296)))
x299 := (uint64(p384Uint1(x298)) + uint64(p384Uint1(x259)))
var x300 uint64
var x301 uint64
x301, x300 = bits.Mul64(x4, arg2[5])
var x302 uint64
var x303 uint64
x303, x302 = bits.Mul64(x4, arg2[4])
var x304 uint64
var x305 uint64
x305, x304 = bits.Mul64(x4, arg2[3])
var x306 uint64
var x307 uint64
x307, x306 = bits.Mul64(x4, arg2[2])
var x308 uint64
var x309 uint64
x309, x308 = bits.Mul64(x4, arg2[1])
var x310 uint64
var x311 uint64
x311, x310 = bits.Mul64(x4, arg2[0])
var x312 uint64
var x313 uint64
x312, x313 = bits.Add64(x311, x308, uint64(0x0))
var x314 uint64
var x315 uint64
x314, x315 = bits.Add64(x309, x306, uint64(p384Uint1(x313)))
var x316 uint64
var x317 uint64
x316, x317 = bits.Add64(x307, x304, uint64(p384Uint1(x315)))
var x318 uint64
var x319 uint64
x318, x319 = bits.Add64(x305, x302, uint64(p384Uint1(x317)))
var x320 uint64
var x321 uint64
x320, x321 = bits.Add64(x303, x300, uint64(p384Uint1(x319)))
x322 := (uint64(p384Uint1(x321)) + x301)
var x323 uint64
var x324 uint64
x323, x324 = bits.Add64(x287, x310, uint64(0x0))
var x325 uint64
var x326 uint64
x325, x326 = bits.Add64(x289, x312, uint64(p384Uint1(x324)))
var x327 uint64
var x328 uint64
x327, x328 = bits.Add64(x291, x314, uint64(p384Uint1(x326)))
var x329 uint64
var x330 uint64
x329, x330 = bits.Add64(x293, x316, uint64(p384Uint1(x328)))
var x331 uint64
var x332 uint64
x331, x332 = bits.Add64(x295, x318, uint64(p384Uint1(x330)))
var x333 uint64
var x334 uint64
x333, x334 = bits.Add64(x297, x320, uint64(p384Uint1(x332)))
var x335 uint64
var x336 uint64
x335, x336 = bits.Add64(x299, x322, uint64(p384Uint1(x334)))
var x337 uint64
_, x337 = bits.Mul64(x323, 0x100000001)
var x339 uint64
var x340 uint64
x340, x339 = bits.Mul64(x337, 0xffffffffffffffff)
var x341 uint64
var x342 uint64
x342, x341 = bits.Mul64(x337, 0xffffffffffffffff)
var x343 uint64
var x344 uint64
x344, x343 = bits.Mul64(x337, 0xffffffffffffffff)
var x345 uint64
var x346 uint64
x346, x345 = bits.Mul64(x337, 0xfffffffffffffffe)
var x347 uint64
var x348 uint64
x348, x347 = bits.Mul64(x337, 0xffffffff00000000)
var x349 uint64
var x350 uint64
x350, x349 = bits.Mul64(x337, 0xffffffff)
var x351 uint64
var x352 uint64
x351, x352 = bits.Add64(x350, x347, uint64(0x0))
var x353 uint64
var x354 uint64
x353, x354 = bits.Add64(x348, x345, uint64(p384Uint1(x352)))
var x355 uint64
var x356 uint64
x355, x356 = bits.Add64(x346, x343, uint64(p384Uint1(x354)))
var x357 uint64
var x358 uint64
x357, x358 = bits.Add64(x344, x341, uint64(p384Uint1(x356)))
var x359 uint64
var x360 uint64
x359, x360 = bits.Add64(x342, x339, uint64(p384Uint1(x358)))
x361 := (uint64(p384Uint1(x360)) + x340)
var x363 uint64
_, x363 = bits.Add64(x323, x349, uint64(0x0))
var x364 uint64
var x365 uint64
x364, x365 = bits.Add64(x325, x351, uint64(p384Uint1(x363)))
var x366 uint64
var x367 uint64
x366, x367 = bits.Add64(x327, x353, uint64(p384Uint1(x365)))
var x368 uint64
var x369 uint64
x368, x369 = bits.Add64(x329, x355, uint64(p384Uint1(x367)))
var x370 uint64
var x371 uint64
x370, x371 = bits.Add64(x331, x357, uint64(p384Uint1(x369)))
var x372 uint64
var x373 uint64
x372, x373 = bits.Add64(x333, x359, uint64(p384Uint1(x371)))
var x374 uint64
var x375 uint64
x374, x375 = bits.Add64(x335, x361, uint64(p384Uint1(x373)))
x376 := (uint64(p384Uint1(x375)) + uint64(p384Uint1(x336)))
var x377 uint64
var x378 uint64
x378, x377 = bits.Mul64(x5, arg2[5])
var x379 uint64
var x380 uint64
x380, x379 = bits.Mul64(x5, arg2[4])
var x381 uint64
var x382 uint64
x382, x381 = bits.Mul64(x5, arg2[3])
var x383 uint64
var x384 uint64
x384, x383 = bits.Mul64(x5, arg2[2])
var x385 uint64
var x386 uint64
x386, x385 = bits.Mul64(x5, arg2[1])
var x387 uint64
var x388 uint64
x388, x387 = bits.Mul64(x5, arg2[0])
var x389 uint64
var x390 uint64
x389, x390 = bits.Add64(x388, x385, uint64(0x0))
var x391 uint64
var x392 uint64
x391, x392 = bits.Add64(x386, x383, uint64(p384Uint1(x390)))
var x393 uint64
var x394 uint64
x393, x394 = bits.Add64(x384, x381, uint64(p384Uint1(x392)))
var x395 uint64
var x396 uint64
x395, x396 = bits.Add64(x382, x379, uint64(p384Uint1(x394)))
var x397 uint64
var x398 uint64
x397, x398 = bits.Add64(x380, x377, uint64(p384Uint1(x396)))
x399 := (uint64(p384Uint1(x398)) + x378)
var x400 uint64
var x401 uint64
x400, x401 = bits.Add64(x364, x387, uint64(0x0))
var x402 uint64
var x403 uint64
x402, x403 = bits.Add64(x366, x389, uint64(p384Uint1(x401)))
var x404 uint64
var x405 uint64
x404, x405 = bits.Add64(x368, x391, uint64(p384Uint1(x403)))
var x406 uint64
var x407 uint64
x406, x407 = bits.Add64(x370, x393, uint64(p384Uint1(x405)))
var x408 uint64
var x409 uint64
x408, x409 = bits.Add64(x372, x395, uint64(p384Uint1(x407)))
var x410 uint64
var x411 uint64
x410, x411 = bits.Add64(x374, x397, uint64(p384Uint1(x409)))
var x412 uint64
var x413 uint64
x412, x413 = bits.Add64(x376, x399, uint64(p384Uint1(x411)))
var x414 uint64
_, x414 = bits.Mul64(x400, 0x100000001)
var x416 uint64
var x417 uint64
x417, x416 = bits.Mul64(x414, 0xffffffffffffffff)
var x418 uint64
var x419 uint64
x419, x418 = bits.Mul64(x414, 0xffffffffffffffff)
var x420 uint64
var x421 uint64
x421, x420 = bits.Mul64(x414, 0xffffffffffffffff)
var x422 uint64
var x423 uint64
x423, x422 = bits.Mul64(x414, 0xfffffffffffffffe)
var x424 uint64
var x425 uint64
x425, x424 = bits.Mul64(x414, 0xffffffff00000000)
var x426 uint64
var x427 uint64
x427, x426 = bits.Mul64(x414, 0xffffffff)
var x428 uint64
var x429 uint64
x428, x429 = bits.Add64(x427, x424, uint64(0x0))
var x430 uint64
var x431 uint64
x430, x431 = bits.Add64(x425, x422, uint64(p384Uint1(x429)))
var x432 uint64
var x433 uint64
x432, x433 = bits.Add64(x423, x420, uint64(p384Uint1(x431)))
var x434 uint64
var x435 uint64
x434, x435 = bits.Add64(x421, x418, uint64(p384Uint1(x433)))
var x436 uint64
var x437 uint64
x436, x437 = bits.Add64(x419, x416, uint64(p384Uint1(x435)))
x438 := (uint64(p384Uint1(x437)) + x417)
var x440 uint64
_, x440 = bits.Add64(x400, x426, uint64(0x0))
var x441 uint64
var x442 uint64
x441, x442 = bits.Add64(x402, x428, uint64(p384Uint1(x440)))
var x443 uint64
var x444 uint64
x443, x444 = bits.Add64(x404, x430, uint64(p384Uint1(x442)))
var x445 uint64
var x446 uint64
x445, x446 = bits.Add64(x406, x432, uint64(p384Uint1(x444)))
var x447 uint64
var x448 uint64
x447, x448 = bits.Add64(x408, x434, uint64(p384Uint1(x446)))
var x449 uint64
var x450 uint64
x449, x450 = bits.Add64(x410, x436, uint64(p384Uint1(x448)))
var x451 uint64
var x452 uint64
x451, x452 = bits.Add64(x412, x438, uint64(p384Uint1(x450)))
x453 := (uint64(p384Uint1(x452)) + uint64(p384Uint1(x413)))
var x454 uint64
var x455 uint64
x454, x455 = bits.Sub64(x441, 0xffffffff, uint64(0x0))
var x456 uint64
var x457 uint64
x456, x457 = bits.Sub64(x443, 0xffffffff00000000, uint64(p384Uint1(x455)))
var x458 uint64
var x459 uint64
x458, x459 = bits.Sub64(x445, 0xfffffffffffffffe, uint64(p384Uint1(x457)))
var x460 uint64
var x461 uint64
x460, x461 = bits.Sub64(x447, 0xffffffffffffffff, uint64(p384Uint1(x459)))
var x462 uint64
var x463 uint64
x462, x463 = bits.Sub64(x449, 0xffffffffffffffff, uint64(p384Uint1(x461)))
var x464 uint64
var x465 uint64
x464, x465 = bits.Sub64(x451, 0xffffffffffffffff, uint64(p384Uint1(x463)))
var x467 uint64
_, x467 = bits.Sub64(x453, uint64(0x0), uint64(p384Uint1(x465)))
var x468 uint64
p384CmovznzU64(&x468, p384Uint1(x467), x454, x441)
var x469 uint64
p384CmovznzU64(&x469, p384Uint1(x467), x456, x443)
var x470 uint64
p384CmovznzU64(&x470, p384Uint1(x467), x458, x445)
var x471 uint64
p384CmovznzU64(&x471, p384Uint1(x467), x460, x447)
var x472 uint64
p384CmovznzU64(&x472, p384Uint1(x467), x462, x449)
var x473 uint64
p384CmovznzU64(&x473, p384Uint1(x467), x464, x451)
out1[0] = x468
out1[1] = x469
out1[2] = x470
out1[3] = x471
out1[4] = x472
out1[5] = x473
}
// p384Square squares a field element in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg1)) mod m
// 0 ≤ eval out1 < m
func p384Square(out1 *p384MontgomeryDomainFieldElement, arg1 *p384MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[4]
x5 := arg1[5]
x6 := arg1[0]
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x6, arg1[5])
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x6, arg1[4])
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x6, arg1[3])
var x13 uint64
var x14 uint64
x14, x13 = bits.Mul64(x6, arg1[2])
var x15 uint64
var x16 uint64
x16, x15 = bits.Mul64(x6, arg1[1])
var x17 uint64
var x18 uint64
x18, x17 = bits.Mul64(x6, arg1[0])
var x19 uint64
var x20 uint64
x19, x20 = bits.Add64(x18, x15, uint64(0x0))
var x21 uint64
var x22 uint64
x21, x22 = bits.Add64(x16, x13, uint64(p384Uint1(x20)))
var x23 uint64
var x24 uint64
x23, x24 = bits.Add64(x14, x11, uint64(p384Uint1(x22)))
var x25 uint64
var x26 uint64
x25, x26 = bits.Add64(x12, x9, uint64(p384Uint1(x24)))
var x27 uint64
var x28 uint64
x27, x28 = bits.Add64(x10, x7, uint64(p384Uint1(x26)))
x29 := (uint64(p384Uint1(x28)) + x8)
var x30 uint64
_, x30 = bits.Mul64(x17, 0x100000001)
var x32 uint64
var x33 uint64
x33, x32 = bits.Mul64(x30, 0xffffffffffffffff)
var x34 uint64
var x35 uint64
x35, x34 = bits.Mul64(x30, 0xffffffffffffffff)
var x36 uint64
var x37 uint64
x37, x36 = bits.Mul64(x30, 0xffffffffffffffff)
var x38 uint64
var x39 uint64
x39, x38 = bits.Mul64(x30, 0xfffffffffffffffe)
var x40 uint64
var x41 uint64
x41, x40 = bits.Mul64(x30, 0xffffffff00000000)
var x42 uint64
var x43 uint64
x43, x42 = bits.Mul64(x30, 0xffffffff)
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(x43, x40, uint64(0x0))
var x46 uint64
var x47 uint64
x46, x47 = bits.Add64(x41, x38, uint64(p384Uint1(x45)))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(x39, x36, uint64(p384Uint1(x47)))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x37, x34, uint64(p384Uint1(x49)))
var x52 uint64
var x53 uint64
x52, x53 = bits.Add64(x35, x32, uint64(p384Uint1(x51)))
x54 := (uint64(p384Uint1(x53)) + x33)
var x56 uint64
_, x56 = bits.Add64(x17, x42, uint64(0x0))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(x19, x44, uint64(p384Uint1(x56)))
var x59 uint64
var x60 uint64
x59, x60 = bits.Add64(x21, x46, uint64(p384Uint1(x58)))
var x61 uint64
var x62 uint64
x61, x62 = bits.Add64(x23, x48, uint64(p384Uint1(x60)))
var x63 uint64
var x64 uint64
x63, x64 = bits.Add64(x25, x50, uint64(p384Uint1(x62)))
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x27, x52, uint64(p384Uint1(x64)))
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x29, x54, uint64(p384Uint1(x66)))
var x69 uint64
var x70 uint64
x70, x69 = bits.Mul64(x1, arg1[5])
var x71 uint64
var x72 uint64
x72, x71 = bits.Mul64(x1, arg1[4])
var x73 uint64
var x74 uint64
x74, x73 = bits.Mul64(x1, arg1[3])
var x75 uint64
var x76 uint64
x76, x75 = bits.Mul64(x1, arg1[2])
var x77 uint64
var x78 uint64
x78, x77 = bits.Mul64(x1, arg1[1])
var x79 uint64
var x80 uint64
x80, x79 = bits.Mul64(x1, arg1[0])
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x80, x77, uint64(0x0))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x78, x75, uint64(p384Uint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x76, x73, uint64(p384Uint1(x84)))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x74, x71, uint64(p384Uint1(x86)))
var x89 uint64
var x90 uint64
x89, x90 = bits.Add64(x72, x69, uint64(p384Uint1(x88)))
x91 := (uint64(p384Uint1(x90)) + x70)
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x57, x79, uint64(0x0))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x59, x81, uint64(p384Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x61, x83, uint64(p384Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x63, x85, uint64(p384Uint1(x97)))
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x65, x87, uint64(p384Uint1(x99)))
var x102 uint64
var x103 uint64
x102, x103 = bits.Add64(x67, x89, uint64(p384Uint1(x101)))
var x104 uint64
var x105 uint64
x104, x105 = bits.Add64(uint64(p384Uint1(x68)), x91, uint64(p384Uint1(x103)))
var x106 uint64
_, x106 = bits.Mul64(x92, 0x100000001)
var x108 uint64
var x109 uint64
x109, x108 = bits.Mul64(x106, 0xffffffffffffffff)
var x110 uint64
var x111 uint64
x111, x110 = bits.Mul64(x106, 0xffffffffffffffff)
var x112 uint64
var x113 uint64
x113, x112 = bits.Mul64(x106, 0xffffffffffffffff)
var x114 uint64
var x115 uint64
x115, x114 = bits.Mul64(x106, 0xfffffffffffffffe)
var x116 uint64
var x117 uint64
x117, x116 = bits.Mul64(x106, 0xffffffff00000000)
var x118 uint64
var x119 uint64
x119, x118 = bits.Mul64(x106, 0xffffffff)
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x119, x116, uint64(0x0))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x117, x114, uint64(p384Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x115, x112, uint64(p384Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x113, x110, uint64(p384Uint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x111, x108, uint64(p384Uint1(x127)))
x130 := (uint64(p384Uint1(x129)) + x109)
var x132 uint64
_, x132 = bits.Add64(x92, x118, uint64(0x0))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x94, x120, uint64(p384Uint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x96, x122, uint64(p384Uint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x98, x124, uint64(p384Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x100, x126, uint64(p384Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x102, x128, uint64(p384Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x104, x130, uint64(p384Uint1(x142)))
x145 := (uint64(p384Uint1(x144)) + uint64(p384Uint1(x105)))
var x146 uint64
var x147 uint64
x147, x146 = bits.Mul64(x2, arg1[5])
var x148 uint64
var x149 uint64
x149, x148 = bits.Mul64(x2, arg1[4])
var x150 uint64
var x151 uint64
x151, x150 = bits.Mul64(x2, arg1[3])
var x152 uint64
var x153 uint64
x153, x152 = bits.Mul64(x2, arg1[2])
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x2, arg1[1])
var x156 uint64
var x157 uint64
x157, x156 = bits.Mul64(x2, arg1[0])
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x157, x154, uint64(0x0))
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x155, x152, uint64(p384Uint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Add64(x153, x150, uint64(p384Uint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Add64(x151, x148, uint64(p384Uint1(x163)))
var x166 uint64
var x167 uint64
x166, x167 = bits.Add64(x149, x146, uint64(p384Uint1(x165)))
x168 := (uint64(p384Uint1(x167)) + x147)
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x133, x156, uint64(0x0))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x135, x158, uint64(p384Uint1(x170)))
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x137, x160, uint64(p384Uint1(x172)))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x139, x162, uint64(p384Uint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x141, x164, uint64(p384Uint1(x176)))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x143, x166, uint64(p384Uint1(x178)))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x145, x168, uint64(p384Uint1(x180)))
var x183 uint64
_, x183 = bits.Mul64(x169, 0x100000001)
var x185 uint64
var x186 uint64
x186, x185 = bits.Mul64(x183, 0xffffffffffffffff)
var x187 uint64
var x188 uint64
x188, x187 = bits.Mul64(x183, 0xffffffffffffffff)
var x189 uint64
var x190 uint64
x190, x189 = bits.Mul64(x183, 0xffffffffffffffff)
var x191 uint64
var x192 uint64
x192, x191 = bits.Mul64(x183, 0xfffffffffffffffe)
var x193 uint64
var x194 uint64
x194, x193 = bits.Mul64(x183, 0xffffffff00000000)
var x195 uint64
var x196 uint64
x196, x195 = bits.Mul64(x183, 0xffffffff)
var x197 uint64
var x198 uint64
x197, x198 = bits.Add64(x196, x193, uint64(0x0))
var x199 uint64
var x200 uint64
x199, x200 = bits.Add64(x194, x191, uint64(p384Uint1(x198)))
var x201 uint64
var x202 uint64
x201, x202 = bits.Add64(x192, x189, uint64(p384Uint1(x200)))
var x203 uint64
var x204 uint64
x203, x204 = bits.Add64(x190, x187, uint64(p384Uint1(x202)))
var x205 uint64
var x206 uint64
x205, x206 = bits.Add64(x188, x185, uint64(p384Uint1(x204)))
x207 := (uint64(p384Uint1(x206)) + x186)
var x209 uint64
_, x209 = bits.Add64(x169, x195, uint64(0x0))
var x210 uint64
var x211 uint64
x210, x211 = bits.Add64(x171, x197, uint64(p384Uint1(x209)))
var x212 uint64
var x213 uint64
x212, x213 = bits.Add64(x173, x199, uint64(p384Uint1(x211)))
var x214 uint64
var x215 uint64
x214, x215 = bits.Add64(x175, x201, uint64(p384Uint1(x213)))
var x216 uint64
var x217 uint64
x216, x217 = bits.Add64(x177, x203, uint64(p384Uint1(x215)))
var x218 uint64
var x219 uint64
x218, x219 = bits.Add64(x179, x205, uint64(p384Uint1(x217)))
var x220 uint64
var x221 uint64
x220, x221 = bits.Add64(x181, x207, uint64(p384Uint1(x219)))
x222 := (uint64(p384Uint1(x221)) + uint64(p384Uint1(x182)))
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x3, arg1[5])
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x3, arg1[4])
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x3, arg1[3])
var x229 uint64
var x230 uint64
x230, x229 = bits.Mul64(x3, arg1[2])
var x231 uint64
var x232 uint64
x232, x231 = bits.Mul64(x3, arg1[1])
var x233 uint64
var x234 uint64
x234, x233 = bits.Mul64(x3, arg1[0])
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x234, x231, uint64(0x0))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x232, x229, uint64(p384Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x230, x227, uint64(p384Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x228, x225, uint64(p384Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x226, x223, uint64(p384Uint1(x242)))
x245 := (uint64(p384Uint1(x244)) + x224)
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x210, x233, uint64(0x0))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x212, x235, uint64(p384Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x214, x237, uint64(p384Uint1(x249)))
var x252 uint64
var x253 uint64
x252, x253 = bits.Add64(x216, x239, uint64(p384Uint1(x251)))
var x254 uint64
var x255 uint64
x254, x255 = bits.Add64(x218, x241, uint64(p384Uint1(x253)))
var x256 uint64
var x257 uint64
x256, x257 = bits.Add64(x220, x243, uint64(p384Uint1(x255)))
var x258 uint64
var x259 uint64
x258, x259 = bits.Add64(x222, x245, uint64(p384Uint1(x257)))
var x260 uint64
_, x260 = bits.Mul64(x246, 0x100000001)
var x262 uint64
var x263 uint64
x263, x262 = bits.Mul64(x260, 0xffffffffffffffff)
var x264 uint64
var x265 uint64
x265, x264 = bits.Mul64(x260, 0xffffffffffffffff)
var x266 uint64
var x267 uint64
x267, x266 = bits.Mul64(x260, 0xffffffffffffffff)
var x268 uint64
var x269 uint64
x269, x268 = bits.Mul64(x260, 0xfffffffffffffffe)
var x270 uint64
var x271 uint64
x271, x270 = bits.Mul64(x260, 0xffffffff00000000)
var x272 uint64
var x273 uint64
x273, x272 = bits.Mul64(x260, 0xffffffff)
var x274 uint64
var x275 uint64
x274, x275 = bits.Add64(x273, x270, uint64(0x0))
var x276 uint64
var x277 uint64
x276, x277 = bits.Add64(x271, x268, uint64(p384Uint1(x275)))
var x278 uint64
var x279 uint64
x278, x279 = bits.Add64(x269, x266, uint64(p384Uint1(x277)))
var x280 uint64
var x281 uint64
x280, x281 = bits.Add64(x267, x264, uint64(p384Uint1(x279)))
var x282 uint64
var x283 uint64
x282, x283 = bits.Add64(x265, x262, uint64(p384Uint1(x281)))
x284 := (uint64(p384Uint1(x283)) + x263)
var x286 uint64
_, x286 = bits.Add64(x246, x272, uint64(0x0))
var x287 uint64
var x288 uint64
x287, x288 = bits.Add64(x248, x274, uint64(p384Uint1(x286)))
var x289 uint64
var x290 uint64
x289, x290 = bits.Add64(x250, x276, uint64(p384Uint1(x288)))
var x291 uint64
var x292 uint64
x291, x292 = bits.Add64(x252, x278, uint64(p384Uint1(x290)))
var x293 uint64
var x294 uint64
x293, x294 = bits.Add64(x254, x280, uint64(p384Uint1(x292)))
var x295 uint64
var x296 uint64
x295, x296 = bits.Add64(x256, x282, uint64(p384Uint1(x294)))
var x297 uint64
var x298 uint64
x297, x298 = bits.Add64(x258, x284, uint64(p384Uint1(x296)))
x299 := (uint64(p384Uint1(x298)) + uint64(p384Uint1(x259)))
var x300 uint64
var x301 uint64
x301, x300 = bits.Mul64(x4, arg1[5])
var x302 uint64
var x303 uint64
x303, x302 = bits.Mul64(x4, arg1[4])
var x304 uint64
var x305 uint64
x305, x304 = bits.Mul64(x4, arg1[3])
var x306 uint64
var x307 uint64
x307, x306 = bits.Mul64(x4, arg1[2])
var x308 uint64
var x309 uint64
x309, x308 = bits.Mul64(x4, arg1[1])
var x310 uint64
var x311 uint64
x311, x310 = bits.Mul64(x4, arg1[0])
var x312 uint64
var x313 uint64
x312, x313 = bits.Add64(x311, x308, uint64(0x0))
var x314 uint64
var x315 uint64
x314, x315 = bits.Add64(x309, x306, uint64(p384Uint1(x313)))
var x316 uint64
var x317 uint64
x316, x317 = bits.Add64(x307, x304, uint64(p384Uint1(x315)))
var x318 uint64
var x319 uint64
x318, x319 = bits.Add64(x305, x302, uint64(p384Uint1(x317)))
var x320 uint64
var x321 uint64
x320, x321 = bits.Add64(x303, x300, uint64(p384Uint1(x319)))
x322 := (uint64(p384Uint1(x321)) + x301)
var x323 uint64
var x324 uint64
x323, x324 = bits.Add64(x287, x310, uint64(0x0))
var x325 uint64
var x326 uint64
x325, x326 = bits.Add64(x289, x312, uint64(p384Uint1(x324)))
var x327 uint64
var x328 uint64
x327, x328 = bits.Add64(x291, x314, uint64(p384Uint1(x326)))
var x329 uint64
var x330 uint64
x329, x330 = bits.Add64(x293, x316, uint64(p384Uint1(x328)))
var x331 uint64
var x332 uint64
x331, x332 = bits.Add64(x295, x318, uint64(p384Uint1(x330)))
var x333 uint64
var x334 uint64
x333, x334 = bits.Add64(x297, x320, uint64(p384Uint1(x332)))
var x335 uint64
var x336 uint64
x335, x336 = bits.Add64(x299, x322, uint64(p384Uint1(x334)))
var x337 uint64
_, x337 = bits.Mul64(x323, 0x100000001)
var x339 uint64
var x340 uint64
x340, x339 = bits.Mul64(x337, 0xffffffffffffffff)
var x341 uint64
var x342 uint64
x342, x341 = bits.Mul64(x337, 0xffffffffffffffff)
var x343 uint64
var x344 uint64
x344, x343 = bits.Mul64(x337, 0xffffffffffffffff)
var x345 uint64
var x346 uint64
x346, x345 = bits.Mul64(x337, 0xfffffffffffffffe)
var x347 uint64
var x348 uint64
x348, x347 = bits.Mul64(x337, 0xffffffff00000000)
var x349 uint64
var x350 uint64
x350, x349 = bits.Mul64(x337, 0xffffffff)
var x351 uint64
var x352 uint64
x351, x352 = bits.Add64(x350, x347, uint64(0x0))
var x353 uint64
var x354 uint64
x353, x354 = bits.Add64(x348, x345, uint64(p384Uint1(x352)))
var x355 uint64
var x356 uint64
x355, x356 = bits.Add64(x346, x343, uint64(p384Uint1(x354)))
var x357 uint64
var x358 uint64
x357, x358 = bits.Add64(x344, x341, uint64(p384Uint1(x356)))
var x359 uint64
var x360 uint64
x359, x360 = bits.Add64(x342, x339, uint64(p384Uint1(x358)))
x361 := (uint64(p384Uint1(x360)) + x340)
var x363 uint64
_, x363 = bits.Add64(x323, x349, uint64(0x0))
var x364 uint64
var x365 uint64
x364, x365 = bits.Add64(x325, x351, uint64(p384Uint1(x363)))
var x366 uint64
var x367 uint64
x366, x367 = bits.Add64(x327, x353, uint64(p384Uint1(x365)))
var x368 uint64
var x369 uint64
x368, x369 = bits.Add64(x329, x355, uint64(p384Uint1(x367)))
var x370 uint64
var x371 uint64
x370, x371 = bits.Add64(x331, x357, uint64(p384Uint1(x369)))
var x372 uint64
var x373 uint64
x372, x373 = bits.Add64(x333, x359, uint64(p384Uint1(x371)))
var x374 uint64
var x375 uint64
x374, x375 = bits.Add64(x335, x361, uint64(p384Uint1(x373)))
x376 := (uint64(p384Uint1(x375)) + uint64(p384Uint1(x336)))
var x377 uint64
var x378 uint64
x378, x377 = bits.Mul64(x5, arg1[5])
var x379 uint64
var x380 uint64
x380, x379 = bits.Mul64(x5, arg1[4])
var x381 uint64
var x382 uint64
x382, x381 = bits.Mul64(x5, arg1[3])
var x383 uint64
var x384 uint64
x384, x383 = bits.Mul64(x5, arg1[2])
var x385 uint64
var x386 uint64
x386, x385 = bits.Mul64(x5, arg1[1])
var x387 uint64
var x388 uint64
x388, x387 = bits.Mul64(x5, arg1[0])
var x389 uint64
var x390 uint64
x389, x390 = bits.Add64(x388, x385, uint64(0x0))
var x391 uint64
var x392 uint64
x391, x392 = bits.Add64(x386, x383, uint64(p384Uint1(x390)))
var x393 uint64
var x394 uint64
x393, x394 = bits.Add64(x384, x381, uint64(p384Uint1(x392)))
var x395 uint64
var x396 uint64
x395, x396 = bits.Add64(x382, x379, uint64(p384Uint1(x394)))
var x397 uint64
var x398 uint64
x397, x398 = bits.Add64(x380, x377, uint64(p384Uint1(x396)))
x399 := (uint64(p384Uint1(x398)) + x378)
var x400 uint64
var x401 uint64
x400, x401 = bits.Add64(x364, x387, uint64(0x0))
var x402 uint64
var x403 uint64
x402, x403 = bits.Add64(x366, x389, uint64(p384Uint1(x401)))
var x404 uint64
var x405 uint64
x404, x405 = bits.Add64(x368, x391, uint64(p384Uint1(x403)))
var x406 uint64
var x407 uint64
x406, x407 = bits.Add64(x370, x393, uint64(p384Uint1(x405)))
var x408 uint64
var x409 uint64
x408, x409 = bits.Add64(x372, x395, uint64(p384Uint1(x407)))
var x410 uint64
var x411 uint64
x410, x411 = bits.Add64(x374, x397, uint64(p384Uint1(x409)))
var x412 uint64
var x413 uint64
x412, x413 = bits.Add64(x376, x399, uint64(p384Uint1(x411)))
var x414 uint64
_, x414 = bits.Mul64(x400, 0x100000001)
var x416 uint64
var x417 uint64
x417, x416 = bits.Mul64(x414, 0xffffffffffffffff)
var x418 uint64
var x419 uint64
x419, x418 = bits.Mul64(x414, 0xffffffffffffffff)
var x420 uint64
var x421 uint64
x421, x420 = bits.Mul64(x414, 0xffffffffffffffff)
var x422 uint64
var x423 uint64
x423, x422 = bits.Mul64(x414, 0xfffffffffffffffe)
var x424 uint64
var x425 uint64
x425, x424 = bits.Mul64(x414, 0xffffffff00000000)
var x426 uint64
var x427 uint64
x427, x426 = bits.Mul64(x414, 0xffffffff)
var x428 uint64
var x429 uint64
x428, x429 = bits.Add64(x427, x424, uint64(0x0))
var x430 uint64
var x431 uint64
x430, x431 = bits.Add64(x425, x422, uint64(p384Uint1(x429)))
var x432 uint64
var x433 uint64
x432, x433 = bits.Add64(x423, x420, uint64(p384Uint1(x431)))
var x434 uint64
var x435 uint64
x434, x435 = bits.Add64(x421, x418, uint64(p384Uint1(x433)))
var x436 uint64
var x437 uint64
x436, x437 = bits.Add64(x419, x416, uint64(p384Uint1(x435)))
x438 := (uint64(p384Uint1(x437)) + x417)
var x440 uint64
_, x440 = bits.Add64(x400, x426, uint64(0x0))
var x441 uint64
var x442 uint64
x441, x442 = bits.Add64(x402, x428, uint64(p384Uint1(x440)))
var x443 uint64
var x444 uint64
x443, x444 = bits.Add64(x404, x430, uint64(p384Uint1(x442)))
var x445 uint64
var x446 uint64
x445, x446 = bits.Add64(x406, x432, uint64(p384Uint1(x444)))
var x447 uint64
var x448 uint64
x447, x448 = bits.Add64(x408, x434, uint64(p384Uint1(x446)))
var x449 uint64
var x450 uint64
x449, x450 = bits.Add64(x410, x436, uint64(p384Uint1(x448)))
var x451 uint64
var x452 uint64
x451, x452 = bits.Add64(x412, x438, uint64(p384Uint1(x450)))
x453 := (uint64(p384Uint1(x452)) + uint64(p384Uint1(x413)))
var x454 uint64
var x455 uint64
x454, x455 = bits.Sub64(x441, 0xffffffff, uint64(0x0))
var x456 uint64
var x457 uint64
x456, x457 = bits.Sub64(x443, 0xffffffff00000000, uint64(p384Uint1(x455)))
var x458 uint64
var x459 uint64
x458, x459 = bits.Sub64(x445, 0xfffffffffffffffe, uint64(p384Uint1(x457)))
var x460 uint64
var x461 uint64
x460, x461 = bits.Sub64(x447, 0xffffffffffffffff, uint64(p384Uint1(x459)))
var x462 uint64
var x463 uint64
x462, x463 = bits.Sub64(x449, 0xffffffffffffffff, uint64(p384Uint1(x461)))
var x464 uint64
var x465 uint64
x464, x465 = bits.Sub64(x451, 0xffffffffffffffff, uint64(p384Uint1(x463)))
var x467 uint64
_, x467 = bits.Sub64(x453, uint64(0x0), uint64(p384Uint1(x465)))
var x468 uint64
p384CmovznzU64(&x468, p384Uint1(x467), x454, x441)
var x469 uint64
p384CmovznzU64(&x469, p384Uint1(x467), x456, x443)
var x470 uint64
p384CmovznzU64(&x470, p384Uint1(x467), x458, x445)
var x471 uint64
p384CmovznzU64(&x471, p384Uint1(x467), x460, x447)
var x472 uint64
p384CmovznzU64(&x472, p384Uint1(x467), x462, x449)
var x473 uint64
p384CmovznzU64(&x473, p384Uint1(x467), x464, x451)
out1[0] = x468
out1[1] = x469
out1[2] = x470
out1[3] = x471
out1[4] = x472
out1[5] = x473
}
// p384Add adds two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) + eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p384Add(out1 *p384MontgomeryDomainFieldElement, arg1 *p384MontgomeryDomainFieldElement, arg2 *p384MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Add64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Add64(arg1[1], arg2[1], uint64(p384Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(arg1[2], arg2[2], uint64(p384Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Add64(arg1[3], arg2[3], uint64(p384Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Add64(arg1[4], arg2[4], uint64(p384Uint1(x8)))
var x11 uint64
var x12 uint64
x11, x12 = bits.Add64(arg1[5], arg2[5], uint64(p384Uint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Sub64(x1, 0xffffffff, uint64(0x0))
var x15 uint64
var x16 uint64
x15, x16 = bits.Sub64(x3, 0xffffffff00000000, uint64(p384Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Sub64(x5, 0xfffffffffffffffe, uint64(p384Uint1(x16)))
var x19 uint64
var x20 uint64
x19, x20 = bits.Sub64(x7, 0xffffffffffffffff, uint64(p384Uint1(x18)))
var x21 uint64
var x22 uint64
x21, x22 = bits.Sub64(x9, 0xffffffffffffffff, uint64(p384Uint1(x20)))
var x23 uint64
var x24 uint64
x23, x24 = bits.Sub64(x11, 0xffffffffffffffff, uint64(p384Uint1(x22)))
var x26 uint64
_, x26 = bits.Sub64(uint64(p384Uint1(x12)), uint64(0x0), uint64(p384Uint1(x24)))
var x27 uint64
p384CmovznzU64(&x27, p384Uint1(x26), x13, x1)
var x28 uint64
p384CmovznzU64(&x28, p384Uint1(x26), x15, x3)
var x29 uint64
p384CmovznzU64(&x29, p384Uint1(x26), x17, x5)
var x30 uint64
p384CmovznzU64(&x30, p384Uint1(x26), x19, x7)
var x31 uint64
p384CmovznzU64(&x31, p384Uint1(x26), x21, x9)
var x32 uint64
p384CmovznzU64(&x32, p384Uint1(x26), x23, x11)
out1[0] = x27
out1[1] = x28
out1[2] = x29
out1[3] = x30
out1[4] = x31
out1[5] = x32
}
// p384Sub subtracts two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) - eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p384Sub(out1 *p384MontgomeryDomainFieldElement, arg1 *p384MontgomeryDomainFieldElement, arg2 *p384MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(arg1[1], arg2[1], uint64(p384Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(arg1[2], arg2[2], uint64(p384Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(arg1[3], arg2[3], uint64(p384Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Sub64(arg1[4], arg2[4], uint64(p384Uint1(x8)))
var x11 uint64
var x12 uint64
x11, x12 = bits.Sub64(arg1[5], arg2[5], uint64(p384Uint1(x10)))
var x13 uint64
p384CmovznzU64(&x13, p384Uint1(x12), uint64(0x0), 0xffffffffffffffff)
var x14 uint64
var x15 uint64
x14, x15 = bits.Add64(x1, (x13 & 0xffffffff), uint64(0x0))
var x16 uint64
var x17 uint64
x16, x17 = bits.Add64(x3, (x13 & 0xffffffff00000000), uint64(p384Uint1(x15)))
var x18 uint64
var x19 uint64
x18, x19 = bits.Add64(x5, (x13 & 0xfffffffffffffffe), uint64(p384Uint1(x17)))
var x20 uint64
var x21 uint64
x20, x21 = bits.Add64(x7, x13, uint64(p384Uint1(x19)))
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x9, x13, uint64(p384Uint1(x21)))
var x24 uint64
x24, _ = bits.Add64(x11, x13, uint64(p384Uint1(x23)))
out1[0] = x14
out1[1] = x16
out1[2] = x18
out1[3] = x20
out1[4] = x22
out1[5] = x24
}
// p384SetOne returns the field element one in the Montgomery domain.
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = 1 mod m
// 0 ≤ eval out1 < m
func p384SetOne(out1 *p384MontgomeryDomainFieldElement) {
out1[0] = 0xffffffff00000001
out1[1] = 0xffffffff
out1[2] = uint64(0x1)
out1[3] = uint64(0x0)
out1[4] = uint64(0x0)
out1[5] = uint64(0x0)
}
// p384FromMontgomery translates a field element out of the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = (eval arg1 * ((2^64)⁻¹ mod m)^6) mod m
// 0 ≤ eval out1 < m
func p384FromMontgomery(out1 *p384NonMontgomeryDomainFieldElement, arg1 *p384MontgomeryDomainFieldElement) {
x1 := arg1[0]
var x2 uint64
_, x2 = bits.Mul64(x1, 0x100000001)
var x4 uint64
var x5 uint64
x5, x4 = bits.Mul64(x2, 0xffffffffffffffff)
var x6 uint64
var x7 uint64
x7, x6 = bits.Mul64(x2, 0xffffffffffffffff)
var x8 uint64
var x9 uint64
x9, x8 = bits.Mul64(x2, 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x11, x10 = bits.Mul64(x2, 0xfffffffffffffffe)
var x12 uint64
var x13 uint64
x13, x12 = bits.Mul64(x2, 0xffffffff00000000)
var x14 uint64
var x15 uint64
x15, x14 = bits.Mul64(x2, 0xffffffff)
var x16 uint64
var x17 uint64
x16, x17 = bits.Add64(x15, x12, uint64(0x0))
var x18 uint64
var x19 uint64
x18, x19 = bits.Add64(x13, x10, uint64(p384Uint1(x17)))
var x20 uint64
var x21 uint64
x20, x21 = bits.Add64(x11, x8, uint64(p384Uint1(x19)))
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x9, x6, uint64(p384Uint1(x21)))
var x24 uint64
var x25 uint64
x24, x25 = bits.Add64(x7, x4, uint64(p384Uint1(x23)))
var x27 uint64
_, x27 = bits.Add64(x1, x14, uint64(0x0))
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(uint64(0x0), x16, uint64(p384Uint1(x27)))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(uint64(0x0), x18, uint64(p384Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(uint64(0x0), x20, uint64(p384Uint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(uint64(0x0), x22, uint64(p384Uint1(x33)))
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(uint64(0x0), x24, uint64(p384Uint1(x35)))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(uint64(0x0), (uint64(p384Uint1(x25)) + x5), uint64(p384Uint1(x37)))
var x40 uint64
var x41 uint64
x40, x41 = bits.Add64(x28, arg1[1], uint64(0x0))
var x42 uint64
var x43 uint64
x42, x43 = bits.Add64(x30, uint64(0x0), uint64(p384Uint1(x41)))
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(x32, uint64(0x0), uint64(p384Uint1(x43)))
var x46 uint64
var x47 uint64
x46, x47 = bits.Add64(x34, uint64(0x0), uint64(p384Uint1(x45)))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(x36, uint64(0x0), uint64(p384Uint1(x47)))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(x38, uint64(0x0), uint64(p384Uint1(x49)))
var x52 uint64
_, x52 = bits.Mul64(x40, 0x100000001)
var x54 uint64
var x55 uint64
x55, x54 = bits.Mul64(x52, 0xffffffffffffffff)
var x56 uint64
var x57 uint64
x57, x56 = bits.Mul64(x52, 0xffffffffffffffff)
var x58 uint64
var x59 uint64
x59, x58 = bits.Mul64(x52, 0xffffffffffffffff)
var x60 uint64
var x61 uint64
x61, x60 = bits.Mul64(x52, 0xfffffffffffffffe)
var x62 uint64
var x63 uint64
x63, x62 = bits.Mul64(x52, 0xffffffff00000000)
var x64 uint64
var x65 uint64
x65, x64 = bits.Mul64(x52, 0xffffffff)
var x66 uint64
var x67 uint64
x66, x67 = bits.Add64(x65, x62, uint64(0x0))
var x68 uint64
var x69 uint64
x68, x69 = bits.Add64(x63, x60, uint64(p384Uint1(x67)))
var x70 uint64
var x71 uint64
x70, x71 = bits.Add64(x61, x58, uint64(p384Uint1(x69)))
var x72 uint64
var x73 uint64
x72, x73 = bits.Add64(x59, x56, uint64(p384Uint1(x71)))
var x74 uint64
var x75 uint64
x74, x75 = bits.Add64(x57, x54, uint64(p384Uint1(x73)))
var x77 uint64
_, x77 = bits.Add64(x40, x64, uint64(0x0))
var x78 uint64
var x79 uint64
x78, x79 = bits.Add64(x42, x66, uint64(p384Uint1(x77)))
var x80 uint64
var x81 uint64
x80, x81 = bits.Add64(x44, x68, uint64(p384Uint1(x79)))
var x82 uint64
var x83 uint64
x82, x83 = bits.Add64(x46, x70, uint64(p384Uint1(x81)))
var x84 uint64
var x85 uint64
x84, x85 = bits.Add64(x48, x72, uint64(p384Uint1(x83)))
var x86 uint64
var x87 uint64
x86, x87 = bits.Add64(x50, x74, uint64(p384Uint1(x85)))
var x88 uint64
var x89 uint64
x88, x89 = bits.Add64((uint64(p384Uint1(x51)) + uint64(p384Uint1(x39))), (uint64(p384Uint1(x75)) + x55), uint64(p384Uint1(x87)))
var x90 uint64
var x91 uint64
x90, x91 = bits.Add64(x78, arg1[2], uint64(0x0))
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x80, uint64(0x0), uint64(p384Uint1(x91)))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x82, uint64(0x0), uint64(p384Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x84, uint64(0x0), uint64(p384Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x86, uint64(0x0), uint64(p384Uint1(x97)))
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x88, uint64(0x0), uint64(p384Uint1(x99)))
var x102 uint64
_, x102 = bits.Mul64(x90, 0x100000001)
var x104 uint64
var x105 uint64
x105, x104 = bits.Mul64(x102, 0xffffffffffffffff)
var x106 uint64
var x107 uint64
x107, x106 = bits.Mul64(x102, 0xffffffffffffffff)
var x108 uint64
var x109 uint64
x109, x108 = bits.Mul64(x102, 0xffffffffffffffff)
var x110 uint64
var x111 uint64
x111, x110 = bits.Mul64(x102, 0xfffffffffffffffe)
var x112 uint64
var x113 uint64
x113, x112 = bits.Mul64(x102, 0xffffffff00000000)
var x114 uint64
var x115 uint64
x115, x114 = bits.Mul64(x102, 0xffffffff)
var x116 uint64
var x117 uint64
x116, x117 = bits.Add64(x115, x112, uint64(0x0))
var x118 uint64
var x119 uint64
x118, x119 = bits.Add64(x113, x110, uint64(p384Uint1(x117)))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x111, x108, uint64(p384Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x109, x106, uint64(p384Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x107, x104, uint64(p384Uint1(x123)))
var x127 uint64
_, x127 = bits.Add64(x90, x114, uint64(0x0))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x92, x116, uint64(p384Uint1(x127)))
var x130 uint64
var x131 uint64
x130, x131 = bits.Add64(x94, x118, uint64(p384Uint1(x129)))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x96, x120, uint64(p384Uint1(x131)))
var x134 uint64
var x135 uint64
x134, x135 = bits.Add64(x98, x122, uint64(p384Uint1(x133)))
var x136 uint64
var x137 uint64
x136, x137 = bits.Add64(x100, x124, uint64(p384Uint1(x135)))
var x138 uint64
var x139 uint64
x138, x139 = bits.Add64((uint64(p384Uint1(x101)) + uint64(p384Uint1(x89))), (uint64(p384Uint1(x125)) + x105), uint64(p384Uint1(x137)))
var x140 uint64
var x141 uint64
x140, x141 = bits.Add64(x128, arg1[3], uint64(0x0))
var x142 uint64
var x143 uint64
x142, x143 = bits.Add64(x130, uint64(0x0), uint64(p384Uint1(x141)))
var x144 uint64
var x145 uint64
x144, x145 = bits.Add64(x132, uint64(0x0), uint64(p384Uint1(x143)))
var x146 uint64
var x147 uint64
x146, x147 = bits.Add64(x134, uint64(0x0), uint64(p384Uint1(x145)))
var x148 uint64
var x149 uint64
x148, x149 = bits.Add64(x136, uint64(0x0), uint64(p384Uint1(x147)))
var x150 uint64
var x151 uint64
x150, x151 = bits.Add64(x138, uint64(0x0), uint64(p384Uint1(x149)))
var x152 uint64
_, x152 = bits.Mul64(x140, 0x100000001)
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x152, 0xffffffffffffffff)
var x156 uint64
var x157 uint64
x157, x156 = bits.Mul64(x152, 0xffffffffffffffff)
var x158 uint64
var x159 uint64
x159, x158 = bits.Mul64(x152, 0xffffffffffffffff)
var x160 uint64
var x161 uint64
x161, x160 = bits.Mul64(x152, 0xfffffffffffffffe)
var x162 uint64
var x163 uint64
x163, x162 = bits.Mul64(x152, 0xffffffff00000000)
var x164 uint64
var x165 uint64
x165, x164 = bits.Mul64(x152, 0xffffffff)
var x166 uint64
var x167 uint64
x166, x167 = bits.Add64(x165, x162, uint64(0x0))
var x168 uint64
var x169 uint64
x168, x169 = bits.Add64(x163, x160, uint64(p384Uint1(x167)))
var x170 uint64
var x171 uint64
x170, x171 = bits.Add64(x161, x158, uint64(p384Uint1(x169)))
var x172 uint64
var x173 uint64
x172, x173 = bits.Add64(x159, x156, uint64(p384Uint1(x171)))
var x174 uint64
var x175 uint64
x174, x175 = bits.Add64(x157, x154, uint64(p384Uint1(x173)))
var x177 uint64
_, x177 = bits.Add64(x140, x164, uint64(0x0))
var x178 uint64
var x179 uint64
x178, x179 = bits.Add64(x142, x166, uint64(p384Uint1(x177)))
var x180 uint64
var x181 uint64
x180, x181 = bits.Add64(x144, x168, uint64(p384Uint1(x179)))
var x182 uint64
var x183 uint64
x182, x183 = bits.Add64(x146, x170, uint64(p384Uint1(x181)))
var x184 uint64
var x185 uint64
x184, x185 = bits.Add64(x148, x172, uint64(p384Uint1(x183)))
var x186 uint64
var x187 uint64
x186, x187 = bits.Add64(x150, x174, uint64(p384Uint1(x185)))
var x188 uint64
var x189 uint64
x188, x189 = bits.Add64((uint64(p384Uint1(x151)) + uint64(p384Uint1(x139))), (uint64(p384Uint1(x175)) + x155), uint64(p384Uint1(x187)))
var x190 uint64
var x191 uint64
x190, x191 = bits.Add64(x178, arg1[4], uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Add64(x180, uint64(0x0), uint64(p384Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Add64(x182, uint64(0x0), uint64(p384Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Add64(x184, uint64(0x0), uint64(p384Uint1(x195)))
var x198 uint64
var x199 uint64
x198, x199 = bits.Add64(x186, uint64(0x0), uint64(p384Uint1(x197)))
var x200 uint64
var x201 uint64
x200, x201 = bits.Add64(x188, uint64(0x0), uint64(p384Uint1(x199)))
var x202 uint64
_, x202 = bits.Mul64(x190, 0x100000001)
var x204 uint64
var x205 uint64
x205, x204 = bits.Mul64(x202, 0xffffffffffffffff)
var x206 uint64
var x207 uint64
x207, x206 = bits.Mul64(x202, 0xffffffffffffffff)
var x208 uint64
var x209 uint64
x209, x208 = bits.Mul64(x202, 0xffffffffffffffff)
var x210 uint64
var x211 uint64
x211, x210 = bits.Mul64(x202, 0xfffffffffffffffe)
var x212 uint64
var x213 uint64
x213, x212 = bits.Mul64(x202, 0xffffffff00000000)
var x214 uint64
var x215 uint64
x215, x214 = bits.Mul64(x202, 0xffffffff)
var x216 uint64
var x217 uint64
x216, x217 = bits.Add64(x215, x212, uint64(0x0))
var x218 uint64
var x219 uint64
x218, x219 = bits.Add64(x213, x210, uint64(p384Uint1(x217)))
var x220 uint64
var x221 uint64
x220, x221 = bits.Add64(x211, x208, uint64(p384Uint1(x219)))
var x222 uint64
var x223 uint64
x222, x223 = bits.Add64(x209, x206, uint64(p384Uint1(x221)))
var x224 uint64
var x225 uint64
x224, x225 = bits.Add64(x207, x204, uint64(p384Uint1(x223)))
var x227 uint64
_, x227 = bits.Add64(x190, x214, uint64(0x0))
var x228 uint64
var x229 uint64
x228, x229 = bits.Add64(x192, x216, uint64(p384Uint1(x227)))
var x230 uint64
var x231 uint64
x230, x231 = bits.Add64(x194, x218, uint64(p384Uint1(x229)))
var x232 uint64
var x233 uint64
x232, x233 = bits.Add64(x196, x220, uint64(p384Uint1(x231)))
var x234 uint64
var x235 uint64
x234, x235 = bits.Add64(x198, x222, uint64(p384Uint1(x233)))
var x236 uint64
var x237 uint64
x236, x237 = bits.Add64(x200, x224, uint64(p384Uint1(x235)))
var x238 uint64
var x239 uint64
x238, x239 = bits.Add64((uint64(p384Uint1(x201)) + uint64(p384Uint1(x189))), (uint64(p384Uint1(x225)) + x205), uint64(p384Uint1(x237)))
var x240 uint64
var x241 uint64
x240, x241 = bits.Add64(x228, arg1[5], uint64(0x0))
var x242 uint64
var x243 uint64
x242, x243 = bits.Add64(x230, uint64(0x0), uint64(p384Uint1(x241)))
var x244 uint64
var x245 uint64
x244, x245 = bits.Add64(x232, uint64(0x0), uint64(p384Uint1(x243)))
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x234, uint64(0x0), uint64(p384Uint1(x245)))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x236, uint64(0x0), uint64(p384Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x238, uint64(0x0), uint64(p384Uint1(x249)))
var x252 uint64
_, x252 = bits.Mul64(x240, 0x100000001)
var x254 uint64
var x255 uint64
x255, x254 = bits.Mul64(x252, 0xffffffffffffffff)
var x256 uint64
var x257 uint64
x257, x256 = bits.Mul64(x252, 0xffffffffffffffff)
var x258 uint64
var x259 uint64
x259, x258 = bits.Mul64(x252, 0xffffffffffffffff)
var x260 uint64
var x261 uint64
x261, x260 = bits.Mul64(x252, 0xfffffffffffffffe)
var x262 uint64
var x263 uint64
x263, x262 = bits.Mul64(x252, 0xffffffff00000000)
var x264 uint64
var x265 uint64
x265, x264 = bits.Mul64(x252, 0xffffffff)
var x266 uint64
var x267 uint64
x266, x267 = bits.Add64(x265, x262, uint64(0x0))
var x268 uint64
var x269 uint64
x268, x269 = bits.Add64(x263, x260, uint64(p384Uint1(x267)))
var x270 uint64
var x271 uint64
x270, x271 = bits.Add64(x261, x258, uint64(p384Uint1(x269)))
var x272 uint64
var x273 uint64
x272, x273 = bits.Add64(x259, x256, uint64(p384Uint1(x271)))
var x274 uint64
var x275 uint64
x274, x275 = bits.Add64(x257, x254, uint64(p384Uint1(x273)))
var x277 uint64
_, x277 = bits.Add64(x240, x264, uint64(0x0))
var x278 uint64
var x279 uint64
x278, x279 = bits.Add64(x242, x266, uint64(p384Uint1(x277)))
var x280 uint64
var x281 uint64
x280, x281 = bits.Add64(x244, x268, uint64(p384Uint1(x279)))
var x282 uint64
var x283 uint64
x282, x283 = bits.Add64(x246, x270, uint64(p384Uint1(x281)))
var x284 uint64
var x285 uint64
x284, x285 = bits.Add64(x248, x272, uint64(p384Uint1(x283)))
var x286 uint64
var x287 uint64
x286, x287 = bits.Add64(x250, x274, uint64(p384Uint1(x285)))
var x288 uint64
var x289 uint64
x288, x289 = bits.Add64((uint64(p384Uint1(x251)) + uint64(p384Uint1(x239))), (uint64(p384Uint1(x275)) + x255), uint64(p384Uint1(x287)))
var x290 uint64
var x291 uint64
x290, x291 = bits.Sub64(x278, 0xffffffff, uint64(0x0))
var x292 uint64
var x293 uint64
x292, x293 = bits.Sub64(x280, 0xffffffff00000000, uint64(p384Uint1(x291)))
var x294 uint64
var x295 uint64
x294, x295 = bits.Sub64(x282, 0xfffffffffffffffe, uint64(p384Uint1(x293)))
var x296 uint64
var x297 uint64
x296, x297 = bits.Sub64(x284, 0xffffffffffffffff, uint64(p384Uint1(x295)))
var x298 uint64
var x299 uint64
x298, x299 = bits.Sub64(x286, 0xffffffffffffffff, uint64(p384Uint1(x297)))
var x300 uint64
var x301 uint64
x300, x301 = bits.Sub64(x288, 0xffffffffffffffff, uint64(p384Uint1(x299)))
var x303 uint64
_, x303 = bits.Sub64(uint64(p384Uint1(x289)), uint64(0x0), uint64(p384Uint1(x301)))
var x304 uint64
p384CmovznzU64(&x304, p384Uint1(x303), x290, x278)
var x305 uint64
p384CmovznzU64(&x305, p384Uint1(x303), x292, x280)
var x306 uint64
p384CmovznzU64(&x306, p384Uint1(x303), x294, x282)
var x307 uint64
p384CmovznzU64(&x307, p384Uint1(x303), x296, x284)
var x308 uint64
p384CmovznzU64(&x308, p384Uint1(x303), x298, x286)
var x309 uint64
p384CmovznzU64(&x309, p384Uint1(x303), x300, x288)
out1[0] = x304
out1[1] = x305
out1[2] = x306
out1[3] = x307
out1[4] = x308
out1[5] = x309
}
// p384ToMontgomery translates a field element into the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = eval arg1 mod m
// 0 ≤ eval out1 < m
func p384ToMontgomery(out1 *p384MontgomeryDomainFieldElement, arg1 *p384NonMontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[4]
x5 := arg1[5]
x6 := arg1[0]
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x6, 0x200000000)
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x6, 0xfffffffe00000000)
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x6, 0x200000000)
var x13 uint64
var x14 uint64
x14, x13 = bits.Mul64(x6, 0xfffffffe00000001)
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(x14, x11, uint64(0x0))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(x12, x9, uint64(p384Uint1(x16)))
var x19 uint64
var x20 uint64
x19, x20 = bits.Add64(x10, x7, uint64(p384Uint1(x18)))
var x21 uint64
var x22 uint64
x21, x22 = bits.Add64(x8, x6, uint64(p384Uint1(x20)))
var x23 uint64
_, x23 = bits.Mul64(x13, 0x100000001)
var x25 uint64
var x26 uint64
x26, x25 = bits.Mul64(x23, 0xffffffffffffffff)
var x27 uint64
var x28 uint64
x28, x27 = bits.Mul64(x23, 0xffffffffffffffff)
var x29 uint64
var x30 uint64
x30, x29 = bits.Mul64(x23, 0xffffffffffffffff)
var x31 uint64
var x32 uint64
x32, x31 = bits.Mul64(x23, 0xfffffffffffffffe)
var x33 uint64
var x34 uint64
x34, x33 = bits.Mul64(x23, 0xffffffff00000000)
var x35 uint64
var x36 uint64
x36, x35 = bits.Mul64(x23, 0xffffffff)
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x36, x33, uint64(0x0))
var x39 uint64
var x40 uint64
x39, x40 = bits.Add64(x34, x31, uint64(p384Uint1(x38)))
var x41 uint64
var x42 uint64
x41, x42 = bits.Add64(x32, x29, uint64(p384Uint1(x40)))
var x43 uint64
var x44 uint64
x43, x44 = bits.Add64(x30, x27, uint64(p384Uint1(x42)))
var x45 uint64
var x46 uint64
x45, x46 = bits.Add64(x28, x25, uint64(p384Uint1(x44)))
var x48 uint64
_, x48 = bits.Add64(x13, x35, uint64(0x0))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(x15, x37, uint64(p384Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(x17, x39, uint64(p384Uint1(x50)))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(x19, x41, uint64(p384Uint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(x21, x43, uint64(p384Uint1(x54)))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(uint64(p384Uint1(x22)), x45, uint64(p384Uint1(x56)))
var x59 uint64
var x60 uint64
x59, x60 = bits.Add64(uint64(0x0), (uint64(p384Uint1(x46)) + x26), uint64(p384Uint1(x58)))
var x61 uint64
var x62 uint64
x62, x61 = bits.Mul64(x1, 0x200000000)
var x63 uint64
var x64 uint64
x64, x63 = bits.Mul64(x1, 0xfffffffe00000000)
var x65 uint64
var x66 uint64
x66, x65 = bits.Mul64(x1, 0x200000000)
var x67 uint64
var x68 uint64
x68, x67 = bits.Mul64(x1, 0xfffffffe00000001)
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x68, x65, uint64(0x0))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x66, x63, uint64(p384Uint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x64, x61, uint64(p384Uint1(x72)))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x62, x1, uint64(p384Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x49, x67, uint64(0x0))
var x79 uint64
var x80 uint64
x79, x80 = bits.Add64(x51, x69, uint64(p384Uint1(x78)))
var x81 uint64
var x82 uint64
x81, x82 = bits.Add64(x53, x71, uint64(p384Uint1(x80)))
var x83 uint64
var x84 uint64
x83, x84 = bits.Add64(x55, x73, uint64(p384Uint1(x82)))
var x85 uint64
var x86 uint64
x85, x86 = bits.Add64(x57, x75, uint64(p384Uint1(x84)))
var x87 uint64
var x88 uint64
x87, x88 = bits.Add64(x59, uint64(p384Uint1(x76)), uint64(p384Uint1(x86)))
var x89 uint64
_, x89 = bits.Mul64(x77, 0x100000001)
var x91 uint64
var x92 uint64
x92, x91 = bits.Mul64(x89, 0xffffffffffffffff)
var x93 uint64
var x94 uint64
x94, x93 = bits.Mul64(x89, 0xffffffffffffffff)
var x95 uint64
var x96 uint64
x96, x95 = bits.Mul64(x89, 0xffffffffffffffff)
var x97 uint64
var x98 uint64
x98, x97 = bits.Mul64(x89, 0xfffffffffffffffe)
var x99 uint64
var x100 uint64
x100, x99 = bits.Mul64(x89, 0xffffffff00000000)
var x101 uint64
var x102 uint64
x102, x101 = bits.Mul64(x89, 0xffffffff)
var x103 uint64
var x104 uint64
x103, x104 = bits.Add64(x102, x99, uint64(0x0))
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x100, x97, uint64(p384Uint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x98, x95, uint64(p384Uint1(x106)))
var x109 uint64
var x110 uint64
x109, x110 = bits.Add64(x96, x93, uint64(p384Uint1(x108)))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x94, x91, uint64(p384Uint1(x110)))
var x114 uint64
_, x114 = bits.Add64(x77, x101, uint64(0x0))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x79, x103, uint64(p384Uint1(x114)))
var x117 uint64
var x118 uint64
x117, x118 = bits.Add64(x81, x105, uint64(p384Uint1(x116)))
var x119 uint64
var x120 uint64
x119, x120 = bits.Add64(x83, x107, uint64(p384Uint1(x118)))
var x121 uint64
var x122 uint64
x121, x122 = bits.Add64(x85, x109, uint64(p384Uint1(x120)))
var x123 uint64
var x124 uint64
x123, x124 = bits.Add64(x87, x111, uint64(p384Uint1(x122)))
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64((uint64(p384Uint1(x88)) + uint64(p384Uint1(x60))), (uint64(p384Uint1(x112)) + x92), uint64(p384Uint1(x124)))
var x127 uint64
var x128 uint64
x128, x127 = bits.Mul64(x2, 0x200000000)
var x129 uint64
var x130 uint64
x130, x129 = bits.Mul64(x2, 0xfffffffe00000000)
var x131 uint64
var x132 uint64
x132, x131 = bits.Mul64(x2, 0x200000000)
var x133 uint64
var x134 uint64
x134, x133 = bits.Mul64(x2, 0xfffffffe00000001)
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x134, x131, uint64(0x0))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x132, x129, uint64(p384Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x130, x127, uint64(p384Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x128, x2, uint64(p384Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x115, x133, uint64(0x0))
var x145 uint64
var x146 uint64
x145, x146 = bits.Add64(x117, x135, uint64(p384Uint1(x144)))
var x147 uint64
var x148 uint64
x147, x148 = bits.Add64(x119, x137, uint64(p384Uint1(x146)))
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x121, x139, uint64(p384Uint1(x148)))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x123, x141, uint64(p384Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(x125, uint64(p384Uint1(x142)), uint64(p384Uint1(x152)))
var x155 uint64
_, x155 = bits.Mul64(x143, 0x100000001)
var x157 uint64
var x158 uint64
x158, x157 = bits.Mul64(x155, 0xffffffffffffffff)
var x159 uint64
var x160 uint64
x160, x159 = bits.Mul64(x155, 0xffffffffffffffff)
var x161 uint64
var x162 uint64
x162, x161 = bits.Mul64(x155, 0xffffffffffffffff)
var x163 uint64
var x164 uint64
x164, x163 = bits.Mul64(x155, 0xfffffffffffffffe)
var x165 uint64
var x166 uint64
x166, x165 = bits.Mul64(x155, 0xffffffff00000000)
var x167 uint64
var x168 uint64
x168, x167 = bits.Mul64(x155, 0xffffffff)
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x168, x165, uint64(0x0))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x166, x163, uint64(p384Uint1(x170)))
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x164, x161, uint64(p384Uint1(x172)))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x162, x159, uint64(p384Uint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x160, x157, uint64(p384Uint1(x176)))
var x180 uint64
_, x180 = bits.Add64(x143, x167, uint64(0x0))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x145, x169, uint64(p384Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x147, x171, uint64(p384Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x149, x173, uint64(p384Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x151, x175, uint64(p384Uint1(x186)))
var x189 uint64
var x190 uint64
x189, x190 = bits.Add64(x153, x177, uint64(p384Uint1(x188)))
var x191 uint64
var x192 uint64
x191, x192 = bits.Add64((uint64(p384Uint1(x154)) + uint64(p384Uint1(x126))), (uint64(p384Uint1(x178)) + x158), uint64(p384Uint1(x190)))
var x193 uint64
var x194 uint64
x194, x193 = bits.Mul64(x3, 0x200000000)
var x195 uint64
var x196 uint64
x196, x195 = bits.Mul64(x3, 0xfffffffe00000000)
var x197 uint64
var x198 uint64
x198, x197 = bits.Mul64(x3, 0x200000000)
var x199 uint64
var x200 uint64
x200, x199 = bits.Mul64(x3, 0xfffffffe00000001)
var x201 uint64
var x202 uint64
x201, x202 = bits.Add64(x200, x197, uint64(0x0))
var x203 uint64
var x204 uint64
x203, x204 = bits.Add64(x198, x195, uint64(p384Uint1(x202)))
var x205 uint64
var x206 uint64
x205, x206 = bits.Add64(x196, x193, uint64(p384Uint1(x204)))
var x207 uint64
var x208 uint64
x207, x208 = bits.Add64(x194, x3, uint64(p384Uint1(x206)))
var x209 uint64
var x210 uint64
x209, x210 = bits.Add64(x181, x199, uint64(0x0))
var x211 uint64
var x212 uint64
x211, x212 = bits.Add64(x183, x201, uint64(p384Uint1(x210)))
var x213 uint64
var x214 uint64
x213, x214 = bits.Add64(x185, x203, uint64(p384Uint1(x212)))
var x215 uint64
var x216 uint64
x215, x216 = bits.Add64(x187, x205, uint64(p384Uint1(x214)))
var x217 uint64
var x218 uint64
x217, x218 = bits.Add64(x189, x207, uint64(p384Uint1(x216)))
var x219 uint64
var x220 uint64
x219, x220 = bits.Add64(x191, uint64(p384Uint1(x208)), uint64(p384Uint1(x218)))
var x221 uint64
_, x221 = bits.Mul64(x209, 0x100000001)
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x221, 0xffffffffffffffff)
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x221, 0xffffffffffffffff)
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x221, 0xffffffffffffffff)
var x229 uint64
var x230 uint64
x230, x229 = bits.Mul64(x221, 0xfffffffffffffffe)
var x231 uint64
var x232 uint64
x232, x231 = bits.Mul64(x221, 0xffffffff00000000)
var x233 uint64
var x234 uint64
x234, x233 = bits.Mul64(x221, 0xffffffff)
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x234, x231, uint64(0x0))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x232, x229, uint64(p384Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x230, x227, uint64(p384Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x228, x225, uint64(p384Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x226, x223, uint64(p384Uint1(x242)))
var x246 uint64
_, x246 = bits.Add64(x209, x233, uint64(0x0))
var x247 uint64
var x248 uint64
x247, x248 = bits.Add64(x211, x235, uint64(p384Uint1(x246)))
var x249 uint64
var x250 uint64
x249, x250 = bits.Add64(x213, x237, uint64(p384Uint1(x248)))
var x251 uint64
var x252 uint64
x251, x252 = bits.Add64(x215, x239, uint64(p384Uint1(x250)))
var x253 uint64
var x254 uint64
x253, x254 = bits.Add64(x217, x241, uint64(p384Uint1(x252)))
var x255 uint64
var x256 uint64
x255, x256 = bits.Add64(x219, x243, uint64(p384Uint1(x254)))
var x257 uint64
var x258 uint64
x257, x258 = bits.Add64((uint64(p384Uint1(x220)) + uint64(p384Uint1(x192))), (uint64(p384Uint1(x244)) + x224), uint64(p384Uint1(x256)))
var x259 uint64
var x260 uint64
x260, x259 = bits.Mul64(x4, 0x200000000)
var x261 uint64
var x262 uint64
x262, x261 = bits.Mul64(x4, 0xfffffffe00000000)
var x263 uint64
var x264 uint64
x264, x263 = bits.Mul64(x4, 0x200000000)
var x265 uint64
var x266 uint64
x266, x265 = bits.Mul64(x4, 0xfffffffe00000001)
var x267 uint64
var x268 uint64
x267, x268 = bits.Add64(x266, x263, uint64(0x0))
var x269 uint64
var x270 uint64
x269, x270 = bits.Add64(x264, x261, uint64(p384Uint1(x268)))
var x271 uint64
var x272 uint64
x271, x272 = bits.Add64(x262, x259, uint64(p384Uint1(x270)))
var x273 uint64
var x274 uint64
x273, x274 = bits.Add64(x260, x4, uint64(p384Uint1(x272)))
var x275 uint64
var x276 uint64
x275, x276 = bits.Add64(x247, x265, uint64(0x0))
var x277 uint64
var x278 uint64
x277, x278 = bits.Add64(x249, x267, uint64(p384Uint1(x276)))
var x279 uint64
var x280 uint64
x279, x280 = bits.Add64(x251, x269, uint64(p384Uint1(x278)))
var x281 uint64
var x282 uint64
x281, x282 = bits.Add64(x253, x271, uint64(p384Uint1(x280)))
var x283 uint64
var x284 uint64
x283, x284 = bits.Add64(x255, x273, uint64(p384Uint1(x282)))
var x285 uint64
var x286 uint64
x285, x286 = bits.Add64(x257, uint64(p384Uint1(x274)), uint64(p384Uint1(x284)))
var x287 uint64
_, x287 = bits.Mul64(x275, 0x100000001)
var x289 uint64
var x290 uint64
x290, x289 = bits.Mul64(x287, 0xffffffffffffffff)
var x291 uint64
var x292 uint64
x292, x291 = bits.Mul64(x287, 0xffffffffffffffff)
var x293 uint64
var x294 uint64
x294, x293 = bits.Mul64(x287, 0xffffffffffffffff)
var x295 uint64
var x296 uint64
x296, x295 = bits.Mul64(x287, 0xfffffffffffffffe)
var x297 uint64
var x298 uint64
x298, x297 = bits.Mul64(x287, 0xffffffff00000000)
var x299 uint64
var x300 uint64
x300, x299 = bits.Mul64(x287, 0xffffffff)
var x301 uint64
var x302 uint64
x301, x302 = bits.Add64(x300, x297, uint64(0x0))
var x303 uint64
var x304 uint64
x303, x304 = bits.Add64(x298, x295, uint64(p384Uint1(x302)))
var x305 uint64
var x306 uint64
x305, x306 = bits.Add64(x296, x293, uint64(p384Uint1(x304)))
var x307 uint64
var x308 uint64
x307, x308 = bits.Add64(x294, x291, uint64(p384Uint1(x306)))
var x309 uint64
var x310 uint64
x309, x310 = bits.Add64(x292, x289, uint64(p384Uint1(x308)))
var x312 uint64
_, x312 = bits.Add64(x275, x299, uint64(0x0))
var x313 uint64
var x314 uint64
x313, x314 = bits.Add64(x277, x301, uint64(p384Uint1(x312)))
var x315 uint64
var x316 uint64
x315, x316 = bits.Add64(x279, x303, uint64(p384Uint1(x314)))
var x317 uint64
var x318 uint64
x317, x318 = bits.Add64(x281, x305, uint64(p384Uint1(x316)))
var x319 uint64
var x320 uint64
x319, x320 = bits.Add64(x283, x307, uint64(p384Uint1(x318)))
var x321 uint64
var x322 uint64
x321, x322 = bits.Add64(x285, x309, uint64(p384Uint1(x320)))
var x323 uint64
var x324 uint64
x323, x324 = bits.Add64((uint64(p384Uint1(x286)) + uint64(p384Uint1(x258))), (uint64(p384Uint1(x310)) + x290), uint64(p384Uint1(x322)))
var x325 uint64
var x326 uint64
x326, x325 = bits.Mul64(x5, 0x200000000)
var x327 uint64
var x328 uint64
x328, x327 = bits.Mul64(x5, 0xfffffffe00000000)
var x329 uint64
var x330 uint64
x330, x329 = bits.Mul64(x5, 0x200000000)
var x331 uint64
var x332 uint64
x332, x331 = bits.Mul64(x5, 0xfffffffe00000001)
var x333 uint64
var x334 uint64
x333, x334 = bits.Add64(x332, x329, uint64(0x0))
var x335 uint64
var x336 uint64
x335, x336 = bits.Add64(x330, x327, uint64(p384Uint1(x334)))
var x337 uint64
var x338 uint64
x337, x338 = bits.Add64(x328, x325, uint64(p384Uint1(x336)))
var x339 uint64
var x340 uint64
x339, x340 = bits.Add64(x326, x5, uint64(p384Uint1(x338)))
var x341 uint64
var x342 uint64
x341, x342 = bits.Add64(x313, x331, uint64(0x0))
var x343 uint64
var x344 uint64
x343, x344 = bits.Add64(x315, x333, uint64(p384Uint1(x342)))
var x345 uint64
var x346 uint64
x345, x346 = bits.Add64(x317, x335, uint64(p384Uint1(x344)))
var x347 uint64
var x348 uint64
x347, x348 = bits.Add64(x319, x337, uint64(p384Uint1(x346)))
var x349 uint64
var x350 uint64
x349, x350 = bits.Add64(x321, x339, uint64(p384Uint1(x348)))
var x351 uint64
var x352 uint64
x351, x352 = bits.Add64(x323, uint64(p384Uint1(x340)), uint64(p384Uint1(x350)))
var x353 uint64
_, x353 = bits.Mul64(x341, 0x100000001)
var x355 uint64
var x356 uint64
x356, x355 = bits.Mul64(x353, 0xffffffffffffffff)
var x357 uint64
var x358 uint64
x358, x357 = bits.Mul64(x353, 0xffffffffffffffff)
var x359 uint64
var x360 uint64
x360, x359 = bits.Mul64(x353, 0xffffffffffffffff)
var x361 uint64
var x362 uint64
x362, x361 = bits.Mul64(x353, 0xfffffffffffffffe)
var x363 uint64
var x364 uint64
x364, x363 = bits.Mul64(x353, 0xffffffff00000000)
var x365 uint64
var x366 uint64
x366, x365 = bits.Mul64(x353, 0xffffffff)
var x367 uint64
var x368 uint64
x367, x368 = bits.Add64(x366, x363, uint64(0x0))
var x369 uint64
var x370 uint64
x369, x370 = bits.Add64(x364, x361, uint64(p384Uint1(x368)))
var x371 uint64
var x372 uint64
x371, x372 = bits.Add64(x362, x359, uint64(p384Uint1(x370)))
var x373 uint64
var x374 uint64
x373, x374 = bits.Add64(x360, x357, uint64(p384Uint1(x372)))
var x375 uint64
var x376 uint64
x375, x376 = bits.Add64(x358, x355, uint64(p384Uint1(x374)))
var x378 uint64
_, x378 = bits.Add64(x341, x365, uint64(0x0))
var x379 uint64
var x380 uint64
x379, x380 = bits.Add64(x343, x367, uint64(p384Uint1(x378)))
var x381 uint64
var x382 uint64
x381, x382 = bits.Add64(x345, x369, uint64(p384Uint1(x380)))
var x383 uint64
var x384 uint64
x383, x384 = bits.Add64(x347, x371, uint64(p384Uint1(x382)))
var x385 uint64
var x386 uint64
x385, x386 = bits.Add64(x349, x373, uint64(p384Uint1(x384)))
var x387 uint64
var x388 uint64
x387, x388 = bits.Add64(x351, x375, uint64(p384Uint1(x386)))
var x389 uint64
var x390 uint64
x389, x390 = bits.Add64((uint64(p384Uint1(x352)) + uint64(p384Uint1(x324))), (uint64(p384Uint1(x376)) + x356), uint64(p384Uint1(x388)))
var x391 uint64
var x392 uint64
x391, x392 = bits.Sub64(x379, 0xffffffff, uint64(0x0))
var x393 uint64
var x394 uint64
x393, x394 = bits.Sub64(x381, 0xffffffff00000000, uint64(p384Uint1(x392)))
var x395 uint64
var x396 uint64
x395, x396 = bits.Sub64(x383, 0xfffffffffffffffe, uint64(p384Uint1(x394)))
var x397 uint64
var x398 uint64
x397, x398 = bits.Sub64(x385, 0xffffffffffffffff, uint64(p384Uint1(x396)))
var x399 uint64
var x400 uint64
x399, x400 = bits.Sub64(x387, 0xffffffffffffffff, uint64(p384Uint1(x398)))
var x401 uint64
var x402 uint64
x401, x402 = bits.Sub64(x389, 0xffffffffffffffff, uint64(p384Uint1(x400)))
var x404 uint64
_, x404 = bits.Sub64(uint64(p384Uint1(x390)), uint64(0x0), uint64(p384Uint1(x402)))
var x405 uint64
p384CmovznzU64(&x405, p384Uint1(x404), x391, x379)
var x406 uint64
p384CmovznzU64(&x406, p384Uint1(x404), x393, x381)
var x407 uint64
p384CmovznzU64(&x407, p384Uint1(x404), x395, x383)
var x408 uint64
p384CmovznzU64(&x408, p384Uint1(x404), x397, x385)
var x409 uint64
p384CmovznzU64(&x409, p384Uint1(x404), x399, x387)
var x410 uint64
p384CmovznzU64(&x410, p384Uint1(x404), x401, x389)
out1[0] = x405
out1[1] = x406
out1[2] = x407
out1[3] = x408
out1[4] = x409
out1[5] = x410
}
// p384Selectznz is a multi-limb conditional select.
//
// Postconditions:
//
// eval out1 = (if arg1 = 0 then eval arg2 else eval arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
// arg3: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p384Selectznz(out1 *[6]uint64, arg1 p384Uint1, arg2 *[6]uint64, arg3 *[6]uint64) {
var x1 uint64
p384CmovznzU64(&x1, arg1, arg2[0], arg3[0])
var x2 uint64
p384CmovznzU64(&x2, arg1, arg2[1], arg3[1])
var x3 uint64
p384CmovznzU64(&x3, arg1, arg2[2], arg3[2])
var x4 uint64
p384CmovznzU64(&x4, arg1, arg2[3], arg3[3])
var x5 uint64
p384CmovznzU64(&x5, arg1, arg2[4], arg3[4])
var x6 uint64
p384CmovznzU64(&x6, arg1, arg2[5], arg3[5])
out1[0] = x1
out1[1] = x2
out1[2] = x3
out1[3] = x4
out1[4] = x5
out1[5] = x6
}
// p384ToBytes serializes a field element NOT in the Montgomery domain to bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = map (λ x, ⌊((eval arg1 mod m) mod 2^(8 * (x + 1))) / 2^(8 * x)⌋) [0..47]
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
func p384ToBytes(out1 *[48]uint8, arg1 *[6]uint64) {
x1 := arg1[5]
x2 := arg1[4]
x3 := arg1[3]
x4 := arg1[2]
x5 := arg1[1]
x6 := arg1[0]
x7 := (uint8(x6) & 0xff)
x8 := (x6 >> 8)
x9 := (uint8(x8) & 0xff)
x10 := (x8 >> 8)
x11 := (uint8(x10) & 0xff)
x12 := (x10 >> 8)
x13 := (uint8(x12) & 0xff)
x14 := (x12 >> 8)
x15 := (uint8(x14) & 0xff)
x16 := (x14 >> 8)
x17 := (uint8(x16) & 0xff)
x18 := (x16 >> 8)
x19 := (uint8(x18) & 0xff)
x20 := uint8((x18 >> 8))
x21 := (uint8(x5) & 0xff)
x22 := (x5 >> 8)
x23 := (uint8(x22) & 0xff)
x24 := (x22 >> 8)
x25 := (uint8(x24) & 0xff)
x26 := (x24 >> 8)
x27 := (uint8(x26) & 0xff)
x28 := (x26 >> 8)
x29 := (uint8(x28) & 0xff)
x30 := (x28 >> 8)
x31 := (uint8(x30) & 0xff)
x32 := (x30 >> 8)
x33 := (uint8(x32) & 0xff)
x34 := uint8((x32 >> 8))
x35 := (uint8(x4) & 0xff)
x36 := (x4 >> 8)
x37 := (uint8(x36) & 0xff)
x38 := (x36 >> 8)
x39 := (uint8(x38) & 0xff)
x40 := (x38 >> 8)
x41 := (uint8(x40) & 0xff)
x42 := (x40 >> 8)
x43 := (uint8(x42) & 0xff)
x44 := (x42 >> 8)
x45 := (uint8(x44) & 0xff)
x46 := (x44 >> 8)
x47 := (uint8(x46) & 0xff)
x48 := uint8((x46 >> 8))
x49 := (uint8(x3) & 0xff)
x50 := (x3 >> 8)
x51 := (uint8(x50) & 0xff)
x52 := (x50 >> 8)
x53 := (uint8(x52) & 0xff)
x54 := (x52 >> 8)
x55 := (uint8(x54) & 0xff)
x56 := (x54 >> 8)
x57 := (uint8(x56) & 0xff)
x58 := (x56 >> 8)
x59 := (uint8(x58) & 0xff)
x60 := (x58 >> 8)
x61 := (uint8(x60) & 0xff)
x62 := uint8((x60 >> 8))
x63 := (uint8(x2) & 0xff)
x64 := (x2 >> 8)
x65 := (uint8(x64) & 0xff)
x66 := (x64 >> 8)
x67 := (uint8(x66) & 0xff)
x68 := (x66 >> 8)
x69 := (uint8(x68) & 0xff)
x70 := (x68 >> 8)
x71 := (uint8(x70) & 0xff)
x72 := (x70 >> 8)
x73 := (uint8(x72) & 0xff)
x74 := (x72 >> 8)
x75 := (uint8(x74) & 0xff)
x76 := uint8((x74 >> 8))
x77 := (uint8(x1) & 0xff)
x78 := (x1 >> 8)
x79 := (uint8(x78) & 0xff)
x80 := (x78 >> 8)
x81 := (uint8(x80) & 0xff)
x82 := (x80 >> 8)
x83 := (uint8(x82) & 0xff)
x84 := (x82 >> 8)
x85 := (uint8(x84) & 0xff)
x86 := (x84 >> 8)
x87 := (uint8(x86) & 0xff)
x88 := (x86 >> 8)
x89 := (uint8(x88) & 0xff)
x90 := uint8((x88 >> 8))
out1[0] = x7
out1[1] = x9
out1[2] = x11
out1[3] = x13
out1[4] = x15
out1[5] = x17
out1[6] = x19
out1[7] = x20
out1[8] = x21
out1[9] = x23
out1[10] = x25
out1[11] = x27
out1[12] = x29
out1[13] = x31
out1[14] = x33
out1[15] = x34
out1[16] = x35
out1[17] = x37
out1[18] = x39
out1[19] = x41
out1[20] = x43
out1[21] = x45
out1[22] = x47
out1[23] = x48
out1[24] = x49
out1[25] = x51
out1[26] = x53
out1[27] = x55
out1[28] = x57
out1[29] = x59
out1[30] = x61
out1[31] = x62
out1[32] = x63
out1[33] = x65
out1[34] = x67
out1[35] = x69
out1[36] = x71
out1[37] = x73
out1[38] = x75
out1[39] = x76
out1[40] = x77
out1[41] = x79
out1[42] = x81
out1[43] = x83
out1[44] = x85
out1[45] = x87
out1[46] = x89
out1[47] = x90
}
// p384FromBytes deserializes a field element NOT in the Montgomery domain from bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ bytes_eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = bytes_eval arg1 mod m
// 0 ≤ eval out1 < m
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p384FromBytes(out1 *[6]uint64, arg1 *[48]uint8) {
x1 := (uint64(arg1[47]) << 56)
x2 := (uint64(arg1[46]) << 48)
x3 := (uint64(arg1[45]) << 40)
x4 := (uint64(arg1[44]) << 32)
x5 := (uint64(arg1[43]) << 24)
x6 := (uint64(arg1[42]) << 16)
x7 := (uint64(arg1[41]) << 8)
x8 := arg1[40]
x9 := (uint64(arg1[39]) << 56)
x10 := (uint64(arg1[38]) << 48)
x11 := (uint64(arg1[37]) << 40)
x12 := (uint64(arg1[36]) << 32)
x13 := (uint64(arg1[35]) << 24)
x14 := (uint64(arg1[34]) << 16)
x15 := (uint64(arg1[33]) << 8)
x16 := arg1[32]
x17 := (uint64(arg1[31]) << 56)
x18 := (uint64(arg1[30]) << 48)
x19 := (uint64(arg1[29]) << 40)
x20 := (uint64(arg1[28]) << 32)
x21 := (uint64(arg1[27]) << 24)
x22 := (uint64(arg1[26]) << 16)
x23 := (uint64(arg1[25]) << 8)
x24 := arg1[24]
x25 := (uint64(arg1[23]) << 56)
x26 := (uint64(arg1[22]) << 48)
x27 := (uint64(arg1[21]) << 40)
x28 := (uint64(arg1[20]) << 32)
x29 := (uint64(arg1[19]) << 24)
x30 := (uint64(arg1[18]) << 16)
x31 := (uint64(arg1[17]) << 8)
x32 := arg1[16]
x33 := (uint64(arg1[15]) << 56)
x34 := (uint64(arg1[14]) << 48)
x35 := (uint64(arg1[13]) << 40)
x36 := (uint64(arg1[12]) << 32)
x37 := (uint64(arg1[11]) << 24)
x38 := (uint64(arg1[10]) << 16)
x39 := (uint64(arg1[9]) << 8)
x40 := arg1[8]
x41 := (uint64(arg1[7]) << 56)
x42 := (uint64(arg1[6]) << 48)
x43 := (uint64(arg1[5]) << 40)
x44 := (uint64(arg1[4]) << 32)
x45 := (uint64(arg1[3]) << 24)
x46 := (uint64(arg1[2]) << 16)
x47 := (uint64(arg1[1]) << 8)
x48 := arg1[0]
x49 := (x47 + uint64(x48))
x50 := (x46 + x49)
x51 := (x45 + x50)
x52 := (x44 + x51)
x53 := (x43 + x52)
x54 := (x42 + x53)
x55 := (x41 + x54)
x56 := (x39 + uint64(x40))
x57 := (x38 + x56)
x58 := (x37 + x57)
x59 := (x36 + x58)
x60 := (x35 + x59)
x61 := (x34 + x60)
x62 := (x33 + x61)
x63 := (x31 + uint64(x32))
x64 := (x30 + x63)
x65 := (x29 + x64)
x66 := (x28 + x65)
x67 := (x27 + x66)
x68 := (x26 + x67)
x69 := (x25 + x68)
x70 := (x23 + uint64(x24))
x71 := (x22 + x70)
x72 := (x21 + x71)
x73 := (x20 + x72)
x74 := (x19 + x73)
x75 := (x18 + x74)
x76 := (x17 + x75)
x77 := (x15 + uint64(x16))
x78 := (x14 + x77)
x79 := (x13 + x78)
x80 := (x12 + x79)
x81 := (x11 + x80)
x82 := (x10 + x81)
x83 := (x9 + x82)
x84 := (x7 + uint64(x8))
x85 := (x6 + x84)
x86 := (x5 + x85)
x87 := (x4 + x86)
x88 := (x3 + x87)
x89 := (x2 + x88)
x90 := (x1 + x89)
out1[0] = x55
out1[1] = x62
out1[2] = x69
out1[3] = x76
out1[4] = x83
out1[5] = x90
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by addchain. DO NOT EDIT.
package fiat
// Invert sets e = 1/x, and returns e.
//
// If x == 0, Invert returns e = 0.
func (e *P384Element) Invert(x *P384Element) *P384Element {
// Inversion is implemented as exponentiation with exponent p − 2.
// The sequence of 15 multiplications and 383 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// x12 = _111111 << 6 + _111111
// x24 = x12 << 12 + x12
// x30 = x24 << 6 + _111111
// x31 = 2*x30 + 1
// x32 = 2*x31 + 1
// x63 = x32 << 31 + x31
// x126 = x63 << 63 + x63
// x252 = x126 << 126 + x126
// x255 = x252 << 3 + _111
// i397 = ((x255 << 33 + x32) << 94 + x30) << 2
// return 1 + i397
//
var z = new(P384Element).Set(e)
var t0 = new(P384Element)
var t1 = new(P384Element)
var t2 = new(P384Element)
var t3 = new(P384Element)
z.Square(x)
z.Mul(x, z)
z.Square(z)
t1.Mul(x, z)
z.Square(t1)
for s := 1; s < 3; s++ {
z.Square(z)
}
z.Mul(t1, z)
t0.Square(z)
for s := 1; s < 6; s++ {
t0.Square(t0)
}
t0.Mul(z, t0)
t2.Square(t0)
for s := 1; s < 12; s++ {
t2.Square(t2)
}
t0.Mul(t0, t2)
for s := 0; s < 6; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
t2.Mul(x, t0)
t0.Square(t2)
t0.Mul(x, t0)
t3.Square(t0)
for s := 1; s < 31; s++ {
t3.Square(t3)
}
t2.Mul(t2, t3)
t3.Square(t2)
for s := 1; s < 63; s++ {
t3.Square(t3)
}
t2.Mul(t2, t3)
t3.Square(t2)
for s := 1; s < 126; s++ {
t3.Square(t3)
}
t2.Mul(t2, t3)
for s := 0; s < 3; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
for s := 0; s < 33; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 94; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
for s := 0; s < 2; s++ {
z.Square(z)
}
z.Mul(x, z)
return e.Set(z)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package fiat
import (
"crypto/subtle"
"errors"
)
// P521Element is an integer modulo 2^521 - 1.
//
// The zero value is a valid zero element.
type P521Element struct {
// Values are represented internally always in the Montgomery domain, and
// converted in Bytes and SetBytes.
x p521MontgomeryDomainFieldElement
}
const p521ElementLen = 66
type p521UntypedFieldElement = [9]uint64
// One sets e = 1, and returns e.
func (e *P521Element) One() *P521Element {
p521SetOne(&e.x)
return e
}
// Equal returns 1 if e == t, and zero otherwise.
func (e *P521Element) Equal(t *P521Element) int {
eBytes := e.Bytes()
tBytes := t.Bytes()
return subtle.ConstantTimeCompare(eBytes, tBytes)
}
// IsZero returns 1 if e == 0, and zero otherwise.
func (e *P521Element) IsZero() int {
zero := make([]byte, p521ElementLen)
eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, zero)
}
// Set sets e = t, and returns e.
func (e *P521Element) Set(t *P521Element) *P521Element {
e.x = t.x
return e
}
// Bytes returns the 66-byte big-endian encoding of e.
func (e *P521Element) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p521ElementLen]byte
return e.bytes(&out)
}
func (e *P521Element) bytes(out *[p521ElementLen]byte) []byte {
var tmp p521NonMontgomeryDomainFieldElement
p521FromMontgomery(&tmp, &e.x)
p521ToBytes(out, (*p521UntypedFieldElement)(&tmp))
p521InvertEndianness(out[:])
return out[:]
}
// SetBytes sets e = v, where v is a big-endian 66-byte encoding, and returns e.
// If v is not 66 bytes or it encodes a value higher than 2^521 - 1,
// SetBytes returns nil and an error, and e is unchanged.
func (e *P521Element) SetBytes(v []byte) (*P521Element, error) {
if len(v) != p521ElementLen {
return nil, errors.New("invalid P521Element encoding")
}
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(P521Element).Sub(
new(P521Element), new(P521Element).One()).Bytes()
for i := range v {
if v[i] < minusOneEncoding[i] {
break
}
if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid P521Element encoding")
}
}
var in [p521ElementLen]byte
copy(in[:], v)
p521InvertEndianness(in[:])
var tmp p521NonMontgomeryDomainFieldElement
p521FromBytes((*p521UntypedFieldElement)(&tmp), &in)
p521ToMontgomery(&e.x, &tmp)
return e, nil
}
// Add sets e = t1 + t2, and returns e.
func (e *P521Element) Add(t1, t2 *P521Element) *P521Element {
p521Add(&e.x, &t1.x, &t2.x)
return e
}
// Sub sets e = t1 - t2, and returns e.
func (e *P521Element) Sub(t1, t2 *P521Element) *P521Element {
p521Sub(&e.x, &t1.x, &t2.x)
return e
}
// Mul sets e = t1 * t2, and returns e.
func (e *P521Element) Mul(t1, t2 *P521Element) *P521Element {
p521Mul(&e.x, &t1.x, &t2.x)
return e
}
// Square sets e = t * t, and returns e.
func (e *P521Element) Square(t *P521Element) *P521Element {
p521Square(&e.x, &t.x)
return e
}
// Select sets v to a if cond == 1, and to b if cond == 0.
func (v *P521Element) Select(a, b *P521Element, cond int) *P521Element {
p521Selectznz((*p521UntypedFieldElement)(&v.x), p521Uint1(cond),
(*p521UntypedFieldElement)(&b.x), (*p521UntypedFieldElement)(&a.x))
return v
}
func p521InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]
}
}
// Code generated by Fiat Cryptography. DO NOT EDIT.
//
// Autogenerated: word_by_word_montgomery --lang Go --no-wide-int --cmovznz-by-mul --relax-primitive-carry-to-bitwidth 32,64 --internal-static --public-function-case camelCase --public-type-case camelCase --private-function-case camelCase --private-type-case camelCase --doc-text-before-function-name '' --doc-newline-before-package-declaration --doc-prepend-header 'Code generated by Fiat Cryptography. DO NOT EDIT.' --package-name fiat --no-prefix-fiat p521 64 '2^521 - 1' mul square add sub one from_montgomery to_montgomery selectznz to_bytes from_bytes
//
// curve description: p521
//
// machine_wordsize = 64 (from "64")
//
// requested operations: mul, square, add, sub, one, from_montgomery, to_montgomery, selectznz, to_bytes, from_bytes
//
// m = 0x1ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff (from "2^521 - 1")
//
//
//
// NOTE: In addition to the bounds specified above each function, all
//
// functions synthesized for this Montgomery arithmetic require the
//
// input to be strictly less than the prime modulus (m), and also
//
// require the input to be in the unique saturated representation.
//
// All functions also ensure that these two properties are true of
//
// return values.
//
//
//
// Computed values:
//
// eval z = z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) + (z[4] << 256) + (z[5] << 0x140) + (z[6] << 0x180) + (z[7] << 0x1c0) + (z[8] << 2^9)
//
// bytes_eval z = z[0] + (z[1] << 8) + (z[2] << 16) + (z[3] << 24) + (z[4] << 32) + (z[5] << 40) + (z[6] << 48) + (z[7] << 56) + (z[8] << 64) + (z[9] << 72) + (z[10] << 80) + (z[11] << 88) + (z[12] << 96) + (z[13] << 104) + (z[14] << 112) + (z[15] << 120) + (z[16] << 128) + (z[17] << 136) + (z[18] << 144) + (z[19] << 152) + (z[20] << 160) + (z[21] << 168) + (z[22] << 176) + (z[23] << 184) + (z[24] << 192) + (z[25] << 200) + (z[26] << 208) + (z[27] << 216) + (z[28] << 224) + (z[29] << 232) + (z[30] << 240) + (z[31] << 248) + (z[32] << 256) + (z[33] << 0x108) + (z[34] << 0x110) + (z[35] << 0x118) + (z[36] << 0x120) + (z[37] << 0x128) + (z[38] << 0x130) + (z[39] << 0x138) + (z[40] << 0x140) + (z[41] << 0x148) + (z[42] << 0x150) + (z[43] << 0x158) + (z[44] << 0x160) + (z[45] << 0x168) + (z[46] << 0x170) + (z[47] << 0x178) + (z[48] << 0x180) + (z[49] << 0x188) + (z[50] << 0x190) + (z[51] << 0x198) + (z[52] << 0x1a0) + (z[53] << 0x1a8) + (z[54] << 0x1b0) + (z[55] << 0x1b8) + (z[56] << 0x1c0) + (z[57] << 0x1c8) + (z[58] << 0x1d0) + (z[59] << 0x1d8) + (z[60] << 0x1e0) + (z[61] << 0x1e8) + (z[62] << 0x1f0) + (z[63] << 0x1f8) + (z[64] << 2^9) + (z[65] << 0x208)
//
// twos_complement_eval z = let x1 := z[0] + (z[1] << 64) + (z[2] << 128) + (z[3] << 192) + (z[4] << 256) + (z[5] << 0x140) + (z[6] << 0x180) + (z[7] << 0x1c0) + (z[8] << 2^9) in
//
// if x1 & (2^576-1) < 2^575 then x1 & (2^576-1) else (x1 & (2^576-1)) - 2^576
package fiat
import "math/bits"
type p521Uint1 uint64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
type p521Int1 int64 // We use uint64 instead of a more narrow type for performance reasons; see https://github.com/mit-plv/fiat-crypto/pull/1006#issuecomment-892625927
// The type p521MontgomeryDomainFieldElement is a field element in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p521MontgomeryDomainFieldElement [9]uint64
// The type p521NonMontgomeryDomainFieldElement is a field element NOT in the Montgomery domain.
//
// Bounds: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
type p521NonMontgomeryDomainFieldElement [9]uint64
// p521CmovznzU64 is a single-word conditional move.
//
// Postconditions:
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
//
// out1: [0x0 ~> 0xffffffffffffffff]
func p521CmovznzU64(out1 *uint64, arg1 p521Uint1, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// p521Mul multiplies two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p521Mul(out1 *p521MontgomeryDomainFieldElement, arg1 *p521MontgomeryDomainFieldElement, arg2 *p521MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[4]
x5 := arg1[5]
x6 := arg1[6]
x7 := arg1[7]
x8 := arg1[8]
x9 := arg1[0]
var x10 uint64
var x11 uint64
x11, x10 = bits.Mul64(x9, arg2[8])
var x12 uint64
var x13 uint64
x13, x12 = bits.Mul64(x9, arg2[7])
var x14 uint64
var x15 uint64
x15, x14 = bits.Mul64(x9, arg2[6])
var x16 uint64
var x17 uint64
x17, x16 = bits.Mul64(x9, arg2[5])
var x18 uint64
var x19 uint64
x19, x18 = bits.Mul64(x9, arg2[4])
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x9, arg2[3])
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x9, arg2[2])
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x9, arg2[1])
var x26 uint64
var x27 uint64
x27, x26 = bits.Mul64(x9, arg2[0])
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x27, x24, uint64(0x0))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x25, x22, uint64(p521Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x23, x20, uint64(p521Uint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x21, x18, uint64(p521Uint1(x33)))
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(x19, x16, uint64(p521Uint1(x35)))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(x17, x14, uint64(p521Uint1(x37)))
var x40 uint64
var x41 uint64
x40, x41 = bits.Add64(x15, x12, uint64(p521Uint1(x39)))
var x42 uint64
var x43 uint64
x42, x43 = bits.Add64(x13, x10, uint64(p521Uint1(x41)))
x44 := (uint64(p521Uint1(x43)) + x11)
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x26, 0x1ff)
var x47 uint64
var x48 uint64
x48, x47 = bits.Mul64(x26, 0xffffffffffffffff)
var x49 uint64
var x50 uint64
x50, x49 = bits.Mul64(x26, 0xffffffffffffffff)
var x51 uint64
var x52 uint64
x52, x51 = bits.Mul64(x26, 0xffffffffffffffff)
var x53 uint64
var x54 uint64
x54, x53 = bits.Mul64(x26, 0xffffffffffffffff)
var x55 uint64
var x56 uint64
x56, x55 = bits.Mul64(x26, 0xffffffffffffffff)
var x57 uint64
var x58 uint64
x58, x57 = bits.Mul64(x26, 0xffffffffffffffff)
var x59 uint64
var x60 uint64
x60, x59 = bits.Mul64(x26, 0xffffffffffffffff)
var x61 uint64
var x62 uint64
x62, x61 = bits.Mul64(x26, 0xffffffffffffffff)
var x63 uint64
var x64 uint64
x63, x64 = bits.Add64(x62, x59, uint64(0x0))
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x60, x57, uint64(p521Uint1(x64)))
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x58, x55, uint64(p521Uint1(x66)))
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x56, x53, uint64(p521Uint1(x68)))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x54, x51, uint64(p521Uint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x52, x49, uint64(p521Uint1(x72)))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x50, x47, uint64(p521Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x48, x45, uint64(p521Uint1(x76)))
x79 := (uint64(p521Uint1(x78)) + x46)
var x81 uint64
_, x81 = bits.Add64(x26, x61, uint64(0x0))
var x82 uint64
var x83 uint64
x82, x83 = bits.Add64(x28, x63, uint64(p521Uint1(x81)))
var x84 uint64
var x85 uint64
x84, x85 = bits.Add64(x30, x65, uint64(p521Uint1(x83)))
var x86 uint64
var x87 uint64
x86, x87 = bits.Add64(x32, x67, uint64(p521Uint1(x85)))
var x88 uint64
var x89 uint64
x88, x89 = bits.Add64(x34, x69, uint64(p521Uint1(x87)))
var x90 uint64
var x91 uint64
x90, x91 = bits.Add64(x36, x71, uint64(p521Uint1(x89)))
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x38, x73, uint64(p521Uint1(x91)))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x40, x75, uint64(p521Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x42, x77, uint64(p521Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x44, x79, uint64(p521Uint1(x97)))
var x100 uint64
var x101 uint64
x101, x100 = bits.Mul64(x1, arg2[8])
var x102 uint64
var x103 uint64
x103, x102 = bits.Mul64(x1, arg2[7])
var x104 uint64
var x105 uint64
x105, x104 = bits.Mul64(x1, arg2[6])
var x106 uint64
var x107 uint64
x107, x106 = bits.Mul64(x1, arg2[5])
var x108 uint64
var x109 uint64
x109, x108 = bits.Mul64(x1, arg2[4])
var x110 uint64
var x111 uint64
x111, x110 = bits.Mul64(x1, arg2[3])
var x112 uint64
var x113 uint64
x113, x112 = bits.Mul64(x1, arg2[2])
var x114 uint64
var x115 uint64
x115, x114 = bits.Mul64(x1, arg2[1])
var x116 uint64
var x117 uint64
x117, x116 = bits.Mul64(x1, arg2[0])
var x118 uint64
var x119 uint64
x118, x119 = bits.Add64(x117, x114, uint64(0x0))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x115, x112, uint64(p521Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x113, x110, uint64(p521Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x111, x108, uint64(p521Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x109, x106, uint64(p521Uint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x107, x104, uint64(p521Uint1(x127)))
var x130 uint64
var x131 uint64
x130, x131 = bits.Add64(x105, x102, uint64(p521Uint1(x129)))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x103, x100, uint64(p521Uint1(x131)))
x134 := (uint64(p521Uint1(x133)) + x101)
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x82, x116, uint64(0x0))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x84, x118, uint64(p521Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x86, x120, uint64(p521Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x88, x122, uint64(p521Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x90, x124, uint64(p521Uint1(x142)))
var x145 uint64
var x146 uint64
x145, x146 = bits.Add64(x92, x126, uint64(p521Uint1(x144)))
var x147 uint64
var x148 uint64
x147, x148 = bits.Add64(x94, x128, uint64(p521Uint1(x146)))
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x96, x130, uint64(p521Uint1(x148)))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x98, x132, uint64(p521Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(uint64(p521Uint1(x99)), x134, uint64(p521Uint1(x152)))
var x155 uint64
var x156 uint64
x156, x155 = bits.Mul64(x135, 0x1ff)
var x157 uint64
var x158 uint64
x158, x157 = bits.Mul64(x135, 0xffffffffffffffff)
var x159 uint64
var x160 uint64
x160, x159 = bits.Mul64(x135, 0xffffffffffffffff)
var x161 uint64
var x162 uint64
x162, x161 = bits.Mul64(x135, 0xffffffffffffffff)
var x163 uint64
var x164 uint64
x164, x163 = bits.Mul64(x135, 0xffffffffffffffff)
var x165 uint64
var x166 uint64
x166, x165 = bits.Mul64(x135, 0xffffffffffffffff)
var x167 uint64
var x168 uint64
x168, x167 = bits.Mul64(x135, 0xffffffffffffffff)
var x169 uint64
var x170 uint64
x170, x169 = bits.Mul64(x135, 0xffffffffffffffff)
var x171 uint64
var x172 uint64
x172, x171 = bits.Mul64(x135, 0xffffffffffffffff)
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x172, x169, uint64(0x0))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x170, x167, uint64(p521Uint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x168, x165, uint64(p521Uint1(x176)))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x166, x163, uint64(p521Uint1(x178)))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x164, x161, uint64(p521Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x162, x159, uint64(p521Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x160, x157, uint64(p521Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x158, x155, uint64(p521Uint1(x186)))
x189 := (uint64(p521Uint1(x188)) + x156)
var x191 uint64
_, x191 = bits.Add64(x135, x171, uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Add64(x137, x173, uint64(p521Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Add64(x139, x175, uint64(p521Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Add64(x141, x177, uint64(p521Uint1(x195)))
var x198 uint64
var x199 uint64
x198, x199 = bits.Add64(x143, x179, uint64(p521Uint1(x197)))
var x200 uint64
var x201 uint64
x200, x201 = bits.Add64(x145, x181, uint64(p521Uint1(x199)))
var x202 uint64
var x203 uint64
x202, x203 = bits.Add64(x147, x183, uint64(p521Uint1(x201)))
var x204 uint64
var x205 uint64
x204, x205 = bits.Add64(x149, x185, uint64(p521Uint1(x203)))
var x206 uint64
var x207 uint64
x206, x207 = bits.Add64(x151, x187, uint64(p521Uint1(x205)))
var x208 uint64
var x209 uint64
x208, x209 = bits.Add64(x153, x189, uint64(p521Uint1(x207)))
x210 := (uint64(p521Uint1(x209)) + uint64(p521Uint1(x154)))
var x211 uint64
var x212 uint64
x212, x211 = bits.Mul64(x2, arg2[8])
var x213 uint64
var x214 uint64
x214, x213 = bits.Mul64(x2, arg2[7])
var x215 uint64
var x216 uint64
x216, x215 = bits.Mul64(x2, arg2[6])
var x217 uint64
var x218 uint64
x218, x217 = bits.Mul64(x2, arg2[5])
var x219 uint64
var x220 uint64
x220, x219 = bits.Mul64(x2, arg2[4])
var x221 uint64
var x222 uint64
x222, x221 = bits.Mul64(x2, arg2[3])
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x2, arg2[2])
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x2, arg2[1])
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x2, arg2[0])
var x229 uint64
var x230 uint64
x229, x230 = bits.Add64(x228, x225, uint64(0x0))
var x231 uint64
var x232 uint64
x231, x232 = bits.Add64(x226, x223, uint64(p521Uint1(x230)))
var x233 uint64
var x234 uint64
x233, x234 = bits.Add64(x224, x221, uint64(p521Uint1(x232)))
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x222, x219, uint64(p521Uint1(x234)))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x220, x217, uint64(p521Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x218, x215, uint64(p521Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x216, x213, uint64(p521Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x214, x211, uint64(p521Uint1(x242)))
x245 := (uint64(p521Uint1(x244)) + x212)
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x192, x227, uint64(0x0))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x194, x229, uint64(p521Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x196, x231, uint64(p521Uint1(x249)))
var x252 uint64
var x253 uint64
x252, x253 = bits.Add64(x198, x233, uint64(p521Uint1(x251)))
var x254 uint64
var x255 uint64
x254, x255 = bits.Add64(x200, x235, uint64(p521Uint1(x253)))
var x256 uint64
var x257 uint64
x256, x257 = bits.Add64(x202, x237, uint64(p521Uint1(x255)))
var x258 uint64
var x259 uint64
x258, x259 = bits.Add64(x204, x239, uint64(p521Uint1(x257)))
var x260 uint64
var x261 uint64
x260, x261 = bits.Add64(x206, x241, uint64(p521Uint1(x259)))
var x262 uint64
var x263 uint64
x262, x263 = bits.Add64(x208, x243, uint64(p521Uint1(x261)))
var x264 uint64
var x265 uint64
x264, x265 = bits.Add64(x210, x245, uint64(p521Uint1(x263)))
var x266 uint64
var x267 uint64
x267, x266 = bits.Mul64(x246, 0x1ff)
var x268 uint64
var x269 uint64
x269, x268 = bits.Mul64(x246, 0xffffffffffffffff)
var x270 uint64
var x271 uint64
x271, x270 = bits.Mul64(x246, 0xffffffffffffffff)
var x272 uint64
var x273 uint64
x273, x272 = bits.Mul64(x246, 0xffffffffffffffff)
var x274 uint64
var x275 uint64
x275, x274 = bits.Mul64(x246, 0xffffffffffffffff)
var x276 uint64
var x277 uint64
x277, x276 = bits.Mul64(x246, 0xffffffffffffffff)
var x278 uint64
var x279 uint64
x279, x278 = bits.Mul64(x246, 0xffffffffffffffff)
var x280 uint64
var x281 uint64
x281, x280 = bits.Mul64(x246, 0xffffffffffffffff)
var x282 uint64
var x283 uint64
x283, x282 = bits.Mul64(x246, 0xffffffffffffffff)
var x284 uint64
var x285 uint64
x284, x285 = bits.Add64(x283, x280, uint64(0x0))
var x286 uint64
var x287 uint64
x286, x287 = bits.Add64(x281, x278, uint64(p521Uint1(x285)))
var x288 uint64
var x289 uint64
x288, x289 = bits.Add64(x279, x276, uint64(p521Uint1(x287)))
var x290 uint64
var x291 uint64
x290, x291 = bits.Add64(x277, x274, uint64(p521Uint1(x289)))
var x292 uint64
var x293 uint64
x292, x293 = bits.Add64(x275, x272, uint64(p521Uint1(x291)))
var x294 uint64
var x295 uint64
x294, x295 = bits.Add64(x273, x270, uint64(p521Uint1(x293)))
var x296 uint64
var x297 uint64
x296, x297 = bits.Add64(x271, x268, uint64(p521Uint1(x295)))
var x298 uint64
var x299 uint64
x298, x299 = bits.Add64(x269, x266, uint64(p521Uint1(x297)))
x300 := (uint64(p521Uint1(x299)) + x267)
var x302 uint64
_, x302 = bits.Add64(x246, x282, uint64(0x0))
var x303 uint64
var x304 uint64
x303, x304 = bits.Add64(x248, x284, uint64(p521Uint1(x302)))
var x305 uint64
var x306 uint64
x305, x306 = bits.Add64(x250, x286, uint64(p521Uint1(x304)))
var x307 uint64
var x308 uint64
x307, x308 = bits.Add64(x252, x288, uint64(p521Uint1(x306)))
var x309 uint64
var x310 uint64
x309, x310 = bits.Add64(x254, x290, uint64(p521Uint1(x308)))
var x311 uint64
var x312 uint64
x311, x312 = bits.Add64(x256, x292, uint64(p521Uint1(x310)))
var x313 uint64
var x314 uint64
x313, x314 = bits.Add64(x258, x294, uint64(p521Uint1(x312)))
var x315 uint64
var x316 uint64
x315, x316 = bits.Add64(x260, x296, uint64(p521Uint1(x314)))
var x317 uint64
var x318 uint64
x317, x318 = bits.Add64(x262, x298, uint64(p521Uint1(x316)))
var x319 uint64
var x320 uint64
x319, x320 = bits.Add64(x264, x300, uint64(p521Uint1(x318)))
x321 := (uint64(p521Uint1(x320)) + uint64(p521Uint1(x265)))
var x322 uint64
var x323 uint64
x323, x322 = bits.Mul64(x3, arg2[8])
var x324 uint64
var x325 uint64
x325, x324 = bits.Mul64(x3, arg2[7])
var x326 uint64
var x327 uint64
x327, x326 = bits.Mul64(x3, arg2[6])
var x328 uint64
var x329 uint64
x329, x328 = bits.Mul64(x3, arg2[5])
var x330 uint64
var x331 uint64
x331, x330 = bits.Mul64(x3, arg2[4])
var x332 uint64
var x333 uint64
x333, x332 = bits.Mul64(x3, arg2[3])
var x334 uint64
var x335 uint64
x335, x334 = bits.Mul64(x3, arg2[2])
var x336 uint64
var x337 uint64
x337, x336 = bits.Mul64(x3, arg2[1])
var x338 uint64
var x339 uint64
x339, x338 = bits.Mul64(x3, arg2[0])
var x340 uint64
var x341 uint64
x340, x341 = bits.Add64(x339, x336, uint64(0x0))
var x342 uint64
var x343 uint64
x342, x343 = bits.Add64(x337, x334, uint64(p521Uint1(x341)))
var x344 uint64
var x345 uint64
x344, x345 = bits.Add64(x335, x332, uint64(p521Uint1(x343)))
var x346 uint64
var x347 uint64
x346, x347 = bits.Add64(x333, x330, uint64(p521Uint1(x345)))
var x348 uint64
var x349 uint64
x348, x349 = bits.Add64(x331, x328, uint64(p521Uint1(x347)))
var x350 uint64
var x351 uint64
x350, x351 = bits.Add64(x329, x326, uint64(p521Uint1(x349)))
var x352 uint64
var x353 uint64
x352, x353 = bits.Add64(x327, x324, uint64(p521Uint1(x351)))
var x354 uint64
var x355 uint64
x354, x355 = bits.Add64(x325, x322, uint64(p521Uint1(x353)))
x356 := (uint64(p521Uint1(x355)) + x323)
var x357 uint64
var x358 uint64
x357, x358 = bits.Add64(x303, x338, uint64(0x0))
var x359 uint64
var x360 uint64
x359, x360 = bits.Add64(x305, x340, uint64(p521Uint1(x358)))
var x361 uint64
var x362 uint64
x361, x362 = bits.Add64(x307, x342, uint64(p521Uint1(x360)))
var x363 uint64
var x364 uint64
x363, x364 = bits.Add64(x309, x344, uint64(p521Uint1(x362)))
var x365 uint64
var x366 uint64
x365, x366 = bits.Add64(x311, x346, uint64(p521Uint1(x364)))
var x367 uint64
var x368 uint64
x367, x368 = bits.Add64(x313, x348, uint64(p521Uint1(x366)))
var x369 uint64
var x370 uint64
x369, x370 = bits.Add64(x315, x350, uint64(p521Uint1(x368)))
var x371 uint64
var x372 uint64
x371, x372 = bits.Add64(x317, x352, uint64(p521Uint1(x370)))
var x373 uint64
var x374 uint64
x373, x374 = bits.Add64(x319, x354, uint64(p521Uint1(x372)))
var x375 uint64
var x376 uint64
x375, x376 = bits.Add64(x321, x356, uint64(p521Uint1(x374)))
var x377 uint64
var x378 uint64
x378, x377 = bits.Mul64(x357, 0x1ff)
var x379 uint64
var x380 uint64
x380, x379 = bits.Mul64(x357, 0xffffffffffffffff)
var x381 uint64
var x382 uint64
x382, x381 = bits.Mul64(x357, 0xffffffffffffffff)
var x383 uint64
var x384 uint64
x384, x383 = bits.Mul64(x357, 0xffffffffffffffff)
var x385 uint64
var x386 uint64
x386, x385 = bits.Mul64(x357, 0xffffffffffffffff)
var x387 uint64
var x388 uint64
x388, x387 = bits.Mul64(x357, 0xffffffffffffffff)
var x389 uint64
var x390 uint64
x390, x389 = bits.Mul64(x357, 0xffffffffffffffff)
var x391 uint64
var x392 uint64
x392, x391 = bits.Mul64(x357, 0xffffffffffffffff)
var x393 uint64
var x394 uint64
x394, x393 = bits.Mul64(x357, 0xffffffffffffffff)
var x395 uint64
var x396 uint64
x395, x396 = bits.Add64(x394, x391, uint64(0x0))
var x397 uint64
var x398 uint64
x397, x398 = bits.Add64(x392, x389, uint64(p521Uint1(x396)))
var x399 uint64
var x400 uint64
x399, x400 = bits.Add64(x390, x387, uint64(p521Uint1(x398)))
var x401 uint64
var x402 uint64
x401, x402 = bits.Add64(x388, x385, uint64(p521Uint1(x400)))
var x403 uint64
var x404 uint64
x403, x404 = bits.Add64(x386, x383, uint64(p521Uint1(x402)))
var x405 uint64
var x406 uint64
x405, x406 = bits.Add64(x384, x381, uint64(p521Uint1(x404)))
var x407 uint64
var x408 uint64
x407, x408 = bits.Add64(x382, x379, uint64(p521Uint1(x406)))
var x409 uint64
var x410 uint64
x409, x410 = bits.Add64(x380, x377, uint64(p521Uint1(x408)))
x411 := (uint64(p521Uint1(x410)) + x378)
var x413 uint64
_, x413 = bits.Add64(x357, x393, uint64(0x0))
var x414 uint64
var x415 uint64
x414, x415 = bits.Add64(x359, x395, uint64(p521Uint1(x413)))
var x416 uint64
var x417 uint64
x416, x417 = bits.Add64(x361, x397, uint64(p521Uint1(x415)))
var x418 uint64
var x419 uint64
x418, x419 = bits.Add64(x363, x399, uint64(p521Uint1(x417)))
var x420 uint64
var x421 uint64
x420, x421 = bits.Add64(x365, x401, uint64(p521Uint1(x419)))
var x422 uint64
var x423 uint64
x422, x423 = bits.Add64(x367, x403, uint64(p521Uint1(x421)))
var x424 uint64
var x425 uint64
x424, x425 = bits.Add64(x369, x405, uint64(p521Uint1(x423)))
var x426 uint64
var x427 uint64
x426, x427 = bits.Add64(x371, x407, uint64(p521Uint1(x425)))
var x428 uint64
var x429 uint64
x428, x429 = bits.Add64(x373, x409, uint64(p521Uint1(x427)))
var x430 uint64
var x431 uint64
x430, x431 = bits.Add64(x375, x411, uint64(p521Uint1(x429)))
x432 := (uint64(p521Uint1(x431)) + uint64(p521Uint1(x376)))
var x433 uint64
var x434 uint64
x434, x433 = bits.Mul64(x4, arg2[8])
var x435 uint64
var x436 uint64
x436, x435 = bits.Mul64(x4, arg2[7])
var x437 uint64
var x438 uint64
x438, x437 = bits.Mul64(x4, arg2[6])
var x439 uint64
var x440 uint64
x440, x439 = bits.Mul64(x4, arg2[5])
var x441 uint64
var x442 uint64
x442, x441 = bits.Mul64(x4, arg2[4])
var x443 uint64
var x444 uint64
x444, x443 = bits.Mul64(x4, arg2[3])
var x445 uint64
var x446 uint64
x446, x445 = bits.Mul64(x4, arg2[2])
var x447 uint64
var x448 uint64
x448, x447 = bits.Mul64(x4, arg2[1])
var x449 uint64
var x450 uint64
x450, x449 = bits.Mul64(x4, arg2[0])
var x451 uint64
var x452 uint64
x451, x452 = bits.Add64(x450, x447, uint64(0x0))
var x453 uint64
var x454 uint64
x453, x454 = bits.Add64(x448, x445, uint64(p521Uint1(x452)))
var x455 uint64
var x456 uint64
x455, x456 = bits.Add64(x446, x443, uint64(p521Uint1(x454)))
var x457 uint64
var x458 uint64
x457, x458 = bits.Add64(x444, x441, uint64(p521Uint1(x456)))
var x459 uint64
var x460 uint64
x459, x460 = bits.Add64(x442, x439, uint64(p521Uint1(x458)))
var x461 uint64
var x462 uint64
x461, x462 = bits.Add64(x440, x437, uint64(p521Uint1(x460)))
var x463 uint64
var x464 uint64
x463, x464 = bits.Add64(x438, x435, uint64(p521Uint1(x462)))
var x465 uint64
var x466 uint64
x465, x466 = bits.Add64(x436, x433, uint64(p521Uint1(x464)))
x467 := (uint64(p521Uint1(x466)) + x434)
var x468 uint64
var x469 uint64
x468, x469 = bits.Add64(x414, x449, uint64(0x0))
var x470 uint64
var x471 uint64
x470, x471 = bits.Add64(x416, x451, uint64(p521Uint1(x469)))
var x472 uint64
var x473 uint64
x472, x473 = bits.Add64(x418, x453, uint64(p521Uint1(x471)))
var x474 uint64
var x475 uint64
x474, x475 = bits.Add64(x420, x455, uint64(p521Uint1(x473)))
var x476 uint64
var x477 uint64
x476, x477 = bits.Add64(x422, x457, uint64(p521Uint1(x475)))
var x478 uint64
var x479 uint64
x478, x479 = bits.Add64(x424, x459, uint64(p521Uint1(x477)))
var x480 uint64
var x481 uint64
x480, x481 = bits.Add64(x426, x461, uint64(p521Uint1(x479)))
var x482 uint64
var x483 uint64
x482, x483 = bits.Add64(x428, x463, uint64(p521Uint1(x481)))
var x484 uint64
var x485 uint64
x484, x485 = bits.Add64(x430, x465, uint64(p521Uint1(x483)))
var x486 uint64
var x487 uint64
x486, x487 = bits.Add64(x432, x467, uint64(p521Uint1(x485)))
var x488 uint64
var x489 uint64
x489, x488 = bits.Mul64(x468, 0x1ff)
var x490 uint64
var x491 uint64
x491, x490 = bits.Mul64(x468, 0xffffffffffffffff)
var x492 uint64
var x493 uint64
x493, x492 = bits.Mul64(x468, 0xffffffffffffffff)
var x494 uint64
var x495 uint64
x495, x494 = bits.Mul64(x468, 0xffffffffffffffff)
var x496 uint64
var x497 uint64
x497, x496 = bits.Mul64(x468, 0xffffffffffffffff)
var x498 uint64
var x499 uint64
x499, x498 = bits.Mul64(x468, 0xffffffffffffffff)
var x500 uint64
var x501 uint64
x501, x500 = bits.Mul64(x468, 0xffffffffffffffff)
var x502 uint64
var x503 uint64
x503, x502 = bits.Mul64(x468, 0xffffffffffffffff)
var x504 uint64
var x505 uint64
x505, x504 = bits.Mul64(x468, 0xffffffffffffffff)
var x506 uint64
var x507 uint64
x506, x507 = bits.Add64(x505, x502, uint64(0x0))
var x508 uint64
var x509 uint64
x508, x509 = bits.Add64(x503, x500, uint64(p521Uint1(x507)))
var x510 uint64
var x511 uint64
x510, x511 = bits.Add64(x501, x498, uint64(p521Uint1(x509)))
var x512 uint64
var x513 uint64
x512, x513 = bits.Add64(x499, x496, uint64(p521Uint1(x511)))
var x514 uint64
var x515 uint64
x514, x515 = bits.Add64(x497, x494, uint64(p521Uint1(x513)))
var x516 uint64
var x517 uint64
x516, x517 = bits.Add64(x495, x492, uint64(p521Uint1(x515)))
var x518 uint64
var x519 uint64
x518, x519 = bits.Add64(x493, x490, uint64(p521Uint1(x517)))
var x520 uint64
var x521 uint64
x520, x521 = bits.Add64(x491, x488, uint64(p521Uint1(x519)))
x522 := (uint64(p521Uint1(x521)) + x489)
var x524 uint64
_, x524 = bits.Add64(x468, x504, uint64(0x0))
var x525 uint64
var x526 uint64
x525, x526 = bits.Add64(x470, x506, uint64(p521Uint1(x524)))
var x527 uint64
var x528 uint64
x527, x528 = bits.Add64(x472, x508, uint64(p521Uint1(x526)))
var x529 uint64
var x530 uint64
x529, x530 = bits.Add64(x474, x510, uint64(p521Uint1(x528)))
var x531 uint64
var x532 uint64
x531, x532 = bits.Add64(x476, x512, uint64(p521Uint1(x530)))
var x533 uint64
var x534 uint64
x533, x534 = bits.Add64(x478, x514, uint64(p521Uint1(x532)))
var x535 uint64
var x536 uint64
x535, x536 = bits.Add64(x480, x516, uint64(p521Uint1(x534)))
var x537 uint64
var x538 uint64
x537, x538 = bits.Add64(x482, x518, uint64(p521Uint1(x536)))
var x539 uint64
var x540 uint64
x539, x540 = bits.Add64(x484, x520, uint64(p521Uint1(x538)))
var x541 uint64
var x542 uint64
x541, x542 = bits.Add64(x486, x522, uint64(p521Uint1(x540)))
x543 := (uint64(p521Uint1(x542)) + uint64(p521Uint1(x487)))
var x544 uint64
var x545 uint64
x545, x544 = bits.Mul64(x5, arg2[8])
var x546 uint64
var x547 uint64
x547, x546 = bits.Mul64(x5, arg2[7])
var x548 uint64
var x549 uint64
x549, x548 = bits.Mul64(x5, arg2[6])
var x550 uint64
var x551 uint64
x551, x550 = bits.Mul64(x5, arg2[5])
var x552 uint64
var x553 uint64
x553, x552 = bits.Mul64(x5, arg2[4])
var x554 uint64
var x555 uint64
x555, x554 = bits.Mul64(x5, arg2[3])
var x556 uint64
var x557 uint64
x557, x556 = bits.Mul64(x5, arg2[2])
var x558 uint64
var x559 uint64
x559, x558 = bits.Mul64(x5, arg2[1])
var x560 uint64
var x561 uint64
x561, x560 = bits.Mul64(x5, arg2[0])
var x562 uint64
var x563 uint64
x562, x563 = bits.Add64(x561, x558, uint64(0x0))
var x564 uint64
var x565 uint64
x564, x565 = bits.Add64(x559, x556, uint64(p521Uint1(x563)))
var x566 uint64
var x567 uint64
x566, x567 = bits.Add64(x557, x554, uint64(p521Uint1(x565)))
var x568 uint64
var x569 uint64
x568, x569 = bits.Add64(x555, x552, uint64(p521Uint1(x567)))
var x570 uint64
var x571 uint64
x570, x571 = bits.Add64(x553, x550, uint64(p521Uint1(x569)))
var x572 uint64
var x573 uint64
x572, x573 = bits.Add64(x551, x548, uint64(p521Uint1(x571)))
var x574 uint64
var x575 uint64
x574, x575 = bits.Add64(x549, x546, uint64(p521Uint1(x573)))
var x576 uint64
var x577 uint64
x576, x577 = bits.Add64(x547, x544, uint64(p521Uint1(x575)))
x578 := (uint64(p521Uint1(x577)) + x545)
var x579 uint64
var x580 uint64
x579, x580 = bits.Add64(x525, x560, uint64(0x0))
var x581 uint64
var x582 uint64
x581, x582 = bits.Add64(x527, x562, uint64(p521Uint1(x580)))
var x583 uint64
var x584 uint64
x583, x584 = bits.Add64(x529, x564, uint64(p521Uint1(x582)))
var x585 uint64
var x586 uint64
x585, x586 = bits.Add64(x531, x566, uint64(p521Uint1(x584)))
var x587 uint64
var x588 uint64
x587, x588 = bits.Add64(x533, x568, uint64(p521Uint1(x586)))
var x589 uint64
var x590 uint64
x589, x590 = bits.Add64(x535, x570, uint64(p521Uint1(x588)))
var x591 uint64
var x592 uint64
x591, x592 = bits.Add64(x537, x572, uint64(p521Uint1(x590)))
var x593 uint64
var x594 uint64
x593, x594 = bits.Add64(x539, x574, uint64(p521Uint1(x592)))
var x595 uint64
var x596 uint64
x595, x596 = bits.Add64(x541, x576, uint64(p521Uint1(x594)))
var x597 uint64
var x598 uint64
x597, x598 = bits.Add64(x543, x578, uint64(p521Uint1(x596)))
var x599 uint64
var x600 uint64
x600, x599 = bits.Mul64(x579, 0x1ff)
var x601 uint64
var x602 uint64
x602, x601 = bits.Mul64(x579, 0xffffffffffffffff)
var x603 uint64
var x604 uint64
x604, x603 = bits.Mul64(x579, 0xffffffffffffffff)
var x605 uint64
var x606 uint64
x606, x605 = bits.Mul64(x579, 0xffffffffffffffff)
var x607 uint64
var x608 uint64
x608, x607 = bits.Mul64(x579, 0xffffffffffffffff)
var x609 uint64
var x610 uint64
x610, x609 = bits.Mul64(x579, 0xffffffffffffffff)
var x611 uint64
var x612 uint64
x612, x611 = bits.Mul64(x579, 0xffffffffffffffff)
var x613 uint64
var x614 uint64
x614, x613 = bits.Mul64(x579, 0xffffffffffffffff)
var x615 uint64
var x616 uint64
x616, x615 = bits.Mul64(x579, 0xffffffffffffffff)
var x617 uint64
var x618 uint64
x617, x618 = bits.Add64(x616, x613, uint64(0x0))
var x619 uint64
var x620 uint64
x619, x620 = bits.Add64(x614, x611, uint64(p521Uint1(x618)))
var x621 uint64
var x622 uint64
x621, x622 = bits.Add64(x612, x609, uint64(p521Uint1(x620)))
var x623 uint64
var x624 uint64
x623, x624 = bits.Add64(x610, x607, uint64(p521Uint1(x622)))
var x625 uint64
var x626 uint64
x625, x626 = bits.Add64(x608, x605, uint64(p521Uint1(x624)))
var x627 uint64
var x628 uint64
x627, x628 = bits.Add64(x606, x603, uint64(p521Uint1(x626)))
var x629 uint64
var x630 uint64
x629, x630 = bits.Add64(x604, x601, uint64(p521Uint1(x628)))
var x631 uint64
var x632 uint64
x631, x632 = bits.Add64(x602, x599, uint64(p521Uint1(x630)))
x633 := (uint64(p521Uint1(x632)) + x600)
var x635 uint64
_, x635 = bits.Add64(x579, x615, uint64(0x0))
var x636 uint64
var x637 uint64
x636, x637 = bits.Add64(x581, x617, uint64(p521Uint1(x635)))
var x638 uint64
var x639 uint64
x638, x639 = bits.Add64(x583, x619, uint64(p521Uint1(x637)))
var x640 uint64
var x641 uint64
x640, x641 = bits.Add64(x585, x621, uint64(p521Uint1(x639)))
var x642 uint64
var x643 uint64
x642, x643 = bits.Add64(x587, x623, uint64(p521Uint1(x641)))
var x644 uint64
var x645 uint64
x644, x645 = bits.Add64(x589, x625, uint64(p521Uint1(x643)))
var x646 uint64
var x647 uint64
x646, x647 = bits.Add64(x591, x627, uint64(p521Uint1(x645)))
var x648 uint64
var x649 uint64
x648, x649 = bits.Add64(x593, x629, uint64(p521Uint1(x647)))
var x650 uint64
var x651 uint64
x650, x651 = bits.Add64(x595, x631, uint64(p521Uint1(x649)))
var x652 uint64
var x653 uint64
x652, x653 = bits.Add64(x597, x633, uint64(p521Uint1(x651)))
x654 := (uint64(p521Uint1(x653)) + uint64(p521Uint1(x598)))
var x655 uint64
var x656 uint64
x656, x655 = bits.Mul64(x6, arg2[8])
var x657 uint64
var x658 uint64
x658, x657 = bits.Mul64(x6, arg2[7])
var x659 uint64
var x660 uint64
x660, x659 = bits.Mul64(x6, arg2[6])
var x661 uint64
var x662 uint64
x662, x661 = bits.Mul64(x6, arg2[5])
var x663 uint64
var x664 uint64
x664, x663 = bits.Mul64(x6, arg2[4])
var x665 uint64
var x666 uint64
x666, x665 = bits.Mul64(x6, arg2[3])
var x667 uint64
var x668 uint64
x668, x667 = bits.Mul64(x6, arg2[2])
var x669 uint64
var x670 uint64
x670, x669 = bits.Mul64(x6, arg2[1])
var x671 uint64
var x672 uint64
x672, x671 = bits.Mul64(x6, arg2[0])
var x673 uint64
var x674 uint64
x673, x674 = bits.Add64(x672, x669, uint64(0x0))
var x675 uint64
var x676 uint64
x675, x676 = bits.Add64(x670, x667, uint64(p521Uint1(x674)))
var x677 uint64
var x678 uint64
x677, x678 = bits.Add64(x668, x665, uint64(p521Uint1(x676)))
var x679 uint64
var x680 uint64
x679, x680 = bits.Add64(x666, x663, uint64(p521Uint1(x678)))
var x681 uint64
var x682 uint64
x681, x682 = bits.Add64(x664, x661, uint64(p521Uint1(x680)))
var x683 uint64
var x684 uint64
x683, x684 = bits.Add64(x662, x659, uint64(p521Uint1(x682)))
var x685 uint64
var x686 uint64
x685, x686 = bits.Add64(x660, x657, uint64(p521Uint1(x684)))
var x687 uint64
var x688 uint64
x687, x688 = bits.Add64(x658, x655, uint64(p521Uint1(x686)))
x689 := (uint64(p521Uint1(x688)) + x656)
var x690 uint64
var x691 uint64
x690, x691 = bits.Add64(x636, x671, uint64(0x0))
var x692 uint64
var x693 uint64
x692, x693 = bits.Add64(x638, x673, uint64(p521Uint1(x691)))
var x694 uint64
var x695 uint64
x694, x695 = bits.Add64(x640, x675, uint64(p521Uint1(x693)))
var x696 uint64
var x697 uint64
x696, x697 = bits.Add64(x642, x677, uint64(p521Uint1(x695)))
var x698 uint64
var x699 uint64
x698, x699 = bits.Add64(x644, x679, uint64(p521Uint1(x697)))
var x700 uint64
var x701 uint64
x700, x701 = bits.Add64(x646, x681, uint64(p521Uint1(x699)))
var x702 uint64
var x703 uint64
x702, x703 = bits.Add64(x648, x683, uint64(p521Uint1(x701)))
var x704 uint64
var x705 uint64
x704, x705 = bits.Add64(x650, x685, uint64(p521Uint1(x703)))
var x706 uint64
var x707 uint64
x706, x707 = bits.Add64(x652, x687, uint64(p521Uint1(x705)))
var x708 uint64
var x709 uint64
x708, x709 = bits.Add64(x654, x689, uint64(p521Uint1(x707)))
var x710 uint64
var x711 uint64
x711, x710 = bits.Mul64(x690, 0x1ff)
var x712 uint64
var x713 uint64
x713, x712 = bits.Mul64(x690, 0xffffffffffffffff)
var x714 uint64
var x715 uint64
x715, x714 = bits.Mul64(x690, 0xffffffffffffffff)
var x716 uint64
var x717 uint64
x717, x716 = bits.Mul64(x690, 0xffffffffffffffff)
var x718 uint64
var x719 uint64
x719, x718 = bits.Mul64(x690, 0xffffffffffffffff)
var x720 uint64
var x721 uint64
x721, x720 = bits.Mul64(x690, 0xffffffffffffffff)
var x722 uint64
var x723 uint64
x723, x722 = bits.Mul64(x690, 0xffffffffffffffff)
var x724 uint64
var x725 uint64
x725, x724 = bits.Mul64(x690, 0xffffffffffffffff)
var x726 uint64
var x727 uint64
x727, x726 = bits.Mul64(x690, 0xffffffffffffffff)
var x728 uint64
var x729 uint64
x728, x729 = bits.Add64(x727, x724, uint64(0x0))
var x730 uint64
var x731 uint64
x730, x731 = bits.Add64(x725, x722, uint64(p521Uint1(x729)))
var x732 uint64
var x733 uint64
x732, x733 = bits.Add64(x723, x720, uint64(p521Uint1(x731)))
var x734 uint64
var x735 uint64
x734, x735 = bits.Add64(x721, x718, uint64(p521Uint1(x733)))
var x736 uint64
var x737 uint64
x736, x737 = bits.Add64(x719, x716, uint64(p521Uint1(x735)))
var x738 uint64
var x739 uint64
x738, x739 = bits.Add64(x717, x714, uint64(p521Uint1(x737)))
var x740 uint64
var x741 uint64
x740, x741 = bits.Add64(x715, x712, uint64(p521Uint1(x739)))
var x742 uint64
var x743 uint64
x742, x743 = bits.Add64(x713, x710, uint64(p521Uint1(x741)))
x744 := (uint64(p521Uint1(x743)) + x711)
var x746 uint64
_, x746 = bits.Add64(x690, x726, uint64(0x0))
var x747 uint64
var x748 uint64
x747, x748 = bits.Add64(x692, x728, uint64(p521Uint1(x746)))
var x749 uint64
var x750 uint64
x749, x750 = bits.Add64(x694, x730, uint64(p521Uint1(x748)))
var x751 uint64
var x752 uint64
x751, x752 = bits.Add64(x696, x732, uint64(p521Uint1(x750)))
var x753 uint64
var x754 uint64
x753, x754 = bits.Add64(x698, x734, uint64(p521Uint1(x752)))
var x755 uint64
var x756 uint64
x755, x756 = bits.Add64(x700, x736, uint64(p521Uint1(x754)))
var x757 uint64
var x758 uint64
x757, x758 = bits.Add64(x702, x738, uint64(p521Uint1(x756)))
var x759 uint64
var x760 uint64
x759, x760 = bits.Add64(x704, x740, uint64(p521Uint1(x758)))
var x761 uint64
var x762 uint64
x761, x762 = bits.Add64(x706, x742, uint64(p521Uint1(x760)))
var x763 uint64
var x764 uint64
x763, x764 = bits.Add64(x708, x744, uint64(p521Uint1(x762)))
x765 := (uint64(p521Uint1(x764)) + uint64(p521Uint1(x709)))
var x766 uint64
var x767 uint64
x767, x766 = bits.Mul64(x7, arg2[8])
var x768 uint64
var x769 uint64
x769, x768 = bits.Mul64(x7, arg2[7])
var x770 uint64
var x771 uint64
x771, x770 = bits.Mul64(x7, arg2[6])
var x772 uint64
var x773 uint64
x773, x772 = bits.Mul64(x7, arg2[5])
var x774 uint64
var x775 uint64
x775, x774 = bits.Mul64(x7, arg2[4])
var x776 uint64
var x777 uint64
x777, x776 = bits.Mul64(x7, arg2[3])
var x778 uint64
var x779 uint64
x779, x778 = bits.Mul64(x7, arg2[2])
var x780 uint64
var x781 uint64
x781, x780 = bits.Mul64(x7, arg2[1])
var x782 uint64
var x783 uint64
x783, x782 = bits.Mul64(x7, arg2[0])
var x784 uint64
var x785 uint64
x784, x785 = bits.Add64(x783, x780, uint64(0x0))
var x786 uint64
var x787 uint64
x786, x787 = bits.Add64(x781, x778, uint64(p521Uint1(x785)))
var x788 uint64
var x789 uint64
x788, x789 = bits.Add64(x779, x776, uint64(p521Uint1(x787)))
var x790 uint64
var x791 uint64
x790, x791 = bits.Add64(x777, x774, uint64(p521Uint1(x789)))
var x792 uint64
var x793 uint64
x792, x793 = bits.Add64(x775, x772, uint64(p521Uint1(x791)))
var x794 uint64
var x795 uint64
x794, x795 = bits.Add64(x773, x770, uint64(p521Uint1(x793)))
var x796 uint64
var x797 uint64
x796, x797 = bits.Add64(x771, x768, uint64(p521Uint1(x795)))
var x798 uint64
var x799 uint64
x798, x799 = bits.Add64(x769, x766, uint64(p521Uint1(x797)))
x800 := (uint64(p521Uint1(x799)) + x767)
var x801 uint64
var x802 uint64
x801, x802 = bits.Add64(x747, x782, uint64(0x0))
var x803 uint64
var x804 uint64
x803, x804 = bits.Add64(x749, x784, uint64(p521Uint1(x802)))
var x805 uint64
var x806 uint64
x805, x806 = bits.Add64(x751, x786, uint64(p521Uint1(x804)))
var x807 uint64
var x808 uint64
x807, x808 = bits.Add64(x753, x788, uint64(p521Uint1(x806)))
var x809 uint64
var x810 uint64
x809, x810 = bits.Add64(x755, x790, uint64(p521Uint1(x808)))
var x811 uint64
var x812 uint64
x811, x812 = bits.Add64(x757, x792, uint64(p521Uint1(x810)))
var x813 uint64
var x814 uint64
x813, x814 = bits.Add64(x759, x794, uint64(p521Uint1(x812)))
var x815 uint64
var x816 uint64
x815, x816 = bits.Add64(x761, x796, uint64(p521Uint1(x814)))
var x817 uint64
var x818 uint64
x817, x818 = bits.Add64(x763, x798, uint64(p521Uint1(x816)))
var x819 uint64
var x820 uint64
x819, x820 = bits.Add64(x765, x800, uint64(p521Uint1(x818)))
var x821 uint64
var x822 uint64
x822, x821 = bits.Mul64(x801, 0x1ff)
var x823 uint64
var x824 uint64
x824, x823 = bits.Mul64(x801, 0xffffffffffffffff)
var x825 uint64
var x826 uint64
x826, x825 = bits.Mul64(x801, 0xffffffffffffffff)
var x827 uint64
var x828 uint64
x828, x827 = bits.Mul64(x801, 0xffffffffffffffff)
var x829 uint64
var x830 uint64
x830, x829 = bits.Mul64(x801, 0xffffffffffffffff)
var x831 uint64
var x832 uint64
x832, x831 = bits.Mul64(x801, 0xffffffffffffffff)
var x833 uint64
var x834 uint64
x834, x833 = bits.Mul64(x801, 0xffffffffffffffff)
var x835 uint64
var x836 uint64
x836, x835 = bits.Mul64(x801, 0xffffffffffffffff)
var x837 uint64
var x838 uint64
x838, x837 = bits.Mul64(x801, 0xffffffffffffffff)
var x839 uint64
var x840 uint64
x839, x840 = bits.Add64(x838, x835, uint64(0x0))
var x841 uint64
var x842 uint64
x841, x842 = bits.Add64(x836, x833, uint64(p521Uint1(x840)))
var x843 uint64
var x844 uint64
x843, x844 = bits.Add64(x834, x831, uint64(p521Uint1(x842)))
var x845 uint64
var x846 uint64
x845, x846 = bits.Add64(x832, x829, uint64(p521Uint1(x844)))
var x847 uint64
var x848 uint64
x847, x848 = bits.Add64(x830, x827, uint64(p521Uint1(x846)))
var x849 uint64
var x850 uint64
x849, x850 = bits.Add64(x828, x825, uint64(p521Uint1(x848)))
var x851 uint64
var x852 uint64
x851, x852 = bits.Add64(x826, x823, uint64(p521Uint1(x850)))
var x853 uint64
var x854 uint64
x853, x854 = bits.Add64(x824, x821, uint64(p521Uint1(x852)))
x855 := (uint64(p521Uint1(x854)) + x822)
var x857 uint64
_, x857 = bits.Add64(x801, x837, uint64(0x0))
var x858 uint64
var x859 uint64
x858, x859 = bits.Add64(x803, x839, uint64(p521Uint1(x857)))
var x860 uint64
var x861 uint64
x860, x861 = bits.Add64(x805, x841, uint64(p521Uint1(x859)))
var x862 uint64
var x863 uint64
x862, x863 = bits.Add64(x807, x843, uint64(p521Uint1(x861)))
var x864 uint64
var x865 uint64
x864, x865 = bits.Add64(x809, x845, uint64(p521Uint1(x863)))
var x866 uint64
var x867 uint64
x866, x867 = bits.Add64(x811, x847, uint64(p521Uint1(x865)))
var x868 uint64
var x869 uint64
x868, x869 = bits.Add64(x813, x849, uint64(p521Uint1(x867)))
var x870 uint64
var x871 uint64
x870, x871 = bits.Add64(x815, x851, uint64(p521Uint1(x869)))
var x872 uint64
var x873 uint64
x872, x873 = bits.Add64(x817, x853, uint64(p521Uint1(x871)))
var x874 uint64
var x875 uint64
x874, x875 = bits.Add64(x819, x855, uint64(p521Uint1(x873)))
x876 := (uint64(p521Uint1(x875)) + uint64(p521Uint1(x820)))
var x877 uint64
var x878 uint64
x878, x877 = bits.Mul64(x8, arg2[8])
var x879 uint64
var x880 uint64
x880, x879 = bits.Mul64(x8, arg2[7])
var x881 uint64
var x882 uint64
x882, x881 = bits.Mul64(x8, arg2[6])
var x883 uint64
var x884 uint64
x884, x883 = bits.Mul64(x8, arg2[5])
var x885 uint64
var x886 uint64
x886, x885 = bits.Mul64(x8, arg2[4])
var x887 uint64
var x888 uint64
x888, x887 = bits.Mul64(x8, arg2[3])
var x889 uint64
var x890 uint64
x890, x889 = bits.Mul64(x8, arg2[2])
var x891 uint64
var x892 uint64
x892, x891 = bits.Mul64(x8, arg2[1])
var x893 uint64
var x894 uint64
x894, x893 = bits.Mul64(x8, arg2[0])
var x895 uint64
var x896 uint64
x895, x896 = bits.Add64(x894, x891, uint64(0x0))
var x897 uint64
var x898 uint64
x897, x898 = bits.Add64(x892, x889, uint64(p521Uint1(x896)))
var x899 uint64
var x900 uint64
x899, x900 = bits.Add64(x890, x887, uint64(p521Uint1(x898)))
var x901 uint64
var x902 uint64
x901, x902 = bits.Add64(x888, x885, uint64(p521Uint1(x900)))
var x903 uint64
var x904 uint64
x903, x904 = bits.Add64(x886, x883, uint64(p521Uint1(x902)))
var x905 uint64
var x906 uint64
x905, x906 = bits.Add64(x884, x881, uint64(p521Uint1(x904)))
var x907 uint64
var x908 uint64
x907, x908 = bits.Add64(x882, x879, uint64(p521Uint1(x906)))
var x909 uint64
var x910 uint64
x909, x910 = bits.Add64(x880, x877, uint64(p521Uint1(x908)))
x911 := (uint64(p521Uint1(x910)) + x878)
var x912 uint64
var x913 uint64
x912, x913 = bits.Add64(x858, x893, uint64(0x0))
var x914 uint64
var x915 uint64
x914, x915 = bits.Add64(x860, x895, uint64(p521Uint1(x913)))
var x916 uint64
var x917 uint64
x916, x917 = bits.Add64(x862, x897, uint64(p521Uint1(x915)))
var x918 uint64
var x919 uint64
x918, x919 = bits.Add64(x864, x899, uint64(p521Uint1(x917)))
var x920 uint64
var x921 uint64
x920, x921 = bits.Add64(x866, x901, uint64(p521Uint1(x919)))
var x922 uint64
var x923 uint64
x922, x923 = bits.Add64(x868, x903, uint64(p521Uint1(x921)))
var x924 uint64
var x925 uint64
x924, x925 = bits.Add64(x870, x905, uint64(p521Uint1(x923)))
var x926 uint64
var x927 uint64
x926, x927 = bits.Add64(x872, x907, uint64(p521Uint1(x925)))
var x928 uint64
var x929 uint64
x928, x929 = bits.Add64(x874, x909, uint64(p521Uint1(x927)))
var x930 uint64
var x931 uint64
x930, x931 = bits.Add64(x876, x911, uint64(p521Uint1(x929)))
var x932 uint64
var x933 uint64
x933, x932 = bits.Mul64(x912, 0x1ff)
var x934 uint64
var x935 uint64
x935, x934 = bits.Mul64(x912, 0xffffffffffffffff)
var x936 uint64
var x937 uint64
x937, x936 = bits.Mul64(x912, 0xffffffffffffffff)
var x938 uint64
var x939 uint64
x939, x938 = bits.Mul64(x912, 0xffffffffffffffff)
var x940 uint64
var x941 uint64
x941, x940 = bits.Mul64(x912, 0xffffffffffffffff)
var x942 uint64
var x943 uint64
x943, x942 = bits.Mul64(x912, 0xffffffffffffffff)
var x944 uint64
var x945 uint64
x945, x944 = bits.Mul64(x912, 0xffffffffffffffff)
var x946 uint64
var x947 uint64
x947, x946 = bits.Mul64(x912, 0xffffffffffffffff)
var x948 uint64
var x949 uint64
x949, x948 = bits.Mul64(x912, 0xffffffffffffffff)
var x950 uint64
var x951 uint64
x950, x951 = bits.Add64(x949, x946, uint64(0x0))
var x952 uint64
var x953 uint64
x952, x953 = bits.Add64(x947, x944, uint64(p521Uint1(x951)))
var x954 uint64
var x955 uint64
x954, x955 = bits.Add64(x945, x942, uint64(p521Uint1(x953)))
var x956 uint64
var x957 uint64
x956, x957 = bits.Add64(x943, x940, uint64(p521Uint1(x955)))
var x958 uint64
var x959 uint64
x958, x959 = bits.Add64(x941, x938, uint64(p521Uint1(x957)))
var x960 uint64
var x961 uint64
x960, x961 = bits.Add64(x939, x936, uint64(p521Uint1(x959)))
var x962 uint64
var x963 uint64
x962, x963 = bits.Add64(x937, x934, uint64(p521Uint1(x961)))
var x964 uint64
var x965 uint64
x964, x965 = bits.Add64(x935, x932, uint64(p521Uint1(x963)))
x966 := (uint64(p521Uint1(x965)) + x933)
var x968 uint64
_, x968 = bits.Add64(x912, x948, uint64(0x0))
var x969 uint64
var x970 uint64
x969, x970 = bits.Add64(x914, x950, uint64(p521Uint1(x968)))
var x971 uint64
var x972 uint64
x971, x972 = bits.Add64(x916, x952, uint64(p521Uint1(x970)))
var x973 uint64
var x974 uint64
x973, x974 = bits.Add64(x918, x954, uint64(p521Uint1(x972)))
var x975 uint64
var x976 uint64
x975, x976 = bits.Add64(x920, x956, uint64(p521Uint1(x974)))
var x977 uint64
var x978 uint64
x977, x978 = bits.Add64(x922, x958, uint64(p521Uint1(x976)))
var x979 uint64
var x980 uint64
x979, x980 = bits.Add64(x924, x960, uint64(p521Uint1(x978)))
var x981 uint64
var x982 uint64
x981, x982 = bits.Add64(x926, x962, uint64(p521Uint1(x980)))
var x983 uint64
var x984 uint64
x983, x984 = bits.Add64(x928, x964, uint64(p521Uint1(x982)))
var x985 uint64
var x986 uint64
x985, x986 = bits.Add64(x930, x966, uint64(p521Uint1(x984)))
x987 := (uint64(p521Uint1(x986)) + uint64(p521Uint1(x931)))
var x988 uint64
var x989 uint64
x988, x989 = bits.Sub64(x969, 0xffffffffffffffff, uint64(0x0))
var x990 uint64
var x991 uint64
x990, x991 = bits.Sub64(x971, 0xffffffffffffffff, uint64(p521Uint1(x989)))
var x992 uint64
var x993 uint64
x992, x993 = bits.Sub64(x973, 0xffffffffffffffff, uint64(p521Uint1(x991)))
var x994 uint64
var x995 uint64
x994, x995 = bits.Sub64(x975, 0xffffffffffffffff, uint64(p521Uint1(x993)))
var x996 uint64
var x997 uint64
x996, x997 = bits.Sub64(x977, 0xffffffffffffffff, uint64(p521Uint1(x995)))
var x998 uint64
var x999 uint64
x998, x999 = bits.Sub64(x979, 0xffffffffffffffff, uint64(p521Uint1(x997)))
var x1000 uint64
var x1001 uint64
x1000, x1001 = bits.Sub64(x981, 0xffffffffffffffff, uint64(p521Uint1(x999)))
var x1002 uint64
var x1003 uint64
x1002, x1003 = bits.Sub64(x983, 0xffffffffffffffff, uint64(p521Uint1(x1001)))
var x1004 uint64
var x1005 uint64
x1004, x1005 = bits.Sub64(x985, 0x1ff, uint64(p521Uint1(x1003)))
var x1007 uint64
_, x1007 = bits.Sub64(x987, uint64(0x0), uint64(p521Uint1(x1005)))
var x1008 uint64
p521CmovznzU64(&x1008, p521Uint1(x1007), x988, x969)
var x1009 uint64
p521CmovznzU64(&x1009, p521Uint1(x1007), x990, x971)
var x1010 uint64
p521CmovznzU64(&x1010, p521Uint1(x1007), x992, x973)
var x1011 uint64
p521CmovznzU64(&x1011, p521Uint1(x1007), x994, x975)
var x1012 uint64
p521CmovznzU64(&x1012, p521Uint1(x1007), x996, x977)
var x1013 uint64
p521CmovznzU64(&x1013, p521Uint1(x1007), x998, x979)
var x1014 uint64
p521CmovznzU64(&x1014, p521Uint1(x1007), x1000, x981)
var x1015 uint64
p521CmovznzU64(&x1015, p521Uint1(x1007), x1002, x983)
var x1016 uint64
p521CmovznzU64(&x1016, p521Uint1(x1007), x1004, x985)
out1[0] = x1008
out1[1] = x1009
out1[2] = x1010
out1[3] = x1011
out1[4] = x1012
out1[5] = x1013
out1[6] = x1014
out1[7] = x1015
out1[8] = x1016
}
// p521Square squares a field element in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) * eval (from_montgomery arg1)) mod m
// 0 ≤ eval out1 < m
func p521Square(out1 *p521MontgomeryDomainFieldElement, arg1 *p521MontgomeryDomainFieldElement) {
x1 := arg1[1]
x2 := arg1[2]
x3 := arg1[3]
x4 := arg1[4]
x5 := arg1[5]
x6 := arg1[6]
x7 := arg1[7]
x8 := arg1[8]
x9 := arg1[0]
var x10 uint64
var x11 uint64
x11, x10 = bits.Mul64(x9, arg1[8])
var x12 uint64
var x13 uint64
x13, x12 = bits.Mul64(x9, arg1[7])
var x14 uint64
var x15 uint64
x15, x14 = bits.Mul64(x9, arg1[6])
var x16 uint64
var x17 uint64
x17, x16 = bits.Mul64(x9, arg1[5])
var x18 uint64
var x19 uint64
x19, x18 = bits.Mul64(x9, arg1[4])
var x20 uint64
var x21 uint64
x21, x20 = bits.Mul64(x9, arg1[3])
var x22 uint64
var x23 uint64
x23, x22 = bits.Mul64(x9, arg1[2])
var x24 uint64
var x25 uint64
x25, x24 = bits.Mul64(x9, arg1[1])
var x26 uint64
var x27 uint64
x27, x26 = bits.Mul64(x9, arg1[0])
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x27, x24, uint64(0x0))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x25, x22, uint64(p521Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x23, x20, uint64(p521Uint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x21, x18, uint64(p521Uint1(x33)))
var x36 uint64
var x37 uint64
x36, x37 = bits.Add64(x19, x16, uint64(p521Uint1(x35)))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(x17, x14, uint64(p521Uint1(x37)))
var x40 uint64
var x41 uint64
x40, x41 = bits.Add64(x15, x12, uint64(p521Uint1(x39)))
var x42 uint64
var x43 uint64
x42, x43 = bits.Add64(x13, x10, uint64(p521Uint1(x41)))
x44 := (uint64(p521Uint1(x43)) + x11)
var x45 uint64
var x46 uint64
x46, x45 = bits.Mul64(x26, 0x1ff)
var x47 uint64
var x48 uint64
x48, x47 = bits.Mul64(x26, 0xffffffffffffffff)
var x49 uint64
var x50 uint64
x50, x49 = bits.Mul64(x26, 0xffffffffffffffff)
var x51 uint64
var x52 uint64
x52, x51 = bits.Mul64(x26, 0xffffffffffffffff)
var x53 uint64
var x54 uint64
x54, x53 = bits.Mul64(x26, 0xffffffffffffffff)
var x55 uint64
var x56 uint64
x56, x55 = bits.Mul64(x26, 0xffffffffffffffff)
var x57 uint64
var x58 uint64
x58, x57 = bits.Mul64(x26, 0xffffffffffffffff)
var x59 uint64
var x60 uint64
x60, x59 = bits.Mul64(x26, 0xffffffffffffffff)
var x61 uint64
var x62 uint64
x62, x61 = bits.Mul64(x26, 0xffffffffffffffff)
var x63 uint64
var x64 uint64
x63, x64 = bits.Add64(x62, x59, uint64(0x0))
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x60, x57, uint64(p521Uint1(x64)))
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x58, x55, uint64(p521Uint1(x66)))
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x56, x53, uint64(p521Uint1(x68)))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x54, x51, uint64(p521Uint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x52, x49, uint64(p521Uint1(x72)))
var x75 uint64
var x76 uint64
x75, x76 = bits.Add64(x50, x47, uint64(p521Uint1(x74)))
var x77 uint64
var x78 uint64
x77, x78 = bits.Add64(x48, x45, uint64(p521Uint1(x76)))
x79 := (uint64(p521Uint1(x78)) + x46)
var x81 uint64
_, x81 = bits.Add64(x26, x61, uint64(0x0))
var x82 uint64
var x83 uint64
x82, x83 = bits.Add64(x28, x63, uint64(p521Uint1(x81)))
var x84 uint64
var x85 uint64
x84, x85 = bits.Add64(x30, x65, uint64(p521Uint1(x83)))
var x86 uint64
var x87 uint64
x86, x87 = bits.Add64(x32, x67, uint64(p521Uint1(x85)))
var x88 uint64
var x89 uint64
x88, x89 = bits.Add64(x34, x69, uint64(p521Uint1(x87)))
var x90 uint64
var x91 uint64
x90, x91 = bits.Add64(x36, x71, uint64(p521Uint1(x89)))
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x38, x73, uint64(p521Uint1(x91)))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x40, x75, uint64(p521Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x42, x77, uint64(p521Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x44, x79, uint64(p521Uint1(x97)))
var x100 uint64
var x101 uint64
x101, x100 = bits.Mul64(x1, arg1[8])
var x102 uint64
var x103 uint64
x103, x102 = bits.Mul64(x1, arg1[7])
var x104 uint64
var x105 uint64
x105, x104 = bits.Mul64(x1, arg1[6])
var x106 uint64
var x107 uint64
x107, x106 = bits.Mul64(x1, arg1[5])
var x108 uint64
var x109 uint64
x109, x108 = bits.Mul64(x1, arg1[4])
var x110 uint64
var x111 uint64
x111, x110 = bits.Mul64(x1, arg1[3])
var x112 uint64
var x113 uint64
x113, x112 = bits.Mul64(x1, arg1[2])
var x114 uint64
var x115 uint64
x115, x114 = bits.Mul64(x1, arg1[1])
var x116 uint64
var x117 uint64
x117, x116 = bits.Mul64(x1, arg1[0])
var x118 uint64
var x119 uint64
x118, x119 = bits.Add64(x117, x114, uint64(0x0))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64(x115, x112, uint64(p521Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x113, x110, uint64(p521Uint1(x121)))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x111, x108, uint64(p521Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x109, x106, uint64(p521Uint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x107, x104, uint64(p521Uint1(x127)))
var x130 uint64
var x131 uint64
x130, x131 = bits.Add64(x105, x102, uint64(p521Uint1(x129)))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x103, x100, uint64(p521Uint1(x131)))
x134 := (uint64(p521Uint1(x133)) + x101)
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x82, x116, uint64(0x0))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x84, x118, uint64(p521Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x86, x120, uint64(p521Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x88, x122, uint64(p521Uint1(x140)))
var x143 uint64
var x144 uint64
x143, x144 = bits.Add64(x90, x124, uint64(p521Uint1(x142)))
var x145 uint64
var x146 uint64
x145, x146 = bits.Add64(x92, x126, uint64(p521Uint1(x144)))
var x147 uint64
var x148 uint64
x147, x148 = bits.Add64(x94, x128, uint64(p521Uint1(x146)))
var x149 uint64
var x150 uint64
x149, x150 = bits.Add64(x96, x130, uint64(p521Uint1(x148)))
var x151 uint64
var x152 uint64
x151, x152 = bits.Add64(x98, x132, uint64(p521Uint1(x150)))
var x153 uint64
var x154 uint64
x153, x154 = bits.Add64(uint64(p521Uint1(x99)), x134, uint64(p521Uint1(x152)))
var x155 uint64
var x156 uint64
x156, x155 = bits.Mul64(x135, 0x1ff)
var x157 uint64
var x158 uint64
x158, x157 = bits.Mul64(x135, 0xffffffffffffffff)
var x159 uint64
var x160 uint64
x160, x159 = bits.Mul64(x135, 0xffffffffffffffff)
var x161 uint64
var x162 uint64
x162, x161 = bits.Mul64(x135, 0xffffffffffffffff)
var x163 uint64
var x164 uint64
x164, x163 = bits.Mul64(x135, 0xffffffffffffffff)
var x165 uint64
var x166 uint64
x166, x165 = bits.Mul64(x135, 0xffffffffffffffff)
var x167 uint64
var x168 uint64
x168, x167 = bits.Mul64(x135, 0xffffffffffffffff)
var x169 uint64
var x170 uint64
x170, x169 = bits.Mul64(x135, 0xffffffffffffffff)
var x171 uint64
var x172 uint64
x172, x171 = bits.Mul64(x135, 0xffffffffffffffff)
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x172, x169, uint64(0x0))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x170, x167, uint64(p521Uint1(x174)))
var x177 uint64
var x178 uint64
x177, x178 = bits.Add64(x168, x165, uint64(p521Uint1(x176)))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x166, x163, uint64(p521Uint1(x178)))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x164, x161, uint64(p521Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x162, x159, uint64(p521Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x160, x157, uint64(p521Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x158, x155, uint64(p521Uint1(x186)))
x189 := (uint64(p521Uint1(x188)) + x156)
var x191 uint64
_, x191 = bits.Add64(x135, x171, uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Add64(x137, x173, uint64(p521Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Add64(x139, x175, uint64(p521Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Add64(x141, x177, uint64(p521Uint1(x195)))
var x198 uint64
var x199 uint64
x198, x199 = bits.Add64(x143, x179, uint64(p521Uint1(x197)))
var x200 uint64
var x201 uint64
x200, x201 = bits.Add64(x145, x181, uint64(p521Uint1(x199)))
var x202 uint64
var x203 uint64
x202, x203 = bits.Add64(x147, x183, uint64(p521Uint1(x201)))
var x204 uint64
var x205 uint64
x204, x205 = bits.Add64(x149, x185, uint64(p521Uint1(x203)))
var x206 uint64
var x207 uint64
x206, x207 = bits.Add64(x151, x187, uint64(p521Uint1(x205)))
var x208 uint64
var x209 uint64
x208, x209 = bits.Add64(x153, x189, uint64(p521Uint1(x207)))
x210 := (uint64(p521Uint1(x209)) + uint64(p521Uint1(x154)))
var x211 uint64
var x212 uint64
x212, x211 = bits.Mul64(x2, arg1[8])
var x213 uint64
var x214 uint64
x214, x213 = bits.Mul64(x2, arg1[7])
var x215 uint64
var x216 uint64
x216, x215 = bits.Mul64(x2, arg1[6])
var x217 uint64
var x218 uint64
x218, x217 = bits.Mul64(x2, arg1[5])
var x219 uint64
var x220 uint64
x220, x219 = bits.Mul64(x2, arg1[4])
var x221 uint64
var x222 uint64
x222, x221 = bits.Mul64(x2, arg1[3])
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x2, arg1[2])
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x2, arg1[1])
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x2, arg1[0])
var x229 uint64
var x230 uint64
x229, x230 = bits.Add64(x228, x225, uint64(0x0))
var x231 uint64
var x232 uint64
x231, x232 = bits.Add64(x226, x223, uint64(p521Uint1(x230)))
var x233 uint64
var x234 uint64
x233, x234 = bits.Add64(x224, x221, uint64(p521Uint1(x232)))
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x222, x219, uint64(p521Uint1(x234)))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x220, x217, uint64(p521Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x218, x215, uint64(p521Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x216, x213, uint64(p521Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x214, x211, uint64(p521Uint1(x242)))
x245 := (uint64(p521Uint1(x244)) + x212)
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x192, x227, uint64(0x0))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x194, x229, uint64(p521Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x196, x231, uint64(p521Uint1(x249)))
var x252 uint64
var x253 uint64
x252, x253 = bits.Add64(x198, x233, uint64(p521Uint1(x251)))
var x254 uint64
var x255 uint64
x254, x255 = bits.Add64(x200, x235, uint64(p521Uint1(x253)))
var x256 uint64
var x257 uint64
x256, x257 = bits.Add64(x202, x237, uint64(p521Uint1(x255)))
var x258 uint64
var x259 uint64
x258, x259 = bits.Add64(x204, x239, uint64(p521Uint1(x257)))
var x260 uint64
var x261 uint64
x260, x261 = bits.Add64(x206, x241, uint64(p521Uint1(x259)))
var x262 uint64
var x263 uint64
x262, x263 = bits.Add64(x208, x243, uint64(p521Uint1(x261)))
var x264 uint64
var x265 uint64
x264, x265 = bits.Add64(x210, x245, uint64(p521Uint1(x263)))
var x266 uint64
var x267 uint64
x267, x266 = bits.Mul64(x246, 0x1ff)
var x268 uint64
var x269 uint64
x269, x268 = bits.Mul64(x246, 0xffffffffffffffff)
var x270 uint64
var x271 uint64
x271, x270 = bits.Mul64(x246, 0xffffffffffffffff)
var x272 uint64
var x273 uint64
x273, x272 = bits.Mul64(x246, 0xffffffffffffffff)
var x274 uint64
var x275 uint64
x275, x274 = bits.Mul64(x246, 0xffffffffffffffff)
var x276 uint64
var x277 uint64
x277, x276 = bits.Mul64(x246, 0xffffffffffffffff)
var x278 uint64
var x279 uint64
x279, x278 = bits.Mul64(x246, 0xffffffffffffffff)
var x280 uint64
var x281 uint64
x281, x280 = bits.Mul64(x246, 0xffffffffffffffff)
var x282 uint64
var x283 uint64
x283, x282 = bits.Mul64(x246, 0xffffffffffffffff)
var x284 uint64
var x285 uint64
x284, x285 = bits.Add64(x283, x280, uint64(0x0))
var x286 uint64
var x287 uint64
x286, x287 = bits.Add64(x281, x278, uint64(p521Uint1(x285)))
var x288 uint64
var x289 uint64
x288, x289 = bits.Add64(x279, x276, uint64(p521Uint1(x287)))
var x290 uint64
var x291 uint64
x290, x291 = bits.Add64(x277, x274, uint64(p521Uint1(x289)))
var x292 uint64
var x293 uint64
x292, x293 = bits.Add64(x275, x272, uint64(p521Uint1(x291)))
var x294 uint64
var x295 uint64
x294, x295 = bits.Add64(x273, x270, uint64(p521Uint1(x293)))
var x296 uint64
var x297 uint64
x296, x297 = bits.Add64(x271, x268, uint64(p521Uint1(x295)))
var x298 uint64
var x299 uint64
x298, x299 = bits.Add64(x269, x266, uint64(p521Uint1(x297)))
x300 := (uint64(p521Uint1(x299)) + x267)
var x302 uint64
_, x302 = bits.Add64(x246, x282, uint64(0x0))
var x303 uint64
var x304 uint64
x303, x304 = bits.Add64(x248, x284, uint64(p521Uint1(x302)))
var x305 uint64
var x306 uint64
x305, x306 = bits.Add64(x250, x286, uint64(p521Uint1(x304)))
var x307 uint64
var x308 uint64
x307, x308 = bits.Add64(x252, x288, uint64(p521Uint1(x306)))
var x309 uint64
var x310 uint64
x309, x310 = bits.Add64(x254, x290, uint64(p521Uint1(x308)))
var x311 uint64
var x312 uint64
x311, x312 = bits.Add64(x256, x292, uint64(p521Uint1(x310)))
var x313 uint64
var x314 uint64
x313, x314 = bits.Add64(x258, x294, uint64(p521Uint1(x312)))
var x315 uint64
var x316 uint64
x315, x316 = bits.Add64(x260, x296, uint64(p521Uint1(x314)))
var x317 uint64
var x318 uint64
x317, x318 = bits.Add64(x262, x298, uint64(p521Uint1(x316)))
var x319 uint64
var x320 uint64
x319, x320 = bits.Add64(x264, x300, uint64(p521Uint1(x318)))
x321 := (uint64(p521Uint1(x320)) + uint64(p521Uint1(x265)))
var x322 uint64
var x323 uint64
x323, x322 = bits.Mul64(x3, arg1[8])
var x324 uint64
var x325 uint64
x325, x324 = bits.Mul64(x3, arg1[7])
var x326 uint64
var x327 uint64
x327, x326 = bits.Mul64(x3, arg1[6])
var x328 uint64
var x329 uint64
x329, x328 = bits.Mul64(x3, arg1[5])
var x330 uint64
var x331 uint64
x331, x330 = bits.Mul64(x3, arg1[4])
var x332 uint64
var x333 uint64
x333, x332 = bits.Mul64(x3, arg1[3])
var x334 uint64
var x335 uint64
x335, x334 = bits.Mul64(x3, arg1[2])
var x336 uint64
var x337 uint64
x337, x336 = bits.Mul64(x3, arg1[1])
var x338 uint64
var x339 uint64
x339, x338 = bits.Mul64(x3, arg1[0])
var x340 uint64
var x341 uint64
x340, x341 = bits.Add64(x339, x336, uint64(0x0))
var x342 uint64
var x343 uint64
x342, x343 = bits.Add64(x337, x334, uint64(p521Uint1(x341)))
var x344 uint64
var x345 uint64
x344, x345 = bits.Add64(x335, x332, uint64(p521Uint1(x343)))
var x346 uint64
var x347 uint64
x346, x347 = bits.Add64(x333, x330, uint64(p521Uint1(x345)))
var x348 uint64
var x349 uint64
x348, x349 = bits.Add64(x331, x328, uint64(p521Uint1(x347)))
var x350 uint64
var x351 uint64
x350, x351 = bits.Add64(x329, x326, uint64(p521Uint1(x349)))
var x352 uint64
var x353 uint64
x352, x353 = bits.Add64(x327, x324, uint64(p521Uint1(x351)))
var x354 uint64
var x355 uint64
x354, x355 = bits.Add64(x325, x322, uint64(p521Uint1(x353)))
x356 := (uint64(p521Uint1(x355)) + x323)
var x357 uint64
var x358 uint64
x357, x358 = bits.Add64(x303, x338, uint64(0x0))
var x359 uint64
var x360 uint64
x359, x360 = bits.Add64(x305, x340, uint64(p521Uint1(x358)))
var x361 uint64
var x362 uint64
x361, x362 = bits.Add64(x307, x342, uint64(p521Uint1(x360)))
var x363 uint64
var x364 uint64
x363, x364 = bits.Add64(x309, x344, uint64(p521Uint1(x362)))
var x365 uint64
var x366 uint64
x365, x366 = bits.Add64(x311, x346, uint64(p521Uint1(x364)))
var x367 uint64
var x368 uint64
x367, x368 = bits.Add64(x313, x348, uint64(p521Uint1(x366)))
var x369 uint64
var x370 uint64
x369, x370 = bits.Add64(x315, x350, uint64(p521Uint1(x368)))
var x371 uint64
var x372 uint64
x371, x372 = bits.Add64(x317, x352, uint64(p521Uint1(x370)))
var x373 uint64
var x374 uint64
x373, x374 = bits.Add64(x319, x354, uint64(p521Uint1(x372)))
var x375 uint64
var x376 uint64
x375, x376 = bits.Add64(x321, x356, uint64(p521Uint1(x374)))
var x377 uint64
var x378 uint64
x378, x377 = bits.Mul64(x357, 0x1ff)
var x379 uint64
var x380 uint64
x380, x379 = bits.Mul64(x357, 0xffffffffffffffff)
var x381 uint64
var x382 uint64
x382, x381 = bits.Mul64(x357, 0xffffffffffffffff)
var x383 uint64
var x384 uint64
x384, x383 = bits.Mul64(x357, 0xffffffffffffffff)
var x385 uint64
var x386 uint64
x386, x385 = bits.Mul64(x357, 0xffffffffffffffff)
var x387 uint64
var x388 uint64
x388, x387 = bits.Mul64(x357, 0xffffffffffffffff)
var x389 uint64
var x390 uint64
x390, x389 = bits.Mul64(x357, 0xffffffffffffffff)
var x391 uint64
var x392 uint64
x392, x391 = bits.Mul64(x357, 0xffffffffffffffff)
var x393 uint64
var x394 uint64
x394, x393 = bits.Mul64(x357, 0xffffffffffffffff)
var x395 uint64
var x396 uint64
x395, x396 = bits.Add64(x394, x391, uint64(0x0))
var x397 uint64
var x398 uint64
x397, x398 = bits.Add64(x392, x389, uint64(p521Uint1(x396)))
var x399 uint64
var x400 uint64
x399, x400 = bits.Add64(x390, x387, uint64(p521Uint1(x398)))
var x401 uint64
var x402 uint64
x401, x402 = bits.Add64(x388, x385, uint64(p521Uint1(x400)))
var x403 uint64
var x404 uint64
x403, x404 = bits.Add64(x386, x383, uint64(p521Uint1(x402)))
var x405 uint64
var x406 uint64
x405, x406 = bits.Add64(x384, x381, uint64(p521Uint1(x404)))
var x407 uint64
var x408 uint64
x407, x408 = bits.Add64(x382, x379, uint64(p521Uint1(x406)))
var x409 uint64
var x410 uint64
x409, x410 = bits.Add64(x380, x377, uint64(p521Uint1(x408)))
x411 := (uint64(p521Uint1(x410)) + x378)
var x413 uint64
_, x413 = bits.Add64(x357, x393, uint64(0x0))
var x414 uint64
var x415 uint64
x414, x415 = bits.Add64(x359, x395, uint64(p521Uint1(x413)))
var x416 uint64
var x417 uint64
x416, x417 = bits.Add64(x361, x397, uint64(p521Uint1(x415)))
var x418 uint64
var x419 uint64
x418, x419 = bits.Add64(x363, x399, uint64(p521Uint1(x417)))
var x420 uint64
var x421 uint64
x420, x421 = bits.Add64(x365, x401, uint64(p521Uint1(x419)))
var x422 uint64
var x423 uint64
x422, x423 = bits.Add64(x367, x403, uint64(p521Uint1(x421)))
var x424 uint64
var x425 uint64
x424, x425 = bits.Add64(x369, x405, uint64(p521Uint1(x423)))
var x426 uint64
var x427 uint64
x426, x427 = bits.Add64(x371, x407, uint64(p521Uint1(x425)))
var x428 uint64
var x429 uint64
x428, x429 = bits.Add64(x373, x409, uint64(p521Uint1(x427)))
var x430 uint64
var x431 uint64
x430, x431 = bits.Add64(x375, x411, uint64(p521Uint1(x429)))
x432 := (uint64(p521Uint1(x431)) + uint64(p521Uint1(x376)))
var x433 uint64
var x434 uint64
x434, x433 = bits.Mul64(x4, arg1[8])
var x435 uint64
var x436 uint64
x436, x435 = bits.Mul64(x4, arg1[7])
var x437 uint64
var x438 uint64
x438, x437 = bits.Mul64(x4, arg1[6])
var x439 uint64
var x440 uint64
x440, x439 = bits.Mul64(x4, arg1[5])
var x441 uint64
var x442 uint64
x442, x441 = bits.Mul64(x4, arg1[4])
var x443 uint64
var x444 uint64
x444, x443 = bits.Mul64(x4, arg1[3])
var x445 uint64
var x446 uint64
x446, x445 = bits.Mul64(x4, arg1[2])
var x447 uint64
var x448 uint64
x448, x447 = bits.Mul64(x4, arg1[1])
var x449 uint64
var x450 uint64
x450, x449 = bits.Mul64(x4, arg1[0])
var x451 uint64
var x452 uint64
x451, x452 = bits.Add64(x450, x447, uint64(0x0))
var x453 uint64
var x454 uint64
x453, x454 = bits.Add64(x448, x445, uint64(p521Uint1(x452)))
var x455 uint64
var x456 uint64
x455, x456 = bits.Add64(x446, x443, uint64(p521Uint1(x454)))
var x457 uint64
var x458 uint64
x457, x458 = bits.Add64(x444, x441, uint64(p521Uint1(x456)))
var x459 uint64
var x460 uint64
x459, x460 = bits.Add64(x442, x439, uint64(p521Uint1(x458)))
var x461 uint64
var x462 uint64
x461, x462 = bits.Add64(x440, x437, uint64(p521Uint1(x460)))
var x463 uint64
var x464 uint64
x463, x464 = bits.Add64(x438, x435, uint64(p521Uint1(x462)))
var x465 uint64
var x466 uint64
x465, x466 = bits.Add64(x436, x433, uint64(p521Uint1(x464)))
x467 := (uint64(p521Uint1(x466)) + x434)
var x468 uint64
var x469 uint64
x468, x469 = bits.Add64(x414, x449, uint64(0x0))
var x470 uint64
var x471 uint64
x470, x471 = bits.Add64(x416, x451, uint64(p521Uint1(x469)))
var x472 uint64
var x473 uint64
x472, x473 = bits.Add64(x418, x453, uint64(p521Uint1(x471)))
var x474 uint64
var x475 uint64
x474, x475 = bits.Add64(x420, x455, uint64(p521Uint1(x473)))
var x476 uint64
var x477 uint64
x476, x477 = bits.Add64(x422, x457, uint64(p521Uint1(x475)))
var x478 uint64
var x479 uint64
x478, x479 = bits.Add64(x424, x459, uint64(p521Uint1(x477)))
var x480 uint64
var x481 uint64
x480, x481 = bits.Add64(x426, x461, uint64(p521Uint1(x479)))
var x482 uint64
var x483 uint64
x482, x483 = bits.Add64(x428, x463, uint64(p521Uint1(x481)))
var x484 uint64
var x485 uint64
x484, x485 = bits.Add64(x430, x465, uint64(p521Uint1(x483)))
var x486 uint64
var x487 uint64
x486, x487 = bits.Add64(x432, x467, uint64(p521Uint1(x485)))
var x488 uint64
var x489 uint64
x489, x488 = bits.Mul64(x468, 0x1ff)
var x490 uint64
var x491 uint64
x491, x490 = bits.Mul64(x468, 0xffffffffffffffff)
var x492 uint64
var x493 uint64
x493, x492 = bits.Mul64(x468, 0xffffffffffffffff)
var x494 uint64
var x495 uint64
x495, x494 = bits.Mul64(x468, 0xffffffffffffffff)
var x496 uint64
var x497 uint64
x497, x496 = bits.Mul64(x468, 0xffffffffffffffff)
var x498 uint64
var x499 uint64
x499, x498 = bits.Mul64(x468, 0xffffffffffffffff)
var x500 uint64
var x501 uint64
x501, x500 = bits.Mul64(x468, 0xffffffffffffffff)
var x502 uint64
var x503 uint64
x503, x502 = bits.Mul64(x468, 0xffffffffffffffff)
var x504 uint64
var x505 uint64
x505, x504 = bits.Mul64(x468, 0xffffffffffffffff)
var x506 uint64
var x507 uint64
x506, x507 = bits.Add64(x505, x502, uint64(0x0))
var x508 uint64
var x509 uint64
x508, x509 = bits.Add64(x503, x500, uint64(p521Uint1(x507)))
var x510 uint64
var x511 uint64
x510, x511 = bits.Add64(x501, x498, uint64(p521Uint1(x509)))
var x512 uint64
var x513 uint64
x512, x513 = bits.Add64(x499, x496, uint64(p521Uint1(x511)))
var x514 uint64
var x515 uint64
x514, x515 = bits.Add64(x497, x494, uint64(p521Uint1(x513)))
var x516 uint64
var x517 uint64
x516, x517 = bits.Add64(x495, x492, uint64(p521Uint1(x515)))
var x518 uint64
var x519 uint64
x518, x519 = bits.Add64(x493, x490, uint64(p521Uint1(x517)))
var x520 uint64
var x521 uint64
x520, x521 = bits.Add64(x491, x488, uint64(p521Uint1(x519)))
x522 := (uint64(p521Uint1(x521)) + x489)
var x524 uint64
_, x524 = bits.Add64(x468, x504, uint64(0x0))
var x525 uint64
var x526 uint64
x525, x526 = bits.Add64(x470, x506, uint64(p521Uint1(x524)))
var x527 uint64
var x528 uint64
x527, x528 = bits.Add64(x472, x508, uint64(p521Uint1(x526)))
var x529 uint64
var x530 uint64
x529, x530 = bits.Add64(x474, x510, uint64(p521Uint1(x528)))
var x531 uint64
var x532 uint64
x531, x532 = bits.Add64(x476, x512, uint64(p521Uint1(x530)))
var x533 uint64
var x534 uint64
x533, x534 = bits.Add64(x478, x514, uint64(p521Uint1(x532)))
var x535 uint64
var x536 uint64
x535, x536 = bits.Add64(x480, x516, uint64(p521Uint1(x534)))
var x537 uint64
var x538 uint64
x537, x538 = bits.Add64(x482, x518, uint64(p521Uint1(x536)))
var x539 uint64
var x540 uint64
x539, x540 = bits.Add64(x484, x520, uint64(p521Uint1(x538)))
var x541 uint64
var x542 uint64
x541, x542 = bits.Add64(x486, x522, uint64(p521Uint1(x540)))
x543 := (uint64(p521Uint1(x542)) + uint64(p521Uint1(x487)))
var x544 uint64
var x545 uint64
x545, x544 = bits.Mul64(x5, arg1[8])
var x546 uint64
var x547 uint64
x547, x546 = bits.Mul64(x5, arg1[7])
var x548 uint64
var x549 uint64
x549, x548 = bits.Mul64(x5, arg1[6])
var x550 uint64
var x551 uint64
x551, x550 = bits.Mul64(x5, arg1[5])
var x552 uint64
var x553 uint64
x553, x552 = bits.Mul64(x5, arg1[4])
var x554 uint64
var x555 uint64
x555, x554 = bits.Mul64(x5, arg1[3])
var x556 uint64
var x557 uint64
x557, x556 = bits.Mul64(x5, arg1[2])
var x558 uint64
var x559 uint64
x559, x558 = bits.Mul64(x5, arg1[1])
var x560 uint64
var x561 uint64
x561, x560 = bits.Mul64(x5, arg1[0])
var x562 uint64
var x563 uint64
x562, x563 = bits.Add64(x561, x558, uint64(0x0))
var x564 uint64
var x565 uint64
x564, x565 = bits.Add64(x559, x556, uint64(p521Uint1(x563)))
var x566 uint64
var x567 uint64
x566, x567 = bits.Add64(x557, x554, uint64(p521Uint1(x565)))
var x568 uint64
var x569 uint64
x568, x569 = bits.Add64(x555, x552, uint64(p521Uint1(x567)))
var x570 uint64
var x571 uint64
x570, x571 = bits.Add64(x553, x550, uint64(p521Uint1(x569)))
var x572 uint64
var x573 uint64
x572, x573 = bits.Add64(x551, x548, uint64(p521Uint1(x571)))
var x574 uint64
var x575 uint64
x574, x575 = bits.Add64(x549, x546, uint64(p521Uint1(x573)))
var x576 uint64
var x577 uint64
x576, x577 = bits.Add64(x547, x544, uint64(p521Uint1(x575)))
x578 := (uint64(p521Uint1(x577)) + x545)
var x579 uint64
var x580 uint64
x579, x580 = bits.Add64(x525, x560, uint64(0x0))
var x581 uint64
var x582 uint64
x581, x582 = bits.Add64(x527, x562, uint64(p521Uint1(x580)))
var x583 uint64
var x584 uint64
x583, x584 = bits.Add64(x529, x564, uint64(p521Uint1(x582)))
var x585 uint64
var x586 uint64
x585, x586 = bits.Add64(x531, x566, uint64(p521Uint1(x584)))
var x587 uint64
var x588 uint64
x587, x588 = bits.Add64(x533, x568, uint64(p521Uint1(x586)))
var x589 uint64
var x590 uint64
x589, x590 = bits.Add64(x535, x570, uint64(p521Uint1(x588)))
var x591 uint64
var x592 uint64
x591, x592 = bits.Add64(x537, x572, uint64(p521Uint1(x590)))
var x593 uint64
var x594 uint64
x593, x594 = bits.Add64(x539, x574, uint64(p521Uint1(x592)))
var x595 uint64
var x596 uint64
x595, x596 = bits.Add64(x541, x576, uint64(p521Uint1(x594)))
var x597 uint64
var x598 uint64
x597, x598 = bits.Add64(x543, x578, uint64(p521Uint1(x596)))
var x599 uint64
var x600 uint64
x600, x599 = bits.Mul64(x579, 0x1ff)
var x601 uint64
var x602 uint64
x602, x601 = bits.Mul64(x579, 0xffffffffffffffff)
var x603 uint64
var x604 uint64
x604, x603 = bits.Mul64(x579, 0xffffffffffffffff)
var x605 uint64
var x606 uint64
x606, x605 = bits.Mul64(x579, 0xffffffffffffffff)
var x607 uint64
var x608 uint64
x608, x607 = bits.Mul64(x579, 0xffffffffffffffff)
var x609 uint64
var x610 uint64
x610, x609 = bits.Mul64(x579, 0xffffffffffffffff)
var x611 uint64
var x612 uint64
x612, x611 = bits.Mul64(x579, 0xffffffffffffffff)
var x613 uint64
var x614 uint64
x614, x613 = bits.Mul64(x579, 0xffffffffffffffff)
var x615 uint64
var x616 uint64
x616, x615 = bits.Mul64(x579, 0xffffffffffffffff)
var x617 uint64
var x618 uint64
x617, x618 = bits.Add64(x616, x613, uint64(0x0))
var x619 uint64
var x620 uint64
x619, x620 = bits.Add64(x614, x611, uint64(p521Uint1(x618)))
var x621 uint64
var x622 uint64
x621, x622 = bits.Add64(x612, x609, uint64(p521Uint1(x620)))
var x623 uint64
var x624 uint64
x623, x624 = bits.Add64(x610, x607, uint64(p521Uint1(x622)))
var x625 uint64
var x626 uint64
x625, x626 = bits.Add64(x608, x605, uint64(p521Uint1(x624)))
var x627 uint64
var x628 uint64
x627, x628 = bits.Add64(x606, x603, uint64(p521Uint1(x626)))
var x629 uint64
var x630 uint64
x629, x630 = bits.Add64(x604, x601, uint64(p521Uint1(x628)))
var x631 uint64
var x632 uint64
x631, x632 = bits.Add64(x602, x599, uint64(p521Uint1(x630)))
x633 := (uint64(p521Uint1(x632)) + x600)
var x635 uint64
_, x635 = bits.Add64(x579, x615, uint64(0x0))
var x636 uint64
var x637 uint64
x636, x637 = bits.Add64(x581, x617, uint64(p521Uint1(x635)))
var x638 uint64
var x639 uint64
x638, x639 = bits.Add64(x583, x619, uint64(p521Uint1(x637)))
var x640 uint64
var x641 uint64
x640, x641 = bits.Add64(x585, x621, uint64(p521Uint1(x639)))
var x642 uint64
var x643 uint64
x642, x643 = bits.Add64(x587, x623, uint64(p521Uint1(x641)))
var x644 uint64
var x645 uint64
x644, x645 = bits.Add64(x589, x625, uint64(p521Uint1(x643)))
var x646 uint64
var x647 uint64
x646, x647 = bits.Add64(x591, x627, uint64(p521Uint1(x645)))
var x648 uint64
var x649 uint64
x648, x649 = bits.Add64(x593, x629, uint64(p521Uint1(x647)))
var x650 uint64
var x651 uint64
x650, x651 = bits.Add64(x595, x631, uint64(p521Uint1(x649)))
var x652 uint64
var x653 uint64
x652, x653 = bits.Add64(x597, x633, uint64(p521Uint1(x651)))
x654 := (uint64(p521Uint1(x653)) + uint64(p521Uint1(x598)))
var x655 uint64
var x656 uint64
x656, x655 = bits.Mul64(x6, arg1[8])
var x657 uint64
var x658 uint64
x658, x657 = bits.Mul64(x6, arg1[7])
var x659 uint64
var x660 uint64
x660, x659 = bits.Mul64(x6, arg1[6])
var x661 uint64
var x662 uint64
x662, x661 = bits.Mul64(x6, arg1[5])
var x663 uint64
var x664 uint64
x664, x663 = bits.Mul64(x6, arg1[4])
var x665 uint64
var x666 uint64
x666, x665 = bits.Mul64(x6, arg1[3])
var x667 uint64
var x668 uint64
x668, x667 = bits.Mul64(x6, arg1[2])
var x669 uint64
var x670 uint64
x670, x669 = bits.Mul64(x6, arg1[1])
var x671 uint64
var x672 uint64
x672, x671 = bits.Mul64(x6, arg1[0])
var x673 uint64
var x674 uint64
x673, x674 = bits.Add64(x672, x669, uint64(0x0))
var x675 uint64
var x676 uint64
x675, x676 = bits.Add64(x670, x667, uint64(p521Uint1(x674)))
var x677 uint64
var x678 uint64
x677, x678 = bits.Add64(x668, x665, uint64(p521Uint1(x676)))
var x679 uint64
var x680 uint64
x679, x680 = bits.Add64(x666, x663, uint64(p521Uint1(x678)))
var x681 uint64
var x682 uint64
x681, x682 = bits.Add64(x664, x661, uint64(p521Uint1(x680)))
var x683 uint64
var x684 uint64
x683, x684 = bits.Add64(x662, x659, uint64(p521Uint1(x682)))
var x685 uint64
var x686 uint64
x685, x686 = bits.Add64(x660, x657, uint64(p521Uint1(x684)))
var x687 uint64
var x688 uint64
x687, x688 = bits.Add64(x658, x655, uint64(p521Uint1(x686)))
x689 := (uint64(p521Uint1(x688)) + x656)
var x690 uint64
var x691 uint64
x690, x691 = bits.Add64(x636, x671, uint64(0x0))
var x692 uint64
var x693 uint64
x692, x693 = bits.Add64(x638, x673, uint64(p521Uint1(x691)))
var x694 uint64
var x695 uint64
x694, x695 = bits.Add64(x640, x675, uint64(p521Uint1(x693)))
var x696 uint64
var x697 uint64
x696, x697 = bits.Add64(x642, x677, uint64(p521Uint1(x695)))
var x698 uint64
var x699 uint64
x698, x699 = bits.Add64(x644, x679, uint64(p521Uint1(x697)))
var x700 uint64
var x701 uint64
x700, x701 = bits.Add64(x646, x681, uint64(p521Uint1(x699)))
var x702 uint64
var x703 uint64
x702, x703 = bits.Add64(x648, x683, uint64(p521Uint1(x701)))
var x704 uint64
var x705 uint64
x704, x705 = bits.Add64(x650, x685, uint64(p521Uint1(x703)))
var x706 uint64
var x707 uint64
x706, x707 = bits.Add64(x652, x687, uint64(p521Uint1(x705)))
var x708 uint64
var x709 uint64
x708, x709 = bits.Add64(x654, x689, uint64(p521Uint1(x707)))
var x710 uint64
var x711 uint64
x711, x710 = bits.Mul64(x690, 0x1ff)
var x712 uint64
var x713 uint64
x713, x712 = bits.Mul64(x690, 0xffffffffffffffff)
var x714 uint64
var x715 uint64
x715, x714 = bits.Mul64(x690, 0xffffffffffffffff)
var x716 uint64
var x717 uint64
x717, x716 = bits.Mul64(x690, 0xffffffffffffffff)
var x718 uint64
var x719 uint64
x719, x718 = bits.Mul64(x690, 0xffffffffffffffff)
var x720 uint64
var x721 uint64
x721, x720 = bits.Mul64(x690, 0xffffffffffffffff)
var x722 uint64
var x723 uint64
x723, x722 = bits.Mul64(x690, 0xffffffffffffffff)
var x724 uint64
var x725 uint64
x725, x724 = bits.Mul64(x690, 0xffffffffffffffff)
var x726 uint64
var x727 uint64
x727, x726 = bits.Mul64(x690, 0xffffffffffffffff)
var x728 uint64
var x729 uint64
x728, x729 = bits.Add64(x727, x724, uint64(0x0))
var x730 uint64
var x731 uint64
x730, x731 = bits.Add64(x725, x722, uint64(p521Uint1(x729)))
var x732 uint64
var x733 uint64
x732, x733 = bits.Add64(x723, x720, uint64(p521Uint1(x731)))
var x734 uint64
var x735 uint64
x734, x735 = bits.Add64(x721, x718, uint64(p521Uint1(x733)))
var x736 uint64
var x737 uint64
x736, x737 = bits.Add64(x719, x716, uint64(p521Uint1(x735)))
var x738 uint64
var x739 uint64
x738, x739 = bits.Add64(x717, x714, uint64(p521Uint1(x737)))
var x740 uint64
var x741 uint64
x740, x741 = bits.Add64(x715, x712, uint64(p521Uint1(x739)))
var x742 uint64
var x743 uint64
x742, x743 = bits.Add64(x713, x710, uint64(p521Uint1(x741)))
x744 := (uint64(p521Uint1(x743)) + x711)
var x746 uint64
_, x746 = bits.Add64(x690, x726, uint64(0x0))
var x747 uint64
var x748 uint64
x747, x748 = bits.Add64(x692, x728, uint64(p521Uint1(x746)))
var x749 uint64
var x750 uint64
x749, x750 = bits.Add64(x694, x730, uint64(p521Uint1(x748)))
var x751 uint64
var x752 uint64
x751, x752 = bits.Add64(x696, x732, uint64(p521Uint1(x750)))
var x753 uint64
var x754 uint64
x753, x754 = bits.Add64(x698, x734, uint64(p521Uint1(x752)))
var x755 uint64
var x756 uint64
x755, x756 = bits.Add64(x700, x736, uint64(p521Uint1(x754)))
var x757 uint64
var x758 uint64
x757, x758 = bits.Add64(x702, x738, uint64(p521Uint1(x756)))
var x759 uint64
var x760 uint64
x759, x760 = bits.Add64(x704, x740, uint64(p521Uint1(x758)))
var x761 uint64
var x762 uint64
x761, x762 = bits.Add64(x706, x742, uint64(p521Uint1(x760)))
var x763 uint64
var x764 uint64
x763, x764 = bits.Add64(x708, x744, uint64(p521Uint1(x762)))
x765 := (uint64(p521Uint1(x764)) + uint64(p521Uint1(x709)))
var x766 uint64
var x767 uint64
x767, x766 = bits.Mul64(x7, arg1[8])
var x768 uint64
var x769 uint64
x769, x768 = bits.Mul64(x7, arg1[7])
var x770 uint64
var x771 uint64
x771, x770 = bits.Mul64(x7, arg1[6])
var x772 uint64
var x773 uint64
x773, x772 = bits.Mul64(x7, arg1[5])
var x774 uint64
var x775 uint64
x775, x774 = bits.Mul64(x7, arg1[4])
var x776 uint64
var x777 uint64
x777, x776 = bits.Mul64(x7, arg1[3])
var x778 uint64
var x779 uint64
x779, x778 = bits.Mul64(x7, arg1[2])
var x780 uint64
var x781 uint64
x781, x780 = bits.Mul64(x7, arg1[1])
var x782 uint64
var x783 uint64
x783, x782 = bits.Mul64(x7, arg1[0])
var x784 uint64
var x785 uint64
x784, x785 = bits.Add64(x783, x780, uint64(0x0))
var x786 uint64
var x787 uint64
x786, x787 = bits.Add64(x781, x778, uint64(p521Uint1(x785)))
var x788 uint64
var x789 uint64
x788, x789 = bits.Add64(x779, x776, uint64(p521Uint1(x787)))
var x790 uint64
var x791 uint64
x790, x791 = bits.Add64(x777, x774, uint64(p521Uint1(x789)))
var x792 uint64
var x793 uint64
x792, x793 = bits.Add64(x775, x772, uint64(p521Uint1(x791)))
var x794 uint64
var x795 uint64
x794, x795 = bits.Add64(x773, x770, uint64(p521Uint1(x793)))
var x796 uint64
var x797 uint64
x796, x797 = bits.Add64(x771, x768, uint64(p521Uint1(x795)))
var x798 uint64
var x799 uint64
x798, x799 = bits.Add64(x769, x766, uint64(p521Uint1(x797)))
x800 := (uint64(p521Uint1(x799)) + x767)
var x801 uint64
var x802 uint64
x801, x802 = bits.Add64(x747, x782, uint64(0x0))
var x803 uint64
var x804 uint64
x803, x804 = bits.Add64(x749, x784, uint64(p521Uint1(x802)))
var x805 uint64
var x806 uint64
x805, x806 = bits.Add64(x751, x786, uint64(p521Uint1(x804)))
var x807 uint64
var x808 uint64
x807, x808 = bits.Add64(x753, x788, uint64(p521Uint1(x806)))
var x809 uint64
var x810 uint64
x809, x810 = bits.Add64(x755, x790, uint64(p521Uint1(x808)))
var x811 uint64
var x812 uint64
x811, x812 = bits.Add64(x757, x792, uint64(p521Uint1(x810)))
var x813 uint64
var x814 uint64
x813, x814 = bits.Add64(x759, x794, uint64(p521Uint1(x812)))
var x815 uint64
var x816 uint64
x815, x816 = bits.Add64(x761, x796, uint64(p521Uint1(x814)))
var x817 uint64
var x818 uint64
x817, x818 = bits.Add64(x763, x798, uint64(p521Uint1(x816)))
var x819 uint64
var x820 uint64
x819, x820 = bits.Add64(x765, x800, uint64(p521Uint1(x818)))
var x821 uint64
var x822 uint64
x822, x821 = bits.Mul64(x801, 0x1ff)
var x823 uint64
var x824 uint64
x824, x823 = bits.Mul64(x801, 0xffffffffffffffff)
var x825 uint64
var x826 uint64
x826, x825 = bits.Mul64(x801, 0xffffffffffffffff)
var x827 uint64
var x828 uint64
x828, x827 = bits.Mul64(x801, 0xffffffffffffffff)
var x829 uint64
var x830 uint64
x830, x829 = bits.Mul64(x801, 0xffffffffffffffff)
var x831 uint64
var x832 uint64
x832, x831 = bits.Mul64(x801, 0xffffffffffffffff)
var x833 uint64
var x834 uint64
x834, x833 = bits.Mul64(x801, 0xffffffffffffffff)
var x835 uint64
var x836 uint64
x836, x835 = bits.Mul64(x801, 0xffffffffffffffff)
var x837 uint64
var x838 uint64
x838, x837 = bits.Mul64(x801, 0xffffffffffffffff)
var x839 uint64
var x840 uint64
x839, x840 = bits.Add64(x838, x835, uint64(0x0))
var x841 uint64
var x842 uint64
x841, x842 = bits.Add64(x836, x833, uint64(p521Uint1(x840)))
var x843 uint64
var x844 uint64
x843, x844 = bits.Add64(x834, x831, uint64(p521Uint1(x842)))
var x845 uint64
var x846 uint64
x845, x846 = bits.Add64(x832, x829, uint64(p521Uint1(x844)))
var x847 uint64
var x848 uint64
x847, x848 = bits.Add64(x830, x827, uint64(p521Uint1(x846)))
var x849 uint64
var x850 uint64
x849, x850 = bits.Add64(x828, x825, uint64(p521Uint1(x848)))
var x851 uint64
var x852 uint64
x851, x852 = bits.Add64(x826, x823, uint64(p521Uint1(x850)))
var x853 uint64
var x854 uint64
x853, x854 = bits.Add64(x824, x821, uint64(p521Uint1(x852)))
x855 := (uint64(p521Uint1(x854)) + x822)
var x857 uint64
_, x857 = bits.Add64(x801, x837, uint64(0x0))
var x858 uint64
var x859 uint64
x858, x859 = bits.Add64(x803, x839, uint64(p521Uint1(x857)))
var x860 uint64
var x861 uint64
x860, x861 = bits.Add64(x805, x841, uint64(p521Uint1(x859)))
var x862 uint64
var x863 uint64
x862, x863 = bits.Add64(x807, x843, uint64(p521Uint1(x861)))
var x864 uint64
var x865 uint64
x864, x865 = bits.Add64(x809, x845, uint64(p521Uint1(x863)))
var x866 uint64
var x867 uint64
x866, x867 = bits.Add64(x811, x847, uint64(p521Uint1(x865)))
var x868 uint64
var x869 uint64
x868, x869 = bits.Add64(x813, x849, uint64(p521Uint1(x867)))
var x870 uint64
var x871 uint64
x870, x871 = bits.Add64(x815, x851, uint64(p521Uint1(x869)))
var x872 uint64
var x873 uint64
x872, x873 = bits.Add64(x817, x853, uint64(p521Uint1(x871)))
var x874 uint64
var x875 uint64
x874, x875 = bits.Add64(x819, x855, uint64(p521Uint1(x873)))
x876 := (uint64(p521Uint1(x875)) + uint64(p521Uint1(x820)))
var x877 uint64
var x878 uint64
x878, x877 = bits.Mul64(x8, arg1[8])
var x879 uint64
var x880 uint64
x880, x879 = bits.Mul64(x8, arg1[7])
var x881 uint64
var x882 uint64
x882, x881 = bits.Mul64(x8, arg1[6])
var x883 uint64
var x884 uint64
x884, x883 = bits.Mul64(x8, arg1[5])
var x885 uint64
var x886 uint64
x886, x885 = bits.Mul64(x8, arg1[4])
var x887 uint64
var x888 uint64
x888, x887 = bits.Mul64(x8, arg1[3])
var x889 uint64
var x890 uint64
x890, x889 = bits.Mul64(x8, arg1[2])
var x891 uint64
var x892 uint64
x892, x891 = bits.Mul64(x8, arg1[1])
var x893 uint64
var x894 uint64
x894, x893 = bits.Mul64(x8, arg1[0])
var x895 uint64
var x896 uint64
x895, x896 = bits.Add64(x894, x891, uint64(0x0))
var x897 uint64
var x898 uint64
x897, x898 = bits.Add64(x892, x889, uint64(p521Uint1(x896)))
var x899 uint64
var x900 uint64
x899, x900 = bits.Add64(x890, x887, uint64(p521Uint1(x898)))
var x901 uint64
var x902 uint64
x901, x902 = bits.Add64(x888, x885, uint64(p521Uint1(x900)))
var x903 uint64
var x904 uint64
x903, x904 = bits.Add64(x886, x883, uint64(p521Uint1(x902)))
var x905 uint64
var x906 uint64
x905, x906 = bits.Add64(x884, x881, uint64(p521Uint1(x904)))
var x907 uint64
var x908 uint64
x907, x908 = bits.Add64(x882, x879, uint64(p521Uint1(x906)))
var x909 uint64
var x910 uint64
x909, x910 = bits.Add64(x880, x877, uint64(p521Uint1(x908)))
x911 := (uint64(p521Uint1(x910)) + x878)
var x912 uint64
var x913 uint64
x912, x913 = bits.Add64(x858, x893, uint64(0x0))
var x914 uint64
var x915 uint64
x914, x915 = bits.Add64(x860, x895, uint64(p521Uint1(x913)))
var x916 uint64
var x917 uint64
x916, x917 = bits.Add64(x862, x897, uint64(p521Uint1(x915)))
var x918 uint64
var x919 uint64
x918, x919 = bits.Add64(x864, x899, uint64(p521Uint1(x917)))
var x920 uint64
var x921 uint64
x920, x921 = bits.Add64(x866, x901, uint64(p521Uint1(x919)))
var x922 uint64
var x923 uint64
x922, x923 = bits.Add64(x868, x903, uint64(p521Uint1(x921)))
var x924 uint64
var x925 uint64
x924, x925 = bits.Add64(x870, x905, uint64(p521Uint1(x923)))
var x926 uint64
var x927 uint64
x926, x927 = bits.Add64(x872, x907, uint64(p521Uint1(x925)))
var x928 uint64
var x929 uint64
x928, x929 = bits.Add64(x874, x909, uint64(p521Uint1(x927)))
var x930 uint64
var x931 uint64
x930, x931 = bits.Add64(x876, x911, uint64(p521Uint1(x929)))
var x932 uint64
var x933 uint64
x933, x932 = bits.Mul64(x912, 0x1ff)
var x934 uint64
var x935 uint64
x935, x934 = bits.Mul64(x912, 0xffffffffffffffff)
var x936 uint64
var x937 uint64
x937, x936 = bits.Mul64(x912, 0xffffffffffffffff)
var x938 uint64
var x939 uint64
x939, x938 = bits.Mul64(x912, 0xffffffffffffffff)
var x940 uint64
var x941 uint64
x941, x940 = bits.Mul64(x912, 0xffffffffffffffff)
var x942 uint64
var x943 uint64
x943, x942 = bits.Mul64(x912, 0xffffffffffffffff)
var x944 uint64
var x945 uint64
x945, x944 = bits.Mul64(x912, 0xffffffffffffffff)
var x946 uint64
var x947 uint64
x947, x946 = bits.Mul64(x912, 0xffffffffffffffff)
var x948 uint64
var x949 uint64
x949, x948 = bits.Mul64(x912, 0xffffffffffffffff)
var x950 uint64
var x951 uint64
x950, x951 = bits.Add64(x949, x946, uint64(0x0))
var x952 uint64
var x953 uint64
x952, x953 = bits.Add64(x947, x944, uint64(p521Uint1(x951)))
var x954 uint64
var x955 uint64
x954, x955 = bits.Add64(x945, x942, uint64(p521Uint1(x953)))
var x956 uint64
var x957 uint64
x956, x957 = bits.Add64(x943, x940, uint64(p521Uint1(x955)))
var x958 uint64
var x959 uint64
x958, x959 = bits.Add64(x941, x938, uint64(p521Uint1(x957)))
var x960 uint64
var x961 uint64
x960, x961 = bits.Add64(x939, x936, uint64(p521Uint1(x959)))
var x962 uint64
var x963 uint64
x962, x963 = bits.Add64(x937, x934, uint64(p521Uint1(x961)))
var x964 uint64
var x965 uint64
x964, x965 = bits.Add64(x935, x932, uint64(p521Uint1(x963)))
x966 := (uint64(p521Uint1(x965)) + x933)
var x968 uint64
_, x968 = bits.Add64(x912, x948, uint64(0x0))
var x969 uint64
var x970 uint64
x969, x970 = bits.Add64(x914, x950, uint64(p521Uint1(x968)))
var x971 uint64
var x972 uint64
x971, x972 = bits.Add64(x916, x952, uint64(p521Uint1(x970)))
var x973 uint64
var x974 uint64
x973, x974 = bits.Add64(x918, x954, uint64(p521Uint1(x972)))
var x975 uint64
var x976 uint64
x975, x976 = bits.Add64(x920, x956, uint64(p521Uint1(x974)))
var x977 uint64
var x978 uint64
x977, x978 = bits.Add64(x922, x958, uint64(p521Uint1(x976)))
var x979 uint64
var x980 uint64
x979, x980 = bits.Add64(x924, x960, uint64(p521Uint1(x978)))
var x981 uint64
var x982 uint64
x981, x982 = bits.Add64(x926, x962, uint64(p521Uint1(x980)))
var x983 uint64
var x984 uint64
x983, x984 = bits.Add64(x928, x964, uint64(p521Uint1(x982)))
var x985 uint64
var x986 uint64
x985, x986 = bits.Add64(x930, x966, uint64(p521Uint1(x984)))
x987 := (uint64(p521Uint1(x986)) + uint64(p521Uint1(x931)))
var x988 uint64
var x989 uint64
x988, x989 = bits.Sub64(x969, 0xffffffffffffffff, uint64(0x0))
var x990 uint64
var x991 uint64
x990, x991 = bits.Sub64(x971, 0xffffffffffffffff, uint64(p521Uint1(x989)))
var x992 uint64
var x993 uint64
x992, x993 = bits.Sub64(x973, 0xffffffffffffffff, uint64(p521Uint1(x991)))
var x994 uint64
var x995 uint64
x994, x995 = bits.Sub64(x975, 0xffffffffffffffff, uint64(p521Uint1(x993)))
var x996 uint64
var x997 uint64
x996, x997 = bits.Sub64(x977, 0xffffffffffffffff, uint64(p521Uint1(x995)))
var x998 uint64
var x999 uint64
x998, x999 = bits.Sub64(x979, 0xffffffffffffffff, uint64(p521Uint1(x997)))
var x1000 uint64
var x1001 uint64
x1000, x1001 = bits.Sub64(x981, 0xffffffffffffffff, uint64(p521Uint1(x999)))
var x1002 uint64
var x1003 uint64
x1002, x1003 = bits.Sub64(x983, 0xffffffffffffffff, uint64(p521Uint1(x1001)))
var x1004 uint64
var x1005 uint64
x1004, x1005 = bits.Sub64(x985, 0x1ff, uint64(p521Uint1(x1003)))
var x1007 uint64
_, x1007 = bits.Sub64(x987, uint64(0x0), uint64(p521Uint1(x1005)))
var x1008 uint64
p521CmovznzU64(&x1008, p521Uint1(x1007), x988, x969)
var x1009 uint64
p521CmovznzU64(&x1009, p521Uint1(x1007), x990, x971)
var x1010 uint64
p521CmovznzU64(&x1010, p521Uint1(x1007), x992, x973)
var x1011 uint64
p521CmovznzU64(&x1011, p521Uint1(x1007), x994, x975)
var x1012 uint64
p521CmovznzU64(&x1012, p521Uint1(x1007), x996, x977)
var x1013 uint64
p521CmovznzU64(&x1013, p521Uint1(x1007), x998, x979)
var x1014 uint64
p521CmovznzU64(&x1014, p521Uint1(x1007), x1000, x981)
var x1015 uint64
p521CmovznzU64(&x1015, p521Uint1(x1007), x1002, x983)
var x1016 uint64
p521CmovznzU64(&x1016, p521Uint1(x1007), x1004, x985)
out1[0] = x1008
out1[1] = x1009
out1[2] = x1010
out1[3] = x1011
out1[4] = x1012
out1[5] = x1013
out1[6] = x1014
out1[7] = x1015
out1[8] = x1016
}
// p521Add adds two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) + eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p521Add(out1 *p521MontgomeryDomainFieldElement, arg1 *p521MontgomeryDomainFieldElement, arg2 *p521MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Add64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Add64(arg1[1], arg2[1], uint64(p521Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(arg1[2], arg2[2], uint64(p521Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Add64(arg1[3], arg2[3], uint64(p521Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Add64(arg1[4], arg2[4], uint64(p521Uint1(x8)))
var x11 uint64
var x12 uint64
x11, x12 = bits.Add64(arg1[5], arg2[5], uint64(p521Uint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Add64(arg1[6], arg2[6], uint64(p521Uint1(x12)))
var x15 uint64
var x16 uint64
x15, x16 = bits.Add64(arg1[7], arg2[7], uint64(p521Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Add64(arg1[8], arg2[8], uint64(p521Uint1(x16)))
var x19 uint64
var x20 uint64
x19, x20 = bits.Sub64(x1, 0xffffffffffffffff, uint64(0x0))
var x21 uint64
var x22 uint64
x21, x22 = bits.Sub64(x3, 0xffffffffffffffff, uint64(p521Uint1(x20)))
var x23 uint64
var x24 uint64
x23, x24 = bits.Sub64(x5, 0xffffffffffffffff, uint64(p521Uint1(x22)))
var x25 uint64
var x26 uint64
x25, x26 = bits.Sub64(x7, 0xffffffffffffffff, uint64(p521Uint1(x24)))
var x27 uint64
var x28 uint64
x27, x28 = bits.Sub64(x9, 0xffffffffffffffff, uint64(p521Uint1(x26)))
var x29 uint64
var x30 uint64
x29, x30 = bits.Sub64(x11, 0xffffffffffffffff, uint64(p521Uint1(x28)))
var x31 uint64
var x32 uint64
x31, x32 = bits.Sub64(x13, 0xffffffffffffffff, uint64(p521Uint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Sub64(x15, 0xffffffffffffffff, uint64(p521Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Sub64(x17, 0x1ff, uint64(p521Uint1(x34)))
var x38 uint64
_, x38 = bits.Sub64(uint64(p521Uint1(x18)), uint64(0x0), uint64(p521Uint1(x36)))
var x39 uint64
p521CmovznzU64(&x39, p521Uint1(x38), x19, x1)
var x40 uint64
p521CmovznzU64(&x40, p521Uint1(x38), x21, x3)
var x41 uint64
p521CmovznzU64(&x41, p521Uint1(x38), x23, x5)
var x42 uint64
p521CmovznzU64(&x42, p521Uint1(x38), x25, x7)
var x43 uint64
p521CmovznzU64(&x43, p521Uint1(x38), x27, x9)
var x44 uint64
p521CmovznzU64(&x44, p521Uint1(x38), x29, x11)
var x45 uint64
p521CmovznzU64(&x45, p521Uint1(x38), x31, x13)
var x46 uint64
p521CmovznzU64(&x46, p521Uint1(x38), x33, x15)
var x47 uint64
p521CmovznzU64(&x47, p521Uint1(x38), x35, x17)
out1[0] = x39
out1[1] = x40
out1[2] = x41
out1[3] = x42
out1[4] = x43
out1[5] = x44
out1[6] = x45
out1[7] = x46
out1[8] = x47
}
// p521Sub subtracts two field elements in the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
// 0 ≤ eval arg2 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = (eval (from_montgomery arg1) - eval (from_montgomery arg2)) mod m
// 0 ≤ eval out1 < m
func p521Sub(out1 *p521MontgomeryDomainFieldElement, arg1 *p521MontgomeryDomainFieldElement, arg2 *p521MontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x1, x2 = bits.Sub64(arg1[0], arg2[0], uint64(0x0))
var x3 uint64
var x4 uint64
x3, x4 = bits.Sub64(arg1[1], arg2[1], uint64(p521Uint1(x2)))
var x5 uint64
var x6 uint64
x5, x6 = bits.Sub64(arg1[2], arg2[2], uint64(p521Uint1(x4)))
var x7 uint64
var x8 uint64
x7, x8 = bits.Sub64(arg1[3], arg2[3], uint64(p521Uint1(x6)))
var x9 uint64
var x10 uint64
x9, x10 = bits.Sub64(arg1[4], arg2[4], uint64(p521Uint1(x8)))
var x11 uint64
var x12 uint64
x11, x12 = bits.Sub64(arg1[5], arg2[5], uint64(p521Uint1(x10)))
var x13 uint64
var x14 uint64
x13, x14 = bits.Sub64(arg1[6], arg2[6], uint64(p521Uint1(x12)))
var x15 uint64
var x16 uint64
x15, x16 = bits.Sub64(arg1[7], arg2[7], uint64(p521Uint1(x14)))
var x17 uint64
var x18 uint64
x17, x18 = bits.Sub64(arg1[8], arg2[8], uint64(p521Uint1(x16)))
var x19 uint64
p521CmovznzU64(&x19, p521Uint1(x18), uint64(0x0), 0xffffffffffffffff)
var x20 uint64
var x21 uint64
x20, x21 = bits.Add64(x1, x19, uint64(0x0))
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x3, x19, uint64(p521Uint1(x21)))
var x24 uint64
var x25 uint64
x24, x25 = bits.Add64(x5, x19, uint64(p521Uint1(x23)))
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x7, x19, uint64(p521Uint1(x25)))
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x9, x19, uint64(p521Uint1(x27)))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x11, x19, uint64(p521Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x13, x19, uint64(p521Uint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x15, x19, uint64(p521Uint1(x33)))
var x36 uint64
x36, _ = bits.Add64(x17, (x19 & 0x1ff), uint64(p521Uint1(x35)))
out1[0] = x20
out1[1] = x22
out1[2] = x24
out1[3] = x26
out1[4] = x28
out1[5] = x30
out1[6] = x32
out1[7] = x34
out1[8] = x36
}
// p521SetOne returns the field element one in the Montgomery domain.
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = 1 mod m
// 0 ≤ eval out1 < m
func p521SetOne(out1 *p521MontgomeryDomainFieldElement) {
out1[0] = 0x80000000000000
out1[1] = uint64(0x0)
out1[2] = uint64(0x0)
out1[3] = uint64(0x0)
out1[4] = uint64(0x0)
out1[5] = uint64(0x0)
out1[6] = uint64(0x0)
out1[7] = uint64(0x0)
out1[8] = uint64(0x0)
}
// p521FromMontgomery translates a field element out of the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = (eval arg1 * ((2^64)⁻¹ mod m)^9) mod m
// 0 ≤ eval out1 < m
func p521FromMontgomery(out1 *p521NonMontgomeryDomainFieldElement, arg1 *p521MontgomeryDomainFieldElement) {
x1 := arg1[0]
var x2 uint64
var x3 uint64
x3, x2 = bits.Mul64(x1, 0x1ff)
var x4 uint64
var x5 uint64
x5, x4 = bits.Mul64(x1, 0xffffffffffffffff)
var x6 uint64
var x7 uint64
x7, x6 = bits.Mul64(x1, 0xffffffffffffffff)
var x8 uint64
var x9 uint64
x9, x8 = bits.Mul64(x1, 0xffffffffffffffff)
var x10 uint64
var x11 uint64
x11, x10 = bits.Mul64(x1, 0xffffffffffffffff)
var x12 uint64
var x13 uint64
x13, x12 = bits.Mul64(x1, 0xffffffffffffffff)
var x14 uint64
var x15 uint64
x15, x14 = bits.Mul64(x1, 0xffffffffffffffff)
var x16 uint64
var x17 uint64
x17, x16 = bits.Mul64(x1, 0xffffffffffffffff)
var x18 uint64
var x19 uint64
x19, x18 = bits.Mul64(x1, 0xffffffffffffffff)
var x20 uint64
var x21 uint64
x20, x21 = bits.Add64(x19, x16, uint64(0x0))
var x22 uint64
var x23 uint64
x22, x23 = bits.Add64(x17, x14, uint64(p521Uint1(x21)))
var x24 uint64
var x25 uint64
x24, x25 = bits.Add64(x15, x12, uint64(p521Uint1(x23)))
var x26 uint64
var x27 uint64
x26, x27 = bits.Add64(x13, x10, uint64(p521Uint1(x25)))
var x28 uint64
var x29 uint64
x28, x29 = bits.Add64(x11, x8, uint64(p521Uint1(x27)))
var x30 uint64
var x31 uint64
x30, x31 = bits.Add64(x9, x6, uint64(p521Uint1(x29)))
var x32 uint64
var x33 uint64
x32, x33 = bits.Add64(x7, x4, uint64(p521Uint1(x31)))
var x34 uint64
var x35 uint64
x34, x35 = bits.Add64(x5, x2, uint64(p521Uint1(x33)))
var x37 uint64
_, x37 = bits.Add64(x1, x18, uint64(0x0))
var x38 uint64
var x39 uint64
x38, x39 = bits.Add64(uint64(0x0), x20, uint64(p521Uint1(x37)))
var x40 uint64
var x41 uint64
x40, x41 = bits.Add64(uint64(0x0), x22, uint64(p521Uint1(x39)))
var x42 uint64
var x43 uint64
x42, x43 = bits.Add64(uint64(0x0), x24, uint64(p521Uint1(x41)))
var x44 uint64
var x45 uint64
x44, x45 = bits.Add64(uint64(0x0), x26, uint64(p521Uint1(x43)))
var x46 uint64
var x47 uint64
x46, x47 = bits.Add64(uint64(0x0), x28, uint64(p521Uint1(x45)))
var x48 uint64
var x49 uint64
x48, x49 = bits.Add64(uint64(0x0), x30, uint64(p521Uint1(x47)))
var x50 uint64
var x51 uint64
x50, x51 = bits.Add64(uint64(0x0), x32, uint64(p521Uint1(x49)))
var x52 uint64
var x53 uint64
x52, x53 = bits.Add64(uint64(0x0), x34, uint64(p521Uint1(x51)))
var x54 uint64
var x55 uint64
x54, x55 = bits.Add64(x38, arg1[1], uint64(0x0))
var x56 uint64
var x57 uint64
x56, x57 = bits.Add64(x40, uint64(0x0), uint64(p521Uint1(x55)))
var x58 uint64
var x59 uint64
x58, x59 = bits.Add64(x42, uint64(0x0), uint64(p521Uint1(x57)))
var x60 uint64
var x61 uint64
x60, x61 = bits.Add64(x44, uint64(0x0), uint64(p521Uint1(x59)))
var x62 uint64
var x63 uint64
x62, x63 = bits.Add64(x46, uint64(0x0), uint64(p521Uint1(x61)))
var x64 uint64
var x65 uint64
x64, x65 = bits.Add64(x48, uint64(0x0), uint64(p521Uint1(x63)))
var x66 uint64
var x67 uint64
x66, x67 = bits.Add64(x50, uint64(0x0), uint64(p521Uint1(x65)))
var x68 uint64
var x69 uint64
x68, x69 = bits.Add64(x52, uint64(0x0), uint64(p521Uint1(x67)))
var x70 uint64
var x71 uint64
x71, x70 = bits.Mul64(x54, 0x1ff)
var x72 uint64
var x73 uint64
x73, x72 = bits.Mul64(x54, 0xffffffffffffffff)
var x74 uint64
var x75 uint64
x75, x74 = bits.Mul64(x54, 0xffffffffffffffff)
var x76 uint64
var x77 uint64
x77, x76 = bits.Mul64(x54, 0xffffffffffffffff)
var x78 uint64
var x79 uint64
x79, x78 = bits.Mul64(x54, 0xffffffffffffffff)
var x80 uint64
var x81 uint64
x81, x80 = bits.Mul64(x54, 0xffffffffffffffff)
var x82 uint64
var x83 uint64
x83, x82 = bits.Mul64(x54, 0xffffffffffffffff)
var x84 uint64
var x85 uint64
x85, x84 = bits.Mul64(x54, 0xffffffffffffffff)
var x86 uint64
var x87 uint64
x87, x86 = bits.Mul64(x54, 0xffffffffffffffff)
var x88 uint64
var x89 uint64
x88, x89 = bits.Add64(x87, x84, uint64(0x0))
var x90 uint64
var x91 uint64
x90, x91 = bits.Add64(x85, x82, uint64(p521Uint1(x89)))
var x92 uint64
var x93 uint64
x92, x93 = bits.Add64(x83, x80, uint64(p521Uint1(x91)))
var x94 uint64
var x95 uint64
x94, x95 = bits.Add64(x81, x78, uint64(p521Uint1(x93)))
var x96 uint64
var x97 uint64
x96, x97 = bits.Add64(x79, x76, uint64(p521Uint1(x95)))
var x98 uint64
var x99 uint64
x98, x99 = bits.Add64(x77, x74, uint64(p521Uint1(x97)))
var x100 uint64
var x101 uint64
x100, x101 = bits.Add64(x75, x72, uint64(p521Uint1(x99)))
var x102 uint64
var x103 uint64
x102, x103 = bits.Add64(x73, x70, uint64(p521Uint1(x101)))
var x105 uint64
_, x105 = bits.Add64(x54, x86, uint64(0x0))
var x106 uint64
var x107 uint64
x106, x107 = bits.Add64(x56, x88, uint64(p521Uint1(x105)))
var x108 uint64
var x109 uint64
x108, x109 = bits.Add64(x58, x90, uint64(p521Uint1(x107)))
var x110 uint64
var x111 uint64
x110, x111 = bits.Add64(x60, x92, uint64(p521Uint1(x109)))
var x112 uint64
var x113 uint64
x112, x113 = bits.Add64(x62, x94, uint64(p521Uint1(x111)))
var x114 uint64
var x115 uint64
x114, x115 = bits.Add64(x64, x96, uint64(p521Uint1(x113)))
var x116 uint64
var x117 uint64
x116, x117 = bits.Add64(x66, x98, uint64(p521Uint1(x115)))
var x118 uint64
var x119 uint64
x118, x119 = bits.Add64(x68, x100, uint64(p521Uint1(x117)))
var x120 uint64
var x121 uint64
x120, x121 = bits.Add64((uint64(p521Uint1(x69)) + (uint64(p521Uint1(x53)) + (uint64(p521Uint1(x35)) + x3))), x102, uint64(p521Uint1(x119)))
var x122 uint64
var x123 uint64
x122, x123 = bits.Add64(x106, arg1[2], uint64(0x0))
var x124 uint64
var x125 uint64
x124, x125 = bits.Add64(x108, uint64(0x0), uint64(p521Uint1(x123)))
var x126 uint64
var x127 uint64
x126, x127 = bits.Add64(x110, uint64(0x0), uint64(p521Uint1(x125)))
var x128 uint64
var x129 uint64
x128, x129 = bits.Add64(x112, uint64(0x0), uint64(p521Uint1(x127)))
var x130 uint64
var x131 uint64
x130, x131 = bits.Add64(x114, uint64(0x0), uint64(p521Uint1(x129)))
var x132 uint64
var x133 uint64
x132, x133 = bits.Add64(x116, uint64(0x0), uint64(p521Uint1(x131)))
var x134 uint64
var x135 uint64
x134, x135 = bits.Add64(x118, uint64(0x0), uint64(p521Uint1(x133)))
var x136 uint64
var x137 uint64
x136, x137 = bits.Add64(x120, uint64(0x0), uint64(p521Uint1(x135)))
var x138 uint64
var x139 uint64
x139, x138 = bits.Mul64(x122, 0x1ff)
var x140 uint64
var x141 uint64
x141, x140 = bits.Mul64(x122, 0xffffffffffffffff)
var x142 uint64
var x143 uint64
x143, x142 = bits.Mul64(x122, 0xffffffffffffffff)
var x144 uint64
var x145 uint64
x145, x144 = bits.Mul64(x122, 0xffffffffffffffff)
var x146 uint64
var x147 uint64
x147, x146 = bits.Mul64(x122, 0xffffffffffffffff)
var x148 uint64
var x149 uint64
x149, x148 = bits.Mul64(x122, 0xffffffffffffffff)
var x150 uint64
var x151 uint64
x151, x150 = bits.Mul64(x122, 0xffffffffffffffff)
var x152 uint64
var x153 uint64
x153, x152 = bits.Mul64(x122, 0xffffffffffffffff)
var x154 uint64
var x155 uint64
x155, x154 = bits.Mul64(x122, 0xffffffffffffffff)
var x156 uint64
var x157 uint64
x156, x157 = bits.Add64(x155, x152, uint64(0x0))
var x158 uint64
var x159 uint64
x158, x159 = bits.Add64(x153, x150, uint64(p521Uint1(x157)))
var x160 uint64
var x161 uint64
x160, x161 = bits.Add64(x151, x148, uint64(p521Uint1(x159)))
var x162 uint64
var x163 uint64
x162, x163 = bits.Add64(x149, x146, uint64(p521Uint1(x161)))
var x164 uint64
var x165 uint64
x164, x165 = bits.Add64(x147, x144, uint64(p521Uint1(x163)))
var x166 uint64
var x167 uint64
x166, x167 = bits.Add64(x145, x142, uint64(p521Uint1(x165)))
var x168 uint64
var x169 uint64
x168, x169 = bits.Add64(x143, x140, uint64(p521Uint1(x167)))
var x170 uint64
var x171 uint64
x170, x171 = bits.Add64(x141, x138, uint64(p521Uint1(x169)))
var x173 uint64
_, x173 = bits.Add64(x122, x154, uint64(0x0))
var x174 uint64
var x175 uint64
x174, x175 = bits.Add64(x124, x156, uint64(p521Uint1(x173)))
var x176 uint64
var x177 uint64
x176, x177 = bits.Add64(x126, x158, uint64(p521Uint1(x175)))
var x178 uint64
var x179 uint64
x178, x179 = bits.Add64(x128, x160, uint64(p521Uint1(x177)))
var x180 uint64
var x181 uint64
x180, x181 = bits.Add64(x130, x162, uint64(p521Uint1(x179)))
var x182 uint64
var x183 uint64
x182, x183 = bits.Add64(x132, x164, uint64(p521Uint1(x181)))
var x184 uint64
var x185 uint64
x184, x185 = bits.Add64(x134, x166, uint64(p521Uint1(x183)))
var x186 uint64
var x187 uint64
x186, x187 = bits.Add64(x136, x168, uint64(p521Uint1(x185)))
var x188 uint64
var x189 uint64
x188, x189 = bits.Add64((uint64(p521Uint1(x137)) + (uint64(p521Uint1(x121)) + (uint64(p521Uint1(x103)) + x71))), x170, uint64(p521Uint1(x187)))
var x190 uint64
var x191 uint64
x190, x191 = bits.Add64(x174, arg1[3], uint64(0x0))
var x192 uint64
var x193 uint64
x192, x193 = bits.Add64(x176, uint64(0x0), uint64(p521Uint1(x191)))
var x194 uint64
var x195 uint64
x194, x195 = bits.Add64(x178, uint64(0x0), uint64(p521Uint1(x193)))
var x196 uint64
var x197 uint64
x196, x197 = bits.Add64(x180, uint64(0x0), uint64(p521Uint1(x195)))
var x198 uint64
var x199 uint64
x198, x199 = bits.Add64(x182, uint64(0x0), uint64(p521Uint1(x197)))
var x200 uint64
var x201 uint64
x200, x201 = bits.Add64(x184, uint64(0x0), uint64(p521Uint1(x199)))
var x202 uint64
var x203 uint64
x202, x203 = bits.Add64(x186, uint64(0x0), uint64(p521Uint1(x201)))
var x204 uint64
var x205 uint64
x204, x205 = bits.Add64(x188, uint64(0x0), uint64(p521Uint1(x203)))
var x206 uint64
var x207 uint64
x207, x206 = bits.Mul64(x190, 0x1ff)
var x208 uint64
var x209 uint64
x209, x208 = bits.Mul64(x190, 0xffffffffffffffff)
var x210 uint64
var x211 uint64
x211, x210 = bits.Mul64(x190, 0xffffffffffffffff)
var x212 uint64
var x213 uint64
x213, x212 = bits.Mul64(x190, 0xffffffffffffffff)
var x214 uint64
var x215 uint64
x215, x214 = bits.Mul64(x190, 0xffffffffffffffff)
var x216 uint64
var x217 uint64
x217, x216 = bits.Mul64(x190, 0xffffffffffffffff)
var x218 uint64
var x219 uint64
x219, x218 = bits.Mul64(x190, 0xffffffffffffffff)
var x220 uint64
var x221 uint64
x221, x220 = bits.Mul64(x190, 0xffffffffffffffff)
var x222 uint64
var x223 uint64
x223, x222 = bits.Mul64(x190, 0xffffffffffffffff)
var x224 uint64
var x225 uint64
x224, x225 = bits.Add64(x223, x220, uint64(0x0))
var x226 uint64
var x227 uint64
x226, x227 = bits.Add64(x221, x218, uint64(p521Uint1(x225)))
var x228 uint64
var x229 uint64
x228, x229 = bits.Add64(x219, x216, uint64(p521Uint1(x227)))
var x230 uint64
var x231 uint64
x230, x231 = bits.Add64(x217, x214, uint64(p521Uint1(x229)))
var x232 uint64
var x233 uint64
x232, x233 = bits.Add64(x215, x212, uint64(p521Uint1(x231)))
var x234 uint64
var x235 uint64
x234, x235 = bits.Add64(x213, x210, uint64(p521Uint1(x233)))
var x236 uint64
var x237 uint64
x236, x237 = bits.Add64(x211, x208, uint64(p521Uint1(x235)))
var x238 uint64
var x239 uint64
x238, x239 = bits.Add64(x209, x206, uint64(p521Uint1(x237)))
var x241 uint64
_, x241 = bits.Add64(x190, x222, uint64(0x0))
var x242 uint64
var x243 uint64
x242, x243 = bits.Add64(x192, x224, uint64(p521Uint1(x241)))
var x244 uint64
var x245 uint64
x244, x245 = bits.Add64(x194, x226, uint64(p521Uint1(x243)))
var x246 uint64
var x247 uint64
x246, x247 = bits.Add64(x196, x228, uint64(p521Uint1(x245)))
var x248 uint64
var x249 uint64
x248, x249 = bits.Add64(x198, x230, uint64(p521Uint1(x247)))
var x250 uint64
var x251 uint64
x250, x251 = bits.Add64(x200, x232, uint64(p521Uint1(x249)))
var x252 uint64
var x253 uint64
x252, x253 = bits.Add64(x202, x234, uint64(p521Uint1(x251)))
var x254 uint64
var x255 uint64
x254, x255 = bits.Add64(x204, x236, uint64(p521Uint1(x253)))
var x256 uint64
var x257 uint64
x256, x257 = bits.Add64((uint64(p521Uint1(x205)) + (uint64(p521Uint1(x189)) + (uint64(p521Uint1(x171)) + x139))), x238, uint64(p521Uint1(x255)))
var x258 uint64
var x259 uint64
x258, x259 = bits.Add64(x242, arg1[4], uint64(0x0))
var x260 uint64
var x261 uint64
x260, x261 = bits.Add64(x244, uint64(0x0), uint64(p521Uint1(x259)))
var x262 uint64
var x263 uint64
x262, x263 = bits.Add64(x246, uint64(0x0), uint64(p521Uint1(x261)))
var x264 uint64
var x265 uint64
x264, x265 = bits.Add64(x248, uint64(0x0), uint64(p521Uint1(x263)))
var x266 uint64
var x267 uint64
x266, x267 = bits.Add64(x250, uint64(0x0), uint64(p521Uint1(x265)))
var x268 uint64
var x269 uint64
x268, x269 = bits.Add64(x252, uint64(0x0), uint64(p521Uint1(x267)))
var x270 uint64
var x271 uint64
x270, x271 = bits.Add64(x254, uint64(0x0), uint64(p521Uint1(x269)))
var x272 uint64
var x273 uint64
x272, x273 = bits.Add64(x256, uint64(0x0), uint64(p521Uint1(x271)))
var x274 uint64
var x275 uint64
x275, x274 = bits.Mul64(x258, 0x1ff)
var x276 uint64
var x277 uint64
x277, x276 = bits.Mul64(x258, 0xffffffffffffffff)
var x278 uint64
var x279 uint64
x279, x278 = bits.Mul64(x258, 0xffffffffffffffff)
var x280 uint64
var x281 uint64
x281, x280 = bits.Mul64(x258, 0xffffffffffffffff)
var x282 uint64
var x283 uint64
x283, x282 = bits.Mul64(x258, 0xffffffffffffffff)
var x284 uint64
var x285 uint64
x285, x284 = bits.Mul64(x258, 0xffffffffffffffff)
var x286 uint64
var x287 uint64
x287, x286 = bits.Mul64(x258, 0xffffffffffffffff)
var x288 uint64
var x289 uint64
x289, x288 = bits.Mul64(x258, 0xffffffffffffffff)
var x290 uint64
var x291 uint64
x291, x290 = bits.Mul64(x258, 0xffffffffffffffff)
var x292 uint64
var x293 uint64
x292, x293 = bits.Add64(x291, x288, uint64(0x0))
var x294 uint64
var x295 uint64
x294, x295 = bits.Add64(x289, x286, uint64(p521Uint1(x293)))
var x296 uint64
var x297 uint64
x296, x297 = bits.Add64(x287, x284, uint64(p521Uint1(x295)))
var x298 uint64
var x299 uint64
x298, x299 = bits.Add64(x285, x282, uint64(p521Uint1(x297)))
var x300 uint64
var x301 uint64
x300, x301 = bits.Add64(x283, x280, uint64(p521Uint1(x299)))
var x302 uint64
var x303 uint64
x302, x303 = bits.Add64(x281, x278, uint64(p521Uint1(x301)))
var x304 uint64
var x305 uint64
x304, x305 = bits.Add64(x279, x276, uint64(p521Uint1(x303)))
var x306 uint64
var x307 uint64
x306, x307 = bits.Add64(x277, x274, uint64(p521Uint1(x305)))
var x309 uint64
_, x309 = bits.Add64(x258, x290, uint64(0x0))
var x310 uint64
var x311 uint64
x310, x311 = bits.Add64(x260, x292, uint64(p521Uint1(x309)))
var x312 uint64
var x313 uint64
x312, x313 = bits.Add64(x262, x294, uint64(p521Uint1(x311)))
var x314 uint64
var x315 uint64
x314, x315 = bits.Add64(x264, x296, uint64(p521Uint1(x313)))
var x316 uint64
var x317 uint64
x316, x317 = bits.Add64(x266, x298, uint64(p521Uint1(x315)))
var x318 uint64
var x319 uint64
x318, x319 = bits.Add64(x268, x300, uint64(p521Uint1(x317)))
var x320 uint64
var x321 uint64
x320, x321 = bits.Add64(x270, x302, uint64(p521Uint1(x319)))
var x322 uint64
var x323 uint64
x322, x323 = bits.Add64(x272, x304, uint64(p521Uint1(x321)))
var x324 uint64
var x325 uint64
x324, x325 = bits.Add64((uint64(p521Uint1(x273)) + (uint64(p521Uint1(x257)) + (uint64(p521Uint1(x239)) + x207))), x306, uint64(p521Uint1(x323)))
var x326 uint64
var x327 uint64
x326, x327 = bits.Add64(x310, arg1[5], uint64(0x0))
var x328 uint64
var x329 uint64
x328, x329 = bits.Add64(x312, uint64(0x0), uint64(p521Uint1(x327)))
var x330 uint64
var x331 uint64
x330, x331 = bits.Add64(x314, uint64(0x0), uint64(p521Uint1(x329)))
var x332 uint64
var x333 uint64
x332, x333 = bits.Add64(x316, uint64(0x0), uint64(p521Uint1(x331)))
var x334 uint64
var x335 uint64
x334, x335 = bits.Add64(x318, uint64(0x0), uint64(p521Uint1(x333)))
var x336 uint64
var x337 uint64
x336, x337 = bits.Add64(x320, uint64(0x0), uint64(p521Uint1(x335)))
var x338 uint64
var x339 uint64
x338, x339 = bits.Add64(x322, uint64(0x0), uint64(p521Uint1(x337)))
var x340 uint64
var x341 uint64
x340, x341 = bits.Add64(x324, uint64(0x0), uint64(p521Uint1(x339)))
var x342 uint64
var x343 uint64
x343, x342 = bits.Mul64(x326, 0x1ff)
var x344 uint64
var x345 uint64
x345, x344 = bits.Mul64(x326, 0xffffffffffffffff)
var x346 uint64
var x347 uint64
x347, x346 = bits.Mul64(x326, 0xffffffffffffffff)
var x348 uint64
var x349 uint64
x349, x348 = bits.Mul64(x326, 0xffffffffffffffff)
var x350 uint64
var x351 uint64
x351, x350 = bits.Mul64(x326, 0xffffffffffffffff)
var x352 uint64
var x353 uint64
x353, x352 = bits.Mul64(x326, 0xffffffffffffffff)
var x354 uint64
var x355 uint64
x355, x354 = bits.Mul64(x326, 0xffffffffffffffff)
var x356 uint64
var x357 uint64
x357, x356 = bits.Mul64(x326, 0xffffffffffffffff)
var x358 uint64
var x359 uint64
x359, x358 = bits.Mul64(x326, 0xffffffffffffffff)
var x360 uint64
var x361 uint64
x360, x361 = bits.Add64(x359, x356, uint64(0x0))
var x362 uint64
var x363 uint64
x362, x363 = bits.Add64(x357, x354, uint64(p521Uint1(x361)))
var x364 uint64
var x365 uint64
x364, x365 = bits.Add64(x355, x352, uint64(p521Uint1(x363)))
var x366 uint64
var x367 uint64
x366, x367 = bits.Add64(x353, x350, uint64(p521Uint1(x365)))
var x368 uint64
var x369 uint64
x368, x369 = bits.Add64(x351, x348, uint64(p521Uint1(x367)))
var x370 uint64
var x371 uint64
x370, x371 = bits.Add64(x349, x346, uint64(p521Uint1(x369)))
var x372 uint64
var x373 uint64
x372, x373 = bits.Add64(x347, x344, uint64(p521Uint1(x371)))
var x374 uint64
var x375 uint64
x374, x375 = bits.Add64(x345, x342, uint64(p521Uint1(x373)))
var x377 uint64
_, x377 = bits.Add64(x326, x358, uint64(0x0))
var x378 uint64
var x379 uint64
x378, x379 = bits.Add64(x328, x360, uint64(p521Uint1(x377)))
var x380 uint64
var x381 uint64
x380, x381 = bits.Add64(x330, x362, uint64(p521Uint1(x379)))
var x382 uint64
var x383 uint64
x382, x383 = bits.Add64(x332, x364, uint64(p521Uint1(x381)))
var x384 uint64
var x385 uint64
x384, x385 = bits.Add64(x334, x366, uint64(p521Uint1(x383)))
var x386 uint64
var x387 uint64
x386, x387 = bits.Add64(x336, x368, uint64(p521Uint1(x385)))
var x388 uint64
var x389 uint64
x388, x389 = bits.Add64(x338, x370, uint64(p521Uint1(x387)))
var x390 uint64
var x391 uint64
x390, x391 = bits.Add64(x340, x372, uint64(p521Uint1(x389)))
var x392 uint64
var x393 uint64
x392, x393 = bits.Add64((uint64(p521Uint1(x341)) + (uint64(p521Uint1(x325)) + (uint64(p521Uint1(x307)) + x275))), x374, uint64(p521Uint1(x391)))
var x394 uint64
var x395 uint64
x394, x395 = bits.Add64(x378, arg1[6], uint64(0x0))
var x396 uint64
var x397 uint64
x396, x397 = bits.Add64(x380, uint64(0x0), uint64(p521Uint1(x395)))
var x398 uint64
var x399 uint64
x398, x399 = bits.Add64(x382, uint64(0x0), uint64(p521Uint1(x397)))
var x400 uint64
var x401 uint64
x400, x401 = bits.Add64(x384, uint64(0x0), uint64(p521Uint1(x399)))
var x402 uint64
var x403 uint64
x402, x403 = bits.Add64(x386, uint64(0x0), uint64(p521Uint1(x401)))
var x404 uint64
var x405 uint64
x404, x405 = bits.Add64(x388, uint64(0x0), uint64(p521Uint1(x403)))
var x406 uint64
var x407 uint64
x406, x407 = bits.Add64(x390, uint64(0x0), uint64(p521Uint1(x405)))
var x408 uint64
var x409 uint64
x408, x409 = bits.Add64(x392, uint64(0x0), uint64(p521Uint1(x407)))
var x410 uint64
var x411 uint64
x411, x410 = bits.Mul64(x394, 0x1ff)
var x412 uint64
var x413 uint64
x413, x412 = bits.Mul64(x394, 0xffffffffffffffff)
var x414 uint64
var x415 uint64
x415, x414 = bits.Mul64(x394, 0xffffffffffffffff)
var x416 uint64
var x417 uint64
x417, x416 = bits.Mul64(x394, 0xffffffffffffffff)
var x418 uint64
var x419 uint64
x419, x418 = bits.Mul64(x394, 0xffffffffffffffff)
var x420 uint64
var x421 uint64
x421, x420 = bits.Mul64(x394, 0xffffffffffffffff)
var x422 uint64
var x423 uint64
x423, x422 = bits.Mul64(x394, 0xffffffffffffffff)
var x424 uint64
var x425 uint64
x425, x424 = bits.Mul64(x394, 0xffffffffffffffff)
var x426 uint64
var x427 uint64
x427, x426 = bits.Mul64(x394, 0xffffffffffffffff)
var x428 uint64
var x429 uint64
x428, x429 = bits.Add64(x427, x424, uint64(0x0))
var x430 uint64
var x431 uint64
x430, x431 = bits.Add64(x425, x422, uint64(p521Uint1(x429)))
var x432 uint64
var x433 uint64
x432, x433 = bits.Add64(x423, x420, uint64(p521Uint1(x431)))
var x434 uint64
var x435 uint64
x434, x435 = bits.Add64(x421, x418, uint64(p521Uint1(x433)))
var x436 uint64
var x437 uint64
x436, x437 = bits.Add64(x419, x416, uint64(p521Uint1(x435)))
var x438 uint64
var x439 uint64
x438, x439 = bits.Add64(x417, x414, uint64(p521Uint1(x437)))
var x440 uint64
var x441 uint64
x440, x441 = bits.Add64(x415, x412, uint64(p521Uint1(x439)))
var x442 uint64
var x443 uint64
x442, x443 = bits.Add64(x413, x410, uint64(p521Uint1(x441)))
var x445 uint64
_, x445 = bits.Add64(x394, x426, uint64(0x0))
var x446 uint64
var x447 uint64
x446, x447 = bits.Add64(x396, x428, uint64(p521Uint1(x445)))
var x448 uint64
var x449 uint64
x448, x449 = bits.Add64(x398, x430, uint64(p521Uint1(x447)))
var x450 uint64
var x451 uint64
x450, x451 = bits.Add64(x400, x432, uint64(p521Uint1(x449)))
var x452 uint64
var x453 uint64
x452, x453 = bits.Add64(x402, x434, uint64(p521Uint1(x451)))
var x454 uint64
var x455 uint64
x454, x455 = bits.Add64(x404, x436, uint64(p521Uint1(x453)))
var x456 uint64
var x457 uint64
x456, x457 = bits.Add64(x406, x438, uint64(p521Uint1(x455)))
var x458 uint64
var x459 uint64
x458, x459 = bits.Add64(x408, x440, uint64(p521Uint1(x457)))
var x460 uint64
var x461 uint64
x460, x461 = bits.Add64((uint64(p521Uint1(x409)) + (uint64(p521Uint1(x393)) + (uint64(p521Uint1(x375)) + x343))), x442, uint64(p521Uint1(x459)))
var x462 uint64
var x463 uint64
x462, x463 = bits.Add64(x446, arg1[7], uint64(0x0))
var x464 uint64
var x465 uint64
x464, x465 = bits.Add64(x448, uint64(0x0), uint64(p521Uint1(x463)))
var x466 uint64
var x467 uint64
x466, x467 = bits.Add64(x450, uint64(0x0), uint64(p521Uint1(x465)))
var x468 uint64
var x469 uint64
x468, x469 = bits.Add64(x452, uint64(0x0), uint64(p521Uint1(x467)))
var x470 uint64
var x471 uint64
x470, x471 = bits.Add64(x454, uint64(0x0), uint64(p521Uint1(x469)))
var x472 uint64
var x473 uint64
x472, x473 = bits.Add64(x456, uint64(0x0), uint64(p521Uint1(x471)))
var x474 uint64
var x475 uint64
x474, x475 = bits.Add64(x458, uint64(0x0), uint64(p521Uint1(x473)))
var x476 uint64
var x477 uint64
x476, x477 = bits.Add64(x460, uint64(0x0), uint64(p521Uint1(x475)))
var x478 uint64
var x479 uint64
x479, x478 = bits.Mul64(x462, 0x1ff)
var x480 uint64
var x481 uint64
x481, x480 = bits.Mul64(x462, 0xffffffffffffffff)
var x482 uint64
var x483 uint64
x483, x482 = bits.Mul64(x462, 0xffffffffffffffff)
var x484 uint64
var x485 uint64
x485, x484 = bits.Mul64(x462, 0xffffffffffffffff)
var x486 uint64
var x487 uint64
x487, x486 = bits.Mul64(x462, 0xffffffffffffffff)
var x488 uint64
var x489 uint64
x489, x488 = bits.Mul64(x462, 0xffffffffffffffff)
var x490 uint64
var x491 uint64
x491, x490 = bits.Mul64(x462, 0xffffffffffffffff)
var x492 uint64
var x493 uint64
x493, x492 = bits.Mul64(x462, 0xffffffffffffffff)
var x494 uint64
var x495 uint64
x495, x494 = bits.Mul64(x462, 0xffffffffffffffff)
var x496 uint64
var x497 uint64
x496, x497 = bits.Add64(x495, x492, uint64(0x0))
var x498 uint64
var x499 uint64
x498, x499 = bits.Add64(x493, x490, uint64(p521Uint1(x497)))
var x500 uint64
var x501 uint64
x500, x501 = bits.Add64(x491, x488, uint64(p521Uint1(x499)))
var x502 uint64
var x503 uint64
x502, x503 = bits.Add64(x489, x486, uint64(p521Uint1(x501)))
var x504 uint64
var x505 uint64
x504, x505 = bits.Add64(x487, x484, uint64(p521Uint1(x503)))
var x506 uint64
var x507 uint64
x506, x507 = bits.Add64(x485, x482, uint64(p521Uint1(x505)))
var x508 uint64
var x509 uint64
x508, x509 = bits.Add64(x483, x480, uint64(p521Uint1(x507)))
var x510 uint64
var x511 uint64
x510, x511 = bits.Add64(x481, x478, uint64(p521Uint1(x509)))
var x513 uint64
_, x513 = bits.Add64(x462, x494, uint64(0x0))
var x514 uint64
var x515 uint64
x514, x515 = bits.Add64(x464, x496, uint64(p521Uint1(x513)))
var x516 uint64
var x517 uint64
x516, x517 = bits.Add64(x466, x498, uint64(p521Uint1(x515)))
var x518 uint64
var x519 uint64
x518, x519 = bits.Add64(x468, x500, uint64(p521Uint1(x517)))
var x520 uint64
var x521 uint64
x520, x521 = bits.Add64(x470, x502, uint64(p521Uint1(x519)))
var x522 uint64
var x523 uint64
x522, x523 = bits.Add64(x472, x504, uint64(p521Uint1(x521)))
var x524 uint64
var x525 uint64
x524, x525 = bits.Add64(x474, x506, uint64(p521Uint1(x523)))
var x526 uint64
var x527 uint64
x526, x527 = bits.Add64(x476, x508, uint64(p521Uint1(x525)))
var x528 uint64
var x529 uint64
x528, x529 = bits.Add64((uint64(p521Uint1(x477)) + (uint64(p521Uint1(x461)) + (uint64(p521Uint1(x443)) + x411))), x510, uint64(p521Uint1(x527)))
var x530 uint64
var x531 uint64
x530, x531 = bits.Add64(x514, arg1[8], uint64(0x0))
var x532 uint64
var x533 uint64
x532, x533 = bits.Add64(x516, uint64(0x0), uint64(p521Uint1(x531)))
var x534 uint64
var x535 uint64
x534, x535 = bits.Add64(x518, uint64(0x0), uint64(p521Uint1(x533)))
var x536 uint64
var x537 uint64
x536, x537 = bits.Add64(x520, uint64(0x0), uint64(p521Uint1(x535)))
var x538 uint64
var x539 uint64
x538, x539 = bits.Add64(x522, uint64(0x0), uint64(p521Uint1(x537)))
var x540 uint64
var x541 uint64
x540, x541 = bits.Add64(x524, uint64(0x0), uint64(p521Uint1(x539)))
var x542 uint64
var x543 uint64
x542, x543 = bits.Add64(x526, uint64(0x0), uint64(p521Uint1(x541)))
var x544 uint64
var x545 uint64
x544, x545 = bits.Add64(x528, uint64(0x0), uint64(p521Uint1(x543)))
var x546 uint64
var x547 uint64
x547, x546 = bits.Mul64(x530, 0x1ff)
var x548 uint64
var x549 uint64
x549, x548 = bits.Mul64(x530, 0xffffffffffffffff)
var x550 uint64
var x551 uint64
x551, x550 = bits.Mul64(x530, 0xffffffffffffffff)
var x552 uint64
var x553 uint64
x553, x552 = bits.Mul64(x530, 0xffffffffffffffff)
var x554 uint64
var x555 uint64
x555, x554 = bits.Mul64(x530, 0xffffffffffffffff)
var x556 uint64
var x557 uint64
x557, x556 = bits.Mul64(x530, 0xffffffffffffffff)
var x558 uint64
var x559 uint64
x559, x558 = bits.Mul64(x530, 0xffffffffffffffff)
var x560 uint64
var x561 uint64
x561, x560 = bits.Mul64(x530, 0xffffffffffffffff)
var x562 uint64
var x563 uint64
x563, x562 = bits.Mul64(x530, 0xffffffffffffffff)
var x564 uint64
var x565 uint64
x564, x565 = bits.Add64(x563, x560, uint64(0x0))
var x566 uint64
var x567 uint64
x566, x567 = bits.Add64(x561, x558, uint64(p521Uint1(x565)))
var x568 uint64
var x569 uint64
x568, x569 = bits.Add64(x559, x556, uint64(p521Uint1(x567)))
var x570 uint64
var x571 uint64
x570, x571 = bits.Add64(x557, x554, uint64(p521Uint1(x569)))
var x572 uint64
var x573 uint64
x572, x573 = bits.Add64(x555, x552, uint64(p521Uint1(x571)))
var x574 uint64
var x575 uint64
x574, x575 = bits.Add64(x553, x550, uint64(p521Uint1(x573)))
var x576 uint64
var x577 uint64
x576, x577 = bits.Add64(x551, x548, uint64(p521Uint1(x575)))
var x578 uint64
var x579 uint64
x578, x579 = bits.Add64(x549, x546, uint64(p521Uint1(x577)))
var x581 uint64
_, x581 = bits.Add64(x530, x562, uint64(0x0))
var x582 uint64
var x583 uint64
x582, x583 = bits.Add64(x532, x564, uint64(p521Uint1(x581)))
var x584 uint64
var x585 uint64
x584, x585 = bits.Add64(x534, x566, uint64(p521Uint1(x583)))
var x586 uint64
var x587 uint64
x586, x587 = bits.Add64(x536, x568, uint64(p521Uint1(x585)))
var x588 uint64
var x589 uint64
x588, x589 = bits.Add64(x538, x570, uint64(p521Uint1(x587)))
var x590 uint64
var x591 uint64
x590, x591 = bits.Add64(x540, x572, uint64(p521Uint1(x589)))
var x592 uint64
var x593 uint64
x592, x593 = bits.Add64(x542, x574, uint64(p521Uint1(x591)))
var x594 uint64
var x595 uint64
x594, x595 = bits.Add64(x544, x576, uint64(p521Uint1(x593)))
var x596 uint64
var x597 uint64
x596, x597 = bits.Add64((uint64(p521Uint1(x545)) + (uint64(p521Uint1(x529)) + (uint64(p521Uint1(x511)) + x479))), x578, uint64(p521Uint1(x595)))
x598 := (uint64(p521Uint1(x597)) + (uint64(p521Uint1(x579)) + x547))
var x599 uint64
var x600 uint64
x599, x600 = bits.Sub64(x582, 0xffffffffffffffff, uint64(0x0))
var x601 uint64
var x602 uint64
x601, x602 = bits.Sub64(x584, 0xffffffffffffffff, uint64(p521Uint1(x600)))
var x603 uint64
var x604 uint64
x603, x604 = bits.Sub64(x586, 0xffffffffffffffff, uint64(p521Uint1(x602)))
var x605 uint64
var x606 uint64
x605, x606 = bits.Sub64(x588, 0xffffffffffffffff, uint64(p521Uint1(x604)))
var x607 uint64
var x608 uint64
x607, x608 = bits.Sub64(x590, 0xffffffffffffffff, uint64(p521Uint1(x606)))
var x609 uint64
var x610 uint64
x609, x610 = bits.Sub64(x592, 0xffffffffffffffff, uint64(p521Uint1(x608)))
var x611 uint64
var x612 uint64
x611, x612 = bits.Sub64(x594, 0xffffffffffffffff, uint64(p521Uint1(x610)))
var x613 uint64
var x614 uint64
x613, x614 = bits.Sub64(x596, 0xffffffffffffffff, uint64(p521Uint1(x612)))
var x615 uint64
var x616 uint64
x615, x616 = bits.Sub64(x598, 0x1ff, uint64(p521Uint1(x614)))
var x618 uint64
_, x618 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(p521Uint1(x616)))
var x619 uint64
p521CmovznzU64(&x619, p521Uint1(x618), x599, x582)
var x620 uint64
p521CmovznzU64(&x620, p521Uint1(x618), x601, x584)
var x621 uint64
p521CmovznzU64(&x621, p521Uint1(x618), x603, x586)
var x622 uint64
p521CmovznzU64(&x622, p521Uint1(x618), x605, x588)
var x623 uint64
p521CmovznzU64(&x623, p521Uint1(x618), x607, x590)
var x624 uint64
p521CmovznzU64(&x624, p521Uint1(x618), x609, x592)
var x625 uint64
p521CmovznzU64(&x625, p521Uint1(x618), x611, x594)
var x626 uint64
p521CmovznzU64(&x626, p521Uint1(x618), x613, x596)
var x627 uint64
p521CmovznzU64(&x627, p521Uint1(x618), x615, x598)
out1[0] = x619
out1[1] = x620
out1[2] = x621
out1[3] = x622
out1[4] = x623
out1[5] = x624
out1[6] = x625
out1[7] = x626
out1[8] = x627
}
// p521ToMontgomery translates a field element into the Montgomery domain.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// eval (from_montgomery out1) mod m = eval arg1 mod m
// 0 ≤ eval out1 < m
func p521ToMontgomery(out1 *p521MontgomeryDomainFieldElement, arg1 *p521NonMontgomeryDomainFieldElement) {
var x1 uint64
var x2 uint64
x2, x1 = bits.Mul64(arg1[0], 0x400000000000)
var x3 uint64
var x4 uint64
x4, x3 = bits.Mul64(arg1[1], 0x400000000000)
var x5 uint64
var x6 uint64
x5, x6 = bits.Add64(x2, x3, uint64(0x0))
var x7 uint64
var x8 uint64
x8, x7 = bits.Mul64(x1, 0x1ff)
var x9 uint64
var x10 uint64
x10, x9 = bits.Mul64(x1, 0xffffffffffffffff)
var x11 uint64
var x12 uint64
x12, x11 = bits.Mul64(x1, 0xffffffffffffffff)
var x13 uint64
var x14 uint64
x14, x13 = bits.Mul64(x1, 0xffffffffffffffff)
var x15 uint64
var x16 uint64
x16, x15 = bits.Mul64(x1, 0xffffffffffffffff)
var x17 uint64
var x18 uint64
x18, x17 = bits.Mul64(x1, 0xffffffffffffffff)
var x19 uint64
var x20 uint64
x20, x19 = bits.Mul64(x1, 0xffffffffffffffff)
var x21 uint64
var x22 uint64
x22, x21 = bits.Mul64(x1, 0xffffffffffffffff)
var x23 uint64
var x24 uint64
x24, x23 = bits.Mul64(x1, 0xffffffffffffffff)
var x25 uint64
var x26 uint64
x25, x26 = bits.Add64(x24, x21, uint64(0x0))
var x27 uint64
var x28 uint64
x27, x28 = bits.Add64(x22, x19, uint64(p521Uint1(x26)))
var x29 uint64
var x30 uint64
x29, x30 = bits.Add64(x20, x17, uint64(p521Uint1(x28)))
var x31 uint64
var x32 uint64
x31, x32 = bits.Add64(x18, x15, uint64(p521Uint1(x30)))
var x33 uint64
var x34 uint64
x33, x34 = bits.Add64(x16, x13, uint64(p521Uint1(x32)))
var x35 uint64
var x36 uint64
x35, x36 = bits.Add64(x14, x11, uint64(p521Uint1(x34)))
var x37 uint64
var x38 uint64
x37, x38 = bits.Add64(x12, x9, uint64(p521Uint1(x36)))
var x39 uint64
var x40 uint64
x39, x40 = bits.Add64(x10, x7, uint64(p521Uint1(x38)))
var x42 uint64
_, x42 = bits.Add64(x1, x23, uint64(0x0))
var x43 uint64
var x44 uint64
x43, x44 = bits.Add64(x5, x25, uint64(p521Uint1(x42)))
var x45 uint64
var x46 uint64
x45, x46 = bits.Add64((uint64(p521Uint1(x6)) + x4), x27, uint64(p521Uint1(x44)))
var x47 uint64
var x48 uint64
x47, x48 = bits.Add64(uint64(0x0), x29, uint64(p521Uint1(x46)))
var x49 uint64
var x50 uint64
x49, x50 = bits.Add64(uint64(0x0), x31, uint64(p521Uint1(x48)))
var x51 uint64
var x52 uint64
x51, x52 = bits.Add64(uint64(0x0), x33, uint64(p521Uint1(x50)))
var x53 uint64
var x54 uint64
x53, x54 = bits.Add64(uint64(0x0), x35, uint64(p521Uint1(x52)))
var x55 uint64
var x56 uint64
x55, x56 = bits.Add64(uint64(0x0), x37, uint64(p521Uint1(x54)))
var x57 uint64
var x58 uint64
x57, x58 = bits.Add64(uint64(0x0), x39, uint64(p521Uint1(x56)))
var x59 uint64
var x60 uint64
x60, x59 = bits.Mul64(arg1[2], 0x400000000000)
var x61 uint64
var x62 uint64
x61, x62 = bits.Add64(x45, x59, uint64(0x0))
var x63 uint64
var x64 uint64
x63, x64 = bits.Add64(x47, x60, uint64(p521Uint1(x62)))
var x65 uint64
var x66 uint64
x65, x66 = bits.Add64(x49, uint64(0x0), uint64(p521Uint1(x64)))
var x67 uint64
var x68 uint64
x67, x68 = bits.Add64(x51, uint64(0x0), uint64(p521Uint1(x66)))
var x69 uint64
var x70 uint64
x69, x70 = bits.Add64(x53, uint64(0x0), uint64(p521Uint1(x68)))
var x71 uint64
var x72 uint64
x71, x72 = bits.Add64(x55, uint64(0x0), uint64(p521Uint1(x70)))
var x73 uint64
var x74 uint64
x73, x74 = bits.Add64(x57, uint64(0x0), uint64(p521Uint1(x72)))
var x75 uint64
var x76 uint64
x76, x75 = bits.Mul64(x43, 0x1ff)
var x77 uint64
var x78 uint64
x78, x77 = bits.Mul64(x43, 0xffffffffffffffff)
var x79 uint64
var x80 uint64
x80, x79 = bits.Mul64(x43, 0xffffffffffffffff)
var x81 uint64
var x82 uint64
x82, x81 = bits.Mul64(x43, 0xffffffffffffffff)
var x83 uint64
var x84 uint64
x84, x83 = bits.Mul64(x43, 0xffffffffffffffff)
var x85 uint64
var x86 uint64
x86, x85 = bits.Mul64(x43, 0xffffffffffffffff)
var x87 uint64
var x88 uint64
x88, x87 = bits.Mul64(x43, 0xffffffffffffffff)
var x89 uint64
var x90 uint64
x90, x89 = bits.Mul64(x43, 0xffffffffffffffff)
var x91 uint64
var x92 uint64
x92, x91 = bits.Mul64(x43, 0xffffffffffffffff)
var x93 uint64
var x94 uint64
x93, x94 = bits.Add64(x92, x89, uint64(0x0))
var x95 uint64
var x96 uint64
x95, x96 = bits.Add64(x90, x87, uint64(p521Uint1(x94)))
var x97 uint64
var x98 uint64
x97, x98 = bits.Add64(x88, x85, uint64(p521Uint1(x96)))
var x99 uint64
var x100 uint64
x99, x100 = bits.Add64(x86, x83, uint64(p521Uint1(x98)))
var x101 uint64
var x102 uint64
x101, x102 = bits.Add64(x84, x81, uint64(p521Uint1(x100)))
var x103 uint64
var x104 uint64
x103, x104 = bits.Add64(x82, x79, uint64(p521Uint1(x102)))
var x105 uint64
var x106 uint64
x105, x106 = bits.Add64(x80, x77, uint64(p521Uint1(x104)))
var x107 uint64
var x108 uint64
x107, x108 = bits.Add64(x78, x75, uint64(p521Uint1(x106)))
var x110 uint64
_, x110 = bits.Add64(x43, x91, uint64(0x0))
var x111 uint64
var x112 uint64
x111, x112 = bits.Add64(x61, x93, uint64(p521Uint1(x110)))
var x113 uint64
var x114 uint64
x113, x114 = bits.Add64(x63, x95, uint64(p521Uint1(x112)))
var x115 uint64
var x116 uint64
x115, x116 = bits.Add64(x65, x97, uint64(p521Uint1(x114)))
var x117 uint64
var x118 uint64
x117, x118 = bits.Add64(x67, x99, uint64(p521Uint1(x116)))
var x119 uint64
var x120 uint64
x119, x120 = bits.Add64(x69, x101, uint64(p521Uint1(x118)))
var x121 uint64
var x122 uint64
x121, x122 = bits.Add64(x71, x103, uint64(p521Uint1(x120)))
var x123 uint64
var x124 uint64
x123, x124 = bits.Add64(x73, x105, uint64(p521Uint1(x122)))
var x125 uint64
var x126 uint64
x125, x126 = bits.Add64((uint64(p521Uint1(x74)) + (uint64(p521Uint1(x58)) + (uint64(p521Uint1(x40)) + x8))), x107, uint64(p521Uint1(x124)))
var x127 uint64
var x128 uint64
x128, x127 = bits.Mul64(arg1[3], 0x400000000000)
var x129 uint64
var x130 uint64
x129, x130 = bits.Add64(x113, x127, uint64(0x0))
var x131 uint64
var x132 uint64
x131, x132 = bits.Add64(x115, x128, uint64(p521Uint1(x130)))
var x133 uint64
var x134 uint64
x133, x134 = bits.Add64(x117, uint64(0x0), uint64(p521Uint1(x132)))
var x135 uint64
var x136 uint64
x135, x136 = bits.Add64(x119, uint64(0x0), uint64(p521Uint1(x134)))
var x137 uint64
var x138 uint64
x137, x138 = bits.Add64(x121, uint64(0x0), uint64(p521Uint1(x136)))
var x139 uint64
var x140 uint64
x139, x140 = bits.Add64(x123, uint64(0x0), uint64(p521Uint1(x138)))
var x141 uint64
var x142 uint64
x141, x142 = bits.Add64(x125, uint64(0x0), uint64(p521Uint1(x140)))
var x143 uint64
var x144 uint64
x144, x143 = bits.Mul64(x111, 0x1ff)
var x145 uint64
var x146 uint64
x146, x145 = bits.Mul64(x111, 0xffffffffffffffff)
var x147 uint64
var x148 uint64
x148, x147 = bits.Mul64(x111, 0xffffffffffffffff)
var x149 uint64
var x150 uint64
x150, x149 = bits.Mul64(x111, 0xffffffffffffffff)
var x151 uint64
var x152 uint64
x152, x151 = bits.Mul64(x111, 0xffffffffffffffff)
var x153 uint64
var x154 uint64
x154, x153 = bits.Mul64(x111, 0xffffffffffffffff)
var x155 uint64
var x156 uint64
x156, x155 = bits.Mul64(x111, 0xffffffffffffffff)
var x157 uint64
var x158 uint64
x158, x157 = bits.Mul64(x111, 0xffffffffffffffff)
var x159 uint64
var x160 uint64
x160, x159 = bits.Mul64(x111, 0xffffffffffffffff)
var x161 uint64
var x162 uint64
x161, x162 = bits.Add64(x160, x157, uint64(0x0))
var x163 uint64
var x164 uint64
x163, x164 = bits.Add64(x158, x155, uint64(p521Uint1(x162)))
var x165 uint64
var x166 uint64
x165, x166 = bits.Add64(x156, x153, uint64(p521Uint1(x164)))
var x167 uint64
var x168 uint64
x167, x168 = bits.Add64(x154, x151, uint64(p521Uint1(x166)))
var x169 uint64
var x170 uint64
x169, x170 = bits.Add64(x152, x149, uint64(p521Uint1(x168)))
var x171 uint64
var x172 uint64
x171, x172 = bits.Add64(x150, x147, uint64(p521Uint1(x170)))
var x173 uint64
var x174 uint64
x173, x174 = bits.Add64(x148, x145, uint64(p521Uint1(x172)))
var x175 uint64
var x176 uint64
x175, x176 = bits.Add64(x146, x143, uint64(p521Uint1(x174)))
var x178 uint64
_, x178 = bits.Add64(x111, x159, uint64(0x0))
var x179 uint64
var x180 uint64
x179, x180 = bits.Add64(x129, x161, uint64(p521Uint1(x178)))
var x181 uint64
var x182 uint64
x181, x182 = bits.Add64(x131, x163, uint64(p521Uint1(x180)))
var x183 uint64
var x184 uint64
x183, x184 = bits.Add64(x133, x165, uint64(p521Uint1(x182)))
var x185 uint64
var x186 uint64
x185, x186 = bits.Add64(x135, x167, uint64(p521Uint1(x184)))
var x187 uint64
var x188 uint64
x187, x188 = bits.Add64(x137, x169, uint64(p521Uint1(x186)))
var x189 uint64
var x190 uint64
x189, x190 = bits.Add64(x139, x171, uint64(p521Uint1(x188)))
var x191 uint64
var x192 uint64
x191, x192 = bits.Add64(x141, x173, uint64(p521Uint1(x190)))
var x193 uint64
var x194 uint64
x193, x194 = bits.Add64((uint64(p521Uint1(x142)) + (uint64(p521Uint1(x126)) + (uint64(p521Uint1(x108)) + x76))), x175, uint64(p521Uint1(x192)))
var x195 uint64
var x196 uint64
x196, x195 = bits.Mul64(arg1[4], 0x400000000000)
var x197 uint64
var x198 uint64
x197, x198 = bits.Add64(x181, x195, uint64(0x0))
var x199 uint64
var x200 uint64
x199, x200 = bits.Add64(x183, x196, uint64(p521Uint1(x198)))
var x201 uint64
var x202 uint64
x201, x202 = bits.Add64(x185, uint64(0x0), uint64(p521Uint1(x200)))
var x203 uint64
var x204 uint64
x203, x204 = bits.Add64(x187, uint64(0x0), uint64(p521Uint1(x202)))
var x205 uint64
var x206 uint64
x205, x206 = bits.Add64(x189, uint64(0x0), uint64(p521Uint1(x204)))
var x207 uint64
var x208 uint64
x207, x208 = bits.Add64(x191, uint64(0x0), uint64(p521Uint1(x206)))
var x209 uint64
var x210 uint64
x209, x210 = bits.Add64(x193, uint64(0x0), uint64(p521Uint1(x208)))
var x211 uint64
var x212 uint64
x212, x211 = bits.Mul64(x179, 0x1ff)
var x213 uint64
var x214 uint64
x214, x213 = bits.Mul64(x179, 0xffffffffffffffff)
var x215 uint64
var x216 uint64
x216, x215 = bits.Mul64(x179, 0xffffffffffffffff)
var x217 uint64
var x218 uint64
x218, x217 = bits.Mul64(x179, 0xffffffffffffffff)
var x219 uint64
var x220 uint64
x220, x219 = bits.Mul64(x179, 0xffffffffffffffff)
var x221 uint64
var x222 uint64
x222, x221 = bits.Mul64(x179, 0xffffffffffffffff)
var x223 uint64
var x224 uint64
x224, x223 = bits.Mul64(x179, 0xffffffffffffffff)
var x225 uint64
var x226 uint64
x226, x225 = bits.Mul64(x179, 0xffffffffffffffff)
var x227 uint64
var x228 uint64
x228, x227 = bits.Mul64(x179, 0xffffffffffffffff)
var x229 uint64
var x230 uint64
x229, x230 = bits.Add64(x228, x225, uint64(0x0))
var x231 uint64
var x232 uint64
x231, x232 = bits.Add64(x226, x223, uint64(p521Uint1(x230)))
var x233 uint64
var x234 uint64
x233, x234 = bits.Add64(x224, x221, uint64(p521Uint1(x232)))
var x235 uint64
var x236 uint64
x235, x236 = bits.Add64(x222, x219, uint64(p521Uint1(x234)))
var x237 uint64
var x238 uint64
x237, x238 = bits.Add64(x220, x217, uint64(p521Uint1(x236)))
var x239 uint64
var x240 uint64
x239, x240 = bits.Add64(x218, x215, uint64(p521Uint1(x238)))
var x241 uint64
var x242 uint64
x241, x242 = bits.Add64(x216, x213, uint64(p521Uint1(x240)))
var x243 uint64
var x244 uint64
x243, x244 = bits.Add64(x214, x211, uint64(p521Uint1(x242)))
var x246 uint64
_, x246 = bits.Add64(x179, x227, uint64(0x0))
var x247 uint64
var x248 uint64
x247, x248 = bits.Add64(x197, x229, uint64(p521Uint1(x246)))
var x249 uint64
var x250 uint64
x249, x250 = bits.Add64(x199, x231, uint64(p521Uint1(x248)))
var x251 uint64
var x252 uint64
x251, x252 = bits.Add64(x201, x233, uint64(p521Uint1(x250)))
var x253 uint64
var x254 uint64
x253, x254 = bits.Add64(x203, x235, uint64(p521Uint1(x252)))
var x255 uint64
var x256 uint64
x255, x256 = bits.Add64(x205, x237, uint64(p521Uint1(x254)))
var x257 uint64
var x258 uint64
x257, x258 = bits.Add64(x207, x239, uint64(p521Uint1(x256)))
var x259 uint64
var x260 uint64
x259, x260 = bits.Add64(x209, x241, uint64(p521Uint1(x258)))
var x261 uint64
var x262 uint64
x261, x262 = bits.Add64((uint64(p521Uint1(x210)) + (uint64(p521Uint1(x194)) + (uint64(p521Uint1(x176)) + x144))), x243, uint64(p521Uint1(x260)))
var x263 uint64
var x264 uint64
x264, x263 = bits.Mul64(arg1[5], 0x400000000000)
var x265 uint64
var x266 uint64
x265, x266 = bits.Add64(x249, x263, uint64(0x0))
var x267 uint64
var x268 uint64
x267, x268 = bits.Add64(x251, x264, uint64(p521Uint1(x266)))
var x269 uint64
var x270 uint64
x269, x270 = bits.Add64(x253, uint64(0x0), uint64(p521Uint1(x268)))
var x271 uint64
var x272 uint64
x271, x272 = bits.Add64(x255, uint64(0x0), uint64(p521Uint1(x270)))
var x273 uint64
var x274 uint64
x273, x274 = bits.Add64(x257, uint64(0x0), uint64(p521Uint1(x272)))
var x275 uint64
var x276 uint64
x275, x276 = bits.Add64(x259, uint64(0x0), uint64(p521Uint1(x274)))
var x277 uint64
var x278 uint64
x277, x278 = bits.Add64(x261, uint64(0x0), uint64(p521Uint1(x276)))
var x279 uint64
var x280 uint64
x280, x279 = bits.Mul64(x247, 0x1ff)
var x281 uint64
var x282 uint64
x282, x281 = bits.Mul64(x247, 0xffffffffffffffff)
var x283 uint64
var x284 uint64
x284, x283 = bits.Mul64(x247, 0xffffffffffffffff)
var x285 uint64
var x286 uint64
x286, x285 = bits.Mul64(x247, 0xffffffffffffffff)
var x287 uint64
var x288 uint64
x288, x287 = bits.Mul64(x247, 0xffffffffffffffff)
var x289 uint64
var x290 uint64
x290, x289 = bits.Mul64(x247, 0xffffffffffffffff)
var x291 uint64
var x292 uint64
x292, x291 = bits.Mul64(x247, 0xffffffffffffffff)
var x293 uint64
var x294 uint64
x294, x293 = bits.Mul64(x247, 0xffffffffffffffff)
var x295 uint64
var x296 uint64
x296, x295 = bits.Mul64(x247, 0xffffffffffffffff)
var x297 uint64
var x298 uint64
x297, x298 = bits.Add64(x296, x293, uint64(0x0))
var x299 uint64
var x300 uint64
x299, x300 = bits.Add64(x294, x291, uint64(p521Uint1(x298)))
var x301 uint64
var x302 uint64
x301, x302 = bits.Add64(x292, x289, uint64(p521Uint1(x300)))
var x303 uint64
var x304 uint64
x303, x304 = bits.Add64(x290, x287, uint64(p521Uint1(x302)))
var x305 uint64
var x306 uint64
x305, x306 = bits.Add64(x288, x285, uint64(p521Uint1(x304)))
var x307 uint64
var x308 uint64
x307, x308 = bits.Add64(x286, x283, uint64(p521Uint1(x306)))
var x309 uint64
var x310 uint64
x309, x310 = bits.Add64(x284, x281, uint64(p521Uint1(x308)))
var x311 uint64
var x312 uint64
x311, x312 = bits.Add64(x282, x279, uint64(p521Uint1(x310)))
var x314 uint64
_, x314 = bits.Add64(x247, x295, uint64(0x0))
var x315 uint64
var x316 uint64
x315, x316 = bits.Add64(x265, x297, uint64(p521Uint1(x314)))
var x317 uint64
var x318 uint64
x317, x318 = bits.Add64(x267, x299, uint64(p521Uint1(x316)))
var x319 uint64
var x320 uint64
x319, x320 = bits.Add64(x269, x301, uint64(p521Uint1(x318)))
var x321 uint64
var x322 uint64
x321, x322 = bits.Add64(x271, x303, uint64(p521Uint1(x320)))
var x323 uint64
var x324 uint64
x323, x324 = bits.Add64(x273, x305, uint64(p521Uint1(x322)))
var x325 uint64
var x326 uint64
x325, x326 = bits.Add64(x275, x307, uint64(p521Uint1(x324)))
var x327 uint64
var x328 uint64
x327, x328 = bits.Add64(x277, x309, uint64(p521Uint1(x326)))
var x329 uint64
var x330 uint64
x329, x330 = bits.Add64((uint64(p521Uint1(x278)) + (uint64(p521Uint1(x262)) + (uint64(p521Uint1(x244)) + x212))), x311, uint64(p521Uint1(x328)))
var x331 uint64
var x332 uint64
x332, x331 = bits.Mul64(arg1[6], 0x400000000000)
var x333 uint64
var x334 uint64
x333, x334 = bits.Add64(x317, x331, uint64(0x0))
var x335 uint64
var x336 uint64
x335, x336 = bits.Add64(x319, x332, uint64(p521Uint1(x334)))
var x337 uint64
var x338 uint64
x337, x338 = bits.Add64(x321, uint64(0x0), uint64(p521Uint1(x336)))
var x339 uint64
var x340 uint64
x339, x340 = bits.Add64(x323, uint64(0x0), uint64(p521Uint1(x338)))
var x341 uint64
var x342 uint64
x341, x342 = bits.Add64(x325, uint64(0x0), uint64(p521Uint1(x340)))
var x343 uint64
var x344 uint64
x343, x344 = bits.Add64(x327, uint64(0x0), uint64(p521Uint1(x342)))
var x345 uint64
var x346 uint64
x345, x346 = bits.Add64(x329, uint64(0x0), uint64(p521Uint1(x344)))
var x347 uint64
var x348 uint64
x348, x347 = bits.Mul64(x315, 0x1ff)
var x349 uint64
var x350 uint64
x350, x349 = bits.Mul64(x315, 0xffffffffffffffff)
var x351 uint64
var x352 uint64
x352, x351 = bits.Mul64(x315, 0xffffffffffffffff)
var x353 uint64
var x354 uint64
x354, x353 = bits.Mul64(x315, 0xffffffffffffffff)
var x355 uint64
var x356 uint64
x356, x355 = bits.Mul64(x315, 0xffffffffffffffff)
var x357 uint64
var x358 uint64
x358, x357 = bits.Mul64(x315, 0xffffffffffffffff)
var x359 uint64
var x360 uint64
x360, x359 = bits.Mul64(x315, 0xffffffffffffffff)
var x361 uint64
var x362 uint64
x362, x361 = bits.Mul64(x315, 0xffffffffffffffff)
var x363 uint64
var x364 uint64
x364, x363 = bits.Mul64(x315, 0xffffffffffffffff)
var x365 uint64
var x366 uint64
x365, x366 = bits.Add64(x364, x361, uint64(0x0))
var x367 uint64
var x368 uint64
x367, x368 = bits.Add64(x362, x359, uint64(p521Uint1(x366)))
var x369 uint64
var x370 uint64
x369, x370 = bits.Add64(x360, x357, uint64(p521Uint1(x368)))
var x371 uint64
var x372 uint64
x371, x372 = bits.Add64(x358, x355, uint64(p521Uint1(x370)))
var x373 uint64
var x374 uint64
x373, x374 = bits.Add64(x356, x353, uint64(p521Uint1(x372)))
var x375 uint64
var x376 uint64
x375, x376 = bits.Add64(x354, x351, uint64(p521Uint1(x374)))
var x377 uint64
var x378 uint64
x377, x378 = bits.Add64(x352, x349, uint64(p521Uint1(x376)))
var x379 uint64
var x380 uint64
x379, x380 = bits.Add64(x350, x347, uint64(p521Uint1(x378)))
var x382 uint64
_, x382 = bits.Add64(x315, x363, uint64(0x0))
var x383 uint64
var x384 uint64
x383, x384 = bits.Add64(x333, x365, uint64(p521Uint1(x382)))
var x385 uint64
var x386 uint64
x385, x386 = bits.Add64(x335, x367, uint64(p521Uint1(x384)))
var x387 uint64
var x388 uint64
x387, x388 = bits.Add64(x337, x369, uint64(p521Uint1(x386)))
var x389 uint64
var x390 uint64
x389, x390 = bits.Add64(x339, x371, uint64(p521Uint1(x388)))
var x391 uint64
var x392 uint64
x391, x392 = bits.Add64(x341, x373, uint64(p521Uint1(x390)))
var x393 uint64
var x394 uint64
x393, x394 = bits.Add64(x343, x375, uint64(p521Uint1(x392)))
var x395 uint64
var x396 uint64
x395, x396 = bits.Add64(x345, x377, uint64(p521Uint1(x394)))
var x397 uint64
var x398 uint64
x397, x398 = bits.Add64((uint64(p521Uint1(x346)) + (uint64(p521Uint1(x330)) + (uint64(p521Uint1(x312)) + x280))), x379, uint64(p521Uint1(x396)))
var x399 uint64
var x400 uint64
x400, x399 = bits.Mul64(arg1[7], 0x400000000000)
var x401 uint64
var x402 uint64
x401, x402 = bits.Add64(x385, x399, uint64(0x0))
var x403 uint64
var x404 uint64
x403, x404 = bits.Add64(x387, x400, uint64(p521Uint1(x402)))
var x405 uint64
var x406 uint64
x405, x406 = bits.Add64(x389, uint64(0x0), uint64(p521Uint1(x404)))
var x407 uint64
var x408 uint64
x407, x408 = bits.Add64(x391, uint64(0x0), uint64(p521Uint1(x406)))
var x409 uint64
var x410 uint64
x409, x410 = bits.Add64(x393, uint64(0x0), uint64(p521Uint1(x408)))
var x411 uint64
var x412 uint64
x411, x412 = bits.Add64(x395, uint64(0x0), uint64(p521Uint1(x410)))
var x413 uint64
var x414 uint64
x413, x414 = bits.Add64(x397, uint64(0x0), uint64(p521Uint1(x412)))
var x415 uint64
var x416 uint64
x416, x415 = bits.Mul64(x383, 0x1ff)
var x417 uint64
var x418 uint64
x418, x417 = bits.Mul64(x383, 0xffffffffffffffff)
var x419 uint64
var x420 uint64
x420, x419 = bits.Mul64(x383, 0xffffffffffffffff)
var x421 uint64
var x422 uint64
x422, x421 = bits.Mul64(x383, 0xffffffffffffffff)
var x423 uint64
var x424 uint64
x424, x423 = bits.Mul64(x383, 0xffffffffffffffff)
var x425 uint64
var x426 uint64
x426, x425 = bits.Mul64(x383, 0xffffffffffffffff)
var x427 uint64
var x428 uint64
x428, x427 = bits.Mul64(x383, 0xffffffffffffffff)
var x429 uint64
var x430 uint64
x430, x429 = bits.Mul64(x383, 0xffffffffffffffff)
var x431 uint64
var x432 uint64
x432, x431 = bits.Mul64(x383, 0xffffffffffffffff)
var x433 uint64
var x434 uint64
x433, x434 = bits.Add64(x432, x429, uint64(0x0))
var x435 uint64
var x436 uint64
x435, x436 = bits.Add64(x430, x427, uint64(p521Uint1(x434)))
var x437 uint64
var x438 uint64
x437, x438 = bits.Add64(x428, x425, uint64(p521Uint1(x436)))
var x439 uint64
var x440 uint64
x439, x440 = bits.Add64(x426, x423, uint64(p521Uint1(x438)))
var x441 uint64
var x442 uint64
x441, x442 = bits.Add64(x424, x421, uint64(p521Uint1(x440)))
var x443 uint64
var x444 uint64
x443, x444 = bits.Add64(x422, x419, uint64(p521Uint1(x442)))
var x445 uint64
var x446 uint64
x445, x446 = bits.Add64(x420, x417, uint64(p521Uint1(x444)))
var x447 uint64
var x448 uint64
x447, x448 = bits.Add64(x418, x415, uint64(p521Uint1(x446)))
var x450 uint64
_, x450 = bits.Add64(x383, x431, uint64(0x0))
var x451 uint64
var x452 uint64
x451, x452 = bits.Add64(x401, x433, uint64(p521Uint1(x450)))
var x453 uint64
var x454 uint64
x453, x454 = bits.Add64(x403, x435, uint64(p521Uint1(x452)))
var x455 uint64
var x456 uint64
x455, x456 = bits.Add64(x405, x437, uint64(p521Uint1(x454)))
var x457 uint64
var x458 uint64
x457, x458 = bits.Add64(x407, x439, uint64(p521Uint1(x456)))
var x459 uint64
var x460 uint64
x459, x460 = bits.Add64(x409, x441, uint64(p521Uint1(x458)))
var x461 uint64
var x462 uint64
x461, x462 = bits.Add64(x411, x443, uint64(p521Uint1(x460)))
var x463 uint64
var x464 uint64
x463, x464 = bits.Add64(x413, x445, uint64(p521Uint1(x462)))
var x465 uint64
var x466 uint64
x465, x466 = bits.Add64((uint64(p521Uint1(x414)) + (uint64(p521Uint1(x398)) + (uint64(p521Uint1(x380)) + x348))), x447, uint64(p521Uint1(x464)))
var x467 uint64
var x468 uint64
x468, x467 = bits.Mul64(arg1[8], 0x400000000000)
var x469 uint64
var x470 uint64
x469, x470 = bits.Add64(x453, x467, uint64(0x0))
var x471 uint64
var x472 uint64
x471, x472 = bits.Add64(x455, x468, uint64(p521Uint1(x470)))
var x473 uint64
var x474 uint64
x473, x474 = bits.Add64(x457, uint64(0x0), uint64(p521Uint1(x472)))
var x475 uint64
var x476 uint64
x475, x476 = bits.Add64(x459, uint64(0x0), uint64(p521Uint1(x474)))
var x477 uint64
var x478 uint64
x477, x478 = bits.Add64(x461, uint64(0x0), uint64(p521Uint1(x476)))
var x479 uint64
var x480 uint64
x479, x480 = bits.Add64(x463, uint64(0x0), uint64(p521Uint1(x478)))
var x481 uint64
var x482 uint64
x481, x482 = bits.Add64(x465, uint64(0x0), uint64(p521Uint1(x480)))
var x483 uint64
var x484 uint64
x484, x483 = bits.Mul64(x451, 0x1ff)
var x485 uint64
var x486 uint64
x486, x485 = bits.Mul64(x451, 0xffffffffffffffff)
var x487 uint64
var x488 uint64
x488, x487 = bits.Mul64(x451, 0xffffffffffffffff)
var x489 uint64
var x490 uint64
x490, x489 = bits.Mul64(x451, 0xffffffffffffffff)
var x491 uint64
var x492 uint64
x492, x491 = bits.Mul64(x451, 0xffffffffffffffff)
var x493 uint64
var x494 uint64
x494, x493 = bits.Mul64(x451, 0xffffffffffffffff)
var x495 uint64
var x496 uint64
x496, x495 = bits.Mul64(x451, 0xffffffffffffffff)
var x497 uint64
var x498 uint64
x498, x497 = bits.Mul64(x451, 0xffffffffffffffff)
var x499 uint64
var x500 uint64
x500, x499 = bits.Mul64(x451, 0xffffffffffffffff)
var x501 uint64
var x502 uint64
x501, x502 = bits.Add64(x500, x497, uint64(0x0))
var x503 uint64
var x504 uint64
x503, x504 = bits.Add64(x498, x495, uint64(p521Uint1(x502)))
var x505 uint64
var x506 uint64
x505, x506 = bits.Add64(x496, x493, uint64(p521Uint1(x504)))
var x507 uint64
var x508 uint64
x507, x508 = bits.Add64(x494, x491, uint64(p521Uint1(x506)))
var x509 uint64
var x510 uint64
x509, x510 = bits.Add64(x492, x489, uint64(p521Uint1(x508)))
var x511 uint64
var x512 uint64
x511, x512 = bits.Add64(x490, x487, uint64(p521Uint1(x510)))
var x513 uint64
var x514 uint64
x513, x514 = bits.Add64(x488, x485, uint64(p521Uint1(x512)))
var x515 uint64
var x516 uint64
x515, x516 = bits.Add64(x486, x483, uint64(p521Uint1(x514)))
var x518 uint64
_, x518 = bits.Add64(x451, x499, uint64(0x0))
var x519 uint64
var x520 uint64
x519, x520 = bits.Add64(x469, x501, uint64(p521Uint1(x518)))
var x521 uint64
var x522 uint64
x521, x522 = bits.Add64(x471, x503, uint64(p521Uint1(x520)))
var x523 uint64
var x524 uint64
x523, x524 = bits.Add64(x473, x505, uint64(p521Uint1(x522)))
var x525 uint64
var x526 uint64
x525, x526 = bits.Add64(x475, x507, uint64(p521Uint1(x524)))
var x527 uint64
var x528 uint64
x527, x528 = bits.Add64(x477, x509, uint64(p521Uint1(x526)))
var x529 uint64
var x530 uint64
x529, x530 = bits.Add64(x479, x511, uint64(p521Uint1(x528)))
var x531 uint64
var x532 uint64
x531, x532 = bits.Add64(x481, x513, uint64(p521Uint1(x530)))
var x533 uint64
var x534 uint64
x533, x534 = bits.Add64((uint64(p521Uint1(x482)) + (uint64(p521Uint1(x466)) + (uint64(p521Uint1(x448)) + x416))), x515, uint64(p521Uint1(x532)))
x535 := (uint64(p521Uint1(x534)) + (uint64(p521Uint1(x516)) + x484))
var x536 uint64
var x537 uint64
x536, x537 = bits.Sub64(x519, 0xffffffffffffffff, uint64(0x0))
var x538 uint64
var x539 uint64
x538, x539 = bits.Sub64(x521, 0xffffffffffffffff, uint64(p521Uint1(x537)))
var x540 uint64
var x541 uint64
x540, x541 = bits.Sub64(x523, 0xffffffffffffffff, uint64(p521Uint1(x539)))
var x542 uint64
var x543 uint64
x542, x543 = bits.Sub64(x525, 0xffffffffffffffff, uint64(p521Uint1(x541)))
var x544 uint64
var x545 uint64
x544, x545 = bits.Sub64(x527, 0xffffffffffffffff, uint64(p521Uint1(x543)))
var x546 uint64
var x547 uint64
x546, x547 = bits.Sub64(x529, 0xffffffffffffffff, uint64(p521Uint1(x545)))
var x548 uint64
var x549 uint64
x548, x549 = bits.Sub64(x531, 0xffffffffffffffff, uint64(p521Uint1(x547)))
var x550 uint64
var x551 uint64
x550, x551 = bits.Sub64(x533, 0xffffffffffffffff, uint64(p521Uint1(x549)))
var x552 uint64
var x553 uint64
x552, x553 = bits.Sub64(x535, 0x1ff, uint64(p521Uint1(x551)))
var x555 uint64
_, x555 = bits.Sub64(uint64(0x0), uint64(0x0), uint64(p521Uint1(x553)))
var x556 uint64
p521CmovznzU64(&x556, p521Uint1(x555), x536, x519)
var x557 uint64
p521CmovznzU64(&x557, p521Uint1(x555), x538, x521)
var x558 uint64
p521CmovznzU64(&x558, p521Uint1(x555), x540, x523)
var x559 uint64
p521CmovznzU64(&x559, p521Uint1(x555), x542, x525)
var x560 uint64
p521CmovznzU64(&x560, p521Uint1(x555), x544, x527)
var x561 uint64
p521CmovznzU64(&x561, p521Uint1(x555), x546, x529)
var x562 uint64
p521CmovznzU64(&x562, p521Uint1(x555), x548, x531)
var x563 uint64
p521CmovznzU64(&x563, p521Uint1(x555), x550, x533)
var x564 uint64
p521CmovznzU64(&x564, p521Uint1(x555), x552, x535)
out1[0] = x556
out1[1] = x557
out1[2] = x558
out1[3] = x559
out1[4] = x560
out1[5] = x561
out1[6] = x562
out1[7] = x563
out1[8] = x564
}
// p521Selectznz is a multi-limb conditional select.
//
// Postconditions:
//
// eval out1 = (if arg1 = 0 then eval arg2 else eval arg3)
//
// Input Bounds:
//
// arg1: [0x0 ~> 0x1]
// arg2: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
// arg3: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff]]
func p521Selectznz(out1 *[9]uint64, arg1 p521Uint1, arg2 *[9]uint64, arg3 *[9]uint64) {
var x1 uint64
p521CmovznzU64(&x1, arg1, arg2[0], arg3[0])
var x2 uint64
p521CmovznzU64(&x2, arg1, arg2[1], arg3[1])
var x3 uint64
p521CmovznzU64(&x3, arg1, arg2[2], arg3[2])
var x4 uint64
p521CmovznzU64(&x4, arg1, arg2[3], arg3[3])
var x5 uint64
p521CmovznzU64(&x5, arg1, arg2[4], arg3[4])
var x6 uint64
p521CmovznzU64(&x6, arg1, arg2[5], arg3[5])
var x7 uint64
p521CmovznzU64(&x7, arg1, arg2[6], arg3[6])
var x8 uint64
p521CmovznzU64(&x8, arg1, arg2[7], arg3[7])
var x9 uint64
p521CmovznzU64(&x9, arg1, arg2[8], arg3[8])
out1[0] = x1
out1[1] = x2
out1[2] = x3
out1[3] = x4
out1[4] = x5
out1[5] = x6
out1[6] = x7
out1[7] = x8
out1[8] = x9
}
// p521ToBytes serializes a field element NOT in the Montgomery domain to bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ eval arg1 < m
//
// Postconditions:
//
// out1 = map (λ x, ⌊((eval arg1 mod m) mod 2^(8 * (x + 1))) / 2^(8 * x)⌋) [0..65]
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0x1ff]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0x1]]
func p521ToBytes(out1 *[66]uint8, arg1 *[9]uint64) {
x1 := arg1[8]
x2 := arg1[7]
x3 := arg1[6]
x4 := arg1[5]
x5 := arg1[4]
x6 := arg1[3]
x7 := arg1[2]
x8 := arg1[1]
x9 := arg1[0]
x10 := (uint8(x9) & 0xff)
x11 := (x9 >> 8)
x12 := (uint8(x11) & 0xff)
x13 := (x11 >> 8)
x14 := (uint8(x13) & 0xff)
x15 := (x13 >> 8)
x16 := (uint8(x15) & 0xff)
x17 := (x15 >> 8)
x18 := (uint8(x17) & 0xff)
x19 := (x17 >> 8)
x20 := (uint8(x19) & 0xff)
x21 := (x19 >> 8)
x22 := (uint8(x21) & 0xff)
x23 := uint8((x21 >> 8))
x24 := (uint8(x8) & 0xff)
x25 := (x8 >> 8)
x26 := (uint8(x25) & 0xff)
x27 := (x25 >> 8)
x28 := (uint8(x27) & 0xff)
x29 := (x27 >> 8)
x30 := (uint8(x29) & 0xff)
x31 := (x29 >> 8)
x32 := (uint8(x31) & 0xff)
x33 := (x31 >> 8)
x34 := (uint8(x33) & 0xff)
x35 := (x33 >> 8)
x36 := (uint8(x35) & 0xff)
x37 := uint8((x35 >> 8))
x38 := (uint8(x7) & 0xff)
x39 := (x7 >> 8)
x40 := (uint8(x39) & 0xff)
x41 := (x39 >> 8)
x42 := (uint8(x41) & 0xff)
x43 := (x41 >> 8)
x44 := (uint8(x43) & 0xff)
x45 := (x43 >> 8)
x46 := (uint8(x45) & 0xff)
x47 := (x45 >> 8)
x48 := (uint8(x47) & 0xff)
x49 := (x47 >> 8)
x50 := (uint8(x49) & 0xff)
x51 := uint8((x49 >> 8))
x52 := (uint8(x6) & 0xff)
x53 := (x6 >> 8)
x54 := (uint8(x53) & 0xff)
x55 := (x53 >> 8)
x56 := (uint8(x55) & 0xff)
x57 := (x55 >> 8)
x58 := (uint8(x57) & 0xff)
x59 := (x57 >> 8)
x60 := (uint8(x59) & 0xff)
x61 := (x59 >> 8)
x62 := (uint8(x61) & 0xff)
x63 := (x61 >> 8)
x64 := (uint8(x63) & 0xff)
x65 := uint8((x63 >> 8))
x66 := (uint8(x5) & 0xff)
x67 := (x5 >> 8)
x68 := (uint8(x67) & 0xff)
x69 := (x67 >> 8)
x70 := (uint8(x69) & 0xff)
x71 := (x69 >> 8)
x72 := (uint8(x71) & 0xff)
x73 := (x71 >> 8)
x74 := (uint8(x73) & 0xff)
x75 := (x73 >> 8)
x76 := (uint8(x75) & 0xff)
x77 := (x75 >> 8)
x78 := (uint8(x77) & 0xff)
x79 := uint8((x77 >> 8))
x80 := (uint8(x4) & 0xff)
x81 := (x4 >> 8)
x82 := (uint8(x81) & 0xff)
x83 := (x81 >> 8)
x84 := (uint8(x83) & 0xff)
x85 := (x83 >> 8)
x86 := (uint8(x85) & 0xff)
x87 := (x85 >> 8)
x88 := (uint8(x87) & 0xff)
x89 := (x87 >> 8)
x90 := (uint8(x89) & 0xff)
x91 := (x89 >> 8)
x92 := (uint8(x91) & 0xff)
x93 := uint8((x91 >> 8))
x94 := (uint8(x3) & 0xff)
x95 := (x3 >> 8)
x96 := (uint8(x95) & 0xff)
x97 := (x95 >> 8)
x98 := (uint8(x97) & 0xff)
x99 := (x97 >> 8)
x100 := (uint8(x99) & 0xff)
x101 := (x99 >> 8)
x102 := (uint8(x101) & 0xff)
x103 := (x101 >> 8)
x104 := (uint8(x103) & 0xff)
x105 := (x103 >> 8)
x106 := (uint8(x105) & 0xff)
x107 := uint8((x105 >> 8))
x108 := (uint8(x2) & 0xff)
x109 := (x2 >> 8)
x110 := (uint8(x109) & 0xff)
x111 := (x109 >> 8)
x112 := (uint8(x111) & 0xff)
x113 := (x111 >> 8)
x114 := (uint8(x113) & 0xff)
x115 := (x113 >> 8)
x116 := (uint8(x115) & 0xff)
x117 := (x115 >> 8)
x118 := (uint8(x117) & 0xff)
x119 := (x117 >> 8)
x120 := (uint8(x119) & 0xff)
x121 := uint8((x119 >> 8))
x122 := (uint8(x1) & 0xff)
x123 := p521Uint1((x1 >> 8))
out1[0] = x10
out1[1] = x12
out1[2] = x14
out1[3] = x16
out1[4] = x18
out1[5] = x20
out1[6] = x22
out1[7] = x23
out1[8] = x24
out1[9] = x26
out1[10] = x28
out1[11] = x30
out1[12] = x32
out1[13] = x34
out1[14] = x36
out1[15] = x37
out1[16] = x38
out1[17] = x40
out1[18] = x42
out1[19] = x44
out1[20] = x46
out1[21] = x48
out1[22] = x50
out1[23] = x51
out1[24] = x52
out1[25] = x54
out1[26] = x56
out1[27] = x58
out1[28] = x60
out1[29] = x62
out1[30] = x64
out1[31] = x65
out1[32] = x66
out1[33] = x68
out1[34] = x70
out1[35] = x72
out1[36] = x74
out1[37] = x76
out1[38] = x78
out1[39] = x79
out1[40] = x80
out1[41] = x82
out1[42] = x84
out1[43] = x86
out1[44] = x88
out1[45] = x90
out1[46] = x92
out1[47] = x93
out1[48] = x94
out1[49] = x96
out1[50] = x98
out1[51] = x100
out1[52] = x102
out1[53] = x104
out1[54] = x106
out1[55] = x107
out1[56] = x108
out1[57] = x110
out1[58] = x112
out1[59] = x114
out1[60] = x116
out1[61] = x118
out1[62] = x120
out1[63] = x121
out1[64] = x122
out1[65] = uint8(x123)
}
// p521FromBytes deserializes a field element NOT in the Montgomery domain from bytes in little-endian order.
//
// Preconditions:
//
// 0 ≤ bytes_eval arg1 < m
//
// Postconditions:
//
// eval out1 mod m = bytes_eval arg1 mod m
// 0 ≤ eval out1 < m
//
// Input Bounds:
//
// arg1: [[0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0xff], [0x0 ~> 0x1]]
//
// Output Bounds:
//
// out1: [[0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0xffffffffffffffff], [0x0 ~> 0x1ff]]
func p521FromBytes(out1 *[9]uint64, arg1 *[66]uint8) {
x1 := (uint64(p521Uint1(arg1[65])) << 8)
x2 := arg1[64]
x3 := (uint64(arg1[63]) << 56)
x4 := (uint64(arg1[62]) << 48)
x5 := (uint64(arg1[61]) << 40)
x6 := (uint64(arg1[60]) << 32)
x7 := (uint64(arg1[59]) << 24)
x8 := (uint64(arg1[58]) << 16)
x9 := (uint64(arg1[57]) << 8)
x10 := arg1[56]
x11 := (uint64(arg1[55]) << 56)
x12 := (uint64(arg1[54]) << 48)
x13 := (uint64(arg1[53]) << 40)
x14 := (uint64(arg1[52]) << 32)
x15 := (uint64(arg1[51]) << 24)
x16 := (uint64(arg1[50]) << 16)
x17 := (uint64(arg1[49]) << 8)
x18 := arg1[48]
x19 := (uint64(arg1[47]) << 56)
x20 := (uint64(arg1[46]) << 48)
x21 := (uint64(arg1[45]) << 40)
x22 := (uint64(arg1[44]) << 32)
x23 := (uint64(arg1[43]) << 24)
x24 := (uint64(arg1[42]) << 16)
x25 := (uint64(arg1[41]) << 8)
x26 := arg1[40]
x27 := (uint64(arg1[39]) << 56)
x28 := (uint64(arg1[38]) << 48)
x29 := (uint64(arg1[37]) << 40)
x30 := (uint64(arg1[36]) << 32)
x31 := (uint64(arg1[35]) << 24)
x32 := (uint64(arg1[34]) << 16)
x33 := (uint64(arg1[33]) << 8)
x34 := arg1[32]
x35 := (uint64(arg1[31]) << 56)
x36 := (uint64(arg1[30]) << 48)
x37 := (uint64(arg1[29]) << 40)
x38 := (uint64(arg1[28]) << 32)
x39 := (uint64(arg1[27]) << 24)
x40 := (uint64(arg1[26]) << 16)
x41 := (uint64(arg1[25]) << 8)
x42 := arg1[24]
x43 := (uint64(arg1[23]) << 56)
x44 := (uint64(arg1[22]) << 48)
x45 := (uint64(arg1[21]) << 40)
x46 := (uint64(arg1[20]) << 32)
x47 := (uint64(arg1[19]) << 24)
x48 := (uint64(arg1[18]) << 16)
x49 := (uint64(arg1[17]) << 8)
x50 := arg1[16]
x51 := (uint64(arg1[15]) << 56)
x52 := (uint64(arg1[14]) << 48)
x53 := (uint64(arg1[13]) << 40)
x54 := (uint64(arg1[12]) << 32)
x55 := (uint64(arg1[11]) << 24)
x56 := (uint64(arg1[10]) << 16)
x57 := (uint64(arg1[9]) << 8)
x58 := arg1[8]
x59 := (uint64(arg1[7]) << 56)
x60 := (uint64(arg1[6]) << 48)
x61 := (uint64(arg1[5]) << 40)
x62 := (uint64(arg1[4]) << 32)
x63 := (uint64(arg1[3]) << 24)
x64 := (uint64(arg1[2]) << 16)
x65 := (uint64(arg1[1]) << 8)
x66 := arg1[0]
x67 := (x65 + uint64(x66))
x68 := (x64 + x67)
x69 := (x63 + x68)
x70 := (x62 + x69)
x71 := (x61 + x70)
x72 := (x60 + x71)
x73 := (x59 + x72)
x74 := (x57 + uint64(x58))
x75 := (x56 + x74)
x76 := (x55 + x75)
x77 := (x54 + x76)
x78 := (x53 + x77)
x79 := (x52 + x78)
x80 := (x51 + x79)
x81 := (x49 + uint64(x50))
x82 := (x48 + x81)
x83 := (x47 + x82)
x84 := (x46 + x83)
x85 := (x45 + x84)
x86 := (x44 + x85)
x87 := (x43 + x86)
x88 := (x41 + uint64(x42))
x89 := (x40 + x88)
x90 := (x39 + x89)
x91 := (x38 + x90)
x92 := (x37 + x91)
x93 := (x36 + x92)
x94 := (x35 + x93)
x95 := (x33 + uint64(x34))
x96 := (x32 + x95)
x97 := (x31 + x96)
x98 := (x30 + x97)
x99 := (x29 + x98)
x100 := (x28 + x99)
x101 := (x27 + x100)
x102 := (x25 + uint64(x26))
x103 := (x24 + x102)
x104 := (x23 + x103)
x105 := (x22 + x104)
x106 := (x21 + x105)
x107 := (x20 + x106)
x108 := (x19 + x107)
x109 := (x17 + uint64(x18))
x110 := (x16 + x109)
x111 := (x15 + x110)
x112 := (x14 + x111)
x113 := (x13 + x112)
x114 := (x12 + x113)
x115 := (x11 + x114)
x116 := (x9 + uint64(x10))
x117 := (x8 + x116)
x118 := (x7 + x117)
x119 := (x6 + x118)
x120 := (x5 + x119)
x121 := (x4 + x120)
x122 := (x3 + x121)
x123 := (x1 + uint64(x2))
out1[0] = x73
out1[1] = x80
out1[2] = x87
out1[3] = x94
out1[4] = x101
out1[5] = x108
out1[6] = x115
out1[7] = x122
out1[8] = x123
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by addchain. DO NOT EDIT.
package fiat
// Invert sets e = 1/x, and returns e.
//
// If x == 0, Invert returns e = 0.
func (e *P521Element) Invert(x *P521Element) *P521Element {
// Inversion is implemented as exponentiation with exponent p − 2.
// The sequence of 13 multiplications and 520 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _1100 = _11 << 2
// _1111 = _11 + _1100
// _11110000 = _1111 << 4
// _11111111 = _1111 + _11110000
// x16 = _11111111 << 8 + _11111111
// x32 = x16 << 16 + x16
// x64 = x32 << 32 + x32
// x65 = 2*x64 + 1
// x129 = x65 << 64 + x64
// x130 = 2*x129 + 1
// x259 = x130 << 129 + x129
// x260 = 2*x259 + 1
// x519 = x260 << 259 + x259
// return x519 << 2 + 1
//
var z = new(P521Element).Set(e)
var t0 = new(P521Element)
z.Square(x)
z.Mul(x, z)
t0.Square(z)
for s := 1; s < 2; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
for s := 1; s < 4; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
for s := 1; s < 8; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
for s := 1; s < 16; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
for s := 1; s < 32; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
t0.Mul(x, t0)
for s := 0; s < 64; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
t0.Mul(x, t0)
for s := 0; s < 129; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
t0.Square(z)
t0.Mul(x, t0)
for s := 0; s < 259; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
for s := 0; s < 2; s++ {
z.Square(z)
}
z.Mul(x, z)
return e.Set(z)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package nistec
import (
"crypto/internal/nistec/fiat"
"crypto/subtle"
"errors"
"sync"
)
// p224ElementLength is the length of an element of the base or scalar field,
// which have the same bytes length for all NIST P curves.
const p224ElementLength = 28
// P224Point is a P224 point. The zero value is NOT valid.
type P224Point struct {
// The point is represented in projective coordinates (X:Y:Z),
// where x = X/Z and y = Y/Z.
x, y, z *fiat.P224Element
}
// NewP224Point returns a new P224Point representing the point at infinity point.
func NewP224Point() *P224Point {
return &P224Point{
x: new(fiat.P224Element),
y: new(fiat.P224Element).One(),
z: new(fiat.P224Element),
}
}
// SetGenerator sets p to the canonical generator and returns p.
func (p *P224Point) SetGenerator() *P224Point {
p.x.SetBytes([]byte{0xb7, 0xe, 0xc, 0xbd, 0x6b, 0xb4, 0xbf, 0x7f, 0x32, 0x13, 0x90, 0xb9, 0x4a, 0x3, 0xc1, 0xd3, 0x56, 0xc2, 0x11, 0x22, 0x34, 0x32, 0x80, 0xd6, 0x11, 0x5c, 0x1d, 0x21})
p.y.SetBytes([]byte{0xbd, 0x37, 0x63, 0x88, 0xb5, 0xf7, 0x23, 0xfb, 0x4c, 0x22, 0xdf, 0xe6, 0xcd, 0x43, 0x75, 0xa0, 0x5a, 0x7, 0x47, 0x64, 0x44, 0xd5, 0x81, 0x99, 0x85, 0x0, 0x7e, 0x34})
p.z.One()
return p
}
// Set sets p = q and returns p.
func (p *P224Point) Set(q *P224Point) *P224Point {
p.x.Set(q.x)
p.y.Set(q.y)
p.z.Set(q.z)
return p
}
// SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
// b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
// the curve, it returns nil and an error, and the receiver is unchanged.
// Otherwise, it returns p.
func (p *P224Point) SetBytes(b []byte) (*P224Point, error) {
switch {
// Point at infinity.
case len(b) == 1 && b[0] == 0:
return p.Set(NewP224Point()), nil
// Uncompressed form.
case len(b) == 1+2*p224ElementLength && b[0] == 4:
x, err := new(fiat.P224Element).SetBytes(b[1 : 1+p224ElementLength])
if err != nil {
return nil, err
}
y, err := new(fiat.P224Element).SetBytes(b[1+p224ElementLength:])
if err != nil {
return nil, err
}
if err := p224CheckOnCurve(x, y); err != nil {
return nil, err
}
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
// Compressed form.
case len(b) == 1+p224ElementLength && (b[0] == 2 || b[0] == 3):
x, err := new(fiat.P224Element).SetBytes(b[1:])
if err != nil {
return nil, err
}
// y² = x³ - 3x + b
y := p224Polynomial(new(fiat.P224Element), x)
if !p224Sqrt(y, y) {
return nil, errors.New("invalid P224 compressed point encoding")
}
// Select the positive or negative root, as indicated by the least
// significant bit, based on the encoding type byte.
otherRoot := new(fiat.P224Element)
otherRoot.Sub(otherRoot, y)
cond := y.Bytes()[p224ElementLength-1]&1 ^ b[0]&1
y.Select(otherRoot, y, int(cond))
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
default:
return nil, errors.New("invalid P224 point encoding")
}
}
var _p224B *fiat.P224Element
var _p224BOnce sync.Once
func p224B() *fiat.P224Element {
_p224BOnce.Do(func() {
_p224B, _ = new(fiat.P224Element).SetBytes([]byte{0xb4, 0x5, 0xa, 0x85, 0xc, 0x4, 0xb3, 0xab, 0xf5, 0x41, 0x32, 0x56, 0x50, 0x44, 0xb0, 0xb7, 0xd7, 0xbf, 0xd8, 0xba, 0x27, 0xb, 0x39, 0x43, 0x23, 0x55, 0xff, 0xb4})
})
return _p224B
}
// p224Polynomial sets y2 to x³ - 3x + b, and returns y2.
func p224Polynomial(y2, x *fiat.P224Element) *fiat.P224Element {
y2.Square(x)
y2.Mul(y2, x)
threeX := new(fiat.P224Element).Add(x, x)
threeX.Add(threeX, x)
y2.Sub(y2, threeX)
return y2.Add(y2, p224B())
}
func p224CheckOnCurve(x, y *fiat.P224Element) error {
// y² = x³ - 3x + b
rhs := p224Polynomial(new(fiat.P224Element), x)
lhs := new(fiat.P224Element).Square(y)
if rhs.Equal(lhs) != 1 {
return errors.New("P224 point not on curve")
}
return nil
}
// Bytes returns the uncompressed or infinity encoding of p, as specified in
// SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
// infinity is shorter than all other encodings.
func (p *P224Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + 2*p224ElementLength]byte
return p.bytes(&out)
}
func (p *P224Point) bytes(out *[1 + 2*p224ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P224Element).Invert(p.z)
x := new(fiat.P224Element).Mul(p.x, zinv)
y := new(fiat.P224Element).Mul(p.y, zinv)
buf := append(out[:0], 4)
buf = append(buf, x.Bytes()...)
buf = append(buf, y.Bytes()...)
return buf
}
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *P224Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p224ElementLength]byte
return p.bytesX(&out)
}
func (p *P224Point) bytesX(out *[p224ElementLength]byte) ([]byte, error) {
if p.z.IsZero() == 1 {
return nil, errors.New("P224 point is the point at infinity")
}
zinv := new(fiat.P224Element).Invert(p.z)
x := new(fiat.P224Element).Mul(p.x, zinv)
return append(out[:0], x.Bytes()...), nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings.
func (p *P224Point) BytesCompressed() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + p224ElementLength]byte
return p.bytesCompressed(&out)
}
func (p *P224Point) bytesCompressed(out *[1 + p224ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P224Element).Invert(p.z)
x := new(fiat.P224Element).Mul(p.x, zinv)
y := new(fiat.P224Element).Mul(p.y, zinv)
// Encode the sign of the y coordinate (indicated by the least significant
// bit) as the encoding type (2 or 3).
buf := append(out[:0], 2)
buf[0] |= y.Bytes()[p224ElementLength-1] & 1
buf = append(buf, x.Bytes()...)
return buf
}
// Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *P224Point) Add(p1, p2 *P224Point) *P224Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P224Element).Mul(p1.x, p2.x) // t0 := X1 * X2
t1 := new(fiat.P224Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2
t2 := new(fiat.P224Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2
t3 := new(fiat.P224Element).Add(p1.x, p1.y) // t3 := X1 + Y1
t4 := new(fiat.P224Element).Add(p2.x, p2.y) // t4 := X2 + Y2
t3.Mul(t3, t4) // t3 := t3 * t4
t4.Add(t0, t1) // t4 := t0 + t1
t3.Sub(t3, t4) // t3 := t3 - t4
t4.Add(p1.y, p1.z) // t4 := Y1 + Z1
x3 := new(fiat.P224Element).Add(p2.y, p2.z) // X3 := Y2 + Z2
t4.Mul(t4, x3) // t4 := t4 * X3
x3.Add(t1, t2) // X3 := t1 + t2
t4.Sub(t4, x3) // t4 := t4 - X3
x3.Add(p1.x, p1.z) // X3 := X1 + Z1
y3 := new(fiat.P224Element).Add(p2.x, p2.z) // Y3 := X2 + Z2
x3.Mul(x3, y3) // X3 := X3 * Y3
y3.Add(t0, t2) // Y3 := t0 + t2
y3.Sub(x3, y3) // Y3 := X3 - Y3
z3 := new(fiat.P224Element).Mul(p224B(), t2) // Z3 := b * t2
x3.Sub(y3, z3) // X3 := Y3 - Z3
z3.Add(x3, x3) // Z3 := X3 + X3
x3.Add(x3, z3) // X3 := X3 + Z3
z3.Sub(t1, x3) // Z3 := t1 - X3
x3.Add(t1, x3) // X3 := t1 + X3
y3.Mul(p224B(), y3) // Y3 := b * Y3
t1.Add(t2, t2) // t1 := t2 + t2
t2.Add(t1, t2) // t2 := t1 + t2
y3.Sub(y3, t2) // Y3 := Y3 - t2
y3.Sub(y3, t0) // Y3 := Y3 - t0
t1.Add(y3, y3) // t1 := Y3 + Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
t1.Add(t0, t0) // t1 := t0 + t0
t0.Add(t1, t0) // t0 := t1 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t1.Mul(t4, y3) // t1 := t4 * Y3
t2.Mul(t0, y3) // t2 := t0 * Y3
y3.Mul(x3, z3) // Y3 := X3 * Z3
y3.Add(y3, t2) // Y3 := Y3 + t2
x3.Mul(t3, x3) // X3 := t3 * X3
x3.Sub(x3, t1) // X3 := X3 - t1
z3.Mul(t4, z3) // Z3 := t4 * Z3
t1.Mul(t3, t0) // t1 := t3 * t0
z3.Add(z3, t1) // Z3 := Z3 + t1
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Double sets q = p + p, and returns q. The points may overlap.
func (q *P224Point) Double(p *P224Point) *P224Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P224Element).Square(p.x) // t0 := X ^ 2
t1 := new(fiat.P224Element).Square(p.y) // t1 := Y ^ 2
t2 := new(fiat.P224Element).Square(p.z) // t2 := Z ^ 2
t3 := new(fiat.P224Element).Mul(p.x, p.y) // t3 := X * Y
t3.Add(t3, t3) // t3 := t3 + t3
z3 := new(fiat.P224Element).Mul(p.x, p.z) // Z3 := X * Z
z3.Add(z3, z3) // Z3 := Z3 + Z3
y3 := new(fiat.P224Element).Mul(p224B(), t2) // Y3 := b * t2
y3.Sub(y3, z3) // Y3 := Y3 - Z3
x3 := new(fiat.P224Element).Add(y3, y3) // X3 := Y3 + Y3
y3.Add(x3, y3) // Y3 := X3 + Y3
x3.Sub(t1, y3) // X3 := t1 - Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
y3.Mul(x3, y3) // Y3 := X3 * Y3
x3.Mul(x3, t3) // X3 := X3 * t3
t3.Add(t2, t2) // t3 := t2 + t2
t2.Add(t2, t3) // t2 := t2 + t3
z3.Mul(p224B(), z3) // Z3 := b * Z3
z3.Sub(z3, t2) // Z3 := Z3 - t2
z3.Sub(z3, t0) // Z3 := Z3 - t0
t3.Add(z3, z3) // t3 := Z3 + Z3
z3.Add(z3, t3) // Z3 := Z3 + t3
t3.Add(t0, t0) // t3 := t0 + t0
t0.Add(t3, t0) // t0 := t3 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t0.Mul(t0, z3) // t0 := t0 * Z3
y3.Add(y3, t0) // Y3 := Y3 + t0
t0.Mul(p.y, p.z) // t0 := Y * Z
t0.Add(t0, t0) // t0 := t0 + t0
z3.Mul(t0, z3) // Z3 := t0 * Z3
x3.Sub(x3, z3) // X3 := X3 - Z3
z3.Mul(t0, t1) // Z3 := t0 * t1
z3.Add(z3, z3) // Z3 := Z3 + Z3
z3.Add(z3, z3) // Z3 := Z3 + Z3
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *P224Point) Select(p1, p2 *P224Point, cond int) *P224Point {
q.x.Select(p1.x, p2.x, cond)
q.y.Select(p1.y, p2.y, cond)
q.z.Select(p1.z, p2.z, cond)
return q
}
// A p224Table holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.
type p224Table [15]*P224Point
// Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *p224Table) Select(p *P224Point, n uint8) {
if n >= 16 {
panic("nistec: internal error: p224Table called with out-of-bounds value")
}
p.Set(NewP224Point())
for i := uint8(1); i < 16; i++ {
cond := subtle.ConstantTimeByteEq(i, n)
p.Select(table[i-1], p, cond)
}
}
// ScalarMult sets p = scalar * q, and returns p.
func (p *P224Point) ScalarMult(q *P224Point, scalar []byte) (*P224Point, error) {
// Compute a p224Table for the base point q. The explicit NewP224Point
// calls get inlined, letting the allocations live on the stack.
var table = p224Table{NewP224Point(), NewP224Point(), NewP224Point(),
NewP224Point(), NewP224Point(), NewP224Point(), NewP224Point(),
NewP224Point(), NewP224Point(), NewP224Point(), NewP224Point(),
NewP224Point(), NewP224Point(), NewP224Point(), NewP224Point()}
table[0].Set(q)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], q)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := NewP224Point()
p.Set(NewP224Point())
for i, byte := range scalar {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
}
windowValue := byte >> 4
table.Select(t, windowValue)
p.Add(p, t)
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
windowValue = byte & 0b1111
table.Select(t, windowValue)
p.Add(p, t)
}
return p, nil
}
var p224GeneratorTable *[p224ElementLength * 2]p224Table
var p224GeneratorTableOnce sync.Once
// generatorTable returns a sequence of p224Tables. The first table contains
// multiples of G. Each successive table is the previous table doubled four
// times.
func (p *P224Point) generatorTable() *[p224ElementLength * 2]p224Table {
p224GeneratorTableOnce.Do(func() {
p224GeneratorTable = new([p224ElementLength * 2]p224Table)
base := NewP224Point().SetGenerator()
for i := 0; i < p224ElementLength*2; i++ {
p224GeneratorTable[i][0] = NewP224Point().Set(base)
for j := 1; j < 15; j++ {
p224GeneratorTable[i][j] = NewP224Point().Add(p224GeneratorTable[i][j-1], base)
}
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return p224GeneratorTable
}
// ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
// returns p.
func (p *P224Point) ScalarBaseMult(scalar []byte) (*P224Point, error) {
if len(scalar) != p224ElementLength {
return nil, errors.New("invalid scalar length")
}
tables := p.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewP224Point()
p.Set(NewP224Point())
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
}
return p, nil
}
// p224Sqrt sets e to a square root of x. If x is not a square, p224Sqrt returns
// false and e is unchanged. e and x can overlap.
func p224Sqrt(e, x *fiat.P224Element) (isSquare bool) {
candidate := new(fiat.P224Element)
p224SqrtCandidate(candidate, x)
square := new(fiat.P224Element).Square(candidate)
if square.Equal(x) != 1 {
return false
}
e.Set(candidate)
return true
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nistec
import (
"crypto/internal/nistec/fiat"
"sync"
)
var p224GG *[96]fiat.P224Element
var p224GGOnce sync.Once
// p224SqrtCandidate sets r to a square root candidate for x. r and x must not overlap.
func p224SqrtCandidate(r, x *fiat.P224Element) {
// Since p = 1 mod 4, we can't use the exponentiation by (p + 1) / 4 like
// for the other primes. Instead, implement a variation of Tonelli–Shanks.
// The constant-time implementation is adapted from Thomas Pornin's ecGFp5.
//
// https://github.com/pornin/ecgfp5/blob/82325b965/rust/src/field.rs#L337-L385
// p = q*2^n + 1 with q odd -> q = 2^128 - 1 and n = 96
// g^(2^n) = 1 -> g = 11 ^ q (where 11 is the smallest non-square)
// GG[j] = g^(2^j) for j = 0 to n-1
p224GGOnce.Do(func() {
p224GG = new([96]fiat.P224Element)
for i := range p224GG {
if i == 0 {
p224GG[i].SetBytes([]byte{0x6a, 0x0f, 0xec, 0x67,
0x85, 0x98, 0xa7, 0x92, 0x0c, 0x55, 0xb2, 0xd4,
0x0b, 0x2d, 0x6f, 0xfb, 0xbe, 0xa3, 0xd8, 0xce,
0xf3, 0xfb, 0x36, 0x32, 0xdc, 0x69, 0x1b, 0x74})
} else {
p224GG[i].Square(&p224GG[i-1])
}
}
})
// r <- x^((q+1)/2) = x^(2^127)
// v <- x^q = x^(2^128-1)
// Compute x^(2^127-1) first.
//
// The sequence of 10 multiplications and 126 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// _1111110 = 2*_111111
// _1111111 = 1 + _1111110
// x12 = _1111110 << 5 + _111111
// x24 = x12 << 12 + x12
// i36 = x24 << 7
// x31 = _1111111 + i36
// x48 = i36 << 17 + x24
// x96 = x48 << 48 + x48
// return x96 << 31 + x31
//
var t0 = new(fiat.P224Element)
var t1 = new(fiat.P224Element)
r.Square(x)
r.Mul(x, r)
r.Square(r)
r.Mul(x, r)
t0.Square(r)
for s := 1; s < 3; s++ {
t0.Square(t0)
}
t0.Mul(r, t0)
t1.Square(t0)
r.Mul(x, t1)
for s := 0; s < 5; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
t1.Square(t0)
for s := 1; s < 12; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
t1.Square(t0)
for s := 1; s < 7; s++ {
t1.Square(t1)
}
r.Mul(r, t1)
for s := 0; s < 17; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
t1.Square(t0)
for s := 1; s < 48; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 31; s++ {
t0.Square(t0)
}
r.Mul(r, t0)
// v = x^(2^127-1)^2 * x
v := new(fiat.P224Element).Square(r)
v.Mul(v, x)
// r = x^(2^127-1) * x
r.Mul(r, x)
// for i = n-1 down to 1:
// w = v^(2^(i-1))
// if w == -1 then:
// v <- v*GG[n-i]
// r <- r*GG[n-i-1]
var p224MinusOne = new(fiat.P224Element).Sub(
new(fiat.P224Element), new(fiat.P224Element).One())
for i := 96 - 1; i >= 1; i-- {
w := new(fiat.P224Element).Set(v)
for j := 0; j < i-1; j++ {
w.Square(w)
}
cond := w.Equal(p224MinusOne)
v.Select(t0.Mul(v, &p224GG[96-i]), v, cond)
r.Select(t0.Mul(r, &p224GG[96-i-1]), r, cond)
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the Go wrapper for the constant-time, 64-bit assembly
// implementation of P256. The optimizations performed here are described in
// detail in:
// S.Gueron and V.Krasnov, "Fast prime field elliptic-curve cryptography with
// 256-bit primes"
// https://link.springer.com/article/10.1007%2Fs13389-014-0090-x
// https://eprint.iacr.org/2013/816.pdf
//go:build amd64 || arm64 || ppc64le || s390x
package nistec
import (
_ "embed"
"encoding/binary"
"errors"
"math/bits"
"runtime"
"unsafe"
)
// p256Element is a P-256 base field element in [0, P-1] in the Montgomery
// domain (with R 2²⁵⁶) as four limbs in little-endian order value.
type p256Element [4]uint64
// p256One is one in the Montgomery domain.
var p256One = p256Element{0x0000000000000001, 0xffffffff00000000,
0xffffffffffffffff, 0x00000000fffffffe}
var p256Zero = p256Element{}
// p256P is 2²⁵⁶ - 2²²⁴ + 2¹⁹² + 2⁹⁶ - 1 in the Montgomery domain.
var p256P = p256Element{0xffffffffffffffff, 0x00000000ffffffff,
0x0000000000000000, 0xffffffff00000001}
// P256Point is a P-256 point. The zero value should not be assumed to be valid
// (although it is in this implementation).
type P256Point struct {
// (X:Y:Z) are Jacobian coordinates where x = X/Z² and y = Y/Z³. The point
// at infinity can be represented by any set of coordinates with Z = 0.
x, y, z p256Element
}
// NewP256Point returns a new P256Point representing the point at infinity.
func NewP256Point() *P256Point {
return &P256Point{
x: p256One, y: p256One, z: p256Zero,
}
}
// SetGenerator sets p to the canonical generator and returns p.
func (p *P256Point) SetGenerator() *P256Point {
p.x = p256Element{0x79e730d418a9143c, 0x75ba95fc5fedb601,
0x79fb732b77622510, 0x18905f76a53755c6}
p.y = p256Element{0xddf25357ce95560a, 0x8b4ab8e4ba19e45c,
0xd2e88688dd21f325, 0x8571ff1825885d85}
p.z = p256One
return p
}
// Set sets p = q and returns p.
func (p *P256Point) Set(q *P256Point) *P256Point {
p.x, p.y, p.z = q.x, q.y, q.z
return p
}
const p256ElementLength = 32
const p256UncompressedLength = 1 + 2*p256ElementLength
const p256CompressedLength = 1 + p256ElementLength
// SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
// b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
// the curve, it returns nil and an error, and the receiver is unchanged.
// Otherwise, it returns p.
func (p *P256Point) SetBytes(b []byte) (*P256Point, error) {
// p256Mul operates in the Montgomery domain with R = 2²⁵⁶ mod p. Thus rr
// here is R in the Montgomery domain, or R×R mod p. See comment in
// P256OrdInverse about how this is used.
rr := p256Element{0x0000000000000003, 0xfffffffbffffffff,
0xfffffffffffffffe, 0x00000004fffffffd}
switch {
// Point at infinity.
case len(b) == 1 && b[0] == 0:
return p.Set(NewP256Point()), nil
// Uncompressed form.
case len(b) == p256UncompressedLength && b[0] == 4:
var r P256Point
p256BigToLittle(&r.x, (*[32]byte)(b[1:33]))
p256BigToLittle(&r.y, (*[32]byte)(b[33:65]))
if p256LessThanP(&r.x) == 0 || p256LessThanP(&r.y) == 0 {
return nil, errors.New("invalid P256 element encoding")
}
p256Mul(&r.x, &r.x, &rr)
p256Mul(&r.y, &r.y, &rr)
if err := p256CheckOnCurve(&r.x, &r.y); err != nil {
return nil, err
}
r.z = p256One
return p.Set(&r), nil
// Compressed form.
case len(b) == p256CompressedLength && (b[0] == 2 || b[0] == 3):
var r P256Point
p256BigToLittle(&r.x, (*[32]byte)(b[1:33]))
if p256LessThanP(&r.x) == 0 {
return nil, errors.New("invalid P256 element encoding")
}
p256Mul(&r.x, &r.x, &rr)
// y² = x³ - 3x + b
p256Polynomial(&r.y, &r.x)
if !p256Sqrt(&r.y, &r.y) {
return nil, errors.New("invalid P256 compressed point encoding")
}
// Select the positive or negative root, as indicated by the least
// significant bit, based on the encoding type byte.
yy := new(p256Element)
p256FromMont(yy, &r.y)
cond := int(yy[0]&1) ^ int(b[0]&1)
p256NegCond(&r.y, cond)
r.z = p256One
return p.Set(&r), nil
default:
return nil, errors.New("invalid P256 point encoding")
}
}
// p256Polynomial sets y2 to x³ - 3x + b, and returns y2.
func p256Polynomial(y2, x *p256Element) *p256Element {
x3 := new(p256Element)
p256Sqr(x3, x, 1)
p256Mul(x3, x3, x)
threeX := new(p256Element)
p256Add(threeX, x, x)
p256Add(threeX, threeX, x)
p256NegCond(threeX, 1)
p256B := &p256Element{0xd89cdf6229c4bddf, 0xacf005cd78843090,
0xe5a220abf7212ed6, 0xdc30061d04874834}
p256Add(x3, x3, threeX)
p256Add(x3, x3, p256B)
*y2 = *x3
return y2
}
func p256CheckOnCurve(x, y *p256Element) error {
// y² = x³ - 3x + b
rhs := p256Polynomial(new(p256Element), x)
lhs := new(p256Element)
p256Sqr(lhs, y, 1)
if p256Equal(lhs, rhs) != 1 {
return errors.New("P256 point not on curve")
}
return nil
}
// p256LessThanP returns 1 if x < p, and 0 otherwise. Note that a p256Element is
// not allowed to be equal to or greater than p, so if this function returns 0
// then x is invalid.
func p256LessThanP(x *p256Element) int {
var b uint64
_, b = bits.Sub64(x[0], p256P[0], b)
_, b = bits.Sub64(x[1], p256P[1], b)
_, b = bits.Sub64(x[2], p256P[2], b)
_, b = bits.Sub64(x[3], p256P[3], b)
return int(b)
}
// p256Add sets res = x + y.
func p256Add(res, x, y *p256Element) {
var c, b uint64
t1 := make([]uint64, 4)
t1[0], c = bits.Add64(x[0], y[0], 0)
t1[1], c = bits.Add64(x[1], y[1], c)
t1[2], c = bits.Add64(x[2], y[2], c)
t1[3], c = bits.Add64(x[3], y[3], c)
t2 := make([]uint64, 4)
t2[0], b = bits.Sub64(t1[0], p256P[0], 0)
t2[1], b = bits.Sub64(t1[1], p256P[1], b)
t2[2], b = bits.Sub64(t1[2], p256P[2], b)
t2[3], b = bits.Sub64(t1[3], p256P[3], b)
// Three options:
// - a+b < p
// then c is 0, b is 1, and t1 is correct
// - p <= a+b < 2^256
// then c is 0, b is 0, and t2 is correct
// - 2^256 <= a+b
// then c is 1, b is 1, and t2 is correct
t2Mask := (c ^ b) - 1
res[0] = (t1[0] & ^t2Mask) | (t2[0] & t2Mask)
res[1] = (t1[1] & ^t2Mask) | (t2[1] & t2Mask)
res[2] = (t1[2] & ^t2Mask) | (t2[2] & t2Mask)
res[3] = (t1[3] & ^t2Mask) | (t2[3] & t2Mask)
}
// p256Sqrt sets e to a square root of x. If x is not a square, p256Sqrt returns
// false and e is unchanged. e and x can overlap.
func p256Sqrt(e, x *p256Element) (isSquare bool) {
t0, t1 := new(p256Element), new(p256Element)
// Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate.
//
// The sequence of 7 multiplications and 253 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _1100 = _11 << 2
// _1111 = _11 + _1100
// _11110000 = _1111 << 4
// _11111111 = _1111 + _11110000
// x16 = _11111111 << 8 + _11111111
// x32 = x16 << 16 + x16
// return ((x32 << 32 + 1) << 96 + 1) << 94
//
p256Sqr(t0, x, 1)
p256Mul(t0, x, t0)
p256Sqr(t1, t0, 2)
p256Mul(t0, t0, t1)
p256Sqr(t1, t0, 4)
p256Mul(t0, t0, t1)
p256Sqr(t1, t0, 8)
p256Mul(t0, t0, t1)
p256Sqr(t1, t0, 16)
p256Mul(t0, t0, t1)
p256Sqr(t0, t0, 32)
p256Mul(t0, x, t0)
p256Sqr(t0, t0, 96)
p256Mul(t0, x, t0)
p256Sqr(t0, t0, 94)
p256Sqr(t1, t0, 1)
if p256Equal(t1, x) != 1 {
return false
}
*e = *t0
return true
}
// The following assembly functions are implemented in p256_asm_*.s
// Montgomery multiplication. Sets res = in1 * in2 * R⁻¹ mod p.
//
//go:noescape
func p256Mul(res, in1, in2 *p256Element)
// Montgomery square, repeated n times (n >= 1).
//
//go:noescape
func p256Sqr(res, in *p256Element, n int)
// Montgomery multiplication by R⁻¹, or 1 outside the domain.
// Sets res = in * R⁻¹, bringing res out of the Montgomery domain.
//
//go:noescape
func p256FromMont(res, in *p256Element)
// If cond is not 0, sets val = -val mod p.
//
//go:noescape
func p256NegCond(val *p256Element, cond int)
// If cond is 0, sets res = b, otherwise sets res = a.
//
//go:noescape
func p256MovCond(res, a, b *P256Point, cond int)
//go:noescape
func p256BigToLittle(res *p256Element, in *[32]byte)
//go:noescape
func p256LittleToBig(res *[32]byte, in *p256Element)
//go:noescape
func p256OrdBigToLittle(res *p256OrdElement, in *[32]byte)
//go:noescape
func p256OrdLittleToBig(res *[32]byte, in *p256OrdElement)
// p256Table is a table of the first 16 multiples of a point. Points are stored
// at an index offset of -1 so [8]P is at index 7, P is at 0, and [16]P is at 15.
// [0]P is the point at infinity and it's not stored.
type p256Table [16]P256Point
// p256Select sets res to the point at index idx in the table.
// idx must be in [0, 15]. It executes in constant time.
//
//go:noescape
func p256Select(res *P256Point, table *p256Table, idx int)
// p256AffinePoint is a point in affine coordinates (x, y). x and y are still
// Montgomery domain elements. The point can't be the point at infinity.
type p256AffinePoint struct {
x, y p256Element
}
// p256AffineTable is a table of the first 32 multiples of a point. Points are
// stored at an index offset of -1 like in p256Table, and [0]P is not stored.
type p256AffineTable [32]p256AffinePoint
// p256Precomputed is a series of precomputed multiples of G, the canonical
// generator. The first p256AffineTable contains multiples of G. The second one
// multiples of [2⁶]G, the third one of [2¹²]G, and so on, where each successive
// table is the previous table doubled six times. Six is the width of the
// sliding window used in p256ScalarMult, and having each table already
// pre-doubled lets us avoid the doublings between windows entirely. This table
// MUST NOT be modified, as it aliases into p256PrecomputedEmbed below.
var p256Precomputed *[43]p256AffineTable
//go:embed p256_asm_table.bin
var p256PrecomputedEmbed string
func init() {
p256PrecomputedPtr := (*unsafe.Pointer)(unsafe.Pointer(&p256PrecomputedEmbed))
if runtime.GOARCH == "s390x" {
var newTable [43 * 32 * 2 * 4]uint64
for i, x := range (*[43 * 32 * 2 * 4][8]byte)(*p256PrecomputedPtr) {
newTable[i] = binary.LittleEndian.Uint64(x[:])
}
newTablePtr := unsafe.Pointer(&newTable)
p256PrecomputedPtr = &newTablePtr
}
p256Precomputed = (*[43]p256AffineTable)(*p256PrecomputedPtr)
}
// p256SelectAffine sets res to the point at index idx in the table.
// idx must be in [0, 31]. It executes in constant time.
//
//go:noescape
func p256SelectAffine(res *p256AffinePoint, table *p256AffineTable, idx int)
// Point addition with an affine point and constant time conditions.
// If zero is 0, sets res = in2. If sel is 0, sets res = in1.
// If sign is not 0, sets res = in1 + -in2. Otherwise, sets res = in1 + in2
//
//go:noescape
func p256PointAddAffineAsm(res, in1 *P256Point, in2 *p256AffinePoint, sign, sel, zero int)
// Point addition. Sets res = in1 + in2. Returns one if the two input points
// were equal and zero otherwise. If in1 or in2 are the point at infinity, res
// and the return value are undefined.
//
//go:noescape
func p256PointAddAsm(res, in1, in2 *P256Point) int
// Point doubling. Sets res = in + in. in can be the point at infinity.
//
//go:noescape
func p256PointDoubleAsm(res, in *P256Point)
// p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the
// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
type p256OrdElement [4]uint64
// p256OrdReduce ensures s is in the range [0, ord(G)-1].
func p256OrdReduce(s *p256OrdElement) {
// Since 2 * ord(G) > 2²⁵⁶, we can just conditionally subtract ord(G),
// keeping the result if it doesn't underflow.
t0, b := bits.Sub64(s[0], 0xf3b9cac2fc632551, 0)
t1, b := bits.Sub64(s[1], 0xbce6faada7179e84, b)
t2, b := bits.Sub64(s[2], 0xffffffffffffffff, b)
t3, b := bits.Sub64(s[3], 0xffffffff00000000, b)
tMask := b - 1 // zero if subtraction underflowed
s[0] ^= (t0 ^ s[0]) & tMask
s[1] ^= (t1 ^ s[1]) & tMask
s[2] ^= (t2 ^ s[2]) & tMask
s[3] ^= (t3 ^ s[3]) & tMask
}
// Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *P256Point) Add(r1, r2 *P256Point) *P256Point {
var sum, double P256Point
r1IsInfinity := r1.isInfinity()
r2IsInfinity := r2.isInfinity()
pointsEqual := p256PointAddAsm(&sum, r1, r2)
p256PointDoubleAsm(&double, r1)
p256MovCond(&sum, &double, &sum, pointsEqual)
p256MovCond(&sum, r1, &sum, r2IsInfinity)
p256MovCond(&sum, r2, &sum, r1IsInfinity)
return q.Set(&sum)
}
// Double sets q = p + p, and returns q. The points may overlap.
func (q *P256Point) Double(p *P256Point) *P256Point {
var double P256Point
p256PointDoubleAsm(&double, p)
return q.Set(&double)
}
// ScalarBaseMult sets r = scalar * generator, where scalar is a 32-byte big
// endian value, and returns r. If scalar is not 32 bytes long, ScalarBaseMult
// returns an error and the receiver is unchanged.
func (r *P256Point) ScalarBaseMult(scalar []byte) (*P256Point, error) {
if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
scalarReversed := new(p256OrdElement)
p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar))
p256OrdReduce(scalarReversed)
r.p256BaseMult(scalarReversed)
return r, nil
}
// ScalarMult sets r = scalar * q, where scalar is a 32-byte big endian value,
// and returns r. If scalar is not 32 bytes long, ScalarBaseMult returns an
// error and the receiver is unchanged.
func (r *P256Point) ScalarMult(q *P256Point, scalar []byte) (*P256Point, error) {
if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
scalarReversed := new(p256OrdElement)
p256OrdBigToLittle(scalarReversed, (*[32]byte)(scalar))
p256OrdReduce(scalarReversed)
r.Set(q).p256ScalarMult(scalarReversed)
return r, nil
}
// uint64IsZero returns 1 if x is zero and zero otherwise.
func uint64IsZero(x uint64) int {
x = ^x
x &= x >> 32
x &= x >> 16
x &= x >> 8
x &= x >> 4
x &= x >> 2
x &= x >> 1
return int(x & 1)
}
// p256Equal returns 1 if a and b are equal and 0 otherwise.
func p256Equal(a, b *p256Element) int {
var acc uint64
for i := range a {
acc |= a[i] ^ b[i]
}
return uint64IsZero(acc)
}
// isInfinity returns 1 if p is the point at infinity and 0 otherwise.
func (p *P256Point) isInfinity() int {
return p256Equal(&p.z, &p256Zero)
}
// Bytes returns the uncompressed or infinity encoding of p, as specified in
// SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
// infinity is shorter than all other encodings.
func (p *P256Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p256UncompressedLength]byte
return p.bytes(&out)
}
func (p *P256Point) bytes(out *[p256UncompressedLength]byte) []byte {
// The proper representation of the point at infinity is a single zero byte.
if p.isInfinity() == 1 {
return append(out[:0], 0)
}
x, y := new(p256Element), new(p256Element)
p.affineFromMont(x, y)
out[0] = 4 // Uncompressed form.
p256LittleToBig((*[32]byte)(out[1:33]), x)
p256LittleToBig((*[32]byte)(out[33:65]), y)
return out[:]
}
// affineFromMont sets (x, y) to the affine coordinates of p, converted out of the
// Montgomery domain.
func (p *P256Point) affineFromMont(x, y *p256Element) {
p256Inverse(y, &p.z)
p256Sqr(x, y, 1)
p256Mul(y, y, x)
p256Mul(x, &p.x, x)
p256Mul(y, &p.y, y)
p256FromMont(x, x)
p256FromMont(y, y)
}
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *P256Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p256ElementLength]byte
return p.bytesX(&out)
}
func (p *P256Point) bytesX(out *[p256ElementLength]byte) ([]byte, error) {
if p.isInfinity() == 1 {
return nil, errors.New("P256 point is the point at infinity")
}
x := new(p256Element)
p256Inverse(x, &p.z)
p256Sqr(x, x, 1)
p256Mul(x, &p.x, x)
p256FromMont(x, x)
p256LittleToBig((*[32]byte)(out[:]), x)
return out[:], nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings.
func (p *P256Point) BytesCompressed() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p256CompressedLength]byte
return p.bytesCompressed(&out)
}
func (p *P256Point) bytesCompressed(out *[p256CompressedLength]byte) []byte {
if p.isInfinity() == 1 {
return append(out[:0], 0)
}
x, y := new(p256Element), new(p256Element)
p.affineFromMont(x, y)
out[0] = 2 | byte(y[0]&1)
p256LittleToBig((*[32]byte)(out[1:33]), x)
return out[:]
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *P256Point) Select(p1, p2 *P256Point, cond int) *P256Point {
p256MovCond(q, p1, p2, cond)
return q
}
// p256Inverse sets out to in⁻¹ mod p. If in is zero, out will be zero.
func p256Inverse(out, in *p256Element) {
// Inversion is calculated through exponentiation by p - 2, per Fermat's
// little theorem.
//
// The sequence of 12 multiplications and 255 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain
// v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// x12 = _111111 << 6 + _111111
// x15 = x12 << 3 + _111
// x16 = 2*x15 + 1
// x32 = x16 << 16 + x16
// i53 = x32 << 15
// x47 = x15 + i53
// i263 = ((i53 << 17 + 1) << 143 + x47) << 47
// return (x47 + i263) << 2 + 1
//
var z = new(p256Element)
var t0 = new(p256Element)
var t1 = new(p256Element)
p256Sqr(z, in, 1)
p256Mul(z, in, z)
p256Sqr(z, z, 1)
p256Mul(z, in, z)
p256Sqr(t0, z, 3)
p256Mul(t0, z, t0)
p256Sqr(t1, t0, 6)
p256Mul(t0, t0, t1)
p256Sqr(t0, t0, 3)
p256Mul(z, z, t0)
p256Sqr(t0, z, 1)
p256Mul(t0, in, t0)
p256Sqr(t1, t0, 16)
p256Mul(t0, t0, t1)
p256Sqr(t0, t0, 15)
p256Mul(z, z, t0)
p256Sqr(t0, t0, 17)
p256Mul(t0, in, t0)
p256Sqr(t0, t0, 143)
p256Mul(t0, z, t0)
p256Sqr(t0, t0, 47)
p256Mul(z, z, t0)
p256Sqr(z, z, 2)
p256Mul(out, in, z)
}
func boothW5(in uint) (int, int) {
var s uint = ^((in >> 5) - 1)
var d uint = (1 << 6) - in - 1
d = (d & s) | (in & (^s))
d = (d >> 1) + (d & 1)
return int(d), int(s & 1)
}
func boothW6(in uint) (int, int) {
var s uint = ^((in >> 6) - 1)
var d uint = (1 << 7) - in - 1
d = (d & s) | (in & (^s))
d = (d >> 1) + (d & 1)
return int(d), int(s & 1)
}
func (p *P256Point) p256BaseMult(scalar *p256OrdElement) {
var t0 p256AffinePoint
wvalue := (scalar[0] << 1) & 0x7f
sel, sign := boothW6(uint(wvalue))
p256SelectAffine(&t0, &p256Precomputed[0], sel)
p.x, p.y, p.z = t0.x, t0.y, p256One
p256NegCond(&p.y, sign)
index := uint(5)
zero := sel
for i := 1; i < 43; i++ {
if index < 192 {
wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x7f
} else {
wvalue = (scalar[index/64] >> (index % 64)) & 0x7f
}
index += 6
sel, sign = boothW6(uint(wvalue))
p256SelectAffine(&t0, &p256Precomputed[i], sel)
p256PointAddAffineAsm(p, p, &t0, sign, sel, zero)
zero |= sel
}
// If the whole scalar was zero, set to the point at infinity.
p256MovCond(p, p, NewP256Point(), zero)
}
func (p *P256Point) p256ScalarMult(scalar *p256OrdElement) {
// precomp is a table of precomputed points that stores powers of p
// from p^1 to p^16.
var precomp p256Table
var t0, t1, t2, t3 P256Point
// Prepare the table
precomp[0] = *p // 1
p256PointDoubleAsm(&t0, p)
p256PointDoubleAsm(&t1, &t0)
p256PointDoubleAsm(&t2, &t1)
p256PointDoubleAsm(&t3, &t2)
precomp[1] = t0 // 2
precomp[3] = t1 // 4
precomp[7] = t2 // 8
precomp[15] = t3 // 16
p256PointAddAsm(&t0, &t0, p)
p256PointAddAsm(&t1, &t1, p)
p256PointAddAsm(&t2, &t2, p)
precomp[2] = t0 // 3
precomp[4] = t1 // 5
precomp[8] = t2 // 9
p256PointDoubleAsm(&t0, &t0)
p256PointDoubleAsm(&t1, &t1)
precomp[5] = t0 // 6
precomp[9] = t1 // 10
p256PointAddAsm(&t2, &t0, p)
p256PointAddAsm(&t1, &t1, p)
precomp[6] = t2 // 7
precomp[10] = t1 // 11
p256PointDoubleAsm(&t0, &t0)
p256PointDoubleAsm(&t2, &t2)
precomp[11] = t0 // 12
precomp[13] = t2 // 14
p256PointAddAsm(&t0, &t0, p)
p256PointAddAsm(&t2, &t2, p)
precomp[12] = t0 // 13
precomp[14] = t2 // 15
// Start scanning the window from top bit
index := uint(254)
var sel, sign int
wvalue := (scalar[index/64] >> (index % 64)) & 0x3f
sel, _ = boothW5(uint(wvalue))
p256Select(p, &precomp, sel)
zero := sel
for index > 4 {
index -= 5
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
if index < 192 {
wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x3f
} else {
wvalue = (scalar[index/64] >> (index % 64)) & 0x3f
}
sel, sign = boothW5(uint(wvalue))
p256Select(&t0, &precomp, sel)
p256NegCond(&t0.y, sign)
p256PointAddAsm(&t1, p, &t0)
p256MovCond(&t1, &t1, p, sel)
p256MovCond(p, &t1, &t0, zero)
zero |= sel
}
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
wvalue = (scalar[0] << 1) & 0x3f
sel, sign = boothW5(uint(wvalue))
p256Select(&t0, &precomp, sel)
p256NegCond(&t0.y, sign)
p256PointAddAsm(&t1, p, &t0)
p256MovCond(&t1, &t1, p, sel)
p256MovCond(p, &t1, &t0, zero)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64
package nistec
import "errors"
// Montgomery multiplication modulo org(G). Sets res = in1 * in2 * R⁻¹.
//
//go:noescape
func p256OrdMul(res, in1, in2 *p256OrdElement)
// Montgomery square modulo org(G), repeated n times (n >= 1).
//
//go:noescape
func p256OrdSqr(res, in *p256OrdElement, n int)
func P256OrdInverse(k []byte) ([]byte, error) {
if len(k) != 32 {
return nil, errors.New("invalid scalar length")
}
x := new(p256OrdElement)
p256OrdBigToLittle(x, (*[32]byte)(k))
p256OrdReduce(x)
// Inversion is implemented as exponentiation by n - 2, per Fermat's little theorem.
//
// The sequence of 38 multiplications and 254 squarings is derived from
// https://briansmith.org/ecc-inversion-addition-chains-01#p256_scalar_inversion
_1 := new(p256OrdElement)
_11 := new(p256OrdElement)
_101 := new(p256OrdElement)
_111 := new(p256OrdElement)
_1111 := new(p256OrdElement)
_10101 := new(p256OrdElement)
_101111 := new(p256OrdElement)
t := new(p256OrdElement)
// This code operates in the Montgomery domain where R = 2²⁵⁶ mod n and n is
// the order of the scalar field. Elements in the Montgomery domain take the
// form a×R and p256OrdMul calculates (a × b × R⁻¹) mod n. RR is R in the
// domain, or R×R mod n, thus p256OrdMul(x, RR) gives x×R, i.e. converts x
// into the Montgomery domain.
RR := &p256OrdElement{0x83244c95be79eea2, 0x4699799c49bd6fa6,
0x2845b2392b6bec59, 0x66e12d94f3d95620}
p256OrdMul(_1, x, RR) // _1
p256OrdSqr(x, _1, 1) // _10
p256OrdMul(_11, x, _1) // _11
p256OrdMul(_101, x, _11) // _101
p256OrdMul(_111, x, _101) // _111
p256OrdSqr(x, _101, 1) // _1010
p256OrdMul(_1111, _101, x) // _1111
p256OrdSqr(t, x, 1) // _10100
p256OrdMul(_10101, t, _1) // _10101
p256OrdSqr(x, _10101, 1) // _101010
p256OrdMul(_101111, _101, x) // _101111
p256OrdMul(x, _10101, x) // _111111 = x6
p256OrdSqr(t, x, 2) // _11111100
p256OrdMul(t, t, _11) // _11111111 = x8
p256OrdSqr(x, t, 8) // _ff00
p256OrdMul(x, x, t) // _ffff = x16
p256OrdSqr(t, x, 16) // _ffff0000
p256OrdMul(t, t, x) // _ffffffff = x32
p256OrdSqr(x, t, 64)
p256OrdMul(x, x, t)
p256OrdSqr(x, x, 32)
p256OrdMul(x, x, t)
sqrs := []int{
6, 5, 4, 5, 5,
4, 3, 3, 5, 9,
6, 2, 5, 6, 5,
4, 5, 5, 3, 10,
2, 5, 5, 3, 7, 6}
muls := []*p256OrdElement{
_101111, _111, _11, _1111, _10101,
_101, _101, _101, _111, _101111,
_1111, _1, _1, _1111, _111,
_111, _111, _101, _11, _101111,
_11, _11, _11, _1, _10101, _1111}
for i, s := range sqrs {
p256OrdSqr(x, x, s)
p256OrdMul(x, x, muls[i])
}
// Montgomery multiplication by R⁻¹, or 1 outside the domain as R⁻¹×R = 1,
// converts a Montgomery value out of the domain.
one := &p256OrdElement{1}
p256OrdMul(x, x, one)
var xOut [32]byte
p256OrdLittleToBig(&xOut, x)
return xOut[:], nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package nistec
import (
"crypto/internal/nistec/fiat"
"crypto/subtle"
"errors"
"sync"
)
// p384ElementLength is the length of an element of the base or scalar field,
// which have the same bytes length for all NIST P curves.
const p384ElementLength = 48
// P384Point is a P384 point. The zero value is NOT valid.
type P384Point struct {
// The point is represented in projective coordinates (X:Y:Z),
// where x = X/Z and y = Y/Z.
x, y, z *fiat.P384Element
}
// NewP384Point returns a new P384Point representing the point at infinity point.
func NewP384Point() *P384Point {
return &P384Point{
x: new(fiat.P384Element),
y: new(fiat.P384Element).One(),
z: new(fiat.P384Element),
}
}
// SetGenerator sets p to the canonical generator and returns p.
func (p *P384Point) SetGenerator() *P384Point {
p.x.SetBytes([]byte{0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x5, 0x37, 0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74, 0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98, 0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38, 0x55, 0x2, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c, 0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0xa, 0xb7})
p.y.SetBytes([]byte{0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f, 0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29, 0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c, 0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0, 0xa, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d, 0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0xe, 0x5f})
p.z.One()
return p
}
// Set sets p = q and returns p.
func (p *P384Point) Set(q *P384Point) *P384Point {
p.x.Set(q.x)
p.y.Set(q.y)
p.z.Set(q.z)
return p
}
// SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
// b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
// the curve, it returns nil and an error, and the receiver is unchanged.
// Otherwise, it returns p.
func (p *P384Point) SetBytes(b []byte) (*P384Point, error) {
switch {
// Point at infinity.
case len(b) == 1 && b[0] == 0:
return p.Set(NewP384Point()), nil
// Uncompressed form.
case len(b) == 1+2*p384ElementLength && b[0] == 4:
x, err := new(fiat.P384Element).SetBytes(b[1 : 1+p384ElementLength])
if err != nil {
return nil, err
}
y, err := new(fiat.P384Element).SetBytes(b[1+p384ElementLength:])
if err != nil {
return nil, err
}
if err := p384CheckOnCurve(x, y); err != nil {
return nil, err
}
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
// Compressed form.
case len(b) == 1+p384ElementLength && (b[0] == 2 || b[0] == 3):
x, err := new(fiat.P384Element).SetBytes(b[1:])
if err != nil {
return nil, err
}
// y² = x³ - 3x + b
y := p384Polynomial(new(fiat.P384Element), x)
if !p384Sqrt(y, y) {
return nil, errors.New("invalid P384 compressed point encoding")
}
// Select the positive or negative root, as indicated by the least
// significant bit, based on the encoding type byte.
otherRoot := new(fiat.P384Element)
otherRoot.Sub(otherRoot, y)
cond := y.Bytes()[p384ElementLength-1]&1 ^ b[0]&1
y.Select(otherRoot, y, int(cond))
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
default:
return nil, errors.New("invalid P384 point encoding")
}
}
var _p384B *fiat.P384Element
var _p384BOnce sync.Once
func p384B() *fiat.P384Element {
_p384BOnce.Do(func() {
_p384B, _ = new(fiat.P384Element).SetBytes([]byte{0xb3, 0x31, 0x2f, 0xa7, 0xe2, 0x3e, 0xe7, 0xe4, 0x98, 0x8e, 0x5, 0x6b, 0xe3, 0xf8, 0x2d, 0x19, 0x18, 0x1d, 0x9c, 0x6e, 0xfe, 0x81, 0x41, 0x12, 0x3, 0x14, 0x8, 0x8f, 0x50, 0x13, 0x87, 0x5a, 0xc6, 0x56, 0x39, 0x8d, 0x8a, 0x2e, 0xd1, 0x9d, 0x2a, 0x85, 0xc8, 0xed, 0xd3, 0xec, 0x2a, 0xef})
})
return _p384B
}
// p384Polynomial sets y2 to x³ - 3x + b, and returns y2.
func p384Polynomial(y2, x *fiat.P384Element) *fiat.P384Element {
y2.Square(x)
y2.Mul(y2, x)
threeX := new(fiat.P384Element).Add(x, x)
threeX.Add(threeX, x)
y2.Sub(y2, threeX)
return y2.Add(y2, p384B())
}
func p384CheckOnCurve(x, y *fiat.P384Element) error {
// y² = x³ - 3x + b
rhs := p384Polynomial(new(fiat.P384Element), x)
lhs := new(fiat.P384Element).Square(y)
if rhs.Equal(lhs) != 1 {
return errors.New("P384 point not on curve")
}
return nil
}
// Bytes returns the uncompressed or infinity encoding of p, as specified in
// SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
// infinity is shorter than all other encodings.
func (p *P384Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + 2*p384ElementLength]byte
return p.bytes(&out)
}
func (p *P384Point) bytes(out *[1 + 2*p384ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P384Element).Invert(p.z)
x := new(fiat.P384Element).Mul(p.x, zinv)
y := new(fiat.P384Element).Mul(p.y, zinv)
buf := append(out[:0], 4)
buf = append(buf, x.Bytes()...)
buf = append(buf, y.Bytes()...)
return buf
}
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *P384Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p384ElementLength]byte
return p.bytesX(&out)
}
func (p *P384Point) bytesX(out *[p384ElementLength]byte) ([]byte, error) {
if p.z.IsZero() == 1 {
return nil, errors.New("P384 point is the point at infinity")
}
zinv := new(fiat.P384Element).Invert(p.z)
x := new(fiat.P384Element).Mul(p.x, zinv)
return append(out[:0], x.Bytes()...), nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings.
func (p *P384Point) BytesCompressed() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + p384ElementLength]byte
return p.bytesCompressed(&out)
}
func (p *P384Point) bytesCompressed(out *[1 + p384ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P384Element).Invert(p.z)
x := new(fiat.P384Element).Mul(p.x, zinv)
y := new(fiat.P384Element).Mul(p.y, zinv)
// Encode the sign of the y coordinate (indicated by the least significant
// bit) as the encoding type (2 or 3).
buf := append(out[:0], 2)
buf[0] |= y.Bytes()[p384ElementLength-1] & 1
buf = append(buf, x.Bytes()...)
return buf
}
// Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *P384Point) Add(p1, p2 *P384Point) *P384Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P384Element).Mul(p1.x, p2.x) // t0 := X1 * X2
t1 := new(fiat.P384Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2
t2 := new(fiat.P384Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2
t3 := new(fiat.P384Element).Add(p1.x, p1.y) // t3 := X1 + Y1
t4 := new(fiat.P384Element).Add(p2.x, p2.y) // t4 := X2 + Y2
t3.Mul(t3, t4) // t3 := t3 * t4
t4.Add(t0, t1) // t4 := t0 + t1
t3.Sub(t3, t4) // t3 := t3 - t4
t4.Add(p1.y, p1.z) // t4 := Y1 + Z1
x3 := new(fiat.P384Element).Add(p2.y, p2.z) // X3 := Y2 + Z2
t4.Mul(t4, x3) // t4 := t4 * X3
x3.Add(t1, t2) // X3 := t1 + t2
t4.Sub(t4, x3) // t4 := t4 - X3
x3.Add(p1.x, p1.z) // X3 := X1 + Z1
y3 := new(fiat.P384Element).Add(p2.x, p2.z) // Y3 := X2 + Z2
x3.Mul(x3, y3) // X3 := X3 * Y3
y3.Add(t0, t2) // Y3 := t0 + t2
y3.Sub(x3, y3) // Y3 := X3 - Y3
z3 := new(fiat.P384Element).Mul(p384B(), t2) // Z3 := b * t2
x3.Sub(y3, z3) // X3 := Y3 - Z3
z3.Add(x3, x3) // Z3 := X3 + X3
x3.Add(x3, z3) // X3 := X3 + Z3
z3.Sub(t1, x3) // Z3 := t1 - X3
x3.Add(t1, x3) // X3 := t1 + X3
y3.Mul(p384B(), y3) // Y3 := b * Y3
t1.Add(t2, t2) // t1 := t2 + t2
t2.Add(t1, t2) // t2 := t1 + t2
y3.Sub(y3, t2) // Y3 := Y3 - t2
y3.Sub(y3, t0) // Y3 := Y3 - t0
t1.Add(y3, y3) // t1 := Y3 + Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
t1.Add(t0, t0) // t1 := t0 + t0
t0.Add(t1, t0) // t0 := t1 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t1.Mul(t4, y3) // t1 := t4 * Y3
t2.Mul(t0, y3) // t2 := t0 * Y3
y3.Mul(x3, z3) // Y3 := X3 * Z3
y3.Add(y3, t2) // Y3 := Y3 + t2
x3.Mul(t3, x3) // X3 := t3 * X3
x3.Sub(x3, t1) // X3 := X3 - t1
z3.Mul(t4, z3) // Z3 := t4 * Z3
t1.Mul(t3, t0) // t1 := t3 * t0
z3.Add(z3, t1) // Z3 := Z3 + t1
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Double sets q = p + p, and returns q. The points may overlap.
func (q *P384Point) Double(p *P384Point) *P384Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P384Element).Square(p.x) // t0 := X ^ 2
t1 := new(fiat.P384Element).Square(p.y) // t1 := Y ^ 2
t2 := new(fiat.P384Element).Square(p.z) // t2 := Z ^ 2
t3 := new(fiat.P384Element).Mul(p.x, p.y) // t3 := X * Y
t3.Add(t3, t3) // t3 := t3 + t3
z3 := new(fiat.P384Element).Mul(p.x, p.z) // Z3 := X * Z
z3.Add(z3, z3) // Z3 := Z3 + Z3
y3 := new(fiat.P384Element).Mul(p384B(), t2) // Y3 := b * t2
y3.Sub(y3, z3) // Y3 := Y3 - Z3
x3 := new(fiat.P384Element).Add(y3, y3) // X3 := Y3 + Y3
y3.Add(x3, y3) // Y3 := X3 + Y3
x3.Sub(t1, y3) // X3 := t1 - Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
y3.Mul(x3, y3) // Y3 := X3 * Y3
x3.Mul(x3, t3) // X3 := X3 * t3
t3.Add(t2, t2) // t3 := t2 + t2
t2.Add(t2, t3) // t2 := t2 + t3
z3.Mul(p384B(), z3) // Z3 := b * Z3
z3.Sub(z3, t2) // Z3 := Z3 - t2
z3.Sub(z3, t0) // Z3 := Z3 - t0
t3.Add(z3, z3) // t3 := Z3 + Z3
z3.Add(z3, t3) // Z3 := Z3 + t3
t3.Add(t0, t0) // t3 := t0 + t0
t0.Add(t3, t0) // t0 := t3 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t0.Mul(t0, z3) // t0 := t0 * Z3
y3.Add(y3, t0) // Y3 := Y3 + t0
t0.Mul(p.y, p.z) // t0 := Y * Z
t0.Add(t0, t0) // t0 := t0 + t0
z3.Mul(t0, z3) // Z3 := t0 * Z3
x3.Sub(x3, z3) // X3 := X3 - Z3
z3.Mul(t0, t1) // Z3 := t0 * t1
z3.Add(z3, z3) // Z3 := Z3 + Z3
z3.Add(z3, z3) // Z3 := Z3 + Z3
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *P384Point) Select(p1, p2 *P384Point, cond int) *P384Point {
q.x.Select(p1.x, p2.x, cond)
q.y.Select(p1.y, p2.y, cond)
q.z.Select(p1.z, p2.z, cond)
return q
}
// A p384Table holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.
type p384Table [15]*P384Point
// Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *p384Table) Select(p *P384Point, n uint8) {
if n >= 16 {
panic("nistec: internal error: p384Table called with out-of-bounds value")
}
p.Set(NewP384Point())
for i := uint8(1); i < 16; i++ {
cond := subtle.ConstantTimeByteEq(i, n)
p.Select(table[i-1], p, cond)
}
}
// ScalarMult sets p = scalar * q, and returns p.
func (p *P384Point) ScalarMult(q *P384Point, scalar []byte) (*P384Point, error) {
// Compute a p384Table for the base point q. The explicit NewP384Point
// calls get inlined, letting the allocations live on the stack.
var table = p384Table{NewP384Point(), NewP384Point(), NewP384Point(),
NewP384Point(), NewP384Point(), NewP384Point(), NewP384Point(),
NewP384Point(), NewP384Point(), NewP384Point(), NewP384Point(),
NewP384Point(), NewP384Point(), NewP384Point(), NewP384Point()}
table[0].Set(q)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], q)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := NewP384Point()
p.Set(NewP384Point())
for i, byte := range scalar {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
}
windowValue := byte >> 4
table.Select(t, windowValue)
p.Add(p, t)
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
windowValue = byte & 0b1111
table.Select(t, windowValue)
p.Add(p, t)
}
return p, nil
}
var p384GeneratorTable *[p384ElementLength * 2]p384Table
var p384GeneratorTableOnce sync.Once
// generatorTable returns a sequence of p384Tables. The first table contains
// multiples of G. Each successive table is the previous table doubled four
// times.
func (p *P384Point) generatorTable() *[p384ElementLength * 2]p384Table {
p384GeneratorTableOnce.Do(func() {
p384GeneratorTable = new([p384ElementLength * 2]p384Table)
base := NewP384Point().SetGenerator()
for i := 0; i < p384ElementLength*2; i++ {
p384GeneratorTable[i][0] = NewP384Point().Set(base)
for j := 1; j < 15; j++ {
p384GeneratorTable[i][j] = NewP384Point().Add(p384GeneratorTable[i][j-1], base)
}
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return p384GeneratorTable
}
// ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
// returns p.
func (p *P384Point) ScalarBaseMult(scalar []byte) (*P384Point, error) {
if len(scalar) != p384ElementLength {
return nil, errors.New("invalid scalar length")
}
tables := p.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewP384Point()
p.Set(NewP384Point())
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
}
return p, nil
}
// p384Sqrt sets e to a square root of x. If x is not a square, p384Sqrt returns
// false and e is unchanged. e and x can overlap.
func p384Sqrt(e, x *fiat.P384Element) (isSquare bool) {
candidate := new(fiat.P384Element)
p384SqrtCandidate(candidate, x)
square := new(fiat.P384Element).Square(candidate)
if square.Equal(x) != 1 {
return false
}
e.Set(candidate)
return true
}
// p384SqrtCandidate sets z to a square root candidate for x. z and x must not overlap.
func p384SqrtCandidate(z, x *fiat.P384Element) {
// Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate.
//
// The sequence of 14 multiplications and 381 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// _10 = 2*1
// _11 = 1 + _10
// _110 = 2*_11
// _111 = 1 + _110
// _111000 = _111 << 3
// _111111 = _111 + _111000
// _1111110 = 2*_111111
// _1111111 = 1 + _1111110
// x12 = _1111110 << 5 + _111111
// x24 = x12 << 12 + x12
// x31 = x24 << 7 + _1111111
// x32 = 2*x31 + 1
// x63 = x32 << 31 + x31
// x126 = x63 << 63 + x63
// x252 = x126 << 126 + x126
// x255 = x252 << 3 + _111
// return ((x255 << 33 + x32) << 64 + 1) << 30
//
var t0 = new(fiat.P384Element)
var t1 = new(fiat.P384Element)
var t2 = new(fiat.P384Element)
z.Square(x)
z.Mul(x, z)
z.Square(z)
t0.Mul(x, z)
z.Square(t0)
for s := 1; s < 3; s++ {
z.Square(z)
}
t1.Mul(t0, z)
t2.Square(t1)
z.Mul(x, t2)
for s := 0; s < 5; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
t2.Square(t1)
for s := 1; s < 12; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
for s := 0; s < 7; s++ {
t1.Square(t1)
}
t1.Mul(z, t1)
z.Square(t1)
z.Mul(x, z)
t2.Square(z)
for s := 1; s < 31; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
t2.Square(t1)
for s := 1; s < 63; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
t2.Square(t1)
for s := 1; s < 126; s++ {
t2.Square(t2)
}
t1.Mul(t1, t2)
for s := 0; s < 3; s++ {
t1.Square(t1)
}
t0.Mul(t0, t1)
for s := 0; s < 33; s++ {
t0.Square(t0)
}
z.Mul(z, t0)
for s := 0; s < 64; s++ {
z.Square(z)
}
z.Mul(x, z)
for s := 0; s < 30; s++ {
z.Square(z)
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT.
package nistec
import (
"crypto/internal/nistec/fiat"
"crypto/subtle"
"errors"
"sync"
)
// p521ElementLength is the length of an element of the base or scalar field,
// which have the same bytes length for all NIST P curves.
const p521ElementLength = 66
// P521Point is a P521 point. The zero value is NOT valid.
type P521Point struct {
// The point is represented in projective coordinates (X:Y:Z),
// where x = X/Z and y = Y/Z.
x, y, z *fiat.P521Element
}
// NewP521Point returns a new P521Point representing the point at infinity point.
func NewP521Point() *P521Point {
return &P521Point{
x: new(fiat.P521Element),
y: new(fiat.P521Element).One(),
z: new(fiat.P521Element),
}
}
// SetGenerator sets p to the canonical generator and returns p.
func (p *P521Point) SetGenerator() *P521Point {
p.x.SetBytes([]byte{0x0, 0xc6, 0x85, 0x8e, 0x6, 0xb7, 0x4, 0x4, 0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95, 0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x5, 0x3f, 0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d, 0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7, 0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff, 0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a, 0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5, 0xbd, 0x66})
p.y.SetBytes([]byte{0x1, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b, 0xc0, 0x4, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d, 0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b, 0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e, 0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4, 0x26, 0x40, 0xc5, 0x50, 0xb9, 0x1, 0x3f, 0xad, 0x7, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72, 0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1, 0x66, 0x50})
p.z.One()
return p
}
// Set sets p = q and returns p.
func (p *P521Point) Set(q *P521Point) *P521Point {
p.x.Set(q.x)
p.y.Set(q.y)
p.z.Set(q.z)
return p
}
// SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
// b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
// the curve, it returns nil and an error, and the receiver is unchanged.
// Otherwise, it returns p.
func (p *P521Point) SetBytes(b []byte) (*P521Point, error) {
switch {
// Point at infinity.
case len(b) == 1 && b[0] == 0:
return p.Set(NewP521Point()), nil
// Uncompressed form.
case len(b) == 1+2*p521ElementLength && b[0] == 4:
x, err := new(fiat.P521Element).SetBytes(b[1 : 1+p521ElementLength])
if err != nil {
return nil, err
}
y, err := new(fiat.P521Element).SetBytes(b[1+p521ElementLength:])
if err != nil {
return nil, err
}
if err := p521CheckOnCurve(x, y); err != nil {
return nil, err
}
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
// Compressed form.
case len(b) == 1+p521ElementLength && (b[0] == 2 || b[0] == 3):
x, err := new(fiat.P521Element).SetBytes(b[1:])
if err != nil {
return nil, err
}
// y² = x³ - 3x + b
y := p521Polynomial(new(fiat.P521Element), x)
if !p521Sqrt(y, y) {
return nil, errors.New("invalid P521 compressed point encoding")
}
// Select the positive or negative root, as indicated by the least
// significant bit, based on the encoding type byte.
otherRoot := new(fiat.P521Element)
otherRoot.Sub(otherRoot, y)
cond := y.Bytes()[p521ElementLength-1]&1 ^ b[0]&1
y.Select(otherRoot, y, int(cond))
p.x.Set(x)
p.y.Set(y)
p.z.One()
return p, nil
default:
return nil, errors.New("invalid P521 point encoding")
}
}
var _p521B *fiat.P521Element
var _p521BOnce sync.Once
func p521B() *fiat.P521Element {
_p521BOnce.Do(func() {
_p521B, _ = new(fiat.P521Element).SetBytes([]byte{0x0, 0x51, 0x95, 0x3e, 0xb9, 0x61, 0x8e, 0x1c, 0x9a, 0x1f, 0x92, 0x9a, 0x21, 0xa0, 0xb6, 0x85, 0x40, 0xee, 0xa2, 0xda, 0x72, 0x5b, 0x99, 0xb3, 0x15, 0xf3, 0xb8, 0xb4, 0x89, 0x91, 0x8e, 0xf1, 0x9, 0xe1, 0x56, 0x19, 0x39, 0x51, 0xec, 0x7e, 0x93, 0x7b, 0x16, 0x52, 0xc0, 0xbd, 0x3b, 0xb1, 0xbf, 0x7, 0x35, 0x73, 0xdf, 0x88, 0x3d, 0x2c, 0x34, 0xf1, 0xef, 0x45, 0x1f, 0xd4, 0x6b, 0x50, 0x3f, 0x0})
})
return _p521B
}
// p521Polynomial sets y2 to x³ - 3x + b, and returns y2.
func p521Polynomial(y2, x *fiat.P521Element) *fiat.P521Element {
y2.Square(x)
y2.Mul(y2, x)
threeX := new(fiat.P521Element).Add(x, x)
threeX.Add(threeX, x)
y2.Sub(y2, threeX)
return y2.Add(y2, p521B())
}
func p521CheckOnCurve(x, y *fiat.P521Element) error {
// y² = x³ - 3x + b
rhs := p521Polynomial(new(fiat.P521Element), x)
lhs := new(fiat.P521Element).Square(y)
if rhs.Equal(lhs) != 1 {
return errors.New("P521 point not on curve")
}
return nil
}
// Bytes returns the uncompressed or infinity encoding of p, as specified in
// SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
// infinity is shorter than all other encodings.
func (p *P521Point) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + 2*p521ElementLength]byte
return p.bytes(&out)
}
func (p *P521Point) bytes(out *[1 + 2*p521ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P521Element).Invert(p.z)
x := new(fiat.P521Element).Mul(p.x, zinv)
y := new(fiat.P521Element).Mul(p.y, zinv)
buf := append(out[:0], 4)
buf = append(buf, x.Bytes()...)
buf = append(buf, y.Bytes()...)
return buf
}
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *P521Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [p521ElementLength]byte
return p.bytesX(&out)
}
func (p *P521Point) bytesX(out *[p521ElementLength]byte) ([]byte, error) {
if p.z.IsZero() == 1 {
return nil, errors.New("P521 point is the point at infinity")
}
zinv := new(fiat.P521Element).Invert(p.z)
x := new(fiat.P521Element).Mul(p.x, zinv)
return append(out[:0], x.Bytes()...), nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings.
func (p *P521Point) BytesCompressed() []byte {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [1 + p521ElementLength]byte
return p.bytesCompressed(&out)
}
func (p *P521Point) bytesCompressed(out *[1 + p521ElementLength]byte) []byte {
if p.z.IsZero() == 1 {
return append(out[:0], 0)
}
zinv := new(fiat.P521Element).Invert(p.z)
x := new(fiat.P521Element).Mul(p.x, zinv)
y := new(fiat.P521Element).Mul(p.y, zinv)
// Encode the sign of the y coordinate (indicated by the least significant
// bit) as the encoding type (2 or 3).
buf := append(out[:0], 2)
buf[0] |= y.Bytes()[p521ElementLength-1] & 1
buf = append(buf, x.Bytes()...)
return buf
}
// Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *P521Point) Add(p1, p2 *P521Point) *P521Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P521Element).Mul(p1.x, p2.x) // t0 := X1 * X2
t1 := new(fiat.P521Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2
t2 := new(fiat.P521Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2
t3 := new(fiat.P521Element).Add(p1.x, p1.y) // t3 := X1 + Y1
t4 := new(fiat.P521Element).Add(p2.x, p2.y) // t4 := X2 + Y2
t3.Mul(t3, t4) // t3 := t3 * t4
t4.Add(t0, t1) // t4 := t0 + t1
t3.Sub(t3, t4) // t3 := t3 - t4
t4.Add(p1.y, p1.z) // t4 := Y1 + Z1
x3 := new(fiat.P521Element).Add(p2.y, p2.z) // X3 := Y2 + Z2
t4.Mul(t4, x3) // t4 := t4 * X3
x3.Add(t1, t2) // X3 := t1 + t2
t4.Sub(t4, x3) // t4 := t4 - X3
x3.Add(p1.x, p1.z) // X3 := X1 + Z1
y3 := new(fiat.P521Element).Add(p2.x, p2.z) // Y3 := X2 + Z2
x3.Mul(x3, y3) // X3 := X3 * Y3
y3.Add(t0, t2) // Y3 := t0 + t2
y3.Sub(x3, y3) // Y3 := X3 - Y3
z3 := new(fiat.P521Element).Mul(p521B(), t2) // Z3 := b * t2
x3.Sub(y3, z3) // X3 := Y3 - Z3
z3.Add(x3, x3) // Z3 := X3 + X3
x3.Add(x3, z3) // X3 := X3 + Z3
z3.Sub(t1, x3) // Z3 := t1 - X3
x3.Add(t1, x3) // X3 := t1 + X3
y3.Mul(p521B(), y3) // Y3 := b * Y3
t1.Add(t2, t2) // t1 := t2 + t2
t2.Add(t1, t2) // t2 := t1 + t2
y3.Sub(y3, t2) // Y3 := Y3 - t2
y3.Sub(y3, t0) // Y3 := Y3 - t0
t1.Add(y3, y3) // t1 := Y3 + Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
t1.Add(t0, t0) // t1 := t0 + t0
t0.Add(t1, t0) // t0 := t1 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t1.Mul(t4, y3) // t1 := t4 * Y3
t2.Mul(t0, y3) // t2 := t0 * Y3
y3.Mul(x3, z3) // Y3 := X3 * Z3
y3.Add(y3, t2) // Y3 := Y3 + t2
x3.Mul(t3, x3) // X3 := t3 * X3
x3.Sub(x3, t1) // X3 := X3 - t1
z3.Mul(t4, z3) // Z3 := t4 * Z3
t1.Mul(t3, t0) // t1 := t3 * t0
z3.Add(z3, t1) // Z3 := Z3 + t1
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Double sets q = p + p, and returns q. The points may overlap.
func (q *P521Point) Double(p *P521Point) *P521Point {
// Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.P521Element).Square(p.x) // t0 := X ^ 2
t1 := new(fiat.P521Element).Square(p.y) // t1 := Y ^ 2
t2 := new(fiat.P521Element).Square(p.z) // t2 := Z ^ 2
t3 := new(fiat.P521Element).Mul(p.x, p.y) // t3 := X * Y
t3.Add(t3, t3) // t3 := t3 + t3
z3 := new(fiat.P521Element).Mul(p.x, p.z) // Z3 := X * Z
z3.Add(z3, z3) // Z3 := Z3 + Z3
y3 := new(fiat.P521Element).Mul(p521B(), t2) // Y3 := b * t2
y3.Sub(y3, z3) // Y3 := Y3 - Z3
x3 := new(fiat.P521Element).Add(y3, y3) // X3 := Y3 + Y3
y3.Add(x3, y3) // Y3 := X3 + Y3
x3.Sub(t1, y3) // X3 := t1 - Y3
y3.Add(t1, y3) // Y3 := t1 + Y3
y3.Mul(x3, y3) // Y3 := X3 * Y3
x3.Mul(x3, t3) // X3 := X3 * t3
t3.Add(t2, t2) // t3 := t2 + t2
t2.Add(t2, t3) // t2 := t2 + t3
z3.Mul(p521B(), z3) // Z3 := b * Z3
z3.Sub(z3, t2) // Z3 := Z3 - t2
z3.Sub(z3, t0) // Z3 := Z3 - t0
t3.Add(z3, z3) // t3 := Z3 + Z3
z3.Add(z3, t3) // Z3 := Z3 + t3
t3.Add(t0, t0) // t3 := t0 + t0
t0.Add(t3, t0) // t0 := t3 + t0
t0.Sub(t0, t2) // t0 := t0 - t2
t0.Mul(t0, z3) // t0 := t0 * Z3
y3.Add(y3, t0) // Y3 := Y3 + t0
t0.Mul(p.y, p.z) // t0 := Y * Z
t0.Add(t0, t0) // t0 := t0 + t0
z3.Mul(t0, z3) // Z3 := t0 * Z3
x3.Sub(x3, z3) // X3 := X3 - Z3
z3.Mul(t0, t1) // Z3 := t0 * t1
z3.Add(z3, z3) // Z3 := Z3 + Z3
z3.Add(z3, z3) // Z3 := Z3 + Z3
q.x.Set(x3)
q.y.Set(y3)
q.z.Set(z3)
return q
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *P521Point) Select(p1, p2 *P521Point, cond int) *P521Point {
q.x.Select(p1.x, p2.x, cond)
q.y.Select(p1.y, p2.y, cond)
q.z.Select(p1.z, p2.z, cond)
return q
}
// A p521Table holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.
type p521Table [15]*P521Point
// Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *p521Table) Select(p *P521Point, n uint8) {
if n >= 16 {
panic("nistec: internal error: p521Table called with out-of-bounds value")
}
p.Set(NewP521Point())
for i := uint8(1); i < 16; i++ {
cond := subtle.ConstantTimeByteEq(i, n)
p.Select(table[i-1], p, cond)
}
}
// ScalarMult sets p = scalar * q, and returns p.
func (p *P521Point) ScalarMult(q *P521Point, scalar []byte) (*P521Point, error) {
// Compute a p521Table for the base point q. The explicit NewP521Point
// calls get inlined, letting the allocations live on the stack.
var table = p521Table{NewP521Point(), NewP521Point(), NewP521Point(),
NewP521Point(), NewP521Point(), NewP521Point(), NewP521Point(),
NewP521Point(), NewP521Point(), NewP521Point(), NewP521Point(),
NewP521Point(), NewP521Point(), NewP521Point(), NewP521Point()}
table[0].Set(q)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], q)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := NewP521Point()
p.Set(NewP521Point())
for i, byte := range scalar {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
}
windowValue := byte >> 4
table.Select(t, windowValue)
p.Add(p, t)
p.Double(p)
p.Double(p)
p.Double(p)
p.Double(p)
windowValue = byte & 0b1111
table.Select(t, windowValue)
p.Add(p, t)
}
return p, nil
}
var p521GeneratorTable *[p521ElementLength * 2]p521Table
var p521GeneratorTableOnce sync.Once
// generatorTable returns a sequence of p521Tables. The first table contains
// multiples of G. Each successive table is the previous table doubled four
// times.
func (p *P521Point) generatorTable() *[p521ElementLength * 2]p521Table {
p521GeneratorTableOnce.Do(func() {
p521GeneratorTable = new([p521ElementLength * 2]p521Table)
base := NewP521Point().SetGenerator()
for i := 0; i < p521ElementLength*2; i++ {
p521GeneratorTable[i][0] = NewP521Point().Set(base)
for j := 1; j < 15; j++ {
p521GeneratorTable[i][j] = NewP521Point().Add(p521GeneratorTable[i][j-1], base)
}
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return p521GeneratorTable
}
// ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
// returns p.
func (p *P521Point) ScalarBaseMult(scalar []byte) (*P521Point, error) {
if len(scalar) != p521ElementLength {
return nil, errors.New("invalid scalar length")
}
tables := p.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewP521Point()
p.Set(NewP521Point())
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
p.Add(p, t)
tableIndex--
}
return p, nil
}
// p521Sqrt sets e to a square root of x. If x is not a square, p521Sqrt returns
// false and e is unchanged. e and x can overlap.
func p521Sqrt(e, x *fiat.P521Element) (isSquare bool) {
candidate := new(fiat.P521Element)
p521SqrtCandidate(candidate, x)
square := new(fiat.P521Element).Square(candidate)
if square.Equal(x) != 1 {
return false
}
e.Set(candidate)
return true
}
// p521SqrtCandidate sets z to a square root candidate for x. z and x must not overlap.
func p521SqrtCandidate(z, x *fiat.P521Element) {
// Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate.
//
// The sequence of 0 multiplications and 519 squarings is derived from the
// following addition chain generated with github.com/mmcloughlin/addchain v0.4.0.
//
// return 1 << 519
//
z.Square(x)
for s := 1; s < 519; s++ {
z.Square(z)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run gen.go -output md5block.go
// Package md5 implements the MD5 hash algorithm as defined in RFC 1321.
//
// MD5 is cryptographically broken and should not be used for secure
// applications.
package md5
import (
"crypto"
"encoding/binary"
"errors"
"hash"
)
func init() {
crypto.RegisterHash(crypto.MD5, New)
}
// The size of an MD5 checksum in bytes.
const Size = 16
// The blocksize of MD5 in bytes.
const BlockSize = 64
const (
init0 = 0x67452301
init1 = 0xEFCDAB89
init2 = 0x98BADCFE
init3 = 0x10325476
)
// digest represents the partial evaluation of a checksum.
type digest struct {
s [4]uint32
x [BlockSize]byte
nx int
len uint64
}
func (d *digest) Reset() {
d.s[0] = init0
d.s[1] = init1
d.s[2] = init2
d.s[3] = init3
d.nx = 0
d.len = 0
}
const (
magic = "md5\x01"
marshaledSize = len(magic) + 4*4 + BlockSize + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = binary.BigEndian.AppendUint32(b, d.s[0])
b = binary.BigEndian.AppendUint32(b, d.s[1])
b = binary.BigEndian.AppendUint32(b, d.s[2])
b = binary.BigEndian.AppendUint32(b, d.s[3])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-d.nx] // already zero
b = binary.BigEndian.AppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/md5: invalid hash state size")
}
b = b[len(magic):]
b, d.s[0] = consumeUint32(b)
b, d.s[1] = consumeUint32(b)
b, d.s[2] = consumeUint32(b)
b, d.s[3] = consumeUint32(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % BlockSize)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
return b[8:], binary.BigEndian.Uint64(b[0:8])
}
func consumeUint32(b []byte) ([]byte, uint32) {
return b[4:], binary.BigEndian.Uint32(b[0:4])
}
// New returns a new hash.Hash computing the MD5 checksum. The Hash also
// implements encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
func New() hash.Hash {
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
// Note that we currently call block or blockGeneric
// directly (guarded using haveAsm) because this allows
// escape analysis to see that p and d don't escape.
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == BlockSize {
if haveAsm {
block(d, d.x[:])
} else {
blockGeneric(d, d.x[:])
}
d.nx = 0
}
p = p[n:]
}
if len(p) >= BlockSize {
n := len(p) &^ (BlockSize - 1)
if haveAsm {
block(d, p[:n])
} else {
blockGeneric(d, p[:n])
}
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
hash := d0.checkSum()
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
// Append 0x80 to the end of the message and then append zeros
// until the length is a multiple of 56 bytes. Finally append
// 8 bytes representing the message length in bits.
//
// 1 byte end marker :: 0-63 padding bytes :: 8 byte length
tmp := [1 + 63 + 8]byte{0x80}
pad := (55 - d.len) % 64 // calculate number of padding bytes
binary.LittleEndian.PutUint64(tmp[1+pad:], d.len<<3) // append length in bits
d.Write(tmp[:1+pad+8])
// The previous write ensures that a whole number of
// blocks (i.e. a multiple of 64 bytes) have been hashed.
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
binary.LittleEndian.PutUint32(digest[0:], d.s[0])
binary.LittleEndian.PutUint32(digest[4:], d.s[1])
binary.LittleEndian.PutUint32(digest[8:], d.s[2])
binary.LittleEndian.PutUint32(digest[12:], d.s[3])
return digest
}
// Sum returns the MD5 checksum of the data.
func Sum(data []byte) [Size]byte {
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by go run gen.go -output md5block.go; DO NOT EDIT.
package md5
import (
"encoding/binary"
"math/bits"
)
func blockGeneric(dig *digest, p []byte) {
// load state
a, b, c, d := dig.s[0], dig.s[1], dig.s[2], dig.s[3]
for i := 0; i <= len(p)-BlockSize; i += BlockSize {
// eliminate bounds checks on p
q := p[i:]
q = q[:BlockSize:BlockSize]
// save current state
aa, bb, cc, dd := a, b, c, d
// load input block
x0 := binary.LittleEndian.Uint32(q[4*0x0:])
x1 := binary.LittleEndian.Uint32(q[4*0x1:])
x2 := binary.LittleEndian.Uint32(q[4*0x2:])
x3 := binary.LittleEndian.Uint32(q[4*0x3:])
x4 := binary.LittleEndian.Uint32(q[4*0x4:])
x5 := binary.LittleEndian.Uint32(q[4*0x5:])
x6 := binary.LittleEndian.Uint32(q[4*0x6:])
x7 := binary.LittleEndian.Uint32(q[4*0x7:])
x8 := binary.LittleEndian.Uint32(q[4*0x8:])
x9 := binary.LittleEndian.Uint32(q[4*0x9:])
xa := binary.LittleEndian.Uint32(q[4*0xa:])
xb := binary.LittleEndian.Uint32(q[4*0xb:])
xc := binary.LittleEndian.Uint32(q[4*0xc:])
xd := binary.LittleEndian.Uint32(q[4*0xd:])
xe := binary.LittleEndian.Uint32(q[4*0xe:])
xf := binary.LittleEndian.Uint32(q[4*0xf:])
// round 1
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x0+0xd76aa478, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x1+0xe8c7b756, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+x2+0x242070db, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+x3+0xc1bdceee, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x4+0xf57c0faf, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x5+0x4787c62a, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+x6+0xa8304613, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+x7+0xfd469501, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+x8+0x698098d8, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+x9+0x8b44f7af, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+xa+0xffff5bb1, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+xb+0x895cd7be, 22)
a = b + bits.RotateLeft32((((c^d)&b)^d)+a+xc+0x6b901122, 7)
d = a + bits.RotateLeft32((((b^c)&a)^c)+d+xd+0xfd987193, 12)
c = d + bits.RotateLeft32((((a^b)&d)^b)+c+xe+0xa679438e, 17)
b = c + bits.RotateLeft32((((d^a)&c)^a)+b+xf+0x49b40821, 22)
// round 2
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x1+0xf61e2562, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+x6+0xc040b340, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+xb+0x265e5a51, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x0+0xe9b6c7aa, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x5+0xd62f105d, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+xa+0x02441453, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+xf+0xd8a1e681, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x4+0xe7d3fbc8, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+x9+0x21e1cde6, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+xe+0xc33707d6, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+x3+0xf4d50d87, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+x8+0x455a14ed, 20)
a = b + bits.RotateLeft32((((b^c)&d)^c)+a+xd+0xa9e3e905, 5)
d = a + bits.RotateLeft32((((a^b)&c)^b)+d+x2+0xfcefa3f8, 9)
c = d + bits.RotateLeft32((((d^a)&b)^a)+c+x7+0x676f02d9, 14)
b = c + bits.RotateLeft32((((c^d)&a)^d)+b+xc+0x8d2a4c8a, 20)
// round 3
a = b + bits.RotateLeft32((b^c^d)+a+x5+0xfffa3942, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x8+0x8771f681, 11)
c = d + bits.RotateLeft32((d^a^b)+c+xb+0x6d9d6122, 16)
b = c + bits.RotateLeft32((c^d^a)+b+xe+0xfde5380c, 23)
a = b + bits.RotateLeft32((b^c^d)+a+x1+0xa4beea44, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x4+0x4bdecfa9, 11)
c = d + bits.RotateLeft32((d^a^b)+c+x7+0xf6bb4b60, 16)
b = c + bits.RotateLeft32((c^d^a)+b+xa+0xbebfbc70, 23)
a = b + bits.RotateLeft32((b^c^d)+a+xd+0x289b7ec6, 4)
d = a + bits.RotateLeft32((a^b^c)+d+x0+0xeaa127fa, 11)
c = d + bits.RotateLeft32((d^a^b)+c+x3+0xd4ef3085, 16)
b = c + bits.RotateLeft32((c^d^a)+b+x6+0x04881d05, 23)
a = b + bits.RotateLeft32((b^c^d)+a+x9+0xd9d4d039, 4)
d = a + bits.RotateLeft32((a^b^c)+d+xc+0xe6db99e5, 11)
c = d + bits.RotateLeft32((d^a^b)+c+xf+0x1fa27cf8, 16)
b = c + bits.RotateLeft32((c^d^a)+b+x2+0xc4ac5665, 23)
// round 4
a = b + bits.RotateLeft32((c^(b|^d))+a+x0+0xf4292244, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+x7+0x432aff97, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+xe+0xab9423a7, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x5+0xfc93a039, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+xc+0x655b59c3, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+x3+0x8f0ccc92, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+xa+0xffeff47d, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x1+0x85845dd1, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+x8+0x6fa87e4f, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+xf+0xfe2ce6e0, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+x6+0xa3014314, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+xd+0x4e0811a1, 21)
a = b + bits.RotateLeft32((c^(b|^d))+a+x4+0xf7537e82, 6)
d = a + bits.RotateLeft32((b^(a|^c))+d+xb+0xbd3af235, 10)
c = d + bits.RotateLeft32((a^(d|^b))+c+x2+0x2ad7d2bb, 15)
b = c + bits.RotateLeft32((d^(c|^a))+b+x9+0xeb86d391, 21)
// add saved state
a += aa
b += bb
c += cc
d += dd
}
// save state
dig.s[0], dig.s[1], dig.s[2], dig.s[3] = a, b, c, d
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rand implements a cryptographically secure
// random number generator.
package rand
import "io"
// Reader is a global, shared instance of a cryptographically
// secure random number generator.
//
// On Linux, FreeBSD, Dragonfly, NetBSD and Solaris, Reader uses getrandom(2) if
// available, /dev/urandom otherwise.
// On OpenBSD and macOS, Reader uses getentropy(2).
// On other Unix-like systems, Reader reads from /dev/urandom.
// On Windows systems, Reader uses the RtlGenRandom API.
// On Wasm, Reader uses the Web Crypto API.
var Reader io.Reader
// Read is a helper function that calls Reader.Read using io.ReadFull.
// On return, n == len(b) if and only if err == nil.
func Read(b []byte) (n int, err error) {
return io.ReadFull(Reader, b)
}
// batched returns a function that calls f to populate a []byte by chunking it
// into subslices of, at most, readMax bytes.
func batched(f func([]byte) error, readMax int) func([]byte) error {
return func(out []byte) error {
for len(out) > 0 {
read := len(out)
if read > readMax {
read = readMax
}
if err := f(out[:read]); err != nil {
return err
}
out = out[read:]
}
return nil
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || solaris
package rand
import (
"internal/syscall/unix"
"runtime"
"syscall"
)
func init() {
var maxGetRandomRead int
switch runtime.GOOS {
case "linux", "android":
// Per the manpage:
// When reading from the urandom source, a maximum of 33554431 bytes
// is returned by a single call to getrandom() on systems where int
// has a size of 32 bits.
maxGetRandomRead = (1 << 25) - 1
case "dragonfly", "freebsd", "illumos", "netbsd", "solaris":
maxGetRandomRead = 1 << 8
default:
panic("no maximum specified for GetRandom")
}
altGetRandom = batched(getRandom, maxGetRandomRead)
}
// If the kernel is too old to support the getrandom syscall(),
// unix.GetRandom will immediately return ENOSYS and we will then fall back to
// reading from /dev/urandom in rand_unix.go. unix.GetRandom caches the ENOSYS
// result so we only suffer the syscall overhead once in this case.
// If the kernel supports the getrandom() syscall, unix.GetRandom will block
// until the kernel has sufficient randomness (as we don't use GRND_NONBLOCK).
// In this case, unix.GetRandom will not return an error.
func getRandom(p []byte) error {
n, err := unix.GetRandom(p, 0)
if err != nil {
return err
}
if n != len(p) {
return syscall.EIO
}
return nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
// Unix cryptographically secure pseudorandom number
// generator.
package rand
import (
"crypto/internal/boring"
"errors"
"io"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
)
const urandomDevice = "/dev/urandom"
func init() {
if boring.Enabled {
Reader = boring.RandReader
return
}
Reader = &reader{}
}
// A reader satisfies reads by reading from urandomDevice
type reader struct {
f io.Reader
mu sync.Mutex
used atomic.Uint32 // Atomic: 0 - never used, 1 - used, but f == nil, 2 - used, and f != nil
}
// altGetRandom if non-nil specifies an OS-specific function to get
// urandom-style randomness.
var altGetRandom func([]byte) (err error)
func warnBlocked() {
println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
}
func (r *reader) Read(b []byte) (n int, err error) {
boring.Unreachable()
if r.used.CompareAndSwap(0, 1) {
// First use of randomness. Start timer to warn about
// being blocked on entropy not being available.
t := time.AfterFunc(time.Minute, warnBlocked)
defer t.Stop()
}
if altGetRandom != nil && altGetRandom(b) == nil {
return len(b), nil
}
if r.used.Load() != 2 {
r.mu.Lock()
if r.used.Load() != 2 {
f, err := os.Open(urandomDevice)
if err != nil {
r.mu.Unlock()
return 0, err
}
r.f = hideAgainReader{f}
r.used.Store(2)
}
r.mu.Unlock()
}
return io.ReadFull(r.f, b)
}
// hideAgainReader masks EAGAIN reads from /dev/urandom.
// See golang.org/issue/9205
type hideAgainReader struct {
r io.Reader
}
func (hr hideAgainReader) Read(p []byte) (n int, err error) {
n, err = hr.r.Read(p)
if errors.Is(err, syscall.EAGAIN) {
err = nil
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"crypto/internal/randutil"
"errors"
"io"
"math/big"
)
// Prime returns a number of the given bit length that is prime with high probability.
// Prime will return error for any error returned by rand.Read or if bits < 2.
func Prime(rand io.Reader, bits int) (*big.Int, error) {
if bits < 2 {
return nil, errors.New("crypto/rand: prime size must be at least 2-bit")
}
randutil.MaybeReadByte(rand)
b := uint(bits % 8)
if b == 0 {
b = 8
}
bytes := make([]byte, (bits+7)/8)
p := new(big.Int)
for {
if _, err := io.ReadFull(rand, bytes); err != nil {
return nil, err
}
// Clear bits in the first byte to make sure the candidate has a size <= bits.
bytes[0] &= uint8(int(1<<b) - 1)
// Don't let the value be too small, i.e, set the most significant two bits.
// Setting the top two bits, rather than just the top bit,
// means that when two of these values are multiplied together,
// the result isn't ever one bit short.
if b >= 2 {
bytes[0] |= 3 << (b - 2)
} else {
// Here b==1, because b cannot be zero.
bytes[0] |= 1
if len(bytes) > 1 {
bytes[1] |= 0x80
}
}
// Make the value odd since an even number this large certainly isn't prime.
bytes[len(bytes)-1] |= 1
p.SetBytes(bytes)
if p.ProbablyPrime(20) {
return p, nil
}
}
}
// Int returns a uniform random value in [0, max). It panics if max <= 0.
func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
if max.Sign() <= 0 {
panic("crypto/rand: argument to Int is <= 0")
}
n = new(big.Int)
n.Sub(max, n.SetUint64(1))
// bitLen is the maximum bit length needed to encode a value < max.
bitLen := n.BitLen()
if bitLen == 0 {
// the only valid result is 0
return
}
// k is the maximum byte length needed to encode a value < max.
k := (bitLen + 7) / 8
// b is the number of bits in the most significant byte of max-1.
b := uint(bitLen % 8)
if b == 0 {
b = 8
}
bytes := make([]byte, k)
for {
_, err = io.ReadFull(rand, bytes)
if err != nil {
return nil, err
}
// Clear bits in the first byte to increase the probability
// that the candidate is < max.
bytes[0] &= uint8(int(1<<b) - 1)
n.SetBytes(bytes)
if n.Cmp(max) < 0 {
return
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rc4 implements RC4 encryption, as defined in Bruce Schneier's
// Applied Cryptography.
//
// RC4 is cryptographically broken and should not be used for secure
// applications.
package rc4
import (
"crypto/internal/alias"
"strconv"
)
// A Cipher is an instance of RC4 using a particular key.
type Cipher struct {
s [256]uint32
i, j uint8
}
type KeySizeError int
func (k KeySizeError) Error() string {
return "crypto/rc4: invalid key size " + strconv.Itoa(int(k))
}
// NewCipher creates and returns a new Cipher. The key argument should be the
// RC4 key, at least 1 byte and at most 256 bytes.
func NewCipher(key []byte) (*Cipher, error) {
k := len(key)
if k < 1 || k > 256 {
return nil, KeySizeError(k)
}
var c Cipher
for i := 0; i < 256; i++ {
c.s[i] = uint32(i)
}
var j uint8 = 0
for i := 0; i < 256; i++ {
j += uint8(c.s[i]) + key[i%k]
c.s[i], c.s[j] = c.s[j], c.s[i]
}
return &c, nil
}
// Reset zeros the key data and makes the Cipher unusable.
//
// Deprecated: Reset can't guarantee that the key will be entirely removed from
// the process's memory.
func (c *Cipher) Reset() {
for i := range c.s {
c.s[i] = 0
}
c.i, c.j = 0, 0
}
// XORKeyStream sets dst to the result of XORing src with the key stream.
// Dst and src must overlap entirely or not at all.
func (c *Cipher) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("crypto/rc4: invalid buffer overlap")
}
i, j := c.i, c.j
_ = dst[len(src)-1]
dst = dst[:len(src)] // eliminate bounds check from loop
for k, v := range src {
i += 1
x := c.s[i]
j += uint8(x)
y := c.s[j]
c.s[i], c.s[j] = y, x
dst[k] = v ^ uint8(c.s[uint8(x+y)])
}
c.i, c.j = i, j
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package rsa
import "crypto/internal/boring"
func boringPublicKey(*PublicKey) (*boring.PublicKeyRSA, error) {
panic("boringcrypto: not available")
}
func boringPrivateKey(*PrivateKey) (*boring.PrivateKeyRSA, error) {
panic("boringcrypto: not available")
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rsa
import (
"crypto"
"crypto/internal/boring"
"crypto/internal/randutil"
"crypto/subtle"
"errors"
"io"
)
// This file implements encryption and decryption using PKCS #1 v1.5 padding.
// PKCS1v15DecryptOptions is for passing options to PKCS #1 v1.5 decryption using
// the crypto.Decrypter interface.
type PKCS1v15DecryptOptions struct {
// SessionKeyLen is the length of the session key that is being
// decrypted. If not zero, then a padding error during decryption will
// cause a random plaintext of this length to be returned rather than
// an error. These alternatives happen in constant time.
SessionKeyLen int
}
// EncryptPKCS1v15 encrypts the given message with RSA and the padding
// scheme from PKCS #1 v1.5. The message must be no longer than the
// length of the public modulus minus 11 bytes.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same
// ciphertext.
//
// WARNING: use of this function to encrypt plaintexts other than
// session keys is dangerous. Use RSA OAEP in new protocols.
func EncryptPKCS1v15(random io.Reader, pub *PublicKey, msg []byte) ([]byte, error) {
randutil.MaybeReadByte(random)
if err := checkPub(pub); err != nil {
return nil, err
}
k := pub.Size()
if len(msg) > k-11 {
return nil, ErrMessageTooLong
}
if boring.Enabled && random == boring.RandReader {
bkey, err := boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSAPKCS1(bkey, msg)
}
boring.UnreachableExceptTests()
// EM = 0x00 || 0x02 || PS || 0x00 || M
em := make([]byte, k)
em[1] = 2
ps, mm := em[2:len(em)-len(msg)-1], em[len(em)-len(msg):]
err := nonZeroRandomBytes(ps, random)
if err != nil {
return nil, err
}
em[len(em)-len(msg)-1] = 0
copy(mm, msg)
if boring.Enabled {
var bkey *boring.PublicKeyRSA
bkey, err = boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSANoPadding(bkey, em)
}
return encrypt(pub, em)
}
// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
// The random parameter is legacy and ignored, and it can be as nil.
//
// Note that whether this function returns an error or not discloses secret
// information. If an attacker can cause this function to run repeatedly and
// learn whether each instance returned an error then they can decrypt and
// forge signatures as if they had the private key. See
// DecryptPKCS1v15SessionKey for a way of solving this problem.
func DecryptPKCS1v15(random io.Reader, priv *PrivateKey, ciphertext []byte) ([]byte, error) {
if err := checkPub(&priv.PublicKey); err != nil {
return nil, err
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
out, err := boring.DecryptRSAPKCS1(bkey, ciphertext)
if err != nil {
return nil, ErrDecryption
}
return out, nil
}
valid, out, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return nil, err
}
if valid == 0 {
return nil, ErrDecryption
}
return out[index:], nil
}
// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS #1 v1.5.
// The random parameter is legacy and ignored, and it can be as nil.
// It returns an error if the ciphertext is the wrong length or if the
// ciphertext is greater than the public modulus. Otherwise, no error is
// returned. If the padding is valid, the resulting plaintext message is copied
// into key. Otherwise, key is unchanged. These alternatives occur in constant
// time. It is intended that the user of this function generate a random
// session key beforehand and continue the protocol with the resulting value.
// This will remove any possibility that an attacker can learn any information
// about the plaintext.
// See “Chosen Ciphertext Attacks Against Protocols Based on the RSA
// Encryption Standard PKCS #1”, Daniel Bleichenbacher, Advances in Cryptology
// (Crypto '98).
//
// Note that if the session key is too small then it may be possible for an
// attacker to brute-force it. If they can do that then they can learn whether
// a random value was used (because it'll be different for the same ciphertext)
// and thus whether the padding was correct. This defeats the point of this
// function. Using at least a 16-byte key will protect against this attack.
func DecryptPKCS1v15SessionKey(random io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) error {
if err := checkPub(&priv.PublicKey); err != nil {
return err
}
k := priv.Size()
if k-(len(key)+3+8) < 0 {
return ErrDecryption
}
valid, em, index, err := decryptPKCS1v15(priv, ciphertext)
if err != nil {
return err
}
if len(em) != k {
// This should be impossible because decryptPKCS1v15 always
// returns the full slice.
return ErrDecryption
}
valid &= subtle.ConstantTimeEq(int32(len(em)-index), int32(len(key)))
subtle.ConstantTimeCopy(valid, key, em[len(em)-len(key):])
return nil
}
// decryptPKCS1v15 decrypts ciphertext using priv. It returns one or zero in
// valid that indicates whether the plaintext was correctly structured.
// In either case, the plaintext is returned in em so that it may be read
// independently of whether it was valid in order to maintain constant memory
// access patterns. If the plaintext was valid then index contains the index of
// the original message in em, to allow constant time padding removal.
func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) {
k := priv.Size()
if k < 11 {
err = ErrDecryption
return
}
if boring.Enabled {
var bkey *boring.PrivateKeyRSA
bkey, err = boringPrivateKey(priv)
if err != nil {
return
}
em, err = boring.DecryptRSANoPadding(bkey, ciphertext)
if err != nil {
return
}
} else {
em, err = decrypt(priv, ciphertext, noCheck)
if err != nil {
return
}
}
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
// The remainder of the plaintext must be a string of non-zero random
// octets, followed by a 0, followed by the message.
// lookingForIndex: 1 iff we are still looking for the zero.
// index: the offset of the first zero byte.
lookingForIndex := 1
for i := 2; i < len(em); i++ {
equals0 := subtle.ConstantTimeByteEq(em[i], 0)
index = subtle.ConstantTimeSelect(lookingForIndex&equals0, i, index)
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
}
// The PS padding must be at least 8 bytes long, and it starts two
// bytes into em.
validPS := subtle.ConstantTimeLessOrEq(2+8, index)
valid = firstByteIsZero & secondByteIsTwo & (^lookingForIndex & 1) & validPS
index = subtle.ConstantTimeSelect(valid, index+1, 0)
return valid, em, index, nil
}
// nonZeroRandomBytes fills the given slice with non-zero random octets.
func nonZeroRandomBytes(s []byte, random io.Reader) (err error) {
_, err = io.ReadFull(random, s)
if err != nil {
return
}
for i := 0; i < len(s); i++ {
for s[i] == 0 {
_, err = io.ReadFull(random, s[i:i+1])
if err != nil {
return
}
// In tests, the PRNG may return all zeros so we do
// this to break the loop.
s[i] ^= 0x42
}
}
return
}
// These are ASN1 DER structures:
//
// DigestInfo ::= SEQUENCE {
// digestAlgorithm AlgorithmIdentifier,
// digest OCTET STRING
// }
//
// For performance, we don't use the generic ASN1 encoder. Rather, we
// precompute a prefix of the digest value that makes a valid ASN1 DER string
// with the correct contents.
var hashPrefixes = map[crypto.Hash][]byte{
crypto.MD5: {0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10},
crypto.SHA1: {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14},
crypto.SHA224: {0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c},
crypto.SHA256: {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20},
crypto.SHA384: {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30},
crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40},
crypto.MD5SHA1: {}, // A special TLS case which doesn't use an ASN1 prefix.
crypto.RIPEMD160: {0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14},
}
// SignPKCS1v15 calculates the signature of hashed using
// RSASSA-PKCS1-V1_5-SIGN from RSA PKCS #1 v1.5. Note that hashed must
// be the result of hashing the input message using the given hash
// function. If hash is zero, hashed is signed directly. This isn't
// advisable except for interoperability.
//
// The random parameter is legacy and ignored, and it can be as nil.
//
// This function is deterministic. Thus, if the set of possible
// messages is small, an attacker may be able to build a map from
// messages to signatures and identify the signed messages. As ever,
// signatures provide authenticity, not confidentiality.
func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte) ([]byte, error) {
hashLen, prefix, err := pkcs1v15HashInfo(hash, len(hashed))
if err != nil {
return nil, err
}
tLen := len(prefix) + hashLen
k := priv.Size()
if k < tLen+11 {
return nil, ErrMessageTooLong
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignRSAPKCS1v15(bkey, hash, hashed)
}
// EM = 0x00 || 0x01 || PS || 0x00 || T
em := make([]byte, k)
em[1] = 1
for i := 2; i < k-tLen-1; i++ {
em[i] = 0xff
}
copy(em[k-tLen:k-hashLen], prefix)
copy(em[k-hashLen:k], hashed)
return decrypt(priv, em, withCheck)
}
// VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature.
// hashed is the result of hashing the input message using the given hash
// function and sig is the signature. A valid signature is indicated by
// returning a nil error. If hash is zero then hashed is used directly. This
// isn't advisable except for interoperability.
func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) error {
if boring.Enabled {
bkey, err := boringPublicKey(pub)
if err != nil {
return err
}
if err := boring.VerifyRSAPKCS1v15(bkey, hash, hashed, sig); err != nil {
return ErrVerification
}
return nil
}
hashLen, prefix, err := pkcs1v15HashInfo(hash, len(hashed))
if err != nil {
return err
}
tLen := len(prefix) + hashLen
k := pub.Size()
if k < tLen+11 {
return ErrVerification
}
// RFC 8017 Section 8.2.2: If the length of the signature S is not k
// octets (where k is the length in octets of the RSA modulus n), output
// "invalid signature" and stop.
if k != len(sig) {
return ErrVerification
}
em, err := encrypt(pub, sig)
if err != nil {
return ErrVerification
}
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
ok &= subtle.ConstantTimeByteEq(em[1], 1)
ok &= subtle.ConstantTimeCompare(em[k-hashLen:k], hashed)
ok &= subtle.ConstantTimeCompare(em[k-tLen:k-hashLen], prefix)
ok &= subtle.ConstantTimeByteEq(em[k-tLen-1], 0)
for i := 2; i < k-tLen-1; i++ {
ok &= subtle.ConstantTimeByteEq(em[i], 0xff)
}
if ok != 1 {
return ErrVerification
}
return nil
}
func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte, err error) {
// Special case: crypto.Hash(0) is used to indicate that the data is
// signed directly.
if hash == 0 {
return inLen, nil, nil
}
hashLen = hash.Size()
if inLen != hashLen {
return 0, nil, errors.New("crypto/rsa: input must be hashed message")
}
prefix, ok := hashPrefixes[hash]
if !ok {
return 0, nil, errors.New("crypto/rsa: unsupported hash function")
}
return
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rsa
// This file implements the RSASSA-PSS signature scheme according to RFC 8017.
import (
"bytes"
"crypto"
"crypto/internal/boring"
"errors"
"hash"
"io"
)
// Per RFC 8017, Section 9.1
//
// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
//
// where
//
// DB = PS || 0x01 || salt
//
// and PS can be empty so
//
// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
//
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
// See RFC 8017, Section 9.1.1.
hLen := hash.Size()
sLen := len(salt)
emLen := (emBits + 7) / 8
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "message too
// long" and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
if len(mHash) != hLen {
return nil, errors.New("crypto/rsa: input must be hashed with given hash")
}
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
if emLen < hLen+sLen+2 {
return nil, ErrMessageTooLong
}
em := make([]byte, emLen)
psLen := emLen - sLen - hLen - 2
db := em[:psLen+1+sLen]
h := em[psLen+1+sLen : emLen-1]
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
// then salt is the empty string.
//
// 5. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
//
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
//
// 6. Let H = Hash(M'), an octet string of length hLen.
var prefix [8]byte
hash.Write(prefix[:])
hash.Write(mHash)
hash.Write(salt)
h = hash.Sum(h[:0])
hash.Reset()
// 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2
// zero octets. The length of PS may be 0.
//
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
// emLen - hLen - 1.
db[psLen] = 0x01
copy(db[psLen+1:], salt)
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
//
// 10. Let maskedDB = DB \xor dbMask.
mgf1XOR(db, hash, h)
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB to zero.
db[0] &= 0xff >> (8*emLen - emBits)
// 12. Let EM = maskedDB || H || 0xbc.
em[emLen-1] = 0xbc
// 13. Output EM.
return em, nil
}
func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// See RFC 8017, Section 9.1.2.
hLen := hash.Size()
if sLen == PSSSaltLengthEqualsHash {
sLen = hLen
}
emLen := (emBits + 7) / 8
if emLen != len(em) {
return errors.New("rsa: internal error: inconsistent length")
}
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
// and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
if hLen != len(mHash) {
return ErrVerification
}
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
if emLen < hLen+sLen+2 {
return ErrVerification
}
// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
if em[emLen-1] != 0xbc {
return ErrVerification
}
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
// let H be the next hLen octets.
db := em[:emLen-hLen-1]
h := em[emLen-hLen-1 : emLen-1]
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
var bitMask byte = 0xff >> (8*emLen - emBits)
if em[0] & ^bitMask != 0 {
return ErrVerification
}
// 7. Let dbMask = MGF(H, emLen - hLen - 1).
//
// 8. Let DB = maskedDB \xor dbMask.
mgf1XOR(db, hash, h)
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
// to zero.
db[0] &= bitMask
// If we don't know the salt length, look for the 0x01 delimiter.
if sLen == PSSSaltLengthAuto {
psLen := bytes.IndexByte(db, 0x01)
if psLen < 0 {
return ErrVerification
}
sLen = len(db) - psLen - 1
}
// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
// position is "position 1") does not have hexadecimal value 0x01,
// output "inconsistent" and stop.
psLen := emLen - hLen - sLen - 2
for _, e := range db[:psLen] {
if e != 0x00 {
return ErrVerification
}
}
if db[psLen] != 0x01 {
return ErrVerification
}
// 11. Let salt be the last sLen octets of DB.
salt := db[len(db)-sLen:]
// 12. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
//
// 13. Let H' = Hash(M'), an octet string of length hLen.
var prefix [8]byte
hash.Write(prefix[:])
hash.Write(mHash)
hash.Write(salt)
h0 := hash.Sum(nil)
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
if !bytes.Equal(h0, h) { // TODO: constant time?
return ErrVerification
}
return nil
}
// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
emBits := priv.N.BitLen() - 1
em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
return nil, err
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
// Note: BoringCrypto always does decrypt "withCheck".
// (It's not just decrypt.)
s, err := boring.DecryptRSANoPadding(bkey, em)
if err != nil {
return nil, err
}
return s, nil
}
// RFC 8017: "Note that the octet length of EM will be one less than k if
// modBits - 1 is divisible by 8 and equal to k otherwise, where k is the
// length in octets of the RSA modulus n." 🙄
//
// This is extremely annoying, as all other encrypt and decrypt inputs are
// always the exact same size as the modulus. Since it only happens for
// weird modulus sizes, fix it by padding inefficiently.
if emLen, k := len(em), priv.Size(); emLen < k {
emNew := make([]byte, k)
copy(emNew[k-emLen:], em)
em = emNew
}
return decrypt(priv, em, withCheck)
}
const (
// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
// as possible when signing, and to be auto-detected when verifying.
PSSSaltLengthAuto = 0
// PSSSaltLengthEqualsHash causes the salt length to equal the length
// of the hash used in the signature.
PSSSaltLengthEqualsHash = -1
)
// PSSOptions contains options for creating and verifying PSS signatures.
type PSSOptions struct {
// SaltLength controls the length of the salt used in the PSS signature. It
// can either be a positive number of bytes, or one of the special
// PSSSaltLength constants.
SaltLength int
// Hash is the hash function used to generate the message digest. If not
// zero, it overrides the hash function passed to SignPSS. It's required
// when using PrivateKey.Sign.
Hash crypto.Hash
}
// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
func (opts *PSSOptions) HashFunc() crypto.Hash {
return opts.Hash
}
func (opts *PSSOptions) saltLength() int {
if opts == nil {
return PSSSaltLengthAuto
}
return opts.SaltLength
}
var invalidSaltLenErr = errors.New("crypto/rsa: PSSOptions.SaltLength cannot be negative")
// SignPSS calculates the signature of digest using PSS.
//
// digest must be the result of hashing the input message using the given hash
// function. The opts argument may be nil, in which case sensible defaults are
// used. If opts.Hash is set, it overrides hash.
func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
if boring.Enabled && rand == boring.RandReader {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
return boring.SignRSAPSS(bkey, hash, digest, opts.saltLength())
}
boring.UnreachableExceptTests()
if opts != nil && opts.Hash != 0 {
hash = opts.Hash
}
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
if saltLength < 0 {
return nil, ErrMessageTooLong
}
case PSSSaltLengthEqualsHash:
saltLength = hash.Size()
default:
// If we get here saltLength is either > 0 or < -1, in the
// latter case we fail out.
if saltLength <= 0 {
return nil, invalidSaltLenErr
}
}
salt := make([]byte, saltLength)
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
}
return signPSSWithSalt(priv, hash, digest, salt)
}
// VerifyPSS verifies a PSS signature.
//
// A valid signature is indicated by returning a nil error. digest must be the
// result of hashing the input message using the given hash function. The opts
// argument may be nil, in which case sensible defaults are used. opts.Hash is
// ignored.
func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
if boring.Enabled {
bkey, err := boringPublicKey(pub)
if err != nil {
return err
}
if err := boring.VerifyRSAPSS(bkey, hash, digest, sig, opts.saltLength()); err != nil {
return ErrVerification
}
return nil
}
if len(sig) != pub.Size() {
return ErrVerification
}
// Salt length must be either one of the special constants (-1 or 0)
// or otherwise positive. If it is < PSSSaltLengthEqualsHash (-1)
// we return an error.
if opts.saltLength() < PSSSaltLengthEqualsHash {
return invalidSaltLenErr
}
emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
em, err := encrypt(pub, sig)
if err != nil {
return ErrVerification
}
// Like in signPSSWithSalt, deal with mismatches between emLen and the size
// of the modulus. The spec would have us wire emLen into the encoding
// function, but we'd rather always encode to the size of the modulus and
// then strip leading zeroes if necessary. This only happens for weird
// modulus sizes anyway.
for len(em) > emLen && len(em) > 0 {
if em[0] != 0 {
return ErrVerification
}
em = em[1:]
}
return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rsa implements RSA encryption as specified in PKCS #1 and RFC 8017.
//
// RSA is a single, fundamental operation that is used in this package to
// implement either public-key encryption or public-key signatures.
//
// The original specification for encryption and signatures with RSA is PKCS #1
// and the terms "RSA encryption" and "RSA signatures" by default refer to
// PKCS #1 version 1.5. However, that specification has flaws and new designs
// should use version 2, usually called by just OAEP and PSS, where
// possible.
//
// Two sets of interfaces are included in this package. When a more abstract
// interface isn't necessary, there are functions for encrypting/decrypting
// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
// over the public key primitive, the PrivateKey type implements the
// Decrypter and Signer interfaces from the crypto package.
//
// Operations in this package are implemented using constant-time algorithms,
// except for [GenerateKey], [PrivateKey.Precompute], and [PrivateKey.Validate].
// Every other operation only leaks the bit size of the involved values, which
// all depend on the selected key size.
package rsa
import (
"crypto"
"crypto/internal/bigmod"
"crypto/internal/boring"
"crypto/internal/boring/bbig"
"crypto/internal/randutil"
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
"io"
"math"
"math/big"
)
var bigOne = big.NewInt(1)
// A PublicKey represents the public part of an RSA key.
type PublicKey struct {
N *big.Int // modulus
E int // public exponent
}
// Any methods implemented on PublicKey might need to also be implemented on
// PrivateKey, as the latter embeds the former and will expose its methods.
// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func (pub *PublicKey) Size() int {
return (pub.N.BitLen() + 7) / 8
}
// Equal reports whether pub and x have the same value.
func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
xx, ok := x.(*PublicKey)
if !ok {
return false
}
return pub.N.Cmp(xx.N) == 0 && pub.E == xx.E
}
// OAEPOptions is an interface for passing options to OAEP decryption using the
// crypto.Decrypter interface.
type OAEPOptions struct {
// Hash is the hash function that will be used when generating the mask.
Hash crypto.Hash
// MGFHash is the hash function used for MGF1.
// If zero, Hash is used instead.
MGFHash crypto.Hash
// Label is an arbitrary byte string that must be equal to the value
// used when encrypting.
Label []byte
}
var (
errPublicModulus = errors.New("crypto/rsa: missing public modulus")
errPublicExponentSmall = errors.New("crypto/rsa: public exponent too small")
errPublicExponentLarge = errors.New("crypto/rsa: public exponent too large")
)
// checkPub sanity checks the public key before we use it.
// We require pub.E to fit into a 32-bit integer so that we
// do not have different behavior depending on whether
// int is 32 or 64 bits. See also
// https://www.imperialviolet.org/2012/03/16/rsae.html.
func checkPub(pub *PublicKey) error {
if pub.N == nil {
return errPublicModulus
}
if pub.E < 2 {
return errPublicExponentSmall
}
if pub.E > 1<<31-1 {
return errPublicExponentLarge
}
return nil
}
// A PrivateKey represents an RSA key
type PrivateKey struct {
PublicKey // public part.
D *big.Int // private exponent
Primes []*big.Int // prime factors of N, has >= 2 elements.
// Precomputed contains precomputed values that speed up RSA operations,
// if available. It must be generated by calling PrivateKey.Precompute and
// must not be modified.
Precomputed PrecomputedValues
}
// Public returns the public key corresponding to priv.
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
// Equal reports whether priv and x have equivalent values. It ignores
// Precomputed values.
func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
xx, ok := x.(*PrivateKey)
if !ok {
return false
}
if !priv.PublicKey.Equal(&xx.PublicKey) || priv.D.Cmp(xx.D) != 0 {
return false
}
if len(priv.Primes) != len(xx.Primes) {
return false
}
for i := range priv.Primes {
if priv.Primes[i].Cmp(xx.Primes[i]) != 0 {
return false
}
}
return true
}
// Sign signs digest with priv, reading randomness from rand. If opts is a
// *PSSOptions then the PSS algorithm will be used, otherwise PKCS #1 v1.5 will
// be used. digest must be the result of hashing the input message using
// opts.HashFunc().
//
// This method implements crypto.Signer, which is an interface to support keys
// where the private part is kept in, for example, a hardware module. Common
// uses should use the Sign* functions in this package directly.
func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
if pssOpts, ok := opts.(*PSSOptions); ok {
return SignPSS(rand, priv, pssOpts.Hash, digest, pssOpts)
}
return SignPKCS1v15(rand, priv, opts.HashFunc(), digest)
}
// Decrypt decrypts ciphertext with priv. If opts is nil or of type
// *PKCS1v15DecryptOptions then PKCS #1 v1.5 decryption is performed. Otherwise
// opts must have type *OAEPOptions and OAEP decryption is done.
func (priv *PrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
if opts == nil {
return DecryptPKCS1v15(rand, priv, ciphertext)
}
switch opts := opts.(type) {
case *OAEPOptions:
if opts.MGFHash == 0 {
return decryptOAEP(opts.Hash.New(), opts.Hash.New(), rand, priv, ciphertext, opts.Label)
} else {
return decryptOAEP(opts.Hash.New(), opts.MGFHash.New(), rand, priv, ciphertext, opts.Label)
}
case *PKCS1v15DecryptOptions:
if l := opts.SessionKeyLen; l > 0 {
plaintext = make([]byte, l)
if _, err := io.ReadFull(rand, plaintext); err != nil {
return nil, err
}
if err := DecryptPKCS1v15SessionKey(rand, priv, ciphertext, plaintext); err != nil {
return nil, err
}
return plaintext, nil
} else {
return DecryptPKCS1v15(rand, priv, ciphertext)
}
default:
return nil, errors.New("crypto/rsa: invalid options for Decrypt")
}
}
type PrecomputedValues struct {
Dp, Dq *big.Int // D mod (P-1) (or mod Q-1)
Qinv *big.Int // Q^-1 mod P
// CRTValues is used for the 3rd and subsequent primes. Due to a
// historical accident, the CRT for the first two primes is handled
// differently in PKCS #1 and interoperability is sufficiently
// important that we mirror this.
//
// Note: these values are still filled in by Precompute for
// backwards compatibility but are not used. Multi-prime RSA is very rare,
// and is implemented by this package without CRT optimizations to limit
// complexity.
CRTValues []CRTValue
n, p, q *bigmod.Modulus // moduli for CRT with Montgomery precomputed constants
}
// CRTValue contains the precomputed Chinese remainder theorem values.
type CRTValue struct {
Exp *big.Int // D mod (prime-1).
Coeff *big.Int // R·Coeff ≡ 1 mod Prime.
R *big.Int // product of primes prior to this (inc p and q).
}
// Validate performs basic sanity checks on the key.
// It returns nil if the key is valid, or else an error describing a problem.
func (priv *PrivateKey) Validate() error {
if err := checkPub(&priv.PublicKey); err != nil {
return err
}
// Check that Πprimes == n.
modulus := new(big.Int).Set(bigOne)
for _, prime := range priv.Primes {
// Any primes ≤ 1 will cause divide-by-zero panics later.
if prime.Cmp(bigOne) <= 0 {
return errors.New("crypto/rsa: invalid prime value")
}
modulus.Mul(modulus, prime)
}
if modulus.Cmp(priv.N) != 0 {
return errors.New("crypto/rsa: invalid modulus")
}
// Check that de ≡ 1 mod p-1, for each prime.
// This implies that e is coprime to each p-1 as e has a multiplicative
// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
// exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
congruence := new(big.Int)
de := new(big.Int).SetInt64(int64(priv.E))
de.Mul(de, priv.D)
for _, prime := range priv.Primes {
pminus1 := new(big.Int).Sub(prime, bigOne)
congruence.Mod(de, pminus1)
if congruence.Cmp(bigOne) != 0 {
return errors.New("crypto/rsa: invalid exponents")
}
}
return nil
}
// GenerateKey generates an RSA keypair of the given bit size using the
// random source random (for example, crypto/rand.Reader).
func GenerateKey(random io.Reader, bits int) (*PrivateKey, error) {
return GenerateMultiPrimeKey(random, 2, bits)
}
// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit
// size and the given random source.
//
// Table 1 in "[On the Security of Multi-prime RSA]" suggests maximum numbers of
// primes for a given bit size.
//
// Although the public keys are compatible (actually, indistinguishable) from
// the 2-prime case, the private keys are not. Thus it may not be possible to
// export multi-prime private keys in certain formats or to subsequently import
// them into other code.
//
// This package does not implement CRT optimizations for multi-prime RSA, so the
// keys with more than two primes will have worse performance.
//
// Note: The use of this function with a number of primes different from
// two is not recommended for the above security, compatibility, and performance
// reasons. Use GenerateKey instead.
//
// [On the Security of Multi-prime RSA]: http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey, error) {
randutil.MaybeReadByte(random)
if boring.Enabled && random == boring.RandReader && nprimes == 2 && (bits == 2048 || bits == 3072) {
bN, bE, bD, bP, bQ, bDp, bDq, bQinv, err := boring.GenerateKeyRSA(bits)
if err != nil {
return nil, err
}
N := bbig.Dec(bN)
E := bbig.Dec(bE)
D := bbig.Dec(bD)
P := bbig.Dec(bP)
Q := bbig.Dec(bQ)
Dp := bbig.Dec(bDp)
Dq := bbig.Dec(bDq)
Qinv := bbig.Dec(bQinv)
e64 := E.Int64()
if !E.IsInt64() || int64(int(e64)) != e64 {
return nil, errors.New("crypto/rsa: generated key exponent too large")
}
key := &PrivateKey{
PublicKey: PublicKey{
N: N,
E: int(e64),
},
D: D,
Primes: []*big.Int{P, Q},
Precomputed: PrecomputedValues{
Dp: Dp,
Dq: Dq,
Qinv: Qinv,
CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute
n: bigmod.NewModulusFromBig(N),
p: bigmod.NewModulusFromBig(P),
q: bigmod.NewModulusFromBig(Q),
},
}
return key, nil
}
priv := new(PrivateKey)
priv.E = 65537
if nprimes < 2 {
return nil, errors.New("crypto/rsa: GenerateMultiPrimeKey: nprimes must be >= 2")
}
if bits < 64 {
primeLimit := float64(uint64(1) << uint(bits/nprimes))
// pi approximates the number of primes less than primeLimit
pi := primeLimit / (math.Log(primeLimit) - 1)
// Generated primes start with 11 (in binary) so we can only
// use a quarter of them.
pi /= 4
// Use a factor of two to ensure that key generation terminates
// in a reasonable amount of time.
pi /= 2
if pi <= float64(nprimes) {
return nil, errors.New("crypto/rsa: too few primes of given length to generate an RSA key")
}
}
primes := make([]*big.Int, nprimes)
NextSetOfPrimes:
for {
todo := bits
// crypto/rand should set the top two bits in each prime.
// Thus each prime has the form
// p_i = 2^bitlen(p_i) × 0.11... (in base 2).
// And the product is:
// P = 2^todo × α
// where α is the product of nprimes numbers of the form 0.11...
//
// If α < 1/2 (which can happen for nprimes > 2), we need to
// shift todo to compensate for lost bits: the mean value of 0.11...
// is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
// will give good results.
if nprimes >= 7 {
todo += (nprimes - 2) / 5
}
for i := 0; i < nprimes; i++ {
var err error
primes[i], err = rand.Prime(random, todo/(nprimes-i))
if err != nil {
return nil, err
}
todo -= primes[i].BitLen()
}
// Make sure that primes is pairwise unequal.
for i, prime := range primes {
for j := 0; j < i; j++ {
if prime.Cmp(primes[j]) == 0 {
continue NextSetOfPrimes
}
}
}
n := new(big.Int).Set(bigOne)
totient := new(big.Int).Set(bigOne)
pminus1 := new(big.Int)
for _, prime := range primes {
n.Mul(n, prime)
pminus1.Sub(prime, bigOne)
totient.Mul(totient, pminus1)
}
if n.BitLen() != bits {
// This should never happen for nprimes == 2 because
// crypto/rand should set the top two bits in each prime.
// For nprimes > 2 we hope it does not happen often.
continue NextSetOfPrimes
}
priv.D = new(big.Int)
e := big.NewInt(int64(priv.E))
ok := priv.D.ModInverse(e, totient)
if ok != nil {
priv.Primes = primes
priv.N = n
break
}
}
priv.Precompute()
return priv, nil
}
// incCounter increments a four byte, big-endian counter.
func incCounter(c *[4]byte) {
if c[3]++; c[3] != 0 {
return
}
if c[2]++; c[2] != 0 {
return
}
if c[1]++; c[1] != 0 {
return
}
c[0]++
}
// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
// specified in PKCS #1 v2.1.
func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
var counter [4]byte
var digest []byte
done := 0
for done < len(out) {
hash.Write(seed)
hash.Write(counter[0:4])
digest = hash.Sum(digest[:0])
hash.Reset()
for i := 0; i < len(digest) && done < len(out); i++ {
out[done] ^= digest[i]
done++
}
incCounter(&counter)
}
}
// ErrMessageTooLong is returned when attempting to encrypt or sign a message
// which is too large for the size of the key. When using SignPSS, this can also
// be returned if the size of the salt is too large.
var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
func encrypt(pub *PublicKey, plaintext []byte) ([]byte, error) {
boring.Unreachable()
N := bigmod.NewModulusFromBig(pub.N)
m, err := bigmod.NewNat().SetBytes(plaintext, N)
if err != nil {
return nil, err
}
e := intToBytes(pub.E)
return bigmod.NewNat().Exp(m, e, N).Bytes(N), nil
}
// intToBytes returns i as a big-endian slice of bytes with no leading zeroes,
// leaking only the bit size of i through timing side-channels.
func intToBytes(i int) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(i))
for len(b) > 1 && b[0] == 0 {
b = b[1:]
}
return b
}
// EncryptOAEP encrypts the given message with RSA-OAEP.
//
// OAEP is parameterised by a hash function that is used as a random oracle.
// Encryption and decryption of a given message must use the same hash function
// and sha256.New() is a reasonable choice.
//
// The random parameter is used as a source of entropy to ensure that
// encrypting the same message twice doesn't result in the same ciphertext.
//
// The label parameter may contain arbitrary data that will not be encrypted,
// but which gives important context to the message. For example, if a given
// public key is used to encrypt two types of messages then distinct label
// values could be used to ensure that a ciphertext for one purpose cannot be
// used for another by an attacker. If not required it can be empty.
//
// The message must be no longer than the length of the public modulus minus
// twice the hash length, minus a further 2.
func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
if err := checkPub(pub); err != nil {
return nil, err
}
hash.Reset()
k := pub.Size()
if len(msg) > k-2*hash.Size()-2 {
return nil, ErrMessageTooLong
}
if boring.Enabled && random == boring.RandReader {
bkey, err := boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSAOAEP(hash, hash, bkey, msg, label)
}
boring.UnreachableExceptTests()
hash.Write(label)
lHash := hash.Sum(nil)
hash.Reset()
em := make([]byte, k)
seed := em[1 : 1+hash.Size()]
db := em[1+hash.Size():]
copy(db[0:hash.Size()], lHash)
db[len(db)-len(msg)-1] = 1
copy(db[len(db)-len(msg):], msg)
_, err := io.ReadFull(random, seed)
if err != nil {
return nil, err
}
mgf1XOR(db, hash, seed)
mgf1XOR(seed, hash, db)
if boring.Enabled {
var bkey *boring.PublicKeyRSA
bkey, err = boringPublicKey(pub)
if err != nil {
return nil, err
}
return boring.EncryptRSANoPadding(bkey, em)
}
return encrypt(pub, em)
}
// ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks.
var ErrDecryption = errors.New("crypto/rsa: decryption error")
// ErrVerification represents a failure to verify a signature.
// It is deliberately vague to avoid adaptive attacks.
var ErrVerification = errors.New("crypto/rsa: verification error")
// Precompute performs some calculations that speed up private key operations
// in the future.
func (priv *PrivateKey) Precompute() {
if priv.Precomputed.n == nil && len(priv.Primes) == 2 {
priv.Precomputed.n = bigmod.NewModulusFromBig(priv.N)
priv.Precomputed.p = bigmod.NewModulusFromBig(priv.Primes[0])
priv.Precomputed.q = bigmod.NewModulusFromBig(priv.Primes[1])
}
// Fill in the backwards-compatibility *big.Int values.
if priv.Precomputed.Dp != nil {
return
}
priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
for i := 2; i < len(priv.Primes); i++ {
prime := priv.Primes[i]
values := &priv.Precomputed.CRTValues[i-2]
values.Exp = new(big.Int).Sub(prime, bigOne)
values.Exp.Mod(priv.D, values.Exp)
values.R = new(big.Int).Set(r)
values.Coeff = new(big.Int).ModInverse(r, prime)
r.Mul(r, prime)
}
}
const withCheck = true
const noCheck = false
// decrypt performs an RSA decryption of ciphertext into out. If check is true,
// m^e is calculated and compared with ciphertext, in order to defend against
// errors in the CRT computation.
func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
if len(priv.Primes) <= 2 {
boring.Unreachable()
}
var (
err error
m, c *bigmod.Nat
N *bigmod.Modulus
t0 = bigmod.NewNat()
)
if priv.Precomputed.n == nil {
N = bigmod.NewModulusFromBig(priv.N)
c, err = bigmod.NewNat().SetBytes(ciphertext, N)
if err != nil {
return nil, ErrDecryption
}
m = bigmod.NewNat().Exp(c, priv.D.Bytes(), N)
} else {
N = priv.Precomputed.n
P, Q := priv.Precomputed.p, priv.Precomputed.q
Qinv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), P)
if err != nil {
return nil, ErrDecryption
}
c, err = bigmod.NewNat().SetBytes(ciphertext, N)
if err != nil {
return nil, ErrDecryption
}
// m = c ^ Dp mod p
m = bigmod.NewNat().Exp(t0.Mod(c, P), priv.Precomputed.Dp.Bytes(), P)
// m2 = c ^ Dq mod q
m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
// m = m - m2 mod p
m.Sub(t0.Mod(m2, P), P)
// m = m * Qinv mod p
m.Mul(Qinv, P)
// m = m * q mod N
m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
// m = m + m2 mod N
m.Add(m2.ExpandFor(N), N)
}
if check {
c1 := bigmod.NewNat().Exp(m, intToBytes(priv.E), N)
if c1.Equal(c) != 1 {
return nil, ErrDecryption
}
}
return m.Bytes(N), nil
}
// DecryptOAEP decrypts ciphertext using RSA-OAEP.
//
// OAEP is parameterised by a hash function that is used as a random oracle.
// Encryption and decryption of a given message must use the same hash function
// and sha256.New() is a reasonable choice.
//
// The random parameter is legacy and ignored, and it can be as nil.
//
// The label parameter must match the value given when encrypting. See
// EncryptOAEP for details.
func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
return decryptOAEP(hash, hash, random, priv, ciphertext, label)
}
func decryptOAEP(hash, mgfHash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
if err := checkPub(&priv.PublicKey); err != nil {
return nil, err
}
k := priv.Size()
if len(ciphertext) > k ||
k < hash.Size()*2+2 {
return nil, ErrDecryption
}
if boring.Enabled {
bkey, err := boringPrivateKey(priv)
if err != nil {
return nil, err
}
out, err := boring.DecryptRSAOAEP(hash, mgfHash, bkey, ciphertext, label)
if err != nil {
return nil, ErrDecryption
}
return out, nil
}
em, err := decrypt(priv, ciphertext, noCheck)
if err != nil {
return nil, err
}
hash.Write(label)
lHash := hash.Sum(nil)
hash.Reset()
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
seed := em[1 : hash.Size()+1]
db := em[hash.Size()+1:]
mgf1XOR(seed, mgfHash, db)
mgf1XOR(db, mgfHash, seed)
lHash2 := db[0:hash.Size()]
// We have to validate the plaintext in constant time in order to avoid
// attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
// Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
// v2.0. In J. Kilian, editor, Advances in Cryptology.
lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
// The remainder of the plaintext must be zero or more 0x00, followed
// by 0x01, followed by the message.
// lookingForIndex: 1 iff we are still looking for the 0x01
// index: the offset of the first 0x01 byte
// invalid: 1 iff we saw a non-zero byte before the 0x01.
var lookingForIndex, index, invalid int
lookingForIndex = 1
rest := db[hash.Size():]
for i := 0; i < len(rest); i++ {
equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
}
if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
return nil, ErrDecryption
}
return rest[index+1:], nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Extra indirection here so that when building go_bootstrap
// cmd/internal/boring is not even imported, so that we don't
// have to maintain changes to cmd/dist's deps graph.
//go:build !cmd_go_bootstrap && cgo
// +build !cmd_go_bootstrap,cgo
package sha1
import (
"crypto/internal/boring"
"hash"
)
const boringEnabled = boring.Enabled
func boringNewSHA1() hash.Hash { return boring.NewSHA1() }
func boringUnreachable() { boring.Unreachable() }
func boringSHA1(p []byte) [20]byte { return boring.SHA1(p) }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha1 implements the SHA-1 hash algorithm as defined in RFC 3174.
//
// SHA-1 is cryptographically broken and should not be used for secure
// applications.
package sha1
import (
"crypto"
"encoding/binary"
"errors"
"hash"
)
func init() {
crypto.RegisterHash(crypto.SHA1, New)
}
// The size of a SHA-1 checksum in bytes.
const Size = 20
// The blocksize of SHA-1 in bytes.
const BlockSize = 64
const (
chunk = 64
init0 = 0x67452301
init1 = 0xEFCDAB89
init2 = 0x98BADCFE
init3 = 0x10325476
init4 = 0xC3D2E1F0
)
// digest represents the partial evaluation of a checksum.
type digest struct {
h [5]uint32
x [chunk]byte
nx int
len uint64
}
const (
magic = "sha\x01"
marshaledSize = len(magic) + 5*4 + chunk + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = binary.BigEndian.AppendUint32(b, d.h[0])
b = binary.BigEndian.AppendUint32(b, d.h[1])
b = binary.BigEndian.AppendUint32(b, d.h[2])
b = binary.BigEndian.AppendUint32(b, d.h[3])
b = binary.BigEndian.AppendUint32(b, d.h[4])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-d.nx] // already zero
b = binary.BigEndian.AppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("crypto/sha1: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/sha1: invalid hash state size")
}
b = b[len(magic):]
b, d.h[0] = consumeUint32(b)
b, d.h[1] = consumeUint32(b)
b, d.h[2] = consumeUint32(b)
b, d.h[3] = consumeUint32(b)
b, d.h[4] = consumeUint32(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % chunk)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
_ = b[7]
x := uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
return b[8:], x
}
func consumeUint32(b []byte) ([]byte, uint32) {
_ = b[3]
x := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
return b[4:], x
}
func (d *digest) Reset() {
d.h[0] = init0
d.h[1] = init1
d.h[2] = init2
d.h[3] = init3
d.h[4] = init4
d.nx = 0
d.len = 0
}
// New returns a new hash.Hash computing the SHA1 checksum. The Hash also
// implements encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
func New() hash.Hash {
if boringEnabled {
return boringNewSHA1()
}
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
boringUnreachable()
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == chunk {
block(d, d.x[:])
d.nx = 0
}
p = p[n:]
}
if len(p) >= chunk {
n := len(p) &^ (chunk - 1)
block(d, p[:n])
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
boringUnreachable()
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
hash := d0.checkSum()
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
len := d.len
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
var tmp [64 + 8]byte // padding + length buffer
tmp[0] = 0x80
var t uint64
if len%64 < 56 {
t = 56 - len%64
} else {
t = 64 + 56 - len%64
}
// Length in bits.
len <<= 3
padlen := tmp[:t+8]
binary.BigEndian.PutUint64(padlen[t:], len)
d.Write(padlen)
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
binary.BigEndian.PutUint32(digest[0:], d.h[0])
binary.BigEndian.PutUint32(digest[4:], d.h[1])
binary.BigEndian.PutUint32(digest[8:], d.h[2])
binary.BigEndian.PutUint32(digest[12:], d.h[3])
binary.BigEndian.PutUint32(digest[16:], d.h[4])
return digest
}
// ConstantTimeSum computes the same result of Sum() but in constant time
func (d *digest) ConstantTimeSum(in []byte) []byte {
d0 := *d
hash := d0.constSum()
return append(in, hash[:]...)
}
func (d *digest) constSum() [Size]byte {
var length [8]byte
l := d.len << 3
for i := uint(0); i < 8; i++ {
length[i] = byte(l >> (56 - 8*i))
}
nx := byte(d.nx)
t := nx - 56 // if nx < 56 then the MSB of t is one
mask1b := byte(int8(t) >> 7) // mask1b is 0xFF iff one block is enough
separator := byte(0x80) // gets reset to 0x00 once used
for i := byte(0); i < chunk; i++ {
mask := byte(int8(i-nx) >> 7) // 0x00 after the end of data
// if we reached the end of the data, replace with 0x80 or 0x00
d.x[i] = (^mask & separator) | (mask & d.x[i])
// zero the separator once used
separator &= mask
if i >= 56 {
// we might have to write the length here if all fit in one block
d.x[i] |= mask1b & length[i-56]
}
}
// compress, and only keep the digest if all fit in one block
block(d, d.x[:])
var digest [Size]byte
for i, s := range d.h {
digest[i*4] = mask1b & byte(s>>24)
digest[i*4+1] = mask1b & byte(s>>16)
digest[i*4+2] = mask1b & byte(s>>8)
digest[i*4+3] = mask1b & byte(s)
}
for i := byte(0); i < chunk; i++ {
// second block, it's always past the end of data, might start with 0x80
if i < 56 {
d.x[i] = separator
separator = 0
} else {
d.x[i] = length[i-56]
}
}
// compress, and only keep the digest if we actually needed the second block
block(d, d.x[:])
for i, s := range d.h {
digest[i*4] |= ^mask1b & byte(s>>24)
digest[i*4+1] |= ^mask1b & byte(s>>16)
digest[i*4+2] |= ^mask1b & byte(s>>8)
digest[i*4+3] |= ^mask1b & byte(s)
}
return digest
}
// Sum returns the SHA-1 checksum of the data.
func Sum(data []byte) [Size]byte {
if boringEnabled {
return boringSHA1(data)
}
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha1
import (
"math/bits"
)
const (
_K0 = 0x5A827999
_K1 = 0x6ED9EBA1
_K2 = 0x8F1BBCDC
_K3 = 0xCA62C1D6
)
// blockGeneric is a portable, pure Go version of the SHA-1 block step.
// It's used by sha1block_generic.go and tests.
func blockGeneric(dig *digest, p []byte) {
var w [16]uint32
h0, h1, h2, h3, h4 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4]
for len(p) >= chunk {
// Can interlace the computation of w with the
// rounds below if needed for speed.
for i := 0; i < 16; i++ {
j := i * 4
w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3])
}
a, b, c, d, e := h0, h1, h2, h3, h4
// Each of the four 20-iteration rounds
// differs only in the computation of f and
// the choice of K (_K0, _K1, etc).
i := 0
for ; i < 16; i++ {
f := b&c | (^b)&d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 20; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b&c | (^b)&d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K0
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 40; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b ^ c ^ d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K1
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 60; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := ((b | c) & d) | (b & c)
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K2
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
for ; i < 80; i++ {
tmp := w[(i-3)&0xf] ^ w[(i-8)&0xf] ^ w[(i-14)&0xf] ^ w[(i)&0xf]
w[i&0xf] = bits.RotateLeft32(tmp, 1)
f := b ^ c ^ d
t := bits.RotateLeft32(a, 5) + f + e + w[i&0xf] + _K3
a, b, c, d, e = t, a, bits.RotateLeft32(b, 30), c, d
}
h0 += a
h1 += b
h2 += c
h3 += d
h4 += e
p = p[chunk:]
}
dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4] = h0, h1, h2, h3, h4
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha1
import "internal/cpu"
//go:noescape
func blockAVX2(dig *digest, p []byte)
//go:noescape
func blockAMD64(dig *digest, p []byte)
var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI1 && cpu.X86.HasBMI2
func block(dig *digest, p []byte) {
if useAVX2 && len(p) >= 256 {
// blockAVX2 calculates sha1 for 2 block per iteration
// it also interleaves precalculation for next block.
// So it may read up-to 192 bytes past end of p
// We may add checks inside blockAVX2, but this will
// just turn it into a copy of blockAMD64,
// so call it directly, instead.
safeLen := len(p) - 128
if safeLen%128 != 0 {
safeLen -= 64
}
blockAVX2(dig, p[:safeLen])
blockAMD64(dig, p[safeLen:])
} else {
blockAMD64(dig, p)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha256 implements the SHA224 and SHA256 hash algorithms as defined
// in FIPS 180-4.
package sha256
import (
"crypto"
"crypto/internal/boring"
"encoding/binary"
"errors"
"hash"
)
func init() {
crypto.RegisterHash(crypto.SHA224, New224)
crypto.RegisterHash(crypto.SHA256, New)
}
// The size of a SHA256 checksum in bytes.
const Size = 32
// The size of a SHA224 checksum in bytes.
const Size224 = 28
// The blocksize of SHA256 and SHA224 in bytes.
const BlockSize = 64
const (
chunk = 64
init0 = 0x6A09E667
init1 = 0xBB67AE85
init2 = 0x3C6EF372
init3 = 0xA54FF53A
init4 = 0x510E527F
init5 = 0x9B05688C
init6 = 0x1F83D9AB
init7 = 0x5BE0CD19
init0_224 = 0xC1059ED8
init1_224 = 0x367CD507
init2_224 = 0x3070DD17
init3_224 = 0xF70E5939
init4_224 = 0xFFC00B31
init5_224 = 0x68581511
init6_224 = 0x64F98FA7
init7_224 = 0xBEFA4FA4
)
// digest represents the partial evaluation of a checksum.
type digest struct {
h [8]uint32
x [chunk]byte
nx int
len uint64
is224 bool // mark if this digest is SHA-224
}
const (
magic224 = "sha\x02"
magic256 = "sha\x03"
marshaledSize = len(magic256) + 8*4 + chunk + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
if d.is224 {
b = append(b, magic224...)
} else {
b = append(b, magic256...)
}
b = binary.BigEndian.AppendUint32(b, d.h[0])
b = binary.BigEndian.AppendUint32(b, d.h[1])
b = binary.BigEndian.AppendUint32(b, d.h[2])
b = binary.BigEndian.AppendUint32(b, d.h[3])
b = binary.BigEndian.AppendUint32(b, d.h[4])
b = binary.BigEndian.AppendUint32(b, d.h[5])
b = binary.BigEndian.AppendUint32(b, d.h[6])
b = binary.BigEndian.AppendUint32(b, d.h[7])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-d.nx] // already zero
b = binary.BigEndian.AppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic224) || (d.is224 && string(b[:len(magic224)]) != magic224) || (!d.is224 && string(b[:len(magic256)]) != magic256) {
return errors.New("crypto/sha256: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/sha256: invalid hash state size")
}
b = b[len(magic224):]
b, d.h[0] = consumeUint32(b)
b, d.h[1] = consumeUint32(b)
b, d.h[2] = consumeUint32(b)
b, d.h[3] = consumeUint32(b)
b, d.h[4] = consumeUint32(b)
b, d.h[5] = consumeUint32(b)
b, d.h[6] = consumeUint32(b)
b, d.h[7] = consumeUint32(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % chunk)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
_ = b[7]
x := uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
return b[8:], x
}
func consumeUint32(b []byte) ([]byte, uint32) {
_ = b[3]
x := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
return b[4:], x
}
func (d *digest) Reset() {
if !d.is224 {
d.h[0] = init0
d.h[1] = init1
d.h[2] = init2
d.h[3] = init3
d.h[4] = init4
d.h[5] = init5
d.h[6] = init6
d.h[7] = init7
} else {
d.h[0] = init0_224
d.h[1] = init1_224
d.h[2] = init2_224
d.h[3] = init3_224
d.h[4] = init4_224
d.h[5] = init5_224
d.h[6] = init6_224
d.h[7] = init7_224
}
d.nx = 0
d.len = 0
}
// New returns a new hash.Hash computing the SHA256 checksum. The Hash
// also implements encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash {
if boring.Enabled {
return boring.NewSHA256()
}
d := new(digest)
d.Reset()
return d
}
// New224 returns a new hash.Hash computing the SHA224 checksum.
func New224() hash.Hash {
if boring.Enabled {
return boring.NewSHA224()
}
d := new(digest)
d.is224 = true
d.Reset()
return d
}
func (d *digest) Size() int {
if !d.is224 {
return Size
}
return Size224
}
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
boring.Unreachable()
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == chunk {
block(d, d.x[:])
d.nx = 0
}
p = p[n:]
}
if len(p) >= chunk {
n := len(p) &^ (chunk - 1)
block(d, p[:n])
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
boring.Unreachable()
// Make a copy of d so that caller can keep writing and summing.
d0 := *d
hash := d0.checkSum()
if d0.is224 {
return append(in, hash[:Size224]...)
}
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
len := d.len
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
var tmp [64 + 8]byte // padding + length buffer
tmp[0] = 0x80
var t uint64
if len%64 < 56 {
t = 56 - len%64
} else {
t = 64 + 56 - len%64
}
// Length in bits.
len <<= 3
padlen := tmp[:t+8]
binary.BigEndian.PutUint64(padlen[t+0:], len)
d.Write(padlen)
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
binary.BigEndian.PutUint32(digest[0:], d.h[0])
binary.BigEndian.PutUint32(digest[4:], d.h[1])
binary.BigEndian.PutUint32(digest[8:], d.h[2])
binary.BigEndian.PutUint32(digest[12:], d.h[3])
binary.BigEndian.PutUint32(digest[16:], d.h[4])
binary.BigEndian.PutUint32(digest[20:], d.h[5])
binary.BigEndian.PutUint32(digest[24:], d.h[6])
if !d.is224 {
binary.BigEndian.PutUint32(digest[28:], d.h[7])
}
return digest
}
// Sum256 returns the SHA256 checksum of the data.
func Sum256(data []byte) [Size]byte {
if boring.Enabled {
return boring.SHA256(data)
}
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
}
// Sum224 returns the SHA224 checksum of the data.
func Sum224(data []byte) [Size224]byte {
if boring.Enabled {
return boring.SHA224(data)
}
var d digest
d.is224 = true
d.Reset()
d.Write(data)
sum := d.checkSum()
ap := (*[Size224]byte)(sum[:])
return *ap
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// SHA256 block step.
// In its own file so that a faster assembly or C version
// can be substituted easily.
package sha256
import "math/bits"
var _K = []uint32{
0x428a2f98,
0x71374491,
0xb5c0fbcf,
0xe9b5dba5,
0x3956c25b,
0x59f111f1,
0x923f82a4,
0xab1c5ed5,
0xd807aa98,
0x12835b01,
0x243185be,
0x550c7dc3,
0x72be5d74,
0x80deb1fe,
0x9bdc06a7,
0xc19bf174,
0xe49b69c1,
0xefbe4786,
0x0fc19dc6,
0x240ca1cc,
0x2de92c6f,
0x4a7484aa,
0x5cb0a9dc,
0x76f988da,
0x983e5152,
0xa831c66d,
0xb00327c8,
0xbf597fc7,
0xc6e00bf3,
0xd5a79147,
0x06ca6351,
0x14292967,
0x27b70a85,
0x2e1b2138,
0x4d2c6dfc,
0x53380d13,
0x650a7354,
0x766a0abb,
0x81c2c92e,
0x92722c85,
0xa2bfe8a1,
0xa81a664b,
0xc24b8b70,
0xc76c51a3,
0xd192e819,
0xd6990624,
0xf40e3585,
0x106aa070,
0x19a4c116,
0x1e376c08,
0x2748774c,
0x34b0bcb5,
0x391c0cb3,
0x4ed8aa4a,
0x5b9cca4f,
0x682e6ff3,
0x748f82ee,
0x78a5636f,
0x84c87814,
0x8cc70208,
0x90befffa,
0xa4506ceb,
0xbef9a3f7,
0xc67178f2,
}
func blockGeneric(dig *digest, p []byte) {
var w [64]uint32
h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7]
for len(p) >= chunk {
// Can interlace the computation of w with the
// rounds below if needed for speed.
for i := 0; i < 16; i++ {
j := i * 4
w[i] = uint32(p[j])<<24 | uint32(p[j+1])<<16 | uint32(p[j+2])<<8 | uint32(p[j+3])
}
for i := 16; i < 64; i++ {
v1 := w[i-2]
t1 := (bits.RotateLeft32(v1, -17)) ^ (bits.RotateLeft32(v1, -19)) ^ (v1 >> 10)
v2 := w[i-15]
t2 := (bits.RotateLeft32(v2, -7)) ^ (bits.RotateLeft32(v2, -18)) ^ (v2 >> 3)
w[i] = t1 + w[i-7] + t2 + w[i-16]
}
a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7
for i := 0; i < 64; i++ {
t1 := h + ((bits.RotateLeft32(e, -6)) ^ (bits.RotateLeft32(e, -11)) ^ (bits.RotateLeft32(e, -25))) + ((e & f) ^ (^e & g)) + _K[i] + w[i]
t2 := ((bits.RotateLeft32(a, -2)) ^ (bits.RotateLeft32(a, -13)) ^ (bits.RotateLeft32(a, -22))) + ((a & b) ^ (a & c) ^ (b & c))
h = g
g = f
f = e
e = d + t1
d = c
c = b
b = a
a = t1 + t2
}
h0 += a
h1 += b
h2 += c
h3 += d
h4 += e
h5 += f
h6 += g
h7 += h
p = p[chunk:]
}
dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha512 implements the SHA-384, SHA-512, SHA-512/224, and SHA-512/256
// hash algorithms as defined in FIPS 180-4.
//
// All the hash.Hash implementations returned by this package also
// implement encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
package sha512
import (
"crypto"
"crypto/internal/boring"
"encoding/binary"
"errors"
"hash"
)
func init() {
crypto.RegisterHash(crypto.SHA384, New384)
crypto.RegisterHash(crypto.SHA512, New)
crypto.RegisterHash(crypto.SHA512_224, New512_224)
crypto.RegisterHash(crypto.SHA512_256, New512_256)
}
const (
// Size is the size, in bytes, of a SHA-512 checksum.
Size = 64
// Size224 is the size, in bytes, of a SHA-512/224 checksum.
Size224 = 28
// Size256 is the size, in bytes, of a SHA-512/256 checksum.
Size256 = 32
// Size384 is the size, in bytes, of a SHA-384 checksum.
Size384 = 48
// BlockSize is the block size, in bytes, of the SHA-512/224,
// SHA-512/256, SHA-384 and SHA-512 hash functions.
BlockSize = 128
)
const (
chunk = 128
init0 = 0x6a09e667f3bcc908
init1 = 0xbb67ae8584caa73b
init2 = 0x3c6ef372fe94f82b
init3 = 0xa54ff53a5f1d36f1
init4 = 0x510e527fade682d1
init5 = 0x9b05688c2b3e6c1f
init6 = 0x1f83d9abfb41bd6b
init7 = 0x5be0cd19137e2179
init0_224 = 0x8c3d37c819544da2
init1_224 = 0x73e1996689dcd4d6
init2_224 = 0x1dfab7ae32ff9c82
init3_224 = 0x679dd514582f9fcf
init4_224 = 0x0f6d2b697bd44da8
init5_224 = 0x77e36f7304c48942
init6_224 = 0x3f9d85a86a1d36c8
init7_224 = 0x1112e6ad91d692a1
init0_256 = 0x22312194fc2bf72c
init1_256 = 0x9f555fa3c84c64c2
init2_256 = 0x2393b86b6f53b151
init3_256 = 0x963877195940eabd
init4_256 = 0x96283ee2a88effe3
init5_256 = 0xbe5e1e2553863992
init6_256 = 0x2b0199fc2c85b8aa
init7_256 = 0x0eb72ddc81c52ca2
init0_384 = 0xcbbb9d5dc1059ed8
init1_384 = 0x629a292a367cd507
init2_384 = 0x9159015a3070dd17
init3_384 = 0x152fecd8f70e5939
init4_384 = 0x67332667ffc00b31
init5_384 = 0x8eb44a8768581511
init6_384 = 0xdb0c2e0d64f98fa7
init7_384 = 0x47b5481dbefa4fa4
)
// digest represents the partial evaluation of a checksum.
type digest struct {
h [8]uint64
x [chunk]byte
nx int
len uint64
function crypto.Hash
}
func (d *digest) Reset() {
switch d.function {
case crypto.SHA384:
d.h[0] = init0_384
d.h[1] = init1_384
d.h[2] = init2_384
d.h[3] = init3_384
d.h[4] = init4_384
d.h[5] = init5_384
d.h[6] = init6_384
d.h[7] = init7_384
case crypto.SHA512_224:
d.h[0] = init0_224
d.h[1] = init1_224
d.h[2] = init2_224
d.h[3] = init3_224
d.h[4] = init4_224
d.h[5] = init5_224
d.h[6] = init6_224
d.h[7] = init7_224
case crypto.SHA512_256:
d.h[0] = init0_256
d.h[1] = init1_256
d.h[2] = init2_256
d.h[3] = init3_256
d.h[4] = init4_256
d.h[5] = init5_256
d.h[6] = init6_256
d.h[7] = init7_256
default:
d.h[0] = init0
d.h[1] = init1
d.h[2] = init2
d.h[3] = init3
d.h[4] = init4
d.h[5] = init5
d.h[6] = init6
d.h[7] = init7
}
d.nx = 0
d.len = 0
}
const (
magic384 = "sha\x04"
magic512_224 = "sha\x05"
magic512_256 = "sha\x06"
magic512 = "sha\x07"
marshaledSize = len(magic512) + 8*8 + chunk + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
switch d.function {
case crypto.SHA384:
b = append(b, magic384...)
case crypto.SHA512_224:
b = append(b, magic512_224...)
case crypto.SHA512_256:
b = append(b, magic512_256...)
case crypto.SHA512:
b = append(b, magic512...)
default:
return nil, errors.New("crypto/sha512: invalid hash function")
}
b = binary.BigEndian.AppendUint64(b, d.h[0])
b = binary.BigEndian.AppendUint64(b, d.h[1])
b = binary.BigEndian.AppendUint64(b, d.h[2])
b = binary.BigEndian.AppendUint64(b, d.h[3])
b = binary.BigEndian.AppendUint64(b, d.h[4])
b = binary.BigEndian.AppendUint64(b, d.h[5])
b = binary.BigEndian.AppendUint64(b, d.h[6])
b = binary.BigEndian.AppendUint64(b, d.h[7])
b = append(b, d.x[:d.nx]...)
b = b[:len(b)+len(d.x)-d.nx] // already zero
b = binary.BigEndian.AppendUint64(b, d.len)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic512) {
return errors.New("crypto/sha512: invalid hash state identifier")
}
switch {
case d.function == crypto.SHA384 && string(b[:len(magic384)]) == magic384:
case d.function == crypto.SHA512_224 && string(b[:len(magic512_224)]) == magic512_224:
case d.function == crypto.SHA512_256 && string(b[:len(magic512_256)]) == magic512_256:
case d.function == crypto.SHA512 && string(b[:len(magic512)]) == magic512:
default:
return errors.New("crypto/sha512: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("crypto/sha512: invalid hash state size")
}
b = b[len(magic512):]
b, d.h[0] = consumeUint64(b)
b, d.h[1] = consumeUint64(b)
b, d.h[2] = consumeUint64(b)
b, d.h[3] = consumeUint64(b)
b, d.h[4] = consumeUint64(b)
b, d.h[5] = consumeUint64(b)
b, d.h[6] = consumeUint64(b)
b, d.h[7] = consumeUint64(b)
b = b[copy(d.x[:], b):]
b, d.len = consumeUint64(b)
d.nx = int(d.len % chunk)
return nil
}
func consumeUint64(b []byte) ([]byte, uint64) {
_ = b[7]
x := uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
return b[8:], x
}
// New returns a new hash.Hash computing the SHA-512 checksum.
func New() hash.Hash {
if boring.Enabled {
return boring.NewSHA512()
}
d := &digest{function: crypto.SHA512}
d.Reset()
return d
}
// New512_224 returns a new hash.Hash computing the SHA-512/224 checksum.
func New512_224() hash.Hash {
d := &digest{function: crypto.SHA512_224}
d.Reset()
return d
}
// New512_256 returns a new hash.Hash computing the SHA-512/256 checksum.
func New512_256() hash.Hash {
d := &digest{function: crypto.SHA512_256}
d.Reset()
return d
}
// New384 returns a new hash.Hash computing the SHA-384 checksum.
func New384() hash.Hash {
if boring.Enabled {
return boring.NewSHA384()
}
d := &digest{function: crypto.SHA384}
d.Reset()
return d
}
func (d *digest) Size() int {
switch d.function {
case crypto.SHA512_224:
return Size224
case crypto.SHA512_256:
return Size256
case crypto.SHA384:
return Size384
default:
return Size
}
}
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
boring.Unreachable()
}
nn = len(p)
d.len += uint64(nn)
if d.nx > 0 {
n := copy(d.x[d.nx:], p)
d.nx += n
if d.nx == chunk {
block(d, d.x[:])
d.nx = 0
}
p = p[n:]
}
if len(p) >= chunk {
n := len(p) &^ (chunk - 1)
block(d, p[:n])
p = p[n:]
}
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d *digest) Sum(in []byte) []byte {
if d.function != crypto.SHA512_224 && d.function != crypto.SHA512_256 {
boring.Unreachable()
}
// Make a copy of d so that caller can keep writing and summing.
d0 := new(digest)
*d0 = *d
hash := d0.checkSum()
switch d0.function {
case crypto.SHA384:
return append(in, hash[:Size384]...)
case crypto.SHA512_224:
return append(in, hash[:Size224]...)
case crypto.SHA512_256:
return append(in, hash[:Size256]...)
default:
return append(in, hash[:]...)
}
}
func (d *digest) checkSum() [Size]byte {
// Padding. Add a 1 bit and 0 bits until 112 bytes mod 128.
len := d.len
var tmp [128 + 16]byte // padding + length buffer
tmp[0] = 0x80
var t uint64
if len%128 < 112 {
t = 112 - len%128
} else {
t = 128 + 112 - len%128
}
// Length in bits.
len <<= 3
padlen := tmp[:t+16]
// Upper 64 bits are always zero, because len variable has type uint64,
// and tmp is already zeroed at that index, so we can skip updating it.
// binary.BigEndian.PutUint64(padlen[t+0:], 0)
binary.BigEndian.PutUint64(padlen[t+8:], len)
d.Write(padlen)
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
binary.BigEndian.PutUint64(digest[0:], d.h[0])
binary.BigEndian.PutUint64(digest[8:], d.h[1])
binary.BigEndian.PutUint64(digest[16:], d.h[2])
binary.BigEndian.PutUint64(digest[24:], d.h[3])
binary.BigEndian.PutUint64(digest[32:], d.h[4])
binary.BigEndian.PutUint64(digest[40:], d.h[5])
if d.function != crypto.SHA384 {
binary.BigEndian.PutUint64(digest[48:], d.h[6])
binary.BigEndian.PutUint64(digest[56:], d.h[7])
}
return digest
}
// Sum512 returns the SHA512 checksum of the data.
func Sum512(data []byte) [Size]byte {
if boring.Enabled {
return boring.SHA512(data)
}
d := digest{function: crypto.SHA512}
d.Reset()
d.Write(data)
return d.checkSum()
}
// Sum384 returns the SHA384 checksum of the data.
func Sum384(data []byte) [Size384]byte {
if boring.Enabled {
return boring.SHA384(data)
}
d := digest{function: crypto.SHA384}
d.Reset()
d.Write(data)
sum := d.checkSum()
ap := (*[Size384]byte)(sum[:])
return *ap
}
// Sum512_224 returns the Sum512/224 checksum of the data.
func Sum512_224(data []byte) [Size224]byte {
d := digest{function: crypto.SHA512_224}
d.Reset()
d.Write(data)
sum := d.checkSum()
ap := (*[Size224]byte)(sum[:])
return *ap
}
// Sum512_256 returns the Sum512/256 checksum of the data.
func Sum512_256(data []byte) [Size256]byte {
d := digest{function: crypto.SHA512_256}
d.Reset()
d.Write(data)
sum := d.checkSum()
ap := (*[Size256]byte)(sum[:])
return *ap
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// SHA512 block step.
// In its own file so that a faster assembly or C version
// can be substituted easily.
package sha512
import "math/bits"
var _K = []uint64{
0x428a2f98d728ae22,
0x7137449123ef65cd,
0xb5c0fbcfec4d3b2f,
0xe9b5dba58189dbbc,
0x3956c25bf348b538,
0x59f111f1b605d019,
0x923f82a4af194f9b,
0xab1c5ed5da6d8118,
0xd807aa98a3030242,
0x12835b0145706fbe,
0x243185be4ee4b28c,
0x550c7dc3d5ffb4e2,
0x72be5d74f27b896f,
0x80deb1fe3b1696b1,
0x9bdc06a725c71235,
0xc19bf174cf692694,
0xe49b69c19ef14ad2,
0xefbe4786384f25e3,
0x0fc19dc68b8cd5b5,
0x240ca1cc77ac9c65,
0x2de92c6f592b0275,
0x4a7484aa6ea6e483,
0x5cb0a9dcbd41fbd4,
0x76f988da831153b5,
0x983e5152ee66dfab,
0xa831c66d2db43210,
0xb00327c898fb213f,
0xbf597fc7beef0ee4,
0xc6e00bf33da88fc2,
0xd5a79147930aa725,
0x06ca6351e003826f,
0x142929670a0e6e70,
0x27b70a8546d22ffc,
0x2e1b21385c26c926,
0x4d2c6dfc5ac42aed,
0x53380d139d95b3df,
0x650a73548baf63de,
0x766a0abb3c77b2a8,
0x81c2c92e47edaee6,
0x92722c851482353b,
0xa2bfe8a14cf10364,
0xa81a664bbc423001,
0xc24b8b70d0f89791,
0xc76c51a30654be30,
0xd192e819d6ef5218,
0xd69906245565a910,
0xf40e35855771202a,
0x106aa07032bbd1b8,
0x19a4c116b8d2d0c8,
0x1e376c085141ab53,
0x2748774cdf8eeb99,
0x34b0bcb5e19b48a8,
0x391c0cb3c5c95a63,
0x4ed8aa4ae3418acb,
0x5b9cca4f7763e373,
0x682e6ff3d6b2b8a3,
0x748f82ee5defb2fc,
0x78a5636f43172f60,
0x84c87814a1f0ab72,
0x8cc702081a6439ec,
0x90befffa23631e28,
0xa4506cebde82bde9,
0xbef9a3f7b2c67915,
0xc67178f2e372532b,
0xca273eceea26619c,
0xd186b8c721c0c207,
0xeada7dd6cde0eb1e,
0xf57d4f7fee6ed178,
0x06f067aa72176fba,
0x0a637dc5a2c898a6,
0x113f9804bef90dae,
0x1b710b35131c471b,
0x28db77f523047d84,
0x32caab7b40c72493,
0x3c9ebe0a15c9bebc,
0x431d67c49c100d4c,
0x4cc5d4becb3e42b6,
0x597f299cfc657e2a,
0x5fcb6fab3ad6faec,
0x6c44198c4a475817,
}
func blockGeneric(dig *digest, p []byte) {
var w [80]uint64
h0, h1, h2, h3, h4, h5, h6, h7 := dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7]
for len(p) >= chunk {
for i := 0; i < 16; i++ {
j := i * 8
w[i] = uint64(p[j])<<56 | uint64(p[j+1])<<48 | uint64(p[j+2])<<40 | uint64(p[j+3])<<32 |
uint64(p[j+4])<<24 | uint64(p[j+5])<<16 | uint64(p[j+6])<<8 | uint64(p[j+7])
}
for i := 16; i < 80; i++ {
v1 := w[i-2]
t1 := bits.RotateLeft64(v1, -19) ^ bits.RotateLeft64(v1, -61) ^ (v1 >> 6)
v2 := w[i-15]
t2 := bits.RotateLeft64(v2, -1) ^ bits.RotateLeft64(v2, -8) ^ (v2 >> 7)
w[i] = t1 + w[i-7] + t2 + w[i-16]
}
a, b, c, d, e, f, g, h := h0, h1, h2, h3, h4, h5, h6, h7
for i := 0; i < 80; i++ {
t1 := h + (bits.RotateLeft64(e, -14) ^ bits.RotateLeft64(e, -18) ^ bits.RotateLeft64(e, -41)) + ((e & f) ^ (^e & g)) + _K[i] + w[i]
t2 := (bits.RotateLeft64(a, -28) ^ bits.RotateLeft64(a, -34) ^ bits.RotateLeft64(a, -39)) + ((a & b) ^ (a & c) ^ (b & c))
h = g
g = f
f = e
e = d + t1
d = c
c = b
b = a
a = t1 + t2
}
h0 += a
h1 += b
h2 += c
h3 += d
h4 += e
h5 += f
h6 += g
h7 += h
p = p[chunk:]
}
dig.h[0], dig.h[1], dig.h[2], dig.h[3], dig.h[4], dig.h[5], dig.h[6], dig.h[7] = h0, h1, h2, h3, h4, h5, h6, h7
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64
package sha512
import "internal/cpu"
//go:noescape
func blockAVX2(dig *digest, p []byte)
//go:noescape
func blockAMD64(dig *digest, p []byte)
var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI1 && cpu.X86.HasBMI2
func block(dig *digest, p []byte) {
if useAVX2 {
blockAVX2(dig, p)
} else {
blockAMD64(dig, p)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle
// ConstantTimeCompare returns 1 if the two slices, x and y, have equal contents
// and 0 otherwise. The time taken is a function of the length of the slices and
// is independent of the contents. If the lengths of x and y do not match it
// returns 0 immediately.
func ConstantTimeCompare(x, y []byte) int {
if len(x) != len(y) {
return 0
}
var v byte
for i := 0; i < len(x); i++ {
v |= x[i] ^ y[i]
}
return ConstantTimeByteEq(v, 0)
}
// ConstantTimeSelect returns x if v == 1 and y if v == 0.
// Its behavior is undefined if v takes any other value.
func ConstantTimeSelect(v, x, y int) int { return ^(v-1)&x | (v-1)&y }
// ConstantTimeByteEq returns 1 if x == y and 0 otherwise.
func ConstantTimeByteEq(x, y uint8) int {
return int((uint32(x^y) - 1) >> 31)
}
// ConstantTimeEq returns 1 if x == y and 0 otherwise.
func ConstantTimeEq(x, y int32) int {
return int((uint64(uint32(x^y)) - 1) >> 63)
}
// ConstantTimeCopy copies the contents of y into x (a slice of equal length)
// if v == 1. If v == 0, x is left unchanged. Its behavior is undefined if v
// takes any other value.
func ConstantTimeCopy(v int, x, y []byte) {
if len(x) != len(y) {
panic("subtle: slices have different lengths")
}
xmask := byte(v - 1)
ymask := byte(^(v - 1))
for i := 0; i < len(x); i++ {
x[i] = x[i]&xmask | y[i]&ymask
}
}
// ConstantTimeLessOrEq returns 1 if x <= y and 0 otherwise.
// Its behavior is undefined if x or y are negative or > 2**31 - 1.
func ConstantTimeLessOrEq(x, y int) int {
x32 := int32(x)
y32 := int32(y)
return int(((x32 - y32 - 1) >> 31) & 1)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package subtle
// XORBytes sets dst[i] = x[i] ^ y[i] for all i < n = min(len(x), len(y)),
// returning n, the number of bytes written to dst.
// If dst does not have length at least n,
// XORBytes panics without writing anything to dst.
func XORBytes(dst, x, y []byte) int {
n := len(x)
if len(y) < n {
n = len(y)
}
if n == 0 {
return 0
}
if n > len(dst) {
panic("subtle.XORBytes: dst too short")
}
xorBytes(&dst[0], &x[0], &y[0], n) // arch-specific
return n
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import "strconv"
type alert uint8
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertExportRestriction alert = 60
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertInappropriateFallback alert = 86
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
alertCertificateUnobtainable alert = 111
alertUnrecognizedName alert = 112
alertBadCertificateStatusResponse alert = 113
alertBadCertificateHashValue alert = 114
alertUnknownPSKIdentity alert = 115
alertCertificateRequired alert = 116
alertNoApplicationProtocol alert = 120
)
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertExportRestriction: "export restriction",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertInappropriateFallback: "inappropriate fallback",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
alertCertificateUnobtainable: "certificate unobtainable",
alertUnrecognizedName: "unrecognized name",
alertBadCertificateStatusResponse: "bad certificate status response",
alertBadCertificateHashValue: "bad certificate hash value",
alertUnknownPSKIdentity: "unknown PSK identity",
alertCertificateRequired: "certificate required",
alertNoApplicationProtocol: "no application protocol",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return "tls: " + s
}
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
)
// verifyHandshakeSignature verifies a signature against pre-hashed
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
return errors.New("ECDSA verification failure")
}
case signatureEd25519:
pubKey, ok := pubkey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
}
if !ed25519.Verify(pubKey, signed, sig) {
return errors.New("Ed25519 verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
return err
}
case signatureRSAPSS:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
)
var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
}
// signedMessage returns the pre-hashed (if necessary) message to be signed by
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
if sigHash == directSigning {
b := &bytes.Buffer{}
b.Write(signaturePadding)
io.WriteString(b, context)
b.Write(transcript.Sum(nil))
return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
}
// typeAndHashFromSignatureScheme returns the corresponding signature type and
// crypto.Hash for a given TLS SignatureScheme.
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
sigType = signaturePKCS1v15
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
sigType = signatureRSAPSS
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
switch signatureAlgorithm {
case PKCS1WithSHA1, ECDSAWithSHA1:
hash = crypto.SHA1
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
hash = crypto.SHA256
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
hash = crypto.SHA384
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
hash = crypto.SHA512
case Ed25519:
hash = directSigning
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
return sigType, hash, nil
}
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
// a given public key used with TLS 1.0 and 1.1, before the introduction of
// signature algorithm negotiation.
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
switch pub.(type) {
case *rsa.PublicKey:
return signaturePKCS1v15, crypto.MD5SHA1, nil
case *ecdsa.PublicKey:
return signatureECDSA, crypto.SHA1, nil
case ed25519.PublicKey:
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
// but it requires holding on to a handshake transcript to do a
// full signature, and not even OpenSSL bothers with the
// complexity, so we can't even test it properly.
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
default:
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
}
}
var rsaSignatureSchemes = []struct {
scheme SignatureScheme
minModulusBytes int
maxVersion uint16
}{
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
// emLen >= hLen + sLen + 2
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
// emLen >= len(prefix) + hLen + 11
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
}
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
// for a given certificate, based on the public key and the protocol version,
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
//
// This function must be kept in sync with supportedSignatureAlgorithms.
// FIPS filtering is applied in the caller, selectSignatureScheme.
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil
}
var sigAlgs []SignatureScheme
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
if version != VersionTLS13 {
// In TLS 1.2 and earlier, ECDSA algorithms are not
// constrained to a single curve.
sigAlgs = []SignatureScheme{
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
ECDSAWithSHA1,
}
break
}
switch pub.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case *rsa.PublicKey:
size := pub.Size()
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
for _, candidate := range rsaSignatureSchemes {
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
sigAlgs = append(sigAlgs, candidate.scheme)
}
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
default:
return nil
}
if cert.SupportedSignatureAlgorithms != nil {
var filteredSigAlgs []SignatureScheme
for _, sigAlg := range sigAlgs {
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
}
}
return filteredSigAlgs
}
return sigAlgs
}
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
// that works with the selected certificate. It's only called for protocol
// versions that support signature algorithms, so TLS 1.2 and 1.3.
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
supportedAlgs := signatureSchemesForCertificate(vers, c)
if len(supportedAlgs) == 0 {
return 0, unsupportedCertificateError(c)
}
if len(peerAlgs) == 0 && vers == VersionTLS12 {
// For TLS 1.2, if the client didn't send signature_algorithms then we
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
}
// Pick signature scheme in the peer's preference order, as our
// preference order is not configurable.
for _, preferredAlg := range peerAlgs {
if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
continue
}
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
return preferredAlg, nil
}
}
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
}
// unsupportedCertificateError returns a helpful error for certificates with
// an unsupported private key.
func unsupportedCertificateError(cert *Certificate) error {
switch cert.PrivateKey.(type) {
case rsa.PrivateKey, ecdsa.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
cert.PrivateKey, cert.PrivateKey)
case *ed25519.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
}
signer, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
cert.PrivateKey)
}
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
case elliptic.P384():
case elliptic.P521():
default:
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
}
case *rsa.PublicKey:
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
case ed25519.PublicKey:
default:
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
}
if cert.SupportedSignatureAlgorithms != nil {
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
}
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/x509"
"runtime"
"sync"
"sync/atomic"
)
type cacheEntry struct {
refs atomic.Int64
cert *x509.Certificate
}
// certCache implements an intern table for reference counted x509.Certificates,
// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This
// allows for a single x509.Certificate to be kept in memory and referenced from
// multiple Conns. Returned references should not be mutated by callers. Certificates
// are still safe to use after they are removed from the cache.
//
// Certificates are returned wrapped in a activeCert struct that should be held by
// the caller. When references to the activeCert are freed, the number of references
// to the certificate in the cache is decremented. Once the number of references
// reaches zero, the entry is evicted from the cache.
//
// The main difference between this implementation and CRYPTO_BUFFER_POOL is that
// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data,
// rather than specific structures. Since we only care about x509.Certificates,
// certCache is implemented as a specific cache, rather than a generic one.
//
// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h
// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c
// for the BoringSSL reference.
type certCache struct {
sync.Map
}
var clientCertCache = new(certCache)
// activeCert is a handle to a certificate held in the cache. Once there are
// no alive activeCerts for a given certificate, the certificate is removed
// from the cache by a finalizer.
type activeCert struct {
cert *x509.Certificate
}
// active increments the number of references to the entry, wraps the
// certificate in the entry in a activeCert, and sets the finalizer.
//
// Note that there is a race between active and the finalizer set on the
// returned activeCert, triggered if active is called after the ref count is
// decremented such that refs may be > 0 when evict is called. We consider this
// safe, since the caller holding an activeCert for an entry that is no longer
// in the cache is fine, with the only side effect being the memory overhead of
// there being more than one distinct reference to a certificate alive at once.
func (cc *certCache) active(e *cacheEntry) *activeCert {
e.refs.Add(1)
a := &activeCert{e.cert}
runtime.SetFinalizer(a, func(_ *activeCert) {
if e.refs.Add(-1) == 0 {
cc.evict(e)
}
})
return a
}
// evict removes a cacheEntry from the cache.
func (cc *certCache) evict(e *cacheEntry) {
cc.Delete(string(e.cert.Raw))
}
// newCert returns a x509.Certificate parsed from der. If there is already a copy
// of the certificate in the cache, a reference to the existing certificate will
// be returned. Otherwise, a fresh certificate will be added to the cache, and
// the reference returned. The returned reference should not be mutated.
func (cc *certCache) newCert(der []byte) (*activeCert, error) {
if entry, ok := cc.Load(string(der)); ok {
return cc.active(entry.(*cacheEntry)), nil
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
entry := &cacheEntry{cert: cert}
if entry, loaded := cc.LoadOrStore(string(der), entry); loaded {
return cc.active(entry.(*cacheEntry)), nil
}
return cc.active(entry), nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/internal/boring"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"fmt"
"hash"
"internal/cpu"
"runtime"
"golang.org/x/crypto/chacha20poly1305"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
// accept and expose cipher suite IDs instead of this type.
type CipherSuite struct {
ID uint16
Name string
// Supported versions is the list of TLS protocol versions that can
// negotiate this cipher suite.
SupportedVersions []uint16
// Insecure is true if the cipher suite has known security issues
// due to its primitives, design, or implementation.
Insecure bool
}
var (
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
supportedOnlyTLS12 = []uint16{VersionTLS12}
supportedOnlyTLS13 = []uint16{VersionTLS13}
)
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
//
// The list is sorted by ID. Note that the default cipher suites selected by
// this package might depend on logic that can't be captured by a static list,
// and might not match those returned by this function.
func CipherSuites() []*CipherSuite {
return []*CipherSuite{
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
//
// Most applications should not use the cipher suites in this list, and should
// only use those returned by CipherSuites.
func InsecureCipherSuites() []*CipherSuite {
// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
// cipherSuitesPreferenceOrder for details.
return []*CipherSuite{
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
}
}
// CipherSuiteName returns the standard name for the passed cipher suite ID
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
// of the ID value if the cipher suite is not implemented by this package.
func CipherSuiteName(id uint16) string {
for _, c := range CipherSuites() {
if c.ID == id {
return c.Name
}
}
for _, c := range InsecureCipherSuites() {
if c.ID == id {
return c.Name
}
}
return fmt.Sprintf("0x%04X", id)
}
const (
// suiteECDHE indicates that the cipher suite involves elliptic curve
// Diffie-Hellman. This means that it should only be selected when the
// client indicates that it supports ECC with a curve and point format
// that we're happy with.
suiteECDHE = 1 << iota
// suiteECSign indicates that the cipher suite involves an ECDSA or
// EdDSA signature and therefore may only be selected when the server's
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
// is RSA based.
suiteECSign
// suiteTLS12 indicates that the cipher suite should only be advertised
// and accepted when using TLS 1.2.
suiteTLS12
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
)
// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange
// mechanism, as well as the cipher+MAC pair or the AEAD.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
ivLen int
ka func(version uint16) keyAgreement
// flags is a bitmask of the suite* values, above.
flags int
cipher func(key, iv []byte, isRead bool) any
mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead
}
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
}
// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which
// is also in supportedIDs and passes the ok filter.
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
for _, id := range ids {
candidate := cipherSuiteByID(id)
if candidate == nil || !ok(candidate) {
continue
}
for _, suppID := range supportedIDs {
if id == suppID {
return candidate
}
}
}
return nil
}
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
type cipherSuiteTLS13 struct {
id uint16
keyLen int
aead func(key, fixedNonce []byte) aead
hash crypto.Hash
}
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
}
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
// server) or advertise (on the client) TLS 1.0–1.2 cipher suites.
//
// Cipher suites are filtered but not reordered based on the application and
// peer's preferences, meaning we'll never select a suite lower in this list if
// any higher one is available. This makes it more defensible to keep weaker
// cipher suites enabled, especially on the server side where we get the last
// word, since there are no known downgrade attacks on cipher suites selection.
//
// The list is sorted by applying the following priority rules, stopping at the
// first (most important) applicable one:
//
// - Anything else comes before RC4
//
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
//
// - Anything else comes before CBC_SHA256
//
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
//
// - Anything else comes before 3DES
//
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
// birthday attacks. See https://sweet32.info.
//
// - ECDHE comes before anything else
//
// Once we got the broken stuff out of the way, the most important
// property a cipher suite can have is forward secrecy. We don't
// implement FFDHE, so that means ECDHE.
//
// - AEADs come before CBC ciphers
//
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
// are fundamentally fragile, and suffered from an endless sequence of
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
//
// - AES comes before ChaCha20
//
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
// than ChaCha20Poly1305.
//
// When AES hardware is not available, AES-128-GCM is one or more of: much
// slower, way more complex, and less safe (because not constant time)
// than ChaCha20Poly1305.
//
// We use this list if we think both peers have AES hardware, and
// cipherSuitesPreferenceOrderNoAES otherwise.
//
// - AES-128 comes before AES-256
//
// The only potential advantages of AES-256 are better multi-target
// margins, and hypothetical post-quantum properties. Neither apply to
// TLS, and AES-256 is slower due to its four extra rounds (which don't
// contribute to the advantages above).
//
// - ECDSA comes before RSA
//
// The relative order of ECDSA and RSA cipher suites doesn't matter,
// as they depend on the certificate. Pick one to get a stable order.
var cipherSuitesPreferenceOrder = []uint16{
// AEADs w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// CBC w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
// AEADs w/o ECDHE
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
// CBC w/o ECDHE
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
// 3DES
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var cipherSuitesPreferenceOrderNoAES = []uint16{
// ChaCha20Poly1305
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// AES-GCM w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// The rest of cipherSuitesPreferenceOrder.
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
// disabledCipherSuites are not used unless explicitly listed in
// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
var disabledCipherSuites = []uint16{
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var (
defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
)
// defaultCipherSuitesTLS13 is also the preference order, since there are no
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
// cipherSuitesPreferenceOrder applies.
var defaultCipherSuitesTLS13 = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
TLS_CHACHA20_POLY1305_SHA256,
}
var defaultCipherSuitesTLS13NoAES = []uint16{
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
)
var aesgcmCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
// TLS 1.3
TLS_AES_128_GCM_SHA256: true,
TLS_AES_256_GCM_SHA384: true,
}
var nonAESGCMAEADCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true,
// TLS 1.3
TLS_CHACHA20_POLY1305_SHA256: true,
}
// aesgcmPreferred returns whether the first known cipher in the preference list
// is an AES-GCM cipher, implying the peer has hardware support for it.
func aesgcmPreferred(ciphers []uint16) bool {
for _, cID := range ciphers {
if c := cipherSuiteByID(cID); c != nil {
return aesgcmCiphers[cID]
}
if c := cipherSuiteTLS13ByID(cID); c != nil {
return aesgcmCiphers[cID]
}
}
return false
}
func cipherRC4(key, iv []byte, isRead bool) any {
cipher, _ := rc4.NewCipher(key)
return cipher
}
func cipher3DES(key, iv []byte, isRead bool) any {
block, _ := des.NewTripleDESCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
func cipherAES(key, iv []byte, isRead bool) any {
block, _ := aes.NewCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
// macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(key []byte) hash.Hash {
h := sha1.New
// The BoringCrypto SHA1 does not have a constant-time
// checksum function, so don't try to use it.
if !boring.Enabled {
h = newConstantTimeHash(h)
}
return hmac.New(h, key)
}
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// is currently only used in disabled-by-default cipher suites.
func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}
type aead interface {
cipher.AEAD
// explicitNonceLen returns the number of bytes of explicit nonce
// included in each record. This is eight for older AEADs and
// zero for modern ones.
explicitNonceLen() int
}
const (
aeadNonceLength = 12
noncePrefixLength = 4
)
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
// each call.
type prefixNonceAEAD struct {
// nonce contains the fixed part of the nonce in the first four bytes.
nonce [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
copy(f.nonce[4:], nonce)
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
}
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
copy(f.nonce[4:], nonce)
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
}
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
func aeadAESGCM(key, noncePrefix []byte) aead {
if len(noncePrefix) != noncePrefixLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
var aead cipher.AEAD
if boring.Enabled {
aead, err = boring.NewGCMTLS(aes)
} else {
boring.Unreachable()
aead, err = cipher.NewGCM(aes)
}
if err != nil {
panic(err)
}
ret := &prefixNonceAEAD{aead: aead}
copy(ret.nonce[:], noncePrefix)
return ret
}
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
type constantTimeHash interface {
hash.Hash
ConstantTimeSum(b []byte) []byte
}
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
type cthWrapper struct {
h constantTimeHash
}
func (c *cthWrapper) Size() int { return c.h.Size() }
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
func (c *cthWrapper) Reset() { c.h.Reset() }
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
boring.Unreachable()
return func() hash.Hash {
return &cthWrapper{h().(constantTimeHash)}
}
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h.Reset()
h.Write(seq)
h.Write(header)
h.Write(data)
res := h.Sum(out)
if extra != nil {
h.Write(extra)
}
return res
}
func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{}
}
func ecdheECDSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: false,
version: version,
}
}
func ecdheRSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: true,
version: version,
}
}
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuiteByID(id)
}
}
return nil
}
func cipherSuiteByID(id uint16) *cipherSuite {
for _, cipherSuite := range cipherSuites {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
for _, id := range have {
if id == want {
return cipherSuiteTLS13ByID(id)
}
}
return nil
}
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
for _, cipherSuite := range cipherSuitesTLS13 {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
// A list of cipher suite IDs that are, or have been, implemented by this
// package.
//
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
const (
// TLS 1.0 - 1.2 cipher suites.
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
// TLS 1.3 cipher suites.
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See RFC 7507.
TLS_FALLBACK_SCSV uint16 = 0x5600
// Legacy names for the corresponding cipher suites with the correct _SHA256
// suffix, retained for backward compatibility.
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"container/list"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha512"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
)
const (
VersionTLS10 = 0x0301
VersionTLS11 = 0x0302
VersionTLS12 = 0x0303
VersionTLS13 = 0x0304
// Deprecated: SSLv3 is cryptographically broken, and is no longer
// supported by this package. See golang.org/issue/32716.
VersionSSL30 = 0x0300
)
const (
maxPlaintext = 16384 // maximum plaintext payload length
maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
maxUselessRecords = 16 // maximum number of consecutive non-advancing records
)
// TLS record types.
type recordType uint8
const (
recordTypeChangeCipherSpec recordType = 20
recordTypeAlert recordType = 21
recordTypeHandshake recordType = 22
recordTypeApplicationData recordType = 23
)
// TLS handshake message types.
const (
typeHelloRequest uint8 = 0
typeClientHello uint8 = 1
typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4
typeEndOfEarlyData uint8 = 5
typeEncryptedExtensions uint8 = 8
typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12
typeCertificateRequest uint8 = 13
typeServerHelloDone uint8 = 14
typeCertificateVerify uint8 = 15
typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20
typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
typeNextProtocol uint8 = 67 // Not IANA assigned
typeMessageHash uint8 = 254 // synthetic message
)
// TLS compression types.
const (
compressionNone uint8 = 0
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35
extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43
extensionCookie uint16 = 44
extensionPSKModes uint16 = 45
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionRenegotiationInfo uint16 = 0xff01
)
// TLS signaling cipher suite values
const (
scsvRenegotiation uint16 = 0x00ff
)
// CurveID is the type of a TLS identifier for an elliptic curve. See
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8.
//
// In TLS 1.3, this type is called NamedGroup, but at this time this library
// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7.
type CurveID uint16
const (
CurveP256 CurveID = 23
CurveP384 CurveID = 24
CurveP521 CurveID = 25
X25519 CurveID = 29
)
// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8.
type keyShare struct {
group CurveID
data []byte
}
// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9.
const (
pskModePlain uint8 = 0
pskModeDHE uint8 = 1
)
// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved
// session. See RFC 8446, Section 4.2.11.
type pskIdentity struct {
label []byte
obfuscatedTicketAge uint32
}
// TLS Elliptic Curve Point Formats
// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9
const (
pointFormatUncompressed uint8 = 0
)
// TLS CertificateStatusType (RFC 3546)
const (
statusTypeOCSP uint8 = 1
)
// Certificate types (for certificateRequestMsg)
const (
certTypeRSASign = 1
certTypeECDSASign = 64 // ECDSA or EdDSA keys, see RFC 8422, Section 3.
)
// Signature algorithms (for internal signaling use). Starting at 225 to avoid overlap with
// TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do.
const (
signaturePKCS1v15 uint8 = iota + 225
signatureRSAPSS
signatureECDSA
signatureEd25519
)
// directSigning is a standard Hash value that signals that no pre-hashing
// should be performed, and that the input should be signed directly. It is the
// hash function associated with the Ed25519 signature scheme.
var directSigning crypto.Hash = 0
// defaultSupportedSignatureAlgorithms contains the signature and hash algorithms that
// the code advertises as supported in a TLS 1.2+ ClientHello and in a TLS 1.2+
// CertificateRequest. The two fields are merged to match with TLS 1.3.
// Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc.
var defaultSupportedSignatureAlgorithms = []SignatureScheme{
PSSWithSHA256,
ECDSAWithP256AndSHA256,
Ed25519,
PSSWithSHA384,
PSSWithSHA512,
PKCS1WithSHA256,
PKCS1WithSHA384,
PKCS1WithSHA512,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
PKCS1WithSHA1,
ECDSAWithSHA1,
}
// helloRetryRequestRandom is set as the Random value of a ServerHello
// to signal that the message is actually a HelloRetryRequest.
var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3.
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
}
const (
// downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server
// random as a downgrade protection if the server would be capable of
// negotiating a higher version. See RFC 8446, Section 4.1.3.
downgradeCanaryTLS12 = "DOWNGRD\x01"
downgradeCanaryTLS11 = "DOWNGRD\x00"
)
// testingOnlyForceDowngradeCanary is set in tests to force the server side to
// include downgrade canaries even if it's using its highers supported version.
var testingOnlyForceDowngradeCanary bool
// ConnectionState records basic TLS details about the connection.
type ConnectionState struct {
// Version is the TLS version used by the connection (e.g. VersionTLS12).
Version uint16
// HandshakeComplete is true if the handshake has concluded.
HandshakeComplete bool
// DidResume is true if this connection was successfully resumed from a
// previous session with a session ticket or similar mechanism.
DidResume bool
// CipherSuite is the cipher suite negotiated for the connection (e.g.
// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_AES_128_GCM_SHA256).
CipherSuite uint16
// NegotiatedProtocol is the application protocol negotiated with ALPN.
NegotiatedProtocol string
// NegotiatedProtocolIsMutual used to indicate a mutual NPN negotiation.
//
// Deprecated: this value is always true.
NegotiatedProtocolIsMutual bool
// ServerName is the value of the Server Name Indication extension sent by
// the client. It's available both on the server and on the client side.
ServerName string
// PeerCertificates are the parsed certificates sent by the peer, in the
// order in which they were sent. The first element is the leaf certificate
// that the connection is verified against.
//
// On the client side, it can't be empty. On the server side, it can be
// empty if Config.ClientAuth is not RequireAnyClientCert or
// RequireAndVerifyClientCert.
//
// PeerCertificates and its contents should not be modified.
PeerCertificates []*x509.Certificate
// VerifiedChains is a list of one or more chains where the first element is
// PeerCertificates[0] and the last element is from Config.RootCAs (on the
// client side) or Config.ClientCAs (on the server side).
//
// On the client side, it's set if Config.InsecureSkipVerify is false. On
// the server side, it's set if Config.ClientAuth is VerifyClientCertIfGiven
// (and the peer provided a certificate) or RequireAndVerifyClientCert.
//
// VerifiedChains and its contents should not be modified.
VerifiedChains [][]*x509.Certificate
// SignedCertificateTimestamps is a list of SCTs provided by the peer
// through the TLS handshake for the leaf certificate, if any.
SignedCertificateTimestamps [][]byte
// OCSPResponse is a stapled Online Certificate Status Protocol (OCSP)
// response provided by the peer for the leaf certificate, if any.
OCSPResponse []byte
// TLSUnique contains the "tls-unique" channel binding value (see RFC 5929,
// Section 3). This value will be nil for TLS 1.3 connections and for all
// resumed connections.
//
// Deprecated: there are conditions in which this value might not be unique
// to a connection. See the Security Considerations sections of RFC 5705 and
// RFC 7627, and https://mitls.org/pages/attacks/3SHAKE#channelbindings.
TLSUnique []byte
// ekm is a closure exposed via ExportKeyingMaterial.
ekm func(label string, context []byte, length int) ([]byte, error)
}
// ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in RFC 5705. If context is nil, it is not used as part of
// the seed. If the connection was set to allow renegotiation via
// Config.Renegotiation, this function will return an error.
func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return cs.ekm(label, context, length)
}
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
// NoClientCert indicates that no client certificate should be requested
// during the handshake, and if any certificates are sent they will not
// be verified.
NoClientCert ClientAuthType = iota
// RequestClientCert indicates that a client certificate should be requested
// during the handshake, but does not require that the client send any
// certificates.
RequestClientCert
// RequireAnyClientCert indicates that a client certificate should be requested
// during the handshake, and that at least one certificate is required to be
// sent by the client, but that certificate is not required to be valid.
RequireAnyClientCert
// VerifyClientCertIfGiven indicates that a client certificate should be requested
// during the handshake, but does not require that the client sends a
// certificate. If the client does send a certificate it is required to be
// valid.
VerifyClientCertIfGiven
// RequireAndVerifyClientCert indicates that a client certificate should be requested
// during the handshake, and that at least one valid certificate is required
// to be sent by the client.
RequireAndVerifyClientCert
)
// requiresClientCert reports whether the ClientAuthType requires a client
// certificate to be provided.
func requiresClientCert(c ClientAuthType) bool {
switch c {
case RequireAnyClientCert, RequireAndVerifyClientCert:
return true
default:
return false
}
}
// ClientSessionState contains the state needed by clients to resume TLS
// sessions.
type ClientSessionState struct {
sessionTicket []uint8 // Encrypted ticket used for session resumption with server
vers uint16 // TLS version negotiated for the session
cipherSuite uint16 // Ciphersuite negotiated for the session
masterSecret []byte // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret
serverCertificates []*x509.Certificate // Certificate chain presented by the server
verifiedChains [][]*x509.Certificate // Certificate chains we built for verification
receivedAt time.Time // When the session ticket was received from the server
ocspResponse []byte // Stapled OCSP response presented by the server
scts [][]byte // SCTs presented by the server
// TLS 1.3 fields.
nonce []byte // Ticket nonce sent by the server, to derive PSK
useBy time.Time // Expiration of the ticket lifetime as set by the server
ageAdd uint32 // Random obfuscation factor for sending the ticket age
}
// ClientSessionCache is a cache of ClientSessionState objects that can be used
// by a client to resume a TLS session with a given server. ClientSessionCache
// implementations should expect to be called concurrently from different
// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not
// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which
// are supported via this interface.
type ClientSessionCache interface {
// Get searches for a ClientSessionState associated with the given key.
// On return, ok is true if one was found.
Get(sessionKey string) (session *ClientSessionState, ok bool)
// Put adds the ClientSessionState to the cache with the given key. It might
// get called multiple times in a connection if a TLS 1.3 server provides
// more than one session ticket. If called with a nil *ClientSessionState,
// it should remove the cache entry.
Put(sessionKey string, cs *ClientSessionState)
}
//go:generate stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go
// SignatureScheme identifies a signature algorithm supported by TLS. See
// RFC 8446, Section 4.2.3.
type SignatureScheme uint16
const (
// RSASSA-PKCS1-v1_5 algorithms.
PKCS1WithSHA256 SignatureScheme = 0x0401
PKCS1WithSHA384 SignatureScheme = 0x0501
PKCS1WithSHA512 SignatureScheme = 0x0601
// RSASSA-PSS algorithms with public key OID rsaEncryption.
PSSWithSHA256 SignatureScheme = 0x0804
PSSWithSHA384 SignatureScheme = 0x0805
PSSWithSHA512 SignatureScheme = 0x0806
// ECDSA algorithms. Only constrained to a specific curve in TLS 1.3.
ECDSAWithP256AndSHA256 SignatureScheme = 0x0403
ECDSAWithP384AndSHA384 SignatureScheme = 0x0503
ECDSAWithP521AndSHA512 SignatureScheme = 0x0603
// EdDSA algorithms.
Ed25519 SignatureScheme = 0x0807
// Legacy signature and hash algorithms for TLS 1.2.
PKCS1WithSHA1 SignatureScheme = 0x0201
ECDSAWithSHA1 SignatureScheme = 0x0203
)
// ClientHelloInfo contains information from a ClientHello message in order to
// guide application logic in the GetCertificate and GetConfigForClient callbacks.
type ClientHelloInfo struct {
// CipherSuites lists the CipherSuites supported by the client (e.g.
// TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).
CipherSuites []uint16
// ServerName indicates the name of the server requested by the client
// in order to support virtual hosting. ServerName is only set if the
// client is using SNI (see RFC 4366, Section 3.1).
ServerName string
// SupportedCurves lists the elliptic curves supported by the client.
// SupportedCurves is set only if the Supported Elliptic Curves
// Extension is being used (see RFC 4492, Section 5.1.1).
SupportedCurves []CurveID
// SupportedPoints lists the point formats supported by the client.
// SupportedPoints is set only if the Supported Point Formats Extension
// is being used (see RFC 4492, Section 5.1.2).
SupportedPoints []uint8
// SignatureSchemes lists the signature and hash schemes that the client
// is willing to verify. SignatureSchemes is set only if the Signature
// Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1).
SignatureSchemes []SignatureScheme
// SupportedProtos lists the application protocols supported by the client.
// SupportedProtos is set only if the Application-Layer Protocol
// Negotiation Extension is being used (see RFC 7301, Section 3.1).
//
// Servers can select a protocol by setting Config.NextProtos in a
// GetConfigForClient return value.
SupportedProtos []string
// SupportedVersions lists the TLS versions supported by the client.
// For TLS versions less than 1.3, this is extrapolated from the max
// version advertised by the client, so values other than the greatest
// might be rejected if used.
SupportedVersions []uint16
// Conn is the underlying net.Conn for the connection. Do not read
// from, or write to, this connection; that will cause the TLS
// connection to fail.
Conn net.Conn
// config is embedded by the GetCertificate or GetConfigForClient caller,
// for use with SupportsCertificate.
config *Config
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *ClientHelloInfo) Context() context.Context {
return c.ctx
}
// CertificateRequestInfo contains information from a server's
// CertificateRequest message, which is used to demand a certificate and proof
// of control from a client.
type CertificateRequestInfo struct {
// AcceptableCAs contains zero or more, DER-encoded, X.501
// Distinguished Names. These are the names of root or intermediate CAs
// that the server wishes the returned certificate to be signed by. An
// empty slice indicates that the server has no preference.
AcceptableCAs [][]byte
// SignatureSchemes lists the signature schemes that the server is
// willing to verify.
SignatureSchemes []SignatureScheme
// Version is the TLS version that was negotiated for this connection.
Version uint16
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *CertificateRequestInfo) Context() context.Context {
return c.ctx
}
// RenegotiationSupport enumerates the different levels of support for TLS
// renegotiation. TLS renegotiation is the act of performing subsequent
// handshakes on a connection after the first. This significantly complicates
// the state machine and has been the source of numerous, subtle security
// issues. Initiating a renegotiation is not supported, but support for
// accepting renegotiation requests may be enabled.
//
// Even when enabled, the server may not change its identity between handshakes
// (i.e. the leaf certificate must be the same). Additionally, concurrent
// handshake and application data flow is not permitted so renegotiation can
// only be used with protocols that synchronise with the renegotiation, such as
// HTTPS.
//
// Renegotiation is not defined in TLS 1.3.
type RenegotiationSupport int
const (
// RenegotiateNever disables renegotiation.
RenegotiateNever RenegotiationSupport = iota
// RenegotiateOnceAsClient allows a remote server to request
// renegotiation once per connection.
RenegotiateOnceAsClient
// RenegotiateFreelyAsClient allows a remote server to repeatedly
// request renegotiation.
RenegotiateFreelyAsClient
)
// A Config structure is used to configure a TLS client or server.
// After one has been passed to a TLS function it must not be
// modified. A Config may be reused; the tls package will also not
// modify it.
type Config struct {
// Rand provides the source of entropy for nonces and RSA blinding.
// If Rand is nil, TLS uses the cryptographic random reader in package
// crypto/rand.
// The Reader must be safe for use by multiple goroutines.
Rand io.Reader
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// Certificates contains one or more certificate chains to present to the
// other side of the connection. The first certificate compatible with the
// peer's requirements is selected automatically.
//
// Server configurations must set one of Certificates, GetCertificate or
// GetConfigForClient. Clients doing client-authentication may set either
// Certificates or GetClientCertificate.
//
// Note: if there are multiple Certificates, and they don't have the
// optional field Leaf set, certificate selection will incur a significant
// per-handshake performance cost.
Certificates []Certificate
// NameToCertificate maps from a certificate name to an element of
// Certificates. Note that a certificate name can be of the form
// '*.example.com' and so doesn't have to be a domain name as such.
//
// Deprecated: NameToCertificate only allows associating a single
// certificate with a given name. Leave this field nil to let the library
// select the first compatible chain from Certificates.
NameToCertificate map[string]*Certificate
// GetCertificate returns a Certificate based on the given
// ClientHelloInfo. It will only be called if the client supplies SNI
// information or if Certificates is empty.
//
// If GetCertificate is nil or returns nil, then the certificate is
// retrieved from NameToCertificate. If NameToCertificate is nil, the
// best element of Certificates will be used.
//
// Once a Certificate is returned it should not be modified.
GetCertificate func(*ClientHelloInfo) (*Certificate, error)
// GetClientCertificate, if not nil, is called when a server requests a
// certificate from a client. If set, the contents of Certificates will
// be ignored.
//
// If GetClientCertificate returns an error, the handshake will be
// aborted and that error will be returned. Otherwise
// GetClientCertificate must return a non-nil Certificate. If
// Certificate.Certificate is empty then no certificate will be sent to
// the server. If this is unacceptable to the server then it may abort
// the handshake.
//
// GetClientCertificate may be called multiple times for the same
// connection if renegotiation occurs or if TLS 1.3 is in use.
//
// Once a Certificate is returned it should not be modified.
GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error)
// GetConfigForClient, if not nil, is called after a ClientHello is
// received from a client. It may return a non-nil Config in order to
// change the Config that will be used to handle this connection. If
// the returned Config is nil, the original Config will be used. The
// Config returned by this callback may not be subsequently modified.
//
// If GetConfigForClient is nil, the Config passed to Server() will be
// used for all connections.
//
// If SessionTicketKey was explicitly set on the returned Config, or if
// SetSessionTicketKeys was called on the returned Config, those keys will
// be used. Otherwise, the original Config keys will be used (and possibly
// rotated if they are automatically managed).
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
// VerifyPeerCertificate, if not nil, is called after normal
// certificate verification by either a TLS client or server. It
// receives the raw ASN.1 certificates provided by the peer and also
// any verified chains that normal processing found. If it returns a
// non-nil error, the handshake is aborted and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. If normal verification is disabled by
// setting InsecureSkipVerify, or (for a server) when ClientAuth is
// RequestClientCert or RequireAnyClientCert, then this callback will
// be considered but the verifiedChains argument will always be nil.
//
// verifiedChains and its contents should not be modified.
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
// VerifyConnection, if not nil, is called after normal certificate
// verification and after VerifyPeerCertificate by either a TLS client
// or server. If it returns a non-nil error, the handshake is aborted
// and that error results.
//
// If normal verification fails then the handshake will abort before
// considering this callback. This callback will run for all connections
// regardless of InsecureSkipVerify or ClientAuth settings.
VerifyConnection func(ConnectionState) error
// RootCAs defines the set of root certificate authorities
// that clients use when verifying server certificates.
// If RootCAs is nil, TLS uses the host's root CA set.
RootCAs *x509.CertPool
// NextProtos is a list of supported application level protocols, in
// order of preference. If both peers support ALPN, the selected
// protocol will be one from this list, and the connection will fail
// if there is no mutually supported protocol. If NextProtos is empty
// or the peer doesn't support ALPN, the connection will succeed and
// ConnectionState.NegotiatedProtocol will be empty.
NextProtos []string
// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given. It is also included
// in the client's handshake to support virtual hosting unless it is
// an IP address.
ServerName string
// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
// ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the server's
// certificate chain and host name. If InsecureSkipVerify is true, crypto/tls
// accepts any certificate presented by the server and any host name in that
// certificate. In this mode, TLS is susceptible to machine-in-the-middle
// attacks unless custom verification is used. This should be used only for
// testing or in combination with VerifyConnection or VerifyPeerCertificate.
InsecureSkipVerify bool
// CipherSuites is a list of enabled TLS 1.0–1.2 cipher suites. The order of
// the list is ignored. Note that TLS 1.3 ciphersuites are not configurable.
//
// If CipherSuites is nil, a safe default list is used. The default cipher
// suites might change over time.
CipherSuites []uint16
// PreferServerCipherSuites is a legacy field and has no effect.
//
// It used to control whether the server would follow the client's or the
// server's preference. Servers now select the best mutually supported
// cipher suite based on logic that takes into account inferred client
// hardware, server hardware, and security.
//
// Deprecated: PreferServerCipherSuites is ignored.
PreferServerCipherSuites bool
// SessionTicketsDisabled may be set to true to disable session ticket and
// PSK (resumption) support. Note that on clients, session ticket support is
// also disabled if ClientSessionCache is nil.
SessionTicketsDisabled bool
// SessionTicketKey is used by TLS servers to provide session resumption.
// See RFC 5077 and the PSK mode of RFC 8446. If zero, it will be filled
// with random data before the first server handshake.
//
// Deprecated: if this field is left at zero, session ticket keys will be
// automatically rotated every day and dropped after seven days. For
// customizing the rotation schedule or synchronizing servers that are
// terminating connections for the same host, use SetSessionTicketKeys.
SessionTicketKey [32]byte
// ClientSessionCache is a cache of ClientSessionState entries for TLS
// session resumption. It is only used by clients.
ClientSessionCache ClientSessionCache
// MinVersion contains the minimum TLS version that is acceptable.
//
// By default, TLS 1.2 is currently used as the minimum when acting as a
// client, and TLS 1.0 when acting as a server. TLS 1.0 is the minimum
// supported by this package, both as a client and as a server.
//
// The client-side default can temporarily be reverted to TLS 1.0 by
// including the value "x509sha1=1" in the GODEBUG environment variable.
// Note that this option will be removed in Go 1.19 (but it will still be
// possible to set this field to VersionTLS10 explicitly).
MinVersion uint16
// MaxVersion contains the maximum TLS version that is acceptable.
//
// By default, the maximum version supported by this package is used,
// which is currently TLS 1.3.
MaxVersion uint16
// CurvePreferences contains the elliptic curves that will be used in
// an ECDHE handshake, in preference order. If empty, the default will
// be used. The client will use the first preference as the type for
// its key share in TLS 1.3. This may change in the future.
CurvePreferences []CurveID
// DynamicRecordSizingDisabled disables adaptive sizing of TLS records.
// When true, the largest possible TLS record size is always used. When
// false, the size of TLS records may be adjusted in an attempt to
// improve latency.
DynamicRecordSizingDisabled bool
// Renegotiation controls what types of renegotiation are supported.
// The default, none, is correct for the vast majority of applications.
Renegotiation RenegotiationSupport
// KeyLogWriter optionally specifies a destination for TLS master secrets
// in NSS key log format that can be used to allow external programs
// such as Wireshark to decrypt TLS connections.
// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format.
// Use of KeyLogWriter compromises security and should only be
// used for debugging.
KeyLogWriter io.Writer
// mutex protects sessionTicketKeys and autoSessionTicketKeys.
mutex sync.RWMutex
// sessionTicketKeys contains zero or more ticket keys. If set, it means
// the keys were set with SessionTicketKey or SetSessionTicketKeys. The
// first key is used for new tickets and any subsequent keys can be used to
// decrypt old tickets. The slice contents are not protected by the mutex
// and are immutable.
sessionTicketKeys []ticketKey
// autoSessionTicketKeys is like sessionTicketKeys but is owned by the
// auto-rotation logic. See Config.ticketKeys.
autoSessionTicketKeys []ticketKey
}
const (
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
// an encrypted session ticket in order to identify the key used to encrypt it.
ticketKeyNameLen = 16
// ticketKeyLifetime is how long a ticket key remains valid and can be used to
// resume a client connection.
ticketKeyLifetime = 7 * 24 * time.Hour // 7 days
// ticketKeyRotation is how often the server should rotate the session ticket key
// that is used for new tickets.
ticketKeyRotation = 24 * time.Hour
)
// ticketKey is the internal representation of a session ticket key.
type ticketKey struct {
// keyName is an opaque byte string that serves to identify the session
// ticket key. It's exposed as plaintext in every session ticket.
keyName [ticketKeyNameLen]byte
aesKey [16]byte
hmacKey [16]byte
// created is the time at which this ticket key was created. See Config.ticketKeys.
created time.Time
}
// ticketKeyFromBytes converts from the external representation of a session
// ticket key to a ticketKey. Externally, session ticket keys are 32 random
// bytes and this function expands that into sufficient name and key material.
func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) {
hashed := sha512.Sum512(b[:])
copy(key.keyName[:], hashed[:ticketKeyNameLen])
copy(key.aesKey[:], hashed[ticketKeyNameLen:ticketKeyNameLen+16])
copy(key.hmacKey[:], hashed[ticketKeyNameLen+16:ticketKeyNameLen+32])
key.created = c.time()
return key
}
// maxSessionTicketLifetime is the maximum allowed lifetime of a TLS 1.3 session
// ticket, and the lifetime we set for tickets we send.
const maxSessionTicketLifetime = 7 * 24 * time.Hour
// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a Config that is
// being used concurrently by a TLS client or server.
func (c *Config) Clone() *Config {
if c == nil {
return nil
}
c.mutex.RLock()
defer c.mutex.RUnlock()
return &Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: c.GetConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter,
sessionTicketKeys: c.sessionTicketKeys,
autoSessionTicketKeys: c.autoSessionTicketKeys,
}
}
// deprecatedSessionTicketKey is set as the prefix of SessionTicketKey if it was
// randomized for backwards compatibility but is not in use.
var deprecatedSessionTicketKey = []byte("DEPRECATED")
// initLegacySessionTicketKeyRLocked ensures the legacy SessionTicketKey field is
// randomized if empty, and that sessionTicketKeys is populated from it otherwise.
func (c *Config) initLegacySessionTicketKeyRLocked() {
// Don't write if SessionTicketKey is already defined as our deprecated string,
// or if it is defined by the user but sessionTicketKeys is already set.
if c.SessionTicketKey != [32]byte{} &&
(bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) || len(c.sessionTicketKeys) > 0) {
return
}
// We need to write some data, so get an exclusive lock and re-check any conditions.
c.mutex.RUnlock()
defer c.mutex.RLock()
c.mutex.Lock()
defer c.mutex.Unlock()
if c.SessionTicketKey == [32]byte{} {
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
panic(fmt.Sprintf("tls: unable to generate random session ticket key: %v", err))
}
// Write the deprecated prefix at the beginning so we know we created
// it. This key with the DEPRECATED prefix isn't used as an actual
// session ticket key, and is only randomized in case the application
// reuses it for some reason.
copy(c.SessionTicketKey[:], deprecatedSessionTicketKey)
} else if !bytes.HasPrefix(c.SessionTicketKey[:], deprecatedSessionTicketKey) && len(c.sessionTicketKeys) == 0 {
c.sessionTicketKeys = []ticketKey{c.ticketKeyFromBytes(c.SessionTicketKey)}
}
}
// ticketKeys returns the ticketKeys for this connection.
// If configForClient has explicitly set keys, those will
// be returned. Otherwise, the keys on c will be used and
// may be rotated if auto-managed.
// During rotation, any expired session ticket keys are deleted from
// c.sessionTicketKeys. If the session ticket key that is currently
// encrypting tickets (ie. the first ticketKey in c.sessionTicketKeys)
// is not fresh, then a new session ticket key will be
// created and prepended to c.sessionTicketKeys.
func (c *Config) ticketKeys(configForClient *Config) []ticketKey {
// If the ConfigForClient callback returned a Config with explicitly set
// keys, use those, otherwise just use the original Config.
if configForClient != nil {
configForClient.mutex.RLock()
if configForClient.SessionTicketsDisabled {
return nil
}
configForClient.initLegacySessionTicketKeyRLocked()
if len(configForClient.sessionTicketKeys) != 0 {
ret := configForClient.sessionTicketKeys
configForClient.mutex.RUnlock()
return ret
}
configForClient.mutex.RUnlock()
}
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.SessionTicketsDisabled {
return nil
}
c.initLegacySessionTicketKeyRLocked()
if len(c.sessionTicketKeys) != 0 {
return c.sessionTicketKeys
}
// Fast path for the common case where the key is fresh enough.
if len(c.autoSessionTicketKeys) > 0 && c.time().Sub(c.autoSessionTicketKeys[0].created) < ticketKeyRotation {
return c.autoSessionTicketKeys
}
// autoSessionTicketKeys are managed by auto-rotation.
c.mutex.RUnlock()
defer c.mutex.RLock()
c.mutex.Lock()
defer c.mutex.Unlock()
// Re-check the condition in case it changed since obtaining the new lock.
if len(c.autoSessionTicketKeys) == 0 || c.time().Sub(c.autoSessionTicketKeys[0].created) >= ticketKeyRotation {
var newKey [32]byte
if _, err := io.ReadFull(c.rand(), newKey[:]); err != nil {
panic(fmt.Sprintf("unable to generate random session ticket key: %v", err))
}
valid := make([]ticketKey, 0, len(c.autoSessionTicketKeys)+1)
valid = append(valid, c.ticketKeyFromBytes(newKey))
for _, k := range c.autoSessionTicketKeys {
// While rotating the current key, also remove any expired ones.
if c.time().Sub(k.created) < ticketKeyLifetime {
valid = append(valid, k)
}
}
c.autoSessionTicketKeys = valid
}
return c.autoSessionTicketKeys
}
// SetSessionTicketKeys updates the session ticket keys for a server.
//
// The first key will be used when creating new tickets, while all keys can be
// used for decrypting tickets. It is safe to call this function while the
// server is running in order to rotate the session ticket keys. The function
// will panic if keys is empty.
//
// Calling this function will turn off automatic session ticket key rotation.
//
// If multiple servers are terminating connections for the same host they should
// all have the same session ticket keys. If the session ticket keys leaks,
// previously recorded and future TLS connections using those keys might be
// compromised.
func (c *Config) SetSessionTicketKeys(keys [][32]byte) {
if len(keys) == 0 {
panic("tls: keys must have at least one key")
}
newKeys := make([]ticketKey, len(keys))
for i, bytes := range keys {
newKeys[i] = c.ticketKeyFromBytes(bytes)
}
c.mutex.Lock()
c.sessionTicketKeys = newKeys
c.mutex.Unlock()
}
func (c *Config) rand() io.Reader {
r := c.Rand
if r == nil {
return rand.Reader
}
return r
}
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Now
}
return t()
}
func (c *Config) cipherSuites() []uint16 {
if needFIPS() {
return fipsCipherSuites(c)
}
if c.CipherSuites != nil {
return c.CipherSuites
}
return defaultCipherSuites
}
var supportedVersions = []uint16{
VersionTLS13,
VersionTLS12,
VersionTLS11,
VersionTLS10,
}
// roleClient and roleServer are meant to call supportedVersions and parents
// with more readability at the callsite.
const roleClient = true
const roleServer = false
func (c *Config) supportedVersions(isClient bool) []uint16 {
versions := make([]uint16, 0, len(supportedVersions))
for _, v := range supportedVersions {
if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) {
continue
}
if (c == nil || c.MinVersion == 0) &&
isClient && v < VersionTLS12 {
continue
}
if c != nil && c.MinVersion != 0 && v < c.MinVersion {
continue
}
if c != nil && c.MaxVersion != 0 && v > c.MaxVersion {
continue
}
versions = append(versions, v)
}
return versions
}
func (c *Config) maxSupportedVersion(isClient bool) uint16 {
supportedVersions := c.supportedVersions(isClient)
if len(supportedVersions) == 0 {
return 0
}
return supportedVersions[0]
}
// supportedVersionsFromMax returns a list of supported versions derived from a
// legacy maximum version value. Note that only versions supported by this
// library are returned. Any newer peer will use supportedVersions anyway.
func supportedVersionsFromMax(maxVersion uint16) []uint16 {
versions := make([]uint16, 0, len(supportedVersions))
for _, v := range supportedVersions {
if v > maxVersion {
continue
}
versions = append(versions, v)
}
return versions
}
var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521}
func (c *Config) curvePreferences() []CurveID {
if needFIPS() {
return fipsCurvePreferences(c)
}
if c == nil || len(c.CurvePreferences) == 0 {
return defaultCurvePreferences
}
return c.CurvePreferences
}
func (c *Config) supportsCurve(curve CurveID) bool {
for _, cc := range c.curvePreferences() {
if cc == curve {
return true
}
}
return false
}
// mutualVersion returns the protocol version to use given the advertised
// versions of the peer. Priority is given to the peer preference order.
func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) {
supportedVersions := c.supportedVersions(isClient)
for _, peerVersion := range peerVersions {
for _, v := range supportedVersions {
if v == peerVersion {
return v, true
}
}
}
return 0, false
}
var errNoCertificates = errors.New("tls: no certificates configured")
// getCertificate returns the best certificate for the given ClientHelloInfo,
// defaulting to the first element of c.Certificates.
func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
if c.GetCertificate != nil &&
(len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) {
cert, err := c.GetCertificate(clientHello)
if cert != nil || err != nil {
return cert, err
}
}
if len(c.Certificates) == 0 {
return nil, errNoCertificates
}
if len(c.Certificates) == 1 {
// There's only one choice, so no point doing any work.
return &c.Certificates[0], nil
}
if c.NameToCertificate != nil {
name := strings.ToLower(clientHello.ServerName)
if cert, ok := c.NameToCertificate[name]; ok {
return cert, nil
}
if len(name) > 0 {
labels := strings.Split(name, ".")
labels[0] = "*"
wildcardName := strings.Join(labels, ".")
if cert, ok := c.NameToCertificate[wildcardName]; ok {
return cert, nil
}
}
}
for _, cert := range c.Certificates {
if err := clientHello.SupportsCertificate(&cert); err == nil {
return &cert, nil
}
}
// If nothing matches, return the first certificate.
return &c.Certificates[0], nil
}
// SupportsCertificate returns nil if the provided certificate is supported by
// the client that sent the ClientHello. Otherwise, it returns an error
// describing the reason for the incompatibility.
//
// If this ClientHelloInfo was passed to a GetConfigForClient or GetCertificate
// callback, this method will take into account the associated Config. Note that
// if GetConfigForClient returns a different Config, the change can't be
// accounted for by this method.
//
// This function will call x509.ParseCertificate unless c.Leaf is set, which can
// incur a significant performance cost.
func (chi *ClientHelloInfo) SupportsCertificate(c *Certificate) error {
// Note we don't currently support certificate_authorities nor
// signature_algorithms_cert, and don't check the algorithms of the
// signatures on the chain (which anyway are a SHOULD, see RFC 8446,
// Section 4.4.2.2).
config := chi.config
if config == nil {
config = &Config{}
}
vers, ok := config.mutualVersion(roleServer, chi.SupportedVersions)
if !ok {
return errors.New("no mutually supported protocol versions")
}
// If the client specified the name they are trying to connect to, the
// certificate needs to be valid for it.
if chi.ServerName != "" {
x509Cert, err := c.leaf()
if err != nil {
return fmt.Errorf("failed to parse certificate: %w", err)
}
if err := x509Cert.VerifyHostname(chi.ServerName); err != nil {
return fmt.Errorf("certificate is not valid for requested server name: %w", err)
}
}
// supportsRSAFallback returns nil if the certificate and connection support
// the static RSA key exchange, and unsupported otherwise. The logic for
// supporting static RSA is completely disjoint from the logic for
// supporting signed key exchanges, so we just check it as a fallback.
supportsRSAFallback := func(unsupported error) error {
// TLS 1.3 dropped support for the static RSA key exchange.
if vers == VersionTLS13 {
return unsupported
}
// The static RSA key exchange works by decrypting a challenge with the
// RSA private key, not by signing, so check the PrivateKey implements
// crypto.Decrypter, like *rsa.PrivateKey does.
if priv, ok := c.PrivateKey.(crypto.Decrypter); ok {
if _, ok := priv.Public().(*rsa.PublicKey); !ok {
return unsupported
}
} else {
return unsupported
}
// Finally, there needs to be a mutual cipher suite that uses the static
// RSA key exchange instead of ECDHE.
rsaCipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
return false
}
if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
})
if rsaCipherSuite == nil {
return unsupported
}
return nil
}
// If the client sent the signature_algorithms extension, ensure it supports
// schemes we can use with this certificate and TLS version.
if len(chi.SignatureSchemes) > 0 {
if _, err := selectSignatureScheme(vers, c, chi.SignatureSchemes); err != nil {
return supportsRSAFallback(err)
}
}
// In TLS 1.3 we are done because supported_groups is only relevant to the
// ECDHE computation, point format negotiation is removed, cipher suites are
// only relevant to the AEAD choice, and static RSA does not exist.
if vers == VersionTLS13 {
return nil
}
// The only signed key exchange we support is ECDHE.
if !supportsECDHE(config, chi.SupportedCurves, chi.SupportedPoints) {
return supportsRSAFallback(errors.New("client doesn't support ECDHE, can only use legacy RSA key exchange"))
}
var ecdsaCipherSuite bool
if priv, ok := c.PrivateKey.(crypto.Signer); ok {
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
var curve CurveID
switch pub.Curve {
case elliptic.P256():
curve = CurveP256
case elliptic.P384():
curve = CurveP384
case elliptic.P521():
curve = CurveP521
default:
return supportsRSAFallback(unsupportedCertificateError(c))
}
var curveOk bool
for _, c := range chi.SupportedCurves {
if c == curve && config.supportsCurve(c) {
curveOk = true
break
}
}
if !curveOk {
return errors.New("client doesn't support certificate curve")
}
ecdsaCipherSuite = true
case ed25519.PublicKey:
if vers < VersionTLS12 || len(chi.SignatureSchemes) == 0 {
return errors.New("connection doesn't support Ed25519")
}
ecdsaCipherSuite = true
case *rsa.PublicKey:
default:
return supportsRSAFallback(unsupportedCertificateError(c))
}
} else {
return supportsRSAFallback(unsupportedCertificateError(c))
}
// Make sure that there is a mutually supported cipher suite that works with
// this certificate. Cipher suite selection will then apply the logic in
// reverse to pick it. See also serverHandshakeState.cipherSuiteOk.
cipherSuite := selectCipherSuite(chi.CipherSuites, config.cipherSuites(), func(c *cipherSuite) bool {
if c.flags&suiteECDHE == 0 {
return false
}
if c.flags&suiteECSign != 0 {
if !ecdsaCipherSuite {
return false
}
} else {
if ecdsaCipherSuite {
return false
}
}
if vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
})
if cipherSuite == nil {
return supportsRSAFallback(errors.New("client doesn't support any cipher suites compatible with the certificate"))
}
return nil
}
// SupportsCertificate returns nil if the provided certificate is supported by
// the server that sent the CertificateRequest. Otherwise, it returns an error
// describing the reason for the incompatibility.
func (cri *CertificateRequestInfo) SupportsCertificate(c *Certificate) error {
if _, err := selectSignatureScheme(cri.Version, c, cri.SignatureSchemes); err != nil {
return err
}
if len(cri.AcceptableCAs) == 0 {
return nil
}
for j, cert := range c.Certificate {
x509Cert := c.Leaf
// Parse the certificate if this isn't the leaf node, or if
// chain.Leaf was nil.
if j != 0 || x509Cert == nil {
var err error
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err)
}
}
for _, ca := range cri.AcceptableCAs {
if bytes.Equal(x509Cert.RawIssuer, ca) {
return nil
}
}
}
return errors.New("chain is not signed by an acceptable CA")
}
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
// from the CommonName and SubjectAlternateName fields of each of the leaf
// certificates.
//
// Deprecated: NameToCertificate only allows associating a single certificate
// with a given name. Leave that field nil to let the library select the first
// compatible chain from Certificates.
func (c *Config) BuildNameToCertificate() {
c.NameToCertificate = make(map[string]*Certificate)
for i := range c.Certificates {
cert := &c.Certificates[i]
x509Cert, err := cert.leaf()
if err != nil {
continue
}
// If SANs are *not* present, some clients will consider the certificate
// valid for the name in the Common Name.
if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 {
c.NameToCertificate[x509Cert.Subject.CommonName] = cert
}
for _, san := range x509Cert.DNSNames {
c.NameToCertificate[san] = cert
}
}
}
const (
keyLogLabelTLS12 = "CLIENT_RANDOM"
keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0"
keyLogLabelServerTraffic = "SERVER_TRAFFIC_SECRET_0"
)
func (c *Config) writeKeyLog(label string, clientRandom, secret []byte) error {
if c.KeyLogWriter == nil {
return nil
}
logLine := fmt.Appendf(nil, "%s %x %x\n", label, clientRandom, secret)
writerMutex.Lock()
_, err := c.KeyLogWriter.Write(logLine)
writerMutex.Unlock()
return err
}
// writerMutex protects all KeyLogWriters globally. It is rarely enabled,
// and is only for debugging, so a global mutex saves space.
var writerMutex sync.Mutex
// A Certificate is a chain of one or more certificates, leaf first.
type Certificate struct {
Certificate [][]byte
// PrivateKey contains the private key corresponding to the public key in
// Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey.
// For a server up to TLS 1.2, it can also implement crypto.Decrypter with
// an RSA PublicKey.
PrivateKey crypto.PrivateKey
// SupportedSignatureAlgorithms is an optional list restricting what
// signature algorithms the PrivateKey can be used for.
SupportedSignatureAlgorithms []SignatureScheme
// OCSPStaple contains an optional OCSP response which will be served
// to clients that request it.
OCSPStaple []byte
// SignedCertificateTimestamps contains an optional list of Signed
// Certificate Timestamps which will be served to clients that request it.
SignedCertificateTimestamps [][]byte
// Leaf is the parsed form of the leaf certificate, which may be initialized
// using x509.ParseCertificate to reduce per-handshake processing. If nil,
// the leaf certificate will be parsed as needed.
Leaf *x509.Certificate
}
// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing
// the corresponding c.Certificate[0].
func (c *Certificate) leaf() (*x509.Certificate, error) {
if c.Leaf != nil {
return c.Leaf, nil
}
return x509.ParseCertificate(c.Certificate[0])
}
type handshakeMessage interface {
marshal() ([]byte, error)
unmarshal([]byte) bool
}
// lruSessionCache is a ClientSessionCache implementation that uses an LRU
// caching strategy.
type lruSessionCache struct {
sync.Mutex
m map[string]*list.Element
q *list.List
capacity int
}
type lruSessionCacheEntry struct {
sessionKey string
state *ClientSessionState
}
// NewLRUClientSessionCache returns a ClientSessionCache with the given
// capacity that uses an LRU strategy. If capacity is < 1, a default capacity
// is used instead.
func NewLRUClientSessionCache(capacity int) ClientSessionCache {
const defaultSessionCacheCapacity = 64
if capacity < 1 {
capacity = defaultSessionCacheCapacity
}
return &lruSessionCache{
m: make(map[string]*list.Element),
q: list.New(),
capacity: capacity,
}
}
// Put adds the provided (sessionKey, cs) pair to the cache. If cs is nil, the entry
// corresponding to sessionKey is removed from the cache instead.
func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) {
c.Lock()
defer c.Unlock()
if elem, ok := c.m[sessionKey]; ok {
if cs == nil {
c.q.Remove(elem)
delete(c.m, sessionKey)
} else {
entry := elem.Value.(*lruSessionCacheEntry)
entry.state = cs
c.q.MoveToFront(elem)
}
return
}
if c.q.Len() < c.capacity {
entry := &lruSessionCacheEntry{sessionKey, cs}
c.m[sessionKey] = c.q.PushFront(entry)
return
}
elem := c.q.Back()
entry := elem.Value.(*lruSessionCacheEntry)
delete(c.m, entry.sessionKey)
entry.sessionKey = sessionKey
entry.state = cs
c.q.MoveToFront(elem)
c.m[sessionKey] = elem
}
// Get returns the ClientSessionState value associated with a given key. It
// returns (nil, false) if no value is found.
func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
c.Lock()
defer c.Unlock()
if elem, ok := c.m[sessionKey]; ok {
c.q.MoveToFront(elem)
return elem.Value.(*lruSessionCacheEntry).state, true
}
return nil, false
}
var emptyConfig Config
func defaultConfig() *Config {
return &emptyConfig
}
func unexpectedMessageError(wanted, got any) error {
return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
}
func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool {
for _, s := range supportedSignatureAlgorithms {
if s == sigAlg {
return true
}
}
return false
}
// CertificateVerificationError is returned when certificate verification fails during the handshake.
type CertificateVerificationError struct {
// UnverifiedCertificates and its contents should not be modified.
UnverifiedCertificates []*x509.Certificate
Err error
}
func (e *CertificateVerificationError) Error() string {
return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err)
}
func (e *CertificateVerificationError) Unwrap() error {
return e.Err
}
// Code generated by "stringer -type=SignatureScheme,CurveID,ClientAuthType -output=common_string.go"; DO NOT EDIT.
package tls
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[PKCS1WithSHA256-1025]
_ = x[PKCS1WithSHA384-1281]
_ = x[PKCS1WithSHA512-1537]
_ = x[PSSWithSHA256-2052]
_ = x[PSSWithSHA384-2053]
_ = x[PSSWithSHA512-2054]
_ = x[ECDSAWithP256AndSHA256-1027]
_ = x[ECDSAWithP384AndSHA384-1283]
_ = x[ECDSAWithP521AndSHA512-1539]
_ = x[Ed25519-2055]
_ = x[PKCS1WithSHA1-513]
_ = x[ECDSAWithSHA1-515]
}
const (
_SignatureScheme_name_0 = "PKCS1WithSHA1"
_SignatureScheme_name_1 = "ECDSAWithSHA1"
_SignatureScheme_name_2 = "PKCS1WithSHA256"
_SignatureScheme_name_3 = "ECDSAWithP256AndSHA256"
_SignatureScheme_name_4 = "PKCS1WithSHA384"
_SignatureScheme_name_5 = "ECDSAWithP384AndSHA384"
_SignatureScheme_name_6 = "PKCS1WithSHA512"
_SignatureScheme_name_7 = "ECDSAWithP521AndSHA512"
_SignatureScheme_name_8 = "PSSWithSHA256PSSWithSHA384PSSWithSHA512Ed25519"
)
var (
_SignatureScheme_index_8 = [...]uint8{0, 13, 26, 39, 46}
)
func (i SignatureScheme) String() string {
switch {
case i == 513:
return _SignatureScheme_name_0
case i == 515:
return _SignatureScheme_name_1
case i == 1025:
return _SignatureScheme_name_2
case i == 1027:
return _SignatureScheme_name_3
case i == 1281:
return _SignatureScheme_name_4
case i == 1283:
return _SignatureScheme_name_5
case i == 1537:
return _SignatureScheme_name_6
case i == 1539:
return _SignatureScheme_name_7
case 2052 <= i && i <= 2055:
i -= 2052
return _SignatureScheme_name_8[_SignatureScheme_index_8[i]:_SignatureScheme_index_8[i+1]]
default:
return "SignatureScheme(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[CurveP256-23]
_ = x[CurveP384-24]
_ = x[CurveP521-25]
_ = x[X25519-29]
}
const (
_CurveID_name_0 = "CurveP256CurveP384CurveP521"
_CurveID_name_1 = "X25519"
)
var (
_CurveID_index_0 = [...]uint8{0, 9, 18, 27}
)
func (i CurveID) String() string {
switch {
case 23 <= i && i <= 25:
i -= 23
return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]]
case i == 29:
return _CurveID_name_1
default:
return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[NoClientCert-0]
_ = x[RequestClientCert-1]
_ = x[RequireAnyClientCert-2]
_ = x[VerifyClientCertIfGiven-3]
_ = x[RequireAndVerifyClientCert-4]
}
const _ClientAuthType_name = "NoClientCertRequestClientCertRequireAnyClientCertVerifyClientCertIfGivenRequireAndVerifyClientCert"
var _ClientAuthType_index = [...]uint8{0, 12, 29, 49, 72, 98}
func (i ClientAuthType) String() string {
if i < 0 || i >= ClientAuthType(len(_ClientAuthType_index)-1) {
return "ClientAuthType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _ClientAuthType_name[_ClientAuthType_index[i]:_ClientAuthType_index[i+1]]
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TLS low level connection and record layer
package tls
import (
"bytes"
"context"
"crypto/cipher"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// A Conn represents a secured connection.
// It implements the net.Conn interface.
type Conn struct {
// constant
conn net.Conn
isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
// isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
// isHandshakeComplete is true implies handshakeErr == nil.
isHandshakeComplete atomic.Bool
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex
handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
handshakes int
didResume bool // whether this connection was a session resumption
cipherSuite uint16
ocspResponse []byte // stapled OCSP response
scts [][]byte // signed certificate timestamps from server
peerCertificates []*x509.Certificate
// activeCertHandles contains the cache handles to certificates in
// peerCertificates that are used to track active references.
activeCertHandles []*activeCert
// verifiedChains contains the certificate chains that we built, as
// opposed to the ones presented by the server.
verifiedChains [][]*x509.Certificate
// serverName contains the server name indicated by the client, if any.
serverName string
// secureRenegotiation is true if the server echoed the secure
// renegotiation extension. (This is meaningless as a server because
// renegotiation is not supported in that case.)
secureRenegotiation bool
// ekm is a closure for exporting keying material.
ekm func(label string, context []byte, length int) ([]byte, error)
// resumptionSecret is the resumption_master_secret for handling
// NewSessionTicket messages. nil if config.SessionTicketsDisabled.
resumptionSecret []byte
// ticketKeys is the set of active session ticket keys for this
// connection. The first one is used to encrypt new tickets and
// all are tried to decrypt tickets.
ticketKeys []ticketKey
// clientFinishedIsFirst is true if the client sent the first Finished
// message during the most recent handshake. This is recorded because
// the first transmitted Finished message is the tls-unique
// channel-binding value.
clientFinishedIsFirst bool
// closeNotifyErr is any error from sending the alertCloseNotify record.
closeNotifyErr error
// closeNotifySent is true if the Conn attempted to send an
// alertCloseNotify record.
closeNotifySent bool
// clientFinished and serverFinished contain the Finished message sent
// by the client or server in the most recent handshake. This is
// retained to support the renegotiation extension and tls-unique
// channel-binding.
clientFinished [12]byte
serverFinished [12]byte
// clientProtocol is the negotiated ALPN protocol.
clientProtocol string
// input/output
in, out halfConn
rawInput bytes.Buffer // raw input, starting with a record header
input bytes.Reader // application data waiting to be read, from rawInput.Next
hand bytes.Buffer // handshake data waiting to be read
buffering bool // whether records are buffered in sendBuf
sendBuf []byte // a buffer of records waiting to be sent
// bytesSent counts the bytes of application data sent.
// packetsSent counts packets.
bytesSent int64
packetsSent int64
// retryCount counts the number of consecutive non-advancing records
// received by Conn.readRecord. That is, records that neither advance the
// handshake, nor deliver application data. Protected by in.Mutex.
retryCount int
// activeCall indicates whether Close has been call in the low bit.
// the rest of the bits are the number of goroutines in Conn.Write.
activeCall atomic.Int32
tmp [16]byte
}
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying connection.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// NetConn returns the underlying connection that is wrapped by c.
// Note that writing to or reading from this connection directly will corrupt the
// TLS session.
func (c *Conn) NetConn() net.Conn {
return c.conn
}
// A halfConn represents one direction of the record layer
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
err error // first permanent error
version uint16 // protocol version
cipher any // cipher algorithm
mac hash.Hash
seq [8]byte // 64-bit sequence number
scratchBuf [13]byte // to avoid allocs; interface method args escape
nextCipher any // next encryption state
nextMac hash.Hash // next MAC algorithm
trafficSecret []byte // current TLS 1.3 traffic secret
}
type permanentError struct {
err net.Error
}
func (e *permanentError) Error() string { return e.err.Error() }
func (e *permanentError) Unwrap() error { return e.err }
func (e *permanentError) Timeout() bool { return e.err.Timeout() }
func (e *permanentError) Temporary() bool { return false }
func (hc *halfConn) setErrorLocked(err error) error {
if e, ok := err.(net.Error); ok {
hc.err = &permanentError{err: e}
} else {
hc.err = err
}
return hc.err
}
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
hc.version = version
hc.nextCipher = cipher
hc.nextMac = mac
}
// changeCipherSpec changes the encryption and MAC states
// to the ones previously passed to prepareCipherSpec.
func (hc *halfConn) changeCipherSpec() error {
if hc.nextCipher == nil || hc.version == VersionTLS13 {
return alertInternalError
}
hc.cipher = hc.nextCipher
hc.mac = hc.nextMac
hc.nextCipher = nil
hc.nextMac = nil
for i := range hc.seq {
hc.seq[i] = 0
}
return nil
}
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
hc.trafficSecret = secret
key, iv := suite.trafficKey(secret)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
hc.seq[i] = 0
}
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq() {
for i := 7; i >= 0; i-- {
hc.seq[i]++
if hc.seq[i] != 0 {
return
}
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
panic("TLS: sequence number wraparound")
}
// explicitNonceLen returns the number of bytes of explicit nonce or IV included
// in each record. Explicit nonces are present only in CBC modes after TLS 1.0
// and in certain AEAD modes in TLS 1.2.
func (hc *halfConn) explicitNonceLen() int {
if hc.cipher == nil {
return 0
}
switch c := hc.cipher.(type) {
case cipher.Stream:
return 0
case aead:
return c.explicitNonceLen()
case cbcMode:
// TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
if hc.version >= VersionTLS11 {
return c.BlockSize()
}
return 0
default:
panic("unknown cipher type")
}
}
// extractPadding returns, in constant time, the length of the padding to remove
// from the end of payload. It also returns a byte which is equal to 255 if the
// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
func extractPadding(payload []byte) (toRemove int, good byte) {
if len(payload) < 1 {
return 0, 0
}
paddingLen := payload[len(payload)-1]
t := uint(len(payload)-1) - uint(paddingLen)
// if len(payload) >= (paddingLen - 1) then the MSB of t is zero
good = byte(int32(^t) >> 31)
// The maximum possible padding length plus the actual length field
toCheck := 256
// The length of the padded data is public, so we can use an if here
if toCheck > len(payload) {
toCheck = len(payload)
}
for i := 0; i < toCheck; i++ {
t := uint(paddingLen) - uint(i)
// if i <= paddingLen then the MSB of t is zero
mask := byte(int32(^t) >> 31)
b := payload[len(payload)-1-i]
good &^= mask&paddingLen ^ mask&b
}
// We AND together the bits of good and replicate the result across
// all the bits.
good &= good << 4
good &= good << 2
good &= good << 1
good = uint8(int8(good) >> 7)
// Zero the padding length on error. This ensures any unchecked bytes
// are included in the MAC. Otherwise, an attacker that could
// distinguish MAC failures from padding failures could mount an attack
// similar to POODLE in SSL 3.0: given a good ciphertext that uses a
// full block's worth of padding, replace the final block with another
// block. If the MAC check passed but the padding check failed, the
// last byte of that block decrypted to the block size.
//
// See also macAndPaddingGood logic below.
paddingLen &= good
toRemove = int(paddingLen) + 1
return
}
func roundUp(a, b int) int {
return a + (b-a%b)%b
}
// cbcMode is an interface for block ciphers using cipher block chaining.
type cbcMode interface {
cipher.BlockMode
SetIV([]byte)
}
// decrypt authenticates and decrypts the record if protection is active at
// this stage. The returned plaintext might overlap with the input.
func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
var plaintext []byte
typ := recordType(record[0])
payload := record[recordHeaderLen:]
// In TLS 1.3, change_cipher_spec messages are to be ignored without being
// decrypted. See RFC 8446, Appendix D.4.
if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
return payload, typ, nil
}
paddingGood := byte(255)
paddingLen := 0
explicitNonceLen := hc.explicitNonceLen()
if hc.cipher != nil {
switch c := hc.cipher.(type) {
case cipher.Stream:
c.XORKeyStream(payload, payload)
case aead:
if len(payload) < explicitNonceLen {
return nil, 0, alertBadRecordMAC
}
nonce := payload[:explicitNonceLen]
if len(nonce) == 0 {
nonce = hc.seq[:]
}
payload = payload[explicitNonceLen:]
var additionalData []byte
if hc.version == VersionTLS13 {
additionalData = record[:recordHeaderLen]
} else {
additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
additionalData = append(additionalData, record[:3]...)
n := len(payload) - c.Overhead()
additionalData = append(additionalData, byte(n>>8), byte(n))
}
var err error
plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
if err != nil {
return nil, 0, alertBadRecordMAC
}
case cbcMode:
blockSize := c.BlockSize()
minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
if len(payload)%blockSize != 0 || len(payload) < minPayload {
return nil, 0, alertBadRecordMAC
}
if explicitNonceLen > 0 {
c.SetIV(payload[:explicitNonceLen])
payload = payload[explicitNonceLen:]
}
c.CryptBlocks(payload, payload)
// In a limited attempt to protect against CBC padding oracles like
// Lucky13, the data past paddingLen (which is secret) is passed to
// the MAC function as extra data, to be fed into the HMAC after
// computing the digest. This makes the MAC roughly constant time as
// long as the digest computation is constant time and does not
// affect the subsequent write, modulo cache effects.
paddingLen, paddingGood = extractPadding(payload)
default:
panic("unknown cipher type")
}
if hc.version == VersionTLS13 {
if typ != recordTypeApplicationData {
return nil, 0, alertUnexpectedMessage
}
if len(plaintext) > maxPlaintext+1 {
return nil, 0, alertRecordOverflow
}
// Remove padding and find the ContentType scanning from the end.
for i := len(plaintext) - 1; i >= 0; i-- {
if plaintext[i] != 0 {
typ = recordType(plaintext[i])
plaintext = plaintext[:i]
break
}
if i == 0 {
return nil, 0, alertUnexpectedMessage
}
}
}
} else {
plaintext = payload
}
if hc.mac != nil {
macSize := hc.mac.Size()
if len(payload) < macSize {
return nil, 0, alertBadRecordMAC
}
n := len(payload) - macSize - paddingLen
n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
record[3] = byte(n >> 8)
record[4] = byte(n)
remoteMAC := payload[n : n+macSize]
localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
// This is equivalent to checking the MACs and paddingGood
// separately, but in constant-time to prevent distinguishing
// padding failures from MAC failures. Depending on what value
// of paddingLen was returned on bad padding, distinguishing
// bad MAC from bad padding can lead to an attack.
//
// See also the logic at the end of extractPadding.
macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
if macAndPaddingGood != 1 {
return nil, 0, alertBadRecordMAC
}
plaintext = payload[:n]
}
hc.incSeq()
return plaintext, typ, nil
}
// sliceForAppend extends the input slice by n bytes. head is the full extended
// slice, while tail is the appended part. If the original slice has sufficient
// capacity no allocation is performed.
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
// appends it to record, which must already contain the record header.
func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
if hc.cipher == nil {
return append(record, payload...), nil
}
var explicitNonce []byte
if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
record, explicitNonce = sliceForAppend(record, explicitNonceLen)
if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
// The AES-GCM construction in TLS has an explicit nonce so that the
// nonce can be random. However, the nonce is only 8 bytes which is
// too small for a secure, random nonce. Therefore we use the
// sequence number as the nonce. The 3DES-CBC construction also has
// an 8 bytes nonce but its nonces must be unpredictable (see RFC
// 5246, Appendix F.3), forcing us to use randomness. That's not
// 3DES' biggest problem anyway because the birthday bound on block
// collision is reached first due to its similarly small block size
// (see the Sweet32 attack).
copy(explicitNonce, hc.seq[:])
} else {
if _, err := io.ReadFull(rand, explicitNonce); err != nil {
return nil, err
}
}
}
var dst []byte
switch c := hc.cipher.(type) {
case cipher.Stream:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
record, dst = sliceForAppend(record, len(payload)+len(mac))
c.XORKeyStream(dst[:len(payload)], payload)
c.XORKeyStream(dst[len(payload):], mac)
case aead:
nonce := explicitNonce
if len(nonce) == 0 {
nonce = hc.seq[:]
}
if hc.version == VersionTLS13 {
record = append(record, payload...)
// Encrypt the actual ContentType and replace the plaintext one.
record = append(record, record[0])
record[0] = byte(recordTypeApplicationData)
n := len(payload) + 1 + c.Overhead()
record[3] = byte(n >> 8)
record[4] = byte(n)
record = c.Seal(record[:recordHeaderLen],
nonce, record[recordHeaderLen:], record[:recordHeaderLen])
} else {
additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
additionalData = append(additionalData, record[:recordHeaderLen]...)
record = c.Seal(record, nonce, payload, additionalData)
}
case cbcMode:
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
blockSize := c.BlockSize()
plaintextLen := len(payload) + len(mac)
paddingLen := blockSize - plaintextLen%blockSize
record, dst = sliceForAppend(record, plaintextLen+paddingLen)
copy(dst, payload)
copy(dst[len(payload):], mac)
for i := plaintextLen; i < len(dst); i++ {
dst[i] = byte(paddingLen - 1)
}
if len(explicitNonce) > 0 {
c.SetIV(explicitNonce)
}
c.CryptBlocks(dst, dst)
default:
panic("unknown cipher type")
}
// Update length to include nonce, MAC and any block padding needed.
n := len(record) - recordHeaderLen
record[3] = byte(n >> 8)
record[4] = byte(n)
hc.incSeq()
return record, nil
}
// RecordHeaderError is returned when a TLS record header is invalid.
type RecordHeaderError struct {
// Msg contains a human readable string that describes the error.
Msg string
// RecordHeader contains the five bytes of TLS record header that
// triggered the error.
RecordHeader [5]byte
// Conn provides the underlying net.Conn in the case that a client
// sent an initial handshake that didn't look like TLS.
// It is nil if there's already been a handshake or a TLS alert has
// been written to the connection.
Conn net.Conn
}
func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
err.Msg = msg
err.Conn = conn
copy(err.RecordHeader[:], c.rawInput.Bytes())
return err
}
func (c *Conn) readRecord() error {
return c.readRecordOrCCS(false)
}
func (c *Conn) readChangeCipherSpec() error {
return c.readRecordOrCCS(true)
}
// readRecordOrCCS reads one or more TLS records from the connection and
// updates the record layer state. Some invariants:
// - c.in must be locked
// - c.input must be empty
//
// During the handshake one and only one of the following will happen:
// - c.hand grows
// - c.in.changeCipherSpec is called
// - an error is returned
//
// After the handshake one and only one of the following will happen:
// - c.hand grows
// - c.input is set
// - an error is returned
func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
if c.in.err != nil {
return c.in.err
}
handshakeComplete := c.isHandshakeComplete.Load()
// This function modifies c.rawInput, which owns the c.input memory.
if c.input.Len() != 0 {
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
}
c.input.Reset(nil)
// Read header, payload.
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
// is an error, but popular web sites seem to do this, so we accept it
// if and only if at the record boundary.
if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
err = io.EOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return err
}
hdr := c.rawInput.Bytes()[:recordHeaderLen]
typ := recordType(hdr[0])
// No valid TLS record has a type of 0x80, however SSLv2 handshakes
// start with a uint16 length where the MSB is set and the first record
// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
// an SSLv2 client.
if !handshakeComplete && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
}
vers := uint16(hdr[1])<<8 | uint16(hdr[2])
n := int(hdr[3])<<8 | int(hdr[4])
if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
c.sendAlert(alertProtocolVersion)
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if !c.haveVers {
// First message, be extra suspicious: this might not be a TLS
// client. Bail out before reading a full 'body', if possible.
// The current max version is 3.3 so if the version is >= 16.0,
// it's probably not real.
if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
}
}
if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
msg := fmt.Sprintf("oversized record received with length %d", n)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return err
}
// Process message.
record := c.rawInput.Next(recordHeaderLen + n)
data, typ, err := c.in.decrypt(record)
if err != nil {
return c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
if len(data) > maxPlaintext {
return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
}
// Application Data messages are always protected.
if c.in.cipher == nil && typ == recordTypeApplicationData {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
// This is a state-advancing message: reset the retry count.
c.retryCount = 0
}
// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
switch typ {
default:
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if len(data) != 2 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if alert(data[1]) == alertCloseNotify {
return c.in.setErrorLocked(io.EOF)
}
if c.vers == VersionTLS13 {
return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
}
switch data[0] {
case alertLevelWarning:
// Drop the record on the floor and retry.
return c.retryReadRecord(expectChangeCipherSpec)
case alertLevelError:
return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if len(data) != 1 || data[0] != 1 {
return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
}
// Handshake messages are not allowed to fragment across the CCS.
if c.hand.Len() > 0 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// In TLS 1.3, change_cipher_spec records are ignored until the
// Finished. See RFC 8446, Appendix D.4. Note that according to Section
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
// c.vers is still unset. That's not useful though and suspicious if the
// server then selects a lower protocol version, so don't allow that.
if c.vers == VersionTLS13 {
return c.retryReadRecord(expectChangeCipherSpec)
}
if !expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if err := c.in.changeCipherSpec(); err != nil {
return c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
case recordTypeApplicationData:
if !handshakeComplete || expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// Some OpenSSL servers send empty records in order to randomize the
// CBC IV. Ignore a limited number of empty records.
if len(data) == 0 {
return c.retryReadRecord(expectChangeCipherSpec)
}
// Note that data is owned by c.rawInput, following the Next call above,
// to avoid copying the plaintext. This is safe because c.rawInput is
// not read from or written to until c.input is drained.
c.input.Reset(data)
case recordTypeHandshake:
if len(data) == 0 || expectChangeCipherSpec {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
c.hand.Write(data)
}
return nil
}
// retryReadRecord recurs into readRecordOrCCS to drop a non-advancing record, like
// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
}
return c.readRecordOrCCS(expectChangeCipherSpec)
}
// atLeastReader reads from R, stopping with EOF once at least N bytes have been
// read. It is different from an io.LimitedReader in that it doesn't cut short
// the last Read call, and in that it considers an early EOF an error.
type atLeastReader struct {
R io.Reader
N int64
}
func (r *atLeastReader) Read(p []byte) (int, error) {
if r.N <= 0 {
return 0, io.EOF
}
n, err := r.R.Read(p)
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
if r.N > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
if r.N <= 0 && err == nil {
return n, io.EOF
}
return n, err
}
// readFromUntil reads from r into c.rawInput until c.rawInput contains
// at least n bytes or else returns an error.
func (c *Conn) readFromUntil(r io.Reader, n int) error {
if c.rawInput.Len() >= n {
return nil
}
needs := n - c.rawInput.Len()
// There might be extra input waiting on the wire. Make a best effort
// attempt to fetch it so that it can be used in (*Conn).Read to
// "predict" closeNotify alerts.
c.rawInput.Grow(needs + bytes.MinRead)
_, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
return err
}
// sendAlertLocked sends a TLS alert message.
func (c *Conn) sendAlertLocked(err alert) error {
switch err {
case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning
default:
c.tmp[0] = alertLevelError
}
c.tmp[1] = byte(err)
_, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
if err == alertCloseNotify {
// closeNotify is a special case in that it isn't an error.
return writeErr
}
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
// sendAlert sends a TLS alert message.
func (c *Conn) sendAlert(err alert) error {
c.out.Lock()
defer c.out.Unlock()
return c.sendAlertLocked(err)
}
const (
// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
// size (MSS). A constant is used, rather than querying the kernel for
// the actual MSS, to avoid complexity. The value here is the IPv6
// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
// bytes) and a TCP header with timestamps (32 bytes).
tcpMSSEstimate = 1208
// recordSizeBoostThreshold is the number of bytes of application data
// sent after which the TLS record size will be increased to the
// maximum.
recordSizeBoostThreshold = 128 * 1024
)
// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the
// next application data record. There is the following trade-off:
//
// - For latency-sensitive applications, such as web browsing, each TLS
// record should fit in one TCP segment.
// - For throughput-sensitive applications, such as large file transfers,
// larger TLS records better amortize framing and encryption overheads.
//
// A simple heuristic that works well in practice is to use small records for
// the first 1MB of data, then use larger records for subsequent data, and
// reset back to smaller records after the connection becomes idle. See "High
// Performance Web Networking", Chapter 4, or:
// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
//
// In the interests of simplicity and determinism, this code does not attempt
// to reset the record size once the connection is idle, however.
func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
return maxPlaintext
}
if c.bytesSent >= recordSizeBoostThreshold {
return maxPlaintext
}
// Subtract TLS overheads to get the maximum payload size.
payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
if c.out.cipher != nil {
switch ciph := c.out.cipher.(type) {
case cipher.Stream:
payloadBytes -= c.out.mac.Size()
case cipher.AEAD:
payloadBytes -= ciph.Overhead()
case cbcMode:
blockSize := ciph.BlockSize()
// The payload must fit in a multiple of blockSize, with
// room for at least one padding byte.
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
// The MAC is appended before padding so affects the
// payload size directly.
payloadBytes -= c.out.mac.Size()
default:
panic("unknown cipher type")
}
}
if c.vers == VersionTLS13 {
payloadBytes-- // encrypted ContentType
}
// Allow packet growth in arithmetic progression up to max.
pkt := c.packetsSent
c.packetsSent++
if pkt > 1000 {
return maxPlaintext // avoid overflow in multiply below
}
n := payloadBytes * int(pkt+1)
if n > maxPlaintext {
n = maxPlaintext
}
return n
}
func (c *Conn) write(data []byte) (int, error) {
if c.buffering {
c.sendBuf = append(c.sendBuf, data...)
return len(data), nil
}
n, err := c.conn.Write(data)
c.bytesSent += int64(n)
return n, err
}
func (c *Conn) flush() (int, error) {
if len(c.sendBuf) == 0 {
return 0, nil
}
n, err := c.conn.Write(c.sendBuf)
c.bytesSent += int64(n)
c.sendBuf = nil
c.buffering = false
return n, err
}
// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
var outBufPool = sync.Pool{
New: func() any {
return new([]byte)
},
}
// writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
outBufPtr := outBufPool.Get().(*[]byte)
outBuf := *outBufPtr
defer func() {
// You might be tempted to simplify this by just passing &outBuf to Put,
// but that would make the local copy of the outBuf slice header escape
// to the heap, causing an allocation. Instead, we keep around the
// pointer to the slice header returned by Get, which is already on the
// heap, and overwrite and return that.
*outBufPtr = outBuf
outBufPool.Put(outBufPtr)
}()
var n int
for len(data) > 0 {
m := len(data)
if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
m = maxPayload
}
_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
outBuf[0] = byte(typ)
vers := c.vers
if vers == 0 {
// Some TLS servers fail if the record version is
// greater than TLS 1.0 for the initial ClientHello.
vers = VersionTLS10
} else if vers == VersionTLS13 {
// TLS 1.3 froze the record layer version to 1.2.
// See RFC 8446, Section 5.1.
vers = VersionTLS12
}
outBuf[1] = byte(vers >> 8)
outBuf[2] = byte(vers)
outBuf[3] = byte(m >> 8)
outBuf[4] = byte(m)
var err error
outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
if err != nil {
return n, err
}
if _, err := c.write(outBuf); err != nil {
return n, err
}
n += m
data = data[m:]
}
if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
if err := c.out.changeCipherSpec(); err != nil {
return n, c.sendAlertLocked(err.(alert))
}
}
return n, nil
}
// writeHandshakeRecord writes a handshake message to the connection and updates
// the record layer state. If transcript is non-nil the marshalled message is
// written to it.
func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
data, err := msg.marshal()
if err != nil {
return 0, err
}
if transcript != nil {
transcript.Write(data)
}
return c.writeRecordLocked(recordTypeHandshake, data)
}
// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
// updates the record layer state.
func (c *Conn) writeChangeCipherRecord() error {
c.out.Lock()
defer c.out.Unlock()
_, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
return err
}
// readHandshake reads the next handshake message from
// the record layer. If transcript is non-nil, the message
// is written to the passed transcriptHash.
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
}
}
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
}
for c.hand.Len() < 4+n {
if err := c.readRecord(); err != nil {
return nil, err
}
}
data = c.hand.Next(4 + n)
var m handshakeMessage
switch data[0] {
case typeHelloRequest:
m = new(helloRequestMsg)
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
m = new(serverHelloMsg)
case typeNewSessionTicket:
if c.vers == VersionTLS13 {
m = new(newSessionTicketMsgTLS13)
} else {
m = new(newSessionTicketMsg)
}
case typeCertificate:
if c.vers == VersionTLS13 {
m = new(certificateMsgTLS13)
} else {
m = new(certificateMsg)
}
case typeCertificateRequest:
if c.vers == VersionTLS13 {
m = new(certificateRequestMsgTLS13)
} else {
m = &certificateRequestMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
}
case typeCertificateStatus:
m = new(certificateStatusMsg)
case typeServerKeyExchange:
m = new(serverKeyExchangeMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeCertificateVerify:
m = &certificateVerifyMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
case typeFinished:
m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshalers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if transcript != nil {
transcript.Write(data)
}
return m, nil
}
var (
errShutdown = errors.New("tls: protocol is shutdown")
)
// Write writes data to the connection.
//
// As Write calls Handshake, in order to prevent indefinite blocking a deadline
// must be set for both Read and Write before Write is called when the handshake
// has not yet completed. See SetDeadline, SetReadDeadline, and
// SetWriteDeadline.
func (c *Conn) Write(b []byte) (int, error) {
// interlock with Close below
for {
x := c.activeCall.Load()
if x&1 != 0 {
return 0, net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x+2) {
break
}
}
defer c.activeCall.Add(-2)
if err := c.Handshake(); err != nil {
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if err := c.out.err; err != nil {
return 0, err
}
if !c.isHandshakeComplete.Load() {
return 0, alertInternalError
}
if c.closeNotifySent {
return 0, errShutdown
}
// TLS 1.0 is susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
// record into two records, effectively randomizing the IV.
//
// https://www.openssl.org/~bodo/tls-cbc.txt
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
if len(b) > 1 && c.vers == VersionTLS10 {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecordLocked(recordTypeApplicationData, b)
return n + m, c.out.setErrorLocked(err)
}
// handleRenegotiation processes a HelloRequest handshake message.
func (c *Conn) handleRenegotiation() error {
if c.vers == VersionTLS13 {
return errors.New("tls: internal error: unexpected renegotiation")
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
helloReq, ok := msg.(*helloRequestMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(helloReq, msg)
}
if !c.isClient {
return c.sendAlert(alertNoRenegotiation)
}
switch c.config.Renegotiation {
case RenegotiateNever:
return c.sendAlert(alertNoRenegotiation)
case RenegotiateOnceAsClient:
if c.handshakes > 1 {
return c.sendAlert(alertNoRenegotiation)
}
case RenegotiateFreelyAsClient:
// Ok.
default:
c.sendAlert(alertInternalError)
return errors.New("tls: unknown Renegotiation value")
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.isHandshakeComplete.Store(false)
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
}
// handlePostHandshakeMessage processes a handshake message arrived after the
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
func (c *Conn) handlePostHandshakeMessage() error {
if c.vers != VersionTLS13 {
return c.handleRenegotiation()
}
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
}
switch msg := msg.(type) {
case *newSessionTicketMsgTLS13:
return c.handleNewSessionTicket(msg)
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
default:
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
}
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil {
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
c.in.setTrafficSecret(cipherSuite, newSecret)
if keyUpdate.updateRequested {
c.out.Lock()
defer c.out.Unlock()
msg := &keyUpdateMsg{}
msgBytes, err := msg.marshal()
if err != nil {
return err
}
_, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.out.setErrorLocked(err)
return nil
}
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
c.out.setTrafficSecret(cipherSuite, newSecret)
}
return nil
}
// Read reads data from the connection.
//
// As Read calls Handshake, in order to prevent indefinite blocking a deadline
// must be set for both Read and Write before Read is called when the handshake
// has not yet completed. See SetDeadline, SetReadDeadline, and
// SetWriteDeadline.
func (c *Conn) Read(b []byte) (int, error) {
if err := c.Handshake(); err != nil {
return 0, err
}
if len(b) == 0 {
// Put this after Handshake, in case people were calling
// Read(nil) for the side effect of the Handshake.
return 0, nil
}
c.in.Lock()
defer c.in.Unlock()
for c.input.Len() == 0 {
if err := c.readRecord(); err != nil {
return 0, err
}
for c.hand.Len() > 0 {
if err := c.handlePostHandshakeMessage(); err != nil {
return 0, err
}
}
}
n, _ := c.input.Read(b)
// If a close-notify alert is waiting, read it so that we can return (n,
// EOF) instead of (n, nil), to signal to the HTTP response reading
// goroutine that the connection is now closed. This eliminates a race
// where the HTTP response reading goroutine would otherwise not observe
// the EOF until its next read, by which time a client goroutine might
// have already tried to reuse the HTTP connection for a new request.
// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
if err := c.readRecord(); err != nil {
return n, err // will be io.EOF on closeNotify
}
}
return n, nil
}
// Close closes the connection.
func (c *Conn) Close() error {
// Interlock with Conn.Write above.
var x int32
for {
x = c.activeCall.Load()
if x&1 != 0 {
return net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x|1) {
break
}
}
if x != 0 {
// io.Writer and io.Closer should not be used concurrently.
// If Close is called while a Write is currently in-flight,
// interpret that as a sign that this Close is really just
// being used to break the Write and/or clean up resources and
// avoid sending the alertCloseNotify, which may block
// waiting on handshakeMutex or the c.out mutex.
return c.conn.Close()
}
var alertErr error
if c.isHandshakeComplete.Load() {
if err := c.closeNotify(); err != nil {
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
}
}
if err := c.conn.Close(); err != nil {
return err
}
return alertErr
}
var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
// CloseWrite shuts down the writing side of the connection. It should only be
// called once the handshake has completed and does not call CloseWrite on the
// underlying connection. Most callers should just use Close.
func (c *Conn) CloseWrite() error {
if !c.isHandshakeComplete.Load() {
return errEarlyCloseWrite
}
return c.closeNotify()
}
func (c *Conn) closeNotify() error {
c.out.Lock()
defer c.out.Unlock()
if !c.closeNotifySent {
// Set a Write Deadline to prevent possibly blocking forever.
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
c.closeNotifySent = true
// Any subsequent writes will fail.
c.SetWriteDeadline(time.Now())
}
return c.closeNotifyErr
}
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
//
// Most uses of this package need not call Handshake explicitly: the
// first Read or Write will call it automatically.
//
// For control over canceling or setting a timeout on a handshake, use
// HandshakeContext or the Dialer's DialContext method instead.
func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}
// HandshakeContext runs the client or server handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first Read or Write will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
// Delegate to unexported method for named return
// without confusing documented signature.
return c.handshakeContext(ctx)
}
func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
// Fast sync/atomic-based exit if there is no handshake in flight and the
// last one succeeded without an error. Avoids the expensive context setup
// and mutex for most Read and Write calls.
if c.isHandshakeComplete.Load() {
return nil
}
handshakeCtx, cancel := context.WithCancel(ctx)
// Note: defer this before starting the "interrupter" goroutine
// so that we can tell the difference between the input being canceled and
// this cancellation. In the former case, we need to close the connection.
defer cancel()
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
if ctx.Done() != nil {
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil {
// Return context error to user.
ret = ctxErr
}
}()
go func() {
select {
case <-handshakeCtx.Done():
// Close the connection, discarding the error
_ = c.conn.Close()
interruptRes <- handshakeCtx.Err()
case <-done:
interruptRes <- nil
}
}()
}
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.handshakeErr; err != nil {
return err
}
if c.isHandshakeComplete.Load() {
return nil
}
c.in.Lock()
defer c.in.Unlock()
c.handshakeErr = c.handshakeFn(handshakeCtx)
if c.handshakeErr == nil {
c.handshakes++
} else {
// If an error occurred during the handshake try to flush the
// alert that might be left in the buffer.
c.flush()
}
if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
}
if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
panic("tls: internal error: handshake returned an error but is marked successful")
}
return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.
func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.connectionStateLocked()
}
func (c *Conn) connectionStateLocked() ConnectionState {
var state ConnectionState
state.HandshakeComplete = c.isHandshakeComplete.Load()
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
state.NegotiatedProtocolIsMutual = true
state.ServerName = c.serverName
state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates
state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse
if !c.didResume && c.vers != VersionTLS13 {
if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:]
} else {
state.TLSUnique = c.serverFinished[:]
}
}
if c.config.Renegotiation != RenegotiateNever {
state.ekm = noExportedKeyingMaterial
} else {
state.ekm = c.ekm
}
return state
}
// OCSPResponse returns the stapled OCSP response from the TLS server, if
// any. (Only valid for client connections.)
func (c *Conn) OCSPResponse() []byte {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.ocspResponse
}
// VerifyHostname checks that the peer certificate chain is valid for
// connecting to host. If so, it returns nil; if not, it returns an error
// describing the problem.
func (c *Conn) VerifyHostname(host string) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.isHandshakeComplete.Load() {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {
return errors.New("tls: handshake did not verify certificate chain")
}
return c.peerCertificates[0].VerifyHostname(host)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"net"
"strings"
"time"
)
type clientHandshakeState struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
suite *cipherSuite
finishedHash finishedHash
masterSecret []byte
session *ClientSessionState
}
var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme
func (c *Conn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
config := c.config
if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
nextProtosLength := 0
for _, proto := range config.NextProtos {
if l := len(proto); l == 0 || l > 255 {
return nil, nil, errors.New("tls: invalid NextProtos value")
} else {
nextProtosLength += 1 + l
}
}
if nextProtosLength > 0xffff {
return nil, nil, errors.New("tls: NextProtos values too large")
}
supportedVersions := config.supportedVersions(roleClient)
if len(supportedVersions) == 0 {
return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
}
clientHelloVersion := config.maxSupportedVersion(roleClient)
// The version at the beginning of the ClientHello was capped at TLS 1.2
// for compatibility reasons. The supported_versions extension is used
// to negotiate versions now. See RFC 8446, Section 4.2.1.
if clientHelloVersion > VersionTLS12 {
clientHelloVersion = VersionTLS12
}
hello := &clientHelloMsg{
vers: clientHelloVersion,
compressionMethods: []uint8{compressionNone},
random: make([]byte, 32),
sessionId: make([]byte, 32),
ocspStapling: true,
scts: true,
serverName: hostnameInSNI(config.ServerName),
supportedCurves: config.curvePreferences(),
supportedPoints: []uint8{pointFormatUncompressed},
secureRenegotiationSupported: true,
alpnProtocols: config.NextProtos,
supportedVersions: supportedVersions,
}
if c.handshakes > 0 {
hello.secureRenegotiation = c.clientFinished[:]
}
preferenceOrder := cipherSuitesPreferenceOrder
if !hasAESGCMHardwareSupport {
preferenceOrder = cipherSuitesPreferenceOrderNoAES
}
configCipherSuites := config.cipherSuites()
hello.cipherSuites = make([]uint16, 0, len(configCipherSuites))
for _, suiteId := range preferenceOrder {
suite := mutualCipherSuite(configCipherSuites, suiteId)
if suite == nil {
continue
}
// Don't advertise TLS 1.2-only cipher suites unless
// we're attempting TLS 1.2.
if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
continue
}
hello.cipherSuites = append(hello.cipherSuites, suiteId)
}
_, err := io.ReadFull(config.rand(), hello.random)
if err != nil {
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
}
// A random session ID is used to detect when the server accepted a ticket
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
// a compatibility measure (see RFC 8446, Section 4.1.2).
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
}
if hello.vers >= VersionTLS12 {
hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if testingOnlyForceClientHelloSignatureAlgorithms != nil {
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
}
var key *ecdh.PrivateKey
if hello.supportedVersions[0] == VersionTLS13 {
if hasAESGCMHardwareSupport {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
} else {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
}
curveID := config.curvePreferences()[0]
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err = generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, err
}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
return hello, key, nil
}
func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.config == nil {
c.config = defaultConfig()
}
// This may be a renegotiation handshake, in which case some fields
// need to be reset.
c.didResume = false
hello, ecdheKey, err := c.makeClientHello()
if err != nil {
return err
}
c.serverName = hello.serverName
cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
if err != nil {
return err
}
if cacheKey != "" && session != nil {
defer func() {
// If we got a handshake failure when resuming a session, throw away
// the session ticket. See RFC 5077, Section 3.2.
//
// RFC 8446 makes no mention of dropping tickets on failure, but it
// does require servers to abort on invalid binders, so we need to
// delete tickets to recover from a corrupted PSK.
if err != nil {
c.config.ClientSessionCache.Put(cacheKey, nil)
}
}()
}
if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
if err := c.pickTLSVersion(serverHello); err != nil {
return err
}
// If we are negotiating a protocol version that's lower than what we
// support, check for the server downgrade canaries.
// See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleClient)
tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12
tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11
if maxVers == VersionTLS13 && c.vers <= VersionTLS12 && (tls12Downgrade || tls11Downgrade) ||
maxVers == VersionTLS12 && c.vers <= VersionTLS11 && tls11Downgrade {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox")
}
if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
ecdheKey: ecdheKey,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
}
// In TLS 1.3, session tickets are delivered after the handshake.
return hs.handshake()
}
hs := &clientHandshakeState{
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
session: session,
}
if err := hs.handshake(); err != nil {
return err
}
// If we had a successful handshake and hs.session is different from
// the one already cached - cache a new one.
if cacheKey != "" && hs.session != nil && session != hs.session {
c.config.ClientSessionCache.Put(cacheKey, hs.session)
}
return nil
}
func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
session *ClientSessionState, earlySecret, binderKey []byte, err error) {
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return "", nil, nil, nil, nil
}
hello.ticketSupported = true
if hello.supportedVersions[0] == VersionTLS13 {
// Require DHE on resumption as it guarantees forward secrecy against
// compromise of the session ticket key. See RFC 8446, Section 4.2.9.
hello.pskModes = []uint8{pskModeDHE}
}
// Session resumption is not allowed if renegotiating because
// renegotiation is primarily used to allow a client to send a client
// certificate, which would be skipped if session resumption occurred.
if c.handshakes != 0 {
return "", nil, nil, nil, nil
}
// Try to resume a previously negotiated TLS session, if available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
session, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || session == nil {
return cacheKey, nil, nil, nil, nil
}
// Check that version used for the previous session is still valid.
versOk := false
for _, v := range hello.supportedVersions {
if v == session.vers {
versOk = true
break
}
}
if !versOk {
return cacheKey, nil, nil, nil, nil
}
// Check that the cached server certificate is not expired, and that it's
// valid for the ServerName. This should be ensured by the cache key, but
// protect the application from a faulty ClientSessionCache implementation.
if !c.config.InsecureSkipVerify {
if len(session.verifiedChains) == 0 {
// The original connection had InsecureSkipVerify, while this doesn't.
return cacheKey, nil, nil, nil, nil
}
serverCert := session.serverCertificates[0]
if c.config.time().After(serverCert.NotAfter) {
// Expired certificate, delete the entry.
c.config.ClientSessionCache.Put(cacheKey, nil)
return cacheKey, nil, nil, nil, nil
}
if err := serverCert.VerifyHostname(c.config.ServerName); err != nil {
return cacheKey, nil, nil, nil, nil
}
}
if session.vers != VersionTLS13 {
// In TLS 1.2 the cipher suite must match the resumed session. Ensure we
// are still offering it.
if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil {
return cacheKey, nil, nil, nil, nil
}
hello.sessionTicket = session.sessionTicket
return
}
// Check that the session ticket is not expired.
if c.config.time().After(session.useBy) {
c.config.ClientSessionCache.Put(cacheKey, nil)
return cacheKey, nil, nil, nil, nil
}
// In TLS 1.3 the KDF hash must match the resumed session. Ensure we
// offer at least one cipher suite with that hash.
cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite)
if cipherSuite == nil {
return cacheKey, nil, nil, nil, nil
}
cipherSuiteOk := false
for _, offeredID := range hello.cipherSuites {
offeredSuite := cipherSuiteTLS13ByID(offeredID)
if offeredSuite != nil && offeredSuite.hash == cipherSuite.hash {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return cacheKey, nil, nil, nil, nil
}
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond)
identity := pskIdentity{
label: session.sessionTicket,
obfuscatedTicketAge: ticketAge + session.ageAdd,
}
hello.pskIdentities = []pskIdentity{identity}
hello.pskBinders = [][]byte{make([]byte, cipherSuite.hash.Size())}
// Compute the PSK binders. See RFC 8446, Section 4.2.11.2.
psk := cipherSuite.expandLabel(session.masterSecret, "resumption",
session.nonce, cipherSuite.hash.Size())
earlySecret = cipherSuite.extract(psk, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
transcript := cipherSuite.hash.New()
helloBytes, err := hello.marshalWithoutBinders()
if err != nil {
return "", nil, nil, nil, err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)}
if err := hello.updateBinders(pskBinders); err != nil {
return "", nil, nil, nil, err
}
return
}
func (c *Conn) pickTLSVersion(serverHello *serverHelloMsg) error {
peerVersion := serverHello.vers
if serverHello.supportedVersion != 0 {
peerVersion = serverHello.supportedVersion
}
vers, ok := c.config.mutualVersion(roleClient, []uint16{peerVersion})
if !ok {
c.sendAlert(alertProtocolVersion)
return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion)
}
c.vers = vers
c.haveVers = true
c.in.version = vers
c.out.version = vers
return nil
}
// Does the handshake, either a full one or resumes old session. Requires hs.c,
// hs.hello, hs.serverHello, and, optionally, hs.session to be set.
func (hs *clientHandshakeState) handshake() error {
c := hs.c
isResume, err := hs.processServerHello()
if err != nil {
return err
}
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
// No signatures of the handshake are needed in a resumption.
// Otherwise, in a full handshake, if we don't have any certificates
// configured then we will never send a CertificateVerify message and
// thus no signatures are needed in that case either.
if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil {
return err
}
if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil {
return err
}
c.buffering = true
c.didResume = isResume
if isResume {
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readSessionTicket(); err != nil {
return err
}
if err := hs.readFinished(c.serverFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = false
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
} else {
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = true
if err := hs.readSessionTicket(); err != nil {
return err
}
if err := hs.readFinished(c.serverFinished[:]); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
c.isHandshakeComplete.Store(true)
return nil
}
func (hs *clientHandshakeState) pickCipherSuite() error {
if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil {
hs.c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.c.cipherSuite = hs.suite.id
return nil
}
func (hs *clientHandshakeState) doFullHandshake() error {
c := hs.c
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsg)
if !ok || len(certMsg.certificates) == 0 {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
cs, ok := msg.(*certificateStatusMsg)
if ok {
// RFC4366 on Certificate Status Request:
// The server MAY return a "certificate_status" message.
if !hs.serverHello.ocspStapling {
// If a server returns a "CertificateStatus" message, then the
// server MUST have included an extension of type "status_request"
// with empty "extension_data" in the extended server hello.
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received unexpected CertificateStatus message")
}
c.ocspResponse = cs.response
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.handshakes == 0 {
// If this is the first handshake on a connection, process and
// (optionally) verify the server's certificates.
if err := c.verifyServerCertificate(certMsg.certificates); err != nil {
return err
}
} else {
// This is a renegotiation handshake. We require that the
// server's identity (i.e. leaf certificate) is unchanged and
// thus any previous trust decision is still valid.
//
// See https://mitls.org/pages/attacks/3SHAKE for the
// motivation behind this requirement.
if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: server's identity changed during renegotiation")
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
if err != nil {
c.sendAlert(alertUnexpectedMessage)
return err
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
var chainToSend *Certificate
var certRequested bool
certReq, ok := msg.(*certificateRequestMsg)
if ok {
certRequested = true
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
c.sendAlert(alertInternalError)
return err
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
shd, ok := msg.(*serverHelloDoneMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(shd, msg)
}
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
// certificate to send.
if certRequested {
certMsg = new(certificateMsg)
certMsg.certificates = chainToSend.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
}
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0])
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if ckx != nil {
if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil {
return err
}
}
if chainToSend != nil && len(chainToSend.Certificate) > 0 {
certVerify := &certificateVerifyMsg{}
key, ok := chainToSend.PrivateKey.(crypto.Signer)
if !ok {
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
certVerify.hasSignatureAlgorithm = true
certVerify.signatureAlgorithm = signatureAlgorithm
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public())
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.hello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to write to key log: " + err.Error())
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *clientHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.cipher != nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
return nil
}
func (hs *clientHandshakeState) serverResumedSession() bool {
// If the server responded with the same sessionId then it means the
// sessionTicket is being used to resume a TLS session.
return hs.session != nil && hs.hello.sessionId != nil &&
bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
}
func (hs *clientHandshakeState) processServerHello() (bool, error) {
c := hs.c
if err := hs.pickCipherSuite(); err != nil {
return false, err
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertUnexpectedMessage)
return false, errors.New("tls: server selected unsupported compression format")
}
if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported {
c.secureRenegotiation = true
if len(hs.serverHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
}
}
if c.handshakes > 0 && c.secureRenegotiation {
var expectedSecureRenegotiation [24]byte
copy(expectedSecureRenegotiation[:], c.clientFinished[:])
copy(expectedSecureRenegotiation[12:], c.serverFinished[:])
if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: incorrect renegotiation extension contents")
}
}
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
return false, err
}
c.clientProtocol = hs.serverHello.alpnProtocol
c.scts = hs.serverHello.scts
if !hs.serverResumedSession() {
return false, nil
}
if hs.session.vers != c.vers {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: server resumed a session with a different version")
}
if hs.session.cipherSuite != hs.suite.id {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: server resumed a session with a different cipher suite")
}
// Restore masterSecret, peerCerts, and ocspResponse from previous state
hs.masterSecret = hs.session.masterSecret
c.peerCertificates = hs.session.serverCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
// Let the ServerHello SCTs override the session SCTs from the original
// connection, if any are provided
if len(c.scts) == 0 && len(hs.session.scts) != 0 {
c.scts = hs.session.scts
}
return true, nil
}
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
// the protocols that we advertised in the Client Hello.
func checkALPN(clientProtos []string, serverProto string) error {
if serverProto == "" {
return nil
}
if len(clientProtos) == 0 {
return errors.New("tls: server advertised unrequested ALPN extension")
}
for _, proto := range clientProtos {
if proto == serverProto {
return nil
}
}
return errors.New("tls: server selected unadvertised ALPN protocol")
}
func (hs *clientHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverFinished, msg)
}
verify := hs.finishedHash.serverSum(hs.masterSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server's Finished message was incorrect")
}
if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *clientHandshakeState) readSessionTicket() error {
if !hs.serverHello.ticketSupported {
return nil
}
c := hs.c
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}
hs.session = &ClientSessionState{
sessionTicket: sessionTicketMsg.ticket,
vers: c.vers,
cipherSuite: hs.suite.id,
masterSecret: hs.masterSecret,
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
ocspResponse: c.ocspResponse,
scts: c.scts,
}
return nil
}
func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// verifyServerCertificate parses and verifies the provided chain, setting
// c.verifiedChains and c.peerCertificates or sending the appropriate alert.
func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
activeHandles := make([]*activeCert, len(certificates))
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := clientCertCache.newCert(asn1Data)
if err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse certificate from server: " + err.Error())
}
activeHandles[i] = cert
certs[i] = cert.cert
}
if !c.config.InsecureSkipVerify {
opts := x509.VerifyOptions{
Roots: c.config.RootCAs,
CurrentTime: c.config.time(),
DNSName: c.config.ServerName,
Intermediates: x509.NewCertPool(),
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
var err error
c.verifiedChains, err = certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
}
switch certs[0].PublicKey.(type) {
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
break
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
}
c.activeCertHandles = activeHandles
c.peerCertificates = certs
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
cri := &CertificateRequestInfo{
AcceptableCAs: certReq.certificateAuthorities,
Version: vers,
ctx: ctx,
}
var rsaAvail, ecAvail bool
for _, certType := range certReq.certificateTypes {
switch certType {
case certTypeRSASign:
rsaAvail = true
case certTypeECDSASign:
ecAvail = true
}
}
if !certReq.hasSignatureAlgorithm {
// Prior to TLS 1.2, signature schemes did not exist. In this case we
// make up a list based on the acceptable certificate types, to help
// GetClientCertificate and SupportsCertificate select the right certificate.
// The hash part of the SignatureScheme is a lie here, because
// TLS 1.0 and 1.1 always use MD5+SHA1 for RSA and SHA1 for ECDSA.
switch {
case rsaAvail && ecAvail:
cri.SignatureSchemes = []SignatureScheme{
ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512,
PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1,
}
case rsaAvail:
cri.SignatureSchemes = []SignatureScheme{
PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1,
}
case ecAvail:
cri.SignatureSchemes = []SignatureScheme{
ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512,
}
}
return cri
}
// Filter the signature schemes based on the certificate types.
// See RFC 5246, Section 7.4.4 (where it calls this "somewhat complicated").
cri.SignatureSchemes = make([]SignatureScheme, 0, len(certReq.supportedSignatureAlgorithms))
for _, sigScheme := range certReq.supportedSignatureAlgorithms {
sigType, _, err := typeAndHashFromSignatureScheme(sigScheme)
if err != nil {
continue
}
switch sigType {
case signatureECDSA, signatureEd25519:
if ecAvail {
cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme)
}
case signatureRSAPSS, signaturePKCS1v15:
if rsaAvail {
cri.SignatureSchemes = append(cri.SignatureSchemes, sigScheme)
}
}
}
return cri
}
func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate, error) {
if c.config.GetClientCertificate != nil {
return c.config.GetClientCertificate(cri)
}
for _, chain := range c.config.Certificates {
if err := cri.SupportsCertificate(&chain); err != nil {
continue
}
return &chain, nil
}
// No acceptable certificate found. Don't send a certificate.
return new(Certificate), nil
}
// clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server.
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
if len(config.ServerName) > 0 {
return config.ServerName
}
return serverAddr.String()
}
// hostnameInSNI converts name into an appropriate hostname for SNI.
// Literal IP addresses and absolute FQDNs are not permitted as SNI values.
// See RFC 6066, Section 3.
func hostnameInSNI(name string) string {
host := name
if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' {
host = host[1 : len(host)-1]
}
if i := strings.LastIndex(host, "%"); i > 0 {
host = host[:i]
}
if net.ParseIP(host) != nil {
return ""
}
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
return name
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/ecdh"
"crypto/hmac"
"crypto/rsa"
"errors"
"hash"
"time"
)
type clientHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheKey *ecdh.PrivateKey
session *ClientSessionState
earlySecret []byte
binderKey []byte
certReq *certificateRequestMsgTLS13
usingPSK bool
sentDummyCCS bool
suite *cipherSuiteTLS13
transcript hash.Hash
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
}
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
// sections 4.1.2 and 4.1.3.
if c.handshakes > 0 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 {
return c.sendAlert(alertInternalError)
}
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
hs.transcript = hs.suite.hash.New()
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
return err
}
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.processHelloRetryRequest(); err != nil {
return err
}
}
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
c.buffering = true
if err := hs.processServerHello(); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.establishHandshakeKeys(); err != nil {
return err
}
if err := hs.readServerParameters(); err != nil {
return err
}
if err := hs.readServerCertificate(); err != nil {
return err
}
if err := hs.readServerFinished(); err != nil {
return err
}
if err := hs.sendClientCertificate(); err != nil {
return err
}
if err := hs.sendClientFinished(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.isHandshakeComplete.Store(true)
return nil
}
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
// HelloRetryRequest messages. It sets hs.suite.
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
c := hs.c
if hs.serverHello.supportedVersion == 0 {
c.sendAlert(alertMissingExtension)
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
}
if hs.serverHello.supportedVersion != VersionTLS13 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
}
if hs.serverHello.vers != VersionTLS12 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an incorrect legacy version")
}
if hs.serverHello.ocspStapling ||
hs.serverHello.ticketSupported ||
hs.serverHello.secureRenegotiationSupported ||
len(hs.serverHello.secureRenegotiation) != 0 ||
len(hs.serverHello.alpnProtocol) != 0 ||
len(hs.serverHello.scts) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
}
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not echo the legacy session ID")
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported compression format")
}
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
if hs.suite != nil && selectedSuite != hs.suite {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
}
if selectedSuite == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.suite = selectedSuite
c.cipherSuite = hs.suite.id
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. (The idea is that the server might offload transcript
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
// in any change in the ClientHello.
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
}
if hs.serverHello.cookie != nil {
hs.hello.cookie = hs.serverHello.cookie
}
if hs.serverHello.serverShare.group != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received malformed key_share extension")
}
// If the server sent a key_share extension selecting a group, ensure it's
// a group we advertised but did not send a key share for, and send a key
// share for it this time.
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
curveOK := false
for _, id := range hs.hello.supportedCurves {
if id == curveID {
curveOK = true
break
}
}
if !curveOK {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if _, ok := curveForCurveID(curveID); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), curveID)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.ecdheKey = key
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash == hs.suite.hash {
// Update binders and obfuscated_ticket_age.
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
helloBytes, err := hs.hello.marshalWithoutBinders()
if err != nil {
return err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
if err := hs.hello.updateBinders(pskBinders); err != nil {
return err
}
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
hs.hello.pskBinders = nil
}
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
hs.serverHello = serverHello
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) processServerHello() error {
c := hs.c
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: server sent two HelloRetryRequest messages")
}
if len(hs.serverHello.cookie) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a cookie in a normal ServerHello")
}
if hs.serverHello.selectedGroup != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: malformed key_share extension")
}
if hs.serverHello.serverShare.group == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if !hs.serverHello.selectedIdentityPresent {
return nil
}
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK")
}
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
return c.sendAlert(alertInternalError)
}
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash != hs.suite.hash {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
}
hs.usingPSK = true
c.didResume = true
c.peerCertificates = hs.session.serverCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
c.scts = hs.session.scts
return nil
}
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c
peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
sharedKey, err := hs.ecdheKey.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
earlySecret := hs.earlySecret
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
return nil
}
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
return nil
}
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c := hs.c
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
hs.certReq = certReq
msg, err = c.readHandshake(hs.transcript)
if err != nil {
return err
}
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if len(certMsg.certificate.Certificate) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
return err
}
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
if !hmac.Equal(expectedMAC, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid server finished hash")
}
if err := transcriptMsg(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
c := hs.c
if hs.certReq == nil {
return nil
}
cert, err := c.getClientCertificate(&CertificateRequestInfo{
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
ctx: hs.ctx,
})
if err != nil {
return err
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *cert
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
// If we sent an empty certificate message, skip the CertificateVerify.
if len(cert.Certificate) == 0 {
return nil
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
if err != nil {
// getClientCertificate returned a certificate incompatible with the
// CertificateRequestInfo supported signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
return nil
}
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received new session ticket from a client")
}
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil
}
// See RFC 8446, Section 4.6.1.
if msg.lifetime == 0 {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil || c.resumptionSecret == nil {
return c.sendAlert(alertInternalError)
}
// Save the resumption_master_secret and nonce instead of deriving the PSK
// to do the least amount of work on NewSessionTicket messages before we
// know if the ticket will be used. Forward secrecy of resumed connections
// is guaranteed by the requirement for pskModeDHE.
session := &ClientSessionState{
sessionTicket: msg.label,
vers: c.vers,
cipherSuite: c.cipherSuite,
masterSecret: c.resumptionSecret,
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
nonce: msg.nonce,
useBy: c.config.time().Add(lifetime),
ageAdd: msg.ageAdd,
ocspResponse: c.ocspResponse,
scts: c.scts,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
c.config.ClientSessionCache.Put(cacheKey, session)
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"errors"
"fmt"
"strings"
"golang.org/x/crypto/cryptobyte"
)
// The marshalingFunction type is an adapter to allow the use of ordinary
// functions as cryptobyte.MarshalingValue.
type marshalingFunction func(b *cryptobyte.Builder) error
func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
return f(b)
}
// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
// the length of the sequence is not the value specified, it produces an error.
func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
if len(v) != n {
return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
}
b.AddBytes(v)
return nil
}))
}
// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
func addUint64(b *cryptobyte.Builder, v uint64) {
b.AddUint32(uint32(v >> 32))
b.AddUint32(uint32(v))
}
// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful.
func readUint64(s *cryptobyte.String, out *uint64) bool {
var hi, lo uint32
if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
return false
}
*out = uint64(hi)<<32 | uint64(lo)
return true
}
// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
}
// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
}
// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
// []byte instead of a cryptobyte.String.
func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
}
type clientHelloMsg struct {
raw []byte
vers uint16
random []byte
sessionId []byte
cipherSuites []uint16
compressionMethods []uint8
serverName string
ocspStapling bool
supportedCurves []CurveID
supportedPoints []uint8
ticketSupported bool
sessionTicket []uint8
supportedSignatureAlgorithms []SignatureScheme
supportedSignatureAlgorithmsCert []SignatureScheme
secureRenegotiationSupported bool
secureRenegotiation []byte
alpnProtocols []string
scts bool
supportedVersions []uint16
cookie []byte
keyShares []keyShare
earlyData bool
pskModes []uint8
pskIdentities []pskIdentity
pskBinders [][]byte
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder
if len(m.serverName) > 0 {
// RFC 6066, Section 3
exts.AddUint16(extensionServerName)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8(0) // name_type = host_name
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(m.serverName))
})
})
})
}
if m.ocspStapling {
// RFC 4366, Section 3.6
exts.AddUint16(extensionStatusRequest)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8(1) // status_type = ocsp
exts.AddUint16(0) // empty responder_id_list
exts.AddUint16(0) // empty request_extensions
})
}
if len(m.supportedCurves) > 0 {
// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
exts.AddUint16(extensionSupportedCurves)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, curve := range m.supportedCurves {
exts.AddUint16(uint16(curve))
}
})
})
}
if len(m.supportedPoints) > 0 {
// RFC 4492, Section 5.1.2
exts.AddUint16(extensionSupportedPoints)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.supportedPoints)
})
})
}
if m.ticketSupported {
// RFC 5077, Section 3.2
exts.AddUint16(extensionSessionTicket)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.sessionTicket)
})
}
if len(m.supportedSignatureAlgorithms) > 0 {
// RFC 5246, Section 7.4.1.4.1
exts.AddUint16(extensionSignatureAlgorithms)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithms {
exts.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.supportedSignatureAlgorithmsCert) > 0 {
// RFC 8446, Section 4.2.3
exts.AddUint16(extensionSignatureAlgorithmsCert)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
exts.AddUint16(uint16(sigAlgo))
}
})
})
}
if m.secureRenegotiationSupported {
// RFC 5746, Section 3.2
exts.AddUint16(extensionRenegotiationInfo)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.secureRenegotiation)
})
})
}
if len(m.alpnProtocols) > 0 {
// RFC 7301, Section 3.1
exts.AddUint16(extensionALPN)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, proto := range m.alpnProtocols {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(proto))
})
}
})
})
}
if m.scts {
// RFC 6962, Section 3.3.1
exts.AddUint16(extensionSCT)
exts.AddUint16(0) // empty extension_data
}
if len(m.supportedVersions) > 0 {
// RFC 8446, Section 4.2.1
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, vers := range m.supportedVersions {
exts.AddUint16(vers)
}
})
})
}
if len(m.cookie) > 0 {
// RFC 8446, Section 4.2.2
exts.AddUint16(extensionCookie)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.cookie)
})
})
}
if len(m.keyShares) > 0 {
// RFC 8446, Section 4.2.8
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, ks := range m.keyShares {
exts.AddUint16(uint16(ks.group))
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(ks.data)
})
}
})
})
}
if m.earlyData {
// RFC 8446, Section 4.2.10
exts.AddUint16(extensionEarlyData)
exts.AddUint16(0) // empty extension_data
}
if len(m.pskModes) > 0 {
// RFC 8446, Section 4.2.9
exts.AddUint16(extensionPSKModes)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.pskModes)
})
})
}
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
// RFC 8446, Section 4.2.11
exts.AddUint16(extensionPreSharedKey)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, psk := range m.pskIdentities {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(psk.label)
})
exts.AddUint32(psk.obfuscatedTicketAge)
}
})
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(binder)
})
}
})
})
}
extBytes, err := exts.Bytes()
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddUint8(typeClientHello)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16(m.vers)
addBytesWithLength(b, m.random, 32)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.sessionId)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, suite := range m.cipherSuites {
b.AddUint16(suite)
}
})
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.compressionMethods)
})
if len(extBytes) > 0 {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(extBytes)
})
}
})
m.raw, err = b.Bytes()
return m.raw, err
}
// marshalWithoutBinders returns the ClientHello through the
// PreSharedKeyExtension.identities field, according to RFC 8446, Section
// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen := 2 // uint16 length prefix
for _, binder := range m.pskBinders {
bindersLen += 1 // uint8 length prefix
bindersLen += len(binder)
}
fullMessage, err := m.marshal()
if err != nil {
return nil, err
}
return fullMessage[:len(fullMessage)-bindersLen], nil
}
// updateBinders updates the m.pskBinders field, if necessary updating the
// cached marshaled representation. The supplied binders must have the same
// length as the current m.pskBinders.
func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
return errors.New("tls: internal error: pskBinders length mismatch")
}
for i := range m.pskBinders {
if len(pskBinders[i]) != len(m.pskBinders[i]) {
return errors.New("tls: internal error: pskBinders length mismatch")
}
}
m.pskBinders = pskBinders
if m.raw != nil {
helloBytes, err := m.marshalWithoutBinders()
if err != nil {
return err
}
lenWithoutBinders := len(helloBytes)
b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(binder)
})
}
})
if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
return errors.New("tls: internal error: failed to update binders")
}
}
return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
*m = clientHelloMsg{raw: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
!readUint8LengthPrefixed(&s, &m.sessionId) {
return false
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false
}
m.cipherSuites = []uint16{}
m.secureRenegotiationSupported = false
for !cipherSuites.Empty() {
var suite uint16
if !cipherSuites.ReadUint16(&suite) {
return false
}
if suite == scsvRenegotiation {
m.secureRenegotiationSupported = true
}
m.cipherSuites = append(m.cipherSuites, suite)
}
if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
return false
}
if s.Empty() {
// ClientHello is optionally followed by extension data
return true
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
seenExts := make(map[uint16]bool)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if seenExts[extension] {
return false
}
seenExts[extension] = true
switch extension {
case extensionServerName:
// RFC 6066, Section 3
var nameList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return false
}
for !nameList.Empty() {
var nameType uint8
var serverName cryptobyte.String
if !nameList.ReadUint8(&nameType) ||
!nameList.ReadUint16LengthPrefixed(&serverName) ||
serverName.Empty() {
return false
}
if nameType != 0 {
continue
}
if len(m.serverName) != 0 {
// Multiple names of the same name_type are prohibited.
return false
}
m.serverName = string(serverName)
// An SNI value may not include a trailing dot.
if strings.HasSuffix(m.serverName, ".") {
return false
}
}
case extensionStatusRequest:
// RFC 4366, Section 3.6
var statusType uint8
var ignored cryptobyte.String
if !extData.ReadUint8(&statusType) ||
!extData.ReadUint16LengthPrefixed(&ignored) ||
!extData.ReadUint16LengthPrefixed(&ignored) {
return false
}
m.ocspStapling = statusType == statusTypeOCSP
case extensionSupportedCurves:
// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
var curves cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
return false
}
for !curves.Empty() {
var curve uint16
if !curves.ReadUint16(&curve) {
return false
}
m.supportedCurves = append(m.supportedCurves, CurveID(curve))
}
case extensionSupportedPoints:
// RFC 4492, Section 5.1.2
if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
len(m.supportedPoints) == 0 {
return false
}
case extensionSessionTicket:
// RFC 5077, Section 3.2
m.ticketSupported = true
extData.ReadBytes(&m.sessionTicket, len(extData))
case extensionSignatureAlgorithms:
// RFC 5246, Section 7.4.1.4.1
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithms = append(
m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionSignatureAlgorithmsCert:
// RFC 8446, Section 4.2.3
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithmsCert = append(
m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
}
case extensionRenegotiationInfo:
// RFC 5746, Section 3.2
if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
m.secureRenegotiationSupported = true
case extensionALPN:
// RFC 7301, Section 3.1
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
for !protoList.Empty() {
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
return false
}
m.alpnProtocols = append(m.alpnProtocols, string(proto))
}
case extensionSCT:
// RFC 6962, Section 3.3.1
m.scts = true
case extensionSupportedVersions:
// RFC 8446, Section 4.2.1
var versList cryptobyte.String
if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
return false
}
for !versList.Empty() {
var vers uint16
if !versList.ReadUint16(&vers) {
return false
}
m.supportedVersions = append(m.supportedVersions, vers)
}
case extensionCookie:
// RFC 8446, Section 4.2.2
if !readUint16LengthPrefixed(&extData, &m.cookie) ||
len(m.cookie) == 0 {
return false
}
case extensionKeyShare:
// RFC 8446, Section 4.2.8
var clientShares cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&clientShares) {
return false
}
for !clientShares.Empty() {
var ks keyShare
if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
!readUint16LengthPrefixed(&clientShares, &ks.data) ||
len(ks.data) == 0 {
return false
}
m.keyShares = append(m.keyShares, ks)
}
case extensionEarlyData:
// RFC 8446, Section 4.2.10
m.earlyData = true
case extensionPSKModes:
// RFC 8446, Section 4.2.9
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
return false
}
case extensionPreSharedKey:
// RFC 8446, Section 4.2.11
if !extensions.Empty() {
return false // pre_shared_key must be the last extension
}
var identities cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
return false
}
for !identities.Empty() {
var psk pskIdentity
if !readUint16LengthPrefixed(&identities, &psk.label) ||
!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
len(psk.label) == 0 {
return false
}
m.pskIdentities = append(m.pskIdentities, psk)
}
var binders cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
return false
}
for !binders.Empty() {
var binder []byte
if !readUint8LengthPrefixed(&binders, &binder) ||
len(binder) == 0 {
return false
}
m.pskBinders = append(m.pskBinders, binder)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type serverHelloMsg struct {
raw []byte
vers uint16
random []byte
sessionId []byte
cipherSuite uint16
compressionMethod uint8
ocspStapling bool
ticketSupported bool
secureRenegotiationSupported bool
secureRenegotiation []byte
alpnProtocol string
scts [][]byte
supportedVersion uint16
serverShare keyShare
selectedIdentityPresent bool
selectedIdentity uint16
supportedPoints []uint8
// HelloRetryRequest extensions
cookie []byte
selectedGroup CurveID
}
func (m *serverHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder
if m.ocspStapling {
exts.AddUint16(extensionStatusRequest)
exts.AddUint16(0) // empty extension_data
}
if m.ticketSupported {
exts.AddUint16(extensionSessionTicket)
exts.AddUint16(0) // empty extension_data
}
if m.secureRenegotiationSupported {
exts.AddUint16(extensionRenegotiationInfo)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.secureRenegotiation)
})
})
}
if len(m.alpnProtocol) > 0 {
exts.AddUint16(extensionALPN)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes([]byte(m.alpnProtocol))
})
})
})
}
if len(m.scts) > 0 {
exts.AddUint16(extensionSCT)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
for _, sct := range m.scts {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(sct)
})
}
})
})
}
if m.supportedVersion != 0 {
exts.AddUint16(extensionSupportedVersions)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(m.supportedVersion)
})
}
if m.serverShare.group != 0 {
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(uint16(m.serverShare.group))
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.serverShare.data)
})
})
}
if m.selectedIdentityPresent {
exts.AddUint16(extensionPreSharedKey)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(m.selectedIdentity)
})
}
if len(m.cookie) > 0 {
exts.AddUint16(extensionCookie)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.cookie)
})
})
}
if m.selectedGroup != 0 {
exts.AddUint16(extensionKeyShare)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint16(uint16(m.selectedGroup))
})
}
if len(m.supportedPoints) > 0 {
exts.AddUint16(extensionSupportedPoints)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(m.supportedPoints)
})
})
}
extBytes, err := exts.Bytes()
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddUint8(typeServerHello)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16(m.vers)
addBytesWithLength(b, m.random, 32)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.sessionId)
})
b.AddUint16(m.cipherSuite)
b.AddUint8(m.compressionMethod)
if len(extBytes) > 0 {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(extBytes)
})
}
})
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
*m = serverHelloMsg{raw: data}
s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
!readUint8LengthPrefixed(&s, &m.sessionId) ||
!s.ReadUint16(&m.cipherSuite) ||
!s.ReadUint8(&m.compressionMethod) {
return false
}
if s.Empty() {
// ServerHello is optionally followed by extension data
return true
}
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
seenExts := make(map[uint16]bool)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if seenExts[extension] {
return false
}
seenExts[extension] = true
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
case extensionSessionTicket:
m.ticketSupported = true
case extensionRenegotiationInfo:
if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
m.secureRenegotiationSupported = true
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) ||
proto.Empty() || !protoList.Empty() {
return false
}
m.alpnProtocol = string(proto)
case extensionSCT:
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
for !sctList.Empty() {
var sct []byte
if !readUint16LengthPrefixed(&sctList, &sct) ||
len(sct) == 0 {
return false
}
m.scts = append(m.scts, sct)
}
case extensionSupportedVersions:
if !extData.ReadUint16(&m.supportedVersion) {
return false
}
case extensionCookie:
if !readUint16LengthPrefixed(&extData, &m.cookie) ||
len(m.cookie) == 0 {
return false
}
case extensionKeyShare:
// This extension has different formats in SH and HRR, accept either
// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
if len(extData) == 2 {
if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
return false
}
} else {
if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
return false
}
}
case extensionPreSharedKey:
m.selectedIdentityPresent = true
if !extData.ReadUint16(&m.selectedIdentity) {
return false
}
case extensionSupportedPoints:
// RFC 4492, Section 5.1.2
if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
len(m.supportedPoints) == 0 {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if len(m.alpnProtocol) > 0 {
b.AddUint16(extensionALPN)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(m.alpnProtocol))
})
})
})
}
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
*m = encryptedExtensionsMsg{raw: data}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) ||
proto.Empty() || !protoList.Empty() {
return false
}
m.alpnProtocol = string(proto)
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type endOfEarlyDataMsg struct{}
func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeEndOfEarlyData
return x, nil
}
func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type keyUpdateMsg struct {
raw []byte
updateRequested bool
}
func (m *keyUpdateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.updateRequested {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
var updateRequested uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&updateRequested) || !s.Empty() {
return false
}
switch updateRequested {
case 0:
m.updateRequested = false
case 1:
m.updateRequested = true
default:
return false
}
return true
}
type newSessionTicketMsgTLS13 struct {
raw []byte
lifetime uint32
ageAdd uint32
nonce []byte
label []byte
maxEarlyData uint32
}
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.lifetime)
b.AddUint32(m.ageAdd)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.nonce)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.label)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.maxEarlyData > 0 {
b.AddUint16(extensionEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.maxEarlyData)
})
}
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{raw: data}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint32(&m.lifetime) ||
!s.ReadUint32(&m.ageAdd) ||
!readUint8LengthPrefixed(&s, &m.nonce) ||
!readUint16LengthPrefixed(&s, &m.label) ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionEarlyData:
if !extData.ReadUint32(&m.maxEarlyData) {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateRequestMsgTLS13 struct {
raw []byte
ocspStapling bool
scts bool
supportedSignatureAlgorithms []SignatureScheme
supportedSignatureAlgorithmsCert []SignatureScheme
certificateAuthorities [][]byte
}
func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
// certificate_request_context (SHALL be zero length unless used for
// post-handshake authentication)
b.AddUint8(0)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.ocspStapling {
b.AddUint16(extensionStatusRequest)
b.AddUint16(0) // empty extension_data
}
if m.scts {
// RFC 8446, Section 4.4.2.1 makes no mention of
// signed_certificate_timestamp in CertificateRequest, but
// "Extensions in the Certificate message from the client MUST
// correspond to extensions in the CertificateRequest message
// from the server." and it appears in the table in Section 4.2.
b.AddUint16(extensionSCT)
b.AddUint16(0) // empty extension_data
}
if len(m.supportedSignatureAlgorithms) > 0 {
b.AddUint16(extensionSignatureAlgorithms)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithms {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.supportedSignatureAlgorithmsCert) > 0 {
b.AddUint16(extensionSignatureAlgorithmsCert)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.certificateAuthorities) > 0 {
b.AddUint16(extensionCertificateAuthorities)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, ca := range m.certificateAuthorities {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(ca)
})
}
})
})
}
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
*m = certificateRequestMsgTLS13{raw: data}
s := cryptobyte.String(data)
var context, extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
case extensionSCT:
m.scts = true
case extensionSignatureAlgorithms:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithms = append(
m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionSignatureAlgorithmsCert:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithmsCert = append(
m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
}
case extensionCertificateAuthorities:
var auths cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
return false
}
for !auths.Empty() {
var ca []byte
if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
return false
}
m.certificateAuthorities = append(m.certificateAuthorities, ca)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateMsg struct {
raw []byte
certificates [][]byte
}
func (m *certificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var i int
for _, slice := range m.certificates {
i += len(slice)
}
length := 3 + 3*len(m.certificates) + i
x := make([]byte, 4+length)
x[0] = typeCertificate
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
certificateOctets := length - 3
x[4] = uint8(certificateOctets >> 16)
x[5] = uint8(certificateOctets >> 8)
x[6] = uint8(certificateOctets)
y := x[7:]
for _, slice := range m.certificates {
y[0] = uint8(len(slice) >> 16)
y[1] = uint8(len(slice) >> 8)
y[2] = uint8(len(slice))
copy(y[3:], slice)
y = y[3+len(slice):]
}
m.raw = x
return m.raw, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
if len(data) < 7 {
return false
}
m.raw = data
certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
if uint32(len(data)) != certsLen+7 {
return false
}
numCerts := 0
d := data[7:]
for certsLen > 0 {
if len(d) < 4 {
return false
}
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
if uint32(len(d)) < 3+certLen {
return false
}
d = d[3+certLen:]
certsLen -= 3 + certLen
numCerts++
}
m.certificates = make([][]byte, numCerts)
d = data[7:]
for i := 0; i < numCerts; i++ {
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
m.certificates[i] = d[3 : 3+certLen]
d = d[3+certLen:]
}
return true
}
type certificateMsgTLS13 struct {
raw []byte
certificate Certificate
ocspStapling bool
scts bool
}
func (m *certificateMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(0) // certificate_request_context
certificate := m.certificate
if !m.ocspStapling {
certificate.OCSPStaple = nil
}
if !m.scts {
certificate.SignedCertificateTimestamps = nil
}
marshalCertificate(b, certificate)
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for i, cert := range certificate.Certificate {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if i > 0 {
// This library only supports OCSP and SCT for leaf certificates.
return
}
if certificate.OCSPStaple != nil {
b.AddUint16(extensionStatusRequest)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(certificate.OCSPStaple)
})
})
}
if certificate.SignedCertificateTimestamps != nil {
b.AddUint16(extensionSCT)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sct := range certificate.SignedCertificateTimestamps {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(sct)
})
}
})
})
}
})
}
})
}
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{raw: data}
s := cryptobyte.String(data)
var context cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!unmarshalCertificate(&s, &m.certificate) ||
!s.Empty() {
return false
}
m.scts = m.certificate.SignedCertificateTimestamps != nil
m.ocspStapling = m.certificate.OCSPStaple != nil
return true
}
func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
var extensions cryptobyte.String
if !readUint24LengthPrefixed(&certList, &cert) ||
!certList.ReadUint16LengthPrefixed(&extensions) {
return false
}
certificate.Certificate = append(certificate.Certificate, cert)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if len(certificate.Certificate) > 1 {
// This library only supports OCSP and SCT for leaf certificates.
continue
}
switch extension {
case extensionStatusRequest:
var statusType uint8
if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
len(certificate.OCSPStaple) == 0 {
return false
}
case extensionSCT:
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
for !sctList.Empty() {
var sct []byte
if !readUint16LengthPrefixed(&sctList, &sct) ||
len(sct) == 0 {
return false
}
certificate.SignedCertificateTimestamps = append(
certificate.SignedCertificateTimestamps, sct)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
}
return true
}
type serverKeyExchangeMsg struct {
raw []byte
key []byte
}
func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.key)
x := make([]byte, length+4)
x[0] = typeServerKeyExchange
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
copy(x[4:], m.key)
m.raw = x
return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 {
return false
}
m.key = data[4:]
return true
}
type certificateStatusMsg struct {
raw []byte
response []byte
}
func (m *certificateStatusMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateStatus)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.response)
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
var statusType uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&s, &m.response) ||
len(m.response) == 0 || !s.Empty() {
return false
}
return true
}
type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeServerHelloDone
return x, nil
}
func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type clientKeyExchangeMsg struct {
raw []byte
ciphertext []byte
}
func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.ciphertext)
x := make([]byte, length+4)
x[0] = typeClientKeyExchange
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
copy(x[4:], m.ciphertext)
m.raw = x
return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 {
return false
}
l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if l != len(data)-4 {
return false
}
m.ciphertext = data[4:]
return true
}
type finishedMsg struct {
raw []byte
verifyData []byte
}
func (m *finishedMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeFinished)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.verifyData)
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *finishedMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
return s.Skip(1) &&
readUint24LengthPrefixed(&s, &m.verifyData) &&
s.Empty()
}
type certificateRequestMsg struct {
raw []byte
// hasSignatureAlgorithm indicates whether this message includes a list of
// supported signature algorithms. This change was introduced with TLS 1.2.
hasSignatureAlgorithm bool
certificateTypes []byte
supportedSignatureAlgorithms []SignatureScheme
certificateAuthorities [][]byte
}
func (m *certificateRequestMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities {
casLength += 2 + len(ca)
}
length += casLength
if m.hasSignatureAlgorithm {
length += 2 + 2*len(m.supportedSignatureAlgorithms)
}
x := make([]byte, 4+length)
x[0] = typeCertificateRequest
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(len(m.certificateTypes))
copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):]
if m.hasSignatureAlgorithm {
n := len(m.supportedSignatureAlgorithms) * 2
y[0] = uint8(n >> 8)
y[1] = uint8(n)
y = y[2:]
for _, sigAlgo := range m.supportedSignatureAlgorithms {
y[0] = uint8(sigAlgo >> 8)
y[1] = uint8(sigAlgo)
y = y[2:]
}
}
y[0] = uint8(casLength >> 8)
y[1] = uint8(casLength)
y = y[2:]
for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8)
y[1] = uint8(len(ca))
y = y[2:]
copy(y, ca)
y = y[len(ca):]
}
m.raw = x
return m.raw, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 5 {
return false
}
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
if uint32(len(data))-4 != length {
return false
}
numCertTypes := int(data[4])
data = data[5:]
if numCertTypes == 0 || len(data) <= numCertTypes {
return false
}
m.certificateTypes = make([]byte, numCertTypes)
if copy(m.certificateTypes, data) != numCertTypes {
return false
}
data = data[numCertTypes:]
if m.hasSignatureAlgorithm {
if len(data) < 2 {
return false
}
sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if sigAndHashLen&1 != 0 {
return false
}
if len(data) < int(sigAndHashLen) {
return false
}
numSigAlgos := sigAndHashLen / 2
m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
for i := range m.supportedSignatureAlgorithms {
m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
data = data[2:]
}
}
if len(data) < 2 {
return false
}
casLength := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if len(data) < int(casLength) {
return false
}
cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
m.certificateAuthorities = nil
for len(cas) > 0 {
if len(cas) < 2 {
return false
}
caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
if len(cas) < int(caLen) {
return false
}
m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
cas = cas[caLen:]
}
return len(data) == 0
}
type certificateVerifyMsg struct {
raw []byte
hasSignatureAlgorithm bool // format change introduced in TLS 1.2
signatureAlgorithm SignatureScheme
signature []byte
}
func (m *certificateVerifyMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateVerify)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.hasSignatureAlgorithm {
b.AddUint16(uint16(m.signatureAlgorithm))
}
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.signature)
})
})
var err error
m.raw, err = b.Bytes()
return m.raw, err
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
if !s.Skip(4) { // message type and uint24 length field
return false
}
if m.hasSignatureAlgorithm {
if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
return false
}
}
return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
}
type newSessionTicketMsg struct {
raw []byte
ticket []byte
}
func (m *newSessionTicketMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
x := make([]byte, 4+length)
x[0] = typeNewSessionTicket
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[8] = uint8(ticketLen >> 8)
x[9] = uint8(ticketLen)
copy(x[10:], m.ticket)
m.raw = x
return m.raw, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 10 {
return false
}
length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
if uint32(len(data))-4 != length {
return false
}
ticketLen := int(data[8])<<8 + int(data[9])
if len(data)-10 != ticketLen {
return false
}
m.ticket = data[10:]
return true
}
type helloRequestMsg struct {
}
func (*helloRequestMsg) marshal() ([]byte, error) {
return []byte{typeHelloRequest, 0, 0, 0}, nil
}
func (*helloRequestMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type transcriptHash interface {
Write([]byte) (int, error)
}
// transcriptMsg is a helper used to marshal and hash messages which typically
// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
data, err := msg.marshal()
if err != nil {
return err
}
h.Write(data)
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"time"
)
// serverHandshakeState contains details of a server handshake in progress.
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
ecdheOk bool
ecSignOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *sessionState
finishedHash finishedHash
masterSecret []byte
cert *Certificate
}
// serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error {
clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
}
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
hs := serverHandshakeState{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
func (hs *serverHandshakeState) handshake() error {
c := hs.c
if err := hs.processClientHello(); err != nil {
return err
}
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true
if hs.checkForResumption() {
// The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = false
if err := hs.readFinished(nil); err != nil {
return err
}
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
if err := hs.pickCipherSuite(); err != nil {
return err
}
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readFinished(c.clientFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = true
c.buffering = true
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(nil); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.isHandshakeComplete.Store(true)
return nil
}
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
// clientHelloMsg is included in the transcript, but we haven't initialized
// it yet. The respective handshake functions will record it themselves.
msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(clientHello, msg)
}
var configForClient *Config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
chi := clientHelloInfo(ctx, c, clientHello)
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, err
} else if configForClient != nil {
c.config = configForClient
}
}
c.ticketKeys = originalConfig.ticketKeys(configForClient)
clientVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
clientVersions = supportedVersionsFromMax(clientHello.vers)
}
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
if !ok {
c.sendAlert(alertProtocolVersion)
return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
}
c.haveVers = true
c.in.version = c.vers
c.out.version = c.vers
return clientHello, nil
}
func (hs *serverHandshakeState) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.hello.vers = c.vers
foundCompression := false
// We only support null compression, so check that the client offered it.
for _, compression := range hs.clientHello.compressionMethods {
if compression == compressionNone {
foundCompression = true
break
}
}
if !foundCompression {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client does not support uncompressed connections")
}
hs.hello.random = make([]byte, 32)
serverRandom := hs.hello.random
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleServer)
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
if c.vers == VersionTLS12 {
copy(serverRandom[24:], downgradeCanaryTLS12)
} else {
copy(serverRandom[24:], downgradeCanaryTLS11)
}
serverRandom = serverRandom[:24]
}
_, err := io.ReadFull(c.config.rand(), serverRandom)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone
if len(hs.clientHello.serverName) > 0 {
c.serverName = hs.clientHello.serverName
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
hs.hello.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
if hs.clientHello.scts {
hs.hello.scts = hs.cert.SignedCertificateTimestamps
}
hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
// Although omitting the ec_point_formats extension is permitted, some
// old OpenSSL version will refuse to handshake if not present.
//
// Per RFC 4492, section 5.1.2, implementations MUST support the
// uncompressed point format. See golang.org/issue/31943.
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
switch priv.Public().(type) {
case *ecdsa.PublicKey:
hs.ecSignOk = true
case ed25519.PublicKey:
hs.ecSignOk = true
case *rsa.PublicKey:
hs.rsaSignOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
switch priv.Public().(type) {
case *rsa.PublicKey:
hs.rsaDecryptOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
}
}
return nil
}
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
return "", nil
}
var http11fallback bool
for _, s := range serverProtos {
for _, c := range clientProtos {
if s == c {
return s, nil
}
if s == "h2" && c == "http/1.1" {
http11fallback = true
}
}
}
// As a special case, let http/1.1 clients connect to h2 servers as if they
// didn't support ALPN. We used not to enforce protocol overlap, so over
// time a number of HTTP servers were configured with only "h2", but
// expected to accept connections from "http/1.1" clients. See Issue 46310.
if http11fallback {
return "", nil
}
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
}
// supportsECDHE returns whether ECDHE key exchanges can be used with this
// pre-TLS 1.3 client.
func supportsECDHE(c *Config, supportedCurves []CurveID, supportedPoints []uint8) bool {
supportsCurve := false
for _, curve := range supportedCurves {
if c.supportsCurve(curve) {
supportsCurve = true
break
}
}
supportsPointFormat := false
for _, pointFormat := range supportedPoints {
if pointFormat == pointFormatUncompressed {
supportsPointFormat = true
break
}
}
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
// missing, uncompressed points are supported. If supportedPoints is empty,
// the extension must be missing, as an empty extension body is rejected by
// the parser. See https://go.dev/issue/49126.
if len(supportedPoints) == 0 {
supportsPointFormat = true
}
return supportsCurve && supportsPointFormat
}
func (hs *serverHandshakeState) pickCipherSuite() error {
c := hs.c
preferenceOrder := cipherSuitesPreferenceOrder
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceOrder = cipherSuitesPreferenceOrderNoAES
}
configCipherSuites := c.config.cipherSuites()
preferenceList := make([]uint16, 0, len(configCipherSuites))
for _, suiteID := range preferenceOrder {
for _, id := range configCipherSuites {
if id == suiteID {
preferenceList = append(preferenceList, id)
break
}
}
}
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. See RFC 7507.
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
return nil
}
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
if !hs.ecdheOk {
return false
}
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaDecryptOk {
return false
}
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
}
// checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() bool {
c := hs.c
if c.config.SessionTicketsDisabled {
return false
}
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
if plaintext == nil {
return false
}
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
ok := hs.sessionState.unmarshal(plaintext)
if !ok {
return false
}
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return false
}
// Never resume a session for a different TLS version.
if c.vers != hs.sessionState.vers {
return false
}
cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites {
if id == hs.sessionState.cipherSuite {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return false
}
// Check that we also support the ciphersuite from the session.
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
c.config.cipherSuites(), hs.cipherSuiteOk)
if hs.suite == nil {
return false
}
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return false
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return false
}
return true
}
func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c
hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.ticketSupported = hs.sessionState.usedOldKey
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
if err := c.processCertsFromClient(Certificate{
Certificate: hs.sessionState.certificates,
}); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
hs.masterSecret = hs.sessionState.masterSecret
return nil
}
func (hs *serverHandshakeState) doFullHandshake() error {
c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true
}
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
if skx != nil {
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
var certReq *certificateRequestMsg
if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq = new(certificateRequestMsg)
certReq.certificateTypes = []byte{
byte(certTypeRSASign),
byte(certTypeECDSASign),
}
if c.vers >= VersionTLS12 {
certReq.hasSignatureAlgorithm = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
// An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if c.config.ClientAuth >= RequestClientCert {
certMsg, ok := msg.(*certificateMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
}); err != nil {
return err
}
if len(certMsg.certificates) != 0 {
pub = c.peerCertificates[0].PublicKey
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
// If we received a client cert in response to our certificate request message,
// the client will send us a certificateVerifyMsg immediately after the
// clientKeyExchangeMsg. This message is a digest of all preceding
// handshake-layer messages that is signed using the private key corresponding
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *serverHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.aead == nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
return nil
}
func (hs *serverHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientFinished, msg)
}
verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect")
}
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *serverHandshakeState) sendSessionTicket() error {
// ticketSupported is set in a resumption handshake if the
// ticket from the client was encrypted with an old session
// ticket key and thus a refreshed ticket should be sent.
if !hs.hello.ticketSupported {
return nil
}
c := hs.c
m := new(newSessionTicketMsg)
createdAt := uint64(c.config.time().Unix())
if hs.sessionState != nil {
// If this is re-wrapping an old key, then keep
// the original time it was created.
createdAt = hs.sessionState.createdAt
}
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionState{
vers: c.vers,
cipherSuite: hs.suite.id,
createdAt: createdAt,
masterSecret: hs.masterSecret,
certificates: certsFromClient,
}
stateBytes, err := state.marshal()
if err != nil {
return err
}
m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// processCertsFromClient takes a chain of client certificates either from a
// Certificates message or from a sessionState and verifies them. It returns
// the public key of the leaf certificate.
func (c *Conn) processCertsFromClient(certificate Certificate) error {
certificates := certificate.Certificate
certs := make([]*x509.Certificate, len(certificates))
var err error
for i, asn1Data := range certificates {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
}
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
if c.vers == VersionTLS13 {
c.sendAlert(alertCertificateRequired)
} else {
c.sendAlert(alertBadCertificate)
}
return errors.New("tls: client didn't provide a certificate")
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
var errCertificateInvalid x509.CertificateInvalidError
if errors.As(err, &x509.UnknownAuthorityError{}) {
c.sendAlert(alertUnknownCA)
} else if errors.As(err, &errCertificateInvalid) && errCertificateInvalid.Reason == x509.Expired {
c.sendAlert(alertCertificateExpired)
} else {
c.sendAlert(alertBadCertificate)
}
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}
c.verifiedChains = chains
}
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
if len(certs) > 0 {
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
}
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
}
return &ClientHelloInfo{
CipherSuites: clientHello.cipherSuites,
ServerName: clientHello.serverName,
SupportedCurves: clientHello.supportedCurves,
SupportedPoints: clientHello.supportedPoints,
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
SupportedProtos: clientHello.alpnProtocols,
SupportedVersions: supportedVersions,
Conn: c.conn,
config: c.config,
ctx: ctx,
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"encoding/binary"
"errors"
"hash"
"io"
"time"
)
// maxClientPSKIdentities is the number of client PSK identities the server will
// attempt to validate. It will ignore the rest not to let cheap ClientHello
// messages cause too much work in session ticket decryption attempts.
const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
}
func (hs *serverHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
if err := hs.processClientHello(); err != nil {
return err
}
if err := hs.checkForResumption(); err != nil {
return err
}
if err := hs.pickCertificate(); err != nil {
return err
}
c.buffering = true
if err := hs.sendServerParameters(); err != nil {
return err
}
if err := hs.sendServerCertificate(); err != nil {
return err
}
if err := hs.sendServerFinished(); err != nil {
return err
}
// Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters.
if _, err := c.flush(); err != nil {
return err
}
if err := hs.readClientCertificate(); err != nil {
return err
}
if err := hs.readClientFinished(); err != nil {
return err
}
c.isHandshakeComplete.Store(true)
return nil
}
func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
hs.hello.vers = VersionTLS12
hs.hello.supportedVersion = c.vers
if len(hs.clientHello.supportedVersions) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
}
// Abort if the client is doing a fallback and landing lower than what we
// support. See RFC 7507, which however does not specify the interaction
// with supported_versions. The only difference is that with
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
// supported_versions was not better because there was just no way to do a
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// Use c.vers instead of max(supported_versions) because an attacker
// could defeat this by adding an arbitrary high version otherwise.
if c.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
if len(hs.clientHello.compressionMethods) != 1 ||
hs.clientHello.compressionMethods[0] != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
}
hs.hello.random = make([]byte, 32)
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
if hs.clientHello.earlyData {
// See RFC 8446, Section 4.2.10 for the complicated behavior required
// here. The scenario is that a different server at our address offered
// to accept early data in the past, which we can't handle. For now, all
// 0-RTT enabled session tickets need to expire before a Go server can
// replace a server or join a pool. That's the same requirement that
// applies to mixing or replacing with any TLS 1.2 server.
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.compressionMethod = compressionNone
preferenceList := defaultCipherSuitesTLS13
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceList = defaultCipherSuitesTLS13NoAES
}
for _, suiteID := range preferenceList {
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
if hs.suite != nil {
break
}
}
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
hs.hello.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
// Pick the ECDHE group in server preference order, but give priority to
// groups with a key share, to avoid a HelloRetryRequest round-trip.
var selectedGroup CurveID
var clientKeyShare *keyShare
GroupSelection:
for _, preferredGroup := range c.config.curvePreferences() {
for _, ks := range hs.clientHello.keyShares {
if ks.group == preferredGroup {
selectedGroup = ks.group
clientKeyShare = &ks
break GroupSelection
}
}
if selectedGroup != 0 {
continue
}
for _, group := range hs.clientHello.supportedCurves {
if group == preferredGroup {
selectedGroup = group
break
}
}
}
if selectedGroup == 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no ECDHE curve supported by both client and server")
}
if clientKeyShare == nil {
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
return err
}
clientKeyShare = &hs.clientHello.keyShares[0]
}
if _, ok := curveForCurveID(selectedGroup); !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(c.config.rand(), selectedGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
hs.sharedKey, err = key.ECDH(peerKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
c.serverName = hs.clientHello.serverName
return nil
}
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
modeOK := false
for _, mode := range hs.clientHello.pskModes {
if mode == pskModeDHE {
modeOK = true
break
}
}
if !modeOK {
return nil
}
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid or missing PSK binders")
}
if len(hs.clientHello.pskIdentities) == 0 {
return nil
}
for i, identity := range hs.clientHello.pskIdentities {
if i >= maxClientPSKIdentities {
break
}
plaintext, _ := c.decryptTicket(identity.label)
if plaintext == nil {
continue
}
sessionState := new(sessionStateTLS13)
if ok := sessionState.unmarshal(plaintext); !ok {
continue
}
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
continue
}
// We don't check the obfuscated ticket age because it's affected by
// clock skew and it's only a freshness signal useful for shrinking the
// window for replay attacks, which don't affect us as we don't do 0-RTT.
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
continue
}
// PSK connections don't re-establish client certificates, but carry
// them over in the session ticket. Ensure the presence of client certs
// in the ticket is consistent with the configured requirements.
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
continue
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
continue
}
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
nil, hs.suite.hash.Size())
hs.earlySecret = hs.suite.extract(psk, nil)
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
// Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash)
if transcript == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid PSK binder")
}
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
return nil
}
return nil
}
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
// interfaces implemented by standard library hashes to clone the state of in
// to a new instance of h. It returns nil if the operation fails.
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
// Recreate the interface to avoid importing encoding.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
UnmarshalBinary(data []byte) error
}
marshaler, ok := in.(binaryMarshaler)
if !ok {
return nil
}
state, err := marshaler.MarshalBinary()
if err != nil {
return nil
}
out := h.New()
unmarshaler, ok := out.(binaryMarshaler)
if !ok {
return nil
}
if err := unmarshaler.UnmarshalBinary(state); err != nil {
return nil
}
return out
}
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
return c.sendAlert(alertMissingExtension)
}
certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
if err != nil {
// getCertificate returned a certificate that is unsupported or
// incompatible with the client's signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
hs.cert = certificate
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
helloRetryRequest := &serverHelloMsg{
vers: hs.hello.vers,
random: helloRetryRequestRandom,
sessionId: hs.hello.sessionId,
cipherSuite: hs.hello.cipherSuite,
compressionMethod: hs.hello.compressionMethod,
supportedVersion: hs.hello.supportedVersion,
selectedGroup: selectedGroup,
}
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
// clientHelloMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientHello, msg)
}
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client sent invalid key share in second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client indicated early data in second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client illegally modified second ClientHello")
}
hs.clientHello = clientHello
return nil
}
// illegalClientHelloChange reports whether the two ClientHello messages are
// different, with the exception of the changes allowed before and after a
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
return true
}
for i := range ch.supportedVersions {
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
return true
}
}
for i := range ch.cipherSuites {
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
return true
}
}
for i := range ch.supportedCurves {
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithms {
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithmsCert {
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
return true
}
}
for i := range ch.alpnProtocols {
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
return true
}
}
return ch.vers != ch1.vers ||
!bytes.Equal(ch.random, ch1.random) ||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
ch.serverName != ch1.serverName ||
ch.ocspStapling != ch1.ocspStapling ||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
ch.ticketSupported != ch1.ticketSupported ||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
ch.scts != ch1.scts ||
!bytes.Equal(ch.cookie, ch1.cookie) ||
!bytes.Equal(ch.pskModes, ch1.pskModes)
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
earlySecret := hs.earlySecret
if earlySecret == nil {
earlySecret = hs.suite.extract(nil, nil)
}
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.in.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
encryptedExtensions := new(encryptedExtensionsMsg)
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
}
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
if hs.requestClientCert() {
// Request a client certificate
certReq := new(certificateRequestMsgTLS13)
certReq.ocspStapling = true
certReq.scts = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *hs.cert
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm = hs.sigAlg
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
c.sendAlert(alertHandshakeFailure)
} else {
c.sendAlert(alertInternalError)
}
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
// If we did not request client certificates, at this point we can
// precompute the client finished and roll the transcript forward to send
// session tickets in our first flight.
if !hs.requestClientCert() {
if err := hs.sendSessionTickets(); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
if hs.c.config.SessionTicketsDisabled {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
for _, pskMode := range hs.clientHello.pskModes {
if pskMode == pskModeDHE {
return true
}
}
return false
}
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
c := hs.c
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
return err
}
if !hs.shouldSendSessionTickets() {
return nil
}
resumptionSecret := hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: hs.suite.id,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
}
stateBytes, err := state.marshal()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
_, err = hs.c.config.rand().Read(ageAdd)
if err != nil {
return err
}
m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c
if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if len(certMsg.certificate.Certificate) != 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
}
// If we waited until the client certificates to send session tickets, we
// are ready to do it now.
if err := hs.sendSessionTickets(); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
// finishedMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid client finished hash")
}
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
return nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/ecdh"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"errors"
"fmt"
"io"
)
// a keyAgreement implements the client and server side of a TLS key agreement
// protocol by generating and processing key exchange messages.
type keyAgreement interface {
// On the server side, the first two methods are called in order.
// In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil.
generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a
// ServerKeyExchange message.
processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil
}
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange
}
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
if ciphertextLen != len(ckx.ciphertext)-2 {
return nil, errClientKeyExchange
}
ciphertext := ckx.ciphertext[2:]
priv, ok := cert.PrivateKey.(crypto.Decrypter)
if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
}
// Perform constant time RSA PKCS #1 v1.5 decryption
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
if err != nil {
return nil, err
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
// encrypted pre-master secret. Secondly, it provides only a small
// benefit against a downgrade attack and some implementations send the
// wrong version anyway. See the discussion at the end of section
// 7.4.7.1 of RFC 4346.
return preMasterSecret, nil
}
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange")
}
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil {
return nil, nil, err
}
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
}
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
if err != nil {
return nil, nil, err
}
ckx := new(clientKeyExchangeMsg)
ckx.ciphertext = make([]byte, len(encrypted)+2)
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
ckx.ciphertext[1] = byte(len(encrypted))
copy(ckx.ciphertext[2:], encrypted)
return preMasterSecret, ckx, nil
}
// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash(slices [][]byte) []byte {
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
return hsha1.Sum(nil)
}
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New()
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum(nil))
copy(md5sha1[md5.Size:], sha1Hash(slices))
return md5sha1
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for >= TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA {
return sha1Hash(slices)
}
return md5SHA1Hash(slices)
}
// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
version uint16
isRSA bool
key *ecdh.PrivateKey
// ckx and preMasterSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
preMasterSecret []byte
}
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(c) {
curveID = c
break
}
}
if curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(curveID); !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, err
}
ka.key = key
// See RFC 4492, Section 5.4.
ecdhePublic := key.PublicKey().Bytes()
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(curveID >> 8)
serverECDHEParams[2] = byte(curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
}
var signatureAlgorithm SignatureScheme
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return nil, err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil {
return nil, err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
skx := new(serverKeyExchangeMsg)
sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHEParams)
k := skx.key[len(serverECDHEParams):]
if ka.version >= VersionTLS12 {
k[0] = byte(signatureAlgorithm >> 8)
k[1] = byte(signatureAlgorithm)
k = k[2:]
}
k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig))
copy(k[2:], sig)
return skx, nil
}
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange
}
peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:])
if err != nil {
return nil, errClientKeyExchange
}
preMasterSecret, err := ka.key.ECDH(peerKey)
if err != nil {
return nil, errClientKeyExchange
}
return preMasterSecret, nil
}
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 {
return errServerKeyExchange
}
if skx.key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
publicLen := int(skx.key[3])
if publicLen+4 > len(skx.key) {
return errServerKeyExchange
}
serverECDHEParams := skx.key[:4+publicLen]
publicKey := serverECDHEParams[4:]
sig := skx.key[4+publicLen:]
if len(sig) < 2 {
return errServerKeyExchange
}
if _, ok := curveForCurveID(curveID); !ok {
return errors.New("tls: server selected unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return err
}
ka.key = key
peerKey, err := key.Curve().NewPublicKey(publicKey)
if err != nil {
return errServerKeyExchange
}
ka.preMasterSecret, err = key.ECDH(peerKey)
if err != nil {
return errServerKeyExchange
}
ourPublicKey := key.PublicKey().Bytes()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[1:], ourPublicKey)
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil {
return err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil
}
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}
return ka.preMasterSecret, ka.ckx, nil
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto/ecdh"
"crypto/hmac"
"errors"
"fmt"
"hash"
"io"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/hkdf"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
const (
resumptionBinderLabel = "res binder"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"
serverApplicationTrafficLabel = "s ap traffic"
exporterLabel = "exp master"
resumptionLabel = "res master"
trafficUpdateLabel = "traffic upd"
)
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
hkdfLabelBytes, err := hkdfLabel.Bytes()
if err != nil {
// Rather than calling BytesOrPanic, we explicitly handle this error, in
// order to provide a reasonable error message. It should be basically
// impossible for this to panic, and routing errors back through the
// tree rooted in this function is quite painful. The labels are fixed
// size, and the context is either a fixed-length computed hash, or
// parsed from a field which has the same length limitation. As such, an
// error here is likely to only be caused during development.
//
// NOTE: another reasonable approach here might be to return a
// randomized slice if we encounter an error, which would break the
// connection, but avoid panicking. This would perhaps be safer but
// significantly more confusing to users.
panic(fmt.Errorf("failed to construct HKDF label: %s", err))
}
out := make([]byte, length)
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
if transcript == nil {
transcript = c.hash.New()
}
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
}
// extract implements HKDF-Extract with the cipher suite hash.
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
if newSecret == nil {
newSecret = make([]byte, c.hash.Size())
}
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
}
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
return
}
// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
verifyData := hmac.New(c.hash.New, finishedKey)
verifyData.Write(transcript.Sum(nil))
return verifyData.Sum(nil)
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
return func(label string, context []byte, length int) ([]byte, error) {
secret := c.deriveSecret(expMasterSecret, label, nil)
h := c.hash.New()
h.Write(context)
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
}
}
// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
// according to RFC 8446, Section 4.2.8.2.
func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, errors.New("tls: internal error: unsupported curve")
}
return curve.GenerateKey(rand)
}
func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
switch id {
case X25519:
return ecdh.X25519(), true
case CurveP256:
return ecdh.P256(), true
case CurveP384:
return ecdh.P384(), true
case CurveP521:
return ecdh.P521(), true
default:
return nil, false
}
}
func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) {
switch curve {
case ecdh.X25519():
return X25519, true
case ecdh.P256():
return CurveP256, true
case ecdh.P384():
return CurveP384, true
case ecdh.P521():
return CurveP521, true
default:
return 0, false
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package tls
func needFIPS() bool { return false }
func supportedSignatureAlgorithms() []SignatureScheme {
return defaultSupportedSignatureAlgorithms
}
func fipsMinVersion(c *Config) uint16 { panic("fipsMinVersion") }
func fipsMaxVersion(c *Config) uint16 { panic("fipsMaxVersion") }
func fipsCurvePreferences(c *Config) []CurveID { panic("fipsCurvePreferences") }
func fipsCipherSuites(c *Config) []uint16 { panic("fipsCipherSuites") }
var fipsSupportedSignatureAlgorithms []SignatureScheme
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
)
// Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:]
return
}
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum(nil)
copy(result[j:], b)
j += len(b)
h.Reset()
h.Write(a)
a = h.Sum(nil)
}
}
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
s1, s2 := splitPreMasterSecret(secret)
pHash(result, s1, labelAndSeed, hashMD5)
result2 := make([]byte, len(result))
pHash(result2, s2, labelAndSeed, hashSHA1)
for i, b := range result2 {
result[i] ^= b
}
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
return func(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
pHash(result, secret, labelAndSeed, hashFunc)
}
}
const (
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
)
var masterSecretLabel = []byte("master secret")
var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
switch version {
case VersionTLS10, VersionTLS11:
return prf10, crypto.Hash(0)
case VersionTLS12:
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
prf, _ := prfAndHashForVersion(version, suite)
return prf
}
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
masterSecret := make([]byte, masterSecretLength)
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
return masterSecret
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...)
seed = append(seed, clientRandom...)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n)
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientIV = keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverIV = keyMaterial[:ivLen]
return
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
var buffer []byte
if version >= VersionTLS12 {
buffer = []byte{}
}
prf, hash := prfAndHashForVersion(version, cipherSuite)
if hash != 0 {
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
client hash.Hash
server hash.Hash
// Prior to TLS 1.2, an additional MD5 hash is required.
clientMD5 hash.Hash
serverMD5 hash.Hash
// In TLS 1.2, a full buffer is sadly required.
buffer []byte
version uint16
prf func(result, secret, label, seed []byte)
}
func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
if h.version < VersionTLS12 {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
if h.buffer != nil {
h.buffer = append(h.buffer, msg...)
}
return len(msg), nil
}
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)
}
out := make([]byte, 0, md5.Size+sha1.Size)
out = h.clientMD5.Sum(out)
return h.client.Sum(out)
}
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
return out
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
return out
}
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
// necessary, suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA {
return h.server.Sum(nil)
}
return h.Sum()
}
// discardHandshakeBuffer is called when there is no more need to
// buffer the entirety of the handshake messages.
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
// noExportedKeyingMaterial is used as a value of
// ConnectionState.ekm when renegotiation is enabled and thus
// we wish to fail all key-material export requests.
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
}
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) {
switch label {
case "client finished", "server finished", "master secret", "key expansion":
// These values are reserved and may not be used.
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
}
seedLen := len(serverRandom) + len(clientRandom)
if context != nil {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
if context != nil {
if len(context) >= 1<<16 {
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
}
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
keyMaterial := make([]byte, length)
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
return keyMaterial, nil
}
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"errors"
"io"
"golang.org/x/crypto/cryptobyte"
)
// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
vers uint16
cipherSuite uint16
createdAt uint64
masterSecret []byte // opaque master_secret<1..2^16-1>;
// struct { opaque certificate<1..2^24-1> } Certificate;
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
// usedOldKey is true if the ticket from which this session came from
// was encrypted with an older key and thus should be refreshed.
usedOldKey bool
}
func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.masterSecret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, cert := range m.certificates {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
}
})
return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
*m = sessionState{usedOldKey: m.usedOldKey}
s := cryptobyte.String(data)
if ok := s.ReadUint16(&m.vers) &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint16LengthPrefixed(&s, &m.masterSecret) &&
len(m.masterSecret) != 0; !ok {
return false
}
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return false
}
m.certificates = append(m.certificates, cert)
}
return s.Empty()
}
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
// validation and the nonce is always empty.
type sessionStateTLS13 struct {
// uint8 version = 0x0304;
// uint8 revision = 0;
cipherSuite uint16
createdAt uint64
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
}
func (m *sessionStateTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(0) // revision
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
return b.Bytes()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
*m = sessionStateTLS13{}
s := cryptobyte.String(data)
var version uint16
var revision uint8
return s.ReadUint16(&version) &&
version == VersionTLS13 &&
s.ReadUint8(&revision) &&
revision == 0 &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
len(m.resumptionSecret) != 0 &&
unmarshalCertificate(&s, &m.certificate) &&
s.Empty()
}
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
if len(c.ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
return nil, err
}
key := c.ticketKeys[0]
copy(keyName, key.keyName[:])
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
mac.Sum(macBytes[:0])
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
return nil, false
}
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
keyIndex := -1
for i, candidateKey := range c.ticketKeys {
if bytes.Equal(keyName, candidateKey.keyName[:]) {
keyIndex = i
break
}
}
if keyIndex == -1 {
return nil, false
}
key := &c.ticketKeys[keyIndex]
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
return nil, false
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, false
}
plaintext = make([]byte, len(ciphertext))
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext, keyIndex > 0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tls partially implements TLS 1.2, as specified in RFC 5246,
// and TLS 1.3, as specified in RFC 8446.
package tls
// BUG(agl): The crypto/tls package only implements some countermeasures
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"strings"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
}
c.handshakeFn = c.serverHandshake
return c
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: config,
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
}
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config), nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config), nil
}
type timeoutError struct{}
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
//
// DialWithDialer uses context.Background internally; to specify the context,
// use Dialer.DialContext with NetDialer set to the desired dialer.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
defer cancel()
}
if !netDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = defaultConfig()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config)
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned Conn, if any, will always be of type *Conn.
//
// Dial uses context.Background internally; to specify the context,
// use DialContext.
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// DialContext connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to
// form a certificate chain. On successful return, Certificate.Leaf will
// be nil because the parsed form of the certificate is not retained.
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return X509KeyPair(certPEMBlock, keyPEMBlock)
}
// X509KeyPair parses a public/private key pair from a pair of
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
// the parsed form of the certificate is not retained.
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
fail := func(err error) (Certificate, error) { return Certificate{}, err }
var cert Certificate
var skippedBlockTypes []string
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
} else {
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
}
}
if len(cert.Certificate) == 0 {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
}
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
}
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
skippedBlockTypes = skippedBlockTypes[:0]
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
if keyDERBlock == nil {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in key input"))
}
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
}
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}
switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.N.Cmp(priv.N) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case ed25519.PublicKey:
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
return fail(errors.New("tls: private key does not match public key"))
}
default:
return fail(errors.New("tls: unknown public key algorithm"))
}
return cert, nil
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("tls: failed to parse private key")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto/sha256"
"encoding/pem"
"sync"
)
type sum224 [sha256.Size224]byte
// CertPool is a set of certificates.
type CertPool struct {
byName map[string][]int // cert.RawSubject => index into lazyCerts
// lazyCerts contains funcs that return a certificate,
// lazily parsing/decompressing it as needed.
lazyCerts []lazyCert
// haveSum maps from sum224(cert.Raw) to true. It's used only
// for AddCert duplicate detection, to avoid CertPool.contains
// calls in the AddCert path (because the contains method can
// call getCert and otherwise negate savings from lazy getCert
// funcs).
haveSum map[sum224]bool
// systemPool indicates whether this is a special pool derived from the
// system roots. If it includes additional roots, it requires doing two
// verifications, one using the roots provided by the caller, and one using
// the system platform verifier.
systemPool bool
}
// lazyCert is minimal metadata about a Cert and a func to retrieve it
// in its normal expanded *Certificate form.
type lazyCert struct {
// rawSubject is the Certificate.RawSubject value.
// It's the same as the CertPool.byName key, but in []byte
// form to make CertPool.Subjects (as used by crypto/tls) do
// fewer allocations.
rawSubject []byte
// getCert returns the certificate.
//
// It is not meant to do network operations or anything else
// where a failure is likely; the func is meant to lazily
// parse/decompress data that is already known to be good. The
// error in the signature primarily is meant for use in the
// case where a cert file existed on local disk when the program
// started up is deleted later before it's read.
getCert func() (*Certificate, error)
}
// NewCertPool returns a new, empty CertPool.
func NewCertPool() *CertPool {
return &CertPool{
byName: make(map[string][]int),
haveSum: make(map[sum224]bool),
}
}
// len returns the number of certs in the set.
// A nil set is a valid empty set.
func (s *CertPool) len() int {
if s == nil {
return 0
}
return len(s.lazyCerts)
}
// cert returns cert index n in s.
func (s *CertPool) cert(n int) (*Certificate, error) {
return s.lazyCerts[n].getCert()
}
// Clone returns a copy of s.
func (s *CertPool) Clone() *CertPool {
p := &CertPool{
byName: make(map[string][]int, len(s.byName)),
lazyCerts: make([]lazyCert, len(s.lazyCerts)),
haveSum: make(map[sum224]bool, len(s.haveSum)),
systemPool: s.systemPool,
}
for k, v := range s.byName {
indexes := make([]int, len(v))
copy(indexes, v)
p.byName[k] = indexes
}
for k := range s.haveSum {
p.haveSum[k] = true
}
copy(p.lazyCerts, s.lazyCerts)
return p
}
// SystemCertPool returns a copy of the system cert pool.
//
// On Unix systems other than macOS the environment variables SSL_CERT_FILE and
// SSL_CERT_DIR can be used to override the system default locations for the SSL
// certificate file and SSL certificate files directory, respectively. The
// latter can be a colon-separated list.
//
// Any mutations to the returned pool are not written to disk and do not affect
// any other pool returned by SystemCertPool.
//
// New changes in the system cert pool might not be reflected in subsequent calls.
func SystemCertPool() (*CertPool, error) {
if sysRoots := systemRootsPool(); sysRoots != nil {
return sysRoots.Clone(), nil
}
return loadSystemRoots()
}
// findPotentialParents returns the indexes of certificates in s which might
// have signed cert.
func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate {
if s == nil {
return nil
}
// consider all candidates where cert.Issuer matches cert.Subject.
// when picking possible candidates the list is built in the order
// of match plausibility as to save cycles in buildChains:
// AKID and SKID match
// AKID present, SKID missing / AKID missing, SKID present
// AKID and SKID don't match
var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate
for _, c := range s.byName[string(cert.RawIssuer)] {
candidate, err := s.cert(c)
if err != nil {
continue
}
kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
switch {
case kidMatch:
matchingKeyID = append(matchingKeyID, candidate)
case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
(len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
oneKeyID = append(oneKeyID, candidate)
default:
mismatchKeyID = append(mismatchKeyID, candidate)
}
}
found := len(matchingKeyID) + len(oneKeyID) + len(mismatchKeyID)
if found == 0 {
return nil
}
candidates := make([]*Certificate, 0, found)
candidates = append(candidates, matchingKeyID...)
candidates = append(candidates, oneKeyID...)
candidates = append(candidates, mismatchKeyID...)
return candidates
}
func (s *CertPool) contains(cert *Certificate) bool {
if s == nil {
return false
}
return s.haveSum[sha256.Sum224(cert.Raw)]
}
// AddCert adds a certificate to a pool.
func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil
})
}
// addCertFunc adds metadata about a certificate to a pool, along with
// a func to fetch that certificate later when needed.
//
// The rawSubject is Certificate.RawSubject and must be non-empty.
// The getCert func may be called 0 or more times.
func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) {
if getCert == nil {
panic("getCert can't be nil")
}
// Check that the certificate isn't being added twice.
if s.haveSum[rawSum224] {
return
}
s.haveSum[rawSum224] = true
s.lazyCerts = append(s.lazyCerts, lazyCert{
rawSubject: []byte(rawSubject),
getCert: getCert,
})
s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
}
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
// It appends any certificates found to s and reports whether any certificates
// were successfully parsed.
//
// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set
// of root CAs in a format suitable for this function.
func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
certBytes := block.Bytes
cert, err := ParseCertificate(certBytes)
if err != nil {
continue
}
var lazyCert struct {
sync.Once
v *Certificate
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
lazyCert.Do(func() {
// This can't fail, as the same bytes already parsed above.
lazyCert.v, _ = ParseCertificate(certBytes)
certBytes = nil
})
return lazyCert.v, nil
})
ok = true
}
return ok
}
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
//
// Deprecated: if s was returned by SystemCertPool, Subjects
// will not include the system roots.
func (s *CertPool) Subjects() [][]byte {
res := make([][]byte, s.len())
for i, lc := range s.lazyCerts {
res[i] = lc.rawSubject
}
return res
}
// Equal reports whether s and other are equal.
func (s *CertPool) Equal(other *CertPool) bool {
if s == nil || other == nil {
return s == other
}
if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) {
return false
}
for h := range s.haveSum {
if !other.haveSum[h] {
return false
}
}
return true
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !boringcrypto
package x509
func boringAllowCert(c *Certificate) bool { return true }
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto/dsa"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf16"
"unicode/utf8"
"golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
)
// isPrintable reports whether the given b is in the ASN.1 PrintableString set.
// This is a simplified version of encoding/asn1.isPrintable.
func isPrintable(b byte) bool {
return 'a' <= b && b <= 'z' ||
'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' ||
'\'' <= b && b <= ')' ||
'+' <= b && b <= '/' ||
b == ' ' ||
b == ':' ||
b == '=' ||
b == '?' ||
// This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it.
b == '*' ||
// This is not technically allowed either. However, not
// only is it relatively common, but there are also a
// handful of CA certificates that contain it. At least
// one of which will not expire until 2027.
b == '&'
}
// parseASN1String parses the ASN.1 string types T61String, PrintableString,
// UTF8String, BMPString, IA5String, and NumericString. This is mostly copied
// from the respective encoding/asn1.parse... methods, rather than just
// increasing the API surface of that package.
func parseASN1String(tag cryptobyte_asn1.Tag, value []byte) (string, error) {
switch tag {
case cryptobyte_asn1.T61String:
return string(value), nil
case cryptobyte_asn1.PrintableString:
for _, b := range value {
if !isPrintable(b) {
return "", errors.New("invalid PrintableString")
}
}
return string(value), nil
case cryptobyte_asn1.UTF8String:
if !utf8.Valid(value) {
return "", errors.New("invalid UTF-8 string")
}
return string(value), nil
case cryptobyte_asn1.Tag(asn1.TagBMPString):
if len(value)%2 != 0 {
return "", errors.New("invalid BMPString")
}
// Strip terminator if present.
if l := len(value); l >= 2 && value[l-1] == 0 && value[l-2] == 0 {
value = value[:l-2]
}
s := make([]uint16, 0, len(value)/2)
for len(value) > 0 {
s = append(s, uint16(value[0])<<8+uint16(value[1]))
value = value[2:]
}
return string(utf16.Decode(s)), nil
case cryptobyte_asn1.IA5String:
s := string(value)
if isIA5String(s) != nil {
return "", errors.New("invalid IA5String")
}
return s, nil
case cryptobyte_asn1.Tag(asn1.TagNumericString):
for _, b := range value {
if !('0' <= b && b <= '9' || b == ' ') {
return "", errors.New("invalid NumericString")
}
}
return string(value), nil
}
return "", fmt.Errorf("unsupported string type: %v", tag)
}
// parseName parses a DER encoded Name as defined in RFC 5280. We may
// want to export this function in the future for use in crypto/tls.
func parseName(raw cryptobyte.String) (*pkix.RDNSequence, error) {
if !raw.ReadASN1(&raw, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RDNSequence")
}
var rdnSeq pkix.RDNSequence
for !raw.Empty() {
var rdnSet pkix.RelativeDistinguishedNameSET
var set cryptobyte.String
if !raw.ReadASN1(&set, cryptobyte_asn1.SET) {
return nil, errors.New("x509: invalid RDNSequence")
}
for !set.Empty() {
var atav cryptobyte.String
if !set.ReadASN1(&atav, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute")
}
var attr pkix.AttributeTypeAndValue
if !atav.ReadASN1ObjectIdentifier(&attr.Type) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute type")
}
var rawValue cryptobyte.String
var valueTag cryptobyte_asn1.Tag
if !atav.ReadAnyASN1(&rawValue, &valueTag) {
return nil, errors.New("x509: invalid RDNSequence: invalid attribute value")
}
var err error
attr.Value, err = parseASN1String(valueTag, rawValue)
if err != nil {
return nil, fmt.Errorf("x509: invalid RDNSequence: invalid attribute value: %s", err)
}
rdnSet = append(rdnSet, attr)
}
rdnSeq = append(rdnSeq, rdnSet)
}
return &rdnSeq, nil
}
func parseAI(der cryptobyte.String) (pkix.AlgorithmIdentifier, error) {
ai := pkix.AlgorithmIdentifier{}
if !der.ReadASN1ObjectIdentifier(&ai.Algorithm) {
return ai, errors.New("x509: malformed OID")
}
if der.Empty() {
return ai, nil
}
var params cryptobyte.String
var tag cryptobyte_asn1.Tag
if !der.ReadAnyASN1Element(¶ms, &tag) {
return ai, errors.New("x509: malformed parameters")
}
ai.Parameters.Tag = int(tag)
ai.Parameters.FullBytes = params
return ai, nil
}
func parseTime(der *cryptobyte.String) (time.Time, error) {
var t time.Time
switch {
case der.PeekASN1Tag(cryptobyte_asn1.UTCTime):
if !der.ReadASN1UTCTime(&t) {
return t, errors.New("x509: malformed UTCTime")
}
case der.PeekASN1Tag(cryptobyte_asn1.GeneralizedTime):
if !der.ReadASN1GeneralizedTime(&t) {
return t, errors.New("x509: malformed GeneralizedTime")
}
default:
return t, errors.New("x509: unsupported time format")
}
return t, nil
}
func parseValidity(der cryptobyte.String) (time.Time, time.Time, error) {
notBefore, err := parseTime(&der)
if err != nil {
return time.Time{}, time.Time{}, err
}
notAfter, err := parseTime(&der)
if err != nil {
return time.Time{}, time.Time{}, err
}
return notBefore, notAfter, nil
}
func parseExtension(der cryptobyte.String) (pkix.Extension, error) {
var ext pkix.Extension
if !der.ReadASN1ObjectIdentifier(&ext.Id) {
return ext, errors.New("x509: malformed extension OID field")
}
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&ext.Critical) {
return ext, errors.New("x509: malformed extension critical field")
}
}
var val cryptobyte.String
if !der.ReadASN1(&val, cryptobyte_asn1.OCTET_STRING) {
return ext, errors.New("x509: malformed extension value field")
}
ext.Value = val
return ext, nil
}
func parsePublicKey(keyData *publicKeyInfo) (any, error) {
oid := keyData.Algorithm.Algorithm
params := keyData.Algorithm.Parameters
der := cryptobyte.String(keyData.PublicKey.RightAlign())
switch {
case oid.Equal(oidPublicKeyRSA):
// RSA public keys must have a NULL in the parameters.
// See RFC 3279, Section 2.3.1.
if !bytes.Equal(params.FullBytes, asn1.NullBytes) {
return nil, errors.New("x509: RSA key missing NULL parameters")
}
p := &pkcs1PublicKey{N: new(big.Int)}
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid RSA public key")
}
if !der.ReadASN1Integer(p.N) {
return nil, errors.New("x509: invalid RSA modulus")
}
if !der.ReadASN1Integer(&p.E) {
return nil, errors.New("x509: invalid RSA public exponent")
}
if p.N.Sign() <= 0 {
return nil, errors.New("x509: RSA modulus is not a positive number")
}
if p.E <= 0 {
return nil, errors.New("x509: RSA public exponent is not a positive number")
}
pub := &rsa.PublicKey{
E: p.E,
N: p.N,
}
return pub, nil
case oid.Equal(oidPublicKeyECDSA):
paramsDer := cryptobyte.String(params.FullBytes)
namedCurveOID := new(asn1.ObjectIdentifier)
if !paramsDer.ReadASN1ObjectIdentifier(namedCurveOID) {
return nil, errors.New("x509: invalid ECDSA parameters")
}
namedCurve := namedCurveFromOID(*namedCurveOID)
if namedCurve == nil {
return nil, errors.New("x509: unsupported elliptic curve")
}
x, y := elliptic.Unmarshal(namedCurve, der)
if x == nil {
return nil, errors.New("x509: failed to unmarshal elliptic curve point")
}
pub := &ecdsa.PublicKey{
Curve: namedCurve,
X: x,
Y: y,
}
return pub, nil
case oid.Equal(oidPublicKeyEd25519):
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(params.FullBytes) != 0 {
return nil, errors.New("x509: Ed25519 key encoded with illegal parameters")
}
if len(der) != ed25519.PublicKeySize {
return nil, errors.New("x509: wrong Ed25519 public key size")
}
return ed25519.PublicKey(der), nil
case oid.Equal(oidPublicKeyX25519):
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(params.FullBytes) != 0 {
return nil, errors.New("x509: X25519 key encoded with illegal parameters")
}
return ecdh.X25519().NewPublicKey(der)
case oid.Equal(oidPublicKeyDSA):
y := new(big.Int)
if !der.ReadASN1Integer(y) {
return nil, errors.New("x509: invalid DSA public key")
}
pub := &dsa.PublicKey{
Y: y,
Parameters: dsa.Parameters{
P: new(big.Int),
Q: new(big.Int),
G: new(big.Int),
},
}
paramsDer := cryptobyte.String(params.FullBytes)
if !paramsDer.ReadASN1(¶msDer, cryptobyte_asn1.SEQUENCE) ||
!paramsDer.ReadASN1Integer(pub.Parameters.P) ||
!paramsDer.ReadASN1Integer(pub.Parameters.Q) ||
!paramsDer.ReadASN1Integer(pub.Parameters.G) {
return nil, errors.New("x509: invalid DSA parameters")
}
if pub.Y.Sign() <= 0 || pub.Parameters.P.Sign() <= 0 ||
pub.Parameters.Q.Sign() <= 0 || pub.Parameters.G.Sign() <= 0 {
return nil, errors.New("x509: zero or negative DSA parameter")
}
return pub, nil
default:
return nil, errors.New("x509: unknown public key algorithm")
}
}
func parseKeyUsageExtension(der cryptobyte.String) (KeyUsage, error) {
var usageBits asn1.BitString
if !der.ReadASN1BitString(&usageBits) {
return 0, errors.New("x509: invalid key usage")
}
var usage int
for i := 0; i < 9; i++ {
if usageBits.At(i) != 0 {
usage |= 1 << uint(i)
}
}
return KeyUsage(usage), nil
}
func parseBasicConstraintsExtension(der cryptobyte.String) (bool, int, error) {
var isCA bool
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return false, 0, errors.New("x509: invalid basic constraints a")
}
if der.PeekASN1Tag(cryptobyte_asn1.BOOLEAN) {
if !der.ReadASN1Boolean(&isCA) {
return false, 0, errors.New("x509: invalid basic constraints b")
}
}
maxPathLen := -1
if !der.Empty() && der.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
if !der.ReadASN1Integer(&maxPathLen) {
return false, 0, errors.New("x509: invalid basic constraints c")
}
}
// TODO: map out.MaxPathLen to 0 if it has the -1 default value? (Issue 19285)
return isCA, maxPathLen, nil
}
func forEachSAN(der cryptobyte.String, callback func(tag int, data []byte) error) error {
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid subject alternative names")
}
for !der.Empty() {
var san cryptobyte.String
var tag cryptobyte_asn1.Tag
if !der.ReadAnyASN1(&san, &tag) {
return errors.New("x509: invalid subject alternative name")
}
if err := callback(int(tag^0x80), san); err != nil {
return err
}
}
return nil
}
func parseSANExtension(der cryptobyte.String) (dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL, err error) {
err = forEachSAN(der, func(tag int, data []byte) error {
switch tag {
case nameTypeEmail:
email := string(data)
if err := isIA5String(email); err != nil {
return errors.New("x509: SAN rfc822Name is malformed")
}
emailAddresses = append(emailAddresses, email)
case nameTypeDNS:
name := string(data)
if err := isIA5String(name); err != nil {
return errors.New("x509: SAN dNSName is malformed")
}
dnsNames = append(dnsNames, string(name))
case nameTypeURI:
uriStr := string(data)
if err := isIA5String(uriStr); err != nil {
return errors.New("x509: SAN uniformResourceIdentifier is malformed")
}
uri, err := url.Parse(uriStr)
if err != nil {
return fmt.Errorf("x509: cannot parse URI %q: %s", uriStr, err)
}
if len(uri.Host) > 0 {
if _, ok := domainToReverseLabels(uri.Host); !ok {
return fmt.Errorf("x509: cannot parse URI %q: invalid domain", uriStr)
}
}
uris = append(uris, uri)
case nameTypeIP:
switch len(data) {
case net.IPv4len, net.IPv6len:
ipAddresses = append(ipAddresses, data)
default:
return errors.New("x509: cannot parse IP address of length " + strconv.Itoa(len(data)))
}
}
return nil
})
return
}
func parseExtKeyUsageExtension(der cryptobyte.String) ([]ExtKeyUsage, []asn1.ObjectIdentifier, error) {
var extKeyUsages []ExtKeyUsage
var unknownUsages []asn1.ObjectIdentifier
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, nil, errors.New("x509: invalid extended key usages")
}
for !der.Empty() {
var eku asn1.ObjectIdentifier
if !der.ReadASN1ObjectIdentifier(&eku) {
return nil, nil, errors.New("x509: invalid extended key usages")
}
if extKeyUsage, ok := extKeyUsageFromOID(eku); ok {
extKeyUsages = append(extKeyUsages, extKeyUsage)
} else {
unknownUsages = append(unknownUsages, eku)
}
}
return extKeyUsages, unknownUsages, nil
}
func parseCertificatePoliciesExtension(der cryptobyte.String) ([]asn1.ObjectIdentifier, error) {
var oids []asn1.ObjectIdentifier
if !der.ReadASN1(&der, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid certificate policies")
}
for !der.Empty() {
var cp cryptobyte.String
if !der.ReadASN1(&cp, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: invalid certificate policies")
}
var oid asn1.ObjectIdentifier
if !cp.ReadASN1ObjectIdentifier(&oid) {
return nil, errors.New("x509: invalid certificate policies")
}
oids = append(oids, oid)
}
return oids, nil
}
// isValidIPMask reports whether mask consists of zero or more 1 bits, followed by zero bits.
func isValidIPMask(mask []byte) bool {
seenZero := false
for _, b := range mask {
if seenZero {
if b != 0 {
return false
}
continue
}
switch b {
case 0x00, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe:
seenZero = true
case 0xff:
default:
return false
}
}
return true
}
func parseNameConstraintsExtension(out *Certificate, e pkix.Extension) (unhandled bool, err error) {
// RFC 5280, 4.2.1.10
// NameConstraints ::= SEQUENCE {
// permittedSubtrees [0] GeneralSubtrees OPTIONAL,
// excludedSubtrees [1] GeneralSubtrees OPTIONAL }
//
// GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree
//
// GeneralSubtree ::= SEQUENCE {
// base GeneralName,
// minimum [0] BaseDistance DEFAULT 0,
// maximum [1] BaseDistance OPTIONAL }
//
// BaseDistance ::= INTEGER (0..MAX)
outer := cryptobyte.String(e.Value)
var toplevel, permitted, excluded cryptobyte.String
var havePermitted, haveExcluded bool
if !outer.ReadASN1(&toplevel, cryptobyte_asn1.SEQUENCE) ||
!outer.Empty() ||
!toplevel.ReadOptionalASN1(&permitted, &havePermitted, cryptobyte_asn1.Tag(0).ContextSpecific().Constructed()) ||
!toplevel.ReadOptionalASN1(&excluded, &haveExcluded, cryptobyte_asn1.Tag(1).ContextSpecific().Constructed()) ||
!toplevel.Empty() {
return false, errors.New("x509: invalid NameConstraints extension")
}
if !havePermitted && !haveExcluded || len(permitted) == 0 && len(excluded) == 0 {
// From RFC 5280, Section 4.2.1.10:
// “either the permittedSubtrees field
// or the excludedSubtrees MUST be
// present”
return false, errors.New("x509: empty name constraints extension")
}
getValues := func(subtrees cryptobyte.String) (dnsNames []string, ips []*net.IPNet, emails, uriDomains []string, err error) {
for !subtrees.Empty() {
var seq, value cryptobyte.String
var tag cryptobyte_asn1.Tag
if !subtrees.ReadASN1(&seq, cryptobyte_asn1.SEQUENCE) ||
!seq.ReadAnyASN1(&value, &tag) {
return nil, nil, nil, nil, fmt.Errorf("x509: invalid NameConstraints extension")
}
var (
dnsTag = cryptobyte_asn1.Tag(2).ContextSpecific()
emailTag = cryptobyte_asn1.Tag(1).ContextSpecific()
ipTag = cryptobyte_asn1.Tag(7).ContextSpecific()
uriTag = cryptobyte_asn1.Tag(6).ContextSpecific()
)
switch tag {
case dnsTag:
domain := string(value)
if err := isIA5String(domain); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
trimmedDomain := domain
if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' {
// constraints can have a leading
// period to exclude the domain
// itself, but that's not valid in a
// normal domain name.
trimmedDomain = trimmedDomain[1:]
}
if _, ok := domainToReverseLabels(trimmedDomain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse dnsName constraint %q", domain)
}
dnsNames = append(dnsNames, domain)
case ipTag:
l := len(value)
var ip, mask []byte
switch l {
case 8:
ip = value[:4]
mask = value[4:]
case 32:
ip = value[:16]
mask = value[16:]
default:
return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained value of length %d", l)
}
if !isValidIPMask(mask) {
return nil, nil, nil, nil, fmt.Errorf("x509: IP constraint contained invalid mask %x", mask)
}
ips = append(ips, &net.IPNet{IP: net.IP(ip), Mask: net.IPMask(mask)})
case emailTag:
constraint := string(value)
if err := isIA5String(constraint); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
// If the constraint contains an @ then
// it specifies an exact mailbox name.
if strings.Contains(constraint, "@") {
if _, ok := parseRFC2821Mailbox(constraint); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint)
}
} else {
// Otherwise it's a domain name.
domain := constraint
if len(domain) > 0 && domain[0] == '.' {
domain = domain[1:]
}
if _, ok := domainToReverseLabels(domain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse rfc822Name constraint %q", constraint)
}
}
emails = append(emails, constraint)
case uriTag:
domain := string(value)
if err := isIA5String(domain); err != nil {
return nil, nil, nil, nil, errors.New("x509: invalid constraint value: " + err.Error())
}
if net.ParseIP(domain) != nil {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q: cannot be IP address", domain)
}
trimmedDomain := domain
if len(trimmedDomain) > 0 && trimmedDomain[0] == '.' {
// constraints can have a leading
// period to exclude the domain itself,
// but that's not valid in a normal
// domain name.
trimmedDomain = trimmedDomain[1:]
}
if _, ok := domainToReverseLabels(trimmedDomain); !ok {
return nil, nil, nil, nil, fmt.Errorf("x509: failed to parse URI constraint %q", domain)
}
uriDomains = append(uriDomains, domain)
default:
unhandled = true
}
}
return dnsNames, ips, emails, uriDomains, nil
}
if out.PermittedDNSDomains, out.PermittedIPRanges, out.PermittedEmailAddresses, out.PermittedURIDomains, err = getValues(permitted); err != nil {
return false, err
}
if out.ExcludedDNSDomains, out.ExcludedIPRanges, out.ExcludedEmailAddresses, out.ExcludedURIDomains, err = getValues(excluded); err != nil {
return false, err
}
out.PermittedDNSDomainsCritical = e.Critical
return unhandled, nil
}
func processExtensions(out *Certificate) error {
var err error
for _, e := range out.Extensions {
unhandled := false
if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 {
switch e.Id[3] {
case 15:
out.KeyUsage, err = parseKeyUsageExtension(e.Value)
if err != nil {
return err
}
case 19:
out.IsCA, out.MaxPathLen, err = parseBasicConstraintsExtension(e.Value)
if err != nil {
return err
}
out.BasicConstraintsValid = true
out.MaxPathLenZero = out.MaxPathLen == 0
case 17:
out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(e.Value)
if err != nil {
return err
}
if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 && len(out.URIs) == 0 {
// If we didn't parse anything then we do the critical check, below.
unhandled = true
}
case 30:
unhandled, err = parseNameConstraintsExtension(out, e)
if err != nil {
return err
}
case 31:
// RFC 5280, 4.2.1.13
// CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint
//
// DistributionPoint ::= SEQUENCE {
// distributionPoint [0] DistributionPointName OPTIONAL,
// reasons [1] ReasonFlags OPTIONAL,
// cRLIssuer [2] GeneralNames OPTIONAL }
//
// DistributionPointName ::= CHOICE {
// fullName [0] GeneralNames,
// nameRelativeToCRLIssuer [1] RelativeDistinguishedName }
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid CRL distribution points")
}
for !val.Empty() {
var dpDER cryptobyte.String
if !val.ReadASN1(&dpDER, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid CRL distribution point")
}
var dpNameDER cryptobyte.String
var dpNamePresent bool
if !dpDER.ReadOptionalASN1(&dpNameDER, &dpNamePresent, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
if !dpNamePresent {
continue
}
if !dpNameDER.ReadASN1(&dpNameDER, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
for !dpNameDER.Empty() {
if !dpNameDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) {
break
}
var uri cryptobyte.String
if !dpNameDER.ReadASN1(&uri, cryptobyte_asn1.Tag(6).ContextSpecific()) {
return errors.New("x509: invalid CRL distribution point")
}
out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(uri))
}
}
case 35:
// RFC 5280, 4.2.1.1
val := cryptobyte.String(e.Value)
var akid cryptobyte.String
if !val.ReadASN1(&akid, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid authority key identifier")
}
if akid.PeekASN1Tag(cryptobyte_asn1.Tag(0).ContextSpecific()) {
if !akid.ReadASN1(&akid, cryptobyte_asn1.Tag(0).ContextSpecific()) {
return errors.New("x509: invalid authority key identifier")
}
out.AuthorityKeyId = akid
}
case 37:
out.ExtKeyUsage, out.UnknownExtKeyUsage, err = parseExtKeyUsageExtension(e.Value)
if err != nil {
return err
}
case 14:
// RFC 5280, 4.2.1.2
val := cryptobyte.String(e.Value)
var skid cryptobyte.String
if !val.ReadASN1(&skid, cryptobyte_asn1.OCTET_STRING) {
return errors.New("x509: invalid subject key identifier")
}
out.SubjectKeyId = skid
case 32:
out.PolicyIdentifiers, err = parseCertificatePoliciesExtension(e.Value)
if err != nil {
return err
}
default:
// Unknown extensions are recorded if critical.
unhandled = true
}
} else if e.Id.Equal(oidExtensionAuthorityInfoAccess) {
// RFC 5280 4.2.2.1: Authority Information Access
val := cryptobyte.String(e.Value)
if !val.ReadASN1(&val, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid authority info access")
}
for !val.Empty() {
var aiaDER cryptobyte.String
if !val.ReadASN1(&aiaDER, cryptobyte_asn1.SEQUENCE) {
return errors.New("x509: invalid authority info access")
}
var method asn1.ObjectIdentifier
if !aiaDER.ReadASN1ObjectIdentifier(&method) {
return errors.New("x509: invalid authority info access")
}
if !aiaDER.PeekASN1Tag(cryptobyte_asn1.Tag(6).ContextSpecific()) {
continue
}
if !aiaDER.ReadASN1(&aiaDER, cryptobyte_asn1.Tag(6).ContextSpecific()) {
return errors.New("x509: invalid authority info access")
}
switch {
case method.Equal(oidAuthorityInfoAccessOcsp):
out.OCSPServer = append(out.OCSPServer, string(aiaDER))
case method.Equal(oidAuthorityInfoAccessIssuers):
out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(aiaDER))
}
}
} else {
// Unknown extensions are recorded if critical.
unhandled = true
}
if e.Critical && unhandled {
out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id)
}
}
return nil
}
func parseCertificate(der []byte) (*Certificate, error) {
cert := &Certificate{}
input := cryptobyte.String(der)
// we read the SEQUENCE including length and tag bytes so that
// we can populate Certificate.Raw, before unwrapping the
// SEQUENCE so it can be operated on
if !input.ReadASN1Element(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed certificate")
}
cert.Raw = input
if !input.ReadASN1(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed certificate")
}
var tbs cryptobyte.String
// do the same trick again as above to extract the raw
// bytes for Certificate.RawTBSCertificate
if !input.ReadASN1Element(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs certificate")
}
cert.RawTBSCertificate = tbs
if !tbs.ReadASN1(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs certificate")
}
if !tbs.ReadOptionalASN1Integer(&cert.Version, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific(), 0) {
return nil, errors.New("x509: malformed version")
}
if cert.Version < 0 {
return nil, errors.New("x509: malformed version")
}
// for backwards compat reasons Version is one-indexed,
// rather than zero-indexed as defined in 5280
cert.Version++
if cert.Version > 3 {
return nil, errors.New("x509: invalid version")
}
serial := new(big.Int)
if !tbs.ReadASN1Integer(serial) {
return nil, errors.New("x509: malformed serial number")
}
// we ignore the presence of negative serial numbers because
// of their prevalence, despite them being invalid
// TODO(rolandshoemaker): revisit this decision, there are currently
// only 10 trusted certificates with negative serial numbers
// according to censys.io.
cert.SerialNumber = serial
var sigAISeq cryptobyte.String
if !tbs.ReadASN1(&sigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed signature algorithm identifier")
}
// Before parsing the inner algorithm identifier, extract
// the outer algorithm identifier and make sure that they
// match.
var outerSigAISeq cryptobyte.String
if !input.ReadASN1(&outerSigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed algorithm identifier")
}
if !bytes.Equal(outerSigAISeq, sigAISeq) {
return nil, errors.New("x509: inner and outer signature algorithm identifiers don't match")
}
sigAI, err := parseAI(sigAISeq)
if err != nil {
return nil, err
}
cert.SignatureAlgorithm = getSignatureAlgorithmFromAI(sigAI)
var issuerSeq cryptobyte.String
if !tbs.ReadASN1Element(&issuerSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
cert.RawIssuer = issuerSeq
issuerRDNs, err := parseName(issuerSeq)
if err != nil {
return nil, err
}
cert.Issuer.FillFromRDNSequence(issuerRDNs)
var validity cryptobyte.String
if !tbs.ReadASN1(&validity, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed validity")
}
cert.NotBefore, cert.NotAfter, err = parseValidity(validity)
if err != nil {
return nil, err
}
var subjectSeq cryptobyte.String
if !tbs.ReadASN1Element(&subjectSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
cert.RawSubject = subjectSeq
subjectRDNs, err := parseName(subjectSeq)
if err != nil {
return nil, err
}
cert.Subject.FillFromRDNSequence(subjectRDNs)
var spki cryptobyte.String
if !tbs.ReadASN1Element(&spki, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed spki")
}
cert.RawSubjectPublicKeyInfo = spki
if !spki.ReadASN1(&spki, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed spki")
}
var pkAISeq cryptobyte.String
if !spki.ReadASN1(&pkAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed public key algorithm identifier")
}
pkAI, err := parseAI(pkAISeq)
if err != nil {
return nil, err
}
cert.PublicKeyAlgorithm = getPublicKeyAlgorithmFromOID(pkAI.Algorithm)
var spk asn1.BitString
if !spki.ReadASN1BitString(&spk) {
return nil, errors.New("x509: malformed subjectPublicKey")
}
if cert.PublicKeyAlgorithm != UnknownPublicKeyAlgorithm {
cert.PublicKey, err = parsePublicKey(&publicKeyInfo{
Algorithm: pkAI,
PublicKey: spk,
})
if err != nil {
return nil, err
}
}
if cert.Version > 1 {
if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(1).ContextSpecific()) {
return nil, errors.New("x509: malformed issuerUniqueID")
}
if !tbs.SkipOptionalASN1(cryptobyte_asn1.Tag(2).ContextSpecific()) {
return nil, errors.New("x509: malformed subjectUniqueID")
}
if cert.Version == 3 {
var extensions cryptobyte.String
var present bool
if !tbs.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.Tag(3).Constructed().ContextSpecific()) {
return nil, errors.New("x509: malformed extensions")
}
if present {
seenExts := make(map[string]bool)
if !extensions.ReadASN1(&extensions, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
oidStr := ext.Id.String()
if seenExts[oidStr] {
return nil, errors.New("x509: certificate contains duplicate extensions")
}
seenExts[oidStr] = true
cert.Extensions = append(cert.Extensions, ext)
}
err = processExtensions(cert)
if err != nil {
return nil, err
}
}
}
}
var signature asn1.BitString
if !input.ReadASN1BitString(&signature) {
return nil, errors.New("x509: malformed signature")
}
cert.Signature = signature.RightAlign()
return cert, nil
}
// ParseCertificate parses a single certificate from the given ASN.1 DER data.
func ParseCertificate(der []byte) (*Certificate, error) {
cert, err := parseCertificate(der)
if err != nil {
return nil, err
}
if len(der) != len(cert.Raw) {
return nil, errors.New("x509: trailing data")
}
return cert, err
}
// ParseCertificates parses one or more certificates from the given ASN.1 DER
// data. The certificates must be concatenated with no intermediate padding.
func ParseCertificates(der []byte) ([]*Certificate, error) {
var certs []*Certificate
for len(der) > 0 {
cert, err := parseCertificate(der)
if err != nil {
return nil, err
}
certs = append(certs, cert)
der = der[len(cert.Raw):]
}
return certs, nil
}
// The X.509 standards confusingly 1-indexed the version names, but 0-indexed
// the actual encoded version, so the version for X.509v2 is 1.
const x509v2Version = 1
// ParseRevocationList parses a X509 v2 Certificate Revocation List from the given
// ASN.1 DER data.
func ParseRevocationList(der []byte) (*RevocationList, error) {
rl := &RevocationList{}
input := cryptobyte.String(der)
// we read the SEQUENCE including length and tag bytes so that
// we can populate RevocationList.Raw, before unwrapping the
// SEQUENCE so it can be operated on
if !input.ReadASN1Element(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
rl.Raw = input
if !input.ReadASN1(&input, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
var tbs cryptobyte.String
// do the same trick again as above to extract the raw
// bytes for Certificate.RawTBSCertificate
if !input.ReadASN1Element(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs crl")
}
rl.RawTBSRevocationList = tbs
if !tbs.ReadASN1(&tbs, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed tbs crl")
}
var version int
if !tbs.PeekASN1Tag(cryptobyte_asn1.INTEGER) {
return nil, errors.New("x509: unsupported crl version")
}
if !tbs.ReadASN1Integer(&version) {
return nil, errors.New("x509: malformed crl")
}
if version != x509v2Version {
return nil, fmt.Errorf("x509: unsupported crl version: %d", version)
}
var sigAISeq cryptobyte.String
if !tbs.ReadASN1(&sigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed signature algorithm identifier")
}
// Before parsing the inner algorithm identifier, extract
// the outer algorithm identifier and make sure that they
// match.
var outerSigAISeq cryptobyte.String
if !input.ReadASN1(&outerSigAISeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed algorithm identifier")
}
if !bytes.Equal(outerSigAISeq, sigAISeq) {
return nil, errors.New("x509: inner and outer signature algorithm identifiers don't match")
}
sigAI, err := parseAI(sigAISeq)
if err != nil {
return nil, err
}
rl.SignatureAlgorithm = getSignatureAlgorithmFromAI(sigAI)
var signature asn1.BitString
if !input.ReadASN1BitString(&signature) {
return nil, errors.New("x509: malformed signature")
}
rl.Signature = signature.RightAlign()
var issuerSeq cryptobyte.String
if !tbs.ReadASN1Element(&issuerSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed issuer")
}
rl.RawIssuer = issuerSeq
issuerRDNs, err := parseName(issuerSeq)
if err != nil {
return nil, err
}
rl.Issuer.FillFromRDNSequence(issuerRDNs)
rl.ThisUpdate, err = parseTime(&tbs)
if err != nil {
return nil, err
}
if tbs.PeekASN1Tag(cryptobyte_asn1.GeneralizedTime) || tbs.PeekASN1Tag(cryptobyte_asn1.UTCTime) {
rl.NextUpdate, err = parseTime(&tbs)
if err != nil {
return nil, err
}
}
if tbs.PeekASN1Tag(cryptobyte_asn1.SEQUENCE) {
var revokedSeq cryptobyte.String
if !tbs.ReadASN1(&revokedSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
for !revokedSeq.Empty() {
var certSeq cryptobyte.String
if !revokedSeq.ReadASN1(&certSeq, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed crl")
}
rc := pkix.RevokedCertificate{}
rc.SerialNumber = new(big.Int)
if !certSeq.ReadASN1Integer(rc.SerialNumber) {
return nil, errors.New("x509: malformed serial number")
}
rc.RevocationTime, err = parseTime(&certSeq)
if err != nil {
return nil, err
}
var extensions cryptobyte.String
var present bool
if !certSeq.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
if present {
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
rc.Extensions = append(rc.Extensions, ext)
}
}
rl.RevokedCertificates = append(rl.RevokedCertificates, rc)
}
}
var extensions cryptobyte.String
var present bool
if !tbs.ReadOptionalASN1(&extensions, &present, cryptobyte_asn1.Tag(0).Constructed().ContextSpecific()) {
return nil, errors.New("x509: malformed extensions")
}
if present {
if !extensions.ReadASN1(&extensions, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extensions")
}
for !extensions.Empty() {
var extension cryptobyte.String
if !extensions.ReadASN1(&extension, cryptobyte_asn1.SEQUENCE) {
return nil, errors.New("x509: malformed extension")
}
ext, err := parseExtension(extension)
if err != nil {
return nil, err
}
if ext.Id.Equal(oidExtensionAuthorityKeyId) {
rl.AuthorityKeyId = ext.Value
} else if ext.Id.Equal(oidExtensionCRLNumber) {
value := cryptobyte.String(ext.Value)
rl.Number = new(big.Int)
if !value.ReadASN1Integer(rl.Number) {
return nil, errors.New("x509: malformed crl number")
}
}
rl.Extensions = append(rl.Extensions, ext)
}
}
return rl, nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
// RFC 1423 describes the encryption of PEM blocks. The algorithm used to
// generate a key from the password was derived by looking at the OpenSSL
// implementation.
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/md5"
"encoding/hex"
"encoding/pem"
"errors"
"io"
"strings"
)
type PEMCipher int
// Possible values for the EncryptPEMBlock encryption algorithm.
const (
_ PEMCipher = iota
PEMCipherDES
PEMCipher3DES
PEMCipherAES128
PEMCipherAES192
PEMCipherAES256
)
// rfc1423Algo holds a method for enciphering a PEM block.
type rfc1423Algo struct {
cipher PEMCipher
name string
cipherFunc func(key []byte) (cipher.Block, error)
keySize int
blockSize int
}
// rfc1423Algos holds a slice of the possible ways to encrypt a PEM
// block. The ivSize numbers were taken from the OpenSSL source.
var rfc1423Algos = []rfc1423Algo{{
cipher: PEMCipherDES,
name: "DES-CBC",
cipherFunc: des.NewCipher,
keySize: 8,
blockSize: des.BlockSize,
}, {
cipher: PEMCipher3DES,
name: "DES-EDE3-CBC",
cipherFunc: des.NewTripleDESCipher,
keySize: 24,
blockSize: des.BlockSize,
}, {
cipher: PEMCipherAES128,
name: "AES-128-CBC",
cipherFunc: aes.NewCipher,
keySize: 16,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES192,
name: "AES-192-CBC",
cipherFunc: aes.NewCipher,
keySize: 24,
blockSize: aes.BlockSize,
}, {
cipher: PEMCipherAES256,
name: "AES-256-CBC",
cipherFunc: aes.NewCipher,
keySize: 32,
blockSize: aes.BlockSize,
},
}
// deriveKey uses a key derivation function to stretch the password into a key
// with the number of bits our cipher requires. This algorithm was derived from
// the OpenSSL source.
func (c rfc1423Algo) deriveKey(password, salt []byte) []byte {
hash := md5.New()
out := make([]byte, c.keySize)
var digest []byte
for i := 0; i < len(out); i += len(digest) {
hash.Reset()
hash.Write(digest)
hash.Write(password)
hash.Write(salt)
digest = hash.Sum(digest[:0])
copy(out[i:], digest)
}
return out
}
// IsEncryptedPEMBlock returns whether the PEM block is password encrypted
// according to RFC 1423.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func IsEncryptedPEMBlock(b *pem.Block) bool {
_, ok := b.Headers["DEK-Info"]
return ok
}
// IncorrectPasswordError is returned when an incorrect password is detected.
var IncorrectPasswordError = errors.New("x509: decryption password incorrect")
// DecryptPEMBlock takes a PEM block encrypted according to RFC 1423 and the
// password used to encrypt it and returns a slice of decrypted DER encoded
// bytes. It inspects the DEK-Info header to determine the algorithm used for
// decryption. If no DEK-Info header is present, an error is returned. If an
// incorrect password is detected an IncorrectPasswordError is returned. Because
// of deficiencies in the format, it's not always possible to detect an
// incorrect password. In these cases no error will be returned but the
// decrypted DER bytes will be random noise.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func DecryptPEMBlock(b *pem.Block, password []byte) ([]byte, error) {
dek, ok := b.Headers["DEK-Info"]
if !ok {
return nil, errors.New("x509: no DEK-Info header in block")
}
mode, hexIV, ok := strings.Cut(dek, ",")
if !ok {
return nil, errors.New("x509: malformed DEK-Info header")
}
ciph := cipherByName(mode)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv, err := hex.DecodeString(hexIV)
if err != nil {
return nil, err
}
if len(iv) != ciph.blockSize {
return nil, errors.New("x509: incorrect IV size")
}
// Based on the OpenSSL implementation. The salt is the first 8 bytes
// of the initialization vector.
key := ciph.deriveKey(password, iv[:8])
block, err := ciph.cipherFunc(key)
if err != nil {
return nil, err
}
if len(b.Bytes)%block.BlockSize() != 0 {
return nil, errors.New("x509: encrypted PEM data is not a multiple of the block size")
}
data := make([]byte, len(b.Bytes))
dec := cipher.NewCBCDecrypter(block, iv)
dec.CryptBlocks(data, b.Bytes)
// Blocks are padded using a scheme where the last n bytes of padding are all
// equal to n. It can pad from 1 to blocksize bytes inclusive. See RFC 1423.
// For example:
// [x y z 2 2]
// [x y 7 7 7 7 7 7 7]
// If we detect a bad padding, we assume it is an invalid password.
dlen := len(data)
if dlen == 0 || dlen%ciph.blockSize != 0 {
return nil, errors.New("x509: invalid padding")
}
last := int(data[dlen-1])
if dlen < last {
return nil, IncorrectPasswordError
}
if last == 0 || last > ciph.blockSize {
return nil, IncorrectPasswordError
}
for _, val := range data[dlen-last:] {
if int(val) != last {
return nil, IncorrectPasswordError
}
}
return data[:dlen-last], nil
}
// EncryptPEMBlock returns a PEM block of the specified type holding the
// given DER encoded data encrypted with the specified algorithm and
// password according to RFC 1423.
//
// Deprecated: Legacy PEM encryption as specified in RFC 1423 is insecure by
// design. Since it does not authenticate the ciphertext, it is vulnerable to
// padding oracle attacks that can let an attacker recover the plaintext.
func EncryptPEMBlock(rand io.Reader, blockType string, data, password []byte, alg PEMCipher) (*pem.Block, error) {
ciph := cipherByKey(alg)
if ciph == nil {
return nil, errors.New("x509: unknown encryption mode")
}
iv := make([]byte, ciph.blockSize)
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, errors.New("x509: cannot generate IV: " + err.Error())
}
// The salt is the first 8 bytes of the initialization vector,
// matching the key derivation in DecryptPEMBlock.
key := ciph.deriveKey(password, iv[:8])
block, err := ciph.cipherFunc(key)
if err != nil {
return nil, err
}
enc := cipher.NewCBCEncrypter(block, iv)
pad := ciph.blockSize - len(data)%ciph.blockSize
encrypted := make([]byte, len(data), len(data)+pad)
// We could save this copy by encrypting all the whole blocks in
// the data separately, but it doesn't seem worth the additional
// code.
copy(encrypted, data)
// See RFC 1423, Section 1.1.
for i := 0; i < pad; i++ {
encrypted = append(encrypted, byte(pad))
}
enc.CryptBlocks(encrypted, encrypted)
return &pem.Block{
Type: blockType,
Headers: map[string]string{
"Proc-Type": "4,ENCRYPTED",
"DEK-Info": ciph.name + "," + hex.EncodeToString(iv),
},
Bytes: encrypted,
}, nil
}
func cipherByName(name string) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.name == name {
return alg
}
}
return nil
}
func cipherByKey(key PEMCipher) *rfc1423Algo {
for i := range rfc1423Algos {
alg := &rfc1423Algos[i]
if alg.cipher == key {
return alg
}
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/rsa"
"encoding/asn1"
"errors"
"math/big"
)
// pkcs1PrivateKey is a structure which mirrors the PKCS #1 ASN.1 for an RSA private key.
type pkcs1PrivateKey struct {
Version int
N *big.Int
E int
D *big.Int
P *big.Int
Q *big.Int
// We ignore these values, if present, because rsa will calculate them.
Dp *big.Int `asn1:"optional"`
Dq *big.Int `asn1:"optional"`
Qinv *big.Int `asn1:"optional"`
AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"`
}
type pkcs1AdditionalRSAPrime struct {
Prime *big.Int
// We ignore these values because rsa will calculate them.
Exp *big.Int
Coeff *big.Int
}
// pkcs1PublicKey reflects the ASN.1 structure of a PKCS #1 public key.
type pkcs1PublicKey struct {
N *big.Int
E int
}
// ParsePKCS1PrivateKey parses an RSA private key in PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PRIVATE KEY".
func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) {
var priv pkcs1PrivateKey
rest, err := asn1.Unmarshal(der, &priv)
if len(rest) > 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
if err != nil {
if _, err := asn1.Unmarshal(der, &ecPrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParseECPrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs8{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)")
}
return nil, err
}
if priv.Version > 1 {
return nil, errors.New("x509: unsupported private key version")
}
if priv.N.Sign() <= 0 || priv.D.Sign() <= 0 || priv.P.Sign() <= 0 || priv.Q.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative value")
}
key := new(rsa.PrivateKey)
key.PublicKey = rsa.PublicKey{
E: priv.E,
N: priv.N,
}
key.D = priv.D
key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes))
key.Primes[0] = priv.P
key.Primes[1] = priv.Q
for i, a := range priv.AdditionalPrimes {
if a.Prime.Sign() <= 0 {
return nil, errors.New("x509: private key contains zero or negative prime")
}
key.Primes[i+2] = a.Prime
// We ignore the other two values because rsa will calculate
// them as needed.
}
err = key.Validate()
if err != nil {
return nil, err
}
key.Precompute()
return key, nil
}
// MarshalPKCS1PrivateKey converts an RSA private key to PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PRIVATE KEY".
// For a more flexible key format which is not RSA specific, use
// MarshalPKCS8PrivateKey.
func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte {
key.Precompute()
version := 0
if len(key.Primes) > 2 {
version = 1
}
priv := pkcs1PrivateKey{
Version: version,
N: key.N,
E: key.PublicKey.E,
D: key.D,
P: key.Primes[0],
Q: key.Primes[1],
Dp: key.Precomputed.Dp,
Dq: key.Precomputed.Dq,
Qinv: key.Precomputed.Qinv,
}
priv.AdditionalPrimes = make([]pkcs1AdditionalRSAPrime, len(key.Precomputed.CRTValues))
for i, values := range key.Precomputed.CRTValues {
priv.AdditionalPrimes[i].Prime = key.Primes[2+i]
priv.AdditionalPrimes[i].Exp = values.Exp
priv.AdditionalPrimes[i].Coeff = values.Coeff
}
b, _ := asn1.Marshal(priv)
return b
}
// ParsePKCS1PublicKey parses an RSA public key in PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PUBLIC KEY".
func ParsePKCS1PublicKey(der []byte) (*rsa.PublicKey, error) {
var pub pkcs1PublicKey
rest, err := asn1.Unmarshal(der, &pub)
if err != nil {
if _, err := asn1.Unmarshal(der, &publicKeyInfo{}); err == nil {
return nil, errors.New("x509: failed to parse public key (use ParsePKIXPublicKey instead for this key format)")
}
return nil, err
}
if len(rest) > 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
if pub.N.Sign() <= 0 || pub.E <= 0 {
return nil, errors.New("x509: public key contains zero or negative value")
}
if pub.E > 1<<31-1 {
return nil, errors.New("x509: public key contains large public exponent")
}
return &rsa.PublicKey{
E: pub.E,
N: pub.N,
}, nil
}
// MarshalPKCS1PublicKey converts an RSA public key to PKCS #1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "RSA PUBLIC KEY".
func MarshalPKCS1PublicKey(key *rsa.PublicKey) []byte {
derBytes, _ := asn1.Marshal(pkcs1PublicKey{
N: key.N,
E: key.E,
})
return derBytes
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
)
// pkcs8 reflects an ASN.1, PKCS #8 PrivateKey. See
// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn
// and RFC 5208.
type pkcs8 struct {
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
// optional attributes omitted.
}
// ParsePKCS8PrivateKey parses an unencrypted private key in PKCS #8, ASN.1 DER form.
//
// It returns a *rsa.PrivateKey, a *ecdsa.PrivateKey, a ed25519.PrivateKey (not
// a pointer), or a *ecdh.PrivateKey (for X25519). More types might be supported
// in the future.
//
// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY".
func ParsePKCS8PrivateKey(der []byte) (key any, err error) {
var privKey pkcs8
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
if _, err := asn1.Unmarshal(der, &ecPrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParseECPrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)")
}
return nil, err
}
switch {
case privKey.Algo.Algorithm.Equal(oidPublicKeyRSA):
key, err = ParsePKCS1PrivateKey(privKey.PrivateKey)
if err != nil {
return nil, errors.New("x509: failed to parse RSA private key embedded in PKCS#8: " + err.Error())
}
return key, nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyECDSA):
bytes := privKey.Algo.Parameters.FullBytes
namedCurveOID := new(asn1.ObjectIdentifier)
if _, err := asn1.Unmarshal(bytes, namedCurveOID); err != nil {
namedCurveOID = nil
}
key, err = parseECPrivateKey(namedCurveOID, privKey.PrivateKey)
if err != nil {
return nil, errors.New("x509: failed to parse EC private key embedded in PKCS#8: " + err.Error())
}
return key, nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyEd25519):
if l := len(privKey.Algo.Parameters.FullBytes); l != 0 {
return nil, errors.New("x509: invalid Ed25519 private key parameters")
}
var curvePrivateKey []byte
if _, err := asn1.Unmarshal(privKey.PrivateKey, &curvePrivateKey); err != nil {
return nil, fmt.Errorf("x509: invalid Ed25519 private key: %v", err)
}
if l := len(curvePrivateKey); l != ed25519.SeedSize {
return nil, fmt.Errorf("x509: invalid Ed25519 private key length: %d", l)
}
return ed25519.NewKeyFromSeed(curvePrivateKey), nil
case privKey.Algo.Algorithm.Equal(oidPublicKeyX25519):
if l := len(privKey.Algo.Parameters.FullBytes); l != 0 {
return nil, errors.New("x509: invalid X25519 private key parameters")
}
var curvePrivateKey []byte
if _, err := asn1.Unmarshal(privKey.PrivateKey, &curvePrivateKey); err != nil {
return nil, fmt.Errorf("x509: invalid X25519 private key: %v", err)
}
return ecdh.X25519().NewPrivateKey(curvePrivateKey)
default:
return nil, fmt.Errorf("x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm)
}
}
// MarshalPKCS8PrivateKey converts a private key to PKCS #8, ASN.1 DER form.
//
// The following key types are currently supported: *rsa.PrivateKey,
// *ecdsa.PrivateKey, ed25519.PrivateKey (not a pointer), and *ecdh.PrivateKey.
// Unsupported key types result in an error.
//
// This kind of key is commonly encoded in PEM blocks of type "PRIVATE KEY".
func MarshalPKCS8PrivateKey(key any) ([]byte, error) {
var privKey pkcs8
switch k := key.(type) {
case *rsa.PrivateKey:
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyRSA,
Parameters: asn1.NullRawValue,
}
privKey.PrivateKey = MarshalPKCS1PrivateKey(k)
case *ecdsa.PrivateKey:
oid, ok := oidFromNamedCurve(k.Curve)
if !ok {
return nil, errors.New("x509: unknown curve while marshaling to PKCS#8")
}
oidBytes, err := asn1.Marshal(oid)
if err != nil {
return nil, errors.New("x509: failed to marshal curve OID: " + err.Error())
}
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyECDSA,
Parameters: asn1.RawValue{
FullBytes: oidBytes,
},
}
if privKey.PrivateKey, err = marshalECPrivateKeyWithOID(k, nil); err != nil {
return nil, errors.New("x509: failed to marshal EC private key while building PKCS#8: " + err.Error())
}
case ed25519.PrivateKey:
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyEd25519,
}
curvePrivateKey, err := asn1.Marshal(k.Seed())
if err != nil {
return nil, fmt.Errorf("x509: failed to marshal private key: %v", err)
}
privKey.PrivateKey = curvePrivateKey
case *ecdh.PrivateKey:
if k.Curve() == ecdh.X25519() {
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyX25519,
}
var err error
if privKey.PrivateKey, err = asn1.Marshal(k.Bytes()); err != nil {
return nil, fmt.Errorf("x509: failed to marshal private key: %v", err)
}
} else {
oid, ok := oidFromECDHCurve(k.Curve())
if !ok {
return nil, errors.New("x509: unknown curve while marshaling to PKCS#8")
}
oidBytes, err := asn1.Marshal(oid)
if err != nil {
return nil, errors.New("x509: failed to marshal curve OID: " + err.Error())
}
privKey.Algo = pkix.AlgorithmIdentifier{
Algorithm: oidPublicKeyECDSA,
Parameters: asn1.RawValue{
FullBytes: oidBytes,
},
}
if privKey.PrivateKey, err = marshalECDHPrivateKey(k); err != nil {
return nil, errors.New("x509: failed to marshal EC private key while building PKCS#8: " + err.Error())
}
}
default:
return nil, fmt.Errorf("x509: unknown key type while marshaling PKCS#8: %T", key)
}
return asn1.Marshal(privKey)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"internal/godebug"
"sync"
)
var (
once sync.Once
systemRootsMu sync.RWMutex
systemRoots *CertPool
systemRootsErr error
fallbacksSet bool
)
func systemRootsPool() *CertPool {
once.Do(initSystemRoots)
systemRootsMu.RLock()
defer systemRootsMu.RUnlock()
return systemRoots
}
func initSystemRoots() {
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
systemRoots, systemRootsErr = loadSystemRoots()
if systemRootsErr != nil {
systemRoots = nil
}
}
var x509usefallbackroots = godebug.New("x509usefallbackroots")
// SetFallbackRoots sets the roots to use during certificate verification, if no
// custom roots are specified and a platform verifier or a system certificate
// pool is not available (for instance in a container which does not have a root
// certificate bundle). SetFallbackRoots will panic if roots is nil.
//
// SetFallbackRoots may only be called once, if called multiple times it will
// panic.
//
// The fallback behavior can be forced on all platforms, even when there is a
// system certificate pool, by setting GODEBUG=x509usefallbackroots=1 (note that
// on Windows and macOS this will disable usage of the platform verification
// APIs and cause the pure Go verifier to be used). Setting
// x509usefallbackroots=1 without calling SetFallbackRoots has no effect.
func SetFallbackRoots(roots *CertPool) {
if roots == nil {
panic("roots must be non-nil")
}
// trigger initSystemRoots if it hasn't already been called before we
// take the lock
_ = systemRootsPool()
systemRootsMu.Lock()
defer systemRootsMu.Unlock()
if fallbacksSet {
panic("SetFallbackRoots has already been called")
}
fallbacksSet = true
if systemRoots != nil && (systemRoots.len() > 0 || systemRoots.systemPool) {
if x509usefallbackroots.Value() != "1" {
return
}
x509usefallbackroots.IncNonDefault()
}
systemRoots, systemRootsErr = roots, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || dragonfly || freebsd || (js && wasm) || linux || netbsd || openbsd || solaris
package x509
import (
"io/fs"
"os"
"path/filepath"
"strings"
)
const (
// certFileEnv is the environment variable which identifies where to locate
// the SSL certificate file. If set this overrides the system default.
certFileEnv = "SSL_CERT_FILE"
// certDirEnv is the environment variable which identifies which directory
// to check for SSL certificate files. If set this overrides the system default.
// It is a colon separated list of directories.
// See https://www.openssl.org/docs/man1.0.2/man1/c_rehash.html.
certDirEnv = "SSL_CERT_DIR"
)
func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) {
return nil, nil
}
func loadSystemRoots() (*CertPool, error) {
roots := NewCertPool()
files := certFiles
if f := os.Getenv(certFileEnv); f != "" {
files = []string{f}
}
var firstErr error
for _, file := range files {
data, err := os.ReadFile(file)
if err == nil {
roots.AppendCertsFromPEM(data)
break
}
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
}
dirs := certDirectories
if d := os.Getenv(certDirEnv); d != "" {
// OpenSSL and BoringSSL both use ":" as the SSL_CERT_DIR separator.
// See:
// * https://golang.org/issue/35325
// * https://www.openssl.org/docs/man1.0.2/man1/c_rehash.html
dirs = strings.Split(d, ":")
}
for _, directory := range dirs {
fis, err := readUniqueDirectoryEntries(directory)
if err != nil {
if firstErr == nil && !os.IsNotExist(err) {
firstErr = err
}
continue
}
for _, fi := range fis {
data, err := os.ReadFile(directory + "/" + fi.Name())
if err == nil {
roots.AppendCertsFromPEM(data)
}
}
}
if roots.len() > 0 || firstErr == nil {
return roots, nil
}
return nil, firstErr
}
// readUniqueDirectoryEntries is like os.ReadDir but omits
// symlinks that point within the directory.
func readUniqueDirectoryEntries(dir string) ([]fs.DirEntry, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
uniq := files[:0]
for _, f := range files {
if !isSameDirSymlink(f, dir) {
uniq = append(uniq, f)
}
}
return uniq, nil
}
// isSameDirSymlink reports whether fi in dir is a symlink with a
// target not containing a slash.
func isSameDirSymlink(f fs.DirEntry, dir string) bool {
if f.Type()&fs.ModeSymlink == 0 {
return false
}
target, err := os.Readlink(filepath.Join(dir, f.Name()))
return err == nil && !strings.Contains(target, "/")
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/ecdh"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/asn1"
"errors"
"fmt"
"math/big"
)
const ecPrivKeyVersion = 1
// ecPrivateKey reflects an ASN.1 Elliptic Curve Private Key Structure.
// References:
//
// RFC 5915
// SEC1 - http://www.secg.org/sec1-v2.pdf
//
// Per RFC 5915 the NamedCurveOID is marked as ASN.1 OPTIONAL, however in
// most cases it is not.
type ecPrivateKey struct {
Version int
PrivateKey []byte
NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"`
}
// ParseECPrivateKey parses an EC private key in SEC 1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY".
func ParseECPrivateKey(der []byte) (*ecdsa.PrivateKey, error) {
return parseECPrivateKey(nil, der)
}
// MarshalECPrivateKey converts an EC private key to SEC 1, ASN.1 DER form.
//
// This kind of key is commonly encoded in PEM blocks of type "EC PRIVATE KEY".
// For a more flexible key format which is not EC specific, use
// MarshalPKCS8PrivateKey.
func MarshalECPrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
oid, ok := oidFromNamedCurve(key.Curve)
if !ok {
return nil, errors.New("x509: unknown elliptic curve")
}
return marshalECPrivateKeyWithOID(key, oid)
}
// marshalECPrivateKeyWithOID marshals an EC private key into ASN.1, DER format and
// sets the curve ID to the given OID, or omits it if OID is nil.
func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
if !key.Curve.IsOnCurve(key.X, key.Y) {
return nil, errors.New("invalid elliptic key public key")
}
privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
return asn1.Marshal(ecPrivateKey{
Version: 1,
PrivateKey: key.D.FillBytes(privateKey),
NamedCurveOID: oid,
PublicKey: asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
})
}
// marshalECPrivateKeyWithOID marshals an EC private key into ASN.1, DER format
// suitable for NIST curves.
func marshalECDHPrivateKey(key *ecdh.PrivateKey) ([]byte, error) {
return asn1.Marshal(ecPrivateKey{
Version: 1,
PrivateKey: key.Bytes(),
PublicKey: asn1.BitString{Bytes: key.PublicKey().Bytes()},
})
}
// parseECPrivateKey parses an ASN.1 Elliptic Curve Private Key Structure.
// The OID for the named curve may be provided from another source (such as
// the PKCS8 container) - if it is provided then use this instead of the OID
// that may exist in the EC private key structure.
func parseECPrivateKey(namedCurveOID *asn1.ObjectIdentifier, der []byte) (key *ecdsa.PrivateKey, err error) {
var privKey ecPrivateKey
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
if _, err := asn1.Unmarshal(der, &pkcs8{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)")
}
if _, err := asn1.Unmarshal(der, &pkcs1PrivateKey{}); err == nil {
return nil, errors.New("x509: failed to parse private key (use ParsePKCS1PrivateKey instead for this key format)")
}
return nil, errors.New("x509: failed to parse EC private key: " + err.Error())
}
if privKey.Version != ecPrivKeyVersion {
return nil, fmt.Errorf("x509: unknown EC private key version %d", privKey.Version)
}
var curve elliptic.Curve
if namedCurveOID != nil {
curve = namedCurveFromOID(*namedCurveOID)
} else {
curve = namedCurveFromOID(privKey.NamedCurveOID)
}
if curve == nil {
return nil, errors.New("x509: unknown elliptic curve")
}
k := new(big.Int).SetBytes(privKey.PrivateKey)
curveOrder := curve.Params().N
if k.Cmp(curveOrder) >= 0 {
return nil, errors.New("x509: invalid elliptic curve private key value")
}
priv := new(ecdsa.PrivateKey)
priv.Curve = curve
priv.D = k
privateKey := make([]byte, (curveOrder.BitLen()+7)/8)
// Some private keys have leading zero padding. This is invalid
// according to [SEC1], but this code will ignore it.
for len(privKey.PrivateKey) > len(privateKey) {
if privKey.PrivateKey[0] != 0 {
return nil, errors.New("x509: invalid private key length")
}
privKey.PrivateKey = privKey.PrivateKey[1:]
}
// Some private keys remove all leading zeros, this is also invalid
// according to [SEC1] but since OpenSSL used to do this, we ignore
// this too.
copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey)
priv.X, priv.Y = curve.ScalarBaseMult(privateKey)
return priv, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"bytes"
"crypto"
"crypto/x509/pkix"
"errors"
"fmt"
"net"
"net/url"
"reflect"
"runtime"
"strings"
"time"
"unicode/utf8"
)
type InvalidReason int
const (
// NotAuthorizedToSign results when a certificate is signed by another
// which isn't marked as a CA certificate.
NotAuthorizedToSign InvalidReason = iota
// Expired results when a certificate has expired, based on the time
// given in the VerifyOptions.
Expired
// CANotAuthorizedForThisName results when an intermediate or root
// certificate has a name constraint which doesn't permit a DNS or
// other name (including IP address) in the leaf certificate.
CANotAuthorizedForThisName
// TooManyIntermediates results when a path length constraint is
// violated.
TooManyIntermediates
// IncompatibleUsage results when the certificate's key usage indicates
// that it may only be used for a different purpose.
IncompatibleUsage
// NameMismatch results when the subject name of a parent certificate
// does not match the issuer name in the child.
NameMismatch
// NameConstraintsWithoutSANs is a legacy error and is no longer returned.
NameConstraintsWithoutSANs
// UnconstrainedName results when a CA certificate contains permitted
// name constraints, but leaf certificate contains a name of an
// unsupported or unconstrained type.
UnconstrainedName
// TooManyConstraints results when the number of comparison operations
// needed to check a certificate exceeds the limit set by
// VerifyOptions.MaxConstraintComparisions. This limit exists to
// prevent pathological certificates can consuming excessive amounts of
// CPU time to verify.
TooManyConstraints
// CANotAuthorizedForExtKeyUsage results when an intermediate or root
// certificate does not permit a requested extended key usage.
CANotAuthorizedForExtKeyUsage
)
// CertificateInvalidError results when an odd error occurs. Users of this
// library probably want to handle all these errors uniformly.
type CertificateInvalidError struct {
Cert *Certificate
Reason InvalidReason
Detail string
}
func (e CertificateInvalidError) Error() string {
switch e.Reason {
case NotAuthorizedToSign:
return "x509: certificate is not authorized to sign other certificates"
case Expired:
return "x509: certificate has expired or is not yet valid: " + e.Detail
case CANotAuthorizedForThisName:
return "x509: a root or intermediate certificate is not authorized to sign for this name: " + e.Detail
case CANotAuthorizedForExtKeyUsage:
return "x509: a root or intermediate certificate is not authorized for an extended key usage: " + e.Detail
case TooManyIntermediates:
return "x509: too many intermediates for path length constraint"
case IncompatibleUsage:
return "x509: certificate specifies an incompatible key usage"
case NameMismatch:
return "x509: issuer name does not match subject from issuing certificate"
case NameConstraintsWithoutSANs:
return "x509: issuer has name constraints but leaf doesn't have a SAN extension"
case UnconstrainedName:
return "x509: issuer has name constraints but leaf contains unknown or unconstrained name: " + e.Detail
}
return "x509: unknown error"
}
// HostnameError results when the set of authorized names doesn't match the
// requested name.
type HostnameError struct {
Certificate *Certificate
Host string
}
func (h HostnameError) Error() string {
c := h.Certificate
if !c.hasSANExtension() && matchHostnames(c.Subject.CommonName, h.Host) {
return "x509: certificate relies on legacy Common Name field, use SANs instead"
}
var valid string
if ip := net.ParseIP(h.Host); ip != nil {
// Trying to validate an IP
if len(c.IPAddresses) == 0 {
return "x509: cannot validate certificate for " + h.Host + " because it doesn't contain any IP SANs"
}
for _, san := range c.IPAddresses {
if len(valid) > 0 {
valid += ", "
}
valid += san.String()
}
} else {
valid = strings.Join(c.DNSNames, ", ")
}
if len(valid) == 0 {
return "x509: certificate is not valid for any names, but wanted to match " + h.Host
}
return "x509: certificate is valid for " + valid + ", not " + h.Host
}
// UnknownAuthorityError results when the certificate issuer is unknown
type UnknownAuthorityError struct {
Cert *Certificate
// hintErr contains an error that may be helpful in determining why an
// authority wasn't found.
hintErr error
// hintCert contains a possible authority certificate that was rejected
// because of the error in hintErr.
hintCert *Certificate
}
func (e UnknownAuthorityError) Error() string {
s := "x509: certificate signed by unknown authority"
if e.hintErr != nil {
certName := e.hintCert.Subject.CommonName
if len(certName) == 0 {
if len(e.hintCert.Subject.Organization) > 0 {
certName = e.hintCert.Subject.Organization[0]
} else {
certName = "serial:" + e.hintCert.SerialNumber.String()
}
}
s += fmt.Sprintf(" (possibly because of %q while trying to verify candidate authority certificate %q)", e.hintErr, certName)
}
return s
}
// SystemRootsError results when we fail to load the system root certificates.
type SystemRootsError struct {
Err error
}
func (se SystemRootsError) Error() string {
msg := "x509: failed to load system roots and no roots provided"
if se.Err != nil {
return msg + "; " + se.Err.Error()
}
return msg
}
func (se SystemRootsError) Unwrap() error { return se.Err }
// errNotParsed is returned when a certificate without ASN.1 contents is
// verified. Platform-specific verification needs the ASN.1 contents.
var errNotParsed = errors.New("x509: missing ASN.1 contents; use ParseCertificate")
// VerifyOptions contains parameters for Certificate.Verify.
type VerifyOptions struct {
// DNSName, if set, is checked against the leaf certificate with
// Certificate.VerifyHostname or the platform verifier.
DNSName string
// Intermediates is an optional pool of certificates that are not trust
// anchors, but can be used to form a chain from the leaf certificate to a
// root certificate.
Intermediates *CertPool
// Roots is the set of trusted root certificates the leaf certificate needs
// to chain up to. If nil, the system roots or the platform verifier are used.
Roots *CertPool
// CurrentTime is used to check the validity of all certificates in the
// chain. If zero, the current time is used.
CurrentTime time.Time
// KeyUsages specifies which Extended Key Usage values are acceptable. A
// chain is accepted if it allows any of the listed values. An empty list
// means ExtKeyUsageServerAuth. To accept any key usage, include ExtKeyUsageAny.
KeyUsages []ExtKeyUsage
// MaxConstraintComparisions is the maximum number of comparisons to
// perform when checking a given certificate's name constraints. If
// zero, a sensible default is used. This limit prevents pathological
// certificates from consuming excessive amounts of CPU time when
// validating. It does not apply to the platform verifier.
MaxConstraintComparisions int
}
const (
leafCertificate = iota
intermediateCertificate
rootCertificate
)
// rfc2821Mailbox represents a “mailbox” (which is an email address to most
// people) by breaking it into the “local” (i.e. before the '@') and “domain”
// parts.
type rfc2821Mailbox struct {
local, domain string
}
// parseRFC2821Mailbox parses an email address into local and domain parts,
// based on the ABNF for a “Mailbox” from RFC 2821. According to RFC 5280,
// Section 4.2.1.6 that's correct for an rfc822Name from a certificate: “The
// format of an rfc822Name is a "Mailbox" as defined in RFC 2821, Section 4.1.2”.
func parseRFC2821Mailbox(in string) (mailbox rfc2821Mailbox, ok bool) {
if len(in) == 0 {
return mailbox, false
}
localPartBytes := make([]byte, 0, len(in)/2)
if in[0] == '"' {
// Quoted-string = DQUOTE *qcontent DQUOTE
// non-whitespace-control = %d1-8 / %d11 / %d12 / %d14-31 / %d127
// qcontent = qtext / quoted-pair
// qtext = non-whitespace-control /
// %d33 / %d35-91 / %d93-126
// quoted-pair = ("\" text) / obs-qp
// text = %d1-9 / %d11 / %d12 / %d14-127 / obs-text
//
// (Names beginning with “obs-” are the obsolete syntax from RFC 2822,
// Section 4. Since it has been 16 years, we no longer accept that.)
in = in[1:]
QuotedString:
for {
if len(in) == 0 {
return mailbox, false
}
c := in[0]
in = in[1:]
switch {
case c == '"':
break QuotedString
case c == '\\':
// quoted-pair
if len(in) == 0 {
return mailbox, false
}
if in[0] == 11 ||
in[0] == 12 ||
(1 <= in[0] && in[0] <= 9) ||
(14 <= in[0] && in[0] <= 127) {
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
} else {
return mailbox, false
}
case c == 11 ||
c == 12 ||
// Space (char 32) is not allowed based on the
// BNF, but RFC 3696 gives an example that
// assumes that it is. Several “verified”
// errata continue to argue about this point.
// We choose to accept it.
c == 32 ||
c == 33 ||
c == 127 ||
(1 <= c && c <= 8) ||
(14 <= c && c <= 31) ||
(35 <= c && c <= 91) ||
(93 <= c && c <= 126):
// qtext
localPartBytes = append(localPartBytes, c)
default:
return mailbox, false
}
}
} else {
// Atom ("." Atom)*
NextChar:
for len(in) > 0 {
// atext from RFC 2822, Section 3.2.4
c := in[0]
switch {
case c == '\\':
// Examples given in RFC 3696 suggest that
// escaped characters can appear outside of a
// quoted string. Several “verified” errata
// continue to argue the point. We choose to
// accept it.
in = in[1:]
if len(in) == 0 {
return mailbox, false
}
fallthrough
case ('0' <= c && c <= '9') ||
('a' <= c && c <= 'z') ||
('A' <= c && c <= 'Z') ||
c == '!' || c == '#' || c == '$' || c == '%' ||
c == '&' || c == '\'' || c == '*' || c == '+' ||
c == '-' || c == '/' || c == '=' || c == '?' ||
c == '^' || c == '_' || c == '`' || c == '{' ||
c == '|' || c == '}' || c == '~' || c == '.':
localPartBytes = append(localPartBytes, in[0])
in = in[1:]
default:
break NextChar
}
}
if len(localPartBytes) == 0 {
return mailbox, false
}
// From RFC 3696, Section 3:
// “period (".") may also appear, but may not be used to start
// or end the local part, nor may two or more consecutive
// periods appear.”
twoDots := []byte{'.', '.'}
if localPartBytes[0] == '.' ||
localPartBytes[len(localPartBytes)-1] == '.' ||
bytes.Contains(localPartBytes, twoDots) {
return mailbox, false
}
}
if len(in) == 0 || in[0] != '@' {
return mailbox, false
}
in = in[1:]
// The RFC species a format for domains, but that's known to be
// violated in practice so we accept that anything after an '@' is the
// domain part.
if _, ok := domainToReverseLabels(in); !ok {
return mailbox, false
}
mailbox.local = string(localPartBytes)
mailbox.domain = in
return mailbox, true
}
// domainToReverseLabels converts a textual domain name like foo.example.com to
// the list of labels in reverse order, e.g. ["com", "example", "foo"].
func domainToReverseLabels(domain string) (reverseLabels []string, ok bool) {
for len(domain) > 0 {
if i := strings.LastIndexByte(domain, '.'); i == -1 {
reverseLabels = append(reverseLabels, domain)
domain = ""
} else {
reverseLabels = append(reverseLabels, domain[i+1:])
domain = domain[:i]
}
}
if len(reverseLabels) > 0 && len(reverseLabels[0]) == 0 {
// An empty label at the end indicates an absolute value.
return nil, false
}
for _, label := range reverseLabels {
if len(label) == 0 {
// Empty labels are otherwise invalid.
return nil, false
}
for _, c := range label {
if c < 33 || c > 126 {
// Invalid character.
return nil, false
}
}
}
return reverseLabels, true
}
func matchEmailConstraint(mailbox rfc2821Mailbox, constraint string) (bool, error) {
// If the constraint contains an @, then it specifies an exact mailbox
// name.
if strings.Contains(constraint, "@") {
constraintMailbox, ok := parseRFC2821Mailbox(constraint)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse constraint %q", constraint)
}
return mailbox.local == constraintMailbox.local && strings.EqualFold(mailbox.domain, constraintMailbox.domain), nil
}
// Otherwise the constraint is like a DNS constraint of the domain part
// of the mailbox.
return matchDomainConstraint(mailbox.domain, constraint)
}
func matchURIConstraint(uri *url.URL, constraint string) (bool, error) {
// From RFC 5280, Section 4.2.1.10:
// “a uniformResourceIdentifier that does not include an authority
// component with a host name specified as a fully qualified domain
// name (e.g., if the URI either does not include an authority
// component or includes an authority component in which the host name
// is specified as an IP address), then the application MUST reject the
// certificate.”
host := uri.Host
if len(host) == 0 {
return false, fmt.Errorf("URI with empty host (%q) cannot be matched against constraints", uri.String())
}
if strings.Contains(host, ":") && !strings.HasSuffix(host, "]") {
var err error
host, _, err = net.SplitHostPort(uri.Host)
if err != nil {
return false, err
}
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") ||
net.ParseIP(host) != nil {
return false, fmt.Errorf("URI with IP (%q) cannot be matched against constraints", uri.String())
}
return matchDomainConstraint(host, constraint)
}
func matchIPConstraint(ip net.IP, constraint *net.IPNet) (bool, error) {
if len(ip) != len(constraint.IP) {
return false, nil
}
for i := range ip {
if mask := constraint.Mask[i]; ip[i]&mask != constraint.IP[i]&mask {
return false, nil
}
}
return true, nil
}
func matchDomainConstraint(domain, constraint string) (bool, error) {
// The meaning of zero length constraints is not specified, but this
// code follows NSS and accepts them as matching everything.
if len(constraint) == 0 {
return true, nil
}
domainLabels, ok := domainToReverseLabels(domain)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse domain %q", domain)
}
// RFC 5280 says that a leading period in a domain name means that at
// least one label must be prepended, but only for URI and email
// constraints, not DNS constraints. The code also supports that
// behaviour for DNS constraints.
mustHaveSubdomains := false
if constraint[0] == '.' {
mustHaveSubdomains = true
constraint = constraint[1:]
}
constraintLabels, ok := domainToReverseLabels(constraint)
if !ok {
return false, fmt.Errorf("x509: internal error: cannot parse domain %q", constraint)
}
if len(domainLabels) < len(constraintLabels) ||
(mustHaveSubdomains && len(domainLabels) == len(constraintLabels)) {
return false, nil
}
for i, constraintLabel := range constraintLabels {
if !strings.EqualFold(constraintLabel, domainLabels[i]) {
return false, nil
}
}
return true, nil
}
// checkNameConstraints checks that c permits a child certificate to claim the
// given name, of type nameType. The argument parsedName contains the parsed
// form of name, suitable for passing to the match function. The total number
// of comparisons is tracked in the given count and should not exceed the given
// limit.
func (c *Certificate) checkNameConstraints(count *int,
maxConstraintComparisons int,
nameType string,
name string,
parsedName any,
match func(parsedName, constraint any) (match bool, err error),
permitted, excluded any) error {
excludedValue := reflect.ValueOf(excluded)
*count += excludedValue.Len()
if *count > maxConstraintComparisons {
return CertificateInvalidError{c, TooManyConstraints, ""}
}
for i := 0; i < excludedValue.Len(); i++ {
constraint := excludedValue.Index(i).Interface()
match, err := match(parsedName, constraint)
if err != nil {
return CertificateInvalidError{c, CANotAuthorizedForThisName, err.Error()}
}
if match {
return CertificateInvalidError{c, CANotAuthorizedForThisName, fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint)}
}
}
permittedValue := reflect.ValueOf(permitted)
*count += permittedValue.Len()
if *count > maxConstraintComparisons {
return CertificateInvalidError{c, TooManyConstraints, ""}
}
ok := true
for i := 0; i < permittedValue.Len(); i++ {
constraint := permittedValue.Index(i).Interface()
var err error
if ok, err = match(parsedName, constraint); err != nil {
return CertificateInvalidError{c, CANotAuthorizedForThisName, err.Error()}
}
if ok {
break
}
}
if !ok {
return CertificateInvalidError{c, CANotAuthorizedForThisName, fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name)}
}
return nil
}
// isValid performs validity checks on c given that it is a candidate to append
// to the chain in currentChain.
func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error {
if len(c.UnhandledCriticalExtensions) > 0 {
return UnhandledCriticalExtension{}
}
if len(currentChain) > 0 {
child := currentChain[len(currentChain)-1]
if !bytes.Equal(child.RawIssuer, c.RawSubject) {
return CertificateInvalidError{c, NameMismatch, ""}
}
}
now := opts.CurrentTime
if now.IsZero() {
now = time.Now()
}
if now.Before(c.NotBefore) {
return CertificateInvalidError{
Cert: c,
Reason: Expired,
Detail: fmt.Sprintf("current time %s is before %s", now.Format(time.RFC3339), c.NotBefore.Format(time.RFC3339)),
}
} else if now.After(c.NotAfter) {
return CertificateInvalidError{
Cert: c,
Reason: Expired,
Detail: fmt.Sprintf("current time %s is after %s", now.Format(time.RFC3339), c.NotAfter.Format(time.RFC3339)),
}
}
maxConstraintComparisons := opts.MaxConstraintComparisions
if maxConstraintComparisons == 0 {
maxConstraintComparisons = 250000
}
comparisonCount := 0
var leaf *Certificate
if certType == intermediateCertificate || certType == rootCertificate {
if len(currentChain) == 0 {
return errors.New("x509: internal error: empty chain when appending CA cert")
}
leaf = currentChain[0]
}
if (certType == intermediateCertificate || certType == rootCertificate) &&
c.hasNameConstraints() {
toCheck := []*Certificate{}
if leaf.hasSANExtension() {
toCheck = append(toCheck, leaf)
}
if c.hasSANExtension() {
toCheck = append(toCheck, c)
}
for _, sanCert := range toCheck {
err := forEachSAN(sanCert.getSANExtension(), func(tag int, data []byte) error {
switch tag {
case nameTypeEmail:
name := string(data)
mailbox, ok := parseRFC2821Mailbox(name)
if !ok {
return fmt.Errorf("x509: cannot parse rfc822Name %q", mailbox)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "email address", name, mailbox,
func(parsedName, constraint any) (bool, error) {
return matchEmailConstraint(parsedName.(rfc2821Mailbox), constraint.(string))
}, c.PermittedEmailAddresses, c.ExcludedEmailAddresses); err != nil {
return err
}
case nameTypeDNS:
name := string(data)
if _, ok := domainToReverseLabels(name); !ok {
return fmt.Errorf("x509: cannot parse dnsName %q", name)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "DNS name", name, name,
func(parsedName, constraint any) (bool, error) {
return matchDomainConstraint(parsedName.(string), constraint.(string))
}, c.PermittedDNSDomains, c.ExcludedDNSDomains); err != nil {
return err
}
case nameTypeURI:
name := string(data)
uri, err := url.Parse(name)
if err != nil {
return fmt.Errorf("x509: internal error: URI SAN %q failed to parse", name)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "URI", name, uri,
func(parsedName, constraint any) (bool, error) {
return matchURIConstraint(parsedName.(*url.URL), constraint.(string))
}, c.PermittedURIDomains, c.ExcludedURIDomains); err != nil {
return err
}
case nameTypeIP:
ip := net.IP(data)
if l := len(ip); l != net.IPv4len && l != net.IPv6len {
return fmt.Errorf("x509: internal error: IP SAN %x failed to parse", data)
}
if err := c.checkNameConstraints(&comparisonCount, maxConstraintComparisons, "IP address", ip.String(), ip,
func(parsedName, constraint any) (bool, error) {
return matchIPConstraint(parsedName.(net.IP), constraint.(*net.IPNet))
}, c.PermittedIPRanges, c.ExcludedIPRanges); err != nil {
return err
}
default:
// Unknown SAN types are ignored.
}
return nil
})
if err != nil {
return err
}
}
}
// KeyUsage status flags are ignored. From Engineering Security, Peter
// Gutmann: A European government CA marked its signing certificates as
// being valid for encryption only, but no-one noticed. Another
// European CA marked its signature keys as not being valid for
// signatures. A different CA marked its own trusted root certificate
// as being invalid for certificate signing. Another national CA
// distributed a certificate to be used to encrypt data for the
// country’s tax authority that was marked as only being usable for
// digital signatures but not for encryption. Yet another CA reversed
// the order of the bit flags in the keyUsage due to confusion over
// encoding endianness, essentially setting a random keyUsage in
// certificates that it issued. Another CA created a self-invalidating
// certificate by adding a certificate policy statement stipulating
// that the certificate had to be used strictly as specified in the
// keyUsage, and a keyUsage containing a flag indicating that the RSA
// encryption key could only be used for Diffie-Hellman key agreement.
if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) {
return CertificateInvalidError{c, NotAuthorizedToSign, ""}
}
if c.BasicConstraintsValid && c.MaxPathLen >= 0 {
numIntermediates := len(currentChain) - 1
if numIntermediates > c.MaxPathLen {
return CertificateInvalidError{c, TooManyIntermediates, ""}
}
}
if !boringAllowCert(c) {
// IncompatibleUsage is not quite right here,
// but it's also the "no chains found" error
// and is close enough.
return CertificateInvalidError{c, IncompatibleUsage, ""}
}
return nil
}
// Verify attempts to verify c by building one or more chains from c to a
// certificate in opts.Roots, using certificates in opts.Intermediates if
// needed. If successful, it returns one or more chains where the first
// element of the chain is c and the last element is from opts.Roots.
//
// If opts.Roots is nil, the platform verifier might be used, and
// verification details might differ from what is described below. If system
// roots are unavailable the returned error will be of type SystemRootsError.
//
// Name constraints in the intermediates will be applied to all names claimed
// in the chain, not just opts.DNSName. Thus it is invalid for a leaf to claim
// example.com if an intermediate doesn't permit it, even if example.com is not
// the name being validated. Note that DirectoryName constraints are not
// supported.
//
// Name constraint validation follows the rules from RFC 5280, with the
// addition that DNS name constraints may use the leading period format
// defined for emails and URIs. When a constraint has a leading period
// it indicates that at least one additional label must be prepended to
// the constrained name to be considered valid.
//
// Extended Key Usage values are enforced nested down a chain, so an intermediate
// or root that enumerates EKUs prevents a leaf from asserting an EKU not in that
// list. (While this is not specified, it is common practice in order to limit
// the types of certificates a CA can issue.)
//
// Certificates that use SHA1WithRSA and ECDSAWithSHA1 signatures are not supported,
// and will not be used to build chains.
//
// Certificates other than c in the returned chains should not be modified.
//
// WARNING: this function doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
// Platform-specific verification needs the ASN.1 contents so
// this makes the behavior consistent across platforms.
if len(c.Raw) == 0 {
return nil, errNotParsed
}
for i := 0; i < opts.Intermediates.len(); i++ {
c, err := opts.Intermediates.cert(i)
if err != nil {
return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
}
if len(c.Raw) == 0 {
return nil, errNotParsed
}
}
// Use platform verifiers, where available, if Roots is from SystemCertPool.
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
// Don't use the system verifier if the system pool was replaced with a non-system pool,
// i.e. if SetFallbackRoots was called with x509usefallbackroots=1.
systemPool := systemRootsPool()
if opts.Roots == nil && (systemPool == nil || systemPool.systemPool) {
return c.systemVerify(&opts)
}
if opts.Roots != nil && opts.Roots.systemPool {
platformChains, err := c.systemVerify(&opts)
// If the platform verifier succeeded, or there are no additional
// roots, return the platform verifier result. Otherwise, continue
// with the Go verifier.
if err == nil || opts.Roots.len() == 0 {
return platformChains, err
}
}
}
if opts.Roots == nil {
opts.Roots = systemRootsPool()
if opts.Roots == nil {
return nil, SystemRootsError{systemRootsErr}
}
}
err = c.isValid(leafCertificate, nil, &opts)
if err != nil {
return
}
if len(opts.DNSName) > 0 {
err = c.VerifyHostname(opts.DNSName)
if err != nil {
return
}
}
var candidateChains [][]*Certificate
if opts.Roots.contains(c) {
candidateChains = [][]*Certificate{{c}}
} else {
candidateChains, err = c.buildChains([]*Certificate{c}, nil, &opts)
if err != nil {
return nil, err
}
}
if len(opts.KeyUsages) == 0 {
opts.KeyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth}
}
for _, eku := range opts.KeyUsages {
if eku == ExtKeyUsageAny {
// If any key usage is acceptable, no need to check the chain for
// key usages.
return candidateChains, nil
}
}
chains = make([][]*Certificate, 0, len(candidateChains))
for _, candidate := range candidateChains {
if checkChainForKeyUsage(candidate, opts.KeyUsages) {
chains = append(chains, candidate)
}
}
if len(chains) == 0 {
return nil, CertificateInvalidError{c, IncompatibleUsage, ""}
}
return chains, nil
}
func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate {
n := make([]*Certificate, len(chain)+1)
copy(n, chain)
n[len(chain)] = cert
return n
}
// alreadyInChain checks whether a candidate certificate is present in a chain.
// Rather than doing a direct byte for byte equivalency check, we check if the
// subject, public key, and SAN, if present, are equal. This prevents loops that
// are created by mutual cross-signatures, or other cross-signature bridge
// oddities.
func alreadyInChain(candidate *Certificate, chain []*Certificate) bool {
type pubKeyEqual interface {
Equal(crypto.PublicKey) bool
}
var candidateSAN *pkix.Extension
for _, ext := range candidate.Extensions {
if ext.Id.Equal(oidExtensionSubjectAltName) {
candidateSAN = &ext
break
}
}
for _, cert := range chain {
if !bytes.Equal(candidate.RawSubject, cert.RawSubject) {
continue
}
if !candidate.PublicKey.(pubKeyEqual).Equal(cert.PublicKey) {
continue
}
var certSAN *pkix.Extension
for _, ext := range cert.Extensions {
if ext.Id.Equal(oidExtensionSubjectAltName) {
certSAN = &ext
break
}
}
if candidateSAN == nil && certSAN == nil {
return true
} else if candidateSAN == nil || certSAN == nil {
return false
}
if bytes.Equal(candidateSAN.Value, certSAN.Value) {
return true
}
}
return false
}
// maxChainSignatureChecks is the maximum number of CheckSignatureFrom calls
// that an invocation of buildChains will (transitively) make. Most chains are
// less than 15 certificates long, so this leaves space for multiple chains and
// for failed checks due to different intermediates having the same Subject.
const maxChainSignatureChecks = 100
func (c *Certificate) buildChains(currentChain []*Certificate, sigChecks *int, opts *VerifyOptions) (chains [][]*Certificate, err error) {
var (
hintErr error
hintCert *Certificate
)
considerCandidate := func(certType int, candidate *Certificate) {
if alreadyInChain(candidate, currentChain) {
return
}
if sigChecks == nil {
sigChecks = new(int)
}
*sigChecks++
if *sigChecks > maxChainSignatureChecks {
err = errors.New("x509: signature check attempts limit reached while verifying certificate chain")
return
}
if err := c.CheckSignatureFrom(candidate); err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate
}
return
}
err = candidate.isValid(certType, currentChain, opts)
if err != nil {
if hintErr == nil {
hintErr = err
hintCert = candidate
}
return
}
switch certType {
case rootCertificate:
chains = append(chains, appendToFreshChain(currentChain, candidate))
case intermediateCertificate:
var childChains [][]*Certificate
childChains, err = candidate.buildChains(appendToFreshChain(currentChain, candidate), sigChecks, opts)
chains = append(chains, childChains...)
}
}
for _, root := range opts.Roots.findPotentialParents(c) {
considerCandidate(rootCertificate, root)
}
for _, intermediate := range opts.Intermediates.findPotentialParents(c) {
considerCandidate(intermediateCertificate, intermediate)
}
if len(chains) > 0 {
err = nil
}
if len(chains) == 0 && err == nil {
err = UnknownAuthorityError{c, hintErr, hintCert}
}
return
}
func validHostnamePattern(host string) bool { return validHostname(host, true) }
func validHostnameInput(host string) bool { return validHostname(host, false) }
// validHostname reports whether host is a valid hostname that can be matched or
// matched against according to RFC 6125 2.2, with some leniency to accommodate
// legacy values.
func validHostname(host string, isPattern bool) bool {
if !isPattern {
host = strings.TrimSuffix(host, ".")
}
if len(host) == 0 {
return false
}
for i, part := range strings.Split(host, ".") {
if part == "" {
// Empty label.
return false
}
if isPattern && i == 0 && part == "*" {
// Only allow full left-most wildcards, as those are the only ones
// we match, and matching literal '*' characters is probably never
// the expected behavior.
continue
}
for j, c := range part {
if 'a' <= c && c <= 'z' {
continue
}
if '0' <= c && c <= '9' {
continue
}
if 'A' <= c && c <= 'Z' {
continue
}
if c == '-' && j != 0 {
continue
}
if c == '_' {
// Not a valid character in hostnames, but commonly
// found in deployments outside the WebPKI.
continue
}
return false
}
}
return true
}
func matchExactly(hostA, hostB string) bool {
if hostA == "" || hostA == "." || hostB == "" || hostB == "." {
return false
}
return toLowerCaseASCII(hostA) == toLowerCaseASCII(hostB)
}
func matchHostnames(pattern, host string) bool {
pattern = toLowerCaseASCII(pattern)
host = toLowerCaseASCII(strings.TrimSuffix(host, "."))
if len(pattern) == 0 || len(host) == 0 {
return false
}
patternParts := strings.Split(pattern, ".")
hostParts := strings.Split(host, ".")
if len(patternParts) != len(hostParts) {
return false
}
for i, patternPart := range patternParts {
if i == 0 && patternPart == "*" {
continue
}
if patternPart != hostParts[i] {
return false
}
}
return true
}
// toLowerCaseASCII returns a lower-case version of in. See RFC 6125 6.4.1. We use
// an explicitly ASCII function to avoid any sharp corners resulting from
// performing Unicode operations on DNS labels.
func toLowerCaseASCII(in string) string {
// If the string is already lower-case then there's nothing to do.
isAlreadyLowerCase := true
for _, c := range in {
if c == utf8.RuneError {
// If we get a UTF-8 error then there might be
// upper-case ASCII bytes in the invalid sequence.
isAlreadyLowerCase = false
break
}
if 'A' <= c && c <= 'Z' {
isAlreadyLowerCase = false
break
}
}
if isAlreadyLowerCase {
return in
}
out := []byte(in)
for i, c := range out {
if 'A' <= c && c <= 'Z' {
out[i] += 'a' - 'A'
}
}
return string(out)
}
// VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an error describing the mismatch.
//
// IP addresses can be optionally enclosed in square brackets and are checked
// against the IPAddresses field. Other names are checked case insensitively
// against the DNSNames field. If the names are valid hostnames, the certificate
// fields can have a wildcard as the left-most label.
//
// Note that the legacy Common Name field is ignored.
func (c *Certificate) VerifyHostname(h string) error {
// IP addresses may be written in [ ].
candidateIP := h
if len(h) >= 3 && h[0] == '[' && h[len(h)-1] == ']' {
candidateIP = h[1 : len(h)-1]
}
if ip := net.ParseIP(candidateIP); ip != nil {
// We only match IP addresses against IP SANs.
// See RFC 6125, Appendix B.2.
for _, candidate := range c.IPAddresses {
if ip.Equal(candidate) {
return nil
}
}
return HostnameError{c, candidateIP}
}
candidateName := toLowerCaseASCII(h) // Save allocations inside the loop.
validCandidateName := validHostnameInput(candidateName)
for _, match := range c.DNSNames {
// Ideally, we'd only match valid hostnames according to RFC 6125 like
// browsers (more or less) do, but in practice Go is used in a wider
// array of contexts and can't even assume DNS resolution. Instead,
// always allow perfect matches, and only apply wildcard and trailing
// dot processing to valid hostnames.
if validCandidateName && validHostnamePattern(match) {
if matchHostnames(match, candidateName) {
return nil
}
} else {
if matchExactly(match, candidateName) {
return nil
}
}
}
return HostnameError{c, h}
}
func checkChainForKeyUsage(chain []*Certificate, keyUsages []ExtKeyUsage) bool {
usages := make([]ExtKeyUsage, len(keyUsages))
copy(usages, keyUsages)
if len(chain) == 0 {
return false
}
usagesRemaining := len(usages)
// We walk down the list and cross out any usages that aren't supported
// by each certificate. If we cross out all the usages, then the chain
// is unacceptable.
NextCert:
for i := len(chain) - 1; i >= 0; i-- {
cert := chain[i]
if len(cert.ExtKeyUsage) == 0 && len(cert.UnknownExtKeyUsage) == 0 {
// The certificate doesn't have any extended key usage specified.
continue
}
for _, usage := range cert.ExtKeyUsage {
if usage == ExtKeyUsageAny {
// The certificate is explicitly good for any usage.
continue NextCert
}
}
const invalidUsage ExtKeyUsage = -1
NextRequestedUsage:
for i, requestedUsage := range usages {
if requestedUsage == invalidUsage {
continue
}
for _, usage := range cert.ExtKeyUsage {
if requestedUsage == usage {
continue NextRequestedUsage
}
}
usages[i] = invalidUsage
usagesRemaining--
if usagesRemaining == 0 {
return false
}
}
}
return true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package x509 implements a subset of the X.509 standard.
//
// It allows parsing and generating certificates, certificate signing
// requests, certificate revocation lists, and encoded public and private keys.
// It provides a certificate verifier, complete with a chain builder.
//
// The package targets the X.509 technical profile defined by the IETF (RFC
// 2459/3280/5280), and as further restricted by the CA/Browser Forum Baseline
// Requirements. There is minimal support for features outside of these
// profiles, as the primary goal of the package is to provide compatibility
// with the publicly trusted TLS certificate ecosystem and its policies and
// constraints.
//
// On macOS and Windows, certificate verification is handled by system APIs, but
// the package aims to apply consistent validation rules across operating
// systems.
package x509
import (
"bytes"
"crypto"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha1"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"internal/godebug"
"io"
"math/big"
"net"
"net/url"
"strconv"
"time"
"unicode"
// Explicitly import these for their crypto.RegisterHash init side-effects.
// Keep these as blank imports, even if they're imported above.
_ "crypto/sha1"
_ "crypto/sha256"
_ "crypto/sha512"
"golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
)
// pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo
// in RFC 3280.
type pkixPublicKey struct {
Algo pkix.AlgorithmIdentifier
BitString asn1.BitString
}
// ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. The encoded
// public key is a SubjectPublicKeyInfo structure (see RFC 5280, Section 4.1).
//
// It returns a *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey,
// ed25519.PublicKey (not a pointer), or *ecdh.PublicKey (for X25519).
// More types might be supported in the future.
//
// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY".
func ParsePKIXPublicKey(derBytes []byte) (pub any, err error) {
var pki publicKeyInfo
if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil {
if _, err := asn1.Unmarshal(derBytes, &pkcs1PublicKey{}); err == nil {
return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)")
}
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after ASN.1 of public-key")
}
return parsePublicKey(&pki)
}
func marshalPublicKey(pub any) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) {
switch pub := pub.(type) {
case *rsa.PublicKey:
publicKeyBytes, err = asn1.Marshal(pkcs1PublicKey{
N: pub.N,
E: pub.E,
})
if err != nil {
return nil, pkix.AlgorithmIdentifier{}, err
}
publicKeyAlgorithm.Algorithm = oidPublicKeyRSA
// This is a NULL parameters value which is required by
// RFC 3279, Section 2.3.1.
publicKeyAlgorithm.Parameters = asn1.NullRawValue
case *ecdsa.PublicKey:
oid, ok := oidFromNamedCurve(pub.Curve)
if !ok {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve")
}
if !pub.Curve.IsOnCurve(pub.X, pub.Y) {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: invalid elliptic curve public key")
}
publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y)
publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA
var paramBytes []byte
paramBytes, err = asn1.Marshal(oid)
if err != nil {
return
}
publicKeyAlgorithm.Parameters.FullBytes = paramBytes
case ed25519.PublicKey:
publicKeyBytes = pub
publicKeyAlgorithm.Algorithm = oidPublicKeyEd25519
case *ecdh.PublicKey:
publicKeyBytes = pub.Bytes()
if pub.Curve() == ecdh.X25519() {
publicKeyAlgorithm.Algorithm = oidPublicKeyX25519
} else {
oid, ok := oidFromECDHCurve(pub.Curve())
if !ok {
return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve")
}
publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA
var paramBytes []byte
paramBytes, err = asn1.Marshal(oid)
if err != nil {
return
}
publicKeyAlgorithm.Parameters.FullBytes = paramBytes
}
default:
return nil, pkix.AlgorithmIdentifier{}, fmt.Errorf("x509: unsupported public key type: %T", pub)
}
return publicKeyBytes, publicKeyAlgorithm, nil
}
// MarshalPKIXPublicKey converts a public key to PKIX, ASN.1 DER form.
// The encoded public key is a SubjectPublicKeyInfo structure
// (see RFC 5280, Section 4.1).
//
// The following key types are currently supported: *rsa.PublicKey,
// *ecdsa.PublicKey, ed25519.PublicKey (not a pointer), and *ecdh.PublicKey.
// Unsupported key types result in an error.
//
// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY".
func MarshalPKIXPublicKey(pub any) ([]byte, error) {
var publicKeyBytes []byte
var publicKeyAlgorithm pkix.AlgorithmIdentifier
var err error
if publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(pub); err != nil {
return nil, err
}
pkix := pkixPublicKey{
Algo: publicKeyAlgorithm,
BitString: asn1.BitString{
Bytes: publicKeyBytes,
BitLength: 8 * len(publicKeyBytes),
},
}
ret, _ := asn1.Marshal(pkix)
return ret, nil
}
// These structures reflect the ASN.1 structure of X.509 certificates.:
type certificate struct {
TBSCertificate tbsCertificate
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
type tbsCertificate struct {
Raw asn1.RawContent
Version int `asn1:"optional,explicit,default:0,tag:0"`
SerialNumber *big.Int
SignatureAlgorithm pkix.AlgorithmIdentifier
Issuer asn1.RawValue
Validity validity
Subject asn1.RawValue
PublicKey publicKeyInfo
UniqueId asn1.BitString `asn1:"optional,tag:1"`
SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"`
Extensions []pkix.Extension `asn1:"omitempty,optional,explicit,tag:3"`
}
type dsaAlgorithmParameters struct {
P, Q, G *big.Int
}
type validity struct {
NotBefore, NotAfter time.Time
}
type publicKeyInfo struct {
Raw asn1.RawContent
Algorithm pkix.AlgorithmIdentifier
PublicKey asn1.BitString
}
// RFC 5280, 4.2.1.1
type authKeyId struct {
Id []byte `asn1:"optional,tag:0"`
}
type SignatureAlgorithm int
const (
UnknownSignatureAlgorithm SignatureAlgorithm = iota
MD2WithRSA // Unsupported.
MD5WithRSA // Only supported for signing, not verification.
SHA1WithRSA // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses.
SHA256WithRSA
SHA384WithRSA
SHA512WithRSA
DSAWithSHA1 // Unsupported.
DSAWithSHA256 // Unsupported.
ECDSAWithSHA1 // Only supported for signing, and verification of CRLs, CSRs, and OCSP responses.
ECDSAWithSHA256
ECDSAWithSHA384
ECDSAWithSHA512
SHA256WithRSAPSS
SHA384WithRSAPSS
SHA512WithRSAPSS
PureEd25519
)
func (algo SignatureAlgorithm) isRSAPSS() bool {
switch algo {
case SHA256WithRSAPSS, SHA384WithRSAPSS, SHA512WithRSAPSS:
return true
default:
return false
}
}
func (algo SignatureAlgorithm) String() string {
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
return details.name
}
}
return strconv.Itoa(int(algo))
}
type PublicKeyAlgorithm int
const (
UnknownPublicKeyAlgorithm PublicKeyAlgorithm = iota
RSA
DSA // Only supported for parsing.
ECDSA
Ed25519
)
var publicKeyAlgoName = [...]string{
RSA: "RSA",
DSA: "DSA",
ECDSA: "ECDSA",
Ed25519: "Ed25519",
}
func (algo PublicKeyAlgorithm) String() string {
if 0 < algo && int(algo) < len(publicKeyAlgoName) {
return publicKeyAlgoName[algo]
}
return strconv.Itoa(int(algo))
}
// OIDs for signature algorithms
//
// pkcs-1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) rsadsi(113549) pkcs(1) 1 }
//
// RFC 3279 2.2.1 RSA Signature Algorithms
//
// md2WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 2 }
//
// md5WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 4 }
//
// sha-1WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 5 }
//
// dsaWithSha1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 3 }
//
// RFC 3279 2.2.3 ECDSA Signature Algorithm
//
// ecdsa-with-SHA1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-x962(10045)
// signatures(4) ecdsa-with-SHA1(1)}
//
// RFC 4055 5 PKCS #1 Version 1.5
//
// sha256WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 11 }
//
// sha384WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 12 }
//
// sha512WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 13 }
//
// RFC 5758 3.1 DSA Signature Algorithms
//
// dsaWithSha256 OBJECT IDENTIFIER ::= {
// joint-iso-ccitt(2) country(16) us(840) organization(1) gov(101)
// csor(3) algorithms(4) id-dsa-with-sha2(3) 2}
//
// RFC 5758 3.2 ECDSA Signature Algorithm
//
// ecdsa-with-SHA256 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 2 }
//
// ecdsa-with-SHA384 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 3 }
//
// ecdsa-with-SHA512 OBJECT IDENTIFIER ::= { iso(1) member-body(2)
// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 4 }
//
// RFC 8410 3 Curve25519 and Curve448 Algorithm Identifiers
//
// id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 }
var (
oidSignatureMD2WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 2}
oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4}
oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5}
oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11}
oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12}
oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13}
oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10}
oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3}
oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2}
oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1}
oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2}
oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3}
oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4}
oidSignatureEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112}
oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}
oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2}
oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3}
oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8}
// oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA
// but it's specified by ISO. Microsoft's makecert.exe has been known
// to produce certificates with this OID.
oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29}
)
var signatureAlgorithmDetails = []struct {
algo SignatureAlgorithm
name string
oid asn1.ObjectIdentifier
pubKeyAlgo PublicKeyAlgorithm
hash crypto.Hash
}{
{MD2WithRSA, "MD2-RSA", oidSignatureMD2WithRSA, RSA, crypto.Hash(0) /* no value for MD2 */},
{MD5WithRSA, "MD5-RSA", oidSignatureMD5WithRSA, RSA, crypto.MD5},
{SHA1WithRSA, "SHA1-RSA", oidSignatureSHA1WithRSA, RSA, crypto.SHA1},
{SHA1WithRSA, "SHA1-RSA", oidISOSignatureSHA1WithRSA, RSA, crypto.SHA1},
{SHA256WithRSA, "SHA256-RSA", oidSignatureSHA256WithRSA, RSA, crypto.SHA256},
{SHA384WithRSA, "SHA384-RSA", oidSignatureSHA384WithRSA, RSA, crypto.SHA384},
{SHA512WithRSA, "SHA512-RSA", oidSignatureSHA512WithRSA, RSA, crypto.SHA512},
{SHA256WithRSAPSS, "SHA256-RSAPSS", oidSignatureRSAPSS, RSA, crypto.SHA256},
{SHA384WithRSAPSS, "SHA384-RSAPSS", oidSignatureRSAPSS, RSA, crypto.SHA384},
{SHA512WithRSAPSS, "SHA512-RSAPSS", oidSignatureRSAPSS, RSA, crypto.SHA512},
{DSAWithSHA1, "DSA-SHA1", oidSignatureDSAWithSHA1, DSA, crypto.SHA1},
{DSAWithSHA256, "DSA-SHA256", oidSignatureDSAWithSHA256, DSA, crypto.SHA256},
{ECDSAWithSHA1, "ECDSA-SHA1", oidSignatureECDSAWithSHA1, ECDSA, crypto.SHA1},
{ECDSAWithSHA256, "ECDSA-SHA256", oidSignatureECDSAWithSHA256, ECDSA, crypto.SHA256},
{ECDSAWithSHA384, "ECDSA-SHA384", oidSignatureECDSAWithSHA384, ECDSA, crypto.SHA384},
{ECDSAWithSHA512, "ECDSA-SHA512", oidSignatureECDSAWithSHA512, ECDSA, crypto.SHA512},
{PureEd25519, "Ed25519", oidSignatureEd25519, Ed25519, crypto.Hash(0) /* no pre-hashing */},
}
// hashToPSSParameters contains the DER encoded RSA PSS parameters for the
// SHA256, SHA384, and SHA512 hashes as defined in RFC 3447, Appendix A.2.3.
// The parameters contain the following values:
// - hashAlgorithm contains the associated hash identifier with NULL parameters
// - maskGenAlgorithm always contains the default mgf1SHA1 identifier
// - saltLength contains the length of the associated hash
// - trailerField always contains the default trailerFieldBC value
var hashToPSSParameters = map[crypto.Hash]asn1.RawValue{
crypto.SHA256: asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 1, 5, 0, 162, 3, 2, 1, 32}},
crypto.SHA384: asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 2, 5, 0, 162, 3, 2, 1, 48}},
crypto.SHA512: asn1.RawValue{FullBytes: []byte{48, 52, 160, 15, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 161, 28, 48, 26, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 8, 48, 13, 6, 9, 96, 134, 72, 1, 101, 3, 4, 2, 3, 5, 0, 162, 3, 2, 1, 64}},
}
// pssParameters reflects the parameters in an AlgorithmIdentifier that
// specifies RSA PSS. See RFC 3447, Appendix A.2.3.
type pssParameters struct {
// The following three fields are not marked as
// optional because the default values specify SHA-1,
// which is no longer suitable for use in signatures.
Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"`
MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"`
SaltLength int `asn1:"explicit,tag:2"`
TrailerField int `asn1:"optional,explicit,tag:3,default:1"`
}
func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) SignatureAlgorithm {
if ai.Algorithm.Equal(oidSignatureEd25519) {
// RFC 8410, Section 3
// > For all of the OIDs, the parameters MUST be absent.
if len(ai.Parameters.FullBytes) != 0 {
return UnknownSignatureAlgorithm
}
}
if !ai.Algorithm.Equal(oidSignatureRSAPSS) {
for _, details := range signatureAlgorithmDetails {
if ai.Algorithm.Equal(details.oid) {
return details.algo
}
}
return UnknownSignatureAlgorithm
}
// RSA PSS is special because it encodes important parameters
// in the Parameters.
var params pssParameters
if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, ¶ms); err != nil {
return UnknownSignatureAlgorithm
}
var mgf1HashFunc pkix.AlgorithmIdentifier
if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil {
return UnknownSignatureAlgorithm
}
// PSS is greatly overburdened with options. This code forces them into
// three buckets by requiring that the MGF1 hash function always match the
// message hash function (as recommended in RFC 3447, Section 8.1), that the
// salt length matches the hash length, and that the trailer field has the
// default value.
if (len(params.Hash.Parameters.FullBytes) != 0 && !bytes.Equal(params.Hash.Parameters.FullBytes, asn1.NullBytes)) ||
!params.MGF.Algorithm.Equal(oidMGF1) ||
!mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) ||
(len(mgf1HashFunc.Parameters.FullBytes) != 0 && !bytes.Equal(mgf1HashFunc.Parameters.FullBytes, asn1.NullBytes)) ||
params.TrailerField != 1 {
return UnknownSignatureAlgorithm
}
switch {
case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32:
return SHA256WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48:
return SHA384WithRSAPSS
case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64:
return SHA512WithRSAPSS
}
return UnknownSignatureAlgorithm
}
var (
// RFC 3279, 2.3 Public Key Algorithms
//
// pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840)
// rsadsi(113549) pkcs(1) 1 }
//
// rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 }
//
// id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840)
// x9-57(10040) x9cm(4) 1 }
oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1}
oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1}
// RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters
//
// id-ecPublicKey OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 }
oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
// RFC 8410, Section 3
//
// id-X25519 OBJECT IDENTIFIER ::= { 1 3 101 110 }
// id-Ed25519 OBJECT IDENTIFIER ::= { 1 3 101 112 }
oidPublicKeyX25519 = asn1.ObjectIdentifier{1, 3, 101, 110}
oidPublicKeyEd25519 = asn1.ObjectIdentifier{1, 3, 101, 112}
)
// getPublicKeyAlgorithmFromOID returns the exposed PublicKeyAlgorithm
// identifier for public key types supported in certificates and CSRs. Marshal
// and Parse functions may support a different set of public key types.
func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) PublicKeyAlgorithm {
switch {
case oid.Equal(oidPublicKeyRSA):
return RSA
case oid.Equal(oidPublicKeyDSA):
return DSA
case oid.Equal(oidPublicKeyECDSA):
return ECDSA
case oid.Equal(oidPublicKeyEd25519):
return Ed25519
}
return UnknownPublicKeyAlgorithm
}
// RFC 5480, 2.1.1.1. Named Curve
//
// secp224r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 33 }
//
// secp256r1 OBJECT IDENTIFIER ::= {
// iso(1) member-body(2) us(840) ansi-X9-62(10045) curves(3)
// prime(1) 7 }
//
// secp384r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 34 }
//
// secp521r1 OBJECT IDENTIFIER ::= {
// iso(1) identified-organization(3) certicom(132) curve(0) 35 }
//
// NB: secp256r1 is equivalent to prime256v1
var (
oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33}
oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}
oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34}
oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35}
)
func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve {
switch {
case oid.Equal(oidNamedCurveP224):
return elliptic.P224()
case oid.Equal(oidNamedCurveP256):
return elliptic.P256()
case oid.Equal(oidNamedCurveP384):
return elliptic.P384()
case oid.Equal(oidNamedCurveP521):
return elliptic.P521()
}
return nil
}
func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) {
switch curve {
case elliptic.P224():
return oidNamedCurveP224, true
case elliptic.P256():
return oidNamedCurveP256, true
case elliptic.P384():
return oidNamedCurveP384, true
case elliptic.P521():
return oidNamedCurveP521, true
}
return nil, false
}
func oidFromECDHCurve(curve ecdh.Curve) (asn1.ObjectIdentifier, bool) {
switch curve {
case ecdh.X25519():
return oidPublicKeyX25519, true
case ecdh.P256():
return oidNamedCurveP256, true
case ecdh.P384():
return oidNamedCurveP384, true
case ecdh.P521():
return oidNamedCurveP521, true
}
return nil, false
}
// KeyUsage represents the set of actions that are valid for a given key. It's
// a bitmap of the KeyUsage* constants.
type KeyUsage int
const (
KeyUsageDigitalSignature KeyUsage = 1 << iota
KeyUsageContentCommitment
KeyUsageKeyEncipherment
KeyUsageDataEncipherment
KeyUsageKeyAgreement
KeyUsageCertSign
KeyUsageCRLSign
KeyUsageEncipherOnly
KeyUsageDecipherOnly
)
// RFC 5280, 4.2.1.12 Extended Key Usage
//
// anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 }
//
// id-kp OBJECT IDENTIFIER ::= { id-pkix 3 }
//
// id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 }
// id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 }
// id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 }
// id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 }
// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 }
// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 }
var (
oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0}
oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}
oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}
oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}
oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}
oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5}
oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6}
oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7}
oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}
oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}
oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3}
oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1}
oidExtKeyUsageMicrosoftCommercialCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22}
oidExtKeyUsageMicrosoftKernelCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1}
)
// ExtKeyUsage represents an extended set of actions that are valid for a given key.
// Each of the ExtKeyUsage* constants define a unique action.
type ExtKeyUsage int
const (
ExtKeyUsageAny ExtKeyUsage = iota
ExtKeyUsageServerAuth
ExtKeyUsageClientAuth
ExtKeyUsageCodeSigning
ExtKeyUsageEmailProtection
ExtKeyUsageIPSECEndSystem
ExtKeyUsageIPSECTunnel
ExtKeyUsageIPSECUser
ExtKeyUsageTimeStamping
ExtKeyUsageOCSPSigning
ExtKeyUsageMicrosoftServerGatedCrypto
ExtKeyUsageNetscapeServerGatedCrypto
ExtKeyUsageMicrosoftCommercialCodeSigning
ExtKeyUsageMicrosoftKernelCodeSigning
)
// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID.
var extKeyUsageOIDs = []struct {
extKeyUsage ExtKeyUsage
oid asn1.ObjectIdentifier
}{
{ExtKeyUsageAny, oidExtKeyUsageAny},
{ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth},
{ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth},
{ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning},
{ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection},
{ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem},
{ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel},
{ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser},
{ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping},
{ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning},
{ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto},
{ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto},
{ExtKeyUsageMicrosoftCommercialCodeSigning, oidExtKeyUsageMicrosoftCommercialCodeSigning},
{ExtKeyUsageMicrosoftKernelCodeSigning, oidExtKeyUsageMicrosoftKernelCodeSigning},
}
func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku ExtKeyUsage, ok bool) {
for _, pair := range extKeyUsageOIDs {
if oid.Equal(pair.oid) {
return pair.extKeyUsage, true
}
}
return
}
func oidFromExtKeyUsage(eku ExtKeyUsage) (oid asn1.ObjectIdentifier, ok bool) {
for _, pair := range extKeyUsageOIDs {
if eku == pair.extKeyUsage {
return pair.oid, true
}
}
return
}
// A Certificate represents an X.509 certificate.
type Certificate struct {
Raw []byte // Complete ASN.1 DER content (certificate, signature algorithm and signature).
RawTBSCertificate []byte // Certificate part of raw ASN.1 DER content.
RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo.
RawSubject []byte // DER encoded Subject
RawIssuer []byte // DER encoded Issuer
Signature []byte
SignatureAlgorithm SignatureAlgorithm
PublicKeyAlgorithm PublicKeyAlgorithm
PublicKey any
Version int
SerialNumber *big.Int
Issuer pkix.Name
Subject pkix.Name
NotBefore, NotAfter time.Time // Validity bounds.
KeyUsage KeyUsage
// Extensions contains raw X.509 extensions. When parsing certificates,
// this can be used to extract non-critical extensions that are not
// parsed by this package. When marshaling certificates, the Extensions
// field is ignored, see ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any
// marshaled certificates. Values override any extensions that would
// otherwise be produced based on the other fields. The ExtraExtensions
// field is not populated when parsing certificates, see Extensions.
ExtraExtensions []pkix.Extension
// UnhandledCriticalExtensions contains a list of extension IDs that
// were not (fully) processed when parsing. Verify will fail if this
// slice is non-empty, unless verification is delegated to an OS
// library which understands all the critical extensions.
//
// Users can access these extensions using Extensions and can remove
// elements from this slice if they believe that they have been
// handled.
UnhandledCriticalExtensions []asn1.ObjectIdentifier
ExtKeyUsage []ExtKeyUsage // Sequence of extended key usages.
UnknownExtKeyUsage []asn1.ObjectIdentifier // Encountered extended key usages unknown to this package.
// BasicConstraintsValid indicates whether IsCA, MaxPathLen,
// and MaxPathLenZero are valid.
BasicConstraintsValid bool
IsCA bool
// MaxPathLen and MaxPathLenZero indicate the presence and
// value of the BasicConstraints' "pathLenConstraint".
//
// When parsing a certificate, a positive non-zero MaxPathLen
// means that the field was specified, -1 means it was unset,
// and MaxPathLenZero being true mean that the field was
// explicitly set to zero. The case of MaxPathLen==0 with MaxPathLenZero==false
// should be treated equivalent to -1 (unset).
//
// When generating a certificate, an unset pathLenConstraint
// can be requested with either MaxPathLen == -1 or using the
// zero value for both MaxPathLen and MaxPathLenZero.
MaxPathLen int
// MaxPathLenZero indicates that BasicConstraintsValid==true
// and MaxPathLen==0 should be interpreted as an actual
// maximum path length of zero. Otherwise, that combination is
// interpreted as MaxPathLen not being set.
MaxPathLenZero bool
SubjectKeyId []byte
AuthorityKeyId []byte
// RFC 5280, 4.2.2.1 (Authority Information Access)
OCSPServer []string
IssuingCertificateURL []string
// Subject Alternate Name values. (Note that these values may not be valid
// if invalid values were contained within a parsed certificate. For
// example, an element of DNSNames may not be a valid DNS domain name.)
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
URIs []*url.URL
// Name constraints
PermittedDNSDomainsCritical bool // if true then the name constraints are marked critical.
PermittedDNSDomains []string
ExcludedDNSDomains []string
PermittedIPRanges []*net.IPNet
ExcludedIPRanges []*net.IPNet
PermittedEmailAddresses []string
ExcludedEmailAddresses []string
PermittedURIDomains []string
ExcludedURIDomains []string
// CRL Distribution Points
CRLDistributionPoints []string
PolicyIdentifiers []asn1.ObjectIdentifier
}
// ErrUnsupportedAlgorithm results from attempting to perform an operation that
// involves algorithms that are not currently implemented.
var ErrUnsupportedAlgorithm = errors.New("x509: cannot verify signature: algorithm unimplemented")
// An InsecureAlgorithmError indicates that the SignatureAlgorithm used to
// generate the signature is not secure, and the signature has been rejected.
//
// To temporarily restore support for SHA-1 signatures, include the value
// "x509sha1=1" in the GODEBUG environment variable. Note that this option will
// be removed in a future release.
type InsecureAlgorithmError SignatureAlgorithm
func (e InsecureAlgorithmError) Error() string {
var override string
if SignatureAlgorithm(e) == SHA1WithRSA || SignatureAlgorithm(e) == ECDSAWithSHA1 {
override = " (temporarily override with GODEBUG=x509sha1=1)"
}
return fmt.Sprintf("x509: cannot verify signature: insecure algorithm %v", SignatureAlgorithm(e)) + override
}
// ConstraintViolationError results when a requested usage is not permitted by
// a certificate. For example: checking a signature when the public key isn't a
// certificate signing key.
type ConstraintViolationError struct{}
func (ConstraintViolationError) Error() string {
return "x509: invalid signature: parent certificate cannot sign this kind of certificate"
}
func (c *Certificate) Equal(other *Certificate) bool {
if c == nil || other == nil {
return c == other
}
return bytes.Equal(c.Raw, other.Raw)
}
func (c *Certificate) hasSANExtension() bool {
return oidInExtensions(oidExtensionSubjectAltName, c.Extensions)
}
// CheckSignatureFrom verifies that the signature on c is a valid signature from parent.
//
// This is a low-level API that performs very limited checks, and not a full
// path verifier. Most users should use [Certificate.Verify] instead.
func (c *Certificate) CheckSignatureFrom(parent *Certificate) error {
// RFC 5280, 4.2.1.9:
// "If the basic constraints extension is not present in a version 3
// certificate, or the extension is present but the cA boolean is not
// asserted, then the certified public key MUST NOT be used to verify
// certificate signatures."
if parent.Version == 3 && !parent.BasicConstraintsValid ||
parent.BasicConstraintsValid && !parent.IsCA {
return ConstraintViolationError{}
}
if parent.KeyUsage != 0 && parent.KeyUsage&KeyUsageCertSign == 0 {
return ConstraintViolationError{}
}
if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm {
return ErrUnsupportedAlgorithm
}
return checkSignature(c.SignatureAlgorithm, c.RawTBSCertificate, c.Signature, parent.PublicKey, false)
}
// CheckSignature verifies that signature is a valid signature over signed from
// c's public key.
//
// This is a low-level API that performs no validity checks on the certificate.
//
// [MD5WithRSA] signatures are rejected, while [SHA1WithRSA] and [ECDSAWithSHA1]
// signatures are currently accepted.
func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature []byte) error {
return checkSignature(algo, signed, signature, c.PublicKey, true)
}
func (c *Certificate) hasNameConstraints() bool {
return oidInExtensions(oidExtensionNameConstraints, c.Extensions)
}
func (c *Certificate) getSANExtension() []byte {
for _, e := range c.Extensions {
if e.Id.Equal(oidExtensionSubjectAltName) {
return e.Value
}
}
return nil
}
func signaturePublicKeyAlgoMismatchError(expectedPubKeyAlgo PublicKeyAlgorithm, pubKey any) error {
return fmt.Errorf("x509: signature algorithm specifies an %s public key, but have public key of type %T", expectedPubKeyAlgo.String(), pubKey)
}
var x509sha1 = godebug.New("x509sha1")
// checkSignature verifies that signature is a valid signature over signed from
// a crypto.PublicKey.
func checkSignature(algo SignatureAlgorithm, signed, signature []byte, publicKey crypto.PublicKey, allowSHA1 bool) (err error) {
var hashType crypto.Hash
var pubKeyAlgo PublicKeyAlgorithm
for _, details := range signatureAlgorithmDetails {
if details.algo == algo {
hashType = details.hash
pubKeyAlgo = details.pubKeyAlgo
}
}
switch hashType {
case crypto.Hash(0):
if pubKeyAlgo != Ed25519 {
return ErrUnsupportedAlgorithm
}
case crypto.MD5:
return InsecureAlgorithmError(algo)
case crypto.SHA1:
// SHA-1 signatures are mostly disabled. See go.dev/issue/41682.
if !allowSHA1 {
if x509sha1.Value() != "1" {
return InsecureAlgorithmError(algo)
}
x509sha1.IncNonDefault()
}
fallthrough
default:
if !hashType.Available() {
return ErrUnsupportedAlgorithm
}
h := hashType.New()
h.Write(signed)
signed = h.Sum(nil)
}
switch pub := publicKey.(type) {
case *rsa.PublicKey:
if pubKeyAlgo != RSA {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if algo.isRSAPSS() {
return rsa.VerifyPSS(pub, hashType, signed, signature, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash})
} else {
return rsa.VerifyPKCS1v15(pub, hashType, signed, signature)
}
case *ecdsa.PublicKey:
if pubKeyAlgo != ECDSA {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if !ecdsa.VerifyASN1(pub, signed, signature) {
return errors.New("x509: ECDSA verification failure")
}
return
case ed25519.PublicKey:
if pubKeyAlgo != Ed25519 {
return signaturePublicKeyAlgoMismatchError(pubKeyAlgo, pub)
}
if !ed25519.Verify(pub, signed, signature) {
return errors.New("x509: Ed25519 verification failure")
}
return
}
return ErrUnsupportedAlgorithm
}
// CheckCRLSignature checks that the signature in crl is from c.
//
// Deprecated: Use RevocationList.CheckSignatureFrom instead.
func (c *Certificate) CheckCRLSignature(crl *pkix.CertificateList) error {
algo := getSignatureAlgorithmFromAI(crl.SignatureAlgorithm)
return c.CheckSignature(algo, crl.TBSCertList.Raw, crl.SignatureValue.RightAlign())
}
type UnhandledCriticalExtension struct{}
func (h UnhandledCriticalExtension) Error() string {
return "x509: unhandled critical extension"
}
type basicConstraints struct {
IsCA bool `asn1:"optional"`
MaxPathLen int `asn1:"optional,default:-1"`
}
// RFC 5280 4.2.1.4
type policyInformation struct {
Policy asn1.ObjectIdentifier
// policyQualifiers omitted
}
const (
nameTypeEmail = 1
nameTypeDNS = 2
nameTypeURI = 6
nameTypeIP = 7
)
// RFC 5280, 4.2.2.1
type authorityInfoAccess struct {
Method asn1.ObjectIdentifier
Location asn1.RawValue
}
// RFC 5280, 4.2.1.14
type distributionPoint struct {
DistributionPoint distributionPointName `asn1:"optional,tag:0"`
Reason asn1.BitString `asn1:"optional,tag:1"`
CRLIssuer asn1.RawValue `asn1:"optional,tag:2"`
}
type distributionPointName struct {
FullName []asn1.RawValue `asn1:"optional,tag:0"`
RelativeName pkix.RDNSequence `asn1:"optional,tag:1"`
}
func reverseBitsInAByte(in byte) byte {
b1 := in>>4 | in<<4
b2 := b1>>2&0x33 | b1<<2&0xcc
b3 := b2>>1&0x55 | b2<<1&0xaa
return b3
}
// asn1BitLength returns the bit-length of bitString by considering the
// most-significant bit in a byte to be the "first" bit. This convention
// matches ASN.1, but differs from almost everything else.
func asn1BitLength(bitString []byte) int {
bitLen := len(bitString) * 8
for i := range bitString {
b := bitString[len(bitString)-i-1]
for bit := uint(0); bit < 8; bit++ {
if (b>>bit)&1 == 1 {
return bitLen
}
bitLen--
}
}
return 0
}
var (
oidExtensionSubjectKeyId = []int{2, 5, 29, 14}
oidExtensionKeyUsage = []int{2, 5, 29, 15}
oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37}
oidExtensionAuthorityKeyId = []int{2, 5, 29, 35}
oidExtensionBasicConstraints = []int{2, 5, 29, 19}
oidExtensionSubjectAltName = []int{2, 5, 29, 17}
oidExtensionCertificatePolicies = []int{2, 5, 29, 32}
oidExtensionNameConstraints = []int{2, 5, 29, 30}
oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31}
oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1}
oidExtensionCRLNumber = []int{2, 5, 29, 20}
)
var (
oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1}
oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2}
)
// oidInExtensions reports whether an extension with the given oid exists in
// extensions.
func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool {
for _, e := range extensions {
if e.Id.Equal(oid) {
return true
}
}
return false
}
// marshalSANs marshals a list of addresses into a the contents of an X.509
// SubjectAlternativeName extension.
func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP, uris []*url.URL) (derBytes []byte, err error) {
var rawValues []asn1.RawValue
for _, name := range dnsNames {
if err := isIA5String(name); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeDNS, Class: 2, Bytes: []byte(name)})
}
for _, email := range emailAddresses {
if err := isIA5String(email); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeEmail, Class: 2, Bytes: []byte(email)})
}
for _, rawIP := range ipAddresses {
// If possible, we always want to encode IPv4 addresses in 4 bytes.
ip := rawIP.To4()
if ip == nil {
ip = rawIP
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeIP, Class: 2, Bytes: ip})
}
for _, uri := range uris {
uriStr := uri.String()
if err := isIA5String(uriStr); err != nil {
return nil, err
}
rawValues = append(rawValues, asn1.RawValue{Tag: nameTypeURI, Class: 2, Bytes: []byte(uriStr)})
}
return asn1.Marshal(rawValues)
}
func isIA5String(s string) error {
for _, r := range s {
// Per RFC5280 "IA5String is limited to the set of ASCII characters"
if r > unicode.MaxASCII {
return fmt.Errorf("x509: %q cannot be encoded as an IA5String", s)
}
}
return nil
}
func buildCertExtensions(template *Certificate, subjectIsEmpty bool, authorityKeyId []byte, subjectKeyId []byte) (ret []pkix.Extension, err error) {
ret = make([]pkix.Extension, 10 /* maximum number of elements. */)
n := 0
if template.KeyUsage != 0 &&
!oidInExtensions(oidExtensionKeyUsage, template.ExtraExtensions) {
ret[n], err = marshalKeyUsage(template.KeyUsage)
if err != nil {
return nil, err
}
n++
}
if (len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0) &&
!oidInExtensions(oidExtensionExtendedKeyUsage, template.ExtraExtensions) {
ret[n], err = marshalExtKeyUsage(template.ExtKeyUsage, template.UnknownExtKeyUsage)
if err != nil {
return nil, err
}
n++
}
if template.BasicConstraintsValid && !oidInExtensions(oidExtensionBasicConstraints, template.ExtraExtensions) {
ret[n], err = marshalBasicConstraints(template.IsCA, template.MaxPathLen, template.MaxPathLenZero)
if err != nil {
return nil, err
}
n++
}
if len(subjectKeyId) > 0 && !oidInExtensions(oidExtensionSubjectKeyId, template.ExtraExtensions) {
ret[n].Id = oidExtensionSubjectKeyId
ret[n].Value, err = asn1.Marshal(subjectKeyId)
if err != nil {
return
}
n++
}
if len(authorityKeyId) > 0 && !oidInExtensions(oidExtensionAuthorityKeyId, template.ExtraExtensions) {
ret[n].Id = oidExtensionAuthorityKeyId
ret[n].Value, err = asn1.Marshal(authKeyId{authorityKeyId})
if err != nil {
return
}
n++
}
if (len(template.OCSPServer) > 0 || len(template.IssuingCertificateURL) > 0) &&
!oidInExtensions(oidExtensionAuthorityInfoAccess, template.ExtraExtensions) {
ret[n].Id = oidExtensionAuthorityInfoAccess
var aiaValues []authorityInfoAccess
for _, name := range template.OCSPServer {
aiaValues = append(aiaValues, authorityInfoAccess{
Method: oidAuthorityInfoAccessOcsp,
Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)},
})
}
for _, name := range template.IssuingCertificateURL {
aiaValues = append(aiaValues, authorityInfoAccess{
Method: oidAuthorityInfoAccessIssuers,
Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)},
})
}
ret[n].Value, err = asn1.Marshal(aiaValues)
if err != nil {
return
}
n++
}
if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) &&
!oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) {
ret[n].Id = oidExtensionSubjectAltName
// From RFC 5280, Section 4.2.1.6:
// “If the subject field contains an empty sequence ... then
// subjectAltName extension ... is marked as critical”
ret[n].Critical = subjectIsEmpty
ret[n].Value, err = marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs)
if err != nil {
return
}
n++
}
if len(template.PolicyIdentifiers) > 0 &&
!oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) {
ret[n], err = marshalCertificatePolicies(template.PolicyIdentifiers)
if err != nil {
return nil, err
}
n++
}
if (len(template.PermittedDNSDomains) > 0 || len(template.ExcludedDNSDomains) > 0 ||
len(template.PermittedIPRanges) > 0 || len(template.ExcludedIPRanges) > 0 ||
len(template.PermittedEmailAddresses) > 0 || len(template.ExcludedEmailAddresses) > 0 ||
len(template.PermittedURIDomains) > 0 || len(template.ExcludedURIDomains) > 0) &&
!oidInExtensions(oidExtensionNameConstraints, template.ExtraExtensions) {
ret[n].Id = oidExtensionNameConstraints
ret[n].Critical = template.PermittedDNSDomainsCritical
ipAndMask := func(ipNet *net.IPNet) []byte {
maskedIP := ipNet.IP.Mask(ipNet.Mask)
ipAndMask := make([]byte, 0, len(maskedIP)+len(ipNet.Mask))
ipAndMask = append(ipAndMask, maskedIP...)
ipAndMask = append(ipAndMask, ipNet.Mask...)
return ipAndMask
}
serialiseConstraints := func(dns []string, ips []*net.IPNet, emails []string, uriDomains []string) (der []byte, err error) {
var b cryptobyte.Builder
for _, name := range dns {
if err = isIA5String(name); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(2).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(name))
})
})
}
for _, ipNet := range ips {
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(7).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes(ipAndMask(ipNet))
})
})
}
for _, email := range emails {
if err = isIA5String(email); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(email))
})
})
}
for _, uriDomain := range uriDomains {
if err = isIA5String(uriDomain); err != nil {
return nil, err
}
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1(cryptobyte_asn1.Tag(6).ContextSpecific(), func(b *cryptobyte.Builder) {
b.AddBytes([]byte(uriDomain))
})
})
}
return b.Bytes()
}
permitted, err := serialiseConstraints(template.PermittedDNSDomains, template.PermittedIPRanges, template.PermittedEmailAddresses, template.PermittedURIDomains)
if err != nil {
return nil, err
}
excluded, err := serialiseConstraints(template.ExcludedDNSDomains, template.ExcludedIPRanges, template.ExcludedEmailAddresses, template.ExcludedURIDomains)
if err != nil {
return nil, err
}
var b cryptobyte.Builder
b.AddASN1(cryptobyte_asn1.SEQUENCE, func(b *cryptobyte.Builder) {
if len(permitted) > 0 {
b.AddASN1(cryptobyte_asn1.Tag(0).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) {
b.AddBytes(permitted)
})
}
if len(excluded) > 0 {
b.AddASN1(cryptobyte_asn1.Tag(1).ContextSpecific().Constructed(), func(b *cryptobyte.Builder) {
b.AddBytes(excluded)
})
}
})
ret[n].Value, err = b.Bytes()
if err != nil {
return nil, err
}
n++
}
if len(template.CRLDistributionPoints) > 0 &&
!oidInExtensions(oidExtensionCRLDistributionPoints, template.ExtraExtensions) {
ret[n].Id = oidExtensionCRLDistributionPoints
var crlDp []distributionPoint
for _, name := range template.CRLDistributionPoints {
dp := distributionPoint{
DistributionPoint: distributionPointName{
FullName: []asn1.RawValue{
{Tag: 6, Class: 2, Bytes: []byte(name)},
},
},
}
crlDp = append(crlDp, dp)
}
ret[n].Value, err = asn1.Marshal(crlDp)
if err != nil {
return
}
n++
}
// Adding another extension here? Remember to update the maximum number
// of elements in the make() at the top of the function and the list of
// template fields used in CreateCertificate documentation.
return append(ret[:n], template.ExtraExtensions...), nil
}
func marshalKeyUsage(ku KeyUsage) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionKeyUsage, Critical: true}
var a [2]byte
a[0] = reverseBitsInAByte(byte(ku))
a[1] = reverseBitsInAByte(byte(ku >> 8))
l := 1
if a[1] != 0 {
l = 2
}
bitString := a[:l]
var err error
ext.Value, err = asn1.Marshal(asn1.BitString{Bytes: bitString, BitLength: asn1BitLength(bitString)})
return ext, err
}
func marshalExtKeyUsage(extUsages []ExtKeyUsage, unknownUsages []asn1.ObjectIdentifier) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionExtendedKeyUsage}
oids := make([]asn1.ObjectIdentifier, len(extUsages)+len(unknownUsages))
for i, u := range extUsages {
if oid, ok := oidFromExtKeyUsage(u); ok {
oids[i] = oid
} else {
return ext, errors.New("x509: unknown extended key usage")
}
}
copy(oids[len(extUsages):], unknownUsages)
var err error
ext.Value, err = asn1.Marshal(oids)
return ext, err
}
func marshalBasicConstraints(isCA bool, maxPathLen int, maxPathLenZero bool) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionBasicConstraints, Critical: true}
// Leaving MaxPathLen as zero indicates that no maximum path
// length is desired, unless MaxPathLenZero is set. A value of
// -1 causes encoding/asn1 to omit the value as desired.
if maxPathLen == 0 && !maxPathLenZero {
maxPathLen = -1
}
var err error
ext.Value, err = asn1.Marshal(basicConstraints{isCA, maxPathLen})
return ext, err
}
func marshalCertificatePolicies(policyIdentifiers []asn1.ObjectIdentifier) (pkix.Extension, error) {
ext := pkix.Extension{Id: oidExtensionCertificatePolicies}
policies := make([]policyInformation, len(policyIdentifiers))
for i, policy := range policyIdentifiers {
policies[i].Policy = policy
}
var err error
ext.Value, err = asn1.Marshal(policies)
return ext, err
}
func buildCSRExtensions(template *CertificateRequest) ([]pkix.Extension, error) {
var ret []pkix.Extension
if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 || len(template.URIs) > 0) &&
!oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) {
sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses, template.URIs)
if err != nil {
return nil, err
}
ret = append(ret, pkix.Extension{
Id: oidExtensionSubjectAltName,
Value: sanBytes,
})
}
return append(ret, template.ExtraExtensions...), nil
}
func subjectBytes(cert *Certificate) ([]byte, error) {
if len(cert.RawSubject) > 0 {
return cert.RawSubject, nil
}
return asn1.Marshal(cert.Subject.ToRDNSequence())
}
// signingParamsForPublicKey returns the parameters to use for signing with
// priv. If requestedSigAlgo is not zero then it overrides the default
// signature algorithm.
func signingParamsForPublicKey(pub any, requestedSigAlgo SignatureAlgorithm) (hashFunc crypto.Hash, sigAlgo pkix.AlgorithmIdentifier, err error) {
var pubType PublicKeyAlgorithm
switch pub := pub.(type) {
case *rsa.PublicKey:
pubType = RSA
hashFunc = crypto.SHA256
sigAlgo.Algorithm = oidSignatureSHA256WithRSA
sigAlgo.Parameters = asn1.NullRawValue
case *ecdsa.PublicKey:
pubType = ECDSA
switch pub.Curve {
case elliptic.P224(), elliptic.P256():
hashFunc = crypto.SHA256
sigAlgo.Algorithm = oidSignatureECDSAWithSHA256
case elliptic.P384():
hashFunc = crypto.SHA384
sigAlgo.Algorithm = oidSignatureECDSAWithSHA384
case elliptic.P521():
hashFunc = crypto.SHA512
sigAlgo.Algorithm = oidSignatureECDSAWithSHA512
default:
err = errors.New("x509: unknown elliptic curve")
}
case ed25519.PublicKey:
pubType = Ed25519
sigAlgo.Algorithm = oidSignatureEd25519
default:
err = errors.New("x509: only RSA, ECDSA and Ed25519 keys supported")
}
if err != nil {
return
}
if requestedSigAlgo == 0 {
return
}
found := false
for _, details := range signatureAlgorithmDetails {
if details.algo == requestedSigAlgo {
if details.pubKeyAlgo != pubType {
err = errors.New("x509: requested SignatureAlgorithm does not match private key type")
return
}
sigAlgo.Algorithm, hashFunc = details.oid, details.hash
if hashFunc == 0 && pubType != Ed25519 {
err = errors.New("x509: cannot sign with hash function requested")
return
}
if hashFunc == crypto.MD5 {
err = errors.New("x509: signing with MD5 is not supported")
return
}
if requestedSigAlgo.isRSAPSS() {
sigAlgo.Parameters = hashToPSSParameters[hashFunc]
}
found = true
break
}
}
if !found {
err = errors.New("x509: unknown SignatureAlgorithm")
}
return
}
// emptyASN1Subject is the ASN.1 DER encoding of an empty Subject, which is
// just an empty SEQUENCE.
var emptyASN1Subject = []byte{0x30, 0}
// CreateCertificate creates a new X.509 v3 certificate based on a template.
// The following members of template are currently used:
//
// - AuthorityKeyId
// - BasicConstraintsValid
// - CRLDistributionPoints
// - DNSNames
// - EmailAddresses
// - ExcludedDNSDomains
// - ExcludedEmailAddresses
// - ExcludedIPRanges
// - ExcludedURIDomains
// - ExtKeyUsage
// - ExtraExtensions
// - IPAddresses
// - IsCA
// - IssuingCertificateURL
// - KeyUsage
// - MaxPathLen
// - MaxPathLenZero
// - NotAfter
// - NotBefore
// - OCSPServer
// - PermittedDNSDomains
// - PermittedDNSDomainsCritical
// - PermittedEmailAddresses
// - PermittedIPRanges
// - PermittedURIDomains
// - PolicyIdentifiers
// - SerialNumber
// - SignatureAlgorithm
// - Subject
// - SubjectKeyId
// - URIs
// - UnknownExtKeyUsage
//
// The certificate is signed by parent. If parent is equal to template then the
// certificate is self-signed. The parameter pub is the public key of the
// certificate to be generated and priv is the private key of the signer.
//
// The returned slice is the certificate in DER encoding.
//
// The currently supported key types are *rsa.PublicKey, *ecdsa.PublicKey and
// ed25519.PublicKey. pub must be a supported key type, and priv must be a
// crypto.Signer with a supported public key.
//
// The AuthorityKeyId will be taken from the SubjectKeyId of parent, if any,
// unless the resulting certificate is self-signed. Otherwise the value from
// template will be used.
//
// If SubjectKeyId from template is empty and the template is a CA, SubjectKeyId
// will be generated from the hash of the public key.
func CreateCertificate(rand io.Reader, template, parent *Certificate, pub, priv any) ([]byte, error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
if template.SerialNumber == nil {
return nil, errors.New("x509: no SerialNumber given")
}
// RFC 5280 Section 4.1.2.2: serial number must positive
//
// We _should_ also restrict serials to <= 20 octets, but it turns out a lot of people
// get this wrong, in part because the encoding can itself alter the length of the
// serial. For now we accept these non-conformant serials.
if template.SerialNumber.Sign() == -1 {
return nil, errors.New("x509: serial number must be positive")
}
if template.BasicConstraintsValid && !template.IsCA && template.MaxPathLen != -1 && (template.MaxPathLen != 0 || template.MaxPathLenZero) {
return nil, errors.New("x509: only CAs are allowed to specify MaxPathLen")
}
hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm)
if err != nil {
return nil, err
}
publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(pub)
if err != nil {
return nil, err
}
if getPublicKeyAlgorithmFromOID(publicKeyAlgorithm.Algorithm) == UnknownPublicKeyAlgorithm {
return nil, fmt.Errorf("x509: unsupported public key type: %T", pub)
}
asn1Issuer, err := subjectBytes(parent)
if err != nil {
return nil, err
}
asn1Subject, err := subjectBytes(template)
if err != nil {
return nil, err
}
authorityKeyId := template.AuthorityKeyId
if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 {
authorityKeyId = parent.SubjectKeyId
}
subjectKeyId := template.SubjectKeyId
if len(subjectKeyId) == 0 && template.IsCA {
// SubjectKeyId generated using method 1 in RFC 5280, Section 4.2.1.2:
// (1) The keyIdentifier is composed of the 160-bit SHA-1 hash of the
// value of the BIT STRING subjectPublicKey (excluding the tag,
// length, and number of unused bits).
h := sha1.Sum(publicKeyBytes)
subjectKeyId = h[:]
}
// Check that the signer's public key matches the private key, if available.
type privateKey interface {
Equal(crypto.PublicKey) bool
}
if privPub, ok := key.Public().(privateKey); !ok {
return nil, errors.New("x509: internal error: supported public key does not implement Equal")
} else if parent.PublicKey != nil && !privPub.Equal(parent.PublicKey) {
return nil, errors.New("x509: provided PrivateKey doesn't match parent's PublicKey")
}
extensions, err := buildCertExtensions(template, bytes.Equal(asn1Subject, emptyASN1Subject), authorityKeyId, subjectKeyId)
if err != nil {
return nil, err
}
encodedPublicKey := asn1.BitString{BitLength: len(publicKeyBytes) * 8, Bytes: publicKeyBytes}
c := tbsCertificate{
Version: 2,
SerialNumber: template.SerialNumber,
SignatureAlgorithm: signatureAlgorithm,
Issuer: asn1.RawValue{FullBytes: asn1Issuer},
Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()},
Subject: asn1.RawValue{FullBytes: asn1Subject},
PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey},
Extensions: extensions,
}
tbsCertContents, err := asn1.Marshal(c)
if err != nil {
return nil, err
}
c.Raw = tbsCertContents
signed := tbsCertContents
if hashFunc != 0 {
h := hashFunc.New()
h.Write(signed)
signed = h.Sum(nil)
}
var signerOpts crypto.SignerOpts = hashFunc
if template.SignatureAlgorithm != 0 && template.SignatureAlgorithm.isRSAPSS() {
signerOpts = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: hashFunc,
}
}
var signature []byte
signature, err = key.Sign(rand, signed, signerOpts)
if err != nil {
return nil, err
}
signedCert, err := asn1.Marshal(certificate{
c,
signatureAlgorithm,
asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
if err != nil {
return nil, err
}
// Check the signature to ensure the crypto.Signer behaved correctly.
if err := checkSignature(getSignatureAlgorithmFromAI(signatureAlgorithm), c.Raw, signature, key.Public(), true); err != nil {
return nil, fmt.Errorf("x509: signature over certificate returned by signer is invalid: %w", err)
}
return signedCert, nil
}
// pemCRLPrefix is the magic string that indicates that we have a PEM encoded
// CRL.
var pemCRLPrefix = []byte("-----BEGIN X509 CRL")
// pemType is the type of a PEM encoded CRL.
var pemType = "X509 CRL"
// ParseCRL parses a CRL from the given bytes. It's often the case that PEM
// encoded CRLs will appear where they should be DER encoded, so this function
// will transparently handle PEM encoding as long as there isn't any leading
// garbage.
//
// Deprecated: Use ParseRevocationList instead.
func ParseCRL(crlBytes []byte) (*pkix.CertificateList, error) {
if bytes.HasPrefix(crlBytes, pemCRLPrefix) {
block, _ := pem.Decode(crlBytes)
if block != nil && block.Type == pemType {
crlBytes = block.Bytes
}
}
return ParseDERCRL(crlBytes)
}
// ParseDERCRL parses a DER encoded CRL from the given bytes.
//
// Deprecated: Use ParseRevocationList instead.
func ParseDERCRL(derBytes []byte) (*pkix.CertificateList, error) {
certList := new(pkix.CertificateList)
if rest, err := asn1.Unmarshal(derBytes, certList); err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after CRL")
}
return certList, nil
}
// CreateCRL returns a DER encoded CRL, signed by this Certificate, that
// contains the given list of revoked certificates.
//
// Deprecated: this method does not generate an RFC 5280 conformant X.509 v2 CRL.
// To generate a standards compliant CRL, use CreateRevocationList instead.
func (c *Certificate) CreateCRL(rand io.Reader, priv any, revokedCerts []pkix.RevokedCertificate, now, expiry time.Time) (crlBytes []byte, err error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), 0)
if err != nil {
return nil, err
}
// Force revocation times to UTC per RFC 5280.
revokedCertsUTC := make([]pkix.RevokedCertificate, len(revokedCerts))
for i, rc := range revokedCerts {
rc.RevocationTime = rc.RevocationTime.UTC()
revokedCertsUTC[i] = rc
}
tbsCertList := pkix.TBSCertificateList{
Version: 1,
Signature: signatureAlgorithm,
Issuer: c.Subject.ToRDNSequence(),
ThisUpdate: now.UTC(),
NextUpdate: expiry.UTC(),
RevokedCertificates: revokedCertsUTC,
}
// Authority Key Id
if len(c.SubjectKeyId) > 0 {
var aki pkix.Extension
aki.Id = oidExtensionAuthorityKeyId
aki.Value, err = asn1.Marshal(authKeyId{Id: c.SubjectKeyId})
if err != nil {
return
}
tbsCertList.Extensions = append(tbsCertList.Extensions, aki)
}
tbsCertListContents, err := asn1.Marshal(tbsCertList)
if err != nil {
return
}
signed := tbsCertListContents
if hashFunc != 0 {
h := hashFunc.New()
h.Write(signed)
signed = h.Sum(nil)
}
var signature []byte
signature, err = key.Sign(rand, signed, hashFunc)
if err != nil {
return
}
return asn1.Marshal(pkix.CertificateList{
TBSCertList: tbsCertList,
SignatureAlgorithm: signatureAlgorithm,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
// CertificateRequest represents a PKCS #10, certificate signature request.
type CertificateRequest struct {
Raw []byte // Complete ASN.1 DER content (CSR, signature algorithm and signature).
RawTBSCertificateRequest []byte // Certificate request info part of raw ASN.1 DER content.
RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo.
RawSubject []byte // DER encoded Subject.
Version int
Signature []byte
SignatureAlgorithm SignatureAlgorithm
PublicKeyAlgorithm PublicKeyAlgorithm
PublicKey any
Subject pkix.Name
// Attributes contains the CSR attributes that can parse as
// pkix.AttributeTypeAndValueSET.
//
// Deprecated: Use Extensions and ExtraExtensions instead for parsing and
// generating the requestedExtensions attribute.
Attributes []pkix.AttributeTypeAndValueSET
// Extensions contains all requested extensions, in raw form. When parsing
// CSRs, this can be used to extract extensions that are not parsed by this
// package.
Extensions []pkix.Extension
// ExtraExtensions contains extensions to be copied, raw, into any CSR
// marshaled by CreateCertificateRequest. Values override any extensions
// that would otherwise be produced based on the other fields but are
// overridden by any extensions specified in Attributes.
//
// The ExtraExtensions field is not populated by ParseCertificateRequest,
// see Extensions instead.
ExtraExtensions []pkix.Extension
// Subject Alternate Name values.
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
URIs []*url.URL
}
// These structures reflect the ASN.1 structure of X.509 certificate
// signature requests (see RFC 2986):
type tbsCertificateRequest struct {
Raw asn1.RawContent
Version int
Subject asn1.RawValue
PublicKey publicKeyInfo
RawAttributes []asn1.RawValue `asn1:"tag:0"`
}
type certificateRequest struct {
Raw asn1.RawContent
TBSCSR tbsCertificateRequest
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
// oidExtensionRequest is a PKCS #9 OBJECT IDENTIFIER that indicates requested
// extensions in a CSR.
var oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14}
// newRawAttributes converts AttributeTypeAndValueSETs from a template
// CertificateRequest's Attributes into tbsCertificateRequest RawAttributes.
func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) {
var rawAttributes []asn1.RawValue
b, err := asn1.Marshal(attributes)
if err != nil {
return nil, err
}
rest, err := asn1.Unmarshal(b, &rawAttributes)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, errors.New("x509: failed to unmarshal raw CSR Attributes")
}
return rawAttributes, nil
}
// parseRawAttributes Unmarshals RawAttributes into AttributeTypeAndValueSETs.
func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET {
var attributes []pkix.AttributeTypeAndValueSET
for _, rawAttr := range rawAttributes {
var attr pkix.AttributeTypeAndValueSET
rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr)
// Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET
// (i.e.: challengePassword or unstructuredName).
if err == nil && len(rest) == 0 {
attributes = append(attributes, attr)
}
}
return attributes
}
// parseCSRExtensions parses the attributes from a CSR and extracts any
// requested extensions.
func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) {
// pkcs10Attribute reflects the Attribute structure from RFC 2986, Section 4.1.
type pkcs10Attribute struct {
Id asn1.ObjectIdentifier
Values []asn1.RawValue `asn1:"set"`
}
var ret []pkix.Extension
requestedExts := make(map[string]bool)
for _, rawAttr := range rawAttributes {
var attr pkcs10Attribute
if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 {
// Ignore attributes that don't parse.
continue
}
if !attr.Id.Equal(oidExtensionRequest) {
continue
}
var extensions []pkix.Extension
if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil {
return nil, err
}
for _, ext := range extensions {
oidStr := ext.Id.String()
if requestedExts[oidStr] {
return nil, errors.New("x509: certificate request contains duplicate requested extensions")
}
requestedExts[oidStr] = true
}
ret = append(ret, extensions...)
}
return ret, nil
}
// CreateCertificateRequest creates a new certificate request based on a
// template. The following members of template are used:
//
// - SignatureAlgorithm
// - Subject
// - DNSNames
// - EmailAddresses
// - IPAddresses
// - URIs
// - ExtraExtensions
// - Attributes (deprecated)
//
// priv is the private key to sign the CSR with, and the corresponding public
// key will be included in the CSR. It must implement crypto.Signer and its
// Public() method must return a *rsa.PublicKey or a *ecdsa.PublicKey or a
// ed25519.PublicKey. (A *rsa.PrivateKey, *ecdsa.PrivateKey or
// ed25519.PrivateKey satisfies this.)
//
// The returned slice is the certificate request in DER encoding.
func CreateCertificateRequest(rand io.Reader, template *CertificateRequest, priv any) (csr []byte, err error) {
key, ok := priv.(crypto.Signer)
if !ok {
return nil, errors.New("x509: certificate private key does not implement crypto.Signer")
}
var hashFunc crypto.Hash
var sigAlgo pkix.AlgorithmIdentifier
hashFunc, sigAlgo, err = signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm)
if err != nil {
return nil, err
}
var publicKeyBytes []byte
var publicKeyAlgorithm pkix.AlgorithmIdentifier
publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(key.Public())
if err != nil {
return nil, err
}
extensions, err := buildCSRExtensions(template)
if err != nil {
return nil, err
}
// Make a copy of template.Attributes because we may alter it below.
attributes := make([]pkix.AttributeTypeAndValueSET, 0, len(template.Attributes))
for _, attr := range template.Attributes {
values := make([][]pkix.AttributeTypeAndValue, len(attr.Value))
copy(values, attr.Value)
attributes = append(attributes, pkix.AttributeTypeAndValueSET{
Type: attr.Type,
Value: values,
})
}
extensionsAppended := false
if len(extensions) > 0 {
// Append the extensions to an existing attribute if possible.
for _, atvSet := range attributes {
if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 {
continue
}
// specifiedExtensions contains all the extensions that we
// found specified via template.Attributes.
specifiedExtensions := make(map[string]bool)
for _, atvs := range atvSet.Value {
for _, atv := range atvs {
specifiedExtensions[atv.Type.String()] = true
}
}
newValue := make([]pkix.AttributeTypeAndValue, 0, len(atvSet.Value[0])+len(extensions))
newValue = append(newValue, atvSet.Value[0]...)
for _, e := range extensions {
if specifiedExtensions[e.Id.String()] {
// Attributes already contained a value for
// this extension and it takes priority.
continue
}
newValue = append(newValue, pkix.AttributeTypeAndValue{
// There is no place for the critical
// flag in an AttributeTypeAndValue.
Type: e.Id,
Value: e.Value,
})
}
atvSet.Value[0] = newValue
extensionsAppended = true
break
}
}
rawAttributes, err := newRawAttributes(attributes)
if err != nil {
return
}
// If not included in attributes, add a new attribute for the
// extensions.
if len(extensions) > 0 && !extensionsAppended {
attr := struct {
Type asn1.ObjectIdentifier
Value [][]pkix.Extension `asn1:"set"`
}{
Type: oidExtensionRequest,
Value: [][]pkix.Extension{extensions},
}
b, err := asn1.Marshal(attr)
if err != nil {
return nil, errors.New("x509: failed to serialise extensions attribute: " + err.Error())
}
var rawValue asn1.RawValue
if _, err := asn1.Unmarshal(b, &rawValue); err != nil {
return nil, err
}
rawAttributes = append(rawAttributes, rawValue)
}
asn1Subject := template.RawSubject
if len(asn1Subject) == 0 {
asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence())
if err != nil {
return nil, err
}
}
tbsCSR := tbsCertificateRequest{
Version: 0, // PKCS #10, RFC 2986
Subject: asn1.RawValue{FullBytes: asn1Subject},
PublicKey: publicKeyInfo{
Algorithm: publicKeyAlgorithm,
PublicKey: asn1.BitString{
Bytes: publicKeyBytes,
BitLength: len(publicKeyBytes) * 8,
},
},
RawAttributes: rawAttributes,
}
tbsCSRContents, err := asn1.Marshal(tbsCSR)
if err != nil {
return
}
tbsCSR.Raw = tbsCSRContents
signed := tbsCSRContents
if hashFunc != 0 {
h := hashFunc.New()
h.Write(signed)
signed = h.Sum(nil)
}
var signature []byte
signature, err = key.Sign(rand, signed, hashFunc)
if err != nil {
return
}
return asn1.Marshal(certificateRequest{
TBSCSR: tbsCSR,
SignatureAlgorithm: sigAlgo,
SignatureValue: asn1.BitString{
Bytes: signature,
BitLength: len(signature) * 8,
},
})
}
// ParseCertificateRequest parses a single certificate request from the
// given ASN.1 DER data.
func ParseCertificateRequest(asn1Data []byte) (*CertificateRequest, error) {
var csr certificateRequest
rest, err := asn1.Unmarshal(asn1Data, &csr)
if err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, asn1.SyntaxError{Msg: "trailing data"}
}
return parseCertificateRequest(&csr)
}
func parseCertificateRequest(in *certificateRequest) (*CertificateRequest, error) {
out := &CertificateRequest{
Raw: in.Raw,
RawTBSCertificateRequest: in.TBSCSR.Raw,
RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw,
RawSubject: in.TBSCSR.Subject.FullBytes,
Signature: in.SignatureValue.RightAlign(),
SignatureAlgorithm: getSignatureAlgorithmFromAI(in.SignatureAlgorithm),
PublicKeyAlgorithm: getPublicKeyAlgorithmFromOID(in.TBSCSR.PublicKey.Algorithm.Algorithm),
Version: in.TBSCSR.Version,
Attributes: parseRawAttributes(in.TBSCSR.RawAttributes),
}
var err error
if out.PublicKeyAlgorithm != UnknownPublicKeyAlgorithm {
out.PublicKey, err = parsePublicKey(&in.TBSCSR.PublicKey)
if err != nil {
return nil, err
}
}
var subject pkix.RDNSequence
if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil {
return nil, err
} else if len(rest) != 0 {
return nil, errors.New("x509: trailing data after X.509 Subject")
}
out.Subject.FillFromRDNSequence(&subject)
if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil {
return nil, err
}
for _, extension := range out.Extensions {
switch {
case extension.Id.Equal(oidExtensionSubjectAltName):
out.DNSNames, out.EmailAddresses, out.IPAddresses, out.URIs, err = parseSANExtension(extension.Value)
if err != nil {
return nil, err
}
}
}
return out, nil
}
// CheckSignature reports whether the signature on c is valid.
func (c *CertificateRequest) CheckSignature() error {
return checkSignature(c.SignatureAlgorithm, c.RawTBSCertificateRequest, c.Signature, c.PublicKey, true)
}
// RevocationList contains the fields used to create an X.509 v2 Certificate
// Revocation list with CreateRevocationList.
type RevocationList struct {
// Raw contains the complete ASN.1 DER content of the CRL (tbsCertList,
// signatureAlgorithm, and signatureValue.)
Raw []byte
// RawTBSRevocationList contains just the tbsCertList portion of the ASN.1
// DER.
RawTBSRevocationList []byte
// RawIssuer contains the DER encoded Issuer.
RawIssuer []byte
// Issuer contains the DN of the issuing certificate.
Issuer pkix.Name
// AuthorityKeyId is used to identify the public key associated with the
// issuing certificate. It is populated from the authorityKeyIdentifier
// extension when parsing a CRL. It is ignored when creating a CRL; the
// extension is populated from the issuing certificate itself.
AuthorityKeyId []byte
Signature []byte
// SignatureAlgorithm is used to determine the signature algorithm to be
// used when signing the CRL. If 0 the default algorithm for the signing
// key will be used.
SignatureAlgorithm SignatureAlgorithm
// RevokedCertificates is used to populate the revokedCertificates
// sequence in the CRL, it may be empty. RevokedCertificates may be nil,
// in which case an empty CRL will be created.
RevokedCertificates []pkix.RevokedCertificate
// Number is used to populate the X.509 v2 cRLNumber extension in the CRL,
// which should be a monotonically increasing sequence number for a given
// CRL scope and CRL issuer. It is also populated from the cRLNumber
// extension when parsing a CRL.
Number *big.Int
// ThisUpdate is used to populate the thisUpdate field in the CRL, which
// indicates the issuance date of the CRL.
ThisUpdate time.Time
// NextUpdate is used to populate the nextUpdate field in the CRL, which
// indicates the date by which the next CRL will be issued. NextUpdate
// must be greater than ThisUpdate.
NextUpdate time.Time
// Extensions contains raw X.509 extensions. When creating a CRL,
// the Extensions field is ignored, see ExtraExtensions.
Extensions []pkix.Extension
// ExtraExtensions contains any additional extensions to add directly to
// the CRL.
ExtraExtensions []pkix.Extension
}
// These structures reflect the ASN.1 structure of X.509 CRLs better than
// the existing crypto/x509/pkix variants do. These mirror the existing
// certificate structs in this file.
//
// Notably, we include issuer as an asn1.RawValue, mirroring the behavior of
// tbsCertificate and allowing raw (unparsed) subjects to be passed cleanly.
type certificateList struct {
TBSCertList tbsCertificateList
SignatureAlgorithm pkix.AlgorithmIdentifier
SignatureValue asn1.BitString
}
type tbsCertificateList struct {
Raw asn1.RawContent
Version int `asn1:"optional,default:0"`
Signature pkix.AlgorithmIdentifier
Issuer asn1.RawValue
ThisUpdate time.Time
NextUpdate time.Time `asn1:"optional"`
RevokedCertificates []pkix.RevokedCertificate `asn1:"optional"`
Extensions []pkix.Extension `asn1:"tag:0,optional,explicit"`
}
// CreateRevocationList creates a new X.509 v2 Certificate Revocation List,
// according to RFC 5280, based on template.
//
// The CRL is signed by priv which should be the private key associated with
// the public key in the issuer certificate.
//
// The issuer may not be nil, and the crlSign bit must be set in KeyUsage in
// order to use it as a CRL issuer.
//
// The issuer distinguished name CRL field and authority key identifier
// extension are populated using the issuer certificate. issuer must have
// SubjectKeyId set.
func CreateRevocationList(rand io.Reader, template *RevocationList, issuer *Certificate, priv crypto.Signer) ([]byte, error) {
if template == nil {
return nil, errors.New("x509: template can not be nil")
}
if issuer == nil {
return nil, errors.New("x509: issuer can not be nil")
}
if (issuer.KeyUsage & KeyUsageCRLSign) == 0 {
return nil, errors.New("x509: issuer must have the crlSign key usage bit set")
}
if len(issuer.SubjectKeyId) == 0 {
return nil, errors.New("x509: issuer certificate doesn't contain a subject key identifier")
}
if template.NextUpdate.Before(template.ThisUpdate) {
return nil, errors.New("x509: template.ThisUpdate is after template.NextUpdate")
}
if template.Number == nil {
return nil, errors.New("x509: template contains nil Number field")
}
hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(priv.Public(), template.SignatureAlgorithm)
if err != nil {
return nil, err
}
// Force revocation times to UTC per RFC 5280.
revokedCertsUTC := make([]pkix.RevokedCertificate, len(template.RevokedCertificates))
for i, rc := range template.RevokedCertificates {
rc.RevocationTime = rc.RevocationTime.UTC()
revokedCertsUTC[i] = rc
}
aki, err := asn1.Marshal(authKeyId{Id: issuer.SubjectKeyId})
if err != nil {
return nil, err
}
if numBytes := template.Number.Bytes(); len(numBytes) > 20 || (len(numBytes) == 20 && numBytes[0]&0x80 != 0) {
return nil, errors.New("x509: CRL number exceeds 20 octets")
}
crlNum, err := asn1.Marshal(template.Number)
if err != nil {
return nil, err
}
// Correctly use the issuer's subject sequence if one is specified.
issuerSubject, err := subjectBytes(issuer)
if err != nil {
return nil, err
}
tbsCertList := tbsCertificateList{
Version: 1, // v2
Signature: signatureAlgorithm,
Issuer: asn1.RawValue{FullBytes: issuerSubject},
ThisUpdate: template.ThisUpdate.UTC(),
NextUpdate: template.NextUpdate.UTC(),
Extensions: []pkix.Extension{
{
Id: oidExtensionAuthorityKeyId,
Value: aki,
},
{
Id: oidExtensionCRLNumber,
Value: crlNum,
},
},
}
if len(revokedCertsUTC) > 0 {
tbsCertList.RevokedCertificates = revokedCertsUTC
}
if len(template.ExtraExtensions) > 0 {
tbsCertList.Extensions = append(tbsCertList.Extensions, template.ExtraExtensions...)
}
tbsCertListContents, err := asn1.Marshal(tbsCertList)
if err != nil {
return nil, err
}
// Optimization to only marshal this struct once, when signing and
// then embedding in certificateList below.
tbsCertList.Raw = tbsCertListContents
input := tbsCertListContents
if hashFunc != 0 {
h := hashFunc.New()
h.Write(tbsCertListContents)
input = h.Sum(nil)
}
var signerOpts crypto.SignerOpts = hashFunc
if template.SignatureAlgorithm.isRSAPSS() {
signerOpts = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash,
Hash: hashFunc,
}
}
signature, err := priv.Sign(rand, input, signerOpts)
if err != nil {
return nil, err
}
return asn1.Marshal(certificateList{
TBSCertList: tbsCertList,
SignatureAlgorithm: signatureAlgorithm,
SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8},
})
}
// CheckSignatureFrom verifies that the signature on rl is a valid signature
// from issuer.
func (rl *RevocationList) CheckSignatureFrom(parent *Certificate) error {
if parent.Version == 3 && !parent.BasicConstraintsValid ||
parent.BasicConstraintsValid && !parent.IsCA {
return ConstraintViolationError{}
}
if parent.KeyUsage != 0 && parent.KeyUsage&KeyUsageCRLSign == 0 {
return ConstraintViolationError{}
}
if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm {
return ErrUnsupportedAlgorithm
}
return parent.CheckSignature(rl.SignatureAlgorithm, rl.RawTBSRevocationList, rl.Signature)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Type conversions for Scan.
package sql
import (
"bytes"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"unicode"
"unicode/utf8"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
func describeNamedValue(nv *driver.NamedValue) string {
if len(nv.Name) == 0 {
return fmt.Sprintf("$%d", nv.Ordinal)
}
return fmt.Sprintf("with name %q", nv.Name)
}
func validateNamedValueName(name string) error {
if len(name) == 0 {
return nil
}
r, _ := utf8.DecodeRuneInString(name)
if unicode.IsLetter(r) {
return nil
}
return fmt.Errorf("name %q does not begin with a letter", name)
}
// ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct {
cci driver.ColumnConverter
want int
}
func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
if c.cci == nil {
return driver.ErrSkip
}
// The column converter shouldn't be called on any index
// it isn't expecting. The final error will be thrown
// in the argument converter loop.
index := nv.Ordinal - 1
if c.want <= index {
return nil
}
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
if vr, ok := nv.Value.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return err
}
if !driver.IsValue(sv) {
return fmt.Errorf("non-subset type %T returned from Value", sv)
}
nv.Value = sv
}
// Second, ask the column to sanity check itself. For
// example, drivers might use this to make sure that
// an int64 values being inserted into a 16-bit
// integer field is in range (before getting
// truncated), or that a nil can't go into a NOT NULL
// column before going across the network to get the
// same error.
var err error
arg := nv.Value
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
if err != nil {
return err
}
if !driver.IsValue(nv.Value) {
return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
}
return nil
}
// defaultCheckNamedValue wraps the default ColumnConverter to have the same
// function signature as the CheckNamedValue in the driver.NamedValueChecker
// interface.
func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}
// driverArgsConnLocked converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
//
// ci must be locked.
func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
want := -1
var si driver.Stmt
var cc ccChecker
if ds != nil {
si = ds.si
want = ds.si.NumInput()
cc.want = want
}
// Check all types of interfaces from the start.
// Drivers may opt to use the NamedValueChecker for special
// argument types, then return driver.ErrSkip to pass it along
// to the column converter.
nvc, ok := si.(driver.NamedValueChecker)
if !ok {
nvc, ok = ci.(driver.NamedValueChecker)
}
cci, ok := si.(driver.ColumnConverter)
if ok {
cc.cci = cci
}
// Loop through all the arguments, checking each one.
// If no error is returned simply increment the index
// and continue. However if driver.ErrRemoveArgument
// is returned the argument is not included in the query
// argument list.
var err error
var n int
for _, arg := range args {
nv := &nvargs[n]
if np, ok := arg.(NamedArg); ok {
if err = validateNamedValueName(np.Name); err != nil {
return nil, err
}
arg = np.Value
nv.Name = np.Name
}
nv.Ordinal = n + 1
nv.Value = arg
// Checking sequence has four routes:
// A: 1. Default
// B: 1. NamedValueChecker 2. Column Converter 3. Default
// C: 1. NamedValueChecker 3. Default
// D: 1. Column Converter 2. Default
//
// The only time a Column Converter is called is first
// or after NamedValueConverter. If first it is handled before
// the nextCheck label. Thus for repeats tries only when the
// NamedValueConverter is selected should the Column Converter
// be used in the retry.
checker := defaultCheckNamedValue
nextCC := false
switch {
case nvc != nil:
nextCC = cci != nil
checker = nvc.CheckNamedValue
case cci != nil:
checker = cc.CheckNamedValue
}
nextCheck:
err = checker(nv)
switch err {
case nil:
n++
continue
case driver.ErrRemoveArgument:
nvargs = nvargs[:len(nvargs)-1]
continue
case driver.ErrSkip:
if nextCC {
nextCC = false
checker = cc.CheckNamedValue
} else {
checker = defaultCheckNamedValue
}
goto nextCheck
default:
return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
}
}
// Check the length of arguments after conversion to allow for omitted
// arguments.
if want != -1 && len(nvargs) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
}
return nvargs, nil
}
// convertAssign is the same as convertAssignRows, but without the optional
// rows argument.
func convertAssign(dest, src any) error {
return convertAssignRows(dest, src, nil)
}
// convertAssignRows copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type. If rows is passed in, the rows will
// be used as the parent for any cursor values converted from a
// driver.Rows to a *Rows.
func convertAssignRows(dest, src any, rows *Rows) error {
// Common cases, without reflect.
switch s := src.(type) {
case string:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = append((*d)[:0], s...)
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = string(s)
return nil
case *any:
if d == nil {
return errNilPtr
}
*d = bytes.Clone(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = bytes.Clone(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s
return nil
}
case time.Time:
switch d := dest.(type) {
case *time.Time:
*d = s
return nil
case *string:
*d = s.Format(time.RFC3339Nano)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s.Format(time.RFC3339Nano))
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
return nil
}
case decimalDecompose:
switch d := dest.(type) {
case decimalCompose:
return d.Compose(s.Decompose(nil))
}
case nil:
switch d := dest.(type) {
case *any:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = nil
return nil
}
// The driver is returning a cursor the client may iterate over.
case driver.Rows:
switch d := dest.(type) {
case *Rows:
if d == nil {
return errNilPtr
}
if rows == nil {
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
}
rows.closemu.Lock()
*d = Rows{
dc: rows.dc,
releaseConn: func(error) {},
rowsi: s,
}
// Chain the cancel function.
parentCancel := rows.cancel
rows.cancel = func() {
// When Rows.cancel is called, the closemu will be locked as well.
// So we can access rs.lasterr.
d.close(rows.lasterr)
if parentCancel != nil {
parentCancel()
}
}
rows.closemu.Unlock()
return nil
}
}
var sv reflect.Value
switch d := dest.(type) {
case *string:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = asString(src)
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
if b, ok := asBytes(nil, sv); ok {
*d = b
return nil
}
case *RawBytes:
sv = reflect.ValueOf(src)
if b, ok := asBytes([]byte(*d)[:0], sv); ok {
*d = RawBytes(b)
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
*d = bv.(bool)
}
return err
case *any:
*d = src
return nil
}
if scanner, ok := dest.(Scanner); ok {
return scanner.Scan(src)
}
dpv := reflect.ValueOf(dest)
if dpv.Kind() != reflect.Pointer {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}
if !sv.IsValid() {
sv = reflect.ValueOf(src)
}
dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
switch b := src.(type) {
case []byte:
dv.Set(reflect.ValueOf(bytes.Clone(b)))
default:
dv.Set(sv)
}
return nil
}
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
dv.Set(sv.Convert(dv.Type()))
return nil
}
// The following conversions use a string value as an intermediate representation
// to convert between various numeric types.
//
// This also allows scanning into user defined types such as "type Int int64".
// For symmetry, also check for string destination types.
switch dv.Kind() {
case reflect.Pointer:
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
}
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssignRows(dv.Interface(), src, rows)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetInt(i64)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetUint(u64)
return nil
case reflect.Float32, reflect.Float64:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
s := asString(src)
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
if err != nil {
err = strconvErr(err)
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
}
dv.SetFloat(f64)
return nil
case reflect.String:
if src == nil {
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
}
switch v := src.(type) {
case string:
dv.SetString(v)
return nil
case []byte:
dv.SetString(string(v))
return nil
}
}
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
}
func strconvErr(err error) error {
if ne, ok := err.(*strconv.NumError); ok {
return ne.Err
}
return err
}
func asString(src any) string {
switch v := src.(type) {
case string:
return v
case []byte:
return string(v)
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
case reflect.Bool:
return strconv.FormatBool(rv.Bool())
}
return fmt.Sprintf("%v", src)
}
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.AppendInt(buf, rv.Int(), 10), true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.AppendUint(buf, rv.Uint(), 10), true
case reflect.Float32:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
case reflect.Float64:
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
case reflect.Bool:
return strconv.AppendBool(buf, rv.Bool()), true
case reflect.String:
s := rv.String()
return append(buf, s...), true
}
return
}
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is mirrored in the database/sql/driver package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
// decimal composes or decomposes a decimal value to and from individual parts.
// There are four parts: a boolean negative flag, a form byte with three possible states
// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer
// coefficient (also known as a significand) as a []byte, and an int32 exponent.
// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent".
// A zero length coefficient is a zero value.
// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]).
// If the form is not finite the coefficient and exponent should be ignored.
// The negative parameter may be set to true for any form, although implementations are not required
// to respect the negative parameter in the non-finite form.
//
// Implementations may choose to set the negative parameter to true on a zero or NaN value,
// but implementations that do not differentiate between negative and positive
// zero or NaN values should ignore the negative parameter without error.
// If an implementation does not support Infinity it may be converted into a NaN without error.
// If a value is set that is larger than what is supported by an implementation,
// an error must be returned.
// Implementations must return an error if a NaN or Infinity is attempted to be set while neither
// are supported.
//
// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870
type decimal interface {
decimalDecompose
decimalCompose
}
type decimalDecompose interface {
// Decompose returns the internal decimal state in parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}
type decimalCompose interface {
// Compose sets the internal decimal value from parts. If the value cannot be
// represented then an error should be returned.
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sql
import (
"context"
"database/sql/driver"
"errors"
)
func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
si, err := ci.Prepare(query)
if err == nil {
select {
default:
case <-ctx.Done():
si.Close()
return nil, ctx.Err()
}
}
return si, err
}
func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if execerCtx != nil {
return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return execer.Exec(query, dargs)
}
func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx != nil {
return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return queryer.Query(query, dargs)
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Exec(dargs)
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Query(dargs)
}
func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
if ciCtx, is := ci.(driver.ConnBeginTx); is {
dopts := driver.TxOptions{}
if opts != nil {
dopts.Isolation = driver.IsolationLevel(opts.Isolation)
dopts.ReadOnly = opts.ReadOnly
}
return ciCtx.BeginTx(ctx, dopts)
}
if opts != nil {
// Check the transaction level. If the transaction level is non-default
// then return an error here as the BeginTx driver value is not supported.
if opts.Isolation != LevelDefault {
return nil, errors.New("sql: driver does not support non-default isolation level")
}
// If a read-only transaction is requested return an error as the
// BeginTx driver value is not supported.
if opts.ReadOnly {
return nil, errors.New("sql: driver does not support read-only transactions")
}
}
if ctx.Done() == nil {
return ci.Begin()
}
txi, err := ci.Begin()
if err == nil {
select {
default:
case <-ctx.Done():
txi.Rollback()
return nil, ctx.Err()
}
}
return txi, err
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package driver defines interfaces to be implemented by database
// drivers as used by package sql.
//
// Most code should use package sql.
//
// The driver interface has evolved over time. Drivers should implement
// Connector and DriverContext interfaces.
// The Connector.Connect and Driver.Open methods should never return ErrBadConn.
// ErrBadConn should only be returned from Validator, SessionResetter, or
// a query method if the connection is already in an invalid (e.g. closed) state.
//
// All Conn implementations should implement the following interfaces:
// Pinger, SessionResetter, and Validator.
//
// If named parameters or context are supported, the driver's Conn should implement:
// ExecerContext, QueryerContext, ConnPrepareContext, and ConnBeginTx.
//
// To support custom data types, implement NamedValueChecker. NamedValueChecker
// also allows queries to accept per-query options as a parameter by returning
// ErrRemoveArgument from CheckNamedValue.
//
// If multiple result sets are supported, Rows should implement RowsNextResultSet.
// If the driver knows how to describe the types present in the returned result
// it should implement the following interfaces: RowsColumnTypeScanType,
// RowsColumnTypeDatabaseTypeName, RowsColumnTypeLength, RowsColumnTypeNullable,
// and RowsColumnTypePrecisionScale. A given row value may also return a Rows
// type, which may represent a database cursor value.
//
// Before a connection is returned to the connection pool after use, IsValid is
// called if implemented. Before a connection is reused for another query,
// ResetSession is called if implemented. If a connection is never returned to the
// connection pool but immediately reused, then ResetSession is called prior to
// reuse but IsValid is not called.
package driver
import (
"context"
"errors"
"reflect"
)
// Value is a value that drivers must be able to handle.
// It is either nil, a type handled by a database driver's NamedValueChecker
// interface, or an instance of one of these types:
//
// int64
// float64
// bool
// []byte
// string
// time.Time
//
// If the driver supports cursors, a returned Value may also implement the Rows interface
// in this package. This is used, for example, when a user selects a cursor
// such as "select cursor(select * from my_table) from dual". If the Rows
// from the select is closed, the cursor Rows will also be closed.
type Value any
// NamedValue holds both the value name and value.
type NamedValue struct {
// If the Name is not empty it should be used for the parameter identifier and
// not the ordinal position.
//
// Name will not have a symbol prefix.
Name string
// Ordinal position of the parameter starting from one and is always set.
Ordinal int
// Value is the parameter value.
Value Value
}
// Driver is the interface that must be implemented by a database
// driver.
//
// Database drivers may implement DriverContext for access
// to contexts and to parse the name only once for a pool of connections,
// instead of once per connection.
type Driver interface {
// Open returns a new connection to the database.
// The name is a string in a driver-specific format.
//
// Open may return a cached connection (one previously
// closed), but doing so is unnecessary; the sql package
// maintains a pool of idle connections for efficient re-use.
//
// The returned connection is only used by one goroutine at a
// time.
Open(name string) (Conn, error)
}
// If a Driver implements DriverContext, then sql.DB will call
// OpenConnector to obtain a Connector and then invoke
// that Connector's Connect method to obtain each needed connection,
// instead of invoking the Driver's Open method for each connection.
// The two-step sequence allows drivers to parse the name just once
// and also provides access to per-Conn contexts.
type DriverContext interface {
// OpenConnector must parse the name in the same format that Driver.Open
// parses the name parameter.
OpenConnector(name string) (Connector, error)
}
// A Connector represents a driver in a fixed configuration
// and can create any number of equivalent Conns for use
// by multiple goroutines.
//
// A Connector can be passed to sql.OpenDB, to allow drivers
// to implement their own sql.DB constructors, or returned by
// DriverContext's OpenConnector method, to allow drivers
// access to context and to avoid repeated parsing of driver
// configuration.
//
// If a Connector implements io.Closer, the sql package's DB.Close
// method will call Close and return error (if any).
type Connector interface {
// Connect returns a connection to the database.
// Connect may return a cached connection (one previously
// closed), but doing so is unnecessary; the sql package
// maintains a pool of idle connections for efficient re-use.
//
// The provided context.Context is for dialing purposes only
// (see net.DialContext) and should not be stored or used for
// other purposes. A default timeout should still be used
// when dialing as a connection pool may call Connect
// asynchronously to any query.
//
// The returned connection is only used by one goroutine at a
// time.
Connect(context.Context) (Conn, error)
// Driver returns the underlying Driver of the Connector,
// mainly to maintain compatibility with the Driver method
// on sql.DB.
Driver() Driver
}
// ErrSkip may be returned by some optional interfaces' methods to
// indicate at runtime that the fast path is unavailable and the sql
// package should continue as if the optional interface was not
// implemented. ErrSkip is only supported where explicitly
// documented.
var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// ErrBadConn should be returned by a driver to signal to the sql
// package that a driver.Conn is in a bad state (such as the server
// having earlier closed the connection) and the sql package should
// retry on a new connection.
//
// To prevent duplicate operations, ErrBadConn should NOT be returned
// if there's a possibility that the database server might have
// performed the operation. Even if the server sends back an error,
// you shouldn't return ErrBadConn.
//
// Errors will be checked using errors.Is. An error may
// wrap ErrBadConn or implement the Is(error) bool method.
var ErrBadConn = errors.New("driver: bad connection")
// Pinger is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Pinger, the sql package's DB.Ping and
// DB.PingContext will check if there is at least one Conn available.
//
// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove
// the Conn from pool.
type Pinger interface {
Ping(ctx context.Context) error
}
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn implements neither ExecerContext nor Execer,
// the sql package's DB.Exec will first prepare a query, execute the statement,
// and then close the statement.
//
// Exec may return ErrSkip.
//
// Deprecated: Drivers should implement ExecerContext instead.
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// ExecerContext is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement ExecerContext, the sql package's DB.Exec
// will fall back to Execer; if the Conn does not implement Execer either,
// DB.Exec will first prepare a query, execute the statement, and then
// close the statement.
//
// ExecContext may return ErrSkip.
//
// ExecContext must honor the context timeout and return when the context is canceled.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
}
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn implements neither QueryerContext nor Queryer,
// the sql package's DB.Query will first prepare a query, execute the statement,
// and then close the statement.
//
// Query may return ErrSkip.
//
// Deprecated: Drivers should implement QueryerContext instead.
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// QueryerContext is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement QueryerContext, the sql package's DB.Query
// will fall back to Queryer; if the Conn does not implement Queryer either,
// DB.Query will first prepare a query, execute the statement, and then
// close the statement.
//
// QueryContext may return ErrSkip.
//
// QueryContext must honor the context timeout and return when the context is canceled.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
}
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
// Conn is assumed to be stateful.
type Conn interface {
// Prepare returns a prepared statement, bound to this connection.
Prepare(query string) (Stmt, error)
// Close invalidates and potentially stops any current
// prepared statements and transactions, marking this
// connection as no longer in use.
//
// Because the sql package maintains a free pool of
// connections and only calls Close when there's a surplus of
// idle connections, it shouldn't be necessary for drivers to
// do their own connection caching.
//
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
Close() error
// Begin starts and returns a new transaction.
//
// Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
Begin() (Tx, error)
}
// ConnPrepareContext enhances the Conn interface with context.
type ConnPrepareContext interface {
// PrepareContext returns a prepared statement, bound to this connection.
// context is for the preparation of the statement,
// it must not store the context within the statement itself.
PrepareContext(ctx context.Context, query string) (Stmt, error)
}
// IsolationLevel is the transaction isolation level stored in TxOptions.
//
// This type should be considered identical to sql.IsolationLevel along
// with any values defined on it.
type IsolationLevel int
// TxOptions holds the transaction options.
//
// This type should be considered identical to sql.TxOptions.
type TxOptions struct {
Isolation IsolationLevel
ReadOnly bool
}
// ConnBeginTx enhances the Conn interface with context and TxOptions.
type ConnBeginTx interface {
// BeginTx starts and returns a new transaction.
// If the context is canceled by the user the sql package will
// call Tx.Rollback before discarding and closing the connection.
//
// This must check opts.Isolation to determine if there is a set
// isolation level. If the driver does not support a non-default
// level and one is set or if there is a non-default isolation level
// that is not supported, an error must be returned.
//
// This must also check opts.ReadOnly to determine if the read-only
// value is true to either set the read-only transaction property if supported
// or return an error if it is not supported.
BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
}
// SessionResetter may be implemented by Conn to allow drivers to reset the
// session state associated with the connection and to signal a bad connection.
type SessionResetter interface {
// ResetSession is called prior to executing a query on the connection
// if the connection has been used before. If the driver returns ErrBadConn
// the connection is discarded.
ResetSession(ctx context.Context) error
}
// Validator may be implemented by Conn to allow drivers to
// signal if a connection is valid or if it should be discarded.
//
// If implemented, drivers may return the underlying error from queries,
// even if the connection should be discarded by the connection pool.
type Validator interface {
// IsValid is called prior to placing the connection into the
// connection pool. The connection will be discarded if false is returned.
IsValid() bool
}
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
// after, for example, an INSERT into a table with primary
// key.
LastInsertId() (int64, error)
// RowsAffected returns the number of rows affected by the
// query.
RowsAffected() (int64, error)
}
// Stmt is a prepared statement. It is bound to a Conn and not
// used by multiple goroutines concurrently.
type Stmt interface {
// Close closes the statement.
//
// As of Go 1.1, a Stmt will not be closed if it's in use
// by any queries.
//
// Drivers must ensure all network calls made by Close
// do not block indefinitely (e.g. apply a timeout).
Close() error
// NumInput returns the number of placeholder parameters.
//
// If NumInput returns >= 0, the sql package will sanity check
// argument counts from callers and return errors to the caller
// before the statement's Exec or Query methods are called.
//
// NumInput may also return -1, if the driver doesn't know
// its number of placeholders. In that case, the sql package
// will not sanity check Exec or Query argument counts.
NumInput() int
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
//
// Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
// StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface {
// ExecContext executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
//
// ExecContext must honor the context timeout and return when it is canceled.
ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}
// StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface {
// QueryContext executes a query that may return rows, such as a
// SELECT.
//
// QueryContext must honor the context timeout and return when it is canceled.
QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
// ErrRemoveArgument may be returned from NamedValueChecker to instruct the
// sql package to not pass the argument to the driver query interface.
// Return when accepting query specific options or structures that aren't
// SQL query arguments.
var ErrRemoveArgument = errors.New("driver: remove argument from query")
// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides
// the driver more control to handle Go and database types beyond the default
// Values types allowed.
//
// The sql package checks for value checkers in the following order,
// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker,
// Stmt.ColumnConverter, DefaultParameterConverter.
//
// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in
// the final query arguments. This may be used to pass special options to
// the query itself.
//
// If ErrSkip is returned the column converter error checking
// path is used for the argument. Drivers may wish to return ErrSkip after
// they have exhausted their own special cases.
type NamedValueChecker interface {
// CheckNamedValue is called before passing arguments to the driver
// and is called in place of any ColumnConverter. CheckNamedValue must do type
// validation and conversion as appropriate for the driver.
CheckNamedValue(*NamedValue) error
}
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
//
// Deprecated: Drivers should implement NamedValueChecker.
type ColumnConverter interface {
// ColumnConverter returns a ValueConverter for the provided
// column index. If the type of a specific column isn't known
// or shouldn't be handled specially, DefaultValueConverter
// can be returned.
ColumnConverter(idx int) ValueConverter
}
// Rows is an iterator over an executed query's results.
type Rows interface {
// Columns returns the names of the columns. The number of
// columns of the result is inferred from the length of the
// slice. If a particular column name isn't known, an empty
// string should be returned for that entry.
Columns() []string
// Close closes the rows iterator.
Close() error
// Next is called to populate the next row of data into
// the provided slice. The provided slice will be the same
// size as the Columns() are wide.
//
// Next should return io.EOF when there are no more rows.
//
// The dest should not be written to outside of Next. Care
// should be taken when closing Rows not to modify
// a buffer held in dest.
Next(dest []Value) error
}
// RowsNextResultSet extends the Rows interface by providing a way to signal
// the driver to advance to the next result set.
type RowsNextResultSet interface {
Rows
// HasNextResultSet is called at the end of the current result set and
// reports whether there is another result set after the current one.
HasNextResultSet() bool
// NextResultSet advances the driver to the next result set even
// if there are remaining rows in the current result set.
//
// NextResultSet should return io.EOF when there are no more result sets.
NextResultSet() error
}
// RowsColumnTypeScanType may be implemented by Rows. It should return
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
type RowsColumnTypeScanType interface {
Rows
ColumnTypeScanType(index int) reflect.Type
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
// database system type name without the length. Type names should be uppercase.
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
type RowsColumnTypeDatabaseTypeName interface {
Rows
ColumnTypeDatabaseTypeName(index int) string
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
// of the column type if the column is a variable length type. If the column is
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
type RowsColumnTypeLength interface {
Rows
ColumnTypeLength(index int) (length int64, ok bool)
}
// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
// be true if it is known the column may be null, or false if the column is known
// to be not nullable.
// If the column nullability is unknown, ok should be false.
type RowsColumnTypeNullable interface {
Rows
ColumnTypeNullable(index int) (nullable, ok bool)
}
// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
//
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
type RowsColumnTypePrecisionScale interface {
Rows
ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
}
// Tx is a transaction.
type Tx interface {
Commit() error
Rollback() error
}
// RowsAffected implements Result for an INSERT or UPDATE operation
// which mutates a number of rows.
type RowsAffected int64
var _ Result = RowsAffected(0)
func (RowsAffected) LastInsertId() (int64, error) {
return 0, errors.New("LastInsertId is not supported by this driver")
}
func (v RowsAffected) RowsAffected() (int64, error) {
return int64(v), nil
}
// ResultNoRows is a pre-defined Result for drivers to return when a DDL
// command (such as a CREATE TABLE) succeeds. It returns an error for both
// LastInsertId and RowsAffected.
var ResultNoRows noRows
type noRows struct{}
var _ Result = noRows{}
func (noRows) LastInsertId() (int64, error) {
return 0, errors.New("no LastInsertId available after DDL statement")
}
func (noRows) RowsAffected() (int64, error) {
return 0, errors.New("no RowsAffected available after DDL statement")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package driver
import (
"fmt"
"reflect"
"strconv"
"time"
)
// ValueConverter is the interface providing the ConvertValue method.
//
// Various implementations of ValueConverter are provided by the
// driver package to provide consistent implementations of conversions
// between drivers. The ValueConverters have several uses:
//
// - converting from the Value types as provided by the sql package
// into a database table's specific column type and making sure it
// fits, such as making sure a particular int64 fits in a
// table's uint16 column.
//
// - converting a value as given from the database into one of the
// driver Value types.
//
// - by the sql package, for converting from a driver's Value type
// to a user's type in a scan.
type ValueConverter interface {
// ConvertValue converts a value to a driver Value.
ConvertValue(v any) (Value, error)
}
// Valuer is the interface providing the Value method.
//
// Types implementing Valuer interface are able to convert
// themselves to a driver Value.
type Valuer interface {
// Value returns a driver Value.
// Value must not panic.
Value() (Value, error)
}
// Bool is a ValueConverter that converts input values to bools.
//
// The conversion rules are:
// - booleans are returned unchanged
// - for integer types,
// 1 is true
// 0 is false,
// other integers are an error
// - for strings and []byte, same rules as strconv.ParseBool
// - all other types are an error
var Bool boolType
type boolType struct{}
var _ ValueConverter = boolType{}
func (boolType) String() string { return "Bool" }
func (boolType) ConvertValue(src any) (Value, error) {
switch s := src.(type) {
case bool:
return s, nil
case string:
b, err := strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
}
return b, nil
case []byte:
b, err := strconv.ParseBool(string(s))
if err != nil {
return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
}
return b, nil
}
sv := reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
iv := sv.Int()
if iv == 1 || iv == 0 {
return iv == 1, nil
}
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uv := sv.Uint()
if uv == 1 || uv == 0 {
return uv == 1, nil
}
return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv)
}
return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src)
}
// Int32 is a ValueConverter that converts input values to int64,
// respecting the limits of an int32 value.
var Int32 int32Type
type int32Type struct{}
var _ ValueConverter = int32Type{}
func (int32Type) ConvertValue(v any) (Value, error) {
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i64 := rv.Int()
if i64 > (1<<31)-1 || i64 < -(1<<31) {
return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
}
return i64, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
u64 := rv.Uint()
if u64 > (1<<31)-1 {
return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
}
return int64(u64), nil
case reflect.String:
i, err := strconv.Atoi(rv.String())
if err != nil {
return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v)
}
return int64(i), nil
}
return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v)
}
// String is a ValueConverter that converts its input to a string.
// If the value is already a string or []byte, it's unchanged.
// If the value is of another type, conversion to string is done
// with fmt.Sprintf("%v", v).
var String stringType
type stringType struct{}
func (stringType) ConvertValue(v any) (Value, error) {
switch v.(type) {
case string, []byte:
return v, nil
}
return fmt.Sprintf("%v", v), nil
}
// Null is a type that implements ValueConverter by allowing nil
// values but otherwise delegating to another ValueConverter.
type Null struct {
Converter ValueConverter
}
func (n Null) ConvertValue(v any) (Value, error) {
if v == nil {
return nil, nil
}
return n.Converter.ConvertValue(v)
}
// NotNull is a type that implements ValueConverter by disallowing nil
// values but otherwise delegating to another ValueConverter.
type NotNull struct {
Converter ValueConverter
}
func (n NotNull) ConvertValue(v any) (Value, error) {
if v == nil {
return nil, fmt.Errorf("nil value not allowed")
}
return n.Converter.ConvertValue(v)
}
// IsValue reports whether v is a valid Value parameter type.
func IsValue(v any) bool {
if v == nil {
return true
}
switch v.(type) {
case []byte, bool, float64, int64, string, time.Time:
return true
case decimalDecompose:
return true
}
return false
}
// IsScanValue is equivalent to IsValue.
// It exists for compatibility.
func IsScanValue(v any) bool {
return IsValue(v)
}
// DefaultParameterConverter is the default implementation of
// ValueConverter that's used when a Stmt doesn't implement
// ColumnConverter.
//
// DefaultParameterConverter returns its argument directly if
// IsValue(arg). Otherwise, if the argument implements Valuer, its
// Value method is used to return a Value. As a fallback, the provided
// argument's underlying type is used to convert it to a Value:
// underlying integer types are converted to int64, floats to float64,
// bool, string, and []byte to themselves. If the argument is a nil
// pointer, ConvertValue returns a nil Value. If the argument is a
// non-nil pointer, it is dereferenced and ConvertValue is called
// recursively. Other types are an error.
var DefaultParameterConverter defaultConverter
type defaultConverter struct{}
var _ ValueConverter = defaultConverter{}
var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem()
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is mirrored in the database/sql package.
func callValuerValue(vr Valuer) (v Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
func (defaultConverter) ConvertValue(v any) (Value, error) {
if IsValue(v) {
return v, nil
}
switch vr := v.(type) {
case Valuer:
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if !IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
return sv, nil
// For now, continue to prefer the Valuer interface over the decimal decompose interface.
case decimalDecompose:
return vr, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Pointer:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return defaultConverter{}.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return nil, fmt.Errorf("uint64 values with high bit set are not supported")
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
return rv.Bytes(), nil
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
type decimalDecompose interface {
// Decompose returns the internal decimal state into parts.
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
// the value set and length set as appropriate.
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sql provides a generic interface around SQL (or SQL-like)
// databases.
//
// The sql package must be used in conjunction with a database driver.
// See https://golang.org/s/sqldrivers for a list of drivers.
//
// Drivers that do not support context cancellation will not return until
// after the query is completed.
//
// For usage examples, see the wiki page at
// https://golang.org/s/sqlwiki.
package sql
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"io"
"reflect"
"runtime"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"
)
var (
driversMu sync.RWMutex
drivers = make(map[string]driver.Driver)
)
// nowFunc returns the current time; it's overridden in tests.
var nowFunc = time.Now
// Register makes a database driver available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, driver driver.Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if driver == nil {
panic("sql: Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("sql: Register called twice for driver " + name)
}
drivers[name] = driver
}
func unregisterAllDrivers() {
driversMu.Lock()
defer driversMu.Unlock()
// For tests.
drivers = make(map[string]driver.Driver)
}
// Drivers returns a sorted list of the names of the registered drivers.
func Drivers() []string {
driversMu.RLock()
defer driversMu.RUnlock()
list := make([]string, 0, len(drivers))
for name := range drivers {
list = append(list, name)
}
sort.Strings(list)
return list
}
// A NamedArg is a named argument. NamedArg values may be used as
// arguments to Query or Exec and bind to the corresponding named
// parameter in the SQL statement.
//
// For a more concise way to create NamedArg values, see
// the Named function.
type NamedArg struct {
_NamedFieldsRequired struct{}
// Name is the name of the parameter placeholder.
//
// If empty, the ordinal position in the argument list will be
// used.
//
// Name must omit any symbol prefix.
Name string
// Value is the value of the parameter.
// It may be assigned the same value types as the query
// arguments.
Value any
}
// Named provides a more concise way to create NamedArg values.
//
// Example usage:
//
// db.ExecContext(ctx, `
// delete from Invoice
// where
// TimeCreated < @end
// and TimeCreated >= @start;`,
// sql.Named("start", startTime),
// sql.Named("end", endTime),
// )
func Named(name string, value any) NamedArg {
// This method exists because the go1compat promise
// doesn't guarantee that structs don't grow more fields,
// so unkeyed struct literals are a vet error. Thus, we don't
// want to allow sql.NamedArg{name, value}.
return NamedArg{Name: name, Value: value}
}
// IsolationLevel is the transaction isolation level used in TxOptions.
type IsolationLevel int
// Various isolation levels that drivers may support in BeginTx.
// If a driver does not support a given isolation level an error may be returned.
//
// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
const (
LevelDefault IsolationLevel = iota
LevelReadUncommitted
LevelReadCommitted
LevelWriteCommitted
LevelRepeatableRead
LevelSnapshot
LevelSerializable
LevelLinearizable
)
// String returns the name of the transaction isolation level.
func (i IsolationLevel) String() string {
switch i {
case LevelDefault:
return "Default"
case LevelReadUncommitted:
return "Read Uncommitted"
case LevelReadCommitted:
return "Read Committed"
case LevelWriteCommitted:
return "Write Committed"
case LevelRepeatableRead:
return "Repeatable Read"
case LevelSnapshot:
return "Snapshot"
case LevelSerializable:
return "Serializable"
case LevelLinearizable:
return "Linearizable"
default:
return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
}
}
var _ fmt.Stringer = LevelDefault
// TxOptions holds the transaction options to be used in DB.BeginTx.
type TxOptions struct {
// Isolation is the transaction isolation level.
// If zero, the driver or database's default level is used.
Isolation IsolationLevel
ReadOnly bool
}
// RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a Scan into a RawBytes, the slice is only
// valid until the next call to Next, Scan, or Close.
type RawBytes []byte
// NullString represents a string that may be null.
// NullString implements the Scanner interface so
// it can be used as a scan destination:
//
// var s NullString
// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
// ...
// if s.Valid {
// // use s.String
// } else {
// // NULL value
// }
type NullString struct {
String string
Valid bool // Valid is true if String is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullString) Scan(value any) error {
if value == nil {
ns.String, ns.Valid = "", false
return nil
}
ns.Valid = true
return convertAssign(&ns.String, value)
}
// Value implements the driver Valuer interface.
func (ns NullString) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return ns.String, nil
}
// NullInt64 represents an int64 that may be null.
// NullInt64 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullInt64 struct {
Int64 int64
Valid bool // Valid is true if Int64 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullInt64) Scan(value any) error {
if value == nil {
n.Int64, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Int64, value)
}
// Value implements the driver Valuer interface.
func (n NullInt64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int64, nil
}
// NullInt32 represents an int32 that may be null.
// NullInt32 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullInt32 struct {
Int32 int32
Valid bool // Valid is true if Int32 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value any) error {
if value == nil {
n.Int32, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Int32, value)
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Int32), nil
}
// NullInt16 represents an int16 that may be null.
// NullInt16 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullInt16 struct {
Int16 int16
Valid bool // Valid is true if Int16 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullInt16) Scan(value any) error {
if value == nil {
n.Int16, n.Valid = 0, false
return nil
}
err := convertAssign(&n.Int16, value)
n.Valid = err == nil
return err
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Int16), nil
}
// NullByte represents a byte that may be null.
// NullByte implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullByte struct {
Byte byte
Valid bool // Valid is true if Byte is not NULL
}
// Scan implements the Scanner interface.
func (n *NullByte) Scan(value any) error {
if value == nil {
n.Byte, n.Valid = 0, false
return nil
}
err := convertAssign(&n.Byte, value)
n.Valid = err == nil
return err
}
// Value implements the driver Valuer interface.
func (n NullByte) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return int64(n.Byte), nil
}
// NullFloat64 represents a float64 that may be null.
// NullFloat64 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullFloat64 struct {
Float64 float64
Valid bool // Valid is true if Float64 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullFloat64) Scan(value any) error {
if value == nil {
n.Float64, n.Valid = 0, false
return nil
}
n.Valid = true
return convertAssign(&n.Float64, value)
}
// Value implements the driver Valuer interface.
func (n NullFloat64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float64, nil
}
// NullBool represents a bool that may be null.
// NullBool implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullBool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}
// Scan implements the Scanner interface.
func (n *NullBool) Scan(value any) error {
if value == nil {
n.Bool, n.Valid = false, false
return nil
}
n.Valid = true
return convertAssign(&n.Bool, value)
}
// Value implements the driver Valuer interface.
func (n NullBool) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Bool, nil
}
// NullTime represents a time.Time that may be null.
// NullTime implements the Scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (n *NullTime) Scan(value any) error {
if value == nil {
n.Time, n.Valid = time.Time{}, false
return nil
}
n.Valid = true
return convertAssign(&n.Time, value)
}
// Value implements the driver Valuer interface.
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
// Scanner is an interface used by Scan.
type Scanner interface {
// Scan assigns a value from a database driver.
//
// The src value will be of one of the following types:
//
// int64
// float64
// bool
// []byte
// string
// time.Time
// nil - for NULL values
//
// An error should be returned if the value cannot be stored
// without loss of information.
//
// Reference types such as []byte are only valid until the next call to Scan
// and should not be retained. Their underlying memory is owned by the driver.
// If retention is necessary, copy their values before the next call to Scan.
Scan(src any) error
}
// Out may be used to retrieve OUTPUT value parameters from stored procedures.
//
// Not all drivers and databases support OUTPUT value parameters.
//
// Example usage:
//
// var outArg string
// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
type Out struct {
_NamedFieldsRequired struct{}
// Dest is a pointer to the value that will be set to the result of the
// stored procedure's OUTPUT parameter.
Dest any
// In is whether the parameter is an INOUT parameter. If so, the input value to the stored
// procedure is the dereferenced value of Dest's pointer, which is then replaced with
// the output value.
In bool
}
// ErrNoRows is returned by Scan when QueryRow doesn't return a
// row. In such a case, QueryRow returns a placeholder *Row value that
// defers this error until a Scan.
var ErrNoRows = errors.New("sql: no rows in result set")
// DB is a database handle representing a pool of zero or more
// underlying connections. It's safe for concurrent use by multiple
// goroutines.
//
// The sql package creates and frees connections automatically; it
// also maintains a free pool of idle connections. If the database has
// a concept of per-connection state, such state can be reliably observed
// within a transaction (Tx) or connection (Conn). Once DB.Begin is called, the
// returned Tx is bound to a single connection. Once Commit or
// Rollback is called on the transaction, that transaction's
// connection is returned to DB's idle connection pool. The pool size
// can be controlled with SetMaxIdleConns.
type DB struct {
// Total time waited for new connections.
waitDuration atomic.Int64
connector driver.Connector
// numClosed is an atomic counter which represents a total number of
// closed connections. Stmt.openStmt checks it before cleaning closed
// connections in Stmt.css.
numClosed atomic.Uint64
mu sync.Mutex // protects following fields
freeConn []*driverConn // free connections ordered by returnedAt oldest to newest
connRequests map[uint64]chan connRequest
nextRequest uint64 // Next key to use in connRequests.
numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
maxIdleCount int // zero means defaultMaxIdleConns; negative means 0
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed
cleanerCh chan struct{}
waitCount int64 // Total number of connections waited for.
maxIdleClosed int64 // Total number of connections closed due to idle count.
maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.
stop func() // stop cancels the connection opener.
}
// connReuseStrategy determines how (*DB).conn returns database connections.
type connReuseStrategy uint8
const (
// alwaysNewConn forces a new connection to the database.
alwaysNewConn connReuseStrategy = iota
// cachedOrNewConn returns a cached connection, if available, else waits
// for one to become available (if MaxOpenConns has been reached) or
// creates a new database connection.
cachedOrNewConn
)
// driverConn wraps a driver.Conn with a mutex, to
// be held during all calls into the Conn. (including any calls onto
// interfaces returned via that Conn, such as calls on Tx, Stmt,
// Result, Rows)
type driverConn struct {
db *DB
createdAt time.Time
sync.Mutex // guards following
ci driver.Conn
needReset bool // The connection session should be reset before use if true.
closed bool
finalClosed bool // ci.Close has been called
openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
returnedAt time.Time // Time the connection was created or returned.
onPut []func() // code (with db.mu held) run when conn is next returned
dbmuClosed bool // same as closed, but guarded by db.mu, for removeClosedStmtLocked
}
func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err, true)
}
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
if timeout <= 0 {
return false
}
return dc.createdAt.Add(timeout).Before(nowFunc())
}
// resetSession checks if the driver connection needs the
// session to be reset and if required, resets it.
func (dc *driverConn) resetSession(ctx context.Context) error {
dc.Lock()
defer dc.Unlock()
if !dc.needReset {
return nil
}
if cr, ok := dc.ci.(driver.SessionResetter); ok {
return cr.ResetSession(ctx)
}
return nil
}
// validateConnection checks if the connection is valid and can
// still be used. It also marks the session for reset if required.
func (dc *driverConn) validateConnection(needsReset bool) bool {
dc.Lock()
defer dc.Unlock()
if needsReset {
dc.needReset = true
}
if cv, ok := dc.ci.(driver.Validator); ok {
return cv.IsValid()
}
return true
}
// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
// the prepared statements in a pool.
func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err != nil {
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
// No need to manage open statements if there is a single connection grabber.
if cg != nil {
return ds, nil
}
// Track each driverConn's open statements, so we can close them
// before closing the conn.
//
// Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
dc.openStmt[ds] = true
return ds, nil
}
// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
defer dc.Unlock()
if dc.closed {
return func() error { return errors.New("sql: duplicate driverConn close") }
}
dc.closed = true
return dc.db.removeDepLocked(dc, dc)
}
func (dc *driverConn) Close() error {
dc.Lock()
if dc.closed {
dc.Unlock()
return errors.New("sql: duplicate driverConn close")
}
dc.closed = true
dc.Unlock() // not defer; removeDep finalClose calls may need to lock
// And now updates that require holding dc.mu.Lock.
dc.db.mu.Lock()
dc.dbmuClosed = true
fn := dc.db.removeDepLocked(dc, dc)
dc.db.mu.Unlock()
return fn()
}
func (dc *driverConn) finalClose() error {
var err error
// Each *driverStmt has a lock to the dc. Copy the list out of the dc
// before calling close on each stmt.
var openStmt []*driverStmt
withLock(dc, func() {
openStmt = make([]*driverStmt, 0, len(dc.openStmt))
for ds := range dc.openStmt {
openStmt = append(openStmt, ds)
}
dc.openStmt = nil
})
for _, ds := range openStmt {
ds.Close()
}
withLock(dc, func() {
dc.finalClosed = true
err = dc.ci.Close()
dc.ci = nil
})
dc.db.mu.Lock()
dc.db.numOpen--
dc.db.maybeOpenNewConnections()
dc.db.mu.Unlock()
dc.db.numClosed.Add(1)
return err
}
// driverStmt associates a driver.Stmt with the
// *driverConn from which it came, so the driverConn's lock can be
// held during calls.
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
closed bool
closeErr error // return value of previous Close call
}
// Close ensures driver.Stmt is only closed once and always returns the same
// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
if ds.closed {
return ds.closeErr
}
ds.closed = true
ds.closeErr = ds.si.Close()
return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
type depSet map[any]bool // set of true bools
// The finalCloser interface is used by (*DB).addDep and related
// dependency reference counting.
type finalCloser interface {
// finalClose is called when the reference count of an object
// goes to zero. (*DB).mu is not held while calling it.
finalClose() error
}
// addDep notes that x now depends on dep, and x's finalClose won't be
// called until all of x's dependencies are removed with removeDep.
func (db *DB) addDep(x finalCloser, dep any) {
db.mu.Lock()
defer db.mu.Unlock()
db.addDepLocked(x, dep)
}
func (db *DB) addDepLocked(x finalCloser, dep any) {
if db.dep == nil {
db.dep = make(map[finalCloser]depSet)
}
xdep := db.dep[x]
if xdep == nil {
xdep = make(depSet)
db.dep[x] = xdep
}
xdep[dep] = true
}
// removeDep notes that x no longer depends on dep.
// If x still has dependencies, nil is returned.
// If x no longer has any dependencies, its finalClose method will be
// called and its error value will be returned.
func (db *DB) removeDep(x finalCloser, dep any) error {
db.mu.Lock()
fn := db.removeDepLocked(x, dep)
db.mu.Unlock()
return fn()
}
func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
xdep, ok := db.dep[x]
if !ok {
panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
}
l0 := len(xdep)
delete(xdep, dep)
switch len(xdep) {
case l0:
// Nothing removed. Shouldn't happen.
panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
case 0:
// No more dependencies.
delete(db.dep, x)
return x.finalClose
default:
// Dependencies remain.
return func() error { return nil }
}
}
// This is the size of the connectionOpener request chan (DB.openerCh).
// This value should be larger than the maximum typical value
// used for db.maxOpen. If maxOpen is significantly larger than
// connectionRequestQueueSize then it is possible for ALL calls into the *DB
// to block until the connectionOpener can satisfy the backlog of requests.
var connectionRequestQueueSize = 1000000
type dsnConnector struct {
dsn string
driver driver.Driver
}
func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}
func (t dsnConnector) Driver() driver.Driver {
return t.driver
}
// OpenDB opens a database using a Connector, allowing drivers to
// bypass a string based data source name.
//
// Most users will open a database via a driver-specific connection
// helper function that returns a *DB. No database drivers are included
// in the Go standard library. See https://golang.org/s/sqldrivers for
// a list of third-party drivers.
//
// OpenDB may just validate its arguments without creating a connection
// to the database. To verify that the data source name is valid, call
// Ping.
//
// The returned DB is safe for concurrent use by multiple goroutines
// and maintains its own pool of idle connections. Thus, the OpenDB
// function should be called just once. It is rarely necessary to
// close a DB.
func OpenDB(c driver.Connector) *DB {
ctx, cancel := context.WithCancel(context.Background())
db := &DB{
connector: c,
openerCh: make(chan struct{}, connectionRequestQueueSize),
lastPut: make(map[*driverConn]string),
connRequests: make(map[uint64]chan connRequest),
stop: cancel,
}
go db.connectionOpener(ctx)
return db
}
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
//
// Most users will open a database via a driver-specific connection
// helper function that returns a *DB. No database drivers are included
// in the Go standard library. See https://golang.org/s/sqldrivers for
// a list of third-party drivers.
//
// Open may just validate its arguments without creating a connection
// to the database. To verify that the data source name is valid, call
// Ping.
//
// The returned DB is safe for concurrent use by multiple goroutines
// and maintains its own pool of idle connections. Thus, the Open
// function should be called just once. It is rarely necessary to
// close a DB.
func Open(driverName, dataSourceName string) (*DB, error) {
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
if driverCtx, ok := driveri.(driver.DriverContext); ok {
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
var err error
if pinger, ok := dc.ci.(driver.Pinger); ok {
withLock(dc, func() {
err = pinger.Ping(ctx)
})
}
release(err)
return err
}
// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) PingContext(ctx context.Context) error {
var dc *driverConn
var err error
err = db.retry(func(strategy connReuseStrategy) error {
dc, err = db.conn(ctx, strategy)
return err
})
if err != nil {
return err
}
return db.pingDC(ctx, dc, dc.releaseConn)
}
// Ping verifies a connection to the database is still alive,
// establishing a connection if necessary.
//
// Ping uses context.Background internally; to specify the context, use
// PingContext.
func (db *DB) Ping() error {
return db.PingContext(context.Background())
}
// Close closes the database and prevents new queries from starting.
// Close then waits for all queries that have started processing on the server
// to finish.
//
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
db.mu.Lock()
if db.closed { // Make DB.Close idempotent
db.mu.Unlock()
return nil
}
if db.cleanerCh != nil {
close(db.cleanerCh)
}
var err error
fns := make([]func() error, 0, len(db.freeConn))
for _, dc := range db.freeConn {
fns = append(fns, dc.closeDBLocked())
}
db.freeConn = nil
db.closed = true
for _, req := range db.connRequests {
close(req)
}
db.mu.Unlock()
for _, fn := range fns {
err1 := fn()
if err1 != nil {
err = err1
}
}
db.stop()
if c, ok := db.connector.(io.Closer); ok {
err1 := c.Close()
if err1 != nil {
err = err1
}
}
return err
}
const defaultMaxIdleConns = 2
func (db *DB) maxIdleConnsLocked() int {
n := db.maxIdleCount
switch {
case n == 0:
// TODO(bradfitz): ask driver, if supported, for its default preference
return defaultMaxIdleConns
case n < 0:
return 0
default:
return n
}
}
func (db *DB) shortestIdleTimeLocked() time.Duration {
if db.maxIdleTime <= 0 {
return db.maxLifetime
}
if db.maxLifetime <= 0 {
return db.maxIdleTime
}
min := db.maxIdleTime
if min > db.maxLifetime {
min = db.maxLifetime
}
return min
}
// SetMaxIdleConns sets the maximum number of connections in the idle
// connection pool.
//
// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
//
// If n <= 0, no idle connections are retained.
//
// The default max idle connections is currently 2. This may change in
// a future release.
func (db *DB) SetMaxIdleConns(n int) {
db.mu.Lock()
if n > 0 {
db.maxIdleCount = n
} else {
// No idle connections.
db.maxIdleCount = -1
}
// Make sure maxIdle doesn't exceed maxOpen
if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
db.maxIdleCount = db.maxOpen
}
var closing []*driverConn
idleCount := len(db.freeConn)
maxIdle := db.maxIdleConnsLocked()
if idleCount > maxIdle {
closing = db.freeConn[maxIdle:]
db.freeConn = db.freeConn[:maxIdle]
}
db.maxIdleClosed += int64(len(closing))
db.mu.Unlock()
for _, c := range closing {
c.Close()
}
}
// SetMaxOpenConns sets the maximum number of open connections to the database.
//
// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
// MaxIdleConns, then MaxIdleConns will be reduced to match the new
// MaxOpenConns limit.
//
// If n <= 0, then there is no limit on the number of open connections.
// The default is 0 (unlimited).
func (db *DB) SetMaxOpenConns(n int) {
db.mu.Lock()
db.maxOpen = n
if n < 0 {
db.maxOpen = 0
}
syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
db.mu.Unlock()
if syncMaxIdle {
db.SetMaxIdleConns(n)
}
}
// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's age.
func (db *DB) SetConnMaxLifetime(d time.Duration) {
if d < 0 {
d = 0
}
db.mu.Lock()
// Wake cleaner up when lifetime is shortened.
if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
select {
case db.cleanerCh <- struct{}{}:
default:
}
}
db.maxLifetime = d
db.startCleanerLocked()
db.mu.Unlock()
}
// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle.
//
// Expired connections may be closed lazily before reuse.
//
// If d <= 0, connections are not closed due to a connection's idle time.
func (db *DB) SetConnMaxIdleTime(d time.Duration) {
if d < 0 {
d = 0
}
db.mu.Lock()
defer db.mu.Unlock()
// Wake cleaner up when idle time is shortened.
if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil {
select {
case db.cleanerCh <- struct{}{}:
default:
}
}
db.maxIdleTime = d
db.startCleanerLocked()
}
// startCleanerLocked starts connectionCleaner if needed.
func (db *DB) startCleanerLocked() {
if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
db.cleanerCh = make(chan struct{}, 1)
go db.connectionCleaner(db.shortestIdleTimeLocked())
}
}
func (db *DB) connectionCleaner(d time.Duration) {
const minInterval = time.Second
if d < minInterval {
d = minInterval
}
t := time.NewTimer(d)
for {
select {
case <-t.C:
case <-db.cleanerCh: // maxLifetime was changed or db was closed.
}
db.mu.Lock()
d = db.shortestIdleTimeLocked()
if db.closed || db.numOpen == 0 || d <= 0 {
db.cleanerCh = nil
db.mu.Unlock()
return
}
d, closing := db.connectionCleanerRunLocked(d)
db.mu.Unlock()
for _, c := range closing {
c.Close()
}
if d < minInterval {
d = minInterval
}
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(d)
}
}
// connectionCleanerRunLocked removes connections that should be closed from
// freeConn and returns them along side an updated duration to the next check
// if a quicker check is required to ensure connections are checked appropriately.
func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
var idleClosing int64
var closing []*driverConn
if db.maxIdleTime > 0 {
// As freeConn is ordered by returnedAt process
// in reverse order to minimise the work needed.
idleSince := nowFunc().Add(-db.maxIdleTime)
last := len(db.freeConn) - 1
for i := last; i >= 0; i-- {
c := db.freeConn[i]
if c.returnedAt.Before(idleSince) {
i++
closing = db.freeConn[:i:i]
db.freeConn = db.freeConn[i:]
idleClosing = int64(len(closing))
db.maxIdleTimeClosed += idleClosing
break
}
}
if len(db.freeConn) > 0 {
c := db.freeConn[0]
if d2 := c.returnedAt.Sub(idleSince); d2 < d {
// Ensure idle connections are cleaned up as soon as
// possible.
d = d2
}
}
}
if db.maxLifetime > 0 {
expiredSince := nowFunc().Add(-db.maxLifetime)
for i := 0; i < len(db.freeConn); i++ {
c := db.freeConn[i]
if c.createdAt.Before(expiredSince) {
closing = append(closing, c)
last := len(db.freeConn) - 1
// Use slow delete as order is required to ensure
// connections are reused least idle time first.
copy(db.freeConn[i:], db.freeConn[i+1:])
db.freeConn[last] = nil
db.freeConn = db.freeConn[:last]
i--
} else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
// Prevent connections sitting the freeConn when they
// have expired by updating our next deadline d.
d = d2
}
}
db.maxLifetimeClosed += int64(len(closing)) - idleClosing
}
return d, closing
}
// DBStats contains database statistics.
type DBStats struct {
MaxOpenConnections int // Maximum number of open connections to the database.
// Pool Status
OpenConnections int // The number of established connections both in use and idle.
InUse int // The number of connections currently in use.
Idle int // The number of idle connections.
// Counters
WaitCount int64 // The total number of connections waited for.
WaitDuration time.Duration // The total time blocked waiting for a new connection.
MaxIdleClosed int64 // The total number of connections closed due to SetMaxIdleConns.
MaxIdleTimeClosed int64 // The total number of connections closed due to SetConnMaxIdleTime.
MaxLifetimeClosed int64 // The total number of connections closed due to SetConnMaxLifetime.
}
// Stats returns database statistics.
func (db *DB) Stats() DBStats {
wait := db.waitDuration.Load()
db.mu.Lock()
defer db.mu.Unlock()
stats := DBStats{
MaxOpenConnections: db.maxOpen,
Idle: len(db.freeConn),
OpenConnections: db.numOpen,
InUse: db.numOpen - len(db.freeConn),
WaitCount: db.waitCount,
WaitDuration: time.Duration(wait),
MaxIdleClosed: db.maxIdleClosed,
MaxIdleTimeClosed: db.maxIdleTimeClosed,
MaxLifetimeClosed: db.maxLifetimeClosed,
}
return stats
}
// Assumes db.mu is locked.
// If there are connRequests and the connection limit hasn't been reached,
// then tell the connectionOpener to open new connections.
func (db *DB) maybeOpenNewConnections() {
numRequests := len(db.connRequests)
if db.maxOpen > 0 {
numCanOpen := db.maxOpen - db.numOpen
if numRequests > numCanOpen {
numRequests = numCanOpen
}
}
for numRequests > 0 {
db.numOpen++ // optimistically
numRequests--
if db.closed {
return
}
db.openerCh <- struct{}{}
}
}
// Runs in a separate goroutine, opens new connections when requested.
func (db *DB) connectionOpener(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-db.openerCh:
db.openNewConnection(ctx)
}
}
}
// Open one new connection
func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnections has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
if err == nil {
ci.Close()
}
db.numOpen--
return
}
if err != nil {
db.numOpen--
db.putConnDBLocked(nil, err)
db.maybeOpenNewConnections()
return
}
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
}
if db.putConnDBLocked(dc, err) {
db.addDepLocked(dc, dc)
} else {
db.numOpen--
ci.Close()
}
}
// connRequest represents one request for a new connection
// When there are no idle connections available, DB.conn will create
// a new connRequest and put it on the db.connRequests list.
type connRequest struct {
conn *driverConn
err error
}
var errDBClosed = errors.New("sql: database is closed")
// nextRequestKeyLocked returns the next connection request key.
// It is assumed that nextRequest will not overflow.
func (db *DB) nextRequestKeyLocked() uint64 {
next := db.nextRequest
db.nextRequest++
return next
}
// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
// Check if the context is expired.
select {
default:
case <-ctx.Done():
db.mu.Unlock()
return nil, ctx.Err()
}
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
last := len(db.freeConn) - 1
if strategy == cachedOrNewConn && last >= 0 {
// Reuse the lowest idle time connection so we can close
// connections which remain idle as soon as possible.
conn := db.freeConn[last]
db.freeConn = db.freeConn[:last]
conn.inUse = true
if conn.expired(lifetime) {
db.maxLifetimeClosed++
db.mu.Unlock()
conn.Close()
return nil, driver.ErrBadConn
}
db.mu.Unlock()
// Reset the session if required.
if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
conn.Close()
return nil, err
}
return conn, nil
}
// Out of free connections or we were asked not to use one. If we're not
// allowed to open any more connections, make a request and wait.
if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
reqKey := db.nextRequestKeyLocked()
db.connRequests[reqKey] = req
db.waitCount++
db.mu.Unlock()
waitStart := nowFunc()
// Timeout the connection request with the context.
select {
case <-ctx.Done():
// Remove the connection request and ensure no value has been sent
// on it after removing.
db.mu.Lock()
delete(db.connRequests, reqKey)
db.mu.Unlock()
db.waitDuration.Add(int64(time.Since(waitStart)))
select {
default:
case ret, ok := <-req:
if ok && ret.conn != nil {
db.putConn(ret.conn, ret.err, false)
}
}
return nil, ctx.Err()
case ret, ok := <-req:
db.waitDuration.Add(int64(time.Since(waitStart)))
if !ok {
return nil, errDBClosed
}
// Only check if the connection is expired if the strategy is cachedOrNewConns.
// If we require a new connection, just re-use the connection without looking
// at the expiry time. If it is expired, it will be checked when it is placed
// back into the connection pool.
// This prioritizes giving a valid connection to a client over the exact connection
// lifetime, which could expire exactly after this point anyway.
if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
db.mu.Lock()
db.maxLifetimeClosed++
db.mu.Unlock()
ret.conn.Close()
return nil, driver.ErrBadConn
}
if ret.conn == nil {
return nil, ret.err
}
// Reset the session if required.
if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
ret.conn.Close()
return nil, err
}
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
ci, err := db.connector.Connect(ctx)
if err != nil {
db.mu.Lock()
db.numOpen-- // correct for earlier optimism
db.maybeOpenNewConnections()
db.mu.Unlock()
return nil, err
}
db.mu.Lock()
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
inUse: true,
}
db.addDepLocked(dc, dc)
db.mu.Unlock()
return dc, nil
}
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
ds.Close()
})
} else {
c.Lock()
fc := c.finalClosed
c.Unlock()
if !fc {
ds.Close()
}
}
}
// debugGetPut determines whether getConn & putConn calls' stack traces
// are returned for more verbose crashes.
const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
if !errors.Is(err, driver.ErrBadConn) {
if !dc.validateConnection(resetSession) {
err = driver.ErrBadConn
}
}
db.mu.Lock()
if !dc.inUse {
db.mu.Unlock()
if debugGetPut {
fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
}
panic("sql: connection returned that was never out")
}
if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
db.maxLifetimeClosed++
err = driver.ErrBadConn
}
if debugGetPut {
db.lastPut[dc] = stack()
}
dc.inUse = false
dc.returnedAt = nowFunc()
for _, fn := range dc.onPut {
fn()
}
dc.onPut = nil
if errors.Is(err, driver.ErrBadConn) {
// Don't reuse bad connections.
// Since the conn is considered bad and is being discarded, treat it
// as closed. Don't decrement the open count here, finalClose will
// take care of that.
db.maybeOpenNewConnections()
db.mu.Unlock()
dc.Close()
return
}
if putConnHook != nil {
putConnHook(db, dc)
}
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
dc.Close()
return
}
}
// Satisfy a connRequest or put the driverConn in the idle pool and return true
// or return false.
// putConnDBLocked will satisfy a connRequest if there is one, or it will
// return the *driverConn to the freeConn list if err == nil and the idle
// connection limit will not be exceeded.
// If err != nil, the value of dc is ignored.
// If err == nil, then dc must not equal nil.
// If a connRequest was fulfilled or the *driverConn was placed in the
// freeConn list, then true is returned, otherwise false is returned.
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
if db.closed {
return false
}
if db.maxOpen > 0 && db.numOpen > db.maxOpen {
return false
}
if c := len(db.connRequests); c > 0 {
var req chan connRequest
var reqKey uint64
for reqKey, req = range db.connRequests {
break
}
delete(db.connRequests, reqKey) // Remove from pending requests.
if err == nil {
dc.inUse = true
}
req <- connRequest{
conn: dc,
err: err,
}
return true
} else if err == nil && !db.closed {
if db.maxIdleConnsLocked() > len(db.freeConn) {
db.freeConn = append(db.freeConn, dc)
db.startCleanerLocked()
return true
}
db.maxIdleClosed++
}
return false
}
// maxBadConnRetries is the number of maximum retries if the driver returns
// driver.ErrBadConn to signal a broken connection before forcing a new
// connection to be opened.
const maxBadConnRetries = 2
func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
for i := int64(0); i < maxBadConnRetries; i++ {
err := fn(cachedOrNewConn)
// retry if err is driver.ErrBadConn
if err == nil || !errors.Is(err, driver.ErrBadConn) {
return err
}
}
return fn(alwaysNewConn)
}
// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
err = db.retry(func(strategy connReuseStrategy) error {
stmt, err = db.prepare(ctx, query, strategy)
return err
})
return stmt, err
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
//
// Prepare uses context.Background internally; to specify the context, use
// PrepareContext.
func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query)
}
func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
}
// prepareDC prepares a query on the driverConn and calls release before
// returning. When cg == nil it implies that a connection pool is used, and
// when cg != nil only a single driver connection is used.
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
var ds *driverStmt
var err error
defer func() {
release(err)
}()
withLock(dc, func() {
ds, err = dc.prepareLocked(ctx, cg, query)
})
if err != nil {
return nil, err
}
stmt := &Stmt{
db: db,
query: query,
cg: cg,
cgds: ds,
}
// When cg == nil this statement will need to keep track of various
// connections they are prepared on and record the stmt dependency on
// the DB.
if cg == nil {
stmt.css = []connStmt{{dc, ds}}
stmt.lastNumClosed = db.numClosed.Load()
db.addDep(stmt, stmt)
}
return stmt, nil
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
var res Result
var err error
err = db.retry(func(strategy connReuseStrategy) error {
res, err = db.exec(ctx, query, args, strategy)
return err
})
return res, err
}
// Exec executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
//
// Exec uses context.Background internally; to specify the context, use
// ExecContext.
func (db *DB) Exec(query string, args ...any) (Result, error) {
return db.ExecContext(context.Background(), query, args...)
}
func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.execDC(ctx, dc, dc.releaseConn, query, args)
}
func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
defer func() {
release(err)
}()
execerCtx, ok := dc.ci.(driver.ExecerContext)
var execer driver.Execer
if !ok {
execer, ok = dc.ci.(driver.Execer)
}
if ok {
var nvdargs []driver.NamedValue
var resi driver.Result
withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
return nil, err
}
return driverResult{dc, resi}, nil
}
}
var si driver.Stmt
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
return resultFromStatement(ctx, dc.ci, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
var rows *Rows
var err error
err = db.retry(func(strategy connReuseStrategy) error {
rows, err = db.query(ctx, query, args, strategy)
return err
})
return rows, err
}
// Query executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
//
// Query uses context.Background internally; to specify the context, use
// QueryContext.
func (db *DB) Query(query string, args ...any) (*Rows, error) {
return db.QueryContext(context.Background(), query, args...)
}
func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}
// queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function.
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
queryerCtx, ok := dc.ci.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = dc.ci.(driver.Queryer)
}
if ok {
var nvdargs []driver.NamedValue
var rowsi driver.Rows
var err error
withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
return nil, err
}
// Note: ownership of dc passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
}
var si driver.Stmt
var err error
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
return nil, err
}
ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
if err != nil {
ds.Close()
releaseConn(err)
return nil, err
}
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
closeStmt: ds,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := db.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
//
// QueryRow uses context.Background internally; to specify the context, use
// QueryRowContext.
func (db *DB) QueryRow(query string, args ...any) *Row {
return db.QueryRowContext(context.Background(), query, args...)
}
// BeginTx starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. Tx.Commit will return an error if the context provided to
// BeginTx is canceled.
//
// The provided TxOptions is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
var tx *Tx
var err error
err = db.retry(func(strategy connReuseStrategy) error {
tx, err = db.begin(ctx, opts, strategy)
return err
})
return tx, err
}
// Begin starts a transaction. The default isolation level is dependent on
// the driver.
//
// Begin uses context.Background internally; to specify the context, use
// BeginTx.
func (db *DB) Begin() (*Tx, error) {
return db.BeginTx(context.Background(), nil)
}
func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.beginDC(ctx, dc, dc.releaseConn, opts)
}
// beginDC starts a transaction. The provided dc must be valid and ready to use.
func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
var txi driver.Tx
keepConnOnRollback := false
withLock(dc, func() {
_, hasSessionResetter := dc.ci.(driver.SessionResetter)
_, hasConnectionValidator := dc.ci.(driver.Validator)
keepConnOnRollback = hasSessionResetter && hasConnectionValidator
txi, err = ctxDriverBegin(ctx, opts, dc.ci)
})
if err != nil {
release(err)
return nil, err
}
// Schedule the transaction to rollback when the context is canceled.
// The cancel function in Tx will be called after done is set to true.
ctx, cancel := context.WithCancel(ctx)
tx = &Tx{
db: db,
dc: dc,
releaseConn: release,
txi: txi,
cancel: cancel,
keepConnOnRollback: keepConnOnRollback,
ctx: ctx,
}
go tx.awaitDone()
return tx, nil
}
// Driver returns the database's underlying driver.
func (db *DB) Driver() driver.Driver {
return db.connector.Driver()
}
// ErrConnDone is returned by any operation that is performed on a connection
// that has already been returned to the connection pool.
var ErrConnDone = errors.New("sql: connection is already closed")
// Conn returns a single connection by either opening a new connection
// or returning an existing connection from the connection pool. Conn will
// block until either a connection is returned or ctx is canceled.
// Queries run on the same Conn will be run in the same database session.
//
// Every Conn must be returned to the database pool after use by
// calling Conn.Close.
func (db *DB) Conn(ctx context.Context) (*Conn, error) {
var dc *driverConn
var err error
err = db.retry(func(strategy connReuseStrategy) error {
dc, err = db.conn(ctx, strategy)
return err
})
if err != nil {
return nil, err
}
conn := &Conn{
db: db,
dc: dc,
}
return conn, nil
}
type releaseConn func(error)
// Conn represents a single database connection rather than a pool of database
// connections. Prefer running queries from DB unless there is a specific
// need for a continuous single database connection.
//
// A Conn must call Close to return the connection to the database pool
// and may do so concurrently with a running query.
//
// After a call to Close, all operations on the
// connection fail with ErrConnDone.
type Conn struct {
db *DB
// closemu prevents the connection from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex
// dc is owned until close, at which point
// it's returned to the connection pool.
dc *driverConn
// done transitions from 0 to 1 exactly once, on close.
// Once done, all operations fail with ErrConnDone.
// Use atomic operations on value when checking value.
done int32
}
// grabConn takes a context to implement stmtConnGrabber
// but the context is not used.
func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
if atomic.LoadInt32(&c.done) != 0 {
return nil, nil, ErrConnDone
}
c.closemu.RLock()
return c.dc, c.closemuRUnlockCondReleaseConn, nil
}
// PingContext verifies the connection to the database is still alive.
func (c *Conn) PingContext(ctx context.Context) error {
dc, release, err := c.grabConn(ctx)
if err != nil {
return err
}
return c.db.pingDC(ctx, dc, release)
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.execDC(ctx, dc, release, query, args)
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.queryDC(ctx, nil, dc, release, query, args)
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := c.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.prepareDC(ctx, dc, release, c, query)
}
// Raw executes f exposing the underlying driver connection for the
// duration of f. The driverConn must not be used outside of f.
//
// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable
// until Conn.Close is called.
func (c *Conn) Raw(f func(driverConn any) error) (err error) {
var dc *driverConn
var release releaseConn
// grabConn takes a context to implement stmtConnGrabber, but the context is not used.
dc, release, err = c.grabConn(nil)
if err != nil {
return
}
fPanic := true
dc.Mutex.Lock()
defer func() {
dc.Mutex.Unlock()
// If f panics fPanic will remain true.
// Ensure an error is passed to release so the connection
// may be discarded.
if fPanic {
err = driver.ErrBadConn
}
release(err)
}()
err = f(dc.ci)
fPanic = false
return
}
// BeginTx starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. Tx.Commit will return an error if the context provided to
// BeginTx is canceled.
//
// The provided TxOptions is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
dc, release, err := c.grabConn(ctx)
if err != nil {
return nil, err
}
return c.db.beginDC(ctx, dc, release, opts)
}
// closemuRUnlockCondReleaseConn read unlocks closemu
// as the sql operation is done with the dc.
func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
c.closemu.RUnlock()
if errors.Is(err, driver.ErrBadConn) {
c.close(err)
}
}
func (c *Conn) txCtx() context.Context {
return nil
}
func (c *Conn) close(err error) error {
if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
return ErrConnDone
}
// Lock around releasing the driver connection
// to ensure all queries have been stopped before doing so.
c.closemu.Lock()
defer c.closemu.Unlock()
c.dc.releaseConn(err)
c.dc = nil
c.db = nil
return err
}
// Close returns the connection to the connection pool.
// All operations after a Close will return with ErrConnDone.
// Close is safe to call concurrently with other operations and will
// block until all other operations finish. It may be useful to first
// cancel any used context and then call close directly after.
func (c *Conn) Close() error {
return c.close(nil)
}
// Tx is an in-progress database transaction.
//
// A transaction must end with a call to Commit or Rollback.
//
// After a call to Commit or Rollback, all operations on the
// transaction fail with ErrTxDone.
//
// The statements prepared for a transaction by calling
// the transaction's Prepare or Stmt methods are closed
// by the call to Commit or Rollback.
type Tx struct {
db *DB
// closemu prevents the transaction from closing while there
// is an active query. It is held for read during queries
// and exclusively during close.
closemu sync.RWMutex
// dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn.
dc *driverConn
txi driver.Tx
// releaseConn is called once the Tx is closed to release
// any held driverConn back to the pool.
releaseConn func(error)
// done transitions from false to true exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
done atomic.Bool
// keepConnOnRollback is true if the driver knows
// how to reset the connection's session and if need be discard
// the connection.
keepConnOnRollback bool
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
// cancel is called after done transitions from 0 to 1.
cancel func()
// ctx lives for the life of the transaction.
ctx context.Context
}
// awaitDone blocks until the context in Tx is canceled and rolls back
// the transaction if it's not already done.
func (tx *Tx) awaitDone() {
// Wait for either the transaction to be committed or rolled
// back, or for the associated context to be closed.
<-tx.ctx.Done()
// Discard and close the connection used to ensure the
// transaction is closed and the resources are released. This
// rollback does nothing if the transaction has already been
// committed or rolled back.
// Do not discard the connection if the connection knows
// how to reset the session.
discardConnection := !tx.keepConnOnRollback
tx.rollback(discardConnection)
}
func (tx *Tx) isDone() bool {
return tx.done.Load()
}
// ErrTxDone is returned by any operation that is performed on a transaction
// that has already been committed or rolled back.
var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
// close returns the connection to the pool and
// must only be called by Tx.rollback or Tx.Commit while
// tx is already canceled and won't be executed concurrently.
func (tx *Tx) close(err error) {
tx.releaseConn(err)
tx.dc = nil
tx.txi = nil
}
// hookTxGrabConn specifies an optional hook to be called on
// a successful call to (*Tx).grabConn. For tests.
var hookTxGrabConn func()
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
select {
default:
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// closemu.RLock must come before the check for isDone to prevent the Tx from
// closing while a query is executing.
tx.closemu.RLock()
if tx.isDone() {
tx.closemu.RUnlock()
return nil, nil, ErrTxDone
}
if hookTxGrabConn != nil { // test hook
hookTxGrabConn()
}
return tx.dc, tx.closemuRUnlockRelease, nil
}
func (tx *Tx) txCtx() context.Context {
return tx.ctx
}
// closemuRUnlockRelease is used as a func(error) method value in
// ExecContext and QueryContext. Unlocking in the releaseConn keeps
// the driver conn from being returned to the connection pool until
// the Rows has been closed.
func (tx *Tx) closemuRUnlockRelease(error) {
tx.closemu.RUnlock()
}
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
defer tx.stmts.Unlock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
// Check context first to avoid transaction leak.
// If put it behind tx.done CompareAndSwap statement, we can't ensure
// the consistency between tx.done and the real COMMIT operation.
select {
default:
case <-tx.ctx.Done():
if tx.done.Load() {
return ErrTxDone
}
return tx.ctx.Err()
}
if !tx.done.CompareAndSwap(false, true) {
return ErrTxDone
}
// Cancel the Tx to release any active R-closemu locks.
// This is safe to do because tx.done has already transitioned
// from 0 to 1. Hold the W-closemu lock prior to rollback
// to ensure no other connection has an active query.
tx.cancel()
tx.closemu.Lock()
tx.closemu.Unlock()
var err error
withLock(tx.dc, func() {
err = tx.txi.Commit()
})
if !errors.Is(err, driver.ErrBadConn) {
tx.closePrepared()
}
tx.close(err)
return err
}
var rollbackHook func()
// rollback aborts the transaction and optionally forces the pool to discard
// the connection.
func (tx *Tx) rollback(discardConn bool) error {
if !tx.done.CompareAndSwap(false, true) {
return ErrTxDone
}
if rollbackHook != nil {
rollbackHook()
}
// Cancel the Tx to release any active R-closemu locks.
// This is safe to do because tx.done has already transitioned
// from 0 to 1. Hold the W-closemu lock prior to rollback
// to ensure no other connection has an active query.
tx.cancel()
tx.closemu.Lock()
tx.closemu.Unlock()
var err error
withLock(tx.dc, func() {
err = tx.txi.Rollback()
})
if !errors.Is(err, driver.ErrBadConn) {
tx.closePrepared()
}
if discardConn {
err = driver.ErrBadConn
}
tx.close(err)
return err
}
// Rollback aborts the transaction.
func (tx *Tx) Rollback() error {
return tx.rollback(false)
}
// PrepareContext creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
//
// The provided context will be used for the preparation of the context, not
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
if err != nil {
return nil, err
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
return stmt, nil
}
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
//
// Prepare uses context.Background internally; to specify the context, use
// PrepareContext.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query)
}
// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
//
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
// The provided context is used for the preparation of the statement, not for the
// execution of the statement.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
var si driver.Stmt
var parentStmt *Stmt
stmt.mu.Lock()
if stmt.closed || stmt.cg != nil {
// If the statement has been closed or already belongs to a
// transaction, we can't reuse it in this connection.
// Since tx.StmtContext should never need to be called with a
// Stmt already belonging to tx, we ignore this edge case and
// re-prepare the statement in this case. No need to add
// code-complexity for this.
stmt.mu.Unlock()
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
if err != nil {
return &Stmt{stickyErr: err}
}
} else {
stmt.removeClosedStmtLocked()
// See if the statement has already been prepared on this connection,
// and reuse it if possible.
for _, v := range stmt.css {
if v.dc == dc {
si = v.ds.si
break
}
}
stmt.mu.Unlock()
if si == nil {
var ds *driverStmt
withLock(dc, func() {
ds, err = stmt.prepareOnConnLocked(ctx, dc)
})
if err != nil {
return &Stmt{stickyErr: err}
}
si = ds.si
}
parentStmt = stmt
}
txs := &Stmt{
db: tx.db,
cg: tx,
cgds: &driverStmt{
Locker: dc,
si: si,
},
parentStmt: parentStmt,
query: stmt.query,
}
if parentStmt != nil {
tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
}
// Stmt returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
//
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
//
// Stmt uses context.Background internally; to specify the context, use
// StmtContext.
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt)
}
// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
return tx.db.execDC(ctx, dc, release, query, args)
}
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
//
// Exec uses context.Background internally; to specify the context, use
// ExecContext.
func (tx *Tx) Exec(query string, args ...any) (Result, error) {
return tx.ExecContext(context.Background(), query, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
}
// Query executes a query that returns rows, typically a SELECT.
//
// Query uses context.Background internally; to specify the context, use
// QueryContext.
func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...)
}
// QueryRowContext executes a query that is expected to return at most one row.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
//
// QueryRow uses context.Background internally; to specify the context, use
// QueryRowContext.
func (tx *Tx) QueryRow(query string, args ...any) *Row {
return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
ds *driverStmt
}
// stmtConnGrabber represents a Tx or Conn that will return the underlying
// driverConn and release function.
type stmtConnGrabber interface {
// grabConn returns the driverConn and the associated release function
// that must be called when the operation completes.
grabConn(context.Context) (*driverConn, releaseConn, error)
// txCtx returns the transaction context if available.
// The returned context should be selected on along with
// any query context when awaiting a cancel.
txCtx() context.Context
}
var (
_ stmtConnGrabber = &Tx{}
_ stmtConnGrabber = &Conn{}
)
// Stmt is a prepared statement.
// A Stmt is safe for concurrent use by multiple goroutines.
//
// If a Stmt is prepared on a Tx or Conn, it will be bound to a single
// underlying connection forever. If the Tx or Conn closes, the Stmt will
// become unusable and all operations will return an error.
// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the
// DB. When the Stmt needs to execute on a new underlying connection, it will
// prepare itself on the new connection automatically.
type Stmt struct {
// Immutable:
db *DB // where we came from
query string // that created the Stmt
stickyErr error // if non-nil, this error is returned for all operations
closemu sync.RWMutex // held exclusively during close, for read otherwise.
// If Stmt is prepared on a Tx or Conn then cg is present and will
// only ever grab a connection from cg.
// If cg is nil then the Stmt must grab an arbitrary connection
// from db and determine if it must prepare the stmt again by
// inspecting css.
cg stmtConnGrabber
cgds *driverStmt
// parentStmt is set when a transaction-specific statement
// is requested from an identical statement prepared on the same
// conn. parentStmt is used to track the dependency of this statement
// on its originating ("parent") statement so that parentStmt may
// be closed by the user without them having to know whether or not
// any transactions are still using it.
parentStmt *Stmt
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
// used if cg == nil and one is found that has idle
// connections. If cg != nil, cgds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
// without tx and closed connections in css are removed.
lastNumClosed uint64
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
err := s.db.retry(func(strategy connReuseStrategy) error {
dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
return err
}
res, err = resultFromStatement(ctx, dc.ci, ds, args...)
releaseConn(err)
return err
})
return res, err
}
// Exec executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
//
// Exec uses context.Background internally; to specify the context, use
// ExecContext.
func (s *Stmt) Exec(args ...any) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
ds.Lock()
defer ds.Unlock()
dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return driverResult{ds.Locker, resi}, nil
}
// removeClosedStmtLocked removes closed conns in s.css.
//
// To avoid lock contention on DB.mu, we do it only when
// s.db.numClosed - s.lastNum is large enough.
func (s *Stmt) removeClosedStmtLocked() {
t := len(s.css)/2 + 1
if t > 10 {
t = 10
}
dbClosed := s.db.numClosed.Load()
if dbClosed-s.lastNumClosed < uint64(t) {
return
}
s.db.mu.Lock()
for i := 0; i < len(s.css); i++ {
if s.css[i].dc.dbmuClosed {
s.css[i] = s.css[len(s.css)-1]
s.css = s.css[:len(s.css)-1]
i--
}
}
s.db.mu.Unlock()
s.lastNumClosed = dbClosed
}
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
err = errors.New("sql: statement is closed")
return
}
// In a transaction or connection, we always use the connection that the
// stmt was created on.
if s.cg != nil {
s.mu.Unlock()
dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
return dc, releaseConn, s.cgds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
s.mu.Lock()
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
dc.releaseConn(err)
return nil, nil, nil, err
}
return dc, dc.releaseConn, ds, nil
}
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
si, err := dc.prepareLocked(ctx, s.cg, s.query)
if err != nil {
return nil, err
}
cs := connStmt{dc, si}
s.mu.Lock()
s.css = append(s.css, cs)
s.mu.Unlock()
return cs.ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
var rows *Rows
err := s.db.retry(func(strategy connReuseStrategy) error {
dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
return err
}
rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
rows = &Rows{
dc: dc,
rowsi: rowsi,
// releaseConn set below
}
// addDep must be added before initContextClose or it could attempt
// to removeDep before it has been added.
s.db.addDep(s, rows)
// releaseConn must be set before initContextClose or it could
// release the connection before it is set.
rows.releaseConn = func(err error) {
releaseConn(err)
s.db.removeDep(s, rows)
}
var txctx context.Context
if s.cg != nil {
txctx = s.cg.txCtx()
}
rows.initContextClose(ctx, txctx)
return nil
}
releaseConn(err)
return err
})
return rows, err
}
// Query executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
//
// Query uses context.Background internally; to specify the context, use
// QueryContext.
func (s *Stmt) Query(args ...any) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
ds.Lock()
defer ds.Unlock()
dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
return ctxDriverStmtQuery(ctx, ds.si, dargs)
}
// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
// QueryRow executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
//
// Example usage:
//
// var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&name)
//
// QueryRow uses context.Background internally; to specify the context, use
// QueryRowContext.
func (s *Stmt) QueryRow(args ...any) *Row {
return s.QueryRowContext(context.Background(), args...)
}
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()
defer s.closemu.Unlock()
if s.stickyErr != nil {
return s.stickyErr
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return nil
}
s.closed = true
txds := s.cgds
s.cgds = nil
s.mu.Unlock()
if s.cg == nil {
return s.db.removeDep(s, s)
}
if s.parentStmt != nil {
// If parentStmt is set, we must not close s.txds since it's stored
// in the css array of the parentStmt.
return s.db.removeDep(s.parentStmt, s)
}
return txds.Close()
}
func (s *Stmt) finalClose() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.dc, v.ds)
v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
return nil
}
// Rows is the result of a query. Its cursor starts before the first row
// of the result set. Use Next to advance from row to row.
type Rows struct {
dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error)
rowsi driver.Rows
cancel func() // called when Rows is closed, may be nil.
closeStmt *driverStmt // if non-nil, statement to Close on close
// closemu prevents Rows from closing while there
// is an active streaming result. It is held for read during non-close operations
// and exclusively during close.
//
// closemu guards lasterr and closed.
closemu sync.RWMutex
closed bool
lasterr error // non-nil only if closed is true
// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
}
// lasterrOrErrLocked returns either lasterr or the provided err.
// rs.closemu must be read-locked.
func (rs *Rows) lasterrOrErrLocked(err error) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
return rs.lasterr
}
return err
}
// bypassRowsAwaitDone is only used for testing.
// If true, it will not close the Rows automatically from the context.
var bypassRowsAwaitDone = false
func (rs *Rows) initContextClose(ctx, txctx context.Context) {
if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
return
}
if bypassRowsAwaitDone {
return
}
ctx, rs.cancel = context.WithCancel(ctx)
go rs.awaitDone(ctx, txctx)
}
// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided
// from the query context and is canceled when the query Rows is closed.
// If the query was issued in a transaction, the transaction's context
// is also provided in txctx to ensure Rows is closed if the Tx is closed.
func (rs *Rows) awaitDone(ctx, txctx context.Context) {
var txctxDone <-chan struct{}
if txctx != nil {
txctxDone = txctx.Done()
}
select {
case <-ctx.Done():
case <-txctxDone:
}
rs.close(ctx.Err())
}
// Next prepares the next result row for reading with the Scan method. It
// returns true on success, or false if there is no next result row or an error
// happened while preparing it. Err should be consulted to distinguish between
// the two cases.
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
var doClose, ok bool
withLock(rs.closemu.RLocker(), func() {
doClose, ok = rs.nextLocked()
})
if doClose {
rs.Close()
}
return ok
}
func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
if rs.lasterr != io.EOF {
return true, false
}
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
return true, false
}
// The driver is at the end of the current result set.
// Test to see if there is another result set after the current one.
// Only close Rows if there is no further result sets to read.
if !nextResultSet.HasNextResultSet() {
doClose = true
}
return doClose, false
}
return false, true
}
// NextResultSet prepares the next result set for reading. It reports whether
// there is further result sets, or false if there is no further result set
// or if there is an error advancing to it. The Err method should be consulted
// to distinguish between the two cases.
//
// After calling NextResultSet, the Next method should always be called before
// scanning. If there are further result sets they may not have rows in the result
// set.
func (rs *Rows) NextResultSet() bool {
var doClose bool
defer func() {
if doClose {
rs.Close()
}
}()
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return false
}
rs.lastcols = nil
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
doClose = true
return false
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
return false
}
return true
}
// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *Rows) Err() error {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
return rs.lasterrOrErrLocked(nil)
}
var errRowsClosed = errors.New("sql: Rows are closed")
var errNoRows = errors.New("sql: no Rows available")
// Columns returns the column names.
// Columns returns an error if the rows are closed.
func (rs *Rows) Columns() ([]string, error) {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return nil, rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.rowsi == nil {
return nil, rs.lasterrOrErrLocked(errNoRows)
}
rs.dc.Lock()
defer rs.dc.Unlock()
return rs.rowsi.Columns(), nil
}
// ColumnTypes returns column information such as column type, length,
// and nullable. Some information may not be available from some drivers.
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
rs.closemu.RLock()
defer rs.closemu.RUnlock()
if rs.closed {
return nil, rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.rowsi == nil {
return nil, rs.lasterrOrErrLocked(errNoRows)
}
rs.dc.Lock()
defer rs.dc.Unlock()
return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
type ColumnType struct {
name string
hasNullable bool
hasLength bool
hasPrecisionScale bool
nullable bool
length int64
databaseType string
precision int64
scale int64
scanType reflect.Type
}
// Name returns the name or alias of the column.
func (ci *ColumnType) Name() string {
return ci.name
}
// Length returns the column type length for variable length column types such
// as text and binary field types. If the type length is unbounded the value will
// be math.MaxInt64 (any database limits will still apply).
// If the column type is not variable length, such as an int, or if not supported
// by the driver ok is false.
func (ci *ColumnType) Length() (length int64, ok bool) {
return ci.length, ci.hasLength
}
// DecimalSize returns the scale and precision of a decimal type.
// If not applicable or if not supported ok is false.
func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
return ci.precision, ci.scale, ci.hasPrecisionScale
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
// If a driver does not support this property ScanType will return
// the type of an empty interface.
func (ci *ColumnType) ScanType() reflect.Type {
return ci.scanType
}
// Nullable reports whether the column may be null.
// If a driver does not support this property ok will be false.
func (ci *ColumnType) Nullable() (nullable, ok bool) {
return ci.nullable, ci.hasNullable
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))
for i := range list {
ci := &ColumnType{
name: names[i],
}
list[i] = ci
if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
ci.scanType = prop.ColumnTypeScanType(i)
} else {
ci.scanType = reflect.TypeOf(new(any)).Elem()
}
if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
ci.length, ci.hasLength = prop.ColumnTypeLength(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
}
if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
}
}
return list
}
// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in Rows.
//
// Scan converts columns read from the database into the following
// common Go types and special types provided by the sql package:
//
// *string
// *[]byte
// *int, *int8, *int16, *int32, *int64
// *uint, *uint8, *uint16, *uint32, *uint64
// *bool
// *float32, *float64
// *interface{}
// *RawBytes
// *Rows (cursor value)
// any type implementing Scanner (see Scanner docs)
//
// In the most simple case, if the type of the value from the source
// column is an integer, bool or string type T and dest is of type *T,
// Scan simply assigns the value through the pointer.
//
// Scan also converts between string and numeric types, as long as no
// information would be lost. While Scan stringifies all numbers
// scanned from numeric database columns into *string, scans into
// numeric types are checked for overflow. For example, a float64 with
// value 300 or a string with value "300" can scan into a uint16, but
// not into a uint8, though float64(255) or "255" can scan into a
// uint8. One exception is that scans of some float64 numbers to
// strings may lose information when stringifying. In general, scan
// floating point columns into *float64.
//
// If a dest argument has type *[]byte, Scan saves in that argument a
// copy of the corresponding data. The copy is owned by the caller and
// can be modified and held indefinitely. The copy can be avoided by
// using an argument of type *RawBytes instead; see the documentation
// for RawBytes for restrictions on its use.
//
// If an argument has type *interface{}, Scan copies the value
// provided by the underlying driver without conversion. When scanning
// from a source value of type []byte to *interface{}, a copy of the
// slice is made and the caller owns the result.
//
// Source values of type time.Time may be scanned into values of type
// *time.Time, *interface{}, *string, or *[]byte. When converting to
// the latter two, time.RFC3339Nano is used.
//
// Source values of type bool may be scanned into types *bool,
// *interface{}, *string, *[]byte, or *RawBytes.
//
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
//
// Scan can also convert a cursor returned from a query, such as
// "select cursor(select * from my_table) from dual", into a
// *Rows value that can itself be scanned from. The parent
// select query will close any cursor *Rows if the parent *Rows is closed.
//
// If any of the first arguments implementing Scanner returns an error,
// that error will be wrapped in the returned error.
func (rs *Rows) Scan(dest ...any) error {
rs.closemu.RLock()
if rs.lasterr != nil && rs.lasterr != io.EOF {
rs.closemu.RUnlock()
return rs.lasterr
}
if rs.closed {
err := rs.lasterrOrErrLocked(errRowsClosed)
rs.closemu.RUnlock()
return err
}
rs.closemu.RUnlock()
if rs.lastcols == nil {
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}
// rowsCloseHook returns a function so tests may install the
// hook through a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }
// Close closes the Rows, preventing further enumeration. If Next is called
// and returns false and there are no further result sets,
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
return rs.close(nil)
}
func (rs *Rows) close(err error) error {
rs.closemu.Lock()
defer rs.closemu.Unlock()
if rs.closed {
return nil
}
rs.closed = true
if rs.lasterr == nil {
rs.lasterr = err
}
withLock(rs.dc, func() {
err = rs.rowsi.Close()
})
if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
if rs.cancel != nil {
rs.cancel()
}
if rs.closeStmt != nil {
rs.closeStmt.Close()
}
rs.releaseConn(err)
rs.lasterr = rs.lasterrOrErrLocked(err)
return err
}
// Row is the result of calling QueryRow to select a single row.
type Row struct {
// One of these two will be non-nil:
err error // deferred error for easy chaining
rows *Rows
}
// Scan copies the columns from the matched row into the values
// pointed at by dest. See the documentation on Rows.Scan for details.
// If more than one row matches the query,
// Scan uses the first row and discards the rest. If no row matches
// the query, Scan returns ErrNoRows.
func (r *Row) Scan(dest ...any) error {
if r.err != nil {
return r.err
}
// TODO(bradfitz): for now we need to defensively clone all
// []byte that the driver returned (not permitting
// *RawBytes in Rows.Scan), since we're about to close
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
// only valid until the next Scan/Close. But the TODO is that
// for a lot of drivers, this copy will be unnecessary. We
// should provide an optional interface for drivers to
// implement to say, "don't worry, the []bytes that I return
// from Next will not be modified again." (for instance, if
// they were obtained from the network anyway) But for now we
// don't care.
defer r.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !r.rows.Next() {
if err := r.rows.Err(); err != nil {
return err
}
return ErrNoRows
}
err := r.rows.Scan(dest...)
if err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
return r.rows.Close()
}
// Err provides a way for wrapping packages to check for
// query errors without calling Scan.
// Err returns the error, if any, that was encountered while running the query.
// If this error is not nil, this error will also be returned from Scan.
func (r *Row) Err() error {
return r.err
}
// A Result summarizes an executed SQL command.
type Result interface {
// LastInsertId returns the integer generated by the database
// in response to a command. Typically this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
LastInsertId() (int64, error)
// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
RowsAffected() (int64, error)
}
type driverResult struct {
sync.Locker // the *driverConn
resi driver.Result
}
func (dr driverResult) LastInsertId() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.LastInsertId()
}
func (dr driverResult) RowsAffected() (int64, error) {
dr.Lock()
defer dr.Unlock()
return dr.resi.RowsAffected()
}
func stack() string {
var buf [2 << 10]byte
return string(buf[:runtime.Stack(buf[:], false)])
}
// withLock runs while holding lk.
func withLock(lk sync.Locker, fn func()) {
lk.Lock()
defer lk.Unlock() // in case fn panics
fn()
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package buildinfo provides access to information embedded in a Go binary
// about how it was built. This includes the Go toolchain version, and the
// set of modules used (for binaries built in module mode).
//
// Build information is available for the currently running binary in
// runtime/debug.ReadBuildInfo.
package buildinfo
import (
"bytes"
"debug/elf"
"debug/macho"
"debug/pe"
"debug/plan9obj"
"encoding/binary"
"errors"
"fmt"
"internal/xcoff"
"io"
"io/fs"
"os"
"runtime/debug"
)
// Type alias for build info. We cannot move the types here, since
// runtime/debug would need to import this package, which would make it
// a much larger dependency.
type BuildInfo = debug.BuildInfo
var (
// errUnrecognizedFormat is returned when a given executable file doesn't
// appear to be in a known format, or it breaks the rules of that format,
// or when there are I/O errors reading the file.
errUnrecognizedFormat = errors.New("unrecognized file format")
// errNotGoExe is returned when a given executable file is valid but does
// not contain Go build information.
errNotGoExe = errors.New("not a Go executable")
// The build info blob left by the linker is identified by
// a 16-byte header, consisting of buildInfoMagic (14 bytes),
// the binary's pointer size (1 byte),
// and whether the binary is big endian (1 byte).
buildInfoMagic = []byte("\xff Go buildinf:")
)
// ReadFile returns build information embedded in a Go binary
// file at the given path. Most information is only available for binaries built
// with module support.
func ReadFile(name string) (info *BuildInfo, err error) {
defer func() {
if pathErr := (*fs.PathError)(nil); errors.As(err, &pathErr) {
err = fmt.Errorf("could not read Go build info: %w", err)
} else if err != nil {
err = fmt.Errorf("could not read Go build info from %s: %w", name, err)
}
}()
f, err := os.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return Read(f)
}
// Read returns build information embedded in a Go binary file
// accessed through the given ReaderAt. Most information is only available for
// binaries built with module support.
func Read(r io.ReaderAt) (*BuildInfo, error) {
vers, mod, err := readRawBuildInfo(r)
if err != nil {
return nil, err
}
bi, err := debug.ParseBuildInfo(mod)
if err != nil {
return nil, err
}
bi.GoVersion = vers
return bi, nil
}
type exe interface {
// ReadData reads and returns up to size bytes starting at virtual address addr.
ReadData(addr, size uint64) ([]byte, error)
// DataStart returns the virtual address of the segment or section that
// should contain build information. This is either a specially named section
// or the first writable non-zero data segment.
DataStart() uint64
}
// readRawBuildInfo extracts the Go toolchain version and module information
// strings from a Go binary. On success, vers should be non-empty. mod
// is empty if the binary was not built with modules enabled.
func readRawBuildInfo(r io.ReaderAt) (vers, mod string, err error) {
// Read the first bytes of the file to identify the format, then delegate to
// a format-specific function to load segment and section headers.
ident := make([]byte, 16)
if n, err := r.ReadAt(ident, 0); n < len(ident) || err != nil {
return "", "", errUnrecognizedFormat
}
var x exe
switch {
case bytes.HasPrefix(ident, []byte("\x7FELF")):
f, err := elf.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &elfExe{f}
case bytes.HasPrefix(ident, []byte("MZ")):
f, err := pe.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &peExe{f}
case bytes.HasPrefix(ident, []byte("\xFE\xED\xFA")) || bytes.HasPrefix(ident[1:], []byte("\xFA\xED\xFE")):
f, err := macho.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &machoExe{f}
case bytes.HasPrefix(ident, []byte{0x01, 0xDF}) || bytes.HasPrefix(ident, []byte{0x01, 0xF7}):
f, err := xcoff.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &xcoffExe{f}
case hasPlan9Magic(ident):
f, err := plan9obj.NewFile(r)
if err != nil {
return "", "", errUnrecognizedFormat
}
x = &plan9objExe{f}
default:
return "", "", errUnrecognizedFormat
}
// Read the first 64kB of dataAddr to find the build info blob.
// On some platforms, the blob will be in its own section, and DataStart
// returns the address of that section. On others, it's somewhere in the
// data segment; the linker puts it near the beginning.
// See cmd/link/internal/ld.Link.buildinfo.
dataAddr := x.DataStart()
data, err := x.ReadData(dataAddr, 64*1024)
if err != nil {
return "", "", err
}
const (
buildInfoAlign = 16
buildInfoSize = 32
)
for {
i := bytes.Index(data, buildInfoMagic)
if i < 0 || len(data)-i < buildInfoSize {
return "", "", errNotGoExe
}
if i%buildInfoAlign == 0 && len(data)-i >= buildInfoSize {
data = data[i:]
break
}
data = data[(i+buildInfoAlign-1)&^(buildInfoAlign-1):]
}
// Decode the blob.
// The first 14 bytes are buildInfoMagic.
// The next two bytes indicate pointer size in bytes (4 or 8) and endianness
// (0 for little, 1 for big).
// Two virtual addresses to Go strings follow that: runtime.buildVersion,
// and runtime.modinfo.
// On 32-bit platforms, the last 8 bytes are unused.
// If the endianness has the 2 bit set, then the pointers are zero
// and the 32-byte header is followed by varint-prefixed string data
// for the two string values we care about.
ptrSize := int(data[14])
if data[15]&2 != 0 {
vers, data = decodeString(data[32:])
mod, data = decodeString(data)
} else {
bigEndian := data[15] != 0
var bo binary.ByteOrder
if bigEndian {
bo = binary.BigEndian
} else {
bo = binary.LittleEndian
}
var readPtr func([]byte) uint64
if ptrSize == 4 {
readPtr = func(b []byte) uint64 { return uint64(bo.Uint32(b)) }
} else if ptrSize == 8 {
readPtr = bo.Uint64
} else {
return "", "", errNotGoExe
}
vers = readString(x, ptrSize, readPtr, readPtr(data[16:]))
mod = readString(x, ptrSize, readPtr, readPtr(data[16+ptrSize:]))
}
if vers == "" {
return "", "", errNotGoExe
}
if len(mod) >= 33 && mod[len(mod)-17] == '\n' {
// Strip module framing: sentinel strings delimiting the module info.
// These are cmd/go/internal/modload.infoStart and infoEnd.
mod = mod[16 : len(mod)-16]
} else {
mod = ""
}
return vers, mod, nil
}
func hasPlan9Magic(magic []byte) bool {
if len(magic) >= 4 {
m := binary.BigEndian.Uint32(magic)
switch m {
case plan9obj.Magic386, plan9obj.MagicAMD64, plan9obj.MagicARM:
return true
}
}
return false
}
func decodeString(data []byte) (s string, rest []byte) {
u, n := binary.Uvarint(data)
if n <= 0 || u >= uint64(len(data)-n) {
return "", nil
}
return string(data[n : uint64(n)+u]), data[uint64(n)+u:]
}
// readString returns the string at address addr in the executable x.
func readString(x exe, ptrSize int, readPtr func([]byte) uint64, addr uint64) string {
hdr, err := x.ReadData(addr, uint64(2*ptrSize))
if err != nil || len(hdr) < 2*ptrSize {
return ""
}
dataAddr := readPtr(hdr)
dataLen := readPtr(hdr[ptrSize:])
data, err := x.ReadData(dataAddr, dataLen)
if err != nil || uint64(len(data)) < dataLen {
return ""
}
return string(data)
}
// elfExe is the ELF implementation of the exe interface.
type elfExe struct {
f *elf.File
}
func (x *elfExe) ReadData(addr, size uint64) ([]byte, error) {
for _, prog := range x.f.Progs {
if prog.Vaddr <= addr && addr <= prog.Vaddr+prog.Filesz-1 {
n := prog.Vaddr + prog.Filesz - addr
if n > size {
n = size
}
data := make([]byte, n)
_, err := prog.ReadAt(data, int64(addr-prog.Vaddr))
if err != nil {
return nil, err
}
return data, nil
}
}
return nil, errUnrecognizedFormat
}
func (x *elfExe) DataStart() uint64 {
for _, s := range x.f.Sections {
if s.Name == ".go.buildinfo" {
return s.Addr
}
}
for _, p := range x.f.Progs {
if p.Type == elf.PT_LOAD && p.Flags&(elf.PF_X|elf.PF_W) == elf.PF_W {
return p.Vaddr
}
}
return 0
}
// peExe is the PE (Windows Portable Executable) implementation of the exe interface.
type peExe struct {
f *pe.File
}
func (x *peExe) imageBase() uint64 {
switch oh := x.f.OptionalHeader.(type) {
case *pe.OptionalHeader32:
return uint64(oh.ImageBase)
case *pe.OptionalHeader64:
return oh.ImageBase
}
return 0
}
func (x *peExe) ReadData(addr, size uint64) ([]byte, error) {
addr -= x.imageBase()
for _, sect := range x.f.Sections {
if uint64(sect.VirtualAddress) <= addr && addr <= uint64(sect.VirtualAddress+sect.Size-1) {
n := uint64(sect.VirtualAddress+sect.Size) - addr
if n > size {
n = size
}
data := make([]byte, n)
_, err := sect.ReadAt(data, int64(addr-uint64(sect.VirtualAddress)))
if err != nil {
return nil, errUnrecognizedFormat
}
return data, nil
}
}
return nil, errUnrecognizedFormat
}
func (x *peExe) DataStart() uint64 {
// Assume data is first writable section.
const (
IMAGE_SCN_CNT_CODE = 0x00000020
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080
IMAGE_SCN_MEM_EXECUTE = 0x20000000
IMAGE_SCN_MEM_READ = 0x40000000
IMAGE_SCN_MEM_WRITE = 0x80000000
IMAGE_SCN_MEM_DISCARDABLE = 0x2000000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x1000000
IMAGE_SCN_ALIGN_32BYTES = 0x600000
)
for _, sect := range x.f.Sections {
if sect.VirtualAddress != 0 && sect.Size != 0 &&
sect.Characteristics&^IMAGE_SCN_ALIGN_32BYTES == IMAGE_SCN_CNT_INITIALIZED_DATA|IMAGE_SCN_MEM_READ|IMAGE_SCN_MEM_WRITE {
return uint64(sect.VirtualAddress) + x.imageBase()
}
}
return 0
}
// machoExe is the Mach-O (Apple macOS/iOS) implementation of the exe interface.
type machoExe struct {
f *macho.File
}
func (x *machoExe) ReadData(addr, size uint64) ([]byte, error) {
for _, load := range x.f.Loads {
seg, ok := load.(*macho.Segment)
if !ok {
continue
}
if seg.Addr <= addr && addr <= seg.Addr+seg.Filesz-1 {
if seg.Name == "__PAGEZERO" {
continue
}
n := seg.Addr + seg.Filesz - addr
if n > size {
n = size
}
data := make([]byte, n)
_, err := seg.ReadAt(data, int64(addr-seg.Addr))
if err != nil {
return nil, err
}
return data, nil
}
}
return nil, errUnrecognizedFormat
}
func (x *machoExe) DataStart() uint64 {
// Look for section named "__go_buildinfo".
for _, sec := range x.f.Sections {
if sec.Name == "__go_buildinfo" {
return sec.Addr
}
}
// Try the first non-empty writable segment.
const RW = 3
for _, load := range x.f.Loads {
seg, ok := load.(*macho.Segment)
if ok && seg.Addr != 0 && seg.Filesz != 0 && seg.Prot == RW && seg.Maxprot == RW {
return seg.Addr
}
}
return 0
}
// xcoffExe is the XCOFF (AIX eXtended COFF) implementation of the exe interface.
type xcoffExe struct {
f *xcoff.File
}
func (x *xcoffExe) ReadData(addr, size uint64) ([]byte, error) {
for _, sect := range x.f.Sections {
if sect.VirtualAddress <= addr && addr <= sect.VirtualAddress+sect.Size-1 {
n := sect.VirtualAddress + sect.Size - addr
if n > size {
n = size
}
data := make([]byte, n)
_, err := sect.ReadAt(data, int64(addr-sect.VirtualAddress))
if err != nil {
return nil, err
}
return data, nil
}
}
return nil, errors.New("address not mapped")
}
func (x *xcoffExe) DataStart() uint64 {
if s := x.f.SectionByType(xcoff.STYP_DATA); s != nil {
return s.VirtualAddress
}
return 0
}
// plan9objExe is the Plan 9 a.out implementation of the exe interface.
type plan9objExe struct {
f *plan9obj.File
}
func (x *plan9objExe) DataStart() uint64 {
if s := x.f.Section("data"); s != nil {
return uint64(s.Offset)
}
return 0
}
func (x *plan9objExe) ReadData(addr, size uint64) ([]byte, error) {
for _, sect := range x.f.Sections {
if uint64(sect.Offset) <= addr && addr <= uint64(sect.Offset+sect.Size-1) {
n := uint64(sect.Offset+sect.Size) - addr
if n > size {
n = size
}
data := make([]byte, n)
_, err := sect.ReadAt(data, int64(addr-uint64(sect.Offset)))
if err != nil {
return nil, err
}
return data, nil
}
}
return nil, errors.New("address not mapped")
}
// Code generated by "stringer -type Attr -trimprefix=Attr"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[AttrSibling-1]
_ = x[AttrLocation-2]
_ = x[AttrName-3]
_ = x[AttrOrdering-9]
_ = x[AttrByteSize-11]
_ = x[AttrBitOffset-12]
_ = x[AttrBitSize-13]
_ = x[AttrStmtList-16]
_ = x[AttrLowpc-17]
_ = x[AttrHighpc-18]
_ = x[AttrLanguage-19]
_ = x[AttrDiscr-21]
_ = x[AttrDiscrValue-22]
_ = x[AttrVisibility-23]
_ = x[AttrImport-24]
_ = x[AttrStringLength-25]
_ = x[AttrCommonRef-26]
_ = x[AttrCompDir-27]
_ = x[AttrConstValue-28]
_ = x[AttrContainingType-29]
_ = x[AttrDefaultValue-30]
_ = x[AttrInline-32]
_ = x[AttrIsOptional-33]
_ = x[AttrLowerBound-34]
_ = x[AttrProducer-37]
_ = x[AttrPrototyped-39]
_ = x[AttrReturnAddr-42]
_ = x[AttrStartScope-44]
_ = x[AttrStrideSize-46]
_ = x[AttrUpperBound-47]
_ = x[AttrAbstractOrigin-49]
_ = x[AttrAccessibility-50]
_ = x[AttrAddrClass-51]
_ = x[AttrArtificial-52]
_ = x[AttrBaseTypes-53]
_ = x[AttrCalling-54]
_ = x[AttrCount-55]
_ = x[AttrDataMemberLoc-56]
_ = x[AttrDeclColumn-57]
_ = x[AttrDeclFile-58]
_ = x[AttrDeclLine-59]
_ = x[AttrDeclaration-60]
_ = x[AttrDiscrList-61]
_ = x[AttrEncoding-62]
_ = x[AttrExternal-63]
_ = x[AttrFrameBase-64]
_ = x[AttrFriend-65]
_ = x[AttrIdentifierCase-66]
_ = x[AttrMacroInfo-67]
_ = x[AttrNamelistItem-68]
_ = x[AttrPriority-69]
_ = x[AttrSegment-70]
_ = x[AttrSpecification-71]
_ = x[AttrStaticLink-72]
_ = x[AttrType-73]
_ = x[AttrUseLocation-74]
_ = x[AttrVarParam-75]
_ = x[AttrVirtuality-76]
_ = x[AttrVtableElemLoc-77]
_ = x[AttrAllocated-78]
_ = x[AttrAssociated-79]
_ = x[AttrDataLocation-80]
_ = x[AttrStride-81]
_ = x[AttrEntrypc-82]
_ = x[AttrUseUTF8-83]
_ = x[AttrExtension-84]
_ = x[AttrRanges-85]
_ = x[AttrTrampoline-86]
_ = x[AttrCallColumn-87]
_ = x[AttrCallFile-88]
_ = x[AttrCallLine-89]
_ = x[AttrDescription-90]
_ = x[AttrBinaryScale-91]
_ = x[AttrDecimalScale-92]
_ = x[AttrSmall-93]
_ = x[AttrDecimalSign-94]
_ = x[AttrDigitCount-95]
_ = x[AttrPictureString-96]
_ = x[AttrMutable-97]
_ = x[AttrThreadsScaled-98]
_ = x[AttrExplicit-99]
_ = x[AttrObjectPointer-100]
_ = x[AttrEndianity-101]
_ = x[AttrElemental-102]
_ = x[AttrPure-103]
_ = x[AttrRecursive-104]
_ = x[AttrSignature-105]
_ = x[AttrMainSubprogram-106]
_ = x[AttrDataBitOffset-107]
_ = x[AttrConstExpr-108]
_ = x[AttrEnumClass-109]
_ = x[AttrLinkageName-110]
_ = x[AttrStringLengthBitSize-111]
_ = x[AttrStringLengthByteSize-112]
_ = x[AttrRank-113]
_ = x[AttrStrOffsetsBase-114]
_ = x[AttrAddrBase-115]
_ = x[AttrRnglistsBase-116]
_ = x[AttrDwoName-118]
_ = x[AttrReference-119]
_ = x[AttrRvalueReference-120]
_ = x[AttrMacros-121]
_ = x[AttrCallAllCalls-122]
_ = x[AttrCallAllSourceCalls-123]
_ = x[AttrCallAllTailCalls-124]
_ = x[AttrCallReturnPC-125]
_ = x[AttrCallValue-126]
_ = x[AttrCallOrigin-127]
_ = x[AttrCallParameter-128]
_ = x[AttrCallPC-129]
_ = x[AttrCallTailCall-130]
_ = x[AttrCallTarget-131]
_ = x[AttrCallTargetClobbered-132]
_ = x[AttrCallDataLocation-133]
_ = x[AttrCallDataValue-134]
_ = x[AttrNoreturn-135]
_ = x[AttrAlignment-136]
_ = x[AttrExportSymbols-137]
_ = x[AttrDeleted-138]
_ = x[AttrDefaulted-139]
_ = x[AttrLoclistsBase-140]
}
const _Attr_name = "SiblingLocationNameOrderingByteSizeBitOffsetBitSizeStmtListLowpcHighpcLanguageDiscrDiscrValueVisibilityImportStringLengthCommonRefCompDirConstValueContainingTypeDefaultValueInlineIsOptionalLowerBoundProducerPrototypedReturnAddrStartScopeStrideSizeUpperBoundAbstractOriginAccessibilityAddrClassArtificialBaseTypesCallingCountDataMemberLocDeclColumnDeclFileDeclLineDeclarationDiscrListEncodingExternalFrameBaseFriendIdentifierCaseMacroInfoNamelistItemPrioritySegmentSpecificationStaticLinkTypeUseLocationVarParamVirtualityVtableElemLocAllocatedAssociatedDataLocationStrideEntrypcUseUTF8ExtensionRangesTrampolineCallColumnCallFileCallLineDescriptionBinaryScaleDecimalScaleSmallDecimalSignDigitCountPictureStringMutableThreadsScaledExplicitObjectPointerEndianityElementalPureRecursiveSignatureMainSubprogramDataBitOffsetConstExprEnumClassLinkageNameStringLengthBitSizeStringLengthByteSizeRankStrOffsetsBaseAddrBaseRnglistsBaseDwoNameReferenceRvalueReferenceMacrosCallAllCallsCallAllSourceCallsCallAllTailCallsCallReturnPCCallValueCallOriginCallParameterCallPCCallTailCallCallTargetCallTargetClobberedCallDataLocationCallDataValueNoreturnAlignmentExportSymbolsDeletedDefaultedLoclistsBase"
var _Attr_map = map[Attr]string{
1: _Attr_name[0:7],
2: _Attr_name[7:15],
3: _Attr_name[15:19],
9: _Attr_name[19:27],
11: _Attr_name[27:35],
12: _Attr_name[35:44],
13: _Attr_name[44:51],
16: _Attr_name[51:59],
17: _Attr_name[59:64],
18: _Attr_name[64:70],
19: _Attr_name[70:78],
21: _Attr_name[78:83],
22: _Attr_name[83:93],
23: _Attr_name[93:103],
24: _Attr_name[103:109],
25: _Attr_name[109:121],
26: _Attr_name[121:130],
27: _Attr_name[130:137],
28: _Attr_name[137:147],
29: _Attr_name[147:161],
30: _Attr_name[161:173],
32: _Attr_name[173:179],
33: _Attr_name[179:189],
34: _Attr_name[189:199],
37: _Attr_name[199:207],
39: _Attr_name[207:217],
42: _Attr_name[217:227],
44: _Attr_name[227:237],
46: _Attr_name[237:247],
47: _Attr_name[247:257],
49: _Attr_name[257:271],
50: _Attr_name[271:284],
51: _Attr_name[284:293],
52: _Attr_name[293:303],
53: _Attr_name[303:312],
54: _Attr_name[312:319],
55: _Attr_name[319:324],
56: _Attr_name[324:337],
57: _Attr_name[337:347],
58: _Attr_name[347:355],
59: _Attr_name[355:363],
60: _Attr_name[363:374],
61: _Attr_name[374:383],
62: _Attr_name[383:391],
63: _Attr_name[391:399],
64: _Attr_name[399:408],
65: _Attr_name[408:414],
66: _Attr_name[414:428],
67: _Attr_name[428:437],
68: _Attr_name[437:449],
69: _Attr_name[449:457],
70: _Attr_name[457:464],
71: _Attr_name[464:477],
72: _Attr_name[477:487],
73: _Attr_name[487:491],
74: _Attr_name[491:502],
75: _Attr_name[502:510],
76: _Attr_name[510:520],
77: _Attr_name[520:533],
78: _Attr_name[533:542],
79: _Attr_name[542:552],
80: _Attr_name[552:564],
81: _Attr_name[564:570],
82: _Attr_name[570:577],
83: _Attr_name[577:584],
84: _Attr_name[584:593],
85: _Attr_name[593:599],
86: _Attr_name[599:609],
87: _Attr_name[609:619],
88: _Attr_name[619:627],
89: _Attr_name[627:635],
90: _Attr_name[635:646],
91: _Attr_name[646:657],
92: _Attr_name[657:669],
93: _Attr_name[669:674],
94: _Attr_name[674:685],
95: _Attr_name[685:695],
96: _Attr_name[695:708],
97: _Attr_name[708:715],
98: _Attr_name[715:728],
99: _Attr_name[728:736],
100: _Attr_name[736:749],
101: _Attr_name[749:758],
102: _Attr_name[758:767],
103: _Attr_name[767:771],
104: _Attr_name[771:780],
105: _Attr_name[780:789],
106: _Attr_name[789:803],
107: _Attr_name[803:816],
108: _Attr_name[816:825],
109: _Attr_name[825:834],
110: _Attr_name[834:845],
111: _Attr_name[845:864],
112: _Attr_name[864:884],
113: _Attr_name[884:888],
114: _Attr_name[888:902],
115: _Attr_name[902:910],
116: _Attr_name[910:922],
118: _Attr_name[922:929],
119: _Attr_name[929:938],
120: _Attr_name[938:953],
121: _Attr_name[953:959],
122: _Attr_name[959:971],
123: _Attr_name[971:989],
124: _Attr_name[989:1005],
125: _Attr_name[1005:1017],
126: _Attr_name[1017:1026],
127: _Attr_name[1026:1036],
128: _Attr_name[1036:1049],
129: _Attr_name[1049:1055],
130: _Attr_name[1055:1067],
131: _Attr_name[1067:1077],
132: _Attr_name[1077:1096],
133: _Attr_name[1096:1112],
134: _Attr_name[1112:1125],
135: _Attr_name[1125:1133],
136: _Attr_name[1133:1142],
137: _Attr_name[1142:1155],
138: _Attr_name[1155:1162],
139: _Attr_name[1162:1171],
140: _Attr_name[1171:1183],
}
func (i Attr) String() string {
if str, ok := _Attr_map[i]; ok {
return str
}
return "Attr(" + strconv.FormatInt(int64(i), 10) + ")"
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Buffered reading and decoding of DWARF data streams.
package dwarf
import (
"bytes"
"encoding/binary"
"strconv"
)
// Data buffer being decoded.
type buf struct {
dwarf *Data
order binary.ByteOrder
format dataFormat
name string
off Offset
data []byte
err error
}
// Data format, other than byte order. This affects the handling of
// certain field formats.
type dataFormat interface {
// DWARF version number. Zero means unknown.
version() int
// 64-bit DWARF format?
dwarf64() (dwarf64 bool, isKnown bool)
// Size of an address, in bytes. Zero means unknown.
addrsize() int
}
// Some parts of DWARF have no data format, e.g., abbrevs.
type unknownFormat struct{}
func (u unknownFormat) version() int {
return 0
}
func (u unknownFormat) dwarf64() (bool, bool) {
return false, false
}
func (u unknownFormat) addrsize() int {
return 0
}
func makeBuf(d *Data, format dataFormat, name string, off Offset, data []byte) buf {
return buf{d, d.order, format, name, off, data, nil}
}
func (b *buf) uint8() uint8 {
if len(b.data) < 1 {
b.error("underflow")
return 0
}
val := b.data[0]
b.data = b.data[1:]
b.off++
return val
}
func (b *buf) bytes(n int) []byte {
if n < 0 || len(b.data) < n {
b.error("underflow")
return nil
}
data := b.data[0:n]
b.data = b.data[n:]
b.off += Offset(n)
return data
}
func (b *buf) skip(n int) { b.bytes(n) }
func (b *buf) string() string {
i := bytes.IndexByte(b.data, 0)
if i < 0 {
b.error("underflow")
return ""
}
s := string(b.data[0:i])
b.data = b.data[i+1:]
b.off += Offset(i + 1)
return s
}
func (b *buf) uint16() uint16 {
a := b.bytes(2)
if a == nil {
return 0
}
return b.order.Uint16(a)
}
func (b *buf) uint24() uint32 {
a := b.bytes(3)
if a == nil {
return 0
}
if b.dwarf.bigEndian {
return uint32(a[2]) | uint32(a[1])<<8 | uint32(a[0])<<16
} else {
return uint32(a[0]) | uint32(a[1])<<8 | uint32(a[2])<<16
}
}
func (b *buf) uint32() uint32 {
a := b.bytes(4)
if a == nil {
return 0
}
return b.order.Uint32(a)
}
func (b *buf) uint64() uint64 {
a := b.bytes(8)
if a == nil {
return 0
}
return b.order.Uint64(a)
}
// Read a varint, which is 7 bits per byte, little endian.
// the 0x80 bit means read another byte.
func (b *buf) varint() (c uint64, bits uint) {
for i := 0; i < len(b.data); i++ {
byte := b.data[i]
c |= uint64(byte&0x7F) << bits
bits += 7
if byte&0x80 == 0 {
b.off += Offset(i + 1)
b.data = b.data[i+1:]
return c, bits
}
}
return 0, 0
}
// Unsigned int is just a varint.
func (b *buf) uint() uint64 {
x, _ := b.varint()
return x
}
// Signed int is a sign-extended varint.
func (b *buf) int() int64 {
ux, bits := b.varint()
x := int64(ux)
if x&(1<<(bits-1)) != 0 {
x |= -1 << bits
}
return x
}
// Address-sized uint.
func (b *buf) addr() uint64 {
switch b.format.addrsize() {
case 1:
return uint64(b.uint8())
case 2:
return uint64(b.uint16())
case 4:
return uint64(b.uint32())
case 8:
return b.uint64()
}
b.error("unknown address size")
return 0
}
func (b *buf) unitLength() (length Offset, dwarf64 bool) {
length = Offset(b.uint32())
if length == 0xffffffff {
dwarf64 = true
length = Offset(b.uint64())
} else if length >= 0xfffffff0 {
b.error("unit length has reserved value")
}
return
}
func (b *buf) error(s string) {
if b.err == nil {
b.data = nil
b.err = DecodeError{b.name, b.off, s}
}
}
type DecodeError struct {
Name string
Offset Offset
Err string
}
func (e DecodeError) Error() string {
return "decoding dwarf section " + e.Name + " at offset 0x" + strconv.FormatInt(int64(e.Offset), 16) + ": " + e.Err
}
// Code generated by "stringer -type=Class"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ClassUnknown-0]
_ = x[ClassAddress-1]
_ = x[ClassBlock-2]
_ = x[ClassConstant-3]
_ = x[ClassExprLoc-4]
_ = x[ClassFlag-5]
_ = x[ClassLinePtr-6]
_ = x[ClassLocListPtr-7]
_ = x[ClassMacPtr-8]
_ = x[ClassRangeListPtr-9]
_ = x[ClassReference-10]
_ = x[ClassReferenceSig-11]
_ = x[ClassString-12]
_ = x[ClassReferenceAlt-13]
_ = x[ClassStringAlt-14]
}
const _Class_name = "ClassUnknownClassAddressClassBlockClassConstantClassExprLocClassFlagClassLinePtrClassLocListPtrClassMacPtrClassRangeListPtrClassReferenceClassReferenceSigClassStringClassReferenceAltClassStringAlt"
var _Class_index = [...]uint8{0, 12, 24, 34, 47, 59, 68, 80, 95, 106, 123, 137, 154, 165, 182, 196}
func (i Class) String() string {
if i < 0 || i >= Class(len(_Class_index)-1) {
return "Class(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Class_name[_Class_index[i]:_Class_index[i+1]]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Constants
package dwarf
//go:generate stringer -type Attr -trimprefix=Attr
// An Attr identifies the attribute type in a DWARF Entry's Field.
type Attr uint32
const (
AttrSibling Attr = 0x01
AttrLocation Attr = 0x02
AttrName Attr = 0x03
AttrOrdering Attr = 0x09
AttrByteSize Attr = 0x0B
AttrBitOffset Attr = 0x0C
AttrBitSize Attr = 0x0D
AttrStmtList Attr = 0x10
AttrLowpc Attr = 0x11
AttrHighpc Attr = 0x12
AttrLanguage Attr = 0x13
AttrDiscr Attr = 0x15
AttrDiscrValue Attr = 0x16
AttrVisibility Attr = 0x17
AttrImport Attr = 0x18
AttrStringLength Attr = 0x19
AttrCommonRef Attr = 0x1A
AttrCompDir Attr = 0x1B
AttrConstValue Attr = 0x1C
AttrContainingType Attr = 0x1D
AttrDefaultValue Attr = 0x1E
AttrInline Attr = 0x20
AttrIsOptional Attr = 0x21
AttrLowerBound Attr = 0x22
AttrProducer Attr = 0x25
AttrPrototyped Attr = 0x27
AttrReturnAddr Attr = 0x2A
AttrStartScope Attr = 0x2C
AttrStrideSize Attr = 0x2E
AttrUpperBound Attr = 0x2F
AttrAbstractOrigin Attr = 0x31
AttrAccessibility Attr = 0x32
AttrAddrClass Attr = 0x33
AttrArtificial Attr = 0x34
AttrBaseTypes Attr = 0x35
AttrCalling Attr = 0x36
AttrCount Attr = 0x37
AttrDataMemberLoc Attr = 0x38
AttrDeclColumn Attr = 0x39
AttrDeclFile Attr = 0x3A
AttrDeclLine Attr = 0x3B
AttrDeclaration Attr = 0x3C
AttrDiscrList Attr = 0x3D
AttrEncoding Attr = 0x3E
AttrExternal Attr = 0x3F
AttrFrameBase Attr = 0x40
AttrFriend Attr = 0x41
AttrIdentifierCase Attr = 0x42
AttrMacroInfo Attr = 0x43
AttrNamelistItem Attr = 0x44
AttrPriority Attr = 0x45
AttrSegment Attr = 0x46
AttrSpecification Attr = 0x47
AttrStaticLink Attr = 0x48
AttrType Attr = 0x49
AttrUseLocation Attr = 0x4A
AttrVarParam Attr = 0x4B
AttrVirtuality Attr = 0x4C
AttrVtableElemLoc Attr = 0x4D
// The following are new in DWARF 3.
AttrAllocated Attr = 0x4E
AttrAssociated Attr = 0x4F
AttrDataLocation Attr = 0x50
AttrStride Attr = 0x51
AttrEntrypc Attr = 0x52
AttrUseUTF8 Attr = 0x53
AttrExtension Attr = 0x54
AttrRanges Attr = 0x55
AttrTrampoline Attr = 0x56
AttrCallColumn Attr = 0x57
AttrCallFile Attr = 0x58
AttrCallLine Attr = 0x59
AttrDescription Attr = 0x5A
AttrBinaryScale Attr = 0x5B
AttrDecimalScale Attr = 0x5C
AttrSmall Attr = 0x5D
AttrDecimalSign Attr = 0x5E
AttrDigitCount Attr = 0x5F
AttrPictureString Attr = 0x60
AttrMutable Attr = 0x61
AttrThreadsScaled Attr = 0x62
AttrExplicit Attr = 0x63
AttrObjectPointer Attr = 0x64
AttrEndianity Attr = 0x65
AttrElemental Attr = 0x66
AttrPure Attr = 0x67
AttrRecursive Attr = 0x68
// The following are new in DWARF 4.
AttrSignature Attr = 0x69
AttrMainSubprogram Attr = 0x6A
AttrDataBitOffset Attr = 0x6B
AttrConstExpr Attr = 0x6C
AttrEnumClass Attr = 0x6D
AttrLinkageName Attr = 0x6E
// The following are new in DWARF 5.
AttrStringLengthBitSize Attr = 0x6F
AttrStringLengthByteSize Attr = 0x70
AttrRank Attr = 0x71
AttrStrOffsetsBase Attr = 0x72
AttrAddrBase Attr = 0x73
AttrRnglistsBase Attr = 0x74
AttrDwoName Attr = 0x76
AttrReference Attr = 0x77
AttrRvalueReference Attr = 0x78
AttrMacros Attr = 0x79
AttrCallAllCalls Attr = 0x7A
AttrCallAllSourceCalls Attr = 0x7B
AttrCallAllTailCalls Attr = 0x7C
AttrCallReturnPC Attr = 0x7D
AttrCallValue Attr = 0x7E
AttrCallOrigin Attr = 0x7F
AttrCallParameter Attr = 0x80
AttrCallPC Attr = 0x81
AttrCallTailCall Attr = 0x82
AttrCallTarget Attr = 0x83
AttrCallTargetClobbered Attr = 0x84
AttrCallDataLocation Attr = 0x85
AttrCallDataValue Attr = 0x86
AttrNoreturn Attr = 0x87
AttrAlignment Attr = 0x88
AttrExportSymbols Attr = 0x89
AttrDeleted Attr = 0x8A
AttrDefaulted Attr = 0x8B
AttrLoclistsBase Attr = 0x8C
)
func (a Attr) GoString() string {
if str, ok := _Attr_map[a]; ok {
return "dwarf.Attr" + str
}
return "dwarf." + a.String()
}
// A format is a DWARF data encoding format.
type format uint32
const (
// value formats
formAddr format = 0x01
formDwarfBlock2 format = 0x03
formDwarfBlock4 format = 0x04
formData2 format = 0x05
formData4 format = 0x06
formData8 format = 0x07
formString format = 0x08
formDwarfBlock format = 0x09
formDwarfBlock1 format = 0x0A
formData1 format = 0x0B
formFlag format = 0x0C
formSdata format = 0x0D
formStrp format = 0x0E
formUdata format = 0x0F
formRefAddr format = 0x10
formRef1 format = 0x11
formRef2 format = 0x12
formRef4 format = 0x13
formRef8 format = 0x14
formRefUdata format = 0x15
formIndirect format = 0x16
// The following are new in DWARF 4.
formSecOffset format = 0x17
formExprloc format = 0x18
formFlagPresent format = 0x19
formRefSig8 format = 0x20
// The following are new in DWARF 5.
formStrx format = 0x1A
formAddrx format = 0x1B
formRefSup4 format = 0x1C
formStrpSup format = 0x1D
formData16 format = 0x1E
formLineStrp format = 0x1F
formImplicitConst format = 0x21
formLoclistx format = 0x22
formRnglistx format = 0x23
formRefSup8 format = 0x24
formStrx1 format = 0x25
formStrx2 format = 0x26
formStrx3 format = 0x27
formStrx4 format = 0x28
formAddrx1 format = 0x29
formAddrx2 format = 0x2A
formAddrx3 format = 0x2B
formAddrx4 format = 0x2C
// Extensions for multi-file compression (.dwz)
// http://www.dwarfstd.org/ShowIssue.php?issue=120604.1
formGnuRefAlt format = 0x1f20
formGnuStrpAlt format = 0x1f21
)
//go:generate stringer -type Tag -trimprefix=Tag
// A Tag is the classification (the type) of an Entry.
type Tag uint32
const (
TagArrayType Tag = 0x01
TagClassType Tag = 0x02
TagEntryPoint Tag = 0x03
TagEnumerationType Tag = 0x04
TagFormalParameter Tag = 0x05
TagImportedDeclaration Tag = 0x08
TagLabel Tag = 0x0A
TagLexDwarfBlock Tag = 0x0B
TagMember Tag = 0x0D
TagPointerType Tag = 0x0F
TagReferenceType Tag = 0x10
TagCompileUnit Tag = 0x11
TagStringType Tag = 0x12
TagStructType Tag = 0x13
TagSubroutineType Tag = 0x15
TagTypedef Tag = 0x16
TagUnionType Tag = 0x17
TagUnspecifiedParameters Tag = 0x18
TagVariant Tag = 0x19
TagCommonDwarfBlock Tag = 0x1A
TagCommonInclusion Tag = 0x1B
TagInheritance Tag = 0x1C
TagInlinedSubroutine Tag = 0x1D
TagModule Tag = 0x1E
TagPtrToMemberType Tag = 0x1F
TagSetType Tag = 0x20
TagSubrangeType Tag = 0x21
TagWithStmt Tag = 0x22
TagAccessDeclaration Tag = 0x23
TagBaseType Tag = 0x24
TagCatchDwarfBlock Tag = 0x25
TagConstType Tag = 0x26
TagConstant Tag = 0x27
TagEnumerator Tag = 0x28
TagFileType Tag = 0x29
TagFriend Tag = 0x2A
TagNamelist Tag = 0x2B
TagNamelistItem Tag = 0x2C
TagPackedType Tag = 0x2D
TagSubprogram Tag = 0x2E
TagTemplateTypeParameter Tag = 0x2F
TagTemplateValueParameter Tag = 0x30
TagThrownType Tag = 0x31
TagTryDwarfBlock Tag = 0x32
TagVariantPart Tag = 0x33
TagVariable Tag = 0x34
TagVolatileType Tag = 0x35
// The following are new in DWARF 3.
TagDwarfProcedure Tag = 0x36
TagRestrictType Tag = 0x37
TagInterfaceType Tag = 0x38
TagNamespace Tag = 0x39
TagImportedModule Tag = 0x3A
TagUnspecifiedType Tag = 0x3B
TagPartialUnit Tag = 0x3C
TagImportedUnit Tag = 0x3D
TagMutableType Tag = 0x3E // Later removed from DWARF.
TagCondition Tag = 0x3F
TagSharedType Tag = 0x40
// The following are new in DWARF 4.
TagTypeUnit Tag = 0x41
TagRvalueReferenceType Tag = 0x42
TagTemplateAlias Tag = 0x43
// The following are new in DWARF 5.
TagCoarrayType Tag = 0x44
TagGenericSubrange Tag = 0x45
TagDynamicType Tag = 0x46
TagAtomicType Tag = 0x47
TagCallSite Tag = 0x48
TagCallSiteParameter Tag = 0x49
TagSkeletonUnit Tag = 0x4A
TagImmutableType Tag = 0x4B
)
func (t Tag) GoString() string {
if t <= TagTemplateAlias {
return "dwarf.Tag" + t.String()
}
return "dwarf." + t.String()
}
// Location expression operators.
// The debug info encodes value locations like 8(R3)
// as a sequence of these op codes.
// This package does not implement full expressions;
// the opPlusUconst operator is expected by the type parser.
const (
opAddr = 0x03 /* 1 op, const addr */
opDeref = 0x06
opConst1u = 0x08 /* 1 op, 1 byte const */
opConst1s = 0x09 /* " signed */
opConst2u = 0x0A /* 1 op, 2 byte const */
opConst2s = 0x0B /* " signed */
opConst4u = 0x0C /* 1 op, 4 byte const */
opConst4s = 0x0D /* " signed */
opConst8u = 0x0E /* 1 op, 8 byte const */
opConst8s = 0x0F /* " signed */
opConstu = 0x10 /* 1 op, LEB128 const */
opConsts = 0x11 /* " signed */
opDup = 0x12
opDrop = 0x13
opOver = 0x14
opPick = 0x15 /* 1 op, 1 byte stack index */
opSwap = 0x16
opRot = 0x17
opXderef = 0x18
opAbs = 0x19
opAnd = 0x1A
opDiv = 0x1B
opMinus = 0x1C
opMod = 0x1D
opMul = 0x1E
opNeg = 0x1F
opNot = 0x20
opOr = 0x21
opPlus = 0x22
opPlusUconst = 0x23 /* 1 op, ULEB128 addend */
opShl = 0x24
opShr = 0x25
opShra = 0x26
opXor = 0x27
opSkip = 0x2F /* 1 op, signed 2-byte constant */
opBra = 0x28 /* 1 op, signed 2-byte constant */
opEq = 0x29
opGe = 0x2A
opGt = 0x2B
opLe = 0x2C
opLt = 0x2D
opNe = 0x2E
opLit0 = 0x30
/* OpLitN = OpLit0 + N for N = 0..31 */
opReg0 = 0x50
/* OpRegN = OpReg0 + N for N = 0..31 */
opBreg0 = 0x70 /* 1 op, signed LEB128 constant */
/* OpBregN = OpBreg0 + N for N = 0..31 */
opRegx = 0x90 /* 1 op, ULEB128 register */
opFbreg = 0x91 /* 1 op, SLEB128 offset */
opBregx = 0x92 /* 2 op, ULEB128 reg; SLEB128 off */
opPiece = 0x93 /* 1 op, ULEB128 size of piece */
opDerefSize = 0x94 /* 1-byte size of data retrieved */
opXderefSize = 0x95 /* 1-byte size of data retrieved */
opNop = 0x96
// The following are new in DWARF 3.
opPushObjAddr = 0x97
opCall2 = 0x98 /* 2-byte offset of DIE */
opCall4 = 0x99 /* 4-byte offset of DIE */
opCallRef = 0x9A /* 4- or 8- byte offset of DIE */
opFormTLSAddress = 0x9B
opCallFrameCFA = 0x9C
opBitPiece = 0x9D
// The following are new in DWARF 4.
opImplicitValue = 0x9E
opStackValue = 0x9F
// The following a new in DWARF 5.
opImplicitPointer = 0xA0
opAddrx = 0xA1
opConstx = 0xA2
opEntryValue = 0xA3
opConstType = 0xA4
opRegvalType = 0xA5
opDerefType = 0xA6
opXderefType = 0xA7
opConvert = 0xA8
opReinterpret = 0xA9
/* 0xE0-0xFF reserved for user-specific */
)
// Basic type encodings -- the value for AttrEncoding in a TagBaseType Entry.
const (
encAddress = 0x01
encBoolean = 0x02
encComplexFloat = 0x03
encFloat = 0x04
encSigned = 0x05
encSignedChar = 0x06
encUnsigned = 0x07
encUnsignedChar = 0x08
// The following are new in DWARF 3.
encImaginaryFloat = 0x09
encPackedDecimal = 0x0A
encNumericString = 0x0B
encEdited = 0x0C
encSignedFixed = 0x0D
encUnsignedFixed = 0x0E
encDecimalFloat = 0x0F
// The following are new in DWARF 4.
encUTF = 0x10
// The following are new in DWARF 5.
encUCS = 0x11
encASCII = 0x12
)
// Statement program standard opcode encodings.
const (
lnsCopy = 1
lnsAdvancePC = 2
lnsAdvanceLine = 3
lnsSetFile = 4
lnsSetColumn = 5
lnsNegateStmt = 6
lnsSetBasicBlock = 7
lnsConstAddPC = 8
lnsFixedAdvancePC = 9
// DWARF 3
lnsSetPrologueEnd = 10
lnsSetEpilogueBegin = 11
lnsSetISA = 12
)
// Statement program extended opcode encodings.
const (
lneEndSequence = 1
lneSetAddress = 2
lneDefineFile = 3
// DWARF 4
lneSetDiscriminator = 4
)
// Line table directory and file name entry formats.
// These are new in DWARF 5.
const (
lnctPath = 0x01
lnctDirectoryIndex = 0x02
lnctTimestamp = 0x03
lnctSize = 0x04
lnctMD5 = 0x05
)
// Location list entry codes.
// These are new in DWARF 5.
const (
lleEndOfList = 0x00
lleBaseAddressx = 0x01
lleStartxEndx = 0x02
lleStartxLength = 0x03
lleOffsetPair = 0x04
lleDefaultLocation = 0x05
lleBaseAddress = 0x06
lleStartEnd = 0x07
lleStartLength = 0x08
)
// Unit header unit type encodings.
// These are new in DWARF 5.
const (
utCompile = 0x01
utType = 0x02
utPartial = 0x03
utSkeleton = 0x04
utSplitCompile = 0x05
utSplitType = 0x06
)
// Opcodes for DWARFv5 debug_rnglists section.
const (
rleEndOfList = 0x0
rleBaseAddressx = 0x1
rleStartxEndx = 0x2
rleStartxLength = 0x3
rleOffsetPair = 0x4
rleBaseAddress = 0x5
rleStartEnd = 0x6
rleStartLength = 0x7
)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DWARF debug information entry parser.
// An entry is a sequence of data items of a given format.
// The first word in the entry is an index into what DWARF
// calls the ``abbreviation table.'' An abbreviation is really
// just a type descriptor: it's an array of attribute tag/value format pairs.
package dwarf
import (
"encoding/binary"
"errors"
"fmt"
"strconv"
)
// a single entry's description: a sequence of attributes
type abbrev struct {
tag Tag
children bool
field []afield
}
type afield struct {
attr Attr
fmt format
class Class
val int64 // for formImplicitConst
}
// a map from entry format ids to their descriptions
type abbrevTable map[uint32]abbrev
// parseAbbrev returns the abbreviation table that starts at byte off
// in the .debug_abbrev section.
func (d *Data) parseAbbrev(off uint64, vers int) (abbrevTable, error) {
if m, ok := d.abbrevCache[off]; ok {
return m, nil
}
data := d.abbrev
if off > uint64(len(data)) {
data = nil
} else {
data = data[off:]
}
b := makeBuf(d, unknownFormat{}, "abbrev", 0, data)
// Error handling is simplified by the buf getters
// returning an endless stream of 0s after an error.
m := make(abbrevTable)
for {
// Table ends with id == 0.
id := uint32(b.uint())
if id == 0 {
break
}
// Walk over attributes, counting.
n := 0
b1 := b // Read from copy of b.
b1.uint()
b1.uint8()
for {
tag := b1.uint()
fmt := b1.uint()
if tag == 0 && fmt == 0 {
break
}
if format(fmt) == formImplicitConst {
b1.int()
}
n++
}
if b1.err != nil {
return nil, b1.err
}
// Walk over attributes again, this time writing them down.
var a abbrev
a.tag = Tag(b.uint())
a.children = b.uint8() != 0
a.field = make([]afield, n)
for i := range a.field {
a.field[i].attr = Attr(b.uint())
a.field[i].fmt = format(b.uint())
a.field[i].class = formToClass(a.field[i].fmt, a.field[i].attr, vers, &b)
if a.field[i].fmt == formImplicitConst {
a.field[i].val = b.int()
}
}
b.uint()
b.uint()
m[id] = a
}
if b.err != nil {
return nil, b.err
}
d.abbrevCache[off] = m
return m, nil
}
// attrIsExprloc indicates attributes that allow exprloc values that
// are encoded as block values in DWARF 2 and 3. See DWARF 4, Figure
// 20.
var attrIsExprloc = map[Attr]bool{
AttrLocation: true,
AttrByteSize: true,
AttrBitOffset: true,
AttrBitSize: true,
AttrStringLength: true,
AttrLowerBound: true,
AttrReturnAddr: true,
AttrStrideSize: true,
AttrUpperBound: true,
AttrCount: true,
AttrDataMemberLoc: true,
AttrFrameBase: true,
AttrSegment: true,
AttrStaticLink: true,
AttrUseLocation: true,
AttrVtableElemLoc: true,
AttrAllocated: true,
AttrAssociated: true,
AttrDataLocation: true,
AttrStride: true,
}
// attrPtrClass indicates the *ptr class of attributes that have
// encoding formSecOffset in DWARF 4 or formData* in DWARF 2 and 3.
var attrPtrClass = map[Attr]Class{
AttrLocation: ClassLocListPtr,
AttrStmtList: ClassLinePtr,
AttrStringLength: ClassLocListPtr,
AttrReturnAddr: ClassLocListPtr,
AttrStartScope: ClassRangeListPtr,
AttrDataMemberLoc: ClassLocListPtr,
AttrFrameBase: ClassLocListPtr,
AttrMacroInfo: ClassMacPtr,
AttrSegment: ClassLocListPtr,
AttrStaticLink: ClassLocListPtr,
AttrUseLocation: ClassLocListPtr,
AttrVtableElemLoc: ClassLocListPtr,
AttrRanges: ClassRangeListPtr,
// The following are new in DWARF 5.
AttrStrOffsetsBase: ClassStrOffsetsPtr,
AttrAddrBase: ClassAddrPtr,
AttrRnglistsBase: ClassRngListsPtr,
AttrLoclistsBase: ClassLocListPtr,
}
// formToClass returns the DWARF 4 Class for the given form. If the
// DWARF version is less then 4, it will disambiguate some forms
// depending on the attribute.
func formToClass(form format, attr Attr, vers int, b *buf) Class {
switch form {
default:
b.error("cannot determine class of unknown attribute form")
return 0
case formIndirect:
return ClassUnknown
case formAddr, formAddrx, formAddrx1, formAddrx2, formAddrx3, formAddrx4:
return ClassAddress
case formDwarfBlock1, formDwarfBlock2, formDwarfBlock4, formDwarfBlock:
// In DWARF 2 and 3, ClassExprLoc was encoded as a
// block. DWARF 4 distinguishes ClassBlock and
// ClassExprLoc, but there are no attributes that can
// be both, so we also promote ClassBlock values in
// DWARF 4 that should be ClassExprLoc in case
// producers get this wrong.
if attrIsExprloc[attr] {
return ClassExprLoc
}
return ClassBlock
case formData1, formData2, formData4, formData8, formSdata, formUdata, formData16, formImplicitConst:
// In DWARF 2 and 3, ClassPtr was encoded as a
// constant. Unlike ClassExprLoc/ClassBlock, some
// DWARF 4 attributes need to distinguish Class*Ptr
// from ClassConstant, so we only do this promotion
// for versions 2 and 3.
if class, ok := attrPtrClass[attr]; vers < 4 && ok {
return class
}
return ClassConstant
case formFlag, formFlagPresent:
return ClassFlag
case formRefAddr, formRef1, formRef2, formRef4, formRef8, formRefUdata, formRefSup4, formRefSup8:
return ClassReference
case formRefSig8:
return ClassReferenceSig
case formString, formStrp, formStrx, formStrpSup, formLineStrp, formStrx1, formStrx2, formStrx3, formStrx4:
return ClassString
case formSecOffset:
// DWARF 4 defines four *ptr classes, but doesn't
// distinguish them in the encoding. Disambiguate
// these classes using the attribute.
if class, ok := attrPtrClass[attr]; ok {
return class
}
return ClassUnknown
case formExprloc:
return ClassExprLoc
case formGnuRefAlt:
return ClassReferenceAlt
case formGnuStrpAlt:
return ClassStringAlt
case formLoclistx:
return ClassLocList
case formRnglistx:
return ClassRngList
}
}
// An entry is a sequence of attribute/value pairs.
type Entry struct {
Offset Offset // offset of Entry in DWARF info
Tag Tag // tag (kind of Entry)
Children bool // whether Entry is followed by children
Field []Field
}
// A Field is a single attribute/value pair in an Entry.
//
// A value can be one of several "attribute classes" defined by DWARF.
// The Go types corresponding to each class are:
//
// DWARF class Go type Class
// ----------- ------- -----
// address uint64 ClassAddress
// block []byte ClassBlock
// constant int64 ClassConstant
// flag bool ClassFlag
// reference
// to info dwarf.Offset ClassReference
// to type unit uint64 ClassReferenceSig
// string string ClassString
// exprloc []byte ClassExprLoc
// lineptr int64 ClassLinePtr
// loclistptr int64 ClassLocListPtr
// macptr int64 ClassMacPtr
// rangelistptr int64 ClassRangeListPtr
//
// For unrecognized or vendor-defined attributes, Class may be
// ClassUnknown.
type Field struct {
Attr Attr
Val any
Class Class
}
// A Class is the DWARF 4 class of an attribute value.
//
// In general, a given attribute's value may take on one of several
// possible classes defined by DWARF, each of which leads to a
// slightly different interpretation of the attribute.
//
// DWARF version 4 distinguishes attribute value classes more finely
// than previous versions of DWARF. The reader will disambiguate
// coarser classes from earlier versions of DWARF into the appropriate
// DWARF 4 class. For example, DWARF 2 uses "constant" for constants
// as well as all types of section offsets, but the reader will
// canonicalize attributes in DWARF 2 files that refer to section
// offsets to one of the Class*Ptr classes, even though these classes
// were only defined in DWARF 3.
type Class int
const (
// ClassUnknown represents values of unknown DWARF class.
ClassUnknown Class = iota
// ClassAddress represents values of type uint64 that are
// addresses on the target machine.
ClassAddress
// ClassBlock represents values of type []byte whose
// interpretation depends on the attribute.
ClassBlock
// ClassConstant represents values of type int64 that are
// constants. The interpretation of this constant depends on
// the attribute.
ClassConstant
// ClassExprLoc represents values of type []byte that contain
// an encoded DWARF expression or location description.
ClassExprLoc
// ClassFlag represents values of type bool.
ClassFlag
// ClassLinePtr represents values that are an int64 offset
// into the "line" section.
ClassLinePtr
// ClassLocListPtr represents values that are an int64 offset
// into the "loclist" section.
ClassLocListPtr
// ClassMacPtr represents values that are an int64 offset into
// the "mac" section.
ClassMacPtr
// ClassRangeListPtr represents values that are an int64 offset into
// the "rangelist" section.
ClassRangeListPtr
// ClassReference represents values that are an Offset offset
// of an Entry in the info section (for use with Reader.Seek).
// The DWARF specification combines ClassReference and
// ClassReferenceSig into class "reference".
ClassReference
// ClassReferenceSig represents values that are a uint64 type
// signature referencing a type Entry.
ClassReferenceSig
// ClassString represents values that are strings. If the
// compilation unit specifies the AttrUseUTF8 flag (strongly
// recommended), the string value will be encoded in UTF-8.
// Otherwise, the encoding is unspecified.
ClassString
// ClassReferenceAlt represents values of type int64 that are
// an offset into the DWARF "info" section of an alternate
// object file.
ClassReferenceAlt
// ClassStringAlt represents values of type int64 that are an
// offset into the DWARF string section of an alternate object
// file.
ClassStringAlt
// ClassAddrPtr represents values that are an int64 offset
// into the "addr" section.
ClassAddrPtr
// ClassLocList represents values that are an int64 offset
// into the "loclists" section.
ClassLocList
// ClassRngList represents values that are a uint64 offset
// from the base of the "rnglists" section.
ClassRngList
// ClassRngListsPtr represents values that are an int64 offset
// into the "rnglists" section. These are used as the base for
// ClassRngList values.
ClassRngListsPtr
// ClassStrOffsetsPtr represents values that are an int64
// offset into the "str_offsets" section.
ClassStrOffsetsPtr
)
//go:generate stringer -type=Class
func (i Class) GoString() string {
return "dwarf." + i.String()
}
// Val returns the value associated with attribute Attr in Entry,
// or nil if there is no such attribute.
//
// A common idiom is to merge the check for nil return with
// the check that the value has the expected dynamic type, as in:
//
// v, ok := e.Val(AttrSibling).(int64)
func (e *Entry) Val(a Attr) any {
if f := e.AttrField(a); f != nil {
return f.Val
}
return nil
}
// AttrField returns the Field associated with attribute Attr in
// Entry, or nil if there is no such attribute.
func (e *Entry) AttrField(a Attr) *Field {
for i, f := range e.Field {
if f.Attr == a {
return &e.Field[i]
}
}
return nil
}
// An Offset represents the location of an Entry within the DWARF info.
// (See Reader.Seek.)
type Offset uint32
// Entry reads a single entry from buf, decoding
// according to the given abbreviation table.
func (b *buf) entry(cu *Entry, atab abbrevTable, ubase Offset, vers int) *Entry {
off := b.off
id := uint32(b.uint())
if id == 0 {
return &Entry{}
}
a, ok := atab[id]
if !ok {
b.error("unknown abbreviation table index")
return nil
}
e := &Entry{
Offset: off,
Tag: a.tag,
Children: a.children,
Field: make([]Field, len(a.field)),
}
// If we are currently parsing the compilation unit,
// we can't evaluate Addrx or Strx until we've seen the
// relevant base entry.
type delayed struct {
idx int
off uint64
fmt format
}
var delay []delayed
resolveStrx := func(strBase, off uint64) string {
off += strBase
if uint64(int(off)) != off {
b.error("DW_FORM_strx offset out of range")
}
b1 := makeBuf(b.dwarf, b.format, "str_offsets", 0, b.dwarf.strOffsets)
b1.skip(int(off))
is64, _ := b.format.dwarf64()
if is64 {
off = b1.uint64()
} else {
off = uint64(b1.uint32())
}
if b1.err != nil {
b.err = b1.err
return ""
}
if uint64(int(off)) != off {
b.error("DW_FORM_strx indirect offset out of range")
}
b1 = makeBuf(b.dwarf, b.format, "str", 0, b.dwarf.str)
b1.skip(int(off))
val := b1.string()
if b1.err != nil {
b.err = b1.err
}
return val
}
resolveRnglistx := func(rnglistsBase, off uint64) uint64 {
is64, _ := b.format.dwarf64()
if is64 {
off *= 8
} else {
off *= 4
}
off += rnglistsBase
if uint64(int(off)) != off {
b.error("DW_FORM_rnglistx offset out of range")
}
b1 := makeBuf(b.dwarf, b.format, "rnglists", 0, b.dwarf.rngLists)
b1.skip(int(off))
if is64 {
off = b1.uint64()
} else {
off = uint64(b1.uint32())
}
if b1.err != nil {
b.err = b1.err
return 0
}
if uint64(int(off)) != off {
b.error("DW_FORM_rnglistx indirect offset out of range")
}
return rnglistsBase + off
}
for i := range e.Field {
e.Field[i].Attr = a.field[i].attr
e.Field[i].Class = a.field[i].class
fmt := a.field[i].fmt
if fmt == formIndirect {
fmt = format(b.uint())
e.Field[i].Class = formToClass(fmt, a.field[i].attr, vers, b)
}
var val any
switch fmt {
default:
b.error("unknown entry attr format 0x" + strconv.FormatInt(int64(fmt), 16))
// address
case formAddr:
val = b.addr()
case formAddrx, formAddrx1, formAddrx2, formAddrx3, formAddrx4:
var off uint64
switch fmt {
case formAddrx:
off = b.uint()
case formAddrx1:
off = uint64(b.uint8())
case formAddrx2:
off = uint64(b.uint16())
case formAddrx3:
off = uint64(b.uint24())
case formAddrx4:
off = uint64(b.uint32())
}
if b.dwarf.addr == nil {
b.error("DW_FORM_addrx with no .debug_addr section")
}
if b.err != nil {
return nil
}
// We have to adjust by the offset of the
// compilation unit. This won't work if the
// program uses Reader.Seek to skip over the
// unit. Not much we can do about that.
var addrBase int64
if cu != nil {
addrBase, _ = cu.Val(AttrAddrBase).(int64)
} else if a.tag == TagCompileUnit {
delay = append(delay, delayed{i, off, formAddrx})
break
}
var err error
val, err = b.dwarf.debugAddr(b.format, uint64(addrBase), off)
if err != nil {
if b.err == nil {
b.err = err
}
return nil
}
// block
case formDwarfBlock1:
val = b.bytes(int(b.uint8()))
case formDwarfBlock2:
val = b.bytes(int(b.uint16()))
case formDwarfBlock4:
val = b.bytes(int(b.uint32()))
case formDwarfBlock:
val = b.bytes(int(b.uint()))
// constant
case formData1:
val = int64(b.uint8())
case formData2:
val = int64(b.uint16())
case formData4:
val = int64(b.uint32())
case formData8:
val = int64(b.uint64())
case formData16:
val = b.bytes(16)
case formSdata:
val = int64(b.int())
case formUdata:
val = int64(b.uint())
case formImplicitConst:
val = a.field[i].val
// flag
case formFlag:
val = b.uint8() == 1
// New in DWARF 4.
case formFlagPresent:
// The attribute is implicitly indicated as present, and no value is
// encoded in the debugging information entry itself.
val = true
// reference to other entry
case formRefAddr:
vers := b.format.version()
if vers == 0 {
b.error("unknown version for DW_FORM_ref_addr")
} else if vers == 2 {
val = Offset(b.addr())
} else {
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_ref_addr")
} else if is64 {
val = Offset(b.uint64())
} else {
val = Offset(b.uint32())
}
}
case formRef1:
val = Offset(b.uint8()) + ubase
case formRef2:
val = Offset(b.uint16()) + ubase
case formRef4:
val = Offset(b.uint32()) + ubase
case formRef8:
val = Offset(b.uint64()) + ubase
case formRefUdata:
val = Offset(b.uint()) + ubase
// string
case formString:
val = b.string()
case formStrp, formLineStrp:
var off uint64 // offset into .debug_str
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_strp/line_strp")
} else if is64 {
off = b.uint64()
} else {
off = uint64(b.uint32())
}
if uint64(int(off)) != off {
b.error("DW_FORM_strp/line_strp offset out of range")
}
if b.err != nil {
return nil
}
var b1 buf
if fmt == formStrp {
b1 = makeBuf(b.dwarf, b.format, "str", 0, b.dwarf.str)
} else {
if len(b.dwarf.lineStr) == 0 {
b.error("DW_FORM_line_strp with no .debug_line_str section")
return nil
}
b1 = makeBuf(b.dwarf, b.format, "line_str", 0, b.dwarf.lineStr)
}
b1.skip(int(off))
val = b1.string()
if b1.err != nil {
b.err = b1.err
return nil
}
case formStrx, formStrx1, formStrx2, formStrx3, formStrx4:
var off uint64
switch fmt {
case formStrx:
off = b.uint()
case formStrx1:
off = uint64(b.uint8())
case formStrx2:
off = uint64(b.uint16())
case formStrx3:
off = uint64(b.uint24())
case formStrx4:
off = uint64(b.uint32())
}
if len(b.dwarf.strOffsets) == 0 {
b.error("DW_FORM_strx with no .debug_str_offsets section")
}
is64, known := b.format.dwarf64()
if !known {
b.error("unknown offset size for DW_FORM_strx")
}
if b.err != nil {
return nil
}
if is64 {
off *= 8
} else {
off *= 4
}
// We have to adjust by the offset of the
// compilation unit. This won't work if the
// program uses Reader.Seek to skip over the
// unit. Not much we can do about that.
var strBase int64
if cu != nil {
strBase, _ = cu.Val(AttrStrOffsetsBase).(int64)
} else if a.tag == TagCompileUnit {
delay = append(delay, delayed{i, off, formStrx})
break
}
val = resolveStrx(uint64(strBase), off)
case formStrpSup:
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_strp_sup")
} else if is64 {
val = b.uint64()
} else {
val = b.uint32()
}
// lineptr, loclistptr, macptr, rangelistptr
// New in DWARF 4, but clang can generate them with -gdwarf-2.
// Section reference, replacing use of formData4 and formData8.
case formSecOffset, formGnuRefAlt, formGnuStrpAlt:
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for form 0x" + strconv.FormatInt(int64(fmt), 16))
} else if is64 {
val = int64(b.uint64())
} else {
val = int64(b.uint32())
}
// exprloc
// New in DWARF 4.
case formExprloc:
val = b.bytes(int(b.uint()))
// reference
// New in DWARF 4.
case formRefSig8:
// 64-bit type signature.
val = b.uint64()
case formRefSup4:
val = b.uint32()
case formRefSup8:
val = b.uint64()
// loclist
case formLoclistx:
val = b.uint()
// rnglist
case formRnglistx:
off := b.uint()
// We have to adjust by the rnglists_base of
// the compilation unit. This won't work if
// the program uses Reader.Seek to skip over
// the unit. Not much we can do about that.
var rnglistsBase int64
if cu != nil {
rnglistsBase, _ = cu.Val(AttrRnglistsBase).(int64)
} else if a.tag == TagCompileUnit {
delay = append(delay, delayed{i, off, formRnglistx})
break
}
val = resolveRnglistx(uint64(rnglistsBase), off)
}
e.Field[i].Val = val
}
if b.err != nil {
return nil
}
for _, del := range delay {
switch del.fmt {
case formAddrx:
addrBase, _ := e.Val(AttrAddrBase).(int64)
val, err := b.dwarf.debugAddr(b.format, uint64(addrBase), del.off)
if err != nil {
b.err = err
return nil
}
e.Field[del.idx].Val = val
case formStrx:
strBase, _ := e.Val(AttrStrOffsetsBase).(int64)
e.Field[del.idx].Val = resolveStrx(uint64(strBase), del.off)
if b.err != nil {
return nil
}
case formRnglistx:
rnglistsBase, _ := e.Val(AttrRnglistsBase).(int64)
e.Field[del.idx].Val = resolveRnglistx(uint64(rnglistsBase), del.off)
if b.err != nil {
return nil
}
}
}
return e
}
// A Reader allows reading Entry structures from a DWARF “info” section.
// The Entry structures are arranged in a tree. The Reader's Next function
// return successive entries from a pre-order traversal of the tree.
// If an entry has children, its Children field will be true, and the children
// follow, terminated by an Entry with Tag 0.
type Reader struct {
b buf
d *Data
err error
unit int
lastUnit bool // set if last entry returned by Next is TagCompileUnit/TagPartialUnit
lastChildren bool // .Children of last entry returned by Next
lastSibling Offset // .Val(AttrSibling) of last entry returned by Next
cu *Entry // current compilation unit
}
// Reader returns a new Reader for Data.
// The reader is positioned at byte offset 0 in the DWARF “info” section.
func (d *Data) Reader() *Reader {
r := &Reader{d: d}
r.Seek(0)
return r
}
// AddressSize returns the size in bytes of addresses in the current compilation
// unit.
func (r *Reader) AddressSize() int {
return r.d.unit[r.unit].asize
}
// ByteOrder returns the byte order in the current compilation unit.
func (r *Reader) ByteOrder() binary.ByteOrder {
return r.b.order
}
// Seek positions the Reader at offset off in the encoded entry stream.
// Offset 0 can be used to denote the first entry.
func (r *Reader) Seek(off Offset) {
d := r.d
r.err = nil
r.lastChildren = false
if off == 0 {
if len(d.unit) == 0 {
return
}
u := &d.unit[0]
r.unit = 0
r.b = makeBuf(r.d, u, "info", u.off, u.data)
r.cu = nil
return
}
i := d.offsetToUnit(off)
if i == -1 {
r.err = errors.New("offset out of range")
return
}
if i != r.unit {
r.cu = nil
}
u := &d.unit[i]
r.unit = i
r.b = makeBuf(r.d, u, "info", off, u.data[off-u.off:])
}
// maybeNextUnit advances to the next unit if this one is finished.
func (r *Reader) maybeNextUnit() {
for len(r.b.data) == 0 && r.unit+1 < len(r.d.unit) {
r.nextUnit()
}
}
// nextUnit advances to the next unit.
func (r *Reader) nextUnit() {
r.unit++
u := &r.d.unit[r.unit]
r.b = makeBuf(r.d, u, "info", u.off, u.data)
r.cu = nil
}
// Next reads the next entry from the encoded entry stream.
// It returns nil, nil when it reaches the end of the section.
// It returns an error if the current offset is invalid or the data at the
// offset cannot be decoded as a valid Entry.
func (r *Reader) Next() (*Entry, error) {
if r.err != nil {
return nil, r.err
}
r.maybeNextUnit()
if len(r.b.data) == 0 {
return nil, nil
}
u := &r.d.unit[r.unit]
e := r.b.entry(r.cu, u.atable, u.base, u.vers)
if r.b.err != nil {
r.err = r.b.err
return nil, r.err
}
r.lastUnit = false
if e != nil {
r.lastChildren = e.Children
if r.lastChildren {
r.lastSibling, _ = e.Val(AttrSibling).(Offset)
}
if e.Tag == TagCompileUnit || e.Tag == TagPartialUnit {
r.lastUnit = true
r.cu = e
}
} else {
r.lastChildren = false
}
return e, nil
}
// SkipChildren skips over the child entries associated with
// the last Entry returned by Next. If that Entry did not have
// children or Next has not been called, SkipChildren is a no-op.
func (r *Reader) SkipChildren() {
if r.err != nil || !r.lastChildren {
return
}
// If the last entry had a sibling attribute,
// that attribute gives the offset of the next
// sibling, so we can avoid decoding the
// child subtrees.
if r.lastSibling >= r.b.off {
r.Seek(r.lastSibling)
return
}
if r.lastUnit && r.unit+1 < len(r.d.unit) {
r.nextUnit()
return
}
for {
e, err := r.Next()
if err != nil || e == nil || e.Tag == 0 {
break
}
if e.Children {
r.SkipChildren()
}
}
}
// clone returns a copy of the reader. This is used by the typeReader
// interface.
func (r *Reader) clone() typeReader {
return r.d.Reader()
}
// offset returns the current buffer offset. This is used by the
// typeReader interface.
func (r *Reader) offset() Offset {
return r.b.off
}
// SeekPC returns the Entry for the compilation unit that includes pc,
// and positions the reader to read the children of that unit. If pc
// is not covered by any unit, SeekPC returns ErrUnknownPC and the
// position of the reader is undefined.
//
// Because compilation units can describe multiple regions of the
// executable, in the worst case SeekPC must search through all the
// ranges in all the compilation units. Each call to SeekPC starts the
// search at the compilation unit of the last call, so in general
// looking up a series of PCs will be faster if they are sorted. If
// the caller wishes to do repeated fast PC lookups, it should build
// an appropriate index using the Ranges method.
func (r *Reader) SeekPC(pc uint64) (*Entry, error) {
unit := r.unit
for i := 0; i < len(r.d.unit); i++ {
if unit >= len(r.d.unit) {
unit = 0
}
r.err = nil
r.lastChildren = false
r.unit = unit
r.cu = nil
u := &r.d.unit[unit]
r.b = makeBuf(r.d, u, "info", u.off, u.data)
e, err := r.Next()
if err != nil || e == nil || e.Tag == 0 {
return nil, err
}
ranges, err := r.d.Ranges(e)
if err != nil {
return nil, err
}
for _, pcs := range ranges {
if pcs[0] <= pc && pc < pcs[1] {
return e, nil
}
}
unit++
}
return nil, ErrUnknownPC
}
// Ranges returns the PC ranges covered by e, a slice of [low,high) pairs.
// Only some entry types, such as TagCompileUnit or TagSubprogram, have PC
// ranges; for others, this will return nil with no error.
func (d *Data) Ranges(e *Entry) ([][2]uint64, error) {
var ret [][2]uint64
low, lowOK := e.Val(AttrLowpc).(uint64)
var high uint64
var highOK bool
highField := e.AttrField(AttrHighpc)
if highField != nil {
switch highField.Class {
case ClassAddress:
high, highOK = highField.Val.(uint64)
case ClassConstant:
off, ok := highField.Val.(int64)
if ok {
high = low + uint64(off)
highOK = true
}
}
}
if lowOK && highOK {
ret = append(ret, [2]uint64{low, high})
}
var u *unit
if uidx := d.offsetToUnit(e.Offset); uidx >= 0 && uidx < len(d.unit) {
u = &d.unit[uidx]
}
if u != nil && u.vers >= 5 && d.rngLists != nil {
// DWARF version 5 and later
field := e.AttrField(AttrRanges)
if field == nil {
return ret, nil
}
switch field.Class {
case ClassRangeListPtr:
ranges, rangesOK := field.Val.(int64)
if !rangesOK {
return ret, nil
}
cu, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf5Ranges(u, cu, base, ranges, ret)
case ClassRngList:
rnglist, ok := field.Val.(uint64)
if !ok {
return ret, nil
}
cu, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf5Ranges(u, cu, base, int64(rnglist), ret)
default:
return ret, nil
}
}
// DWARF version 2 through 4
ranges, rangesOK := e.Val(AttrRanges).(int64)
if rangesOK && d.ranges != nil {
_, base, err := d.baseAddressForEntry(e)
if err != nil {
return nil, err
}
return d.dwarf2Ranges(u, base, ranges, ret)
}
return ret, nil
}
// baseAddressForEntry returns the initial base address to be used when
// looking up the range list of entry e.
// DWARF specifies that this should be the lowpc attribute of the enclosing
// compilation unit, however comments in gdb/dwarf2read.c say that some
// versions of GCC use the entrypc attribute, so we check that too.
func (d *Data) baseAddressForEntry(e *Entry) (*Entry, uint64, error) {
var cu *Entry
if e.Tag == TagCompileUnit {
cu = e
} else {
i := d.offsetToUnit(e.Offset)
if i == -1 {
return nil, 0, errors.New("no unit for entry")
}
u := &d.unit[i]
b := makeBuf(d, u, "info", u.off, u.data)
cu = b.entry(nil, u.atable, u.base, u.vers)
if b.err != nil {
return nil, 0, b.err
}
}
if cuEntry, cuEntryOK := cu.Val(AttrEntrypc).(uint64); cuEntryOK {
return cu, cuEntry, nil
} else if cuLow, cuLowOK := cu.Val(AttrLowpc).(uint64); cuLowOK {
return cu, cuLow, nil
}
return cu, 0, nil
}
func (d *Data) dwarf2Ranges(u *unit, base uint64, ranges int64, ret [][2]uint64) ([][2]uint64, error) {
if ranges < 0 || ranges > int64(len(d.ranges)) {
return nil, fmt.Errorf("invalid range offset %d (max %d)", ranges, len(d.ranges))
}
buf := makeBuf(d, u, "ranges", Offset(ranges), d.ranges[ranges:])
for len(buf.data) > 0 {
low := buf.addr()
high := buf.addr()
if low == 0 && high == 0 {
break
}
if low == ^uint64(0)>>uint((8-u.addrsize())*8) {
base = high
} else {
ret = append(ret, [2]uint64{base + low, base + high})
}
}
return ret, nil
}
// dwarf5Ranges interprets a debug_rnglists sequence, see DWARFv5 section
// 2.17.3 (page 53).
func (d *Data) dwarf5Ranges(u *unit, cu *Entry, base uint64, ranges int64, ret [][2]uint64) ([][2]uint64, error) {
if ranges < 0 || ranges > int64(len(d.rngLists)) {
return nil, fmt.Errorf("invalid rnglist offset %d (max %d)", ranges, len(d.ranges))
}
var addrBase int64
if cu != nil {
addrBase, _ = cu.Val(AttrAddrBase).(int64)
}
buf := makeBuf(d, u, "rnglists", 0, d.rngLists)
buf.skip(int(ranges))
for {
opcode := buf.uint8()
switch opcode {
case rleEndOfList:
if buf.err != nil {
return nil, buf.err
}
return ret, nil
case rleBaseAddressx:
baseIdx := buf.uint()
var err error
base, err = d.debugAddr(u, uint64(addrBase), baseIdx)
if err != nil {
return nil, err
}
case rleStartxEndx:
startIdx := buf.uint()
endIdx := buf.uint()
start, err := d.debugAddr(u, uint64(addrBase), startIdx)
if err != nil {
return nil, err
}
end, err := d.debugAddr(u, uint64(addrBase), endIdx)
if err != nil {
return nil, err
}
ret = append(ret, [2]uint64{start, end})
case rleStartxLength:
startIdx := buf.uint()
len := buf.uint()
start, err := d.debugAddr(u, uint64(addrBase), startIdx)
if err != nil {
return nil, err
}
ret = append(ret, [2]uint64{start, start + len})
case rleOffsetPair:
off1 := buf.uint()
off2 := buf.uint()
ret = append(ret, [2]uint64{base + off1, base + off2})
case rleBaseAddress:
base = buf.addr()
case rleStartEnd:
start := buf.addr()
end := buf.addr()
ret = append(ret, [2]uint64{start, end})
case rleStartLength:
start := buf.addr()
len := buf.uint()
ret = append(ret, [2]uint64{start, start + len})
}
}
}
// debugAddr returns the address at idx in debug_addr
func (d *Data) debugAddr(format dataFormat, addrBase, idx uint64) (uint64, error) {
off := idx*uint64(format.addrsize()) + addrBase
if uint64(int(off)) != off {
return 0, errors.New("offset out of range")
}
b := makeBuf(d, format, "addr", 0, d.addr)
b.skip(int(off))
val := b.addr()
if b.err != nil {
return 0, b.err
}
return val, nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"errors"
"fmt"
"io"
"path"
"strings"
)
// A LineReader reads a sequence of LineEntry structures from a DWARF
// "line" section for a single compilation unit. LineEntries occur in
// order of increasing PC and each LineEntry gives metadata for the
// instructions from that LineEntry's PC to just before the next
// LineEntry's PC. The last entry will have its EndSequence field set.
type LineReader struct {
buf buf
// Original .debug_line section data. Used by Seek.
section []byte
str []byte // .debug_str
lineStr []byte // .debug_line_str
// Header information
version uint16
addrsize int
segmentSelectorSize int
minInstructionLength int
maxOpsPerInstruction int
defaultIsStmt bool
lineBase int
lineRange int
opcodeBase int
opcodeLengths []int
directories []string
fileEntries []*LineFile
programOffset Offset // section offset of line number program
endOffset Offset // section offset of byte following program
initialFileEntries int // initial length of fileEntries
// Current line number program state machine registers
state LineEntry // public state
fileIndex int // private state
}
// A LineEntry is a row in a DWARF line table.
type LineEntry struct {
// Address is the program-counter value of a machine
// instruction generated by the compiler. This LineEntry
// applies to each instruction from Address to just before the
// Address of the next LineEntry.
Address uint64
// OpIndex is the index of an operation within a VLIW
// instruction. The index of the first operation is 0. For
// non-VLIW architectures, it will always be 0. Address and
// OpIndex together form an operation pointer that can
// reference any individual operation within the instruction
// stream.
OpIndex int
// File is the source file corresponding to these
// instructions.
File *LineFile
// Line is the source code line number corresponding to these
// instructions. Lines are numbered beginning at 1. It may be
// 0 if these instructions cannot be attributed to any source
// line.
Line int
// Column is the column number within the source line of these
// instructions. Columns are numbered beginning at 1. It may
// be 0 to indicate the "left edge" of the line.
Column int
// IsStmt indicates that Address is a recommended breakpoint
// location, such as the beginning of a line, statement, or a
// distinct subpart of a statement.
IsStmt bool
// BasicBlock indicates that Address is the beginning of a
// basic block.
BasicBlock bool
// PrologueEnd indicates that Address is one (of possibly
// many) PCs where execution should be suspended for a
// breakpoint on entry to the containing function.
//
// Added in DWARF 3.
PrologueEnd bool
// EpilogueBegin indicates that Address is one (of possibly
// many) PCs where execution should be suspended for a
// breakpoint on exit from this function.
//
// Added in DWARF 3.
EpilogueBegin bool
// ISA is the instruction set architecture for these
// instructions. Possible ISA values should be defined by the
// applicable ABI specification.
//
// Added in DWARF 3.
ISA int
// Discriminator is an arbitrary integer indicating the block
// to which these instructions belong. It serves to
// distinguish among multiple blocks that may all have with
// the same source file, line, and column. Where only one
// block exists for a given source position, it should be 0.
//
// Added in DWARF 3.
Discriminator int
// EndSequence indicates that Address is the first byte after
// the end of a sequence of target machine instructions. If it
// is set, only this and the Address field are meaningful. A
// line number table may contain information for multiple
// potentially disjoint instruction sequences. The last entry
// in a line table should always have EndSequence set.
EndSequence bool
}
// A LineFile is a source file referenced by a DWARF line table entry.
type LineFile struct {
Name string
Mtime uint64 // Implementation defined modification time, or 0 if unknown
Length int // File length, or 0 if unknown
}
// LineReader returns a new reader for the line table of compilation
// unit cu, which must be an Entry with tag TagCompileUnit.
//
// If this compilation unit has no line table, it returns nil, nil.
func (d *Data) LineReader(cu *Entry) (*LineReader, error) {
if d.line == nil {
// No line tables available.
return nil, nil
}
// Get line table information from cu.
off, ok := cu.Val(AttrStmtList).(int64)
if !ok {
// cu has no line table.
return nil, nil
}
if off < 0 || off > int64(len(d.line)) {
return nil, errors.New("AttrStmtList value out of range")
}
// AttrCompDir is optional if all file names are absolute. Use
// the empty string if it's not present.
compDir, _ := cu.Val(AttrCompDir).(string)
// Create the LineReader.
u := &d.unit[d.offsetToUnit(cu.Offset)]
buf := makeBuf(d, u, "line", Offset(off), d.line[off:])
// The compilation directory is implicitly directories[0].
r := LineReader{
buf: buf,
section: d.line,
str: d.str,
lineStr: d.lineStr,
}
// Read the header.
if err := r.readHeader(compDir); err != nil {
return nil, err
}
// Initialize line reader state.
r.Reset()
return &r, nil
}
// readHeader reads the line number program header from r.buf and sets
// all of the header fields in r.
func (r *LineReader) readHeader(compDir string) error {
buf := &r.buf
// Read basic header fields [DWARF2 6.2.4].
hdrOffset := buf.off
unitLength, dwarf64 := buf.unitLength()
r.endOffset = buf.off + unitLength
if r.endOffset > buf.off+Offset(len(buf.data)) {
return DecodeError{"line", hdrOffset, fmt.Sprintf("line table end %d exceeds section size %d", r.endOffset, buf.off+Offset(len(buf.data)))}
}
r.version = buf.uint16()
if buf.err == nil && (r.version < 2 || r.version > 5) {
// DWARF goes to all this effort to make new opcodes
// backward-compatible, and then adds fields right in
// the middle of the header in new versions, so we're
// picky about only supporting known line table
// versions.
return DecodeError{"line", hdrOffset, fmt.Sprintf("unknown line table version %d", r.version)}
}
if r.version >= 5 {
r.addrsize = int(buf.uint8())
r.segmentSelectorSize = int(buf.uint8())
} else {
r.addrsize = buf.format.addrsize()
r.segmentSelectorSize = 0
}
var headerLength Offset
if dwarf64 {
headerLength = Offset(buf.uint64())
} else {
headerLength = Offset(buf.uint32())
}
programOffset := buf.off + headerLength
if programOffset > r.endOffset {
return DecodeError{"line", hdrOffset, fmt.Sprintf("malformed line table: program offset %d exceeds end offset %d", programOffset, r.endOffset)}
}
r.programOffset = programOffset
r.minInstructionLength = int(buf.uint8())
if r.version >= 4 {
// [DWARF4 6.2.4]
r.maxOpsPerInstruction = int(buf.uint8())
} else {
r.maxOpsPerInstruction = 1
}
r.defaultIsStmt = buf.uint8() != 0
r.lineBase = int(int8(buf.uint8()))
r.lineRange = int(buf.uint8())
// Validate header.
if buf.err != nil {
return buf.err
}
if r.maxOpsPerInstruction == 0 {
return DecodeError{"line", hdrOffset, "invalid maximum operations per instruction: 0"}
}
if r.lineRange == 0 {
return DecodeError{"line", hdrOffset, "invalid line range: 0"}
}
// Read standard opcode length table. This table starts with opcode 1.
r.opcodeBase = int(buf.uint8())
r.opcodeLengths = make([]int, r.opcodeBase)
for i := 1; i < r.opcodeBase; i++ {
r.opcodeLengths[i] = int(buf.uint8())
}
// Validate opcode lengths.
if buf.err != nil {
return buf.err
}
for i, length := range r.opcodeLengths {
if known, ok := knownOpcodeLengths[i]; ok && known != length {
return DecodeError{"line", hdrOffset, fmt.Sprintf("opcode %d expected to have length %d, but has length %d", i, known, length)}
}
}
if r.version < 5 {
// Read include directories table.
r.directories = []string{compDir}
for {
directory := buf.string()
if buf.err != nil {
return buf.err
}
if len(directory) == 0 {
break
}
if !pathIsAbs(directory) {
// Relative paths are implicitly relative to
// the compilation directory.
directory = pathJoin(compDir, directory)
}
r.directories = append(r.directories, directory)
}
// Read file name list. File numbering starts with 1,
// so leave the first entry nil.
r.fileEntries = make([]*LineFile, 1)
for {
if done, err := r.readFileEntry(); err != nil {
return err
} else if done {
break
}
}
} else {
dirFormat := r.readLNCTFormat()
c := buf.uint()
r.directories = make([]string, c)
for i := range r.directories {
dir, _, _, err := r.readLNCT(dirFormat, dwarf64)
if err != nil {
return err
}
r.directories[i] = dir
}
fileFormat := r.readLNCTFormat()
c = buf.uint()
r.fileEntries = make([]*LineFile, c)
for i := range r.fileEntries {
name, mtime, size, err := r.readLNCT(fileFormat, dwarf64)
if err != nil {
return err
}
r.fileEntries[i] = &LineFile{name, mtime, int(size)}
}
}
r.initialFileEntries = len(r.fileEntries)
return buf.err
}
// lnctForm is a pair of an LNCT code and a form. This represents an
// entry in the directory name or file name description in the DWARF 5
// line number program header.
type lnctForm struct {
lnct int
form format
}
// readLNCTFormat reads an LNCT format description.
func (r *LineReader) readLNCTFormat() []lnctForm {
c := r.buf.uint8()
ret := make([]lnctForm, c)
for i := range ret {
ret[i].lnct = int(r.buf.uint())
ret[i].form = format(r.buf.uint())
}
return ret
}
// readLNCT reads a sequence of LNCT entries and returns path information.
func (r *LineReader) readLNCT(s []lnctForm, dwarf64 bool) (path string, mtime uint64, size uint64, err error) {
var dir string
for _, lf := range s {
var str string
var val uint64
switch lf.form {
case formString:
str = r.buf.string()
case formStrp, formLineStrp:
var off uint64
if dwarf64 {
off = r.buf.uint64()
} else {
off = uint64(r.buf.uint32())
}
if uint64(int(off)) != off {
return "", 0, 0, DecodeError{"line", r.buf.off, "strp/line_strp offset out of range"}
}
var b1 buf
if lf.form == formStrp {
b1 = makeBuf(r.buf.dwarf, r.buf.format, "str", 0, r.str)
} else {
b1 = makeBuf(r.buf.dwarf, r.buf.format, "line_str", 0, r.lineStr)
}
b1.skip(int(off))
str = b1.string()
if b1.err != nil {
return "", 0, 0, DecodeError{"line", r.buf.off, b1.err.Error()}
}
case formStrpSup:
// Supplemental sections not yet supported.
if dwarf64 {
r.buf.uint64()
} else {
r.buf.uint32()
}
case formStrx:
// .debug_line.dwo sections not yet supported.
r.buf.uint()
case formStrx1:
r.buf.uint8()
case formStrx2:
r.buf.uint16()
case formStrx3:
r.buf.uint24()
case formStrx4:
r.buf.uint32()
case formData1:
val = uint64(r.buf.uint8())
case formData2:
val = uint64(r.buf.uint16())
case formData4:
val = uint64(r.buf.uint32())
case formData8:
val = r.buf.uint64()
case formData16:
r.buf.bytes(16)
case formDwarfBlock:
r.buf.bytes(int(r.buf.uint()))
case formUdata:
val = r.buf.uint()
}
switch lf.lnct {
case lnctPath:
path = str
case lnctDirectoryIndex:
if val >= uint64(len(r.directories)) {
return "", 0, 0, DecodeError{"line", r.buf.off, "directory index out of range"}
}
dir = r.directories[val]
case lnctTimestamp:
mtime = val
case lnctSize:
size = val
case lnctMD5:
// Ignored.
}
}
if dir != "" && path != "" {
path = pathJoin(dir, path)
}
return path, mtime, size, nil
}
// readFileEntry reads a file entry from either the header or a
// DW_LNE_define_file extended opcode and adds it to r.fileEntries. A
// true return value indicates that there are no more entries to read.
func (r *LineReader) readFileEntry() (bool, error) {
name := r.buf.string()
if r.buf.err != nil {
return false, r.buf.err
}
if len(name) == 0 {
return true, nil
}
off := r.buf.off
dirIndex := int(r.buf.uint())
if !pathIsAbs(name) {
if dirIndex >= len(r.directories) {
return false, DecodeError{"line", off, "directory index too large"}
}
name = pathJoin(r.directories[dirIndex], name)
}
mtime := r.buf.uint()
length := int(r.buf.uint())
// If this is a dynamically added path and the cursor was
// backed up, we may have already added this entry. Avoid
// updating existing line table entries in this case. This
// avoids an allocation and potential racy access to the slice
// backing store if the user called Files.
if len(r.fileEntries) < cap(r.fileEntries) {
fe := r.fileEntries[:len(r.fileEntries)+1]
if fe[len(fe)-1] != nil {
// We already processed this addition.
r.fileEntries = fe
return false, nil
}
}
r.fileEntries = append(r.fileEntries, &LineFile{name, mtime, length})
return false, nil
}
// updateFile updates r.state.File after r.fileIndex has
// changed or r.fileEntries has changed.
func (r *LineReader) updateFile() {
if r.fileIndex < len(r.fileEntries) {
r.state.File = r.fileEntries[r.fileIndex]
} else {
r.state.File = nil
}
}
// Next sets *entry to the next row in this line table and moves to
// the next row. If there are no more entries and the line table is
// properly terminated, it returns io.EOF.
//
// Rows are always in order of increasing entry.Address, but
// entry.Line may go forward or backward.
func (r *LineReader) Next(entry *LineEntry) error {
if r.buf.err != nil {
return r.buf.err
}
// Execute opcodes until we reach an opcode that emits a line
// table entry.
for {
if len(r.buf.data) == 0 {
return io.EOF
}
emit := r.step(entry)
if r.buf.err != nil {
return r.buf.err
}
if emit {
return nil
}
}
}
// knownOpcodeLengths gives the opcode lengths (in varint arguments)
// of known standard opcodes.
var knownOpcodeLengths = map[int]int{
lnsCopy: 0,
lnsAdvancePC: 1,
lnsAdvanceLine: 1,
lnsSetFile: 1,
lnsNegateStmt: 0,
lnsSetBasicBlock: 0,
lnsConstAddPC: 0,
lnsSetPrologueEnd: 0,
lnsSetEpilogueBegin: 0,
lnsSetISA: 1,
// lnsFixedAdvancePC takes a uint8 rather than a varint; it's
// unclear what length the header is supposed to claim, so
// ignore it.
}
// step processes the next opcode and updates r.state. If the opcode
// emits a row in the line table, this updates *entry and returns
// true.
func (r *LineReader) step(entry *LineEntry) bool {
opcode := int(r.buf.uint8())
if opcode >= r.opcodeBase {
// Special opcode [DWARF2 6.2.5.1, DWARF4 6.2.5.1]
adjustedOpcode := opcode - r.opcodeBase
r.advancePC(adjustedOpcode / r.lineRange)
lineDelta := r.lineBase + adjustedOpcode%r.lineRange
r.state.Line += lineDelta
goto emit
}
switch opcode {
case 0:
// Extended opcode [DWARF2 6.2.5.3]
length := Offset(r.buf.uint())
startOff := r.buf.off
opcode := r.buf.uint8()
switch opcode {
case lneEndSequence:
r.state.EndSequence = true
*entry = r.state
r.resetState()
case lneSetAddress:
switch r.addrsize {
case 1:
r.state.Address = uint64(r.buf.uint8())
case 2:
r.state.Address = uint64(r.buf.uint16())
case 4:
r.state.Address = uint64(r.buf.uint32())
case 8:
r.state.Address = r.buf.uint64()
default:
r.buf.error("unknown address size")
}
case lneDefineFile:
if done, err := r.readFileEntry(); err != nil {
r.buf.err = err
return false
} else if done {
r.buf.err = DecodeError{"line", startOff, "malformed DW_LNE_define_file operation"}
return false
}
r.updateFile()
case lneSetDiscriminator:
// [DWARF4 6.2.5.3]
r.state.Discriminator = int(r.buf.uint())
}
r.buf.skip(int(startOff + length - r.buf.off))
if opcode == lneEndSequence {
return true
}
// Standard opcodes [DWARF2 6.2.5.2]
case lnsCopy:
goto emit
case lnsAdvancePC:
r.advancePC(int(r.buf.uint()))
case lnsAdvanceLine:
r.state.Line += int(r.buf.int())
case lnsSetFile:
r.fileIndex = int(r.buf.uint())
r.updateFile()
case lnsSetColumn:
r.state.Column = int(r.buf.uint())
case lnsNegateStmt:
r.state.IsStmt = !r.state.IsStmt
case lnsSetBasicBlock:
r.state.BasicBlock = true
case lnsConstAddPC:
r.advancePC((255 - r.opcodeBase) / r.lineRange)
case lnsFixedAdvancePC:
r.state.Address += uint64(r.buf.uint16())
// DWARF3 standard opcodes [DWARF3 6.2.5.2]
case lnsSetPrologueEnd:
r.state.PrologueEnd = true
case lnsSetEpilogueBegin:
r.state.EpilogueBegin = true
case lnsSetISA:
r.state.ISA = int(r.buf.uint())
default:
// Unhandled standard opcode. Skip the number of
// arguments that the prologue says this opcode has.
for i := 0; i < r.opcodeLengths[opcode]; i++ {
r.buf.uint()
}
}
return false
emit:
*entry = r.state
r.state.BasicBlock = false
r.state.PrologueEnd = false
r.state.EpilogueBegin = false
r.state.Discriminator = 0
return true
}
// advancePC advances "operation pointer" (the combination of Address
// and OpIndex) in r.state by opAdvance steps.
func (r *LineReader) advancePC(opAdvance int) {
opIndex := r.state.OpIndex + opAdvance
r.state.Address += uint64(r.minInstructionLength * (opIndex / r.maxOpsPerInstruction))
r.state.OpIndex = opIndex % r.maxOpsPerInstruction
}
// A LineReaderPos represents a position in a line table.
type LineReaderPos struct {
// off is the current offset in the DWARF line section.
off Offset
// numFileEntries is the length of fileEntries.
numFileEntries int
// state and fileIndex are the statement machine state at
// offset off.
state LineEntry
fileIndex int
}
// Tell returns the current position in the line table.
func (r *LineReader) Tell() LineReaderPos {
return LineReaderPos{r.buf.off, len(r.fileEntries), r.state, r.fileIndex}
}
// Seek restores the line table reader to a position returned by Tell.
//
// The argument pos must have been returned by a call to Tell on this
// line table.
func (r *LineReader) Seek(pos LineReaderPos) {
r.buf.off = pos.off
r.buf.data = r.section[r.buf.off:r.endOffset]
r.fileEntries = r.fileEntries[:pos.numFileEntries]
r.state = pos.state
r.fileIndex = pos.fileIndex
}
// Reset repositions the line table reader at the beginning of the
// line table.
func (r *LineReader) Reset() {
// Reset buffer to the line number program offset.
r.buf.off = r.programOffset
r.buf.data = r.section[r.buf.off:r.endOffset]
// Reset file entries list.
r.fileEntries = r.fileEntries[:r.initialFileEntries]
// Reset line number program state.
r.resetState()
}
// resetState resets r.state to its default values
func (r *LineReader) resetState() {
// Reset the state machine registers to the defaults given in
// [DWARF4 6.2.2].
r.state = LineEntry{
Address: 0,
OpIndex: 0,
File: nil,
Line: 1,
Column: 0,
IsStmt: r.defaultIsStmt,
BasicBlock: false,
PrologueEnd: false,
EpilogueBegin: false,
ISA: 0,
Discriminator: 0,
}
r.fileIndex = 1
r.updateFile()
}
// Files returns the file name table of this compilation unit as of
// the current position in the line table. The file name table may be
// referenced from attributes in this compilation unit such as
// AttrDeclFile.
//
// Entry 0 is always nil, since file index 0 represents "no file".
//
// The file name table of a compilation unit is not fixed. Files
// returns the file table as of the current position in the line
// table. This may contain more entries than the file table at an
// earlier position in the line table, though existing entries never
// change.
func (r *LineReader) Files() []*LineFile {
return r.fileEntries
}
// ErrUnknownPC is the error returned by LineReader.ScanPC when the
// seek PC is not covered by any entry in the line table.
var ErrUnknownPC = errors.New("ErrUnknownPC")
// SeekPC sets *entry to the LineEntry that includes pc and positions
// the reader on the next entry in the line table. If necessary, this
// will seek backwards to find pc.
//
// If pc is not covered by any entry in this line table, SeekPC
// returns ErrUnknownPC. In this case, *entry and the final seek
// position are unspecified.
//
// Note that DWARF line tables only permit sequential, forward scans.
// Hence, in the worst case, this takes time linear in the size of the
// line table. If the caller wishes to do repeated fast PC lookups, it
// should build an appropriate index of the line table.
func (r *LineReader) SeekPC(pc uint64, entry *LineEntry) error {
if err := r.Next(entry); err != nil {
return err
}
if entry.Address > pc {
// We're too far. Start at the beginning of the table.
r.Reset()
if err := r.Next(entry); err != nil {
return err
}
if entry.Address > pc {
// The whole table starts after pc.
r.Reset()
return ErrUnknownPC
}
}
// Scan until we pass pc, then back up one.
for {
var next LineEntry
pos := r.Tell()
if err := r.Next(&next); err != nil {
if err == io.EOF {
return ErrUnknownPC
}
return err
}
if next.Address > pc {
if entry.EndSequence {
// pc is in a hole in the table.
return ErrUnknownPC
}
// entry is the desired entry. Back up the
// cursor to "next" and return success.
r.Seek(pos)
return nil
}
*entry = next
}
}
// pathIsAbs reports whether path is an absolute path (or "full path
// name" in DWARF parlance). This is in "whatever form makes sense for
// the host system", so this accepts both UNIX-style and DOS-style
// absolute paths. We avoid the filepath package because we want this
// to behave the same regardless of our host system and because we
// don't know what system the paths came from.
func pathIsAbs(path string) bool {
_, path = splitDrive(path)
return len(path) > 0 && (path[0] == '/' || path[0] == '\\')
}
// pathJoin joins dirname and filename. filename must be relative.
// DWARF paths can be UNIX-style or DOS-style, so this handles both.
func pathJoin(dirname, filename string) string {
if len(dirname) == 0 {
return filename
}
// dirname should be absolute, which means we can determine
// whether it's a DOS path reasonably reliably by looking for
// a drive letter or UNC path.
drive, dirname := splitDrive(dirname)
if drive == "" {
// UNIX-style path.
return path.Join(dirname, filename)
}
// DOS-style path.
drive2, filename := splitDrive(filename)
if drive2 != "" {
if !strings.EqualFold(drive, drive2) {
// Different drives. There's not much we can
// do here, so just ignore the directory.
return drive2 + filename
}
// Drives are the same. Ignore drive on filename.
}
if !(strings.HasSuffix(dirname, "/") || strings.HasSuffix(dirname, `\`)) && dirname != "" {
sep := `\`
if strings.HasPrefix(dirname, "/") {
sep = `/`
}
dirname += sep
}
return drive + dirname + filename
}
// splitDrive splits the DOS drive letter or UNC share point from
// path, if any. path == drive + rest
func splitDrive(path string) (drive, rest string) {
if len(path) >= 2 && path[1] == ':' {
if c := path[0]; 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' {
return path[:2], path[2:]
}
}
if len(path) > 3 && (path[0] == '\\' || path[0] == '/') && (path[1] == '\\' || path[1] == '/') {
// Normalize the path so we can search for just \ below.
npath := strings.Replace(path, "/", `\`, -1)
// Get the host part, which must be non-empty.
slash1 := strings.IndexByte(npath[2:], '\\') + 2
if slash1 > 2 {
// Get the mount-point part, which must be non-empty.
slash2 := strings.IndexByte(npath[slash1+1:], '\\') + slash1 + 1
if slash2 > slash1 {
return path[:slash2], path[slash2:]
}
}
}
return "", path
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package dwarf provides access to DWARF debugging information loaded from
executable files, as defined in the DWARF 2.0 Standard at
http://dwarfstd.org/doc/dwarf-2.0.0.pdf.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package dwarf
import (
"encoding/binary"
"errors"
)
// Data represents the DWARF debugging information
// loaded from an executable file (for example, an ELF or Mach-O executable).
type Data struct {
// raw data
abbrev []byte
aranges []byte
frame []byte
info []byte
line []byte
pubnames []byte
ranges []byte
str []byte
// New sections added in DWARF 5.
addr []byte
lineStr []byte
strOffsets []byte
rngLists []byte
// parsed data
abbrevCache map[uint64]abbrevTable
bigEndian bool
order binary.ByteOrder
typeCache map[Offset]Type
typeSigs map[uint64]*typeUnit
unit []unit
}
var errSegmentSelector = errors.New("non-zero segment_selector size not supported")
// New returns a new Data object initialized from the given parameters.
// Rather than calling this function directly, clients should typically use
// the DWARF method of the File type of the appropriate package debug/elf,
// debug/macho, or debug/pe.
//
// The []byte arguments are the data from the corresponding debug section
// in the object file; for example, for an ELF object, abbrev is the contents of
// the ".debug_abbrev" section.
func New(abbrev, aranges, frame, info, line, pubnames, ranges, str []byte) (*Data, error) {
d := &Data{
abbrev: abbrev,
aranges: aranges,
frame: frame,
info: info,
line: line,
pubnames: pubnames,
ranges: ranges,
str: str,
abbrevCache: make(map[uint64]abbrevTable),
typeCache: make(map[Offset]Type),
typeSigs: make(map[uint64]*typeUnit),
}
// Sniff .debug_info to figure out byte order.
// 32-bit DWARF: 4 byte length, 2 byte version.
// 64-bit DWARf: 4 bytes of 0xff, 8 byte length, 2 byte version.
if len(d.info) < 6 {
return nil, DecodeError{"info", Offset(len(d.info)), "too short"}
}
offset := 4
if d.info[0] == 0xff && d.info[1] == 0xff && d.info[2] == 0xff && d.info[3] == 0xff {
if len(d.info) < 14 {
return nil, DecodeError{"info", Offset(len(d.info)), "too short"}
}
offset = 12
}
// Fetch the version, a tiny 16-bit number (1, 2, 3, 4, 5).
x, y := d.info[offset], d.info[offset+1]
switch {
case x == 0 && y == 0:
return nil, DecodeError{"info", 4, "unsupported version 0"}
case x == 0:
d.bigEndian = true
d.order = binary.BigEndian
case y == 0:
d.bigEndian = false
d.order = binary.LittleEndian
default:
return nil, DecodeError{"info", 4, "cannot determine byte order"}
}
u, err := d.parseUnits()
if err != nil {
return nil, err
}
d.unit = u
return d, nil
}
// AddTypes will add one .debug_types section to the DWARF data. A
// typical object with DWARF version 4 debug info will have multiple
// .debug_types sections. The name is used for error reporting only,
// and serves to distinguish one .debug_types section from another.
func (d *Data) AddTypes(name string, types []byte) error {
return d.parseTypes(name, types)
}
// AddSection adds another DWARF section by name. The name should be a
// DWARF section name such as ".debug_addr", ".debug_str_offsets", and
// so forth. This approach is used for new DWARF sections added in
// DWARF 5 and later.
func (d *Data) AddSection(name string, contents []byte) error {
var err error
switch name {
case ".debug_addr":
d.addr = contents
case ".debug_line_str":
d.lineStr = contents
case ".debug_str_offsets":
d.strOffsets = contents
case ".debug_rnglists":
d.rngLists = contents
}
// Just ignore names that we don't yet support.
return err
}
// Code generated by "stringer -type Tag -trimprefix=Tag"; DO NOT EDIT.
package dwarf
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[TagArrayType-1]
_ = x[TagClassType-2]
_ = x[TagEntryPoint-3]
_ = x[TagEnumerationType-4]
_ = x[TagFormalParameter-5]
_ = x[TagImportedDeclaration-8]
_ = x[TagLabel-10]
_ = x[TagLexDwarfBlock-11]
_ = x[TagMember-13]
_ = x[TagPointerType-15]
_ = x[TagReferenceType-16]
_ = x[TagCompileUnit-17]
_ = x[TagStringType-18]
_ = x[TagStructType-19]
_ = x[TagSubroutineType-21]
_ = x[TagTypedef-22]
_ = x[TagUnionType-23]
_ = x[TagUnspecifiedParameters-24]
_ = x[TagVariant-25]
_ = x[TagCommonDwarfBlock-26]
_ = x[TagCommonInclusion-27]
_ = x[TagInheritance-28]
_ = x[TagInlinedSubroutine-29]
_ = x[TagModule-30]
_ = x[TagPtrToMemberType-31]
_ = x[TagSetType-32]
_ = x[TagSubrangeType-33]
_ = x[TagWithStmt-34]
_ = x[TagAccessDeclaration-35]
_ = x[TagBaseType-36]
_ = x[TagCatchDwarfBlock-37]
_ = x[TagConstType-38]
_ = x[TagConstant-39]
_ = x[TagEnumerator-40]
_ = x[TagFileType-41]
_ = x[TagFriend-42]
_ = x[TagNamelist-43]
_ = x[TagNamelistItem-44]
_ = x[TagPackedType-45]
_ = x[TagSubprogram-46]
_ = x[TagTemplateTypeParameter-47]
_ = x[TagTemplateValueParameter-48]
_ = x[TagThrownType-49]
_ = x[TagTryDwarfBlock-50]
_ = x[TagVariantPart-51]
_ = x[TagVariable-52]
_ = x[TagVolatileType-53]
_ = x[TagDwarfProcedure-54]
_ = x[TagRestrictType-55]
_ = x[TagInterfaceType-56]
_ = x[TagNamespace-57]
_ = x[TagImportedModule-58]
_ = x[TagUnspecifiedType-59]
_ = x[TagPartialUnit-60]
_ = x[TagImportedUnit-61]
_ = x[TagMutableType-62]
_ = x[TagCondition-63]
_ = x[TagSharedType-64]
_ = x[TagTypeUnit-65]
_ = x[TagRvalueReferenceType-66]
_ = x[TagTemplateAlias-67]
_ = x[TagCoarrayType-68]
_ = x[TagGenericSubrange-69]
_ = x[TagDynamicType-70]
_ = x[TagAtomicType-71]
_ = x[TagCallSite-72]
_ = x[TagCallSiteParameter-73]
_ = x[TagSkeletonUnit-74]
_ = x[TagImmutableType-75]
}
const (
_Tag_name_0 = "ArrayTypeClassTypeEntryPointEnumerationTypeFormalParameter"
_Tag_name_1 = "ImportedDeclaration"
_Tag_name_2 = "LabelLexDwarfBlock"
_Tag_name_3 = "Member"
_Tag_name_4 = "PointerTypeReferenceTypeCompileUnitStringTypeStructType"
_Tag_name_5 = "SubroutineTypeTypedefUnionTypeUnspecifiedParametersVariantCommonDwarfBlockCommonInclusionInheritanceInlinedSubroutineModulePtrToMemberTypeSetTypeSubrangeTypeWithStmtAccessDeclarationBaseTypeCatchDwarfBlockConstTypeConstantEnumeratorFileTypeFriendNamelistNamelistItemPackedTypeSubprogramTemplateTypeParameterTemplateValueParameterThrownTypeTryDwarfBlockVariantPartVariableVolatileTypeDwarfProcedureRestrictTypeInterfaceTypeNamespaceImportedModuleUnspecifiedTypePartialUnitImportedUnitMutableTypeConditionSharedTypeTypeUnitRvalueReferenceTypeTemplateAliasCoarrayTypeGenericSubrangeDynamicTypeAtomicTypeCallSiteCallSiteParameterSkeletonUnitImmutableType"
)
var (
_Tag_index_0 = [...]uint8{0, 9, 18, 28, 43, 58}
_Tag_index_2 = [...]uint8{0, 5, 18}
_Tag_index_4 = [...]uint8{0, 11, 24, 35, 45, 55}
_Tag_index_5 = [...]uint16{0, 14, 21, 30, 51, 58, 74, 89, 100, 117, 123, 138, 145, 157, 165, 182, 190, 205, 214, 222, 232, 240, 246, 254, 266, 276, 286, 307, 329, 339, 352, 363, 371, 383, 397, 409, 422, 431, 445, 460, 471, 483, 494, 503, 513, 521, 540, 553, 564, 579, 590, 600, 608, 625, 637, 650}
)
func (i Tag) String() string {
switch {
case 1 <= i && i <= 5:
i -= 1
return _Tag_name_0[_Tag_index_0[i]:_Tag_index_0[i+1]]
case i == 8:
return _Tag_name_1
case 10 <= i && i <= 11:
i -= 10
return _Tag_name_2[_Tag_index_2[i]:_Tag_index_2[i+1]]
case i == 13:
return _Tag_name_3
case 15 <= i && i <= 19:
i -= 15
return _Tag_name_4[_Tag_index_4[i]:_Tag_index_4[i+1]]
case 21 <= i && i <= 75:
i -= 21
return _Tag_name_5[_Tag_index_5[i]:_Tag_index_5[i+1]]
default:
return "Tag(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DWARF type information structures.
// The format is heavily biased toward C, but for simplicity
// the String methods use a pseudo-Go syntax.
package dwarf
import "strconv"
// A Type conventionally represents a pointer to any of the
// specific Type structures (CharType, StructType, etc.).
type Type interface {
Common() *CommonType
String() string
Size() int64
}
// A CommonType holds fields common to multiple types.
// If a field is not known or not applicable for a given type,
// the zero value is used.
type CommonType struct {
ByteSize int64 // size of value of this type, in bytes
Name string // name that can be used to refer to type
}
func (c *CommonType) Common() *CommonType { return c }
func (c *CommonType) Size() int64 { return c.ByteSize }
// Basic types
// A BasicType holds fields common to all basic types.
//
// See the documentation for StructField for more info on the interpretation of
// the BitSize/BitOffset/DataBitOffset fields.
type BasicType struct {
CommonType
BitSize int64
BitOffset int64
DataBitOffset int64
}
func (b *BasicType) Basic() *BasicType { return b }
func (t *BasicType) String() string {
if t.Name != "" {
return t.Name
}
return "?"
}
// A CharType represents a signed character type.
type CharType struct {
BasicType
}
// A UcharType represents an unsigned character type.
type UcharType struct {
BasicType
}
// An IntType represents a signed integer type.
type IntType struct {
BasicType
}
// A UintType represents an unsigned integer type.
type UintType struct {
BasicType
}
// A FloatType represents a floating point type.
type FloatType struct {
BasicType
}
// A ComplexType represents a complex floating point type.
type ComplexType struct {
BasicType
}
// A BoolType represents a boolean type.
type BoolType struct {
BasicType
}
// An AddrType represents a machine address type.
type AddrType struct {
BasicType
}
// An UnspecifiedType represents an implicit, unknown, ambiguous or nonexistent type.
type UnspecifiedType struct {
BasicType
}
// qualifiers
// A QualType represents a type that has the C/C++ "const", "restrict", or "volatile" qualifier.
type QualType struct {
CommonType
Qual string
Type Type
}
func (t *QualType) String() string { return t.Qual + " " + t.Type.String() }
func (t *QualType) Size() int64 { return t.Type.Size() }
// An ArrayType represents a fixed size array type.
type ArrayType struct {
CommonType
Type Type
StrideBitSize int64 // if > 0, number of bits to hold each element
Count int64 // if == -1, an incomplete array, like char x[].
}
func (t *ArrayType) String() string {
return "[" + strconv.FormatInt(t.Count, 10) + "]" + t.Type.String()
}
func (t *ArrayType) Size() int64 {
if t.Count == -1 {
return 0
}
return t.Count * t.Type.Size()
}
// A VoidType represents the C void type.
type VoidType struct {
CommonType
}
func (t *VoidType) String() string { return "void" }
// A PtrType represents a pointer type.
type PtrType struct {
CommonType
Type Type
}
func (t *PtrType) String() string { return "*" + t.Type.String() }
// A StructType represents a struct, union, or C++ class type.
type StructType struct {
CommonType
StructName string
Kind string // "struct", "union", or "class".
Field []*StructField
Incomplete bool // if true, struct, union, class is declared but not defined
}
// A StructField represents a field in a struct, union, or C++ class type.
//
// # Bit Fields
//
// The BitSize, BitOffset, and DataBitOffset fields describe the bit
// size and offset of data members declared as bit fields in C/C++
// struct/union/class types.
//
// BitSize is the number of bits in the bit field.
//
// DataBitOffset, if non-zero, is the number of bits from the start of
// the enclosing entity (e.g. containing struct/class/union) to the
// start of the bit field. This corresponds to the DW_AT_data_bit_offset
// DWARF attribute that was introduced in DWARF 4.
//
// BitOffset, if non-zero, is the number of bits between the most
// significant bit of the storage unit holding the bit field to the
// most significant bit of the bit field. Here "storage unit" is the
// type name before the bit field (for a field "unsigned x:17", the
// storage unit is "unsigned"). BitOffset values can vary depending on
// the endianness of the system. BitOffset corresponds to the
// DW_AT_bit_offset DWARF attribute that was deprecated in DWARF 4 and
// removed in DWARF 5.
//
// At most one of DataBitOffset and BitOffset will be non-zero;
// DataBitOffset/BitOffset will only be non-zero if BitSize is
// non-zero. Whether a C compiler uses one or the other
// will depend on compiler vintage and command line options.
//
// Here is an example of C/C++ bit field use, along with what to
// expect in terms of DWARF bit offset info. Consider this code:
//
// struct S {
// int q;
// int j:5;
// int k:6;
// int m:5;
// int n:8;
// } s;
//
// For the code above, one would expect to see the following for
// DW_AT_bit_offset values (using GCC 8):
//
// Little | Big
// Endian | Endian
// |
// "j": 27 | 0
// "k": 21 | 5
// "m": 16 | 11
// "n": 8 | 16
//
// Note that in the above the offsets are purely with respect to the
// containing storage unit for j/k/m/n -- these values won't vary based
// on the size of prior data members in the containing struct.
//
// If the compiler emits DW_AT_data_bit_offset, the expected values
// would be:
//
// "j": 32
// "k": 37
// "m": 43
// "n": 48
//
// Here the value 32 for "j" reflects the fact that the bit field is
// preceded by other data members (recall that DW_AT_data_bit_offset
// values are relative to the start of the containing struct). Hence
// DW_AT_data_bit_offset values can be quite large for structs with
// many fields.
//
// DWARF also allow for the possibility of base types that have
// non-zero bit size and bit offset, so this information is also
// captured for base types, but it is worth noting that it is not
// possible to trigger this behavior using mainstream languages.
type StructField struct {
Name string
Type Type
ByteOffset int64
ByteSize int64 // usually zero; use Type.Size() for normal fields
BitOffset int64
DataBitOffset int64
BitSize int64 // zero if not a bit field
}
func (t *StructType) String() string {
if t.StructName != "" {
return t.Kind + " " + t.StructName
}
return t.Defn()
}
func (f *StructField) bitOffset() int64 {
if f.BitOffset != 0 {
return f.BitOffset
}
return f.DataBitOffset
}
func (t *StructType) Defn() string {
s := t.Kind
if t.StructName != "" {
s += " " + t.StructName
}
if t.Incomplete {
s += " /*incomplete*/"
return s
}
s += " {"
for i, f := range t.Field {
if i > 0 {
s += "; "
}
s += f.Name + " " + f.Type.String()
s += "@" + strconv.FormatInt(f.ByteOffset, 10)
if f.BitSize > 0 {
s += " : " + strconv.FormatInt(f.BitSize, 10)
s += "@" + strconv.FormatInt(f.bitOffset(), 10)
}
}
s += "}"
return s
}
// An EnumType represents an enumerated type.
// The only indication of its native integer type is its ByteSize
// (inside CommonType).
type EnumType struct {
CommonType
EnumName string
Val []*EnumValue
}
// An EnumValue represents a single enumeration value.
type EnumValue struct {
Name string
Val int64
}
func (t *EnumType) String() string {
s := "enum"
if t.EnumName != "" {
s += " " + t.EnumName
}
s += " {"
for i, v := range t.Val {
if i > 0 {
s += "; "
}
s += v.Name + "=" + strconv.FormatInt(v.Val, 10)
}
s += "}"
return s
}
// A FuncType represents a function type.
type FuncType struct {
CommonType
ReturnType Type
ParamType []Type
}
func (t *FuncType) String() string {
s := "func("
for i, t := range t.ParamType {
if i > 0 {
s += ", "
}
s += t.String()
}
s += ")"
if t.ReturnType != nil {
s += " " + t.ReturnType.String()
}
return s
}
// A DotDotDotType represents the variadic ... function parameter.
type DotDotDotType struct {
CommonType
}
func (t *DotDotDotType) String() string { return "..." }
// A TypedefType represents a named type.
type TypedefType struct {
CommonType
Type Type
}
func (t *TypedefType) String() string { return t.Name }
func (t *TypedefType) Size() int64 { return t.Type.Size() }
// An UnsupportedType is a placeholder returned in situations where we
// encounter a type that isn't supported.
type UnsupportedType struct {
CommonType
Tag Tag
}
func (t *UnsupportedType) String() string {
if t.Name != "" {
return t.Name
}
return t.Name + "(unsupported type " + t.Tag.String() + ")"
}
// typeReader is used to read from either the info section or the
// types section.
type typeReader interface {
Seek(Offset)
Next() (*Entry, error)
clone() typeReader
offset() Offset
// AddressSize returns the size in bytes of addresses in the current
// compilation unit.
AddressSize() int
}
// Type reads the type at off in the DWARF “info” section.
func (d *Data) Type(off Offset) (Type, error) {
return d.readType("info", d.Reader(), off, d.typeCache, nil)
}
type typeFixer struct {
typedefs []*TypedefType
arraytypes []*Type
}
func (tf *typeFixer) recordArrayType(t *Type) {
if t == nil {
return
}
_, ok := (*t).(*ArrayType)
if ok {
tf.arraytypes = append(tf.arraytypes, t)
}
}
func (tf *typeFixer) apply() {
for _, t := range tf.typedefs {
t.Common().ByteSize = t.Type.Size()
}
for _, t := range tf.arraytypes {
zeroArray(t)
}
}
// readType reads a type from r at off of name. It adds types to the
// type cache, appends new typedef types to typedefs, and computes the
// sizes of types. Callers should pass nil for typedefs; this is used
// for internal recursion.
func (d *Data) readType(name string, r typeReader, off Offset, typeCache map[Offset]Type, fixups *typeFixer) (Type, error) {
if t, ok := typeCache[off]; ok {
return t, nil
}
r.Seek(off)
e, err := r.Next()
if err != nil {
return nil, err
}
addressSize := r.AddressSize()
if e == nil || e.Offset != off {
return nil, DecodeError{name, off, "no type at offset"}
}
// If this is the root of the recursion, prepare to resolve
// typedef sizes and perform other fixups once the recursion is
// done. This must be done after the type graph is constructed
// because it may need to resolve cycles in a different order than
// readType encounters them.
if fixups == nil {
var fixer typeFixer
defer func() {
fixer.apply()
}()
fixups = &fixer
}
// Parse type from Entry.
// Must always set typeCache[off] before calling
// d.readType recursively, to handle circular types correctly.
var typ Type
nextDepth := 0
// Get next child; set err if error happens.
next := func() *Entry {
if !e.Children {
return nil
}
// Only return direct children.
// Skip over composite entries that happen to be nested
// inside this one. Most DWARF generators wouldn't generate
// such a thing, but clang does.
// See golang.org/issue/6472.
for {
kid, err1 := r.Next()
if err1 != nil {
err = err1
return nil
}
if kid == nil {
err = DecodeError{name, r.offset(), "unexpected end of DWARF entries"}
return nil
}
if kid.Tag == 0 {
if nextDepth > 0 {
nextDepth--
continue
}
return nil
}
if kid.Children {
nextDepth++
}
if nextDepth > 0 {
continue
}
return kid
}
}
// Get Type referred to by Entry's AttrType field.
// Set err if error happens. Not having a type is an error.
typeOf := func(e *Entry) Type {
tval := e.Val(AttrType)
var t Type
switch toff := tval.(type) {
case Offset:
if t, err = d.readType(name, r.clone(), toff, typeCache, fixups); err != nil {
return nil
}
case uint64:
if t, err = d.sigToType(toff); err != nil {
return nil
}
default:
// It appears that no Type means "void".
return new(VoidType)
}
return t
}
switch e.Tag {
case TagArrayType:
// Multi-dimensional array. (DWARF v2 §5.4)
// Attributes:
// AttrType:subtype [required]
// AttrStrideSize: size in bits of each element of the array
// AttrByteSize: size of entire array
// Children:
// TagSubrangeType or TagEnumerationType giving one dimension.
// dimensions are in left to right order.
t := new(ArrayType)
typ = t
typeCache[off] = t
if t.Type = typeOf(e); err != nil {
goto Error
}
t.StrideBitSize, _ = e.Val(AttrStrideSize).(int64)
// Accumulate dimensions,
var dims []int64
for kid := next(); kid != nil; kid = next() {
// TODO(rsc): Can also be TagEnumerationType
// but haven't seen that in the wild yet.
switch kid.Tag {
case TagSubrangeType:
count, ok := kid.Val(AttrCount).(int64)
if !ok {
// Old binaries may have an upper bound instead.
count, ok = kid.Val(AttrUpperBound).(int64)
if ok {
count++ // Length is one more than upper bound.
} else if len(dims) == 0 {
count = -1 // As in x[].
}
}
dims = append(dims, count)
case TagEnumerationType:
err = DecodeError{name, kid.Offset, "cannot handle enumeration type as array bound"}
goto Error
}
}
if len(dims) == 0 {
// LLVM generates this for x[].
dims = []int64{-1}
}
t.Count = dims[0]
for i := len(dims) - 1; i >= 1; i-- {
t.Type = &ArrayType{Type: t.Type, Count: dims[i]}
}
case TagBaseType:
// Basic type. (DWARF v2 §5.1)
// Attributes:
// AttrName: name of base type in programming language of the compilation unit [required]
// AttrEncoding: encoding value for type (encFloat etc) [required]
// AttrByteSize: size of type in bytes [required]
// AttrBitOffset: bit offset of value within containing storage unit
// AttrDataBitOffset: bit offset of value within containing storage unit
// AttrBitSize: size in bits
//
// For most languages BitOffset/DataBitOffset/BitSize will not be present
// for base types.
name, _ := e.Val(AttrName).(string)
enc, ok := e.Val(AttrEncoding).(int64)
if !ok {
err = DecodeError{name, e.Offset, "missing encoding attribute for " + name}
goto Error
}
switch enc {
default:
err = DecodeError{name, e.Offset, "unrecognized encoding attribute value"}
goto Error
case encAddress:
typ = new(AddrType)
case encBoolean:
typ = new(BoolType)
case encComplexFloat:
typ = new(ComplexType)
if name == "complex" {
// clang writes out 'complex' instead of 'complex float' or 'complex double'.
// clang also writes out a byte size that we can use to distinguish.
// See issue 8694.
switch byteSize, _ := e.Val(AttrByteSize).(int64); byteSize {
case 8:
name = "complex float"
case 16:
name = "complex double"
}
}
case encFloat:
typ = new(FloatType)
case encSigned:
typ = new(IntType)
case encUnsigned:
typ = new(UintType)
case encSignedChar:
typ = new(CharType)
case encUnsignedChar:
typ = new(UcharType)
}
typeCache[off] = typ
t := typ.(interface {
Basic() *BasicType
}).Basic()
t.Name = name
t.BitSize, _ = e.Val(AttrBitSize).(int64)
haveBitOffset := false
haveDataBitOffset := false
t.BitOffset, haveBitOffset = e.Val(AttrBitOffset).(int64)
t.DataBitOffset, haveDataBitOffset = e.Val(AttrDataBitOffset).(int64)
if haveBitOffset && haveDataBitOffset {
err = DecodeError{name, e.Offset, "duplicate bit offset attributes"}
goto Error
}
case TagClassType, TagStructType, TagUnionType:
// Structure, union, or class type. (DWARF v2 §5.5)
// Attributes:
// AttrName: name of struct, union, or class
// AttrByteSize: byte size [required]
// AttrDeclaration: if true, struct/union/class is incomplete
// Children:
// TagMember to describe one member.
// AttrName: name of member [required]
// AttrType: type of member [required]
// AttrByteSize: size in bytes
// AttrBitOffset: bit offset within bytes for bit fields
// AttrDataBitOffset: field bit offset relative to struct start
// AttrBitSize: bit size for bit fields
// AttrDataMemberLoc: location within struct [required for struct, class]
// There is much more to handle C++, all ignored for now.
t := new(StructType)
typ = t
typeCache[off] = t
switch e.Tag {
case TagClassType:
t.Kind = "class"
case TagStructType:
t.Kind = "struct"
case TagUnionType:
t.Kind = "union"
}
t.StructName, _ = e.Val(AttrName).(string)
t.Incomplete = e.Val(AttrDeclaration) != nil
t.Field = make([]*StructField, 0, 8)
var lastFieldType *Type
var lastFieldBitSize int64
var lastFieldByteOffset int64
for kid := next(); kid != nil; kid = next() {
if kid.Tag != TagMember {
continue
}
f := new(StructField)
if f.Type = typeOf(kid); err != nil {
goto Error
}
switch loc := kid.Val(AttrDataMemberLoc).(type) {
case []byte:
// TODO: Should have original compilation
// unit here, not unknownFormat.
b := makeBuf(d, unknownFormat{}, "location", 0, loc)
if b.uint8() != opPlusUconst {
err = DecodeError{name, kid.Offset, "unexpected opcode"}
goto Error
}
f.ByteOffset = int64(b.uint())
if b.err != nil {
err = b.err
goto Error
}
case int64:
f.ByteOffset = loc
}
f.Name, _ = kid.Val(AttrName).(string)
f.ByteSize, _ = kid.Val(AttrByteSize).(int64)
haveBitOffset := false
haveDataBitOffset := false
f.BitOffset, haveBitOffset = kid.Val(AttrBitOffset).(int64)
f.DataBitOffset, haveDataBitOffset = kid.Val(AttrDataBitOffset).(int64)
if haveBitOffset && haveDataBitOffset {
err = DecodeError{name, e.Offset, "duplicate bit offset attributes"}
goto Error
}
f.BitSize, _ = kid.Val(AttrBitSize).(int64)
t.Field = append(t.Field, f)
if lastFieldBitSize == 0 && lastFieldByteOffset == f.ByteOffset && t.Kind != "union" {
// Last field was zero width. Fix array length.
// (DWARF writes out 0-length arrays as if they were 1-length arrays.)
fixups.recordArrayType(lastFieldType)
}
lastFieldType = &f.Type
lastFieldByteOffset = f.ByteOffset
lastFieldBitSize = f.BitSize
}
if t.Kind != "union" {
b, ok := e.Val(AttrByteSize).(int64)
if ok && b == lastFieldByteOffset {
// Final field must be zero width. Fix array length.
fixups.recordArrayType(lastFieldType)
}
}
case TagConstType, TagVolatileType, TagRestrictType:
// Type modifier (DWARF v2 §5.2)
// Attributes:
// AttrType: subtype
t := new(QualType)
typ = t
typeCache[off] = t
if t.Type = typeOf(e); err != nil {
goto Error
}
switch e.Tag {
case TagConstType:
t.Qual = "const"
case TagRestrictType:
t.Qual = "restrict"
case TagVolatileType:
t.Qual = "volatile"
}
case TagEnumerationType:
// Enumeration type (DWARF v2 §5.6)
// Attributes:
// AttrName: enum name if any
// AttrByteSize: bytes required to represent largest value
// Children:
// TagEnumerator:
// AttrName: name of constant
// AttrConstValue: value of constant
t := new(EnumType)
typ = t
typeCache[off] = t
t.EnumName, _ = e.Val(AttrName).(string)
t.Val = make([]*EnumValue, 0, 8)
for kid := next(); kid != nil; kid = next() {
if kid.Tag == TagEnumerator {
f := new(EnumValue)
f.Name, _ = kid.Val(AttrName).(string)
f.Val, _ = kid.Val(AttrConstValue).(int64)
n := len(t.Val)
if n >= cap(t.Val) {
val := make([]*EnumValue, n, n*2)
copy(val, t.Val)
t.Val = val
}
t.Val = t.Val[0 : n+1]
t.Val[n] = f
}
}
case TagPointerType:
// Type modifier (DWARF v2 §5.2)
// Attributes:
// AttrType: subtype [not required! void* has no AttrType]
// AttrAddrClass: address class [ignored]
t := new(PtrType)
typ = t
typeCache[off] = t
if e.Val(AttrType) == nil {
t.Type = &VoidType{}
break
}
t.Type = typeOf(e)
case TagSubroutineType:
// Subroutine type. (DWARF v2 §5.7)
// Attributes:
// AttrType: type of return value if any
// AttrName: possible name of type [ignored]
// AttrPrototyped: whether used ANSI C prototype [ignored]
// Children:
// TagFormalParameter: typed parameter
// AttrType: type of parameter
// TagUnspecifiedParameter: final ...
t := new(FuncType)
typ = t
typeCache[off] = t
if t.ReturnType = typeOf(e); err != nil {
goto Error
}
t.ParamType = make([]Type, 0, 8)
for kid := next(); kid != nil; kid = next() {
var tkid Type
switch kid.Tag {
default:
continue
case TagFormalParameter:
if tkid = typeOf(kid); err != nil {
goto Error
}
case TagUnspecifiedParameters:
tkid = &DotDotDotType{}
}
t.ParamType = append(t.ParamType, tkid)
}
case TagTypedef:
// Typedef (DWARF v2 §5.3)
// Attributes:
// AttrName: name [required]
// AttrType: type definition [required]
t := new(TypedefType)
typ = t
typeCache[off] = t
t.Name, _ = e.Val(AttrName).(string)
t.Type = typeOf(e)
case TagUnspecifiedType:
// Unspecified type (DWARF v3 §5.2)
// Attributes:
// AttrName: name
t := new(UnspecifiedType)
typ = t
typeCache[off] = t
t.Name, _ = e.Val(AttrName).(string)
default:
// This is some other type DIE that we're currently not
// equipped to handle. Return an abstract "unsupported type"
// object in such cases.
t := new(UnsupportedType)
typ = t
typeCache[off] = t
t.Tag = e.Tag
t.Name, _ = e.Val(AttrName).(string)
}
if err != nil {
goto Error
}
{
b, ok := e.Val(AttrByteSize).(int64)
if !ok {
b = -1
switch t := typ.(type) {
case *TypedefType:
// Record that we need to resolve this
// type's size once the type graph is
// constructed.
fixups.typedefs = append(fixups.typedefs, t)
case *PtrType:
b = int64(addressSize)
}
}
typ.Common().ByteSize = b
}
return typ, nil
Error:
// If the parse fails, take the type out of the cache
// so that the next call with this offset doesn't hit
// the cache and return success.
delete(typeCache, off)
return nil, err
}
func zeroArray(t *Type) {
at := (*t).(*ArrayType)
if at.Type.Size() == 0 {
return
}
// Make a copy to avoid invalidating typeCache.
tt := *at
tt.Count = 0
*t = &tt
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"fmt"
"strconv"
)
// Parse the type units stored in a DWARF4 .debug_types section. Each
// type unit defines a single primary type and an 8-byte signature.
// Other sections may then use formRefSig8 to refer to the type.
// The typeUnit format is a single type with a signature. It holds
// the same data as a compilation unit.
type typeUnit struct {
unit
toff Offset // Offset to signature type within data.
name string // Name of .debug_type section.
cache Type // Cache the type, nil to start.
}
// Parse a .debug_types section.
func (d *Data) parseTypes(name string, types []byte) error {
b := makeBuf(d, unknownFormat{}, name, 0, types)
for len(b.data) > 0 {
base := b.off
n, dwarf64 := b.unitLength()
if n != Offset(uint32(n)) {
b.error("type unit length overflow")
return b.err
}
hdroff := b.off
vers := int(b.uint16())
if vers != 4 {
b.error("unsupported DWARF version " + strconv.Itoa(vers))
return b.err
}
var ao uint64
if !dwarf64 {
ao = uint64(b.uint32())
} else {
ao = b.uint64()
}
atable, err := d.parseAbbrev(ao, vers)
if err != nil {
return err
}
asize := b.uint8()
sig := b.uint64()
var toff uint32
if !dwarf64 {
toff = b.uint32()
} else {
to64 := b.uint64()
if to64 != uint64(uint32(to64)) {
b.error("type unit type offset overflow")
return b.err
}
toff = uint32(to64)
}
boff := b.off
d.typeSigs[sig] = &typeUnit{
unit: unit{
base: base,
off: boff,
data: b.bytes(int(n - (b.off - hdroff))),
atable: atable,
asize: int(asize),
vers: vers,
is64: dwarf64,
},
toff: Offset(toff),
name: name,
}
if b.err != nil {
return b.err
}
}
return nil
}
// Return the type for a type signature.
func (d *Data) sigToType(sig uint64) (Type, error) {
tu := d.typeSigs[sig]
if tu == nil {
return nil, fmt.Errorf("no type unit with signature %v", sig)
}
if tu.cache != nil {
return tu.cache, nil
}
b := makeBuf(d, tu, tu.name, tu.off, tu.data)
r := &typeUnitReader{d: d, tu: tu, b: b}
t, err := d.readType(tu.name, r, tu.toff, make(map[Offset]Type), nil)
if err != nil {
return nil, err
}
tu.cache = t
return t, nil
}
// typeUnitReader is a typeReader for a tagTypeUnit.
type typeUnitReader struct {
d *Data
tu *typeUnit
b buf
err error
}
// Seek to a new position in the type unit.
func (tur *typeUnitReader) Seek(off Offset) {
tur.err = nil
doff := off - tur.tu.off
if doff < 0 || doff >= Offset(len(tur.tu.data)) {
tur.err = fmt.Errorf("%s: offset %d out of range; max %d", tur.tu.name, doff, len(tur.tu.data))
return
}
tur.b = makeBuf(tur.d, tur.tu, tur.tu.name, off, tur.tu.data[doff:])
}
// AddressSize returns the size in bytes of addresses in the current type unit.
func (tur *typeUnitReader) AddressSize() int {
return tur.tu.unit.asize
}
// Next reads the next Entry from the type unit.
func (tur *typeUnitReader) Next() (*Entry, error) {
if tur.err != nil {
return nil, tur.err
}
if len(tur.tu.data) == 0 {
return nil, nil
}
e := tur.b.entry(nil, tur.tu.atable, tur.tu.base, tur.tu.vers)
if tur.b.err != nil {
tur.err = tur.b.err
return nil, tur.err
}
return e, nil
}
// clone returns a new reader for the type unit.
func (tur *typeUnitReader) clone() typeReader {
return &typeUnitReader{
d: tur.d,
tu: tur.tu,
b: makeBuf(tur.d, tur.tu, tur.tu.name, tur.tu.off, tur.tu.data),
}
}
// offset returns the current offset.
func (tur *typeUnitReader) offset() Offset {
return tur.b.off
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dwarf
import (
"sort"
"strconv"
)
// DWARF debug info is split into a sequence of compilation units.
// Each unit has its own abbreviation table and address size.
type unit struct {
base Offset // byte offset of header within the aggregate info
off Offset // byte offset of data within the aggregate info
data []byte
atable abbrevTable
asize int
vers int
utype uint8 // DWARF 5 unit type
is64 bool // True for 64-bit DWARF format
}
// Implement the dataFormat interface.
func (u *unit) version() int {
return u.vers
}
func (u *unit) dwarf64() (bool, bool) {
return u.is64, true
}
func (u *unit) addrsize() int {
return u.asize
}
func (d *Data) parseUnits() ([]unit, error) {
// Count units.
nunit := 0
b := makeBuf(d, unknownFormat{}, "info", 0, d.info)
for len(b.data) > 0 {
len, _ := b.unitLength()
if len != Offset(uint32(len)) {
b.error("unit length overflow")
break
}
b.skip(int(len))
if len > 0 {
nunit++
}
}
if b.err != nil {
return nil, b.err
}
// Again, this time writing them down.
b = makeBuf(d, unknownFormat{}, "info", 0, d.info)
units := make([]unit, nunit)
for i := range units {
u := &units[i]
u.base = b.off
var n Offset
if b.err != nil {
return nil, b.err
}
for n == 0 {
n, u.is64 = b.unitLength()
}
dataOff := b.off
vers := b.uint16()
if vers < 2 || vers > 5 {
b.error("unsupported DWARF version " + strconv.Itoa(int(vers)))
break
}
u.vers = int(vers)
if vers >= 5 {
u.utype = b.uint8()
u.asize = int(b.uint8())
}
var abbrevOff uint64
if u.is64 {
abbrevOff = b.uint64()
} else {
abbrevOff = uint64(b.uint32())
}
atable, err := d.parseAbbrev(abbrevOff, u.vers)
if err != nil {
if b.err == nil {
b.err = err
}
break
}
u.atable = atable
if vers < 5 {
u.asize = int(b.uint8())
}
switch u.utype {
case utSkeleton, utSplitCompile:
b.uint64() // unit ID
case utType, utSplitType:
b.uint64() // type signature
if u.is64 { // type offset
b.uint64()
} else {
b.uint32()
}
}
u.off = b.off
u.data = b.bytes(int(n - (b.off - dataOff)))
}
if b.err != nil {
return nil, b.err
}
return units, nil
}
// offsetToUnit returns the index of the unit containing offset off.
// It returns -1 if no unit contains this offset.
func (d *Data) offsetToUnit(off Offset) int {
// Find the unit after off
next := sort.Search(len(d.unit), func(i int) bool {
return d.unit[i].off > off
})
if next == 0 {
return -1
}
u := &d.unit[next-1]
if u.off <= off && off < u.off+Offset(len(u.data)) {
return next - 1
}
return -1
}
/*
* ELF constants and data structures
*
* Derived from:
* $FreeBSD: src/sys/sys/elf32.h,v 1.8.14.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/sys/elf64.h,v 1.10.14.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/sys/elf_common.h,v 1.15.8.1 2005/12/30 22:13:58 marcel Exp $
* $FreeBSD: src/sys/alpha/include/elf.h,v 1.14 2003/09/25 01:10:22 peter Exp $
* $FreeBSD: src/sys/amd64/include/elf.h,v 1.18 2004/08/03 08:21:48 dfr Exp $
* $FreeBSD: src/sys/arm/include/elf.h,v 1.5.2.1 2006/06/30 21:42:52 cognet Exp $
* $FreeBSD: src/sys/i386/include/elf.h,v 1.16 2004/08/02 19:12:17 dfr Exp $
* $FreeBSD: src/sys/powerpc/include/elf.h,v 1.7 2004/11/02 09:47:01 ssouhlal Exp $
* $FreeBSD: src/sys/sparc64/include/elf.h,v 1.12 2003/09/25 01:10:26 peter Exp $
* "System V ABI" (http://www.sco.com/developers/gabi/latest/ch4.eheader.html)
* "ELF for the ARM® 64-bit Architecture (AArch64)" (ARM IHI 0056B)
* "RISC-V ELF psABI specification" (https://github.com/riscv/riscv-elf-psabi-doc/blob/master/riscv-elf.md)
* llvm/BinaryFormat/ELF.h - ELF constants and structures
*
* Copyright (c) 1996-1998 John D. Polstra. All rights reserved.
* Copyright (c) 2001 David E. O'Brien
* Portions Copyright 2009 The Go Authors. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/
package elf
import "strconv"
/*
* Constants
*/
// Indexes into the Header.Ident array.
const (
EI_CLASS = 4 /* Class of machine. */
EI_DATA = 5 /* Data format. */
EI_VERSION = 6 /* ELF format version. */
EI_OSABI = 7 /* Operating system / ABI identification */
EI_ABIVERSION = 8 /* ABI version */
EI_PAD = 9 /* Start of padding (per SVR4 ABI). */
EI_NIDENT = 16 /* Size of e_ident array. */
)
// Initial magic number for ELF files.
const ELFMAG = "\177ELF"
// Version is found in Header.Ident[EI_VERSION] and Header.Version.
type Version byte
const (
EV_NONE Version = 0
EV_CURRENT Version = 1
)
var versionStrings = []intName{
{0, "EV_NONE"},
{1, "EV_CURRENT"},
}
func (i Version) String() string { return stringName(uint32(i), versionStrings, false) }
func (i Version) GoString() string { return stringName(uint32(i), versionStrings, true) }
// Class is found in Header.Ident[EI_CLASS] and Header.Class.
type Class byte
const (
ELFCLASSNONE Class = 0 /* Unknown class. */
ELFCLASS32 Class = 1 /* 32-bit architecture. */
ELFCLASS64 Class = 2 /* 64-bit architecture. */
)
var classStrings = []intName{
{0, "ELFCLASSNONE"},
{1, "ELFCLASS32"},
{2, "ELFCLASS64"},
}
func (i Class) String() string { return stringName(uint32(i), classStrings, false) }
func (i Class) GoString() string { return stringName(uint32(i), classStrings, true) }
// Data is found in Header.Ident[EI_DATA] and Header.Data.
type Data byte
const (
ELFDATANONE Data = 0 /* Unknown data format. */
ELFDATA2LSB Data = 1 /* 2's complement little-endian. */
ELFDATA2MSB Data = 2 /* 2's complement big-endian. */
)
var dataStrings = []intName{
{0, "ELFDATANONE"},
{1, "ELFDATA2LSB"},
{2, "ELFDATA2MSB"},
}
func (i Data) String() string { return stringName(uint32(i), dataStrings, false) }
func (i Data) GoString() string { return stringName(uint32(i), dataStrings, true) }
// OSABI is found in Header.Ident[EI_OSABI] and Header.OSABI.
type OSABI byte
const (
ELFOSABI_NONE OSABI = 0 /* UNIX System V ABI */
ELFOSABI_HPUX OSABI = 1 /* HP-UX operating system */
ELFOSABI_NETBSD OSABI = 2 /* NetBSD */
ELFOSABI_LINUX OSABI = 3 /* Linux */
ELFOSABI_HURD OSABI = 4 /* Hurd */
ELFOSABI_86OPEN OSABI = 5 /* 86Open common IA32 ABI */
ELFOSABI_SOLARIS OSABI = 6 /* Solaris */
ELFOSABI_AIX OSABI = 7 /* AIX */
ELFOSABI_IRIX OSABI = 8 /* IRIX */
ELFOSABI_FREEBSD OSABI = 9 /* FreeBSD */
ELFOSABI_TRU64 OSABI = 10 /* TRU64 UNIX */
ELFOSABI_MODESTO OSABI = 11 /* Novell Modesto */
ELFOSABI_OPENBSD OSABI = 12 /* OpenBSD */
ELFOSABI_OPENVMS OSABI = 13 /* Open VMS */
ELFOSABI_NSK OSABI = 14 /* HP Non-Stop Kernel */
ELFOSABI_AROS OSABI = 15 /* Amiga Research OS */
ELFOSABI_FENIXOS OSABI = 16 /* The FenixOS highly scalable multi-core OS */
ELFOSABI_CLOUDABI OSABI = 17 /* Nuxi CloudABI */
ELFOSABI_ARM OSABI = 97 /* ARM */
ELFOSABI_STANDALONE OSABI = 255 /* Standalone (embedded) application */
)
var osabiStrings = []intName{
{0, "ELFOSABI_NONE"},
{1, "ELFOSABI_HPUX"},
{2, "ELFOSABI_NETBSD"},
{3, "ELFOSABI_LINUX"},
{4, "ELFOSABI_HURD"},
{5, "ELFOSABI_86OPEN"},
{6, "ELFOSABI_SOLARIS"},
{7, "ELFOSABI_AIX"},
{8, "ELFOSABI_IRIX"},
{9, "ELFOSABI_FREEBSD"},
{10, "ELFOSABI_TRU64"},
{11, "ELFOSABI_MODESTO"},
{12, "ELFOSABI_OPENBSD"},
{13, "ELFOSABI_OPENVMS"},
{14, "ELFOSABI_NSK"},
{15, "ELFOSABI_AROS"},
{16, "ELFOSABI_FENIXOS"},
{17, "ELFOSABI_CLOUDABI"},
{97, "ELFOSABI_ARM"},
{255, "ELFOSABI_STANDALONE"},
}
func (i OSABI) String() string { return stringName(uint32(i), osabiStrings, false) }
func (i OSABI) GoString() string { return stringName(uint32(i), osabiStrings, true) }
// Type is found in Header.Type.
type Type uint16
const (
ET_NONE Type = 0 /* Unknown type. */
ET_REL Type = 1 /* Relocatable. */
ET_EXEC Type = 2 /* Executable. */
ET_DYN Type = 3 /* Shared object. */
ET_CORE Type = 4 /* Core file. */
ET_LOOS Type = 0xfe00 /* First operating system specific. */
ET_HIOS Type = 0xfeff /* Last operating system-specific. */
ET_LOPROC Type = 0xff00 /* First processor-specific. */
ET_HIPROC Type = 0xffff /* Last processor-specific. */
)
var typeStrings = []intName{
{0, "ET_NONE"},
{1, "ET_REL"},
{2, "ET_EXEC"},
{3, "ET_DYN"},
{4, "ET_CORE"},
{0xfe00, "ET_LOOS"},
{0xfeff, "ET_HIOS"},
{0xff00, "ET_LOPROC"},
{0xffff, "ET_HIPROC"},
}
func (i Type) String() string { return stringName(uint32(i), typeStrings, false) }
func (i Type) GoString() string { return stringName(uint32(i), typeStrings, true) }
// Machine is found in Header.Machine.
type Machine uint16
const (
EM_NONE Machine = 0 /* Unknown machine. */
EM_M32 Machine = 1 /* AT&T WE32100. */
EM_SPARC Machine = 2 /* Sun SPARC. */
EM_386 Machine = 3 /* Intel i386. */
EM_68K Machine = 4 /* Motorola 68000. */
EM_88K Machine = 5 /* Motorola 88000. */
EM_860 Machine = 7 /* Intel i860. */
EM_MIPS Machine = 8 /* MIPS R3000 Big-Endian only. */
EM_S370 Machine = 9 /* IBM System/370. */
EM_MIPS_RS3_LE Machine = 10 /* MIPS R3000 Little-Endian. */
EM_PARISC Machine = 15 /* HP PA-RISC. */
EM_VPP500 Machine = 17 /* Fujitsu VPP500. */
EM_SPARC32PLUS Machine = 18 /* SPARC v8plus. */
EM_960 Machine = 19 /* Intel 80960. */
EM_PPC Machine = 20 /* PowerPC 32-bit. */
EM_PPC64 Machine = 21 /* PowerPC 64-bit. */
EM_S390 Machine = 22 /* IBM System/390. */
EM_V800 Machine = 36 /* NEC V800. */
EM_FR20 Machine = 37 /* Fujitsu FR20. */
EM_RH32 Machine = 38 /* TRW RH-32. */
EM_RCE Machine = 39 /* Motorola RCE. */
EM_ARM Machine = 40 /* ARM. */
EM_SH Machine = 42 /* Hitachi SH. */
EM_SPARCV9 Machine = 43 /* SPARC v9 64-bit. */
EM_TRICORE Machine = 44 /* Siemens TriCore embedded processor. */
EM_ARC Machine = 45 /* Argonaut RISC Core. */
EM_H8_300 Machine = 46 /* Hitachi H8/300. */
EM_H8_300H Machine = 47 /* Hitachi H8/300H. */
EM_H8S Machine = 48 /* Hitachi H8S. */
EM_H8_500 Machine = 49 /* Hitachi H8/500. */
EM_IA_64 Machine = 50 /* Intel IA-64 Processor. */
EM_MIPS_X Machine = 51 /* Stanford MIPS-X. */
EM_COLDFIRE Machine = 52 /* Motorola ColdFire. */
EM_68HC12 Machine = 53 /* Motorola M68HC12. */
EM_MMA Machine = 54 /* Fujitsu MMA. */
EM_PCP Machine = 55 /* Siemens PCP. */
EM_NCPU Machine = 56 /* Sony nCPU. */
EM_NDR1 Machine = 57 /* Denso NDR1 microprocessor. */
EM_STARCORE Machine = 58 /* Motorola Star*Core processor. */
EM_ME16 Machine = 59 /* Toyota ME16 processor. */
EM_ST100 Machine = 60 /* STMicroelectronics ST100 processor. */
EM_TINYJ Machine = 61 /* Advanced Logic Corp. TinyJ processor. */
EM_X86_64 Machine = 62 /* Advanced Micro Devices x86-64 */
EM_PDSP Machine = 63 /* Sony DSP Processor */
EM_PDP10 Machine = 64 /* Digital Equipment Corp. PDP-10 */
EM_PDP11 Machine = 65 /* Digital Equipment Corp. PDP-11 */
EM_FX66 Machine = 66 /* Siemens FX66 microcontroller */
EM_ST9PLUS Machine = 67 /* STMicroelectronics ST9+ 8/16 bit microcontroller */
EM_ST7 Machine = 68 /* STMicroelectronics ST7 8-bit microcontroller */
EM_68HC16 Machine = 69 /* Motorola MC68HC16 Microcontroller */
EM_68HC11 Machine = 70 /* Motorola MC68HC11 Microcontroller */
EM_68HC08 Machine = 71 /* Motorola MC68HC08 Microcontroller */
EM_68HC05 Machine = 72 /* Motorola MC68HC05 Microcontroller */
EM_SVX Machine = 73 /* Silicon Graphics SVx */
EM_ST19 Machine = 74 /* STMicroelectronics ST19 8-bit microcontroller */
EM_VAX Machine = 75 /* Digital VAX */
EM_CRIS Machine = 76 /* Axis Communications 32-bit embedded processor */
EM_JAVELIN Machine = 77 /* Infineon Technologies 32-bit embedded processor */
EM_FIREPATH Machine = 78 /* Element 14 64-bit DSP Processor */
EM_ZSP Machine = 79 /* LSI Logic 16-bit DSP Processor */
EM_MMIX Machine = 80 /* Donald Knuth's educational 64-bit processor */
EM_HUANY Machine = 81 /* Harvard University machine-independent object files */
EM_PRISM Machine = 82 /* SiTera Prism */
EM_AVR Machine = 83 /* Atmel AVR 8-bit microcontroller */
EM_FR30 Machine = 84 /* Fujitsu FR30 */
EM_D10V Machine = 85 /* Mitsubishi D10V */
EM_D30V Machine = 86 /* Mitsubishi D30V */
EM_V850 Machine = 87 /* NEC v850 */
EM_M32R Machine = 88 /* Mitsubishi M32R */
EM_MN10300 Machine = 89 /* Matsushita MN10300 */
EM_MN10200 Machine = 90 /* Matsushita MN10200 */
EM_PJ Machine = 91 /* picoJava */
EM_OPENRISC Machine = 92 /* OpenRISC 32-bit embedded processor */
EM_ARC_COMPACT Machine = 93 /* ARC International ARCompact processor (old spelling/synonym: EM_ARC_A5) */
EM_XTENSA Machine = 94 /* Tensilica Xtensa Architecture */
EM_VIDEOCORE Machine = 95 /* Alphamosaic VideoCore processor */
EM_TMM_GPP Machine = 96 /* Thompson Multimedia General Purpose Processor */
EM_NS32K Machine = 97 /* National Semiconductor 32000 series */
EM_TPC Machine = 98 /* Tenor Network TPC processor */
EM_SNP1K Machine = 99 /* Trebia SNP 1000 processor */
EM_ST200 Machine = 100 /* STMicroelectronics (www.st.com) ST200 microcontroller */
EM_IP2K Machine = 101 /* Ubicom IP2xxx microcontroller family */
EM_MAX Machine = 102 /* MAX Processor */
EM_CR Machine = 103 /* National Semiconductor CompactRISC microprocessor */
EM_F2MC16 Machine = 104 /* Fujitsu F2MC16 */
EM_MSP430 Machine = 105 /* Texas Instruments embedded microcontroller msp430 */
EM_BLACKFIN Machine = 106 /* Analog Devices Blackfin (DSP) processor */
EM_SE_C33 Machine = 107 /* S1C33 Family of Seiko Epson processors */
EM_SEP Machine = 108 /* Sharp embedded microprocessor */
EM_ARCA Machine = 109 /* Arca RISC Microprocessor */
EM_UNICORE Machine = 110 /* Microprocessor series from PKU-Unity Ltd. and MPRC of Peking University */
EM_EXCESS Machine = 111 /* eXcess: 16/32/64-bit configurable embedded CPU */
EM_DXP Machine = 112 /* Icera Semiconductor Inc. Deep Execution Processor */
EM_ALTERA_NIOS2 Machine = 113 /* Altera Nios II soft-core processor */
EM_CRX Machine = 114 /* National Semiconductor CompactRISC CRX microprocessor */
EM_XGATE Machine = 115 /* Motorola XGATE embedded processor */
EM_C166 Machine = 116 /* Infineon C16x/XC16x processor */
EM_M16C Machine = 117 /* Renesas M16C series microprocessors */
EM_DSPIC30F Machine = 118 /* Microchip Technology dsPIC30F Digital Signal Controller */
EM_CE Machine = 119 /* Freescale Communication Engine RISC core */
EM_M32C Machine = 120 /* Renesas M32C series microprocessors */
EM_TSK3000 Machine = 131 /* Altium TSK3000 core */
EM_RS08 Machine = 132 /* Freescale RS08 embedded processor */
EM_SHARC Machine = 133 /* Analog Devices SHARC family of 32-bit DSP processors */
EM_ECOG2 Machine = 134 /* Cyan Technology eCOG2 microprocessor */
EM_SCORE7 Machine = 135 /* Sunplus S+core7 RISC processor */
EM_DSP24 Machine = 136 /* New Japan Radio (NJR) 24-bit DSP Processor */
EM_VIDEOCORE3 Machine = 137 /* Broadcom VideoCore III processor */
EM_LATTICEMICO32 Machine = 138 /* RISC processor for Lattice FPGA architecture */
EM_SE_C17 Machine = 139 /* Seiko Epson C17 family */
EM_TI_C6000 Machine = 140 /* The Texas Instruments TMS320C6000 DSP family */
EM_TI_C2000 Machine = 141 /* The Texas Instruments TMS320C2000 DSP family */
EM_TI_C5500 Machine = 142 /* The Texas Instruments TMS320C55x DSP family */
EM_TI_ARP32 Machine = 143 /* Texas Instruments Application Specific RISC Processor, 32bit fetch */
EM_TI_PRU Machine = 144 /* Texas Instruments Programmable Realtime Unit */
EM_MMDSP_PLUS Machine = 160 /* STMicroelectronics 64bit VLIW Data Signal Processor */
EM_CYPRESS_M8C Machine = 161 /* Cypress M8C microprocessor */
EM_R32C Machine = 162 /* Renesas R32C series microprocessors */
EM_TRIMEDIA Machine = 163 /* NXP Semiconductors TriMedia architecture family */
EM_QDSP6 Machine = 164 /* QUALCOMM DSP6 Processor */
EM_8051 Machine = 165 /* Intel 8051 and variants */
EM_STXP7X Machine = 166 /* STMicroelectronics STxP7x family of configurable and extensible RISC processors */
EM_NDS32 Machine = 167 /* Andes Technology compact code size embedded RISC processor family */
EM_ECOG1 Machine = 168 /* Cyan Technology eCOG1X family */
EM_ECOG1X Machine = 168 /* Cyan Technology eCOG1X family */
EM_MAXQ30 Machine = 169 /* Dallas Semiconductor MAXQ30 Core Micro-controllers */
EM_XIMO16 Machine = 170 /* New Japan Radio (NJR) 16-bit DSP Processor */
EM_MANIK Machine = 171 /* M2000 Reconfigurable RISC Microprocessor */
EM_CRAYNV2 Machine = 172 /* Cray Inc. NV2 vector architecture */
EM_RX Machine = 173 /* Renesas RX family */
EM_METAG Machine = 174 /* Imagination Technologies META processor architecture */
EM_MCST_ELBRUS Machine = 175 /* MCST Elbrus general purpose hardware architecture */
EM_ECOG16 Machine = 176 /* Cyan Technology eCOG16 family */
EM_CR16 Machine = 177 /* National Semiconductor CompactRISC CR16 16-bit microprocessor */
EM_ETPU Machine = 178 /* Freescale Extended Time Processing Unit */
EM_SLE9X Machine = 179 /* Infineon Technologies SLE9X core */
EM_L10M Machine = 180 /* Intel L10M */
EM_K10M Machine = 181 /* Intel K10M */
EM_AARCH64 Machine = 183 /* ARM 64-bit Architecture (AArch64) */
EM_AVR32 Machine = 185 /* Atmel Corporation 32-bit microprocessor family */
EM_STM8 Machine = 186 /* STMicroeletronics STM8 8-bit microcontroller */
EM_TILE64 Machine = 187 /* Tilera TILE64 multicore architecture family */
EM_TILEPRO Machine = 188 /* Tilera TILEPro multicore architecture family */
EM_MICROBLAZE Machine = 189 /* Xilinx MicroBlaze 32-bit RISC soft processor core */
EM_CUDA Machine = 190 /* NVIDIA CUDA architecture */
EM_TILEGX Machine = 191 /* Tilera TILE-Gx multicore architecture family */
EM_CLOUDSHIELD Machine = 192 /* CloudShield architecture family */
EM_COREA_1ST Machine = 193 /* KIPO-KAIST Core-A 1st generation processor family */
EM_COREA_2ND Machine = 194 /* KIPO-KAIST Core-A 2nd generation processor family */
EM_ARC_COMPACT2 Machine = 195 /* Synopsys ARCompact V2 */
EM_OPEN8 Machine = 196 /* Open8 8-bit RISC soft processor core */
EM_RL78 Machine = 197 /* Renesas RL78 family */
EM_VIDEOCORE5 Machine = 198 /* Broadcom VideoCore V processor */
EM_78KOR Machine = 199 /* Renesas 78KOR family */
EM_56800EX Machine = 200 /* Freescale 56800EX Digital Signal Controller (DSC) */
EM_BA1 Machine = 201 /* Beyond BA1 CPU architecture */
EM_BA2 Machine = 202 /* Beyond BA2 CPU architecture */
EM_XCORE Machine = 203 /* XMOS xCORE processor family */
EM_MCHP_PIC Machine = 204 /* Microchip 8-bit PIC(r) family */
EM_INTEL205 Machine = 205 /* Reserved by Intel */
EM_INTEL206 Machine = 206 /* Reserved by Intel */
EM_INTEL207 Machine = 207 /* Reserved by Intel */
EM_INTEL208 Machine = 208 /* Reserved by Intel */
EM_INTEL209 Machine = 209 /* Reserved by Intel */
EM_KM32 Machine = 210 /* KM211 KM32 32-bit processor */
EM_KMX32 Machine = 211 /* KM211 KMX32 32-bit processor */
EM_KMX16 Machine = 212 /* KM211 KMX16 16-bit processor */
EM_KMX8 Machine = 213 /* KM211 KMX8 8-bit processor */
EM_KVARC Machine = 214 /* KM211 KVARC processor */
EM_CDP Machine = 215 /* Paneve CDP architecture family */
EM_COGE Machine = 216 /* Cognitive Smart Memory Processor */
EM_COOL Machine = 217 /* Bluechip Systems CoolEngine */
EM_NORC Machine = 218 /* Nanoradio Optimized RISC */
EM_CSR_KALIMBA Machine = 219 /* CSR Kalimba architecture family */
EM_Z80 Machine = 220 /* Zilog Z80 */
EM_VISIUM Machine = 221 /* Controls and Data Services VISIUMcore processor */
EM_FT32 Machine = 222 /* FTDI Chip FT32 high performance 32-bit RISC architecture */
EM_MOXIE Machine = 223 /* Moxie processor family */
EM_AMDGPU Machine = 224 /* AMD GPU architecture */
EM_RISCV Machine = 243 /* RISC-V */
EM_LANAI Machine = 244 /* Lanai 32-bit processor */
EM_BPF Machine = 247 /* Linux BPF – in-kernel virtual machine */
EM_LOONGARCH Machine = 258 /* LoongArch */
/* Non-standard or deprecated. */
EM_486 Machine = 6 /* Intel i486. */
EM_MIPS_RS4_BE Machine = 10 /* MIPS R4000 Big-Endian */
EM_ALPHA_STD Machine = 41 /* Digital Alpha (standard value). */
EM_ALPHA Machine = 0x9026 /* Alpha (written in the absence of an ABI) */
)
var machineStrings = []intName{
{0, "EM_NONE"},
{1, "EM_M32"},
{2, "EM_SPARC"},
{3, "EM_386"},
{4, "EM_68K"},
{5, "EM_88K"},
{7, "EM_860"},
{8, "EM_MIPS"},
{9, "EM_S370"},
{10, "EM_MIPS_RS3_LE"},
{15, "EM_PARISC"},
{17, "EM_VPP500"},
{18, "EM_SPARC32PLUS"},
{19, "EM_960"},
{20, "EM_PPC"},
{21, "EM_PPC64"},
{22, "EM_S390"},
{36, "EM_V800"},
{37, "EM_FR20"},
{38, "EM_RH32"},
{39, "EM_RCE"},
{40, "EM_ARM"},
{42, "EM_SH"},
{43, "EM_SPARCV9"},
{44, "EM_TRICORE"},
{45, "EM_ARC"},
{46, "EM_H8_300"},
{47, "EM_H8_300H"},
{48, "EM_H8S"},
{49, "EM_H8_500"},
{50, "EM_IA_64"},
{51, "EM_MIPS_X"},
{52, "EM_COLDFIRE"},
{53, "EM_68HC12"},
{54, "EM_MMA"},
{55, "EM_PCP"},
{56, "EM_NCPU"},
{57, "EM_NDR1"},
{58, "EM_STARCORE"},
{59, "EM_ME16"},
{60, "EM_ST100"},
{61, "EM_TINYJ"},
{62, "EM_X86_64"},
{63, "EM_PDSP"},
{64, "EM_PDP10"},
{65, "EM_PDP11"},
{66, "EM_FX66"},
{67, "EM_ST9PLUS"},
{68, "EM_ST7"},
{69, "EM_68HC16"},
{70, "EM_68HC11"},
{71, "EM_68HC08"},
{72, "EM_68HC05"},
{73, "EM_SVX"},
{74, "EM_ST19"},
{75, "EM_VAX"},
{76, "EM_CRIS"},
{77, "EM_JAVELIN"},
{78, "EM_FIREPATH"},
{79, "EM_ZSP"},
{80, "EM_MMIX"},
{81, "EM_HUANY"},
{82, "EM_PRISM"},
{83, "EM_AVR"},
{84, "EM_FR30"},
{85, "EM_D10V"},
{86, "EM_D30V"},
{87, "EM_V850"},
{88, "EM_M32R"},
{89, "EM_MN10300"},
{90, "EM_MN10200"},
{91, "EM_PJ"},
{92, "EM_OPENRISC"},
{93, "EM_ARC_COMPACT"},
{94, "EM_XTENSA"},
{95, "EM_VIDEOCORE"},
{96, "EM_TMM_GPP"},
{97, "EM_NS32K"},
{98, "EM_TPC"},
{99, "EM_SNP1K"},
{100, "EM_ST200"},
{101, "EM_IP2K"},
{102, "EM_MAX"},
{103, "EM_CR"},
{104, "EM_F2MC16"},
{105, "EM_MSP430"},
{106, "EM_BLACKFIN"},
{107, "EM_SE_C33"},
{108, "EM_SEP"},
{109, "EM_ARCA"},
{110, "EM_UNICORE"},
{111, "EM_EXCESS"},
{112, "EM_DXP"},
{113, "EM_ALTERA_NIOS2"},
{114, "EM_CRX"},
{115, "EM_XGATE"},
{116, "EM_C166"},
{117, "EM_M16C"},
{118, "EM_DSPIC30F"},
{119, "EM_CE"},
{120, "EM_M32C"},
{131, "EM_TSK3000"},
{132, "EM_RS08"},
{133, "EM_SHARC"},
{134, "EM_ECOG2"},
{135, "EM_SCORE7"},
{136, "EM_DSP24"},
{137, "EM_VIDEOCORE3"},
{138, "EM_LATTICEMICO32"},
{139, "EM_SE_C17"},
{140, "EM_TI_C6000"},
{141, "EM_TI_C2000"},
{142, "EM_TI_C5500"},
{143, "EM_TI_ARP32"},
{144, "EM_TI_PRU"},
{160, "EM_MMDSP_PLUS"},
{161, "EM_CYPRESS_M8C"},
{162, "EM_R32C"},
{163, "EM_TRIMEDIA"},
{164, "EM_QDSP6"},
{165, "EM_8051"},
{166, "EM_STXP7X"},
{167, "EM_NDS32"},
{168, "EM_ECOG1"},
{168, "EM_ECOG1X"},
{169, "EM_MAXQ30"},
{170, "EM_XIMO16"},
{171, "EM_MANIK"},
{172, "EM_CRAYNV2"},
{173, "EM_RX"},
{174, "EM_METAG"},
{175, "EM_MCST_ELBRUS"},
{176, "EM_ECOG16"},
{177, "EM_CR16"},
{178, "EM_ETPU"},
{179, "EM_SLE9X"},
{180, "EM_L10M"},
{181, "EM_K10M"},
{183, "EM_AARCH64"},
{185, "EM_AVR32"},
{186, "EM_STM8"},
{187, "EM_TILE64"},
{188, "EM_TILEPRO"},
{189, "EM_MICROBLAZE"},
{190, "EM_CUDA"},
{191, "EM_TILEGX"},
{192, "EM_CLOUDSHIELD"},
{193, "EM_COREA_1ST"},
{194, "EM_COREA_2ND"},
{195, "EM_ARC_COMPACT2"},
{196, "EM_OPEN8"},
{197, "EM_RL78"},
{198, "EM_VIDEOCORE5"},
{199, "EM_78KOR"},
{200, "EM_56800EX"},
{201, "EM_BA1"},
{202, "EM_BA2"},
{203, "EM_XCORE"},
{204, "EM_MCHP_PIC"},
{205, "EM_INTEL205"},
{206, "EM_INTEL206"},
{207, "EM_INTEL207"},
{208, "EM_INTEL208"},
{209, "EM_INTEL209"},
{210, "EM_KM32"},
{211, "EM_KMX32"},
{212, "EM_KMX16"},
{213, "EM_KMX8"},
{214, "EM_KVARC"},
{215, "EM_CDP"},
{216, "EM_COGE"},
{217, "EM_COOL"},
{218, "EM_NORC"},
{219, "EM_CSR_KALIMBA "},
{220, "EM_Z80 "},
{221, "EM_VISIUM "},
{222, "EM_FT32 "},
{223, "EM_MOXIE"},
{224, "EM_AMDGPU"},
{243, "EM_RISCV"},
{244, "EM_LANAI"},
{247, "EM_BPF"},
{258, "EM_LOONGARCH"},
/* Non-standard or deprecated. */
{6, "EM_486"},
{10, "EM_MIPS_RS4_BE"},
{41, "EM_ALPHA_STD"},
{0x9026, "EM_ALPHA"},
}
func (i Machine) String() string { return stringName(uint32(i), machineStrings, false) }
func (i Machine) GoString() string { return stringName(uint32(i), machineStrings, true) }
// Special section indices.
type SectionIndex int
const (
SHN_UNDEF SectionIndex = 0 /* Undefined, missing, irrelevant. */
SHN_LORESERVE SectionIndex = 0xff00 /* First of reserved range. */
SHN_LOPROC SectionIndex = 0xff00 /* First processor-specific. */
SHN_HIPROC SectionIndex = 0xff1f /* Last processor-specific. */
SHN_LOOS SectionIndex = 0xff20 /* First operating system-specific. */
SHN_HIOS SectionIndex = 0xff3f /* Last operating system-specific. */
SHN_ABS SectionIndex = 0xfff1 /* Absolute values. */
SHN_COMMON SectionIndex = 0xfff2 /* Common data. */
SHN_XINDEX SectionIndex = 0xffff /* Escape; index stored elsewhere. */
SHN_HIRESERVE SectionIndex = 0xffff /* Last of reserved range. */
)
var shnStrings = []intName{
{0, "SHN_UNDEF"},
{0xff00, "SHN_LOPROC"},
{0xff20, "SHN_LOOS"},
{0xfff1, "SHN_ABS"},
{0xfff2, "SHN_COMMON"},
{0xffff, "SHN_XINDEX"},
}
func (i SectionIndex) String() string { return stringName(uint32(i), shnStrings, false) }
func (i SectionIndex) GoString() string { return stringName(uint32(i), shnStrings, true) }
// Section type.
type SectionType uint32
const (
SHT_NULL SectionType = 0 /* inactive */
SHT_PROGBITS SectionType = 1 /* program defined information */
SHT_SYMTAB SectionType = 2 /* symbol table section */
SHT_STRTAB SectionType = 3 /* string table section */
SHT_RELA SectionType = 4 /* relocation section with addends */
SHT_HASH SectionType = 5 /* symbol hash table section */
SHT_DYNAMIC SectionType = 6 /* dynamic section */
SHT_NOTE SectionType = 7 /* note section */
SHT_NOBITS SectionType = 8 /* no space section */
SHT_REL SectionType = 9 /* relocation section - no addends */
SHT_SHLIB SectionType = 10 /* reserved - purpose unknown */
SHT_DYNSYM SectionType = 11 /* dynamic symbol table section */
SHT_INIT_ARRAY SectionType = 14 /* Initialization function pointers. */
SHT_FINI_ARRAY SectionType = 15 /* Termination function pointers. */
SHT_PREINIT_ARRAY SectionType = 16 /* Pre-initialization function ptrs. */
SHT_GROUP SectionType = 17 /* Section group. */
SHT_SYMTAB_SHNDX SectionType = 18 /* Section indexes (see SHN_XINDEX). */
SHT_LOOS SectionType = 0x60000000 /* First of OS specific semantics */
SHT_GNU_ATTRIBUTES SectionType = 0x6ffffff5 /* GNU object attributes */
SHT_GNU_HASH SectionType = 0x6ffffff6 /* GNU hash table */
SHT_GNU_LIBLIST SectionType = 0x6ffffff7 /* GNU prelink library list */
SHT_GNU_VERDEF SectionType = 0x6ffffffd /* GNU version definition section */
SHT_GNU_VERNEED SectionType = 0x6ffffffe /* GNU version needs section */
SHT_GNU_VERSYM SectionType = 0x6fffffff /* GNU version symbol table */
SHT_HIOS SectionType = 0x6fffffff /* Last of OS specific semantics */
SHT_LOPROC SectionType = 0x70000000 /* reserved range for processor */
SHT_MIPS_ABIFLAGS SectionType = 0x7000002a /* .MIPS.abiflags */
SHT_HIPROC SectionType = 0x7fffffff /* specific section header types */
SHT_LOUSER SectionType = 0x80000000 /* reserved range for application */
SHT_HIUSER SectionType = 0xffffffff /* specific indexes */
)
var shtStrings = []intName{
{0, "SHT_NULL"},
{1, "SHT_PROGBITS"},
{2, "SHT_SYMTAB"},
{3, "SHT_STRTAB"},
{4, "SHT_RELA"},
{5, "SHT_HASH"},
{6, "SHT_DYNAMIC"},
{7, "SHT_NOTE"},
{8, "SHT_NOBITS"},
{9, "SHT_REL"},
{10, "SHT_SHLIB"},
{11, "SHT_DYNSYM"},
{14, "SHT_INIT_ARRAY"},
{15, "SHT_FINI_ARRAY"},
{16, "SHT_PREINIT_ARRAY"},
{17, "SHT_GROUP"},
{18, "SHT_SYMTAB_SHNDX"},
{0x60000000, "SHT_LOOS"},
{0x6ffffff5, "SHT_GNU_ATTRIBUTES"},
{0x6ffffff6, "SHT_GNU_HASH"},
{0x6ffffff7, "SHT_GNU_LIBLIST"},
{0x6ffffffd, "SHT_GNU_VERDEF"},
{0x6ffffffe, "SHT_GNU_VERNEED"},
{0x6fffffff, "SHT_GNU_VERSYM"},
{0x70000000, "SHT_LOPROC"},
{0x7000002a, "SHT_MIPS_ABIFLAGS"},
{0x7fffffff, "SHT_HIPROC"},
{0x80000000, "SHT_LOUSER"},
{0xffffffff, "SHT_HIUSER"},
}
func (i SectionType) String() string { return stringName(uint32(i), shtStrings, false) }
func (i SectionType) GoString() string { return stringName(uint32(i), shtStrings, true) }
// Section flags.
type SectionFlag uint32
const (
SHF_WRITE SectionFlag = 0x1 /* Section contains writable data. */
SHF_ALLOC SectionFlag = 0x2 /* Section occupies memory. */
SHF_EXECINSTR SectionFlag = 0x4 /* Section contains instructions. */
SHF_MERGE SectionFlag = 0x10 /* Section may be merged. */
SHF_STRINGS SectionFlag = 0x20 /* Section contains strings. */
SHF_INFO_LINK SectionFlag = 0x40 /* sh_info holds section index. */
SHF_LINK_ORDER SectionFlag = 0x80 /* Special ordering requirements. */
SHF_OS_NONCONFORMING SectionFlag = 0x100 /* OS-specific processing required. */
SHF_GROUP SectionFlag = 0x200 /* Member of section group. */
SHF_TLS SectionFlag = 0x400 /* Section contains TLS data. */
SHF_COMPRESSED SectionFlag = 0x800 /* Section is compressed. */
SHF_MASKOS SectionFlag = 0x0ff00000 /* OS-specific semantics. */
SHF_MASKPROC SectionFlag = 0xf0000000 /* Processor-specific semantics. */
)
var shfStrings = []intName{
{0x1, "SHF_WRITE"},
{0x2, "SHF_ALLOC"},
{0x4, "SHF_EXECINSTR"},
{0x10, "SHF_MERGE"},
{0x20, "SHF_STRINGS"},
{0x40, "SHF_INFO_LINK"},
{0x80, "SHF_LINK_ORDER"},
{0x100, "SHF_OS_NONCONFORMING"},
{0x200, "SHF_GROUP"},
{0x400, "SHF_TLS"},
{0x800, "SHF_COMPRESSED"},
}
func (i SectionFlag) String() string { return flagName(uint32(i), shfStrings, false) }
func (i SectionFlag) GoString() string { return flagName(uint32(i), shfStrings, true) }
// Section compression type.
type CompressionType int
const (
COMPRESS_ZLIB CompressionType = 1 /* ZLIB compression. */
COMPRESS_LOOS CompressionType = 0x60000000 /* First OS-specific. */
COMPRESS_HIOS CompressionType = 0x6fffffff /* Last OS-specific. */
COMPRESS_LOPROC CompressionType = 0x70000000 /* First processor-specific type. */
COMPRESS_HIPROC CompressionType = 0x7fffffff /* Last processor-specific type. */
)
var compressionStrings = []intName{
{1, "COMPRESS_ZLIB"},
{0x60000000, "COMPRESS_LOOS"},
{0x6fffffff, "COMPRESS_HIOS"},
{0x70000000, "COMPRESS_LOPROC"},
{0x7fffffff, "COMPRESS_HIPROC"},
}
func (i CompressionType) String() string { return stringName(uint32(i), compressionStrings, false) }
func (i CompressionType) GoString() string { return stringName(uint32(i), compressionStrings, true) }
// Prog.Type
type ProgType int
const (
PT_NULL ProgType = 0 /* Unused entry. */
PT_LOAD ProgType = 1 /* Loadable segment. */
PT_DYNAMIC ProgType = 2 /* Dynamic linking information segment. */
PT_INTERP ProgType = 3 /* Pathname of interpreter. */
PT_NOTE ProgType = 4 /* Auxiliary information. */
PT_SHLIB ProgType = 5 /* Reserved (not used). */
PT_PHDR ProgType = 6 /* Location of program header itself. */
PT_TLS ProgType = 7 /* Thread local storage segment */
PT_LOOS ProgType = 0x60000000 /* First OS-specific. */
PT_GNU_EH_FRAME ProgType = 0x6474e550 /* Frame unwind information */
PT_GNU_STACK ProgType = 0x6474e551 /* Stack flags */
PT_GNU_RELRO ProgType = 0x6474e552 /* Read only after relocs */
PT_GNU_PROPERTY ProgType = 0x6474e553 /* GNU property */
PT_GNU_MBIND_LO ProgType = 0x6474e555 /* Mbind segments start */
PT_GNU_MBIND_HI ProgType = 0x6474f554 /* Mbind segments finish */
PT_PAX_FLAGS ProgType = 0x65041580 /* PAX flags */
PT_OPENBSD_RANDOMIZE ProgType = 0x65a3dbe6 /* Random data */
PT_OPENBSD_WXNEEDED ProgType = 0x65a3dbe7 /* W^X violations */
PT_OPENBSD_BOOTDATA ProgType = 0x65a41be6 /* Boot arguments */
PT_SUNW_EH_FRAME ProgType = 0x6474e550 /* Frame unwind information */
PT_SUNWSTACK ProgType = 0x6ffffffb /* Stack segment */
PT_HIOS ProgType = 0x6fffffff /* Last OS-specific. */
PT_LOPROC ProgType = 0x70000000 /* First processor-specific type. */
PT_ARM_ARCHEXT ProgType = 0x70000000 /* Architecture compatibility */
PT_ARM_EXIDX ProgType = 0x70000001 /* Exception unwind tables */
PT_AARCH64_ARCHEXT ProgType = 0x70000000 /* Architecture compatibility */
PT_AARCH64_UNWIND ProgType = 0x70000001 /* Exception unwind tables */
PT_MIPS_REGINFO ProgType = 0x70000000 /* Register usage */
PT_MIPS_RTPROC ProgType = 0x70000001 /* Runtime procedures */
PT_MIPS_OPTIONS ProgType = 0x70000002 /* Options */
PT_MIPS_ABIFLAGS ProgType = 0x70000003 /* ABI flags */
PT_S390_PGSTE ProgType = 0x70000000 /* 4k page table size */
PT_HIPROC ProgType = 0x7fffffff /* Last processor-specific type. */
)
var ptStrings = []intName{
{0, "PT_NULL"},
{1, "PT_LOAD"},
{2, "PT_DYNAMIC"},
{3, "PT_INTERP"},
{4, "PT_NOTE"},
{5, "PT_SHLIB"},
{6, "PT_PHDR"},
{7, "PT_TLS"},
{0x60000000, "PT_LOOS"},
{0x6474e550, "PT_GNU_EH_FRAME"},
{0x6474e551, "PT_GNU_STACK"},
{0x6474e552, "PT_GNU_RELRO"},
{0x6474e553, "PT_GNU_PROPERTY"},
{0x65041580, "PT_PAX_FLAGS"},
{0x65a3dbe6, "PT_OPENBSD_RANDOMIZE"},
{0x65a3dbe7, "PT_OPENBSD_WXNEEDED"},
{0x65a41be6, "PT_OPENBSD_BOOTDATA"},
{0x6ffffffb, "PT_SUNWSTACK"},
{0x6fffffff, "PT_HIOS"},
{0x70000000, "PT_LOPROC"},
// We don't list the processor-dependent ProgTypes,
// as the values overlap.
{0x7fffffff, "PT_HIPROC"},
}
func (i ProgType) String() string { return stringName(uint32(i), ptStrings, false) }
func (i ProgType) GoString() string { return stringName(uint32(i), ptStrings, true) }
// Prog.Flag
type ProgFlag uint32
const (
PF_X ProgFlag = 0x1 /* Executable. */
PF_W ProgFlag = 0x2 /* Writable. */
PF_R ProgFlag = 0x4 /* Readable. */
PF_MASKOS ProgFlag = 0x0ff00000 /* Operating system-specific. */
PF_MASKPROC ProgFlag = 0xf0000000 /* Processor-specific. */
)
var pfStrings = []intName{
{0x1, "PF_X"},
{0x2, "PF_W"},
{0x4, "PF_R"},
}
func (i ProgFlag) String() string { return flagName(uint32(i), pfStrings, false) }
func (i ProgFlag) GoString() string { return flagName(uint32(i), pfStrings, true) }
// Dyn.Tag
type DynTag int
const (
DT_NULL DynTag = 0 /* Terminating entry. */
DT_NEEDED DynTag = 1 /* String table offset of a needed shared library. */
DT_PLTRELSZ DynTag = 2 /* Total size in bytes of PLT relocations. */
DT_PLTGOT DynTag = 3 /* Processor-dependent address. */
DT_HASH DynTag = 4 /* Address of symbol hash table. */
DT_STRTAB DynTag = 5 /* Address of string table. */
DT_SYMTAB DynTag = 6 /* Address of symbol table. */
DT_RELA DynTag = 7 /* Address of ElfNN_Rela relocations. */
DT_RELASZ DynTag = 8 /* Total size of ElfNN_Rela relocations. */
DT_RELAENT DynTag = 9 /* Size of each ElfNN_Rela relocation entry. */
DT_STRSZ DynTag = 10 /* Size of string table. */
DT_SYMENT DynTag = 11 /* Size of each symbol table entry. */
DT_INIT DynTag = 12 /* Address of initialization function. */
DT_FINI DynTag = 13 /* Address of finalization function. */
DT_SONAME DynTag = 14 /* String table offset of shared object name. */
DT_RPATH DynTag = 15 /* String table offset of library path. [sup] */
DT_SYMBOLIC DynTag = 16 /* Indicates "symbolic" linking. [sup] */
DT_REL DynTag = 17 /* Address of ElfNN_Rel relocations. */
DT_RELSZ DynTag = 18 /* Total size of ElfNN_Rel relocations. */
DT_RELENT DynTag = 19 /* Size of each ElfNN_Rel relocation. */
DT_PLTREL DynTag = 20 /* Type of relocation used for PLT. */
DT_DEBUG DynTag = 21 /* Reserved (not used). */
DT_TEXTREL DynTag = 22 /* Indicates there may be relocations in non-writable segments. [sup] */
DT_JMPREL DynTag = 23 /* Address of PLT relocations. */
DT_BIND_NOW DynTag = 24 /* [sup] */
DT_INIT_ARRAY DynTag = 25 /* Address of the array of pointers to initialization functions */
DT_FINI_ARRAY DynTag = 26 /* Address of the array of pointers to termination functions */
DT_INIT_ARRAYSZ DynTag = 27 /* Size in bytes of the array of initialization functions. */
DT_FINI_ARRAYSZ DynTag = 28 /* Size in bytes of the array of termination functions. */
DT_RUNPATH DynTag = 29 /* String table offset of a null-terminated library search path string. */
DT_FLAGS DynTag = 30 /* Object specific flag values. */
DT_ENCODING DynTag = 32 /* Values greater than or equal to DT_ENCODING
and less than DT_LOOS follow the rules for
the interpretation of the d_un union
as follows: even == 'd_ptr', even == 'd_val'
or none */
DT_PREINIT_ARRAY DynTag = 32 /* Address of the array of pointers to pre-initialization functions. */
DT_PREINIT_ARRAYSZ DynTag = 33 /* Size in bytes of the array of pre-initialization functions. */
DT_SYMTAB_SHNDX DynTag = 34 /* Address of SHT_SYMTAB_SHNDX section. */
DT_LOOS DynTag = 0x6000000d /* First OS-specific */
DT_HIOS DynTag = 0x6ffff000 /* Last OS-specific */
DT_VALRNGLO DynTag = 0x6ffffd00
DT_GNU_PRELINKED DynTag = 0x6ffffdf5
DT_GNU_CONFLICTSZ DynTag = 0x6ffffdf6
DT_GNU_LIBLISTSZ DynTag = 0x6ffffdf7
DT_CHECKSUM DynTag = 0x6ffffdf8
DT_PLTPADSZ DynTag = 0x6ffffdf9
DT_MOVEENT DynTag = 0x6ffffdfa
DT_MOVESZ DynTag = 0x6ffffdfb
DT_FEATURE DynTag = 0x6ffffdfc
DT_POSFLAG_1 DynTag = 0x6ffffdfd
DT_SYMINSZ DynTag = 0x6ffffdfe
DT_SYMINENT DynTag = 0x6ffffdff
DT_VALRNGHI DynTag = 0x6ffffdff
DT_ADDRRNGLO DynTag = 0x6ffffe00
DT_GNU_HASH DynTag = 0x6ffffef5
DT_TLSDESC_PLT DynTag = 0x6ffffef6
DT_TLSDESC_GOT DynTag = 0x6ffffef7
DT_GNU_CONFLICT DynTag = 0x6ffffef8
DT_GNU_LIBLIST DynTag = 0x6ffffef9
DT_CONFIG DynTag = 0x6ffffefa
DT_DEPAUDIT DynTag = 0x6ffffefb
DT_AUDIT DynTag = 0x6ffffefc
DT_PLTPAD DynTag = 0x6ffffefd
DT_MOVETAB DynTag = 0x6ffffefe
DT_SYMINFO DynTag = 0x6ffffeff
DT_ADDRRNGHI DynTag = 0x6ffffeff
DT_VERSYM DynTag = 0x6ffffff0
DT_RELACOUNT DynTag = 0x6ffffff9
DT_RELCOUNT DynTag = 0x6ffffffa
DT_FLAGS_1 DynTag = 0x6ffffffb
DT_VERDEF DynTag = 0x6ffffffc
DT_VERDEFNUM DynTag = 0x6ffffffd
DT_VERNEED DynTag = 0x6ffffffe
DT_VERNEEDNUM DynTag = 0x6fffffff
DT_LOPROC DynTag = 0x70000000 /* First processor-specific type. */
DT_MIPS_RLD_VERSION DynTag = 0x70000001
DT_MIPS_TIME_STAMP DynTag = 0x70000002
DT_MIPS_ICHECKSUM DynTag = 0x70000003
DT_MIPS_IVERSION DynTag = 0x70000004
DT_MIPS_FLAGS DynTag = 0x70000005
DT_MIPS_BASE_ADDRESS DynTag = 0x70000006
DT_MIPS_MSYM DynTag = 0x70000007
DT_MIPS_CONFLICT DynTag = 0x70000008
DT_MIPS_LIBLIST DynTag = 0x70000009
DT_MIPS_LOCAL_GOTNO DynTag = 0x7000000a
DT_MIPS_CONFLICTNO DynTag = 0x7000000b
DT_MIPS_LIBLISTNO DynTag = 0x70000010
DT_MIPS_SYMTABNO DynTag = 0x70000011
DT_MIPS_UNREFEXTNO DynTag = 0x70000012
DT_MIPS_GOTSYM DynTag = 0x70000013
DT_MIPS_HIPAGENO DynTag = 0x70000014
DT_MIPS_RLD_MAP DynTag = 0x70000016
DT_MIPS_DELTA_CLASS DynTag = 0x70000017
DT_MIPS_DELTA_CLASS_NO DynTag = 0x70000018
DT_MIPS_DELTA_INSTANCE DynTag = 0x70000019
DT_MIPS_DELTA_INSTANCE_NO DynTag = 0x7000001a
DT_MIPS_DELTA_RELOC DynTag = 0x7000001b
DT_MIPS_DELTA_RELOC_NO DynTag = 0x7000001c
DT_MIPS_DELTA_SYM DynTag = 0x7000001d
DT_MIPS_DELTA_SYM_NO DynTag = 0x7000001e
DT_MIPS_DELTA_CLASSSYM DynTag = 0x70000020
DT_MIPS_DELTA_CLASSSYM_NO DynTag = 0x70000021
DT_MIPS_CXX_FLAGS DynTag = 0x70000022
DT_MIPS_PIXIE_INIT DynTag = 0x70000023
DT_MIPS_SYMBOL_LIB DynTag = 0x70000024
DT_MIPS_LOCALPAGE_GOTIDX DynTag = 0x70000025
DT_MIPS_LOCAL_GOTIDX DynTag = 0x70000026
DT_MIPS_HIDDEN_GOTIDX DynTag = 0x70000027
DT_MIPS_PROTECTED_GOTIDX DynTag = 0x70000028
DT_MIPS_OPTIONS DynTag = 0x70000029
DT_MIPS_INTERFACE DynTag = 0x7000002a
DT_MIPS_DYNSTR_ALIGN DynTag = 0x7000002b
DT_MIPS_INTERFACE_SIZE DynTag = 0x7000002c
DT_MIPS_RLD_TEXT_RESOLVE_ADDR DynTag = 0x7000002d
DT_MIPS_PERF_SUFFIX DynTag = 0x7000002e
DT_MIPS_COMPACT_SIZE DynTag = 0x7000002f
DT_MIPS_GP_VALUE DynTag = 0x70000030
DT_MIPS_AUX_DYNAMIC DynTag = 0x70000031
DT_MIPS_PLTGOT DynTag = 0x70000032
DT_MIPS_RWPLT DynTag = 0x70000034
DT_MIPS_RLD_MAP_REL DynTag = 0x70000035
DT_PPC_GOT DynTag = 0x70000000
DT_PPC_OPT DynTag = 0x70000001
DT_PPC64_GLINK DynTag = 0x70000000
DT_PPC64_OPD DynTag = 0x70000001
DT_PPC64_OPDSZ DynTag = 0x70000002
DT_PPC64_OPT DynTag = 0x70000003
DT_SPARC_REGISTER DynTag = 0x70000001
DT_AUXILIARY DynTag = 0x7ffffffd
DT_USED DynTag = 0x7ffffffe
DT_FILTER DynTag = 0x7fffffff
DT_HIPROC DynTag = 0x7fffffff /* Last processor-specific type. */
)
var dtStrings = []intName{
{0, "DT_NULL"},
{1, "DT_NEEDED"},
{2, "DT_PLTRELSZ"},
{3, "DT_PLTGOT"},
{4, "DT_HASH"},
{5, "DT_STRTAB"},
{6, "DT_SYMTAB"},
{7, "DT_RELA"},
{8, "DT_RELASZ"},
{9, "DT_RELAENT"},
{10, "DT_STRSZ"},
{11, "DT_SYMENT"},
{12, "DT_INIT"},
{13, "DT_FINI"},
{14, "DT_SONAME"},
{15, "DT_RPATH"},
{16, "DT_SYMBOLIC"},
{17, "DT_REL"},
{18, "DT_RELSZ"},
{19, "DT_RELENT"},
{20, "DT_PLTREL"},
{21, "DT_DEBUG"},
{22, "DT_TEXTREL"},
{23, "DT_JMPREL"},
{24, "DT_BIND_NOW"},
{25, "DT_INIT_ARRAY"},
{26, "DT_FINI_ARRAY"},
{27, "DT_INIT_ARRAYSZ"},
{28, "DT_FINI_ARRAYSZ"},
{29, "DT_RUNPATH"},
{30, "DT_FLAGS"},
{32, "DT_ENCODING"},
{32, "DT_PREINIT_ARRAY"},
{33, "DT_PREINIT_ARRAYSZ"},
{34, "DT_SYMTAB_SHNDX"},
{0x6000000d, "DT_LOOS"},
{0x6ffff000, "DT_HIOS"},
{0x6ffffd00, "DT_VALRNGLO"},
{0x6ffffdf5, "DT_GNU_PRELINKED"},
{0x6ffffdf6, "DT_GNU_CONFLICTSZ"},
{0x6ffffdf7, "DT_GNU_LIBLISTSZ"},
{0x6ffffdf8, "DT_CHECKSUM"},
{0x6ffffdf9, "DT_PLTPADSZ"},
{0x6ffffdfa, "DT_MOVEENT"},
{0x6ffffdfb, "DT_MOVESZ"},
{0x6ffffdfc, "DT_FEATURE"},
{0x6ffffdfd, "DT_POSFLAG_1"},
{0x6ffffdfe, "DT_SYMINSZ"},
{0x6ffffdff, "DT_SYMINENT"},
{0x6ffffdff, "DT_VALRNGHI"},
{0x6ffffe00, "DT_ADDRRNGLO"},
{0x6ffffef5, "DT_GNU_HASH"},
{0x6ffffef6, "DT_TLSDESC_PLT"},
{0x6ffffef7, "DT_TLSDESC_GOT"},
{0x6ffffef8, "DT_GNU_CONFLICT"},
{0x6ffffef9, "DT_GNU_LIBLIST"},
{0x6ffffefa, "DT_CONFIG"},
{0x6ffffefb, "DT_DEPAUDIT"},
{0x6ffffefc, "DT_AUDIT"},
{0x6ffffefd, "DT_PLTPAD"},
{0x6ffffefe, "DT_MOVETAB"},
{0x6ffffeff, "DT_SYMINFO"},
{0x6ffffeff, "DT_ADDRRNGHI"},
{0x6ffffff0, "DT_VERSYM"},
{0x6ffffff9, "DT_RELACOUNT"},
{0x6ffffffa, "DT_RELCOUNT"},
{0x6ffffffb, "DT_FLAGS_1"},
{0x6ffffffc, "DT_VERDEF"},
{0x6ffffffd, "DT_VERDEFNUM"},
{0x6ffffffe, "DT_VERNEED"},
{0x6fffffff, "DT_VERNEEDNUM"},
{0x70000000, "DT_LOPROC"},
// We don't list the processor-dependent DynTags,
// as the values overlap.
{0x7ffffffd, "DT_AUXILIARY"},
{0x7ffffffe, "DT_USED"},
{0x7fffffff, "DT_FILTER"},
}
func (i DynTag) String() string { return stringName(uint32(i), dtStrings, false) }
func (i DynTag) GoString() string { return stringName(uint32(i), dtStrings, true) }
// DT_FLAGS values.
type DynFlag int
const (
DF_ORIGIN DynFlag = 0x0001 /* Indicates that the object being loaded may
make reference to the
$ORIGIN substitution string */
DF_SYMBOLIC DynFlag = 0x0002 /* Indicates "symbolic" linking. */
DF_TEXTREL DynFlag = 0x0004 /* Indicates there may be relocations in non-writable segments. */
DF_BIND_NOW DynFlag = 0x0008 /* Indicates that the dynamic linker should
process all relocations for the object
containing this entry before transferring
control to the program. */
DF_STATIC_TLS DynFlag = 0x0010 /* Indicates that the shared object or
executable contains code using a static
thread-local storage scheme. */
)
var dflagStrings = []intName{
{0x0001, "DF_ORIGIN"},
{0x0002, "DF_SYMBOLIC"},
{0x0004, "DF_TEXTREL"},
{0x0008, "DF_BIND_NOW"},
{0x0010, "DF_STATIC_TLS"},
}
func (i DynFlag) String() string { return flagName(uint32(i), dflagStrings, false) }
func (i DynFlag) GoString() string { return flagName(uint32(i), dflagStrings, true) }
// DT_FLAGS_1 values.
type DynFlag1 uint32
const (
// Indicates that all relocations for this object must be processed before
// returning control to the program.
DF_1_NOW DynFlag1 = 0x00000001
// Unused.
DF_1_GLOBAL DynFlag1 = 0x00000002
// Indicates that the object is a member of a group.
DF_1_GROUP DynFlag1 = 0x00000004
// Indicates that the object cannot be deleted from a process.
DF_1_NODELETE DynFlag1 = 0x00000008
// Meaningful only for filters. Indicates that all associated filtees be
// processed immediately.
DF_1_LOADFLTR DynFlag1 = 0x00000010
// Indicates that this object's initialization section be run before any other
// objects loaded.
DF_1_INITFIRST DynFlag1 = 0x00000020
// Indicates that the object cannot be added to a running process with dlopen.
DF_1_NOOPEN DynFlag1 = 0x00000040
// Indicates the object requires $ORIGIN processing.
DF_1_ORIGIN DynFlag1 = 0x00000080
// Indicates that the object should use direct binding information.
DF_1_DIRECT DynFlag1 = 0x00000100
// Unused.
DF_1_TRANS DynFlag1 = 0x00000200
// Indicates that the objects symbol table is to interpose before all symbols
// except the primary load object, which is typically the executable.
DF_1_INTERPOSE DynFlag1 = 0x00000400
// Indicates that the search for dependencies of this object ignores any
// default library search paths.
DF_1_NODEFLIB DynFlag1 = 0x00000800
// Indicates that this object is not dumped by dldump. Candidates are objects
// with no relocations that might get included when generating alternative
// objects using.
DF_1_NODUMP DynFlag1 = 0x00001000
// Identifies this object as a configuration alternative object generated by
// crle. Triggers the runtime linker to search for a configuration file $ORIGIN/ld.config.app-name.
DF_1_CONFALT DynFlag1 = 0x00002000
// Meaningful only for filtees. Terminates a filters search for any
// further filtees.
DF_1_ENDFILTEE DynFlag1 = 0x00004000
// Indicates that this object has displacement relocations applied.
DF_1_DISPRELDNE DynFlag1 = 0x00008000
// Indicates that this object has displacement relocations pending.
DF_1_DISPRELPND DynFlag1 = 0x00010000
// Indicates that this object contains symbols that cannot be directly
// bound to.
DF_1_NODIRECT DynFlag1 = 0x00020000
// Reserved for internal use by the kernel runtime-linker.
DF_1_IGNMULDEF DynFlag1 = 0x00040000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NOKSYMS DynFlag1 = 0x00080000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NOHDR DynFlag1 = 0x00100000
// Indicates that this object has been edited or has been modified since the
// objects original construction by the link-editor.
DF_1_EDITED DynFlag1 = 0x00200000
// Reserved for internal use by the kernel runtime-linker.
DF_1_NORELOC DynFlag1 = 0x00400000
// Indicates that the object contains individual symbols that should interpose
// before all symbols except the primary load object, which is typically the
// executable.
DF_1_SYMINTPOSE DynFlag1 = 0x00800000
// Indicates that the executable requires global auditing.
DF_1_GLOBAUDIT DynFlag1 = 0x01000000
// Indicates that the object defines, or makes reference to singleton symbols.
DF_1_SINGLETON DynFlag1 = 0x02000000
// Indicates that the object is a stub.
DF_1_STUB DynFlag1 = 0x04000000
// Indicates that the object is a position-independent executable.
DF_1_PIE DynFlag1 = 0x08000000
// Indicates that the object is a kernel module.
DF_1_KMOD DynFlag1 = 0x10000000
// Indicates that the object is a weak standard filter.
DF_1_WEAKFILTER DynFlag1 = 0x20000000
// Unused.
DF_1_NOCOMMON DynFlag1 = 0x40000000
)
var dflag1Strings = []intName{
{0x00000001, "DF_1_NOW"},
{0x00000002, "DF_1_GLOBAL"},
{0x00000004, "DF_1_GROUP"},
{0x00000008, "DF_1_NODELETE"},
{0x00000010, "DF_1_LOADFLTR"},
{0x00000020, "DF_1_INITFIRST"},
{0x00000040, "DF_1_NOOPEN"},
{0x00000080, "DF_1_ORIGIN"},
{0x00000100, "DF_1_DIRECT"},
{0x00000200, "DF_1_TRANS"},
{0x00000400, "DF_1_INTERPOSE"},
{0x00000800, "DF_1_NODEFLIB"},
{0x00001000, "DF_1_NODUMP"},
{0x00002000, "DF_1_CONFALT"},
{0x00004000, "DF_1_ENDFILTEE"},
{0x00008000, "DF_1_DISPRELDNE"},
{0x00010000, "DF_1_DISPRELPND"},
{0x00020000, "DF_1_NODIRECT"},
{0x00040000, "DF_1_IGNMULDEF"},
{0x00080000, "DF_1_NOKSYMS"},
{0x00100000, "DF_1_NOHDR"},
{0x00200000, "DF_1_EDITED"},
{0x00400000, "DF_1_NORELOC"},
{0x00800000, "DF_1_SYMINTPOSE"},
{0x01000000, "DF_1_GLOBAUDIT"},
{0x02000000, "DF_1_SINGLETON"},
{0x04000000, "DF_1_STUB"},
{0x08000000, "DF_1_PIE"},
{0x10000000, "DF_1_KMOD"},
{0x20000000, "DF_1_WEAKFILTER"},
{0x40000000, "DF_1_NOCOMMON"},
}
func (i DynFlag1) String() string { return flagName(uint32(i), dflag1Strings, false) }
func (i DynFlag1) GoString() string { return flagName(uint32(i), dflag1Strings, true) }
// NType values; used in core files.
type NType int
const (
NT_PRSTATUS NType = 1 /* Process status. */
NT_FPREGSET NType = 2 /* Floating point registers. */
NT_PRPSINFO NType = 3 /* Process state info. */
)
var ntypeStrings = []intName{
{1, "NT_PRSTATUS"},
{2, "NT_FPREGSET"},
{3, "NT_PRPSINFO"},
}
func (i NType) String() string { return stringName(uint32(i), ntypeStrings, false) }
func (i NType) GoString() string { return stringName(uint32(i), ntypeStrings, true) }
/* Symbol Binding - ELFNN_ST_BIND - st_info */
type SymBind int
const (
STB_LOCAL SymBind = 0 /* Local symbol */
STB_GLOBAL SymBind = 1 /* Global symbol */
STB_WEAK SymBind = 2 /* like global - lower precedence */
STB_LOOS SymBind = 10 /* Reserved range for operating system */
STB_HIOS SymBind = 12 /* specific semantics. */
STB_LOPROC SymBind = 13 /* reserved range for processor */
STB_HIPROC SymBind = 15 /* specific semantics. */
)
var stbStrings = []intName{
{0, "STB_LOCAL"},
{1, "STB_GLOBAL"},
{2, "STB_WEAK"},
{10, "STB_LOOS"},
{12, "STB_HIOS"},
{13, "STB_LOPROC"},
{15, "STB_HIPROC"},
}
func (i SymBind) String() string { return stringName(uint32(i), stbStrings, false) }
func (i SymBind) GoString() string { return stringName(uint32(i), stbStrings, true) }
/* Symbol type - ELFNN_ST_TYPE - st_info */
type SymType int
const (
STT_NOTYPE SymType = 0 /* Unspecified type. */
STT_OBJECT SymType = 1 /* Data object. */
STT_FUNC SymType = 2 /* Function. */
STT_SECTION SymType = 3 /* Section. */
STT_FILE SymType = 4 /* Source file. */
STT_COMMON SymType = 5 /* Uninitialized common block. */
STT_TLS SymType = 6 /* TLS object. */
STT_LOOS SymType = 10 /* Reserved range for operating system */
STT_HIOS SymType = 12 /* specific semantics. */
STT_LOPROC SymType = 13 /* reserved range for processor */
STT_HIPROC SymType = 15 /* specific semantics. */
)
var sttStrings = []intName{
{0, "STT_NOTYPE"},
{1, "STT_OBJECT"},
{2, "STT_FUNC"},
{3, "STT_SECTION"},
{4, "STT_FILE"},
{5, "STT_COMMON"},
{6, "STT_TLS"},
{10, "STT_LOOS"},
{12, "STT_HIOS"},
{13, "STT_LOPROC"},
{15, "STT_HIPROC"},
}
func (i SymType) String() string { return stringName(uint32(i), sttStrings, false) }
func (i SymType) GoString() string { return stringName(uint32(i), sttStrings, true) }
/* Symbol visibility - ELFNN_ST_VISIBILITY - st_other */
type SymVis int
const (
STV_DEFAULT SymVis = 0x0 /* Default visibility (see binding). */
STV_INTERNAL SymVis = 0x1 /* Special meaning in relocatable objects. */
STV_HIDDEN SymVis = 0x2 /* Not visible. */
STV_PROTECTED SymVis = 0x3 /* Visible but not preemptible. */
)
var stvStrings = []intName{
{0x0, "STV_DEFAULT"},
{0x1, "STV_INTERNAL"},
{0x2, "STV_HIDDEN"},
{0x3, "STV_PROTECTED"},
}
func (i SymVis) String() string { return stringName(uint32(i), stvStrings, false) }
func (i SymVis) GoString() string { return stringName(uint32(i), stvStrings, true) }
/*
* Relocation types.
*/
// Relocation types for x86-64.
type R_X86_64 int
const (
R_X86_64_NONE R_X86_64 = 0 /* No relocation. */
R_X86_64_64 R_X86_64 = 1 /* Add 64 bit symbol value. */
R_X86_64_PC32 R_X86_64 = 2 /* PC-relative 32 bit signed sym value. */
R_X86_64_GOT32 R_X86_64 = 3 /* PC-relative 32 bit GOT offset. */
R_X86_64_PLT32 R_X86_64 = 4 /* PC-relative 32 bit PLT offset. */
R_X86_64_COPY R_X86_64 = 5 /* Copy data from shared object. */
R_X86_64_GLOB_DAT R_X86_64 = 6 /* Set GOT entry to data address. */
R_X86_64_JMP_SLOT R_X86_64 = 7 /* Set GOT entry to code address. */
R_X86_64_RELATIVE R_X86_64 = 8 /* Add load address of shared object. */
R_X86_64_GOTPCREL R_X86_64 = 9 /* Add 32 bit signed pcrel offset to GOT. */
R_X86_64_32 R_X86_64 = 10 /* Add 32 bit zero extended symbol value */
R_X86_64_32S R_X86_64 = 11 /* Add 32 bit sign extended symbol value */
R_X86_64_16 R_X86_64 = 12 /* Add 16 bit zero extended symbol value */
R_X86_64_PC16 R_X86_64 = 13 /* Add 16 bit signed extended pc relative symbol value */
R_X86_64_8 R_X86_64 = 14 /* Add 8 bit zero extended symbol value */
R_X86_64_PC8 R_X86_64 = 15 /* Add 8 bit signed extended pc relative symbol value */
R_X86_64_DTPMOD64 R_X86_64 = 16 /* ID of module containing symbol */
R_X86_64_DTPOFF64 R_X86_64 = 17 /* Offset in TLS block */
R_X86_64_TPOFF64 R_X86_64 = 18 /* Offset in static TLS block */
R_X86_64_TLSGD R_X86_64 = 19 /* PC relative offset to GD GOT entry */
R_X86_64_TLSLD R_X86_64 = 20 /* PC relative offset to LD GOT entry */
R_X86_64_DTPOFF32 R_X86_64 = 21 /* Offset in TLS block */
R_X86_64_GOTTPOFF R_X86_64 = 22 /* PC relative offset to IE GOT entry */
R_X86_64_TPOFF32 R_X86_64 = 23 /* Offset in static TLS block */
R_X86_64_PC64 R_X86_64 = 24 /* PC relative 64-bit sign extended symbol value. */
R_X86_64_GOTOFF64 R_X86_64 = 25
R_X86_64_GOTPC32 R_X86_64 = 26
R_X86_64_GOT64 R_X86_64 = 27
R_X86_64_GOTPCREL64 R_X86_64 = 28
R_X86_64_GOTPC64 R_X86_64 = 29
R_X86_64_GOTPLT64 R_X86_64 = 30
R_X86_64_PLTOFF64 R_X86_64 = 31
R_X86_64_SIZE32 R_X86_64 = 32
R_X86_64_SIZE64 R_X86_64 = 33
R_X86_64_GOTPC32_TLSDESC R_X86_64 = 34
R_X86_64_TLSDESC_CALL R_X86_64 = 35
R_X86_64_TLSDESC R_X86_64 = 36
R_X86_64_IRELATIVE R_X86_64 = 37
R_X86_64_RELATIVE64 R_X86_64 = 38
R_X86_64_PC32_BND R_X86_64 = 39
R_X86_64_PLT32_BND R_X86_64 = 40
R_X86_64_GOTPCRELX R_X86_64 = 41
R_X86_64_REX_GOTPCRELX R_X86_64 = 42
)
var rx86_64Strings = []intName{
{0, "R_X86_64_NONE"},
{1, "R_X86_64_64"},
{2, "R_X86_64_PC32"},
{3, "R_X86_64_GOT32"},
{4, "R_X86_64_PLT32"},
{5, "R_X86_64_COPY"},
{6, "R_X86_64_GLOB_DAT"},
{7, "R_X86_64_JMP_SLOT"},
{8, "R_X86_64_RELATIVE"},
{9, "R_X86_64_GOTPCREL"},
{10, "R_X86_64_32"},
{11, "R_X86_64_32S"},
{12, "R_X86_64_16"},
{13, "R_X86_64_PC16"},
{14, "R_X86_64_8"},
{15, "R_X86_64_PC8"},
{16, "R_X86_64_DTPMOD64"},
{17, "R_X86_64_DTPOFF64"},
{18, "R_X86_64_TPOFF64"},
{19, "R_X86_64_TLSGD"},
{20, "R_X86_64_TLSLD"},
{21, "R_X86_64_DTPOFF32"},
{22, "R_X86_64_GOTTPOFF"},
{23, "R_X86_64_TPOFF32"},
{24, "R_X86_64_PC64"},
{25, "R_X86_64_GOTOFF64"},
{26, "R_X86_64_GOTPC32"},
{27, "R_X86_64_GOT64"},
{28, "R_X86_64_GOTPCREL64"},
{29, "R_X86_64_GOTPC64"},
{30, "R_X86_64_GOTPLT64"},
{31, "R_X86_64_PLTOFF64"},
{32, "R_X86_64_SIZE32"},
{33, "R_X86_64_SIZE64"},
{34, "R_X86_64_GOTPC32_TLSDESC"},
{35, "R_X86_64_TLSDESC_CALL"},
{36, "R_X86_64_TLSDESC"},
{37, "R_X86_64_IRELATIVE"},
{38, "R_X86_64_RELATIVE64"},
{39, "R_X86_64_PC32_BND"},
{40, "R_X86_64_PLT32_BND"},
{41, "R_X86_64_GOTPCRELX"},
{42, "R_X86_64_REX_GOTPCRELX"},
}
func (i R_X86_64) String() string { return stringName(uint32(i), rx86_64Strings, false) }
func (i R_X86_64) GoString() string { return stringName(uint32(i), rx86_64Strings, true) }
// Relocation types for AArch64 (aka arm64)
type R_AARCH64 int
const (
R_AARCH64_NONE R_AARCH64 = 0
R_AARCH64_P32_ABS32 R_AARCH64 = 1
R_AARCH64_P32_ABS16 R_AARCH64 = 2
R_AARCH64_P32_PREL32 R_AARCH64 = 3
R_AARCH64_P32_PREL16 R_AARCH64 = 4
R_AARCH64_P32_MOVW_UABS_G0 R_AARCH64 = 5
R_AARCH64_P32_MOVW_UABS_G0_NC R_AARCH64 = 6
R_AARCH64_P32_MOVW_UABS_G1 R_AARCH64 = 7
R_AARCH64_P32_MOVW_SABS_G0 R_AARCH64 = 8
R_AARCH64_P32_LD_PREL_LO19 R_AARCH64 = 9
R_AARCH64_P32_ADR_PREL_LO21 R_AARCH64 = 10
R_AARCH64_P32_ADR_PREL_PG_HI21 R_AARCH64 = 11
R_AARCH64_P32_ADD_ABS_LO12_NC R_AARCH64 = 12
R_AARCH64_P32_LDST8_ABS_LO12_NC R_AARCH64 = 13
R_AARCH64_P32_LDST16_ABS_LO12_NC R_AARCH64 = 14
R_AARCH64_P32_LDST32_ABS_LO12_NC R_AARCH64 = 15
R_AARCH64_P32_LDST64_ABS_LO12_NC R_AARCH64 = 16
R_AARCH64_P32_LDST128_ABS_LO12_NC R_AARCH64 = 17
R_AARCH64_P32_TSTBR14 R_AARCH64 = 18
R_AARCH64_P32_CONDBR19 R_AARCH64 = 19
R_AARCH64_P32_JUMP26 R_AARCH64 = 20
R_AARCH64_P32_CALL26 R_AARCH64 = 21
R_AARCH64_P32_GOT_LD_PREL19 R_AARCH64 = 25
R_AARCH64_P32_ADR_GOT_PAGE R_AARCH64 = 26
R_AARCH64_P32_LD32_GOT_LO12_NC R_AARCH64 = 27
R_AARCH64_P32_TLSGD_ADR_PAGE21 R_AARCH64 = 81
R_AARCH64_P32_TLSGD_ADD_LO12_NC R_AARCH64 = 82
R_AARCH64_P32_TLSIE_ADR_GOTTPREL_PAGE21 R_AARCH64 = 103
R_AARCH64_P32_TLSIE_LD32_GOTTPREL_LO12_NC R_AARCH64 = 104
R_AARCH64_P32_TLSIE_LD_GOTTPREL_PREL19 R_AARCH64 = 105
R_AARCH64_P32_TLSLE_MOVW_TPREL_G1 R_AARCH64 = 106
R_AARCH64_P32_TLSLE_MOVW_TPREL_G0 R_AARCH64 = 107
R_AARCH64_P32_TLSLE_MOVW_TPREL_G0_NC R_AARCH64 = 108
R_AARCH64_P32_TLSLE_ADD_TPREL_HI12 R_AARCH64 = 109
R_AARCH64_P32_TLSLE_ADD_TPREL_LO12 R_AARCH64 = 110
R_AARCH64_P32_TLSLE_ADD_TPREL_LO12_NC R_AARCH64 = 111
R_AARCH64_P32_TLSDESC_LD_PREL19 R_AARCH64 = 122
R_AARCH64_P32_TLSDESC_ADR_PREL21 R_AARCH64 = 123
R_AARCH64_P32_TLSDESC_ADR_PAGE21 R_AARCH64 = 124
R_AARCH64_P32_TLSDESC_LD32_LO12_NC R_AARCH64 = 125
R_AARCH64_P32_TLSDESC_ADD_LO12_NC R_AARCH64 = 126
R_AARCH64_P32_TLSDESC_CALL R_AARCH64 = 127
R_AARCH64_P32_COPY R_AARCH64 = 180
R_AARCH64_P32_GLOB_DAT R_AARCH64 = 181
R_AARCH64_P32_JUMP_SLOT R_AARCH64 = 182
R_AARCH64_P32_RELATIVE R_AARCH64 = 183
R_AARCH64_P32_TLS_DTPMOD R_AARCH64 = 184
R_AARCH64_P32_TLS_DTPREL R_AARCH64 = 185
R_AARCH64_P32_TLS_TPREL R_AARCH64 = 186
R_AARCH64_P32_TLSDESC R_AARCH64 = 187
R_AARCH64_P32_IRELATIVE R_AARCH64 = 188
R_AARCH64_NULL R_AARCH64 = 256
R_AARCH64_ABS64 R_AARCH64 = 257
R_AARCH64_ABS32 R_AARCH64 = 258
R_AARCH64_ABS16 R_AARCH64 = 259
R_AARCH64_PREL64 R_AARCH64 = 260
R_AARCH64_PREL32 R_AARCH64 = 261
R_AARCH64_PREL16 R_AARCH64 = 262
R_AARCH64_MOVW_UABS_G0 R_AARCH64 = 263
R_AARCH64_MOVW_UABS_G0_NC R_AARCH64 = 264
R_AARCH64_MOVW_UABS_G1 R_AARCH64 = 265
R_AARCH64_MOVW_UABS_G1_NC R_AARCH64 = 266
R_AARCH64_MOVW_UABS_G2 R_AARCH64 = 267
R_AARCH64_MOVW_UABS_G2_NC R_AARCH64 = 268
R_AARCH64_MOVW_UABS_G3 R_AARCH64 = 269
R_AARCH64_MOVW_SABS_G0 R_AARCH64 = 270
R_AARCH64_MOVW_SABS_G1 R_AARCH64 = 271
R_AARCH64_MOVW_SABS_G2 R_AARCH64 = 272
R_AARCH64_LD_PREL_LO19 R_AARCH64 = 273
R_AARCH64_ADR_PREL_LO21 R_AARCH64 = 274
R_AARCH64_ADR_PREL_PG_HI21 R_AARCH64 = 275
R_AARCH64_ADR_PREL_PG_HI21_NC R_AARCH64 = 276
R_AARCH64_ADD_ABS_LO12_NC R_AARCH64 = 277
R_AARCH64_LDST8_ABS_LO12_NC R_AARCH64 = 278
R_AARCH64_TSTBR14 R_AARCH64 = 279
R_AARCH64_CONDBR19 R_AARCH64 = 280
R_AARCH64_JUMP26 R_AARCH64 = 282
R_AARCH64_CALL26 R_AARCH64 = 283
R_AARCH64_LDST16_ABS_LO12_NC R_AARCH64 = 284
R_AARCH64_LDST32_ABS_LO12_NC R_AARCH64 = 285
R_AARCH64_LDST64_ABS_LO12_NC R_AARCH64 = 286
R_AARCH64_LDST128_ABS_LO12_NC R_AARCH64 = 299
R_AARCH64_GOT_LD_PREL19 R_AARCH64 = 309
R_AARCH64_LD64_GOTOFF_LO15 R_AARCH64 = 310
R_AARCH64_ADR_GOT_PAGE R_AARCH64 = 311
R_AARCH64_LD64_GOT_LO12_NC R_AARCH64 = 312
R_AARCH64_LD64_GOTPAGE_LO15 R_AARCH64 = 313
R_AARCH64_TLSGD_ADR_PREL21 R_AARCH64 = 512
R_AARCH64_TLSGD_ADR_PAGE21 R_AARCH64 = 513
R_AARCH64_TLSGD_ADD_LO12_NC R_AARCH64 = 514
R_AARCH64_TLSGD_MOVW_G1 R_AARCH64 = 515
R_AARCH64_TLSGD_MOVW_G0_NC R_AARCH64 = 516
R_AARCH64_TLSLD_ADR_PREL21 R_AARCH64 = 517
R_AARCH64_TLSLD_ADR_PAGE21 R_AARCH64 = 518
R_AARCH64_TLSIE_MOVW_GOTTPREL_G1 R_AARCH64 = 539
R_AARCH64_TLSIE_MOVW_GOTTPREL_G0_NC R_AARCH64 = 540
R_AARCH64_TLSIE_ADR_GOTTPREL_PAGE21 R_AARCH64 = 541
R_AARCH64_TLSIE_LD64_GOTTPREL_LO12_NC R_AARCH64 = 542
R_AARCH64_TLSIE_LD_GOTTPREL_PREL19 R_AARCH64 = 543
R_AARCH64_TLSLE_MOVW_TPREL_G2 R_AARCH64 = 544
R_AARCH64_TLSLE_MOVW_TPREL_G1 R_AARCH64 = 545
R_AARCH64_TLSLE_MOVW_TPREL_G1_NC R_AARCH64 = 546
R_AARCH64_TLSLE_MOVW_TPREL_G0 R_AARCH64 = 547
R_AARCH64_TLSLE_MOVW_TPREL_G0_NC R_AARCH64 = 548
R_AARCH64_TLSLE_ADD_TPREL_HI12 R_AARCH64 = 549
R_AARCH64_TLSLE_ADD_TPREL_LO12 R_AARCH64 = 550
R_AARCH64_TLSLE_ADD_TPREL_LO12_NC R_AARCH64 = 551
R_AARCH64_TLSDESC_LD_PREL19 R_AARCH64 = 560
R_AARCH64_TLSDESC_ADR_PREL21 R_AARCH64 = 561
R_AARCH64_TLSDESC_ADR_PAGE21 R_AARCH64 = 562
R_AARCH64_TLSDESC_LD64_LO12_NC R_AARCH64 = 563
R_AARCH64_TLSDESC_ADD_LO12_NC R_AARCH64 = 564
R_AARCH64_TLSDESC_OFF_G1 R_AARCH64 = 565
R_AARCH64_TLSDESC_OFF_G0_NC R_AARCH64 = 566
R_AARCH64_TLSDESC_LDR R_AARCH64 = 567
R_AARCH64_TLSDESC_ADD R_AARCH64 = 568
R_AARCH64_TLSDESC_CALL R_AARCH64 = 569
R_AARCH64_TLSLE_LDST128_TPREL_LO12 R_AARCH64 = 570
R_AARCH64_TLSLE_LDST128_TPREL_LO12_NC R_AARCH64 = 571
R_AARCH64_TLSLD_LDST128_DTPREL_LO12 R_AARCH64 = 572
R_AARCH64_TLSLD_LDST128_DTPREL_LO12_NC R_AARCH64 = 573
R_AARCH64_COPY R_AARCH64 = 1024
R_AARCH64_GLOB_DAT R_AARCH64 = 1025
R_AARCH64_JUMP_SLOT R_AARCH64 = 1026
R_AARCH64_RELATIVE R_AARCH64 = 1027
R_AARCH64_TLS_DTPMOD64 R_AARCH64 = 1028
R_AARCH64_TLS_DTPREL64 R_AARCH64 = 1029
R_AARCH64_TLS_TPREL64 R_AARCH64 = 1030
R_AARCH64_TLSDESC R_AARCH64 = 1031
R_AARCH64_IRELATIVE R_AARCH64 = 1032
)
var raarch64Strings = []intName{
{0, "R_AARCH64_NONE"},
{1, "R_AARCH64_P32_ABS32"},
{2, "R_AARCH64_P32_ABS16"},
{3, "R_AARCH64_P32_PREL32"},
{4, "R_AARCH64_P32_PREL16"},
{5, "R_AARCH64_P32_MOVW_UABS_G0"},
{6, "R_AARCH64_P32_MOVW_UABS_G0_NC"},
{7, "R_AARCH64_P32_MOVW_UABS_G1"},
{8, "R_AARCH64_P32_MOVW_SABS_G0"},
{9, "R_AARCH64_P32_LD_PREL_LO19"},
{10, "R_AARCH64_P32_ADR_PREL_LO21"},
{11, "R_AARCH64_P32_ADR_PREL_PG_HI21"},
{12, "R_AARCH64_P32_ADD_ABS_LO12_NC"},
{13, "R_AARCH64_P32_LDST8_ABS_LO12_NC"},
{14, "R_AARCH64_P32_LDST16_ABS_LO12_NC"},
{15, "R_AARCH64_P32_LDST32_ABS_LO12_NC"},
{16, "R_AARCH64_P32_LDST64_ABS_LO12_NC"},
{17, "R_AARCH64_P32_LDST128_ABS_LO12_NC"},
{18, "R_AARCH64_P32_TSTBR14"},
{19, "R_AARCH64_P32_CONDBR19"},
{20, "R_AARCH64_P32_JUMP26"},
{21, "R_AARCH64_P32_CALL26"},
{25, "R_AARCH64_P32_GOT_LD_PREL19"},
{26, "R_AARCH64_P32_ADR_GOT_PAGE"},
{27, "R_AARCH64_P32_LD32_GOT_LO12_NC"},
{81, "R_AARCH64_P32_TLSGD_ADR_PAGE21"},
{82, "R_AARCH64_P32_TLSGD_ADD_LO12_NC"},
{103, "R_AARCH64_P32_TLSIE_ADR_GOTTPREL_PAGE21"},
{104, "R_AARCH64_P32_TLSIE_LD32_GOTTPREL_LO12_NC"},
{105, "R_AARCH64_P32_TLSIE_LD_GOTTPREL_PREL19"},
{106, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G1"},
{107, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G0"},
{108, "R_AARCH64_P32_TLSLE_MOVW_TPREL_G0_NC"},
{109, "R_AARCH64_P32_TLSLE_ADD_TPREL_HI12"},
{110, "R_AARCH64_P32_TLSLE_ADD_TPREL_LO12"},
{111, "R_AARCH64_P32_TLSLE_ADD_TPREL_LO12_NC"},
{122, "R_AARCH64_P32_TLSDESC_LD_PREL19"},
{123, "R_AARCH64_P32_TLSDESC_ADR_PREL21"},
{124, "R_AARCH64_P32_TLSDESC_ADR_PAGE21"},
{125, "R_AARCH64_P32_TLSDESC_LD32_LO12_NC"},
{126, "R_AARCH64_P32_TLSDESC_ADD_LO12_NC"},
{127, "R_AARCH64_P32_TLSDESC_CALL"},
{180, "R_AARCH64_P32_COPY"},
{181, "R_AARCH64_P32_GLOB_DAT"},
{182, "R_AARCH64_P32_JUMP_SLOT"},
{183, "R_AARCH64_P32_RELATIVE"},
{184, "R_AARCH64_P32_TLS_DTPMOD"},
{185, "R_AARCH64_P32_TLS_DTPREL"},
{186, "R_AARCH64_P32_TLS_TPREL"},
{187, "R_AARCH64_P32_TLSDESC"},
{188, "R_AARCH64_P32_IRELATIVE"},
{256, "R_AARCH64_NULL"},
{257, "R_AARCH64_ABS64"},
{258, "R_AARCH64_ABS32"},
{259, "R_AARCH64_ABS16"},
{260, "R_AARCH64_PREL64"},
{261, "R_AARCH64_PREL32"},
{262, "R_AARCH64_PREL16"},
{263, "R_AARCH64_MOVW_UABS_G0"},
{264, "R_AARCH64_MOVW_UABS_G0_NC"},
{265, "R_AARCH64_MOVW_UABS_G1"},
{266, "R_AARCH64_MOVW_UABS_G1_NC"},
{267, "R_AARCH64_MOVW_UABS_G2"},
{268, "R_AARCH64_MOVW_UABS_G2_NC"},
{269, "R_AARCH64_MOVW_UABS_G3"},
{270, "R_AARCH64_MOVW_SABS_G0"},
{271, "R_AARCH64_MOVW_SABS_G1"},
{272, "R_AARCH64_MOVW_SABS_G2"},
{273, "R_AARCH64_LD_PREL_LO19"},
{274, "R_AARCH64_ADR_PREL_LO21"},
{275, "R_AARCH64_ADR_PREL_PG_HI21"},
{276, "R_AARCH64_ADR_PREL_PG_HI21_NC"},
{277, "R_AARCH64_ADD_ABS_LO12_NC"},
{278, "R_AARCH64_LDST8_ABS_LO12_NC"},
{279, "R_AARCH64_TSTBR14"},
{280, "R_AARCH64_CONDBR19"},
{282, "R_AARCH64_JUMP26"},
{283, "R_AARCH64_CALL26"},
{284, "R_AARCH64_LDST16_ABS_LO12_NC"},
{285, "R_AARCH64_LDST32_ABS_LO12_NC"},
{286, "R_AARCH64_LDST64_ABS_LO12_NC"},
{299, "R_AARCH64_LDST128_ABS_LO12_NC"},
{309, "R_AARCH64_GOT_LD_PREL19"},
{310, "R_AARCH64_LD64_GOTOFF_LO15"},
{311, "R_AARCH64_ADR_GOT_PAGE"},
{312, "R_AARCH64_LD64_GOT_LO12_NC"},
{313, "R_AARCH64_LD64_GOTPAGE_LO15"},
{512, "R_AARCH64_TLSGD_ADR_PREL21"},
{513, "R_AARCH64_TLSGD_ADR_PAGE21"},
{514, "R_AARCH64_TLSGD_ADD_LO12_NC"},
{515, "R_AARCH64_TLSGD_MOVW_G1"},
{516, "R_AARCH64_TLSGD_MOVW_G0_NC"},
{517, "R_AARCH64_TLSLD_ADR_PREL21"},
{518, "R_AARCH64_TLSLD_ADR_PAGE21"},
{539, "R_AARCH64_TLSIE_MOVW_GOTTPREL_G1"},
{540, "R_AARCH64_TLSIE_MOVW_GOTTPREL_G0_NC"},
{541, "R_AARCH64_TLSIE_ADR_GOTTPREL_PAGE21"},
{542, "R_AARCH64_TLSIE_LD64_GOTTPREL_LO12_NC"},
{543, "R_AARCH64_TLSIE_LD_GOTTPREL_PREL19"},
{544, "R_AARCH64_TLSLE_MOVW_TPREL_G2"},
{545, "R_AARCH64_TLSLE_MOVW_TPREL_G1"},
{546, "R_AARCH64_TLSLE_MOVW_TPREL_G1_NC"},
{547, "R_AARCH64_TLSLE_MOVW_TPREL_G0"},
{548, "R_AARCH64_TLSLE_MOVW_TPREL_G0_NC"},
{549, "R_AARCH64_TLSLE_ADD_TPREL_HI12"},
{550, "R_AARCH64_TLSLE_ADD_TPREL_LO12"},
{551, "R_AARCH64_TLSLE_ADD_TPREL_LO12_NC"},
{560, "R_AARCH64_TLSDESC_LD_PREL19"},
{561, "R_AARCH64_TLSDESC_ADR_PREL21"},
{562, "R_AARCH64_TLSDESC_ADR_PAGE21"},
{563, "R_AARCH64_TLSDESC_LD64_LO12_NC"},
{564, "R_AARCH64_TLSDESC_ADD_LO12_NC"},
{565, "R_AARCH64_TLSDESC_OFF_G1"},
{566, "R_AARCH64_TLSDESC_OFF_G0_NC"},
{567, "R_AARCH64_TLSDESC_LDR"},
{568, "R_AARCH64_TLSDESC_ADD"},
{569, "R_AARCH64_TLSDESC_CALL"},
{570, "R_AARCH64_TLSLE_LDST128_TPREL_LO12"},
{571, "R_AARCH64_TLSLE_LDST128_TPREL_LO12_NC"},
{572, "R_AARCH64_TLSLD_LDST128_DTPREL_LO12"},
{573, "R_AARCH64_TLSLD_LDST128_DTPREL_LO12_NC"},
{1024, "R_AARCH64_COPY"},
{1025, "R_AARCH64_GLOB_DAT"},
{1026, "R_AARCH64_JUMP_SLOT"},
{1027, "R_AARCH64_RELATIVE"},
{1028, "R_AARCH64_TLS_DTPMOD64"},
{1029, "R_AARCH64_TLS_DTPREL64"},
{1030, "R_AARCH64_TLS_TPREL64"},
{1031, "R_AARCH64_TLSDESC"},
{1032, "R_AARCH64_IRELATIVE"},
}
func (i R_AARCH64) String() string { return stringName(uint32(i), raarch64Strings, false) }
func (i R_AARCH64) GoString() string { return stringName(uint32(i), raarch64Strings, true) }
// Relocation types for Alpha.
type R_ALPHA int
const (
R_ALPHA_NONE R_ALPHA = 0 /* No reloc */
R_ALPHA_REFLONG R_ALPHA = 1 /* Direct 32 bit */
R_ALPHA_REFQUAD R_ALPHA = 2 /* Direct 64 bit */
R_ALPHA_GPREL32 R_ALPHA = 3 /* GP relative 32 bit */
R_ALPHA_LITERAL R_ALPHA = 4 /* GP relative 16 bit w/optimization */
R_ALPHA_LITUSE R_ALPHA = 5 /* Optimization hint for LITERAL */
R_ALPHA_GPDISP R_ALPHA = 6 /* Add displacement to GP */
R_ALPHA_BRADDR R_ALPHA = 7 /* PC+4 relative 23 bit shifted */
R_ALPHA_HINT R_ALPHA = 8 /* PC+4 relative 16 bit shifted */
R_ALPHA_SREL16 R_ALPHA = 9 /* PC relative 16 bit */
R_ALPHA_SREL32 R_ALPHA = 10 /* PC relative 32 bit */
R_ALPHA_SREL64 R_ALPHA = 11 /* PC relative 64 bit */
R_ALPHA_OP_PUSH R_ALPHA = 12 /* OP stack push */
R_ALPHA_OP_STORE R_ALPHA = 13 /* OP stack pop and store */
R_ALPHA_OP_PSUB R_ALPHA = 14 /* OP stack subtract */
R_ALPHA_OP_PRSHIFT R_ALPHA = 15 /* OP stack right shift */
R_ALPHA_GPVALUE R_ALPHA = 16
R_ALPHA_GPRELHIGH R_ALPHA = 17
R_ALPHA_GPRELLOW R_ALPHA = 18
R_ALPHA_IMMED_GP_16 R_ALPHA = 19
R_ALPHA_IMMED_GP_HI32 R_ALPHA = 20
R_ALPHA_IMMED_SCN_HI32 R_ALPHA = 21
R_ALPHA_IMMED_BR_HI32 R_ALPHA = 22
R_ALPHA_IMMED_LO32 R_ALPHA = 23
R_ALPHA_COPY R_ALPHA = 24 /* Copy symbol at runtime */
R_ALPHA_GLOB_DAT R_ALPHA = 25 /* Create GOT entry */
R_ALPHA_JMP_SLOT R_ALPHA = 26 /* Create PLT entry */
R_ALPHA_RELATIVE R_ALPHA = 27 /* Adjust by program base */
)
var ralphaStrings = []intName{
{0, "R_ALPHA_NONE"},
{1, "R_ALPHA_REFLONG"},
{2, "R_ALPHA_REFQUAD"},
{3, "R_ALPHA_GPREL32"},
{4, "R_ALPHA_LITERAL"},
{5, "R_ALPHA_LITUSE"},
{6, "R_ALPHA_GPDISP"},
{7, "R_ALPHA_BRADDR"},
{8, "R_ALPHA_HINT"},
{9, "R_ALPHA_SREL16"},
{10, "R_ALPHA_SREL32"},
{11, "R_ALPHA_SREL64"},
{12, "R_ALPHA_OP_PUSH"},
{13, "R_ALPHA_OP_STORE"},
{14, "R_ALPHA_OP_PSUB"},
{15, "R_ALPHA_OP_PRSHIFT"},
{16, "R_ALPHA_GPVALUE"},
{17, "R_ALPHA_GPRELHIGH"},
{18, "R_ALPHA_GPRELLOW"},
{19, "R_ALPHA_IMMED_GP_16"},
{20, "R_ALPHA_IMMED_GP_HI32"},
{21, "R_ALPHA_IMMED_SCN_HI32"},
{22, "R_ALPHA_IMMED_BR_HI32"},
{23, "R_ALPHA_IMMED_LO32"},
{24, "R_ALPHA_COPY"},
{25, "R_ALPHA_GLOB_DAT"},
{26, "R_ALPHA_JMP_SLOT"},
{27, "R_ALPHA_RELATIVE"},
}
func (i R_ALPHA) String() string { return stringName(uint32(i), ralphaStrings, false) }
func (i R_ALPHA) GoString() string { return stringName(uint32(i), ralphaStrings, true) }
// Relocation types for ARM.
type R_ARM int
const (
R_ARM_NONE R_ARM = 0 /* No relocation. */
R_ARM_PC24 R_ARM = 1
R_ARM_ABS32 R_ARM = 2
R_ARM_REL32 R_ARM = 3
R_ARM_PC13 R_ARM = 4
R_ARM_ABS16 R_ARM = 5
R_ARM_ABS12 R_ARM = 6
R_ARM_THM_ABS5 R_ARM = 7
R_ARM_ABS8 R_ARM = 8
R_ARM_SBREL32 R_ARM = 9
R_ARM_THM_PC22 R_ARM = 10
R_ARM_THM_PC8 R_ARM = 11
R_ARM_AMP_VCALL9 R_ARM = 12
R_ARM_SWI24 R_ARM = 13
R_ARM_THM_SWI8 R_ARM = 14
R_ARM_XPC25 R_ARM = 15
R_ARM_THM_XPC22 R_ARM = 16
R_ARM_TLS_DTPMOD32 R_ARM = 17
R_ARM_TLS_DTPOFF32 R_ARM = 18
R_ARM_TLS_TPOFF32 R_ARM = 19
R_ARM_COPY R_ARM = 20 /* Copy data from shared object. */
R_ARM_GLOB_DAT R_ARM = 21 /* Set GOT entry to data address. */
R_ARM_JUMP_SLOT R_ARM = 22 /* Set GOT entry to code address. */
R_ARM_RELATIVE R_ARM = 23 /* Add load address of shared object. */
R_ARM_GOTOFF R_ARM = 24 /* Add GOT-relative symbol address. */
R_ARM_GOTPC R_ARM = 25 /* Add PC-relative GOT table address. */
R_ARM_GOT32 R_ARM = 26 /* Add PC-relative GOT offset. */
R_ARM_PLT32 R_ARM = 27 /* Add PC-relative PLT offset. */
R_ARM_CALL R_ARM = 28
R_ARM_JUMP24 R_ARM = 29
R_ARM_THM_JUMP24 R_ARM = 30
R_ARM_BASE_ABS R_ARM = 31
R_ARM_ALU_PCREL_7_0 R_ARM = 32
R_ARM_ALU_PCREL_15_8 R_ARM = 33
R_ARM_ALU_PCREL_23_15 R_ARM = 34
R_ARM_LDR_SBREL_11_10_NC R_ARM = 35
R_ARM_ALU_SBREL_19_12_NC R_ARM = 36
R_ARM_ALU_SBREL_27_20_CK R_ARM = 37
R_ARM_TARGET1 R_ARM = 38
R_ARM_SBREL31 R_ARM = 39
R_ARM_V4BX R_ARM = 40
R_ARM_TARGET2 R_ARM = 41
R_ARM_PREL31 R_ARM = 42
R_ARM_MOVW_ABS_NC R_ARM = 43
R_ARM_MOVT_ABS R_ARM = 44
R_ARM_MOVW_PREL_NC R_ARM = 45
R_ARM_MOVT_PREL R_ARM = 46
R_ARM_THM_MOVW_ABS_NC R_ARM = 47
R_ARM_THM_MOVT_ABS R_ARM = 48
R_ARM_THM_MOVW_PREL_NC R_ARM = 49
R_ARM_THM_MOVT_PREL R_ARM = 50
R_ARM_THM_JUMP19 R_ARM = 51
R_ARM_THM_JUMP6 R_ARM = 52
R_ARM_THM_ALU_PREL_11_0 R_ARM = 53
R_ARM_THM_PC12 R_ARM = 54
R_ARM_ABS32_NOI R_ARM = 55
R_ARM_REL32_NOI R_ARM = 56
R_ARM_ALU_PC_G0_NC R_ARM = 57
R_ARM_ALU_PC_G0 R_ARM = 58
R_ARM_ALU_PC_G1_NC R_ARM = 59
R_ARM_ALU_PC_G1 R_ARM = 60
R_ARM_ALU_PC_G2 R_ARM = 61
R_ARM_LDR_PC_G1 R_ARM = 62
R_ARM_LDR_PC_G2 R_ARM = 63
R_ARM_LDRS_PC_G0 R_ARM = 64
R_ARM_LDRS_PC_G1 R_ARM = 65
R_ARM_LDRS_PC_G2 R_ARM = 66
R_ARM_LDC_PC_G0 R_ARM = 67
R_ARM_LDC_PC_G1 R_ARM = 68
R_ARM_LDC_PC_G2 R_ARM = 69
R_ARM_ALU_SB_G0_NC R_ARM = 70
R_ARM_ALU_SB_G0 R_ARM = 71
R_ARM_ALU_SB_G1_NC R_ARM = 72
R_ARM_ALU_SB_G1 R_ARM = 73
R_ARM_ALU_SB_G2 R_ARM = 74
R_ARM_LDR_SB_G0 R_ARM = 75
R_ARM_LDR_SB_G1 R_ARM = 76
R_ARM_LDR_SB_G2 R_ARM = 77
R_ARM_LDRS_SB_G0 R_ARM = 78
R_ARM_LDRS_SB_G1 R_ARM = 79
R_ARM_LDRS_SB_G2 R_ARM = 80
R_ARM_LDC_SB_G0 R_ARM = 81
R_ARM_LDC_SB_G1 R_ARM = 82
R_ARM_LDC_SB_G2 R_ARM = 83
R_ARM_MOVW_BREL_NC R_ARM = 84
R_ARM_MOVT_BREL R_ARM = 85
R_ARM_MOVW_BREL R_ARM = 86
R_ARM_THM_MOVW_BREL_NC R_ARM = 87
R_ARM_THM_MOVT_BREL R_ARM = 88
R_ARM_THM_MOVW_BREL R_ARM = 89
R_ARM_TLS_GOTDESC R_ARM = 90
R_ARM_TLS_CALL R_ARM = 91
R_ARM_TLS_DESCSEQ R_ARM = 92
R_ARM_THM_TLS_CALL R_ARM = 93
R_ARM_PLT32_ABS R_ARM = 94
R_ARM_GOT_ABS R_ARM = 95
R_ARM_GOT_PREL R_ARM = 96
R_ARM_GOT_BREL12 R_ARM = 97
R_ARM_GOTOFF12 R_ARM = 98
R_ARM_GOTRELAX R_ARM = 99
R_ARM_GNU_VTENTRY R_ARM = 100
R_ARM_GNU_VTINHERIT R_ARM = 101
R_ARM_THM_JUMP11 R_ARM = 102
R_ARM_THM_JUMP8 R_ARM = 103
R_ARM_TLS_GD32 R_ARM = 104
R_ARM_TLS_LDM32 R_ARM = 105
R_ARM_TLS_LDO32 R_ARM = 106
R_ARM_TLS_IE32 R_ARM = 107
R_ARM_TLS_LE32 R_ARM = 108
R_ARM_TLS_LDO12 R_ARM = 109
R_ARM_TLS_LE12 R_ARM = 110
R_ARM_TLS_IE12GP R_ARM = 111
R_ARM_PRIVATE_0 R_ARM = 112
R_ARM_PRIVATE_1 R_ARM = 113
R_ARM_PRIVATE_2 R_ARM = 114
R_ARM_PRIVATE_3 R_ARM = 115
R_ARM_PRIVATE_4 R_ARM = 116
R_ARM_PRIVATE_5 R_ARM = 117
R_ARM_PRIVATE_6 R_ARM = 118
R_ARM_PRIVATE_7 R_ARM = 119
R_ARM_PRIVATE_8 R_ARM = 120
R_ARM_PRIVATE_9 R_ARM = 121
R_ARM_PRIVATE_10 R_ARM = 122
R_ARM_PRIVATE_11 R_ARM = 123
R_ARM_PRIVATE_12 R_ARM = 124
R_ARM_PRIVATE_13 R_ARM = 125
R_ARM_PRIVATE_14 R_ARM = 126
R_ARM_PRIVATE_15 R_ARM = 127
R_ARM_ME_TOO R_ARM = 128
R_ARM_THM_TLS_DESCSEQ16 R_ARM = 129
R_ARM_THM_TLS_DESCSEQ32 R_ARM = 130
R_ARM_THM_GOT_BREL12 R_ARM = 131
R_ARM_THM_ALU_ABS_G0_NC R_ARM = 132
R_ARM_THM_ALU_ABS_G1_NC R_ARM = 133
R_ARM_THM_ALU_ABS_G2_NC R_ARM = 134
R_ARM_THM_ALU_ABS_G3 R_ARM = 135
R_ARM_IRELATIVE R_ARM = 160
R_ARM_RXPC25 R_ARM = 249
R_ARM_RSBREL32 R_ARM = 250
R_ARM_THM_RPC22 R_ARM = 251
R_ARM_RREL32 R_ARM = 252
R_ARM_RABS32 R_ARM = 253
R_ARM_RPC24 R_ARM = 254
R_ARM_RBASE R_ARM = 255
)
var rarmStrings = []intName{
{0, "R_ARM_NONE"},
{1, "R_ARM_PC24"},
{2, "R_ARM_ABS32"},
{3, "R_ARM_REL32"},
{4, "R_ARM_PC13"},
{5, "R_ARM_ABS16"},
{6, "R_ARM_ABS12"},
{7, "R_ARM_THM_ABS5"},
{8, "R_ARM_ABS8"},
{9, "R_ARM_SBREL32"},
{10, "R_ARM_THM_PC22"},
{11, "R_ARM_THM_PC8"},
{12, "R_ARM_AMP_VCALL9"},
{13, "R_ARM_SWI24"},
{14, "R_ARM_THM_SWI8"},
{15, "R_ARM_XPC25"},
{16, "R_ARM_THM_XPC22"},
{17, "R_ARM_TLS_DTPMOD32"},
{18, "R_ARM_TLS_DTPOFF32"},
{19, "R_ARM_TLS_TPOFF32"},
{20, "R_ARM_COPY"},
{21, "R_ARM_GLOB_DAT"},
{22, "R_ARM_JUMP_SLOT"},
{23, "R_ARM_RELATIVE"},
{24, "R_ARM_GOTOFF"},
{25, "R_ARM_GOTPC"},
{26, "R_ARM_GOT32"},
{27, "R_ARM_PLT32"},
{28, "R_ARM_CALL"},
{29, "R_ARM_JUMP24"},
{30, "R_ARM_THM_JUMP24"},
{31, "R_ARM_BASE_ABS"},
{32, "R_ARM_ALU_PCREL_7_0"},
{33, "R_ARM_ALU_PCREL_15_8"},
{34, "R_ARM_ALU_PCREL_23_15"},
{35, "R_ARM_LDR_SBREL_11_10_NC"},
{36, "R_ARM_ALU_SBREL_19_12_NC"},
{37, "R_ARM_ALU_SBREL_27_20_CK"},
{38, "R_ARM_TARGET1"},
{39, "R_ARM_SBREL31"},
{40, "R_ARM_V4BX"},
{41, "R_ARM_TARGET2"},
{42, "R_ARM_PREL31"},
{43, "R_ARM_MOVW_ABS_NC"},
{44, "R_ARM_MOVT_ABS"},
{45, "R_ARM_MOVW_PREL_NC"},
{46, "R_ARM_MOVT_PREL"},
{47, "R_ARM_THM_MOVW_ABS_NC"},
{48, "R_ARM_THM_MOVT_ABS"},
{49, "R_ARM_THM_MOVW_PREL_NC"},
{50, "R_ARM_THM_MOVT_PREL"},
{51, "R_ARM_THM_JUMP19"},
{52, "R_ARM_THM_JUMP6"},
{53, "R_ARM_THM_ALU_PREL_11_0"},
{54, "R_ARM_THM_PC12"},
{55, "R_ARM_ABS32_NOI"},
{56, "R_ARM_REL32_NOI"},
{57, "R_ARM_ALU_PC_G0_NC"},
{58, "R_ARM_ALU_PC_G0"},
{59, "R_ARM_ALU_PC_G1_NC"},
{60, "R_ARM_ALU_PC_G1"},
{61, "R_ARM_ALU_PC_G2"},
{62, "R_ARM_LDR_PC_G1"},
{63, "R_ARM_LDR_PC_G2"},
{64, "R_ARM_LDRS_PC_G0"},
{65, "R_ARM_LDRS_PC_G1"},
{66, "R_ARM_LDRS_PC_G2"},
{67, "R_ARM_LDC_PC_G0"},
{68, "R_ARM_LDC_PC_G1"},
{69, "R_ARM_LDC_PC_G2"},
{70, "R_ARM_ALU_SB_G0_NC"},
{71, "R_ARM_ALU_SB_G0"},
{72, "R_ARM_ALU_SB_G1_NC"},
{73, "R_ARM_ALU_SB_G1"},
{74, "R_ARM_ALU_SB_G2"},
{75, "R_ARM_LDR_SB_G0"},
{76, "R_ARM_LDR_SB_G1"},
{77, "R_ARM_LDR_SB_G2"},
{78, "R_ARM_LDRS_SB_G0"},
{79, "R_ARM_LDRS_SB_G1"},
{80, "R_ARM_LDRS_SB_G2"},
{81, "R_ARM_LDC_SB_G0"},
{82, "R_ARM_LDC_SB_G1"},
{83, "R_ARM_LDC_SB_G2"},
{84, "R_ARM_MOVW_BREL_NC"},
{85, "R_ARM_MOVT_BREL"},
{86, "R_ARM_MOVW_BREL"},
{87, "R_ARM_THM_MOVW_BREL_NC"},
{88, "R_ARM_THM_MOVT_BREL"},
{89, "R_ARM_THM_MOVW_BREL"},
{90, "R_ARM_TLS_GOTDESC"},
{91, "R_ARM_TLS_CALL"},
{92, "R_ARM_TLS_DESCSEQ"},
{93, "R_ARM_THM_TLS_CALL"},
{94, "R_ARM_PLT32_ABS"},
{95, "R_ARM_GOT_ABS"},
{96, "R_ARM_GOT_PREL"},
{97, "R_ARM_GOT_BREL12"},
{98, "R_ARM_GOTOFF12"},
{99, "R_ARM_GOTRELAX"},
{100, "R_ARM_GNU_VTENTRY"},
{101, "R_ARM_GNU_VTINHERIT"},
{102, "R_ARM_THM_JUMP11"},
{103, "R_ARM_THM_JUMP8"},
{104, "R_ARM_TLS_GD32"},
{105, "R_ARM_TLS_LDM32"},
{106, "R_ARM_TLS_LDO32"},
{107, "R_ARM_TLS_IE32"},
{108, "R_ARM_TLS_LE32"},
{109, "R_ARM_TLS_LDO12"},
{110, "R_ARM_TLS_LE12"},
{111, "R_ARM_TLS_IE12GP"},
{112, "R_ARM_PRIVATE_0"},
{113, "R_ARM_PRIVATE_1"},
{114, "R_ARM_PRIVATE_2"},
{115, "R_ARM_PRIVATE_3"},
{116, "R_ARM_PRIVATE_4"},
{117, "R_ARM_PRIVATE_5"},
{118, "R_ARM_PRIVATE_6"},
{119, "R_ARM_PRIVATE_7"},
{120, "R_ARM_PRIVATE_8"},
{121, "R_ARM_PRIVATE_9"},
{122, "R_ARM_PRIVATE_10"},
{123, "R_ARM_PRIVATE_11"},
{124, "R_ARM_PRIVATE_12"},
{125, "R_ARM_PRIVATE_13"},
{126, "R_ARM_PRIVATE_14"},
{127, "R_ARM_PRIVATE_15"},
{128, "R_ARM_ME_TOO"},
{129, "R_ARM_THM_TLS_DESCSEQ16"},
{130, "R_ARM_THM_TLS_DESCSEQ32"},
{131, "R_ARM_THM_GOT_BREL12"},
{132, "R_ARM_THM_ALU_ABS_G0_NC"},
{133, "R_ARM_THM_ALU_ABS_G1_NC"},
{134, "R_ARM_THM_ALU_ABS_G2_NC"},
{135, "R_ARM_THM_ALU_ABS_G3"},
{160, "R_ARM_IRELATIVE"},
{249, "R_ARM_RXPC25"},
{250, "R_ARM_RSBREL32"},
{251, "R_ARM_THM_RPC22"},
{252, "R_ARM_RREL32"},
{253, "R_ARM_RABS32"},
{254, "R_ARM_RPC24"},
{255, "R_ARM_RBASE"},
}
func (i R_ARM) String() string { return stringName(uint32(i), rarmStrings, false) }
func (i R_ARM) GoString() string { return stringName(uint32(i), rarmStrings, true) }
// Relocation types for 386.
type R_386 int
const (
R_386_NONE R_386 = 0 /* No relocation. */
R_386_32 R_386 = 1 /* Add symbol value. */
R_386_PC32 R_386 = 2 /* Add PC-relative symbol value. */
R_386_GOT32 R_386 = 3 /* Add PC-relative GOT offset. */
R_386_PLT32 R_386 = 4 /* Add PC-relative PLT offset. */
R_386_COPY R_386 = 5 /* Copy data from shared object. */
R_386_GLOB_DAT R_386 = 6 /* Set GOT entry to data address. */
R_386_JMP_SLOT R_386 = 7 /* Set GOT entry to code address. */
R_386_RELATIVE R_386 = 8 /* Add load address of shared object. */
R_386_GOTOFF R_386 = 9 /* Add GOT-relative symbol address. */
R_386_GOTPC R_386 = 10 /* Add PC-relative GOT table address. */
R_386_32PLT R_386 = 11
R_386_TLS_TPOFF R_386 = 14 /* Negative offset in static TLS block */
R_386_TLS_IE R_386 = 15 /* Absolute address of GOT for -ve static TLS */
R_386_TLS_GOTIE R_386 = 16 /* GOT entry for negative static TLS block */
R_386_TLS_LE R_386 = 17 /* Negative offset relative to static TLS */
R_386_TLS_GD R_386 = 18 /* 32 bit offset to GOT (index,off) pair */
R_386_TLS_LDM R_386 = 19 /* 32 bit offset to GOT (index,zero) pair */
R_386_16 R_386 = 20
R_386_PC16 R_386 = 21
R_386_8 R_386 = 22
R_386_PC8 R_386 = 23
R_386_TLS_GD_32 R_386 = 24 /* 32 bit offset to GOT (index,off) pair */
R_386_TLS_GD_PUSH R_386 = 25 /* pushl instruction for Sun ABI GD sequence */
R_386_TLS_GD_CALL R_386 = 26 /* call instruction for Sun ABI GD sequence */
R_386_TLS_GD_POP R_386 = 27 /* popl instruction for Sun ABI GD sequence */
R_386_TLS_LDM_32 R_386 = 28 /* 32 bit offset to GOT (index,zero) pair */
R_386_TLS_LDM_PUSH R_386 = 29 /* pushl instruction for Sun ABI LD sequence */
R_386_TLS_LDM_CALL R_386 = 30 /* call instruction for Sun ABI LD sequence */
R_386_TLS_LDM_POP R_386 = 31 /* popl instruction for Sun ABI LD sequence */
R_386_TLS_LDO_32 R_386 = 32 /* 32 bit offset from start of TLS block */
R_386_TLS_IE_32 R_386 = 33 /* 32 bit offset to GOT static TLS offset entry */
R_386_TLS_LE_32 R_386 = 34 /* 32 bit offset within static TLS block */
R_386_TLS_DTPMOD32 R_386 = 35 /* GOT entry containing TLS index */
R_386_TLS_DTPOFF32 R_386 = 36 /* GOT entry containing TLS offset */
R_386_TLS_TPOFF32 R_386 = 37 /* GOT entry of -ve static TLS offset */
R_386_SIZE32 R_386 = 38
R_386_TLS_GOTDESC R_386 = 39
R_386_TLS_DESC_CALL R_386 = 40
R_386_TLS_DESC R_386 = 41
R_386_IRELATIVE R_386 = 42
R_386_GOT32X R_386 = 43
)
var r386Strings = []intName{
{0, "R_386_NONE"},
{1, "R_386_32"},
{2, "R_386_PC32"},
{3, "R_386_GOT32"},
{4, "R_386_PLT32"},
{5, "R_386_COPY"},
{6, "R_386_GLOB_DAT"},
{7, "R_386_JMP_SLOT"},
{8, "R_386_RELATIVE"},
{9, "R_386_GOTOFF"},
{10, "R_386_GOTPC"},
{11, "R_386_32PLT"},
{14, "R_386_TLS_TPOFF"},
{15, "R_386_TLS_IE"},
{16, "R_386_TLS_GOTIE"},
{17, "R_386_TLS_LE"},
{18, "R_386_TLS_GD"},
{19, "R_386_TLS_LDM"},
{20, "R_386_16"},
{21, "R_386_PC16"},
{22, "R_386_8"},
{23, "R_386_PC8"},
{24, "R_386_TLS_GD_32"},
{25, "R_386_TLS_GD_PUSH"},
{26, "R_386_TLS_GD_CALL"},
{27, "R_386_TLS_GD_POP"},
{28, "R_386_TLS_LDM_32"},
{29, "R_386_TLS_LDM_PUSH"},
{30, "R_386_TLS_LDM_CALL"},
{31, "R_386_TLS_LDM_POP"},
{32, "R_386_TLS_LDO_32"},
{33, "R_386_TLS_IE_32"},
{34, "R_386_TLS_LE_32"},
{35, "R_386_TLS_DTPMOD32"},
{36, "R_386_TLS_DTPOFF32"},
{37, "R_386_TLS_TPOFF32"},
{38, "R_386_SIZE32"},
{39, "R_386_TLS_GOTDESC"},
{40, "R_386_TLS_DESC_CALL"},
{41, "R_386_TLS_DESC"},
{42, "R_386_IRELATIVE"},
{43, "R_386_GOT32X"},
}
func (i R_386) String() string { return stringName(uint32(i), r386Strings, false) }
func (i R_386) GoString() string { return stringName(uint32(i), r386Strings, true) }
// Relocation types for MIPS.
type R_MIPS int
const (
R_MIPS_NONE R_MIPS = 0
R_MIPS_16 R_MIPS = 1
R_MIPS_32 R_MIPS = 2
R_MIPS_REL32 R_MIPS = 3
R_MIPS_26 R_MIPS = 4
R_MIPS_HI16 R_MIPS = 5 /* high 16 bits of symbol value */
R_MIPS_LO16 R_MIPS = 6 /* low 16 bits of symbol value */
R_MIPS_GPREL16 R_MIPS = 7 /* GP-relative reference */
R_MIPS_LITERAL R_MIPS = 8 /* Reference to literal section */
R_MIPS_GOT16 R_MIPS = 9 /* Reference to global offset table */
R_MIPS_PC16 R_MIPS = 10 /* 16 bit PC relative reference */
R_MIPS_CALL16 R_MIPS = 11 /* 16 bit call through glbl offset tbl */
R_MIPS_GPREL32 R_MIPS = 12
R_MIPS_SHIFT5 R_MIPS = 16
R_MIPS_SHIFT6 R_MIPS = 17
R_MIPS_64 R_MIPS = 18
R_MIPS_GOT_DISP R_MIPS = 19
R_MIPS_GOT_PAGE R_MIPS = 20
R_MIPS_GOT_OFST R_MIPS = 21
R_MIPS_GOT_HI16 R_MIPS = 22
R_MIPS_GOT_LO16 R_MIPS = 23
R_MIPS_SUB R_MIPS = 24
R_MIPS_INSERT_A R_MIPS = 25
R_MIPS_INSERT_B R_MIPS = 26
R_MIPS_DELETE R_MIPS = 27
R_MIPS_HIGHER R_MIPS = 28
R_MIPS_HIGHEST R_MIPS = 29
R_MIPS_CALL_HI16 R_MIPS = 30
R_MIPS_CALL_LO16 R_MIPS = 31
R_MIPS_SCN_DISP R_MIPS = 32
R_MIPS_REL16 R_MIPS = 33
R_MIPS_ADD_IMMEDIATE R_MIPS = 34
R_MIPS_PJUMP R_MIPS = 35
R_MIPS_RELGOT R_MIPS = 36
R_MIPS_JALR R_MIPS = 37
R_MIPS_TLS_DTPMOD32 R_MIPS = 38 /* Module number 32 bit */
R_MIPS_TLS_DTPREL32 R_MIPS = 39 /* Module-relative offset 32 bit */
R_MIPS_TLS_DTPMOD64 R_MIPS = 40 /* Module number 64 bit */
R_MIPS_TLS_DTPREL64 R_MIPS = 41 /* Module-relative offset 64 bit */
R_MIPS_TLS_GD R_MIPS = 42 /* 16 bit GOT offset for GD */
R_MIPS_TLS_LDM R_MIPS = 43 /* 16 bit GOT offset for LDM */
R_MIPS_TLS_DTPREL_HI16 R_MIPS = 44 /* Module-relative offset, high 16 bits */
R_MIPS_TLS_DTPREL_LO16 R_MIPS = 45 /* Module-relative offset, low 16 bits */
R_MIPS_TLS_GOTTPREL R_MIPS = 46 /* 16 bit GOT offset for IE */
R_MIPS_TLS_TPREL32 R_MIPS = 47 /* TP-relative offset, 32 bit */
R_MIPS_TLS_TPREL64 R_MIPS = 48 /* TP-relative offset, 64 bit */
R_MIPS_TLS_TPREL_HI16 R_MIPS = 49 /* TP-relative offset, high 16 bits */
R_MIPS_TLS_TPREL_LO16 R_MIPS = 50 /* TP-relative offset, low 16 bits */
)
var rmipsStrings = []intName{
{0, "R_MIPS_NONE"},
{1, "R_MIPS_16"},
{2, "R_MIPS_32"},
{3, "R_MIPS_REL32"},
{4, "R_MIPS_26"},
{5, "R_MIPS_HI16"},
{6, "R_MIPS_LO16"},
{7, "R_MIPS_GPREL16"},
{8, "R_MIPS_LITERAL"},
{9, "R_MIPS_GOT16"},
{10, "R_MIPS_PC16"},
{11, "R_MIPS_CALL16"},
{12, "R_MIPS_GPREL32"},
{16, "R_MIPS_SHIFT5"},
{17, "R_MIPS_SHIFT6"},
{18, "R_MIPS_64"},
{19, "R_MIPS_GOT_DISP"},
{20, "R_MIPS_GOT_PAGE"},
{21, "R_MIPS_GOT_OFST"},
{22, "R_MIPS_GOT_HI16"},
{23, "R_MIPS_GOT_LO16"},
{24, "R_MIPS_SUB"},
{25, "R_MIPS_INSERT_A"},
{26, "R_MIPS_INSERT_B"},
{27, "R_MIPS_DELETE"},
{28, "R_MIPS_HIGHER"},
{29, "R_MIPS_HIGHEST"},
{30, "R_MIPS_CALL_HI16"},
{31, "R_MIPS_CALL_LO16"},
{32, "R_MIPS_SCN_DISP"},
{33, "R_MIPS_REL16"},
{34, "R_MIPS_ADD_IMMEDIATE"},
{35, "R_MIPS_PJUMP"},
{36, "R_MIPS_RELGOT"},
{37, "R_MIPS_JALR"},
{38, "R_MIPS_TLS_DTPMOD32"},
{39, "R_MIPS_TLS_DTPREL32"},
{40, "R_MIPS_TLS_DTPMOD64"},
{41, "R_MIPS_TLS_DTPREL64"},
{42, "R_MIPS_TLS_GD"},
{43, "R_MIPS_TLS_LDM"},
{44, "R_MIPS_TLS_DTPREL_HI16"},
{45, "R_MIPS_TLS_DTPREL_LO16"},
{46, "R_MIPS_TLS_GOTTPREL"},
{47, "R_MIPS_TLS_TPREL32"},
{48, "R_MIPS_TLS_TPREL64"},
{49, "R_MIPS_TLS_TPREL_HI16"},
{50, "R_MIPS_TLS_TPREL_LO16"},
}
func (i R_MIPS) String() string { return stringName(uint32(i), rmipsStrings, false) }
func (i R_MIPS) GoString() string { return stringName(uint32(i), rmipsStrings, true) }
// Relocation types for LoongArch.
type R_LARCH int
const (
R_LARCH_NONE R_LARCH = 0
R_LARCH_32 R_LARCH = 1
R_LARCH_64 R_LARCH = 2
R_LARCH_RELATIVE R_LARCH = 3
R_LARCH_COPY R_LARCH = 4
R_LARCH_JUMP_SLOT R_LARCH = 5
R_LARCH_TLS_DTPMOD32 R_LARCH = 6
R_LARCH_TLS_DTPMOD64 R_LARCH = 7
R_LARCH_TLS_DTPREL32 R_LARCH = 8
R_LARCH_TLS_DTPREL64 R_LARCH = 9
R_LARCH_TLS_TPREL32 R_LARCH = 10
R_LARCH_TLS_TPREL64 R_LARCH = 11
R_LARCH_IRELATIVE R_LARCH = 12
R_LARCH_MARK_LA R_LARCH = 20
R_LARCH_MARK_PCREL R_LARCH = 21
R_LARCH_SOP_PUSH_PCREL R_LARCH = 22
R_LARCH_SOP_PUSH_ABSOLUTE R_LARCH = 23
R_LARCH_SOP_PUSH_DUP R_LARCH = 24
R_LARCH_SOP_PUSH_GPREL R_LARCH = 25
R_LARCH_SOP_PUSH_TLS_TPREL R_LARCH = 26
R_LARCH_SOP_PUSH_TLS_GOT R_LARCH = 27
R_LARCH_SOP_PUSH_TLS_GD R_LARCH = 28
R_LARCH_SOP_PUSH_PLT_PCREL R_LARCH = 29
R_LARCH_SOP_ASSERT R_LARCH = 30
R_LARCH_SOP_NOT R_LARCH = 31
R_LARCH_SOP_SUB R_LARCH = 32
R_LARCH_SOP_SL R_LARCH = 33
R_LARCH_SOP_SR R_LARCH = 34
R_LARCH_SOP_ADD R_LARCH = 35
R_LARCH_SOP_AND R_LARCH = 36
R_LARCH_SOP_IF_ELSE R_LARCH = 37
R_LARCH_SOP_POP_32_S_10_5 R_LARCH = 38
R_LARCH_SOP_POP_32_U_10_12 R_LARCH = 39
R_LARCH_SOP_POP_32_S_10_12 R_LARCH = 40
R_LARCH_SOP_POP_32_S_10_16 R_LARCH = 41
R_LARCH_SOP_POP_32_S_10_16_S2 R_LARCH = 42
R_LARCH_SOP_POP_32_S_5_20 R_LARCH = 43
R_LARCH_SOP_POP_32_S_0_5_10_16_S2 R_LARCH = 44
R_LARCH_SOP_POP_32_S_0_10_10_16_S2 R_LARCH = 45
R_LARCH_SOP_POP_32_U R_LARCH = 46
R_LARCH_ADD8 R_LARCH = 47
R_LARCH_ADD16 R_LARCH = 48
R_LARCH_ADD24 R_LARCH = 49
R_LARCH_ADD32 R_LARCH = 50
R_LARCH_ADD64 R_LARCH = 51
R_LARCH_SUB8 R_LARCH = 52
R_LARCH_SUB16 R_LARCH = 53
R_LARCH_SUB24 R_LARCH = 54
R_LARCH_SUB32 R_LARCH = 55
R_LARCH_SUB64 R_LARCH = 56
R_LARCH_GNU_VTINHERIT R_LARCH = 57
R_LARCH_GNU_VTENTRY R_LARCH = 58
R_LARCH_B16 R_LARCH = 64
R_LARCH_B21 R_LARCH = 65
R_LARCH_B26 R_LARCH = 66
R_LARCH_ABS_HI20 R_LARCH = 67
R_LARCH_ABS_LO12 R_LARCH = 68
R_LARCH_ABS64_LO20 R_LARCH = 69
R_LARCH_ABS64_HI12 R_LARCH = 70
R_LARCH_PCALA_HI20 R_LARCH = 71
R_LARCH_PCALA_LO12 R_LARCH = 72
R_LARCH_PCALA64_LO20 R_LARCH = 73
R_LARCH_PCALA64_HI12 R_LARCH = 74
R_LARCH_GOT_PC_HI20 R_LARCH = 75
R_LARCH_GOT_PC_LO12 R_LARCH = 76
R_LARCH_GOT64_PC_LO20 R_LARCH = 77
R_LARCH_GOT64_PC_HI12 R_LARCH = 78
R_LARCH_GOT_HI20 R_LARCH = 79
R_LARCH_GOT_LO12 R_LARCH = 80
R_LARCH_GOT64_LO20 R_LARCH = 81
R_LARCH_GOT64_HI12 R_LARCH = 82
R_LARCH_TLS_LE_HI20 R_LARCH = 83
R_LARCH_TLS_LE_LO12 R_LARCH = 84
R_LARCH_TLS_LE64_LO20 R_LARCH = 85
R_LARCH_TLS_LE64_HI12 R_LARCH = 86
R_LARCH_TLS_IE_PC_HI20 R_LARCH = 87
R_LARCH_TLS_IE_PC_LO12 R_LARCH = 88
R_LARCH_TLS_IE64_PC_LO20 R_LARCH = 89
R_LARCH_TLS_IE64_PC_HI12 R_LARCH = 90
R_LARCH_TLS_IE_HI20 R_LARCH = 91
R_LARCH_TLS_IE_LO12 R_LARCH = 92
R_LARCH_TLS_IE64_LO20 R_LARCH = 93
R_LARCH_TLS_IE64_HI12 R_LARCH = 94
R_LARCH_TLS_LD_PC_HI20 R_LARCH = 95
R_LARCH_TLS_LD_HI20 R_LARCH = 96
R_LARCH_TLS_GD_PC_HI20 R_LARCH = 97
R_LARCH_TLS_GD_HI20 R_LARCH = 98
R_LARCH_32_PCREL R_LARCH = 99
R_LARCH_RELAX R_LARCH = 100
)
var rlarchStrings = []intName{
{0, "R_LARCH_NONE"},
{1, "R_LARCH_32"},
{2, "R_LARCH_64"},
{3, "R_LARCH_RELATIVE"},
{4, "R_LARCH_COPY"},
{5, "R_LARCH_JUMP_SLOT"},
{6, "R_LARCH_TLS_DTPMOD32"},
{7, "R_LARCH_TLS_DTPMOD64"},
{8, "R_LARCH_TLS_DTPREL32"},
{9, "R_LARCH_TLS_DTPREL64"},
{10, "R_LARCH_TLS_TPREL32"},
{11, "R_LARCH_TLS_TPREL64"},
{12, "R_LARCH_IRELATIVE"},
{20, "R_LARCH_MARK_LA"},
{21, "R_LARCH_MARK_PCREL"},
{22, "R_LARCH_SOP_PUSH_PCREL"},
{23, "R_LARCH_SOP_PUSH_ABSOLUTE"},
{24, "R_LARCH_SOP_PUSH_DUP"},
{25, "R_LARCH_SOP_PUSH_GPREL"},
{26, "R_LARCH_SOP_PUSH_TLS_TPREL"},
{27, "R_LARCH_SOP_PUSH_TLS_GOT"},
{28, "R_LARCH_SOP_PUSH_TLS_GD"},
{29, "R_LARCH_SOP_PUSH_PLT_PCREL"},
{30, "R_LARCH_SOP_ASSERT"},
{31, "R_LARCH_SOP_NOT"},
{32, "R_LARCH_SOP_SUB"},
{33, "R_LARCH_SOP_SL"},
{34, "R_LARCH_SOP_SR"},
{35, "R_LARCH_SOP_ADD"},
{36, "R_LARCH_SOP_AND"},
{37, "R_LARCH_SOP_IF_ELSE"},
{38, "R_LARCH_SOP_POP_32_S_10_5"},
{39, "R_LARCH_SOP_POP_32_U_10_12"},
{40, "R_LARCH_SOP_POP_32_S_10_12"},
{41, "R_LARCH_SOP_POP_32_S_10_16"},
{42, "R_LARCH_SOP_POP_32_S_10_16_S2"},
{43, "R_LARCH_SOP_POP_32_S_5_20"},
{44, "R_LARCH_SOP_POP_32_S_0_5_10_16_S2"},
{45, "R_LARCH_SOP_POP_32_S_0_10_10_16_S2"},
{46, "R_LARCH_SOP_POP_32_U"},
{47, "R_LARCH_ADD8"},
{48, "R_LARCH_ADD16"},
{49, "R_LARCH_ADD24"},
{50, "R_LARCH_ADD32"},
{51, "R_LARCH_ADD64"},
{52, "R_LARCH_SUB8"},
{53, "R_LARCH_SUB16"},
{54, "R_LARCH_SUB24"},
{55, "R_LARCH_SUB32"},
{56, "R_LARCH_SUB64"},
{57, "R_LARCH_GNU_VTINHERIT"},
{58, "R_LARCH_GNU_VTENTRY"},
{64, "R_LARCH_B16"},
{65, "R_LARCH_B21"},
{66, "R_LARCH_B26"},
{67, "R_LARCH_ABS_HI20"},
{68, "R_LARCH_ABS_LO12"},
{69, "R_LARCH_ABS64_LO20"},
{70, "R_LARCH_ABS64_HI12"},
{71, "R_LARCH_PCALA_HI20"},
{72, "R_LARCH_PCALA_LO12"},
{73, "R_LARCH_PCALA64_LO20"},
{74, "R_LARCH_PCALA64_HI12"},
{75, "R_LARCH_GOT_PC_HI20"},
{76, "R_LARCH_GOT_PC_LO12"},
{77, "R_LARCH_GOT64_PC_LO20"},
{78, "R_LARCH_GOT64_PC_HI12"},
{79, "R_LARCH_GOT_HI20"},
{80, "R_LARCH_GOT_LO12"},
{81, "R_LARCH_GOT64_LO20"},
{82, "R_LARCH_GOT64_HI12"},
{83, "R_LARCH_TLS_LE_HI20"},
{84, "R_LARCH_TLS_LE_LO12"},
{85, "R_LARCH_TLS_LE64_LO20"},
{86, "R_LARCH_TLS_LE64_HI12"},
{87, "R_LARCH_TLS_IE_PC_HI20"},
{88, "R_LARCH_TLS_IE_PC_LO12"},
{89, "R_LARCH_TLS_IE64_PC_LO20"},
{90, "R_LARCH_TLS_IE64_PC_HI12"},
{91, "R_LARCH_TLS_IE_HI20"},
{92, "R_LARCH_TLS_IE_LO12"},
{93, "R_LARCH_TLS_IE64_LO20"},
{94, "R_LARCH_TLS_IE64_HI12"},
{95, "R_LARCH_TLS_LD_PC_HI20"},
{96, "R_LARCH_TLS_LD_HI20"},
{97, "R_LARCH_TLS_GD_PC_HI20"},
{98, "R_LARCH_TLS_GD_HI20"},
{99, "R_LARCH_32_PCREL"},
{100, "R_LARCH_RELAX"},
}
func (i R_LARCH) String() string { return stringName(uint32(i), rlarchStrings, false) }
func (i R_LARCH) GoString() string { return stringName(uint32(i), rlarchStrings, true) }
// Relocation types for PowerPC.
//
// Values that are shared by both R_PPC and R_PPC64 are prefixed with
// R_POWERPC_ in the ELF standard. For the R_PPC type, the relevant
// shared relocations have been renamed with the prefix R_PPC_.
// The original name follows the value in a comment.
type R_PPC int
const (
R_PPC_NONE R_PPC = 0 // R_POWERPC_NONE
R_PPC_ADDR32 R_PPC = 1 // R_POWERPC_ADDR32
R_PPC_ADDR24 R_PPC = 2 // R_POWERPC_ADDR24
R_PPC_ADDR16 R_PPC = 3 // R_POWERPC_ADDR16
R_PPC_ADDR16_LO R_PPC = 4 // R_POWERPC_ADDR16_LO
R_PPC_ADDR16_HI R_PPC = 5 // R_POWERPC_ADDR16_HI
R_PPC_ADDR16_HA R_PPC = 6 // R_POWERPC_ADDR16_HA
R_PPC_ADDR14 R_PPC = 7 // R_POWERPC_ADDR14
R_PPC_ADDR14_BRTAKEN R_PPC = 8 // R_POWERPC_ADDR14_BRTAKEN
R_PPC_ADDR14_BRNTAKEN R_PPC = 9 // R_POWERPC_ADDR14_BRNTAKEN
R_PPC_REL24 R_PPC = 10 // R_POWERPC_REL24
R_PPC_REL14 R_PPC = 11 // R_POWERPC_REL14
R_PPC_REL14_BRTAKEN R_PPC = 12 // R_POWERPC_REL14_BRTAKEN
R_PPC_REL14_BRNTAKEN R_PPC = 13 // R_POWERPC_REL14_BRNTAKEN
R_PPC_GOT16 R_PPC = 14 // R_POWERPC_GOT16
R_PPC_GOT16_LO R_PPC = 15 // R_POWERPC_GOT16_LO
R_PPC_GOT16_HI R_PPC = 16 // R_POWERPC_GOT16_HI
R_PPC_GOT16_HA R_PPC = 17 // R_POWERPC_GOT16_HA
R_PPC_PLTREL24 R_PPC = 18
R_PPC_COPY R_PPC = 19 // R_POWERPC_COPY
R_PPC_GLOB_DAT R_PPC = 20 // R_POWERPC_GLOB_DAT
R_PPC_JMP_SLOT R_PPC = 21 // R_POWERPC_JMP_SLOT
R_PPC_RELATIVE R_PPC = 22 // R_POWERPC_RELATIVE
R_PPC_LOCAL24PC R_PPC = 23
R_PPC_UADDR32 R_PPC = 24 // R_POWERPC_UADDR32
R_PPC_UADDR16 R_PPC = 25 // R_POWERPC_UADDR16
R_PPC_REL32 R_PPC = 26 // R_POWERPC_REL32
R_PPC_PLT32 R_PPC = 27 // R_POWERPC_PLT32
R_PPC_PLTREL32 R_PPC = 28 // R_POWERPC_PLTREL32
R_PPC_PLT16_LO R_PPC = 29 // R_POWERPC_PLT16_LO
R_PPC_PLT16_HI R_PPC = 30 // R_POWERPC_PLT16_HI
R_PPC_PLT16_HA R_PPC = 31 // R_POWERPC_PLT16_HA
R_PPC_SDAREL16 R_PPC = 32
R_PPC_SECTOFF R_PPC = 33 // R_POWERPC_SECTOFF
R_PPC_SECTOFF_LO R_PPC = 34 // R_POWERPC_SECTOFF_LO
R_PPC_SECTOFF_HI R_PPC = 35 // R_POWERPC_SECTOFF_HI
R_PPC_SECTOFF_HA R_PPC = 36 // R_POWERPC_SECTOFF_HA
R_PPC_TLS R_PPC = 67 // R_POWERPC_TLS
R_PPC_DTPMOD32 R_PPC = 68 // R_POWERPC_DTPMOD32
R_PPC_TPREL16 R_PPC = 69 // R_POWERPC_TPREL16
R_PPC_TPREL16_LO R_PPC = 70 // R_POWERPC_TPREL16_LO
R_PPC_TPREL16_HI R_PPC = 71 // R_POWERPC_TPREL16_HI
R_PPC_TPREL16_HA R_PPC = 72 // R_POWERPC_TPREL16_HA
R_PPC_TPREL32 R_PPC = 73 // R_POWERPC_TPREL32
R_PPC_DTPREL16 R_PPC = 74 // R_POWERPC_DTPREL16
R_PPC_DTPREL16_LO R_PPC = 75 // R_POWERPC_DTPREL16_LO
R_PPC_DTPREL16_HI R_PPC = 76 // R_POWERPC_DTPREL16_HI
R_PPC_DTPREL16_HA R_PPC = 77 // R_POWERPC_DTPREL16_HA
R_PPC_DTPREL32 R_PPC = 78 // R_POWERPC_DTPREL32
R_PPC_GOT_TLSGD16 R_PPC = 79 // R_POWERPC_GOT_TLSGD16
R_PPC_GOT_TLSGD16_LO R_PPC = 80 // R_POWERPC_GOT_TLSGD16_LO
R_PPC_GOT_TLSGD16_HI R_PPC = 81 // R_POWERPC_GOT_TLSGD16_HI
R_PPC_GOT_TLSGD16_HA R_PPC = 82 // R_POWERPC_GOT_TLSGD16_HA
R_PPC_GOT_TLSLD16 R_PPC = 83 // R_POWERPC_GOT_TLSLD16
R_PPC_GOT_TLSLD16_LO R_PPC = 84 // R_POWERPC_GOT_TLSLD16_LO
R_PPC_GOT_TLSLD16_HI R_PPC = 85 // R_POWERPC_GOT_TLSLD16_HI
R_PPC_GOT_TLSLD16_HA R_PPC = 86 // R_POWERPC_GOT_TLSLD16_HA
R_PPC_GOT_TPREL16 R_PPC = 87 // R_POWERPC_GOT_TPREL16
R_PPC_GOT_TPREL16_LO R_PPC = 88 // R_POWERPC_GOT_TPREL16_LO
R_PPC_GOT_TPREL16_HI R_PPC = 89 // R_POWERPC_GOT_TPREL16_HI
R_PPC_GOT_TPREL16_HA R_PPC = 90 // R_POWERPC_GOT_TPREL16_HA
R_PPC_EMB_NADDR32 R_PPC = 101
R_PPC_EMB_NADDR16 R_PPC = 102
R_PPC_EMB_NADDR16_LO R_PPC = 103
R_PPC_EMB_NADDR16_HI R_PPC = 104
R_PPC_EMB_NADDR16_HA R_PPC = 105
R_PPC_EMB_SDAI16 R_PPC = 106
R_PPC_EMB_SDA2I16 R_PPC = 107
R_PPC_EMB_SDA2REL R_PPC = 108
R_PPC_EMB_SDA21 R_PPC = 109
R_PPC_EMB_MRKREF R_PPC = 110
R_PPC_EMB_RELSEC16 R_PPC = 111
R_PPC_EMB_RELST_LO R_PPC = 112
R_PPC_EMB_RELST_HI R_PPC = 113
R_PPC_EMB_RELST_HA R_PPC = 114
R_PPC_EMB_BIT_FLD R_PPC = 115
R_PPC_EMB_RELSDA R_PPC = 116
)
var rppcStrings = []intName{
{0, "R_PPC_NONE"},
{1, "R_PPC_ADDR32"},
{2, "R_PPC_ADDR24"},
{3, "R_PPC_ADDR16"},
{4, "R_PPC_ADDR16_LO"},
{5, "R_PPC_ADDR16_HI"},
{6, "R_PPC_ADDR16_HA"},
{7, "R_PPC_ADDR14"},
{8, "R_PPC_ADDR14_BRTAKEN"},
{9, "R_PPC_ADDR14_BRNTAKEN"},
{10, "R_PPC_REL24"},
{11, "R_PPC_REL14"},
{12, "R_PPC_REL14_BRTAKEN"},
{13, "R_PPC_REL14_BRNTAKEN"},
{14, "R_PPC_GOT16"},
{15, "R_PPC_GOT16_LO"},
{16, "R_PPC_GOT16_HI"},
{17, "R_PPC_GOT16_HA"},
{18, "R_PPC_PLTREL24"},
{19, "R_PPC_COPY"},
{20, "R_PPC_GLOB_DAT"},
{21, "R_PPC_JMP_SLOT"},
{22, "R_PPC_RELATIVE"},
{23, "R_PPC_LOCAL24PC"},
{24, "R_PPC_UADDR32"},
{25, "R_PPC_UADDR16"},
{26, "R_PPC_REL32"},
{27, "R_PPC_PLT32"},
{28, "R_PPC_PLTREL32"},
{29, "R_PPC_PLT16_LO"},
{30, "R_PPC_PLT16_HI"},
{31, "R_PPC_PLT16_HA"},
{32, "R_PPC_SDAREL16"},
{33, "R_PPC_SECTOFF"},
{34, "R_PPC_SECTOFF_LO"},
{35, "R_PPC_SECTOFF_HI"},
{36, "R_PPC_SECTOFF_HA"},
{67, "R_PPC_TLS"},
{68, "R_PPC_DTPMOD32"},
{69, "R_PPC_TPREL16"},
{70, "R_PPC_TPREL16_LO"},
{71, "R_PPC_TPREL16_HI"},
{72, "R_PPC_TPREL16_HA"},
{73, "R_PPC_TPREL32"},
{74, "R_PPC_DTPREL16"},
{75, "R_PPC_DTPREL16_LO"},
{76, "R_PPC_DTPREL16_HI"},
{77, "R_PPC_DTPREL16_HA"},
{78, "R_PPC_DTPREL32"},
{79, "R_PPC_GOT_TLSGD16"},
{80, "R_PPC_GOT_TLSGD16_LO"},
{81, "R_PPC_GOT_TLSGD16_HI"},
{82, "R_PPC_GOT_TLSGD16_HA"},
{83, "R_PPC_GOT_TLSLD16"},
{84, "R_PPC_GOT_TLSLD16_LO"},
{85, "R_PPC_GOT_TLSLD16_HI"},
{86, "R_PPC_GOT_TLSLD16_HA"},
{87, "R_PPC_GOT_TPREL16"},
{88, "R_PPC_GOT_TPREL16_LO"},
{89, "R_PPC_GOT_TPREL16_HI"},
{90, "R_PPC_GOT_TPREL16_HA"},
{101, "R_PPC_EMB_NADDR32"},
{102, "R_PPC_EMB_NADDR16"},
{103, "R_PPC_EMB_NADDR16_LO"},
{104, "R_PPC_EMB_NADDR16_HI"},
{105, "R_PPC_EMB_NADDR16_HA"},
{106, "R_PPC_EMB_SDAI16"},
{107, "R_PPC_EMB_SDA2I16"},
{108, "R_PPC_EMB_SDA2REL"},
{109, "R_PPC_EMB_SDA21"},
{110, "R_PPC_EMB_MRKREF"},
{111, "R_PPC_EMB_RELSEC16"},
{112, "R_PPC_EMB_RELST_LO"},
{113, "R_PPC_EMB_RELST_HI"},
{114, "R_PPC_EMB_RELST_HA"},
{115, "R_PPC_EMB_BIT_FLD"},
{116, "R_PPC_EMB_RELSDA"},
}
func (i R_PPC) String() string { return stringName(uint32(i), rppcStrings, false) }
func (i R_PPC) GoString() string { return stringName(uint32(i), rppcStrings, true) }
// Relocation types for 64-bit PowerPC or Power Architecture processors.
//
// Values that are shared by both R_PPC and R_PPC64 are prefixed with
// R_POWERPC_ in the ELF standard. For the R_PPC64 type, the relevant
// shared relocations have been renamed with the prefix R_PPC64_.
// The original name follows the value in a comment.
type R_PPC64 int
const (
R_PPC64_NONE R_PPC64 = 0 // R_POWERPC_NONE
R_PPC64_ADDR32 R_PPC64 = 1 // R_POWERPC_ADDR32
R_PPC64_ADDR24 R_PPC64 = 2 // R_POWERPC_ADDR24
R_PPC64_ADDR16 R_PPC64 = 3 // R_POWERPC_ADDR16
R_PPC64_ADDR16_LO R_PPC64 = 4 // R_POWERPC_ADDR16_LO
R_PPC64_ADDR16_HI R_PPC64 = 5 // R_POWERPC_ADDR16_HI
R_PPC64_ADDR16_HA R_PPC64 = 6 // R_POWERPC_ADDR16_HA
R_PPC64_ADDR14 R_PPC64 = 7 // R_POWERPC_ADDR14
R_PPC64_ADDR14_BRTAKEN R_PPC64 = 8 // R_POWERPC_ADDR14_BRTAKEN
R_PPC64_ADDR14_BRNTAKEN R_PPC64 = 9 // R_POWERPC_ADDR14_BRNTAKEN
R_PPC64_REL24 R_PPC64 = 10 // R_POWERPC_REL24
R_PPC64_REL14 R_PPC64 = 11 // R_POWERPC_REL14
R_PPC64_REL14_BRTAKEN R_PPC64 = 12 // R_POWERPC_REL14_BRTAKEN
R_PPC64_REL14_BRNTAKEN R_PPC64 = 13 // R_POWERPC_REL14_BRNTAKEN
R_PPC64_GOT16 R_PPC64 = 14 // R_POWERPC_GOT16
R_PPC64_GOT16_LO R_PPC64 = 15 // R_POWERPC_GOT16_LO
R_PPC64_GOT16_HI R_PPC64 = 16 // R_POWERPC_GOT16_HI
R_PPC64_GOT16_HA R_PPC64 = 17 // R_POWERPC_GOT16_HA
R_PPC64_COPY R_PPC64 = 19 // R_POWERPC_COPY
R_PPC64_GLOB_DAT R_PPC64 = 20 // R_POWERPC_GLOB_DAT
R_PPC64_JMP_SLOT R_PPC64 = 21 // R_POWERPC_JMP_SLOT
R_PPC64_RELATIVE R_PPC64 = 22 // R_POWERPC_RELATIVE
R_PPC64_UADDR32 R_PPC64 = 24 // R_POWERPC_UADDR32
R_PPC64_UADDR16 R_PPC64 = 25 // R_POWERPC_UADDR16
R_PPC64_REL32 R_PPC64 = 26 // R_POWERPC_REL32
R_PPC64_PLT32 R_PPC64 = 27 // R_POWERPC_PLT32
R_PPC64_PLTREL32 R_PPC64 = 28 // R_POWERPC_PLTREL32
R_PPC64_PLT16_LO R_PPC64 = 29 // R_POWERPC_PLT16_LO
R_PPC64_PLT16_HI R_PPC64 = 30 // R_POWERPC_PLT16_HI
R_PPC64_PLT16_HA R_PPC64 = 31 // R_POWERPC_PLT16_HA
R_PPC64_SECTOFF R_PPC64 = 33 // R_POWERPC_SECTOFF
R_PPC64_SECTOFF_LO R_PPC64 = 34 // R_POWERPC_SECTOFF_LO
R_PPC64_SECTOFF_HI R_PPC64 = 35 // R_POWERPC_SECTOFF_HI
R_PPC64_SECTOFF_HA R_PPC64 = 36 // R_POWERPC_SECTOFF_HA
R_PPC64_REL30 R_PPC64 = 37 // R_POWERPC_ADDR30
R_PPC64_ADDR64 R_PPC64 = 38
R_PPC64_ADDR16_HIGHER R_PPC64 = 39
R_PPC64_ADDR16_HIGHERA R_PPC64 = 40
R_PPC64_ADDR16_HIGHEST R_PPC64 = 41
R_PPC64_ADDR16_HIGHESTA R_PPC64 = 42
R_PPC64_UADDR64 R_PPC64 = 43
R_PPC64_REL64 R_PPC64 = 44
R_PPC64_PLT64 R_PPC64 = 45
R_PPC64_PLTREL64 R_PPC64 = 46
R_PPC64_TOC16 R_PPC64 = 47
R_PPC64_TOC16_LO R_PPC64 = 48
R_PPC64_TOC16_HI R_PPC64 = 49
R_PPC64_TOC16_HA R_PPC64 = 50
R_PPC64_TOC R_PPC64 = 51
R_PPC64_PLTGOT16 R_PPC64 = 52
R_PPC64_PLTGOT16_LO R_PPC64 = 53
R_PPC64_PLTGOT16_HI R_PPC64 = 54
R_PPC64_PLTGOT16_HA R_PPC64 = 55
R_PPC64_ADDR16_DS R_PPC64 = 56
R_PPC64_ADDR16_LO_DS R_PPC64 = 57
R_PPC64_GOT16_DS R_PPC64 = 58
R_PPC64_GOT16_LO_DS R_PPC64 = 59
R_PPC64_PLT16_LO_DS R_PPC64 = 60
R_PPC64_SECTOFF_DS R_PPC64 = 61
R_PPC64_SECTOFF_LO_DS R_PPC64 = 62
R_PPC64_TOC16_DS R_PPC64 = 63
R_PPC64_TOC16_LO_DS R_PPC64 = 64
R_PPC64_PLTGOT16_DS R_PPC64 = 65
R_PPC64_PLTGOT_LO_DS R_PPC64 = 66
R_PPC64_TLS R_PPC64 = 67 // R_POWERPC_TLS
R_PPC64_DTPMOD64 R_PPC64 = 68 // R_POWERPC_DTPMOD64
R_PPC64_TPREL16 R_PPC64 = 69 // R_POWERPC_TPREL16
R_PPC64_TPREL16_LO R_PPC64 = 70 // R_POWERPC_TPREL16_LO
R_PPC64_TPREL16_HI R_PPC64 = 71 // R_POWERPC_TPREL16_HI
R_PPC64_TPREL16_HA R_PPC64 = 72 // R_POWERPC_TPREL16_HA
R_PPC64_TPREL64 R_PPC64 = 73 // R_POWERPC_TPREL64
R_PPC64_DTPREL16 R_PPC64 = 74 // R_POWERPC_DTPREL16
R_PPC64_DTPREL16_LO R_PPC64 = 75 // R_POWERPC_DTPREL16_LO
R_PPC64_DTPREL16_HI R_PPC64 = 76 // R_POWERPC_DTPREL16_HI
R_PPC64_DTPREL16_HA R_PPC64 = 77 // R_POWERPC_DTPREL16_HA
R_PPC64_DTPREL64 R_PPC64 = 78 // R_POWERPC_DTPREL64
R_PPC64_GOT_TLSGD16 R_PPC64 = 79 // R_POWERPC_GOT_TLSGD16
R_PPC64_GOT_TLSGD16_LO R_PPC64 = 80 // R_POWERPC_GOT_TLSGD16_LO
R_PPC64_GOT_TLSGD16_HI R_PPC64 = 81 // R_POWERPC_GOT_TLSGD16_HI
R_PPC64_GOT_TLSGD16_HA R_PPC64 = 82 // R_POWERPC_GOT_TLSGD16_HA
R_PPC64_GOT_TLSLD16 R_PPC64 = 83 // R_POWERPC_GOT_TLSLD16
R_PPC64_GOT_TLSLD16_LO R_PPC64 = 84 // R_POWERPC_GOT_TLSLD16_LO
R_PPC64_GOT_TLSLD16_HI R_PPC64 = 85 // R_POWERPC_GOT_TLSLD16_HI
R_PPC64_GOT_TLSLD16_HA R_PPC64 = 86 // R_POWERPC_GOT_TLSLD16_HA
R_PPC64_GOT_TPREL16_DS R_PPC64 = 87 // R_POWERPC_GOT_TPREL16_DS
R_PPC64_GOT_TPREL16_LO_DS R_PPC64 = 88 // R_POWERPC_GOT_TPREL16_LO_DS
R_PPC64_GOT_TPREL16_HI R_PPC64 = 89 // R_POWERPC_GOT_TPREL16_HI
R_PPC64_GOT_TPREL16_HA R_PPC64 = 90 // R_POWERPC_GOT_TPREL16_HA
R_PPC64_GOT_DTPREL16_DS R_PPC64 = 91 // R_POWERPC_GOT_DTPREL16_DS
R_PPC64_GOT_DTPREL16_LO_DS R_PPC64 = 92 // R_POWERPC_GOT_DTPREL16_LO_DS
R_PPC64_GOT_DTPREL16_HI R_PPC64 = 93 // R_POWERPC_GOT_DTPREL16_HI
R_PPC64_GOT_DTPREL16_HA R_PPC64 = 94 // R_POWERPC_GOT_DTPREL16_HA
R_PPC64_TPREL16_DS R_PPC64 = 95
R_PPC64_TPREL16_LO_DS R_PPC64 = 96
R_PPC64_TPREL16_HIGHER R_PPC64 = 97
R_PPC64_TPREL16_HIGHERA R_PPC64 = 98
R_PPC64_TPREL16_HIGHEST R_PPC64 = 99
R_PPC64_TPREL16_HIGHESTA R_PPC64 = 100
R_PPC64_DTPREL16_DS R_PPC64 = 101
R_PPC64_DTPREL16_LO_DS R_PPC64 = 102
R_PPC64_DTPREL16_HIGHER R_PPC64 = 103
R_PPC64_DTPREL16_HIGHERA R_PPC64 = 104
R_PPC64_DTPREL16_HIGHEST R_PPC64 = 105
R_PPC64_DTPREL16_HIGHESTA R_PPC64 = 106
R_PPC64_TLSGD R_PPC64 = 107
R_PPC64_TLSLD R_PPC64 = 108
R_PPC64_TOCSAVE R_PPC64 = 109
R_PPC64_ADDR16_HIGH R_PPC64 = 110
R_PPC64_ADDR16_HIGHA R_PPC64 = 111
R_PPC64_TPREL16_HIGH R_PPC64 = 112
R_PPC64_TPREL16_HIGHA R_PPC64 = 113
R_PPC64_DTPREL16_HIGH R_PPC64 = 114
R_PPC64_DTPREL16_HIGHA R_PPC64 = 115
R_PPC64_REL24_NOTOC R_PPC64 = 116
R_PPC64_ADDR64_LOCAL R_PPC64 = 117
R_PPC64_ENTRY R_PPC64 = 118
R_PPC64_PLTSEQ R_PPC64 = 119
R_PPC64_PLTCALL R_PPC64 = 120
R_PPC64_PLTSEQ_NOTOC R_PPC64 = 121
R_PPC64_PLTCALL_NOTOC R_PPC64 = 122
R_PPC64_PCREL_OPT R_PPC64 = 123
R_PPC64_D34 R_PPC64 = 128
R_PPC64_D34_LO R_PPC64 = 129
R_PPC64_D34_HI30 R_PPC64 = 130
R_PPC64_D34_HA30 R_PPC64 = 131
R_PPC64_PCREL34 R_PPC64 = 132
R_PPC64_GOT_PCREL34 R_PPC64 = 133
R_PPC64_PLT_PCREL34 R_PPC64 = 134
R_PPC64_PLT_PCREL34_NOTOC R_PPC64 = 135
R_PPC64_ADDR16_HIGHER34 R_PPC64 = 136
R_PPC64_ADDR16_HIGHERA34 R_PPC64 = 137
R_PPC64_ADDR16_HIGHEST34 R_PPC64 = 138
R_PPC64_ADDR16_HIGHESTA34 R_PPC64 = 139
R_PPC64_REL16_HIGHER34 R_PPC64 = 140
R_PPC64_REL16_HIGHERA34 R_PPC64 = 141
R_PPC64_REL16_HIGHEST34 R_PPC64 = 142
R_PPC64_REL16_HIGHESTA34 R_PPC64 = 143
R_PPC64_D28 R_PPC64 = 144
R_PPC64_PCREL28 R_PPC64 = 145
R_PPC64_TPREL34 R_PPC64 = 146
R_PPC64_DTPREL34 R_PPC64 = 147
R_PPC64_GOT_TLSGD_PCREL34 R_PPC64 = 148
R_PPC64_GOT_TLSLD_PCREL34 R_PPC64 = 149
R_PPC64_GOT_TPREL_PCREL34 R_PPC64 = 150
R_PPC64_GOT_DTPREL_PCREL34 R_PPC64 = 151
R_PPC64_REL16_HIGH R_PPC64 = 240
R_PPC64_REL16_HIGHA R_PPC64 = 241
R_PPC64_REL16_HIGHER R_PPC64 = 242
R_PPC64_REL16_HIGHERA R_PPC64 = 243
R_PPC64_REL16_HIGHEST R_PPC64 = 244
R_PPC64_REL16_HIGHESTA R_PPC64 = 245
R_PPC64_REL16DX_HA R_PPC64 = 246 // R_POWERPC_REL16DX_HA
R_PPC64_JMP_IREL R_PPC64 = 247
R_PPC64_IRELATIVE R_PPC64 = 248 // R_POWERPC_IRELATIVE
R_PPC64_REL16 R_PPC64 = 249 // R_POWERPC_REL16
R_PPC64_REL16_LO R_PPC64 = 250 // R_POWERPC_REL16_LO
R_PPC64_REL16_HI R_PPC64 = 251 // R_POWERPC_REL16_HI
R_PPC64_REL16_HA R_PPC64 = 252 // R_POWERPC_REL16_HA
R_PPC64_GNU_VTINHERIT R_PPC64 = 253
R_PPC64_GNU_VTENTRY R_PPC64 = 254
)
var rppc64Strings = []intName{
{0, "R_PPC64_NONE"},
{1, "R_PPC64_ADDR32"},
{2, "R_PPC64_ADDR24"},
{3, "R_PPC64_ADDR16"},
{4, "R_PPC64_ADDR16_LO"},
{5, "R_PPC64_ADDR16_HI"},
{6, "R_PPC64_ADDR16_HA"},
{7, "R_PPC64_ADDR14"},
{8, "R_PPC64_ADDR14_BRTAKEN"},
{9, "R_PPC64_ADDR14_BRNTAKEN"},
{10, "R_PPC64_REL24"},
{11, "R_PPC64_REL14"},
{12, "R_PPC64_REL14_BRTAKEN"},
{13, "R_PPC64_REL14_BRNTAKEN"},
{14, "R_PPC64_GOT16"},
{15, "R_PPC64_GOT16_LO"},
{16, "R_PPC64_GOT16_HI"},
{17, "R_PPC64_GOT16_HA"},
{19, "R_PPC64_COPY"},
{20, "R_PPC64_GLOB_DAT"},
{21, "R_PPC64_JMP_SLOT"},
{22, "R_PPC64_RELATIVE"},
{24, "R_PPC64_UADDR32"},
{25, "R_PPC64_UADDR16"},
{26, "R_PPC64_REL32"},
{27, "R_PPC64_PLT32"},
{28, "R_PPC64_PLTREL32"},
{29, "R_PPC64_PLT16_LO"},
{30, "R_PPC64_PLT16_HI"},
{31, "R_PPC64_PLT16_HA"},
{33, "R_PPC64_SECTOFF"},
{34, "R_PPC64_SECTOFF_LO"},
{35, "R_PPC64_SECTOFF_HI"},
{36, "R_PPC64_SECTOFF_HA"},
{37, "R_PPC64_REL30"},
{38, "R_PPC64_ADDR64"},
{39, "R_PPC64_ADDR16_HIGHER"},
{40, "R_PPC64_ADDR16_HIGHERA"},
{41, "R_PPC64_ADDR16_HIGHEST"},
{42, "R_PPC64_ADDR16_HIGHESTA"},
{43, "R_PPC64_UADDR64"},
{44, "R_PPC64_REL64"},
{45, "R_PPC64_PLT64"},
{46, "R_PPC64_PLTREL64"},
{47, "R_PPC64_TOC16"},
{48, "R_PPC64_TOC16_LO"},
{49, "R_PPC64_TOC16_HI"},
{50, "R_PPC64_TOC16_HA"},
{51, "R_PPC64_TOC"},
{52, "R_PPC64_PLTGOT16"},
{53, "R_PPC64_PLTGOT16_LO"},
{54, "R_PPC64_PLTGOT16_HI"},
{55, "R_PPC64_PLTGOT16_HA"},
{56, "R_PPC64_ADDR16_DS"},
{57, "R_PPC64_ADDR16_LO_DS"},
{58, "R_PPC64_GOT16_DS"},
{59, "R_PPC64_GOT16_LO_DS"},
{60, "R_PPC64_PLT16_LO_DS"},
{61, "R_PPC64_SECTOFF_DS"},
{62, "R_PPC64_SECTOFF_LO_DS"},
{63, "R_PPC64_TOC16_DS"},
{64, "R_PPC64_TOC16_LO_DS"},
{65, "R_PPC64_PLTGOT16_DS"},
{66, "R_PPC64_PLTGOT_LO_DS"},
{67, "R_PPC64_TLS"},
{68, "R_PPC64_DTPMOD64"},
{69, "R_PPC64_TPREL16"},
{70, "R_PPC64_TPREL16_LO"},
{71, "R_PPC64_TPREL16_HI"},
{72, "R_PPC64_TPREL16_HA"},
{73, "R_PPC64_TPREL64"},
{74, "R_PPC64_DTPREL16"},
{75, "R_PPC64_DTPREL16_LO"},
{76, "R_PPC64_DTPREL16_HI"},
{77, "R_PPC64_DTPREL16_HA"},
{78, "R_PPC64_DTPREL64"},
{79, "R_PPC64_GOT_TLSGD16"},
{80, "R_PPC64_GOT_TLSGD16_LO"},
{81, "R_PPC64_GOT_TLSGD16_HI"},
{82, "R_PPC64_GOT_TLSGD16_HA"},
{83, "R_PPC64_GOT_TLSLD16"},
{84, "R_PPC64_GOT_TLSLD16_LO"},
{85, "R_PPC64_GOT_TLSLD16_HI"},
{86, "R_PPC64_GOT_TLSLD16_HA"},
{87, "R_PPC64_GOT_TPREL16_DS"},
{88, "R_PPC64_GOT_TPREL16_LO_DS"},
{89, "R_PPC64_GOT_TPREL16_HI"},
{90, "R_PPC64_GOT_TPREL16_HA"},
{91, "R_PPC64_GOT_DTPREL16_DS"},
{92, "R_PPC64_GOT_DTPREL16_LO_DS"},
{93, "R_PPC64_GOT_DTPREL16_HI"},
{94, "R_PPC64_GOT_DTPREL16_HA"},
{95, "R_PPC64_TPREL16_DS"},
{96, "R_PPC64_TPREL16_LO_DS"},
{97, "R_PPC64_TPREL16_HIGHER"},
{98, "R_PPC64_TPREL16_HIGHERA"},
{99, "R_PPC64_TPREL16_HIGHEST"},
{100, "R_PPC64_TPREL16_HIGHESTA"},
{101, "R_PPC64_DTPREL16_DS"},
{102, "R_PPC64_DTPREL16_LO_DS"},
{103, "R_PPC64_DTPREL16_HIGHER"},
{104, "R_PPC64_DTPREL16_HIGHERA"},
{105, "R_PPC64_DTPREL16_HIGHEST"},
{106, "R_PPC64_DTPREL16_HIGHESTA"},
{107, "R_PPC64_TLSGD"},
{108, "R_PPC64_TLSLD"},
{109, "R_PPC64_TOCSAVE"},
{110, "R_PPC64_ADDR16_HIGH"},
{111, "R_PPC64_ADDR16_HIGHA"},
{112, "R_PPC64_TPREL16_HIGH"},
{113, "R_PPC64_TPREL16_HIGHA"},
{114, "R_PPC64_DTPREL16_HIGH"},
{115, "R_PPC64_DTPREL16_HIGHA"},
{116, "R_PPC64_REL24_NOTOC"},
{117, "R_PPC64_ADDR64_LOCAL"},
{118, "R_PPC64_ENTRY"},
{119, "R_PPC64_PLTSEQ"},
{120, "R_PPC64_PLTCALL"},
{121, "R_PPC64_PLTSEQ_NOTOC"},
{122, "R_PPC64_PLTCALL_NOTOC"},
{123, "R_PPC64_PCREL_OPT"},
{128, "R_PPC64_D34"},
{129, "R_PPC64_D34_LO"},
{130, "R_PPC64_D34_HI30"},
{131, "R_PPC64_D34_HA30"},
{132, "R_PPC64_PCREL34"},
{133, "R_PPC64_GOT_PCREL34"},
{134, "R_PPC64_PLT_PCREL34"},
{135, "R_PPC64_PLT_PCREL34_NOTOC"},
{136, "R_PPC64_ADDR16_HIGHER34"},
{137, "R_PPC64_ADDR16_HIGHERA34"},
{138, "R_PPC64_ADDR16_HIGHEST34"},
{139, "R_PPC64_ADDR16_HIGHESTA34"},
{140, "R_PPC64_REL16_HIGHER34"},
{141, "R_PPC64_REL16_HIGHERA34"},
{142, "R_PPC64_REL16_HIGHEST34"},
{143, "R_PPC64_REL16_HIGHESTA34"},
{144, "R_PPC64_D28"},
{145, "R_PPC64_PCREL28"},
{146, "R_PPC64_TPREL34"},
{147, "R_PPC64_DTPREL34"},
{148, "R_PPC64_GOT_TLSGD_PCREL34"},
{149, "R_PPC64_GOT_TLSLD_PCREL34"},
{150, "R_PPC64_GOT_TPREL_PCREL34"},
{151, "R_PPC64_GOT_DTPREL_PCREL34"},
{240, "R_PPC64_REL16_HIGH"},
{241, "R_PPC64_REL16_HIGHA"},
{242, "R_PPC64_REL16_HIGHER"},
{243, "R_PPC64_REL16_HIGHERA"},
{244, "R_PPC64_REL16_HIGHEST"},
{245, "R_PPC64_REL16_HIGHESTA"},
{246, "R_PPC64_REL16DX_HA"},
{247, "R_PPC64_JMP_IREL"},
{248, "R_PPC64_IRELATIVE"},
{249, "R_PPC64_REL16"},
{250, "R_PPC64_REL16_LO"},
{251, "R_PPC64_REL16_HI"},
{252, "R_PPC64_REL16_HA"},
{253, "R_PPC64_GNU_VTINHERIT"},
{254, "R_PPC64_GNU_VTENTRY"},
}
func (i R_PPC64) String() string { return stringName(uint32(i), rppc64Strings, false) }
func (i R_PPC64) GoString() string { return stringName(uint32(i), rppc64Strings, true) }
// Relocation types for RISC-V processors.
type R_RISCV int
const (
R_RISCV_NONE R_RISCV = 0 /* No relocation. */
R_RISCV_32 R_RISCV = 1 /* Add 32 bit zero extended symbol value */
R_RISCV_64 R_RISCV = 2 /* Add 64 bit symbol value. */
R_RISCV_RELATIVE R_RISCV = 3 /* Add load address of shared object. */
R_RISCV_COPY R_RISCV = 4 /* Copy data from shared object. */
R_RISCV_JUMP_SLOT R_RISCV = 5 /* Set GOT entry to code address. */
R_RISCV_TLS_DTPMOD32 R_RISCV = 6 /* 32 bit ID of module containing symbol */
R_RISCV_TLS_DTPMOD64 R_RISCV = 7 /* ID of module containing symbol */
R_RISCV_TLS_DTPREL32 R_RISCV = 8 /* 32 bit relative offset in TLS block */
R_RISCV_TLS_DTPREL64 R_RISCV = 9 /* Relative offset in TLS block */
R_RISCV_TLS_TPREL32 R_RISCV = 10 /* 32 bit relative offset in static TLS block */
R_RISCV_TLS_TPREL64 R_RISCV = 11 /* Relative offset in static TLS block */
R_RISCV_BRANCH R_RISCV = 16 /* PC-relative branch */
R_RISCV_JAL R_RISCV = 17 /* PC-relative jump */
R_RISCV_CALL R_RISCV = 18 /* PC-relative call */
R_RISCV_CALL_PLT R_RISCV = 19 /* PC-relative call (PLT) */
R_RISCV_GOT_HI20 R_RISCV = 20 /* PC-relative GOT reference */
R_RISCV_TLS_GOT_HI20 R_RISCV = 21 /* PC-relative TLS IE GOT offset */
R_RISCV_TLS_GD_HI20 R_RISCV = 22 /* PC-relative TLS GD reference */
R_RISCV_PCREL_HI20 R_RISCV = 23 /* PC-relative reference */
R_RISCV_PCREL_LO12_I R_RISCV = 24 /* PC-relative reference */
R_RISCV_PCREL_LO12_S R_RISCV = 25 /* PC-relative reference */
R_RISCV_HI20 R_RISCV = 26 /* Absolute address */
R_RISCV_LO12_I R_RISCV = 27 /* Absolute address */
R_RISCV_LO12_S R_RISCV = 28 /* Absolute address */
R_RISCV_TPREL_HI20 R_RISCV = 29 /* TLS LE thread offset */
R_RISCV_TPREL_LO12_I R_RISCV = 30 /* TLS LE thread offset */
R_RISCV_TPREL_LO12_S R_RISCV = 31 /* TLS LE thread offset */
R_RISCV_TPREL_ADD R_RISCV = 32 /* TLS LE thread usage */
R_RISCV_ADD8 R_RISCV = 33 /* 8-bit label addition */
R_RISCV_ADD16 R_RISCV = 34 /* 16-bit label addition */
R_RISCV_ADD32 R_RISCV = 35 /* 32-bit label addition */
R_RISCV_ADD64 R_RISCV = 36 /* 64-bit label addition */
R_RISCV_SUB8 R_RISCV = 37 /* 8-bit label subtraction */
R_RISCV_SUB16 R_RISCV = 38 /* 16-bit label subtraction */
R_RISCV_SUB32 R_RISCV = 39 /* 32-bit label subtraction */
R_RISCV_SUB64 R_RISCV = 40 /* 64-bit label subtraction */
R_RISCV_GNU_VTINHERIT R_RISCV = 41 /* GNU C++ vtable hierarchy */
R_RISCV_GNU_VTENTRY R_RISCV = 42 /* GNU C++ vtable member usage */
R_RISCV_ALIGN R_RISCV = 43 /* Alignment statement */
R_RISCV_RVC_BRANCH R_RISCV = 44 /* PC-relative branch offset */
R_RISCV_RVC_JUMP R_RISCV = 45 /* PC-relative jump offset */
R_RISCV_RVC_LUI R_RISCV = 46 /* Absolute address */
R_RISCV_GPREL_I R_RISCV = 47 /* GP-relative reference */
R_RISCV_GPREL_S R_RISCV = 48 /* GP-relative reference */
R_RISCV_TPREL_I R_RISCV = 49 /* TP-relative TLS LE load */
R_RISCV_TPREL_S R_RISCV = 50 /* TP-relative TLS LE store */
R_RISCV_RELAX R_RISCV = 51 /* Instruction pair can be relaxed */
R_RISCV_SUB6 R_RISCV = 52 /* Local label subtraction */
R_RISCV_SET6 R_RISCV = 53 /* Local label subtraction */
R_RISCV_SET8 R_RISCV = 54 /* Local label subtraction */
R_RISCV_SET16 R_RISCV = 55 /* Local label subtraction */
R_RISCV_SET32 R_RISCV = 56 /* Local label subtraction */
R_RISCV_32_PCREL R_RISCV = 57 /* 32-bit PC relative */
)
var rriscvStrings = []intName{
{0, "R_RISCV_NONE"},
{1, "R_RISCV_32"},
{2, "R_RISCV_64"},
{3, "R_RISCV_RELATIVE"},
{4, "R_RISCV_COPY"},
{5, "R_RISCV_JUMP_SLOT"},
{6, "R_RISCV_TLS_DTPMOD32"},
{7, "R_RISCV_TLS_DTPMOD64"},
{8, "R_RISCV_TLS_DTPREL32"},
{9, "R_RISCV_TLS_DTPREL64"},
{10, "R_RISCV_TLS_TPREL32"},
{11, "R_RISCV_TLS_TPREL64"},
{16, "R_RISCV_BRANCH"},
{17, "R_RISCV_JAL"},
{18, "R_RISCV_CALL"},
{19, "R_RISCV_CALL_PLT"},
{20, "R_RISCV_GOT_HI20"},
{21, "R_RISCV_TLS_GOT_HI20"},
{22, "R_RISCV_TLS_GD_HI20"},
{23, "R_RISCV_PCREL_HI20"},
{24, "R_RISCV_PCREL_LO12_I"},
{25, "R_RISCV_PCREL_LO12_S"},
{26, "R_RISCV_HI20"},
{27, "R_RISCV_LO12_I"},
{28, "R_RISCV_LO12_S"},
{29, "R_RISCV_TPREL_HI20"},
{30, "R_RISCV_TPREL_LO12_I"},
{31, "R_RISCV_TPREL_LO12_S"},
{32, "R_RISCV_TPREL_ADD"},
{33, "R_RISCV_ADD8"},
{34, "R_RISCV_ADD16"},
{35, "R_RISCV_ADD32"},
{36, "R_RISCV_ADD64"},
{37, "R_RISCV_SUB8"},
{38, "R_RISCV_SUB16"},
{39, "R_RISCV_SUB32"},
{40, "R_RISCV_SUB64"},
{41, "R_RISCV_GNU_VTINHERIT"},
{42, "R_RISCV_GNU_VTENTRY"},
{43, "R_RISCV_ALIGN"},
{44, "R_RISCV_RVC_BRANCH"},
{45, "R_RISCV_RVC_JUMP"},
{46, "R_RISCV_RVC_LUI"},
{47, "R_RISCV_GPREL_I"},
{48, "R_RISCV_GPREL_S"},
{49, "R_RISCV_TPREL_I"},
{50, "R_RISCV_TPREL_S"},
{51, "R_RISCV_RELAX"},
{52, "R_RISCV_SUB6"},
{53, "R_RISCV_SET6"},
{54, "R_RISCV_SET8"},
{55, "R_RISCV_SET16"},
{56, "R_RISCV_SET32"},
{57, "R_RISCV_32_PCREL"},
}
func (i R_RISCV) String() string { return stringName(uint32(i), rriscvStrings, false) }
func (i R_RISCV) GoString() string { return stringName(uint32(i), rriscvStrings, true) }
// Relocation types for s390x processors.
type R_390 int
const (
R_390_NONE R_390 = 0
R_390_8 R_390 = 1
R_390_12 R_390 = 2
R_390_16 R_390 = 3
R_390_32 R_390 = 4
R_390_PC32 R_390 = 5
R_390_GOT12 R_390 = 6
R_390_GOT32 R_390 = 7
R_390_PLT32 R_390 = 8
R_390_COPY R_390 = 9
R_390_GLOB_DAT R_390 = 10
R_390_JMP_SLOT R_390 = 11
R_390_RELATIVE R_390 = 12
R_390_GOTOFF R_390 = 13
R_390_GOTPC R_390 = 14
R_390_GOT16 R_390 = 15
R_390_PC16 R_390 = 16
R_390_PC16DBL R_390 = 17
R_390_PLT16DBL R_390 = 18
R_390_PC32DBL R_390 = 19
R_390_PLT32DBL R_390 = 20
R_390_GOTPCDBL R_390 = 21
R_390_64 R_390 = 22
R_390_PC64 R_390 = 23
R_390_GOT64 R_390 = 24
R_390_PLT64 R_390 = 25
R_390_GOTENT R_390 = 26
R_390_GOTOFF16 R_390 = 27
R_390_GOTOFF64 R_390 = 28
R_390_GOTPLT12 R_390 = 29
R_390_GOTPLT16 R_390 = 30
R_390_GOTPLT32 R_390 = 31
R_390_GOTPLT64 R_390 = 32
R_390_GOTPLTENT R_390 = 33
R_390_GOTPLTOFF16 R_390 = 34
R_390_GOTPLTOFF32 R_390 = 35
R_390_GOTPLTOFF64 R_390 = 36
R_390_TLS_LOAD R_390 = 37
R_390_TLS_GDCALL R_390 = 38
R_390_TLS_LDCALL R_390 = 39
R_390_TLS_GD32 R_390 = 40
R_390_TLS_GD64 R_390 = 41
R_390_TLS_GOTIE12 R_390 = 42
R_390_TLS_GOTIE32 R_390 = 43
R_390_TLS_GOTIE64 R_390 = 44
R_390_TLS_LDM32 R_390 = 45
R_390_TLS_LDM64 R_390 = 46
R_390_TLS_IE32 R_390 = 47
R_390_TLS_IE64 R_390 = 48
R_390_TLS_IEENT R_390 = 49
R_390_TLS_LE32 R_390 = 50
R_390_TLS_LE64 R_390 = 51
R_390_TLS_LDO32 R_390 = 52
R_390_TLS_LDO64 R_390 = 53
R_390_TLS_DTPMOD R_390 = 54
R_390_TLS_DTPOFF R_390 = 55
R_390_TLS_TPOFF R_390 = 56
R_390_20 R_390 = 57
R_390_GOT20 R_390 = 58
R_390_GOTPLT20 R_390 = 59
R_390_TLS_GOTIE20 R_390 = 60
)
var r390Strings = []intName{
{0, "R_390_NONE"},
{1, "R_390_8"},
{2, "R_390_12"},
{3, "R_390_16"},
{4, "R_390_32"},
{5, "R_390_PC32"},
{6, "R_390_GOT12"},
{7, "R_390_GOT32"},
{8, "R_390_PLT32"},
{9, "R_390_COPY"},
{10, "R_390_GLOB_DAT"},
{11, "R_390_JMP_SLOT"},
{12, "R_390_RELATIVE"},
{13, "R_390_GOTOFF"},
{14, "R_390_GOTPC"},
{15, "R_390_GOT16"},
{16, "R_390_PC16"},
{17, "R_390_PC16DBL"},
{18, "R_390_PLT16DBL"},
{19, "R_390_PC32DBL"},
{20, "R_390_PLT32DBL"},
{21, "R_390_GOTPCDBL"},
{22, "R_390_64"},
{23, "R_390_PC64"},
{24, "R_390_GOT64"},
{25, "R_390_PLT64"},
{26, "R_390_GOTENT"},
{27, "R_390_GOTOFF16"},
{28, "R_390_GOTOFF64"},
{29, "R_390_GOTPLT12"},
{30, "R_390_GOTPLT16"},
{31, "R_390_GOTPLT32"},
{32, "R_390_GOTPLT64"},
{33, "R_390_GOTPLTENT"},
{34, "R_390_GOTPLTOFF16"},
{35, "R_390_GOTPLTOFF32"},
{36, "R_390_GOTPLTOFF64"},
{37, "R_390_TLS_LOAD"},
{38, "R_390_TLS_GDCALL"},
{39, "R_390_TLS_LDCALL"},
{40, "R_390_TLS_GD32"},
{41, "R_390_TLS_GD64"},
{42, "R_390_TLS_GOTIE12"},
{43, "R_390_TLS_GOTIE32"},
{44, "R_390_TLS_GOTIE64"},
{45, "R_390_TLS_LDM32"},
{46, "R_390_TLS_LDM64"},
{47, "R_390_TLS_IE32"},
{48, "R_390_TLS_IE64"},
{49, "R_390_TLS_IEENT"},
{50, "R_390_TLS_LE32"},
{51, "R_390_TLS_LE64"},
{52, "R_390_TLS_LDO32"},
{53, "R_390_TLS_LDO64"},
{54, "R_390_TLS_DTPMOD"},
{55, "R_390_TLS_DTPOFF"},
{56, "R_390_TLS_TPOFF"},
{57, "R_390_20"},
{58, "R_390_GOT20"},
{59, "R_390_GOTPLT20"},
{60, "R_390_TLS_GOTIE20"},
}
func (i R_390) String() string { return stringName(uint32(i), r390Strings, false) }
func (i R_390) GoString() string { return stringName(uint32(i), r390Strings, true) }
// Relocation types for SPARC.
type R_SPARC int
const (
R_SPARC_NONE R_SPARC = 0
R_SPARC_8 R_SPARC = 1
R_SPARC_16 R_SPARC = 2
R_SPARC_32 R_SPARC = 3
R_SPARC_DISP8 R_SPARC = 4
R_SPARC_DISP16 R_SPARC = 5
R_SPARC_DISP32 R_SPARC = 6
R_SPARC_WDISP30 R_SPARC = 7
R_SPARC_WDISP22 R_SPARC = 8
R_SPARC_HI22 R_SPARC = 9
R_SPARC_22 R_SPARC = 10
R_SPARC_13 R_SPARC = 11
R_SPARC_LO10 R_SPARC = 12
R_SPARC_GOT10 R_SPARC = 13
R_SPARC_GOT13 R_SPARC = 14
R_SPARC_GOT22 R_SPARC = 15
R_SPARC_PC10 R_SPARC = 16
R_SPARC_PC22 R_SPARC = 17
R_SPARC_WPLT30 R_SPARC = 18
R_SPARC_COPY R_SPARC = 19
R_SPARC_GLOB_DAT R_SPARC = 20
R_SPARC_JMP_SLOT R_SPARC = 21
R_SPARC_RELATIVE R_SPARC = 22
R_SPARC_UA32 R_SPARC = 23
R_SPARC_PLT32 R_SPARC = 24
R_SPARC_HIPLT22 R_SPARC = 25
R_SPARC_LOPLT10 R_SPARC = 26
R_SPARC_PCPLT32 R_SPARC = 27
R_SPARC_PCPLT22 R_SPARC = 28
R_SPARC_PCPLT10 R_SPARC = 29
R_SPARC_10 R_SPARC = 30
R_SPARC_11 R_SPARC = 31
R_SPARC_64 R_SPARC = 32
R_SPARC_OLO10 R_SPARC = 33
R_SPARC_HH22 R_SPARC = 34
R_SPARC_HM10 R_SPARC = 35
R_SPARC_LM22 R_SPARC = 36
R_SPARC_PC_HH22 R_SPARC = 37
R_SPARC_PC_HM10 R_SPARC = 38
R_SPARC_PC_LM22 R_SPARC = 39
R_SPARC_WDISP16 R_SPARC = 40
R_SPARC_WDISP19 R_SPARC = 41
R_SPARC_GLOB_JMP R_SPARC = 42
R_SPARC_7 R_SPARC = 43
R_SPARC_5 R_SPARC = 44
R_SPARC_6 R_SPARC = 45
R_SPARC_DISP64 R_SPARC = 46
R_SPARC_PLT64 R_SPARC = 47
R_SPARC_HIX22 R_SPARC = 48
R_SPARC_LOX10 R_SPARC = 49
R_SPARC_H44 R_SPARC = 50
R_SPARC_M44 R_SPARC = 51
R_SPARC_L44 R_SPARC = 52
R_SPARC_REGISTER R_SPARC = 53
R_SPARC_UA64 R_SPARC = 54
R_SPARC_UA16 R_SPARC = 55
)
var rsparcStrings = []intName{
{0, "R_SPARC_NONE"},
{1, "R_SPARC_8"},
{2, "R_SPARC_16"},
{3, "R_SPARC_32"},
{4, "R_SPARC_DISP8"},
{5, "R_SPARC_DISP16"},
{6, "R_SPARC_DISP32"},
{7, "R_SPARC_WDISP30"},
{8, "R_SPARC_WDISP22"},
{9, "R_SPARC_HI22"},
{10, "R_SPARC_22"},
{11, "R_SPARC_13"},
{12, "R_SPARC_LO10"},
{13, "R_SPARC_GOT10"},
{14, "R_SPARC_GOT13"},
{15, "R_SPARC_GOT22"},
{16, "R_SPARC_PC10"},
{17, "R_SPARC_PC22"},
{18, "R_SPARC_WPLT30"},
{19, "R_SPARC_COPY"},
{20, "R_SPARC_GLOB_DAT"},
{21, "R_SPARC_JMP_SLOT"},
{22, "R_SPARC_RELATIVE"},
{23, "R_SPARC_UA32"},
{24, "R_SPARC_PLT32"},
{25, "R_SPARC_HIPLT22"},
{26, "R_SPARC_LOPLT10"},
{27, "R_SPARC_PCPLT32"},
{28, "R_SPARC_PCPLT22"},
{29, "R_SPARC_PCPLT10"},
{30, "R_SPARC_10"},
{31, "R_SPARC_11"},
{32, "R_SPARC_64"},
{33, "R_SPARC_OLO10"},
{34, "R_SPARC_HH22"},
{35, "R_SPARC_HM10"},
{36, "R_SPARC_LM22"},
{37, "R_SPARC_PC_HH22"},
{38, "R_SPARC_PC_HM10"},
{39, "R_SPARC_PC_LM22"},
{40, "R_SPARC_WDISP16"},
{41, "R_SPARC_WDISP19"},
{42, "R_SPARC_GLOB_JMP"},
{43, "R_SPARC_7"},
{44, "R_SPARC_5"},
{45, "R_SPARC_6"},
{46, "R_SPARC_DISP64"},
{47, "R_SPARC_PLT64"},
{48, "R_SPARC_HIX22"},
{49, "R_SPARC_LOX10"},
{50, "R_SPARC_H44"},
{51, "R_SPARC_M44"},
{52, "R_SPARC_L44"},
{53, "R_SPARC_REGISTER"},
{54, "R_SPARC_UA64"},
{55, "R_SPARC_UA16"},
}
func (i R_SPARC) String() string { return stringName(uint32(i), rsparcStrings, false) }
func (i R_SPARC) GoString() string { return stringName(uint32(i), rsparcStrings, true) }
// Magic number for the elf trampoline, chosen wisely to be an immediate value.
const ARM_MAGIC_TRAMP_NUMBER = 0x5c000003
// ELF32 File header.
type Header32 struct {
Ident [EI_NIDENT]byte /* File identification. */
Type uint16 /* File type. */
Machine uint16 /* Machine architecture. */
Version uint32 /* ELF format version. */
Entry uint32 /* Entry point. */
Phoff uint32 /* Program header file offset. */
Shoff uint32 /* Section header file offset. */
Flags uint32 /* Architecture-specific flags. */
Ehsize uint16 /* Size of ELF header in bytes. */
Phentsize uint16 /* Size of program header entry. */
Phnum uint16 /* Number of program header entries. */
Shentsize uint16 /* Size of section header entry. */
Shnum uint16 /* Number of section header entries. */
Shstrndx uint16 /* Section name strings section. */
}
// ELF32 Section header.
type Section32 struct {
Name uint32 /* Section name (index into the section header string table). */
Type uint32 /* Section type. */
Flags uint32 /* Section flags. */
Addr uint32 /* Address in memory image. */
Off uint32 /* Offset in file. */
Size uint32 /* Size in bytes. */
Link uint32 /* Index of a related section. */
Info uint32 /* Depends on section type. */
Addralign uint32 /* Alignment in bytes. */
Entsize uint32 /* Size of each entry in section. */
}
// ELF32 Program header.
type Prog32 struct {
Type uint32 /* Entry type. */
Off uint32 /* File offset of contents. */
Vaddr uint32 /* Virtual address in memory image. */
Paddr uint32 /* Physical address (not used). */
Filesz uint32 /* Size of contents in file. */
Memsz uint32 /* Size of contents in memory. */
Flags uint32 /* Access permission flags. */
Align uint32 /* Alignment in memory and file. */
}
// ELF32 Dynamic structure. The ".dynamic" section contains an array of them.
type Dyn32 struct {
Tag int32 /* Entry type. */
Val uint32 /* Integer/Address value. */
}
// ELF32 Compression header.
type Chdr32 struct {
Type uint32
Size uint32
Addralign uint32
}
/*
* Relocation entries.
*/
// ELF32 Relocations that don't need an addend field.
type Rel32 struct {
Off uint32 /* Location to be relocated. */
Info uint32 /* Relocation type and symbol index. */
}
// ELF32 Relocations that need an addend field.
type Rela32 struct {
Off uint32 /* Location to be relocated. */
Info uint32 /* Relocation type and symbol index. */
Addend int32 /* Addend. */
}
func R_SYM32(info uint32) uint32 { return info >> 8 }
func R_TYPE32(info uint32) uint32 { return info & 0xff }
func R_INFO32(sym, typ uint32) uint32 { return sym<<8 | typ }
// ELF32 Symbol.
type Sym32 struct {
Name uint32
Value uint32
Size uint32
Info uint8
Other uint8
Shndx uint16
}
const Sym32Size = 16
func ST_BIND(info uint8) SymBind { return SymBind(info >> 4) }
func ST_TYPE(info uint8) SymType { return SymType(info & 0xF) }
func ST_INFO(bind SymBind, typ SymType) uint8 {
return uint8(bind)<<4 | uint8(typ)&0xf
}
func ST_VISIBILITY(other uint8) SymVis { return SymVis(other & 3) }
/*
* ELF64
*/
// ELF64 file header.
type Header64 struct {
Ident [EI_NIDENT]byte /* File identification. */
Type uint16 /* File type. */
Machine uint16 /* Machine architecture. */
Version uint32 /* ELF format version. */
Entry uint64 /* Entry point. */
Phoff uint64 /* Program header file offset. */
Shoff uint64 /* Section header file offset. */
Flags uint32 /* Architecture-specific flags. */
Ehsize uint16 /* Size of ELF header in bytes. */
Phentsize uint16 /* Size of program header entry. */
Phnum uint16 /* Number of program header entries. */
Shentsize uint16 /* Size of section header entry. */
Shnum uint16 /* Number of section header entries. */
Shstrndx uint16 /* Section name strings section. */
}
// ELF64 Section header.
type Section64 struct {
Name uint32 /* Section name (index into the section header string table). */
Type uint32 /* Section type. */
Flags uint64 /* Section flags. */
Addr uint64 /* Address in memory image. */
Off uint64 /* Offset in file. */
Size uint64 /* Size in bytes. */
Link uint32 /* Index of a related section. */
Info uint32 /* Depends on section type. */
Addralign uint64 /* Alignment in bytes. */
Entsize uint64 /* Size of each entry in section. */
}
// ELF64 Program header.
type Prog64 struct {
Type uint32 /* Entry type. */
Flags uint32 /* Access permission flags. */
Off uint64 /* File offset of contents. */
Vaddr uint64 /* Virtual address in memory image. */
Paddr uint64 /* Physical address (not used). */
Filesz uint64 /* Size of contents in file. */
Memsz uint64 /* Size of contents in memory. */
Align uint64 /* Alignment in memory and file. */
}
// ELF64 Dynamic structure. The ".dynamic" section contains an array of them.
type Dyn64 struct {
Tag int64 /* Entry type. */
Val uint64 /* Integer/address value */
}
// ELF64 Compression header.
type Chdr64 struct {
Type uint32
_ uint32 /* Reserved. */
Size uint64
Addralign uint64
}
/*
* Relocation entries.
*/
/* ELF64 relocations that don't need an addend field. */
type Rel64 struct {
Off uint64 /* Location to be relocated. */
Info uint64 /* Relocation type and symbol index. */
}
/* ELF64 relocations that need an addend field. */
type Rela64 struct {
Off uint64 /* Location to be relocated. */
Info uint64 /* Relocation type and symbol index. */
Addend int64 /* Addend. */
}
func R_SYM64(info uint64) uint32 { return uint32(info >> 32) }
func R_TYPE64(info uint64) uint32 { return uint32(info) }
func R_INFO(sym, typ uint32) uint64 { return uint64(sym)<<32 | uint64(typ) }
// ELF64 symbol table entries.
type Sym64 struct {
Name uint32 /* String table index of name. */
Info uint8 /* Type and binding information. */
Other uint8 /* Reserved (not used). */
Shndx uint16 /* Section index of symbol. */
Value uint64 /* Symbol value. */
Size uint64 /* Size of associated object. */
}
const Sym64Size = 24
type intName struct {
i uint32
s string
}
func stringName(i uint32, names []intName, goSyntax bool) string {
for _, n := range names {
if n.i == i {
if goSyntax {
return "elf." + n.s
}
return n.s
}
}
// second pass - look for smaller to add with.
// assume sorted already
for j := len(names) - 1; j >= 0; j-- {
n := names[j]
if n.i < i {
s := n.s
if goSyntax {
s = "elf." + s
}
return s + "+" + strconv.FormatUint(uint64(i-n.i), 10)
}
}
return strconv.FormatUint(uint64(i), 10)
}
func flagName(i uint32, names []intName, goSyntax bool) string {
s := ""
for _, n := range names {
if n.i&i == n.i {
if len(s) > 0 {
s += "+"
}
if goSyntax {
s += "elf."
}
s += n.s
i -= n.i
}
}
if len(s) == 0 {
return "0x" + strconv.FormatUint(uint64(i), 16)
}
if i != 0 {
s += "+0x" + strconv.FormatUint(uint64(i), 16)
}
return s
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package elf implements access to ELF object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package elf
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"io"
"os"
"strings"
)
// seekStart, seekCurrent, seekEnd are copies of
// io.SeekStart, io.SeekCurrent, and io.SeekEnd.
// We can't use the ones from package io because
// we want this code to build with Go 1.4 during
// cmd/dist bootstrap.
const (
seekStart int = 0
seekCurrent int = 1
seekEnd int = 2
)
// TODO: error reporting detail
/*
* Internal ELF representation
*/
// A FileHeader represents an ELF file header.
type FileHeader struct {
Class Class
Data Data
Version Version
OSABI OSABI
ABIVersion uint8
ByteOrder binary.ByteOrder
Type Type
Machine Machine
Entry uint64
}
// A File represents an open ELF file.
type File struct {
FileHeader
Sections []*Section
Progs []*Prog
closer io.Closer
gnuNeed []verneed
gnuVersym []byte
}
// A SectionHeader represents a single ELF section header.
type SectionHeader struct {
Name string
Type SectionType
Flags SectionFlag
Addr uint64
Offset uint64
Size uint64
Link uint32
Info uint32
Addralign uint64
Entsize uint64
// FileSize is the size of this section in the file in bytes.
// If a section is compressed, FileSize is the size of the
// compressed data, while Size (above) is the size of the
// uncompressed data.
FileSize uint64
}
// A Section represents a single section in an ELF file.
type Section struct {
SectionHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
//
// ReaderAt may be nil if the section is not easily available
// in a random-access form. For example, a compressed section
// may have a nil ReaderAt.
io.ReaderAt
sr *io.SectionReader
compressionType CompressionType
compressionOffset int64
}
// Data reads and returns the contents of the ELF section.
// Even if the section is stored compressed in the ELF file,
// Data returns uncompressed data.
//
// For an SHT_NOBITS section, Data always returns a non-nil error.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadData(s.Open(), s.Size)
}
// stringTable reads and returns the string table given by the
// specified link value.
func (f *File) stringTable(link uint32) ([]byte, error) {
if link <= 0 || link >= uint32(len(f.Sections)) {
return nil, errors.New("section has invalid string table link")
}
return f.Sections[link].Data()
}
// Open returns a new ReadSeeker reading the ELF section.
// Even if the section is stored compressed in the ELF file,
// the ReadSeeker reads uncompressed data.
//
// For an SHT_NOBITS section, all calls to the opened reader
// will return a non-nil error.
func (s *Section) Open() io.ReadSeeker {
if s.Type == SHT_NOBITS {
return io.NewSectionReader(&nobitsSectionReader{}, 0, int64(s.Size))
}
if s.Flags&SHF_COMPRESSED == 0 {
return io.NewSectionReader(s.sr, 0, 1<<63-1)
}
if s.compressionType == COMPRESS_ZLIB {
return &readSeekerFromReader{
reset: func() (io.Reader, error) {
fr := io.NewSectionReader(s.sr, s.compressionOffset, int64(s.FileSize)-s.compressionOffset)
return zlib.NewReader(fr)
},
size: int64(s.Size),
}
}
err := &FormatError{int64(s.Offset), "unknown compression type", s.compressionType}
return errorReader{err}
}
// A ProgHeader represents a single ELF program header.
type ProgHeader struct {
Type ProgType
Flags ProgFlag
Off uint64
Vaddr uint64
Paddr uint64
Filesz uint64
Memsz uint64
Align uint64
}
// A Prog represents a single ELF program header in an ELF binary.
type Prog struct {
ProgHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Open returns a new ReadSeeker reading the ELF program body.
func (p *Prog) Open() io.ReadSeeker { return io.NewSectionReader(p.sr, 0, 1<<63-1) }
// A Symbol represents an entry in an ELF symbol table section.
type Symbol struct {
Name string
Info, Other byte
Section SectionIndex
Value, Size uint64
// Version and Library are present only for the dynamic symbol
// table.
Version string
Library string
}
/*
* ELF reader
*/
type FormatError struct {
off int64
msg string
val any
}
func (e *FormatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v' ", e.val)
}
msg += fmt.Sprintf("in record at byte %#x", e.off)
return msg
}
// Open opens the named file using os.Open and prepares it for use as an ELF binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the File.
// If the File was created using NewFile directly instead of Open,
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// SectionByType returns the first section in f with the
// given type, or nil if there is no such section.
func (f *File) SectionByType(typ SectionType) *Section {
for _, s := range f.Sections {
if s.Type == typ {
return s
}
}
return nil
}
// NewFile creates a new File for accessing an ELF binary in an underlying reader.
// The ELF binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode ELF identifier
var ident [16]uint8
if _, err := r.ReadAt(ident[0:], 0); err != nil {
return nil, err
}
if ident[0] != '\x7f' || ident[1] != 'E' || ident[2] != 'L' || ident[3] != 'F' {
return nil, &FormatError{0, "bad magic number", ident[0:4]}
}
f := new(File)
f.Class = Class(ident[EI_CLASS])
switch f.Class {
case ELFCLASS32:
case ELFCLASS64:
// ok
default:
return nil, &FormatError{0, "unknown ELF class", f.Class}
}
f.Data = Data(ident[EI_DATA])
switch f.Data {
case ELFDATA2LSB:
f.ByteOrder = binary.LittleEndian
case ELFDATA2MSB:
f.ByteOrder = binary.BigEndian
default:
return nil, &FormatError{0, "unknown ELF data encoding", f.Data}
}
f.Version = Version(ident[EI_VERSION])
if f.Version != EV_CURRENT {
return nil, &FormatError{0, "unknown ELF version", f.Version}
}
f.OSABI = OSABI(ident[EI_OSABI])
f.ABIVersion = ident[EI_ABIVERSION]
// Read ELF file header
var phoff int64
var phentsize, phnum int
var shoff int64
var shentsize, shnum, shstrndx int
switch f.Class {
case ELFCLASS32:
hdr := new(Header32)
sr.Seek(0, seekStart)
if err := binary.Read(sr, f.ByteOrder, hdr); err != nil {
return nil, err
}
f.Type = Type(hdr.Type)
f.Machine = Machine(hdr.Machine)
f.Entry = uint64(hdr.Entry)
if v := Version(hdr.Version); v != f.Version {
return nil, &FormatError{0, "mismatched ELF version", v}
}
phoff = int64(hdr.Phoff)
phentsize = int(hdr.Phentsize)
phnum = int(hdr.Phnum)
shoff = int64(hdr.Shoff)
shentsize = int(hdr.Shentsize)
shnum = int(hdr.Shnum)
shstrndx = int(hdr.Shstrndx)
case ELFCLASS64:
hdr := new(Header64)
sr.Seek(0, seekStart)
if err := binary.Read(sr, f.ByteOrder, hdr); err != nil {
return nil, err
}
f.Type = Type(hdr.Type)
f.Machine = Machine(hdr.Machine)
f.Entry = hdr.Entry
if v := Version(hdr.Version); v != f.Version {
return nil, &FormatError{0, "mismatched ELF version", v}
}
phoff = int64(hdr.Phoff)
phentsize = int(hdr.Phentsize)
phnum = int(hdr.Phnum)
shoff = int64(hdr.Shoff)
shentsize = int(hdr.Shentsize)
shnum = int(hdr.Shnum)
shstrndx = int(hdr.Shstrndx)
}
if shoff < 0 {
return nil, &FormatError{0, "invalid shoff", shoff}
}
if phoff < 0 {
return nil, &FormatError{0, "invalid phoff", phoff}
}
if shoff == 0 && shnum != 0 {
return nil, &FormatError{0, "invalid ELF shnum for shoff=0", shnum}
}
if shnum > 0 && shstrndx >= shnum {
return nil, &FormatError{0, "invalid ELF shstrndx", shstrndx}
}
var wantPhentsize, wantShentsize int
switch f.Class {
case ELFCLASS32:
wantPhentsize = 8 * 4
wantShentsize = 10 * 4
case ELFCLASS64:
wantPhentsize = 2*4 + 6*8
wantShentsize = 4*4 + 6*8
}
if phnum > 0 && phentsize < wantPhentsize {
return nil, &FormatError{0, "invalid ELF phentsize", phentsize}
}
// Read program headers
f.Progs = make([]*Prog, phnum)
for i := 0; i < phnum; i++ {
off := phoff + int64(i)*int64(phentsize)
sr.Seek(off, seekStart)
p := new(Prog)
switch f.Class {
case ELFCLASS32:
ph := new(Prog32)
if err := binary.Read(sr, f.ByteOrder, ph); err != nil {
return nil, err
}
p.ProgHeader = ProgHeader{
Type: ProgType(ph.Type),
Flags: ProgFlag(ph.Flags),
Off: uint64(ph.Off),
Vaddr: uint64(ph.Vaddr),
Paddr: uint64(ph.Paddr),
Filesz: uint64(ph.Filesz),
Memsz: uint64(ph.Memsz),
Align: uint64(ph.Align),
}
case ELFCLASS64:
ph := new(Prog64)
if err := binary.Read(sr, f.ByteOrder, ph); err != nil {
return nil, err
}
p.ProgHeader = ProgHeader{
Type: ProgType(ph.Type),
Flags: ProgFlag(ph.Flags),
Off: ph.Off,
Vaddr: ph.Vaddr,
Paddr: ph.Paddr,
Filesz: ph.Filesz,
Memsz: ph.Memsz,
Align: ph.Align,
}
}
if int64(p.Off) < 0 {
return nil, &FormatError{off, "invalid program header offset", p.Off}
}
if int64(p.Filesz) < 0 {
return nil, &FormatError{off, "invalid program header file size", p.Filesz}
}
p.sr = io.NewSectionReader(r, int64(p.Off), int64(p.Filesz))
p.ReaderAt = p.sr
f.Progs[i] = p
}
// If the number of sections is greater than or equal to SHN_LORESERVE
// (0xff00), shnum has the value zero and the actual number of section
// header table entries is contained in the sh_size field of the section
// header at index 0.
if shoff > 0 && shnum == 0 {
var typ, link uint32
sr.Seek(shoff, seekStart)
switch f.Class {
case ELFCLASS32:
sh := new(Section32)
if err := binary.Read(sr, f.ByteOrder, sh); err != nil {
return nil, err
}
shnum = int(sh.Size)
typ = sh.Type
link = sh.Link
case ELFCLASS64:
sh := new(Section64)
if err := binary.Read(sr, f.ByteOrder, sh); err != nil {
return nil, err
}
shnum = int(sh.Size)
typ = sh.Type
link = sh.Link
}
if SectionType(typ) != SHT_NULL {
return nil, &FormatError{shoff, "invalid type of the initial section", SectionType(typ)}
}
if shnum < int(SHN_LORESERVE) {
return nil, &FormatError{shoff, "invalid ELF shnum contained in sh_size", shnum}
}
// If the section name string table section index is greater than or
// equal to SHN_LORESERVE (0xff00), this member has the value
// SHN_XINDEX (0xffff) and the actual index of the section name
// string table section is contained in the sh_link field of the
// section header at index 0.
if shstrndx == int(SHN_XINDEX) {
shstrndx = int(link)
if shstrndx < int(SHN_LORESERVE) {
return nil, &FormatError{shoff, "invalid ELF shstrndx contained in sh_link", shstrndx}
}
}
}
if shnum > 0 && shentsize < wantShentsize {
return nil, &FormatError{0, "invalid ELF shentsize", shentsize}
}
// Read section headers
c := saferio.SliceCap((*Section)(nil), uint64(shnum))
if c < 0 {
return nil, &FormatError{0, "too many sections", shnum}
}
f.Sections = make([]*Section, 0, c)
names := make([]uint32, 0, c)
for i := 0; i < shnum; i++ {
off := shoff + int64(i)*int64(shentsize)
sr.Seek(off, seekStart)
s := new(Section)
switch f.Class {
case ELFCLASS32:
sh := new(Section32)
if err := binary.Read(sr, f.ByteOrder, sh); err != nil {
return nil, err
}
names = append(names, sh.Name)
s.SectionHeader = SectionHeader{
Type: SectionType(sh.Type),
Flags: SectionFlag(sh.Flags),
Addr: uint64(sh.Addr),
Offset: uint64(sh.Off),
FileSize: uint64(sh.Size),
Link: sh.Link,
Info: sh.Info,
Addralign: uint64(sh.Addralign),
Entsize: uint64(sh.Entsize),
}
case ELFCLASS64:
sh := new(Section64)
if err := binary.Read(sr, f.ByteOrder, sh); err != nil {
return nil, err
}
names = append(names, sh.Name)
s.SectionHeader = SectionHeader{
Type: SectionType(sh.Type),
Flags: SectionFlag(sh.Flags),
Offset: sh.Off,
FileSize: sh.Size,
Addr: sh.Addr,
Link: sh.Link,
Info: sh.Info,
Addralign: sh.Addralign,
Entsize: sh.Entsize,
}
}
if int64(s.Offset) < 0 {
return nil, &FormatError{off, "invalid section offset", int64(s.Offset)}
}
if int64(s.FileSize) < 0 {
return nil, &FormatError{off, "invalid section size", int64(s.FileSize)}
}
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.FileSize))
if s.Flags&SHF_COMPRESSED == 0 {
s.ReaderAt = s.sr
s.Size = s.FileSize
} else {
// Read the compression header.
switch f.Class {
case ELFCLASS32:
ch := new(Chdr32)
if err := binary.Read(s.sr, f.ByteOrder, ch); err != nil {
return nil, err
}
s.compressionType = CompressionType(ch.Type)
s.Size = uint64(ch.Size)
s.Addralign = uint64(ch.Addralign)
s.compressionOffset = int64(binary.Size(ch))
case ELFCLASS64:
ch := new(Chdr64)
if err := binary.Read(s.sr, f.ByteOrder, ch); err != nil {
return nil, err
}
s.compressionType = CompressionType(ch.Type)
s.Size = ch.Size
s.Addralign = ch.Addralign
s.compressionOffset = int64(binary.Size(ch))
}
}
f.Sections = append(f.Sections, s)
}
if len(f.Sections) == 0 {
return f, nil
}
// Load section header string table.
if shstrndx == 0 {
// If the file has no section name string table,
// shstrndx holds the value SHN_UNDEF (0).
return f, nil
}
shstr := f.Sections[shstrndx]
if shstr.Type != SHT_STRTAB {
return nil, &FormatError{shoff + int64(shstrndx*shentsize), "invalid ELF section name string table type", shstr.Type}
}
shstrtab, err := shstr.Data()
if err != nil {
return nil, err
}
for i, s := range f.Sections {
var ok bool
s.Name, ok = getString(shstrtab, int(names[i]))
if !ok {
return nil, &FormatError{shoff + int64(i*shentsize), "bad section name index", names[i]}
}
}
return f, nil
}
// getSymbols returns a slice of Symbols from parsing the symbol table
// with the given type, along with the associated string table.
func (f *File) getSymbols(typ SectionType) ([]Symbol, []byte, error) {
switch f.Class {
case ELFCLASS64:
return f.getSymbols64(typ)
case ELFCLASS32:
return f.getSymbols32(typ)
}
return nil, nil, errors.New("not implemented")
}
// ErrNoSymbols is returned by File.Symbols and File.DynamicSymbols
// if there is no such section in the File.
var ErrNoSymbols = errors.New("no symbol section")
func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, error) {
symtabSection := f.SectionByType(typ)
if symtabSection == nil {
return nil, nil, ErrNoSymbols
}
data, err := symtabSection.Data()
if err != nil {
return nil, nil, fmt.Errorf("cannot load symbol section: %w", err)
}
symtab := bytes.NewReader(data)
if symtab.Len()%Sym32Size != 0 {
return nil, nil, errors.New("length of symbol section is not a multiple of SymSize")
}
strdata, err := f.stringTable(symtabSection.Link)
if err != nil {
return nil, nil, fmt.Errorf("cannot load string table section: %w", err)
}
// The first entry is all zeros.
var skip [Sym32Size]byte
symtab.Read(skip[:])
symbols := make([]Symbol, symtab.Len()/Sym32Size)
i := 0
var sym Sym32
for symtab.Len() > 0 {
binary.Read(symtab, f.ByteOrder, &sym)
str, _ := getString(strdata, int(sym.Name))
symbols[i].Name = str
symbols[i].Info = sym.Info
symbols[i].Other = sym.Other
symbols[i].Section = SectionIndex(sym.Shndx)
symbols[i].Value = uint64(sym.Value)
symbols[i].Size = uint64(sym.Size)
i++
}
return symbols, strdata, nil
}
func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, error) {
symtabSection := f.SectionByType(typ)
if symtabSection == nil {
return nil, nil, ErrNoSymbols
}
data, err := symtabSection.Data()
if err != nil {
return nil, nil, fmt.Errorf("cannot load symbol section: %w", err)
}
symtab := bytes.NewReader(data)
if symtab.Len()%Sym64Size != 0 {
return nil, nil, errors.New("length of symbol section is not a multiple of Sym64Size")
}
strdata, err := f.stringTable(symtabSection.Link)
if err != nil {
return nil, nil, fmt.Errorf("cannot load string table section: %w", err)
}
// The first entry is all zeros.
var skip [Sym64Size]byte
symtab.Read(skip[:])
symbols := make([]Symbol, symtab.Len()/Sym64Size)
i := 0
var sym Sym64
for symtab.Len() > 0 {
binary.Read(symtab, f.ByteOrder, &sym)
str, _ := getString(strdata, int(sym.Name))
symbols[i].Name = str
symbols[i].Info = sym.Info
symbols[i].Other = sym.Other
symbols[i].Section = SectionIndex(sym.Shndx)
symbols[i].Value = sym.Value
symbols[i].Size = sym.Size
i++
}
return symbols, strdata, nil
}
// getString extracts a string from an ELF string table.
func getString(section []byte, start int) (string, bool) {
if start < 0 || start >= len(section) {
return "", false
}
for end := start; end < len(section); end++ {
if section[end] == 0 {
return string(section[start:end]), true
}
}
return "", false
}
// Section returns a section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// applyRelocations applies relocations to dst. rels is a relocations section
// in REL or RELA format.
func (f *File) applyRelocations(dst []byte, rels []byte) error {
switch {
case f.Class == ELFCLASS64 && f.Machine == EM_X86_64:
return f.applyRelocationsAMD64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_386:
return f.applyRelocations386(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_ARM:
return f.applyRelocationsARM(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_AARCH64:
return f.applyRelocationsARM64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_PPC:
return f.applyRelocationsPPC(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_PPC64:
return f.applyRelocationsPPC64(dst, rels)
case f.Class == ELFCLASS32 && f.Machine == EM_MIPS:
return f.applyRelocationsMIPS(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_MIPS:
return f.applyRelocationsMIPS64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_LOONGARCH:
return f.applyRelocationsLOONG64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_RISCV:
return f.applyRelocationsRISCV64(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_S390:
return f.applyRelocationss390x(dst, rels)
case f.Class == ELFCLASS64 && f.Machine == EM_SPARCV9:
return f.applyRelocationsSPARC64(dst, rels)
default:
return errors.New("applyRelocations: not implemented")
}
}
// canApplyRelocation reports whether we should try to apply a
// relocation to a DWARF data section, given a pointer to the symbol
// targeted by the relocation.
// Most relocations in DWARF data tend to be section-relative, but
// some target non-section symbols (for example, low_PC attrs on
// subprogram or compilation unit DIEs that target function symbols).
func canApplyRelocation(sym *Symbol) bool {
return sym.Section != SHN_UNDEF && sym.Section < SHN_LORESERVE
}
func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_X86_64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
// There are relocations, so this must be a normal
// object file. The code below handles only basic relocations
// of the form S + A (symbol plus addend).
switch t {
case R_X86_64_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_X86_64_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocations386(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_386(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if t == R_386_32 {
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsARM(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_ARM(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
switch t {
case R_ARM_ABS32:
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsARM64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_AARCH64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
// There are relocations, so this must be a normal
// object file. The code below handles only basic relocations
// of the form S + A (symbol plus addend).
switch t {
case R_AARCH64_ABS64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_AARCH64_ABS32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsPPC(dst []byte, rels []byte) error {
// 12 is the size of Rela32.
if len(rels)%12 != 0 {
return errors.New("length of relocation section is not a multiple of 12")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 8
t := R_PPC(rela.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_PPC_ADDR32:
if rela.Off+4 >= uint32(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsPPC64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_PPC64(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_PPC64_ADDR64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_PPC64_ADDR32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsMIPS(dst []byte, rels []byte) error {
// 8 is the size of Rel32.
if len(rels)%8 != 0 {
return errors.New("length of relocation section is not a multiple of 8")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rel Rel32
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rel)
symNo := rel.Info >> 8
t := R_MIPS(rel.Info & 0xff)
if symNo == 0 || symNo > uint32(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
switch t {
case R_MIPS_32:
if rel.Off+4 >= uint32(len(dst)) {
continue
}
val := f.ByteOrder.Uint32(dst[rel.Off : rel.Off+4])
val += uint32(sym.Value)
f.ByteOrder.PutUint32(dst[rel.Off:rel.Off+4], val)
}
}
return nil
}
func (f *File) applyRelocationsMIPS64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
var symNo uint64
var t R_MIPS
if f.ByteOrder == binary.BigEndian {
symNo = rela.Info >> 32
t = R_MIPS(rela.Info & 0xff)
} else {
symNo = rela.Info & 0xffffffff
t = R_MIPS(rela.Info >> 56)
}
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_MIPS_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_MIPS_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsLOONG64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
var symNo uint64
var t R_LARCH
symNo = rela.Info >> 32
t = R_LARCH(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_LARCH_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_LARCH_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsRISCV64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_RISCV(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_RISCV_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_RISCV_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationss390x(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_390(rela.Info & 0xffff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_390_64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_390_32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) applyRelocationsSPARC64(dst []byte, rels []byte) error {
// 24 is the size of Rela64.
if len(rels)%24 != 0 {
return errors.New("length of relocation section is not a multiple of 24")
}
symbols, _, err := f.getSymbols(SHT_SYMTAB)
if err != nil {
return err
}
b := bytes.NewReader(rels)
var rela Rela64
for b.Len() > 0 {
binary.Read(b, f.ByteOrder, &rela)
symNo := rela.Info >> 32
t := R_SPARC(rela.Info & 0xff)
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo-1]
if !canApplyRelocation(sym) {
continue
}
switch t {
case R_SPARC_64, R_SPARC_UA64:
if rela.Off+8 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val64 := sym.Value + uint64(rela.Addend)
f.ByteOrder.PutUint64(dst[rela.Off:rela.Off+8], val64)
case R_SPARC_32, R_SPARC_UA32:
if rela.Off+4 >= uint64(len(dst)) || rela.Addend < 0 {
continue
}
val32 := uint32(sym.Value) + uint32(rela.Addend)
f.ByteOrder.PutUint32(dst[rela.Off:rela.Off+4], val32)
}
}
return nil
}
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
switch {
case strings.HasPrefix(s.Name, ".debug_"):
return s.Name[7:]
case strings.HasPrefix(s.Name, ".zdebug_"):
return s.Name[8:]
default:
return ""
}
}
// sectionData gets the data for s, checks its size, and
// applies any applicable relations.
sectionData := func(i int, s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint64(len(b)) < s.Size {
return nil, err
}
var dlen uint64
if len(b) >= 12 && string(b[:4]) == "ZLIB" {
dlen = binary.BigEndian.Uint64(b[4:12])
s.compressionOffset = 12
}
if dlen == 0 && len(b) >= 12 && s.Flags&SHF_COMPRESSED != 0 &&
s.Flags&SHF_ALLOC == 0 &&
f.FileHeader.ByteOrder.Uint32(b[:]) == uint32(COMPRESS_ZLIB) {
s.compressionType = COMPRESS_ZLIB
switch f.FileHeader.Class {
case ELFCLASS32:
// Chdr32.Size offset
dlen = uint64(f.FileHeader.ByteOrder.Uint32(b[4:]))
s.compressionOffset = 12
case ELFCLASS64:
if len(b) < 24 {
return nil, errors.New("invalid compress header 64")
}
// Chdr64.Size offset
dlen = f.FileHeader.ByteOrder.Uint64(b[8:])
s.compressionOffset = 24
default:
return nil, fmt.Errorf("unsupported compress header:%s", f.FileHeader.Class)
}
}
if dlen > 0 {
r, err := zlib.NewReader(bytes.NewBuffer(b[s.compressionOffset:]))
if err != nil {
return nil, err
}
b, err = saferio.ReadData(r, dlen)
if err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
}
if f.Type == ET_EXEC {
// Do not apply relocations to DWARF sections for ET_EXEC binaries.
// Relocations should already be applied, and .rela sections may
// contain incorrect data.
return b, nil
}
for _, r := range f.Sections {
if r.Type != SHT_RELA && r.Type != SHT_REL {
continue
}
if int(r.Info) != i {
continue
}
rd, err := r.Data()
if err != nil {
return nil, err
}
err = f.applyRelocations(b, rd)
if err != nil {
return nil, err
}
}
return b, nil
}
// There are many DWARf sections, but these are the ones
// the debug/dwarf package started with.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(i, s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(i, s)
if err != nil {
return nil, err
}
if suffix == "types" {
if err := d.AddTypes(fmt.Sprintf("types-%d", i), b); err != nil {
return nil, err
}
} else {
if err := d.AddSection(".debug_"+suffix, b); err != nil {
return nil, err
}
}
}
return d, nil
}
// Symbols returns the symbol table for f. The symbols will be listed in the order
// they appear in f.
//
// For compatibility with Go 1.0, Symbols omits the null symbol at index 0.
// After retrieving the symbols as symtab, an externally supplied index x
// corresponds to symtab[x-1], not symtab[x].
func (f *File) Symbols() ([]Symbol, error) {
sym, _, err := f.getSymbols(SHT_SYMTAB)
return sym, err
}
// DynamicSymbols returns the dynamic symbol table for f. The symbols
// will be listed in the order they appear in f.
//
// If f has a symbol version table, the returned Symbols will have
// initialized Version and Library fields.
//
// For compatibility with Symbols, DynamicSymbols omits the null symbol at index 0.
// After retrieving the symbols as symtab, an externally supplied index x
// corresponds to symtab[x-1], not symtab[x].
func (f *File) DynamicSymbols() ([]Symbol, error) {
sym, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
if f.gnuVersionInit(str) {
for i := range sym {
sym[i].Library, sym[i].Version = f.gnuVersion(i)
}
}
return sym, nil
}
type ImportedSymbol struct {
Name string
Version string
Library string
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
// It does not return weak symbols.
func (f *File) ImportedSymbols() ([]ImportedSymbol, error) {
sym, str, err := f.getSymbols(SHT_DYNSYM)
if err != nil {
return nil, err
}
f.gnuVersionInit(str)
var all []ImportedSymbol
for i, s := range sym {
if ST_BIND(s.Info) == STB_GLOBAL && s.Section == SHN_UNDEF {
all = append(all, ImportedSymbol{Name: s.Name})
sym := &all[len(all)-1]
sym.Library, sym.Version = f.gnuVersion(i)
}
}
return all, nil
}
type verneed struct {
File string
Name string
}
// gnuVersionInit parses the GNU version tables
// for use by calls to gnuVersion.
func (f *File) gnuVersionInit(str []byte) bool {
if f.gnuNeed != nil {
// Already initialized
return true
}
// Accumulate verneed information.
vn := f.SectionByType(SHT_GNU_VERNEED)
if vn == nil {
return false
}
d, _ := vn.Data()
var need []verneed
i := 0
for {
if i+16 > len(d) {
break
}
vers := f.ByteOrder.Uint16(d[i : i+2])
if vers != 1 {
break
}
cnt := f.ByteOrder.Uint16(d[i+2 : i+4])
fileoff := f.ByteOrder.Uint32(d[i+4 : i+8])
aux := f.ByteOrder.Uint32(d[i+8 : i+12])
next := f.ByteOrder.Uint32(d[i+12 : i+16])
file, _ := getString(str, int(fileoff))
var name string
j := i + int(aux)
for c := 0; c < int(cnt); c++ {
if j+16 > len(d) {
break
}
// hash := f.ByteOrder.Uint32(d[j:j+4])
// flags := f.ByteOrder.Uint16(d[j+4:j+6])
other := f.ByteOrder.Uint16(d[j+6 : j+8])
nameoff := f.ByteOrder.Uint32(d[j+8 : j+12])
next := f.ByteOrder.Uint32(d[j+12 : j+16])
name, _ = getString(str, int(nameoff))
ndx := int(other)
if ndx >= len(need) {
a := make([]verneed, 2*(ndx+1))
copy(a, need)
need = a
}
need[ndx] = verneed{file, name}
if next == 0 {
break
}
j += int(next)
}
if next == 0 {
break
}
i += int(next)
}
// Versym parallels symbol table, indexing into verneed.
vs := f.SectionByType(SHT_GNU_VERSYM)
if vs == nil {
return false
}
d, _ = vs.Data()
f.gnuNeed = need
f.gnuVersym = d
return true
}
// gnuVersion adds Library and Version information to sym,
// which came from offset i of the symbol table.
func (f *File) gnuVersion(i int) (library string, version string) {
// Each entry is two bytes; skip undef entry at beginning.
i = (i + 1) * 2
if i >= len(f.gnuVersym) {
return
}
s := f.gnuVersym[i:]
if len(s) < 2 {
return
}
j := int(f.ByteOrder.Uint16(s))
if j < 2 || j >= len(f.gnuNeed) {
return
}
n := &f.gnuNeed[j]
return n.File, n.Name
}
// ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
return f.DynString(DT_NEEDED)
}
// DynString returns the strings listed for the given tag in the file's dynamic
// section.
//
// The tag must be one that takes string values: DT_NEEDED, DT_SONAME, DT_RPATH, or
// DT_RUNPATH.
func (f *File) DynString(tag DynTag) ([]string, error) {
switch tag {
case DT_NEEDED, DT_SONAME, DT_RPATH, DT_RUNPATH:
default:
return nil, fmt.Errorf("non-string-valued tag %v", tag)
}
ds := f.SectionByType(SHT_DYNAMIC)
if ds == nil {
// not dynamic, so no libraries
return nil, nil
}
d, err := ds.Data()
if err != nil {
return nil, err
}
str, err := f.stringTable(ds.Link)
if err != nil {
return nil, err
}
var all []string
for len(d) > 0 {
var t DynTag
var v uint64
switch f.Class {
case ELFCLASS32:
t = DynTag(f.ByteOrder.Uint32(d[0:4]))
v = uint64(f.ByteOrder.Uint32(d[4:8]))
d = d[8:]
case ELFCLASS64:
t = DynTag(f.ByteOrder.Uint64(d[0:8]))
v = f.ByteOrder.Uint64(d[8:16])
d = d[16:]
}
if t == tag {
s, ok := getString(str, int(v))
if ok {
all = append(all, s)
}
}
}
return all, nil
}
type nobitsSectionReader struct{}
func (*nobitsSectionReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, errors.New("unexpected read from SHT_NOBITS section")
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package elf
import (
"io"
"os"
)
// errorReader returns error from all operations.
type errorReader struct {
error
}
func (r errorReader) Read(p []byte) (n int, err error) {
return 0, r.error
}
func (r errorReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, r.error
}
func (r errorReader) Seek(offset int64, whence int) (int64, error) {
return 0, r.error
}
func (r errorReader) Close() error {
return r.error
}
// readSeekerFromReader converts an io.Reader into an io.ReadSeeker.
// In general Seek may not be efficient, but it is optimized for
// common cases such as seeking to the end to find the length of the
// data.
type readSeekerFromReader struct {
reset func() (io.Reader, error)
r io.Reader
size int64
offset int64
}
func (r *readSeekerFromReader) start() {
x, err := r.reset()
if err != nil {
r.r = errorReader{err}
} else {
r.r = x
}
r.offset = 0
}
func (r *readSeekerFromReader) Read(p []byte) (n int, err error) {
if r.r == nil {
r.start()
}
n, err = r.r.Read(p)
r.offset += int64(n)
return n, err
}
func (r *readSeekerFromReader) Seek(offset int64, whence int) (int64, error) {
var newOffset int64
switch whence {
case seekStart:
newOffset = offset
case seekCurrent:
newOffset = r.offset + offset
case seekEnd:
newOffset = r.size + offset
default:
return 0, os.ErrInvalid
}
switch {
case newOffset == r.offset:
return newOffset, nil
case newOffset < 0, newOffset > r.size:
return 0, os.ErrInvalid
case newOffset == 0:
r.r = nil
case newOffset == r.size:
r.r = errorReader{io.EOF}
default:
if newOffset < r.offset {
// Restart at the beginning.
r.start()
}
// Read until we reach offset.
var buf [512]byte
for r.offset < newOffset {
b := buf[:]
if newOffset-r.offset < int64(len(buf)) {
b = buf[:newOffset-r.offset]
}
if _, err := r.Read(b); err != nil {
return 0, err
}
}
}
r.offset = newOffset
return r.offset, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
* Line tables
*/
package gosym
import (
"bytes"
"encoding/binary"
"sort"
"sync"
)
// version of the pclntab
type version int
const (
verUnknown version = iota
ver11
ver12
ver116
ver118
ver120
)
// A LineTable is a data structure mapping program counters to line numbers.
//
// In Go 1.1 and earlier, each function (represented by a Func) had its own LineTable,
// and the line number corresponded to a numbering of all source lines in the
// program, across all files. That absolute line number would then have to be
// converted separately to a file name and line number within the file.
//
// In Go 1.2, the format of the data changed so that there is a single LineTable
// for the entire program, shared by all Funcs, and there are no absolute line
// numbers, just line numbers within specific files.
//
// For the most part, LineTable's methods should be treated as an internal
// detail of the package; callers should use the methods on Table instead.
type LineTable struct {
Data []byte
PC uint64
Line int
// This mutex is used to keep parsing of pclntab synchronous.
mu sync.Mutex
// Contains the version of the pclntab section.
version version
// Go 1.2/1.16/1.18 state
binary binary.ByteOrder
quantum uint32
ptrsize uint32
textStart uint64 // address of runtime.text symbol (1.18+)
funcnametab []byte
cutab []byte
funcdata []byte
functab []byte
nfunctab uint32
filetab []byte
pctab []byte // points to the pctables.
nfiletab uint32
funcNames map[uint32]string // cache the function names
strings map[uint32]string // interned substrings of Data, keyed by offset
// fileMap varies depending on the version of the object file.
// For ver12, it maps the name to the index in the file table.
// For ver116, it maps the name to the offset in filetab.
fileMap map[string]uint32
}
// NOTE(rsc): This is wrong for GOARCH=arm, which uses a quantum of 4,
// but we have no idea whether we're using arm or not. This only
// matters in the old (pre-Go 1.2) symbol table format, so it's not worth
// fixing.
const oldQuantum = 1
func (t *LineTable) parse(targetPC uint64, targetLine int) (b []byte, pc uint64, line int) {
// The PC/line table can be thought of as a sequence of
// <pc update>* <line update>
// batches. Each update batch results in a (pc, line) pair,
// where line applies to every PC from pc up to but not
// including the pc of the next pair.
//
// Here we process each update individually, which simplifies
// the code, but makes the corner cases more confusing.
b, pc, line = t.Data, t.PC, t.Line
for pc <= targetPC && line != targetLine && len(b) > 0 {
code := b[0]
b = b[1:]
switch {
case code == 0:
if len(b) < 4 {
b = b[0:0]
break
}
val := binary.BigEndian.Uint32(b)
b = b[4:]
line += int(val)
case code <= 64:
line += int(code)
case code <= 128:
line -= int(code - 64)
default:
pc += oldQuantum * uint64(code-128)
continue
}
pc += oldQuantum
}
return b, pc, line
}
func (t *LineTable) slice(pc uint64) *LineTable {
data, pc, line := t.parse(pc, -1)
return &LineTable{Data: data, PC: pc, Line: line}
}
// PCToLine returns the line number for the given program counter.
//
// Deprecated: Use Table's PCToLine method instead.
func (t *LineTable) PCToLine(pc uint64) int {
if t.isGo12() {
return t.go12PCToLine(pc)
}
_, _, line := t.parse(pc, -1)
return line
}
// LineToPC returns the program counter for the given line number,
// considering only program counters before maxpc.
//
// Deprecated: Use Table's LineToPC method instead.
func (t *LineTable) LineToPC(line int, maxpc uint64) uint64 {
if t.isGo12() {
return 0
}
_, pc, line1 := t.parse(maxpc, line)
if line1 != line {
return 0
}
// Subtract quantum from PC to account for post-line increment
return pc - oldQuantum
}
// NewLineTable returns a new PC/line table
// corresponding to the encoded data.
// Text must be the start address of the
// corresponding text segment.
func NewLineTable(data []byte, text uint64) *LineTable {
return &LineTable{Data: data, PC: text, Line: 0, funcNames: make(map[uint32]string), strings: make(map[uint32]string)}
}
// Go 1.2 symbol table format.
// See golang.org/s/go12symtab.
//
// A general note about the methods here: rather than try to avoid
// index out of bounds errors, we trust Go to detect them, and then
// we recover from the panics and treat them as indicative of a malformed
// or incomplete table.
//
// The methods called by symtab.go, which begin with "go12" prefixes,
// are expected to have that recovery logic.
// isGo12 reports whether this is a Go 1.2 (or later) symbol table.
func (t *LineTable) isGo12() bool {
t.parsePclnTab()
return t.version >= ver12
}
const (
go12magic = 0xfffffffb
go116magic = 0xfffffffa
go118magic = 0xfffffff0
go120magic = 0xfffffff1
)
// uintptr returns the pointer-sized value encoded at b.
// The pointer size is dictated by the table being read.
func (t *LineTable) uintptr(b []byte) uint64 {
if t.ptrsize == 4 {
return uint64(t.binary.Uint32(b))
}
return t.binary.Uint64(b)
}
// parsePclnTab parses the pclntab, setting the version.
func (t *LineTable) parsePclnTab() {
t.mu.Lock()
defer t.mu.Unlock()
if t.version != verUnknown {
return
}
// Note that during this function, setting the version is the last thing we do.
// If we set the version too early, and parsing failed (likely as a panic on
// slice lookups), we'd have a mistaken version.
//
// Error paths through this code will default the version to 1.1.
t.version = ver11
if !disableRecover {
defer func() {
// If we panic parsing, assume it's a Go 1.1 pclntab.
recover()
}()
}
// Check header: 4-byte magic, two zeros, pc quantum, pointer size.
if len(t.Data) < 16 || t.Data[4] != 0 || t.Data[5] != 0 ||
(t.Data[6] != 1 && t.Data[6] != 2 && t.Data[6] != 4) || // pc quantum
(t.Data[7] != 4 && t.Data[7] != 8) { // pointer size
return
}
var possibleVersion version
leMagic := binary.LittleEndian.Uint32(t.Data)
beMagic := binary.BigEndian.Uint32(t.Data)
switch {
case leMagic == go12magic:
t.binary, possibleVersion = binary.LittleEndian, ver12
case beMagic == go12magic:
t.binary, possibleVersion = binary.BigEndian, ver12
case leMagic == go116magic:
t.binary, possibleVersion = binary.LittleEndian, ver116
case beMagic == go116magic:
t.binary, possibleVersion = binary.BigEndian, ver116
case leMagic == go118magic:
t.binary, possibleVersion = binary.LittleEndian, ver118
case beMagic == go118magic:
t.binary, possibleVersion = binary.BigEndian, ver118
case leMagic == go120magic:
t.binary, possibleVersion = binary.LittleEndian, ver120
case beMagic == go120magic:
t.binary, possibleVersion = binary.BigEndian, ver120
default:
return
}
t.version = possibleVersion
// quantum and ptrSize are the same between 1.2, 1.16, and 1.18
t.quantum = uint32(t.Data[6])
t.ptrsize = uint32(t.Data[7])
offset := func(word uint32) uint64 {
return t.uintptr(t.Data[8+word*t.ptrsize:])
}
data := func(word uint32) []byte {
return t.Data[offset(word):]
}
switch possibleVersion {
case ver118, ver120:
t.nfunctab = uint32(offset(0))
t.nfiletab = uint32(offset(1))
t.textStart = t.PC // use the start PC instead of reading from the table, which may be unrelocated
t.funcnametab = data(3)
t.cutab = data(4)
t.filetab = data(5)
t.pctab = data(6)
t.funcdata = data(7)
t.functab = data(7)
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
t.functab = t.functab[:functabsize]
case ver116:
t.nfunctab = uint32(offset(0))
t.nfiletab = uint32(offset(1))
t.funcnametab = data(2)
t.cutab = data(3)
t.filetab = data(4)
t.pctab = data(5)
t.funcdata = data(6)
t.functab = data(6)
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
t.functab = t.functab[:functabsize]
case ver12:
t.nfunctab = uint32(t.uintptr(t.Data[8:]))
t.funcdata = t.Data
t.funcnametab = t.Data
t.functab = t.Data[8+t.ptrsize:]
t.pctab = t.Data
functabsize := (int(t.nfunctab)*2 + 1) * t.functabFieldSize()
fileoff := t.binary.Uint32(t.functab[functabsize:])
t.functab = t.functab[:functabsize]
t.filetab = t.Data[fileoff:]
t.nfiletab = t.binary.Uint32(t.filetab)
t.filetab = t.filetab[:t.nfiletab*4]
default:
panic("unreachable")
}
}
// go12Funcs returns a slice of Funcs derived from the Go 1.2+ pcln table.
func (t *LineTable) go12Funcs() []Func {
// Assume it is malformed and return nil on error.
if !disableRecover {
defer func() {
recover()
}()
}
ft := t.funcTab()
funcs := make([]Func, ft.Count())
syms := make([]Sym, len(funcs))
for i := range funcs {
f := &funcs[i]
f.Entry = ft.pc(i)
f.End = ft.pc(i + 1)
info := t.funcData(uint32(i))
f.LineTable = t
f.FrameSize = int(info.deferreturn())
syms[i] = Sym{
Value: f.Entry,
Type: 'T',
Name: t.funcName(info.nameOff()),
GoType: 0,
Func: f,
goVersion: t.version,
}
f.Sym = &syms[i]
}
return funcs
}
// findFunc returns the funcData corresponding to the given program counter.
func (t *LineTable) findFunc(pc uint64) funcData {
ft := t.funcTab()
if pc < ft.pc(0) || pc >= ft.pc(ft.Count()) {
return funcData{}
}
idx := sort.Search(int(t.nfunctab), func(i int) bool {
return ft.pc(i) > pc
})
idx--
return t.funcData(uint32(idx))
}
// readvarint reads, removes, and returns a varint from *pp.
func (t *LineTable) readvarint(pp *[]byte) uint32 {
var v, shift uint32
p := *pp
for shift = 0; ; shift += 7 {
b := p[0]
p = p[1:]
v |= (uint32(b) & 0x7F) << shift
if b&0x80 == 0 {
break
}
}
*pp = p
return v
}
// funcName returns the name of the function found at off.
func (t *LineTable) funcName(off uint32) string {
if s, ok := t.funcNames[off]; ok {
return s
}
i := bytes.IndexByte(t.funcnametab[off:], 0)
s := string(t.funcnametab[off : off+uint32(i)])
t.funcNames[off] = s
return s
}
// stringFrom returns a Go string found at off from a position.
func (t *LineTable) stringFrom(arr []byte, off uint32) string {
if s, ok := t.strings[off]; ok {
return s
}
i := bytes.IndexByte(arr[off:], 0)
s := string(arr[off : off+uint32(i)])
t.strings[off] = s
return s
}
// string returns a Go string found at off.
func (t *LineTable) string(off uint32) string {
return t.stringFrom(t.funcdata, off)
}
// functabFieldSize returns the size in bytes of a single functab field.
func (t *LineTable) functabFieldSize() int {
if t.version >= ver118 {
return 4
}
return int(t.ptrsize)
}
// funcTab returns t's funcTab.
func (t *LineTable) funcTab() funcTab {
return funcTab{LineTable: t, sz: t.functabFieldSize()}
}
// funcTab is memory corresponding to a slice of functab structs, followed by an invalid PC.
// A functab struct is a PC and a func offset.
type funcTab struct {
*LineTable
sz int // cached result of t.functabFieldSize
}
// Count returns the number of func entries in f.
func (f funcTab) Count() int {
return int(f.nfunctab)
}
// pc returns the PC of the i'th func in f.
func (f funcTab) pc(i int) uint64 {
u := f.uint(f.functab[2*i*f.sz:])
if f.version >= ver118 {
u += f.textStart
}
return u
}
// funcOff returns the funcdata offset of the i'th func in f.
func (f funcTab) funcOff(i int) uint64 {
return f.uint(f.functab[(2*i+1)*f.sz:])
}
// uint returns the uint stored at b.
func (f funcTab) uint(b []byte) uint64 {
if f.sz == 4 {
return uint64(f.binary.Uint32(b))
}
return f.binary.Uint64(b)
}
// funcData is memory corresponding to an _func struct.
type funcData struct {
t *LineTable // LineTable this data is a part of
data []byte // raw memory for the function
}
// funcData returns the ith funcData in t.functab.
func (t *LineTable) funcData(i uint32) funcData {
data := t.funcdata[t.funcTab().funcOff(int(i)):]
return funcData{t: t, data: data}
}
// IsZero reports whether f is the zero value.
func (f funcData) IsZero() bool {
return f.t == nil && f.data == nil
}
// entryPC returns the func's entry PC.
func (f *funcData) entryPC() uint64 {
// In Go 1.18, the first field of _func changed
// from a uintptr entry PC to a uint32 entry offset.
if f.t.version >= ver118 {
// TODO: support multiple text sections.
// See runtime/symtab.go:(*moduledata).textAddr.
return uint64(f.t.binary.Uint32(f.data)) + f.t.textStart
}
return f.t.uintptr(f.data)
}
func (f funcData) nameOff() uint32 { return f.field(1) }
func (f funcData) deferreturn() uint32 { return f.field(3) }
func (f funcData) pcfile() uint32 { return f.field(5) }
func (f funcData) pcln() uint32 { return f.field(6) }
func (f funcData) cuOffset() uint32 { return f.field(8) }
// field returns the nth field of the _func struct.
// It panics if n == 0 or n > 9; for n == 0, call f.entryPC.
// Most callers should use a named field accessor (just above).
func (f funcData) field(n uint32) uint32 {
if n == 0 || n > 9 {
panic("bad funcdata field")
}
// In Go 1.18, the first field of _func changed
// from a uintptr entry PC to a uint32 entry offset.
sz0 := f.t.ptrsize
if f.t.version >= ver118 {
sz0 = 4
}
off := sz0 + (n-1)*4 // subsequent fields are 4 bytes each
data := f.data[off:]
return f.t.binary.Uint32(data)
}
// step advances to the next pc, value pair in the encoded table.
func (t *LineTable) step(p *[]byte, pc *uint64, val *int32, first bool) bool {
uvdelta := t.readvarint(p)
if uvdelta == 0 && !first {
return false
}
if uvdelta&1 != 0 {
uvdelta = ^(uvdelta >> 1)
} else {
uvdelta >>= 1
}
vdelta := int32(uvdelta)
pcdelta := t.readvarint(p) * t.quantum
*pc += uint64(pcdelta)
*val += vdelta
return true
}
// pcvalue reports the value associated with the target pc.
// off is the offset to the beginning of the pc-value table,
// and entry is the start PC for the corresponding function.
func (t *LineTable) pcvalue(off uint32, entry, targetpc uint64) int32 {
p := t.pctab[off:]
val := int32(-1)
pc := entry
for t.step(&p, &pc, &val, pc == entry) {
if targetpc < pc {
return val
}
}
return -1
}
// findFileLine scans one function in the binary looking for a
// program counter in the given file on the given line.
// It does so by running the pc-value tables mapping program counter
// to file number. Since most functions come from a single file, these
// are usually short and quick to scan. If a file match is found, then the
// code goes to the expense of looking for a simultaneous line number match.
func (t *LineTable) findFileLine(entry uint64, filetab, linetab uint32, filenum, line int32, cutab []byte) uint64 {
if filetab == 0 || linetab == 0 {
return 0
}
fp := t.pctab[filetab:]
fl := t.pctab[linetab:]
fileVal := int32(-1)
filePC := entry
lineVal := int32(-1)
linePC := entry
fileStartPC := filePC
for t.step(&fp, &filePC, &fileVal, filePC == entry) {
fileIndex := fileVal
if t.version == ver116 || t.version == ver118 || t.version == ver120 {
fileIndex = int32(t.binary.Uint32(cutab[fileVal*4:]))
}
if fileIndex == filenum && fileStartPC < filePC {
// fileIndex is in effect starting at fileStartPC up to
// but not including filePC, and it's the file we want.
// Run the PC table looking for a matching line number
// or until we reach filePC.
lineStartPC := linePC
for linePC < filePC && t.step(&fl, &linePC, &lineVal, linePC == entry) {
// lineVal is in effect until linePC, and lineStartPC < filePC.
if lineVal == line {
if fileStartPC <= lineStartPC {
return lineStartPC
}
if fileStartPC < linePC {
return fileStartPC
}
}
lineStartPC = linePC
}
}
fileStartPC = filePC
}
return 0
}
// go12PCToLine maps program counter to line number for the Go 1.2+ pcln table.
func (t *LineTable) go12PCToLine(pc uint64) (line int) {
defer func() {
if !disableRecover && recover() != nil {
line = -1
}
}()
f := t.findFunc(pc)
if f.IsZero() {
return -1
}
entry := f.entryPC()
linetab := f.pcln()
return int(t.pcvalue(linetab, entry, pc))
}
// go12PCToFile maps program counter to file name for the Go 1.2+ pcln table.
func (t *LineTable) go12PCToFile(pc uint64) (file string) {
defer func() {
if !disableRecover && recover() != nil {
file = ""
}
}()
f := t.findFunc(pc)
if f.IsZero() {
return ""
}
entry := f.entryPC()
filetab := f.pcfile()
fno := t.pcvalue(filetab, entry, pc)
if t.version == ver12 {
if fno <= 0 {
return ""
}
return t.string(t.binary.Uint32(t.filetab[4*fno:]))
}
// Go ≥ 1.16
if fno < 0 { // 0 is valid for ≥ 1.16
return ""
}
cuoff := f.cuOffset()
if fnoff := t.binary.Uint32(t.cutab[(cuoff+uint32(fno))*4:]); fnoff != ^uint32(0) {
return t.stringFrom(t.filetab, fnoff)
}
return ""
}
// go12LineToPC maps a (file, line) pair to a program counter for the Go 1.2+ pcln table.
func (t *LineTable) go12LineToPC(file string, line int) (pc uint64) {
defer func() {
if !disableRecover && recover() != nil {
pc = 0
}
}()
t.initFileMap()
filenum, ok := t.fileMap[file]
if !ok {
return 0
}
// Scan all functions.
// If this turns out to be a bottleneck, we could build a map[int32][]int32
// mapping file number to a list of functions with code from that file.
var cutab []byte
for i := uint32(0); i < t.nfunctab; i++ {
f := t.funcData(i)
entry := f.entryPC()
filetab := f.pcfile()
linetab := f.pcln()
if t.version == ver116 || t.version == ver118 || t.version == ver120 {
if f.cuOffset() == ^uint32(0) {
// skip functions without compilation unit (not real function, or linker generated)
continue
}
cutab = t.cutab[f.cuOffset()*4:]
}
pc := t.findFileLine(entry, filetab, linetab, int32(filenum), int32(line), cutab)
if pc != 0 {
return pc
}
}
return 0
}
// initFileMap initializes the map from file name to file number.
func (t *LineTable) initFileMap() {
t.mu.Lock()
defer t.mu.Unlock()
if t.fileMap != nil {
return
}
m := make(map[string]uint32)
if t.version == ver12 {
for i := uint32(1); i < t.nfiletab; i++ {
s := t.string(t.binary.Uint32(t.filetab[4*i:]))
m[s] = i
}
} else {
var pos uint32
for i := uint32(0); i < t.nfiletab; i++ {
s := t.stringFrom(t.filetab, pos)
m[s] = pos
pos += uint32(len(s) + 1)
}
}
t.fileMap = m
}
// go12MapFiles adds to m a key for every file in the Go 1.2 LineTable.
// Every key maps to obj. That's not a very interesting map, but it provides
// a way for callers to obtain the list of files in the program.
func (t *LineTable) go12MapFiles(m map[string]*Obj, obj *Obj) {
if !disableRecover {
defer func() {
recover()
}()
}
t.initFileMap()
for file := range t.fileMap {
m[file] = obj
}
}
// disableRecover causes this package not to swallow panics.
// This is useful when making changes.
const disableRecover = false
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gosym implements access to the Go symbol
// and line number tables embedded in Go binaries generated
// by the gc compilers.
package gosym
import (
"bytes"
"encoding/binary"
"fmt"
"strconv"
"strings"
)
/*
* Symbols
*/
// A Sym represents a single symbol table entry.
type Sym struct {
Value uint64
Type byte
Name string
GoType uint64
// If this symbol is a function symbol, the corresponding Func
Func *Func
goVersion version
}
// Static reports whether this symbol is static (not visible outside its file).
func (s *Sym) Static() bool { return s.Type >= 'a' }
// nameWithoutInst returns s.Name if s.Name has no brackets (does not reference an
// instantiated type, function, or method). If s.Name contains brackets, then it
// returns s.Name with all the contents between (and including) the outermost left
// and right bracket removed. This is useful to ignore any extra slashes or dots
// inside the brackets from the string searches below, where needed.
func (s *Sym) nameWithoutInst() string {
start := strings.Index(s.Name, "[")
if start < 0 {
return s.Name
}
end := strings.LastIndex(s.Name, "]")
if end < 0 {
// Malformed name, should contain closing bracket too.
return s.Name
}
return s.Name[0:start] + s.Name[end+1:]
}
// PackageName returns the package part of the symbol name,
// or the empty string if there is none.
func (s *Sym) PackageName() string {
name := s.nameWithoutInst()
// Since go1.20, a prefix of "type:" and "go:" is a compiler-generated symbol,
// they do not belong to any package.
//
// See cmd/compile/internal/base/link.go:ReservedImports variable.
if s.goVersion >= ver120 && (strings.HasPrefix(name, "go:") || strings.HasPrefix(name, "type:")) {
return ""
}
// For go1.18 and below, the prefix are "type." and "go." instead.
if s.goVersion <= ver118 && (strings.HasPrefix(name, "go.") || strings.HasPrefix(name, "type.")) {
return ""
}
pathend := strings.LastIndex(name, "/")
if pathend < 0 {
pathend = 0
}
if i := strings.Index(name[pathend:], "."); i != -1 {
return name[:pathend+i]
}
return ""
}
// ReceiverName returns the receiver type name of this symbol,
// or the empty string if there is none. A receiver name is only detected in
// the case that s.Name is fully-specified with a package name.
func (s *Sym) ReceiverName() string {
name := s.nameWithoutInst()
// If we find a slash in name, it should precede any bracketed expression
// that was removed, so pathend will apply correctly to name and s.Name.
pathend := strings.LastIndex(name, "/")
if pathend < 0 {
pathend = 0
}
// Find the first dot after pathend (or from the beginning, if there was
// no slash in name).
l := strings.Index(name[pathend:], ".")
// Find the last dot after pathend (or the beginning).
r := strings.LastIndex(name[pathend:], ".")
if l == -1 || r == -1 || l == r {
// There is no receiver if we didn't find two distinct dots after pathend.
return ""
}
// Given there is a trailing '.' that is in name, find it now in s.Name.
// pathend+l should apply to s.Name, because it should be the dot in the
// package name.
r = strings.LastIndex(s.Name[pathend:], ".")
return s.Name[pathend+l+1 : pathend+r]
}
// BaseName returns the symbol name without the package or receiver name.
func (s *Sym) BaseName() string {
name := s.nameWithoutInst()
if i := strings.LastIndex(name, "."); i != -1 {
if s.Name != name {
brack := strings.Index(s.Name, "[")
if i > brack {
// BaseName is a method name after the brackets, so
// recalculate for s.Name. Otherwise, i applies
// correctly to s.Name, since it is before the
// brackets.
i = strings.LastIndex(s.Name, ".")
}
}
return s.Name[i+1:]
}
return s.Name
}
// A Func collects information about a single function.
type Func struct {
Entry uint64
*Sym
End uint64
Params []*Sym // nil for Go 1.3 and later binaries
Locals []*Sym // nil for Go 1.3 and later binaries
FrameSize int
LineTable *LineTable
Obj *Obj
}
// An Obj represents a collection of functions in a symbol table.
//
// The exact method of division of a binary into separate Objs is an internal detail
// of the symbol table format.
//
// In early versions of Go each source file became a different Obj.
//
// In Go 1 and Go 1.1, each package produced one Obj for all Go sources
// and one Obj per C source file.
//
// In Go 1.2, there is a single Obj for the entire program.
type Obj struct {
// Funcs is a list of functions in the Obj.
Funcs []Func
// In Go 1.1 and earlier, Paths is a list of symbols corresponding
// to the source file names that produced the Obj.
// In Go 1.2, Paths is nil.
// Use the keys of Table.Files to obtain a list of source files.
Paths []Sym // meta
}
/*
* Symbol tables
*/
// Table represents a Go symbol table. It stores all of the
// symbols decoded from the program and provides methods to translate
// between symbols, names, and addresses.
type Table struct {
Syms []Sym // nil for Go 1.3 and later binaries
Funcs []Func
Files map[string]*Obj // for Go 1.2 and later all files map to one Obj
Objs []Obj // for Go 1.2 and later only one Obj in slice
go12line *LineTable // Go 1.2 line number table
}
type sym struct {
value uint64
gotype uint64
typ byte
name []byte
}
var (
littleEndianSymtab = []byte{0xFD, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00}
bigEndianSymtab = []byte{0xFF, 0xFF, 0xFF, 0xFD, 0x00, 0x00, 0x00}
oldLittleEndianSymtab = []byte{0xFE, 0xFF, 0xFF, 0xFF, 0x00, 0x00}
)
func walksymtab(data []byte, fn func(sym) error) error {
if len(data) == 0 { // missing symtab is okay
return nil
}
var order binary.ByteOrder = binary.BigEndian
newTable := false
switch {
case bytes.HasPrefix(data, oldLittleEndianSymtab):
// Same as Go 1.0, but little endian.
// Format was used during interim development between Go 1.0 and Go 1.1.
// Should not be widespread, but easy to support.
data = data[6:]
order = binary.LittleEndian
case bytes.HasPrefix(data, bigEndianSymtab):
newTable = true
case bytes.HasPrefix(data, littleEndianSymtab):
newTable = true
order = binary.LittleEndian
}
var ptrsz int
if newTable {
if len(data) < 8 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
ptrsz = int(data[7])
if ptrsz != 4 && ptrsz != 8 {
return &DecodingError{7, "invalid pointer size", ptrsz}
}
data = data[8:]
}
var s sym
p := data
for len(p) >= 4 {
var typ byte
if newTable {
// Symbol type, value, Go type.
typ = p[0] & 0x3F
wideValue := p[0]&0x40 != 0
goType := p[0]&0x80 != 0
if typ < 26 {
typ += 'A'
} else {
typ += 'a' - 26
}
s.typ = typ
p = p[1:]
if wideValue {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width value
if ptrsz == 8 {
s.value = order.Uint64(p[0:8])
p = p[8:]
} else {
s.value = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
} else {
// varint value
s.value = 0
shift := uint(0)
for len(p) > 0 && p[0]&0x80 != 0 {
s.value |= uint64(p[0]&0x7F) << shift
shift += 7
p = p[1:]
}
if len(p) == 0 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.value |= uint64(p[0]) << shift
p = p[1:]
}
if goType {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width go type
if ptrsz == 8 {
s.gotype = order.Uint64(p[0:8])
p = p[8:]
} else {
s.gotype = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
}
} else {
// Value, symbol type.
s.value = uint64(order.Uint32(p[0:4]))
if len(p) < 5 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
typ = p[4]
if typ&0x80 == 0 {
return &DecodingError{len(data) - len(p) + 4, "bad symbol type", typ}
}
typ &^= 0x80
s.typ = typ
p = p[5:]
}
// Name.
var i int
var nnul int
for i = 0; i < len(p); i++ {
if p[i] == 0 {
nnul = 1
break
}
}
switch typ {
case 'z', 'Z':
p = p[i+nnul:]
for i = 0; i+2 <= len(p); i += 2 {
if p[i] == 0 && p[i+1] == 0 {
nnul = 2
break
}
}
}
if len(p) < i+nnul {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.name = p[0:i]
i += nnul
p = p[i:]
if !newTable {
if len(p) < 4 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// Go type.
s.gotype = uint64(order.Uint32(p[:4]))
p = p[4:]
}
fn(s)
}
return nil
}
// NewTable decodes the Go symbol table (the ".gosymtab" section in ELF),
// returning an in-memory representation.
// Starting with Go 1.3, the Go symbol table no longer includes symbol data.
func NewTable(symtab []byte, pcln *LineTable) (*Table, error) {
var n int
err := walksymtab(symtab, func(s sym) error {
n++
return nil
})
if err != nil {
return nil, err
}
var t Table
if pcln.isGo12() {
t.go12line = pcln
}
fname := make(map[uint16]string)
t.Syms = make([]Sym, 0, n)
nf := 0
nz := 0
lasttyp := uint8(0)
err = walksymtab(symtab, func(s sym) error {
n := len(t.Syms)
t.Syms = t.Syms[0 : n+1]
ts := &t.Syms[n]
ts.Type = s.typ
ts.Value = s.value
ts.GoType = s.gotype
ts.goVersion = pcln.version
switch s.typ {
default:
// rewrite name to use . instead of · (c2 b7)
w := 0
b := s.name
for i := 0; i < len(b); i++ {
if b[i] == 0xc2 && i+1 < len(b) && b[i+1] == 0xb7 {
i++
b[i] = '.'
}
b[w] = b[i]
w++
}
ts.Name = string(s.name[0:w])
case 'z', 'Z':
if lasttyp != 'z' && lasttyp != 'Z' {
nz++
}
for i := 0; i < len(s.name); i += 2 {
eltIdx := binary.BigEndian.Uint16(s.name[i : i+2])
elt, ok := fname[eltIdx]
if !ok {
return &DecodingError{-1, "bad filename code", eltIdx}
}
if n := len(ts.Name); n > 0 && ts.Name[n-1] != '/' {
ts.Name += "/"
}
ts.Name += elt
}
}
switch s.typ {
case 'T', 't', 'L', 'l':
nf++
case 'f':
fname[uint16(s.value)] = ts.Name
}
lasttyp = s.typ
return nil
})
if err != nil {
return nil, err
}
t.Funcs = make([]Func, 0, nf)
t.Files = make(map[string]*Obj)
var obj *Obj
if t.go12line != nil {
// Put all functions into one Obj.
t.Objs = make([]Obj, 1)
obj = &t.Objs[0]
t.go12line.go12MapFiles(t.Files, obj)
} else {
t.Objs = make([]Obj, 0, nz)
}
// Count text symbols and attach frame sizes, parameters, and
// locals to them. Also, find object file boundaries.
lastf := 0
for i := 0; i < len(t.Syms); i++ {
sym := &t.Syms[i]
switch sym.Type {
case 'Z', 'z': // path symbol
if t.go12line != nil {
// Go 1.2 binaries have the file information elsewhere. Ignore.
break
}
// Finish the current object
if obj != nil {
obj.Funcs = t.Funcs[lastf:]
}
lastf = len(t.Funcs)
// Start new object
n := len(t.Objs)
t.Objs = t.Objs[0 : n+1]
obj = &t.Objs[n]
// Count & copy path symbols
var end int
for end = i + 1; end < len(t.Syms); end++ {
if c := t.Syms[end].Type; c != 'Z' && c != 'z' {
break
}
}
obj.Paths = t.Syms[i:end]
i = end - 1 // loop will i++
// Record file names
depth := 0
for j := range obj.Paths {
s := &obj.Paths[j]
if s.Name == "" {
depth--
} else {
if depth == 0 {
t.Files[s.Name] = obj
}
depth++
}
}
case 'T', 't', 'L', 'l': // text symbol
if n := len(t.Funcs); n > 0 {
t.Funcs[n-1].End = sym.Value
}
if sym.Name == "runtime.etext" || sym.Name == "etext" {
continue
}
// Count parameter and local (auto) syms
var np, na int
var end int
countloop:
for end = i + 1; end < len(t.Syms); end++ {
switch t.Syms[end].Type {
case 'T', 't', 'L', 'l', 'Z', 'z':
break countloop
case 'p':
np++
case 'a':
na++
}
}
// Fill in the function symbol
n := len(t.Funcs)
t.Funcs = t.Funcs[0 : n+1]
fn := &t.Funcs[n]
sym.Func = fn
fn.Params = make([]*Sym, 0, np)
fn.Locals = make([]*Sym, 0, na)
fn.Sym = sym
fn.Entry = sym.Value
fn.Obj = obj
if t.go12line != nil {
// All functions share the same line table.
// It knows how to narrow down to a specific
// function quickly.
fn.LineTable = t.go12line
} else if pcln != nil {
fn.LineTable = pcln.slice(fn.Entry)
pcln = fn.LineTable
}
for j := i; j < end; j++ {
s := &t.Syms[j]
switch s.Type {
case 'm':
fn.FrameSize = int(s.Value)
case 'p':
n := len(fn.Params)
fn.Params = fn.Params[0 : n+1]
fn.Params[n] = s
case 'a':
n := len(fn.Locals)
fn.Locals = fn.Locals[0 : n+1]
fn.Locals[n] = s
}
}
i = end - 1 // loop will i++
}
}
if t.go12line != nil && nf == 0 {
t.Funcs = t.go12line.go12Funcs()
}
if obj != nil {
obj.Funcs = t.Funcs[lastf:]
}
return &t, nil
}
// PCToFunc returns the function containing the program counter pc,
// or nil if there is no such function.
func (t *Table) PCToFunc(pc uint64) *Func {
funcs := t.Funcs
for len(funcs) > 0 {
m := len(funcs) / 2
fn := &funcs[m]
switch {
case pc < fn.Entry:
funcs = funcs[0:m]
case fn.Entry <= pc && pc < fn.End:
return fn
default:
funcs = funcs[m+1:]
}
}
return nil
}
// PCToLine looks up line number information for a program counter.
// If there is no information, it returns fn == nil.
func (t *Table) PCToLine(pc uint64) (file string, line int, fn *Func) {
if fn = t.PCToFunc(pc); fn == nil {
return
}
if t.go12line != nil {
file = t.go12line.go12PCToFile(pc)
line = t.go12line.go12PCToLine(pc)
} else {
file, line = fn.Obj.lineFromAline(fn.LineTable.PCToLine(pc))
}
return
}
// LineToPC looks up the first program counter on the given line in
// the named file. It returns UnknownPathError or UnknownLineError if
// there is an error looking up this line.
func (t *Table) LineToPC(file string, line int) (pc uint64, fn *Func, err error) {
obj, ok := t.Files[file]
if !ok {
return 0, nil, UnknownFileError(file)
}
if t.go12line != nil {
pc := t.go12line.go12LineToPC(file, line)
if pc == 0 {
return 0, nil, &UnknownLineError{file, line}
}
return pc, t.PCToFunc(pc), nil
}
abs, err := obj.alineFromLine(file, line)
if err != nil {
return
}
for i := range obj.Funcs {
f := &obj.Funcs[i]
pc := f.LineTable.LineToPC(abs, f.End)
if pc != 0 {
return pc, f, nil
}
}
return 0, nil, &UnknownLineError{file, line}
}
// LookupSym returns the text, data, or bss symbol with the given name,
// or nil if no such symbol is found.
func (t *Table) LookupSym(name string) *Sym {
// TODO(austin) Maybe make a map
for i := range t.Syms {
s := &t.Syms[i]
switch s.Type {
case 'T', 't', 'L', 'l', 'D', 'd', 'B', 'b':
if s.Name == name {
return s
}
}
}
return nil
}
// LookupFunc returns the text, data, or bss symbol with the given name,
// or nil if no such symbol is found.
func (t *Table) LookupFunc(name string) *Func {
for i := range t.Funcs {
f := &t.Funcs[i]
if f.Sym.Name == name {
return f
}
}
return nil
}
// SymByAddr returns the text, data, or bss symbol starting at the given address.
func (t *Table) SymByAddr(addr uint64) *Sym {
for i := range t.Syms {
s := &t.Syms[i]
switch s.Type {
case 'T', 't', 'L', 'l', 'D', 'd', 'B', 'b':
if s.Value == addr {
return s
}
}
}
return nil
}
/*
* Object files
*/
// This is legacy code for Go 1.1 and earlier, which used the
// Plan 9 format for pc-line tables. This code was never quite
// correct. It's probably very close, and it's usually correct, but
// we never quite found all the corner cases.
//
// Go 1.2 and later use a simpler format, documented at golang.org/s/go12symtab.
func (o *Obj) lineFromAline(aline int) (string, int) {
type stackEnt struct {
path string
start int
offset int
prev *stackEnt
}
noPath := &stackEnt{"", 0, 0, nil}
tos := noPath
pathloop:
for _, s := range o.Paths {
val := int(s.Value)
switch {
case val > aline:
break pathloop
case val == 1:
// Start a new stack
tos = &stackEnt{s.Name, val, 0, noPath}
case s.Name == "":
// Pop
if tos == noPath {
return "<malformed symbol table>", 0
}
tos.prev.offset += val - tos.start
tos = tos.prev
default:
// Push
tos = &stackEnt{s.Name, val, 0, tos}
}
}
if tos == noPath {
return "", 0
}
return tos.path, aline - tos.start - tos.offset + 1
}
func (o *Obj) alineFromLine(path string, line int) (int, error) {
if line < 1 {
return 0, &UnknownLineError{path, line}
}
for i, s := range o.Paths {
// Find this path
if s.Name != path {
continue
}
// Find this line at this stack level
depth := 0
var incstart int
line += int(s.Value)
pathloop:
for _, s := range o.Paths[i:] {
val := int(s.Value)
switch {
case depth == 1 && val >= line:
return line - 1, nil
case s.Name == "":
depth--
if depth == 0 {
break pathloop
} else if depth == 1 {
line += val - incstart
}
default:
if depth == 1 {
incstart = val
}
depth++
}
}
return 0, &UnknownLineError{path, line}
}
return 0, UnknownFileError(path)
}
/*
* Errors
*/
// UnknownFileError represents a failure to find the specific file in
// the symbol table.
type UnknownFileError string
func (e UnknownFileError) Error() string { return "unknown file: " + string(e) }
// UnknownLineError represents a failure to map a line to a program
// counter, either because the line is beyond the bounds of the file
// or because there is no code on the given line.
type UnknownLineError struct {
File string
Line int
}
func (e *UnknownLineError) Error() string {
return "no code at " + e.File + ":" + strconv.Itoa(e.Line)
}
// DecodingError represents an error during the decoding of
// the symbol table.
type DecodingError struct {
off int
msg string
val any
}
func (e *DecodingError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" at byte %#x", e.off)
return msg
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package macho
import (
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"os"
)
// A FatFile is a Mach-O universal binary that contains at least one architecture.
type FatFile struct {
Magic uint32
Arches []FatArch
closer io.Closer
}
// A FatArchHeader represents a fat header for a specific image architecture.
type FatArchHeader struct {
Cpu Cpu
SubCpu uint32
Offset uint32
Size uint32
Align uint32
}
const fatArchHeaderSize = 5 * 4
// A FatArch is a Mach-O File inside a FatFile.
type FatArch struct {
FatArchHeader
*File
}
// ErrNotFat is returned from NewFatFile or OpenFat when the file is not a
// universal binary but may be a thin binary, based on its magic number.
var ErrNotFat = &FormatError{0, "not a fat Mach-O file", nil}
// NewFatFile creates a new FatFile for accessing all the Mach-O images in a
// universal binary. The Mach-O binary is expected to start at position 0 in
// the ReaderAt.
func NewFatFile(r io.ReaderAt) (*FatFile, error) {
var ff FatFile
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read the fat_header struct, which is always in big endian.
// Start with the magic number.
err := binary.Read(sr, binary.BigEndian, &ff.Magic)
if err != nil {
return nil, &FormatError{0, "error reading magic number", nil}
} else if ff.Magic != MagicFat {
// See if this is a Mach-O file via its magic number. The magic
// must be converted to little endian first though.
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], ff.Magic)
leMagic := binary.LittleEndian.Uint32(buf[:])
if leMagic == Magic32 || leMagic == Magic64 {
return nil, ErrNotFat
} else {
return nil, &FormatError{0, "invalid magic number", nil}
}
}
offset := int64(4)
// Read the number of FatArchHeaders that come after the fat_header.
var narch uint32
err = binary.Read(sr, binary.BigEndian, &narch)
if err != nil {
return nil, &FormatError{offset, "invalid fat_header", nil}
}
offset += 4
if narch < 1 {
return nil, &FormatError{offset, "file contains no images", nil}
}
// Combine the Cpu and SubCpu (both uint32) into a uint64 to make sure
// there are not duplicate architectures.
seenArches := make(map[uint64]bool)
// Make sure that all images are for the same MH_ type.
var machoType Type
// Following the fat_header comes narch fat_arch structs that index
// Mach-O images further in the file.
c := saferio.SliceCap((*FatArch)(nil), uint64(narch))
if c < 0 {
return nil, &FormatError{offset, "too many images", nil}
}
ff.Arches = make([]FatArch, 0, c)
for i := uint32(0); i < narch; i++ {
var fa FatArch
err = binary.Read(sr, binary.BigEndian, &fa.FatArchHeader)
if err != nil {
return nil, &FormatError{offset, "invalid fat_arch header", nil}
}
offset += fatArchHeaderSize
fr := io.NewSectionReader(r, int64(fa.Offset), int64(fa.Size))
fa.File, err = NewFile(fr)
if err != nil {
return nil, err
}
// Make sure the architecture for this image is not duplicate.
seenArch := (uint64(fa.Cpu) << 32) | uint64(fa.SubCpu)
if o, k := seenArches[seenArch]; o || k {
return nil, &FormatError{offset, fmt.Sprintf("duplicate architecture cpu=%v, subcpu=%#x", fa.Cpu, fa.SubCpu), nil}
}
seenArches[seenArch] = true
// Make sure the Mach-O type matches that of the first image.
if i == 0 {
machoType = fa.Type
} else {
if fa.Type != machoType {
return nil, &FormatError{offset, fmt.Sprintf("Mach-O type for architecture #%d (type=%#x) does not match first (type=%#x)", i, fa.Type, machoType), nil}
}
}
ff.Arches = append(ff.Arches, fa)
}
return &ff, nil
}
// OpenFat opens the named file using os.Open and prepares it for use as a Mach-O
// universal binary.
func OpenFat(name string) (*FatFile, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFatFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
func (ff *FatFile) Close() error {
var err error
if ff.closer != nil {
err = ff.closer.Close()
ff.closer = nil
}
return err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package macho implements access to Mach-O object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package macho
// High level access to low level data structures.
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"os"
"strings"
)
// A File represents an open Mach-O file.
type File struct {
FileHeader
ByteOrder binary.ByteOrder
Loads []Load
Sections []*Section
Symtab *Symtab
Dysymtab *Dysymtab
closer io.Closer
}
// A Load represents any Mach-O load command.
type Load interface {
Raw() []byte
}
// A LoadBytes is the uninterpreted bytes of a Mach-O load command.
type LoadBytes []byte
func (b LoadBytes) Raw() []byte { return b }
// A SegmentHeader is the header for a Mach-O 32-bit or 64-bit load segment command.
type SegmentHeader struct {
Cmd LoadCmd
Len uint32
Name string
Addr uint64
Memsz uint64
Offset uint64
Filesz uint64
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A Segment represents a Mach-O 32-bit or 64-bit load segment command.
type Segment struct {
LoadBytes
SegmentHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the segment.
func (s *Segment) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, s.Filesz, 0)
}
// Open returns a new ReadSeeker reading the segment.
func (s *Segment) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
type SectionHeader struct {
Name string
Seg string
Addr uint64
Size uint64
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
}
// A Reloc represents a Mach-O relocation.
type Reloc struct {
Addr uint32
Value uint32
// when Scattered == false && Extern == true, Value is the symbol number.
// when Scattered == false && Extern == false, Value is the section number.
// when Scattered == true, Value is the value that this reloc refers to.
Type uint8
Len uint8 // 0=byte, 1=word, 2=long, 3=quad
Pcrel bool
Extern bool // valid if Scattered == false
Scattered bool
}
type Section struct {
SectionHeader
Relocs []Reloc
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the Mach-O section.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, s.Size, 0)
}
// Open returns a new ReadSeeker reading the Mach-O section.
func (s *Section) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
// A Dylib represents a Mach-O load dynamic library command.
type Dylib struct {
LoadBytes
Name string
Time uint32
CurrentVersion uint32
CompatVersion uint32
}
// A Symtab represents a Mach-O symbol table command.
type Symtab struct {
LoadBytes
SymtabCmd
Syms []Symbol
}
// A Dysymtab represents a Mach-O dynamic symbol table command.
type Dysymtab struct {
LoadBytes
DysymtabCmd
IndirectSyms []uint32 // indices into Symtab.Syms
}
// A Rpath represents a Mach-O rpath command.
type Rpath struct {
LoadBytes
Path string
}
// A Symbol is a Mach-O 32-bit or 64-bit symbol table entry.
type Symbol struct {
Name string
Type uint8
Sect uint8
Desc uint16
Value uint64
}
/*
* Mach-O reader
*/
// FormatError is returned by some operations if the data does
// not have the correct format for an object file.
type FormatError struct {
off int64
msg string
val any
}
func (e *FormatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" in record at byte %#x", e.off)
return msg
}
// Open opens the named file using os.Open and prepares it for use as a Mach-O binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the File.
// If the File was created using NewFile directly instead of Open,
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// NewFile creates a new File for accessing a Mach-O binary in an underlying reader.
// The Mach-O binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
f := new(File)
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode Mach magic to determine byte order, size.
// Magic32 and Magic64 differ only in the bottom bit.
var ident [4]byte
if _, err := r.ReadAt(ident[0:], 0); err != nil {
return nil, err
}
be := binary.BigEndian.Uint32(ident[0:])
le := binary.LittleEndian.Uint32(ident[0:])
switch Magic32 &^ 1 {
case be &^ 1:
f.ByteOrder = binary.BigEndian
f.Magic = be
case le &^ 1:
f.ByteOrder = binary.LittleEndian
f.Magic = le
default:
return nil, &FormatError{0, "invalid magic number", nil}
}
// Read entire file header.
if err := binary.Read(sr, f.ByteOrder, &f.FileHeader); err != nil {
return nil, err
}
// Then load commands.
offset := int64(fileHeaderSize32)
if f.Magic == Magic64 {
offset = fileHeaderSize64
}
dat, err := saferio.ReadDataAt(r, uint64(f.Cmdsz), offset)
if err != nil {
return nil, err
}
c := saferio.SliceCap((*Load)(nil), uint64(f.Ncmd))
if c < 0 {
return nil, &FormatError{offset, "too many load commands", nil}
}
f.Loads = make([]Load, 0, c)
bo := f.ByteOrder
for i := uint32(0); i < f.Ncmd; i++ {
// Each load command begins with uint32 command and length.
if len(dat) < 8 {
return nil, &FormatError{offset, "command block too small", nil}
}
cmd, siz := LoadCmd(bo.Uint32(dat[0:4])), bo.Uint32(dat[4:8])
if siz < 8 || siz > uint32(len(dat)) {
return nil, &FormatError{offset, "invalid command block size", nil}
}
var cmddat []byte
cmddat, dat = dat[0:siz], dat[siz:]
offset += int64(siz)
var s *Segment
switch cmd {
default:
f.Loads = append(f.Loads, LoadBytes(cmddat))
case LoadCmdRpath:
var hdr RpathCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
l := new(Rpath)
if hdr.Path >= uint32(len(cmddat)) {
return nil, &FormatError{offset, "invalid path in rpath command", hdr.Path}
}
l.Path = cstring(cmddat[hdr.Path:])
l.LoadBytes = LoadBytes(cmddat)
f.Loads = append(f.Loads, l)
case LoadCmdDylib:
var hdr DylibCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
l := new(Dylib)
if hdr.Name >= uint32(len(cmddat)) {
return nil, &FormatError{offset, "invalid name in dynamic library command", hdr.Name}
}
l.Name = cstring(cmddat[hdr.Name:])
l.Time = hdr.Time
l.CurrentVersion = hdr.CurrentVersion
l.CompatVersion = hdr.CompatVersion
l.LoadBytes = LoadBytes(cmddat)
f.Loads = append(f.Loads, l)
case LoadCmdSymtab:
var hdr SymtabCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
strtab, err := saferio.ReadDataAt(r, uint64(hdr.Strsize), int64(hdr.Stroff))
if err != nil {
return nil, err
}
var symsz int
if f.Magic == Magic64 {
symsz = 16
} else {
symsz = 12
}
symdat, err := saferio.ReadDataAt(r, uint64(hdr.Nsyms)*uint64(symsz), int64(hdr.Symoff))
if err != nil {
return nil, err
}
st, err := f.parseSymtab(symdat, strtab, cmddat, &hdr, offset)
if err != nil {
return nil, err
}
f.Loads = append(f.Loads, st)
f.Symtab = st
case LoadCmdDysymtab:
var hdr DysymtabCmd
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &hdr); err != nil {
return nil, err
}
if f.Symtab == nil {
return nil, &FormatError{offset, "dynamic symbol table seen before any ordinary symbol table", nil}
} else if hdr.Iundefsym > uint32(len(f.Symtab.Syms)) {
return nil, &FormatError{offset, fmt.Sprintf(
"undefined symbols index in dynamic symbol table command is greater than symbol table length (%d > %d)",
hdr.Iundefsym, len(f.Symtab.Syms)), nil}
} else if hdr.Iundefsym+hdr.Nundefsym > uint32(len(f.Symtab.Syms)) {
return nil, &FormatError{offset, fmt.Sprintf(
"number of undefined symbols after index in dynamic symbol table command is greater than symbol table length (%d > %d)",
hdr.Iundefsym+hdr.Nundefsym, len(f.Symtab.Syms)), nil}
}
dat, err := saferio.ReadDataAt(r, uint64(hdr.Nindirectsyms)*4, int64(hdr.Indirectsymoff))
if err != nil {
return nil, err
}
x := make([]uint32, hdr.Nindirectsyms)
if err := binary.Read(bytes.NewReader(dat), bo, x); err != nil {
return nil, err
}
st := new(Dysymtab)
st.LoadBytes = LoadBytes(cmddat)
st.DysymtabCmd = hdr
st.IndirectSyms = x
f.Loads = append(f.Loads, st)
f.Dysymtab = st
case LoadCmdSegment:
var seg32 Segment32
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &seg32); err != nil {
return nil, err
}
s = new(Segment)
s.LoadBytes = cmddat
s.Cmd = cmd
s.Len = siz
s.Name = cstring(seg32.Name[0:])
s.Addr = uint64(seg32.Addr)
s.Memsz = uint64(seg32.Memsz)
s.Offset = uint64(seg32.Offset)
s.Filesz = uint64(seg32.Filesz)
s.Maxprot = seg32.Maxprot
s.Prot = seg32.Prot
s.Nsect = seg32.Nsect
s.Flag = seg32.Flag
f.Loads = append(f.Loads, s)
for i := 0; i < int(s.Nsect); i++ {
var sh32 Section32
if err := binary.Read(b, bo, &sh32); err != nil {
return nil, err
}
sh := new(Section)
sh.Name = cstring(sh32.Name[0:])
sh.Seg = cstring(sh32.Seg[0:])
sh.Addr = uint64(sh32.Addr)
sh.Size = uint64(sh32.Size)
sh.Offset = sh32.Offset
sh.Align = sh32.Align
sh.Reloff = sh32.Reloff
sh.Nreloc = sh32.Nreloc
sh.Flags = sh32.Flags
if err := f.pushSection(sh, r); err != nil {
return nil, err
}
}
case LoadCmdSegment64:
var seg64 Segment64
b := bytes.NewReader(cmddat)
if err := binary.Read(b, bo, &seg64); err != nil {
return nil, err
}
s = new(Segment)
s.LoadBytes = cmddat
s.Cmd = cmd
s.Len = siz
s.Name = cstring(seg64.Name[0:])
s.Addr = seg64.Addr
s.Memsz = seg64.Memsz
s.Offset = seg64.Offset
s.Filesz = seg64.Filesz
s.Maxprot = seg64.Maxprot
s.Prot = seg64.Prot
s.Nsect = seg64.Nsect
s.Flag = seg64.Flag
f.Loads = append(f.Loads, s)
for i := 0; i < int(s.Nsect); i++ {
var sh64 Section64
if err := binary.Read(b, bo, &sh64); err != nil {
return nil, err
}
sh := new(Section)
sh.Name = cstring(sh64.Name[0:])
sh.Seg = cstring(sh64.Seg[0:])
sh.Addr = sh64.Addr
sh.Size = sh64.Size
sh.Offset = sh64.Offset
sh.Align = sh64.Align
sh.Reloff = sh64.Reloff
sh.Nreloc = sh64.Nreloc
sh.Flags = sh64.Flags
if err := f.pushSection(sh, r); err != nil {
return nil, err
}
}
}
if s != nil {
if int64(s.Offset) < 0 {
return nil, &FormatError{offset, "invalid section offset", s.Offset}
}
if int64(s.Filesz) < 0 {
return nil, &FormatError{offset, "invalid section file size", s.Filesz}
}
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.Filesz))
s.ReaderAt = s.sr
}
}
return f, nil
}
func (f *File) parseSymtab(symdat, strtab, cmddat []byte, hdr *SymtabCmd, offset int64) (*Symtab, error) {
bo := f.ByteOrder
c := saferio.SliceCap((*Symbol)(nil), uint64(hdr.Nsyms))
if c < 0 {
return nil, &FormatError{offset, "too many symbols", nil}
}
symtab := make([]Symbol, 0, c)
b := bytes.NewReader(symdat)
for i := 0; i < int(hdr.Nsyms); i++ {
var n Nlist64
if f.Magic == Magic64 {
if err := binary.Read(b, bo, &n); err != nil {
return nil, err
}
} else {
var n32 Nlist32
if err := binary.Read(b, bo, &n32); err != nil {
return nil, err
}
n.Name = n32.Name
n.Type = n32.Type
n.Sect = n32.Sect
n.Desc = n32.Desc
n.Value = uint64(n32.Value)
}
if n.Name >= uint32(len(strtab)) {
return nil, &FormatError{offset, "invalid name in symbol table", n.Name}
}
// We add "_" to Go symbols. Strip it here. See issue 33808.
name := cstring(strtab[n.Name:])
if strings.Contains(name, ".") && name[0] == '_' {
name = name[1:]
}
symtab = append(symtab, Symbol{
Name: name,
Type: n.Type,
Sect: n.Sect,
Desc: n.Desc,
Value: n.Value,
})
}
st := new(Symtab)
st.LoadBytes = LoadBytes(cmddat)
st.Syms = symtab
return st, nil
}
type relocInfo struct {
Addr uint32
Symnum uint32
}
func (f *File) pushSection(sh *Section, r io.ReaderAt) error {
f.Sections = append(f.Sections, sh)
sh.sr = io.NewSectionReader(r, int64(sh.Offset), int64(sh.Size))
sh.ReaderAt = sh.sr
if sh.Nreloc > 0 {
reldat, err := saferio.ReadDataAt(r, uint64(sh.Nreloc)*8, int64(sh.Reloff))
if err != nil {
return err
}
b := bytes.NewReader(reldat)
bo := f.ByteOrder
sh.Relocs = make([]Reloc, sh.Nreloc)
for i := range sh.Relocs {
rel := &sh.Relocs[i]
var ri relocInfo
if err := binary.Read(b, bo, &ri); err != nil {
return err
}
if ri.Addr&(1<<31) != 0 { // scattered
rel.Addr = ri.Addr & (1<<24 - 1)
rel.Type = uint8((ri.Addr >> 24) & (1<<4 - 1))
rel.Len = uint8((ri.Addr >> 28) & (1<<2 - 1))
rel.Pcrel = ri.Addr&(1<<30) != 0
rel.Value = ri.Symnum
rel.Scattered = true
} else {
switch bo {
case binary.LittleEndian:
rel.Addr = ri.Addr
rel.Value = ri.Symnum & (1<<24 - 1)
rel.Pcrel = ri.Symnum&(1<<24) != 0
rel.Len = uint8((ri.Symnum >> 25) & (1<<2 - 1))
rel.Extern = ri.Symnum&(1<<27) != 0
rel.Type = uint8((ri.Symnum >> 28) & (1<<4 - 1))
case binary.BigEndian:
rel.Addr = ri.Addr
rel.Value = ri.Symnum >> 8
rel.Pcrel = ri.Symnum&(1<<7) != 0
rel.Len = uint8((ri.Symnum >> 5) & (1<<2 - 1))
rel.Extern = ri.Symnum&(1<<4) != 0
rel.Type = uint8(ri.Symnum & (1<<4 - 1))
default:
panic("unreachable")
}
}
}
}
return nil
}
func cstring(b []byte) string {
i := bytes.IndexByte(b, 0)
if i == -1 {
i = len(b)
}
return string(b[0:i])
}
// Segment returns the first Segment with the given name, or nil if no such segment exists.
func (f *File) Segment(name string) *Segment {
for _, l := range f.Loads {
if s, ok := l.(*Segment); ok && s.Name == name {
return s
}
}
return nil
}
// Section returns the first section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// DWARF returns the DWARF debug information for the Mach-O file.
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
switch {
case strings.HasPrefix(s.Name, "__debug_"):
return s.Name[8:]
case strings.HasPrefix(s.Name, "__zdebug_"):
return s.Name[9:]
default:
return ""
}
}
sectionData := func(s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint64(len(b)) < s.Size {
return nil, err
}
if len(b) >= 12 && string(b[:4]) == "ZLIB" {
dlen := binary.BigEndian.Uint64(b[4:12])
dbuf := make([]byte, dlen)
r, err := zlib.NewReader(bytes.NewBuffer(b[12:]))
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, dbuf); err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
b = dbuf
}
return b, nil
}
// There are many other DWARF sections, but these
// are the ones the debug/dwarf package uses.
// Don't bother loading others.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for _, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
if suffix == "types" {
err = d.AddTypes(fmt.Sprintf("types-%d", i), b)
} else {
err = d.AddSection(".debug_"+suffix, b)
}
if err != nil {
return nil, err
}
}
return d, nil
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
func (f *File) ImportedSymbols() ([]string, error) {
if f.Dysymtab == nil || f.Symtab == nil {
return nil, &FormatError{0, "missing symbol table", nil}
}
st := f.Symtab
dt := f.Dysymtab
var all []string
for _, s := range st.Syms[dt.Iundefsym : dt.Iundefsym+dt.Nundefsym] {
all = append(all, s.Name)
}
return all, nil
}
// ImportedLibraries returns the paths of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
var all []string
for _, l := range f.Loads {
if lib, ok := l.(*Dylib); ok {
all = append(all, lib.Name)
}
}
return all, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Mach-O header data structures
// Originally at:
// http://developer.apple.com/mac/library/documentation/DeveloperTools/Conceptual/MachORuntime/Reference/reference.html (since deleted by Apple)
// Archived copy at:
// https://web.archive.org/web/20090819232456/http://developer.apple.com/documentation/DeveloperTools/Conceptual/MachORuntime/index.html
// For cloned PDF see:
// https://github.com/aidansteele/osx-abi-macho-file-format-reference
package macho
import "strconv"
// A FileHeader represents a Mach-O file header.
type FileHeader struct {
Magic uint32
Cpu Cpu
SubCpu uint32
Type Type
Ncmd uint32
Cmdsz uint32
Flags uint32
}
const (
fileHeaderSize32 = 7 * 4
fileHeaderSize64 = 8 * 4
)
const (
Magic32 uint32 = 0xfeedface
Magic64 uint32 = 0xfeedfacf
MagicFat uint32 = 0xcafebabe
)
// A Type is the Mach-O file type, e.g. an object file, executable, or dynamic library.
type Type uint32
const (
TypeObj Type = 1
TypeExec Type = 2
TypeDylib Type = 6
TypeBundle Type = 8
)
var typeStrings = []intName{
{uint32(TypeObj), "Obj"},
{uint32(TypeExec), "Exec"},
{uint32(TypeDylib), "Dylib"},
{uint32(TypeBundle), "Bundle"},
}
func (t Type) String() string { return stringName(uint32(t), typeStrings, false) }
func (t Type) GoString() string { return stringName(uint32(t), typeStrings, true) }
// A Cpu is a Mach-O cpu type.
type Cpu uint32
const cpuArch64 = 0x01000000
const (
Cpu386 Cpu = 7
CpuAmd64 Cpu = Cpu386 | cpuArch64
CpuArm Cpu = 12
CpuArm64 Cpu = CpuArm | cpuArch64
CpuPpc Cpu = 18
CpuPpc64 Cpu = CpuPpc | cpuArch64
)
var cpuStrings = []intName{
{uint32(Cpu386), "Cpu386"},
{uint32(CpuAmd64), "CpuAmd64"},
{uint32(CpuArm), "CpuArm"},
{uint32(CpuArm64), "CpuArm64"},
{uint32(CpuPpc), "CpuPpc"},
{uint32(CpuPpc64), "CpuPpc64"},
}
func (i Cpu) String() string { return stringName(uint32(i), cpuStrings, false) }
func (i Cpu) GoString() string { return stringName(uint32(i), cpuStrings, true) }
// A LoadCmd is a Mach-O load command.
type LoadCmd uint32
const (
LoadCmdSegment LoadCmd = 0x1
LoadCmdSymtab LoadCmd = 0x2
LoadCmdThread LoadCmd = 0x4
LoadCmdUnixThread LoadCmd = 0x5 // thread+stack
LoadCmdDysymtab LoadCmd = 0xb
LoadCmdDylib LoadCmd = 0xc // load dylib command
LoadCmdDylinker LoadCmd = 0xf // id dylinker command (not load dylinker command)
LoadCmdSegment64 LoadCmd = 0x19
LoadCmdRpath LoadCmd = 0x8000001c
)
var cmdStrings = []intName{
{uint32(LoadCmdSegment), "LoadCmdSegment"},
{uint32(LoadCmdThread), "LoadCmdThread"},
{uint32(LoadCmdUnixThread), "LoadCmdUnixThread"},
{uint32(LoadCmdDylib), "LoadCmdDylib"},
{uint32(LoadCmdSegment64), "LoadCmdSegment64"},
{uint32(LoadCmdRpath), "LoadCmdRpath"},
}
func (i LoadCmd) String() string { return stringName(uint32(i), cmdStrings, false) }
func (i LoadCmd) GoString() string { return stringName(uint32(i), cmdStrings, true) }
type (
// A Segment32 is a 32-bit Mach-O segment load command.
Segment32 struct {
Cmd LoadCmd
Len uint32
Name [16]byte
Addr uint32
Memsz uint32
Offset uint32
Filesz uint32
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A Segment64 is a 64-bit Mach-O segment load command.
Segment64 struct {
Cmd LoadCmd
Len uint32
Name [16]byte
Addr uint64
Memsz uint64
Offset uint64
Filesz uint64
Maxprot uint32
Prot uint32
Nsect uint32
Flag uint32
}
// A SymtabCmd is a Mach-O symbol table command.
SymtabCmd struct {
Cmd LoadCmd
Len uint32
Symoff uint32
Nsyms uint32
Stroff uint32
Strsize uint32
}
// A DysymtabCmd is a Mach-O dynamic symbol table command.
DysymtabCmd struct {
Cmd LoadCmd
Len uint32
Ilocalsym uint32
Nlocalsym uint32
Iextdefsym uint32
Nextdefsym uint32
Iundefsym uint32
Nundefsym uint32
Tocoffset uint32
Ntoc uint32
Modtaboff uint32
Nmodtab uint32
Extrefsymoff uint32
Nextrefsyms uint32
Indirectsymoff uint32
Nindirectsyms uint32
Extreloff uint32
Nextrel uint32
Locreloff uint32
Nlocrel uint32
}
// A DylibCmd is a Mach-O load dynamic library command.
DylibCmd struct {
Cmd LoadCmd
Len uint32
Name uint32
Time uint32
CurrentVersion uint32
CompatVersion uint32
}
// A RpathCmd is a Mach-O rpath command.
RpathCmd struct {
Cmd LoadCmd
Len uint32
Path uint32
}
// A Thread is a Mach-O thread state command.
Thread struct {
Cmd LoadCmd
Len uint32
Type uint32
Data []uint32
}
)
const (
FlagNoUndefs uint32 = 0x1
FlagIncrLink uint32 = 0x2
FlagDyldLink uint32 = 0x4
FlagBindAtLoad uint32 = 0x8
FlagPrebound uint32 = 0x10
FlagSplitSegs uint32 = 0x20
FlagLazyInit uint32 = 0x40
FlagTwoLevel uint32 = 0x80
FlagForceFlat uint32 = 0x100
FlagNoMultiDefs uint32 = 0x200
FlagNoFixPrebinding uint32 = 0x400
FlagPrebindable uint32 = 0x800
FlagAllModsBound uint32 = 0x1000
FlagSubsectionsViaSymbols uint32 = 0x2000
FlagCanonical uint32 = 0x4000
FlagWeakDefines uint32 = 0x8000
FlagBindsToWeak uint32 = 0x10000
FlagAllowStackExecution uint32 = 0x20000
FlagRootSafe uint32 = 0x40000
FlagSetuidSafe uint32 = 0x80000
FlagNoReexportedDylibs uint32 = 0x100000
FlagPIE uint32 = 0x200000
FlagDeadStrippableDylib uint32 = 0x400000
FlagHasTLVDescriptors uint32 = 0x800000
FlagNoHeapExecution uint32 = 0x1000000
FlagAppExtensionSafe uint32 = 0x2000000
)
// A Section32 is a 32-bit Mach-O section header.
type Section32 struct {
Name [16]byte
Seg [16]byte
Addr uint32
Size uint32
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
Reserve1 uint32
Reserve2 uint32
}
// A Section64 is a 64-bit Mach-O section header.
type Section64 struct {
Name [16]byte
Seg [16]byte
Addr uint64
Size uint64
Offset uint32
Align uint32
Reloff uint32
Nreloc uint32
Flags uint32
Reserve1 uint32
Reserve2 uint32
Reserve3 uint32
}
// An Nlist32 is a Mach-O 32-bit symbol table entry.
type Nlist32 struct {
Name uint32
Type uint8
Sect uint8
Desc uint16
Value uint32
}
// An Nlist64 is a Mach-O 64-bit symbol table entry.
type Nlist64 struct {
Name uint32
Type uint8
Sect uint8
Desc uint16
Value uint64
}
// Regs386 is the Mach-O 386 register structure.
type Regs386 struct {
AX uint32
BX uint32
CX uint32
DX uint32
DI uint32
SI uint32
BP uint32
SP uint32
SS uint32
FLAGS uint32
IP uint32
CS uint32
DS uint32
ES uint32
FS uint32
GS uint32
}
// RegsAMD64 is the Mach-O AMD64 register structure.
type RegsAMD64 struct {
AX uint64
BX uint64
CX uint64
DX uint64
DI uint64
SI uint64
BP uint64
SP uint64
R8 uint64
R9 uint64
R10 uint64
R11 uint64
R12 uint64
R13 uint64
R14 uint64
R15 uint64
IP uint64
FLAGS uint64
CS uint64
FS uint64
GS uint64
}
type intName struct {
i uint32
s string
}
func stringName(i uint32, names []intName, goSyntax bool) string {
for _, n := range names {
if n.i == i {
if goSyntax {
return "macho." + n.s
}
return n.s
}
}
return strconv.FormatUint(uint64(i), 10)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package macho
//go:generate stringer -type=RelocTypeGeneric,RelocTypeX86_64,RelocTypeARM,RelocTypeARM64 -output reloctype_string.go
type RelocTypeGeneric int
const (
GENERIC_RELOC_VANILLA RelocTypeGeneric = 0
GENERIC_RELOC_PAIR RelocTypeGeneric = 1
GENERIC_RELOC_SECTDIFF RelocTypeGeneric = 2
GENERIC_RELOC_PB_LA_PTR RelocTypeGeneric = 3
GENERIC_RELOC_LOCAL_SECTDIFF RelocTypeGeneric = 4
GENERIC_RELOC_TLV RelocTypeGeneric = 5
)
func (r RelocTypeGeneric) GoString() string { return "macho." + r.String() }
type RelocTypeX86_64 int
const (
X86_64_RELOC_UNSIGNED RelocTypeX86_64 = 0
X86_64_RELOC_SIGNED RelocTypeX86_64 = 1
X86_64_RELOC_BRANCH RelocTypeX86_64 = 2
X86_64_RELOC_GOT_LOAD RelocTypeX86_64 = 3
X86_64_RELOC_GOT RelocTypeX86_64 = 4
X86_64_RELOC_SUBTRACTOR RelocTypeX86_64 = 5
X86_64_RELOC_SIGNED_1 RelocTypeX86_64 = 6
X86_64_RELOC_SIGNED_2 RelocTypeX86_64 = 7
X86_64_RELOC_SIGNED_4 RelocTypeX86_64 = 8
X86_64_RELOC_TLV RelocTypeX86_64 = 9
)
func (r RelocTypeX86_64) GoString() string { return "macho." + r.String() }
type RelocTypeARM int
const (
ARM_RELOC_VANILLA RelocTypeARM = 0
ARM_RELOC_PAIR RelocTypeARM = 1
ARM_RELOC_SECTDIFF RelocTypeARM = 2
ARM_RELOC_LOCAL_SECTDIFF RelocTypeARM = 3
ARM_RELOC_PB_LA_PTR RelocTypeARM = 4
ARM_RELOC_BR24 RelocTypeARM = 5
ARM_THUMB_RELOC_BR22 RelocTypeARM = 6
ARM_THUMB_32BIT_BRANCH RelocTypeARM = 7
ARM_RELOC_HALF RelocTypeARM = 8
ARM_RELOC_HALF_SECTDIFF RelocTypeARM = 9
)
func (r RelocTypeARM) GoString() string { return "macho." + r.String() }
type RelocTypeARM64 int
const (
ARM64_RELOC_UNSIGNED RelocTypeARM64 = 0
ARM64_RELOC_SUBTRACTOR RelocTypeARM64 = 1
ARM64_RELOC_BRANCH26 RelocTypeARM64 = 2
ARM64_RELOC_PAGE21 RelocTypeARM64 = 3
ARM64_RELOC_PAGEOFF12 RelocTypeARM64 = 4
ARM64_RELOC_GOT_LOAD_PAGE21 RelocTypeARM64 = 5
ARM64_RELOC_GOT_LOAD_PAGEOFF12 RelocTypeARM64 = 6
ARM64_RELOC_POINTER_TO_GOT RelocTypeARM64 = 7
ARM64_RELOC_TLVP_LOAD_PAGE21 RelocTypeARM64 = 8
ARM64_RELOC_TLVP_LOAD_PAGEOFF12 RelocTypeARM64 = 9
ARM64_RELOC_ADDEND RelocTypeARM64 = 10
)
func (r RelocTypeARM64) GoString() string { return "macho." + r.String() }
// Code generated by "stringer -type=RelocTypeGeneric,RelocTypeX86_64,RelocTypeARM,RelocTypeARM64 -output reloctype_string.go"; DO NOT EDIT.
package macho
import "strconv"
const _RelocTypeGeneric_name = "GENERIC_RELOC_VANILLAGENERIC_RELOC_PAIRGENERIC_RELOC_SECTDIFFGENERIC_RELOC_PB_LA_PTRGENERIC_RELOC_LOCAL_SECTDIFFGENERIC_RELOC_TLV"
var _RelocTypeGeneric_index = [...]uint8{0, 21, 39, 61, 84, 112, 129}
func (i RelocTypeGeneric) String() string {
if i < 0 || i >= RelocTypeGeneric(len(_RelocTypeGeneric_index)-1) {
return "RelocTypeGeneric(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeGeneric_name[_RelocTypeGeneric_index[i]:_RelocTypeGeneric_index[i+1]]
}
const _RelocTypeX86_64_name = "X86_64_RELOC_UNSIGNEDX86_64_RELOC_SIGNEDX86_64_RELOC_BRANCHX86_64_RELOC_GOT_LOADX86_64_RELOC_GOTX86_64_RELOC_SUBTRACTORX86_64_RELOC_SIGNED_1X86_64_RELOC_SIGNED_2X86_64_RELOC_SIGNED_4X86_64_RELOC_TLV"
var _RelocTypeX86_64_index = [...]uint8{0, 21, 40, 59, 80, 96, 119, 140, 161, 182, 198}
func (i RelocTypeX86_64) String() string {
if i < 0 || i >= RelocTypeX86_64(len(_RelocTypeX86_64_index)-1) {
return "RelocTypeX86_64(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeX86_64_name[_RelocTypeX86_64_index[i]:_RelocTypeX86_64_index[i+1]]
}
const _RelocTypeARM_name = "ARM_RELOC_VANILLAARM_RELOC_PAIRARM_RELOC_SECTDIFFARM_RELOC_LOCAL_SECTDIFFARM_RELOC_PB_LA_PTRARM_RELOC_BR24ARM_THUMB_RELOC_BR22ARM_THUMB_32BIT_BRANCHARM_RELOC_HALFARM_RELOC_HALF_SECTDIFF"
var _RelocTypeARM_index = [...]uint8{0, 17, 31, 49, 73, 92, 106, 126, 148, 162, 185}
func (i RelocTypeARM) String() string {
if i < 0 || i >= RelocTypeARM(len(_RelocTypeARM_index)-1) {
return "RelocTypeARM(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeARM_name[_RelocTypeARM_index[i]:_RelocTypeARM_index[i+1]]
}
const _RelocTypeARM64_name = "ARM64_RELOC_UNSIGNEDARM64_RELOC_SUBTRACTORARM64_RELOC_BRANCH26ARM64_RELOC_PAGE21ARM64_RELOC_PAGEOFF12ARM64_RELOC_GOT_LOAD_PAGE21ARM64_RELOC_GOT_LOAD_PAGEOFF12ARM64_RELOC_POINTER_TO_GOTARM64_RELOC_TLVP_LOAD_PAGE21ARM64_RELOC_TLVP_LOAD_PAGEOFF12ARM64_RELOC_ADDEND"
var _RelocTypeARM64_index = [...]uint16{0, 20, 42, 62, 80, 101, 128, 158, 184, 212, 243, 261}
func (i RelocTypeARM64) String() string {
if i < 0 || i >= RelocTypeARM64(len(_RelocTypeARM64_index)-1) {
return "RelocTypeARM64(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RelocTypeARM64_name[_RelocTypeARM64_index[i]:_RelocTypeARM64_index[i+1]]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package pe implements access to PE (Microsoft Windows Portable Executable) files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package pe
import (
"bytes"
"compress/zlib"
"debug/dwarf"
"encoding/binary"
"fmt"
"io"
"os"
"strings"
)
// Avoid use of post-Go 1.4 io features, to make safe for toolchain bootstrap.
const seekStart = 0
// A File represents an open PE file.
type File struct {
FileHeader
OptionalHeader any // of type *OptionalHeader32 or *OptionalHeader64
Sections []*Section
Symbols []*Symbol // COFF symbols with auxiliary symbol records removed
COFFSymbols []COFFSymbol // all COFF symbols (including auxiliary symbol records)
StringTable StringTable
closer io.Closer
}
// Open opens the named file using os.Open and prepares it for use as a PE binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the File.
// If the File was created using NewFile directly instead of Open,
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// TODO(brainman): add Load function, as a replacement for NewFile, that does not call removeAuxSymbols (for performance)
// NewFile creates a new File for accessing a PE binary in an underlying reader.
func NewFile(r io.ReaderAt) (*File, error) {
f := new(File)
sr := io.NewSectionReader(r, 0, 1<<63-1)
var dosheader [96]byte
if _, err := r.ReadAt(dosheader[0:], 0); err != nil {
return nil, err
}
var base int64
if dosheader[0] == 'M' && dosheader[1] == 'Z' {
signoff := int64(binary.LittleEndian.Uint32(dosheader[0x3c:]))
var sign [4]byte
r.ReadAt(sign[:], signoff)
if !(sign[0] == 'P' && sign[1] == 'E' && sign[2] == 0 && sign[3] == 0) {
return nil, fmt.Errorf("invalid PE file signature: % x", sign)
}
base = signoff + 4
} else {
base = int64(0)
}
sr.Seek(base, seekStart)
if err := binary.Read(sr, binary.LittleEndian, &f.FileHeader); err != nil {
return nil, err
}
switch f.FileHeader.Machine {
case IMAGE_FILE_MACHINE_AMD64,
IMAGE_FILE_MACHINE_ARM64,
IMAGE_FILE_MACHINE_ARMNT,
IMAGE_FILE_MACHINE_I386,
IMAGE_FILE_MACHINE_RISCV32,
IMAGE_FILE_MACHINE_RISCV64,
IMAGE_FILE_MACHINE_RISCV128,
IMAGE_FILE_MACHINE_UNKNOWN:
// ok
default:
return nil, fmt.Errorf("unrecognized PE machine: %#x", f.FileHeader.Machine)
}
var err error
// Read string table.
f.StringTable, err = readStringTable(&f.FileHeader, sr)
if err != nil {
return nil, err
}
// Read symbol table.
f.COFFSymbols, err = readCOFFSymbols(&f.FileHeader, sr)
if err != nil {
return nil, err
}
f.Symbols, err = removeAuxSymbols(f.COFFSymbols, f.StringTable)
if err != nil {
return nil, err
}
// Seek past file header.
_, err = sr.Seek(base+int64(binary.Size(f.FileHeader)), seekStart)
if err != nil {
return nil, err
}
// Read optional header.
f.OptionalHeader, err = readOptionalHeader(sr, f.FileHeader.SizeOfOptionalHeader)
if err != nil {
return nil, err
}
// Process sections.
f.Sections = make([]*Section, f.FileHeader.NumberOfSections)
for i := 0; i < int(f.FileHeader.NumberOfSections); i++ {
sh := new(SectionHeader32)
if err := binary.Read(sr, binary.LittleEndian, sh); err != nil {
return nil, err
}
name, err := sh.fullName(f.StringTable)
if err != nil {
return nil, err
}
s := new(Section)
s.SectionHeader = SectionHeader{
Name: name,
VirtualSize: sh.VirtualSize,
VirtualAddress: sh.VirtualAddress,
Size: sh.SizeOfRawData,
Offset: sh.PointerToRawData,
PointerToRelocations: sh.PointerToRelocations,
PointerToLineNumbers: sh.PointerToLineNumbers,
NumberOfRelocations: sh.NumberOfRelocations,
NumberOfLineNumbers: sh.NumberOfLineNumbers,
Characteristics: sh.Characteristics,
}
r2 := r
if sh.PointerToRawData == 0 { // .bss must have all 0s
r2 = zeroReaderAt{}
}
s.sr = io.NewSectionReader(r2, int64(s.SectionHeader.Offset), int64(s.SectionHeader.Size))
s.ReaderAt = s.sr
f.Sections[i] = s
}
for i := range f.Sections {
var err error
f.Sections[i].Relocs, err = readRelocs(&f.Sections[i].SectionHeader, sr)
if err != nil {
return nil, err
}
}
return f, nil
}
// zeroReaderAt is ReaderAt that reads 0s.
type zeroReaderAt struct{}
// ReadAt writes len(p) 0s into p.
func (w zeroReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
for i := range p {
p[i] = 0
}
return len(p), nil
}
// getString extracts a string from symbol string table.
func getString(section []byte, start int) (string, bool) {
if start < 0 || start >= len(section) {
return "", false
}
for end := start; end < len(section); end++ {
if section[end] == 0 {
return string(section[start:end]), true
}
}
return "", false
}
// Section returns the first section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
func (f *File) DWARF() (*dwarf.Data, error) {
dwarfSuffix := func(s *Section) string {
switch {
case strings.HasPrefix(s.Name, ".debug_"):
return s.Name[7:]
case strings.HasPrefix(s.Name, ".zdebug_"):
return s.Name[8:]
default:
return ""
}
}
// sectionData gets the data for s and checks its size.
sectionData := func(s *Section) ([]byte, error) {
b, err := s.Data()
if err != nil && uint32(len(b)) < s.Size {
return nil, err
}
if 0 < s.VirtualSize && s.VirtualSize < s.Size {
b = b[:s.VirtualSize]
}
if len(b) >= 12 && string(b[:4]) == "ZLIB" {
dlen := binary.BigEndian.Uint64(b[4:12])
dbuf := make([]byte, dlen)
r, err := zlib.NewReader(bytes.NewBuffer(b[12:]))
if err != nil {
return nil, err
}
if _, err := io.ReadFull(r, dbuf); err != nil {
return nil, err
}
if err := r.Close(); err != nil {
return nil, err
}
b = dbuf
}
return b, nil
}
// There are many other DWARF sections, but these
// are the ones the debug/dwarf package uses.
// Don't bother loading others.
var dat = map[string][]byte{"abbrev": nil, "info": nil, "str": nil, "line": nil, "ranges": nil}
for _, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; !ok {
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
dat[suffix] = b
}
d, err := dwarf.New(dat["abbrev"], nil, nil, dat["info"], dat["line"], nil, dat["ranges"], dat["str"])
if err != nil {
return nil, err
}
// Look for DWARF4 .debug_types sections and DWARF5 sections.
for i, s := range f.Sections {
suffix := dwarfSuffix(s)
if suffix == "" {
continue
}
if _, ok := dat[suffix]; ok {
// Already handled.
continue
}
b, err := sectionData(s)
if err != nil {
return nil, err
}
if suffix == "types" {
err = d.AddTypes(fmt.Sprintf("types-%d", i), b)
} else {
err = d.AddSection(".debug_"+suffix, b)
}
if err != nil {
return nil, err
}
}
return d, nil
}
// TODO(brainman): document ImportDirectory once we decide what to do with it.
type ImportDirectory struct {
OriginalFirstThunk uint32
TimeDateStamp uint32
ForwarderChain uint32
Name uint32
FirstThunk uint32
dll string
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
// It does not return weak symbols.
func (f *File) ImportedSymbols() ([]string, error) {
if f.OptionalHeader == nil {
return nil, nil
}
_, pe64 := f.OptionalHeader.(*OptionalHeader64)
// grab the number of data directory entries
var dd_length uint32
if pe64 {
dd_length = f.OptionalHeader.(*OptionalHeader64).NumberOfRvaAndSizes
} else {
dd_length = f.OptionalHeader.(*OptionalHeader32).NumberOfRvaAndSizes
}
// check that the length of data directory entries is large
// enough to include the imports directory.
if dd_length < IMAGE_DIRECTORY_ENTRY_IMPORT+1 {
return nil, nil
}
// grab the import data directory entry
var idd DataDirectory
if pe64 {
idd = f.OptionalHeader.(*OptionalHeader64).DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
} else {
idd = f.OptionalHeader.(*OptionalHeader32).DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]
}
// figure out which section contains the import directory table
var ds *Section
ds = nil
for _, s := range f.Sections {
// We are using distance between s.VirtualAddress and idd.VirtualAddress
// to avoid potential overflow of uint32 caused by addition of s.VirtualSize
// to s.VirtualAddress.
if s.VirtualAddress <= idd.VirtualAddress && idd.VirtualAddress-s.VirtualAddress < s.VirtualSize {
ds = s
break
}
}
// didn't find a section, so no import libraries were found
if ds == nil {
return nil, nil
}
d, err := ds.Data()
if err != nil {
return nil, err
}
// seek to the virtual address specified in the import data directory
d = d[idd.VirtualAddress-ds.VirtualAddress:]
// start decoding the import directory
var ida []ImportDirectory
for len(d) >= 20 {
var dt ImportDirectory
dt.OriginalFirstThunk = binary.LittleEndian.Uint32(d[0:4])
dt.TimeDateStamp = binary.LittleEndian.Uint32(d[4:8])
dt.ForwarderChain = binary.LittleEndian.Uint32(d[8:12])
dt.Name = binary.LittleEndian.Uint32(d[12:16])
dt.FirstThunk = binary.LittleEndian.Uint32(d[16:20])
d = d[20:]
if dt.OriginalFirstThunk == 0 {
break
}
ida = append(ida, dt)
}
// TODO(brainman): this needs to be rewritten
// ds.Data() returns contents of section containing import table. Why store in variable called "names"?
// Why we are retrieving it second time? We already have it in "d", and it is not modified anywhere.
// getString does not extracts a string from symbol string table (as getString doco says).
// Why ds.Data() called again and again in the loop?
// Needs test before rewrite.
names, _ := ds.Data()
var all []string
for _, dt := range ida {
dt.dll, _ = getString(names, int(dt.Name-ds.VirtualAddress))
d, _ = ds.Data()
// seek to OriginalFirstThunk
d = d[dt.OriginalFirstThunk-ds.VirtualAddress:]
for len(d) > 0 {
if pe64 { // 64bit
va := binary.LittleEndian.Uint64(d[0:8])
d = d[8:]
if va == 0 {
break
}
if va&0x8000000000000000 > 0 { // is Ordinal
// TODO add dynimport ordinal support.
} else {
fn, _ := getString(names, int(uint32(va)-ds.VirtualAddress+2))
all = append(all, fn+":"+dt.dll)
}
} else { // 32bit
va := binary.LittleEndian.Uint32(d[0:4])
d = d[4:]
if va == 0 {
break
}
if va&0x80000000 > 0 { // is Ordinal
// TODO add dynimport ordinal support.
//ord := va&0x0000FFFF
} else {
fn, _ := getString(names, int(va-ds.VirtualAddress+2))
all = append(all, fn+":"+dt.dll)
}
}
}
}
return all, nil
}
// ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
// TODO
// cgo -dynimport don't use this for windows PE, so just return.
return nil, nil
}
// FormatError is unused.
// The type is retained for compatibility.
type FormatError struct {
}
func (e *FormatError) Error() string {
return "unknown error"
}
// readOptionalHeader accepts a io.ReadSeeker pointing to optional header in the PE file
// and its size as seen in the file header.
// It parses the given size of bytes and returns optional header. It infers whether the
// bytes being parsed refer to 32 bit or 64 bit version of optional header.
func readOptionalHeader(r io.ReadSeeker, sz uint16) (any, error) {
// If optional header size is 0, return empty optional header.
if sz == 0 {
return nil, nil
}
var (
// First couple of bytes in option header state its type.
// We need to read them first to determine the type and
// validity of optional header.
ohMagic uint16
ohMagicSz = binary.Size(ohMagic)
)
// If optional header size is greater than 0 but less than its magic size, return error.
if sz < uint16(ohMagicSz) {
return nil, fmt.Errorf("optional header size is less than optional header magic size")
}
// read reads from io.ReadSeeke, r, into data.
var err error
read := func(data any) bool {
err = binary.Read(r, binary.LittleEndian, data)
return err == nil
}
if !read(&ohMagic) {
return nil, fmt.Errorf("failure to read optional header magic: %v", err)
}
switch ohMagic {
case 0x10b: // PE32
var (
oh32 OptionalHeader32
// There can be 0 or more data directories. So the minimum size of optional
// header is calculated by subtracting oh32.DataDirectory size from oh32 size.
oh32MinSz = binary.Size(oh32) - binary.Size(oh32.DataDirectory)
)
if sz < uint16(oh32MinSz) {
return nil, fmt.Errorf("optional header size(%d) is less minimum size (%d) of PE32 optional header", sz, oh32MinSz)
}
// Init oh32 fields
oh32.Magic = ohMagic
if !read(&oh32.MajorLinkerVersion) ||
!read(&oh32.MinorLinkerVersion) ||
!read(&oh32.SizeOfCode) ||
!read(&oh32.SizeOfInitializedData) ||
!read(&oh32.SizeOfUninitializedData) ||
!read(&oh32.AddressOfEntryPoint) ||
!read(&oh32.BaseOfCode) ||
!read(&oh32.BaseOfData) ||
!read(&oh32.ImageBase) ||
!read(&oh32.SectionAlignment) ||
!read(&oh32.FileAlignment) ||
!read(&oh32.MajorOperatingSystemVersion) ||
!read(&oh32.MinorOperatingSystemVersion) ||
!read(&oh32.MajorImageVersion) ||
!read(&oh32.MinorImageVersion) ||
!read(&oh32.MajorSubsystemVersion) ||
!read(&oh32.MinorSubsystemVersion) ||
!read(&oh32.Win32VersionValue) ||
!read(&oh32.SizeOfImage) ||
!read(&oh32.SizeOfHeaders) ||
!read(&oh32.CheckSum) ||
!read(&oh32.Subsystem) ||
!read(&oh32.DllCharacteristics) ||
!read(&oh32.SizeOfStackReserve) ||
!read(&oh32.SizeOfStackCommit) ||
!read(&oh32.SizeOfHeapReserve) ||
!read(&oh32.SizeOfHeapCommit) ||
!read(&oh32.LoaderFlags) ||
!read(&oh32.NumberOfRvaAndSizes) {
return nil, fmt.Errorf("failure to read PE32 optional header: %v", err)
}
dd, err := readDataDirectories(r, sz-uint16(oh32MinSz), oh32.NumberOfRvaAndSizes)
if err != nil {
return nil, err
}
copy(oh32.DataDirectory[:], dd)
return &oh32, nil
case 0x20b: // PE32+
var (
oh64 OptionalHeader64
// There can be 0 or more data directories. So the minimum size of optional
// header is calculated by subtracting oh64.DataDirectory size from oh64 size.
oh64MinSz = binary.Size(oh64) - binary.Size(oh64.DataDirectory)
)
if sz < uint16(oh64MinSz) {
return nil, fmt.Errorf("optional header size(%d) is less minimum size (%d) for PE32+ optional header", sz, oh64MinSz)
}
// Init oh64 fields
oh64.Magic = ohMagic
if !read(&oh64.MajorLinkerVersion) ||
!read(&oh64.MinorLinkerVersion) ||
!read(&oh64.SizeOfCode) ||
!read(&oh64.SizeOfInitializedData) ||
!read(&oh64.SizeOfUninitializedData) ||
!read(&oh64.AddressOfEntryPoint) ||
!read(&oh64.BaseOfCode) ||
!read(&oh64.ImageBase) ||
!read(&oh64.SectionAlignment) ||
!read(&oh64.FileAlignment) ||
!read(&oh64.MajorOperatingSystemVersion) ||
!read(&oh64.MinorOperatingSystemVersion) ||
!read(&oh64.MajorImageVersion) ||
!read(&oh64.MinorImageVersion) ||
!read(&oh64.MajorSubsystemVersion) ||
!read(&oh64.MinorSubsystemVersion) ||
!read(&oh64.Win32VersionValue) ||
!read(&oh64.SizeOfImage) ||
!read(&oh64.SizeOfHeaders) ||
!read(&oh64.CheckSum) ||
!read(&oh64.Subsystem) ||
!read(&oh64.DllCharacteristics) ||
!read(&oh64.SizeOfStackReserve) ||
!read(&oh64.SizeOfStackCommit) ||
!read(&oh64.SizeOfHeapReserve) ||
!read(&oh64.SizeOfHeapCommit) ||
!read(&oh64.LoaderFlags) ||
!read(&oh64.NumberOfRvaAndSizes) {
return nil, fmt.Errorf("failure to read PE32+ optional header: %v", err)
}
dd, err := readDataDirectories(r, sz-uint16(oh64MinSz), oh64.NumberOfRvaAndSizes)
if err != nil {
return nil, err
}
copy(oh64.DataDirectory[:], dd)
return &oh64, nil
default:
return nil, fmt.Errorf("optional header has unexpected Magic of 0x%x", ohMagic)
}
}
// readDataDirectories accepts a io.ReadSeeker pointing to data directories in the PE file,
// its size and number of data directories as seen in optional header.
// It parses the given size of bytes and returns given number of data directories.
func readDataDirectories(r io.ReadSeeker, sz uint16, n uint32) ([]DataDirectory, error) {
ddSz := uint64(binary.Size(DataDirectory{}))
if uint64(sz) != uint64(n)*ddSz {
return nil, fmt.Errorf("size of data directories(%d) is inconsistent with number of data directories(%d)", sz, n)
}
dd := make([]DataDirectory, n)
if err := binary.Read(r, binary.LittleEndian, dd); err != nil {
return nil, fmt.Errorf("failure to read data directories: %v", err)
}
return dd, nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"strconv"
)
// SectionHeader32 represents real PE COFF section header.
type SectionHeader32 struct {
Name [8]uint8
VirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLineNumbers uint32
NumberOfRelocations uint16
NumberOfLineNumbers uint16
Characteristics uint32
}
// fullName finds real name of section sh. Normally name is stored
// in sh.Name, but if it is longer then 8 characters, it is stored
// in COFF string table st instead.
func (sh *SectionHeader32) fullName(st StringTable) (string, error) {
if sh.Name[0] != '/' {
return cstring(sh.Name[:]), nil
}
i, err := strconv.Atoi(cstring(sh.Name[1:]))
if err != nil {
return "", err
}
return st.String(uint32(i))
}
// TODO(brainman): copy all IMAGE_REL_* consts from ldpe.go here
// Reloc represents a PE COFF relocation.
// Each section contains its own relocation list.
type Reloc struct {
VirtualAddress uint32
SymbolTableIndex uint32
Type uint16
}
func readRelocs(sh *SectionHeader, r io.ReadSeeker) ([]Reloc, error) {
if sh.NumberOfRelocations <= 0 {
return nil, nil
}
_, err := r.Seek(int64(sh.PointerToRelocations), seekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to %q section relocations: %v", sh.Name, err)
}
relocs := make([]Reloc, sh.NumberOfRelocations)
err = binary.Read(r, binary.LittleEndian, relocs)
if err != nil {
return nil, fmt.Errorf("fail to read section relocations: %v", err)
}
return relocs, nil
}
// SectionHeader is similar to SectionHeader32 with Name
// field replaced by Go string.
type SectionHeader struct {
Name string
VirtualSize uint32
VirtualAddress uint32
Size uint32
Offset uint32
PointerToRelocations uint32
PointerToLineNumbers uint32
NumberOfRelocations uint16
NumberOfLineNumbers uint16
Characteristics uint32
}
// Section provides access to PE COFF section.
type Section struct {
SectionHeader
Relocs []Reloc
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the PE section s.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, uint64(s.Size), 0)
}
// Open returns a new ReadSeeker reading the PE section s.
func (s *Section) Open() io.ReadSeeker {
return io.NewSectionReader(s.sr, 0, 1<<63-1)
}
// Section characteristics flags.
const (
IMAGE_SCN_CNT_CODE = 0x00000020
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080
IMAGE_SCN_LNK_COMDAT = 0x00001000
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000
IMAGE_SCN_MEM_EXECUTE = 0x20000000
IMAGE_SCN_MEM_READ = 0x40000000
IMAGE_SCN_MEM_WRITE = 0x80000000
)
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"bytes"
"encoding/binary"
"fmt"
"internal/saferio"
"io"
)
// cstring converts ASCII byte sequence b to string.
// It stops once it finds 0 or reaches end of b.
func cstring(b []byte) string {
i := bytes.IndexByte(b, 0)
if i == -1 {
i = len(b)
}
return string(b[:i])
}
// StringTable is a COFF string table.
type StringTable []byte
func readStringTable(fh *FileHeader, r io.ReadSeeker) (StringTable, error) {
// COFF string table is located right after COFF symbol table.
if fh.PointerToSymbolTable <= 0 {
return nil, nil
}
offset := fh.PointerToSymbolTable + COFFSymbolSize*fh.NumberOfSymbols
_, err := r.Seek(int64(offset), seekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to string table: %v", err)
}
var l uint32
err = binary.Read(r, binary.LittleEndian, &l)
if err != nil {
return nil, fmt.Errorf("fail to read string table length: %v", err)
}
// string table length includes itself
if l <= 4 {
return nil, nil
}
l -= 4
buf, err := saferio.ReadData(r, uint64(l))
if err != nil {
return nil, fmt.Errorf("fail to read string table: %v", err)
}
return StringTable(buf), nil
}
// TODO(brainman): decide if start parameter should be int instead of uint32
// String extracts string from COFF string table st at offset start.
func (st StringTable) String(start uint32) (string, error) {
// start includes 4 bytes of string table length
if start < 4 {
return "", fmt.Errorf("offset %d is before the start of string table", start)
}
start -= 4
if int(start) > len(st) {
return "", fmt.Errorf("offset %d is beyond the end of string table", start)
}
return cstring(st[start:]), nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pe
import (
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"io"
"unsafe"
)
const COFFSymbolSize = 18
// COFFSymbol represents single COFF symbol table record.
type COFFSymbol struct {
Name [8]uint8
Value uint32
SectionNumber int16
Type uint16
StorageClass uint8
NumberOfAuxSymbols uint8
}
// readCOFFSymbols reads in the symbol table for a PE file, returning
// a slice of COFFSymbol objects. The PE format includes both primary
// symbols (whose fields are described by COFFSymbol above) and
// auxiliary symbols; all symbols are 18 bytes in size. The auxiliary
// symbols for a given primary symbol are placed following it in the
// array, e.g.
//
// ...
// k+0: regular sym k
// k+1: 1st aux symbol for k
// k+2: 2nd aux symbol for k
// k+3: regular sym k+3
// k+4: 1st aux symbol for k+3
// k+5: regular sym k+5
// k+6: regular sym k+6
//
// The PE format allows for several possible aux symbol formats. For
// more info see:
//
// https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-symbol-records
//
// At the moment this package only provides APIs for looking at
// aux symbols of format 5 (associated with section definition symbols).
func readCOFFSymbols(fh *FileHeader, r io.ReadSeeker) ([]COFFSymbol, error) {
if fh.PointerToSymbolTable == 0 {
return nil, nil
}
if fh.NumberOfSymbols <= 0 {
return nil, nil
}
_, err := r.Seek(int64(fh.PointerToSymbolTable), seekStart)
if err != nil {
return nil, fmt.Errorf("fail to seek to symbol table: %v", err)
}
c := saferio.SliceCap((*COFFSymbol)(nil), uint64(fh.NumberOfSymbols))
if c < 0 {
return nil, errors.New("too many symbols; file may be corrupt")
}
syms := make([]COFFSymbol, 0, c)
naux := 0
for k := uint32(0); k < fh.NumberOfSymbols; k++ {
var sym COFFSymbol
if naux == 0 {
// Read a primary symbol.
err = binary.Read(r, binary.LittleEndian, &sym)
if err != nil {
return nil, fmt.Errorf("fail to read symbol table: %v", err)
}
// Record how many auxiliary symbols it has.
naux = int(sym.NumberOfAuxSymbols)
} else {
// Read an aux symbol. At the moment we assume all
// aux symbols are format 5 (obviously this doesn't always
// hold; more cases will be needed below if more aux formats
// are supported in the future).
naux--
aux := (*COFFSymbolAuxFormat5)(unsafe.Pointer(&sym))
err = binary.Read(r, binary.LittleEndian, aux)
if err != nil {
return nil, fmt.Errorf("fail to read symbol table: %v", err)
}
}
syms = append(syms, sym)
}
if naux != 0 {
return nil, fmt.Errorf("fail to read symbol table: %d aux symbols unread", naux)
}
return syms, nil
}
// isSymNameOffset checks symbol name if it is encoded as offset into string table.
func isSymNameOffset(name [8]byte) (bool, uint32) {
if name[0] == 0 && name[1] == 0 && name[2] == 0 && name[3] == 0 {
return true, binary.LittleEndian.Uint32(name[4:])
}
return false, 0
}
// FullName finds real name of symbol sym. Normally name is stored
// in sym.Name, but if it is longer then 8 characters, it is stored
// in COFF string table st instead.
func (sym *COFFSymbol) FullName(st StringTable) (string, error) {
if ok, offset := isSymNameOffset(sym.Name); ok {
return st.String(offset)
}
return cstring(sym.Name[:]), nil
}
func removeAuxSymbols(allsyms []COFFSymbol, st StringTable) ([]*Symbol, error) {
if len(allsyms) == 0 {
return nil, nil
}
syms := make([]*Symbol, 0)
aux := uint8(0)
for _, sym := range allsyms {
if aux > 0 {
aux--
continue
}
name, err := sym.FullName(st)
if err != nil {
return nil, err
}
aux = sym.NumberOfAuxSymbols
s := &Symbol{
Name: name,
Value: sym.Value,
SectionNumber: sym.SectionNumber,
Type: sym.Type,
StorageClass: sym.StorageClass,
}
syms = append(syms, s)
}
return syms, nil
}
// Symbol is similar to COFFSymbol with Name field replaced
// by Go string. Symbol also does not have NumberOfAuxSymbols.
type Symbol struct {
Name string
Value uint32
SectionNumber int16
Type uint16
StorageClass uint8
}
// COFFSymbolAuxFormat5 describes the expected form of an aux symbol
// attached to a section definition symbol. The PE format defines a
// number of different aux symbol formats: format 1 for function
// definitions, format 2 for .be and .ef symbols, and so on. Format 5
// holds extra info associated with a section definition, including
// number of relocations + line numbers, as well as COMDAT info. See
// https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-format-5-section-definitions
// for more on what's going on here.
type COFFSymbolAuxFormat5 struct {
Size uint32
NumRelocs uint16
NumLineNumbers uint16
Checksum uint32
SecNum uint16
Selection uint8
_ [3]uint8 // padding
}
// These constants make up the possible values for the 'Selection'
// field in an AuxFormat5.
const (
IMAGE_COMDAT_SELECT_NODUPLICATES = 1
IMAGE_COMDAT_SELECT_ANY = 2
IMAGE_COMDAT_SELECT_SAME_SIZE = 3
IMAGE_COMDAT_SELECT_EXACT_MATCH = 4
IMAGE_COMDAT_SELECT_ASSOCIATIVE = 5
IMAGE_COMDAT_SELECT_LARGEST = 6
)
// COFFSymbolReadSectionDefAux returns a blob of auxiliary information
// (including COMDAT info) for a section definition symbol. Here 'idx'
// is the index of a section symbol in the main COFFSymbol array for
// the File. Return value is a pointer to the appropriate aux symbol
// struct. For more info, see:
//
// auxiliary symbols: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-symbol-records
// COMDAT sections: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#comdat-sections-object-only
// auxiliary info for section definitions: https://docs.microsoft.com/en-us/windows/win32/debug/pe-format#auxiliary-format-5-section-definitions
func (f *File) COFFSymbolReadSectionDefAux(idx int) (*COFFSymbolAuxFormat5, error) {
var rv *COFFSymbolAuxFormat5
if idx < 0 || idx >= len(f.COFFSymbols) {
return rv, fmt.Errorf("invalid symbol index")
}
pesym := &f.COFFSymbols[idx]
const IMAGE_SYM_CLASS_STATIC = 3
if pesym.StorageClass != uint8(IMAGE_SYM_CLASS_STATIC) {
return rv, fmt.Errorf("incorrect symbol storage class")
}
if pesym.NumberOfAuxSymbols == 0 || idx+1 >= len(f.COFFSymbols) {
return rv, fmt.Errorf("aux symbol unavailable")
}
// Locate and return a pointer to the successor aux symbol.
pesymn := &f.COFFSymbols[idx+1]
rv = (*COFFSymbolAuxFormat5)(unsafe.Pointer(pesymn))
return rv, nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package plan9obj implements access to Plan 9 a.out object files.
# Security
This package is not designed to be hardened against adversarial inputs, and is
outside the scope of https://go.dev/security/policy. In particular, only basic
validation is done when parsing object files. As such, care should be taken when
parsing untrusted inputs, as parsing malformed files may consume significant
resources, or cause panics.
*/
package plan9obj
import (
"encoding/binary"
"errors"
"fmt"
"internal/saferio"
"io"
"os"
)
// A FileHeader represents a Plan 9 a.out file header.
type FileHeader struct {
Magic uint32
Bss uint32
Entry uint64
PtrSize int
LoadAddress uint64
HdrSize uint64
}
// A File represents an open Plan 9 a.out file.
type File struct {
FileHeader
Sections []*Section
closer io.Closer
}
// A SectionHeader represents a single Plan 9 a.out section header.
// This structure doesn't exist on-disk, but eases navigation
// through the object file.
type SectionHeader struct {
Name string
Size uint32
Offset uint32
}
// A Section represents a single section in a Plan 9 a.out file.
type Section struct {
SectionHeader
// Embed ReaderAt for ReadAt method.
// Do not embed SectionReader directly
// to avoid having Read and Seek.
// If a client wants Read and Seek it must use
// Open() to avoid fighting over the seek offset
// with other clients.
io.ReaderAt
sr *io.SectionReader
}
// Data reads and returns the contents of the Plan 9 a.out section.
func (s *Section) Data() ([]byte, error) {
return saferio.ReadDataAt(s.sr, uint64(s.Size), 0)
}
// Open returns a new ReadSeeker reading the Plan 9 a.out section.
func (s *Section) Open() io.ReadSeeker { return io.NewSectionReader(s.sr, 0, 1<<63-1) }
// A Symbol represents an entry in a Plan 9 a.out symbol table section.
type Sym struct {
Value uint64
Type rune
Name string
}
/*
* Plan 9 a.out reader
*/
// formatError is returned by some operations if the data does
// not have the correct format for an object file.
type formatError struct {
off int
msg string
val any
}
func (e *formatError) Error() string {
msg := e.msg
if e.val != nil {
msg += fmt.Sprintf(" '%v'", e.val)
}
msg += fmt.Sprintf(" in record at byte %#x", e.off)
return msg
}
// Open opens the named file using os.Open and prepares it for use as a Plan 9 a.out binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the File.
// If the File was created using NewFile directly instead of Open,
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
func parseMagic(magic []byte) (uint32, error) {
m := binary.BigEndian.Uint32(magic)
switch m {
case Magic386, MagicAMD64, MagicARM:
return m, nil
}
return 0, &formatError{0, "bad magic number", magic}
}
// NewFile creates a new File for accessing a Plan 9 binary in an underlying reader.
// The Plan 9 binary is expected to start at position 0 in the ReaderAt.
func NewFile(r io.ReaderAt) (*File, error) {
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read and decode Plan 9 magic
var magic [4]byte
if _, err := r.ReadAt(magic[:], 0); err != nil {
return nil, err
}
_, err := parseMagic(magic[:])
if err != nil {
return nil, err
}
ph := new(prog)
if err := binary.Read(sr, binary.BigEndian, ph); err != nil {
return nil, err
}
f := &File{FileHeader: FileHeader{
Magic: ph.Magic,
Bss: ph.Bss,
Entry: uint64(ph.Entry),
PtrSize: 4,
LoadAddress: 0x1000,
HdrSize: 4 * 8,
}}
if ph.Magic&Magic64 != 0 {
if err := binary.Read(sr, binary.BigEndian, &f.Entry); err != nil {
return nil, err
}
f.PtrSize = 8
f.LoadAddress = 0x200000
f.HdrSize += 8
}
var sects = []struct {
name string
size uint32
}{
{"text", ph.Text},
{"data", ph.Data},
{"syms", ph.Syms},
{"spsz", ph.Spsz},
{"pcsz", ph.Pcsz},
}
f.Sections = make([]*Section, 5)
off := uint32(f.HdrSize)
for i, sect := range sects {
s := new(Section)
s.SectionHeader = SectionHeader{
Name: sect.name,
Size: sect.size,
Offset: off,
}
off += sect.size
s.sr = io.NewSectionReader(r, int64(s.Offset), int64(s.Size))
s.ReaderAt = s.sr
f.Sections[i] = s
}
return f, nil
}
func walksymtab(data []byte, ptrsz int, fn func(sym) error) error {
var order binary.ByteOrder = binary.BigEndian
var s sym
p := data
for len(p) >= 4 {
// Symbol type, value.
if len(p) < ptrsz {
return &formatError{len(data), "unexpected EOF", nil}
}
// fixed-width value
if ptrsz == 8 {
s.value = order.Uint64(p[0:8])
p = p[8:]
} else {
s.value = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
if len(p) < 1 {
return &formatError{len(data), "unexpected EOF", nil}
}
typ := p[0] & 0x7F
s.typ = typ
p = p[1:]
// Name.
var i int
var nnul int
for i = 0; i < len(p); i++ {
if p[i] == 0 {
nnul = 1
break
}
}
switch typ {
case 'z', 'Z':
p = p[i+nnul:]
for i = 0; i+2 <= len(p); i += 2 {
if p[i] == 0 && p[i+1] == 0 {
nnul = 2
break
}
}
}
if len(p) < i+nnul {
return &formatError{len(data), "unexpected EOF", nil}
}
s.name = p[0:i]
i += nnul
p = p[i:]
fn(s)
}
return nil
}
// newTable decodes the Go symbol table in data,
// returning an in-memory representation.
func newTable(symtab []byte, ptrsz int) ([]Sym, error) {
var n int
err := walksymtab(symtab, ptrsz, func(s sym) error {
n++
return nil
})
if err != nil {
return nil, err
}
fname := make(map[uint16]string)
syms := make([]Sym, 0, n)
err = walksymtab(symtab, ptrsz, func(s sym) error {
n := len(syms)
syms = syms[0 : n+1]
ts := &syms[n]
ts.Type = rune(s.typ)
ts.Value = s.value
switch s.typ {
default:
ts.Name = string(s.name)
case 'z', 'Z':
for i := 0; i < len(s.name); i += 2 {
eltIdx := binary.BigEndian.Uint16(s.name[i : i+2])
elt, ok := fname[eltIdx]
if !ok {
return &formatError{-1, "bad filename code", eltIdx}
}
if n := len(ts.Name); n > 0 && ts.Name[n-1] != '/' {
ts.Name += "/"
}
ts.Name += elt
}
}
switch s.typ {
case 'f':
fname[uint16(s.value)] = ts.Name
}
return nil
})
if err != nil {
return nil, err
}
return syms, nil
}
// ErrNoSymbols is returned by File.Symbols if there is no such section
// in the File.
var ErrNoSymbols = errors.New("no symbol section")
// Symbols returns the symbol table for f.
func (f *File) Symbols() ([]Sym, error) {
symtabSection := f.Section("syms")
if symtabSection == nil {
return nil, ErrNoSymbols
}
symtab, err := symtabSection.Data()
if err != nil {
return nil, errors.New("cannot load symbol section")
}
return newTable(symtab, f.PtrSize)
}
// Section returns a section with the given name, or nil if no such
// section exists.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name {
return s
}
}
return nil
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embed provides access to files embedded in the running Go program.
//
// Go source files that import "embed" can use the //go:embed directive
// to initialize a variable of type string, []byte, or FS with the contents of
// files read from the package directory or subdirectories at compile time.
//
// For example, here are three ways to embed a file named hello.txt
// and then print its contents at run time.
//
// Embedding one file into a string:
//
// import _ "embed"
//
// //go:embed hello.txt
// var s string
// print(s)
//
// Embedding one file into a slice of bytes:
//
// import _ "embed"
//
// //go:embed hello.txt
// var b []byte
// print(string(b))
//
// Embedded one or more files into a file system:
//
// import "embed"
//
// //go:embed hello.txt
// var f embed.FS
// data, _ := f.ReadFile("hello.txt")
// print(string(data))
//
// # Directives
//
// A //go:embed directive above a variable declaration specifies which files to embed,
// using one or more path.Match patterns.
//
// The directive must immediately precede a line containing the declaration of a single variable.
// Only blank lines and ‘//’ line comments are permitted between the directive and the declaration.
//
// The type of the variable must be a string type, or a slice of a byte type,
// or FS (or an alias of FS).
//
// For example:
//
// package server
//
// import "embed"
//
// // content holds our static web server content.
// //go:embed image/* template/*
// //go:embed html/index.html
// var content embed.FS
//
// The Go build system will recognize the directives and arrange for the declared variable
// (in the example above, content) to be populated with the matching files from the file system.
//
// The //go:embed directive accepts multiple space-separated patterns for
// brevity, but it can also be repeated, to avoid very long lines when there are
// many patterns. The patterns are interpreted relative to the package directory
// containing the source file. The path separator is a forward slash, even on
// Windows systems. Patterns may not contain ‘.’ or ‘..’ or empty path elements,
// nor may they begin or end with a slash. To match everything in the current
// directory, use ‘*’ instead of ‘.’. To allow for naming files with spaces in
// their names, patterns can be written as Go double-quoted or back-quoted
// string literals.
//
// If a pattern names a directory, all files in the subtree rooted at that directory are
// embedded (recursively), except that files with names beginning with ‘.’ or ‘_’
// are excluded. So the variable in the above example is almost equivalent to:
//
// // content is our static web server content.
// //go:embed image template html/index.html
// var content embed.FS
//
// The difference is that ‘image/*’ embeds ‘image/.tempfile’ while ‘image’ does not.
// Neither embeds ‘image/dir/.tempfile’.
//
// If a pattern begins with the prefix ‘all:’, then the rule for walking directories is changed
// to include those files beginning with ‘.’ or ‘_’. For example, ‘all:image’ embeds
// both ‘image/.tempfile’ and ‘image/dir/.tempfile’.
//
// The //go:embed directive can be used with both exported and unexported variables,
// depending on whether the package wants to make the data available to other packages.
// It can only be used with variables at package scope, not with local variables.
//
// Patterns must not match files outside the package's module, such as ‘.git/*’ or symbolic links.
// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :.
// Matches for empty directories are ignored. After that, each pattern in a //go:embed line
// must match at least one file or non-empty directory.
//
// If any patterns are invalid or have invalid matches, the build will fail.
//
// # Strings and Bytes
//
// The //go:embed line for a variable of type string or []byte can have only a single pattern,
// and that pattern can match only a single file. The string or []byte is initialized with
// the contents of that file.
//
// The //go:embed directive requires importing "embed", even when using a string or []byte.
// In source files that don't refer to embed.FS, use a blank import (import _ "embed").
//
// # File Systems
//
// For embedding a single file, a variable of type string or []byte is often best.
// The FS type enables embedding a tree of files, such as a directory of static
// web server content, as in the example above.
//
// FS implements the io/fs package's FS interface, so it can be used with any package that
// understands file systems, including net/http, text/template, and html/template.
//
// For example, given the content variable in the example above, we can write:
//
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content))))
//
// template.ParseFS(content, "*.tmpl")
//
// # Tools
//
// To support tools that analyze Go packages, the patterns found in //go:embed lines
// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns,
// and XTestEmbedPatterns fields in the “go help list” output.
package embed
import (
"errors"
"io"
"io/fs"
"time"
)
// An FS is a read-only collection of files, usually initialized with a //go:embed directive.
// When declared without a //go:embed directive, an FS is an empty file system.
//
// An FS is a read-only value, so it is safe to use from multiple goroutines
// simultaneously and also safe to assign values of type FS to each other.
//
// FS implements fs.FS, so it can be used with any package that understands
// file system interfaces, including net/http, text/template, and html/template.
//
// See the package documentation for more details about initializing an FS.
type FS struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
//
// The files list is sorted by name but not by simple string comparison.
// Instead, each file's name takes the form "dir/elem" or "dir/elem/".
// The optional trailing slash indicates that the file is itself a directory.
// The files list is sorted first by dir (if dir is missing, it is taken to be ".")
// and then by base, so this list of files:
//
// p
// q/
// q/r
// q/s/
// q/s/t
// q/s/u
// q/v
// w
//
// is actually sorted as:
//
// p # dir=. elem=p
// q/ # dir=. elem=q
// w/ # dir=. elem=w
// q/r # dir=q elem=r
// q/s/ # dir=q elem=s
// q/v # dir=q elem=v
// q/s/t # dir=q/s elem=t
// q/s/u # dir=q/s elem=u
//
// This order brings directory contents together in contiguous sections
// of the list, allowing a directory read to use binary search to find
// the relevant sequence of entries.
files *[]file
}
// split splits the name into dir and elem as described in the
// comment in the FS struct above. isDir reports whether the
// final trailing slash was present, indicating that name is a directory.
func split(name string) (dir, elem string, isDir bool) {
if name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
// trimSlash trims a trailing slash from name, if present,
// returning the possibly shortened name.
func trimSlash(name string) string {
if len(name) > 0 && name[len(name)-1] == '/' {
return name[:len(name)-1]
}
return name
}
var (
_ fs.ReadDirFS = FS{}
_ fs.ReadFileFS = FS{}
)
// A file is a single file in the FS.
// It implements fs.FileInfo and fs.DirEntry.
type file struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
name string
data string
hash [16]byte // truncated SHA256 hash
}
var (
_ fs.FileInfo = (*file)(nil)
_ fs.DirEntry = (*file)(nil)
)
func (f *file) Name() string { _, elem, _ := split(f.name); return elem }
func (f *file) Size() int64 { return int64(len(f.data)) }
func (f *file) ModTime() time.Time { return time.Time{} }
func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir }
func (f *file) Sys() any { return nil }
func (f *file) Type() fs.FileMode { return f.Mode().Type() }
func (f *file) Info() (fs.FileInfo, error) { return f, nil }
func (f *file) Mode() fs.FileMode {
if f.IsDir() {
return fs.ModeDir | 0555
}
return 0444
}
// dotFile is a file for the root directory,
// which is omitted from the files list in a FS.
var dotFile = &file{name: "./"}
// lookup returns the named file, or nil if it is not present.
func (f FS) lookup(name string) *file {
if !fs.ValidPath(name) {
// The compiler should never emit a file with an invalid name,
// so this check is not strictly necessary (if name is invalid,
// we shouldn't find a match below), but it's a good backstop anyway.
return nil
}
if name == "." {
return dotFile
}
if f.files == nil {
return nil
}
// Binary search to find where name would be in the list,
// and then check if name is at that position.
dir, elem, _ := split(name)
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) && trimSlash(files[i].name) == name {
return &files[i]
}
return nil
}
// readDir returns the list of files corresponding to the directory dir.
func (f FS) readDir(dir string) []file {
if f.files == nil {
return nil
}
// Binary search to find where dir starts and ends in the list
// and then return that slice of the list.
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sortSearch(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
// Open opens the named file for reading and returns it as an fs.File.
//
// The returned file implements io.Seeker when the file is not a directory.
func (f FS) Open(name string) (fs.File, error) {
file := f.lookup(name)
if file == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if file.IsDir() {
return &openDir{file, f.readDir(name), 0}, nil
}
return &openFile{file, 0}, nil
}
// ReadDir reads and returns the entire named directory.
func (f FS) ReadDir(name string) ([]fs.DirEntry, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
dir, ok := file.(*openDir)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")}
}
list := make([]fs.DirEntry, len(dir.files))
for i := range list {
list[i] = &dir.files[i]
}
return list, nil
}
// ReadFile reads and returns the content of the named file.
func (f FS) ReadFile(name string) ([]byte, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
ofile, ok := file.(*openFile)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")}
}
return []byte(ofile.f.data), nil
}
// An openFile is a regular file open for reading.
type openFile struct {
f *file // the file itself
offset int64 // current read offset
}
var (
_ io.Seeker = (*openFile)(nil)
)
func (f *openFile) Close() error { return nil }
func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil }
func (f *openFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid}
}
n := copy(b, f.f.data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.data))
}
if offset < 0 || offset > int64(len(f.f.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
// An openDir is a directory open for reading.
type openDir struct {
f *file // the directory file itself
files []file // the directory contents
offset int // the read offset, an index into the files slice
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.files[d.offset+i]
}
d.offset += n
return list, nil
}
// sortSearch is like sort.Search, avoiding an import.
func sortSearch(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ascii85 implements the ascii85 data encoding
// as used in the btoa tool and Adobe's PostScript and PDF document formats.
package ascii85
import (
"io"
"strconv"
)
/*
* Encoder
*/
// Encode encodes src into at most MaxEncodedLen(len(src))
// bytes of dst, returning the actual number of bytes written.
//
// The encoding handles 4-byte chunks, using a special encoding
// for the last fragment, so Encode is not appropriate for use on
// individual blocks of a large data stream. Use NewEncoder() instead.
//
// Often, ascii85-encoded data is wrapped in <~ and ~> symbols.
// Encode does not add these.
func Encode(dst, src []byte) int {
if len(src) == 0 {
return 0
}
n := 0
for len(src) > 0 {
dst[0] = 0
dst[1] = 0
dst[2] = 0
dst[3] = 0
dst[4] = 0
// Unpack 4 bytes into uint32 to repack into base 85 5-byte.
var v uint32
switch len(src) {
default:
v |= uint32(src[3])
fallthrough
case 3:
v |= uint32(src[2]) << 8
fallthrough
case 2:
v |= uint32(src[1]) << 16
fallthrough
case 1:
v |= uint32(src[0]) << 24
}
// Special case: zero (!!!!!) shortens to z.
if v == 0 && len(src) >= 4 {
dst[0] = 'z'
dst = dst[1:]
src = src[4:]
n++
continue
}
// Otherwise, 5 base 85 digits starting at !.
for i := 4; i >= 0; i-- {
dst[i] = '!' + byte(v%85)
v /= 85
}
// If src was short, discard the low destination bytes.
m := 5
if len(src) < 4 {
m -= 4 - len(src)
src = nil
} else {
src = src[4:]
}
dst = dst[m:]
n += m
}
return n
}
// MaxEncodedLen returns the maximum length of an encoding of n source bytes.
func MaxEncodedLen(n int) int { return (n + 3) / 4 * 5 }
// NewEncoder returns a new ascii85 stream encoder. Data written to
// the returned writer will be encoded and then written to w.
// Ascii85 encodings operate in 32-bit blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// trailing partial block.
func NewEncoder(w io.Writer) io.WriteCloser { return &encoder{w: w} }
type encoder struct {
err error
w io.Writer
buf [4]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 4; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 4 {
return
}
nout := Encode(e.out[0:], e.buf[0:])
if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 4 {
nn := len(e.out) / 5 * 4
if nn > len(p) {
nn = len(p)
}
nn -= nn % 4
if nn > 0 {
nout := Encode(e.out[0:], p[0:nn])
if _, e.err = e.w.Write(e.out[0:nout]); e.err != nil {
return n, e.err
}
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
nout := Encode(e.out[0:], e.buf[0:e.nbuf])
e.nbuf = 0
_, e.err = e.w.Write(e.out[0:nout])
}
return e.err
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal ascii85 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// Decode decodes src into dst, returning both the number
// of bytes written to dst and the number consumed from src.
// If src contains invalid ascii85 data, Decode will return the
// number of bytes successfully written and a CorruptInputError.
// Decode ignores space and control characters in src.
// Often, ascii85-encoded data is wrapped in <~ and ~> symbols.
// Decode expects these to have been stripped by the caller.
//
// If flush is true, Decode assumes that src represents the
// end of the input stream and processes it completely rather
// than wait for the completion of another 32-bit block.
//
// NewDecoder wraps an io.Reader interface around Decode.
func Decode(dst, src []byte, flush bool) (ndst, nsrc int, err error) {
var v uint32
var nb int
for i, b := range src {
if len(dst)-ndst < 4 {
return
}
switch {
case b <= ' ':
continue
case b == 'z' && nb == 0:
nb = 5
v = 0
case '!' <= b && b <= 'u':
v = v*85 + uint32(b-'!')
nb++
default:
return 0, 0, CorruptInputError(i)
}
if nb == 5 {
nsrc = i + 1
dst[ndst] = byte(v >> 24)
dst[ndst+1] = byte(v >> 16)
dst[ndst+2] = byte(v >> 8)
dst[ndst+3] = byte(v)
ndst += 4
nb = 0
v = 0
}
}
if flush {
nsrc = len(src)
if nb > 0 {
// The number of output bytes in the last fragment
// is the number of leftover input bytes - 1:
// the extra byte provides enough bits to cover
// the inefficiency of the encoding for the block.
if nb == 1 {
return 0, 0, CorruptInputError(len(src))
}
for i := nb; i < 5; i++ {
// The short encoding truncated the output value.
// We have to assume the worst case values (digit 84)
// in order to ensure that the top bits are correct.
v = v*85 + 84
}
for i := 0; i < nb-1; i++ {
dst[ndst] = byte(v >> 24)
v <<= 8
ndst++
}
}
}
return
}
// NewDecoder constructs a new ascii85 stream decoder.
func NewDecoder(r io.Reader) io.Reader { return &decoder{r: r} }
type decoder struct {
err error
readErr error
r io.Reader
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024]byte
}
func (d *decoder) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
if d.err != nil {
return 0, d.err
}
for {
// Copy leftover output from last decode.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
return
}
// Decode leftover input from last read.
var nn, nsrc, ndst int
if d.nbuf > 0 {
ndst, nsrc, d.err = Decode(d.outbuf[0:], d.buf[0:d.nbuf], d.readErr != nil)
if ndst > 0 {
d.out = d.outbuf[0:ndst]
d.nbuf = copy(d.buf[0:], d.buf[nsrc:d.nbuf])
continue // copy out and return
}
if ndst == 0 && d.err == nil {
// Special case: input buffer is mostly filled with non-data bytes.
// Filter out such bytes to make room for more input.
off := 0
for i := 0; i < d.nbuf; i++ {
if d.buf[i] > ' ' {
d.buf[off] = d.buf[i]
off++
}
}
d.nbuf = off
}
}
// Out of input, out of decoded output. Check errors.
if d.err != nil {
return 0, d.err
}
if d.readErr != nil {
d.err = d.readErr
return 0, d.err
}
// Read more data.
nn, d.readErr = d.r.Read(d.buf[d.nbuf:])
d.nbuf += nn
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package asn1 implements parsing of DER-encoded ASN.1 data structures,
// as defined in ITU-T Rec X.690.
//
// See also “A Layman's Guide to a Subset of ASN.1, BER, and DER,”
// http://luca.ntop.org/Teaching/Appunti/asn1.html.
package asn1
// ASN.1 is a syntax for specifying abstract objects and BER, DER, PER, XER etc
// are different encoding formats for those objects. Here, we'll be dealing
// with DER, the Distinguished Encoding Rules. DER is used in X.509 because
// it's fast to parse and, unlike BER, has a unique encoding for every object.
// When calculating hashes over objects, it's important that the resulting
// bytes be the same at both ends and DER removes this margin of error.
//
// ASN.1 is very complex and this package doesn't attempt to implement
// everything by any means.
import (
"errors"
"fmt"
"math"
"math/big"
"reflect"
"strconv"
"time"
"unicode/utf16"
"unicode/utf8"
)
// A StructuralError suggests that the ASN.1 data is valid, but the Go type
// which is receiving it doesn't match.
type StructuralError struct {
Msg string
}
func (e StructuralError) Error() string { return "asn1: structure error: " + e.Msg }
// A SyntaxError suggests that the ASN.1 data is invalid.
type SyntaxError struct {
Msg string
}
func (e SyntaxError) Error() string { return "asn1: syntax error: " + e.Msg }
// We start by dealing with each of the primitive types in turn.
// BOOLEAN
func parseBool(bytes []byte) (ret bool, err error) {
if len(bytes) != 1 {
err = SyntaxError{"invalid boolean"}
return
}
// DER demands that "If the encoding represents the boolean value TRUE,
// its single contents octet shall have all eight bits set to one."
// Thus only 0 and 255 are valid encoded values.
switch bytes[0] {
case 0:
ret = false
case 0xff:
ret = true
default:
err = SyntaxError{"invalid boolean"}
}
return
}
// INTEGER
// checkInteger returns nil if the given bytes are a valid DER-encoded
// INTEGER and an error otherwise.
func checkInteger(bytes []byte) error {
if len(bytes) == 0 {
return StructuralError{"empty integer"}
}
if len(bytes) == 1 {
return nil
}
if (bytes[0] == 0 && bytes[1]&0x80 == 0) || (bytes[0] == 0xff && bytes[1]&0x80 == 0x80) {
return StructuralError{"integer not minimally-encoded"}
}
return nil
}
// parseInt64 treats the given bytes as a big-endian, signed integer and
// returns the result.
func parseInt64(bytes []byte) (ret int64, err error) {
err = checkInteger(bytes)
if err != nil {
return
}
if len(bytes) > 8 {
// We'll overflow an int64 in this case.
err = StructuralError{"integer too large"}
return
}
for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
ret <<= 8
ret |= int64(bytes[bytesRead])
}
// Shift up and down in order to sign extend the result.
ret <<= 64 - uint8(len(bytes))*8
ret >>= 64 - uint8(len(bytes))*8
return
}
// parseInt32 treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseInt32(bytes []byte) (int32, error) {
if err := checkInteger(bytes); err != nil {
return 0, err
}
ret64, err := parseInt64(bytes)
if err != nil {
return 0, err
}
if ret64 != int64(int32(ret64)) {
return 0, StructuralError{"integer too large"}
}
return int32(ret64), nil
}
var bigOne = big.NewInt(1)
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseBigInt(bytes []byte) (*big.Int, error) {
if err := checkInteger(bytes); err != nil {
return nil, err
}
ret := new(big.Int)
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
// This is a negative number.
notBytes := make([]byte, len(bytes))
for i := range notBytes {
notBytes[i] = ^bytes[i]
}
ret.SetBytes(notBytes)
ret.Add(ret, bigOne)
ret.Neg(ret)
return ret, nil
}
ret.SetBytes(bytes)
return ret, nil
}
// BIT STRING
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A
// bit string is padded up to the nearest byte in memory and the number of
// valid bits is recorded. Padding bits will be zero.
type BitString struct {
Bytes []byte // bits packed into bytes.
BitLength int // length in bits.
}
// At returns the bit at the given index. If the index is out of range it
// returns 0.
func (b BitString) At(i int) int {
if i < 0 || i >= b.BitLength {
return 0
}
x := i / 8
y := 7 - uint(i%8)
return int(b.Bytes[x]>>y) & 1
}
// RightAlign returns a slice where the padding bits are at the beginning. The
// slice may share memory with the BitString.
func (b BitString) RightAlign() []byte {
shift := uint(8 - (b.BitLength % 8))
if shift == 8 || len(b.Bytes) == 0 {
return b.Bytes
}
a := make([]byte, len(b.Bytes))
a[0] = b.Bytes[0] >> shift
for i := 1; i < len(b.Bytes); i++ {
a[i] = b.Bytes[i-1] << (8 - shift)
a[i] |= b.Bytes[i] >> shift
}
return a
}
// parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
func parseBitString(bytes []byte) (ret BitString, err error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length BIT STRING"}
return
}
paddingBits := int(bytes[0])
if paddingBits > 7 ||
len(bytes) == 1 && paddingBits > 0 ||
bytes[len(bytes)-1]&((1<<bytes[0])-1) != 0 {
err = SyntaxError{"invalid padding bits in BIT STRING"}
return
}
ret.BitLength = (len(bytes)-1)*8 - paddingBits
ret.Bytes = bytes[1:]
return
}
// NULL
// NullRawValue is a RawValue with its Tag set to the ASN.1 NULL type tag (5).
var NullRawValue = RawValue{Tag: TagNull}
// NullBytes contains bytes representing the DER-encoded ASN.1 NULL type.
var NullBytes = []byte{TagNull, 0}
// OBJECT IDENTIFIER
// An ObjectIdentifier represents an ASN.1 OBJECT IDENTIFIER.
type ObjectIdentifier []int
// Equal reports whether oi and other represent the same identifier.
func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
if len(oi) != len(other) {
return false
}
for i := 0; i < len(oi); i++ {
if oi[i] != other[i] {
return false
}
}
return true
}
func (oi ObjectIdentifier) String() string {
var s string
for i, v := range oi {
if i > 0 {
s += "."
}
s += strconv.Itoa(v)
}
return s
}
// parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarchy.
func parseObjectIdentifier(bytes []byte) (s ObjectIdentifier, err error) {
if len(bytes) == 0 {
err = SyntaxError{"zero length OBJECT IDENTIFIER"}
return
}
// In the worst case, we get two elements from the first byte (which is
// encoded differently) and then every varint is a single byte long.
s = make([]int, len(bytes)+1)
// The first varint is 40*value1 + value2:
// According to this packing, value1 can take the values 0, 1 and 2 only.
// When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2,
// then there are no restrictions on value2.
v, offset, err := parseBase128Int(bytes, 0)
if err != nil {
return
}
if v < 80 {
s[0] = v / 40
s[1] = v % 40
} else {
s[0] = 2
s[1] = v - 80
}
i := 2
for ; offset < len(bytes); i++ {
v, offset, err = parseBase128Int(bytes, offset)
if err != nil {
return
}
s[i] = v
}
s = s[0:i]
return
}
// ENUMERATED
// An Enumerated is represented as a plain int.
type Enumerated int
// FLAG
// A Flag accepts any data and is set to true if present.
type Flag bool
// parseBase128Int parses a base-128 encoded int from the given offset in the
// given byte slice. It returns the value and the new offset.
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error) {
offset = initOffset
var ret64 int64
for shifted := 0; offset < len(bytes); shifted++ {
// 5 * 7 bits per byte == 35 bits of data
// Thus the representation is either non-minimal or too large for an int32
if shifted == 5 {
err = StructuralError{"base 128 integer too large"}
return
}
ret64 <<= 7
b := bytes[offset]
// integers should be minimally encoded, so the leading octet should
// never be 0x80
if shifted == 0 && b == 0x80 {
err = SyntaxError{"integer is not minimally encoded"}
return
}
ret64 |= int64(b & 0x7f)
offset++
if b&0x80 == 0 {
ret = int(ret64)
// Ensure that the returned value fits in an int on all platforms
if ret64 > math.MaxInt32 {
err = StructuralError{"base 128 integer too large"}
}
return
}
}
err = SyntaxError{"truncated base 128 integer"}
return
}
// UTCTime
func parseUTCTime(bytes []byte) (ret time.Time, err error) {
s := string(bytes)
formatStr := "0601021504Z0700"
ret, err = time.Parse(formatStr, s)
if err != nil {
formatStr = "060102150405Z0700"
ret, err = time.Parse(formatStr, s)
}
if err != nil {
return
}
if serialized := ret.Format(formatStr); serialized != s {
err = fmt.Errorf("asn1: time did not serialize back to the original value and may be invalid: given %q, but serialized as %q", s, serialized)
return
}
if ret.Year() >= 2050 {
// UTCTime only encodes times prior to 2050. See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1
ret = ret.AddDate(-100, 0, 0)
}
return
}
// parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret time.Time, err error) {
const formatStr = "20060102150405Z0700"
s := string(bytes)
if ret, err = time.Parse(formatStr, s); err != nil {
return
}
if serialized := ret.Format(formatStr); serialized != s {
err = fmt.Errorf("asn1: time did not serialize back to the original value and may be invalid: given %q, but serialized as %q", s, serialized)
}
return
}
// NumericString
// parseNumericString parses an ASN.1 NumericString from the given byte array
// and returns it.
func parseNumericString(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if !isNumeric(b) {
return "", SyntaxError{"NumericString contains invalid character"}
}
}
return string(bytes), nil
}
// isNumeric reports whether the given b is in the ASN.1 NumericString set.
func isNumeric(b byte) bool {
return '0' <= b && b <= '9' ||
b == ' '
}
// PrintableString
// parsePrintableString parses an ASN.1 PrintableString from the given byte
// array and returns it.
func parsePrintableString(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if !isPrintable(b, allowAsterisk, allowAmpersand) {
err = SyntaxError{"PrintableString contains invalid character"}
return
}
}
ret = string(bytes)
return
}
type asteriskFlag bool
type ampersandFlag bool
const (
allowAsterisk asteriskFlag = true
rejectAsterisk asteriskFlag = false
allowAmpersand ampersandFlag = true
rejectAmpersand ampersandFlag = false
)
// isPrintable reports whether the given b is in the ASN.1 PrintableString set.
// If asterisk is allowAsterisk then '*' is also allowed, reflecting existing
// practice. If ampersand is allowAmpersand then '&' is allowed as well.
func isPrintable(b byte, asterisk asteriskFlag, ampersand ampersandFlag) bool {
return 'a' <= b && b <= 'z' ||
'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' ||
'\'' <= b && b <= ')' ||
'+' <= b && b <= '/' ||
b == ' ' ||
b == ':' ||
b == '=' ||
b == '?' ||
// This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it.
(bool(asterisk) && b == '*') ||
// This is not technically allowed either. However, not
// only is it relatively common, but there are also a
// handful of CA certificates that contain it. At least
// one of which will not expire until 2027.
(bool(ampersand) && b == '&')
}
// IA5String
// parseIA5String parses an ASN.1 IA5String (ASCII string) from the given
// byte slice and returns it.
func parseIA5String(bytes []byte) (ret string, err error) {
for _, b := range bytes {
if b >= utf8.RuneSelf {
err = SyntaxError{"IA5String contains invalid character"}
return
}
}
ret = string(bytes)
return
}
// T61String
// parseT61String parses an ASN.1 T61String (8-bit clean string) from the given
// byte slice and returns it.
func parseT61String(bytes []byte) (ret string, err error) {
return string(bytes), nil
}
// UTF8String
// parseUTF8String parses an ASN.1 UTF8String (raw UTF-8) from the given byte
// array and returns it.
func parseUTF8String(bytes []byte) (ret string, err error) {
if !utf8.Valid(bytes) {
return "", errors.New("asn1: invalid UTF-8 string")
}
return string(bytes), nil
}
// BMPString
// parseBMPString parses an ASN.1 BMPString (Basic Multilingual Plane of
// ISO/IEC/ITU 10646-1) from the given byte slice and returns it.
func parseBMPString(bmpString []byte) (string, error) {
if len(bmpString)%2 != 0 {
return "", errors.New("pkcs12: odd-length BMP string")
}
// Strip terminator if present.
if l := len(bmpString); l >= 2 && bmpString[l-1] == 0 && bmpString[l-2] == 0 {
bmpString = bmpString[:l-2]
}
s := make([]uint16, 0, len(bmpString)/2)
for len(bmpString) > 0 {
s = append(s, uint16(bmpString[0])<<8+uint16(bmpString[1]))
bmpString = bmpString[2:]
}
return string(utf16.Decode(s)), nil
}
// A RawValue represents an undecoded ASN.1 object.
type RawValue struct {
Class, Tag int
IsCompound bool
Bytes []byte
FullBytes []byte // includes the tag and length
}
// RawContent is used to signal that the undecoded, DER data needs to be
// preserved for a struct. To use it, the first field of the struct must have
// this type. It's an error for any of the other fields to have this type.
type RawContent []byte
// Tagging
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset
// into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code.
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err error) {
offset = initOffset
// parseTagAndLength should not be called without at least a single
// byte to read. Thus this check is for robustness:
if offset >= len(bytes) {
err = errors.New("asn1: internal error in parseTagAndLength")
return
}
b := bytes[offset]
offset++
ret.class = int(b >> 6)
ret.isCompound = b&0x20 == 0x20
ret.tag = int(b & 0x1f)
// If the bottom five bits are set, then the tag number is actually base 128
// encoded afterwards
if ret.tag == 0x1f {
ret.tag, offset, err = parseBase128Int(bytes, offset)
if err != nil {
return
}
// Tags should be encoded in minimal form.
if ret.tag < 0x1f {
err = SyntaxError{"non-minimal tag"}
return
}
}
if offset >= len(bytes) {
err = SyntaxError{"truncated tag or length"}
return
}
b = bytes[offset]
offset++
if b&0x80 == 0 {
// The length is encoded in the bottom 7 bits.
ret.length = int(b & 0x7f)
} else {
// Bottom 7 bits give the number of length bytes to follow.
numBytes := int(b & 0x7f)
if numBytes == 0 {
err = SyntaxError{"indefinite length found (not DER)"}
return
}
ret.length = 0
for i := 0; i < numBytes; i++ {
if offset >= len(bytes) {
err = SyntaxError{"truncated tag or length"}
return
}
b = bytes[offset]
offset++
if ret.length >= 1<<23 {
// We can't shift ret.length up without
// overflowing.
err = StructuralError{"length too large"}
return
}
ret.length <<= 8
ret.length |= int(b)
if ret.length == 0 {
// DER requires that lengths be minimal.
err = StructuralError{"superfluous leading zeros in length"}
return
}
}
// Short lengths must be encoded in short form.
if ret.length < 0x80 {
err = StructuralError{"non-minimal length"}
return
}
}
return
}
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte slice and returns them as a
// slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err error) {
matchAny, expectedTag, compoundType, ok := getUniversalType(elemType)
if !ok {
err = StructuralError{"unknown Go type for slice"}
return
}
// First we iterate over the input and count the number of elements,
// checking that the types are correct in each case.
numElements := 0
for offset := 0; offset < len(bytes); {
var t tagAndLength
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
switch t.tag {
case TagIA5String, TagGeneralString, TagT61String, TagUTF8String, TagNumericString, TagBMPString:
// We pretend that various other string types are
// PRINTABLE STRINGs so that a sequence of them can be
// parsed into a []string.
t.tag = TagPrintableString
case TagGeneralizedTime, TagUTCTime:
// Likewise, both time types are treated the same.
t.tag = TagUTCTime
}
if !matchAny && (t.class != ClassUniversal || t.isCompound != compoundType || t.tag != expectedTag) {
err = StructuralError{"sequence tag mismatch"}
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"truncated sequence"}
return
}
offset += t.length
numElements++
}
ret = reflect.MakeSlice(sliceType, numElements, numElements)
params := fieldParameters{}
offset := 0
for i := 0; i < numElements; i++ {
offset, err = parseField(ret.Index(i), bytes, offset, params)
if err != nil {
return
}
}
return
}
var (
bitStringType = reflect.TypeOf(BitString{})
objectIdentifierType = reflect.TypeOf(ObjectIdentifier{})
enumeratedType = reflect.TypeOf(Enumerated(0))
flagType = reflect.TypeOf(Flag(false))
timeType = reflect.TypeOf(time.Time{})
rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.TypeOf(RawContent(nil))
bigIntType = reflect.TypeOf((*big.Int)(nil))
)
// invalidLength reports whether offset + length > sliceLength, or if the
// addition would overflow.
func invalidLength(offset, length, sliceLength int) bool {
return offset+length < offset || offset+length > sliceLength
}
// parseField is the main parsing function. Given a byte slice and an offset
// into the array, it will try to parse a suitable ASN.1 value out and store it
// in the given Value.
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err error) {
offset = initOffset
fieldType := v.Type()
// If we have run out of data, it may be that there are optional elements at the end.
if offset == len(bytes) {
if !setDefaultValue(v, params) {
err = SyntaxError{"sequence truncated"}
}
return
}
// Deal with the ANY type.
if ifaceType := fieldType; ifaceType.Kind() == reflect.Interface && ifaceType.NumMethod() == 0 {
var t tagAndLength
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"data truncated"}
return
}
var result any
if !t.isCompound && t.class == ClassUniversal {
innerBytes := bytes[offset : offset+t.length]
switch t.tag {
case TagPrintableString:
result, err = parsePrintableString(innerBytes)
case TagNumericString:
result, err = parseNumericString(innerBytes)
case TagIA5String:
result, err = parseIA5String(innerBytes)
case TagT61String:
result, err = parseT61String(innerBytes)
case TagUTF8String:
result, err = parseUTF8String(innerBytes)
case TagInteger:
result, err = parseInt64(innerBytes)
case TagBitString:
result, err = parseBitString(innerBytes)
case TagOID:
result, err = parseObjectIdentifier(innerBytes)
case TagUTCTime:
result, err = parseUTCTime(innerBytes)
case TagGeneralizedTime:
result, err = parseGeneralizedTime(innerBytes)
case TagOctetString:
result = innerBytes
case TagBMPString:
result, err = parseBMPString(innerBytes)
default:
// If we don't know how to handle the type, we just leave Value as nil.
}
}
offset += t.length
if err != nil {
return
}
if result != nil {
v.Set(reflect.ValueOf(result))
}
return
}
t, offset, err := parseTagAndLength(bytes, offset)
if err != nil {
return
}
if params.explicit {
expectedClass := ClassContextSpecific
if params.application {
expectedClass = ClassApplication
}
if offset == len(bytes) {
err = StructuralError{"explicit tag has no child"}
return
}
if t.class == expectedClass && t.tag == *params.tag && (t.length == 0 || t.isCompound) {
if fieldType == rawValueType {
// The inner element should not be parsed for RawValues.
} else if t.length > 0 {
t, offset, err = parseTagAndLength(bytes, offset)
if err != nil {
return
}
} else {
if fieldType != flagType {
err = StructuralError{"zero length explicit tag was not an asn1.Flag"}
return
}
v.SetBool(true)
return
}
} else {
// The tags didn't match, it might be an optional element.
ok := setDefaultValue(v, params)
if ok {
offset = initOffset
} else {
err = StructuralError{"explicitly tagged member didn't match"}
}
return
}
}
matchAny, universalTag, compoundType, ok1 := getUniversalType(fieldType)
if !ok1 {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", fieldType)}
return
}
// Special case for strings: all the ASN.1 string types map to the Go
// type string. getUniversalType returns the tag for PrintableString
// when it sees a string, so if we see a different string type on the
// wire, we change the universal type to match.
if universalTag == TagPrintableString {
if t.class == ClassUniversal {
switch t.tag {
case TagIA5String, TagGeneralString, TagT61String, TagUTF8String, TagNumericString, TagBMPString:
universalTag = t.tag
}
} else if params.stringType != 0 {
universalTag = params.stringType
}
}
// Special case for time: UTCTime and GeneralizedTime both map to the
// Go type time.Time.
if universalTag == TagUTCTime && t.tag == TagGeneralizedTime && t.class == ClassUniversal {
universalTag = TagGeneralizedTime
}
if params.set {
universalTag = TagSet
}
matchAnyClassAndTag := matchAny
expectedClass := ClassUniversal
expectedTag := universalTag
if !params.explicit && params.tag != nil {
expectedClass = ClassContextSpecific
expectedTag = *params.tag
matchAnyClassAndTag = false
}
if !params.explicit && params.application && params.tag != nil {
expectedClass = ClassApplication
expectedTag = *params.tag
matchAnyClassAndTag = false
}
if !params.explicit && params.private && params.tag != nil {
expectedClass = ClassPrivate
expectedTag = *params.tag
matchAnyClassAndTag = false
}
// We have unwrapped any explicit tagging at this point.
if !matchAnyClassAndTag && (t.class != expectedClass || t.tag != expectedTag) ||
(!matchAny && t.isCompound != compoundType) {
// Tags don't match. Again, it could be an optional element.
ok := setDefaultValue(v, params)
if ok {
offset = initOffset
} else {
err = StructuralError{fmt.Sprintf("tags don't match (%d vs %+v) %+v %s @%d", expectedTag, t, params, fieldType.Name(), offset)}
}
return
}
if invalidLength(offset, t.length, len(bytes)) {
err = SyntaxError{"data truncated"}
return
}
innerBytes := bytes[offset : offset+t.length]
offset += t.length
// We deal with the structures defined in this package first.
switch v := v.Addr().Interface().(type) {
case *RawValue:
*v = RawValue{t.class, t.tag, t.isCompound, innerBytes, bytes[initOffset:offset]}
return
case *ObjectIdentifier:
*v, err = parseObjectIdentifier(innerBytes)
return
case *BitString:
*v, err = parseBitString(innerBytes)
return
case *time.Time:
if universalTag == TagUTCTime {
*v, err = parseUTCTime(innerBytes)
return
}
*v, err = parseGeneralizedTime(innerBytes)
return
case *Enumerated:
parsedInt, err1 := parseInt32(innerBytes)
if err1 == nil {
*v = Enumerated(parsedInt)
}
err = err1
return
case *Flag:
*v = true
return
case **big.Int:
parsedInt, err1 := parseBigInt(innerBytes)
if err1 == nil {
*v = parsedInt
}
err = err1
return
}
switch val := v; val.Kind() {
case reflect.Bool:
parsedBool, err1 := parseBool(innerBytes)
if err1 == nil {
val.SetBool(parsedBool)
}
err = err1
return
case reflect.Int, reflect.Int32, reflect.Int64:
if val.Type().Size() == 4 {
parsedInt, err1 := parseInt32(innerBytes)
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
} else {
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
}
return
// TODO(dfc) Add support for the remaining integer types
case reflect.Struct:
structType := fieldType
for i := 0; i < structType.NumField(); i++ {
if !structType.Field(i).IsExported() {
err = StructuralError{"struct contains unexported fields"}
return
}
}
if structType.NumField() > 0 &&
structType.Field(0).Type == rawContentsType {
bytes := bytes[initOffset:offset]
val.Field(0).Set(reflect.ValueOf(RawContent(bytes)))
}
innerOffset := 0
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if i == 0 && field.Type == rawContentsType {
continue
}
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
if err != nil {
return
}
}
// We allow extra bytes at the end of the SEQUENCE because
// adding elements to the end has been used in X.509 as the
// version numbers have increased.
return
case reflect.Slice:
sliceType := fieldType
if sliceType.Elem().Kind() == reflect.Uint8 {
val.Set(reflect.MakeSlice(sliceType, len(innerBytes), len(innerBytes)))
reflect.Copy(val, reflect.ValueOf(innerBytes))
return
}
newSlice, err1 := parseSequenceOf(innerBytes, sliceType, sliceType.Elem())
if err1 == nil {
val.Set(newSlice)
}
err = err1
return
case reflect.String:
var v string
switch universalTag {
case TagPrintableString:
v, err = parsePrintableString(innerBytes)
case TagNumericString:
v, err = parseNumericString(innerBytes)
case TagIA5String:
v, err = parseIA5String(innerBytes)
case TagT61String:
v, err = parseT61String(innerBytes)
case TagUTF8String:
v, err = parseUTF8String(innerBytes)
case TagGeneralString:
// GeneralString is specified in ISO-2022/ECMA-35,
// A brief review suggests that it includes structures
// that allow the encoding to change midstring and
// such. We give up and pass it as an 8-bit string.
v, err = parseT61String(innerBytes)
case TagBMPString:
v, err = parseBMPString(innerBytes)
default:
err = SyntaxError{fmt.Sprintf("internal error: unknown string type %d", universalTag)}
}
if err == nil {
val.SetString(v)
}
return
}
err = StructuralError{"unsupported: " + v.Type().String()}
return
}
// canHaveDefaultValue reports whether k is a Kind that we will set a default
// value for. (A signed integer, essentially.)
func canHaveDefaultValue(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
}
return false
}
// setDefaultValue is used to install a default value, from a tag string, into
// a Value. It is successful if the field was optional, even if a default value
// wasn't provided or it failed to install it into the Value.
func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
if !params.optional {
return
}
ok = true
if params.defaultValue == nil {
return
}
if canHaveDefaultValue(v.Kind()) {
v.SetInt(*params.defaultValue)
}
return
}
// Unmarshal parses the DER-encoded ASN.1 data structure b
// and uses the reflect package to fill in an arbitrary value pointed at by val.
// Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. If val
// is nil or not a pointer, Unmarshal returns an error.
//
// After parsing b, any bytes that were leftover and not used to fill
// val will be returned in rest. When parsing a SEQUENCE into a struct,
// any trailing elements of the SEQUENCE that do not have matching
// fields in val will not be included in rest, as these are considered
// valid elements of the SEQUENCE and not trailing data.
//
// An ASN.1 INTEGER can be written to an int, int32, int64,
// or *big.Int (from the math/big package).
// If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error.
//
// An ASN.1 BIT STRING can be written to a BitString.
//
// An ASN.1 OCTET STRING can be written to a []byte.
//
// An ASN.1 OBJECT IDENTIFIER can be written to an
// ObjectIdentifier.
//
// An ASN.1 ENUMERATED can be written to an Enumerated.
//
// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a time.Time.
//
// An ASN.1 PrintableString, IA5String, or NumericString can be written to a string.
//
// Any of the above ASN.1 values can be written to an interface{}.
// The value stored in the interface has the corresponding Go type.
// For integers, that type is int64.
//
// An ASN.1 SEQUENCE OF x or SET OF x can be written
// to a slice if an x can be written to the slice's element type.
//
// An ASN.1 SEQUENCE or SET can be written to a struct
// if each of the elements in the sequence can be
// written to the corresponding element in the struct.
//
// The following tags on struct fields have special meaning to Unmarshal:
//
// application specifies that an APPLICATION tag is used
// private specifies that a PRIVATE tag is used
// default:x sets the default value for optional integer fields (only used if optional is also present)
// explicit specifies that an additional, explicit tag wraps the implicit one
// optional marks the field as ASN.1 OPTIONAL
// set causes a SET, rather than a SEQUENCE type to be expected
// tag:x specifies the ASN.1 tag number; implies ASN.1 CONTEXT SPECIFIC
//
// When decoding an ASN.1 value with an IMPLICIT tag into a string field,
// Unmarshal will default to a PrintableString, which doesn't support
// characters such as '@' and '&'. To force other encodings, use the following
// tags:
//
// ia5 causes strings to be unmarshaled as ASN.1 IA5String values
// numeric causes strings to be unmarshaled as ASN.1 NumericString values
// utf8 causes strings to be unmarshaled as ASN.1 UTF8String values
//
// If the type of the first field of a structure is RawContent then the raw
// ASN1 contents of the struct will be stored in it.
//
// If the name of a slice type ends with "SET" then it's treated as if
// the "set" tag was set on it. This results in interpreting the type as a
// SET OF x rather than a SEQUENCE OF x. This can be used with nested slices
// where a struct tag cannot be given.
//
// Other ASN.1 types are not supported; if it encounters them,
// Unmarshal returns a parse error.
func Unmarshal(b []byte, val any) (rest []byte, err error) {
return UnmarshalWithParams(b, val, "")
}
// An invalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
type invalidUnmarshalError struct {
Type reflect.Type
}
func (e *invalidUnmarshalError) Error() string {
if e.Type == nil {
return "asn1: Unmarshal recipient value is nil"
}
if e.Type.Kind() != reflect.Pointer {
return "asn1: Unmarshal recipient value is non-pointer " + e.Type.String()
}
return "asn1: Unmarshal recipient value is nil " + e.Type.String()
}
// UnmarshalWithParams allows field parameters to be specified for the
// top-level element. The form of the params is the same as the field tags.
func UnmarshalWithParams(b []byte, val any, params string) (rest []byte, err error) {
v := reflect.ValueOf(val)
if v.Kind() != reflect.Pointer || v.IsNil() {
return nil, &invalidUnmarshalError{reflect.TypeOf(val)}
}
offset, err := parseField(v.Elem(), b, 0, parseFieldParameters(params))
if err != nil {
return nil, err
}
return b[offset:], nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"reflect"
"strconv"
"strings"
)
// ASN.1 objects have metadata preceding them:
// the tag: the type of the object
// a flag denoting if this object is compound or not
// the class type: the namespace of the tag
// the length of the object, in bytes
// Here are some standard tags and classes
// ASN.1 tags represent the type of the following object.
const (
TagBoolean = 1
TagInteger = 2
TagBitString = 3
TagOctetString = 4
TagNull = 5
TagOID = 6
TagEnum = 10
TagUTF8String = 12
TagSequence = 16
TagSet = 17
TagNumericString = 18
TagPrintableString = 19
TagT61String = 20
TagIA5String = 22
TagUTCTime = 23
TagGeneralizedTime = 24
TagGeneralString = 27
TagBMPString = 30
)
// ASN.1 class types represent the namespace of the tag.
const (
ClassUniversal = 0
ClassApplication = 1
ClassContextSpecific = 2
ClassPrivate = 3
)
type tagAndLength struct {
class, tag, length int
isCompound bool
}
// ASN.1 has IMPLICIT and EXPLICIT tags, which can be translated as "instead
// of" and "in addition to". When not specified, every primitive type has a
// default tag in the UNIVERSAL class.
//
// For example: a BIT STRING is tagged [UNIVERSAL 3] by default (although ASN.1
// doesn't actually have a UNIVERSAL keyword). However, by saying [IMPLICIT
// CONTEXT-SPECIFIC 42], that means that the tag is replaced by another.
//
// On the other hand, if it said [EXPLICIT CONTEXT-SPECIFIC 10], then an
// /additional/ tag would wrap the default tag. This explicit tag will have the
// compound flag set.
//
// (This is used in order to remove ambiguity with optional elements.)
//
// You can layer EXPLICIT and IMPLICIT tags to an arbitrary depth, however we
// don't support that here. We support a single layer of EXPLICIT or IMPLICIT
// tagging with tag strings on the fields of a structure.
// fieldParameters is the parsed representation of tag string from a structure field.
type fieldParameters struct {
optional bool // true iff the field is OPTIONAL
explicit bool // true iff an EXPLICIT tag is in use.
application bool // true iff an APPLICATION tag is in use.
private bool // true iff a PRIVATE tag is in use.
defaultValue *int64 // a default value for INTEGER typed fields (maybe nil).
tag *int // the EXPLICIT or IMPLICIT tag (maybe nil).
stringType int // the string tag to use when marshaling.
timeType int // the time tag to use when marshaling.
set bool // true iff this should be encoded as a SET
omitEmpty bool // true iff this should be omitted if empty when marshaling.
// Invariants:
// if explicit is set, tag is non-nil.
}
// Given a tag string with the format specified in the package comment,
// parseFieldParameters will parse it into a fieldParameters structure,
// ignoring unknown parts of the string.
func parseFieldParameters(str string) (ret fieldParameters) {
var part string
for len(str) > 0 {
part, str, _ = strings.Cut(str, ",")
switch {
case part == "optional":
ret.optional = true
case part == "explicit":
ret.explicit = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "generalized":
ret.timeType = TagGeneralizedTime
case part == "utc":
ret.timeType = TagUTCTime
case part == "ia5":
ret.stringType = TagIA5String
case part == "printable":
ret.stringType = TagPrintableString
case part == "numeric":
ret.stringType = TagNumericString
case part == "utf8":
ret.stringType = TagUTF8String
case strings.HasPrefix(part, "default:"):
i, err := strconv.ParseInt(part[8:], 10, 64)
if err == nil {
ret.defaultValue = new(int64)
*ret.defaultValue = i
}
case strings.HasPrefix(part, "tag:"):
i, err := strconv.Atoi(part[4:])
if err == nil {
ret.tag = new(int)
*ret.tag = i
}
case part == "set":
ret.set = true
case part == "application":
ret.application = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "private":
ret.private = true
if ret.tag == nil {
ret.tag = new(int)
}
case part == "omitempty":
ret.omitEmpty = true
}
}
return
}
// Given a reflected Go type, getUniversalType returns the default tag number
// and expected compound flag.
func getUniversalType(t reflect.Type) (matchAny bool, tagNumber int, isCompound, ok bool) {
switch t {
case rawValueType:
return true, -1, false, true
case objectIdentifierType:
return false, TagOID, false, true
case bitStringType:
return false, TagBitString, false, true
case timeType:
return false, TagUTCTime, false, true
case enumeratedType:
return false, TagEnum, false, true
case bigIntType:
return false, TagInteger, false, true
}
switch t.Kind() {
case reflect.Bool:
return false, TagBoolean, false, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return false, TagInteger, false, true
case reflect.Struct:
return false, TagSequence, true, true
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return false, TagOctetString, false, true
}
if strings.HasSuffix(t.Name(), "SET") {
return false, TagSet, true, true
}
return false, TagSequence, true, true
case reflect.String:
return false, TagPrintableString, false, true
}
return false, 0, false, false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"bytes"
"errors"
"fmt"
"math/big"
"reflect"
"sort"
"time"
"unicode/utf8"
)
var (
byte00Encoder encoder = byteEncoder(0x00)
byteFFEncoder encoder = byteEncoder(0xff)
)
// encoder represents an ASN.1 element that is waiting to be marshaled.
type encoder interface {
// Len returns the number of bytes needed to marshal this element.
Len() int
// Encode encodes this element by writing Len() bytes to dst.
Encode(dst []byte)
}
type byteEncoder byte
func (c byteEncoder) Len() int {
return 1
}
func (c byteEncoder) Encode(dst []byte) {
dst[0] = byte(c)
}
type bytesEncoder []byte
func (b bytesEncoder) Len() int {
return len(b)
}
func (b bytesEncoder) Encode(dst []byte) {
if copy(dst, b) != len(b) {
panic("internal error")
}
}
type stringEncoder string
func (s stringEncoder) Len() int {
return len(s)
}
func (s stringEncoder) Encode(dst []byte) {
if copy(dst, s) != len(s) {
panic("internal error")
}
}
type multiEncoder []encoder
func (m multiEncoder) Len() int {
var size int
for _, e := range m {
size += e.Len()
}
return size
}
func (m multiEncoder) Encode(dst []byte) {
var off int
for _, e := range m {
e.Encode(dst[off:])
off += e.Len()
}
}
type setEncoder []encoder
func (s setEncoder) Len() int {
var size int
for _, e := range s {
size += e.Len()
}
return size
}
func (s setEncoder) Encode(dst []byte) {
// Per X690 Section 11.6: The encodings of the component values of a
// set-of value shall appear in ascending order, the encodings being
// compared as octet strings with the shorter components being padded
// at their trailing end with 0-octets.
//
// First we encode each element to its TLV encoding and then use
// octetSort to get the ordering expected by X690 DER rules before
// writing the sorted encodings out to dst.
l := make([][]byte, len(s))
for i, e := range s {
l[i] = make([]byte, e.Len())
e.Encode(l[i])
}
sort.Slice(l, func(i, j int) bool {
// Since we are using bytes.Compare to compare TLV encodings we
// don't need to right pad s[i] and s[j] to the same length as
// suggested in X690. If len(s[i]) < len(s[j]) the length octet of
// s[i], which is the first determining byte, will inherently be
// smaller than the length octet of s[j]. This lets us skip the
// padding step.
return bytes.Compare(l[i], l[j]) < 0
})
var off int
for _, b := range l {
copy(dst[off:], b)
off += len(b)
}
}
type taggedEncoder struct {
// scratch contains temporary space for encoding the tag and length of
// an element in order to avoid extra allocations.
scratch [8]byte
tag encoder
body encoder
}
func (t *taggedEncoder) Len() int {
return t.tag.Len() + t.body.Len()
}
func (t *taggedEncoder) Encode(dst []byte) {
t.tag.Encode(dst)
t.body.Encode(dst[t.tag.Len():])
}
type int64Encoder int64
func (i int64Encoder) Len() int {
n := 1
for i > 127 {
n++
i >>= 8
}
for i < -128 {
n++
i >>= 8
}
return n
}
func (i int64Encoder) Encode(dst []byte) {
n := i.Len()
for j := 0; j < n; j++ {
dst[j] = byte(i >> uint((n-1-j)*8))
}
}
func base128IntLength(n int64) int {
if n == 0 {
return 1
}
l := 0
for i := n; i > 0; i >>= 7 {
l++
}
return l
}
func appendBase128Int(dst []byte, n int64) []byte {
l := base128IntLength(n)
for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
dst = append(dst, o)
}
return dst
}
func makeBigInt(n *big.Int) (encoder, error) {
if n == nil {
return nil, StructuralError{"empty integer"}
}
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll invert and subtract 1. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
}
return bytesEncoder(bytes), nil
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
return byte00Encoder, nil
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
}
return bytesEncoder(bytes), nil
}
}
func appendLength(dst []byte, i int) []byte {
n := lengthLength(i)
for ; n > 0; n-- {
dst = append(dst, byte(i>>uint((n-1)*8)))
}
return dst
}
func lengthLength(i int) (numBytes int) {
numBytes = 1
for i > 255 {
numBytes++
i >>= 8
}
return
}
func appendTagAndLength(dst []byte, t tagAndLength) []byte {
b := uint8(t.class) << 6
if t.isCompound {
b |= 0x20
}
if t.tag >= 31 {
b |= 0x1f
dst = append(dst, b)
dst = appendBase128Int(dst, int64(t.tag))
} else {
b |= uint8(t.tag)
dst = append(dst, b)
}
if t.length >= 128 {
l := lengthLength(t.length)
dst = append(dst, 0x80|byte(l))
dst = appendLength(dst, t.length)
} else {
dst = append(dst, byte(t.length))
}
return dst
}
type bitStringEncoder BitString
func (b bitStringEncoder) Len() int {
return len(b.Bytes) + 1
}
func (b bitStringEncoder) Encode(dst []byte) {
dst[0] = byte((8 - b.BitLength%8) % 8)
if copy(dst[1:], b.Bytes) != len(b.Bytes) {
panic("internal error")
}
}
type oidEncoder []int
func (oid oidEncoder) Len() int {
l := base128IntLength(int64(oid[0]*40 + oid[1]))
for i := 2; i < len(oid); i++ {
l += base128IntLength(int64(oid[i]))
}
return l
}
func (oid oidEncoder) Encode(dst []byte) {
dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
for i := 2; i < len(oid); i++ {
dst = appendBase128Int(dst, int64(oid[i]))
}
}
func makeObjectIdentifier(oid []int) (e encoder, err error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return nil, StructuralError{"invalid object identifier"}
}
return oidEncoder(oid), nil
}
func makePrintableString(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
// The asterisk is often used in PrintableString, even though
// it is invalid. If a PrintableString was specifically
// requested then the asterisk is permitted by this code.
// Ampersand is allowed in parsing due a handful of CA
// certificates, however when making new certificates
// it is rejected.
if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
return nil, StructuralError{"PrintableString contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeIA5String(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
if s[i] > 127 {
return nil, StructuralError{"IA5String contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeNumericString(s string) (e encoder, err error) {
for i := 0; i < len(s); i++ {
if !isNumeric(s[i]) {
return nil, StructuralError{"NumericString contains invalid character"}
}
}
return stringEncoder(s), nil
}
func makeUTF8String(s string) encoder {
return stringEncoder(s)
}
func appendTwoDigits(dst []byte, v int) []byte {
return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
}
func appendFourDigits(dst []byte, v int) []byte {
var bytes [4]byte
for i := range bytes {
bytes[3-i] = '0' + byte(v%10)
v /= 10
}
return append(dst, bytes[:]...)
}
func outsideUTCRange(t time.Time) bool {
year := t.Year()
return year < 1950 || year >= 2050
}
func makeUTCTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 18)
dst, err = appendUTCTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func makeGeneralizedTime(t time.Time) (e encoder, err error) {
dst := make([]byte, 0, 20)
dst, err = appendGeneralizedTime(dst, t)
if err != nil {
return nil, err
}
return bytesEncoder(dst), nil
}
func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
switch {
case 1950 <= year && year < 2000:
dst = appendTwoDigits(dst, year-1900)
case 2000 <= year && year < 2050:
dst = appendTwoDigits(dst, year-2000)
default:
return nil, StructuralError{"cannot represent time as UTCTime"}
}
return appendTimeCommon(dst, t), nil
}
func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
if year < 0 || year > 9999 {
return nil, StructuralError{"cannot represent time as GeneralizedTime"}
}
dst = appendFourDigits(dst, year)
return appendTimeCommon(dst, t), nil
}
func appendTimeCommon(dst []byte, t time.Time) []byte {
_, month, day := t.Date()
dst = appendTwoDigits(dst, int(month))
dst = appendTwoDigits(dst, day)
hour, min, sec := t.Clock()
dst = appendTwoDigits(dst, hour)
dst = appendTwoDigits(dst, min)
dst = appendTwoDigits(dst, sec)
_, offset := t.Zone()
switch {
case offset/60 == 0:
return append(dst, 'Z')
case offset > 0:
dst = append(dst, '+')
case offset < 0:
dst = append(dst, '-')
}
offsetMinutes := offset / 60
if offsetMinutes < 0 {
offsetMinutes = -offsetMinutes
}
dst = appendTwoDigits(dst, offsetMinutes/60)
dst = appendTwoDigits(dst, offsetMinutes%60)
return dst
}
func stripTagAndLength(in []byte) []byte {
_, offset, err := parseTagAndLength(in, 0)
if err != nil {
return in
}
return in[offset:]
}
func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
switch value.Type() {
case flagType:
return bytesEncoder(nil), nil
case timeType:
t := value.Interface().(time.Time)
if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
return makeGeneralizedTime(t)
}
return makeUTCTime(t)
case bitStringType:
return bitStringEncoder(value.Interface().(BitString)), nil
case objectIdentifierType:
return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
case bigIntType:
return makeBigInt(value.Interface().(*big.Int))
}
switch v := value; v.Kind() {
case reflect.Bool:
if v.Bool() {
return byteFFEncoder, nil
}
return byte00Encoder, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return int64Encoder(v.Int()), nil
case reflect.Struct:
t := v.Type()
for i := 0; i < t.NumField(); i++ {
if !t.Field(i).IsExported() {
return nil, StructuralError{"struct contains unexported fields"}
}
}
startingField := 0
n := t.NumField()
if n == 0 {
return bytesEncoder(nil), nil
}
// If the first element of the structure is a non-empty
// RawContents, then we don't bother serializing the rest.
if t.Field(0).Type == rawContentsType {
s := v.Field(0)
if s.Len() > 0 {
bytes := s.Bytes()
/* The RawContents will contain the tag and
* length fields but we'll also be writing
* those ourselves, so we strip them out of
* bytes */
return bytesEncoder(stripTagAndLength(bytes)), nil
}
startingField = 1
}
switch n1 := n - startingField; n1 {
case 0:
return bytesEncoder(nil), nil
case 1:
return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
default:
m := make([]encoder, n1)
for i := 0; i < n1; i++ {
m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
if err != nil {
return nil, err
}
}
return multiEncoder(m), nil
}
case reflect.Slice:
sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 {
return bytesEncoder(v.Bytes()), nil
}
var fp fieldParameters
switch l := v.Len(); l {
case 0:
return bytesEncoder(nil), nil
case 1:
return makeField(v.Index(0), fp)
default:
m := make([]encoder, l)
for i := 0; i < l; i++ {
m[i], err = makeField(v.Index(i), fp)
if err != nil {
return nil, err
}
}
if params.set {
return setEncoder(m), nil
}
return multiEncoder(m), nil
}
case reflect.String:
switch params.stringType {
case TagIA5String:
return makeIA5String(v.String())
case TagPrintableString:
return makePrintableString(v.String())
case TagNumericString:
return makeNumericString(v.String())
default:
return makeUTF8String(v.String()), nil
}
}
return nil, StructuralError{"unknown Go type"}
}
func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
if !v.IsValid() {
return nil, fmt.Errorf("asn1: cannot marshal nil value")
}
// If the field is an interface{} then recurse into it.
if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
return makeField(v.Elem(), params)
}
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
return bytesEncoder(nil), nil
}
if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
defaultValue := reflect.New(v.Type()).Elem()
defaultValue.SetInt(*params.defaultValue)
if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
return bytesEncoder(nil), nil
}
}
// If no default value is given then the zero value for the type is
// assumed to be the default value. This isn't obviously the correct
// behavior, but it's what Go has traditionally done.
if params.optional && params.defaultValue == nil {
if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return bytesEncoder(nil), nil
}
}
if v.Type() == rawValueType {
rv := v.Interface().(RawValue)
if len(rv.FullBytes) != 0 {
return bytesEncoder(rv.FullBytes), nil
}
t := new(taggedEncoder)
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
t.body = bytesEncoder(rv.Bytes)
return t, nil
}
matchAny, tag, isCompound, ok := getUniversalType(v.Type())
if !ok || matchAny {
return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
}
if params.timeType != 0 && tag != TagUTCTime {
return nil, StructuralError{"explicit time type given to non-time member"}
}
if params.stringType != 0 && tag != TagPrintableString {
return nil, StructuralError{"explicit string type given to non-string member"}
}
switch tag {
case TagPrintableString:
if params.stringType == 0 {
// This is a string without an explicit string type. We'll use
// a PrintableString if the character set in the string is
// sufficiently limited, otherwise we'll use a UTF8String.
for _, r := range v.String() {
if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
if !utf8.ValidString(v.String()) {
return nil, errors.New("asn1: string not valid UTF-8")
}
tag = TagUTF8String
break
}
}
} else {
tag = params.stringType
}
case TagUTCTime:
if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
tag = TagGeneralizedTime
}
}
if params.set {
if tag != TagSequence {
return nil, StructuralError{"non sequence tagged as set"}
}
tag = TagSet
}
// makeField can be called for a slice that should be treated as a SET
// but doesn't have params.set set, for instance when using a slice
// with the SET type name suffix. In this case getUniversalType returns
// TagSet, but makeBody doesn't know about that so will treat the slice
// as a sequence. To work around this we set params.set.
if tag == TagSet && !params.set {
params.set = true
}
t := new(taggedEncoder)
t.body, err = makeBody(v, params)
if err != nil {
return nil, err
}
bodyLen := t.body.Len()
class := ClassUniversal
if params.tag != nil {
if params.application {
class = ClassApplication
} else if params.private {
class = ClassPrivate
} else {
class = ClassContextSpecific
}
if params.explicit {
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
tt := new(taggedEncoder)
tt.body = t
tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
class: class,
tag: *params.tag,
length: bodyLen + t.tag.Len(),
isCompound: true,
}))
return tt, nil
}
// implicit tag.
tag = *params.tag
}
t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
return t, nil
}
// Marshal returns the ASN.1 encoding of val.
//
// In addition to the struct tags recognised by Unmarshal, the following can be
// used:
//
// ia5: causes strings to be marshaled as ASN.1, IA5String values
// omitempty: causes empty slices to be skipped
// printable: causes strings to be marshaled as ASN.1, PrintableString values
// utf8: causes strings to be marshaled as ASN.1, UTF8String values
// utc: causes time.Time to be marshaled as ASN.1, UTCTime values
// generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values
func Marshal(val any) ([]byte, error) {
return MarshalWithParams(val, "")
}
// MarshalWithParams allows field parameters to be specified for the
// top-level element. The form of the params is the same as the field tags.
func MarshalWithParams(val any, params string) ([]byte, error) {
e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
if err != nil {
return nil, err
}
b := make([]byte, e.Len())
e.Encode(b)
return b, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package base32 implements base32 encoding as specified by RFC 4648.
package base32
import (
"io"
"strconv"
)
/*
* Encodings
*/
// An Encoding is a radix 32 encoding/decoding scheme, defined by a
// 32-character alphabet. The most common is the "base32" encoding
// introduced for SASL GSSAPI and standardized in RFC 4648.
// The alternate "base32hex" encoding is used in DNSSEC.
type Encoding struct {
encode [32]byte
decodeMap [256]byte
padChar rune
}
const (
StdPadding rune = '=' // Standard padding character
NoPadding rune = -1 // No padding
decodeMapInitialize = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
)
const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
const encodeHex = "0123456789ABCDEFGHIJKLMNOPQRSTUV"
// NewEncoding returns a new Encoding defined by the given alphabet,
// which must be a 32-byte string.
func NewEncoding(encoder string) *Encoding {
if len(encoder) != 32 {
panic("encoding alphabet is not 32-bytes long")
}
e := new(Encoding)
e.padChar = StdPadding
copy(e.encode[:], encoder)
copy(e.decodeMap[:], decodeMapInitialize)
for i := 0; i < len(encoder); i++ {
e.decodeMap[encoder[i]] = byte(i)
}
return e
}
// StdEncoding is the standard base32 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
// HexEncoding is the “Extended Hex Alphabet” defined in RFC 4648.
// It is typically used in DNS.
var HexEncoding = NewEncoding(encodeHex)
// WithPadding creates a new encoding identical to enc except
// with a specified padding character, or NoPadding to disable padding.
// The padding character must not be '\r' or '\n', must not
// be contained in the encoding's alphabet and must be a rune equal or
// below '\xff'.
func (enc Encoding) WithPadding(padding rune) *Encoding {
if padding == '\r' || padding == '\n' || padding > 0xff {
panic("invalid padding")
}
for i := 0; i < len(enc.encode); i++ {
if rune(enc.encode[i]) == padding {
panic("padding contained in alphabet")
}
}
enc.padChar = padding
return &enc
}
/*
* Encoder
*/
// Encode encodes src using the encoding enc, writing
// EncodedLen(len(src)) bytes to dst.
//
// The encoding pads the output to a multiple of 8 bytes,
// so Encode is not appropriate for use on individual blocks
// of a large data stream. Use NewEncoder() instead.
func (enc *Encoding) Encode(dst, src []byte) {
for len(src) > 0 {
var b [8]byte
// Unpack 8x 5-bit source blocks into a 5 byte
// destination quantum
switch len(src) {
default:
b[7] = src[4] & 0x1F
b[6] = src[4] >> 5
fallthrough
case 4:
b[6] |= (src[3] << 3) & 0x1F
b[5] = (src[3] >> 2) & 0x1F
b[4] = src[3] >> 7
fallthrough
case 3:
b[4] |= (src[2] << 1) & 0x1F
b[3] = (src[2] >> 4) & 0x1F
fallthrough
case 2:
b[3] |= (src[1] << 4) & 0x1F
b[2] = (src[1] >> 1) & 0x1F
b[1] = (src[1] >> 6) & 0x1F
fallthrough
case 1:
b[1] |= (src[0] << 2) & 0x1F
b[0] = src[0] >> 3
}
// Encode 5-bit blocks using the base32 alphabet
size := len(dst)
if size >= 8 {
// Common case, unrolled for extra performance
dst[0] = enc.encode[b[0]&31]
dst[1] = enc.encode[b[1]&31]
dst[2] = enc.encode[b[2]&31]
dst[3] = enc.encode[b[3]&31]
dst[4] = enc.encode[b[4]&31]
dst[5] = enc.encode[b[5]&31]
dst[6] = enc.encode[b[6]&31]
dst[7] = enc.encode[b[7]&31]
} else {
for i := 0; i < size; i++ {
dst[i] = enc.encode[b[i]&31]
}
}
// Pad the final quantum
if len(src) < 5 {
if enc.padChar == NoPadding {
break
}
dst[7] = byte(enc.padChar)
if len(src) < 4 {
dst[6] = byte(enc.padChar)
dst[5] = byte(enc.padChar)
if len(src) < 3 {
dst[4] = byte(enc.padChar)
if len(src) < 2 {
dst[3] = byte(enc.padChar)
dst[2] = byte(enc.padChar)
}
}
}
break
}
src = src[5:]
dst = dst[8:]
}
}
// EncodeToString returns the base32 encoding of src.
func (enc *Encoding) EncodeToString(src []byte) string {
buf := make([]byte, enc.EncodedLen(len(src)))
enc.Encode(buf, src)
return string(buf)
}
type encoder struct {
err error
enc *Encoding
w io.Writer
buf [5]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 5; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 5 {
return
}
e.enc.Encode(e.out[0:], e.buf[0:])
if _, e.err = e.w.Write(e.out[0:8]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 5 {
nn := len(e.out) / 8 * 5
if nn > len(p) {
nn = len(p)
nn -= nn % 5
}
e.enc.Encode(e.out[0:], p[0:nn])
if _, e.err = e.w.Write(e.out[0 : nn/5*8]); e.err != nil {
return n, e.err
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
e.enc.Encode(e.out[0:], e.buf[0:e.nbuf])
encodedLen := e.enc.EncodedLen(e.nbuf)
e.nbuf = 0
_, e.err = e.w.Write(e.out[0:encodedLen])
}
return e.err
}
// NewEncoder returns a new base32 stream encoder. Data written to
// the returned writer will be encoded using enc and then written to w.
// Base32 encodings operate in 5-byte blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// partially written blocks.
func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
return &encoder{enc: enc, w: w}
}
// EncodedLen returns the length in bytes of the base32 encoding
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return (n*8 + 4) / 5
}
return (n + 4) / 5 * 8
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal base32 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any
// additional data is an error. This method assumes that src has been
// stripped of all supported whitespace ('\r' and '\n').
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Lift the nil check outside of the loop.
_ = enc.decodeMap
dsti := 0
olen := len(src)
for len(src) > 0 && !end {
// Decode quantum using the base32 alphabet
var dbuf [8]byte
dlen := 8
for j := 0; j < 8; {
if len(src) == 0 {
if enc.padChar != NoPadding {
// We have reached the end and are missing padding
return n, false, CorruptInputError(olen - len(src) - j)
}
// We have reached the end and are not expecting any padding
dlen, end = j, true
break
}
in := src[0]
src = src[1:]
if in == byte(enc.padChar) && j >= 2 && len(src) < 8 {
// We've reached the end and there's padding
if len(src)+j < 8-1 {
// not enough padding
return n, false, CorruptInputError(olen)
}
for k := 0; k < 8-1-j; k++ {
if len(src) > k && src[k] != byte(enc.padChar) {
// incorrect padding
return n, false, CorruptInputError(olen - len(src) + k - 1)
}
}
dlen, end = j, true
// 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not
// valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing
// the five valid padding lengths, and Section 9 "Illustrations and
// Examples" for an illustration for how the 1st, 3rd and 6th base32
// src bytes do not yield enough information to decode a dst byte.
if dlen == 1 || dlen == 3 || dlen == 6 {
return n, false, CorruptInputError(olen - len(src) - 1)
}
break
}
dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF {
return n, false, CorruptInputError(olen - len(src) - 1)
}
j++
}
// Pack 8x 5-bit source blocks into 5 byte destination
// quantum
switch dlen {
case 8:
dst[dsti+4] = dbuf[6]<<5 | dbuf[7]
n++
fallthrough
case 7:
dst[dsti+3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
n++
fallthrough
case 5:
dst[dsti+2] = dbuf[3]<<4 | dbuf[4]>>1
n++
fallthrough
case 4:
dst[dsti+1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
n++
fallthrough
case 2:
dst[dsti+0] = dbuf[0]<<3 | dbuf[1]>>2
n++
}
dsti += 5
}
return n, end, nil
}
// Decode decodes src using the encoding enc. It writes at most
// DecodedLen(len(src)) bytes to dst and returns the number of bytes
// written. If src contains invalid base32 data, it will return the
// number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
buf := make([]byte, len(src))
l := stripNewlines(buf, src)
n, _, err = enc.decode(dst, buf[:l])
return
}
// DecodeString returns the bytes represented by the base32 string s.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
buf := []byte(s)
l := stripNewlines(buf, buf)
n, _, err := enc.decode(buf, buf[:l])
return buf[:n], err
}
type decoder struct {
err error
enc *Encoding
r io.Reader
end bool // saw end of message
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024 / 8 * 5]byte
}
func readEncodedData(r io.Reader, buf []byte, min int, expectsPadding bool) (n int, err error) {
for n < min && err == nil {
var nn int
nn, err = r.Read(buf[n:])
n += nn
}
// data was read, less than min bytes could be read
if n < min && n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
// no data was read, the buffer already contains some data
// when padding is disabled this is not an error, as the message can be of
// any length
if expectsPadding && min < 8 && n == 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Use leftover decoded output from last read.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
if len(d.out) == 0 {
return n, d.err
}
return n, nil
}
if d.err != nil {
return 0, d.err
}
// Read a chunk.
nn := len(p) / 5 * 8
if nn < 8 {
nn = 8
}
if nn > len(d.buf) {
nn = len(d.buf)
}
// Minimum amount of bytes that needs to be read each cycle
var min int
var expectsPadding bool
if d.enc.padChar == NoPadding {
min = 1
expectsPadding = false
} else {
min = 8 - d.nbuf
expectsPadding = true
}
nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], min, expectsPadding)
d.nbuf += nn
if d.nbuf < min {
return 0, d.err
}
if nn > 0 && d.end {
return 0, CorruptInputError(0)
}
// Decode chunk into p, or d.out and then p if p is too small.
var nr int
if d.enc.padChar == NoPadding {
nr = d.nbuf
} else {
nr = d.nbuf / 8 * 8
}
nw := d.enc.DecodedLen(d.nbuf)
if nw > len(p) {
nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
d.out = d.outbuf[0:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
n, d.end, err = d.enc.decode(p, d.buf[0:nr])
}
d.nbuf -= nr
for i := 0; i < d.nbuf; i++ {
d.buf[i] = d.buf[i+nr]
}
if err != nil && (d.err == nil || d.err == io.EOF) {
d.err = err
}
if len(d.out) > 0 {
// We cannot return all the decoded bytes to the caller in this
// invocation of Read, so we return a nil error to ensure that Read
// will be called again. The error stored in d.err, if any, will be
// returned with the last set of decoded bytes.
return n, nil
}
return n, d.err
}
type newlineFilteringReader struct {
wrapped io.Reader
}
// stripNewlines removes newline characters and returns the number
// of non-newline characters copied to dst.
func stripNewlines(dst, src []byte) int {
offset := 0
for _, b := range src {
if b == '\r' || b == '\n' {
continue
}
dst[offset] = b
offset++
}
return offset
}
func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
s := p[0:n]
offset := stripNewlines(s, s)
if err != nil || offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again
n, err = r.wrapped.Read(p)
}
return n, err
}
// NewDecoder constructs a new base32 stream decoder.
func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
return &decoder{enc: enc, r: &newlineFilteringReader{r}}
}
// DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base32-encoded data.
func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding {
return n * 5 / 8
}
return n / 8 * 5
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package base64 implements base64 encoding as specified by RFC 4648.
package base64
import (
"encoding/binary"
"io"
"strconv"
)
/*
* Encodings
*/
// An Encoding is a radix 64 encoding/decoding scheme, defined by a
// 64-character alphabet. The most common encoding is the "base64"
// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
// (RFC 1421). RFC 4648 also defines an alternate encoding, which is
// the standard encoding with - and _ substituted for + and /.
type Encoding struct {
encode [64]byte
decodeMap [256]byte
padChar rune
strict bool
}
const (
StdPadding rune = '=' // Standard padding character
NoPadding rune = -1 // No padding
decodeMapInitialize = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
)
const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
// NewEncoding returns a new padded Encoding defined by the given alphabet,
// which must be a 64-byte string that does not contain the padding character
// or CR / LF ('\r', '\n').
// The resulting Encoding uses the default padding character ('='),
// which may be changed or disabled via WithPadding.
func NewEncoding(encoder string) *Encoding {
if len(encoder) != 64 {
panic("encoding alphabet is not 64-bytes long")
}
for i := 0; i < len(encoder); i++ {
if encoder[i] == '\n' || encoder[i] == '\r' {
panic("encoding alphabet contains newline character")
}
}
e := new(Encoding)
e.padChar = StdPadding
copy(e.encode[:], encoder)
copy(e.decodeMap[:], decodeMapInitialize)
for i := 0; i < len(encoder); i++ {
e.decodeMap[encoder[i]] = byte(i)
}
return e
}
// WithPadding creates a new encoding identical to enc except
// with a specified padding character, or NoPadding to disable padding.
// The padding character must not be '\r' or '\n', must not
// be contained in the encoding's alphabet and must be a rune equal or
// below '\xff'.
func (enc Encoding) WithPadding(padding rune) *Encoding {
if padding == '\r' || padding == '\n' || padding > 0xff {
panic("invalid padding")
}
for i := 0; i < len(enc.encode); i++ {
if rune(enc.encode[i]) == padding {
panic("padding contained in alphabet")
}
}
enc.padChar = padding
return &enc
}
// Strict creates a new encoding identical to enc except with
// strict decoding enabled. In this mode, the decoder requires that
// trailing padding bits are zero, as described in RFC 4648 section 3.5.
//
// Note that the input is still malleable, as new line characters
// (CR and LF) are still ignored.
func (enc Encoding) Strict() *Encoding {
enc.strict = true
return &enc
}
// StdEncoding is the standard base64 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
// URLEncoding is the alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
var URLEncoding = NewEncoding(encodeURL)
// RawStdEncoding is the standard raw, unpadded base64 encoding,
// as defined in RFC 4648 section 3.2.
// This is the same as StdEncoding but omits padding characters.
var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
// RawURLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
// This is the same as URLEncoding but omits padding characters.
var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
/*
* Encoder
*/
// Encode encodes src using the encoding enc, writing
// EncodedLen(len(src)) bytes to dst.
//
// The encoding pads the output to a multiple of 4 bytes,
// so Encode is not appropriate for use on individual blocks
// of a large data stream. Use NewEncoder() instead.
func (enc *Encoding) Encode(dst, src []byte) {
if len(src) == 0 {
return
}
// enc is a pointer receiver, so the use of enc.encode within the hot
// loop below means a nil check at every operation. Lift that nil check
// outside of the loop to speed up the encoder.
_ = enc.encode
di, si := 0, 0
n := (len(src) / 3) * 3
for si < n {
// Convert 3x 8bit source bytes into 4 bytes
val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
dst[di+2] = enc.encode[val>>6&0x3F]
dst[di+3] = enc.encode[val&0x3F]
si += 3
di += 4
}
remain := len(src) - si
if remain == 0 {
return
}
// Add the remaining small block
val := uint(src[si+0]) << 16
if remain == 2 {
val |= uint(src[si+1]) << 8
}
dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
switch remain {
case 2:
dst[di+2] = enc.encode[val>>6&0x3F]
if enc.padChar != NoPadding {
dst[di+3] = byte(enc.padChar)
}
case 1:
if enc.padChar != NoPadding {
dst[di+2] = byte(enc.padChar)
dst[di+3] = byte(enc.padChar)
}
}
}
// EncodeToString returns the base64 encoding of src.
func (enc *Encoding) EncodeToString(src []byte) string {
buf := make([]byte, enc.EncodedLen(len(src)))
enc.Encode(buf, src)
return string(buf)
}
type encoder struct {
err error
enc *Encoding
w io.Writer
buf [3]byte // buffered data waiting to be encoded
nbuf int // number of bytes in buf
out [1024]byte // output buffer
}
func (e *encoder) Write(p []byte) (n int, err error) {
if e.err != nil {
return 0, e.err
}
// Leading fringe.
if e.nbuf > 0 {
var i int
for i = 0; i < len(p) && e.nbuf < 3; i++ {
e.buf[e.nbuf] = p[i]
e.nbuf++
}
n += i
p = p[i:]
if e.nbuf < 3 {
return
}
e.enc.Encode(e.out[:], e.buf[:])
if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
return n, e.err
}
e.nbuf = 0
}
// Large interior chunks.
for len(p) >= 3 {
nn := len(e.out) / 4 * 3
if nn > len(p) {
nn = len(p)
nn -= nn % 3
}
e.enc.Encode(e.out[:], p[:nn])
if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
return n, e.err
}
n += nn
p = p[nn:]
}
// Trailing fringe.
copy(e.buf[:], p)
e.nbuf = len(p)
n += len(p)
return
}
// Close flushes any pending output from the encoder.
// It is an error to call Write after calling Close.
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
e.enc.Encode(e.out[:], e.buf[:e.nbuf])
_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
e.nbuf = 0
}
return e.err
}
// NewEncoder returns a new base64 stream encoder. Data written to
// the returned writer will be encoded using enc and then written to w.
// Base64 encodings operate in 4-byte blocks; when finished
// writing, the caller must Close the returned encoder to flush any
// partially written blocks.
func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
return &encoder{enc: enc, w: w}
}
// EncodedLen returns the length in bytes of the base64 encoding
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
}
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
}
/*
* Decoder
*/
type CorruptInputError int64
func (e CorruptInputError) Error() string {
return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
}
// decodeQuantum decodes up to 4 base64 bytes. The received parameters are
// the destination buffer dst, the source buffer src and an index in the
// source buffer si.
// It returns the number of bytes read from src, the number of bytes written
// to dst, and an error, if any.
func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
// Decode quantum using the base64 alphabet
var dbuf [4]byte
dlen := 4
// Lift the nil check outside of the loop.
_ = enc.decodeMap
for j := 0; j < len(dbuf); j++ {
if len(src) == si {
switch {
case j == 0:
return si, 0, nil
case j == 1, enc.padChar != NoPadding:
return si, 0, CorruptInputError(si - j)
}
dlen = j
break
}
in := src[si]
si++
out := enc.decodeMap[in]
if out != 0xff {
dbuf[j] = out
continue
}
if in == '\n' || in == '\r' {
j--
continue
}
if rune(in) != enc.padChar {
return si, 0, CorruptInputError(si - 1)
}
// We've reached the end and there's padding
switch j {
case 0, 1:
// incorrect padding
return si, 0, CorruptInputError(si - 1)
case 2:
// "==" is expected, the first "=" is already consumed.
// skip over newlines
for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
si++
}
if si == len(src) {
// not enough padding
return si, 0, CorruptInputError(len(src))
}
if rune(src[si]) != enc.padChar {
// incorrect padding
return si, 0, CorruptInputError(si - 1)
}
si++
}
// skip over newlines
for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
si++
}
if si < len(src) {
// trailing garbage
err = CorruptInputError(si)
}
dlen = j
break
}
// Convert 4x 6bit source bytes into 3 bytes
val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
switch dlen {
case 4:
dst[2] = dbuf[2]
dbuf[2] = 0
fallthrough
case 3:
dst[1] = dbuf[1]
if enc.strict && dbuf[2] != 0 {
return si, 0, CorruptInputError(si - 1)
}
dbuf[1] = 0
fallthrough
case 2:
dst[0] = dbuf[0]
if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
return si, 0, CorruptInputError(si - 2)
}
}
return si, dlen - 1, err
}
// DecodeString returns the bytes represented by the base64 string s.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
dbuf := make([]byte, enc.DecodedLen(len(s)))
n, err := enc.Decode(dbuf, []byte(s))
return dbuf[:n], err
}
type decoder struct {
err error
readErr error // error from r.Read
enc *Encoding
r io.Reader
buf [1024]byte // leftover input
nbuf int
out []byte // leftover decoded output
outbuf [1024 / 4 * 3]byte
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Use leftover decoded output from last read.
if len(d.out) > 0 {
n = copy(p, d.out)
d.out = d.out[n:]
return n, nil
}
if d.err != nil {
return 0, d.err
}
// This code assumes that d.r strips supported whitespace ('\r' and '\n').
// Refill buffer.
for d.nbuf < 4 && d.readErr == nil {
nn := len(p) / 3 * 4
if nn < 4 {
nn = 4
}
if nn > len(d.buf) {
nn = len(d.buf)
}
nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
d.nbuf += nn
}
if d.nbuf < 4 {
if d.enc.padChar == NoPadding && d.nbuf > 0 {
// Decode final fragment, without padding.
var nw int
nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
d.nbuf = 0
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
if n > 0 || len(p) == 0 && len(d.out) > 0 {
return n, nil
}
if d.err != nil {
return 0, d.err
}
}
d.err = d.readErr
if d.err == io.EOF && d.nbuf > 0 {
d.err = io.ErrUnexpectedEOF
}
return 0, d.err
}
// Decode chunk into p, or d.out and then p if p is too small.
nr := d.nbuf / 4 * 4
nw := d.nbuf / 4 * 3
if nw > len(p) {
nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
n, d.err = d.enc.Decode(p, d.buf[:nr])
}
d.nbuf -= nr
copy(d.buf[:d.nbuf], d.buf[nr:])
return n, d.err
}
// Decode decodes src using the encoding enc. It writes at most
// DecodedLen(len(src)) bytes to dst and returns the number of bytes
// written. If src contains invalid base64 data, it will return the
// number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
if len(src) == 0 {
return 0, nil
}
// Lift the nil check outside of the loop. enc.decodeMap is directly
// used later in this function, to let the compiler know that the
// receiver can't be nil.
_ = enc.decodeMap
si := 0
for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
src2 := src[si : si+8]
if dn, ok := assemble64(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
enc.decodeMap[src2[4]],
enc.decodeMap[src2[5]],
enc.decodeMap[src2[6]],
enc.decodeMap[src2[7]],
); ok {
binary.BigEndian.PutUint64(dst[n:], dn)
n += 6
si += 8
} else {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
}
for len(src)-si >= 4 && len(dst)-n >= 4 {
src2 := src[si : si+4]
if dn, ok := assemble32(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
); ok {
binary.BigEndian.PutUint32(dst[n:], dn)
n += 3
si += 4
} else {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
}
for si < len(src) {
var ninc int
si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
n += ninc
if err != nil {
return n, err
}
}
return n, err
}
// assemble32 assembles 4 base64 digits into 3 bytes.
// Each digit comes from the decode map, and will be 0xff
// if it came from an invalid character.
func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
// Check that all the digits are valid. If any of them was 0xff, their
// bitwise OR will be 0xff.
if n1|n2|n3|n4 == 0xff {
return 0, false
}
return uint32(n1)<<26 |
uint32(n2)<<20 |
uint32(n3)<<14 |
uint32(n4)<<8,
true
}
// assemble64 assembles 8 base64 digits into 6 bytes.
// Each digit comes from the decode map, and will be 0xff
// if it came from an invalid character.
func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
// Check that all the digits are valid. If any of them was 0xff, their
// bitwise OR will be 0xff.
if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
return 0, false
}
return uint64(n1)<<58 |
uint64(n2)<<52 |
uint64(n3)<<46 |
uint64(n4)<<40 |
uint64(n5)<<34 |
uint64(n6)<<28 |
uint64(n7)<<22 |
uint64(n8)<<16,
true
}
type newlineFilteringReader struct {
wrapped io.Reader
}
func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
offset := 0
for i, b := range p[:n] {
if b != '\r' && b != '\n' {
if i != offset {
p[offset] = b
}
offset++
}
}
if offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again
n, err = r.wrapped.Read(p)
}
return n, err
}
// NewDecoder constructs a new base64 stream decoder.
func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
return &decoder{enc: enc, r: &newlineFilteringReader{r}}
}
// DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base64-encoded data.
func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding {
// Unpadded data may end with partial block of 2-3 characters.
return n * 6 / 8
}
// Padded base64 should always be a multiple of 4 characters in length.
return n / 4 * 3
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package binary implements simple translation between numbers and byte
// sequences and encoding and decoding of varints.
//
// Numbers are translated by reading and writing fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (bool, int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
//
// The varint functions encode and decode single integer values using
// a variable-length encoding; smaller values require fewer bytes.
// For a specification, see
// https://developers.google.com/protocol-buffers/docs/encoding.
//
// This package favors simplicity over efficiency. Clients that require
// high-performance serialization, especially for large data structures,
// should look at more advanced solutions such as the encoding/gob
// package or protocol buffers.
package binary
import (
"errors"
"io"
"math"
"reflect"
"sync"
)
// A ByteOrder specifies how to convert byte slices into
// 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
Uint16([]byte) uint16
Uint32([]byte) uint32
Uint64([]byte) uint64
PutUint16([]byte, uint16)
PutUint32([]byte, uint32)
PutUint64([]byte, uint64)
String() string
}
// AppendByteOrder specifies how to append 16-, 32-, or 64-bit unsigned integers
// into a byte slice.
type AppendByteOrder interface {
AppendUint16([]byte, uint16) []byte
AppendUint32([]byte, uint32) []byte
AppendUint64([]byte, uint64) []byte
String() string
}
// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder.
var LittleEndian littleEndian
// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder.
var BigEndian bigEndian
type littleEndian struct{}
func (littleEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func (littleEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
func (littleEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v),
byte(v>>8),
)
}
func (littleEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func (littleEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
}
func (littleEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
)
}
func (littleEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
func (littleEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
b[2] = byte(v >> 16)
b[3] = byte(v >> 24)
b[4] = byte(v >> 32)
b[5] = byte(v >> 40)
b[6] = byte(v >> 48)
b[7] = byte(v >> 56)
}
func (littleEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v),
byte(v>>8),
byte(v>>16),
byte(v>>24),
byte(v>>32),
byte(v>>40),
byte(v>>48),
byte(v>>56),
)
}
func (littleEndian) String() string { return "LittleEndian" }
func (littleEndian) GoString() string { return "binary.LittleEndian" }
type bigEndian struct{}
func (bigEndian) Uint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[1]) | uint16(b[0])<<8
}
func (bigEndian) PutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 8)
b[1] = byte(v)
}
func (bigEndian) AppendUint16(b []byte, v uint16) []byte {
return append(b,
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint32(b []byte) uint32 {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func (bigEndian) PutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func (bigEndian) AppendUint32(b []byte, v uint32) []byte {
return append(b,
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) Uint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func (bigEndian) PutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func (bigEndian) AppendUint64(b []byte, v uint64) []byte {
return append(b,
byte(v>>56),
byte(v>>48),
byte(v>>40),
byte(v>>32),
byte(v>>24),
byte(v>>16),
byte(v>>8),
byte(v),
)
}
func (bigEndian) String() string { return "BigEndian" }
func (bigEndian) GoString() string { return "binary.BigEndian" }
func (nativeEndian) String() string { return "NativeEndian" }
func (nativeEndian) GoString() string { return "binary.NativeEndian" }
// Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice
// of fixed-size values.
// Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data.
// When decoding boolean values, a zero byte is decoded as false, and
// any other non-zero byte is decoded as true.
// When reading into structs, the field data for fields with
// blank (_) field names is skipped; i.e., blank field names
// may be used for padding.
// When reading into a struct, all non-blank fields must be exported
// or Read may panic.
//
// The error is EOF only if no bytes were read.
// If an EOF happens after reading some but not all the bytes,
// Read returns ErrUnexpectedEOF.
func Read(r io.Reader, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
if _, err := io.ReadFull(r, bs); err != nil {
return err
}
switch data := data.(type) {
case *bool:
*data = bs[0] != 0
case *int8:
*data = int8(bs[0])
case *uint8:
*data = bs[0]
case *int16:
*data = int16(order.Uint16(bs))
case *uint16:
*data = order.Uint16(bs)
case *int32:
*data = int32(order.Uint32(bs))
case *uint32:
*data = order.Uint32(bs)
case *int64:
*data = int64(order.Uint64(bs))
case *uint64:
*data = order.Uint64(bs)
case *float32:
*data = math.Float32frombits(order.Uint32(bs))
case *float64:
*data = math.Float64frombits(order.Uint64(bs))
case []bool:
for i, x := range bs { // Easier to loop over the input for 8-bit values.
data[i] = x != 0
}
case []int8:
for i, x := range bs {
data[i] = int8(x)
}
case []uint8:
copy(data, bs)
case []int16:
for i := range data {
data[i] = int16(order.Uint16(bs[2*i:]))
}
case []uint16:
for i := range data {
data[i] = order.Uint16(bs[2*i:])
}
case []int32:
for i := range data {
data[i] = int32(order.Uint32(bs[4*i:]))
}
case []uint32:
for i := range data {
data[i] = order.Uint32(bs[4*i:])
}
case []int64:
for i := range data {
data[i] = int64(order.Uint64(bs[8*i:]))
}
case []uint64:
for i := range data {
data[i] = order.Uint64(bs[8*i:])
}
case []float32:
for i := range data {
data[i] = math.Float32frombits(order.Uint32(bs[4*i:]))
}
case []float64:
for i := range data {
data[i] = math.Float64frombits(order.Uint64(bs[8*i:]))
}
default:
n = 0 // fast path doesn't apply
}
if n != 0 {
return nil
}
}
// Fallback to reflect-based decoding.
v := reflect.ValueOf(data)
size := -1
switch v.Kind() {
case reflect.Pointer:
v = v.Elem()
size = dataSize(v)
case reflect.Slice:
size = dataSize(v)
}
if size < 0 {
return errors.New("binary.Read: invalid type " + reflect.TypeOf(data).String())
}
d := &decoder{order: order, buf: make([]byte, size)}
if _, err := io.ReadFull(r, d.buf); err != nil {
return err
}
d.value(v)
return nil
}
// Write writes the binary representation of data into w.
// Data must be a fixed-size value or a slice of fixed-size
// values, or a pointer to such data.
// Boolean values encode as one byte: 1 for true, and 0 for false.
// Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data.
// When writing structs, zero values are written for fields
// with blank (_) field names.
func Write(w io.Writer, order ByteOrder, data any) error {
// Fast path for basic types and slices.
if n := intDataSize(data); n != 0 {
bs := make([]byte, n)
switch v := data.(type) {
case *bool:
if *v {
bs[0] = 1
} else {
bs[0] = 0
}
case bool:
if v {
bs[0] = 1
} else {
bs[0] = 0
}
case []bool:
for i, x := range v {
if x {
bs[i] = 1
} else {
bs[i] = 0
}
}
case *int8:
bs[0] = byte(*v)
case int8:
bs[0] = byte(v)
case []int8:
for i, x := range v {
bs[i] = byte(x)
}
case *uint8:
bs[0] = *v
case uint8:
bs[0] = v
case []uint8:
bs = v
case *int16:
order.PutUint16(bs, uint16(*v))
case int16:
order.PutUint16(bs, uint16(v))
case []int16:
for i, x := range v {
order.PutUint16(bs[2*i:], uint16(x))
}
case *uint16:
order.PutUint16(bs, *v)
case uint16:
order.PutUint16(bs, v)
case []uint16:
for i, x := range v {
order.PutUint16(bs[2*i:], x)
}
case *int32:
order.PutUint32(bs, uint32(*v))
case int32:
order.PutUint32(bs, uint32(v))
case []int32:
for i, x := range v {
order.PutUint32(bs[4*i:], uint32(x))
}
case *uint32:
order.PutUint32(bs, *v)
case uint32:
order.PutUint32(bs, v)
case []uint32:
for i, x := range v {
order.PutUint32(bs[4*i:], x)
}
case *int64:
order.PutUint64(bs, uint64(*v))
case int64:
order.PutUint64(bs, uint64(v))
case []int64:
for i, x := range v {
order.PutUint64(bs[8*i:], uint64(x))
}
case *uint64:
order.PutUint64(bs, *v)
case uint64:
order.PutUint64(bs, v)
case []uint64:
for i, x := range v {
order.PutUint64(bs[8*i:], x)
}
case *float32:
order.PutUint32(bs, math.Float32bits(*v))
case float32:
order.PutUint32(bs, math.Float32bits(v))
case []float32:
for i, x := range v {
order.PutUint32(bs[4*i:], math.Float32bits(x))
}
case *float64:
order.PutUint64(bs, math.Float64bits(*v))
case float64:
order.PutUint64(bs, math.Float64bits(v))
case []float64:
for i, x := range v {
order.PutUint64(bs[8*i:], math.Float64bits(x))
}
}
_, err := w.Write(bs)
return err
}
// Fallback to reflect-based encoding.
v := reflect.Indirect(reflect.ValueOf(data))
size := dataSize(v)
if size < 0 {
return errors.New("binary.Write: invalid type " + reflect.TypeOf(data).String())
}
buf := make([]byte, size)
e := &encoder{order: order, buf: buf}
e.value(v)
_, err := w.Write(buf)
return err
}
// Size returns how many bytes Write would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
// If v is neither of these, Size returns -1.
func Size(v any) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
var structSize sync.Map // map[reflect.Type]int
// dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header. If the type of v is not acceptable, dataSize returns -1.
func dataSize(v reflect.Value) int {
switch v.Kind() {
case reflect.Slice:
if s := sizeof(v.Type().Elem()); s >= 0 {
return s * v.Len()
}
return -1
case reflect.Struct:
t := v.Type()
if size, ok := structSize.Load(t); ok {
return size.(int)
}
size := sizeof(t)
structSize.Store(t, size)
return size
default:
return sizeof(v.Type())
}
}
// sizeof returns the size >= 0 of variables for the given type or -1 if the type is not acceptable.
func sizeof(t reflect.Type) int {
switch t.Kind() {
case reflect.Array:
if s := sizeof(t.Elem()); s >= 0 {
return s * t.Len()
}
case reflect.Struct:
sum := 0
for i, n := 0, t.NumField(); i < n; i++ {
s := sizeof(t.Field(i).Type)
if s < 0 {
return -1
}
sum += s
}
return sum
case reflect.Bool,
reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return int(t.Size())
}
return -1
}
type coder struct {
order ByteOrder
buf []byte
offset int
}
type decoder coder
type encoder coder
func (d *decoder) bool() bool {
x := d.buf[d.offset]
d.offset++
return x != 0
}
func (e *encoder) bool(x bool) {
if x {
e.buf[e.offset] = 1
} else {
e.buf[e.offset] = 0
}
e.offset++
}
func (d *decoder) uint8() uint8 {
x := d.buf[d.offset]
d.offset++
return x
}
func (e *encoder) uint8(x uint8) {
e.buf[e.offset] = x
e.offset++
}
func (d *decoder) uint16() uint16 {
x := d.order.Uint16(d.buf[d.offset : d.offset+2])
d.offset += 2
return x
}
func (e *encoder) uint16(x uint16) {
e.order.PutUint16(e.buf[e.offset:e.offset+2], x)
e.offset += 2
}
func (d *decoder) uint32() uint32 {
x := d.order.Uint32(d.buf[d.offset : d.offset+4])
d.offset += 4
return x
}
func (e *encoder) uint32(x uint32) {
e.order.PutUint32(e.buf[e.offset:e.offset+4], x)
e.offset += 4
}
func (d *decoder) uint64() uint64 {
x := d.order.Uint64(d.buf[d.offset : d.offset+8])
d.offset += 8
return x
}
func (e *encoder) uint64(x uint64) {
e.order.PutUint64(e.buf[e.offset:e.offset+8], x)
e.offset += 8
}
func (d *decoder) int8() int8 { return int8(d.uint8()) }
func (e *encoder) int8(x int8) { e.uint8(uint8(x)) }
func (d *decoder) int16() int16 { return int16(d.uint16()) }
func (e *encoder) int16(x int16) { e.uint16(uint16(x)) }
func (d *decoder) int32() int32 { return int32(d.uint32()) }
func (e *encoder) int32(x int32) { e.uint32(uint32(x)) }
func (d *decoder) int64() int64 { return int64(d.uint64()) }
func (e *encoder) int64(x int64) { e.uint64(uint64(x)) }
func (d *decoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// Note: Calling v.CanSet() below is an optimization.
// It would be sufficient to check the field name,
// but creating the StructField info for each field is
// costly (run "go test -bench=ReadStruct" and compare
// results when making changes to this code).
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
d.value(v)
} else {
d.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
d.value(v.Index(i))
}
case reflect.Bool:
v.SetBool(d.bool())
case reflect.Int8:
v.SetInt(int64(d.int8()))
case reflect.Int16:
v.SetInt(int64(d.int16()))
case reflect.Int32:
v.SetInt(int64(d.int32()))
case reflect.Int64:
v.SetInt(d.int64())
case reflect.Uint8:
v.SetUint(uint64(d.uint8()))
case reflect.Uint16:
v.SetUint(uint64(d.uint16()))
case reflect.Uint32:
v.SetUint(uint64(d.uint32()))
case reflect.Uint64:
v.SetUint(d.uint64())
case reflect.Float32:
v.SetFloat(float64(math.Float32frombits(d.uint32())))
case reflect.Float64:
v.SetFloat(math.Float64frombits(d.uint64()))
case reflect.Complex64:
v.SetComplex(complex(
float64(math.Float32frombits(d.uint32())),
float64(math.Float32frombits(d.uint32())),
))
case reflect.Complex128:
v.SetComplex(complex(
math.Float64frombits(d.uint64()),
math.Float64frombits(d.uint64()),
))
}
}
func (e *encoder) value(v reflect.Value) {
switch v.Kind() {
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Struct:
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
// see comment for corresponding code in decoder.value()
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
e.value(v)
} else {
e.skip(v)
}
}
case reflect.Slice:
l := v.Len()
for i := 0; i < l; i++ {
e.value(v.Index(i))
}
case reflect.Bool:
e.bool(v.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch v.Type().Kind() {
case reflect.Int8:
e.int8(int8(v.Int()))
case reflect.Int16:
e.int16(int16(v.Int()))
case reflect.Int32:
e.int32(int32(v.Int()))
case reflect.Int64:
e.int64(v.Int())
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch v.Type().Kind() {
case reflect.Uint8:
e.uint8(uint8(v.Uint()))
case reflect.Uint16:
e.uint16(uint16(v.Uint()))
case reflect.Uint32:
e.uint32(uint32(v.Uint()))
case reflect.Uint64:
e.uint64(v.Uint())
}
case reflect.Float32, reflect.Float64:
switch v.Type().Kind() {
case reflect.Float32:
e.uint32(math.Float32bits(float32(v.Float())))
case reflect.Float64:
e.uint64(math.Float64bits(v.Float()))
}
case reflect.Complex64, reflect.Complex128:
switch v.Type().Kind() {
case reflect.Complex64:
x := v.Complex()
e.uint32(math.Float32bits(float32(real(x))))
e.uint32(math.Float32bits(float32(imag(x))))
case reflect.Complex128:
x := v.Complex()
e.uint64(math.Float64bits(real(x)))
e.uint64(math.Float64bits(imag(x)))
}
}
}
func (d *decoder) skip(v reflect.Value) {
d.offset += dataSize(v)
}
func (e *encoder) skip(v reflect.Value) {
n := dataSize(v)
zero := e.buf[e.offset : e.offset+n]
for i := range zero {
zero[i] = 0
}
e.offset += n
}
// intDataSize returns the size of the data required to represent the data when encoded.
// It returns zero if the type cannot be implemented by the fast path in Read or Write.
func intDataSize(data any) int {
switch data := data.(type) {
case bool, int8, uint8, *bool, *int8, *uint8:
return 1
case []bool:
return len(data)
case []int8:
return len(data)
case []uint8:
return len(data)
case int16, uint16, *int16, *uint16:
return 2
case []int16:
return 2 * len(data)
case []uint16:
return 2 * len(data)
case int32, uint32, *int32, *uint32:
return 4
case []int32:
return 4 * len(data)
case []uint32:
return 4 * len(data)
case int64, uint64, *int64, *uint64:
return 8
case []int64:
return 8 * len(data)
case []uint64:
return 8 * len(data)
case float32, *float32:
return 4
case float64, *float64:
return 8
case []float32:
return 4 * len(data)
case []float64:
return 8 * len(data)
}
return 0
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package binary
// This file implements "varint" encoding of 64-bit integers.
// The encoding is:
// - unsigned integers are serialized 7 bits at a time, starting with the
// least significant bits
// - the most significant bit (msb) in each output byte indicates if there
// is a continuation byte (msb = 1)
// - signed integers are mapped to unsigned integers using "zig-zag"
// encoding: Positive values x are written as 2*x + 0, negative values
// are written as 2*(^x) + 1; that is, negative numbers are complemented
// and whether to complement is encoded in bit 0.
//
// Design note:
// At most 10 bytes are needed for 64-bit values. The encoding could
// be more dense: a full 64-bit value needs an extra byte just to hold bit 63.
// Instead, the msb of the previous byte could be used to hold bit 63 since we
// know there can't be more than 64 bits. This is a trivial improvement and
// would reduce the maximum encoding length to 9 bytes. However, it breaks the
// invariant that the msb is always the "continuation bit" and thus makes the
// format incompatible with a varint encoding for larger numbers (say 128-bit).
import (
"errors"
"io"
)
// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer.
const (
MaxVarintLen16 = 3
MaxVarintLen32 = 5
MaxVarintLen64 = 10
)
// AppendUvarint appends the varint-encoded form of x,
// as generated by PutUvarint, to buf and returns the extended buffer.
func AppendUvarint(buf []byte, x uint64) []byte {
for x >= 0x80 {
buf = append(buf, byte(x)|0x80)
x >>= 7
}
return append(buf, byte(x))
}
// PutUvarint encodes a uint64 into buf and returns the number of bytes written.
// If the buffer is too small, PutUvarint will panic.
func PutUvarint(buf []byte, x uint64) int {
i := 0
for x >= 0x80 {
buf[i] = byte(x) | 0x80
x >>= 7
i++
}
buf[i] = byte(x)
return i + 1
}
// Uvarint decodes a uint64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Uvarint(buf []byte) (uint64, int) {
var x uint64
var s uint
for i, b := range buf {
if i == MaxVarintLen64 {
// Catch byte reads past MaxVarintLen64.
// See issue https://golang.org/issues/41185
return 0, -(i + 1) // overflow
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return 0, -(i + 1) // overflow
}
return x | uint64(b)<<s, i + 1
}
x |= uint64(b&0x7f) << s
s += 7
}
return 0, 0
}
// AppendVarint appends the varint-encoded form of x,
// as generated by PutVarint, to buf and returns the extended buffer.
func AppendVarint(buf []byte, x int64) []byte {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return AppendUvarint(buf, ux)
}
// PutVarint encodes an int64 into buf and returns the number of bytes written.
// If the buffer is too small, PutVarint will panic.
func PutVarint(buf []byte, x int64) int {
ux := uint64(x) << 1
if x < 0 {
ux = ^ux
}
return PutUvarint(buf, ux)
}
// Varint decodes an int64 from buf and returns that value and the
// number of bytes read (> 0). If an error occurred, the value is 0
// and the number of bytes n is <= 0 with the following meaning:
//
// n == 0: buf too small
// n < 0: value larger than 64 bits (overflow)
// and -n is the number of bytes read
func Varint(buf []byte) (int64, int) {
ux, n := Uvarint(buf) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, n
}
var errOverflow = errors.New("binary: varint overflows a 64-bit integer")
// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64.
// The error is EOF only if no bytes were read.
// If an EOF happens after reading some but not all the bytes,
// ReadUvarint returns io.ErrUnexpectedEOF.
func ReadUvarint(r io.ByteReader) (uint64, error) {
var x uint64
var s uint
for i := 0; i < MaxVarintLen64; i++ {
b, err := r.ReadByte()
if err != nil {
if i > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return x, err
}
if b < 0x80 {
if i == MaxVarintLen64-1 && b > 1 {
return x, errOverflow
}
return x | uint64(b)<<s, nil
}
x |= uint64(b&0x7f) << s
s += 7
}
return x, errOverflow
}
// ReadVarint reads an encoded signed integer from r and returns it as an int64.
// The error is EOF only if no bytes were read.
// If an EOF happens after reading some but not all the bytes,
// ReadVarint returns io.ErrUnexpectedEOF.
func ReadVarint(r io.ByteReader) (int64, error) {
ux, err := ReadUvarint(r) // ok to continue in presence of error
x := int64(ux >> 1)
if ux&1 != 0 {
x = ^x
}
return x, err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package csv reads and writes comma-separated values (CSV) files.
// There are many kinds of CSV files; this package supports the format
// described in RFC 4180.
//
// A csv file contains zero or more records of one or more fields per record.
// Each record is separated by the newline character. The final record may
// optionally be followed by a newline character.
//
// field1,field2,field3
//
// White space is considered part of a field.
//
// Carriage returns before newline characters are silently removed.
//
// Blank lines are ignored. A line with only whitespace characters (excluding
// the ending newline character) is not considered a blank line.
//
// Fields which start and stop with the quote character " are called
// quoted-fields. The beginning and ending quote are not part of the
// field.
//
// The source:
//
// normal string,"quoted-field"
//
// results in the fields
//
// {`normal string`, `quoted-field`}
//
// Within a quoted-field a quote character followed by a second quote
// character is considered a single quote.
//
// "the ""word"" is true","a ""quoted-field"""
//
// results in
//
// {`the "word" is true`, `a "quoted-field"`}
//
// Newlines and commas may be included in a quoted-field
//
// "Multi-line
// field","comma is ,"
//
// results in
//
// {`Multi-line
// field`, `comma is ,`}
package csv
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"unicode"
"unicode/utf8"
)
// A ParseError is returned for parsing errors.
// Line numbers are 1-indexed and columns are 0-indexed.
type ParseError struct {
StartLine int // Line where the record starts
Line int // Line where the error occurred
Column int // Column (1-based byte index) where the error occurred
Err error // The actual error
}
func (e *ParseError) Error() string {
if e.Err == ErrFieldCount {
return fmt.Sprintf("record on line %d: %v", e.Line, e.Err)
}
if e.StartLine != e.Line {
return fmt.Sprintf("record on line %d; parse error on line %d, column %d: %v", e.StartLine, e.Line, e.Column, e.Err)
}
return fmt.Sprintf("parse error on line %d, column %d: %v", e.Line, e.Column, e.Err)
}
func (e *ParseError) Unwrap() error { return e.Err }
// These are the errors that can be returned in ParseError.Err.
var (
ErrBareQuote = errors.New("bare \" in non-quoted-field")
ErrQuote = errors.New("extraneous or missing \" in quoted-field")
ErrFieldCount = errors.New("wrong number of fields")
// Deprecated: ErrTrailingComma is no longer used.
ErrTrailingComma = errors.New("extra delimiter at end of line")
)
var errInvalidDelim = errors.New("csv: invalid field or comment delimiter")
func validDelim(r rune) bool {
return r != 0 && r != '"' && r != '\r' && r != '\n' && utf8.ValidRune(r) && r != utf8.RuneError
}
// A Reader reads records from a CSV-encoded file.
//
// As returned by NewReader, a Reader expects input conforming to RFC 4180.
// The exported fields can be changed to customize the details before the
// first call to Read or ReadAll.
//
// The Reader converts all \r\n sequences in its input to plain \n,
// including in multiline field values, so that the returned data does
// not depend on which line-ending convention an input file uses.
type Reader struct {
// Comma is the field delimiter.
// It is set to comma (',') by NewReader.
// Comma must be a valid rune and must not be \r, \n,
// or the Unicode replacement character (0xFFFD).
Comma rune
// Comment, if not 0, is the comment character. Lines beginning with the
// Comment character without preceding whitespace are ignored.
// With leading whitespace the Comment character becomes part of the
// field, even if TrimLeadingSpace is true.
// Comment must be a valid rune and must not be \r, \n,
// or the Unicode replacement character (0xFFFD).
// It must also not be equal to Comma.
Comment rune
// FieldsPerRecord is the number of expected fields per record.
// If FieldsPerRecord is positive, Read requires each record to
// have the given number of fields. If FieldsPerRecord is 0, Read sets it to
// the number of fields in the first record, so that future records must
// have the same field count. If FieldsPerRecord is negative, no check is
// made and records may have a variable number of fields.
FieldsPerRecord int
// If LazyQuotes is true, a quote may appear in an unquoted field and a
// non-doubled quote may appear in a quoted field.
LazyQuotes bool
// If TrimLeadingSpace is true, leading white space in a field is ignored.
// This is done even if the field delimiter, Comma, is white space.
TrimLeadingSpace bool
// ReuseRecord controls whether calls to Read may return a slice sharing
// the backing array of the previous call's returned slice for performance.
// By default, each call to Read returns newly allocated memory owned by the caller.
ReuseRecord bool
// Deprecated: TrailingComma is no longer used.
TrailingComma bool
r *bufio.Reader
// numLine is the current line being read in the CSV file.
numLine int
// offset is the input stream byte offset of the current reader position.
offset int64
// rawBuffer is a line buffer only used by the readLine method.
rawBuffer []byte
// recordBuffer holds the unescaped fields, one after another.
// The fields can be accessed by using the indexes in fieldIndexes.
// E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de`
// and fieldIndexes will contain the indexes [1, 2, 5, 6].
recordBuffer []byte
// fieldIndexes is an index of fields inside recordBuffer.
// The i'th field ends at offset fieldIndexes[i] in recordBuffer.
fieldIndexes []int
// fieldPositions is an index of field positions for the
// last record returned by Read.
fieldPositions []position
// lastRecord is a record cache and only used when ReuseRecord == true.
lastRecord []string
}
// NewReader returns a new Reader that reads from r.
func NewReader(r io.Reader) *Reader {
return &Reader{
Comma: ',',
r: bufio.NewReader(r),
}
}
// Read reads one record (a slice of fields) from r.
// If the record has an unexpected number of fields,
// Read returns the record along with the error ErrFieldCount.
// Except for that case, Read always returns either a non-nil
// record or a non-nil error, but not both.
// If there is no data left to be read, Read returns nil, io.EOF.
// If ReuseRecord is true, the returned slice may be shared
// between multiple calls to Read.
func (r *Reader) Read() (record []string, err error) {
if r.ReuseRecord {
record, err = r.readRecord(r.lastRecord)
r.lastRecord = record
} else {
record, err = r.readRecord(nil)
}
return record, err
}
// FieldPos returns the line and column corresponding to
// the start of the field with the given index in the slice most recently
// returned by Read. Numbering of lines and columns starts at 1;
// columns are counted in bytes, not runes.
//
// If this is called with an out-of-bounds index, it panics.
func (r *Reader) FieldPos(field int) (line, column int) {
if field < 0 || field >= len(r.fieldPositions) {
panic("out of range index passed to FieldPos")
}
p := &r.fieldPositions[field]
return p.line, p.col
}
// InputOffset returns the input stream byte offset of the current reader
// position. The offset gives the location of the end of the most recently
// read row and the beginning of the next row.
func (r *Reader) InputOffset() int64 {
return r.offset
}
// pos holds the position of a field in the current line.
type position struct {
line, col int
}
// ReadAll reads all the remaining records from r.
// Each record is a slice of fields.
// A successful call returns err == nil, not err == io.EOF. Because ReadAll is
// defined to read until EOF, it does not treat end of file as an error to be
// reported.
func (r *Reader) ReadAll() (records [][]string, err error) {
for {
record, err := r.readRecord(nil)
if err == io.EOF {
return records, nil
}
if err != nil {
return nil, err
}
records = append(records, record)
}
}
// readLine reads the next line (with the trailing endline).
// If EOF is hit without a trailing endline, it will be omitted.
// If some bytes were read, then the error is never io.EOF.
// The result is only valid until the next call to readLine.
func (r *Reader) readLine() ([]byte, error) {
line, err := r.r.ReadSlice('\n')
if err == bufio.ErrBufferFull {
r.rawBuffer = append(r.rawBuffer[:0], line...)
for err == bufio.ErrBufferFull {
line, err = r.r.ReadSlice('\n')
r.rawBuffer = append(r.rawBuffer, line...)
}
line = r.rawBuffer
}
readSize := len(line)
if readSize > 0 && err == io.EOF {
err = nil
// For backwards compatibility, drop trailing \r before EOF.
if line[readSize-1] == '\r' {
line = line[:readSize-1]
}
}
r.numLine++
r.offset += int64(readSize)
// Normalize \r\n to \n on all input lines.
if n := len(line); n >= 2 && line[n-2] == '\r' && line[n-1] == '\n' {
line[n-2] = '\n'
line = line[:n-1]
}
return line, err
}
// lengthNL reports the number of bytes for the trailing \n.
func lengthNL(b []byte) int {
if len(b) > 0 && b[len(b)-1] == '\n' {
return 1
}
return 0
}
// nextRune returns the next rune in b or utf8.RuneError.
func nextRune(b []byte) rune {
r, _ := utf8.DecodeRune(b)
return r
}
func (r *Reader) readRecord(dst []string) ([]string, error) {
if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) {
return nil, errInvalidDelim
}
// Read line (automatically skipping past empty lines and any comments).
var line []byte
var errRead error
for errRead == nil {
line, errRead = r.readLine()
if r.Comment != 0 && nextRune(line) == r.Comment {
line = nil
continue // Skip comment lines
}
if errRead == nil && len(line) == lengthNL(line) {
line = nil
continue // Skip empty lines
}
break
}
if errRead == io.EOF {
return nil, errRead
}
// Parse each field in the record.
var err error
const quoteLen = len(`"`)
commaLen := utf8.RuneLen(r.Comma)
recLine := r.numLine // Starting line for record
r.recordBuffer = r.recordBuffer[:0]
r.fieldIndexes = r.fieldIndexes[:0]
r.fieldPositions = r.fieldPositions[:0]
pos := position{line: r.numLine, col: 1}
parseField:
for {
if r.TrimLeadingSpace {
i := bytes.IndexFunc(line, func(r rune) bool {
return !unicode.IsSpace(r)
})
if i < 0 {
i = len(line)
pos.col -= lengthNL(line)
}
line = line[i:]
pos.col += i
}
if len(line) == 0 || line[0] != '"' {
// Non-quoted string field
i := bytes.IndexRune(line, r.Comma)
field := line
if i >= 0 {
field = field[:i]
} else {
field = field[:len(field)-lengthNL(field)]
}
// Check to make sure a quote does not appear in field.
if !r.LazyQuotes {
if j := bytes.IndexByte(field, '"'); j >= 0 {
col := pos.col + j
err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote}
break parseField
}
}
r.recordBuffer = append(r.recordBuffer, field...)
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, pos)
if i >= 0 {
line = line[i+commaLen:]
pos.col += i + commaLen
continue parseField
}
break parseField
} else {
// Quoted string field
fieldPos := pos
line = line[quoteLen:]
pos.col += quoteLen
for {
i := bytes.IndexByte(line, '"')
if i >= 0 {
// Hit next quote.
r.recordBuffer = append(r.recordBuffer, line[:i]...)
line = line[i+quoteLen:]
pos.col += i + quoteLen
switch rn := nextRune(line); {
case rn == '"':
// `""` sequence (append quote).
r.recordBuffer = append(r.recordBuffer, '"')
line = line[quoteLen:]
pos.col += quoteLen
case rn == r.Comma:
// `",` sequence (end of field).
line = line[commaLen:]
pos.col += commaLen
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
continue parseField
case lengthNL(line) == len(line):
// `"\n` sequence (end of line).
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
break parseField
case r.LazyQuotes:
// `"` sequence (bare quote).
r.recordBuffer = append(r.recordBuffer, '"')
default:
// `"*` sequence (invalid non-escaped quote).
err = &ParseError{StartLine: recLine, Line: r.numLine, Column: pos.col - quoteLen, Err: ErrQuote}
break parseField
}
} else if len(line) > 0 {
// Hit end of line (copy all data so far).
r.recordBuffer = append(r.recordBuffer, line...)
if errRead != nil {
break parseField
}
pos.col += len(line)
line, errRead = r.readLine()
if len(line) > 0 {
pos.line++
pos.col = 1
}
if errRead == io.EOF {
errRead = nil
}
} else {
// Abrupt end of file (EOF or error).
if !r.LazyQuotes && errRead == nil {
err = &ParseError{StartLine: recLine, Line: pos.line, Column: pos.col, Err: ErrQuote}
break parseField
}
r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer))
r.fieldPositions = append(r.fieldPositions, fieldPos)
break parseField
}
}
}
}
if err == nil {
err = errRead
}
// Create a single string and create slices out of it.
// This pins the memory of the fields together, but allocates once.
str := string(r.recordBuffer) // Convert to string once to batch allocations
dst = dst[:0]
if cap(dst) < len(r.fieldIndexes) {
dst = make([]string, len(r.fieldIndexes))
}
dst = dst[:len(r.fieldIndexes)]
var preIdx int
for i, idx := range r.fieldIndexes {
dst[i] = str[preIdx:idx]
preIdx = idx
}
// Check or update the expected fields per record.
if r.FieldsPerRecord > 0 {
if len(dst) != r.FieldsPerRecord && err == nil {
err = &ParseError{
StartLine: recLine,
Line: recLine,
Column: 1,
Err: ErrFieldCount,
}
}
} else if r.FieldsPerRecord == 0 {
r.FieldsPerRecord = len(dst)
}
return dst, err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package csv
import (
"bufio"
"io"
"strings"
"unicode"
"unicode/utf8"
)
// A Writer writes records using CSV encoding.
//
// As returned by NewWriter, a Writer writes records terminated by a
// newline and uses ',' as the field delimiter. The exported fields can be
// changed to customize the details before the first call to Write or WriteAll.
//
// Comma is the field delimiter.
//
// If UseCRLF is true, the Writer ends each output line with \r\n instead of \n.
//
// The writes of individual records are buffered.
// After all data has been written, the client should call the
// Flush method to guarantee all data has been forwarded to
// the underlying io.Writer. Any errors that occurred should
// be checked by calling the Error method.
type Writer struct {
Comma rune // Field delimiter (set to ',' by NewWriter)
UseCRLF bool // True to use \r\n as the line terminator
w *bufio.Writer
}
// NewWriter returns a new Writer that writes to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{
Comma: ',',
w: bufio.NewWriter(w),
}
}
// Write writes a single CSV record to w along with any necessary quoting.
// A record is a slice of strings with each string being one field.
// Writes are buffered, so Flush must eventually be called to ensure
// that the record is written to the underlying io.Writer.
func (w *Writer) Write(record []string) error {
if !validDelim(w.Comma) {
return errInvalidDelim
}
for n, field := range record {
if n > 0 {
if _, err := w.w.WriteRune(w.Comma); err != nil {
return err
}
}
// If we don't have to have a quoted field then just
// write out the field and continue to the next field.
if !w.fieldNeedsQuotes(field) {
if _, err := w.w.WriteString(field); err != nil {
return err
}
continue
}
if err := w.w.WriteByte('"'); err != nil {
return err
}
for len(field) > 0 {
// Search for special characters.
i := strings.IndexAny(field, "\"\r\n")
if i < 0 {
i = len(field)
}
// Copy verbatim everything before the special character.
if _, err := w.w.WriteString(field[:i]); err != nil {
return err
}
field = field[i:]
// Encode the special character.
if len(field) > 0 {
var err error
switch field[0] {
case '"':
_, err = w.w.WriteString(`""`)
case '\r':
if !w.UseCRLF {
err = w.w.WriteByte('\r')
}
case '\n':
if w.UseCRLF {
_, err = w.w.WriteString("\r\n")
} else {
err = w.w.WriteByte('\n')
}
}
field = field[1:]
if err != nil {
return err
}
}
}
if err := w.w.WriteByte('"'); err != nil {
return err
}
}
var err error
if w.UseCRLF {
_, err = w.w.WriteString("\r\n")
} else {
err = w.w.WriteByte('\n')
}
return err
}
// Flush writes any buffered data to the underlying io.Writer.
// To check if an error occurred during the Flush, call Error.
func (w *Writer) Flush() {
w.w.Flush()
}
// Error reports any error that has occurred during a previous Write or Flush.
func (w *Writer) Error() error {
_, err := w.w.Write(nil)
return err
}
// WriteAll writes multiple CSV records to w using Write and then calls Flush,
// returning any error from the Flush.
func (w *Writer) WriteAll(records [][]string) error {
for _, record := range records {
err := w.Write(record)
if err != nil {
return err
}
}
return w.w.Flush()
}
// fieldNeedsQuotes reports whether our field must be enclosed in quotes.
// Fields with a Comma, fields with a quote or newline, and
// fields which start with a space must be enclosed in quotes.
// We used to quote empty strings, but we do not anymore (as of Go 1.4).
// The two representations should be equivalent, but Postgres distinguishes
// quoted vs non-quoted empty string during database imports, and it has
// an option to force the quoted behavior for non-quoted CSV but it has
// no option to force the non-quoted behavior for quoted CSV, making
// CSV with quoted empty strings strictly less useful.
// Not quoting the empty string also makes this package match the behavior
// of Microsoft Excel and Google Drive.
// For Postgres, quote the data terminating string `\.`.
func (w *Writer) fieldNeedsQuotes(field string) bool {
if field == "" {
return false
}
if field == `\.` {
return true
}
if w.Comma < utf8.RuneSelf {
for i := 0; i < len(field); i++ {
c := field[i]
if c == '\n' || c == '\r' || c == '"' || c == byte(w.Comma) {
return true
}
}
} else {
if strings.ContainsRune(field, w.Comma) || strings.ContainsAny(field, "\"\r\n") {
return true
}
}
r1, _ := utf8.DecodeRuneInString(field)
return unicode.IsSpace(r1)
}
// Code generated by go run decgen.go -output dec_helpers.go; DO NOT EDIT.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"math"
"reflect"
)
var decArrayHelper = map[reflect.Kind]decHelper{
reflect.Bool: decBoolArray,
reflect.Complex64: decComplex64Array,
reflect.Complex128: decComplex128Array,
reflect.Float32: decFloat32Array,
reflect.Float64: decFloat64Array,
reflect.Int: decIntArray,
reflect.Int16: decInt16Array,
reflect.Int32: decInt32Array,
reflect.Int64: decInt64Array,
reflect.Int8: decInt8Array,
reflect.String: decStringArray,
reflect.Uint: decUintArray,
reflect.Uint16: decUint16Array,
reflect.Uint32: decUint32Array,
reflect.Uint64: decUint64Array,
reflect.Uintptr: decUintptrArray,
}
var decSliceHelper = map[reflect.Kind]decHelper{
reflect.Bool: decBoolSlice,
reflect.Complex64: decComplex64Slice,
reflect.Complex128: decComplex128Slice,
reflect.Float32: decFloat32Slice,
reflect.Float64: decFloat64Slice,
reflect.Int: decIntSlice,
reflect.Int16: decInt16Slice,
reflect.Int32: decInt32Slice,
reflect.Int64: decInt64Slice,
reflect.Int8: decInt8Slice,
reflect.String: decStringSlice,
reflect.Uint: decUintSlice,
reflect.Uint16: decUint16Slice,
reflect.Uint32: decUint32Slice,
reflect.Uint64: decUint64Slice,
reflect.Uintptr: decUintptrSlice,
}
func decBoolArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decBoolSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decBoolSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]bool)
if !ok {
// It is kind bool but not type bool. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding bool array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeUint() != 0
}
return true
}
func decComplex64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decComplex64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decComplex64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]complex64)
if !ok {
// It is kind complex64 but not type complex64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding complex64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
real := float32FromBits(state.decodeUint(), ovfl)
imag := float32FromBits(state.decodeUint(), ovfl)
slice[i] = complex(float32(real), float32(imag))
}
return true
}
func decComplex128Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decComplex128Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decComplex128Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]complex128)
if !ok {
// It is kind complex128 but not type complex128. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding complex128 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
real := float64FromBits(state.decodeUint())
imag := float64FromBits(state.decodeUint())
slice[i] = complex(real, imag)
}
return true
}
func decFloat32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decFloat32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decFloat32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]float32)
if !ok {
// It is kind float32 but not type float32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding float32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = float32(float32FromBits(state.decodeUint(), ovfl))
}
return true
}
func decFloat64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decFloat64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decFloat64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]float64)
if !ok {
// It is kind float64 but not type float64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding float64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = float64FromBits(state.decodeUint())
}
return true
}
func decIntArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decIntSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decIntSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int)
if !ok {
// It is kind int but not type int. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
// MinInt and MaxInt
if x < ^int64(^uint(0)>>1) || int64(^uint(0)>>1) < x {
error_(ovfl)
}
slice[i] = int(x)
}
return true
}
func decInt16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt16Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int16)
if !ok {
// It is kind int16 but not type int16. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int16 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt16 || math.MaxInt16 < x {
error_(ovfl)
}
slice[i] = int16(x)
}
return true
}
func decInt32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int32)
if !ok {
// It is kind int32 but not type int32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt32 || math.MaxInt32 < x {
error_(ovfl)
}
slice[i] = int32(x)
}
return true
}
func decInt64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int64)
if !ok {
// It is kind int64 but not type int64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeInt()
}
return true
}
func decInt8Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decInt8Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decInt8Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]int8)
if !ok {
// It is kind int8 but not type int8. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding int8 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeInt()
if x < math.MinInt8 || math.MaxInt8 < x {
error_(ovfl)
}
slice[i] = int8(x)
}
return true
}
func decStringArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decStringSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decStringSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]string)
if !ok {
// It is kind string but not type string. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding string array or slice: length exceeds input size (%d elements)", length)
}
u := state.decodeUint()
n := int(u)
if n < 0 || uint64(n) != u || n > state.b.Len() {
errorf("length of string exceeds input size (%d bytes)", u)
}
if n > state.b.Len() {
errorf("string data too long for buffer: %d", n)
}
// Read the data.
data := state.b.Bytes()
if len(data) < n {
errorf("invalid string length %d: exceeds input size %d", n, len(data))
}
slice[i] = string(data[:n])
state.b.Drop(n)
}
return true
}
func decUintArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUintSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUintSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint)
if !ok {
// It is kind uint but not type uint. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
/*TODO if math.MaxUint32 < x {
error_(ovfl)
}*/
slice[i] = uint(x)
}
return true
}
func decUint16Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint16Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint16Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint16)
if !ok {
// It is kind uint16 but not type uint16. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint16 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if math.MaxUint16 < x {
error_(ovfl)
}
slice[i] = uint16(x)
}
return true
}
func decUint32Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint32Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint32Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint32)
if !ok {
// It is kind uint32 but not type uint32. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint32 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if math.MaxUint32 < x {
error_(ovfl)
}
slice[i] = uint32(x)
}
return true
}
func decUint64Array(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUint64Slice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUint64Slice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uint64)
if !ok {
// It is kind uint64 but not type uint64. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uint64 array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
slice[i] = state.decodeUint()
}
return true
}
func decUintptrArray(state *decoderState, v reflect.Value, length int, ovfl error) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return decUintptrSlice(state, v.Slice(0, v.Len()), length, ovfl)
}
func decUintptrSlice(state *decoderState, v reflect.Value, length int, ovfl error) bool {
slice, ok := v.Interface().([]uintptr)
if !ok {
// It is kind uintptr but not type uintptr. TODO: We can handle this unsafely.
return false
}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding uintptr array or slice: length exceeds input size (%d elements)", length)
}
if i >= len(slice) {
// This is a slice that we only partially allocated.
growSlice(v, &slice, length)
}
x := state.decodeUint()
if uint64(^uintptr(0)) < x {
error_(ovfl)
}
slice[i] = uintptr(x)
}
return true
}
// growSlice is called for a slice that we only partially allocated,
// to grow it up to length.
func growSlice[E any](v reflect.Value, ps *[]E, length int) {
var zero E
s := *ps
s = append(s, zero)
cp := cap(s)
if cp > length {
cp = length
}
s = s[:cp]
v.Set(reflect.ValueOf(s))
*ps = s
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run decgen.go -output dec_helpers.go
package gob
import (
"encoding"
"errors"
"internal/saferio"
"io"
"math"
"math/bits"
"reflect"
)
var (
errBadUint = errors.New("gob: encoded unsigned integer out of range")
errBadType = errors.New("gob: unknown type id or corrupted data")
errRange = errors.New("gob: bad data: field numbers out of bounds")
)
type decHelper func(state *decoderState, v reflect.Value, length int, ovfl error) bool
// decoderState is the execution state of an instance of the decoder. A new state
// is created for nested objects.
type decoderState struct {
dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value).
b *decBuffer
fieldnum int // the last field number read.
next *decoderState // for free list
}
// decBuffer is an extremely simple, fast implementation of a read-only byte buffer.
// It is initialized by calling Size and then copying the data into the slice returned by Bytes().
type decBuffer struct {
data []byte
offset int // Read offset.
}
func (d *decBuffer) Read(p []byte) (int, error) {
n := copy(p, d.data[d.offset:])
if n == 0 && len(p) != 0 {
return 0, io.EOF
}
d.offset += n
return n, nil
}
func (d *decBuffer) Drop(n int) {
if n > d.Len() {
panic("drop")
}
d.offset += n
}
func (d *decBuffer) ReadByte() (byte, error) {
if d.offset >= len(d.data) {
return 0, io.EOF
}
c := d.data[d.offset]
d.offset++
return c, nil
}
func (d *decBuffer) Len() int {
return len(d.data) - d.offset
}
func (d *decBuffer) Bytes() []byte {
return d.data[d.offset:]
}
// SetBytes sets the buffer to the bytes, discarding any existing data.
func (d *decBuffer) SetBytes(data []byte) {
d.data = data
d.offset = 0
}
func (d *decBuffer) Reset() {
d.data = d.data[0:0]
d.offset = 0
}
// We pass the bytes.Buffer separately for easier testing of the infrastructure
// without requiring a full Decoder.
func (dec *Decoder) newDecoderState(buf *decBuffer) *decoderState {
d := dec.freeList
if d == nil {
d = new(decoderState)
d.dec = dec
} else {
dec.freeList = d.next
}
d.b = buf
return d
}
func (dec *Decoder) freeDecoderState(d *decoderState) {
d.next = dec.freeList
dec.freeList = d
}
func overflow(name string) error {
return errors.New(`value for "` + name + `" out of range`)
}
// decodeUintReader reads an encoded unsigned integer from an io.Reader.
// Used only by the Decoder to read the message length.
func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err error) {
width = 1
n, err := io.ReadFull(r, buf[0:width])
if n == 0 {
return
}
b := buf[0]
if b <= 0x7f {
return uint64(b), width, nil
}
n = -int(int8(b))
if n > uint64Size {
err = errBadUint
return
}
width, err = io.ReadFull(r, buf[0:n])
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}
// Could check that the high byte is zero but it's not worth it.
for _, b := range buf[0:width] {
x = x<<8 | uint64(b)
}
width++ // +1 for length byte
return
}
// decodeUint reads an encoded unsigned integer from state.r.
// Does not check for overflow.
func (state *decoderState) decodeUint() (x uint64) {
b, err := state.b.ReadByte()
if err != nil {
error_(err)
}
if b <= 0x7f {
return uint64(b)
}
n := -int(int8(b))
if n > uint64Size {
error_(errBadUint)
}
buf := state.b.Bytes()
if len(buf) < n {
errorf("invalid uint data length %d: exceeds input size %d", n, len(buf))
}
// Don't need to check error; it's safe to loop regardless.
// Could check that the high byte is zero but it's not worth it.
for _, b := range buf[0:n] {
x = x<<8 | uint64(b)
}
state.b.Drop(n)
return x
}
// decodeInt reads an encoded signed integer from state.r.
// Does not check for overflow.
func (state *decoderState) decodeInt() int64 {
x := state.decodeUint()
if x&1 != 0 {
return ^int64(x >> 1)
}
return int64(x >> 1)
}
// getLength decodes the next uint and makes sure it is a possible
// size for a data item that follows, which means it must fit in a
// non-negative int and fit in the buffer.
func (state *decoderState) getLength() (int, bool) {
n := int(state.decodeUint())
if n < 0 || state.b.Len() < n || tooBig <= n {
return 0, false
}
return n, true
}
// decOp is the signature of a decoding operator for a given type.
type decOp func(i *decInstr, state *decoderState, v reflect.Value)
// The 'instructions' of the decoding machine
type decInstr struct {
op decOp
field int // field number of the wire type
index []int // field access indices for destination type
ovfl error // error message for overflow/underflow (for arrays, of the elements)
}
// ignoreUint discards a uint value with no destination.
func ignoreUint(i *decInstr, state *decoderState, v reflect.Value) {
state.decodeUint()
}
// ignoreTwoUints discards a uint value with no destination. It's used to skip
// complex values.
func ignoreTwoUints(i *decInstr, state *decoderState, v reflect.Value) {
state.decodeUint()
state.decodeUint()
}
// Since the encoder writes no zeros, if we arrive at a decoder we have
// a value to extract and store. The field number has already been read
// (it's how we knew to call this decoder).
// Each decoder is responsible for handling any indirections associated
// with the data structure. If any pointer so reached is nil, allocation must
// be done.
// decAlloc takes a value and returns a settable value that can
// be assigned to. If the value is a pointer, decAlloc guarantees it points to storage.
// The callers to the individual decoders are expected to have used decAlloc.
// The individual decoders don't need to it.
func decAlloc(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Pointer {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
// decBool decodes a uint and stores it as a boolean in value.
func decBool(i *decInstr, state *decoderState, value reflect.Value) {
value.SetBool(state.decodeUint() != 0)
}
// decInt8 decodes an integer and stores it as an int8 in value.
func decInt8(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt8 || math.MaxInt8 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint8 decodes an unsigned integer and stores it as a uint8 in value.
func decUint8(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint8 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt16 decodes an integer and stores it as an int16 in value.
func decInt16(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt16 || math.MaxInt16 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint16 decodes an unsigned integer and stores it as a uint16 in value.
func decUint16(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint16 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt32 decodes an integer and stores it as an int32 in value.
func decInt32(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
if v < math.MinInt32 || math.MaxInt32 < v {
error_(i.ovfl)
}
value.SetInt(v)
}
// decUint32 decodes an unsigned integer and stores it as a uint32 in value.
func decUint32(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
if math.MaxUint32 < v {
error_(i.ovfl)
}
value.SetUint(v)
}
// decInt64 decodes an integer and stores it as an int64 in value.
func decInt64(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeInt()
value.SetInt(v)
}
// decUint64 decodes an unsigned integer and stores it as a uint64 in value.
func decUint64(i *decInstr, state *decoderState, value reflect.Value) {
v := state.decodeUint()
value.SetUint(v)
}
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
// (for example) transmit more compactly. This routine does the
// unswizzling.
func float64FromBits(u uint64) float64 {
v := bits.ReverseBytes64(u)
return math.Float64frombits(v)
}
// float32FromBits decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and returns it. It's a helper function for float32 and complex64.
// It returns a float64 because that's what reflection needs, but its return
// value is known to be accurately representable in a float32.
func float32FromBits(u uint64, ovfl error) float64 {
v := float64FromBits(u)
av := v
if av < 0 {
av = -av
}
// +Inf is OK in both 32- and 64-bit floats. Underflow is always OK.
if math.MaxFloat32 < av && av <= math.MaxFloat64 {
error_(ovfl)
}
return v
}
// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
// number, and stores it in value.
func decFloat32(i *decInstr, state *decoderState, value reflect.Value) {
value.SetFloat(float32FromBits(state.decodeUint(), i.ovfl))
}
// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
// number, and stores it in value.
func decFloat64(i *decInstr, state *decoderState, value reflect.Value) {
value.SetFloat(float64FromBits(state.decodeUint()))
}
// decComplex64 decodes a pair of unsigned integers, treats them as a
// pair of floating point numbers, and stores them as a complex64 in value.
// The real part comes first.
func decComplex64(i *decInstr, state *decoderState, value reflect.Value) {
real := float32FromBits(state.decodeUint(), i.ovfl)
imag := float32FromBits(state.decodeUint(), i.ovfl)
value.SetComplex(complex(real, imag))
}
// decComplex128 decodes a pair of unsigned integers, treats them as a
// pair of floating point numbers, and stores them as a complex128 in value.
// The real part comes first.
func decComplex128(i *decInstr, state *decoderState, value reflect.Value) {
real := float64FromBits(state.decodeUint())
imag := float64FromBits(state.decodeUint())
value.SetComplex(complex(real, imag))
}
// decUint8Slice decodes a byte slice and stores in value a slice header
// describing the data.
// uint8 slices are encoded as an unsigned count followed by the raw bytes.
func decUint8Slice(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("bad %s slice length: %d", value.Type(), n)
}
if value.Cap() < n {
safe := saferio.SliceCap((*byte)(nil), uint64(n))
if safe < 0 {
errorf("%s slice too big: %d elements", value.Type(), n)
}
value.Set(reflect.MakeSlice(value.Type(), safe, safe))
ln := safe
i := 0
for i < n {
if i >= ln {
// We didn't allocate the entire slice,
// due to using saferio.SliceCap.
// Append a value to grow the slice.
// The slice is full, so this should
// bump up the capacity.
value.Set(reflect.Append(value, reflect.Zero(value.Type().Elem())))
}
// Copy into s up to the capacity or n,
// whichever is less.
ln = value.Cap()
if ln > n {
ln = n
}
value.SetLen(ln)
sub := value.Slice(i, ln)
if _, err := state.b.Read(sub.Bytes()); err != nil {
errorf("error decoding []byte at %d: %s", err, i)
}
i = ln
}
} else {
value.SetLen(n)
if _, err := state.b.Read(value.Bytes()); err != nil {
errorf("error decoding []byte: %s", err)
}
}
}
// decString decodes byte array and stores in value a string header
// describing the data.
// Strings are encoded as an unsigned count followed by the raw bytes.
func decString(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("bad %s slice length: %d", value.Type(), n)
}
// Read the data.
data := state.b.Bytes()
if len(data) < n {
errorf("invalid string length %d: exceeds input size %d", n, len(data))
}
s := string(data[:n])
state.b.Drop(n)
value.SetString(s)
}
// ignoreUint8Array skips over the data for a byte slice value with no destination.
func ignoreUint8Array(i *decInstr, state *decoderState, value reflect.Value) {
n, ok := state.getLength()
if !ok {
errorf("slice length too large")
}
bn := state.b.Len()
if bn < n {
errorf("invalid slice length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
}
// Execution engine
// The encoder engine is an array of instructions indexed by field number of the incoming
// decoder. It is executed with random access according to field number.
type decEngine struct {
instr []decInstr
numInstr int // the number of active instructions
}
// decodeSingle decodes a top-level value that is not a struct and stores it in value.
// Such values are preceded by a zero, making them have the memory layout of a
// struct field (although with an illegal field number).
func (dec *Decoder) decodeSingle(engine *decEngine, value reflect.Value) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = singletonField
if state.decodeUint() != 0 {
errorf("decode: corrupted data: non-zero delta for singleton")
}
instr := &engine.instr[singletonField]
instr.op(instr, state, value)
}
// decodeStruct decodes a top-level struct and stores it in value.
// Indir is for the value, not the type. At the time of the call it may
// differ from ut.indir, which was computed when the engine was built.
// This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, value reflect.Value) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = -1
for state.b.Len() > 0 {
delta := int(state.decodeUint())
if delta < 0 {
errorf("decode: corrupted data: negative delta")
}
if delta == 0 { // struct terminator is zero delta fieldnum
break
}
if state.fieldnum >= len(engine.instr)-delta { // subtract to compare without overflow
error_(errRange)
}
fieldnum := state.fieldnum + delta
instr := &engine.instr[fieldnum]
var field reflect.Value
if instr.index != nil {
// Otherwise the field is unknown to us and instr.op is an ignore op.
field = value.FieldByIndex(instr.index)
if field.Kind() == reflect.Pointer {
field = decAlloc(field)
}
}
instr.op(instr, state, field)
state.fieldnum = fieldnum
}
}
var noValue reflect.Value
// ignoreStruct discards the data for a struct with no destination.
func (dec *Decoder) ignoreStruct(engine *decEngine) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = -1
for state.b.Len() > 0 {
delta := int(state.decodeUint())
if delta < 0 {
errorf("ignore decode: corrupted data: negative delta")
}
if delta == 0 { // struct terminator is zero delta fieldnum
break
}
fieldnum := state.fieldnum + delta
if fieldnum >= len(engine.instr) {
error_(errRange)
}
instr := &engine.instr[fieldnum]
instr.op(instr, state, noValue)
state.fieldnum = fieldnum
}
}
// ignoreSingle discards the data for a top-level non-struct value with no
// destination. It's used when calling Decode with a nil value.
func (dec *Decoder) ignoreSingle(engine *decEngine) {
state := dec.newDecoderState(&dec.buf)
defer dec.freeDecoderState(state)
state.fieldnum = singletonField
delta := int(state.decodeUint())
if delta != 0 {
errorf("decode: corrupted data: non-zero delta for singleton")
}
instr := &engine.instr[singletonField]
instr.op(instr, state, noValue)
}
// decodeArrayHelper does the work for decoding arrays and slices.
func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
if helper != nil && helper(state, value, length, ovfl) {
return
}
instr := &decInstr{elemOp, 0, nil, ovfl}
isPtr := value.Type().Elem().Kind() == reflect.Pointer
ln := value.Len()
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding array or slice: length exceeds input size (%d elements)", length)
}
if i >= ln {
// This is a slice that we only partially allocated.
// Grow it using append, up to length.
value.Set(reflect.Append(value, reflect.Zero(value.Type().Elem())))
cp := value.Cap()
if cp > length {
cp = length
}
value.SetLen(cp)
ln = cp
}
v := value.Index(i)
if isPtr {
v = decAlloc(v)
}
elemOp(instr, state, v)
}
}
// decodeArray decodes an array and stores it in value.
// The length is an unsigned integer preceding the elements. Even though the length is redundant
// (it's part of the type), it's a useful check and is included in the encoding.
func (dec *Decoder) decodeArray(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error, helper decHelper) {
if n := state.decodeUint(); n != uint64(length) {
errorf("length mismatch in decodeArray")
}
dec.decodeArrayHelper(state, value, elemOp, length, ovfl, helper)
}
// decodeIntoValue is a helper for map decoding.
func decodeIntoValue(state *decoderState, op decOp, isPtr bool, value reflect.Value, instr *decInstr) reflect.Value {
v := value
if isPtr {
v = decAlloc(value)
}
op(instr, state, v)
return value
}
// decodeMap decodes a map and stores it in value.
// Maps are encoded as a length followed by key:value pairs.
// Because the internals of maps are not visible to us, we must
// use reflection rather than pointer magic.
func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, value reflect.Value, keyOp, elemOp decOp, ovfl error) {
n := int(state.decodeUint())
if value.IsNil() {
value.Set(reflect.MakeMapWithSize(mtyp, n))
}
keyIsPtr := mtyp.Key().Kind() == reflect.Pointer
elemIsPtr := mtyp.Elem().Kind() == reflect.Pointer
keyInstr := &decInstr{keyOp, 0, nil, ovfl}
elemInstr := &decInstr{elemOp, 0, nil, ovfl}
keyP := reflect.New(mtyp.Key())
elemP := reflect.New(mtyp.Elem())
for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIsPtr, keyP.Elem(), keyInstr)
elem := decodeIntoValue(state, elemOp, elemIsPtr, elemP.Elem(), elemInstr)
value.SetMapIndex(key, elem)
keyP.Elem().SetZero()
elemP.Elem().SetZero()
}
}
// ignoreArrayHelper does the work for discarding arrays and slices.
func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) {
instr := &decInstr{elemOp, 0, nil, errors.New("no error")}
for i := 0; i < length; i++ {
if state.b.Len() == 0 {
errorf("decoding array or slice: length exceeds input size (%d elements)", length)
}
elemOp(instr, state, noValue)
}
}
// ignoreArray discards the data for an array value with no destination.
func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) {
if n := state.decodeUint(); n != uint64(length) {
errorf("length mismatch in ignoreArray")
}
dec.ignoreArrayHelper(state, elemOp, length)
}
// ignoreMap discards the data for a map value with no destination.
func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
n := int(state.decodeUint())
keyInstr := &decInstr{keyOp, 0, nil, errors.New("no error")}
elemInstr := &decInstr{elemOp, 0, nil, errors.New("no error")}
for i := 0; i < n; i++ {
keyOp(keyInstr, state, noValue)
elemOp(elemInstr, state, noValue)
}
}
// decodeSlice decodes a slice and stores it in value.
// Slices are encoded as an unsigned length followed by the elements.
func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp decOp, ovfl error, helper decHelper) {
u := state.decodeUint()
typ := value.Type()
size := uint64(typ.Elem().Size())
nBytes := u * size
n := int(u)
// Take care with overflow in this calculation.
if n < 0 || uint64(n) != u || nBytes > tooBig || (size > 0 && nBytes/size != u) {
// We don't check n against buffer length here because if it's a slice
// of interfaces, there will be buffer reloads.
errorf("%s slice too big: %d elements of %d bytes", typ.Elem(), u, size)
}
if value.Cap() < n {
safe := saferio.SliceCap(reflect.Zero(reflect.PtrTo(typ.Elem())).Interface(), uint64(n))
if safe < 0 {
errorf("%s slice too big: %d elements of %d bytes", typ.Elem(), u, size)
}
value.Set(reflect.MakeSlice(typ, safe, safe))
} else {
value.SetLen(n)
}
dec.decodeArrayHelper(state, value, elemOp, n, ovfl, helper)
}
// ignoreSlice skips over the data for a slice value with no destination.
func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) {
dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint()))
}
// decodeInterface decodes an interface value and stores it in value.
// Interfaces are encoded as the name of a concrete type followed by a value.
// If the name is empty, the value is nil and no value is sent.
func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, value reflect.Value) {
// Read the name of the concrete type.
nr := state.decodeUint()
if nr > 1<<31 { // zero is permissible for anonymous types
errorf("invalid type name length %d", nr)
}
if nr > uint64(state.b.Len()) {
errorf("invalid type name length %d: exceeds input size", nr)
}
n := int(nr)
name := state.b.Bytes()[:n]
state.b.Drop(n)
// Allocate the destination interface value.
if len(name) == 0 {
// Copy the nil interface value to the target.
value.SetZero()
return
}
if len(name) > 1024 {
errorf("name too long (%d bytes): %.20q...", len(name), name)
}
// The concrete type must be registered.
typi, ok := nameToConcreteType.Load(string(name))
if !ok {
errorf("name not registered for interface: %q", name)
}
typ := typi.(reflect.Type)
// Read the type id of the concrete value.
concreteId := dec.decodeTypeSequence(true)
if concreteId < 0 {
error_(dec.err)
}
// Byte count of value is next; we don't care what it is (it's there
// in case we want to ignore the value by skipping it completely).
state.decodeUint()
// Read the concrete value.
v := allocValue(typ)
dec.decodeValue(concreteId, v)
if dec.err != nil {
error_(dec.err)
}
// Assign the concrete value to the interface.
// Tread carefully; it might not satisfy the interface.
if !typ.AssignableTo(ityp) {
errorf("%s is not assignable to type %s", typ, ityp)
}
// Copy the interface value to the target.
value.Set(v)
}
// ignoreInterface discards the data for an interface value with no destination.
func (dec *Decoder) ignoreInterface(state *decoderState) {
// Read the name of the concrete type.
n, ok := state.getLength()
if !ok {
errorf("bad interface encoding: name too large for buffer")
}
bn := state.b.Len()
if bn < n {
errorf("invalid interface value length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
id := dec.decodeTypeSequence(true)
if id < 0 {
error_(dec.err)
}
// At this point, the decoder buffer contains a delimited value. Just toss it.
n, ok = state.getLength()
if !ok {
errorf("bad interface encoding: data length too large for buffer")
}
state.b.Drop(n)
}
// decodeGobDecoder decodes something implementing the GobDecoder interface.
// The data is encoded as a byte slice.
func (dec *Decoder) decodeGobDecoder(ut *userTypeInfo, state *decoderState, value reflect.Value) {
// Read the bytes for the value.
n, ok := state.getLength()
if !ok {
errorf("GobDecoder: length too large for buffer")
}
b := state.b.Bytes()
if len(b) < n {
errorf("GobDecoder: invalid data length %d: exceeds input size %d", n, len(b))
}
b = b[:n]
state.b.Drop(n)
var err error
// We know it's one of these.
switch ut.externalDec {
case xGob:
err = value.Interface().(GobDecoder).GobDecode(b)
case xBinary:
err = value.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(b)
case xText:
err = value.Interface().(encoding.TextUnmarshaler).UnmarshalText(b)
}
if err != nil {
error_(err)
}
}
// ignoreGobDecoder discards the data for a GobDecoder value with no destination.
func (dec *Decoder) ignoreGobDecoder(state *decoderState) {
// Read the bytes for the value.
n, ok := state.getLength()
if !ok {
errorf("GobDecoder: length too large for buffer")
}
bn := state.b.Len()
if bn < n {
errorf("GobDecoder: invalid data length %d: exceeds input size %d", n, bn)
}
state.b.Drop(n)
}
// Index by Go types.
var decOpTable = [...]decOp{
reflect.Bool: decBool,
reflect.Int8: decInt8,
reflect.Int16: decInt16,
reflect.Int32: decInt32,
reflect.Int64: decInt64,
reflect.Uint8: decUint8,
reflect.Uint16: decUint16,
reflect.Uint32: decUint32,
reflect.Uint64: decUint64,
reflect.Float32: decFloat32,
reflect.Float64: decFloat64,
reflect.Complex64: decComplex64,
reflect.Complex128: decComplex128,
reflect.String: decString,
}
// Indexed by gob types. tComplex will be added during type.init().
var decIgnoreOpMap = map[typeId]decOp{
tBool: ignoreUint,
tInt: ignoreUint,
tUint: ignoreUint,
tFloat: ignoreUint,
tBytes: ignoreUint8Array,
tString: ignoreUint8Array,
tComplex: ignoreTwoUints,
}
// decOpFor returns the decoding op for the base type under rt and
// the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) *decOp {
ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.externalDec != 0 {
return dec.gobDecodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr
}
typ := ut.base
var op decOp
k := typ.Kind()
if int(k) < len(decOpTable) {
op = decOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ; t.Kind() {
case reflect.Array:
name = "element of " + name
elemId := dec.wireType[wireId].ArrayT.Elem
elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
helper := decArrayHelper[t.Elem().Kind()]
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeArray(state, value, *elemOp, t.Len(), ovfl, helper)
}
case reflect.Map:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decOpFor(keyId, t.Key(), "key of "+name, inProgress)
elemOp := dec.decOpFor(elemId, t.Elem(), "element of "+name, inProgress)
ovfl := overflow(name)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeMap(t, state, value, *keyOp, *elemOp, ovfl)
}
case reflect.Slice:
name = "element of " + name
if t.Elem().Kind() == reflect.Uint8 {
op = decUint8Slice
break
}
var elemId typeId
if tt := builtinIdToType(wireId); tt != nil {
elemId = tt.(*sliceType).Elem
} else {
elemId = dec.wireType[wireId].SliceT.Elem
}
elemOp := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
helper := decSliceHelper[t.Elem().Kind()]
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeSlice(state, value, *elemOp, ovfl, helper)
}
case reflect.Struct:
// Generate a closure that calls out to the engine for the nested type.
ut := userType(typ)
enginePtr, err := dec.getDecEnginePtr(wireId, ut)
if err != nil {
error_(err)
}
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// indirect through enginePtr to delay evaluation for recursive structs.
dec.decodeStruct(*enginePtr, value)
}
case reflect.Interface:
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.decodeInterface(t, state, value)
}
}
}
if op == nil {
errorf("decode can't handle type %s", rt)
}
return &op
}
var maxIgnoreNestingDepth = 10000
// decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp {
if depth > maxIgnoreNestingDepth {
error_(errors.New("invalid nesting depth"))
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[wireId]; opPtr != nil {
return opPtr
}
op, ok := decIgnoreOpMap[wireId]
if !ok {
inProgress[wireId] = &op
if wireId == tInterface {
// Special case because it's a method: the ignored item might
// define types and we need to record their state in the decoder.
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreInterface(state)
}
return &op
}
// Special cases
wire := dec.wireType[wireId]
switch {
case wire == nil:
errorf("bad data: undefined type %s", wireId.string())
case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len)
}
case wire.MapT != nil:
keyId := dec.wireType[wireId].MapT.Key
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1)
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreMap(state, *keyOp, *elemOp)
}
case wire.SliceT != nil:
elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1)
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreSlice(state, *elemOp)
}
case wire.StructT != nil:
// Generate a closure that calls out to the engine for the nested type.
enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil {
error_(err)
}
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// indirect through enginePtr to delay evaluation for recursive structs
state.dec.ignoreStruct(*enginePtr)
}
case wire.GobEncoderT != nil, wire.BinaryMarshalerT != nil, wire.TextMarshalerT != nil:
op = func(i *decInstr, state *decoderState, value reflect.Value) {
state.dec.ignoreGobDecoder(state)
}
}
}
if op == nil {
errorf("bad data: ignore can't handle type %s", wireId.string())
}
return &op
}
// gobDecodeOpFor returns the op for a type that is known to implement
// GobDecoder.
func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) *decOp {
rcvrType := ut.user
if ut.decIndir == -1 {
rcvrType = reflect.PointerTo(rcvrType)
} else if ut.decIndir > 0 {
for i := int8(0); i < ut.decIndir; i++ {
rcvrType = rcvrType.Elem()
}
}
var op decOp
op = func(i *decInstr, state *decoderState, value reflect.Value) {
// We now have the base type. We need its address if the receiver is a pointer.
if value.Kind() != reflect.Pointer && rcvrType.Kind() == reflect.Pointer {
value = value.Addr()
}
state.dec.decodeGobDecoder(ut, state, value)
}
return &op
}
// compatibleType asks: Are these two gob Types compatible?
// Answers the question for basic types, arrays, maps and slices, plus
// GobEncoder/Decoder pairs.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
ut := userType(fr)
wire, ok := dec.wireType[fw]
// If wire was encoded with an encoding method, fr must have that method.
// And if not, it must not.
// At most one of the booleans in ut is set.
// We could possibly relax this constraint in the future in order to
// choose the decoding method using the data in the wireType.
// The parentheses look odd but are correct.
if (ut.externalDec == xGob) != (ok && wire.GobEncoderT != nil) ||
(ut.externalDec == xBinary) != (ok && wire.BinaryMarshalerT != nil) ||
(ut.externalDec == xText) != (ok && wire.TextMarshalerT != nil) {
return false
}
if ut.externalDec != 0 { // This test trumps all others.
return true
}
switch t := ut.base; t.Kind() {
default:
// chan, etc: cannot handle.
return false
case reflect.Bool:
return fw == tBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fw == tInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return fw == tUint
case reflect.Float32, reflect.Float64:
return fw == tFloat
case reflect.Complex64, reflect.Complex128:
return fw == tComplex
case reflect.String:
return fw == tString
case reflect.Interface:
return fw == tInterface
case reflect.Array:
if !ok || wire.ArrayT == nil {
return false
}
array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case reflect.Map:
if !ok || wire.MapT == nil {
return false
}
MapType := wire.MapT
return dec.compatibleType(t.Key(), MapType.Key, inProgress) && dec.compatibleType(t.Elem(), MapType.Elem, inProgress)
case reflect.Slice:
// Is it an array of bytes?
if t.Elem().Kind() == reflect.Uint8 {
return fw == tBytes
}
// Extract and compare element types.
var sw *sliceType
if tt := builtinIdToType(fw); tt != nil {
sw, _ = tt.(*sliceType)
} else if wire != nil {
sw = wire.SliceT
}
elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
case reflect.Struct:
return true
}
}
// typeString returns a human-readable description of the type identified by remoteId.
func (dec *Decoder) typeString(remoteId typeId) string {
typeLock.Lock()
defer typeLock.Unlock()
if t := idToType[remoteId]; t != nil {
// globally known type.
return t.string()
}
return dec.wireType[remoteId].string()
}
// compileSingle compiles the decoder engine for a non-struct top-level value, including
// GobDecoders.
func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) {
rt := ut.user
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
if !dec.compatibleType(rt, remoteId, make(map[reflect.Type]typeId)) {
remoteType := dec.typeString(remoteId)
// Common confusing case: local interface type, remote concrete type.
if ut.base.Kind() == reflect.Interface && remoteId != tInterface {
return nil, errors.New("gob: local interface type " + name + " can only be decoded from remote interface type; received concrete type " + remoteType)
}
return nil, errors.New("gob: decoding into local type " + name + ", received remote type " + remoteType)
}
op := dec.decOpFor(remoteId, rt, name, make(map[reflect.Type]*decOp))
ovfl := errors.New(`value for "` + name + `" out of range`)
engine.instr[singletonField] = decInstr{*op, singletonField, nil, ovfl}
engine.numInstr = 1
return
}
// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded.
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine {
engine := new(decEngine)
engine.instr = make([]decInstr, 1) // one item
op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0)
ovfl := overflow(dec.typeString(remoteId))
engine.instr[0] = decInstr{*op, 0, nil, ovfl}
engine.numInstr = 1
return engine
}
// compileDec compiles the decoder engine for a value. If the value is not a struct,
// it calls out to compileSingle.
func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err error) {
defer catchError(&err)
rt := ut.base
srt := rt
if srt.Kind() != reflect.Struct || ut.externalDec != 0 {
return dec.compileSingle(remoteId, ut)
}
var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder.
// Also we know we're decoding a struct now, so the client must have sent one.
if t := builtinIdToType(remoteId); t != nil {
wireStruct, _ = t.(*structType)
} else {
wire := dec.wireType[remoteId]
if wire == nil {
error_(errBadType)
}
wireStruct = wire.StructT
}
if wireStruct == nil {
errorf("type mismatch in decoder: want struct type %s; got non-struct", rt)
}
engine = new(decEngine)
engine.instr = make([]decInstr, len(wireStruct.Field))
seen := make(map[reflect.Type]*decOp)
// Loop over the fields of the wire type.
for fieldnum := 0; fieldnum < len(wireStruct.Field); fieldnum++ {
wireField := wireStruct.Field[fieldnum]
if wireField.Name == "" {
errorf("empty name for remote field of type %s", wireStruct.Name)
}
ovfl := overflow(wireField.Name)
// Find the field of the local type with the same name.
localField, present := srt.FieldByName(wireField.Name)
// TODO(r): anonymous names
if !present || !isExported(wireField.Name) {
op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0)
engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl}
continue
}
if !dec.compatibleType(localField.Type, wireField.Id, make(map[reflect.Type]typeId)) {
errorf("wrong type (%s) for received field %s.%s", localField.Type, wireStruct.Name, wireField.Name)
}
op := dec.decOpFor(wireField.Id, localField.Type, localField.Name, seen)
engine.instr[fieldnum] = decInstr{*op, fieldnum, localField.Index, ovfl}
engine.numInstr++
}
return
}
// getDecEnginePtr returns the engine for the specified type.
func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err error) {
rt := ut.user
decoderMap, ok := dec.decoderCache[rt]
if !ok {
decoderMap = make(map[typeId]**decEngine)
dec.decoderCache[rt] = decoderMap
}
if enginePtr, ok = decoderMap[remoteId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine)
decoderMap[remoteId] = enginePtr
*enginePtr, err = dec.compileDec(remoteId, ut)
if err != nil {
delete(decoderMap, remoteId)
}
}
return
}
// emptyStruct is the type we compile into when ignoring a struct value.
type emptyStruct struct{}
var emptyStructType = reflect.TypeOf((*emptyStruct)(nil)).Elem()
// getIgnoreEnginePtr returns the engine for the specified type when the value is to be discarded.
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err error) {
var ok bool
if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine)
dec.ignorerCache[wireId] = enginePtr
wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil {
*enginePtr, err = dec.compileDec(wireId, userType(emptyStructType))
} else {
*enginePtr = dec.compileIgnoreSingle(wireId)
}
if err != nil {
delete(dec.ignorerCache, wireId)
}
}
return
}
// decodeValue decodes the data stream representing a value and stores it in value.
func (dec *Decoder) decodeValue(wireId typeId, value reflect.Value) {
defer catchError(&dec.err)
// If the value is nil, it means we should just ignore this item.
if !value.IsValid() {
dec.decodeIgnoredValue(wireId)
return
}
// Dereference down to the underlying type.
ut := userType(value.Type())
base := ut.base
var enginePtr **decEngine
enginePtr, dec.err = dec.getDecEnginePtr(wireId, ut)
if dec.err != nil {
return
}
value = decAlloc(value)
engine := *enginePtr
if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 {
wt := dec.wireType[wireId]
if engine.numInstr == 0 && st.NumField() > 0 &&
wt != nil && len(wt.StructT.Field) > 0 {
name := base.Name()
errorf("type mismatch: no fields matched compiling decoder for %s", name)
}
dec.decodeStruct(engine, value)
} else {
dec.decodeSingle(engine, value)
}
}
// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it.
func (dec *Decoder) decodeIgnoredValue(wireId typeId) {
var enginePtr **decEngine
enginePtr, dec.err = dec.getIgnoreEnginePtr(wireId)
if dec.err != nil {
return
}
wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil {
dec.ignoreStruct(*enginePtr)
} else {
dec.ignoreSingle(*enginePtr)
}
}
const (
intBits = 32 << (^uint(0) >> 63)
uintptrBits = 32 << (^uintptr(0) >> 63)
)
func init() {
var iop, uop decOp
switch intBits {
case 32:
iop = decInt32
uop = decUint32
case 64:
iop = decInt64
uop = decUint64
default:
panic("gob: unknown size of int/uint")
}
decOpTable[reflect.Int] = iop
decOpTable[reflect.Uint] = uop
// Finally uintptr
switch uintptrBits {
case 32:
uop = decUint32
case 64:
uop = decUint64
default:
panic("gob: unknown size of uintptr")
}
decOpTable[reflect.Uintptr] = uop
}
// Gob depends on being able to take the address
// of zeroed Values it creates, so use this wrapper instead
// of the standard reflect.Zero.
// Each call allocates once.
func allocValue(t reflect.Type) reflect.Value {
return reflect.New(t).Elem()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"bufio"
"errors"
"internal/saferio"
"io"
"reflect"
"sync"
)
// tooBig provides a sanity check for sizes; used in several places. Upper limit
// of is 1GB on 32-bit systems, 8GB on 64-bit, allowing room to grow a little
// without overflow.
const tooBig = (1 << 30) << (^uint(0) >> 62)
// A Decoder manages the receipt of type and data information read from the
// remote side of a connection. It is safe for concurrent use by multiple
// goroutines.
//
// The Decoder does only basic sanity checking on decoded input sizes,
// and its limits are not configurable. Take caution when decoding gob data
// from untrusted sources.
type Decoder struct {
mutex sync.Mutex // each item must be received atomically
r io.Reader // source of the data
buf decBuffer // buffer for more efficient i/o from r
wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects
freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages
err error
}
// NewDecoder returns a new decoder that reads from the io.Reader.
// If r does not also implement io.ByteReader, it will be wrapped in a
// bufio.Reader.
func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder)
// We use the ability to read bytes as a plausible surrogate for buffering.
if _, ok := r.(io.ByteReader); !ok {
r = bufio.NewReader(r)
}
dec.r = r
dec.wireType = make(map[typeId]*wireType)
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine)
dec.countBuf = make([]byte, 9) // counts may be uint64s (unlikely!), require 9 bytes
return dec
}
// recvType loads the definition of a type.
func (dec *Decoder) recvType(id typeId) {
// Have we already seen this type? That's an error
if id < firstUserId || dec.wireType[id] != nil {
dec.err = errors.New("gob: duplicate type received")
return
}
// Type:
wire := new(wireType)
dec.decodeValue(tWireType, reflect.ValueOf(wire))
if dec.err != nil {
return
}
// Remember we've seen this type.
dec.wireType[id] = wire
}
var errBadCount = errors.New("invalid message length")
// recvMessage reads the next count-delimited item from the input. It is the converse
// of Encoder.writeMessage. It returns false on EOF or other error reading the message.
func (dec *Decoder) recvMessage() bool {
// Read a count.
nbytes, _, err := decodeUintReader(dec.r, dec.countBuf)
if err != nil {
dec.err = err
return false
}
if nbytes >= tooBig {
dec.err = errBadCount
return false
}
dec.readMessage(int(nbytes))
return dec.err == nil
}
// readMessage reads the next nbytes bytes from the input.
func (dec *Decoder) readMessage(nbytes int) {
if dec.buf.Len() != 0 {
// The buffer should always be empty now.
panic("non-empty decoder buffer")
}
// Read the data
var buf []byte
buf, dec.err = saferio.ReadData(dec.r, uint64(nbytes))
dec.buf.SetBytes(buf)
if dec.err == io.EOF {
dec.err = io.ErrUnexpectedEOF
}
}
// toInt turns an encoded uint64 into an int, according to the marshaling rules.
func toInt(x uint64) int64 {
i := int64(x >> 1)
if x&1 != 0 {
i = ^i
}
return i
}
func (dec *Decoder) nextInt() int64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return toInt(n)
}
func (dec *Decoder) nextUint() uint64 {
n, _, err := decodeUintReader(&dec.buf, dec.countBuf)
if err != nil {
dec.err = err
}
return n
}
// decodeTypeSequence parses:
// TypeSequence
//
// (TypeDefinition DelimitedTypeDefinition*)?
//
// and returns the type id of the next value. It returns -1 at
// EOF. Upon return, the remainder of dec.buf is the value to be
// decoded. If this is an interface value, it can be ignored by
// resetting that buffer.
func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
firstMessage := true
for dec.err == nil {
if dec.buf.Len() == 0 {
if !dec.recvMessage() {
// We can only return io.EOF if the input was empty.
// If we read one or more type spec messages,
// require a data item message to follow.
// If we hit an EOF before that, then give ErrUnexpectedEOF.
if !firstMessage && dec.err == io.EOF {
dec.err = io.ErrUnexpectedEOF
}
break
}
}
// Receive a type id.
id := typeId(dec.nextInt())
if id >= 0 {
// Value follows.
return id
}
// Type definition for (-id) follows.
dec.recvType(-id)
if dec.err != nil {
break
}
// When decoding an interface, after a type there may be a
// DelimitedValue still in the buffer. Skip its count.
// (Alternatively, the buffer is empty and the byte count
// will be absorbed by recvMessage.)
if dec.buf.Len() > 0 {
if !isInterface {
dec.err = errors.New("extra data in buffer")
break
}
dec.nextUint()
}
firstMessage = false
}
return -1
}
// Decode reads the next value from the input stream and stores
// it in the data represented by the empty interface value.
// If e is nil, the value will be discarded. Otherwise,
// the value underlying e must be a pointer to the
// correct type for the next data item received.
// If the input is at EOF, Decode returns io.EOF and
// does not modify e.
func (dec *Decoder) Decode(e any) error {
if e == nil {
return dec.DecodeValue(reflect.Value{})
}
value := reflect.ValueOf(e)
// If e represents a value as opposed to a pointer, the answer won't
// get back to the caller. Make sure it's a pointer.
if value.Type().Kind() != reflect.Pointer {
dec.err = errors.New("gob: attempt to decode into a non-pointer")
return dec.err
}
return dec.DecodeValue(value)
}
// DecodeValue reads the next value from the input stream.
// If v is the zero reflect.Value (v.Kind() == Invalid), DecodeValue discards the value.
// Otherwise, it stores the value into v. In that case, v must represent
// a non-nil pointer to data or be an assignable reflect.Value (v.CanSet())
// If the input is at EOF, DecodeValue returns io.EOF and
// does not modify v.
func (dec *Decoder) DecodeValue(v reflect.Value) error {
if v.IsValid() {
if v.Kind() == reflect.Pointer && !v.IsNil() {
// That's okay, we'll store through the pointer.
} else if !v.CanSet() {
return errors.New("gob: DecodeValue of unassignable value")
}
}
// Make sure we're single-threaded through here.
dec.mutex.Lock()
defer dec.mutex.Unlock()
dec.buf.Reset() // In case data lingers from previous invocation.
dec.err = nil
id := dec.decodeTypeSequence(false)
if dec.err == nil {
dec.decodeValue(id, v)
}
return dec.err
}
// If debug.go is compiled into the program, debugFunc prints a human-readable
// representation of the gob data read from r by calling that file's Debug function.
// Otherwise it is nil.
var debugFunc func(io.Reader)
// Code generated by go run encgen.go -output enc_helpers.go; DO NOT EDIT.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"reflect"
)
var encArrayHelper = map[reflect.Kind]encHelper{
reflect.Bool: encBoolArray,
reflect.Complex64: encComplex64Array,
reflect.Complex128: encComplex128Array,
reflect.Float32: encFloat32Array,
reflect.Float64: encFloat64Array,
reflect.Int: encIntArray,
reflect.Int16: encInt16Array,
reflect.Int32: encInt32Array,
reflect.Int64: encInt64Array,
reflect.Int8: encInt8Array,
reflect.String: encStringArray,
reflect.Uint: encUintArray,
reflect.Uint16: encUint16Array,
reflect.Uint32: encUint32Array,
reflect.Uint64: encUint64Array,
reflect.Uintptr: encUintptrArray,
}
var encSliceHelper = map[reflect.Kind]encHelper{
reflect.Bool: encBoolSlice,
reflect.Complex64: encComplex64Slice,
reflect.Complex128: encComplex128Slice,
reflect.Float32: encFloat32Slice,
reflect.Float64: encFloat64Slice,
reflect.Int: encIntSlice,
reflect.Int16: encInt16Slice,
reflect.Int32: encInt32Slice,
reflect.Int64: encInt64Slice,
reflect.Int8: encInt8Slice,
reflect.String: encStringSlice,
reflect.Uint: encUintSlice,
reflect.Uint16: encUint16Slice,
reflect.Uint32: encUint32Slice,
reflect.Uint64: encUint64Slice,
reflect.Uintptr: encUintptrSlice,
}
func encBoolArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encBoolSlice(state, v.Slice(0, v.Len()))
}
func encBoolSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]bool)
if !ok {
// It is kind bool but not type bool. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != false || state.sendZero {
if x {
state.encodeUint(1)
} else {
state.encodeUint(0)
}
}
}
return true
}
func encComplex64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encComplex64Slice(state, v.Slice(0, v.Len()))
}
func encComplex64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]complex64)
if !ok {
// It is kind complex64 but not type complex64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0+0i || state.sendZero {
rpart := floatBits(float64(real(x)))
ipart := floatBits(float64(imag(x)))
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
return true
}
func encComplex128Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encComplex128Slice(state, v.Slice(0, v.Len()))
}
func encComplex128Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]complex128)
if !ok {
// It is kind complex128 but not type complex128. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0+0i || state.sendZero {
rpart := floatBits(real(x))
ipart := floatBits(imag(x))
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
return true
}
func encFloat32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encFloat32Slice(state, v.Slice(0, v.Len()))
}
func encFloat32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]float32)
if !ok {
// It is kind float32 but not type float32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
bits := floatBits(float64(x))
state.encodeUint(bits)
}
}
return true
}
func encFloat64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encFloat64Slice(state, v.Slice(0, v.Len()))
}
func encFloat64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]float64)
if !ok {
// It is kind float64 but not type float64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
bits := floatBits(x)
state.encodeUint(bits)
}
}
return true
}
func encIntArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encIntSlice(state, v.Slice(0, v.Len()))
}
func encIntSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int)
if !ok {
// It is kind int but not type int. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt16Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt16Slice(state, v.Slice(0, v.Len()))
}
func encInt16Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int16)
if !ok {
// It is kind int16 but not type int16. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt32Slice(state, v.Slice(0, v.Len()))
}
func encInt32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int32)
if !ok {
// It is kind int32 but not type int32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encInt64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt64Slice(state, v.Slice(0, v.Len()))
}
func encInt64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int64)
if !ok {
// It is kind int64 but not type int64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(x)
}
}
return true
}
func encInt8Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encInt8Slice(state, v.Slice(0, v.Len()))
}
func encInt8Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]int8)
if !ok {
// It is kind int8 but not type int8. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeInt(int64(x))
}
}
return true
}
func encStringArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encStringSlice(state, v.Slice(0, v.Len()))
}
func encStringSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]string)
if !ok {
// It is kind string but not type string. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != "" || state.sendZero {
state.encodeUint(uint64(len(x)))
state.b.WriteString(x)
}
}
return true
}
func encUintArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUintSlice(state, v.Slice(0, v.Len()))
}
func encUintSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint)
if !ok {
// It is kind uint but not type uint. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint16Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint16Slice(state, v.Slice(0, v.Len()))
}
func encUint16Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint16)
if !ok {
// It is kind uint16 but not type uint16. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint32Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint32Slice(state, v.Slice(0, v.Len()))
}
func encUint32Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint32)
if !ok {
// It is kind uint32 but not type uint32. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
func encUint64Array(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUint64Slice(state, v.Slice(0, v.Len()))
}
func encUint64Slice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uint64)
if !ok {
// It is kind uint64 but not type uint64. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(x)
}
}
return true
}
func encUintptrArray(state *encoderState, v reflect.Value) bool {
// Can only slice if it is addressable.
if !v.CanAddr() {
return false
}
return encUintptrSlice(state, v.Slice(0, v.Len()))
}
func encUintptrSlice(state *encoderState, v reflect.Value) bool {
slice, ok := v.Interface().([]uintptr)
if !ok {
// It is kind uintptr but not type uintptr. TODO: We can handle this unsafely.
return false
}
for _, x := range slice {
if x != 0 || state.sendZero {
state.encodeUint(uint64(x))
}
}
return true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run encgen.go -output enc_helpers.go
package gob
import (
"encoding"
"encoding/binary"
"math"
"math/bits"
"reflect"
"sync"
)
const uint64Size = 8
type encHelper func(state *encoderState, v reflect.Value) bool
// encoderState is the global execution state of an instance of the encoder.
// Field numbers are delta encoded and always increase. The field
// number is initialized to -1 so 0 comes out as delta(1). A delta of
// 0 terminates the structure.
type encoderState struct {
enc *Encoder
b *encBuffer
sendZero bool // encoding an array element or map key/value pair; send zero values
fieldnum int // the last field number written.
buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation.
next *encoderState // for free list
}
// encBuffer is an extremely simple, fast implementation of a write-only byte buffer.
// It never returns a non-nil error, but Write returns an error value so it matches io.Writer.
type encBuffer struct {
data []byte
scratch [64]byte
}
var encBufferPool = sync.Pool{
New: func() any {
e := new(encBuffer)
e.data = e.scratch[0:0]
return e
},
}
func (e *encBuffer) writeByte(c byte) {
e.data = append(e.data, c)
}
func (e *encBuffer) Write(p []byte) (int, error) {
e.data = append(e.data, p...)
return len(p), nil
}
func (e *encBuffer) WriteString(s string) {
e.data = append(e.data, s...)
}
func (e *encBuffer) Len() int {
return len(e.data)
}
func (e *encBuffer) Bytes() []byte {
return e.data
}
func (e *encBuffer) Reset() {
if len(e.data) >= tooBig {
e.data = e.scratch[0:0]
} else {
e.data = e.data[0:0]
}
}
func (enc *Encoder) newEncoderState(b *encBuffer) *encoderState {
e := enc.freeList
if e == nil {
e = new(encoderState)
e.enc = enc
} else {
enc.freeList = e.next
}
e.sendZero = false
e.fieldnum = 0
e.b = b
if len(b.data) == 0 {
b.data = b.scratch[0:0]
}
return e
}
func (enc *Encoder) freeEncoderState(e *encoderState) {
e.next = enc.freeList
enc.freeList = e
}
// Unsigned integers have a two-state encoding. If the number is less
// than 128 (0 through 0x7F), its value is written directly.
// Otherwise the value is written in big-endian byte order preceded
// by the byte length, negated.
// encodeUint writes an encoded unsigned integer to state.b.
func (state *encoderState) encodeUint(x uint64) {
if x <= 0x7F {
state.b.writeByte(uint8(x))
return
}
binary.BigEndian.PutUint64(state.buf[1:], x)
bc := bits.LeadingZeros64(x) >> 3 // 8 - bytelen(x)
state.buf[bc] = uint8(bc - uint64Size) // and then we subtract 8 to get -bytelen(x)
state.b.Write(state.buf[bc : uint64Size+1])
}
// encodeInt writes an encoded signed integer to state.w.
// The low bit of the encoding says whether to bit complement the (other bits of the)
// uint to recover the int.
func (state *encoderState) encodeInt(i int64) {
var x uint64
if i < 0 {
x = uint64(^i<<1) | 1
} else {
x = uint64(i << 1)
}
state.encodeUint(x)
}
// encOp is the signature of an encoding operator for a given type.
type encOp func(i *encInstr, state *encoderState, v reflect.Value)
// The 'instructions' of the encoding machine
type encInstr struct {
op encOp
field int // field number in input
index []int // struct index
indir int // how many pointer indirections to reach the value in the struct
}
// update emits a field number and updates the state to record its value for delta encoding.
// If the instruction pointer is nil, it does nothing
func (state *encoderState) update(instr *encInstr) {
if instr != nil {
state.encodeUint(uint64(instr.field - state.fieldnum))
state.fieldnum = instr.field
}
}
// Each encoder for a composite is responsible for handling any
// indirections associated with the elements of the data structure.
// If any pointer so reached is nil, no bytes are written. If the
// data item is zero, no bytes are written. Single values - ints,
// strings etc. - are indirected before calling their encoders.
// Otherwise, the output (for a scalar) is the field number, as an
// encoded integer, followed by the field data in its appropriate
// format.
// encIndirect dereferences pv indir times and returns the result.
func encIndirect(pv reflect.Value, indir int) reflect.Value {
for ; indir > 0; indir-- {
if pv.IsNil() {
break
}
pv = pv.Elem()
}
return pv
}
// encBool encodes the bool referenced by v as an unsigned 0 or 1.
func encBool(i *encInstr, state *encoderState, v reflect.Value) {
b := v.Bool()
if b || state.sendZero {
state.update(i)
if b {
state.encodeUint(1)
} else {
state.encodeUint(0)
}
}
}
// encInt encodes the signed integer (int int8 int16 int32 int64) referenced by v.
func encInt(i *encInstr, state *encoderState, v reflect.Value) {
value := v.Int()
if value != 0 || state.sendZero {
state.update(i)
state.encodeInt(value)
}
}
// encUint encodes the unsigned integer (uint uint8 uint16 uint32 uint64 uintptr) referenced by v.
func encUint(i *encInstr, state *encoderState, v reflect.Value) {
value := v.Uint()
if value != 0 || state.sendZero {
state.update(i)
state.encodeUint(value)
}
}
// floatBits returns a uint64 holding the bits of a floating-point number.
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
// (for example) transmit more compactly. This routine does the
// swizzling.
func floatBits(f float64) uint64 {
u := math.Float64bits(f)
return bits.ReverseBytes64(u)
}
// encFloat encodes the floating point value (float32 float64) referenced by v.
func encFloat(i *encInstr, state *encoderState, v reflect.Value) {
f := v.Float()
if f != 0 || state.sendZero {
bits := floatBits(f)
state.update(i)
state.encodeUint(bits)
}
}
// encComplex encodes the complex value (complex64 complex128) referenced by v.
// Complex numbers are just a pair of floating-point numbers, real part first.
func encComplex(i *encInstr, state *encoderState, v reflect.Value) {
c := v.Complex()
if c != 0+0i || state.sendZero {
rpart := floatBits(real(c))
ipart := floatBits(imag(c))
state.update(i)
state.encodeUint(rpart)
state.encodeUint(ipart)
}
}
// encUint8Array encodes the byte array referenced by v.
// Byte arrays are encoded as an unsigned count followed by the raw bytes.
func encUint8Array(i *encInstr, state *encoderState, v reflect.Value) {
b := v.Bytes()
if len(b) > 0 || state.sendZero {
state.update(i)
state.encodeUint(uint64(len(b)))
state.b.Write(b)
}
}
// encString encodes the string referenced by v.
// Strings are encoded as an unsigned count followed by the raw bytes.
func encString(i *encInstr, state *encoderState, v reflect.Value) {
s := v.String()
if len(s) > 0 || state.sendZero {
state.update(i)
state.encodeUint(uint64(len(s)))
state.b.WriteString(s)
}
}
// encStructTerminator encodes the end of an encoded struct
// as delta field number of 0.
func encStructTerminator(i *encInstr, state *encoderState, v reflect.Value) {
state.encodeUint(0)
}
// Execution engine
// encEngine an array of instructions indexed by field number of the encoding
// data, typically a struct. It is executed top to bottom, walking the struct.
type encEngine struct {
instr []encInstr
}
const singletonField = 0
// valid reports whether the value is valid and a non-nil pointer.
// (Slices, maps, and chans take care of themselves.)
func valid(v reflect.Value) bool {
switch v.Kind() {
case reflect.Invalid:
return false
case reflect.Pointer:
return !v.IsNil()
}
return true
}
// encodeSingle encodes a single top-level non-struct value.
func (enc *Encoder) encodeSingle(b *encBuffer, engine *encEngine, value reflect.Value) {
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = singletonField
// There is no surrounding struct to frame the transmission, so we must
// generate data even if the item is zero. To do this, set sendZero.
state.sendZero = true
instr := &engine.instr[singletonField]
if instr.indir > 0 {
value = encIndirect(value, instr.indir)
}
if valid(value) {
instr.op(instr, state, value)
}
}
// encodeStruct encodes a single struct value.
func (enc *Encoder) encodeStruct(b *encBuffer, engine *encEngine, value reflect.Value) {
if !valid(value) {
return
}
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = -1
for i := 0; i < len(engine.instr); i++ {
instr := &engine.instr[i]
if i >= value.NumField() {
// encStructTerminator
instr.op(instr, state, reflect.Value{})
break
}
field := value.FieldByIndex(instr.index)
if instr.indir > 0 {
field = encIndirect(field, instr.indir)
// TODO: Is field guaranteed valid? If so we could avoid this check.
if !valid(field) {
continue
}
}
instr.op(instr, state, field)
}
}
// encodeArray encodes an array.
func (enc *Encoder) encodeArray(b *encBuffer, value reflect.Value, op encOp, elemIndir int, length int, helper encHelper) {
state := enc.newEncoderState(b)
defer enc.freeEncoderState(state)
state.fieldnum = -1
state.sendZero = true
state.encodeUint(uint64(length))
if helper != nil && helper(state, value) {
return
}
for i := 0; i < length; i++ {
elem := value.Index(i)
if elemIndir > 0 {
elem = encIndirect(elem, elemIndir)
// TODO: Is elem guaranteed valid? If so we could avoid this check.
if !valid(elem) {
errorf("encodeArray: nil element")
}
}
op(nil, state, elem)
}
}
// encodeReflectValue is a helper for maps. It encodes the value v.
func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
for i := 0; i < indir && v.IsValid(); i++ {
v = reflect.Indirect(v)
}
if !v.IsValid() {
errorf("encodeReflectValue: nil element")
}
op(nil, state, v)
}
// encodeMap encodes a map as unsigned count followed by key:value pairs.
func (enc *Encoder) encodeMap(b *encBuffer, mv reflect.Value, keyOp, elemOp encOp, keyIndir, elemIndir int) {
state := enc.newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
state.encodeUint(uint64(mv.Len()))
mi := mv.MapRange()
for mi.Next() {
encodeReflectValue(state, mi.Key(), keyOp, keyIndir)
encodeReflectValue(state, mi.Value(), elemOp, elemIndir)
}
enc.freeEncoderState(state)
}
// encodeInterface encodes the interface value iv.
// To send an interface, we send a string identifying the concrete type, followed
// by the type identifier (which might require defining that type right now), followed
// by the concrete value. A nil value gets sent as the empty string for the name,
// followed by no value.
func (enc *Encoder) encodeInterface(b *encBuffer, iv reflect.Value) {
// Gobs can encode nil interface values but not typed interface
// values holding nil pointers, since nil pointers point to no value.
elem := iv.Elem()
if elem.Kind() == reflect.Pointer && elem.IsNil() {
errorf("gob: cannot encode nil pointer of type %s inside interface", iv.Elem().Type())
}
state := enc.newEncoderState(b)
state.fieldnum = -1
state.sendZero = true
if iv.IsNil() {
state.encodeUint(0)
return
}
ut := userType(iv.Elem().Type())
namei, ok := concreteTypeToName.Load(ut.base)
if !ok {
errorf("type not registered for interface: %s", ut.base)
}
name := namei.(string)
// Send the name.
state.encodeUint(uint64(len(name)))
state.b.WriteString(name)
// Define the type id if necessary.
enc.sendTypeDescriptor(enc.writer(), state, ut)
// Send the type id.
enc.sendTypeId(state, ut)
// Encode the value into a new buffer. Any nested type definitions
// should be written to b, before the encoded value.
enc.pushWriter(b)
data := encBufferPool.Get().(*encBuffer)
data.Write(spaceForLength)
enc.encode(data, elem, ut)
if enc.err != nil {
error_(enc.err)
}
enc.popWriter()
enc.writeMessage(b, data)
data.Reset()
encBufferPool.Put(data)
if enc.err != nil {
error_(enc.err)
}
enc.freeEncoderState(state)
}
// isZero reports whether the value is the zero of its type.
func isZero(val reflect.Value) bool {
switch val.Kind() {
case reflect.Array:
for i := 0; i < val.Len(); i++ {
if !isZero(val.Index(i)) {
return false
}
}
return true
case reflect.Map, reflect.Slice, reflect.String:
return val.Len() == 0
case reflect.Bool:
return !val.Bool()
case reflect.Complex64, reflect.Complex128:
return val.Complex() == 0
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Pointer:
return val.IsNil()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return val.Int() == 0
case reflect.Float32, reflect.Float64:
return val.Float() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return val.Uint() == 0
case reflect.Struct:
for i := 0; i < val.NumField(); i++ {
if !isZero(val.Field(i)) {
return false
}
}
return true
}
panic("unknown type in isZero " + val.Type().String())
}
// encodeGobEncoder encodes a value that implements the GobEncoder interface.
// The data is sent as a byte array.
func (enc *Encoder) encodeGobEncoder(b *encBuffer, ut *userTypeInfo, v reflect.Value) {
// TODO: should we catch panics from the called method?
var data []byte
var err error
// We know it's one of these.
switch ut.externalEnc {
case xGob:
data, err = v.Interface().(GobEncoder).GobEncode()
case xBinary:
data, err = v.Interface().(encoding.BinaryMarshaler).MarshalBinary()
case xText:
data, err = v.Interface().(encoding.TextMarshaler).MarshalText()
}
if err != nil {
error_(err)
}
state := enc.newEncoderState(b)
state.fieldnum = -1
state.encodeUint(uint64(len(data)))
state.b.Write(data)
enc.freeEncoderState(state)
}
var encOpTable = [...]encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
reflect.Int8: encInt,
reflect.Int16: encInt,
reflect.Int32: encInt,
reflect.Int64: encInt,
reflect.Uint: encUint,
reflect.Uint8: encUint,
reflect.Uint16: encUint,
reflect.Uint32: encUint,
reflect.Uint64: encUint,
reflect.Uintptr: encUint,
reflect.Float32: encFloat,
reflect.Float64: encFloat,
reflect.Complex64: encComplex,
reflect.Complex128: encComplex,
reflect.String: encString,
}
// encOpFor returns (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it.
func encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp, building map[*typeInfo]bool) (*encOp, int) {
ut := userType(rt)
// If the type implements GobEncoder, we handle it without further processing.
if ut.externalEnc != 0 {
return gobEncodeOpFor(ut)
}
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
return opPtr, ut.indir
}
typ := ut.base
indir := ut.indir
k := typ.Kind()
var op encOp
if int(k) < len(encOpTable) {
op = encOpTable[k]
}
if op == nil {
inProgress[rt] = &op
// Special cases
switch t := typ; t.Kind() {
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
op = encUint8Array
break
}
// Slices have a header; we decode it to find the underlying array.
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
helper := encSliceHelper[t.Elem().Kind()]
op = func(i *encInstr, state *encoderState, slice reflect.Value) {
if !state.sendZero && slice.Len() == 0 {
return
}
state.update(i)
state.enc.encodeArray(state.b, slice, *elemOp, elemIndir, slice.Len(), helper)
}
case reflect.Array:
// True arrays have size in the type.
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
helper := encArrayHelper[t.Elem().Kind()]
op = func(i *encInstr, state *encoderState, array reflect.Value) {
state.update(i)
state.enc.encodeArray(state.b, array, *elemOp, elemIndir, array.Len(), helper)
}
case reflect.Map:
keyOp, keyIndir := encOpFor(t.Key(), inProgress, building)
elemOp, elemIndir := encOpFor(t.Elem(), inProgress, building)
op = func(i *encInstr, state *encoderState, mv reflect.Value) {
// We send zero-length (but non-nil) maps because the
// receiver might want to use the map. (Maps don't use append.)
if !state.sendZero && mv.IsNil() {
return
}
state.update(i)
state.enc.encodeMap(state.b, mv, *keyOp, *elemOp, keyIndir, elemIndir)
}
case reflect.Struct:
// Generate a closure that calls out to the engine for the nested type.
getEncEngine(userType(typ), building)
info := mustGetTypeInfo(typ)
op = func(i *encInstr, state *encoderState, sv reflect.Value) {
state.update(i)
// indirect through info to delay evaluation for recursive structs
enc := info.encoder.Load()
state.enc.encodeStruct(state.b, enc, sv)
}
case reflect.Interface:
op = func(i *encInstr, state *encoderState, iv reflect.Value) {
if !state.sendZero && (!iv.IsValid() || iv.IsNil()) {
return
}
state.update(i)
state.enc.encodeInterface(state.b, iv)
}
}
}
if op == nil {
errorf("can't happen: encode type %s", rt)
}
return &op, indir
}
// gobEncodeOpFor returns the op for a type that is known to implement GobEncoder.
func gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
rt := ut.user
if ut.encIndir == -1 {
rt = reflect.PointerTo(rt)
} else if ut.encIndir > 0 {
for i := int8(0); i < ut.encIndir; i++ {
rt = rt.Elem()
}
}
var op encOp
op = func(i *encInstr, state *encoderState, v reflect.Value) {
if ut.encIndir == -1 {
// Need to climb up one level to turn value into pointer.
if !v.CanAddr() {
errorf("unaddressable value of type %s", rt)
}
v = v.Addr()
}
if !state.sendZero && isZero(v) {
return
}
state.update(i)
state.enc.encodeGobEncoder(state.b, ut, v)
}
return &op, int(ut.encIndir) // encIndir: op will get called with p == address of receiver.
}
// compileEnc returns the engine to compile the type.
func compileEnc(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
srt := ut.base
engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
rt := ut.base
if ut.externalEnc != 0 {
rt = ut.user
}
if ut.externalEnc == 0 && srt.Kind() == reflect.Struct {
for fieldNum, wireFieldNum := 0, 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum)
if !isSent(&f) {
continue
}
op, indir := encOpFor(f.Type, seen, building)
engine.instr = append(engine.instr, encInstr{*op, wireFieldNum, f.Index, indir})
wireFieldNum++
}
if srt.NumField() > 0 && len(engine.instr) == 0 {
errorf("type %s has no exported fields", rt)
}
engine.instr = append(engine.instr, encInstr{encStructTerminator, 0, nil, 0})
} else {
engine.instr = make([]encInstr, 1)
op, indir := encOpFor(rt, seen, building)
engine.instr[0] = encInstr{*op, singletonField, nil, indir}
}
return engine
}
// getEncEngine returns the engine to compile the type.
func getEncEngine(ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
info, err := getTypeInfo(ut)
if err != nil {
error_(err)
}
enc := info.encoder.Load()
if enc == nil {
enc = buildEncEngine(info, ut, building)
}
return enc
}
func buildEncEngine(info *typeInfo, ut *userTypeInfo, building map[*typeInfo]bool) *encEngine {
// Check for recursive types.
if building != nil && building[info] {
return nil
}
info.encInit.Lock()
defer info.encInit.Unlock()
enc := info.encoder.Load()
if enc == nil {
if building == nil {
building = make(map[*typeInfo]bool)
}
building[info] = true
enc = compileEnc(ut, building)
info.encoder.Store(enc)
}
return enc
}
func (enc *Encoder) encode(b *encBuffer, value reflect.Value, ut *userTypeInfo) {
defer catchError(&enc.err)
engine := getEncEngine(ut, nil)
indir := ut.indir
if ut.externalEnc != 0 {
indir = int(ut.encIndir)
}
for i := 0; i < indir; i++ {
value = reflect.Indirect(value)
}
if ut.externalEnc == 0 && value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value)
} else {
enc.encodeSingle(b, engine, value)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"errors"
"io"
"reflect"
"sync"
)
// An Encoder manages the transmission of type and data information to the
// other side of a connection. It is safe for concurrent use by multiple
// goroutines.
type Encoder struct {
mutex sync.Mutex // each item must be sent atomically
w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent
countState *encoderState // stage for writing counts
freeList *encoderState // list of free encoderStates; avoids reallocation
byteBuf encBuffer // buffer for top-level encoderState
err error
}
// Before we encode a message, we reserve space at the head of the
// buffer in which to encode its length. This means we can use the
// buffer to assemble the message without another allocation.
const maxLength = 9 // Maximum size of an encoded length.
var spaceForLength = make([]byte, maxLength)
// NewEncoder returns a new encoder that will transmit on the io.Writer.
func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder)
enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId)
enc.countState = enc.newEncoderState(new(encBuffer))
return enc
}
// writer() returns the innermost writer the encoder is using
func (enc *Encoder) writer() io.Writer {
return enc.w[len(enc.w)-1]
}
// pushWriter adds a writer to the encoder.
func (enc *Encoder) pushWriter(w io.Writer) {
enc.w = append(enc.w, w)
}
// popWriter pops the innermost writer.
func (enc *Encoder) popWriter() {
enc.w = enc.w[0 : len(enc.w)-1]
}
func (enc *Encoder) setError(err error) {
if enc.err == nil { // remember the first.
enc.err = err
}
}
// writeMessage sends the data item preceded by a unsigned count of its length.
func (enc *Encoder) writeMessage(w io.Writer, b *encBuffer) {
// Space has been reserved for the length at the head of the message.
// This is a little dirty: we grab the slice from the bytes.Buffer and massage
// it by hand.
message := b.Bytes()
messageLen := len(message) - maxLength
// Length cannot be bigger than the decoder can handle.
if messageLen >= tooBig {
enc.setError(errors.New("gob: encoder: message too big"))
return
}
// Encode the length.
enc.countState.b.Reset()
enc.countState.encodeUint(uint64(messageLen))
// Copy the length to be a prefix of the message.
offset := maxLength - enc.countState.b.Len()
copy(message[offset:], enc.countState.b.Bytes())
// Write the data.
_, err := w.Write(message[offset:])
// Drain the buffer and restore the space at the front for the count of the next message.
b.Reset()
b.Write(spaceForLength)
if err != nil {
enc.setError(err)
}
}
// sendActualType sends the requested type, without further investigation, unless
// it's been sent before.
func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
if _, alreadySent := enc.sent[actual]; alreadySent {
return false
}
info, err := getTypeInfo(ut)
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.ValueOf(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type, both what the user gave us and the base type.
enc.sent[ut.base] = info.id
if ut.user != ut.base {
enc.sent[ut.user] = info.id
}
// Now send the inner types
switch st := actual; st.Kind() {
case reflect.Struct:
for i := 0; i < st.NumField(); i++ {
if isExported(st.Field(i).Name) {
enc.sendType(w, state, st.Field(i).Type)
}
}
case reflect.Array, reflect.Slice:
enc.sendType(w, state, st.Elem())
case reflect.Map:
enc.sendType(w, state, st.Key())
enc.sendType(w, state, st.Elem())
}
return true
}
// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
ut := userType(origt)
if ut.externalEnc != 0 {
// The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that the base type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.base)
}
// It's a concrete value, so drill down to the base type.
switch rt := ut.base; rt.Kind() {
default:
// Basic types and interfaces do not need to be described.
return
case reflect.Slice:
// If it's []uint8, don't send; it's considered basic.
if rt.Elem().Kind() == reflect.Uint8 {
return
}
// Otherwise we do send.
break
case reflect.Array:
// arrays must be sent so we know their lengths and element types.
break
case reflect.Map:
// maps must be sent so we know their lengths and key/value types.
break
case reflect.Struct:
// structs must be sent so we know their fields.
break
case reflect.Chan, reflect.Func:
// If we get here, it's a field of a struct; ignore it.
return
}
return enc.sendActualType(w, state, ut, ut.base)
}
// Encode transmits the data item represented by the empty interface value,
// guaranteeing that all necessary type information has been transmitted first.
// Passing a nil pointer to Encoder will panic, as they cannot be transmitted by gob.
func (enc *Encoder) Encode(e any) error {
return enc.EncodeValue(reflect.ValueOf(e))
}
// sendTypeDescriptor makes sure the remote side knows about this type.
// It will send a descriptor if this is the first time the type has been
// sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side.
// First, have we already sent this type?
rt := ut.base
if ut.externalEnc != 0 {
rt = ut.user
}
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
sent := enc.sendType(w, state, rt)
if enc.err != nil {
return
}
// If the type info has still not been transmitted, it means we have
// a singleton basic type (int, []byte etc.) at top level. We don't
// need to send the type info but we do need to update enc.sent.
if !sent {
info, err := getTypeInfo(ut)
if err != nil {
enc.setError(err)
return
}
enc.sent[rt] = info.id
}
}
}
// sendTypeId sends the id, which must have already been defined.
func (enc *Encoder) sendTypeId(state *encoderState, ut *userTypeInfo) {
// Identify the type of this top-level value.
state.encodeInt(int64(enc.sent[ut.base]))
}
// EncodeValue transmits the data item represented by the reflection value,
// guaranteeing that all necessary type information has been transmitted first.
// Passing a nil pointer to EncodeValue will panic, as they cannot be transmitted by gob.
func (enc *Encoder) EncodeValue(value reflect.Value) error {
if value.Kind() == reflect.Invalid {
return errors.New("gob: cannot encode nil value")
}
if value.Kind() == reflect.Pointer && value.IsNil() {
panic("gob: cannot encode nil pointer of type " + value.Type().String())
}
// Make sure we're single-threaded through here, so multiple
// goroutines can share an encoder.
enc.mutex.Lock()
defer enc.mutex.Unlock()
// Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1]
ut, err := validUserType(value.Type())
if err != nil {
return err
}
enc.err = nil
enc.byteBuf.Reset()
enc.byteBuf.Write(spaceForLength)
state := enc.newEncoderState(&enc.byteBuf)
enc.sendTypeDescriptor(enc.writer(), state, ut)
enc.sendTypeId(state, ut)
if enc.err != nil {
return enc.err
}
// Encode the object.
enc.encode(state.b, value, ut)
if enc.err == nil {
enc.writeMessage(enc.writer(), state.b)
}
enc.freeEncoderState(state)
return enc.err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import "fmt"
// Errors in decoding and encoding are handled using panic and recover.
// Panics caused by user error (that is, everything except run-time panics
// such as "index out of bounds" errors) do not leave the file that caused
// them, but are instead turned into plain error returns. Encoding and
// decoding functions and methods that do not return an error either use
// panic to report an error or are guaranteed error-free.
// A gobError is used to distinguish errors (panics) generated in this package.
type gobError struct {
err error
}
// errorf is like error_ but takes Printf-style arguments to construct an error.
// It always prefixes the message with "gob: ".
func errorf(format string, args ...any) {
error_(fmt.Errorf("gob: "+format, args...))
}
// error_ wraps the argument error and uses it as the argument to panic.
func error_(err error) {
panic(gobError{err})
}
// catchError is meant to be used as a deferred function to turn a panic(gobError) into a
// plain error. It overwrites the error return of the function that deferred its call.
func catchError(err *error) {
if e := recover(); e != nil {
ge, ok := e.(gobError)
if !ok {
panic(e)
}
*err = ge.err
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"encoding"
"errors"
"fmt"
"os"
"reflect"
"sync"
"sync/atomic"
"unicode"
"unicode/utf8"
)
// userTypeInfo stores the information associated with a type the user has handed
// to the package. It's computed once and stored in a map keyed by reflection
// type.
type userTypeInfo struct {
user reflect.Type // the type the user handed us
base reflect.Type // the base type after all indirections
indir int // number of indirections to reach the base type
externalEnc int // xGob, xBinary, or xText
externalDec int // xGob, xBinary or xText
encIndir int8 // number of indirections to reach the receiver type; may be negative
decIndir int8 // number of indirections to reach the receiver type; may be negative
}
// externalEncoding bits
const (
xGob = 1 + iota // GobEncoder or GobDecoder
xBinary // encoding.BinaryMarshaler or encoding.BinaryUnmarshaler
xText // encoding.TextMarshaler or encoding.TextUnmarshaler
)
var userTypeCache sync.Map // map[reflect.Type]*userTypeInfo
// validUserType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, err will be non-nil. To be used when the error handler
// is not set up.
func validUserType(rt reflect.Type) (*userTypeInfo, error) {
if ui, ok := userTypeCache.Load(rt); ok {
return ui.(*userTypeInfo), nil
}
// Construct a new userTypeInfo and atomically add it to the userTypeCache.
// If we lose the race, we'll waste a little CPU and create a little garbage
// but return the existing value anyway.
ut := new(userTypeInfo)
ut.base = rt
ut.user = rt
// A type that is just a cycle of pointers (such as type T *T) cannot
// be represented in gobs, which need some concrete data. We use a
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt := ut.base
if pt.Kind() != reflect.Pointer {
break
}
ut.base = pt.Elem()
if ut.base == slowpoke { // ut.base lapped slowpoke
// recursive pointer type.
return nil, errors.New("can't represent recursive pointer type " + ut.base.String())
}
if ut.indir%2 == 0 {
slowpoke = slowpoke.Elem()
}
ut.indir++
}
if ok, indir := implementsInterface(ut.user, gobEncoderInterfaceType); ok {
ut.externalEnc, ut.encIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryMarshalerInterfaceType); ok {
ut.externalEnc, ut.encIndir = xBinary, indir
}
// NOTE(rsc): Would like to allow MarshalText here, but results in incompatibility
// with older encodings for net.IP. See golang.org/issue/6760.
// } else if ok, indir := implementsInterface(ut.user, textMarshalerInterfaceType); ok {
// ut.externalEnc, ut.encIndir = xText, indir
// }
if ok, indir := implementsInterface(ut.user, gobDecoderInterfaceType); ok {
ut.externalDec, ut.decIndir = xGob, indir
} else if ok, indir := implementsInterface(ut.user, binaryUnmarshalerInterfaceType); ok {
ut.externalDec, ut.decIndir = xBinary, indir
}
// See note above.
// } else if ok, indir := implementsInterface(ut.user, textUnmarshalerInterfaceType); ok {
// ut.externalDec, ut.decIndir = xText, indir
// }
ui, _ := userTypeCache.LoadOrStore(rt, ut)
return ui.(*userTypeInfo), nil
}
var (
gobEncoderInterfaceType = reflect.TypeOf((*GobEncoder)(nil)).Elem()
gobDecoderInterfaceType = reflect.TypeOf((*GobDecoder)(nil)).Elem()
binaryMarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()
binaryUnmarshalerInterfaceType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
textMarshalerInterfaceType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerInterfaceType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)
// implementsInterface reports whether the type implements the
// gobEncoder/gobDecoder interface.
// It also returns the number of indirections required to get to the
// implementation.
func implementsInterface(typ, gobEncDecType reflect.Type) (success bool, indir int8) {
if typ == nil {
return
}
rt := typ
// The type might be a pointer and we need to keep
// dereferencing to the base type until we find an implementation.
for {
if rt.Implements(gobEncDecType) {
return true, indir
}
if p := rt; p.Kind() == reflect.Pointer {
indir++
if indir > 100 { // insane number of indirections
return false, 0
}
rt = p.Elem()
continue
}
break
}
// No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
if typ.Kind() != reflect.Pointer {
// Not a pointer, but does the pointer work?
if reflect.PointerTo(typ).Implements(gobEncDecType) {
return true, -1
}
}
return false, 0
}
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
ut, err := validUserType(rt)
if err != nil {
error_(err)
}
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info.
type typeId int32
var typeLock sync.Mutex // set while building a type
const firstUserId = 64 // lowest id number granted to user
type gobType interface {
id() typeId
setId(id typeId)
name() string
string() string // not public; only for debugging
safeString(seen map[typeId]bool) string
}
var types = make(map[reflect.Type]gobType, 32)
var idToType = make([]gobType, 1, firstUserId)
var builtinIdToTypeSlice [firstUserId]gobType // set in init() after builtins are established
func builtinIdToType(id typeId) gobType {
if id < 0 || int(id) >= len(builtinIdToTypeSlice) {
return nil
}
return builtinIdToTypeSlice[id]
}
func setTypeId(typ gobType) {
// When building recursive types, someone may get there before us.
if typ.id() != 0 {
return
}
nextId := typeId(len(idToType))
typ.setId(nextId)
idToType = append(idToType, typ)
}
func (t typeId) gobType() gobType {
if t == 0 {
return nil
}
return idToType[t]
}
// string returns the string representation of the type associated with the typeId.
func (t typeId) string() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().string()
}
// Name returns the name of the type associated with the typeId.
func (t typeId) name() string {
if t.gobType() == nil {
return "<nil>"
}
return t.gobType().name()
}
// CommonType holds elements of all types.
// It is a historical artifact, kept for binary compatibility and exported
// only for the benefit of the package's encoding of type descriptors. It is
// not intended for direct use by clients.
type CommonType struct {
Name string
Id typeId
}
func (t *CommonType) id() typeId { return t.Id }
func (t *CommonType) setId(id typeId) { t.Id = id }
func (t *CommonType) string() string { return t.Name }
func (t *CommonType) safeString(seen map[typeId]bool) string {
return t.Name
}
func (t *CommonType) name() string { return t.Name }
// Create and check predefined types
// The string for tBytes is "bytes" not "[]byte" to signify its specialness.
var (
// Primordial types, needed during initialization.
// Always passed as pointers so the interface{} type
// goes through without losing its interfaceness.
tBool = bootstrapType("bool", (*bool)(nil))
tInt = bootstrapType("int", (*int)(nil))
tUint = bootstrapType("uint", (*uint)(nil))
tFloat = bootstrapType("float", (*float64)(nil))
tBytes = bootstrapType("bytes", (*[]byte)(nil))
tString = bootstrapType("string", (*string)(nil))
tComplex = bootstrapType("complex", (*complex128)(nil))
tInterface = bootstrapType("interface", (*any)(nil))
// Reserve some Ids for compatible expansion
tReserved7 = bootstrapType("_reserved1", (*struct{ r7 int })(nil))
tReserved6 = bootstrapType("_reserved1", (*struct{ r6 int })(nil))
tReserved5 = bootstrapType("_reserved1", (*struct{ r5 int })(nil))
tReserved4 = bootstrapType("_reserved1", (*struct{ r4 int })(nil))
tReserved3 = bootstrapType("_reserved1", (*struct{ r3 int })(nil))
tReserved2 = bootstrapType("_reserved1", (*struct{ r2 int })(nil))
tReserved1 = bootstrapType("_reserved1", (*struct{ r1 int })(nil))
)
// Predefined because it's needed by the Decoder
var tWireType = mustGetTypeInfo(reflect.TypeOf((*wireType)(nil)).Elem()).id
var wireTypeUserInfo *userTypeInfo // userTypeInfo of (*wireType)
func init() {
// Some magic numbers to make sure there are no surprises.
checkId(16, tWireType)
checkId(17, mustGetTypeInfo(reflect.TypeOf((*arrayType)(nil)).Elem()).id)
checkId(18, mustGetTypeInfo(reflect.TypeOf((*CommonType)(nil)).Elem()).id)
checkId(19, mustGetTypeInfo(reflect.TypeOf((*sliceType)(nil)).Elem()).id)
checkId(20, mustGetTypeInfo(reflect.TypeOf((*structType)(nil)).Elem()).id)
checkId(21, mustGetTypeInfo(reflect.TypeOf((*fieldType)(nil)).Elem()).id)
checkId(23, mustGetTypeInfo(reflect.TypeOf((*mapType)(nil)).Elem()).id)
copy(builtinIdToTypeSlice[:], idToType)
// Move the id space upwards to allow for growth in the predefined world
// without breaking existing files.
if nextId := len(idToType); nextId > firstUserId {
panic(fmt.Sprintln("nextId too large:", nextId))
}
idToType = idToType[:firstUserId]
registerBasics()
wireTypeUserInfo = userType(reflect.TypeOf((*wireType)(nil)))
}
// Array type
type arrayType struct {
CommonType
Elem typeId
Len int
}
func newArrayType(name string) *arrayType {
a := &arrayType{CommonType{Name: name}, 0, 0}
return a
}
func (a *arrayType) init(elem gobType, len int) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(a)
a.Elem = elem.id()
a.Len = len
}
func (a *arrayType) safeString(seen map[typeId]bool) string {
if seen[a.Id] {
return a.Name
}
seen[a.Id] = true
return fmt.Sprintf("[%d]%s", a.Len, a.Elem.gobType().safeString(seen))
}
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
// GobEncoder type (something that implements the GobEncoder interface)
type gobEncoderType struct {
CommonType
}
func newGobEncoderType(name string) *gobEncoderType {
g := &gobEncoderType{CommonType{Name: name}}
setTypeId(g)
return g
}
func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
return g.Name
}
func (g *gobEncoderType) string() string { return g.Name }
// Map type
type mapType struct {
CommonType
Key typeId
Elem typeId
}
func newMapType(name string) *mapType {
m := &mapType{CommonType{Name: name}, 0, 0}
return m
}
func (m *mapType) init(key, elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(m)
m.Key = key.id()
m.Elem = elem.id()
}
func (m *mapType) safeString(seen map[typeId]bool) string {
if seen[m.Id] {
return m.Name
}
seen[m.Id] = true
key := m.Key.gobType().safeString(seen)
elem := m.Elem.gobType().safeString(seen)
return fmt.Sprintf("map[%s]%s", key, elem)
}
func (m *mapType) string() string { return m.safeString(make(map[typeId]bool)) }
// Slice type
type sliceType struct {
CommonType
Elem typeId
}
func newSliceType(name string) *sliceType {
s := &sliceType{CommonType{Name: name}, 0}
return s
}
func (s *sliceType) init(elem gobType) {
// Set our type id before evaluating the element's, in case it's our own.
setTypeId(s)
// See the comments about ids in newTypeObject. Only slices and
// structs have mutual recursion.
if elem.id() == 0 {
setTypeId(elem)
}
s.Elem = elem.id()
}
func (s *sliceType) safeString(seen map[typeId]bool) string {
if seen[s.Id] {
return s.Name
}
seen[s.Id] = true
return fmt.Sprintf("[]%s", s.Elem.gobType().safeString(seen))
}
func (s *sliceType) string() string { return s.safeString(make(map[typeId]bool)) }
// Struct type
type fieldType struct {
Name string
Id typeId
}
type structType struct {
CommonType
Field []*fieldType
}
func (s *structType) safeString(seen map[typeId]bool) string {
if s == nil {
return "<nil>"
}
if _, ok := seen[s.Id]; ok {
return s.Name
}
seen[s.Id] = true
str := s.Name + " = struct { "
for _, f := range s.Field {
str += fmt.Sprintf("%s %s; ", f.Name, f.Id.gobType().safeString(seen))
}
str += "}"
return str
}
func (s *structType) string() string { return s.safeString(make(map[typeId]bool)) }
func newStructType(name string) *structType {
s := &structType{CommonType{Name: name}, nil}
// For historical reasons we set the id here rather than init.
// See the comment in newTypeObject for details.
setTypeId(s)
return s
}
// newTypeObject allocates a gobType for the reflection type rt.
// Unless ut represents a GobEncoder, rt should be the base type
// of ut.
// This is only called from the encoding side. The decoding side
// works through typeIds and userTypeInfos alone.
func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) {
// Does this type implement GobEncoder?
if ut.externalEnc != 0 {
return newGobEncoderType(name), nil
}
var err error
var type0, type1 gobType
defer func() {
if err != nil {
delete(types, rt)
}
}()
// Install the top-level type before the subtypes (e.g. struct before
// fields) so recursive types can be constructed safely.
switch t := rt; t.Kind() {
// All basic types are easy: they are predefined.
case reflect.Bool:
return tBool.gobType(), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return tInt.gobType(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return tUint.gobType(), nil
case reflect.Float32, reflect.Float64:
return tFloat.gobType(), nil
case reflect.Complex64, reflect.Complex128:
return tComplex.gobType(), nil
case reflect.String:
return tString.gobType(), nil
case reflect.Interface:
return tInterface.gobType(), nil
case reflect.Array:
at := newArrayType(name)
types[rt] = at
type0, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
// Historical aside:
// For arrays, maps, and slices, we set the type id after the elements
// are constructed. This is to retain the order of type id allocation after
// a fix made to handle recursive types, which changed the order in
// which types are built. Delaying the setting in this way preserves
// type ids while allowing recursive types to be described. Structs,
// done below, were already handling recursion correctly so they
// assign the top-level id before those of the field.
at.init(type0, t.Len())
return at, nil
case reflect.Map:
mt := newMapType(name)
types[rt] = mt
type0, err = getBaseType("", t.Key())
if err != nil {
return nil, err
}
type1, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
mt.init(type0, type1)
return mt, nil
case reflect.Slice:
// []byte == []uint8 is a special case
if t.Elem().Kind() == reflect.Uint8 {
return tBytes.gobType(), nil
}
st := newSliceType(name)
types[rt] = st
type0, err = getBaseType(t.Elem().Name(), t.Elem())
if err != nil {
return nil, err
}
st.init(type0)
return st, nil
case reflect.Struct:
st := newStructType(name)
types[rt] = st
idToType[st.id()] = st
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !isSent(&f) {
continue
}
typ := userType(f.Type).base
tname := typ.Name()
if tname == "" {
t := userType(f.Type).base
tname = t.String()
}
gt, err := getBaseType(tname, f.Type)
if err != nil {
return nil, err
}
// Some mutually recursive types can cause us to be here while
// still defining the element. Fix the element type id here.
// We could do this more neatly by setting the id at the start of
// building every type, but that would break binary compatibility.
if gt.id() == 0 {
setTypeId(gt)
}
st.Field = append(st.Field, &fieldType{f.Name, gt.id()})
}
return st, nil
default:
return nil, errors.New("gob NewTypeObject can't handle type: " + rt.String())
}
}
// isExported reports whether this is an exported - upper case - name.
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// isSent reports whether this struct field is to be transmitted.
// It will be transmitted only if it is exported and not a chan or func field
// or pointer to chan or func.
func isSent(field *reflect.StructField) bool {
if !isExported(field.Name) {
return false
}
// If the field is a chan or func or pointer thereto, don't send it.
// That is, treat it like an unexported field.
typ := field.Type
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
if typ.Kind() == reflect.Chan || typ.Kind() == reflect.Func {
return false
}
return true
}
// getBaseType returns the Gob type describing the given reflect.Type's base type.
// typeLock must be held.
func getBaseType(name string, rt reflect.Type) (gobType, error) {
ut := userType(rt)
return getType(name, ut, ut.base)
}
// getType returns the Gob type describing the given reflect.Type.
// Should be called only when handling GobEncoders/Decoders,
// which may be pointers. All other types are handled through the
// base type, never a pointer.
// typeLock must be held.
func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, error) {
typ, present := types[rt]
if present {
return typ, nil
}
typ, err := newTypeObject(name, ut, rt)
if err == nil {
types[rt] = typ
}
return typ, err
}
func checkId(want, got typeId) {
if want != got {
fmt.Fprintf(os.Stderr, "checkId: %d should be %d\n", int(got), int(want))
panic("bootstrap type wrong id: " + got.name() + " " + got.string() + " not " + want.string())
}
}
// used for building the basic types; called only from init(). the incoming
// interface always refers to a pointer.
func bootstrapType(name string, e any) typeId {
rt := reflect.TypeOf(e).Elem()
_, present := types[rt]
if present {
panic("bootstrap type already present: " + name + ", " + rt.String())
}
typ := &CommonType{Name: name}
types[rt] = typ
setTypeId(typ)
userType(rt) // might as well cache it now
return typ.id()
}
// Representation of the information we send and receive about this type.
// Each value we send is preceded by its type definition: an encoded int.
// However, the very first time we send the value, we first send the pair
// (-id, wireType).
// For bootstrapping purposes, we assume that the recipient knows how
// to decode a wireType; it is exactly the wireType struct here, interpreted
// using the gob rules for sending a structure, except that we assume the
// ids for wireType and structType etc. are known. The relevant pieces
// are built in encode.go's init() function.
// To maintain binary compatibility, if you extend this type, always put
// the new fields last.
type wireType struct {
ArrayT *arrayType
SliceT *sliceType
StructT *structType
MapT *mapType
GobEncoderT *gobEncoderType
BinaryMarshalerT *gobEncoderType
TextMarshalerT *gobEncoderType
}
func (w *wireType) string() string {
const unknown = "unknown type"
if w == nil {
return unknown
}
switch {
case w.ArrayT != nil:
return w.ArrayT.Name
case w.SliceT != nil:
return w.SliceT.Name
case w.StructT != nil:
return w.StructT.Name
case w.MapT != nil:
return w.MapT.Name
case w.GobEncoderT != nil:
return w.GobEncoderT.Name
case w.BinaryMarshalerT != nil:
return w.BinaryMarshalerT.Name
case w.TextMarshalerT != nil:
return w.TextMarshalerT.Name
}
return unknown
}
type typeInfo struct {
id typeId
encInit sync.Mutex // protects creation of encoder
encoder atomic.Pointer[encEngine]
wire *wireType
}
// typeInfoMap is an atomic pointer to map[reflect.Type]*typeInfo.
// It's updated copy-on-write. Readers just do an atomic load
// to get the current version of the map. Writers make a full copy of
// the map and atomically update the pointer to point to the new map.
// Under heavy read contention, this is significantly faster than a map
// protected by a mutex.
var typeInfoMap atomic.Value
// typeInfoMapInit is used instead of typeInfoMap during init time,
// as types are registered sequentially during init and we can save
// the overhead of making map copies.
// It is saved to typeInfoMap and set to nil before init finishes.
var typeInfoMapInit = make(map[reflect.Type]*typeInfo, 16)
func lookupTypeInfo(rt reflect.Type) *typeInfo {
if m := typeInfoMapInit; m != nil {
return m[rt]
}
m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo)
return m[rt]
}
func getTypeInfo(ut *userTypeInfo) (*typeInfo, error) {
rt := ut.base
if ut.externalEnc != 0 {
// We want the user type, not the base type.
rt = ut.user
}
if info := lookupTypeInfo(rt); info != nil {
return info, nil
}
return buildTypeInfo(ut, rt)
}
// buildTypeInfo constructs the type information for the type
// and stores it in the type info map.
func buildTypeInfo(ut *userTypeInfo, rt reflect.Type) (*typeInfo, error) {
typeLock.Lock()
defer typeLock.Unlock()
if info := lookupTypeInfo(rt); info != nil {
return info, nil
}
gt, err := getBaseType(rt.Name(), rt)
if err != nil {
return nil, err
}
info := &typeInfo{id: gt.id()}
if ut.externalEnc != 0 {
userType, err := getType(rt.Name(), ut, rt)
if err != nil {
return nil, err
}
gt := userType.id().gobType().(*gobEncoderType)
switch ut.externalEnc {
case xGob:
info.wire = &wireType{GobEncoderT: gt}
case xBinary:
info.wire = &wireType{BinaryMarshalerT: gt}
case xText:
info.wire = &wireType{TextMarshalerT: gt}
}
rt = ut.user
} else {
t := info.id.gobType()
switch typ := rt; typ.Kind() {
case reflect.Array:
info.wire = &wireType{ArrayT: t.(*arrayType)}
case reflect.Map:
info.wire = &wireType{MapT: t.(*mapType)}
case reflect.Slice:
// []byte == []uint8 is a special case handled separately
if typ.Elem().Kind() != reflect.Uint8 {
info.wire = &wireType{SliceT: t.(*sliceType)}
}
case reflect.Struct:
info.wire = &wireType{StructT: t.(*structType)}
}
}
if m := typeInfoMapInit; m != nil {
m[rt] = info
return info, nil
}
// Create new map with old contents plus new entry.
m, _ := typeInfoMap.Load().(map[reflect.Type]*typeInfo)
newm := make(map[reflect.Type]*typeInfo, len(m))
for k, v := range m {
newm[k] = v
}
newm[rt] = info
typeInfoMap.Store(newm)
return info, nil
}
// Called only when a panic is acceptable and unexpected.
func mustGetTypeInfo(rt reflect.Type) *typeInfo {
t, err := getTypeInfo(userType(rt))
if err != nil {
panic("getTypeInfo: " + err.Error())
}
return t
}
// GobEncoder is the interface describing data that provides its own
// representation for encoding values for transmission to a GobDecoder.
// A type that implements GobEncoder and GobDecoder has complete
// control over the representation of its data and may therefore
// contain things such as private fields, channels, and functions,
// which are not usually transmissible in gob streams.
//
// Note: Since gobs can be stored permanently, it is good design
// to guarantee the encoding used by a GobEncoder is stable as the
// software evolves. For instance, it might make sense for GobEncode
// to include a version number in the encoding.
type GobEncoder interface {
// GobEncode returns a byte slice representing the encoding of the
// receiver for transmission to a GobDecoder, usually of the same
// concrete type.
GobEncode() ([]byte, error)
}
// GobDecoder is the interface describing data that provides its own
// routine for decoding transmitted values sent by a GobEncoder.
type GobDecoder interface {
// GobDecode overwrites the receiver, which must be a pointer,
// with the value represented by the byte slice, which was written
// by GobEncode, usually for the same concrete type.
GobDecode([]byte) error
}
var (
nameToConcreteType sync.Map // map[string]reflect.Type
concreteTypeToName sync.Map // map[reflect.Type]string
)
// RegisterName is like Register but uses the provided name rather than the
// type's default.
func RegisterName(name string, value any) {
if name == "" {
// reserved for nil
panic("attempt to register empty name")
}
ut := userType(reflect.TypeOf(value))
// Check for incompatible duplicates. The name must refer to the
// same user type, and vice versa.
// Store the name and type provided by the user....
if t, dup := nameToConcreteType.LoadOrStore(name, reflect.TypeOf(value)); dup && t != ut.user {
panic(fmt.Sprintf("gob: registering duplicate types for %q: %s != %s", name, t, ut.user))
}
// but the flattened type in the type table, since that's what decode needs.
if n, dup := concreteTypeToName.LoadOrStore(ut.base, name); dup && n != name {
nameToConcreteType.Delete(name)
panic(fmt.Sprintf("gob: registering duplicate names for %s: %q != %q", ut.user, n, name))
}
}
// Register records a type, identified by a value for that type, under its
// internal type name. That name will identify the concrete type of a value
// sent or received as an interface variable. Only types that will be
// transferred as implementations of interface values need to be registered.
// Expecting to be used only during initialization, it panics if the mapping
// between types and names is not a bijection.
func Register(value any) {
// Default to printed representation for unnamed types
rt := reflect.TypeOf(value)
name := rt.String()
// But for named types (or pointers to them), qualify with import path (but see inner comment).
// Dereference one pointer looking for a named type.
star := ""
if rt.Name() == "" {
if pt := rt; pt.Kind() == reflect.Pointer {
star = "*"
// NOTE: The following line should be rt = pt.Elem() to implement
// what the comment above claims, but fixing it would break compatibility
// with existing gobs.
//
// Given package p imported as "full/p" with these definitions:
// package p
// type T1 struct { ... }
// this table shows the intended and actual strings used by gob to
// name the types:
//
// Type Correct string Actual string
//
// T1 full/p.T1 full/p.T1
// *T1 *full/p.T1 *p.T1
//
// The missing full path cannot be fixed without breaking existing gob decoders.
rt = pt
}
}
if rt.Name() != "" {
if rt.PkgPath() == "" {
name = star + rt.Name()
} else {
name = star + rt.PkgPath() + "." + rt.Name()
}
}
RegisterName(name, value)
}
func registerBasics() {
Register(int(0))
Register(int8(0))
Register(int16(0))
Register(int32(0))
Register(int64(0))
Register(uint(0))
Register(uint8(0))
Register(uint16(0))
Register(uint32(0))
Register(uint64(0))
Register(float32(0))
Register(float64(0))
Register(complex64(0i))
Register(complex128(0i))
Register(uintptr(0))
Register(false)
Register("")
Register([]byte(nil))
Register([]int(nil))
Register([]int8(nil))
Register([]int16(nil))
Register([]int32(nil))
Register([]int64(nil))
Register([]uint(nil))
Register([]uint8(nil))
Register([]uint16(nil))
Register([]uint32(nil))
Register([]uint64(nil))
Register([]float32(nil))
Register([]float64(nil))
Register([]complex64(nil))
Register([]complex128(nil))
Register([]uintptr(nil))
Register([]bool(nil))
Register([]string(nil))
}
func init() {
typeInfoMap.Store(typeInfoMapInit)
typeInfoMapInit = nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package hex implements hexadecimal encoding and decoding.
package hex
import (
"errors"
"fmt"
"io"
"strings"
)
const (
hextable = "0123456789abcdef"
reverseHexTable = "" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\xff\xff\xff\xff\xff\xff" +
"\xff\x0a\x0b\x0c\x0d\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\x0a\x0b\x0c\x0d\x0e\x0f\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
)
// EncodedLen returns the length of an encoding of n source bytes.
// Specifically, it returns n * 2.
func EncodedLen(n int) int { return n * 2 }
// Encode encodes src into EncodedLen(len(src))
// bytes of dst. As a convenience, it returns the number
// of bytes written to dst, but this value is always EncodedLen(len(src)).
// Encode implements hexadecimal encoding.
func Encode(dst, src []byte) int {
j := 0
for _, v := range src {
dst[j] = hextable[v>>4]
dst[j+1] = hextable[v&0x0f]
j += 2
}
return len(src) * 2
}
// ErrLength reports an attempt to decode an odd-length input
// using Decode or DecodeString.
// The stream-based Decoder returns io.ErrUnexpectedEOF instead of ErrLength.
var ErrLength = errors.New("encoding/hex: odd length hex string")
// InvalidByteError values describe errors resulting from an invalid byte in a hex string.
type InvalidByteError byte
func (e InvalidByteError) Error() string {
return fmt.Sprintf("encoding/hex: invalid byte: %#U", rune(e))
}
// DecodedLen returns the length of a decoding of x source bytes.
// Specifically, it returns x / 2.
func DecodedLen(x int) int { return x / 2 }
// Decode decodes src into DecodedLen(len(src)) bytes,
// returning the actual number of bytes written to dst.
//
// Decode expects that src contains only hexadecimal
// characters and that src has even length.
// If the input is malformed, Decode returns the number
// of bytes decoded before the error.
func Decode(dst, src []byte) (int, error) {
i, j := 0, 1
for ; j < len(src); j += 2 {
p := src[j-1]
q := src[j]
a := reverseHexTable[p]
b := reverseHexTable[q]
if a > 0x0f {
return i, InvalidByteError(p)
}
if b > 0x0f {
return i, InvalidByteError(q)
}
dst[i] = (a << 4) | b
i++
}
if len(src)%2 == 1 {
// Check for invalid char before reporting bad length,
// since the invalid char (if present) is an earlier problem.
if reverseHexTable[src[j-1]] > 0x0f {
return i, InvalidByteError(src[j-1])
}
return i, ErrLength
}
return i, nil
}
// EncodeToString returns the hexadecimal encoding of src.
func EncodeToString(src []byte) string {
dst := make([]byte, EncodedLen(len(src)))
Encode(dst, src)
return string(dst)
}
// DecodeString returns the bytes represented by the hexadecimal string s.
//
// DecodeString expects that src contains only hexadecimal
// characters and that src has even length.
// If the input is malformed, DecodeString returns
// the bytes decoded before the error.
func DecodeString(s string) ([]byte, error) {
src := []byte(s)
// We can use the source slice itself as the destination
// because the decode loop increments by one and then the 'seen' byte is not used anymore.
n, err := Decode(src, src)
return src[:n], err
}
// Dump returns a string that contains a hex dump of the given data. The format
// of the hex dump matches the output of `hexdump -C` on the command line.
func Dump(data []byte) string {
if len(data) == 0 {
return ""
}
var buf strings.Builder
// Dumper will write 79 bytes per complete 16 byte chunk, and at least
// 64 bytes for whatever remains. Round the allocation up, since only a
// maximum of 15 bytes will be wasted.
buf.Grow((1 + ((len(data) - 1) / 16)) * 79)
dumper := Dumper(&buf)
dumper.Write(data)
dumper.Close()
return buf.String()
}
// bufferSize is the number of hexadecimal characters to buffer in encoder and decoder.
const bufferSize = 1024
type encoder struct {
w io.Writer
err error
out [bufferSize]byte // output buffer
}
// NewEncoder returns an io.Writer that writes lowercase hexadecimal characters to w.
func NewEncoder(w io.Writer) io.Writer {
return &encoder{w: w}
}
func (e *encoder) Write(p []byte) (n int, err error) {
for len(p) > 0 && e.err == nil {
chunkSize := bufferSize / 2
if len(p) < chunkSize {
chunkSize = len(p)
}
var written int
encoded := Encode(e.out[:], p[:chunkSize])
written, e.err = e.w.Write(e.out[:encoded])
n += written / 2
p = p[chunkSize:]
}
return n, e.err
}
type decoder struct {
r io.Reader
err error
in []byte // input buffer (encoded form)
arr [bufferSize]byte // backing array for in
}
// NewDecoder returns an io.Reader that decodes hexadecimal characters from r.
// NewDecoder expects that r contain only an even number of hexadecimal characters.
func NewDecoder(r io.Reader) io.Reader {
return &decoder{r: r}
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Fill internal buffer with sufficient bytes to decode
if len(d.in) < 2 && d.err == nil {
var numCopy, numRead int
numCopy = copy(d.arr[:], d.in) // Copies either 0 or 1 bytes
numRead, d.err = d.r.Read(d.arr[numCopy:])
d.in = d.arr[:numCopy+numRead]
if d.err == io.EOF && len(d.in)%2 != 0 {
if a := reverseHexTable[d.in[len(d.in)-1]]; a > 0x0f {
d.err = InvalidByteError(d.in[len(d.in)-1])
} else {
d.err = io.ErrUnexpectedEOF
}
}
}
// Decode internal buffer into output buffer
if numAvail := len(d.in) / 2; len(p) > numAvail {
p = p[:numAvail]
}
numDec, err := Decode(p, d.in[:len(p)*2])
d.in = d.in[2*numDec:]
if err != nil {
d.in, d.err = nil, err // Decode error; discard input remainder
}
if len(d.in) < 2 {
return numDec, d.err // Only expose errors when buffer fully consumed
}
return numDec, nil
}
// Dumper returns a WriteCloser that writes a hex dump of all written data to
// w. The format of the dump matches the output of `hexdump -C` on the command
// line.
func Dumper(w io.Writer) io.WriteCloser {
return &dumper{w: w}
}
type dumper struct {
w io.Writer
rightChars [18]byte
buf [14]byte
used int // number of bytes in the current line
n uint // number of bytes, total
closed bool
}
func toChar(b byte) byte {
if b < 32 || b > 126 {
return '.'
}
return b
}
func (h *dumper) Write(data []byte) (n int, err error) {
if h.closed {
return 0, errors.New("encoding/hex: dumper closed")
}
// Output lines look like:
// 00000010 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d |./0123456789:;<=|
// ^ offset ^ extra space ^ ASCII of line.
for i := range data {
if h.used == 0 {
// At the beginning of a line we print the current
// offset in hex.
h.buf[0] = byte(h.n >> 24)
h.buf[1] = byte(h.n >> 16)
h.buf[2] = byte(h.n >> 8)
h.buf[3] = byte(h.n)
Encode(h.buf[4:], h.buf[:4])
h.buf[12] = ' '
h.buf[13] = ' '
_, err = h.w.Write(h.buf[4:])
if err != nil {
return
}
}
Encode(h.buf[:], data[i:i+1])
h.buf[2] = ' '
l := 3
if h.used == 7 {
// There's an additional space after the 8th byte.
h.buf[3] = ' '
l = 4
} else if h.used == 15 {
// At the end of the line there's an extra space and
// the bar for the right column.
h.buf[3] = ' '
h.buf[4] = '|'
l = 5
}
_, err = h.w.Write(h.buf[:l])
if err != nil {
return
}
n++
h.rightChars[h.used] = toChar(data[i])
h.used++
h.n++
if h.used == 16 {
h.rightChars[16] = '|'
h.rightChars[17] = '\n'
_, err = h.w.Write(h.rightChars[:])
if err != nil {
return
}
h.used = 0
}
}
return
}
func (h *dumper) Close() (err error) {
// See the comments in Write() for the details of this format.
if h.closed {
return
}
h.closed = true
if h.used == 0 {
return
}
h.buf[0] = ' '
h.buf[1] = ' '
h.buf[2] = ' '
h.buf[3] = ' '
h.buf[4] = '|'
nBytes := h.used
for h.used < 16 {
l := 3
if h.used == 7 {
l = 4
} else if h.used == 15 {
l = 5
}
_, err = h.w.Write(h.buf[:l])
if err != nil {
return
}
h.used++
}
h.rightChars[nBytes] = '|'
h.rightChars[nBytes+1] = '\n'
_, err = h.w.Write(h.rightChars[:nBytes+2])
return
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Represents JSON data structure using native Go types: booleans, floats,
// strings, arrays, and maps.
package json
import (
"encoding"
"encoding/base64"
"fmt"
"reflect"
"strconv"
"strings"
"unicode"
"unicode/utf16"
"unicode/utf8"
)
// Unmarshal parses the JSON-encoded data and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// Unmarshal returns an InvalidUnmarshalError.
//
// Unmarshal uses the inverse of the encodings that
// Marshal uses, allocating maps, slices, and pointers as necessary,
// with the following additional rules:
//
// To unmarshal JSON into a pointer, Unmarshal first handles the case of
// the JSON being the JSON literal null. In that case, Unmarshal sets
// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into
// the value pointed at by the pointer. If the pointer is nil, Unmarshal
// allocates a new value for it to point to.
//
// To unmarshal JSON into a value implementing the Unmarshaler interface,
// Unmarshal calls that value's UnmarshalJSON method, including
// when the input is a JSON null.
// Otherwise, if the value implements encoding.TextUnmarshaler
// and the input is a JSON quoted string, Unmarshal calls that value's
// UnmarshalText method with the unquoted form of the string.
//
// To unmarshal JSON into a struct, Unmarshal matches incoming object
// keys to the keys used by Marshal (either the struct field name or its tag),
// preferring an exact match but also accepting a case-insensitive match. By
// default, object keys which don't have a corresponding struct field are
// ignored (see Decoder.DisallowUnknownFields for an alternative).
//
// To unmarshal JSON into an interface value,
// Unmarshal stores one of these in the interface value:
//
// bool, for JSON booleans
// float64, for JSON numbers
// string, for JSON strings
// []interface{}, for JSON arrays
// map[string]interface{}, for JSON objects
// nil for JSON null
//
// To unmarshal a JSON array into a slice, Unmarshal resets the slice length
// to zero and then appends each element to the slice.
// As a special case, to unmarshal an empty JSON array into a slice,
// Unmarshal replaces the slice with a new empty slice.
//
// To unmarshal a JSON array into a Go array, Unmarshal decodes
// JSON array elements into corresponding Go array elements.
// If the Go array is smaller than the JSON array,
// the additional JSON array elements are discarded.
// If the JSON array is smaller than the Go array,
// the additional Go array elements are set to zero values.
//
// To unmarshal a JSON object into a map, Unmarshal first establishes a map to
// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal
// reuses the existing map, keeping existing entries. Unmarshal then stores
// key-value pairs from the JSON object into the map. The map's key type must
// either be any string type, an integer, implement json.Unmarshaler, or
// implement encoding.TextUnmarshaler.
//
// If the JSON-encoded data contain a syntax error, Unmarshal returns a SyntaxError.
//
// If a JSON value is not appropriate for a given target type,
// or if a JSON number overflows the target type, Unmarshal
// skips that field and completes the unmarshaling as best it can.
// If no more serious errors are encountered, Unmarshal returns
// an UnmarshalTypeError describing the earliest such error. In any
// case, it's not guaranteed that all the remaining fields following
// the problematic one will be unmarshaled into the target object.
//
// The JSON null value unmarshals into an interface, map, pointer, or slice
// by setting that Go value to nil. Because null is often used in JSON to mean
// “not present,” unmarshaling a JSON null into any other Go type has no effect
// on the value and produces no error.
//
// When unmarshaling quoted strings, invalid UTF-8 or
// invalid UTF-16 surrogate pairs are not treated as an error.
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
err := checkValid(data, &d.scan)
if err != nil {
return err
}
d.init(data)
return d.unmarshal(v)
}
// Unmarshaler is the interface implemented by types
// that can unmarshal a JSON description of themselves.
// The input can be assumed to be a valid encoding of
// a JSON value. UnmarshalJSON must copy the JSON data
// if it wishes to retain the data after returning.
//
// By convention, to approximate the behavior of Unmarshal itself,
// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op.
type Unmarshaler interface {
UnmarshalJSON([]byte) error
}
// An UnmarshalTypeError describes a JSON value that was
// not appropriate for a value of a specific Go type.
type UnmarshalTypeError struct {
Value string // description of JSON value - "bool", "array", "number -5"
Type reflect.Type // type of Go value it could not be assigned to
Offset int64 // error occurred after reading Offset bytes
Struct string // name of the struct type containing the field
Field string // the full path from root node to the field
}
func (e *UnmarshalTypeError) Error() string {
if e.Struct != "" || e.Field != "" {
return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String()
}
return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String()
}
// An UnmarshalFieldError describes a JSON object key that
// led to an unexported (and therefore unwritable) struct field.
//
// Deprecated: No longer used; kept for compatibility.
type UnmarshalFieldError struct {
Key string
Type reflect.Type
Field reflect.StructField
}
func (e *UnmarshalFieldError) Error() string {
return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String()
}
// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
type InvalidUnmarshalError struct {
Type reflect.Type
}
func (e *InvalidUnmarshalError) Error() string {
if e.Type == nil {
return "json: Unmarshal(nil)"
}
if e.Type.Kind() != reflect.Pointer {
return "json: Unmarshal(non-pointer " + e.Type.String() + ")"
}
return "json: Unmarshal(nil " + e.Type.String() + ")"
}
func (d *decodeState) unmarshal(v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return &InvalidUnmarshalError{reflect.TypeOf(v)}
}
d.scan.reset()
d.scanWhile(scanSkipSpace)
// We decode rv not rv.Elem because the Unmarshaler interface
// test must be applied at the top level of the value.
err := d.value(rv)
if err != nil {
return d.addErrorContext(err)
}
return d.savedError
}
// A Number represents a JSON number literal.
type Number string
// String returns the literal text of the number.
func (n Number) String() string { return string(n) }
// Float64 returns the number as a float64.
func (n Number) Float64() (float64, error) {
return strconv.ParseFloat(string(n), 64)
}
// Int64 returns the number as an int64.
func (n Number) Int64() (int64, error) {
return strconv.ParseInt(string(n), 10, 64)
}
// An errorContext provides context for type errors during decoding.
type errorContext struct {
Struct reflect.Type
FieldStack []string
}
// decodeState represents the state while decoding a JSON value.
type decodeState struct {
data []byte
off int // next read offset in data
opcode int // last read result
scan scanner
errorContext *errorContext
savedError error
useNumber bool
disallowUnknownFields bool
}
// readIndex returns the position of the last byte read.
func (d *decodeState) readIndex() int {
return d.off - 1
}
// phasePanicMsg is used as a panic message when we end up with something that
// shouldn't happen. It can indicate a bug in the JSON decoder, or that
// something is editing the data slice while the decoder executes.
const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?"
func (d *decodeState) init(data []byte) *decodeState {
d.data = data
d.off = 0
d.savedError = nil
if d.errorContext != nil {
d.errorContext.Struct = nil
// Reuse the allocated space for the FieldStack slice.
d.errorContext.FieldStack = d.errorContext.FieldStack[:0]
}
return d
}
// saveError saves the first err it is called with,
// for reporting at the end of the unmarshal.
func (d *decodeState) saveError(err error) {
if d.savedError == nil {
d.savedError = d.addErrorContext(err)
}
}
// addErrorContext returns a new error enhanced with information from d.errorContext
func (d *decodeState) addErrorContext(err error) error {
if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) {
switch err := err.(type) {
case *UnmarshalTypeError:
err.Struct = d.errorContext.Struct.Name()
err.Field = strings.Join(d.errorContext.FieldStack, ".")
}
}
return err
}
// skip scans to the end of what was started.
func (d *decodeState) skip() {
s, data, i := &d.scan, d.data, d.off
depth := len(s.parseState)
for {
op := s.step(s, data[i])
i++
if len(s.parseState) < depth {
d.off = i
d.opcode = op
return
}
}
}
// scanNext processes the byte at d.data[d.off].
func (d *decodeState) scanNext() {
if d.off < len(d.data) {
d.opcode = d.scan.step(&d.scan, d.data[d.off])
d.off++
} else {
d.opcode = d.scan.eof()
d.off = len(d.data) + 1 // mark processed EOF with len+1
}
}
// scanWhile processes bytes in d.data[d.off:] until it
// receives a scan code not equal to op.
func (d *decodeState) scanWhile(op int) {
s, data, i := &d.scan, d.data, d.off
for i < len(data) {
newOp := s.step(s, data[i])
i++
if newOp != op {
d.opcode = newOp
d.off = i
return
}
}
d.off = len(data) + 1 // mark processed EOF with len+1
d.opcode = d.scan.eof()
}
// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the
// common case where we're decoding a literal. The decoder scans the input
// twice, once for syntax errors and to check the length of the value, and the
// second to perform the decoding.
//
// Only in the second step do we use decodeState to tokenize literals, so we
// know there aren't any syntax errors. We can take advantage of that knowledge,
// and scan a literal's bytes much more quickly.
func (d *decodeState) rescanLiteral() {
data, i := d.data, d.off
Switch:
switch data[i-1] {
case '"': // string
for ; i < len(data); i++ {
switch data[i] {
case '\\':
i++ // escaped char
case '"':
i++ // tokenize the closing quote too
break Switch
}
}
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number
for ; i < len(data); i++ {
switch data[i] {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'.', 'e', 'E', '+', '-':
default:
break Switch
}
}
case 't': // true
i += len("rue")
case 'f': // false
i += len("alse")
case 'n': // null
i += len("ull")
}
if i < len(data) {
d.opcode = stateEndValue(&d.scan, data[i])
} else {
d.opcode = scanEnd
}
d.off = i + 1
}
// value consumes a JSON value from d.data[d.off-1:], decoding into v, and
// reads the following byte ahead. If v is invalid, the value is discarded.
// The first byte of the value has been read already.
func (d *decodeState) value(v reflect.Value) error {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray:
if v.IsValid() {
if err := d.array(v); err != nil {
return err
}
} else {
d.skip()
}
d.scanNext()
case scanBeginObject:
if v.IsValid() {
if err := d.object(v); err != nil {
return err
}
} else {
d.skip()
}
d.scanNext()
case scanBeginLiteral:
// All bytes inside literal return scanContinue op code.
start := d.readIndex()
d.rescanLiteral()
if v.IsValid() {
if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil {
return err
}
}
}
return nil
}
type unquotedValue struct{}
// valueQuoted is like value but decodes a
// quoted string literal or literal null into an interface value.
// If it finds anything other than a quoted string literal or null,
// valueQuoted returns unquotedValue{}.
func (d *decodeState) valueQuoted() any {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray, scanBeginObject:
d.skip()
d.scanNext()
case scanBeginLiteral:
v := d.literalInterface()
switch v.(type) {
case nil, string:
return v
}
}
return unquotedValue{}
}
// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
//
// The logic below effectively does this when it first addresses the value
// (to satisfy possible pointer methods) and continues to dereference
// subsequent pointers as necessary.
//
// After the first round-trip, we set v back to the original value to
// preserve the original RW flags contained in reflect.Value.
v0 := v
haveAddr := false
// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}
for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) {
haveAddr = false
v = e
continue
}
}
if v.Kind() != reflect.Pointer {
break
}
if decodingNull && v.CanSet() {
break
}
// Prevent infinite loop if v is an interface pointing to its own address:
// var v interface{}
// v = &v
if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v {
v = v.Elem()
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
}
}
}
if haveAddr {
v = v0 // restore original value after round-trip Value.Addr().Elem()
haveAddr = false
} else {
v = v.Elem()
}
}
return nil, nil, v
}
// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
return u.UnmarshalJSON(d.data[start:d.off])
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
}
v = pv
// Check type of target.
switch v.Kind() {
case reflect.Interface:
if v.NumMethod() == 0 {
// Decoding into nil interface? Switch to non-reflect code.
ai := d.arrayInterface()
v.Set(reflect.ValueOf(ai))
return nil
}
// Otherwise it's invalid.
fallthrough
default:
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
case reflect.Array, reflect.Slice:
break
}
i := 0
for {
// Look ahead for ] - can only happen on first iteration.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndArray {
break
}
// Expand slice length, growing the slice if necessary.
if v.Kind() == reflect.Slice {
if i >= v.Cap() {
v.Grow(1)
}
if i >= v.Len() {
v.SetLen(i + 1)
}
}
if i < v.Len() {
// Decode into element.
if err := d.value(v.Index(i)); err != nil {
return err
}
} else {
// Ran out of fixed array: skip.
if err := d.value(reflect.Value{}); err != nil {
return err
}
}
i++
// Next token must be , or ].
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndArray {
break
}
if d.opcode != scanArrayValue {
panic(phasePanicMsg)
}
}
if i < v.Len() {
if v.Kind() == reflect.Array {
for ; i < v.Len(); i++ {
v.Index(i).SetZero() // zero remainder of array
}
} else {
v.SetLen(i) // truncate the slice
}
}
if i == 0 && v.Kind() == reflect.Slice {
v.Set(reflect.MakeSlice(v.Type(), 0, 0))
}
return nil
}
var nullLiteral = []byte("null")
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
// object consumes an object from d.data[d.off-1:], decoding into v.
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
return u.UnmarshalJSON(d.data[start:d.off])
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
return nil
}
v = pv
t := v.Type()
// Decoding into nil interface? Switch to non-reflect code.
if v.Kind() == reflect.Interface && v.NumMethod() == 0 {
oi := d.objectInterface()
v.Set(reflect.ValueOf(oi))
return nil
}
var fields structFields
// Check type of target:
// struct or
// map[T1]T2 where T1 is string, an integer type,
// or an encoding.TextUnmarshaler
switch v.Kind() {
case reflect.Map:
// Map key must either have string kind, have an integer kind,
// or be an encoding.TextUnmarshaler.
switch t.Key().Kind() {
case reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) {
d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)})
d.skip()
return nil
}
}
if v.IsNil() {
v.Set(reflect.MakeMap(t))
}
case reflect.Struct:
fields = cachedTypeFields(t)
// ok
default:
d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)})
d.skip()
return nil
}
var mapElem reflect.Value
var origErrorContext errorContext
if d.errorContext != nil {
origErrorContext = *d.errorContext
}
for {
// Read opening " of string key or closing }.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndObject {
// closing } - can only happen on first iteration.
break
}
if d.opcode != scanBeginLiteral {
panic(phasePanicMsg)
}
// Read key.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
key, ok := unquoteBytes(item)
if !ok {
panic(phasePanicMsg)
}
// Figure out field corresponding to key.
var subv reflect.Value
destring := false // whether the value is wrapped in a string to be decoded first
if v.Kind() == reflect.Map {
elemType := t.Elem()
if !mapElem.IsValid() {
mapElem = reflect.New(elemType).Elem()
} else {
mapElem.SetZero()
}
subv = mapElem
} else {
f := fields.byExactName[string(key)]
if f == nil {
f = fields.byFoldedName[string(foldName(key))]
}
if f != nil {
subv = v
destring = f.quoted
for _, i := range f.index {
if subv.Kind() == reflect.Pointer {
if subv.IsNil() {
// If a struct embeds a pointer to an unexported type,
// it is not possible to set a newly allocated value
// since the field is unexported.
//
// See https://golang.org/issue/21357
if !subv.CanSet() {
d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem()))
// Invalidate subv to ensure d.value(subv) skips over
// the JSON value without assigning it to subv.
subv = reflect.Value{}
destring = false
break
}
subv.Set(reflect.New(subv.Type().Elem()))
}
subv = subv.Elem()
}
subv = subv.Field(i)
}
if d.errorContext == nil {
d.errorContext = new(errorContext)
}
d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name)
d.errorContext.Struct = t
} else if d.disallowUnknownFields {
d.saveError(fmt.Errorf("json: unknown field %q", key))
}
}
// Read : before value.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanObjectKey {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace)
if destring {
switch qv := d.valueQuoted().(type) {
case nil:
if err := d.literalStore(nullLiteral, subv, false); err != nil {
return err
}
case string:
if err := d.literalStore([]byte(qv), subv, true); err != nil {
return err
}
default:
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type()))
}
} else {
if err := d.value(subv); err != nil {
return err
}
}
// Write value back to map;
// if using struct, subv points into struct already.
if v.Kind() == reflect.Map {
kt := t.Key()
var kv reflect.Value
switch {
case reflect.PointerTo(kt).Implements(textUnmarshalerType):
kv = reflect.New(kt)
if err := d.literalStore(item, kv, true); err != nil {
return err
}
kv = kv.Elem()
case kt.Kind() == reflect.String:
kv = reflect.ValueOf(key).Convert(kt)
default:
switch kt.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := string(key)
n, err := strconv.ParseInt(s, 10, 64)
if err != nil || reflect.Zero(kt).OverflowInt(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
break
}
kv = reflect.ValueOf(n).Convert(kt)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
s := string(key)
n, err := strconv.ParseUint(s, 10, 64)
if err != nil || reflect.Zero(kt).OverflowUint(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)})
break
}
kv = reflect.ValueOf(n).Convert(kt)
default:
panic("json: Unexpected key type") // should never occur
}
}
if kv.IsValid() {
v.SetMapIndex(kv, subv)
}
}
// Next token must be , or }.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.errorContext != nil {
// Reset errorContext to its original state.
// Keep the same underlying array for FieldStack, to reuse the
// space and avoid unnecessary allocs.
d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)]
d.errorContext.Struct = origErrorContext.Struct
}
if d.opcode == scanEndObject {
break
}
if d.opcode != scanObjectValue {
panic(phasePanicMsg)
}
}
return nil
}
// convertNumber converts the number literal s to a float64 or a Number
// depending on the setting of d.useNumber.
func (d *decodeState) convertNumber(s string) (any, error) {
if d.useNumber {
return Number(s), nil
}
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, &UnmarshalTypeError{Value: "number " + s, Type: reflect.TypeOf(0.0), Offset: int64(d.off)}
}
return f, nil
}
var numberType = reflect.TypeOf(Number(""))
// literalStore decodes a literal stored in item into v.
//
// fromQuoted indicates whether this literal came from unwrapping a
// string from the ",string" struct tag option. this is used only to
// produce more helpful error messages.
func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error {
// Check for unmarshaler.
if len(item) == 0 {
//Empty string given
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
if u != nil {
return u.UnmarshalJSON(item)
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
return nil
}
val := "number"
switch item[0] {
case 'n':
val = "null"
case 't', 'f':
val = "bool"
}
d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())})
return nil
}
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
return ut.UnmarshalText(s)
}
v = pv
switch c := item[0]; c {
case 'n': // null
// The main parser checks that only true and false can reach here,
// but if this was a quoted string input, it could be anything.
if fromQuoted && string(item) != "null" {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
break
}
switch v.Kind() {
case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice:
v.SetZero()
// otherwise, ignore null for primitives/string
}
case 't', 'f': // true, false
value := item[0] == 't'
// The main parser checks that only true and false can reach here,
// but if this was a quoted string input, it could be anything.
if fromQuoted && string(item) != "true" && string(item) != "false" {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
break
}
switch v.Kind() {
default:
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())})
}
case reflect.Bool:
v.SetBool(value)
case reflect.Interface:
if v.NumMethod() == 0 {
v.Set(reflect.ValueOf(value))
} else {
d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())})
}
}
case '"': // string
s, ok := unquoteBytes(item)
if !ok {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
switch v.Kind() {
default:
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
break
}
b := make([]byte, base64.StdEncoding.DecodedLen(len(s)))
n, err := base64.StdEncoding.Decode(b, s)
if err != nil {
d.saveError(err)
break
}
v.SetBytes(b[:n])
case reflect.String:
if v.Type() == numberType && !isValidNumber(string(s)) {
return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item)
}
v.SetString(string(s))
case reflect.Interface:
if v.NumMethod() == 0 {
v.Set(reflect.ValueOf(string(s)))
} else {
d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())})
}
}
default: // number
if c != '-' && (c < '0' || c > '9') {
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
panic(phasePanicMsg)
}
s := string(item)
switch v.Kind() {
default:
if v.Kind() == reflect.String && v.Type() == numberType {
// s must be a valid number, because it's
// already been tokenized.
v.SetString(s)
break
}
if fromQuoted {
return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())
}
d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())})
case reflect.Interface:
n, err := d.convertNumber(s)
if err != nil {
d.saveError(err)
break
}
if v.NumMethod() != 0 {
d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.Set(reflect.ValueOf(n))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(s, 10, 64)
if err != nil || v.OverflowInt(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, err := strconv.ParseUint(s, 10, 64)
if err != nil || v.OverflowUint(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetUint(n)
case reflect.Float32, reflect.Float64:
n, err := strconv.ParseFloat(s, v.Type().Bits())
if err != nil || v.OverflowFloat(n) {
d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: v.Type(), Offset: int64(d.readIndex())})
break
}
v.SetFloat(n)
}
}
return nil
}
// The xxxInterface routines build up a value to be stored
// in an empty interface. They are not strictly necessary,
// but they avoid the weight of reflection in this common case.
// valueInterface is like value but returns interface{}
func (d *decodeState) valueInterface() (val any) {
switch d.opcode {
default:
panic(phasePanicMsg)
case scanBeginArray:
val = d.arrayInterface()
d.scanNext()
case scanBeginObject:
val = d.objectInterface()
d.scanNext()
case scanBeginLiteral:
val = d.literalInterface()
}
return
}
// arrayInterface is like array but returns []interface{}.
func (d *decodeState) arrayInterface() []any {
var v = make([]any, 0)
for {
// Look ahead for ] - can only happen on first iteration.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndArray {
break
}
v = append(v, d.valueInterface())
// Next token must be , or ].
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndArray {
break
}
if d.opcode != scanArrayValue {
panic(phasePanicMsg)
}
}
return v
}
// objectInterface is like object but returns map[string]interface{}.
func (d *decodeState) objectInterface() map[string]any {
m := make(map[string]any)
for {
// Read opening " of string key or closing }.
d.scanWhile(scanSkipSpace)
if d.opcode == scanEndObject {
// closing } - can only happen on first iteration.
break
}
if d.opcode != scanBeginLiteral {
panic(phasePanicMsg)
}
// Read string key.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
key, ok := unquote(item)
if !ok {
panic(phasePanicMsg)
}
// Read : before value.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode != scanObjectKey {
panic(phasePanicMsg)
}
d.scanWhile(scanSkipSpace)
// Read value.
m[key] = d.valueInterface()
// Next token must be , or }.
if d.opcode == scanSkipSpace {
d.scanWhile(scanSkipSpace)
}
if d.opcode == scanEndObject {
break
}
if d.opcode != scanObjectValue {
panic(phasePanicMsg)
}
}
return m
}
// literalInterface consumes and returns a literal from d.data[d.off-1:] and
// it reads the following byte ahead. The first byte of the literal has been
// read already (that's how the caller knows it's a literal).
func (d *decodeState) literalInterface() any {
// All bytes inside literal return scanContinue op code.
start := d.readIndex()
d.rescanLiteral()
item := d.data[start:d.readIndex()]
switch c := item[0]; c {
case 'n': // null
return nil
case 't', 'f': // true, false
return c == 't'
case '"': // string
s, ok := unquote(item)
if !ok {
panic(phasePanicMsg)
}
return s
default: // number
if c != '-' && (c < '0' || c > '9') {
panic(phasePanicMsg)
}
n, err := d.convertNumber(string(item))
if err != nil {
d.saveError(err)
}
return n
}
}
// getu4 decodes \uXXXX from the beginning of s, returning the hex value,
// or it returns -1.
func getu4(s []byte) rune {
if len(s) < 6 || s[0] != '\\' || s[1] != 'u' {
return -1
}
var r rune
for _, c := range s[2:6] {
switch {
case '0' <= c && c <= '9':
c = c - '0'
case 'a' <= c && c <= 'f':
c = c - 'a' + 10
case 'A' <= c && c <= 'F':
c = c - 'A' + 10
default:
return -1
}
r = r*16 + rune(c)
}
return r
}
// unquote converts a quoted JSON string literal s into an actual string t.
// The rules are different than for Go, so cannot use strconv.Unquote.
func unquote(s []byte) (t string, ok bool) {
s, ok = unquoteBytes(s)
t = string(s)
return
}
func unquoteBytes(s []byte) (t []byte, ok bool) {
if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
return
}
s = s[1 : len(s)-1]
// Check for unusual characters. If there are none,
// then no unquoting is needed, so return a slice of the
// original bytes.
r := 0
for r < len(s) {
c := s[r]
if c == '\\' || c == '"' || c < ' ' {
break
}
if c < utf8.RuneSelf {
r++
continue
}
rr, size := utf8.DecodeRune(s[r:])
if rr == utf8.RuneError && size == 1 {
break
}
r += size
}
if r == len(s) {
return s, true
}
b := make([]byte, len(s)+2*utf8.UTFMax)
w := copy(b, s[0:r])
for r < len(s) {
// Out of room? Can only happen if s is full of
// malformed UTF-8 and we're replacing each
// byte with RuneError.
if w >= len(b)-2*utf8.UTFMax {
nb := make([]byte, (len(b)+utf8.UTFMax)*2)
copy(nb, b[0:w])
b = nb
}
switch c := s[r]; {
case c == '\\':
r++
if r >= len(s) {
return
}
switch s[r] {
default:
return
case '"', '\\', '/', '\'':
b[w] = s[r]
r++
w++
case 'b':
b[w] = '\b'
r++
w++
case 'f':
b[w] = '\f'
r++
w++
case 'n':
b[w] = '\n'
r++
w++
case 'r':
b[w] = '\r'
r++
w++
case 't':
b[w] = '\t'
r++
w++
case 'u':
r--
rr := getu4(s[r:])
if rr < 0 {
return
}
r += 6
if utf16.IsSurrogate(rr) {
rr1 := getu4(s[r:])
if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar {
// A valid pair; consume.
r += 6
w += utf8.EncodeRune(b[w:], dec)
break
}
// Invalid surrogate; fall back to replacement rune.
rr = unicode.ReplacementChar
}
w += utf8.EncodeRune(b[w:], rr)
}
// Quote, control characters are invalid.
case c == '"', c < ' ':
return
// ASCII
case c < utf8.RuneSelf:
b[w] = c
r++
w++
// Coerce to well-formed UTF-8.
default:
rr, size := utf8.DecodeRune(s[r:])
r += size
w += utf8.EncodeRune(b[w:], rr)
}
}
return b[0:w], true
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package json implements encoding and decoding of JSON as defined in
// RFC 7159. The mapping between JSON and Go values is described
// in the documentation for the Marshal and Unmarshal functions.
//
// See "JSON and Go" for an introduction to this package:
// https://golang.org/doc/articles/json_and_go.html
package json
import (
"bytes"
"encoding"
"encoding/base64"
"fmt"
"math"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
// Marshal returns the JSON encoding of v.
//
// Marshal traverses the value v recursively.
// If an encountered value implements the Marshaler interface
// and is not a nil pointer, Marshal calls its MarshalJSON method
// to produce JSON. If no MarshalJSON method is present but the
// value implements encoding.TextMarshaler instead, Marshal calls
// its MarshalText method and encodes the result as a JSON string.
// The nil pointer exception is not strictly necessary
// but mimics a similar, necessary exception in the behavior of
// UnmarshalJSON.
//
// Otherwise, Marshal uses the following type-dependent default encodings:
//
// Boolean values encode as JSON booleans.
//
// Floating point, integer, and Number values encode as JSON numbers.
//
// String values encode as JSON strings coerced to valid UTF-8,
// replacing invalid bytes with the Unicode replacement rune.
// So that the JSON will be safe to embed inside HTML <script> tags,
// the string is encoded using HTMLEscape,
// which replaces "<", ">", "&", U+2028, and U+2029 are escaped
// to "\u003c","\u003e", "\u0026", "\u2028", and "\u2029".
// This replacement can be disabled when using an Encoder,
// by calling SetEscapeHTML(false).
//
// Array and slice values encode as JSON arrays, except that
// []byte encodes as a base64-encoded string, and a nil slice
// encodes as the null JSON value.
//
// Struct values encode as JSON objects.
// Each exported struct field becomes a member of the object, using the
// field name as the object key, unless the field is omitted for one of the
// reasons given below.
//
// The encoding of each struct field can be customized by the format string
// stored under the "json" key in the struct field's tag.
// The format string gives the name of the field, possibly followed by a
// comma-separated list of options. The name may be empty in order to
// specify options without overriding the default field name.
//
// The "omitempty" option specifies that the field should be omitted
// from the encoding if the field has an empty value, defined as
// false, 0, a nil pointer, a nil interface value, and any empty array,
// slice, map, or string.
//
// As a special case, if the field tag is "-", the field is always omitted.
// Note that a field with name "-" can still be generated using the tag "-,".
//
// Examples of struct field tags and their meanings:
//
// // Field appears in JSON as key "myName".
// Field int `json:"myName"`
//
// // Field appears in JSON as key "myName" and
// // the field is omitted from the object if its value is empty,
// // as defined above.
// Field int `json:"myName,omitempty"`
//
// // Field appears in JSON as key "Field" (the default), but
// // the field is skipped if empty.
// // Note the leading comma.
// Field int `json:",omitempty"`
//
// // Field is ignored by this package.
// Field int `json:"-"`
//
// // Field appears in JSON as key "-".
// Field int `json:"-,"`
//
// The "string" option signals that a field is stored as JSON inside a
// JSON-encoded string. It applies only to fields of string, floating point,
// integer, or boolean types. This extra level of encoding is sometimes used
// when communicating with JavaScript programs:
//
// Int64String int64 `json:",string"`
//
// The key name will be used if it's a non-empty string consisting of
// only Unicode letters, digits, and ASCII punctuation except quotation
// marks, backslash, and comma.
//
// Anonymous struct fields are usually marshaled as if their inner exported fields
// were fields in the outer struct, subject to the usual Go visibility rules amended
// as described in the next paragraph.
// An anonymous struct field with a name given in its JSON tag is treated as
// having that name, rather than being anonymous.
// An anonymous struct field of interface type is treated the same as having
// that type as its name, rather than being anonymous.
//
// The Go visibility rules for struct fields are amended for JSON when
// deciding which field to marshal or unmarshal. If there are
// multiple fields at the same level, and that level is the least
// nested (and would therefore be the nesting level selected by the
// usual Go rules), the following extra rules apply:
//
// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered,
// even if there are multiple untagged fields that would otherwise conflict.
//
// 2) If there is exactly one field (tagged or not according to the first rule), that is selected.
//
// 3) Otherwise there are multiple fields, and all are ignored; no error occurs.
//
// Handling of anonymous struct fields is new in Go 1.1.
// Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of
// an anonymous struct field in both current and earlier versions, give the field
// a JSON tag of "-".
//
// Map values encode as JSON objects. The map's key type must either be a
// string, an integer type, or implement encoding.TextMarshaler. The map keys
// are sorted and used as JSON object keys by applying the following rules,
// subject to the UTF-8 coercion described for string values above:
// - keys of any string type are used directly
// - encoding.TextMarshalers are marshaled
// - integer keys are converted to strings
//
// Pointer values encode as the value pointed to.
// A nil pointer encodes as the null JSON value.
//
// Interface values encode as the value contained in the interface.
// A nil interface value encodes as the null JSON value.
//
// Channel, complex, and function values cannot be encoded in JSON.
// Attempting to encode such a value causes Marshal to return
// an UnsupportedTypeError.
//
// JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
if err != nil {
return nil, err
}
buf := append([]byte(nil), e.Bytes()...)
return buf, nil
}
// MarshalIndent is like Marshal but applies Indent to format the output.
// Each JSON element in the output will begin on a new line beginning with prefix
// followed by one or more copies of indent according to the indentation nesting.
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
b, err := Marshal(v)
if err != nil {
return nil, err
}
b2 := make([]byte, 0, indentGrowthFactor*len(b))
b2, err = appendIndent(b2, b, prefix, indent)
if err != nil {
return nil, err
}
return b2, nil
}
// Marshaler is the interface implemented by types that
// can marshal themselves into valid JSON.
type Marshaler interface {
MarshalJSON() ([]byte, error)
}
// An UnsupportedTypeError is returned by Marshal when attempting
// to encode an unsupported value type.
type UnsupportedTypeError struct {
Type reflect.Type
}
func (e *UnsupportedTypeError) Error() string {
return "json: unsupported type: " + e.Type.String()
}
// An UnsupportedValueError is returned by Marshal when attempting
// to encode an unsupported value.
type UnsupportedValueError struct {
Value reflect.Value
Str string
}
func (e *UnsupportedValueError) Error() string {
return "json: unsupported value: " + e.Str
}
// Before Go 1.2, an InvalidUTF8Error was returned by Marshal when
// attempting to encode a string value with invalid UTF-8 sequences.
// As of Go 1.2, Marshal instead coerces the string to valid UTF-8 by
// replacing invalid bytes with the Unicode replacement rune U+FFFD.
//
// Deprecated: No longer used; kept for compatibility.
type InvalidUTF8Error struct {
S string // the whole string value that caused the error
}
func (e *InvalidUTF8Error) Error() string {
return "json: invalid UTF-8 in string: " + strconv.Quote(e.S)
}
// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
type MarshalerError struct {
Type reflect.Type
Err error
sourceFunc string
}
func (e *MarshalerError) Error() string {
srcFunc := e.sourceFunc
if srcFunc == "" {
srcFunc = "MarshalJSON"
}
return "json: error calling " + srcFunc +
" for type " + e.Type.String() +
": " + e.Err.Error()
}
// Unwrap returns the underlying error.
func (e *MarshalerError) Unwrap() error { return e.Err }
var hex = "0123456789abcdef"
// An encodeState encodes JSON into a bytes.Buffer.
type encodeState struct {
bytes.Buffer // accumulated output
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
// startDetectingCyclesAfter, so that we skip the work if we're within a
// reasonable amount of nested pointers deep.
ptrLevel uint
ptrSeen map[any]struct{}
}
func (e *encodeState) AvailableBuffer() []byte {
return availableBuffer(&e.Buffer)
}
const startDetectingCyclesAfter = 1000
var encodeStatePool sync.Pool
func newEncodeState() *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
if len(e.ptrSeen) > 0 {
panic("ptrEncoder.encode should have emptied ptrSeen via defers")
}
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
}
// jsonError is an error wrapper type for internal use only.
// Panics with errors are wrapped in jsonError so that the top-level recover
// can distinguish intentional panics from this package.
type jsonError struct{ error }
func (e *encodeState) marshal(v any, opts encOpts) (err error) {
defer func() {
if r := recover(); r != nil {
if je, ok := r.(jsonError); ok {
err = je.error
} else {
panic(r)
}
}
}()
e.reflectValue(reflect.ValueOf(v), opts)
return nil
}
// error aborts the encoding by panicking with err wrapped in jsonError.
func (e *encodeState) error(err error) {
panic(jsonError{err})
}
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return v.Bool() == false
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Pointer:
return v.IsNil()
}
return false
}
func (e *encodeState) reflectValue(v reflect.Value, opts encOpts) {
valueEncoder(v)(e, v, opts)
}
type encOpts struct {
// quoted causes primitive fields to be encoded inside JSON strings.
quoted bool
// escapeHTML causes '<', '>', and '&' to be escaped in JSON strings.
escapeHTML bool
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
var encoderCache sync.Map // map[reflect.Type]encoderFunc
func valueEncoder(v reflect.Value) encoderFunc {
if !v.IsValid() {
return invalidValueEncoder
}
return typeEncoder(v.Type())
}
func typeEncoder(t reflect.Type) encoderFunc {
if fi, ok := encoderCache.Load(t); ok {
return fi.(encoderFunc)
}
// To deal with recursive types, populate the map with an
// indirect func before we build it. This type waits on the
// real func (f) to be ready and then calls it. This indirect
// func is only used for recursive types.
var (
wg sync.WaitGroup
f encoderFunc
)
wg.Add(1)
fi, loaded := encoderCache.LoadOrStore(t, encoderFunc(func(e *encodeState, v reflect.Value, opts encOpts) {
wg.Wait()
f(e, v, opts)
}))
if loaded {
return fi.(encoderFunc)
}
// Compute the real encoder and replace the indirect func with it.
f = newTypeEncoder(t, true)
wg.Done()
encoderCache.Store(t, f)
return f
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// newTypeEncoder constructs an encoderFunc for a type.
// The returned encoder only checks CanAddr when allowAddr is true.
func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
// If we have a non-pointer value whose type implements
// Marshaler with a value receiver, then we're better off taking
// the address of the value - otherwise we end up with an
// allocation as we cast the value to an interface.
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(textMarshalerType) {
return textMarshalerEncoder
}
switch t.Kind() {
case reflect.Bool:
return boolEncoder
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intEncoder
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintEncoder
case reflect.Float32:
return float32Encoder
case reflect.Float64:
return float64Encoder
case reflect.String:
return stringEncoder
case reflect.Interface:
return interfaceEncoder
case reflect.Struct:
return newStructEncoder(t)
case reflect.Map:
return newMapEncoder(t)
case reflect.Slice:
return newSliceEncoder(t)
case reflect.Array:
return newArrayEncoder(t)
case reflect.Pointer:
return newPtrEncoder(t)
default:
return unsupportedTypeEncoder
}
}
func invalidValueEncoder(e *encodeState, v reflect.Value, _ encOpts) {
e.WriteString("null")
}
func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(Marshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(Marshaler)
b, err := m.MarshalJSON()
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(encoding.TextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalText()
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
}
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalText"})
}
e.Write(appendString(e.AvailableBuffer(), b, opts.escapeHTML))
}
func boolEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendBool(b, v.Bool())
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
func intEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendInt(b, v.Int(), 10)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
func uintEncoder(e *encodeState, v reflect.Value, opts encOpts) {
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = strconv.AppendUint(b, v.Uint(), 10)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
type floatEncoder int // number of bits
func (bits floatEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
f := v.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, int(bits))})
}
// Convert as if by ES6 number to string conversion.
// This matches most other JSON generators.
// See golang.org/issue/6384 and golang.org/issue/14135.
// Like fmt %g, but the exponent cutoffs are different
// and exponents themselves are not padded to two digits.
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
abs := math.Abs(f)
fmt := byte('f')
// Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right.
if abs != 0 {
if bits == 64 && (abs < 1e-6 || abs >= 1e21) || bits == 32 && (float32(abs) < 1e-6 || float32(abs) >= 1e21) {
fmt = 'e'
}
}
b = strconv.AppendFloat(b, f, fmt, -1, int(bits))
if fmt == 'e' {
// clean up e-09 to e-9
n := len(b)
if n >= 4 && b[n-4] == 'e' && b[n-3] == '-' && b[n-2] == '0' {
b[n-2] = b[n-1]
b = b[:n-1]
}
}
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
}
var (
float32Encoder = (floatEncoder(32)).encode
float64Encoder = (floatEncoder(64)).encode
)
func stringEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Type() == numberType {
numStr := v.String()
// In Go1.5 the empty string encodes to "0", while this is not a valid number literal
// we keep compatibility so check validity after this.
if numStr == "" {
numStr = "0" // Number's zero-val
}
if !isValidNumber(numStr) {
e.error(fmt.Errorf("json: invalid number literal %q", numStr))
}
b := e.AvailableBuffer()
b = mayAppendQuote(b, opts.quoted)
b = append(b, numStr...)
b = mayAppendQuote(b, opts.quoted)
e.Write(b)
return
}
if opts.quoted {
b := appendString(nil, v.String(), opts.escapeHTML)
e.Write(appendString(e.AvailableBuffer(), b, false)) // no need to escape again since it is already escaped
} else {
e.Write(appendString(e.AvailableBuffer(), v.String(), opts.escapeHTML))
}
}
// isValidNumber reports whether s is a valid JSON number literal.
func isValidNumber(s string) bool {
// This function implements the JSON numbers grammar.
// See https://tools.ietf.org/html/rfc7159#section-6
// and https://www.json.org/img/number.png
if s == "" {
return false
}
// Optional -
if s[0] == '-' {
s = s[1:]
if s == "" {
return false
}
}
// Digits
switch {
default:
return false
case s[0] == '0':
s = s[1:]
case '1' <= s[0] && s[0] <= '9':
s = s[1:]
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// . followed by 1 or more digits.
if len(s) >= 2 && s[0] == '.' && '0' <= s[1] && s[1] <= '9' {
s = s[2:]
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// e or E followed by an optional - or + and
// 1 or more digits.
if len(s) >= 2 && (s[0] == 'e' || s[0] == 'E') {
s = s[1:]
if s[0] == '+' || s[0] == '-' {
s = s[1:]
if s == "" {
return false
}
}
for len(s) > 0 && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
}
// Make sure we are at the end.
return s == ""
}
func interfaceEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
e.reflectValue(v.Elem(), opts)
}
func unsupportedTypeEncoder(e *encodeState, v reflect.Value, _ encOpts) {
e.error(&UnsupportedTypeError{v.Type()})
}
type structEncoder struct {
fields structFields
}
type structFields struct {
list []field
byExactName map[string]*field
byFoldedName map[string]*field
}
func (se structEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
next := byte('{')
FieldLoop:
for i := range se.fields.list {
f := &se.fields.list[i]
// Find the nested struct field by following f.index.
fv := v
for _, i := range f.index {
if fv.Kind() == reflect.Pointer {
if fv.IsNil() {
continue FieldLoop
}
fv = fv.Elem()
}
fv = fv.Field(i)
}
if f.omitEmpty && isEmptyValue(fv) {
continue
}
e.WriteByte(next)
next = ','
if opts.escapeHTML {
e.WriteString(f.nameEscHTML)
} else {
e.WriteString(f.nameNonEsc)
}
opts.quoted = f.quoted
f.encoder(e, fv, opts)
}
if next == '{' {
e.WriteString("{}")
} else {
e.WriteByte('}')
}
}
func newStructEncoder(t reflect.Type) encoderFunc {
se := structEncoder{fields: cachedTypeFields(t)}
return se.encode
}
type mapEncoder struct {
elemEnc encoderFunc
}
func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.UnsafePointer()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
e.WriteByte('{')
// Extract and sort the keys.
sv := make([]reflectWithString, v.Len())
mi := v.MapRange()
for i := 0; mi.Next(); i++ {
sv[i].k = mi.Key()
sv[i].v = mi.Value()
if err := sv[i].resolve(); err != nil {
e.error(fmt.Errorf("json: encoding error for type %q: %q", v.Type().String(), err.Error()))
}
}
sort.Slice(sv, func(i, j int) bool { return sv[i].ks < sv[j].ks })
for i, kv := range sv {
if i > 0 {
e.WriteByte(',')
}
e.Write(appendString(e.AvailableBuffer(), kv.ks, opts.escapeHTML))
e.WriteByte(':')
me.elemEnc(e, kv.v, opts)
}
e.WriteByte('}')
e.ptrLevel--
}
func newMapEncoder(t reflect.Type) encoderFunc {
switch t.Key().Kind() {
case reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
if !t.Key().Implements(textMarshalerType) {
return unsupportedTypeEncoder
}
}
me := mapEncoder{typeEncoder(t.Elem())}
return me.encode
}
func encodeByteSlice(e *encodeState, v reflect.Value, _ encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
s := v.Bytes()
encodedLen := base64.StdEncoding.EncodedLen(len(s))
e.Grow(len(`"`) + encodedLen + len(`"`))
// TODO(https://go.dev/issue/53693): Use base64.Encoding.AppendEncode.
b := e.AvailableBuffer()
b = append(b, '"')
base64.StdEncoding.Encode(b[len(b):][:encodedLen], s)
b = b[:len(b)+encodedLen]
b = append(b, '"')
e.Write(b)
}
// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil.
type sliceEncoder struct {
arrayEnc encoderFunc
}
func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
// Here we use a struct to memorize the pointer to the first element of the slice
// and its length.
ptr := struct {
ptr interface{} // always an unsafe.Pointer, but avoids a dependency on package unsafe
len int
}{v.UnsafePointer(), v.Len()}
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
se.arrayEnc(e, v, opts)
e.ptrLevel--
}
func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}
enc := sliceEncoder{newArrayEncoder(t)}
return enc.encode
}
type arrayEncoder struct {
elemEnc encoderFunc
}
func (ae arrayEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteByte('[')
n := v.Len()
for i := 0; i < n; i++ {
if i > 0 {
e.WriteByte(',')
}
ae.elemEnc(e, v.Index(i), opts)
}
e.WriteByte(']')
}
func newArrayEncoder(t reflect.Type) encoderFunc {
enc := arrayEncoder{typeEncoder(t.Elem())}
return enc.encode
}
type ptrEncoder struct {
elemEnc encoderFunc
}
func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.IsNil() {
e.WriteString("null")
return
}
if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
// We're a large number of nested ptrEncoder.encode calls deep;
// start checking if we've run into a pointer cycle.
ptr := v.Interface()
if _, ok := e.ptrSeen[ptr]; ok {
e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
}
e.ptrSeen[ptr] = struct{}{}
defer delete(e.ptrSeen, ptr)
}
pe.elemEnc(e, v.Elem(), opts)
e.ptrLevel--
}
func newPtrEncoder(t reflect.Type) encoderFunc {
enc := ptrEncoder{typeEncoder(t.Elem())}
return enc.encode
}
type condAddrEncoder struct {
canAddrEnc, elseEnc encoderFunc
}
func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
if v.CanAddr() {
ce.canAddrEnc(e, v, opts)
} else {
ce.elseEnc(e, v, opts)
}
}
// newCondAddrEncoder returns an encoder that checks whether its value
// CanAddr and delegates to canAddrEnc if so, else to elseEnc.
func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc {
enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return enc.encode
}
func isValidTag(s string) bool {
if s == "" {
return false
}
for _, c := range s {
switch {
case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c):
// Backslash and quote chars are reserved, but
// otherwise any punctuation chars are allowed
// in a tag name.
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
return false
}
}
return true
}
func typeByIndex(t reflect.Type, index []int) reflect.Type {
for _, i := range index {
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
t = t.Field(i).Type
}
return t
}
type reflectWithString struct {
k reflect.Value
v reflect.Value
ks string
}
func (w *reflectWithString) resolve() error {
if w.k.Kind() == reflect.String {
w.ks = w.k.String()
return nil
}
if tm, ok := w.k.Interface().(encoding.TextMarshaler); ok {
if w.k.Kind() == reflect.Pointer && w.k.IsNil() {
return nil
}
buf, err := tm.MarshalText()
w.ks = string(buf)
return err
}
switch w.k.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
w.ks = strconv.FormatInt(w.k.Int(), 10)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
w.ks = strconv.FormatUint(w.k.Uint(), 10)
return nil
}
panic("unexpected map key type")
}
func appendString[Bytes []byte | string](dst []byte, src Bytes, escapeHTML bool) []byte {
dst = append(dst, '"')
start := 0
for i := 0; i < len(src); {
if b := src[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
dst = append(dst, src[start:i]...)
switch b {
case '\\', '"':
dst = append(dst, '\\', b)
case '\n':
dst = append(dst, '\\', 'n')
case '\r':
dst = append(dst, '\\', 'r')
case '\t':
dst = append(dst, '\\', 't')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
dst = append(dst, '\\', 'u', '0', '0', hex[b>>4], hex[b&0xF])
}
i++
start = i
continue
}
// TODO(https://go.dev/issue/56948): Use generic utf8 functionality.
// For now, cast only a small portion of byte slices to a string
// so that it can be stack allocated. This slows down []byte slightly
// due to the extra copy, but keeps string performance roughly the same.
n := len(src) - i
if n > utf8.UTFMax {
n = utf8.UTFMax
}
c, size := utf8.DecodeRuneInString(string(src[i : i+n]))
if c == utf8.RuneError && size == 1 {
dst = append(dst, src[start:i]...)
dst = append(dst, `\ufffd`...)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[c&0xF])
i += size
start = i
continue
}
i += size
}
dst = append(dst, src[start:]...)
dst = append(dst, '"')
return dst
}
// A field represents a single field found in a struct.
type field struct {
name string
nameBytes []byte // []byte(name)
nameNonEsc string // `"` + name + `":`
nameEscHTML string // `"` + HTMLEscape(name) + `":`
tag bool
index []int
typ reflect.Type
omitEmpty bool
quoted bool
encoder encoderFunc
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that JSON should recognize for the given type.
// The algorithm is breadth-first search over the set of structs to include - the top struct
// and then any reachable anonymous structs.
func typeFields(t reflect.Type) structFields {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
var count, nextCount map[reflect.Type]int
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
// Buffer to run appendHTMLEscape on field names.
var nameEscBuf []byte
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.Anonymous {
t := sf.Type
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if !sf.IsExported() && t.Kind() != reflect.Struct {
// Ignore embedded fields of unexported non-struct types.
continue
}
// Do not ignore embedded fields of unexported struct types
// since they may have exported fields.
} else if !sf.IsExported() {
// Ignore unexported non-embedded fields.
continue
}
tag := sf.Tag.Get("json")
if tag == "-" {
continue
}
name, opts := parseTag(tag)
if !isValidTag(name) {
name = ""
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Pointer {
// Follow pointer.
ft = ft.Elem()
}
// Only strings, floats, integers, and booleans can be quoted.
quoted := false
if opts.Contains("string") {
switch ft.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64,
reflect.String:
quoted = true
}
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
field := field{
name: name,
tag: tagged,
index: index,
typ: ft,
omitEmpty: opts.Contains("omitempty"),
quoted: quoted,
}
field.nameBytes = []byte(field.name)
// Build nameEscHTML and nameNonEsc ahead of time.
nameEscBuf = appendHTMLEscape(nameEscBuf[:0], field.nameBytes)
field.nameEscHTML = `"` + string(nameEscBuf) + `":`
field.nameNonEsc = `"` + field.name + `":`
fields = append(fields, field)
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
next = append(next, field{name: ft.Name(), index: index, typ: ft})
}
}
}
}
sort.Slice(fields, func(i, j int) bool {
x := fields
// sort field by name, breaking ties with depth, then
// breaking ties with "name came from json tag", then
// breaking ties with index sequence.
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
})
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with JSON tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
for i := range fields {
f := &fields[i]
f.encoder = typeEncoder(typeByIndex(t, f.index))
}
exactNameIndex := make(map[string]*field, len(fields))
foldedNameIndex := make(map[string]*field, len(fields))
for i, field := range fields {
exactNameIndex[field.name] = &fields[i]
// For historical reasons, first folded match takes precedence.
if _, ok := foldedNameIndex[string(foldName(field.nameBytes))]; !ok {
foldedNameIndex[string(foldName(field.nameBytes))] = &fields[i]
}
}
return structFields{fields, exactNameIndex, foldedNameIndex}
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// JSON tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level, either both tagged or neither tagged.
if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag {
return field{}, false
}
return fields[0], true
}
var fieldCache sync.Map // map[reflect.Type]structFields
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) structFields {
if f, ok := fieldCache.Load(t); ok {
return f.(structFields)
}
f, _ := fieldCache.LoadOrStore(t, typeFields(t))
return f.(structFields)
}
func mayAppendQuote(b []byte, quoted bool) []byte {
if quoted {
b = append(b, '"')
}
return b
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"unicode"
"unicode/utf8"
)
// foldName returns a folded string such that foldName(x) == foldName(y)
// is identical to bytes.EqualFold(x, y).
func foldName(in []byte) []byte {
// This is inlinable to take advantage of "function outlining".
var arr [32]byte // large enough for most JSON names
return appendFoldedName(arr[:0], in)
}
func appendFoldedName(out, in []byte) []byte {
for i := 0; i < len(in); {
// Handle single-byte ASCII.
if c := in[i]; c < utf8.RuneSelf {
if 'a' <= c && c <= 'z' {
c -= 'a' - 'A'
}
out = append(out, c)
i++
continue
}
// Handle multi-byte Unicode.
r, n := utf8.DecodeRune(in[i:])
out = utf8.AppendRune(out, foldRune(r))
i += n
}
return out
}
// foldRune is returns the smallest rune for all runes in the same fold set.
func foldRune(r rune) rune {
for {
r2 := unicode.SimpleFold(r)
if r2 <= r {
return r2
}
r = r2
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "bytes"
// TODO(https://go.dev/issue/53685): Use bytes.Buffer.AvailableBuffer instead.
func availableBuffer(b *bytes.Buffer) []byte {
return b.Bytes()[b.Len():]
}
// HTMLEscape appends to dst the JSON-encoded src with <, >, &, U+2028 and U+2029
// characters inside string literals changed to \u003c, \u003e, \u0026, \u2028, \u2029
// so that the JSON will be safe to embed inside HTML <script> tags.
// For historical reasons, web browsers don't honor standard HTML
// escaping within <script> tags, so an alternative JSON encoding must be used.
func HTMLEscape(dst *bytes.Buffer, src []byte) {
dst.Grow(len(src))
dst.Write(appendHTMLEscape(availableBuffer(dst), src))
}
func appendHTMLEscape(dst, src []byte) []byte {
// The characters can only appear in string literals,
// so just scan the string one byte at a time.
start := 0
for i, c := range src {
if c == '<' || c == '>' || c == '&' {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + len("\u2029")
}
}
return append(dst, src[start:]...)
}
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
dst.Grow(len(src))
b := availableBuffer(dst)
b, err := appendCompact(b, src, false)
dst.Write(b)
return err
}
func appendCompact(dst, src []byte, escape bool) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '0', '0', hex[c>>4], hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
dst = append(dst, src[start:i]...)
dst = append(dst, '\\', 'u', '2', '0', '2', hex[src[i+2]&0xF])
start = i + len("\u2029")
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
dst = append(dst, src[start:i]...)
start = i + 1
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
dst = append(dst, src[start:]...)
return dst, nil
}
func appendNewline(dst []byte, prefix, indent string, depth int) []byte {
dst = append(dst, '\n')
dst = append(dst, prefix...)
for i := 0; i < depth; i++ {
dst = append(dst, indent...)
}
return dst
}
// indentGrowthFactor specifies the growth factor of indenting JSON input.
// Empirically, the growth factor was measured to be between 1.4x to 1.8x
// for some set of compacted JSON with the indent being a single tab.
// Specify a growth factor slightly larger than what is observed
// to reduce probability of allocation in appendIndent.
// A factor no higher than 2 ensures that wasted space never exceeds 50%.
const indentGrowthFactor = 2
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
dst.Grow(indentGrowthFactor * len(src))
b := availableBuffer(dst)
b, err := appendIndent(b, src, prefix, indent)
dst.Write(b)
return err
}
func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) {
origLen := len(dst)
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
dst = appendNewline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst = append(dst, c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst = append(dst, c)
case ',':
dst = append(dst, c)
dst = appendNewline(dst, prefix, indent, depth)
case ':':
dst = append(dst, c, ' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
dst = appendNewline(dst, prefix, indent, depth)
}
dst = append(dst, c)
default:
dst = append(dst, c)
}
}
if scan.eof() == scanError {
return dst[:origLen], scan.err
}
return dst, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
indentBuf []byte
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
enc.indentBuf, err = appendIndent(enc.indentBuf[:0], b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package pem implements the PEM data encoding, which originated in Privacy
// Enhanced Mail. The most common use of PEM encoding today is in TLS keys and
// certificates. See RFC 1421.
package pem
import (
"bytes"
"encoding/base64"
"errors"
"io"
"sort"
"strings"
)
// A Block represents a PEM encoded structure.
//
// The encoded form is:
//
// -----BEGIN Type-----
// Headers
// base64-encoded Bytes
// -----END Type-----
//
// where Headers is a possibly empty sequence of Key: Value lines.
type Block struct {
Type string // The type, taken from the preamble (i.e. "RSA PRIVATE KEY").
Headers map[string]string // Optional headers.
Bytes []byte // The decoded bytes of the contents. Typically a DER encoded ASN.1 structure.
}
// getLine results the first \r\n or \n delineated line from the given byte
// array. The line does not include trailing whitespace or the trailing new
// line bytes. The remainder of the byte array (also not including the new line
// bytes) is also returned and this will always be smaller than the original
// argument.
func getLine(data []byte) (line, rest []byte) {
i := bytes.IndexByte(data, '\n')
var j int
if i < 0 {
i = len(data)
j = i
} else {
j = i + 1
if i > 0 && data[i-1] == '\r' {
i--
}
}
return bytes.TrimRight(data[0:i], " \t"), data[j:]
}
// removeSpacesAndTabs returns a copy of its input with all spaces and tabs
// removed, if there were any. Otherwise, the input is returned unchanged.
//
// The base64 decoder already skips newline characters, so we don't need to
// filter them out here.
func removeSpacesAndTabs(data []byte) []byte {
if !bytes.ContainsAny(data, " \t") {
// Fast path; most base64 data within PEM contains newlines, but
// no spaces nor tabs. Skip the extra alloc and work.
return data
}
result := make([]byte, len(data))
n := 0
for _, b := range data {
if b == ' ' || b == '\t' {
continue
}
result[n] = b
n++
}
return result[0:n]
}
var pemStart = []byte("\n-----BEGIN ")
var pemEnd = []byte("\n-----END ")
var pemEndOfLine = []byte("-----")
var colon = []byte(":")
// Decode will find the next PEM formatted block (certificate, private key
// etc) in the input. It returns that block and the remainder of the input. If
// no PEM data is found, p is nil and the whole of the input is returned in
// rest.
func Decode(data []byte) (p *Block, rest []byte) {
// pemStart begins with a newline. However, at the very beginning of
// the byte array, we'll accept the start string without it.
rest = data
for {
if bytes.HasPrefix(rest, pemStart[1:]) {
rest = rest[len(pemStart)-1:]
} else if _, after, ok := bytes.Cut(rest, pemStart); ok {
rest = after
} else {
return nil, data
}
var typeLine []byte
typeLine, rest = getLine(rest)
if !bytes.HasSuffix(typeLine, pemEndOfLine) {
continue
}
typeLine = typeLine[0 : len(typeLine)-len(pemEndOfLine)]
p = &Block{
Headers: make(map[string]string),
Type: string(typeLine),
}
for {
// This loop terminates because getLine's second result is
// always smaller than its argument.
if len(rest) == 0 {
return nil, data
}
line, next := getLine(rest)
key, val, ok := bytes.Cut(line, colon)
if !ok {
break
}
// TODO(agl): need to cope with values that spread across lines.
key = bytes.TrimSpace(key)
val = bytes.TrimSpace(val)
p.Headers[string(key)] = string(val)
rest = next
}
var endIndex, endTrailerIndex int
// If there were no headers, the END line might occur
// immediately, without a leading newline.
if len(p.Headers) == 0 && bytes.HasPrefix(rest, pemEnd[1:]) {
endIndex = 0
endTrailerIndex = len(pemEnd) - 1
} else {
endIndex = bytes.Index(rest, pemEnd)
endTrailerIndex = endIndex + len(pemEnd)
}
if endIndex < 0 {
continue
}
// After the "-----" of the ending line, there should be the same type
// and then a final five dashes.
endTrailer := rest[endTrailerIndex:]
endTrailerLen := len(typeLine) + len(pemEndOfLine)
if len(endTrailer) < endTrailerLen {
continue
}
restOfEndLine := endTrailer[endTrailerLen:]
endTrailer = endTrailer[:endTrailerLen]
if !bytes.HasPrefix(endTrailer, typeLine) ||
!bytes.HasSuffix(endTrailer, pemEndOfLine) {
continue
}
// The line must end with only whitespace.
if s, _ := getLine(restOfEndLine); len(s) != 0 {
continue
}
base64Data := removeSpacesAndTabs(rest[:endIndex])
p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data)))
n, err := base64.StdEncoding.Decode(p.Bytes, base64Data)
if err != nil {
continue
}
p.Bytes = p.Bytes[:n]
// the -1 is because we might have only matched pemEnd without the
// leading newline if the PEM block was empty.
_, rest = getLine(rest[endIndex+len(pemEnd)-1:])
return p, rest
}
}
const pemLineLength = 64
type lineBreaker struct {
line [pemLineLength]byte
used int
out io.Writer
}
var nl = []byte{'\n'}
func (l *lineBreaker) Write(b []byte) (n int, err error) {
if l.used+len(b) < pemLineLength {
copy(l.line[l.used:], b)
l.used += len(b)
return len(b), nil
}
n, err = l.out.Write(l.line[0:l.used])
if err != nil {
return
}
excess := pemLineLength - l.used
l.used = 0
n, err = l.out.Write(b[0:excess])
if err != nil {
return
}
n, err = l.out.Write(nl)
if err != nil {
return
}
return l.Write(b[excess:])
}
func (l *lineBreaker) Close() (err error) {
if l.used > 0 {
_, err = l.out.Write(l.line[0:l.used])
if err != nil {
return
}
_, err = l.out.Write(nl)
}
return
}
func writeHeader(out io.Writer, k, v string) error {
_, err := out.Write([]byte(k + ": " + v + "\n"))
return err
}
// Encode writes the PEM encoding of b to out.
func Encode(out io.Writer, b *Block) error {
// Check for invalid block before writing any output.
for k := range b.Headers {
if strings.Contains(k, ":") {
return errors.New("pem: cannot encode a header key that contains a colon")
}
}
// All errors below are relayed from underlying io.Writer,
// so it is now safe to write data.
if _, err := out.Write(pemStart[1:]); err != nil {
return err
}
if _, err := out.Write([]byte(b.Type + "-----\n")); err != nil {
return err
}
if len(b.Headers) > 0 {
const procType = "Proc-Type"
h := make([]string, 0, len(b.Headers))
hasProcType := false
for k := range b.Headers {
if k == procType {
hasProcType = true
continue
}
h = append(h, k)
}
// The Proc-Type header must be written first.
// See RFC 1421, section 4.6.1.1
if hasProcType {
if err := writeHeader(out, procType, b.Headers[procType]); err != nil {
return err
}
}
// For consistency of output, write other headers sorted by key.
sort.Strings(h)
for _, k := range h {
if err := writeHeader(out, k, b.Headers[k]); err != nil {
return err
}
}
if _, err := out.Write(nl); err != nil {
return err
}
}
var breaker lineBreaker
breaker.out = out
b64 := base64.NewEncoder(base64.StdEncoding, &breaker)
if _, err := b64.Write(b.Bytes); err != nil {
return err
}
b64.Close()
breaker.Close()
if _, err := out.Write(pemEnd[1:]); err != nil {
return err
}
_, err := out.Write([]byte(b.Type + "-----\n"))
return err
}
// EncodeToMemory returns the PEM encoding of b.
//
// If b has invalid headers and cannot be encoded,
// EncodeToMemory returns nil. If it is important to
// report details about this error case, use Encode instead.
func EncodeToMemory(b *Block) []byte {
var buf bytes.Buffer
if err := Encode(&buf, b); err != nil {
return nil
}
return buf.Bytes()
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"bufio"
"bytes"
"encoding"
"errors"
"fmt"
"io"
"reflect"
"strconv"
"strings"
)
const (
// Header is a generic XML header suitable for use with the output of Marshal.
// This is not automatically added to any output of this package,
// it is provided as a convenience.
Header = `<?xml version="1.0" encoding="UTF-8"?>` + "\n"
)
// Marshal returns the XML encoding of v.
//
// Marshal handles an array or slice by marshaling each of the elements.
// Marshal handles a pointer by marshaling the value it points at or, if the
// pointer is nil, by writing nothing. Marshal handles an interface value by
// marshaling the value it contains or, if the interface value is nil, by
// writing nothing. Marshal handles all other data by writing one or more XML
// elements containing the data.
//
// The name for the XML elements is taken from, in order of preference:
// - the tag on the XMLName field, if the data is a struct
// - the value of the XMLName field of type Name
// - the tag of the struct field used to obtain the data
// - the name of the struct field used to obtain the data
// - the name of the marshaled type
//
// The XML element for a struct contains marshaled elements for each of the
// exported fields of the struct, with these exceptions:
// - the XMLName field, described above, is omitted.
// - a field with tag "-" is omitted.
// - a field with tag "name,attr" becomes an attribute with
// the given name in the XML element.
// - a field with tag ",attr" becomes an attribute with the
// field name in the XML element.
// - a field with tag ",chardata" is written as character data,
// not as an XML element.
// - a field with tag ",cdata" is written as character data
// wrapped in one or more <![CDATA[ ... ]]> tags, not as an XML element.
// - a field with tag ",innerxml" is written verbatim, not subject
// to the usual marshaling procedure.
// - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshaling procedure. It must not contain
// the "--" string within it.
// - a field with a tag including the "omitempty" option is omitted
// if the field value is empty. The empty values are false, 0, any
// nil pointer or interface value, and any array, slice, map, or
// string of length zero.
// - an anonymous struct field is handled as if the fields of its
// value were part of the outer struct.
// - a field implementing Marshaler is written by calling its MarshalXML
// method.
// - a field implementing encoding.TextMarshaler is written by encoding the
// result of its MarshalText method as text.
//
// If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name
// the same parent will be enclosed in one XML element.
//
// If the XML name for a struct field is defined by both the field tag and the
// struct's XMLName field, the names must match.
//
// See MarshalIndent for an example.
//
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(v any) ([]byte, error) {
var b bytes.Buffer
enc := NewEncoder(&b)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Marshaler is the interface implemented by objects that can marshal
// themselves into valid XML elements.
//
// MarshalXML encodes the receiver as zero or more XML elements.
// By convention, arrays or slices are typically encoded as a sequence
// of elements, one per entry.
// Using start as the element tag is not required, but doing so
// will enable Unmarshal to match the XML elements to the correct
// struct field.
// One common implementation strategy is to construct a separate
// value with a layout corresponding to the desired XML and then
// to encode it using e.EncodeElement.
// Another common strategy is to use repeated calls to e.EncodeToken
// to generate the XML output one token at a time.
// The sequence of encoded tokens must make up zero or more valid
// XML elements.
type Marshaler interface {
MarshalXML(e *Encoder, start StartElement) error
}
// MarshalerAttr is the interface implemented by objects that can marshal
// themselves into valid XML attributes.
//
// MarshalXMLAttr returns an XML attribute with the encoded value of the receiver.
// Using name as the attribute name is not required, but doing so
// will enable Unmarshal to match the attribute to the correct
// struct field.
// If MarshalXMLAttr returns the zero attribute Attr{}, no attribute
// will be generated in the output.
// MarshalXMLAttr is used only for struct fields with the
// "attr" option in the field tag.
type MarshalerAttr interface {
MarshalXMLAttr(name Name) (Attr, error)
}
// MarshalIndent works like Marshal, but each XML element begins on a new
// indented line that starts with prefix and is followed by one or more
// copies of indent according to the nesting depth.
func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
var b bytes.Buffer
enc := NewEncoder(&b)
enc.Indent(prefix, indent)
if err := enc.Encode(v); err != nil {
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// An Encoder writes XML data to an output stream.
type Encoder struct {
p printer
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{printer{w: bufio.NewWriter(w)}}
e.p.encoder = e
return e
}
// Indent sets the encoder to generate XML in which each element
// begins on a new indented line that starts with prefix and is followed by
// one or more copies of indent according to the nesting depth.
func (enc *Encoder) Indent(prefix, indent string) {
enc.p.prefix = prefix
enc.p.indent = indent
}
// Encode writes the XML encoding of v to the stream.
//
// See the documentation for Marshal for details about the conversion
// of Go values to XML.
//
// Encode calls Flush before returning.
func (enc *Encoder) Encode(v any) error {
err := enc.p.marshalValue(reflect.ValueOf(v), nil, nil)
if err != nil {
return err
}
return enc.p.w.Flush()
}
// EncodeElement writes the XML encoding of v to the stream,
// using start as the outermost tag in the encoding.
//
// See the documentation for Marshal for details about the conversion
// of Go values to XML.
//
// EncodeElement calls Flush before returning.
func (enc *Encoder) EncodeElement(v any, start StartElement) error {
err := enc.p.marshalValue(reflect.ValueOf(v), nil, &start)
if err != nil {
return err
}
return enc.p.w.Flush()
}
var (
begComment = []byte("<!--")
endComment = []byte("-->")
endProcInst = []byte("?>")
)
// EncodeToken writes the given XML token to the stream.
// It returns an error if StartElement and EndElement tokens are not properly matched.
//
// EncodeToken does not call Flush, because usually it is part of a larger operation
// such as Encode or EncodeElement (or a custom Marshaler's MarshalXML invoked
// during those), and those will call Flush when finished.
// Callers that create an Encoder and then invoke EncodeToken directly, without
// using Encode or EncodeElement, need to call Flush when finished to ensure
// that the XML is written to the underlying writer.
//
// EncodeToken allows writing a ProcInst with Target set to "xml" only as the first token
// in the stream.
func (enc *Encoder) EncodeToken(t Token) error {
p := &enc.p
switch t := t.(type) {
case StartElement:
if err := p.writeStart(&t); err != nil {
return err
}
case EndElement:
if err := p.writeEnd(t.Name); err != nil {
return err
}
case CharData:
escapeText(p, t, false)
case Comment:
if bytes.Contains(t, endComment) {
return fmt.Errorf("xml: EncodeToken of Comment containing --> marker")
}
p.WriteString("<!--")
p.Write(t)
p.WriteString("-->")
return p.cachedWriteError()
case ProcInst:
// First token to be encoded which is also a ProcInst with target of xml
// is the xml declaration. The only ProcInst where target of xml is allowed.
if t.Target == "xml" && p.w.Buffered() != 0 {
return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded")
}
if !isNameString(t.Target) {
return fmt.Errorf("xml: EncodeToken of ProcInst with invalid Target")
}
if bytes.Contains(t.Inst, endProcInst) {
return fmt.Errorf("xml: EncodeToken of ProcInst containing ?> marker")
}
p.WriteString("<?")
p.WriteString(t.Target)
if len(t.Inst) > 0 {
p.WriteByte(' ')
p.Write(t.Inst)
}
p.WriteString("?>")
case Directive:
if !isValidDirective(t) {
return fmt.Errorf("xml: EncodeToken of Directive containing wrong < or > markers")
}
p.WriteString("<!")
p.Write(t)
p.WriteString(">")
default:
return fmt.Errorf("xml: EncodeToken of invalid token type")
}
return p.cachedWriteError()
}
// isValidDirective reports whether dir is a valid directive text,
// meaning angle brackets are matched, ignoring comments and strings.
func isValidDirective(dir Directive) bool {
var (
depth int
inquote uint8
incomment bool
)
for i, c := range dir {
switch {
case incomment:
if c == '>' {
if n := 1 + i - len(endComment); n >= 0 && bytes.Equal(dir[n:i+1], endComment) {
incomment = false
}
}
// Just ignore anything in comment
case inquote != 0:
if c == inquote {
inquote = 0
}
// Just ignore anything within quotes
case c == '\'' || c == '"':
inquote = c
case c == '<':
if i+len(begComment) < len(dir) && bytes.Equal(dir[i:i+len(begComment)], begComment) {
incomment = true
} else {
depth++
}
case c == '>':
if depth == 0 {
return false
}
depth--
}
}
return depth == 0 && inquote == 0 && !incomment
}
// Flush flushes any buffered XML to the underlying writer.
// See the EncodeToken documentation for details about when it is necessary.
func (enc *Encoder) Flush() error {
return enc.p.w.Flush()
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (enc *Encoder) Close() error {
return enc.p.Close()
}
type printer struct {
w *bufio.Writer
encoder *Encoder
seq int
indent string
prefix string
depth int
indentedIn bool
putNewline bool
attrNS map[string]string // map prefix -> name space
attrPrefix map[string]string // map name space -> prefix
prefixes []string
tags []Name
closed bool
err error
}
// createAttrPrefix finds the name space prefix attribute to use for the given name space,
// defining a new prefix if necessary. It returns the prefix.
func (p *printer) createAttrPrefix(url string) string {
if prefix := p.attrPrefix[url]; prefix != "" {
return prefix
}
// The "http://www.w3.org/XML/1998/namespace" name space is predefined as "xml"
// and must be referred to that way.
// (The "http://www.w3.org/2000/xmlns/" name space is also predefined as "xmlns",
// but users should not be trying to use that one directly - that's our job.)
if url == xmlURL {
return xmlPrefix
}
// Need to define a new name space.
if p.attrPrefix == nil {
p.attrPrefix = make(map[string]string)
p.attrNS = make(map[string]string)
}
// Pick a name. We try to use the final element of the path
// but fall back to _.
prefix := strings.TrimRight(url, "/")
if i := strings.LastIndex(prefix, "/"); i >= 0 {
prefix = prefix[i+1:]
}
if prefix == "" || !isName([]byte(prefix)) || strings.Contains(prefix, ":") {
prefix = "_"
}
// xmlanything is reserved and any variant of it regardless of
// case should be matched, so:
// (('X'|'x') ('M'|'m') ('L'|'l'))
// See Section 2.3 of https://www.w3.org/TR/REC-xml/
if len(prefix) >= 3 && strings.EqualFold(prefix[:3], "xml") {
prefix = "_" + prefix
}
if p.attrNS[prefix] != "" {
// Name is taken. Find a better one.
for p.seq++; ; p.seq++ {
if id := prefix + "_" + strconv.Itoa(p.seq); p.attrNS[id] == "" {
prefix = id
break
}
}
}
p.attrPrefix[url] = prefix
p.attrNS[prefix] = url
p.WriteString(`xmlns:`)
p.WriteString(prefix)
p.WriteString(`="`)
EscapeText(p, []byte(url))
p.WriteString(`" `)
p.prefixes = append(p.prefixes, prefix)
return prefix
}
// deleteAttrPrefix removes an attribute name space prefix.
func (p *printer) deleteAttrPrefix(prefix string) {
delete(p.attrPrefix, p.attrNS[prefix])
delete(p.attrNS, prefix)
}
func (p *printer) markPrefix() {
p.prefixes = append(p.prefixes, "")
}
func (p *printer) popPrefix() {
for len(p.prefixes) > 0 {
prefix := p.prefixes[len(p.prefixes)-1]
p.prefixes = p.prefixes[:len(p.prefixes)-1]
if prefix == "" {
break
}
p.deleteAttrPrefix(prefix)
}
}
var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
// marshalValue writes one or more XML elements representing val.
// If val was obtained from a struct field, finfo must have its details.
func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error {
if startTemplate != nil && startTemplate.Name.Local == "" {
return fmt.Errorf("xml: EncodeElement of StartElement with missing name")
}
if !val.IsValid() {
return nil
}
if finfo != nil && finfo.flags&fOmitEmpty != 0 && isEmptyValue(val) {
return nil
}
// Drill into interfaces and pointers.
// This can turn into an infinite loop given a cyclic chain,
// but it matches the Go 1 behavior.
for val.Kind() == reflect.Interface || val.Kind() == reflect.Pointer {
if val.IsNil() {
return nil
}
val = val.Elem()
}
kind := val.Kind()
typ := val.Type()
// Check for marshaler.
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
// Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
}
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
for i, n := 0, val.Len(); i < n; i++ {
if err := p.marshalValue(val.Index(i), finfo, startTemplate); err != nil {
return err
}
}
return nil
}
tinfo, err := getTypeInfo(typ)
if err != nil {
return err
}
// Create start element.
// Precedence for the XML element name is:
// 0. startTemplate
// 1. XMLName field in underlying struct;
// 2. field name/tag in the struct field; and
// 3. type name
var start StartElement
if startTemplate != nil {
start.Name = startTemplate.Name
start.Attr = append(start.Attr, startTemplate.Attr...)
} else if tinfo.xmlname != nil {
xmlname := tinfo.xmlname
if xmlname.name != "" {
start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
} else {
fv := xmlname.value(val, dontInitNilPointers)
if v, ok := fv.Interface().(Name); ok && v.Local != "" {
start.Name = v
}
}
}
if start.Name.Local == "" && finfo != nil {
start.Name.Space, start.Name.Local = finfo.xmlns, finfo.name
}
if start.Name.Local == "" {
name := typ.Name()
if i := strings.IndexByte(name, '['); i >= 0 {
// Truncate generic instantiation name. See issue 48318.
name = name[:i]
}
if name == "" {
return &UnsupportedTypeError{typ}
}
start.Name.Local = name
}
// Attributes
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fAttr == 0 {
continue
}
fv := finfo.value(val, dontInitNilPointers)
if finfo.flags&fOmitEmpty != 0 && (!fv.IsValid() || isEmptyValue(fv)) {
continue
}
if fv.Kind() == reflect.Interface && fv.IsNil() {
continue
}
name := Name{Space: finfo.xmlns, Local: finfo.name}
if err := p.marshalAttr(&start, name, fv); err != nil {
return err
}
}
// If a name was found, namespace is overridden with an empty space
if tinfo.xmlname != nil && start.Name.Space == "" &&
len(p.tags) != 0 && p.tags[len(p.tags)-1].Space != "" {
start.Attr = append(start.Attr, Attr{Name{"", xmlnsPrefix}, ""})
}
if err := p.writeStart(&start); err != nil {
return err
}
if val.Kind() == reflect.Struct {
err = p.marshalStruct(tinfo, val)
} else {
s, b, err1 := p.marshalSimple(typ, val)
if err1 != nil {
err = err1
} else if b != nil {
EscapeText(p, b)
} else {
p.EscapeString(s)
}
}
if err != nil {
return err
}
if err := p.writeEnd(start.Name); err != nil {
return err
}
return p.cachedWriteError()
}
// marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error {
if val.CanInterface() && val.Type().Implements(marshalerAttrType) {
attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}
}
if val.CanInterface() && val.Type().Implements(textMarshalerType) {
text, err := val.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}
}
// Dereference or skip nil pointer, interface values.
switch val.Kind() {
case reflect.Pointer, reflect.Interface:
if val.IsNil() {
return nil
}
val = val.Elem()
}
// Walk slices.
if val.Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
n := val.Len()
for i := 0; i < n; i++ {
if err := p.marshalAttr(start, name, val.Index(i)); err != nil {
return err
}
}
return nil
}
if val.Type() == attrType {
start.Attr = append(start.Attr, val.Interface().(Attr))
return nil
}
s, b, err := p.marshalSimple(val.Type(), val)
if err != nil {
return err
}
if b != nil {
s = string(b)
}
start.Attr = append(start.Attr, Attr{name, s})
return nil
}
// defaultStart returns the default start element to use,
// given the reflect type, field info, and start template.
func defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement {
var start StartElement
// Precedence for the XML element name is as above,
// except that we do not look inside structs for the first field.
if startTemplate != nil {
start.Name = startTemplate.Name
start.Attr = append(start.Attr, startTemplate.Attr...)
} else if finfo != nil && finfo.name != "" {
start.Name.Local = finfo.name
start.Name.Space = finfo.xmlns
} else if typ.Name() != "" {
start.Name.Local = typ.Name()
} else {
// Must be a pointer to a named type,
// since it has the Marshaler methods.
start.Name.Local = typ.Elem().Name()
}
return start
}
// marshalInterface marshals a Marshaler interface value.
func (p *printer) marshalInterface(val Marshaler, start StartElement) error {
// Push a marker onto the tag stack so that MarshalXML
// cannot close the XML tags that it did not open.
p.tags = append(p.tags, Name{})
n := len(p.tags)
err := val.MarshalXML(p.encoder, start)
if err != nil {
return err
}
// Make sure MarshalXML closed all its tags. p.tags[n-1] is the mark.
if len(p.tags) > n {
return fmt.Errorf("xml: %s.MarshalXML wrote invalid XML: <%s> not closed", receiverType(val), p.tags[len(p.tags)-1].Local)
}
p.tags = p.tags[:n-1]
return nil
}
// marshalTextInterface marshals a TextMarshaler interface value.
func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error {
if err := p.writeStart(&start); err != nil {
return err
}
text, err := val.MarshalText()
if err != nil {
return err
}
EscapeText(p, text)
return p.writeEnd(start.Name)
}
// writeStart writes the given start element.
func (p *printer) writeStart(start *StartElement) error {
if start.Name.Local == "" {
return fmt.Errorf("xml: start tag with no name")
}
p.tags = append(p.tags, start.Name)
p.markPrefix()
p.writeIndent(1)
p.WriteByte('<')
p.WriteString(start.Name.Local)
if start.Name.Space != "" {
p.WriteString(` xmlns="`)
p.EscapeString(start.Name.Space)
p.WriteByte('"')
}
// Attributes
for _, attr := range start.Attr {
name := attr.Name
if name.Local == "" {
continue
}
p.WriteByte(' ')
if name.Space != "" {
p.WriteString(p.createAttrPrefix(name.Space))
p.WriteByte(':')
}
p.WriteString(name.Local)
p.WriteString(`="`)
p.EscapeString(attr.Value)
p.WriteByte('"')
}
p.WriteByte('>')
return nil
}
func (p *printer) writeEnd(name Name) error {
if name.Local == "" {
return fmt.Errorf("xml: end tag with no name")
}
if len(p.tags) == 0 || p.tags[len(p.tags)-1].Local == "" {
return fmt.Errorf("xml: end tag </%s> without start tag", name.Local)
}
if top := p.tags[len(p.tags)-1]; top != name {
if top.Local != name.Local {
return fmt.Errorf("xml: end tag </%s> does not match start tag <%s>", name.Local, top.Local)
}
return fmt.Errorf("xml: end tag </%s> in namespace %s does not match start tag <%s> in namespace %s", name.Local, name.Space, top.Local, top.Space)
}
p.tags = p.tags[:len(p.tags)-1]
p.writeIndent(-1)
p.WriteByte('<')
p.WriteByte('/')
p.WriteString(name.Local)
p.WriteByte('>')
p.popPrefix()
return nil
}
func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) {
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil, nil
case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()), nil, nil
case reflect.String:
return val.String(), nil, nil
case reflect.Bool:
return strconv.FormatBool(val.Bool()), nil, nil
case reflect.Array:
if typ.Elem().Kind() != reflect.Uint8 {
break
}
// [...]byte
var bytes []byte
if val.CanAddr() {
bytes = val.Slice(0, val.Len()).Bytes()
} else {
bytes = make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(bytes), val)
}
return "", bytes, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
break
}
// []byte
return "", val.Bytes(), nil
}
return "", nil, &UnsupportedTypeError{typ}
}
var ddBytes = []byte("--")
// indirect drills into interfaces and pointers, returning the pointed-at value.
// If it encounters a nil interface or pointer, indirect returns that nil value.
// This can turn into an infinite loop given a cyclic chain,
// but it matches the Go 1 behavior.
func indirect(vf reflect.Value) reflect.Value {
for vf.Kind() == reflect.Interface || vf.Kind() == reflect.Pointer {
if vf.IsNil() {
return vf
}
vf = vf.Elem()
}
return vf
}
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
s := parentStack{p: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fAttr != 0 {
continue
}
vf := finfo.value(val, dontInitNilPointers)
if !vf.IsValid() {
// The field is behind an anonymous struct field that's
// nil. Skip it.
continue
}
switch finfo.flags & fMode {
case fCDATA, fCharData:
emit := EscapeText
if finfo.flags&fMode == fCDATA {
emit = emitCDATA
}
if err := s.trim(finfo.parents); err != nil {
return err
}
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
}
var scratch [64]byte
vf = indirect(vf)
switch vf.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if err := emit(p, strconv.AppendInt(scratch[:0], vf.Int(), 10)); err != nil {
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if err := emit(p, strconv.AppendUint(scratch[:0], vf.Uint(), 10)); err != nil {
return err
}
case reflect.Float32, reflect.Float64:
if err := emit(p, strconv.AppendFloat(scratch[:0], vf.Float(), 'g', -1, vf.Type().Bits())); err != nil {
return err
}
case reflect.Bool:
if err := emit(p, strconv.AppendBool(scratch[:0], vf.Bool())); err != nil {
return err
}
case reflect.String:
if err := emit(p, []byte(vf.String())); err != nil {
return err
}
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
if err := emit(p, elem); err != nil {
return err
}
}
}
continue
case fComment:
if err := s.trim(finfo.parents); err != nil {
return err
}
vf = indirect(vf)
k := vf.Kind()
if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
}
if vf.Len() == 0 {
continue
}
p.writeIndent(0)
p.WriteString("<!--")
dashDash := false
dashLast := false
switch k {
case reflect.String:
s := vf.String()
dashDash = strings.Contains(s, "--")
dashLast = s[len(s)-1] == '-'
if !dashDash {
p.WriteString(s)
}
case reflect.Slice:
b := vf.Bytes()
dashDash = bytes.Contains(b, ddBytes)
dashLast = b[len(b)-1] == '-'
if !dashDash {
p.Write(b)
}
default:
panic("can't happen")
}
if dashDash {
return fmt.Errorf(`xml: comments must not contain "--"`)
}
if dashLast {
// "--->" is invalid grammar. Make it "- -->"
p.WriteByte(' ')
}
p.WriteString("-->")
continue
case fInnerXML:
vf = indirect(vf)
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case fElement, fElement | fAny:
if err := s.trim(finfo.parents); err != nil {
return err
}
if len(finfo.parents) > len(s.stack) {
if vf.Kind() != reflect.Pointer && vf.Kind() != reflect.Interface || !vf.IsNil() {
if err := s.push(finfo.parents[len(s.stack):]); err != nil {
return err
}
}
}
}
if err := p.marshalValue(vf, finfo, nil); err != nil {
return err
}
}
s.trim(nil)
return p.cachedWriteError()
}
// Write implements io.Writer
func (p *printer) Write(b []byte) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.Write(b)
}
return n, p.err
}
// WriteString implements io.StringWriter
func (p *printer) WriteString(s string) (n int, err error) {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
n, p.err = p.w.WriteString(s)
}
return n, p.err
}
// WriteByte implements io.ByteWriter
func (p *printer) WriteByte(c byte) error {
if p.closed && p.err == nil {
p.err = errors.New("use of closed Encoder")
}
if p.err == nil {
p.err = p.w.WriteByte(c)
}
return p.err
}
// Close the Encoder, indicating that no more data will be written. It flushes
// any buffered XML to the underlying writer and returns an error if the
// written XML is invalid (e.g. by containing unclosed elements).
func (p *printer) Close() error {
if p.closed {
return nil
}
p.closed = true
if err := p.w.Flush(); err != nil {
return err
}
if len(p.tags) > 0 {
return fmt.Errorf("unclosed tag <%s>", p.tags[len(p.tags)-1].Local)
}
return nil
}
// return the bufio Writer's cached write error
func (p *printer) cachedWriteError() error {
_, err := p.Write(nil)
return err
}
func (p *printer) writeIndent(depthDelta int) {
if len(p.prefix) == 0 && len(p.indent) == 0 {
return
}
if depthDelta < 0 {
p.depth--
if p.indentedIn {
p.indentedIn = false
return
}
p.indentedIn = false
}
if p.putNewline {
p.WriteByte('\n')
} else {
p.putNewline = true
}
if len(p.prefix) > 0 {
p.WriteString(p.prefix)
}
if len(p.indent) > 0 {
for i := 0; i < p.depth; i++ {
p.WriteString(p.indent)
}
}
if depthDelta > 0 {
p.depth++
p.indentedIn = true
}
}
type parentStack struct {
p *printer
stack []string
}
// trim updates the XML context to match the longest common prefix of the stack
// and the given parents. A closing tag will be written for every parent
// popped. Passing a zero slice or nil will close all the elements.
func (s *parentStack) trim(parents []string) error {
split := 0
for ; split < len(parents) && split < len(s.stack); split++ {
if parents[split] != s.stack[split] {
break
}
}
for i := len(s.stack) - 1; i >= split; i-- {
if err := s.p.writeEnd(Name{Local: s.stack[i]}); err != nil {
return err
}
}
s.stack = s.stack[:split]
return nil
}
// push adds parent elements to the stack and writes open tags.
func (s *parentStack) push(parents []string) error {
for i := 0; i < len(parents); i++ {
if err := s.p.writeStart(&StartElement{Name: Name{Local: parents[i]}}); err != nil {
return err
}
}
s.stack = append(s.stack, parents...)
return nil
}
// UnsupportedTypeError is returned when Marshal encounters a type
// that cannot be converted into XML.
type UnsupportedTypeError struct {
Type reflect.Type
}
func (e *UnsupportedTypeError) Error() string {
return "xml: unsupported type: " + e.Type.String()
}
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Pointer:
return v.IsNil()
}
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"bytes"
"encoding"
"errors"
"fmt"
"reflect"
"runtime"
"strconv"
"strings"
)
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
// an XML element is an order-dependent collection of anonymous
// values, while a data structure is an order-independent collection
// of named values.
// See package json for a textual representation more suitable
// to data structures.
// Unmarshal parses the XML-encoded data and stores the result in
// the value pointed to by v, which must be an arbitrary struct,
// slice, or string. Well-formed data that does not fit into v is
// discarded.
//
// Because Unmarshal uses the reflect package, it can only assign
// to exported (upper case) fields. Unmarshal uses a case-sensitive
// comparison to match XML element names to tag values and struct
// field names.
//
// Unmarshal maps an XML element to a struct using the following rules.
// In the rules, the tag of a field refers to the value associated with the
// key 'xml' in the struct field's tag (see the example above).
//
// - If the struct has a field of type []byte or string with tag
// ",innerxml", Unmarshal accumulates the raw XML nested inside the
// element in that field. The rest of the rules still apply.
//
// - If the struct has a field named XMLName of type Name,
// Unmarshal records the element name in that field.
//
// - If the XMLName field has an associated tag of the form
// "name" or "namespace-URL name", the XML element must have
// the given name (and, optionally, name space) or else Unmarshal
// returns an error.
//
// - If the XML element has an attribute whose name matches a
// struct field name with an associated tag containing ",attr" or
// the explicit name in a struct field tag of the form "name,attr",
// Unmarshal records the attribute value in that field.
//
// - If the XML element has an attribute not handled by the previous
// rule and the struct has a field with an associated tag containing
// ",any,attr", Unmarshal records the attribute value in the first
// such field.
//
// - If the XML element contains character data, that data is
// accumulated in the first struct field that has tag ",chardata".
// The struct field may have type []byte or string.
// If there is no such field, the character data is discarded.
//
// - If the XML element contains comments, they are accumulated in
// the first struct field that has tag ",comment". The struct
// field may have type []byte or string. If there is no such
// field, the comments are discarded.
//
// - If the XML element contains a sub-element whose name matches
// the prefix of a tag formatted as "a" or "a>b>c", unmarshal
// will descend into the XML structure looking for elements with the
// given names, and will map the innermost elements to that struct
// field. A tag starting with ">" is equivalent to one starting
// with the field name followed by ">".
//
// - If the XML element contains a sub-element whose name matches
// a struct field's XMLName tag and the struct field has no
// explicit name tag as per the previous rule, unmarshal maps
// the sub-element to that struct field.
//
// - If the XML element contains a sub-element whose name matches a
// field without any mode flags (",attr", ",chardata", etc), Unmarshal
// maps the sub-element to that struct field.
//
// - If the XML element contains a sub-element that hasn't matched any
// of the above rules and the struct has a field with tag ",any",
// unmarshal maps the sub-element to that struct field.
//
// - An anonymous struct field is handled as if the fields of its
// value were part of the outer struct.
//
// - A struct field with tag "-" is never unmarshaled into.
//
// If Unmarshal encounters a field type that implements the Unmarshaler
// interface, Unmarshal calls its UnmarshalXML method to produce the value from
// the XML element. Otherwise, if the value implements
// encoding.TextUnmarshaler, Unmarshal calls that value's UnmarshalText method.
//
// Unmarshal maps an XML element to a string or []byte by saving the
// concatenation of that element's character data in the string or
// []byte. The saved []byte is never nil.
//
// Unmarshal maps an attribute value to a string or []byte by saving
// the value in the string or slice.
//
// Unmarshal maps an attribute value to an Attr by saving the attribute,
// including its name, in the Attr.
//
// Unmarshal maps an XML element or attribute value to a slice by
// extending the length of the slice and mapping the element or attribute
// to the newly created value.
//
// Unmarshal maps an XML element or attribute value to a bool by
// setting it to the boolean value represented by the string. Whitespace
// is trimmed and ignored.
//
// Unmarshal maps an XML element or attribute value to an integer or
// floating-point field by setting the field to the result of
// interpreting the string value in decimal. There is no check for
// overflow. Whitespace is trimmed and ignored.
//
// Unmarshal maps an XML element to a Name by recording the element
// name.
//
// Unmarshal maps an XML element to a pointer by setting the pointer
// to a freshly allocated value and then mapping the element to that value.
//
// A missing element or empty attribute value will be unmarshaled as a zero value.
// If the field is a slice, a zero value will be appended to the field. Otherwise, the
// field will be set to its zero value.
func Unmarshal(data []byte, v any) error {
return NewDecoder(bytes.NewReader(data)).Decode(v)
}
// Decode works like Unmarshal, except it reads the decoder
// stream to find the start element.
func (d *Decoder) Decode(v any) error {
return d.DecodeElement(v, nil)
}
// DecodeElement works like Unmarshal except that it takes
// a pointer to the start XML element to decode into v.
// It is useful when a client reads some raw XML tokens itself
// but also wants to defer to Unmarshal for some elements.
func (d *Decoder) DecodeElement(v any, start *StartElement) error {
val := reflect.ValueOf(v)
if val.Kind() != reflect.Pointer {
return errors.New("non-pointer passed to Unmarshal")
}
if val.IsNil() {
return errors.New("nil pointer passed to Unmarshal")
}
return d.unmarshal(val.Elem(), start, 0)
}
// An UnmarshalError represents an error in the unmarshaling process.
type UnmarshalError string
func (e UnmarshalError) Error() string { return string(e) }
// Unmarshaler is the interface implemented by objects that can unmarshal
// an XML element description of themselves.
//
// UnmarshalXML decodes a single XML element
// beginning with the given start element.
// If it returns an error, the outer call to Unmarshal stops and
// returns that error.
// UnmarshalXML must consume exactly one XML element.
// One common implementation strategy is to unmarshal into
// a separate value with a layout matching the expected XML
// using d.DecodeElement, and then to copy the data from
// that value into the receiver.
// Another common strategy is to use d.Token to process the
// XML object one token at a time.
// UnmarshalXML may not use d.RawToken.
type Unmarshaler interface {
UnmarshalXML(d *Decoder, start StartElement) error
}
// UnmarshalerAttr is the interface implemented by objects that can unmarshal
// an XML attribute description of themselves.
//
// UnmarshalXMLAttr decodes a single XML attribute.
// If it returns an error, the outer call to Unmarshal stops and
// returns that error.
// UnmarshalXMLAttr is used only for struct fields with the
// "attr" option in the field tag.
type UnmarshalerAttr interface {
UnmarshalXMLAttr(attr Attr) error
}
// receiverType returns the receiver type to use in an expression like "%s.MethodName".
func receiverType(val any) string {
t := reflect.TypeOf(val)
if t.Name() != "" {
return t.String()
}
return "(" + t.String() + ")"
}
// unmarshalInterface unmarshals a single XML element into val.
// start is the opening tag of the element.
func (d *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
// Record that decoder must stop at end tag corresponding to start.
d.pushEOF()
d.unmarshalDepth++
err := val.UnmarshalXML(d, *start)
d.unmarshalDepth--
if err != nil {
d.popEOF()
return err
}
if !d.popEOF() {
return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local)
}
return nil
}
// unmarshalTextInterface unmarshals a single XML element into val.
// The chardata contained in the element (but not its children)
// is passed to the text unmarshaler.
func (d *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler) error {
var buf []byte
depth := 1
for depth > 0 {
t, err := d.Token()
if err != nil {
return err
}
switch t := t.(type) {
case CharData:
if depth == 1 {
buf = append(buf, t...)
}
case StartElement:
depth++
case EndElement:
depth--
}
}
return val.UnmarshalText(buf)
}
// unmarshalAttr unmarshals a single XML attribute into val.
func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
if val.Kind() == reflect.Pointer {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
}
}
// Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
}
}
if val.Type().Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
// Slice of element values.
// Grow slice.
n := val.Len()
val.Set(reflect.Append(val, reflect.Zero(val.Type().Elem())))
// Recur to read element into slice.
if err := d.unmarshalAttr(val.Index(n), attr); err != nil {
val.SetLen(n)
return err
}
return nil
}
if val.Type() == attrType {
val.Set(reflect.ValueOf(attr))
return nil
}
return copyValue(val, []byte(attr.Value))
}
var (
attrType = reflect.TypeOf(Attr{})
unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)
const (
maxUnmarshalDepth = 10000
maxUnmarshalDepthWasm = 5000 // go.dev/issue/56498
)
var errUnmarshalDepth = errors.New("exceeded max depth")
// Unmarshal a single XML element into val.
func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error {
if depth >= maxUnmarshalDepth || runtime.GOARCH == "wasm" && depth >= maxUnmarshalDepthWasm {
return errUnmarshalDepth
}
// Find start element if we need it.
if start == nil {
for {
tok, err := d.Token()
if err != nil {
return err
}
if t, ok := tok.(StartElement); ok {
start = &t
break
}
}
}
// Load value from interface, but only if the result will be
// usefully addressable.
if val.Kind() == reflect.Interface && !val.IsNil() {
e := val.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() {
val = e
}
}
if val.Kind() == reflect.Pointer {
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(unmarshalerType) {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
}
}
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
}
}
var (
data []byte
saveData reflect.Value
comment []byte
saveComment reflect.Value
saveXML reflect.Value
saveXMLIndex int
saveXMLData []byte
saveAny reflect.Value
sv reflect.Value
tinfo *typeInfo
err error
)
switch v := val; v.Kind() {
default:
return errors.New("unknown type " + v.Type().String())
case reflect.Interface:
// TODO: For now, simply ignore the field. In the near
// future we may choose to unmarshal the start
// element on it, if not nil.
return d.Skip()
case reflect.Slice:
typ := v.Type()
if typ.Elem().Kind() == reflect.Uint8 {
// []byte
saveData = v
break
}
// Slice of element values.
// Grow slice.
n := v.Len()
v.Set(reflect.Append(val, reflect.Zero(v.Type().Elem())))
// Recur to read element into slice.
if err := d.unmarshal(v.Index(n), start, depth+1); err != nil {
v.SetLen(n)
return err
}
return nil
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
saveData = v
case reflect.Struct:
typ := v.Type()
if typ == nameType {
v.Set(reflect.ValueOf(start.Name))
break
}
sv = v
tinfo, err = getTypeInfo(typ)
if err != nil {
return err
}
// Validate and assign element name.
if tinfo.xmlname != nil {
finfo := tinfo.xmlname
if finfo.name != "" && finfo.name != start.Name.Local {
return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">")
}
if finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have "
if start.Name.Space == "" {
e += "no name space"
} else {
e += start.Name.Space
}
return UnmarshalError(e)
}
fv := finfo.value(sv, initNilPointers)
if _, ok := fv.Interface().(Name); ok {
fv.Set(reflect.ValueOf(start.Name))
}
}
// Assign attributes.
for _, a := range start.Attr {
handled := false
any := -1
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
switch finfo.flags & fMode {
case fAttr:
strv := finfo.value(sv, initNilPointers)
if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
if err := d.unmarshalAttr(strv, a); err != nil {
return err
}
handled = true
}
case fAny | fAttr:
if any == -1 {
any = i
}
}
}
if !handled && any >= 0 {
finfo := &tinfo.fields[any]
strv := finfo.value(sv, initNilPointers)
if err := d.unmarshalAttr(strv, a); err != nil {
return err
}
}
}
// Determine whether we need to save character data or comments.
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
switch finfo.flags & fMode {
case fCDATA, fCharData:
if !saveData.IsValid() {
saveData = finfo.value(sv, initNilPointers)
}
case fComment:
if !saveComment.IsValid() {
saveComment = finfo.value(sv, initNilPointers)
}
case fAny, fAny | fElement:
if !saveAny.IsValid() {
saveAny = finfo.value(sv, initNilPointers)
}
case fInnerXML:
if !saveXML.IsValid() {
saveXML = finfo.value(sv, initNilPointers)
if d.saved == nil {
saveXMLIndex = 0
d.saved = new(bytes.Buffer)
} else {
saveXMLIndex = d.savedOffset()
}
}
}
}
}
// Find end element.
// Process sub-elements along the way.
Loop:
for {
var savedOffset int
if saveXML.IsValid() {
savedOffset = d.savedOffset()
}
tok, err := d.Token()
if err != nil {
return err
}
switch t := tok.(type) {
case StartElement:
consumed := false
if sv.IsValid() {
// unmarshalPath can call unmarshal, so we need to pass the depth through so that
// we can continue to enforce the maximum recursion limit.
consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth)
if err != nil {
return err
}
if !consumed && saveAny.IsValid() {
consumed = true
if err := d.unmarshal(saveAny, &t, depth+1); err != nil {
return err
}
}
}
if !consumed {
if err := d.Skip(); err != nil {
return err
}
}
case EndElement:
if saveXML.IsValid() {
saveXMLData = d.saved.Bytes()[saveXMLIndex:savedOffset]
if saveXMLIndex == 0 {
d.saved = nil
}
}
break Loop
case CharData:
if saveData.IsValid() {
data = append(data, t...)
}
case Comment:
if saveComment.IsValid() {
comment = append(comment, t...)
}
}
}
if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
if saveData.IsValid() && saveData.CanAddr() {
pv := saveData.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
}
if err := copyValue(saveData, data); err != nil {
return err
}
switch t := saveComment; t.Kind() {
case reflect.String:
t.SetString(string(comment))
case reflect.Slice:
t.Set(reflect.ValueOf(comment))
}
switch t := saveXML; t.Kind() {
case reflect.String:
t.SetString(string(saveXMLData))
case reflect.Slice:
if t.Type().Elem().Kind() == reflect.Uint8 {
t.Set(reflect.ValueOf(saveXMLData))
}
}
return nil
}
func copyValue(dst reflect.Value, src []byte) (err error) {
dst0 := dst
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(reflect.New(dst.Type().Elem()))
}
dst = dst.Elem()
}
// Save accumulated data.
switch dst.Kind() {
case reflect.Invalid:
// Probably a comment.
default:
return errors.New("cannot unmarshal into " + dst0.Type().String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if len(src) == 0 {
dst.SetInt(0)
return nil
}
itmp, err := strconv.ParseInt(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
if err != nil {
return err
}
dst.SetInt(itmp)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if len(src) == 0 {
dst.SetUint(0)
return nil
}
utmp, err := strconv.ParseUint(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
if err != nil {
return err
}
dst.SetUint(utmp)
case reflect.Float32, reflect.Float64:
if len(src) == 0 {
dst.SetFloat(0)
return nil
}
ftmp, err := strconv.ParseFloat(strings.TrimSpace(string(src)), dst.Type().Bits())
if err != nil {
return err
}
dst.SetFloat(ftmp)
case reflect.Bool:
if len(src) == 0 {
dst.SetBool(false)
return nil
}
value, err := strconv.ParseBool(strings.TrimSpace(string(src)))
if err != nil {
return err
}
dst.SetBool(value)
case reflect.String:
dst.SetString(string(src))
case reflect.Slice:
if len(src) == 0 {
// non-nil to flag presence
src = []byte{}
}
dst.SetBytes(src)
}
return nil
}
// unmarshalPath walks down an XML structure looking for wanted
// paths, and calls unmarshal on them.
// The consumed result tells whether XML elements have been consumed
// from the Decoder until start's matching end element, or if it's
// still untouched because start is uninteresting for sv's fields.
func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) {
recurse := false
Loop:
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
continue
}
for j := range parents {
if parents[j] != finfo.parents[j] {
continue Loop
}
}
if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
// It's a perfect match, unmarshal the field.
return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1)
}
if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
// It's a prefix for the field. Break and recurse
// since it's not ok for one field path to be itself
// the prefix for another field path.
recurse = true
// We can reuse the same slice as long as we
// don't try to append to it.
parents = finfo.parents[:len(parents)+1]
break
}
}
if !recurse {
// We have no business with this element.
return false, nil
}
// The element is not a perfect match for any field, but one
// or more fields have the path to this element as a parent
// prefix. Recurse and attempt to match these.
for {
var tok Token
tok, err = d.Token()
if err != nil {
return true, err
}
switch t := tok.(type) {
case StartElement:
// the recursion depth of unmarshalPath is limited to the path length specified
// by the struct field tag, so we don't increment the depth here.
consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth)
if err != nil {
return true, err
}
if !consumed2 {
if err := d.Skip(); err != nil {
return true, err
}
}
case EndElement:
return true, nil
}
}
}
// Skip reads tokens until it has consumed the end element
// matching the most recent start element already consumed,
// skipping nested structures.
// It returns nil if it finds an end element matching the start
// element; otherwise it returns an error describing the problem.
func (d *Decoder) Skip() error {
var depth int64
for {
tok, err := d.Token()
if err != nil {
return err
}
switch tok.(type) {
case StartElement:
depth++
case EndElement:
if depth == 0 {
return nil
}
depth--
}
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"fmt"
"reflect"
"strings"
"sync"
)
// typeInfo holds details for the xml representation of a type.
type typeInfo struct {
xmlname *fieldInfo
fields []fieldInfo
}
// fieldInfo holds details for the xml representation of a single field.
type fieldInfo struct {
idx []int
name string
xmlns string
flags fieldFlags
parents []string
}
type fieldFlags int
const (
fElement fieldFlags = 1 << iota
fAttr
fCDATA
fCharData
fInnerXML
fComment
fAny
fOmitEmpty
fMode = fElement | fAttr | fCDATA | fCharData | fInnerXML | fComment | fAny
xmlName = "XMLName"
)
var tinfoMap sync.Map // map[reflect.Type]*typeInfo
var nameType = reflect.TypeOf(Name{})
// getTypeInfo returns the typeInfo structure with details necessary
// for marshaling and unmarshaling typ.
func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
if ti, ok := tinfoMap.Load(typ); ok {
return ti.(*typeInfo), nil
}
tinfo := &typeInfo{}
if typ.Kind() == reflect.Struct && typ != nameType {
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
if (!f.IsExported() && !f.Anonymous) || f.Tag.Get("xml") == "-" {
continue // Private field
}
// For embedded structs, embed its fields.
if f.Anonymous {
t := f.Type
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
inner, err := getTypeInfo(t)
if err != nil {
return nil, err
}
if tinfo.xmlname == nil {
tinfo.xmlname = inner.xmlname
}
for _, finfo := range inner.fields {
finfo.idx = append([]int{i}, finfo.idx...)
if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
return nil, err
}
}
continue
}
}
finfo, err := structFieldInfo(typ, &f)
if err != nil {
return nil, err
}
if f.Name == xmlName {
tinfo.xmlname = finfo
continue
}
// Add the field if it doesn't conflict with other fields.
if err := addFieldInfo(typ, tinfo, finfo); err != nil {
return nil, err
}
}
}
ti, _ := tinfoMap.LoadOrStore(typ, tinfo)
return ti.(*typeInfo), nil
}
// structFieldInfo builds and returns a fieldInfo for f.
func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
finfo := &fieldInfo{idx: f.Index}
// Split the tag from the xml namespace if necessary.
tag := f.Tag.Get("xml")
if ns, t, ok := strings.Cut(tag, " "); ok {
finfo.xmlns, tag = ns, t
}
// Parse flags.
tokens := strings.Split(tag, ",")
if len(tokens) == 1 {
finfo.flags = fElement
} else {
tag = tokens[0]
for _, flag := range tokens[1:] {
switch flag {
case "attr":
finfo.flags |= fAttr
case "cdata":
finfo.flags |= fCDATA
case "chardata":
finfo.flags |= fCharData
case "innerxml":
finfo.flags |= fInnerXML
case "comment":
finfo.flags |= fComment
case "any":
finfo.flags |= fAny
case "omitempty":
finfo.flags |= fOmitEmpty
}
}
// Validate the flags used.
valid := true
switch mode := finfo.flags & fMode; mode {
case 0:
finfo.flags |= fElement
case fAttr, fCDATA, fCharData, fInnerXML, fComment, fAny, fAny | fAttr:
if f.Name == xmlName || tag != "" && mode != fAttr {
valid = false
}
default:
// This will also catch multiple modes in a single field.
valid = false
}
if finfo.flags&fMode == fAny {
finfo.flags |= fElement
}
if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 {
valid = false
}
if !valid {
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
}
// Use of xmlns without a name is not allowed.
if finfo.xmlns != "" && tag == "" {
return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
if f.Name == xmlName {
// The XMLName field records the XML element name. Don't
// process it as usual because its name should default to
// empty rather than to the field name.
finfo.name = tag
return finfo, nil
}
if tag == "" {
// If the name part of the tag is completely empty, get
// default from XMLName of underlying struct if feasible,
// or field name otherwise.
if xmlname := lookupXMLName(f.Type); xmlname != nil {
finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
} else {
finfo.name = f.Name
}
return finfo, nil
}
// Prepare field name and parents.
parents := strings.Split(tag, ">")
if parents[0] == "" {
parents[0] = f.Name
}
if parents[len(parents)-1] == "" {
return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
}
finfo.name = parents[len(parents)-1]
if len(parents) > 1 {
if (finfo.flags & fElement) == 0 {
return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ","))
}
finfo.parents = parents[:len(parents)-1]
}
// If the field type has an XMLName field, the names must match
// so that the behavior of both marshaling and unmarshaling
// is straightforward and unambiguous.
if finfo.flags&fElement != 0 {
ftyp := f.Type
xmlname := lookupXMLName(ftyp)
if xmlname != nil && xmlname.name != finfo.name {
return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
finfo.name, typ, f.Name, xmlname.name, ftyp)
}
}
return finfo, nil
}
// lookupXMLName returns the fieldInfo for typ's XMLName field
// in case it exists and has a valid xml field tag, otherwise
// it returns nil.
func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}
for i, n := 0, typ.NumField(); i < n; i++ {
f := typ.Field(i)
if f.Name != xmlName {
continue
}
finfo, err := structFieldInfo(typ, &f)
if err == nil && finfo.name != "" {
return finfo
}
// Also consider errors as a non-existent field tag
// and let getTypeInfo itself report the error.
break
}
return nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
// addFieldInfo adds finfo to tinfo.fields if there are no
// conflicts, or if conflicts arise from previous fields that were
// obtained from deeper embedded structures than finfo. In the latter
// case, the conflicting entries are dropped.
// A conflict occurs when the path (parent + name) to a field is
// itself a prefix of another path, or when two paths match exactly.
// It is okay for field paths to share a common, shorter prefix.
func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
var conflicts []int
Loop:
// First, figure all conflicts. Most working code will have none.
for i := range tinfo.fields {
oldf := &tinfo.fields[i]
if oldf.flags&fMode != newf.flags&fMode {
continue
}
if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns {
continue
}
minl := min(len(newf.parents), len(oldf.parents))
for p := 0; p < minl; p++ {
if oldf.parents[p] != newf.parents[p] {
continue Loop
}
}
if len(oldf.parents) > len(newf.parents) {
if oldf.parents[len(newf.parents)] == newf.name {
conflicts = append(conflicts, i)
}
} else if len(oldf.parents) < len(newf.parents) {
if newf.parents[len(oldf.parents)] == oldf.name {
conflicts = append(conflicts, i)
}
} else {
if newf.name == oldf.name && newf.xmlns == oldf.xmlns {
conflicts = append(conflicts, i)
}
}
}
// Without conflicts, add the new field and return.
if conflicts == nil {
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// If any conflict is shallower, ignore the new field.
// This matches the Go field resolution on embedding.
for _, i := range conflicts {
if len(tinfo.fields[i].idx) < len(newf.idx) {
return nil
}
}
// Otherwise, if any of them is at the same depth level, it's an error.
for _, i := range conflicts {
oldf := &tinfo.fields[i]
if len(oldf.idx) == len(newf.idx) {
f1 := typ.FieldByIndex(oldf.idx)
f2 := typ.FieldByIndex(newf.idx)
return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
}
// Otherwise, the new field is shallower, and thus takes precedence,
// so drop the conflicting fields from tinfo and append the new one.
for c := len(conflicts) - 1; c >= 0; c-- {
i := conflicts[c]
copy(tinfo.fields[i:], tinfo.fields[i+1:])
tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
}
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// A TagPathError represents an error in the unmarshaling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}
const (
initNilPointers = true
dontInitNilPointers = false
)
// value returns v's field value corresponding to finfo.
// It's equivalent to v.FieldByIndex(finfo.idx), but when passed
// initNilPointers, it initializes and dereferences pointers as necessary.
// When passed dontInitNilPointers and a nil pointer is reached, the function
// returns a zero reflect.Value.
func (finfo *fieldInfo) value(v reflect.Value, shouldInitNilPointers bool) reflect.Value {
for i, x := range finfo.idx {
if i > 0 {
t := v.Type()
if t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct {
if v.IsNil() {
if !shouldInitNilPointers {
return reflect.Value{}
}
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
v = v.Field(x)
}
return v
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package xml implements a simple XML 1.0 parser that
// understands XML name spaces.
package xml
// References:
// Annotated XML spec: https://www.xml.com/axml/testaxml.htm
// XML name spaces: https://www.w3.org/TR/REC-xml-names/
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// A SyntaxError represents a syntax error in the XML input stream.
type SyntaxError struct {
Msg string
Line int
}
func (e *SyntaxError) Error() string {
return "XML syntax error on line " + strconv.Itoa(e.Line) + ": " + e.Msg
}
// A Name represents an XML name (Local) annotated
// with a name space identifier (Space).
// In tokens returned by Decoder.Token, the Space identifier
// is given as a canonical URL, not the short prefix used
// in the document being parsed.
type Name struct {
Space, Local string
}
// An Attr represents an attribute in an XML element (Name=Value).
type Attr struct {
Name Name
Value string
}
// A Token is an interface holding one of the token types:
// StartElement, EndElement, CharData, Comment, ProcInst, or Directive.
type Token any
// A StartElement represents an XML start element.
type StartElement struct {
Name Name
Attr []Attr
}
// Copy creates a new copy of StartElement.
func (e StartElement) Copy() StartElement {
attrs := make([]Attr, len(e.Attr))
copy(attrs, e.Attr)
e.Attr = attrs
return e
}
// End returns the corresponding XML end element.
func (e StartElement) End() EndElement {
return EndElement{e.Name}
}
// An EndElement represents an XML end element.
type EndElement struct {
Name Name
}
// A CharData represents XML character data (raw text),
// in which XML escape sequences have been replaced by
// the characters they represent.
type CharData []byte
// Copy creates a new copy of CharData.
func (c CharData) Copy() CharData { return CharData(bytes.Clone(c)) }
// A Comment represents an XML comment of the form <!--comment-->.
// The bytes do not include the <!-- and --> comment markers.
type Comment []byte
// Copy creates a new copy of Comment.
func (c Comment) Copy() Comment { return Comment(bytes.Clone(c)) }
// A ProcInst represents an XML processing instruction of the form <?target inst?>
type ProcInst struct {
Target string
Inst []byte
}
// Copy creates a new copy of ProcInst.
func (p ProcInst) Copy() ProcInst {
p.Inst = bytes.Clone(p.Inst)
return p
}
// A Directive represents an XML directive of the form <!text>.
// The bytes do not include the <! and > markers.
type Directive []byte
// Copy creates a new copy of Directive.
func (d Directive) Copy() Directive { return Directive(bytes.Clone(d)) }
// CopyToken returns a copy of a Token.
func CopyToken(t Token) Token {
switch v := t.(type) {
case CharData:
return v.Copy()
case Comment:
return v.Copy()
case Directive:
return v.Copy()
case ProcInst:
return v.Copy()
case StartElement:
return v.Copy()
}
return t
}
// A TokenReader is anything that can decode a stream of XML tokens, including a
// Decoder.
//
// When Token encounters an error or end-of-file condition after successfully
// reading a token, it returns the token. It may return the (non-nil) error from
// the same call or return the error (and a nil token) from a subsequent call.
// An instance of this general case is that a TokenReader returning a non-nil
// token at the end of the token stream may return either io.EOF or a nil error.
// The next Read should return nil, io.EOF.
//
// Implementations of Token are discouraged from returning a nil token with a
// nil error. Callers should treat a return of nil, nil as indicating that
// nothing happened; in particular it does not indicate EOF.
type TokenReader interface {
Token() (Token, error)
}
// A Decoder represents an XML parser reading a particular input stream.
// The parser assumes that its input is encoded in UTF-8.
type Decoder struct {
// Strict defaults to true, enforcing the requirements
// of the XML specification.
// If set to false, the parser allows input containing common
// mistakes:
// * If an element is missing an end tag, the parser invents
// end tags as necessary to keep the return values from Token
// properly balanced.
// * In attribute values and character data, unknown or malformed
// character entities (sequences beginning with &) are left alone.
//
// Setting:
//
// d.Strict = false
// d.AutoClose = xml.HTMLAutoClose
// d.Entity = xml.HTMLEntity
//
// creates a parser that can handle typical HTML.
//
// Strict mode does not enforce the requirements of the XML name spaces TR.
// In particular it does not reject name space tags using undefined prefixes.
// Such tags are recorded with the unknown prefix as the name space URL.
Strict bool
// When Strict == false, AutoClose indicates a set of elements to
// consider closed immediately after they are opened, regardless
// of whether an end element is present.
AutoClose []string
// Entity can be used to map non-standard entity names to string replacements.
// The parser behaves as if these standard mappings are present in the map,
// regardless of the actual map content:
//
// "lt": "<",
// "gt": ">",
// "amp": "&",
// "apos": "'",
// "quot": `"`,
Entity map[string]string
// CharsetReader, if non-nil, defines a function to generate
// charset-conversion readers, converting from the provided
// non-UTF-8 charset into UTF-8. If CharsetReader is nil or
// returns an error, parsing stops with an error. One of the
// CharsetReader's result values must be non-nil.
CharsetReader func(charset string, input io.Reader) (io.Reader, error)
// DefaultSpace sets the default name space used for unadorned tags,
// as if the entire XML stream were wrapped in an element containing
// the attribute xmlns="DefaultSpace".
DefaultSpace string
r io.ByteReader
t TokenReader
buf bytes.Buffer
saved *bytes.Buffer
stk *stack
free *stack
needClose bool
toClose Name
nextToken Token
nextByte int
ns map[string]string
err error
line int
linestart int64
offset int64
unmarshalDepth int
}
// NewDecoder creates a new XML parser reading from r.
// If r does not implement io.ByteReader, NewDecoder will
// do its own buffering.
func NewDecoder(r io.Reader) *Decoder {
d := &Decoder{
ns: make(map[string]string),
nextByte: -1,
line: 1,
Strict: true,
}
d.switchToReader(r)
return d
}
// NewTokenDecoder creates a new XML parser using an underlying token stream.
func NewTokenDecoder(t TokenReader) *Decoder {
// Is it already a Decoder?
if d, ok := t.(*Decoder); ok {
return d
}
d := &Decoder{
ns: make(map[string]string),
t: t,
nextByte: -1,
line: 1,
Strict: true,
}
return d
}
// Token returns the next XML token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Slices of bytes in the returned token data refer to the
// parser's internal buffer and remain valid only until the next
// call to Token. To acquire a copy of the bytes, call CopyToken
// or the token's Copy method.
//
// Token expands self-closing elements such as <br>
// into separate start and end elements returned by successive calls.
//
// Token guarantees that the StartElement and EndElement
// tokens it returns are properly nested and matched:
// if Token encounters an unexpected end element
// or EOF before all expected end elements,
// it will return an error.
//
// Token implements XML name spaces as described by
// https://www.w3.org/TR/REC-xml-names/. Each of the
// Name structures contained in the Token has the Space
// set to the URL identifying its name space when known.
// If Token encounters an unrecognized name space prefix,
// it uses the prefix as the Space rather than report an error.
func (d *Decoder) Token() (Token, error) {
var t Token
var err error
if d.stk != nil && d.stk.kind == stkEOF {
return nil, io.EOF
}
if d.nextToken != nil {
t = d.nextToken
d.nextToken = nil
} else {
if t, err = d.rawToken(); t == nil && err != nil {
if err == io.EOF && d.stk != nil && d.stk.kind != stkEOF {
err = d.syntaxError("unexpected EOF")
}
return nil, err
}
// We still have a token to process, so clear any
// errors (e.g. EOF) and proceed.
err = nil
}
if !d.Strict {
if t1, ok := d.autoClose(t); ok {
d.nextToken = t
t = t1
}
}
switch t1 := t.(type) {
case StartElement:
// In XML name spaces, the translations listed in the
// attributes apply to the element name and
// to the other attribute names, so process
// the translations first.
for _, a := range t1.Attr {
if a.Name.Space == xmlnsPrefix {
v, ok := d.ns[a.Name.Local]
d.pushNs(a.Name.Local, v, ok)
d.ns[a.Name.Local] = a.Value
}
if a.Name.Space == "" && a.Name.Local == xmlnsPrefix {
// Default space for untagged names
v, ok := d.ns[""]
d.pushNs("", v, ok)
d.ns[""] = a.Value
}
}
d.pushElement(t1.Name)
d.translate(&t1.Name, true)
for i := range t1.Attr {
d.translate(&t1.Attr[i].Name, false)
}
t = t1
case EndElement:
if !d.popElement(&t1) {
return nil, d.err
}
t = t1
}
return t, err
}
const (
xmlURL = "http://www.w3.org/XML/1998/namespace"
xmlnsPrefix = "xmlns"
xmlPrefix = "xml"
)
// Apply name space translation to name n.
// The default name space (for Space=="")
// applies only to element names, not to attribute names.
func (d *Decoder) translate(n *Name, isElementName bool) {
switch {
case n.Space == xmlnsPrefix:
return
case n.Space == "" && !isElementName:
return
case n.Space == xmlPrefix:
n.Space = xmlURL
case n.Space == "" && n.Local == xmlnsPrefix:
return
}
if v, ok := d.ns[n.Space]; ok {
n.Space = v
} else if n.Space == "" {
n.Space = d.DefaultSpace
}
}
func (d *Decoder) switchToReader(r io.Reader) {
// Get efficient byte at a time reader.
// Assume that if reader has its own
// ReadByte, it's efficient enough.
// Otherwise, use bufio.
if rb, ok := r.(io.ByteReader); ok {
d.r = rb
} else {
d.r = bufio.NewReader(r)
}
}
// Parsing state - stack holds old name space translations
// and the current set of open elements. The translations to pop when
// ending a given tag are *below* it on the stack, which is
// more work but forced on us by XML.
type stack struct {
next *stack
kind int
name Name
ok bool
}
const (
stkStart = iota
stkNs
stkEOF
)
func (d *Decoder) push(kind int) *stack {
s := d.free
if s != nil {
d.free = s.next
} else {
s = new(stack)
}
s.next = d.stk
s.kind = kind
d.stk = s
return s
}
func (d *Decoder) pop() *stack {
s := d.stk
if s != nil {
d.stk = s.next
s.next = d.free
d.free = s
}
return s
}
// Record that after the current element is finished
// (that element is already pushed on the stack)
// Token should return EOF until popEOF is called.
func (d *Decoder) pushEOF() {
// Walk down stack to find Start.
// It might not be the top, because there might be stkNs
// entries above it.
start := d.stk
for start.kind != stkStart {
start = start.next
}
// The stkNs entries below a start are associated with that
// element too; skip over them.
for start.next != nil && start.next.kind == stkNs {
start = start.next
}
s := d.free
if s != nil {
d.free = s.next
} else {
s = new(stack)
}
s.kind = stkEOF
s.next = start.next
start.next = s
}
// Undo a pushEOF.
// The element must have been finished, so the EOF should be at the top of the stack.
func (d *Decoder) popEOF() bool {
if d.stk == nil || d.stk.kind != stkEOF {
return false
}
d.pop()
return true
}
// Record that we are starting an element with the given name.
func (d *Decoder) pushElement(name Name) {
s := d.push(stkStart)
s.name = name
}
// Record that we are changing the value of ns[local].
// The old value is url, ok.
func (d *Decoder) pushNs(local string, url string, ok bool) {
s := d.push(stkNs)
s.name.Local = local
s.name.Space = url
s.ok = ok
}
// Creates a SyntaxError with the current line number.
func (d *Decoder) syntaxError(msg string) error {
return &SyntaxError{Msg: msg, Line: d.line}
}
// Record that we are ending an element with the given name.
// The name must match the record at the top of the stack,
// which must be a pushElement record.
// After popping the element, apply any undo records from
// the stack to restore the name translations that existed
// before we saw this element.
func (d *Decoder) popElement(t *EndElement) bool {
s := d.pop()
name := t.Name
switch {
case s == nil || s.kind != stkStart:
d.err = d.syntaxError("unexpected end element </" + name.Local + ">")
return false
case s.name.Local != name.Local:
if !d.Strict {
d.needClose = true
d.toClose = t.Name
t.Name = s.name
return true
}
d.err = d.syntaxError("element <" + s.name.Local + "> closed by </" + name.Local + ">")
return false
case s.name.Space != name.Space:
d.err = d.syntaxError("element <" + s.name.Local + "> in space " + s.name.Space +
" closed by </" + name.Local + "> in space " + name.Space)
return false
}
d.translate(&t.Name, true)
// Pop stack until a Start or EOF is on the top, undoing the
// translations that were associated with the element we just closed.
for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF {
s := d.pop()
if s.ok {
d.ns[s.name.Local] = s.name.Space
} else {
delete(d.ns, s.name.Local)
}
}
return true
}
// If the top element on the stack is autoclosing and
// t is not the end tag, invent the end tag.
func (d *Decoder) autoClose(t Token) (Token, bool) {
if d.stk == nil || d.stk.kind != stkStart {
return nil, false
}
for _, s := range d.AutoClose {
if strings.EqualFold(s, d.stk.name.Local) {
// This one should be auto closed if t doesn't close it.
et, ok := t.(EndElement)
if !ok || !strings.EqualFold(et.Name.Local, d.stk.name.Local) {
return EndElement{d.stk.name}, true
}
break
}
}
return nil, false
}
var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method")
// RawToken is like Token but does not verify that
// start and end elements match and does not translate
// name space prefixes to their corresponding URLs.
func (d *Decoder) RawToken() (Token, error) {
if d.unmarshalDepth > 0 {
return nil, errRawToken
}
return d.rawToken()
}
func (d *Decoder) rawToken() (Token, error) {
if d.t != nil {
return d.t.Token()
}
if d.err != nil {
return nil, d.err
}
if d.needClose {
// The last element we read was self-closing and
// we returned just the StartElement half.
// Return the EndElement half now.
d.needClose = false
return EndElement{d.toClose}, nil
}
b, ok := d.getc()
if !ok {
return nil, d.err
}
if b != '<' {
// Text section.
d.ungetc(b)
data := d.text(-1, false)
if data == nil {
return nil, d.err
}
return CharData(data), nil
}
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
switch b {
case '/':
// </: End element
var name Name
if name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected element name after </")
}
return nil, d.err
}
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '>' {
d.err = d.syntaxError("invalid characters between </" + name.Local + " and >")
return nil, d.err
}
return EndElement{name}, nil
case '?':
// <?: Processing instruction.
var target string
if target, ok = d.name(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected target name after <?")
}
return nil, d.err
}
d.space()
d.buf.Reset()
var b0 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
d.buf.WriteByte(b)
if b0 == '?' && b == '>' {
break
}
b0 = b
}
data := d.buf.Bytes()
data = data[0 : len(data)-2] // chop ?>
if target == "xml" {
content := string(data)
ver := procInst("version", content)
if ver != "" && ver != "1.0" {
d.err = fmt.Errorf("xml: unsupported version %q; only version 1.0 is supported", ver)
return nil, d.err
}
enc := procInst("encoding", content)
if enc != "" && enc != "utf-8" && enc != "UTF-8" && !strings.EqualFold(enc, "utf-8") {
if d.CharsetReader == nil {
d.err = fmt.Errorf("xml: encoding %q declared but Decoder.CharsetReader is nil", enc)
return nil, d.err
}
newr, err := d.CharsetReader(enc, d.r.(io.Reader))
if err != nil {
d.err = fmt.Errorf("xml: opening charset %q: %v", enc, err)
return nil, d.err
}
if newr == nil {
panic("CharsetReader returned a nil Reader for charset " + enc)
}
d.switchToReader(newr)
}
}
return ProcInst{target, data}, nil
case '!':
// <!: Maybe comment, maybe CDATA.
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
switch b {
case '-': // <!-
// Probably <!-- for a comment.
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '-' {
d.err = d.syntaxError("invalid sequence <!- not part of <!--")
return nil, d.err
}
// Look for terminator.
d.buf.Reset()
var b0, b1 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
d.buf.WriteByte(b)
if b0 == '-' && b1 == '-' {
if b != '>' {
d.err = d.syntaxError(
`invalid sequence "--" not allowed in comments`)
return nil, d.err
}
break
}
b0, b1 = b1, b
}
data := d.buf.Bytes()
data = data[0 : len(data)-3] // chop -->
return Comment(data), nil
case '[': // <![
// Probably <![CDATA[.
for i := 0; i < 6; i++ {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != "CDATA["[i] {
d.err = d.syntaxError("invalid <![ sequence")
return nil, d.err
}
}
// Have <![CDATA[. Read text until ]]>.
data := d.text(-1, true)
if data == nil {
return nil, d.err
}
return CharData(data), nil
}
// Probably a directive: <!DOCTYPE ...>, <!ENTITY ...>, etc.
// We don't care, but accumulate for caller. Quoted angle
// brackets do not count for nesting.
d.buf.Reset()
d.buf.WriteByte(b)
inquote := uint8(0)
depth := 0
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if inquote == 0 && b == '>' && depth == 0 {
break
}
HandleB:
d.buf.WriteByte(b)
switch {
case b == inquote:
inquote = 0
case inquote != 0:
// in quotes, no special action
case b == '\'' || b == '"':
inquote = b
case b == '>' && inquote == 0:
depth--
case b == '<' && inquote == 0:
// Look for <!-- to begin comment.
s := "!--"
for i := 0; i < len(s); i++ {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != s[i] {
for j := 0; j < i; j++ {
d.buf.WriteByte(s[j])
}
depth++
goto HandleB
}
}
// Remove < that was written above.
d.buf.Truncate(d.buf.Len() - 1)
// Look for terminator.
var b0, b1 byte
for {
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b0 == '-' && b1 == '-' && b == '>' {
break
}
b0, b1 = b1, b
}
// Replace the comment with a space in the returned Directive
// body, so that markup parts that were separated by the comment
// (like a "<" and a "!") don't get joined when re-encoding the
// Directive, taking new semantic meaning.
d.buf.WriteByte(' ')
}
}
return Directive(d.buf.Bytes()), nil
}
// Must be an open element like <a href="foo">
d.ungetc(b)
var (
name Name
empty bool
attr []Attr
)
if name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected element name after <")
}
return nil, d.err
}
attr = []Attr{}
for {
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b == '/' {
empty = true
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '>' {
d.err = d.syntaxError("expected /> in element")
return nil, d.err
}
break
}
if b == '>' {
break
}
d.ungetc(b)
a := Attr{}
if a.Name, ok = d.nsname(); !ok {
if d.err == nil {
d.err = d.syntaxError("expected attribute name in element")
}
return nil, d.err
}
d.space()
if b, ok = d.mustgetc(); !ok {
return nil, d.err
}
if b != '=' {
if d.Strict {
d.err = d.syntaxError("attribute name without = in element")
return nil, d.err
}
d.ungetc(b)
a.Value = a.Name.Local
} else {
d.space()
data := d.attrval()
if data == nil {
return nil, d.err
}
a.Value = string(data)
}
attr = append(attr, a)
}
if empty {
d.needClose = true
d.toClose = name
}
return StartElement{name, attr}, nil
}
func (d *Decoder) attrval() []byte {
b, ok := d.mustgetc()
if !ok {
return nil
}
// Handle quoted attribute values
if b == '"' || b == '\'' {
return d.text(int(b), false)
}
// Handle unquoted attribute values for strict parsers
if d.Strict {
d.err = d.syntaxError("unquoted or missing attribute value in element")
return nil
}
// Handle unquoted attribute values for unstrict parsers
d.ungetc(b)
d.buf.Reset()
for {
b, ok = d.mustgetc()
if !ok {
return nil
}
// https://www.w3.org/TR/REC-html40/intro/sgmltut.html#h-3.2.2
if 'a' <= b && b <= 'z' || 'A' <= b && b <= 'Z' ||
'0' <= b && b <= '9' || b == '_' || b == ':' || b == '-' {
d.buf.WriteByte(b)
} else {
d.ungetc(b)
break
}
}
return d.buf.Bytes()
}
// Skip spaces if any
func (d *Decoder) space() {
for {
b, ok := d.getc()
if !ok {
return
}
switch b {
case ' ', '\r', '\n', '\t':
default:
d.ungetc(b)
return
}
}
}
// Read a single byte.
// If there is no byte to read, return ok==false
// and leave the error in d.err.
// Maintain line number.
func (d *Decoder) getc() (b byte, ok bool) {
if d.err != nil {
return 0, false
}
if d.nextByte >= 0 {
b = byte(d.nextByte)
d.nextByte = -1
} else {
b, d.err = d.r.ReadByte()
if d.err != nil {
return 0, false
}
if d.saved != nil {
d.saved.WriteByte(b)
}
}
if b == '\n' {
d.line++
d.linestart = d.offset + 1
}
d.offset++
return b, true
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (d *Decoder) InputOffset() int64 {
return d.offset
}
// InputPos returns the line of the current decoder position and the 1 based
// input position of the line. The position gives the location of the end of the
// most recently returned token.
func (d *Decoder) InputPos() (line, column int) {
return d.line, int(d.offset-d.linestart) + 1
}
// Return saved offset.
// If we did ungetc (nextByte >= 0), have to back up one.
func (d *Decoder) savedOffset() int {
n := d.saved.Len()
if d.nextByte >= 0 {
n--
}
return n
}
// Must read a single byte.
// If there is no byte to read,
// set d.err to SyntaxError("unexpected EOF")
// and return ok==false
func (d *Decoder) mustgetc() (b byte, ok bool) {
if b, ok = d.getc(); !ok {
if d.err == io.EOF {
d.err = d.syntaxError("unexpected EOF")
}
}
return
}
// Unread a single byte.
func (d *Decoder) ungetc(b byte) {
if b == '\n' {
d.line--
}
d.nextByte = int(b)
d.offset--
}
var entity = map[string]rune{
"lt": '<',
"gt": '>',
"amp": '&',
"apos": '\'',
"quot": '"',
}
// Read plain text section (XML calls it character data).
// If quote >= 0, we are in a quoted string and need to find the matching quote.
// If cdata == true, we are in a <![CDATA[ section and need to find ]]>.
// On failure return nil and leave the error in d.err.
func (d *Decoder) text(quote int, cdata bool) []byte {
var b0, b1 byte
var trunc int
d.buf.Reset()
Input:
for {
b, ok := d.getc()
if !ok {
if cdata {
if d.err == io.EOF {
d.err = d.syntaxError("unexpected EOF in CDATA section")
}
return nil
}
break Input
}
// <![CDATA[ section ends with ]]>.
// It is an error for ]]> to appear in ordinary text.
if b0 == ']' && b1 == ']' && b == '>' {
if cdata {
trunc = 2
break Input
}
d.err = d.syntaxError("unescaped ]]> not in CDATA section")
return nil
}
// Stop reading text if we see a <.
if b == '<' && !cdata {
if quote >= 0 {
d.err = d.syntaxError("unescaped < inside quoted string")
return nil
}
d.ungetc('<')
break Input
}
if quote >= 0 && b == byte(quote) {
break Input
}
if b == '&' && !cdata {
// Read escaped character expression up to semicolon.
// XML in all its glory allows a document to define and use
// its own character names with <!ENTITY ...> directives.
// Parsers are required to recognize lt, gt, amp, apos, and quot
// even if they have not been declared.
before := d.buf.Len()
d.buf.WriteByte('&')
var ok bool
var text string
var haveText bool
if b, ok = d.mustgetc(); !ok {
return nil
}
if b == '#' {
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
base := 10
if b == 'x' {
base = 16
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
start := d.buf.Len()
for '0' <= b && b <= '9' ||
base == 16 && 'a' <= b && b <= 'f' ||
base == 16 && 'A' <= b && b <= 'F' {
d.buf.WriteByte(b)
if b, ok = d.mustgetc(); !ok {
return nil
}
}
if b != ';' {
d.ungetc(b)
} else {
s := string(d.buf.Bytes()[start:])
d.buf.WriteByte(';')
n, err := strconv.ParseUint(s, base, 64)
if err == nil && n <= unicode.MaxRune {
text = string(rune(n))
haveText = true
}
}
} else {
d.ungetc(b)
if !d.readName() {
if d.err != nil {
return nil
}
}
if b, ok = d.mustgetc(); !ok {
return nil
}
if b != ';' {
d.ungetc(b)
} else {
name := d.buf.Bytes()[before+1:]
d.buf.WriteByte(';')
if isName(name) {
s := string(name)
if r, ok := entity[s]; ok {
text = string(r)
haveText = true
} else if d.Entity != nil {
text, haveText = d.Entity[s]
}
}
}
}
if haveText {
d.buf.Truncate(before)
d.buf.WriteString(text)
b0, b1 = 0, 0
continue Input
}
if !d.Strict {
b0, b1 = 0, 0
continue Input
}
ent := string(d.buf.Bytes()[before:])
if ent[len(ent)-1] != ';' {
ent += " (no semicolon)"
}
d.err = d.syntaxError("invalid character entity " + ent)
return nil
}
// We must rewrite unescaped \r and \r\n into \n.
if b == '\r' {
d.buf.WriteByte('\n')
} else if b1 == '\r' && b == '\n' {
// Skip \r\n--we already wrote \n.
} else {
d.buf.WriteByte(b)
}
b0, b1 = b1, b
}
data := d.buf.Bytes()
data = data[0 : len(data)-trunc]
// Inspect each rune for being a disallowed character.
buf := data
for len(buf) > 0 {
r, size := utf8.DecodeRune(buf)
if r == utf8.RuneError && size == 1 {
d.err = d.syntaxError("invalid UTF-8")
return nil
}
buf = buf[size:]
if !isInCharacterRange(r) {
d.err = d.syntaxError(fmt.Sprintf("illegal character code %U", r))
return nil
}
}
return data
}
// Decide whether the given rune is in the XML Character Range, per
// the Char production of https://www.xml.com/axml/testaxml.htm,
// Section 2.2 Characters.
func isInCharacterRange(r rune) (inrange bool) {
return r == 0x09 ||
r == 0x0A ||
r == 0x0D ||
r >= 0x20 && r <= 0xD7FF ||
r >= 0xE000 && r <= 0xFFFD ||
r >= 0x10000 && r <= 0x10FFFF
}
// Get name space name: name with a : stuck in the middle.
// The part before the : is the name space identifier.
func (d *Decoder) nsname() (name Name, ok bool) {
s, ok := d.name()
if !ok {
return
}
if strings.Count(s, ":") > 1 {
return name, false
} else if space, local, ok := strings.Cut(s, ":"); !ok || space == "" || local == "" {
name.Local = s
} else {
name.Space = space
name.Local = local
}
return name, true
}
// Get name: /first(first|second)*/
// Do not set d.err if the name is missing (unless unexpected EOF is received):
// let the caller provide better context.
func (d *Decoder) name() (s string, ok bool) {
d.buf.Reset()
if !d.readName() {
return "", false
}
// Now we check the characters.
b := d.buf.Bytes()
if !isName(b) {
d.err = d.syntaxError("invalid XML name: " + string(b))
return "", false
}
return string(b), true
}
// Read a name and append its bytes to d.buf.
// The name is delimited by any single-byte character not valid in names.
// All multi-byte characters are accepted; the caller must check their validity.
func (d *Decoder) readName() (ok bool) {
var b byte
if b, ok = d.mustgetc(); !ok {
return
}
if b < utf8.RuneSelf && !isNameByte(b) {
d.ungetc(b)
return false
}
d.buf.WriteByte(b)
for {
if b, ok = d.mustgetc(); !ok {
return
}
if b < utf8.RuneSelf && !isNameByte(b) {
d.ungetc(b)
break
}
d.buf.WriteByte(b)
}
return true
}
func isNameByte(c byte) bool {
return 'A' <= c && c <= 'Z' ||
'a' <= c && c <= 'z' ||
'0' <= c && c <= '9' ||
c == '_' || c == ':' || c == '.' || c == '-'
}
func isName(s []byte) bool {
if len(s) == 0 {
return false
}
c, n := utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) {
return false
}
for n < len(s) {
s = s[n:]
c, n = utf8.DecodeRune(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) && !unicode.Is(second, c) {
return false
}
}
return true
}
func isNameString(s string) bool {
if len(s) == 0 {
return false
}
c, n := utf8.DecodeRuneInString(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) {
return false
}
for n < len(s) {
s = s[n:]
c, n = utf8.DecodeRuneInString(s)
if c == utf8.RuneError && n == 1 {
return false
}
if !unicode.Is(first, c) && !unicode.Is(second, c) {
return false
}
}
return true
}
// These tables were generated by cut and paste from Appendix B of
// the XML spec at https://www.xml.com/axml/testaxml.htm
// and then reformatting. First corresponds to (Letter | '_' | ':')
// and second corresponds to NameChar.
var first = &unicode.RangeTable{
R16: []unicode.Range16{
{0x003A, 0x003A, 1},
{0x0041, 0x005A, 1},
{0x005F, 0x005F, 1},
{0x0061, 0x007A, 1},
{0x00C0, 0x00D6, 1},
{0x00D8, 0x00F6, 1},
{0x00F8, 0x00FF, 1},
{0x0100, 0x0131, 1},
{0x0134, 0x013E, 1},
{0x0141, 0x0148, 1},
{0x014A, 0x017E, 1},
{0x0180, 0x01C3, 1},
{0x01CD, 0x01F0, 1},
{0x01F4, 0x01F5, 1},
{0x01FA, 0x0217, 1},
{0x0250, 0x02A8, 1},
{0x02BB, 0x02C1, 1},
{0x0386, 0x0386, 1},
{0x0388, 0x038A, 1},
{0x038C, 0x038C, 1},
{0x038E, 0x03A1, 1},
{0x03A3, 0x03CE, 1},
{0x03D0, 0x03D6, 1},
{0x03DA, 0x03E0, 2},
{0x03E2, 0x03F3, 1},
{0x0401, 0x040C, 1},
{0x040E, 0x044F, 1},
{0x0451, 0x045C, 1},
{0x045E, 0x0481, 1},
{0x0490, 0x04C4, 1},
{0x04C7, 0x04C8, 1},
{0x04CB, 0x04CC, 1},
{0x04D0, 0x04EB, 1},
{0x04EE, 0x04F5, 1},
{0x04F8, 0x04F9, 1},
{0x0531, 0x0556, 1},
{0x0559, 0x0559, 1},
{0x0561, 0x0586, 1},
{0x05D0, 0x05EA, 1},
{0x05F0, 0x05F2, 1},
{0x0621, 0x063A, 1},
{0x0641, 0x064A, 1},
{0x0671, 0x06B7, 1},
{0x06BA, 0x06BE, 1},
{0x06C0, 0x06CE, 1},
{0x06D0, 0x06D3, 1},
{0x06D5, 0x06D5, 1},
{0x06E5, 0x06E6, 1},
{0x0905, 0x0939, 1},
{0x093D, 0x093D, 1},
{0x0958, 0x0961, 1},
{0x0985, 0x098C, 1},
{0x098F, 0x0990, 1},
{0x0993, 0x09A8, 1},
{0x09AA, 0x09B0, 1},
{0x09B2, 0x09B2, 1},
{0x09B6, 0x09B9, 1},
{0x09DC, 0x09DD, 1},
{0x09DF, 0x09E1, 1},
{0x09F0, 0x09F1, 1},
{0x0A05, 0x0A0A, 1},
{0x0A0F, 0x0A10, 1},
{0x0A13, 0x0A28, 1},
{0x0A2A, 0x0A30, 1},
{0x0A32, 0x0A33, 1},
{0x0A35, 0x0A36, 1},
{0x0A38, 0x0A39, 1},
{0x0A59, 0x0A5C, 1},
{0x0A5E, 0x0A5E, 1},
{0x0A72, 0x0A74, 1},
{0x0A85, 0x0A8B, 1},
{0x0A8D, 0x0A8D, 1},
{0x0A8F, 0x0A91, 1},
{0x0A93, 0x0AA8, 1},
{0x0AAA, 0x0AB0, 1},
{0x0AB2, 0x0AB3, 1},
{0x0AB5, 0x0AB9, 1},
{0x0ABD, 0x0AE0, 0x23},
{0x0B05, 0x0B0C, 1},
{0x0B0F, 0x0B10, 1},
{0x0B13, 0x0B28, 1},
{0x0B2A, 0x0B30, 1},
{0x0B32, 0x0B33, 1},
{0x0B36, 0x0B39, 1},
{0x0B3D, 0x0B3D, 1},
{0x0B5C, 0x0B5D, 1},
{0x0B5F, 0x0B61, 1},
{0x0B85, 0x0B8A, 1},
{0x0B8E, 0x0B90, 1},
{0x0B92, 0x0B95, 1},
{0x0B99, 0x0B9A, 1},
{0x0B9C, 0x0B9C, 1},
{0x0B9E, 0x0B9F, 1},
{0x0BA3, 0x0BA4, 1},
{0x0BA8, 0x0BAA, 1},
{0x0BAE, 0x0BB5, 1},
{0x0BB7, 0x0BB9, 1},
{0x0C05, 0x0C0C, 1},
{0x0C0E, 0x0C10, 1},
{0x0C12, 0x0C28, 1},
{0x0C2A, 0x0C33, 1},
{0x0C35, 0x0C39, 1},
{0x0C60, 0x0C61, 1},
{0x0C85, 0x0C8C, 1},
{0x0C8E, 0x0C90, 1},
{0x0C92, 0x0CA8, 1},
{0x0CAA, 0x0CB3, 1},
{0x0CB5, 0x0CB9, 1},
{0x0CDE, 0x0CDE, 1},
{0x0CE0, 0x0CE1, 1},
{0x0D05, 0x0D0C, 1},
{0x0D0E, 0x0D10, 1},
{0x0D12, 0x0D28, 1},
{0x0D2A, 0x0D39, 1},
{0x0D60, 0x0D61, 1},
{0x0E01, 0x0E2E, 1},
{0x0E30, 0x0E30, 1},
{0x0E32, 0x0E33, 1},
{0x0E40, 0x0E45, 1},
{0x0E81, 0x0E82, 1},
{0x0E84, 0x0E84, 1},
{0x0E87, 0x0E88, 1},
{0x0E8A, 0x0E8D, 3},
{0x0E94, 0x0E97, 1},
{0x0E99, 0x0E9F, 1},
{0x0EA1, 0x0EA3, 1},
{0x0EA5, 0x0EA7, 2},
{0x0EAA, 0x0EAB, 1},
{0x0EAD, 0x0EAE, 1},
{0x0EB0, 0x0EB0, 1},
{0x0EB2, 0x0EB3, 1},
{0x0EBD, 0x0EBD, 1},
{0x0EC0, 0x0EC4, 1},
{0x0F40, 0x0F47, 1},
{0x0F49, 0x0F69, 1},
{0x10A0, 0x10C5, 1},
{0x10D0, 0x10F6, 1},
{0x1100, 0x1100, 1},
{0x1102, 0x1103, 1},
{0x1105, 0x1107, 1},
{0x1109, 0x1109, 1},
{0x110B, 0x110C, 1},
{0x110E, 0x1112, 1},
{0x113C, 0x1140, 2},
{0x114C, 0x1150, 2},
{0x1154, 0x1155, 1},
{0x1159, 0x1159, 1},
{0x115F, 0x1161, 1},
{0x1163, 0x1169, 2},
{0x116D, 0x116E, 1},
{0x1172, 0x1173, 1},
{0x1175, 0x119E, 0x119E - 0x1175},
{0x11A8, 0x11AB, 0x11AB - 0x11A8},
{0x11AE, 0x11AF, 1},
{0x11B7, 0x11B8, 1},
{0x11BA, 0x11BA, 1},
{0x11BC, 0x11C2, 1},
{0x11EB, 0x11F0, 0x11F0 - 0x11EB},
{0x11F9, 0x11F9, 1},
{0x1E00, 0x1E9B, 1},
{0x1EA0, 0x1EF9, 1},
{0x1F00, 0x1F15, 1},
{0x1F18, 0x1F1D, 1},
{0x1F20, 0x1F45, 1},
{0x1F48, 0x1F4D, 1},
{0x1F50, 0x1F57, 1},
{0x1F59, 0x1F5B, 0x1F5B - 0x1F59},
{0x1F5D, 0x1F5D, 1},
{0x1F5F, 0x1F7D, 1},
{0x1F80, 0x1FB4, 1},
{0x1FB6, 0x1FBC, 1},
{0x1FBE, 0x1FBE, 1},
{0x1FC2, 0x1FC4, 1},
{0x1FC6, 0x1FCC, 1},
{0x1FD0, 0x1FD3, 1},
{0x1FD6, 0x1FDB, 1},
{0x1FE0, 0x1FEC, 1},
{0x1FF2, 0x1FF4, 1},
{0x1FF6, 0x1FFC, 1},
{0x2126, 0x2126, 1},
{0x212A, 0x212B, 1},
{0x212E, 0x212E, 1},
{0x2180, 0x2182, 1},
{0x3007, 0x3007, 1},
{0x3021, 0x3029, 1},
{0x3041, 0x3094, 1},
{0x30A1, 0x30FA, 1},
{0x3105, 0x312C, 1},
{0x4E00, 0x9FA5, 1},
{0xAC00, 0xD7A3, 1},
},
}
var second = &unicode.RangeTable{
R16: []unicode.Range16{
{0x002D, 0x002E, 1},
{0x0030, 0x0039, 1},
{0x00B7, 0x00B7, 1},
{0x02D0, 0x02D1, 1},
{0x0300, 0x0345, 1},
{0x0360, 0x0361, 1},
{0x0387, 0x0387, 1},
{0x0483, 0x0486, 1},
{0x0591, 0x05A1, 1},
{0x05A3, 0x05B9, 1},
{0x05BB, 0x05BD, 1},
{0x05BF, 0x05BF, 1},
{0x05C1, 0x05C2, 1},
{0x05C4, 0x0640, 0x0640 - 0x05C4},
{0x064B, 0x0652, 1},
{0x0660, 0x0669, 1},
{0x0670, 0x0670, 1},
{0x06D6, 0x06DC, 1},
{0x06DD, 0x06DF, 1},
{0x06E0, 0x06E4, 1},
{0x06E7, 0x06E8, 1},
{0x06EA, 0x06ED, 1},
{0x06F0, 0x06F9, 1},
{0x0901, 0x0903, 1},
{0x093C, 0x093C, 1},
{0x093E, 0x094C, 1},
{0x094D, 0x094D, 1},
{0x0951, 0x0954, 1},
{0x0962, 0x0963, 1},
{0x0966, 0x096F, 1},
{0x0981, 0x0983, 1},
{0x09BC, 0x09BC, 1},
{0x09BE, 0x09BF, 1},
{0x09C0, 0x09C4, 1},
{0x09C7, 0x09C8, 1},
{0x09CB, 0x09CD, 1},
{0x09D7, 0x09D7, 1},
{0x09E2, 0x09E3, 1},
{0x09E6, 0x09EF, 1},
{0x0A02, 0x0A3C, 0x3A},
{0x0A3E, 0x0A3F, 1},
{0x0A40, 0x0A42, 1},
{0x0A47, 0x0A48, 1},
{0x0A4B, 0x0A4D, 1},
{0x0A66, 0x0A6F, 1},
{0x0A70, 0x0A71, 1},
{0x0A81, 0x0A83, 1},
{0x0ABC, 0x0ABC, 1},
{0x0ABE, 0x0AC5, 1},
{0x0AC7, 0x0AC9, 1},
{0x0ACB, 0x0ACD, 1},
{0x0AE6, 0x0AEF, 1},
{0x0B01, 0x0B03, 1},
{0x0B3C, 0x0B3C, 1},
{0x0B3E, 0x0B43, 1},
{0x0B47, 0x0B48, 1},
{0x0B4B, 0x0B4D, 1},
{0x0B56, 0x0B57, 1},
{0x0B66, 0x0B6F, 1},
{0x0B82, 0x0B83, 1},
{0x0BBE, 0x0BC2, 1},
{0x0BC6, 0x0BC8, 1},
{0x0BCA, 0x0BCD, 1},
{0x0BD7, 0x0BD7, 1},
{0x0BE7, 0x0BEF, 1},
{0x0C01, 0x0C03, 1},
{0x0C3E, 0x0C44, 1},
{0x0C46, 0x0C48, 1},
{0x0C4A, 0x0C4D, 1},
{0x0C55, 0x0C56, 1},
{0x0C66, 0x0C6F, 1},
{0x0C82, 0x0C83, 1},
{0x0CBE, 0x0CC4, 1},
{0x0CC6, 0x0CC8, 1},
{0x0CCA, 0x0CCD, 1},
{0x0CD5, 0x0CD6, 1},
{0x0CE6, 0x0CEF, 1},
{0x0D02, 0x0D03, 1},
{0x0D3E, 0x0D43, 1},
{0x0D46, 0x0D48, 1},
{0x0D4A, 0x0D4D, 1},
{0x0D57, 0x0D57, 1},
{0x0D66, 0x0D6F, 1},
{0x0E31, 0x0E31, 1},
{0x0E34, 0x0E3A, 1},
{0x0E46, 0x0E46, 1},
{0x0E47, 0x0E4E, 1},
{0x0E50, 0x0E59, 1},
{0x0EB1, 0x0EB1, 1},
{0x0EB4, 0x0EB9, 1},
{0x0EBB, 0x0EBC, 1},
{0x0EC6, 0x0EC6, 1},
{0x0EC8, 0x0ECD, 1},
{0x0ED0, 0x0ED9, 1},
{0x0F18, 0x0F19, 1},
{0x0F20, 0x0F29, 1},
{0x0F35, 0x0F39, 2},
{0x0F3E, 0x0F3F, 1},
{0x0F71, 0x0F84, 1},
{0x0F86, 0x0F8B, 1},
{0x0F90, 0x0F95, 1},
{0x0F97, 0x0F97, 1},
{0x0F99, 0x0FAD, 1},
{0x0FB1, 0x0FB7, 1},
{0x0FB9, 0x0FB9, 1},
{0x20D0, 0x20DC, 1},
{0x20E1, 0x3005, 0x3005 - 0x20E1},
{0x302A, 0x302F, 1},
{0x3031, 0x3035, 1},
{0x3099, 0x309A, 1},
{0x309D, 0x309E, 1},
{0x30FC, 0x30FE, 1},
},
}
// HTMLEntity is an entity map containing translations for the
// standard HTML entity characters.
//
// See the Decoder.Strict and Decoder.Entity fields' documentation.
var HTMLEntity map[string]string = htmlEntity
var htmlEntity = map[string]string{
/*
hget http://www.w3.org/TR/html4/sgml/entities.html |
ssam '
,y /\>/ x/\<(.|\n)+/ s/\n/ /g
,x v/^\<!ENTITY/d
,s/\<!ENTITY ([^ ]+) .*U\+([0-9A-F][0-9A-F][0-9A-F][0-9A-F]) .+/ "\1": "\\u\2",/g
'
*/
"nbsp": "\u00A0",
"iexcl": "\u00A1",
"cent": "\u00A2",
"pound": "\u00A3",
"curren": "\u00A4",
"yen": "\u00A5",
"brvbar": "\u00A6",
"sect": "\u00A7",
"uml": "\u00A8",
"copy": "\u00A9",
"ordf": "\u00AA",
"laquo": "\u00AB",
"not": "\u00AC",
"shy": "\u00AD",
"reg": "\u00AE",
"macr": "\u00AF",
"deg": "\u00B0",
"plusmn": "\u00B1",
"sup2": "\u00B2",
"sup3": "\u00B3",
"acute": "\u00B4",
"micro": "\u00B5",
"para": "\u00B6",
"middot": "\u00B7",
"cedil": "\u00B8",
"sup1": "\u00B9",
"ordm": "\u00BA",
"raquo": "\u00BB",
"frac14": "\u00BC",
"frac12": "\u00BD",
"frac34": "\u00BE",
"iquest": "\u00BF",
"Agrave": "\u00C0",
"Aacute": "\u00C1",
"Acirc": "\u00C2",
"Atilde": "\u00C3",
"Auml": "\u00C4",
"Aring": "\u00C5",
"AElig": "\u00C6",
"Ccedil": "\u00C7",
"Egrave": "\u00C8",
"Eacute": "\u00C9",
"Ecirc": "\u00CA",
"Euml": "\u00CB",
"Igrave": "\u00CC",
"Iacute": "\u00CD",
"Icirc": "\u00CE",
"Iuml": "\u00CF",
"ETH": "\u00D0",
"Ntilde": "\u00D1",
"Ograve": "\u00D2",
"Oacute": "\u00D3",
"Ocirc": "\u00D4",
"Otilde": "\u00D5",
"Ouml": "\u00D6",
"times": "\u00D7",
"Oslash": "\u00D8",
"Ugrave": "\u00D9",
"Uacute": "\u00DA",
"Ucirc": "\u00DB",
"Uuml": "\u00DC",
"Yacute": "\u00DD",
"THORN": "\u00DE",
"szlig": "\u00DF",
"agrave": "\u00E0",
"aacute": "\u00E1",
"acirc": "\u00E2",
"atilde": "\u00E3",
"auml": "\u00E4",
"aring": "\u00E5",
"aelig": "\u00E6",
"ccedil": "\u00E7",
"egrave": "\u00E8",
"eacute": "\u00E9",
"ecirc": "\u00EA",
"euml": "\u00EB",
"igrave": "\u00EC",
"iacute": "\u00ED",
"icirc": "\u00EE",
"iuml": "\u00EF",
"eth": "\u00F0",
"ntilde": "\u00F1",
"ograve": "\u00F2",
"oacute": "\u00F3",
"ocirc": "\u00F4",
"otilde": "\u00F5",
"ouml": "\u00F6",
"divide": "\u00F7",
"oslash": "\u00F8",
"ugrave": "\u00F9",
"uacute": "\u00FA",
"ucirc": "\u00FB",
"uuml": "\u00FC",
"yacute": "\u00FD",
"thorn": "\u00FE",
"yuml": "\u00FF",
"fnof": "\u0192",
"Alpha": "\u0391",
"Beta": "\u0392",
"Gamma": "\u0393",
"Delta": "\u0394",
"Epsilon": "\u0395",
"Zeta": "\u0396",
"Eta": "\u0397",
"Theta": "\u0398",
"Iota": "\u0399",
"Kappa": "\u039A",
"Lambda": "\u039B",
"Mu": "\u039C",
"Nu": "\u039D",
"Xi": "\u039E",
"Omicron": "\u039F",
"Pi": "\u03A0",
"Rho": "\u03A1",
"Sigma": "\u03A3",
"Tau": "\u03A4",
"Upsilon": "\u03A5",
"Phi": "\u03A6",
"Chi": "\u03A7",
"Psi": "\u03A8",
"Omega": "\u03A9",
"alpha": "\u03B1",
"beta": "\u03B2",
"gamma": "\u03B3",
"delta": "\u03B4",
"epsilon": "\u03B5",
"zeta": "\u03B6",
"eta": "\u03B7",
"theta": "\u03B8",
"iota": "\u03B9",
"kappa": "\u03BA",
"lambda": "\u03BB",
"mu": "\u03BC",
"nu": "\u03BD",
"xi": "\u03BE",
"omicron": "\u03BF",
"pi": "\u03C0",
"rho": "\u03C1",
"sigmaf": "\u03C2",
"sigma": "\u03C3",
"tau": "\u03C4",
"upsilon": "\u03C5",
"phi": "\u03C6",
"chi": "\u03C7",
"psi": "\u03C8",
"omega": "\u03C9",
"thetasym": "\u03D1",
"upsih": "\u03D2",
"piv": "\u03D6",
"bull": "\u2022",
"hellip": "\u2026",
"prime": "\u2032",
"Prime": "\u2033",
"oline": "\u203E",
"frasl": "\u2044",
"weierp": "\u2118",
"image": "\u2111",
"real": "\u211C",
"trade": "\u2122",
"alefsym": "\u2135",
"larr": "\u2190",
"uarr": "\u2191",
"rarr": "\u2192",
"darr": "\u2193",
"harr": "\u2194",
"crarr": "\u21B5",
"lArr": "\u21D0",
"uArr": "\u21D1",
"rArr": "\u21D2",
"dArr": "\u21D3",
"hArr": "\u21D4",
"forall": "\u2200",
"part": "\u2202",
"exist": "\u2203",
"empty": "\u2205",
"nabla": "\u2207",
"isin": "\u2208",
"notin": "\u2209",
"ni": "\u220B",
"prod": "\u220F",
"sum": "\u2211",
"minus": "\u2212",
"lowast": "\u2217",
"radic": "\u221A",
"prop": "\u221D",
"infin": "\u221E",
"ang": "\u2220",
"and": "\u2227",
"or": "\u2228",
"cap": "\u2229",
"cup": "\u222A",
"int": "\u222B",
"there4": "\u2234",
"sim": "\u223C",
"cong": "\u2245",
"asymp": "\u2248",
"ne": "\u2260",
"equiv": "\u2261",
"le": "\u2264",
"ge": "\u2265",
"sub": "\u2282",
"sup": "\u2283",
"nsub": "\u2284",
"sube": "\u2286",
"supe": "\u2287",
"oplus": "\u2295",
"otimes": "\u2297",
"perp": "\u22A5",
"sdot": "\u22C5",
"lceil": "\u2308",
"rceil": "\u2309",
"lfloor": "\u230A",
"rfloor": "\u230B",
"lang": "\u2329",
"rang": "\u232A",
"loz": "\u25CA",
"spades": "\u2660",
"clubs": "\u2663",
"hearts": "\u2665",
"diams": "\u2666",
"quot": "\u0022",
"amp": "\u0026",
"lt": "\u003C",
"gt": "\u003E",
"OElig": "\u0152",
"oelig": "\u0153",
"Scaron": "\u0160",
"scaron": "\u0161",
"Yuml": "\u0178",
"circ": "\u02C6",
"tilde": "\u02DC",
"ensp": "\u2002",
"emsp": "\u2003",
"thinsp": "\u2009",
"zwnj": "\u200C",
"zwj": "\u200D",
"lrm": "\u200E",
"rlm": "\u200F",
"ndash": "\u2013",
"mdash": "\u2014",
"lsquo": "\u2018",
"rsquo": "\u2019",
"sbquo": "\u201A",
"ldquo": "\u201C",
"rdquo": "\u201D",
"bdquo": "\u201E",
"dagger": "\u2020",
"Dagger": "\u2021",
"permil": "\u2030",
"lsaquo": "\u2039",
"rsaquo": "\u203A",
"euro": "\u20AC",
}
// HTMLAutoClose is the set of HTML elements that
// should be considered to close automatically.
//
// See the Decoder.Strict and Decoder.Entity fields' documentation.
var HTMLAutoClose []string = htmlAutoClose
var htmlAutoClose = []string{
/*
hget http://www.w3.org/TR/html4/loose.dtd |
9 sed -n 's/<!ELEMENT ([^ ]*) +- O EMPTY.+/ "\1",/p' | tr A-Z a-z
*/
"basefont",
"br",
"area",
"link",
"img",
"param",
"hr",
"input",
"col",
"frame",
"isindex",
"base",
"meta",
}
var (
escQuot = []byte(""") // shorter than """
escApos = []byte("'") // shorter than "'"
escAmp = []byte("&")
escLT = []byte("<")
escGT = []byte(">")
escTab = []byte("	")
escNL = []byte("
")
escCR = []byte("
")
escFFFD = []byte("\uFFFD") // Unicode replacement character
)
// EscapeText writes to w the properly escaped XML equivalent
// of the plain text data s.
func EscapeText(w io.Writer, s []byte) error {
return escapeText(w, s, true)
}
// escapeText writes to w the properly escaped XML equivalent
// of the plain text data s. If escapeNewline is true, newline
// characters will be escaped.
func escapeText(w io.Writer, s []byte, escapeNewline bool) error {
var esc []byte
last := 0
for i := 0; i < len(s); {
r, width := utf8.DecodeRune(s[i:])
i += width
switch r {
case '"':
esc = escQuot
case '\'':
esc = escApos
case '&':
esc = escAmp
case '<':
esc = escLT
case '>':
esc = escGT
case '\t':
esc = escTab
case '\n':
if !escapeNewline {
continue
}
esc = escNL
case '\r':
esc = escCR
default:
if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
esc = escFFFD
break
}
continue
}
if _, err := w.Write(s[last : i-width]); err != nil {
return err
}
if _, err := w.Write(esc); err != nil {
return err
}
last = i
}
_, err := w.Write(s[last:])
return err
}
// EscapeString writes to p the properly escaped XML equivalent
// of the plain text data s.
func (p *printer) EscapeString(s string) {
var esc []byte
last := 0
for i := 0; i < len(s); {
r, width := utf8.DecodeRuneInString(s[i:])
i += width
switch r {
case '"':
esc = escQuot
case '\'':
esc = escApos
case '&':
esc = escAmp
case '<':
esc = escLT
case '>':
esc = escGT
case '\t':
esc = escTab
case '\n':
esc = escNL
case '\r':
esc = escCR
default:
if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) {
esc = escFFFD
break
}
continue
}
p.WriteString(s[last : i-width])
p.Write(esc)
last = i
}
p.WriteString(s[last:])
}
// Escape is like EscapeText but omits the error return value.
// It is provided for backwards compatibility with Go 1.0.
// Code targeting Go 1.1 or later should use EscapeText.
func Escape(w io.Writer, s []byte) {
EscapeText(w, s)
}
var (
cdataStart = []byte("<![CDATA[")
cdataEnd = []byte("]]>")
cdataEscape = []byte("]]]]><![CDATA[>")
)
// emitCDATA writes to w the CDATA-wrapped plain text data s.
// It escapes CDATA directives nested in s.
func emitCDATA(w io.Writer, s []byte) error {
if len(s) == 0 {
return nil
}
if _, err := w.Write(cdataStart); err != nil {
return err
}
for {
before, after, ok := bytes.Cut(s, cdataEnd)
if !ok {
break
}
// Found a nested CDATA directive end.
if _, err := w.Write(before); err != nil {
return err
}
if _, err := w.Write(cdataEscape); err != nil {
return err
}
s = after
}
if _, err := w.Write(s); err != nil {
return err
}
_, err := w.Write(cdataEnd)
return err
}
// procInst parses the `param="..."` or `param='...'`
// value out of the provided string, returning "" if not found.
func procInst(param, s string) string {
// TODO: this parsing is somewhat lame and not exact.
// It works for all actual cases, though.
param = param + "="
_, v, _ := strings.Cut(s, param)
if v == "" {
return ""
}
if v[0] != '\'' && v[0] != '"' {
return ""
}
unquote, _, ok := strings.Cut(v[1:], v[:1])
if !ok {
return ""
}
return unquote
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errors implements functions to manipulate errors.
//
// The New function creates errors whose only content is a text message.
//
// An error e wraps another error if e's type has one of the methods
//
// Unwrap() error
// Unwrap() []error
//
// If e.Unwrap() returns a non-nil error w or a slice containing w,
// then we say that e wraps w. A nil error returned from e.Unwrap()
// indicates that e does not wrap any error. It is invalid for an
// Unwrap method to return an []error containing a nil error value.
//
// An easy way to create wrapped errors is to call fmt.Errorf and apply
// the %w verb to the error argument:
//
// wrapsErr := fmt.Errorf("... %w ...", ..., err, ...)
//
// Successive unwrapping of an error creates a tree. The Is and As
// functions inspect an error's tree by examining first the error
// itself followed by the tree of each of its children in turn
// (pre-order, depth-first traversal).
//
// Is examines the tree of its first argument looking for an error that
// matches the second. It reports whether it finds a match. It should be
// used in preference to simple equality checks:
//
// if errors.Is(err, fs.ErrExist)
//
// is preferable to
//
// if err == fs.ErrExist
//
// because the former will succeed if err wraps fs.ErrExist.
//
// As examines the tree of its first argument looking for an error that can be
// assigned to its second argument, which must be a pointer. If it succeeds, it
// performs the assignment and returns true. Otherwise, it returns false. The form
//
// var perr *fs.PathError
// if errors.As(err, &perr) {
// fmt.Println(perr.Path)
// }
//
// is preferable to
//
// if perr, ok := err.(*fs.PathError); ok {
// fmt.Println(perr.Path)
// }
//
// because the former will succeed if err wraps an *fs.PathError.
package errors
// New returns an error that formats as the given text.
// Each call to New returns a distinct error value even if the text is identical.
func New(text string) error {
return &errorString{text}
}
// errorString is a trivial implementation of error.
type errorString struct {
s string
}
func (e *errorString) Error() string {
return e.s
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errors
// Join returns an error that wraps the given errors.
// Any nil error values are discarded.
// Join returns nil if errs contains no non-nil values.
// The error formats as the concatenation of the strings obtained
// by calling the Error method of each element of errs, with a newline
// between each string.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}
type joinError struct {
errs []error
}
func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}
func (e *joinError) Unwrap() []error {
return e.errs
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errors
import (
"internal/reflectlite"
)
// Unwrap returns the result of calling the Unwrap method on err, if err's
// type contains an Unwrap method returning error.
// Otherwise, Unwrap returns nil.
//
// Unwrap returns nil if the Unwrap method returns []error.
func Unwrap(err error) error {
u, ok := err.(interface {
Unwrap() error
})
if !ok {
return nil
}
return u.Unwrap()
}
// Is reports whether any error in err's tree matches target.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling Unwrap. When err wraps multiple errors, Is examines err followed by a
// depth-first traversal of its children.
//
// An error is considered to match a target if it is equal to that target or if
// it implements a method Is(error) bool such that Is(target) returns true.
//
// An error type might provide an Is method so it can be treated as equivalent
// to an existing error. For example, if MyError defines
//
// func (m MyError) Is(target error) bool { return target == fs.ErrExist }
//
// then Is(MyError{}, fs.ErrExist) returns true. See syscall.Errno.Is for
// an example in the standard library. An Is method should only shallowly
// compare err and the target and not call Unwrap on either.
func Is(err, target error) bool {
if target == nil {
return err == target
}
isComparable := reflectlite.TypeOf(target).Comparable()
for {
if isComparable && err == target {
return true
}
if x, ok := err.(interface{ Is(error) bool }); ok && x.Is(target) {
return true
}
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return false
}
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if Is(err, target) {
return true
}
}
return false
default:
return false
}
}
}
// As finds the first error in err's tree that matches target, and if one is found, sets
// target to that error value and returns true. Otherwise, it returns false.
//
// The tree consists of err itself, followed by the errors obtained by repeatedly
// calling Unwrap. When err wraps multiple errors, As examines err followed by a
// depth-first traversal of its children.
//
// An error matches target if the error's concrete value is assignable to the value
// pointed to by target, or if the error has a method As(interface{}) bool such that
// As(target) returns true. In the latter case, the As method is responsible for
// setting target.
//
// An error type might provide an As method so it can be treated as if it were a
// different error type.
//
// As panics if target is not a non-nil pointer to either a type that implements
// error, or to any interface type.
func As(err error, target any) bool {
if err == nil {
return false
}
if target == nil {
panic("errors: target cannot be nil")
}
val := reflectlite.ValueOf(target)
typ := val.Type()
if typ.Kind() != reflectlite.Ptr || val.IsNil() {
panic("errors: target must be a non-nil pointer")
}
targetType := typ.Elem()
if targetType.Kind() != reflectlite.Interface && !targetType.Implements(errorType) {
panic("errors: *target must be interface or implement error")
}
for {
if reflectlite.TypeOf(err).AssignableTo(targetType) {
val.Elem().Set(reflectlite.ValueOf(err))
return true
}
if x, ok := err.(interface{ As(any) bool }); ok && x.As(target) {
return true
}
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return false
}
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if As(err, target) {
return true
}
}
return false
default:
return false
}
}
}
var errorType = reflectlite.TypeOf((*error)(nil)).Elem()
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package expvar provides a standardized interface to public variables, such
// as operation counters in servers. It exposes these variables via HTTP at
// /debug/vars in JSON format.
//
// Operations to set or modify these public variables are atomic.
//
// In addition to adding the HTTP handler, this package registers the
// following variables:
//
// cmdline os.Args
// memstats runtime.Memstats
//
// The package is sometimes only imported for the side effect of
// registering its HTTP handler and the above variables. To use it
// this way, link this package into your program:
//
// import _ "expvar"
package expvar
import (
"encoding/json"
"fmt"
"log"
"math"
"net/http"
"os"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
)
// Var is an abstract type for all exported variables.
type Var interface {
// String returns a valid JSON value for the variable.
// Types with String methods that do not return valid JSON
// (such as time.Time) must not be used as a Var.
String() string
}
// Int is a 64-bit integer variable that satisfies the Var interface.
type Int struct {
i int64
}
func (v *Int) Value() int64 {
return atomic.LoadInt64(&v.i)
}
func (v *Int) String() string {
return strconv.FormatInt(atomic.LoadInt64(&v.i), 10)
}
func (v *Int) Add(delta int64) {
atomic.AddInt64(&v.i, delta)
}
func (v *Int) Set(value int64) {
atomic.StoreInt64(&v.i, value)
}
// Float is a 64-bit float variable that satisfies the Var interface.
type Float struct {
f atomic.Uint64
}
func (v *Float) Value() float64 {
return math.Float64frombits(v.f.Load())
}
func (v *Float) String() string {
return strconv.FormatFloat(
math.Float64frombits(v.f.Load()), 'g', -1, 64)
}
// Add adds delta to v.
func (v *Float) Add(delta float64) {
for {
cur := v.f.Load()
curVal := math.Float64frombits(cur)
nxtVal := curVal + delta
nxt := math.Float64bits(nxtVal)
if v.f.CompareAndSwap(cur, nxt) {
return
}
}
}
// Set sets v to value.
func (v *Float) Set(value float64) {
v.f.Store(math.Float64bits(value))
}
// Map is a string-to-Var map variable that satisfies the Var interface.
type Map struct {
m sync.Map // map[string]Var
keysMu sync.RWMutex
keys []string // sorted
}
// KeyValue represents a single entry in a Map.
type KeyValue struct {
Key string
Value Var
}
func (v *Map) String() string {
var b strings.Builder
fmt.Fprintf(&b, "{")
first := true
v.Do(func(kv KeyValue) {
if !first {
fmt.Fprintf(&b, ", ")
}
fmt.Fprintf(&b, "%q: ", kv.Key)
if kv.Value != nil {
fmt.Fprintf(&b, "%v", kv.Value)
} else {
fmt.Fprint(&b, "null")
}
first = false
})
fmt.Fprintf(&b, "}")
return b.String()
}
// Init removes all keys from the map.
func (v *Map) Init() *Map {
v.keysMu.Lock()
defer v.keysMu.Unlock()
v.keys = v.keys[:0]
v.m.Range(func(k, _ any) bool {
v.m.Delete(k)
return true
})
return v
}
// addKey updates the sorted list of keys in v.keys.
func (v *Map) addKey(key string) {
v.keysMu.Lock()
defer v.keysMu.Unlock()
// Using insertion sort to place key into the already-sorted v.keys.
if i := sort.SearchStrings(v.keys, key); i >= len(v.keys) {
v.keys = append(v.keys, key)
} else if v.keys[i] != key {
v.keys = append(v.keys, "")
copy(v.keys[i+1:], v.keys[i:])
v.keys[i] = key
}
}
func (v *Map) Get(key string) Var {
i, _ := v.m.Load(key)
av, _ := i.(Var)
return av
}
func (v *Map) Set(key string, av Var) {
// Before we store the value, check to see whether the key is new. Try a Load
// before LoadOrStore: LoadOrStore causes the key interface to escape even on
// the Load path.
if _, ok := v.m.Load(key); !ok {
if _, dup := v.m.LoadOrStore(key, av); !dup {
v.addKey(key)
return
}
}
v.m.Store(key, av)
}
// Add adds delta to the *Int value stored under the given map key.
func (v *Map) Add(key string, delta int64) {
i, ok := v.m.Load(key)
if !ok {
var dup bool
i, dup = v.m.LoadOrStore(key, new(Int))
if !dup {
v.addKey(key)
}
}
// Add to Int; ignore otherwise.
if iv, ok := i.(*Int); ok {
iv.Add(delta)
}
}
// AddFloat adds delta to the *Float value stored under the given map key.
func (v *Map) AddFloat(key string, delta float64) {
i, ok := v.m.Load(key)
if !ok {
var dup bool
i, dup = v.m.LoadOrStore(key, new(Float))
if !dup {
v.addKey(key)
}
}
// Add to Float; ignore otherwise.
if iv, ok := i.(*Float); ok {
iv.Add(delta)
}
}
// Delete deletes the given key from the map.
func (v *Map) Delete(key string) {
v.keysMu.Lock()
defer v.keysMu.Unlock()
i := sort.SearchStrings(v.keys, key)
if i < len(v.keys) && key == v.keys[i] {
v.keys = append(v.keys[:i], v.keys[i+1:]...)
v.m.Delete(key)
}
}
// Do calls f for each entry in the map.
// The map is locked during the iteration,
// but existing entries may be concurrently updated.
func (v *Map) Do(f func(KeyValue)) {
v.keysMu.RLock()
defer v.keysMu.RUnlock()
for _, k := range v.keys {
i, _ := v.m.Load(k)
val, _ := i.(Var)
f(KeyValue{k, val})
}
}
// String is a string variable, and satisfies the Var interface.
type String struct {
s atomic.Value // string
}
func (v *String) Value() string {
p, _ := v.s.Load().(string)
return p
}
// String implements the Var interface. To get the unquoted string
// use Value.
func (v *String) String() string {
s := v.Value()
b, _ := json.Marshal(s)
return string(b)
}
func (v *String) Set(value string) {
v.s.Store(value)
}
// Func implements Var by calling the function
// and formatting the returned value using JSON.
type Func func() any
func (f Func) Value() any {
return f()
}
func (f Func) String() string {
v, _ := json.Marshal(f())
return string(v)
}
// All published variables.
var (
vars sync.Map // map[string]Var
varKeysMu sync.RWMutex
varKeys []string // sorted
)
// Publish declares a named exported variable. This should be called from a
// package's init function when it creates its Vars. If the name is already
// registered then this will log.Panic.
func Publish(name string, v Var) {
if _, dup := vars.LoadOrStore(name, v); dup {
log.Panicln("Reuse of exported var name:", name)
}
varKeysMu.Lock()
defer varKeysMu.Unlock()
varKeys = append(varKeys, name)
sort.Strings(varKeys)
}
// Get retrieves a named exported variable. It returns nil if the name has
// not been registered.
func Get(name string) Var {
i, _ := vars.Load(name)
v, _ := i.(Var)
return v
}
// Convenience functions for creating new exported variables.
func NewInt(name string) *Int {
v := new(Int)
Publish(name, v)
return v
}
func NewFloat(name string) *Float {
v := new(Float)
Publish(name, v)
return v
}
func NewMap(name string) *Map {
v := new(Map).Init()
Publish(name, v)
return v
}
func NewString(name string) *String {
v := new(String)
Publish(name, v)
return v
}
// Do calls f for each exported variable.
// The global variable map is locked during the iteration,
// but existing entries may be concurrently updated.
func Do(f func(KeyValue)) {
varKeysMu.RLock()
defer varKeysMu.RUnlock()
for _, k := range varKeys {
val, _ := vars.Load(k)
f(KeyValue{k, val.(Var)})
}
}
func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n")
first := true
Do(func(kv KeyValue) {
if !first {
fmt.Fprintf(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprintf(w, "\n}\n")
}
// Handler returns the expvar HTTP Handler.
//
// This is only needed to install the handler in a non-standard location.
func Handler() http.Handler {
return http.HandlerFunc(expvarHandler)
}
func cmdline() any {
return os.Args
}
func memstats() any {
stats := new(runtime.MemStats)
runtime.ReadMemStats(stats)
return *stats
}
func init() {
http.HandleFunc("/debug/vars", expvarHandler)
Publish("cmdline", Func(cmdline))
Publish("memstats", Func(memstats))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package flag implements command-line flag parsing.
# Usage
Define flags using flag.String(), Bool(), Int(), etc.
This declares an integer flag, -n, stored in the pointer nFlag, with type *int:
import "flag"
var nFlag = flag.Int("n", 1234, "help message for flag n")
If you like, you can bind the flag to a variable using the Var() functions.
var flagvar int
func init() {
flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname")
}
Or you can create custom flags that satisfy the Value interface (with
pointer receivers) and couple them to flag parsing by
flag.Var(&flagVal, "name", "help message for flagname")
For such flags, the default value is just the initial value of the variable.
After all flags are defined, call
flag.Parse()
to parse the command line into the defined flags.
Flags may then be used directly. If you're using the flags themselves,
they are all pointers; if you bind to variables, they're values.
fmt.Println("ip has value ", *ip)
fmt.Println("flagvar has value ", flagvar)
After parsing, the arguments following the flags are available as the
slice flag.Args() or individually as flag.Arg(i).
The arguments are indexed from 0 through flag.NArg()-1.
# Command line flag syntax
The following forms are permitted:
-flag
--flag // double dashes are also permitted
-flag=x
-flag x // non-boolean flags only
One or two dashes may be used; they are equivalent.
The last form is not permitted for boolean flags because the
meaning of the command
cmd -x *
where * is a Unix shell wildcard, will change if there is a file
called 0, false, etc. You must use the -flag=false form to turn
off a boolean flag.
Flag parsing stops just before the first non-flag argument
("-" is a non-flag argument) or after the terminator "--".
Integer flags accept 1234, 0664, 0x1234 and may be negative.
Boolean flags may be:
1, 0, t, f, T, F, true, false, TRUE, FALSE, True, False
Duration flags accept any input valid for time.ParseDuration.
The default set of command-line flags is controlled by
top-level functions. The FlagSet type allows one to define
independent sets of flags, such as to implement subcommands
in a command-line interface. The methods of FlagSet are
analogous to the top-level functions for the command-line
flag set.
*/
package flag
import (
"encoding"
"errors"
"fmt"
"io"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
// ErrHelp is the error returned if the -help or -h flag is invoked
// but no such flag is defined.
var ErrHelp = errors.New("flag: help requested")
// errParse is returned by Set if a flag's value fails to parse, such as with an invalid integer for Int.
// It then gets wrapped through failf to provide more information.
var errParse = errors.New("parse error")
// errRange is returned by Set if a flag's value is out of range.
// It then gets wrapped through failf to provide more information.
var errRange = errors.New("value out of range")
func numError(err error) error {
ne, ok := err.(*strconv.NumError)
if !ok {
return err
}
if ne.Err == strconv.ErrSyntax {
return errParse
}
if ne.Err == strconv.ErrRange {
return errRange
}
return err
}
// -- bool Value
type boolValue bool
func newBoolValue(val bool, p *bool) *boolValue {
*p = val
return (*boolValue)(p)
}
func (b *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
err = errParse
}
*b = boolValue(v)
return err
}
func (b *boolValue) Get() any { return bool(*b) }
func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) }
func (b *boolValue) IsBoolFlag() bool { return true }
// optional interface to indicate boolean flags that can be
// supplied without "=value" text
type boolFlag interface {
Value
IsBoolFlag() bool
}
// -- int Value
type intValue int
func newIntValue(val int, p *int) *intValue {
*p = val
return (*intValue)(p)
}
func (i *intValue) Set(s string) error {
v, err := strconv.ParseInt(s, 0, strconv.IntSize)
if err != nil {
err = numError(err)
}
*i = intValue(v)
return err
}
func (i *intValue) Get() any { return int(*i) }
func (i *intValue) String() string { return strconv.Itoa(int(*i)) }
// -- int64 Value
type int64Value int64
func newInt64Value(val int64, p *int64) *int64Value {
*p = val
return (*int64Value)(p)
}
func (i *int64Value) Set(s string) error {
v, err := strconv.ParseInt(s, 0, 64)
if err != nil {
err = numError(err)
}
*i = int64Value(v)
return err
}
func (i *int64Value) Get() any { return int64(*i) }
func (i *int64Value) String() string { return strconv.FormatInt(int64(*i), 10) }
// -- uint Value
type uintValue uint
func newUintValue(val uint, p *uint) *uintValue {
*p = val
return (*uintValue)(p)
}
func (i *uintValue) Set(s string) error {
v, err := strconv.ParseUint(s, 0, strconv.IntSize)
if err != nil {
err = numError(err)
}
*i = uintValue(v)
return err
}
func (i *uintValue) Get() any { return uint(*i) }
func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i), 10) }
// -- uint64 Value
type uint64Value uint64
func newUint64Value(val uint64, p *uint64) *uint64Value {
*p = val
return (*uint64Value)(p)
}
func (i *uint64Value) Set(s string) error {
v, err := strconv.ParseUint(s, 0, 64)
if err != nil {
err = numError(err)
}
*i = uint64Value(v)
return err
}
func (i *uint64Value) Get() any { return uint64(*i) }
func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i), 10) }
// -- string Value
type stringValue string
func newStringValue(val string, p *string) *stringValue {
*p = val
return (*stringValue)(p)
}
func (s *stringValue) Set(val string) error {
*s = stringValue(val)
return nil
}
func (s *stringValue) Get() any { return string(*s) }
func (s *stringValue) String() string { return string(*s) }
// -- float64 Value
type float64Value float64
func newFloat64Value(val float64, p *float64) *float64Value {
*p = val
return (*float64Value)(p)
}
func (f *float64Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 64)
if err != nil {
err = numError(err)
}
*f = float64Value(v)
return err
}
func (f *float64Value) Get() any { return float64(*f) }
func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) }
// -- time.Duration Value
type durationValue time.Duration
func newDurationValue(val time.Duration, p *time.Duration) *durationValue {
*p = val
return (*durationValue)(p)
}
func (d *durationValue) Set(s string) error {
v, err := time.ParseDuration(s)
if err != nil {
err = errParse
}
*d = durationValue(v)
return err
}
func (d *durationValue) Get() any { return time.Duration(*d) }
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
// -- encoding.TextUnmarshaler Value
type textValue struct{ p encoding.TextUnmarshaler }
func newTextValue(val encoding.TextMarshaler, p encoding.TextUnmarshaler) textValue {
ptrVal := reflect.ValueOf(p)
if ptrVal.Kind() != reflect.Ptr {
panic("variable value type must be a pointer")
}
defVal := reflect.ValueOf(val)
if defVal.Kind() == reflect.Ptr {
defVal = defVal.Elem()
}
if defVal.Type() != ptrVal.Type().Elem() {
panic(fmt.Sprintf("default type does not match variable type: %v != %v", defVal.Type(), ptrVal.Type().Elem()))
}
ptrVal.Elem().Set(defVal)
return textValue{p}
}
func (v textValue) Set(s string) error {
return v.p.UnmarshalText([]byte(s))
}
func (v textValue) Get() interface{} {
return v.p
}
func (v textValue) String() string {
if m, ok := v.p.(encoding.TextMarshaler); ok {
if b, err := m.MarshalText(); err == nil {
return string(b)
}
}
return ""
}
// -- func Value
type funcValue func(string) error
func (f funcValue) Set(s string) error { return f(s) }
func (f funcValue) String() string { return "" }
// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
//
// If a Value has an IsBoolFlag() bool method returning true,
// the command-line parser makes -name equivalent to -name=true
// rather than using the next command-line argument.
//
// Set is called once, in command line order, for each flag present.
// The flag package may call the String method with a zero-valued receiver,
// such as a nil pointer.
type Value interface {
String() string
Set(string) error
}
// Getter is an interface that allows the contents of a Value to be retrieved.
// It wraps the Value interface, rather than being part of it, because it
// appeared after Go 1 and its compatibility rules. All Value types provided
// by this package satisfy the Getter interface, except the type used by Func.
type Getter interface {
Value
Get() any
}
// ErrorHandling defines how FlagSet.Parse behaves if the parse fails.
type ErrorHandling int
// These constants cause FlagSet.Parse to behave as described if the parse fails.
const (
ContinueOnError ErrorHandling = iota // Return a descriptive error.
ExitOnError // Call os.Exit(2) or for -h/-help Exit(0).
PanicOnError // Call panic with a descriptive error.
)
// A FlagSet represents a set of defined flags. The zero value of a FlagSet
// has no name and has ContinueOnError error handling.
//
// Flag names must be unique within a FlagSet. An attempt to define a flag whose
// name is already in use will cause a panic.
type FlagSet struct {
// Usage is the function called when an error occurs while parsing flags.
// The field is a function (not a method) that may be changed to point to
// a custom error handler. What happens after Usage is called depends
// on the ErrorHandling setting; for the command line, this defaults
// to ExitOnError, which exits the program after calling Usage.
Usage func()
name string
parsed bool
actual map[string]*Flag
formal map[string]*Flag
args []string // arguments after flags
errorHandling ErrorHandling
output io.Writer // nil means stderr; use Output() accessor
}
// A Flag represents the state of a flag.
type Flag struct {
Name string // name as it appears on command line
Usage string // help message
Value Value // value as set
DefValue string // default value (as text); for usage message
}
// sortFlags returns the flags as a slice in lexicographical sorted order.
func sortFlags(flags map[string]*Flag) []*Flag {
result := make([]*Flag, len(flags))
i := 0
for _, f := range flags {
result[i] = f
i++
}
sort.Slice(result, func(i, j int) bool {
return result[i].Name < result[j].Name
})
return result
}
// Output returns the destination for usage and error messages. os.Stderr is returned if
// output was not set or was set to nil.
func (f *FlagSet) Output() io.Writer {
if f.output == nil {
return os.Stderr
}
return f.output
}
// Name returns the name of the flag set.
func (f *FlagSet) Name() string {
return f.name
}
// ErrorHandling returns the error handling behavior of the flag set.
func (f *FlagSet) ErrorHandling() ErrorHandling {
return f.errorHandling
}
// SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used.
func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}
// VisitAll visits the flags in lexicographical order, calling fn for each.
// It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) {
for _, flag := range sortFlags(f.formal) {
fn(flag)
}
}
// VisitAll visits the command-line flags in lexicographical order, calling
// fn for each. It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) {
CommandLine.VisitAll(fn)
}
// Visit visits the flags in lexicographical order, calling fn for each.
// It visits only those flags that have been set.
func (f *FlagSet) Visit(fn func(*Flag)) {
for _, flag := range sortFlags(f.actual) {
fn(flag)
}
}
// Visit visits the command-line flags in lexicographical order, calling fn
// for each. It visits only those flags that have been set.
func Visit(fn func(*Flag)) {
CommandLine.Visit(fn)
}
// Lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) Lookup(name string) *Flag {
return f.formal[name]
}
// Lookup returns the Flag structure of the named command-line flag,
// returning nil if none exists.
func Lookup(name string) *Flag {
return CommandLine.formal[name]
}
// Set sets the value of the named flag.
func (f *FlagSet) Set(name, value string) error {
flag, ok := f.formal[name]
if !ok {
return fmt.Errorf("no such flag -%v", name)
}
err := flag.Value.Set(value)
if err != nil {
return err
}
if f.actual == nil {
f.actual = make(map[string]*Flag)
}
f.actual[name] = flag
return nil
}
// Set sets the value of the named command-line flag.
func Set(name, value string) error {
return CommandLine.Set(name, value)
}
// isZeroValue determines whether the string represents the zero
// value for a flag.
func isZeroValue(flag *Flag, value string) (ok bool, err error) {
// Build a zero value of the flag's Value type, and see if the
// result of calling its String method equals the value passed in.
// This works unless the Value type is itself an interface type.
typ := reflect.TypeOf(flag.Value)
var z reflect.Value
if typ.Kind() == reflect.Pointer {
z = reflect.New(typ.Elem())
} else {
z = reflect.Zero(typ)
}
// Catch panics calling the String method, which shouldn't prevent the
// usage message from being printed, but that we should report to the
// user so that they know to fix their code.
defer func() {
if e := recover(); e != nil {
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
err = fmt.Errorf("panic calling String method on zero %v for flag %s: %v", typ, flag.Name, e)
}
}()
return value == z.Interface().(Value).String(), nil
}
// UnquoteUsage extracts a back-quoted name from the usage
// string for a flag and returns it and the un-quoted usage.
// Given "a `name` to show" it returns ("name", "a name to show").
// If there are no back quotes, the name is an educated guess of the
// type of the flag's value, or the empty string if the flag is boolean.
func UnquoteUsage(flag *Flag) (name string, usage string) {
// Look for a back-quoted name, but avoid the strings package.
usage = flag.Usage
for i := 0; i < len(usage); i++ {
if usage[i] == '`' {
for j := i + 1; j < len(usage); j++ {
if usage[j] == '`' {
name = usage[i+1 : j]
usage = usage[:i] + name + usage[j+1:]
return name, usage
}
}
break // Only one back quote; use type name.
}
}
// No explicit name, so use type if we can find one.
name = "value"
switch fv := flag.Value.(type) {
case boolFlag:
if fv.IsBoolFlag() {
name = ""
}
case *durationValue:
name = "duration"
case *float64Value:
name = "float"
case *intValue, *int64Value:
name = "int"
case *stringValue:
name = "string"
case *uintValue, *uint64Value:
name = "uint"
}
return
}
// PrintDefaults prints, to standard error unless configured otherwise, the
// default values of all defined command-line flags in the set. See the
// documentation for the global function PrintDefaults for more information.
func (f *FlagSet) PrintDefaults() {
var isZeroValueErrs []error
f.VisitAll(func(flag *Flag) {
var b strings.Builder
fmt.Fprintf(&b, " -%s", flag.Name) // Two spaces before -; see next two comments.
name, usage := UnquoteUsage(flag)
if len(name) > 0 {
b.WriteString(" ")
b.WriteString(name)
}
// Boolean flags of one ASCII letter are so common we
// treat them specially, putting their usage on the same line.
if b.Len() <= 4 { // space, space, '-', 'x'.
b.WriteString("\t")
} else {
// Four spaces before the tab triggers good alignment
// for both 4- and 8-space tab stops.
b.WriteString("\n \t")
}
b.WriteString(strings.ReplaceAll(usage, "\n", "\n \t"))
// Print the default value only if it differs to the zero value
// for this flag type.
if isZero, err := isZeroValue(flag, flag.DefValue); err != nil {
isZeroValueErrs = append(isZeroValueErrs, err)
} else if !isZero {
if _, ok := flag.Value.(*stringValue); ok {
// put quotes on the value
fmt.Fprintf(&b, " (default %q)", flag.DefValue)
} else {
fmt.Fprintf(&b, " (default %v)", flag.DefValue)
}
}
fmt.Fprint(f.Output(), b.String(), "\n")
})
// If calling String on any zero flag.Values triggered a panic, print
// the messages after the full set of defaults so that the programmer
// knows to fix the panic.
if errs := isZeroValueErrs; len(errs) > 0 {
fmt.Fprintln(f.Output())
for _, err := range errs {
fmt.Fprintln(f.Output(), err)
}
}
}
// PrintDefaults prints, to standard error unless configured otherwise,
// a usage message showing the default settings of all defined
// command-line flags.
// For an integer valued flag x, the default output has the form
//
// -x int
// usage-message-for-x (default 7)
//
// The usage message will appear on a separate line for anything but
// a bool flag with a one-byte name. For bool flags, the type is
// omitted and if the flag name is one byte the usage message appears
// on the same line. The parenthetical default is omitted if the
// default is the zero value for the type. The listed type, here int,
// can be changed by placing a back-quoted name in the flag's usage
// string; the first such item in the message is taken to be a parameter
// name to show in the message and the back quotes are stripped from
// the message when displayed. For instance, given
//
// flag.String("I", "", "search `directory` for include files")
//
// the output will be
//
// -I directory
// search directory for include files.
//
// To change the destination for flag messages, call CommandLine.SetOutput.
func PrintDefaults() {
CommandLine.PrintDefaults()
}
// defaultUsage is the default function to print a usage message.
func (f *FlagSet) defaultUsage() {
if f.name == "" {
fmt.Fprintf(f.Output(), "Usage:\n")
} else {
fmt.Fprintf(f.Output(), "Usage of %s:\n", f.name)
}
f.PrintDefaults()
}
// NOTE: Usage is not just defaultUsage(CommandLine)
// because it serves (via godoc flag Usage) as the example
// for how to write your own usage function.
// Usage prints a usage message documenting all defined command-line flags
// to CommandLine's output, which by default is os.Stderr.
// It is called when an error occurs while parsing flags.
// The function is a variable that may be changed to point to a custom function.
// By default it prints a simple header and calls PrintDefaults; for details about the
// format of the output and how to control it, see the documentation for PrintDefaults.
// Custom usage functions may choose to exit the program; by default exiting
// happens anyway as the command line's error handling strategy is set to
// ExitOnError.
var Usage = func() {
fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0])
PrintDefaults()
}
// NFlag returns the number of flags that have been set.
func (f *FlagSet) NFlag() int { return len(f.actual) }
// NFlag returns the number of command-line flags that have been set.
func NFlag() int { return len(CommandLine.actual) }
// Arg returns the i'th argument. Arg(0) is the first remaining argument
// after flags have been processed. Arg returns an empty string if the
// requested element does not exist.
func (f *FlagSet) Arg(i int) string {
if i < 0 || i >= len(f.args) {
return ""
}
return f.args[i]
}
// Arg returns the i'th command-line argument. Arg(0) is the first remaining argument
// after flags have been processed. Arg returns an empty string if the
// requested element does not exist.
func Arg(i int) string {
return CommandLine.Arg(i)
}
// NArg is the number of arguments remaining after flags have been processed.
func (f *FlagSet) NArg() int { return len(f.args) }
// NArg is the number of arguments remaining after flags have been processed.
func NArg() int { return len(CommandLine.args) }
// Args returns the non-flag arguments.
func (f *FlagSet) Args() []string { return f.args }
// Args returns the non-flag command-line arguments.
func Args() []string { return CommandLine.args }
// BoolVar defines a bool flag with specified name, default value, and usage string.
// The argument p points to a bool variable in which to store the value of the flag.
func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) {
f.Var(newBoolValue(value, p), name, usage)
}
// BoolVar defines a bool flag with specified name, default value, and usage string.
// The argument p points to a bool variable in which to store the value of the flag.
func BoolVar(p *bool, name string, value bool, usage string) {
CommandLine.Var(newBoolValue(value, p), name, usage)
}
// Bool defines a bool flag with specified name, default value, and usage string.
// The return value is the address of a bool variable that stores the value of the flag.
func (f *FlagSet) Bool(name string, value bool, usage string) *bool {
p := new(bool)
f.BoolVar(p, name, value, usage)
return p
}
// Bool defines a bool flag with specified name, default value, and usage string.
// The return value is the address of a bool variable that stores the value of the flag.
func Bool(name string, value bool, usage string) *bool {
return CommandLine.Bool(name, value, usage)
}
// IntVar defines an int flag with specified name, default value, and usage string.
// The argument p points to an int variable in which to store the value of the flag.
func (f *FlagSet) IntVar(p *int, name string, value int, usage string) {
f.Var(newIntValue(value, p), name, usage)
}
// IntVar defines an int flag with specified name, default value, and usage string.
// The argument p points to an int variable in which to store the value of the flag.
func IntVar(p *int, name string, value int, usage string) {
CommandLine.Var(newIntValue(value, p), name, usage)
}
// Int defines an int flag with specified name, default value, and usage string.
// The return value is the address of an int variable that stores the value of the flag.
func (f *FlagSet) Int(name string, value int, usage string) *int {
p := new(int)
f.IntVar(p, name, value, usage)
return p
}
// Int defines an int flag with specified name, default value, and usage string.
// The return value is the address of an int variable that stores the value of the flag.
func Int(name string, value int, usage string) *int {
return CommandLine.Int(name, value, usage)
}
// Int64Var defines an int64 flag with specified name, default value, and usage string.
// The argument p points to an int64 variable in which to store the value of the flag.
func (f *FlagSet) Int64Var(p *int64, name string, value int64, usage string) {
f.Var(newInt64Value(value, p), name, usage)
}
// Int64Var defines an int64 flag with specified name, default value, and usage string.
// The argument p points to an int64 variable in which to store the value of the flag.
func Int64Var(p *int64, name string, value int64, usage string) {
CommandLine.Var(newInt64Value(value, p), name, usage)
}
// Int64 defines an int64 flag with specified name, default value, and usage string.
// The return value is the address of an int64 variable that stores the value of the flag.
func (f *FlagSet) Int64(name string, value int64, usage string) *int64 {
p := new(int64)
f.Int64Var(p, name, value, usage)
return p
}
// Int64 defines an int64 flag with specified name, default value, and usage string.
// The return value is the address of an int64 variable that stores the value of the flag.
func Int64(name string, value int64, usage string) *int64 {
return CommandLine.Int64(name, value, usage)
}
// UintVar defines a uint flag with specified name, default value, and usage string.
// The argument p points to a uint variable in which to store the value of the flag.
func (f *FlagSet) UintVar(p *uint, name string, value uint, usage string) {
f.Var(newUintValue(value, p), name, usage)
}
// UintVar defines a uint flag with specified name, default value, and usage string.
// The argument p points to a uint variable in which to store the value of the flag.
func UintVar(p *uint, name string, value uint, usage string) {
CommandLine.Var(newUintValue(value, p), name, usage)
}
// Uint defines a uint flag with specified name, default value, and usage string.
// The return value is the address of a uint variable that stores the value of the flag.
func (f *FlagSet) Uint(name string, value uint, usage string) *uint {
p := new(uint)
f.UintVar(p, name, value, usage)
return p
}
// Uint defines a uint flag with specified name, default value, and usage string.
// The return value is the address of a uint variable that stores the value of the flag.
func Uint(name string, value uint, usage string) *uint {
return CommandLine.Uint(name, value, usage)
}
// Uint64Var defines a uint64 flag with specified name, default value, and usage string.
// The argument p points to a uint64 variable in which to store the value of the flag.
func (f *FlagSet) Uint64Var(p *uint64, name string, value uint64, usage string) {
f.Var(newUint64Value(value, p), name, usage)
}
// Uint64Var defines a uint64 flag with specified name, default value, and usage string.
// The argument p points to a uint64 variable in which to store the value of the flag.
func Uint64Var(p *uint64, name string, value uint64, usage string) {
CommandLine.Var(newUint64Value(value, p), name, usage)
}
// Uint64 defines a uint64 flag with specified name, default value, and usage string.
// The return value is the address of a uint64 variable that stores the value of the flag.
func (f *FlagSet) Uint64(name string, value uint64, usage string) *uint64 {
p := new(uint64)
f.Uint64Var(p, name, value, usage)
return p
}
// Uint64 defines a uint64 flag with specified name, default value, and usage string.
// The return value is the address of a uint64 variable that stores the value of the flag.
func Uint64(name string, value uint64, usage string) *uint64 {
return CommandLine.Uint64(name, value, usage)
}
// StringVar defines a string flag with specified name, default value, and usage string.
// The argument p points to a string variable in which to store the value of the flag.
func (f *FlagSet) StringVar(p *string, name string, value string, usage string) {
f.Var(newStringValue(value, p), name, usage)
}
// StringVar defines a string flag with specified name, default value, and usage string.
// The argument p points to a string variable in which to store the value of the flag.
func StringVar(p *string, name string, value string, usage string) {
CommandLine.Var(newStringValue(value, p), name, usage)
}
// String defines a string flag with specified name, default value, and usage string.
// The return value is the address of a string variable that stores the value of the flag.
func (f *FlagSet) String(name string, value string, usage string) *string {
p := new(string)
f.StringVar(p, name, value, usage)
return p
}
// String defines a string flag with specified name, default value, and usage string.
// The return value is the address of a string variable that stores the value of the flag.
func String(name string, value string, usage string) *string {
return CommandLine.String(name, value, usage)
}
// Float64Var defines a float64 flag with specified name, default value, and usage string.
// The argument p points to a float64 variable in which to store the value of the flag.
func (f *FlagSet) Float64Var(p *float64, name string, value float64, usage string) {
f.Var(newFloat64Value(value, p), name, usage)
}
// Float64Var defines a float64 flag with specified name, default value, and usage string.
// The argument p points to a float64 variable in which to store the value of the flag.
func Float64Var(p *float64, name string, value float64, usage string) {
CommandLine.Var(newFloat64Value(value, p), name, usage)
}
// Float64 defines a float64 flag with specified name, default value, and usage string.
// The return value is the address of a float64 variable that stores the value of the flag.
func (f *FlagSet) Float64(name string, value float64, usage string) *float64 {
p := new(float64)
f.Float64Var(p, name, value, usage)
return p
}
// Float64 defines a float64 flag with specified name, default value, and usage string.
// The return value is the address of a float64 variable that stores the value of the flag.
func Float64(name string, value float64, usage string) *float64 {
return CommandLine.Float64(name, value, usage)
}
// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
// The argument p points to a time.Duration variable in which to store the value of the flag.
// The flag accepts a value acceptable to time.ParseDuration.
func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
f.Var(newDurationValue(value, p), name, usage)
}
// DurationVar defines a time.Duration flag with specified name, default value, and usage string.
// The argument p points to a time.Duration variable in which to store the value of the flag.
// The flag accepts a value acceptable to time.ParseDuration.
func DurationVar(p *time.Duration, name string, value time.Duration, usage string) {
CommandLine.Var(newDurationValue(value, p), name, usage)
}
// Duration defines a time.Duration flag with specified name, default value, and usage string.
// The return value is the address of a time.Duration variable that stores the value of the flag.
// The flag accepts a value acceptable to time.ParseDuration.
func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time.Duration {
p := new(time.Duration)
f.DurationVar(p, name, value, usage)
return p
}
// Duration defines a time.Duration flag with specified name, default value, and usage string.
// The return value is the address of a time.Duration variable that stores the value of the flag.
// The flag accepts a value acceptable to time.ParseDuration.
func Duration(name string, value time.Duration, usage string) *time.Duration {
return CommandLine.Duration(name, value, usage)
}
// TextVar defines a flag with a specified name, default value, and usage string.
// The argument p must be a pointer to a variable that will hold the value
// of the flag, and p must implement encoding.TextUnmarshaler.
// If the flag is used, the flag value will be passed to p's UnmarshalText method.
// The type of the default value must be the same as the type of p.
func (f *FlagSet) TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
f.Var(newTextValue(value, p), name, usage)
}
// TextVar defines a flag with a specified name, default value, and usage string.
// The argument p must be a pointer to a variable that will hold the value
// of the flag, and p must implement encoding.TextUnmarshaler.
// If the flag is used, the flag value will be passed to p's UnmarshalText method.
// The type of the default value must be the same as the type of p.
func TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
CommandLine.Var(newTextValue(value, p), name, usage)
}
// Func defines a flag with the specified name and usage string.
// Each time the flag is seen, fn is called with the value of the flag.
// If fn returns a non-nil error, it will be treated as a flag value parsing error.
func (f *FlagSet) Func(name, usage string, fn func(string) error) {
f.Var(funcValue(fn), name, usage)
}
// Func defines a flag with the specified name and usage string.
// Each time the flag is seen, fn is called with the value of the flag.
// If fn returns a non-nil error, it will be treated as a flag value parsing error.
func Func(name, usage string, fn func(string) error) {
CommandLine.Func(name, usage, fn)
}
// Var defines a flag with the specified name and usage string. The type and
// value of the flag are represented by the first argument, of type Value, which
// typically holds a user-defined implementation of Value. For instance, the
// caller could create a flag that turns a comma-separated string into a slice
// of strings by giving the slice the methods of Value; in particular, Set would
// decompose the comma-separated string into the slice.
func (f *FlagSet) Var(value Value, name string, usage string) {
// Flag must not begin "-" or contain "=".
if strings.HasPrefix(name, "-") {
panic(f.sprintf("flag %q begins with -", name))
} else if strings.Contains(name, "=") {
panic(f.sprintf("flag %q contains =", name))
}
// Remember the default value as a string; it won't change.
flag := &Flag{name, usage, value, value.String()}
_, alreadythere := f.formal[name]
if alreadythere {
var msg string
if f.name == "" {
msg = f.sprintf("flag redefined: %s", name)
} else {
msg = f.sprintf("%s flag redefined: %s", f.name, name)
}
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
f.formal = make(map[string]*Flag)
}
f.formal[name] = flag
}
// Var defines a flag with the specified name and usage string. The type and
// value of the flag are represented by the first argument, of type Value, which
// typically holds a user-defined implementation of Value. For instance, the
// caller could create a flag that turns a comma-separated string into a slice
// of strings by giving the slice the methods of Value; in particular, Set would
// decompose the comma-separated string into the slice.
func Var(value Value, name string, usage string) {
CommandLine.Var(value, name, usage)
}
// sprintf formats the message, prints it to output, and returns it.
func (f *FlagSet) sprintf(format string, a ...any) string {
msg := fmt.Sprintf(format, a...)
fmt.Fprintln(f.Output(), msg)
return msg
}
// failf prints to standard error a formatted error and usage message and
// returns the error.
func (f *FlagSet) failf(format string, a ...any) error {
msg := f.sprintf(format, a...)
f.usage()
return errors.New(msg)
}
// usage calls the Usage method for the flag set if one is specified,
// or the appropriate default usage function otherwise.
func (f *FlagSet) usage() {
if f.Usage == nil {
f.defaultUsage()
} else {
f.Usage()
}
}
// parseOne parses one flag. It reports whether a flag was seen.
func (f *FlagSet) parseOne() (bool, error) {
if len(f.args) == 0 {
return false, nil
}
s := f.args[0]
if len(s) < 2 || s[0] != '-' {
return false, nil
}
numMinuses := 1
if s[1] == '-' {
numMinuses++
if len(s) == 2 { // "--" terminates the flags
f.args = f.args[1:]
return false, nil
}
}
name := s[numMinuses:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
return false, f.failf("bad flag syntax: %s", s)
}
// it's a flag. does it have an argument?
f.args = f.args[1:]
hasValue := false
value := ""
for i := 1; i < len(name); i++ { // equals cannot be first
if name[i] == '=' {
value = name[i+1:]
hasValue = true
name = name[0:i]
break
}
}
flag, ok := f.formal[name]
if !ok {
if name == "help" || name == "h" { // special case for nice help message.
f.usage()
return false, ErrHelp
}
return false, f.failf("flag provided but not defined: -%s", name)
}
if fv, ok := flag.Value.(boolFlag); ok && fv.IsBoolFlag() { // special case: doesn't need an arg
if hasValue {
if err := fv.Set(value); err != nil {
return false, f.failf("invalid boolean value %q for -%s: %v", value, name, err)
}
} else {
if err := fv.Set("true"); err != nil {
return false, f.failf("invalid boolean flag %s: %v", name, err)
}
}
} else {
// It must have a value, which might be the next argument.
if !hasValue && len(f.args) > 0 {
// value is the next arg
hasValue = true
value, f.args = f.args[0], f.args[1:]
}
if !hasValue {
return false, f.failf("flag needs an argument: -%s", name)
}
if err := flag.Value.Set(value); err != nil {
return false, f.failf("invalid value %q for flag -%s: %v", value, name, err)
}
}
if f.actual == nil {
f.actual = make(map[string]*Flag)
}
f.actual[name] = flag
return true, nil
}
// Parse parses flag definitions from the argument list, which should not
// include the command name. Must be called after all flags in the FlagSet
// are defined and before flags are accessed by the program.
// The return value will be ErrHelp if -help or -h were set but not defined.
func (f *FlagSet) Parse(arguments []string) error {
f.parsed = true
f.args = arguments
for {
seen, err := f.parseOne()
if seen {
continue
}
if err == nil {
break
}
switch f.errorHandling {
case ContinueOnError:
return err
case ExitOnError:
if err == ErrHelp {
os.Exit(0)
}
os.Exit(2)
case PanicOnError:
panic(err)
}
}
return nil
}
// Parsed reports whether f.Parse has been called.
func (f *FlagSet) Parsed() bool {
return f.parsed
}
// Parse parses the command-line flags from os.Args[1:]. Must be called
// after all flags are defined and before flags are accessed by the program.
func Parse() {
// Ignore errors; CommandLine is set for ExitOnError.
CommandLine.Parse(os.Args[1:])
}
// Parsed reports whether the command-line flags have been parsed.
func Parsed() bool {
return CommandLine.Parsed()
}
// CommandLine is the default set of command-line flags, parsed from os.Args.
// The top-level functions such as BoolVar, Arg, and so on are wrappers for the
// methods of CommandLine.
var CommandLine = NewFlagSet(os.Args[0], ExitOnError)
func init() {
// Override generic FlagSet default Usage with call to global Usage.
// Note: This is not CommandLine.Usage = Usage,
// because we want any eventual call to use any updated value of Usage,
// not the value it has when this line is run.
CommandLine.Usage = commandLineUsage
}
func commandLineUsage() {
Usage()
}
// NewFlagSet returns a new, empty flag set with the specified name and
// error handling property. If the name is not empty, it will be printed
// in the default usage message and in error messages.
func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet {
f := &FlagSet{
name: name,
errorHandling: errorHandling,
}
f.Usage = f.defaultUsage
return f
}
// Init sets the name and error handling property for a flag set.
// By default, the zero FlagSet uses an empty name and the
// ContinueOnError error handling policy.
func (f *FlagSet) Init(name string, errorHandling ErrorHandling) {
f.name = name
f.errorHandling = errorHandling
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fmt
import (
"errors"
"sort"
)
// Errorf formats according to a format specifier and returns the string as a
// value that satisfies error.
//
// If the format specifier includes a %w verb with an error operand,
// the returned error will implement an Unwrap method returning the operand.
// If there is more than one %w verb, the returned error will implement an
// Unwrap method returning a []error containing all the %w operands in the
// order they appear in the arguments.
// It is invalid to supply the %w verb with an operand that does not implement
// the error interface. The %w verb is otherwise a synonym for %v.
func Errorf(format string, a ...any) error {
p := newPrinter()
p.wrapErrs = true
p.doPrintf(format, a)
s := string(p.buf)
var err error
switch len(p.wrappedErrs) {
case 0:
err = errors.New(s)
case 1:
w := &wrapError{msg: s}
w.err, _ = a[p.wrappedErrs[0]].(error)
err = w
default:
if p.reordered {
sort.Ints(p.wrappedErrs)
}
var errs []error
for i, argNum := range p.wrappedErrs {
if i > 0 && p.wrappedErrs[i-1] == argNum {
continue
}
if e, ok := a[argNum].(error); ok {
errs = append(errs, e)
}
}
err = &wrapErrors{s, errs}
}
p.free()
return err
}
type wrapError struct {
msg string
err error
}
func (e *wrapError) Error() string {
return e.msg
}
func (e *wrapError) Unwrap() error {
return e.err
}
type wrapErrors struct {
msg string
errs []error
}
func (e *wrapErrors) Error() string {
return e.msg
}
func (e *wrapErrors) Unwrap() []error {
return e.errs
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fmt
import (
"strconv"
"unicode/utf8"
)
const (
ldigits = "0123456789abcdefx"
udigits = "0123456789ABCDEFX"
)
const (
signed = true
unsigned = false
)
// flags placed in a separate struct for easy clearing.
type fmtFlags struct {
widPresent bool
precPresent bool
minus bool
plus bool
sharp bool
space bool
zero bool
// For the formats %+v %#v, we set the plusV/sharpV flags
// and clear the plus/sharp flags since %+v and %#v are in effect
// different, flagless formats set at the top level.
plusV bool
sharpV bool
}
// A fmt is the raw formatter used by Printf etc.
// It prints into a buffer that must be set up separately.
type fmt struct {
buf *buffer
fmtFlags
wid int // width
prec int // precision
// intbuf is large enough to store %b of an int64 with a sign and
// avoids padding at the end of the struct on 32 bit architectures.
intbuf [68]byte
}
func (f *fmt) clearflags() {
f.fmtFlags = fmtFlags{}
}
func (f *fmt) init(buf *buffer) {
f.buf = buf
f.clearflags()
}
// writePadding generates n bytes of padding.
func (f *fmt) writePadding(n int) {
if n <= 0 { // No padding bytes needed.
return
}
buf := *f.buf
oldLen := len(buf)
newLen := oldLen + n
// Make enough room for padding.
if newLen > cap(buf) {
buf = make(buffer, cap(buf)*2+n)
copy(buf, *f.buf)
}
// Decide which byte the padding should be filled with.
padByte := byte(' ')
if f.zero {
padByte = byte('0')
}
// Fill padding with padByte.
padding := buf[oldLen:newLen]
for i := range padding {
padding[i] = padByte
}
*f.buf = buf[:newLen]
}
// pad appends b to f.buf, padded on left (!f.minus) or right (f.minus).
func (f *fmt) pad(b []byte) {
if !f.widPresent || f.wid == 0 {
f.buf.write(b)
return
}
width := f.wid - utf8.RuneCount(b)
if !f.minus {
// left padding
f.writePadding(width)
f.buf.write(b)
} else {
// right padding
f.buf.write(b)
f.writePadding(width)
}
}
// padString appends s to f.buf, padded on left (!f.minus) or right (f.minus).
func (f *fmt) padString(s string) {
if !f.widPresent || f.wid == 0 {
f.buf.writeString(s)
return
}
width := f.wid - utf8.RuneCountInString(s)
if !f.minus {
// left padding
f.writePadding(width)
f.buf.writeString(s)
} else {
// right padding
f.buf.writeString(s)
f.writePadding(width)
}
}
// fmtBoolean formats a boolean.
func (f *fmt) fmtBoolean(v bool) {
if v {
f.padString("true")
} else {
f.padString("false")
}
}
// fmtUnicode formats a uint64 as "U+0078" or with f.sharp set as "U+0078 'x'".
func (f *fmt) fmtUnicode(u uint64) {
buf := f.intbuf[0:]
// With default precision set the maximum needed buf length is 18
// for formatting -1 with %#U ("U+FFFFFFFFFFFFFFFF") which fits
// into the already allocated intbuf with a capacity of 68 bytes.
prec := 4
if f.precPresent && f.prec > 4 {
prec = f.prec
// Compute space needed for "U+" , number, " '", character, "'".
width := 2 + prec + 2 + utf8.UTFMax + 1
if width > len(buf) {
buf = make([]byte, width)
}
}
// Format into buf, ending at buf[i]. Formatting numbers is easier right-to-left.
i := len(buf)
// For %#U we want to add a space and a quoted character at the end of the buffer.
if f.sharp && u <= utf8.MaxRune && strconv.IsPrint(rune(u)) {
i--
buf[i] = '\''
i -= utf8.RuneLen(rune(u))
utf8.EncodeRune(buf[i:], rune(u))
i--
buf[i] = '\''
i--
buf[i] = ' '
}
// Format the Unicode code point u as a hexadecimal number.
for u >= 16 {
i--
buf[i] = udigits[u&0xF]
prec--
u >>= 4
}
i--
buf[i] = udigits[u]
prec--
// Add zeros in front of the number until requested precision is reached.
for prec > 0 {
i--
buf[i] = '0'
prec--
}
// Add a leading "U+".
i--
buf[i] = '+'
i--
buf[i] = 'U'
oldZero := f.zero
f.zero = false
f.pad(buf[i:])
f.zero = oldZero
}
// fmtInteger formats signed and unsigned integers.
func (f *fmt) fmtInteger(u uint64, base int, isSigned bool, verb rune, digits string) {
negative := isSigned && int64(u) < 0
if negative {
u = -u
}
buf := f.intbuf[0:]
// The already allocated f.intbuf with a capacity of 68 bytes
// is large enough for integer formatting when no precision or width is set.
if f.widPresent || f.precPresent {
// Account 3 extra bytes for possible addition of a sign and "0x".
width := 3 + f.wid + f.prec // wid and prec are always positive.
if width > len(buf) {
// We're going to need a bigger boat.
buf = make([]byte, width)
}
}
// Two ways to ask for extra leading zero digits: %.3d or %03d.
// If both are specified the f.zero flag is ignored and
// padding with spaces is used instead.
prec := 0
if f.precPresent {
prec = f.prec
// Precision of 0 and value of 0 means "print nothing" but padding.
if prec == 0 && u == 0 {
oldZero := f.zero
f.zero = false
f.writePadding(f.wid)
f.zero = oldZero
return
}
} else if f.zero && f.widPresent {
prec = f.wid
if negative || f.plus || f.space {
prec-- // leave room for sign
}
}
// Because printing is easier right-to-left: format u into buf, ending at buf[i].
// We could make things marginally faster by splitting the 32-bit case out
// into a separate block but it's not worth the duplication, so u has 64 bits.
i := len(buf)
// Use constants for the division and modulo for more efficient code.
// Switch cases ordered by popularity.
switch base {
case 10:
for u >= 10 {
i--
next := u / 10
buf[i] = byte('0' + u - next*10)
u = next
}
case 16:
for u >= 16 {
i--
buf[i] = digits[u&0xF]
u >>= 4
}
case 8:
for u >= 8 {
i--
buf[i] = byte('0' + u&7)
u >>= 3
}
case 2:
for u >= 2 {
i--
buf[i] = byte('0' + u&1)
u >>= 1
}
default:
panic("fmt: unknown base; can't happen")
}
i--
buf[i] = digits[u]
for i > 0 && prec > len(buf)-i {
i--
buf[i] = '0'
}
// Various prefixes: 0x, -, etc.
if f.sharp {
switch base {
case 2:
// Add a leading 0b.
i--
buf[i] = 'b'
i--
buf[i] = '0'
case 8:
if buf[i] != '0' {
i--
buf[i] = '0'
}
case 16:
// Add a leading 0x or 0X.
i--
buf[i] = digits[16]
i--
buf[i] = '0'
}
}
if verb == 'O' {
i--
buf[i] = 'o'
i--
buf[i] = '0'
}
if negative {
i--
buf[i] = '-'
} else if f.plus {
i--
buf[i] = '+'
} else if f.space {
i--
buf[i] = ' '
}
// Left padding with zeros has already been handled like precision earlier
// or the f.zero flag is ignored due to an explicitly set precision.
oldZero := f.zero
f.zero = false
f.pad(buf[i:])
f.zero = oldZero
}
// truncateString truncates the string s to the specified precision, if present.
func (f *fmt) truncateString(s string) string {
if f.precPresent {
n := f.prec
for i := range s {
n--
if n < 0 {
return s[:i]
}
}
}
return s
}
// truncate truncates the byte slice b as a string of the specified precision, if present.
func (f *fmt) truncate(b []byte) []byte {
if f.precPresent {
n := f.prec
for i := 0; i < len(b); {
n--
if n < 0 {
return b[:i]
}
wid := 1
if b[i] >= utf8.RuneSelf {
_, wid = utf8.DecodeRune(b[i:])
}
i += wid
}
}
return b
}
// fmtS formats a string.
func (f *fmt) fmtS(s string) {
s = f.truncateString(s)
f.padString(s)
}
// fmtBs formats the byte slice b as if it was formatted as string with fmtS.
func (f *fmt) fmtBs(b []byte) {
b = f.truncate(b)
f.pad(b)
}
// fmtSbx formats a string or byte slice as a hexadecimal encoding of its bytes.
func (f *fmt) fmtSbx(s string, b []byte, digits string) {
length := len(b)
if b == nil {
// No byte slice present. Assume string s should be encoded.
length = len(s)
}
// Set length to not process more bytes than the precision demands.
if f.precPresent && f.prec < length {
length = f.prec
}
// Compute width of the encoding taking into account the f.sharp and f.space flag.
width := 2 * length
if width > 0 {
if f.space {
// Each element encoded by two hexadecimals will get a leading 0x or 0X.
if f.sharp {
width *= 2
}
// Elements will be separated by a space.
width += length - 1
} else if f.sharp {
// Only a leading 0x or 0X will be added for the whole string.
width += 2
}
} else { // The byte slice or string that should be encoded is empty.
if f.widPresent {
f.writePadding(f.wid)
}
return
}
// Handle padding to the left.
if f.widPresent && f.wid > width && !f.minus {
f.writePadding(f.wid - width)
}
// Write the encoding directly into the output buffer.
buf := *f.buf
if f.sharp {
// Add leading 0x or 0X.
buf = append(buf, '0', digits[16])
}
var c byte
for i := 0; i < length; i++ {
if f.space && i > 0 {
// Separate elements with a space.
buf = append(buf, ' ')
if f.sharp {
// Add leading 0x or 0X for each element.
buf = append(buf, '0', digits[16])
}
}
if b != nil {
c = b[i] // Take a byte from the input byte slice.
} else {
c = s[i] // Take a byte from the input string.
}
// Encode each byte as two hexadecimal digits.
buf = append(buf, digits[c>>4], digits[c&0xF])
}
*f.buf = buf
// Handle padding to the right.
if f.widPresent && f.wid > width && f.minus {
f.writePadding(f.wid - width)
}
}
// fmtSx formats a string as a hexadecimal encoding of its bytes.
func (f *fmt) fmtSx(s, digits string) {
f.fmtSbx(s, nil, digits)
}
// fmtBx formats a byte slice as a hexadecimal encoding of its bytes.
func (f *fmt) fmtBx(b []byte, digits string) {
f.fmtSbx("", b, digits)
}
// fmtQ formats a string as a double-quoted, escaped Go string constant.
// If f.sharp is set a raw (backquoted) string may be returned instead
// if the string does not contain any control characters other than tab.
func (f *fmt) fmtQ(s string) {
s = f.truncateString(s)
if f.sharp && strconv.CanBackquote(s) {
f.padString("`" + s + "`")
return
}
buf := f.intbuf[:0]
if f.plus {
f.pad(strconv.AppendQuoteToASCII(buf, s))
} else {
f.pad(strconv.AppendQuote(buf, s))
}
}
// fmtC formats an integer as a Unicode character.
// If the character is not valid Unicode, it will print '\ufffd'.
func (f *fmt) fmtC(c uint64) {
// Explicitly check whether c exceeds utf8.MaxRune since the conversion
// of a uint64 to a rune may lose precision that indicates an overflow.
r := rune(c)
if c > utf8.MaxRune {
r = utf8.RuneError
}
buf := f.intbuf[:0]
f.pad(utf8.AppendRune(buf, r))
}
// fmtQc formats an integer as a single-quoted, escaped Go character constant.
// If the character is not valid Unicode, it will print '\ufffd'.
func (f *fmt) fmtQc(c uint64) {
r := rune(c)
if c > utf8.MaxRune {
r = utf8.RuneError
}
buf := f.intbuf[:0]
if f.plus {
f.pad(strconv.AppendQuoteRuneToASCII(buf, r))
} else {
f.pad(strconv.AppendQuoteRune(buf, r))
}
}
// fmtFloat formats a float64. It assumes that verb is a valid format specifier
// for strconv.AppendFloat and therefore fits into a byte.
func (f *fmt) fmtFloat(v float64, size int, verb rune, prec int) {
// Explicit precision in format specifier overrules default precision.
if f.precPresent {
prec = f.prec
}
// Format number, reserving space for leading + sign if needed.
num := strconv.AppendFloat(f.intbuf[:1], v, byte(verb), prec, size)
if num[1] == '-' || num[1] == '+' {
num = num[1:]
} else {
num[0] = '+'
}
// f.space means to add a leading space instead of a "+" sign unless
// the sign is explicitly asked for by f.plus.
if f.space && num[0] == '+' && !f.plus {
num[0] = ' '
}
// Special handling for infinities and NaN,
// which don't look like a number so shouldn't be padded with zeros.
if num[1] == 'I' || num[1] == 'N' {
oldZero := f.zero
f.zero = false
// Remove sign before NaN if not asked for.
if num[1] == 'N' && !f.space && !f.plus {
num = num[1:]
}
f.pad(num)
f.zero = oldZero
return
}
// The sharp flag forces printing a decimal point for non-binary formats
// and retains trailing zeros, which we may need to restore.
if f.sharp && verb != 'b' {
digits := 0
switch verb {
case 'v', 'g', 'G', 'x':
digits = prec
// If no precision is set explicitly use a precision of 6.
if digits == -1 {
digits = 6
}
}
// Buffer pre-allocated with enough room for
// exponent notations of the form "e+123" or "p-1023".
var tailBuf [6]byte
tail := tailBuf[:0]
hasDecimalPoint := false
sawNonzeroDigit := false
// Starting from i = 1 to skip sign at num[0].
for i := 1; i < len(num); i++ {
switch num[i] {
case '.':
hasDecimalPoint = true
case 'p', 'P':
tail = append(tail, num[i:]...)
num = num[:i]
case 'e', 'E':
if verb != 'x' && verb != 'X' {
tail = append(tail, num[i:]...)
num = num[:i]
break
}
fallthrough
default:
if num[i] != '0' {
sawNonzeroDigit = true
}
// Count significant digits after the first non-zero digit.
if sawNonzeroDigit {
digits--
}
}
}
if !hasDecimalPoint {
// Leading digit 0 should contribute once to digits.
if len(num) == 2 && num[1] == '0' {
digits--
}
num = append(num, '.')
}
for digits > 0 {
num = append(num, '0')
digits--
}
num = append(num, tail...)
}
// We want a sign if asked for and if the sign is not positive.
if f.plus || num[0] != '+' {
// If we're zero padding to the left we want the sign before the leading zeros.
// Achieve this by writing the sign out and then padding the unsigned number.
if f.zero && f.widPresent && f.wid > len(num) {
f.buf.writeByte(num[0])
f.writePadding(f.wid - len(num))
f.buf.write(num[1:])
return
}
f.pad(num)
return
}
// No sign to show and the number is positive; just print the unsigned number.
f.pad(num[1:])
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fmt
import (
"internal/fmtsort"
"io"
"os"
"reflect"
"strconv"
"sync"
"unicode/utf8"
)
// Strings for use with buffer.WriteString.
// This is less overhead than using buffer.Write with byte arrays.
const (
commaSpaceString = ", "
nilAngleString = "<nil>"
nilParenString = "(nil)"
nilString = "nil"
mapString = "map["
percentBangString = "%!"
missingString = "(MISSING)"
badIndexString = "(BADINDEX)"
panicString = "(PANIC="
extraString = "%!(EXTRA "
badWidthString = "%!(BADWIDTH)"
badPrecString = "%!(BADPREC)"
noVerbString = "%!(NOVERB)"
invReflectString = "<invalid reflect.Value>"
)
// State represents the printer state passed to custom formatters.
// It provides access to the io.Writer interface plus information about
// the flags and options for the operand's format specifier.
type State interface {
// Write is the function to call to emit formatted output to be printed.
Write(b []byte) (n int, err error)
// Width returns the value of the width option and whether it has been set.
Width() (wid int, ok bool)
// Precision returns the value of the precision option and whether it has been set.
Precision() (prec int, ok bool)
// Flag reports whether the flag c, a character, has been set.
Flag(c int) bool
}
// Formatter is implemented by any value that has a Format method.
// The implementation controls how State and rune are interpreted,
// and may call Sprint(f) or Fprint(f) etc. to generate its output.
type Formatter interface {
Format(f State, verb rune)
}
// Stringer is implemented by any value that has a String method,
// which defines the “native” format for that value.
// The String method is used to print values passed as an operand
// to any format that accepts a string or to an unformatted printer
// such as Print.
type Stringer interface {
String() string
}
// GoStringer is implemented by any value that has a GoString method,
// which defines the Go syntax for that value.
// The GoString method is used to print values passed as an operand
// to a %#v format.
type GoStringer interface {
GoString() string
}
// FormatString returns a string representing the fully qualified formatting
// directive captured by the State, followed by the argument verb. (State does not
// itself contain the verb.) The result has a leading percent sign followed by any
// flags, the width, and the precision. Missing flags, width, and precision are
// omitted. This function allows a Formatter to reconstruct the original
// directive triggering the call to Format.
func FormatString(state State, verb rune) string {
var tmp [16]byte // Use a local buffer.
b := append(tmp[:0], '%')
for _, c := range " +-#0" { // All known flags
if state.Flag(int(c)) { // The argument is an int for historical reasons.
b = append(b, byte(c))
}
}
if w, ok := state.Width(); ok {
b = strconv.AppendInt(b, int64(w), 10)
}
if p, ok := state.Precision(); ok {
b = append(b, '.')
b = strconv.AppendInt(b, int64(p), 10)
}
b = utf8.AppendRune(b, verb)
return string(b)
}
// Use simple []byte instead of bytes.Buffer to avoid large dependency.
type buffer []byte
func (b *buffer) write(p []byte) {
*b = append(*b, p...)
}
func (b *buffer) writeString(s string) {
*b = append(*b, s...)
}
func (b *buffer) writeByte(c byte) {
*b = append(*b, c)
}
func (bp *buffer) writeRune(r rune) {
*bp = utf8.AppendRune(*bp, r)
}
// pp is used to store a printer's state and is reused with sync.Pool to avoid allocations.
type pp struct {
buf buffer
// arg holds the current item, as an interface{}.
arg any
// value is used instead of arg for reflect values.
value reflect.Value
// fmt is used to format basic items such as integers or strings.
fmt fmt
// reordered records whether the format string used argument reordering.
reordered bool
// goodArgNum records whether the most recent reordering directive was valid.
goodArgNum bool
// panicking is set by catchPanic to avoid infinite panic, recover, panic, ... recursion.
panicking bool
// erroring is set when printing an error string to guard against calling handleMethods.
erroring bool
// wrapErrs is set when the format string may contain a %w verb.
wrapErrs bool
// wrappedErrs records the targets of the %w verb.
wrappedErrs []int
}
var ppFree = sync.Pool{
New: func() any { return new(pp) },
}
// newPrinter allocates a new pp struct or grabs a cached one.
func newPrinter() *pp {
p := ppFree.Get().(*pp)
p.panicking = false
p.erroring = false
p.wrapErrs = false
p.fmt.init(&p.buf)
return p
}
// free saves used pp structs in ppFree; avoids an allocation per invocation.
func (p *pp) free() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. If the buffer is larger than the
// limit, we drop the buffer and recycle just the printer.
//
// See https://golang.org/issue/23199.
if cap(p.buf) > 64*1024 {
p.buf = nil
} else {
p.buf = p.buf[:0]
}
if cap(p.wrappedErrs) > 8 {
p.wrappedErrs = nil
}
p.arg = nil
p.value = reflect.Value{}
p.wrappedErrs = p.wrappedErrs[:0]
ppFree.Put(p)
}
func (p *pp) Width() (wid int, ok bool) { return p.fmt.wid, p.fmt.widPresent }
func (p *pp) Precision() (prec int, ok bool) { return p.fmt.prec, p.fmt.precPresent }
func (p *pp) Flag(b int) bool {
switch b {
case '-':
return p.fmt.minus
case '+':
return p.fmt.plus || p.fmt.plusV
case '#':
return p.fmt.sharp || p.fmt.sharpV
case ' ':
return p.fmt.space
case '0':
return p.fmt.zero
}
return false
}
// Implement Write so we can call Fprintf on a pp (through State), for
// recursive use in custom verbs.
func (p *pp) Write(b []byte) (ret int, err error) {
p.buf.write(b)
return len(b), nil
}
// Implement WriteString so that we can call io.WriteString
// on a pp (through state), for efficiency.
func (p *pp) WriteString(s string) (ret int, err error) {
p.buf.writeString(s)
return len(s), nil
}
// These routines end in 'f' and take a format string.
// Fprintf formats according to a format specifier and writes to w.
// It returns the number of bytes written and any write error encountered.
func Fprintf(w io.Writer, format string, a ...any) (n int, err error) {
p := newPrinter()
p.doPrintf(format, a)
n, err = w.Write(p.buf)
p.free()
return
}
// Printf formats according to a format specifier and writes to standard output.
// It returns the number of bytes written and any write error encountered.
func Printf(format string, a ...any) (n int, err error) {
return Fprintf(os.Stdout, format, a...)
}
// Sprintf formats according to a format specifier and returns the resulting string.
func Sprintf(format string, a ...any) string {
p := newPrinter()
p.doPrintf(format, a)
s := string(p.buf)
p.free()
return s
}
// Appendf formats according to a format specifier, appends the result to the byte
// slice, and returns the updated slice.
func Appendf(b []byte, format string, a ...any) []byte {
p := newPrinter()
p.doPrintf(format, a)
b = append(b, p.buf...)
p.free()
return b
}
// These routines do not take a format string
// Fprint formats using the default formats for its operands and writes to w.
// Spaces are added between operands when neither is a string.
// It returns the number of bytes written and any write error encountered.
func Fprint(w io.Writer, a ...any) (n int, err error) {
p := newPrinter()
p.doPrint(a)
n, err = w.Write(p.buf)
p.free()
return
}
// Print formats using the default formats for its operands and writes to standard output.
// Spaces are added between operands when neither is a string.
// It returns the number of bytes written and any write error encountered.
func Print(a ...any) (n int, err error) {
return Fprint(os.Stdout, a...)
}
// Sprint formats using the default formats for its operands and returns the resulting string.
// Spaces are added between operands when neither is a string.
func Sprint(a ...any) string {
p := newPrinter()
p.doPrint(a)
s := string(p.buf)
p.free()
return s
}
// Append formats using the default formats for its operands, appends the result to
// the byte slice, and returns the updated slice.
func Append(b []byte, a ...any) []byte {
p := newPrinter()
p.doPrint(a)
b = append(b, p.buf...)
p.free()
return b
}
// These routines end in 'ln', do not take a format string,
// always add spaces between operands, and add a newline
// after the last operand.
// Fprintln formats using the default formats for its operands and writes to w.
// Spaces are always added between operands and a newline is appended.
// It returns the number of bytes written and any write error encountered.
func Fprintln(w io.Writer, a ...any) (n int, err error) {
p := newPrinter()
p.doPrintln(a)
n, err = w.Write(p.buf)
p.free()
return
}
// Println formats using the default formats for its operands and writes to standard output.
// Spaces are always added between operands and a newline is appended.
// It returns the number of bytes written and any write error encountered.
func Println(a ...any) (n int, err error) {
return Fprintln(os.Stdout, a...)
}
// Sprintln formats using the default formats for its operands and returns the resulting string.
// Spaces are always added between operands and a newline is appended.
func Sprintln(a ...any) string {
p := newPrinter()
p.doPrintln(a)
s := string(p.buf)
p.free()
return s
}
// Appendln formats using the default formats for its operands, appends the result
// to the byte slice, and returns the updated slice. Spaces are always added
// between operands and a newline is appended.
func Appendln(b []byte, a ...any) []byte {
p := newPrinter()
p.doPrintln(a)
b = append(b, p.buf...)
p.free()
return b
}
// getField gets the i'th field of the struct value.
// If the field is itself is an interface, return a value for
// the thing inside the interface, not the interface itself.
func getField(v reflect.Value, i int) reflect.Value {
val := v.Field(i)
if val.Kind() == reflect.Interface && !val.IsNil() {
val = val.Elem()
}
return val
}
// tooLarge reports whether the magnitude of the integer is
// too large to be used as a formatting width or precision.
func tooLarge(x int) bool {
const max int = 1e6
return x > max || x < -max
}
// parsenum converts ASCII to integer. num is 0 (and isnum is false) if no number present.
func parsenum(s string, start, end int) (num int, isnum bool, newi int) {
if start >= end {
return 0, false, end
}
for newi = start; newi < end && '0' <= s[newi] && s[newi] <= '9'; newi++ {
if tooLarge(num) {
return 0, false, end // Overflow; crazy long number most likely.
}
num = num*10 + int(s[newi]-'0')
isnum = true
}
return
}
func (p *pp) unknownType(v reflect.Value) {
if !v.IsValid() {
p.buf.writeString(nilAngleString)
return
}
p.buf.writeByte('?')
p.buf.writeString(v.Type().String())
p.buf.writeByte('?')
}
func (p *pp) badVerb(verb rune) {
p.erroring = true
p.buf.writeString(percentBangString)
p.buf.writeRune(verb)
p.buf.writeByte('(')
switch {
case p.arg != nil:
p.buf.writeString(reflect.TypeOf(p.arg).String())
p.buf.writeByte('=')
p.printArg(p.arg, 'v')
case p.value.IsValid():
p.buf.writeString(p.value.Type().String())
p.buf.writeByte('=')
p.printValue(p.value, 'v', 0)
default:
p.buf.writeString(nilAngleString)
}
p.buf.writeByte(')')
p.erroring = false
}
func (p *pp) fmtBool(v bool, verb rune) {
switch verb {
case 't', 'v':
p.fmt.fmtBoolean(v)
default:
p.badVerb(verb)
}
}
// fmt0x64 formats a uint64 in hexadecimal and prefixes it with 0x or
// not, as requested, by temporarily setting the sharp flag.
func (p *pp) fmt0x64(v uint64, leading0x bool) {
sharp := p.fmt.sharp
p.fmt.sharp = leading0x
p.fmt.fmtInteger(v, 16, unsigned, 'v', ldigits)
p.fmt.sharp = sharp
}
// fmtInteger formats a signed or unsigned integer.
func (p *pp) fmtInteger(v uint64, isSigned bool, verb rune) {
switch verb {
case 'v':
if p.fmt.sharpV && !isSigned {
p.fmt0x64(v, true)
} else {
p.fmt.fmtInteger(v, 10, isSigned, verb, ldigits)
}
case 'd':
p.fmt.fmtInteger(v, 10, isSigned, verb, ldigits)
case 'b':
p.fmt.fmtInteger(v, 2, isSigned, verb, ldigits)
case 'o', 'O':
p.fmt.fmtInteger(v, 8, isSigned, verb, ldigits)
case 'x':
p.fmt.fmtInteger(v, 16, isSigned, verb, ldigits)
case 'X':
p.fmt.fmtInteger(v, 16, isSigned, verb, udigits)
case 'c':
p.fmt.fmtC(v)
case 'q':
p.fmt.fmtQc(v)
case 'U':
p.fmt.fmtUnicode(v)
default:
p.badVerb(verb)
}
}
// fmtFloat formats a float. The default precision for each verb
// is specified as last argument in the call to fmt_float.
func (p *pp) fmtFloat(v float64, size int, verb rune) {
switch verb {
case 'v':
p.fmt.fmtFloat(v, size, 'g', -1)
case 'b', 'g', 'G', 'x', 'X':
p.fmt.fmtFloat(v, size, verb, -1)
case 'f', 'e', 'E':
p.fmt.fmtFloat(v, size, verb, 6)
case 'F':
p.fmt.fmtFloat(v, size, 'f', 6)
default:
p.badVerb(verb)
}
}
// fmtComplex formats a complex number v with
// r = real(v) and j = imag(v) as (r+ji) using
// fmtFloat for r and j formatting.
func (p *pp) fmtComplex(v complex128, size int, verb rune) {
// Make sure any unsupported verbs are found before the
// calls to fmtFloat to not generate an incorrect error string.
switch verb {
case 'v', 'b', 'g', 'G', 'x', 'X', 'f', 'F', 'e', 'E':
oldPlus := p.fmt.plus
p.buf.writeByte('(')
p.fmtFloat(real(v), size/2, verb)
// Imaginary part always has a sign.
p.fmt.plus = true
p.fmtFloat(imag(v), size/2, verb)
p.buf.writeString("i)")
p.fmt.plus = oldPlus
default:
p.badVerb(verb)
}
}
func (p *pp) fmtString(v string, verb rune) {
switch verb {
case 'v':
if p.fmt.sharpV {
p.fmt.fmtQ(v)
} else {
p.fmt.fmtS(v)
}
case 's':
p.fmt.fmtS(v)
case 'x':
p.fmt.fmtSx(v, ldigits)
case 'X':
p.fmt.fmtSx(v, udigits)
case 'q':
p.fmt.fmtQ(v)
default:
p.badVerb(verb)
}
}
func (p *pp) fmtBytes(v []byte, verb rune, typeString string) {
switch verb {
case 'v', 'd':
if p.fmt.sharpV {
p.buf.writeString(typeString)
if v == nil {
p.buf.writeString(nilParenString)
return
}
p.buf.writeByte('{')
for i, c := range v {
if i > 0 {
p.buf.writeString(commaSpaceString)
}
p.fmt0x64(uint64(c), true)
}
p.buf.writeByte('}')
} else {
p.buf.writeByte('[')
for i, c := range v {
if i > 0 {
p.buf.writeByte(' ')
}
p.fmt.fmtInteger(uint64(c), 10, unsigned, verb, ldigits)
}
p.buf.writeByte(']')
}
case 's':
p.fmt.fmtBs(v)
case 'x':
p.fmt.fmtBx(v, ldigits)
case 'X':
p.fmt.fmtBx(v, udigits)
case 'q':
p.fmt.fmtQ(string(v))
default:
p.printValue(reflect.ValueOf(v), verb, 0)
}
}
func (p *pp) fmtPointer(value reflect.Value, verb rune) {
var u uintptr
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Slice, reflect.UnsafePointer:
u = value.Pointer()
default:
p.badVerb(verb)
return
}
switch verb {
case 'v':
if p.fmt.sharpV {
p.buf.writeByte('(')
p.buf.writeString(value.Type().String())
p.buf.writeString(")(")
if u == 0 {
p.buf.writeString(nilString)
} else {
p.fmt0x64(uint64(u), true)
}
p.buf.writeByte(')')
} else {
if u == 0 {
p.fmt.padString(nilAngleString)
} else {
p.fmt0x64(uint64(u), !p.fmt.sharp)
}
}
case 'p':
p.fmt0x64(uint64(u), !p.fmt.sharp)
case 'b', 'o', 'd', 'x', 'X':
p.fmtInteger(uint64(u), unsigned, verb)
default:
p.badVerb(verb)
}
}
func (p *pp) catchPanic(arg any, verb rune, method string) {
if err := recover(); err != nil {
// If it's a nil pointer, just say "<nil>". The likeliest causes are a
// Stringer that fails to guard against nil or a nil pointer for a
// value receiver, and in either case, "<nil>" is a nice result.
if v := reflect.ValueOf(arg); v.Kind() == reflect.Pointer && v.IsNil() {
p.buf.writeString(nilAngleString)
return
}
// Otherwise print a concise panic message. Most of the time the panic
// value will print itself nicely.
if p.panicking {
// Nested panics; the recursion in printArg cannot succeed.
panic(err)
}
oldFlags := p.fmt.fmtFlags
// For this output we want default behavior.
p.fmt.clearflags()
p.buf.writeString(percentBangString)
p.buf.writeRune(verb)
p.buf.writeString(panicString)
p.buf.writeString(method)
p.buf.writeString(" method: ")
p.panicking = true
p.printArg(err, 'v')
p.panicking = false
p.buf.writeByte(')')
p.fmt.fmtFlags = oldFlags
}
}
func (p *pp) handleMethods(verb rune) (handled bool) {
if p.erroring {
return
}
if verb == 'w' {
// It is invalid to use %w other than with Errorf or with a non-error arg.
_, ok := p.arg.(error)
if !ok || !p.wrapErrs {
p.badVerb(verb)
return true
}
// If the arg is a Formatter, pass 'v' as the verb to it.
verb = 'v'
}
// Is it a Formatter?
if formatter, ok := p.arg.(Formatter); ok {
handled = true
defer p.catchPanic(p.arg, verb, "Format")
formatter.Format(p, verb)
return
}
// If we're doing Go syntax and the argument knows how to supply it, take care of it now.
if p.fmt.sharpV {
if stringer, ok := p.arg.(GoStringer); ok {
handled = true
defer p.catchPanic(p.arg, verb, "GoString")
// Print the result of GoString unadorned.
p.fmt.fmtS(stringer.GoString())
return
}
} else {
// If a string is acceptable according to the format, see if
// the value satisfies one of the string-valued interfaces.
// Println etc. set verb to %v, which is "stringable".
switch verb {
case 'v', 's', 'x', 'X', 'q':
// Is it an error or Stringer?
// The duplication in the bodies is necessary:
// setting handled and deferring catchPanic
// must happen before calling the method.
switch v := p.arg.(type) {
case error:
handled = true
defer p.catchPanic(p.arg, verb, "Error")
p.fmtString(v.Error(), verb)
return
case Stringer:
handled = true
defer p.catchPanic(p.arg, verb, "String")
p.fmtString(v.String(), verb)
return
}
}
}
return false
}
func (p *pp) printArg(arg any, verb rune) {
p.arg = arg
p.value = reflect.Value{}
if arg == nil {
switch verb {
case 'T', 'v':
p.fmt.padString(nilAngleString)
default:
p.badVerb(verb)
}
return
}
// Special processing considerations.
// %T (the value's type) and %p (its address) are special; we always do them first.
switch verb {
case 'T':
p.fmt.fmtS(reflect.TypeOf(arg).String())
return
case 'p':
p.fmtPointer(reflect.ValueOf(arg), 'p')
return
}
// Some types can be done without reflection.
switch f := arg.(type) {
case bool:
p.fmtBool(f, verb)
case float32:
p.fmtFloat(float64(f), 32, verb)
case float64:
p.fmtFloat(f, 64, verb)
case complex64:
p.fmtComplex(complex128(f), 64, verb)
case complex128:
p.fmtComplex(f, 128, verb)
case int:
p.fmtInteger(uint64(f), signed, verb)
case int8:
p.fmtInteger(uint64(f), signed, verb)
case int16:
p.fmtInteger(uint64(f), signed, verb)
case int32:
p.fmtInteger(uint64(f), signed, verb)
case int64:
p.fmtInteger(uint64(f), signed, verb)
case uint:
p.fmtInteger(uint64(f), unsigned, verb)
case uint8:
p.fmtInteger(uint64(f), unsigned, verb)
case uint16:
p.fmtInteger(uint64(f), unsigned, verb)
case uint32:
p.fmtInteger(uint64(f), unsigned, verb)
case uint64:
p.fmtInteger(f, unsigned, verb)
case uintptr:
p.fmtInteger(uint64(f), unsigned, verb)
case string:
p.fmtString(f, verb)
case []byte:
p.fmtBytes(f, verb, "[]byte")
case reflect.Value:
// Handle extractable values with special methods
// since printValue does not handle them at depth 0.
if f.IsValid() && f.CanInterface() {
p.arg = f.Interface()
if p.handleMethods(verb) {
return
}
}
p.printValue(f, verb, 0)
default:
// If the type is not simple, it might have methods.
if !p.handleMethods(verb) {
// Need to use reflection, since the type had no
// interface methods that could be used for formatting.
p.printValue(reflect.ValueOf(f), verb, 0)
}
}
}
// printValue is similar to printArg but starts with a reflect value, not an interface{} value.
// It does not handle 'p' and 'T' verbs because these should have been already handled by printArg.
func (p *pp) printValue(value reflect.Value, verb rune, depth int) {
// Handle values with special methods if not already handled by printArg (depth == 0).
if depth > 0 && value.IsValid() && value.CanInterface() {
p.arg = value.Interface()
if p.handleMethods(verb) {
return
}
}
p.arg = nil
p.value = value
switch f := value; value.Kind() {
case reflect.Invalid:
if depth == 0 {
p.buf.writeString(invReflectString)
} else {
switch verb {
case 'v':
p.buf.writeString(nilAngleString)
default:
p.badVerb(verb)
}
}
case reflect.Bool:
p.fmtBool(f.Bool(), verb)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p.fmtInteger(uint64(f.Int()), signed, verb)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p.fmtInteger(f.Uint(), unsigned, verb)
case reflect.Float32:
p.fmtFloat(f.Float(), 32, verb)
case reflect.Float64:
p.fmtFloat(f.Float(), 64, verb)
case reflect.Complex64:
p.fmtComplex(f.Complex(), 64, verb)
case reflect.Complex128:
p.fmtComplex(f.Complex(), 128, verb)
case reflect.String:
p.fmtString(f.String(), verb)
case reflect.Map:
if p.fmt.sharpV {
p.buf.writeString(f.Type().String())
if f.IsNil() {
p.buf.writeString(nilParenString)
return
}
p.buf.writeByte('{')
} else {
p.buf.writeString(mapString)
}
sorted := fmtsort.Sort(f)
for i, key := range sorted.Key {
if i > 0 {
if p.fmt.sharpV {
p.buf.writeString(commaSpaceString)
} else {
p.buf.writeByte(' ')
}
}
p.printValue(key, verb, depth+1)
p.buf.writeByte(':')
p.printValue(sorted.Value[i], verb, depth+1)
}
if p.fmt.sharpV {
p.buf.writeByte('}')
} else {
p.buf.writeByte(']')
}
case reflect.Struct:
if p.fmt.sharpV {
p.buf.writeString(f.Type().String())
}
p.buf.writeByte('{')
for i := 0; i < f.NumField(); i++ {
if i > 0 {
if p.fmt.sharpV {
p.buf.writeString(commaSpaceString)
} else {
p.buf.writeByte(' ')
}
}
if p.fmt.plusV || p.fmt.sharpV {
if name := f.Type().Field(i).Name; name != "" {
p.buf.writeString(name)
p.buf.writeByte(':')
}
}
p.printValue(getField(f, i), verb, depth+1)
}
p.buf.writeByte('}')
case reflect.Interface:
value := f.Elem()
if !value.IsValid() {
if p.fmt.sharpV {
p.buf.writeString(f.Type().String())
p.buf.writeString(nilParenString)
} else {
p.buf.writeString(nilAngleString)
}
} else {
p.printValue(value, verb, depth+1)
}
case reflect.Array, reflect.Slice:
switch verb {
case 's', 'q', 'x', 'X':
// Handle byte and uint8 slices and arrays special for the above verbs.
t := f.Type()
if t.Elem().Kind() == reflect.Uint8 {
var bytes []byte
if f.Kind() == reflect.Slice {
bytes = f.Bytes()
} else if f.CanAddr() {
bytes = f.Slice(0, f.Len()).Bytes()
} else {
// We have an array, but we cannot Slice() a non-addressable array,
// so we build a slice by hand. This is a rare case but it would be nice
// if reflection could help a little more.
bytes = make([]byte, f.Len())
for i := range bytes {
bytes[i] = byte(f.Index(i).Uint())
}
}
p.fmtBytes(bytes, verb, t.String())
return
}
}
if p.fmt.sharpV {
p.buf.writeString(f.Type().String())
if f.Kind() == reflect.Slice && f.IsNil() {
p.buf.writeString(nilParenString)
return
}
p.buf.writeByte('{')
for i := 0; i < f.Len(); i++ {
if i > 0 {
p.buf.writeString(commaSpaceString)
}
p.printValue(f.Index(i), verb, depth+1)
}
p.buf.writeByte('}')
} else {
p.buf.writeByte('[')
for i := 0; i < f.Len(); i++ {
if i > 0 {
p.buf.writeByte(' ')
}
p.printValue(f.Index(i), verb, depth+1)
}
p.buf.writeByte(']')
}
case reflect.Pointer:
// pointer to array or slice or struct? ok at top level
// but not embedded (avoid loops)
if depth == 0 && f.Pointer() != 0 {
switch a := f.Elem(); a.Kind() {
case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
p.buf.writeByte('&')
p.printValue(a, verb, depth+1)
return
}
}
fallthrough
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
p.fmtPointer(f, verb)
default:
p.unknownType(f)
}
}
// intFromArg gets the argNumth element of a. On return, isInt reports whether the argument has integer type.
func intFromArg(a []any, argNum int) (num int, isInt bool, newArgNum int) {
newArgNum = argNum
if argNum < len(a) {
num, isInt = a[argNum].(int) // Almost always OK.
if !isInt {
// Work harder.
switch v := reflect.ValueOf(a[argNum]); v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := v.Int()
if int64(int(n)) == n {
num = int(n)
isInt = true
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n := v.Uint()
if int64(n) >= 0 && uint64(int(n)) == n {
num = int(n)
isInt = true
}
default:
// Already 0, false.
}
}
newArgNum = argNum + 1
if tooLarge(num) {
num = 0
isInt = false
}
}
return
}
// parseArgNumber returns the value of the bracketed number, minus 1
// (explicit argument numbers are one-indexed but we want zero-indexed).
// The opening bracket is known to be present at format[0].
// The returned values are the index, the number of bytes to consume
// up to the closing paren, if present, and whether the number parsed
// ok. The bytes to consume will be 1 if no closing paren is present.
func parseArgNumber(format string) (index int, wid int, ok bool) {
// There must be at least 3 bytes: [n].
if len(format) < 3 {
return 0, 1, false
}
// Find closing bracket.
for i := 1; i < len(format); i++ {
if format[i] == ']' {
width, ok, newi := parsenum(format, 1, i)
if !ok || newi != i {
return 0, i + 1, false
}
return width - 1, i + 1, true // arg numbers are one-indexed and skip paren.
}
}
return 0, 1, false
}
// argNumber returns the next argument to evaluate, which is either the value of the passed-in
// argNum or the value of the bracketed integer that begins format[i:]. It also returns
// the new value of i, that is, the index of the next byte of the format to process.
func (p *pp) argNumber(argNum int, format string, i int, numArgs int) (newArgNum, newi int, found bool) {
if len(format) <= i || format[i] != '[' {
return argNum, i, false
}
p.reordered = true
index, wid, ok := parseArgNumber(format[i:])
if ok && 0 <= index && index < numArgs {
return index, i + wid, true
}
p.goodArgNum = false
return argNum, i + wid, ok
}
func (p *pp) badArgNum(verb rune) {
p.buf.writeString(percentBangString)
p.buf.writeRune(verb)
p.buf.writeString(badIndexString)
}
func (p *pp) missingArg(verb rune) {
p.buf.writeString(percentBangString)
p.buf.writeRune(verb)
p.buf.writeString(missingString)
}
func (p *pp) doPrintf(format string, a []any) {
end := len(format)
argNum := 0 // we process one argument per non-trivial format
afterIndex := false // previous item in format was an index like [3].
p.reordered = false
formatLoop:
for i := 0; i < end; {
p.goodArgNum = true
lasti := i
for i < end && format[i] != '%' {
i++
}
if i > lasti {
p.buf.writeString(format[lasti:i])
}
if i >= end {
// done processing format string
break
}
// Process one verb
i++
// Do we have flags?
p.fmt.clearflags()
simpleFormat:
for ; i < end; i++ {
c := format[i]
switch c {
case '#':
p.fmt.sharp = true
case '0':
p.fmt.zero = !p.fmt.minus // Only allow zero padding to the left.
case '+':
p.fmt.plus = true
case '-':
p.fmt.minus = true
p.fmt.zero = false // Do not pad with zeros to the right.
case ' ':
p.fmt.space = true
default:
// Fast path for common case of ascii lower case simple verbs
// without precision or width or argument indices.
if 'a' <= c && c <= 'z' && argNum < len(a) {
switch c {
case 'w':
p.wrappedErrs = append(p.wrappedErrs, argNum)
fallthrough
case 'v':
// Go syntax
p.fmt.sharpV = p.fmt.sharp
p.fmt.sharp = false
// Struct-field syntax
p.fmt.plusV = p.fmt.plus
p.fmt.plus = false
}
p.printArg(a[argNum], rune(c))
argNum++
i++
continue formatLoop
}
// Format is more complex than simple flags and a verb or is malformed.
break simpleFormat
}
}
// Do we have an explicit argument index?
argNum, i, afterIndex = p.argNumber(argNum, format, i, len(a))
// Do we have width?
if i < end && format[i] == '*' {
i++
p.fmt.wid, p.fmt.widPresent, argNum = intFromArg(a, argNum)
if !p.fmt.widPresent {
p.buf.writeString(badWidthString)
}
// We have a negative width, so take its value and ensure
// that the minus flag is set
if p.fmt.wid < 0 {
p.fmt.wid = -p.fmt.wid
p.fmt.minus = true
p.fmt.zero = false // Do not pad with zeros to the right.
}
afterIndex = false
} else {
p.fmt.wid, p.fmt.widPresent, i = parsenum(format, i, end)
if afterIndex && p.fmt.widPresent { // "%[3]2d"
p.goodArgNum = false
}
}
// Do we have precision?
if i+1 < end && format[i] == '.' {
i++
if afterIndex { // "%[3].2d"
p.goodArgNum = false
}
argNum, i, afterIndex = p.argNumber(argNum, format, i, len(a))
if i < end && format[i] == '*' {
i++
p.fmt.prec, p.fmt.precPresent, argNum = intFromArg(a, argNum)
// Negative precision arguments don't make sense
if p.fmt.prec < 0 {
p.fmt.prec = 0
p.fmt.precPresent = false
}
if !p.fmt.precPresent {
p.buf.writeString(badPrecString)
}
afterIndex = false
} else {
p.fmt.prec, p.fmt.precPresent, i = parsenum(format, i, end)
if !p.fmt.precPresent {
p.fmt.prec = 0
p.fmt.precPresent = true
}
}
}
if !afterIndex {
argNum, i, afterIndex = p.argNumber(argNum, format, i, len(a))
}
if i >= end {
p.buf.writeString(noVerbString)
break
}
verb, size := rune(format[i]), 1
if verb >= utf8.RuneSelf {
verb, size = utf8.DecodeRuneInString(format[i:])
}
i += size
switch {
case verb == '%': // Percent does not absorb operands and ignores f.wid and f.prec.
p.buf.writeByte('%')
case !p.goodArgNum:
p.badArgNum(verb)
case argNum >= len(a): // No argument left over to print for the current verb.
p.missingArg(verb)
case verb == 'w':
p.wrappedErrs = append(p.wrappedErrs, argNum)
fallthrough
case verb == 'v':
// Go syntax
p.fmt.sharpV = p.fmt.sharp
p.fmt.sharp = false
// Struct-field syntax
p.fmt.plusV = p.fmt.plus
p.fmt.plus = false
fallthrough
default:
p.printArg(a[argNum], verb)
argNum++
}
}
// Check for extra arguments unless the call accessed the arguments
// out of order, in which case it's too expensive to detect if they've all
// been used and arguably OK if they're not.
if !p.reordered && argNum < len(a) {
p.fmt.clearflags()
p.buf.writeString(extraString)
for i, arg := range a[argNum:] {
if i > 0 {
p.buf.writeString(commaSpaceString)
}
if arg == nil {
p.buf.writeString(nilAngleString)
} else {
p.buf.writeString(reflect.TypeOf(arg).String())
p.buf.writeByte('=')
p.printArg(arg, 'v')
}
}
p.buf.writeByte(')')
}
}
func (p *pp) doPrint(a []any) {
prevString := false
for argNum, arg := range a {
isString := arg != nil && reflect.TypeOf(arg).Kind() == reflect.String
// Add a space between two non-string arguments.
if argNum > 0 && !isString && !prevString {
p.buf.writeByte(' ')
}
p.printArg(arg, 'v')
prevString = isString
}
}
// doPrintln is like doPrint but always adds a space between arguments
// and a newline after the last argument.
func (p *pp) doPrintln(a []any) {
for argNum, arg := range a {
if argNum > 0 {
p.buf.writeByte(' ')
}
p.printArg(arg, 'v')
}
p.buf.writeByte('\n')
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fmt
import (
"errors"
"io"
"math"
"os"
"reflect"
"strconv"
"sync"
"unicode/utf8"
)
// ScanState represents the scanner state passed to custom scanners.
// Scanners may do rune-at-a-time scanning or ask the ScanState
// to discover the next space-delimited token.
type ScanState interface {
// ReadRune reads the next rune (Unicode code point) from the input.
// If invoked during Scanln, Fscanln, or Sscanln, ReadRune() will
// return EOF after returning the first '\n' or when reading beyond
// the specified width.
ReadRune() (r rune, size int, err error)
// UnreadRune causes the next call to ReadRune to return the same rune.
UnreadRune() error
// SkipSpace skips space in the input. Newlines are treated appropriately
// for the operation being performed; see the package documentation
// for more information.
SkipSpace()
// Token skips space in the input if skipSpace is true, then returns the
// run of Unicode code points c satisfying f(c). If f is nil,
// !unicode.IsSpace(c) is used; that is, the token will hold non-space
// characters. Newlines are treated appropriately for the operation being
// performed; see the package documentation for more information.
// The returned slice points to shared data that may be overwritten
// by the next call to Token, a call to a Scan function using the ScanState
// as input, or when the calling Scan method returns.
Token(skipSpace bool, f func(rune) bool) (token []byte, err error)
// Width returns the value of the width option and whether it has been set.
// The unit is Unicode code points.
Width() (wid int, ok bool)
// Because ReadRune is implemented by the interface, Read should never be
// called by the scanning routines and a valid implementation of
// ScanState may choose always to return an error from Read.
Read(buf []byte) (n int, err error)
}
// Scanner is implemented by any value that has a Scan method, which scans
// the input for the representation of a value and stores the result in the
// receiver, which must be a pointer to be useful. The Scan method is called
// for any argument to Scan, Scanf, or Scanln that implements it.
type Scanner interface {
Scan(state ScanState, verb rune) error
}
// Scan scans text read from standard input, storing successive
// space-separated values into successive arguments. Newlines count
// as space. It returns the number of items successfully scanned.
// If that is less than the number of arguments, err will report why.
func Scan(a ...any) (n int, err error) {
return Fscan(os.Stdin, a...)
}
// Scanln is similar to Scan, but stops scanning at a newline and
// after the final item there must be a newline or EOF.
func Scanln(a ...any) (n int, err error) {
return Fscanln(os.Stdin, a...)
}
// Scanf scans text read from standard input, storing successive
// space-separated values into successive arguments as determined by
// the format. It returns the number of items successfully scanned.
// If that is less than the number of arguments, err will report why.
// Newlines in the input must match newlines in the format.
// The one exception: the verb %c always scans the next rune in the
// input, even if it is a space (or tab etc.) or newline.
func Scanf(format string, a ...any) (n int, err error) {
return Fscanf(os.Stdin, format, a...)
}
type stringReader string
func (r *stringReader) Read(b []byte) (n int, err error) {
n = copy(b, *r)
*r = (*r)[n:]
if n == 0 {
err = io.EOF
}
return
}
// Sscan scans the argument string, storing successive space-separated
// values into successive arguments. Newlines count as space. It
// returns the number of items successfully scanned. If that is less
// than the number of arguments, err will report why.
func Sscan(str string, a ...any) (n int, err error) {
return Fscan((*stringReader)(&str), a...)
}
// Sscanln is similar to Sscan, but stops scanning at a newline and
// after the final item there must be a newline or EOF.
func Sscanln(str string, a ...any) (n int, err error) {
return Fscanln((*stringReader)(&str), a...)
}
// Sscanf scans the argument string, storing successive space-separated
// values into successive arguments as determined by the format. It
// returns the number of items successfully parsed.
// Newlines in the input must match newlines in the format.
func Sscanf(str string, format string, a ...any) (n int, err error) {
return Fscanf((*stringReader)(&str), format, a...)
}
// Fscan scans text read from r, storing successive space-separated
// values into successive arguments. Newlines count as space. It
// returns the number of items successfully scanned. If that is less
// than the number of arguments, err will report why.
func Fscan(r io.Reader, a ...any) (n int, err error) {
s, old := newScanState(r, true, false)
n, err = s.doScan(a)
s.free(old)
return
}
// Fscanln is similar to Fscan, but stops scanning at a newline and
// after the final item there must be a newline or EOF.
func Fscanln(r io.Reader, a ...any) (n int, err error) {
s, old := newScanState(r, false, true)
n, err = s.doScan(a)
s.free(old)
return
}
// Fscanf scans text read from r, storing successive space-separated
// values into successive arguments as determined by the format. It
// returns the number of items successfully parsed.
// Newlines in the input must match newlines in the format.
func Fscanf(r io.Reader, format string, a ...any) (n int, err error) {
s, old := newScanState(r, false, false)
n, err = s.doScanf(format, a)
s.free(old)
return
}
// scanError represents an error generated by the scanning software.
// It's used as a unique signature to identify such errors when recovering.
type scanError struct {
err error
}
const eof = -1
// ss is the internal implementation of ScanState.
type ss struct {
rs io.RuneScanner // where to read input
buf buffer // token accumulator
count int // runes consumed so far.
atEOF bool // already read EOF
ssave
}
// ssave holds the parts of ss that need to be
// saved and restored on recursive scans.
type ssave struct {
validSave bool // is or was a part of an actual ss.
nlIsEnd bool // whether newline terminates scan
nlIsSpace bool // whether newline counts as white space
argLimit int // max value of ss.count for this arg; argLimit <= limit
limit int // max value of ss.count.
maxWid int // width of this arg.
}
// The Read method is only in ScanState so that ScanState
// satisfies io.Reader. It will never be called when used as
// intended, so there is no need to make it actually work.
func (s *ss) Read(buf []byte) (n int, err error) {
return 0, errors.New("ScanState's Read should not be called. Use ReadRune")
}
func (s *ss) ReadRune() (r rune, size int, err error) {
if s.atEOF || s.count >= s.argLimit {
err = io.EOF
return
}
r, size, err = s.rs.ReadRune()
if err == nil {
s.count++
if s.nlIsEnd && r == '\n' {
s.atEOF = true
}
} else if err == io.EOF {
s.atEOF = true
}
return
}
func (s *ss) Width() (wid int, ok bool) {
if s.maxWid == hugeWid {
return 0, false
}
return s.maxWid, true
}
// The public method returns an error; this private one panics.
// If getRune reaches EOF, the return value is EOF (-1).
func (s *ss) getRune() (r rune) {
r, _, err := s.ReadRune()
if err != nil {
if err == io.EOF {
return eof
}
s.error(err)
}
return
}
// mustReadRune turns io.EOF into a panic(io.ErrUnexpectedEOF).
// It is called in cases such as string scanning where an EOF is a
// syntax error.
func (s *ss) mustReadRune() (r rune) {
r = s.getRune()
if r == eof {
s.error(io.ErrUnexpectedEOF)
}
return
}
func (s *ss) UnreadRune() error {
s.rs.UnreadRune()
s.atEOF = false
s.count--
return nil
}
func (s *ss) error(err error) {
panic(scanError{err})
}
func (s *ss) errorString(err string) {
panic(scanError{errors.New(err)})
}
func (s *ss) Token(skipSpace bool, f func(rune) bool) (tok []byte, err error) {
defer func() {
if e := recover(); e != nil {
if se, ok := e.(scanError); ok {
err = se.err
} else {
panic(e)
}
}
}()
if f == nil {
f = notSpace
}
s.buf = s.buf[:0]
tok = s.token(skipSpace, f)
return
}
// space is a copy of the unicode.White_Space ranges,
// to avoid depending on package unicode.
var space = [][2]uint16{
{0x0009, 0x000d},
{0x0020, 0x0020},
{0x0085, 0x0085},
{0x00a0, 0x00a0},
{0x1680, 0x1680},
{0x2000, 0x200a},
{0x2028, 0x2029},
{0x202f, 0x202f},
{0x205f, 0x205f},
{0x3000, 0x3000},
}
func isSpace(r rune) bool {
if r >= 1<<16 {
return false
}
rx := uint16(r)
for _, rng := range space {
if rx < rng[0] {
return false
}
if rx <= rng[1] {
return true
}
}
return false
}
// notSpace is the default scanning function used in Token.
func notSpace(r rune) bool {
return !isSpace(r)
}
// readRune is a structure to enable reading UTF-8 encoded code points
// from an io.Reader. It is used if the Reader given to the scanner does
// not already implement io.RuneScanner.
type readRune struct {
reader io.Reader
buf [utf8.UTFMax]byte // used only inside ReadRune
pending int // number of bytes in pendBuf; only >0 for bad UTF-8
pendBuf [utf8.UTFMax]byte // bytes left over
peekRune rune // if >=0 next rune; when <0 is ^(previous Rune)
}
// readByte returns the next byte from the input, which may be
// left over from a previous read if the UTF-8 was ill-formed.
func (r *readRune) readByte() (b byte, err error) {
if r.pending > 0 {
b = r.pendBuf[0]
copy(r.pendBuf[0:], r.pendBuf[1:])
r.pending--
return
}
n, err := io.ReadFull(r.reader, r.pendBuf[:1])
if n != 1 {
return 0, err
}
return r.pendBuf[0], err
}
// ReadRune returns the next UTF-8 encoded code point from the
// io.Reader inside r.
func (r *readRune) ReadRune() (rr rune, size int, err error) {
if r.peekRune >= 0 {
rr = r.peekRune
r.peekRune = ^r.peekRune
size = utf8.RuneLen(rr)
return
}
r.buf[0], err = r.readByte()
if err != nil {
return
}
if r.buf[0] < utf8.RuneSelf { // fast check for common ASCII case
rr = rune(r.buf[0])
size = 1 // Known to be 1.
// Flip the bits of the rune so it's available to UnreadRune.
r.peekRune = ^rr
return
}
var n int
for n = 1; !utf8.FullRune(r.buf[:n]); n++ {
r.buf[n], err = r.readByte()
if err != nil {
if err == io.EOF {
err = nil
break
}
return
}
}
rr, size = utf8.DecodeRune(r.buf[:n])
if size < n { // an error, save the bytes for the next read
copy(r.pendBuf[r.pending:], r.buf[size:n])
r.pending += n - size
}
// Flip the bits of the rune so it's available to UnreadRune.
r.peekRune = ^rr
return
}
func (r *readRune) UnreadRune() error {
if r.peekRune >= 0 {
return errors.New("fmt: scanning called UnreadRune with no rune available")
}
// Reverse bit flip of previously read rune to obtain valid >=0 state.
r.peekRune = ^r.peekRune
return nil
}
var ssFree = sync.Pool{
New: func() any { return new(ss) },
}
// newScanState allocates a new ss struct or grab a cached one.
func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
s = ssFree.Get().(*ss)
if rs, ok := r.(io.RuneScanner); ok {
s.rs = rs
} else {
s.rs = &readRune{reader: r, peekRune: -1}
}
s.nlIsSpace = nlIsSpace
s.nlIsEnd = nlIsEnd
s.atEOF = false
s.limit = hugeWid
s.argLimit = hugeWid
s.maxWid = hugeWid
s.validSave = true
s.count = 0
return
}
// free saves used ss structs in ssFree; avoid an allocation per invocation.
func (s *ss) free(old ssave) {
// If it was used recursively, just restore the old state.
if old.validSave {
s.ssave = old
return
}
// Don't hold on to ss structs with large buffers.
if cap(s.buf) > 1024 {
return
}
s.buf = s.buf[:0]
s.rs = nil
ssFree.Put(s)
}
// SkipSpace provides Scan methods the ability to skip space and newline
// characters in keeping with the current scanning mode set by format strings
// and Scan/Scanln.
func (s *ss) SkipSpace() {
for {
r := s.getRune()
if r == eof {
return
}
if r == '\r' && s.peek("\n") {
continue
}
if r == '\n' {
if s.nlIsSpace {
continue
}
s.errorString("unexpected newline")
return
}
if !isSpace(r) {
s.UnreadRune()
break
}
}
}
// token returns the next space-delimited string from the input. It
// skips white space. For Scanln, it stops at newlines. For Scan,
// newlines are treated as spaces.
func (s *ss) token(skipSpace bool, f func(rune) bool) []byte {
if skipSpace {
s.SkipSpace()
}
// read until white space or newline
for {
r := s.getRune()
if r == eof {
break
}
if !f(r) {
s.UnreadRune()
break
}
s.buf.writeRune(r)
}
return s.buf
}
var errComplex = errors.New("syntax error scanning complex number")
var errBool = errors.New("syntax error scanning boolean")
func indexRune(s string, r rune) int {
for i, c := range s {
if c == r {
return i
}
}
return -1
}
// consume reads the next rune in the input and reports whether it is in the ok string.
// If accept is true, it puts the character into the input token.
func (s *ss) consume(ok string, accept bool) bool {
r := s.getRune()
if r == eof {
return false
}
if indexRune(ok, r) >= 0 {
if accept {
s.buf.writeRune(r)
}
return true
}
if r != eof && accept {
s.UnreadRune()
}
return false
}
// peek reports whether the next character is in the ok string, without consuming it.
func (s *ss) peek(ok string) bool {
r := s.getRune()
if r != eof {
s.UnreadRune()
}
return indexRune(ok, r) >= 0
}
func (s *ss) notEOF() {
// Guarantee there is data to be read.
if r := s.getRune(); r == eof {
panic(io.EOF)
}
s.UnreadRune()
}
// accept checks the next rune in the input. If it's a byte (sic) in the string, it puts it in the
// buffer and returns true. Otherwise it return false.
func (s *ss) accept(ok string) bool {
return s.consume(ok, true)
}
// okVerb verifies that the verb is present in the list, setting s.err appropriately if not.
func (s *ss) okVerb(verb rune, okVerbs, typ string) bool {
for _, v := range okVerbs {
if v == verb {
return true
}
}
s.errorString("bad verb '%" + string(verb) + "' for " + typ)
return false
}
// scanBool returns the value of the boolean represented by the next token.
func (s *ss) scanBool(verb rune) bool {
s.SkipSpace()
s.notEOF()
if !s.okVerb(verb, "tv", "boolean") {
return false
}
// Syntax-checking a boolean is annoying. We're not fastidious about case.
switch s.getRune() {
case '0':
return false
case '1':
return true
case 't', 'T':
if s.accept("rR") && (!s.accept("uU") || !s.accept("eE")) {
s.error(errBool)
}
return true
case 'f', 'F':
if s.accept("aA") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
s.error(errBool)
}
return false
}
return false
}
// Numerical elements
const (
binaryDigits = "01"
octalDigits = "01234567"
decimalDigits = "0123456789"
hexadecimalDigits = "0123456789aAbBcCdDeEfF"
sign = "+-"
period = "."
exponent = "eEpP"
)
// getBase returns the numeric base represented by the verb and its digit string.
func (s *ss) getBase(verb rune) (base int, digits string) {
s.okVerb(verb, "bdoUxXv", "integer") // sets s.err
base = 10
digits = decimalDigits
switch verb {
case 'b':
base = 2
digits = binaryDigits
case 'o':
base = 8
digits = octalDigits
case 'x', 'X', 'U':
base = 16
digits = hexadecimalDigits
}
return
}
// scanNumber returns the numerical string with specified digits starting here.
func (s *ss) scanNumber(digits string, haveDigits bool) string {
if !haveDigits {
s.notEOF()
if !s.accept(digits) {
s.errorString("expected integer")
}
}
for s.accept(digits) {
}
return string(s.buf)
}
// scanRune returns the next rune value in the input.
func (s *ss) scanRune(bitSize int) int64 {
s.notEOF()
r := s.getRune()
n := uint(bitSize)
x := (int64(r) << (64 - n)) >> (64 - n)
if x != int64(r) {
s.errorString("overflow on character value " + string(r))
}
return int64(r)
}
// scanBasePrefix reports whether the integer begins with a base prefix
// and returns the base, digit string, and whether a zero was found.
// It is called only if the verb is %v.
func (s *ss) scanBasePrefix() (base int, digits string, zeroFound bool) {
if !s.peek("0") {
return 0, decimalDigits + "_", false
}
s.accept("0")
// Special cases for 0, 0b, 0o, 0x.
switch {
case s.peek("bB"):
s.consume("bB", true)
return 0, binaryDigits + "_", true
case s.peek("oO"):
s.consume("oO", true)
return 0, octalDigits + "_", true
case s.peek("xX"):
s.consume("xX", true)
return 0, hexadecimalDigits + "_", true
default:
return 0, octalDigits + "_", true
}
}
// scanInt returns the value of the integer represented by the next
// token, checking for overflow. Any error is stored in s.err.
func (s *ss) scanInt(verb rune, bitSize int) int64 {
if verb == 'c' {
return s.scanRune(bitSize)
}
s.SkipSpace()
s.notEOF()
base, digits := s.getBase(verb)
haveDigits := false
if verb == 'U' {
if !s.consume("U", false) || !s.consume("+", false) {
s.errorString("bad unicode format ")
}
} else {
s.accept(sign) // If there's a sign, it will be left in the token buffer.
if verb == 'v' {
base, digits, haveDigits = s.scanBasePrefix()
}
}
tok := s.scanNumber(digits, haveDigits)
i, err := strconv.ParseInt(tok, base, 64)
if err != nil {
s.error(err)
}
n := uint(bitSize)
x := (i << (64 - n)) >> (64 - n)
if x != i {
s.errorString("integer overflow on token " + tok)
}
return i
}
// scanUint returns the value of the unsigned integer represented
// by the next token, checking for overflow. Any error is stored in s.err.
func (s *ss) scanUint(verb rune, bitSize int) uint64 {
if verb == 'c' {
return uint64(s.scanRune(bitSize))
}
s.SkipSpace()
s.notEOF()
base, digits := s.getBase(verb)
haveDigits := false
if verb == 'U' {
if !s.consume("U", false) || !s.consume("+", false) {
s.errorString("bad unicode format ")
}
} else if verb == 'v' {
base, digits, haveDigits = s.scanBasePrefix()
}
tok := s.scanNumber(digits, haveDigits)
i, err := strconv.ParseUint(tok, base, 64)
if err != nil {
s.error(err)
}
n := uint(bitSize)
x := (i << (64 - n)) >> (64 - n)
if x != i {
s.errorString("unsigned integer overflow on token " + tok)
}
return i
}
// floatToken returns the floating-point number starting here, no longer than swid
// if the width is specified. It's not rigorous about syntax because it doesn't check that
// we have at least some digits, but Atof will do that.
func (s *ss) floatToken() string {
s.buf = s.buf[:0]
// NaN?
if s.accept("nN") && s.accept("aA") && s.accept("nN") {
return string(s.buf)
}
// leading sign?
s.accept(sign)
// Inf?
if s.accept("iI") && s.accept("nN") && s.accept("fF") {
return string(s.buf)
}
digits := decimalDigits + "_"
exp := exponent
if s.accept("0") && s.accept("xX") {
digits = hexadecimalDigits + "_"
exp = "pP"
}
// digits?
for s.accept(digits) {
}
// decimal point?
if s.accept(period) {
// fraction?
for s.accept(digits) {
}
}
// exponent?
if s.accept(exp) {
// leading sign?
s.accept(sign)
// digits?
for s.accept(decimalDigits + "_") {
}
}
return string(s.buf)
}
// complexTokens returns the real and imaginary parts of the complex number starting here.
// The number might be parenthesized and has the format (N+Ni) where N is a floating-point
// number and there are no spaces within.
func (s *ss) complexTokens() (real, imag string) {
// TODO: accept N and Ni independently?
parens := s.accept("(")
real = s.floatToken()
s.buf = s.buf[:0]
// Must now have a sign.
if !s.accept("+-") {
s.error(errComplex)
}
// Sign is now in buffer
imagSign := string(s.buf)
imag = s.floatToken()
if !s.accept("i") {
s.error(errComplex)
}
if parens && !s.accept(")") {
s.error(errComplex)
}
return real, imagSign + imag
}
func hasX(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] == 'x' || s[i] == 'X' {
return true
}
}
return false
}
// convertFloat converts the string to a float64value.
func (s *ss) convertFloat(str string, n int) float64 {
// strconv.ParseFloat will handle "+0x1.fp+2",
// but we have to implement our non-standard
// decimal+binary exponent mix (1.2p4) ourselves.
if p := indexRune(str, 'p'); p >= 0 && !hasX(str) {
// Atof doesn't handle power-of-2 exponents,
// but they're easy to evaluate.
f, err := strconv.ParseFloat(str[:p], n)
if err != nil {
// Put full string into error.
if e, ok := err.(*strconv.NumError); ok {
e.Num = str
}
s.error(err)
}
m, err := strconv.Atoi(str[p+1:])
if err != nil {
// Put full string into error.
if e, ok := err.(*strconv.NumError); ok {
e.Num = str
}
s.error(err)
}
return math.Ldexp(f, m)
}
f, err := strconv.ParseFloat(str, n)
if err != nil {
s.error(err)
}
return f
}
// scanComplex converts the next token to a complex128 value.
// The atof argument is a type-specific reader for the underlying type.
// If we're reading complex64, atof will parse float32s and convert them
// to float64's to avoid reproducing this code for each complex type.
func (s *ss) scanComplex(verb rune, n int) complex128 {
if !s.okVerb(verb, floatVerbs, "complex") {
return 0
}
s.SkipSpace()
s.notEOF()
sreal, simag := s.complexTokens()
real := s.convertFloat(sreal, n/2)
imag := s.convertFloat(simag, n/2)
return complex(real, imag)
}
// convertString returns the string represented by the next input characters.
// The format of the input is determined by the verb.
func (s *ss) convertString(verb rune) (str string) {
if !s.okVerb(verb, "svqxX", "string") {
return ""
}
s.SkipSpace()
s.notEOF()
switch verb {
case 'q':
str = s.quotedString()
case 'x', 'X':
str = s.hexString()
default:
str = string(s.token(true, notSpace)) // %s and %v just return the next word
}
return
}
// quotedString returns the double- or back-quoted string represented by the next input characters.
func (s *ss) quotedString() string {
s.notEOF()
quote := s.getRune()
switch quote {
case '`':
// Back-quoted: Anything goes until EOF or back quote.
for {
r := s.mustReadRune()
if r == quote {
break
}
s.buf.writeRune(r)
}
return string(s.buf)
case '"':
// Double-quoted: Include the quotes and let strconv.Unquote do the backslash escapes.
s.buf.writeByte('"')
for {
r := s.mustReadRune()
s.buf.writeRune(r)
if r == '\\' {
// In a legal backslash escape, no matter how long, only the character
// immediately after the escape can itself be a backslash or quote.
// Thus we only need to protect the first character after the backslash.
s.buf.writeRune(s.mustReadRune())
} else if r == '"' {
break
}
}
result, err := strconv.Unquote(string(s.buf))
if err != nil {
s.error(err)
}
return result
default:
s.errorString("expected quoted string")
}
return ""
}
// hexDigit returns the value of the hexadecimal digit.
func hexDigit(d rune) (int, bool) {
digit := int(d)
switch digit {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return digit - '0', true
case 'a', 'b', 'c', 'd', 'e', 'f':
return 10 + digit - 'a', true
case 'A', 'B', 'C', 'D', 'E', 'F':
return 10 + digit - 'A', true
}
return -1, false
}
// hexByte returns the next hex-encoded (two-character) byte from the input.
// It returns ok==false if the next bytes in the input do not encode a hex byte.
// If the first byte is hex and the second is not, processing stops.
func (s *ss) hexByte() (b byte, ok bool) {
rune1 := s.getRune()
if rune1 == eof {
return
}
value1, ok := hexDigit(rune1)
if !ok {
s.UnreadRune()
return
}
value2, ok := hexDigit(s.mustReadRune())
if !ok {
s.errorString("illegal hex digit")
return
}
return byte(value1<<4 | value2), true
}
// hexString returns the space-delimited hexpair-encoded string.
func (s *ss) hexString() string {
s.notEOF()
for {
b, ok := s.hexByte()
if !ok {
break
}
s.buf.writeByte(b)
}
if len(s.buf) == 0 {
s.errorString("no hex data for %x string")
return ""
}
return string(s.buf)
}
const (
floatVerbs = "beEfFgGv"
hugeWid = 1 << 30
intBits = 32 << (^uint(0) >> 63)
uintptrBits = 32 << (^uintptr(0) >> 63)
)
// scanPercent scans a literal percent character.
func (s *ss) scanPercent() {
s.SkipSpace()
s.notEOF()
if !s.accept("%") {
s.errorString("missing literal %")
}
}
// scanOne scans a single value, deriving the scanner from the type of the argument.
func (s *ss) scanOne(verb rune, arg any) {
s.buf = s.buf[:0]
var err error
// If the parameter has its own Scan method, use that.
if v, ok := arg.(Scanner); ok {
err = v.Scan(s, verb)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
s.error(err)
}
return
}
switch v := arg.(type) {
case *bool:
*v = s.scanBool(verb)
case *complex64:
*v = complex64(s.scanComplex(verb, 64))
case *complex128:
*v = s.scanComplex(verb, 128)
case *int:
*v = int(s.scanInt(verb, intBits))
case *int8:
*v = int8(s.scanInt(verb, 8))
case *int16:
*v = int16(s.scanInt(verb, 16))
case *int32:
*v = int32(s.scanInt(verb, 32))
case *int64:
*v = s.scanInt(verb, 64)
case *uint:
*v = uint(s.scanUint(verb, intBits))
case *uint8:
*v = uint8(s.scanUint(verb, 8))
case *uint16:
*v = uint16(s.scanUint(verb, 16))
case *uint32:
*v = uint32(s.scanUint(verb, 32))
case *uint64:
*v = s.scanUint(verb, 64)
case *uintptr:
*v = uintptr(s.scanUint(verb, uintptrBits))
// Floats are tricky because you want to scan in the precision of the result, not
// scan in high precision and convert, in order to preserve the correct error condition.
case *float32:
if s.okVerb(verb, floatVerbs, "float32") {
s.SkipSpace()
s.notEOF()
*v = float32(s.convertFloat(s.floatToken(), 32))
}
case *float64:
if s.okVerb(verb, floatVerbs, "float64") {
s.SkipSpace()
s.notEOF()
*v = s.convertFloat(s.floatToken(), 64)
}
case *string:
*v = s.convertString(verb)
case *[]byte:
// We scan to string and convert so we get a copy of the data.
// If we scanned to bytes, the slice would point at the buffer.
*v = []byte(s.convertString(verb))
default:
val := reflect.ValueOf(v)
ptr := val
if ptr.Kind() != reflect.Pointer {
s.errorString("type not a pointer: " + val.Type().String())
return
}
switch v := ptr.Elem(); v.Kind() {
case reflect.Bool:
v.SetBool(s.scanBool(verb))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v.SetInt(s.scanInt(verb, v.Type().Bits()))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
v.SetUint(s.scanUint(verb, v.Type().Bits()))
case reflect.String:
v.SetString(s.convertString(verb))
case reflect.Slice:
// For now, can only handle (renamed) []byte.
typ := v.Type()
if typ.Elem().Kind() != reflect.Uint8 {
s.errorString("can't scan type: " + val.Type().String())
}
str := s.convertString(verb)
v.Set(reflect.MakeSlice(typ, len(str), len(str)))
for i := 0; i < len(str); i++ {
v.Index(i).SetUint(uint64(str[i]))
}
case reflect.Float32, reflect.Float64:
s.SkipSpace()
s.notEOF()
v.SetFloat(s.convertFloat(s.floatToken(), v.Type().Bits()))
case reflect.Complex64, reflect.Complex128:
v.SetComplex(s.scanComplex(verb, v.Type().Bits()))
default:
s.errorString("can't scan type: " + val.Type().String())
}
}
}
// errorHandler turns local panics into error returns.
func errorHandler(errp *error) {
if e := recover(); e != nil {
if se, ok := e.(scanError); ok { // catch local error
*errp = se.err
} else if eof, ok := e.(error); ok && eof == io.EOF { // out of input
*errp = eof
} else {
panic(e)
}
}
}
// doScan does the real work for scanning without a format string.
func (s *ss) doScan(a []any) (numProcessed int, err error) {
defer errorHandler(&err)
for _, arg := range a {
s.scanOne('v', arg)
numProcessed++
}
// Check for newline (or EOF) if required (Scanln etc.).
if s.nlIsEnd {
for {
r := s.getRune()
if r == '\n' || r == eof {
break
}
if !isSpace(r) {
s.errorString("expected newline")
break
}
}
}
return
}
// advance determines whether the next characters in the input match
// those of the format. It returns the number of bytes (sic) consumed
// in the format. All runs of space characters in either input or
// format behave as a single space. Newlines are special, though:
// newlines in the format must match those in the input and vice versa.
// This routine also handles the %% case. If the return value is zero,
// either format starts with a % (with no following %) or the input
// is empty. If it is negative, the input did not match the string.
func (s *ss) advance(format string) (i int) {
for i < len(format) {
fmtc, w := utf8.DecodeRuneInString(format[i:])
// Space processing.
// In the rest of this comment "space" means spaces other than newline.
// Newline in the format matches input of zero or more spaces and then newline or end-of-input.
// Spaces in the format before the newline are collapsed into the newline.
// Spaces in the format after the newline match zero or more spaces after the corresponding input newline.
// Other spaces in the format match input of one or more spaces or end-of-input.
if isSpace(fmtc) {
newlines := 0
trailingSpace := false
for isSpace(fmtc) && i < len(format) {
if fmtc == '\n' {
newlines++
trailingSpace = false
} else {
trailingSpace = true
}
i += w
fmtc, w = utf8.DecodeRuneInString(format[i:])
}
for j := 0; j < newlines; j++ {
inputc := s.getRune()
for isSpace(inputc) && inputc != '\n' {
inputc = s.getRune()
}
if inputc != '\n' && inputc != eof {
s.errorString("newline in format does not match input")
}
}
if trailingSpace {
inputc := s.getRune()
if newlines == 0 {
// If the trailing space stood alone (did not follow a newline),
// it must find at least one space to consume.
if !isSpace(inputc) && inputc != eof {
s.errorString("expected space in input to match format")
}
if inputc == '\n' {
s.errorString("newline in input does not match format")
}
}
for isSpace(inputc) && inputc != '\n' {
inputc = s.getRune()
}
if inputc != eof {
s.UnreadRune()
}
}
continue
}
// Verbs.
if fmtc == '%' {
// % at end of string is an error.
if i+w == len(format) {
s.errorString("missing verb: % at end of format string")
}
// %% acts like a real percent
nextc, _ := utf8.DecodeRuneInString(format[i+w:]) // will not match % if string is empty
if nextc != '%' {
return
}
i += w // skip the first %
}
// Literals.
inputc := s.mustReadRune()
if fmtc != inputc {
s.UnreadRune()
return -1
}
i += w
}
return
}
// doScanf does the real work when scanning with a format string.
// At the moment, it handles only pointers to basic types.
func (s *ss) doScanf(format string, a []any) (numProcessed int, err error) {
defer errorHandler(&err)
end := len(format) - 1
// We process one item per non-trivial format
for i := 0; i <= end; {
w := s.advance(format[i:])
if w > 0 {
i += w
continue
}
// Either we failed to advance, we have a percent character, or we ran out of input.
if format[i] != '%' {
// Can't advance format. Why not?
if w < 0 {
s.errorString("input does not match format")
}
// Otherwise at EOF; "too many operands" error handled below
break
}
i++ // % is one byte
// do we have 20 (width)?
var widPresent bool
s.maxWid, widPresent, i = parsenum(format, i, end)
if !widPresent {
s.maxWid = hugeWid
}
c, w := utf8.DecodeRuneInString(format[i:])
i += w
if c != 'c' {
s.SkipSpace()
}
if c == '%' {
s.scanPercent()
continue // Do not consume an argument.
}
s.argLimit = s.limit
if f := s.count + s.maxWid; f < s.argLimit {
s.argLimit = f
}
if numProcessed >= len(a) { // out of operands
s.errorString("too few operands for format '%" + format[i-w:] + "'")
break
}
arg := a[numProcessed]
s.scanOne(c, arg)
numProcessed++
s.argLimit = s.limit
}
if numProcessed < len(a) {
s.errorString("too many operands")
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ast declares the types used to represent syntax trees for Go
// packages.
package ast
import (
"go/token"
"strings"
)
// ----------------------------------------------------------------------------
// Interfaces
//
// There are 3 main classes of nodes: Expressions and type nodes,
// statement nodes, and declaration nodes. The node names usually
// match the corresponding Go spec production names to which they
// correspond. The node fields correspond to the individual parts
// of the respective productions.
//
// All nodes contain position information marking the beginning of
// the corresponding source text segment; it is accessible via the
// Pos accessor method. Nodes may contain additional position info
// for language constructs where comments may be found between parts
// of the construct (typically any larger, parenthesized subpart).
// That position information is needed to properly position comments
// when printing the construct.
// All node types implement the Node interface.
type Node interface {
Pos() token.Pos // position of first character belonging to the node
End() token.Pos // position of first character immediately after the node
}
// All expression nodes implement the Expr interface.
type Expr interface {
Node
exprNode()
}
// All statement nodes implement the Stmt interface.
type Stmt interface {
Node
stmtNode()
}
// All declaration nodes implement the Decl interface.
type Decl interface {
Node
declNode()
}
// ----------------------------------------------------------------------------
// Comments
// A Comment node represents a single //-style or /*-style comment.
//
// The Text field contains the comment text without carriage returns (\r) that
// may have been present in the source. Because a comment's end position is
// computed using len(Text), the position reported by End() does not match the
// true source end position for comments containing carriage returns.
type Comment struct {
Slash token.Pos // position of "/" starting the comment
Text string // comment text (excluding '\n' for //-style comments)
}
func (c *Comment) Pos() token.Pos { return c.Slash }
func (c *Comment) End() token.Pos { return token.Pos(int(c.Slash) + len(c.Text)) }
// A CommentGroup represents a sequence of comments
// with no other tokens and no empty lines between.
type CommentGroup struct {
List []*Comment // len(List) > 0
}
func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() }
func (g *CommentGroup) End() token.Pos { return g.List[len(g.List)-1].End() }
func isWhitespace(ch byte) bool { return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' }
func stripTrailingWhitespace(s string) string {
i := len(s)
for i > 0 && isWhitespace(s[i-1]) {
i--
}
return s[0:i]
}
// Text returns the text of the comment.
// Comment markers (//, /*, and */), the first space of a line comment, and
// leading and trailing empty lines are removed.
// Comment directives like "//line" and "//go:noinline" are also removed.
// Multiple empty lines are reduced to one, and trailing space on lines is trimmed.
// Unless the result is empty, it is newline-terminated.
func (g *CommentGroup) Text() string {
if g == nil {
return ""
}
comments := make([]string, len(g.List))
for i, c := range g.List {
comments[i] = c.Text
}
lines := make([]string, 0, 10) // most comments are less than 10 lines
for _, c := range comments {
// Remove comment markers.
// The parser has given us exactly the comment text.
switch c[1] {
case '/':
//-style comment (no newline at the end)
c = c[2:]
if len(c) == 0 {
// empty line
break
}
if c[0] == ' ' {
// strip first space - required for Example tests
c = c[1:]
break
}
if isDirective(c) {
// Ignore //go:noinline, //line, and so on.
continue
}
case '*':
/*-style comment */
c = c[2 : len(c)-2]
}
// Split on newlines.
cl := strings.Split(c, "\n")
// Walk lines, stripping trailing white space and adding to list.
for _, l := range cl {
lines = append(lines, stripTrailingWhitespace(l))
}
}
// Remove leading blank lines; convert runs of
// interior blank lines to a single blank line.
n := 0
for _, line := range lines {
if line != "" || n > 0 && lines[n-1] != "" {
lines[n] = line
n++
}
}
lines = lines[0:n]
// Add final "" entry to get trailing newline from Join.
if n > 0 && lines[n-1] != "" {
lines = append(lines, "")
}
return strings.Join(lines, "\n")
}
// isDirective reports whether c is a comment directive.
// This code is also in go/printer.
func isDirective(c string) bool {
// "//line " is a line directive.
// "//extern " is for gccgo.
// "//export " is for cgo.
// (The // has been removed.)
if strings.HasPrefix(c, "line ") || strings.HasPrefix(c, "extern ") || strings.HasPrefix(c, "export ") {
return true
}
// "//[a-z0-9]+:[a-z0-9]"
// (The // has been removed.)
colon := strings.Index(c, ":")
if colon <= 0 || colon+1 >= len(c) {
return false
}
for i := 0; i <= colon+1; i++ {
if i == colon {
continue
}
b := c[i]
if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
return false
}
}
return true
}
// ----------------------------------------------------------------------------
// Expressions and types
// A Field represents a Field declaration list in a struct type,
// a method list in an interface type, or a parameter/result declaration
// in a signature.
// Field.Names is nil for unnamed parameters (parameter lists which only contain types)
// and embedded struct fields. In the latter case, the field name is the type name.
type Field struct {
Doc *CommentGroup // associated documentation; or nil
Names []*Ident // field/method/(type) parameter names; or nil
Type Expr // field/method/parameter type; or nil
Tag *BasicLit // field tag; or nil
Comment *CommentGroup // line comments; or nil
}
func (f *Field) Pos() token.Pos {
if len(f.Names) > 0 {
return f.Names[0].Pos()
}
if f.Type != nil {
return f.Type.Pos()
}
return token.NoPos
}
func (f *Field) End() token.Pos {
if f.Tag != nil {
return f.Tag.End()
}
if f.Type != nil {
return f.Type.End()
}
if len(f.Names) > 0 {
return f.Names[len(f.Names)-1].End()
}
return token.NoPos
}
// A FieldList represents a list of Fields, enclosed by parentheses,
// curly braces, or square brackets.
type FieldList struct {
Opening token.Pos // position of opening parenthesis/brace/bracket, if any
List []*Field // field list; or nil
Closing token.Pos // position of closing parenthesis/brace/bracket, if any
}
func (f *FieldList) Pos() token.Pos {
if f.Opening.IsValid() {
return f.Opening
}
// the list should not be empty in this case;
// be conservative and guard against bad ASTs
if len(f.List) > 0 {
return f.List[0].Pos()
}
return token.NoPos
}
func (f *FieldList) End() token.Pos {
if f.Closing.IsValid() {
return f.Closing + 1
}
// the list should not be empty in this case;
// be conservative and guard against bad ASTs
if n := len(f.List); n > 0 {
return f.List[n-1].End()
}
return token.NoPos
}
// NumFields returns the number of parameters or struct fields represented by a FieldList.
func (f *FieldList) NumFields() int {
n := 0
if f != nil {
for _, g := range f.List {
m := len(g.Names)
if m == 0 {
m = 1
}
n += m
}
}
return n
}
// An expression is represented by a tree consisting of one
// or more of the following concrete expression nodes.
type (
// A BadExpr node is a placeholder for an expression containing
// syntax errors for which a correct expression node cannot be
// created.
//
BadExpr struct {
From, To token.Pos // position range of bad expression
}
// An Ident node represents an identifier.
Ident struct {
NamePos token.Pos // identifier position
Name string // identifier name
Obj *Object // denoted object; or nil
}
// An Ellipsis node stands for the "..." type in a
// parameter list or the "..." length in an array type.
//
Ellipsis struct {
Ellipsis token.Pos // position of "..."
Elt Expr // ellipsis element type (parameter lists only); or nil
}
// A BasicLit node represents a literal of basic type.
BasicLit struct {
ValuePos token.Pos // literal position
Kind token.Token // token.INT, token.FLOAT, token.IMAG, token.CHAR, or token.STRING
Value string // literal string; e.g. 42, 0x7f, 3.14, 1e-9, 2.4i, 'a', '\x7f', "foo" or `\m\n\o`
}
// A FuncLit node represents a function literal.
FuncLit struct {
Type *FuncType // function type
Body *BlockStmt // function body
}
// A CompositeLit node represents a composite literal.
CompositeLit struct {
Type Expr // literal type; or nil
Lbrace token.Pos // position of "{"
Elts []Expr // list of composite elements; or nil
Rbrace token.Pos // position of "}"
Incomplete bool // true if (source) expressions are missing in the Elts list
}
// A ParenExpr node represents a parenthesized expression.
ParenExpr struct {
Lparen token.Pos // position of "("
X Expr // parenthesized expression
Rparen token.Pos // position of ")"
}
// A SelectorExpr node represents an expression followed by a selector.
SelectorExpr struct {
X Expr // expression
Sel *Ident // field selector
}
// An IndexExpr node represents an expression followed by an index.
IndexExpr struct {
X Expr // expression
Lbrack token.Pos // position of "["
Index Expr // index expression
Rbrack token.Pos // position of "]"
}
// An IndexListExpr node represents an expression followed by multiple
// indices.
IndexListExpr struct {
X Expr // expression
Lbrack token.Pos // position of "["
Indices []Expr // index expressions
Rbrack token.Pos // position of "]"
}
// A SliceExpr node represents an expression followed by slice indices.
SliceExpr struct {
X Expr // expression
Lbrack token.Pos // position of "["
Low Expr // begin of slice range; or nil
High Expr // end of slice range; or nil
Max Expr // maximum capacity of slice; or nil
Slice3 bool // true if 3-index slice (2 colons present)
Rbrack token.Pos // position of "]"
}
// A TypeAssertExpr node represents an expression followed by a
// type assertion.
//
TypeAssertExpr struct {
X Expr // expression
Lparen token.Pos // position of "("
Type Expr // asserted type; nil means type switch X.(type)
Rparen token.Pos // position of ")"
}
// A CallExpr node represents an expression followed by an argument list.
CallExpr struct {
Fun Expr // function expression
Lparen token.Pos // position of "("
Args []Expr // function arguments; or nil
Ellipsis token.Pos // position of "..." (token.NoPos if there is no "...")
Rparen token.Pos // position of ")"
}
// A StarExpr node represents an expression of the form "*" Expression.
// Semantically it could be a unary "*" expression, or a pointer type.
//
StarExpr struct {
Star token.Pos // position of "*"
X Expr // operand
}
// A UnaryExpr node represents a unary expression.
// Unary "*" expressions are represented via StarExpr nodes.
//
UnaryExpr struct {
OpPos token.Pos // position of Op
Op token.Token // operator
X Expr // operand
}
// A BinaryExpr node represents a binary expression.
BinaryExpr struct {
X Expr // left operand
OpPos token.Pos // position of Op
Op token.Token // operator
Y Expr // right operand
}
// A KeyValueExpr node represents (key : value) pairs
// in composite literals.
//
KeyValueExpr struct {
Key Expr
Colon token.Pos // position of ":"
Value Expr
}
)
// The direction of a channel type is indicated by a bit
// mask including one or both of the following constants.
type ChanDir int
const (
SEND ChanDir = 1 << iota
RECV
)
// A type is represented by a tree consisting of one
// or more of the following type-specific expression
// nodes.
type (
// An ArrayType node represents an array or slice type.
ArrayType struct {
Lbrack token.Pos // position of "["
Len Expr // Ellipsis node for [...]T array types, nil for slice types
Elt Expr // element type
}
// A StructType node represents a struct type.
StructType struct {
Struct token.Pos // position of "struct" keyword
Fields *FieldList // list of field declarations
Incomplete bool // true if (source) fields are missing in the Fields list
}
// Pointer types are represented via StarExpr nodes.
// A FuncType node represents a function type.
FuncType struct {
Func token.Pos // position of "func" keyword (token.NoPos if there is no "func")
TypeParams *FieldList // type parameters; or nil
Params *FieldList // (incoming) parameters; non-nil
Results *FieldList // (outgoing) results; or nil
}
// An InterfaceType node represents an interface type.
InterfaceType struct {
Interface token.Pos // position of "interface" keyword
Methods *FieldList // list of embedded interfaces, methods, or types
Incomplete bool // true if (source) methods or types are missing in the Methods list
}
// A MapType node represents a map type.
MapType struct {
Map token.Pos // position of "map" keyword
Key Expr
Value Expr
}
// A ChanType node represents a channel type.
ChanType struct {
Begin token.Pos // position of "chan" keyword or "<-" (whichever comes first)
Arrow token.Pos // position of "<-" (token.NoPos if there is no "<-")
Dir ChanDir // channel direction
Value Expr // value type
}
)
// Pos and End implementations for expression/type nodes.
func (x *BadExpr) Pos() token.Pos { return x.From }
func (x *Ident) Pos() token.Pos { return x.NamePos }
func (x *Ellipsis) Pos() token.Pos { return x.Ellipsis }
func (x *BasicLit) Pos() token.Pos { return x.ValuePos }
func (x *FuncLit) Pos() token.Pos { return x.Type.Pos() }
func (x *CompositeLit) Pos() token.Pos {
if x.Type != nil {
return x.Type.Pos()
}
return x.Lbrace
}
func (x *ParenExpr) Pos() token.Pos { return x.Lparen }
func (x *SelectorExpr) Pos() token.Pos { return x.X.Pos() }
func (x *IndexExpr) Pos() token.Pos { return x.X.Pos() }
func (x *IndexListExpr) Pos() token.Pos { return x.X.Pos() }
func (x *SliceExpr) Pos() token.Pos { return x.X.Pos() }
func (x *TypeAssertExpr) Pos() token.Pos { return x.X.Pos() }
func (x *CallExpr) Pos() token.Pos { return x.Fun.Pos() }
func (x *StarExpr) Pos() token.Pos { return x.Star }
func (x *UnaryExpr) Pos() token.Pos { return x.OpPos }
func (x *BinaryExpr) Pos() token.Pos { return x.X.Pos() }
func (x *KeyValueExpr) Pos() token.Pos { return x.Key.Pos() }
func (x *ArrayType) Pos() token.Pos { return x.Lbrack }
func (x *StructType) Pos() token.Pos { return x.Struct }
func (x *FuncType) Pos() token.Pos {
if x.Func.IsValid() || x.Params == nil { // see issue 3870
return x.Func
}
return x.Params.Pos() // interface method declarations have no "func" keyword
}
func (x *InterfaceType) Pos() token.Pos { return x.Interface }
func (x *MapType) Pos() token.Pos { return x.Map }
func (x *ChanType) Pos() token.Pos { return x.Begin }
func (x *BadExpr) End() token.Pos { return x.To }
func (x *Ident) End() token.Pos { return token.Pos(int(x.NamePos) + len(x.Name)) }
func (x *Ellipsis) End() token.Pos {
if x.Elt != nil {
return x.Elt.End()
}
return x.Ellipsis + 3 // len("...")
}
func (x *BasicLit) End() token.Pos { return token.Pos(int(x.ValuePos) + len(x.Value)) }
func (x *FuncLit) End() token.Pos { return x.Body.End() }
func (x *CompositeLit) End() token.Pos { return x.Rbrace + 1 }
func (x *ParenExpr) End() token.Pos { return x.Rparen + 1 }
func (x *SelectorExpr) End() token.Pos { return x.Sel.End() }
func (x *IndexExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *IndexListExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *SliceExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *TypeAssertExpr) End() token.Pos { return x.Rparen + 1 }
func (x *CallExpr) End() token.Pos { return x.Rparen + 1 }
func (x *StarExpr) End() token.Pos { return x.X.End() }
func (x *UnaryExpr) End() token.Pos { return x.X.End() }
func (x *BinaryExpr) End() token.Pos { return x.Y.End() }
func (x *KeyValueExpr) End() token.Pos { return x.Value.End() }
func (x *ArrayType) End() token.Pos { return x.Elt.End() }
func (x *StructType) End() token.Pos { return x.Fields.End() }
func (x *FuncType) End() token.Pos {
if x.Results != nil {
return x.Results.End()
}
return x.Params.End()
}
func (x *InterfaceType) End() token.Pos { return x.Methods.End() }
func (x *MapType) End() token.Pos { return x.Value.End() }
func (x *ChanType) End() token.Pos { return x.Value.End() }
// exprNode() ensures that only expression/type nodes can be
// assigned to an Expr.
func (*BadExpr) exprNode() {}
func (*Ident) exprNode() {}
func (*Ellipsis) exprNode() {}
func (*BasicLit) exprNode() {}
func (*FuncLit) exprNode() {}
func (*CompositeLit) exprNode() {}
func (*ParenExpr) exprNode() {}
func (*SelectorExpr) exprNode() {}
func (*IndexExpr) exprNode() {}
func (*IndexListExpr) exprNode() {}
func (*SliceExpr) exprNode() {}
func (*TypeAssertExpr) exprNode() {}
func (*CallExpr) exprNode() {}
func (*StarExpr) exprNode() {}
func (*UnaryExpr) exprNode() {}
func (*BinaryExpr) exprNode() {}
func (*KeyValueExpr) exprNode() {}
func (*ArrayType) exprNode() {}
func (*StructType) exprNode() {}
func (*FuncType) exprNode() {}
func (*InterfaceType) exprNode() {}
func (*MapType) exprNode() {}
func (*ChanType) exprNode() {}
// ----------------------------------------------------------------------------
// Convenience functions for Idents
// NewIdent creates a new Ident without position.
// Useful for ASTs generated by code other than the Go parser.
func NewIdent(name string) *Ident { return &Ident{token.NoPos, name, nil} }
// IsExported reports whether name starts with an upper-case letter.
func IsExported(name string) bool { return token.IsExported(name) }
// IsExported reports whether id starts with an upper-case letter.
func (id *Ident) IsExported() bool { return token.IsExported(id.Name) }
func (id *Ident) String() string {
if id != nil {
return id.Name
}
return "<nil>"
}
// ----------------------------------------------------------------------------
// Statements
// A statement is represented by a tree consisting of one
// or more of the following concrete statement nodes.
type (
// A BadStmt node is a placeholder for statements containing
// syntax errors for which no correct statement nodes can be
// created.
//
BadStmt struct {
From, To token.Pos // position range of bad statement
}
// A DeclStmt node represents a declaration in a statement list.
DeclStmt struct {
Decl Decl // *GenDecl with CONST, TYPE, or VAR token
}
// An EmptyStmt node represents an empty statement.
// The "position" of the empty statement is the position
// of the immediately following (explicit or implicit) semicolon.
//
EmptyStmt struct {
Semicolon token.Pos // position of following ";"
Implicit bool // if set, ";" was omitted in the source
}
// A LabeledStmt node represents a labeled statement.
LabeledStmt struct {
Label *Ident
Colon token.Pos // position of ":"
Stmt Stmt
}
// An ExprStmt node represents a (stand-alone) expression
// in a statement list.
//
ExprStmt struct {
X Expr // expression
}
// A SendStmt node represents a send statement.
SendStmt struct {
Chan Expr
Arrow token.Pos // position of "<-"
Value Expr
}
// An IncDecStmt node represents an increment or decrement statement.
IncDecStmt struct {
X Expr
TokPos token.Pos // position of Tok
Tok token.Token // INC or DEC
}
// An AssignStmt node represents an assignment or
// a short variable declaration.
//
AssignStmt struct {
Lhs []Expr
TokPos token.Pos // position of Tok
Tok token.Token // assignment token, DEFINE
Rhs []Expr
}
// A GoStmt node represents a go statement.
GoStmt struct {
Go token.Pos // position of "go" keyword
Call *CallExpr
}
// A DeferStmt node represents a defer statement.
DeferStmt struct {
Defer token.Pos // position of "defer" keyword
Call *CallExpr
}
// A ReturnStmt node represents a return statement.
ReturnStmt struct {
Return token.Pos // position of "return" keyword
Results []Expr // result expressions; or nil
}
// A BranchStmt node represents a break, continue, goto,
// or fallthrough statement.
//
BranchStmt struct {
TokPos token.Pos // position of Tok
Tok token.Token // keyword token (BREAK, CONTINUE, GOTO, FALLTHROUGH)
Label *Ident // label name; or nil
}
// A BlockStmt node represents a braced statement list.
BlockStmt struct {
Lbrace token.Pos // position of "{"
List []Stmt
Rbrace token.Pos // position of "}", if any (may be absent due to syntax error)
}
// An IfStmt node represents an if statement.
IfStmt struct {
If token.Pos // position of "if" keyword
Init Stmt // initialization statement; or nil
Cond Expr // condition
Body *BlockStmt
Else Stmt // else branch; or nil
}
// A CaseClause represents a case of an expression or type switch statement.
CaseClause struct {
Case token.Pos // position of "case" or "default" keyword
List []Expr // list of expressions or types; nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// A SwitchStmt node represents an expression switch statement.
SwitchStmt struct {
Switch token.Pos // position of "switch" keyword
Init Stmt // initialization statement; or nil
Tag Expr // tag expression; or nil
Body *BlockStmt // CaseClauses only
}
// A TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword
Init Stmt // initialization statement; or nil
Assign Stmt // x := y.(type) or y.(type)
Body *BlockStmt // CaseClauses only
}
// A CommClause node represents a case of a select statement.
CommClause struct {
Case token.Pos // position of "case" or "default" keyword
Comm Stmt // send or receive statement; nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// A SelectStmt node represents a select statement.
SelectStmt struct {
Select token.Pos // position of "select" keyword
Body *BlockStmt // CommClauses only
}
// A ForStmt represents a for statement.
ForStmt struct {
For token.Pos // position of "for" keyword
Init Stmt // initialization statement; or nil
Cond Expr // condition; or nil
Post Stmt // post iteration statement; or nil
Body *BlockStmt
}
// A RangeStmt represents a for statement with a range clause.
RangeStmt struct {
For token.Pos // position of "for" keyword
Key, Value Expr // Key, Value may be nil
TokPos token.Pos // position of Tok; invalid if Key == nil
Tok token.Token // ILLEGAL if Key == nil, ASSIGN, DEFINE
Range token.Pos // position of "range" keyword
X Expr // value to range over
Body *BlockStmt
}
)
// Pos and End implementations for statement nodes.
func (s *BadStmt) Pos() token.Pos { return s.From }
func (s *DeclStmt) Pos() token.Pos { return s.Decl.Pos() }
func (s *EmptyStmt) Pos() token.Pos { return s.Semicolon }
func (s *LabeledStmt) Pos() token.Pos { return s.Label.Pos() }
func (s *ExprStmt) Pos() token.Pos { return s.X.Pos() }
func (s *SendStmt) Pos() token.Pos { return s.Chan.Pos() }
func (s *IncDecStmt) Pos() token.Pos { return s.X.Pos() }
func (s *AssignStmt) Pos() token.Pos { return s.Lhs[0].Pos() }
func (s *GoStmt) Pos() token.Pos { return s.Go }
func (s *DeferStmt) Pos() token.Pos { return s.Defer }
func (s *ReturnStmt) Pos() token.Pos { return s.Return }
func (s *BranchStmt) Pos() token.Pos { return s.TokPos }
func (s *BlockStmt) Pos() token.Pos { return s.Lbrace }
func (s *IfStmt) Pos() token.Pos { return s.If }
func (s *CaseClause) Pos() token.Pos { return s.Case }
func (s *SwitchStmt) Pos() token.Pos { return s.Switch }
func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
func (s *CommClause) Pos() token.Pos { return s.Case }
func (s *SelectStmt) Pos() token.Pos { return s.Select }
func (s *ForStmt) Pos() token.Pos { return s.For }
func (s *RangeStmt) Pos() token.Pos { return s.For }
func (s *BadStmt) End() token.Pos { return s.To }
func (s *DeclStmt) End() token.Pos { return s.Decl.End() }
func (s *EmptyStmt) End() token.Pos {
if s.Implicit {
return s.Semicolon
}
return s.Semicolon + 1 /* len(";") */
}
func (s *LabeledStmt) End() token.Pos { return s.Stmt.End() }
func (s *ExprStmt) End() token.Pos { return s.X.End() }
func (s *SendStmt) End() token.Pos { return s.Value.End() }
func (s *IncDecStmt) End() token.Pos {
return s.TokPos + 2 /* len("++") */
}
func (s *AssignStmt) End() token.Pos { return s.Rhs[len(s.Rhs)-1].End() }
func (s *GoStmt) End() token.Pos { return s.Call.End() }
func (s *DeferStmt) End() token.Pos { return s.Call.End() }
func (s *ReturnStmt) End() token.Pos {
if n := len(s.Results); n > 0 {
return s.Results[n-1].End()
}
return s.Return + 6 // len("return")
}
func (s *BranchStmt) End() token.Pos {
if s.Label != nil {
return s.Label.End()
}
return token.Pos(int(s.TokPos) + len(s.Tok.String()))
}
func (s *BlockStmt) End() token.Pos {
if s.Rbrace.IsValid() {
return s.Rbrace + 1
}
if n := len(s.List); n > 0 {
return s.List[n-1].End()
}
return s.Lbrace + 1
}
func (s *IfStmt) End() token.Pos {
if s.Else != nil {
return s.Else.End()
}
return s.Body.End()
}
func (s *CaseClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
func (s *CommClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *SelectStmt) End() token.Pos { return s.Body.End() }
func (s *ForStmt) End() token.Pos { return s.Body.End() }
func (s *RangeStmt) End() token.Pos { return s.Body.End() }
// stmtNode() ensures that only statement nodes can be
// assigned to a Stmt.
func (*BadStmt) stmtNode() {}
func (*DeclStmt) stmtNode() {}
func (*EmptyStmt) stmtNode() {}
func (*LabeledStmt) stmtNode() {}
func (*ExprStmt) stmtNode() {}
func (*SendStmt) stmtNode() {}
func (*IncDecStmt) stmtNode() {}
func (*AssignStmt) stmtNode() {}
func (*GoStmt) stmtNode() {}
func (*DeferStmt) stmtNode() {}
func (*ReturnStmt) stmtNode() {}
func (*BranchStmt) stmtNode() {}
func (*BlockStmt) stmtNode() {}
func (*IfStmt) stmtNode() {}
func (*CaseClause) stmtNode() {}
func (*SwitchStmt) stmtNode() {}
func (*TypeSwitchStmt) stmtNode() {}
func (*CommClause) stmtNode() {}
func (*SelectStmt) stmtNode() {}
func (*ForStmt) stmtNode() {}
func (*RangeStmt) stmtNode() {}
// ----------------------------------------------------------------------------
// Declarations
// A Spec node represents a single (non-parenthesized) import,
// constant, type, or variable declaration.
type (
// The Spec type stands for any of *ImportSpec, *ValueSpec, and *TypeSpec.
Spec interface {
Node
specNode()
}
// An ImportSpec node represents a single package import.
ImportSpec struct {
Doc *CommentGroup // associated documentation; or nil
Name *Ident // local package name (including "."); or nil
Path *BasicLit // import path
Comment *CommentGroup // line comments; or nil
EndPos token.Pos // end of spec (overrides Path.Pos if nonzero)
}
// A ValueSpec node represents a constant or variable declaration
// (ConstSpec or VarSpec production).
//
ValueSpec struct {
Doc *CommentGroup // associated documentation; or nil
Names []*Ident // value names (len(Names) > 0)
Type Expr // value type; or nil
Values []Expr // initial values; or nil
Comment *CommentGroup // line comments; or nil
}
// A TypeSpec node represents a type declaration (TypeSpec production).
TypeSpec struct {
Doc *CommentGroup // associated documentation; or nil
Name *Ident // type name
TypeParams *FieldList // type parameters; or nil
Assign token.Pos // position of '=', if any
Type Expr // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes
Comment *CommentGroup // line comments; or nil
}
)
// Pos and End implementations for spec nodes.
func (s *ImportSpec) Pos() token.Pos {
if s.Name != nil {
return s.Name.Pos()
}
return s.Path.Pos()
}
func (s *ValueSpec) Pos() token.Pos { return s.Names[0].Pos() }
func (s *TypeSpec) Pos() token.Pos { return s.Name.Pos() }
func (s *ImportSpec) End() token.Pos {
if s.EndPos != 0 {
return s.EndPos
}
return s.Path.End()
}
func (s *ValueSpec) End() token.Pos {
if n := len(s.Values); n > 0 {
return s.Values[n-1].End()
}
if s.Type != nil {
return s.Type.End()
}
return s.Names[len(s.Names)-1].End()
}
func (s *TypeSpec) End() token.Pos { return s.Type.End() }
// specNode() ensures that only spec nodes can be
// assigned to a Spec.
func (*ImportSpec) specNode() {}
func (*ValueSpec) specNode() {}
func (*TypeSpec) specNode() {}
// A declaration is represented by one of the following declaration nodes.
type (
// A BadDecl node is a placeholder for a declaration containing
// syntax errors for which a correct declaration node cannot be
// created.
//
BadDecl struct {
From, To token.Pos // position range of bad declaration
}
// A GenDecl node (generic declaration node) represents an import,
// constant, type or variable declaration. A valid Lparen position
// (Lparen.IsValid()) indicates a parenthesized declaration.
//
// Relationship between Tok value and Specs element type:
//
// token.IMPORT *ImportSpec
// token.CONST *ValueSpec
// token.TYPE *TypeSpec
// token.VAR *ValueSpec
//
GenDecl struct {
Doc *CommentGroup // associated documentation; or nil
TokPos token.Pos // position of Tok
Tok token.Token // IMPORT, CONST, TYPE, or VAR
Lparen token.Pos // position of '(', if any
Specs []Spec
Rparen token.Pos // position of ')', if any
}
// A FuncDecl node represents a function declaration.
FuncDecl struct {
Doc *CommentGroup // associated documentation; or nil
Recv *FieldList // receiver (methods); or nil (functions)
Name *Ident // function/method name
Type *FuncType // function signature: type and value parameters, results, and position of "func" keyword
Body *BlockStmt // function body; or nil for external (non-Go) function
}
)
// Pos and End implementations for declaration nodes.
func (d *BadDecl) Pos() token.Pos { return d.From }
func (d *GenDecl) Pos() token.Pos { return d.TokPos }
func (d *FuncDecl) Pos() token.Pos { return d.Type.Pos() }
func (d *BadDecl) End() token.Pos { return d.To }
func (d *GenDecl) End() token.Pos {
if d.Rparen.IsValid() {
return d.Rparen + 1
}
return d.Specs[0].End()
}
func (d *FuncDecl) End() token.Pos {
if d.Body != nil {
return d.Body.End()
}
return d.Type.End()
}
// declNode() ensures that only declaration nodes can be
// assigned to a Decl.
func (*BadDecl) declNode() {}
func (*GenDecl) declNode() {}
func (*FuncDecl) declNode() {}
// ----------------------------------------------------------------------------
// Files and packages
// A File node represents a Go source file.
//
// The Comments list contains all comments in the source file in order of
// appearance, including the comments that are pointed to from other nodes
// via Doc and Comment fields.
//
// For correct printing of source code containing comments (using packages
// go/format and go/printer), special care must be taken to update comments
// when a File's syntax tree is modified: For printing, comments are interspersed
// between tokens based on their position. If syntax tree nodes are
// removed or moved, relevant comments in their vicinity must also be removed
// (from the File.Comments list) or moved accordingly (by updating their
// positions). A CommentMap may be used to facilitate some of these operations.
//
// Whether and how a comment is associated with a node depends on the
// interpretation of the syntax tree by the manipulating program: Except for Doc
// and Comment comments directly associated with nodes, the remaining comments
// are "free-floating" (see also issues #18593, #20744).
type File struct {
Doc *CommentGroup // associated documentation; or nil
Package token.Pos // position of "package" keyword
Name *Ident // package name
Decls []Decl // top-level declarations; or nil
FileStart, FileEnd token.Pos // start and end of entire file
Scope *Scope // package scope (this file only)
Imports []*ImportSpec // imports in this file
Unresolved []*Ident // unresolved identifiers in this file
Comments []*CommentGroup // list of all comments in the source file
}
// Pos returns the position of the package declaration.
// (Use FileStart for the start of the entire file.)
func (f *File) Pos() token.Pos { return f.Package }
// End returns the end of the last declaration in the file.
// (Use FileEnd for the end of the entire file.)
func (f *File) End() token.Pos {
if n := len(f.Decls); n > 0 {
return f.Decls[n-1].End()
}
return f.Name.End()
}
// A Package node represents a set of source files
// collectively building a Go package.
type Package struct {
Name string // package name
Scope *Scope // package scope across all files
Imports map[string]*Object // map of package id -> package object
Files map[string]*File // Go source files by filename
}
func (p *Package) Pos() token.Pos { return token.NoPos }
func (p *Package) End() token.Pos { return token.NoPos }
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ast
import (
"bytes"
"fmt"
"go/token"
"sort"
"strings"
)
type byPos []*CommentGroup
func (a byPos) Len() int { return len(a) }
func (a byPos) Less(i, j int) bool { return a[i].Pos() < a[j].Pos() }
func (a byPos) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// sortComments sorts the list of comment groups in source order.
func sortComments(list []*CommentGroup) {
// TODO(gri): Does it make sense to check for sorted-ness
// first (because we know that sorted-ness is
// very likely)?
if orderedList := byPos(list); !sort.IsSorted(orderedList) {
sort.Sort(orderedList)
}
}
// A CommentMap maps an AST node to a list of comment groups
// associated with it. See NewCommentMap for a description of
// the association.
type CommentMap map[Node][]*CommentGroup
func (cmap CommentMap) addComment(n Node, c *CommentGroup) {
list := cmap[n]
if len(list) == 0 {
list = []*CommentGroup{c}
} else {
list = append(list, c)
}
cmap[n] = list
}
type byInterval []Node
func (a byInterval) Len() int { return len(a) }
func (a byInterval) Less(i, j int) bool {
pi, pj := a[i].Pos(), a[j].Pos()
return pi < pj || pi == pj && a[i].End() > a[j].End()
}
func (a byInterval) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// nodeList returns the list of nodes of the AST n in source order.
func nodeList(n Node) []Node {
var list []Node
Inspect(n, func(n Node) bool {
// don't collect comments
switch n.(type) {
case nil, *CommentGroup, *Comment:
return false
}
list = append(list, n)
return true
})
// Note: The current implementation assumes that Inspect traverses the
// AST in depth-first and thus _source_ order. If AST traversal
// does not follow source order, the sorting call below will be
// required.
// sort.Sort(byInterval(list))
return list
}
// A commentListReader helps iterating through a list of comment groups.
type commentListReader struct {
fset *token.FileSet
list []*CommentGroup
index int
comment *CommentGroup // comment group at current index
pos, end token.Position // source interval of comment group at current index
}
func (r *commentListReader) eol() bool {
return r.index >= len(r.list)
}
func (r *commentListReader) next() {
if !r.eol() {
r.comment = r.list[r.index]
r.pos = r.fset.Position(r.comment.Pos())
r.end = r.fset.Position(r.comment.End())
r.index++
}
}
// A nodeStack keeps track of nested nodes.
// A node lower on the stack lexically contains the nodes higher on the stack.
type nodeStack []Node
// push pops all nodes that appear lexically before n
// and then pushes n on the stack.
func (s *nodeStack) push(n Node) {
s.pop(n.Pos())
*s = append((*s), n)
}
// pop pops all nodes that appear lexically before pos
// (i.e., whose lexical extent has ended before or at pos).
// It returns the last node popped.
func (s *nodeStack) pop(pos token.Pos) (top Node) {
i := len(*s)
for i > 0 && (*s)[i-1].End() <= pos {
top = (*s)[i-1]
i--
}
*s = (*s)[0:i]
return top
}
// NewCommentMap creates a new comment map by associating comment groups
// of the comments list with the nodes of the AST specified by node.
//
// A comment group g is associated with a node n if:
//
// - g starts on the same line as n ends
// - g starts on the line immediately following n, and there is
// at least one empty line after g and before the next node
// - g starts before n and is not associated to the node before n
// via the previous rules
//
// NewCommentMap tries to associate a comment group to the "largest"
// node possible: For instance, if the comment is a line comment
// trailing an assignment, the comment is associated with the entire
// assignment rather than just the last operand in the assignment.
func NewCommentMap(fset *token.FileSet, node Node, comments []*CommentGroup) CommentMap {
if len(comments) == 0 {
return nil // no comments to map
}
cmap := make(CommentMap)
// set up comment reader r
tmp := make([]*CommentGroup, len(comments))
copy(tmp, comments) // don't change incoming comments
sortComments(tmp)
r := commentListReader{fset: fset, list: tmp} // !r.eol() because len(comments) > 0
r.next()
// create node list in lexical order
nodes := nodeList(node)
nodes = append(nodes, nil) // append sentinel
// set up iteration variables
var (
p Node // previous node
pend token.Position // end of p
pg Node // previous node group (enclosing nodes of "importance")
pgend token.Position // end of pg
stack nodeStack // stack of node groups
)
for _, q := range nodes {
var qpos token.Position
if q != nil {
qpos = fset.Position(q.Pos()) // current node position
} else {
// set fake sentinel position to infinity so that
// all comments get processed before the sentinel
const infinity = 1 << 30
qpos.Offset = infinity
qpos.Line = infinity
}
// process comments before current node
for r.end.Offset <= qpos.Offset {
// determine recent node group
if top := stack.pop(r.comment.Pos()); top != nil {
pg = top
pgend = fset.Position(pg.End())
}
// Try to associate a comment first with a node group
// (i.e., a node of "importance" such as a declaration);
// if that fails, try to associate it with the most recent
// node.
// TODO(gri) try to simplify the logic below
var assoc Node
switch {
case pg != nil &&
(pgend.Line == r.pos.Line ||
pgend.Line+1 == r.pos.Line && r.end.Line+1 < qpos.Line):
// 1) comment starts on same line as previous node group ends, or
// 2) comment starts on the line immediately after the
// previous node group and there is an empty line before
// the current node
// => associate comment with previous node group
assoc = pg
case p != nil &&
(pend.Line == r.pos.Line ||
pend.Line+1 == r.pos.Line && r.end.Line+1 < qpos.Line ||
q == nil):
// same rules apply as above for p rather than pg,
// but also associate with p if we are at the end (q == nil)
assoc = p
default:
// otherwise, associate comment with current node
if q == nil {
// we can only reach here if there was no p
// which would imply that there were no nodes
panic("internal error: no comments should be associated with sentinel")
}
assoc = q
}
cmap.addComment(assoc, r.comment)
if r.eol() {
return cmap
}
r.next()
}
// update previous node
p = q
pend = fset.Position(p.End())
// update previous node group if we see an "important" node
switch q.(type) {
case *File, *Field, Decl, Spec, Stmt:
stack.push(q)
}
}
return cmap
}
// Update replaces an old node in the comment map with the new node
// and returns the new node. Comments that were associated with the
// old node are associated with the new node.
func (cmap CommentMap) Update(old, new Node) Node {
if list := cmap[old]; len(list) > 0 {
delete(cmap, old)
cmap[new] = append(cmap[new], list...)
}
return new
}
// Filter returns a new comment map consisting of only those
// entries of cmap for which a corresponding node exists in
// the AST specified by node.
func (cmap CommentMap) Filter(node Node) CommentMap {
umap := make(CommentMap)
Inspect(node, func(n Node) bool {
if g := cmap[n]; len(g) > 0 {
umap[n] = g
}
return true
})
return umap
}
// Comments returns the list of comment groups in the comment map.
// The result is sorted in source order.
func (cmap CommentMap) Comments() []*CommentGroup {
list := make([]*CommentGroup, 0, len(cmap))
for _, e := range cmap {
list = append(list, e...)
}
sortComments(list)
return list
}
func summary(list []*CommentGroup) string {
const maxLen = 40
var buf bytes.Buffer
// collect comments text
loop:
for _, group := range list {
// Note: CommentGroup.Text() does too much work for what we
// need and would only replace this innermost loop.
// Just do it explicitly.
for _, comment := range group.List {
if buf.Len() >= maxLen {
break loop
}
buf.WriteString(comment.Text)
}
}
// truncate if too long
if buf.Len() > maxLen {
buf.Truncate(maxLen - 3)
buf.WriteString("...")
}
// replace any invisibles with blanks
bytes := buf.Bytes()
for i, b := range bytes {
switch b {
case '\t', '\n', '\r':
bytes[i] = ' '
}
}
return string(bytes)
}
func (cmap CommentMap) String() string {
// print map entries in sorted order
var nodes []Node
for node := range cmap {
nodes = append(nodes, node)
}
sort.Sort(byInterval(nodes))
var buf strings.Builder
fmt.Fprintln(&buf, "CommentMap {")
for _, node := range nodes {
comment := cmap[node]
// print name of identifiers; print node type for other nodes
var s string
if ident, ok := node.(*Ident); ok {
s = ident.Name
} else {
s = fmt.Sprintf("%T", node)
}
fmt.Fprintf(&buf, "\t%p %20s: %s\n", node, s, summary(comment))
}
fmt.Fprintln(&buf, "}")
return buf.String()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ast
import (
"go/token"
"sort"
)
// ----------------------------------------------------------------------------
// Export filtering
// exportFilter is a special filter function to extract exported nodes.
func exportFilter(name string) bool {
return IsExported(name)
}
// FileExports trims the AST for a Go source file in place such that
// only exported nodes remain: all top-level identifiers which are not exported
// and their associated information (such as type, initial value, or function
// body) are removed. Non-exported fields and methods of exported types are
// stripped. The File.Comments list is not changed.
//
// FileExports reports whether there are exported declarations.
func FileExports(src *File) bool {
return filterFile(src, exportFilter, true)
}
// PackageExports trims the AST for a Go package in place such that
// only exported nodes remain. The pkg.Files list is not changed, so that
// file names and top-level package comments don't get lost.
//
// PackageExports reports whether there are exported declarations;
// it returns false otherwise.
func PackageExports(pkg *Package) bool {
return filterPackage(pkg, exportFilter, true)
}
// ----------------------------------------------------------------------------
// General filtering
type Filter func(string) bool
func filterIdentList(list []*Ident, f Filter) []*Ident {
j := 0
for _, x := range list {
if f(x.Name) {
list[j] = x
j++
}
}
return list[0:j]
}
// fieldName assumes that x is the type of an anonymous field and
// returns the corresponding field name. If x is not an acceptable
// anonymous field, the result is nil.
func fieldName(x Expr) *Ident {
switch t := x.(type) {
case *Ident:
return t
case *SelectorExpr:
if _, ok := t.X.(*Ident); ok {
return t.Sel
}
case *StarExpr:
return fieldName(t.X)
}
return nil
}
func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil {
return false
}
list := fields.List
j := 0
for _, f := range list {
keepField := false
if len(f.Names) == 0 {
// anonymous field
name := fieldName(f.Type)
keepField = name != nil && filter(name.Name)
} else {
n := len(f.Names)
f.Names = filterIdentList(f.Names, filter)
if len(f.Names) < n {
removedFields = true
}
keepField = len(f.Names) > 0
}
if keepField {
if export {
filterType(f.Type, filter, export)
}
list[j] = f
j++
}
}
if j < len(list) {
removedFields = true
}
fields.List = list[0:j]
return
}
func filterCompositeLit(lit *CompositeLit, filter Filter, export bool) {
n := len(lit.Elts)
lit.Elts = filterExprList(lit.Elts, filter, export)
if len(lit.Elts) < n {
lit.Incomplete = true
}
}
func filterExprList(list []Expr, filter Filter, export bool) []Expr {
j := 0
for _, exp := range list {
switch x := exp.(type) {
case *CompositeLit:
filterCompositeLit(x, filter, export)
case *KeyValueExpr:
if x, ok := x.Key.(*Ident); ok && !filter(x.Name) {
continue
}
if x, ok := x.Value.(*CompositeLit); ok {
filterCompositeLit(x, filter, export)
}
}
list[j] = exp
j++
}
return list[0:j]
}
func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil {
return false
}
var b bool
for _, f := range fields.List {
if filterType(f.Type, filter, export) {
b = true
}
}
return b
}
func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) {
case *Ident:
return f(t.Name)
case *ParenExpr:
return filterType(t.X, f, export)
case *ArrayType:
return filterType(t.Elt, f, export)
case *StructType:
if filterFieldList(t.Fields, f, export) {
t.Incomplete = true
}
return len(t.Fields.List) > 0
case *FuncType:
b1 := filterParamList(t.Params, f, export)
b2 := filterParamList(t.Results, f, export)
return b1 || b2
case *InterfaceType:
if filterFieldList(t.Methods, f, export) {
t.Incomplete = true
}
return len(t.Methods.List) > 0
case *MapType:
b1 := filterType(t.Key, f, export)
b2 := filterType(t.Value, f, export)
return b1 || b2
case *ChanType:
return filterType(t.Value, f, export)
}
return false
}
func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) {
case *ValueSpec:
s.Names = filterIdentList(s.Names, f)
s.Values = filterExprList(s.Values, f, export)
if len(s.Names) > 0 {
if export {
filterType(s.Type, f, export)
}
return true
}
case *TypeSpec:
if f(s.Name.Name) {
if export {
filterType(s.Type, f, export)
}
return true
}
if !export {
// For general filtering (not just exports),
// filter type even if name is not filtered
// out.
// If the type contains filtered elements,
// keep the declaration.
return filterType(s.Type, f, export)
}
}
return false
}
func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0
for _, s := range list {
if filterSpec(s, f, export) {
list[j] = s
j++
}
}
return list[0:j]
}
// FilterDecl trims the AST for a Go declaration in place by removing
// all names (including struct field and interface method names, but
// not from parameter lists) that don't pass through the filter f.
//
// FilterDecl reports whether there are any declared names left after
// filtering.
func FilterDecl(decl Decl, f Filter) bool {
return filterDecl(decl, f, false)
}
func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) {
case *GenDecl:
d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0
case *FuncDecl:
return f(d.Name.Name)
}
return false
}
// FilterFile trims the AST for a Go file in place by removing all
// names from top-level declarations (including struct field and
// interface method names, but not from parameter lists) that don't
// pass through the filter f. If the declaration is empty afterwards,
// the declaration is removed from the AST. Import declarations are
// always removed. The File.Comments list is not changed.
//
// FilterFile reports whether there are any top-level declarations
// left after filtering.
func FilterFile(src *File, f Filter) bool {
return filterFile(src, f, false)
}
func filterFile(src *File, f Filter, export bool) bool {
j := 0
for _, d := range src.Decls {
if filterDecl(d, f, export) {
src.Decls[j] = d
j++
}
}
src.Decls = src.Decls[0:j]
return j > 0
}
// FilterPackage trims the AST for a Go package in place by removing
// all names from top-level declarations (including struct field and
// interface method names, but not from parameter lists) that don't
// pass through the filter f. If the declaration is empty afterwards,
// the declaration is removed from the AST. The pkg.Files list is not
// changed, so that file names and top-level package comments don't get
// lost.
//
// FilterPackage reports whether there are any top-level declarations
// left after filtering.
func FilterPackage(pkg *Package, f Filter) bool {
return filterPackage(pkg, f, false)
}
func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false
for _, src := range pkg.Files {
if filterFile(src, f, export) {
hasDecls = true
}
}
return hasDecls
}
// ----------------------------------------------------------------------------
// Merging of package files
// The MergeMode flags control the behavior of MergePackageFiles.
type MergeMode uint
const (
// If set, duplicate function declarations are excluded.
FilterFuncDuplicates MergeMode = 1 << iota
// If set, comments that are not associated with a specific
// AST node (as Doc or Comment) are excluded.
FilterUnassociatedComments
// If set, duplicate import declarations are excluded.
FilterImportDuplicates
)
// nameOf returns the function (foo) or method name (foo.bar) for
// the given function declaration. If the AST is incorrect for the
// receiver, it assumes a function instead.
func nameOf(f *FuncDecl) string {
if r := f.Recv; r != nil && len(r.List) == 1 {
// looks like a correct receiver declaration
t := r.List[0].Type
// dereference pointer receiver types
if p, _ := t.(*StarExpr); p != nil {
t = p.X
}
// the receiver type must be a type name
if p, _ := t.(*Ident); p != nil {
return p.Name + "." + f.Name.Name
}
// otherwise assume a function instead
}
return f.Name.Name
}
// separator is an empty //-style comment that is interspersed between
// different comment groups when they are concatenated into a single group
var separator = &Comment{token.NoPos, "//"}
// MergePackageFiles creates a file AST by merging the ASTs of the
// files belonging to a package. The mode flags control merging behavior.
func MergePackageFiles(pkg *Package, mode MergeMode) *File {
// Count the number of package docs, comments and declarations across
// all package files. Also, compute sorted list of filenames, so that
// subsequent iterations can always iterate in the same order.
ndocs := 0
ncomments := 0
ndecls := 0
filenames := make([]string, len(pkg.Files))
var minPos, maxPos token.Pos
i := 0
for filename, f := range pkg.Files {
filenames[i] = filename
i++
if f.Doc != nil {
ndocs += len(f.Doc.List) + 1 // +1 for separator
}
ncomments += len(f.Comments)
ndecls += len(f.Decls)
if i == 0 || f.FileStart < minPos {
minPos = f.FileStart
}
if i == 0 || f.FileEnd > maxPos {
maxPos = f.FileEnd
}
}
sort.Strings(filenames)
// Collect package comments from all package files into a single
// CommentGroup - the collected package documentation. In general
// there should be only one file with a package comment; but it's
// better to collect extra comments than drop them on the floor.
var doc *CommentGroup
var pos token.Pos
if ndocs > 0 {
list := make([]*Comment, ndocs-1) // -1: no separator before first group
i := 0
for _, filename := range filenames {
f := pkg.Files[filename]
if f.Doc != nil {
if i > 0 {
// not the first group - add separator
list[i] = separator
i++
}
for _, c := range f.Doc.List {
list[i] = c
i++
}
if f.Package > pos {
// Keep the maximum package clause position as
// position for the package clause of the merged
// files.
pos = f.Package
}
}
}
doc = &CommentGroup{list}
}
// Collect declarations from all package files.
var decls []Decl
if ndecls > 0 {
decls = make([]Decl, ndecls)
funcs := make(map[string]int) // map of func name -> decls index
i := 0 // current index
n := 0 // number of filtered entries
for _, filename := range filenames {
f := pkg.Files[filename]
for _, d := range f.Decls {
if mode&FilterFuncDuplicates != 0 {
// A language entity may be declared multiple
// times in different package files; only at
// build time declarations must be unique.
// For now, exclude multiple declarations of
// functions - keep the one with documentation.
//
// TODO(gri): Expand this filtering to other
// entities (const, type, vars) if
// multiple declarations are common.
if f, isFun := d.(*FuncDecl); isFun {
name := nameOf(f)
if j, exists := funcs[name]; exists {
// function declared already
if decls[j] != nil && decls[j].(*FuncDecl).Doc == nil {
// existing declaration has no documentation;
// ignore the existing declaration
decls[j] = nil
} else {
// ignore the new declaration
d = nil
}
n++ // filtered an entry
} else {
funcs[name] = i
}
}
}
decls[i] = d
i++
}
}
// Eliminate nil entries from the decls list if entries were
// filtered. We do this using a 2nd pass in order to not disturb
// the original declaration order in the source (otherwise, this
// would also invalidate the monotonically increasing position
// info within a single file).
if n > 0 {
i = 0
for _, d := range decls {
if d != nil {
decls[i] = d
i++
}
}
decls = decls[0:i]
}
}
// Collect import specs from all package files.
var imports []*ImportSpec
if mode&FilterImportDuplicates != 0 {
seen := make(map[string]bool)
for _, filename := range filenames {
f := pkg.Files[filename]
for _, imp := range f.Imports {
if path := imp.Path.Value; !seen[path] {
// TODO: consider handling cases where:
// - 2 imports exist with the same import path but
// have different local names (one should probably
// keep both of them)
// - 2 imports exist but only one has a comment
// - 2 imports exist and they both have (possibly
// different) comments
imports = append(imports, imp)
seen[path] = true
}
}
}
} else {
// Iterate over filenames for deterministic order.
for _, filename := range filenames {
f := pkg.Files[filename]
imports = append(imports, f.Imports...)
}
}
// Collect comments from all package files.
var comments []*CommentGroup
if mode&FilterUnassociatedComments == 0 {
comments = make([]*CommentGroup, ncomments)
i := 0
for _, filename := range filenames {
f := pkg.Files[filename]
i += copy(comments[i:], f.Comments)
}
}
// TODO(gri) need to compute unresolved identifiers!
return &File{doc, pos, NewIdent(pkg.Name), decls, minPos, maxPos, pkg.Scope, imports, nil, comments}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ast
import (
"go/token"
"sort"
"strconv"
)
// SortImports sorts runs of consecutive import lines in import blocks in f.
// It also removes duplicate imports when it is possible to do so without data loss.
func SortImports(fset *token.FileSet, f *File) {
for _, d := range f.Decls {
d, ok := d.(*GenDecl)
if !ok || d.Tok != token.IMPORT {
// Not an import declaration, so we're done.
// Imports are always first.
break
}
if !d.Lparen.IsValid() {
// Not a block: sorted by default.
continue
}
// Identify and sort runs of specs on successive lines.
i := 0
specs := d.Specs[:0]
for j, s := range d.Specs {
if j > i && lineAt(fset, s.Pos()) > 1+lineAt(fset, d.Specs[j-1].End()) {
// j begins a new run. End this one.
specs = append(specs, sortSpecs(fset, f, d.Specs[i:j])...)
i = j
}
}
specs = append(specs, sortSpecs(fset, f, d.Specs[i:])...)
d.Specs = specs
// Deduping can leave a blank line before the rparen; clean that up.
if len(d.Specs) > 0 {
lastSpec := d.Specs[len(d.Specs)-1]
lastLine := lineAt(fset, lastSpec.Pos())
rParenLine := lineAt(fset, d.Rparen)
for rParenLine > lastLine+1 {
rParenLine--
fset.File(d.Rparen).MergeLine(rParenLine)
}
}
}
}
func lineAt(fset *token.FileSet, pos token.Pos) int {
return fset.PositionFor(pos, false).Line
}
func importPath(s Spec) string {
t, err := strconv.Unquote(s.(*ImportSpec).Path.Value)
if err == nil {
return t
}
return ""
}
func importName(s Spec) string {
n := s.(*ImportSpec).Name
if n == nil {
return ""
}
return n.Name
}
func importComment(s Spec) string {
c := s.(*ImportSpec).Comment
if c == nil {
return ""
}
return c.Text()
}
// collapse indicates whether prev may be removed, leaving only next.
func collapse(prev, next Spec) bool {
if importPath(next) != importPath(prev) || importName(next) != importName(prev) {
return false
}
return prev.(*ImportSpec).Comment == nil
}
type posSpan struct {
Start token.Pos
End token.Pos
}
type cgPos struct {
left bool // true if comment is to the left of the spec, false otherwise.
cg *CommentGroup
}
func sortSpecs(fset *token.FileSet, f *File, specs []Spec) []Spec {
// Can't short-circuit here even if specs are already sorted,
// since they might yet need deduplication.
// A lone import, however, may be safely ignored.
if len(specs) <= 1 {
return specs
}
// Record positions for specs.
pos := make([]posSpan, len(specs))
for i, s := range specs {
pos[i] = posSpan{s.Pos(), s.End()}
}
// Identify comments in this range.
begSpecs := pos[0].Start
endSpecs := pos[len(pos)-1].End
beg := fset.File(begSpecs).LineStart(lineAt(fset, begSpecs))
endLine := lineAt(fset, endSpecs)
endFile := fset.File(endSpecs)
var end token.Pos
if endLine == endFile.LineCount() {
end = endSpecs
} else {
end = endFile.LineStart(endLine + 1) // beginning of next line
}
first := len(f.Comments)
last := -1
for i, g := range f.Comments {
if g.End() >= end {
break
}
// g.End() < end
if beg <= g.Pos() {
// comment is within the range [beg, end[ of import declarations
if i < first {
first = i
}
if i > last {
last = i
}
}
}
var comments []*CommentGroup
if last >= 0 {
comments = f.Comments[first : last+1]
}
// Assign each comment to the import spec on the same line.
importComments := map[*ImportSpec][]cgPos{}
specIndex := 0
for _, g := range comments {
for specIndex+1 < len(specs) && pos[specIndex+1].Start <= g.Pos() {
specIndex++
}
var left bool
// A block comment can appear before the first import spec.
if specIndex == 0 && pos[specIndex].Start > g.Pos() {
left = true
} else if specIndex+1 < len(specs) && // Or it can appear on the left of an import spec.
lineAt(fset, pos[specIndex].Start)+1 == lineAt(fset, g.Pos()) {
specIndex++
left = true
}
s := specs[specIndex].(*ImportSpec)
importComments[s] = append(importComments[s], cgPos{left: left, cg: g})
}
// Sort the import specs by import path.
// Remove duplicates, when possible without data loss.
// Reassign the import paths to have the same position sequence.
// Reassign each comment to the spec on the same line.
// Sort the comments by new position.
sort.Slice(specs, func(i, j int) bool {
ipath := importPath(specs[i])
jpath := importPath(specs[j])
if ipath != jpath {
return ipath < jpath
}
iname := importName(specs[i])
jname := importName(specs[j])
if iname != jname {
return iname < jname
}
return importComment(specs[i]) < importComment(specs[j])
})
// Dedup. Thanks to our sorting, we can just consider
// adjacent pairs of imports.
deduped := specs[:0]
for i, s := range specs {
if i == len(specs)-1 || !collapse(s, specs[i+1]) {
deduped = append(deduped, s)
} else {
p := s.Pos()
fset.File(p).MergeLine(lineAt(fset, p))
}
}
specs = deduped
// Fix up comment positions
for i, s := range specs {
s := s.(*ImportSpec)
if s.Name != nil {
s.Name.NamePos = pos[i].Start
}
s.Path.ValuePos = pos[i].Start
s.EndPos = pos[i].End
for _, g := range importComments[s] {
for _, c := range g.cg.List {
if g.left {
c.Slash = pos[i].Start - 1
} else {
// An import spec can have both block comment and a line comment
// to its right. In that case, both of them will have the same pos.
// But while formatting the AST, the line comment gets moved to
// after the block comment.
c.Slash = pos[i].End
}
}
}
}
sort.Slice(comments, func(i, j int) bool {
return comments[i].Pos() < comments[j].Pos()
})
return specs
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains printing support for ASTs.
package ast
import (
"fmt"
"go/token"
"io"
"os"
"reflect"
)
// A FieldFilter may be provided to Fprint to control the output.
type FieldFilter func(name string, value reflect.Value) bool
// NotNilFilter returns true for field values that are not nil;
// it returns false otherwise.
func NotNilFilter(_ string, v reflect.Value) bool {
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return !v.IsNil()
}
return true
}
// Fprint prints the (sub-)tree starting at AST node x to w.
// If fset != nil, position information is interpreted relative
// to that file set. Otherwise positions are printed as integer
// values (file set specific offsets).
//
// A non-nil FieldFilter f may be provided to control the output:
// struct fields for which f(fieldname, fieldvalue) is true are
// printed; all others are filtered from the output. Unexported
// struct fields are never printed.
func Fprint(w io.Writer, fset *token.FileSet, x any, f FieldFilter) error {
return fprint(w, fset, x, f)
}
func fprint(w io.Writer, fset *token.FileSet, x any, f FieldFilter) (err error) {
// setup printer
p := printer{
output: w,
fset: fset,
filter: f,
ptrmap: make(map[any]int),
last: '\n', // force printing of line number on first line
}
// install error handler
defer func() {
if e := recover(); e != nil {
err = e.(localError).err // re-panics if it's not a localError
}
}()
// print x
if x == nil {
p.printf("nil\n")
return
}
p.print(reflect.ValueOf(x))
p.printf("\n")
return
}
// Print prints x to standard output, skipping nil fields.
// Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
func Print(fset *token.FileSet, x any) error {
return Fprint(os.Stdout, fset, x, NotNilFilter)
}
type printer struct {
output io.Writer
fset *token.FileSet
filter FieldFilter
ptrmap map[any]int // *T -> line number
indent int // current indentation level
last byte // the last byte processed by Write
line int // current line number
}
var indent = []byte(". ")
func (p *printer) Write(data []byte) (n int, err error) {
var m int
for i, b := range data {
// invariant: data[0:n] has been written
if b == '\n' {
m, err = p.output.Write(data[n : i+1])
n += m
if err != nil {
return
}
p.line++
} else if p.last == '\n' {
_, err = fmt.Fprintf(p.output, "%6d ", p.line)
if err != nil {
return
}
for j := p.indent; j > 0; j-- {
_, err = p.output.Write(indent)
if err != nil {
return
}
}
}
p.last = b
}
if len(data) > n {
m, err = p.output.Write(data[n:])
n += m
}
return
}
// localError wraps locally caught errors so we can distinguish
// them from genuine panics which we don't want to return as errors.
type localError struct {
err error
}
// printf is a convenience wrapper that takes care of print errors.
func (p *printer) printf(format string, args ...any) {
if _, err := fmt.Fprintf(p, format, args...); err != nil {
panic(localError{err})
}
}
// Implementation note: Print is written for AST nodes but could be
// used to print arbitrary data structures; such a version should
// probably be in a different package.
//
// Note: This code detects (some) cycles created via pointers but
// not cycles that are created via slices or maps containing the
// same slice or map. Code for general data structures probably
// should catch those as well.
func (p *printer) print(x reflect.Value) {
if !NotNilFilter("", x) {
p.printf("nil")
return
}
switch x.Kind() {
case reflect.Interface:
p.print(x.Elem())
case reflect.Map:
p.printf("%s (len = %d) {", x.Type(), x.Len())
if x.Len() > 0 {
p.indent++
p.printf("\n")
for _, key := range x.MapKeys() {
p.print(key)
p.printf(": ")
p.print(x.MapIndex(key))
p.printf("\n")
}
p.indent--
}
p.printf("}")
case reflect.Pointer:
p.printf("*")
// type-checked ASTs may contain cycles - use ptrmap
// to keep track of objects that have been printed
// already and print the respective line number instead
ptr := x.Interface()
if line, exists := p.ptrmap[ptr]; exists {
p.printf("(obj @ %d)", line)
} else {
p.ptrmap[ptr] = p.line
p.print(x.Elem())
}
case reflect.Array:
p.printf("%s {", x.Type())
if x.Len() > 0 {
p.indent++
p.printf("\n")
for i, n := 0, x.Len(); i < n; i++ {
p.printf("%d: ", i)
p.print(x.Index(i))
p.printf("\n")
}
p.indent--
}
p.printf("}")
case reflect.Slice:
if s, ok := x.Interface().([]byte); ok {
p.printf("%#q", s)
return
}
p.printf("%s (len = %d) {", x.Type(), x.Len())
if x.Len() > 0 {
p.indent++
p.printf("\n")
for i, n := 0, x.Len(); i < n; i++ {
p.printf("%d: ", i)
p.print(x.Index(i))
p.printf("\n")
}
p.indent--
}
p.printf("}")
case reflect.Struct:
t := x.Type()
p.printf("%s {", t)
p.indent++
first := true
for i, n := 0, t.NumField(); i < n; i++ {
// exclude non-exported fields because their
// values cannot be accessed via reflection
if name := t.Field(i).Name; IsExported(name) {
value := x.Field(i)
if p.filter == nil || p.filter(name, value) {
if first {
p.printf("\n")
first = false
}
p.printf("%s: ", name)
p.print(value)
p.printf("\n")
}
}
}
p.indent--
p.printf("}")
default:
v := x.Interface()
switch v := v.(type) {
case string:
// print strings in quotes
p.printf("%q", v)
return
case token.Pos:
// position values can be printed nicely if we have a file set
if p.fset != nil {
p.printf("%s", p.fset.Position(v))
return
}
}
// default
p.printf("%v", v)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements NewPackage.
package ast
import (
"fmt"
"go/scanner"
"go/token"
"strconv"
)
type pkgBuilder struct {
fset *token.FileSet
errors scanner.ErrorList
}
func (p *pkgBuilder) error(pos token.Pos, msg string) {
p.errors.Add(p.fset.Position(pos), msg)
}
func (p *pkgBuilder) errorf(pos token.Pos, format string, args ...any) {
p.error(pos, fmt.Sprintf(format, args...))
}
func (p *pkgBuilder) declare(scope, altScope *Scope, obj *Object) {
alt := scope.Insert(obj)
if alt == nil && altScope != nil {
// see if there is a conflicting declaration in altScope
alt = altScope.Lookup(obj.Name)
}
if alt != nil {
prevDecl := ""
if pos := alt.Pos(); pos.IsValid() {
prevDecl = fmt.Sprintf("\n\tprevious declaration at %s", p.fset.Position(pos))
}
p.error(obj.Pos(), fmt.Sprintf("%s redeclared in this block%s", obj.Name, prevDecl))
}
}
func resolve(scope *Scope, ident *Ident) bool {
for ; scope != nil; scope = scope.Outer {
if obj := scope.Lookup(ident.Name); obj != nil {
ident.Obj = obj
return true
}
}
return false
}
// An Importer resolves import paths to package Objects.
// The imports map records the packages already imported,
// indexed by package id (canonical import path).
// An Importer must determine the canonical import path and
// check the map to see if it is already present in the imports map.
// If so, the Importer can return the map entry. Otherwise, the
// Importer should load the package data for the given path into
// a new *Object (pkg), record pkg in the imports map, and then
// return pkg.
type Importer func(imports map[string]*Object, path string) (pkg *Object, err error)
// NewPackage creates a new Package node from a set of File nodes. It resolves
// unresolved identifiers across files and updates each file's Unresolved list
// accordingly. If a non-nil importer and universe scope are provided, they are
// used to resolve identifiers not declared in any of the package files. Any
// remaining unresolved identifiers are reported as undeclared. If the files
// belong to different packages, one package name is selected and files with
// different package names are reported and then ignored.
// The result is a package node and a scanner.ErrorList if there were errors.
func NewPackage(fset *token.FileSet, files map[string]*File, importer Importer, universe *Scope) (*Package, error) {
var p pkgBuilder
p.fset = fset
// complete package scope
pkgName := ""
pkgScope := NewScope(universe)
for _, file := range files {
// package names must match
switch name := file.Name.Name; {
case pkgName == "":
pkgName = name
case name != pkgName:
p.errorf(file.Package, "package %s; expected %s", name, pkgName)
continue // ignore this file
}
// collect top-level file objects in package scope
for _, obj := range file.Scope.Objects {
p.declare(pkgScope, nil, obj)
}
}
// package global mapping of imported package ids to package objects
imports := make(map[string]*Object)
// complete file scopes with imports and resolve identifiers
for _, file := range files {
// ignore file if it belongs to a different package
// (error has already been reported)
if file.Name.Name != pkgName {
continue
}
// build file scope by processing all imports
importErrors := false
fileScope := NewScope(pkgScope)
for _, spec := range file.Imports {
if importer == nil {
importErrors = true
continue
}
path, _ := strconv.Unquote(spec.Path.Value)
pkg, err := importer(imports, path)
if err != nil {
p.errorf(spec.Path.Pos(), "could not import %s (%s)", path, err)
importErrors = true
continue
}
// TODO(gri) If a local package name != "." is provided,
// global identifier resolution could proceed even if the
// import failed. Consider adjusting the logic here a bit.
// local name overrides imported package name
name := pkg.Name
if spec.Name != nil {
name = spec.Name.Name
}
// add import to file scope
if name == "." {
// merge imported scope with file scope
for _, obj := range pkg.Data.(*Scope).Objects {
p.declare(fileScope, pkgScope, obj)
}
} else if name != "_" {
// declare imported package object in file scope
// (do not re-use pkg in the file scope but create
// a new object instead; the Decl field is different
// for different files)
obj := NewObj(Pkg, name)
obj.Decl = spec
obj.Data = pkg.Data
p.declare(fileScope, pkgScope, obj)
}
}
// resolve identifiers
if importErrors {
// don't use the universe scope without correct imports
// (objects in the universe may be shadowed by imports;
// with missing imports, identifiers might get resolved
// incorrectly to universe objects)
pkgScope.Outer = nil
}
i := 0
for _, ident := range file.Unresolved {
if !resolve(fileScope, ident) {
p.errorf(ident.Pos(), "undeclared name: %s", ident.Name)
file.Unresolved[i] = ident
i++
}
}
file.Unresolved = file.Unresolved[0:i]
pkgScope.Outer = universe // reset universe scope
}
p.errors.Sort()
return &Package{pkgName, pkgScope, imports, files}, p.errors.Err()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements scopes and the objects they contain.
package ast
import (
"fmt"
"go/token"
"strings"
)
// A Scope maintains the set of named language entities declared
// in the scope and a link to the immediately surrounding (outer)
// scope.
type Scope struct {
Outer *Scope
Objects map[string]*Object
}
// NewScope creates a new scope nested in the outer scope.
func NewScope(outer *Scope) *Scope {
const n = 4 // initial scope capacity
return &Scope{outer, make(map[string]*Object, n)}
}
// Lookup returns the object with the given name if it is
// found in scope s, otherwise it returns nil. Outer scopes
// are ignored.
func (s *Scope) Lookup(name string) *Object {
return s.Objects[name]
}
// Insert attempts to insert a named object obj into the scope s.
// If the scope already contains an object alt with the same name,
// Insert leaves the scope unchanged and returns alt. Otherwise
// it inserts obj and returns nil.
func (s *Scope) Insert(obj *Object) (alt *Object) {
if alt = s.Objects[obj.Name]; alt == nil {
s.Objects[obj.Name] = obj
}
return
}
// Debugging support
func (s *Scope) String() string {
var buf strings.Builder
fmt.Fprintf(&buf, "scope %p {", s)
if s != nil && len(s.Objects) > 0 {
fmt.Fprintln(&buf)
for _, obj := range s.Objects {
fmt.Fprintf(&buf, "\t%s %s\n", obj.Kind, obj.Name)
}
}
fmt.Fprintf(&buf, "}\n")
return buf.String()
}
// ----------------------------------------------------------------------------
// Objects
// An Object describes a named language entity such as a package,
// constant, type, variable, function (incl. methods), or label.
//
// The Data fields contains object-specific data:
//
// Kind Data type Data value
// Pkg *Scope package scope
// Con int iota for the respective declaration
type Object struct {
Kind ObjKind
Name string // declared name
Decl any // corresponding Field, XxxSpec, FuncDecl, LabeledStmt, AssignStmt, Scope; or nil
Data any // object-specific data; or nil
Type any // placeholder for type information; may be nil
}
// NewObj creates a new object of a given kind and name.
func NewObj(kind ObjKind, name string) *Object {
return &Object{Kind: kind, Name: name}
}
// Pos computes the source position of the declaration of an object name.
// The result may be an invalid position if it cannot be computed
// (obj.Decl may be nil or not correct).
func (obj *Object) Pos() token.Pos {
name := obj.Name
switch d := obj.Decl.(type) {
case *Field:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *ImportSpec:
if d.Name != nil && d.Name.Name == name {
return d.Name.Pos()
}
return d.Path.Pos()
case *ValueSpec:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *TypeSpec:
if d.Name.Name == name {
return d.Name.Pos()
}
case *FuncDecl:
if d.Name.Name == name {
return d.Name.Pos()
}
case *LabeledStmt:
if d.Label.Name == name {
return d.Label.Pos()
}
case *AssignStmt:
for _, x := range d.Lhs {
if ident, isIdent := x.(*Ident); isIdent && ident.Name == name {
return ident.Pos()
}
}
case *Scope:
// predeclared object - nothing to do for now
}
return token.NoPos
}
// ObjKind describes what an object represents.
type ObjKind int
// The list of possible Object kinds.
const (
Bad ObjKind = iota // for error handling
Pkg // package
Con // constant
Typ // type
Var // variable
Fun // function or method
Lbl // label
)
var objKindStrings = [...]string{
Bad: "bad",
Pkg: "package",
Con: "const",
Typ: "type",
Var: "var",
Fun: "func",
Lbl: "label",
}
func (kind ObjKind) String() string { return objKindStrings[kind] }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ast
import "fmt"
// A Visitor's Visit method is invoked for each node encountered by Walk.
// If the result visitor w is not nil, Walk visits each of the children
// of node with the visitor w, followed by a call of w.Visit(nil).
type Visitor interface {
Visit(node Node) (w Visitor)
}
// Helper functions for common node lists. They may be empty.
func walkIdentList(v Visitor, list []*Ident) {
for _, x := range list {
Walk(v, x)
}
}
func walkExprList(v Visitor, list []Expr) {
for _, x := range list {
Walk(v, x)
}
}
func walkStmtList(v Visitor, list []Stmt) {
for _, x := range list {
Walk(v, x)
}
}
func walkDeclList(v Visitor, list []Decl) {
for _, x := range list {
Walk(v, x)
}
}
// TODO(gri): Investigate if providing a closure to Walk leads to
// simpler use (and may help eliminate Inspect in turn).
// Walk traverses an AST in depth-first order: It starts by calling
// v.Visit(node); node must not be nil. If the visitor w returned by
// v.Visit(node) is not nil, Walk is invoked recursively with visitor
// w for each of the non-nil children of node, followed by a call of
// w.Visit(nil).
func Walk(v Visitor, node Node) {
if v = v.Visit(node); v == nil {
return
}
// walk children
// (the order of the cases matches the order
// of the corresponding node types in ast.go)
switch n := node.(type) {
// Comments and fields
case *Comment:
// nothing to do
case *CommentGroup:
for _, c := range n.List {
Walk(v, c)
}
case *Field:
if n.Doc != nil {
Walk(v, n.Doc)
}
walkIdentList(v, n.Names)
if n.Type != nil {
Walk(v, n.Type)
}
if n.Tag != nil {
Walk(v, n.Tag)
}
if n.Comment != nil {
Walk(v, n.Comment)
}
case *FieldList:
for _, f := range n.List {
Walk(v, f)
}
// Expressions
case *BadExpr, *Ident, *BasicLit:
// nothing to do
case *Ellipsis:
if n.Elt != nil {
Walk(v, n.Elt)
}
case *FuncLit:
Walk(v, n.Type)
Walk(v, n.Body)
case *CompositeLit:
if n.Type != nil {
Walk(v, n.Type)
}
walkExprList(v, n.Elts)
case *ParenExpr:
Walk(v, n.X)
case *SelectorExpr:
Walk(v, n.X)
Walk(v, n.Sel)
case *IndexExpr:
Walk(v, n.X)
Walk(v, n.Index)
case *IndexListExpr:
Walk(v, n.X)
for _, index := range n.Indices {
Walk(v, index)
}
case *SliceExpr:
Walk(v, n.X)
if n.Low != nil {
Walk(v, n.Low)
}
if n.High != nil {
Walk(v, n.High)
}
if n.Max != nil {
Walk(v, n.Max)
}
case *TypeAssertExpr:
Walk(v, n.X)
if n.Type != nil {
Walk(v, n.Type)
}
case *CallExpr:
Walk(v, n.Fun)
walkExprList(v, n.Args)
case *StarExpr:
Walk(v, n.X)
case *UnaryExpr:
Walk(v, n.X)
case *BinaryExpr:
Walk(v, n.X)
Walk(v, n.Y)
case *KeyValueExpr:
Walk(v, n.Key)
Walk(v, n.Value)
// Types
case *ArrayType:
if n.Len != nil {
Walk(v, n.Len)
}
Walk(v, n.Elt)
case *StructType:
Walk(v, n.Fields)
case *FuncType:
if n.TypeParams != nil {
Walk(v, n.TypeParams)
}
if n.Params != nil {
Walk(v, n.Params)
}
if n.Results != nil {
Walk(v, n.Results)
}
case *InterfaceType:
Walk(v, n.Methods)
case *MapType:
Walk(v, n.Key)
Walk(v, n.Value)
case *ChanType:
Walk(v, n.Value)
// Statements
case *BadStmt:
// nothing to do
case *DeclStmt:
Walk(v, n.Decl)
case *EmptyStmt:
// nothing to do
case *LabeledStmt:
Walk(v, n.Label)
Walk(v, n.Stmt)
case *ExprStmt:
Walk(v, n.X)
case *SendStmt:
Walk(v, n.Chan)
Walk(v, n.Value)
case *IncDecStmt:
Walk(v, n.X)
case *AssignStmt:
walkExprList(v, n.Lhs)
walkExprList(v, n.Rhs)
case *GoStmt:
Walk(v, n.Call)
case *DeferStmt:
Walk(v, n.Call)
case *ReturnStmt:
walkExprList(v, n.Results)
case *BranchStmt:
if n.Label != nil {
Walk(v, n.Label)
}
case *BlockStmt:
walkStmtList(v, n.List)
case *IfStmt:
if n.Init != nil {
Walk(v, n.Init)
}
Walk(v, n.Cond)
Walk(v, n.Body)
if n.Else != nil {
Walk(v, n.Else)
}
case *CaseClause:
walkExprList(v, n.List)
walkStmtList(v, n.Body)
case *SwitchStmt:
if n.Init != nil {
Walk(v, n.Init)
}
if n.Tag != nil {
Walk(v, n.Tag)
}
Walk(v, n.Body)
case *TypeSwitchStmt:
if n.Init != nil {
Walk(v, n.Init)
}
Walk(v, n.Assign)
Walk(v, n.Body)
case *CommClause:
if n.Comm != nil {
Walk(v, n.Comm)
}
walkStmtList(v, n.Body)
case *SelectStmt:
Walk(v, n.Body)
case *ForStmt:
if n.Init != nil {
Walk(v, n.Init)
}
if n.Cond != nil {
Walk(v, n.Cond)
}
if n.Post != nil {
Walk(v, n.Post)
}
Walk(v, n.Body)
case *RangeStmt:
if n.Key != nil {
Walk(v, n.Key)
}
if n.Value != nil {
Walk(v, n.Value)
}
Walk(v, n.X)
Walk(v, n.Body)
// Declarations
case *ImportSpec:
if n.Doc != nil {
Walk(v, n.Doc)
}
if n.Name != nil {
Walk(v, n.Name)
}
Walk(v, n.Path)
if n.Comment != nil {
Walk(v, n.Comment)
}
case *ValueSpec:
if n.Doc != nil {
Walk(v, n.Doc)
}
walkIdentList(v, n.Names)
if n.Type != nil {
Walk(v, n.Type)
}
walkExprList(v, n.Values)
if n.Comment != nil {
Walk(v, n.Comment)
}
case *TypeSpec:
if n.Doc != nil {
Walk(v, n.Doc)
}
Walk(v, n.Name)
if n.TypeParams != nil {
Walk(v, n.TypeParams)
}
Walk(v, n.Type)
if n.Comment != nil {
Walk(v, n.Comment)
}
case *BadDecl:
// nothing to do
case *GenDecl:
if n.Doc != nil {
Walk(v, n.Doc)
}
for _, s := range n.Specs {
Walk(v, s)
}
case *FuncDecl:
if n.Doc != nil {
Walk(v, n.Doc)
}
if n.Recv != nil {
Walk(v, n.Recv)
}
Walk(v, n.Name)
Walk(v, n.Type)
if n.Body != nil {
Walk(v, n.Body)
}
// Files and packages
case *File:
if n.Doc != nil {
Walk(v, n.Doc)
}
Walk(v, n.Name)
walkDeclList(v, n.Decls)
// don't walk n.Comments - they have been
// visited already through the individual
// nodes
case *Package:
for _, f := range n.Files {
Walk(v, f)
}
default:
panic(fmt.Sprintf("ast.Walk: unexpected node type %T", n))
}
v.Visit(nil)
}
type inspector func(Node) bool
func (f inspector) Visit(node Node) Visitor {
if f(node) {
return f
}
return nil
}
// Inspect traverses an AST in depth-first order: It starts by calling
// f(node); node must not be nil. If f returns true, Inspect invokes f
// recursively for each of the non-nil children of node, followed by a
// call of f(nil).
func Inspect(node Node, f func(Node) bool) {
Walk(inspector(f), node)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package build
import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/build/constraint"
"go/doc"
"go/token"
"internal/buildcfg"
"internal/godebug"
"internal/goroot"
"internal/goversion"
"io"
"io/fs"
"os"
"os/exec"
pathpkg "path"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// A Context specifies the supporting context for a build.
type Context struct {
GOARCH string // target architecture
GOOS string // target operating system
GOROOT string // Go root
GOPATH string // Go paths
// Dir is the caller's working directory, or the empty string to use
// the current directory of the running process. In module mode, this is used
// to locate the main module.
//
// If Dir is non-empty, directories passed to Import and ImportDir must
// be absolute.
Dir string
CgoEnabled bool // whether cgo files are included
UseAllFiles bool // use files regardless of go:build lines, file names
Compiler string // compiler to assume when computing target paths
// The build, tool, and release tags specify build constraints
// that should be considered satisfied when processing go:build lines.
// Clients creating a new context may customize BuildTags, which
// defaults to empty, but it is usually an error to customize ToolTags or ReleaseTags.
// ToolTags defaults to build tags appropriate to the current Go toolchain configuration.
// ReleaseTags defaults to the list of Go releases the current release is compatible with.
// BuildTags is not set for the Default build Context.
// In addition to the BuildTags, ToolTags, and ReleaseTags, build constraints
// consider the values of GOARCH and GOOS as satisfied tags.
// The last element in ReleaseTags is assumed to be the current release.
BuildTags []string
ToolTags []string
ReleaseTags []string
// The install suffix specifies a suffix to use in the name of the installation
// directory. By default it is empty, but custom builds that need to keep
// their outputs separate can set InstallSuffix to do so. For example, when
// using the race detector, the go command uses InstallSuffix = "race", so
// that on a Linux/386 system, packages are written to a directory named
// "linux_386_race" instead of the usual "linux_386".
InstallSuffix string
// By default, Import uses the operating system's file system calls
// to read directories and files. To read from other sources,
// callers can set the following functions. They all have default
// behaviors that use the local file system, so clients need only set
// the functions whose behaviors they wish to change.
// JoinPath joins the sequence of path fragments into a single path.
// If JoinPath is nil, Import uses filepath.Join.
JoinPath func(elem ...string) string
// SplitPathList splits the path list into a slice of individual paths.
// If SplitPathList is nil, Import uses filepath.SplitList.
SplitPathList func(list string) []string
// IsAbsPath reports whether path is an absolute path.
// If IsAbsPath is nil, Import uses filepath.IsAbs.
IsAbsPath func(path string) bool
// IsDir reports whether the path names a directory.
// If IsDir is nil, Import calls os.Stat and uses the result's IsDir method.
IsDir func(path string) bool
// HasSubdir reports whether dir is lexically a subdirectory of
// root, perhaps multiple levels below. It does not try to check
// whether dir exists.
// If so, HasSubdir sets rel to a slash-separated path that
// can be joined to root to produce a path equivalent to dir.
// If HasSubdir is nil, Import uses an implementation built on
// filepath.EvalSymlinks.
HasSubdir func(root, dir string) (rel string, ok bool)
// ReadDir returns a slice of fs.FileInfo, sorted by Name,
// describing the content of the named directory.
// If ReadDir is nil, Import uses os.ReadDir.
ReadDir func(dir string) ([]fs.FileInfo, error)
// OpenFile opens a file (not a directory) for reading.
// If OpenFile is nil, Import uses os.Open.
OpenFile func(path string) (io.ReadCloser, error)
}
// joinPath calls ctxt.JoinPath (if not nil) or else filepath.Join.
func (ctxt *Context) joinPath(elem ...string) string {
if f := ctxt.JoinPath; f != nil {
return f(elem...)
}
return filepath.Join(elem...)
}
// splitPathList calls ctxt.SplitPathList (if not nil) or else filepath.SplitList.
func (ctxt *Context) splitPathList(s string) []string {
if f := ctxt.SplitPathList; f != nil {
return f(s)
}
return filepath.SplitList(s)
}
// isAbsPath calls ctxt.IsAbsPath (if not nil) or else filepath.IsAbs.
func (ctxt *Context) isAbsPath(path string) bool {
if f := ctxt.IsAbsPath; f != nil {
return f(path)
}
return filepath.IsAbs(path)
}
// isDir calls ctxt.IsDir (if not nil) or else uses os.Stat.
func (ctxt *Context) isDir(path string) bool {
if f := ctxt.IsDir; f != nil {
return f(path)
}
fi, err := os.Stat(path)
return err == nil && fi.IsDir()
}
// hasSubdir calls ctxt.HasSubdir (if not nil) or else uses
// the local file system to answer the question.
func (ctxt *Context) hasSubdir(root, dir string) (rel string, ok bool) {
if f := ctxt.HasSubdir; f != nil {
return f(root, dir)
}
// Try using paths we received.
if rel, ok = hasSubdir(root, dir); ok {
return
}
// Try expanding symlinks and comparing
// expanded against unexpanded and
// expanded against expanded.
rootSym, _ := filepath.EvalSymlinks(root)
dirSym, _ := filepath.EvalSymlinks(dir)
if rel, ok = hasSubdir(rootSym, dir); ok {
return
}
if rel, ok = hasSubdir(root, dirSym); ok {
return
}
return hasSubdir(rootSym, dirSym)
}
// hasSubdir reports if dir is within root by performing lexical analysis only.
func hasSubdir(root, dir string) (rel string, ok bool) {
const sep = string(filepath.Separator)
root = filepath.Clean(root)
if !strings.HasSuffix(root, sep) {
root += sep
}
dir = filepath.Clean(dir)
after, found := strings.CutPrefix(dir, root)
if !found {
return "", false
}
return filepath.ToSlash(after), true
}
// readDir calls ctxt.ReadDir (if not nil) or else os.ReadDir.
func (ctxt *Context) readDir(path string) ([]fs.DirEntry, error) {
// TODO: add a fs.DirEntry version of Context.ReadDir
if f := ctxt.ReadDir; f != nil {
fis, err := f(path)
if err != nil {
return nil, err
}
des := make([]fs.DirEntry, len(fis))
for i, fi := range fis {
des[i] = fs.FileInfoToDirEntry(fi)
}
return des, nil
}
return os.ReadDir(path)
}
// openFile calls ctxt.OpenFile (if not nil) or else os.Open.
func (ctxt *Context) openFile(path string) (io.ReadCloser, error) {
if fn := ctxt.OpenFile; fn != nil {
return fn(path)
}
f, err := os.Open(path)
if err != nil {
return nil, err // nil interface
}
return f, nil
}
// isFile determines whether path is a file by trying to open it.
// It reuses openFile instead of adding another function to the
// list in Context.
func (ctxt *Context) isFile(path string) bool {
f, err := ctxt.openFile(path)
if err != nil {
return false
}
f.Close()
return true
}
// gopath returns the list of Go path directories.
func (ctxt *Context) gopath() []string {
var all []string
for _, p := range ctxt.splitPathList(ctxt.GOPATH) {
if p == "" || p == ctxt.GOROOT {
// Empty paths are uninteresting.
// If the path is the GOROOT, ignore it.
// People sometimes set GOPATH=$GOROOT.
// Do not get confused by this common mistake.
continue
}
if strings.HasPrefix(p, "~") {
// Path segments starting with ~ on Unix are almost always
// users who have incorrectly quoted ~ while setting GOPATH,
// preventing it from expanding to $HOME.
// The situation is made more confusing by the fact that
// bash allows quoted ~ in $PATH (most shells do not).
// Do not get confused by this, and do not try to use the path.
// It does not exist, and printing errors about it confuses
// those users even more, because they think "sure ~ exists!".
// The go command diagnoses this situation and prints a
// useful error.
// On Windows, ~ is used in short names, such as c:\progra~1
// for c:\program files.
continue
}
all = append(all, p)
}
return all
}
// SrcDirs returns a list of package source root directories.
// It draws from the current Go root and Go path but omits directories
// that do not exist.
func (ctxt *Context) SrcDirs() []string {
var all []string
if ctxt.GOROOT != "" && ctxt.Compiler != "gccgo" {
dir := ctxt.joinPath(ctxt.GOROOT, "src")
if ctxt.isDir(dir) {
all = append(all, dir)
}
}
for _, p := range ctxt.gopath() {
dir := ctxt.joinPath(p, "src")
if ctxt.isDir(dir) {
all = append(all, dir)
}
}
return all
}
// Default is the default Context for builds.
// It uses the GOARCH, GOOS, GOROOT, and GOPATH environment variables
// if set, or else the compiled code's GOARCH, GOOS, and GOROOT.
var Default Context = defaultContext()
func defaultGOPATH() string {
env := "HOME"
if runtime.GOOS == "windows" {
env = "USERPROFILE"
} else if runtime.GOOS == "plan9" {
env = "home"
}
if home := os.Getenv(env); home != "" {
def := filepath.Join(home, "go")
if filepath.Clean(def) == filepath.Clean(runtime.GOROOT()) {
// Don't set the default GOPATH to GOROOT,
// as that will trigger warnings from the go tool.
return ""
}
return def
}
return ""
}
var defaultToolTags, defaultReleaseTags []string
func defaultContext() Context {
var c Context
c.GOARCH = buildcfg.GOARCH
c.GOOS = buildcfg.GOOS
if goroot := runtime.GOROOT(); goroot != "" {
c.GOROOT = filepath.Clean(goroot)
}
c.GOPATH = envOr("GOPATH", defaultGOPATH())
c.Compiler = runtime.Compiler
c.ToolTags = append(c.ToolTags, buildcfg.ToolTags...)
defaultToolTags = append([]string{}, c.ToolTags...) // our own private copy
// Each major Go release in the Go 1.x series adds a new
// "go1.x" release tag. That is, the go1.x tag is present in
// all releases >= Go 1.x. Code that requires Go 1.x or later
// should say "go:build go1.x", and code that should only be
// built before Go 1.x (perhaps it is the stub to use in that
// case) should say "go:build !go1.x".
// The last element in ReleaseTags is the current release.
for i := 1; i <= goversion.Version; i++ {
c.ReleaseTags = append(c.ReleaseTags, "go1."+strconv.Itoa(i))
}
defaultReleaseTags = append([]string{}, c.ReleaseTags...) // our own private copy
env := os.Getenv("CGO_ENABLED")
if env == "" {
env = defaultCGO_ENABLED
}
switch env {
case "1":
c.CgoEnabled = true
case "0":
c.CgoEnabled = false
default:
// cgo must be explicitly enabled for cross compilation builds
if runtime.GOARCH == c.GOARCH && runtime.GOOS == c.GOOS {
c.CgoEnabled = cgoEnabled[c.GOOS+"/"+c.GOARCH]
break
}
c.CgoEnabled = false
}
return c
}
func envOr(name, def string) string {
s := os.Getenv(name)
if s == "" {
return def
}
return s
}
// An ImportMode controls the behavior of the Import method.
type ImportMode uint
const (
// If FindOnly is set, Import stops after locating the directory
// that should contain the sources for a package. It does not
// read any files in the directory.
FindOnly ImportMode = 1 << iota
// If AllowBinary is set, Import can be satisfied by a compiled
// package object without corresponding sources.
//
// Deprecated:
// The supported way to create a compiled-only package is to
// write source code containing a //go:binary-only-package comment at
// the top of the file. Such a package will be recognized
// regardless of this flag setting (because it has source code)
// and will have BinaryOnly set to true in the returned Package.
AllowBinary
// If ImportComment is set, parse import comments on package statements.
// Import returns an error if it finds a comment it cannot understand
// or finds conflicting comments in multiple source files.
// See golang.org/s/go14customimport for more information.
ImportComment
// By default, Import searches vendor directories
// that apply in the given source directory before searching
// the GOROOT and GOPATH roots.
// If an Import finds and returns a package using a vendor
// directory, the resulting ImportPath is the complete path
// to the package, including the path elements leading up
// to and including "vendor".
// For example, if Import("y", "x/subdir", 0) finds
// "x/vendor/y", the returned package's ImportPath is "x/vendor/y",
// not plain "y".
// See golang.org/s/go15vendor for more information.
//
// Setting IgnoreVendor ignores vendor directories.
//
// In contrast to the package's ImportPath,
// the returned package's Imports, TestImports, and XTestImports
// are always the exact import paths from the source files:
// Import makes no attempt to resolve or check those paths.
IgnoreVendor
)
// A Package describes the Go package found in a directory.
type Package struct {
Dir string // directory containing package sources
Name string // package name
ImportComment string // path in import comment on package statement
Doc string // documentation synopsis
ImportPath string // import path of package ("" if unknown)
Root string // root of Go tree where this package lives
SrcRoot string // package source root directory ("" if unknown)
PkgRoot string // package install root directory ("" if unknown)
PkgTargetRoot string // architecture dependent install root directory ("" if unknown)
BinDir string // command install directory ("" if unknown)
Goroot bool // package found in Go root
PkgObj string // installed .a file
AllTags []string // tags that can influence file selection in this directory
ConflictDir string // this directory shadows Dir in $GOPATH
BinaryOnly bool // cannot be rebuilt from source (has //go:binary-only-package comment)
// Source files
GoFiles []string // .go source files (excluding CgoFiles, TestGoFiles, XTestGoFiles)
CgoFiles []string // .go source files that import "C"
IgnoredGoFiles []string // .go source files ignored for this build (including ignored _test.go files)
InvalidGoFiles []string // .go source files with detected problems (parse error, wrong package name, and so on)
IgnoredOtherFiles []string // non-.go source files ignored for this build
CFiles []string // .c source files
CXXFiles []string // .cc, .cpp and .cxx source files
MFiles []string // .m (Objective-C) source files
HFiles []string // .h, .hh, .hpp and .hxx source files
FFiles []string // .f, .F, .for and .f90 Fortran source files
SFiles []string // .s source files
SwigFiles []string // .swig files
SwigCXXFiles []string // .swigcxx files
SysoFiles []string // .syso system object files to add to archive
// Cgo directives
CgoCFLAGS []string // Cgo CFLAGS directives
CgoCPPFLAGS []string // Cgo CPPFLAGS directives
CgoCXXFLAGS []string // Cgo CXXFLAGS directives
CgoFFLAGS []string // Cgo FFLAGS directives
CgoLDFLAGS []string // Cgo LDFLAGS directives
CgoPkgConfig []string // Cgo pkg-config directives
// Test information
TestGoFiles []string // _test.go files in package
XTestGoFiles []string // _test.go files outside package
// Go directive comments (//go:zzz...) found in source files.
Directives []Directive
TestDirectives []Directive
XTestDirectives []Directive
// Dependency information
Imports []string // import paths from GoFiles, CgoFiles
ImportPos map[string][]token.Position // line information for Imports
TestImports []string // import paths from TestGoFiles
TestImportPos map[string][]token.Position // line information for TestImports
XTestImports []string // import paths from XTestGoFiles
XTestImportPos map[string][]token.Position // line information for XTestImports
// //go:embed patterns found in Go source files
// For example, if a source file says
// //go:embed a* b.c
// then the list will contain those two strings as separate entries.
// (See package embed for more details about //go:embed.)
EmbedPatterns []string // patterns from GoFiles, CgoFiles
EmbedPatternPos map[string][]token.Position // line information for EmbedPatterns
TestEmbedPatterns []string // patterns from TestGoFiles
TestEmbedPatternPos map[string][]token.Position // line information for TestEmbedPatterns
XTestEmbedPatterns []string // patterns from XTestGoFiles
XTestEmbedPatternPos map[string][]token.Position // line information for XTestEmbedPatternPos
}
// A Directive is a Go directive comment (//go:zzz...) found in a source file.
type Directive struct {
Text string // full line comment including leading slashes
Pos token.Position // position of comment
}
// IsCommand reports whether the package is considered a
// command to be installed (not just a library).
// Packages named "main" are treated as commands.
func (p *Package) IsCommand() bool {
return p.Name == "main"
}
// ImportDir is like Import but processes the Go package found in
// the named directory.
func (ctxt *Context) ImportDir(dir string, mode ImportMode) (*Package, error) {
return ctxt.Import(".", dir, mode)
}
// NoGoError is the error used by Import to describe a directory
// containing no buildable Go source files. (It may still contain
// test files, files hidden by build tags, and so on.)
type NoGoError struct {
Dir string
}
func (e *NoGoError) Error() string {
return "no buildable Go source files in " + e.Dir
}
// MultiplePackageError describes a directory containing
// multiple buildable Go source files for multiple packages.
type MultiplePackageError struct {
Dir string // directory containing files
Packages []string // package names found
Files []string // corresponding files: Files[i] declares package Packages[i]
}
func (e *MultiplePackageError) Error() string {
// Error string limited to two entries for compatibility.
return fmt.Sprintf("found packages %s (%s) and %s (%s) in %s", e.Packages[0], e.Files[0], e.Packages[1], e.Files[1], e.Dir)
}
func nameExt(name string) string {
i := strings.LastIndex(name, ".")
if i < 0 {
return ""
}
return name[i:]
}
var installgoroot = godebug.New("installgoroot")
// Import returns details about the Go package named by the import path,
// interpreting local import paths relative to the srcDir directory.
// If the path is a local import path naming a package that can be imported
// using a standard import path, the returned package will set p.ImportPath
// to that path.
//
// In the directory containing the package, .go, .c, .h, and .s files are
// considered part of the package except for:
//
// - .go files in package documentation
// - files starting with _ or . (likely editor temporary files)
// - files with build constraints not satisfied by the context
//
// If an error occurs, Import returns a non-nil error and a non-nil
// *Package containing partial information.
func (ctxt *Context) Import(path string, srcDir string, mode ImportMode) (*Package, error) {
p := &Package{
ImportPath: path,
}
if path == "" {
return p, fmt.Errorf("import %q: invalid import path", path)
}
var pkgtargetroot string
var pkga string
var pkgerr error
suffix := ""
if ctxt.InstallSuffix != "" {
suffix = "_" + ctxt.InstallSuffix
}
switch ctxt.Compiler {
case "gccgo":
pkgtargetroot = "pkg/gccgo_" + ctxt.GOOS + "_" + ctxt.GOARCH + suffix
case "gc":
pkgtargetroot = "pkg/" + ctxt.GOOS + "_" + ctxt.GOARCH + suffix
default:
// Save error for end of function.
pkgerr = fmt.Errorf("import %q: unknown compiler %q", path, ctxt.Compiler)
}
setPkga := func() {
switch ctxt.Compiler {
case "gccgo":
dir, elem := pathpkg.Split(p.ImportPath)
pkga = pkgtargetroot + "/" + dir + "lib" + elem + ".a"
case "gc":
pkga = pkgtargetroot + "/" + p.ImportPath + ".a"
}
}
setPkga()
binaryOnly := false
if IsLocalImport(path) {
pkga = "" // local imports have no installed path
if srcDir == "" {
return p, fmt.Errorf("import %q: import relative to unknown directory", path)
}
if !ctxt.isAbsPath(path) {
p.Dir = ctxt.joinPath(srcDir, path)
}
// p.Dir directory may or may not exist. Gather partial information first, check if it exists later.
// Determine canonical import path, if any.
// Exclude results where the import path would include /testdata/.
inTestdata := func(sub string) bool {
return strings.Contains(sub, "/testdata/") || strings.HasSuffix(sub, "/testdata") || strings.HasPrefix(sub, "testdata/") || sub == "testdata"
}
if ctxt.GOROOT != "" {
root := ctxt.joinPath(ctxt.GOROOT, "src")
if sub, ok := ctxt.hasSubdir(root, p.Dir); ok && !inTestdata(sub) {
p.Goroot = true
p.ImportPath = sub
p.Root = ctxt.GOROOT
setPkga() // p.ImportPath changed
goto Found
}
}
all := ctxt.gopath()
for i, root := range all {
rootsrc := ctxt.joinPath(root, "src")
if sub, ok := ctxt.hasSubdir(rootsrc, p.Dir); ok && !inTestdata(sub) {
// We found a potential import path for dir,
// but check that using it wouldn't find something
// else first.
if ctxt.GOROOT != "" && ctxt.Compiler != "gccgo" {
if dir := ctxt.joinPath(ctxt.GOROOT, "src", sub); ctxt.isDir(dir) {
p.ConflictDir = dir
goto Found
}
}
for _, earlyRoot := range all[:i] {
if dir := ctxt.joinPath(earlyRoot, "src", sub); ctxt.isDir(dir) {
p.ConflictDir = dir
goto Found
}
}
// sub would not name some other directory instead of this one.
// Record it.
p.ImportPath = sub
p.Root = root
setPkga() // p.ImportPath changed
goto Found
}
}
// It's okay that we didn't find a root containing dir.
// Keep going with the information we have.
} else {
if strings.HasPrefix(path, "/") {
return p, fmt.Errorf("import %q: cannot import absolute path", path)
}
if err := ctxt.importGo(p, path, srcDir, mode); err == nil {
goto Found
} else if err != errNoModules {
return p, err
}
gopath := ctxt.gopath() // needed twice below; avoid computing many times
// tried records the location of unsuccessful package lookups
var tried struct {
vendor []string
goroot string
gopath []string
}
// Vendor directories get first chance to satisfy import.
if mode&IgnoreVendor == 0 && srcDir != "" {
searchVendor := func(root string, isGoroot bool) bool {
sub, ok := ctxt.hasSubdir(root, srcDir)
if !ok || !strings.HasPrefix(sub, "src/") || strings.Contains(sub, "/testdata/") {
return false
}
for {
vendor := ctxt.joinPath(root, sub, "vendor")
if ctxt.isDir(vendor) {
dir := ctxt.joinPath(vendor, path)
if ctxt.isDir(dir) && hasGoFiles(ctxt, dir) {
p.Dir = dir
p.ImportPath = strings.TrimPrefix(pathpkg.Join(sub, "vendor", path), "src/")
p.Goroot = isGoroot
p.Root = root
setPkga() // p.ImportPath changed
return true
}
tried.vendor = append(tried.vendor, dir)
}
i := strings.LastIndex(sub, "/")
if i < 0 {
break
}
sub = sub[:i]
}
return false
}
if ctxt.Compiler != "gccgo" && ctxt.GOROOT != "" && searchVendor(ctxt.GOROOT, true) {
goto Found
}
for _, root := range gopath {
if searchVendor(root, false) {
goto Found
}
}
}
// Determine directory from import path.
if ctxt.GOROOT != "" {
// If the package path starts with "vendor/", only search GOROOT before
// GOPATH if the importer is also within GOROOT. That way, if the user has
// vendored in a package that is subsequently included in the standard
// distribution, they'll continue to pick up their own vendored copy.
gorootFirst := srcDir == "" || !strings.HasPrefix(path, "vendor/")
if !gorootFirst {
_, gorootFirst = ctxt.hasSubdir(ctxt.GOROOT, srcDir)
}
if gorootFirst {
dir := ctxt.joinPath(ctxt.GOROOT, "src", path)
if ctxt.Compiler != "gccgo" {
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
tried.goroot = dir
}
if ctxt.Compiler == "gccgo" && goroot.IsStandardPackage(ctxt.GOROOT, ctxt.Compiler, path) {
// TODO(bcmills): Setting p.Dir here is misleading, because gccgo
// doesn't actually load its standard-library packages from this
// directory. See if we can leave it unset.
p.Dir = ctxt.joinPath(ctxt.GOROOT, "src", path)
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
for _, root := range gopath {
dir := ctxt.joinPath(root, "src", path)
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(root, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Root = root
goto Found
}
tried.gopath = append(tried.gopath, dir)
}
// If we tried GOPATH first due to a "vendor/" prefix, fall back to GOPATH.
// That way, the user can still get useful results from 'go list' for
// standard-vendored paths passed on the command line.
if ctxt.GOROOT != "" && tried.goroot == "" {
dir := ctxt.joinPath(ctxt.GOROOT, "src", path)
if ctxt.Compiler != "gccgo" {
isDir := ctxt.isDir(dir)
binaryOnly = !isDir && mode&AllowBinary != 0 && pkga != "" && ctxt.isFile(ctxt.joinPath(ctxt.GOROOT, pkga))
if isDir || binaryOnly {
p.Dir = dir
p.Goroot = true
p.Root = ctxt.GOROOT
goto Found
}
}
tried.goroot = dir
}
// package was not found
var paths []string
format := "\t%s (vendor tree)"
for _, dir := range tried.vendor {
paths = append(paths, fmt.Sprintf(format, dir))
format = "\t%s"
}
if tried.goroot != "" {
paths = append(paths, fmt.Sprintf("\t%s (from $GOROOT)", tried.goroot))
} else {
paths = append(paths, "\t($GOROOT not set)")
}
format = "\t%s (from $GOPATH)"
for _, dir := range tried.gopath {
paths = append(paths, fmt.Sprintf(format, dir))
format = "\t%s"
}
if len(tried.gopath) == 0 {
paths = append(paths, "\t($GOPATH not set. For more details see: 'go help gopath')")
}
return p, fmt.Errorf("cannot find package %q in any of:\n%s", path, strings.Join(paths, "\n"))
}
Found:
if p.Root != "" {
p.SrcRoot = ctxt.joinPath(p.Root, "src")
p.PkgRoot = ctxt.joinPath(p.Root, "pkg")
p.BinDir = ctxt.joinPath(p.Root, "bin")
if pkga != "" {
// Always set PkgTargetRoot. It might be used when building in shared
// mode.
p.PkgTargetRoot = ctxt.joinPath(p.Root, pkgtargetroot)
// Set the install target if applicable.
if !p.Goroot || (installgoroot.Value() == "all" && p.ImportPath != "unsafe" && p.ImportPath != "builtin") {
if p.Goroot {
installgoroot.IncNonDefault()
}
p.PkgObj = ctxt.joinPath(p.Root, pkga)
}
}
}
// If it's a local import path, by the time we get here, we still haven't checked
// that p.Dir directory exists. This is the right time to do that check.
// We can't do it earlier, because we want to gather partial information for the
// non-nil *Package returned when an error occurs.
// We need to do this before we return early on FindOnly flag.
if IsLocalImport(path) && !ctxt.isDir(p.Dir) {
if ctxt.Compiler == "gccgo" && p.Goroot {
// gccgo has no sources for GOROOT packages.
return p, nil
}
// package was not found
return p, fmt.Errorf("cannot find package %q in:\n\t%s", p.ImportPath, p.Dir)
}
if mode&FindOnly != 0 {
return p, pkgerr
}
if binaryOnly && (mode&AllowBinary) != 0 {
return p, pkgerr
}
if ctxt.Compiler == "gccgo" && p.Goroot {
// gccgo has no sources for GOROOT packages.
return p, nil
}
dirs, err := ctxt.readDir(p.Dir)
if err != nil {
return p, err
}
var badGoError error
badGoFiles := make(map[string]bool)
badGoFile := func(name string, err error) {
if badGoError == nil {
badGoError = err
}
if !badGoFiles[name] {
p.InvalidGoFiles = append(p.InvalidGoFiles, name)
badGoFiles[name] = true
}
}
var Sfiles []string // files with ".S"(capital S)/.sx(capital s equivalent for case insensitive filesystems)
var firstFile, firstCommentFile string
embedPos := make(map[string][]token.Position)
testEmbedPos := make(map[string][]token.Position)
xTestEmbedPos := make(map[string][]token.Position)
importPos := make(map[string][]token.Position)
testImportPos := make(map[string][]token.Position)
xTestImportPos := make(map[string][]token.Position)
allTags := make(map[string]bool)
fset := token.NewFileSet()
for _, d := range dirs {
if d.IsDir() {
continue
}
if d.Type() == fs.ModeSymlink {
if ctxt.isDir(ctxt.joinPath(p.Dir, d.Name())) {
// Symlinks to directories are not source files.
continue
}
}
name := d.Name()
ext := nameExt(name)
info, err := ctxt.matchFile(p.Dir, name, allTags, &p.BinaryOnly, fset)
if err != nil && strings.HasSuffix(name, ".go") {
badGoFile(name, err)
continue
}
if info == nil {
if strings.HasPrefix(name, "_") || strings.HasPrefix(name, ".") {
// not due to build constraints - don't report
} else if ext == ".go" {
p.IgnoredGoFiles = append(p.IgnoredGoFiles, name)
} else if fileListForExt(p, ext) != nil {
p.IgnoredOtherFiles = append(p.IgnoredOtherFiles, name)
}
continue
}
// Going to save the file. For non-Go files, can stop here.
switch ext {
case ".go":
// keep going
case ".S", ".sx":
// special case for cgo, handled at end
Sfiles = append(Sfiles, name)
continue
default:
if list := fileListForExt(p, ext); list != nil {
*list = append(*list, name)
}
continue
}
data, filename := info.header, info.name
if info.parseErr != nil {
badGoFile(name, info.parseErr)
// Fall through: we might still have a partial AST in info.parsed,
// and we want to list files with parse errors anyway.
}
var pkg string
if info.parsed != nil {
pkg = info.parsed.Name.Name
if pkg == "documentation" {
p.IgnoredGoFiles = append(p.IgnoredGoFiles, name)
continue
}
}
isTest := strings.HasSuffix(name, "_test.go")
isXTest := false
if isTest && strings.HasSuffix(pkg, "_test") && p.Name != pkg {
isXTest = true
pkg = pkg[:len(pkg)-len("_test")]
}
if p.Name == "" {
p.Name = pkg
firstFile = name
} else if pkg != p.Name {
// TODO(#45999): The choice of p.Name is arbitrary based on file iteration
// order. Instead of resolving p.Name arbitrarily, we should clear out the
// existing name and mark the existing files as also invalid.
badGoFile(name, &MultiplePackageError{
Dir: p.Dir,
Packages: []string{p.Name, pkg},
Files: []string{firstFile, name},
})
}
// Grab the first package comment as docs, provided it is not from a test file.
if info.parsed != nil && info.parsed.Doc != nil && p.Doc == "" && !isTest && !isXTest {
p.Doc = doc.Synopsis(info.parsed.Doc.Text())
}
if mode&ImportComment != 0 {
qcom, line := findImportComment(data)
if line != 0 {
com, err := strconv.Unquote(qcom)
if err != nil {
badGoFile(name, fmt.Errorf("%s:%d: cannot parse import comment", filename, line))
} else if p.ImportComment == "" {
p.ImportComment = com
firstCommentFile = name
} else if p.ImportComment != com {
badGoFile(name, fmt.Errorf("found import comments %q (%s) and %q (%s) in %s", p.ImportComment, firstCommentFile, com, name, p.Dir))
}
}
}
// Record imports and information about cgo.
isCgo := false
for _, imp := range info.imports {
if imp.path == "C" {
if isTest {
badGoFile(name, fmt.Errorf("use of cgo in test %s not supported", filename))
continue
}
isCgo = true
if imp.doc != nil {
if err := ctxt.saveCgo(filename, p, imp.doc); err != nil {
badGoFile(name, err)
}
}
}
}
var fileList *[]string
var importMap, embedMap map[string][]token.Position
var directives *[]Directive
switch {
case isCgo:
allTags["cgo"] = true
if ctxt.CgoEnabled {
fileList = &p.CgoFiles
importMap = importPos
embedMap = embedPos
directives = &p.Directives
} else {
// Ignore imports and embeds from cgo files if cgo is disabled.
fileList = &p.IgnoredGoFiles
}
case isXTest:
fileList = &p.XTestGoFiles
importMap = xTestImportPos
embedMap = xTestEmbedPos
directives = &p.XTestDirectives
case isTest:
fileList = &p.TestGoFiles
importMap = testImportPos
embedMap = testEmbedPos
directives = &p.TestDirectives
default:
fileList = &p.GoFiles
importMap = importPos
embedMap = embedPos
directives = &p.Directives
}
*fileList = append(*fileList, name)
if importMap != nil {
for _, imp := range info.imports {
importMap[imp.path] = append(importMap[imp.path], fset.Position(imp.pos))
}
}
if embedMap != nil {
for _, emb := range info.embeds {
embedMap[emb.pattern] = append(embedMap[emb.pattern], emb.pos)
}
}
if directives != nil {
*directives = append(*directives, info.directives...)
}
}
for tag := range allTags {
p.AllTags = append(p.AllTags, tag)
}
sort.Strings(p.AllTags)
p.EmbedPatterns, p.EmbedPatternPos = cleanDecls(embedPos)
p.TestEmbedPatterns, p.TestEmbedPatternPos = cleanDecls(testEmbedPos)
p.XTestEmbedPatterns, p.XTestEmbedPatternPos = cleanDecls(xTestEmbedPos)
p.Imports, p.ImportPos = cleanDecls(importPos)
p.TestImports, p.TestImportPos = cleanDecls(testImportPos)
p.XTestImports, p.XTestImportPos = cleanDecls(xTestImportPos)
// add the .S/.sx files only if we are using cgo
// (which means gcc will compile them).
// The standard assemblers expect .s files.
if len(p.CgoFiles) > 0 {
p.SFiles = append(p.SFiles, Sfiles...)
sort.Strings(p.SFiles)
} else {
p.IgnoredOtherFiles = append(p.IgnoredOtherFiles, Sfiles...)
sort.Strings(p.IgnoredOtherFiles)
}
if badGoError != nil {
return p, badGoError
}
if len(p.GoFiles)+len(p.CgoFiles)+len(p.TestGoFiles)+len(p.XTestGoFiles) == 0 {
return p, &NoGoError{p.Dir}
}
return p, pkgerr
}
func fileListForExt(p *Package, ext string) *[]string {
switch ext {
case ".c":
return &p.CFiles
case ".cc", ".cpp", ".cxx":
return &p.CXXFiles
case ".m":
return &p.MFiles
case ".h", ".hh", ".hpp", ".hxx":
return &p.HFiles
case ".f", ".F", ".for", ".f90":
return &p.FFiles
case ".s", ".S", ".sx":
return &p.SFiles
case ".swig":
return &p.SwigFiles
case ".swigcxx":
return &p.SwigCXXFiles
case ".syso":
return &p.SysoFiles
}
return nil
}
func uniq(list []string) []string {
if list == nil {
return nil
}
out := make([]string, len(list))
copy(out, list)
sort.Strings(out)
uniq := out[:0]
for _, x := range out {
if len(uniq) == 0 || uniq[len(uniq)-1] != x {
uniq = append(uniq, x)
}
}
return uniq
}
var errNoModules = errors.New("not using modules")
// importGo checks whether it can use the go command to find the directory for path.
// If using the go command is not appropriate, importGo returns errNoModules.
// Otherwise, importGo tries using the go command and reports whether that succeeded.
// Using the go command lets build.Import and build.Context.Import find code
// in Go modules. In the long term we want tools to use go/packages (currently golang.org/x/tools/go/packages),
// which will also use the go command.
// Invoking the go command here is not very efficient in that it computes information
// about the requested package and all dependencies and then only reports about the requested package.
// Then we reinvoke it for every dependency. But this is still better than not working at all.
// See golang.org/issue/26504.
func (ctxt *Context) importGo(p *Package, path, srcDir string, mode ImportMode) error {
// To invoke the go command,
// we must not being doing special things like AllowBinary or IgnoreVendor,
// and all the file system callbacks must be nil (we're meant to use the local file system).
if mode&AllowBinary != 0 || mode&IgnoreVendor != 0 ||
ctxt.JoinPath != nil || ctxt.SplitPathList != nil || ctxt.IsAbsPath != nil || ctxt.IsDir != nil || ctxt.HasSubdir != nil || ctxt.ReadDir != nil || ctxt.OpenFile != nil || !equal(ctxt.ToolTags, defaultToolTags) || !equal(ctxt.ReleaseTags, defaultReleaseTags) {
return errNoModules
}
// If ctxt.GOROOT is not set, we don't know which go command to invoke,
// and even if we did we might return packages in GOROOT that we wouldn't otherwise find
// (because we don't know to search in 'go env GOROOT' otherwise).
if ctxt.GOROOT == "" {
return errNoModules
}
// Predict whether module aware mode is enabled by checking the value of
// GO111MODULE and looking for a go.mod file in the source directory or
// one of its parents. Running 'go env GOMOD' in the source directory would
// give a canonical answer, but we'd prefer not to execute another command.
go111Module := os.Getenv("GO111MODULE")
switch go111Module {
case "off":
return errNoModules
default: // "", "on", "auto", anything else
// Maybe use modules.
}
if srcDir != "" {
var absSrcDir string
if filepath.IsAbs(srcDir) {
absSrcDir = srcDir
} else if ctxt.Dir != "" {
return fmt.Errorf("go/build: Dir is non-empty, so relative srcDir is not allowed: %v", srcDir)
} else {
// Find the absolute source directory. hasSubdir does not handle
// relative paths (and can't because the callbacks don't support this).
var err error
absSrcDir, err = filepath.Abs(srcDir)
if err != nil {
return errNoModules
}
}
// If the source directory is in GOROOT, then the in-process code works fine
// and we should keep using it. Moreover, the 'go list' approach below doesn't
// take standard-library vendoring into account and will fail.
if _, ok := ctxt.hasSubdir(filepath.Join(ctxt.GOROOT, "src"), absSrcDir); ok {
return errNoModules
}
}
// For efficiency, if path is a standard library package, let the usual lookup code handle it.
if dir := ctxt.joinPath(ctxt.GOROOT, "src", path); ctxt.isDir(dir) {
return errNoModules
}
// If GO111MODULE=auto, look to see if there is a go.mod.
// Since go1.13, it doesn't matter if we're inside GOPATH.
if go111Module == "auto" {
var (
parent string
err error
)
if ctxt.Dir == "" {
parent, err = os.Getwd()
if err != nil {
// A nonexistent working directory can't be in a module.
return errNoModules
}
} else {
parent, err = filepath.Abs(ctxt.Dir)
if err != nil {
// If the caller passed a bogus Dir explicitly, that's materially
// different from not having modules enabled.
return err
}
}
for {
if f, err := ctxt.openFile(ctxt.joinPath(parent, "go.mod")); err == nil {
buf := make([]byte, 100)
_, err := f.Read(buf)
f.Close()
if err == nil || err == io.EOF {
// go.mod exists and is readable (is a file, not a directory).
break
}
}
d := filepath.Dir(parent)
if len(d) >= len(parent) {
return errNoModules // reached top of file system, no go.mod
}
parent = d
}
}
goCmd := filepath.Join(ctxt.GOROOT, "bin", "go")
cmd := exec.Command(goCmd, "list", "-e", "-compiler="+ctxt.Compiler, "-tags="+strings.Join(ctxt.BuildTags, ","), "-installsuffix="+ctxt.InstallSuffix, "-f={{.Dir}}\n{{.ImportPath}}\n{{.Root}}\n{{.Goroot}}\n{{if .Error}}{{.Error}}{{end}}\n", "--", path)
if ctxt.Dir != "" {
cmd.Dir = ctxt.Dir
}
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cgo := "0"
if ctxt.CgoEnabled {
cgo = "1"
}
cmd.Env = append(cmd.Environ(),
"GOOS="+ctxt.GOOS,
"GOARCH="+ctxt.GOARCH,
"GOROOT="+ctxt.GOROOT,
"GOPATH="+ctxt.GOPATH,
"CGO_ENABLED="+cgo,
)
if err := cmd.Run(); err != nil {
return fmt.Errorf("go/build: go list %s: %v\n%s\n", path, err, stderr.String())
}
f := strings.SplitN(stdout.String(), "\n", 5)
if len(f) != 5 {
return fmt.Errorf("go/build: importGo %s: unexpected output:\n%s\n", path, stdout.String())
}
dir := f[0]
errStr := strings.TrimSpace(f[4])
if errStr != "" && dir == "" {
// If 'go list' could not locate the package (dir is empty),
// return the same error that 'go list' reported.
return errors.New(errStr)
}
// If 'go list' did locate the package, ignore the error.
// It was probably related to loading source files, and we'll
// encounter it ourselves shortly if the FindOnly flag isn't set.
p.Dir = dir
p.ImportPath = f[1]
p.Root = f[2]
p.Goroot = f[3] == "true"
return nil
}
func equal(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, xi := range x {
if xi != y[i] {
return false
}
}
return true
}
// hasGoFiles reports whether dir contains any files with names ending in .go.
// For a vendor check we must exclude directories that contain no .go files.
// Otherwise it is not possible to vendor just a/b/c and still import the
// non-vendored a/b. See golang.org/issue/13832.
func hasGoFiles(ctxt *Context, dir string) bool {
ents, _ := ctxt.readDir(dir)
for _, ent := range ents {
if !ent.IsDir() && strings.HasSuffix(ent.Name(), ".go") {
return true
}
}
return false
}
func findImportComment(data []byte) (s string, line int) {
// expect keyword package
word, data := parseWord(data)
if string(word) != "package" {
return "", 0
}
// expect package name
_, data = parseWord(data)
// now ready for import comment, a // or /* */ comment
// beginning and ending on the current line.
for len(data) > 0 && (data[0] == ' ' || data[0] == '\t' || data[0] == '\r') {
data = data[1:]
}
var comment []byte
switch {
case bytes.HasPrefix(data, slashSlash):
comment, _, _ = bytes.Cut(data[2:], newline)
case bytes.HasPrefix(data, slashStar):
var ok bool
comment, _, ok = bytes.Cut(data[2:], starSlash)
if !ok {
// malformed comment
return "", 0
}
if bytes.Contains(comment, newline) {
return "", 0
}
}
comment = bytes.TrimSpace(comment)
// split comment into `import`, `"pkg"`
word, arg := parseWord(comment)
if string(word) != "import" {
return "", 0
}
line = 1 + bytes.Count(data[:cap(data)-cap(arg)], newline)
return strings.TrimSpace(string(arg)), line
}
var (
slashSlash = []byte("//")
slashStar = []byte("/*")
starSlash = []byte("*/")
newline = []byte("\n")
)
// skipSpaceOrComment returns data with any leading spaces or comments removed.
func skipSpaceOrComment(data []byte) []byte {
for len(data) > 0 {
switch data[0] {
case ' ', '\t', '\r', '\n':
data = data[1:]
continue
case '/':
if bytes.HasPrefix(data, slashSlash) {
i := bytes.Index(data, newline)
if i < 0 {
return nil
}
data = data[i+1:]
continue
}
if bytes.HasPrefix(data, slashStar) {
data = data[2:]
i := bytes.Index(data, starSlash)
if i < 0 {
return nil
}
data = data[i+2:]
continue
}
}
break
}
return data
}
// parseWord skips any leading spaces or comments in data
// and then parses the beginning of data as an identifier or keyword,
// returning that word and what remains after the word.
func parseWord(data []byte) (word, rest []byte) {
data = skipSpaceOrComment(data)
// Parse past leading word characters.
rest = data
for {
r, size := utf8.DecodeRune(rest)
if unicode.IsLetter(r) || '0' <= r && r <= '9' || r == '_' {
rest = rest[size:]
continue
}
break
}
word = data[:len(data)-len(rest)]
if len(word) == 0 {
return nil, nil
}
return word, rest
}
// MatchFile reports whether the file with the given name in the given directory
// matches the context and would be included in a Package created by ImportDir
// of that directory.
//
// MatchFile considers the name of the file and may use ctxt.OpenFile to
// read some or all of the file's content.
func (ctxt *Context) MatchFile(dir, name string) (match bool, err error) {
info, err := ctxt.matchFile(dir, name, nil, nil, nil)
return info != nil, err
}
var dummyPkg Package
// fileInfo records information learned about a file included in a build.
type fileInfo struct {
name string // full name including dir
header []byte
fset *token.FileSet
parsed *ast.File
parseErr error
imports []fileImport
embeds []fileEmbed
directives []Directive
}
type fileImport struct {
path string
pos token.Pos
doc *ast.CommentGroup
}
type fileEmbed struct {
pattern string
pos token.Position
}
// matchFile determines whether the file with the given name in the given directory
// should be included in the package being constructed.
// If the file should be included, matchFile returns a non-nil *fileInfo (and a nil error).
// Non-nil errors are reserved for unexpected problems.
//
// If name denotes a Go program, matchFile reads until the end of the
// imports and returns that section of the file in the fileInfo's header field,
// even though it only considers text until the first non-comment
// for go:build lines.
//
// If allTags is non-nil, matchFile records any encountered build tag
// by setting allTags[tag] = true.
func (ctxt *Context) matchFile(dir, name string, allTags map[string]bool, binaryOnly *bool, fset *token.FileSet) (*fileInfo, error) {
if strings.HasPrefix(name, "_") ||
strings.HasPrefix(name, ".") {
return nil, nil
}
i := strings.LastIndex(name, ".")
if i < 0 {
i = len(name)
}
ext := name[i:]
if ext != ".go" && fileListForExt(&dummyPkg, ext) == nil {
// skip
return nil, nil
}
if !ctxt.goodOSArchFile(name, allTags) && !ctxt.UseAllFiles {
return nil, nil
}
info := &fileInfo{name: ctxt.joinPath(dir, name), fset: fset}
if ext == ".syso" {
// binary, no reading
return info, nil
}
f, err := ctxt.openFile(info.name)
if err != nil {
return nil, err
}
if strings.HasSuffix(name, ".go") {
err = readGoInfo(f, info)
if strings.HasSuffix(name, "_test.go") {
binaryOnly = nil // ignore //go:binary-only-package comments in _test.go files
}
} else {
binaryOnly = nil // ignore //go:binary-only-package comments in non-Go sources
info.header, err = readComments(f)
}
f.Close()
if err != nil {
return info, fmt.Errorf("read %s: %v", info.name, err)
}
// Look for go:build comments to accept or reject the file.
ok, sawBinaryOnly, err := ctxt.shouldBuild(info.header, allTags)
if err != nil {
return nil, fmt.Errorf("%s: %v", name, err)
}
if !ok && !ctxt.UseAllFiles {
return nil, nil
}
if binaryOnly != nil && sawBinaryOnly {
*binaryOnly = true
}
return info, nil
}
func cleanDecls(m map[string][]token.Position) ([]string, map[string][]token.Position) {
all := make([]string, 0, len(m))
for path := range m {
all = append(all, path)
}
sort.Strings(all)
return all, m
}
// Import is shorthand for Default.Import.
func Import(path, srcDir string, mode ImportMode) (*Package, error) {
return Default.Import(path, srcDir, mode)
}
// ImportDir is shorthand for Default.ImportDir.
func ImportDir(dir string, mode ImportMode) (*Package, error) {
return Default.ImportDir(dir, mode)
}
var (
plusBuild = []byte("+build")
goBuildComment = []byte("//go:build")
errMultipleGoBuild = errors.New("multiple //go:build comments")
)
func isGoBuildComment(line []byte) bool {
if !bytes.HasPrefix(line, goBuildComment) {
return false
}
line = bytes.TrimSpace(line)
rest := line[len(goBuildComment):]
return len(rest) == 0 || len(bytes.TrimSpace(rest)) < len(rest)
}
// Special comment denoting a binary-only package.
// See https://golang.org/design/2775-binary-only-packages
// for more about the design of binary-only packages.
var binaryOnlyComment = []byte("//go:binary-only-package")
// shouldBuild reports whether it is okay to use this file,
// The rule is that in the file's leading run of // comments
// and blank lines, which must be followed by a blank line
// (to avoid including a Go package clause doc comment),
// lines beginning with '//go:build' are taken as build directives.
//
// The file is accepted only if each such line lists something
// matching the file. For example:
//
// //go:build windows linux
//
// marks the file as applicable only on Windows and Linux.
//
// For each build tag it consults, shouldBuild sets allTags[tag] = true.
//
// shouldBuild reports whether the file should be built
// and whether a //go:binary-only-package comment was found.
func (ctxt *Context) shouldBuild(content []byte, allTags map[string]bool) (shouldBuild, binaryOnly bool, err error) {
// Identify leading run of // comments and blank lines,
// which must be followed by a blank line.
// Also identify any //go:build comments.
content, goBuild, sawBinaryOnly, err := parseFileHeader(content)
if err != nil {
return false, false, err
}
// If //go:build line is present, it controls.
// Otherwise fall back to +build processing.
switch {
case goBuild != nil:
x, err := constraint.Parse(string(goBuild))
if err != nil {
return false, false, fmt.Errorf("parsing //go:build line: %v", err)
}
shouldBuild = ctxt.eval(x, allTags)
default:
shouldBuild = true
p := content
for len(p) > 0 {
line := p
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, p = line[:i], p[i+1:]
} else {
p = p[len(p):]
}
line = bytes.TrimSpace(line)
if !bytes.HasPrefix(line, slashSlash) || !bytes.Contains(line, plusBuild) {
continue
}
text := string(line)
if !constraint.IsPlusBuild(text) {
continue
}
if x, err := constraint.Parse(text); err == nil {
if !ctxt.eval(x, allTags) {
shouldBuild = false
}
}
}
}
return shouldBuild, sawBinaryOnly, nil
}
func parseFileHeader(content []byte) (trimmed, goBuild []byte, sawBinaryOnly bool, err error) {
end := 0
p := content
ended := false // found non-blank, non-// line, so stopped accepting //go:build lines
inSlashStar := false // in /* */ comment
Lines:
for len(p) > 0 {
line := p
if i := bytes.IndexByte(line, '\n'); i >= 0 {
line, p = line[:i], p[i+1:]
} else {
p = p[len(p):]
}
line = bytes.TrimSpace(line)
if len(line) == 0 && !ended { // Blank line
// Remember position of most recent blank line.
// When we find the first non-blank, non-// line,
// this "end" position marks the latest file position
// where a //go:build line can appear.
// (It must appear _before_ a blank line before the non-blank, non-// line.
// Yes, that's confusing, which is part of why we moved to //go:build lines.)
// Note that ended==false here means that inSlashStar==false,
// since seeing a /* would have set ended==true.
end = len(content) - len(p)
continue Lines
}
if !bytes.HasPrefix(line, slashSlash) { // Not comment line
ended = true
}
if !inSlashStar && isGoBuildComment(line) {
if goBuild != nil {
return nil, nil, false, errMultipleGoBuild
}
goBuild = line
}
if !inSlashStar && bytes.Equal(line, binaryOnlyComment) {
sawBinaryOnly = true
}
Comments:
for len(line) > 0 {
if inSlashStar {
if i := bytes.Index(line, starSlash); i >= 0 {
inSlashStar = false
line = bytes.TrimSpace(line[i+len(starSlash):])
continue Comments
}
continue Lines
}
if bytes.HasPrefix(line, slashSlash) {
continue Lines
}
if bytes.HasPrefix(line, slashStar) {
inSlashStar = true
line = bytes.TrimSpace(line[len(slashStar):])
continue Comments
}
// Found non-comment text.
break Lines
}
}
return content[:end], goBuild, sawBinaryOnly, nil
}
// saveCgo saves the information from the #cgo lines in the import "C" comment.
// These lines set CFLAGS, CPPFLAGS, CXXFLAGS and LDFLAGS and pkg-config directives
// that affect the way cgo's C code is built.
func (ctxt *Context) saveCgo(filename string, di *Package, cg *ast.CommentGroup) error {
text := cg.Text()
for _, line := range strings.Split(text, "\n") {
orig := line
// Line is
// #cgo [GOOS/GOARCH...] LDFLAGS: stuff
//
line = strings.TrimSpace(line)
if len(line) < 5 || line[:4] != "#cgo" || (line[4] != ' ' && line[4] != '\t') {
continue
}
// Split at colon.
line, argstr, ok := strings.Cut(strings.TrimSpace(line[4:]), ":")
if !ok {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
// Parse GOOS/GOARCH stuff.
f := strings.Fields(line)
if len(f) < 1 {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
cond, verb := f[:len(f)-1], f[len(f)-1]
if len(cond) > 0 {
ok := false
for _, c := range cond {
if ctxt.matchAuto(c, nil) {
ok = true
break
}
}
if !ok {
continue
}
}
args, err := splitQuoted(argstr)
if err != nil {
return fmt.Errorf("%s: invalid #cgo line: %s", filename, orig)
}
for i, arg := range args {
if arg, ok = expandSrcDir(arg, di.Dir); !ok {
return fmt.Errorf("%s: malformed #cgo argument: %s", filename, arg)
}
args[i] = arg
}
switch verb {
case "CFLAGS", "CPPFLAGS", "CXXFLAGS", "FFLAGS", "LDFLAGS":
// Change relative paths to absolute.
ctxt.makePathsAbsolute(args, di.Dir)
}
switch verb {
case "CFLAGS":
di.CgoCFLAGS = append(di.CgoCFLAGS, args...)
case "CPPFLAGS":
di.CgoCPPFLAGS = append(di.CgoCPPFLAGS, args...)
case "CXXFLAGS":
di.CgoCXXFLAGS = append(di.CgoCXXFLAGS, args...)
case "FFLAGS":
di.CgoFFLAGS = append(di.CgoFFLAGS, args...)
case "LDFLAGS":
di.CgoLDFLAGS = append(di.CgoLDFLAGS, args...)
case "pkg-config":
di.CgoPkgConfig = append(di.CgoPkgConfig, args...)
default:
return fmt.Errorf("%s: invalid #cgo verb: %s", filename, orig)
}
}
return nil
}
// expandSrcDir expands any occurrence of ${SRCDIR}, making sure
// the result is safe for the shell.
func expandSrcDir(str string, srcdir string) (string, bool) {
// "\" delimited paths cause safeCgoName to fail
// so convert native paths with a different delimiter
// to "/" before starting (eg: on windows).
srcdir = filepath.ToSlash(srcdir)
chunks := strings.Split(str, "${SRCDIR}")
if len(chunks) < 2 {
return str, safeCgoName(str)
}
ok := true
for _, chunk := range chunks {
ok = ok && (chunk == "" || safeCgoName(chunk))
}
ok = ok && (srcdir == "" || safeCgoName(srcdir))
res := strings.Join(chunks, srcdir)
return res, ok && res != ""
}
// makePathsAbsolute looks for compiler options that take paths and
// makes them absolute. We do this because through the 1.8 release we
// ran the compiler in the package directory, so any relative -I or -L
// options would be relative to that directory. In 1.9 we changed to
// running the compiler in the build directory, to get consistent
// build results (issue #19964). To keep builds working, we change any
// relative -I or -L options to be absolute.
//
// Using filepath.IsAbs and filepath.Join here means the results will be
// different on different systems, but that's OK: -I and -L options are
// inherently system-dependent.
func (ctxt *Context) makePathsAbsolute(args []string, srcDir string) {
nextPath := false
for i, arg := range args {
if nextPath {
if !filepath.IsAbs(arg) {
args[i] = filepath.Join(srcDir, arg)
}
nextPath = false
} else if strings.HasPrefix(arg, "-I") || strings.HasPrefix(arg, "-L") {
if len(arg) == 2 {
nextPath = true
} else {
if !filepath.IsAbs(arg[2:]) {
args[i] = arg[:2] + filepath.Join(srcDir, arg[2:])
}
}
}
}
}
// NOTE: $ is not safe for the shell, but it is allowed here because of linker options like -Wl,$ORIGIN.
// We never pass these arguments to a shell (just to programs we construct argv for), so this should be okay.
// See golang.org/issue/6038.
// The @ is for OS X. See golang.org/issue/13720.
// The % is for Jenkins. See golang.org/issue/16959.
// The ! is because module paths may use them. See golang.org/issue/26716.
// The ~ and ^ are for sr.ht. See golang.org/issue/32260.
const safeString = "+-.,/0123456789=ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz:$@%! ~^"
func safeCgoName(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if c := s[i]; c < utf8.RuneSelf && strings.IndexByte(safeString, c) < 0 {
return false
}
}
return true
}
// splitQuoted splits the string s around each instance of one or more consecutive
// white space characters while taking into account quotes and escaping, and
// returns an array of substrings of s or an empty list if s contains only white space.
// Single quotes and double quotes are recognized to prevent splitting within the
// quoted region, and are removed from the resulting substrings. If a quote in s
// isn't closed err will be set and r will have the unclosed argument as the
// last element. The backslash is used for escaping.
//
// For example, the following string:
//
// a b:"c d" 'e''f' "g\""
//
// Would be parsed as:
//
// []string{"a", "b:c d", "ef", `g"`}
func splitQuoted(s string) (r []string, err error) {
var args []string
arg := make([]rune, len(s))
escaped := false
quoted := false
quote := '\x00'
i := 0
for _, rune := range s {
switch {
case escaped:
escaped = false
case rune == '\\':
escaped = true
continue
case quote != '\x00':
if rune == quote {
quote = '\x00'
continue
}
case rune == '"' || rune == '\'':
quoted = true
quote = rune
continue
case unicode.IsSpace(rune):
if quoted || i > 0 {
quoted = false
args = append(args, string(arg[:i]))
i = 0
}
continue
}
arg[i] = rune
i++
}
if quoted || i > 0 {
args = append(args, string(arg[:i]))
}
if quote != 0 {
err = errors.New("unclosed quote")
} else if escaped {
err = errors.New("unfinished escaping")
}
return args, err
}
// matchAuto interprets text as either a +build or //go:build expression (whichever works),
// reporting whether the expression matches the build context.
//
// matchAuto is only used for testing of tag evaluation
// and in #cgo lines, which accept either syntax.
func (ctxt *Context) matchAuto(text string, allTags map[string]bool) bool {
if strings.ContainsAny(text, "&|()") {
text = "//go:build " + text
} else {
text = "// +build " + text
}
x, err := constraint.Parse(text)
if err != nil {
return false
}
return ctxt.eval(x, allTags)
}
func (ctxt *Context) eval(x constraint.Expr, allTags map[string]bool) bool {
return x.Eval(func(tag string) bool { return ctxt.matchTag(tag, allTags) })
}
// matchTag reports whether the name is one of:
//
// cgo (if cgo is enabled)
// $GOOS
// $GOARCH
// ctxt.Compiler
// linux (if GOOS = android)
// solaris (if GOOS = illumos)
// darwin (if GOOS = ios)
// unix (if this is a Unix GOOS)
// boringcrypto (if GOEXPERIMENT=boringcrypto is enabled)
// tag (if tag is listed in ctxt.BuildTags, ctxt.ToolTags, or ctxt.ReleaseTags)
//
// It records all consulted tags in allTags.
func (ctxt *Context) matchTag(name string, allTags map[string]bool) bool {
if allTags != nil {
allTags[name] = true
}
// special tags
if ctxt.CgoEnabled && name == "cgo" {
return true
}
if name == ctxt.GOOS || name == ctxt.GOARCH || name == ctxt.Compiler {
return true
}
if ctxt.GOOS == "android" && name == "linux" {
return true
}
if ctxt.GOOS == "illumos" && name == "solaris" {
return true
}
if ctxt.GOOS == "ios" && name == "darwin" {
return true
}
if name == "unix" && unixOS[ctxt.GOOS] {
return true
}
if name == "boringcrypto" {
name = "goexperiment.boringcrypto" // boringcrypto is an old name for goexperiment.boringcrypto
}
// other tags
for _, tag := range ctxt.BuildTags {
if tag == name {
return true
}
}
for _, tag := range ctxt.ToolTags {
if tag == name {
return true
}
}
for _, tag := range ctxt.ReleaseTags {
if tag == name {
return true
}
}
return false
}
// goodOSArchFile returns false if the name contains a $GOOS or $GOARCH
// suffix which does not match the current system.
// The recognized name formats are:
//
// name_$(GOOS).*
// name_$(GOARCH).*
// name_$(GOOS)_$(GOARCH).*
// name_$(GOOS)_test.*
// name_$(GOARCH)_test.*
// name_$(GOOS)_$(GOARCH)_test.*
//
// Exceptions:
// if GOOS=android, then files with GOOS=linux are also matched.
// if GOOS=illumos, then files with GOOS=solaris are also matched.
// if GOOS=ios, then files with GOOS=darwin are also matched.
func (ctxt *Context) goodOSArchFile(name string, allTags map[string]bool) bool {
name, _, _ = strings.Cut(name, ".")
// Before Go 1.4, a file called "linux.go" would be equivalent to having a
// build tag "linux" in that file. For Go 1.4 and beyond, we require this
// auto-tagging to apply only to files with a non-empty prefix, so
// "foo_linux.go" is tagged but "linux.go" is not. This allows new operating
// systems, such as android, to arrive without breaking existing code with
// innocuous source code in "android.go". The easiest fix: cut everything
// in the name before the initial _.
i := strings.Index(name, "_")
if i < 0 {
return true
}
name = name[i:] // ignore everything before first _
l := strings.Split(name, "_")
if n := len(l); n > 0 && l[n-1] == "test" {
l = l[:n-1]
}
n := len(l)
if n >= 2 && knownOS[l[n-2]] && knownArch[l[n-1]] {
if allTags != nil {
// In case we short-circuit on l[n-1].
allTags[l[n-2]] = true
}
return ctxt.matchTag(l[n-1], allTags) && ctxt.matchTag(l[n-2], allTags)
}
if n >= 1 && (knownOS[l[n-1]] || knownArch[l[n-1]]) {
return ctxt.matchTag(l[n-1], allTags)
}
return true
}
// ToolDir is the directory containing build tools.
var ToolDir = getToolDir()
// IsLocalImport reports whether the import path is
// a local import path, like ".", "..", "./foo", or "../foo".
func IsLocalImport(path string) bool {
return path == "." || path == ".." ||
strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../")
}
// ArchChar returns "?" and an error.
// In earlier versions of Go, the returned string was used to derive
// the compiler and linker tool names, the default object file suffix,
// and the default linker output name. As of Go 1.5, those strings
// no longer vary by architecture; they are compile, link, .o, and a.out, respectively.
func ArchChar(goarch string) (string, error) {
return "?", errors.New("architecture letter no longer used")
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package constraint implements parsing and evaluation of build constraint lines.
// See https://golang.org/cmd/go/#hdr-Build_constraints for documentation about build constraints themselves.
//
// This package parses both the original “// +build” syntax and the “//go:build” syntax that was added in Go 1.17.
// See https://golang.org/design/draft-gobuild for details about the “//go:build” syntax.
package constraint
import (
"errors"
"strings"
"unicode"
"unicode/utf8"
)
// An Expr is a build tag constraint expression.
// The underlying concrete type is *AndExpr, *OrExpr, *NotExpr, or *TagExpr.
type Expr interface {
// String returns the string form of the expression,
// using the boolean syntax used in //go:build lines.
String() string
// Eval reports whether the expression evaluates to true.
// It calls ok(tag) as needed to find out whether a given build tag
// is satisfied by the current build configuration.
Eval(ok func(tag string) bool) bool
// The presence of an isExpr method explicitly marks the type as an Expr.
// Only implementations in this package should be used as Exprs.
isExpr()
}
// A TagExpr is an Expr for the single tag Tag.
type TagExpr struct {
Tag string // for example, “linux” or “cgo”
}
func (x *TagExpr) isExpr() {}
func (x *TagExpr) Eval(ok func(tag string) bool) bool {
return ok(x.Tag)
}
func (x *TagExpr) String() string {
return x.Tag
}
func tag(tag string) Expr { return &TagExpr{tag} }
// A NotExpr represents the expression !X (the negation of X).
type NotExpr struct {
X Expr
}
func (x *NotExpr) isExpr() {}
func (x *NotExpr) Eval(ok func(tag string) bool) bool {
return !x.X.Eval(ok)
}
func (x *NotExpr) String() string {
s := x.X.String()
switch x.X.(type) {
case *AndExpr, *OrExpr:
s = "(" + s + ")"
}
return "!" + s
}
func not(x Expr) Expr { return &NotExpr{x} }
// An AndExpr represents the expression X && Y.
type AndExpr struct {
X, Y Expr
}
func (x *AndExpr) isExpr() {}
func (x *AndExpr) Eval(ok func(tag string) bool) bool {
// Note: Eval both, to make sure ok func observes all tags.
xok := x.X.Eval(ok)
yok := x.Y.Eval(ok)
return xok && yok
}
func (x *AndExpr) String() string {
return andArg(x.X) + " && " + andArg(x.Y)
}
func andArg(x Expr) string {
s := x.String()
if _, ok := x.(*OrExpr); ok {
s = "(" + s + ")"
}
return s
}
func and(x, y Expr) Expr {
return &AndExpr{x, y}
}
// An OrExpr represents the expression X || Y.
type OrExpr struct {
X, Y Expr
}
func (x *OrExpr) isExpr() {}
func (x *OrExpr) Eval(ok func(tag string) bool) bool {
// Note: Eval both, to make sure ok func observes all tags.
xok := x.X.Eval(ok)
yok := x.Y.Eval(ok)
return xok || yok
}
func (x *OrExpr) String() string {
return orArg(x.X) + " || " + orArg(x.Y)
}
func orArg(x Expr) string {
s := x.String()
if _, ok := x.(*AndExpr); ok {
s = "(" + s + ")"
}
return s
}
func or(x, y Expr) Expr {
return &OrExpr{x, y}
}
// A SyntaxError reports a syntax error in a parsed build expression.
type SyntaxError struct {
Offset int // byte offset in input where error was detected
Err string // description of error
}
func (e *SyntaxError) Error() string {
return e.Err
}
var errNotConstraint = errors.New("not a build constraint")
// Parse parses a single build constraint line of the form “//go:build ...” or “// +build ...”
// and returns the corresponding boolean expression.
func Parse(line string) (Expr, error) {
if text, ok := splitGoBuild(line); ok {
return parseExpr(text)
}
if text, ok := splitPlusBuild(line); ok {
return parsePlusBuildExpr(text), nil
}
return nil, errNotConstraint
}
// IsGoBuild reports whether the line of text is a “//go:build” constraint.
// It only checks the prefix of the text, not that the expression itself parses.
func IsGoBuild(line string) bool {
_, ok := splitGoBuild(line)
return ok
}
// splitGoBuild splits apart the leading //go:build prefix in line from the build expression itself.
// It returns "", false if the input is not a //go:build line or if the input contains multiple lines.
func splitGoBuild(line string) (expr string, ok bool) {
// A single trailing newline is OK; otherwise multiple lines are not.
if len(line) > 0 && line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
if strings.Contains(line, "\n") {
return "", false
}
if !strings.HasPrefix(line, "//go:build") {
return "", false
}
line = strings.TrimSpace(line)
line = line[len("//go:build"):]
// If strings.TrimSpace finds more to trim after removing the //go:build prefix,
// it means that the prefix was followed by a space, making this a //go:build line
// (as opposed to a //go:buildsomethingelse line).
// If line is empty, we had "//go:build" by itself, which also counts.
trim := strings.TrimSpace(line)
if len(line) == len(trim) && line != "" {
return "", false
}
return trim, true
}
// An exprParser holds state for parsing a build expression.
type exprParser struct {
s string // input string
i int // next read location in s
tok string // last token read
isTag bool
pos int // position (start) of last token
}
// parseExpr parses a boolean build tag expression.
func parseExpr(text string) (x Expr, err error) {
defer func() {
if e := recover(); e != nil {
if e, ok := e.(*SyntaxError); ok {
err = e
return
}
panic(e) // unreachable unless parser has a bug
}
}()
p := &exprParser{s: text}
x = p.or()
if p.tok != "" {
panic(&SyntaxError{Offset: p.pos, Err: "unexpected token " + p.tok})
}
return x, nil
}
// or parses a sequence of || expressions.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) or() Expr {
x := p.and()
for p.tok == "||" {
x = or(x, p.and())
}
return x
}
// and parses a sequence of && expressions.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) and() Expr {
x := p.not()
for p.tok == "&&" {
x = and(x, p.not())
}
return x
}
// not parses a ! expression.
// On entry, the next input token has not yet been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) not() Expr {
p.lex()
if p.tok == "!" {
p.lex()
if p.tok == "!" {
panic(&SyntaxError{Offset: p.pos, Err: "double negation not allowed"})
}
return not(p.atom())
}
return p.atom()
}
// atom parses a tag or a parenthesized expression.
// On entry, the next input token HAS been lexed.
// On exit, the next input token has been lexed and is in p.tok.
func (p *exprParser) atom() Expr {
// first token already in p.tok
if p.tok == "(" {
pos := p.pos
defer func() {
if e := recover(); e != nil {
if e, ok := e.(*SyntaxError); ok && e.Err == "unexpected end of expression" {
e.Err = "missing close paren"
}
panic(e)
}
}()
x := p.or()
if p.tok != ")" {
panic(&SyntaxError{Offset: pos, Err: "missing close paren"})
}
p.lex()
return x
}
if !p.isTag {
if p.tok == "" {
panic(&SyntaxError{Offset: p.pos, Err: "unexpected end of expression"})
}
panic(&SyntaxError{Offset: p.pos, Err: "unexpected token " + p.tok})
}
tok := p.tok
p.lex()
return tag(tok)
}
// lex finds and consumes the next token in the input stream.
// On return, p.tok is set to the token text,
// p.isTag reports whether the token was a tag,
// and p.pos records the byte offset of the start of the token in the input stream.
// If lex reaches the end of the input, p.tok is set to the empty string.
// For any other syntax error, lex panics with a SyntaxError.
func (p *exprParser) lex() {
p.isTag = false
for p.i < len(p.s) && (p.s[p.i] == ' ' || p.s[p.i] == '\t') {
p.i++
}
if p.i >= len(p.s) {
p.tok = ""
p.pos = p.i
return
}
switch p.s[p.i] {
case '(', ')', '!':
p.pos = p.i
p.i++
p.tok = p.s[p.pos:p.i]
return
case '&', '|':
if p.i+1 >= len(p.s) || p.s[p.i+1] != p.s[p.i] {
panic(&SyntaxError{Offset: p.i, Err: "invalid syntax at " + string(rune(p.s[p.i]))})
}
p.pos = p.i
p.i += 2
p.tok = p.s[p.pos:p.i]
return
}
tag := p.s[p.i:]
for i, c := range tag {
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && c != '_' && c != '.' {
tag = tag[:i]
break
}
}
if tag == "" {
c, _ := utf8.DecodeRuneInString(p.s[p.i:])
panic(&SyntaxError{Offset: p.i, Err: "invalid syntax at " + string(c)})
}
p.pos = p.i
p.i += len(tag)
p.tok = p.s[p.pos:p.i]
p.isTag = true
}
// IsPlusBuild reports whether the line of text is a “// +build” constraint.
// It only checks the prefix of the text, not that the expression itself parses.
func IsPlusBuild(line string) bool {
_, ok := splitPlusBuild(line)
return ok
}
// splitPlusBuild splits apart the leading // +build prefix in line from the build expression itself.
// It returns "", false if the input is not a // +build line or if the input contains multiple lines.
func splitPlusBuild(line string) (expr string, ok bool) {
// A single trailing newline is OK; otherwise multiple lines are not.
if len(line) > 0 && line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
if strings.Contains(line, "\n") {
return "", false
}
if !strings.HasPrefix(line, "//") {
return "", false
}
line = line[len("//"):]
// Note the space is optional; "//+build" is recognized too.
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "+build") {
return "", false
}
line = line[len("+build"):]
// If strings.TrimSpace finds more to trim after removing the +build prefix,
// it means that the prefix was followed by a space, making this a +build line
// (as opposed to a +buildsomethingelse line).
// If line is empty, we had "// +build" by itself, which also counts.
trim := strings.TrimSpace(line)
if len(line) == len(trim) && line != "" {
return "", false
}
return trim, true
}
// parsePlusBuildExpr parses a legacy build tag expression (as used with “// +build”).
func parsePlusBuildExpr(text string) Expr {
var x Expr
for _, clause := range strings.Fields(text) {
var y Expr
for _, lit := range strings.Split(clause, ",") {
var z Expr
var neg bool
if strings.HasPrefix(lit, "!!") || lit == "!" {
z = tag("ignore")
} else {
if strings.HasPrefix(lit, "!") {
neg = true
lit = lit[len("!"):]
}
if isValidTag(lit) {
z = tag(lit)
} else {
z = tag("ignore")
}
if neg {
z = not(z)
}
}
if y == nil {
y = z
} else {
y = and(y, z)
}
}
if x == nil {
x = y
} else {
x = or(x, y)
}
}
if x == nil {
x = tag("ignore")
}
return x
}
// isValidTag reports whether the word is a valid build tag.
// Tags must be letters, digits, underscores or dots.
// Unlike in Go identifiers, all digits are fine (e.g., "386").
func isValidTag(word string) bool {
if word == "" {
return false
}
for _, c := range word {
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && c != '_' && c != '.' {
return false
}
}
return true
}
var errComplex = errors.New("expression too complex for // +build lines")
// PlusBuildLines returns a sequence of “// +build” lines that evaluate to the build expression x.
// If the expression is too complex to convert directly to “// +build” lines, PlusBuildLines returns an error.
func PlusBuildLines(x Expr) ([]string, error) {
// Push all NOTs to the expression leaves, so that //go:build !(x && y) can be treated as !x || !y.
// This rewrite is both efficient and commonly needed, so it's worth doing.
// Essentially all other possible rewrites are too expensive and too rarely needed.
x = pushNot(x, false)
// Split into AND of ORs of ANDs of literals (tag or NOT tag).
var split [][][]Expr
for _, or := range appendSplitAnd(nil, x) {
var ands [][]Expr
for _, and := range appendSplitOr(nil, or) {
var lits []Expr
for _, lit := range appendSplitAnd(nil, and) {
switch lit.(type) {
case *TagExpr, *NotExpr:
lits = append(lits, lit)
default:
return nil, errComplex
}
}
ands = append(ands, lits)
}
split = append(split, ands)
}
// If all the ORs have length 1 (no actual OR'ing going on),
// push the top-level ANDs to the bottom level, so that we get
// one // +build line instead of many.
maxOr := 0
for _, or := range split {
if maxOr < len(or) {
maxOr = len(or)
}
}
if maxOr == 1 {
var lits []Expr
for _, or := range split {
lits = append(lits, or[0]...)
}
split = [][][]Expr{{lits}}
}
// Prepare the +build lines.
var lines []string
for _, or := range split {
line := "// +build"
for _, and := range or {
clause := ""
for i, lit := range and {
if i > 0 {
clause += ","
}
clause += lit.String()
}
line += " " + clause
}
lines = append(lines, line)
}
return lines, nil
}
// pushNot applies DeMorgan's law to push negations down the expression,
// so that only tags are negated in the result.
// (It applies the rewrites !(X && Y) => (!X || !Y) and !(X || Y) => (!X && !Y).)
func pushNot(x Expr, not bool) Expr {
switch x := x.(type) {
default:
// unreachable
return x
case *NotExpr:
if _, ok := x.X.(*TagExpr); ok && !not {
return x
}
return pushNot(x.X, !not)
case *TagExpr:
if not {
return &NotExpr{X: x}
}
return x
case *AndExpr:
x1 := pushNot(x.X, not)
y1 := pushNot(x.Y, not)
if not {
return or(x1, y1)
}
if x1 == x.X && y1 == x.Y {
return x
}
return and(x1, y1)
case *OrExpr:
x1 := pushNot(x.X, not)
y1 := pushNot(x.Y, not)
if not {
return and(x1, y1)
}
if x1 == x.X && y1 == x.Y {
return x
}
return or(x1, y1)
}
}
// appendSplitAnd appends x to list while splitting apart any top-level && expressions.
// For example, appendSplitAnd({W}, X && Y && Z) = {W, X, Y, Z}.
func appendSplitAnd(list []Expr, x Expr) []Expr {
if x, ok := x.(*AndExpr); ok {
list = appendSplitAnd(list, x.X)
list = appendSplitAnd(list, x.Y)
return list
}
return append(list, x)
}
// appendSplitOr appends x to list while splitting apart any top-level || expressions.
// For example, appendSplitOr({W}, X || Y || Z) = {W, X, Y, Z}.
func appendSplitOr(list []Expr, x Expr) []Expr {
if x, ok := x.(*OrExpr); ok {
list = appendSplitOr(list, x.X)
list = appendSplitOr(list, x.Y)
return list
}
return append(list, x)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gc
package build
import (
"path/filepath"
"runtime"
)
// getToolDir returns the default value of ToolDir.
func getToolDir() string {
return filepath.Join(runtime.GOROOT(), "pkg/tool/"+runtime.GOOS+"_"+runtime.GOARCH)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package build
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
type importReader struct {
b *bufio.Reader
buf []byte
peek byte
err error
eof bool
nerr int
pos token.Position
}
var bom = []byte{0xef, 0xbb, 0xbf}
func newImportReader(name string, r io.Reader) *importReader {
b := bufio.NewReader(r)
// Remove leading UTF-8 BOM.
// Per https://golang.org/ref/spec#Source_code_representation:
// a compiler may ignore a UTF-8-encoded byte order mark (U+FEFF)
// if it is the first Unicode code point in the source text.
if leadingBytes, err := b.Peek(3); err == nil && bytes.Equal(leadingBytes, bom) {
b.Discard(3)
}
return &importReader{
b: b,
pos: token.Position{
Filename: name,
Line: 1,
Column: 1,
},
}
}
func isIdent(c byte) bool {
return 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c >= utf8.RuneSelf
}
var (
errSyntax = errors.New("syntax error")
errNUL = errors.New("unexpected NUL in input")
)
// syntaxError records a syntax error, but only if an I/O error has not already been recorded.
func (r *importReader) syntaxError() {
if r.err == nil {
r.err = errSyntax
}
}
// readByte reads the next byte from the input, saves it in buf, and returns it.
// If an error occurs, readByte records the error in r.err and returns 0.
func (r *importReader) readByte() byte {
c, err := r.b.ReadByte()
if err == nil {
r.buf = append(r.buf, c)
if c == 0 {
err = errNUL
}
}
if err != nil {
if err == io.EOF {
r.eof = true
} else if r.err == nil {
r.err = err
}
c = 0
}
return c
}
// readByteNoBuf is like readByte but doesn't buffer the byte.
// It exhausts r.buf before reading from r.b.
func (r *importReader) readByteNoBuf() byte {
var c byte
var err error
if len(r.buf) > 0 {
c = r.buf[0]
r.buf = r.buf[1:]
} else {
c, err = r.b.ReadByte()
if err == nil && c == 0 {
err = errNUL
}
}
if err != nil {
if err == io.EOF {
r.eof = true
} else if r.err == nil {
r.err = err
}
return 0
}
r.pos.Offset++
if c == '\n' {
r.pos.Line++
r.pos.Column = 1
} else {
r.pos.Column++
}
return c
}
// peekByte returns the next byte from the input reader but does not advance beyond it.
// If skipSpace is set, peekByte skips leading spaces and comments.
func (r *importReader) peekByte(skipSpace bool) byte {
if r.err != nil {
if r.nerr++; r.nerr > 10000 {
panic("go/build: import reader looping")
}
return 0
}
// Use r.peek as first input byte.
// Don't just return r.peek here: it might have been left by peekByte(false)
// and this might be peekByte(true).
c := r.peek
if c == 0 {
c = r.readByte()
}
for r.err == nil && !r.eof {
if skipSpace {
// For the purposes of this reader, semicolons are never necessary to
// understand the input and are treated as spaces.
switch c {
case ' ', '\f', '\t', '\r', '\n', ';':
c = r.readByte()
continue
case '/':
c = r.readByte()
if c == '/' {
for c != '\n' && r.err == nil && !r.eof {
c = r.readByte()
}
} else if c == '*' {
var c1 byte
for (c != '*' || c1 != '/') && r.err == nil {
if r.eof {
r.syntaxError()
}
c, c1 = c1, r.readByte()
}
} else {
r.syntaxError()
}
c = r.readByte()
continue
}
}
break
}
r.peek = c
return r.peek
}
// nextByte is like peekByte but advances beyond the returned byte.
func (r *importReader) nextByte(skipSpace bool) byte {
c := r.peekByte(skipSpace)
r.peek = 0
return c
}
var goEmbed = []byte("go:embed")
// findEmbed advances the input reader to the next //go:embed comment.
// It reports whether it found a comment.
// (Otherwise it found an error or EOF.)
func (r *importReader) findEmbed(first bool) bool {
// The import block scan stopped after a non-space character,
// so the reader is not at the start of a line on the first call.
// After that, each //go:embed extraction leaves the reader
// at the end of a line.
startLine := !first
var c byte
for r.err == nil && !r.eof {
c = r.readByteNoBuf()
Reswitch:
switch c {
default:
startLine = false
case '\n':
startLine = true
case ' ', '\t':
// leave startLine alone
case '"':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '\\' {
r.readByteNoBuf()
if r.err != nil {
r.syntaxError()
return false
}
continue
}
if c == '"' {
c = r.readByteNoBuf()
goto Reswitch
}
}
goto Reswitch
case '`':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '`' {
c = r.readByteNoBuf()
goto Reswitch
}
}
case '\'':
startLine = false
for r.err == nil {
if r.eof {
r.syntaxError()
}
c = r.readByteNoBuf()
if c == '\\' {
r.readByteNoBuf()
if r.err != nil {
r.syntaxError()
return false
}
continue
}
if c == '\'' {
c = r.readByteNoBuf()
goto Reswitch
}
}
case '/':
c = r.readByteNoBuf()
switch c {
default:
startLine = false
goto Reswitch
case '*':
var c1 byte
for (c != '*' || c1 != '/') && r.err == nil {
if r.eof {
r.syntaxError()
}
c, c1 = c1, r.readByteNoBuf()
}
startLine = false
case '/':
if startLine {
// Try to read this as a //go:embed comment.
for i := range goEmbed {
c = r.readByteNoBuf()
if c != goEmbed[i] {
goto SkipSlashSlash
}
}
c = r.readByteNoBuf()
if c == ' ' || c == '\t' {
// Found one!
return true
}
}
SkipSlashSlash:
for c != '\n' && r.err == nil && !r.eof {
c = r.readByteNoBuf()
}
startLine = true
}
}
}
return false
}
// readKeyword reads the given keyword from the input.
// If the keyword is not present, readKeyword records a syntax error.
func (r *importReader) readKeyword(kw string) {
r.peekByte(true)
for i := 0; i < len(kw); i++ {
if r.nextByte(false) != kw[i] {
r.syntaxError()
return
}
}
if isIdent(r.peekByte(false)) {
r.syntaxError()
}
}
// readIdent reads an identifier from the input.
// If an identifier is not present, readIdent records a syntax error.
func (r *importReader) readIdent() {
c := r.peekByte(true)
if !isIdent(c) {
r.syntaxError()
return
}
for isIdent(r.peekByte(false)) {
r.peek = 0
}
}
// readString reads a quoted string literal from the input.
// If an identifier is not present, readString records a syntax error.
func (r *importReader) readString() {
switch r.nextByte(true) {
case '`':
for r.err == nil {
if r.nextByte(false) == '`' {
break
}
if r.eof {
r.syntaxError()
}
}
case '"':
for r.err == nil {
c := r.nextByte(false)
if c == '"' {
break
}
if r.eof || c == '\n' {
r.syntaxError()
}
if c == '\\' {
r.nextByte(false)
}
}
default:
r.syntaxError()
}
}
// readImport reads an import clause - optional identifier followed by quoted string -
// from the input.
func (r *importReader) readImport() {
c := r.peekByte(true)
if c == '.' {
r.peek = 0
} else if isIdent(c) {
r.readIdent()
}
r.readString()
}
// readComments is like io.ReadAll, except that it only reads the leading
// block of comments in the file.
func readComments(f io.Reader) ([]byte, error) {
r := newImportReader("", f)
r.peekByte(true)
if r.err == nil && !r.eof {
// Didn't reach EOF, so must have found a non-space byte. Remove it.
r.buf = r.buf[:len(r.buf)-1]
}
return r.buf, r.err
}
// readGoInfo expects a Go file as input and reads the file up to and including the import section.
// It records what it learned in *info.
// If info.fset is non-nil, readGoInfo parses the file and sets info.parsed, info.parseErr,
// info.imports and info.embeds.
//
// It only returns an error if there are problems reading the file,
// not for syntax errors in the file itself.
func readGoInfo(f io.Reader, info *fileInfo) error {
r := newImportReader(info.name, f)
r.readKeyword("package")
r.readIdent()
for r.peekByte(true) == 'i' {
r.readKeyword("import")
if r.peekByte(true) == '(' {
r.nextByte(false)
for r.peekByte(true) != ')' && r.err == nil {
r.readImport()
}
r.nextByte(false)
} else {
r.readImport()
}
}
info.header = r.buf
// If we stopped successfully before EOF, we read a byte that told us we were done.
// Return all but that last byte, which would cause a syntax error if we let it through.
if r.err == nil && !r.eof {
info.header = r.buf[:len(r.buf)-1]
}
// If we stopped for a syntax error, consume the whole file so that
// we are sure we don't change the errors that go/parser returns.
if r.err == errSyntax {
r.err = nil
for r.err == nil && !r.eof {
r.readByte()
}
info.header = r.buf
}
if r.err != nil {
return r.err
}
if info.fset == nil {
return nil
}
// Parse file header & record imports.
info.parsed, info.parseErr = parser.ParseFile(info.fset, info.name, info.header, parser.ImportsOnly|parser.ParseComments)
if info.parseErr != nil {
return nil
}
hasEmbed := false
for _, decl := range info.parsed.Decls {
d, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
for _, dspec := range d.Specs {
spec, ok := dspec.(*ast.ImportSpec)
if !ok {
continue
}
quoted := spec.Path.Value
path, err := strconv.Unquote(quoted)
if err != nil {
return fmt.Errorf("parser returned invalid quoted string: <%s>", quoted)
}
if path == "embed" {
hasEmbed = true
}
doc := spec.Doc
if doc == nil && len(d.Specs) == 1 {
doc = d.Doc
}
info.imports = append(info.imports, fileImport{path, spec.Pos(), doc})
}
}
// Extract directives.
for _, group := range info.parsed.Comments {
if group.Pos() >= info.parsed.Package {
break
}
for _, c := range group.List {
if strings.HasPrefix(c.Text, "//go:") {
info.directives = append(info.directives, Directive{c.Text, info.fset.Position(c.Slash)})
}
}
}
// If the file imports "embed",
// we have to look for //go:embed comments
// in the remainder of the file.
// The compiler will enforce the mapping of comments to
// declared variables. We just need to know the patterns.
// If there were //go:embed comments earlier in the file
// (near the package statement or imports), the compiler
// will reject them. They can be (and have already been) ignored.
if hasEmbed {
var line []byte
for first := true; r.findEmbed(first); first = false {
line = line[:0]
pos := r.pos
for {
c := r.readByteNoBuf()
if c == '\n' || r.err != nil || r.eof {
break
}
line = append(line, c)
}
// Add args if line is well-formed.
// Ignore badly-formed lines - the compiler will report them when it finds them,
// and we can pretend they are not there to help go list succeed with what it knows.
embs, err := parseGoEmbed(string(line), pos)
if err == nil {
info.embeds = append(info.embeds, embs...)
}
}
}
return nil
}
// parseGoEmbed parses the text following "//go:embed" to extract the glob patterns.
// It accepts unquoted space-separated patterns as well as double-quoted and back-quoted Go strings.
// This is based on a similar function in cmd/compile/internal/gc/noder.go;
// this version calculates position information as well.
func parseGoEmbed(args string, pos token.Position) ([]fileEmbed, error) {
trimBytes := func(n int) {
pos.Offset += n
pos.Column += utf8.RuneCountInString(args[:n])
args = args[n:]
}
trimSpace := func() {
trim := strings.TrimLeftFunc(args, unicode.IsSpace)
trimBytes(len(args) - len(trim))
}
var list []fileEmbed
for trimSpace(); args != ""; trimSpace() {
var path string
pathPos := pos
Switch:
switch args[0] {
default:
i := len(args)
for j, c := range args {
if unicode.IsSpace(c) {
i = j
break
}
}
path = args[:i]
trimBytes(i)
case '`':
var ok bool
path, _, ok = strings.Cut(args[1:], "`")
if !ok {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
trimBytes(1 + len(path) + 1)
case '"':
i := 1
for ; i < len(args); i++ {
if args[i] == '\\' {
i++
continue
}
if args[i] == '"' {
q, err := strconv.Unquote(args[:i+1])
if err != nil {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args[:i+1])
}
path = q
trimBytes(i + 1)
break Switch
}
}
if i >= len(args) {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
}
if args != "" {
r, _ := utf8.DecodeRuneInString(args)
if !unicode.IsSpace(r) {
return nil, fmt.Errorf("invalid quoted string in //go:embed: %s", args)
}
}
list = append(list, fileEmbed{path, pathPos})
}
return list, nil
}
// Code generated by "stringer -type Kind"; DO NOT EDIT.
package constant
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[Bool-1]
_ = x[String-2]
_ = x[Int-3]
_ = x[Float-4]
_ = x[Complex-5]
}
const _Kind_name = "UnknownBoolStringIntFloatComplex"
var _Kind_index = [...]uint8{0, 7, 11, 17, 20, 25, 32}
func (i Kind) String() string {
if i < 0 || i >= Kind(len(_Kind_index)-1) {
return "Kind(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Kind_name[_Kind_index[i]:_Kind_index[i+1]]
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package constant implements Values representing untyped
// Go constants and their corresponding operations.
//
// A special Unknown value may be used when a value
// is unknown due to an error. Operations on unknown
// values produce unknown values unless specified
// otherwise.
package constant
import (
"fmt"
"go/token"
"math"
"math/big"
"math/bits"
"strconv"
"strings"
"sync"
"unicode/utf8"
)
//go:generate stringer -type Kind
// Kind specifies the kind of value represented by a Value.
type Kind int
const (
// unknown values
Unknown Kind = iota
// non-numeric values
Bool
String
// numeric values
Int
Float
Complex
)
// A Value represents the value of a Go constant.
type Value interface {
// Kind returns the value kind.
Kind() Kind
// String returns a short, quoted (human-readable) form of the value.
// For numeric values, the result may be an approximation;
// for String values the result may be a shortened string.
// Use ExactString for a string representing a value exactly.
String() string
// ExactString returns an exact, quoted (human-readable) form of the value.
// If the Value is of Kind String, use StringVal to obtain the unquoted string.
ExactString() string
// Prevent external implementations.
implementsValue()
}
// ----------------------------------------------------------------------------
// Implementations
// Maximum supported mantissa precision.
// The spec requires at least 256 bits; typical implementations use 512 bits.
const prec = 512
// TODO(gri) Consider storing "error" information in an unknownVal so clients
// can provide better error messages. For instance, if a number is
// too large (incl. infinity), that could be recorded in unknownVal.
// See also #20583 and #42695 for use cases.
// Representation of values:
//
// Values of Int and Float Kind have two different representations each: int64Val
// and intVal, and ratVal and floatVal. When possible, the "smaller", respectively
// more precise (for Floats) representation is chosen. However, once a Float value
// is represented as a floatVal, any subsequent results remain floatVals (unless
// explicitly converted); i.e., no attempt is made to convert a floatVal back into
// a ratVal. The reasoning is that all representations but floatVal are mathematically
// exact, but once that precision is lost (by moving to floatVal), moving back to
// a different representation implies a precision that's not actually there.
type (
unknownVal struct{}
boolVal bool
stringVal struct {
// Lazy value: either a string (l,r==nil) or an addition (l,r!=nil).
mu sync.Mutex
s string
l, r *stringVal
}
int64Val int64 // Int values representable as an int64
intVal struct{ val *big.Int } // Int values not representable as an int64
ratVal struct{ val *big.Rat } // Float values representable as a fraction
floatVal struct{ val *big.Float } // Float values not representable as a fraction
complexVal struct{ re, im Value }
)
func (unknownVal) Kind() Kind { return Unknown }
func (boolVal) Kind() Kind { return Bool }
func (*stringVal) Kind() Kind { return String }
func (int64Val) Kind() Kind { return Int }
func (intVal) Kind() Kind { return Int }
func (ratVal) Kind() Kind { return Float }
func (floatVal) Kind() Kind { return Float }
func (complexVal) Kind() Kind { return Complex }
func (unknownVal) String() string { return "unknown" }
func (x boolVal) String() string { return strconv.FormatBool(bool(x)) }
// String returns a possibly shortened quoted form of the String value.
func (x *stringVal) String() string {
const maxLen = 72 // a reasonable length
s := strconv.Quote(x.string())
if utf8.RuneCountInString(s) > maxLen {
// The string without the enclosing quotes is greater than maxLen-2 runes
// long. Remove the last 3 runes (including the closing '"') by keeping
// only the first maxLen-3 runes; then add "...".
i := 0
for n := 0; n < maxLen-3; n++ {
_, size := utf8.DecodeRuneInString(s[i:])
i += size
}
s = s[:i] + "..."
}
return s
}
// string constructs and returns the actual string literal value.
// If x represents an addition, then it rewrites x to be a single
// string, to speed future calls. This lazy construction avoids
// building different string values for all subpieces of a large
// concatenation. See golang.org/issue/23348.
func (x *stringVal) string() string {
x.mu.Lock()
if x.l != nil {
x.s = strings.Join(reverse(x.appendReverse(nil)), "")
x.l = nil
x.r = nil
}
s := x.s
x.mu.Unlock()
return s
}
// reverse reverses x in place and returns it.
func reverse(x []string) []string {
n := len(x)
for i := 0; i+i < n; i++ {
x[i], x[n-1-i] = x[n-1-i], x[i]
}
return x
}
// appendReverse appends to list all of x's subpieces, but in reverse,
// and returns the result. Appending the reversal allows processing
// the right side in a recursive call and the left side in a loop.
// Because a chain like a + b + c + d + e is actually represented
// as ((((a + b) + c) + d) + e), the left-side loop avoids deep recursion.
// x must be locked.
func (x *stringVal) appendReverse(list []string) []string {
y := x
for y.r != nil {
y.r.mu.Lock()
list = y.r.appendReverse(list)
y.r.mu.Unlock()
l := y.l
if y != x {
y.mu.Unlock()
}
l.mu.Lock()
y = l
}
s := y.s
if y != x {
y.mu.Unlock()
}
return append(list, s)
}
func (x int64Val) String() string { return strconv.FormatInt(int64(x), 10) }
func (x intVal) String() string { return x.val.String() }
func (x ratVal) String() string { return rtof(x).String() }
// String returns a decimal approximation of the Float value.
func (x floatVal) String() string {
f := x.val
// Don't try to convert infinities (will not terminate).
if f.IsInf() {
return f.String()
}
// Use exact fmt formatting if in float64 range (common case):
// proceed if f doesn't underflow to 0 or overflow to inf.
if x, _ := f.Float64(); f.Sign() == 0 == (x == 0) && !math.IsInf(x, 0) {
s := fmt.Sprintf("%.6g", x)
if !f.IsInt() && strings.IndexByte(s, '.') < 0 {
// f is not an integer, but its string representation
// doesn't reflect that. Use more digits. See issue 56220.
s = fmt.Sprintf("%g", x)
}
return s
}
// Out of float64 range. Do approximate manual to decimal
// conversion to avoid precise but possibly slow Float
// formatting.
// f = mant * 2**exp
var mant big.Float
exp := f.MantExp(&mant) // 0.5 <= |mant| < 1.0
// approximate float64 mantissa m and decimal exponent d
// f ~ m * 10**d
m, _ := mant.Float64() // 0.5 <= |m| < 1.0
d := float64(exp) * (math.Ln2 / math.Ln10) // log_10(2)
// adjust m for truncated (integer) decimal exponent e
e := int64(d)
m *= math.Pow(10, d-float64(e))
// ensure 1 <= |m| < 10
switch am := math.Abs(m); {
case am < 1-0.5e-6:
// The %.6g format below rounds m to 5 digits after the
// decimal point. Make sure that m*10 < 10 even after
// rounding up: m*10 + 0.5e-5 < 10 => m < 1 - 0.5e6.
m *= 10
e--
case am >= 10:
m /= 10
e++
}
return fmt.Sprintf("%.6ge%+d", m, e)
}
func (x complexVal) String() string { return fmt.Sprintf("(%s + %si)", x.re, x.im) }
func (x unknownVal) ExactString() string { return x.String() }
func (x boolVal) ExactString() string { return x.String() }
func (x *stringVal) ExactString() string { return strconv.Quote(x.string()) }
func (x int64Val) ExactString() string { return x.String() }
func (x intVal) ExactString() string { return x.String() }
func (x ratVal) ExactString() string {
r := x.val
if r.IsInt() {
return r.Num().String()
}
return r.String()
}
func (x floatVal) ExactString() string { return x.val.Text('p', 0) }
func (x complexVal) ExactString() string {
return fmt.Sprintf("(%s + %si)", x.re.ExactString(), x.im.ExactString())
}
func (unknownVal) implementsValue() {}
func (boolVal) implementsValue() {}
func (*stringVal) implementsValue() {}
func (int64Val) implementsValue() {}
func (ratVal) implementsValue() {}
func (intVal) implementsValue() {}
func (floatVal) implementsValue() {}
func (complexVal) implementsValue() {}
func newInt() *big.Int { return new(big.Int) }
func newRat() *big.Rat { return new(big.Rat) }
func newFloat() *big.Float { return new(big.Float).SetPrec(prec) }
func i64toi(x int64Val) intVal { return intVal{newInt().SetInt64(int64(x))} }
func i64tor(x int64Val) ratVal { return ratVal{newRat().SetInt64(int64(x))} }
func i64tof(x int64Val) floatVal { return floatVal{newFloat().SetInt64(int64(x))} }
func itor(x intVal) ratVal { return ratVal{newRat().SetInt(x.val)} }
func itof(x intVal) floatVal { return floatVal{newFloat().SetInt(x.val)} }
func rtof(x ratVal) floatVal { return floatVal{newFloat().SetRat(x.val)} }
func vtoc(x Value) complexVal { return complexVal{x, int64Val(0)} }
func makeInt(x *big.Int) Value {
if x.IsInt64() {
return int64Val(x.Int64())
}
return intVal{x}
}
func makeRat(x *big.Rat) Value {
a := x.Num()
b := x.Denom()
if smallInt(a) && smallInt(b) {
// ok to remain fraction
return ratVal{x}
}
// components too large => switch to float
return floatVal{newFloat().SetRat(x)}
}
var floatVal0 = floatVal{newFloat()}
func makeFloat(x *big.Float) Value {
// convert -0
if x.Sign() == 0 {
return floatVal0
}
if x.IsInf() {
return unknownVal{}
}
// No attempt is made to "go back" to ratVal, even if possible,
// to avoid providing the illusion of a mathematically exact
// representation.
return floatVal{x}
}
func makeComplex(re, im Value) Value {
if re.Kind() == Unknown || im.Kind() == Unknown {
return unknownVal{}
}
return complexVal{re, im}
}
func makeFloatFromLiteral(lit string) Value {
if f, ok := newFloat().SetString(lit); ok {
if smallFloat(f) {
// ok to use rationals
if f.Sign() == 0 {
// Issue 20228: If the float underflowed to zero, parse just "0".
// Otherwise, lit might contain a value with a large negative exponent,
// such as -6e-1886451601. As a float, that will underflow to 0,
// but it'll take forever to parse as a Rat.
lit = "0"
}
if r, ok := newRat().SetString(lit); ok {
return ratVal{r}
}
}
// otherwise use floats
return makeFloat(f)
}
return nil
}
// Permit fractions with component sizes up to maxExp
// before switching to using floating-point numbers.
const maxExp = 4 << 10
// smallInt reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallInt(x *big.Int) bool {
return x.BitLen() < maxExp
}
// smallFloat64 reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallFloat64(x float64) bool {
if math.IsInf(x, 0) {
return false
}
_, e := math.Frexp(x)
return -maxExp < e && e < maxExp
}
// smallFloat reports whether x would lead to "reasonably"-sized fraction
// if converted to a *big.Rat.
func smallFloat(x *big.Float) bool {
if x.IsInf() {
return false
}
e := x.MantExp(nil)
return -maxExp < e && e < maxExp
}
// ----------------------------------------------------------------------------
// Factories
// MakeUnknown returns the Unknown value.
func MakeUnknown() Value { return unknownVal{} }
// MakeBool returns the Bool value for b.
func MakeBool(b bool) Value { return boolVal(b) }
// MakeString returns the String value for s.
func MakeString(s string) Value {
if s == "" {
return &emptyString // common case
}
return &stringVal{s: s}
}
var emptyString stringVal
// MakeInt64 returns the Int value for x.
func MakeInt64(x int64) Value { return int64Val(x) }
// MakeUint64 returns the Int value for x.
func MakeUint64(x uint64) Value {
if x < 1<<63 {
return int64Val(int64(x))
}
return intVal{newInt().SetUint64(x)}
}
// MakeFloat64 returns the Float value for x.
// If x is -0.0, the result is 0.0.
// If x is not finite, the result is an Unknown.
func MakeFloat64(x float64) Value {
if math.IsInf(x, 0) || math.IsNaN(x) {
return unknownVal{}
}
if smallFloat64(x) {
return ratVal{newRat().SetFloat64(x + 0)} // convert -0 to 0
}
return floatVal{newFloat().SetFloat64(x + 0)}
}
// MakeFromLiteral returns the corresponding integer, floating-point,
// imaginary, character, or string value for a Go literal string. The
// tok value must be one of token.INT, token.FLOAT, token.IMAG,
// token.CHAR, or token.STRING. The final argument must be zero.
// If the literal string syntax is invalid, the result is an Unknown.
func MakeFromLiteral(lit string, tok token.Token, zero uint) Value {
if zero != 0 {
panic("MakeFromLiteral called with non-zero last argument")
}
switch tok {
case token.INT:
if x, err := strconv.ParseInt(lit, 0, 64); err == nil {
return int64Val(x)
}
if x, ok := newInt().SetString(lit, 0); ok {
return intVal{x}
}
case token.FLOAT:
if x := makeFloatFromLiteral(lit); x != nil {
return x
}
case token.IMAG:
if n := len(lit); n > 0 && lit[n-1] == 'i' {
if im := makeFloatFromLiteral(lit[:n-1]); im != nil {
return makeComplex(int64Val(0), im)
}
}
case token.CHAR:
if n := len(lit); n >= 2 {
if code, _, _, err := strconv.UnquoteChar(lit[1:n-1], '\''); err == nil {
return MakeInt64(int64(code))
}
}
case token.STRING:
if s, err := strconv.Unquote(lit); err == nil {
return MakeString(s)
}
default:
panic(fmt.Sprintf("%v is not a valid token", tok))
}
return unknownVal{}
}
// ----------------------------------------------------------------------------
// Accessors
//
// For unknown arguments the result is the zero value for the respective
// accessor type, except for Sign, where the result is 1.
// BoolVal returns the Go boolean value of x, which must be a Bool or an Unknown.
// If x is Unknown, the result is false.
func BoolVal(x Value) bool {
switch x := x.(type) {
case boolVal:
return bool(x)
case unknownVal:
return false
default:
panic(fmt.Sprintf("%v not a Bool", x))
}
}
// StringVal returns the Go string value of x, which must be a String or an Unknown.
// If x is Unknown, the result is "".
func StringVal(x Value) string {
switch x := x.(type) {
case *stringVal:
return x.string()
case unknownVal:
return ""
default:
panic(fmt.Sprintf("%v not a String", x))
}
}
// Int64Val returns the Go int64 value of x and whether the result is exact;
// x must be an Int or an Unknown. If the result is not exact, its value is undefined.
// If x is Unknown, the result is (0, false).
func Int64Val(x Value) (int64, bool) {
switch x := x.(type) {
case int64Val:
return int64(x), true
case intVal:
return x.val.Int64(), false // not an int64Val and thus not exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Uint64Val returns the Go uint64 value of x and whether the result is exact;
// x must be an Int or an Unknown. If the result is not exact, its value is undefined.
// If x is Unknown, the result is (0, false).
func Uint64Val(x Value) (uint64, bool) {
switch x := x.(type) {
case int64Val:
return uint64(x), x >= 0
case intVal:
return x.val.Uint64(), x.val.IsUint64()
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Float32Val is like Float64Val but for float32 instead of float64.
func Float32Val(x Value) (float32, bool) {
switch x := x.(type) {
case int64Val:
f := float32(x)
return f, int64Val(f) == x
case intVal:
f, acc := newFloat().SetInt(x.val).Float32()
return f, acc == big.Exact
case ratVal:
return x.val.Float32()
case floatVal:
f, acc := x.val.Float32()
return f, acc == big.Exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not a Float", x))
}
}
// Float64Val returns the nearest Go float64 value of x and whether the result is exact;
// x must be numeric or an Unknown, but not Complex. For values too small (too close to 0)
// to represent as float64, Float64Val silently underflows to 0. The result sign always
// matches the sign of x, even for 0.
// If x is Unknown, the result is (0, false).
func Float64Val(x Value) (float64, bool) {
switch x := x.(type) {
case int64Val:
f := float64(int64(x))
return f, int64Val(f) == x
case intVal:
f, acc := newFloat().SetInt(x.val).Float64()
return f, acc == big.Exact
case ratVal:
return x.val.Float64()
case floatVal:
f, acc := x.val.Float64()
return f, acc == big.Exact
case unknownVal:
return 0, false
default:
panic(fmt.Sprintf("%v not a Float", x))
}
}
// Val returns the underlying value for a given constant. Since it returns an
// interface, it is up to the caller to type assert the result to the expected
// type. The possible dynamic return types are:
//
// x Kind type of result
// -----------------------------------------
// Bool bool
// String string
// Int int64 or *big.Int
// Float *big.Float or *big.Rat
// everything else nil
func Val(x Value) any {
switch x := x.(type) {
case boolVal:
return bool(x)
case *stringVal:
return x.string()
case int64Val:
return int64(x)
case intVal:
return x.val
case ratVal:
return x.val
case floatVal:
return x.val
default:
return nil
}
}
// Make returns the Value for x.
//
// type of x result Kind
// ----------------------------
// bool Bool
// string String
// int64 Int
// *big.Int Int
// *big.Float Float
// *big.Rat Float
// anything else Unknown
func Make(x any) Value {
switch x := x.(type) {
case bool:
return boolVal(x)
case string:
return &stringVal{s: x}
case int64:
return int64Val(x)
case *big.Int:
return makeInt(x)
case *big.Rat:
return makeRat(x)
case *big.Float:
return makeFloat(x)
default:
return unknownVal{}
}
}
// BitLen returns the number of bits required to represent
// the absolute value x in binary representation; x must be an Int or an Unknown.
// If x is Unknown, the result is 0.
func BitLen(x Value) int {
switch x := x.(type) {
case int64Val:
u := uint64(x)
if x < 0 {
u = uint64(-x)
}
return 64 - bits.LeadingZeros64(u)
case intVal:
return x.val.BitLen()
case unknownVal:
return 0
default:
panic(fmt.Sprintf("%v not an Int", x))
}
}
// Sign returns -1, 0, or 1 depending on whether x < 0, x == 0, or x > 0;
// x must be numeric or Unknown. For complex values x, the sign is 0 if x == 0,
// otherwise it is != 0. If x is Unknown, the result is 1.
func Sign(x Value) int {
switch x := x.(type) {
case int64Val:
switch {
case x < 0:
return -1
case x > 0:
return 1
}
return 0
case intVal:
return x.val.Sign()
case ratVal:
return x.val.Sign()
case floatVal:
return x.val.Sign()
case complexVal:
return Sign(x.re) | Sign(x.im)
case unknownVal:
return 1 // avoid spurious division by zero errors
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// ----------------------------------------------------------------------------
// Support for assembling/disassembling numeric values
const (
// Compute the size of a Word in bytes.
_m = ^big.Word(0)
_log = _m>>8&1 + _m>>16&1 + _m>>32&1
wordSize = 1 << _log
)
// Bytes returns the bytes for the absolute value of x in little-
// endian binary representation; x must be an Int.
func Bytes(x Value) []byte {
var t intVal
switch x := x.(type) {
case int64Val:
t = i64toi(x)
case intVal:
t = x
default:
panic(fmt.Sprintf("%v not an Int", x))
}
words := t.val.Bits()
bytes := make([]byte, len(words)*wordSize)
i := 0
for _, w := range words {
for j := 0; j < wordSize; j++ {
bytes[i] = byte(w)
w >>= 8
i++
}
}
// remove leading 0's
for i > 0 && bytes[i-1] == 0 {
i--
}
return bytes[:i]
}
// MakeFromBytes returns the Int value given the bytes of its little-endian
// binary representation. An empty byte slice argument represents 0.
func MakeFromBytes(bytes []byte) Value {
words := make([]big.Word, (len(bytes)+(wordSize-1))/wordSize)
i := 0
var w big.Word
var s uint
for _, b := range bytes {
w |= big.Word(b) << s
if s += 8; s == wordSize*8 {
words[i] = w
i++
w = 0
s = 0
}
}
// store last word
if i < len(words) {
words[i] = w
i++
}
// remove leading 0's
for i > 0 && words[i-1] == 0 {
i--
}
return makeInt(newInt().SetBits(words[:i]))
}
// Num returns the numerator of x; x must be Int, Float, or Unknown.
// If x is Unknown, or if it is too large or small to represent as a
// fraction, the result is Unknown. Otherwise the result is an Int
// with the same sign as x.
func Num(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return x
case ratVal:
return makeInt(x.val.Num())
case floatVal:
if smallFloat(x.val) {
r, _ := x.val.Rat(nil)
return makeInt(r.Num())
}
case unknownVal:
break
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
return unknownVal{}
}
// Denom returns the denominator of x; x must be Int, Float, or Unknown.
// If x is Unknown, or if it is too large or small to represent as a
// fraction, the result is Unknown. Otherwise the result is an Int >= 1.
func Denom(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return int64Val(1)
case ratVal:
return makeInt(x.val.Denom())
case floatVal:
if smallFloat(x.val) {
r, _ := x.val.Rat(nil)
return makeInt(r.Denom())
}
case unknownVal:
break
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
return unknownVal{}
}
// MakeImag returns the Complex value x*i;
// x must be Int, Float, or Unknown.
// If x is Unknown, the result is Unknown.
func MakeImag(x Value) Value {
switch x.(type) {
case unknownVal:
return x
case int64Val, intVal, ratVal, floatVal:
return makeComplex(int64Val(0), x)
default:
panic(fmt.Sprintf("%v not Int or Float", x))
}
}
// Real returns the real part of x, which must be a numeric or unknown value.
// If x is Unknown, the result is Unknown.
func Real(x Value) Value {
switch x := x.(type) {
case unknownVal, int64Val, intVal, ratVal, floatVal:
return x
case complexVal:
return x.re
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// Imag returns the imaginary part of x, which must be a numeric or unknown value.
// If x is Unknown, the result is Unknown.
func Imag(x Value) Value {
switch x := x.(type) {
case unknownVal:
return x
case int64Val, intVal, ratVal, floatVal:
return int64Val(0)
case complexVal:
return x.im
default:
panic(fmt.Sprintf("%v not numeric", x))
}
}
// ----------------------------------------------------------------------------
// Numeric conversions
// ToInt converts x to an Int value if x is representable as an Int.
// Otherwise it returns an Unknown.
func ToInt(x Value) Value {
switch x := x.(type) {
case int64Val, intVal:
return x
case ratVal:
if x.val.IsInt() {
return makeInt(x.val.Num())
}
case floatVal:
// avoid creation of huge integers
// (Existing tests require permitting exponents of at least 1024;
// allow any value that would also be permissible as a fraction.)
if smallFloat(x.val) {
i := newInt()
if _, acc := x.val.Int(i); acc == big.Exact {
return makeInt(i)
}
// If we can get an integer by rounding up or down,
// assume x is not an integer because of rounding
// errors in prior computations.
const delta = 4 // a small number of bits > 0
var t big.Float
t.SetPrec(prec - delta)
// try rounding down a little
t.SetMode(big.ToZero)
t.Set(x.val)
if _, acc := t.Int(i); acc == big.Exact {
return makeInt(i)
}
// try rounding up a little
t.SetMode(big.AwayFromZero)
t.Set(x.val)
if _, acc := t.Int(i); acc == big.Exact {
return makeInt(i)
}
}
case complexVal:
if re := ToFloat(x); re.Kind() == Float {
return ToInt(re)
}
}
return unknownVal{}
}
// ToFloat converts x to a Float value if x is representable as a Float.
// Otherwise it returns an Unknown.
func ToFloat(x Value) Value {
switch x := x.(type) {
case int64Val:
return i64tor(x) // x is always a small int
case intVal:
if smallInt(x.val) {
return itor(x)
}
return itof(x)
case ratVal, floatVal:
return x
case complexVal:
if Sign(x.im) == 0 {
return ToFloat(x.re)
}
}
return unknownVal{}
}
// ToComplex converts x to a Complex value if x is representable as a Complex.
// Otherwise it returns an Unknown.
func ToComplex(x Value) Value {
switch x := x.(type) {
case int64Val, intVal, ratVal, floatVal:
return vtoc(x)
case complexVal:
return x
}
return unknownVal{}
}
// ----------------------------------------------------------------------------
// Operations
// is32bit reports whether x can be represented using 32 bits.
func is32bit(x int64) bool {
const s = 32
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
}
// is63bit reports whether x can be represented using 63 bits.
func is63bit(x int64) bool {
const s = 63
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
}
// UnaryOp returns the result of the unary expression op y.
// The operation must be defined for the operand.
// If prec > 0 it specifies the ^ (xor) result size in bits.
// If y is Unknown, the result is Unknown.
func UnaryOp(op token.Token, y Value, prec uint) Value {
switch op {
case token.ADD:
switch y.(type) {
case unknownVal, int64Val, intVal, ratVal, floatVal, complexVal:
return y
}
case token.SUB:
switch y := y.(type) {
case unknownVal:
return y
case int64Val:
if z := -y; z != y {
return z // no overflow
}
return makeInt(newInt().Neg(big.NewInt(int64(y))))
case intVal:
return makeInt(newInt().Neg(y.val))
case ratVal:
return makeRat(newRat().Neg(y.val))
case floatVal:
return makeFloat(newFloat().Neg(y.val))
case complexVal:
re := UnaryOp(token.SUB, y.re, 0)
im := UnaryOp(token.SUB, y.im, 0)
return makeComplex(re, im)
}
case token.XOR:
z := newInt()
switch y := y.(type) {
case unknownVal:
return y
case int64Val:
z.Not(big.NewInt(int64(y)))
case intVal:
z.Not(y.val)
default:
goto Error
}
// For unsigned types, the result will be negative and
// thus "too large": We must limit the result precision
// to the type's precision.
if prec > 0 {
z.AndNot(z, newInt().Lsh(big.NewInt(-1), prec)) // z &^= (-1)<<prec
}
return makeInt(z)
case token.NOT:
switch y := y.(type) {
case unknownVal:
return y
case boolVal:
return !y
}
}
Error:
panic(fmt.Sprintf("invalid unary operation %s%v", op, y))
}
func ord(x Value) int {
switch x.(type) {
default:
// force invalid value into "x position" in match
// (don't panic here so that callers can provide a better error message)
return -1
case unknownVal:
return 0
case boolVal, *stringVal:
return 1
case int64Val:
return 2
case intVal:
return 3
case ratVal:
return 4
case floatVal:
return 5
case complexVal:
return 6
}
}
// match returns the matching representation (same type) with the
// smallest complexity for two values x and y. If one of them is
// numeric, both of them must be numeric. If one of them is Unknown
// or invalid (say, nil) both results are that value.
func match(x, y Value) (_, _ Value) {
switch ox, oy := ord(x), ord(y); {
case ox < oy:
x, y = match0(x, y)
case ox > oy:
y, x = match0(y, x)
}
return x, y
}
// match0 must only be called by match.
// Invariant: ord(x) < ord(y)
func match0(x, y Value) (_, _ Value) {
// Prefer to return the original x and y arguments when possible,
// to avoid unnecessary heap allocations.
switch y.(type) {
case intVal:
switch x1 := x.(type) {
case int64Val:
return i64toi(x1), y
}
case ratVal:
switch x1 := x.(type) {
case int64Val:
return i64tor(x1), y
case intVal:
return itor(x1), y
}
case floatVal:
switch x1 := x.(type) {
case int64Val:
return i64tof(x1), y
case intVal:
return itof(x1), y
case ratVal:
return rtof(x1), y
}
case complexVal:
return vtoc(x), y
}
// force unknown and invalid values into "x position" in callers of match
// (don't panic here so that callers can provide a better error message)
return x, x
}
// BinaryOp returns the result of the binary expression x op y.
// The operation must be defined for the operands. If one of the
// operands is Unknown, the result is Unknown.
// BinaryOp doesn't handle comparisons or shifts; use Compare
// or Shift instead.
//
// To force integer division of Int operands, use op == token.QUO_ASSIGN
// instead of token.QUO; the result is guaranteed to be Int in this case.
// Division by zero leads to a run-time panic.
func BinaryOp(x_ Value, op token.Token, y_ Value) Value {
x, y := match(x_, y_)
switch x := x.(type) {
case unknownVal:
return x
case boolVal:
y := y.(boolVal)
switch op {
case token.LAND:
return x && y
case token.LOR:
return x || y
}
case int64Val:
a := int64(x)
b := int64(y.(int64Val))
var c int64
switch op {
case token.ADD:
if !is63bit(a) || !is63bit(b) {
return makeInt(newInt().Add(big.NewInt(a), big.NewInt(b)))
}
c = a + b
case token.SUB:
if !is63bit(a) || !is63bit(b) {
return makeInt(newInt().Sub(big.NewInt(a), big.NewInt(b)))
}
c = a - b
case token.MUL:
if !is32bit(a) || !is32bit(b) {
return makeInt(newInt().Mul(big.NewInt(a), big.NewInt(b)))
}
c = a * b
case token.QUO:
return makeRat(big.NewRat(a, b))
case token.QUO_ASSIGN: // force integer division
c = a / b
case token.REM:
c = a % b
case token.AND:
c = a & b
case token.OR:
c = a | b
case token.XOR:
c = a ^ b
case token.AND_NOT:
c = a &^ b
default:
goto Error
}
return int64Val(c)
case intVal:
a := x.val
b := y.(intVal).val
c := newInt()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
return makeRat(newRat().SetFrac(a, b))
case token.QUO_ASSIGN: // force integer division
c.Quo(a, b)
case token.REM:
c.Rem(a, b)
case token.AND:
c.And(a, b)
case token.OR:
c.Or(a, b)
case token.XOR:
c.Xor(a, b)
case token.AND_NOT:
c.AndNot(a, b)
default:
goto Error
}
return makeInt(c)
case ratVal:
a := x.val
b := y.(ratVal).val
c := newRat()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
c.Quo(a, b)
default:
goto Error
}
return makeRat(c)
case floatVal:
a := x.val
b := y.(floatVal).val
c := newFloat()
switch op {
case token.ADD:
c.Add(a, b)
case token.SUB:
c.Sub(a, b)
case token.MUL:
c.Mul(a, b)
case token.QUO:
c.Quo(a, b)
default:
goto Error
}
return makeFloat(c)
case complexVal:
y := y.(complexVal)
a, b := x.re, x.im
c, d := y.re, y.im
var re, im Value
switch op {
case token.ADD:
// (a+c) + i(b+d)
re = add(a, c)
im = add(b, d)
case token.SUB:
// (a-c) + i(b-d)
re = sub(a, c)
im = sub(b, d)
case token.MUL:
// (ac-bd) + i(bc+ad)
ac := mul(a, c)
bd := mul(b, d)
bc := mul(b, c)
ad := mul(a, d)
re = sub(ac, bd)
im = add(bc, ad)
case token.QUO:
// (ac+bd)/s + i(bc-ad)/s, with s = cc + dd
ac := mul(a, c)
bd := mul(b, d)
bc := mul(b, c)
ad := mul(a, d)
cc := mul(c, c)
dd := mul(d, d)
s := add(cc, dd)
re = add(ac, bd)
re = quo(re, s)
im = sub(bc, ad)
im = quo(im, s)
default:
goto Error
}
return makeComplex(re, im)
case *stringVal:
if op == token.ADD {
return &stringVal{l: x, r: y.(*stringVal)}
}
}
Error:
panic(fmt.Sprintf("invalid binary operation %v %s %v", x_, op, y_))
}
func add(x, y Value) Value { return BinaryOp(x, token.ADD, y) }
func sub(x, y Value) Value { return BinaryOp(x, token.SUB, y) }
func mul(x, y Value) Value { return BinaryOp(x, token.MUL, y) }
func quo(x, y Value) Value { return BinaryOp(x, token.QUO, y) }
// Shift returns the result of the shift expression x op s
// with op == token.SHL or token.SHR (<< or >>). x must be
// an Int or an Unknown. If x is Unknown, the result is x.
func Shift(x Value, op token.Token, s uint) Value {
switch x := x.(type) {
case unknownVal:
return x
case int64Val:
if s == 0 {
return x
}
switch op {
case token.SHL:
z := i64toi(x).val
return makeInt(z.Lsh(z, s))
case token.SHR:
return x >> s
}
case intVal:
if s == 0 {
return x
}
z := newInt()
switch op {
case token.SHL:
return makeInt(z.Lsh(x.val, s))
case token.SHR:
return makeInt(z.Rsh(x.val, s))
}
}
panic(fmt.Sprintf("invalid shift %v %s %d", x, op, s))
}
func cmpZero(x int, op token.Token) bool {
switch op {
case token.EQL:
return x == 0
case token.NEQ:
return x != 0
case token.LSS:
return x < 0
case token.LEQ:
return x <= 0
case token.GTR:
return x > 0
case token.GEQ:
return x >= 0
}
panic(fmt.Sprintf("invalid comparison %v %s 0", x, op))
}
// Compare returns the result of the comparison x op y.
// The comparison must be defined for the operands.
// If one of the operands is Unknown, the result is
// false.
func Compare(x_ Value, op token.Token, y_ Value) bool {
x, y := match(x_, y_)
switch x := x.(type) {
case unknownVal:
return false
case boolVal:
y := y.(boolVal)
switch op {
case token.EQL:
return x == y
case token.NEQ:
return x != y
}
case int64Val:
y := y.(int64Val)
switch op {
case token.EQL:
return x == y
case token.NEQ:
return x != y
case token.LSS:
return x < y
case token.LEQ:
return x <= y
case token.GTR:
return x > y
case token.GEQ:
return x >= y
}
case intVal:
return cmpZero(x.val.Cmp(y.(intVal).val), op)
case ratVal:
return cmpZero(x.val.Cmp(y.(ratVal).val), op)
case floatVal:
return cmpZero(x.val.Cmp(y.(floatVal).val), op)
case complexVal:
y := y.(complexVal)
re := Compare(x.re, token.EQL, y.re)
im := Compare(x.im, token.EQL, y.im)
switch op {
case token.EQL:
return re && im
case token.NEQ:
return !re || !im
}
case *stringVal:
xs := x.string()
ys := y.(*stringVal).string()
switch op {
case token.EQL:
return xs == ys
case token.NEQ:
return xs != ys
case token.LSS:
return xs < ys
case token.LEQ:
return xs <= ys
case token.GTR:
return xs > ys
case token.GEQ:
return xs >= ys
}
}
panic(fmt.Sprintf("invalid comparison %v %s %v", x_, op, y_))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"go/doc/comment"
"io"
)
// ToHTML converts comment text to formatted HTML.
//
// Deprecated: ToHTML cannot identify documentation links
// in the doc comment, because they depend on knowing what
// package the text came from, which is not included in this API.
//
// Given the *[doc.Package] p where text was found,
// ToHTML(w, text, nil) can be replaced by:
//
// w.Write(p.HTML(text))
//
// which is in turn shorthand for:
//
// w.Write(p.Printer().HTML(p.Parser().Parse(text)))
//
// If words may be non-nil, the longer replacement is:
//
// parser := p.Parser()
// parser.Words = words
// w.Write(p.Printer().HTML(parser.Parse(d)))
func ToHTML(w io.Writer, text string, words map[string]string) {
p := new(Package).Parser()
p.Words = words
d := p.Parse(text)
pr := new(comment.Printer)
w.Write(pr.HTML(d))
}
// ToText converts comment text to formatted text.
//
// Deprecated: ToText cannot identify documentation links
// in the doc comment, because they depend on knowing what
// package the text came from, which is not included in this API.
//
// Given the *[doc.Package] p where text was found,
// ToText(w, text, "", "\t", 80) can be replaced by:
//
// w.Write(p.Text(text))
//
// In the general case, ToText(w, text, prefix, codePrefix, width)
// can be replaced by:
//
// d := p.Parser().Parse(text)
// pr := p.Printer()
// pr.TextPrefix = prefix
// pr.TextCodePrefix = codePrefix
// pr.TextWidth = width
// w.Write(pr.Text(d))
//
// See the documentation for [Package.Text] and [comment.Printer.Text]
// for more details.
func ToText(w io.Writer, text string, prefix, codePrefix string, width int) {
d := new(Package).Parser().Parse(text)
pr := &comment.Printer{
TextPrefix: prefix,
TextCodePrefix: codePrefix,
TextWidth: width,
}
w.Write(pr.Text(d))
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strconv"
)
// An htmlPrinter holds the state needed for printing a Doc as HTML.
type htmlPrinter struct {
*Printer
tight bool
}
// HTML returns an HTML formatting of the Doc.
// See the [Printer] documentation for ways to customize the HTML output.
func (p *Printer) HTML(d *Doc) []byte {
hp := &htmlPrinter{Printer: p}
var out bytes.Buffer
for _, x := range d.Content {
hp.block(&out, x)
}
return out.Bytes()
}
// block prints the block x to out.
func (p *htmlPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
if !p.tight {
out.WriteString("<p>")
}
p.text(out, x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString("<h")
h := strconv.Itoa(p.headingLevel())
out.WriteString(h)
if id := p.headingID(x); id != "" {
out.WriteString(` id="`)
p.escape(out, id)
out.WriteString(`"`)
}
out.WriteString(">")
p.text(out, x.Text)
out.WriteString("</h")
out.WriteString(h)
out.WriteString(">\n")
case *Code:
out.WriteString("<pre>")
p.escape(out, x.Text)
out.WriteString("</pre>\n")
case *List:
kind := "ol>\n"
if x.Items[0].Number == "" {
kind = "ul>\n"
}
out.WriteString("<")
out.WriteString(kind)
next := "1"
for _, item := range x.Items {
out.WriteString("<li")
if n := item.Number; n != "" {
if n != next {
out.WriteString(` value="`)
out.WriteString(n)
out.WriteString(`"`)
next = n
}
next = inc(next)
}
out.WriteString(">")
p.tight = !x.BlankBetween()
for _, blk := range item.Content {
p.block(out, blk)
}
p.tight = false
}
out.WriteString("</")
out.WriteString(kind)
}
}
// inc increments the decimal string s.
// For example, inc("1199") == "1200".
func inc(s string) string {
b := []byte(s)
for i := len(b) - 1; i >= 0; i-- {
if b[i] < '9' {
b[i]++
return string(b)
}
b[i] = '0'
}
return "1" + string(b)
}
// text prints the text sequence x to out.
func (p *htmlPrinter) text(out *bytes.Buffer, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.escape(out, string(t))
case Italic:
out.WriteString("<i>")
p.escape(out, string(t))
out.WriteString("</i>")
case *Link:
out.WriteString(`<a href="`)
p.escape(out, t.URL)
out.WriteString(`">`)
p.text(out, t.Text)
out.WriteString("</a>")
case *DocLink:
url := p.docLinkURL(t)
if url != "" {
out.WriteString(`<a href="`)
p.escape(out, url)
out.WriteString(`">`)
}
p.text(out, t.Text)
if url != "" {
out.WriteString("</a>")
}
}
}
}
// escape prints s to out as plain text,
// escaping < & " ' and > to avoid being misinterpreted
// in larger HTML constructs.
func (p *htmlPrinter) escape(out *bytes.Buffer, s string) {
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '<':
out.WriteString(s[start:i])
out.WriteString("<")
start = i + 1
case '&':
out.WriteString(s[start:i])
out.WriteString("&")
start = i + 1
case '"':
out.WriteString(s[start:i])
out.WriteString(""")
start = i + 1
case '\'':
out.WriteString(s[start:i])
out.WriteString("'")
start = i + 1
case '>':
out.WriteString(s[start:i])
out.WriteString(">")
start = i + 1
}
}
out.WriteString(s[start:])
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strings"
)
// An mdPrinter holds the state needed for printing a Doc as Markdown.
type mdPrinter struct {
*Printer
headingPrefix string
raw bytes.Buffer
}
// Markdown returns a Markdown formatting of the Doc.
// See the [Printer] documentation for ways to customize the Markdown output.
func (p *Printer) Markdown(d *Doc) []byte {
mp := &mdPrinter{
Printer: p,
headingPrefix: strings.Repeat("#", p.headingLevel()) + " ",
}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 {
out.WriteByte('\n')
}
mp.block(&out, x)
}
return out.Bytes()
}
// block prints the block x to out.
func (p *mdPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
p.text(out, x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString(p.headingPrefix)
p.text(out, x.Text)
if id := p.headingID(x); id != "" {
out.WriteString(" {#")
out.WriteString(id)
out.WriteString("}")
}
out.WriteString("\n")
case *Code:
md := x.Text
for md != "" {
var line string
line, md, _ = strings.Cut(md, "\n")
if line != "" {
out.WriteString("\t")
out.WriteString(line)
}
out.WriteString("\n")
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString("\n")
}
if n := item.Number; n != "" {
out.WriteString(" ")
out.WriteString(n)
out.WriteString(". ")
} else {
out.WriteString(" - ") // SP SP - SP
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
out.WriteString("\n" + fourSpace)
}
p.text(out, blk.(*Paragraph).Text)
out.WriteString("\n")
}
}
}
}
// text prints the text sequence x to out.
func (p *mdPrinter) text(out *bytes.Buffer, x []Text) {
p.raw.Reset()
p.rawText(&p.raw, x)
line := bytes.TrimSpace(p.raw.Bytes())
if len(line) == 0 {
return
}
switch line[0] {
case '+', '-', '*', '#':
// Escape what would be the start of an unordered list or heading.
out.WriteByte('\\')
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
i := 1
for i < len(line) && '0' <= line[i] && line[i] <= '9' {
i++
}
if i < len(line) && (line[i] == '.' || line[i] == ')') {
// Escape what would be the start of an ordered list.
out.Write(line[:i])
out.WriteByte('\\')
line = line[i:]
}
}
out.Write(line)
}
// rawText prints the text sequence x to out,
// without worrying about escaping characters
// that have special meaning at the start of a Markdown line.
func (p *mdPrinter) rawText(out *bytes.Buffer, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.escape(out, string(t))
case Italic:
out.WriteString("*")
p.escape(out, string(t))
out.WriteString("*")
case *Link:
out.WriteString("[")
p.rawText(out, t.Text)
out.WriteString("](")
out.WriteString(t.URL)
out.WriteString(")")
case *DocLink:
url := p.docLinkURL(t)
if url != "" {
out.WriteString("[")
}
p.rawText(out, t.Text)
if url != "" {
out.WriteString("](")
url = strings.ReplaceAll(url, "(", "%28")
url = strings.ReplaceAll(url, ")", "%29")
out.WriteString(url)
out.WriteString(")")
}
}
}
}
// escape prints s to out as plain text,
// escaping special characters to avoid being misinterpreted
// as Markdown markup sequences.
func (p *mdPrinter) escape(out *bytes.Buffer, s string) {
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '\n':
// Turn all \n into spaces, for a few reasons:
// - Avoid introducing paragraph breaks accidentally.
// - Avoid the need to reindent after the newline.
// - Avoid problems with Markdown renderers treating
// every mid-paragraph newline as a <br>.
out.WriteString(s[start:i])
out.WriteByte(' ')
start = i + 1
continue
case '`', '_', '*', '[', '<', '\\':
// Not all of these need to be escaped all the time,
// but is valid and easy to do so.
// We assume the Markdown is being passed to a
// Markdown renderer, not edited by a person,
// so it's fine to have escapes that are not strictly
// necessary in some cases.
out.WriteString(s[start:i])
out.WriteByte('\\')
out.WriteByte(s[i])
start = i + 1
}
}
out.WriteString(s[start:])
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"sort"
"strings"
"unicode"
"unicode/utf8"
)
// A Doc is a parsed Go doc comment.
type Doc struct {
// Content is the sequence of content blocks in the comment.
Content []Block
// Links is the link definitions in the comment.
Links []*LinkDef
}
// A LinkDef is a single link definition.
type LinkDef struct {
Text string // the link text
URL string // the link URL
Used bool // whether the comment uses the definition
}
// A Block is block-level content in a doc comment,
// one of [*Code], [*Heading], [*List], or [*Paragraph].
type Block interface {
block()
}
// A Heading is a doc comment heading.
type Heading struct {
Text []Text // the heading text
}
func (*Heading) block() {}
// A List is a numbered or bullet list.
// Lists are always non-empty: len(Items) > 0.
// In a numbered list, every Items[i].Number is a non-empty string.
// In a bullet list, every Items[i].Number is an empty string.
type List struct {
// Items is the list items.
Items []*ListItem
// ForceBlankBefore indicates that the list must be
// preceded by a blank line when reformatting the comment,
// overriding the usual conditions. See the BlankBefore method.
//
// The comment parser sets ForceBlankBefore for any list
// that is preceded by a blank line, to make sure
// the blank line is preserved when printing.
ForceBlankBefore bool
// ForceBlankBetween indicates that list items must be
// separated by blank lines when reformatting the comment,
// overriding the usual conditions. See the BlankBetween method.
//
// The comment parser sets ForceBlankBetween for any list
// that has a blank line between any two of its items, to make sure
// the blank lines are preserved when printing.
ForceBlankBetween bool
}
func (*List) block() {}
// BlankBefore reports whether a reformatting of the comment
// should include a blank line before the list.
// The default rule is the same as for [BlankBetween]:
// if the list item content contains any blank lines
// (meaning at least one item has multiple paragraphs)
// then the list itself must be preceded by a blank line.
// A preceding blank line can be forced by setting [List].ForceBlankBefore.
func (l *List) BlankBefore() bool {
return l.ForceBlankBefore || l.BlankBetween()
}
// BlankBetween reports whether a reformatting of the comment
// should include a blank line between each pair of list items.
// The default rule is that if the list item content contains any blank lines
// (meaning at least one item has multiple paragraphs)
// then list items must themselves be separated by blank lines.
// Blank line separators can be forced by setting [List].ForceBlankBetween.
func (l *List) BlankBetween() bool {
if l.ForceBlankBetween {
return true
}
for _, item := range l.Items {
if len(item.Content) != 1 {
// Unreachable for parsed comments today,
// since the only way to get multiple item.Content
// is multiple paragraphs, which must have been
// separated by a blank line.
return true
}
}
return false
}
// A ListItem is a single item in a numbered or bullet list.
type ListItem struct {
// Number is a decimal string in a numbered list
// or an empty string in a bullet list.
Number string // "1", "2", ...; "" for bullet list
// Content is the list content.
// Currently, restrictions in the parser and printer
// require every element of Content to be a *Paragraph.
Content []Block // Content of this item.
}
// A Paragraph is a paragraph of text.
type Paragraph struct {
Text []Text
}
func (*Paragraph) block() {}
// A Code is a preformatted code block.
type Code struct {
// Text is the preformatted text, ending with a newline character.
// It may be multiple lines, each of which ends with a newline character.
// It is never empty, nor does it start or end with a blank line.
Text string
}
func (*Code) block() {}
// A Text is text-level content in a doc comment,
// one of [Plain], [Italic], [*Link], or [*DocLink].
type Text interface {
text()
}
// A Plain is a string rendered as plain text (not italicized).
type Plain string
func (Plain) text() {}
// An Italic is a string rendered as italicized text.
type Italic string
func (Italic) text() {}
// A Link is a link to a specific URL.
type Link struct {
Auto bool // is this an automatic (implicit) link of a literal URL?
Text []Text // text of link
URL string // target URL of link
}
func (*Link) text() {}
// A DocLink is a link to documentation for a Go package or symbol.
type DocLink struct {
Text []Text // text of link
// ImportPath, Recv, and Name identify the Go package or symbol
// that is the link target. The potential combinations of
// non-empty fields are:
// - ImportPath: a link to another package
// - ImportPath, Name: a link to a const, func, type, or var in another package
// - ImportPath, Recv, Name: a link to a method in another package
// - Name: a link to a const, func, type, or var in this package
// - Recv, Name: a link to a method in this package
ImportPath string // import path
Recv string // receiver type, without any pointer star, for methods
Name string // const, func, type, var, or method name
}
func (*DocLink) text() {}
// A Parser is a doc comment parser.
// The fields in the struct can be filled in before calling Parse
// in order to customize the details of the parsing process.
type Parser struct {
// Words is a map of Go identifier words that
// should be italicized and potentially linked.
// If Words[w] is the empty string, then the word w
// is only italicized. Otherwise it is linked, using
// Words[w] as the link target.
// Words corresponds to the [go/doc.ToHTML] words parameter.
Words map[string]string
// LookupPackage resolves a package name to an import path.
//
// If LookupPackage(name) returns ok == true, then [name]
// (or [name.Sym] or [name.Sym.Method])
// is considered a documentation link to importPath's package docs.
// It is valid to return "", true, in which case name is considered
// to refer to the current package.
//
// If LookupPackage(name) returns ok == false,
// then [name] (or [name.Sym] or [name.Sym.Method])
// will not be considered a documentation link,
// except in the case where name is the full (but single-element) import path
// of a package in the standard library, such as in [math] or [io.Reader].
// LookupPackage is still called for such names,
// in order to permit references to imports of other packages
// with the same package names.
//
// Setting LookupPackage to nil is equivalent to setting it to
// a function that always returns "", false.
LookupPackage func(name string) (importPath string, ok bool)
// LookupSym reports whether a symbol name or method name
// exists in the current package.
//
// If LookupSym("", "Name") returns true, then [Name]
// is considered a documentation link for a const, func, type, or var.
//
// Similarly, if LookupSym("Recv", "Name") returns true,
// then [Recv.Name] is considered a documentation link for
// type Recv's method Name.
//
// Setting LookupSym to nil is equivalent to setting it to a function
// that always returns false.
LookupSym func(recv, name string) (ok bool)
}
// parseDoc is parsing state for a single doc comment.
type parseDoc struct {
*Parser
*Doc
links map[string]*LinkDef
lines []string
lookupSym func(recv, name string) bool
}
// lookupPkg is called to look up the pkg in [pkg], [pkg.Name], and [pkg.Name.Recv].
// If pkg has a slash, it is assumed to be the full import path and is returned with ok = true.
//
// Otherwise, pkg is probably a simple package name like "rand" (not "crypto/rand" or "math/rand").
// d.LookupPackage provides a way for the caller to allow resolving such names with reference
// to the imports in the surrounding package.
//
// There is one collision between these two cases: single-element standard library names
// like "math" are full import paths but don't contain slashes. We let d.LookupPackage have
// the first chance to resolve it, in case there's a different package imported as math,
// and otherwise we refer to a built-in list of single-element standard library package names.
func (d *parseDoc) lookupPkg(pkg string) (importPath string, ok bool) {
if strings.Contains(pkg, "/") { // assume a full import path
if validImportPath(pkg) {
return pkg, true
}
return "", false
}
if d.LookupPackage != nil {
// Give LookupPackage a chance.
if path, ok := d.LookupPackage(pkg); ok {
return path, true
}
}
return DefaultLookupPackage(pkg)
}
func isStdPkg(path string) bool {
// TODO(rsc): Use sort.Find once we don't have to worry about
// copying this code into older Go environments.
i := sort.Search(len(stdPkgs), func(i int) bool { return stdPkgs[i] >= path })
return i < len(stdPkgs) && stdPkgs[i] == path
}
// DefaultLookupPackage is the default package lookup
// function, used when [Parser].LookupPackage is nil.
// It recognizes names of the packages from the standard
// library with single-element import paths, such as math,
// which would otherwise be impossible to name.
//
// Note that the go/doc package provides a more sophisticated
// lookup based on the imports used in the current package.
func DefaultLookupPackage(name string) (importPath string, ok bool) {
if isStdPkg(name) {
return name, true
}
return "", false
}
// Parse parses the doc comment text and returns the *Doc form.
// Comment markers (/* // and */) in the text must have already been removed.
func (p *Parser) Parse(text string) *Doc {
lines := unindent(strings.Split(text, "\n"))
d := &parseDoc{
Parser: p,
Doc: new(Doc),
links: make(map[string]*LinkDef),
lines: lines,
lookupSym: func(recv, name string) bool { return false },
}
if p.LookupSym != nil {
d.lookupSym = p.LookupSym
}
// First pass: break into block structure and collect known links.
// The text is all recorded as Plain for now.
var prev span
for _, s := range parseSpans(lines) {
var b Block
switch s.kind {
default:
panic("go/doc/comment: internal error: unknown span kind")
case spanList:
b = d.list(lines[s.start:s.end], prev.end < s.start)
case spanCode:
b = d.code(lines[s.start:s.end])
case spanOldHeading:
b = d.oldHeading(lines[s.start])
case spanHeading:
b = d.heading(lines[s.start])
case spanPara:
b = d.paragraph(lines[s.start:s.end])
}
if b != nil {
d.Content = append(d.Content, b)
}
prev = s
}
// Second pass: interpret all the Plain text now that we know the links.
for _, b := range d.Content {
switch b := b.(type) {
case *Paragraph:
b.Text = d.parseLinkedText(string(b.Text[0].(Plain)))
case *List:
for _, i := range b.Items {
for _, c := range i.Content {
p := c.(*Paragraph)
p.Text = d.parseLinkedText(string(p.Text[0].(Plain)))
}
}
}
}
return d.Doc
}
// A span represents a single span of comment lines (lines[start:end])
// of an identified kind (code, heading, paragraph, and so on).
type span struct {
start int
end int
kind spanKind
}
// A spanKind describes the kind of span.
type spanKind int
const (
_ spanKind = iota
spanCode
spanHeading
spanList
spanOldHeading
spanPara
)
func parseSpans(lines []string) []span {
var spans []span
// The loop may process a line twice: once as unindented
// and again forced indented. So the maximum expected
// number of iterations is 2*len(lines). The repeating logic
// can be subtle, though, and to protect against introduction
// of infinite loops in future changes, we watch to see that
// we are not looping too much. A panic is better than a
// quiet infinite loop.
watchdog := 2 * len(lines)
i := 0
forceIndent := 0
Spans:
for {
// Skip blank lines.
for i < len(lines) && lines[i] == "" {
i++
}
if i >= len(lines) {
break
}
if watchdog--; watchdog < 0 {
panic("go/doc/comment: internal error: not making progress")
}
var kind spanKind
start := i
end := i
if i < forceIndent || indented(lines[i]) {
// Indented (or force indented).
// Ends before next unindented. (Blank lines are OK.)
// If this is an unindented list that we are heuristically treating as indented,
// then accept unindented list item lines up to the first blank lines.
// The heuristic is disabled at blank lines to contain its effect
// to non-gofmt'ed sections of the comment.
unindentedListOK := isList(lines[i]) && i < forceIndent
i++
for i < len(lines) && (lines[i] == "" || i < forceIndent || indented(lines[i]) || (unindentedListOK && isList(lines[i]))) {
if lines[i] == "" {
unindentedListOK = false
}
i++
}
// Drop trailing blank lines.
end = i
for end > start && lines[end-1] == "" {
end--
}
// If indented lines are followed (without a blank line)
// by an unindented line ending in a brace,
// take that one line too. This fixes the common mistake
// of pasting in something like
//
// func main() {
// fmt.Println("hello, world")
// }
//
// and forgetting to indent it.
// The heuristic will never trigger on a gofmt'ed comment,
// because any gofmt'ed code block or list would be
// followed by a blank line or end of comment.
if end < len(lines) && strings.HasPrefix(lines[end], "}") {
end++
}
if isList(lines[start]) {
kind = spanList
} else {
kind = spanCode
}
} else {
// Unindented. Ends at next blank or indented line.
i++
for i < len(lines) && lines[i] != "" && !indented(lines[i]) {
i++
}
end = i
// If unindented lines are followed (without a blank line)
// by an indented line that would start a code block,
// check whether the final unindented lines
// should be left for the indented section.
// This can happen for the common mistakes of
// unindented code or unindented lists.
// The heuristic will never trigger on a gofmt'ed comment,
// because any gofmt'ed code block would have a blank line
// preceding it after the unindented lines.
if i < len(lines) && lines[i] != "" && !isList(lines[i]) {
switch {
case isList(lines[i-1]):
// If the final unindented line looks like a list item,
// this may be the first indented line wrap of
// a mistakenly unindented list.
// Leave all the unindented list items.
forceIndent = end
end--
for end > start && isList(lines[end-1]) {
end--
}
case strings.HasSuffix(lines[i-1], "{") || strings.HasSuffix(lines[i-1], `\`):
// If the final unindented line ended in { or \
// it is probably the start of a misindented code block.
// Give the user a single line fix.
// Often that's enough; if not, the user can fix the others themselves.
forceIndent = end
end--
}
if start == end && forceIndent > start {
i = start
continue Spans
}
}
// Span is either paragraph or heading.
if end-start == 1 && isHeading(lines[start]) {
kind = spanHeading
} else if end-start == 1 && isOldHeading(lines[start], lines, start) {
kind = spanOldHeading
} else {
kind = spanPara
}
}
spans = append(spans, span{start, end, kind})
i = end
}
return spans
}
// indented reports whether line is indented
// (starts with a leading space or tab).
func indented(line string) bool {
return line != "" && (line[0] == ' ' || line[0] == '\t')
}
// unindent removes any common space/tab prefix
// from each line in lines, returning a copy of lines in which
// those prefixes have been trimmed from each line.
// It also replaces any lines containing only spaces with blank lines (empty strings).
func unindent(lines []string) []string {
// Trim leading and trailing blank lines.
for len(lines) > 0 && isBlank(lines[0]) {
lines = lines[1:]
}
for len(lines) > 0 && isBlank(lines[len(lines)-1]) {
lines = lines[:len(lines)-1]
}
if len(lines) == 0 {
return nil
}
// Compute and remove common indentation.
prefix := leadingSpace(lines[0])
for _, line := range lines[1:] {
if !isBlank(line) {
prefix = commonPrefix(prefix, leadingSpace(line))
}
}
out := make([]string, len(lines))
for i, line := range lines {
line = strings.TrimPrefix(line, prefix)
if strings.TrimSpace(line) == "" {
line = ""
}
out[i] = line
}
for len(out) > 0 && out[0] == "" {
out = out[1:]
}
for len(out) > 0 && out[len(out)-1] == "" {
out = out[:len(out)-1]
}
return out
}
// isBlank reports whether s is a blank line.
func isBlank(s string) bool {
return len(s) == 0 || (len(s) == 1 && s[0] == '\n')
}
// commonPrefix returns the longest common prefix of a and b.
func commonPrefix(a, b string) string {
i := 0
for i < len(a) && i < len(b) && a[i] == b[i] {
i++
}
return a[0:i]
}
// leadingSpace returns the longest prefix of s consisting of spaces and tabs.
func leadingSpace(s string) string {
i := 0
for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
i++
}
return s[:i]
}
// isOldHeading reports whether line is an old-style section heading.
// line is all[off].
func isOldHeading(line string, all []string, off int) bool {
if off <= 0 || all[off-1] != "" || off+2 >= len(all) || all[off+1] != "" || leadingSpace(all[off+2]) != "" {
return false
}
line = strings.TrimSpace(line)
// a heading must start with an uppercase letter
r, _ := utf8.DecodeRuneInString(line)
if !unicode.IsLetter(r) || !unicode.IsUpper(r) {
return false
}
// it must end in a letter or digit:
r, _ = utf8.DecodeLastRuneInString(line)
if !unicode.IsLetter(r) && !unicode.IsDigit(r) {
return false
}
// exclude lines with illegal characters. we allow "(),"
if strings.ContainsAny(line, ";:!?+*/=[]{}_^°&§~%#@<\">\\") {
return false
}
// allow "'" for possessive "'s" only
for b := line; ; {
var ok bool
if _, b, ok = strings.Cut(b, "'"); !ok {
break
}
if b != "s" && !strings.HasPrefix(b, "s ") {
return false // ' not followed by s and then end-of-word
}
}
// allow "." when followed by non-space
for b := line; ; {
var ok bool
if _, b, ok = strings.Cut(b, "."); !ok {
break
}
if b == "" || strings.HasPrefix(b, " ") {
return false // not followed by non-space
}
}
return true
}
// oldHeading returns the *Heading for the given old-style section heading line.
func (d *parseDoc) oldHeading(line string) Block {
return &Heading{Text: []Text{Plain(strings.TrimSpace(line))}}
}
// isHeading reports whether line is a new-style section heading.
func isHeading(line string) bool {
return len(line) >= 2 &&
line[0] == '#' &&
(line[1] == ' ' || line[1] == '\t') &&
strings.TrimSpace(line) != "#"
}
// heading returns the *Heading for the given new-style section heading line.
func (d *parseDoc) heading(line string) Block {
return &Heading{Text: []Text{Plain(strings.TrimSpace(line[1:]))}}
}
// code returns a code block built from the lines.
func (d *parseDoc) code(lines []string) *Code {
body := unindent(lines)
body = append(body, "") // to get final \n from Join
return &Code{Text: strings.Join(body, "\n")}
}
// paragraph returns a paragraph block built from the lines.
// If the lines are link definitions, paragraph adds them to d and returns nil.
func (d *parseDoc) paragraph(lines []string) Block {
// Is this a block of known links? Handle.
var defs []*LinkDef
for _, line := range lines {
def, ok := parseLink(line)
if !ok {
goto NoDefs
}
defs = append(defs, def)
}
for _, def := range defs {
d.Links = append(d.Links, def)
if d.links[def.Text] == nil {
d.links[def.Text] = def
}
}
return nil
NoDefs:
return &Paragraph{Text: []Text{Plain(strings.Join(lines, "\n"))}}
}
// parseLink parses a single link definition line:
//
// [text]: url
//
// It returns the link definition and whether the line was well formed.
func parseLink(line string) (*LinkDef, bool) {
if line == "" || line[0] != '[' {
return nil, false
}
i := strings.Index(line, "]:")
if i < 0 || i+3 >= len(line) || (line[i+2] != ' ' && line[i+2] != '\t') {
return nil, false
}
text := line[1:i]
url := strings.TrimSpace(line[i+3:])
j := strings.Index(url, "://")
if j < 0 || !isScheme(url[:j]) {
return nil, false
}
// Line has right form and has valid scheme://.
// That's good enough for us - we are not as picky
// about the characters beyond the :// as we are
// when extracting inline URLs from text.
return &LinkDef{Text: text, URL: url}, true
}
// list returns a list built from the indented lines,
// using forceBlankBefore as the value of the List's ForceBlankBefore field.
func (d *parseDoc) list(lines []string, forceBlankBefore bool) *List {
num, _, _ := listMarker(lines[0])
var (
list *List = &List{ForceBlankBefore: forceBlankBefore}
item *ListItem
text []string
)
flush := func() {
if item != nil {
if para := d.paragraph(text); para != nil {
item.Content = append(item.Content, para)
}
}
text = nil
}
for _, line := range lines {
if n, after, ok := listMarker(line); ok && (n != "") == (num != "") {
// start new list item
flush()
item = &ListItem{Number: n}
list.Items = append(list.Items, item)
line = after
}
line = strings.TrimSpace(line)
if line == "" {
list.ForceBlankBetween = true
flush()
continue
}
text = append(text, strings.TrimSpace(line))
}
flush()
return list
}
// listMarker parses the line as beginning with a list marker.
// If it can do that, it returns the numeric marker ("" for a bullet list),
// the rest of the line, and ok == true.
// Otherwise, it returns "", "", false.
func listMarker(line string) (num, rest string, ok bool) {
line = strings.TrimSpace(line)
if line == "" {
return "", "", false
}
// Can we find a marker?
if r, n := utf8.DecodeRuneInString(line); r == '•' || r == '*' || r == '+' || r == '-' {
num, rest = "", line[n:]
} else if '0' <= line[0] && line[0] <= '9' {
n := 1
for n < len(line) && '0' <= line[n] && line[n] <= '9' {
n++
}
if n >= len(line) || (line[n] != '.' && line[n] != ')') {
return "", "", false
}
num, rest = line[:n], line[n+1:]
} else {
return "", "", false
}
if !indented(rest) || strings.TrimSpace(rest) == "" {
return "", "", false
}
return num, rest, true
}
// isList reports whether the line is the first line of a list,
// meaning starts with a list marker after any indentation.
// (The caller is responsible for checking the line is indented, as appropriate.)
func isList(line string) bool {
_, _, ok := listMarker(line)
return ok
}
// parseLinkedText parses text that is allowed to contain explicit links,
// such as [math.Sin] or [Go home page], into a slice of Text items.
//
// A “pkg” is only assumed to be a full import path if it starts with
// a domain name (a path element with a dot) or is one of the packages
// from the standard library (“[os]”, “[encoding/json]”, and so on).
// To avoid problems with maps, generics, and array types, doc links
// must be both preceded and followed by punctuation, spaces, tabs,
// or the start or end of a line. An example problem would be treating
// map[ast.Expr]TypeAndValue as containing a link.
func (d *parseDoc) parseLinkedText(text string) []Text {
var out []Text
wrote := 0
flush := func(i int) {
if wrote < i {
out = d.parseText(out, text[wrote:i], true)
wrote = i
}
}
start := -1
var buf []byte
for i := 0; i < len(text); i++ {
c := text[i]
if c == '\n' || c == '\t' {
c = ' '
}
switch c {
case '[':
start = i
case ']':
if start >= 0 {
if def, ok := d.links[string(buf)]; ok {
def.Used = true
flush(start)
out = append(out, &Link{
Text: d.parseText(nil, text[start+1:i], false),
URL: def.URL,
})
wrote = i + 1
} else if link, ok := d.docLink(text[start+1:i], text[:start], text[i+1:]); ok {
flush(start)
link.Text = d.parseText(nil, text[start+1:i], false)
out = append(out, link)
wrote = i + 1
}
}
start = -1
buf = buf[:0]
}
if start >= 0 && i != start {
buf = append(buf, c)
}
}
flush(len(text))
return out
}
// docLink parses text, which was found inside [ ] brackets,
// as a doc link if possible, returning the DocLink and ok == true
// or else nil, false.
// The before and after strings are the text before the [ and after the ]
// on the same line. Doc links must be preceded and followed by
// punctuation, spaces, tabs, or the start or end of a line.
func (d *parseDoc) docLink(text, before, after string) (link *DocLink, ok bool) {
if before != "" {
r, _ := utf8.DecodeLastRuneInString(before)
if !unicode.IsPunct(r) && r != ' ' && r != '\t' && r != '\n' {
return nil, false
}
}
if after != "" {
r, _ := utf8.DecodeRuneInString(after)
if !unicode.IsPunct(r) && r != ' ' && r != '\t' && r != '\n' {
return nil, false
}
}
text = strings.TrimPrefix(text, "*")
pkg, name, ok := splitDocName(text)
var recv string
if ok {
pkg, recv, _ = splitDocName(pkg)
}
if pkg != "" {
if pkg, ok = d.lookupPkg(pkg); !ok {
return nil, false
}
} else {
if ok = d.lookupSym(recv, name); !ok {
return nil, false
}
}
link = &DocLink{
ImportPath: pkg,
Recv: recv,
Name: name,
}
return link, true
}
// If text is of the form before.Name, where Name is a capitalized Go identifier,
// then splitDocName returns before, name, true.
// Otherwise it returns text, "", false.
func splitDocName(text string) (before, name string, foundDot bool) {
i := strings.LastIndex(text, ".")
name = text[i+1:]
if !isName(name) {
return text, "", false
}
if i >= 0 {
before = text[:i]
}
return before, name, true
}
// parseText parses s as text and returns the result of appending
// those parsed Text elements to out.
// parseText does not handle explicit links like [math.Sin] or [Go home page]:
// those are handled by parseLinkedText.
// If autoLink is true, then parseText recognizes URLs and words from d.Words
// and converts those to links as appropriate.
func (d *parseDoc) parseText(out []Text, s string, autoLink bool) []Text {
var w strings.Builder
wrote := 0
writeUntil := func(i int) {
w.WriteString(s[wrote:i])
wrote = i
}
flush := func(i int) {
writeUntil(i)
if w.Len() > 0 {
out = append(out, Plain(w.String()))
w.Reset()
}
}
for i := 0; i < len(s); {
t := s[i:]
if autoLink {
if url, ok := autoURL(t); ok {
flush(i)
// Note: The old comment parser would look up the URL in words
// and replace the target with words[URL] if it was non-empty.
// That would allow creating links that display as one URL but
// when clicked go to a different URL. Not sure what the point
// of that is, so we're not doing that lookup here.
out = append(out, &Link{Auto: true, Text: []Text{Plain(url)}, URL: url})
i += len(url)
wrote = i
continue
}
if id, ok := ident(t); ok {
url, italics := d.Words[id]
if !italics {
i += len(id)
continue
}
flush(i)
if url == "" {
out = append(out, Italic(id))
} else {
out = append(out, &Link{Auto: true, Text: []Text{Italic(id)}, URL: url})
}
i += len(id)
wrote = i
continue
}
}
switch {
case strings.HasPrefix(t, "``"):
if len(t) >= 3 && t[2] == '`' {
// Do not convert `` inside ```, in case people are mistakenly writing Markdown.
i += 3
for i < len(t) && t[i] == '`' {
i++
}
break
}
writeUntil(i)
w.WriteRune('“')
i += 2
wrote = i
case strings.HasPrefix(t, "''"):
writeUntil(i)
w.WriteRune('”')
i += 2
wrote = i
default:
i++
}
}
flush(len(s))
return out
}
// autoURL checks whether s begins with a URL that should be hyperlinked.
// If so, it returns the URL, which is a prefix of s, and ok == true.
// Otherwise it returns "", false.
// The caller should skip over the first len(url) bytes of s
// before further processing.
func autoURL(s string) (url string, ok bool) {
// Find the ://. Fast path to pick off non-URL,
// since we call this at every position in the string.
// The shortest possible URL is ftp://x, 7 bytes.
var i int
switch {
case len(s) < 7:
return "", false
case s[3] == ':':
i = 3
case s[4] == ':':
i = 4
case s[5] == ':':
i = 5
case s[6] == ':':
i = 6
default:
return "", false
}
if i+3 > len(s) || s[i:i+3] != "://" {
return "", false
}
// Check valid scheme.
if !isScheme(s[:i]) {
return "", false
}
// Scan host part. Must have at least one byte,
// and must start and end in non-punctuation.
i += 3
if i >= len(s) || !isHost(s[i]) || isPunct(s[i]) {
return "", false
}
i++
end := i
for i < len(s) && isHost(s[i]) {
if !isPunct(s[i]) {
end = i + 1
}
i++
}
i = end
// At this point we are definitely returning a URL (scheme://host).
// We just have to find the longest path we can add to it.
// Heuristics abound.
// We allow parens, braces, and brackets,
// but only if they match (#5043, #22285).
// We allow .,:;?! in the path but not at the end,
// to avoid end-of-sentence punctuation (#18139, #16565).
stk := []byte{}
end = i
Path:
for ; i < len(s); i++ {
if isPunct(s[i]) {
continue
}
if !isPath(s[i]) {
break
}
switch s[i] {
case '(':
stk = append(stk, ')')
case '{':
stk = append(stk, '}')
case '[':
stk = append(stk, ']')
case ')', '}', ']':
if len(stk) == 0 || stk[len(stk)-1] != s[i] {
break Path
}
stk = stk[:len(stk)-1]
}
if len(stk) == 0 {
end = i + 1
}
}
return s[:end], true
}
// isScheme reports whether s is a recognized URL scheme.
// Note that if strings of new length (beyond 3-7)
// are added here, the fast path at the top of autoURL will need updating.
func isScheme(s string) bool {
switch s {
case "file",
"ftp",
"gopher",
"http",
"https",
"mailto",
"nntp":
return true
}
return false
}
// isHost reports whether c is a byte that can appear in a URL host,
// like www.example.com or user@[::1]:8080
func isHost(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'_' |
1<<'@' |
1<<'-' |
1<<'.' |
1<<'[' |
1<<']' |
1<<':'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isPunct reports whether c is a punctuation byte that can appear
// inside a path but not at the end.
func isPunct(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
1<<'.' |
1<<',' |
1<<':' |
1<<';' |
1<<'?' |
1<<'!'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isPath reports whether c is a (non-punctuation) path byte.
func isPath(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'$' |
1<<'\'' |
1<<'(' |
1<<')' |
1<<'*' |
1<<'+' |
1<<'&' |
1<<'#' |
1<<'=' |
1<<'@' |
1<<'~' |
1<<'_' |
1<<'/' |
1<<'-' |
1<<'[' |
1<<']' |
1<<'{' |
1<<'}' |
1<<'%'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// isName reports whether s is a capitalized Go identifier (like Name).
func isName(s string) bool {
t, ok := ident(s)
if !ok || t != s {
return false
}
r, _ := utf8.DecodeRuneInString(s)
return unicode.IsUpper(r)
}
// ident checks whether s begins with a Go identifier.
// If so, it returns the identifier, which is a prefix of s, and ok == true.
// Otherwise it returns "", false.
// The caller should skip over the first len(id) bytes of s
// before further processing.
func ident(s string) (id string, ok bool) {
// Scan [\pL_][\pL_0-9]*
n := 0
for n < len(s) {
if c := s[n]; c < utf8.RuneSelf {
if isIdentASCII(c) && (n > 0 || c < '0' || c > '9') {
n++
continue
}
break
}
r, nr := utf8.DecodeRuneInString(s[n:])
if unicode.IsLetter(r) {
n += nr
continue
}
break
}
return s[:n], n > 0
}
// isIdentASCII reports whether c is an ASCII identifier byte.
func isIdentASCII(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'_'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// validImportPath reports whether path is a valid import path.
// It is a lightly edited copy of golang.org/x/mod/module.CheckImportPath.
func validImportPath(path string) bool {
if !utf8.ValidString(path) {
return false
}
if path == "" {
return false
}
if path[0] == '-' {
return false
}
if strings.Contains(path, "//") {
return false
}
if path[len(path)-1] == '/' {
return false
}
elemStart := 0
for i, r := range path {
if r == '/' {
if !validImportPathElem(path[elemStart:i]) {
return false
}
elemStart = i + 1
}
}
return validImportPathElem(path[elemStart:])
}
func validImportPathElem(elem string) bool {
if elem == "" || elem[0] == '.' || elem[len(elem)-1] == '.' {
return false
}
for i := 0; i < len(elem); i++ {
if !importPathOK(elem[i]) {
return false
}
}
return true
}
func importPathOK(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c > 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<26-1)<<'A' |
(1<<26-1)<<'a' |
(1<<10-1)<<'0' |
1<<'-' |
1<<'.' |
1<<'~' |
1<<'_' |
1<<'+'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"strings"
)
// A Printer is a doc comment printer.
// The fields in the struct can be filled in before calling
// any of the printing methods
// in order to customize the details of the printing process.
type Printer struct {
// HeadingLevel is the nesting level used for
// HTML and Markdown headings.
// If HeadingLevel is zero, it defaults to level 3,
// meaning to use <h3> and ###.
HeadingLevel int
// HeadingID is a function that computes the heading ID
// (anchor tag) to use for the heading h when generating
// HTML and Markdown. If HeadingID returns an empty string,
// then the heading ID is omitted.
// If HeadingID is nil, h.DefaultID is used.
HeadingID func(h *Heading) string
// DocLinkURL is a function that computes the URL for the given DocLink.
// If DocLinkURL is nil, then link.DefaultURL(p.DocLinkBaseURL) is used.
DocLinkURL func(link *DocLink) string
// DocLinkBaseURL is used when DocLinkURL is nil,
// passed to [DocLink.DefaultURL] to construct a DocLink's URL.
// See that method's documentation for details.
DocLinkBaseURL string
// TextPrefix is a prefix to print at the start of every line
// when generating text output using the Text method.
TextPrefix string
// TextCodePrefix is the prefix to print at the start of each
// preformatted (code block) line when generating text output,
// instead of (not in addition to) TextPrefix.
// If TextCodePrefix is the empty string, it defaults to TextPrefix+"\t".
TextCodePrefix string
// TextWidth is the maximum width text line to generate,
// measured in Unicode code points,
// excluding TextPrefix and the newline character.
// If TextWidth is zero, it defaults to 80 minus the number of code points in TextPrefix.
// If TextWidth is negative, there is no limit.
TextWidth int
}
func (p *Printer) headingLevel() int {
if p.HeadingLevel <= 0 {
return 3
}
return p.HeadingLevel
}
func (p *Printer) headingID(h *Heading) string {
if p.HeadingID == nil {
return h.DefaultID()
}
return p.HeadingID(h)
}
func (p *Printer) docLinkURL(link *DocLink) string {
if p.DocLinkURL != nil {
return p.DocLinkURL(link)
}
return link.DefaultURL(p.DocLinkBaseURL)
}
// DefaultURL constructs and returns the documentation URL for l,
// using baseURL as a prefix for links to other packages.
//
// The possible forms returned by DefaultURL are:
// - baseURL/ImportPath, for a link to another package
// - baseURL/ImportPath#Name, for a link to a const, func, type, or var in another package
// - baseURL/ImportPath#Recv.Name, for a link to a method in another package
// - #Name, for a link to a const, func, type, or var in this package
// - #Recv.Name, for a link to a method in this package
//
// If baseURL ends in a trailing slash, then DefaultURL inserts
// a slash between ImportPath and # in the anchored forms.
// For example, here are some baseURL values and URLs they can generate:
//
// "/pkg/" → "/pkg/math/#Sqrt"
// "/pkg" → "/pkg/math#Sqrt"
// "/" → "/math/#Sqrt"
// "" → "/math#Sqrt"
func (l *DocLink) DefaultURL(baseURL string) string {
if l.ImportPath != "" {
slash := ""
if strings.HasSuffix(baseURL, "/") {
slash = "/"
} else {
baseURL += "/"
}
switch {
case l.Name == "":
return baseURL + l.ImportPath + slash
case l.Recv != "":
return baseURL + l.ImportPath + slash + "#" + l.Recv + "." + l.Name
default:
return baseURL + l.ImportPath + slash + "#" + l.Name
}
}
if l.Recv != "" {
return "#" + l.Recv + "." + l.Name
}
return "#" + l.Name
}
// DefaultID returns the default anchor ID for the heading h.
//
// The default anchor ID is constructed by converting every
// rune that is not alphanumeric ASCII to an underscore
// and then adding the prefix “hdr-”.
// For example, if the heading text is “Go Doc Comments”,
// the default ID is “hdr-Go_Doc_Comments”.
func (h *Heading) DefaultID() string {
// Note: The “hdr-” prefix is important to avoid DOM clobbering attacks.
// See https://pkg.go.dev/github.com/google/safehtml#Identifier.
var out strings.Builder
var p textPrinter
p.oneLongLine(&out, h.Text)
s := strings.TrimSpace(out.String())
if s == "" {
return ""
}
out.Reset()
out.WriteString("hdr-")
for _, r := range s {
if r < 0x80 && isIdentASCII(byte(r)) {
out.WriteByte(byte(r))
} else {
out.WriteByte('_')
}
}
return out.String()
}
type commentPrinter struct {
*Printer
headingPrefix string
needDoc map[string]bool
}
// Comment returns the standard Go formatting of the Doc,
// without any comment markers.
func (p *Printer) Comment(d *Doc) []byte {
cp := &commentPrinter{Printer: p}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 && blankBefore(x) {
out.WriteString("\n")
}
cp.block(&out, x)
}
// Print one block containing all the link definitions that were used,
// and then a second block containing all the unused ones.
// This makes it easy to clean up the unused ones: gofmt and
// delete the final block. And it's a nice visual signal without
// affecting the way the comment formats for users.
for i := 0; i < 2; i++ {
used := i == 0
first := true
for _, def := range d.Links {
if def.Used == used {
if first {
out.WriteString("\n")
first = false
}
out.WriteString("[")
out.WriteString(def.Text)
out.WriteString("]: ")
out.WriteString(def.URL)
out.WriteString("\n")
}
}
}
return out.Bytes()
}
// blankBefore reports whether the block x requires a blank line before it.
// All blocks do, except for Lists that return false from x.BlankBefore().
func blankBefore(x Block) bool {
if x, ok := x.(*List); ok {
return x.BlankBefore()
}
return true
}
// block prints the block x to out.
func (p *commentPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T", x)
case *Paragraph:
p.text(out, "", x.Text)
out.WriteString("\n")
case *Heading:
out.WriteString("# ")
p.text(out, "", x.Text)
out.WriteString("\n")
case *Code:
md := x.Text
for md != "" {
var line string
line, md, _ = strings.Cut(md, "\n")
if line != "" {
out.WriteString("\t")
out.WriteString(line)
}
out.WriteString("\n")
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString("\n")
}
out.WriteString(" ")
if item.Number == "" {
out.WriteString(" - ")
} else {
out.WriteString(item.Number)
out.WriteString(". ")
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
out.WriteString("\n" + fourSpace)
}
p.text(out, fourSpace, blk.(*Paragraph).Text)
out.WriteString("\n")
}
}
}
}
// text prints the text sequence x to out.
func (p *commentPrinter) text(out *bytes.Buffer, indent string, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
p.indent(out, indent, string(t))
case Italic:
p.indent(out, indent, string(t))
case *Link:
if t.Auto {
p.text(out, indent, t.Text)
} else {
out.WriteString("[")
p.text(out, indent, t.Text)
out.WriteString("]")
}
case *DocLink:
out.WriteString("[")
p.text(out, indent, t.Text)
out.WriteString("]")
}
}
}
// indent prints s to out, indenting with the indent string
// after each newline in s.
func (p *commentPrinter) indent(out *bytes.Buffer, indent, s string) {
for s != "" {
line, rest, ok := strings.Cut(s, "\n")
out.WriteString(line)
if ok {
out.WriteString("\n")
out.WriteString(indent)
}
s = rest
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package comment
import (
"bytes"
"fmt"
"sort"
"strings"
"unicode/utf8"
)
// A textPrinter holds the state needed for printing a Doc as plain text.
type textPrinter struct {
*Printer
long strings.Builder
prefix string
codePrefix string
width int
}
// Text returns a textual formatting of the Doc.
// See the [Printer] documentation for ways to customize the text output.
func (p *Printer) Text(d *Doc) []byte {
tp := &textPrinter{
Printer: p,
prefix: p.TextPrefix,
codePrefix: p.TextCodePrefix,
width: p.TextWidth,
}
if tp.codePrefix == "" {
tp.codePrefix = p.TextPrefix + "\t"
}
if tp.width == 0 {
tp.width = 80 - utf8.RuneCountInString(tp.prefix)
}
var out bytes.Buffer
for i, x := range d.Content {
if i > 0 && blankBefore(x) {
out.WriteString(tp.prefix)
writeNL(&out)
}
tp.block(&out, x)
}
anyUsed := false
for _, def := range d.Links {
if def.Used {
anyUsed = true
break
}
}
if anyUsed {
writeNL(&out)
for _, def := range d.Links {
if def.Used {
fmt.Fprintf(&out, "[%s]: %s\n", def.Text, def.URL)
}
}
}
return out.Bytes()
}
// writeNL calls out.WriteByte('\n')
// but first trims trailing spaces on the previous line.
func writeNL(out *bytes.Buffer) {
// Trim trailing spaces.
data := out.Bytes()
n := 0
for n < len(data) && (data[len(data)-n-1] == ' ' || data[len(data)-n-1] == '\t') {
n++
}
if n > 0 {
out.Truncate(len(data) - n)
}
out.WriteByte('\n')
}
// block prints the block x to out.
func (p *textPrinter) block(out *bytes.Buffer, x Block) {
switch x := x.(type) {
default:
fmt.Fprintf(out, "?%T\n", x)
case *Paragraph:
out.WriteString(p.prefix)
p.text(out, "", x.Text)
case *Heading:
out.WriteString(p.prefix)
out.WriteString("# ")
p.text(out, "", x.Text)
case *Code:
text := x.Text
for text != "" {
var line string
line, text, _ = strings.Cut(text, "\n")
if line != "" {
out.WriteString(p.codePrefix)
out.WriteString(line)
}
writeNL(out)
}
case *List:
loose := x.BlankBetween()
for i, item := range x.Items {
if i > 0 && loose {
out.WriteString(p.prefix)
writeNL(out)
}
out.WriteString(p.prefix)
out.WriteString(" ")
if item.Number == "" {
out.WriteString(" - ")
} else {
out.WriteString(item.Number)
out.WriteString(". ")
}
for i, blk := range item.Content {
const fourSpace = " "
if i > 0 {
writeNL(out)
out.WriteString(p.prefix)
out.WriteString(fourSpace)
}
p.text(out, fourSpace, blk.(*Paragraph).Text)
}
}
}
}
// text prints the text sequence x to out.
func (p *textPrinter) text(out *bytes.Buffer, indent string, x []Text) {
p.oneLongLine(&p.long, x)
words := strings.Fields(p.long.String())
p.long.Reset()
var seq []int
if p.width < 0 || len(words) == 0 {
seq = []int{0, len(words)} // one long line
} else {
seq = wrap(words, p.width-utf8.RuneCountInString(indent))
}
for i := 0; i+1 < len(seq); i++ {
if i > 0 {
out.WriteString(p.prefix)
out.WriteString(indent)
}
for j, w := range words[seq[i]:seq[i+1]] {
if j > 0 {
out.WriteString(" ")
}
out.WriteString(w)
}
writeNL(out)
}
}
// oneLongLine prints the text sequence x to out as one long line,
// without worrying about line wrapping.
// Explicit links have the [ ] dropped to improve readability.
func (p *textPrinter) oneLongLine(out *strings.Builder, x []Text) {
for _, t := range x {
switch t := t.(type) {
case Plain:
out.WriteString(string(t))
case Italic:
out.WriteString(string(t))
case *Link:
p.oneLongLine(out, t.Text)
case *DocLink:
p.oneLongLine(out, t.Text)
}
}
}
// wrap wraps words into lines of at most max runes,
// minimizing the sum of the squares of the leftover lengths
// at the end of each line (except the last, of course),
// with a preference for ending lines at punctuation (.,:;).
//
// The returned slice gives the indexes of the first words
// on each line in the wrapped text with a final entry of len(words).
// Thus the lines are words[seq[0]:seq[1]], words[seq[1]:seq[2]],
// ..., words[seq[len(seq)-2]:seq[len(seq)-1]].
//
// The implementation runs in O(n log n) time, where n = len(words),
// using the algorithm described in D. S. Hirschberg and L. L. Larmore,
// “[The least weight subsequence problem],” FOCS 1985, pp. 137-143.
//
// [The least weight subsequence problem]: https://doi.org/10.1109/SFCS.1985.60
func wrap(words []string, max int) (seq []int) {
// The algorithm requires that our scoring function be concave,
// meaning that for all i₀ ≤ i₁ < j₀ ≤ j₁,
// weight(i₀, j₀) + weight(i₁, j₁) ≤ weight(i₀, j₁) + weight(i₁, j₀).
//
// Our weights are two-element pairs [hi, lo]
// ordered by elementwise comparison.
// The hi entry counts the weight for lines that are longer than max,
// and the lo entry counts the weight for lines that are not.
// This forces the algorithm to first minimize the number of lines
// that are longer than max, which correspond to lines with
// single very long words. Having done that, it can move on to
// minimizing the lo score, which is more interesting.
//
// The lo score is the sum for each line of the square of the
// number of spaces remaining at the end of the line and a
// penalty of 64 given out for not ending the line in a
// punctuation character (.,:;).
// The penalty is somewhat arbitrarily chosen by trying
// different amounts and judging how nice the wrapped text looks.
// Roughly speaking, using 64 means that we are willing to
// end a line with eight blank spaces in order to end at a
// punctuation character, even if the next word would fit in
// those spaces.
//
// We care about ending in punctuation characters because
// it makes the text easier to skim if not too many sentences
// or phrases begin with a single word on the previous line.
// A score is the score (also called weight) for a given line.
// add and cmp add and compare scores.
type score struct {
hi int64
lo int64
}
add := func(s, t score) score { return score{s.hi + t.hi, s.lo + t.lo} }
cmp := func(s, t score) int {
switch {
case s.hi < t.hi:
return -1
case s.hi > t.hi:
return +1
case s.lo < t.lo:
return -1
case s.lo > t.lo:
return +1
}
return 0
}
// total[j] is the total number of runes
// (including separating spaces) in words[:j].
total := make([]int, len(words)+1)
total[0] = 0
for i, s := range words {
total[1+i] = total[i] + utf8.RuneCountInString(s) + 1
}
// weight returns weight(i, j).
weight := func(i, j int) score {
// On the last line, there is zero weight for being too short.
n := total[j] - 1 - total[i]
if j == len(words) && n <= max {
return score{0, 0}
}
// Otherwise the weight is the penalty plus the square of the number of
// characters remaining on the line or by which the line goes over.
// In the latter case, that value goes in the hi part of the score.
// (See note above.)
p := wrapPenalty(words[j-1])
v := int64(max-n) * int64(max-n)
if n > max {
return score{v, p}
}
return score{0, v + p}
}
// The rest of this function is “The Basic Algorithm” from
// Hirschberg and Larmore's conference paper,
// using the same names as in the paper.
f := []score{{0, 0}}
g := func(i, j int) score { return add(f[i], weight(i, j)) }
bridge := func(a, b, c int) bool {
k := c + sort.Search(len(words)+1-c, func(k int) bool {
k += c
return cmp(g(a, k), g(b, k)) > 0
})
if k > len(words) {
return true
}
return cmp(g(c, k), g(b, k)) <= 0
}
// d is a one-ended deque implemented as a slice.
d := make([]int, 1, len(words))
d[0] = 0
bestleft := make([]int, 1, len(words))
bestleft[0] = -1
for m := 1; m < len(words); m++ {
f = append(f, g(d[0], m))
bestleft = append(bestleft, d[0])
for len(d) > 1 && cmp(g(d[1], m+1), g(d[0], m+1)) <= 0 {
d = d[1:] // “Retire”
}
for len(d) > 1 && bridge(d[len(d)-2], d[len(d)-1], m) {
d = d[:len(d)-1] // “Fire”
}
if cmp(g(m, len(words)), g(d[len(d)-1], len(words))) < 0 {
d = append(d, m) // “Hire”
// The next few lines are not in the paper but are necessary
// to handle two-word inputs correctly. It appears to be
// just a bug in the paper's pseudocode.
if len(d) == 2 && cmp(g(d[1], m+1), g(d[0], m+1)) <= 0 {
d = d[1:]
}
}
}
bestleft = append(bestleft, d[0])
// Recover least weight sequence from bestleft.
n := 1
for m := len(words); m > 0; m = bestleft[m] {
n++
}
seq = make([]int, n)
for m := len(words); m > 0; m = bestleft[m] {
n--
seq[n] = m
}
return seq
}
// wrapPenalty is the penalty for inserting a line break after word s.
func wrapPenalty(s string) int64 {
switch s[len(s)-1] {
case '.', ',', ':', ';':
return 0
}
return 64
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package doc extracts source code documentation from a Go AST.
package doc
import (
"fmt"
"go/ast"
"go/doc/comment"
"go/token"
"strings"
)
// Package is the documentation for an entire package.
type Package struct {
Doc string
Name string
ImportPath string
Imports []string
Filenames []string
Notes map[string][]*Note
// Deprecated: For backward compatibility Bugs is still populated,
// but all new code should use Notes instead.
Bugs []string
// declarations
Consts []*Value
Types []*Type
Vars []*Value
Funcs []*Func
// Examples is a sorted list of examples associated with
// the package. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
importByName map[string]string
syms map[string]bool
}
// Value is the documentation for a (possibly grouped) var or const declaration.
type Value struct {
Doc string
Names []string // var or const names in declaration order
Decl *ast.GenDecl
order int
}
// Type is the documentation for a type declaration.
type Type struct {
Doc string
Name string
Decl *ast.GenDecl
// associated declarations
Consts []*Value // sorted list of constants of (mostly) this type
Vars []*Value // sorted list of variables of (mostly) this type
Funcs []*Func // sorted list of functions returning this type
Methods []*Func // sorted list of methods (including embedded ones) of this type
// Examples is a sorted list of examples associated with
// this type. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
}
// Func is the documentation for a func declaration.
type Func struct {
Doc string
Name string
Decl *ast.FuncDecl
// methods
// (for functions, these fields have the respective zero value)
Recv string // actual receiver "T" or "*T" possibly followed by type parameters [P1, ..., Pn]
Orig string // original receiver "T" or "*T"
Level int // embedding level; 0 means not embedded
// Examples is a sorted list of examples associated with this
// function or method. Examples are extracted from _test.go files
// provided to NewFromFiles.
Examples []*Example
}
// A Note represents a marked comment starting with "MARKER(uid): note body".
// Any note with a marker of 2 or more upper case [A-Z] letters and a uid of
// at least one character is recognized. The ":" following the uid is optional.
// Notes are collected in the Package.Notes map indexed by the notes marker.
type Note struct {
Pos, End token.Pos // position range of the comment containing the marker
UID string // uid found with the marker
Body string // note body text
}
// Mode values control the operation of New and NewFromFiles.
type Mode int
const (
// AllDecls says to extract documentation for all package-level
// declarations, not just exported ones.
AllDecls Mode = 1 << iota
// AllMethods says to show all embedded methods, not just the ones of
// invisible (unexported) anonymous fields.
AllMethods
// PreserveAST says to leave the AST unmodified. Originally, pieces of
// the AST such as function bodies were nil-ed out to save memory in
// godoc, but not all programs want that behavior.
PreserveAST
)
// New computes the package documentation for the given package AST.
// New takes ownership of the AST pkg and may edit or overwrite it.
// To have the Examples fields populated, use NewFromFiles and include
// the package's _test.go files.
func New(pkg *ast.Package, importPath string, mode Mode) *Package {
var r reader
r.readPackage(pkg, mode)
r.computeMethodSets()
r.cleanupTypes()
p := &Package{
Doc: r.doc,
Name: pkg.Name,
ImportPath: importPath,
Imports: sortedKeys(r.imports),
Filenames: r.filenames,
Notes: r.notes,
Bugs: noteBodies(r.notes["BUG"]),
Consts: sortedValues(r.values, token.CONST),
Types: sortedTypes(r.types, mode&AllMethods != 0),
Vars: sortedValues(r.values, token.VAR),
Funcs: sortedFuncs(r.funcs, true),
importByName: r.importByName,
syms: make(map[string]bool),
}
p.collectValues(p.Consts)
p.collectValues(p.Vars)
p.collectTypes(p.Types)
p.collectFuncs(p.Funcs)
return p
}
func (p *Package) collectValues(values []*Value) {
for _, v := range values {
for _, name := range v.Names {
p.syms[name] = true
}
}
}
func (p *Package) collectTypes(types []*Type) {
for _, t := range types {
if p.syms[t.Name] {
// Shouldn't be any cycles but stop just in case.
continue
}
p.syms[t.Name] = true
p.collectValues(t.Consts)
p.collectValues(t.Vars)
p.collectFuncs(t.Funcs)
p.collectFuncs(t.Methods)
}
}
func (p *Package) collectFuncs(funcs []*Func) {
for _, f := range funcs {
if f.Recv != "" {
r := strings.TrimPrefix(f.Recv, "*")
if i := strings.IndexByte(r, '['); i >= 0 {
r = r[:i] // remove type parameters
}
p.syms[r+"."+f.Name] = true
} else {
p.syms[f.Name] = true
}
}
}
// NewFromFiles computes documentation for a package.
//
// The package is specified by a list of *ast.Files and corresponding
// file set, which must not be nil.
// NewFromFiles uses all provided files when computing documentation,
// so it is the caller's responsibility to provide only the files that
// match the desired build context. "go/build".Context.MatchFile can
// be used for determining whether a file matches a build context with
// the desired GOOS and GOARCH values, and other build constraints.
// The import path of the package is specified by importPath.
//
// Examples found in _test.go files are associated with the corresponding
// type, function, method, or the package, based on their name.
// If the example has a suffix in its name, it is set in the
// Example.Suffix field. Examples with malformed names are skipped.
//
// Optionally, a single extra argument of type Mode can be provided to
// control low-level aspects of the documentation extraction behavior.
//
// NewFromFiles takes ownership of the AST files and may edit them,
// unless the PreserveAST Mode bit is on.
func NewFromFiles(fset *token.FileSet, files []*ast.File, importPath string, opts ...any) (*Package, error) {
// Check for invalid API usage.
if fset == nil {
panic(fmt.Errorf("doc.NewFromFiles: no token.FileSet provided (fset == nil)"))
}
var mode Mode
switch len(opts) { // There can only be 0 or 1 options, so a simple switch works for now.
case 0:
// Nothing to do.
case 1:
m, ok := opts[0].(Mode)
if !ok {
panic(fmt.Errorf("doc.NewFromFiles: option argument type must be doc.Mode"))
}
mode = m
default:
panic(fmt.Errorf("doc.NewFromFiles: there must not be more than 1 option argument"))
}
// Collect .go and _test.go files.
var (
goFiles = make(map[string]*ast.File)
testGoFiles []*ast.File
)
for i := range files {
f := fset.File(files[i].Pos())
if f == nil {
return nil, fmt.Errorf("file files[%d] is not found in the provided file set", i)
}
switch name := f.Name(); {
case strings.HasSuffix(name, ".go") && !strings.HasSuffix(name, "_test.go"):
goFiles[name] = files[i]
case strings.HasSuffix(name, "_test.go"):
testGoFiles = append(testGoFiles, files[i])
default:
return nil, fmt.Errorf("file files[%d] filename %q does not have a .go extension", i, name)
}
}
// TODO(dmitshur,gri): A relatively high level call to ast.NewPackage with a simpleImporter
// ast.Importer implementation is made below. It might be possible to short-circuit and simplify.
// Compute package documentation.
pkg, _ := ast.NewPackage(fset, goFiles, simpleImporter, nil) // Ignore errors that can happen due to unresolved identifiers.
p := New(pkg, importPath, mode)
classifyExamples(p, Examples(testGoFiles...))
return p, nil
}
// simpleImporter returns a (dummy) package object named by the last path
// component of the provided package path (as is the convention for packages).
// This is sufficient to resolve package identifiers without doing an actual
// import. It never returns an error.
func simpleImporter(imports map[string]*ast.Object, path string) (*ast.Object, error) {
pkg := imports[path]
if pkg == nil {
// note that strings.LastIndex returns -1 if there is no "/"
pkg = ast.NewObj(ast.Pkg, path[strings.LastIndex(path, "/")+1:])
pkg.Data = ast.NewScope(nil) // required by ast.NewPackage for dot-import
imports[path] = pkg
}
return pkg, nil
}
// lookupSym reports whether the package has a given symbol or method.
//
// If recv == "", HasSym reports whether the package has a top-level
// const, func, type, or var named name.
//
// If recv != "", HasSym reports whether the package has a type
// named recv with a method named name.
func (p *Package) lookupSym(recv, name string) bool {
if recv != "" {
return p.syms[recv+"."+name]
}
return p.syms[name]
}
// lookupPackage returns the import path identified by name
// in the given package. If name uniquely identifies a single import,
// then lookupPackage returns that import.
// If multiple packages are imported as name, importPath returns "", false.
// Otherwise, if name is the name of p itself, importPath returns "", true,
// to signal a reference to p.
// Otherwise, importPath returns "", false.
func (p *Package) lookupPackage(name string) (importPath string, ok bool) {
if path, ok := p.importByName[name]; ok {
if path == "" {
return "", false // multiple imports used the name
}
return path, true // found import
}
if p.Name == name {
return "", true // allow reference to this package
}
return "", false // unknown name
}
// Parser returns a doc comment parser configured
// for parsing doc comments from package p.
// Each call returns a new parser, so that the caller may
// customize it before use.
func (p *Package) Parser() *comment.Parser {
return &comment.Parser{
LookupPackage: p.lookupPackage,
LookupSym: p.lookupSym,
}
}
// Printer returns a doc comment printer configured
// for printing doc comments from package p.
// Each call returns a new printer, so that the caller may
// customize it before use.
func (p *Package) Printer() *comment.Printer {
// No customization today, but having p.Printer()
// gives us flexibility in the future, and it is convenient for callers.
return &comment.Printer{}
}
// HTML returns formatted HTML for the doc comment text.
//
// To customize details of the HTML, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its HTML method.
func (p *Package) HTML(text string) []byte {
return p.Printer().HTML(p.Parser().Parse(text))
}
// Markdown returns formatted Markdown for the doc comment text.
//
// To customize details of the Markdown, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its Markdown method.
func (p *Package) Markdown(text string) []byte {
return p.Printer().Markdown(p.Parser().Parse(text))
}
// Text returns formatted text for the doc comment text,
// wrapped to 80 Unicode code points and using tabs for
// code block indentation.
//
// To customize details of the formatting, use [Package.Printer]
// to obtain a [comment.Printer], and configure it
// before calling its Text method.
func (p *Package) Text(text string) []byte {
return p.Printer().Text(p.Parser().Parse(text))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Extract example functions from file ASTs.
package doc
import (
"go/ast"
"go/token"
"internal/lazyregexp"
"path"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// An Example represents an example function found in a test source file.
type Example struct {
Name string // name of the item being exemplified (including optional suffix)
Suffix string // example suffix, without leading '_' (only populated by NewFromFiles)
Doc string // example function doc string
Code ast.Node
Play *ast.File // a whole program version of the example
Comments []*ast.CommentGroup
Output string // expected output
Unordered bool
EmptyOutput bool // expect empty output
Order int // original source code order
}
// Examples returns the examples found in testFiles, sorted by Name field.
// The Order fields record the order in which the examples were encountered.
// The Suffix field is not populated when Examples is called directly, it is
// only populated by NewFromFiles for examples it finds in _test.go files.
//
// Playable Examples must be in a package whose name ends in "_test".
// An Example is "playable" (the Play field is non-nil) in either of these
// circumstances:
// - The example function is self-contained: the function references only
// identifiers from other packages (or predeclared identifiers, such as
// "int") and the test file does not include a dot import.
// - The entire test file is the example: the file contains exactly one
// example function, zero test, fuzz test, or benchmark function, and at
// least one top-level function, type, variable, or constant declaration
// other than the example function.
func Examples(testFiles ...*ast.File) []*Example {
var list []*Example
for _, file := range testFiles {
hasTests := false // file contains tests, fuzz test, or benchmarks
numDecl := 0 // number of non-import declarations in the file
var flist []*Example
for _, decl := range file.Decls {
if g, ok := decl.(*ast.GenDecl); ok && g.Tok != token.IMPORT {
numDecl++
continue
}
f, ok := decl.(*ast.FuncDecl)
if !ok || f.Recv != nil {
continue
}
numDecl++
name := f.Name.Name
if isTest(name, "Test") || isTest(name, "Benchmark") || isTest(name, "Fuzz") {
hasTests = true
continue
}
if !isTest(name, "Example") {
continue
}
if params := f.Type.Params; len(params.List) != 0 {
continue // function has params; not a valid example
}
if f.Body == nil { // ast.File.Body nil dereference (see issue 28044)
continue
}
var doc string
if f.Doc != nil {
doc = f.Doc.Text()
}
output, unordered, hasOutput := exampleOutput(f.Body, file.Comments)
flist = append(flist, &Example{
Name: name[len("Example"):],
Doc: doc,
Code: f.Body,
Play: playExample(file, f),
Comments: file.Comments,
Output: output,
Unordered: unordered,
EmptyOutput: output == "" && hasOutput,
Order: len(flist),
})
}
if !hasTests && numDecl > 1 && len(flist) == 1 {
// If this file only has one example function, some
// other top-level declarations, and no tests or
// benchmarks, use the whole file as the example.
flist[0].Code = file
flist[0].Play = playExampleFile(file)
}
list = append(list, flist...)
}
// sort by name
sort.Slice(list, func(i, j int) bool {
return list[i].Name < list[j].Name
})
return list
}
var outputPrefix = lazyregexp.New(`(?i)^[[:space:]]*(unordered )?output:`)
// Extracts the expected output and whether there was a valid output comment.
func exampleOutput(b *ast.BlockStmt, comments []*ast.CommentGroup) (output string, unordered, ok bool) {
if _, last := lastComment(b, comments); last != nil {
// test that it begins with the correct prefix
text := last.Text()
if loc := outputPrefix.FindStringSubmatchIndex(text); loc != nil {
if loc[2] != -1 {
unordered = true
}
text = text[loc[1]:]
// Strip zero or more spaces followed by \n or a single space.
text = strings.TrimLeft(text, " ")
if len(text) > 0 && text[0] == '\n' {
text = text[1:]
}
return text, unordered, true
}
}
return "", false, false // no suitable comment found
}
// isTest tells whether name looks like a test, example, fuzz test, or
// benchmark. It is a Test (say) if there is a character after Test that is not
// a lower-case letter. (We don't want Testiness.)
func isTest(name, prefix string) bool {
if !strings.HasPrefix(name, prefix) {
return false
}
if len(name) == len(prefix) { // "Test" is ok
return true
}
rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
return !unicode.IsLower(rune)
}
// playExample synthesizes a new *ast.File based on the provided
// file with the provided function body as the body of main.
func playExample(file *ast.File, f *ast.FuncDecl) *ast.File {
body := f.Body
if !strings.HasSuffix(file.Name.Name, "_test") {
// We don't support examples that are part of the
// greater package (yet).
return nil
}
// Collect top-level declarations in the file.
topDecls := make(map[*ast.Object]ast.Decl)
typMethods := make(map[string][]ast.Decl)
for _, decl := range file.Decls {
switch d := decl.(type) {
case *ast.FuncDecl:
if d.Recv == nil {
topDecls[d.Name.Obj] = d
} else {
if len(d.Recv.List) == 1 {
t := d.Recv.List[0].Type
tname, _ := baseTypeName(t)
typMethods[tname] = append(typMethods[tname], d)
}
}
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.TypeSpec:
topDecls[s.Name.Obj] = d
case *ast.ValueSpec:
for _, name := range s.Names {
topDecls[name.Obj] = d
}
}
}
}
}
// Find unresolved identifiers and uses of top-level declarations.
depDecls, unresolved := findDeclsAndUnresolved(body, topDecls, typMethods)
// Remove predeclared identifiers from unresolved list.
for n := range unresolved {
if predeclaredTypes[n] || predeclaredConstants[n] || predeclaredFuncs[n] {
delete(unresolved, n)
}
}
// Use unresolved identifiers to determine the imports used by this
// example. The heuristic assumes package names match base import
// paths for imports w/o renames (should be good enough most of the time).
var namedImports []ast.Spec
var blankImports []ast.Spec // _ imports
// To preserve the blank lines between groups of imports, find the
// start position of each group, and assign that position to all
// imports from that group.
groupStarts := findImportGroupStarts(file.Imports)
groupStart := func(s *ast.ImportSpec) token.Pos {
for i, start := range groupStarts {
if s.Path.ValuePos < start {
return groupStarts[i-1]
}
}
return groupStarts[len(groupStarts)-1]
}
for _, s := range file.Imports {
p, err := strconv.Unquote(s.Path.Value)
if err != nil {
continue
}
if p == "syscall/js" {
// We don't support examples that import syscall/js,
// because the package syscall/js is not available in the playground.
return nil
}
n := path.Base(p)
if s.Name != nil {
n = s.Name.Name
switch n {
case "_":
blankImports = append(blankImports, s)
continue
case ".":
// We can't resolve dot imports (yet).
return nil
}
}
if unresolved[n] {
// Copy the spec and its path to avoid modifying the original.
spec := *s
path := *s.Path
spec.Path = &path
spec.Path.ValuePos = groupStart(&spec)
namedImports = append(namedImports, &spec)
delete(unresolved, n)
}
}
// If there are other unresolved identifiers, give up because this
// synthesized file is not going to build.
if len(unresolved) > 0 {
return nil
}
// Include documentation belonging to blank imports.
var comments []*ast.CommentGroup
for _, s := range blankImports {
if c := s.(*ast.ImportSpec).Doc; c != nil {
comments = append(comments, c)
}
}
// Include comments that are inside the function body.
for _, c := range file.Comments {
if body.Pos() <= c.Pos() && c.End() <= body.End() {
comments = append(comments, c)
}
}
// Strip the "Output:" or "Unordered output:" comment and adjust body
// end position.
body, comments = stripOutputComment(body, comments)
// Include documentation belonging to dependent declarations.
for _, d := range depDecls {
switch d := d.(type) {
case *ast.GenDecl:
if d.Doc != nil {
comments = append(comments, d.Doc)
}
case *ast.FuncDecl:
if d.Doc != nil {
comments = append(comments, d.Doc)
}
}
}
// Synthesize import declaration.
importDecl := &ast.GenDecl{
Tok: token.IMPORT,
Lparen: 1, // Need non-zero Lparen and Rparen so that printer
Rparen: 1, // treats this as a factored import.
}
importDecl.Specs = append(namedImports, blankImports...)
// Synthesize main function.
funcDecl := &ast.FuncDecl{
Name: ast.NewIdent("main"),
Type: f.Type,
Body: body,
}
decls := make([]ast.Decl, 0, 2+len(depDecls))
decls = append(decls, importDecl)
decls = append(decls, depDecls...)
decls = append(decls, funcDecl)
sort.Slice(decls, func(i, j int) bool {
return decls[i].Pos() < decls[j].Pos()
})
sort.Slice(comments, func(i, j int) bool {
return comments[i].Pos() < comments[j].Pos()
})
// Synthesize file.
return &ast.File{
Name: ast.NewIdent("main"),
Decls: decls,
Comments: comments,
}
}
// findDeclsAndUnresolved returns all the top-level declarations mentioned in
// the body, and a set of unresolved symbols (those that appear in the body but
// have no declaration in the program).
//
// topDecls maps objects to the top-level declaration declaring them (not
// necessarily obj.Decl, as obj.Decl will be a Spec for GenDecls, but
// topDecls[obj] will be the GenDecl itself).
func findDeclsAndUnresolved(body ast.Node, topDecls map[*ast.Object]ast.Decl, typMethods map[string][]ast.Decl) ([]ast.Decl, map[string]bool) {
// This function recursively finds every top-level declaration used
// transitively by the body, populating usedDecls and usedObjs. Then it
// trims down the declarations to include only the symbols actually
// referenced by the body.
unresolved := make(map[string]bool)
var depDecls []ast.Decl
usedDecls := make(map[ast.Decl]bool) // set of top-level decls reachable from the body
usedObjs := make(map[*ast.Object]bool) // set of objects reachable from the body (each declared by a usedDecl)
var inspectFunc func(ast.Node) bool
inspectFunc = func(n ast.Node) bool {
switch e := n.(type) {
case *ast.Ident:
if e.Obj == nil && e.Name != "_" {
unresolved[e.Name] = true
} else if d := topDecls[e.Obj]; d != nil {
usedObjs[e.Obj] = true
if !usedDecls[d] {
usedDecls[d] = true
depDecls = append(depDecls, d)
}
}
return true
case *ast.SelectorExpr:
// For selector expressions, only inspect the left hand side.
// (For an expression like fmt.Println, only add "fmt" to the
// set of unresolved names, not "Println".)
ast.Inspect(e.X, inspectFunc)
return false
case *ast.KeyValueExpr:
// For key value expressions, only inspect the value
// as the key should be resolved by the type of the
// composite literal.
ast.Inspect(e.Value, inspectFunc)
return false
}
return true
}
inspectFieldList := func(fl *ast.FieldList) {
if fl != nil {
for _, f := range fl.List {
ast.Inspect(f.Type, inspectFunc)
}
}
}
// Find the decls immediately referenced by body.
ast.Inspect(body, inspectFunc)
// Now loop over them, adding to the list when we find a new decl that the
// body depends on. Keep going until we don't find anything new.
for i := 0; i < len(depDecls); i++ {
switch d := depDecls[i].(type) {
case *ast.FuncDecl:
// Inpect type parameters.
inspectFieldList(d.Type.TypeParams)
// Inspect types of parameters and results. See #28492.
inspectFieldList(d.Type.Params)
inspectFieldList(d.Type.Results)
// Functions might not have a body. See #42706.
if d.Body != nil {
ast.Inspect(d.Body, inspectFunc)
}
case *ast.GenDecl:
for _, spec := range d.Specs {
switch s := spec.(type) {
case *ast.TypeSpec:
inspectFieldList(s.TypeParams)
ast.Inspect(s.Type, inspectFunc)
depDecls = append(depDecls, typMethods[s.Name.Name]...)
case *ast.ValueSpec:
if s.Type != nil {
ast.Inspect(s.Type, inspectFunc)
}
for _, val := range s.Values {
ast.Inspect(val, inspectFunc)
}
}
}
}
}
// Some decls include multiple specs, such as a variable declaration with
// multiple variables on the same line, or a parenthesized declaration. Trim
// the declarations to include only the specs that are actually mentioned.
// However, if there is a constant group with iota, leave it all: later
// constant declarations in the group may have no value and so cannot stand
// on their own, and removing any constant from the group could change the
// values of subsequent ones.
// See testdata/examples/iota.go for a minimal example.
var ds []ast.Decl
for _, d := range depDecls {
switch d := d.(type) {
case *ast.FuncDecl:
ds = append(ds, d)
case *ast.GenDecl:
containsIota := false // does any spec have iota?
// Collect all Specs that were mentioned in the example.
var specs []ast.Spec
for _, s := range d.Specs {
switch s := s.(type) {
case *ast.TypeSpec:
if usedObjs[s.Name.Obj] {
specs = append(specs, s)
}
case *ast.ValueSpec:
if !containsIota {
containsIota = hasIota(s)
}
// A ValueSpec may have multiple names (e.g. "var a, b int").
// Keep only the names that were mentioned in the example.
// Exception: the multiple names have a single initializer (which
// would be a function call with multiple return values). In that
// case, keep everything.
if len(s.Names) > 1 && len(s.Values) == 1 {
specs = append(specs, s)
continue
}
ns := *s
ns.Names = nil
ns.Values = nil
for i, n := range s.Names {
if usedObjs[n.Obj] {
ns.Names = append(ns.Names, n)
if s.Values != nil {
ns.Values = append(ns.Values, s.Values[i])
}
}
}
if len(ns.Names) > 0 {
specs = append(specs, &ns)
}
}
}
if len(specs) > 0 {
// Constant with iota? Keep it all.
if d.Tok == token.CONST && containsIota {
ds = append(ds, d)
} else {
// Synthesize a GenDecl with just the Specs we need.
nd := *d // copy the GenDecl
nd.Specs = specs
if len(specs) == 1 {
// Remove grouping parens if there is only one spec.
nd.Lparen = 0
}
ds = append(ds, &nd)
}
}
}
}
return ds, unresolved
}
func hasIota(s ast.Spec) bool {
has := false
ast.Inspect(s, func(n ast.Node) bool {
// Check that this is the special built-in "iota" identifier, not
// a user-defined shadow.
if id, ok := n.(*ast.Ident); ok && id.Name == "iota" && id.Obj == nil {
has = true
return false
}
return true
})
return has
}
// findImportGroupStarts finds the start positions of each sequence of import
// specs that are not separated by a blank line.
func findImportGroupStarts(imps []*ast.ImportSpec) []token.Pos {
startImps := findImportGroupStarts1(imps)
groupStarts := make([]token.Pos, len(startImps))
for i, imp := range startImps {
groupStarts[i] = imp.Pos()
}
return groupStarts
}
// Helper for findImportGroupStarts to ease testing.
func findImportGroupStarts1(origImps []*ast.ImportSpec) []*ast.ImportSpec {
// Copy to avoid mutation.
imps := make([]*ast.ImportSpec, len(origImps))
copy(imps, origImps)
// Assume the imports are sorted by position.
sort.Slice(imps, func(i, j int) bool { return imps[i].Pos() < imps[j].Pos() })
// Assume gofmt has been applied, so there is a blank line between adjacent imps
// if and only if they are more than 2 positions apart (newline, tab).
var groupStarts []*ast.ImportSpec
prevEnd := token.Pos(-2)
for _, imp := range imps {
if imp.Pos()-prevEnd > 2 {
groupStarts = append(groupStarts, imp)
}
prevEnd = imp.End()
// Account for end-of-line comments.
if imp.Comment != nil {
prevEnd = imp.Comment.End()
}
}
return groupStarts
}
// playExampleFile takes a whole file example and synthesizes a new *ast.File
// such that the example is function main in package main.
func playExampleFile(file *ast.File) *ast.File {
// Strip copyright comment if present.
comments := file.Comments
if len(comments) > 0 && strings.HasPrefix(comments[0].Text(), "Copyright") {
comments = comments[1:]
}
// Copy declaration slice, rewriting the ExampleX function to main.
var decls []ast.Decl
for _, d := range file.Decls {
if f, ok := d.(*ast.FuncDecl); ok && isTest(f.Name.Name, "Example") {
// Copy the FuncDecl, as it may be used elsewhere.
newF := *f
newF.Name = ast.NewIdent("main")
newF.Body, comments = stripOutputComment(f.Body, comments)
d = &newF
}
decls = append(decls, d)
}
// Copy the File, as it may be used elsewhere.
f := *file
f.Name = ast.NewIdent("main")
f.Decls = decls
f.Comments = comments
return &f
}
// stripOutputComment finds and removes the "Output:" or "Unordered output:"
// comment from body and comments, and adjusts the body block's end position.
func stripOutputComment(body *ast.BlockStmt, comments []*ast.CommentGroup) (*ast.BlockStmt, []*ast.CommentGroup) {
// Do nothing if there is no "Output:" or "Unordered output:" comment.
i, last := lastComment(body, comments)
if last == nil || !outputPrefix.MatchString(last.Text()) {
return body, comments
}
// Copy body and comments, as the originals may be used elsewhere.
newBody := &ast.BlockStmt{
Lbrace: body.Lbrace,
List: body.List,
Rbrace: last.Pos(),
}
newComments := make([]*ast.CommentGroup, len(comments)-1)
copy(newComments, comments[:i])
copy(newComments[i:], comments[i+1:])
return newBody, newComments
}
// lastComment returns the last comment inside the provided block.
func lastComment(b *ast.BlockStmt, c []*ast.CommentGroup) (i int, last *ast.CommentGroup) {
if b == nil {
return
}
pos, end := b.Pos(), b.End()
for j, cg := range c {
if cg.Pos() < pos {
continue
}
if cg.End() > end {
break
}
i, last = j, cg
}
return
}
// classifyExamples classifies examples and assigns them to the Examples field
// of the relevant Func, Type, or Package that the example is associated with.
//
// The classification process is ambiguous in some cases:
//
// - ExampleFoo_Bar matches a type named Foo_Bar
// or a method named Foo.Bar.
// - ExampleFoo_bar matches a type named Foo_bar
// or Foo (with a "bar" suffix).
//
// Examples with malformed names are not associated with anything.
func classifyExamples(p *Package, examples []*Example) {
if len(examples) == 0 {
return
}
// Mapping of names for funcs, types, and methods to the example listing.
ids := make(map[string]*[]*Example)
ids[""] = &p.Examples // package-level examples have an empty name
for _, f := range p.Funcs {
if !token.IsExported(f.Name) {
continue
}
ids[f.Name] = &f.Examples
}
for _, t := range p.Types {
if !token.IsExported(t.Name) {
continue
}
ids[t.Name] = &t.Examples
for _, f := range t.Funcs {
if !token.IsExported(f.Name) {
continue
}
ids[f.Name] = &f.Examples
}
for _, m := range t.Methods {
if !token.IsExported(m.Name) {
continue
}
ids[strings.TrimPrefix(nameWithoutInst(m.Recv), "*")+"_"+m.Name] = &m.Examples
}
}
// Group each example with the associated func, type, or method.
for _, ex := range examples {
// Consider all possible split points for the suffix
// by starting at the end of string (no suffix case),
// then trying all positions that contain a '_' character.
//
// An association is made on the first successful match.
// Examples with malformed names that match nothing are skipped.
for i := len(ex.Name); i >= 0; i = strings.LastIndexByte(ex.Name[:i], '_') {
prefix, suffix, ok := splitExampleName(ex.Name, i)
if !ok {
continue
}
exs, ok := ids[prefix]
if !ok {
continue
}
ex.Suffix = suffix
*exs = append(*exs, ex)
break
}
}
// Sort list of example according to the user-specified suffix name.
for _, exs := range ids {
sort.Slice((*exs), func(i, j int) bool {
return (*exs)[i].Suffix < (*exs)[j].Suffix
})
}
}
// nameWithoutInst returns name if name has no brackets. If name contains
// brackets, then it returns name with all the contents between (and including)
// the outermost left and right bracket removed.
//
// Adapted from debug/gosym/symtab.go:Sym.nameWithoutInst.
func nameWithoutInst(name string) string {
start := strings.Index(name, "[")
if start < 0 {
return name
}
end := strings.LastIndex(name, "]")
if end < 0 {
// Malformed name, should contain closing bracket too.
return name
}
return name[0:start] + name[end+1:]
}
// splitExampleName attempts to split example name s at index i,
// and reports if that produces a valid split. The suffix may be
// absent. Otherwise, it must start with a lower-case letter and
// be preceded by '_'.
//
// One of i == len(s) or s[i] == '_' must be true.
func splitExampleName(s string, i int) (prefix, suffix string, ok bool) {
if i == len(s) {
return s, "", true
}
if i == len(s)-1 {
return "", "", false
}
prefix, suffix = s[:i], s[i+1:]
return prefix, suffix, isExampleSuffix(suffix)
}
func isExampleSuffix(s string) bool {
r, size := utf8.DecodeRuneInString(s)
return size > 0 && unicode.IsLower(r)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements export filtering of an AST.
package doc
import (
"go/ast"
"go/token"
)
// filterIdentList removes unexported names from list in place
// and returns the resulting list.
func filterIdentList(list []*ast.Ident) []*ast.Ident {
j := 0
for _, x := range list {
if token.IsExported(x.Name) {
list[j] = x
j++
}
}
return list[0:j]
}
var underscore = ast.NewIdent("_")
func filterCompositeLit(lit *ast.CompositeLit, filter Filter, export bool) {
n := len(lit.Elts)
lit.Elts = filterExprList(lit.Elts, filter, export)
if len(lit.Elts) < n {
lit.Incomplete = true
}
}
func filterExprList(list []ast.Expr, filter Filter, export bool) []ast.Expr {
j := 0
for _, exp := range list {
switch x := exp.(type) {
case *ast.CompositeLit:
filterCompositeLit(x, filter, export)
case *ast.KeyValueExpr:
if x, ok := x.Key.(*ast.Ident); ok && !filter(x.Name) {
continue
}
if x, ok := x.Value.(*ast.CompositeLit); ok {
filterCompositeLit(x, filter, export)
}
}
list[j] = exp
j++
}
return list[0:j]
}
// updateIdentList replaces all unexported identifiers with underscore
// and reports whether at least one exported name exists.
func updateIdentList(list []*ast.Ident) (hasExported bool) {
for i, x := range list {
if token.IsExported(x.Name) {
hasExported = true
} else {
list[i] = underscore
}
}
return hasExported
}
// hasExportedName reports whether list contains any exported names.
func hasExportedName(list []*ast.Ident) bool {
for _, x := range list {
if x.IsExported() {
return true
}
}
return false
}
// removeAnonymousField removes anonymous fields named name from an interface.
func removeAnonymousField(name string, ityp *ast.InterfaceType) {
list := ityp.Methods.List // we know that ityp.Methods != nil
j := 0
for _, field := range list {
keepField := true
if n := len(field.Names); n == 0 {
// anonymous field
if fname, _ := baseTypeName(field.Type); fname == name {
keepField = false
}
}
if keepField {
list[j] = field
j++
}
}
if j < len(list) {
ityp.Incomplete = true
}
ityp.Methods.List = list[0:j]
}
// filterFieldList removes unexported fields (field names) from the field list
// in place and reports whether fields were removed. Anonymous fields are
// recorded with the parent type. filterType is called with the types of
// all remaining fields.
func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
if fields == nil {
return
}
list := fields.List
j := 0
for _, field := range list {
keepField := false
if n := len(field.Names); n == 0 {
// anonymous field or embedded type or union element
fname := r.recordAnonymousField(parent, field.Type)
if fname != "" {
if token.IsExported(fname) {
keepField = true
} else if ityp != nil && predeclaredTypes[fname] {
// possibly an embedded predeclared type; keep it for now but
// remember this interface so that it can be fixed if name is also
// defined locally
keepField = true
r.remember(fname, ityp)
}
} else {
// If we're operating on an interface, assume that this is an embedded
// type or union element.
//
// TODO(rfindley): consider traversing into approximation/unions
// elements to see if they are entirely unexported.
keepField = ityp != nil
}
} else {
field.Names = filterIdentList(field.Names)
if len(field.Names) < n {
removedFields = true
}
if len(field.Names) > 0 {
keepField = true
}
}
if keepField {
r.filterType(nil, field.Type)
list[j] = field
j++
}
}
if j < len(list) {
removedFields = true
}
fields.List = list[0:j]
return
}
// filterParamList applies filterType to each parameter type in fields.
func (r *reader) filterParamList(fields *ast.FieldList) {
if fields != nil {
for _, f := range fields.List {
r.filterType(nil, f.Type)
}
}
}
// filterType strips any unexported struct fields or method types from typ
// in place. If fields (or methods) have been removed, the corresponding
// struct or interface type has the Incomplete field set to true.
func (r *reader) filterType(parent *namedType, typ ast.Expr) {
switch t := typ.(type) {
case *ast.Ident:
// nothing to do
case *ast.ParenExpr:
r.filterType(nil, t.X)
case *ast.StarExpr: // possibly an embedded type literal
r.filterType(nil, t.X)
case *ast.UnaryExpr:
if t.Op == token.TILDE { // approximation element
r.filterType(nil, t.X)
}
case *ast.BinaryExpr:
if t.Op == token.OR { // union
r.filterType(nil, t.X)
r.filterType(nil, t.Y)
}
case *ast.ArrayType:
r.filterType(nil, t.Elt)
case *ast.StructType:
if r.filterFieldList(parent, t.Fields, nil) {
t.Incomplete = true
}
case *ast.FuncType:
r.filterParamList(t.TypeParams)
r.filterParamList(t.Params)
r.filterParamList(t.Results)
case *ast.InterfaceType:
if r.filterFieldList(parent, t.Methods, t) {
t.Incomplete = true
}
case *ast.MapType:
r.filterType(nil, t.Key)
r.filterType(nil, t.Value)
case *ast.ChanType:
r.filterType(nil, t.Value)
}
}
func (r *reader) filterSpec(spec ast.Spec) bool {
switch s := spec.(type) {
case *ast.ImportSpec:
// always keep imports so we can collect them
return true
case *ast.ValueSpec:
s.Values = filterExprList(s.Values, token.IsExported, true)
if len(s.Values) > 0 || s.Type == nil && len(s.Values) == 0 {
// If there are values declared on RHS, just replace the unexported
// identifiers on the LHS with underscore, so that it matches
// the sequence of expression on the RHS.
//
// Similarly, if there are no type and values, then this expression
// must be following an iota expression, where order matters.
if updateIdentList(s.Names) {
r.filterType(nil, s.Type)
return true
}
} else {
s.Names = filterIdentList(s.Names)
if len(s.Names) > 0 {
r.filterType(nil, s.Type)
return true
}
}
case *ast.TypeSpec:
// Don't filter type parameters here, by analogy with function parameters
// which are not filtered for top-level function declarations.
if name := s.Name.Name; token.IsExported(name) {
r.filterType(r.lookupType(s.Name.Name), s.Type)
return true
} else if IsPredeclared(name) {
if r.shadowedPredecl == nil {
r.shadowedPredecl = make(map[string]bool)
}
r.shadowedPredecl[name] = true
}
}
return false
}
// copyConstType returns a copy of typ with position pos.
// typ must be a valid constant type.
// In practice, only (possibly qualified) identifiers are possible.
func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
switch typ := typ.(type) {
case *ast.Ident:
return &ast.Ident{Name: typ.Name, NamePos: pos}
case *ast.SelectorExpr:
if id, ok := typ.X.(*ast.Ident); ok {
// presumably a qualified identifier
return &ast.SelectorExpr{
Sel: ast.NewIdent(typ.Sel.Name),
X: &ast.Ident{Name: id.Name, NamePos: pos},
}
}
}
return nil // shouldn't happen, but be conservative and don't panic
}
func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
if tok == token.CONST {
// Propagate any type information that would get lost otherwise
// when unexported constants are filtered.
var prevType ast.Expr
for _, spec := range list {
spec := spec.(*ast.ValueSpec)
if spec.Type == nil && len(spec.Values) == 0 && prevType != nil {
// provide current spec with an explicit type
spec.Type = copyConstType(prevType, spec.Pos())
}
if hasExportedName(spec.Names) {
// exported names are preserved so there's no need to propagate the type
prevType = nil
} else {
prevType = spec.Type
}
}
}
j := 0
for _, s := range list {
if r.filterSpec(s) {
list[j] = s
j++
}
}
return list[0:j]
}
func (r *reader) filterDecl(decl ast.Decl) bool {
switch d := decl.(type) {
case *ast.GenDecl:
d.Specs = r.filterSpecList(d.Specs, d.Tok)
return len(d.Specs) > 0
case *ast.FuncDecl:
// ok to filter these methods early because any
// conflicting method will be filtered here, too -
// thus, removing these methods early will not lead
// to the false removal of possible conflicts
return token.IsExported(d.Name.Name)
}
return false
}
// fileExports removes unexported declarations from src in place.
func (r *reader) fileExports(src *ast.File) {
j := 0
for _, d := range src.Decls {
if r.filterDecl(d) {
src.Decls[j] = d
j++
}
}
src.Decls = src.Decls[0:j]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import "go/ast"
type Filter func(string) bool
func matchFields(fields *ast.FieldList, f Filter) bool {
if fields != nil {
for _, field := range fields.List {
for _, name := range field.Names {
if f(name.Name) {
return true
}
}
}
}
return false
}
func matchDecl(d *ast.GenDecl, f Filter) bool {
for _, d := range d.Specs {
switch v := d.(type) {
case *ast.ValueSpec:
for _, name := range v.Names {
if f(name.Name) {
return true
}
}
case *ast.TypeSpec:
if f(v.Name.Name) {
return true
}
// We don't match ordinary parameters in filterFuncs, so by analogy don't
// match type parameters here.
switch t := v.Type.(type) {
case *ast.StructType:
if matchFields(t.Fields, f) {
return true
}
case *ast.InterfaceType:
if matchFields(t.Methods, f) {
return true
}
}
}
}
return false
}
func filterValues(a []*Value, f Filter) []*Value {
w := 0
for _, vd := range a {
if matchDecl(vd.Decl, f) {
a[w] = vd
w++
}
}
return a[0:w]
}
func filterFuncs(a []*Func, f Filter) []*Func {
w := 0
for _, fd := range a {
if f(fd.Name) {
a[w] = fd
w++
}
}
return a[0:w]
}
func filterTypes(a []*Type, f Filter) []*Type {
w := 0
for _, td := range a {
n := 0 // number of matches
if matchDecl(td.Decl, f) {
n = 1
} else {
// type name doesn't match, but we may have matching consts, vars, factories or methods
td.Consts = filterValues(td.Consts, f)
td.Vars = filterValues(td.Vars, f)
td.Funcs = filterFuncs(td.Funcs, f)
td.Methods = filterFuncs(td.Methods, f)
n += len(td.Consts) + len(td.Vars) + len(td.Funcs) + len(td.Methods)
}
if n > 0 {
a[w] = td
w++
}
}
return a[0:w]
}
// Filter eliminates documentation for names that don't pass through the filter f.
// TODO(gri): Recognize "Type.Method" as a name.
func (p *Package) Filter(f Filter) {
p.Consts = filterValues(p.Consts, f)
p.Vars = filterValues(p.Vars, f)
p.Types = filterTypes(p.Types, f)
p.Funcs = filterFuncs(p.Funcs, f)
p.Doc = "" // don't show top-level package doc
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"fmt"
"go/ast"
"go/token"
"internal/lazyregexp"
"path"
"sort"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// ----------------------------------------------------------------------------
// function/method sets
//
// Internally, we treat functions like methods and collect them in method sets.
// A methodSet describes a set of methods. Entries where Decl == nil are conflict
// entries (more than one method with the same name at the same embedding level).
type methodSet map[string]*Func
// recvString returns a string representation of recv of the form "T", "*T",
// "T[A, ...]", "*T[A, ...]" or "BADRECV" (if not a proper receiver type).
func recvString(recv ast.Expr) string {
switch t := recv.(type) {
case *ast.Ident:
return t.Name
case *ast.StarExpr:
return "*" + recvString(t.X)
case *ast.IndexExpr:
// Generic type with one parameter.
return fmt.Sprintf("%s[%s]", recvString(t.X), recvParam(t.Index))
case *ast.IndexListExpr:
// Generic type with multiple parameters.
if len(t.Indices) > 0 {
var b strings.Builder
b.WriteString(recvString(t.X))
b.WriteByte('[')
b.WriteString(recvParam(t.Indices[0]))
for _, e := range t.Indices[1:] {
b.WriteString(", ")
b.WriteString(recvParam(e))
}
b.WriteByte(']')
return b.String()
}
}
return "BADRECV"
}
func recvParam(p ast.Expr) string {
if id, ok := p.(*ast.Ident); ok {
return id.Name
}
return "BADPARAM"
}
// set creates the corresponding Func for f and adds it to mset.
// If there are multiple f's with the same name, set keeps the first
// one with documentation; conflicts are ignored. The boolean
// specifies whether to leave the AST untouched.
func (mset methodSet) set(f *ast.FuncDecl, preserveAST bool) {
name := f.Name.Name
if g := mset[name]; g != nil && g.Doc != "" {
// A function with the same name has already been registered;
// since it has documentation, assume f is simply another
// implementation and ignore it. This does not happen if the
// caller is using go/build.ScanDir to determine the list of
// files implementing a package.
return
}
// function doesn't exist or has no documentation; use f
recv := ""
if f.Recv != nil {
var typ ast.Expr
// be careful in case of incorrect ASTs
if list := f.Recv.List; len(list) == 1 {
typ = list[0].Type
}
recv = recvString(typ)
}
mset[name] = &Func{
Doc: f.Doc.Text(),
Name: name,
Decl: f,
Recv: recv,
Orig: recv,
}
if !preserveAST {
f.Doc = nil // doc consumed - remove from AST
}
}
// add adds method m to the method set; m is ignored if the method set
// already contains a method with the same name at the same or a higher
// level than m.
func (mset methodSet) add(m *Func) {
old := mset[m.Name]
if old == nil || m.Level < old.Level {
mset[m.Name] = m
return
}
if m.Level == old.Level {
// conflict - mark it using a method with nil Decl
mset[m.Name] = &Func{
Name: m.Name,
Level: m.Level,
}
}
}
// ----------------------------------------------------------------------------
// Named types
// baseTypeName returns the name of the base type of x (or "")
// and whether the type is imported or not.
func baseTypeName(x ast.Expr) (name string, imported bool) {
switch t := x.(type) {
case *ast.Ident:
return t.Name, false
case *ast.IndexExpr:
return baseTypeName(t.X)
case *ast.IndexListExpr:
return baseTypeName(t.X)
case *ast.SelectorExpr:
if _, ok := t.X.(*ast.Ident); ok {
// only possible for qualified type names;
// assume type is imported
return t.Sel.Name, true
}
case *ast.ParenExpr:
return baseTypeName(t.X)
case *ast.StarExpr:
return baseTypeName(t.X)
}
return "", false
}
// An embeddedSet describes a set of embedded types.
type embeddedSet map[*namedType]bool
// A namedType represents a named unqualified (package local, or possibly
// predeclared) type. The namedType for a type name is always found via
// reader.lookupType.
type namedType struct {
doc string // doc comment for type
name string // type name
decl *ast.GenDecl // nil if declaration hasn't been seen yet
isEmbedded bool // true if this type is embedded
isStruct bool // true if this type is a struct
embedded embeddedSet // true if the embedded type is a pointer
// associated declarations
values []*Value // consts and vars
funcs methodSet
methods methodSet
}
// ----------------------------------------------------------------------------
// AST reader
// reader accumulates documentation for a single package.
// It modifies the AST: Comments (declaration documentation)
// that have been collected by the reader are set to nil
// in the respective AST nodes so that they are not printed
// twice (once when printing the documentation and once when
// printing the corresponding AST node).
type reader struct {
mode Mode
// package properties
doc string // package documentation, if any
filenames []string
notes map[string][]*Note
// imports
imports map[string]int
hasDotImp bool // if set, package contains a dot import
importByName map[string]string
// declarations
values []*Value // consts and vars
order int // sort order of const and var declarations (when we can't use a name)
types map[string]*namedType
funcs methodSet
// support for package-local shadowing of predeclared types
shadowedPredecl map[string]bool
fixmap map[string][]*ast.InterfaceType
}
func (r *reader) isVisible(name string) bool {
return r.mode&AllDecls != 0 || token.IsExported(name)
}
// lookupType returns the base type with the given name.
// If the base type has not been encountered yet, a new
// type with the given name but no associated declaration
// is added to the type map.
func (r *reader) lookupType(name string) *namedType {
if name == "" || name == "_" {
return nil // no type docs for anonymous types
}
if typ, found := r.types[name]; found {
return typ
}
// type not found - add one without declaration
typ := &namedType{
name: name,
embedded: make(embeddedSet),
funcs: make(methodSet),
methods: make(methodSet),
}
r.types[name] = typ
return typ
}
// recordAnonymousField registers fieldType as the type of an
// anonymous field in the parent type. If the field is imported
// (qualified name) or the parent is nil, the field is ignored.
// The function returns the field name.
func (r *reader) recordAnonymousField(parent *namedType, fieldType ast.Expr) (fname string) {
fname, imp := baseTypeName(fieldType)
if parent == nil || imp {
return
}
if ftype := r.lookupType(fname); ftype != nil {
ftype.isEmbedded = true
_, ptr := fieldType.(*ast.StarExpr)
parent.embedded[ftype] = ptr
}
return
}
func (r *reader) readDoc(comment *ast.CommentGroup) {
// By convention there should be only one package comment
// but collect all of them if there are more than one.
text := comment.Text()
if r.doc == "" {
r.doc = text
return
}
r.doc += "\n" + text
}
func (r *reader) remember(predecl string, typ *ast.InterfaceType) {
if r.fixmap == nil {
r.fixmap = make(map[string][]*ast.InterfaceType)
}
r.fixmap[predecl] = append(r.fixmap[predecl], typ)
}
func specNames(specs []ast.Spec) []string {
names := make([]string, 0, len(specs)) // reasonable estimate
for _, s := range specs {
// s guaranteed to be an *ast.ValueSpec by readValue
for _, ident := range s.(*ast.ValueSpec).Names {
names = append(names, ident.Name)
}
}
return names
}
// readValue processes a const or var declaration.
func (r *reader) readValue(decl *ast.GenDecl) {
// determine if decl should be associated with a type
// Heuristic: For each typed entry, determine the type name, if any.
// If there is exactly one type name that is sufficiently
// frequent, associate the decl with the respective type.
domName := ""
domFreq := 0
prev := ""
n := 0
for _, spec := range decl.Specs {
s, ok := spec.(*ast.ValueSpec)
if !ok {
continue // should not happen, but be conservative
}
name := ""
switch {
case s.Type != nil:
// a type is present; determine its name
if n, imp := baseTypeName(s.Type); !imp {
name = n
}
case decl.Tok == token.CONST && len(s.Values) == 0:
// no type or value is present but we have a constant declaration;
// use the previous type name (possibly the empty string)
name = prev
}
if name != "" {
// entry has a named type
if domName != "" && domName != name {
// more than one type name - do not associate
// with any type
domName = ""
break
}
domName = name
domFreq++
}
prev = name
n++
}
// nothing to do w/o a legal declaration
if n == 0 {
return
}
// determine values list with which to associate the Value for this decl
values := &r.values
const threshold = 0.75
if domName != "" && r.isVisible(domName) && domFreq >= int(float64(len(decl.Specs))*threshold) {
// typed entries are sufficiently frequent
if typ := r.lookupType(domName); typ != nil {
values = &typ.values // associate with that type
}
}
*values = append(*values, &Value{
Doc: decl.Doc.Text(),
Names: specNames(decl.Specs),
Decl: decl,
order: r.order,
})
if r.mode&PreserveAST == 0 {
decl.Doc = nil // doc consumed - remove from AST
}
// Note: It's important that the order used here is global because the cleanupTypes
// methods may move values associated with types back into the global list. If the
// order is list-specific, sorting is not deterministic because the same order value
// may appear multiple times (was bug, found when fixing #16153).
r.order++
}
// fields returns a struct's fields or an interface's methods.
func fields(typ ast.Expr) (list []*ast.Field, isStruct bool) {
var fields *ast.FieldList
switch t := typ.(type) {
case *ast.StructType:
fields = t.Fields
isStruct = true
case *ast.InterfaceType:
fields = t.Methods
}
if fields != nil {
list = fields.List
}
return
}
// readType processes a type declaration.
func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) {
typ := r.lookupType(spec.Name.Name)
if typ == nil {
return // no name or blank name - ignore the type
}
// A type should be added at most once, so typ.decl
// should be nil - if it is not, simply overwrite it.
typ.decl = decl
// compute documentation
doc := spec.Doc
if doc == nil {
// no doc associated with the spec, use the declaration doc, if any
doc = decl.Doc
}
if r.mode&PreserveAST == 0 {
spec.Doc = nil // doc consumed - remove from AST
decl.Doc = nil // doc consumed - remove from AST
}
typ.doc = doc.Text()
// record anonymous fields (they may contribute methods)
// (some fields may have been recorded already when filtering
// exports, but that's ok)
var list []*ast.Field
list, typ.isStruct = fields(spec.Type)
for _, field := range list {
if len(field.Names) == 0 {
r.recordAnonymousField(typ, field.Type)
}
}
}
// isPredeclared reports whether n denotes a predeclared type.
func (r *reader) isPredeclared(n string) bool {
return predeclaredTypes[n] && r.types[n] == nil
}
// readFunc processes a func or method declaration.
func (r *reader) readFunc(fun *ast.FuncDecl) {
// strip function body if requested.
if r.mode&PreserveAST == 0 {
fun.Body = nil
}
// associate methods with the receiver type, if any
if fun.Recv != nil {
// method
if len(fun.Recv.List) == 0 {
// should not happen (incorrect AST); (See issue 17788)
// don't show this method
return
}
recvTypeName, imp := baseTypeName(fun.Recv.List[0].Type)
if imp {
// should not happen (incorrect AST);
// don't show this method
return
}
if typ := r.lookupType(recvTypeName); typ != nil {
typ.methods.set(fun, r.mode&PreserveAST != 0)
}
// otherwise ignore the method
// TODO(gri): There may be exported methods of non-exported types
// that can be called because of exported values (consts, vars, or
// function results) of that type. Could determine if that is the
// case and then show those methods in an appropriate section.
return
}
// Associate factory functions with the first visible result type, as long as
// others are predeclared types.
if fun.Type.Results.NumFields() >= 1 {
var typ *namedType // type to associate the function with
numResultTypes := 0
for _, res := range fun.Type.Results.List {
factoryType := res.Type
if t, ok := factoryType.(*ast.ArrayType); ok {
// We consider functions that return slices or arrays of type
// T (or pointers to T) as factory functions of T.
factoryType = t.Elt
}
if n, imp := baseTypeName(factoryType); !imp && r.isVisible(n) && !r.isPredeclared(n) {
if lookupTypeParam(n, fun.Type.TypeParams) != nil {
// Issue #49477: don't associate fun with its type parameter result.
// A type parameter is not a defined type.
continue
}
if t := r.lookupType(n); t != nil {
typ = t
numResultTypes++
if numResultTypes > 1 {
break
}
}
}
}
// If there is exactly one result type,
// associate the function with that type.
if numResultTypes == 1 {
typ.funcs.set(fun, r.mode&PreserveAST != 0)
return
}
}
// just an ordinary function
r.funcs.set(fun, r.mode&PreserveAST != 0)
}
// lookupTypeParam searches for type parameters named name within the tparams
// field list, returning the relevant identifier if found, or nil if not.
func lookupTypeParam(name string, tparams *ast.FieldList) *ast.Ident {
if tparams == nil {
return nil
}
for _, field := range tparams.List {
for _, id := range field.Names {
if id.Name == name {
return id
}
}
}
return nil
}
var (
noteMarker = `([A-Z][A-Z]+)\(([^)]+)\):?` // MARKER(uid), MARKER at least 2 chars, uid at least 1 char
noteMarkerRx = lazyregexp.New(`^[ \t]*` + noteMarker) // MARKER(uid) at text start
noteCommentRx = lazyregexp.New(`^/[/*][ \t]*` + noteMarker) // MARKER(uid) at comment start
)
// clean replaces each sequence of space, \r, or \t characters
// with a single space and removes any trailing and leading spaces.
func clean(s string) string {
var b []byte
p := byte(' ')
for i := 0; i < len(s); i++ {
q := s[i]
if q == '\r' || q == '\t' {
q = ' '
}
if q != ' ' || p != ' ' {
b = append(b, q)
p = q
}
}
// remove trailing blank, if any
if n := len(b); n > 0 && p == ' ' {
b = b[0 : n-1]
}
return string(b)
}
// readNote collects a single note from a sequence of comments.
func (r *reader) readNote(list []*ast.Comment) {
text := (&ast.CommentGroup{List: list}).Text()
if m := noteMarkerRx.FindStringSubmatchIndex(text); m != nil {
// The note body starts after the marker.
// We remove any formatting so that we don't
// get spurious line breaks/indentation when
// showing the TODO body.
body := clean(text[m[1]:])
if body != "" {
marker := text[m[2]:m[3]]
r.notes[marker] = append(r.notes[marker], &Note{
Pos: list[0].Pos(),
End: list[len(list)-1].End(),
UID: text[m[4]:m[5]],
Body: body,
})
}
}
}
// readNotes extracts notes from comments.
// A note must start at the beginning of a comment with "MARKER(uid):"
// and is followed by the note body (e.g., "// BUG(gri): fix this").
// The note ends at the end of the comment group or at the start of
// another note in the same comment group, whichever comes first.
func (r *reader) readNotes(comments []*ast.CommentGroup) {
for _, group := range comments {
i := -1 // comment index of most recent note start, valid if >= 0
list := group.List
for j, c := range list {
if noteCommentRx.MatchString(c.Text) {
if i >= 0 {
r.readNote(list[i:j])
}
i = j
}
}
if i >= 0 {
r.readNote(list[i:])
}
}
}
// readFile adds the AST for a source file to the reader.
func (r *reader) readFile(src *ast.File) {
// add package documentation
if src.Doc != nil {
r.readDoc(src.Doc)
if r.mode&PreserveAST == 0 {
src.Doc = nil // doc consumed - remove from AST
}
}
// add all declarations but for functions which are processed in a separate pass
for _, decl := range src.Decls {
switch d := decl.(type) {
case *ast.GenDecl:
switch d.Tok {
case token.IMPORT:
// imports are handled individually
for _, spec := range d.Specs {
if s, ok := spec.(*ast.ImportSpec); ok {
if import_, err := strconv.Unquote(s.Path.Value); err == nil {
r.imports[import_] = 1
var name string
if s.Name != nil {
name = s.Name.Name
if name == "." {
r.hasDotImp = true
}
}
if name != "." {
if name == "" {
name = assumedPackageName(import_)
}
old, ok := r.importByName[name]
if !ok {
r.importByName[name] = import_
} else if old != import_ && old != "" {
r.importByName[name] = "" // ambiguous
}
}
}
}
}
case token.CONST, token.VAR:
// constants and variables are always handled as a group
r.readValue(d)
case token.TYPE:
// types are handled individually
if len(d.Specs) == 1 && !d.Lparen.IsValid() {
// common case: single declaration w/o parentheses
// (if a single declaration is parenthesized,
// create a new fake declaration below, so that
// go/doc type declarations always appear w/o
// parentheses)
if s, ok := d.Specs[0].(*ast.TypeSpec); ok {
r.readType(d, s)
}
break
}
for _, spec := range d.Specs {
if s, ok := spec.(*ast.TypeSpec); ok {
// use an individual (possibly fake) declaration
// for each type; this also ensures that each type
// gets to (re-)use the declaration documentation
// if there's none associated with the spec itself
fake := &ast.GenDecl{
Doc: d.Doc,
// don't use the existing TokPos because it
// will lead to the wrong selection range for
// the fake declaration if there are more
// than one type in the group (this affects
// src/cmd/godoc/godoc.go's posLink_urlFunc)
TokPos: s.Pos(),
Tok: token.TYPE,
Specs: []ast.Spec{s},
}
r.readType(fake, s)
}
}
}
}
}
// collect MARKER(...): annotations
r.readNotes(src.Comments)
if r.mode&PreserveAST == 0 {
src.Comments = nil // consumed unassociated comments - remove from AST
}
}
func (r *reader) readPackage(pkg *ast.Package, mode Mode) {
// initialize reader
r.filenames = make([]string, len(pkg.Files))
r.imports = make(map[string]int)
r.mode = mode
r.types = make(map[string]*namedType)
r.funcs = make(methodSet)
r.notes = make(map[string][]*Note)
r.importByName = make(map[string]string)
// sort package files before reading them so that the
// result does not depend on map iteration order
i := 0
for filename := range pkg.Files {
r.filenames[i] = filename
i++
}
sort.Strings(r.filenames)
// process files in sorted order
for _, filename := range r.filenames {
f := pkg.Files[filename]
if mode&AllDecls == 0 {
r.fileExports(f)
}
r.readFile(f)
}
for name, path := range r.importByName {
if path == "" {
delete(r.importByName, name)
}
}
// process functions now that we have better type information
for _, f := range pkg.Files {
for _, decl := range f.Decls {
if d, ok := decl.(*ast.FuncDecl); ok {
r.readFunc(d)
}
}
}
}
// ----------------------------------------------------------------------------
// Types
func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int) *Func {
if f == nil || f.Decl == nil || f.Decl.Recv == nil || len(f.Decl.Recv.List) != 1 {
return f // shouldn't happen, but be safe
}
// copy existing receiver field and set new type
newField := *f.Decl.Recv.List[0]
origPos := newField.Type.Pos()
_, origRecvIsPtr := newField.Type.(*ast.StarExpr)
newIdent := &ast.Ident{NamePos: origPos, Name: recvTypeName}
var typ ast.Expr = newIdent
if !embeddedIsPtr && origRecvIsPtr {
newIdent.NamePos++ // '*' is one character
typ = &ast.StarExpr{Star: origPos, X: newIdent}
}
newField.Type = typ
// copy existing receiver field list and set new receiver field
newFieldList := *f.Decl.Recv
newFieldList.List = []*ast.Field{&newField}
// copy existing function declaration and set new receiver field list
newFuncDecl := *f.Decl
newFuncDecl.Recv = &newFieldList
// copy existing function documentation and set new declaration
newF := *f
newF.Decl = &newFuncDecl
newF.Recv = recvString(typ)
// the Orig field never changes
newF.Level = level
return &newF
}
// collectEmbeddedMethods collects the embedded methods of typ in mset.
func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvTypeName string, embeddedIsPtr bool, level int, visited embeddedSet) {
visited[typ] = true
for embedded, isPtr := range typ.embedded {
// Once an embedded type is embedded as a pointer type
// all embedded types in those types are treated like
// pointer types for the purpose of the receiver type
// computation; i.e., embeddedIsPtr is sticky for this
// embedding hierarchy.
thisEmbeddedIsPtr := embeddedIsPtr || isPtr
for _, m := range embedded.methods {
// only top-level methods are embedded
if m.Level == 0 {
mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level))
}
}
if !visited[embedded] {
r.collectEmbeddedMethods(mset, embedded, recvTypeName, thisEmbeddedIsPtr, level+1, visited)
}
}
delete(visited, typ)
}
// computeMethodSets determines the actual method sets for each type encountered.
func (r *reader) computeMethodSets() {
for _, t := range r.types {
// collect embedded methods for t
if t.isStruct {
// struct
r.collectEmbeddedMethods(t.methods, t, t.name, false, 1, make(embeddedSet))
} else {
// interface
// TODO(gri) fix this
}
}
// For any predeclared names that are declared locally, don't treat them as
// exported fields anymore.
for predecl := range r.shadowedPredecl {
for _, ityp := range r.fixmap[predecl] {
removeAnonymousField(predecl, ityp)
}
}
}
// cleanupTypes removes the association of functions and methods with
// types that have no declaration. Instead, these functions and methods
// are shown at the package level. It also removes types with missing
// declarations or which are not visible.
func (r *reader) cleanupTypes() {
for _, t := range r.types {
visible := r.isVisible(t.name)
predeclared := predeclaredTypes[t.name]
if t.decl == nil && (predeclared || visible && (t.isEmbedded || r.hasDotImp)) {
// t.name is a predeclared type (and was not redeclared in this package),
// or it was embedded somewhere but its declaration is missing (because
// the AST is incomplete), or we have a dot-import (and all bets are off):
// move any associated values, funcs, and methods back to the top-level so
// that they are not lost.
// 1) move values
r.values = append(r.values, t.values...)
// 2) move factory functions
for name, f := range t.funcs {
// in a correct AST, package-level function names
// are all different - no need to check for conflicts
r.funcs[name] = f
}
// 3) move methods
if !predeclared {
for name, m := range t.methods {
// don't overwrite functions with the same name - drop them
if _, found := r.funcs[name]; !found {
r.funcs[name] = m
}
}
}
}
// remove types w/o declaration or which are not visible
if t.decl == nil || !visible {
delete(r.types, t.name)
}
}
}
// ----------------------------------------------------------------------------
// Sorting
type data struct {
n int
swap func(i, j int)
less func(i, j int) bool
}
func (d *data) Len() int { return d.n }
func (d *data) Swap(i, j int) { d.swap(i, j) }
func (d *data) Less(i, j int) bool { return d.less(i, j) }
// sortBy is a helper function for sorting.
func sortBy(less func(i, j int) bool, swap func(i, j int), n int) {
sort.Sort(&data{n, swap, less})
}
func sortedKeys(m map[string]int) []string {
list := make([]string, len(m))
i := 0
for key := range m {
list[i] = key
i++
}
sort.Strings(list)
return list
}
// sortingName returns the name to use when sorting d into place.
func sortingName(d *ast.GenDecl) string {
if len(d.Specs) == 1 {
if s, ok := d.Specs[0].(*ast.ValueSpec); ok {
return s.Names[0].Name
}
}
return ""
}
func sortedValues(m []*Value, tok token.Token) []*Value {
list := make([]*Value, len(m)) // big enough in any case
i := 0
for _, val := range m {
if val.Decl.Tok == tok {
list[i] = val
i++
}
}
list = list[0:i]
sortBy(
func(i, j int) bool {
if ni, nj := sortingName(list[i].Decl), sortingName(list[j].Decl); ni != nj {
return ni < nj
}
return list[i].order < list[j].order
},
func(i, j int) { list[i], list[j] = list[j], list[i] },
len(list),
)
return list
}
func sortedTypes(m map[string]*namedType, allMethods bool) []*Type {
list := make([]*Type, len(m))
i := 0
for _, t := range m {
list[i] = &Type{
Doc: t.doc,
Name: t.name,
Decl: t.decl,
Consts: sortedValues(t.values, token.CONST),
Vars: sortedValues(t.values, token.VAR),
Funcs: sortedFuncs(t.funcs, true),
Methods: sortedFuncs(t.methods, allMethods),
}
i++
}
sortBy(
func(i, j int) bool { return list[i].Name < list[j].Name },
func(i, j int) { list[i], list[j] = list[j], list[i] },
len(list),
)
return list
}
func removeStar(s string) string {
if len(s) > 0 && s[0] == '*' {
return s[1:]
}
return s
}
func sortedFuncs(m methodSet, allMethods bool) []*Func {
list := make([]*Func, len(m))
i := 0
for _, m := range m {
// determine which methods to include
switch {
case m.Decl == nil:
// exclude conflict entry
case allMethods, m.Level == 0, !token.IsExported(removeStar(m.Orig)):
// forced inclusion, method not embedded, or method
// embedded but original receiver type not exported
list[i] = m
i++
}
}
list = list[0:i]
sortBy(
func(i, j int) bool { return list[i].Name < list[j].Name },
func(i, j int) { list[i], list[j] = list[j], list[i] },
len(list),
)
return list
}
// noteBodies returns a list of note body strings given a list of notes.
// This is only used to populate the deprecated Package.Bugs field.
func noteBodies(notes []*Note) []string {
var list []string
for _, n := range notes {
list = append(list, n.Body)
}
return list
}
// ----------------------------------------------------------------------------
// Predeclared identifiers
// IsPredeclared reports whether s is a predeclared identifier.
func IsPredeclared(s string) bool {
return predeclaredTypes[s] || predeclaredFuncs[s] || predeclaredConstants[s]
}
var predeclaredTypes = map[string]bool{
"any": true,
"bool": true,
"byte": true,
"comparable": true,
"complex64": true,
"complex128": true,
"error": true,
"float32": true,
"float64": true,
"int": true,
"int8": true,
"int16": true,
"int32": true,
"int64": true,
"rune": true,
"string": true,
"uint": true,
"uint8": true,
"uint16": true,
"uint32": true,
"uint64": true,
"uintptr": true,
}
var predeclaredFuncs = map[string]bool{
"append": true,
"cap": true,
"close": true,
"complex": true,
"copy": true,
"delete": true,
"imag": true,
"len": true,
"make": true,
"new": true,
"panic": true,
"print": true,
"println": true,
"real": true,
"recover": true,
}
var predeclaredConstants = map[string]bool{
"false": true,
"iota": true,
"nil": true,
"true": true,
}
// assumedPackageName returns the assumed package name
// for a given import path. This is a copy of
// golang.org/x/tools/internal/imports.ImportPathToAssumedName.
func assumedPackageName(importPath string) string {
notIdentifier := func(ch rune) bool {
return !('a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' ||
'0' <= ch && ch <= '9' ||
ch == '_' ||
ch >= utf8.RuneSelf && (unicode.IsLetter(ch) || unicode.IsDigit(ch)))
}
base := path.Base(importPath)
if strings.HasPrefix(base, "v") {
if _, err := strconv.Atoi(base[1:]); err == nil {
dir := path.Dir(importPath)
if dir != "." {
base = path.Base(dir)
}
}
}
base = strings.TrimPrefix(base, "go-")
if i := strings.IndexFunc(base, notIdentifier); i >= 0 {
base = base[:i]
}
return base
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package doc
import (
"go/doc/comment"
"strings"
"unicode"
)
// firstSentence returns the first sentence in s.
// The sentence ends after the first period followed by space and
// not preceded by exactly one uppercase letter.
func firstSentence(s string) string {
var ppp, pp, p rune
for i, q := range s {
if q == '\n' || q == '\r' || q == '\t' {
q = ' '
}
if q == ' ' && p == '.' && (!unicode.IsUpper(pp) || unicode.IsUpper(ppp)) {
return s[:i]
}
if p == '。' || p == '.' {
return s[:i]
}
ppp, pp, p = pp, p, q
}
return s
}
// Synopsis returns a cleaned version of the first sentence in text.
//
// Deprecated: New programs should use [Package.Synopsis] instead,
// which handles links in text properly.
func Synopsis(text string) string {
var p Package
return p.Synopsis(text)
}
// IllegalPrefixes is a list of lower-case prefixes that identify
// a comment as not being a doc comment.
// This helps to avoid misinterpreting the common mistake
// of a copyright notice immediately before a package statement
// as being a doc comment.
var IllegalPrefixes = []string{
"copyright",
"all rights",
"author",
}
// Synopsis returns a cleaned version of the first sentence in text.
// That sentence ends after the first period followed by space and not
// preceded by exactly one uppercase letter, or at the first paragraph break.
// The result string has no \n, \r, or \t characters and uses only single
// spaces between words. If text starts with any of the IllegalPrefixes,
// the result is the empty string.
func (p *Package) Synopsis(text string) string {
text = firstSentence(text)
lower := strings.ToLower(text)
for _, prefix := range IllegalPrefixes {
if strings.HasPrefix(lower, prefix) {
return ""
}
}
pr := p.Printer()
pr.TextWidth = -1
d := p.Parser().Parse(text)
if len(d.Content) == 0 {
return ""
}
if _, ok := d.Content[0].(*comment.Paragraph); !ok {
return ""
}
d.Content = d.Content[:1] // might be blank lines, code blocks, etc in “first sentence”
return strings.TrimSpace(string(pr.Text(d)))
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package format implements standard formatting of Go source.
//
// Note that formatting of Go source code changes over time, so tools relying on
// consistent formatting should execute a specific version of the gofmt binary
// instead of using this package. That way, the formatting will be stable, and
// the tools won't need to be recompiled each time gofmt changes.
//
// For example, pre-submit checks that use this package directly would behave
// differently depending on what Go version each developer uses, causing the
// check to be inherently fragile.
package format
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
)
// Keep these in sync with cmd/gofmt/gofmt.go.
const (
tabWidth = 8
printerMode = printer.UseSpaces | printer.TabIndent | printerNormalizeNumbers
// printerNormalizeNumbers means to canonicalize number literal prefixes
// and exponents while printing. See https://golang.org/doc/go1.13#gofmt.
//
// This value is defined in go/printer specifically for go/format and cmd/gofmt.
printerNormalizeNumbers = 1 << 30
)
var config = printer.Config{Mode: printerMode, Tabwidth: tabWidth}
const parserMode = parser.ParseComments | parser.SkipObjectResolution
// Node formats node in canonical gofmt style and writes the result to dst.
//
// The node type must be *ast.File, *printer.CommentedNode, []ast.Decl,
// []ast.Stmt, or assignment-compatible to ast.Expr, ast.Decl, ast.Spec,
// or ast.Stmt. Node does not modify node. Imports are not sorted for
// nodes representing partial source files (for instance, if the node is
// not an *ast.File or a *printer.CommentedNode not wrapping an *ast.File).
//
// The function may return early (before the entire result is written)
// and return a formatting error, for instance due to an incorrect AST.
func Node(dst io.Writer, fset *token.FileSet, node any) error {
// Determine if we have a complete source file (file != nil).
var file *ast.File
var cnode *printer.CommentedNode
switch n := node.(type) {
case *ast.File:
file = n
case *printer.CommentedNode:
if f, ok := n.Node.(*ast.File); ok {
file = f
cnode = n
}
}
// Sort imports if necessary.
if file != nil && hasUnsortedImports(file) {
// Make a copy of the AST because ast.SortImports is destructive.
// TODO(gri) Do this more efficiently.
var buf bytes.Buffer
err := config.Fprint(&buf, fset, file)
if err != nil {
return err
}
file, err = parser.ParseFile(fset, "", buf.Bytes(), parserMode)
if err != nil {
// We should never get here. If we do, provide good diagnostic.
return fmt.Errorf("format.Node internal error (%s)", err)
}
ast.SortImports(fset, file)
// Use new file with sorted imports.
node = file
if cnode != nil {
node = &printer.CommentedNode{Node: file, Comments: cnode.Comments}
}
}
return config.Fprint(dst, fset, node)
}
// Source formats src in canonical gofmt style and returns the result
// or an (I/O or syntax) error. src is expected to be a syntactically
// correct Go source file, or a list of Go declarations or statements.
//
// If src is a partial source file, the leading and trailing space of src
// is applied to the result (such that it has the same leading and trailing
// space as src), and the result is indented by the same amount as the first
// line of src containing code. Imports are not sorted for partial source files.
func Source(src []byte) ([]byte, error) {
fset := token.NewFileSet()
file, sourceAdj, indentAdj, err := parse(fset, "", src, true)
if err != nil {
return nil, err
}
if sourceAdj == nil {
// Complete source file.
// TODO(gri) consider doing this always.
ast.SortImports(fset, file)
}
return format(fset, file, sourceAdj, indentAdj, src, config)
}
func hasUnsortedImports(file *ast.File) bool {
for _, d := range file.Decls {
d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT {
// Not an import declaration, so we're done.
// Imports are always first.
return false
}
if d.Lparen.IsValid() {
// For now assume all grouped imports are unsorted.
// TODO(gri) Should check if they are sorted already.
return true
}
// Ungrouped imports are sorted by default.
}
return false
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TODO(gri): This file and the file src/cmd/gofmt/internal.go are
// the same (but for this comment and the package name). Do not modify
// one without the other. Determine if we can factor out functionality
// in a public API. See also #11844 for context.
package format
import (
"bytes"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"strings"
)
// parse parses src, which was read from the named file,
// as a Go source file, declaration, or statement list.
func parse(fset *token.FileSet, filename string, src []byte, fragmentOk bool) (
file *ast.File,
sourceAdj func(src []byte, indent int) []byte,
indentAdj int,
err error,
) {
// Try as whole source file.
file, err = parser.ParseFile(fset, filename, src, parserMode)
// If there's no error, return. If the error is that the source file didn't begin with a
// package line and source fragments are ok, fall through to
// try as a source fragment. Stop and return on any other error.
if err == nil || !fragmentOk || !strings.Contains(err.Error(), "expected 'package'") {
return
}
// If this is a declaration list, make it a source file
// by inserting a package clause.
// Insert using a ';', not a newline, so that the line numbers
// in psrc match the ones in src.
psrc := append([]byte("package p;"), src...)
file, err = parser.ParseFile(fset, filename, psrc, parserMode)
if err == nil {
sourceAdj = func(src []byte, indent int) []byte {
// Remove the package clause.
// Gofmt has turned the ';' into a '\n'.
src = src[indent+len("package p\n"):]
return bytes.TrimSpace(src)
}
return
}
// If the error is that the source file didn't begin with a
// declaration, fall through to try as a statement list.
// Stop and return on any other error.
if !strings.Contains(err.Error(), "expected declaration") {
return
}
// If this is a statement list, make it a source file
// by inserting a package clause and turning the list
// into a function body. This handles expressions too.
// Insert using a ';', not a newline, so that the line numbers
// in fsrc match the ones in src. Add an extra '\n' before the '}'
// to make sure comments are flushed before the '}'.
fsrc := append(append([]byte("package p; func _() {"), src...), '\n', '\n', '}')
file, err = parser.ParseFile(fset, filename, fsrc, parserMode)
if err == nil {
sourceAdj = func(src []byte, indent int) []byte {
// Cap adjusted indent to zero.
if indent < 0 {
indent = 0
}
// Remove the wrapping.
// Gofmt has turned the "; " into a "\n\n".
// There will be two non-blank lines with indent, hence 2*indent.
src = src[2*indent+len("package p\n\nfunc _() {"):]
// Remove only the "}\n" suffix: remaining whitespaces will be trimmed anyway
src = src[:len(src)-len("}\n")]
return bytes.TrimSpace(src)
}
// Gofmt has also indented the function body one level.
// Adjust that with indentAdj.
indentAdj = -1
}
// Succeeded, or out of options.
return
}
// format formats the given package file originally obtained from src
// and adjusts the result based on the original source via sourceAdj
// and indentAdj.
func format(
fset *token.FileSet,
file *ast.File,
sourceAdj func(src []byte, indent int) []byte,
indentAdj int,
src []byte,
cfg printer.Config,
) ([]byte, error) {
if sourceAdj == nil {
// Complete source file.
var buf bytes.Buffer
err := cfg.Fprint(&buf, fset, file)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Partial source file.
// Determine and prepend leading space.
i, j := 0, 0
for j < len(src) && isSpace(src[j]) {
if src[j] == '\n' {
i = j + 1 // byte offset of last line in leading space
}
j++
}
var res []byte
res = append(res, src[:i]...)
// Determine and prepend indentation of first code line.
// Spaces are ignored unless there are no tabs,
// in which case spaces count as one tab.
indent := 0
hasSpace := false
for _, b := range src[i:j] {
switch b {
case ' ':
hasSpace = true
case '\t':
indent++
}
}
if indent == 0 && hasSpace {
indent = 1
}
for i := 0; i < indent; i++ {
res = append(res, '\t')
}
// Format the source.
// Write it without any leading and trailing space.
cfg.Indent = indent + indentAdj
var buf bytes.Buffer
err := cfg.Fprint(&buf, fset, file)
if err != nil {
return nil, err
}
out := sourceAdj(buf.Bytes(), cfg.Indent)
// If the adjusted output is empty, the source
// was empty but (possibly) for white space.
// The result is the incoming source.
if len(out) == 0 {
return src, nil
}
// Otherwise, append output to leading space.
res = append(res, out...)
// Determine and append trailing space.
i = len(src)
for i > 0 && isSpace(src[i-1]) {
i--
}
return append(res, src[i:]...), nil
}
// isSpace reports whether the byte is a space character.
// isSpace defines a space as being among the following bytes: ' ', '\t', '\n' and '\r'.
func isSpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package importer provides access to export data importers.
package importer
import (
"go/build"
"go/internal/gccgoimporter"
"go/internal/gcimporter"
"go/internal/srcimporter"
"go/token"
"go/types"
"io"
"runtime"
)
// A Lookup function returns a reader to access package data for
// a given import path, or an error if no matching package is found.
type Lookup func(path string) (io.ReadCloser, error)
// ForCompiler returns an Importer for importing from installed packages
// for the compilers "gc" and "gccgo", or for importing directly
// from the source if the compiler argument is "source". In this
// latter case, importing may fail under circumstances where the
// exported API is not entirely defined in pure Go source code
// (if the package API depends on cgo-defined entities, the type
// checker won't have access to those).
//
// The lookup function is called each time the resulting importer needs
// to resolve an import path. In this mode the importer can only be
// invoked with canonical import paths (not relative or absolute ones);
// it is assumed that the translation to canonical import paths is being
// done by the client of the importer.
//
// A lookup function must be provided for correct module-aware operation.
// Deprecated: If lookup is nil, for backwards-compatibility, the importer
// will attempt to resolve imports in the $GOPATH workspace.
func ForCompiler(fset *token.FileSet, compiler string, lookup Lookup) types.Importer {
switch compiler {
case "gc":
return &gcimports{
fset: fset,
packages: make(map[string]*types.Package),
lookup: lookup,
}
case "gccgo":
var inst gccgoimporter.GccgoInstallation
if err := inst.InitFromDriver("gccgo"); err != nil {
return nil
}
return &gccgoimports{
packages: make(map[string]*types.Package),
importer: inst.GetImporter(nil, nil),
lookup: lookup,
}
case "source":
if lookup != nil {
panic("source importer for custom import path lookup not supported (issue #13847).")
}
return srcimporter.New(&build.Default, fset, make(map[string]*types.Package))
}
// compiler not supported
return nil
}
// For calls ForCompiler with a new FileSet.
//
// Deprecated: Use ForCompiler, which populates a FileSet
// with the positions of objects created by the importer.
func For(compiler string, lookup Lookup) types.Importer {
return ForCompiler(token.NewFileSet(), compiler, lookup)
}
// Default returns an Importer for the compiler that built the running binary.
// If available, the result implements types.ImporterFrom.
func Default() types.Importer {
return For(runtime.Compiler, nil)
}
// gc importer
type gcimports struct {
fset *token.FileSet
packages map[string]*types.Package
lookup Lookup
}
func (m *gcimports) Import(path string) (*types.Package, error) {
return m.ImportFrom(path, "" /* no vendoring */, 0)
}
func (m *gcimports) ImportFrom(path, srcDir string, mode types.ImportMode) (*types.Package, error) {
if mode != 0 {
panic("mode must be 0")
}
return gcimporter.Import(m.fset, m.packages, path, srcDir, m.lookup)
}
// gccgo importer
type gccgoimports struct {
packages map[string]*types.Package
importer gccgoimporter.Importer
lookup Lookup
}
func (m *gccgoimports) Import(path string) (*types.Package, error) {
return m.ImportFrom(path, "" /* no vendoring */, 0)
}
func (m *gccgoimports) ImportFrom(path, srcDir string, mode types.ImportMode) (*types.Package, error) {
if mode != 0 {
panic("mode must be 0")
}
return m.importer(m.packages, path, srcDir, m.lookup)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gccgoimporter
import (
"bytes"
"debug/elf"
"errors"
"fmt"
"internal/xcoff"
"io"
"strconv"
"strings"
)
// Magic strings for different archive file formats.
const (
armag = "!<arch>\n"
armagt = "!<thin>\n"
armagb = "<bigaf>\n"
)
// Offsets and sizes for fields in a standard archive header.
const (
arNameOff = 0
arNameSize = 16
arDateOff = arNameOff + arNameSize
arDateSize = 12
arUIDOff = arDateOff + arDateSize
arUIDSize = 6
arGIDOff = arUIDOff + arUIDSize
arGIDSize = 6
arModeOff = arGIDOff + arGIDSize
arModeSize = 8
arSizeOff = arModeOff + arModeSize
arSizeSize = 10
arFmagOff = arSizeOff + arSizeSize
arFmagSize = 2
arHdrSize = arFmagOff + arFmagSize
)
// The contents of the fmag field of a standard archive header.
const arfmag = "`\n"
// arExportData takes an archive file and returns a ReadSeeker for the
// export data in that file. This assumes that there is only one
// object in the archive containing export data, which is not quite
// what gccgo does; gccgo concatenates together all the export data
// for all the objects in the file. In practice that case does not arise.
func arExportData(archive io.ReadSeeker) (io.ReadSeeker, error) {
if _, err := archive.Seek(0, io.SeekStart); err != nil {
return nil, err
}
var buf [len(armag)]byte
if _, err := archive.Read(buf[:]); err != nil {
return nil, err
}
switch string(buf[:]) {
case armag:
return standardArExportData(archive)
case armagt:
return nil, errors.New("unsupported thin archive")
case armagb:
return aixBigArExportData(archive)
default:
return nil, fmt.Errorf("unrecognized archive file format %q", buf[:])
}
}
// standardArExportData returns export data from a standard archive.
func standardArExportData(archive io.ReadSeeker) (io.ReadSeeker, error) {
off := int64(len(armag))
for {
var hdrBuf [arHdrSize]byte
if _, err := archive.Read(hdrBuf[:]); err != nil {
return nil, err
}
off += arHdrSize
if !bytes.Equal(hdrBuf[arFmagOff:arFmagOff+arFmagSize], []byte(arfmag)) {
return nil, fmt.Errorf("archive header format header (%q)", hdrBuf[:])
}
size, err := strconv.ParseInt(strings.TrimSpace(string(hdrBuf[arSizeOff:arSizeOff+arSizeSize])), 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing size in archive header (%q): %v", hdrBuf[:], err)
}
fn := hdrBuf[arNameOff : arNameOff+arNameSize]
if fn[0] == '/' && (fn[1] == ' ' || fn[1] == '/' || string(fn[:8]) == "/SYM64/ ") {
// Archive symbol table or extended name table,
// which we don't care about.
} else {
archiveAt := readerAtFromSeeker(archive)
ret, err := elfFromAr(io.NewSectionReader(archiveAt, off, size))
if ret != nil || err != nil {
return ret, err
}
}
if size&1 != 0 {
size++
}
off += size
if _, err := archive.Seek(off, io.SeekStart); err != nil {
return nil, err
}
}
}
// elfFromAr tries to get export data from an archive member as an ELF file.
// If there is no export data, this returns nil, nil.
func elfFromAr(member *io.SectionReader) (io.ReadSeeker, error) {
ef, err := elf.NewFile(member)
if err != nil {
return nil, err
}
sec := ef.Section(".go_export")
if sec == nil {
return nil, nil
}
return sec.Open(), nil
}
// aixBigArExportData returns export data from an AIX big archive.
func aixBigArExportData(archive io.ReadSeeker) (io.ReadSeeker, error) {
archiveAt := readerAtFromSeeker(archive)
arch, err := xcoff.NewArchive(archiveAt)
if err != nil {
return nil, err
}
for _, mem := range arch.Members {
f, err := arch.GetFile(mem.Name)
if err != nil {
return nil, err
}
sdat := f.CSect(".go_export")
if sdat != nil {
return bytes.NewReader(sdat), nil
}
}
return nil, fmt.Errorf(".go_export not found in this archive")
}
// readerAtFromSeeker turns an io.ReadSeeker into an io.ReaderAt.
// This is only safe because there won't be any concurrent seeks
// while this code is executing.
func readerAtFromSeeker(rs io.ReadSeeker) io.ReaderAt {
if ret, ok := rs.(io.ReaderAt); ok {
return ret
}
return seekerReadAt{rs}
}
type seekerReadAt struct {
seeker io.ReadSeeker
}
func (sra seekerReadAt) ReadAt(p []byte, off int64) (int, error) {
if _, err := sra.seeker.Seek(off, io.SeekStart); err != nil {
return 0, err
}
return sra.seeker.Read(p)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gccgoimporter
import (
"bufio"
"go/types"
"os"
"os/exec"
"path/filepath"
"strings"
)
// Information about a specific installation of gccgo.
type GccgoInstallation struct {
// Version of gcc (e.g. 4.8.0).
GccVersion string
// Target triple (e.g. x86_64-unknown-linux-gnu).
TargetTriple string
// Built-in library paths used by this installation.
LibPaths []string
}
// Ask the driver at the given path for information for this GccgoInstallation.
// The given arguments are passed directly to the call of the driver.
func (inst *GccgoInstallation) InitFromDriver(gccgoPath string, args ...string) (err error) {
argv := append([]string{"-###", "-S", "-x", "go", "-"}, args...)
cmd := exec.Command(gccgoPath, argv...)
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
err = cmd.Start()
if err != nil {
return
}
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
switch {
case strings.HasPrefix(line, "Target: "):
inst.TargetTriple = line[8:]
case line[0] == ' ':
args := strings.Fields(line)
for _, arg := range args[1:] {
if strings.HasPrefix(arg, "-L") {
inst.LibPaths = append(inst.LibPaths, arg[2:])
}
}
}
}
argv = append([]string{"-dumpversion"}, args...)
stdout, err := exec.Command(gccgoPath, argv...).Output()
if err != nil {
return
}
inst.GccVersion = strings.TrimSpace(string(stdout))
return
}
// Return the list of export search paths for this GccgoInstallation.
func (inst *GccgoInstallation) SearchPaths() (paths []string) {
for _, lpath := range inst.LibPaths {
spath := filepath.Join(lpath, "go", inst.GccVersion)
fi, err := os.Stat(spath)
if err != nil || !fi.IsDir() {
continue
}
paths = append(paths, spath)
spath = filepath.Join(spath, inst.TargetTriple)
fi, err = os.Stat(spath)
if err != nil || !fi.IsDir() {
continue
}
paths = append(paths, spath)
}
paths = append(paths, inst.LibPaths...)
return
}
// Return an importer that searches incpaths followed by the gcc installation's
// built-in search paths and the current directory.
func (inst *GccgoInstallation) GetImporter(incpaths []string, initmap map[*types.Package]InitData) Importer {
return GetImporter(append(append(incpaths, inst.SearchPaths()...), "."), initmap)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gccgoimporter implements Import for gccgo-generated object files.
package gccgoimporter // import "go/internal/gccgoimporter"
import (
"bytes"
"debug/elf"
"fmt"
"go/types"
"internal/xcoff"
"io"
"os"
"path/filepath"
"strings"
)
// A PackageInit describes an imported package that needs initialization.
type PackageInit struct {
Name string // short package name
InitFunc string // name of init function
Priority int // priority of init function, see InitData.Priority
}
// The gccgo-specific init data for a package.
type InitData struct {
// Initialization priority of this package relative to other packages.
// This is based on the maximum depth of the package's dependency graph;
// it is guaranteed to be greater than that of its dependencies.
Priority int
// The list of packages which this package depends on to be initialized,
// including itself if needed. This is the subset of the transitive closure of
// the package's dependencies that need initialization.
Inits []PackageInit
}
// Locate the file from which to read export data.
// This is intended to replicate the logic in gofrontend.
func findExportFile(searchpaths []string, pkgpath string) (string, error) {
for _, spath := range searchpaths {
pkgfullpath := filepath.Join(spath, pkgpath)
pkgdir, name := filepath.Split(pkgfullpath)
for _, filepath := range [...]string{
pkgfullpath,
pkgfullpath + ".gox",
pkgdir + "lib" + name + ".so",
pkgdir + "lib" + name + ".a",
pkgfullpath + ".o",
} {
fi, err := os.Stat(filepath)
if err == nil && !fi.IsDir() {
return filepath, nil
}
}
}
return "", fmt.Errorf("%s: could not find export data (tried %s)", pkgpath, strings.Join(searchpaths, ":"))
}
const (
gccgov1Magic = "v1;\n"
gccgov2Magic = "v2;\n"
gccgov3Magic = "v3;\n"
goimporterMagic = "\n$$ "
archiveMagic = "!<ar"
aixbigafMagic = "<big"
)
// Opens the export data file at the given path. If this is an ELF file,
// searches for and opens the .go_export section. If this is an archive,
// reads the export data from the first member, which is assumed to be an ELF file.
// This is intended to replicate the logic in gofrontend.
func openExportFile(fpath string) (reader io.ReadSeeker, closer io.Closer, err error) {
f, err := os.Open(fpath)
if err != nil {
return
}
closer = f
defer func() {
if err != nil && closer != nil {
f.Close()
}
}()
var magic [4]byte
_, err = f.ReadAt(magic[:], 0)
if err != nil {
return
}
var objreader io.ReaderAt
switch string(magic[:]) {
case gccgov1Magic, gccgov2Magic, gccgov3Magic, goimporterMagic:
// Raw export data.
reader = f
return
case archiveMagic, aixbigafMagic:
reader, err = arExportData(f)
return
default:
objreader = f
}
ef, err := elf.NewFile(objreader)
if err == nil {
sec := ef.Section(".go_export")
if sec == nil {
err = fmt.Errorf("%s: .go_export section not found", fpath)
return
}
reader = sec.Open()
return
}
xf, err := xcoff.NewFile(objreader)
if err == nil {
sdat := xf.CSect(".go_export")
if sdat == nil {
err = fmt.Errorf("%s: .go_export section not found", fpath)
return
}
reader = bytes.NewReader(sdat)
return
}
err = fmt.Errorf("%s: unrecognized file format", fpath)
return
}
// An Importer resolves import paths to Packages. The imports map records
// packages already known, indexed by package path.
// An importer must determine the canonical package path and check imports
// to see if it is already present in the map. If so, the Importer can return
// the map entry. Otherwise, the importer must load the package data for the
// given path into a new *Package, record it in imports map, and return the
// package.
type Importer func(imports map[string]*types.Package, path, srcDir string, lookup func(string) (io.ReadCloser, error)) (*types.Package, error)
func GetImporter(searchpaths []string, initmap map[*types.Package]InitData) Importer {
return func(imports map[string]*types.Package, pkgpath, srcDir string, lookup func(string) (io.ReadCloser, error)) (pkg *types.Package, err error) {
// TODO(gri): Use srcDir.
// Or not. It's possible that srcDir will fade in importance as
// the go command and other tools provide a translation table
// for relative imports (like ./foo or vendored imports).
if pkgpath == "unsafe" {
return types.Unsafe, nil
}
var reader io.ReadSeeker
var fpath string
var rc io.ReadCloser
if lookup != nil {
if p := imports[pkgpath]; p != nil && p.Complete() {
return p, nil
}
rc, err = lookup(pkgpath)
if err != nil {
return nil, err
}
}
if rc != nil {
defer rc.Close()
rs, ok := rc.(io.ReadSeeker)
if !ok {
return nil, fmt.Errorf("gccgo importer requires lookup to return an io.ReadSeeker, have %T", rc)
}
reader = rs
fpath = "<lookup " + pkgpath + ">"
// Take name from Name method (like on os.File) if present.
if n, ok := rc.(interface{ Name() string }); ok {
fpath = n.Name()
}
} else {
fpath, err = findExportFile(searchpaths, pkgpath)
if err != nil {
return nil, err
}
r, closer, err := openExportFile(fpath)
if err != nil {
return nil, err
}
if closer != nil {
defer closer.Close()
}
reader = r
}
var magics string
magics, err = readMagic(reader)
if err != nil {
return
}
if magics == archiveMagic || magics == aixbigafMagic {
reader, err = arExportData(reader)
if err != nil {
return
}
magics, err = readMagic(reader)
if err != nil {
return
}
}
switch magics {
case gccgov1Magic, gccgov2Magic, gccgov3Magic:
var p parser
p.init(fpath, reader, imports)
pkg = p.parsePackage()
if initmap != nil {
initmap[pkg] = p.initdata
}
// Excluded for now: Standard gccgo doesn't support this import format currently.
// case goimporterMagic:
// var data []byte
// data, err = io.ReadAll(reader)
// if err != nil {
// return
// }
// var n int
// n, pkg, err = importer.ImportData(imports, data)
// if err != nil {
// return
// }
// if initmap != nil {
// suffixreader := bytes.NewReader(data[n:])
// var p parser
// p.init(fpath, suffixreader, nil)
// p.parseInitData()
// initmap[pkg] = p.initdata
// }
default:
err = fmt.Errorf("unrecognized magic string: %q", magics)
}
return
}
}
// readMagic reads the four bytes at the start of a ReadSeeker and
// returns them as a string.
func readMagic(reader io.ReadSeeker) (string, error) {
var magic [4]byte
if _, err := reader.Read(magic[:]); err != nil {
return "", err
}
if _, err := reader.Seek(0, io.SeekStart); err != nil {
return "", err
}
return string(magic[:]), nil
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gccgoimporter
import (
"errors"
"fmt"
"go/constant"
"go/token"
"go/types"
"io"
"strconv"
"strings"
"text/scanner"
"unicode/utf8"
)
type parser struct {
scanner *scanner.Scanner
version string // format version
tok rune // current token
lit string // literal string; only valid for Ident, Int, String tokens
pkgpath string // package path of imported package
pkgname string // name of imported package
pkg *types.Package // reference to imported package
imports map[string]*types.Package // package path -> package object
typeList []types.Type // type number -> type
typeData []string // unparsed type data (v3 and later)
fixups []fixupRecord // fixups to apply at end of parsing
initdata InitData // package init priority data
aliases map[int]string // maps saved type number to alias name
}
// When reading export data it's possible to encounter a defined type
// N1 with an underlying defined type N2 while we are still reading in
// that defined type N2; see issues #29006 and #29198 for instances
// of this. Example:
//
// type N1 N2
// type N2 struct {
// ...
// p *N1
// }
//
// To handle such cases, the parser generates a fixup record (below) and
// delays setting of N1's underlying type until parsing is complete, at
// which point fixups are applied.
type fixupRecord struct {
toUpdate *types.Named // type to modify when fixup is processed
target types.Type // type that was incomplete when fixup was created
}
func (p *parser) init(filename string, src io.Reader, imports map[string]*types.Package) {
p.scanner = new(scanner.Scanner)
p.initScanner(filename, src)
p.imports = imports
p.aliases = make(map[int]string)
p.typeList = make([]types.Type, 1 /* type numbers start at 1 */, 16)
}
func (p *parser) initScanner(filename string, src io.Reader) {
p.scanner.Init(src)
p.scanner.Error = func(_ *scanner.Scanner, msg string) { p.error(msg) }
p.scanner.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings
p.scanner.Whitespace = 1<<'\t' | 1<<' '
p.scanner.Filename = filename // for good error messages
p.next()
}
type importError struct {
pos scanner.Position
err error
}
func (e importError) Error() string {
return fmt.Sprintf("import error %s (byte offset = %d): %s", e.pos, e.pos.Offset, e.err)
}
func (p *parser) error(err any) {
if s, ok := err.(string); ok {
err = errors.New(s)
}
// panic with a runtime.Error if err is not an error
panic(importError{p.scanner.Pos(), err.(error)})
}
func (p *parser) errorf(format string, args ...any) {
p.error(fmt.Errorf(format, args...))
}
func (p *parser) expect(tok rune) string {
lit := p.lit
if p.tok != tok {
p.errorf("expected %s, got %s (%s)", scanner.TokenString(tok), scanner.TokenString(p.tok), lit)
}
p.next()
return lit
}
func (p *parser) expectEOL() {
if p.version == "v1" || p.version == "v2" {
p.expect(';')
}
p.expect('\n')
}
func (p *parser) expectKeyword(keyword string) {
lit := p.expect(scanner.Ident)
if lit != keyword {
p.errorf("expected keyword %s, got %q", keyword, lit)
}
}
func (p *parser) parseString() string {
str, err := strconv.Unquote(p.expect(scanner.String))
if err != nil {
p.error(err)
}
return str
}
// unquotedString = { unquotedStringChar } .
// unquotedStringChar = <neither a whitespace nor a ';' char> .
func (p *parser) parseUnquotedString() string {
if p.tok == scanner.EOF {
p.error("unexpected EOF")
}
var b strings.Builder
b.WriteString(p.scanner.TokenText())
// This loop needs to examine each character before deciding whether to consume it. If we see a semicolon,
// we need to let it be consumed by p.next().
for ch := p.scanner.Peek(); ch != '\n' && ch != ';' && ch != scanner.EOF && p.scanner.Whitespace&(1<<uint(ch)) == 0; ch = p.scanner.Peek() {
b.WriteRune(ch)
p.scanner.Next()
}
p.next()
return b.String()
}
func (p *parser) next() {
p.tok = p.scanner.Scan()
switch p.tok {
case scanner.Ident, scanner.Int, scanner.Float, scanner.String, '·':
p.lit = p.scanner.TokenText()
default:
p.lit = ""
}
}
func (p *parser) parseQualifiedName() (path, name string) {
return p.parseQualifiedNameStr(p.parseString())
}
func (p *parser) parseUnquotedQualifiedName() (path, name string) {
return p.parseQualifiedNameStr(p.parseUnquotedString())
}
// qualifiedName = [ ["."] unquotedString "." ] unquotedString .
//
// The above production uses greedy matching.
func (p *parser) parseQualifiedNameStr(unquotedName string) (pkgpath, name string) {
parts := strings.Split(unquotedName, ".")
if parts[0] == "" {
parts = parts[1:]
}
switch len(parts) {
case 0:
p.errorf("malformed qualified name: %q", unquotedName)
case 1:
// unqualified name
pkgpath = p.pkgpath
name = parts[0]
default:
// qualified name, which may contain periods
pkgpath = strings.Join(parts[0:len(parts)-1], ".")
name = parts[len(parts)-1]
}
return
}
// getPkg returns the package for a given path. If the package is
// not found but we have a package name, create the package and
// add it to the p.imports map.
func (p *parser) getPkg(pkgpath, name string) *types.Package {
// package unsafe is not in the imports map - handle explicitly
if pkgpath == "unsafe" {
return types.Unsafe
}
pkg := p.imports[pkgpath]
if pkg == nil && name != "" {
pkg = types.NewPackage(pkgpath, name)
p.imports[pkgpath] = pkg
}
return pkg
}
// parseExportedName is like parseQualifiedName, but
// the package path is resolved to an imported *types.Package.
//
// ExportedName = string [string] .
func (p *parser) parseExportedName() (pkg *types.Package, name string) {
path, name := p.parseQualifiedName()
var pkgname string
if p.tok == scanner.String {
pkgname = p.parseString()
}
pkg = p.getPkg(path, pkgname)
if pkg == nil {
p.errorf("package %s (path = %q) not found", name, path)
}
return
}
// Name = QualifiedName | "?" .
func (p *parser) parseName() string {
if p.tok == '?' {
// Anonymous.
p.next()
return ""
}
// The package path is redundant for us. Don't try to parse it.
_, name := p.parseUnquotedQualifiedName()
return name
}
func deref(typ types.Type) types.Type {
if p, _ := typ.(*types.Pointer); p != nil {
typ = p.Elem()
}
return typ
}
// Field = Name Type [string] .
func (p *parser) parseField(pkg *types.Package) (field *types.Var, tag string) {
name := p.parseName()
typ, n := p.parseTypeExtended(pkg)
anon := false
if name == "" {
anon = true
// Alias?
if aname, ok := p.aliases[n]; ok {
name = aname
} else {
switch typ := deref(typ).(type) {
case *types.Basic:
name = typ.Name()
case *types.Named:
name = typ.Obj().Name()
default:
p.error("embedded field expected")
}
}
}
field = types.NewField(token.NoPos, pkg, name, typ, anon)
if p.tok == scanner.String {
tag = p.parseString()
}
return
}
// Param = Name ["..."] Type .
func (p *parser) parseParam(pkg *types.Package) (param *types.Var, isVariadic bool) {
name := p.parseName()
// Ignore names invented for inlinable functions.
if strings.HasPrefix(name, "p.") || strings.HasPrefix(name, "r.") || strings.HasPrefix(name, "$ret") {
name = ""
}
if p.tok == '<' && p.scanner.Peek() == 'e' {
// EscInfo = "<esc:" int ">" . (optional and ignored)
p.next()
p.expectKeyword("esc")
p.expect(':')
p.expect(scanner.Int)
p.expect('>')
}
if p.tok == '.' {
p.next()
p.expect('.')
p.expect('.')
isVariadic = true
}
typ := p.parseType(pkg)
if isVariadic {
typ = types.NewSlice(typ)
}
param = types.NewParam(token.NoPos, pkg, name, typ)
return
}
// Var = Name Type .
func (p *parser) parseVar(pkg *types.Package) *types.Var {
name := p.parseName()
v := types.NewVar(token.NoPos, pkg, name, p.parseType(pkg))
if name[0] == '.' || name[0] == '<' {
// This is an unexported variable,
// or a variable defined in a different package.
// We only want to record exported variables.
return nil
}
return v
}
// Conversion = "convert" "(" Type "," ConstValue ")" .
func (p *parser) parseConversion(pkg *types.Package) (val constant.Value, typ types.Type) {
p.expectKeyword("convert")
p.expect('(')
typ = p.parseType(pkg)
p.expect(',')
val, _ = p.parseConstValue(pkg)
p.expect(')')
return
}
// ConstValue = string | "false" | "true" | ["-"] (int ["'"] | FloatOrComplex) | Conversion .
// FloatOrComplex = float ["i" | ("+"|"-") float "i"] .
func (p *parser) parseConstValue(pkg *types.Package) (val constant.Value, typ types.Type) {
// v3 changed to $false, $true, $convert, to avoid confusion
// with variable names in inline function bodies.
if p.tok == '$' {
p.next()
if p.tok != scanner.Ident {
p.errorf("expected identifier after '$', got %s (%q)", scanner.TokenString(p.tok), p.lit)
}
}
switch p.tok {
case scanner.String:
str := p.parseString()
val = constant.MakeString(str)
typ = types.Typ[types.UntypedString]
return
case scanner.Ident:
b := false
switch p.lit {
case "false":
case "true":
b = true
case "convert":
return p.parseConversion(pkg)
default:
p.errorf("expected const value, got %s (%q)", scanner.TokenString(p.tok), p.lit)
}
p.next()
val = constant.MakeBool(b)
typ = types.Typ[types.UntypedBool]
return
}
sign := ""
if p.tok == '-' {
p.next()
sign = "-"
}
switch p.tok {
case scanner.Int:
val = constant.MakeFromLiteral(sign+p.lit, token.INT, 0)
if val == nil {
p.error("could not parse integer literal")
}
p.next()
if p.tok == '\'' {
p.next()
typ = types.Typ[types.UntypedRune]
} else {
typ = types.Typ[types.UntypedInt]
}
case scanner.Float:
re := sign + p.lit
p.next()
var im string
switch p.tok {
case '+':
p.next()
im = p.expect(scanner.Float)
case '-':
p.next()
im = "-" + p.expect(scanner.Float)
case scanner.Ident:
// re is in fact the imaginary component. Expect "i" below.
im = re
re = "0"
default:
val = constant.MakeFromLiteral(re, token.FLOAT, 0)
if val == nil {
p.error("could not parse float literal")
}
typ = types.Typ[types.UntypedFloat]
return
}
p.expectKeyword("i")
reval := constant.MakeFromLiteral(re, token.FLOAT, 0)
if reval == nil {
p.error("could not parse real component of complex literal")
}
imval := constant.MakeFromLiteral(im+"i", token.IMAG, 0)
if imval == nil {
p.error("could not parse imag component of complex literal")
}
val = constant.BinaryOp(reval, token.ADD, imval)
typ = types.Typ[types.UntypedComplex]
default:
p.errorf("expected const value, got %s (%q)", scanner.TokenString(p.tok), p.lit)
}
return
}
// Const = Name [Type] "=" ConstValue .
func (p *parser) parseConst(pkg *types.Package) *types.Const {
name := p.parseName()
var typ types.Type
if p.tok == '<' {
typ = p.parseType(pkg)
}
p.expect('=')
val, vtyp := p.parseConstValue(pkg)
if typ == nil {
typ = vtyp
}
return types.NewConst(token.NoPos, pkg, name, typ, val)
}
// reserved is a singleton type used to fill type map slots that have
// been reserved (i.e., for which a type number has been parsed) but
// which don't have their actual type yet. When the type map is updated,
// the actual type must replace a reserved entry (or we have an internal
// error). Used for self-verification only - not required for correctness.
var reserved = new(struct{ types.Type })
// reserve reserves the type map entry n for future use.
func (p *parser) reserve(n int) {
// Notes:
// - for pre-V3 export data, the type numbers we see are
// guaranteed to be in increasing order, so we append a
// reserved entry onto the list.
// - for V3+ export data, type numbers can appear in
// any order, however the 'types' section tells us the
// total number of types, hence typeList is pre-allocated.
if len(p.typeData) == 0 {
if n != len(p.typeList) {
p.errorf("invalid type number %d (out of sync)", n)
}
p.typeList = append(p.typeList, reserved)
} else {
if p.typeList[n] != nil {
p.errorf("previously visited type number %d", n)
}
p.typeList[n] = reserved
}
}
// update sets the type map entries for the entries in nlist to t.
// An entry in nlist can be a type number in p.typeList,
// used to resolve named types, or it can be a *types.Pointer,
// used to resolve pointers to named types in case they are referenced
// by embedded fields.
func (p *parser) update(t types.Type, nlist []any) {
if t == reserved {
p.errorf("internal error: update(%v) invoked on reserved", nlist)
}
if t == nil {
p.errorf("internal error: update(%v) invoked on nil", nlist)
}
for _, n := range nlist {
switch n := n.(type) {
case int:
if p.typeList[n] == t {
continue
}
if p.typeList[n] != reserved {
p.errorf("internal error: update(%v): %d not reserved", nlist, n)
}
p.typeList[n] = t
case *types.Pointer:
if *n != (types.Pointer{}) {
elem := n.Elem()
if elem == t {
continue
}
p.errorf("internal error: update: pointer already set to %v, expected %v", elem, t)
}
*n = *types.NewPointer(t)
default:
p.errorf("internal error: %T on nlist", n)
}
}
}
// NamedType = TypeName [ "=" ] Type { Method } .
// TypeName = ExportedName .
// Method = "func" "(" Param ")" Name ParamList ResultList [InlineBody] ";" .
func (p *parser) parseNamedType(nlist []any) types.Type {
pkg, name := p.parseExportedName()
scope := pkg.Scope()
obj := scope.Lookup(name)
if obj != nil && obj.Type() == nil {
p.errorf("%v has nil type", obj)
}
if p.tok == scanner.Ident && p.lit == "notinheap" {
p.next()
// The go/types package has no way of recording that
// this type is marked notinheap. Presumably no user
// of this package actually cares.
}
// type alias
if p.tok == '=' {
p.next()
p.aliases[nlist[len(nlist)-1].(int)] = name
if obj != nil {
// use the previously imported (canonical) type
t := obj.Type()
p.update(t, nlist)
p.parseType(pkg) // discard
return t
}
t := p.parseType(pkg, nlist...)
obj = types.NewTypeName(token.NoPos, pkg, name, t)
scope.Insert(obj)
return t
}
// defined type
if obj == nil {
// A named type may be referred to before the underlying type
// is known - set it up.
tname := types.NewTypeName(token.NoPos, pkg, name, nil)
types.NewNamed(tname, nil, nil)
scope.Insert(tname)
obj = tname
}
// use the previously imported (canonical), or newly created type
t := obj.Type()
p.update(t, nlist)
nt, ok := t.(*types.Named)
if !ok {
// This can happen for unsafe.Pointer, which is a TypeName holding a Basic type.
pt := p.parseType(pkg)
if pt != t {
p.error("unexpected underlying type for non-named TypeName")
}
return t
}
underlying := p.parseType(pkg)
if nt.Underlying() == nil {
if underlying.Underlying() == nil {
fix := fixupRecord{toUpdate: nt, target: underlying}
p.fixups = append(p.fixups, fix)
} else {
nt.SetUnderlying(underlying.Underlying())
}
}
if p.tok == '\n' {
p.next()
// collect associated methods
for p.tok == scanner.Ident {
p.expectKeyword("func")
if p.tok == '/' {
// Skip a /*nointerface*/ or /*asm ID */ comment.
p.expect('/')
p.expect('*')
if p.expect(scanner.Ident) == "asm" {
p.parseUnquotedString()
}
p.expect('*')
p.expect('/')
}
p.expect('(')
receiver, _ := p.parseParam(pkg)
p.expect(')')
name := p.parseName()
params, isVariadic := p.parseParamList(pkg)
results := p.parseResultList(pkg)
p.skipInlineBody()
p.expectEOL()
sig := types.NewSignatureType(receiver, nil, nil, params, results, isVariadic)
nt.AddMethod(types.NewFunc(token.NoPos, pkg, name, sig))
}
}
return nt
}
func (p *parser) parseInt64() int64 {
lit := p.expect(scanner.Int)
n, err := strconv.ParseInt(lit, 10, 64)
if err != nil {
p.error(err)
}
return n
}
func (p *parser) parseInt() int {
lit := p.expect(scanner.Int)
n, err := strconv.ParseInt(lit, 10, 0 /* int */)
if err != nil {
p.error(err)
}
return int(n)
}
// ArrayOrSliceType = "[" [ int ] "]" Type .
func (p *parser) parseArrayOrSliceType(pkg *types.Package, nlist []any) types.Type {
p.expect('[')
if p.tok == ']' {
p.next()
t := new(types.Slice)
p.update(t, nlist)
*t = *types.NewSlice(p.parseType(pkg))
return t
}
t := new(types.Array)
p.update(t, nlist)
len := p.parseInt64()
p.expect(']')
*t = *types.NewArray(p.parseType(pkg), len)
return t
}
// MapType = "map" "[" Type "]" Type .
func (p *parser) parseMapType(pkg *types.Package, nlist []any) types.Type {
p.expectKeyword("map")
t := new(types.Map)
p.update(t, nlist)
p.expect('[')
key := p.parseType(pkg)
p.expect(']')
elem := p.parseType(pkg)
*t = *types.NewMap(key, elem)
return t
}
// ChanType = "chan" ["<-" | "-<"] Type .
func (p *parser) parseChanType(pkg *types.Package, nlist []any) types.Type {
p.expectKeyword("chan")
t := new(types.Chan)
p.update(t, nlist)
dir := types.SendRecv
switch p.tok {
case '-':
p.next()
p.expect('<')
dir = types.SendOnly
case '<':
// don't consume '<' if it belongs to Type
if p.scanner.Peek() == '-' {
p.next()
p.expect('-')
dir = types.RecvOnly
}
}
*t = *types.NewChan(dir, p.parseType(pkg))
return t
}
// StructType = "struct" "{" { Field } "}" .
func (p *parser) parseStructType(pkg *types.Package, nlist []any) types.Type {
p.expectKeyword("struct")
t := new(types.Struct)
p.update(t, nlist)
var fields []*types.Var
var tags []string
p.expect('{')
for p.tok != '}' && p.tok != scanner.EOF {
field, tag := p.parseField(pkg)
p.expect(';')
fields = append(fields, field)
tags = append(tags, tag)
}
p.expect('}')
*t = *types.NewStruct(fields, tags)
return t
}
// ParamList = "(" [ { Parameter "," } Parameter ] ")" .
func (p *parser) parseParamList(pkg *types.Package) (*types.Tuple, bool) {
var list []*types.Var
isVariadic := false
p.expect('(')
for p.tok != ')' && p.tok != scanner.EOF {
if len(list) > 0 {
p.expect(',')
}
par, variadic := p.parseParam(pkg)
list = append(list, par)
if variadic {
if isVariadic {
p.error("... not on final argument")
}
isVariadic = true
}
}
p.expect(')')
return types.NewTuple(list...), isVariadic
}
// ResultList = Type | ParamList .
func (p *parser) parseResultList(pkg *types.Package) *types.Tuple {
switch p.tok {
case '<':
p.next()
if p.tok == scanner.Ident && p.lit == "inl" {
return nil
}
taa, _ := p.parseTypeAfterAngle(pkg)
return types.NewTuple(types.NewParam(token.NoPos, pkg, "", taa))
case '(':
params, _ := p.parseParamList(pkg)
return params
default:
return nil
}
}
// FunctionType = ParamList ResultList .
func (p *parser) parseFunctionType(pkg *types.Package, nlist []any) *types.Signature {
t := new(types.Signature)
p.update(t, nlist)
params, isVariadic := p.parseParamList(pkg)
results := p.parseResultList(pkg)
*t = *types.NewSignatureType(nil, nil, nil, params, results, isVariadic)
return t
}
// Func = Name FunctionType [InlineBody] .
func (p *parser) parseFunc(pkg *types.Package) *types.Func {
if p.tok == '/' {
// Skip an /*asm ID */ comment.
p.expect('/')
p.expect('*')
if p.expect(scanner.Ident) == "asm" {
p.parseUnquotedString()
}
p.expect('*')
p.expect('/')
}
name := p.parseName()
f := types.NewFunc(token.NoPos, pkg, name, p.parseFunctionType(pkg, nil))
p.skipInlineBody()
if name[0] == '.' || name[0] == '<' || strings.ContainsRune(name, '$') {
// This is an unexported function,
// or a function defined in a different package,
// or a type$equal or type$hash function.
// We only want to record exported functions.
return nil
}
return f
}
// InterfaceType = "interface" "{" { ("?" Type | Func) ";" } "}" .
func (p *parser) parseInterfaceType(pkg *types.Package, nlist []any) types.Type {
p.expectKeyword("interface")
t := new(types.Interface)
p.update(t, nlist)
var methods []*types.Func
var embeddeds []types.Type
p.expect('{')
for p.tok != '}' && p.tok != scanner.EOF {
if p.tok == '?' {
p.next()
embeddeds = append(embeddeds, p.parseType(pkg))
} else {
method := p.parseFunc(pkg)
if method != nil {
methods = append(methods, method)
}
}
p.expect(';')
}
p.expect('}')
*t = *types.NewInterfaceType(methods, embeddeds)
return t
}
// PointerType = "*" ("any" | Type) .
func (p *parser) parsePointerType(pkg *types.Package, nlist []any) types.Type {
p.expect('*')
if p.tok == scanner.Ident {
p.expectKeyword("any")
t := types.Typ[types.UnsafePointer]
p.update(t, nlist)
return t
}
t := new(types.Pointer)
p.update(t, nlist)
*t = *types.NewPointer(p.parseType(pkg, t))
return t
}
// TypeSpec = NamedType | MapType | ChanType | StructType | InterfaceType | PointerType | ArrayOrSliceType | FunctionType .
func (p *parser) parseTypeSpec(pkg *types.Package, nlist []any) types.Type {
switch p.tok {
case scanner.String:
return p.parseNamedType(nlist)
case scanner.Ident:
switch p.lit {
case "map":
return p.parseMapType(pkg, nlist)
case "chan":
return p.parseChanType(pkg, nlist)
case "struct":
return p.parseStructType(pkg, nlist)
case "interface":
return p.parseInterfaceType(pkg, nlist)
}
case '*':
return p.parsePointerType(pkg, nlist)
case '[':
return p.parseArrayOrSliceType(pkg, nlist)
case '(':
return p.parseFunctionType(pkg, nlist)
}
p.errorf("expected type name or literal, got %s", scanner.TokenString(p.tok))
return nil
}
const (
// From gofrontend/go/export.h
// Note that these values are negative in the gofrontend and have been made positive
// in the gccgoimporter.
gccgoBuiltinINT8 = 1
gccgoBuiltinINT16 = 2
gccgoBuiltinINT32 = 3
gccgoBuiltinINT64 = 4
gccgoBuiltinUINT8 = 5
gccgoBuiltinUINT16 = 6
gccgoBuiltinUINT32 = 7
gccgoBuiltinUINT64 = 8
gccgoBuiltinFLOAT32 = 9
gccgoBuiltinFLOAT64 = 10
gccgoBuiltinINT = 11
gccgoBuiltinUINT = 12
gccgoBuiltinUINTPTR = 13
gccgoBuiltinBOOL = 15
gccgoBuiltinSTRING = 16
gccgoBuiltinCOMPLEX64 = 17
gccgoBuiltinCOMPLEX128 = 18
gccgoBuiltinERROR = 19
gccgoBuiltinBYTE = 20
gccgoBuiltinRUNE = 21
)
func lookupBuiltinType(typ int) types.Type {
return [...]types.Type{
gccgoBuiltinINT8: types.Typ[types.Int8],
gccgoBuiltinINT16: types.Typ[types.Int16],
gccgoBuiltinINT32: types.Typ[types.Int32],
gccgoBuiltinINT64: types.Typ[types.Int64],
gccgoBuiltinUINT8: types.Typ[types.Uint8],
gccgoBuiltinUINT16: types.Typ[types.Uint16],
gccgoBuiltinUINT32: types.Typ[types.Uint32],
gccgoBuiltinUINT64: types.Typ[types.Uint64],
gccgoBuiltinFLOAT32: types.Typ[types.Float32],
gccgoBuiltinFLOAT64: types.Typ[types.Float64],
gccgoBuiltinINT: types.Typ[types.Int],
gccgoBuiltinUINT: types.Typ[types.Uint],
gccgoBuiltinUINTPTR: types.Typ[types.Uintptr],
gccgoBuiltinBOOL: types.Typ[types.Bool],
gccgoBuiltinSTRING: types.Typ[types.String],
gccgoBuiltinCOMPLEX64: types.Typ[types.Complex64],
gccgoBuiltinCOMPLEX128: types.Typ[types.Complex128],
gccgoBuiltinERROR: types.Universe.Lookup("error").Type(),
gccgoBuiltinBYTE: types.Universe.Lookup("byte").Type(),
gccgoBuiltinRUNE: types.Universe.Lookup("rune").Type(),
}[typ]
}
// Type = "<" "type" ( "-" int | int [ TypeSpec ] ) ">" .
//
// parseType updates the type map to t for all type numbers n.
func (p *parser) parseType(pkg *types.Package, n ...any) types.Type {
p.expect('<')
t, _ := p.parseTypeAfterAngle(pkg, n...)
return t
}
// (*parser).Type after reading the "<".
func (p *parser) parseTypeAfterAngle(pkg *types.Package, n ...any) (t types.Type, n1 int) {
p.expectKeyword("type")
n1 = 0
switch p.tok {
case scanner.Int:
n1 = p.parseInt()
if p.tok == '>' {
if len(p.typeData) > 0 && p.typeList[n1] == nil {
p.parseSavedType(pkg, n1, n)
}
t = p.typeList[n1]
if len(p.typeData) == 0 && t == reserved {
p.errorf("invalid type cycle, type %d not yet defined (nlist=%v)", n1, n)
}
p.update(t, n)
} else {
p.reserve(n1)
t = p.parseTypeSpec(pkg, append(n, n1))
}
case '-':
p.next()
n1 := p.parseInt()
t = lookupBuiltinType(n1)
p.update(t, n)
default:
p.errorf("expected type number, got %s (%q)", scanner.TokenString(p.tok), p.lit)
return nil, 0
}
if t == nil || t == reserved {
p.errorf("internal error: bad return from parseType(%v)", n)
}
p.expect('>')
return
}
// parseTypeExtended is identical to parseType, but if the type in
// question is a saved type, returns the index as well as the type
// pointer (index returned is zero if we parsed a builtin).
func (p *parser) parseTypeExtended(pkg *types.Package, n ...any) (t types.Type, n1 int) {
p.expect('<')
t, n1 = p.parseTypeAfterAngle(pkg, n...)
return
}
// InlineBody = "<inl:NN>" .{NN}
// Reports whether a body was skipped.
func (p *parser) skipInlineBody() {
// We may or may not have seen the '<' already, depending on
// whether the function had a result type or not.
if p.tok == '<' {
p.next()
p.expectKeyword("inl")
} else if p.tok != scanner.Ident || p.lit != "inl" {
return
} else {
p.next()
}
p.expect(':')
want := p.parseInt()
p.expect('>')
defer func(w uint64) {
p.scanner.Whitespace = w
}(p.scanner.Whitespace)
p.scanner.Whitespace = 0
got := 0
for got < want {
r := p.scanner.Next()
if r == scanner.EOF {
p.error("unexpected EOF")
}
got += utf8.RuneLen(r)
}
}
// Types = "types" maxp1 exportedp1 (offset length)* .
func (p *parser) parseTypes(pkg *types.Package) {
maxp1 := p.parseInt()
exportedp1 := p.parseInt()
p.typeList = make([]types.Type, maxp1, maxp1)
type typeOffset struct {
offset int
length int
}
var typeOffsets []typeOffset
total := 0
for i := 1; i < maxp1; i++ {
len := p.parseInt()
typeOffsets = append(typeOffsets, typeOffset{total, len})
total += len
}
defer func(w uint64) {
p.scanner.Whitespace = w
}(p.scanner.Whitespace)
p.scanner.Whitespace = 0
// We should now have p.tok pointing to the final newline.
// The next runes from the scanner should be the type data.
var sb strings.Builder
for sb.Len() < total {
r := p.scanner.Next()
if r == scanner.EOF {
p.error("unexpected EOF")
}
sb.WriteRune(r)
}
allTypeData := sb.String()
p.typeData = []string{""} // type 0, unused
for _, to := range typeOffsets {
p.typeData = append(p.typeData, allTypeData[to.offset:to.offset+to.length])
}
for i := 1; i < int(exportedp1); i++ {
p.parseSavedType(pkg, i, nil)
}
}
// parseSavedType parses one saved type definition.
func (p *parser) parseSavedType(pkg *types.Package, i int, nlist []any) {
defer func(s *scanner.Scanner, tok rune, lit string) {
p.scanner = s
p.tok = tok
p.lit = lit
}(p.scanner, p.tok, p.lit)
p.scanner = new(scanner.Scanner)
p.initScanner(p.scanner.Filename, strings.NewReader(p.typeData[i]))
p.expectKeyword("type")
id := p.parseInt()
if id != i {
p.errorf("type ID mismatch: got %d, want %d", id, i)
}
if p.typeList[i] == reserved {
p.errorf("internal error: %d already reserved in parseSavedType", i)
}
if p.typeList[i] == nil {
p.reserve(i)
p.parseTypeSpec(pkg, append(nlist, i))
}
if p.typeList[i] == nil || p.typeList[i] == reserved {
p.errorf("internal error: parseSavedType(%d,%v) reserved/nil", i, nlist)
}
}
// PackageInit = unquotedString unquotedString int .
func (p *parser) parsePackageInit() PackageInit {
name := p.parseUnquotedString()
initfunc := p.parseUnquotedString()
priority := -1
if p.version == "v1" {
priority = p.parseInt()
}
return PackageInit{Name: name, InitFunc: initfunc, Priority: priority}
}
// Create the package if we have parsed both the package path and package name.
func (p *parser) maybeCreatePackage() {
if p.pkgname != "" && p.pkgpath != "" {
p.pkg = p.getPkg(p.pkgpath, p.pkgname)
}
}
// InitDataDirective = ( "v1" | "v2" | "v3" ) ";" |
//
// "priority" int ";" |
// "init" { PackageInit } ";" |
// "checksum" unquotedString ";" .
func (p *parser) parseInitDataDirective() {
if p.tok != scanner.Ident {
// unexpected token kind; panic
p.expect(scanner.Ident)
}
switch p.lit {
case "v1", "v2", "v3":
p.version = p.lit
p.next()
p.expect(';')
p.expect('\n')
case "priority":
p.next()
p.initdata.Priority = p.parseInt()
p.expectEOL()
case "init":
p.next()
for p.tok != '\n' && p.tok != ';' && p.tok != scanner.EOF {
p.initdata.Inits = append(p.initdata.Inits, p.parsePackageInit())
}
p.expectEOL()
case "init_graph":
p.next()
// The graph data is thrown away for now.
for p.tok != '\n' && p.tok != ';' && p.tok != scanner.EOF {
p.parseInt64()
p.parseInt64()
}
p.expectEOL()
case "checksum":
// Don't let the scanner try to parse the checksum as a number.
defer func(mode uint) {
p.scanner.Mode = mode
}(p.scanner.Mode)
p.scanner.Mode &^= scanner.ScanInts | scanner.ScanFloats
p.next()
p.parseUnquotedString()
p.expectEOL()
default:
p.errorf("unexpected identifier: %q", p.lit)
}
}
// Directive = InitDataDirective |
//
// "package" unquotedString [ unquotedString ] [ unquotedString ] ";" |
// "pkgpath" unquotedString ";" |
// "prefix" unquotedString ";" |
// "import" unquotedString unquotedString string ";" |
// "indirectimport" unquotedString unquotedstring ";" |
// "func" Func ";" |
// "type" Type ";" |
// "var" Var ";" |
// "const" Const ";" .
func (p *parser) parseDirective() {
if p.tok != scanner.Ident {
// unexpected token kind; panic
p.expect(scanner.Ident)
}
switch p.lit {
case "v1", "v2", "v3", "priority", "init", "init_graph", "checksum":
p.parseInitDataDirective()
case "package":
p.next()
p.pkgname = p.parseUnquotedString()
p.maybeCreatePackage()
if p.version != "v1" && p.tok != '\n' && p.tok != ';' {
p.parseUnquotedString()
p.parseUnquotedString()
}
p.expectEOL()
case "pkgpath":
p.next()
p.pkgpath = p.parseUnquotedString()
p.maybeCreatePackage()
p.expectEOL()
case "prefix":
p.next()
p.pkgpath = p.parseUnquotedString()
p.expectEOL()
case "import":
p.next()
pkgname := p.parseUnquotedString()
pkgpath := p.parseUnquotedString()
p.getPkg(pkgpath, pkgname)
p.parseString()
p.expectEOL()
case "indirectimport":
p.next()
pkgname := p.parseUnquotedString()
pkgpath := p.parseUnquotedString()
p.getPkg(pkgpath, pkgname)
p.expectEOL()
case "types":
p.next()
p.parseTypes(p.pkg)
p.expectEOL()
case "func":
p.next()
fun := p.parseFunc(p.pkg)
if fun != nil {
p.pkg.Scope().Insert(fun)
}
p.expectEOL()
case "type":
p.next()
p.parseType(p.pkg)
p.expectEOL()
case "var":
p.next()
v := p.parseVar(p.pkg)
if v != nil {
p.pkg.Scope().Insert(v)
}
p.expectEOL()
case "const":
p.next()
c := p.parseConst(p.pkg)
p.pkg.Scope().Insert(c)
p.expectEOL()
default:
p.errorf("unexpected identifier: %q", p.lit)
}
}
// Package = { Directive } .
func (p *parser) parsePackage() *types.Package {
for p.tok != scanner.EOF {
p.parseDirective()
}
for _, f := range p.fixups {
if f.target.Underlying() == nil {
p.errorf("internal error: fixup can't be applied, loop required")
}
f.toUpdate.SetUnderlying(f.target.Underlying())
}
p.fixups = nil
for _, typ := range p.typeList {
if it, ok := typ.(*types.Interface); ok {
it.Complete()
}
}
p.pkg.MarkComplete()
return p.pkg
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements FindExportData.
package gcimporter
import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
)
func readGopackHeader(r *bufio.Reader) (name string, size int, err error) {
// See $GOROOT/include/ar.h.
hdr := make([]byte, 16+12+6+6+8+10+2)
_, err = io.ReadFull(r, hdr)
if err != nil {
return
}
// leave for debugging
if false {
fmt.Printf("header: %s", hdr)
}
s := strings.TrimSpace(string(hdr[16+12+6+6+8:][:10]))
size, err = strconv.Atoi(s)
if err != nil || hdr[len(hdr)-2] != '`' || hdr[len(hdr)-1] != '\n' {
err = fmt.Errorf("invalid archive header")
return
}
name = strings.TrimSpace(string(hdr[:16]))
return
}
// FindExportData positions the reader r at the beginning of the
// export data section of an underlying GC-created object/archive
// file by reading from it. The reader must be positioned at the
// start of the file before calling this function. The hdr result
// is the string before the export data, either "$$" or "$$B".
func FindExportData(r *bufio.Reader) (hdr string, size int, err error) {
// Read first line to make sure this is an object file.
line, err := r.ReadSlice('\n')
if err != nil {
err = fmt.Errorf("can't find export data (%v)", err)
return
}
if string(line) == "!<arch>\n" {
// Archive file. Scan to __.PKGDEF.
var name string
if name, size, err = readGopackHeader(r); err != nil {
return
}
// First entry should be __.PKGDEF.
if name != "__.PKGDEF" {
err = fmt.Errorf("go archive is missing __.PKGDEF")
return
}
// Read first line of __.PKGDEF data, so that line
// is once again the first line of the input.
if line, err = r.ReadSlice('\n'); err != nil {
err = fmt.Errorf("can't find export data (%v)", err)
return
}
}
// Now at __.PKGDEF in archive or still at beginning of file.
// Either way, line should begin with "go object ".
if !strings.HasPrefix(string(line), "go object ") {
err = fmt.Errorf("not a Go object file")
return
}
size -= len(line)
// Skip over object header to export data.
// Begins after first line starting with $$.
for line[0] != '$' {
if line, err = r.ReadSlice('\n'); err != nil {
err = fmt.Errorf("can't find export data (%v)", err)
return
}
size -= len(line)
}
hdr = string(line)
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gcimporter implements Import for gc-generated object files.
package gcimporter // import "go/internal/gcimporter"
import (
"bufio"
"bytes"
"fmt"
"go/build"
"go/token"
"go/types"
"internal/pkgbits"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
)
// debugging/development support
const debug = false
var exportMap sync.Map // package dir → func() (string, bool)
// lookupGorootExport returns the location of the export data
// (normally found in the build cache, but located in GOROOT/pkg
// in prior Go releases) for the package located in pkgDir.
//
// (We use the package's directory instead of its import path
// mainly to simplify handling of the packages in src/vendor
// and cmd/vendor.)
func lookupGorootExport(pkgDir string) (string, bool) {
f, ok := exportMap.Load(pkgDir)
if !ok {
var (
listOnce sync.Once
exportPath string
)
f, _ = exportMap.LoadOrStore(pkgDir, func() (string, bool) {
listOnce.Do(func() {
cmd := exec.Command("go", "list", "-export", "-f", "{{.Export}}", pkgDir)
cmd.Dir = build.Default.GOROOT
var output []byte
output, err := cmd.Output()
if err != nil {
return
}
exports := strings.Split(string(bytes.TrimSpace(output)), "\n")
if len(exports) != 1 {
return
}
exportPath = exports[0]
})
return exportPath, exportPath != ""
})
}
return f.(func() (string, bool))()
}
var pkgExts = [...]string{".a", ".o"} // a file from the build cache will have no extension
// FindPkg returns the filename and unique package id for an import
// path based on package information provided by build.Import (using
// the build.Default build.Context). A relative srcDir is interpreted
// relative to the current working directory.
// If no file was found, an empty filename is returned.
func FindPkg(path, srcDir string) (filename, id string) {
if path == "" {
return
}
var noext string
switch {
default:
// "x" -> "$GOPATH/pkg/$GOOS_$GOARCH/x.ext", "x"
// Don't require the source files to be present.
if abs, err := filepath.Abs(srcDir); err == nil { // see issue 14282
srcDir = abs
}
bp, _ := build.Import(path, srcDir, build.FindOnly|build.AllowBinary)
if bp.PkgObj == "" {
var ok bool
if bp.Goroot && bp.Dir != "" {
filename, ok = lookupGorootExport(bp.Dir)
}
if !ok {
id = path // make sure we have an id to print in error message
return
}
} else {
noext = strings.TrimSuffix(bp.PkgObj, ".a")
}
id = bp.ImportPath
case build.IsLocalImport(path):
// "./x" -> "/this/directory/x.ext", "/this/directory/x"
noext = filepath.Join(srcDir, path)
id = noext
case filepath.IsAbs(path):
// for completeness only - go/build.Import
// does not support absolute imports
// "/x" -> "/x.ext", "/x"
noext = path
id = path
}
if false { // for debugging
if path != id {
fmt.Printf("%s -> %s\n", path, id)
}
}
if filename != "" {
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
// try extensions
for _, ext := range pkgExts {
filename = noext + ext
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
filename = "" // not found
return
}
// Import imports a gc-generated package given its import path and srcDir, adds
// the corresponding package object to the packages map, and returns the object.
// The packages map must contain all packages already imported.
func Import(fset *token.FileSet, packages map[string]*types.Package, path, srcDir string, lookup func(path string) (io.ReadCloser, error)) (pkg *types.Package, err error) {
var rc io.ReadCloser
var id string
if lookup != nil {
// With custom lookup specified, assume that caller has
// converted path to a canonical import path for use in the map.
if path == "unsafe" {
return types.Unsafe, nil
}
id = path
// No need to re-import if the package was imported completely before.
if pkg = packages[id]; pkg != nil && pkg.Complete() {
return
}
f, err := lookup(path)
if err != nil {
return nil, err
}
rc = f
} else {
var filename string
filename, id = FindPkg(path, srcDir)
if filename == "" {
if path == "unsafe" {
return types.Unsafe, nil
}
return nil, fmt.Errorf("can't find import: %q", id)
}
// no need to re-import if the package was imported completely before
if pkg = packages[id]; pkg != nil && pkg.Complete() {
return
}
// open file
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
// add file name to error
err = fmt.Errorf("%s: %v", filename, err)
}
}()
rc = f
}
defer rc.Close()
buf := bufio.NewReader(rc)
hdr, size, err := FindExportData(buf)
if err != nil {
return
}
switch hdr {
case "$$\n":
err = fmt.Errorf("import %q: old textual export format no longer supported (recompile library)", path)
case "$$B\n":
var exportFormat byte
if exportFormat, err = buf.ReadByte(); err != nil {
return
}
// The unified export format starts with a 'u'; the indexed export
// format starts with an 'i'; and the older binary export format
// starts with a 'c', 'd', or 'v' (from "version"). Select
// appropriate importer.
switch exportFormat {
case 'u':
var data []byte
var r io.Reader = buf
if size >= 0 {
r = io.LimitReader(r, int64(size))
}
if data, err = io.ReadAll(r); err != nil {
return
}
s := string(data)
s = s[:strings.LastIndex(s, "\n$$\n")]
input := pkgbits.NewPkgDecoder(id, s)
pkg = readUnifiedPackage(fset, nil, packages, input)
case 'i':
pkg, err = iImportData(fset, packages, buf, id)
default:
err = fmt.Errorf("import %q: old binary export format no longer supported (recompile library)", path)
}
default:
err = fmt.Errorf("import %q: unknown export data header: %q", path, hdr)
}
return
}
type byPath []*types.Package
func (a byPath) Len() int { return len(a) }
func (a byPath) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byPath) Less(i, j int) bool { return a[i].Path() < a[j].Path() }
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Indexed package import.
// See cmd/compile/internal/gc/iexport.go for the export data format.
package gcimporter
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"go/constant"
"go/token"
"go/types"
"internal/saferio"
"io"
"math"
"math/big"
"sort"
"strings"
)
type intReader struct {
*bufio.Reader
path string
}
func (r *intReader) int64() int64 {
i, err := binary.ReadVarint(r.Reader)
if err != nil {
errorf("import %q: read varint error: %v", r.path, err)
}
return i
}
func (r *intReader) uint64() uint64 {
i, err := binary.ReadUvarint(r.Reader)
if err != nil {
errorf("import %q: read varint error: %v", r.path, err)
}
return i
}
// Keep this in sync with constants in iexport.go.
const (
iexportVersionGo1_11 = 0
iexportVersionPosCol = 1
iexportVersionGenerics = 2
iexportVersionGo1_18 = 2
iexportVersionCurrent = 2
)
type ident struct {
pkg *types.Package
name string
}
const predeclReserved = 32
type itag uint64
const (
// Types
definedType itag = iota
pointerType
sliceType
arrayType
chanType
mapType
signatureType
structType
interfaceType
typeParamType
instanceType
unionType
)
// iImportData imports a package from the serialized package data
// and returns the number of bytes consumed and a reference to the package.
// If the export data version is not recognized or the format is otherwise
// compromised, an error is returned.
func iImportData(fset *token.FileSet, imports map[string]*types.Package, dataReader *bufio.Reader, path string) (pkg *types.Package, err error) {
const currentVersion = iexportVersionCurrent
version := int64(-1)
defer func() {
if e := recover(); e != nil {
if version > currentVersion {
err = fmt.Errorf("cannot import %q (%v), export data is newer version - update tool", path, e)
} else {
err = fmt.Errorf("cannot import %q (%v), possibly version skew - reinstall package", path, e)
}
}
}()
r := &intReader{dataReader, path}
version = int64(r.uint64())
switch version {
case iexportVersionGo1_18, iexportVersionPosCol, iexportVersionGo1_11:
default:
errorf("unknown iexport format version %d", version)
}
sLen := r.uint64()
dLen := r.uint64()
if sLen > math.MaxUint64-dLen {
errorf("lengths out of range (%d, %d)", sLen, dLen)
}
data, err := saferio.ReadData(r, sLen+dLen)
if err != nil {
errorf("cannot read %d bytes of stringData and declData: %s", sLen+dLen, err)
}
stringData := data[:sLen]
declData := data[sLen:]
p := iimporter{
exportVersion: version,
ipath: path,
version: int(version),
stringData: stringData,
stringCache: make(map[uint64]string),
pkgCache: make(map[uint64]*types.Package),
declData: declData,
pkgIndex: make(map[*types.Package]map[string]uint64),
typCache: make(map[uint64]types.Type),
// Separate map for typeparams, keyed by their package and unique
// name (name with subscript).
tparamIndex: make(map[ident]*types.TypeParam),
fake: fakeFileSet{
fset: fset,
files: make(map[string]*fileInfo),
},
}
defer p.fake.setLines() // set lines for files in fset
for i, pt := range predeclared {
p.typCache[uint64(i)] = pt
}
pkgList := make([]*types.Package, r.uint64())
for i := range pkgList {
pkgPathOff := r.uint64()
pkgPath := p.stringAt(pkgPathOff)
pkgName := p.stringAt(r.uint64())
_ = r.uint64() // package height; unused by go/types
if pkgPath == "" {
pkgPath = path
}
pkg := imports[pkgPath]
if pkg == nil {
pkg = types.NewPackage(pkgPath, pkgName)
imports[pkgPath] = pkg
} else if pkg.Name() != pkgName {
errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path)
}
p.pkgCache[pkgPathOff] = pkg
nameIndex := make(map[string]uint64)
for nSyms := r.uint64(); nSyms > 0; nSyms-- {
name := p.stringAt(r.uint64())
nameIndex[name] = r.uint64()
}
p.pkgIndex[pkg] = nameIndex
pkgList[i] = pkg
}
localpkg := pkgList[0]
names := make([]string, 0, len(p.pkgIndex[localpkg]))
for name := range p.pkgIndex[localpkg] {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
p.doDecl(localpkg, name)
}
// SetConstraint can't be called if the constraint type is not yet complete.
// When type params are created in the 'P' case of (*importReader).obj(),
// the associated constraint type may not be complete due to recursion.
// Therefore, we defer calling SetConstraint there, and call it here instead
// after all types are complete.
for _, d := range p.later {
d.t.SetConstraint(d.constraint)
}
for _, typ := range p.interfaceList {
typ.Complete()
}
// record all referenced packages as imports
list := append(([]*types.Package)(nil), pkgList[1:]...)
sort.Sort(byPath(list))
localpkg.SetImports(list)
// package was imported completely and without errors
localpkg.MarkComplete()
return localpkg, nil
}
type setConstraintArgs struct {
t *types.TypeParam
constraint types.Type
}
type iimporter struct {
exportVersion int64
ipath string
version int
stringData []byte
stringCache map[uint64]string
pkgCache map[uint64]*types.Package
declData []byte
pkgIndex map[*types.Package]map[string]uint64
typCache map[uint64]types.Type
tparamIndex map[ident]*types.TypeParam
fake fakeFileSet
interfaceList []*types.Interface
// Arguments for calls to SetConstraint that are deferred due to recursive types
later []setConstraintArgs
}
func (p *iimporter) doDecl(pkg *types.Package, name string) {
// See if we've already imported this declaration.
if obj := pkg.Scope().Lookup(name); obj != nil {
return
}
off, ok := p.pkgIndex[pkg][name]
if !ok {
errorf("%v.%v not in index", pkg, name)
}
r := &importReader{p: p, currPkg: pkg}
r.declReader.Reset(p.declData[off:])
r.obj(name)
}
func (p *iimporter) stringAt(off uint64) string {
if s, ok := p.stringCache[off]; ok {
return s
}
slen, n := binary.Uvarint(p.stringData[off:])
if n <= 0 {
errorf("varint failed")
}
spos := off + uint64(n)
s := string(p.stringData[spos : spos+slen])
p.stringCache[off] = s
return s
}
func (p *iimporter) pkgAt(off uint64) *types.Package {
if pkg, ok := p.pkgCache[off]; ok {
return pkg
}
path := p.stringAt(off)
errorf("missing package %q in %q", path, p.ipath)
return nil
}
func (p *iimporter) typAt(off uint64, base *types.Named) types.Type {
if t, ok := p.typCache[off]; ok && canReuse(base, t) {
return t
}
if off < predeclReserved {
errorf("predeclared type missing from cache: %v", off)
}
r := &importReader{p: p}
r.declReader.Reset(p.declData[off-predeclReserved:])
t := r.doType(base)
if canReuse(base, t) {
p.typCache[off] = t
}
return t
}
// canReuse reports whether the type rhs on the RHS of the declaration for def
// may be re-used.
//
// Specifically, if def is non-nil and rhs is an interface type with methods, it
// may not be re-used because we have a convention of setting the receiver type
// for interface methods to def.
func canReuse(def *types.Named, rhs types.Type) bool {
if def == nil {
return true
}
iface, _ := rhs.(*types.Interface)
if iface == nil {
return true
}
// Don't use iface.Empty() here as iface may not be complete.
return iface.NumEmbeddeds() == 0 && iface.NumExplicitMethods() == 0
}
type importReader struct {
p *iimporter
declReader bytes.Reader
currPkg *types.Package
prevFile string
prevLine int64
prevColumn int64
}
func (r *importReader) obj(name string) {
tag := r.byte()
pos := r.pos()
switch tag {
case 'A':
typ := r.typ()
r.declare(types.NewTypeName(pos, r.currPkg, name, typ))
case 'C':
typ, val := r.value()
r.declare(types.NewConst(pos, r.currPkg, name, typ, val))
case 'F', 'G':
var tparams []*types.TypeParam
if tag == 'G' {
tparams = r.tparamList()
}
sig := r.signature(nil, nil, tparams)
r.declare(types.NewFunc(pos, r.currPkg, name, sig))
case 'T', 'U':
// Types can be recursive. We need to setup a stub
// declaration before recurring.
obj := types.NewTypeName(pos, r.currPkg, name, nil)
named := types.NewNamed(obj, nil, nil)
// Declare obj before calling r.tparamList, so the new type name is recognized
// if used in the constraint of one of its own typeparams (see #48280).
r.declare(obj)
if tag == 'U' {
tparams := r.tparamList()
named.SetTypeParams(tparams)
}
underlying := r.p.typAt(r.uint64(), named).Underlying()
named.SetUnderlying(underlying)
if !isInterface(underlying) {
for n := r.uint64(); n > 0; n-- {
mpos := r.pos()
mname := r.ident()
recv := r.param()
// If the receiver has any targs, set those as the
// rparams of the method (since those are the
// typeparams being used in the method sig/body).
targs := baseType(recv.Type()).TypeArgs()
var rparams []*types.TypeParam
if targs.Len() > 0 {
rparams = make([]*types.TypeParam, targs.Len())
for i := range rparams {
rparams[i], _ = targs.At(i).(*types.TypeParam)
}
}
msig := r.signature(recv, rparams, nil)
named.AddMethod(types.NewFunc(mpos, r.currPkg, mname, msig))
}
}
case 'P':
// We need to "declare" a typeparam in order to have a name that
// can be referenced recursively (if needed) in the type param's
// bound.
if r.p.exportVersion < iexportVersionGenerics {
errorf("unexpected type param type")
}
// Remove the "path" from the type param name that makes it unique,
// and revert any unique name used for blank typeparams.
name0 := tparamName(name)
tn := types.NewTypeName(pos, r.currPkg, name0, nil)
t := types.NewTypeParam(tn, nil)
// To handle recursive references to the typeparam within its
// bound, save the partial type in tparamIndex before reading the bounds.
id := ident{r.currPkg, name}
r.p.tparamIndex[id] = t
var implicit bool
if r.p.exportVersion >= iexportVersionGo1_18 {
implicit = r.bool()
}
constraint := r.typ()
if implicit {
iface, _ := constraint.(*types.Interface)
if iface == nil {
errorf("non-interface constraint marked implicit")
}
iface.MarkImplicit()
}
// The constraint type may not be complete, if we
// are in the middle of a type recursion involving type
// constraints. So, we defer SetConstraint until we have
// completely set up all types in ImportData.
r.p.later = append(r.p.later, setConstraintArgs{t: t, constraint: constraint})
case 'V':
typ := r.typ()
r.declare(types.NewVar(pos, r.currPkg, name, typ))
default:
errorf("unexpected tag: %v", tag)
}
}
func (r *importReader) declare(obj types.Object) {
obj.Pkg().Scope().Insert(obj)
}
func (r *importReader) value() (typ types.Type, val constant.Value) {
typ = r.typ()
if r.p.exportVersion >= iexportVersionGo1_18 {
// TODO: add support for using the kind
_ = constant.Kind(r.int64())
}
switch b := typ.Underlying().(*types.Basic); b.Info() & types.IsConstType {
case types.IsBoolean:
val = constant.MakeBool(r.bool())
case types.IsString:
val = constant.MakeString(r.string())
case types.IsInteger:
var x big.Int
r.mpint(&x, b)
val = constant.Make(&x)
case types.IsFloat:
val = r.mpfloat(b)
case types.IsComplex:
re := r.mpfloat(b)
im := r.mpfloat(b)
val = constant.BinaryOp(re, token.ADD, constant.MakeImag(im))
default:
errorf("unexpected type %v", typ) // panics
panic("unreachable")
}
return
}
func intSize(b *types.Basic) (signed bool, maxBytes uint) {
if (b.Info() & types.IsUntyped) != 0 {
return true, 64
}
switch b.Kind() {
case types.Float32, types.Complex64:
return true, 3
case types.Float64, types.Complex128:
return true, 7
}
signed = (b.Info() & types.IsUnsigned) == 0
switch b.Kind() {
case types.Int8, types.Uint8:
maxBytes = 1
case types.Int16, types.Uint16:
maxBytes = 2
case types.Int32, types.Uint32:
maxBytes = 4
default:
maxBytes = 8
}
return
}
func (r *importReader) mpint(x *big.Int, typ *types.Basic) {
signed, maxBytes := intSize(typ)
maxSmall := 256 - maxBytes
if signed {
maxSmall = 256 - 2*maxBytes
}
if maxBytes == 1 {
maxSmall = 256
}
n, _ := r.declReader.ReadByte()
if uint(n) < maxSmall {
v := int64(n)
if signed {
v >>= 1
if n&1 != 0 {
v = ^v
}
}
x.SetInt64(v)
return
}
v := -n
if signed {
v = -(n &^ 1) >> 1
}
if v < 1 || uint(v) > maxBytes {
errorf("weird decoding: %v, %v => %v", n, signed, v)
}
b := make([]byte, v)
io.ReadFull(&r.declReader, b)
x.SetBytes(b)
if signed && n&1 != 0 {
x.Neg(x)
}
}
func (r *importReader) mpfloat(typ *types.Basic) constant.Value {
var mant big.Int
r.mpint(&mant, typ)
var f big.Float
f.SetInt(&mant)
if f.Sign() != 0 {
f.SetMantExp(&f, int(r.int64()))
}
return constant.Make(&f)
}
func (r *importReader) ident() string {
return r.string()
}
func (r *importReader) qualifiedIdent() (*types.Package, string) {
name := r.string()
pkg := r.pkg()
return pkg, name
}
func (r *importReader) pos() token.Pos {
if r.p.version >= 1 {
r.posv1()
} else {
r.posv0()
}
if r.prevFile == "" && r.prevLine == 0 && r.prevColumn == 0 {
return token.NoPos
}
return r.p.fake.pos(r.prevFile, int(r.prevLine), int(r.prevColumn))
}
func (r *importReader) posv0() {
delta := r.int64()
if delta != deltaNewFile {
r.prevLine += delta
} else if l := r.int64(); l == -1 {
r.prevLine += deltaNewFile
} else {
r.prevFile = r.string()
r.prevLine = l
}
}
func (r *importReader) posv1() {
delta := r.int64()
r.prevColumn += delta >> 1
if delta&1 != 0 {
delta = r.int64()
r.prevLine += delta >> 1
if delta&1 != 0 {
r.prevFile = r.string()
}
}
}
func (r *importReader) typ() types.Type {
return r.p.typAt(r.uint64(), nil)
}
func isInterface(t types.Type) bool {
_, ok := t.(*types.Interface)
return ok
}
func (r *importReader) pkg() *types.Package { return r.p.pkgAt(r.uint64()) }
func (r *importReader) string() string { return r.p.stringAt(r.uint64()) }
func (r *importReader) doType(base *types.Named) types.Type {
switch k := r.kind(); k {
default:
errorf("unexpected kind tag in %q: %v", r.p.ipath, k)
return nil
case definedType:
pkg, name := r.qualifiedIdent()
r.p.doDecl(pkg, name)
return pkg.Scope().Lookup(name).(*types.TypeName).Type()
case pointerType:
return types.NewPointer(r.typ())
case sliceType:
return types.NewSlice(r.typ())
case arrayType:
n := r.uint64()
return types.NewArray(r.typ(), int64(n))
case chanType:
dir := chanDir(int(r.uint64()))
return types.NewChan(dir, r.typ())
case mapType:
return types.NewMap(r.typ(), r.typ())
case signatureType:
r.currPkg = r.pkg()
return r.signature(nil, nil, nil)
case structType:
r.currPkg = r.pkg()
fields := make([]*types.Var, r.uint64())
tags := make([]string, len(fields))
for i := range fields {
fpos := r.pos()
fname := r.ident()
ftyp := r.typ()
emb := r.bool()
tag := r.string()
fields[i] = types.NewField(fpos, r.currPkg, fname, ftyp, emb)
tags[i] = tag
}
return types.NewStruct(fields, tags)
case interfaceType:
r.currPkg = r.pkg()
embeddeds := make([]types.Type, r.uint64())
for i := range embeddeds {
_ = r.pos()
embeddeds[i] = r.typ()
}
methods := make([]*types.Func, r.uint64())
for i := range methods {
mpos := r.pos()
mname := r.ident()
// TODO(mdempsky): Matches bimport.go, but I
// don't agree with this.
var recv *types.Var
if base != nil {
recv = types.NewVar(token.NoPos, r.currPkg, "", base)
}
msig := r.signature(recv, nil, nil)
methods[i] = types.NewFunc(mpos, r.currPkg, mname, msig)
}
typ := types.NewInterfaceType(methods, embeddeds)
r.p.interfaceList = append(r.p.interfaceList, typ)
return typ
case typeParamType:
if r.p.exportVersion < iexportVersionGenerics {
errorf("unexpected type param type")
}
pkg, name := r.qualifiedIdent()
id := ident{pkg, name}
if t, ok := r.p.tparamIndex[id]; ok {
// We're already in the process of importing this typeparam.
return t
}
// Otherwise, import the definition of the typeparam now.
r.p.doDecl(pkg, name)
return r.p.tparamIndex[id]
case instanceType:
if r.p.exportVersion < iexportVersionGenerics {
errorf("unexpected instantiation type")
}
// pos does not matter for instances: they are positioned on the original
// type.
_ = r.pos()
len := r.uint64()
targs := make([]types.Type, len)
for i := range targs {
targs[i] = r.typ()
}
baseType := r.typ()
// The imported instantiated type doesn't include any methods, so
// we must always use the methods of the base (orig) type.
// TODO provide a non-nil *Context
t, _ := types.Instantiate(nil, baseType, targs, false)
return t
case unionType:
if r.p.exportVersion < iexportVersionGenerics {
errorf("unexpected instantiation type")
}
terms := make([]*types.Term, r.uint64())
for i := range terms {
terms[i] = types.NewTerm(r.bool(), r.typ())
}
return types.NewUnion(terms)
}
}
func (r *importReader) kind() itag {
return itag(r.uint64())
}
func (r *importReader) signature(recv *types.Var, rparams, tparams []*types.TypeParam) *types.Signature {
params := r.paramList()
results := r.paramList()
variadic := params.Len() > 0 && r.bool()
return types.NewSignatureType(recv, rparams, tparams, params, results, variadic)
}
func (r *importReader) tparamList() []*types.TypeParam {
n := r.uint64()
if n == 0 {
return nil
}
xs := make([]*types.TypeParam, n)
for i := range xs {
xs[i], _ = r.typ().(*types.TypeParam)
}
return xs
}
func (r *importReader) paramList() *types.Tuple {
xs := make([]*types.Var, r.uint64())
for i := range xs {
xs[i] = r.param()
}
return types.NewTuple(xs...)
}
func (r *importReader) param() *types.Var {
pos := r.pos()
name := r.ident()
typ := r.typ()
return types.NewParam(pos, r.currPkg, name, typ)
}
func (r *importReader) bool() bool {
return r.uint64() != 0
}
func (r *importReader) int64() int64 {
n, err := binary.ReadVarint(&r.declReader)
if err != nil {
errorf("readVarint: %v", err)
}
return n
}
func (r *importReader) uint64() uint64 {
n, err := binary.ReadUvarint(&r.declReader)
if err != nil {
errorf("readUvarint: %v", err)
}
return n
}
func (r *importReader) byte() byte {
x, err := r.declReader.ReadByte()
if err != nil {
errorf("declReader.ReadByte: %v", err)
}
return x
}
func baseType(typ types.Type) *types.Named {
// pointer receivers are never types.Named types
if p, _ := typ.(*types.Pointer); p != nil {
typ = p.Elem()
}
// receiver base types are always (possibly generic) types.Named types
n, _ := typ.(*types.Named)
return n
}
const blankMarker = "$"
// tparamName returns the real name of a type parameter, after stripping its
// qualifying prefix and reverting blank-name encoding. See tparamExportName
// for details.
func tparamName(exportName string) string {
// Remove the "path" from the type param name that makes it unique.
ix := strings.LastIndex(exportName, ".")
if ix < 0 {
errorf("malformed type parameter export name %s: missing prefix", exportName)
}
name := exportName[ix+1:]
if strings.HasPrefix(name, blankMarker) {
return "_"
}
return name
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements support functionality for iimport.go.
package gcimporter
import (
"fmt"
"go/token"
"go/types"
"internal/pkgbits"
"sync"
)
func assert(b bool) {
if !b {
panic("assertion failed")
}
}
func errorf(format string, args ...any) {
panic(fmt.Sprintf(format, args...))
}
// deltaNewFile is a magic line delta offset indicating a new file.
// We use -64 because it is rare; see issue 20080 and CL 41619.
// -64 is the smallest int that fits in a single byte as a varint.
const deltaNewFile = -64
// Synthesize a token.Pos
type fakeFileSet struct {
fset *token.FileSet
files map[string]*fileInfo
}
type fileInfo struct {
file *token.File
lastline int
}
const maxlines = 64 * 1024
func (s *fakeFileSet) pos(file string, line, column int) token.Pos {
// TODO(mdempsky): Make use of column.
// Since we don't know the set of needed file positions, we reserve
// maxlines positions per file. We delay calling token.File.SetLines until
// all positions have been calculated (by way of fakeFileSet.setLines), so
// that we can avoid setting unnecessary lines. See also golang/go#46586.
f := s.files[file]
if f == nil {
f = &fileInfo{file: s.fset.AddFile(file, -1, maxlines)}
s.files[file] = f
}
if line > maxlines {
line = 1
}
if line > f.lastline {
f.lastline = line
}
// Return a fake position assuming that f.file consists only of newlines.
return token.Pos(f.file.Base() + line - 1)
}
func (s *fakeFileSet) setLines() {
fakeLinesOnce.Do(func() {
fakeLines = make([]int, maxlines)
for i := range fakeLines {
fakeLines[i] = i
}
})
for _, f := range s.files {
f.file.SetLines(fakeLines[:f.lastline])
}
}
var (
fakeLines []int
fakeLinesOnce sync.Once
)
func chanDir(d int) types.ChanDir {
// tag values must match the constants in cmd/compile/internal/gc/go.go
switch d {
case 1 /* Crecv */ :
return types.RecvOnly
case 2 /* Csend */ :
return types.SendOnly
case 3 /* Cboth */ :
return types.SendRecv
default:
errorf("unexpected channel dir %d", d)
return 0
}
}
var predeclared = []types.Type{
// basic types
types.Typ[types.Bool],
types.Typ[types.Int],
types.Typ[types.Int8],
types.Typ[types.Int16],
types.Typ[types.Int32],
types.Typ[types.Int64],
types.Typ[types.Uint],
types.Typ[types.Uint8],
types.Typ[types.Uint16],
types.Typ[types.Uint32],
types.Typ[types.Uint64],
types.Typ[types.Uintptr],
types.Typ[types.Float32],
types.Typ[types.Float64],
types.Typ[types.Complex64],
types.Typ[types.Complex128],
types.Typ[types.String],
// basic type aliases
types.Universe.Lookup("byte").Type(),
types.Universe.Lookup("rune").Type(),
// error
types.Universe.Lookup("error").Type(),
// untyped types
types.Typ[types.UntypedBool],
types.Typ[types.UntypedInt],
types.Typ[types.UntypedRune],
types.Typ[types.UntypedFloat],
types.Typ[types.UntypedComplex],
types.Typ[types.UntypedString],
types.Typ[types.UntypedNil],
// package unsafe
types.Typ[types.UnsafePointer],
// invalid type
types.Typ[types.Invalid], // only appears in packages with errors
// used internally by gc; never used by this package or in .a files
// not to be confused with the universe any
anyType{},
// comparable
types.Universe.Lookup("comparable").Type(),
// any
types.Universe.Lookup("any").Type(),
}
type anyType struct{}
func (t anyType) Underlying() types.Type { return t }
func (t anyType) String() string { return "any" }
// See cmd/compile/internal/noder.derivedInfo.
type derivedInfo struct {
idx pkgbits.Index
needed bool
}
// See cmd/compile/internal/noder.typeInfo.
type typeInfo struct {
idx pkgbits.Index
derived bool
}
// See cmd/compile/internal/types.SplitVargenSuffix.
func splitVargenSuffix(name string) (base, suffix string) {
i := len(name)
for i > 0 && name[i-1] >= '0' && name[i-1] <= '9' {
i--
}
const dot = "·"
if i >= len(dot) && name[i-len(dot):i] == dot {
i -= len(dot)
return name[:i], name[i:]
}
return name, ""
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gcimporter
import (
"go/token"
"go/types"
"internal/pkgbits"
"sort"
)
// A pkgReader holds the shared state for reading a unified IR package
// description.
type pkgReader struct {
pkgbits.PkgDecoder
fake fakeFileSet
ctxt *types.Context
imports map[string]*types.Package // previously imported packages, indexed by path
// lazily initialized arrays corresponding to the unified IR
// PosBase, Pkg, and Type sections, respectively.
posBases []string // position bases (i.e., file names)
pkgs []*types.Package
typs []types.Type
// laterFns holds functions that need to be invoked at the end of
// import reading.
laterFns []func()
// ifaces holds a list of constructed Interfaces, which need to have
// Complete called after importing is done.
ifaces []*types.Interface
}
// later adds a function to be invoked at the end of import reading.
func (pr *pkgReader) later(fn func()) {
pr.laterFns = append(pr.laterFns, fn)
}
// readUnifiedPackage reads a package description from the given
// unified IR export data decoder.
func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[string]*types.Package, input pkgbits.PkgDecoder) *types.Package {
pr := pkgReader{
PkgDecoder: input,
fake: fakeFileSet{
fset: fset,
files: make(map[string]*fileInfo),
},
ctxt: ctxt,
imports: imports,
posBases: make([]string, input.NumElems(pkgbits.RelocPosBase)),
pkgs: make([]*types.Package, input.NumElems(pkgbits.RelocPkg)),
typs: make([]types.Type, input.NumElems(pkgbits.RelocType)),
}
defer pr.fake.setLines()
r := pr.newReader(pkgbits.RelocMeta, pkgbits.PublicRootIdx, pkgbits.SyncPublic)
pkg := r.pkg()
r.Bool() // TODO(mdempsky): Remove; was "has init"
for i, n := 0, r.Len(); i < n; i++ {
// As if r.obj(), but avoiding the Scope.Lookup call,
// to avoid eager loading of imports.
r.Sync(pkgbits.SyncObject)
assert(!r.Bool())
r.p.objIdx(r.Reloc(pkgbits.RelocObj))
assert(r.Len() == 0)
}
r.Sync(pkgbits.SyncEOF)
for _, fn := range pr.laterFns {
fn()
}
for _, iface := range pr.ifaces {
iface.Complete()
}
// Imports() of pkg are all of the transitive packages that were loaded.
var imps []*types.Package
for _, imp := range pr.pkgs {
if imp != nil && imp != pkg {
imps = append(imps, imp)
}
}
sort.Sort(byPath(imps))
pkg.SetImports(imps)
pkg.MarkComplete()
return pkg
}
// A reader holds the state for reading a single unified IR element
// within a package.
type reader struct {
pkgbits.Decoder
p *pkgReader
dict *readerDict
}
// A readerDict holds the state for type parameters that parameterize
// the current unified IR element.
type readerDict struct {
// bounds is a slice of typeInfos corresponding to the underlying
// bounds of the element's type parameters.
bounds []typeInfo
// tparams is a slice of the constructed TypeParams for the element.
tparams []*types.TypeParam
// devived is a slice of types derived from tparams, which may be
// instantiated while reading the current element.
derived []derivedInfo
derivedTypes []types.Type // lazily instantiated from derived
}
func (pr *pkgReader) newReader(k pkgbits.RelocKind, idx pkgbits.Index, marker pkgbits.SyncMarker) *reader {
return &reader{
Decoder: pr.NewDecoder(k, idx, marker),
p: pr,
}
}
func (pr *pkgReader) tempReader(k pkgbits.RelocKind, idx pkgbits.Index, marker pkgbits.SyncMarker) *reader {
return &reader{
Decoder: pr.TempDecoder(k, idx, marker),
p: pr,
}
}
func (pr *pkgReader) retireReader(r *reader) {
pr.RetireDecoder(&r.Decoder)
}
// @@@ Positions
func (r *reader) pos() token.Pos {
r.Sync(pkgbits.SyncPos)
if !r.Bool() {
return token.NoPos
}
// TODO(mdempsky): Delta encoding.
posBase := r.posBase()
line := r.Uint()
col := r.Uint()
return r.p.fake.pos(posBase, int(line), int(col))
}
func (r *reader) posBase() string {
return r.p.posBaseIdx(r.Reloc(pkgbits.RelocPosBase))
}
func (pr *pkgReader) posBaseIdx(idx pkgbits.Index) string {
if b := pr.posBases[idx]; b != "" {
return b
}
var filename string
{
r := pr.tempReader(pkgbits.RelocPosBase, idx, pkgbits.SyncPosBase)
// Within types2, position bases have a lot more details (e.g.,
// keeping track of where //line directives appeared exactly).
//
// For go/types, we just track the file name.
filename = r.String()
if r.Bool() { // file base
// Was: "b = token.NewTrimmedFileBase(filename, true)"
} else { // line base
pos := r.pos()
line := r.Uint()
col := r.Uint()
// Was: "b = token.NewLineBase(pos, filename, true, line, col)"
_, _, _ = pos, line, col
}
pr.retireReader(r)
}
b := filename
pr.posBases[idx] = b
return b
}
// @@@ Packages
func (r *reader) pkg() *types.Package {
r.Sync(pkgbits.SyncPkg)
return r.p.pkgIdx(r.Reloc(pkgbits.RelocPkg))
}
func (pr *pkgReader) pkgIdx(idx pkgbits.Index) *types.Package {
// TODO(mdempsky): Consider using some non-nil pointer to indicate
// the universe scope, so we don't need to keep re-reading it.
if pkg := pr.pkgs[idx]; pkg != nil {
return pkg
}
pkg := pr.newReader(pkgbits.RelocPkg, idx, pkgbits.SyncPkgDef).doPkg()
pr.pkgs[idx] = pkg
return pkg
}
func (r *reader) doPkg() *types.Package {
path := r.String()
switch path {
case "":
path = r.p.PkgPath()
case "builtin":
return nil // universe
case "unsafe":
return types.Unsafe
}
if pkg := r.p.imports[path]; pkg != nil {
return pkg
}
name := r.String()
pkg := types.NewPackage(path, name)
r.p.imports[path] = pkg
return pkg
}
// @@@ Types
func (r *reader) typ() types.Type {
return r.p.typIdx(r.typInfo(), r.dict)
}
func (r *reader) typInfo() typeInfo {
r.Sync(pkgbits.SyncType)
if r.Bool() {
return typeInfo{idx: pkgbits.Index(r.Len()), derived: true}
}
return typeInfo{idx: r.Reloc(pkgbits.RelocType), derived: false}
}
func (pr *pkgReader) typIdx(info typeInfo, dict *readerDict) types.Type {
idx := info.idx
var where *types.Type
if info.derived {
where = &dict.derivedTypes[idx]
idx = dict.derived[idx].idx
} else {
where = &pr.typs[idx]
}
if typ := *where; typ != nil {
return typ
}
var typ types.Type
{
r := pr.tempReader(pkgbits.RelocType, idx, pkgbits.SyncTypeIdx)
r.dict = dict
typ = r.doTyp()
assert(typ != nil)
pr.retireReader(r)
}
// See comment in pkgReader.typIdx explaining how this happens.
if prev := *where; prev != nil {
return prev
}
*where = typ
return typ
}
func (r *reader) doTyp() (res types.Type) {
switch tag := pkgbits.CodeType(r.Code(pkgbits.SyncType)); tag {
default:
errorf("unhandled type tag: %v", tag)
panic("unreachable")
case pkgbits.TypeBasic:
return types.Typ[r.Len()]
case pkgbits.TypeNamed:
obj, targs := r.obj()
name := obj.(*types.TypeName)
if len(targs) != 0 {
t, _ := types.Instantiate(r.p.ctxt, name.Type(), targs, false)
return t
}
return name.Type()
case pkgbits.TypeTypeParam:
return r.dict.tparams[r.Len()]
case pkgbits.TypeArray:
len := int64(r.Uint64())
return types.NewArray(r.typ(), len)
case pkgbits.TypeChan:
dir := types.ChanDir(r.Len())
return types.NewChan(dir, r.typ())
case pkgbits.TypeMap:
return types.NewMap(r.typ(), r.typ())
case pkgbits.TypePointer:
return types.NewPointer(r.typ())
case pkgbits.TypeSignature:
return r.signature(nil, nil, nil)
case pkgbits.TypeSlice:
return types.NewSlice(r.typ())
case pkgbits.TypeStruct:
return r.structType()
case pkgbits.TypeInterface:
return r.interfaceType()
case pkgbits.TypeUnion:
return r.unionType()
}
}
func (r *reader) structType() *types.Struct {
fields := make([]*types.Var, r.Len())
var tags []string
for i := range fields {
pos := r.pos()
pkg, name := r.selector()
ftyp := r.typ()
tag := r.String()
embedded := r.Bool()
fields[i] = types.NewField(pos, pkg, name, ftyp, embedded)
if tag != "" {
for len(tags) < i {
tags = append(tags, "")
}
tags = append(tags, tag)
}
}
return types.NewStruct(fields, tags)
}
func (r *reader) unionType() *types.Union {
terms := make([]*types.Term, r.Len())
for i := range terms {
terms[i] = types.NewTerm(r.Bool(), r.typ())
}
return types.NewUnion(terms)
}
func (r *reader) interfaceType() *types.Interface {
methods := make([]*types.Func, r.Len())
embeddeds := make([]types.Type, r.Len())
implicit := len(methods) == 0 && len(embeddeds) == 1 && r.Bool()
for i := range methods {
pos := r.pos()
pkg, name := r.selector()
mtyp := r.signature(nil, nil, nil)
methods[i] = types.NewFunc(pos, pkg, name, mtyp)
}
for i := range embeddeds {
embeddeds[i] = r.typ()
}
iface := types.NewInterfaceType(methods, embeddeds)
if implicit {
iface.MarkImplicit()
}
// We need to call iface.Complete(), but if there are any embedded
// defined types, then we may not have set their underlying
// interface type yet. So we need to defer calling Complete until
// after we've called SetUnderlying everywhere.
//
// TODO(mdempsky): After CL 424876 lands, it should be safe to call
// iface.Complete() immediately.
r.p.ifaces = append(r.p.ifaces, iface)
return iface
}
func (r *reader) signature(recv *types.Var, rtparams, tparams []*types.TypeParam) *types.Signature {
r.Sync(pkgbits.SyncSignature)
params := r.params()
results := r.params()
variadic := r.Bool()
return types.NewSignatureType(recv, rtparams, tparams, params, results, variadic)
}
func (r *reader) params() *types.Tuple {
r.Sync(pkgbits.SyncParams)
params := make([]*types.Var, r.Len())
for i := range params {
params[i] = r.param()
}
return types.NewTuple(params...)
}
func (r *reader) param() *types.Var {
r.Sync(pkgbits.SyncParam)
pos := r.pos()
pkg, name := r.localIdent()
typ := r.typ()
return types.NewParam(pos, pkg, name, typ)
}
// @@@ Objects
func (r *reader) obj() (types.Object, []types.Type) {
r.Sync(pkgbits.SyncObject)
assert(!r.Bool())
pkg, name := r.p.objIdx(r.Reloc(pkgbits.RelocObj))
obj := pkgScope(pkg).Lookup(name)
targs := make([]types.Type, r.Len())
for i := range targs {
targs[i] = r.typ()
}
return obj, targs
}
func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) {
var objPkg *types.Package
var objName string
var tag pkgbits.CodeObj
{
rname := pr.tempReader(pkgbits.RelocName, idx, pkgbits.SyncObject1)
objPkg, objName = rname.qualifiedIdent()
assert(objName != "")
tag = pkgbits.CodeObj(rname.Code(pkgbits.SyncCodeObj))
pr.retireReader(rname)
}
if tag == pkgbits.ObjStub {
assert(objPkg == nil || objPkg == types.Unsafe)
return objPkg, objName
}
// Ignore local types promoted to global scope (#55110).
if _, suffix := splitVargenSuffix(objName); suffix != "" {
return objPkg, objName
}
if objPkg.Scope().Lookup(objName) == nil {
dict := pr.objDictIdx(idx)
r := pr.newReader(pkgbits.RelocObj, idx, pkgbits.SyncObject1)
r.dict = dict
declare := func(obj types.Object) {
objPkg.Scope().Insert(obj)
}
switch tag {
default:
panic("weird")
case pkgbits.ObjAlias:
pos := r.pos()
typ := r.typ()
declare(types.NewTypeName(pos, objPkg, objName, typ))
case pkgbits.ObjConst:
pos := r.pos()
typ := r.typ()
val := r.Value()
declare(types.NewConst(pos, objPkg, objName, typ, val))
case pkgbits.ObjFunc:
pos := r.pos()
tparams := r.typeParamNames()
sig := r.signature(nil, nil, tparams)
declare(types.NewFunc(pos, objPkg, objName, sig))
case pkgbits.ObjType:
pos := r.pos()
obj := types.NewTypeName(pos, objPkg, objName, nil)
named := types.NewNamed(obj, nil, nil)
declare(obj)
named.SetTypeParams(r.typeParamNames())
underlying := r.typ().Underlying()
// If the underlying type is an interface, we need to
// duplicate its methods so we can replace the receiver
// parameter's type (#49906).
if iface, ok := underlying.(*types.Interface); ok && iface.NumExplicitMethods() != 0 {
methods := make([]*types.Func, iface.NumExplicitMethods())
for i := range methods {
fn := iface.ExplicitMethod(i)
sig := fn.Type().(*types.Signature)
recv := types.NewVar(fn.Pos(), fn.Pkg(), "", named)
methods[i] = types.NewFunc(fn.Pos(), fn.Pkg(), fn.Name(), types.NewSignature(recv, sig.Params(), sig.Results(), sig.Variadic()))
}
embeds := make([]types.Type, iface.NumEmbeddeds())
for i := range embeds {
embeds[i] = iface.EmbeddedType(i)
}
newIface := types.NewInterfaceType(methods, embeds)
r.p.ifaces = append(r.p.ifaces, newIface)
underlying = newIface
}
named.SetUnderlying(underlying)
for i, n := 0, r.Len(); i < n; i++ {
named.AddMethod(r.method())
}
case pkgbits.ObjVar:
pos := r.pos()
typ := r.typ()
declare(types.NewVar(pos, objPkg, objName, typ))
}
}
return objPkg, objName
}
func (pr *pkgReader) objDictIdx(idx pkgbits.Index) *readerDict {
var dict readerDict
{
r := pr.tempReader(pkgbits.RelocObjDict, idx, pkgbits.SyncObject1)
if implicits := r.Len(); implicits != 0 {
errorf("unexpected object with %v implicit type parameter(s)", implicits)
}
dict.bounds = make([]typeInfo, r.Len())
for i := range dict.bounds {
dict.bounds[i] = r.typInfo()
}
dict.derived = make([]derivedInfo, r.Len())
dict.derivedTypes = make([]types.Type, len(dict.derived))
for i := range dict.derived {
dict.derived[i] = derivedInfo{r.Reloc(pkgbits.RelocType), r.Bool()}
}
pr.retireReader(r)
}
// function references follow, but reader doesn't need those
return &dict
}
func (r *reader) typeParamNames() []*types.TypeParam {
r.Sync(pkgbits.SyncTypeParamNames)
// Note: This code assumes it only processes objects without
// implement type parameters. This is currently fine, because
// reader is only used to read in exported declarations, which are
// always package scoped.
if len(r.dict.bounds) == 0 {
return nil
}
// Careful: Type parameter lists may have cycles. To allow for this,
// we construct the type parameter list in two passes: first we
// create all the TypeNames and TypeParams, then we construct and
// set the bound type.
r.dict.tparams = make([]*types.TypeParam, len(r.dict.bounds))
for i := range r.dict.bounds {
pos := r.pos()
pkg, name := r.localIdent()
tname := types.NewTypeName(pos, pkg, name, nil)
r.dict.tparams[i] = types.NewTypeParam(tname, nil)
}
typs := make([]types.Type, len(r.dict.bounds))
for i, bound := range r.dict.bounds {
typs[i] = r.p.typIdx(bound, r.dict)
}
// TODO(mdempsky): This is subtle, elaborate further.
//
// We have to save tparams outside of the closure, because
// typeParamNames() can be called multiple times with the same
// dictionary instance.
//
// Also, this needs to happen later to make sure SetUnderlying has
// been called.
//
// TODO(mdempsky): Is it safe to have a single "later" slice or do
// we need to have multiple passes? See comments on CL 386002 and
// go.dev/issue/52104.
tparams := r.dict.tparams
r.p.later(func() {
for i, typ := range typs {
tparams[i].SetConstraint(typ)
}
})
return r.dict.tparams
}
func (r *reader) method() *types.Func {
r.Sync(pkgbits.SyncMethod)
pos := r.pos()
pkg, name := r.selector()
rparams := r.typeParamNames()
sig := r.signature(r.param(), rparams, nil)
_ = r.pos() // TODO(mdempsky): Remove; this is a hacker for linker.go.
return types.NewFunc(pos, pkg, name, sig)
}
func (r *reader) qualifiedIdent() (*types.Package, string) { return r.ident(pkgbits.SyncSym) }
func (r *reader) localIdent() (*types.Package, string) { return r.ident(pkgbits.SyncLocalIdent) }
func (r *reader) selector() (*types.Package, string) { return r.ident(pkgbits.SyncSelector) }
func (r *reader) ident(marker pkgbits.SyncMarker) (*types.Package, string) {
r.Sync(marker)
return r.pkg(), r.String()
}
// pkgScope returns pkg.Scope().
// If pkg is nil, it returns types.Universe instead.
//
// TODO(mdempsky): Remove after x/tools can depend on Go 1.19.
func pkgScope(pkg *types.Package) *types.Scope {
if pkg != nil {
return pkg.Scope()
}
return types.Universe
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package srcimporter implements importing directly
// from source files rather than installed packages.
package srcimporter // import "go/internal/srcimporter"
import (
"fmt"
"go/ast"
"go/build"
"go/parser"
"go/token"
"go/types"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
_ "unsafe" // for go:linkname
)
// An Importer provides the context for importing packages from source code.
type Importer struct {
ctxt *build.Context
fset *token.FileSet
sizes types.Sizes
packages map[string]*types.Package
}
// New returns a new Importer for the given context, file set, and map
// of packages. The context is used to resolve import paths to package paths,
// and identifying the files belonging to the package. If the context provides
// non-nil file system functions, they are used instead of the regular package
// os functions. The file set is used to track position information of package
// files; and imported packages are added to the packages map.
func New(ctxt *build.Context, fset *token.FileSet, packages map[string]*types.Package) *Importer {
return &Importer{
ctxt: ctxt,
fset: fset,
sizes: types.SizesFor(ctxt.Compiler, ctxt.GOARCH), // uses go/types default if GOARCH not found
packages: packages,
}
}
// Importing is a sentinel taking the place in Importer.packages
// for a package that is in the process of being imported.
var importing types.Package
// Import(path) is a shortcut for ImportFrom(path, ".", 0).
func (p *Importer) Import(path string) (*types.Package, error) {
return p.ImportFrom(path, ".", 0) // use "." rather than "" (see issue #24441)
}
// ImportFrom imports the package with the given import path resolved from the given srcDir,
// adds the new package to the set of packages maintained by the importer, and returns the
// package. Package path resolution and file system operations are controlled by the context
// maintained with the importer. The import mode must be zero but is otherwise ignored.
// Packages that are not comprised entirely of pure Go files may fail to import because the
// type checker may not be able to determine all exported entities (e.g. due to cgo dependencies).
func (p *Importer) ImportFrom(path, srcDir string, mode types.ImportMode) (*types.Package, error) {
if mode != 0 {
panic("non-zero import mode")
}
if abs, err := p.absPath(srcDir); err == nil { // see issue #14282
srcDir = abs
}
bp, err := p.ctxt.Import(path, srcDir, 0)
if err != nil {
return nil, err // err may be *build.NoGoError - return as is
}
// package unsafe is known to the type checker
if bp.ImportPath == "unsafe" {
return types.Unsafe, nil
}
// no need to re-import if the package was imported completely before
pkg := p.packages[bp.ImportPath]
if pkg != nil {
if pkg == &importing {
return nil, fmt.Errorf("import cycle through package %q", bp.ImportPath)
}
if !pkg.Complete() {
// Package exists but is not complete - we cannot handle this
// at the moment since the source importer replaces the package
// wholesale rather than augmenting it (see #19337 for details).
// Return incomplete package with error (see #16088).
return pkg, fmt.Errorf("reimported partially imported package %q", bp.ImportPath)
}
return pkg, nil
}
p.packages[bp.ImportPath] = &importing
defer func() {
// clean up in case of error
// TODO(gri) Eventually we may want to leave a (possibly empty)
// package in the map in all cases (and use that package to
// identify cycles). See also issue 16088.
if p.packages[bp.ImportPath] == &importing {
p.packages[bp.ImportPath] = nil
}
}()
var filenames []string
filenames = append(filenames, bp.GoFiles...)
filenames = append(filenames, bp.CgoFiles...)
files, err := p.parseFiles(bp.Dir, filenames)
if err != nil {
return nil, err
}
// type-check package files
var firstHardErr error
conf := types.Config{
IgnoreFuncBodies: true,
// continue type-checking after the first error
Error: func(err error) {
if firstHardErr == nil && !err.(types.Error).Soft {
firstHardErr = err
}
},
Importer: p,
Sizes: p.sizes,
}
if len(bp.CgoFiles) > 0 {
if p.ctxt.OpenFile != nil {
// cgo, gcc, pkg-config, etc. do not support
// build.Context's VFS.
conf.FakeImportC = true
} else {
setUsesCgo(&conf)
file, err := p.cgo(bp)
if err != nil {
return nil, fmt.Errorf("error processing cgo for package %q: %w", bp.ImportPath, err)
}
files = append(files, file)
}
}
pkg, err = conf.Check(bp.ImportPath, p.fset, files, nil)
if err != nil {
// If there was a hard error it is possibly unsafe
// to use the package as it may not be fully populated.
// Do not return it (see also #20837, #20855).
if firstHardErr != nil {
pkg = nil
err = firstHardErr // give preference to first hard error over any soft error
}
return pkg, fmt.Errorf("type-checking package %q failed (%v)", bp.ImportPath, err)
}
if firstHardErr != nil {
// this can only happen if we have a bug in go/types
panic("package is not safe yet no error was returned")
}
p.packages[bp.ImportPath] = pkg
return pkg, nil
}
func (p *Importer) parseFiles(dir string, filenames []string) ([]*ast.File, error) {
// use build.Context's OpenFile if there is one
open := p.ctxt.OpenFile
if open == nil {
open = func(name string) (io.ReadCloser, error) { return os.Open(name) }
}
files := make([]*ast.File, len(filenames))
errors := make([]error, len(filenames))
var wg sync.WaitGroup
wg.Add(len(filenames))
for i, filename := range filenames {
go func(i int, filepath string) {
defer wg.Done()
src, err := open(filepath)
if err != nil {
errors[i] = err // open provides operation and filename in error
return
}
files[i], errors[i] = parser.ParseFile(p.fset, filepath, src, parser.SkipObjectResolution)
src.Close() // ignore Close error - parsing may have succeeded which is all we need
}(i, p.joinPath(dir, filename))
}
wg.Wait()
// if there are errors, return the first one for deterministic results
for _, err := range errors {
if err != nil {
return nil, err
}
}
return files, nil
}
func (p *Importer) cgo(bp *build.Package) (*ast.File, error) {
tmpdir, err := os.MkdirTemp("", "srcimporter")
if err != nil {
return nil, err
}
defer os.RemoveAll(tmpdir)
goCmd := "go"
if p.ctxt.GOROOT != "" {
goCmd = filepath.Join(p.ctxt.GOROOT, "bin", "go")
}
args := []string{goCmd, "tool", "cgo", "-objdir", tmpdir}
if bp.Goroot {
switch bp.ImportPath {
case "runtime/cgo":
args = append(args, "-import_runtime_cgo=false", "-import_syscall=false")
case "runtime/race":
args = append(args, "-import_syscall=false")
}
}
args = append(args, "--")
args = append(args, strings.Fields(os.Getenv("CGO_CPPFLAGS"))...)
args = append(args, bp.CgoCPPFLAGS...)
if len(bp.CgoPkgConfig) > 0 {
cmd := exec.Command("pkg-config", append([]string{"--cflags"}, bp.CgoPkgConfig...)...)
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("pkg-config --cflags: %w", err)
}
args = append(args, strings.Fields(string(out))...)
}
args = append(args, "-I", tmpdir)
args = append(args, strings.Fields(os.Getenv("CGO_CFLAGS"))...)
args = append(args, bp.CgoCFLAGS...)
args = append(args, bp.CgoFiles...)
cmd := exec.Command(args[0], args[1:]...)
cmd.Dir = bp.Dir
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("go tool cgo: %w", err)
}
return parser.ParseFile(p.fset, filepath.Join(tmpdir, "_cgo_gotypes.go"), nil, parser.SkipObjectResolution)
}
// context-controlled file system operations
func (p *Importer) absPath(path string) (string, error) {
// TODO(gri) This should be using p.ctxt.AbsPath which doesn't
// exist but probably should. See also issue #14282.
return filepath.Abs(path)
}
func (p *Importer) isAbsPath(path string) bool {
if f := p.ctxt.IsAbsPath; f != nil {
return f(path)
}
return filepath.IsAbs(path)
}
func (p *Importer) joinPath(elem ...string) string {
if f := p.ctxt.JoinPath; f != nil {
return f(elem...)
}
return filepath.Join(elem...)
}
//go:linkname setUsesCgo go/types.srcimporter_setUsesCgo
func setUsesCgo(conf *types.Config)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the exported entry points for invoking the parser.
package parser
import (
"bytes"
"errors"
"go/ast"
"go/token"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
)
// If src != nil, readSource converts src to a []byte if possible;
// otherwise it returns an error. If src == nil, readSource returns
// the result of reading the file specified by filename.
func readSource(filename string, src any) ([]byte, error) {
if src != nil {
switch s := src.(type) {
case string:
return []byte(s), nil
case []byte:
return s, nil
case *bytes.Buffer:
// is io.Reader, but src is already available in []byte form
if s != nil {
return s.Bytes(), nil
}
case io.Reader:
return io.ReadAll(s)
}
return nil, errors.New("invalid source")
}
return os.ReadFile(filename)
}
// A Mode value is a set of flags (or 0).
// They control the amount of source code parsed and other optional
// parser functionality.
type Mode uint
const (
PackageClauseOnly Mode = 1 << iota // stop parsing after package clause
ImportsOnly // stop parsing after import declarations
ParseComments // parse comments and add them to AST
Trace // print a trace of parsed productions
DeclarationErrors // report declaration errors
SpuriousErrors // same as AllErrors, for backward-compatibility
SkipObjectResolution // don't resolve identifiers to objects - see ParseFile
AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines)
)
// ParseFile parses the source code of a single Go source file and returns
// the corresponding ast.File node. The source code may be provided via
// the filename of the source file, or via the src parameter.
//
// If src != nil, ParseFile parses the source from src and the filename is
// only used when recording position information. The type of the argument
// for the src parameter must be string, []byte, or io.Reader.
// If src == nil, ParseFile parses the file specified by filename.
//
// The mode parameter controls the amount of source text parsed and other
// optional parser functionality. If the SkipObjectResolution mode bit is set,
// the object resolution phase of parsing will be skipped, causing File.Scope,
// File.Unresolved, and all Ident.Obj fields to be nil.
//
// Position information is recorded in the file set fset, which must not be
// nil.
//
// If the source couldn't be read, the returned AST is nil and the error
// indicates the specific failure. If the source was read but syntax
// errors were found, the result is a partial AST (with ast.Bad* nodes
// representing the fragments of erroneous source code). Multiple errors
// are returned via a scanner.ErrorList which is sorted by source position.
func ParseFile(fset *token.FileSet, filename string, src any, mode Mode) (f *ast.File, err error) {
if fset == nil {
panic("parser.ParseFile: no token.FileSet provided (fset == nil)")
}
// get source
text, err := readSource(filename, src)
if err != nil {
return nil, err
}
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
// set result values
if f == nil {
// source is not a valid Go source file - satisfy
// ParseFile API and return a valid (but) empty
// *ast.File
f = &ast.File{
Name: new(ast.Ident),
Scope: ast.NewScope(nil),
}
}
p.errors.Sort()
err = p.errors.Err()
}()
// parse source
p.init(fset, filename, text, mode)
f = p.parseFile()
return
}
// ParseDir calls ParseFile for all files with names ending in ".go" in the
// directory specified by path and returns a map of package name -> package
// AST with all the packages found.
//
// If filter != nil, only the files with fs.FileInfo entries passing through
// the filter (and ending in ".go") are considered. The mode bits are passed
// to ParseFile unchanged. Position information is recorded in fset, which
// must not be nil.
//
// If the directory couldn't be read, a nil map and the respective error are
// returned. If a parse error occurred, a non-nil but incomplete map and the
// first error encountered are returned.
func ParseDir(fset *token.FileSet, path string, filter func(fs.FileInfo) bool, mode Mode) (pkgs map[string]*ast.Package, first error) {
list, err := os.ReadDir(path)
if err != nil {
return nil, err
}
pkgs = make(map[string]*ast.Package)
for _, d := range list {
if d.IsDir() || !strings.HasSuffix(d.Name(), ".go") {
continue
}
if filter != nil {
info, err := d.Info()
if err != nil {
return nil, err
}
if !filter(info) {
continue
}
}
filename := filepath.Join(path, d.Name())
if src, err := ParseFile(fset, filename, nil, mode); err == nil {
name := src.Name.Name
pkg, found := pkgs[name]
if !found {
pkg = &ast.Package{
Name: name,
Files: make(map[string]*ast.File),
}
pkgs[name] = pkg
}
pkg.Files[filename] = src
} else if first == nil {
first = err
}
}
return
}
// ParseExprFrom is a convenience function for parsing an expression.
// The arguments have the same meaning as for ParseFile, but the source must
// be a valid Go (type or value) expression. Specifically, fset must not
// be nil.
//
// If the source couldn't be read, the returned AST is nil and the error
// indicates the specific failure. If the source was read but syntax
// errors were found, the result is a partial AST (with ast.Bad* nodes
// representing the fragments of erroneous source code). Multiple errors
// are returned via a scanner.ErrorList which is sorted by source position.
func ParseExprFrom(fset *token.FileSet, filename string, src any, mode Mode) (expr ast.Expr, err error) {
if fset == nil {
panic("parser.ParseExprFrom: no token.FileSet provided (fset == nil)")
}
// get source
text, err := readSource(filename, src)
if err != nil {
return nil, err
}
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
// parse expr
p.init(fset, filename, text, mode)
expr = p.parseRhs()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
p.expect(token.EOF)
return
}
// ParseExpr is a convenience function for obtaining the AST of an expression x.
// The position information recorded in the AST is undefined. The filename used
// in error messages is the empty string.
//
// If syntax errors were found, the result is a partial AST (with ast.Bad* nodes
// representing the fragments of erroneous source code). Multiple errors are
// returned via a scanner.ErrorList which is sorted by source position.
func ParseExpr(x string) (ast.Expr, error) {
return ParseExprFrom(token.NewFileSet(), "", []byte(x), 0)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package parser implements a parser for Go source files. Input may be
// provided in a variety of forms (see the various Parse* functions); the
// output is an abstract syntax tree (AST) representing the Go source. The
// parser is invoked through one of the Parse* functions.
//
// The parser accepts a larger language than is syntactically permitted by
// the Go spec, for simplicity, and for improved robustness in the presence
// of syntax errors. For instance, in method declarations, the receiver is
// treated like an ordinary parameter list and thus may contain multiple
// entries where the spec permits exactly one. Consequently, the corresponding
// field in the AST (ast.FuncDecl.Recv) field is not restricted to one entry.
package parser
import (
"fmt"
"go/ast"
"go/internal/typeparams"
"go/scanner"
"go/token"
)
// The parser structure holds the parser's internal state.
type parser struct {
file *token.File
errors scanner.ErrorList
scanner scanner.Scanner
// Tracing/debugging
mode Mode // parsing mode
trace bool // == (mode&Trace != 0)
indent int // indentation used for tracing output
// Comments
comments []*ast.CommentGroup
leadComment *ast.CommentGroup // last lead comment
lineComment *ast.CommentGroup // last line comment
// Next token
pos token.Pos // token position
tok token.Token // one token look-ahead
lit string // token literal
// Error recovery
// (used to limit the number of calls to parser.advance
// w/o making scanning progress - avoids potential endless
// loops across multiple parser functions during error recovery)
syncPos token.Pos // last synchronization position
syncCnt int // number of parser.advance calls without progress
// Non-syntactic parser control
exprLev int // < 0: in control clause, >= 0: in expression
inRhs bool // if set, the parser is parsing a rhs expression
imports []*ast.ImportSpec // list of imports
// nestLev is used to track and limit the recursion depth
// during parsing.
nestLev int
}
func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode Mode) {
p.file = fset.AddFile(filename, -1, len(src))
var m scanner.Mode
if mode&ParseComments != 0 {
m = scanner.ScanComments
}
eh := func(pos token.Position, msg string) { p.errors.Add(pos, msg) }
p.scanner.Init(p.file, src, eh, m)
p.mode = mode
p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
p.next()
}
// ----------------------------------------------------------------------------
// Parsing support
func (p *parser) printTrace(a ...any) {
const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
const n = len(dots)
pos := p.file.Position(p.pos)
fmt.Printf("%5d:%3d: ", pos.Line, pos.Column)
i := 2 * p.indent
for i > n {
fmt.Print(dots)
i -= n
}
// i <= n
fmt.Print(dots[0:i])
fmt.Println(a...)
}
func trace(p *parser, msg string) *parser {
p.printTrace(msg, "(")
p.indent++
return p
}
// Usage pattern: defer un(trace(p, "..."))
func un(p *parser) {
p.indent--
p.printTrace(")")
}
// maxNestLev is the deepest we're willing to recurse during parsing
const maxNestLev int = 1e5
func incNestLev(p *parser) *parser {
p.nestLev++
if p.nestLev > maxNestLev {
p.error(p.pos, "exceeded max nesting depth")
panic(bailout{})
}
return p
}
// decNestLev is used to track nesting depth during parsing to prevent stack exhaustion.
// It is used along with incNestLev in a similar fashion to how un and trace are used.
func decNestLev(p *parser) {
p.nestLev--
}
// Advance to the next token.
func (p *parser) next0() {
// Because of one-token look-ahead, print the previous token
// when tracing as it provides a more readable output. The
// very first token (!p.pos.IsValid()) is not initialized
// (it is token.ILLEGAL), so don't print it.
if p.trace && p.pos.IsValid() {
s := p.tok.String()
switch {
case p.tok.IsLiteral():
p.printTrace(s, p.lit)
case p.tok.IsOperator(), p.tok.IsKeyword():
p.printTrace("\"" + s + "\"")
default:
p.printTrace(s)
}
}
p.pos, p.tok, p.lit = p.scanner.Scan()
}
// Consume a comment and return it and the line on which it ends.
func (p *parser) consumeComment() (comment *ast.Comment, endline int) {
// /*-style comments may end on a different line than where they start.
// Scan the comment for '\n' chars and adjust endline accordingly.
endline = p.file.Line(p.pos)
if p.lit[1] == '*' {
// don't use range here - no need to decode Unicode code points
for i := 0; i < len(p.lit); i++ {
if p.lit[i] == '\n' {
endline++
}
}
}
comment = &ast.Comment{Slash: p.pos, Text: p.lit}
p.next0()
return
}
// Consume a group of adjacent comments, add it to the parser's
// comments list, and return it together with the line at which
// the last comment in the group ends. A non-comment token or n
// empty lines terminate a comment group.
func (p *parser) consumeCommentGroup(n int) (comments *ast.CommentGroup, endline int) {
var list []*ast.Comment
endline = p.file.Line(p.pos)
for p.tok == token.COMMENT && p.file.Line(p.pos) <= endline+n {
var comment *ast.Comment
comment, endline = p.consumeComment()
list = append(list, comment)
}
// add comment group to the comments list
comments = &ast.CommentGroup{List: list}
p.comments = append(p.comments, comments)
return
}
// Advance to the next non-comment token. In the process, collect
// any comment groups encountered, and remember the last lead and
// line comments.
//
// A lead comment is a comment group that starts and ends in a
// line without any other tokens and that is followed by a non-comment
// token on the line immediately after the comment group.
//
// A line comment is a comment group that follows a non-comment
// token on the same line, and that has no tokens after it on the line
// where it ends.
//
// Lead and line comments may be considered documentation that is
// stored in the AST.
func (p *parser) next() {
p.leadComment = nil
p.lineComment = nil
prev := p.pos
p.next0()
if p.tok == token.COMMENT {
var comment *ast.CommentGroup
var endline int
if p.file.Line(p.pos) == p.file.Line(prev) {
// The comment is on same line as the previous token; it
// cannot be a lead comment but may be a line comment.
comment, endline = p.consumeCommentGroup(0)
if p.file.Line(p.pos) != endline || p.tok == token.SEMICOLON || p.tok == token.EOF {
// The next token is on a different line, thus
// the last comment group is a line comment.
p.lineComment = comment
}
}
// consume successor comments, if any
endline = -1
for p.tok == token.COMMENT {
comment, endline = p.consumeCommentGroup(1)
}
if endline+1 == p.file.Line(p.pos) {
// The next token is following on the line immediately after the
// comment group, thus the last comment group is a lead comment.
p.leadComment = comment
}
}
}
// A bailout panic is raised to indicate early termination. pos and msg are
// only populated when bailing out of object resolution.
type bailout struct {
pos token.Pos
msg string
}
func (p *parser) error(pos token.Pos, msg string) {
if p.trace {
defer un(trace(p, "error: "+msg))
}
epos := p.file.Position(pos)
// If AllErrors is not set, discard errors reported on the same line
// as the last recorded error and stop parsing if there are more than
// 10 errors.
if p.mode&AllErrors == 0 {
n := len(p.errors)
if n > 0 && p.errors[n-1].Pos.Line == epos.Line {
return // discard - likely a spurious error
}
if n > 10 {
panic(bailout{})
}
}
p.errors.Add(epos, msg)
}
func (p *parser) errorExpected(pos token.Pos, msg string) {
msg = "expected " + msg
if pos == p.pos {
// the error happened at the current position;
// make the error message more specific
switch {
case p.tok == token.SEMICOLON && p.lit == "\n":
msg += ", found newline"
case p.tok.IsLiteral():
// print 123 rather than 'INT', etc.
msg += ", found " + p.lit
default:
msg += ", found '" + p.tok.String() + "'"
}
}
p.error(pos, msg)
}
func (p *parser) expect(tok token.Token) token.Pos {
pos := p.pos
if p.tok != tok {
p.errorExpected(pos, "'"+tok.String()+"'")
}
p.next() // make progress
return pos
}
// expect2 is like expect, but it returns an invalid position
// if the expected token is not found.
func (p *parser) expect2(tok token.Token) (pos token.Pos) {
if p.tok == tok {
pos = p.pos
} else {
p.errorExpected(p.pos, "'"+tok.String()+"'")
}
p.next() // make progress
return
}
// expectClosing is like expect but provides a better error message
// for the common case of a missing comma before a newline.
func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+context)
p.next()
}
return p.expect(tok)
}
// expectSemi consumes a semicolon and returns the applicable line comment.
func (p *parser) expectSemi() (comment *ast.CommentGroup) {
// semicolon is optional before a closing ')' or '}'
if p.tok != token.RPAREN && p.tok != token.RBRACE {
switch p.tok {
case token.COMMA:
// permit a ',' instead of a ';' but complain
p.errorExpected(p.pos, "';'")
fallthrough
case token.SEMICOLON:
if p.lit == ";" {
// explicit semicolon
p.next()
comment = p.lineComment // use following comments
} else {
// artificial semicolon
comment = p.lineComment // use preceding comments
p.next()
}
return comment
default:
p.errorExpected(p.pos, "';'")
p.advance(stmtStart)
}
}
return nil
}
func (p *parser) atComma(context string, follow token.Token) bool {
if p.tok == token.COMMA {
return true
}
if p.tok != follow {
msg := "missing ','"
if p.tok == token.SEMICOLON && p.lit == "\n" {
msg += " before newline"
}
p.error(p.pos, msg+" in "+context)
return true // "insert" comma and continue
}
return false
}
func assert(cond bool, msg string) {
if !cond {
panic("go/parser internal error: " + msg)
}
}
// advance consumes tokens until the current token p.tok
// is in the 'to' set, or token.EOF. For error recovery.
func (p *parser) advance(to map[token.Token]bool) {
for ; p.tok != token.EOF; p.next() {
if to[p.tok] {
// Return only if parser made some progress since last
// sync or if it has not reached 10 advance calls without
// progress. Otherwise consume at least one token to
// avoid an endless parser loop (it is possible that
// both parseOperand and parseStmt call advance and
// correctly do not advance, thus the need for the
// invocation limit p.syncCnt).
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
// Reaching here indicates a parser bug, likely an
// incorrect token list in this function, but it only
// leads to skipping of possibly correct code if a
// previous error is present, and thus is preferred
// over a non-terminating parse.
}
}
}
var stmtStart = map[token.Token]bool{
token.BREAK: true,
token.CONST: true,
token.CONTINUE: true,
token.DEFER: true,
token.FALLTHROUGH: true,
token.FOR: true,
token.GO: true,
token.GOTO: true,
token.IF: true,
token.RETURN: true,
token.SELECT: true,
token.SWITCH: true,
token.TYPE: true,
token.VAR: true,
}
var declStart = map[token.Token]bool{
token.IMPORT: true,
token.CONST: true,
token.TYPE: true,
token.VAR: true,
}
var exprEnd = map[token.Token]bool{
token.COMMA: true,
token.COLON: true,
token.SEMICOLON: true,
token.RPAREN: true,
token.RBRACK: true,
token.RBRACE: true,
}
// safePos returns a valid file position for a given position: If pos
// is valid to begin with, safePos returns pos. If pos is out-of-range,
// safePos returns the EOF position.
//
// This is hack to work around "artificial" end positions in the AST which
// are computed by adding 1 to (presumably valid) token positions. If the
// token positions are invalid due to parse errors, the resulting end position
// may be past the file's EOF position, which would lead to panics if used
// later on.
func (p *parser) safePos(pos token.Pos) (res token.Pos) {
defer func() {
if recover() != nil {
res = token.Pos(p.file.Base() + p.file.Size()) // EOF position
}
}()
_ = p.file.Offset(pos) // trigger a panic if position is out-of-range
return pos
}
// ----------------------------------------------------------------------------
// Identifiers
func (p *parser) parseIdent() *ast.Ident {
pos := p.pos
name := "_"
if p.tok == token.IDENT {
name = p.lit
p.next()
} else {
p.expect(token.IDENT) // use expect() error handling
}
return &ast.Ident{NamePos: pos, Name: name}
}
func (p *parser) parseIdentList() (list []*ast.Ident) {
if p.trace {
defer un(trace(p, "IdentList"))
}
list = append(list, p.parseIdent())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseIdent())
}
return
}
// ----------------------------------------------------------------------------
// Common productions
// If lhs is set, result list elements which are identifiers are not resolved.
func (p *parser) parseExprList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ExpressionList"))
}
list = append(list, p.parseExpr())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseExpr())
}
return
}
func (p *parser) parseList(inRhs bool) []ast.Expr {
old := p.inRhs
p.inRhs = inRhs
list := p.parseExprList()
p.inRhs = old
return list
}
// ----------------------------------------------------------------------------
// Types
func (p *parser) parseType() ast.Expr {
if p.trace {
defer un(trace(p, "Type"))
}
typ := p.tryIdentOrType()
if typ == nil {
pos := p.pos
p.errorExpected(pos, "type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return typ
}
func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "QualifiedIdent"))
}
typ := p.parseTypeName(ident)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
}
// If the result is an identifier, it is not resolved.
func (p *parser) parseTypeName(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "TypeName"))
}
if ident == nil {
ident = p.parseIdent()
}
if p.tok == token.PERIOD {
// ident is a package name
p.next()
sel := p.parseIdent()
return &ast.SelectorExpr{X: ident, Sel: sel}
}
return ident
}
// "[" has already been consumed, and lbrack is its position.
// If len != nil it is the already consumed array length.
func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType {
if p.trace {
defer un(trace(p, "ArrayType"))
}
if len == nil {
p.exprLev++
// always permit ellipsis for more fault-tolerant parsing
if p.tok == token.ELLIPSIS {
len = &ast.Ellipsis{Ellipsis: p.pos}
p.next()
} else if p.tok != token.RBRACK {
len = p.parseRhs()
}
p.exprLev--
}
if p.tok == token.COMMA {
// Trailing commas are accepted in type parameter
// lists but not in array type declarations.
// Accept for better error handling but complain.
p.error(p.pos, "unexpected comma; expecting ]")
p.next()
}
p.expect(token.RBRACK)
elt := p.parseType()
return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt}
}
func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Expr) {
if p.trace {
defer un(trace(p, "ArrayFieldOrTypeInstance"))
}
lbrack := p.expect(token.LBRACK)
trailingComma := token.NoPos // if valid, the position of a trailing comma preceding the ']'
var args []ast.Expr
if p.tok != token.RBRACK {
p.exprLev++
args = append(args, p.parseRhs())
for p.tok == token.COMMA {
comma := p.pos
p.next()
if p.tok == token.RBRACK {
trailingComma = comma
break
}
args = append(args, p.parseRhs())
}
p.exprLev--
}
rbrack := p.expect(token.RBRACK)
if len(args) == 0 {
// x []E
elt := p.parseType()
return x, &ast.ArrayType{Lbrack: lbrack, Elt: elt}
}
// x [P]E or x[P]
if len(args) == 1 {
elt := p.tryIdentOrType()
if elt != nil {
// x [P]E
if trailingComma.IsValid() {
// Trailing commas are invalid in array type fields.
p.error(trailingComma, "unexpected comma; expecting ]")
}
return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt}
}
}
// x[P], x[P1, P2], ...
return nil, typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseFieldDecl() *ast.Field {
if p.trace {
defer un(trace(p, "FieldDecl"))
}
doc := p.leadComment
var names []*ast.Ident
var typ ast.Expr
switch p.tok {
case token.IDENT:
name := p.parseIdent()
if p.tok == token.PERIOD || p.tok == token.STRING || p.tok == token.SEMICOLON || p.tok == token.RBRACE {
// embedded type
typ = name
if p.tok == token.PERIOD {
typ = p.parseQualifiedIdent(name)
}
} else {
// name1, name2, ... T
names = []*ast.Ident{name}
for p.tok == token.COMMA {
p.next()
names = append(names, p.parseIdent())
}
// Careful dance: We don't know if we have an embedded instantiated
// type T[P1, P2, ...] or a field T of array type []E or [P]E.
if len(names) == 1 && p.tok == token.LBRACK {
name, typ = p.parseArrayFieldOrTypeInstance(name)
if name == nil {
names = nil
}
} else {
// T P
typ = p.parseType()
}
}
case token.MUL:
star := p.pos
p.next()
if p.tok == token.LPAREN {
// *(T)
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
typ = p.parseQualifiedIdent(nil)
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
} else {
// *T
typ = p.parseQualifiedIdent(nil)
}
typ = &ast.StarExpr{Star: star, X: typ}
case token.LPAREN:
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
if p.tok == token.MUL {
// (*T)
star := p.pos
p.next()
typ = &ast.StarExpr{Star: star, X: p.parseQualifiedIdent(nil)}
} else {
// (T)
typ = p.parseQualifiedIdent(nil)
}
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
default:
pos := p.pos
p.errorExpected(pos, "field name or embedded type")
p.advance(exprEnd)
typ = &ast.BadExpr{From: pos, To: p.pos}
}
var tag *ast.BasicLit
if p.tok == token.STRING {
tag = &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
}
comment := p.expectSemi()
field := &ast.Field{Doc: doc, Names: names, Type: typ, Tag: tag, Comment: comment}
return field
}
func (p *parser) parseStructType() *ast.StructType {
if p.trace {
defer un(trace(p, "StructType"))
}
pos := p.expect(token.STRUCT)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN {
// a field declaration cannot start with a '(' but we accept
// it here for more robust parsing and better error messages
// (parseFieldDecl will check and complain if necessary)
list = append(list, p.parseFieldDecl())
}
rbrace := p.expect(token.RBRACE)
return &ast.StructType{
Struct: pos,
Fields: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parsePointerType() *ast.StarExpr {
if p.trace {
defer un(trace(p, "PointerType"))
}
star := p.expect(token.MUL)
base := p.parseType()
return &ast.StarExpr{Star: star, X: base}
}
func (p *parser) parseDotsType() *ast.Ellipsis {
if p.trace {
defer un(trace(p, "DotsType"))
}
pos := p.expect(token.ELLIPSIS)
elt := p.parseType()
return &ast.Ellipsis{Ellipsis: pos, Elt: elt}
}
type field struct {
name *ast.Ident
typ ast.Expr
}
func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) {
// TODO(rFindley) refactor to be more similar to paramDeclOrNil in the syntax
// package
if p.trace {
defer un(trace(p, "ParamDeclOrNil"))
}
ptok := p.tok
if name != nil {
p.tok = token.IDENT // force token.IDENT case in switch below
} else if typeSetsOK && p.tok == token.TILDE {
// "~" ...
return field{nil, p.embeddedElem(nil)}
}
switch p.tok {
case token.IDENT:
// name
if name != nil {
f.name = name
p.tok = ptok
} else {
f.name = p.parseIdent()
}
switch p.tok {
case token.IDENT, token.MUL, token.ARROW, token.FUNC, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// name type
f.typ = p.parseType()
case token.LBRACK:
// name "[" type1, ..., typeN "]" or name "[" n "]" type
f.name, f.typ = p.parseArrayFieldOrTypeInstance(f.name)
case token.ELLIPSIS:
// name "..." type
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
case token.PERIOD:
// name "." ...
f.typ = p.parseQualifiedIdent(f.name)
f.name = nil
case token.TILDE:
if typeSetsOK {
f.typ = p.embeddedElem(nil)
return
}
case token.OR:
if typeSetsOK {
// name "|" typeset
f.typ = p.embeddedElem(f.name)
f.name = nil
return
}
}
case token.MUL, token.ARROW, token.FUNC, token.LBRACK, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// type
f.typ = p.parseType()
case token.ELLIPSIS:
// "..." type
// (always accepted)
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
default:
// TODO(rfindley): this is incorrect in the case of type parameter lists
// (should be "']'" in that case)
p.errorExpected(p.pos, "')'")
p.advance(exprEnd)
}
// [name] type "|"
if typeSetsOK && p.tok == token.OR && f.typ != nil {
f.typ = p.embeddedElem(f.typ)
}
return
}
func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token) (params []*ast.Field) {
if p.trace {
defer un(trace(p, "ParameterList"))
}
// Type parameters are the only parameter list closed by ']'.
tparams := closing == token.RBRACK
// Type set notation is ok in type parameter lists.
typeSetsOK := tparams
pos := p.pos
if name0 != nil {
pos = name0.Pos()
}
var list []field
var named int // number of parameters that have an explicit name and type
for name0 != nil || p.tok != closing && p.tok != token.EOF {
var par field
if typ0 != nil {
if typeSetsOK {
typ0 = p.embeddedElem(typ0)
}
par = field{name0, typ0}
} else {
par = p.parseParamDecl(name0, typeSetsOK)
}
name0 = nil // 1st name was consumed if present
typ0 = nil // 1st typ was consumed if present
if par.name != nil || par.typ != nil {
list = append(list, par)
if par.name != nil && par.typ != nil {
named++
}
}
if !p.atComma("parameter list", closing) {
break
}
p.next()
}
if len(list) == 0 {
return // not uncommon
}
// TODO(gri) parameter distribution and conversion to []*ast.Field
// can be combined and made more efficient
// distribute parameter types
if named == 0 {
// all unnamed => found names are type names
for i := 0; i < len(list); i++ {
par := &list[i]
if typ := par.name; typ != nil {
par.typ = typ
par.name = nil
}
}
if tparams {
p.error(pos, "type parameters must be named")
}
} else if named != len(list) {
// some named => all must be named
ok := true
var typ ast.Expr
missingName := pos
for i := len(list) - 1; i >= 0; i-- {
if par := &list[i]; par.typ != nil {
typ = par.typ
if par.name == nil {
ok = false
missingName = par.typ.Pos()
n := ast.NewIdent("_")
n.NamePos = typ.Pos() // correct position
par.name = n
}
} else if typ != nil {
par.typ = typ
} else {
// par.typ == nil && typ == nil => we only have a par.name
ok = false
missingName = par.name.Pos()
par.typ = &ast.BadExpr{From: par.name.Pos(), To: p.pos}
}
}
if !ok {
if tparams {
p.error(missingName, "type parameters must be named")
} else {
p.error(pos, "mixed named and unnamed parameters")
}
}
}
// convert list []*ast.Field
if named == 0 {
// parameter list consists of types only
for _, par := range list {
assert(par.typ != nil, "nil type in unnamed parameter list")
params = append(params, &ast.Field{Type: par.typ})
}
return
}
// parameter list consists of named parameters with types
var names []*ast.Ident
var typ ast.Expr
addParams := func() {
assert(typ != nil, "nil type in named parameter list")
field := &ast.Field{Names: names, Type: typ}
params = append(params, field)
names = nil
}
for _, par := range list {
if par.typ != typ {
if len(names) > 0 {
addParams()
}
typ = par.typ
}
names = append(names, par.name)
}
if len(names) > 0 {
addParams()
}
return
}
func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.FieldList) {
if p.trace {
defer un(trace(p, "Parameters"))
}
if acceptTParams && p.tok == token.LBRACK {
opening := p.pos
p.next()
// [T any](params) syntax
list := p.parseParameterList(nil, nil, token.RBRACK)
rbrack := p.expect(token.RBRACK)
tparams = &ast.FieldList{Opening: opening, List: list, Closing: rbrack}
// Type parameter lists must not be empty.
if tparams.NumFields() == 0 {
p.error(tparams.Closing, "empty type parameter list")
tparams = nil // avoid follow-on errors
}
}
opening := p.expect(token.LPAREN)
var fields []*ast.Field
if p.tok != token.RPAREN {
fields = p.parseParameterList(nil, nil, token.RPAREN)
}
rparen := p.expect(token.RPAREN)
params = &ast.FieldList{Opening: opening, List: fields, Closing: rparen}
return
}
func (p *parser) parseResult() *ast.FieldList {
if p.trace {
defer un(trace(p, "Result"))
}
if p.tok == token.LPAREN {
_, results := p.parseParameters(false)
return results
}
typ := p.tryIdentOrType()
if typ != nil {
list := make([]*ast.Field, 1)
list[0] = &ast.Field{Type: typ}
return &ast.FieldList{List: list}
}
return nil
}
func (p *parser) parseFuncType() *ast.FuncType {
if p.trace {
defer un(trace(p, "FuncType"))
}
pos := p.expect(token.FUNC)
tparams, params := p.parseParameters(true)
if tparams != nil {
p.error(tparams.Pos(), "function type must have no type parameters")
}
results := p.parseResult()
return &ast.FuncType{Func: pos, Params: params, Results: results}
}
func (p *parser) parseMethodSpec() *ast.Field {
if p.trace {
defer un(trace(p, "MethodSpec"))
}
doc := p.leadComment
var idents []*ast.Ident
var typ ast.Expr
x := p.parseTypeName(nil)
if ident, _ := x.(*ast.Ident); ident != nil {
switch {
case p.tok == token.LBRACK:
// generic method or embedded instantiated type
lbrack := p.pos
p.next()
p.exprLev++
x := p.parseExpr()
p.exprLev--
if name0, _ := x.(*ast.Ident); name0 != nil && p.tok != token.COMMA && p.tok != token.RBRACK {
// generic method m[T any]
//
// Interface methods do not have type parameters. We parse them for a
// better error message and improved error recovery.
_ = p.parseParameterList(name0, nil, token.RBRACK)
_ = p.expect(token.RBRACK)
p.error(lbrack, "interface method must have no type parameters")
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{
Func: token.NoPos,
Params: params,
Results: results,
}
} else {
// embedded instantiated type
// TODO(rfindley) should resolve all identifiers in x.
list := []ast.Expr{x}
if p.atComma("type argument list", token.RBRACK) {
p.exprLev++
p.next()
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
}
rbrack := p.expectClosing(token.RBRACK, "type argument list")
typ = typeparams.PackIndexExpr(ident, lbrack, list, rbrack)
}
case p.tok == token.LPAREN:
// ordinary method
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results}
default:
// embedded type
typ = x
}
} else {
// embedded, possibly instantiated type
typ = x
if p.tok == token.LBRACK {
// embedded instantiated interface
typ = p.parseTypeInstance(typ)
}
}
// Comment is added at the callsite: the field below may joined with
// additional type specs using '|'.
// TODO(rfindley) this should be refactored.
// TODO(rfindley) add more tests for comment handling.
return &ast.Field{Doc: doc, Names: idents, Type: typ}
}
func (p *parser) embeddedElem(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedElem"))
}
if x == nil {
x = p.embeddedTerm()
}
for p.tok == token.OR {
t := new(ast.BinaryExpr)
t.OpPos = p.pos
t.Op = token.OR
p.next()
t.X = x
t.Y = p.embeddedTerm()
x = t
}
return x
}
func (p *parser) embeddedTerm() ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedTerm"))
}
if p.tok == token.TILDE {
t := new(ast.UnaryExpr)
t.OpPos = p.pos
t.Op = token.TILDE
p.next()
t.X = p.parseType()
return t
}
t := p.tryIdentOrType()
if t == nil {
pos := p.pos
p.errorExpected(pos, "~ term or type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return t
}
func (p *parser) parseInterfaceType() *ast.InterfaceType {
if p.trace {
defer un(trace(p, "InterfaceType"))
}
pos := p.expect(token.INTERFACE)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
parseElements:
for {
switch {
case p.tok == token.IDENT:
f := p.parseMethodSpec()
if f.Names == nil {
f.Type = p.embeddedElem(f.Type)
}
f.Comment = p.expectSemi()
list = append(list, f)
case p.tok == token.TILDE:
typ := p.embeddedElem(nil)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
default:
if t := p.tryIdentOrType(); t != nil {
typ := p.embeddedElem(t)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
} else {
break parseElements
}
}
}
// TODO(rfindley): the error produced here could be improved, since we could
// accept a identifier, 'type', or a '}' at this point.
rbrace := p.expect(token.RBRACE)
return &ast.InterfaceType{
Interface: pos,
Methods: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parseMapType() *ast.MapType {
if p.trace {
defer un(trace(p, "MapType"))
}
pos := p.expect(token.MAP)
p.expect(token.LBRACK)
key := p.parseType()
p.expect(token.RBRACK)
value := p.parseType()
return &ast.MapType{Map: pos, Key: key, Value: value}
}
func (p *parser) parseChanType() *ast.ChanType {
if p.trace {
defer un(trace(p, "ChanType"))
}
pos := p.pos
dir := ast.SEND | ast.RECV
var arrow token.Pos
if p.tok == token.CHAN {
p.next()
if p.tok == token.ARROW {
arrow = p.pos
p.next()
dir = ast.SEND
}
} else {
arrow = p.expect(token.ARROW)
p.expect(token.CHAN)
dir = ast.RECV
}
value := p.parseType()
return &ast.ChanType{Begin: pos, Arrow: arrow, Dir: dir, Value: value}
}
func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeInstance"))
}
opening := p.expect(token.LBRACK)
p.exprLev++
var list []ast.Expr
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
closing := p.expectClosing(token.RBRACK, "type argument list")
if len(list) == 0 {
p.errorExpected(closing, "type argument list")
return &ast.IndexExpr{
X: typ,
Lbrack: opening,
Index: &ast.BadExpr{From: opening + 1, To: closing},
Rbrack: closing,
}
}
return typeparams.PackIndexExpr(typ, opening, list, closing)
}
func (p *parser) tryIdentOrType() ast.Expr {
defer decNestLev(incNestLev(p))
switch p.tok {
case token.IDENT:
typ := p.parseTypeName(nil)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
case token.LBRACK:
lbrack := p.expect(token.LBRACK)
return p.parseArrayType(lbrack, nil)
case token.STRUCT:
return p.parseStructType()
case token.MUL:
return p.parsePointerType()
case token.FUNC:
return p.parseFuncType()
case token.INTERFACE:
return p.parseInterfaceType()
case token.MAP:
return p.parseMapType()
case token.CHAN, token.ARROW:
return p.parseChanType()
case token.LPAREN:
lparen := p.pos
p.next()
typ := p.parseType()
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: typ, Rparen: rparen}
}
// no type found
return nil
}
// ----------------------------------------------------------------------------
// Blocks
func (p *parser) parseStmtList() (list []ast.Stmt) {
if p.trace {
defer un(trace(p, "StatementList"))
}
for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseStmt())
}
return
}
func (p *parser) parseBody() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "Body"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
func (p *parser) parseBlockStmt() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "BlockStmt"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
// ----------------------------------------------------------------------------
// Expressions
func (p *parser) parseFuncTypeOrLit() ast.Expr {
if p.trace {
defer un(trace(p, "FuncTypeOrLit"))
}
typ := p.parseFuncType()
if p.tok != token.LBRACE {
// function type only
return typ
}
p.exprLev++
body := p.parseBody()
p.exprLev--
return &ast.FuncLit{Type: typ, Body: body}
}
// parseOperand may return an expression or a raw type (incl. array
// types of the form [...]T). Callers must verify the result.
func (p *parser) parseOperand() ast.Expr {
if p.trace {
defer un(trace(p, "Operand"))
}
switch p.tok {
case token.IDENT:
x := p.parseIdent()
return x
case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
x := &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
return x
case token.LPAREN:
lparen := p.pos
p.next()
p.exprLev++
x := p.parseRhs() // types may be parenthesized: (some type)
p.exprLev--
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: x, Rparen: rparen}
case token.FUNC:
return p.parseFuncTypeOrLit()
}
if typ := p.tryIdentOrType(); typ != nil { // do not consume trailing type parameters
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
assert(!isIdent, "type cannot be identifier")
return typ
}
// we have an error
pos := p.pos
p.errorExpected(pos, "operand")
p.advance(stmtStart)
return &ast.BadExpr{From: pos, To: p.pos}
}
func (p *parser) parseSelector(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "Selector"))
}
sel := p.parseIdent()
return &ast.SelectorExpr{X: x, Sel: sel}
}
func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeAssertion"))
}
lparen := p.expect(token.LPAREN)
var typ ast.Expr
if p.tok == token.TYPE {
// type switch: typ == nil
p.next()
} else {
typ = p.parseType()
}
rparen := p.expect(token.RPAREN)
return &ast.TypeAssertExpr{X: x, Type: typ, Lparen: lparen, Rparen: rparen}
}
func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "parseIndexOrSliceOrInstance"))
}
lbrack := p.expect(token.LBRACK)
if p.tok == token.RBRACK {
// empty index, slice or index expressions are not permitted;
// accept them for parsing tolerance, but complain
p.errorExpected(p.pos, "operand")
rbrack := p.pos
p.next()
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: &ast.BadExpr{From: rbrack, To: rbrack},
Rbrack: rbrack,
}
}
p.exprLev++
const N = 3 // change the 3 to 2 to disable 3-index slices
var args []ast.Expr
var index [N]ast.Expr
var colons [N - 1]token.Pos
if p.tok != token.COLON {
// We can't know if we have an index expression or a type instantiation;
// so even if we see a (named) type we are not going to be in type context.
index[0] = p.parseRhs()
}
ncolons := 0
switch p.tok {
case token.COLON:
// slice expression
for p.tok == token.COLON && ncolons < len(colons) {
colons[ncolons] = p.pos
ncolons++
p.next()
if p.tok != token.COLON && p.tok != token.RBRACK && p.tok != token.EOF {
index[ncolons] = p.parseRhs()
}
}
case token.COMMA:
// instance expression
args = append(args, index[0])
for p.tok == token.COMMA {
p.next()
if p.tok != token.RBRACK && p.tok != token.EOF {
args = append(args, p.parseType())
}
}
}
p.exprLev--
rbrack := p.expect(token.RBRACK)
if ncolons > 0 {
// slice expression
slice3 := false
if ncolons == 2 {
slice3 = true
// Check presence of middle and final index here rather than during type-checking
// to prevent erroneous programs from passing through gofmt (was issue 7305).
if index[1] == nil {
p.error(colons[0], "middle index required in 3-index slice")
index[1] = &ast.BadExpr{From: colons[0] + 1, To: colons[1]}
}
if index[2] == nil {
p.error(colons[1], "final index required in 3-index slice")
index[2] = &ast.BadExpr{From: colons[1] + 1, To: rbrack}
}
}
return &ast.SliceExpr{X: x, Lbrack: lbrack, Low: index[0], High: index[1], Max: index[2], Slice3: slice3, Rbrack: rbrack}
}
if len(args) == 0 {
// index expression
return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack}
}
// instance expression
return typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
if p.trace {
defer un(trace(p, "CallOrConversion"))
}
lparen := p.expect(token.LPAREN)
p.exprLev++
var list []ast.Expr
var ellipsis token.Pos
for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() {
list = append(list, p.parseRhs()) // builtins may expect a type: make(some type, ...)
if p.tok == token.ELLIPSIS {
ellipsis = p.pos
p.next()
}
if !p.atComma("argument list", token.RPAREN) {
break
}
p.next()
}
p.exprLev--
rparen := p.expectClosing(token.RPAREN, "argument list")
return &ast.CallExpr{Fun: fun, Lparen: lparen, Args: list, Ellipsis: ellipsis, Rparen: rparen}
}
func (p *parser) parseValue() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
if p.tok == token.LBRACE {
return p.parseLiteralValue(nil)
}
x := p.parseExpr()
return x
}
func (p *parser) parseElement() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
x := p.parseValue()
if p.tok == token.COLON {
colon := p.pos
p.next()
x = &ast.KeyValueExpr{Key: x, Colon: colon, Value: p.parseValue()}
}
return x
}
func (p *parser) parseElementList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ElementList"))
}
for p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseElement())
if !p.atComma("composite literal", token.RBRACE) {
break
}
p.next()
}
return
}
func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "LiteralValue"))
}
lbrace := p.expect(token.LBRACE)
var elts []ast.Expr
p.exprLev++
if p.tok != token.RBRACE {
elts = p.parseElementList()
}
p.exprLev--
rbrace := p.expectClosing(token.RBRACE, "composite literal")
return &ast.CompositeLit{Type: typ, Lbrace: lbrace, Elts: elts, Rbrace: rbrace}
}
// If x is of the form (T), unparen returns unparen(T), otherwise it returns x.
func unparen(x ast.Expr) ast.Expr {
if p, isParen := x.(*ast.ParenExpr); isParen {
x = unparen(p.X)
}
return x
}
func (p *parser) parsePrimaryExpr(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "PrimaryExpr"))
}
if x == nil {
x = p.parseOperand()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
switch p.tok {
case token.PERIOD:
p.next()
switch p.tok {
case token.IDENT:
x = p.parseSelector(x)
case token.LPAREN:
x = p.parseTypeAssertion(x)
default:
pos := p.pos
p.errorExpected(pos, "selector or type assertion")
// TODO(rFindley) The check for token.RBRACE below is a targeted fix
// to error recovery sufficient to make the x/tools tests to
// pass with the new parsing logic introduced for type
// parameters. Remove this once error recovery has been
// more generally reconsidered.
if p.tok != token.RBRACE {
p.next() // make progress
}
sel := &ast.Ident{NamePos: pos, Name: "_"}
x = &ast.SelectorExpr{X: x, Sel: sel}
}
case token.LBRACK:
x = p.parseIndexOrSliceOrInstance(x)
case token.LPAREN:
x = p.parseCallOrConversion(x)
case token.LBRACE:
// operand may have returned a parenthesized complit
// type; accept it but complain if we have a complit
t := unparen(x)
// determine if '{' belongs to a composite literal or a block statement
switch t.(type) {
case *ast.BadExpr, *ast.Ident, *ast.SelectorExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.IndexExpr, *ast.IndexListExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.ArrayType, *ast.StructType, *ast.MapType:
// x is a composite literal type
default:
return x
}
if t != x {
p.error(t.Pos(), "cannot parenthesize type in composite literal")
// already progressed, no need to advance
}
x = p.parseLiteralValue(x)
default:
return x
}
}
}
func (p *parser) parseUnaryExpr() ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "UnaryExpr"))
}
switch p.tok {
case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.TILDE:
pos, op := p.pos, p.tok
p.next()
x := p.parseUnaryExpr()
return &ast.UnaryExpr{OpPos: pos, Op: op, X: x}
case token.ARROW:
// channel type or receive expression
arrow := p.pos
p.next()
// If the next token is token.CHAN we still don't know if it
// is a channel type or a receive operation - we only know
// once we have found the end of the unary expression. There
// are two cases:
//
// <- type => (<-type) must be channel type
// <- expr => <-(expr) is a receive from an expression
//
// In the first case, the arrow must be re-associated with
// the channel type parsed already:
//
// <- (chan type) => (<-chan type)
// <- (chan<- type) => (<-chan (<-type))
x := p.parseUnaryExpr()
// determine which case we have
if typ, ok := x.(*ast.ChanType); ok {
// (<-type)
// re-associate position info and <-
dir := ast.SEND
for ok && dir == ast.SEND {
if typ.Dir == ast.RECV {
// error: (<-type) is (<-(<-chan T))
p.errorExpected(typ.Arrow, "'chan'")
}
arrow, typ.Begin, typ.Arrow = typ.Arrow, arrow, arrow
dir, typ.Dir = typ.Dir, ast.RECV
typ, ok = typ.Value.(*ast.ChanType)
}
if dir == ast.SEND {
p.errorExpected(arrow, "channel type")
}
return x
}
// <-(expr)
return &ast.UnaryExpr{OpPos: arrow, Op: token.ARROW, X: x}
case token.MUL:
// pointer type or unary "*" expression
pos := p.pos
p.next()
x := p.parseUnaryExpr()
return &ast.StarExpr{Star: pos, X: x}
}
return p.parsePrimaryExpr(nil)
}
func (p *parser) tokPrec() (token.Token, int) {
tok := p.tok
if p.inRhs && tok == token.ASSIGN {
tok = token.EQL
}
return tok, tok.Precedence()
}
// parseBinaryExpr parses a (possibly) binary expression.
// If x is non-nil, it is used as the left operand.
//
// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring.
func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
if p.trace {
defer un(trace(p, "BinaryExpr"))
}
if x == nil {
x = p.parseUnaryExpr()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
op, oprec := p.tokPrec()
if oprec < prec1 {
return x
}
pos := p.expect(op)
y := p.parseBinaryExpr(nil, oprec+1)
x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y}
}
}
// The result may be a type or even a raw type ([...]int).
func (p *parser) parseExpr() ast.Expr {
if p.trace {
defer un(trace(p, "Expression"))
}
return p.parseBinaryExpr(nil, token.LowestPrec+1)
}
func (p *parser) parseRhs() ast.Expr {
old := p.inRhs
p.inRhs = true
x := p.parseExpr()
p.inRhs = old
return x
}
// ----------------------------------------------------------------------------
// Statements
// Parsing modes for parseSimpleStmt.
const (
basic = iota
labelOk
rangeOk
)
// parseSimpleStmt returns true as 2nd result if it parsed the assignment
// of a range clause (with mode == rangeOk). The returned statement is an
// assignment with a right-hand side that is a single unary expression of
// the form "range x". No guarantees are given for the left-hand side.
func (p *parser) parseSimpleStmt(mode int) (ast.Stmt, bool) {
if p.trace {
defer un(trace(p, "SimpleStmt"))
}
x := p.parseList(false)
switch p.tok {
case
token.DEFINE, token.ASSIGN, token.ADD_ASSIGN,
token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN,
token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN,
token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
// assignment statement, possibly part of a range clause
pos, tok := p.pos, p.tok
p.next()
var y []ast.Expr
isRange := false
if mode == rangeOk && p.tok == token.RANGE && (tok == token.DEFINE || tok == token.ASSIGN) {
pos := p.pos
p.next()
y = []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
isRange = true
} else {
y = p.parseList(true)
}
return &ast.AssignStmt{Lhs: x, TokPos: pos, Tok: tok, Rhs: y}, isRange
}
if len(x) > 1 {
p.errorExpected(x[0].Pos(), "1 expression")
// continue with first expression
}
switch p.tok {
case token.COLON:
// labeled statement
colon := p.pos
p.next()
if label, isIdent := x[0].(*ast.Ident); mode == labelOk && isIdent {
// Go spec: The scope of a label is the body of the function
// in which it is declared and excludes the body of any nested
// function.
stmt := &ast.LabeledStmt{Label: label, Colon: colon, Stmt: p.parseStmt()}
return stmt, false
}
// The label declaration typically starts at x[0].Pos(), but the label
// declaration may be erroneous due to a token after that position (and
// before the ':'). If SpuriousErrors is not set, the (only) error
// reported for the line is the illegal label error instead of the token
// before the ':' that caused the problem. Thus, use the (latest) colon
// position for error reporting.
p.error(colon, "illegal label declaration")
return &ast.BadStmt{From: x[0].Pos(), To: colon + 1}, false
case token.ARROW:
// send statement
arrow := p.pos
p.next()
y := p.parseRhs()
return &ast.SendStmt{Chan: x[0], Arrow: arrow, Value: y}, false
case token.INC, token.DEC:
// increment or decrement
s := &ast.IncDecStmt{X: x[0], TokPos: p.pos, Tok: p.tok}
p.next()
return s, false
}
// expression
return &ast.ExprStmt{X: x[0]}, false
}
func (p *parser) parseCallExpr(callType string) *ast.CallExpr {
x := p.parseRhs() // could be a conversion: (some type)(x)
if t := unparen(x); t != x {
p.error(x.Pos(), fmt.Sprintf("expression in %s must not be parenthesized", callType))
x = t
}
if call, isCall := x.(*ast.CallExpr); isCall {
return call
}
if _, isBad := x.(*ast.BadExpr); !isBad {
// only report error if it's a new one
p.error(p.safePos(x.End()), fmt.Sprintf("expression in %s must be function call", callType))
}
return nil
}
func (p *parser) parseGoStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "GoStmt"))
}
pos := p.expect(token.GO)
call := p.parseCallExpr("go")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 2} // len("go")
}
return &ast.GoStmt{Go: pos, Call: call}
}
func (p *parser) parseDeferStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "DeferStmt"))
}
pos := p.expect(token.DEFER)
call := p.parseCallExpr("defer")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 5} // len("defer")
}
return &ast.DeferStmt{Defer: pos, Call: call}
}
func (p *parser) parseReturnStmt() *ast.ReturnStmt {
if p.trace {
defer un(trace(p, "ReturnStmt"))
}
pos := p.pos
p.expect(token.RETURN)
var x []ast.Expr
if p.tok != token.SEMICOLON && p.tok != token.RBRACE {
x = p.parseList(true)
}
p.expectSemi()
return &ast.ReturnStmt{Return: pos, Results: x}
}
func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt {
if p.trace {
defer un(trace(p, "BranchStmt"))
}
pos := p.expect(tok)
var label *ast.Ident
if tok != token.FALLTHROUGH && p.tok == token.IDENT {
label = p.parseIdent()
}
p.expectSemi()
return &ast.BranchStmt{TokPos: pos, Tok: tok, Label: label}
}
func (p *parser) makeExpr(s ast.Stmt, want string) ast.Expr {
if s == nil {
return nil
}
if es, isExpr := s.(*ast.ExprStmt); isExpr {
return es.X
}
found := "simple statement"
if _, isAss := s.(*ast.AssignStmt); isAss {
found = "assignment"
}
p.error(s.Pos(), fmt.Sprintf("expected %s, found %s (missing parentheses around composite literal?)", want, found))
return &ast.BadExpr{From: s.Pos(), To: p.safePos(s.End())}
}
// parseIfHeader is an adjusted version of parser.header
// in cmd/compile/internal/syntax/parser.go, which has
// been tuned for better error handling.
func (p *parser) parseIfHeader() (init ast.Stmt, cond ast.Expr) {
if p.tok == token.LBRACE {
p.error(p.pos, "missing condition in if statement")
cond = &ast.BadExpr{From: p.pos, To: p.pos}
return
}
// p.tok != token.LBRACE
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
// accept potential variable declaration but complain
if p.tok == token.VAR {
p.next()
p.error(p.pos, "var declaration not allowed in if initializer")
}
init, _ = p.parseSimpleStmt(basic)
}
var condStmt ast.Stmt
var semi struct {
pos token.Pos
lit string // ";" or "\n"; valid if pos.IsValid()
}
if p.tok != token.LBRACE {
if p.tok == token.SEMICOLON {
semi.pos = p.pos
semi.lit = p.lit
p.next()
} else {
p.expect(token.SEMICOLON)
}
if p.tok != token.LBRACE {
condStmt, _ = p.parseSimpleStmt(basic)
}
} else {
condStmt = init
init = nil
}
if condStmt != nil {
cond = p.makeExpr(condStmt, "boolean expression")
} else if semi.pos.IsValid() {
if semi.lit == "\n" {
p.error(semi.pos, "unexpected newline, expecting { after if clause")
} else {
p.error(semi.pos, "missing condition in if statement")
}
}
// make sure we have a valid AST
if cond == nil {
cond = &ast.BadExpr{From: p.pos, To: p.pos}
}
p.exprLev = prevLev
return
}
func (p *parser) parseIfStmt() *ast.IfStmt {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "IfStmt"))
}
pos := p.expect(token.IF)
init, cond := p.parseIfHeader()
body := p.parseBlockStmt()
var else_ ast.Stmt
if p.tok == token.ELSE {
p.next()
switch p.tok {
case token.IF:
else_ = p.parseIfStmt()
case token.LBRACE:
else_ = p.parseBlockStmt()
p.expectSemi()
default:
p.errorExpected(p.pos, "if statement or block")
else_ = &ast.BadStmt{From: p.pos, To: p.pos}
}
} else {
p.expectSemi()
}
return &ast.IfStmt{If: pos, Init: init, Cond: cond, Body: body, Else: else_}
}
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var list []ast.Expr
if p.tok == token.CASE {
p.next()
list = p.parseList(true)
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CaseClause{Case: pos, List: list, Colon: colon, Body: body}
}
func isTypeSwitchAssert(x ast.Expr) bool {
a, ok := x.(*ast.TypeAssertExpr)
return ok && a.Type == nil
}
func (p *parser) isTypeSwitchGuard(s ast.Stmt) bool {
switch t := s.(type) {
case *ast.ExprStmt:
// x.(type)
return isTypeSwitchAssert(t.X)
case *ast.AssignStmt:
// v := x.(type)
if len(t.Lhs) == 1 && len(t.Rhs) == 1 && isTypeSwitchAssert(t.Rhs[0]) {
switch t.Tok {
case token.ASSIGN:
// permit v = x.(type) but complain
p.error(t.TokPos, "expected ':=', found '='")
fallthrough
case token.DEFINE:
return true
}
}
}
return false
}
func (p *parser) parseSwitchStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "SwitchStmt"))
}
pos := p.expect(token.SWITCH)
var s1, s2 ast.Stmt
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
if p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.LBRACE {
// A TypeSwitchGuard may declare a variable in addition
// to the variable declared in the initial SimpleStmt.
// Introduce extra scope to avoid redeclaration errors:
//
// switch t := 0; t := x.(T) { ... }
//
// (this code is not valid Go because the first t
// cannot be accessed and thus is never used, the extra
// scope is needed for the correct error message).
//
// If we don't have a type switch, s2 must be an expression.
// Having the extra nested but empty scope won't affect it.
s2, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
typeSwitch := p.isTypeSwitchGuard(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
if typeSwitch {
return &ast.TypeSwitchStmt{Switch: pos, Init: s1, Assign: s2, Body: body}
}
return &ast.SwitchStmt{Switch: pos, Init: s1, Tag: p.makeExpr(s2, "switch expression"), Body: body}
}
func (p *parser) parseCommClause() *ast.CommClause {
if p.trace {
defer un(trace(p, "CommClause"))
}
pos := p.pos
var comm ast.Stmt
if p.tok == token.CASE {
p.next()
lhs := p.parseList(false)
if p.tok == token.ARROW {
// SendStmt
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
arrow := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.SendStmt{Chan: lhs[0], Arrow: arrow, Value: rhs}
} else {
// RecvStmt
if tok := p.tok; tok == token.ASSIGN || tok == token.DEFINE {
// RecvStmt with assignment
if len(lhs) > 2 {
p.errorExpected(lhs[0].Pos(), "1 or 2 expressions")
// continue with first two expressions
lhs = lhs[0:2]
}
pos := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.AssignStmt{Lhs: lhs, TokPos: pos, Tok: tok, Rhs: []ast.Expr{rhs}}
} else {
// lhs must be single receive operation
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
comm = &ast.ExprStmt{X: lhs[0]}
}
}
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CommClause{Case: pos, Comm: comm, Colon: colon, Body: body}
}
func (p *parser) parseSelectStmt() *ast.SelectStmt {
if p.trace {
defer un(trace(p, "SelectStmt"))
}
pos := p.expect(token.SELECT)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCommClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
return &ast.SelectStmt{Select: pos, Body: body}
}
func (p *parser) parseForStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "ForStmt"))
}
pos := p.expect(token.FOR)
var s1, s2, s3 ast.Stmt
var isRange bool
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
if p.tok == token.RANGE {
// "for range x" (nil lhs in assignment)
pos := p.pos
p.next()
y := []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
s2 = &ast.AssignStmt{Rhs: y}
isRange = true
} else {
s2, isRange = p.parseSimpleStmt(rangeOk)
}
}
if !isRange && p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
p.expectSemi()
if p.tok != token.LBRACE {
s3, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
body := p.parseBlockStmt()
p.expectSemi()
if isRange {
as := s2.(*ast.AssignStmt)
// check lhs
var key, value ast.Expr
switch len(as.Lhs) {
case 0:
// nothing to do
case 1:
key = as.Lhs[0]
case 2:
key, value = as.Lhs[0], as.Lhs[1]
default:
p.errorExpected(as.Lhs[len(as.Lhs)-1].Pos(), "at most 2 expressions")
return &ast.BadStmt{From: pos, To: p.safePos(body.End())}
}
// parseSimpleStmt returned a right-hand side that
// is a single unary expression of the form "range x"
x := as.Rhs[0].(*ast.UnaryExpr).X
return &ast.RangeStmt{
For: pos,
Key: key,
Value: value,
TokPos: as.TokPos,
Tok: as.Tok,
Range: as.Rhs[0].Pos(),
X: x,
Body: body,
}
}
// regular for statement
return &ast.ForStmt{
For: pos,
Init: s1,
Cond: p.makeExpr(s2, "boolean or range expression"),
Post: s3,
Body: body,
}
}
func (p *parser) parseStmt() (s ast.Stmt) {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "Statement"))
}
switch p.tok {
case token.CONST, token.TYPE, token.VAR:
s = &ast.DeclStmt{Decl: p.parseDecl(stmtStart)}
case
// tokens that may start an expression
token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
token.LBRACK, token.STRUCT, token.MAP, token.CHAN, token.INTERFACE, // composite types
token.ADD, token.SUB, token.MUL, token.AND, token.XOR, token.ARROW, token.NOT: // unary operators
s, _ = p.parseSimpleStmt(labelOk)
// because of the required look-ahead, labeled statements are
// parsed by parseSimpleStmt - don't expect a semicolon after
// them
if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt {
p.expectSemi()
}
case token.GO:
s = p.parseGoStmt()
case token.DEFER:
s = p.parseDeferStmt()
case token.RETURN:
s = p.parseReturnStmt()
case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH:
s = p.parseBranchStmt(p.tok)
case token.LBRACE:
s = p.parseBlockStmt()
p.expectSemi()
case token.IF:
s = p.parseIfStmt()
case token.SWITCH:
s = p.parseSwitchStmt()
case token.SELECT:
s = p.parseSelectStmt()
case token.FOR:
s = p.parseForStmt()
case token.SEMICOLON:
// Is it ever possible to have an implicit semicolon
// producing an empty statement in a valid program?
// (handle correctly anyway)
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: p.lit == "\n"}
p.next()
case token.RBRACE:
// a semicolon may be omitted before a closing "}"
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: true}
default:
// no statement found
pos := p.pos
p.errorExpected(pos, "statement")
p.advance(stmtStart)
s = &ast.BadStmt{From: pos, To: p.pos}
}
return
}
// ----------------------------------------------------------------------------
// Declarations
type parseSpecFunction func(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec
func (p *parser) parseImportSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "ImportSpec"))
}
var ident *ast.Ident
switch p.tok {
case token.IDENT:
ident = p.parseIdent()
case token.PERIOD:
ident = &ast.Ident{NamePos: p.pos, Name: "."}
p.next()
}
pos := p.pos
var path string
if p.tok == token.STRING {
path = p.lit
p.next()
} else if p.tok.IsLiteral() {
p.error(pos, "import path must be a string")
p.next()
} else {
p.error(pos, "missing import path")
p.advance(exprEnd)
}
comment := p.expectSemi()
// collect imports
spec := &ast.ImportSpec{
Doc: doc,
Name: ident,
Path: &ast.BasicLit{ValuePos: pos, Kind: token.STRING, Value: path},
Comment: comment,
}
p.imports = append(p.imports, spec)
return spec
}
func (p *parser) parseValueSpec(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec {
if p.trace {
defer un(trace(p, keyword.String()+"Spec"))
}
idents := p.parseIdentList()
var typ ast.Expr
var values []ast.Expr
switch keyword {
case token.CONST:
// always permit optional type and initialization for more tolerant parsing
if p.tok != token.EOF && p.tok != token.SEMICOLON && p.tok != token.RPAREN {
typ = p.tryIdentOrType()
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
}
case token.VAR:
if p.tok != token.ASSIGN {
typ = p.parseType()
}
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
default:
panic("unreachable")
}
comment := p.expectSemi()
spec := &ast.ValueSpec{
Doc: doc,
Names: idents,
Type: typ,
Values: values,
Comment: comment,
}
return spec
}
func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) {
if p.trace {
defer un(trace(p, "parseGenericType"))
}
list := p.parseParameterList(name0, typ0, token.RBRACK)
closePos := p.expect(token.RBRACK)
spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos}
// Let the type checker decide whether to accept type parameters on aliases:
// see issue #46477.
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "TypeSpec"))
}
name := p.parseIdent()
spec := &ast.TypeSpec{Doc: doc, Name: name}
if p.tok == token.LBRACK {
// spec.Name "[" ...
// array/slice type or type parameter list
lbrack := p.pos
p.next()
if p.tok == token.IDENT {
// We may have an array type or a type parameter list.
// In either case we expect an expression x (which may
// just be a name, or a more complex expression) which
// we can analyze further.
//
// A type parameter list may have a type bound starting
// with a "[" as in: P []E. In that case, simply parsing
// an expression would lead to an error: P[] is invalid.
// But since index or slice expressions are never constant
// and thus invalid array length expressions, if the name
// is followed by "[" it must be the start of an array or
// slice constraint. Only if we don't see a "[" do we
// need to parse a full expression. Notably, name <- x
// is not a concern because name <- x is a statement and
// not an expression.
var x ast.Expr = p.parseIdent()
if p.tok != token.LBRACK {
// To parse the expression starting with name, expand
// the call sequence we would get by passing in name
// to parser.expr, and pass in name to parsePrimaryExpr.
p.exprLev++
lhs := p.parsePrimaryExpr(x)
x = p.parseBinaryExpr(lhs, token.LowestPrec+1)
p.exprLev--
}
// Analyze expression x. If we can split x into a type parameter
// name, possibly followed by a type parameter type, we consider
// this the start of a type parameter list, with some caveats:
// a single name followed by "]" tilts the decision towards an
// array declaration; a type parameter type that could also be
// an ordinary expression but which is followed by a comma tilts
// the decision towards a type parameter list.
if pname, ptype := extractName(x, p.tok == token.COMMA); pname != nil && (ptype != nil || p.tok != token.RBRACK) {
// spec.Name "[" pname ...
// spec.Name "[" pname ptype ...
// spec.Name "[" pname ptype "," ...
p.parseGenericType(spec, lbrack, pname, ptype) // ptype may be nil
} else {
// spec.Name "[" pname "]" ...
// spec.Name "[" x ...
spec.Type = p.parseArrayType(lbrack, x)
}
} else {
// array type
spec.Type = p.parseArrayType(lbrack, nil)
}
} else {
// no type parameters
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
spec.Comment = p.expectSemi()
return spec
}
// extractName splits the expression x into (name, expr) if syntactically
// x can be written as name expr. The split only happens if expr is a type
// element (per the isTypeElem predicate) or if force is set.
// If x is just a name, the result is (name, nil). If the split succeeds,
// the result is (name, expr). Otherwise the result is (nil, x).
// Examples:
//
// x force name expr
// ------------------------------------
// P*[]int T/F P *[]int
// P*E T P *E
// P*E F nil P*E
// P([]int) T/F P []int
// P(E) T P E
// P(E) F nil P(E)
// P*E|F|~G T/F P *E|F|~G
// P*E|F|G T P *E|F|G
// P*E|F|G F nil P*E|F|G
func extractName(x ast.Expr, force bool) (*ast.Ident, ast.Expr) {
switch x := x.(type) {
case *ast.Ident:
return x, nil
case *ast.BinaryExpr:
switch x.Op {
case token.MUL:
if name, _ := x.X.(*ast.Ident); name != nil && (force || isTypeElem(x.Y)) {
// x = name *x.Y
return name, &ast.StarExpr{Star: x.OpPos, X: x.Y}
}
case token.OR:
if name, lhs := extractName(x.X, force || isTypeElem(x.Y)); name != nil && lhs != nil {
// x = name lhs|x.Y
op := *x
op.X = lhs
return name, &op
}
}
case *ast.CallExpr:
if name, _ := x.Fun.(*ast.Ident); name != nil {
if len(x.Args) == 1 && x.Ellipsis == token.NoPos && (force || isTypeElem(x.Args[0])) {
// x = name "(" x.ArgList[0] ")"
return name, x.Args[0]
}
}
}
return nil, x
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl {
if p.trace {
defer un(trace(p, "GenDecl("+keyword.String()+")"))
}
doc := p.leadComment
pos := p.expect(keyword)
var lparen, rparen token.Pos
var list []ast.Spec
if p.tok == token.LPAREN {
lparen = p.pos
p.next()
for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ {
list = append(list, f(p.leadComment, keyword, iota))
}
rparen = p.expect(token.RPAREN)
p.expectSemi()
} else {
list = append(list, f(nil, keyword, 0))
}
return &ast.GenDecl{
Doc: doc,
TokPos: pos,
Tok: keyword,
Lparen: lparen,
Specs: list,
Rparen: rparen,
}
}
func (p *parser) parseFuncDecl() *ast.FuncDecl {
if p.trace {
defer un(trace(p, "FunctionDecl"))
}
doc := p.leadComment
pos := p.expect(token.FUNC)
var recv *ast.FieldList
if p.tok == token.LPAREN {
_, recv = p.parseParameters(false)
}
ident := p.parseIdent()
tparams, params := p.parseParameters(true)
if recv != nil && tparams != nil {
// Method declarations do not have type parameters. We parse them for a
// better error message and improved error recovery.
p.error(tparams.Opening, "method must have no type parameters")
tparams = nil
}
results := p.parseResult()
var body *ast.BlockStmt
switch p.tok {
case token.LBRACE:
body = p.parseBody()
p.expectSemi()
case token.SEMICOLON:
p.next()
if p.tok == token.LBRACE {
// opening { of function declaration on next line
p.error(p.pos, "unexpected semicolon or newline before {")
body = p.parseBody()
p.expectSemi()
}
default:
p.expectSemi()
}
decl := &ast.FuncDecl{
Doc: doc,
Recv: recv,
Name: ident,
Type: &ast.FuncType{
Func: pos,
TypeParams: tparams,
Params: params,
Results: results,
},
Body: body,
}
return decl
}
func (p *parser) parseDecl(sync map[token.Token]bool) ast.Decl {
if p.trace {
defer un(trace(p, "Declaration"))
}
var f parseSpecFunction
switch p.tok {
case token.IMPORT:
f = p.parseImportSpec
case token.CONST, token.VAR:
f = p.parseValueSpec
case token.TYPE:
f = p.parseTypeSpec
case token.FUNC:
return p.parseFuncDecl()
default:
pos := p.pos
p.errorExpected(pos, "declaration")
p.advance(sync)
return &ast.BadDecl{From: pos, To: p.pos}
}
return p.parseGenDecl(p.tok, f)
}
// ----------------------------------------------------------------------------
// Source files
func (p *parser) parseFile() *ast.File {
if p.trace {
defer un(trace(p, "File"))
}
// Don't bother parsing the rest if we had errors scanning the first token.
// Likely not a Go source file at all.
if p.errors.Len() != 0 {
return nil
}
// package clause
doc := p.leadComment
pos := p.expect(token.PACKAGE)
// Go spec: The package clause is not a declaration;
// the package name does not appear in any scope.
ident := p.parseIdent()
if ident.Name == "_" && p.mode&DeclarationErrors != 0 {
p.error(p.pos, "invalid package name _")
}
p.expectSemi()
// Don't bother parsing the rest if we had errors parsing the package clause.
// Likely not a Go source file at all.
if p.errors.Len() != 0 {
return nil
}
var decls []ast.Decl
if p.mode&PackageClauseOnly == 0 {
// import decls
for p.tok == token.IMPORT {
decls = append(decls, p.parseGenDecl(token.IMPORT, p.parseImportSpec))
}
if p.mode&ImportsOnly == 0 {
// rest of package body
prev := token.IMPORT
for p.tok != token.EOF {
// Continue to accept import declarations for error tolerance, but complain.
if p.tok == token.IMPORT && prev != token.IMPORT {
p.error(p.pos, "imports must appear before other declarations")
}
prev = p.tok
decls = append(decls, p.parseDecl(declStart))
}
}
}
f := &ast.File{
Doc: doc,
Package: pos,
Name: ident,
Decls: decls,
FileStart: token.Pos(p.file.Base()),
FileEnd: token.Pos(p.file.Base() + p.file.Size()),
Imports: p.imports,
Comments: p.comments,
}
var declErr func(token.Pos, string)
if p.mode&DeclarationErrors != 0 {
declErr = p.error
}
if p.mode&SkipObjectResolution == 0 {
resolveFile(f, p.file, declErr)
}
return f
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package parser
import (
"fmt"
"go/ast"
"go/token"
"strings"
)
const debugResolve = false
// resolveFile walks the given file to resolve identifiers within the file
// scope, updating ast.Ident.Obj fields with declaration information.
//
// If declErr is non-nil, it is used to report declaration errors during
// resolution. tok is used to format position in error messages.
func resolveFile(file *ast.File, handle *token.File, declErr func(token.Pos, string)) {
pkgScope := ast.NewScope(nil)
r := &resolver{
handle: handle,
declErr: declErr,
topScope: pkgScope,
pkgScope: pkgScope,
depth: 1,
}
for _, decl := range file.Decls {
ast.Walk(r, decl)
}
r.closeScope()
assert(r.topScope == nil, "unbalanced scopes")
assert(r.labelScope == nil, "unbalanced label scopes")
// resolve global identifiers within the same file
i := 0
for _, ident := range r.unresolved {
// i <= index for current ident
assert(ident.Obj == unresolved, "object already resolved")
ident.Obj = r.pkgScope.Lookup(ident.Name) // also removes unresolved sentinel
if ident.Obj == nil {
r.unresolved[i] = ident
i++
} else if debugResolve {
pos := ident.Obj.Decl.(interface{ Pos() token.Pos }).Pos()
r.trace("resolved %s@%v to package object %v", ident.Name, ident.Pos(), pos)
}
}
file.Scope = r.pkgScope
file.Unresolved = r.unresolved[0:i]
}
const maxScopeDepth int = 1e3
type resolver struct {
handle *token.File
declErr func(token.Pos, string)
// Ordinary identifier scopes
pkgScope *ast.Scope // pkgScope.Outer == nil
topScope *ast.Scope // top-most scope; may be pkgScope
unresolved []*ast.Ident // unresolved identifiers
depth int // scope depth
// Label scopes
// (maintained by open/close LabelScope)
labelScope *ast.Scope // label scope for current function
targetStack [][]*ast.Ident // stack of unresolved labels
}
func (r *resolver) trace(format string, args ...any) {
fmt.Println(strings.Repeat(". ", r.depth) + r.sprintf(format, args...))
}
func (r *resolver) sprintf(format string, args ...any) string {
for i, arg := range args {
switch arg := arg.(type) {
case token.Pos:
args[i] = r.handle.Position(arg)
}
}
return fmt.Sprintf(format, args...)
}
func (r *resolver) openScope(pos token.Pos) {
r.depth++
if r.depth > maxScopeDepth {
panic(bailout{pos: pos, msg: "exceeded max scope depth during object resolution"})
}
if debugResolve {
r.trace("opening scope @%v", pos)
}
r.topScope = ast.NewScope(r.topScope)
}
func (r *resolver) closeScope() {
r.depth--
if debugResolve {
r.trace("closing scope")
}
r.topScope = r.topScope.Outer
}
func (r *resolver) openLabelScope() {
r.labelScope = ast.NewScope(r.labelScope)
r.targetStack = append(r.targetStack, nil)
}
func (r *resolver) closeLabelScope() {
// resolve labels
n := len(r.targetStack) - 1
scope := r.labelScope
for _, ident := range r.targetStack[n] {
ident.Obj = scope.Lookup(ident.Name)
if ident.Obj == nil && r.declErr != nil {
r.declErr(ident.Pos(), fmt.Sprintf("label %s undefined", ident.Name))
}
}
// pop label scope
r.targetStack = r.targetStack[0:n]
r.labelScope = r.labelScope.Outer
}
func (r *resolver) declare(decl, data any, scope *ast.Scope, kind ast.ObjKind, idents ...*ast.Ident) {
for _, ident := range idents {
if ident.Obj != nil {
panic(fmt.Sprintf("%v: identifier %s already declared or resolved", ident.Pos(), ident.Name))
}
obj := ast.NewObj(kind, ident.Name)
// remember the corresponding declaration for redeclaration
// errors and global variable resolution/typechecking phase
obj.Decl = decl
obj.Data = data
// Identifiers (for receiver type parameters) are written to the scope, but
// never set as the resolved object. See issue #50956.
if _, ok := decl.(*ast.Ident); !ok {
ident.Obj = obj
}
if ident.Name != "_" {
if debugResolve {
r.trace("declaring %s@%v", ident.Name, ident.Pos())
}
if alt := scope.Insert(obj); alt != nil && r.declErr != nil {
prevDecl := ""
if pos := alt.Pos(); pos.IsValid() {
prevDecl = r.sprintf("\n\tprevious declaration at %v", pos)
}
r.declErr(ident.Pos(), fmt.Sprintf("%s redeclared in this block%s", ident.Name, prevDecl))
}
}
}
}
func (r *resolver) shortVarDecl(decl *ast.AssignStmt) {
// Go spec: A short variable declaration may redeclare variables
// provided they were originally declared in the same block with
// the same type, and at least one of the non-blank variables is new.
n := 0 // number of new variables
for _, x := range decl.Lhs {
if ident, isIdent := x.(*ast.Ident); isIdent {
assert(ident.Obj == nil, "identifier already declared or resolved")
obj := ast.NewObj(ast.Var, ident.Name)
// remember corresponding assignment for other tools
obj.Decl = decl
ident.Obj = obj
if ident.Name != "_" {
if debugResolve {
r.trace("declaring %s@%v", ident.Name, ident.Pos())
}
if alt := r.topScope.Insert(obj); alt != nil {
ident.Obj = alt // redeclaration
} else {
n++ // new declaration
}
}
}
}
if n == 0 && r.declErr != nil {
r.declErr(decl.Lhs[0].Pos(), "no new variables on left side of :=")
}
}
// The unresolved object is a sentinel to mark identifiers that have been added
// to the list of unresolved identifiers. The sentinel is only used for verifying
// internal consistency.
var unresolved = new(ast.Object)
// If x is an identifier, resolve attempts to resolve x by looking up
// the object it denotes. If no object is found and collectUnresolved is
// set, x is marked as unresolved and collected in the list of unresolved
// identifiers.
func (r *resolver) resolve(ident *ast.Ident, collectUnresolved bool) {
if ident.Obj != nil {
panic(r.sprintf("%v: identifier %s already declared or resolved", ident.Pos(), ident.Name))
}
// '_' should never refer to existing declarations, because it has special
// handling in the spec.
if ident.Name == "_" {
return
}
for s := r.topScope; s != nil; s = s.Outer {
if obj := s.Lookup(ident.Name); obj != nil {
if debugResolve {
r.trace("resolved %v:%s to %v", ident.Pos(), ident.Name, obj)
}
assert(obj.Name != "", "obj with no name")
// Identifiers (for receiver type parameters) are written to the scope,
// but never set as the resolved object. See issue #50956.
if _, ok := obj.Decl.(*ast.Ident); !ok {
ident.Obj = obj
}
return
}
}
// all local scopes are known, so any unresolved identifier
// must be found either in the file scope, package scope
// (perhaps in another file), or universe scope --- collect
// them so that they can be resolved later
if collectUnresolved {
ident.Obj = unresolved
r.unresolved = append(r.unresolved, ident)
}
}
func (r *resolver) walkExprs(list []ast.Expr) {
for _, node := range list {
ast.Walk(r, node)
}
}
func (r *resolver) walkLHS(list []ast.Expr) {
for _, expr := range list {
expr := unparen(expr)
if _, ok := expr.(*ast.Ident); !ok && expr != nil {
ast.Walk(r, expr)
}
}
}
func (r *resolver) walkStmts(list []ast.Stmt) {
for _, stmt := range list {
ast.Walk(r, stmt)
}
}
func (r *resolver) Visit(node ast.Node) ast.Visitor {
if debugResolve && node != nil {
r.trace("node %T@%v", node, node.Pos())
}
switch n := node.(type) {
// Expressions.
case *ast.Ident:
r.resolve(n, true)
case *ast.FuncLit:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFuncType(n.Type)
r.walkBody(n.Body)
case *ast.SelectorExpr:
ast.Walk(r, n.X)
// Note: don't try to resolve n.Sel, as we don't support qualified
// resolution.
case *ast.StructType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFieldList(n.Fields, ast.Var)
case *ast.FuncType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFuncType(n)
case *ast.CompositeLit:
if n.Type != nil {
ast.Walk(r, n.Type)
}
for _, e := range n.Elts {
if kv, _ := e.(*ast.KeyValueExpr); kv != nil {
// See issue #45160: try to resolve composite lit keys, but don't
// collect them as unresolved if resolution failed. This replicates
// existing behavior when resolving during parsing.
if ident, _ := kv.Key.(*ast.Ident); ident != nil {
r.resolve(ident, false)
} else {
ast.Walk(r, kv.Key)
}
ast.Walk(r, kv.Value)
} else {
ast.Walk(r, e)
}
}
case *ast.InterfaceType:
r.openScope(n.Pos())
defer r.closeScope()
r.walkFieldList(n.Methods, ast.Fun)
// Statements
case *ast.LabeledStmt:
r.declare(n, nil, r.labelScope, ast.Lbl, n.Label)
ast.Walk(r, n.Stmt)
case *ast.AssignStmt:
r.walkExprs(n.Rhs)
if n.Tok == token.DEFINE {
r.shortVarDecl(n)
} else {
r.walkExprs(n.Lhs)
}
case *ast.BranchStmt:
// add to list of unresolved targets
if n.Tok != token.FALLTHROUGH && n.Label != nil {
depth := len(r.targetStack) - 1
r.targetStack[depth] = append(r.targetStack[depth], n.Label)
}
case *ast.BlockStmt:
r.openScope(n.Pos())
defer r.closeScope()
r.walkStmts(n.List)
case *ast.IfStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
ast.Walk(r, n.Cond)
ast.Walk(r, n.Body)
if n.Else != nil {
ast.Walk(r, n.Else)
}
case *ast.CaseClause:
r.walkExprs(n.List)
r.openScope(n.Pos())
defer r.closeScope()
r.walkStmts(n.Body)
case *ast.SwitchStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
if n.Tag != nil {
// The scope below reproduces some unnecessary behavior of the parser,
// opening an extra scope in case this is a type switch. It's not needed
// for expression switches.
// TODO: remove this once we've matched the parser resolution exactly.
if n.Init != nil {
r.openScope(n.Tag.Pos())
defer r.closeScope()
}
ast.Walk(r, n.Tag)
}
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.TypeSwitchStmt:
if n.Init != nil {
r.openScope(n.Pos())
defer r.closeScope()
ast.Walk(r, n.Init)
}
r.openScope(n.Assign.Pos())
defer r.closeScope()
ast.Walk(r, n.Assign)
// s.Body consists only of case clauses, so does not get its own
// scope.
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.CommClause:
r.openScope(n.Pos())
defer r.closeScope()
if n.Comm != nil {
ast.Walk(r, n.Comm)
}
r.walkStmts(n.Body)
case *ast.SelectStmt:
// as for switch statements, select statement bodies don't get their own
// scope.
if n.Body != nil {
r.walkStmts(n.Body.List)
}
case *ast.ForStmt:
r.openScope(n.Pos())
defer r.closeScope()
if n.Init != nil {
ast.Walk(r, n.Init)
}
if n.Cond != nil {
ast.Walk(r, n.Cond)
}
if n.Post != nil {
ast.Walk(r, n.Post)
}
ast.Walk(r, n.Body)
case *ast.RangeStmt:
r.openScope(n.Pos())
defer r.closeScope()
ast.Walk(r, n.X)
var lhs []ast.Expr
if n.Key != nil {
lhs = append(lhs, n.Key)
}
if n.Value != nil {
lhs = append(lhs, n.Value)
}
if len(lhs) > 0 {
if n.Tok == token.DEFINE {
// Note: we can't exactly match the behavior of object resolution
// during the parsing pass here, as it uses the position of the RANGE
// token for the RHS OpPos. That information is not contained within
// the AST.
as := &ast.AssignStmt{
Lhs: lhs,
Tok: token.DEFINE,
TokPos: n.TokPos,
Rhs: []ast.Expr{&ast.UnaryExpr{Op: token.RANGE, X: n.X}},
}
// TODO(rFindley): this walkLHS reproduced the parser resolution, but
// is it necessary? By comparison, for a normal AssignStmt we don't
// walk the LHS in case there is an invalid identifier list.
r.walkLHS(lhs)
r.shortVarDecl(as)
} else {
r.walkExprs(lhs)
}
}
ast.Walk(r, n.Body)
// Declarations
case *ast.GenDecl:
switch n.Tok {
case token.CONST, token.VAR:
for i, spec := range n.Specs {
spec := spec.(*ast.ValueSpec)
kind := ast.Con
if n.Tok == token.VAR {
kind = ast.Var
}
r.walkExprs(spec.Values)
if spec.Type != nil {
ast.Walk(r, spec.Type)
}
r.declare(spec, i, r.topScope, kind, spec.Names...)
}
case token.TYPE:
for _, spec := range n.Specs {
spec := spec.(*ast.TypeSpec)
// Go spec: The scope of a type identifier declared inside a function begins
// at the identifier in the TypeSpec and ends at the end of the innermost
// containing block.
r.declare(spec, nil, r.topScope, ast.Typ, spec.Name)
if spec.TypeParams != nil {
r.openScope(spec.Pos())
defer r.closeScope()
r.walkTParams(spec.TypeParams)
}
ast.Walk(r, spec.Type)
}
}
case *ast.FuncDecl:
// Open the function scope.
r.openScope(n.Pos())
defer r.closeScope()
r.walkRecv(n.Recv)
// Type parameters are walked normally: they can reference each other, and
// can be referenced by normal parameters.
if n.Type.TypeParams != nil {
r.walkTParams(n.Type.TypeParams)
// TODO(rFindley): need to address receiver type parameters.
}
// Resolve and declare parameters in a specific order to get duplicate
// declaration errors in the correct location.
r.resolveList(n.Type.Params)
r.resolveList(n.Type.Results)
r.declareList(n.Recv, ast.Var)
r.declareList(n.Type.Params, ast.Var)
r.declareList(n.Type.Results, ast.Var)
r.walkBody(n.Body)
if n.Recv == nil && n.Name.Name != "init" {
r.declare(n, nil, r.pkgScope, ast.Fun, n.Name)
}
default:
return r
}
return nil
}
func (r *resolver) walkFuncType(typ *ast.FuncType) {
// typ.TypeParams must be walked separately for FuncDecls.
r.resolveList(typ.Params)
r.resolveList(typ.Results)
r.declareList(typ.Params, ast.Var)
r.declareList(typ.Results, ast.Var)
}
func (r *resolver) resolveList(list *ast.FieldList) {
if list == nil {
return
}
for _, f := range list.List {
if f.Type != nil {
ast.Walk(r, f.Type)
}
}
}
func (r *resolver) declareList(list *ast.FieldList, kind ast.ObjKind) {
if list == nil {
return
}
for _, f := range list.List {
r.declare(f, nil, r.topScope, kind, f.Names...)
}
}
func (r *resolver) walkRecv(recv *ast.FieldList) {
// If our receiver has receiver type parameters, we must declare them before
// trying to resolve the rest of the receiver, and avoid re-resolving the
// type parameter identifiers.
if recv == nil || len(recv.List) == 0 {
return // nothing to do
}
typ := recv.List[0].Type
if ptr, ok := typ.(*ast.StarExpr); ok {
typ = ptr.X
}
var declareExprs []ast.Expr // exprs to declare
var resolveExprs []ast.Expr // exprs to resolve
switch typ := typ.(type) {
case *ast.IndexExpr:
declareExprs = []ast.Expr{typ.Index}
resolveExprs = append(resolveExprs, typ.X)
case *ast.IndexListExpr:
declareExprs = typ.Indices
resolveExprs = append(resolveExprs, typ.X)
default:
resolveExprs = append(resolveExprs, typ)
}
for _, expr := range declareExprs {
if id, _ := expr.(*ast.Ident); id != nil {
r.declare(expr, nil, r.topScope, ast.Typ, id)
} else {
// The receiver type parameter expression is invalid, but try to resolve
// it anyway for consistency.
resolveExprs = append(resolveExprs, expr)
}
}
for _, expr := range resolveExprs {
if expr != nil {
ast.Walk(r, expr)
}
}
// The receiver is invalid, but try to resolve it anyway for consistency.
for _, f := range recv.List[1:] {
if f.Type != nil {
ast.Walk(r, f.Type)
}
}
}
func (r *resolver) walkFieldList(list *ast.FieldList, kind ast.ObjKind) {
if list == nil {
return
}
r.resolveList(list)
r.declareList(list, kind)
}
// walkTParams is like walkFieldList, but declares type parameters eagerly so
// that they may be resolved in the constraint expressions held in the field
// Type.
func (r *resolver) walkTParams(list *ast.FieldList) {
r.declareList(list, ast.Typ)
r.resolveList(list)
}
func (r *resolver) walkBody(body *ast.BlockStmt) {
if body == nil {
return
}
r.openLabelScope()
defer r.closeLabelScope()
r.walkStmts(body.List)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package printer
import (
"go/ast"
"go/doc/comment"
"strings"
)
// formatDocComment reformats the doc comment list,
// returning the canonical formatting.
func formatDocComment(list []*ast.Comment) []*ast.Comment {
// Extract comment text (removing comment markers).
var kind, text string
var directives []*ast.Comment
if len(list) == 1 && strings.HasPrefix(list[0].Text, "/*") {
kind = "/*"
text = list[0].Text
if !strings.Contains(text, "\n") || allStars(text) {
// Single-line /* .. */ comment in doc comment position,
// or multiline old-style comment like
// /*
// * Comment
// * text here.
// */
// Should not happen, since it will not work well as a
// doc comment, but if it does, just ignore:
// reformatting it will only make the situation worse.
return list
}
text = text[2 : len(text)-2] // cut /* and */
} else if strings.HasPrefix(list[0].Text, "//") {
kind = "//"
var b strings.Builder
for _, c := range list {
after, found := strings.CutPrefix(c.Text, "//")
if !found {
return list
}
// Accumulate //go:build etc lines separately.
if isDirective(after) {
directives = append(directives, c)
continue
}
b.WriteString(strings.TrimPrefix(after, " "))
b.WriteString("\n")
}
text = b.String()
} else {
// Not sure what this is, so leave alone.
return list
}
if text == "" {
return list
}
// Parse comment and reformat as text.
var p comment.Parser
d := p.Parse(text)
var pr comment.Printer
text = string(pr.Comment(d))
// For /* */ comment, return one big comment with text inside.
slash := list[0].Slash
if kind == "/*" {
c := &ast.Comment{
Slash: slash,
Text: "/*\n" + text + "*/",
}
return []*ast.Comment{c}
}
// For // comment, return sequence of // lines.
var out []*ast.Comment
for text != "" {
var line string
line, text, _ = strings.Cut(text, "\n")
if line == "" {
line = "//"
} else if strings.HasPrefix(line, "\t") {
line = "//" + line
} else {
line = "// " + line
}
out = append(out, &ast.Comment{
Slash: slash,
Text: line,
})
}
if len(directives) > 0 {
out = append(out, &ast.Comment{
Slash: slash,
Text: "//",
})
for _, c := range directives {
out = append(out, &ast.Comment{
Slash: slash,
Text: c.Text,
})
}
}
return out
}
// isDirective reports whether c is a comment directive.
// See go.dev/issue/37974.
// This code is also in go/ast.
func isDirective(c string) bool {
// "//line " is a line directive.
// "//extern " is for gccgo.
// "//export " is for cgo.
// (The // has been removed.)
if strings.HasPrefix(c, "line ") || strings.HasPrefix(c, "extern ") || strings.HasPrefix(c, "export ") {
return true
}
// "//[a-z0-9]+:[a-z0-9]"
// (The // has been removed.)
colon := strings.Index(c, ":")
if colon <= 0 || colon+1 >= len(c) {
return false
}
for i := 0; i <= colon+1; i++ {
if i == colon {
continue
}
b := c[i]
if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
return false
}
}
return true
}
// allStars reports whether text is the interior of an
// old-style /* */ comment with a star at the start of each line.
func allStars(text string) bool {
for i := 0; i < len(text); i++ {
if text[i] == '\n' {
j := i + 1
for j < len(text) && (text[j] == ' ' || text[j] == '\t') {
j++
}
if j < len(text) && text[j] != '*' {
return false
}
}
}
return true
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package printer
import (
"go/build/constraint"
"sort"
"text/tabwriter"
)
func (p *printer) fixGoBuildLines() {
if len(p.goBuild)+len(p.plusBuild) == 0 {
return
}
// Find latest possible placement of //go:build and // +build comments.
// That's just after the last blank line before we find a non-comment.
// (We'll add another blank line after our comment block.)
// When we start dropping // +build comments, we can skip over /* */ comments too.
// Note that we are processing tabwriter input, so every comment
// begins and ends with a tabwriter.Escape byte.
// And some newlines have turned into \f bytes.
insert := 0
for pos := 0; ; {
// Skip leading space at beginning of line.
blank := true
for pos < len(p.output) && (p.output[pos] == ' ' || p.output[pos] == '\t') {
pos++
}
// Skip over // comment if any.
if pos+3 < len(p.output) && p.output[pos] == tabwriter.Escape && p.output[pos+1] == '/' && p.output[pos+2] == '/' {
blank = false
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
}
// Skip over \n at end of line.
if pos >= len(p.output) || !isNL(p.output[pos]) {
break
}
pos++
if blank {
insert = pos
}
}
// If there is a //go:build comment before the place we identified,
// use that point instead. (Earlier in the file is always fine.)
if len(p.goBuild) > 0 && p.goBuild[0] < insert {
insert = p.goBuild[0]
} else if len(p.plusBuild) > 0 && p.plusBuild[0] < insert {
insert = p.plusBuild[0]
}
var x constraint.Expr
switch len(p.goBuild) {
case 0:
// Synthesize //go:build expression from // +build lines.
for _, pos := range p.plusBuild {
y, err := constraint.Parse(p.commentTextAt(pos))
if err != nil {
x = nil
break
}
if x == nil {
x = y
} else {
x = &constraint.AndExpr{X: x, Y: y}
}
}
case 1:
// Parse //go:build expression.
x, _ = constraint.Parse(p.commentTextAt(p.goBuild[0]))
}
var block []byte
if x == nil {
// Don't have a valid //go:build expression to treat as truth.
// Bring all the lines together but leave them alone.
// Note that these are already tabwriter-escaped.
for _, pos := range p.goBuild {
block = append(block, p.lineAt(pos)...)
}
for _, pos := range p.plusBuild {
block = append(block, p.lineAt(pos)...)
}
} else {
block = append(block, tabwriter.Escape)
block = append(block, "//go:build "...)
block = append(block, x.String()...)
block = append(block, tabwriter.Escape, '\n')
if len(p.plusBuild) > 0 {
lines, err := constraint.PlusBuildLines(x)
if err != nil {
lines = []string{"// +build error: " + err.Error()}
}
for _, line := range lines {
block = append(block, tabwriter.Escape)
block = append(block, line...)
block = append(block, tabwriter.Escape, '\n')
}
}
}
block = append(block, '\n')
// Build sorted list of lines to delete from remainder of output.
toDelete := append(p.goBuild, p.plusBuild...)
sort.Ints(toDelete)
// Collect output after insertion point, with lines deleted, into after.
var after []byte
start := insert
for _, end := range toDelete {
if end < start {
continue
}
after = appendLines(after, p.output[start:end])
start = end + len(p.lineAt(end))
}
after = appendLines(after, p.output[start:])
if n := len(after); n >= 2 && isNL(after[n-1]) && isNL(after[n-2]) {
after = after[:n-1]
}
p.output = p.output[:insert]
p.output = append(p.output, block...)
p.output = append(p.output, after...)
}
// appendLines is like append(x, y...)
// but it avoids creating doubled blank lines,
// which would not be gofmt-standard output.
// It assumes that only whole blocks of lines are being appended,
// not line fragments.
func appendLines(x, y []byte) []byte {
if len(y) > 0 && isNL(y[0]) && // y starts in blank line
(len(x) == 0 || len(x) >= 2 && isNL(x[len(x)-1]) && isNL(x[len(x)-2])) { // x is empty or ends in blank line
y = y[1:] // delete y's leading blank line
}
return append(x, y...)
}
func (p *printer) lineAt(start int) []byte {
pos := start
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
if pos < len(p.output) {
pos++
}
return p.output[start:pos]
}
func (p *printer) commentTextAt(start int) string {
if start < len(p.output) && p.output[start] == tabwriter.Escape {
start++
}
pos := start
for pos < len(p.output) && p.output[pos] != tabwriter.Escape && !isNL(p.output[pos]) {
pos++
}
return string(p.output[start:pos])
}
func isNL(b byte) bool {
return b == '\n' || b == '\f'
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements printing of AST nodes; specifically
// expressions, statements, declarations, and files. It uses
// the print functionality implemented in printer.go.
package printer
import (
"go/ast"
"go/token"
"math"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// Formatting issues:
// - better comment formatting for /*-style comments at the end of a line (e.g. a declaration)
// when the comment spans multiple lines; if such a comment is just two lines, formatting is
// not idempotent
// - formatting of expression lists
// - should use blank instead of tab to separate one-line function bodies from
// the function header unless there is a group of consecutive one-liners
// ----------------------------------------------------------------------------
// Common AST nodes.
// Print as many newlines as necessary (but at least min newlines) to get to
// the current line. ws is printed before the first line break. If newSection
// is set, the first line break is printed as formfeed. Returns 0 if no line
// breaks were printed, returns 1 if there was exactly one newline printed,
// and returns a value > 1 if there was a formfeed or more than one newline
// printed.
//
// TODO(gri): linebreak may add too many lines if the next statement at "line"
// is preceded by comments because the computation of n assumes
// the current position before the comment and the target position
// after the comment. Thus, after interspersing such comments, the
// space taken up by them is not considered to reduce the number of
// linebreaks. At the moment there is no easy way to know about
// future (not yet interspersed) comments in this function.
func (p *printer) linebreak(line, min int, ws whiteSpace, newSection bool) (nbreaks int) {
n := nlimit(line - p.pos.Line)
if n < min {
n = min
}
if n > 0 {
p.print(ws)
if newSection {
p.print(formfeed)
n--
nbreaks = 2
}
nbreaks += n
for ; n > 0; n-- {
p.print(newline)
}
}
return
}
// setComment sets g as the next comment if g != nil and if node comments
// are enabled - this mode is used when printing source code fragments such
// as exports only. It assumes that there is no pending comment in p.comments
// and at most one pending comment in the p.comment cache.
func (p *printer) setComment(g *ast.CommentGroup) {
if g == nil || !p.useNodeComments {
return
}
if p.comments == nil {
// initialize p.comments lazily
p.comments = make([]*ast.CommentGroup, 1)
} else if p.cindex < len(p.comments) {
// for some reason there are pending comments; this
// should never happen - handle gracefully and flush
// all comments up to g, ignore anything after that
p.flush(p.posFor(g.List[0].Pos()), token.ILLEGAL)
p.comments = p.comments[0:1]
// in debug mode, report error
p.internalError("setComment found pending comments")
}
p.comments[0] = g
p.cindex = 0
// don't overwrite any pending comment in the p.comment cache
// (there may be a pending comment when a line comment is
// immediately followed by a lead comment with no other
// tokens between)
if p.commentOffset == infinity {
p.nextComment() // get comment ready for use
}
}
type exprListMode uint
const (
commaTerm exprListMode = 1 << iota // list is optionally terminated by a comma
noIndent // no extra indentation in multi-line lists
)
// If indent is set, a multi-line identifier list is indented after the
// first linebreak encountered.
func (p *printer) identList(list []*ast.Ident, indent bool) {
// convert into an expression list so we can re-use exprList formatting
xlist := make([]ast.Expr, len(list))
for i, x := range list {
xlist[i] = x
}
var mode exprListMode
if !indent {
mode = noIndent
}
p.exprList(token.NoPos, xlist, 1, mode, token.NoPos, false)
}
const filteredMsg = "contains filtered or unexported fields"
// Print a list of expressions. If the list spans multiple
// source lines, the original line breaks are respected between
// expressions.
//
// TODO(gri) Consider rewriting this to be independent of []ast.Expr
// so that we can use the algorithm for any kind of list
//
// (e.g., pass list via a channel over which to range).
func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exprListMode, next0 token.Pos, isIncomplete bool) {
if len(list) == 0 {
if isIncomplete {
prev := p.posFor(prev0)
next := p.posFor(next0)
if prev.IsValid() && prev.Line == next.Line {
p.print("/* " + filteredMsg + " */")
} else {
p.print(newline)
p.print(indent, "// "+filteredMsg, unindent, newline)
}
}
return
}
prev := p.posFor(prev0)
next := p.posFor(next0)
line := p.lineFor(list[0].Pos())
endLine := p.lineFor(list[len(list)-1].End())
if prev.IsValid() && prev.Line == line && line == endLine {
// all list entries on a single line
for i, x := range list {
if i > 0 {
// use position of expression following the comma as
// comma position for correct comment placement
p.setPos(x.Pos())
p.print(token.COMMA, blank)
}
p.expr0(x, depth)
}
if isIncomplete {
p.print(token.COMMA, blank, "/* "+filteredMsg+" */")
}
return
}
// list entries span multiple lines;
// use source code positions to guide line breaks
// Don't add extra indentation if noIndent is set;
// i.e., pretend that the first line is already indented.
ws := ignore
if mode&noIndent == 0 {
ws = indent
}
// The first linebreak is always a formfeed since this section must not
// depend on any previous formatting.
prevBreak := -1 // index of last expression that was followed by a linebreak
if prev.IsValid() && prev.Line < line && p.linebreak(line, 0, ws, true) > 0 {
ws = ignore
prevBreak = 0
}
// initialize expression/key size: a zero value indicates expr/key doesn't fit on a single line
size := 0
// We use the ratio between the geometric mean of the previous key sizes and
// the current size to determine if there should be a break in the alignment.
// To compute the geometric mean we accumulate the ln(size) values (lnsum)
// and the number of sizes included (count).
lnsum := 0.0
count := 0
// print all list elements
prevLine := prev.Line
for i, x := range list {
line = p.lineFor(x.Pos())
// Determine if the next linebreak, if any, needs to use formfeed:
// in general, use the entire node size to make the decision; for
// key:value expressions, use the key size.
// TODO(gri) for a better result, should probably incorporate both
// the key and the node size into the decision process
useFF := true
// Determine element size: All bets are off if we don't have
// position information for the previous and next token (likely
// generated code - simply ignore the size in this case by setting
// it to 0).
prevSize := size
const infinity = 1e6 // larger than any source line
size = p.nodeSize(x, infinity)
pair, isPair := x.(*ast.KeyValueExpr)
if size <= infinity && prev.IsValid() && next.IsValid() {
// x fits on a single line
if isPair {
size = p.nodeSize(pair.Key, infinity) // size <= infinity
}
} else {
// size too large or we don't have good layout information
size = 0
}
// If the previous line and the current line had single-
// line-expressions and the key sizes are small or the
// ratio between the current key and the geometric mean
// if the previous key sizes does not exceed a threshold,
// align columns and do not use formfeed.
if prevSize > 0 && size > 0 {
const smallSize = 40
if count == 0 || prevSize <= smallSize && size <= smallSize {
useFF = false
} else {
const r = 2.5 // threshold
geomean := math.Exp(lnsum / float64(count)) // count > 0
ratio := float64(size) / geomean
useFF = r*ratio <= 1 || r <= ratio
}
}
needsLinebreak := 0 < prevLine && prevLine < line
if i > 0 {
// Use position of expression following the comma as
// comma position for correct comment placement, but
// only if the expression is on the same line.
if !needsLinebreak {
p.setPos(x.Pos())
}
p.print(token.COMMA)
needsBlank := true
if needsLinebreak {
// Lines are broken using newlines so comments remain aligned
// unless useFF is set or there are multiple expressions on
// the same line in which case formfeed is used.
nbreaks := p.linebreak(line, 0, ws, useFF || prevBreak+1 < i)
if nbreaks > 0 {
ws = ignore
prevBreak = i
needsBlank = false // we got a line break instead
}
// If there was a new section or more than one new line
// (which means that the tabwriter will implicitly break
// the section), reset the geomean variables since we are
// starting a new group of elements with the next element.
if nbreaks > 1 {
lnsum = 0
count = 0
}
}
if needsBlank {
p.print(blank)
}
}
if len(list) > 1 && isPair && size > 0 && needsLinebreak {
// We have a key:value expression that fits onto one line
// and it's not on the same line as the prior expression:
// Use a column for the key such that consecutive entries
// can align if possible.
// (needsLinebreak is set if we started a new line before)
p.expr(pair.Key)
p.setPos(pair.Colon)
p.print(token.COLON, vtab)
p.expr(pair.Value)
} else {
p.expr0(x, depth)
}
if size > 0 {
lnsum += math.Log(float64(size))
count++
}
prevLine = line
}
if mode&commaTerm != 0 && next.IsValid() && p.pos.Line < next.Line {
// Print a terminating comma if the next token is on a new line.
p.print(token.COMMA)
if isIncomplete {
p.print(newline)
p.print("// " + filteredMsg)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
p.print(formfeed) // terminating comma needs a line break to look good
return
}
if isIncomplete {
p.print(token.COMMA, newline)
p.print("// "+filteredMsg, newline)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
}
type paramMode int
const (
funcParam paramMode = iota
funcTParam
typeTParam
)
func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
openTok, closeTok := token.LPAREN, token.RPAREN
if mode != funcParam {
openTok, closeTok = token.LBRACK, token.RBRACK
}
p.setPos(fields.Opening)
p.print(openTok)
if len(fields.List) > 0 {
prevLine := p.lineFor(fields.Opening)
ws := indent
for i, par := range fields.List {
// determine par begin and end line (may be different
// if there are multiple parameter names for this par
// or the type is on a separate line)
parLineBeg := p.lineFor(par.Pos())
parLineEnd := p.lineFor(par.End())
// separating "," if needed
needsLinebreak := 0 < prevLine && prevLine < parLineBeg
if i > 0 {
// use position of parameter following the comma as
// comma position for correct comma placement, but
// only if the next parameter is on the same line
if !needsLinebreak {
p.setPos(par.Pos())
}
p.print(token.COMMA)
}
// separator if needed (linebreak or blank)
if needsLinebreak && p.linebreak(parLineBeg, 0, ws, true) > 0 {
// break line if the opening "(" or previous parameter ended on a different line
ws = ignore
} else if i > 0 {
p.print(blank)
}
// parameter names
if len(par.Names) > 0 {
// Very subtle: If we indented before (ws == ignore), identList
// won't indent again. If we didn't (ws == indent), identList will
// indent if the identList spans multiple lines, and it will outdent
// again at the end (and still ws == indent). Thus, a subsequent indent
// by a linebreak call after a type, or in the next multi-line identList
// will do the right thing.
p.identList(par.Names, ws == indent)
p.print(blank)
}
// parameter type
p.expr(stripParensAlways(par.Type))
prevLine = parLineEnd
}
// if the closing ")" is on a separate line from the last parameter,
// print an additional "," and line break
if closing := p.lineFor(fields.Closing); 0 < prevLine && prevLine < closing {
p.print(token.COMMA)
p.linebreak(closing, 0, ignore, true)
} else if mode == typeTParam && fields.NumFields() == 1 && combinesWithName(fields.List[0].Type) {
// A type parameter list [P T] where the name P and the type expression T syntactically
// combine to another valid (value) expression requires a trailing comma, as in [P *T,]
// (or an enclosing interface as in [P interface(*T)]), so that the type parameter list
// is not parsed as an array length [P*T].
p.print(token.COMMA)
}
// unindent if we indented
if ws == ignore {
p.print(unindent)
}
}
p.setPos(fields.Closing)
p.print(closeTok)
}
// combinesWithName reports whether a name followed by the expression x
// syntactically combines to another valid (value) expression. For instance
// using *T for x, "name *T" syntactically appears as the expression x*T.
// On the other hand, using P|Q or *P|~Q for x, "name P|Q" or name *P|~Q"
// cannot be combined into a valid (value) expression.
func combinesWithName(x ast.Expr) bool {
switch x := x.(type) {
case *ast.StarExpr:
// name *x.X combines to name*x.X if x.X is not a type element
return !isTypeElem(x.X)
case *ast.BinaryExpr:
return combinesWithName(x.X) && !isTypeElem(x.Y)
case *ast.ParenExpr:
// name(x) combines but we are making sure at
// the call site that x is never parenthesized.
panic("unexpected parenthesized expression")
}
return false
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *printer) signature(sig *ast.FuncType) {
if sig.TypeParams != nil {
p.parameters(sig.TypeParams, funcTParam)
}
if sig.Params != nil {
p.parameters(sig.Params, funcParam)
} else {
p.print(token.LPAREN, token.RPAREN)
}
res := sig.Results
n := res.NumFields()
if n > 0 {
// res != nil
p.print(blank)
if n == 1 && res.List[0].Names == nil {
// single anonymous res; no ()'s
p.expr(stripParensAlways(res.List[0].Type))
return
}
p.parameters(res, funcParam)
}
}
func identListSize(list []*ast.Ident, maxSize int) (size int) {
for i, x := range list {
if i > 0 {
size += len(", ")
}
size += utf8.RuneCountInString(x.Name)
if size >= maxSize {
break
}
}
return
}
func (p *printer) isOneLineFieldList(list []*ast.Field) bool {
if len(list) != 1 {
return false // allow only one field
}
f := list[0]
if f.Tag != nil || f.Comment != nil {
return false // don't allow tags or comments
}
// only name(s) and type
const maxSize = 30 // adjust as appropriate, this is an approximate value
namesSize := identListSize(f.Names, maxSize)
if namesSize > 0 {
namesSize = 1 // blank between names and types
}
typeSize := p.nodeSize(f.Type, maxSize)
return namesSize+typeSize <= maxSize
}
func (p *printer) setLineComment(text string) {
p.setComment(&ast.CommentGroup{List: []*ast.Comment{{Slash: token.NoPos, Text: text}}})
}
func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) {
lbrace := fields.Opening
list := fields.List
rbrace := fields.Closing
hasComments := isIncomplete || p.commentBefore(p.posFor(rbrace))
srcIsOneLine := lbrace.IsValid() && rbrace.IsValid() && p.lineFor(lbrace) == p.lineFor(rbrace)
if !hasComments && srcIsOneLine {
// possibly a one-line struct/interface
if len(list) == 0 {
// no blank between keyword and {} in this case
p.setPos(lbrace)
p.print(token.LBRACE)
p.setPos(rbrace)
p.print(token.RBRACE)
return
} else if p.isOneLineFieldList(list) {
// small enough - print on one line
// (don't use identList and ignore source line breaks)
p.setPos(lbrace)
p.print(token.LBRACE, blank)
f := list[0]
if isStruct {
for i, x := range f.Names {
if i > 0 {
// no comments so no need for comma position
p.print(token.COMMA, blank)
}
p.expr(x)
}
if len(f.Names) > 0 {
p.print(blank)
}
p.expr(f.Type)
} else { // interface
if len(f.Names) > 0 {
name := f.Names[0] // method name
p.expr(name)
p.signature(f.Type.(*ast.FuncType)) // don't print "func"
} else {
// embedded interface
p.expr(f.Type)
}
}
p.print(blank)
p.setPos(rbrace)
p.print(token.RBRACE)
return
}
}
// hasComments || !srcIsOneLine
p.print(blank)
p.setPos(lbrace)
p.print(token.LBRACE, indent)
if hasComments || len(list) > 0 {
p.print(formfeed)
}
if isStruct {
sep := vtab
if len(list) == 1 {
sep = blank
}
var line int
for i, f := range list {
if i > 0 {
p.linebreak(p.lineFor(f.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
extraTabs := 0
p.setComment(f.Doc)
p.recordLine(&line)
if len(f.Names) > 0 {
// named fields
p.identList(f.Names, false)
p.print(sep)
p.expr(f.Type)
extraTabs = 1
} else {
// anonymous field
p.expr(f.Type)
extraTabs = 2
}
if f.Tag != nil {
if len(f.Names) > 0 && sep == vtab {
p.print(sep)
}
p.print(sep)
p.expr(f.Tag)
extraTabs = 0
}
if f.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(sep)
}
p.setComment(f.Comment)
}
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// " + filteredMsg)
}
} else { // interface
var line int
var prev *ast.Ident // previous "type" identifier
for i, f := range list {
var name *ast.Ident // first name, or nil
if len(f.Names) > 0 {
name = f.Names[0]
}
if i > 0 {
// don't do a line break (min == 0) if we are printing a list of types
// TODO(gri) this doesn't work quite right if the list of types is
// spread across multiple lines
min := 1
if prev != nil && name == prev {
min = 0
}
p.linebreak(p.lineFor(f.Pos()), min, ignore, p.linesFrom(line) > 0)
}
p.setComment(f.Doc)
p.recordLine(&line)
if name != nil {
// method
p.expr(name)
p.signature(f.Type.(*ast.FuncType)) // don't print "func"
prev = nil
} else {
// embedded interface
p.expr(f.Type)
prev = nil
}
p.setComment(f.Comment)
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// contains filtered or unexported methods")
}
}
p.print(unindent, formfeed)
p.setPos(rbrace)
p.print(token.RBRACE)
}
// ----------------------------------------------------------------------------
// Expressions
func walkBinary(e *ast.BinaryExpr) (has4, has5 bool, maxProblem int) {
switch e.Op.Precedence() {
case 4:
has4 = true
case 5:
has5 = true
}
switch l := e.X.(type) {
case *ast.BinaryExpr:
if l.Op.Precedence() < e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(l)
has4 = has4 || h4
has5 = has5 || h5
if maxProblem < mp {
maxProblem = mp
}
}
switch r := e.Y.(type) {
case *ast.BinaryExpr:
if r.Op.Precedence() <= e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(r)
has4 = has4 || h4
has5 = has5 || h5
if maxProblem < mp {
maxProblem = mp
}
case *ast.StarExpr:
if e.Op == token.QUO { // `*/`
maxProblem = 5
}
case *ast.UnaryExpr:
switch e.Op.String() + r.Op.String() {
case "/*", "&&", "&^":
maxProblem = 5
case "++", "--":
if maxProblem < 4 {
maxProblem = 4
}
}
}
return
}
func cutoff(e *ast.BinaryExpr, depth int) int {
has4, has5, maxProblem := walkBinary(e)
if maxProblem > 0 {
return maxProblem + 1
}
if has4 && has5 {
if depth == 1 {
return 5
}
return 4
}
if depth == 1 {
return 6
}
return 4
}
func diffPrec(expr ast.Expr, prec int) int {
x, ok := expr.(*ast.BinaryExpr)
if !ok || prec != x.Op.Precedence() {
return 1
}
return 0
}
func reduceDepth(depth int) int {
depth--
if depth < 1 {
depth = 1
}
return depth
}
// Format the binary expression: decide the cutoff and then format.
// Let's call depth == 1 Normal mode, and depth > 1 Compact mode.
// (Algorithm suggestion by Russ Cox.)
//
// The precedences are:
//
// 5 * / % << >> & &^
// 4 + - | ^
// 3 == != < <= > >=
// 2 &&
// 1 ||
//
// The only decision is whether there will be spaces around levels 4 and 5.
// There are never spaces at level 6 (unary), and always spaces at levels 3 and below.
//
// To choose the cutoff, look at the whole expression but excluding primary
// expressions (function calls, parenthesized exprs), and apply these rules:
//
// 1. If there is a binary operator with a right side unary operand
// that would clash without a space, the cutoff must be (in order):
//
// /* 6
// && 6
// &^ 6
// ++ 5
// -- 5
//
// (Comparison operators always have spaces around them.)
//
// 2. If there is a mix of level 5 and level 4 operators, then the cutoff
// is 5 (use spaces to distinguish precedence) in Normal mode
// and 4 (never use spaces) in Compact mode.
//
// 3. If there are no level 4 operators or no level 5 operators, then the
// cutoff is 6 (always use spaces) in Normal mode
// and 4 (never use spaces) in Compact mode.
func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) {
prec := x.Op.Precedence()
if prec < prec1 {
// parenthesis needed
// Note: The parser inserts an ast.ParenExpr node; thus this case
// can only occur if the AST is created in a different way.
p.print(token.LPAREN)
p.expr0(x, reduceDepth(depth)) // parentheses undo one level of depth
p.print(token.RPAREN)
return
}
printBlank := prec < cutoff
ws := indent
p.expr1(x.X, prec, depth+diffPrec(x.X, prec))
if printBlank {
p.print(blank)
}
xline := p.pos.Line // before the operator (it may be on the next line!)
yline := p.lineFor(x.Y.Pos())
p.setPos(x.OpPos)
p.print(x.Op)
if xline != yline && xline > 0 && yline > 0 {
// at least one line break, but respect an extra empty line
// in the source
if p.linebreak(yline, 1, ws, true) > 0 {
ws = ignore
printBlank = false // no blank after line break
}
}
if printBlank {
p.print(blank)
}
p.expr1(x.Y, prec+1, depth+1)
if ws == ignore {
p.print(unindent)
}
}
func isBinary(expr ast.Expr) bool {
_, ok := expr.(*ast.BinaryExpr)
return ok
}
func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
p.setPos(expr.Pos())
switch x := expr.(type) {
case *ast.BadExpr:
p.print("BadExpr")
case *ast.Ident:
p.print(x)
case *ast.BinaryExpr:
if depth < 1 {
p.internalError("depth < 1:", depth)
depth = 1
}
p.binaryExpr(x, prec1, cutoff(x, depth), depth)
case *ast.KeyValueExpr:
p.expr(x.Key)
p.setPos(x.Colon)
p.print(token.COLON, blank)
p.expr(x.Value)
case *ast.StarExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.print(token.MUL)
p.expr(x.X)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(token.MUL)
p.expr(x.X)
}
case *ast.UnaryExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.expr(x)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(x.Op)
if x.Op == token.RANGE {
// TODO(gri) Remove this code if it cannot be reached.
p.print(blank)
}
p.expr1(x.X, prec, depth)
}
case *ast.BasicLit:
if p.Config.Mode&normalizeNumbers != 0 {
x = normalizedNumber(x)
}
p.print(x)
case *ast.FuncLit:
p.setPos(x.Type.Pos())
p.print(token.FUNC)
// See the comment in funcDecl about how the header size is computed.
startCol := p.out.Column - len("func")
p.signature(x.Type)
p.funcBody(p.distanceFrom(x.Type.Pos(), startCol), blank, x.Body)
case *ast.ParenExpr:
if _, hasParens := x.X.(*ast.ParenExpr); hasParens {
// don't print parentheses around an already parenthesized expression
// TODO(gri) consider making this more general and incorporate precedence levels
p.expr0(x.X, depth)
} else {
p.print(token.LPAREN)
p.expr0(x.X, reduceDepth(depth)) // parentheses undo one level of depth
p.setPos(x.Rparen)
p.print(token.RPAREN)
}
case *ast.SelectorExpr:
p.selectorExpr(x, depth, false)
case *ast.TypeAssertExpr:
p.expr1(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Type != nil {
p.expr(x.Type)
} else {
p.print(token.TYPE)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
case *ast.IndexExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.expr0(x.Index, depth+1)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.IndexListExpr:
// TODO(gri): as for IndexExpr, should treat [] like parentheses and undo
// one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.exprList(x.Lbrack, x.Indices, depth+1, commaTerm, x.Rbrack, false)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.SliceExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
indices := []ast.Expr{x.Low, x.High}
if x.Max != nil {
indices = append(indices, x.Max)
}
// determine if we need extra blanks around ':'
var needsBlanks bool
if depth <= 1 {
var indexCount int
var hasBinaries bool
for _, x := range indices {
if x != nil {
indexCount++
if isBinary(x) {
hasBinaries = true
}
}
}
if indexCount > 1 && hasBinaries {
needsBlanks = true
}
}
for i, x := range indices {
if i > 0 {
if indices[i-1] != nil && needsBlanks {
p.print(blank)
}
p.print(token.COLON)
if x != nil && needsBlanks {
p.print(blank)
}
}
if x != nil {
p.expr0(x, depth+1)
}
}
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.CallExpr:
if len(x.Args) > 1 {
depth++
}
var wasIndented bool
if _, ok := x.Fun.(*ast.FuncType); ok {
// conversions to literal function types require parentheses around the type
p.print(token.LPAREN)
wasIndented = p.possibleSelectorExpr(x.Fun, token.HighestPrec, depth)
p.print(token.RPAREN)
} else {
wasIndented = p.possibleSelectorExpr(x.Fun, token.HighestPrec, depth)
}
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Ellipsis.IsValid() {
p.exprList(x.Lparen, x.Args, depth, 0, x.Ellipsis, false)
p.setPos(x.Ellipsis)
p.print(token.ELLIPSIS)
if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
p.print(token.COMMA, formfeed)
}
} else {
p.exprList(x.Lparen, x.Args, depth, commaTerm, x.Rparen, false)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
if wasIndented {
p.print(unindent)
}
case *ast.CompositeLit:
// composite literal elements that are composite literals themselves may have the type omitted
if x.Type != nil {
p.expr1(x.Type, token.HighestPrec, depth)
}
p.level++
p.setPos(x.Lbrace)
p.print(token.LBRACE)
p.exprList(x.Lbrace, x.Elts, 1, commaTerm, x.Rbrace, x.Incomplete)
// do not insert extra line break following a /*-style comment
// before the closing '}' as it might break the code if there
// is no trailing ','
mode := noExtraLinebreak
// do not insert extra blank following a /*-style comment
// before the closing '}' unless the literal is empty
if len(x.Elts) > 0 {
mode |= noExtraBlank
}
// need the initial indent to print lone comments with
// the proper level of indentation
p.print(indent, unindent, mode)
p.setPos(x.Rbrace)
p.print(token.RBRACE, mode)
p.level--
case *ast.Ellipsis:
p.print(token.ELLIPSIS)
if x.Elt != nil {
p.expr(x.Elt)
}
case *ast.ArrayType:
p.print(token.LBRACK)
if x.Len != nil {
p.expr(x.Len)
}
p.print(token.RBRACK)
p.expr(x.Elt)
case *ast.StructType:
p.print(token.STRUCT)
p.fieldList(x.Fields, true, x.Incomplete)
case *ast.FuncType:
p.print(token.FUNC)
p.signature(x)
case *ast.InterfaceType:
p.print(token.INTERFACE)
p.fieldList(x.Methods, false, x.Incomplete)
case *ast.MapType:
p.print(token.MAP, token.LBRACK)
p.expr(x.Key)
p.print(token.RBRACK)
p.expr(x.Value)
case *ast.ChanType:
switch x.Dir {
case ast.SEND | ast.RECV:
p.print(token.CHAN)
case ast.RECV:
p.print(token.ARROW, token.CHAN) // x.Arrow and x.Pos() are the same
case ast.SEND:
p.print(token.CHAN)
p.setPos(x.Arrow)
p.print(token.ARROW)
}
p.print(blank)
p.expr(x.Value)
default:
panic("unreachable")
}
}
// normalizedNumber rewrites base prefixes and exponents
// of numbers to use lower-case letters (0X123 to 0x123 and 1.2E3 to 1.2e3),
// and removes leading 0's from integer imaginary literals (0765i to 765i).
// It leaves hexadecimal digits alone.
//
// normalizedNumber doesn't modify the ast.BasicLit value lit points to.
// If lit is not a number or a number in canonical format already,
// lit is returned as is. Otherwise a new ast.BasicLit is created.
func normalizedNumber(lit *ast.BasicLit) *ast.BasicLit {
if lit.Kind != token.INT && lit.Kind != token.FLOAT && lit.Kind != token.IMAG {
return lit // not a number - nothing to do
}
if len(lit.Value) < 2 {
return lit // only one digit (common case) - nothing to do
}
// len(lit.Value) >= 2
// We ignore lit.Kind because for lit.Kind == token.IMAG the literal may be an integer
// or floating-point value, decimal or not. Instead, just consider the literal pattern.
x := lit.Value
switch x[:2] {
default:
// 0-prefix octal, decimal int, or float (possibly with 'i' suffix)
if i := strings.LastIndexByte(x, 'E'); i >= 0 {
x = x[:i] + "e" + x[i+1:]
break
}
// remove leading 0's from integer (but not floating-point) imaginary literals
if x[len(x)-1] == 'i' && !strings.ContainsAny(x, ".e") {
x = strings.TrimLeft(x, "0_")
if x == "i" {
x = "0i"
}
}
case "0X":
x = "0x" + x[2:]
// possibly a hexadecimal float
if i := strings.LastIndexByte(x, 'P'); i >= 0 {
x = x[:i] + "p" + x[i+1:]
}
case "0x":
// possibly a hexadecimal float
i := strings.LastIndexByte(x, 'P')
if i == -1 {
return lit // nothing to do
}
x = x[:i] + "p" + x[i+1:]
case "0O":
x = "0o" + x[2:]
case "0o":
return lit // nothing to do
case "0B":
x = "0b" + x[2:]
case "0b":
return lit // nothing to do
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: lit.Kind, Value: x}
}
func (p *printer) possibleSelectorExpr(expr ast.Expr, prec1, depth int) bool {
if x, ok := expr.(*ast.SelectorExpr); ok {
return p.selectorExpr(x, depth, true)
}
p.expr1(expr, prec1, depth)
return false
}
// selectorExpr handles an *ast.SelectorExpr node and reports whether x spans
// multiple lines.
func (p *printer) selectorExpr(x *ast.SelectorExpr, depth int, isMethod bool) bool {
p.expr1(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
if line := p.lineFor(x.Sel.Pos()); p.pos.IsValid() && p.pos.Line < line {
p.print(indent, newline)
p.setPos(x.Sel.Pos())
p.print(x.Sel)
if !isMethod {
p.print(unindent)
}
return true
}
p.setPos(x.Sel.Pos())
p.print(x.Sel)
return false
}
func (p *printer) expr0(x ast.Expr, depth int) {
p.expr1(x, token.LowestPrec, depth)
}
func (p *printer) expr(x ast.Expr) {
const depth = 1
p.expr1(x, token.LowestPrec, depth)
}
// ----------------------------------------------------------------------------
// Statements
// Print the statement list indented, but without a newline after the last statement.
// Extra line breaks between statements in the source are respected but at most one
// empty line is printed between statements.
func (p *printer) stmtList(list []ast.Stmt, nindent int, nextIsRBrace bool) {
if nindent > 0 {
p.print(indent)
}
var line int
i := 0
for _, s := range list {
// ignore empty statements (was issue 3466)
if _, isEmpty := s.(*ast.EmptyStmt); !isEmpty {
// nindent == 0 only for lists of switch/select case clauses;
// in those cases each clause is a new section
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
p.linebreak(p.lineFor(s.Pos()), 1, ignore, i == 0 || nindent == 0 || p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.stmt(s, nextIsRBrace && i == len(list)-1)
// labeled statements put labels on a separate line, but here
// we only care about the start line of the actual statement
// without label - correct line for each label
for t := s; ; {
lt, _ := t.(*ast.LabeledStmt)
if lt == nil {
break
}
line++
t = lt.Stmt
}
i++
}
}
if nindent > 0 {
p.print(unindent)
}
}
// block prints an *ast.BlockStmt; it always spans at least two lines.
func (p *printer) block(b *ast.BlockStmt, nindent int) {
p.setPos(b.Lbrace)
p.print(token.LBRACE)
p.stmtList(b.List, nindent, true)
p.linebreak(p.lineFor(b.Rbrace), 1, ignore, true)
p.setPos(b.Rbrace)
p.print(token.RBRACE)
}
func isTypeName(x ast.Expr) bool {
switch t := x.(type) {
case *ast.Ident:
return true
case *ast.SelectorExpr:
return isTypeName(t.X)
}
return false
}
func stripParens(x ast.Expr) ast.Expr {
if px, strip := x.(*ast.ParenExpr); strip {
// parentheses must not be stripped if there are any
// unparenthesized composite literals starting with
// a type name
ast.Inspect(px.X, func(node ast.Node) bool {
switch x := node.(type) {
case *ast.ParenExpr:
// parentheses protect enclosed composite literals
return false
case *ast.CompositeLit:
if isTypeName(x.Type) {
strip = false // do not strip parentheses
}
return false
}
// in all other cases, keep inspecting
return true
})
if strip {
return stripParens(px.X)
}
}
return x
}
func stripParensAlways(x ast.Expr) ast.Expr {
if x, ok := x.(*ast.ParenExpr); ok {
return stripParensAlways(x.X)
}
return x
}
func (p *printer) controlClause(isForStmt bool, init ast.Stmt, expr ast.Expr, post ast.Stmt) {
p.print(blank)
needsBlank := false
if init == nil && post == nil {
// no semicolons required
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
} else {
// all semicolons required
// (they are not separators, print them explicitly)
if init != nil {
p.stmt(init, false)
}
p.print(token.SEMICOLON, blank)
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
if isForStmt {
p.print(token.SEMICOLON, blank)
needsBlank = false
if post != nil {
p.stmt(post, false)
needsBlank = true
}
}
}
if needsBlank {
p.print(blank)
}
}
// indentList reports whether an expression list would look better if it
// were indented wholesale (starting with the very first element, rather
// than starting at the first line break).
func (p *printer) indentList(list []ast.Expr) bool {
// Heuristic: indentList reports whether there are more than one multi-
// line element in the list, or if there is any element that is not
// starting on the same line as the previous one ends.
if len(list) >= 2 {
var b = p.lineFor(list[0].Pos())
var e = p.lineFor(list[len(list)-1].End())
if 0 < b && b < e {
// list spans multiple lines
n := 0 // multi-line element count
line := b
for _, x := range list {
xb := p.lineFor(x.Pos())
xe := p.lineFor(x.End())
if line < xb {
// x is not starting on the same
// line as the previous one ended
return true
}
if xb < xe {
// x is a multi-line element
n++
}
line = xe
}
return n > 1
}
}
return false
}
func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool) {
p.setPos(stmt.Pos())
switch s := stmt.(type) {
case *ast.BadStmt:
p.print("BadStmt")
case *ast.DeclStmt:
p.decl(s.Decl)
case *ast.EmptyStmt:
// nothing to do
case *ast.LabeledStmt:
// a "correcting" unindent immediately following a line break
// is applied before the line break if there is no comment
// between (see writeWhitespace)
p.print(unindent)
p.expr(s.Label)
p.setPos(s.Colon)
p.print(token.COLON, indent)
if e, isEmpty := s.Stmt.(*ast.EmptyStmt); isEmpty {
if !nextIsRBrace {
p.print(newline)
p.setPos(e.Pos())
p.print(token.SEMICOLON)
break
}
} else {
p.linebreak(p.lineFor(s.Stmt.Pos()), 1, ignore, true)
}
p.stmt(s.Stmt, nextIsRBrace)
case *ast.ExprStmt:
const depth = 1
p.expr0(s.X, depth)
case *ast.SendStmt:
const depth = 1
p.expr0(s.Chan, depth)
p.print(blank)
p.setPos(s.Arrow)
p.print(token.ARROW, blank)
p.expr0(s.Value, depth)
case *ast.IncDecStmt:
const depth = 1
p.expr0(s.X, depth+1)
p.setPos(s.TokPos)
p.print(s.Tok)
case *ast.AssignStmt:
var depth = 1
if len(s.Lhs) > 1 && len(s.Rhs) > 1 {
depth++
}
p.exprList(s.Pos(), s.Lhs, depth, 0, s.TokPos, false)
p.print(blank)
p.setPos(s.TokPos)
p.print(s.Tok, blank)
p.exprList(s.TokPos, s.Rhs, depth, 0, token.NoPos, false)
case *ast.GoStmt:
p.print(token.GO, blank)
p.expr(s.Call)
case *ast.DeferStmt:
p.print(token.DEFER, blank)
p.expr(s.Call)
case *ast.ReturnStmt:
p.print(token.RETURN)
if s.Results != nil {
p.print(blank)
// Use indentList heuristic to make corner cases look
// better (issue 1207). A more systematic approach would
// always indent, but this would cause significant
// reformatting of the code base and not necessarily
// lead to more nicely formatted code in general.
if p.indentList(s.Results) {
p.print(indent)
// Use NoPos so that a newline never goes before
// the results (see issue #32854).
p.exprList(token.NoPos, s.Results, 1, noIndent, token.NoPos, false)
p.print(unindent)
} else {
p.exprList(token.NoPos, s.Results, 1, 0, token.NoPos, false)
}
}
case *ast.BranchStmt:
p.print(s.Tok)
if s.Label != nil {
p.print(blank)
p.expr(s.Label)
}
case *ast.BlockStmt:
p.block(s, 1)
case *ast.IfStmt:
p.print(token.IF)
p.controlClause(false, s.Init, s.Cond, nil)
p.block(s.Body, 1)
if s.Else != nil {
p.print(blank, token.ELSE, blank)
switch s.Else.(type) {
case *ast.BlockStmt, *ast.IfStmt:
p.stmt(s.Else, nextIsRBrace)
default:
// This can only happen with an incorrectly
// constructed AST. Permit it but print so
// that it can be parsed without errors.
p.print(token.LBRACE, indent, formfeed)
p.stmt(s.Else, true)
p.print(unindent, formfeed, token.RBRACE)
}
}
case *ast.CaseClause:
if s.List != nil {
p.print(token.CASE, blank)
p.exprList(s.Pos(), s.List, 1, 0, s.Colon, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.SwitchStmt:
p.print(token.SWITCH)
p.controlClause(false, s.Init, s.Tag, nil)
p.block(s.Body, 0)
case *ast.TypeSwitchStmt:
p.print(token.SWITCH)
if s.Init != nil {
p.print(blank)
p.stmt(s.Init, false)
p.print(token.SEMICOLON)
}
p.print(blank)
p.stmt(s.Assign, false)
p.print(blank)
p.block(s.Body, 0)
case *ast.CommClause:
if s.Comm != nil {
p.print(token.CASE, blank)
p.stmt(s.Comm, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.SelectStmt:
p.print(token.SELECT, blank)
body := s.Body
if len(body.List) == 0 && !p.commentBefore(p.posFor(body.Rbrace)) {
// print empty select statement w/o comments on one line
p.setPos(body.Lbrace)
p.print(token.LBRACE)
p.setPos(body.Rbrace)
p.print(token.RBRACE)
} else {
p.block(body, 0)
}
case *ast.ForStmt:
p.print(token.FOR)
p.controlClause(true, s.Init, s.Cond, s.Post)
p.block(s.Body, 1)
case *ast.RangeStmt:
p.print(token.FOR, blank)
if s.Key != nil {
p.expr(s.Key)
if s.Value != nil {
// use position of value following the comma as
// comma position for correct comment placement
p.setPos(s.Value.Pos())
p.print(token.COMMA, blank)
p.expr(s.Value)
}
p.print(blank)
p.setPos(s.TokPos)
p.print(s.Tok, blank)
}
p.print(token.RANGE, blank)
p.expr(stripParens(s.X))
p.print(blank)
p.block(s.Body, 1)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Declarations
// The keepTypeColumn function determines if the type column of a series of
// consecutive const or var declarations must be kept, or if initialization
// values (V) can be placed in the type column (T) instead. The i'th entry
// in the result slice is true if the type column in spec[i] must be kept.
//
// For example, the declaration:
//
// const (
// foobar int = 42 // comment
// x = 7 // comment
// foo
// bar = 991
// )
//
// leads to the type/values matrix below. A run of value columns (V) can
// be moved into the type column if there is no type for any of the values
// in that column (we only move entire columns so that they align properly).
//
// matrix formatted result
// matrix
// T V -> T V -> true there is a T and so the type
// - V - V true column must be kept
// - - - - false
// - V V - false V is moved into T column
func keepTypeColumn(specs []ast.Spec) []bool {
m := make([]bool, len(specs))
populate := func(i, j int, keepType bool) {
if keepType {
for ; i < j; i++ {
m[i] = true
}
}
}
i0 := -1 // if i0 >= 0 we are in a run and i0 is the start of the run
var keepType bool
for i, s := range specs {
t := s.(*ast.ValueSpec)
if t.Values != nil {
if i0 < 0 {
// start of a run of ValueSpecs with non-nil Values
i0 = i
keepType = false
}
} else {
if i0 >= 0 {
// end of a run
populate(i0, i, keepType)
i0 = -1
}
}
if t.Type != nil {
keepType = true
}
}
if i0 >= 0 {
// end of a run
populate(i0, len(specs), keepType)
}
return m
}
func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool) {
p.setComment(s.Doc)
p.identList(s.Names, false) // always present
extraTabs := 3
if s.Type != nil || keepType {
p.print(vtab)
extraTabs--
}
if s.Type != nil {
p.expr(s.Type)
}
if s.Values != nil {
p.print(vtab, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
extraTabs--
}
if s.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(vtab)
}
p.setComment(s.Comment)
}
}
func sanitizeImportPath(lit *ast.BasicLit) *ast.BasicLit {
// Note: An unmodified AST generated by go/parser will already
// contain a backward- or double-quoted path string that does
// not contain any invalid characters, and most of the work
// here is not needed. However, a modified or generated AST
// may possibly contain non-canonical paths. Do the work in
// all cases since it's not too hard and not speed-critical.
// if we don't have a proper string, be conservative and return whatever we have
if lit.Kind != token.STRING {
return lit
}
s, err := strconv.Unquote(lit.Value)
if err != nil {
return lit
}
// if the string is an invalid path, return whatever we have
//
// spec: "Implementation restriction: A compiler may restrict
// ImportPaths to non-empty strings using only characters belonging
// to Unicode's L, M, N, P, and S general categories (the Graphic
// characters without spaces) and may also exclude the characters
// !"#$%&'()*,:;<=>?[\]^`{|} and the Unicode replacement character
// U+FFFD."
if s == "" {
return lit
}
const illegalChars = `!"#$%&'()*,:;<=>?[\]^{|}` + "`\uFFFD"
for _, r := range s {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) || strings.ContainsRune(illegalChars, r) {
return lit
}
}
// otherwise, return the double-quoted path
s = strconv.Quote(s)
if s == lit.Value {
return lit // nothing wrong with lit
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: token.STRING, Value: s}
}
// The parameter n is the number of specs in the group. If doIndent is set,
// multi-line identifier lists in the spec are indented when the first
// linebreak is encountered.
func (p *printer) spec(spec ast.Spec, n int, doIndent bool) {
switch s := spec.(type) {
case *ast.ImportSpec:
p.setComment(s.Doc)
if s.Name != nil {
p.expr(s.Name)
p.print(blank)
}
p.expr(sanitizeImportPath(s.Path))
p.setComment(s.Comment)
p.setPos(s.EndPos)
case *ast.ValueSpec:
if n != 1 {
p.internalError("expected n = 1; got", n)
}
p.setComment(s.Doc)
p.identList(s.Names, doIndent) // always present
if s.Type != nil {
p.print(blank)
p.expr(s.Type)
}
if s.Values != nil {
p.print(blank, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
}
p.setComment(s.Comment)
case *ast.TypeSpec:
p.setComment(s.Doc)
p.expr(s.Name)
if s.TypeParams != nil {
p.parameters(s.TypeParams, typeTParam)
}
if n == 1 {
p.print(blank)
} else {
p.print(vtab)
}
if s.Assign.IsValid() {
p.print(token.ASSIGN, blank)
}
p.expr(s.Type)
p.setComment(s.Comment)
default:
panic("unreachable")
}
}
func (p *printer) genDecl(d *ast.GenDecl) {
p.setComment(d.Doc)
p.setPos(d.Pos())
p.print(d.Tok, blank)
if d.Lparen.IsValid() || len(d.Specs) > 1 {
// group of parenthesized declarations
p.setPos(d.Lparen)
p.print(token.LPAREN)
if n := len(d.Specs); n > 0 {
p.print(indent, formfeed)
if n > 1 && (d.Tok == token.CONST || d.Tok == token.VAR) {
// two or more grouped const/var declarations:
// determine if the type column must be kept
keepType := keepTypeColumn(d.Specs)
var line int
for i, s := range d.Specs {
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.valueSpec(s.(*ast.ValueSpec), keepType[i])
}
} else {
var line int
for i, s := range d.Specs {
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.spec(s, n, false)
}
}
p.print(unindent, formfeed)
}
p.setPos(d.Rparen)
p.print(token.RPAREN)
} else if len(d.Specs) > 0 {
// single declaration
p.spec(d.Specs[0], 1, true)
}
}
// sizeCounter is an io.Writer which counts the number of bytes written,
// as well as whether a newline character was seen.
type sizeCounter struct {
hasNewline bool
size int
}
func (c *sizeCounter) Write(p []byte) (int, error) {
if !c.hasNewline {
for _, b := range p {
if b == '\n' || b == '\f' {
c.hasNewline = true
break
}
}
}
c.size += len(p)
return len(p), nil
}
// nodeSize determines the size of n in chars after formatting.
// The result is <= maxSize if the node fits on one line with at
// most maxSize chars and the formatted output doesn't contain
// any control chars. Otherwise, the result is > maxSize.
func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) {
// nodeSize invokes the printer, which may invoke nodeSize
// recursively. For deep composite literal nests, this can
// lead to an exponential algorithm. Remember previous
// results to prune the recursion (was issue 1628).
if size, found := p.nodeSizes[n]; found {
return size
}
size = maxSize + 1 // assume n doesn't fit
p.nodeSizes[n] = size
// nodeSize computation must be independent of particular
// style so that we always get the same decision; print
// in RawFormat
cfg := Config{Mode: RawFormat}
var counter sizeCounter
if err := cfg.fprint(&counter, p.fset, n, p.nodeSizes); err != nil {
return
}
if counter.size <= maxSize && !counter.hasNewline {
// n fits in a single line
size = counter.size
p.nodeSizes[n] = size
}
return
}
// numLines returns the number of lines spanned by node n in the original source.
func (p *printer) numLines(n ast.Node) int {
if from := n.Pos(); from.IsValid() {
if to := n.End(); to.IsValid() {
return p.lineFor(to) - p.lineFor(from) + 1
}
}
return infinity
}
// bodySize is like nodeSize but it is specialized for *ast.BlockStmt's.
func (p *printer) bodySize(b *ast.BlockStmt, maxSize int) int {
pos1 := b.Pos()
pos2 := b.Rbrace
if pos1.IsValid() && pos2.IsValid() && p.lineFor(pos1) != p.lineFor(pos2) {
// opening and closing brace are on different lines - don't make it a one-liner
return maxSize + 1
}
if len(b.List) > 5 {
// too many statements - don't make it a one-liner
return maxSize + 1
}
// otherwise, estimate body size
bodySize := p.commentSizeBefore(p.posFor(pos2))
for i, s := range b.List {
if bodySize > maxSize {
break // no need to continue
}
if i > 0 {
bodySize += 2 // space for a semicolon and blank
}
bodySize += p.nodeSize(s, maxSize)
}
return bodySize
}
// funcBody prints a function body following a function header of given headerSize.
// If the header's and block's size are "small enough" and the block is "simple enough",
// the block is printed on the current line, without line breaks, spaced from the header
// by sep. Otherwise the block's opening "{" is printed on the current line, followed by
// lines for the block's statements and its closing "}".
func (p *printer) funcBody(headerSize int, sep whiteSpace, b *ast.BlockStmt) {
if b == nil {
return
}
// save/restore composite literal nesting level
defer func(level int) {
p.level = level
}(p.level)
p.level = 0
const maxSize = 100
if headerSize+p.bodySize(b, maxSize) <= maxSize {
p.print(sep)
p.setPos(b.Lbrace)
p.print(token.LBRACE)
if len(b.List) > 0 {
p.print(blank)
for i, s := range b.List {
if i > 0 {
p.print(token.SEMICOLON, blank)
}
p.stmt(s, i == len(b.List)-1)
}
p.print(blank)
}
p.print(noExtraLinebreak)
p.setPos(b.Rbrace)
p.print(token.RBRACE, noExtraLinebreak)
return
}
if sep != ignore {
p.print(blank) // always use blank
}
p.block(b, 1)
}
// distanceFrom returns the column difference between p.out (the current output
// position) and startOutCol. If the start position is on a different line from
// the current position (or either is unknown), the result is infinity.
func (p *printer) distanceFrom(startPos token.Pos, startOutCol int) int {
if startPos.IsValid() && p.pos.IsValid() && p.posFor(startPos).Line == p.pos.Line {
return p.out.Column - startOutCol
}
return infinity
}
func (p *printer) funcDecl(d *ast.FuncDecl) {
p.setComment(d.Doc)
p.setPos(d.Pos())
p.print(token.FUNC, blank)
// We have to save startCol only after emitting FUNC; otherwise it can be on a
// different line (all whitespace preceding the FUNC is emitted only when the
// FUNC is emitted).
startCol := p.out.Column - len("func ")
if d.Recv != nil {
p.parameters(d.Recv, funcParam) // method: print receiver
p.print(blank)
}
p.expr(d.Name)
p.signature(d.Type)
p.funcBody(p.distanceFrom(d.Pos(), startCol), vtab, d.Body)
}
func (p *printer) decl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.BadDecl:
p.setPos(d.Pos())
p.print("BadDecl")
case *ast.GenDecl:
p.genDecl(d)
case *ast.FuncDecl:
p.funcDecl(d)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Files
func declToken(decl ast.Decl) (tok token.Token) {
tok = token.ILLEGAL
switch d := decl.(type) {
case *ast.GenDecl:
tok = d.Tok
case *ast.FuncDecl:
tok = token.FUNC
}
return
}
func (p *printer) declList(list []ast.Decl) {
tok := token.ILLEGAL
for _, d := range list {
prev := tok
tok = declToken(d)
// If the declaration token changed (e.g., from CONST to TYPE)
// or the next declaration has documentation associated with it,
// print an empty line between top-level declarations.
// (because p.linebreak is called with the position of d, which
// is past any documentation, the minimum requirement is satisfied
// even w/o the extra getDoc(d) nil-check - leave it in case the
// linebreak logic improves - there's already a TODO).
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
min := 1
if prev != tok || getDoc(d) != nil {
min = 2
}
// start a new section if the next declaration is a function
// that spans multiple lines (see also issue #19544)
p.linebreak(p.lineFor(d.Pos()), min, ignore, tok == token.FUNC && p.numLines(d) > 1)
}
p.decl(d)
}
}
func (p *printer) file(src *ast.File) {
p.setComment(src.Doc)
p.setPos(src.Pos())
p.print(token.PACKAGE, blank)
p.expr(src.Name)
p.declList(src.Decls)
p.print(newline)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package printer implements printing of AST nodes.
package printer
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/token"
"io"
"os"
"strings"
"sync"
"text/tabwriter"
"unicode"
)
const (
maxNewlines = 2 // max. number of newlines between source text
debug = false // enable for debugging
infinity = 1 << 30
)
type whiteSpace byte
const (
ignore = whiteSpace(0)
blank = whiteSpace(' ')
vtab = whiteSpace('\v')
newline = whiteSpace('\n')
formfeed = whiteSpace('\f')
indent = whiteSpace('>')
unindent = whiteSpace('<')
)
// A pmode value represents the current printer mode.
type pmode int
const (
noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment
noExtraLinebreak // disables extra line break after /*-style comment
)
type commentInfo struct {
cindex int // current comment index
comment *ast.CommentGroup // = printer.comments[cindex]; or nil
commentOffset int // = printer.posFor(printer.comments[cindex].List[0].Pos()).Offset; or infinity
commentNewline bool // true if the comment group contains newlines
}
type printer struct {
// Configuration (does not change after initialization)
Config
fset *token.FileSet
// Current state
output []byte // raw printer result
indent int // current indentation
level int // level == 0: outside composite literal; level > 0: inside composite literal
mode pmode // current printer mode
endAlignment bool // if set, terminate alignment immediately
impliedSemi bool // if set, a linebreak implies a semicolon
lastTok token.Token // last token printed (token.ILLEGAL if it's whitespace)
prevOpen token.Token // previous non-brace "open" token (, [, or token.ILLEGAL
wsbuf []whiteSpace // delayed white space
goBuild []int // start index of all //go:build comments in output
plusBuild []int // start index of all // +build comments in output
// Positions
// The out position differs from the pos position when the result
// formatting differs from the source formatting (in the amount of
// white space). If there's a difference and SourcePos is set in
// ConfigMode, //line directives are used in the output to restore
// original source positions for a reader.
pos token.Position // current position in AST (source) space
out token.Position // current position in output space
last token.Position // value of pos after calling writeString
linePtr *int // if set, record out.Line for the next token in *linePtr
// The list of all source comments, in order of appearance.
comments []*ast.CommentGroup // may be nil
useNodeComments bool // if not set, ignore lead and line comments of nodes
// Information about p.comments[p.cindex]; set up by nextComment.
commentInfo
// Cache of already computed node sizes.
nodeSizes map[ast.Node]int
// Cache of most recently computed line position.
cachedPos token.Pos
cachedLine int // line corresponding to cachedPos
}
func (p *printer) internalError(msg ...any) {
if debug {
fmt.Print(p.pos.String() + ": ")
fmt.Println(msg...)
panic("go/printer")
}
}
// commentsHaveNewline reports whether a list of comments belonging to
// an *ast.CommentGroup contains newlines. Because the position information
// may only be partially correct, we also have to read the comment text.
func (p *printer) commentsHaveNewline(list []*ast.Comment) bool {
// len(list) > 0
line := p.lineFor(list[0].Pos())
for i, c := range list {
if i > 0 && p.lineFor(list[i].Pos()) != line {
// not all comments on the same line
return true
}
if t := c.Text; len(t) >= 2 && (t[1] == '/' || strings.Contains(t, "\n")) {
return true
}
}
_ = line
return false
}
func (p *printer) nextComment() {
for p.cindex < len(p.comments) {
c := p.comments[p.cindex]
p.cindex++
if list := c.List; len(list) > 0 {
p.comment = c
p.commentOffset = p.posFor(list[0].Pos()).Offset
p.commentNewline = p.commentsHaveNewline(list)
return
}
// we should not reach here (correct ASTs don't have empty
// ast.CommentGroup nodes), but be conservative and try again
}
// no more comments
p.commentOffset = infinity
}
// commentBefore reports whether the current comment group occurs
// before the next position in the source code and printing it does
// not introduce implicit semicolons.
func (p *printer) commentBefore(next token.Position) bool {
return p.commentOffset < next.Offset && (!p.impliedSemi || !p.commentNewline)
}
// commentSizeBefore returns the estimated size of the
// comments on the same line before the next position.
func (p *printer) commentSizeBefore(next token.Position) int {
// save/restore current p.commentInfo (p.nextComment() modifies it)
defer func(info commentInfo) {
p.commentInfo = info
}(p.commentInfo)
size := 0
for p.commentBefore(next) {
for _, c := range p.comment.List {
size += len(c.Text)
}
p.nextComment()
}
return size
}
// recordLine records the output line number for the next non-whitespace
// token in *linePtr. It is used to compute an accurate line number for a
// formatted construct, independent of pending (not yet emitted) whitespace
// or comments.
func (p *printer) recordLine(linePtr *int) {
p.linePtr = linePtr
}
// linesFrom returns the number of output lines between the current
// output line and the line argument, ignoring any pending (not yet
// emitted) whitespace or comments. It is used to compute an accurate
// size (in number of lines) for a formatted construct.
func (p *printer) linesFrom(line int) int {
return p.out.Line - line
}
func (p *printer) posFor(pos token.Pos) token.Position {
// not used frequently enough to cache entire token.Position
return p.fset.PositionFor(pos, false /* absolute position */)
}
func (p *printer) lineFor(pos token.Pos) int {
if pos != p.cachedPos {
p.cachedPos = pos
p.cachedLine = p.fset.PositionFor(pos, false /* absolute position */).Line
}
return p.cachedLine
}
// writeLineDirective writes a //line directive if necessary.
func (p *printer) writeLineDirective(pos token.Position) {
if pos.IsValid() && (p.out.Line != pos.Line || p.out.Filename != pos.Filename) {
p.output = append(p.output, tabwriter.Escape) // protect '\n' in //line from tabwriter interpretation
p.output = append(p.output, fmt.Sprintf("//line %s:%d\n", pos.Filename, pos.Line)...)
p.output = append(p.output, tabwriter.Escape)
// p.out must match the //line directive
p.out.Filename = pos.Filename
p.out.Line = pos.Line
}
}
// writeIndent writes indentation.
func (p *printer) writeIndent() {
// use "hard" htabs - indentation columns
// must not be discarded by the tabwriter
n := p.Config.Indent + p.indent // include base indentation
for i := 0; i < n; i++ {
p.output = append(p.output, '\t')
}
// update positions
p.pos.Offset += n
p.pos.Column += n
p.out.Column += n
}
// writeByte writes ch n times to p.output and updates p.pos.
// Only used to write formatting (white space) characters.
func (p *printer) writeByte(ch byte, n int) {
if p.endAlignment {
// Ignore any alignment control character;
// and at the end of the line, break with
// a formfeed to indicate termination of
// existing columns.
switch ch {
case '\t', '\v':
ch = ' '
case '\n', '\f':
ch = '\f'
p.endAlignment = false
}
}
if p.out.Column == 1 {
// no need to write line directives before white space
p.writeIndent()
}
for i := 0; i < n; i++ {
p.output = append(p.output, ch)
}
// update positions
p.pos.Offset += n
if ch == '\n' || ch == '\f' {
p.pos.Line += n
p.out.Line += n
p.pos.Column = 1
p.out.Column = 1
return
}
p.pos.Column += n
p.out.Column += n
}
// writeString writes the string s to p.output and updates p.pos, p.out,
// and p.last. If isLit is set, s is escaped w/ tabwriter.Escape characters
// to protect s from being interpreted by the tabwriter.
//
// Note: writeString is only used to write Go tokens, literals, and
// comments, all of which must be written literally. Thus, it is correct
// to always set isLit = true. However, setting it explicitly only when
// needed (i.e., when we don't know that s contains no tabs or line breaks)
// avoids processing extra escape characters and reduces run time of the
// printer benchmark by up to 10%.
func (p *printer) writeString(pos token.Position, s string, isLit bool) {
if p.out.Column == 1 {
if p.Config.Mode&SourcePos != 0 {
p.writeLineDirective(pos)
}
p.writeIndent()
}
if pos.IsValid() {
// update p.pos (if pos is invalid, continue with existing p.pos)
// Note: Must do this after handling line beginnings because
// writeIndent updates p.pos if there's indentation, but p.pos
// is the position of s.
p.pos = pos
}
if isLit {
// Protect s such that is passes through the tabwriter
// unchanged. Note that valid Go programs cannot contain
// tabwriter.Escape bytes since they do not appear in legal
// UTF-8 sequences.
p.output = append(p.output, tabwriter.Escape)
}
if debug {
p.output = append(p.output, fmt.Sprintf("/*%s*/", pos)...) // do not update p.pos!
}
p.output = append(p.output, s...)
// update positions
nlines := 0
var li int // index of last newline; valid if nlines > 0
for i := 0; i < len(s); i++ {
// Raw string literals may contain any character except back quote (`).
if ch := s[i]; ch == '\n' || ch == '\f' {
// account for line break
nlines++
li = i
// A line break inside a literal will break whatever column
// formatting is in place; ignore any further alignment through
// the end of the line.
p.endAlignment = true
}
}
p.pos.Offset += len(s)
if nlines > 0 {
p.pos.Line += nlines
p.out.Line += nlines
c := len(s) - li
p.pos.Column = c
p.out.Column = c
} else {
p.pos.Column += len(s)
p.out.Column += len(s)
}
if isLit {
p.output = append(p.output, tabwriter.Escape)
}
p.last = p.pos
}
// writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely.
// pos is the comment position, next the position of the item
// after all pending comments, prev is the previous comment in
// a group of comments (or nil), and tok is the next token.
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, tok token.Token) {
if len(p.output) == 0 {
// the comment is the first item to be printed - don't write any whitespace
return
}
if pos.IsValid() && pos.Filename != p.last.Filename {
// comment in a different file - separate with newlines
p.writeByte('\f', maxNewlines)
return
}
if pos.Line == p.last.Line && (prev == nil || prev.Text[1] != '/') {
// comment on the same line as last item:
// separate with at least one separator
hasSep := false
if prev == nil {
// first comment of a comment group
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank:
// ignore any blanks before a comment
p.wsbuf[i] = ignore
continue
case vtab:
// respect existing tabs - important
// for proper formatting of commented structs
hasSep = true
continue
case indent:
// apply pending indentation
continue
}
j = i
break
}
p.writeWhitespace(j)
}
// make sure there is at least one separator
if !hasSep {
sep := byte('\t')
if pos.Line == next.Line {
// next item is on the same line as the comment
// (which must be a /*-style comment): separate
// with a blank instead of a tab
sep = ' '
}
p.writeByte(sep, 1)
}
} else {
// comment on a different line:
// separate with at least one line break
droppedLinebreak := false
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore any horizontal whitespace before line breaks
p.wsbuf[i] = ignore
continue
case indent:
// apply pending indentation
continue
case unindent:
// if this is not the last unindent, apply it
// as it is (likely) belonging to the last
// construct (e.g., a multi-line expression list)
// and is not part of closing a block
if i+1 < len(p.wsbuf) && p.wsbuf[i+1] == unindent {
continue
}
// if the next token is not a closing }, apply the unindent
// if it appears that the comment is aligned with the
// token; otherwise assume the unindent is part of a
// closing block and stop (this scenario appears with
// comments before a case label where the comments
// apply to the next case instead of the current one)
if tok != token.RBRACE && pos.Column == next.Column {
continue
}
case newline, formfeed:
p.wsbuf[i] = ignore
droppedLinebreak = prev == nil // record only if first comment of a group
}
j = i
break
}
p.writeWhitespace(j)
// determine number of linebreaks before the comment
n := 0
if pos.IsValid() && p.last.IsValid() {
n = pos.Line - p.last.Line
if n < 0 { // should never happen
n = 0
}
}
// at the package scope level only (p.indent == 0),
// add an extra newline if we dropped one before:
// this preserves a blank line before documentation
// comments at the package scope level (issue 2570)
if p.indent == 0 && droppedLinebreak {
n++
}
// make sure there is at least one line break
// if the previous comment was a line comment
if n == 0 && prev != nil && prev.Text[1] == '/' {
n = 1
}
if n > 0 {
// use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate
// individual lines of /*-style comments
p.writeByte('\f', nlimit(n))
}
}
}
// Returns true if s contains only white space
// (only tabs and blanks can appear in the printer's context).
func isBlank(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > ' ' {
return false
}
}
return true
}
// commonPrefix returns the common prefix of a and b.
func commonPrefix(a, b string) string {
i := 0
for i < len(a) && i < len(b) && a[i] == b[i] && (a[i] <= ' ' || a[i] == '*') {
i++
}
return a[0:i]
}
// trimRight returns s with trailing whitespace removed.
func trimRight(s string) string {
return strings.TrimRightFunc(s, unicode.IsSpace)
}
// stripCommonPrefix removes a common prefix from /*-style comment lines (unless no
// comment line is indented, all but the first line have some form of space prefix).
// The prefix is computed using heuristics such that is likely that the comment
// contents are nicely laid out after re-printing each line using the printer's
// current indentation.
func stripCommonPrefix(lines []string) {
if len(lines) <= 1 {
return // at most one line - nothing to do
}
// len(lines) > 1
// The heuristic in this function tries to handle a few
// common patterns of /*-style comments: Comments where
// the opening /* and closing */ are aligned and the
// rest of the comment text is aligned and indented with
// blanks or tabs, cases with a vertical "line of stars"
// on the left, and cases where the closing */ is on the
// same line as the last comment text.
// Compute maximum common white prefix of all but the first,
// last, and blank lines, and replace blank lines with empty
// lines (the first line starts with /* and has no prefix).
// In cases where only the first and last lines are not blank,
// such as two-line comments, or comments where all inner lines
// are blank, consider the last line for the prefix computation
// since otherwise the prefix would be empty.
//
// Note that the first and last line are never empty (they
// contain the opening /* and closing */ respectively) and
// thus they can be ignored by the blank line check.
prefix := ""
prefixSet := false
if len(lines) > 2 {
for i, line := range lines[1 : len(lines)-1] {
if isBlank(line) {
lines[1+i] = "" // range starts with lines[1]
} else {
if !prefixSet {
prefix = line
prefixSet = true
}
prefix = commonPrefix(prefix, line)
}
}
}
// If we don't have a prefix yet, consider the last line.
if !prefixSet {
line := lines[len(lines)-1]
prefix = commonPrefix(line, line)
}
/*
* Check for vertical "line of stars" and correct prefix accordingly.
*/
lineOfStars := false
if p, _, ok := strings.Cut(prefix, "*"); ok {
// remove trailing blank from prefix so stars remain aligned
prefix = strings.TrimSuffix(p, " ")
lineOfStars = true
} else {
// No line of stars present.
// Determine the white space on the first line after the /*
// and before the beginning of the comment text, assume two
// blanks instead of the /* unless the first character after
// the /* is a tab. If the first comment line is empty but
// for the opening /*, assume up to 3 blanks or a tab. This
// whitespace may be found as suffix in the common prefix.
first := lines[0]
if isBlank(first[2:]) {
// no comment text on the first line:
// reduce prefix by up to 3 blanks or a tab
// if present - this keeps comment text indented
// relative to the /* and */'s if it was indented
// in the first place
i := len(prefix)
for n := 0; n < 3 && i > 0 && prefix[i-1] == ' '; n++ {
i--
}
if i == len(prefix) && i > 0 && prefix[i-1] == '\t' {
i--
}
prefix = prefix[0:i]
} else {
// comment text on the first line
suffix := make([]byte, len(first))
n := 2 // start after opening /*
for n < len(first) && first[n] <= ' ' {
suffix[n] = first[n]
n++
}
if n > 2 && suffix[2] == '\t' {
// assume the '\t' compensates for the /*
suffix = suffix[2:n]
} else {
// otherwise assume two blanks
suffix[0], suffix[1] = ' ', ' '
suffix = suffix[0:n]
}
// Shorten the computed common prefix by the length of
// suffix, if it is found as suffix of the prefix.
prefix = strings.TrimSuffix(prefix, string(suffix))
}
}
// Handle last line: If it only contains a closing */, align it
// with the opening /*, otherwise align the text with the other
// lines.
last := lines[len(lines)-1]
closing := "*/"
before, _, _ := strings.Cut(last, closing) // closing always present
if isBlank(before) {
// last line only contains closing */
if lineOfStars {
closing = " */" // add blank to align final star
}
lines[len(lines)-1] = prefix + closing
} else {
// last line contains more comment text - assume
// it is aligned like the other lines and include
// in prefix computation
prefix = commonPrefix(prefix, last)
}
// Remove the common prefix from all but the first and empty lines.
for i, line := range lines {
if i > 0 && line != "" {
lines[i] = line[len(prefix):]
}
}
}
func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text
pos := p.posFor(comment.Pos())
const linePrefix = "//line "
if strings.HasPrefix(text, linePrefix) && (!pos.IsValid() || pos.Column == 1) {
// Possibly a //-style line directive.
// Suspend indentation temporarily to keep line directive valid.
defer func(indent int) { p.indent = indent }(p.indent)
p.indent = 0
}
// shortcut common case of //-style comments
if text[1] == '/' {
if constraint.IsGoBuild(text) {
p.goBuild = append(p.goBuild, len(p.output))
} else if constraint.IsPlusBuild(text) {
p.plusBuild = append(p.plusBuild, len(p.output))
}
p.writeString(pos, trimRight(text), true)
return
}
// for /*-style comments, print line by line and let the
// write function take care of the proper indentation
lines := strings.Split(text, "\n")
// The comment started in the first column but is going
// to be indented. For an idempotent result, add indentation
// to all lines such that they look like they were indented
// before - this will make sure the common prefix computation
// is the same independent of how many times formatting is
// applied (was issue 1835).
if pos.IsValid() && pos.Column == 1 && p.indent > 0 {
for i, line := range lines[1:] {
lines[1+i] = " " + line
}
}
stripCommonPrefix(lines)
// write comment lines, separated by formfeed,
// without a line break after the last line
for i, line := range lines {
if i > 0 {
p.writeByte('\f', 1)
pos = p.pos
}
if len(line) > 0 {
p.writeString(pos, trimRight(line), true)
}
}
}
// writeCommentSuffix writes a line break after a comment if indicated
// and processes any leftover indentation information. If a line break
// is needed, the kind of break (newline vs formfeed) depends on the
// pending whitespace. The writeCommentSuffix result indicates if a
// newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) writeCommentSuffix(needsLinebreak bool) (wroteNewline, droppedFF bool) {
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore trailing whitespace
p.wsbuf[i] = ignore
case indent, unindent:
// don't lose indentation information
case newline, formfeed:
// if we need a line break, keep exactly one
// but remember if we dropped any formfeeds
if needsLinebreak {
needsLinebreak = false
wroteNewline = true
} else {
if ch == formfeed {
droppedFF = true
}
p.wsbuf[i] = ignore
}
}
}
p.writeWhitespace(len(p.wsbuf))
// make sure we have a line break
if needsLinebreak {
p.writeByte('\n', 1)
wroteNewline = true
}
return
}
// containsLinebreak reports whether the whitespace buffer contains any line breaks.
func (p *printer) containsLinebreak() bool {
for _, ch := range p.wsbuf {
if ch == newline || ch == formfeed {
return true
}
}
return false
}
// intersperseComments consumes all comments that appear before the next token
// tok and prints it together with the buffered whitespace (i.e., the whitespace
// that needs to be written before the next token). A heuristic is used to mix
// the comments and whitespace. The intersperseComments result indicates if a
// newline was written or if a formfeed was dropped from the whitespace buffer.
func (p *printer) intersperseComments(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
var last *ast.Comment
for p.commentBefore(next) {
list := p.comment.List
changed := false
if p.lastTok != token.IMPORT && // do not rewrite cgo's import "C" comments
p.posFor(p.comment.Pos()).Column == 1 &&
p.posFor(p.comment.End()+1) == next {
// Unindented comment abutting next token position:
// a top-level doc comment.
list = formatDocComment(list)
changed = true
if len(p.comment.List) > 0 && len(list) == 0 {
// The doc comment was removed entirely.
// Keep preceding whitespace.
p.writeCommentPrefix(p.posFor(p.comment.Pos()), next, last, tok)
// Change print state to continue at next.
p.pos = next
p.last = next
// There can't be any more comments.
p.nextComment()
return p.writeCommentSuffix(false)
}
}
for _, c := range list {
p.writeCommentPrefix(p.posFor(c.Pos()), next, last, tok)
p.writeComment(c)
last = c
}
// In case list was rewritten, change print state to where
// the original list would have ended.
if len(p.comment.List) > 0 && changed {
last = p.comment.List[len(p.comment.List)-1]
p.pos = p.posFor(last.End())
p.last = p.pos
}
p.nextComment()
}
if last != nil {
// If the last comment is a /*-style comment and the next item
// follows on the same line but is not a comma, and not a "closing"
// token immediately following its corresponding "opening" token,
// add an extra separator unless explicitly disabled. Use a blank
// as separator unless we have pending linebreaks, they are not
// disabled, and we are outside a composite literal, in which case
// we want a linebreak (issue 15137).
// TODO(gri) This has become overly complicated. We should be able
// to track whether we're inside an expression or statement and
// use that information to decide more directly.
needsLinebreak := false
if p.mode&noExtraBlank == 0 &&
last.Text[1] == '*' && p.lineFor(last.Pos()) == next.Line &&
tok != token.COMMA &&
(tok != token.RPAREN || p.prevOpen == token.LPAREN) &&
(tok != token.RBRACK || p.prevOpen == token.LBRACK) {
if p.containsLinebreak() && p.mode&noExtraLinebreak == 0 && p.level == 0 {
needsLinebreak = true
} else {
p.writeByte(' ', 1)
}
}
// Ensure that there is a line break after a //-style comment,
// before EOF, and before a closing '}' unless explicitly disabled.
if last.Text[1] == '/' ||
tok == token.EOF ||
tok == token.RBRACE && p.mode&noExtraLinebreak == 0 {
needsLinebreak = true
}
return p.writeCommentSuffix(needsLinebreak)
}
// no comment was written - we should never reach here since
// intersperseComments should not be called in that case
p.internalError("intersperseComments called without pending comments")
return
}
// whiteWhitespace writes the first n whitespace entries.
func (p *printer) writeWhitespace(n int) {
// write entries
for i := 0; i < n; i++ {
switch ch := p.wsbuf[i]; ch {
case ignore:
// ignore!
case indent:
p.indent++
case unindent:
p.indent--
if p.indent < 0 {
p.internalError("negative indentation:", p.indent)
p.indent = 0
}
case newline, formfeed:
// A line break immediately followed by a "correcting"
// unindent is swapped with the unindent - this permits
// proper label positioning. If a comment is between
// the line break and the label, the unindent is not
// part of the comment whitespace prefix and the comment
// will be positioned correctly indented.
if i+1 < n && p.wsbuf[i+1] == unindent {
// Use a formfeed to terminate the current section.
// Otherwise, a long label name on the next line leading
// to a wide column may increase the indentation column
// of lines before the label; effectively leading to wrong
// indentation.
p.wsbuf[i], p.wsbuf[i+1] = unindent, formfeed
i-- // do it again
continue
}
fallthrough
default:
p.writeByte(byte(ch), 1)
}
}
// shift remaining entries down
l := copy(p.wsbuf, p.wsbuf[n:])
p.wsbuf = p.wsbuf[:l]
}
// ----------------------------------------------------------------------------
// Printing interface
// nlimit limits n to maxNewlines.
func nlimit(n int) int {
if n > maxNewlines {
n = maxNewlines
}
return n
}
func mayCombine(prev token.Token, next byte) (b bool) {
switch prev {
case token.INT:
b = next == '.' // 1.
case token.ADD:
b = next == '+' // ++
case token.SUB:
b = next == '-' // --
case token.QUO:
b = next == '*' // /*
case token.LSS:
b = next == '-' || next == '<' // <- or <<
case token.AND:
b = next == '&' || next == '^' // && or &^
}
return
}
func (p *printer) setPos(pos token.Pos) {
if pos.IsValid() {
p.pos = p.posFor(pos) // accurate position of next item
}
}
// print prints a list of "items" (roughly corresponding to syntactic
// tokens, but also including whitespace and formatting information).
// It is the only print function that should be called directly from
// any of the AST printing functions in nodes.go.
//
// Whitespace is accumulated until a non-whitespace token appears. Any
// comments that need to appear before that token are printed first,
// taking into account the amount and structure of any pending white-
// space for best comment placement. Then, any leftover whitespace is
// printed, followed by the actual token.
func (p *printer) print(args ...any) {
for _, arg := range args {
// information about the current arg
var data string
var isLit bool
var impliedSemi bool // value for p.impliedSemi after this arg
// record previous opening token, if any
switch p.lastTok {
case token.ILLEGAL:
// ignore (white space)
case token.LPAREN, token.LBRACK:
p.prevOpen = p.lastTok
default:
// other tokens followed any opening token
p.prevOpen = token.ILLEGAL
}
switch x := arg.(type) {
case pmode:
// toggle printer mode
p.mode ^= x
continue
case whiteSpace:
if x == ignore {
// don't add ignore's to the buffer; they
// may screw up "correcting" unindents (see
// LabeledStmt)
continue
}
i := len(p.wsbuf)
if i == cap(p.wsbuf) {
// Whitespace sequences are very short so this should
// never happen. Handle gracefully (but possibly with
// bad comment placement) if it does happen.
p.writeWhitespace(i)
i = 0
}
p.wsbuf = p.wsbuf[0 : i+1]
p.wsbuf[i] = x
if x == newline || x == formfeed {
// newlines affect the current state (p.impliedSemi)
// and not the state after printing arg (impliedSemi)
// because comments can be interspersed before the arg
// in this case
p.impliedSemi = false
}
p.lastTok = token.ILLEGAL
continue
case *ast.Ident:
data = x.Name
impliedSemi = true
p.lastTok = token.IDENT
case *ast.BasicLit:
data = x.Value
isLit = true
impliedSemi = true
p.lastTok = x.Kind
case token.Token:
s := x.String()
if mayCombine(p.lastTok, s[0]) {
// the previous and the current token must be
// separated by a blank otherwise they combine
// into a different incorrect token sequence
// (except for token.INT followed by a '.' this
// should never happen because it is taken care
// of via binary expression formatting)
if len(p.wsbuf) != 0 {
p.internalError("whitespace buffer not empty")
}
p.wsbuf = p.wsbuf[0:1]
p.wsbuf[0] = ' '
}
data = s
// some keywords followed by a newline imply a semicolon
switch x {
case token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN,
token.INC, token.DEC, token.RPAREN, token.RBRACK, token.RBRACE:
impliedSemi = true
}
p.lastTok = x
case string:
// incorrect AST - print error message
data = x
isLit = true
impliedSemi = true
p.lastTok = token.STRING
default:
fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", arg, arg)
panic("go/printer type")
}
// data != ""
next := p.pos // estimated/accurate position of next item
wroteNewline, droppedFF := p.flush(next, p.lastTok)
// intersperse extra newlines if present in the source and
// if they don't cause extra semicolons (don't do this in
// flush as it will cause extra newlines at the end of a file)
if !p.impliedSemi {
n := nlimit(next.Line - p.pos.Line)
// don't exceed maxNewlines if we already wrote one
if wroteNewline && n == maxNewlines {
n = maxNewlines - 1
}
if n > 0 {
ch := byte('\n')
if droppedFF {
ch = '\f' // use formfeed since we dropped one before
}
p.writeByte(ch, n)
impliedSemi = false
}
}
// the next token starts now - record its line number if requested
if p.linePtr != nil {
*p.linePtr = p.out.Line
p.linePtr = nil
}
p.writeString(next, data, isLit)
p.impliedSemi = impliedSemi
}
}
// flush prints any pending comments and whitespace occurring textually
// before the position of the next token tok. The flush result indicates
// if a newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) flush(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
if p.commentBefore(next) {
// if there are comments before the next item, intersperse them
wroteNewline, droppedFF = p.intersperseComments(next, tok)
} else {
// otherwise, write any leftover whitespace
p.writeWhitespace(len(p.wsbuf))
}
return
}
// getDoc returns the ast.CommentGroup associated with n, if any.
func getDoc(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Doc
case *ast.ImportSpec:
return n.Doc
case *ast.ValueSpec:
return n.Doc
case *ast.TypeSpec:
return n.Doc
case *ast.GenDecl:
return n.Doc
case *ast.FuncDecl:
return n.Doc
case *ast.File:
return n.Doc
}
return nil
}
func getLastComment(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Comment
case *ast.ImportSpec:
return n.Comment
case *ast.ValueSpec:
return n.Comment
case *ast.TypeSpec:
return n.Comment
case *ast.GenDecl:
if len(n.Specs) > 0 {
return getLastComment(n.Specs[len(n.Specs)-1])
}
case *ast.File:
if len(n.Comments) > 0 {
return n.Comments[len(n.Comments)-1]
}
}
return nil
}
func (p *printer) printNode(node any) error {
// unpack *CommentedNode, if any
var comments []*ast.CommentGroup
if cnode, ok := node.(*CommentedNode); ok {
node = cnode.Node
comments = cnode.Comments
}
if comments != nil {
// commented node - restrict comment list to relevant range
n, ok := node.(ast.Node)
if !ok {
goto unsupported
}
beg := n.Pos()
end := n.End()
// if the node has associated documentation,
// include that commentgroup in the range
// (the comment list is sorted in the order
// of the comment appearance in the source code)
if doc := getDoc(n); doc != nil {
beg = doc.Pos()
}
if com := getLastComment(n); com != nil {
if e := com.End(); e > end {
end = e
}
}
// token.Pos values are global offsets, we can
// compare them directly
i := 0
for i < len(comments) && comments[i].End() < beg {
i++
}
j := i
for j < len(comments) && comments[j].Pos() < end {
j++
}
if i < j {
p.comments = comments[i:j]
}
} else if n, ok := node.(*ast.File); ok {
// use ast.File comments, if any
p.comments = n.Comments
}
// if there are no comments, use node comments
p.useNodeComments = p.comments == nil
// get comments ready for use
p.nextComment()
p.print(pmode(0))
// format node
switch n := node.(type) {
case ast.Expr:
p.expr(n)
case ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
if _, ok := n.(*ast.LabeledStmt); ok {
p.indent = 1
}
p.stmt(n, false)
case ast.Decl:
p.decl(n)
case ast.Spec:
p.spec(n, 1, false)
case []ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
for _, s := range n {
if _, ok := s.(*ast.LabeledStmt); ok {
p.indent = 1
}
}
p.stmtList(n, 0, false)
case []ast.Decl:
p.declList(n)
case *ast.File:
p.file(n)
default:
goto unsupported
}
return nil
unsupported:
return fmt.Errorf("go/printer: unsupported node type %T", node)
}
// ----------------------------------------------------------------------------
// Trimmer
// A trimmer is an io.Writer filter for stripping tabwriter.Escape
// characters, trailing blanks and tabs, and for converting formfeed
// and vtab characters into newlines and htabs (in case no tabwriter
// is used). Text bracketed by tabwriter.Escape characters is passed
// through unchanged.
type trimmer struct {
output io.Writer
state int
space []byte
}
// trimmer is implemented as a state machine.
// It can be in one of the following states:
const (
inSpace = iota // inside space
inEscape // inside text bracketed by tabwriter.Escapes
inText // inside text
)
func (p *trimmer) resetSpace() {
p.state = inSpace
p.space = p.space[0:0]
}
// Design note: It is tempting to eliminate extra blanks occurring in
// whitespace in this function as it could simplify some
// of the blanks logic in the node printing functions.
// However, this would mess up any formatting done by
// the tabwriter.
var aNewline = []byte("\n")
func (p *trimmer) Write(data []byte) (n int, err error) {
// invariants:
// p.state == inSpace:
// p.space is unwritten
// p.state == inEscape, inText:
// data[m:n] is unwritten
m := 0
var b byte
for n, b = range data {
if b == '\v' {
b = '\t' // convert to htab
}
switch p.state {
case inSpace:
switch b {
case '\t', ' ':
p.space = append(p.space, b)
case '\n', '\f':
p.resetSpace() // discard trailing space
_, err = p.output.Write(aNewline)
case tabwriter.Escape:
_, err = p.output.Write(p.space)
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
default:
_, err = p.output.Write(p.space)
p.state = inText
m = n
}
case inEscape:
if b == tabwriter.Escape {
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
case inText:
switch b {
case '\t', ' ':
_, err = p.output.Write(data[m:n])
p.resetSpace()
p.space = append(p.space, b)
case '\n', '\f':
_, err = p.output.Write(data[m:n])
p.resetSpace()
if err == nil {
_, err = p.output.Write(aNewline)
}
case tabwriter.Escape:
_, err = p.output.Write(data[m:n])
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
}
default:
panic("unreachable")
}
if err != nil {
return
}
}
n = len(data)
switch p.state {
case inEscape, inText:
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
return
}
// ----------------------------------------------------------------------------
// Public interface
// A Mode value is a set of flags (or 0). They control printing.
type Mode uint
const (
RawFormat Mode = 1 << iota // do not use a tabwriter; if set, UseSpaces is ignored
TabIndent // use tabs for indentation independent of UseSpaces
UseSpaces // use spaces instead of tabs for alignment
SourcePos // emit //line directives to preserve original source positions
)
// The mode below is not included in printer's public API because
// editing code text is deemed out of scope. Because this mode is
// unexported, it's also possible to modify or remove it based on
// the evolving needs of go/format and cmd/gofmt without breaking
// users. See discussion in CL 240683.
const (
// normalizeNumbers means to canonicalize number
// literal prefixes and exponents while printing.
//
// This value is known in and used by go/format and cmd/gofmt.
// It is currently more convenient and performant for those
// packages to apply number normalization during printing,
// rather than by modifying the AST in advance.
normalizeNumbers Mode = 1 << 30
)
// A Config node controls the output of Fprint.
type Config struct {
Mode Mode // default: 0
Tabwidth int // default: 8
Indent int // default: 0 (all code is indented at least by this much)
}
var printerPool = sync.Pool{
New: func() any {
return &printer{
// Whitespace sequences are short.
wsbuf: make([]whiteSpace, 0, 16),
// We start the printer with a 16K output buffer, which is currently
// larger than about 80% of Go files in the standard library.
output: make([]byte, 0, 16<<10),
}
},
}
func newPrinter(cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) *printer {
p := printerPool.Get().(*printer)
*p = printer{
Config: *cfg,
fset: fset,
pos: token.Position{Line: 1, Column: 1},
out: token.Position{Line: 1, Column: 1},
wsbuf: p.wsbuf[:0],
nodeSizes: nodeSizes,
cachedPos: -1,
output: p.output[:0],
}
return p
}
func (p *printer) free() {
// Hard limit on buffer size; see https://golang.org/issue/23199.
if cap(p.output) > 64<<10 {
return
}
printerPool.Put(p)
}
// fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node any, nodeSizes map[ast.Node]int) (err error) {
// print node
p := newPrinter(cfg, fset, nodeSizes)
defer p.free()
if err = p.printNode(node); err != nil {
return
}
// print outstanding comments
p.impliedSemi = false // EOF acts like a newline
p.flush(token.Position{Offset: infinity, Line: infinity}, token.EOF)
// output is buffered in p.output now.
// fix //go:build and // +build comments if needed.
p.fixGoBuildLines()
// redirect output through a trimmer to eliminate trailing whitespace
// (Input to a tabwriter must be untrimmed since trailing tabs provide
// formatting information. The tabwriter could provide trimming
// functionality but no tabwriter is used when RawFormat is set.)
output = &trimmer{output: output}
// redirect output through a tabwriter if necessary
if cfg.Mode&RawFormat == 0 {
minwidth := cfg.Tabwidth
padchar := byte('\t')
if cfg.Mode&UseSpaces != 0 {
padchar = ' '
}
twmode := tabwriter.DiscardEmptyColumns
if cfg.Mode&TabIndent != 0 {
minwidth = 0
twmode |= tabwriter.TabIndent
}
output = tabwriter.NewWriter(output, minwidth, cfg.Tabwidth, 1, padchar, twmode)
}
// write printer result via tabwriter/trimmer to output
if _, err = output.Write(p.output); err != nil {
return
}
// flush tabwriter, if any
if tw, _ := output.(*tabwriter.Writer); tw != nil {
err = tw.Flush()
}
return
}
// A CommentedNode bundles an AST node and corresponding comments.
// It may be provided as argument to any of the Fprint functions.
type CommentedNode struct {
Node any // *ast.File, or ast.Expr, ast.Decl, ast.Spec, or ast.Stmt
Comments []*ast.CommentGroup
}
// Fprint "pretty-prints" an AST node to output for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *ast.File, *CommentedNode, []ast.Decl, []ast.Stmt,
// or assignment-compatible to ast.Expr, ast.Decl, ast.Spec, or ast.Stmt.
func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node any) error {
return cfg.fprint(output, fset, node, make(map[ast.Node]int))
}
// Fprint "pretty-prints" an AST node to output.
// It calls Config.Fprint with default settings.
// Note that gofmt uses tabs for indentation but spaces for alignment;
// use format.Node (package go/format) for output that matches gofmt.
func Fprint(output io.Writer, fset *token.FileSet, node any) error {
return (&Config{Tabwidth: 8}).Fprint(output, fset, node)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package scanner
import (
"fmt"
"go/token"
"io"
"sort"
)
// In an ErrorList, an error is represented by an *Error.
// The position Pos, if valid, points to the beginning of
// the offending token, and the error condition is described
// by Msg.
type Error struct {
Pos token.Position
Msg string
}
// Error implements the error interface.
func (e Error) Error() string {
if e.Pos.Filename != "" || e.Pos.IsValid() {
// don't print "<unknown position>"
// TODO(gri) reconsider the semantics of Position.IsValid
return e.Pos.String() + ": " + e.Msg
}
return e.Msg
}
// ErrorList is a list of *Errors.
// The zero value for an ErrorList is an empty ErrorList ready to use.
type ErrorList []*Error
// Add adds an Error with given position and error message to an ErrorList.
func (p *ErrorList) Add(pos token.Position, msg string) {
*p = append(*p, &Error{pos, msg})
}
// Reset resets an ErrorList to no errors.
func (p *ErrorList) Reset() { *p = (*p)[0:0] }
// ErrorList implements the sort Interface.
func (p ErrorList) Len() int { return len(p) }
func (p ErrorList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p ErrorList) Less(i, j int) bool {
e := &p[i].Pos
f := &p[j].Pos
// Note that it is not sufficient to simply compare file offsets because
// the offsets do not reflect modified line information (through //line
// comments).
if e.Filename != f.Filename {
return e.Filename < f.Filename
}
if e.Line != f.Line {
return e.Line < f.Line
}
if e.Column != f.Column {
return e.Column < f.Column
}
return p[i].Msg < p[j].Msg
}
// Sort sorts an ErrorList. *Error entries are sorted by position,
// other errors are sorted by error message, and before any *Error
// entry.
func (p ErrorList) Sort() {
sort.Sort(p)
}
// RemoveMultiples sorts an ErrorList and removes all but the first error per line.
func (p *ErrorList) RemoveMultiples() {
sort.Sort(p)
var last token.Position // initial last.Line is != any legal error line
i := 0
for _, e := range *p {
if e.Pos.Filename != last.Filename || e.Pos.Line != last.Line {
last = e.Pos
(*p)[i] = e
i++
}
}
*p = (*p)[0:i]
}
// An ErrorList implements the error interface.
func (p ErrorList) Error() string {
switch len(p) {
case 0:
return "no errors"
case 1:
return p[0].Error()
}
return fmt.Sprintf("%s (and %d more errors)", p[0], len(p)-1)
}
// Err returns an error equivalent to this error list.
// If the list is empty, Err returns nil.
func (p ErrorList) Err() error {
if len(p) == 0 {
return nil
}
return p
}
// PrintError is a utility function that prints a list of errors to w,
// one error per line, if the err parameter is an ErrorList. Otherwise
// it prints the err string.
func PrintError(w io.Writer, err error) {
if list, ok := err.(ErrorList); ok {
for _, e := range list {
fmt.Fprintf(w, "%s\n", e)
}
} else if err != nil {
fmt.Fprintf(w, "%s\n", err)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package scanner implements a scanner for Go source text.
// It takes a []byte as source which can then be tokenized
// through repeated calls to the Scan method.
package scanner
import (
"bytes"
"fmt"
"go/token"
"path/filepath"
"strconv"
"unicode"
"unicode/utf8"
)
// An ErrorHandler may be provided to Scanner.Init. If a syntax error is
// encountered and a handler was installed, the handler is called with a
// position and an error message. The position points to the beginning of
// the offending token.
type ErrorHandler func(pos token.Position, msg string)
// A Scanner holds the scanner's internal state while processing
// a given text. It can be allocated as part of another data
// structure but must be initialized via Init before use.
type Scanner struct {
// immutable state
file *token.File // source file handle
dir string // directory portion of file.Name()
src []byte // source
err ErrorHandler // error reporting; or nil
mode Mode // scanning mode
// scanning state
ch rune // current character
offset int // character offset
rdOffset int // reading offset (position after current character)
lineOffset int // current line offset
insertSemi bool // insert a semicolon before next newline
nlPos token.Pos // position of newline in preceding comment
// public state - ok to modify
ErrorCount int // number of errors encountered
}
const (
bom = 0xFEFF // byte order mark, only permitted as very first character
eof = -1 // end of file
)
// Read the next Unicode char into s.ch.
// s.ch < 0 means end-of-file.
//
// For optimization, there is some overlap between this method and
// s.scanIdentifier.
func (s *Scanner) next() {
if s.rdOffset < len(s.src) {
s.offset = s.rdOffset
if s.ch == '\n' {
s.lineOffset = s.offset
s.file.AddLine(s.offset)
}
r, w := rune(s.src[s.rdOffset]), 1
switch {
case r == 0:
s.error(s.offset, "illegal character NUL")
case r >= utf8.RuneSelf:
// not ASCII
r, w = utf8.DecodeRune(s.src[s.rdOffset:])
if r == utf8.RuneError && w == 1 {
s.error(s.offset, "illegal UTF-8 encoding")
} else if r == bom && s.offset > 0 {
s.error(s.offset, "illegal byte order mark")
}
}
s.rdOffset += w
s.ch = r
} else {
s.offset = len(s.src)
if s.ch == '\n' {
s.lineOffset = s.offset
s.file.AddLine(s.offset)
}
s.ch = eof
}
}
// peek returns the byte following the most recently read character without
// advancing the scanner. If the scanner is at EOF, peek returns 0.
func (s *Scanner) peek() byte {
if s.rdOffset < len(s.src) {
return s.src[s.rdOffset]
}
return 0
}
// A mode value is a set of flags (or 0).
// They control scanner behavior.
type Mode uint
const (
ScanComments Mode = 1 << iota // return comments as COMMENT tokens
dontInsertSemis // do not automatically insert semicolons - for testing only
)
// Init prepares the scanner s to tokenize the text src by setting the
// scanner at the beginning of src. The scanner uses the file set file
// for position information and it adds line information for each line.
// It is ok to re-use the same file when re-scanning the same file as
// line information which is already present is ignored. Init causes a
// panic if the file size does not match the src size.
//
// Calls to Scan will invoke the error handler err if they encounter a
// syntax error and err is not nil. Also, for each error encountered,
// the Scanner field ErrorCount is incremented by one. The mode parameter
// determines how comments are handled.
//
// Note that Init may call err if there is an error in the first character
// of the file.
func (s *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode Mode) {
// Explicitly initialize all fields since a scanner may be reused.
if file.Size() != len(src) {
panic(fmt.Sprintf("file size (%d) does not match src len (%d)", file.Size(), len(src)))
}
s.file = file
s.dir, _ = filepath.Split(file.Name())
s.src = src
s.err = err
s.mode = mode
s.ch = ' '
s.offset = 0
s.rdOffset = 0
s.lineOffset = 0
s.insertSemi = false
s.ErrorCount = 0
s.next()
if s.ch == bom {
s.next() // ignore BOM at file beginning
}
}
func (s *Scanner) error(offs int, msg string) {
if s.err != nil {
s.err(s.file.Position(s.file.Pos(offs)), msg)
}
s.ErrorCount++
}
func (s *Scanner) errorf(offs int, format string, args ...any) {
s.error(offs, fmt.Sprintf(format, args...))
}
// scanComment returns the text of the comment and (if nonzero)
// the offset of the first newline within it, which implies a
// /*...*/ comment.
func (s *Scanner) scanComment() (string, int) {
// initial '/' already consumed; s.ch == '/' || s.ch == '*'
offs := s.offset - 1 // position of initial '/'
next := -1 // position immediately following the comment; < 0 means invalid comment
numCR := 0
nlOffset := 0 // offset of first newline within /*...*/ comment
if s.ch == '/' {
//-style comment
// (the final '\n' is not considered part of the comment)
s.next()
for s.ch != '\n' && s.ch >= 0 {
if s.ch == '\r' {
numCR++
}
s.next()
}
// if we are at '\n', the position following the comment is afterwards
next = s.offset
if s.ch == '\n' {
next++
}
goto exit
}
/*-style comment */
s.next()
for s.ch >= 0 {
ch := s.ch
if ch == '\r' {
numCR++
} else if ch == '\n' && nlOffset == 0 {
nlOffset = s.offset
}
s.next()
if ch == '*' && s.ch == '/' {
s.next()
next = s.offset
goto exit
}
}
s.error(offs, "comment not terminated")
exit:
lit := s.src[offs:s.offset]
// On Windows, a (//-comment) line may end in "\r\n".
// Remove the final '\r' before analyzing the text for
// line directives (matching the compiler). Remove any
// other '\r' afterwards (matching the pre-existing be-
// havior of the scanner).
if numCR > 0 && len(lit) >= 2 && lit[1] == '/' && lit[len(lit)-1] == '\r' {
lit = lit[:len(lit)-1]
numCR--
}
// interpret line directives
// (//line directives must start at the beginning of the current line)
if next >= 0 /* implies valid comment */ && (lit[1] == '*' || offs == s.lineOffset) && bytes.HasPrefix(lit[2:], prefix) {
s.updateLineInfo(next, offs, lit)
}
if numCR > 0 {
lit = stripCR(lit, lit[1] == '*')
}
return string(lit), nlOffset
}
var prefix = []byte("line ")
// updateLineInfo parses the incoming comment text at offset offs
// as a line directive. If successful, it updates the line info table
// for the position next per the line directive.
func (s *Scanner) updateLineInfo(next, offs int, text []byte) {
// extract comment text
if text[1] == '*' {
text = text[:len(text)-2] // lop off trailing "*/"
}
text = text[7:] // lop off leading "//line " or "/*line "
offs += 7
i, n, ok := trailingDigits(text)
if i == 0 {
return // ignore (not a line directive)
}
// i > 0
if !ok {
// text has a suffix :xxx but xxx is not a number
s.error(offs+i, "invalid line number: "+string(text[i:]))
return
}
var line, col int
i2, n2, ok2 := trailingDigits(text[:i-1])
if ok2 {
//line filename:line:col
i, i2 = i2, i
line, col = n2, n
if col == 0 {
s.error(offs+i2, "invalid column number: "+string(text[i2:]))
return
}
text = text[:i2-1] // lop off ":col"
} else {
//line filename:line
line = n
}
if line == 0 {
s.error(offs+i, "invalid line number: "+string(text[i:]))
return
}
// If we have a column (//line filename:line:col form),
// an empty filename means to use the previous filename.
filename := string(text[:i-1]) // lop off ":line", and trim white space
if filename == "" && ok2 {
filename = s.file.Position(s.file.Pos(offs)).Filename
} else if filename != "" {
// Put a relative filename in the current directory.
// This is for compatibility with earlier releases.
// See issue 26671.
filename = filepath.Clean(filename)
if !filepath.IsAbs(filename) {
filename = filepath.Join(s.dir, filename)
}
}
s.file.AddLineColumnInfo(next, filename, line, col)
}
func trailingDigits(text []byte) (int, int, bool) {
i := bytes.LastIndexByte(text, ':') // look from right (Windows filenames may contain ':')
if i < 0 {
return 0, 0, false // no ":"
}
// i >= 0
n, err := strconv.ParseUint(string(text[i+1:]), 10, 0)
return i + 1, int(n), err == nil
}
func isLetter(ch rune) bool {
return 'a' <= lower(ch) && lower(ch) <= 'z' || ch == '_' || ch >= utf8.RuneSelf && unicode.IsLetter(ch)
}
func isDigit(ch rune) bool {
return isDecimal(ch) || ch >= utf8.RuneSelf && unicode.IsDigit(ch)
}
// scanIdentifier reads the string of valid identifier characters at s.offset.
// It must only be called when s.ch is known to be a valid letter.
//
// Be careful when making changes to this function: it is optimized and affects
// scanning performance significantly.
func (s *Scanner) scanIdentifier() string {
offs := s.offset
// Optimize for the common case of an ASCII identifier.
//
// Ranging over s.src[s.rdOffset:] lets us avoid some bounds checks, and
// avoids conversions to runes.
//
// In case we encounter a non-ASCII character, fall back on the slower path
// of calling into s.next().
for rdOffset, b := range s.src[s.rdOffset:] {
if 'a' <= b && b <= 'z' || 'A' <= b && b <= 'Z' || b == '_' || '0' <= b && b <= '9' {
// Avoid assigning a rune for the common case of an ascii character.
continue
}
s.rdOffset += rdOffset
if 0 < b && b < utf8.RuneSelf {
// Optimization: we've encountered an ASCII character that's not a letter
// or number. Avoid the call into s.next() and corresponding set up.
//
// Note that s.next() does some line accounting if s.ch is '\n', so this
// shortcut is only possible because we know that the preceding character
// is not '\n'.
s.ch = rune(b)
s.offset = s.rdOffset
s.rdOffset++
goto exit
}
// We know that the preceding character is valid for an identifier because
// scanIdentifier is only called when s.ch is a letter, so calling s.next()
// at s.rdOffset resets the scanner state.
s.next()
for isLetter(s.ch) || isDigit(s.ch) {
s.next()
}
goto exit
}
s.offset = len(s.src)
s.rdOffset = len(s.src)
s.ch = eof
exit:
return string(s.src[offs:s.offset])
}
func digitVal(ch rune) int {
switch {
case '0' <= ch && ch <= '9':
return int(ch - '0')
case 'a' <= lower(ch) && lower(ch) <= 'f':
return int(lower(ch) - 'a' + 10)
}
return 16 // larger than any legal digit val
}
func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter
func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' }
func isHex(ch rune) bool { return '0' <= ch && ch <= '9' || 'a' <= lower(ch) && lower(ch) <= 'f' }
// digits accepts the sequence { digit | '_' }.
// If base <= 10, digits accepts any decimal digit but records
// the offset (relative to the source start) of a digit >= base
// in *invalid, if *invalid < 0.
// digits returns a bitset describing whether the sequence contained
// digits (bit 0 is set), or separators '_' (bit 1 is set).
func (s *Scanner) digits(base int, invalid *int) (digsep int) {
if base <= 10 {
max := rune('0' + base)
for isDecimal(s.ch) || s.ch == '_' {
ds := 1
if s.ch == '_' {
ds = 2
} else if s.ch >= max && *invalid < 0 {
*invalid = s.offset // record invalid rune offset
}
digsep |= ds
s.next()
}
} else {
for isHex(s.ch) || s.ch == '_' {
ds := 1
if s.ch == '_' {
ds = 2
}
digsep |= ds
s.next()
}
}
return
}
func (s *Scanner) scanNumber() (token.Token, string) {
offs := s.offset
tok := token.ILLEGAL
base := 10 // number base
prefix := rune(0) // one of 0 (decimal), '0' (0-octal), 'x', 'o', or 'b'
digsep := 0 // bit 0: digit present, bit 1: '_' present
invalid := -1 // index of invalid digit in literal, or < 0
// integer part
if s.ch != '.' {
tok = token.INT
if s.ch == '0' {
s.next()
switch lower(s.ch) {
case 'x':
s.next()
base, prefix = 16, 'x'
case 'o':
s.next()
base, prefix = 8, 'o'
case 'b':
s.next()
base, prefix = 2, 'b'
default:
base, prefix = 8, '0'
digsep = 1 // leading 0
}
}
digsep |= s.digits(base, &invalid)
}
// fractional part
if s.ch == '.' {
tok = token.FLOAT
if prefix == 'o' || prefix == 'b' {
s.error(s.offset, "invalid radix point in "+litname(prefix))
}
s.next()
digsep |= s.digits(base, &invalid)
}
if digsep&1 == 0 {
s.error(s.offset, litname(prefix)+" has no digits")
}
// exponent
if e := lower(s.ch); e == 'e' || e == 'p' {
switch {
case e == 'e' && prefix != 0 && prefix != '0':
s.errorf(s.offset, "%q exponent requires decimal mantissa", s.ch)
case e == 'p' && prefix != 'x':
s.errorf(s.offset, "%q exponent requires hexadecimal mantissa", s.ch)
}
s.next()
tok = token.FLOAT
if s.ch == '+' || s.ch == '-' {
s.next()
}
ds := s.digits(10, nil)
digsep |= ds
if ds&1 == 0 {
s.error(s.offset, "exponent has no digits")
}
} else if prefix == 'x' && tok == token.FLOAT {
s.error(s.offset, "hexadecimal mantissa requires a 'p' exponent")
}
// suffix 'i'
if s.ch == 'i' {
tok = token.IMAG
s.next()
}
lit := string(s.src[offs:s.offset])
if tok == token.INT && invalid >= 0 {
s.errorf(invalid, "invalid digit %q in %s", lit[invalid-offs], litname(prefix))
}
if digsep&2 != 0 {
if i := invalidSep(lit); i >= 0 {
s.error(offs+i, "'_' must separate successive digits")
}
}
return tok, lit
}
func litname(prefix rune) string {
switch prefix {
case 'x':
return "hexadecimal literal"
case 'o', '0':
return "octal literal"
case 'b':
return "binary literal"
}
return "decimal literal"
}
// invalidSep returns the index of the first invalid separator in x, or -1.
func invalidSep(x string) int {
x1 := ' ' // prefix char, we only care if it's 'x'
d := '.' // digit, one of '_', '0' (a digit), or '.' (anything else)
i := 0
// a prefix counts as a digit
if len(x) >= 2 && x[0] == '0' {
x1 = lower(rune(x[1]))
if x1 == 'x' || x1 == 'o' || x1 == 'b' {
d = '0'
i = 2
}
}
// mantissa and exponent
for ; i < len(x); i++ {
p := d // previous digit
d = rune(x[i])
switch {
case d == '_':
if p != '0' {
return i
}
case isDecimal(d) || x1 == 'x' && isHex(d):
d = '0'
default:
if p == '_' {
return i - 1
}
d = '.'
}
}
if d == '_' {
return len(x) - 1
}
return -1
}
// scanEscape parses an escape sequence where rune is the accepted
// escaped quote. In case of a syntax error, it stops at the offending
// character (without consuming it) and returns false. Otherwise
// it returns true.
func (s *Scanner) scanEscape(quote rune) bool {
offs := s.offset
var n int
var base, max uint32
switch s.ch {
case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', quote:
s.next()
return true
case '0', '1', '2', '3', '4', '5', '6', '7':
n, base, max = 3, 8, 255
case 'x':
s.next()
n, base, max = 2, 16, 255
case 'u':
s.next()
n, base, max = 4, 16, unicode.MaxRune
case 'U':
s.next()
n, base, max = 8, 16, unicode.MaxRune
default:
msg := "unknown escape sequence"
if s.ch < 0 {
msg = "escape sequence not terminated"
}
s.error(offs, msg)
return false
}
var x uint32
for n > 0 {
d := uint32(digitVal(s.ch))
if d >= base {
msg := fmt.Sprintf("illegal character %#U in escape sequence", s.ch)
if s.ch < 0 {
msg = "escape sequence not terminated"
}
s.error(s.offset, msg)
return false
}
x = x*base + d
s.next()
n--
}
if x > max || 0xD800 <= x && x < 0xE000 {
s.error(offs, "escape sequence is invalid Unicode code point")
return false
}
return true
}
func (s *Scanner) scanRune() string {
// '\'' opening already consumed
offs := s.offset - 1
valid := true
n := 0
for {
ch := s.ch
if ch == '\n' || ch < 0 {
// only report error if we don't have one already
if valid {
s.error(offs, "rune literal not terminated")
valid = false
}
break
}
s.next()
if ch == '\'' {
break
}
n++
if ch == '\\' {
if !s.scanEscape('\'') {
valid = false
}
// continue to read to closing quote
}
}
if valid && n != 1 {
s.error(offs, "illegal rune literal")
}
return string(s.src[offs:s.offset])
}
func (s *Scanner) scanString() string {
// '"' opening already consumed
offs := s.offset - 1
for {
ch := s.ch
if ch == '\n' || ch < 0 {
s.error(offs, "string literal not terminated")
break
}
s.next()
if ch == '"' {
break
}
if ch == '\\' {
s.scanEscape('"')
}
}
return string(s.src[offs:s.offset])
}
func stripCR(b []byte, comment bool) []byte {
c := make([]byte, len(b))
i := 0
for j, ch := range b {
// In a /*-style comment, don't strip \r from *\r/ (incl.
// sequences of \r from *\r\r...\r/) since the resulting
// */ would terminate the comment too early unless the \r
// is immediately following the opening /* in which case
// it's ok because /*/ is not closed yet (issue #11151).
if ch != '\r' || comment && i > len("/*") && c[i-1] == '*' && j+1 < len(b) && b[j+1] == '/' {
c[i] = ch
i++
}
}
return c[:i]
}
func (s *Scanner) scanRawString() string {
// '`' opening already consumed
offs := s.offset - 1
hasCR := false
for {
ch := s.ch
if ch < 0 {
s.error(offs, "raw string literal not terminated")
break
}
s.next()
if ch == '`' {
break
}
if ch == '\r' {
hasCR = true
}
}
lit := s.src[offs:s.offset]
if hasCR {
lit = stripCR(lit, false)
}
return string(lit)
}
func (s *Scanner) skipWhitespace() {
for s.ch == ' ' || s.ch == '\t' || s.ch == '\n' && !s.insertSemi || s.ch == '\r' {
s.next()
}
}
// Helper functions for scanning multi-byte tokens such as >> += >>= .
// Different routines recognize different length tok_i based on matches
// of ch_i. If a token ends in '=', the result is tok1 or tok3
// respectively. Otherwise, the result is tok0 if there was no other
// matching character, or tok2 if the matching character was ch2.
func (s *Scanner) switch2(tok0, tok1 token.Token) token.Token {
if s.ch == '=' {
s.next()
return tok1
}
return tok0
}
func (s *Scanner) switch3(tok0, tok1 token.Token, ch2 rune, tok2 token.Token) token.Token {
if s.ch == '=' {
s.next()
return tok1
}
if s.ch == ch2 {
s.next()
return tok2
}
return tok0
}
func (s *Scanner) switch4(tok0, tok1 token.Token, ch2 rune, tok2, tok3 token.Token) token.Token {
if s.ch == '=' {
s.next()
return tok1
}
if s.ch == ch2 {
s.next()
if s.ch == '=' {
s.next()
return tok3
}
return tok2
}
return tok0
}
// Scan scans the next token and returns the token position, the token,
// and its literal string if applicable. The source end is indicated by
// token.EOF.
//
// If the returned token is a literal (token.IDENT, token.INT, token.FLOAT,
// token.IMAG, token.CHAR, token.STRING) or token.COMMENT, the literal string
// has the corresponding value.
//
// If the returned token is a keyword, the literal string is the keyword.
//
// If the returned token is token.SEMICOLON, the corresponding
// literal string is ";" if the semicolon was present in the source,
// and "\n" if the semicolon was inserted because of a newline or
// at EOF.
//
// If the returned token is token.ILLEGAL, the literal string is the
// offending character.
//
// In all other cases, Scan returns an empty literal string.
//
// For more tolerant parsing, Scan will return a valid token if
// possible even if a syntax error was encountered. Thus, even
// if the resulting token sequence contains no illegal tokens,
// a client may not assume that no error occurred. Instead it
// must check the scanner's ErrorCount or the number of calls
// of the error handler, if there was one installed.
//
// Scan adds line information to the file added to the file
// set with Init. Token positions are relative to that file
// and thus relative to the file set.
func (s *Scanner) Scan() (pos token.Pos, tok token.Token, lit string) {
scanAgain:
if s.nlPos.IsValid() {
// Return artificial ';' token after /*...*/ comment
// containing newline, at position of first newline.
pos, tok, lit = s.nlPos, token.SEMICOLON, "\n"
s.nlPos = token.NoPos
return
}
s.skipWhitespace()
// current token start
pos = s.file.Pos(s.offset)
// determine token value
insertSemi := false
switch ch := s.ch; {
case isLetter(ch):
lit = s.scanIdentifier()
if len(lit) > 1 {
// keywords are longer than one letter - avoid lookup otherwise
tok = token.Lookup(lit)
switch tok {
case token.IDENT, token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN:
insertSemi = true
}
} else {
insertSemi = true
tok = token.IDENT
}
case isDecimal(ch) || ch == '.' && isDecimal(rune(s.peek())):
insertSemi = true
tok, lit = s.scanNumber()
default:
s.next() // always make progress
switch ch {
case eof:
if s.insertSemi {
s.insertSemi = false // EOF consumed
return pos, token.SEMICOLON, "\n"
}
tok = token.EOF
case '\n':
// we only reach here if s.insertSemi was
// set in the first place and exited early
// from s.skipWhitespace()
s.insertSemi = false // newline consumed
return pos, token.SEMICOLON, "\n"
case '"':
insertSemi = true
tok = token.STRING
lit = s.scanString()
case '\'':
insertSemi = true
tok = token.CHAR
lit = s.scanRune()
case '`':
insertSemi = true
tok = token.STRING
lit = s.scanRawString()
case ':':
tok = s.switch2(token.COLON, token.DEFINE)
case '.':
// fractions starting with a '.' are handled by outer switch
tok = token.PERIOD
if s.ch == '.' && s.peek() == '.' {
s.next()
s.next() // consume last '.'
tok = token.ELLIPSIS
}
case ',':
tok = token.COMMA
case ';':
tok = token.SEMICOLON
lit = ";"
case '(':
tok = token.LPAREN
case ')':
insertSemi = true
tok = token.RPAREN
case '[':
tok = token.LBRACK
case ']':
insertSemi = true
tok = token.RBRACK
case '{':
tok = token.LBRACE
case '}':
insertSemi = true
tok = token.RBRACE
case '+':
tok = s.switch3(token.ADD, token.ADD_ASSIGN, '+', token.INC)
if tok == token.INC {
insertSemi = true
}
case '-':
tok = s.switch3(token.SUB, token.SUB_ASSIGN, '-', token.DEC)
if tok == token.DEC {
insertSemi = true
}
case '*':
tok = s.switch2(token.MUL, token.MUL_ASSIGN)
case '/':
if s.ch == '/' || s.ch == '*' {
// comment
comment, nlOffset := s.scanComment()
if s.insertSemi && nlOffset != 0 {
// For /*...*/ containing \n, return
// COMMENT then artificial SEMICOLON.
s.nlPos = s.file.Pos(nlOffset)
s.insertSemi = false
} else {
insertSemi = s.insertSemi // preserve insertSemi info
}
if s.mode&ScanComments == 0 {
// skip comment
goto scanAgain
}
tok = token.COMMENT
lit = comment
} else {
// division
tok = s.switch2(token.QUO, token.QUO_ASSIGN)
}
case '%':
tok = s.switch2(token.REM, token.REM_ASSIGN)
case '^':
tok = s.switch2(token.XOR, token.XOR_ASSIGN)
case '<':
if s.ch == '-' {
s.next()
tok = token.ARROW
} else {
tok = s.switch4(token.LSS, token.LEQ, '<', token.SHL, token.SHL_ASSIGN)
}
case '>':
tok = s.switch4(token.GTR, token.GEQ, '>', token.SHR, token.SHR_ASSIGN)
case '=':
tok = s.switch2(token.ASSIGN, token.EQL)
case '!':
tok = s.switch2(token.NOT, token.NEQ)
case '&':
if s.ch == '^' {
s.next()
tok = s.switch2(token.AND_NOT, token.AND_NOT_ASSIGN)
} else {
tok = s.switch3(token.AND, token.AND_ASSIGN, '&', token.LAND)
}
case '|':
tok = s.switch3(token.OR, token.OR_ASSIGN, '|', token.LOR)
case '~':
tok = token.TILDE
default:
// next reports unexpected BOMs - don't repeat
if ch != bom {
s.errorf(s.file.Offset(pos), "illegal character %#U", ch)
}
insertSemi = s.insertSemi // preserve insertSemi info
tok = token.ILLEGAL
lit = string(ch)
}
}
if s.mode&dontInsertSemis == 0 {
s.insertSemi = insertSemi
}
return
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package token
import (
"fmt"
"sort"
"sync"
"sync/atomic"
)
// -----------------------------------------------------------------------------
// Positions
// Position describes an arbitrary source position
// including the file, line, and column location.
// A Position is valid if the line number is > 0.
type Position struct {
Filename string // filename, if any
Offset int // offset, starting at 0
Line int // line number, starting at 1
Column int // column number, starting at 1 (byte count)
}
// IsValid reports whether the position is valid.
func (pos *Position) IsValid() bool { return pos.Line > 0 }
// String returns a string in one of several forms:
//
// file:line:column valid position with file name
// file:line valid position with file name but no column (column == 0)
// line:column valid position without file name
// line valid position without file name and no column (column == 0)
// file invalid position with file name
// - invalid position without file name
func (pos Position) String() string {
s := pos.Filename
if pos.IsValid() {
if s != "" {
s += ":"
}
s += fmt.Sprintf("%d", pos.Line)
if pos.Column != 0 {
s += fmt.Sprintf(":%d", pos.Column)
}
}
if s == "" {
s = "-"
}
return s
}
// Pos is a compact encoding of a source position within a file set.
// It can be converted into a Position for a more convenient, but much
// larger, representation.
//
// The Pos value for a given file is a number in the range [base, base+size],
// where base and size are specified when a file is added to the file set.
// The difference between a Pos value and the corresponding file base
// corresponds to the byte offset of that position (represented by the Pos value)
// from the beginning of the file. Thus, the file base offset is the Pos value
// representing the first byte in the file.
//
// To create the Pos value for a specific source offset (measured in bytes),
// first add the respective file to the current file set using FileSet.AddFile
// and then call File.Pos(offset) for that file. Given a Pos value p
// for a specific file set fset, the corresponding Position value is
// obtained by calling fset.Position(p).
//
// Pos values can be compared directly with the usual comparison operators:
// If two Pos values p and q are in the same file, comparing p and q is
// equivalent to comparing the respective source file offsets. If p and q
// are in different files, p < q is true if the file implied by p was added
// to the respective file set before the file implied by q.
type Pos int
// The zero value for Pos is NoPos; there is no file and line information
// associated with it, and NoPos.IsValid() is false. NoPos is always
// smaller than any other Pos value. The corresponding Position value
// for NoPos is the zero value for Position.
const NoPos Pos = 0
// IsValid reports whether the position is valid.
func (p Pos) IsValid() bool {
return p != NoPos
}
// -----------------------------------------------------------------------------
// File
// A File is a handle for a file belonging to a FileSet.
// A File has a name, size, and line offset table.
type File struct {
name string // file name as provided to AddFile
base int // Pos value range for this file is [base...base+size]
size int // file size as provided to AddFile
// lines and infos are protected by mutex
mutex sync.Mutex
lines []int // lines contains the offset of the first character for each line (the first entry is always 0)
infos []lineInfo
}
// Name returns the file name of file f as registered with AddFile.
func (f *File) Name() string {
return f.name
}
// Base returns the base offset of file f as registered with AddFile.
func (f *File) Base() int {
return f.base
}
// Size returns the size of file f as registered with AddFile.
func (f *File) Size() int {
return f.size
}
// LineCount returns the number of lines in file f.
func (f *File) LineCount() int {
f.mutex.Lock()
n := len(f.lines)
f.mutex.Unlock()
return n
}
// AddLine adds the line offset for a new line.
// The line offset must be larger than the offset for the previous line
// and smaller than the file size; otherwise the line offset is ignored.
func (f *File) AddLine(offset int) {
f.mutex.Lock()
if i := len(f.lines); (i == 0 || f.lines[i-1] < offset) && offset < f.size {
f.lines = append(f.lines, offset)
}
f.mutex.Unlock()
}
// MergeLine merges a line with the following line. It is akin to replacing
// the newline character at the end of the line with a space (to not change the
// remaining offsets). To obtain the line number, consult e.g. Position.Line.
// MergeLine will panic if given an invalid line number.
func (f *File) MergeLine(line int) {
if line < 1 {
panic(fmt.Sprintf("invalid line number %d (should be >= 1)", line))
}
f.mutex.Lock()
defer f.mutex.Unlock()
if line >= len(f.lines) {
panic(fmt.Sprintf("invalid line number %d (should be < %d)", line, len(f.lines)))
}
// To merge the line numbered <line> with the line numbered <line+1>,
// we need to remove the entry in lines corresponding to the line
// numbered <line+1>. The entry in lines corresponding to the line
// numbered <line+1> is located at index <line>, since indices in lines
// are 0-based and line numbers are 1-based.
copy(f.lines[line:], f.lines[line+1:])
f.lines = f.lines[:len(f.lines)-1]
}
// Lines returns the effective line offset table of the form described by SetLines.
// Callers must not mutate the result.
func (f *File) Lines() []int {
f.mutex.Lock()
lines := f.lines
f.mutex.Unlock()
return lines
}
// SetLines sets the line offsets for a file and reports whether it succeeded.
// The line offsets are the offsets of the first character of each line;
// for instance for the content "ab\nc\n" the line offsets are {0, 3}.
// An empty file has an empty line offset table.
// Each line offset must be larger than the offset for the previous line
// and smaller than the file size; otherwise SetLines fails and returns
// false.
// Callers must not mutate the provided slice after SetLines returns.
func (f *File) SetLines(lines []int) bool {
// verify validity of lines table
size := f.size
for i, offset := range lines {
if i > 0 && offset <= lines[i-1] || size <= offset {
return false
}
}
// set lines table
f.mutex.Lock()
f.lines = lines
f.mutex.Unlock()
return true
}
// SetLinesForContent sets the line offsets for the given file content.
// It ignores position-altering //line comments.
func (f *File) SetLinesForContent(content []byte) {
var lines []int
line := 0
for offset, b := range content {
if line >= 0 {
lines = append(lines, line)
}
line = -1
if b == '\n' {
line = offset + 1
}
}
// set lines table
f.mutex.Lock()
f.lines = lines
f.mutex.Unlock()
}
// LineStart returns the Pos value of the start of the specified line.
// It ignores any alternative positions set using AddLineColumnInfo.
// LineStart panics if the 1-based line number is invalid.
func (f *File) LineStart(line int) Pos {
if line < 1 {
panic(fmt.Sprintf("invalid line number %d (should be >= 1)", line))
}
f.mutex.Lock()
defer f.mutex.Unlock()
if line > len(f.lines) {
panic(fmt.Sprintf("invalid line number %d (should be < %d)", line, len(f.lines)))
}
return Pos(f.base + f.lines[line-1])
}
// A lineInfo object describes alternative file, line, and column
// number information (such as provided via a //line directive)
// for a given file offset.
type lineInfo struct {
// fields are exported to make them accessible to gob
Offset int
Filename string
Line, Column int
}
// AddLineInfo is like AddLineColumnInfo with a column = 1 argument.
// It is here for backward-compatibility for code prior to Go 1.11.
func (f *File) AddLineInfo(offset int, filename string, line int) {
f.AddLineColumnInfo(offset, filename, line, 1)
}
// AddLineColumnInfo adds alternative file, line, and column number
// information for a given file offset. The offset must be larger
// than the offset for the previously added alternative line info
// and smaller than the file size; otherwise the information is
// ignored.
//
// AddLineColumnInfo is typically used to register alternative position
// information for line directives such as //line filename:line:column.
func (f *File) AddLineColumnInfo(offset int, filename string, line, column int) {
f.mutex.Lock()
if i := len(f.infos); (i == 0 || f.infos[i-1].Offset < offset) && offset < f.size {
f.infos = append(f.infos, lineInfo{offset, filename, line, column})
}
f.mutex.Unlock()
}
// Pos returns the Pos value for the given file offset;
// the offset must be <= f.Size().
// f.Pos(f.Offset(p)) == p.
func (f *File) Pos(offset int) Pos {
if offset > f.size {
panic(fmt.Sprintf("invalid file offset %d (should be <= %d)", offset, f.size))
}
return Pos(f.base + offset)
}
// Offset returns the offset for the given file position p;
// p must be a valid Pos value in that file.
// f.Offset(f.Pos(offset)) == offset.
func (f *File) Offset(p Pos) int {
if int(p) < f.base || int(p) > f.base+f.size {
panic(fmt.Sprintf("invalid Pos value %d (should be in [%d, %d])", p, f.base, f.base+f.size))
}
return int(p) - f.base
}
// Line returns the line number for the given file position p;
// p must be a Pos value in that file or NoPos.
func (f *File) Line(p Pos) int {
return f.Position(p).Line
}
func searchLineInfos(a []lineInfo, x int) int {
return sort.Search(len(a), func(i int) bool { return a[i].Offset > x }) - 1
}
// unpack returns the filename and line and column number for a file offset.
// If adjusted is set, unpack will return the filename and line information
// possibly adjusted by //line comments; otherwise those comments are ignored.
func (f *File) unpack(offset int, adjusted bool) (filename string, line, column int) {
f.mutex.Lock()
filename = f.name
if i := searchInts(f.lines, offset); i >= 0 {
line, column = i+1, offset-f.lines[i]+1
}
if adjusted && len(f.infos) > 0 {
// few files have extra line infos
if i := searchLineInfos(f.infos, offset); i >= 0 {
alt := &f.infos[i]
filename = alt.Filename
if i := searchInts(f.lines, alt.Offset); i >= 0 {
// i+1 is the line at which the alternative position was recorded
d := line - (i + 1) // line distance from alternative position base
line = alt.Line + d
if alt.Column == 0 {
// alternative column is unknown => relative column is unknown
// (the current specification for line directives requires
// this to apply until the next PosBase/line directive,
// not just until the new newline)
column = 0
} else if d == 0 {
// the alternative position base is on the current line
// => column is relative to alternative column
column = alt.Column + (offset - alt.Offset)
}
}
}
}
// TODO(mvdan): move Unlock back under Lock with a defer statement once
// https://go.dev/issue/38471 is fixed to remove the performance penalty.
f.mutex.Unlock()
return
}
func (f *File) position(p Pos, adjusted bool) (pos Position) {
offset := int(p) - f.base
pos.Offset = offset
pos.Filename, pos.Line, pos.Column = f.unpack(offset, adjusted)
return
}
// PositionFor returns the Position value for the given file position p.
// If adjusted is set, the position may be adjusted by position-altering
// //line comments; otherwise those comments are ignored.
// p must be a Pos value in f or NoPos.
func (f *File) PositionFor(p Pos, adjusted bool) (pos Position) {
if p != NoPos {
if int(p) < f.base || int(p) > f.base+f.size {
panic(fmt.Sprintf("invalid Pos value %d (should be in [%d, %d])", p, f.base, f.base+f.size))
}
pos = f.position(p, adjusted)
}
return
}
// Position returns the Position value for the given file position p.
// Calling f.Position(p) is equivalent to calling f.PositionFor(p, true).
func (f *File) Position(p Pos) (pos Position) {
return f.PositionFor(p, true)
}
// -----------------------------------------------------------------------------
// FileSet
// A FileSet represents a set of source files.
// Methods of file sets are synchronized; multiple goroutines
// may invoke them concurrently.
//
// The byte offsets for each file in a file set are mapped into
// distinct (integer) intervals, one interval [base, base+size]
// per file. Base represents the first byte in the file, and size
// is the corresponding file size. A Pos value is a value in such
// an interval. By determining the interval a Pos value belongs
// to, the file, its file base, and thus the byte offset (position)
// the Pos value is representing can be computed.
//
// When adding a new file, a file base must be provided. That can
// be any integer value that is past the end of any interval of any
// file already in the file set. For convenience, FileSet.Base provides
// such a value, which is simply the end of the Pos interval of the most
// recently added file, plus one. Unless there is a need to extend an
// interval later, using the FileSet.Base should be used as argument
// for FileSet.AddFile.
//
// A File may be removed from a FileSet when it is no longer needed.
// This may reduce memory usage in a long-running application.
type FileSet struct {
mutex sync.RWMutex // protects the file set
base int // base offset for the next file
files []*File // list of files in the order added to the set
last atomic.Pointer[File] // cache of last file looked up
}
// NewFileSet creates a new file set.
func NewFileSet() *FileSet {
return &FileSet{
base: 1, // 0 == NoPos
}
}
// Base returns the minimum base offset that must be provided to
// AddFile when adding the next file.
func (s *FileSet) Base() int {
s.mutex.RLock()
b := s.base
s.mutex.RUnlock()
return b
}
// AddFile adds a new file with a given filename, base offset, and file size
// to the file set s and returns the file. Multiple files may have the same
// name. The base offset must not be smaller than the FileSet's Base(), and
// size must not be negative. As a special case, if a negative base is provided,
// the current value of the FileSet's Base() is used instead.
//
// Adding the file will set the file set's Base() value to base + size + 1
// as the minimum base value for the next file. The following relationship
// exists between a Pos value p for a given file offset offs:
//
// int(p) = base + offs
//
// with offs in the range [0, size] and thus p in the range [base, base+size].
// For convenience, File.Pos may be used to create file-specific position
// values from a file offset.
func (s *FileSet) AddFile(filename string, base, size int) *File {
// Allocate f outside the critical section.
f := &File{name: filename, size: size, lines: []int{0}}
s.mutex.Lock()
defer s.mutex.Unlock()
if base < 0 {
base = s.base
}
if base < s.base {
panic(fmt.Sprintf("invalid base %d (should be >= %d)", base, s.base))
}
f.base = base
if size < 0 {
panic(fmt.Sprintf("invalid size %d (should be >= 0)", size))
}
// base >= s.base && size >= 0
base += size + 1 // +1 because EOF also has a position
if base < 0 {
panic("token.Pos offset overflow (> 2G of source code in file set)")
}
// add the file to the file set
s.base = base
s.files = append(s.files, f)
s.last.Store(f)
return f
}
// RemoveFile removes a file from the FileSet so that subsequent
// queries for its Pos interval yield a negative result.
// This reduces the memory usage of a long-lived FileSet that
// encounters an unbounded stream of files.
//
// Removing a file that does not belong to the set has no effect.
func (s *FileSet) RemoveFile(file *File) {
s.last.CompareAndSwap(file, nil) // clear last file cache
s.mutex.Lock()
defer s.mutex.Unlock()
if i := searchFiles(s.files, file.base); i >= 0 && s.files[i] == file {
last := &s.files[len(s.files)-1]
s.files = append(s.files[:i], s.files[i+1:]...)
*last = nil // don't prolong lifetime when popping last element
}
}
// Iterate calls f for the files in the file set in the order they were added
// until f returns false.
func (s *FileSet) Iterate(f func(*File) bool) {
for i := 0; ; i++ {
var file *File
s.mutex.RLock()
if i < len(s.files) {
file = s.files[i]
}
s.mutex.RUnlock()
if file == nil || !f(file) {
break
}
}
}
func searchFiles(a []*File, x int) int {
return sort.Search(len(a), func(i int) bool { return a[i].base > x }) - 1
}
func (s *FileSet) file(p Pos) *File {
// common case: p is in last file.
if f := s.last.Load(); f != nil && f.base <= int(p) && int(p) <= f.base+f.size {
return f
}
s.mutex.RLock()
defer s.mutex.RUnlock()
// p is not in last file - search all files
if i := searchFiles(s.files, int(p)); i >= 0 {
f := s.files[i]
// f.base <= int(p) by definition of searchFiles
if int(p) <= f.base+f.size {
// Update cache of last file. A race is ok,
// but an exclusive lock causes heavy contention.
s.last.Store(f)
return f
}
}
return nil
}
// File returns the file that contains the position p.
// If no such file is found (for instance for p == NoPos),
// the result is nil.
func (s *FileSet) File(p Pos) (f *File) {
if p != NoPos {
f = s.file(p)
}
return
}
// PositionFor converts a Pos p in the fileset into a Position value.
// If adjusted is set, the position may be adjusted by position-altering
// //line comments; otherwise those comments are ignored.
// p must be a Pos value in s or NoPos.
func (s *FileSet) PositionFor(p Pos, adjusted bool) (pos Position) {
if p != NoPos {
if f := s.file(p); f != nil {
return f.position(p, adjusted)
}
}
return
}
// Position converts a Pos p in the fileset into a Position value.
// Calling s.Position(p) is equivalent to calling s.PositionFor(p, true).
func (s *FileSet) Position(p Pos) (pos Position) {
return s.PositionFor(p, true)
}
// -----------------------------------------------------------------------------
// Helper functions
func searchInts(a []int, x int) int {
// This function body is a manually inlined version of:
//
// return sort.Search(len(a), func(i int) bool { return a[i] > x }) - 1
//
// With better compiler optimizations, this may not be needed in the
// future, but at the moment this change improves the go/printer
// benchmark performance by ~30%. This has a direct impact on the
// speed of gofmt and thus seems worthwhile (2011-04-29).
// TODO(gri): Remove this when compilers have caught up.
i, j := 0, len(a)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if a[h] <= x {
i = h + 1
} else {
j = h
}
}
return i - 1
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package token
type serializedFile struct {
// fields correspond 1:1 to fields with same (lower-case) name in File
Name string
Base int
Size int
Lines []int
Infos []lineInfo
}
type serializedFileSet struct {
Base int
Files []serializedFile
}
// Read calls decode to deserialize a file set into s; s must not be nil.
func (s *FileSet) Read(decode func(any) error) error {
var ss serializedFileSet
if err := decode(&ss); err != nil {
return err
}
s.mutex.Lock()
s.base = ss.Base
files := make([]*File, len(ss.Files))
for i := 0; i < len(ss.Files); i++ {
f := &ss.Files[i]
files[i] = &File{
name: f.Name,
base: f.Base,
size: f.Size,
lines: f.Lines,
infos: f.Infos,
}
}
s.files = files
s.last.Store(nil)
s.mutex.Unlock()
return nil
}
// Write calls encode to serialize the file set s.
func (s *FileSet) Write(encode func(any) error) error {
var ss serializedFileSet
s.mutex.Lock()
ss.Base = s.base
files := make([]serializedFile, len(s.files))
for i, f := range s.files {
f.mutex.Lock()
files[i] = serializedFile{
Name: f.name,
Base: f.base,
Size: f.size,
Lines: append([]int(nil), f.lines...),
Infos: append([]lineInfo(nil), f.infos...),
}
f.mutex.Unlock()
}
ss.Files = files
s.mutex.Unlock()
return encode(ss)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package token defines constants representing the lexical tokens of the Go
// programming language and basic operations on tokens (printing, predicates).
package token
import (
"strconv"
"unicode"
"unicode/utf8"
)
// Token is the set of lexical tokens of the Go programming language.
type Token int
// The list of tokens.
const (
// Special tokens
ILLEGAL Token = iota
EOF
COMMENT
literal_beg
// Identifiers and basic type literals
// (these tokens stand for classes of literals)
IDENT // main
INT // 12345
FLOAT // 123.45
IMAG // 123.45i
CHAR // 'a'
STRING // "abc"
literal_end
operator_beg
// Operators and delimiters
ADD // +
SUB // -
MUL // *
QUO // /
REM // %
AND // &
OR // |
XOR // ^
SHL // <<
SHR // >>
AND_NOT // &^
ADD_ASSIGN // +=
SUB_ASSIGN // -=
MUL_ASSIGN // *=
QUO_ASSIGN // /=
REM_ASSIGN // %=
AND_ASSIGN // &=
OR_ASSIGN // |=
XOR_ASSIGN // ^=
SHL_ASSIGN // <<=
SHR_ASSIGN // >>=
AND_NOT_ASSIGN // &^=
LAND // &&
LOR // ||
ARROW // <-
INC // ++
DEC // --
EQL // ==
LSS // <
GTR // >
ASSIGN // =
NOT // !
NEQ // !=
LEQ // <=
GEQ // >=
DEFINE // :=
ELLIPSIS // ...
LPAREN // (
LBRACK // [
LBRACE // {
COMMA // ,
PERIOD // .
RPAREN // )
RBRACK // ]
RBRACE // }
SEMICOLON // ;
COLON // :
operator_end
keyword_beg
// Keywords
BREAK
CASE
CHAN
CONST
CONTINUE
DEFAULT
DEFER
ELSE
FALLTHROUGH
FOR
FUNC
GO
GOTO
IF
IMPORT
INTERFACE
MAP
PACKAGE
RANGE
RETURN
SELECT
STRUCT
SWITCH
TYPE
VAR
keyword_end
additional_beg
// additional tokens, handled in an ad-hoc manner
TILDE
additional_end
)
var tokens = [...]string{
ILLEGAL: "ILLEGAL",
EOF: "EOF",
COMMENT: "COMMENT",
IDENT: "IDENT",
INT: "INT",
FLOAT: "FLOAT",
IMAG: "IMAG",
CHAR: "CHAR",
STRING: "STRING",
ADD: "+",
SUB: "-",
MUL: "*",
QUO: "/",
REM: "%",
AND: "&",
OR: "|",
XOR: "^",
SHL: "<<",
SHR: ">>",
AND_NOT: "&^",
ADD_ASSIGN: "+=",
SUB_ASSIGN: "-=",
MUL_ASSIGN: "*=",
QUO_ASSIGN: "/=",
REM_ASSIGN: "%=",
AND_ASSIGN: "&=",
OR_ASSIGN: "|=",
XOR_ASSIGN: "^=",
SHL_ASSIGN: "<<=",
SHR_ASSIGN: ">>=",
AND_NOT_ASSIGN: "&^=",
LAND: "&&",
LOR: "||",
ARROW: "<-",
INC: "++",
DEC: "--",
EQL: "==",
LSS: "<",
GTR: ">",
ASSIGN: "=",
NOT: "!",
NEQ: "!=",
LEQ: "<=",
GEQ: ">=",
DEFINE: ":=",
ELLIPSIS: "...",
LPAREN: "(",
LBRACK: "[",
LBRACE: "{",
COMMA: ",",
PERIOD: ".",
RPAREN: ")",
RBRACK: "]",
RBRACE: "}",
SEMICOLON: ";",
COLON: ":",
BREAK: "break",
CASE: "case",
CHAN: "chan",
CONST: "const",
CONTINUE: "continue",
DEFAULT: "default",
DEFER: "defer",
ELSE: "else",
FALLTHROUGH: "fallthrough",
FOR: "for",
FUNC: "func",
GO: "go",
GOTO: "goto",
IF: "if",
IMPORT: "import",
INTERFACE: "interface",
MAP: "map",
PACKAGE: "package",
RANGE: "range",
RETURN: "return",
SELECT: "select",
STRUCT: "struct",
SWITCH: "switch",
TYPE: "type",
VAR: "var",
TILDE: "~",
}
// String returns the string corresponding to the token tok.
// For operators, delimiters, and keywords the string is the actual
// token character sequence (e.g., for the token ADD, the string is
// "+"). For all other tokens the string corresponds to the token
// constant name (e.g. for the token IDENT, the string is "IDENT").
func (tok Token) String() string {
s := ""
if 0 <= tok && tok < Token(len(tokens)) {
s = tokens[tok]
}
if s == "" {
s = "token(" + strconv.Itoa(int(tok)) + ")"
}
return s
}
// A set of constants for precedence-based expression parsing.
// Non-operators have lowest precedence, followed by operators
// starting with precedence 1 up to unary operators. The highest
// precedence serves as "catch-all" precedence for selector,
// indexing, and other operator and delimiter tokens.
const (
LowestPrec = 0 // non-operators
UnaryPrec = 6
HighestPrec = 7
)
// Precedence returns the operator precedence of the binary
// operator op. If op is not a binary operator, the result
// is LowestPrecedence.
func (op Token) Precedence() int {
switch op {
case LOR:
return 1
case LAND:
return 2
case EQL, NEQ, LSS, LEQ, GTR, GEQ:
return 3
case ADD, SUB, OR, XOR:
return 4
case MUL, QUO, REM, SHL, SHR, AND, AND_NOT:
return 5
}
return LowestPrec
}
var keywords map[string]Token
func init() {
keywords = make(map[string]Token, keyword_end-(keyword_beg+1))
for i := keyword_beg + 1; i < keyword_end; i++ {
keywords[tokens[i]] = i
}
}
// Lookup maps an identifier to its keyword token or IDENT (if not a keyword).
func Lookup(ident string) Token {
if tok, is_keyword := keywords[ident]; is_keyword {
return tok
}
return IDENT
}
// Predicates
// IsLiteral returns true for tokens corresponding to identifiers
// and basic type literals; it returns false otherwise.
func (tok Token) IsLiteral() bool { return literal_beg < tok && tok < literal_end }
// IsOperator returns true for tokens corresponding to operators and
// delimiters; it returns false otherwise.
func (tok Token) IsOperator() bool {
return (operator_beg < tok && tok < operator_end) || tok == TILDE
}
// IsKeyword returns true for tokens corresponding to keywords;
// it returns false otherwise.
func (tok Token) IsKeyword() bool { return keyword_beg < tok && tok < keyword_end }
// IsExported reports whether name starts with an upper-case letter.
func IsExported(name string) bool {
ch, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(ch)
}
// IsKeyword reports whether name is a Go keyword, such as "func" or "return".
func IsKeyword(name string) bool {
// TODO: opt: use a perfect hash function instead of a global map.
_, ok := keywords[name]
return ok
}
// IsIdentifier reports whether name is a Go identifier, that is, a non-empty
// string made up of letters, digits, and underscores, where the first character
// is not a digit. Keywords are not identifiers.
func IsIdentifier(name string) bool {
if name == "" || IsKeyword(name) {
return false
}
for i, c := range name {
if !unicode.IsLetter(c) && c != '_' && (i == 0 || !unicode.IsDigit(c)) {
return false
}
}
return true
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package types declares the data types and implements
// the algorithms for type-checking of Go packages. Use
// Config.Check to invoke the type checker for a package.
// Alternatively, create a new type checker with NewChecker
// and invoke it incrementally by calling Checker.Files.
//
// Type-checking consists of several interdependent phases:
//
// Name resolution maps each identifier (ast.Ident) in the program to the
// language object (Object) it denotes.
// Use Info.{Defs,Uses,Implicits} for the results of name resolution.
//
// Constant folding computes the exact constant value (constant.Value)
// for every expression (ast.Expr) that is a compile-time constant.
// Use Info.Types[expr].Value for the results of constant folding.
//
// Type inference computes the type (Type) of every expression (ast.Expr)
// and checks for compliance with the language specification.
// Use Info.Types[expr].Type for the results of type inference.
//
// For a tutorial, see https://golang.org/s/types-tutorial.
package types
import (
"bytes"
"fmt"
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
)
// An Error describes a type-checking error; it implements the error interface.
// A "soft" error is an error that still permits a valid interpretation of a
// package (such as "unused variable"); "hard" errors may lead to unpredictable
// behavior if ignored.
type Error struct {
Fset *token.FileSet // file set for interpretation of Pos
Pos token.Pos // error position
Msg string // error message
Soft bool // if set, error is "soft"
// go116code is a future API, unexported as the set of error codes is large
// and likely to change significantly during experimentation. Tools wishing
// to preview this feature may read go116code using reflection (see
// errorcodes_test.go), but beware that there is no guarantee of future
// compatibility.
go116code Code
go116start token.Pos
go116end token.Pos
}
// Error returns an error string formatted as follows:
// filename:line:column: message
func (err Error) Error() string {
return fmt.Sprintf("%s: %s", err.Fset.Position(err.Pos), err.Msg)
}
// An ArgumentError holds an error associated with an argument index.
type ArgumentError struct {
Index int
Err error
}
func (e *ArgumentError) Error() string { return e.Err.Error() }
func (e *ArgumentError) Unwrap() error { return e.Err }
// An Importer resolves import paths to Packages.
//
// CAUTION: This interface does not support the import of locally
// vendored packages. See https://golang.org/s/go15vendor.
// If possible, external implementations should implement ImporterFrom.
type Importer interface {
// Import returns the imported package for the given import path.
// The semantics is like for ImporterFrom.ImportFrom except that
// dir and mode are ignored (since they are not present).
Import(path string) (*Package, error)
}
// ImportMode is reserved for future use.
type ImportMode int
// An ImporterFrom resolves import paths to packages; it
// supports vendoring per https://golang.org/s/go15vendor.
// Use go/importer to obtain an ImporterFrom implementation.
type ImporterFrom interface {
// Importer is present for backward-compatibility. Calling
// Import(path) is the same as calling ImportFrom(path, "", 0);
// i.e., locally vendored packages may not be found.
// The types package does not call Import if an ImporterFrom
// is present.
Importer
// ImportFrom returns the imported package for the given import
// path when imported by a package file located in dir.
// If the import failed, besides returning an error, ImportFrom
// is encouraged to cache and return a package anyway, if one
// was created. This will reduce package inconsistencies and
// follow-on type checker errors due to the missing package.
// The mode value must be 0; it is reserved for future use.
// Two calls to ImportFrom with the same path and dir must
// return the same package.
ImportFrom(path, dir string, mode ImportMode) (*Package, error)
}
// A Config specifies the configuration for type checking.
// The zero value for Config is a ready-to-use default configuration.
type Config struct {
// Context is the context used for resolving global identifiers. If nil, the
// type checker will initialize this field with a newly created context.
Context *Context
// GoVersion describes the accepted Go language version. The string
// must follow the format "go%d.%d" (e.g. "go1.12") or it must be
// empty; an empty string indicates the latest language version.
// If the format is invalid, invoking the type checker will cause a
// panic.
GoVersion string
// If IgnoreFuncBodies is set, function bodies are not
// type-checked.
IgnoreFuncBodies bool
// If FakeImportC is set, `import "C"` (for packages requiring Cgo)
// declares an empty "C" package and errors are omitted for qualified
// identifiers referring to package C (which won't find an object).
// This feature is intended for the standard library cmd/api tool.
//
// Caution: Effects may be unpredictable due to follow-on errors.
// Do not use casually!
FakeImportC bool
// If go115UsesCgo is set, the type checker expects the
// _cgo_gotypes.go file generated by running cmd/cgo to be
// provided as a package source file. Qualified identifiers
// referring to package C will be resolved to cgo-provided
// declarations within _cgo_gotypes.go.
//
// It is an error to set both FakeImportC and go115UsesCgo.
go115UsesCgo bool
// If _Trace is set, a debug trace is printed to stdout.
_Trace bool
// If Error != nil, it is called with each error found
// during type checking; err has dynamic type Error.
// Secondary errors (for instance, to enumerate all types
// involved in an invalid recursive type declaration) have
// error strings that start with a '\t' character.
// If Error == nil, type-checking stops with the first
// error found.
Error func(err error)
// An importer is used to import packages referred to from
// import declarations.
// If the installed importer implements ImporterFrom, the type
// checker calls ImportFrom instead of Import.
// The type checker reports an error if an importer is needed
// but none was installed.
Importer Importer
// If Sizes != nil, it provides the sizing functions for package unsafe.
// Otherwise SizesFor("gc", "amd64") is used instead.
Sizes Sizes
// If DisableUnusedImportCheck is set, packages are not checked
// for unused imports.
DisableUnusedImportCheck bool
}
func srcimporter_setUsesCgo(conf *Config) {
conf.go115UsesCgo = true
}
// Info holds result type information for a type-checked package.
// Only the information for which a map is provided is collected.
// If the package has type errors, the collected information may
// be incomplete.
type Info struct {
// Types maps expressions to their types, and for constant
// expressions, also their values. Invalid expressions are
// omitted.
//
// For (possibly parenthesized) identifiers denoting built-in
// functions, the recorded signatures are call-site specific:
// if the call result is not a constant, the recorded type is
// an argument-specific signature. Otherwise, the recorded type
// is invalid.
//
// The Types map does not record the type of every identifier,
// only those that appear where an arbitrary expression is
// permitted. For instance, the identifier f in a selector
// expression x.f is found only in the Selections map, the
// identifier z in a variable declaration 'var z int' is found
// only in the Defs map, and identifiers denoting packages in
// qualified identifiers are collected in the Uses map.
Types map[ast.Expr]TypeAndValue
// Instances maps identifiers denoting generic types or functions to their
// type arguments and instantiated type.
//
// For example, Instances will map the identifier for 'T' in the type
// instantiation T[int, string] to the type arguments [int, string] and
// resulting instantiated *Named type. Given a generic function
// func F[A any](A), Instances will map the identifier for 'F' in the call
// expression F(int(1)) to the inferred type arguments [int], and resulting
// instantiated *Signature.
//
// Invariant: Instantiating Uses[id].Type() with Instances[id].TypeArgs
// results in an equivalent of Instances[id].Type.
Instances map[*ast.Ident]Instance
// Defs maps identifiers to the objects they define (including
// package names, dots "." of dot-imports, and blank "_" identifiers).
// For identifiers that do not denote objects (e.g., the package name
// in package clauses, or symbolic variables t in t := x.(type) of
// type switch headers), the corresponding objects are nil.
//
// For an embedded field, Defs returns the field *Var it defines.
//
// Invariant: Defs[id] == nil || Defs[id].Pos() == id.Pos()
Defs map[*ast.Ident]Object
// Uses maps identifiers to the objects they denote.
//
// For an embedded field, Uses returns the *TypeName it denotes.
//
// Invariant: Uses[id].Pos() != id.Pos()
Uses map[*ast.Ident]Object
// Implicits maps nodes to their implicitly declared objects, if any.
// The following node and object types may appear:
//
// node declared object
//
// *ast.ImportSpec *PkgName for imports without renames
// *ast.CaseClause type-specific *Var for each type switch case clause (incl. default)
// *ast.Field anonymous parameter *Var (incl. unnamed results)
//
Implicits map[ast.Node]Object
// Selections maps selector expressions (excluding qualified identifiers)
// to their corresponding selections.
Selections map[*ast.SelectorExpr]*Selection
// Scopes maps ast.Nodes to the scopes they define. Package scopes are not
// associated with a specific node but with all files belonging to a package.
// Thus, the package scope can be found in the type-checked Package object.
// Scopes nest, with the Universe scope being the outermost scope, enclosing
// the package scope, which contains (one or more) files scopes, which enclose
// function scopes which in turn enclose statement and function literal scopes.
// Note that even though package-level functions are declared in the package
// scope, the function scopes are embedded in the file scope of the file
// containing the function declaration.
//
// The following node types may appear in Scopes:
//
// *ast.File
// *ast.FuncType
// *ast.TypeSpec
// *ast.BlockStmt
// *ast.IfStmt
// *ast.SwitchStmt
// *ast.TypeSwitchStmt
// *ast.CaseClause
// *ast.CommClause
// *ast.ForStmt
// *ast.RangeStmt
//
Scopes map[ast.Node]*Scope
// InitOrder is the list of package-level initializers in the order in which
// they must be executed. Initializers referring to variables related by an
// initialization dependency appear in topological order, the others appear
// in source order. Variables without an initialization expression do not
// appear in this list.
InitOrder []*Initializer
}
// TypeOf returns the type of expression e, or nil if not found.
// Precondition: the Types, Uses and Defs maps are populated.
func (info *Info) TypeOf(e ast.Expr) Type {
if t, ok := info.Types[e]; ok {
return t.Type
}
if id, _ := e.(*ast.Ident); id != nil {
if obj := info.ObjectOf(id); obj != nil {
return obj.Type()
}
}
return nil
}
// ObjectOf returns the object denoted by the specified id,
// or nil if not found.
//
// If id is an embedded struct field, ObjectOf returns the field (*Var)
// it defines, not the type (*TypeName) it uses.
//
// Precondition: the Uses and Defs maps are populated.
func (info *Info) ObjectOf(id *ast.Ident) Object {
if obj := info.Defs[id]; obj != nil {
return obj
}
return info.Uses[id]
}
// TypeAndValue reports the type and value (for constants)
// of the corresponding expression.
type TypeAndValue struct {
mode operandMode
Type Type
Value constant.Value
}
// IsVoid reports whether the corresponding expression
// is a function call without results.
func (tv TypeAndValue) IsVoid() bool {
return tv.mode == novalue
}
// IsType reports whether the corresponding expression specifies a type.
func (tv TypeAndValue) IsType() bool {
return tv.mode == typexpr
}
// IsBuiltin reports whether the corresponding expression denotes
// a (possibly parenthesized) built-in function.
func (tv TypeAndValue) IsBuiltin() bool {
return tv.mode == builtin
}
// IsValue reports whether the corresponding expression is a value.
// Builtins are not considered values. Constant values have a non-
// nil Value.
func (tv TypeAndValue) IsValue() bool {
switch tv.mode {
case constant_, variable, mapindex, value, commaok, commaerr:
return true
}
return false
}
// IsNil reports whether the corresponding expression denotes the
// predeclared value nil.
func (tv TypeAndValue) IsNil() bool {
return tv.mode == value && tv.Type == Typ[UntypedNil]
}
// Addressable reports whether the corresponding expression
// is addressable (https://golang.org/ref/spec#Address_operators).
func (tv TypeAndValue) Addressable() bool {
return tv.mode == variable
}
// Assignable reports whether the corresponding expression
// is assignable to (provided a value of the right type).
func (tv TypeAndValue) Assignable() bool {
return tv.mode == variable || tv.mode == mapindex
}
// HasOk reports whether the corresponding expression may be
// used on the rhs of a comma-ok assignment.
func (tv TypeAndValue) HasOk() bool {
return tv.mode == commaok || tv.mode == mapindex
}
// Instance reports the type arguments and instantiated type for type and
// function instantiations. For type instantiations, Type will be of dynamic
// type *Named. For function instantiations, Type will be of dynamic type
// *Signature.
type Instance struct {
TypeArgs *TypeList
Type Type
}
// An Initializer describes a package-level variable, or a list of variables in case
// of a multi-valued initialization expression, and the corresponding initialization
// expression.
type Initializer struct {
Lhs []*Var // var Lhs = Rhs
Rhs ast.Expr
}
func (init *Initializer) String() string {
var buf bytes.Buffer
for i, lhs := range init.Lhs {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(lhs.Name())
}
buf.WriteString(" = ")
WriteExpr(&buf, init.Rhs)
return buf.String()
}
// Check type-checks a package and returns the resulting package object and
// the first error if any. Additionally, if info != nil, Check populates each
// of the non-nil maps in the Info struct.
//
// The package is marked as complete if no errors occurred, otherwise it is
// incomplete. See Config.Error for controlling behavior in the presence of
// errors.
//
// The package is specified by a list of *ast.Files and corresponding
// file set, and the package path the package is identified with.
// The clean path must not be empty or dot (".").
func (conf *Config) Check(path string, fset *token.FileSet, files []*ast.File, info *Info) (*Package, error) {
pkg := NewPackage(path, "")
return pkg, NewChecker(conf, fset, pkg, info).Files(files)
}
// AssertableTo reports whether a value of type V can be asserted to have type T.
//
// The behavior of AssertableTo is unspecified in three cases:
// - if T is Typ[Invalid]
// - if V is a generalized interface; i.e., an interface that may only be used
// as a type constraint in Go code
// - if T is an uninstantiated generic type
func AssertableTo(V *Interface, T Type) bool {
// Checker.newAssertableTo suppresses errors for invalid types, so we need special
// handling here.
if T.Underlying() == Typ[Invalid] {
return false
}
return (*Checker)(nil).newAssertableTo(V, T, nil)
}
// AssignableTo reports whether a value of type V is assignable to a variable
// of type T.
//
// The behavior of AssignableTo is unspecified if V or T is Typ[Invalid] or an
// uninstantiated generic type.
func AssignableTo(V, T Type) bool {
x := operand{mode: value, typ: V}
ok, _ := x.assignableTo(nil, T, nil) // check not needed for non-constant x
return ok
}
// ConvertibleTo reports whether a value of type V is convertible to a value of
// type T.
//
// The behavior of ConvertibleTo is unspecified if V or T is Typ[Invalid] or an
// uninstantiated generic type.
func ConvertibleTo(V, T Type) bool {
x := operand{mode: value, typ: V}
return x.convertibleTo(nil, T, nil) // check not needed for non-constant x
}
// Implements reports whether type V implements interface T.
//
// The behavior of Implements is unspecified if V is Typ[Invalid] or an uninstantiated
// generic type.
func Implements(V Type, T *Interface) bool {
if T.Empty() {
// All types (even Typ[Invalid]) implement the empty interface.
return true
}
// Checker.implements suppresses errors for invalid types, so we need special
// handling here.
if V.Underlying() == Typ[Invalid] {
return false
}
return (*Checker)(nil).implements(V, T, false, nil)
}
// Satisfies reports whether type V satisfies the constraint T.
//
// The behavior of Satisfies is unspecified if V is Typ[Invalid] or an uninstantiated
// generic type.
func Satisfies(V Type, T *Interface) bool {
return (*Checker)(nil).implements(V, T, true, nil)
}
// Identical reports whether x and y are identical types.
// Receivers of Signature types are ignored.
func Identical(x, y Type) bool {
var c comparer
return c.identical(x, y, nil)
}
// IdenticalIgnoreTags reports whether x and y are identical types if tags are ignored.
// Receivers of Signature types are ignored.
func IdenticalIgnoreTags(x, y Type) bool {
var c comparer
c.ignoreTags = true
return c.identical(x, y, nil)
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// An Array represents an array type.
type Array struct {
len int64
elem Type
}
// NewArray returns a new array type for the given element type and length.
// A negative length indicates an unknown length.
func NewArray(elem Type, len int64) *Array { return &Array{len: len, elem: elem} }
// Len returns the length of array a.
// A negative result indicates an unknown length.
func (a *Array) Len() int64 { return a.len }
// Elem returns element type of array a.
func (a *Array) Elem() Type { return a.elem }
func (a *Array) Underlying() Type { return a }
func (a *Array) String() string { return TypeString(a, nil) }
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements initialization and assignment checks.
package types
import (
"fmt"
"go/ast"
. "internal/types/errors"
"strings"
)
// assignment reports whether x can be assigned to a variable of type T,
// if necessary by attempting to convert untyped values to the appropriate
// type. context describes the context in which the assignment takes place.
// Use T == nil to indicate assignment to an untyped blank identifier.
// x.mode is set to invalid if the assignment failed.
func (check *Checker) assignment(x *operand, T Type, context string) {
check.singleValue(x)
switch x.mode {
case invalid:
return // error reported before
case constant_, variable, mapindex, value, commaok, commaerr:
// ok
default:
// we may get here because of other problems (go.dev/issue/39634, crash 12)
// TODO(gri) do we need a new "generic" error code here?
check.errorf(x, IncompatibleAssign, "cannot assign %s to %s in %s", x, T, context)
return
}
if isUntyped(x.typ) {
target := T
// spec: "If an untyped constant is assigned to a variable of interface
// type or the blank identifier, the constant is first converted to type
// bool, rune, int, float64, complex128 or string respectively, depending
// on whether the value is a boolean, rune, integer, floating-point,
// complex, or string constant."
if T == nil || isNonTypeParamInterface(T) {
if T == nil && x.typ == Typ[UntypedNil] {
check.errorf(x, UntypedNilUse, "use of untyped nil in %s", context)
x.mode = invalid
return
}
target = Default(x.typ)
}
newType, val, code := check.implicitTypeAndValue(x, target)
if code != 0 {
msg := check.sprintf("cannot use %s as %s value in %s", x, target, context)
switch code {
case TruncatedFloat:
msg += " (truncated)"
case NumericOverflow:
msg += " (overflows)"
default:
code = IncompatibleAssign
}
check.error(x, code, msg)
x.mode = invalid
return
}
if val != nil {
x.val = val
check.updateExprVal(x.expr, val)
}
if newType != x.typ {
x.typ = newType
check.updateExprType(x.expr, newType, false)
}
}
// A generic (non-instantiated) function value cannot be assigned to a variable.
if sig, _ := under(x.typ).(*Signature); sig != nil && sig.TypeParams().Len() > 0 {
check.errorf(x, WrongTypeArgCount, "cannot use generic function %s without instantiation in %s", x, context)
}
// spec: "If a left-hand side is the blank identifier, any typed or
// non-constant value except for the predeclared identifier nil may
// be assigned to it."
if T == nil {
return
}
cause := ""
if ok, code := x.assignableTo(check, T, &cause); !ok {
if cause != "" {
check.errorf(x, code, "cannot use %s as %s value in %s: %s", x, T, context, cause)
} else {
check.errorf(x, code, "cannot use %s as %s value in %s", x, T, context)
}
x.mode = invalid
}
}
func (check *Checker) initConst(lhs *Const, x *operand) {
if x.mode == invalid || x.typ == Typ[Invalid] || lhs.typ == Typ[Invalid] {
if lhs.typ == nil {
lhs.typ = Typ[Invalid]
}
return
}
// rhs must be a constant
if x.mode != constant_ {
check.errorf(x, InvalidConstInit, "%s is not constant", x)
if lhs.typ == nil {
lhs.typ = Typ[Invalid]
}
return
}
assert(isConstType(x.typ))
// If the lhs doesn't have a type yet, use the type of x.
if lhs.typ == nil {
lhs.typ = x.typ
}
check.assignment(x, lhs.typ, "constant declaration")
if x.mode == invalid {
return
}
lhs.val = x.val
}
func (check *Checker) initVar(lhs *Var, x *operand, context string) Type {
if x.mode == invalid || x.typ == Typ[Invalid] || lhs.typ == Typ[Invalid] {
if lhs.typ == nil {
lhs.typ = Typ[Invalid]
}
return nil
}
// If the lhs doesn't have a type yet, use the type of x.
if lhs.typ == nil {
typ := x.typ
if isUntyped(typ) {
// convert untyped types to default types
if typ == Typ[UntypedNil] {
check.errorf(x, UntypedNilUse, "use of untyped nil in %s", context)
lhs.typ = Typ[Invalid]
return nil
}
typ = Default(typ)
}
lhs.typ = typ
}
check.assignment(x, lhs.typ, context)
if x.mode == invalid {
return nil
}
return x.typ
}
func (check *Checker) assignVar(lhs ast.Expr, x *operand) Type {
if x.mode == invalid || x.typ == Typ[Invalid] {
check.useLHS(lhs)
return nil
}
// Determine if the lhs is a (possibly parenthesized) identifier.
ident, _ := unparen(lhs).(*ast.Ident)
// Don't evaluate lhs if it is the blank identifier.
if ident != nil && ident.Name == "_" {
check.recordDef(ident, nil)
check.assignment(x, nil, "assignment to _ identifier")
if x.mode == invalid {
return nil
}
return x.typ
}
// If the lhs is an identifier denoting a variable v, this assignment
// is not a 'use' of v. Remember current value of v.used and restore
// after evaluating the lhs via check.expr.
var v *Var
var v_used bool
if ident != nil {
if obj := check.lookup(ident.Name); obj != nil {
// It's ok to mark non-local variables, but ignore variables
// from other packages to avoid potential race conditions with
// dot-imported variables.
if w, _ := obj.(*Var); w != nil && w.pkg == check.pkg {
v = w
v_used = v.used
}
}
}
var z operand
check.expr(&z, lhs)
if v != nil {
v.used = v_used // restore v.used
}
if z.mode == invalid || z.typ == Typ[Invalid] {
return nil
}
// spec: "Each left-hand side operand must be addressable, a map index
// expression, or the blank identifier. Operands may be parenthesized."
switch z.mode {
case invalid:
return nil
case variable, mapindex:
// ok
default:
if sel, ok := z.expr.(*ast.SelectorExpr); ok {
var op operand
check.expr(&op, sel.X)
if op.mode == mapindex {
check.errorf(&z, UnaddressableFieldAssign, "cannot assign to struct field %s in map", ExprString(z.expr))
return nil
}
}
check.errorf(&z, UnassignableOperand, "cannot assign to %s", &z)
return nil
}
check.assignment(x, z.typ, "assignment")
if x.mode == invalid {
return nil
}
return x.typ
}
// operandTypes returns the list of types for the given operands.
func operandTypes(list []*operand) (res []Type) {
for _, x := range list {
res = append(res, x.typ)
}
return res
}
// varTypes returns the list of types for the given variables.
func varTypes(list []*Var) (res []Type) {
for _, x := range list {
res = append(res, x.typ)
}
return res
}
// typesSummary returns a string of the form "(t1, t2, ...)" where the
// ti's are user-friendly string representations for the given types.
// If variadic is set and the last type is a slice, its string is of
// the form "...E" where E is the slice's element type.
func (check *Checker) typesSummary(list []Type, variadic bool) string {
var res []string
for i, t := range list {
var s string
switch {
case t == nil:
fallthrough // should not happen but be cautious
case t == Typ[Invalid]:
s = "<T>"
case isUntyped(t):
if isNumeric(t) {
// Do not imply a specific type requirement:
// "have number, want float64" is better than
// "have untyped int, want float64" or
// "have int, want float64".
s = "number"
} else {
// If we don't have a number, omit the "untyped" qualifier
// for compactness.
s = strings.Replace(t.(*Basic).name, "untyped ", "", -1)
}
case variadic && i == len(list)-1:
s = check.sprintf("...%s", t.(*Slice).elem)
}
if s == "" {
s = check.sprintf("%s", t)
}
res = append(res, s)
}
return "(" + strings.Join(res, ", ") + ")"
}
func measure(x int, unit string) string {
if x != 1 {
unit += "s"
}
return fmt.Sprintf("%d %s", x, unit)
}
func (check *Checker) assignError(rhs []ast.Expr, nvars, nvals int) {
vars := measure(nvars, "variable")
vals := measure(nvals, "value")
rhs0 := rhs[0]
if len(rhs) == 1 {
if call, _ := unparen(rhs0).(*ast.CallExpr); call != nil {
check.errorf(rhs0, WrongAssignCount, "assignment mismatch: %s but %s returns %s", vars, call.Fun, vals)
return
}
}
check.errorf(rhs0, WrongAssignCount, "assignment mismatch: %s but %s", vars, vals)
}
// If returnStmt != nil, initVars is called to type-check the assignment
// of return expressions, and returnStmt is the return statement.
func (check *Checker) initVars(lhs []*Var, origRHS []ast.Expr, returnStmt ast.Stmt) {
rhs, commaOk := check.exprList(origRHS, len(lhs) == 2 && returnStmt == nil)
if len(lhs) != len(rhs) {
// invalidate lhs
for _, obj := range lhs {
obj.used = true // avoid declared and not used errors
if obj.typ == nil {
obj.typ = Typ[Invalid]
}
}
// don't report an error if we already reported one
for _, x := range rhs {
if x.mode == invalid {
return
}
}
if returnStmt != nil {
var at positioner = returnStmt
qualifier := "not enough"
if len(rhs) > len(lhs) {
at = rhs[len(lhs)].expr // report at first extra value
qualifier = "too many"
} else if len(rhs) > 0 {
at = rhs[len(rhs)-1].expr // report at last value
}
err := newErrorf(at, WrongResultCount, "%s return values", qualifier)
err.errorf(nopos, "have %s", check.typesSummary(operandTypes(rhs), false))
err.errorf(nopos, "want %s", check.typesSummary(varTypes(lhs), false))
check.report(err)
return
}
check.assignError(origRHS, len(lhs), len(rhs))
return
}
context := "assignment"
if returnStmt != nil {
context = "return statement"
}
if commaOk {
var a [2]Type
for i := range a {
a[i] = check.initVar(lhs[i], rhs[i], context)
}
check.recordCommaOkTypes(origRHS[0], a)
return
}
for i, lhs := range lhs {
check.initVar(lhs, rhs[i], context)
}
}
func (check *Checker) assignVars(lhs, origRHS []ast.Expr) {
rhs, commaOk := check.exprList(origRHS, len(lhs) == 2)
if len(lhs) != len(rhs) {
check.useLHS(lhs...)
// don't report an error if we already reported one
for _, x := range rhs {
if x.mode == invalid {
return
}
}
check.assignError(origRHS, len(lhs), len(rhs))
return
}
if commaOk {
var a [2]Type
for i := range a {
a[i] = check.assignVar(lhs[i], rhs[i])
}
check.recordCommaOkTypes(origRHS[0], a)
return
}
for i, lhs := range lhs {
check.assignVar(lhs, rhs[i])
}
}
func (check *Checker) shortVarDecl(pos positioner, lhs, rhs []ast.Expr) {
top := len(check.delayed)
scope := check.scope
// collect lhs variables
seen := make(map[string]bool, len(lhs))
lhsVars := make([]*Var, len(lhs))
newVars := make([]*Var, 0, len(lhs))
hasErr := false
for i, lhs := range lhs {
ident, _ := lhs.(*ast.Ident)
if ident == nil {
check.useLHS(lhs)
// TODO(rFindley) this is redundant with a parser error. Consider omitting?
check.errorf(lhs, BadDecl, "non-name %s on left side of :=", lhs)
hasErr = true
continue
}
name := ident.Name
if name != "_" {
if seen[name] {
check.errorf(lhs, RepeatedDecl, "%s repeated on left side of :=", lhs)
hasErr = true
continue
}
seen[name] = true
}
// Use the correct obj if the ident is redeclared. The
// variable's scope starts after the declaration; so we
// must use Scope.Lookup here and call Scope.Insert
// (via check.declare) later.
if alt := scope.Lookup(name); alt != nil {
check.recordUse(ident, alt)
// redeclared object must be a variable
if obj, _ := alt.(*Var); obj != nil {
lhsVars[i] = obj
} else {
check.errorf(lhs, UnassignableOperand, "cannot assign to %s", lhs)
hasErr = true
}
continue
}
// declare new variable
obj := NewVar(ident.Pos(), check.pkg, name, nil)
lhsVars[i] = obj
if name != "_" {
newVars = append(newVars, obj)
}
check.recordDef(ident, obj)
}
// create dummy variables where the lhs is invalid
for i, obj := range lhsVars {
if obj == nil {
lhsVars[i] = NewVar(lhs[i].Pos(), check.pkg, "_", nil)
}
}
check.initVars(lhsVars, rhs, nil)
// process function literals in rhs expressions before scope changes
check.processDelayed(top)
if len(newVars) == 0 && !hasErr {
check.softErrorf(pos, NoNewVar, "no new variables on left side of :=")
return
}
// declare new variables
// spec: "The scope of a constant or variable identifier declared inside
// a function begins at the end of the ConstSpec or VarSpec (ShortVarDecl
// for short variable declarations) and ends at the end of the innermost
// containing block."
scopePos := rhs[len(rhs)-1].End()
for _, obj := range newVars {
check.declare(scope, nil, obj, scopePos) // id = nil: recordDef already called
}
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// BasicKind describes the kind of basic type.
type BasicKind int
const (
Invalid BasicKind = iota // type is invalid
// predeclared types
Bool
Int
Int8
Int16
Int32
Int64
Uint
Uint8
Uint16
Uint32
Uint64
Uintptr
Float32
Float64
Complex64
Complex128
String
UnsafePointer
// types for untyped values
UntypedBool
UntypedInt
UntypedRune
UntypedFloat
UntypedComplex
UntypedString
UntypedNil
// aliases
Byte = Uint8
Rune = Int32
)
// BasicInfo is a set of flags describing properties of a basic type.
type BasicInfo int
// Properties of basic types.
const (
IsBoolean BasicInfo = 1 << iota
IsInteger
IsUnsigned
IsFloat
IsComplex
IsString
IsUntyped
IsOrdered = IsInteger | IsFloat | IsString
IsNumeric = IsInteger | IsFloat | IsComplex
IsConstType = IsBoolean | IsNumeric | IsString
)
// A Basic represents a basic type.
type Basic struct {
kind BasicKind
info BasicInfo
name string
}
// Kind returns the kind of basic type b.
func (b *Basic) Kind() BasicKind { return b.kind }
// Info returns information about properties of basic type b.
func (b *Basic) Info() BasicInfo { return b.info }
// Name returns the name of basic type b.
func (b *Basic) Name() string { return b.name }
func (b *Basic) Underlying() Type { return b }
func (b *Basic) String() string { return TypeString(b, nil) }
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of builtin function calls.
package types
import (
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
)
// builtin type-checks a call to the built-in specified by id and
// reports whether the call is valid, with *x holding the result;
// but x.expr is not set. If the call is invalid, the result is
// false, and *x is undefined.
func (check *Checker) builtin(x *operand, call *ast.CallExpr, id builtinId) (_ bool) {
// append is the only built-in that permits the use of ... for the last argument
bin := predeclaredFuncs[id]
if call.Ellipsis.IsValid() && id != _Append {
check.errorf(atPos(call.Ellipsis),
InvalidDotDotDot,
invalidOp+"invalid use of ... with built-in %s", bin.name)
check.use(call.Args...)
return
}
// For len(x) and cap(x) we need to know if x contains any function calls or
// receive operations. Save/restore current setting and set hasCallOrRecv to
// false for the evaluation of x so that we can check it afterwards.
// Note: We must do this _before_ calling exprList because exprList evaluates
// all arguments.
if id == _Len || id == _Cap {
defer func(b bool) {
check.hasCallOrRecv = b
}(check.hasCallOrRecv)
check.hasCallOrRecv = false
}
// determine actual arguments
var arg func(*operand, int) // TODO(gri) remove use of arg getter in favor of using xlist directly
nargs := len(call.Args)
switch id {
default:
// make argument getter
xlist, _ := check.exprList(call.Args, false)
arg = func(x *operand, i int) { *x = *xlist[i] }
nargs = len(xlist)
// evaluate first argument, if present
if nargs > 0 {
arg(x, 0)
if x.mode == invalid {
return
}
}
case _Make, _New, _Offsetof, _Trace:
// arguments require special handling
}
// check argument count
{
msg := ""
if nargs < bin.nargs {
msg = "not enough"
} else if !bin.variadic && nargs > bin.nargs {
msg = "too many"
}
if msg != "" {
check.errorf(inNode(call, call.Rparen), WrongArgCount, invalidOp+"%s arguments for %s (expected %d, found %d)", msg, call, bin.nargs, nargs)
return
}
}
switch id {
case _Append:
// append(s S, x ...T) S, where T is the element type of S
// spec: "The variadic function append appends zero or more values x to s of type
// S, which must be a slice type, and returns the resulting slice, also of type S.
// The values x are passed to a parameter of type ...T where T is the element type
// of S and the respective parameter passing rules apply."
S := x.typ
var T Type
if s, _ := coreType(S).(*Slice); s != nil {
T = s.elem
} else {
var cause string
switch {
case x.isNil():
cause = "have untyped nil"
case isTypeParam(S):
if u := coreType(S); u != nil {
cause = check.sprintf("%s has core type %s", x, u)
} else {
cause = check.sprintf("%s has no core type", x)
}
default:
cause = check.sprintf("have %s", x)
}
// don't use Checker.invalidArg here as it would repeat "argument" in the error message
check.errorf(x, InvalidAppend, "first argument to append must be a slice; %s", cause)
return
}
// remember arguments that have been evaluated already
alist := []operand{*x}
// spec: "As a special case, append also accepts a first argument assignable
// to type []byte with a second argument of string type followed by ... .
// This form appends the bytes of the string.
if nargs == 2 && call.Ellipsis.IsValid() {
if ok, _ := x.assignableTo(check, NewSlice(universeByte), nil); ok {
arg(x, 1)
if x.mode == invalid {
return
}
if t := coreString(x.typ); t != nil && isString(t) {
if check.Types != nil {
sig := makeSig(S, S, x.typ)
sig.variadic = true
check.recordBuiltinType(call.Fun, sig)
}
x.mode = value
x.typ = S
break
}
alist = append(alist, *x)
// fallthrough
}
}
// check general case by creating custom signature
sig := makeSig(S, S, NewSlice(T)) // []T required for variadic signature
sig.variadic = true
var xlist []*operand
// convert []operand to []*operand
for i := range alist {
xlist = append(xlist, &alist[i])
}
for i := len(alist); i < nargs; i++ {
var x operand
arg(&x, i)
xlist = append(xlist, &x)
}
check.arguments(call, sig, nil, xlist, nil) // discard result (we know the result type)
// ok to continue even if check.arguments reported errors
x.mode = value
x.typ = S
if check.Types != nil {
check.recordBuiltinType(call.Fun, sig)
}
case _Cap, _Len:
// cap(x)
// len(x)
mode := invalid
var val constant.Value
switch t := arrayPtrDeref(under(x.typ)).(type) {
case *Basic:
if isString(t) && id == _Len {
if x.mode == constant_ {
mode = constant_
val = constant.MakeInt64(int64(len(constant.StringVal(x.val))))
} else {
mode = value
}
}
case *Array:
mode = value
// spec: "The expressions len(s) and cap(s) are constants
// if the type of s is an array or pointer to an array and
// the expression s does not contain channel receives or
// function calls; in this case s is not evaluated."
if !check.hasCallOrRecv {
mode = constant_
if t.len >= 0 {
val = constant.MakeInt64(t.len)
} else {
val = constant.MakeUnknown()
}
}
case *Slice, *Chan:
mode = value
case *Map:
if id == _Len {
mode = value
}
case *Interface:
if !isTypeParam(x.typ) {
break
}
if t.typeSet().underIs(func(t Type) bool {
switch t := arrayPtrDeref(t).(type) {
case *Basic:
if isString(t) && id == _Len {
return true
}
case *Array, *Slice, *Chan:
return true
case *Map:
if id == _Len {
return true
}
}
return false
}) {
mode = value
}
}
if mode == invalid && under(x.typ) != Typ[Invalid] {
code := InvalidCap
if id == _Len {
code = InvalidLen
}
check.errorf(x, code, invalidArg+"%s for %s", x, bin.name)
return
}
// record the signature before changing x.typ
if check.Types != nil && mode != constant_ {
check.recordBuiltinType(call.Fun, makeSig(Typ[Int], x.typ))
}
x.mode = mode
x.typ = Typ[Int]
x.val = val
case _Clear:
// clear(m)
if !check.allowVersion(check.pkg, 1, 21) {
check.error(call.Fun, UnsupportedFeature, "clear requires go1.21 or later")
return
}
if !underIs(x.typ, func(u Type) bool {
switch u.(type) {
case *Map, *Slice:
return true
}
check.errorf(x, InvalidClear, invalidArg+"cannot clear %s: argument must be (or constrained by) map or slice", x)
return false
}) {
return
}
x.mode = novalue
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(nil, x.typ))
}
case _Close:
// close(c)
if !underIs(x.typ, func(u Type) bool {
uch, _ := u.(*Chan)
if uch == nil {
check.errorf(x, InvalidClose, invalidOp+"cannot close non-channel %s", x)
return false
}
if uch.dir == RecvOnly {
check.errorf(x, InvalidClose, invalidOp+"cannot close receive-only channel %s", x)
return false
}
return true
}) {
return
}
x.mode = novalue
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(nil, x.typ))
}
case _Complex:
// complex(x, y floatT) complexT
var y operand
arg(&y, 1)
if y.mode == invalid {
return
}
// convert or check untyped arguments
d := 0
if isUntyped(x.typ) {
d |= 1
}
if isUntyped(y.typ) {
d |= 2
}
switch d {
case 0:
// x and y are typed => nothing to do
case 1:
// only x is untyped => convert to type of y
check.convertUntyped(x, y.typ)
case 2:
// only y is untyped => convert to type of x
check.convertUntyped(&y, x.typ)
case 3:
// x and y are untyped =>
// 1) if both are constants, convert them to untyped
// floating-point numbers if possible,
// 2) if one of them is not constant (possible because
// it contains a shift that is yet untyped), convert
// both of them to float64 since they must have the
// same type to succeed (this will result in an error
// because shifts of floats are not permitted)
if x.mode == constant_ && y.mode == constant_ {
toFloat := func(x *operand) {
if isNumeric(x.typ) && constant.Sign(constant.Imag(x.val)) == 0 {
x.typ = Typ[UntypedFloat]
}
}
toFloat(x)
toFloat(&y)
} else {
check.convertUntyped(x, Typ[Float64])
check.convertUntyped(&y, Typ[Float64])
// x and y should be invalid now, but be conservative
// and check below
}
}
if x.mode == invalid || y.mode == invalid {
return
}
// both argument types must be identical
if !Identical(x.typ, y.typ) {
check.errorf(x, InvalidComplex, invalidArg+"mismatched types %s and %s", x.typ, y.typ)
return
}
// the argument types must be of floating-point type
// (applyTypeFunc never calls f with a type parameter)
f := func(typ Type) Type {
assert(!isTypeParam(typ))
if t, _ := under(typ).(*Basic); t != nil {
switch t.kind {
case Float32:
return Typ[Complex64]
case Float64:
return Typ[Complex128]
case UntypedFloat:
return Typ[UntypedComplex]
}
}
return nil
}
resTyp := check.applyTypeFunc(f, x, id)
if resTyp == nil {
check.errorf(x, InvalidComplex, invalidArg+"arguments have type %s, expected floating-point", x.typ)
return
}
// if both arguments are constants, the result is a constant
if x.mode == constant_ && y.mode == constant_ {
x.val = constant.BinaryOp(constant.ToFloat(x.val), token.ADD, constant.MakeImag(constant.ToFloat(y.val)))
} else {
x.mode = value
}
if check.Types != nil && x.mode != constant_ {
check.recordBuiltinType(call.Fun, makeSig(resTyp, x.typ, x.typ))
}
x.typ = resTyp
case _Copy:
// copy(x, y []T) int
dst, _ := coreType(x.typ).(*Slice)
var y operand
arg(&y, 1)
if y.mode == invalid {
return
}
src0 := coreString(y.typ)
if src0 != nil && isString(src0) {
src0 = NewSlice(universeByte)
}
src, _ := src0.(*Slice)
if dst == nil || src == nil {
check.errorf(x, InvalidCopy, invalidArg+"copy expects slice arguments; found %s and %s", x, &y)
return
}
if !Identical(dst.elem, src.elem) {
check.errorf(x, InvalidCopy, "arguments to copy %s and %s have different element types %s and %s", x, &y, dst.elem, src.elem)
return
}
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(Typ[Int], x.typ, y.typ))
}
x.mode = value
x.typ = Typ[Int]
case _Delete:
// delete(map_, key)
// map_ must be a map type or a type parameter describing map types.
// The key cannot be a type parameter for now.
map_ := x.typ
var key Type
if !underIs(map_, func(u Type) bool {
map_, _ := u.(*Map)
if map_ == nil {
check.errorf(x, InvalidDelete, invalidArg+"%s is not a map", x)
return false
}
if key != nil && !Identical(map_.key, key) {
check.errorf(x, InvalidDelete, invalidArg+"maps of %s must have identical key types", x)
return false
}
key = map_.key
return true
}) {
return
}
arg(x, 1) // k
if x.mode == invalid {
return
}
check.assignment(x, key, "argument to delete")
if x.mode == invalid {
return
}
x.mode = novalue
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(nil, map_, key))
}
case _Imag, _Real:
// imag(complexT) floatT
// real(complexT) floatT
// convert or check untyped argument
if isUntyped(x.typ) {
if x.mode == constant_ {
// an untyped constant number can always be considered
// as a complex constant
if isNumeric(x.typ) {
x.typ = Typ[UntypedComplex]
}
} else {
// an untyped non-constant argument may appear if
// it contains a (yet untyped non-constant) shift
// expression: convert it to complex128 which will
// result in an error (shift of complex value)
check.convertUntyped(x, Typ[Complex128])
// x should be invalid now, but be conservative and check
if x.mode == invalid {
return
}
}
}
// the argument must be of complex type
// (applyTypeFunc never calls f with a type parameter)
f := func(typ Type) Type {
assert(!isTypeParam(typ))
if t, _ := under(typ).(*Basic); t != nil {
switch t.kind {
case Complex64:
return Typ[Float32]
case Complex128:
return Typ[Float64]
case UntypedComplex:
return Typ[UntypedFloat]
}
}
return nil
}
resTyp := check.applyTypeFunc(f, x, id)
if resTyp == nil {
code := InvalidImag
if id == _Real {
code = InvalidReal
}
check.errorf(x, code, invalidArg+"argument has type %s, expected complex type", x.typ)
return
}
// if the argument is a constant, the result is a constant
if x.mode == constant_ {
if id == _Real {
x.val = constant.Real(x.val)
} else {
x.val = constant.Imag(x.val)
}
} else {
x.mode = value
}
if check.Types != nil && x.mode != constant_ {
check.recordBuiltinType(call.Fun, makeSig(resTyp, x.typ))
}
x.typ = resTyp
case _Make:
// make(T, n)
// make(T, n, m)
// (no argument evaluated yet)
arg0 := call.Args[0]
T := check.varType(arg0)
if T == Typ[Invalid] {
return
}
var min int // minimum number of arguments
switch coreType(T).(type) {
case *Slice:
min = 2
case *Map, *Chan:
min = 1
case nil:
check.errorf(arg0, InvalidMake, "cannot make %s: no core type", arg0)
return
default:
check.errorf(arg0, InvalidMake, invalidArg+"cannot make %s; type must be slice, map, or channel", arg0)
return
}
if nargs < min || min+1 < nargs {
check.errorf(call, WrongArgCount, invalidOp+"%v expects %d or %d arguments; found %d", call, min, min+1, nargs)
return
}
types := []Type{T}
var sizes []int64 // constant integer arguments, if any
for _, arg := range call.Args[1:] {
typ, size := check.index(arg, -1) // ok to continue with typ == Typ[Invalid]
types = append(types, typ)
if size >= 0 {
sizes = append(sizes, size)
}
}
if len(sizes) == 2 && sizes[0] > sizes[1] {
check.error(call.Args[1], SwappedMakeArgs, invalidArg+"length and capacity swapped")
// safe to continue
}
x.mode = value
x.typ = T
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, types...))
}
case _New:
// new(T)
// (no argument evaluated yet)
T := check.varType(call.Args[0])
if T == Typ[Invalid] {
return
}
x.mode = value
x.typ = &Pointer{base: T}
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, T))
}
case _Panic:
// panic(x)
// record panic call if inside a function with result parameters
// (for use in Checker.isTerminating)
if check.sig != nil && check.sig.results.Len() > 0 {
// function has result parameters
p := check.isPanic
if p == nil {
// allocate lazily
p = make(map[*ast.CallExpr]bool)
check.isPanic = p
}
p[call] = true
}
check.assignment(x, &emptyInterface, "argument to panic")
if x.mode == invalid {
return
}
x.mode = novalue
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(nil, &emptyInterface))
}
case _Print, _Println:
// print(x, y, ...)
// println(x, y, ...)
var params []Type
if nargs > 0 {
params = make([]Type, nargs)
for i := 0; i < nargs; i++ {
if i > 0 {
arg(x, i) // first argument already evaluated
}
check.assignment(x, nil, "argument to "+predeclaredFuncs[id].name)
if x.mode == invalid {
// TODO(gri) "use" all arguments?
return
}
params[i] = x.typ
}
}
x.mode = novalue
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(nil, params...))
}
case _Recover:
// recover() interface{}
x.mode = value
x.typ = &emptyInterface
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ))
}
case _Add:
// unsafe.Add(ptr unsafe.Pointer, len IntegerType) unsafe.Pointer
if !check.allowVersion(check.pkg, 1, 17) {
check.error(call.Fun, UnsupportedFeature, "unsafe.Add requires go1.17 or later")
return
}
check.assignment(x, Typ[UnsafePointer], "argument to unsafe.Add")
if x.mode == invalid {
return
}
var y operand
arg(&y, 1)
if !check.isValidIndex(&y, InvalidUnsafeAdd, "length", true) {
return
}
x.mode = value
x.typ = Typ[UnsafePointer]
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, x.typ, y.typ))
}
case _Alignof:
// unsafe.Alignof(x T) uintptr
check.assignment(x, nil, "argument to unsafe.Alignof")
if x.mode == invalid {
return
}
if hasVarSize(x.typ, nil) {
x.mode = value
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(Typ[Uintptr], x.typ))
}
} else {
x.mode = constant_
x.val = constant.MakeInt64(check.conf.alignof(x.typ))
// result is constant - no need to record signature
}
x.typ = Typ[Uintptr]
case _Offsetof:
// unsafe.Offsetof(x T) uintptr, where x must be a selector
// (no argument evaluated yet)
arg0 := call.Args[0]
selx, _ := unparen(arg0).(*ast.SelectorExpr)
if selx == nil {
check.errorf(arg0, BadOffsetofSyntax, invalidArg+"%s is not a selector expression", arg0)
check.use(arg0)
return
}
check.expr(x, selx.X)
if x.mode == invalid {
return
}
base := derefStructPtr(x.typ)
sel := selx.Sel.Name
obj, index, indirect := LookupFieldOrMethod(base, false, check.pkg, sel)
switch obj.(type) {
case nil:
check.errorf(x, MissingFieldOrMethod, invalidArg+"%s has no single field %s", base, sel)
return
case *Func:
// TODO(gri) Using derefStructPtr may result in methods being found
// that don't actually exist. An error either way, but the error
// message is confusing. See: https://play.golang.org/p/al75v23kUy ,
// but go/types reports: "invalid argument: x.m is a method value".
check.errorf(arg0, InvalidOffsetof, invalidArg+"%s is a method value", arg0)
return
}
if indirect {
check.errorf(x, InvalidOffsetof, invalidArg+"field %s is embedded via a pointer in %s", sel, base)
return
}
// TODO(gri) Should we pass x.typ instead of base (and have indirect report if derefStructPtr indirected)?
check.recordSelection(selx, FieldVal, base, obj, index, false)
// record the selector expression (was bug - go.dev/issue/47895)
{
mode := value
if x.mode == variable || indirect {
mode = variable
}
check.record(&operand{mode, selx, obj.Type(), nil, 0})
}
// The field offset is considered a variable even if the field is declared before
// the part of the struct which is variable-sized. This makes both the rules
// simpler and also permits (or at least doesn't prevent) a compiler from re-
// arranging struct fields if it wanted to.
if hasVarSize(base, nil) {
x.mode = value
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(Typ[Uintptr], obj.Type()))
}
} else {
x.mode = constant_
x.val = constant.MakeInt64(check.conf.offsetof(base, index))
// result is constant - no need to record signature
}
x.typ = Typ[Uintptr]
case _Sizeof:
// unsafe.Sizeof(x T) uintptr
check.assignment(x, nil, "argument to unsafe.Sizeof")
if x.mode == invalid {
return
}
if hasVarSize(x.typ, nil) {
x.mode = value
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(Typ[Uintptr], x.typ))
}
} else {
x.mode = constant_
x.val = constant.MakeInt64(check.conf.sizeof(x.typ))
// result is constant - no need to record signature
}
x.typ = Typ[Uintptr]
case _Slice:
// unsafe.Slice(ptr *T, len IntegerType) []T
if !check.allowVersion(check.pkg, 1, 17) {
check.error(call.Fun, UnsupportedFeature, "unsafe.Slice requires go1.17 or later")
return
}
ptr, _ := under(x.typ).(*Pointer) // TODO(gri) should this be coreType rather than under?
if ptr == nil {
check.errorf(x, InvalidUnsafeSlice, invalidArg+"%s is not a pointer", x)
return
}
var y operand
arg(&y, 1)
if !check.isValidIndex(&y, InvalidUnsafeSlice, "length", false) {
return
}
x.mode = value
x.typ = NewSlice(ptr.base)
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, ptr, y.typ))
}
case _SliceData:
// unsafe.SliceData(slice []T) *T
if !check.allowVersion(check.pkg, 1, 20) {
check.error(call.Fun, UnsupportedFeature, "unsafe.SliceData requires go1.20 or later")
return
}
slice, _ := under(x.typ).(*Slice) // TODO(gri) should this be coreType rather than under?
if slice == nil {
check.errorf(x, InvalidUnsafeSliceData, invalidArg+"%s is not a slice", x)
return
}
x.mode = value
x.typ = NewPointer(slice.elem)
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, slice))
}
case _String:
// unsafe.String(ptr *byte, len IntegerType) string
if !check.allowVersion(check.pkg, 1, 20) {
check.error(call.Fun, UnsupportedFeature, "unsafe.String requires go1.20 or later")
return
}
check.assignment(x, NewPointer(universeByte), "argument to unsafe.String")
if x.mode == invalid {
return
}
var y operand
arg(&y, 1)
if !check.isValidIndex(&y, InvalidUnsafeString, "length", false) {
return
}
x.mode = value
x.typ = Typ[String]
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, NewPointer(universeByte), y.typ))
}
case _StringData:
// unsafe.StringData(str string) *byte
if !check.allowVersion(check.pkg, 1, 20) {
check.error(call.Fun, UnsupportedFeature, "unsafe.StringData requires go1.20 or later")
return
}
check.assignment(x, Typ[String], "argument to unsafe.StringData")
if x.mode == invalid {
return
}
x.mode = value
x.typ = NewPointer(universeByte)
if check.Types != nil {
check.recordBuiltinType(call.Fun, makeSig(x.typ, Typ[String]))
}
case _Assert:
// assert(pred) causes a typechecker error if pred is false.
// The result of assert is the value of pred if there is no error.
// Note: assert is only available in self-test mode.
if x.mode != constant_ || !isBoolean(x.typ) {
check.errorf(x, Test, invalidArg+"%s is not a boolean constant", x)
return
}
if x.val.Kind() != constant.Bool {
check.errorf(x, Test, "internal error: value of %s should be a boolean constant", x)
return
}
if !constant.BoolVal(x.val) {
check.errorf(call, Test, "%v failed", call)
// compile-time assertion failure - safe to continue
}
// result is constant - no need to record signature
case _Trace:
// trace(x, y, z, ...) dumps the positions, expressions, and
// values of its arguments. The result of trace is the value
// of the first argument.
// Note: trace is only available in self-test mode.
// (no argument evaluated yet)
if nargs == 0 {
check.dump("%v: trace() without arguments", call.Pos())
x.mode = novalue
break
}
var t operand
x1 := x
for _, arg := range call.Args {
check.rawExpr(x1, arg, nil, false) // permit trace for types, e.g.: new(trace(T))
check.dump("%v: %s", x1.Pos(), x1)
x1 = &t // use incoming x only for first argument
}
// trace is only available in test mode - no need to record signature
default:
unreachable()
}
return true
}
// hasVarSize reports if the size of type t is variable due to type parameters
// or if the type is infinitely-sized due to a cycle for which the type has not
// yet been checked.
func hasVarSize(t Type, seen map[*Named]bool) (varSized bool) {
// Cycles are only possible through *Named types.
// The seen map is used to detect cycles and track
// the results of previously seen types.
if named, _ := t.(*Named); named != nil {
if v, ok := seen[named]; ok {
return v
}
if seen == nil {
seen = make(map[*Named]bool)
}
seen[named] = true // possibly cyclic until proven otherwise
defer func() {
seen[named] = varSized // record final determination for named
}()
}
switch u := under(t).(type) {
case *Array:
return hasVarSize(u.elem, seen)
case *Struct:
for _, f := range u.fields {
if hasVarSize(f.typ, seen) {
return true
}
}
case *Interface:
return isTypeParam(t)
case *Named, *Union:
unreachable()
}
return false
}
// applyTypeFunc applies f to x. If x is a type parameter,
// the result is a type parameter constrained by an new
// interface bound. The type bounds for that interface
// are computed by applying f to each of the type bounds
// of x. If any of these applications of f return nil,
// applyTypeFunc returns nil.
// If x is not a type parameter, the result is f(x).
func (check *Checker) applyTypeFunc(f func(Type) Type, x *operand, id builtinId) Type {
if tp, _ := x.typ.(*TypeParam); tp != nil {
// Test if t satisfies the requirements for the argument
// type and collect possible result types at the same time.
var terms []*Term
if !tp.is(func(t *term) bool {
if t == nil {
return false
}
if r := f(t.typ); r != nil {
terms = append(terms, NewTerm(t.tilde, r))
return true
}
return false
}) {
return nil
}
// We can type-check this fine but we're introducing a synthetic
// type parameter for the result. It's not clear what the API
// implications are here. Report an error for 1.18 (see go.dev/issue/50912),
// but continue type-checking.
var code Code
switch id {
case _Real:
code = InvalidReal
case _Imag:
code = InvalidImag
case _Complex:
code = InvalidComplex
default:
unreachable()
}
check.softErrorf(x, code, "%s not supported as argument to %s for go1.18 (see go.dev/issue/50937)", x, predeclaredFuncs[id].name)
// Construct a suitable new type parameter for the result type.
// The type parameter is placed in the current package so export/import
// works as expected.
tpar := NewTypeName(nopos, check.pkg, tp.obj.name, nil)
ptyp := check.newTypeParam(tpar, NewInterfaceType(nil, []Type{NewUnion(terms)})) // assigns type to tpar as a side-effect
ptyp.index = tp.index
return ptyp
}
return f(x.typ)
}
// makeSig makes a signature for the given argument and result types.
// Default types are used for untyped arguments, and res may be nil.
func makeSig(res Type, args ...Type) *Signature {
list := make([]*Var, len(args))
for i, param := range args {
list[i] = NewVar(nopos, nil, "", Default(param))
}
params := NewTuple(list...)
var result *Tuple
if res != nil {
assert(!isUntyped(res))
result = NewTuple(NewVar(nopos, nil, "", res))
}
return &Signature{params: params, results: result}
}
// arrayPtrDeref returns A if typ is of the form *A and A is an array;
// otherwise it returns typ.
func arrayPtrDeref(typ Type) Type {
if p, ok := typ.(*Pointer); ok {
if a, _ := under(p.base).(*Array); a != nil {
return a
}
}
return typ
}
// unparen returns e with any enclosing parentheses stripped.
func unparen(e ast.Expr) ast.Expr {
for {
p, ok := e.(*ast.ParenExpr)
if !ok {
return e
}
e = p.X
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of call and selector expressions.
package types
import (
"go/ast"
"go/internal/typeparams"
"go/token"
. "internal/types/errors"
"strings"
"unicode"
)
// funcInst type-checks a function instantiation inst and returns the result in x.
// The operand x must be the evaluation of inst.X and its type must be a signature.
func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(ix.Orig, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
}
targs := check.typeList(ix.Indices)
if targs == nil {
x.mode = invalid
x.expr = ix.Orig
return
}
assert(len(targs) == len(ix.Indices))
// check number of type arguments (got) vs number of type parameters (want)
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = ix.Orig
return
}
if got < want {
targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil)
if targs == nil {
// error was already reported
x.mode = invalid
x.expr = ix.Orig
return
}
got = len(targs)
}
assert(got == want)
// instantiate function signature
sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(ix.Orig, targs, sig)
x.typ = sig
x.mode = value
x.expr = ix.Orig
}
func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) {
assert(check != nil)
assert(len(targs) == typ.TypeParams().Len())
if check.conf._Trace {
check.trace(pos, "-- instantiating signature %s with %s", typ, targs)
check.indent++
defer func() {
check.indent--
check.trace(pos, "=> %s (under = %s)", res, res.Underlying())
}()
}
inst := check.instance(pos, typ, targs, nil, check.context()).(*Signature)
assert(len(xlist) <= len(targs))
// verify instantiation lazily (was go.dev/issue/50450)
check.later(func() {
tparams := typ.TypeParams().list()
if i, err := check.verify(pos, tparams, targs, check.context()); err != nil {
// best position for error reporting
pos := pos
if i < len(xlist) {
pos = xlist[i].Pos()
}
check.softErrorf(atPos(pos), InvalidTypeArg, "%s", err)
} else {
check.mono.recordInstance(check.pkg, pos, tparams, targs, xlist)
}
}).describef(atPos(pos), "verify instantiation")
return inst
}
func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
ix := typeparams.UnpackIndexExpr(call.Fun)
if ix != nil {
if check.indexExpr(x, ix) {
// Delay function instantiation to argument checking,
// where we combine type and value arguments for type
// inference.
assert(x.mode == value)
} else {
ix = nil
}
x.expr = call.Fun
check.record(x)
} else {
check.exprOrType(x, call.Fun, true)
}
// x.typ may be generic
switch x.mode {
case invalid:
check.use(call.Args...)
x.expr = call
return statement
case typexpr:
// conversion
check.nonGeneric(x)
if x.mode == invalid {
return conversion
}
T := x.typ
x.mode = invalid
switch n := len(call.Args); n {
case 0:
check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T)
case 1:
check.expr(x, call.Args[0])
if x.mode != invalid {
if call.Ellipsis.IsValid() {
check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T)
break
}
if t, _ := under(T).(*Interface); t != nil && !isTypeParam(T) {
if !t.IsMethodSet() {
check.errorf(call, MisplacedConstraintIface, "cannot use interface %s in conversion (contains specific type constraints or is comparable)", T)
break
}
}
check.conversion(x, T)
}
default:
check.use(call.Args...)
check.errorf(call.Args[n-1], WrongArgCount, "too many arguments in conversion to %s", T)
}
x.expr = call
return conversion
case builtin:
// no need to check for non-genericity here
id := x.id
if !check.builtin(x, call, id) {
x.mode = invalid
}
x.expr = call
// a non-constant result implies a function call
if x.mode != invalid && x.mode != constant_ {
check.hasCallOrRecv = true
}
return predeclaredFuncs[id].kind
}
// ordinary function/method call
// signature may be generic
cgocall := x.mode == cgofunc
// a type parameter may be "called" if all types have the same signature
sig, _ := coreType(x.typ).(*Signature)
if sig == nil {
check.errorf(x, InvalidCall, invalidOp+"cannot call non-function %s", x)
x.mode = invalid
x.expr = call
return statement
}
// Capture wasGeneric before sig is potentially instantiated below.
wasGeneric := sig.TypeParams().Len() > 0
// evaluate type arguments, if any
var xlist []ast.Expr
var targs []Type
if ix != nil {
xlist = ix.Indices
targs = check.typeList(xlist)
if targs == nil {
check.use(call.Args...)
x.mode = invalid
x.expr = call
return statement
}
assert(len(targs) == len(xlist))
// check number of type arguments (got) vs number of type parameters (want)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
check.errorf(xlist[want], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
check.use(call.Args...)
x.mode = invalid
x.expr = call
return statement
}
// If sig is generic and all type arguments are provided, preempt function
// argument type inference by explicitly instantiating the signature. This
// ensures that we record accurate type information for sig, even if there
// is an error checking its arguments (for example, if an incorrect number
// of arguments is supplied).
if got == want && want > 0 {
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(call.Fun, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
}
sig = check.instantiateSignature(ix.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(ix.Orig, targs, sig)
// targs have been consumed; proceed with checking arguments of the
// non-generic signature.
targs = nil
xlist = nil
}
}
// evaluate arguments
args, _ := check.exprList(call.Args, false)
sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 {
// Update the recorded type of call.Fun to its instantiated type.
check.recordTypeAndValue(call.Fun, value, sig, nil)
}
// determine result
switch sig.results.Len() {
case 0:
x.mode = novalue
case 1:
if cgocall {
x.mode = commaerr
} else {
x.mode = value
}
x.typ = sig.results.vars[0].typ // unpack tuple
default:
x.mode = value
x.typ = sig.results
}
x.expr = call
check.hasCallOrRecv = true
// if type inference failed, a parametrized result must be invalidated
// (operands cannot have a parametrized type)
if x.mode == value && sig.TypeParams().Len() > 0 && isParameterized(sig.TypeParams().list(), x.typ) {
x.mode = invalid
}
return statement
}
func (check *Checker) exprList(elist []ast.Expr, allowCommaOk bool) (xlist []*operand, commaOk bool) {
switch len(elist) {
case 0:
// nothing to do
case 1:
// single (possibly comma-ok) value, or function returning multiple values
e := elist[0]
var x operand
check.multiExpr(&x, e)
if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
// multiple values
xlist = make([]*operand, t.Len())
for i, v := range t.vars {
xlist[i] = &operand{mode: value, expr: e, typ: v.typ}
}
break
}
// exactly one (possibly invalid or comma-ok) value
xlist = []*operand{&x}
if allowCommaOk && (x.mode == mapindex || x.mode == commaok || x.mode == commaerr) {
x2 := &operand{mode: value, expr: e, typ: Typ[UntypedBool]}
if x.mode == commaerr {
x2.typ = universeError
}
xlist = append(xlist, x2)
commaOk = true
}
default:
// multiple (possibly invalid) values
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
check.expr(&x, e)
xlist[i] = &x
}
}
return
}
// xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) {
rsig = sig
// TODO(gri) try to eliminate this extra verification loop
for _, a := range args {
switch a.mode {
case typexpr:
check.errorf(a, NotAnExpr, "%s used as value", a)
return
case invalid:
return
}
}
// Function call argument/parameter count requirements
//
// | standard call | dotdotdot call |
// --------------+------------------+----------------+
// standard func | nargs == npars | invalid |
// --------------+------------------+----------------+
// variadic func | nargs >= npars-1 | nargs == npars |
// --------------+------------------+----------------+
nargs := len(args)
npars := sig.params.Len()
ddd := call.Ellipsis.IsValid()
// set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
adjusted := false // indicates if sigParams is different from t.params
if sig.variadic {
if ddd {
// variadic_func(a, b, c...)
if len(call.Args) == 1 && nargs > 1 {
// f()... is not permitted if f() is multi-valued
check.errorf(inNode(call, call.Ellipsis), InvalidDotDotDot, "cannot use ... with %d-valued %s", nargs, call.Args[0])
return
}
} else {
// variadic_func(a, b, c)
if nargs >= npars-1 {
// Create custom parameters for arguments: keep
// the first npars-1 parameters and add one for
// each argument mapping to the ... parameter.
vars := make([]*Var, npars-1) // npars > 0 for variadic functions
copy(vars, sig.params.vars)
last := sig.params.vars[npars-1]
typ := last.typ.(*Slice).elem
for len(vars) < nargs {
vars = append(vars, NewParam(last.pos, last.pkg, last.name, typ))
}
sigParams = NewTuple(vars...) // possibly nil!
adjusted = true
npars = nargs
} else {
// nargs < npars-1
npars-- // for correct error message below
}
}
} else {
if ddd {
// standard_func(a, b, c...)
check.errorf(inNode(call, call.Ellipsis), NonVariadicDotDotDot, "cannot use ... in call to non-variadic %s", call.Fun)
return
}
// standard_func(a, b, c)
}
// check argument count
if nargs != npars {
var at positioner = call
qualifier := "not enough"
if nargs > npars {
at = args[npars].expr // report at first extra argument
qualifier = "too many"
} else {
at = atPos(call.Rparen) // report at closing )
}
// take care of empty parameter lists represented by nil tuples
var params []*Var
if sig.params != nil {
params = sig.params.vars
}
err := newErrorf(at, WrongArgCount, "%s arguments in call to %s", qualifier, call.Fun)
err.errorf(nopos, "have %s", check.typesSummary(operandTypes(args), false))
err.errorf(nopos, "want %s", check.typesSummary(varTypes(params), sig.variadic))
check.report(err)
return
}
// infer type arguments and instantiate signature if necessary
if sig.TypeParams().Len() > 0 {
if !check.allowVersion(check.pkg, 1, 18) {
switch call.Fun.(type) {
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(call.Fun)
check.softErrorf(inNode(call.Fun, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
default:
check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later")
}
}
targs := check.infer(call, sig.TypeParams().list(), targs, sigParams, args)
if targs == nil {
return // error already reported
}
// compute result signature
rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
check.recordInstance(call.Fun, targs, rsig)
// Optimization: Only if the parameter list was adjusted do we
// need to compute it from the adjusted list; otherwise we can
// simply use the result signature's parameter list.
if adjusted {
sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(sig.TypeParams().list(), targs), nil, check.context()).(*Tuple)
} else {
sigParams = rsig.params
}
}
// check arguments
if len(args) > 0 {
context := check.sprintf("argument to %s", call.Fun)
for i, a := range args {
check.assignment(a, sigParams.vars[i].typ, context)
}
}
return
}
var cgoPrefixes = [...]string{
"_Ciconst_",
"_Cfconst_",
"_Csconst_",
"_Ctype_",
"_Cvar_", // actually a pointer to the var
"_Cfpvar_fp_",
"_Cfunc_",
"_Cmacro_", // function to evaluate the expanded expression
}
func (check *Checker) selector(x *operand, e *ast.SelectorExpr, def *Named, wantType bool) {
// these must be declared before the "goto Error" statements
var (
obj Object
index []int
indirect bool
)
sel := e.Sel.Name
// If the identifier refers to a package, handle everything here
// so we don't need a "package" mode for operands: package names
// can only appear in qualified identifiers which are mapped to
// selector expressions.
if ident, ok := e.X.(*ast.Ident); ok {
obj := check.lookup(ident.Name)
if pname, _ := obj.(*PkgName); pname != nil {
assert(pname.pkg == check.pkg)
check.recordUse(ident, pname)
pname.used = true
pkg := pname.imported
var exp Object
funcMode := value
if pkg.cgo {
// cgo special cases C.malloc: it's
// rewritten to _CMalloc and does not
// support two-result calls.
if sel == "malloc" {
sel = "_CMalloc"
} else {
funcMode = cgofunc
}
for _, prefix := range cgoPrefixes {
// cgo objects are part of the current package (in file
// _cgo_gotypes.go). Use regular lookup.
_, exp = check.scope.LookupParent(prefix+sel, check.pos)
if exp != nil {
break
}
}
if exp == nil {
check.errorf(e.Sel, UndeclaredImportedName, "undefined: %s", ast.Expr(e)) // cast to ast.Expr to silence vet
goto Error
}
check.objDecl(exp, nil)
} else {
exp = pkg.scope.Lookup(sel)
if exp == nil {
if !pkg.fake {
check.errorf(e.Sel, UndeclaredImportedName, "undefined: %s", ast.Expr(e))
}
goto Error
}
if !exp.Exported() {
check.errorf(e.Sel, UnexportedName, "%s not exported by package %s", sel, pkg.name)
// ok to continue
}
}
check.recordUse(e.Sel, exp)
// Simplified version of the code for *ast.Idents:
// - imported objects are always fully initialized
switch exp := exp.(type) {
case *Const:
assert(exp.Val() != nil)
x.mode = constant_
x.typ = exp.typ
x.val = exp.val
case *TypeName:
x.mode = typexpr
x.typ = exp.typ
case *Var:
x.mode = variable
x.typ = exp.typ
if pkg.cgo && strings.HasPrefix(exp.name, "_Cvar_") {
x.typ = x.typ.(*Pointer).base
}
case *Func:
x.mode = funcMode
x.typ = exp.typ
if pkg.cgo && strings.HasPrefix(exp.name, "_Cmacro_") {
x.mode = value
x.typ = x.typ.(*Signature).results.vars[0].typ
}
case *Builtin:
x.mode = builtin
x.typ = exp.typ
x.id = exp.id
default:
check.dump("%v: unexpected object %v", e.Sel.Pos(), exp)
unreachable()
}
x.expr = e
return
}
}
check.exprOrType(x, e.X, false)
switch x.mode {
case typexpr:
// don't crash for "type T T.x" (was go.dev/issue/51509)
if def != nil && x.typ == def {
check.cycleError([]Object{def.obj})
goto Error
}
case builtin:
// types2 uses the position of '.' for the error
check.errorf(e.Sel, UncalledBuiltin, "cannot select on %s", x)
goto Error
case invalid:
goto Error
}
// Avoid crashing when checking an invalid selector in a method declaration
// (i.e., where def is not set):
//
// type S[T any] struct{}
// type V = S[any]
// func (fs *S[T]) M(x V.M) {}
//
// All codepaths below return a non-type expression. If we get here while
// expecting a type expression, it is an error.
//
// See go.dev/issue/57522 for more details.
//
// TODO(rfindley): We should do better by refusing to check selectors in all cases where
// x.typ is incomplete.
if wantType {
check.errorf(e.Sel, NotAType, "%s is not a type", ast.Expr(e))
goto Error
}
obj, index, indirect = LookupFieldOrMethod(x.typ, x.mode == variable, check.pkg, sel)
if obj == nil {
// Don't report another error if the underlying type was invalid (go.dev/issue/49541).
if under(x.typ) == Typ[Invalid] {
goto Error
}
if index != nil {
// TODO(gri) should provide actual type where the conflict happens
check.errorf(e.Sel, AmbiguousSelector, "ambiguous selector %s.%s", x.expr, sel)
goto Error
}
if indirect {
if x.mode == typexpr {
check.errorf(e.Sel, InvalidMethodExpr, "invalid method expression %s.%s (needs pointer receiver (*%s).%s)", x.typ, sel, x.typ, sel)
} else {
check.errorf(e.Sel, InvalidMethodExpr, "cannot call pointer method %s on %s", sel, x.typ)
}
goto Error
}
var why string
if isInterfacePtr(x.typ) {
why = check.interfacePtrError(x.typ)
} else {
why = check.sprintf("type %s has no field or method %s", x.typ, sel)
// Check if capitalization of sel matters and provide better error message in that case.
// TODO(gri) This code only looks at the first character but LookupFieldOrMethod should
// have an (internal) mechanism for case-insensitive lookup that we should use
// instead (see types2).
if len(sel) > 0 {
var changeCase string
if r := rune(sel[0]); unicode.IsUpper(r) {
changeCase = string(unicode.ToLower(r)) + sel[1:]
} else {
changeCase = string(unicode.ToUpper(r)) + sel[1:]
}
if obj, _, _ = LookupFieldOrMethod(x.typ, x.mode == variable, check.pkg, changeCase); obj != nil {
why += ", but does have " + changeCase
}
}
}
check.errorf(e.Sel, MissingFieldOrMethod, "%s.%s undefined (%s)", x.expr, sel, why)
goto Error
}
// methods may not have a fully set up signature yet
if m, _ := obj.(*Func); m != nil {
check.objDecl(m, nil)
}
if x.mode == typexpr {
// method expression
m, _ := obj.(*Func)
if m == nil {
// TODO(gri) should check if capitalization of sel matters and provide better error message in that case
check.errorf(e.Sel, MissingFieldOrMethod, "%s.%s undefined (type %s has no method %s)", x.expr, sel, x.typ, sel)
goto Error
}
check.recordSelection(e, MethodExpr, x.typ, m, index, indirect)
sig := m.typ.(*Signature)
if sig.recv == nil {
check.error(e, InvalidDeclCycle, "illegal cycle in method declaration")
goto Error
}
// the receiver type becomes the type of the first function
// argument of the method expression's function type
var params []*Var
if sig.params != nil {
params = sig.params.vars
}
// Be consistent about named/unnamed parameters. This is not needed
// for type-checking, but the newly constructed signature may appear
// in an error message and then have mixed named/unnamed parameters.
// (An alternative would be to not print parameter names in errors,
// but it's useful to see them; this is cheap and method expressions
// are rare.)
name := ""
if len(params) > 0 && params[0].name != "" {
// name needed
name = sig.recv.name
if name == "" {
name = "_"
}
}
params = append([]*Var{NewVar(sig.recv.pos, sig.recv.pkg, name, x.typ)}, params...)
x.mode = value
x.typ = &Signature{
tparams: sig.tparams,
params: NewTuple(params...),
results: sig.results,
variadic: sig.variadic,
}
check.addDeclDep(m)
} else {
// regular selector
switch obj := obj.(type) {
case *Var:
check.recordSelection(e, FieldVal, x.typ, obj, index, indirect)
if x.mode == variable || indirect {
x.mode = variable
} else {
x.mode = value
}
x.typ = obj.typ
case *Func:
// TODO(gri) If we needed to take into account the receiver's
// addressability, should we report the type &(x.typ) instead?
check.recordSelection(e, MethodVal, x.typ, obj, index, indirect)
// TODO(gri) The verification pass below is disabled for now because
// method sets don't match method lookup in some cases.
// For instance, if we made a copy above when creating a
// custom method for a parameterized received type, the
// method set method doesn't match (no copy there). There
/// may be other situations.
disabled := true
if !disabled && debug {
// Verify that LookupFieldOrMethod and MethodSet.Lookup agree.
// TODO(gri) This only works because we call LookupFieldOrMethod
// _before_ calling NewMethodSet: LookupFieldOrMethod completes
// any incomplete interfaces so they are available to NewMethodSet
// (which assumes that interfaces have been completed already).
typ := x.typ
if x.mode == variable {
// If typ is not an (unnamed) pointer or an interface,
// use *typ instead, because the method set of *typ
// includes the methods of typ.
// Variables are addressable, so we can always take their
// address.
if _, ok := typ.(*Pointer); !ok && !IsInterface(typ) {
typ = &Pointer{base: typ}
}
}
// If we created a synthetic pointer type above, we will throw
// away the method set computed here after use.
// TODO(gri) Method set computation should probably always compute
// both, the value and the pointer receiver method set and represent
// them in a single structure.
// TODO(gri) Consider also using a method set cache for the lifetime
// of checker once we rely on MethodSet lookup instead of individual
// lookup.
mset := NewMethodSet(typ)
if m := mset.Lookup(check.pkg, sel); m == nil || m.obj != obj {
check.dump("%v: (%s).%v -> %s", e.Pos(), typ, obj.name, m)
check.dump("%s\n", mset)
// Caution: MethodSets are supposed to be used externally
// only (after all interface types were completed). It's
// now possible that we get here incorrectly. Not urgent
// to fix since we only run this code in debug mode.
// TODO(gri) fix this eventually.
panic("method sets and lookup don't agree")
}
}
x.mode = value
// remove receiver
sig := *obj.typ.(*Signature)
sig.recv = nil
x.typ = &sig
check.addDeclDep(obj)
default:
unreachable()
}
}
// everything went well
x.expr = e
return
Error:
x.mode = invalid
x.expr = e
}
// use type-checks each argument.
// Useful to make sure expressions are evaluated
// (and variables are "used") in the presence of other errors.
// The arguments may be nil.
func (check *Checker) use(arg ...ast.Expr) {
var x operand
for _, e := range arg {
// The nil check below is necessary since certain AST fields
// may legally be nil (e.g., the ast.SliceExpr.High field).
if e != nil {
check.rawExpr(&x, e, nil, false)
}
}
}
// useLHS is like use, but doesn't "use" top-level identifiers.
// It should be called instead of use if the arguments are
// expressions on the lhs of an assignment.
// The arguments must not be nil.
func (check *Checker) useLHS(arg ...ast.Expr) {
var x operand
for _, e := range arg {
// If the lhs is an identifier denoting a variable v, this assignment
// is not a 'use' of v. Remember current value of v.used and restore
// after evaluating the lhs via check.rawExpr.
var v *Var
var v_used bool
if ident, _ := unparen(e).(*ast.Ident); ident != nil {
// never type-check the blank name on the lhs
if ident.Name == "_" {
continue
}
if _, obj := check.scope.LookupParent(ident.Name, nopos); obj != nil {
// It's ok to mark non-local variables, but ignore variables
// from other packages to avoid potential race conditions with
// dot-imported variables.
if w, _ := obj.(*Var); w != nil && w.pkg == check.pkg {
v = w
v_used = v.used
}
}
}
check.rawExpr(&x, e, nil, false)
if v != nil {
v.used = v_used // restore v.used
}
}
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A Chan represents a channel type.
type Chan struct {
dir ChanDir
elem Type
}
// A ChanDir value indicates a channel direction.
type ChanDir int
// The direction of a channel is indicated by one of these constants.
const (
SendRecv ChanDir = iota
SendOnly
RecvOnly
)
// NewChan returns a new channel type for the given direction and element type.
func NewChan(dir ChanDir, elem Type) *Chan {
return &Chan{dir: dir, elem: elem}
}
// Dir returns the direction of channel c.
func (c *Chan) Dir() ChanDir { return c.dir }
// Elem returns the element type of channel c.
func (c *Chan) Elem() Type { return c.elem }
func (c *Chan) Underlying() Type { return c }
func (c *Chan) String() string { return TypeString(c, nil) }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements the Check function, which drives type-checking.
package types
import (
"errors"
"fmt"
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
)
// nopos indicates an unknown position
var nopos token.Pos
// debugging/development support
const debug = false // leave on during development
// exprInfo stores information about an untyped expression.
type exprInfo struct {
isLhs bool // expression is lhs operand of a shift with delayed type-check
mode operandMode
typ *Basic
val constant.Value // constant value; or nil (if not a constant)
}
// An environment represents the environment within which an object is
// type-checked.
type environment struct {
decl *declInfo // package-level declaration whose init expression/function body is checked
scope *Scope // top-most scope for lookups
pos token.Pos // if valid, identifiers are looked up as if at position pos (used by Eval)
iota constant.Value // value of iota in a constant declaration; nil otherwise
errpos positioner // if set, identifier position of a constant with inherited initializer
inTParamList bool // set if inside a type parameter list
sig *Signature // function signature if inside a function; nil otherwise
isPanic map[*ast.CallExpr]bool // set of panic call expressions (used for termination check)
hasLabel bool // set if a function makes use of labels (only ~1% of functions); unused outside functions
hasCallOrRecv bool // set if an expression contains a function call or channel receive operation
}
// lookup looks up name in the current environment and returns the matching object, or nil.
func (env *environment) lookup(name string) Object {
_, obj := env.scope.LookupParent(name, env.pos)
return obj
}
// An importKey identifies an imported package by import path and source directory
// (directory containing the file containing the import). In practice, the directory
// may always be the same, or may not matter. Given an (import path, directory), an
// importer must always return the same package (but given two different import paths,
// an importer may still return the same package by mapping them to the same package
// paths).
type importKey struct {
path, dir string
}
// A dotImportKey describes a dot-imported object in the given scope.
type dotImportKey struct {
scope *Scope
name string
}
// An action describes a (delayed) action.
type action struct {
f func() // action to be executed
desc *actionDesc // action description; may be nil, requires debug to be set
}
// If debug is set, describef sets a printf-formatted description for action a.
// Otherwise, it is a no-op.
func (a *action) describef(pos positioner, format string, args ...any) {
if debug {
a.desc = &actionDesc{pos, format, args}
}
}
// An actionDesc provides information on an action.
// For debugging only.
type actionDesc struct {
pos positioner
format string
args []any
}
// A Checker maintains the state of the type checker.
// It must be created with NewChecker.
type Checker struct {
// package information
// (initialized by NewChecker, valid for the life-time of checker)
conf *Config
ctxt *Context // context for de-duplicating instances
fset *token.FileSet
pkg *Package
*Info
version version // accepted language version
nextID uint64 // unique Id for type parameters (first valid Id is 1)
objMap map[Object]*declInfo // maps package-level objects and (non-interface) methods to declaration info
impMap map[importKey]*Package // maps (import path, source directory) to (complete or fake) package
valids instanceLookup // valid *Named (incl. instantiated) types per the validType check
// pkgPathMap maps package names to the set of distinct import paths we've
// seen for that name, anywhere in the import graph. It is used for
// disambiguating package names in error messages.
//
// pkgPathMap is allocated lazily, so that we don't pay the price of building
// it on the happy path. seenPkgMap tracks the packages that we've already
// walked.
pkgPathMap map[string]map[string]bool
seenPkgMap map[*Package]bool
// information collected during type-checking of a set of package files
// (initialized by Files, valid only for the duration of check.Files;
// maps and lists are allocated on demand)
files []*ast.File // package files
imports []*PkgName // list of imported packages
dotImportMap map[dotImportKey]*PkgName // maps dot-imported objects to the package they were dot-imported through
recvTParamMap map[*ast.Ident]*TypeParam // maps blank receiver type parameters to their type
brokenAliases map[*TypeName]bool // set of aliases with broken (not yet determined) types
unionTypeSets map[*Union]*_TypeSet // computed type sets for union types
mono monoGraph // graph for detecting non-monomorphizable instantiation loops
firstErr error // first error encountered
methods map[*TypeName][]*Func // maps package scope type names to associated non-blank (non-interface) methods
untyped map[ast.Expr]exprInfo // map of expressions without final type
delayed []action // stack of delayed action segments; segments are processed in FIFO order
objPath []Object // path of object dependencies during type inference (for cycle reporting)
cleaners []cleaner // list of types that may need a final cleanup at the end of type-checking
// environment within which the current object is type-checked (valid only
// for the duration of type-checking a specific object)
environment
// debugging
indent int // indentation for tracing
}
// addDeclDep adds the dependency edge (check.decl -> to) if check.decl exists
func (check *Checker) addDeclDep(to Object) {
from := check.decl
if from == nil {
return // not in a package-level init expression
}
if _, found := check.objMap[to]; !found {
return // to is not a package-level object
}
from.addDep(to)
}
// brokenAlias records that alias doesn't have a determined type yet.
// It also sets alias.typ to Typ[Invalid].
func (check *Checker) brokenAlias(alias *TypeName) {
if check.brokenAliases == nil {
check.brokenAliases = make(map[*TypeName]bool)
}
check.brokenAliases[alias] = true
alias.typ = Typ[Invalid]
}
// validAlias records that alias has the valid type typ (possibly Typ[Invalid]).
func (check *Checker) validAlias(alias *TypeName, typ Type) {
delete(check.brokenAliases, alias)
alias.typ = typ
}
// isBrokenAlias reports whether alias doesn't have a determined type yet.
func (check *Checker) isBrokenAlias(alias *TypeName) bool {
return alias.typ == Typ[Invalid] && check.brokenAliases[alias]
}
func (check *Checker) rememberUntyped(e ast.Expr, lhs bool, mode operandMode, typ *Basic, val constant.Value) {
m := check.untyped
if m == nil {
m = make(map[ast.Expr]exprInfo)
check.untyped = m
}
m[e] = exprInfo{lhs, mode, typ, val}
}
// later pushes f on to the stack of actions that will be processed later;
// either at the end of the current statement, or in case of a local constant
// or variable declaration, before the constant or variable is in scope
// (so that f still sees the scope before any new declarations).
// later returns the pushed action so one can provide a description
// via action.describef for debugging, if desired.
func (check *Checker) later(f func()) *action {
i := len(check.delayed)
check.delayed = append(check.delayed, action{f: f})
return &check.delayed[i]
}
// push pushes obj onto the object path and returns its index in the path.
func (check *Checker) push(obj Object) int {
check.objPath = append(check.objPath, obj)
return len(check.objPath) - 1
}
// pop pops and returns the topmost object from the object path.
func (check *Checker) pop() Object {
i := len(check.objPath) - 1
obj := check.objPath[i]
check.objPath[i] = nil
check.objPath = check.objPath[:i]
return obj
}
type cleaner interface {
cleanup()
}
// needsCleanup records objects/types that implement the cleanup method
// which will be called at the end of type-checking.
func (check *Checker) needsCleanup(c cleaner) {
check.cleaners = append(check.cleaners, c)
}
// NewChecker returns a new Checker instance for a given package.
// Package files may be added incrementally via checker.Files.
func NewChecker(conf *Config, fset *token.FileSet, pkg *Package, info *Info) *Checker {
// make sure we have a configuration
if conf == nil {
conf = new(Config)
}
// make sure we have an info struct
if info == nil {
info = new(Info)
}
version, err := parseGoVersion(conf.GoVersion)
if err != nil {
panic(fmt.Sprintf("invalid Go version %q (%v)", conf.GoVersion, err))
}
return &Checker{
conf: conf,
ctxt: conf.Context,
fset: fset,
pkg: pkg,
Info: info,
version: version,
objMap: make(map[Object]*declInfo),
impMap: make(map[importKey]*Package),
}
}
// initFiles initializes the files-specific portion of checker.
// The provided files must all belong to the same package.
func (check *Checker) initFiles(files []*ast.File) {
// start with a clean slate (check.Files may be called multiple times)
check.files = nil
check.imports = nil
check.dotImportMap = nil
check.firstErr = nil
check.methods = nil
check.untyped = nil
check.delayed = nil
check.objPath = nil
check.cleaners = nil
// determine package name and collect valid files
pkg := check.pkg
for _, file := range files {
switch name := file.Name.Name; pkg.name {
case "":
if name != "_" {
pkg.name = name
} else {
check.error(file.Name, BlankPkgName, "invalid package name _")
}
fallthrough
case name:
check.files = append(check.files, file)
default:
check.errorf(atPos(file.Package), MismatchedPkgName, "package %s; expected %s", name, pkg.name)
// ignore this file
}
}
}
// A bailout panic is used for early termination.
type bailout struct{}
func (check *Checker) handleBailout(err *error) {
switch p := recover().(type) {
case nil, bailout:
// normal return or early exit
*err = check.firstErr
default:
// re-panic
panic(p)
}
}
// Files checks the provided files as part of the checker's package.
func (check *Checker) Files(files []*ast.File) error { return check.checkFiles(files) }
var errBadCgo = errors.New("cannot use FakeImportC and go115UsesCgo together")
func (check *Checker) checkFiles(files []*ast.File) (err error) {
if check.conf.FakeImportC && check.conf.go115UsesCgo {
return errBadCgo
}
defer check.handleBailout(&err)
print := func(msg string) {
if check.conf._Trace {
fmt.Println()
fmt.Println(msg)
}
}
print("== initFiles ==")
check.initFiles(files)
print("== collectObjects ==")
check.collectObjects()
print("== packageObjects ==")
check.packageObjects()
print("== processDelayed ==")
check.processDelayed(0) // incl. all functions
print("== cleanup ==")
check.cleanup()
print("== initOrder ==")
check.initOrder()
if !check.conf.DisableUnusedImportCheck {
print("== unusedImports ==")
check.unusedImports()
}
print("== recordUntyped ==")
check.recordUntyped()
if check.firstErr == nil {
// TODO(mdempsky): Ensure monomorph is safe when errors exist.
check.monomorph()
}
check.pkg.complete = true
// no longer needed - release memory
check.imports = nil
check.dotImportMap = nil
check.pkgPathMap = nil
check.seenPkgMap = nil
check.recvTParamMap = nil
check.brokenAliases = nil
check.unionTypeSets = nil
check.ctxt = nil
// TODO(rFindley) There's more memory we should release at this point.
return
}
// processDelayed processes all delayed actions pushed after top.
func (check *Checker) processDelayed(top int) {
// If each delayed action pushes a new action, the
// stack will continue to grow during this loop.
// However, it is only processing functions (which
// are processed in a delayed fashion) that may
// add more actions (such as nested functions), so
// this is a sufficiently bounded process.
for i := top; i < len(check.delayed); i++ {
a := &check.delayed[i]
if check.conf._Trace {
if a.desc != nil {
check.trace(a.desc.pos.Pos(), "-- "+a.desc.format, a.desc.args...)
} else {
check.trace(nopos, "-- delayed %p", a.f)
}
}
a.f() // may append to check.delayed
if check.conf._Trace {
fmt.Println()
}
}
assert(top <= len(check.delayed)) // stack must not have shrunk
check.delayed = check.delayed[:top]
}
// cleanup runs cleanup for all collected cleaners.
func (check *Checker) cleanup() {
// Don't use a range clause since Named.cleanup may add more cleaners.
for i := 0; i < len(check.cleaners); i++ {
check.cleaners[i].cleanup()
}
check.cleaners = nil
}
func (check *Checker) record(x *operand) {
// convert x into a user-friendly set of values
// TODO(gri) this code can be simplified
var typ Type
var val constant.Value
switch x.mode {
case invalid:
typ = Typ[Invalid]
case novalue:
typ = (*Tuple)(nil)
case constant_:
typ = x.typ
val = x.val
default:
typ = x.typ
}
assert(x.expr != nil && typ != nil)
if isUntyped(typ) {
// delay type and value recording until we know the type
// or until the end of type checking
check.rememberUntyped(x.expr, false, x.mode, typ.(*Basic), val)
} else {
check.recordTypeAndValue(x.expr, x.mode, typ, val)
}
}
func (check *Checker) recordUntyped() {
if !debug && check.Types == nil {
return // nothing to do
}
for x, info := range check.untyped {
if debug && isTyped(info.typ) {
check.dump("%v: %s (type %s) is typed", x.Pos(), x, info.typ)
unreachable()
}
check.recordTypeAndValue(x, info.mode, info.typ, info.val)
}
}
func (check *Checker) recordTypeAndValue(x ast.Expr, mode operandMode, typ Type, val constant.Value) {
assert(x != nil)
assert(typ != nil)
if mode == invalid {
return // omit
}
if mode == constant_ {
assert(val != nil)
// We check allBasic(typ, IsConstType) here as constant expressions may be
// recorded as type parameters.
assert(typ == Typ[Invalid] || allBasic(typ, IsConstType))
}
if m := check.Types; m != nil {
m[x] = TypeAndValue{mode, typ, val}
}
}
func (check *Checker) recordBuiltinType(f ast.Expr, sig *Signature) {
// f must be a (possibly parenthesized, possibly qualified)
// identifier denoting a built-in (including unsafe's non-constant
// functions Add and Slice): record the signature for f and possible
// children.
for {
check.recordTypeAndValue(f, builtin, sig, nil)
switch p := f.(type) {
case *ast.Ident, *ast.SelectorExpr:
return // we're done
case *ast.ParenExpr:
f = p.X
default:
unreachable()
}
}
}
func (check *Checker) recordCommaOkTypes(x ast.Expr, a [2]Type) {
assert(x != nil)
if a[0] == nil || a[1] == nil {
return
}
assert(isTyped(a[0]) && isTyped(a[1]) && (isBoolean(a[1]) || a[1] == universeError))
if m := check.Types; m != nil {
for {
tv := m[x]
assert(tv.Type != nil) // should have been recorded already
pos := x.Pos()
tv.Type = NewTuple(
NewVar(pos, check.pkg, "", a[0]),
NewVar(pos, check.pkg, "", a[1]),
)
m[x] = tv
// if x is a parenthesized expression (p.X), update p.X
p, _ := x.(*ast.ParenExpr)
if p == nil {
break
}
x = p.X
}
}
}
// recordInstance records instantiation information into check.Info, if the
// Instances map is non-nil. The given expr must be an ident, selector, or
// index (list) expr with ident or selector operand.
//
// TODO(rfindley): the expr parameter is fragile. See if we can access the
// instantiated identifier in some other way.
func (check *Checker) recordInstance(expr ast.Expr, targs []Type, typ Type) {
ident := instantiatedIdent(expr)
assert(ident != nil)
assert(typ != nil)
if m := check.Instances; m != nil {
m[ident] = Instance{newTypeList(targs), typ}
}
}
func instantiatedIdent(expr ast.Expr) *ast.Ident {
var selOrIdent ast.Expr
switch e := expr.(type) {
case *ast.IndexExpr:
selOrIdent = e.X
case *ast.IndexListExpr:
selOrIdent = e.X
case *ast.SelectorExpr, *ast.Ident:
selOrIdent = e
}
switch x := selOrIdent.(type) {
case *ast.Ident:
return x
case *ast.SelectorExpr:
return x.Sel
}
panic("instantiated ident not found")
}
func (check *Checker) recordDef(id *ast.Ident, obj Object) {
assert(id != nil)
if m := check.Defs; m != nil {
m[id] = obj
}
}
func (check *Checker) recordUse(id *ast.Ident, obj Object) {
assert(id != nil)
assert(obj != nil)
if m := check.Uses; m != nil {
m[id] = obj
}
}
func (check *Checker) recordImplicit(node ast.Node, obj Object) {
assert(node != nil)
assert(obj != nil)
if m := check.Implicits; m != nil {
m[node] = obj
}
}
func (check *Checker) recordSelection(x *ast.SelectorExpr, kind SelectionKind, recv Type, obj Object, index []int, indirect bool) {
assert(obj != nil && (recv == nil || len(index) > 0))
check.recordUse(x.Sel, obj)
if m := check.Selections; m != nil {
m[x] = &Selection{kind, recv, obj, index, indirect}
}
}
func (check *Checker) recordScope(node ast.Node, scope *Scope) {
assert(node != nil)
assert(scope != nil)
if m := check.Scopes; m != nil {
m[node] = scope
}
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"bytes"
"fmt"
"strconv"
"strings"
"sync"
)
// This file contains a definition of the type-checking context; an opaque type
// that may be supplied by users during instantiation.
//
// Contexts serve two purposes:
// - reduce the duplication of identical instances
// - short-circuit instantiation cycles
//
// For the latter purpose, we must always have a context during instantiation,
// whether or not it is supplied by the user. For both purposes, it must be the
// case that hashing a pointer-identical type produces consistent results
// (somewhat obviously).
//
// However, neither of these purposes require that our hash is perfect, and so
// this was not an explicit design goal of the context type. In fact, due to
// concurrent use it is convenient not to guarantee de-duplication.
//
// Nevertheless, in the future it could be helpful to allow users to leverage
// contexts to canonicalize instances, and it would probably be possible to
// achieve such a guarantee.
// A Context is an opaque type checking context. It may be used to share
// identical type instances across type-checked packages or calls to
// Instantiate. Contexts are safe for concurrent use.
//
// The use of a shared context does not guarantee that identical instances are
// deduplicated in all cases.
type Context struct {
mu sync.Mutex
typeMap map[string][]ctxtEntry // type hash -> instances entries
nextID int // next unique ID
originIDs map[Type]int // origin type -> unique ID
}
type ctxtEntry struct {
orig Type
targs []Type
instance Type // = orig[targs]
}
// NewContext creates a new Context.
func NewContext() *Context {
return &Context{
typeMap: make(map[string][]ctxtEntry),
originIDs: make(map[Type]int),
}
}
// instanceHash returns a string representation of typ instantiated with targs.
// The hash should be a perfect hash, though out of caution the type checker
// does not assume this. The result is guaranteed to not contain blanks.
func (ctxt *Context) instanceHash(orig Type, targs []Type) string {
assert(ctxt != nil)
assert(orig != nil)
var buf bytes.Buffer
h := newTypeHasher(&buf, ctxt)
h.string(strconv.Itoa(ctxt.getID(orig)))
// Because we've already written the unique origin ID this call to h.typ is
// unnecessary, but we leave it for hash readability. It can be removed later
// if performance is an issue.
h.typ(orig)
if len(targs) > 0 {
// TODO(rfindley): consider asserting on isGeneric(typ) here, if and when
// isGeneric handles *Signature types.
h.typeList(targs)
}
return strings.Replace(buf.String(), " ", "#", -1) // ReplaceAll is not available in Go1.4
}
// lookup returns an existing instantiation of orig with targs, if it exists.
// Otherwise, it returns nil.
func (ctxt *Context) lookup(h string, orig Type, targs []Type) Type {
ctxt.mu.Lock()
defer ctxt.mu.Unlock()
for _, e := range ctxt.typeMap[h] {
if identicalInstance(orig, targs, e.orig, e.targs) {
return e.instance
}
if debug {
// Panic during development to surface any imperfections in our hash.
panic(fmt.Sprintf("non-identical instances: (orig: %s, targs: %v) and %s", orig, targs, e.instance))
}
}
return nil
}
// update de-duplicates n against previously seen types with the hash h. If an
// identical type is found with the type hash h, the previously seen type is
// returned. Otherwise, n is returned, and recorded in the Context for the hash
// h.
func (ctxt *Context) update(h string, orig Type, targs []Type, inst Type) Type {
assert(inst != nil)
ctxt.mu.Lock()
defer ctxt.mu.Unlock()
for _, e := range ctxt.typeMap[h] {
if inst == nil || Identical(inst, e.instance) {
return e.instance
}
if debug {
// Panic during development to surface any imperfections in our hash.
panic(fmt.Sprintf("%s and %s are not identical", inst, e.instance))
}
}
ctxt.typeMap[h] = append(ctxt.typeMap[h], ctxtEntry{
orig: orig,
targs: targs,
instance: inst,
})
return inst
}
// getID returns a unique ID for the type t.
func (ctxt *Context) getID(t Type) int {
ctxt.mu.Lock()
defer ctxt.mu.Unlock()
id, ok := ctxt.originIDs[t]
if !ok {
id = ctxt.nextID
ctxt.originIDs[t] = id
ctxt.nextID++
}
return id
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of conversions.
package types
import (
"go/constant"
. "internal/types/errors"
"unicode"
)
// conversion type-checks the conversion T(x).
// The result is in x.
func (check *Checker) conversion(x *operand, T Type) {
constArg := x.mode == constant_
constConvertibleTo := func(T Type, val *constant.Value) bool {
switch t, _ := under(T).(*Basic); {
case t == nil:
// nothing to do
case representableConst(x.val, check, t, val):
return true
case isInteger(x.typ) && isString(t):
codepoint := unicode.ReplacementChar
if i, ok := constant.Uint64Val(x.val); ok && i <= unicode.MaxRune {
codepoint = rune(i)
}
if val != nil {
*val = constant.MakeString(string(codepoint))
}
return true
}
return false
}
var ok bool
var cause string
switch {
case constArg && isConstType(T):
// constant conversion
ok = constConvertibleTo(T, &x.val)
case constArg && isTypeParam(T):
// x is convertible to T if it is convertible
// to each specific type in the type set of T.
// If T's type set is empty, or if it doesn't
// have specific types, constant x cannot be
// converted.
ok = T.(*TypeParam).underIs(func(u Type) bool {
// u is nil if there are no specific type terms
if u == nil {
cause = check.sprintf("%s does not contain specific types", T)
return false
}
if isString(x.typ) && isBytesOrRunes(u) {
return true
}
if !constConvertibleTo(u, nil) {
cause = check.sprintf("cannot convert %s to type %s (in %s)", x, u, T)
return false
}
return true
})
x.mode = value // type parameters are not constants
case x.convertibleTo(check, T, &cause):
// non-constant conversion
ok = true
x.mode = value
}
if !ok {
if cause != "" {
check.errorf(x, InvalidConversion, "cannot convert %s to type %s: %s", x, T, cause)
} else {
check.errorf(x, InvalidConversion, "cannot convert %s to type %s", x, T)
}
x.mode = invalid
return
}
// The conversion argument types are final. For untyped values the
// conversion provides the type, per the spec: "A constant may be
// given a type explicitly by a constant declaration or conversion,...".
if isUntyped(x.typ) {
final := T
// - For conversions to interfaces, use the argument's default type.
// - For conversions of untyped constants to non-constant types, also
// use the default type (e.g., []byte("foo") should report string
// not []byte as type for the constant "foo").
// - Keep untyped nil for untyped nil arguments.
// - For constant integer to string conversions, keep the argument type.
// (See also the TODO below.)
if isNonTypeParamInterface(T) || constArg && !isConstType(T) || x.isNil() {
final = Default(x.typ) // default type of untyped nil is untyped nil
} else if x.mode == constant_ && isInteger(x.typ) && allString(T) {
final = x.typ
}
check.updateExprType(x.expr, final, true)
}
x.typ = T
}
// TODO(gri) convertibleTo checks if T(x) is valid. It assumes that the type
// of x is fully known, but that's not the case for say string(1<<s + 1.0):
// Here, the type of 1<<s + 1.0 will be UntypedFloat which will lead to the
// (correct!) refusal of the conversion. But the reported error is essentially
// "cannot convert untyped float value to string", yet the correct error (per
// the spec) is that we cannot shift a floating-point value: 1 in 1<<s should
// be converted to UntypedFloat because of the addition of 1.0. Fixing this
// is tricky because we'd have to run updateExprType on the argument first.
// (go.dev/issue/21982.)
// convertibleTo reports whether T(x) is valid. In the failure case, *cause
// may be set to the cause for the failure.
// The check parameter may be nil if convertibleTo is invoked through an
// exported API call, i.e., when all methods have been type-checked.
func (x *operand) convertibleTo(check *Checker, T Type, cause *string) bool {
// "x is assignable to T"
if ok, _ := x.assignableTo(check, T, cause); ok {
return true
}
// "V and T have identical underlying types if tags are ignored
// and V and T are not type parameters"
V := x.typ
Vu := under(V)
Tu := under(T)
Vp, _ := V.(*TypeParam)
Tp, _ := T.(*TypeParam)
if IdenticalIgnoreTags(Vu, Tu) && Vp == nil && Tp == nil {
return true
}
// "V and T are unnamed pointer types and their pointer base types
// have identical underlying types if tags are ignored
// and their pointer base types are not type parameters"
if V, ok := V.(*Pointer); ok {
if T, ok := T.(*Pointer); ok {
if IdenticalIgnoreTags(under(V.base), under(T.base)) && !isTypeParam(V.base) && !isTypeParam(T.base) {
return true
}
}
}
// "V and T are both integer or floating point types"
if isIntegerOrFloat(Vu) && isIntegerOrFloat(Tu) {
return true
}
// "V and T are both complex types"
if isComplex(Vu) && isComplex(Tu) {
return true
}
// "V is an integer or a slice of bytes or runes and T is a string type"
if (isInteger(Vu) || isBytesOrRunes(Vu)) && isString(Tu) {
return true
}
// "V is a string and T is a slice of bytes or runes"
if isString(Vu) && isBytesOrRunes(Tu) {
return true
}
// package unsafe:
// "any pointer or value of underlying type uintptr can be converted into a unsafe.Pointer"
if (isPointer(Vu) || isUintptr(Vu)) && isUnsafePointer(Tu) {
return true
}
// "and vice versa"
if isUnsafePointer(Vu) && (isPointer(Tu) || isUintptr(Tu)) {
return true
}
// "V is a slice, T is an array or pointer-to-array type,
// and the slice and array types have identical element types."
if s, _ := Vu.(*Slice); s != nil {
switch a := Tu.(type) {
case *Array:
if Identical(s.Elem(), a.Elem()) {
if check == nil || check.allowVersion(check.pkg, 1, 20) {
return true
}
// check != nil
if cause != nil {
// TODO(gri) consider restructuring versionErrorf so we can use it here and below
*cause = "conversion of slices to arrays requires go1.20 or later"
}
return false
}
case *Pointer:
if a, _ := under(a.Elem()).(*Array); a != nil {
if Identical(s.Elem(), a.Elem()) {
if check == nil || check.allowVersion(check.pkg, 1, 17) {
return true
}
// check != nil
if cause != nil {
*cause = "conversion of slices to array pointers requires go1.17 or later"
}
return false
}
}
}
}
// optimization: if we don't have type parameters, we're done
if Vp == nil && Tp == nil {
return false
}
errorf := func(format string, args ...any) {
if check != nil && cause != nil {
msg := check.sprintf(format, args...)
if *cause != "" {
msg += "\n\t" + *cause
}
*cause = msg
}
}
// generic cases with specific type terms
// (generic operands cannot be constants, so we can ignore x.val)
switch {
case Vp != nil && Tp != nil:
x := *x // don't clobber outer x
return Vp.is(func(V *term) bool {
if V == nil {
return false // no specific types
}
x.typ = V.typ
return Tp.is(func(T *term) bool {
if T == nil {
return false // no specific types
}
if !x.convertibleTo(check, T.typ, cause) {
errorf("cannot convert %s (in %s) to type %s (in %s)", V.typ, Vp, T.typ, Tp)
return false
}
return true
})
})
case Vp != nil:
x := *x // don't clobber outer x
return Vp.is(func(V *term) bool {
if V == nil {
return false // no specific types
}
x.typ = V.typ
if !x.convertibleTo(check, T, cause) {
errorf("cannot convert %s (in %s) to type %s", V.typ, Vp, T)
return false
}
return true
})
case Tp != nil:
return Tp.is(func(T *term) bool {
if T == nil {
return false // no specific types
}
if !x.convertibleTo(check, T.typ, cause) {
errorf("cannot convert %s to type %s (in %s)", x.typ, T.typ, Tp)
return false
}
return true
})
}
return false
}
func isUintptr(typ Type) bool {
t, _ := under(typ).(*Basic)
return t != nil && t.kind == Uintptr
}
func isUnsafePointer(typ Type) bool {
t, _ := under(typ).(*Basic)
return t != nil && t.kind == UnsafePointer
}
func isPointer(typ Type) bool {
_, ok := under(typ).(*Pointer)
return ok
}
func isBytesOrRunes(typ Type) bool {
if s, _ := under(typ).(*Slice); s != nil {
t, _ := under(s.elem).(*Basic)
return t != nil && (t.kind == Byte || t.kind == Rune)
}
return false
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
)
func (check *Checker) reportAltDecl(obj Object) {
if pos := obj.Pos(); pos.IsValid() {
// We use "other" rather than "previous" here because
// the first declaration seen may not be textually
// earlier in the source.
check.errorf(obj, DuplicateDecl, "\tother declaration of %s", obj.Name()) // secondary error, \t indented
}
}
func (check *Checker) declare(scope *Scope, id *ast.Ident, obj Object, pos token.Pos) {
// spec: "The blank identifier, represented by the underscore
// character _, may be used in a declaration like any other
// identifier but the declaration does not introduce a new
// binding."
if obj.Name() != "_" {
if alt := scope.Insert(obj); alt != nil {
check.errorf(obj, DuplicateDecl, "%s redeclared in this block", obj.Name())
check.reportAltDecl(alt)
return
}
obj.setScopePos(pos)
}
if id != nil {
check.recordDef(id, obj)
}
}
// pathString returns a string of the form a->b-> ... ->g for a path [a, b, ... g].
func pathString(path []Object) string {
var s string
for i, p := range path {
if i > 0 {
s += "->"
}
s += p.Name()
}
return s
}
// objDecl type-checks the declaration of obj in its respective (file) environment.
// For the meaning of def, see Checker.definedType, in typexpr.go.
func (check *Checker) objDecl(obj Object, def *Named) {
if check.conf._Trace && obj.Type() == nil {
if check.indent == 0 {
fmt.Println() // empty line between top-level objects for readability
}
check.trace(obj.Pos(), "-- checking %s (%s, objPath = %s)", obj, obj.color(), pathString(check.objPath))
check.indent++
defer func() {
check.indent--
check.trace(obj.Pos(), "=> %s (%s)", obj, obj.color())
}()
}
// Checking the declaration of obj means inferring its type
// (and possibly its value, for constants).
// An object's type (and thus the object) may be in one of
// three states which are expressed by colors:
//
// - an object whose type is not yet known is painted white (initial color)
// - an object whose type is in the process of being inferred is painted grey
// - an object whose type is fully inferred is painted black
//
// During type inference, an object's color changes from white to grey
// to black (pre-declared objects are painted black from the start).
// A black object (i.e., its type) can only depend on (refer to) other black
// ones. White and grey objects may depend on white and black objects.
// A dependency on a grey object indicates a cycle which may or may not be
// valid.
//
// When objects turn grey, they are pushed on the object path (a stack);
// they are popped again when they turn black. Thus, if a grey object (a
// cycle) is encountered, it is on the object path, and all the objects
// it depends on are the remaining objects on that path. Color encoding
// is such that the color value of a grey object indicates the index of
// that object in the object path.
// During type-checking, white objects may be assigned a type without
// traversing through objDecl; e.g., when initializing constants and
// variables. Update the colors of those objects here (rather than
// everywhere where we set the type) to satisfy the color invariants.
if obj.color() == white && obj.Type() != nil {
obj.setColor(black)
return
}
switch obj.color() {
case white:
assert(obj.Type() == nil)
// All color values other than white and black are considered grey.
// Because black and white are < grey, all values >= grey are grey.
// Use those values to encode the object's index into the object path.
obj.setColor(grey + color(check.push(obj)))
defer func() {
check.pop().setColor(black)
}()
case black:
assert(obj.Type() != nil)
return
default:
// Color values other than white or black are considered grey.
fallthrough
case grey:
// We have a (possibly invalid) cycle.
// In the existing code, this is marked by a non-nil type
// for the object except for constants and variables whose
// type may be non-nil (known), or nil if it depends on the
// not-yet known initialization value.
// In the former case, set the type to Typ[Invalid] because
// we have an initialization cycle. The cycle error will be
// reported later, when determining initialization order.
// TODO(gri) Report cycle here and simplify initialization
// order code.
switch obj := obj.(type) {
case *Const:
if !check.validCycle(obj) || obj.typ == nil {
obj.typ = Typ[Invalid]
}
case *Var:
if !check.validCycle(obj) || obj.typ == nil {
obj.typ = Typ[Invalid]
}
case *TypeName:
if !check.validCycle(obj) {
// break cycle
// (without this, calling underlying()
// below may lead to an endless loop
// if we have a cycle for a defined
// (*Named) type)
obj.typ = Typ[Invalid]
}
case *Func:
if !check.validCycle(obj) {
// Don't set obj.typ to Typ[Invalid] here
// because plenty of code type-asserts that
// functions have a *Signature type. Grey
// functions have their type set to an empty
// signature which makes it impossible to
// initialize a variable with the function.
}
default:
unreachable()
}
assert(obj.Type() != nil)
return
}
d := check.objMap[obj]
if d == nil {
check.dump("%v: %s should have been declared", obj.Pos(), obj)
unreachable()
}
// save/restore current environment and set up object environment
defer func(env environment) {
check.environment = env
}(check.environment)
check.environment = environment{
scope: d.file,
}
// Const and var declarations must not have initialization
// cycles. We track them by remembering the current declaration
// in check.decl. Initialization expressions depending on other
// consts, vars, or functions, add dependencies to the current
// check.decl.
switch obj := obj.(type) {
case *Const:
check.decl = d // new package-level const decl
check.constDecl(obj, d.vtyp, d.init, d.inherited)
case *Var:
check.decl = d // new package-level var decl
check.varDecl(obj, d.lhs, d.vtyp, d.init)
case *TypeName:
// invalid recursive types are detected via path
check.typeDecl(obj, d.tdecl, def)
check.collectMethods(obj) // methods can only be added to top-level types
case *Func:
// functions may be recursive - no need to track dependencies
check.funcDecl(obj, d)
default:
unreachable()
}
}
// validCycle checks if the cycle starting with obj is valid and
// reports an error if it is not.
func (check *Checker) validCycle(obj Object) (valid bool) {
// The object map contains the package scope objects and the non-interface methods.
if debug {
info := check.objMap[obj]
inObjMap := info != nil && (info.fdecl == nil || info.fdecl.Recv == nil) // exclude methods
isPkgObj := obj.Parent() == check.pkg.scope
if isPkgObj != inObjMap {
check.dump("%v: inconsistent object map for %s (isPkgObj = %v, inObjMap = %v)", obj.Pos(), obj, isPkgObj, inObjMap)
unreachable()
}
}
// Count cycle objects.
assert(obj.color() >= grey)
start := obj.color() - grey // index of obj in objPath
cycle := check.objPath[start:]
tparCycle := false // if set, the cycle is through a type parameter list
nval := 0 // number of (constant or variable) values in the cycle; valid if !generic
ndef := 0 // number of type definitions in the cycle; valid if !generic
loop:
for _, obj := range cycle {
switch obj := obj.(type) {
case *Const, *Var:
nval++
case *TypeName:
// If we reach a generic type that is part of a cycle
// and we are in a type parameter list, we have a cycle
// through a type parameter list, which is invalid.
if check.inTParamList && isGeneric(obj.typ) {
tparCycle = true
break loop
}
// Determine if the type name is an alias or not. For
// package-level objects, use the object map which
// provides syntactic information (which doesn't rely
// on the order in which the objects are set up). For
// local objects, we can rely on the order, so use
// the object's predicate.
// TODO(gri) It would be less fragile to always access
// the syntactic information. We should consider storing
// this information explicitly in the object.
var alias bool
if d := check.objMap[obj]; d != nil {
alias = d.tdecl.Assign.IsValid() // package-level object
} else {
alias = obj.IsAlias() // function local object
}
if !alias {
ndef++
}
case *Func:
// ignored for now
default:
unreachable()
}
}
if check.conf._Trace {
check.trace(obj.Pos(), "## cycle detected: objPath = %s->%s (len = %d)", pathString(cycle), obj.Name(), len(cycle))
if tparCycle {
check.trace(obj.Pos(), "## cycle contains: generic type in a type parameter list")
} else {
check.trace(obj.Pos(), "## cycle contains: %d values, %d type definitions", nval, ndef)
}
defer func() {
if valid {
check.trace(obj.Pos(), "=> cycle is valid")
} else {
check.trace(obj.Pos(), "=> error: cycle is invalid")
}
}()
}
if !tparCycle {
// A cycle involving only constants and variables is invalid but we
// ignore them here because they are reported via the initialization
// cycle check.
if nval == len(cycle) {
return true
}
// A cycle involving only types (and possibly functions) must have at least
// one type definition to be permitted: If there is no type definition, we
// have a sequence of alias type names which will expand ad infinitum.
if nval == 0 && ndef > 0 {
return true
}
}
check.cycleError(cycle)
return false
}
// cycleError reports a declaration cycle starting with
// the object in cycle that is "first" in the source.
func (check *Checker) cycleError(cycle []Object) {
// name returns the (possibly qualified) object name.
// This is needed because with generic types, cycles
// may refer to imported types. See go.dev/issue/50788.
// TODO(gri) Thus functionality is used elsewhere. Factor it out.
name := func(obj Object) string {
return packagePrefix(obj.Pkg(), check.qualifier) + obj.Name()
}
// TODO(gri) Should we start with the last (rather than the first) object in the cycle
// since that is the earliest point in the source where we start seeing the
// cycle? That would be more consistent with other error messages.
i := firstInSrc(cycle)
obj := cycle[i]
objName := name(obj)
// If obj is a type alias, mark it as valid (not broken) in order to avoid follow-on errors.
tname, _ := obj.(*TypeName)
if tname != nil && tname.IsAlias() {
check.validAlias(tname, Typ[Invalid])
}
// report a more concise error for self references
if len(cycle) == 1 {
if tname != nil {
check.errorf(obj, InvalidDeclCycle, "invalid recursive type: %s refers to itself", objName)
} else {
check.errorf(obj, InvalidDeclCycle, "invalid cycle in declaration: %s refers to itself", objName)
}
return
}
if tname != nil {
check.errorf(obj, InvalidDeclCycle, "invalid recursive type %s", objName)
} else {
check.errorf(obj, InvalidDeclCycle, "invalid cycle in declaration of %s", objName)
}
for range cycle {
check.errorf(obj, InvalidDeclCycle, "\t%s refers to", objName) // secondary error, \t indented
i++
if i >= len(cycle) {
i = 0
}
obj = cycle[i]
objName = name(obj)
}
check.errorf(obj, InvalidDeclCycle, "\t%s", objName)
}
// firstInSrc reports the index of the object with the "smallest"
// source position in path. path must not be empty.
func firstInSrc(path []Object) int {
fst, pos := 0, path[0].Pos()
for i, t := range path[1:] {
if cmpPos(t.Pos(), pos) < 0 {
fst, pos = i+1, t.Pos()
}
}
return fst
}
type (
decl interface {
node() ast.Node
}
importDecl struct{ spec *ast.ImportSpec }
constDecl struct {
spec *ast.ValueSpec
iota int
typ ast.Expr
init []ast.Expr
inherited bool
}
varDecl struct{ spec *ast.ValueSpec }
typeDecl struct{ spec *ast.TypeSpec }
funcDecl struct{ decl *ast.FuncDecl }
)
func (d importDecl) node() ast.Node { return d.spec }
func (d constDecl) node() ast.Node { return d.spec }
func (d varDecl) node() ast.Node { return d.spec }
func (d typeDecl) node() ast.Node { return d.spec }
func (d funcDecl) node() ast.Node { return d.decl }
func (check *Checker) walkDecls(decls []ast.Decl, f func(decl)) {
for _, d := range decls {
check.walkDecl(d, f)
}
}
func (check *Checker) walkDecl(d ast.Decl, f func(decl)) {
switch d := d.(type) {
case *ast.BadDecl:
// ignore
case *ast.GenDecl:
var last *ast.ValueSpec // last ValueSpec with type or init exprs seen
for iota, s := range d.Specs {
switch s := s.(type) {
case *ast.ImportSpec:
f(importDecl{s})
case *ast.ValueSpec:
switch d.Tok {
case token.CONST:
// determine which initialization expressions to use
inherited := true
switch {
case s.Type != nil || len(s.Values) > 0:
last = s
inherited = false
case last == nil:
last = new(ast.ValueSpec) // make sure last exists
inherited = false
}
check.arityMatch(s, last)
f(constDecl{spec: s, iota: iota, typ: last.Type, init: last.Values, inherited: inherited})
case token.VAR:
check.arityMatch(s, nil)
f(varDecl{s})
default:
check.errorf(s, InvalidSyntaxTree, "invalid token %s", d.Tok)
}
case *ast.TypeSpec:
f(typeDecl{s})
default:
check.errorf(s, InvalidSyntaxTree, "unknown ast.Spec node %T", s)
}
}
case *ast.FuncDecl:
f(funcDecl{d})
default:
check.errorf(d, InvalidSyntaxTree, "unknown ast.Decl node %T", d)
}
}
func (check *Checker) constDecl(obj *Const, typ, init ast.Expr, inherited bool) {
assert(obj.typ == nil)
// use the correct value of iota
defer func(iota constant.Value, errpos positioner) {
check.iota = iota
check.errpos = errpos
}(check.iota, check.errpos)
check.iota = obj.val
check.errpos = nil
// provide valid constant value under all circumstances
obj.val = constant.MakeUnknown()
// determine type, if any
if typ != nil {
t := check.typ(typ)
if !isConstType(t) {
// don't report an error if the type is an invalid C (defined) type
// (go.dev/issue/22090)
if under(t) != Typ[Invalid] {
check.errorf(typ, InvalidConstType, "invalid constant type %s", t)
}
obj.typ = Typ[Invalid]
return
}
obj.typ = t
}
// check initialization
var x operand
if init != nil {
if inherited {
// The initialization expression is inherited from a previous
// constant declaration, and (error) positions refer to that
// expression and not the current constant declaration. Use
// the constant identifier position for any errors during
// init expression evaluation since that is all we have
// (see issues go.dev/issue/42991, go.dev/issue/42992).
check.errpos = atPos(obj.pos)
}
check.expr(&x, init)
}
check.initConst(obj, &x)
}
func (check *Checker) varDecl(obj *Var, lhs []*Var, typ, init ast.Expr) {
assert(obj.typ == nil)
// determine type, if any
if typ != nil {
obj.typ = check.varType(typ)
// We cannot spread the type to all lhs variables if there
// are more than one since that would mark them as checked
// (see Checker.objDecl) and the assignment of init exprs,
// if any, would not be checked.
//
// TODO(gri) If we have no init expr, we should distribute
// a given type otherwise we need to re-evalate the type
// expr for each lhs variable, leading to duplicate work.
}
// check initialization
if init == nil {
if typ == nil {
// error reported before by arityMatch
obj.typ = Typ[Invalid]
}
return
}
if lhs == nil || len(lhs) == 1 {
assert(lhs == nil || lhs[0] == obj)
var x operand
check.expr(&x, init)
check.initVar(obj, &x, "variable declaration")
return
}
if debug {
// obj must be one of lhs
found := false
for _, lhs := range lhs {
if obj == lhs {
found = true
break
}
}
if !found {
panic("inconsistent lhs")
}
}
// We have multiple variables on the lhs and one init expr.
// Make sure all variables have been given the same type if
// one was specified, otherwise they assume the type of the
// init expression values (was go.dev/issue/15755).
if typ != nil {
for _, lhs := range lhs {
lhs.typ = obj.typ
}
}
check.initVars(lhs, []ast.Expr{init}, nil)
}
// isImportedConstraint reports whether typ is an imported type constraint.
func (check *Checker) isImportedConstraint(typ Type) bool {
named, _ := typ.(*Named)
if named == nil || named.obj.pkg == check.pkg || named.obj.pkg == nil {
return false
}
u, _ := named.under().(*Interface)
return u != nil && !u.IsMethodSet()
}
func (check *Checker) typeDecl(obj *TypeName, tdecl *ast.TypeSpec, def *Named) {
assert(obj.typ == nil)
var rhs Type
check.later(func() {
if t, _ := obj.typ.(*Named); t != nil { // type may be invalid
check.validType(t)
}
// If typ is local, an error was already reported where typ is specified/defined.
if check.isImportedConstraint(rhs) && !check.allowVersion(check.pkg, 1, 18) {
check.errorf(tdecl.Type, UnsupportedFeature, "using type constraint %s requires go1.18 or later", rhs)
}
}).describef(obj, "validType(%s)", obj.Name())
alias := tdecl.Assign.IsValid()
if alias && tdecl.TypeParams.NumFields() != 0 {
// The parser will ensure this but we may still get an invalid AST.
// Complain and continue as regular type definition.
check.error(atPos(tdecl.Assign), BadDecl, "generic type cannot be alias")
alias = false
}
// alias declaration
if alias {
if !check.allowVersion(check.pkg, 1, 9) {
check.error(atPos(tdecl.Assign), UnsupportedFeature, "type aliases requires go1.9 or later")
}
check.brokenAlias(obj)
rhs = check.typ(tdecl.Type)
check.validAlias(obj, rhs)
return
}
// type definition or generic type declaration
named := check.newNamed(obj, nil, nil)
def.setUnderlying(named)
if tdecl.TypeParams != nil {
check.openScope(tdecl, "type parameters")
defer check.closeScope()
check.collectTypeParams(&named.tparams, tdecl.TypeParams)
}
// determine underlying type of named
rhs = check.definedType(tdecl.Type, named)
assert(rhs != nil)
named.fromRHS = rhs
// If the underlying type was not set while type-checking the right-hand
// side, it is invalid and an error should have been reported elsewhere.
if named.underlying == nil {
named.underlying = Typ[Invalid]
}
// Disallow a lone type parameter as the RHS of a type declaration (go.dev/issue/45639).
// We don't need this restriction anymore if we make the underlying type of a type
// parameter its constraint interface: if the RHS is a lone type parameter, we will
// use its underlying type (like we do for any RHS in a type declaration), and its
// underlying type is an interface and the type declaration is well defined.
if isTypeParam(rhs) {
check.error(tdecl.Type, MisplacedTypeParam, "cannot use a type parameter as RHS in type declaration")
named.underlying = Typ[Invalid]
}
}
func (check *Checker) collectTypeParams(dst **TypeParamList, list *ast.FieldList) {
var tparams []*TypeParam
// Declare type parameters up-front, with empty interface as type bound.
// The scope of type parameters starts at the beginning of the type parameter
// list (so we can have mutually recursive parameterized interfaces).
for _, f := range list.List {
tparams = check.declareTypeParams(tparams, f.Names)
}
// Set the type parameters before collecting the type constraints because
// the parameterized type may be used by the constraints (go.dev/issue/47887).
// Example: type T[P T[P]] interface{}
*dst = bindTParams(tparams)
// Signal to cycle detection that we are in a type parameter list.
// We can only be inside one type parameter list at any given time:
// function closures may appear inside a type parameter list but they
// cannot be generic, and their bodies are processed in delayed and
// sequential fashion. Note that with each new declaration, we save
// the existing environment and restore it when done; thus inTPList is
// true exactly only when we are in a specific type parameter list.
assert(!check.inTParamList)
check.inTParamList = true
defer func() {
check.inTParamList = false
}()
index := 0
for _, f := range list.List {
var bound Type
// NOTE: we may be able to assert that f.Type != nil here, but this is not
// an invariant of the AST, so we are cautious.
if f.Type != nil {
bound = check.bound(f.Type)
if isTypeParam(bound) {
// We may be able to allow this since it is now well-defined what
// the underlying type and thus type set of a type parameter is.
// But we may need some additional form of cycle detection within
// type parameter lists.
check.error(f.Type, MisplacedTypeParam, "cannot use a type parameter as constraint")
bound = Typ[Invalid]
}
} else {
bound = Typ[Invalid]
}
for i := range f.Names {
tparams[index+i].bound = bound
}
index += len(f.Names)
}
}
func (check *Checker) bound(x ast.Expr) Type {
// A type set literal of the form ~T and A|B may only appear as constraint;
// embed it in an implicit interface so that only interface type-checking
// needs to take care of such type expressions.
wrap := false
switch op := x.(type) {
case *ast.UnaryExpr:
wrap = op.Op == token.TILDE
case *ast.BinaryExpr:
wrap = op.Op == token.OR
}
if wrap {
x = &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{{Type: x}}}}
t := check.typ(x)
// mark t as implicit interface if all went well
if t, _ := t.(*Interface); t != nil {
t.implicit = true
}
return t
}
return check.typ(x)
}
func (check *Checker) declareTypeParams(tparams []*TypeParam, names []*ast.Ident) []*TypeParam {
// Use Typ[Invalid] for the type constraint to ensure that a type
// is present even if the actual constraint has not been assigned
// yet.
// TODO(gri) Need to systematically review all uses of type parameter
// constraints to make sure we don't rely on them if they
// are not properly set yet.
for _, name := range names {
tname := NewTypeName(name.Pos(), check.pkg, name.Name, nil)
tpar := check.newTypeParam(tname, Typ[Invalid]) // assigns type to tpar as a side-effect
check.declare(check.scope, name, tname, check.scope.pos) // TODO(gri) check scope position
tparams = append(tparams, tpar)
}
if check.conf._Trace && len(names) > 0 {
check.trace(names[0].Pos(), "type params = %v", tparams[len(tparams)-len(names):])
}
return tparams
}
func (check *Checker) collectMethods(obj *TypeName) {
// get associated methods
// (Checker.collectObjects only collects methods with non-blank names;
// Checker.resolveBaseTypeName ensures that obj is not an alias name
// if it has attached methods.)
methods := check.methods[obj]
if methods == nil {
return
}
delete(check.methods, obj)
assert(!check.objMap[obj].tdecl.Assign.IsValid()) // don't use TypeName.IsAlias (requires fully set up object)
// use an objset to check for name conflicts
var mset objset
// spec: "If the base type is a struct type, the non-blank method
// and field names must be distinct."
base, _ := obj.typ.(*Named) // shouldn't fail but be conservative
if base != nil {
assert(base.TypeArgs().Len() == 0) // collectMethods should not be called on an instantiated type
// See go.dev/issue/52529: we must delay the expansion of underlying here, as
// base may not be fully set-up.
check.later(func() {
check.checkFieldUniqueness(base)
}).describef(obj, "verifying field uniqueness for %v", base)
// Checker.Files may be called multiple times; additional package files
// may add methods to already type-checked types. Add pre-existing methods
// so that we can detect redeclarations.
for i := 0; i < base.NumMethods(); i++ {
m := base.Method(i)
assert(m.name != "_")
assert(mset.insert(m) == nil)
}
}
// add valid methods
for _, m := range methods {
// spec: "For a base type, the non-blank names of methods bound
// to it must be unique."
assert(m.name != "_")
if alt := mset.insert(m); alt != nil {
if alt.Pos().IsValid() {
check.errorf(m, DuplicateMethod, "method %s.%s already declared at %s", obj.Name(), m.name, alt.Pos())
} else {
check.errorf(m, DuplicateMethod, "method %s.%s already declared", obj.Name(), m.name)
}
continue
}
if base != nil {
base.AddMethod(m)
}
}
}
func (check *Checker) checkFieldUniqueness(base *Named) {
if t, _ := base.under().(*Struct); t != nil {
var mset objset
for i := 0; i < base.NumMethods(); i++ {
m := base.Method(i)
assert(m.name != "_")
assert(mset.insert(m) == nil)
}
// Check that any non-blank field names of base are distinct from its
// method names.
for _, fld := range t.fields {
if fld.name != "_" {
if alt := mset.insert(fld); alt != nil {
// Struct fields should already be unique, so we should only
// encounter an alternate via collision with a method name.
_ = alt.(*Func)
// For historical consistency, we report the primary error on the
// method, and the alt decl on the field.
check.errorf(alt, DuplicateFieldAndMethod, "field and method with the same name %s", fld.name)
check.reportAltDecl(fld)
}
}
}
}
}
func (check *Checker) funcDecl(obj *Func, decl *declInfo) {
assert(obj.typ == nil)
// func declarations cannot use iota
assert(check.iota == nil)
sig := new(Signature)
obj.typ = sig // guard against cycles
// Avoid cycle error when referring to method while type-checking the signature.
// This avoids a nuisance in the best case (non-parameterized receiver type) and
// since the method is not a type, we get an error. If we have a parameterized
// receiver type, instantiating the receiver type leads to the instantiation of
// its methods, and we don't want a cycle error in that case.
// TODO(gri) review if this is correct and/or whether we still need this?
saved := obj.color_
obj.color_ = black
fdecl := decl.fdecl
check.funcType(sig, fdecl.Recv, fdecl.Type)
obj.color_ = saved
if fdecl.Type.TypeParams.NumFields() > 0 && fdecl.Body == nil {
check.softErrorf(fdecl.Name, BadDecl, "generic function is missing function body")
}
// function body must be type-checked after global declarations
// (functions implemented elsewhere have no body)
if !check.conf.IgnoreFuncBodies && fdecl.Body != nil {
check.later(func() {
check.funcBody(decl, obj.name, sig, fdecl.Body, nil)
}).describef(obj, "func %s", obj.name)
}
}
func (check *Checker) declStmt(d ast.Decl) {
pkg := check.pkg
check.walkDecl(d, func(d decl) {
switch d := d.(type) {
case constDecl:
top := len(check.delayed)
// declare all constants
lhs := make([]*Const, len(d.spec.Names))
for i, name := range d.spec.Names {
obj := NewConst(name.Pos(), pkg, name.Name, nil, constant.MakeInt64(int64(d.iota)))
lhs[i] = obj
var init ast.Expr
if i < len(d.init) {
init = d.init[i]
}
check.constDecl(obj, d.typ, init, d.inherited)
}
// process function literals in init expressions before scope changes
check.processDelayed(top)
// spec: "The scope of a constant or variable identifier declared
// inside a function begins at the end of the ConstSpec or VarSpec
// (ShortVarDecl for short variable declarations) and ends at the
// end of the innermost containing block."
scopePos := d.spec.End()
for i, name := range d.spec.Names {
check.declare(check.scope, name, lhs[i], scopePos)
}
case varDecl:
top := len(check.delayed)
lhs0 := make([]*Var, len(d.spec.Names))
for i, name := range d.spec.Names {
lhs0[i] = NewVar(name.Pos(), pkg, name.Name, nil)
}
// initialize all variables
for i, obj := range lhs0 {
var lhs []*Var
var init ast.Expr
switch len(d.spec.Values) {
case len(d.spec.Names):
// lhs and rhs match
init = d.spec.Values[i]
case 1:
// rhs is expected to be a multi-valued expression
lhs = lhs0
init = d.spec.Values[0]
default:
if i < len(d.spec.Values) {
init = d.spec.Values[i]
}
}
check.varDecl(obj, lhs, d.spec.Type, init)
if len(d.spec.Values) == 1 {
// If we have a single lhs variable we are done either way.
// If we have a single rhs expression, it must be a multi-
// valued expression, in which case handling the first lhs
// variable will cause all lhs variables to have a type
// assigned, and we are done as well.
if debug {
for _, obj := range lhs0 {
assert(obj.typ != nil)
}
}
break
}
}
// process function literals in init expressions before scope changes
check.processDelayed(top)
// declare all variables
// (only at this point are the variable scopes (parents) set)
scopePos := d.spec.End() // see constant declarations
for i, name := range d.spec.Names {
// see constant declarations
check.declare(check.scope, name, lhs0[i], scopePos)
}
case typeDecl:
obj := NewTypeName(d.spec.Name.Pos(), pkg, d.spec.Name.Name, nil)
// spec: "The scope of a type identifier declared inside a function
// begins at the identifier in the TypeSpec and ends at the end of
// the innermost containing block."
scopePos := d.spec.Name.Pos()
check.declare(check.scope, d.spec.Name, obj, scopePos)
// mark and unmark type before calling typeDecl; its type is still nil (see Checker.objDecl)
obj.setColor(grey + color(check.push(obj)))
check.typeDecl(obj, d.spec, nil)
check.pop().setColor(black)
default:
check.errorf(d.node(), InvalidSyntaxTree, "unknown ast.Decl node %T", d.node())
}
})
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements various error reporters.
package types
import (
"bytes"
"fmt"
"go/ast"
"go/token"
. "internal/types/errors"
"runtime"
"strconv"
"strings"
)
func assert(p bool) {
if !p {
msg := "assertion failed"
// Include information about the assertion location. Due to panic recovery,
// this location is otherwise buried in the middle of the panicking stack.
if _, file, line, ok := runtime.Caller(1); ok {
msg = fmt.Sprintf("%s:%d: %s", file, line, msg)
}
panic(msg)
}
}
func unreachable() {
panic("unreachable")
}
// An error_ represents a type-checking error.
// To report an error_, call Checker.report.
type error_ struct {
desc []errorDesc
code Code
soft bool // TODO(gri) eventually determine this from an error code
}
// An errorDesc describes part of a type-checking error.
type errorDesc struct {
posn positioner
format string
args []interface{}
}
func (err *error_) empty() bool {
return err.desc == nil
}
func (err *error_) pos() token.Pos {
if err.empty() {
return nopos
}
return err.desc[0].posn.Pos()
}
func (err *error_) msg(fset *token.FileSet, qf Qualifier) string {
if err.empty() {
return "no error"
}
var buf strings.Builder
for i := range err.desc {
p := &err.desc[i]
if i > 0 {
fmt.Fprint(&buf, "\n\t")
if p.posn.Pos().IsValid() {
fmt.Fprintf(&buf, "%s: ", fset.Position(p.posn.Pos()))
}
}
buf.WriteString(sprintf(fset, qf, false, p.format, p.args...))
}
return buf.String()
}
// String is for testing.
func (err *error_) String() string {
if err.empty() {
return "no error"
}
return fmt.Sprintf("%d: %s", err.pos(), err.msg(nil, nil))
}
// errorf adds formatted error information to err.
// It may be called multiple times to provide additional information.
func (err *error_) errorf(at token.Pos, format string, args ...interface{}) {
err.desc = append(err.desc, errorDesc{atPos(at), format, args})
}
func (check *Checker) qualifier(pkg *Package) string {
// Qualify the package unless it's the package being type-checked.
if pkg != check.pkg {
if check.pkgPathMap == nil {
check.pkgPathMap = make(map[string]map[string]bool)
check.seenPkgMap = make(map[*Package]bool)
check.markImports(check.pkg)
}
// If the same package name was used by multiple packages, display the full path.
if len(check.pkgPathMap[pkg.name]) > 1 {
return strconv.Quote(pkg.path)
}
return pkg.name
}
return ""
}
// markImports recursively walks pkg and its imports, to record unique import
// paths in pkgPathMap.
func (check *Checker) markImports(pkg *Package) {
if check.seenPkgMap[pkg] {
return
}
check.seenPkgMap[pkg] = true
forName, ok := check.pkgPathMap[pkg.name]
if !ok {
forName = make(map[string]bool)
check.pkgPathMap[pkg.name] = forName
}
forName[pkg.path] = true
for _, imp := range pkg.imports {
check.markImports(imp)
}
}
// check may be nil.
func (check *Checker) sprintf(format string, args ...any) string {
var fset *token.FileSet
var qf Qualifier
if check != nil {
fset = check.fset
qf = check.qualifier
}
return sprintf(fset, qf, false, format, args...)
}
func sprintf(fset *token.FileSet, qf Qualifier, tpSubscripts bool, format string, args ...any) string {
for i, arg := range args {
switch a := arg.(type) {
case nil:
arg = "<nil>"
case operand:
panic("got operand instead of *operand")
case *operand:
arg = operandString(a, qf)
case token.Pos:
if fset != nil {
arg = fset.Position(a).String()
}
case ast.Expr:
arg = ExprString(a)
case []ast.Expr:
var buf bytes.Buffer
buf.WriteByte('[')
writeExprList(&buf, a)
buf.WriteByte(']')
arg = buf.String()
case Object:
arg = ObjectString(a, qf)
case Type:
var buf bytes.Buffer
w := newTypeWriter(&buf, qf)
w.tpSubscripts = tpSubscripts
w.typ(a)
arg = buf.String()
case []Type:
var buf bytes.Buffer
w := newTypeWriter(&buf, qf)
w.tpSubscripts = tpSubscripts
buf.WriteByte('[')
for i, x := range a {
if i > 0 {
buf.WriteString(", ")
}
w.typ(x)
}
buf.WriteByte(']')
arg = buf.String()
case []*TypeParam:
var buf bytes.Buffer
w := newTypeWriter(&buf, qf)
w.tpSubscripts = tpSubscripts
buf.WriteByte('[')
for i, x := range a {
if i > 0 {
buf.WriteString(", ")
}
w.typ(x)
}
buf.WriteByte(']')
arg = buf.String()
}
args[i] = arg
}
return fmt.Sprintf(format, args...)
}
func (check *Checker) trace(pos token.Pos, format string, args ...any) {
fmt.Printf("%s:\t%s%s\n",
check.fset.Position(pos),
strings.Repeat(". ", check.indent),
sprintf(check.fset, check.qualifier, true, format, args...),
)
}
// dump is only needed for debugging
func (check *Checker) dump(format string, args ...any) {
fmt.Println(sprintf(check.fset, check.qualifier, true, format, args...))
}
// Report records the error pointed to by errp, setting check.firstError if
// necessary.
func (check *Checker) report(errp *error_) {
if errp.empty() {
panic("empty error details")
}
msg := errp.msg(check.fset, check.qualifier)
switch errp.code {
case InvalidSyntaxTree:
msg = "invalid AST: " + msg
case 0:
panic("no error code provided")
}
span := spanOf(errp.desc[0].posn)
e := Error{
Fset: check.fset,
Pos: span.pos,
Msg: msg,
Soft: errp.soft,
go116code: errp.code,
go116start: span.start,
go116end: span.end,
}
// Cheap trick: Don't report errors with messages containing
// "invalid operand" or "invalid type" as those tend to be
// follow-on errors which don't add useful information. Only
// exclude them if these strings are not at the beginning,
// and only if we have at least one error already reported.
isInvalidErr := strings.Index(e.Msg, "invalid operand") > 0 || strings.Index(e.Msg, "invalid type") > 0
if check.firstErr != nil && isInvalidErr {
return
}
e.Msg = stripAnnotations(e.Msg)
if check.errpos != nil {
// If we have an internal error and the errpos override is set, use it to
// augment our error positioning.
// TODO(rFindley) we may also want to augment the error message and refer
// to the position (pos) in the original expression.
span := spanOf(check.errpos)
e.Pos = span.pos
e.go116start = span.start
e.go116end = span.end
}
err := e
if check.firstErr == nil {
check.firstErr = err
}
if check.conf._Trace {
pos := e.Pos
msg := e.Msg
check.trace(pos, "ERROR: %s", msg)
}
f := check.conf.Error
if f == nil {
panic(bailout{}) // report only first error
}
f(err)
}
const (
invalidArg = "invalid argument: "
invalidOp = "invalid operation: "
)
// newErrorf creates a new error_ for later reporting with check.report.
func newErrorf(at positioner, code Code, format string, args ...any) *error_ {
return &error_{
desc: []errorDesc{{at, format, args}},
code: code,
}
}
func (check *Checker) error(at positioner, code Code, msg string) {
check.report(newErrorf(at, code, msg))
}
func (check *Checker) errorf(at positioner, code Code, format string, args ...any) {
check.report(newErrorf(at, code, format, args...))
}
func (check *Checker) softErrorf(at positioner, code Code, format string, args ...any) {
err := newErrorf(at, code, format, args...)
err.soft = true
check.report(err)
}
func (check *Checker) versionErrorf(at positioner, goVersion string, format string, args ...interface{}) {
msg := check.sprintf(format, args...)
var err *error_
err = newErrorf(at, UnsupportedFeature, "%s requires %s or later", msg, goVersion)
check.report(err)
}
// The positioner interface is used to extract the position of type-checker
// errors.
type positioner interface {
Pos() token.Pos
}
// posSpan holds a position range along with a highlighted position within that
// range. This is used for positioning errors, with pos by convention being the
// first position in the source where the error is known to exist, and start
// and end defining the full span of syntax being considered when the error was
// detected. Invariant: start <= pos < end || start == pos == end.
type posSpan struct {
start, pos, end token.Pos
}
func (e posSpan) Pos() token.Pos {
return e.pos
}
// inNode creates a posSpan for the given node.
// Invariant: node.Pos() <= pos < node.End() (node.End() is the position of the
// first byte after node within the source).
func inNode(node ast.Node, pos token.Pos) posSpan {
start, end := node.Pos(), node.End()
if debug {
assert(start <= pos && pos < end)
}
return posSpan{start, pos, end}
}
// atPos wraps a token.Pos to implement the positioner interface.
type atPos token.Pos
func (s atPos) Pos() token.Pos {
return token.Pos(s)
}
// spanOf extracts an error span from the given positioner. By default this is
// the trivial span starting and ending at pos, but this span is expanded when
// the argument naturally corresponds to a span of source code.
func spanOf(at positioner) posSpan {
switch x := at.(type) {
case nil:
panic("nil positioner")
case posSpan:
return x
case ast.Node:
pos := x.Pos()
return posSpan{pos, pos, x.End()}
case *operand:
if x.expr != nil {
pos := x.Pos()
return posSpan{pos, pos, x.expr.End()}
}
return posSpan{nopos, nopos, nopos}
default:
pos := at.Pos()
return posSpan{pos, pos, pos}
}
}
// stripAnnotations removes internal (type) annotations from s.
func stripAnnotations(s string) string {
var buf strings.Builder
for _, r := range s {
// strip #'s and subscript digits
if r < '₀' || '₀'+10 <= r { // '₀' == U+2080
buf.WriteRune(r)
}
}
if buf.Len() < len(s) {
return buf.String()
}
return s
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
)
// Eval returns the type and, if constant, the value for the
// expression expr, evaluated at position pos of package pkg,
// which must have been derived from type-checking an AST with
// complete position information relative to the provided file
// set.
//
// The meaning of the parameters fset, pkg, and pos is the
// same as in CheckExpr. An error is returned if expr cannot
// be parsed successfully, or the resulting expr AST cannot be
// type-checked.
func Eval(fset *token.FileSet, pkg *Package, pos token.Pos, expr string) (_ TypeAndValue, err error) {
// parse expressions
node, err := parser.ParseExprFrom(fset, "eval", expr, 0)
if err != nil {
return TypeAndValue{}, err
}
info := &Info{
Types: make(map[ast.Expr]TypeAndValue),
}
err = CheckExpr(fset, pkg, pos, node, info)
return info.Types[node], err
}
// CheckExpr type checks the expression expr as if it had appeared at position
// pos of package pkg. Type information about the expression is recorded in
// info. The expression may be an identifier denoting an uninstantiated generic
// function or type.
//
// If pkg == nil, the Universe scope is used and the provided
// position pos is ignored. If pkg != nil, and pos is invalid,
// the package scope is used. Otherwise, pos must belong to the
// package.
//
// An error is returned if pos is not within the package or
// if the node cannot be type-checked.
//
// Note: Eval and CheckExpr should not be used instead of running Check
// to compute types and values, but in addition to Check, as these
// functions ignore the context in which an expression is used (e.g., an
// assignment). Thus, top-level untyped constants will return an
// untyped type rather then the respective context-specific type.
func CheckExpr(fset *token.FileSet, pkg *Package, pos token.Pos, expr ast.Expr, info *Info) (err error) {
// determine scope
var scope *Scope
if pkg == nil {
scope = Universe
pos = nopos
} else if !pos.IsValid() {
scope = pkg.scope
} else {
// The package scope extent (position information) may be
// incorrect (files spread across a wide range of fset
// positions) - ignore it and just consider its children
// (file scopes).
for _, fscope := range pkg.scope.children {
if scope = fscope.Innermost(pos); scope != nil {
break
}
}
if scope == nil || debug {
s := scope
for s != nil && s != pkg.scope {
s = s.parent
}
// s == nil || s == pkg.scope
if s == nil {
return fmt.Errorf("no position %s found in package %s", fset.Position(pos), pkg.name)
}
}
}
// initialize checker
check := NewChecker(nil, fset, pkg, info)
check.scope = scope
check.pos = pos
defer check.handleBailout(&err)
// evaluate node
var x operand
check.rawExpr(&x, expr, nil, true) // allow generic expressions
check.processDelayed(0) // incl. all functions
check.recordUntyped()
return nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of expressions.
package types
import (
"fmt"
"go/ast"
"go/constant"
"go/internal/typeparams"
"go/token"
. "internal/types/errors"
"math"
)
/*
Basic algorithm:
Expressions are checked recursively, top down. Expression checker functions
are generally of the form:
func f(x *operand, e *ast.Expr, ...)
where e is the expression to be checked, and x is the result of the check.
The check performed by f may fail in which case x.mode == invalid, and
related error messages will have been issued by f.
If a hint argument is present, it is the composite literal element type
of an outer composite literal; it is used to type-check composite literal
elements that have no explicit type specification in the source
(e.g.: []T{{...}, {...}}, the hint is the type T in this case).
All expressions are checked via rawExpr, which dispatches according
to expression kind. Upon returning, rawExpr is recording the types and
constant values for all expressions that have an untyped type (those types
may change on the way up in the expression tree). Usually these are constants,
but the results of comparisons or non-constant shifts of untyped constants
may also be untyped, but not constant.
Untyped expressions may eventually become fully typed (i.e., not untyped),
typically when the value is assigned to a variable, or is used otherwise.
The updateExprType method is used to record this final type and update
the recorded types: the type-checked expression tree is again traversed down,
and the new type is propagated as needed. Untyped constant expression values
that become fully typed must now be representable by the full type (constant
sub-expression trees are left alone except for their roots). This mechanism
ensures that a client sees the actual (run-time) type an untyped value would
have. It also permits type-checking of lhs shift operands "as if the shift
were not present": when updateExprType visits an untyped lhs shift operand
and assigns it it's final type, that type must be an integer type, and a
constant lhs must be representable as an integer.
When an expression gets its final type, either on the way out from rawExpr,
on the way down in updateExprType, or at the end of the type checker run,
the type (and constant value, if any) is recorded via Info.Types, if present.
*/
type opPredicates map[token.Token]func(Type) bool
var unaryOpPredicates opPredicates
func init() {
// Setting unaryOpPredicates in init avoids declaration cycles.
unaryOpPredicates = opPredicates{
token.ADD: allNumeric,
token.SUB: allNumeric,
token.XOR: allInteger,
token.NOT: allBoolean,
}
}
func (check *Checker) op(m opPredicates, x *operand, op token.Token) bool {
if pred := m[op]; pred != nil {
if !pred(x.typ) {
check.errorf(x, UndefinedOp, invalidOp+"operator %s not defined on %s", op, x)
return false
}
} else {
check.errorf(x, InvalidSyntaxTree, "unknown operator %s", op)
return false
}
return true
}
// overflow checks that the constant x is representable by its type.
// For untyped constants, it checks that the value doesn't become
// arbitrarily large.
func (check *Checker) overflow(x *operand, opPos token.Pos) {
assert(x.mode == constant_)
if x.val.Kind() == constant.Unknown {
// TODO(gri) We should report exactly what went wrong. At the
// moment we don't have the (go/constant) API for that.
// See also TODO in go/constant/value.go.
check.error(atPos(opPos), InvalidConstVal, "constant result is not representable")
return
}
// Typed constants must be representable in
// their type after each constant operation.
// x.typ cannot be a type parameter (type
// parameters cannot be constant types).
if isTyped(x.typ) {
check.representable(x, under(x.typ).(*Basic))
return
}
// Untyped integer values must not grow arbitrarily.
const prec = 512 // 512 is the constant precision
if x.val.Kind() == constant.Int && constant.BitLen(x.val) > prec {
op := opName(x.expr)
if op != "" {
op += " "
}
check.errorf(atPos(opPos), InvalidConstVal, "constant %soverflow", op)
x.val = constant.MakeUnknown()
}
}
// opName returns the name of the operation if x is an operation
// that might overflow; otherwise it returns the empty string.
func opName(e ast.Expr) string {
switch e := e.(type) {
case *ast.BinaryExpr:
if int(e.Op) < len(op2str2) {
return op2str2[e.Op]
}
case *ast.UnaryExpr:
if int(e.Op) < len(op2str1) {
return op2str1[e.Op]
}
}
return ""
}
var op2str1 = [...]string{
token.XOR: "bitwise complement",
}
// This is only used for operations that may cause overflow.
var op2str2 = [...]string{
token.ADD: "addition",
token.SUB: "subtraction",
token.XOR: "bitwise XOR",
token.MUL: "multiplication",
token.SHL: "shift",
}
// If typ is a type parameter, underIs returns the result of typ.underIs(f).
// Otherwise, underIs returns the result of f(under(typ)).
func underIs(typ Type, f func(Type) bool) bool {
if tpar, _ := typ.(*TypeParam); tpar != nil {
return tpar.underIs(f)
}
return f(under(typ))
}
// The unary expression e may be nil. It's passed in for better error messages only.
func (check *Checker) unary(x *operand, e *ast.UnaryExpr) {
check.expr(x, e.X)
if x.mode == invalid {
return
}
op := e.Op
switch op {
case token.AND:
// spec: "As an exception to the addressability
// requirement x may also be a composite literal."
if _, ok := unparen(e.X).(*ast.CompositeLit); !ok && x.mode != variable {
check.errorf(x, UnaddressableOperand, invalidOp+"cannot take address of %s", x)
x.mode = invalid
return
}
x.mode = value
x.typ = &Pointer{base: x.typ}
return
case token.ARROW:
u := coreType(x.typ)
if u == nil {
check.errorf(x, InvalidReceive, invalidOp+"cannot receive from %s (no core type)", x)
x.mode = invalid
return
}
ch, _ := u.(*Chan)
if ch == nil {
check.errorf(x, InvalidReceive, invalidOp+"cannot receive from non-channel %s", x)
x.mode = invalid
return
}
if ch.dir == SendOnly {
check.errorf(x, InvalidReceive, invalidOp+"cannot receive from send-only channel %s", x)
x.mode = invalid
return
}
x.mode = commaok
x.typ = ch.elem
check.hasCallOrRecv = true
return
case token.TILDE:
// Provide a better error position and message than what check.op below would do.
if !allInteger(x.typ) {
check.error(e, UndefinedOp, "cannot use ~ outside of interface or type constraint")
x.mode = invalid
return
}
check.error(e, UndefinedOp, "cannot use ~ outside of interface or type constraint (use ^ for bitwise complement)")
op = token.XOR
}
if !check.op(unaryOpPredicates, x, op) {
x.mode = invalid
return
}
if x.mode == constant_ {
if x.val.Kind() == constant.Unknown {
// nothing to do (and don't cause an error below in the overflow check)
return
}
var prec uint
if isUnsigned(x.typ) {
prec = uint(check.conf.sizeof(x.typ) * 8)
}
x.val = constant.UnaryOp(op, x.val, prec)
x.expr = e
check.overflow(x, x.Pos())
return
}
x.mode = value
// x.typ remains unchanged
}
func isShift(op token.Token) bool {
return op == token.SHL || op == token.SHR
}
func isComparison(op token.Token) bool {
// Note: tokens are not ordered well to make this much easier
switch op {
case token.EQL, token.NEQ, token.LSS, token.LEQ, token.GTR, token.GEQ:
return true
}
return false
}
func fitsFloat32(x constant.Value) bool {
f32, _ := constant.Float32Val(x)
f := float64(f32)
return !math.IsInf(f, 0)
}
func roundFloat32(x constant.Value) constant.Value {
f32, _ := constant.Float32Val(x)
f := float64(f32)
if !math.IsInf(f, 0) {
return constant.MakeFloat64(f)
}
return nil
}
func fitsFloat64(x constant.Value) bool {
f, _ := constant.Float64Val(x)
return !math.IsInf(f, 0)
}
func roundFloat64(x constant.Value) constant.Value {
f, _ := constant.Float64Val(x)
if !math.IsInf(f, 0) {
return constant.MakeFloat64(f)
}
return nil
}
// representableConst reports whether x can be represented as
// value of the given basic type and for the configuration
// provided (only needed for int/uint sizes).
//
// If rounded != nil, *rounded is set to the rounded value of x for
// representable floating-point and complex values, and to an Int
// value for integer values; it is left alone otherwise.
// It is ok to provide the addressof the first argument for rounded.
//
// The check parameter may be nil if representableConst is invoked
// (indirectly) through an exported API call (AssignableTo, ConvertibleTo)
// because we don't need the Checker's config for those calls.
func representableConst(x constant.Value, check *Checker, typ *Basic, rounded *constant.Value) bool {
if x.Kind() == constant.Unknown {
return true // avoid follow-up errors
}
var conf *Config
if check != nil {
conf = check.conf
}
switch {
case isInteger(typ):
x := constant.ToInt(x)
if x.Kind() != constant.Int {
return false
}
if rounded != nil {
*rounded = x
}
if x, ok := constant.Int64Val(x); ok {
switch typ.kind {
case Int:
var s = uint(conf.sizeof(typ)) * 8
return int64(-1)<<(s-1) <= x && x <= int64(1)<<(s-1)-1
case Int8:
const s = 8
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
case Int16:
const s = 16
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
case Int32:
const s = 32
return -1<<(s-1) <= x && x <= 1<<(s-1)-1
case Int64, UntypedInt:
return true
case Uint, Uintptr:
if s := uint(conf.sizeof(typ)) * 8; s < 64 {
return 0 <= x && x <= int64(1)<<s-1
}
return 0 <= x
case Uint8:
const s = 8
return 0 <= x && x <= 1<<s-1
case Uint16:
const s = 16
return 0 <= x && x <= 1<<s-1
case Uint32:
const s = 32
return 0 <= x && x <= 1<<s-1
case Uint64:
return 0 <= x
default:
unreachable()
}
}
// x does not fit into int64
switch n := constant.BitLen(x); typ.kind {
case Uint, Uintptr:
var s = uint(conf.sizeof(typ)) * 8
return constant.Sign(x) >= 0 && n <= int(s)
case Uint64:
return constant.Sign(x) >= 0 && n <= 64
case UntypedInt:
return true
}
case isFloat(typ):
x := constant.ToFloat(x)
if x.Kind() != constant.Float {
return false
}
switch typ.kind {
case Float32:
if rounded == nil {
return fitsFloat32(x)
}
r := roundFloat32(x)
if r != nil {
*rounded = r
return true
}
case Float64:
if rounded == nil {
return fitsFloat64(x)
}
r := roundFloat64(x)
if r != nil {
*rounded = r
return true
}
case UntypedFloat:
return true
default:
unreachable()
}
case isComplex(typ):
x := constant.ToComplex(x)
if x.Kind() != constant.Complex {
return false
}
switch typ.kind {
case Complex64:
if rounded == nil {
return fitsFloat32(constant.Real(x)) && fitsFloat32(constant.Imag(x))
}
re := roundFloat32(constant.Real(x))
im := roundFloat32(constant.Imag(x))
if re != nil && im != nil {
*rounded = constant.BinaryOp(re, token.ADD, constant.MakeImag(im))
return true
}
case Complex128:
if rounded == nil {
return fitsFloat64(constant.Real(x)) && fitsFloat64(constant.Imag(x))
}
re := roundFloat64(constant.Real(x))
im := roundFloat64(constant.Imag(x))
if re != nil && im != nil {
*rounded = constant.BinaryOp(re, token.ADD, constant.MakeImag(im))
return true
}
case UntypedComplex:
return true
default:
unreachable()
}
case isString(typ):
return x.Kind() == constant.String
case isBoolean(typ):
return x.Kind() == constant.Bool
}
return false
}
// representable checks that a constant operand is representable in the given
// basic type.
func (check *Checker) representable(x *operand, typ *Basic) {
v, code := check.representation(x, typ)
if code != 0 {
check.invalidConversion(code, x, typ)
x.mode = invalid
return
}
assert(v != nil)
x.val = v
}
// representation returns the representation of the constant operand x as the
// basic type typ.
//
// If no such representation is possible, it returns a non-zero error code.
func (check *Checker) representation(x *operand, typ *Basic) (constant.Value, Code) {
assert(x.mode == constant_)
v := x.val
if !representableConst(x.val, check, typ, &v) {
if isNumeric(x.typ) && isNumeric(typ) {
// numeric conversion : error msg
//
// integer -> integer : overflows
// integer -> float : overflows (actually not possible)
// float -> integer : truncated
// float -> float : overflows
//
if !isInteger(x.typ) && isInteger(typ) {
return nil, TruncatedFloat
} else {
return nil, NumericOverflow
}
}
return nil, InvalidConstVal
}
return v, 0
}
func (check *Checker) invalidConversion(code Code, x *operand, target Type) {
msg := "cannot convert %s to type %s"
switch code {
case TruncatedFloat:
msg = "%s truncated to %s"
case NumericOverflow:
msg = "%s overflows %s"
}
check.errorf(x, code, msg, x, target)
}
// updateExprType updates the type of x to typ and invokes itself
// recursively for the operands of x, depending on expression kind.
// If typ is still an untyped and not the final type, updateExprType
// only updates the recorded untyped type for x and possibly its
// operands. Otherwise (i.e., typ is not an untyped type anymore,
// or it is the final type for x), the type and value are recorded.
// Also, if x is a constant, it must be representable as a value of typ,
// and if x is the (formerly untyped) lhs operand of a non-constant
// shift, it must be an integer value.
func (check *Checker) updateExprType(x ast.Expr, typ Type, final bool) {
check.updateExprType0(nil, x, typ, final)
}
func (check *Checker) updateExprType0(parent, x ast.Expr, typ Type, final bool) {
old, found := check.untyped[x]
if !found {
return // nothing to do
}
// update operands of x if necessary
switch x := x.(type) {
case *ast.BadExpr,
*ast.FuncLit,
*ast.CompositeLit,
*ast.IndexExpr,
*ast.SliceExpr,
*ast.TypeAssertExpr,
*ast.StarExpr,
*ast.KeyValueExpr,
*ast.ArrayType,
*ast.StructType,
*ast.FuncType,
*ast.InterfaceType,
*ast.MapType,
*ast.ChanType:
// These expression are never untyped - nothing to do.
// The respective sub-expressions got their final types
// upon assignment or use.
if debug {
check.dump("%v: found old type(%s): %s (new: %s)", x.Pos(), x, old.typ, typ)
unreachable()
}
return
case *ast.CallExpr:
// Resulting in an untyped constant (e.g., built-in complex).
// The respective calls take care of calling updateExprType
// for the arguments if necessary.
case *ast.Ident, *ast.BasicLit, *ast.SelectorExpr:
// An identifier denoting a constant, a constant literal,
// or a qualified identifier (imported untyped constant).
// No operands to take care of.
case *ast.ParenExpr:
check.updateExprType0(x, x.X, typ, final)
case *ast.UnaryExpr:
// If x is a constant, the operands were constants.
// The operands don't need to be updated since they
// never get "materialized" into a typed value. If
// left in the untyped map, they will be processed
// at the end of the type check.
if old.val != nil {
break
}
check.updateExprType0(x, x.X, typ, final)
case *ast.BinaryExpr:
if old.val != nil {
break // see comment for unary expressions
}
if isComparison(x.Op) {
// The result type is independent of operand types
// and the operand types must have final types.
} else if isShift(x.Op) {
// The result type depends only on lhs operand.
// The rhs type was updated when checking the shift.
check.updateExprType0(x, x.X, typ, final)
} else {
// The operand types match the result type.
check.updateExprType0(x, x.X, typ, final)
check.updateExprType0(x, x.Y, typ, final)
}
default:
unreachable()
}
// If the new type is not final and still untyped, just
// update the recorded type.
if !final && isUntyped(typ) {
old.typ = under(typ).(*Basic)
check.untyped[x] = old
return
}
// Otherwise we have the final (typed or untyped type).
// Remove it from the map of yet untyped expressions.
delete(check.untyped, x)
if old.isLhs {
// If x is the lhs of a shift, its final type must be integer.
// We already know from the shift check that it is representable
// as an integer if it is a constant.
if !allInteger(typ) {
check.errorf(x, InvalidShiftOperand, invalidOp+"shifted operand %s (type %s) must be integer", x, typ)
return
}
// Even if we have an integer, if the value is a constant we
// still must check that it is representable as the specific
// int type requested (was go.dev/issue/22969). Fall through here.
}
if old.val != nil {
// If x is a constant, it must be representable as a value of typ.
c := operand{old.mode, x, old.typ, old.val, 0}
check.convertUntyped(&c, typ)
if c.mode == invalid {
return
}
}
// Everything's fine, record final type and value for x.
check.recordTypeAndValue(x, old.mode, typ, old.val)
}
// updateExprVal updates the value of x to val.
func (check *Checker) updateExprVal(x ast.Expr, val constant.Value) {
if info, ok := check.untyped[x]; ok {
info.val = val
check.untyped[x] = info
}
}
// convertUntyped attempts to set the type of an untyped value to the target type.
func (check *Checker) convertUntyped(x *operand, target Type) {
newType, val, code := check.implicitTypeAndValue(x, target)
if code != 0 {
t := target
if !isTypeParam(target) {
t = safeUnderlying(target)
}
check.invalidConversion(code, x, t)
x.mode = invalid
return
}
if val != nil {
x.val = val
check.updateExprVal(x.expr, val)
}
if newType != x.typ {
x.typ = newType
check.updateExprType(x.expr, newType, false)
}
}
// implicitTypeAndValue returns the implicit type of x when used in a context
// where the target type is expected. If no such implicit conversion is
// possible, it returns a nil Type and non-zero error code.
//
// If x is a constant operand, the returned constant.Value will be the
// representation of x in this context.
func (check *Checker) implicitTypeAndValue(x *operand, target Type) (Type, constant.Value, Code) {
if x.mode == invalid || isTyped(x.typ) || target == Typ[Invalid] {
return x.typ, nil, 0
}
if isUntyped(target) {
// both x and target are untyped
xkind := x.typ.(*Basic).kind
tkind := target.(*Basic).kind
if isNumeric(x.typ) && isNumeric(target) {
if xkind < tkind {
return target, nil, 0
}
} else if xkind != tkind {
return nil, nil, InvalidUntypedConversion
}
return x.typ, nil, 0
}
switch u := under(target).(type) {
case *Basic:
if x.mode == constant_ {
v, code := check.representation(x, u)
if code != 0 {
return nil, nil, code
}
return target, v, code
}
// Non-constant untyped values may appear as the
// result of comparisons (untyped bool), intermediate
// (delayed-checked) rhs operands of shifts, and as
// the value nil.
switch x.typ.(*Basic).kind {
case UntypedBool:
if !isBoolean(target) {
return nil, nil, InvalidUntypedConversion
}
case UntypedInt, UntypedRune, UntypedFloat, UntypedComplex:
if !isNumeric(target) {
return nil, nil, InvalidUntypedConversion
}
case UntypedString:
// Non-constant untyped string values are not permitted by the spec and
// should not occur during normal typechecking passes, but this path is
// reachable via the AssignableTo API.
if !isString(target) {
return nil, nil, InvalidUntypedConversion
}
case UntypedNil:
// Unsafe.Pointer is a basic type that includes nil.
if !hasNil(target) {
return nil, nil, InvalidUntypedConversion
}
// Preserve the type of nil as UntypedNil: see go.dev/issue/13061.
return Typ[UntypedNil], nil, 0
default:
return nil, nil, InvalidUntypedConversion
}
case *Interface:
if isTypeParam(target) {
if !u.typeSet().underIs(func(u Type) bool {
if u == nil {
return false
}
t, _, _ := check.implicitTypeAndValue(x, u)
return t != nil
}) {
return nil, nil, InvalidUntypedConversion
}
// keep nil untyped (was bug go.dev/issue/39755)
if x.isNil() {
return Typ[UntypedNil], nil, 0
}
break
}
// Values must have concrete dynamic types. If the value is nil,
// keep it untyped (this is important for tools such as go vet which
// need the dynamic type for argument checking of say, print
// functions)
if x.isNil() {
return Typ[UntypedNil], nil, 0
}
// cannot assign untyped values to non-empty interfaces
if !u.Empty() {
return nil, nil, InvalidUntypedConversion
}
return Default(x.typ), nil, 0
case *Pointer, *Signature, *Slice, *Map, *Chan:
if !x.isNil() {
return nil, nil, InvalidUntypedConversion
}
// Keep nil untyped - see comment for interfaces, above.
return Typ[UntypedNil], nil, 0
default:
return nil, nil, InvalidUntypedConversion
}
return target, nil, 0
}
// If switchCase is true, the operator op is ignored.
func (check *Checker) comparison(x, y *operand, op token.Token, switchCase bool) {
// Avoid spurious errors if any of the operands has an invalid type (go.dev/issue/54405).
if x.typ == Typ[Invalid] || y.typ == Typ[Invalid] {
x.mode = invalid
return
}
if switchCase {
op = token.EQL
}
errOp := x // operand for which error is reported, if any
cause := "" // specific error cause, if any
// spec: "In any comparison, the first operand must be assignable
// to the type of the second operand, or vice versa."
code := MismatchedTypes
ok, _ := x.assignableTo(check, y.typ, nil)
if !ok {
ok, _ = y.assignableTo(check, x.typ, nil)
}
if !ok {
// Report the error on the 2nd operand since we only
// know after seeing the 2nd operand whether we have
// a type mismatch.
errOp = y
cause = check.sprintf("mismatched types %s and %s", x.typ, y.typ)
goto Error
}
// check if comparison is defined for operands
code = UndefinedOp
switch op {
case token.EQL, token.NEQ:
// spec: "The equality operators == and != apply to operands that are comparable."
switch {
case x.isNil() || y.isNil():
// Comparison against nil requires that the other operand type has nil.
typ := x.typ
if x.isNil() {
typ = y.typ
}
if !hasNil(typ) {
// This case should only be possible for "nil == nil".
// Report the error on the 2nd operand since we only
// know after seeing the 2nd operand whether we have
// an invalid comparison.
errOp = y
goto Error
}
case !Comparable(x.typ):
errOp = x
cause = check.incomparableCause(x.typ)
goto Error
case !Comparable(y.typ):
errOp = y
cause = check.incomparableCause(y.typ)
goto Error
}
case token.LSS, token.LEQ, token.GTR, token.GEQ:
// spec: The ordering operators <, <=, >, and >= apply to operands that are ordered."
switch {
case !allOrdered(x.typ):
errOp = x
goto Error
case !allOrdered(y.typ):
errOp = y
goto Error
}
default:
unreachable()
}
// comparison is ok
if x.mode == constant_ && y.mode == constant_ {
x.val = constant.MakeBool(constant.Compare(x.val, op, y.val))
// The operands are never materialized; no need to update
// their types.
} else {
x.mode = value
// The operands have now their final types, which at run-
// time will be materialized. Update the expression trees.
// If the current types are untyped, the materialized type
// is the respective default type.
check.updateExprType(x.expr, Default(x.typ), true)
check.updateExprType(y.expr, Default(y.typ), true)
}
// spec: "Comparison operators compare two operands and yield
// an untyped boolean value."
x.typ = Typ[UntypedBool]
return
Error:
// We have an offending operand errOp and possibly an error cause.
if cause == "" {
if isTypeParam(x.typ) || isTypeParam(y.typ) {
// TODO(gri) should report the specific type causing the problem, if any
if !isTypeParam(x.typ) {
errOp = y
}
cause = check.sprintf("type parameter %s is not comparable with %s", errOp.typ, op)
} else {
cause = check.sprintf("operator %s not defined on %s", op, check.kindString(errOp.typ)) // catch-all
}
}
if switchCase {
check.errorf(x, code, "invalid case %s in switch on %s (%s)", x.expr, y.expr, cause) // error position always at 1st operand
} else {
check.errorf(errOp, code, invalidOp+"%s %s %s (%s)", x.expr, op, y.expr, cause)
}
x.mode = invalid
}
// incomparableCause returns a more specific cause why typ is not comparable.
// If there is no more specific cause, the result is "".
func (check *Checker) incomparableCause(typ Type) string {
switch under(typ).(type) {
case *Slice, *Signature, *Map:
return check.kindString(typ) + " can only be compared to nil"
}
// see if we can extract a more specific error
var cause string
comparable(typ, true, nil, func(format string, args ...interface{}) {
cause = check.sprintf(format, args...)
})
return cause
}
// kindString returns the type kind as a string.
func (check *Checker) kindString(typ Type) string {
switch under(typ).(type) {
case *Array:
return "array"
case *Slice:
return "slice"
case *Struct:
return "struct"
case *Pointer:
return "pointer"
case *Signature:
return "func"
case *Interface:
if isTypeParam(typ) {
return check.sprintf("type parameter %s", typ)
}
return "interface"
case *Map:
return "map"
case *Chan:
return "chan"
default:
return check.sprintf("%s", typ) // catch-all
}
}
// If e != nil, it must be the shift expression; it may be nil for non-constant shifts.
func (check *Checker) shift(x, y *operand, e ast.Expr, op token.Token) {
// TODO(gri) This function seems overly complex. Revisit.
var xval constant.Value
if x.mode == constant_ {
xval = constant.ToInt(x.val)
}
if allInteger(x.typ) || isUntyped(x.typ) && xval != nil && xval.Kind() == constant.Int {
// The lhs is of integer type or an untyped constant representable
// as an integer. Nothing to do.
} else {
// shift has no chance
check.errorf(x, InvalidShiftOperand, invalidOp+"shifted operand %s must be integer", x)
x.mode = invalid
return
}
// spec: "The right operand in a shift expression must have integer type
// or be an untyped constant representable by a value of type uint."
// Check that constants are representable by uint, but do not convert them
// (see also go.dev/issue/47243).
var yval constant.Value
if y.mode == constant_ {
// Provide a good error message for negative shift counts.
yval = constant.ToInt(y.val) // consider -1, 1.0, but not -1.1
if yval.Kind() == constant.Int && constant.Sign(yval) < 0 {
check.errorf(y, InvalidShiftCount, invalidOp+"negative shift count %s", y)
x.mode = invalid
return
}
if isUntyped(y.typ) {
// Caution: Check for representability here, rather than in the switch
// below, because isInteger includes untyped integers (was bug go.dev/issue/43697).
check.representable(y, Typ[Uint])
if y.mode == invalid {
x.mode = invalid
return
}
}
} else {
// Check that RHS is otherwise at least of integer type.
switch {
case allInteger(y.typ):
if !allUnsigned(y.typ) && !check.allowVersion(check.pkg, 1, 13) {
check.errorf(y, UnsupportedFeature, invalidOp+"signed shift count %s requires go1.13 or later", y)
x.mode = invalid
return
}
case isUntyped(y.typ):
// This is incorrect, but preserves pre-existing behavior.
// See also go.dev/issue/47410.
check.convertUntyped(y, Typ[Uint])
if y.mode == invalid {
x.mode = invalid
return
}
default:
check.errorf(y, InvalidShiftCount, invalidOp+"shift count %s must be integer", y)
x.mode = invalid
return
}
}
if x.mode == constant_ {
if y.mode == constant_ {
// if either x or y has an unknown value, the result is unknown
if x.val.Kind() == constant.Unknown || y.val.Kind() == constant.Unknown {
x.val = constant.MakeUnknown()
// ensure the correct type - see comment below
if !isInteger(x.typ) {
x.typ = Typ[UntypedInt]
}
return
}
// rhs must be within reasonable bounds in constant shifts
const shiftBound = 1023 - 1 + 52 // so we can express smallestFloat64 (see go.dev/issue/44057)
s, ok := constant.Uint64Val(yval)
if !ok || s > shiftBound {
check.errorf(y, InvalidShiftCount, invalidOp+"invalid shift count %s", y)
x.mode = invalid
return
}
// The lhs is representable as an integer but may not be an integer
// (e.g., 2.0, an untyped float) - this can only happen for untyped
// non-integer numeric constants. Correct the type so that the shift
// result is of integer type.
if !isInteger(x.typ) {
x.typ = Typ[UntypedInt]
}
// x is a constant so xval != nil and it must be of Int kind.
x.val = constant.Shift(xval, op, uint(s))
x.expr = e
opPos := x.Pos()
if b, _ := e.(*ast.BinaryExpr); b != nil {
opPos = b.OpPos
}
check.overflow(x, opPos)
return
}
// non-constant shift with constant lhs
if isUntyped(x.typ) {
// spec: "If the left operand of a non-constant shift
// expression is an untyped constant, the type of the
// constant is what it would be if the shift expression
// were replaced by its left operand alone.".
//
// Delay operand checking until we know the final type
// by marking the lhs expression as lhs shift operand.
//
// Usually (in correct programs), the lhs expression
// is in the untyped map. However, it is possible to
// create incorrect programs where the same expression
// is evaluated twice (via a declaration cycle) such
// that the lhs expression type is determined in the
// first round and thus deleted from the map, and then
// not found in the second round (double insertion of
// the same expr node still just leads to one entry for
// that node, and it can only be deleted once).
// Be cautious and check for presence of entry.
// Example: var e, f = int(1<<""[f]) // go.dev/issue/11347
if info, found := check.untyped[x.expr]; found {
info.isLhs = true
check.untyped[x.expr] = info
}
// keep x's type
x.mode = value
return
}
}
// non-constant shift - lhs must be an integer
if !allInteger(x.typ) {
check.errorf(x, InvalidShiftOperand, invalidOp+"shifted operand %s must be integer", x)
x.mode = invalid
return
}
x.mode = value
}
var binaryOpPredicates opPredicates
func init() {
// Setting binaryOpPredicates in init avoids declaration cycles.
binaryOpPredicates = opPredicates{
token.ADD: allNumericOrString,
token.SUB: allNumeric,
token.MUL: allNumeric,
token.QUO: allNumeric,
token.REM: allInteger,
token.AND: allInteger,
token.OR: allInteger,
token.XOR: allInteger,
token.AND_NOT: allInteger,
token.LAND: allBoolean,
token.LOR: allBoolean,
}
}
// If e != nil, it must be the binary expression; it may be nil for non-constant expressions
// (when invoked for an assignment operation where the binary expression is implicit).
func (check *Checker) binary(x *operand, e ast.Expr, lhs, rhs ast.Expr, op token.Token, opPos token.Pos) {
var y operand
check.expr(x, lhs)
check.expr(&y, rhs)
if x.mode == invalid {
return
}
if y.mode == invalid {
x.mode = invalid
x.expr = y.expr
return
}
if isShift(op) {
check.shift(x, &y, e, op)
return
}
// mayConvert reports whether the operands x and y may
// possibly have matching types after converting one
// untyped operand to the type of the other.
// If mayConvert returns true, we try to convert the
// operands to each other's types, and if that fails
// we report a conversion failure.
// If mayConvert returns false, we continue without an
// attempt at conversion, and if the operand types are
// not compatible, we report a type mismatch error.
mayConvert := func(x, y *operand) bool {
// If both operands are typed, there's no need for an implicit conversion.
if isTyped(x.typ) && isTyped(y.typ) {
return false
}
// An untyped operand may convert to its default type when paired with an empty interface
// TODO(gri) This should only matter for comparisons (the only binary operation that is
// valid with interfaces), but in that case the assignability check should take
// care of the conversion. Verify and possibly eliminate this extra test.
if isNonTypeParamInterface(x.typ) || isNonTypeParamInterface(y.typ) {
return true
}
// A boolean type can only convert to another boolean type.
if allBoolean(x.typ) != allBoolean(y.typ) {
return false
}
// A string type can only convert to another string type.
if allString(x.typ) != allString(y.typ) {
return false
}
// Untyped nil can only convert to a type that has a nil.
if x.isNil() {
return hasNil(y.typ)
}
if y.isNil() {
return hasNil(x.typ)
}
// An untyped operand cannot convert to a pointer.
// TODO(gri) generalize to type parameters
if isPointer(x.typ) || isPointer(y.typ) {
return false
}
return true
}
if mayConvert(x, &y) {
check.convertUntyped(x, y.typ)
if x.mode == invalid {
return
}
check.convertUntyped(&y, x.typ)
if y.mode == invalid {
x.mode = invalid
return
}
}
if isComparison(op) {
check.comparison(x, &y, op, false)
return
}
if !Identical(x.typ, y.typ) {
// only report an error if we have valid types
// (otherwise we had an error reported elsewhere already)
if x.typ != Typ[Invalid] && y.typ != Typ[Invalid] {
var posn positioner = x
if e != nil {
posn = e
}
if e != nil {
check.errorf(posn, MismatchedTypes, invalidOp+"%s (mismatched types %s and %s)", e, x.typ, y.typ)
} else {
check.errorf(posn, MismatchedTypes, invalidOp+"%s %s= %s (mismatched types %s and %s)", lhs, op, rhs, x.typ, y.typ)
}
}
x.mode = invalid
return
}
if !check.op(binaryOpPredicates, x, op) {
x.mode = invalid
return
}
if op == token.QUO || op == token.REM {
// check for zero divisor
if (x.mode == constant_ || allInteger(x.typ)) && y.mode == constant_ && constant.Sign(y.val) == 0 {
check.error(&y, DivByZero, invalidOp+"division by zero")
x.mode = invalid
return
}
// check for divisor underflow in complex division (see go.dev/issue/20227)
if x.mode == constant_ && y.mode == constant_ && isComplex(x.typ) {
re, im := constant.Real(y.val), constant.Imag(y.val)
re2, im2 := constant.BinaryOp(re, token.MUL, re), constant.BinaryOp(im, token.MUL, im)
if constant.Sign(re2) == 0 && constant.Sign(im2) == 0 {
check.error(&y, DivByZero, invalidOp+"division by zero")
x.mode = invalid
return
}
}
}
if x.mode == constant_ && y.mode == constant_ {
// if either x or y has an unknown value, the result is unknown
if x.val.Kind() == constant.Unknown || y.val.Kind() == constant.Unknown {
x.val = constant.MakeUnknown()
// x.typ is unchanged
return
}
// force integer division of integer operands
if op == token.QUO && isInteger(x.typ) {
op = token.QUO_ASSIGN
}
x.val = constant.BinaryOp(x.val, op, y.val)
x.expr = e
check.overflow(x, opPos)
return
}
x.mode = value
// x.typ is unchanged
}
// exprKind describes the kind of an expression; the kind
// determines if an expression is valid in 'statement context'.
type exprKind int
const (
conversion exprKind = iota
expression
statement
)
// rawExpr typechecks expression e and initializes x with the expression
// value or type. If an error occurred, x.mode is set to invalid.
// If hint != nil, it is the type of a composite literal element.
// If allowGeneric is set, the operand type may be an uninstantiated
// parameterized type or function value.
func (check *Checker) rawExpr(x *operand, e ast.Expr, hint Type, allowGeneric bool) exprKind {
if check.conf._Trace {
check.trace(e.Pos(), "-- expr %s", e)
check.indent++
defer func() {
check.indent--
check.trace(e.Pos(), "=> %s", x)
}()
}
kind := check.exprInternal(x, e, hint)
if !allowGeneric {
check.nonGeneric(x)
}
check.record(x)
return kind
}
// If x is a generic function or type, nonGeneric reports an error and invalidates x.mode and x.typ.
// Otherwise it leaves x alone.
func (check *Checker) nonGeneric(x *operand) {
if x.mode == invalid || x.mode == novalue {
return
}
var what string
switch t := x.typ.(type) {
case *Named:
if isGeneric(t) {
what = "type"
}
case *Signature:
if t.tparams != nil {
what = "function"
}
}
if what != "" {
check.errorf(x.expr, WrongTypeArgCount, "cannot use generic %s %s without instantiation", what, x.expr)
x.mode = invalid
x.typ = Typ[Invalid]
}
}
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
func (check *Checker) exprInternal(x *operand, e ast.Expr, hint Type) exprKind {
// make sure x has a valid state in case of bailout
// (was go.dev/issue/5770)
x.mode = invalid
x.typ = Typ[Invalid]
switch e := e.(type) {
case *ast.BadExpr:
goto Error // error was reported before
case *ast.Ident:
check.ident(x, e, nil, false)
case *ast.Ellipsis:
// ellipses are handled explicitly where they are legal
// (array composite literals and parameter lists)
check.error(e, BadDotDotDotSyntax, "invalid use of '...'")
goto Error
case *ast.BasicLit:
switch e.Kind {
case token.INT, token.FLOAT, token.IMAG:
check.langCompat(e)
// The max. mantissa precision for untyped numeric values
// is 512 bits, or 4048 bits for each of the two integer
// parts of a fraction for floating-point numbers that are
// represented accurately in the go/constant package.
// Constant literals that are longer than this many bits
// are not meaningful; and excessively long constants may
// consume a lot of space and time for a useless conversion.
// Cap constant length with a generous upper limit that also
// allows for separators between all digits.
const limit = 10000
if len(e.Value) > limit {
check.errorf(e, InvalidConstVal, "excessively long constant: %s... (%d chars)", e.Value[:10], len(e.Value))
goto Error
}
}
x.setConst(e.Kind, e.Value)
if x.mode == invalid {
// The parser already establishes syntactic correctness.
// If we reach here it's because of number under-/overflow.
// TODO(gri) setConst (and in turn the go/constant package)
// should return an error describing the issue.
check.errorf(e, InvalidConstVal, "malformed constant: %s", e.Value)
goto Error
}
// Ensure that integer values don't overflow (go.dev/issue/54280).
check.overflow(x, e.Pos())
case *ast.FuncLit:
if sig, ok := check.typ(e.Type).(*Signature); ok {
if !check.conf.IgnoreFuncBodies && e.Body != nil {
// Anonymous functions are considered part of the
// init expression/func declaration which contains
// them: use existing package-level declaration info.
decl := check.decl // capture for use in closure below
iota := check.iota // capture for use in closure below (go.dev/issue/22345)
// Don't type-check right away because the function may
// be part of a type definition to which the function
// body refers. Instead, type-check as soon as possible,
// but before the enclosing scope contents changes (go.dev/issue/22992).
check.later(func() {
check.funcBody(decl, "<function literal>", sig, e.Body, iota)
}).describef(e, "func literal")
}
x.mode = value
x.typ = sig
} else {
check.errorf(e, InvalidSyntaxTree, "invalid function literal %s", e)
goto Error
}
case *ast.CompositeLit:
var typ, base Type
switch {
case e.Type != nil:
// composite literal type present - use it
// [...]T array types may only appear with composite literals.
// Check for them here so we don't have to handle ... in general.
if atyp, _ := e.Type.(*ast.ArrayType); atyp != nil && atyp.Len != nil {
if ellip, _ := atyp.Len.(*ast.Ellipsis); ellip != nil && ellip.Elt == nil {
// We have an "open" [...]T array type.
// Create a new ArrayType with unknown length (-1)
// and finish setting it up after analyzing the literal.
typ = &Array{len: -1, elem: check.varType(atyp.Elt)}
base = typ
break
}
}
typ = check.typ(e.Type)
base = typ
case hint != nil:
// no composite literal type present - use hint (element type of enclosing type)
typ = hint
base, _ = deref(coreType(typ)) // *T implies &T{}
if base == nil {
check.errorf(e, InvalidLit, "invalid composite literal element type %s (no core type)", typ)
goto Error
}
default:
// TODO(gri) provide better error messages depending on context
check.error(e, UntypedLit, "missing type in composite literal")
goto Error
}
switch utyp := coreType(base).(type) {
case *Struct:
// Prevent crash if the struct referred to is not yet set up.
// See analogous comment for *Array.
if utyp.fields == nil {
check.error(e, InvalidTypeCycle, "invalid recursive type")
goto Error
}
if len(e.Elts) == 0 {
break
}
// Convention for error messages on invalid struct literals:
// we mention the struct type only if it clarifies the error
// (e.g., a duplicate field error doesn't need the struct type).
fields := utyp.fields
if _, ok := e.Elts[0].(*ast.KeyValueExpr); ok {
// all elements must have keys
visited := make([]bool, len(fields))
for _, e := range e.Elts {
kv, _ := e.(*ast.KeyValueExpr)
if kv == nil {
check.error(e, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
key, _ := kv.Key.(*ast.Ident)
// do all possible checks early (before exiting due to errors)
// so we don't drop information on the floor
check.expr(x, kv.Value)
if key == nil {
check.errorf(kv, InvalidLitField, "invalid field name %s in struct literal", kv.Key)
continue
}
i := fieldIndex(utyp.fields, check.pkg, key.Name)
if i < 0 {
check.errorf(kv, MissingLitField, "unknown field %s in struct literal of type %s", key.Name, base)
continue
}
fld := fields[i]
check.recordUse(key, fld)
etyp := fld.typ
check.assignment(x, etyp, "struct literal")
// 0 <= i < len(fields)
if visited[i] {
check.errorf(kv, DuplicateLitField, "duplicate field name %s in struct literal", key.Name)
continue
}
visited[i] = true
}
} else {
// no element must have a key
for i, e := range e.Elts {
if kv, _ := e.(*ast.KeyValueExpr); kv != nil {
check.error(kv, MixedStructLit, "mixture of field:value and value elements in struct literal")
continue
}
check.expr(x, e)
if i >= len(fields) {
check.errorf(x, InvalidStructLit, "too many values in struct literal of type %s", base)
break // cannot continue
}
// i < len(fields)
fld := fields[i]
if !fld.Exported() && fld.pkg != check.pkg {
check.errorf(x,
UnexportedLitField,
"implicit assignment to unexported field %s in struct literal of type %s", fld.name, base)
continue
}
etyp := fld.typ
check.assignment(x, etyp, "struct literal")
}
if len(e.Elts) < len(fields) {
check.errorf(inNode(e, e.Rbrace), InvalidStructLit, "too few values in struct literal of type %s", base)
// ok to continue
}
}
case *Array:
// Prevent crash if the array referred to is not yet set up. Was go.dev/issue/18643.
// This is a stop-gap solution. Should use Checker.objPath to report entire
// path starting with earliest declaration in the source. TODO(gri) fix this.
if utyp.elem == nil {
check.error(e, InvalidTypeCycle, "invalid recursive type")
goto Error
}
n := check.indexedElts(e.Elts, utyp.elem, utyp.len)
// If we have an array of unknown length (usually [...]T arrays, but also
// arrays [n]T where n is invalid) set the length now that we know it and
// record the type for the array (usually done by check.typ which is not
// called for [...]T). We handle [...]T arrays and arrays with invalid
// length the same here because it makes sense to "guess" the length for
// the latter if we have a composite literal; e.g. for [n]int{1, 2, 3}
// where n is invalid for some reason, it seems fair to assume it should
// be 3 (see also Checked.arrayLength and go.dev/issue/27346).
if utyp.len < 0 {
utyp.len = n
// e.Type is missing if we have a composite literal element
// that is itself a composite literal with omitted type. In
// that case there is nothing to record (there is no type in
// the source at that point).
if e.Type != nil {
check.recordTypeAndValue(e.Type, typexpr, utyp, nil)
}
}
case *Slice:
// Prevent crash if the slice referred to is not yet set up.
// See analogous comment for *Array.
if utyp.elem == nil {
check.error(e, InvalidTypeCycle, "invalid recursive type")
goto Error
}
check.indexedElts(e.Elts, utyp.elem, -1)
case *Map:
// Prevent crash if the map referred to is not yet set up.
// See analogous comment for *Array.
if utyp.key == nil || utyp.elem == nil {
check.error(e, InvalidTypeCycle, "invalid recursive type")
goto Error
}
// If the map key type is an interface (but not a type parameter),
// the type of a constant key must be considered when checking for
// duplicates.
keyIsInterface := isNonTypeParamInterface(utyp.key)
visited := make(map[any][]Type, len(e.Elts))
for _, e := range e.Elts {
kv, _ := e.(*ast.KeyValueExpr)
if kv == nil {
check.error(e, MissingLitKey, "missing key in map literal")
continue
}
check.exprWithHint(x, kv.Key, utyp.key)
check.assignment(x, utyp.key, "map literal")
if x.mode == invalid {
continue
}
if x.mode == constant_ {
duplicate := false
xkey := keyVal(x.val)
if keyIsInterface {
for _, vtyp := range visited[xkey] {
if Identical(vtyp, x.typ) {
duplicate = true
break
}
}
visited[xkey] = append(visited[xkey], x.typ)
} else {
_, duplicate = visited[xkey]
visited[xkey] = nil
}
if duplicate {
check.errorf(x, DuplicateLitKey, "duplicate key %s in map literal", x.val)
continue
}
}
check.exprWithHint(x, kv.Value, utyp.elem)
check.assignment(x, utyp.elem, "map literal")
}
default:
// when "using" all elements unpack KeyValueExpr
// explicitly because check.use doesn't accept them
for _, e := range e.Elts {
if kv, _ := e.(*ast.KeyValueExpr); kv != nil {
// Ideally, we should also "use" kv.Key but we can't know
// if it's an externally defined struct key or not. Going
// forward anyway can lead to other errors. Give up instead.
e = kv.Value
}
check.use(e)
}
// if utyp is invalid, an error was reported before
if utyp != Typ[Invalid] {
check.errorf(e, InvalidLit, "invalid composite literal type %s", typ)
goto Error
}
}
x.mode = value
x.typ = typ
case *ast.ParenExpr:
kind := check.rawExpr(x, e.X, nil, false)
x.expr = e
return kind
case *ast.SelectorExpr:
check.selector(x, e, nil, false)
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(e)
if check.indexExpr(x, ix) {
check.funcInst(x, ix)
}
if x.mode == invalid {
goto Error
}
case *ast.SliceExpr:
check.sliceExpr(x, e)
if x.mode == invalid {
goto Error
}
case *ast.TypeAssertExpr:
check.expr(x, e.X)
if x.mode == invalid {
goto Error
}
// x.(type) expressions are handled explicitly in type switches
if e.Type == nil {
// Don't use invalidAST because this can occur in the AST produced by
// go/parser.
check.error(e, BadTypeKeyword, "use of .(type) outside type switch")
goto Error
}
// TODO(gri) we may want to permit type assertions on type parameter values at some point
if isTypeParam(x.typ) {
check.errorf(x, InvalidAssert, invalidOp+"cannot use type assertion on type parameter value %s", x)
goto Error
}
if _, ok := under(x.typ).(*Interface); !ok {
check.errorf(x, InvalidAssert, invalidOp+"%s is not an interface", x)
goto Error
}
T := check.varType(e.Type)
if T == Typ[Invalid] {
goto Error
}
check.typeAssertion(e, x, T, false)
x.mode = commaok
x.typ = T
case *ast.CallExpr:
return check.callExpr(x, e)
case *ast.StarExpr:
check.exprOrType(x, e.X, false)
switch x.mode {
case invalid:
goto Error
case typexpr:
check.validVarType(e.X, x.typ)
x.typ = &Pointer{base: x.typ}
default:
var base Type
if !underIs(x.typ, func(u Type) bool {
p, _ := u.(*Pointer)
if p == nil {
check.errorf(x, InvalidIndirection, invalidOp+"cannot indirect %s", x)
return false
}
if base != nil && !Identical(p.base, base) {
check.errorf(x, InvalidIndirection, invalidOp+"pointers of %s must have identical base types", x)
return false
}
base = p.base
return true
}) {
goto Error
}
x.mode = variable
x.typ = base
}
case *ast.UnaryExpr:
check.unary(x, e)
if x.mode == invalid {
goto Error
}
if e.Op == token.ARROW {
x.expr = e
return statement // receive operations may appear in statement context
}
case *ast.BinaryExpr:
check.binary(x, e, e.X, e.Y, e.Op, e.OpPos)
if x.mode == invalid {
goto Error
}
case *ast.KeyValueExpr:
// key:value expressions are handled in composite literals
check.error(e, InvalidSyntaxTree, "no key:value expected")
goto Error
case *ast.ArrayType, *ast.StructType, *ast.FuncType,
*ast.InterfaceType, *ast.MapType, *ast.ChanType:
x.mode = typexpr
x.typ = check.typ(e)
// Note: rawExpr (caller of exprInternal) will call check.recordTypeAndValue
// even though check.typ has already called it. This is fine as both
// times the same expression and type are recorded. It is also not a
// performance issue because we only reach here for composite literal
// types, which are comparatively rare.
default:
panic(fmt.Sprintf("%s: unknown expression type %T", check.fset.Position(e.Pos()), e))
}
// everything went well
x.expr = e
return expression
Error:
x.mode = invalid
x.expr = e
return statement // avoid follow-up errors
}
// keyVal maps a complex, float, integer, string or boolean constant value
// to the corresponding complex128, float64, int64, uint64, string, or bool
// Go value if possible; otherwise it returns x.
// A complex constant that can be represented as a float (such as 1.2 + 0i)
// is returned as a floating point value; if a floating point value can be
// represented as an integer (such as 1.0) it is returned as an integer value.
// This ensures that constants of different kind but equal value (such as
// 1.0 + 0i, 1.0, 1) result in the same value.
func keyVal(x constant.Value) interface{} {
switch x.Kind() {
case constant.Complex:
f := constant.ToFloat(x)
if f.Kind() != constant.Float {
r, _ := constant.Float64Val(constant.Real(x))
i, _ := constant.Float64Val(constant.Imag(x))
return complex(r, i)
}
x = f
fallthrough
case constant.Float:
i := constant.ToInt(x)
if i.Kind() != constant.Int {
v, _ := constant.Float64Val(x)
return v
}
x = i
fallthrough
case constant.Int:
if v, ok := constant.Int64Val(x); ok {
return v
}
if v, ok := constant.Uint64Val(x); ok {
return v
}
case constant.String:
return constant.StringVal(x)
case constant.Bool:
return constant.BoolVal(x)
}
return x
}
// typeAssertion checks x.(T). The type of x must be an interface.
func (check *Checker) typeAssertion(e ast.Expr, x *operand, T Type, typeSwitch bool) {
var cause string
if check.assertableTo(x.typ, T, &cause) {
return // success
}
if typeSwitch {
check.errorf(e, ImpossibleAssert, "impossible type switch case: %s\n\t%s cannot have dynamic type %s %s", e, x, T, cause)
return
}
check.errorf(e, ImpossibleAssert, "impossible type assertion: %s\n\t%s does not implement %s %s", e, T, x.typ, cause)
}
// expr typechecks expression e and initializes x with the expression value.
// The result must be a single value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) expr(x *operand, e ast.Expr) {
check.rawExpr(x, e, nil, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// multiExpr is like expr but the result may also be a multi-value.
func (check *Checker) multiExpr(x *operand, e ast.Expr) {
check.rawExpr(x, e, nil, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
}
// exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprWithHint(x *operand, e ast.Expr, hint Type) {
assert(hint != nil)
check.rawExpr(x, e, hint, false)
check.exclude(x, 1<<novalue|1<<builtin|1<<typexpr)
check.singleValue(x)
}
// exprOrType typechecks expression or type e and initializes x with the expression value or type.
// If allowGeneric is set, the operand type may be an uninstantiated parameterized type or function
// value.
// If an error occurred, x.mode is set to invalid.
func (check *Checker) exprOrType(x *operand, e ast.Expr, allowGeneric bool) {
check.rawExpr(x, e, nil, allowGeneric)
check.exclude(x, 1<<novalue)
check.singleValue(x)
}
// exclude reports an error if x.mode is in modeset and sets x.mode to invalid.
// The modeset may contain any of 1<<novalue, 1<<builtin, 1<<typexpr.
func (check *Checker) exclude(x *operand, modeset uint) {
if modeset&(1<<x.mode) != 0 {
var msg string
var code Code
switch x.mode {
case novalue:
if modeset&(1<<typexpr) != 0 {
msg = "%s used as value"
} else {
msg = "%s used as value or type"
}
code = TooManyValues
case builtin:
msg = "%s must be called"
code = UncalledBuiltin
case typexpr:
msg = "%s is not an expression"
code = NotAnExpr
default:
unreachable()
}
check.errorf(x, code, msg, x)
x.mode = invalid
}
}
// singleValue reports an error if x describes a tuple and sets x.mode to invalid.
func (check *Checker) singleValue(x *operand) {
if x.mode == value {
// tuple types are never named - no need for underlying type below
if t, ok := x.typ.(*Tuple); ok {
assert(t.Len() != 1)
check.errorf(x, TooManyValues, "multiple-value %s in single-value context", x)
x.mode = invalid
}
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements printing of expressions.
package types
import (
"bytes"
"fmt"
"go/ast"
"go/internal/typeparams"
)
// ExprString returns the (possibly shortened) string representation for x.
// Shortened representations are suitable for user interfaces but may not
// necessarily follow Go syntax.
func ExprString(x ast.Expr) string {
var buf bytes.Buffer
WriteExpr(&buf, x)
return buf.String()
}
// WriteExpr writes the (possibly shortened) string representation for x to buf.
// Shortened representations are suitable for user interfaces but may not
// necessarily follow Go syntax.
func WriteExpr(buf *bytes.Buffer, x ast.Expr) {
// The AST preserves source-level parentheses so there is
// no need to introduce them here to correct for different
// operator precedences. (This assumes that the AST was
// generated by a Go parser.)
switch x := x.(type) {
default:
fmt.Fprintf(buf, "(ast: %T)", x) // nil, ast.BadExpr, ast.KeyValueExpr
case *ast.Ident:
buf.WriteString(x.Name)
case *ast.Ellipsis:
buf.WriteString("...")
if x.Elt != nil {
WriteExpr(buf, x.Elt)
}
case *ast.BasicLit:
buf.WriteString(x.Value)
case *ast.FuncLit:
buf.WriteByte('(')
WriteExpr(buf, x.Type)
buf.WriteString(" literal)") // shortened
case *ast.CompositeLit:
WriteExpr(buf, x.Type)
buf.WriteByte('{')
if len(x.Elts) > 0 {
buf.WriteString("…")
}
buf.WriteByte('}')
case *ast.ParenExpr:
buf.WriteByte('(')
WriteExpr(buf, x.X)
buf.WriteByte(')')
case *ast.SelectorExpr:
WriteExpr(buf, x.X)
buf.WriteByte('.')
buf.WriteString(x.Sel.Name)
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(x)
WriteExpr(buf, ix.X)
buf.WriteByte('[')
writeExprList(buf, ix.Indices)
buf.WriteByte(']')
case *ast.SliceExpr:
WriteExpr(buf, x.X)
buf.WriteByte('[')
if x.Low != nil {
WriteExpr(buf, x.Low)
}
buf.WriteByte(':')
if x.High != nil {
WriteExpr(buf, x.High)
}
if x.Slice3 {
buf.WriteByte(':')
if x.Max != nil {
WriteExpr(buf, x.Max)
}
}
buf.WriteByte(']')
case *ast.TypeAssertExpr:
WriteExpr(buf, x.X)
buf.WriteString(".(")
WriteExpr(buf, x.Type)
buf.WriteByte(')')
case *ast.CallExpr:
WriteExpr(buf, x.Fun)
buf.WriteByte('(')
writeExprList(buf, x.Args)
if x.Ellipsis.IsValid() {
buf.WriteString("...")
}
buf.WriteByte(')')
case *ast.StarExpr:
buf.WriteByte('*')
WriteExpr(buf, x.X)
case *ast.UnaryExpr:
buf.WriteString(x.Op.String())
WriteExpr(buf, x.X)
case *ast.BinaryExpr:
WriteExpr(buf, x.X)
buf.WriteByte(' ')
buf.WriteString(x.Op.String())
buf.WriteByte(' ')
WriteExpr(buf, x.Y)
case *ast.ArrayType:
buf.WriteByte('[')
if x.Len != nil {
WriteExpr(buf, x.Len)
}
buf.WriteByte(']')
WriteExpr(buf, x.Elt)
case *ast.StructType:
buf.WriteString("struct{")
writeFieldList(buf, x.Fields.List, "; ", false)
buf.WriteByte('}')
case *ast.FuncType:
buf.WriteString("func")
writeSigExpr(buf, x)
case *ast.InterfaceType:
buf.WriteString("interface{")
writeFieldList(buf, x.Methods.List, "; ", true)
buf.WriteByte('}')
case *ast.MapType:
buf.WriteString("map[")
WriteExpr(buf, x.Key)
buf.WriteByte(']')
WriteExpr(buf, x.Value)
case *ast.ChanType:
var s string
switch x.Dir {
case ast.SEND:
s = "chan<- "
case ast.RECV:
s = "<-chan "
default:
s = "chan "
}
buf.WriteString(s)
WriteExpr(buf, x.Value)
}
}
func writeSigExpr(buf *bytes.Buffer, sig *ast.FuncType) {
buf.WriteByte('(')
writeFieldList(buf, sig.Params.List, ", ", false)
buf.WriteByte(')')
res := sig.Results
n := res.NumFields()
if n == 0 {
// no result
return
}
buf.WriteByte(' ')
if n == 1 && len(res.List[0].Names) == 0 {
// single unnamed result
WriteExpr(buf, res.List[0].Type)
return
}
// multiple or named result(s)
buf.WriteByte('(')
writeFieldList(buf, res.List, ", ", false)
buf.WriteByte(')')
}
func writeFieldList(buf *bytes.Buffer, list []*ast.Field, sep string, iface bool) {
for i, f := range list {
if i > 0 {
buf.WriteString(sep)
}
// field list names
writeIdentList(buf, f.Names)
// types of interface methods consist of signatures only
if sig, _ := f.Type.(*ast.FuncType); sig != nil && iface {
writeSigExpr(buf, sig)
continue
}
// named fields are separated with a blank from the field type
if len(f.Names) > 0 {
buf.WriteByte(' ')
}
WriteExpr(buf, f.Type)
// ignore tag
}
}
func writeIdentList(buf *bytes.Buffer, list []*ast.Ident) {
for i, x := range list {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(x.Name)
}
}
func writeExprList(buf *bytes.Buffer, list []ast.Expr) {
for i, x := range list {
if i > 0 {
buf.WriteString(", ")
}
WriteExpr(buf, x)
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of index/slice expressions.
package types
import (
"go/ast"
"go/constant"
"go/internal/typeparams"
. "internal/types/errors"
)
// If e is a valid function instantiation, indexExpr returns true.
// In that case x represents the uninstantiated function value and
// it is the caller's responsibility to instantiate the function.
func (check *Checker) indexExpr(x *operand, e *typeparams.IndexExpr) (isFuncInst bool) {
check.exprOrType(x, e.X, true)
// x may be generic
switch x.mode {
case invalid:
check.use(e.Indices...)
return false
case typexpr:
// type instantiation
x.mode = invalid
// TODO(gri) here we re-evaluate e.X - try to avoid this
x.typ = check.varType(e.Orig)
if x.typ != Typ[Invalid] {
x.mode = typexpr
}
return false
case value:
if sig, _ := under(x.typ).(*Signature); sig != nil && sig.TypeParams().Len() > 0 {
// function instantiation
return true
}
}
// x should not be generic at this point, but be safe and check
check.nonGeneric(x)
if x.mode == invalid {
return false
}
// ordinary index expression
valid := false
length := int64(-1) // valid if >= 0
switch typ := under(x.typ).(type) {
case *Basic:
if isString(typ) {
valid = true
if x.mode == constant_ {
length = int64(len(constant.StringVal(x.val)))
}
// an indexed string always yields a byte value
// (not a constant) even if the string and the
// index are constant
x.mode = value
x.typ = universeByte // use 'byte' name
}
case *Array:
valid = true
length = typ.len
if x.mode != variable {
x.mode = value
}
x.typ = typ.elem
case *Pointer:
if typ, _ := under(typ.base).(*Array); typ != nil {
valid = true
length = typ.len
x.mode = variable
x.typ = typ.elem
}
case *Slice:
valid = true
x.mode = variable
x.typ = typ.elem
case *Map:
index := check.singleIndex(e)
if index == nil {
x.mode = invalid
return false
}
var key operand
check.expr(&key, index)
check.assignment(&key, typ.key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
x.typ = typ.elem
x.expr = e.Orig
return false
case *Interface:
if !isTypeParam(x.typ) {
break
}
// TODO(gri) report detailed failure cause for better error messages
var key, elem Type // key != nil: we must have all maps
mode := variable // non-maps result mode
// TODO(gri) factor out closure and use it for non-typeparam cases as well
if typ.typeSet().underIs(func(u Type) bool {
l := int64(-1) // valid if >= 0
var k, e Type // k is only set for maps
switch t := u.(type) {
case *Basic:
if isString(t) {
e = universeByte
mode = value
}
case *Array:
l = t.len
e = t.elem
if x.mode != variable {
mode = value
}
case *Pointer:
if t, _ := under(t.base).(*Array); t != nil {
l = t.len
e = t.elem
}
case *Slice:
e = t.elem
case *Map:
k = t.key
e = t.elem
}
if e == nil {
return false
}
if elem == nil {
// first type
length = l
key, elem = k, e
return true
}
// all map keys must be identical (incl. all nil)
// (that is, we cannot mix maps with other types)
if !Identical(key, k) {
return false
}
// all element types must be identical
if !Identical(elem, e) {
return false
}
// track the minimal length for arrays, if any
if l >= 0 && l < length {
length = l
}
return true
}) {
// For maps, the index expression must be assignable to the map key type.
if key != nil {
index := check.singleIndex(e)
if index == nil {
x.mode = invalid
return false
}
var k operand
check.expr(&k, index)
check.assignment(&k, key, "map index")
// ok to continue even if indexing failed - map element type is known
x.mode = mapindex
x.typ = elem
x.expr = e
return false
}
// no maps
valid = true
x.mode = mode
x.typ = elem
}
}
if !valid {
// types2 uses the position of '[' for the error
check.errorf(x, NonIndexableOperand, invalidOp+"cannot index %s", x)
x.mode = invalid
return false
}
index := check.singleIndex(e)
if index == nil {
x.mode = invalid
return false
}
// In pathological (invalid) cases (e.g.: type T1 [][[]T1{}[0][0]]T0)
// the element type may be accessed before it's set. Make sure we have
// a valid type.
if x.typ == nil {
x.typ = Typ[Invalid]
}
check.index(index, length)
return false
}
func (check *Checker) sliceExpr(x *operand, e *ast.SliceExpr) {
check.expr(x, e.X)
if x.mode == invalid {
check.use(e.Low, e.High, e.Max)
return
}
valid := false
length := int64(-1) // valid if >= 0
switch u := coreString(x.typ).(type) {
case nil:
check.errorf(x, NonSliceableOperand, invalidOp+"cannot slice %s: %s has no core type", x, x.typ)
x.mode = invalid
return
case *Basic:
if isString(u) {
if e.Slice3 {
at := e.Max
if at == nil {
at = e // e.Index[2] should be present but be careful
}
check.error(at, InvalidSliceExpr, invalidOp+"3-index slice of string")
x.mode = invalid
return
}
valid = true
if x.mode == constant_ {
length = int64(len(constant.StringVal(x.val)))
}
// spec: "For untyped string operands the result
// is a non-constant value of type string."
if isUntyped(x.typ) {
x.typ = Typ[String]
}
}
case *Array:
valid = true
length = u.len
if x.mode != variable {
check.errorf(x, NonSliceableOperand, invalidOp+"cannot slice %s (value not addressable)", x)
x.mode = invalid
return
}
x.typ = &Slice{elem: u.elem}
case *Pointer:
if u, _ := under(u.base).(*Array); u != nil {
valid = true
length = u.len
x.typ = &Slice{elem: u.elem}
}
case *Slice:
valid = true
// x.typ doesn't change
}
if !valid {
check.errorf(x, NonSliceableOperand, invalidOp+"cannot slice %s", x)
x.mode = invalid
return
}
x.mode = value
// spec: "Only the first index may be omitted; it defaults to 0."
if e.Slice3 && (e.High == nil || e.Max == nil) {
check.error(inNode(e, e.Rbrack), InvalidSyntaxTree, "2nd and 3rd index required in 3-index slice")
x.mode = invalid
return
}
// check indices
var ind [3]int64
for i, expr := range []ast.Expr{e.Low, e.High, e.Max} {
x := int64(-1)
switch {
case expr != nil:
// The "capacity" is only known statically for strings, arrays,
// and pointers to arrays, and it is the same as the length for
// those types.
max := int64(-1)
if length >= 0 {
max = length + 1
}
if _, v := check.index(expr, max); v >= 0 {
x = v
}
case i == 0:
// default is 0 for the first index
x = 0
case length >= 0:
// default is length (== capacity) otherwise
x = length
}
ind[i] = x
}
// constant indices must be in range
// (check.index already checks that existing indices >= 0)
L:
for i, x := range ind[:len(ind)-1] {
if x > 0 {
for j, y := range ind[i+1:] {
if y >= 0 && y < x {
// The value y corresponds to the expression e.Index[i+1+j].
// Because y >= 0, it must have been set from the expression
// when checking indices and thus e.Index[i+1+j] is not nil.
at := []ast.Expr{e.Low, e.High, e.Max}[i+1+j]
check.errorf(at, SwappedSliceIndices, "invalid slice indices: %d < %d", y, x)
break L // only report one error, ok to continue
}
}
}
}
}
// singleIndex returns the (single) index from the index expression e.
// If the index is missing, or if there are multiple indices, an error
// is reported and the result is nil.
func (check *Checker) singleIndex(expr *typeparams.IndexExpr) ast.Expr {
if len(expr.Indices) == 0 {
check.errorf(expr.Orig, InvalidSyntaxTree, "index expression %v with 0 indices", expr)
return nil
}
if len(expr.Indices) > 1 {
// TODO(rFindley) should this get a distinct error code?
check.error(expr.Indices[1], InvalidIndex, invalidOp+"more than one index")
}
return expr.Indices[0]
}
// index checks an index expression for validity.
// If max >= 0, it is the upper bound for index.
// If the result typ is != Typ[Invalid], index is valid and typ is its (possibly named) integer type.
// If the result val >= 0, index is valid and val is its constant int value.
func (check *Checker) index(index ast.Expr, max int64) (typ Type, val int64) {
typ = Typ[Invalid]
val = -1
var x operand
check.expr(&x, index)
if !check.isValidIndex(&x, InvalidIndex, "index", false) {
return
}
if x.mode != constant_ {
return x.typ, -1
}
if x.val.Kind() == constant.Unknown {
return
}
v, ok := constant.Int64Val(x.val)
assert(ok)
if max >= 0 && v >= max {
check.errorf(&x, InvalidIndex, invalidArg+"index %s out of bounds [0:%d]", x.val.String(), max)
return
}
// 0 <= v [ && v < max ]
return x.typ, v
}
func (check *Checker) isValidIndex(x *operand, code Code, what string, allowNegative bool) bool {
if x.mode == invalid {
return false
}
// spec: "a constant index that is untyped is given type int"
check.convertUntyped(x, Typ[Int])
if x.mode == invalid {
return false
}
// spec: "the index x must be of integer type or an untyped constant"
if !allInteger(x.typ) {
check.errorf(x, code, invalidArg+"%s %s must be integer", what, x)
return false
}
if x.mode == constant_ {
// spec: "a constant index must be non-negative ..."
if !allowNegative && constant.Sign(x.val) < 0 {
check.errorf(x, code, invalidArg+"%s %s must not be negative", what, x)
return false
}
// spec: "... and representable by a value of type int"
if !representableConst(x.val, check, Typ[Int], &x.val) {
check.errorf(x, code, invalidArg+"%s %s overflows int", what, x)
return false
}
}
return true
}
// indexedElts checks the elements (elts) of an array or slice composite literal
// against the literal's element type (typ), and the element indices against
// the literal length if known (length >= 0). It returns the length of the
// literal (maximum index value + 1).
func (check *Checker) indexedElts(elts []ast.Expr, typ Type, length int64) int64 {
visited := make(map[int64]bool, len(elts))
var index, max int64
for _, e := range elts {
// determine and check index
validIndex := false
eval := e
if kv, _ := e.(*ast.KeyValueExpr); kv != nil {
if typ, i := check.index(kv.Key, length); typ != Typ[Invalid] {
if i >= 0 {
index = i
validIndex = true
} else {
check.errorf(e, InvalidLitIndex, "index %s must be integer constant", kv.Key)
}
}
eval = kv.Value
} else if length >= 0 && index >= length {
check.errorf(e, OversizeArrayLit, "index %d is out of bounds (>= %d)", index, length)
} else {
validIndex = true
}
// if we have a valid index, check for duplicate entries
if validIndex {
if visited[index] {
check.errorf(e, DuplicateLitKey, "duplicate index %d in array or slice literal", index)
}
visited[index] = true
}
index++
if index > max {
max = index
}
// check element against composite literal element type
var x operand
check.exprWithHint(&x, eval, typ)
check.assignment(&x, typ, "array or slice literal")
}
return max
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type parameter inference.
package types
import (
"fmt"
"go/token"
. "internal/types/errors"
"strings"
)
// infer attempts to infer the complete set of type arguments for generic function instantiation/call
// based on the given type parameters tparams, type arguments targs, function parameters params, and
// function arguments args, if any. There must be at least one type parameter, no more type arguments
// than type parameters, and params and args must match in number (incl. zero).
// If successful, infer returns the complete list of given and inferred type arguments, one for each
// type parameter. Otherwise the result is nil and appropriate errors will be reported.
func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, params *Tuple, args []*operand) (inferred []Type) {
if debug {
defer func() {
assert(inferred == nil || len(inferred) == len(tparams))
for _, targ := range inferred {
assert(targ != nil)
}
}()
}
if traceInference {
check.dump("-- infer %s%s ➞ %s", tparams, params, targs)
defer func() {
check.dump("=> %s ➞ %s\n", tparams, inferred)
}()
}
// There must be at least one type parameter, and no more type arguments than type parameters.
n := len(tparams)
assert(n > 0 && len(targs) <= n)
// Function parameters and arguments must match in number.
assert(params.Len() == len(args))
// If we already have all type arguments, we're done.
if len(targs) == n {
return targs
}
// len(targs) < n
// Rename type parameters to avoid conflicts in recursive instantiation scenarios.
tparams, params = check.renameTParams(posn.Pos(), tparams, params)
if traceInference {
check.dump("after rename: %s%s ➞ %s\n", tparams, params, targs)
}
// Make sure we have a "full" list of type arguments, some of which may
// be nil (unknown). Make a copy so as to not clobber the incoming slice.
if len(targs) < n {
targs2 := make([]Type, n)
copy(targs2, targs)
targs = targs2
}
// len(targs) == n
// Continue with the type arguments we have. Avoid matching generic
// parameters that already have type arguments against function arguments:
// It may fail because matching uses type identity while parameter passing
// uses assignment rules. Instantiate the parameter list with the type
// arguments we have, and continue with that parameter list.
// Substitute type arguments for their respective type parameters in params,
// if any. Note that nil targs entries are ignored by check.subst.
// We do this for better error messages; it's not needed for correctness.
// For instance, given:
//
// func f[P, Q any](P, Q) {}
//
// func _(s string) {
// f[int](s, s) // ERROR
// }
//
// With substitution, we get the error:
// "cannot use s (variable of type string) as int value in argument to f[int]"
//
// Without substitution we get the (worse) error:
// "type string of s does not match inferred type int for P"
// even though the type int was provided (not inferred) for P.
//
// TODO(gri) We might be able to finesse this in the error message reporting
// (which only happens in case of an error) and then avoid doing
// the substitution (which always happens).
if params.Len() > 0 {
smap := makeSubstMap(tparams, targs)
params = check.subst(nopos, params, smap, nil, check.context()).(*Tuple)
}
// Unify parameter and argument types for generic parameters with typed arguments
// and collect the indices of generic parameters with untyped arguments.
// Terminology: generic parameter = function parameter with a type-parameterized type
u := newUnifier(tparams, targs)
errorf := func(kind string, tpar, targ Type, arg *operand) {
// provide a better error message if we can
targs := u.inferred(tparams)
if targs[0] == nil {
// The first type parameter couldn't be inferred.
// If none of them could be inferred, don't try
// to provide the inferred type in the error msg.
allFailed := true
for _, targ := range targs {
if targ != nil {
allFailed = false
break
}
}
if allFailed {
check.errorf(arg, CannotInferTypeArgs, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeParamsString(tparams))
return
}
}
smap := makeSubstMap(tparams, targs)
// TODO(gri): pass a poser here, rather than arg.Pos().
inferred := check.subst(arg.Pos(), tpar, smap, nil, check.context())
// CannotInferTypeArgs indicates a failure of inference, though the actual
// error may be better attributed to a user-provided type argument (hence
// InvalidTypeArg). We can't differentiate these cases, so fall back on
// the more general CannotInferTypeArgs.
if inferred != tpar {
check.errorf(arg, CannotInferTypeArgs, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
} else {
check.errorf(arg, CannotInferTypeArgs, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar)
}
}
// indices of generic parameters with untyped arguments, for later use
var untyped []int
// --- 1 ---
// use information from function arguments
if traceInference {
u.tracef("parameters: %s", params)
u.tracef("arguments : %s", args)
}
for i, arg := range args {
par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be
// executed even if par.typ is not parameterized since the argument may be a
// generic function (for which we want to infer its type arguments).
if isParameterized(tparams, par.typ) {
if arg.mode == invalid {
// An error was reported earlier. Ignore this targ
// and continue, we may still be able to infer all
// targs resulting in fewer follow-on errors.
continue
}
if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg)
return nil
}
} else if _, ok := par.typ.(*TypeParam); ok {
// Since default types are all basic (i.e., non-composite) types, an
// untyped argument will never match a composite parameter type; the
// only parameter type it can possibly match against is a *TypeParam.
// Thus, for untyped arguments we only need to look at parameter types
// that are single type parameters.
untyped = append(untyped, i)
}
}
}
if traceInference {
inferred := u.inferred(tparams)
u.tracef("=> %s ➞ %s\n", tparams, inferred)
}
// --- 2 ---
// use information from type parameter constraints
if traceInference {
u.tracef("type parameters: %s", tparams)
}
// Unify type parameters with their constraints as long
// as progress is being made.
//
// This is an O(n^2) algorithm where n is the number of
// type parameters: if there is progress, at least one
// type argument is inferred per iteration, and we have
// a doubly nested loop.
//
// In practice this is not a problem because the number
// of type parameters tends to be very small (< 5 or so).
// (It should be possible for unification to efficiently
// signal newly inferred type arguments; then the loops
// here could handle the respective type parameters only,
// but that will come at a cost of extra complexity which
// may not be worth it.)
for {
nn := u.unknowns()
for _, tpar := range tparams {
tx := u.at(tpar)
if traceInference && tx != nil {
u.tracef("%s = %s", tpar, tx)
}
// If there is a core term (i.e., a core type with tilde information)
// unify the type parameter with the core type.
if core, single := coreTerm(tpar); core != nil {
if traceInference {
u.tracef("core(%s) = %s (single = %v)", tpar, core, single)
}
// A type parameter can be unified with its core type in two cases.
switch {
case tx != nil:
// The corresponding type argument tx is known. There are 2 cases:
// 1) If the core type has a tilde, per spec requirement for tilde
// elements, the core type is an underlying (literal) type.
// And because of the tilde, the underlying type of tx must match
// against the core type.
// But because unify automatically matches a defined type against
// an underlying literal type, we can simply unify tx with the
// core type.
// 2) If the core type doesn't have a tilde, we also must unify tx
// with the core type.
if !u.unify(tx, core.typ) {
check.errorf(posn, CannotInferTypeArgs, "%s does not match %s", tpar, core.typ)
return nil
}
case single && !core.tilde:
// The corresponding type argument tx is unknown and there's a single
// specific type and no tilde.
// In this case the type argument must be that single type; set it.
u.set(tpar, core.typ)
}
} else {
if traceInference {
u.tracef("core(%s) = nil", tpar)
}
if tx != nil {
// We don't have a core type, but the type argument tx is known.
// It must have (at least) all the methods of the type constraint,
// and the method signatures must unify; otherwise tx cannot satisfy
// the constraint.
var cause string
constraint := tpar.iface()
if m, _ := check.missingMethod(tx, constraint, true, u.unify, &cause); m != nil {
check.errorf(posn, CannotInferTypeArgs, "%s does not satisfy %s %s", tx, constraint, cause)
return nil
}
}
}
}
if u.unknowns() == nn {
break // no progress
}
}
if traceInference {
inferred := u.inferred(tparams)
u.tracef("=> %s ➞ %s\n", tparams, inferred)
}
// --- 3 ---
// use information from untyped contants
if traceInference {
u.tracef("untyped: %v", untyped)
}
// Some generic parameters with untyped arguments may have been given a type by now.
// Collect all remaining parameters that don't have a type yet and unify them with
// the default types of the untyped arguments.
// We need to collect them all before unifying them with their untyped arguments;
// otherwise a parameter type that appears multiple times will have a type after
// the first unification and will be skipped later on, leading to incorrect results.
j := 0
for _, i := range untyped {
tpar := params.At(i).typ.(*TypeParam) // is type parameter by construction of untyped
if u.at(tpar) == nil {
untyped[j] = i
j++
}
}
// untyped[:j] are the indices of parameters without a type yet
for _, i := range untyped[:j] {
tpar := params.At(i).typ.(*TypeParam)
arg := args[i]
typ := Default(arg.typ)
// The default type for an untyped nil is untyped nil which must
// not be inferred as type parameter type. Ignore them by making
// sure all default types are typed.
if isTyped(typ) && !u.unify(tpar, typ) {
errorf("default type", tpar, typ, arg)
return nil
}
}
// --- simplify ---
// u.inferred(tparams) now contains the incoming type arguments plus any additional type
// arguments which were inferred. The inferred non-nil entries may still contain
// references to other type parameters found in constraints.
// For instance, for [A any, B interface{ []C }, C interface{ *A }], if A == int
// was given, unification produced the type list [int, []C, *A]. We eliminate the
// remaining type parameters by substituting the type parameters in this type list
// until nothing changes anymore.
inferred = u.inferred(tparams)
if debug {
for i, targ := range targs {
assert(targ == nil || inferred[i] == targ)
}
}
// The data structure of each (provided or inferred) type represents a graph, where
// each node corresponds to a type and each (directed) vertex points to a component
// type. The substitution process described above repeatedly replaces type parameter
// nodes in these graphs with the graphs of the types the type parameters stand for,
// which creates a new (possibly bigger) graph for each type.
// The substitution process will not stop if the replacement graph for a type parameter
// also contains that type parameter.
// For instance, for [A interface{ *A }], without any type argument provided for A,
// unification produces the type list [*A]. Substituting A in *A with the value for
// A will lead to infinite expansion by producing [**A], [****A], [********A], etc.,
// because the graph A -> *A has a cycle through A.
// Generally, cycles may occur across multiple type parameters and inferred types
// (for instance, consider [P interface{ *Q }, Q interface{ func(P) }]).
// We eliminate cycles by walking the graphs for all type parameters. If a cycle
// through a type parameter is detected, cycleFinder nils out the respective type
// which kills the cycle; this also means that the respective type could not be
// inferred.
//
// TODO(gri) If useful, we could report the respective cycle as an error. We don't
// do this now because type inference will fail anyway, and furthermore,
// constraints with cycles of this kind cannot currently be satisfied by
// any user-supplied type. But should that change, reporting an error
// would be wrong.
w := cycleFinder{tparams, inferred, make(map[Type]bool)}
for _, t := range tparams {
w.typ(t) // t != nil
}
// dirty tracks the indices of all types that may still contain type parameters.
// We know that nil type entries and entries corresponding to provided (non-nil)
// type arguments are clean, so exclude them from the start.
var dirty []int
for i, typ := range inferred {
if typ != nil && (i >= len(targs) || targs[i] == nil) {
dirty = append(dirty, i)
}
}
for len(dirty) > 0 {
// TODO(gri) Instead of creating a new substMap for each iteration,
// provide an update operation for substMaps and only change when
// needed. Optimization.
smap := makeSubstMap(tparams, inferred)
n := 0
for _, index := range dirty {
t0 := inferred[index]
if t1 := check.subst(nopos, t0, smap, nil, check.context()); t1 != t0 {
inferred[index] = t1
dirty[n] = index
n++
}
}
dirty = dirty[:n]
}
// Once nothing changes anymore, we may still have type parameters left;
// e.g., a constraint with core type *P may match a type parameter Q but
// we don't have any type arguments to fill in for *P or Q (go.dev/issue/45548).
// Don't let such inferences escape; instead treat them as unresolved.
for i, typ := range inferred {
if typ == nil || isParameterized(tparams, typ) {
obj := tparams[i].obj
check.errorf(posn, CannotInferTypeArgs, "cannot infer %s (%s)", obj.name, obj.pos)
return nil
}
}
return
}
// renameTParams renames the type parameters in a function signature described by its
// type and ordinary parameters (tparams and params) such that each type parameter is
// given a new identity. renameTParams returns the new type and ordinary parameters.
func (check *Checker) renameTParams(pos token.Pos, tparams []*TypeParam, params *Tuple) ([]*TypeParam, *Tuple) {
// For the purpose of type inference we must differentiate type parameters
// occurring in explicit type or value function arguments from the type
// parameters we are solving for via unification because they may be the
// same in self-recursive calls:
//
// func f[P constraint](x P) {
// f(x)
// }
//
// In this example, without type parameter renaming, the P used in the
// instantation f[P] has the same pointer identity as the P we are trying
// to solve for through type inference. This causes problems for type
// unification. Because any such self-recursive call is equivalent to
// a mutually recursive call, type parameter renaming can be used to
// create separate, disentangled type parameters. The above example
// can be rewritten into the following equivalent code:
//
// func f[P constraint](x P) {
// f2(x)
// }
//
// func f2[P2 constraint](x P2) {
// f(x)
// }
//
// Type parameter renaming turns the first example into the second
// example by renaming the type parameter P into P2.
tparams2 := make([]*TypeParam, len(tparams))
for i, tparam := range tparams {
tname := NewTypeName(tparam.Obj().Pos(), tparam.Obj().Pkg(), tparam.Obj().Name(), nil)
tparams2[i] = NewTypeParam(tname, nil)
tparams2[i].index = tparam.index // == i
}
renameMap := makeRenameMap(tparams, tparams2)
for i, tparam := range tparams {
tparams2[i].bound = check.subst(pos, tparam.bound, renameMap, nil, check.context())
}
return tparams2, check.subst(pos, params, renameMap, nil, check.context()).(*Tuple)
}
// typeParamsString produces a string containing all the type parameter names
// in list suitable for human consumption.
func typeParamsString(list []*TypeParam) string {
// common cases
n := len(list)
switch n {
case 0:
return ""
case 1:
return list[0].obj.name
case 2:
return list[0].obj.name + " and " + list[1].obj.name
}
// general case (n > 2)
var buf strings.Builder
for i, tname := range list[:n-1] {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(tname.obj.name)
}
buf.WriteString(", and ")
buf.WriteString(list[n-1].obj.name)
return buf.String()
}
// isParameterized reports whether typ contains any of the type parameters of tparams.
func isParameterized(tparams []*TypeParam, typ Type) bool {
w := tpWalker{
seen: make(map[Type]bool),
tparams: tparams,
}
return w.isParameterized(typ)
}
type tpWalker struct {
seen map[Type]bool
tparams []*TypeParam
}
func (w *tpWalker) isParameterized(typ Type) (res bool) {
// detect cycles
if x, ok := w.seen[typ]; ok {
return x
}
w.seen[typ] = false
defer func() {
w.seen[typ] = res
}()
switch t := typ.(type) {
case nil, *Basic: // TODO(gri) should nil be handled here?
break
case *Array:
return w.isParameterized(t.elem)
case *Slice:
return w.isParameterized(t.elem)
case *Struct:
for _, fld := range t.fields {
if w.isParameterized(fld.typ) {
return true
}
}
case *Pointer:
return w.isParameterized(t.base)
case *Tuple:
n := t.Len()
for i := 0; i < n; i++ {
if w.isParameterized(t.At(i).typ) {
return true
}
}
case *Signature:
// t.tparams may not be nil if we are looking at a signature
// of a generic function type (or an interface method) that is
// part of the type we're testing. We don't care about these type
// parameters.
// Similarly, the receiver of a method may declare (rather then
// use) type parameters, we don't care about those either.
// Thus, we only need to look at the input and result parameters.
return w.isParameterized(t.params) || w.isParameterized(t.results)
case *Interface:
tset := t.typeSet()
for _, m := range tset.methods {
if w.isParameterized(m.typ) {
return true
}
}
return tset.is(func(t *term) bool {
return t != nil && w.isParameterized(t.typ)
})
case *Map:
return w.isParameterized(t.key) || w.isParameterized(t.elem)
case *Chan:
return w.isParameterized(t.elem)
case *Named:
return w.isParameterizedTypeList(t.TypeArgs().list())
case *TypeParam:
// t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0
default:
unreachable()
}
return false
}
func (w *tpWalker) isParameterizedTypeList(list []Type) bool {
for _, t := range list {
if w.isParameterized(t) {
return true
}
}
return false
}
// If the type parameter has a single specific type S, coreTerm returns (S, true).
// Otherwise, if tpar has a core type T, it returns a term corresponding to that
// core type and false. In that case, if any term of tpar has a tilde, the core
// term has a tilde. In all other cases coreTerm returns (nil, false).
func coreTerm(tpar *TypeParam) (*term, bool) {
n := 0
var single *term // valid if n == 1
var tilde bool
tpar.is(func(t *term) bool {
if t == nil {
assert(n == 0)
return false // no terms
}
n++
single = t
if t.tilde {
tilde = true
}
return true
})
if n == 1 {
if debug {
assert(debug && under(single.typ) == coreType(tpar))
}
return single, true
}
if typ := coreType(tpar); typ != nil {
// A core type is always an underlying type.
// If any term of tpar has a tilde, we don't
// have a precise core type and we must return
// a tilde as well.
return &term{tilde, typ}, false
}
return nil, false
}
type cycleFinder struct {
tparams []*TypeParam
types []Type
seen map[Type]bool
}
func (w *cycleFinder) typ(typ Type) {
if w.seen[typ] {
// We have seen typ before. If it is one of the type parameters
// in tparams, iterative substitution will lead to infinite expansion.
// Nil out the corresponding type which effectively kills the cycle.
if tpar, _ := typ.(*TypeParam); tpar != nil {
if i := tparamIndex(w.tparams, tpar); i >= 0 {
// cycle through tpar
w.types[i] = nil
}
}
// If we don't have one of our type parameters, the cycle is due
// to an ordinary recursive type and we can just stop walking it.
return
}
w.seen[typ] = true
defer delete(w.seen, typ)
switch t := typ.(type) {
case *Basic:
// nothing to do
case *Array:
w.typ(t.elem)
case *Slice:
w.typ(t.elem)
case *Struct:
w.varList(t.fields)
case *Pointer:
w.typ(t.base)
// case *Tuple:
// This case should not occur because tuples only appear
// in signatures where they are handled explicitly.
case *Signature:
if t.params != nil {
w.varList(t.params.vars)
}
if t.results != nil {
w.varList(t.results.vars)
}
case *Union:
for _, t := range t.terms {
w.typ(t.typ)
}
case *Interface:
for _, m := range t.methods {
w.typ(m.typ)
}
for _, t := range t.embeddeds {
w.typ(t)
}
case *Map:
w.typ(t.key)
w.typ(t.elem)
case *Chan:
w.typ(t.elem)
case *Named:
for _, tpar := range t.TypeArgs().list() {
w.typ(tpar)
}
case *TypeParam:
if i := tparamIndex(w.tparams, t); i >= 0 && w.types[i] != nil {
w.typ(w.types[i])
}
default:
panic(fmt.Sprintf("unexpected %T", typ))
}
}
func (w *cycleFinder) varList(list []*Var) {
for _, v := range list {
w.typ(v.typ)
}
}
// If tpar is a type parameter in list, tparamIndex returns the type parameter index.
// Otherwise, the result is < 0. tpar must not be nil.
func tparamIndex(list []*TypeParam, tpar *TypeParam) int {
// Once a type parameter is bound its index is >= 0. However, there are some
// code paths (namely tracing and type hashing) by which it is possible to
// arrive here with a type parameter that has not been bound, hence the check
// for 0 <= i below.
// TODO(rfindley): investigate a better approach for guarding against using
// unbound type parameters.
if i := tpar.index; 0 <= i && i < len(list) && list[i] == tpar {
return i
}
return -1
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"container/heap"
"fmt"
. "internal/types/errors"
"sort"
)
// initOrder computes the Info.InitOrder for package variables.
func (check *Checker) initOrder() {
// An InitOrder may already have been computed if a package is
// built from several calls to (*Checker).Files. Clear it.
check.Info.InitOrder = check.Info.InitOrder[:0]
// Compute the object dependency graph and initialize
// a priority queue with the list of graph nodes.
pq := nodeQueue(dependencyGraph(check.objMap))
heap.Init(&pq)
const debug = false
if debug {
fmt.Printf("Computing initialization order for %s\n\n", check.pkg)
fmt.Println("Object dependency graph:")
for obj, d := range check.objMap {
// only print objects that may appear in the dependency graph
if obj, _ := obj.(dependency); obj != nil {
if len(d.deps) > 0 {
fmt.Printf("\t%s depends on\n", obj.Name())
for dep := range d.deps {
fmt.Printf("\t\t%s\n", dep.Name())
}
} else {
fmt.Printf("\t%s has no dependencies\n", obj.Name())
}
}
}
fmt.Println()
fmt.Println("Transposed object dependency graph (functions eliminated):")
for _, n := range pq {
fmt.Printf("\t%s depends on %d nodes\n", n.obj.Name(), n.ndeps)
for p := range n.pred {
fmt.Printf("\t\t%s is dependent\n", p.obj.Name())
}
}
fmt.Println()
fmt.Println("Processing nodes:")
}
// Determine initialization order by removing the highest priority node
// (the one with the fewest dependencies) and its edges from the graph,
// repeatedly, until there are no nodes left.
// In a valid Go program, those nodes always have zero dependencies (after
// removing all incoming dependencies), otherwise there are initialization
// cycles.
emitted := make(map[*declInfo]bool)
for len(pq) > 0 {
// get the next node
n := heap.Pop(&pq).(*graphNode)
if debug {
fmt.Printf("\t%s (src pos %d) depends on %d nodes now\n",
n.obj.Name(), n.obj.order(), n.ndeps)
}
// if n still depends on other nodes, we have a cycle
if n.ndeps > 0 {
cycle := findPath(check.objMap, n.obj, n.obj, make(map[Object]bool))
// If n.obj is not part of the cycle (e.g., n.obj->b->c->d->c),
// cycle will be nil. Don't report anything in that case since
// the cycle is reported when the algorithm gets to an object
// in the cycle.
// Furthermore, once an object in the cycle is encountered,
// the cycle will be broken (dependency count will be reduced
// below), and so the remaining nodes in the cycle don't trigger
// another error (unless they are part of multiple cycles).
if cycle != nil {
check.reportCycle(cycle)
}
// Ok to continue, but the variable initialization order
// will be incorrect at this point since it assumes no
// cycle errors.
}
// reduce dependency count of all dependent nodes
// and update priority queue
for p := range n.pred {
p.ndeps--
heap.Fix(&pq, p.index)
}
// record the init order for variables with initializers only
v, _ := n.obj.(*Var)
info := check.objMap[v]
if v == nil || !info.hasInitializer() {
continue
}
// n:1 variable declarations such as: a, b = f()
// introduce a node for each lhs variable (here: a, b);
// but they all have the same initializer - emit only
// one, for the first variable seen
if emitted[info] {
continue // initializer already emitted, if any
}
emitted[info] = true
infoLhs := info.lhs // possibly nil (see declInfo.lhs field comment)
if infoLhs == nil {
infoLhs = []*Var{v}
}
init := &Initializer{infoLhs, info.init}
check.Info.InitOrder = append(check.Info.InitOrder, init)
}
if debug {
fmt.Println()
fmt.Println("Initialization order:")
for _, init := range check.Info.InitOrder {
fmt.Printf("\t%s\n", init)
}
fmt.Println()
}
}
// findPath returns the (reversed) list of objects []Object{to, ... from}
// such that there is a path of object dependencies from 'from' to 'to'.
// If there is no such path, the result is nil.
func findPath(objMap map[Object]*declInfo, from, to Object, seen map[Object]bool) []Object {
if seen[from] {
return nil
}
seen[from] = true
for d := range objMap[from].deps {
if d == to {
return []Object{d}
}
if P := findPath(objMap, d, to, seen); P != nil {
return append(P, d)
}
}
return nil
}
// reportCycle reports an error for the given cycle.
func (check *Checker) reportCycle(cycle []Object) {
obj := cycle[0]
// report a more concise error for self references
if len(cycle) == 1 {
check.errorf(obj, InvalidInitCycle, "initialization cycle: %s refers to itself", obj.Name())
return
}
check.errorf(obj, InvalidInitCycle, "initialization cycle for %s", obj.Name())
// subtle loop: print cycle[i] for i = 0, n-1, n-2, ... 1 for len(cycle) = n
for i := len(cycle) - 1; i >= 0; i-- {
check.errorf(obj, InvalidInitCycle, "\t%s refers to", obj.Name()) // secondary error, \t indented
obj = cycle[i]
}
// print cycle[0] again to close the cycle
check.errorf(obj, InvalidInitCycle, "\t%s", obj.Name())
}
// ----------------------------------------------------------------------------
// Object dependency graph
// A dependency is an object that may be a dependency in an initialization
// expression. Only constants, variables, and functions can be dependencies.
// Constants are here because constant expression cycles are reported during
// initialization order computation.
type dependency interface {
Object
isDependency()
}
// A graphNode represents a node in the object dependency graph.
// Each node p in n.pred represents an edge p->n, and each node
// s in n.succ represents an edge n->s; with a->b indicating that
// a depends on b.
type graphNode struct {
obj dependency // object represented by this node
pred, succ nodeSet // consumers and dependencies of this node (lazily initialized)
index int // node index in graph slice/priority queue
ndeps int // number of outstanding dependencies before this object can be initialized
}
// cost returns the cost of removing this node, which involves copying each
// predecessor to each successor (and vice-versa).
func (n *graphNode) cost() int {
return len(n.pred) * len(n.succ)
}
type nodeSet map[*graphNode]bool
func (s *nodeSet) add(p *graphNode) {
if *s == nil {
*s = make(nodeSet)
}
(*s)[p] = true
}
// dependencyGraph computes the object dependency graph from the given objMap,
// with any function nodes removed. The resulting graph contains only constants
// and variables.
func dependencyGraph(objMap map[Object]*declInfo) []*graphNode {
// M is the dependency (Object) -> graphNode mapping
M := make(map[dependency]*graphNode)
for obj := range objMap {
// only consider nodes that may be an initialization dependency
if obj, _ := obj.(dependency); obj != nil {
M[obj] = &graphNode{obj: obj}
}
}
// compute edges for graph M
// (We need to include all nodes, even isolated ones, because they still need
// to be scheduled for initialization in correct order relative to other nodes.)
for obj, n := range M {
// for each dependency obj -> d (= deps[i]), create graph edges n->s and s->n
for d := range objMap[obj].deps {
// only consider nodes that may be an initialization dependency
if d, _ := d.(dependency); d != nil {
d := M[d]
n.succ.add(d)
d.pred.add(n)
}
}
}
var G, funcG []*graphNode // separate non-functions and functions
for _, n := range M {
if _, ok := n.obj.(*Func); ok {
funcG = append(funcG, n)
} else {
G = append(G, n)
}
}
// remove function nodes and collect remaining graph nodes in G
// (Mutually recursive functions may introduce cycles among themselves
// which are permitted. Yet such cycles may incorrectly inflate the dependency
// count for variables which in turn may not get scheduled for initialization
// in correct order.)
//
// Note that because we recursively copy predecessors and successors
// throughout the function graph, the cost of removing a function at
// position X is proportional to cost * (len(funcG)-X). Therefore, we should
// remove high-cost functions last.
sort.Slice(funcG, func(i, j int) bool {
return funcG[i].cost() < funcG[j].cost()
})
for _, n := range funcG {
// connect each predecessor p of n with each successor s
// and drop the function node (don't collect it in G)
for p := range n.pred {
// ignore self-cycles
if p != n {
// Each successor s of n becomes a successor of p, and
// each predecessor p of n becomes a predecessor of s.
for s := range n.succ {
// ignore self-cycles
if s != n {
p.succ.add(s)
s.pred.add(p)
}
}
delete(p.succ, n) // remove edge to n
}
}
for s := range n.succ {
delete(s.pred, n) // remove edge to n
}
}
// fill in index and ndeps fields
for i, n := range G {
n.index = i
n.ndeps = len(n.succ)
}
return G
}
// ----------------------------------------------------------------------------
// Priority queue
// nodeQueue implements the container/heap interface;
// a nodeQueue may be used as a priority queue.
type nodeQueue []*graphNode
func (a nodeQueue) Len() int { return len(a) }
func (a nodeQueue) Swap(i, j int) {
x, y := a[i], a[j]
a[i], a[j] = y, x
x.index, y.index = j, i
}
func (a nodeQueue) Less(i, j int) bool {
x, y := a[i], a[j]
// nodes are prioritized by number of incoming dependencies (1st key)
// and source order (2nd key)
return x.ndeps < y.ndeps || x.ndeps == y.ndeps && x.obj.order() < y.obj.order()
}
func (a *nodeQueue) Push(x any) {
panic("unreachable")
}
func (a *nodeQueue) Pop() any {
n := len(*a)
x := (*a)[n-1]
x.index = -1 // for safety
*a = (*a)[:n-1]
return x
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements instantiation of generic types
// through substitution of type parameters by type arguments.
package types
import (
"errors"
"fmt"
"go/token"
. "internal/types/errors"
)
// Instantiate instantiates the type orig with the given type arguments targs.
// orig must be a *Named or a *Signature type. If there is no error, the
// resulting Type is an instantiated type of the same kind (either a *Named or
// a *Signature). Methods attached to a *Named type are also instantiated, and
// associated with a new *Func that has the same position as the original
// method, but nil function scope.
//
// If ctxt is non-nil, it may be used to de-duplicate the instance against
// previous instances with the same identity. As a special case, generic
// *Signature origin types are only considered identical if they are pointer
// equivalent, so that instantiating distinct (but possibly identical)
// signatures will yield different instances. The use of a shared context does
// not guarantee that identical instances are deduplicated in all cases.
//
// If validate is set, Instantiate verifies that the number of type arguments
// and parameters match, and that the type arguments satisfy their
// corresponding type constraints. If verification fails, the resulting error
// may wrap an *ArgumentError indicating which type argument did not satisfy
// its corresponding type parameter constraint, and why.
//
// If validate is not set, Instantiate does not verify the type argument count
// or whether the type arguments satisfy their constraints. Instantiate is
// guaranteed to not return an error, but may panic. Specifically, for
// *Signature types, Instantiate will panic immediately if the type argument
// count is incorrect; for *Named types, a panic may occur later inside the
// *Named API.
func Instantiate(ctxt *Context, orig Type, targs []Type, validate bool) (Type, error) {
if ctxt == nil {
ctxt = NewContext()
}
if validate {
var tparams []*TypeParam
switch t := orig.(type) {
case *Named:
tparams = t.TypeParams().list()
case *Signature:
tparams = t.TypeParams().list()
}
if len(targs) != len(tparams) {
return nil, fmt.Errorf("got %d type arguments but %s has %d type parameters", len(targs), orig, len(tparams))
}
if i, err := (*Checker)(nil).verify(nopos, tparams, targs, ctxt); err != nil {
return nil, &ArgumentError{i, err}
}
}
inst := (*Checker)(nil).instance(nopos, orig, targs, nil, ctxt)
return inst, nil
}
// instance instantiates the given original (generic) function or type with the
// provided type arguments and returns the resulting instance. If an identical
// instance exists already in the given contexts, it returns that instance,
// otherwise it creates a new one.
//
// If expanding is non-nil, it is the Named instance type currently being
// expanded. If ctxt is non-nil, it is the context associated with the current
// type-checking pass or call to Instantiate. At least one of expanding or ctxt
// must be non-nil.
//
// For Named types the resulting instance may be unexpanded.
func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, expanding *Named, ctxt *Context) (res Type) {
// The order of the contexts below matters: we always prefer instances in the
// expanding instance context in order to preserve reference cycles.
//
// Invariant: if expanding != nil, the returned instance will be the instance
// recorded in expanding.inst.ctxt.
var ctxts []*Context
if expanding != nil {
ctxts = append(ctxts, expanding.inst.ctxt)
}
if ctxt != nil {
ctxts = append(ctxts, ctxt)
}
assert(len(ctxts) > 0)
// Compute all hashes; hashes may differ across contexts due to different
// unique IDs for Named types within the hasher.
hashes := make([]string, len(ctxts))
for i, ctxt := range ctxts {
hashes[i] = ctxt.instanceHash(orig, targs)
}
// If local is non-nil, updateContexts return the type recorded in
// local.
updateContexts := func(res Type) Type {
for i := len(ctxts) - 1; i >= 0; i-- {
res = ctxts[i].update(hashes[i], orig, targs, res)
}
return res
}
// typ may already have been instantiated with identical type arguments. In
// that case, re-use the existing instance.
for i, ctxt := range ctxts {
if inst := ctxt.lookup(hashes[i], orig, targs); inst != nil {
return updateContexts(inst)
}
}
switch orig := orig.(type) {
case *Named:
res = check.newNamedInstance(pos, orig, targs, expanding) // substituted lazily
case *Signature:
assert(expanding == nil) // function instances cannot be reached from Named types
tparams := orig.TypeParams()
if !check.validateTArgLen(pos, tparams.Len(), len(targs)) {
return Typ[Invalid]
}
if tparams.Len() == 0 {
return orig // nothing to do (minor optimization)
}
sig := check.subst(pos, orig, makeSubstMap(tparams.list(), targs), nil, ctxt).(*Signature)
// If the signature doesn't use its type parameters, subst
// will not make a copy. In that case, make a copy now (so
// we can set tparams to nil w/o causing side-effects).
if sig == orig {
copy := *sig
sig = ©
}
// After instantiating a generic signature, it is not generic
// anymore; we need to set tparams to nil.
sig.tparams = nil
res = sig
default:
// only types and functions can be generic
panic(fmt.Sprintf("%v: cannot instantiate %v", pos, orig))
}
// Update all contexts; it's possible that we've lost a race.
return updateContexts(res)
}
// validateTArgLen verifies that the length of targs and tparams matches,
// reporting an error if not. If validation fails and check is nil,
// validateTArgLen panics.
func (check *Checker) validateTArgLen(pos token.Pos, ntparams, ntargs int) bool {
if ntargs != ntparams {
// TODO(gri) provide better error message
if check != nil {
check.errorf(atPos(pos), WrongTypeArgCount, "got %d arguments but %d type parameters", ntargs, ntparams)
return false
}
panic(fmt.Sprintf("%v: got %d arguments but %d type parameters", pos, ntargs, ntparams))
}
return true
}
func (check *Checker) verify(pos token.Pos, tparams []*TypeParam, targs []Type, ctxt *Context) (int, error) {
smap := makeSubstMap(tparams, targs)
for i, tpar := range tparams {
// Ensure that we have a (possibly implicit) interface as type bound (go.dev/issue/51048).
tpar.iface()
// The type parameter bound is parameterized with the same type parameters
// as the instantiated type; before we can use it for bounds checking we
// need to instantiate it with the type arguments with which we instantiated
// the parameterized type.
bound := check.subst(pos, tpar.bound, smap, nil, ctxt)
var cause string
if !check.implements(targs[i], bound, true, &cause) {
return i, errors.New(cause)
}
}
return -1, nil
}
// implements checks if V implements T. The receiver may be nil if implements
// is called through an exported API call such as AssignableTo. If constraint
// is set, T is a type constraint.
//
// If the provided cause is non-nil, it may be set to an error string
// explaining why V does not implement (or satisfy, for constraints) T.
func (check *Checker) implements(V, T Type, constraint bool, cause *string) bool {
Vu := under(V)
Tu := under(T)
if Vu == Typ[Invalid] || Tu == Typ[Invalid] {
return true // avoid follow-on errors
}
if p, _ := Vu.(*Pointer); p != nil && under(p.base) == Typ[Invalid] {
return true // avoid follow-on errors (see go.dev/issue/49541 for an example)
}
verb := "implement"
if constraint {
verb = "satisfy"
}
Ti, _ := Tu.(*Interface)
if Ti == nil {
if cause != nil {
var detail string
if isInterfacePtr(Tu) {
detail = check.sprintf("type %s is pointer to interface, not interface", T)
} else {
detail = check.sprintf("%s is not an interface", T)
}
*cause = check.sprintf("%s does not %s %s (%s)", V, verb, T, detail)
}
return false
}
// Every type satisfies the empty interface.
if Ti.Empty() {
return true
}
// T is not the empty interface (i.e., the type set of T is restricted)
// An interface V with an empty type set satisfies any interface.
// (The empty set is a subset of any set.)
Vi, _ := Vu.(*Interface)
if Vi != nil && Vi.typeSet().IsEmpty() {
return true
}
// type set of V is not empty
// No type with non-empty type set satisfies the empty type set.
if Ti.typeSet().IsEmpty() {
if cause != nil {
*cause = check.sprintf("cannot %s %s (empty type set)", verb, T)
}
return false
}
// V must implement T's methods, if any.
if m, _ := check.missingMethod(V, T, true, Identical, cause); m != nil /* !Implements(V, T) */ {
if cause != nil {
*cause = check.sprintf("%s does not %s %s %s", V, verb, T, *cause)
}
return false
}
// Only check comparability if we don't have a more specific error.
checkComparability := func() bool {
if !Ti.IsComparable() {
return true
}
// If T is comparable, V must be comparable.
// If V is strictly comparable, we're done.
if comparable(V, false /* strict comparability */, nil, nil) {
return true
}
// For constraint satisfaction, use dynamic (spec) comparability
// so that ordinary, non-type parameter interfaces implement comparable.
if constraint && comparable(V, true /* spec comparability */, nil, nil) {
// V is comparable if we are at Go 1.20 or higher.
if check == nil || check.allowVersion(check.pkg, 1, 20) {
return true
}
if cause != nil {
*cause = check.sprintf("%s to %s comparable requires go1.20 or later", V, verb)
}
return false
}
if cause != nil {
*cause = check.sprintf("%s does not %s comparable", V, verb)
}
return false
}
// V must also be in the set of types of T, if any.
// Constraints with empty type sets were already excluded above.
if !Ti.typeSet().hasTerms() {
return checkComparability() // nothing to do
}
// If V is itself an interface, each of its possible types must be in the set
// of T types (i.e., the V type set must be a subset of the T type set).
// Interfaces V with empty type sets were already excluded above.
if Vi != nil {
if !Vi.typeSet().subsetOf(Ti.typeSet()) {
// TODO(gri) report which type is missing
if cause != nil {
*cause = check.sprintf("%s does not %s %s", V, verb, T)
}
return false
}
return checkComparability()
}
// Otherwise, V's type must be included in the iface type set.
var alt Type
if Ti.typeSet().is(func(t *term) bool {
if !t.includes(V) {
// If V ∉ t.typ but V ∈ ~t.typ then remember this type
// so we can suggest it as an alternative in the error
// message.
if alt == nil && !t.tilde && Identical(t.typ, under(t.typ)) {
tt := *t
tt.tilde = true
if tt.includes(V) {
alt = t.typ
}
}
return true
}
return false
}) {
if cause != nil {
var detail string
switch {
case alt != nil:
detail = check.sprintf("possibly missing ~ for %s in %s", alt, T)
case mentions(Ti, V):
detail = check.sprintf("%s mentions %s, but %s is not in the type set of %s", T, V, V, T)
default:
detail = check.sprintf("%s missing in %s", V, Ti.typeSet().terms)
}
*cause = check.sprintf("%s does not %s %s (%s)", V, verb, T, detail)
}
return false
}
return checkComparability()
}
// mentions reports whether type T "mentions" typ in an (embedded) element or term
// of T (whether typ is in the type set of T or not). For better error messages.
func mentions(T, typ Type) bool {
switch T := T.(type) {
case *Interface:
for _, e := range T.embeddeds {
if mentions(e, typ) {
return true
}
}
case *Union:
for _, t := range T.terms {
if mentions(t.typ, typ) {
return true
}
}
default:
if Identical(T, typ) {
return true
}
}
return false
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/token"
. "internal/types/errors"
)
// ----------------------------------------------------------------------------
// API
// An Interface represents an interface type.
type Interface struct {
check *Checker // for error reporting; nil once type set is computed
methods []*Func // ordered list of explicitly declared methods
embeddeds []Type // ordered list of explicitly embedded elements
embedPos *[]token.Pos // positions of embedded elements; or nil (for error messages) - use pointer to save space
implicit bool // interface is wrapper for type set literal (non-interface T, ~T, or A|B)
complete bool // indicates that obj, methods, and embeddeds are set and type set can be computed
tset *_TypeSet // type set described by this interface, computed lazily
}
// typeSet returns the type set for interface t.
func (t *Interface) typeSet() *_TypeSet { return computeInterfaceTypeSet(t.check, nopos, t) }
// emptyInterface represents the empty (completed) interface
var emptyInterface = Interface{complete: true, tset: &topTypeSet}
// NewInterface returns a new interface for the given methods and embedded types.
// NewInterface takes ownership of the provided methods and may modify their types
// by setting missing receivers.
//
// Deprecated: Use NewInterfaceType instead which allows arbitrary embedded types.
func NewInterface(methods []*Func, embeddeds []*Named) *Interface {
tnames := make([]Type, len(embeddeds))
for i, t := range embeddeds {
tnames[i] = t
}
return NewInterfaceType(methods, tnames)
}
// NewInterfaceType returns a new interface for the given methods and embedded
// types. NewInterfaceType takes ownership of the provided methods and may
// modify their types by setting missing receivers.
//
// To avoid race conditions, the interface's type set should be computed before
// concurrent use of the interface, by explicitly calling Complete.
func NewInterfaceType(methods []*Func, embeddeds []Type) *Interface {
if len(methods) == 0 && len(embeddeds) == 0 {
return &emptyInterface
}
// set method receivers if necessary
typ := (*Checker)(nil).newInterface()
for _, m := range methods {
if sig := m.typ.(*Signature); sig.recv == nil {
sig.recv = NewVar(m.pos, m.pkg, "", typ)
}
}
// sort for API stability
sortMethods(methods)
typ.methods = methods
typ.embeddeds = embeddeds
typ.complete = true
return typ
}
// check may be nil
func (check *Checker) newInterface() *Interface {
typ := &Interface{check: check}
if check != nil {
check.needsCleanup(typ)
}
return typ
}
// MarkImplicit marks the interface t as implicit, meaning this interface
// corresponds to a constraint literal such as ~T or A|B without explicit
// interface embedding. MarkImplicit should be called before any concurrent use
// of implicit interfaces.
func (t *Interface) MarkImplicit() {
t.implicit = true
}
// NumExplicitMethods returns the number of explicitly declared methods of interface t.
func (t *Interface) NumExplicitMethods() int { return len(t.methods) }
// ExplicitMethod returns the i'th explicitly declared method of interface t for 0 <= i < t.NumExplicitMethods().
// The methods are ordered by their unique Id.
func (t *Interface) ExplicitMethod(i int) *Func { return t.methods[i] }
// NumEmbeddeds returns the number of embedded types in interface t.
func (t *Interface) NumEmbeddeds() int { return len(t.embeddeds) }
// Embedded returns the i'th embedded defined (*Named) type of interface t for 0 <= i < t.NumEmbeddeds().
// The result is nil if the i'th embedded type is not a defined type.
//
// Deprecated: Use EmbeddedType which is not restricted to defined (*Named) types.
func (t *Interface) Embedded(i int) *Named { tname, _ := t.embeddeds[i].(*Named); return tname }
// EmbeddedType returns the i'th embedded type of interface t for 0 <= i < t.NumEmbeddeds().
func (t *Interface) EmbeddedType(i int) Type { return t.embeddeds[i] }
// NumMethods returns the total number of methods of interface t.
func (t *Interface) NumMethods() int { return t.typeSet().NumMethods() }
// Method returns the i'th method of interface t for 0 <= i < t.NumMethods().
// The methods are ordered by their unique Id.
func (t *Interface) Method(i int) *Func { return t.typeSet().Method(i) }
// Empty reports whether t is the empty interface.
func (t *Interface) Empty() bool { return t.typeSet().IsAll() }
// IsComparable reports whether each type in interface t's type set is comparable.
func (t *Interface) IsComparable() bool { return t.typeSet().IsComparable(nil) }
// IsMethodSet reports whether the interface t is fully described by its method
// set.
func (t *Interface) IsMethodSet() bool { return t.typeSet().IsMethodSet() }
// IsImplicit reports whether the interface t is a wrapper for a type set literal.
func (t *Interface) IsImplicit() bool { return t.implicit }
// Complete computes the interface's type set. It must be called by users of
// NewInterfaceType and NewInterface after the interface's embedded types are
// fully defined and before using the interface type in any way other than to
// form other types. The interface must not contain duplicate methods or a
// panic occurs. Complete returns the receiver.
//
// Interface types that have been completed are safe for concurrent use.
func (t *Interface) Complete() *Interface {
if !t.complete {
t.complete = true
}
t.typeSet() // checks if t.tset is already set
return t
}
func (t *Interface) Underlying() Type { return t }
func (t *Interface) String() string { return TypeString(t, nil) }
// ----------------------------------------------------------------------------
// Implementation
func (t *Interface) cleanup() {
t.check = nil
t.embedPos = nil
}
func (check *Checker) interfaceType(ityp *Interface, iface *ast.InterfaceType, def *Named) {
addEmbedded := func(pos token.Pos, typ Type) {
ityp.embeddeds = append(ityp.embeddeds, typ)
if ityp.embedPos == nil {
ityp.embedPos = new([]token.Pos)
}
*ityp.embedPos = append(*ityp.embedPos, pos)
}
for _, f := range iface.Methods.List {
if len(f.Names) == 0 {
addEmbedded(f.Type.Pos(), parseUnion(check, f.Type))
continue
}
// f.Name != nil
// We have a method with name f.Names[0].
name := f.Names[0]
if name.Name == "_" {
check.error(name, BlankIfaceMethod, "methods must have a unique non-blank name")
continue // ignore
}
typ := check.typ(f.Type)
sig, _ := typ.(*Signature)
if sig == nil {
if typ != Typ[Invalid] {
check.errorf(f.Type, InvalidSyntaxTree, "%s is not a method signature", typ)
}
continue // ignore
}
// Always type-check method type parameters but complain if they are not enabled.
// (This extra check is needed here because interface method signatures don't have
// a receiver specification.)
if sig.tparams != nil {
var at positioner = f.Type
if ftyp, _ := f.Type.(*ast.FuncType); ftyp != nil && ftyp.TypeParams != nil {
at = ftyp.TypeParams
}
check.error(at, InvalidMethodTypeParams, "methods cannot have type parameters")
}
// use named receiver type if available (for better error messages)
var recvTyp Type = ityp
if def != nil {
recvTyp = def
}
sig.recv = NewVar(name.Pos(), check.pkg, "", recvTyp)
m := NewFunc(name.Pos(), check.pkg, name.Name, sig)
check.recordDef(name, m)
ityp.methods = append(ityp.methods, m)
}
// All methods and embedded elements for this interface are collected;
// i.e., this interface may be used in a type set computation.
ityp.complete = true
if len(ityp.methods) == 0 && len(ityp.embeddeds) == 0 {
// empty interface
ityp.tset = &topTypeSet
return
}
// sort for API stability
sortMethods(ityp.methods)
// (don't sort embeddeds: they must correspond to *embedPos entries)
// Compute type set as soon as possible to report any errors.
// Subsequent uses of type sets will use this computed type
// set and won't need to pass in a *Checker.
check.later(func() {
computeInterfaceTypeSet(check, iface.Pos(), ityp)
}).describef(iface, "compute type set for %s", ityp)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/token"
. "internal/types/errors"
)
// labels checks correct label use in body.
func (check *Checker) labels(body *ast.BlockStmt) {
// set of all labels in this body
all := NewScope(nil, body.Pos(), body.End(), "label")
fwdJumps := check.blockBranches(all, nil, nil, body.List)
// If there are any forward jumps left, no label was found for
// the corresponding goto statements. Either those labels were
// never defined, or they are inside blocks and not reachable
// for the respective gotos.
for _, jmp := range fwdJumps {
var msg string
var code Code
name := jmp.Label.Name
if alt := all.Lookup(name); alt != nil {
msg = "goto %s jumps into block"
alt.(*Label).used = true // avoid another error
code = JumpIntoBlock
} else {
msg = "label %s not declared"
code = UndeclaredLabel
}
check.errorf(jmp.Label, code, msg, name)
}
// spec: "It is illegal to define a label that is never used."
for name, obj := range all.elems {
obj = resolve(name, obj)
if lbl := obj.(*Label); !lbl.used {
check.softErrorf(lbl, UnusedLabel, "label %s declared and not used", lbl.name)
}
}
}
// A block tracks label declarations in a block and its enclosing blocks.
type block struct {
parent *block // enclosing block
lstmt *ast.LabeledStmt // labeled statement to which this block belongs, or nil
labels map[string]*ast.LabeledStmt // allocated lazily
}
// insert records a new label declaration for the current block.
// The label must not have been declared before in any block.
func (b *block) insert(s *ast.LabeledStmt) {
name := s.Label.Name
if debug {
assert(b.gotoTarget(name) == nil)
}
labels := b.labels
if labels == nil {
labels = make(map[string]*ast.LabeledStmt)
b.labels = labels
}
labels[name] = s
}
// gotoTarget returns the labeled statement in the current
// or an enclosing block with the given label name, or nil.
func (b *block) gotoTarget(name string) *ast.LabeledStmt {
for s := b; s != nil; s = s.parent {
if t := s.labels[name]; t != nil {
return t
}
}
return nil
}
// enclosingTarget returns the innermost enclosing labeled
// statement with the given label name, or nil.
func (b *block) enclosingTarget(name string) *ast.LabeledStmt {
for s := b; s != nil; s = s.parent {
if t := s.lstmt; t != nil && t.Label.Name == name {
return t
}
}
return nil
}
// blockBranches processes a block's statement list and returns the set of outgoing forward jumps.
// all is the scope of all declared labels, parent the set of labels declared in the immediately
// enclosing block, and lstmt is the labeled statement this block is associated with (or nil).
func (check *Checker) blockBranches(all *Scope, parent *block, lstmt *ast.LabeledStmt, list []ast.Stmt) []*ast.BranchStmt {
b := &block{parent: parent, lstmt: lstmt}
var (
varDeclPos token.Pos
fwdJumps, badJumps []*ast.BranchStmt
)
// All forward jumps jumping over a variable declaration are possibly
// invalid (they may still jump out of the block and be ok).
// recordVarDecl records them for the given position.
recordVarDecl := func(pos token.Pos) {
varDeclPos = pos
badJumps = append(badJumps[:0], fwdJumps...) // copy fwdJumps to badJumps
}
jumpsOverVarDecl := func(jmp *ast.BranchStmt) bool {
if varDeclPos.IsValid() {
for _, bad := range badJumps {
if jmp == bad {
return true
}
}
}
return false
}
blockBranches := func(lstmt *ast.LabeledStmt, list []ast.Stmt) {
// Unresolved forward jumps inside the nested block
// become forward jumps in the current block.
fwdJumps = append(fwdJumps, check.blockBranches(all, b, lstmt, list)...)
}
var stmtBranches func(ast.Stmt)
stmtBranches = func(s ast.Stmt) {
switch s := s.(type) {
case *ast.DeclStmt:
if d, _ := s.Decl.(*ast.GenDecl); d != nil && d.Tok == token.VAR {
recordVarDecl(d.Pos())
}
case *ast.LabeledStmt:
// declare non-blank label
if name := s.Label.Name; name != "_" {
lbl := NewLabel(s.Label.Pos(), check.pkg, name)
if alt := all.Insert(lbl); alt != nil {
check.softErrorf(lbl, DuplicateLabel, "label %s already declared", name)
check.reportAltDecl(alt)
// ok to continue
} else {
b.insert(s)
check.recordDef(s.Label, lbl)
}
// resolve matching forward jumps and remove them from fwdJumps
i := 0
for _, jmp := range fwdJumps {
if jmp.Label.Name == name {
// match
lbl.used = true
check.recordUse(jmp.Label, lbl)
if jumpsOverVarDecl(jmp) {
check.softErrorf(
jmp.Label,
JumpOverDecl,
"goto %s jumps over variable declaration at line %d",
name,
check.fset.Position(varDeclPos).Line,
)
// ok to continue
}
} else {
// no match - record new forward jump
fwdJumps[i] = jmp
i++
}
}
fwdJumps = fwdJumps[:i]
lstmt = s
}
stmtBranches(s.Stmt)
case *ast.BranchStmt:
if s.Label == nil {
return // checked in 1st pass (check.stmt)
}
// determine and validate target
name := s.Label.Name
switch s.Tok {
case token.BREAK:
// spec: "If there is a label, it must be that of an enclosing
// "for", "switch", or "select" statement, and that is the one
// whose execution terminates."
valid := false
if t := b.enclosingTarget(name); t != nil {
switch t.Stmt.(type) {
case *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt, *ast.ForStmt, *ast.RangeStmt:
valid = true
}
}
if !valid {
check.errorf(s.Label, MisplacedLabel, "invalid break label %s", name)
return
}
case token.CONTINUE:
// spec: "If there is a label, it must be that of an enclosing
// "for" statement, and that is the one whose execution advances."
valid := false
if t := b.enclosingTarget(name); t != nil {
switch t.Stmt.(type) {
case *ast.ForStmt, *ast.RangeStmt:
valid = true
}
}
if !valid {
check.errorf(s.Label, MisplacedLabel, "invalid continue label %s", name)
return
}
case token.GOTO:
if b.gotoTarget(name) == nil {
// label may be declared later - add branch to forward jumps
fwdJumps = append(fwdJumps, s)
return
}
default:
check.errorf(s, InvalidSyntaxTree, "branch statement: %s %s", s.Tok, name)
return
}
// record label use
obj := all.Lookup(name)
obj.(*Label).used = true
check.recordUse(s.Label, obj)
case *ast.AssignStmt:
if s.Tok == token.DEFINE {
recordVarDecl(s.Pos())
}
case *ast.BlockStmt:
blockBranches(lstmt, s.List)
case *ast.IfStmt:
stmtBranches(s.Body)
if s.Else != nil {
stmtBranches(s.Else)
}
case *ast.CaseClause:
blockBranches(nil, s.Body)
case *ast.SwitchStmt:
stmtBranches(s.Body)
case *ast.TypeSwitchStmt:
stmtBranches(s.Body)
case *ast.CommClause:
blockBranches(nil, s.Body)
case *ast.SelectStmt:
stmtBranches(s.Body)
case *ast.ForStmt:
stmtBranches(s.Body)
case *ast.RangeStmt:
stmtBranches(s.Body)
}
}
for _, s := range list {
stmtBranches(s)
}
return fwdJumps
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements various field and method lookup functions.
package types
import (
"bytes"
"strings"
)
// Internal use of LookupFieldOrMethod: If the obj result is a method
// associated with a concrete (non-interface) type, the method's signature
// may not be fully set up. Call Checker.objDecl(obj, nil) before accessing
// the method's type.
// LookupFieldOrMethod looks up a field or method with given package and name
// in T and returns the corresponding *Var or *Func, an index sequence, and a
// bool indicating if there were any pointer indirections on the path to the
// field or method. If addressable is set, T is the type of an addressable
// variable (only matters for method lookups). T must not be nil.
//
// The last index entry is the field or method index in the (possibly embedded)
// type where the entry was found, either:
//
// 1. the list of declared methods of a named type; or
// 2. the list of all methods (method set) of an interface type; or
// 3. the list of fields of a struct type.
//
// The earlier index entries are the indices of the embedded struct fields
// traversed to get to the found entry, starting at depth 0.
//
// If no entry is found, a nil object is returned. In this case, the returned
// index and indirect values have the following meaning:
//
// - If index != nil, the index sequence points to an ambiguous entry
// (the same name appeared more than once at the same embedding level).
//
// - If indirect is set, a method with a pointer receiver type was found
// but there was no pointer on the path from the actual receiver type to
// the method's formal receiver base type, nor was the receiver addressable.
func LookupFieldOrMethod(T Type, addressable bool, pkg *Package, name string) (obj Object, index []int, indirect bool) {
if T == nil {
panic("LookupFieldOrMethod on nil type")
}
// Methods cannot be associated to a named pointer type.
// (spec: "The type denoted by T is called the receiver base type;
// it must not be a pointer or interface type and it must be declared
// in the same package as the method.").
// Thus, if we have a named pointer type, proceed with the underlying
// pointer type but discard the result if it is a method since we would
// not have found it for T (see also go.dev/issue/8590).
if t, _ := T.(*Named); t != nil {
if p, _ := t.Underlying().(*Pointer); p != nil {
obj, index, indirect = lookupFieldOrMethodImpl(p, false, pkg, name, false)
if _, ok := obj.(*Func); ok {
return nil, nil, false
}
return
}
}
obj, index, indirect = lookupFieldOrMethodImpl(T, addressable, pkg, name, false)
// If we didn't find anything and if we have a type parameter with a core type,
// see if there is a matching field (but not a method, those need to be declared
// explicitly in the constraint). If the constraint is a named pointer type (see
// above), we are ok here because only fields are accepted as results.
const enableTParamFieldLookup = false // see go.dev/issue/51576
if enableTParamFieldLookup && obj == nil && isTypeParam(T) {
if t := coreType(T); t != nil {
obj, index, indirect = lookupFieldOrMethodImpl(t, addressable, pkg, name, false)
if _, ok := obj.(*Var); !ok {
obj, index, indirect = nil, nil, false // accept fields (variables) only
}
}
}
return
}
// lookupFieldOrMethodImpl is the implementation of LookupFieldOrMethod.
// Notably, in contrast to LookupFieldOrMethod, it won't find struct fields
// in base types of defined (*Named) pointer types T. For instance, given
// the declaration:
//
// type T *struct{f int}
//
// lookupFieldOrMethodImpl won't find the field f in the defined (*Named) type T
// (methods on T are not permitted in the first place).
//
// Thus, lookupFieldOrMethodImpl should only be called by LookupFieldOrMethod
// and missingMethod (the latter doesn't care about struct fields).
//
// If foldCase is true, method names are considered equal if they are equal
// with case folding.
//
// The resulting object may not be fully type-checked.
func lookupFieldOrMethodImpl(T Type, addressable bool, pkg *Package, name string, foldCase bool) (obj Object, index []int, indirect bool) {
// WARNING: The code in this function is extremely subtle - do not modify casually!
if name == "_" {
return // blank fields/methods are never found
}
// Importantly, we must not call under before the call to deref below (nor
// does deref call under), as doing so could incorrectly result in finding
// methods of the pointer base type when T is a (*Named) pointer type.
typ, isPtr := deref(T)
// *typ where typ is an interface (incl. a type parameter) has no methods.
if isPtr {
if _, ok := under(typ).(*Interface); ok {
return
}
}
// Start with typ as single entry at shallowest depth.
current := []embeddedType{{typ, nil, isPtr, false}}
// seen tracks named types that we have seen already, allocated lazily.
// Used to avoid endless searches in case of recursive types.
//
// We must use a lookup on identity rather than a simple map[*Named]bool as
// instantiated types may be identical but not equal.
var seen instanceLookup
// search current depth
for len(current) > 0 {
var next []embeddedType // embedded types found at current depth
// look for (pkg, name) in all types at current depth
for _, e := range current {
typ := e.typ
// If we have a named type, we may have associated methods.
// Look for those first.
if named, _ := typ.(*Named); named != nil {
if alt := seen.lookup(named); alt != nil {
// We have seen this type before, at a more shallow depth
// (note that multiples of this type at the current depth
// were consolidated before). The type at that depth shadows
// this same type at the current depth, so we can ignore
// this one.
continue
}
seen.add(named)
// look for a matching attached method
if i, m := named.lookupMethod(pkg, name, foldCase); m != nil {
// potential match
// caution: method may not have a proper signature yet
index = concat(e.index, i)
if obj != nil || e.multiples {
return nil, index, false // collision
}
obj = m
indirect = e.indirect
continue // we can't have a matching field or interface method
}
}
switch t := under(typ).(type) {
case *Struct:
// look for a matching field and collect embedded types
for i, f := range t.fields {
if f.sameId(pkg, name) {
assert(f.typ != nil)
index = concat(e.index, i)
if obj != nil || e.multiples {
return nil, index, false // collision
}
obj = f
indirect = e.indirect
continue // we can't have a matching interface method
}
// Collect embedded struct fields for searching the next
// lower depth, but only if we have not seen a match yet
// (if we have a match it is either the desired field or
// we have a name collision on the same depth; in either
// case we don't need to look further).
// Embedded fields are always of the form T or *T where
// T is a type name. If e.typ appeared multiple times at
// this depth, f.typ appears multiple times at the next
// depth.
if obj == nil && f.embedded {
typ, isPtr := deref(f.typ)
// TODO(gri) optimization: ignore types that can't
// have fields or methods (only Named, Struct, and
// Interface types need to be considered).
next = append(next, embeddedType{typ, concat(e.index, i), e.indirect || isPtr, e.multiples})
}
}
case *Interface:
// look for a matching method (interface may be a type parameter)
if i, m := t.typeSet().LookupMethod(pkg, name, foldCase); m != nil {
assert(m.typ != nil)
index = concat(e.index, i)
if obj != nil || e.multiples {
return nil, index, false // collision
}
obj = m
indirect = e.indirect
}
}
}
if obj != nil {
// found a potential match
// spec: "A method call x.m() is valid if the method set of (the type of) x
// contains m and the argument list can be assigned to the parameter
// list of m. If x is addressable and &x's method set contains m, x.m()
// is shorthand for (&x).m()".
if f, _ := obj.(*Func); f != nil {
// determine if method has a pointer receiver
if f.hasPtrRecv() && !indirect && !addressable {
return nil, nil, true // pointer/addressable receiver required
}
}
return
}
current = consolidateMultiples(next)
}
return nil, nil, false // not found
}
// embeddedType represents an embedded type
type embeddedType struct {
typ Type
index []int // embedded field indices, starting with index at depth 0
indirect bool // if set, there was a pointer indirection on the path to this field
multiples bool // if set, typ appears multiple times at this depth
}
// consolidateMultiples collects multiple list entries with the same type
// into a single entry marked as containing multiples. The result is the
// consolidated list.
func consolidateMultiples(list []embeddedType) []embeddedType {
if len(list) <= 1 {
return list // at most one entry - nothing to do
}
n := 0 // number of entries w/ unique type
prev := make(map[Type]int) // index at which type was previously seen
for _, e := range list {
if i, found := lookupType(prev, e.typ); found {
list[i].multiples = true
// ignore this entry
} else {
prev[e.typ] = n
list[n] = e
n++
}
}
return list[:n]
}
func lookupType(m map[Type]int, typ Type) (int, bool) {
// fast path: maybe the types are equal
if i, found := m[typ]; found {
return i, true
}
for t, i := range m {
if Identical(t, typ) {
return i, true
}
}
return 0, false
}
type instanceLookup struct {
// buf is used to avoid allocating the map m in the common case of a small
// number of instances.
buf [3]*Named
m map[*Named][]*Named
}
func (l *instanceLookup) lookup(inst *Named) *Named {
for _, t := range l.buf {
if t != nil && Identical(inst, t) {
return t
}
}
for _, t := range l.m[inst.Origin()] {
if Identical(inst, t) {
return t
}
}
return nil
}
func (l *instanceLookup) add(inst *Named) {
for i, t := range l.buf {
if t == nil {
l.buf[i] = inst
return
}
}
if l.m == nil {
l.m = make(map[*Named][]*Named)
}
insts := l.m[inst.Origin()]
l.m[inst.Origin()] = append(insts, inst)
}
// MissingMethod returns (nil, false) if V implements T, otherwise it
// returns a missing method required by T and whether it is missing or
// just has the wrong type: either a pointer receiver or wrong signature.
//
// For non-interface types V, or if static is set, V implements T if all
// methods of T are present in V. Otherwise (V is an interface and static
// is not set), MissingMethod only checks that methods of T which are also
// present in V have matching types (e.g., for a type assertion x.(T) where
// x is of interface type V).
func MissingMethod(V Type, T *Interface, static bool) (method *Func, wrongType bool) {
return (*Checker)(nil).missingMethod(V, T, static, Identical, nil)
}
// missingMethod is like MissingMethod but accepts a *Checker as receiver,
// a comparator equivalent for type comparison, and a *string for error causes.
// The receiver may be nil if missingMethod is invoked through an exported
// API call (such as MissingMethod), i.e., when all methods have been type-
// checked.
// The underlying type of T must be an interface; T (rather than its under-
// lying type) is used for better error messages (reported through *cause).
// The comparator is used to compare signatures.
// If a method is missing and cause is not nil, *cause describes the error.
func (check *Checker) missingMethod(V, T Type, static bool, equivalent func(x, y Type) bool, cause *string) (method *Func, wrongType bool) {
methods := under(T).(*Interface).typeSet().methods // T must be an interface
if len(methods) == 0 {
return nil, false
}
const (
ok = iota
notFound
wrongName
wrongSig
ptrRecv
field
)
state := ok
var m *Func // method on T we're trying to implement
var f *Func // method on V, if found (state is one of ok, wrongName, wrongSig, ptrRecv)
if u, _ := under(V).(*Interface); u != nil {
tset := u.typeSet()
for _, m = range methods {
_, f = tset.LookupMethod(m.pkg, m.name, false)
if f == nil {
if !static {
continue
}
state = notFound
break
}
if !equivalent(f.typ, m.typ) {
state = wrongSig
break
}
}
} else {
for _, m = range methods {
obj, _, _ := lookupFieldOrMethodImpl(V, false, m.pkg, m.name, false)
// check if m is on *V, or on V with case-folding
if obj == nil {
state = notFound
// TODO(gri) Instead of NewPointer(V) below, can we just set the "addressable" argument?
obj, _, _ = lookupFieldOrMethodImpl(NewPointer(V), false, m.pkg, m.name, false)
if obj != nil {
f, _ = obj.(*Func)
if f != nil {
state = ptrRecv
}
// otherwise we found a field, keep state == notFound
break
}
obj, _, _ = lookupFieldOrMethodImpl(V, false, m.pkg, m.name, true /* fold case */)
if obj != nil {
f, _ = obj.(*Func)
if f != nil {
state = wrongName
}
// otherwise we found a (differently spelled) field, keep state == notFound
}
break
}
// we must have a method (not a struct field)
f, _ = obj.(*Func)
if f == nil {
state = field
break
}
// methods may not have a fully set up signature yet
if check != nil {
check.objDecl(f, nil)
}
if !equivalent(f.typ, m.typ) {
state = wrongSig
break
}
}
}
if state == ok {
return nil, false
}
if cause != nil {
switch state {
case notFound:
switch {
case isInterfacePtr(V):
*cause = "(" + check.interfacePtrError(V) + ")"
case isInterfacePtr(T):
*cause = "(" + check.interfacePtrError(T) + ")"
default:
*cause = check.sprintf("(missing method %s)", m.Name())
}
case wrongName:
fs, ms := check.funcString(f, false), check.funcString(m, false)
*cause = check.sprintf("(missing method %s)\n\t\thave %s\n\t\twant %s",
m.Name(), fs, ms)
case wrongSig:
fs, ms := check.funcString(f, false), check.funcString(m, false)
if fs == ms {
// Don't report "want Foo, have Foo".
// Add package information to disambiguate (go.dev/issue/54258).
fs, ms = check.funcString(f, true), check.funcString(m, true)
}
*cause = check.sprintf("(wrong type for method %s)\n\t\thave %s\n\t\twant %s",
m.Name(), fs, ms)
case ptrRecv:
*cause = check.sprintf("(method %s has pointer receiver)", m.Name())
case field:
*cause = check.sprintf("(%s.%s is a field, not a method)", V, m.Name())
default:
unreachable()
}
}
return m, state == wrongSig || state == ptrRecv
}
func isInterfacePtr(T Type) bool {
p, _ := under(T).(*Pointer)
return p != nil && IsInterface(p.base)
}
// check may be nil.
func (check *Checker) interfacePtrError(T Type) string {
assert(isInterfacePtr(T))
if p, _ := under(T).(*Pointer); isTypeParam(p.base) {
return check.sprintf("type %s is pointer to type parameter, not type parameter", T)
}
return check.sprintf("type %s is pointer to interface, not interface", T)
}
// funcString returns a string of the form name + signature for f.
// check may be nil.
func (check *Checker) funcString(f *Func, pkgInfo bool) string {
buf := bytes.NewBufferString(f.name)
var qf Qualifier
if check != nil && !pkgInfo {
qf = check.qualifier
}
w := newTypeWriter(buf, qf)
w.pkgInfo = pkgInfo
w.paramNames = false
w.signature(f.typ.(*Signature))
return buf.String()
}
// assertableTo reports whether a value of type V can be asserted to have type T.
// The receiver may be nil if assertableTo is invoked through an exported API call
// (such as AssertableTo), i.e., when all methods have been type-checked.
// The underlying type of V must be an interface.
// If the result is false and cause is not nil, *cause describes the error.
// TODO(gri) replace calls to this function with calls to newAssertableTo.
func (check *Checker) assertableTo(V, T Type, cause *string) bool {
// no static check is required if T is an interface
// spec: "If T is an interface type, x.(T) asserts that the
// dynamic type of x implements the interface T."
if IsInterface(T) {
return true
}
// TODO(gri) fix this for generalized interfaces
m, _ := check.missingMethod(T, V, false, Identical, cause)
return m == nil
}
// newAssertableTo reports whether a value of type V can be asserted to have type T.
// It also implements behavior for interfaces that currently are only permitted
// in constraint position (we have not yet defined that behavior in the spec).
// The underlying type of V must be an interface.
// If the result is false and cause is not nil, *cause is set to the error cause.
func (check *Checker) newAssertableTo(V, T Type, cause *string) bool {
// no static check is required if T is an interface
// spec: "If T is an interface type, x.(T) asserts that the
// dynamic type of x implements the interface T."
if IsInterface(T) {
return true
}
return check.implements(T, V, false, cause)
}
// deref dereferences typ if it is a *Pointer (but not a *Named type
// with an underlying pointer type!) and returns its base and true.
// Otherwise it returns (typ, false).
func deref(typ Type) (Type, bool) {
if p, _ := typ.(*Pointer); p != nil {
// p.base should never be nil, but be conservative
if p.base == nil {
if debug {
panic("pointer with nil base type (possibly due to an invalid cyclic declaration)")
}
return Typ[Invalid], true
}
return p.base, true
}
return typ, false
}
// derefStructPtr dereferences typ if it is a (named or unnamed) pointer to a
// (named or unnamed) struct and returns its base. Otherwise it returns typ.
func derefStructPtr(typ Type) Type {
if p, _ := under(typ).(*Pointer); p != nil {
if _, ok := under(p.base).(*Struct); ok {
return p.base
}
}
return typ
}
// concat returns the result of concatenating list and i.
// The result does not share its underlying array with list.
func concat(list []int, i int) []int {
var t []int
t = append(t, list...)
return append(t, i)
}
// fieldIndex returns the index for the field with matching package and name, or a value < 0.
func fieldIndex(fields []*Var, pkg *Package, name string) int {
if name != "_" {
for i, f := range fields {
if f.sameId(pkg, name) {
return i
}
}
}
return -1
}
// lookupMethod returns the index of and method with matching package and name, or (-1, nil).
// If foldCase is true, method names are considered equal if they are equal with case folding.
func lookupMethod(methods []*Func, pkg *Package, name string, foldCase bool) (int, *Func) {
if name != "_" {
for i, m := range methods {
if (m.name == name || foldCase && strings.EqualFold(m.name, name)) && m.sameId(pkg, m.name) {
return i, m
}
}
}
return -1, nil
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A Map represents a map type.
type Map struct {
key, elem Type
}
// NewMap returns a new map for the given key and element types.
func NewMap(key, elem Type) *Map {
return &Map{key: key, elem: elem}
}
// Key returns the key type of map m.
func (m *Map) Key() Type { return m.key }
// Elem returns the element type of map m.
func (m *Map) Elem() Type { return m.elem }
func (t *Map) Underlying() Type { return t }
func (t *Map) String() string { return TypeString(t, nil) }
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements method sets.
package types
import (
"fmt"
"sort"
"strings"
)
// A MethodSet is an ordered set of concrete or abstract (interface) methods;
// a method is a MethodVal selection, and they are ordered by ascending m.Obj().Id().
// The zero value for a MethodSet is a ready-to-use empty method set.
type MethodSet struct {
list []*Selection
}
func (s *MethodSet) String() string {
if s.Len() == 0 {
return "MethodSet {}"
}
var buf strings.Builder
fmt.Fprintln(&buf, "MethodSet {")
for _, f := range s.list {
fmt.Fprintf(&buf, "\t%s\n", f)
}
fmt.Fprintln(&buf, "}")
return buf.String()
}
// Len returns the number of methods in s.
func (s *MethodSet) Len() int { return len(s.list) }
// At returns the i'th method in s for 0 <= i < s.Len().
func (s *MethodSet) At(i int) *Selection { return s.list[i] }
// Lookup returns the method with matching package and name, or nil if not found.
func (s *MethodSet) Lookup(pkg *Package, name string) *Selection {
if s.Len() == 0 {
return nil
}
key := Id(pkg, name)
i := sort.Search(len(s.list), func(i int) bool {
m := s.list[i]
return m.obj.Id() >= key
})
if i < len(s.list) {
m := s.list[i]
if m.obj.Id() == key {
return m
}
}
return nil
}
// Shared empty method set.
var emptyMethodSet MethodSet
// Note: NewMethodSet is intended for external use only as it
// requires interfaces to be complete. It may be used
// internally if LookupFieldOrMethod completed the same
// interfaces beforehand.
// NewMethodSet returns the method set for the given type T.
// It always returns a non-nil method set, even if it is empty.
func NewMethodSet(T Type) *MethodSet {
// WARNING: The code in this function is extremely subtle - do not modify casually!
// This function and lookupFieldOrMethod should be kept in sync.
// TODO(rfindley) confirm that this code is in sync with lookupFieldOrMethod
// with respect to type params.
// method set up to the current depth, allocated lazily
var base methodSet
typ, isPtr := deref(T)
// *typ where typ is an interface has no methods.
if isPtr && IsInterface(typ) {
return &emptyMethodSet
}
// Start with typ as single entry at shallowest depth.
current := []embeddedType{{typ, nil, isPtr, false}}
// seen tracks named types that we have seen already, allocated lazily.
// Used to avoid endless searches in case of recursive types.
//
// We must use a lookup on identity rather than a simple map[*Named]bool as
// instantiated types may be identical but not equal.
var seen instanceLookup
// collect methods at current depth
for len(current) > 0 {
var next []embeddedType // embedded types found at current depth
// field and method sets at current depth, indexed by names (Id's), and allocated lazily
var fset map[string]bool // we only care about the field names
var mset methodSet
for _, e := range current {
typ := e.typ
// If we have a named type, we may have associated methods.
// Look for those first.
if named, _ := typ.(*Named); named != nil {
if alt := seen.lookup(named); alt != nil {
// We have seen this type before, at a more shallow depth
// (note that multiples of this type at the current depth
// were consolidated before). The type at that depth shadows
// this same type at the current depth, so we can ignore
// this one.
continue
}
seen.add(named)
for i := 0; i < named.NumMethods(); i++ {
mset = mset.addOne(named.Method(i), concat(e.index, i), e.indirect, e.multiples)
}
}
switch t := under(typ).(type) {
case *Struct:
for i, f := range t.fields {
if fset == nil {
fset = make(map[string]bool)
}
fset[f.Id()] = true
// Embedded fields are always of the form T or *T where
// T is a type name. If typ appeared multiple times at
// this depth, f.Type appears multiple times at the next
// depth.
if f.embedded {
typ, isPtr := deref(f.typ)
// TODO(gri) optimization: ignore types that can't
// have fields or methods (only Named, Struct, and
// Interface types need to be considered).
next = append(next, embeddedType{typ, concat(e.index, i), e.indirect || isPtr, e.multiples})
}
}
case *Interface:
mset = mset.add(t.typeSet().methods, e.index, true, e.multiples)
}
}
// Add methods and collisions at this depth to base if no entries with matching
// names exist already.
for k, m := range mset {
if _, found := base[k]; !found {
// Fields collide with methods of the same name at this depth.
if fset[k] {
m = nil // collision
}
if base == nil {
base = make(methodSet)
}
base[k] = m
}
}
// Add all (remaining) fields at this depth as collisions (since they will
// hide any method further down) if no entries with matching names exist already.
for k := range fset {
if _, found := base[k]; !found {
if base == nil {
base = make(methodSet)
}
base[k] = nil // collision
}
}
current = consolidateMultiples(next)
}
if len(base) == 0 {
return &emptyMethodSet
}
// collect methods
var list []*Selection
for _, m := range base {
if m != nil {
m.recv = T
list = append(list, m)
}
}
// sort by unique name
sort.Slice(list, func(i, j int) bool {
return list[i].obj.Id() < list[j].obj.Id()
})
return &MethodSet{list}
}
// A methodSet is a set of methods and name collisions.
// A collision indicates that multiple methods with the
// same unique id, or a field with that id appeared.
type methodSet map[string]*Selection // a nil entry indicates a name collision
// Add adds all functions in list to the method set s.
// If multiples is set, every function in list appears multiple times
// and is treated as a collision.
func (s methodSet) add(list []*Func, index []int, indirect bool, multiples bool) methodSet {
if len(list) == 0 {
return s
}
for i, f := range list {
s = s.addOne(f, concat(index, i), indirect, multiples)
}
return s
}
func (s methodSet) addOne(f *Func, index []int, indirect bool, multiples bool) methodSet {
if s == nil {
s = make(methodSet)
}
key := f.Id()
// if f is not in the set, add it
if !multiples {
// TODO(gri) A found method may not be added because it's not in the method set
// (!indirect && f.hasPtrRecv()). A 2nd method on the same level may be in the method
// set and may not collide with the first one, thus leading to a false positive.
// Is that possible? Investigate.
if _, found := s[key]; !found && (indirect || !f.hasPtrRecv()) {
s[key] = &Selection{MethodVal, nil, f, index, indirect}
return s
}
}
s[key] = nil // collision
return s
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/token"
. "internal/types/errors"
)
// This file implements a check to validate that a Go package doesn't
// have unbounded recursive instantiation, which is not compatible
// with compilers using static instantiation (such as
// monomorphization).
//
// It implements a sort of "type flow" analysis by detecting which
// type parameters are instantiated with other type parameters (or
// types derived thereof). A package cannot be statically instantiated
// if the graph has any cycles involving at least one derived type.
//
// Concretely, we construct a directed, weighted graph. Vertices are
// used to represent type parameters as well as some defined
// types. Edges are used to represent how types depend on each other:
//
// * Everywhere a type-parameterized function or type is instantiated,
// we add edges to each type parameter from the vertices (if any)
// representing each type parameter or defined type referenced by
// the type argument. If the type argument is just the referenced
// type itself, then the edge has weight 0, otherwise 1.
//
// * For every defined type declared within a type-parameterized
// function or method, we add an edge of weight 1 to the defined
// type from each ambient type parameter.
//
// For example, given:
//
// func f[A, B any]() {
// type T int
// f[T, map[A]B]()
// }
//
// we construct vertices representing types A, B, and T. Because of
// declaration "type T int", we construct edges T<-A and T<-B with
// weight 1; and because of instantiation "f[T, map[A]B]" we construct
// edges A<-T with weight 0, and B<-A and B<-B with weight 1.
//
// Finally, we look for any positive-weight cycles. Zero-weight cycles
// are allowed because static instantiation will reach a fixed point.
type monoGraph struct {
vertices []monoVertex
edges []monoEdge
// canon maps method receiver type parameters to their respective
// receiver type's type parameters.
canon map[*TypeParam]*TypeParam
// nameIdx maps a defined type or (canonical) type parameter to its
// vertex index.
nameIdx map[*TypeName]int
}
type monoVertex struct {
weight int // weight of heaviest known path to this vertex
pre int // previous edge (if any) in the above path
len int // length of the above path
// obj is the defined type or type parameter represented by this
// vertex.
obj *TypeName
}
type monoEdge struct {
dst, src int
weight int
pos token.Pos
typ Type
}
func (check *Checker) monomorph() {
// We detect unbounded instantiation cycles using a variant of
// Bellman-Ford's algorithm. Namely, instead of always running |V|
// iterations, we run until we either reach a fixed point or we've
// found a path of length |V|. This allows us to terminate earlier
// when there are no cycles, which should be the common case.
again := true
for again {
again = false
for i, edge := range check.mono.edges {
src := &check.mono.vertices[edge.src]
dst := &check.mono.vertices[edge.dst]
// N.B., we're looking for the greatest weight paths, unlike
// typical Bellman-Ford.
w := src.weight + edge.weight
if w <= dst.weight {
continue
}
dst.pre = i
dst.len = src.len + 1
if dst.len == len(check.mono.vertices) {
check.reportInstanceLoop(edge.dst)
return
}
dst.weight = w
again = true
}
}
}
func (check *Checker) reportInstanceLoop(v int) {
var stack []int
seen := make([]bool, len(check.mono.vertices))
// We have a path that contains a cycle and ends at v, but v may
// only be reachable from the cycle, not on the cycle itself. We
// start by walking backwards along the path until we find a vertex
// that appears twice.
for !seen[v] {
stack = append(stack, v)
seen[v] = true
v = check.mono.edges[check.mono.vertices[v].pre].src
}
// Trim any vertices we visited before visiting v the first
// time. Since v is the first vertex we found within the cycle, any
// vertices we visited earlier cannot be part of the cycle.
for stack[0] != v {
stack = stack[1:]
}
// TODO(mdempsky): Pivot stack so we report the cycle from the top?
obj0 := check.mono.vertices[v].obj
check.error(obj0, InvalidInstanceCycle, "instantiation cycle:")
qf := RelativeTo(check.pkg)
for _, v := range stack {
edge := check.mono.edges[check.mono.vertices[v].pre]
obj := check.mono.vertices[edge.dst].obj
switch obj.Type().(type) {
default:
panic("unexpected type")
case *Named:
check.errorf(atPos(edge.pos), InvalidInstanceCycle, "\t%s implicitly parameterized by %s", obj.Name(), TypeString(edge.typ, qf)) // secondary error, \t indented
case *TypeParam:
check.errorf(atPos(edge.pos), InvalidInstanceCycle, "\t%s instantiated as %s", obj.Name(), TypeString(edge.typ, qf)) // secondary error, \t indented
}
}
}
// recordCanon records that tpar is the canonical type parameter
// corresponding to method type parameter mpar.
func (w *monoGraph) recordCanon(mpar, tpar *TypeParam) {
if w.canon == nil {
w.canon = make(map[*TypeParam]*TypeParam)
}
w.canon[mpar] = tpar
}
// recordInstance records that the given type parameters were
// instantiated with the corresponding type arguments.
func (w *monoGraph) recordInstance(pkg *Package, pos token.Pos, tparams []*TypeParam, targs []Type, xlist []ast.Expr) {
for i, tpar := range tparams {
pos := pos
if i < len(xlist) {
pos = xlist[i].Pos()
}
w.assign(pkg, pos, tpar, targs[i])
}
}
// assign records that tpar was instantiated as targ at pos.
func (w *monoGraph) assign(pkg *Package, pos token.Pos, tpar *TypeParam, targ Type) {
// Go generics do not have an analog to C++`s template-templates,
// where a template parameter can itself be an instantiable
// template. So any instantiation cycles must occur within a single
// package. Accordingly, we can ignore instantiations of imported
// type parameters.
//
// TODO(mdempsky): Push this check up into recordInstance? All type
// parameters in a list will appear in the same package.
if tpar.Obj().Pkg() != pkg {
return
}
// flow adds an edge from vertex src representing that typ flows to tpar.
flow := func(src int, typ Type) {
weight := 1
if typ == targ {
weight = 0
}
w.addEdge(w.typeParamVertex(tpar), src, weight, pos, targ)
}
// Recursively walk the type argument to find any defined types or
// type parameters.
var do func(typ Type)
do = func(typ Type) {
switch typ := typ.(type) {
default:
panic("unexpected type")
case *TypeParam:
assert(typ.Obj().Pkg() == pkg)
flow(w.typeParamVertex(typ), typ)
case *Named:
if src := w.localNamedVertex(pkg, typ.Origin()); src >= 0 {
flow(src, typ)
}
targs := typ.TypeArgs()
for i := 0; i < targs.Len(); i++ {
do(targs.At(i))
}
case *Array:
do(typ.Elem())
case *Basic:
// ok
case *Chan:
do(typ.Elem())
case *Map:
do(typ.Key())
do(typ.Elem())
case *Pointer:
do(typ.Elem())
case *Slice:
do(typ.Elem())
case *Interface:
for i := 0; i < typ.NumMethods(); i++ {
do(typ.Method(i).Type())
}
case *Signature:
tuple := func(tup *Tuple) {
for i := 0; i < tup.Len(); i++ {
do(tup.At(i).Type())
}
}
tuple(typ.Params())
tuple(typ.Results())
case *Struct:
for i := 0; i < typ.NumFields(); i++ {
do(typ.Field(i).Type())
}
}
}
do(targ)
}
// localNamedVertex returns the index of the vertex representing
// named, or -1 if named doesn't need representation.
func (w *monoGraph) localNamedVertex(pkg *Package, named *Named) int {
obj := named.Obj()
if obj.Pkg() != pkg {
return -1 // imported type
}
root := pkg.Scope()
if obj.Parent() == root {
return -1 // package scope, no ambient type parameters
}
if idx, ok := w.nameIdx[obj]; ok {
return idx
}
idx := -1
// Walk the type definition's scope to find any ambient type
// parameters that it's implicitly parameterized by.
for scope := obj.Parent(); scope != root; scope = scope.Parent() {
for _, elem := range scope.elems {
if elem, ok := elem.(*TypeName); ok && !elem.IsAlias() && cmpPos(elem.Pos(), obj.Pos()) < 0 {
if tpar, ok := elem.Type().(*TypeParam); ok {
if idx < 0 {
idx = len(w.vertices)
w.vertices = append(w.vertices, monoVertex{obj: obj})
}
w.addEdge(idx, w.typeParamVertex(tpar), 1, obj.Pos(), tpar)
}
}
}
}
if w.nameIdx == nil {
w.nameIdx = make(map[*TypeName]int)
}
w.nameIdx[obj] = idx
return idx
}
// typeParamVertex returns the index of the vertex representing tpar.
func (w *monoGraph) typeParamVertex(tpar *TypeParam) int {
if x, ok := w.canon[tpar]; ok {
tpar = x
}
obj := tpar.Obj()
if idx, ok := w.nameIdx[obj]; ok {
return idx
}
if w.nameIdx == nil {
w.nameIdx = make(map[*TypeName]int)
}
idx := len(w.vertices)
w.vertices = append(w.vertices, monoVertex{obj: obj})
w.nameIdx[obj] = idx
return idx
}
func (w *monoGraph) addEdge(dst, src, weight int, pos token.Pos, typ Type) {
// TODO(mdempsky): Deduplicate redundant edges?
w.edges = append(w.edges, monoEdge{
dst: dst,
src: src,
weight: weight,
pos: pos,
typ: typ,
})
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/token"
"sync"
"sync/atomic"
)
// Type-checking Named types is subtle, because they may be recursively
// defined, and because their full details may be spread across multiple
// declarations (via methods). For this reason they are type-checked lazily,
// to avoid information being accessed before it is complete.
//
// Conceptually, it is helpful to think of named types as having two distinct
// sets of information:
// - "LHS" information, defining their identity: Obj() and TypeArgs()
// - "RHS" information, defining their details: TypeParams(), Underlying(),
// and methods.
//
// In this taxonomy, LHS information is available immediately, but RHS
// information is lazy. Specifically, a named type N may be constructed in any
// of the following ways:
// 1. type-checked from the source
// 2. loaded eagerly from export data
// 3. loaded lazily from export data (when using unified IR)
// 4. instantiated from a generic type
//
// In cases 1, 3, and 4, it is possible that the underlying type or methods of
// N may not be immediately available.
// - During type-checking, we allocate N before type-checking its underlying
// type or methods, so that we may resolve recursive references.
// - When loading from export data, we may load its methods and underlying
// type lazily using a provided load function.
// - After instantiating, we lazily expand the underlying type and methods
// (note that instances may be created while still in the process of
// type-checking the original type declaration).
//
// In cases 3 and 4 this lazy construction may also occur concurrently, due to
// concurrent use of the type checker API (after type checking or importing has
// finished). It is critical that we keep track of state, so that Named types
// are constructed exactly once and so that we do not access their details too
// soon.
//
// We achieve this by tracking state with an atomic state variable, and
// guarding potentially concurrent calculations with a mutex. At any point in
// time this state variable determines which data on N may be accessed. As
// state monotonically progresses, any data available at state M may be
// accessed without acquiring the mutex at state N, provided N >= M.
//
// GLOSSARY: Here are a few terms used in this file to describe Named types:
// - We say that a Named type is "instantiated" if it has been constructed by
// instantiating a generic named type with type arguments.
// - We say that a Named type is "declared" if it corresponds to a type
// declaration in the source. Instantiated named types correspond to a type
// instantiation in the source, not a declaration. But their Origin type is
// a declared type.
// - We say that a Named type is "resolved" if its RHS information has been
// loaded or fully type-checked. For Named types constructed from export
// data, this may involve invoking a loader function to extract information
// from export data. For instantiated named types this involves reading
// information from their origin.
// - We say that a Named type is "expanded" if it is an instantiated type and
// type parameters in its underlying type and methods have been substituted
// with the type arguments from the instantiation. A type may be partially
// expanded if some but not all of these details have been substituted.
// Similarly, we refer to these individual details (underlying type or
// method) as being "expanded".
// - When all information is known for a named type, we say it is "complete".
//
// Some invariants to keep in mind: each declared Named type has a single
// corresponding object, and that object's type is the (possibly generic) Named
// type. Declared Named types are identical if and only if their pointers are
// identical. On the other hand, multiple instantiated Named types may be
// identical even though their pointers are not identical. One has to use
// Identical to compare them. For instantiated named types, their obj is a
// synthetic placeholder that records their position of the corresponding
// instantiation in the source (if they were constructed during type checking).
//
// To prevent infinite expansion of named instances that are created outside of
// type-checking, instances share a Context with other instances created during
// their expansion. Via the pidgeonhole principle, this guarantees that in the
// presence of a cycle of named types, expansion will eventually find an
// existing instance in the Context and short-circuit the expansion.
//
// Once an instance is complete, we can nil out this shared Context to unpin
// memory, though this Context may still be held by other incomplete instances
// in its "lineage".
// A Named represents a named (defined) type.
type Named struct {
check *Checker // non-nil during type-checking; nil otherwise
obj *TypeName // corresponding declared object for declared types; see above for instantiated types
// fromRHS holds the type (on RHS of declaration) this *Named type is derived
// from (for cycle reporting). Only used by validType, and therefore does not
// require synchronization.
fromRHS Type
// information for instantiated types; nil otherwise
inst *instance
mu sync.Mutex // guards all fields below
state_ uint32 // the current state of this type; must only be accessed atomically
underlying Type // possibly a *Named during setup; never a *Named once set up completely
tparams *TypeParamList // type parameters, or nil
// methods declared for this type (not the method set of this type)
// Signatures are type-checked lazily.
// For non-instantiated types, this is a fully populated list of methods. For
// instantiated types, methods are individually expanded when they are first
// accessed.
methods []*Func
// loader may be provided to lazily load type parameters, underlying type, and methods.
loader func(*Named) (tparams []*TypeParam, underlying Type, methods []*Func)
}
// instance holds information that is only necessary for instantiated named
// types.
type instance struct {
orig *Named // original, uninstantiated type
targs *TypeList // type arguments
expandedMethods int // number of expanded methods; expandedMethods <= len(orig.methods)
ctxt *Context // local Context; set to nil after full expansion
}
// namedState represents the possible states that a named type may assume.
type namedState uint32
const (
unresolved namedState = iota // tparams, underlying type and methods might be unavailable
resolved // resolve has run; methods might be incomplete (for instances)
complete // all data is known
)
// NewNamed returns a new named type for the given type name, underlying type, and associated methods.
// If the given type name obj doesn't have a type yet, its type is set to the returned named type.
// The underlying type must not be a *Named.
func NewNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
if _, ok := underlying.(*Named); ok {
panic("underlying type must not be *Named")
}
return (*Checker)(nil).newNamed(obj, underlying, methods)
}
// resolve resolves the type parameters, methods, and underlying type of n.
// This information may be loaded from a provided loader function, or computed
// from an origin type (in the case of instances).
//
// After resolution, the type parameters, methods, and underlying type of n are
// accessible; but if n is an instantiated type, its methods may still be
// unexpanded.
func (n *Named) resolve() *Named {
if n.state() >= resolved { // avoid locking below
return n
}
// TODO(rfindley): if n.check is non-nil we can avoid locking here, since
// type-checking is not concurrent. Evaluate if this is worth doing.
n.mu.Lock()
defer n.mu.Unlock()
if n.state() >= resolved {
return n
}
if n.inst != nil {
assert(n.underlying == nil) // n is an unresolved instance
assert(n.loader == nil) // instances are created by instantiation, in which case n.loader is nil
orig := n.inst.orig
orig.resolve()
underlying := n.expandUnderlying()
n.tparams = orig.tparams
n.underlying = underlying
n.fromRHS = orig.fromRHS // for cycle detection
if len(orig.methods) == 0 {
n.setState(complete) // nothing further to do
n.inst.ctxt = nil
} else {
n.setState(resolved)
}
return n
}
// TODO(mdempsky): Since we're passing n to the loader anyway
// (necessary because types2 expects the receiver type for methods
// on defined interface types to be the Named rather than the
// underlying Interface), maybe it should just handle calling
// SetTypeParams, SetUnderlying, and AddMethod instead? Those
// methods would need to support reentrant calls though. It would
// also make the API more future-proof towards further extensions.
if n.loader != nil {
assert(n.underlying == nil)
assert(n.TypeArgs().Len() == 0) // instances are created by instantiation, in which case n.loader is nil
tparams, underlying, methods := n.loader(n)
n.tparams = bindTParams(tparams)
n.underlying = underlying
n.fromRHS = underlying // for cycle detection
n.methods = methods
n.loader = nil
}
n.setState(complete)
return n
}
// state atomically accesses the current state of the receiver.
func (n *Named) state() namedState {
return namedState(atomic.LoadUint32(&n.state_))
}
// setState atomically stores the given state for n.
// Must only be called while holding n.mu.
func (n *Named) setState(state namedState) {
atomic.StoreUint32(&n.state_, uint32(state))
}
// newNamed is like NewNamed but with a *Checker receiver and additional orig argument.
func (check *Checker) newNamed(obj *TypeName, underlying Type, methods []*Func) *Named {
typ := &Named{check: check, obj: obj, fromRHS: underlying, underlying: underlying, methods: methods}
if obj.typ == nil {
obj.typ = typ
}
// Ensure that typ is always sanity-checked.
if check != nil {
check.needsCleanup(typ)
}
return typ
}
// newNamedInstance creates a new named instance for the given origin and type
// arguments, recording pos as the position of its synthetic object (for error
// reporting).
//
// If set, expanding is the named type instance currently being expanded, that
// led to the creation of this instance.
func (check *Checker) newNamedInstance(pos token.Pos, orig *Named, targs []Type, expanding *Named) *Named {
assert(len(targs) > 0)
obj := NewTypeName(pos, orig.obj.pkg, orig.obj.name, nil)
inst := &instance{orig: orig, targs: newTypeList(targs)}
// Only pass the expanding context to the new instance if their packages
// match. Since type reference cycles are only possible within a single
// package, this is sufficient for the purposes of short-circuiting cycles.
// Avoiding passing the context in other cases prevents unnecessary coupling
// of types across packages.
if expanding != nil && expanding.Obj().pkg == obj.pkg {
inst.ctxt = expanding.inst.ctxt
}
typ := &Named{check: check, obj: obj, inst: inst}
obj.typ = typ
// Ensure that typ is always sanity-checked.
if check != nil {
check.needsCleanup(typ)
}
return typ
}
func (t *Named) cleanup() {
assert(t.inst == nil || t.inst.orig.inst == nil)
// Ensure that every defined type created in the course of type-checking has
// either non-*Named underlying type, or is unexpanded.
//
// This guarantees that we don't leak any types whose underlying type is
// *Named, because any unexpanded instances will lazily compute their
// underlying type by substituting in the underlying type of their origin.
// The origin must have either been imported or type-checked and expanded
// here, and in either case its underlying type will be fully expanded.
switch t.underlying.(type) {
case nil:
if t.TypeArgs().Len() == 0 {
panic("nil underlying")
}
case *Named:
t.under() // t.under may add entries to check.cleaners
}
t.check = nil
}
// Obj returns the type name for the declaration defining the named type t. For
// instantiated types, this is same as the type name of the origin type.
func (t *Named) Obj() *TypeName {
if t.inst == nil {
return t.obj
}
return t.inst.orig.obj
}
// Origin returns the generic type from which the named type t is
// instantiated. If t is not an instantiated type, the result is t.
func (t *Named) Origin() *Named {
if t.inst == nil {
return t
}
return t.inst.orig
}
// TypeParams returns the type parameters of the named type t, or nil.
// The result is non-nil for an (originally) generic type even if it is instantiated.
func (t *Named) TypeParams() *TypeParamList { return t.resolve().tparams }
// SetTypeParams sets the type parameters of the named type t.
// t must not have type arguments.
func (t *Named) SetTypeParams(tparams []*TypeParam) {
assert(t.inst == nil)
t.resolve().tparams = bindTParams(tparams)
}
// TypeArgs returns the type arguments used to instantiate the named type t.
func (t *Named) TypeArgs() *TypeList {
if t.inst == nil {
return nil
}
return t.inst.targs
}
// NumMethods returns the number of explicit methods defined for t.
func (t *Named) NumMethods() int {
return len(t.Origin().resolve().methods)
}
// Method returns the i'th method of named type t for 0 <= i < t.NumMethods().
//
// For an ordinary or instantiated type t, the receiver base type of this
// method is the named type t. For an uninstantiated generic type t, each
// method receiver is instantiated with its receiver type parameters.
func (t *Named) Method(i int) *Func {
t.resolve()
if t.state() >= complete {
return t.methods[i]
}
assert(t.inst != nil) // only instances should have incomplete methods
orig := t.inst.orig
t.mu.Lock()
defer t.mu.Unlock()
if len(t.methods) != len(orig.methods) {
assert(len(t.methods) == 0)
t.methods = make([]*Func, len(orig.methods))
}
if t.methods[i] == nil {
assert(t.inst.ctxt != nil) // we should still have a context remaining from the resolution phase
t.methods[i] = t.expandMethod(i)
t.inst.expandedMethods++
// Check if we've created all methods at this point. If we have, mark the
// type as fully expanded.
if t.inst.expandedMethods == len(orig.methods) {
t.setState(complete)
t.inst.ctxt = nil // no need for a context anymore
}
}
return t.methods[i]
}
// expandMethod substitutes type arguments in the i'th method for an
// instantiated receiver.
func (t *Named) expandMethod(i int) *Func {
// t.orig.methods is not lazy. origm is the method instantiated with its
// receiver type parameters (the "origin" method).
origm := t.inst.orig.Method(i)
assert(origm != nil)
check := t.check
// Ensure that the original method is type-checked.
if check != nil {
check.objDecl(origm, nil)
}
origSig := origm.typ.(*Signature)
rbase, _ := deref(origSig.Recv().Type())
// If rbase is t, then origm is already the instantiated method we're looking
// for. In this case, we return origm to preserve the invariant that
// traversing Method->Receiver Type->Method should get back to the same
// method.
//
// This occurs if t is instantiated with the receiver type parameters, as in
// the use of m in func (r T[_]) m() { r.m() }.
if rbase == t {
return origm
}
sig := origSig
// We can only substitute if we have a correspondence between type arguments
// and type parameters. This check is necessary in the presence of invalid
// code.
if origSig.RecvTypeParams().Len() == t.inst.targs.Len() {
smap := makeSubstMap(origSig.RecvTypeParams().list(), t.inst.targs.list())
var ctxt *Context
if check != nil {
ctxt = check.context()
}
sig = check.subst(origm.pos, origSig, smap, t, ctxt).(*Signature)
}
if sig == origSig {
// No substitution occurred, but we still need to create a new signature to
// hold the instantiated receiver.
copy := *origSig
sig = ©
}
var rtyp Type
if origm.hasPtrRecv() {
rtyp = NewPointer(t)
} else {
rtyp = t
}
sig.recv = substVar(origSig.recv, rtyp)
return substFunc(origm, sig)
}
// SetUnderlying sets the underlying type and marks t as complete.
// t must not have type arguments.
func (t *Named) SetUnderlying(underlying Type) {
assert(t.inst == nil)
if underlying == nil {
panic("underlying type must not be nil")
}
if _, ok := underlying.(*Named); ok {
panic("underlying type must not be *Named")
}
t.resolve().underlying = underlying
if t.fromRHS == nil {
t.fromRHS = underlying // for cycle detection
}
}
// AddMethod adds method m unless it is already in the method list.
// t must not have type arguments.
func (t *Named) AddMethod(m *Func) {
assert(t.inst == nil)
t.resolve()
if i, _ := lookupMethod(t.methods, m.pkg, m.name, false); i < 0 {
t.methods = append(t.methods, m)
}
}
func (t *Named) Underlying() Type { return t.resolve().underlying }
func (t *Named) String() string { return TypeString(t, nil) }
// ----------------------------------------------------------------------------
// Implementation
//
// TODO(rfindley): reorganize the loading and expansion methods under this
// heading.
// under returns the expanded underlying type of n0; possibly by following
// forward chains of named types. If an underlying type is found, resolve
// the chain by setting the underlying type for each defined type in the
// chain before returning it. If no underlying type is found or a cycle
// is detected, the result is Typ[Invalid]. If a cycle is detected and
// n0.check != nil, the cycle is reported.
//
// This is necessary because the underlying type of named may be itself a
// named type that is incomplete:
//
// type (
// A B
// B *C
// C A
// )
//
// The type of C is the (named) type of A which is incomplete,
// and which has as its underlying type the named type B.
func (n0 *Named) under() Type {
u := n0.Underlying()
// If the underlying type of a defined type is not a defined
// (incl. instance) type, then that is the desired underlying
// type.
var n1 *Named
switch u1 := u.(type) {
case nil:
// After expansion via Underlying(), we should never encounter a nil
// underlying.
panic("nil underlying")
default:
// common case
return u
case *Named:
// handled below
n1 = u1
}
if n0.check == nil {
panic("Named.check == nil but type is incomplete")
}
// Invariant: after this point n0 as well as any named types in its
// underlying chain should be set up when this function exits.
check := n0.check
n := n0
seen := make(map[*Named]int) // types that need their underlying type resolved
var path []Object // objects encountered, for cycle reporting
loop:
for {
seen[n] = len(seen)
path = append(path, n.obj)
n = n1
if i, ok := seen[n]; ok {
// cycle
check.cycleError(path[i:])
u = Typ[Invalid]
break
}
u = n.Underlying()
switch u1 := u.(type) {
case nil:
u = Typ[Invalid]
break loop
default:
break loop
case *Named:
// Continue collecting *Named types in the chain.
n1 = u1
}
}
for n := range seen {
// We should never have to update the underlying type of an imported type;
// those underlying types should have been resolved during the import.
// Also, doing so would lead to a race condition (was go.dev/issue/31749).
// Do this check always, not just in debug mode (it's cheap).
if n.obj.pkg != check.pkg {
panic("imported type with unresolved underlying type")
}
n.underlying = u
}
return u
}
func (n *Named) setUnderlying(typ Type) {
if n != nil {
n.underlying = typ
}
}
func (n *Named) lookupMethod(pkg *Package, name string, foldCase bool) (int, *Func) {
n.resolve()
// If n is an instance, we may not have yet instantiated all of its methods.
// Look up the method index in orig, and only instantiate method at the
// matching index (if any).
i, _ := lookupMethod(n.Origin().methods, pkg, name, foldCase)
if i < 0 {
return -1, nil
}
// For instances, m.Method(i) will be different from the orig method.
return i, n.Method(i)
}
// context returns the type-checker context.
func (check *Checker) context() *Context {
if check.ctxt == nil {
check.ctxt = NewContext()
}
return check.ctxt
}
// expandUnderlying substitutes type arguments in the underlying type n.orig,
// returning the result. Returns Typ[Invalid] if there was an error.
func (n *Named) expandUnderlying() Type {
check := n.check
if check != nil && check.conf._Trace {
check.trace(n.obj.pos, "-- Named.expandUnderlying %s", n)
check.indent++
defer func() {
check.indent--
check.trace(n.obj.pos, "=> %s (tparams = %s, under = %s)", n, n.tparams.list(), n.underlying)
}()
}
assert(n.inst.orig.underlying != nil)
if n.inst.ctxt == nil {
n.inst.ctxt = NewContext()
}
orig := n.inst.orig
targs := n.inst.targs
if _, unexpanded := orig.underlying.(*Named); unexpanded {
// We should only get a Named underlying type here during type checking
// (for example, in recursive type declarations).
assert(check != nil)
}
if orig.tparams.Len() != targs.Len() {
// Mismatching arg and tparam length may be checked elsewhere.
return Typ[Invalid]
}
// Ensure that an instance is recorded before substituting, so that we
// resolve n for any recursive references.
h := n.inst.ctxt.instanceHash(orig, targs.list())
n2 := n.inst.ctxt.update(h, orig, n.TypeArgs().list(), n)
assert(n == n2)
smap := makeSubstMap(orig.tparams.list(), targs.list())
var ctxt *Context
if check != nil {
ctxt = check.context()
}
underlying := n.check.subst(n.obj.pos, orig.underlying, smap, n, ctxt)
// If the underlying type of n is an interface, we need to set the receiver of
// its methods accurately -- we set the receiver of interface methods on
// the RHS of a type declaration to the defined type.
if iface, _ := underlying.(*Interface); iface != nil {
if methods, copied := replaceRecvType(iface.methods, orig, n); copied {
// If the underlying type doesn't actually use type parameters, it's
// possible that it wasn't substituted. In this case we need to create
// a new *Interface before modifying receivers.
if iface == orig.underlying {
old := iface
iface = check.newInterface()
iface.embeddeds = old.embeddeds
iface.complete = old.complete
iface.implicit = old.implicit // should be false but be conservative
underlying = iface
}
iface.methods = methods
}
}
return underlying
}
// safeUnderlying returns the underlying type of typ without expanding
// instances, to avoid infinite recursion.
//
// TODO(rfindley): eliminate this function or give it a better name.
func safeUnderlying(typ Type) Type {
if t, _ := typ.(*Named); t != nil {
return t.underlying
}
return typ.Underlying()
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"bytes"
"fmt"
"go/constant"
"go/token"
"unicode"
"unicode/utf8"
)
// An Object describes a named language entity such as a package,
// constant, type, variable, function (incl. methods), or label.
// All objects implement the Object interface.
type Object interface {
Parent() *Scope // scope in which this object is declared; nil for methods and struct fields
Pos() token.Pos // position of object identifier in declaration
Pkg() *Package // package to which this object belongs; nil for labels and objects in the Universe scope
Name() string // package local object name
Type() Type // object type
Exported() bool // reports whether the name starts with a capital letter
Id() string // object name if exported, qualified name if not exported (see func Id)
// String returns a human-readable string of the object.
String() string
// order reflects a package-level object's source order: if object
// a is before object b in the source, then a.order() < b.order().
// order returns a value > 0 for package-level objects; it returns
// 0 for all other objects (including objects in file scopes).
order() uint32
// color returns the object's color.
color() color
// setType sets the type of the object.
setType(Type)
// setOrder sets the order number of the object. It must be > 0.
setOrder(uint32)
// setColor sets the object's color. It must not be white.
setColor(color color)
// setParent sets the parent scope of the object.
setParent(*Scope)
// sameId reports whether obj.Id() and Id(pkg, name) are the same.
sameId(pkg *Package, name string) bool
// scopePos returns the start position of the scope of this Object
scopePos() token.Pos
// setScopePos sets the start position of the scope for this Object.
setScopePos(pos token.Pos)
}
func isExported(name string) bool {
ch, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(ch)
}
// Id returns name if it is exported, otherwise it
// returns the name qualified with the package path.
func Id(pkg *Package, name string) string {
if isExported(name) {
return name
}
// unexported names need the package path for differentiation
// (if there's no package, make sure we don't start with '.'
// as that may change the order of methods between a setup
// inside a package and outside a package - which breaks some
// tests)
path := "_"
// pkg is nil for objects in Universe scope and possibly types
// introduced via Eval (see also comment in object.sameId)
if pkg != nil && pkg.path != "" {
path = pkg.path
}
return path + "." + name
}
// An object implements the common parts of an Object.
type object struct {
parent *Scope
pos token.Pos
pkg *Package
name string
typ Type
order_ uint32
color_ color
scopePos_ token.Pos
}
// color encodes the color of an object (see Checker.objDecl for details).
type color uint32
// An object may be painted in one of three colors.
// Color values other than white or black are considered grey.
const (
white color = iota
black
grey // must be > white and black
)
func (c color) String() string {
switch c {
case white:
return "white"
case black:
return "black"
default:
return "grey"
}
}
// colorFor returns the (initial) color for an object depending on
// whether its type t is known or not.
func colorFor(t Type) color {
if t != nil {
return black
}
return white
}
// Parent returns the scope in which the object is declared.
// The result is nil for methods and struct fields.
func (obj *object) Parent() *Scope { return obj.parent }
// Pos returns the declaration position of the object's identifier.
func (obj *object) Pos() token.Pos { return obj.pos }
// Pkg returns the package to which the object belongs.
// The result is nil for labels and objects in the Universe scope.
func (obj *object) Pkg() *Package { return obj.pkg }
// Name returns the object's (package-local, unqualified) name.
func (obj *object) Name() string { return obj.name }
// Type returns the object's type.
func (obj *object) Type() Type { return obj.typ }
// Exported reports whether the object is exported (starts with a capital letter).
// It doesn't take into account whether the object is in a local (function) scope
// or not.
func (obj *object) Exported() bool { return isExported(obj.name) }
// Id is a wrapper for Id(obj.Pkg(), obj.Name()).
func (obj *object) Id() string { return Id(obj.pkg, obj.name) }
func (obj *object) String() string { panic("abstract") }
func (obj *object) order() uint32 { return obj.order_ }
func (obj *object) color() color { return obj.color_ }
func (obj *object) scopePos() token.Pos { return obj.scopePos_ }
func (obj *object) setParent(parent *Scope) { obj.parent = parent }
func (obj *object) setType(typ Type) { obj.typ = typ }
func (obj *object) setOrder(order uint32) { assert(order > 0); obj.order_ = order }
func (obj *object) setColor(color color) { assert(color != white); obj.color_ = color }
func (obj *object) setScopePos(pos token.Pos) { obj.scopePos_ = pos }
func (obj *object) sameId(pkg *Package, name string) bool {
// spec:
// "Two identifiers are different if they are spelled differently,
// or if they appear in different packages and are not exported.
// Otherwise, they are the same."
if name != obj.name {
return false
}
// obj.Name == name
if obj.Exported() {
return true
}
// not exported, so packages must be the same (pkg == nil for
// fields in Universe scope; this can only happen for types
// introduced via Eval)
if pkg == nil || obj.pkg == nil {
return pkg == obj.pkg
}
// pkg != nil && obj.pkg != nil
return pkg.path == obj.pkg.path
}
// less reports whether object a is ordered before object b.
//
// Objects are ordered nil before non-nil, exported before
// non-exported, then by name, and finally (for non-exported
// functions) by package path.
func (a *object) less(b *object) bool {
if a == b {
return false
}
// Nil before non-nil.
if a == nil {
return true
}
if b == nil {
return false
}
// Exported functions before non-exported.
ea := isExported(a.name)
eb := isExported(b.name)
if ea != eb {
return ea
}
// Order by name and then (for non-exported names) by package.
if a.name != b.name {
return a.name < b.name
}
if !ea {
return a.pkg.path < b.pkg.path
}
return false
}
// A PkgName represents an imported Go package.
// PkgNames don't have a type.
type PkgName struct {
object
imported *Package
used bool // set if the package was used
}
// NewPkgName returns a new PkgName object representing an imported package.
// The remaining arguments set the attributes found with all Objects.
func NewPkgName(pos token.Pos, pkg *Package, name string, imported *Package) *PkgName {
return &PkgName{object{nil, pos, pkg, name, Typ[Invalid], 0, black, nopos}, imported, false}
}
// Imported returns the package that was imported.
// It is distinct from Pkg(), which is the package containing the import statement.
func (obj *PkgName) Imported() *Package { return obj.imported }
// A Const represents a declared constant.
type Const struct {
object
val constant.Value
}
// NewConst returns a new constant with value val.
// The remaining arguments set the attributes found with all Objects.
func NewConst(pos token.Pos, pkg *Package, name string, typ Type, val constant.Value) *Const {
return &Const{object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}, val}
}
// Val returns the constant's value.
func (obj *Const) Val() constant.Value { return obj.val }
func (*Const) isDependency() {} // a constant may be a dependency of an initialization expression
// A TypeName represents a name for a (defined or alias) type.
type TypeName struct {
object
}
// NewTypeName returns a new type name denoting the given typ.
// The remaining arguments set the attributes found with all Objects.
//
// The typ argument may be a defined (Named) type or an alias type.
// It may also be nil such that the returned TypeName can be used as
// argument for NewNamed, which will set the TypeName's type as a side-
// effect.
func NewTypeName(pos token.Pos, pkg *Package, name string, typ Type) *TypeName {
return &TypeName{object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}}
}
// NewTypeNameLazy returns a new defined type like NewTypeName, but it
// lazily calls resolve to finish constructing the Named object.
func _NewTypeNameLazy(pos token.Pos, pkg *Package, name string, load func(named *Named) (tparams []*TypeParam, underlying Type, methods []*Func)) *TypeName {
obj := NewTypeName(pos, pkg, name, nil)
NewNamed(obj, nil, nil).loader = load
return obj
}
// IsAlias reports whether obj is an alias name for a type.
func (obj *TypeName) IsAlias() bool {
switch t := obj.typ.(type) {
case nil:
return false
case *Basic:
// unsafe.Pointer is not an alias.
if obj.pkg == Unsafe {
return false
}
// Any user-defined type name for a basic type is an alias for a
// basic type (because basic types are pre-declared in the Universe
// scope, outside any package scope), and so is any type name with
// a different name than the name of the basic type it refers to.
// Additionally, we need to look for "byte" and "rune" because they
// are aliases but have the same names (for better error messages).
return obj.pkg != nil || t.name != obj.name || t == universeByte || t == universeRune
case *Named:
return obj != t.obj
case *TypeParam:
return obj != t.obj
default:
return true
}
}
// A Variable represents a declared variable (including function parameters and results, and struct fields).
type Var struct {
object
embedded bool // if set, the variable is an embedded struct field, and name is the type name
isField bool // var is struct field
used bool // set if the variable was used
origin *Var // if non-nil, the Var from which this one was instantiated
}
// NewVar returns a new variable.
// The arguments set the attributes found with all Objects.
func NewVar(pos token.Pos, pkg *Package, name string, typ Type) *Var {
return &Var{object: object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}}
}
// NewParam returns a new variable representing a function parameter.
func NewParam(pos token.Pos, pkg *Package, name string, typ Type) *Var {
return &Var{object: object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}, used: true} // parameters are always 'used'
}
// NewField returns a new variable representing a struct field.
// For embedded fields, the name is the unqualified type name
// under which the field is accessible.
func NewField(pos token.Pos, pkg *Package, name string, typ Type, embedded bool) *Var {
return &Var{object: object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}, embedded: embedded, isField: true}
}
// Anonymous reports whether the variable is an embedded field.
// Same as Embedded; only present for backward-compatibility.
func (obj *Var) Anonymous() bool { return obj.embedded }
// Embedded reports whether the variable is an embedded field.
func (obj *Var) Embedded() bool { return obj.embedded }
// IsField reports whether the variable is a struct field.
func (obj *Var) IsField() bool { return obj.isField }
// Origin returns the canonical Var for its receiver, i.e. the Var object
// recorded in Info.Defs.
//
// For synthetic Vars created during instantiation (such as struct fields or
// function parameters that depend on type arguments), this will be the
// corresponding Var on the generic (uninstantiated) type. For all other Vars
// Origin returns the receiver.
func (obj *Var) Origin() *Var {
if obj.origin != nil {
return obj.origin
}
return obj
}
func (*Var) isDependency() {} // a variable may be a dependency of an initialization expression
// A Func represents a declared function, concrete method, or abstract
// (interface) method. Its Type() is always a *Signature.
// An abstract method may belong to many interfaces due to embedding.
type Func struct {
object
hasPtrRecv_ bool // only valid for methods that don't have a type yet; use hasPtrRecv() to read
origin *Func // if non-nil, the Func from which this one was instantiated
}
// NewFunc returns a new function with the given signature, representing
// the function's type.
func NewFunc(pos token.Pos, pkg *Package, name string, sig *Signature) *Func {
// don't store a (typed) nil signature
var typ Type
if sig != nil {
typ = sig
}
return &Func{object{nil, pos, pkg, name, typ, 0, colorFor(typ), nopos}, false, nil}
}
// FullName returns the package- or receiver-type-qualified name of
// function or method obj.
func (obj *Func) FullName() string {
var buf bytes.Buffer
writeFuncName(&buf, obj, nil)
return buf.String()
}
// Scope returns the scope of the function's body block.
// The result is nil for imported or instantiated functions and methods
// (but there is also no mechanism to get to an instantiated function).
func (obj *Func) Scope() *Scope { return obj.typ.(*Signature).scope }
// Origin returns the canonical Func for its receiver, i.e. the Func object
// recorded in Info.Defs.
//
// For synthetic functions created during instantiation (such as methods on an
// instantiated Named type or interface methods that depend on type arguments),
// this will be the corresponding Func on the generic (uninstantiated) type.
// For all other Funcs Origin returns the receiver.
func (obj *Func) Origin() *Func {
if obj.origin != nil {
return obj.origin
}
return obj
}
// hasPtrRecv reports whether the receiver is of the form *T for the given method obj.
func (obj *Func) hasPtrRecv() bool {
// If a method's receiver type is set, use that as the source of truth for the receiver.
// Caution: Checker.funcDecl (decl.go) marks a function by setting its type to an empty
// signature. We may reach here before the signature is fully set up: we must explicitly
// check if the receiver is set (we cannot just look for non-nil obj.typ).
if sig, _ := obj.typ.(*Signature); sig != nil && sig.recv != nil {
_, isPtr := deref(sig.recv.typ)
return isPtr
}
// If a method's type is not set it may be a method/function that is:
// 1) client-supplied (via NewFunc with no signature), or
// 2) internally created but not yet type-checked.
// For case 1) we can't do anything; the client must know what they are doing.
// For case 2) we can use the information gathered by the resolver.
return obj.hasPtrRecv_
}
func (*Func) isDependency() {} // a function may be a dependency of an initialization expression
// A Label represents a declared label.
// Labels don't have a type.
type Label struct {
object
used bool // set if the label was used
}
// NewLabel returns a new label.
func NewLabel(pos token.Pos, pkg *Package, name string) *Label {
return &Label{object{pos: pos, pkg: pkg, name: name, typ: Typ[Invalid], color_: black}, false}
}
// A Builtin represents a built-in function.
// Builtins don't have a valid type.
type Builtin struct {
object
id builtinId
}
func newBuiltin(id builtinId) *Builtin {
return &Builtin{object{name: predeclaredFuncs[id].name, typ: Typ[Invalid], color_: black}, id}
}
// Nil represents the predeclared value nil.
type Nil struct {
object
}
func writeObject(buf *bytes.Buffer, obj Object, qf Qualifier) {
var tname *TypeName
typ := obj.Type()
switch obj := obj.(type) {
case *PkgName:
fmt.Fprintf(buf, "package %s", obj.Name())
if path := obj.imported.path; path != "" && path != obj.name {
fmt.Fprintf(buf, " (%q)", path)
}
return
case *Const:
buf.WriteString("const")
case *TypeName:
tname = obj
buf.WriteString("type")
if isTypeParam(typ) {
buf.WriteString(" parameter")
}
case *Var:
if obj.isField {
buf.WriteString("field")
} else {
buf.WriteString("var")
}
case *Func:
buf.WriteString("func ")
writeFuncName(buf, obj, qf)
if typ != nil {
WriteSignature(buf, typ.(*Signature), qf)
}
return
case *Label:
buf.WriteString("label")
typ = nil
case *Builtin:
buf.WriteString("builtin")
typ = nil
case *Nil:
buf.WriteString("nil")
return
default:
panic(fmt.Sprintf("writeObject(%T)", obj))
}
buf.WriteByte(' ')
// For package-level objects, qualify the name.
if obj.Pkg() != nil && obj.Pkg().scope.Lookup(obj.Name()) == obj {
buf.WriteString(packagePrefix(obj.Pkg(), qf))
}
buf.WriteString(obj.Name())
if typ == nil {
return
}
if tname != nil {
switch t := typ.(type) {
case *Basic:
// Don't print anything more for basic types since there's
// no more information.
return
case *Named:
if t.TypeParams().Len() > 0 {
newTypeWriter(buf, qf).tParamList(t.TypeParams().list())
}
}
if tname.IsAlias() {
buf.WriteString(" =")
} else if t, _ := typ.(*TypeParam); t != nil {
typ = t.bound
} else {
// TODO(gri) should this be fromRHS for *Named?
typ = under(typ)
}
}
// Special handling for any: because WriteType will format 'any' as 'any',
// resulting in the object string `type any = any` rather than `type any =
// interface{}`. To avoid this, swap in a different empty interface.
if obj == universeAny {
assert(Identical(typ, &emptyInterface))
typ = &emptyInterface
}
buf.WriteByte(' ')
WriteType(buf, typ, qf)
}
func packagePrefix(pkg *Package, qf Qualifier) string {
if pkg == nil {
return ""
}
var s string
if qf != nil {
s = qf(pkg)
} else {
s = pkg.Path()
}
if s != "" {
s += "."
}
return s
}
// ObjectString returns the string form of obj.
// The Qualifier controls the printing of
// package-level objects, and may be nil.
func ObjectString(obj Object, qf Qualifier) string {
var buf bytes.Buffer
writeObject(&buf, obj, qf)
return buf.String()
}
func (obj *PkgName) String() string { return ObjectString(obj, nil) }
func (obj *Const) String() string { return ObjectString(obj, nil) }
func (obj *TypeName) String() string { return ObjectString(obj, nil) }
func (obj *Var) String() string { return ObjectString(obj, nil) }
func (obj *Func) String() string { return ObjectString(obj, nil) }
func (obj *Label) String() string { return ObjectString(obj, nil) }
func (obj *Builtin) String() string { return ObjectString(obj, nil) }
func (obj *Nil) String() string { return ObjectString(obj, nil) }
func writeFuncName(buf *bytes.Buffer, f *Func, qf Qualifier) {
if f.typ != nil {
sig := f.typ.(*Signature)
if recv := sig.Recv(); recv != nil {
buf.WriteByte('(')
if _, ok := recv.Type().(*Interface); ok {
// gcimporter creates abstract methods of
// named interfaces using the interface type
// (not the named type) as the receiver.
// Don't print it in full.
buf.WriteString("interface")
} else {
WriteType(buf, recv.Type(), qf)
}
buf.WriteByte(')')
buf.WriteByte('.')
} else if f.pkg != nil {
buf.WriteString(packagePrefix(f.pkg, qf))
}
}
buf.WriteString(f.name)
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements objsets.
//
// An objset is similar to a Scope but objset elements
// are identified by their unique id, instead of their
// object name.
package types
// An objset is a set of objects identified by their unique id.
// The zero value for objset is a ready-to-use empty objset.
type objset map[string]Object // initialized lazily
// insert attempts to insert an object obj into objset s.
// If s already contains an alternative object alt with
// the same name, insert leaves s unchanged and returns alt.
// Otherwise it inserts obj and returns nil.
func (s *objset) insert(obj Object) Object {
id := obj.Id()
if alt := (*s)[id]; alt != nil {
return alt
}
if *s == nil {
*s = make(map[string]Object)
}
(*s)[id] = obj
return nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file defines operands and associated operations.
package types
import (
"bytes"
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
)
// An operandMode specifies the (addressing) mode of an operand.
type operandMode byte
const (
invalid operandMode = iota // operand is invalid
novalue // operand represents no value (result of a function call w/o result)
builtin // operand is a built-in function
typexpr // operand is a type
constant_ // operand is a constant; the operand's typ is a Basic type
variable // operand is an addressable variable
mapindex // operand is a map index expression (acts like a variable on lhs, commaok on rhs of an assignment)
value // operand is a computed value
commaok // like value, but operand may be used in a comma,ok expression
commaerr // like commaok, but second value is error, not boolean
cgofunc // operand is a cgo function
)
var operandModeString = [...]string{
invalid: "invalid operand",
novalue: "no value",
builtin: "built-in",
typexpr: "type",
constant_: "constant",
variable: "variable",
mapindex: "map index expression",
value: "value",
commaok: "comma, ok expression",
commaerr: "comma, error expression",
cgofunc: "cgo function",
}
// An operand represents an intermediate value during type checking.
// Operands have an (addressing) mode, the expression evaluating to
// the operand, the operand's type, a value for constants, and an id
// for built-in functions.
// The zero value of operand is a ready to use invalid operand.
type operand struct {
mode operandMode
expr ast.Expr
typ Type
val constant.Value
id builtinId
}
// Pos returns the position of the expression corresponding to x.
// If x is invalid the position is nopos.
func (x *operand) Pos() token.Pos {
// x.expr may not be set if x is invalid
if x.expr == nil {
return nopos
}
return x.expr.Pos()
}
// Operand string formats
// (not all "untyped" cases can appear due to the type system,
// but they fall out naturally here)
//
// mode format
//
// invalid <expr> ( <mode> )
// novalue <expr> ( <mode> )
// builtin <expr> ( <mode> )
// typexpr <expr> ( <mode> )
//
// constant <expr> (<untyped kind> <mode> )
// constant <expr> ( <mode> of type <typ>)
// constant <expr> (<untyped kind> <mode> <val> )
// constant <expr> ( <mode> <val> of type <typ>)
//
// variable <expr> (<untyped kind> <mode> )
// variable <expr> ( <mode> of type <typ>)
//
// mapindex <expr> (<untyped kind> <mode> )
// mapindex <expr> ( <mode> of type <typ>)
//
// value <expr> (<untyped kind> <mode> )
// value <expr> ( <mode> of type <typ>)
//
// commaok <expr> (<untyped kind> <mode> )
// commaok <expr> ( <mode> of type <typ>)
//
// commaerr <expr> (<untyped kind> <mode> )
// commaerr <expr> ( <mode> of type <typ>)
//
// cgofunc <expr> (<untyped kind> <mode> )
// cgofunc <expr> ( <mode> of type <typ>)
func operandString(x *operand, qf Qualifier) string {
// special-case nil
if x.mode == value && x.typ == Typ[UntypedNil] {
return "nil"
}
var buf bytes.Buffer
var expr string
if x.expr != nil {
expr = ExprString(x.expr)
} else {
switch x.mode {
case builtin:
expr = predeclaredFuncs[x.id].name
case typexpr:
expr = TypeString(x.typ, qf)
case constant_:
expr = x.val.String()
}
}
// <expr> (
if expr != "" {
buf.WriteString(expr)
buf.WriteString(" (")
}
// <untyped kind>
hasType := false
switch x.mode {
case invalid, novalue, builtin, typexpr:
// no type
default:
// should have a type, but be cautious (don't crash during printing)
if x.typ != nil {
if isUntyped(x.typ) {
buf.WriteString(x.typ.(*Basic).name)
buf.WriteByte(' ')
break
}
hasType = true
}
}
// <mode>
buf.WriteString(operandModeString[x.mode])
// <val>
if x.mode == constant_ {
if s := x.val.String(); s != expr {
buf.WriteByte(' ')
buf.WriteString(s)
}
}
// <typ>
if hasType {
if x.typ != Typ[Invalid] {
var intro string
if isGeneric(x.typ) {
intro = " of generic type "
} else {
intro = " of type "
}
buf.WriteString(intro)
WriteType(&buf, x.typ, qf)
if tpar, _ := x.typ.(*TypeParam); tpar != nil {
buf.WriteString(" constrained by ")
WriteType(&buf, tpar.bound, qf) // do not compute interface type sets here
// If we have the type set and it's empty, say so for better error messages.
if hasEmptyTypeset(tpar) {
buf.WriteString(" with empty type set")
}
}
} else {
buf.WriteString(" with invalid type")
}
}
// )
if expr != "" {
buf.WriteByte(')')
}
return buf.String()
}
func (x *operand) String() string {
return operandString(x, nil)
}
// setConst sets x to the untyped constant for literal lit.
func (x *operand) setConst(tok token.Token, lit string) {
var kind BasicKind
switch tok {
case token.INT:
kind = UntypedInt
case token.FLOAT:
kind = UntypedFloat
case token.IMAG:
kind = UntypedComplex
case token.CHAR:
kind = UntypedRune
case token.STRING:
kind = UntypedString
default:
unreachable()
}
val := constant.MakeFromLiteral(lit, tok, 0)
if val.Kind() == constant.Unknown {
x.mode = invalid
x.typ = Typ[Invalid]
return
}
x.mode = constant_
x.typ = Typ[kind]
x.val = val
}
// isNil reports whether x is the nil value.
func (x *operand) isNil() bool {
return x.mode == value && x.typ == Typ[UntypedNil]
}
// assignableTo reports whether x is assignable to a variable of type T. If the
// result is false and a non-nil cause is provided, it may be set to a more
// detailed explanation of the failure (result != ""). The returned error code
// is only valid if the (first) result is false. The check parameter may be nil
// if assignableTo is invoked through an exported API call, i.e., when all
// methods have been type-checked.
func (x *operand) assignableTo(check *Checker, T Type, cause *string) (bool, Code) {
if x.mode == invalid || T == Typ[Invalid] {
return true, 0 // avoid spurious errors
}
V := x.typ
// x's type is identical to T
if Identical(V, T) {
return true, 0
}
Vu := under(V)
Tu := under(T)
Vp, _ := V.(*TypeParam)
Tp, _ := T.(*TypeParam)
// x is an untyped value representable by a value of type T.
if isUntyped(Vu) {
assert(Vp == nil)
if Tp != nil {
// T is a type parameter: x is assignable to T if it is
// representable by each specific type in the type set of T.
return Tp.is(func(t *term) bool {
if t == nil {
return false
}
// A term may be a tilde term but the underlying
// type of an untyped value doesn't change so we
// don't need to do anything special.
newType, _, _ := check.implicitTypeAndValue(x, t.typ)
return newType != nil
}), IncompatibleAssign
}
newType, _, _ := check.implicitTypeAndValue(x, T)
return newType != nil, IncompatibleAssign
}
// Vu is typed
// x's type V and T have identical underlying types
// and at least one of V or T is not a named type
// and neither V nor T is a type parameter.
if Identical(Vu, Tu) && (!hasName(V) || !hasName(T)) && Vp == nil && Tp == nil {
return true, 0
}
// T is an interface type and x implements T and T is not a type parameter.
// Also handle the case where T is a pointer to an interface.
if _, ok := Tu.(*Interface); ok && Tp == nil || isInterfacePtr(Tu) {
if !check.implements(V, T, false, cause) {
return false, InvalidIfaceAssign
}
return true, 0
}
// If V is an interface, check if a missing type assertion is the problem.
if Vi, _ := Vu.(*Interface); Vi != nil && Vp == nil {
if check.implements(T, V, false, nil) {
// T implements V, so give hint about type assertion.
if cause != nil {
*cause = "need type assertion"
}
return false, IncompatibleAssign
}
}
// x is a bidirectional channel value, T is a channel
// type, x's type V and T have identical element types,
// and at least one of V or T is not a named type.
if Vc, ok := Vu.(*Chan); ok && Vc.dir == SendRecv {
if Tc, ok := Tu.(*Chan); ok && Identical(Vc.elem, Tc.elem) {
return !hasName(V) || !hasName(T), InvalidChanAssign
}
}
// optimization: if we don't have type parameters, we're done
if Vp == nil && Tp == nil {
return false, IncompatibleAssign
}
errorf := func(format string, args ...any) {
if check != nil && cause != nil {
msg := check.sprintf(format, args...)
if *cause != "" {
msg += "\n\t" + *cause
}
*cause = msg
}
}
// x's type V is not a named type and T is a type parameter, and
// x is assignable to each specific type in T's type set.
if !hasName(V) && Tp != nil {
ok := false
code := IncompatibleAssign
Tp.is(func(T *term) bool {
if T == nil {
return false // no specific types
}
ok, code = x.assignableTo(check, T.typ, cause)
if !ok {
errorf("cannot assign %s to %s (in %s)", x.typ, T.typ, Tp)
return false
}
return true
})
return ok, code
}
// x's type V is a type parameter and T is not a named type,
// and values x' of each specific type in V's type set are
// assignable to T.
if Vp != nil && !hasName(T) {
x := *x // don't clobber outer x
ok := false
code := IncompatibleAssign
Vp.is(func(V *term) bool {
if V == nil {
return false // no specific types
}
x.typ = V.typ
ok, code = x.assignableTo(check, T, cause)
if !ok {
errorf("cannot assign %s (in %s) to %s", V.typ, Vp, T)
return false
}
return true
})
return ok, code
}
return false, IncompatibleAssign
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
)
// A Package describes a Go package.
type Package struct {
path string
name string
scope *Scope
imports []*Package
complete bool
fake bool // scope lookup errors are silently dropped if package is fake (internal use only)
cgo bool // uses of this package will be rewritten into uses of declarations from _cgo_gotypes.go
}
// NewPackage returns a new Package for the given package path and name.
// The package is not complete and contains no explicit imports.
func NewPackage(path, name string) *Package {
scope := NewScope(Universe, nopos, nopos, fmt.Sprintf("package %q", path))
return &Package{path: path, name: name, scope: scope}
}
// Path returns the package path.
func (pkg *Package) Path() string { return pkg.path }
// Name returns the package name.
func (pkg *Package) Name() string { return pkg.name }
// SetName sets the package name.
func (pkg *Package) SetName(name string) { pkg.name = name }
// Scope returns the (complete or incomplete) package scope
// holding the objects declared at package level (TypeNames,
// Consts, Vars, and Funcs).
// For a nil pkg receiver, Scope returns the Universe scope.
func (pkg *Package) Scope() *Scope {
if pkg != nil {
return pkg.scope
}
return Universe
}
// A package is complete if its scope contains (at least) all
// exported objects; otherwise it is incomplete.
func (pkg *Package) Complete() bool { return pkg.complete }
// MarkComplete marks a package as complete.
func (pkg *Package) MarkComplete() { pkg.complete = true }
// Imports returns the list of packages directly imported by
// pkg; the list is in source order.
//
// If pkg was loaded from export data, Imports includes packages that
// provide package-level objects referenced by pkg. This may be more or
// less than the set of packages directly imported by pkg's source code.
//
// If pkg uses cgo and the FakeImportC configuration option
// was enabled, the imports list may contain a fake "C" package.
func (pkg *Package) Imports() []*Package { return pkg.imports }
// SetImports sets the list of explicitly imported packages to list.
// It is the caller's responsibility to make sure list elements are unique.
func (pkg *Package) SetImports(list []*Package) { pkg.imports = list }
func (pkg *Package) String() string {
return fmt.Sprintf("package %s (%q)", pkg.name, pkg.path)
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A Pointer represents a pointer type.
type Pointer struct {
base Type // element type
}
// NewPointer returns a new pointer type for the given element (base) type.
func NewPointer(elem Type) *Pointer { return &Pointer{base: elem} }
// Elem returns the element type for the given pointer p.
func (p *Pointer) Elem() Type { return p.base }
func (p *Pointer) Underlying() Type { return p }
func (p *Pointer) String() string { return TypeString(p, nil) }
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements commonly used type predicates.
package types
// The isX predicates below report whether t is an X.
// If t is a type parameter the result is false; i.e.,
// these predicates don't look inside a type parameter.
func isBoolean(t Type) bool { return isBasic(t, IsBoolean) }
func isInteger(t Type) bool { return isBasic(t, IsInteger) }
func isUnsigned(t Type) bool { return isBasic(t, IsUnsigned) }
func isFloat(t Type) bool { return isBasic(t, IsFloat) }
func isComplex(t Type) bool { return isBasic(t, IsComplex) }
func isNumeric(t Type) bool { return isBasic(t, IsNumeric) }
func isString(t Type) bool { return isBasic(t, IsString) }
func isIntegerOrFloat(t Type) bool { return isBasic(t, IsInteger|IsFloat) }
func isConstType(t Type) bool { return isBasic(t, IsConstType) }
// isBasic reports whether under(t) is a basic type with the specified info.
// If t is a type parameter the result is false; i.e.,
// isBasic does not look inside a type parameter.
func isBasic(t Type, info BasicInfo) bool {
u, _ := under(t).(*Basic)
return u != nil && u.info&info != 0
}
// The allX predicates below report whether t is an X.
// If t is a type parameter the result is true if isX is true
// for all specified types of the type parameter's type set.
// allX is an optimized version of isX(coreType(t)) (which
// is the same as underIs(t, isX)).
func allBoolean(t Type) bool { return allBasic(t, IsBoolean) }
func allInteger(t Type) bool { return allBasic(t, IsInteger) }
func allUnsigned(t Type) bool { return allBasic(t, IsUnsigned) }
func allNumeric(t Type) bool { return allBasic(t, IsNumeric) }
func allString(t Type) bool { return allBasic(t, IsString) }
func allOrdered(t Type) bool { return allBasic(t, IsOrdered) }
func allNumericOrString(t Type) bool { return allBasic(t, IsNumeric|IsString) }
// allBasic reports whether under(t) is a basic type with the specified info.
// If t is a type parameter, the result is true if isBasic(t, info) is true
// for all specific types of the type parameter's type set.
// allBasic(t, info) is an optimized version of isBasic(coreType(t), info).
func allBasic(t Type, info BasicInfo) bool {
if tpar, _ := t.(*TypeParam); tpar != nil {
return tpar.is(func(t *term) bool { return t != nil && isBasic(t.typ, info) })
}
return isBasic(t, info)
}
// hasName reports whether t has a name. This includes
// predeclared types, defined types, and type parameters.
// hasName may be called with types that are not fully set up.
func hasName(t Type) bool {
switch t.(type) {
case *Basic, *Named, *TypeParam:
return true
}
return false
}
// isTypeLit reports whether t is a type literal.
// This includes all non-defined types, but also basic types.
// isTypeLit may be called with types that are not fully set up.
func isTypeLit(t Type) bool {
switch t.(type) {
case *Named, *TypeParam:
return false
}
return true
}
// isTyped reports whether t is typed; i.e., not an untyped
// constant or boolean. isTyped may be called with types that
// are not fully set up.
func isTyped(t Type) bool {
// isTyped is called with types that are not fully
// set up. Must not call under()!
b, _ := t.(*Basic)
return b == nil || b.info&IsUntyped == 0
}
// isUntyped(t) is the same as !isTyped(t).
func isUntyped(t Type) bool {
return !isTyped(t)
}
// IsInterface reports whether t is an interface type.
func IsInterface(t Type) bool {
_, ok := under(t).(*Interface)
return ok
}
// isNonTypeParamInterface reports whether t is an interface type but not a type parameter.
func isNonTypeParamInterface(t Type) bool {
return !isTypeParam(t) && IsInterface(t)
}
// isTypeParam reports whether t is a type parameter.
func isTypeParam(t Type) bool {
_, ok := t.(*TypeParam)
return ok
}
// hasEmptyTypeset reports whether t is a type parameter with an empty type set.
// The function does not force the computation of the type set and so is safe to
// use anywhere, but it may report a false negative if the type set has not been
// computed yet.
func hasEmptyTypeset(t Type) bool {
if tpar, _ := t.(*TypeParam); tpar != nil && tpar.bound != nil {
iface, _ := safeUnderlying(tpar.bound).(*Interface)
return iface != nil && iface.tset != nil && iface.tset.IsEmpty()
}
return false
}
// isGeneric reports whether a type is a generic, uninstantiated type
// (generic signatures are not included).
// TODO(gri) should we include signatures or assert that they are not present?
func isGeneric(t Type) bool {
// A parameterized type is only generic if it doesn't have an instantiation already.
named, _ := t.(*Named)
return named != nil && named.obj != nil && named.inst == nil && named.TypeParams().Len() > 0
}
// Comparable reports whether values of type T are comparable.
func Comparable(T Type) bool {
return comparable(T, true, nil, nil)
}
// If dynamic is set, non-type parameter interfaces are always comparable.
// If reportf != nil, it may be used to report why T is not comparable.
func comparable(T Type, dynamic bool, seen map[Type]bool, reportf func(string, ...interface{})) bool {
if seen[T] {
return true
}
if seen == nil {
seen = make(map[Type]bool)
}
seen[T] = true
switch t := under(T).(type) {
case *Basic:
// assume invalid types to be comparable
// to avoid follow-up errors
return t.kind != UntypedNil
case *Pointer, *Chan:
return true
case *Struct:
for _, f := range t.fields {
if !comparable(f.typ, dynamic, seen, nil) {
if reportf != nil {
reportf("struct containing %s cannot be compared", f.typ)
}
return false
}
}
return true
case *Array:
if !comparable(t.elem, dynamic, seen, nil) {
if reportf != nil {
reportf("%s cannot be compared", t)
}
return false
}
return true
case *Interface:
if dynamic && !isTypeParam(T) || t.typeSet().IsComparable(seen) {
return true
}
if reportf != nil {
if t.typeSet().IsEmpty() {
reportf("empty type set")
} else {
reportf("incomparable types in type set")
}
}
// fallthrough
}
return false
}
// hasNil reports whether type t includes the nil value.
func hasNil(t Type) bool {
switch u := under(t).(type) {
case *Basic:
return u.kind == UnsafePointer
case *Slice, *Pointer, *Signature, *Map, *Chan:
return true
case *Interface:
return !isTypeParam(t) || u.typeSet().underIs(func(u Type) bool {
return u != nil && hasNil(u)
})
}
return false
}
// An ifacePair is a node in a stack of interface type pairs compared for identity.
type ifacePair struct {
x, y *Interface
prev *ifacePair
}
func (p *ifacePair) identical(q *ifacePair) bool {
return p.x == q.x && p.y == q.y || p.x == q.y && p.y == q.x
}
// A comparer is used to compare types.
type comparer struct {
ignoreTags bool // if set, identical ignores struct tags
ignoreInvalids bool // if set, identical treats an invalid type as identical to any type
}
// For changes to this code the corresponding changes should be made to unifier.nify.
func (c *comparer) identical(x, y Type, p *ifacePair) bool {
if x == y {
return true
}
if c.ignoreInvalids && (x == Typ[Invalid] || y == Typ[Invalid]) {
return true
}
switch x := x.(type) {
case *Basic:
// Basic types are singletons except for the rune and byte
// aliases, thus we cannot solely rely on the x == y check
// above. See also comment in TypeName.IsAlias.
if y, ok := y.(*Basic); ok {
return x.kind == y.kind
}
case *Array:
// Two array types are identical if they have identical element types
// and the same array length.
if y, ok := y.(*Array); ok {
// If one or both array lengths are unknown (< 0) due to some error,
// assume they are the same to avoid spurious follow-on errors.
return (x.len < 0 || y.len < 0 || x.len == y.len) && c.identical(x.elem, y.elem, p)
}
case *Slice:
// Two slice types are identical if they have identical element types.
if y, ok := y.(*Slice); ok {
return c.identical(x.elem, y.elem, p)
}
case *Struct:
// Two struct types are identical if they have the same sequence of fields,
// and if corresponding fields have the same names, and identical types,
// and identical tags. Two embedded fields are considered to have the same
// name. Lower-case field names from different packages are always different.
if y, ok := y.(*Struct); ok {
if x.NumFields() == y.NumFields() {
for i, f := range x.fields {
g := y.fields[i]
if f.embedded != g.embedded ||
!c.ignoreTags && x.Tag(i) != y.Tag(i) ||
!f.sameId(g.pkg, g.name) ||
!c.identical(f.typ, g.typ, p) {
return false
}
}
return true
}
}
case *Pointer:
// Two pointer types are identical if they have identical base types.
if y, ok := y.(*Pointer); ok {
return c.identical(x.base, y.base, p)
}
case *Tuple:
// Two tuples types are identical if they have the same number of elements
// and corresponding elements have identical types.
if y, ok := y.(*Tuple); ok {
if x.Len() == y.Len() {
if x != nil {
for i, v := range x.vars {
w := y.vars[i]
if !c.identical(v.typ, w.typ, p) {
return false
}
}
}
return true
}
}
case *Signature:
y, _ := y.(*Signature)
if y == nil {
return false
}
// Two function types are identical if they have the same number of
// parameters and result values, corresponding parameter and result types
// are identical, and either both functions are variadic or neither is.
// Parameter and result names are not required to match, and type
// parameters are considered identical modulo renaming.
if x.TypeParams().Len() != y.TypeParams().Len() {
return false
}
// In the case of generic signatures, we will substitute in yparams and
// yresults.
yparams := y.params
yresults := y.results
if x.TypeParams().Len() > 0 {
// We must ignore type parameter names when comparing x and y. The
// easiest way to do this is to substitute x's type parameters for y's.
xtparams := x.TypeParams().list()
ytparams := y.TypeParams().list()
var targs []Type
for i := range xtparams {
targs = append(targs, x.TypeParams().At(i))
}
smap := makeSubstMap(ytparams, targs)
var check *Checker // ok to call subst on a nil *Checker
ctxt := NewContext() // need a non-nil Context for the substitution below
// Constraints must be pair-wise identical, after substitution.
for i, xtparam := range xtparams {
ybound := check.subst(nopos, ytparams[i].bound, smap, nil, ctxt)
if !c.identical(xtparam.bound, ybound, p) {
return false
}
}
yparams = check.subst(nopos, y.params, smap, nil, ctxt).(*Tuple)
yresults = check.subst(nopos, y.results, smap, nil, ctxt).(*Tuple)
}
return x.variadic == y.variadic &&
c.identical(x.params, yparams, p) &&
c.identical(x.results, yresults, p)
case *Union:
if y, _ := y.(*Union); y != nil {
// TODO(rfindley): can this be reached during type checking? If so,
// consider passing a type set map.
unionSets := make(map[*Union]*_TypeSet)
xset := computeUnionTypeSet(nil, unionSets, nopos, x)
yset := computeUnionTypeSet(nil, unionSets, nopos, y)
return xset.terms.equal(yset.terms)
}
case *Interface:
// Two interface types are identical if they describe the same type sets.
// With the existing implementation restriction, this simplifies to:
//
// Two interface types are identical if they have the same set of methods with
// the same names and identical function types, and if any type restrictions
// are the same. Lower-case method names from different packages are always
// different. The order of the methods is irrelevant.
if y, ok := y.(*Interface); ok {
xset := x.typeSet()
yset := y.typeSet()
if xset.comparable != yset.comparable {
return false
}
if !xset.terms.equal(yset.terms) {
return false
}
a := xset.methods
b := yset.methods
if len(a) == len(b) {
// Interface types are the only types where cycles can occur
// that are not "terminated" via named types; and such cycles
// can only be created via method parameter types that are
// anonymous interfaces (directly or indirectly) embedding
// the current interface. Example:
//
// type T interface {
// m() interface{T}
// }
//
// If two such (differently named) interfaces are compared,
// endless recursion occurs if the cycle is not detected.
//
// If x and y were compared before, they must be equal
// (if they were not, the recursion would have stopped);
// search the ifacePair stack for the same pair.
//
// This is a quadratic algorithm, but in practice these stacks
// are extremely short (bounded by the nesting depth of interface
// type declarations that recur via parameter types, an extremely
// rare occurrence). An alternative implementation might use a
// "visited" map, but that is probably less efficient overall.
q := &ifacePair{x, y, p}
for p != nil {
if p.identical(q) {
return true // same pair was compared before
}
p = p.prev
}
if debug {
assertSortedMethods(a)
assertSortedMethods(b)
}
for i, f := range a {
g := b[i]
if f.Id() != g.Id() || !c.identical(f.typ, g.typ, q) {
return false
}
}
return true
}
}
case *Map:
// Two map types are identical if they have identical key and value types.
if y, ok := y.(*Map); ok {
return c.identical(x.key, y.key, p) && c.identical(x.elem, y.elem, p)
}
case *Chan:
// Two channel types are identical if they have identical value types
// and the same direction.
if y, ok := y.(*Chan); ok {
return x.dir == y.dir && c.identical(x.elem, y.elem, p)
}
case *Named:
// Two named types are identical if their type names originate
// in the same type declaration.
if y, ok := y.(*Named); ok {
xargs := x.TypeArgs().list()
yargs := y.TypeArgs().list()
if len(xargs) != len(yargs) {
return false
}
if len(xargs) > 0 {
// Instances are identical if their original type and type arguments
// are identical.
if !Identical(x.Origin(), y.Origin()) {
return false
}
for i, xa := range xargs {
if !Identical(xa, yargs[i]) {
return false
}
}
return true
}
// TODO(gri) Why is x == y not sufficient? And if it is,
// we can just return false here because x == y
// is caught in the very beginning of this function.
return x.obj == y.obj
}
case *TypeParam:
// nothing to do (x and y being equal is caught in the very beginning of this function)
case nil:
// avoid a crash in case of nil type
default:
unreachable()
}
return false
}
// identicalInstance reports if two type instantiations are identical.
// Instantiations are identical if their origin and type arguments are
// identical.
func identicalInstance(xorig Type, xargs []Type, yorig Type, yargs []Type) bool {
if len(xargs) != len(yargs) {
return false
}
for i, xa := range xargs {
if !Identical(xa, yargs[i]) {
return false
}
}
return Identical(xorig, yorig)
}
// Default returns the default "typed" type for an "untyped" type;
// it returns the incoming type for all other types. The default type
// for untyped nil is untyped nil.
func Default(t Type) Type {
if t, ok := t.(*Basic); ok {
switch t.kind {
case UntypedBool:
return Typ[Bool]
case UntypedInt:
return Typ[Int]
case UntypedRune:
return universeRune // use 'rune' name
case UntypedFloat:
return Typ[Float64]
case UntypedComplex:
return Typ[Complex128]
case UntypedString:
return Typ[String]
}
}
return t
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
"go/constant"
"go/internal/typeparams"
"go/token"
. "internal/types/errors"
"sort"
"strconv"
"strings"
"unicode"
)
// A declInfo describes a package-level const, type, var, or func declaration.
type declInfo struct {
file *Scope // scope of file containing this declaration
lhs []*Var // lhs of n:1 variable declarations, or nil
vtyp ast.Expr // type, or nil (for const and var declarations only)
init ast.Expr // init/orig expression, or nil (for const and var declarations only)
inherited bool // if set, the init expression is inherited from a previous constant declaration
tdecl *ast.TypeSpec // type declaration, or nil
fdecl *ast.FuncDecl // func declaration, or nil
// The deps field tracks initialization expression dependencies.
deps map[Object]bool // lazily initialized
}
// hasInitializer reports whether the declared object has an initialization
// expression or function body.
func (d *declInfo) hasInitializer() bool {
return d.init != nil || d.fdecl != nil && d.fdecl.Body != nil
}
// addDep adds obj to the set of objects d's init expression depends on.
func (d *declInfo) addDep(obj Object) {
m := d.deps
if m == nil {
m = make(map[Object]bool)
d.deps = m
}
m[obj] = true
}
// arityMatch checks that the lhs and rhs of a const or var decl
// have the appropriate number of names and init exprs. For const
// decls, init is the value spec providing the init exprs; for
// var decls, init is nil (the init exprs are in s in this case).
func (check *Checker) arityMatch(s, init *ast.ValueSpec) {
l := len(s.Names)
r := len(s.Values)
if init != nil {
r = len(init.Values)
}
const code = WrongAssignCount
switch {
case init == nil && r == 0:
// var decl w/o init expr
if s.Type == nil {
check.error(s, code, "missing type or init expr")
}
case l < r:
if l < len(s.Values) {
// init exprs from s
n := s.Values[l]
check.errorf(n, code, "extra init expr %s", n)
// TODO(gri) avoid declared and not used error here
} else {
// init exprs "inherited"
check.errorf(s, code, "extra init expr at %s", check.fset.Position(init.Pos()))
// TODO(gri) avoid declared and not used error here
}
case l > r && (init != nil || r != 1):
n := s.Names[r]
check.errorf(n, code, "missing init expr for %s", n)
}
}
func validatedImportPath(path string) (string, error) {
s, err := strconv.Unquote(path)
if err != nil {
return "", err
}
if s == "" {
return "", fmt.Errorf("empty string")
}
const illegalChars = `!"#$%&'()*,:;<=>?[\]^{|}` + "`\uFFFD"
for _, r := range s {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) || strings.ContainsRune(illegalChars, r) {
return s, fmt.Errorf("invalid character %#U", r)
}
}
return s, nil
}
// declarePkgObj declares obj in the package scope, records its ident -> obj mapping,
// and updates check.objMap. The object must not be a function or method.
func (check *Checker) declarePkgObj(ident *ast.Ident, obj Object, d *declInfo) {
assert(ident.Name == obj.Name())
// spec: "A package-scope or file-scope identifier with name init
// may only be declared to be a function with this (func()) signature."
if ident.Name == "init" {
check.error(ident, InvalidInitDecl, "cannot declare init - must be func")
return
}
// spec: "The main package must have package name main and declare
// a function main that takes no arguments and returns no value."
if ident.Name == "main" && check.pkg.name == "main" {
check.error(ident, InvalidMainDecl, "cannot declare main - must be func")
return
}
check.declare(check.pkg.scope, ident, obj, nopos)
check.objMap[obj] = d
obj.setOrder(uint32(len(check.objMap)))
}
// filename returns a filename suitable for debugging output.
func (check *Checker) filename(fileNo int) string {
file := check.files[fileNo]
if pos := file.Pos(); pos.IsValid() {
return check.fset.File(pos).Name()
}
return fmt.Sprintf("file[%d]", fileNo)
}
func (check *Checker) importPackage(at positioner, path, dir string) *Package {
// If we already have a package for the given (path, dir)
// pair, use it instead of doing a full import.
// Checker.impMap only caches packages that are marked Complete
// or fake (dummy packages for failed imports). Incomplete but
// non-fake packages do require an import to complete them.
key := importKey{path, dir}
imp := check.impMap[key]
if imp != nil {
return imp
}
// no package yet => import it
if path == "C" && (check.conf.FakeImportC || check.conf.go115UsesCgo) {
imp = NewPackage("C", "C")
imp.fake = true // package scope is not populated
imp.cgo = check.conf.go115UsesCgo
} else {
// ordinary import
var err error
if importer := check.conf.Importer; importer == nil {
err = fmt.Errorf("Config.Importer not installed")
} else if importerFrom, ok := importer.(ImporterFrom); ok {
imp, err = importerFrom.ImportFrom(path, dir, 0)
if imp == nil && err == nil {
err = fmt.Errorf("Config.Importer.ImportFrom(%s, %s, 0) returned nil but no error", path, dir)
}
} else {
imp, err = importer.Import(path)
if imp == nil && err == nil {
err = fmt.Errorf("Config.Importer.Import(%s) returned nil but no error", path)
}
}
// make sure we have a valid package name
// (errors here can only happen through manipulation of packages after creation)
if err == nil && imp != nil && (imp.name == "_" || imp.name == "") {
err = fmt.Errorf("invalid package name: %q", imp.name)
imp = nil // create fake package below
}
if err != nil {
check.errorf(at, BrokenImport, "could not import %s (%s)", path, err)
if imp == nil {
// create a new fake package
// come up with a sensible package name (heuristic)
name := path
if i := len(name); i > 0 && name[i-1] == '/' {
name = name[:i-1]
}
if i := strings.LastIndex(name, "/"); i >= 0 {
name = name[i+1:]
}
imp = NewPackage(path, name)
}
// continue to use the package as best as we can
imp.fake = true // avoid follow-up lookup failures
}
}
// package should be complete or marked fake, but be cautious
if imp.complete || imp.fake {
check.impMap[key] = imp
// Once we've formatted an error message, keep the pkgPathMap
// up-to-date on subsequent imports. It is used for package
// qualification in error messages.
if check.pkgPathMap != nil {
check.markImports(imp)
}
return imp
}
// something went wrong (importer may have returned incomplete package without error)
return nil
}
// collectObjects collects all file and package objects and inserts them
// into their respective scopes. It also performs imports and associates
// methods with receiver base type names.
func (check *Checker) collectObjects() {
pkg := check.pkg
// pkgImports is the set of packages already imported by any package file seen
// so far. Used to avoid duplicate entries in pkg.imports. Allocate and populate
// it (pkg.imports may not be empty if we are checking test files incrementally).
// Note that pkgImports is keyed by package (and thus package path), not by an
// importKey value. Two different importKey values may map to the same package
// which is why we cannot use the check.impMap here.
var pkgImports = make(map[*Package]bool)
for _, imp := range pkg.imports {
pkgImports[imp] = true
}
type methodInfo struct {
obj *Func // method
ptr bool // true if pointer receiver
recv *ast.Ident // receiver type name
}
var methods []methodInfo // collected methods with valid receivers and non-blank _ names
var fileScopes []*Scope
for fileNo, file := range check.files {
// The package identifier denotes the current package,
// but there is no corresponding package object.
check.recordDef(file.Name, nil)
// Use the actual source file extent rather than *ast.File extent since the
// latter doesn't include comments which appear at the start or end of the file.
// Be conservative and use the *ast.File extent if we don't have a *token.File.
pos, end := file.Pos(), file.End()
if f := check.fset.File(file.Pos()); f != nil {
pos, end = token.Pos(f.Base()), token.Pos(f.Base()+f.Size())
}
fileScope := NewScope(check.pkg.scope, pos, end, check.filename(fileNo))
fileScopes = append(fileScopes, fileScope)
check.recordScope(file, fileScope)
// determine file directory, necessary to resolve imports
// FileName may be "" (typically for tests) in which case
// we get "." as the directory which is what we would want.
fileDir := dir(check.fset.Position(file.Name.Pos()).Filename)
check.walkDecls(file.Decls, func(d decl) {
switch d := d.(type) {
case importDecl:
// import package
if d.spec.Path.Value == "" {
return // error reported by parser
}
path, err := validatedImportPath(d.spec.Path.Value)
if err != nil {
check.errorf(d.spec.Path, BadImportPath, "invalid import path (%s)", err)
return
}
imp := check.importPackage(d.spec.Path, path, fileDir)
if imp == nil {
return
}
// local name overrides imported package name
name := imp.name
if d.spec.Name != nil {
name = d.spec.Name.Name
if path == "C" {
// match 1.17 cmd/compile (not prescribed by spec)
check.error(d.spec.Name, ImportCRenamed, `cannot rename import "C"`)
return
}
}
if name == "init" {
check.error(d.spec, InvalidInitDecl, "cannot import package as init - init must be a func")
return
}
// add package to list of explicit imports
// (this functionality is provided as a convenience
// for clients; it is not needed for type-checking)
if !pkgImports[imp] {
pkgImports[imp] = true
pkg.imports = append(pkg.imports, imp)
}
pkgName := NewPkgName(d.spec.Pos(), pkg, name, imp)
if d.spec.Name != nil {
// in a dot-import, the dot represents the package
check.recordDef(d.spec.Name, pkgName)
} else {
check.recordImplicit(d.spec, pkgName)
}
if imp.fake {
// match 1.17 cmd/compile (not prescribed by spec)
pkgName.used = true
}
// add import to file scope
check.imports = append(check.imports, pkgName)
if name == "." {
// dot-import
if check.dotImportMap == nil {
check.dotImportMap = make(map[dotImportKey]*PkgName)
}
// merge imported scope with file scope
for name, obj := range imp.scope.elems {
// Note: Avoid eager resolve(name, obj) here, so we only
// resolve dot-imported objects as needed.
// A package scope may contain non-exported objects,
// do not import them!
if token.IsExported(name) {
// declare dot-imported object
// (Do not use check.declare because it modifies the object
// via Object.setScopePos, which leads to a race condition;
// the object may be imported into more than one file scope
// concurrently. See go.dev/issue/32154.)
if alt := fileScope.Lookup(name); alt != nil {
check.errorf(d.spec.Name, DuplicateDecl, "%s redeclared in this block", alt.Name())
check.reportAltDecl(alt)
} else {
fileScope.insert(name, obj)
check.dotImportMap[dotImportKey{fileScope, name}] = pkgName
}
}
}
} else {
// declare imported package object in file scope
// (no need to provide s.Name since we called check.recordDef earlier)
check.declare(fileScope, nil, pkgName, nopos)
}
case constDecl:
// declare all constants
for i, name := range d.spec.Names {
obj := NewConst(name.Pos(), pkg, name.Name, nil, constant.MakeInt64(int64(d.iota)))
var init ast.Expr
if i < len(d.init) {
init = d.init[i]
}
d := &declInfo{file: fileScope, vtyp: d.typ, init: init, inherited: d.inherited}
check.declarePkgObj(name, obj, d)
}
case varDecl:
lhs := make([]*Var, len(d.spec.Names))
// If there's exactly one rhs initializer, use
// the same declInfo d1 for all lhs variables
// so that each lhs variable depends on the same
// rhs initializer (n:1 var declaration).
var d1 *declInfo
if len(d.spec.Values) == 1 {
// The lhs elements are only set up after the for loop below,
// but that's ok because declareVar only collects the declInfo
// for a later phase.
d1 = &declInfo{file: fileScope, lhs: lhs, vtyp: d.spec.Type, init: d.spec.Values[0]}
}
// declare all variables
for i, name := range d.spec.Names {
obj := NewVar(name.Pos(), pkg, name.Name, nil)
lhs[i] = obj
di := d1
if di == nil {
// individual assignments
var init ast.Expr
if i < len(d.spec.Values) {
init = d.spec.Values[i]
}
di = &declInfo{file: fileScope, vtyp: d.spec.Type, init: init}
}
check.declarePkgObj(name, obj, di)
}
case typeDecl:
if d.spec.TypeParams.NumFields() != 0 && !check.allowVersion(pkg, 1, 18) {
check.softErrorf(d.spec.TypeParams.List[0], UnsupportedFeature, "type parameter requires go1.18 or later")
}
obj := NewTypeName(d.spec.Name.Pos(), pkg, d.spec.Name.Name, nil)
check.declarePkgObj(d.spec.Name, obj, &declInfo{file: fileScope, tdecl: d.spec})
case funcDecl:
name := d.decl.Name.Name
obj := NewFunc(d.decl.Name.Pos(), pkg, name, nil)
hasTParamError := false // avoid duplicate type parameter errors
if d.decl.Recv.NumFields() == 0 {
// regular function
if d.decl.Recv != nil {
check.error(d.decl.Recv, BadRecv, "method has no receiver")
// treat as function
}
if name == "init" || (name == "main" && check.pkg.name == "main") {
code := InvalidInitDecl
if name == "main" {
code = InvalidMainDecl
}
if d.decl.Type.TypeParams.NumFields() != 0 {
check.softErrorf(d.decl.Type.TypeParams.List[0], code, "func %s must have no type parameters", name)
hasTParamError = true
}
if t := d.decl.Type; t.Params.NumFields() != 0 || t.Results != nil {
// TODO(rFindley) Should this be a hard error?
check.softErrorf(d.decl.Name, code, "func %s must have no arguments and no return values", name)
}
}
if name == "init" {
// don't declare init functions in the package scope - they are invisible
obj.parent = pkg.scope
check.recordDef(d.decl.Name, obj)
// init functions must have a body
if d.decl.Body == nil {
// TODO(gri) make this error message consistent with the others above
check.softErrorf(obj, MissingInitBody, "missing function body")
}
} else {
check.declare(pkg.scope, d.decl.Name, obj, nopos)
}
} else {
// method
// TODO(rFindley) earlier versions of this code checked that methods
// have no type parameters, but this is checked later
// when type checking the function type. Confirm that
// we don't need to check tparams here.
ptr, recv, _ := check.unpackRecv(d.decl.Recv.List[0].Type, false)
// (Methods with invalid receiver cannot be associated to a type, and
// methods with blank _ names are never found; no need to collect any
// of them. They will still be type-checked with all the other functions.)
if recv != nil && name != "_" {
methods = append(methods, methodInfo{obj, ptr, recv})
}
check.recordDef(d.decl.Name, obj)
}
if d.decl.Type.TypeParams.NumFields() != 0 && !check.allowVersion(pkg, 1, 18) && !hasTParamError {
check.softErrorf(d.decl.Type.TypeParams.List[0], UnsupportedFeature, "type parameter requires go1.18 or later")
}
info := &declInfo{file: fileScope, fdecl: d.decl}
// Methods are not package-level objects but we still track them in the
// object map so that we can handle them like regular functions (if the
// receiver is invalid); also we need their fdecl info when associating
// them with their receiver base type, below.
check.objMap[obj] = info
obj.setOrder(uint32(len(check.objMap)))
}
})
}
// verify that objects in package and file scopes have different names
for _, scope := range fileScopes {
for name, obj := range scope.elems {
if alt := pkg.scope.Lookup(name); alt != nil {
obj = resolve(name, obj)
if pkg, ok := obj.(*PkgName); ok {
check.errorf(alt, DuplicateDecl, "%s already declared through import of %s", alt.Name(), pkg.Imported())
check.reportAltDecl(pkg)
} else {
check.errorf(alt, DuplicateDecl, "%s already declared through dot-import of %s", alt.Name(), obj.Pkg())
// TODO(gri) dot-imported objects don't have a position; reportAltDecl won't print anything
check.reportAltDecl(obj)
}
}
}
}
// Now that we have all package scope objects and all methods,
// associate methods with receiver base type name where possible.
// Ignore methods that have an invalid receiver. They will be
// type-checked later, with regular functions.
if methods == nil {
return // nothing to do
}
check.methods = make(map[*TypeName][]*Func)
for i := range methods {
m := &methods[i]
// Determine the receiver base type and associate m with it.
ptr, base := check.resolveBaseTypeName(m.ptr, m.recv)
if base != nil {
m.obj.hasPtrRecv_ = ptr
check.methods[base] = append(check.methods[base], m.obj)
}
}
}
// unpackRecv unpacks a receiver type and returns its components: ptr indicates whether
// rtyp is a pointer receiver, rname is the receiver type name, and tparams are its
// type parameters, if any. The type parameters are only unpacked if unpackParams is
// set. If rname is nil, the receiver is unusable (i.e., the source has a bug which we
// cannot easily work around).
func (check *Checker) unpackRecv(rtyp ast.Expr, unpackParams bool) (ptr bool, rname *ast.Ident, tparams []*ast.Ident) {
L: // unpack receiver type
// This accepts invalid receivers such as ***T and does not
// work for other invalid receivers, but we don't care. The
// validity of receiver expressions is checked elsewhere.
for {
switch t := rtyp.(type) {
case *ast.ParenExpr:
rtyp = t.X
case *ast.StarExpr:
ptr = true
rtyp = t.X
default:
break L
}
}
// unpack type parameters, if any
switch rtyp.(type) {
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(rtyp)
rtyp = ix.X
if unpackParams {
for _, arg := range ix.Indices {
var par *ast.Ident
switch arg := arg.(type) {
case *ast.Ident:
par = arg
case *ast.BadExpr:
// ignore - error already reported by parser
case nil:
check.error(ix.Orig, InvalidSyntaxTree, "parameterized receiver contains nil parameters")
default:
check.errorf(arg, BadDecl, "receiver type parameter %s must be an identifier", arg)
}
if par == nil {
par = &ast.Ident{NamePos: arg.Pos(), Name: "_"}
}
tparams = append(tparams, par)
}
}
}
// unpack receiver name
if name, _ := rtyp.(*ast.Ident); name != nil {
rname = name
}
return
}
// resolveBaseTypeName returns the non-alias base type name for typ, and whether
// there was a pointer indirection to get to it. The base type name must be declared
// in package scope, and there can be at most one pointer indirection. If no such type
// name exists, the returned base is nil.
func (check *Checker) resolveBaseTypeName(seenPtr bool, name *ast.Ident) (ptr bool, base *TypeName) {
// Algorithm: Starting from a type expression, which may be a name,
// we follow that type through alias declarations until we reach a
// non-alias type name. If we encounter anything but pointer types or
// parentheses we're done. If we encounter more than one pointer type
// we're done.
ptr = seenPtr
var seen map[*TypeName]bool
var typ ast.Expr = name
for {
typ = unparen(typ)
// check if we have a pointer type
if pexpr, _ := typ.(*ast.StarExpr); pexpr != nil {
// if we've already seen a pointer, we're done
if ptr {
return false, nil
}
ptr = true
typ = unparen(pexpr.X) // continue with pointer base type
}
// typ must be a name
name, _ := typ.(*ast.Ident)
if name == nil {
return false, nil
}
// name must denote an object found in the current package scope
// (note that dot-imported objects are not in the package scope!)
obj := check.pkg.scope.Lookup(name.Name)
if obj == nil {
return false, nil
}
// the object must be a type name...
tname, _ := obj.(*TypeName)
if tname == nil {
return false, nil
}
// ... which we have not seen before
if seen[tname] {
return false, nil
}
// we're done if tdecl defined tname as a new type
// (rather than an alias)
tdecl := check.objMap[tname].tdecl // must exist for objects in package scope
if !tdecl.Assign.IsValid() {
return ptr, tname
}
// otherwise, continue resolving
typ = tdecl.Type
if seen == nil {
seen = make(map[*TypeName]bool)
}
seen[tname] = true
}
}
// packageObjects typechecks all package objects, but not function bodies.
func (check *Checker) packageObjects() {
// process package objects in source order for reproducible results
objList := make([]Object, len(check.objMap))
i := 0
for obj := range check.objMap {
objList[i] = obj
i++
}
sort.Sort(inSourceOrder(objList))
// add new methods to already type-checked types (from a prior Checker.Files call)
for _, obj := range objList {
if obj, _ := obj.(*TypeName); obj != nil && obj.typ != nil {
check.collectMethods(obj)
}
}
// We process non-alias type declarations first, followed by alias declarations,
// and then everything else. This appears to avoid most situations where the type
// of an alias is needed before it is available.
// There may still be cases where this is not good enough (see also go.dev/issue/25838).
// In those cases Checker.ident will report an error ("invalid use of type alias").
var aliasList []*TypeName
var othersList []Object // everything that's not a type
// phase 1: non-alias type declarations
for _, obj := range objList {
if tname, _ := obj.(*TypeName); tname != nil {
if check.objMap[tname].tdecl.Assign.IsValid() {
aliasList = append(aliasList, tname)
} else {
check.objDecl(obj, nil)
}
} else {
othersList = append(othersList, obj)
}
}
// phase 2: alias type declarations
for _, obj := range aliasList {
check.objDecl(obj, nil)
}
// phase 3: all other declarations
for _, obj := range othersList {
check.objDecl(obj, nil)
}
// At this point we may have a non-empty check.methods map; this means that not all
// entries were deleted at the end of typeDecl because the respective receiver base
// types were not found. In that case, an error was reported when declaring those
// methods. We can now safely discard this map.
check.methods = nil
}
// inSourceOrder implements the sort.Sort interface.
type inSourceOrder []Object
func (a inSourceOrder) Len() int { return len(a) }
func (a inSourceOrder) Less(i, j int) bool { return a[i].order() < a[j].order() }
func (a inSourceOrder) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// unusedImports checks for unused imports.
func (check *Checker) unusedImports() {
// If function bodies are not checked, packages' uses are likely missing - don't check.
if check.conf.IgnoreFuncBodies {
return
}
// spec: "It is illegal (...) to directly import a package without referring to
// any of its exported identifiers. To import a package solely for its side-effects
// (initialization), use the blank identifier as explicit package name."
for _, obj := range check.imports {
if !obj.used && obj.name != "_" {
check.errorUnusedPkg(obj)
}
}
}
func (check *Checker) errorUnusedPkg(obj *PkgName) {
// If the package was imported with a name other than the final
// import path element, show it explicitly in the error message.
// Note that this handles both renamed imports and imports of
// packages containing unconventional package declarations.
// Note that this uses / always, even on Windows, because Go import
// paths always use forward slashes.
path := obj.imported.path
elem := path
if i := strings.LastIndex(elem, "/"); i >= 0 {
elem = elem[i+1:]
}
if obj.name == "" || obj.name == "." || obj.name == elem {
check.softErrorf(obj, UnusedImport, "%q imported and not used", path)
} else {
check.softErrorf(obj, UnusedImport, "%q imported as %s and not used", path, obj.name)
}
}
// dir makes a good-faith attempt to return the directory
// portion of path. If path is empty, the result is ".".
// (Per the go/build package dependency tests, we cannot import
// path/filepath and simply use filepath.Dir.)
func dir(path string) string {
if i := strings.LastIndexAny(path, `/\`); i > 0 {
return path[:i]
}
// i <= 0
return "."
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements isTerminating.
package types
import (
"go/ast"
"go/token"
)
// isTerminating reports if s is a terminating statement.
// If s is labeled, label is the label name; otherwise s
// is "".
func (check *Checker) isTerminating(s ast.Stmt, label string) bool {
switch s := s.(type) {
default:
unreachable()
case *ast.BadStmt, *ast.DeclStmt, *ast.EmptyStmt, *ast.SendStmt,
*ast.IncDecStmt, *ast.AssignStmt, *ast.GoStmt, *ast.DeferStmt,
*ast.RangeStmt:
// no chance
case *ast.LabeledStmt:
return check.isTerminating(s.Stmt, s.Label.Name)
case *ast.ExprStmt:
// calling the predeclared (possibly parenthesized) panic() function is terminating
if call, ok := unparen(s.X).(*ast.CallExpr); ok && check.isPanic[call] {
return true
}
case *ast.ReturnStmt:
return true
case *ast.BranchStmt:
if s.Tok == token.GOTO || s.Tok == token.FALLTHROUGH {
return true
}
case *ast.BlockStmt:
return check.isTerminatingList(s.List, "")
case *ast.IfStmt:
if s.Else != nil &&
check.isTerminating(s.Body, "") &&
check.isTerminating(s.Else, "") {
return true
}
case *ast.SwitchStmt:
return check.isTerminatingSwitch(s.Body, label)
case *ast.TypeSwitchStmt:
return check.isTerminatingSwitch(s.Body, label)
case *ast.SelectStmt:
for _, s := range s.Body.List {
cc := s.(*ast.CommClause)
if !check.isTerminatingList(cc.Body, "") || hasBreakList(cc.Body, label, true) {
return false
}
}
return true
case *ast.ForStmt:
if s.Cond == nil && !hasBreak(s.Body, label, true) {
return true
}
}
return false
}
func (check *Checker) isTerminatingList(list []ast.Stmt, label string) bool {
// trailing empty statements are permitted - skip them
for i := len(list) - 1; i >= 0; i-- {
if _, ok := list[i].(*ast.EmptyStmt); !ok {
return check.isTerminating(list[i], label)
}
}
return false // all statements are empty
}
func (check *Checker) isTerminatingSwitch(body *ast.BlockStmt, label string) bool {
hasDefault := false
for _, s := range body.List {
cc := s.(*ast.CaseClause)
if cc.List == nil {
hasDefault = true
}
if !check.isTerminatingList(cc.Body, "") || hasBreakList(cc.Body, label, true) {
return false
}
}
return hasDefault
}
// TODO(gri) For nested breakable statements, the current implementation of hasBreak
// will traverse the same subtree repeatedly, once for each label. Replace
// with a single-pass label/break matching phase.
// hasBreak reports if s is or contains a break statement
// referring to the label-ed statement or implicit-ly the
// closest outer breakable statement.
func hasBreak(s ast.Stmt, label string, implicit bool) bool {
switch s := s.(type) {
default:
unreachable()
case *ast.BadStmt, *ast.DeclStmt, *ast.EmptyStmt, *ast.ExprStmt,
*ast.SendStmt, *ast.IncDecStmt, *ast.AssignStmt, *ast.GoStmt,
*ast.DeferStmt, *ast.ReturnStmt:
// no chance
case *ast.LabeledStmt:
return hasBreak(s.Stmt, label, implicit)
case *ast.BranchStmt:
if s.Tok == token.BREAK {
if s.Label == nil {
return implicit
}
if s.Label.Name == label {
return true
}
}
case *ast.BlockStmt:
return hasBreakList(s.List, label, implicit)
case *ast.IfStmt:
if hasBreak(s.Body, label, implicit) ||
s.Else != nil && hasBreak(s.Else, label, implicit) {
return true
}
case *ast.CaseClause:
return hasBreakList(s.Body, label, implicit)
case *ast.SwitchStmt:
if label != "" && hasBreak(s.Body, label, false) {
return true
}
case *ast.TypeSwitchStmt:
if label != "" && hasBreak(s.Body, label, false) {
return true
}
case *ast.CommClause:
return hasBreakList(s.Body, label, implicit)
case *ast.SelectStmt:
if label != "" && hasBreak(s.Body, label, false) {
return true
}
case *ast.ForStmt:
if label != "" && hasBreak(s.Body, label, false) {
return true
}
case *ast.RangeStmt:
if label != "" && hasBreak(s.Body, label, false) {
return true
}
}
return false
}
func hasBreakList(list []ast.Stmt, label string, implicit bool) bool {
for _, s := range list {
if hasBreak(s, label, implicit) {
return true
}
}
return false
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements Scopes.
package types
import (
"fmt"
"go/token"
"io"
"sort"
"strings"
"sync"
)
// A Scope maintains a set of objects and links to its containing
// (parent) and contained (children) scopes. Objects may be inserted
// and looked up by name. The zero value for Scope is a ready-to-use
// empty scope.
type Scope struct {
parent *Scope
children []*Scope
number int // parent.children[number-1] is this scope; 0 if there is no parent
elems map[string]Object // lazily allocated
pos, end token.Pos // scope extent; may be invalid
comment string // for debugging only
isFunc bool // set if this is a function scope (internal use only)
}
// NewScope returns a new, empty scope contained in the given parent
// scope, if any. The comment is for debugging only.
func NewScope(parent *Scope, pos, end token.Pos, comment string) *Scope {
s := &Scope{parent, nil, 0, nil, pos, end, comment, false}
// don't add children to Universe scope!
if parent != nil && parent != Universe {
parent.children = append(parent.children, s)
s.number = len(parent.children)
}
return s
}
// Parent returns the scope's containing (parent) scope.
func (s *Scope) Parent() *Scope { return s.parent }
// Len returns the number of scope elements.
func (s *Scope) Len() int { return len(s.elems) }
// Names returns the scope's element names in sorted order.
func (s *Scope) Names() []string {
names := make([]string, len(s.elems))
i := 0
for name := range s.elems {
names[i] = name
i++
}
sort.Strings(names)
return names
}
// NumChildren returns the number of scopes nested in s.
func (s *Scope) NumChildren() int { return len(s.children) }
// Child returns the i'th child scope for 0 <= i < NumChildren().
func (s *Scope) Child(i int) *Scope { return s.children[i] }
// Lookup returns the object in scope s with the given name if such an
// object exists; otherwise the result is nil.
func (s *Scope) Lookup(name string) Object {
return resolve(name, s.elems[name])
}
// LookupParent follows the parent chain of scopes starting with s until
// it finds a scope where Lookup(name) returns a non-nil object, and then
// returns that scope and object. If a valid position pos is provided,
// only objects that were declared at or before pos are considered.
// If no such scope and object exists, the result is (nil, nil).
//
// Note that obj.Parent() may be different from the returned scope if the
// object was inserted into the scope and already had a parent at that
// time (see Insert). This can only happen for dot-imported objects
// whose scope is the scope of the package that exported them.
func (s *Scope) LookupParent(name string, pos token.Pos) (*Scope, Object) {
for ; s != nil; s = s.parent {
if obj := s.Lookup(name); obj != nil && (!pos.IsValid() || cmpPos(obj.scopePos(), pos) <= 0) {
return s, obj
}
}
return nil, nil
}
// Insert attempts to insert an object obj into scope s.
// If s already contains an alternative object alt with
// the same name, Insert leaves s unchanged and returns alt.
// Otherwise it inserts obj, sets the object's parent scope
// if not already set, and returns nil.
func (s *Scope) Insert(obj Object) Object {
name := obj.Name()
if alt := s.Lookup(name); alt != nil {
return alt
}
s.insert(name, obj)
if obj.Parent() == nil {
obj.setParent(s)
}
return nil
}
// InsertLazy is like Insert, but allows deferring construction of the
// inserted object until it's accessed with Lookup. The Object
// returned by resolve must have the same name as given to InsertLazy.
// If s already contains an alternative object with the same name,
// InsertLazy leaves s unchanged and returns false. Otherwise it
// records the binding and returns true. The object's parent scope
// will be set to s after resolve is called.
func (s *Scope) _InsertLazy(name string, resolve func() Object) bool {
if s.elems[name] != nil {
return false
}
s.insert(name, &lazyObject{parent: s, resolve: resolve})
return true
}
func (s *Scope) insert(name string, obj Object) {
if s.elems == nil {
s.elems = make(map[string]Object)
}
s.elems[name] = obj
}
// Squash merges s with its parent scope p by adding all
// objects of s to p, adding all children of s to the
// children of p, and removing s from p's children.
// The function f is called for each object obj in s which
// has an object alt in p. s should be discarded after
// having been squashed.
func (s *Scope) squash(err func(obj, alt Object)) {
p := s.parent
assert(p != nil)
for name, obj := range s.elems {
obj = resolve(name, obj)
obj.setParent(nil)
if alt := p.Insert(obj); alt != nil {
err(obj, alt)
}
}
j := -1 // index of s in p.children
for i, ch := range p.children {
if ch == s {
j = i
break
}
}
assert(j >= 0)
k := len(p.children) - 1
p.children[j] = p.children[k]
p.children = p.children[:k]
p.children = append(p.children, s.children...)
s.children = nil
s.elems = nil
}
// Pos and End describe the scope's source code extent [pos, end).
// The results are guaranteed to be valid only if the type-checked
// AST has complete position information. The extent is undefined
// for Universe and package scopes.
func (s *Scope) Pos() token.Pos { return s.pos }
func (s *Scope) End() token.Pos { return s.end }
// Contains reports whether pos is within the scope's extent.
// The result is guaranteed to be valid only if the type-checked
// AST has complete position information.
func (s *Scope) Contains(pos token.Pos) bool {
return cmpPos(s.pos, pos) <= 0 && cmpPos(pos, s.end) < 0
}
// Innermost returns the innermost (child) scope containing
// pos. If pos is not within any scope, the result is nil.
// The result is also nil for the Universe scope.
// The result is guaranteed to be valid only if the type-checked
// AST has complete position information.
func (s *Scope) Innermost(pos token.Pos) *Scope {
// Package scopes do not have extents since they may be
// discontiguous, so iterate over the package's files.
if s.parent == Universe {
for _, s := range s.children {
if inner := s.Innermost(pos); inner != nil {
return inner
}
}
}
if s.Contains(pos) {
for _, s := range s.children {
if s.Contains(pos) {
return s.Innermost(pos)
}
}
return s
}
return nil
}
// WriteTo writes a string representation of the scope to w,
// with the scope elements sorted by name.
// The level of indentation is controlled by n >= 0, with
// n == 0 for no indentation.
// If recurse is set, it also writes nested (children) scopes.
func (s *Scope) WriteTo(w io.Writer, n int, recurse bool) {
const ind = ". "
indn := strings.Repeat(ind, n)
fmt.Fprintf(w, "%s%s scope %p {\n", indn, s.comment, s)
indn1 := indn + ind
for _, name := range s.Names() {
fmt.Fprintf(w, "%s%s\n", indn1, s.Lookup(name))
}
if recurse {
for _, s := range s.children {
s.WriteTo(w, n+1, recurse)
}
}
fmt.Fprintf(w, "%s}\n", indn)
}
// String returns a string representation of the scope, for debugging.
func (s *Scope) String() string {
var buf strings.Builder
s.WriteTo(&buf, 0, false)
return buf.String()
}
// A lazyObject represents an imported Object that has not been fully
// resolved yet by its importer.
type lazyObject struct {
parent *Scope
resolve func() Object
obj Object
once sync.Once
}
// resolve returns the Object represented by obj, resolving lazy
// objects as appropriate.
func resolve(name string, obj Object) Object {
if lazy, ok := obj.(*lazyObject); ok {
lazy.once.Do(func() {
obj := lazy.resolve()
if _, ok := obj.(*lazyObject); ok {
panic("recursive lazy object")
}
if obj.Name() != name {
panic("lazy object has unexpected name")
}
if obj.Parent() == nil {
obj.setParent(lazy.parent)
}
lazy.obj = obj
})
obj = lazy.obj
}
return obj
}
// stub implementations so *lazyObject implements Object and we can
// store them directly into Scope.elems.
func (*lazyObject) Parent() *Scope { panic("unreachable") }
func (*lazyObject) Pos() token.Pos { panic("unreachable") }
func (*lazyObject) Pkg() *Package { panic("unreachable") }
func (*lazyObject) Name() string { panic("unreachable") }
func (*lazyObject) Type() Type { panic("unreachable") }
func (*lazyObject) Exported() bool { panic("unreachable") }
func (*lazyObject) Id() string { panic("unreachable") }
func (*lazyObject) String() string { panic("unreachable") }
func (*lazyObject) order() uint32 { panic("unreachable") }
func (*lazyObject) color() color { panic("unreachable") }
func (*lazyObject) setType(Type) { panic("unreachable") }
func (*lazyObject) setOrder(uint32) { panic("unreachable") }
func (*lazyObject) setColor(color color) { panic("unreachable") }
func (*lazyObject) setParent(*Scope) { panic("unreachable") }
func (*lazyObject) sameId(pkg *Package, name string) bool { panic("unreachable") }
func (*lazyObject) scopePos() token.Pos { panic("unreachable") }
func (*lazyObject) setScopePos(pos token.Pos) { panic("unreachable") }
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements Selections.
package types
import (
"bytes"
"fmt"
)
// SelectionKind describes the kind of a selector expression x.f
// (excluding qualified identifiers).
type SelectionKind int
const (
FieldVal SelectionKind = iota // x.f is a struct field selector
MethodVal // x.f is a method selector
MethodExpr // x.f is a method expression
)
// A Selection describes a selector expression x.f.
// For the declarations:
//
// type T struct{ x int; E }
// type E struct{}
// func (e E) m() {}
// var p *T
//
// the following relations exist:
//
// Selector Kind Recv Obj Type Index Indirect
//
// p.x FieldVal T x int {0} true
// p.m MethodVal *T m func() {1, 0} true
// T.m MethodExpr T m func(T) {1, 0} false
type Selection struct {
kind SelectionKind
recv Type // type of x
obj Object // object denoted by x.f
index []int // path from x to x.f
indirect bool // set if there was any pointer indirection on the path
}
// Kind returns the selection kind.
func (s *Selection) Kind() SelectionKind { return s.kind }
// Recv returns the type of x in x.f.
func (s *Selection) Recv() Type { return s.recv }
// Obj returns the object denoted by x.f; a *Var for
// a field selection, and a *Func in all other cases.
func (s *Selection) Obj() Object { return s.obj }
// Type returns the type of x.f, which may be different from the type of f.
// See Selection for more information.
func (s *Selection) Type() Type {
switch s.kind {
case MethodVal:
// The type of x.f is a method with its receiver type set
// to the type of x.
sig := *s.obj.(*Func).typ.(*Signature)
recv := *sig.recv
recv.typ = s.recv
sig.recv = &recv
return &sig
case MethodExpr:
// The type of x.f is a function (without receiver)
// and an additional first argument with the same type as x.
// TODO(gri) Similar code is already in call.go - factor!
// TODO(gri) Compute this eagerly to avoid allocations.
sig := *s.obj.(*Func).typ.(*Signature)
arg0 := *sig.recv
sig.recv = nil
arg0.typ = s.recv
var params []*Var
if sig.params != nil {
params = sig.params.vars
}
sig.params = NewTuple(append([]*Var{&arg0}, params...)...)
return &sig
}
// In all other cases, the type of x.f is the type of x.
return s.obj.Type()
}
// Index describes the path from x to f in x.f.
// The last index entry is the field or method index of the type declaring f;
// either:
//
// 1. the list of declared methods of a named type; or
// 2. the list of methods of an interface type; or
// 3. the list of fields of a struct type.
//
// The earlier index entries are the indices of the embedded fields implicitly
// traversed to get from (the type of) x to f, starting at embedding depth 0.
func (s *Selection) Index() []int { return s.index }
// Indirect reports whether any pointer indirection was required to get from
// x to f in x.f.
func (s *Selection) Indirect() bool { return s.indirect }
func (s *Selection) String() string { return SelectionString(s, nil) }
// SelectionString returns the string form of s.
// The Qualifier controls the printing of
// package-level objects, and may be nil.
//
// Examples:
//
// "field (T) f int"
// "method (T) f(X) Y"
// "method expr (T) f(X) Y"
func SelectionString(s *Selection, qf Qualifier) string {
var k string
switch s.kind {
case FieldVal:
k = "field "
case MethodVal:
k = "method "
case MethodExpr:
k = "method expr "
default:
unreachable()
}
var buf bytes.Buffer
buf.WriteString(k)
buf.WriteByte('(')
WriteType(&buf, s.Recv(), qf)
fmt.Fprintf(&buf, ") %s", s.obj.Name())
if T := s.Type(); s.kind == FieldVal {
buf.WriteByte(' ')
WriteType(&buf, T, qf)
} else {
WriteSignature(&buf, T.(*Signature), qf)
}
return buf.String()
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
. "internal/types/errors"
)
// ----------------------------------------------------------------------------
// API
// A Signature represents a (non-builtin) function or method type.
// The receiver is ignored when comparing signatures for identity.
type Signature struct {
// We need to keep the scope in Signature (rather than passing it around
// and store it in the Func Object) because when type-checking a function
// literal we call the general type checker which returns a general Type.
// We then unpack the *Signature and use the scope for the literal body.
rparams *TypeParamList // receiver type parameters from left to right, or nil
tparams *TypeParamList // type parameters from left to right, or nil
scope *Scope // function scope for package-local and non-instantiated signatures; nil otherwise
recv *Var // nil if not a method
params *Tuple // (incoming) parameters from left to right; or nil
results *Tuple // (outgoing) results from left to right; or nil
variadic bool // true if the last parameter's type is of the form ...T (or string, for append built-in only)
}
// NewSignature returns a new function type for the given receiver, parameters,
// and results, either of which may be nil. If variadic is set, the function
// is variadic, it must have at least one parameter, and the last parameter
// must be of unnamed slice type.
//
// Deprecated: Use NewSignatureType instead which allows for type parameters.
func NewSignature(recv *Var, params, results *Tuple, variadic bool) *Signature {
return NewSignatureType(recv, nil, nil, params, results, variadic)
}
// NewSignatureType creates a new function type for the given receiver,
// receiver type parameters, type parameters, parameters, and results. If
// variadic is set, params must hold at least one parameter and the last
// parameter's core type must be of unnamed slice or bytestring type.
// If recv is non-nil, typeParams must be empty. If recvTypeParams is
// non-empty, recv must be non-nil.
func NewSignatureType(recv *Var, recvTypeParams, typeParams []*TypeParam, params, results *Tuple, variadic bool) *Signature {
if variadic {
n := params.Len()
if n == 0 {
panic("variadic function must have at least one parameter")
}
core := coreString(params.At(n - 1).typ)
if _, ok := core.(*Slice); !ok && !isString(core) {
panic(fmt.Sprintf("got %s, want variadic parameter with unnamed slice type or string as core type", core.String()))
}
}
sig := &Signature{recv: recv, params: params, results: results, variadic: variadic}
if len(recvTypeParams) != 0 {
if recv == nil {
panic("function with receiver type parameters must have a receiver")
}
sig.rparams = bindTParams(recvTypeParams)
}
if len(typeParams) != 0 {
if recv != nil {
panic("function with type parameters cannot have a receiver")
}
sig.tparams = bindTParams(typeParams)
}
return sig
}
// Recv returns the receiver of signature s (if a method), or nil if a
// function. It is ignored when comparing signatures for identity.
//
// For an abstract method, Recv returns the enclosing interface either
// as a *Named or an *Interface. Due to embedding, an interface may
// contain methods whose receiver type is a different interface.
func (s *Signature) Recv() *Var { return s.recv }
// TypeParams returns the type parameters of signature s, or nil.
func (s *Signature) TypeParams() *TypeParamList { return s.tparams }
// RecvTypeParams returns the receiver type parameters of signature s, or nil.
func (s *Signature) RecvTypeParams() *TypeParamList { return s.rparams }
// Params returns the parameters of signature s, or nil.
func (s *Signature) Params() *Tuple { return s.params }
// Results returns the results of signature s, or nil.
func (s *Signature) Results() *Tuple { return s.results }
// Variadic reports whether the signature s is variadic.
func (s *Signature) Variadic() bool { return s.variadic }
func (t *Signature) Underlying() Type { return t }
func (t *Signature) String() string { return TypeString(t, nil) }
// ----------------------------------------------------------------------------
// Implementation
// funcType type-checks a function or method type.
func (check *Checker) funcType(sig *Signature, recvPar *ast.FieldList, ftyp *ast.FuncType) {
check.openScope(ftyp, "function")
check.scope.isFunc = true
check.recordScope(ftyp, check.scope)
sig.scope = check.scope
defer check.closeScope()
if recvPar != nil && len(recvPar.List) > 0 {
// collect generic receiver type parameters, if any
// - a receiver type parameter is like any other type parameter, except that it is declared implicitly
// - the receiver specification acts as local declaration for its type parameters, which may be blank
_, rname, rparams := check.unpackRecv(recvPar.List[0].Type, true)
if len(rparams) > 0 {
tparams := check.declareTypeParams(nil, rparams)
sig.rparams = bindTParams(tparams)
// Blank identifiers don't get declared, so naive type-checking of the
// receiver type expression would fail in Checker.collectParams below,
// when Checker.ident cannot resolve the _ to a type.
//
// Checker.recvTParamMap maps these blank identifiers to their type parameter
// types, so that they may be resolved in Checker.ident when they fail
// lookup in the scope.
for i, p := range rparams {
if p.Name == "_" {
if check.recvTParamMap == nil {
check.recvTParamMap = make(map[*ast.Ident]*TypeParam)
}
check.recvTParamMap[p] = tparams[i]
}
}
// determine receiver type to get its type parameters
// and the respective type parameter bounds
var recvTParams []*TypeParam
if rname != nil {
// recv should be a Named type (otherwise an error is reported elsewhere)
// Also: Don't report an error via genericType since it will be reported
// again when we type-check the signature.
// TODO(gri) maybe the receiver should be marked as invalid instead?
if recv, _ := check.genericType(rname, nil).(*Named); recv != nil {
recvTParams = recv.TypeParams().list()
}
}
// provide type parameter bounds
if len(tparams) == len(recvTParams) {
smap := makeRenameMap(recvTParams, tparams)
for i, tpar := range tparams {
recvTPar := recvTParams[i]
check.mono.recordCanon(tpar, recvTPar)
// recvTPar.bound is (possibly) parameterized in the context of the
// receiver type declaration. Substitute parameters for the current
// context.
tpar.bound = check.subst(tpar.obj.pos, recvTPar.bound, smap, nil, check.context())
}
} else if len(tparams) < len(recvTParams) {
// Reporting an error here is a stop-gap measure to avoid crashes in the
// compiler when a type parameter/argument cannot be inferred later. It
// may lead to follow-on errors (see issues go.dev/issue/51339, go.dev/issue/51343).
// TODO(gri) find a better solution
got := measure(len(tparams), "type parameter")
check.errorf(recvPar, BadRecv, "got %s, but receiver base type declares %d", got, len(recvTParams))
}
}
}
if ftyp.TypeParams != nil {
check.collectTypeParams(&sig.tparams, ftyp.TypeParams)
// Always type-check method type parameters but complain that they are not allowed.
// (A separate check is needed when type-checking interface method signatures because
// they don't have a receiver specification.)
if recvPar != nil {
check.error(ftyp.TypeParams, InvalidMethodTypeParams, "methods cannot have type parameters")
}
}
// Value (non-type) parameters' scope starts in the function body. Use a temporary scope for their
// declarations and then squash that scope into the parent scope (and report any redeclarations at
// that time).
scope := NewScope(check.scope, nopos, nopos, "function body (temp. scope)")
recvList, _ := check.collectParams(scope, recvPar, false)
params, variadic := check.collectParams(scope, ftyp.Params, true)
results, _ := check.collectParams(scope, ftyp.Results, false)
scope.squash(func(obj, alt Object) {
check.errorf(obj, DuplicateDecl, "%s redeclared in this block", obj.Name())
check.reportAltDecl(alt)
})
if recvPar != nil {
// recv parameter list present (may be empty)
// spec: "The receiver is specified via an extra parameter section preceding the
// method name. That parameter section must declare a single parameter, the receiver."
var recv *Var
switch len(recvList) {
case 0:
// error reported by resolver
recv = NewParam(nopos, nil, "", Typ[Invalid]) // ignore recv below
default:
// more than one receiver
check.error(recvList[len(recvList)-1], InvalidRecv, "method has multiple receivers")
fallthrough // continue with first receiver
case 1:
recv = recvList[0]
}
sig.recv = recv
// Delay validation of receiver type as it may cause premature expansion
// of types the receiver type is dependent on (see issues go.dev/issue/51232, go.dev/issue/51233).
check.later(func() {
// spec: "The receiver type must be of the form T or *T where T is a type name."
rtyp, _ := deref(recv.typ)
if rtyp == Typ[Invalid] {
return // error was reported before
}
// spec: "The type denoted by T is called the receiver base type; it must not
// be a pointer or interface type and it must be declared in the same package
// as the method."
switch T := rtyp.(type) {
case *Named:
// The receiver type may be an instantiated type referred to
// by an alias (which cannot have receiver parameters for now).
if T.TypeArgs() != nil && sig.RecvTypeParams() == nil {
check.errorf(recv, InvalidRecv, "cannot define new methods on instantiated type %s", rtyp)
break
}
if T.obj.pkg != check.pkg {
check.errorf(recv, InvalidRecv, "cannot define new methods on non-local type %s", rtyp)
break
}
var cause string
switch u := T.under().(type) {
case *Basic:
// unsafe.Pointer is treated like a regular pointer
if u.kind == UnsafePointer {
cause = "unsafe.Pointer"
}
case *Pointer, *Interface:
cause = "pointer or interface type"
case *TypeParam:
// The underlying type of a receiver base type cannot be a
// type parameter: "type T[P any] P" is not a valid declaration.
unreachable()
}
if cause != "" {
check.errorf(recv, InvalidRecv, "invalid receiver type %s (%s)", rtyp, cause)
}
case *Basic:
check.errorf(recv, InvalidRecv, "cannot define new methods on non-local type %s", rtyp)
default:
check.errorf(recv, InvalidRecv, "invalid receiver type %s", recv.typ)
}
}).describef(recv, "validate receiver %s", recv)
}
sig.params = NewTuple(params...)
sig.results = NewTuple(results...)
sig.variadic = variadic
}
// collectParams declares the parameters of list in scope and returns the corresponding
// variable list.
func (check *Checker) collectParams(scope *Scope, list *ast.FieldList, variadicOk bool) (params []*Var, variadic bool) {
if list == nil {
return
}
var named, anonymous bool
for i, field := range list.List {
ftype := field.Type
if t, _ := ftype.(*ast.Ellipsis); t != nil {
ftype = t.Elt
if variadicOk && i == len(list.List)-1 && len(field.Names) <= 1 {
variadic = true
} else {
check.softErrorf(t, MisplacedDotDotDot, "can only use ... with final parameter in list")
// ignore ... and continue
}
}
typ := check.varType(ftype)
// The parser ensures that f.Tag is nil and we don't
// care if a constructed AST contains a non-nil tag.
if len(field.Names) > 0 {
// named parameter
for _, name := range field.Names {
if name.Name == "" {
check.error(name, InvalidSyntaxTree, "anonymous parameter")
// ok to continue
}
par := NewParam(name.Pos(), check.pkg, name.Name, typ)
check.declare(scope, name, par, scope.pos)
params = append(params, par)
}
named = true
} else {
// anonymous parameter
par := NewParam(ftype.Pos(), check.pkg, "", typ)
check.recordImplicit(field, par)
params = append(params, par)
anonymous = true
}
}
if named && anonymous {
check.error(list, InvalidSyntaxTree, "list contains both named and anonymous parameters")
// ok to continue
}
// For a variadic function, change the last parameter's type from T to []T.
// Since we type-checked T rather than ...T, we also need to retro-actively
// record the type for ...T.
if variadic {
last := params[len(params)-1]
last.typ = &Slice{elem: last.typ}
check.recordTypeAndValue(list.List[len(list.List)-1].Type, typexpr, last.typ, nil)
}
return
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements Sizes.
package types
// Sizes defines the sizing functions for package unsafe.
type Sizes interface {
// Alignof returns the alignment of a variable of type T.
// Alignof must implement the alignment guarantees required by the spec.
Alignof(T Type) int64
// Offsetsof returns the offsets of the given struct fields, in bytes.
// Offsetsof must implement the offset guarantees required by the spec.
Offsetsof(fields []*Var) []int64
// Sizeof returns the size of a variable of type T.
// Sizeof must implement the size guarantees required by the spec.
Sizeof(T Type) int64
}
// StdSizes is a convenience type for creating commonly used Sizes.
// It makes the following simplifying assumptions:
//
// - The size of explicitly sized basic types (int16, etc.) is the
// specified size.
// - The size of strings and interfaces is 2*WordSize.
// - The size of slices is 3*WordSize.
// - The size of an array of n elements corresponds to the size of
// a struct of n consecutive fields of the array's element type.
// - The size of a struct is the offset of the last field plus that
// field's size. As with all element types, if the struct is used
// in an array its size must first be aligned to a multiple of the
// struct's alignment.
// - All other types have size WordSize.
// - Arrays and structs are aligned per spec definition; all other
// types are naturally aligned with a maximum alignment MaxAlign.
//
// *StdSizes implements Sizes.
type StdSizes struct {
WordSize int64 // word size in bytes - must be >= 4 (32bits)
MaxAlign int64 // maximum alignment in bytes - must be >= 1
}
func (s *StdSizes) Alignof(T Type) int64 {
// For arrays and structs, alignment is defined in terms
// of alignment of the elements and fields, respectively.
switch t := under(T).(type) {
case *Array:
// spec: "For a variable x of array type: unsafe.Alignof(x)
// is the same as unsafe.Alignof(x[0]), but at least 1."
return s.Alignof(t.elem)
case *Struct:
if len(t.fields) == 0 && _IsSyncAtomicAlign64(T) {
// Special case: sync/atomic.align64 is an
// empty struct we recognize as a signal that
// the struct it contains must be
// 64-bit-aligned.
//
// This logic is equivalent to the logic in
// cmd/compile/internal/types/size.go:calcStructOffset
return 8
}
// spec: "For a variable x of struct type: unsafe.Alignof(x)
// is the largest of the values unsafe.Alignof(x.f) for each
// field f of x, but at least 1."
max := int64(1)
for _, f := range t.fields {
if a := s.Alignof(f.typ); a > max {
max = a
}
}
return max
case *Slice, *Interface:
// Multiword data structures are effectively structs
// in which each element has size WordSize.
// Type parameters lead to variable sizes/alignments;
// StdSizes.Alignof won't be called for them.
assert(!isTypeParam(T))
return s.WordSize
case *Basic:
// Strings are like slices and interfaces.
if t.Info()&IsString != 0 {
return s.WordSize
}
case *TypeParam, *Union:
unreachable()
}
a := s.Sizeof(T) // may be 0
// spec: "For a variable x of any type: unsafe.Alignof(x) is at least 1."
if a < 1 {
return 1
}
// complex{64,128} are aligned like [2]float{32,64}.
if isComplex(T) {
a /= 2
}
if a > s.MaxAlign {
return s.MaxAlign
}
return a
}
func _IsSyncAtomicAlign64(T Type) bool {
named, ok := T.(*Named)
if !ok {
return false
}
obj := named.Obj()
return obj.Name() == "align64" &&
obj.Pkg() != nil &&
(obj.Pkg().Path() == "sync/atomic" ||
obj.Pkg().Path() == "runtime/internal/atomic")
}
func (s *StdSizes) Offsetsof(fields []*Var) []int64 {
offsets := make([]int64, len(fields))
var o int64
for i, f := range fields {
a := s.Alignof(f.typ)
o = align(o, a)
offsets[i] = o
o += s.Sizeof(f.typ)
}
return offsets
}
var basicSizes = [...]byte{
Bool: 1,
Int8: 1,
Int16: 2,
Int32: 4,
Int64: 8,
Uint8: 1,
Uint16: 2,
Uint32: 4,
Uint64: 8,
Float32: 4,
Float64: 8,
Complex64: 8,
Complex128: 16,
}
func (s *StdSizes) Sizeof(T Type) int64 {
switch t := under(T).(type) {
case *Basic:
assert(isTyped(T))
k := t.kind
if int(k) < len(basicSizes) {
if s := basicSizes[k]; s > 0 {
return int64(s)
}
}
if k == String {
return s.WordSize * 2
}
case *Array:
n := t.len
if n <= 0 {
return 0
}
// n > 0
a := s.Alignof(t.elem)
z := s.Sizeof(t.elem)
return align(z, a)*(n-1) + z
case *Slice:
return s.WordSize * 3
case *Struct:
n := t.NumFields()
if n == 0 {
return 0
}
offsets := s.Offsetsof(t.fields)
return offsets[n-1] + s.Sizeof(t.fields[n-1].typ)
case *Interface:
// Type parameters lead to variable sizes/alignments;
// StdSizes.Sizeof won't be called for them.
assert(!isTypeParam(T))
return s.WordSize * 2
case *TypeParam, *Union:
unreachable()
}
return s.WordSize // catch-all
}
// common architecture word sizes and alignments
var gcArchSizes = map[string]*StdSizes{
"386": {4, 4},
"amd64": {8, 8},
"amd64p32": {4, 8},
"arm": {4, 4},
"arm64": {8, 8},
"loong64": {8, 8},
"mips": {4, 4},
"mipsle": {4, 4},
"mips64": {8, 8},
"mips64le": {8, 8},
"ppc64": {8, 8},
"ppc64le": {8, 8},
"riscv64": {8, 8},
"s390x": {8, 8},
"sparc64": {8, 8},
"wasm": {8, 8},
// When adding more architectures here,
// update the doc string of SizesFor below.
}
// SizesFor returns the Sizes used by a compiler for an architecture.
// The result is nil if a compiler/architecture pair is not known.
//
// Supported architectures for compiler "gc":
// "386", "amd64", "amd64p32", "arm", "arm64", "loong64", "mips", "mipsle",
// "mips64", "mips64le", "ppc64", "ppc64le", "riscv64", "s390x", "sparc64", "wasm".
func SizesFor(compiler, arch string) Sizes {
var m map[string]*StdSizes
switch compiler {
case "gc":
m = gcArchSizes
case "gccgo":
m = gccgoArchSizes
default:
return nil
}
s, ok := m[arch]
if !ok {
return nil
}
return s
}
// stdSizes is used if Config.Sizes == nil.
var stdSizes = SizesFor("gc", "amd64")
func (conf *Config) alignof(T Type) int64 {
if s := conf.Sizes; s != nil {
if a := s.Alignof(T); a >= 1 {
return a
}
panic("Config.Sizes.Alignof returned an alignment < 1")
}
return stdSizes.Alignof(T)
}
func (conf *Config) offsetsof(T *Struct) []int64 {
var offsets []int64
if T.NumFields() > 0 {
// compute offsets on demand
if s := conf.Sizes; s != nil {
offsets = s.Offsetsof(T.fields)
// sanity checks
if len(offsets) != T.NumFields() {
panic("Config.Sizes.Offsetsof returned the wrong number of offsets")
}
for _, o := range offsets {
if o < 0 {
panic("Config.Sizes.Offsetsof returned an offset < 0")
}
}
} else {
offsets = stdSizes.Offsetsof(T.fields)
}
}
return offsets
}
// offsetof returns the offset of the field specified via
// the index sequence relative to typ. All embedded fields
// must be structs (rather than pointer to structs).
func (conf *Config) offsetof(typ Type, index []int) int64 {
var o int64
for _, i := range index {
s := under(typ).(*Struct)
o += conf.offsetsof(s)[i]
typ = s.fields[i].typ
}
return o
}
func (conf *Config) sizeof(T Type) int64 {
if s := conf.Sizes; s != nil {
if z := s.Sizeof(T); z >= 0 {
return z
}
panic("Config.Sizes.Sizeof returned a size < 0")
}
return stdSizes.Sizeof(T)
}
// align returns the smallest y >= x such that y % a == 0.
func align(x, a int64) int64 {
y := x + a - 1
return y - y%a
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A Slice represents a slice type.
type Slice struct {
elem Type
}
// NewSlice returns a new slice type for the given element type.
func NewSlice(elem Type) *Slice { return &Slice{elem: elem} }
// Elem returns the element type of slice s.
func (s *Slice) Elem() Type { return s.elem }
func (s *Slice) Underlying() Type { return s }
func (s *Slice) String() string { return TypeString(s, nil) }
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements typechecking of statements.
package types
import (
"go/ast"
"go/constant"
"go/token"
. "internal/types/errors"
"sort"
)
func (check *Checker) funcBody(decl *declInfo, name string, sig *Signature, body *ast.BlockStmt, iota constant.Value) {
if check.conf.IgnoreFuncBodies {
panic("function body not ignored")
}
if check.conf._Trace {
check.trace(body.Pos(), "-- %s: %s", name, sig)
}
// set function scope extent
sig.scope.pos = body.Pos()
sig.scope.end = body.End()
// save/restore current environment and set up function environment
// (and use 0 indentation at function start)
defer func(env environment, indent int) {
check.environment = env
check.indent = indent
}(check.environment, check.indent)
check.environment = environment{
decl: decl,
scope: sig.scope,
iota: iota,
sig: sig,
}
check.indent = 0
check.stmtList(0, body.List)
if check.hasLabel {
check.labels(body)
}
if sig.results.Len() > 0 && !check.isTerminating(body, "") {
check.error(atPos(body.Rbrace), MissingReturn, "missing return")
}
// spec: "Implementation restriction: A compiler may make it illegal to
// declare a variable inside a function body if the variable is never used."
check.usage(sig.scope)
}
func (check *Checker) usage(scope *Scope) {
var unused []*Var
for name, elem := range scope.elems {
elem = resolve(name, elem)
if v, _ := elem.(*Var); v != nil && !v.used {
unused = append(unused, v)
}
}
sort.Slice(unused, func(i, j int) bool {
return cmpPos(unused[i].pos, unused[j].pos) < 0
})
for _, v := range unused {
check.softErrorf(v, UnusedVar, "%s declared and not used", v.name)
}
for _, scope := range scope.children {
// Don't go inside function literal scopes a second time;
// they are handled explicitly by funcBody.
if !scope.isFunc {
check.usage(scope)
}
}
}
// stmtContext is a bitset describing which
// control-flow statements are permissible,
// and provides additional context information
// for better error messages.
type stmtContext uint
const (
// permissible control-flow statements
breakOk stmtContext = 1 << iota
continueOk
fallthroughOk
// additional context information
finalSwitchCase
inTypeSwitch
)
func (check *Checker) simpleStmt(s ast.Stmt) {
if s != nil {
check.stmt(0, s)
}
}
func trimTrailingEmptyStmts(list []ast.Stmt) []ast.Stmt {
for i := len(list); i > 0; i-- {
if _, ok := list[i-1].(*ast.EmptyStmt); !ok {
return list[:i]
}
}
return nil
}
func (check *Checker) stmtList(ctxt stmtContext, list []ast.Stmt) {
ok := ctxt&fallthroughOk != 0
inner := ctxt &^ fallthroughOk
list = trimTrailingEmptyStmts(list) // trailing empty statements are "invisible" to fallthrough analysis
for i, s := range list {
inner := inner
if ok && i+1 == len(list) {
inner |= fallthroughOk
}
check.stmt(inner, s)
}
}
func (check *Checker) multipleDefaults(list []ast.Stmt) {
var first ast.Stmt
for _, s := range list {
var d ast.Stmt
switch c := s.(type) {
case *ast.CaseClause:
if len(c.List) == 0 {
d = s
}
case *ast.CommClause:
if c.Comm == nil {
d = s
}
default:
check.error(s, InvalidSyntaxTree, "case/communication clause expected")
}
if d != nil {
if first != nil {
check.errorf(d, DuplicateDefault, "multiple defaults (first at %s)", check.fset.Position(first.Pos()))
} else {
first = d
}
}
}
}
func (check *Checker) openScope(node ast.Node, comment string) {
scope := NewScope(check.scope, node.Pos(), node.End(), comment)
check.recordScope(node, scope)
check.scope = scope
}
func (check *Checker) closeScope() {
check.scope = check.scope.Parent()
}
func assignOp(op token.Token) token.Token {
// token_test.go verifies the token ordering this function relies on
if token.ADD_ASSIGN <= op && op <= token.AND_NOT_ASSIGN {
return op + (token.ADD - token.ADD_ASSIGN)
}
return token.ILLEGAL
}
func (check *Checker) suspendedCall(keyword string, call *ast.CallExpr) {
var x operand
var msg string
var code Code
switch check.rawExpr(&x, call, nil, false) {
case conversion:
msg = "requires function call, not conversion"
code = InvalidDefer
if keyword == "go" {
code = InvalidGo
}
case expression:
msg = "discards result of"
code = UnusedResults
case statement:
return
default:
unreachable()
}
check.errorf(&x, code, "%s %s %s", keyword, msg, &x)
}
// goVal returns the Go value for val, or nil.
func goVal(val constant.Value) any {
// val should exist, but be conservative and check
if val == nil {
return nil
}
// Match implementation restriction of other compilers.
// gc only checks duplicates for integer, floating-point
// and string values, so only create Go values for these
// types.
switch val.Kind() {
case constant.Int:
if x, ok := constant.Int64Val(val); ok {
return x
}
if x, ok := constant.Uint64Val(val); ok {
return x
}
case constant.Float:
if x, ok := constant.Float64Val(val); ok {
return x
}
case constant.String:
return constant.StringVal(val)
}
return nil
}
// A valueMap maps a case value (of a basic Go type) to a list of positions
// where the same case value appeared, together with the corresponding case
// types.
// Since two case values may have the same "underlying" value but different
// types we need to also check the value's types (e.g., byte(1) vs myByte(1))
// when the switch expression is of interface type.
type (
valueMap map[any][]valueType // underlying Go value -> valueType
valueType struct {
pos token.Pos
typ Type
}
)
func (check *Checker) caseValues(x *operand, values []ast.Expr, seen valueMap) {
L:
for _, e := range values {
var v operand
check.expr(&v, e)
if x.mode == invalid || v.mode == invalid {
continue L
}
check.convertUntyped(&v, x.typ)
if v.mode == invalid {
continue L
}
// Order matters: By comparing v against x, error positions are at the case values.
res := v // keep original v unchanged
check.comparison(&res, x, token.EQL, true)
if res.mode == invalid {
continue L
}
if v.mode != constant_ {
continue L // we're done
}
// look for duplicate values
if val := goVal(v.val); val != nil {
// look for duplicate types for a given value
// (quadratic algorithm, but these lists tend to be very short)
for _, vt := range seen[val] {
if Identical(v.typ, vt.typ) {
check.errorf(&v, DuplicateCase, "duplicate case %s in expression switch", &v)
check.error(atPos(vt.pos), DuplicateCase, "\tprevious case") // secondary error, \t indented
continue L
}
}
seen[val] = append(seen[val], valueType{v.Pos(), v.typ})
}
}
}
// isNil reports whether the expression e denotes the predeclared value nil.
func (check *Checker) isNil(e ast.Expr) bool {
// The only way to express the nil value is by literally writing nil (possibly in parentheses).
if name, _ := unparen(e).(*ast.Ident); name != nil {
_, ok := check.lookup(name.Name).(*Nil)
return ok
}
return false
}
// If the type switch expression is invalid, x is nil.
func (check *Checker) caseTypes(x *operand, types []ast.Expr, seen map[Type]ast.Expr) (T Type) {
var dummy operand
L:
for _, e := range types {
// The spec allows the value nil instead of a type.
if check.isNil(e) {
T = nil
check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
} else {
T = check.varType(e)
if T == Typ[Invalid] {
continue L
}
}
// look for duplicate types
// (quadratic algorithm, but type switches tend to be reasonably small)
for t, other := range seen {
if T == nil && t == nil || T != nil && t != nil && Identical(T, t) {
// talk about "case" rather than "type" because of nil case
Ts := "nil"
if T != nil {
Ts = TypeString(T, check.qualifier)
}
check.errorf(e, DuplicateCase, "duplicate case %s in type switch", Ts)
check.error(other, DuplicateCase, "\tprevious case") // secondary error, \t indented
continue L
}
}
seen[T] = e
if x != nil && T != nil {
check.typeAssertion(e, x, T, true)
}
}
return
}
// TODO(gri) Once we are certain that typeHash is correct in all situations, use this version of caseTypes instead.
// (Currently it may be possible that different types have identical names and import paths due to ImporterFrom.)
//
// func (check *Checker) caseTypes(x *operand, xtyp *Interface, types []ast.Expr, seen map[string]ast.Expr) (T Type) {
// var dummy operand
// L:
// for _, e := range types {
// // The spec allows the value nil instead of a type.
// var hash string
// if check.isNil(e) {
// check.expr(&dummy, e) // run e through expr so we get the usual Info recordings
// T = nil
// hash = "<nil>" // avoid collision with a type named nil
// } else {
// T = check.varType(e)
// if T == Typ[Invalid] {
// continue L
// }
// hash = typeHash(T, nil)
// }
// // look for duplicate types
// if other := seen[hash]; other != nil {
// // talk about "case" rather than "type" because of nil case
// Ts := "nil"
// if T != nil {
// Ts = TypeString(T, check.qualifier)
// }
// var err error_
// err.code = DuplicateCase
// err.errorf(e, "duplicate case %s in type switch", Ts)
// err.errorf(other, "previous case")
// check.report(&err)
// continue L
// }
// seen[hash] = e
// if T != nil {
// check.typeAssertion(e.Pos(), x, xtyp, T)
// }
// }
// return
// }
// stmt typechecks statement s.
func (check *Checker) stmt(ctxt stmtContext, s ast.Stmt) {
// statements must end with the same top scope as they started with
if debug {
defer func(scope *Scope) {
// don't check if code is panicking
if p := recover(); p != nil {
panic(p)
}
assert(scope == check.scope)
}(check.scope)
}
// process collected function literals before scope changes
defer check.processDelayed(len(check.delayed))
// reset context for statements of inner blocks
inner := ctxt &^ (fallthroughOk | finalSwitchCase | inTypeSwitch)
switch s := s.(type) {
case *ast.BadStmt, *ast.EmptyStmt:
// ignore
case *ast.DeclStmt:
check.declStmt(s.Decl)
case *ast.LabeledStmt:
check.hasLabel = true
check.stmt(ctxt, s.Stmt)
case *ast.ExprStmt:
// spec: "With the exception of specific built-in functions,
// function and method calls and receive operations can appear
// in statement context. Such statements may be parenthesized."
var x operand
kind := check.rawExpr(&x, s.X, nil, false)
var msg string
var code Code
switch x.mode {
default:
if kind == statement {
return
}
msg = "is not used"
code = UnusedExpr
case builtin:
msg = "must be called"
code = UncalledBuiltin
case typexpr:
msg = "is not an expression"
code = NotAnExpr
}
check.errorf(&x, code, "%s %s", &x, msg)
case *ast.SendStmt:
var ch, val operand
check.expr(&ch, s.Chan)
check.expr(&val, s.Value)
if ch.mode == invalid || val.mode == invalid {
return
}
u := coreType(ch.typ)
if u == nil {
check.errorf(inNode(s, s.Arrow), InvalidSend, invalidOp+"cannot send to %s: no core type", &ch)
return
}
uch, _ := u.(*Chan)
if uch == nil {
check.errorf(inNode(s, s.Arrow), InvalidSend, invalidOp+"cannot send to non-channel %s", &ch)
return
}
if uch.dir == RecvOnly {
check.errorf(inNode(s, s.Arrow), InvalidSend, invalidOp+"cannot send to receive-only channel %s", &ch)
return
}
check.assignment(&val, uch.elem, "send")
case *ast.IncDecStmt:
var op token.Token
switch s.Tok {
case token.INC:
op = token.ADD
case token.DEC:
op = token.SUB
default:
check.errorf(inNode(s, s.TokPos), InvalidSyntaxTree, "unknown inc/dec operation %s", s.Tok)
return
}
var x operand
check.expr(&x, s.X)
if x.mode == invalid {
return
}
if !allNumeric(x.typ) {
check.errorf(s.X, NonNumericIncDec, invalidOp+"%s%s (non-numeric type %s)", s.X, s.Tok, x.typ)
return
}
Y := &ast.BasicLit{ValuePos: s.X.Pos(), Kind: token.INT, Value: "1"} // use x's position
check.binary(&x, nil, s.X, Y, op, s.TokPos)
if x.mode == invalid {
return
}
check.assignVar(s.X, &x)
case *ast.AssignStmt:
switch s.Tok {
case token.ASSIGN, token.DEFINE:
if len(s.Lhs) == 0 {
check.error(s, InvalidSyntaxTree, "missing lhs in assignment")
return
}
if s.Tok == token.DEFINE {
check.shortVarDecl(inNode(s, s.TokPos), s.Lhs, s.Rhs)
} else {
// regular assignment
check.assignVars(s.Lhs, s.Rhs)
}
default:
// assignment operations
if len(s.Lhs) != 1 || len(s.Rhs) != 1 {
check.errorf(inNode(s, s.TokPos), MultiValAssignOp, "assignment operation %s requires single-valued expressions", s.Tok)
return
}
op := assignOp(s.Tok)
if op == token.ILLEGAL {
check.errorf(atPos(s.TokPos), InvalidSyntaxTree, "unknown assignment operation %s", s.Tok)
return
}
var x operand
check.binary(&x, nil, s.Lhs[0], s.Rhs[0], op, s.TokPos)
if x.mode == invalid {
return
}
check.assignVar(s.Lhs[0], &x)
}
case *ast.GoStmt:
check.suspendedCall("go", s.Call)
case *ast.DeferStmt:
check.suspendedCall("defer", s.Call)
case *ast.ReturnStmt:
res := check.sig.results
// Return with implicit results allowed for function with named results.
// (If one is named, all are named.)
if len(s.Results) == 0 && res.Len() > 0 && res.vars[0].name != "" {
// spec: "Implementation restriction: A compiler may disallow an empty expression
// list in a "return" statement if a different entity (constant, type, or variable)
// with the same name as a result parameter is in scope at the place of the return."
for _, obj := range res.vars {
if alt := check.lookup(obj.name); alt != nil && alt != obj {
check.errorf(s, OutOfScopeResult, "result parameter %s not in scope at return", obj.name)
check.errorf(alt, OutOfScopeResult, "\tinner declaration of %s", obj)
// ok to continue
}
}
} else {
var lhs []*Var
if res.Len() > 0 {
lhs = res.vars
}
check.initVars(lhs, s.Results, s)
}
case *ast.BranchStmt:
if s.Label != nil {
check.hasLabel = true
return // checked in 2nd pass (check.labels)
}
switch s.Tok {
case token.BREAK:
if ctxt&breakOk == 0 {
check.error(s, MisplacedBreak, "break not in for, switch, or select statement")
}
case token.CONTINUE:
if ctxt&continueOk == 0 {
check.error(s, MisplacedContinue, "continue not in for statement")
}
case token.FALLTHROUGH:
if ctxt&fallthroughOk == 0 {
var msg string
switch {
case ctxt&finalSwitchCase != 0:
msg = "cannot fallthrough final case in switch"
case ctxt&inTypeSwitch != 0:
msg = "cannot fallthrough in type switch"
default:
msg = "fallthrough statement out of place"
}
check.error(s, MisplacedFallthrough, msg)
}
default:
check.errorf(s, InvalidSyntaxTree, "branch statement: %s", s.Tok)
}
case *ast.BlockStmt:
check.openScope(s, "block")
defer check.closeScope()
check.stmtList(inner, s.List)
case *ast.IfStmt:
check.openScope(s, "if")
defer check.closeScope()
check.simpleStmt(s.Init)
var x operand
check.expr(&x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in if statement")
}
check.stmt(inner, s.Body)
// The parser produces a correct AST but if it was modified
// elsewhere the else branch may be invalid. Check again.
switch s.Else.(type) {
case nil, *ast.BadStmt:
// valid or error already reported
case *ast.IfStmt, *ast.BlockStmt:
check.stmt(inner, s.Else)
default:
check.error(s.Else, InvalidSyntaxTree, "invalid else branch in if statement")
}
case *ast.SwitchStmt:
inner |= breakOk
check.openScope(s, "switch")
defer check.closeScope()
check.simpleStmt(s.Init)
var x operand
if s.Tag != nil {
check.expr(&x, s.Tag)
// By checking assignment of x to an invisible temporary
// (as a compiler would), we get all the relevant checks.
check.assignment(&x, nil, "switch expression")
if x.mode != invalid && !Comparable(x.typ) && !hasNil(x.typ) {
check.errorf(&x, InvalidExprSwitch, "cannot switch on %s (%s is not comparable)", &x, x.typ)
x.mode = invalid
}
} else {
// spec: "A missing switch expression is
// equivalent to the boolean value true."
x.mode = constant_
x.typ = Typ[Bool]
x.val = constant.MakeBool(true)
x.expr = &ast.Ident{NamePos: s.Body.Lbrace, Name: "true"}
}
check.multipleDefaults(s.Body.List)
seen := make(valueMap) // map of seen case values to positions and types
for i, c := range s.Body.List {
clause, _ := c.(*ast.CaseClause)
if clause == nil {
check.error(c, InvalidSyntaxTree, "incorrect expression switch case")
continue
}
check.caseValues(&x, clause.List, seen)
check.openScope(clause, "case")
inner := inner
if i+1 < len(s.Body.List) {
inner |= fallthroughOk
} else {
inner |= finalSwitchCase
}
check.stmtList(inner, clause.Body)
check.closeScope()
}
case *ast.TypeSwitchStmt:
inner |= breakOk | inTypeSwitch
check.openScope(s, "type switch")
defer check.closeScope()
check.simpleStmt(s.Init)
// A type switch guard must be of the form:
//
// TypeSwitchGuard = [ identifier ":=" ] PrimaryExpr "." "(" "type" ")" .
//
// The parser is checking syntactic correctness;
// remaining syntactic errors are considered AST errors here.
// TODO(gri) better factoring of error handling (invalid ASTs)
//
var lhs *ast.Ident // lhs identifier or nil
var rhs ast.Expr
switch guard := s.Assign.(type) {
case *ast.ExprStmt:
rhs = guard.X
case *ast.AssignStmt:
if len(guard.Lhs) != 1 || guard.Tok != token.DEFINE || len(guard.Rhs) != 1 {
check.error(s, InvalidSyntaxTree, "incorrect form of type switch guard")
return
}
lhs, _ = guard.Lhs[0].(*ast.Ident)
if lhs == nil {
check.error(s, InvalidSyntaxTree, "incorrect form of type switch guard")
return
}
if lhs.Name == "_" {
// _ := x.(type) is an invalid short variable declaration
check.softErrorf(lhs, NoNewVar, "no new variable on left side of :=")
lhs = nil // avoid declared and not used error below
} else {
check.recordDef(lhs, nil) // lhs variable is implicitly declared in each cause clause
}
rhs = guard.Rhs[0]
default:
check.error(s, InvalidSyntaxTree, "incorrect form of type switch guard")
return
}
// rhs must be of the form: expr.(type) and expr must be an ordinary interface
expr, _ := rhs.(*ast.TypeAssertExpr)
if expr == nil || expr.Type != nil {
check.error(s, InvalidSyntaxTree, "incorrect form of type switch guard")
return
}
var x operand
check.expr(&x, expr.X)
if x.mode == invalid {
return
}
// TODO(gri) we may want to permit type switches on type parameter values at some point
var sx *operand // switch expression against which cases are compared against; nil if invalid
if isTypeParam(x.typ) {
check.errorf(&x, InvalidTypeSwitch, "cannot use type switch on type parameter value %s", &x)
} else {
if _, ok := under(x.typ).(*Interface); ok {
sx = &x
} else {
check.errorf(&x, InvalidTypeSwitch, "%s is not an interface", &x)
}
}
check.multipleDefaults(s.Body.List)
var lhsVars []*Var // list of implicitly declared lhs variables
seen := make(map[Type]ast.Expr) // map of seen types to positions
for _, s := range s.Body.List {
clause, _ := s.(*ast.CaseClause)
if clause == nil {
check.error(s, InvalidSyntaxTree, "incorrect type switch case")
continue
}
// Check each type in this type switch case.
T := check.caseTypes(sx, clause.List, seen)
check.openScope(clause, "case")
// If lhs exists, declare a corresponding variable in the case-local scope.
if lhs != nil {
// spec: "The TypeSwitchGuard may include a short variable declaration.
// When that form is used, the variable is declared at the beginning of
// the implicit block in each clause. In clauses with a case listing
// exactly one type, the variable has that type; otherwise, the variable
// has the type of the expression in the TypeSwitchGuard."
if len(clause.List) != 1 || T == nil {
T = x.typ
}
obj := NewVar(lhs.Pos(), check.pkg, lhs.Name, T)
scopePos := clause.Pos() + token.Pos(len("default")) // for default clause (len(List) == 0)
if n := len(clause.List); n > 0 {
scopePos = clause.List[n-1].End()
}
check.declare(check.scope, nil, obj, scopePos)
check.recordImplicit(clause, obj)
// For the "declared and not used" error, all lhs variables act as
// one; i.e., if any one of them is 'used', all of them are 'used'.
// Collect them for later analysis.
lhsVars = append(lhsVars, obj)
}
check.stmtList(inner, clause.Body)
check.closeScope()
}
// If lhs exists, we must have at least one lhs variable that was used.
if lhs != nil {
var used bool
for _, v := range lhsVars {
if v.used {
used = true
}
v.used = true // avoid usage error when checking entire function
}
if !used {
check.softErrorf(lhs, UnusedVar, "%s declared and not used", lhs.Name)
}
}
case *ast.SelectStmt:
inner |= breakOk
check.multipleDefaults(s.Body.List)
for _, s := range s.Body.List {
clause, _ := s.(*ast.CommClause)
if clause == nil {
continue // error reported before
}
// clause.Comm must be a SendStmt, RecvStmt, or default case
valid := false
var rhs ast.Expr // rhs of RecvStmt, or nil
switch s := clause.Comm.(type) {
case nil, *ast.SendStmt:
valid = true
case *ast.AssignStmt:
if len(s.Rhs) == 1 {
rhs = s.Rhs[0]
}
case *ast.ExprStmt:
rhs = s.X
}
// if present, rhs must be a receive operation
if rhs != nil {
if x, _ := unparen(rhs).(*ast.UnaryExpr); x != nil && x.Op == token.ARROW {
valid = true
}
}
if !valid {
check.error(clause.Comm, InvalidSelectCase, "select case must be send or receive (possibly with assignment)")
continue
}
check.openScope(s, "case")
if clause.Comm != nil {
check.stmt(inner, clause.Comm)
}
check.stmtList(inner, clause.Body)
check.closeScope()
}
case *ast.ForStmt:
inner |= breakOk | continueOk
check.openScope(s, "for")
defer check.closeScope()
check.simpleStmt(s.Init)
if s.Cond != nil {
var x operand
check.expr(&x, s.Cond)
if x.mode != invalid && !allBoolean(x.typ) {
check.error(s.Cond, InvalidCond, "non-boolean condition in for statement")
}
}
check.simpleStmt(s.Post)
// spec: "The init statement may be a short variable
// declaration, but the post statement must not."
if s, _ := s.Post.(*ast.AssignStmt); s != nil && s.Tok == token.DEFINE {
check.softErrorf(s, InvalidPostDecl, "cannot declare in post statement")
// Don't call useLHS here because we want to use the lhs in
// this erroneous statement so that we don't get errors about
// these lhs variables being declared and not used.
check.use(s.Lhs...) // avoid follow-up errors
}
check.stmt(inner, s.Body)
case *ast.RangeStmt:
inner |= breakOk | continueOk
// check expression to iterate over
var x operand
check.expr(&x, s.X)
// determine key/value types
var key, val Type
if x.mode != invalid {
// Ranging over a type parameter is permitted if it has a core type.
var cause string
u := coreType(x.typ)
switch t := u.(type) {
case nil:
cause = check.sprintf("%s has no core type", x.typ)
case *Chan:
if s.Value != nil {
check.softErrorf(s.Value, InvalidIterVar, "range over %s permits only one iteration variable", &x)
// ok to continue
}
if t.dir == SendOnly {
cause = "receive from send-only channel"
}
}
key, val = rangeKeyVal(u)
if key == nil || cause != "" {
if cause == "" {
check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s", &x)
} else {
check.softErrorf(&x, InvalidRangeExpr, "cannot range over %s (%s)", &x, cause)
}
// ok to continue
}
}
// Open the for-statement block scope now, after the range clause.
// Iteration variables declared with := need to go in this scope (was go.dev/issue/51437).
check.openScope(s, "range")
defer check.closeScope()
// check assignment to/declaration of iteration variables
// (irregular assignment, cannot easily map to existing assignment checks)
// lhs expressions and initialization value (rhs) types
lhs := [2]ast.Expr{s.Key, s.Value}
rhs := [2]Type{key, val} // key, val may be nil
if s.Tok == token.DEFINE {
// short variable declaration
var vars []*Var
for i, lhs := range lhs {
if lhs == nil {
continue
}
// determine lhs variable
var obj *Var
if ident, _ := lhs.(*ast.Ident); ident != nil {
// declare new variable
name := ident.Name
obj = NewVar(ident.Pos(), check.pkg, name, nil)
check.recordDef(ident, obj)
// _ variables don't count as new variables
if name != "_" {
vars = append(vars, obj)
}
} else {
check.errorf(lhs, InvalidSyntaxTree, "cannot declare %s", lhs)
obj = NewVar(lhs.Pos(), check.pkg, "_", nil) // dummy variable
}
// initialize lhs variable
if typ := rhs[i]; typ != nil {
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
check.initVar(obj, &x, "range clause")
} else {
obj.typ = Typ[Invalid]
obj.used = true // don't complain about unused variable
}
}
// declare variables
if len(vars) > 0 {
scopePos := s.Body.Pos()
for _, obj := range vars {
check.declare(check.scope, nil /* recordDef already called */, obj, scopePos)
}
} else {
check.error(inNode(s, s.TokPos), NoNewVar, "no new variables on left side of :=")
}
} else {
// ordinary assignment
for i, lhs := range lhs {
if lhs == nil {
continue
}
if typ := rhs[i]; typ != nil {
x.mode = value
x.expr = lhs // we don't have a better rhs expression to use here
x.typ = typ
check.assignVar(lhs, &x)
}
}
}
check.stmt(inner, s.Body)
default:
check.error(s, InvalidSyntaxTree, "invalid statement")
}
}
// rangeKeyVal returns the key and value type produced by a range clause
// over an expression of type typ. If the range clause is not permitted
// the results are nil.
func rangeKeyVal(typ Type) (key, val Type) {
switch typ := arrayPtrDeref(typ).(type) {
case *Basic:
if isString(typ) {
return Typ[Int], universeRune // use 'rune' name
}
case *Array:
return Typ[Int], typ.elem
case *Slice:
return Typ[Int], typ.elem
case *Map:
return typ.key, typ.elem
case *Chan:
return typ.elem, Typ[Invalid]
}
return
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/token"
. "internal/types/errors"
"strconv"
)
// ----------------------------------------------------------------------------
// API
// A Struct represents a struct type.
type Struct struct {
fields []*Var // fields != nil indicates the struct is set up (possibly with len(fields) == 0)
tags []string // field tags; nil if there are no tags
}
// NewStruct returns a new struct with the given fields and corresponding field tags.
// If a field with index i has a tag, tags[i] must be that tag, but len(tags) may be
// only as long as required to hold the tag with the largest index i. Consequently,
// if no field has a tag, tags may be nil.
func NewStruct(fields []*Var, tags []string) *Struct {
var fset objset
for _, f := range fields {
if f.name != "_" && fset.insert(f) != nil {
panic("multiple fields with the same name")
}
}
if len(tags) > len(fields) {
panic("more tags than fields")
}
s := &Struct{fields: fields, tags: tags}
s.markComplete()
return s
}
// NumFields returns the number of fields in the struct (including blank and embedded fields).
func (s *Struct) NumFields() int { return len(s.fields) }
// Field returns the i'th field for 0 <= i < NumFields().
func (s *Struct) Field(i int) *Var { return s.fields[i] }
// Tag returns the i'th field tag for 0 <= i < NumFields().
func (s *Struct) Tag(i int) string {
if i < len(s.tags) {
return s.tags[i]
}
return ""
}
func (t *Struct) Underlying() Type { return t }
func (t *Struct) String() string { return TypeString(t, nil) }
// ----------------------------------------------------------------------------
// Implementation
func (s *Struct) markComplete() {
if s.fields == nil {
s.fields = make([]*Var, 0)
}
}
func (check *Checker) structType(styp *Struct, e *ast.StructType) {
list := e.Fields
if list == nil {
styp.markComplete()
return
}
// struct fields and tags
var fields []*Var
var tags []string
// for double-declaration checks
var fset objset
// current field typ and tag
var typ Type
var tag string
add := func(ident *ast.Ident, embedded bool, pos token.Pos) {
if tag != "" && tags == nil {
tags = make([]string, len(fields))
}
if tags != nil {
tags = append(tags, tag)
}
name := ident.Name
fld := NewField(pos, check.pkg, name, typ, embedded)
// spec: "Within a struct, non-blank field names must be unique."
if name == "_" || check.declareInSet(&fset, pos, fld) {
fields = append(fields, fld)
check.recordDef(ident, fld)
}
}
// addInvalid adds an embedded field of invalid type to the struct for
// fields with errors; this keeps the number of struct fields in sync
// with the source as long as the fields are _ or have different names
// (go.dev/issue/25627).
addInvalid := func(ident *ast.Ident, pos token.Pos) {
typ = Typ[Invalid]
tag = ""
add(ident, true, pos)
}
for _, f := range list.List {
typ = check.varType(f.Type)
tag = check.tag(f.Tag)
if len(f.Names) > 0 {
// named fields
for _, name := range f.Names {
add(name, false, name.Pos())
}
} else {
// embedded field
// spec: "An embedded type must be specified as a type name T or as a
// pointer to a non-interface type name *T, and T itself may not be a
// pointer type."
pos := f.Type.Pos()
name := embeddedFieldIdent(f.Type)
if name == nil {
check.errorf(f.Type, InvalidSyntaxTree, "embedded field type %s has no name", f.Type)
name = ast.NewIdent("_")
name.NamePos = pos
addInvalid(name, pos)
continue
}
add(name, true, pos)
// Because we have a name, typ must be of the form T or *T, where T is the name
// of a (named or alias) type, and t (= deref(typ)) must be the type of T.
// We must delay this check to the end because we don't want to instantiate
// (via under(t)) a possibly incomplete type.
// for use in the closure below
embeddedTyp := typ
embeddedPos := f.Type
check.later(func() {
t, isPtr := deref(embeddedTyp)
switch u := under(t).(type) {
case *Basic:
if t == Typ[Invalid] {
// error was reported before
return
}
// unsafe.Pointer is treated like a regular pointer
if u.kind == UnsafePointer {
check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be unsafe.Pointer")
}
case *Pointer:
check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer")
case *Interface:
if isTypeParam(t) {
// The error code here is inconsistent with other error codes for
// invalid embedding, because this restriction may be relaxed in the
// future, and so it did not warrant a new error code.
check.error(embeddedPos, MisplacedTypeParam, "embedded field type cannot be a (pointer to a) type parameter")
break
}
if isPtr {
check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer to an interface")
}
}
}).describef(embeddedPos, "check embedded type %s", embeddedTyp)
}
}
styp.fields = fields
styp.tags = tags
styp.markComplete()
}
func embeddedFieldIdent(e ast.Expr) *ast.Ident {
switch e := e.(type) {
case *ast.Ident:
return e
case *ast.StarExpr:
// *T is valid, but **T is not
if _, ok := e.X.(*ast.StarExpr); !ok {
return embeddedFieldIdent(e.X)
}
case *ast.SelectorExpr:
return e.Sel
case *ast.IndexExpr:
return embeddedFieldIdent(e.X)
case *ast.IndexListExpr:
return embeddedFieldIdent(e.X)
}
return nil // invalid embedded field
}
func (check *Checker) declareInSet(oset *objset, pos token.Pos, obj Object) bool {
if alt := oset.insert(obj); alt != nil {
check.errorf(atPos(pos), DuplicateDecl, "%s redeclared", obj.Name())
check.reportAltDecl(alt)
return false
}
return true
}
func (check *Checker) tag(t *ast.BasicLit) string {
if t != nil {
if t.Kind == token.STRING {
if val, err := strconv.Unquote(t.Value); err == nil {
return val
}
}
check.errorf(t, InvalidSyntaxTree, "incorrect tag syntax: %q", t.Value)
}
return ""
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type parameter substitution.
package types
import (
"go/token"
)
type substMap map[*TypeParam]Type
// makeSubstMap creates a new substitution map mapping tpars[i] to targs[i].
// If targs[i] is nil, tpars[i] is not substituted.
func makeSubstMap(tpars []*TypeParam, targs []Type) substMap {
assert(len(tpars) == len(targs))
proj := make(substMap, len(tpars))
for i, tpar := range tpars {
proj[tpar] = targs[i]
}
return proj
}
// makeRenameMap is like makeSubstMap, but creates a map used to rename type
// parameters in from with the type parameters in to.
func makeRenameMap(from, to []*TypeParam) substMap {
assert(len(from) == len(to))
proj := make(substMap, len(from))
for i, tpar := range from {
proj[tpar] = to[i]
}
return proj
}
func (m substMap) empty() bool {
return len(m) == 0
}
func (m substMap) lookup(tpar *TypeParam) Type {
if t := m[tpar]; t != nil {
return t
}
return tpar
}
// subst returns the type typ with its type parameters tpars replaced by the
// corresponding type arguments targs, recursively. subst doesn't modify the
// incoming type. If a substitution took place, the result type is different
// from the incoming type.
//
// If expanding is non-nil, it is the instance type currently being expanded.
// One of expanding or ctxt must be non-nil.
func (check *Checker) subst(pos token.Pos, typ Type, smap substMap, expanding *Named, ctxt *Context) Type {
assert(expanding != nil || ctxt != nil)
if smap.empty() {
return typ
}
// common cases
switch t := typ.(type) {
case *Basic:
return typ // nothing to do
case *TypeParam:
return smap.lookup(t)
}
// general case
subst := subster{
pos: pos,
smap: smap,
check: check,
expanding: expanding,
ctxt: ctxt,
}
return subst.typ(typ)
}
type subster struct {
pos token.Pos
smap substMap
check *Checker // nil if called via Instantiate
expanding *Named // if non-nil, the instance that is being expanded
ctxt *Context
}
func (subst *subster) typ(typ Type) Type {
switch t := typ.(type) {
case nil:
// Call typOrNil if it's possible that typ is nil.
panic("nil typ")
case *Basic:
// nothing to do
case *Array:
elem := subst.typOrNil(t.elem)
if elem != t.elem {
return &Array{len: t.len, elem: elem}
}
case *Slice:
elem := subst.typOrNil(t.elem)
if elem != t.elem {
return &Slice{elem: elem}
}
case *Struct:
if fields, copied := subst.varList(t.fields); copied {
s := &Struct{fields: fields, tags: t.tags}
s.markComplete()
return s
}
case *Pointer:
base := subst.typ(t.base)
if base != t.base {
return &Pointer{base: base}
}
case *Tuple:
return subst.tuple(t)
case *Signature:
// Preserve the receiver: it is handled during *Interface and *Named type
// substitution.
//
// Naively doing the substitution here can lead to an infinite recursion in
// the case where the receiver is an interface. For example, consider the
// following declaration:
//
// type T[A any] struct { f interface{ m() } }
//
// In this case, the type of f is an interface that is itself the receiver
// type of all of its methods. Because we have no type name to break
// cycles, substituting in the recv results in an infinite loop of
// recv->interface->recv->interface->...
recv := t.recv
params := subst.tuple(t.params)
results := subst.tuple(t.results)
if params != t.params || results != t.results {
return &Signature{
rparams: t.rparams,
// TODO(gri) why can't we nil out tparams here, rather than in instantiate?
tparams: t.tparams,
// instantiated signatures have a nil scope
recv: recv,
params: params,
results: results,
variadic: t.variadic,
}
}
case *Union:
terms, copied := subst.termlist(t.terms)
if copied {
// term list substitution may introduce duplicate terms (unlikely but possible).
// This is ok; lazy type set computation will determine the actual type set
// in normal form.
return &Union{terms}
}
case *Interface:
methods, mcopied := subst.funcList(t.methods)
embeddeds, ecopied := subst.typeList(t.embeddeds)
if mcopied || ecopied {
iface := subst.check.newInterface()
iface.embeddeds = embeddeds
iface.implicit = t.implicit
iface.complete = t.complete
// If we've changed the interface type, we may need to replace its
// receiver if the receiver type is the original interface. Receivers of
// *Named type are replaced during named type expansion.
//
// Notably, it's possible to reach here and not create a new *Interface,
// even though the receiver type may be parameterized. For example:
//
// type T[P any] interface{ m() }
//
// In this case the interface will not be substituted here, because its
// method signatures do not depend on the type parameter P, but we still
// need to create new interface methods to hold the instantiated
// receiver. This is handled by Named.expandUnderlying.
iface.methods, _ = replaceRecvType(methods, t, iface)
return iface
}
case *Map:
key := subst.typ(t.key)
elem := subst.typ(t.elem)
if key != t.key || elem != t.elem {
return &Map{key: key, elem: elem}
}
case *Chan:
elem := subst.typ(t.elem)
if elem != t.elem {
return &Chan{dir: t.dir, elem: elem}
}
case *Named:
// dump is for debugging
dump := func(string, ...interface{}) {}
if subst.check != nil && subst.check.conf._Trace {
subst.check.indent++
defer func() {
subst.check.indent--
}()
dump = func(format string, args ...interface{}) {
subst.check.trace(subst.pos, format, args...)
}
}
// subst is called during expansion, so in this function we need to be
// careful not to call any methods that would cause t to be expanded: doing
// so would result in deadlock.
//
// So we call t.Origin().TypeParams() rather than t.TypeParams().
orig := t.Origin()
n := orig.TypeParams().Len()
if n == 0 {
dump(">>> %s is not parameterized", t)
return t // type is not parameterized
}
var newTArgs []Type
if t.TypeArgs().Len() != n {
return Typ[Invalid] // error reported elsewhere
}
// already instantiated
dump(">>> %s already instantiated", t)
// For each (existing) type argument targ, determine if it needs
// to be substituted; i.e., if it is or contains a type parameter
// that has a type argument for it.
for i, targ := range t.TypeArgs().list() {
dump(">>> %d targ = %s", i, targ)
new_targ := subst.typ(targ)
if new_targ != targ {
dump(">>> substituted %d targ %s => %s", i, targ, new_targ)
if newTArgs == nil {
newTArgs = make([]Type, n)
copy(newTArgs, t.TypeArgs().list())
}
newTArgs[i] = new_targ
}
}
if newTArgs == nil {
dump(">>> nothing to substitute in %s", t)
return t // nothing to substitute
}
// Create a new instance and populate the context to avoid endless
// recursion. The position used here is irrelevant because validation only
// occurs on t (we don't call validType on named), but we use subst.pos to
// help with debugging.
return subst.check.instance(subst.pos, orig, newTArgs, subst.expanding, subst.ctxt)
case *TypeParam:
return subst.smap.lookup(t)
default:
unreachable()
}
return typ
}
// typOrNil is like typ but if the argument is nil it is replaced with Typ[Invalid].
// A nil type may appear in pathological cases such as type T[P any] []func(_ T([]_))
// where an array/slice element is accessed before it is set up.
func (subst *subster) typOrNil(typ Type) Type {
if typ == nil {
return Typ[Invalid]
}
return subst.typ(typ)
}
func (subst *subster) var_(v *Var) *Var {
if v != nil {
if typ := subst.typ(v.typ); typ != v.typ {
return substVar(v, typ)
}
}
return v
}
func substVar(v *Var, typ Type) *Var {
copy := *v
copy.typ = typ
copy.origin = v.Origin()
return ©
}
func (subst *subster) tuple(t *Tuple) *Tuple {
if t != nil {
if vars, copied := subst.varList(t.vars); copied {
return &Tuple{vars: vars}
}
}
return t
}
func (subst *subster) varList(in []*Var) (out []*Var, copied bool) {
out = in
for i, v := range in {
if w := subst.var_(v); w != v {
if !copied {
// first variable that got substituted => allocate new out slice
// and copy all variables
new := make([]*Var, len(in))
copy(new, out)
out = new
copied = true
}
out[i] = w
}
}
return
}
func (subst *subster) func_(f *Func) *Func {
if f != nil {
if typ := subst.typ(f.typ); typ != f.typ {
return substFunc(f, typ)
}
}
return f
}
func substFunc(f *Func, typ Type) *Func {
copy := *f
copy.typ = typ
copy.origin = f.Origin()
return ©
}
func (subst *subster) funcList(in []*Func) (out []*Func, copied bool) {
out = in
for i, f := range in {
if g := subst.func_(f); g != f {
if !copied {
// first function that got substituted => allocate new out slice
// and copy all functions
new := make([]*Func, len(in))
copy(new, out)
out = new
copied = true
}
out[i] = g
}
}
return
}
func (subst *subster) typeList(in []Type) (out []Type, copied bool) {
out = in
for i, t := range in {
if u := subst.typ(t); u != t {
if !copied {
// first function that got substituted => allocate new out slice
// and copy all functions
new := make([]Type, len(in))
copy(new, out)
out = new
copied = true
}
out[i] = u
}
}
return
}
func (subst *subster) termlist(in []*Term) (out []*Term, copied bool) {
out = in
for i, t := range in {
if u := subst.typ(t.typ); u != t.typ {
if !copied {
// first function that got substituted => allocate new out slice
// and copy all functions
new := make([]*Term, len(in))
copy(new, out)
out = new
copied = true
}
out[i] = NewTerm(t.tilde, u)
}
}
return
}
// replaceRecvType updates any function receivers that have type old to have
// type new. It does not modify the input slice; if modifications are required,
// the input slice and any affected signatures will be copied before mutating.
//
// The resulting out slice contains the updated functions, and copied reports
// if anything was modified.
func replaceRecvType(in []*Func, old, new Type) (out []*Func, copied bool) {
out = in
for i, method := range in {
sig := method.Type().(*Signature)
if sig.recv != nil && sig.recv.Type() == old {
if !copied {
// Allocate a new methods slice before mutating for the first time.
// This is defensive, as we may share methods across instantiations of
// a given interface type if they do not get substituted.
out = make([]*Func, len(in))
copy(out, in)
copied = true
}
newsig := *sig
newsig.recv = substVar(sig.recv, new)
out[i] = substFunc(method, &newsig)
}
}
return
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import "strings"
// A termlist represents the type set represented by the union
// t1 ∪ y2 ∪ ... tn of the type sets of the terms t1 to tn.
// A termlist is in normal form if all terms are disjoint.
// termlist operations don't require the operands to be in
// normal form.
type termlist []*term
// allTermlist represents the set of all types.
// It is in normal form.
var allTermlist = termlist{new(term)}
// termSep is the separator used between individual terms.
const termSep = " | "
// String prints the termlist exactly (without normalization).
func (xl termlist) String() string {
if len(xl) == 0 {
return "∅"
}
var buf strings.Builder
for i, x := range xl {
if i > 0 {
buf.WriteString(termSep)
}
buf.WriteString(x.String())
}
return buf.String()
}
// isEmpty reports whether the termlist xl represents the empty set of types.
func (xl termlist) isEmpty() bool {
// If there's a non-nil term, the entire list is not empty.
// If the termlist is in normal form, this requires at most
// one iteration.
for _, x := range xl {
if x != nil {
return false
}
}
return true
}
// isAll reports whether the termlist xl represents the set of all types.
func (xl termlist) isAll() bool {
// If there's a 𝓤 term, the entire list is 𝓤.
// If the termlist is in normal form, this requires at most
// one iteration.
for _, x := range xl {
if x != nil && x.typ == nil {
return true
}
}
return false
}
// norm returns the normal form of xl.
func (xl termlist) norm() termlist {
// Quadratic algorithm, but good enough for now.
// TODO(gri) fix asymptotic performance
used := make([]bool, len(xl))
var rl termlist
for i, xi := range xl {
if xi == nil || used[i] {
continue
}
for j := i + 1; j < len(xl); j++ {
xj := xl[j]
if xj == nil || used[j] {
continue
}
if u1, u2 := xi.union(xj); u2 == nil {
// If we encounter a 𝓤 term, the entire list is 𝓤.
// Exit early.
// (Note that this is not just an optimization;
// if we continue, we may end up with a 𝓤 term
// and other terms and the result would not be
// in normal form.)
if u1.typ == nil {
return allTermlist
}
xi = u1
used[j] = true // xj is now unioned into xi - ignore it in future iterations
}
}
rl = append(rl, xi)
}
return rl
}
// union returns the union xl ∪ yl.
func (xl termlist) union(yl termlist) termlist {
return append(xl, yl...).norm()
}
// intersect returns the intersection xl ∩ yl.
func (xl termlist) intersect(yl termlist) termlist {
if xl.isEmpty() || yl.isEmpty() {
return nil
}
// Quadratic algorithm, but good enough for now.
// TODO(gri) fix asymptotic performance
var rl termlist
for _, x := range xl {
for _, y := range yl {
if r := x.intersect(y); r != nil {
rl = append(rl, r)
}
}
}
return rl.norm()
}
// equal reports whether xl and yl represent the same type set.
func (xl termlist) equal(yl termlist) bool {
// TODO(gri) this should be more efficient
return xl.subsetOf(yl) && yl.subsetOf(xl)
}
// includes reports whether t ∈ xl.
func (xl termlist) includes(t Type) bool {
for _, x := range xl {
if x.includes(t) {
return true
}
}
return false
}
// supersetOf reports whether y ⊆ xl.
func (xl termlist) supersetOf(y *term) bool {
for _, x := range xl {
if y.subsetOf(x) {
return true
}
}
return false
}
// subsetOf reports whether xl ⊆ yl.
func (xl termlist) subsetOf(yl termlist) bool {
if yl.isEmpty() {
return xl.isEmpty()
}
// each term x of xl must be a subset of yl
for _, x := range xl {
if !yl.supersetOf(x) {
return false // x is not a subset yl
}
}
return true
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A Tuple represents an ordered list of variables; a nil *Tuple is a valid (empty) tuple.
// Tuples are used as components of signatures and to represent the type of multiple
// assignments; they are not first class types of Go.
type Tuple struct {
vars []*Var
}
// NewTuple returns a new tuple for the given variables.
func NewTuple(x ...*Var) *Tuple {
if len(x) > 0 {
return &Tuple{vars: x}
}
return nil
}
// Len returns the number variables of tuple t.
func (t *Tuple) Len() int {
if t != nil {
return len(t.vars)
}
return 0
}
// At returns the i'th variable of tuple t.
func (t *Tuple) At(i int) *Var { return t.vars[i] }
func (t *Tuple) Underlying() Type { return t }
func (t *Tuple) String() string { return TypeString(t, nil) }
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// TypeParamList holds a list of type parameters.
type TypeParamList struct{ tparams []*TypeParam }
// Len returns the number of type parameters in the list.
// It is safe to call on a nil receiver.
func (l *TypeParamList) Len() int { return len(l.list()) }
// At returns the i'th type parameter in the list.
func (l *TypeParamList) At(i int) *TypeParam { return l.tparams[i] }
// list is for internal use where we expect a []*TypeParam.
// TODO(rfindley): list should probably be eliminated: we can pass around a
// TypeParamList instead.
func (l *TypeParamList) list() []*TypeParam {
if l == nil {
return nil
}
return l.tparams
}
// TypeList holds a list of types.
type TypeList struct{ types []Type }
// newTypeList returns a new TypeList with the types in list.
func newTypeList(list []Type) *TypeList {
if len(list) == 0 {
return nil
}
return &TypeList{list}
}
// Len returns the number of types in the list.
// It is safe to call on a nil receiver.
func (l *TypeList) Len() int { return len(l.list()) }
// At returns the i'th type in the list.
func (l *TypeList) At(i int) Type { return l.types[i] }
// list is for internal use where we expect a []Type.
// TODO(rfindley): list should probably be eliminated: we can pass around a
// TypeList instead.
func (l *TypeList) list() []Type {
if l == nil {
return nil
}
return l.types
}
// ----------------------------------------------------------------------------
// Implementation
func bindTParams(list []*TypeParam) *TypeParamList {
if len(list) == 0 {
return nil
}
for i, typ := range list {
if typ.index >= 0 {
panic("type parameter bound more than once")
}
typ.index = i
}
return &TypeParamList{tparams: list}
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import "sync/atomic"
// Note: This is a uint32 rather than a uint64 because the
// respective 64 bit atomic instructions are not available
// on all platforms.
var lastID uint32
// nextID returns a value increasing monotonically by 1 with
// each call, starting with 1. It may be called concurrently.
func nextID() uint64 { return uint64(atomic.AddUint32(&lastID, 1)) }
// A TypeParam represents a type parameter type.
type TypeParam struct {
check *Checker // for lazy type bound completion
id uint64 // unique id, for debugging only
obj *TypeName // corresponding type name
index int // type parameter index in source order, starting at 0
bound Type // any type, but underlying is eventually *Interface for correct programs (see TypeParam.iface)
}
// NewTypeParam returns a new TypeParam. Type parameters may be set on a Named
// or Signature type by calling SetTypeParams. Setting a type parameter on more
// than one type will result in a panic.
//
// The constraint argument can be nil, and set later via SetConstraint. If the
// constraint is non-nil, it must be fully defined.
func NewTypeParam(obj *TypeName, constraint Type) *TypeParam {
return (*Checker)(nil).newTypeParam(obj, constraint)
}
// check may be nil
func (check *Checker) newTypeParam(obj *TypeName, constraint Type) *TypeParam {
// Always increment lastID, even if it is not used.
id := nextID()
if check != nil {
check.nextID++
id = check.nextID
}
typ := &TypeParam{check: check, id: id, obj: obj, index: -1, bound: constraint}
if obj.typ == nil {
obj.typ = typ
}
// iface may mutate typ.bound, so we must ensure that iface() is called
// at least once before the resulting TypeParam escapes.
if check != nil {
check.needsCleanup(typ)
} else if constraint != nil {
typ.iface()
}
return typ
}
// Obj returns the type name for the type parameter t.
func (t *TypeParam) Obj() *TypeName { return t.obj }
// Index returns the index of the type param within its param list, or -1 if
// the type parameter has not yet been bound to a type.
func (t *TypeParam) Index() int {
return t.index
}
// Constraint returns the type constraint specified for t.
func (t *TypeParam) Constraint() Type {
return t.bound
}
// SetConstraint sets the type constraint for t.
//
// It must be called by users of NewTypeParam after the bound's underlying is
// fully defined, and before using the type parameter in any way other than to
// form other types. Once SetConstraint returns the receiver, t is safe for
// concurrent use.
func (t *TypeParam) SetConstraint(bound Type) {
if bound == nil {
panic("nil constraint")
}
t.bound = bound
// iface may mutate t.bound (if bound is not an interface), so ensure that
// this is done before returning.
t.iface()
}
func (t *TypeParam) Underlying() Type {
return t.iface()
}
func (t *TypeParam) String() string { return TypeString(t, nil) }
// ----------------------------------------------------------------------------
// Implementation
func (t *TypeParam) cleanup() {
t.iface()
t.check = nil
}
// iface returns the constraint interface of t.
func (t *TypeParam) iface() *Interface {
bound := t.bound
// determine constraint interface
var ityp *Interface
switch u := under(bound).(type) {
case *Basic:
if u == Typ[Invalid] {
// error is reported elsewhere
return &emptyInterface
}
case *Interface:
if isTypeParam(bound) {
// error is reported in Checker.collectTypeParams
return &emptyInterface
}
ityp = u
}
// If we don't have an interface, wrap constraint into an implicit interface.
if ityp == nil {
ityp = NewInterfaceType(nil, []Type{bound})
ityp.implicit = true
t.bound = ityp // update t.bound for next time (optimization)
}
// compute type set if necessary
if ityp.tset == nil {
// pos is used for tracing output; start with the type parameter position.
pos := t.obj.pos
// use the (original or possibly instantiated) type bound position if we have one
if n, _ := bound.(*Named); n != nil {
pos = n.obj.pos
}
computeInterfaceTypeSet(t.check, pos, ityp)
}
return ityp
}
// is calls f with the specific type terms of t's constraint and reports whether
// all calls to f returned true. If there are no specific terms, is
// returns the result of f(nil).
func (t *TypeParam) is(f func(*term) bool) bool {
return t.iface().typeSet().is(f)
}
// underIs calls f with the underlying types of the specific type terms
// of t's constraint and reports whether all calls to f returned true.
// If there are no specific terms, underIs returns the result of f(nil).
func (t *TypeParam) underIs(f func(Type) bool) bool {
return t.iface().typeSet().underIs(f)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/token"
. "internal/types/errors"
"sort"
"strings"
)
// ----------------------------------------------------------------------------
// API
// A _TypeSet represents the type set of an interface.
// Because of existing language restrictions, methods can be "factored out"
// from the terms. The actual type set is the intersection of the type set
// implied by the methods and the type set described by the terms and the
// comparable bit. To test whether a type is included in a type set
// ("implements" relation), the type must implement all methods _and_ be
// an element of the type set described by the terms and the comparable bit.
// If the term list describes the set of all types and comparable is true,
// only comparable types are meant; in all other cases comparable is false.
type _TypeSet struct {
methods []*Func // all methods of the interface; sorted by unique ID
terms termlist // type terms of the type set
comparable bool // invariant: !comparable || terms.isAll()
}
// IsEmpty reports whether type set s is the empty set.
func (s *_TypeSet) IsEmpty() bool { return s.terms.isEmpty() }
// IsAll reports whether type set s is the set of all types (corresponding to the empty interface).
func (s *_TypeSet) IsAll() bool { return s.IsMethodSet() && len(s.methods) == 0 }
// IsMethodSet reports whether the interface t is fully described by its method set.
func (s *_TypeSet) IsMethodSet() bool { return !s.comparable && s.terms.isAll() }
// IsComparable reports whether each type in the set is comparable.
func (s *_TypeSet) IsComparable(seen map[Type]bool) bool {
if s.terms.isAll() {
return s.comparable
}
return s.is(func(t *term) bool {
return t != nil && comparable(t.typ, false, seen, nil)
})
}
// NumMethods returns the number of methods available.
func (s *_TypeSet) NumMethods() int { return len(s.methods) }
// Method returns the i'th method of type set s for 0 <= i < s.NumMethods().
// The methods are ordered by their unique ID.
func (s *_TypeSet) Method(i int) *Func { return s.methods[i] }
// LookupMethod returns the index of and method with matching package and name, or (-1, nil).
func (s *_TypeSet) LookupMethod(pkg *Package, name string, foldCase bool) (int, *Func) {
return lookupMethod(s.methods, pkg, name, foldCase)
}
func (s *_TypeSet) String() string {
switch {
case s.IsEmpty():
return "∅"
case s.IsAll():
return "𝓤"
}
hasMethods := len(s.methods) > 0
hasTerms := s.hasTerms()
var buf strings.Builder
buf.WriteByte('{')
if s.comparable {
buf.WriteString("comparable")
if hasMethods || hasTerms {
buf.WriteString("; ")
}
}
for i, m := range s.methods {
if i > 0 {
buf.WriteString("; ")
}
buf.WriteString(m.String())
}
if hasMethods && hasTerms {
buf.WriteString("; ")
}
if hasTerms {
buf.WriteString(s.terms.String())
}
buf.WriteString("}")
return buf.String()
}
// ----------------------------------------------------------------------------
// Implementation
// hasTerms reports whether the type set has specific type terms.
func (s *_TypeSet) hasTerms() bool { return !s.terms.isEmpty() && !s.terms.isAll() }
// subsetOf reports whether s1 ⊆ s2.
func (s1 *_TypeSet) subsetOf(s2 *_TypeSet) bool { return s1.terms.subsetOf(s2.terms) }
// TODO(gri) TypeSet.is and TypeSet.underIs should probably also go into termlist.go
// is calls f with the specific type terms of s and reports whether
// all calls to f returned true. If there are no specific terms, is
// returns the result of f(nil).
func (s *_TypeSet) is(f func(*term) bool) bool {
if !s.hasTerms() {
return f(nil)
}
for _, t := range s.terms {
assert(t.typ != nil)
if !f(t) {
return false
}
}
return true
}
// underIs calls f with the underlying types of the specific type terms
// of s and reports whether all calls to f returned true. If there are
// no specific terms, underIs returns the result of f(nil).
func (s *_TypeSet) underIs(f func(Type) bool) bool {
if !s.hasTerms() {
return f(nil)
}
for _, t := range s.terms {
assert(t.typ != nil)
// x == under(x) for ~x terms
u := t.typ
if !t.tilde {
u = under(u)
}
if debug {
assert(Identical(u, under(u)))
}
if !f(u) {
return false
}
}
return true
}
// topTypeSet may be used as type set for the empty interface.
var topTypeSet = _TypeSet{terms: allTermlist}
// computeInterfaceTypeSet may be called with check == nil.
func computeInterfaceTypeSet(check *Checker, pos token.Pos, ityp *Interface) *_TypeSet {
if ityp.tset != nil {
return ityp.tset
}
// If the interface is not fully set up yet, the type set will
// not be complete, which may lead to errors when using the
// type set (e.g. missing method). Don't compute a partial type
// set (and don't store it!), so that we still compute the full
// type set eventually. Instead, return the top type set and
// let any follow-on errors play out.
//
// TODO(gri) Consider recording when this happens and reporting
// it as an error (but only if there were no other errors so to
// to not have unnecessary follow-on errors).
if !ityp.complete {
return &topTypeSet
}
if check != nil && check.conf._Trace {
// Types don't generally have position information.
// If we don't have a valid pos provided, try to use
// one close enough.
if !pos.IsValid() && len(ityp.methods) > 0 {
pos = ityp.methods[0].pos
}
check.trace(pos, "-- type set for %s", ityp)
check.indent++
defer func() {
check.indent--
check.trace(pos, "=> %s ", ityp.typeSet())
}()
}
// An infinitely expanding interface (due to a cycle) is detected
// elsewhere (Checker.validType), so here we simply assume we only
// have valid interfaces. Mark the interface as complete to avoid
// infinite recursion if the validType check occurs later for some
// reason.
ityp.tset = &_TypeSet{terms: allTermlist} // TODO(gri) is this sufficient?
var unionSets map[*Union]*_TypeSet
if check != nil {
if check.unionTypeSets == nil {
check.unionTypeSets = make(map[*Union]*_TypeSet)
}
unionSets = check.unionTypeSets
} else {
unionSets = make(map[*Union]*_TypeSet)
}
// Methods of embedded interfaces are collected unchanged; i.e., the identity
// of a method I.m's Func Object of an interface I is the same as that of
// the method m in an interface that embeds interface I. On the other hand,
// if a method is embedded via multiple overlapping embedded interfaces, we
// don't provide a guarantee which "original m" got chosen for the embedding
// interface. See also go.dev/issue/34421.
//
// If we don't care to provide this identity guarantee anymore, instead of
// reusing the original method in embeddings, we can clone the method's Func
// Object and give it the position of a corresponding embedded interface. Then
// we can get rid of the mpos map below and simply use the cloned method's
// position.
var todo []*Func
var seen objset
var allMethods []*Func
mpos := make(map[*Func]token.Pos) // method specification or method embedding position, for good error messages
addMethod := func(pos token.Pos, m *Func, explicit bool) {
switch other := seen.insert(m); {
case other == nil:
allMethods = append(allMethods, m)
mpos[m] = pos
case explicit:
if check == nil {
panic(fmt.Sprintf("%v: duplicate method %s", m.pos, m.name))
}
// check != nil
check.errorf(atPos(pos), DuplicateDecl, "duplicate method %s", m.name)
check.errorf(atPos(mpos[other.(*Func)]), DuplicateDecl, "\tother declaration of %s", m.name) // secondary error, \t indented
default:
// We have a duplicate method name in an embedded (not explicitly declared) method.
// Check method signatures after all types are computed (go.dev/issue/33656).
// If we're pre-go1.14 (overlapping embeddings are not permitted), report that
// error here as well (even though we could do it eagerly) because it's the same
// error message.
if check == nil {
// check method signatures after all locally embedded interfaces are computed
todo = append(todo, m, other.(*Func))
break
}
// check != nil
check.later(func() {
if !check.allowVersion(m.pkg, 1, 14) || !Identical(m.typ, other.Type()) {
check.errorf(atPos(pos), DuplicateDecl, "duplicate method %s", m.name)
check.errorf(atPos(mpos[other.(*Func)]), DuplicateDecl, "\tother declaration of %s", m.name) // secondary error, \t indented
}
}).describef(atPos(pos), "duplicate method check for %s", m.name)
}
}
for _, m := range ityp.methods {
addMethod(m.pos, m, true)
}
// collect embedded elements
allTerms := allTermlist
allComparable := false
for i, typ := range ityp.embeddeds {
// The embedding position is nil for imported interfaces
// and also for interface copies after substitution (but
// in that case we don't need to report errors again).
var pos token.Pos // embedding position
if ityp.embedPos != nil {
pos = (*ityp.embedPos)[i]
}
var comparable bool
var terms termlist
switch u := under(typ).(type) {
case *Interface:
// For now we don't permit type parameters as constraints.
assert(!isTypeParam(typ))
tset := computeInterfaceTypeSet(check, pos, u)
// If typ is local, an error was already reported where typ is specified/defined.
if check != nil && check.isImportedConstraint(typ) && !check.allowVersion(check.pkg, 1, 18) {
check.errorf(atPos(pos), UnsupportedFeature, "embedding constraint interface %s requires go1.18 or later", typ)
continue
}
comparable = tset.comparable
for _, m := range tset.methods {
addMethod(pos, m, false) // use embedding position pos rather than m.pos
}
terms = tset.terms
case *Union:
if check != nil && !check.allowVersion(check.pkg, 1, 18) {
check.errorf(atPos(pos), UnsupportedFeature, "embedding interface element %s requires go1.18 or later", u)
continue
}
tset := computeUnionTypeSet(check, unionSets, pos, u)
if tset == &invalidTypeSet {
continue // ignore invalid unions
}
assert(!tset.comparable)
assert(len(tset.methods) == 0)
terms = tset.terms
default:
if u == Typ[Invalid] {
continue
}
if check != nil && !check.allowVersion(check.pkg, 1, 18) {
check.errorf(atPos(pos), UnsupportedFeature, "embedding non-interface type %s requires go1.18 or later", typ)
continue
}
terms = termlist{{false, typ}}
}
// The type set of an interface is the intersection of the type sets of all its elements.
// Due to language restrictions, only embedded interfaces can add methods, they are handled
// separately. Here we only need to intersect the term lists and comparable bits.
allTerms, allComparable = intersectTermLists(allTerms, allComparable, terms, comparable)
}
ityp.embedPos = nil // not needed anymore (errors have been reported)
// process todo's (this only happens if check == nil)
for i := 0; i < len(todo); i += 2 {
m := todo[i]
other := todo[i+1]
if !Identical(m.typ, other.typ) {
panic(fmt.Sprintf("%v: duplicate method %s", m.pos, m.name))
}
}
ityp.tset.comparable = allComparable
if len(allMethods) != 0 {
sortMethods(allMethods)
ityp.tset.methods = allMethods
}
ityp.tset.terms = allTerms
return ityp.tset
}
// TODO(gri) The intersectTermLists function belongs to the termlist implementation.
// The comparable type set may also be best represented as a term (using
// a special type).
// intersectTermLists computes the intersection of two term lists and respective comparable bits.
// xcomp, ycomp are valid only if xterms.isAll() and yterms.isAll() respectively.
func intersectTermLists(xterms termlist, xcomp bool, yterms termlist, ycomp bool) (termlist, bool) {
terms := xterms.intersect(yterms)
// If one of xterms or yterms is marked as comparable,
// the result must only include comparable types.
comp := xcomp || ycomp
if comp && !terms.isAll() {
// only keep comparable terms
i := 0
for _, t := range terms {
assert(t.typ != nil)
if comparable(t.typ, false /* strictly comparable */, nil, nil) {
terms[i] = t
i++
}
}
terms = terms[:i]
if !terms.isAll() {
comp = false
}
}
assert(!comp || terms.isAll()) // comparable invariant
return terms, comp
}
func sortMethods(list []*Func) {
sort.Sort(byUniqueMethodName(list))
}
func assertSortedMethods(list []*Func) {
if !debug {
panic("assertSortedMethods called outside debug mode")
}
if !sort.IsSorted(byUniqueMethodName(list)) {
panic("methods not sorted")
}
}
// byUniqueMethodName method lists can be sorted by their unique method names.
type byUniqueMethodName []*Func
func (a byUniqueMethodName) Len() int { return len(a) }
func (a byUniqueMethodName) Less(i, j int) bool { return a[i].less(&a[j].object) }
func (a byUniqueMethodName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// invalidTypeSet is a singleton type set to signal an invalid type set
// due to an error. It's also a valid empty type set, so consumers of
// type sets may choose to ignore it.
var invalidTypeSet _TypeSet
// computeUnionTypeSet may be called with check == nil.
// The result is &invalidTypeSet if the union overflows.
func computeUnionTypeSet(check *Checker, unionSets map[*Union]*_TypeSet, pos token.Pos, utyp *Union) *_TypeSet {
if tset, _ := unionSets[utyp]; tset != nil {
return tset
}
// avoid infinite recursion (see also computeInterfaceTypeSet)
unionSets[utyp] = new(_TypeSet)
var allTerms termlist
for _, t := range utyp.terms {
var terms termlist
u := under(t.typ)
if ui, _ := u.(*Interface); ui != nil {
// For now we don't permit type parameters as constraints.
assert(!isTypeParam(t.typ))
terms = computeInterfaceTypeSet(check, pos, ui).terms
} else if u == Typ[Invalid] {
continue
} else {
if t.tilde && !Identical(t.typ, u) {
// There is no underlying type which is t.typ.
// The corresponding type set is empty.
t = nil // ∅ term
}
terms = termlist{(*term)(t)}
}
// The type set of a union expression is the union
// of the type sets of each term.
allTerms = allTerms.union(terms)
if len(allTerms) > maxTermCount {
if check != nil {
check.errorf(atPos(pos), InvalidUnion, "cannot handle more than %d union terms (implementation limitation)", maxTermCount)
}
unionSets[utyp] = &invalidTypeSet
return unionSets[utyp]
}
}
unionSets[utyp].terms = allTerms
return unionSets[utyp]
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements printing of types.
package types
import (
"bytes"
"fmt"
"go/token"
"sort"
"strconv"
"strings"
"unicode/utf8"
)
// A Qualifier controls how named package-level objects are printed in
// calls to TypeString, ObjectString, and SelectionString.
//
// These three formatting routines call the Qualifier for each
// package-level object O, and if the Qualifier returns a non-empty
// string p, the object is printed in the form p.O.
// If it returns an empty string, only the object name O is printed.
//
// Using a nil Qualifier is equivalent to using (*Package).Path: the
// object is qualified by the import path, e.g., "encoding/json.Marshal".
type Qualifier func(*Package) string
// RelativeTo returns a Qualifier that fully qualifies members of
// all packages other than pkg.
func RelativeTo(pkg *Package) Qualifier {
if pkg == nil {
return nil
}
return func(other *Package) string {
if pkg == other {
return "" // same package; unqualified
}
return other.Path()
}
}
// TypeString returns the string representation of typ.
// The Qualifier controls the printing of
// package-level objects, and may be nil.
func TypeString(typ Type, qf Qualifier) string {
var buf bytes.Buffer
WriteType(&buf, typ, qf)
return buf.String()
}
// WriteType writes the string representation of typ to buf.
// The Qualifier controls the printing of
// package-level objects, and may be nil.
func WriteType(buf *bytes.Buffer, typ Type, qf Qualifier) {
newTypeWriter(buf, qf).typ(typ)
}
// WriteSignature writes the representation of the signature sig to buf,
// without a leading "func" keyword. The Qualifier controls the printing
// of package-level objects, and may be nil.
func WriteSignature(buf *bytes.Buffer, sig *Signature, qf Qualifier) {
newTypeWriter(buf, qf).signature(sig)
}
type typeWriter struct {
buf *bytes.Buffer
seen map[Type]bool
qf Qualifier
ctxt *Context // if non-nil, we are type hashing
tparams *TypeParamList // local type parameters
paramNames bool // if set, write function parameter names, otherwise, write types only
tpSubscripts bool // if set, write type parameter indices as subscripts
pkgInfo bool // package-annotate first unexported-type field to avoid confusing type description
}
func newTypeWriter(buf *bytes.Buffer, qf Qualifier) *typeWriter {
return &typeWriter{buf, make(map[Type]bool), qf, nil, nil, true, false, false}
}
func newTypeHasher(buf *bytes.Buffer, ctxt *Context) *typeWriter {
assert(ctxt != nil)
return &typeWriter{buf, make(map[Type]bool), nil, ctxt, nil, false, false, false}
}
func (w *typeWriter) byte(b byte) {
if w.ctxt != nil {
if b == ' ' {
b = '#'
}
w.buf.WriteByte(b)
return
}
w.buf.WriteByte(b)
if b == ',' || b == ';' {
w.buf.WriteByte(' ')
}
}
func (w *typeWriter) string(s string) {
w.buf.WriteString(s)
}
func (w *typeWriter) error(msg string) {
if w.ctxt != nil {
panic(msg)
}
w.buf.WriteString("<" + msg + ">")
}
func (w *typeWriter) typ(typ Type) {
if w.seen[typ] {
w.error("cycle to " + goTypeName(typ))
return
}
w.seen[typ] = true
defer delete(w.seen, typ)
switch t := typ.(type) {
case nil:
w.error("nil")
case *Basic:
// exported basic types go into package unsafe
// (currently this is just unsafe.Pointer)
if token.IsExported(t.name) {
if obj, _ := Unsafe.scope.Lookup(t.name).(*TypeName); obj != nil {
w.typeName(obj)
break
}
}
w.string(t.name)
case *Array:
w.byte('[')
w.string(strconv.FormatInt(t.len, 10))
w.byte(']')
w.typ(t.elem)
case *Slice:
w.string("[]")
w.typ(t.elem)
case *Struct:
w.string("struct{")
for i, f := range t.fields {
if i > 0 {
w.byte(';')
}
// If disambiguating one struct for another, look for the first unexported field.
// Do this first in case of nested structs; tag the first-outermost field.
pkgAnnotate := false
if w.qf == nil && w.pkgInfo && !token.IsExported(f.name) {
// note for embedded types, type name is field name, and "string" etc are lower case hence unexported.
pkgAnnotate = true
w.pkgInfo = false // only tag once
}
// This doesn't do the right thing for embedded type
// aliases where we should print the alias name, not
// the aliased type (see go.dev/issue/44410).
if !f.embedded {
w.string(f.name)
w.byte(' ')
}
w.typ(f.typ)
if pkgAnnotate {
w.string(" /* package ")
w.string(f.pkg.Path())
w.string(" */ ")
}
if tag := t.Tag(i); tag != "" {
w.byte(' ')
// TODO(rfindley) If tag contains blanks, replacing them with '#'
// in Context.TypeHash may produce another tag
// accidentally.
w.string(strconv.Quote(tag))
}
}
w.byte('}')
case *Pointer:
w.byte('*')
w.typ(t.base)
case *Tuple:
w.tuple(t, false)
case *Signature:
w.string("func")
w.signature(t)
case *Union:
// Unions only appear as (syntactic) embedded elements
// in interfaces and syntactically cannot be empty.
if t.Len() == 0 {
w.error("empty union")
break
}
for i, t := range t.terms {
if i > 0 {
w.string(termSep)
}
if t.tilde {
w.byte('~')
}
w.typ(t.typ)
}
case *Interface:
if w.ctxt == nil {
if t == universeAny.Type() {
// When not hashing, we can try to improve type strings by writing "any"
// for a type that is pointer-identical to universeAny. This logic should
// be deprecated by more robust handling for aliases.
w.string("any")
break
}
if t == universeComparable.Type().(*Named).underlying {
w.string("interface{comparable}")
break
}
}
if t.implicit {
if len(t.methods) == 0 && len(t.embeddeds) == 1 {
w.typ(t.embeddeds[0])
break
}
// Something's wrong with the implicit interface.
// Print it as such and continue.
w.string("/* implicit */ ")
}
w.string("interface{")
first := true
if w.ctxt != nil {
w.typeSet(t.typeSet())
} else {
for _, m := range t.methods {
if !first {
w.byte(';')
}
first = false
w.string(m.name)
w.signature(m.typ.(*Signature))
}
for _, typ := range t.embeddeds {
if !first {
w.byte(';')
}
first = false
w.typ(typ)
}
}
w.byte('}')
case *Map:
w.string("map[")
w.typ(t.key)
w.byte(']')
w.typ(t.elem)
case *Chan:
var s string
var parens bool
switch t.dir {
case SendRecv:
s = "chan "
// chan (<-chan T) requires parentheses
if c, _ := t.elem.(*Chan); c != nil && c.dir == RecvOnly {
parens = true
}
case SendOnly:
s = "chan<- "
case RecvOnly:
s = "<-chan "
default:
w.error("unknown channel direction")
}
w.string(s)
if parens {
w.byte('(')
}
w.typ(t.elem)
if parens {
w.byte(')')
}
case *Named:
// If hashing, write a unique prefix for t to represent its identity, since
// named type identity is pointer identity.
if w.ctxt != nil {
w.string(strconv.Itoa(w.ctxt.getID(t)))
}
w.typeName(t.obj) // when hashing written for readability of the hash only
if t.inst != nil {
// instantiated type
w.typeList(t.inst.targs.list())
} else if w.ctxt == nil && t.TypeParams().Len() != 0 { // For type hashing, don't need to format the TypeParams
// parameterized type
w.tParamList(t.TypeParams().list())
}
case *TypeParam:
if t.obj == nil {
w.error("unnamed type parameter")
break
}
if i := tparamIndex(w.tparams.list(), t); i >= 0 {
// The names of type parameters that are declared by the type being
// hashed are not part of the type identity. Replace them with a
// placeholder indicating their index.
w.string(fmt.Sprintf("$%d", i))
} else {
w.string(t.obj.name)
if w.tpSubscripts || w.ctxt != nil {
w.string(subscript(t.id))
}
// If the type parameter name is the same as a predeclared object
// (say int), point out where it is declared to avoid confusing
// error messages. This doesn't need to be super-elegant; we just
// need a clear indication that this is not a predeclared name.
// Note: types2 prints position information here - we can't do
// that because we don't have a token.FileSet accessible.
if w.ctxt == nil && Universe.Lookup(t.obj.name) != nil {
w.string("/* type parameter */")
}
}
default:
// For externally defined implementations of Type.
// Note: In this case cycles won't be caught.
w.string(t.String())
}
}
// typeSet writes a canonical hash for an interface type set.
func (w *typeWriter) typeSet(s *_TypeSet) {
assert(w.ctxt != nil)
first := true
for _, m := range s.methods {
if !first {
w.byte(';')
}
first = false
w.string(m.name)
w.signature(m.typ.(*Signature))
}
switch {
case s.terms.isAll():
// nothing to do
case s.terms.isEmpty():
w.string(s.terms.String())
default:
var termHashes []string
for _, term := range s.terms {
// terms are not canonically sorted, so we sort their hashes instead.
var buf bytes.Buffer
if term.tilde {
buf.WriteByte('~')
}
newTypeHasher(&buf, w.ctxt).typ(term.typ)
termHashes = append(termHashes, buf.String())
}
sort.Strings(termHashes)
if !first {
w.byte(';')
}
w.string(strings.Join(termHashes, "|"))
}
}
func (w *typeWriter) typeList(list []Type) {
w.byte('[')
for i, typ := range list {
if i > 0 {
w.byte(',')
}
w.typ(typ)
}
w.byte(']')
}
func (w *typeWriter) tParamList(list []*TypeParam) {
w.byte('[')
var prev Type
for i, tpar := range list {
// Determine the type parameter and its constraint.
// list is expected to hold type parameter names,
// but don't crash if that's not the case.
if tpar == nil {
w.error("nil type parameter")
continue
}
if i > 0 {
if tpar.bound != prev {
// bound changed - write previous one before advancing
w.byte(' ')
w.typ(prev)
}
w.byte(',')
}
prev = tpar.bound
w.typ(tpar)
}
if prev != nil {
w.byte(' ')
w.typ(prev)
}
w.byte(']')
}
func (w *typeWriter) typeName(obj *TypeName) {
w.string(packagePrefix(obj.pkg, w.qf))
w.string(obj.name)
}
func (w *typeWriter) tuple(tup *Tuple, variadic bool) {
w.byte('(')
if tup != nil {
for i, v := range tup.vars {
if i > 0 {
w.byte(',')
}
// parameter names are ignored for type identity and thus type hashes
if w.ctxt == nil && v.name != "" && w.paramNames {
w.string(v.name)
w.byte(' ')
}
typ := v.typ
if variadic && i == len(tup.vars)-1 {
if s, ok := typ.(*Slice); ok {
w.string("...")
typ = s.elem
} else {
// special case:
// append(s, "foo"...) leads to signature func([]byte, string...)
if t, _ := under(typ).(*Basic); t == nil || t.kind != String {
w.error("expected string type")
continue
}
w.typ(typ)
w.string("...")
continue
}
}
w.typ(typ)
}
}
w.byte(')')
}
func (w *typeWriter) signature(sig *Signature) {
if sig.TypeParams().Len() != 0 {
if w.ctxt != nil {
assert(w.tparams == nil)
w.tparams = sig.TypeParams()
defer func() {
w.tparams = nil
}()
}
w.tParamList(sig.TypeParams().list())
}
w.tuple(sig.params, sig.variadic)
n := sig.results.Len()
if n == 0 {
// no result
return
}
w.byte(' ')
if n == 1 && (w.ctxt != nil || sig.results.vars[0].name == "") {
// single unnamed result (if type hashing, name must be ignored)
w.typ(sig.results.vars[0].typ)
return
}
// multiple or named result(s)
w.tuple(sig.results, false)
}
// subscript returns the decimal (utf8) representation of x using subscript digits.
func subscript(x uint64) string {
const w = len("₀") // all digits 0...9 have the same utf8 width
var buf [32 * w]byte
i := len(buf)
for {
i -= w
utf8.EncodeRune(buf[i:], '₀'+rune(x%10)) // '₀' == U+2080
x /= 10
if x == 0 {
break
}
}
return string(buf[i:])
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// A term describes elementary type sets:
//
// ∅: (*term)(nil) == ∅ // set of no types (empty set)
// 𝓤: &term{} == 𝓤 // set of all types (𝓤niverse)
// T: &term{false, T} == {T} // set of type T
// ~t: &term{true, t} == {t' | under(t') == t} // set of types with underlying type t
type term struct {
tilde bool // valid if typ != nil
typ Type
}
func (x *term) String() string {
switch {
case x == nil:
return "∅"
case x.typ == nil:
return "𝓤"
case x.tilde:
return "~" + x.typ.String()
default:
return x.typ.String()
}
}
// equal reports whether x and y represent the same type set.
func (x *term) equal(y *term) bool {
// easy cases
switch {
case x == nil || y == nil:
return x == y
case x.typ == nil || y.typ == nil:
return x.typ == y.typ
}
// ∅ ⊂ x, y ⊂ 𝓤
return x.tilde == y.tilde && Identical(x.typ, y.typ)
}
// union returns the union x ∪ y: zero, one, or two non-nil terms.
func (x *term) union(y *term) (_, _ *term) {
// easy cases
switch {
case x == nil && y == nil:
return nil, nil // ∅ ∪ ∅ == ∅
case x == nil:
return y, nil // ∅ ∪ y == y
case y == nil:
return x, nil // x ∪ ∅ == x
case x.typ == nil:
return x, nil // 𝓤 ∪ y == 𝓤
case y.typ == nil:
return y, nil // x ∪ 𝓤 == 𝓤
}
// ∅ ⊂ x, y ⊂ 𝓤
if x.disjoint(y) {
return x, y // x ∪ y == (x, y) if x ∩ y == ∅
}
// x.typ == y.typ
// ~t ∪ ~t == ~t
// ~t ∪ T == ~t
// T ∪ ~t == ~t
// T ∪ T == T
if x.tilde || !y.tilde {
return x, nil
}
return y, nil
}
// intersect returns the intersection x ∩ y.
func (x *term) intersect(y *term) *term {
// easy cases
switch {
case x == nil || y == nil:
return nil // ∅ ∩ y == ∅ and ∩ ∅ == ∅
case x.typ == nil:
return y // 𝓤 ∩ y == y
case y.typ == nil:
return x // x ∩ 𝓤 == x
}
// ∅ ⊂ x, y ⊂ 𝓤
if x.disjoint(y) {
return nil // x ∩ y == ∅ if x ∩ y == ∅
}
// x.typ == y.typ
// ~t ∩ ~t == ~t
// ~t ∩ T == T
// T ∩ ~t == T
// T ∩ T == T
if !x.tilde || y.tilde {
return x
}
return y
}
// includes reports whether t ∈ x.
func (x *term) includes(t Type) bool {
// easy cases
switch {
case x == nil:
return false // t ∈ ∅ == false
case x.typ == nil:
return true // t ∈ 𝓤 == true
}
// ∅ ⊂ x ⊂ 𝓤
u := t
if x.tilde {
u = under(u)
}
return Identical(x.typ, u)
}
// subsetOf reports whether x ⊆ y.
func (x *term) subsetOf(y *term) bool {
// easy cases
switch {
case x == nil:
return true // ∅ ⊆ y == true
case y == nil:
return false // x ⊆ ∅ == false since x != ∅
case y.typ == nil:
return true // x ⊆ 𝓤 == true
case x.typ == nil:
return false // 𝓤 ⊆ y == false since y != 𝓤
}
// ∅ ⊂ x, y ⊂ 𝓤
if x.disjoint(y) {
return false // x ⊆ y == false if x ∩ y == ∅
}
// x.typ == y.typ
// ~t ⊆ ~t == true
// ~t ⊆ T == false
// T ⊆ ~t == true
// T ⊆ T == true
return !x.tilde || y.tilde
}
// disjoint reports whether x ∩ y == ∅.
// x.typ and y.typ must not be nil.
func (x *term) disjoint(y *term) bool {
if debug && (x.typ == nil || y.typ == nil) {
panic("invalid argument(s)")
}
ux := x.typ
if y.tilde {
ux = under(ux)
}
uy := y.typ
if x.tilde {
uy = under(uy)
}
return !Identical(ux, uy)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type-checking of identifiers and type expressions.
package types
import (
"fmt"
"go/ast"
"go/constant"
"go/internal/typeparams"
. "internal/types/errors"
"strings"
)
// ident type-checks identifier e and initializes x with the value or type of e.
// If an error occurred, x.mode is set to invalid.
// For the meaning of def, see Checker.definedType, below.
// If wantType is set, the identifier e is expected to denote a type.
func (check *Checker) ident(x *operand, e *ast.Ident, def *Named, wantType bool) {
x.mode = invalid
x.expr = e
// Note that we cannot use check.lookup here because the returned scope
// may be different from obj.Parent(). See also Scope.LookupParent doc.
scope, obj := check.scope.LookupParent(e.Name, check.pos)
switch obj {
case nil:
if e.Name == "_" {
// Blank identifiers are never declared, but the current identifier may
// be a placeholder for a receiver type parameter. In this case we can
// resolve its type and object from Checker.recvTParamMap.
if tpar := check.recvTParamMap[e]; tpar != nil {
x.mode = typexpr
x.typ = tpar
} else {
check.error(e, InvalidBlank, "cannot use _ as value or type")
}
} else {
check.errorf(e, UndeclaredName, "undefined: %s", e.Name)
}
return
case universeAny, universeComparable:
if !check.allowVersion(check.pkg, 1, 18) {
check.versionErrorf(e, "go1.18", "predeclared %s", e.Name)
return // avoid follow-on errors
}
}
check.recordUse(e, obj)
// Type-check the object.
// Only call Checker.objDecl if the object doesn't have a type yet
// (in which case we must actually determine it) or the object is a
// TypeName and we also want a type (in which case we might detect
// a cycle which needs to be reported). Otherwise we can skip the
// call and avoid a possible cycle error in favor of the more
// informative "not a type/value" error that this function's caller
// will issue (see go.dev/issue/25790).
typ := obj.Type()
if _, gotType := obj.(*TypeName); typ == nil || gotType && wantType {
check.objDecl(obj, def)
typ = obj.Type() // type must have been assigned by Checker.objDecl
}
assert(typ != nil)
// The object may have been dot-imported.
// If so, mark the respective package as used.
// (This code is only needed for dot-imports. Without them,
// we only have to mark variables, see *Var case below).
if pkgName := check.dotImportMap[dotImportKey{scope, obj.Name()}]; pkgName != nil {
pkgName.used = true
}
switch obj := obj.(type) {
case *PkgName:
check.errorf(e, InvalidPkgUse, "use of package %s not in selector", obj.name)
return
case *Const:
check.addDeclDep(obj)
if typ == Typ[Invalid] {
return
}
if obj == universeIota {
if check.iota == nil {
check.error(e, InvalidIota, "cannot use iota outside constant declaration")
return
}
x.val = check.iota
} else {
x.val = obj.val
}
assert(x.val != nil)
x.mode = constant_
case *TypeName:
if check.isBrokenAlias(obj) {
check.errorf(e, InvalidDeclCycle, "invalid use of type alias %s in recursive type (see go.dev/issue/50729)", obj.name)
return
}
x.mode = typexpr
case *Var:
// It's ok to mark non-local variables, but ignore variables
// from other packages to avoid potential race conditions with
// dot-imported variables.
if obj.pkg == check.pkg {
obj.used = true
}
check.addDeclDep(obj)
if typ == Typ[Invalid] {
return
}
x.mode = variable
case *Func:
check.addDeclDep(obj)
x.mode = value
case *Builtin:
x.id = obj.id
x.mode = builtin
case *Nil:
x.mode = value
default:
unreachable()
}
x.typ = typ
}
// typ type-checks the type expression e and returns its type, or Typ[Invalid].
// The type must not be an (uninstantiated) generic type.
func (check *Checker) typ(e ast.Expr) Type {
return check.definedType(e, nil)
}
// varType type-checks the type expression e and returns its type, or Typ[Invalid].
// The type must not be an (uninstantiated) generic type and it must not be a
// constraint interface.
func (check *Checker) varType(e ast.Expr) Type {
typ := check.definedType(e, nil)
check.validVarType(e, typ)
return typ
}
// validVarType reports an error if typ is a constraint interface.
// The expression e is used for error reporting, if any.
func (check *Checker) validVarType(e ast.Expr, typ Type) {
// If we have a type parameter there's nothing to do.
if isTypeParam(typ) {
return
}
// We don't want to call under() or complete interfaces while we are in
// the middle of type-checking parameter declarations that might belong
// to interface methods. Delay this check to the end of type-checking.
check.later(func() {
if t, _ := under(typ).(*Interface); t != nil {
tset := computeInterfaceTypeSet(check, e.Pos(), t) // TODO(gri) is this the correct position?
if !tset.IsMethodSet() {
if tset.comparable {
check.softErrorf(e, MisplacedConstraintIface, "cannot use type %s outside a type constraint: interface is (or embeds) comparable", typ)
} else {
check.softErrorf(e, MisplacedConstraintIface, "cannot use type %s outside a type constraint: interface contains type constraints", typ)
}
}
}
}).describef(e, "check var type %s", typ)
}
// definedType is like typ but also accepts a type name def.
// If def != nil, e is the type specification for the defined type def, declared
// in a type declaration, and def.underlying will be set to the type of e before
// any components of e are type-checked.
func (check *Checker) definedType(e ast.Expr, def *Named) Type {
typ := check.typInternal(e, def)
assert(isTyped(typ))
if isGeneric(typ) {
check.errorf(e, WrongTypeArgCount, "cannot use generic type %s without instantiation", typ)
typ = Typ[Invalid]
}
check.recordTypeAndValue(e, typexpr, typ, nil)
return typ
}
// genericType is like typ but the type must be an (uninstantiated) generic
// type. If cause is non-nil and the type expression was a valid type but not
// generic, cause will be populated with a message describing the error.
func (check *Checker) genericType(e ast.Expr, cause *string) Type {
typ := check.typInternal(e, nil)
assert(isTyped(typ))
if typ != Typ[Invalid] && !isGeneric(typ) {
if cause != nil {
*cause = check.sprintf("%s is not a generic type", typ)
}
typ = Typ[Invalid]
}
// TODO(gri) what is the correct call below?
check.recordTypeAndValue(e, typexpr, typ, nil)
return typ
}
// goTypeName returns the Go type name for typ and
// removes any occurrences of "types." from that name.
func goTypeName(typ Type) string {
return strings.ReplaceAll(fmt.Sprintf("%T", typ), "types.", "")
}
// typInternal drives type checking of types.
// Must only be called by definedType or genericType.
func (check *Checker) typInternal(e0 ast.Expr, def *Named) (T Type) {
if check.conf._Trace {
check.trace(e0.Pos(), "-- type %s", e0)
check.indent++
defer func() {
check.indent--
var under Type
if T != nil {
// Calling under() here may lead to endless instantiations.
// Test case: type T[P any] *T[P]
under = safeUnderlying(T)
}
if T == under {
check.trace(e0.Pos(), "=> %s // %s", T, goTypeName(T))
} else {
check.trace(e0.Pos(), "=> %s (under = %s) // %s", T, under, goTypeName(T))
}
}()
}
switch e := e0.(type) {
case *ast.BadExpr:
// ignore - error reported before
case *ast.Ident:
var x operand
check.ident(&x, e, def, true)
switch x.mode {
case typexpr:
typ := x.typ
def.setUnderlying(typ)
return typ
case invalid:
// ignore - error reported before
case novalue:
check.errorf(&x, NotAType, "%s used as type", &x)
default:
check.errorf(&x, NotAType, "%s is not a type", &x)
}
case *ast.SelectorExpr:
var x operand
check.selector(&x, e, def, true)
switch x.mode {
case typexpr:
typ := x.typ
def.setUnderlying(typ)
return typ
case invalid:
// ignore - error reported before
case novalue:
check.errorf(&x, NotAType, "%s used as type", &x)
default:
check.errorf(&x, NotAType, "%s is not a type", &x)
}
case *ast.IndexExpr, *ast.IndexListExpr:
ix := typeparams.UnpackIndexExpr(e)
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(e, ix.Lbrack), UnsupportedFeature, "type instantiation requires go1.18 or later")
}
return check.instantiatedType(ix, def)
case *ast.ParenExpr:
// Generic types must be instantiated before they can be used in any form.
// Consequently, generic types cannot be parenthesized.
return check.definedType(e.X, def)
case *ast.ArrayType:
if e.Len == nil {
typ := new(Slice)
def.setUnderlying(typ)
typ.elem = check.varType(e.Elt)
return typ
}
typ := new(Array)
def.setUnderlying(typ)
// Provide a more specific error when encountering a [...] array
// rather than leaving it to the handling of the ... expression.
if _, ok := e.Len.(*ast.Ellipsis); ok {
check.error(e.Len, BadDotDotDotSyntax, "invalid use of [...] array (outside a composite literal)")
typ.len = -1
} else {
typ.len = check.arrayLength(e.Len)
}
typ.elem = check.varType(e.Elt)
if typ.len >= 0 {
return typ
}
// report error if we encountered [...]
case *ast.Ellipsis:
// dots are handled explicitly where they are legal
// (array composite literals and parameter lists)
check.error(e, InvalidDotDotDot, "invalid use of '...'")
check.use(e.Elt)
case *ast.StructType:
typ := new(Struct)
def.setUnderlying(typ)
check.structType(typ, e)
return typ
case *ast.StarExpr:
typ := new(Pointer)
typ.base = Typ[Invalid] // avoid nil base in invalid recursive type declaration
def.setUnderlying(typ)
typ.base = check.varType(e.X)
return typ
case *ast.FuncType:
typ := new(Signature)
def.setUnderlying(typ)
check.funcType(typ, nil, e)
return typ
case *ast.InterfaceType:
typ := check.newInterface()
def.setUnderlying(typ)
check.interfaceType(typ, e, def)
return typ
case *ast.MapType:
typ := new(Map)
def.setUnderlying(typ)
typ.key = check.varType(e.Key)
typ.elem = check.varType(e.Value)
// spec: "The comparison operators == and != must be fully defined
// for operands of the key type; thus the key type must not be a
// function, map, or slice."
//
// Delay this check because it requires fully setup types;
// it is safe to continue in any case (was go.dev/issue/6667).
check.later(func() {
if !Comparable(typ.key) {
var why string
if isTypeParam(typ.key) {
why = " (missing comparable constraint)"
}
check.errorf(e.Key, IncomparableMapKey, "invalid map key type %s%s", typ.key, why)
}
}).describef(e.Key, "check map key %s", typ.key)
return typ
case *ast.ChanType:
typ := new(Chan)
def.setUnderlying(typ)
dir := SendRecv
switch e.Dir {
case ast.SEND | ast.RECV:
// nothing to do
case ast.SEND:
dir = SendOnly
case ast.RECV:
dir = RecvOnly
default:
check.errorf(e, InvalidSyntaxTree, "unknown channel direction %d", e.Dir)
// ok to continue
}
typ.dir = dir
typ.elem = check.varType(e.Value)
return typ
default:
check.errorf(e0, NotAType, "%s is not a type", e0)
check.use(e0)
}
typ := Typ[Invalid]
def.setUnderlying(typ)
return typ
}
func (check *Checker) instantiatedType(ix *typeparams.IndexExpr, def *Named) (res Type) {
if check.conf._Trace {
check.trace(ix.Pos(), "-- instantiating type %s with %s", ix.X, ix.Indices)
check.indent++
defer func() {
check.indent--
// Don't format the underlying here. It will always be nil.
check.trace(ix.Pos(), "=> %s", res)
}()
}
var cause string
gtyp := check.genericType(ix.X, &cause)
if cause != "" {
check.errorf(ix.Orig, NotAGenericType, invalidOp+"%s (%s)", ix.Orig, cause)
}
if gtyp == Typ[Invalid] {
return gtyp // error already reported
}
orig, _ := gtyp.(*Named)
if orig == nil {
panic(fmt.Sprintf("%v: cannot instantiate %v", ix.Pos(), gtyp))
}
// evaluate arguments
targs := check.typeList(ix.Indices)
if targs == nil {
def.setUnderlying(Typ[Invalid]) // avoid errors later due to lazy instantiation
return Typ[Invalid]
}
// create the instance
inst := check.instance(ix.Pos(), orig, targs, nil, check.context()).(*Named)
def.setUnderlying(inst)
// orig.tparams may not be set up, so we need to do expansion later.
check.later(func() {
// This is an instance from the source, not from recursive substitution,
// and so it must be resolved during type-checking so that we can report
// errors.
check.recordInstance(ix.Orig, inst.TypeArgs().list(), inst)
if check.validateTArgLen(ix.Pos(), inst.TypeParams().Len(), inst.TypeArgs().Len()) {
if i, err := check.verify(ix.Pos(), inst.TypeParams().list(), inst.TypeArgs().list(), check.context()); err != nil {
// best position for error reporting
pos := ix.Pos()
if i < len(ix.Indices) {
pos = ix.Indices[i].Pos()
}
check.softErrorf(atPos(pos), InvalidTypeArg, err.Error())
} else {
check.mono.recordInstance(check.pkg, ix.Pos(), inst.TypeParams().list(), inst.TypeArgs().list(), ix.Indices)
}
}
// TODO(rfindley): remove this call: we don't need to call validType here,
// as cycles can only occur for types used inside a Named type declaration,
// and so it suffices to call validType from declared types.
check.validType(inst)
}).describef(ix, "resolve instance %s", inst)
return inst
}
// arrayLength type-checks the array length expression e
// and returns the constant length >= 0, or a value < 0
// to indicate an error (and thus an unknown length).
func (check *Checker) arrayLength(e ast.Expr) int64 {
// If e is an identifier, the array declaration might be an
// attempt at a parameterized type declaration with missing
// constraint. Provide an error message that mentions array
// length.
if name, _ := e.(*ast.Ident); name != nil {
obj := check.lookup(name.Name)
if obj == nil {
check.errorf(name, InvalidArrayLen, "undefined array length %s or missing type constraint", name.Name)
return -1
}
if _, ok := obj.(*Const); !ok {
check.errorf(name, InvalidArrayLen, "invalid array length %s", name.Name)
return -1
}
}
var x operand
check.expr(&x, e)
if x.mode != constant_ {
if x.mode != invalid {
check.errorf(&x, InvalidArrayLen, "array length %s must be constant", &x)
}
return -1
}
if isUntyped(x.typ) || isInteger(x.typ) {
if val := constant.ToInt(x.val); val.Kind() == constant.Int {
if representableConst(val, check, Typ[Int], nil) {
if n, ok := constant.Int64Val(val); ok && n >= 0 {
return n
}
check.errorf(&x, InvalidArrayLen, "invalid array length %s", &x)
return -1
}
}
}
check.errorf(&x, InvalidArrayLen, "array length %s must be integer", &x)
return -1
}
// typeList provides the list of types corresponding to the incoming expression list.
// If an error occurred, the result is nil, but all list elements were type-checked.
func (check *Checker) typeList(list []ast.Expr) []Type {
res := make([]Type, len(list)) // res != nil even if len(list) == 0
for i, x := range list {
t := check.varType(x)
if t == Typ[Invalid] {
res = nil
}
if res != nil {
res[i] = t
}
}
return res
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// under returns the true expanded underlying type.
// If it doesn't exist, the result is Typ[Invalid].
// under must only be called when a type is known
// to be fully set up.
func under(t Type) Type {
if t, _ := t.(*Named); t != nil {
return t.under()
}
return t.Underlying()
}
// If t is not a type parameter, coreType returns the underlying type.
// If t is a type parameter, coreType returns the single underlying
// type of all types in its type set if it exists, or nil otherwise. If the
// type set contains only unrestricted and restricted channel types (with
// identical element types), the single underlying type is the restricted
// channel type if the restrictions are always the same, or nil otherwise.
func coreType(t Type) Type {
tpar, _ := t.(*TypeParam)
if tpar == nil {
return under(t)
}
var su Type
if tpar.underIs(func(u Type) bool {
if u == nil {
return false
}
if su != nil {
u = match(su, u)
if u == nil {
return false
}
}
// su == nil || match(su, u) != nil
su = u
return true
}) {
return su
}
return nil
}
// coreString is like coreType but also considers []byte
// and strings as identical. In this case, if successful and we saw
// a string, the result is of type (possibly untyped) string.
func coreString(t Type) Type {
tpar, _ := t.(*TypeParam)
if tpar == nil {
return under(t) // string or untyped string
}
var su Type
hasString := false
if tpar.underIs(func(u Type) bool {
if u == nil {
return false
}
if isString(u) {
u = NewSlice(universeByte)
hasString = true
}
if su != nil {
u = match(su, u)
if u == nil {
return false
}
}
// su == nil || match(su, u) != nil
su = u
return true
}) {
if hasString {
return Typ[String]
}
return su
}
return nil
}
// If x and y are identical, match returns x.
// If x and y are identical channels but for their direction
// and one of them is unrestricted, match returns the channel
// with the restricted direction.
// In all other cases, match returns nil.
func match(x, y Type) Type {
// Common case: we don't have channels.
if Identical(x, y) {
return x
}
// We may have channels that differ in direction only.
if x, _ := x.(*Chan); x != nil {
if y, _ := y.(*Chan); y != nil && Identical(x.elem, y.elem) {
// We have channels that differ in direction only.
// If there's an unrestricted channel, select the restricted one.
switch {
case x.dir == SendRecv:
return y
case y.dir == SendRecv:
return x
}
}
}
// types are different
return nil
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type unification.
//
// Type unification attempts to make two types x and y structurally
// identical by determining the types for a given list of (bound)
// type parameters which may occur within x and y. If x and y are
// are structurally different (say []T vs chan T), or conflicting
// types are determined for type parameters, unification fails.
// If unification succeeds, as a side-effect, the types of the
// bound type parameters may be determined.
//
// Unification typically requires multiple calls u.unify(x, y) to
// a given unifier u, with various combinations of types x and y.
// In each call, additional type parameter types may be determined
// as a side effect. If a call fails (returns false), unification
// fails.
//
// In the unification context, structural identity ignores the
// difference between a defined type and its underlying type.
// It also ignores the difference between an (external, unbound)
// type parameter and its core type.
// If two types are not structurally identical, they cannot be Go
// identical types. On the other hand, if they are structurally
// identical, they may be Go identical or at least assignable, or
// they may be in the type set of a constraint.
// Whether they indeed are identical or assignable is determined
// upon instantiation and function argument passing.
package types
import (
"bytes"
"fmt"
"sort"
"strings"
)
const (
// Upper limit for recursion depth. Used to catch infinite recursions
// due to implementation issues (e.g., see issues go.dev/issue/48619, go.dev/issue/48656).
unificationDepthLimit = 50
// Whether to panic when unificationDepthLimit is reached.
// If disabled, a recursion depth overflow results in a (quiet)
// unification failure.
panicAtUnificationDepthLimit = true
// If enableCoreTypeUnification is set, unification will consider
// the core types, if any, of non-local (unbound) type parameters.
enableCoreTypeUnification = true
// If traceInference is set, unification will print a trace of its operation.
// Interpretation of trace:
// x ≡ y attempt to unify types x and y
// p ➞ y type parameter p is set to type y (p is inferred to be y)
// p ⇄ q type parameters p and q match (p is inferred to be q and vice versa)
// x ≢ y types x and y cannot be unified
// [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types
traceInference = false
)
// A unifier maintains a list of type parameters and
// corresponding types inferred for each type parameter.
// A unifier is created by calling newUnifier.
type unifier struct {
// handles maps each type parameter to its inferred type through
// an indirection *Type called (inferred type) "handle".
// Initially, each type parameter has its own, separate handle,
// with a nil (i.e., not yet inferred) type.
// After a type parameter P is unified with a type parameter Q,
// P and Q share the same handle (and thus type). This ensures
// that inferring the type for a given type parameter P will
// automatically infer the same type for all other parameters
// unified (joined) with P.
handles map[*TypeParam]*Type
depth int // recursion depth during unification
}
// newUnifier returns a new unifier initialized with the given type parameter
// and corresponding type argument lists. The type argument list may be shorter
// than the type parameter list, and it may contain nil types. Matching type
// parameters and arguments must have the same index.
func newUnifier(tparams []*TypeParam, targs []Type) *unifier {
assert(len(tparams) >= len(targs))
handles := make(map[*TypeParam]*Type, len(tparams))
// Allocate all handles up-front: in a correct program, all type parameters
// must be resolved and thus eventually will get a handle.
// Also, sharing of handles caused by unified type parameters is rare and
// so it's ok to not optimize for that case (and delay handle allocation).
for i, x := range tparams {
var t Type
if i < len(targs) {
t = targs[i]
}
handles[x] = &t
}
return &unifier{handles, 0}
}
// unify attempts to unify x and y and reports whether it succeeded.
// As a side-effect, types may be inferred for type parameters.
func (u *unifier) unify(x, y Type) bool {
return u.nify(x, y, nil)
}
func (u *unifier) tracef(format string, args ...interface{}) {
fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, nil, true, format, args...))
}
// String returns a string representation of the current mapping
// from type parameters to types.
func (u *unifier) String() string {
// sort type parameters for reproducible strings
tparams := make(typeParamsById, len(u.handles))
i := 0
for tpar := range u.handles {
tparams[i] = tpar
i++
}
sort.Sort(tparams)
var buf bytes.Buffer
w := newTypeWriter(&buf, nil)
w.byte('[')
for i, x := range tparams {
if i > 0 {
w.string(", ")
}
w.typ(x)
w.string(": ")
w.typ(u.at(x))
}
w.byte(']')
return buf.String()
}
type typeParamsById []*TypeParam
func (s typeParamsById) Len() int { return len(s) }
func (s typeParamsById) Less(i, j int) bool { return s[i].id < s[j].id }
func (s typeParamsById) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// join unifies the given type parameters x and y.
// If both type parameters already have a type associated with them
// and they are not joined, join fails and returns false.
func (u *unifier) join(x, y *TypeParam) bool {
if traceInference {
u.tracef("%s ⇄ %s", x, y)
}
switch hx, hy := u.handles[x], u.handles[y]; {
case hx == hy:
// Both type parameters already share the same handle. Nothing to do.
case *hx != nil && *hy != nil:
// Both type parameters have (possibly different) inferred types. Cannot join.
return false
case *hx != nil:
// Only type parameter x has an inferred type. Use handle of x.
u.setHandle(y, hx)
// This case is treated like the default case.
// case *hy != nil:
// // Only type parameter y has an inferred type. Use handle of y.
// u.setHandle(x, hy)
default:
// Neither type parameter has an inferred type. Use handle of y.
u.setHandle(x, hy)
}
return true
}
// asTypeParam returns x.(*TypeParam) if x is a type parameter recorded with u.
// Otherwise, the result is nil.
func (u *unifier) asTypeParam(x Type) *TypeParam {
if x, _ := x.(*TypeParam); x != nil {
if _, found := u.handles[x]; found {
return x
}
}
return nil
}
// setHandle sets the handle for type parameter x
// (and all its joined type parameters) to h.
func (u *unifier) setHandle(x *TypeParam, h *Type) {
hx := u.handles[x]
assert(hx != nil)
for y, hy := range u.handles {
if hy == hx {
u.handles[y] = h
}
}
}
// at returns the (possibly nil) type for type parameter x.
func (u *unifier) at(x *TypeParam) Type {
return *u.handles[x]
}
// set sets the type t for type parameter x;
// t must not be nil.
func (u *unifier) set(x *TypeParam, t Type) {
assert(t != nil)
if traceInference {
u.tracef("%s ➞ %s", x, t)
}
*u.handles[x] = t
}
// unknowns returns the number of type parameters for which no type has been set yet.
func (u *unifier) unknowns() int {
n := 0
for _, h := range u.handles {
if *h == nil {
n++
}
}
return n
}
// inferred returns the list of inferred types for the given type parameter list.
// The result is never nil and has the same length as tparams; result types that
// could not be inferred are nil. Corresponding type parameters and result types
// have identical indices.
func (u *unifier) inferred(tparams []*TypeParam) []Type {
list := make([]Type, len(tparams))
for i, x := range tparams {
list[i] = u.at(x)
}
return list
}
// nify implements the core unification algorithm which is an
// adapted version of Checker.identical. For changes to that
// code the corresponding changes should be made here.
// Must not be called directly from outside the unifier.
func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) {
u.depth++
if traceInference {
u.tracef("%s ≡ %s", x, y)
}
defer func() {
if traceInference && !result {
u.tracef("%s ≢ %s", x, y)
}
u.depth--
}()
// nothing to do if x == y
if x == y {
return true
}
// Stop gap for cases where unification fails.
if u.depth > unificationDepthLimit {
if traceInference {
u.tracef("depth %d >= %d", u.depth, unificationDepthLimit)
}
if panicAtUnificationDepthLimit {
panic("unification reached recursion depth limit")
}
return false
}
// Unification is symmetric, so we can swap the operands.
// Ensure that if we have at least one
// - defined type, make sure one is in y
// - type parameter recorded with u, make sure one is in x
if _, ok := x.(*Named); ok || u.asTypeParam(y) != nil {
if traceInference {
u.tracef("%s ≡ %s (swap)", y, x)
}
x, y = y, x
}
// Unification will fail if we match a defined type against a type literal.
// Per the (spec) assignment rules, assignments of values to variables with
// the same type structure are permitted as long as at least one of them
// is not a defined type. To accomodate for that possibility, we continue
// unification with the underlying type of a defined type if the other type
// is a type literal.
// We also continue if the other type is a basic type because basic types
// are valid underlying types and may appear as core types of type constraints.
// If we exclude them, inferred defined types for type parameters may not
// match against the core types of their constraints (even though they might
// correctly match against some of the types in the constraint's type set).
// Finally, if unification (incorrectly) succeeds by matching the underlying
// type of a defined type against a basic type (because we include basic types
// as type literals here), and if that leads to an incorrectly inferred type,
// we will fail at function instantiation or argument assignment time.
//
// If we have at least one defined type, there is one in y.
if ny, _ := y.(*Named); ny != nil && isTypeLit(x) {
if traceInference {
u.tracef("%s ≡ under %s", x, ny)
}
y = ny.under()
// Per the spec, a defined type cannot have an underlying type
// that is a type parameter.
assert(!isTypeParam(y))
// x and y may be identical now
if x == y {
return true
}
}
// Cases where at least one of x or y is a type parameter recorded with u.
// If we have at least one type parameter, there is one in x.
// If we have exactly one type parameter, because it is in x,
// isTypeLit(x) is false and y was not changed above. In other
// words, if y was a defined type, it is still a defined type
// (relevant for the logic below).
switch px, py := u.asTypeParam(x), u.asTypeParam(y); {
case px != nil && py != nil:
// both x and y are type parameters
if u.join(px, py) {
return true
}
// both x and y have an inferred type - they must match
return u.nify(u.at(px), u.at(py), p)
case px != nil:
// x is a type parameter, y is not
if x := u.at(px); x != nil {
// x has an inferred type which must match y
if u.nify(x, y, p) {
// If we have a match, possibly through underlying types,
// and y is a defined type, make sure we record that type
// for type parameter x, which may have until now only
// recorded an underlying type (go.dev/issue/43056).
if _, ok := y.(*Named); ok {
u.set(px, y)
}
return true
}
return false
}
// otherwise, infer type from y
u.set(px, y)
return true
}
// x != y if we get here
assert(x != y)
// If we get here and x or y is a type parameter, they are unbound
// (not recorded with the unifier).
// Ensure that if we have at least one type parameter, it is in x
// (the earlier swap checks for _recorded_ type parameters only).
if isTypeParam(y) {
if traceInference {
u.tracef("%s ≡ %s (swap)", y, x)
}
x, y = y, x
}
switch x := x.(type) {
case *Basic:
// Basic types are singletons except for the rune and byte
// aliases, thus we cannot solely rely on the x == y check
// above. See also comment in TypeName.IsAlias.
if y, ok := y.(*Basic); ok {
return x.kind == y.kind
}
case *Array:
// Two array types are identical if they have identical element types
// and the same array length.
if y, ok := y.(*Array); ok {
// If one or both array lengths are unknown (< 0) due to some error,
// assume they are the same to avoid spurious follow-on errors.
return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p)
}
case *Slice:
// Two slice types are identical if they have identical element types.
if y, ok := y.(*Slice); ok {
return u.nify(x.elem, y.elem, p)
}
case *Struct:
// Two struct types are identical if they have the same sequence of fields,
// and if corresponding fields have the same names, and identical types,
// and identical tags. Two embedded fields are considered to have the same
// name. Lower-case field names from different packages are always different.
if y, ok := y.(*Struct); ok {
if x.NumFields() == y.NumFields() {
for i, f := range x.fields {
g := y.fields[i]
if f.embedded != g.embedded ||
x.Tag(i) != y.Tag(i) ||
!f.sameId(g.pkg, g.name) ||
!u.nify(f.typ, g.typ, p) {
return false
}
}
return true
}
}
case *Pointer:
// Two pointer types are identical if they have identical base types.
if y, ok := y.(*Pointer); ok {
return u.nify(x.base, y.base, p)
}
case *Tuple:
// Two tuples types are identical if they have the same number of elements
// and corresponding elements have identical types.
if y, ok := y.(*Tuple); ok {
if x.Len() == y.Len() {
if x != nil {
for i, v := range x.vars {
w := y.vars[i]
if !u.nify(v.typ, w.typ, p) {
return false
}
}
}
return true
}
}
case *Signature:
// Two function types are identical if they have the same number of parameters
// and result values, corresponding parameter and result types are identical,
// and either both functions are variadic or neither is. Parameter and result
// names are not required to match.
// TODO(gri) handle type parameters or document why we can ignore them.
if y, ok := y.(*Signature); ok {
return x.variadic == y.variadic &&
u.nify(x.params, y.params, p) &&
u.nify(x.results, y.results, p)
}
case *Interface:
// Two interface types are identical if they have the same set of methods with
// the same names and identical function types. Lower-case method names from
// different packages are always different. The order of the methods is irrelevant.
if y, ok := y.(*Interface); ok {
xset := x.typeSet()
yset := y.typeSet()
if xset.comparable != yset.comparable {
return false
}
if !xset.terms.equal(yset.terms) {
return false
}
a := xset.methods
b := yset.methods
if len(a) == len(b) {
// Interface types are the only types where cycles can occur
// that are not "terminated" via named types; and such cycles
// can only be created via method parameter types that are
// anonymous interfaces (directly or indirectly) embedding
// the current interface. Example:
//
// type T interface {
// m() interface{T}
// }
//
// If two such (differently named) interfaces are compared,
// endless recursion occurs if the cycle is not detected.
//
// If x and y were compared before, they must be equal
// (if they were not, the recursion would have stopped);
// search the ifacePair stack for the same pair.
//
// This is a quadratic algorithm, but in practice these stacks
// are extremely short (bounded by the nesting depth of interface
// type declarations that recur via parameter types, an extremely
// rare occurrence). An alternative implementation might use a
// "visited" map, but that is probably less efficient overall.
q := &ifacePair{x, y, p}
for p != nil {
if p.identical(q) {
return true // same pair was compared before
}
p = p.prev
}
if debug {
assertSortedMethods(a)
assertSortedMethods(b)
}
for i, f := range a {
g := b[i]
if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) {
return false
}
}
return true
}
}
case *Map:
// Two map types are identical if they have identical key and value types.
if y, ok := y.(*Map); ok {
return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p)
}
case *Chan:
// Two channel types are identical if they have identical value types.
if y, ok := y.(*Chan); ok {
return u.nify(x.elem, y.elem, p)
}
case *Named:
// TODO(gri) This code differs now from the parallel code in Checker.identical. Investigate.
if y, ok := y.(*Named); ok {
xargs := x.TypeArgs().list()
yargs := y.TypeArgs().list()
if len(xargs) != len(yargs) {
return false
}
// TODO(gri) This is not always correct: two types may have the same names
// in the same package if one of them is nested in a function.
// Extremely unlikely but we need an always correct solution.
if x.obj.pkg == y.obj.pkg && x.obj.name == y.obj.name {
for i, x := range xargs {
if !u.nify(x, yargs[i], p) {
return false
}
}
return true
}
}
case *TypeParam:
// x must be an unbound type parameter (see comment above).
if debug {
assert(u.asTypeParam(x) == nil)
}
// By definition, a valid type argument must be in the type set of
// the respective type constraint. Therefore, the type argument's
// underlying type must be in the set of underlying types of that
// constraint. If there is a single such underlying type, it's the
// constraint's core type. It must match the type argument's under-
// lying type, irrespective of whether the actual type argument,
// which may be a defined type, is actually in the type set (that
// will be determined at instantiation time).
// Thus, if we have the core type of an unbound type parameter,
// we know the structure of the possible types satisfying such
// parameters. Use that core type for further unification
// (see go.dev/issue/50755 for a test case).
if enableCoreTypeUnification {
// Because the core type is always an underlying type,
// unification will take care of matching against a
// defined or literal type automatically.
// If y is also an unbound type parameter, we will end
// up here again with x and y swapped, so we don't
// need to take care of that case separately.
if cx := coreType(x); cx != nil {
if traceInference {
u.tracef("core %s ≡ %s", x, y)
}
return u.nify(cx, y, p)
}
}
// x != y and there's nothing to do
case nil:
// avoid a crash in case of nil type
default:
panic(sprintf(nil, nil, true, "u.nify(%s, %s)", x, y))
}
return false
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"go/ast"
"go/token"
. "internal/types/errors"
)
// ----------------------------------------------------------------------------
// API
// A Union represents a union of terms embedded in an interface.
type Union struct {
terms []*Term // list of syntactical terms (not a canonicalized termlist)
}
// NewUnion returns a new Union type with the given terms.
// It is an error to create an empty union; they are syntactically not possible.
func NewUnion(terms []*Term) *Union {
if len(terms) == 0 {
panic("empty union")
}
return &Union{terms}
}
func (u *Union) Len() int { return len(u.terms) }
func (u *Union) Term(i int) *Term { return u.terms[i] }
func (u *Union) Underlying() Type { return u }
func (u *Union) String() string { return TypeString(u, nil) }
// A Term represents a term in a Union.
type Term term
// NewTerm returns a new union term.
func NewTerm(tilde bool, typ Type) *Term { return &Term{tilde, typ} }
func (t *Term) Tilde() bool { return t.tilde }
func (t *Term) Type() Type { return t.typ }
func (t *Term) String() string { return (*term)(t).String() }
// ----------------------------------------------------------------------------
// Implementation
// Avoid excessive type-checking times due to quadratic termlist operations.
const maxTermCount = 100
// parseUnion parses uexpr as a union of expressions.
// The result is a Union type, or Typ[Invalid] for some errors.
func parseUnion(check *Checker, uexpr ast.Expr) Type {
blist, tlist := flattenUnion(nil, uexpr)
assert(len(blist) == len(tlist)-1)
var terms []*Term
var u Type
for i, x := range tlist {
term := parseTilde(check, x)
if len(tlist) == 1 && !term.tilde {
// Single type. Ok to return early because all relevant
// checks have been performed in parseTilde (no need to
// run through term validity check below).
return term.typ // typ already recorded through check.typ in parseTilde
}
if len(terms) >= maxTermCount {
if u != Typ[Invalid] {
check.errorf(x, InvalidUnion, "cannot handle more than %d union terms (implementation limitation)", maxTermCount)
u = Typ[Invalid]
}
} else {
terms = append(terms, term)
u = &Union{terms}
}
if i > 0 {
check.recordTypeAndValue(blist[i-1], typexpr, u, nil)
}
}
if u == Typ[Invalid] {
return u
}
// Check validity of terms.
// Do this check later because it requires types to be set up.
// Note: This is a quadratic algorithm, but unions tend to be short.
check.later(func() {
for i, t := range terms {
if t.typ == Typ[Invalid] {
continue
}
u := under(t.typ)
f, _ := u.(*Interface)
if t.tilde {
if f != nil {
check.errorf(tlist[i], InvalidUnion, "invalid use of ~ (%s is an interface)", t.typ)
continue // don't report another error for t
}
if !Identical(u, t.typ) {
check.errorf(tlist[i], InvalidUnion, "invalid use of ~ (underlying type of %s is %s)", t.typ, u)
continue
}
}
// Stand-alone embedded interfaces are ok and are handled by the single-type case
// in the beginning. Embedded interfaces with tilde are excluded above. If we reach
// here, we must have at least two terms in the syntactic term list (but not necessarily
// in the term list of the union's type set).
if f != nil {
tset := f.typeSet()
switch {
case tset.NumMethods() != 0:
check.errorf(tlist[i], InvalidUnion, "cannot use %s in union (%s contains methods)", t, t)
case t.typ == universeComparable.Type():
check.error(tlist[i], InvalidUnion, "cannot use comparable in union")
case tset.comparable:
check.errorf(tlist[i], InvalidUnion, "cannot use %s in union (%s embeds comparable)", t, t)
}
continue // terms with interface types are not subject to the no-overlap rule
}
// Report overlapping (non-disjoint) terms such as
// a|a, a|~a, ~a|~a, and ~a|A (where under(A) == a).
if j := overlappingTerm(terms[:i], t); j >= 0 {
check.softErrorf(tlist[i], InvalidUnion, "overlapping terms %s and %s", t, terms[j])
}
}
}).describef(uexpr, "check term validity %s", uexpr)
return u
}
func parseTilde(check *Checker, tx ast.Expr) *Term {
x := tx
var tilde bool
if op, _ := x.(*ast.UnaryExpr); op != nil && op.Op == token.TILDE {
x = op.X
tilde = true
}
typ := check.typ(x)
// Embedding stand-alone type parameters is not permitted (go.dev/issue/47127).
// We don't need this restriction anymore if we make the underlying type of a type
// parameter its constraint interface: if we embed a lone type parameter, we will
// simply use its underlying type (like we do for other named, embedded interfaces),
// and since the underlying type is an interface the embedding is well defined.
if isTypeParam(typ) {
if tilde {
check.errorf(x, MisplacedTypeParam, "type in term %s cannot be a type parameter", tx)
} else {
check.error(x, MisplacedTypeParam, "term cannot be a type parameter")
}
typ = Typ[Invalid]
}
term := NewTerm(tilde, typ)
if tilde {
check.recordTypeAndValue(tx, typexpr, &Union{[]*Term{term}}, nil)
}
return term
}
// overlappingTerm reports the index of the term x in terms which is
// overlapping (not disjoint) from y. The result is < 0 if there is no
// such term. The type of term y must not be an interface, and terms
// with an interface type are ignored in the terms list.
func overlappingTerm(terms []*Term, y *Term) int {
assert(!IsInterface(y.typ))
for i, x := range terms {
if IsInterface(x.typ) {
continue
}
// disjoint requires non-nil, non-top arguments,
// and non-interface types as term types.
if debug {
if x == nil || x.typ == nil || y == nil || y.typ == nil {
panic("empty or top union term")
}
}
if !(*term)(x).disjoint((*term)(y)) {
return i
}
}
return -1
}
// flattenUnion walks a union type expression of the form A | B | C | ...,
// extracting both the binary exprs (blist) and leaf types (tlist).
func flattenUnion(list []ast.Expr, x ast.Expr) (blist, tlist []ast.Expr) {
if o, _ := x.(*ast.BinaryExpr); o != nil && o.Op == token.OR {
blist, tlist = flattenUnion(list, o.X)
blist = append(blist, o)
x = o.Y
}
return blist, append(tlist, x)
}
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file sets up the universe scope and the unsafe package.
package types
import (
"go/constant"
"strings"
)
// The Universe scope contains all predeclared objects of Go.
// It is the outermost scope of any chain of nested scopes.
var Universe *Scope
// The Unsafe package is the package returned by an importer
// for the import path "unsafe".
var Unsafe *Package
var (
universeIota Object
universeByte Type // uint8 alias, but has name "byte"
universeRune Type // int32 alias, but has name "rune"
universeAny Object
universeError Type
universeComparable Object
)
// Typ contains the predeclared *Basic types indexed by their
// corresponding BasicKind.
//
// The *Basic type for Typ[Byte] will have the name "uint8".
// Use Universe.Lookup("byte").Type() to obtain the specific
// alias basic type named "byte" (and analogous for "rune").
var Typ = []*Basic{
Invalid: {Invalid, 0, "invalid type"},
Bool: {Bool, IsBoolean, "bool"},
Int: {Int, IsInteger, "int"},
Int8: {Int8, IsInteger, "int8"},
Int16: {Int16, IsInteger, "int16"},
Int32: {Int32, IsInteger, "int32"},
Int64: {Int64, IsInteger, "int64"},
Uint: {Uint, IsInteger | IsUnsigned, "uint"},
Uint8: {Uint8, IsInteger | IsUnsigned, "uint8"},
Uint16: {Uint16, IsInteger | IsUnsigned, "uint16"},
Uint32: {Uint32, IsInteger | IsUnsigned, "uint32"},
Uint64: {Uint64, IsInteger | IsUnsigned, "uint64"},
Uintptr: {Uintptr, IsInteger | IsUnsigned, "uintptr"},
Float32: {Float32, IsFloat, "float32"},
Float64: {Float64, IsFloat, "float64"},
Complex64: {Complex64, IsComplex, "complex64"},
Complex128: {Complex128, IsComplex, "complex128"},
String: {String, IsString, "string"},
UnsafePointer: {UnsafePointer, 0, "Pointer"},
UntypedBool: {UntypedBool, IsBoolean | IsUntyped, "untyped bool"},
UntypedInt: {UntypedInt, IsInteger | IsUntyped, "untyped int"},
UntypedRune: {UntypedRune, IsInteger | IsUntyped, "untyped rune"},
UntypedFloat: {UntypedFloat, IsFloat | IsUntyped, "untyped float"},
UntypedComplex: {UntypedComplex, IsComplex | IsUntyped, "untyped complex"},
UntypedString: {UntypedString, IsString | IsUntyped, "untyped string"},
UntypedNil: {UntypedNil, IsUntyped, "untyped nil"},
}
var aliases = [...]*Basic{
{Byte, IsInteger | IsUnsigned, "byte"},
{Rune, IsInteger, "rune"},
}
func defPredeclaredTypes() {
for _, t := range Typ {
def(NewTypeName(nopos, nil, t.name, t))
}
for _, t := range aliases {
def(NewTypeName(nopos, nil, t.name, t))
}
// type any = interface{}
// Note: don't use &emptyInterface for the type of any. Using a unique
// pointer allows us to detect any and format it as "any" rather than
// interface{}, which clarifies user-facing error messages significantly.
def(NewTypeName(nopos, nil, "any", &Interface{complete: true, tset: &topTypeSet}))
// type error interface{ Error() string }
{
obj := NewTypeName(nopos, nil, "error", nil)
obj.setColor(black)
typ := NewNamed(obj, nil, nil)
// error.Error() string
recv := NewVar(nopos, nil, "", typ)
res := NewVar(nopos, nil, "", Typ[String])
sig := NewSignatureType(recv, nil, nil, nil, NewTuple(res), false)
err := NewFunc(nopos, nil, "Error", sig)
// interface{ Error() string }
ityp := &Interface{methods: []*Func{err}, complete: true}
computeInterfaceTypeSet(nil, nopos, ityp) // prevent races due to lazy computation of tset
typ.SetUnderlying(ityp)
def(obj)
}
// type comparable interface{} // marked as comparable
{
obj := NewTypeName(nopos, nil, "comparable", nil)
obj.setColor(black)
typ := NewNamed(obj, nil, nil)
// interface{} // marked as comparable
ityp := &Interface{complete: true, tset: &_TypeSet{nil, allTermlist, true}}
typ.SetUnderlying(ityp)
def(obj)
}
}
var predeclaredConsts = [...]struct {
name string
kind BasicKind
val constant.Value
}{
{"true", UntypedBool, constant.MakeBool(true)},
{"false", UntypedBool, constant.MakeBool(false)},
{"iota", UntypedInt, constant.MakeInt64(0)},
}
func defPredeclaredConsts() {
for _, c := range predeclaredConsts {
def(NewConst(nopos, nil, c.name, Typ[c.kind], c.val))
}
}
func defPredeclaredNil() {
def(&Nil{object{name: "nil", typ: Typ[UntypedNil], color_: black}})
}
// A builtinId is the id of a builtin function.
type builtinId int
const (
// universe scope
_Append builtinId = iota
_Cap
_Clear
_Close
_Complex
_Copy
_Delete
_Imag
_Len
_Make
_New
_Panic
_Print
_Println
_Real
_Recover
// package unsafe
_Add
_Alignof
_Offsetof
_Sizeof
_Slice
_SliceData
_String
_StringData
// testing support
_Assert
_Trace
)
var predeclaredFuncs = [...]struct {
name string
nargs int
variadic bool
kind exprKind
}{
_Append: {"append", 1, true, expression},
_Cap: {"cap", 1, false, expression},
_Clear: {"clear", 1, false, statement},
_Close: {"close", 1, false, statement},
_Complex: {"complex", 2, false, expression},
_Copy: {"copy", 2, false, statement},
_Delete: {"delete", 2, false, statement},
_Imag: {"imag", 1, false, expression},
_Len: {"len", 1, false, expression},
_Make: {"make", 1, true, expression},
_New: {"new", 1, false, expression},
_Panic: {"panic", 1, false, statement},
_Print: {"print", 0, true, statement},
_Println: {"println", 0, true, statement},
_Real: {"real", 1, false, expression},
_Recover: {"recover", 0, false, statement},
_Add: {"Add", 2, false, expression},
_Alignof: {"Alignof", 1, false, expression},
_Offsetof: {"Offsetof", 1, false, expression},
_Sizeof: {"Sizeof", 1, false, expression},
_Slice: {"Slice", 2, false, expression},
_SliceData: {"SliceData", 1, false, expression},
_String: {"String", 2, false, expression},
_StringData: {"StringData", 1, false, expression},
_Assert: {"assert", 1, false, statement},
_Trace: {"trace", 0, true, statement},
}
func defPredeclaredFuncs() {
for i := range predeclaredFuncs {
id := builtinId(i)
if id == _Assert || id == _Trace {
continue // only define these in testing environment
}
def(newBuiltin(id))
}
}
// DefPredeclaredTestFuncs defines the assert and trace built-ins.
// These built-ins are intended for debugging and testing of this
// package only.
func DefPredeclaredTestFuncs() {
if Universe.Lookup("assert") != nil {
return // already defined
}
def(newBuiltin(_Assert))
def(newBuiltin(_Trace))
}
func init() {
Universe = NewScope(nil, nopos, nopos, "universe")
Unsafe = NewPackage("unsafe", "unsafe")
Unsafe.complete = true
defPredeclaredTypes()
defPredeclaredConsts()
defPredeclaredNil()
defPredeclaredFuncs()
universeIota = Universe.Lookup("iota")
universeByte = Universe.Lookup("byte").Type()
universeRune = Universe.Lookup("rune").Type()
universeAny = Universe.Lookup("any")
universeError = Universe.Lookup("error").Type()
universeComparable = Universe.Lookup("comparable")
}
// Objects with names containing blanks are internal and not entered into
// a scope. Objects with exported names are inserted in the unsafe package
// scope; other objects are inserted in the universe scope.
func def(obj Object) {
assert(obj.color() == black)
name := obj.Name()
if strings.Contains(name, " ") {
return // nothing to do
}
// fix Obj link for named types
if typ, _ := obj.Type().(*Named); typ != nil {
typ.obj = obj.(*TypeName)
}
// exported identifiers go into package unsafe
scope := Universe
if obj.Exported() {
scope = Unsafe.scope
// set Pkg field
switch obj := obj.(type) {
case *TypeName:
obj.pkg = Unsafe
case *Builtin:
obj.pkg = Unsafe
default:
unreachable()
}
}
if scope.Insert(obj) != nil {
panic("double declaration of predeclared identifier")
}
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains various functionality that is
// different between go/types and types2. Factoring
// out this code allows more of the rest of the code
// to be shared.
package types
import "go/token"
// cmpPos compares the positions p and q and returns a result r as follows:
//
// r < 0: p is before q
// r == 0: p and q are the same position (but may not be identical)
// r > 0: p is after q
//
// If p and q are in different files, p is before q if the filename
// of p sorts lexicographically before the filename of q.
func cmpPos(p, q token.Pos) int { return int(p - q) }
// Code generated by "go test -run=Generate -write=all"; DO NOT EDIT.
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
// validType verifies that the given type does not "expand" indefinitely
// producing a cycle in the type graph.
// (Cycles involving alias types, as in "type A = [10]A" are detected
// earlier, via the objDecl cycle detection mechanism.)
func (check *Checker) validType(typ *Named) {
check.validType0(typ, nil, nil)
}
// validType0 checks if the given type is valid. If typ is a type parameter
// its value is looked up in the type argument list of the instantiated
// (enclosing) type, if it exists. Otherwise the type parameter must be from
// an enclosing function and can be ignored.
// The nest list describes the stack (the "nest in memory") of types which
// contain (or embed in the case of interfaces) other types. For instance, a
// struct named S which contains a field of named type F contains (the memory
// of) F in S, leading to the nest S->F. If a type appears in its own nest
// (say S->F->S) we have an invalid recursive type. The path list is the full
// path of named types in a cycle, it is only needed for error reporting.
func (check *Checker) validType0(typ Type, nest, path []*Named) bool {
switch t := typ.(type) {
case nil:
// We should never see a nil type but be conservative and panic
// only in debug mode.
if debug {
panic("validType0(nil)")
}
case *Array:
return check.validType0(t.elem, nest, path)
case *Struct:
for _, f := range t.fields {
if !check.validType0(f.typ, nest, path) {
return false
}
}
case *Union:
for _, t := range t.terms {
if !check.validType0(t.typ, nest, path) {
return false
}
}
case *Interface:
for _, etyp := range t.embeddeds {
if !check.validType0(etyp, nest, path) {
return false
}
}
case *Named:
// Exit early if we already know t is valid.
// This is purely an optimization but it prevents excessive computation
// times in pathological cases such as testdata/fixedbugs/issue6977.go.
// (Note: The valids map could also be allocated locally, once for each
// validType call.)
if check.valids.lookup(t) != nil {
break
}
// Don't report a 2nd error if we already know the type is invalid
// (e.g., if a cycle was detected earlier, via under).
// Note: ensure that t.orig is fully resolved by calling Underlying().
if t.Underlying() == Typ[Invalid] {
return false
}
// If the current type t is also found in nest, (the memory of) t is
// embedded in itself, indicating an invalid recursive type.
for _, e := range nest {
if Identical(e, t) {
// We have a cycle. If t != t.Origin() then t is an instance of
// the generic type t.Origin(). Because t is in the nest, t must
// occur within the definition (RHS) of the generic type t.Origin(),
// directly or indirectly, after expansion of the RHS.
// Therefore t.Origin() must be invalid, no matter how it is
// instantiated since the instantiation t of t.Origin() happens
// inside t.Origin()'s RHS and thus is always the same and always
// present.
// Therefore we can mark the underlying of both t and t.Origin()
// as invalid. If t is not an instance of a generic type, t and
// t.Origin() are the same.
// Furthermore, because we check all types in a package for validity
// before type checking is complete, any exported type that is invalid
// will have an invalid underlying type and we can't reach here with
// such a type (invalid types are excluded above).
// Thus, if we reach here with a type t, both t and t.Origin() (if
// different in the first place) must be from the current package;
// they cannot have been imported.
// Therefore it is safe to change their underlying types; there is
// no chance for a race condition (the types of the current package
// are not yet available to other goroutines).
assert(t.obj.pkg == check.pkg)
assert(t.Origin().obj.pkg == check.pkg)
t.underlying = Typ[Invalid]
t.Origin().underlying = Typ[Invalid]
// Find the starting point of the cycle and report it.
// Because each type in nest must also appear in path (see invariant below),
// type t must be in path since it was found in nest. But not every type in path
// is in nest. Specifically t may appear in path with an earlier index than the
// index of t in nest. Search again.
for start, p := range path {
if Identical(p, t) {
check.cycleError(makeObjList(path[start:]))
return false
}
}
panic("cycle start not found")
}
}
// No cycle was found. Check the RHS of t.
// Every type added to nest is also added to path; thus every type that is in nest
// must also be in path (invariant). But not every type in path is in nest, since
// nest may be pruned (see below, *TypeParam case).
if !check.validType0(t.Origin().fromRHS, append(nest, t), append(path, t)) {
return false
}
check.valids.add(t) // t is valid
case *TypeParam:
// A type parameter stands for the type (argument) it was instantiated with.
// Check the corresponding type argument for validity if we are in an
// instantiated type.
if len(nest) > 0 {
inst := nest[len(nest)-1] // the type instance
// Find the corresponding type argument for the type parameter
// and proceed with checking that type argument.
for i, tparam := range inst.TypeParams().list() {
// The type parameter and type argument lists should
// match in length but be careful in case of errors.
if t == tparam && i < inst.TypeArgs().Len() {
targ := inst.TypeArgs().At(i)
// The type argument must be valid in the enclosing
// type (where inst was instantiated), hence we must
// check targ's validity in the type nest excluding
// the current (instantiated) type (see the example
// at the end of this file).
// For error reporting we keep the full path.
return check.validType0(targ, nest[:len(nest)-1], path)
}
}
}
}
return true
}
// makeObjList returns the list of type name objects for the given
// list of named types.
func makeObjList(tlist []*Named) []Object {
olist := make([]Object, len(tlist))
for i, t := range tlist {
olist[i] = t.obj
}
return olist
}
// Here is an example illustrating why we need to exclude the
// instantiated type from nest when evaluating the validity of
// a type parameter. Given the declarations
//
// var _ A[A[string]]
//
// type A[P any] struct { _ B[P] }
// type B[P any] struct { _ P }
//
// we want to determine if the type A[A[string]] is valid.
// We start evaluating A[A[string]] outside any type nest:
//
// A[A[string]]
// nest =
// path =
//
// The RHS of A is now evaluated in the A[A[string]] nest:
//
// struct{_ B[P₁]}
// nest = A[A[string]]
// path = A[A[string]]
//
// The struct has a single field of type B[P₁] with which
// we continue:
//
// B[P₁]
// nest = A[A[string]]
// path = A[A[string]]
//
// struct{_ P₂}
// nest = A[A[string]]->B[P]
// path = A[A[string]]->B[P]
//
// Eventually we reach the type parameter P of type B (P₂):
//
// P₂
// nest = A[A[string]]->B[P]
// path = A[A[string]]->B[P]
//
// The type argument for P of B is the type parameter P of A (P₁).
// It must be evaluated in the type nest that existed when B was
// instantiated:
//
// P₁
// nest = A[A[string]] <== type nest at B's instantiation time
// path = A[A[string]]->B[P]
//
// If we'd use the current nest it would correspond to the path
// which will be wrong as we will see shortly. P's type argument
// is A[string], which again must be evaluated in the type nest
// that existed when A was instantiated with A[string]. That type
// nest is empty:
//
// A[string]
// nest = <== type nest at A's instantiation time
// path = A[A[string]]->B[P]
//
// Evaluation then proceeds as before for A[string]:
//
// struct{_ B[P₁]}
// nest = A[string]
// path = A[A[string]]->B[P]->A[string]
//
// Now we reach B[P] again. If we had not adjusted nest, it would
// correspond to path, and we would find B[P] in nest, indicating
// a cycle, which would clearly be wrong since there's no cycle in
// A[string]:
//
// B[P₁]
// nest = A[string]
// path = A[A[string]]->B[P]->A[string] <== path contains B[P]!
//
// But because we use the correct type nest, evaluation proceeds without
// errors and we get the evaluation sequence:
//
// struct{_ P₂}
// nest = A[string]->B[P]
// path = A[A[string]]->B[P]->A[string]->B[P]
// P₂
// nest = A[string]->B[P]
// path = A[A[string]]->B[P]->A[string]->B[P]
// P₁
// nest = A[string]
// path = A[A[string]]->B[P]->A[string]->B[P]
// string
// nest =
// path = A[A[string]]->B[P]->A[string]->B[P]
//
// At this point we're done and A[A[string]] and is valid.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package types
import (
"fmt"
"go/ast"
"go/token"
"internal/lazyregexp"
. "internal/types/errors"
"strconv"
"strings"
)
// langCompat reports an error if the representation of a numeric
// literal is not compatible with the current language version.
func (check *Checker) langCompat(lit *ast.BasicLit) {
s := lit.Value
if len(s) <= 2 || check.allowVersion(check.pkg, 1, 13) {
return
}
// len(s) > 2
if strings.Contains(s, "_") {
check.error(lit, UnsupportedFeature, "underscores in numeric literals requires go1.13 or later")
return
}
if s[0] != '0' {
return
}
radix := s[1]
if radix == 'b' || radix == 'B' {
check.error(lit, UnsupportedFeature, "binary literals requires go1.13 or later")
return
}
if radix == 'o' || radix == 'O' {
check.error(lit, UnsupportedFeature, "0o/0O-style octal literals requires go1.13 or later")
return
}
if lit.Kind != token.INT && (radix == 'x' || radix == 'X') {
check.error(lit, UnsupportedFeature, "hexadecimal floating-point literals requires go1.13 or later")
}
}
// allowVersion reports whether the given package
// is allowed to use version major.minor.
func (check *Checker) allowVersion(pkg *Package, major, minor int) bool {
// We assume that imported packages have all been checked,
// so we only have to check for the local package.
if pkg != check.pkg {
return true
}
ma, mi := check.version.major, check.version.minor
return ma == 0 && mi == 0 || ma > major || ma == major && mi >= minor
}
type version struct {
major, minor int
}
// parseGoVersion parses a Go version string (such as "go1.12")
// and returns the version, or an error. If s is the empty
// string, the version is 0.0.
func parseGoVersion(s string) (v version, err error) {
if s == "" {
return
}
matches := goVersionRx.FindStringSubmatch(s)
if matches == nil {
err = fmt.Errorf(`should be something like "go1.12"`)
return
}
v.major, err = strconv.Atoi(matches[1])
if err != nil {
return
}
v.minor, err = strconv.Atoi(matches[2])
return
}
// goVersionRx matches a Go version string, e.g. "go1.12".
var goVersionRx = lazyregexp.New(`^go([1-9]\d*)\.(0|[1-9]\d*)$`)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package adler32 implements the Adler-32 checksum.
//
// It is defined in RFC 1950:
//
// Adler-32 is composed of two sums accumulated per byte: s1 is
// the sum of all bytes, s2 is the sum of all s1 values. Both sums
// are done modulo 65521. s1 is initialized to 1, s2 to zero. The
// Adler-32 checksum is stored as s2*65536 + s1 in most-
// significant-byte first (network) order.
package adler32
import (
"errors"
"hash"
)
const (
// mod is the largest prime that is less than 65536.
mod = 65521
// nmax is the largest n such that
// 255 * n * (n+1) / 2 + (n+1) * (mod-1) <= 2^32-1.
// It is mentioned in RFC 1950 (search for "5552").
nmax = 5552
)
// The size of an Adler-32 checksum in bytes.
const Size = 4
// digest represents the partial evaluation of a checksum.
// The low 16 bits are s1, the high 16 bits are s2.
type digest uint32
func (d *digest) Reset() { *d = 1 }
// New returns a new hash.Hash32 computing the Adler-32 checksum. Its
// Sum method will lay the value out in big-endian byte order. The
// returned Hash32 also implements encoding.BinaryMarshaler and
// encoding.BinaryUnmarshaler to marshal and unmarshal the internal
// state of the hash.
func New() hash.Hash32 {
d := new(digest)
d.Reset()
return d
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 4 }
const (
magic = "adl\x01"
marshaledSize = len(magic) + 4
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = appendUint32(b, uint32(*d))
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/adler32: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/adler32: invalid hash state size")
}
*d = digest(readUint32(b[len(magic):]))
return nil
}
func appendUint32(b []byte, x uint32) []byte {
a := [4]byte{
byte(x >> 24),
byte(x >> 16),
byte(x >> 8),
byte(x),
}
return append(b, a[:]...)
}
func readUint32(b []byte) uint32 {
_ = b[3]
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
// Add p to the running checksum d.
func update(d digest, p []byte) digest {
s1, s2 := uint32(d&0xffff), uint32(d>>16)
for len(p) > 0 {
var q []byte
if len(p) > nmax {
p, q = p[:nmax], p[nmax:]
}
for len(p) >= 4 {
s1 += uint32(p[0])
s2 += s1
s1 += uint32(p[1])
s2 += s1
s1 += uint32(p[2])
s2 += s1
s1 += uint32(p[3])
s2 += s1
p = p[4:]
}
for _, x := range p {
s1 += uint32(x)
s2 += s1
}
s1 %= mod
s2 %= mod
p = q
}
return digest(s2<<16 | s1)
}
func (d *digest) Write(p []byte) (nn int, err error) {
*d = update(*d, p)
return len(p), nil
}
func (d *digest) Sum32() uint32 { return uint32(*d) }
func (d *digest) Sum(in []byte) []byte {
s := uint32(*d)
return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the Adler-32 checksum of data.
func Checksum(data []byte) uint32 { return uint32(update(1, data)) }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package crc32 implements the 32-bit cyclic redundancy check, or CRC-32,
// checksum. See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for
// information.
//
// Polynomials are represented in LSB-first form also known as reversed representation.
//
// See https://en.wikipedia.org/wiki/Mathematics_of_cyclic_redundancy_checks#Reversed_representations_and_reciprocal_polynomials
// for information.
package crc32
import (
"errors"
"hash"
"sync"
"sync/atomic"
)
// The size of a CRC-32 checksum in bytes.
const Size = 4
// Predefined polynomials.
const (
// IEEE is by far and away the most common CRC-32 polynomial.
// Used by ethernet (IEEE 802.3), v.42, fddi, gzip, zip, png, ...
IEEE = 0xedb88320
// Castagnoli's polynomial, used in iSCSI.
// Has better error detection characteristics than IEEE.
// https://dx.doi.org/10.1109/26.231911
Castagnoli = 0x82f63b78
// Koopman's polynomial.
// Also has better error detection characteristics than IEEE.
// https://dx.doi.org/10.1109/DSN.2002.1028931
Koopman = 0xeb31d82e
)
// Table is a 256-word table representing the polynomial for efficient processing.
type Table [256]uint32
// This file makes use of functions implemented in architecture-specific files.
// The interface that they implement is as follows:
//
// // archAvailableIEEE reports whether an architecture-specific CRC32-IEEE
// // algorithm is available.
// archAvailableIEEE() bool
//
// // archInitIEEE initializes the architecture-specific CRC3-IEEE algorithm.
// // It can only be called if archAvailableIEEE() returns true.
// archInitIEEE()
//
// // archUpdateIEEE updates the given CRC32-IEEE. It can only be called if
// // archInitIEEE() was previously called.
// archUpdateIEEE(crc uint32, p []byte) uint32
//
// // archAvailableCastagnoli reports whether an architecture-specific
// // CRC32-C algorithm is available.
// archAvailableCastagnoli() bool
//
// // archInitCastagnoli initializes the architecture-specific CRC32-C
// // algorithm. It can only be called if archAvailableCastagnoli() returns
// // true.
// archInitCastagnoli()
//
// // archUpdateCastagnoli updates the given CRC32-C. It can only be called
// // if archInitCastagnoli() was previously called.
// archUpdateCastagnoli(crc uint32, p []byte) uint32
// castagnoliTable points to a lazily initialized Table for the Castagnoli
// polynomial. MakeTable will always return this value when asked to make a
// Castagnoli table so we can compare against it to find when the caller is
// using this polynomial.
var castagnoliTable *Table
var castagnoliTable8 *slicing8Table
var updateCastagnoli func(crc uint32, p []byte) uint32
var castagnoliOnce sync.Once
var haveCastagnoli atomic.Bool
func castagnoliInit() {
castagnoliTable = simpleMakeTable(Castagnoli)
if archAvailableCastagnoli() {
archInitCastagnoli()
updateCastagnoli = archUpdateCastagnoli
} else {
// Initialize the slicing-by-8 table.
castagnoliTable8 = slicingMakeTable(Castagnoli)
updateCastagnoli = func(crc uint32, p []byte) uint32 {
return slicingUpdate(crc, castagnoliTable8, p)
}
}
haveCastagnoli.Store(true)
}
// IEEETable is the table for the IEEE polynomial.
var IEEETable = simpleMakeTable(IEEE)
// ieeeTable8 is the slicing8Table for IEEE
var ieeeTable8 *slicing8Table
var updateIEEE func(crc uint32, p []byte) uint32
var ieeeOnce sync.Once
func ieeeInit() {
if archAvailableIEEE() {
archInitIEEE()
updateIEEE = archUpdateIEEE
} else {
// Initialize the slicing-by-8 table.
ieeeTable8 = slicingMakeTable(IEEE)
updateIEEE = func(crc uint32, p []byte) uint32 {
return slicingUpdate(crc, ieeeTable8, p)
}
}
}
// MakeTable returns a Table constructed from the specified polynomial.
// The contents of this Table must not be modified.
func MakeTable(poly uint32) *Table {
switch poly {
case IEEE:
ieeeOnce.Do(ieeeInit)
return IEEETable
case Castagnoli:
castagnoliOnce.Do(castagnoliInit)
return castagnoliTable
default:
return simpleMakeTable(poly)
}
}
// digest represents the partial evaluation of a checksum.
type digest struct {
crc uint32
tab *Table
}
// New creates a new hash.Hash32 computing the CRC-32 checksum using the
// polynomial represented by the Table. Its Sum method will lay the
// value out in big-endian byte order. The returned Hash32 also
// implements encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
func New(tab *Table) hash.Hash32 {
if tab == IEEETable {
ieeeOnce.Do(ieeeInit)
}
return &digest{0, tab}
}
// NewIEEE creates a new hash.Hash32 computing the CRC-32 checksum using
// the IEEE polynomial. Its Sum method will lay the value out in
// big-endian byte order. The returned Hash32 also implements
// encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to marshal
// and unmarshal the internal state of the hash.
func NewIEEE() hash.Hash32 { return New(IEEETable) }
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 1 }
func (d *digest) Reset() { d.crc = 0 }
const (
magic = "crc\x01"
marshaledSize = len(magic) + 4 + 4
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = appendUint32(b, tableSum(d.tab))
b = appendUint32(b, d.crc)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/crc32: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/crc32: invalid hash state size")
}
if tableSum(d.tab) != readUint32(b[4:]) {
return errors.New("hash/crc32: tables do not match")
}
d.crc = readUint32(b[8:])
return nil
}
func appendUint32(b []byte, x uint32) []byte {
a := [4]byte{
byte(x >> 24),
byte(x >> 16),
byte(x >> 8),
byte(x),
}
return append(b, a[:]...)
}
func readUint32(b []byte) uint32 {
_ = b[3]
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func update(crc uint32, tab *Table, p []byte, checkInitIEEE bool) uint32 {
switch {
case haveCastagnoli.Load() && tab == castagnoliTable:
return updateCastagnoli(crc, p)
case tab == IEEETable:
if checkInitIEEE {
ieeeOnce.Do(ieeeInit)
}
return updateIEEE(crc, p)
default:
return simpleUpdate(crc, tab, p)
}
}
// Update returns the result of adding the bytes in p to the crc.
func Update(crc uint32, tab *Table, p []byte) uint32 {
// Unfortunately, because IEEETable is exported, IEEE may be used without a
// call to MakeTable. We have to make sure it gets initialized in that case.
return update(crc, tab, p, true)
}
func (d *digest) Write(p []byte) (n int, err error) {
// We only create digest objects through New() which takes care of
// initialization in this case.
d.crc = update(d.crc, d.tab, p, false)
return len(p), nil
}
func (d *digest) Sum32() uint32 { return d.crc }
func (d *digest) Sum(in []byte) []byte {
s := d.Sum32()
return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the CRC-32 checksum of data
// using the polynomial represented by the Table.
func Checksum(data []byte, tab *Table) uint32 { return Update(0, tab, data) }
// ChecksumIEEE returns the CRC-32 checksum of data
// using the IEEE polynomial.
func ChecksumIEEE(data []byte) uint32 {
ieeeOnce.Do(ieeeInit)
return updateIEEE(0, data)
}
// tableSum returns the IEEE checksum of table t.
func tableSum(t *Table) uint32 {
var a [1024]byte
b := a[:0]
if t != nil {
for _, x := range t {
b = appendUint32(b, x)
}
}
return ChecksumIEEE(b)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// AMD64-specific hardware-assisted CRC32 algorithms. See crc32.go for a
// description of the interface that each architecture-specific file
// implements.
package crc32
import (
"internal/cpu"
"unsafe"
)
// This file contains the code to call the SSE 4.2 version of the Castagnoli
// and IEEE CRC.
// castagnoliSSE42 is defined in crc32_amd64.s and uses the SSE 4.2 CRC32
// instruction.
//
//go:noescape
func castagnoliSSE42(crc uint32, p []byte) uint32
// castagnoliSSE42Triple is defined in crc32_amd64.s and uses the SSE 4.2 CRC32
// instruction.
//
//go:noescape
func castagnoliSSE42Triple(
crcA, crcB, crcC uint32,
a, b, c []byte,
rounds uint32,
) (retA uint32, retB uint32, retC uint32)
// ieeeCLMUL is defined in crc_amd64.s and uses the PCLMULQDQ
// instruction as well as SSE 4.1.
//
//go:noescape
func ieeeCLMUL(crc uint32, p []byte) uint32
const castagnoliK1 = 168
const castagnoliK2 = 1344
type sse42Table [4]Table
var castagnoliSSE42TableK1 *sse42Table
var castagnoliSSE42TableK2 *sse42Table
func archAvailableCastagnoli() bool {
return cpu.X86.HasSSE42
}
func archInitCastagnoli() {
if !cpu.X86.HasSSE42 {
panic("arch-specific Castagnoli not available")
}
castagnoliSSE42TableK1 = new(sse42Table)
castagnoliSSE42TableK2 = new(sse42Table)
// See description in updateCastagnoli.
// t[0][i] = CRC(i000, O)
// t[1][i] = CRC(0i00, O)
// t[2][i] = CRC(00i0, O)
// t[3][i] = CRC(000i, O)
// where O is a sequence of K zeros.
var tmp [castagnoliK2]byte
for b := 0; b < 4; b++ {
for i := 0; i < 256; i++ {
val := uint32(i) << uint32(b*8)
castagnoliSSE42TableK1[b][i] = castagnoliSSE42(val, tmp[:castagnoliK1])
castagnoliSSE42TableK2[b][i] = castagnoliSSE42(val, tmp[:])
}
}
}
// castagnoliShift computes the CRC32-C of K1 or K2 zeroes (depending on the
// table given) with the given initial crc value. This corresponds to
// CRC(crc, O) in the description in updateCastagnoli.
func castagnoliShift(table *sse42Table, crc uint32) uint32 {
return table[3][crc>>24] ^
table[2][(crc>>16)&0xFF] ^
table[1][(crc>>8)&0xFF] ^
table[0][crc&0xFF]
}
func archUpdateCastagnoli(crc uint32, p []byte) uint32 {
if !cpu.X86.HasSSE42 {
panic("not available")
}
// This method is inspired from the algorithm in Intel's white paper:
// "Fast CRC Computation for iSCSI Polynomial Using CRC32 Instruction"
// The same strategy of splitting the buffer in three is used but the
// combining calculation is different; the complete derivation is explained
// below.
//
// -- The basic idea --
//
// The CRC32 instruction (available in SSE4.2) can process 8 bytes at a
// time. In recent Intel architectures the instruction takes 3 cycles;
// however the processor can pipeline up to three instructions if they
// don't depend on each other.
//
// Roughly this means that we can process three buffers in about the same
// time we can process one buffer.
//
// The idea is then to split the buffer in three, CRC the three pieces
// separately and then combine the results.
//
// Combining the results requires precomputed tables, so we must choose a
// fixed buffer length to optimize. The longer the length, the faster; but
// only buffers longer than this length will use the optimization. We choose
// two cutoffs and compute tables for both:
// - one around 512: 168*3=504
// - one around 4KB: 1344*3=4032
//
// -- The nitty gritty --
//
// Let CRC(I, X) be the non-inverted CRC32-C of the sequence X (with
// initial non-inverted CRC I). This function has the following properties:
// (a) CRC(I, AB) = CRC(CRC(I, A), B)
// (b) CRC(I, A xor B) = CRC(I, A) xor CRC(0, B)
//
// Say we want to compute CRC(I, ABC) where A, B, C are three sequences of
// K bytes each, where K is a fixed constant. Let O be the sequence of K zero
// bytes.
//
// CRC(I, ABC) = CRC(I, ABO xor C)
// = CRC(I, ABO) xor CRC(0, C)
// = CRC(CRC(I, AB), O) xor CRC(0, C)
// = CRC(CRC(I, AO xor B), O) xor CRC(0, C)
// = CRC(CRC(I, AO) xor CRC(0, B), O) xor CRC(0, C)
// = CRC(CRC(CRC(I, A), O) xor CRC(0, B), O) xor CRC(0, C)
//
// The castagnoliSSE42Triple function can compute CRC(I, A), CRC(0, B),
// and CRC(0, C) efficiently. We just need to find a way to quickly compute
// CRC(uvwx, O) given a 4-byte initial value uvwx. We can precompute these
// values; since we can't have a 32-bit table, we break it up into four
// 8-bit tables:
//
// CRC(uvwx, O) = CRC(u000, O) xor
// CRC(0v00, O) xor
// CRC(00w0, O) xor
// CRC(000x, O)
//
// We can compute tables corresponding to the four terms for all 8-bit
// values.
crc = ^crc
// If a buffer is long enough to use the optimization, process the first few
// bytes to align the buffer to an 8 byte boundary (if necessary).
if len(p) >= castagnoliK1*3 {
delta := int(uintptr(unsafe.Pointer(&p[0])) & 7)
if delta != 0 {
delta = 8 - delta
crc = castagnoliSSE42(crc, p[:delta])
p = p[delta:]
}
}
// Process 3*K2 at a time.
for len(p) >= castagnoliK2*3 {
// Compute CRC(I, A), CRC(0, B), and CRC(0, C).
crcA, crcB, crcC := castagnoliSSE42Triple(
crc, 0, 0,
p, p[castagnoliK2:], p[castagnoliK2*2:],
castagnoliK2/24)
// CRC(I, AB) = CRC(CRC(I, A), O) xor CRC(0, B)
crcAB := castagnoliShift(castagnoliSSE42TableK2, crcA) ^ crcB
// CRC(I, ABC) = CRC(CRC(I, AB), O) xor CRC(0, C)
crc = castagnoliShift(castagnoliSSE42TableK2, crcAB) ^ crcC
p = p[castagnoliK2*3:]
}
// Process 3*K1 at a time.
for len(p) >= castagnoliK1*3 {
// Compute CRC(I, A), CRC(0, B), and CRC(0, C).
crcA, crcB, crcC := castagnoliSSE42Triple(
crc, 0, 0,
p, p[castagnoliK1:], p[castagnoliK1*2:],
castagnoliK1/24)
// CRC(I, AB) = CRC(CRC(I, A), O) xor CRC(0, B)
crcAB := castagnoliShift(castagnoliSSE42TableK1, crcA) ^ crcB
// CRC(I, ABC) = CRC(CRC(I, AB), O) xor CRC(0, C)
crc = castagnoliShift(castagnoliSSE42TableK1, crcAB) ^ crcC
p = p[castagnoliK1*3:]
}
// Use the simple implementation for what's left.
crc = castagnoliSSE42(crc, p)
return ^crc
}
func archAvailableIEEE() bool {
return cpu.X86.HasPCLMULQDQ && cpu.X86.HasSSE41
}
var archIeeeTable8 *slicing8Table
func archInitIEEE() {
if !cpu.X86.HasPCLMULQDQ || !cpu.X86.HasSSE41 {
panic("not available")
}
// We still use slicing-by-8 for small buffers.
archIeeeTable8 = slicingMakeTable(IEEE)
}
func archUpdateIEEE(crc uint32, p []byte) uint32 {
if !cpu.X86.HasPCLMULQDQ || !cpu.X86.HasSSE41 {
panic("not available")
}
if len(p) >= 64 {
left := len(p) & 15
do := len(p) - left
crc = ^ieeeCLMUL(^crc, p[:do])
p = p[do:]
}
if len(p) == 0 {
return crc
}
return slicingUpdate(crc, archIeeeTable8, p)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains CRC32 algorithms that are not specific to any architecture
// and don't use hardware acceleration.
//
// The simple (and slow) CRC32 implementation only uses a 256*4 bytes table.
//
// The slicing-by-8 algorithm is a faster implementation that uses a bigger
// table (8*256*4 bytes).
package crc32
// simpleMakeTable allocates and constructs a Table for the specified
// polynomial. The table is suitable for use with the simple algorithm
// (simpleUpdate).
func simpleMakeTable(poly uint32) *Table {
t := new(Table)
simplePopulateTable(poly, t)
return t
}
// simplePopulateTable constructs a Table for the specified polynomial, suitable
// for use with simpleUpdate.
func simplePopulateTable(poly uint32, t *Table) {
for i := 0; i < 256; i++ {
crc := uint32(i)
for j := 0; j < 8; j++ {
if crc&1 == 1 {
crc = (crc >> 1) ^ poly
} else {
crc >>= 1
}
}
t[i] = crc
}
}
// simpleUpdate uses the simple algorithm to update the CRC, given a table that
// was previously computed using simpleMakeTable.
func simpleUpdate(crc uint32, tab *Table, p []byte) uint32 {
crc = ^crc
for _, v := range p {
crc = tab[byte(crc)^v] ^ (crc >> 8)
}
return ^crc
}
// Use slicing-by-8 when payload >= this value.
const slicing8Cutoff = 16
// slicing8Table is array of 8 Tables, used by the slicing-by-8 algorithm.
type slicing8Table [8]Table
// slicingMakeTable constructs a slicing8Table for the specified polynomial. The
// table is suitable for use with the slicing-by-8 algorithm (slicingUpdate).
func slicingMakeTable(poly uint32) *slicing8Table {
t := new(slicing8Table)
simplePopulateTable(poly, &t[0])
for i := 0; i < 256; i++ {
crc := t[0][i]
for j := 1; j < 8; j++ {
crc = t[0][crc&0xFF] ^ (crc >> 8)
t[j][i] = crc
}
}
return t
}
// slicingUpdate uses the slicing-by-8 algorithm to update the CRC, given a
// table that was previously computed using slicingMakeTable.
func slicingUpdate(crc uint32, tab *slicing8Table, p []byte) uint32 {
if len(p) >= slicing8Cutoff {
crc = ^crc
for len(p) > 8 {
crc ^= uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
crc = tab[0][p[7]] ^ tab[1][p[6]] ^ tab[2][p[5]] ^ tab[3][p[4]] ^
tab[4][crc>>24] ^ tab[5][(crc>>16)&0xFF] ^
tab[6][(crc>>8)&0xFF] ^ tab[7][crc&0xFF]
p = p[8:]
}
crc = ^crc
}
if len(p) == 0 {
return crc
}
return simpleUpdate(crc, &tab[0], p)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package crc64 implements the 64-bit cyclic redundancy check, or CRC-64,
// checksum. See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for
// information.
package crc64
import (
"errors"
"hash"
"sync"
)
// The size of a CRC-64 checksum in bytes.
const Size = 8
// Predefined polynomials.
const (
// The ISO polynomial, defined in ISO 3309 and used in HDLC.
ISO = 0xD800000000000000
// The ECMA polynomial, defined in ECMA 182.
ECMA = 0xC96C5795D7870F42
)
// Table is a 256-word table representing the polynomial for efficient processing.
type Table [256]uint64
var (
slicing8TablesBuildOnce sync.Once
slicing8TableISO *[8]Table
slicing8TableECMA *[8]Table
)
func buildSlicing8TablesOnce() {
slicing8TablesBuildOnce.Do(buildSlicing8Tables)
}
func buildSlicing8Tables() {
slicing8TableISO = makeSlicingBy8Table(makeTable(ISO))
slicing8TableECMA = makeSlicingBy8Table(makeTable(ECMA))
}
// MakeTable returns a Table constructed from the specified polynomial.
// The contents of this Table must not be modified.
func MakeTable(poly uint64) *Table {
buildSlicing8TablesOnce()
switch poly {
case ISO:
return &slicing8TableISO[0]
case ECMA:
return &slicing8TableECMA[0]
default:
return makeTable(poly)
}
}
func makeTable(poly uint64) *Table {
t := new(Table)
for i := 0; i < 256; i++ {
crc := uint64(i)
for j := 0; j < 8; j++ {
if crc&1 == 1 {
crc = (crc >> 1) ^ poly
} else {
crc >>= 1
}
}
t[i] = crc
}
return t
}
func makeSlicingBy8Table(t *Table) *[8]Table {
var helperTable [8]Table
helperTable[0] = *t
for i := 0; i < 256; i++ {
crc := t[i]
for j := 1; j < 8; j++ {
crc = t[crc&0xff] ^ (crc >> 8)
helperTable[j][i] = crc
}
}
return &helperTable
}
// digest represents the partial evaluation of a checksum.
type digest struct {
crc uint64
tab *Table
}
// New creates a new hash.Hash64 computing the CRC-64 checksum using the
// polynomial represented by the Table. Its Sum method will lay the
// value out in big-endian byte order. The returned Hash64 also
// implements encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
func New(tab *Table) hash.Hash64 { return &digest{0, tab} }
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return 1 }
func (d *digest) Reset() { d.crc = 0 }
const (
magic = "crc\x02"
marshaledSize = len(magic) + 8 + 8
)
func (d *digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = appendUint64(b, tableSum(d.tab))
b = appendUint64(b, d.crc)
return b, nil
}
func (d *digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("hash/crc64: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("hash/crc64: invalid hash state size")
}
if tableSum(d.tab) != readUint64(b[4:]) {
return errors.New("hash/crc64: tables do not match")
}
d.crc = readUint64(b[12:])
return nil
}
func appendUint64(b []byte, x uint64) []byte {
a := [8]byte{
byte(x >> 56),
byte(x >> 48),
byte(x >> 40),
byte(x >> 32),
byte(x >> 24),
byte(x >> 16),
byte(x >> 8),
byte(x),
}
return append(b, a[:]...)
}
func readUint64(b []byte) uint64 {
_ = b[7]
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func update(crc uint64, tab *Table, p []byte) uint64 {
buildSlicing8TablesOnce()
crc = ^crc
// Table comparison is somewhat expensive, so avoid it for small sizes
for len(p) >= 64 {
var helperTable *[8]Table
if *tab == slicing8TableECMA[0] {
helperTable = slicing8TableECMA
} else if *tab == slicing8TableISO[0] {
helperTable = slicing8TableISO
// For smaller sizes creating extended table takes too much time
} else if len(p) >= 2048 {
// According to the tests between various x86 and arm CPUs, 2k is a reasonable
// threshold for now. This may change in the future.
helperTable = makeSlicingBy8Table(tab)
} else {
break
}
// Update using slicing-by-8
for len(p) > 8 {
crc ^= uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 |
uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
crc = helperTable[7][crc&0xff] ^
helperTable[6][(crc>>8)&0xff] ^
helperTable[5][(crc>>16)&0xff] ^
helperTable[4][(crc>>24)&0xff] ^
helperTable[3][(crc>>32)&0xff] ^
helperTable[2][(crc>>40)&0xff] ^
helperTable[1][(crc>>48)&0xff] ^
helperTable[0][crc>>56]
p = p[8:]
}
}
// For reminders or small sizes
for _, v := range p {
crc = tab[byte(crc)^v] ^ (crc >> 8)
}
return ^crc
}
// Update returns the result of adding the bytes in p to the crc.
func Update(crc uint64, tab *Table, p []byte) uint64 {
return update(crc, tab, p)
}
func (d *digest) Write(p []byte) (n int, err error) {
d.crc = update(d.crc, d.tab, p)
return len(p), nil
}
func (d *digest) Sum64() uint64 { return d.crc }
func (d *digest) Sum(in []byte) []byte {
s := d.Sum64()
return append(in, byte(s>>56), byte(s>>48), byte(s>>40), byte(s>>32), byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// Checksum returns the CRC-64 checksum of data
// using the polynomial represented by the Table.
func Checksum(data []byte, tab *Table) uint64 { return update(0, tab, data) }
// tableSum returns the ISO checksum of table t.
func tableSum(t *Table) uint64 {
var a [2048]byte
b := a[:0]
if t != nil {
for _, x := range t {
b = appendUint64(b, x)
}
}
return Checksum(b, MakeTable(ISO))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fnv implements FNV-1 and FNV-1a, non-cryptographic hash functions
// created by Glenn Fowler, Landon Curt Noll, and Phong Vo.
// See
// https://en.wikipedia.org/wiki/Fowler-Noll-Vo_hash_function.
//
// All the hash.Hash implementations returned by this package also
// implement encoding.BinaryMarshaler and encoding.BinaryUnmarshaler to
// marshal and unmarshal the internal state of the hash.
package fnv
import (
"errors"
"hash"
"math/bits"
)
type (
sum32 uint32
sum32a uint32
sum64 uint64
sum64a uint64
sum128 [2]uint64
sum128a [2]uint64
)
const (
offset32 = 2166136261
offset64 = 14695981039346656037
offset128Lower = 0x62b821756295c58d
offset128Higher = 0x6c62272e07bb0142
prime32 = 16777619
prime64 = 1099511628211
prime128Lower = 0x13b
prime128Shift = 24
)
// New32 returns a new 32-bit FNV-1 hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New32() hash.Hash32 {
var s sum32 = offset32
return &s
}
// New32a returns a new 32-bit FNV-1a hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New32a() hash.Hash32 {
var s sum32a = offset32
return &s
}
// New64 returns a new 64-bit FNV-1 hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New64() hash.Hash64 {
var s sum64 = offset64
return &s
}
// New64a returns a new 64-bit FNV-1a hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New64a() hash.Hash64 {
var s sum64a = offset64
return &s
}
// New128 returns a new 128-bit FNV-1 hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New128() hash.Hash {
var s sum128
s[0] = offset128Higher
s[1] = offset128Lower
return &s
}
// New128a returns a new 128-bit FNV-1a hash.Hash.
// Its Sum method will lay the value out in big-endian byte order.
func New128a() hash.Hash {
var s sum128a
s[0] = offset128Higher
s[1] = offset128Lower
return &s
}
func (s *sum32) Reset() { *s = offset32 }
func (s *sum32a) Reset() { *s = offset32 }
func (s *sum64) Reset() { *s = offset64 }
func (s *sum64a) Reset() { *s = offset64 }
func (s *sum128) Reset() { s[0] = offset128Higher; s[1] = offset128Lower }
func (s *sum128a) Reset() { s[0] = offset128Higher; s[1] = offset128Lower }
func (s *sum32) Sum32() uint32 { return uint32(*s) }
func (s *sum32a) Sum32() uint32 { return uint32(*s) }
func (s *sum64) Sum64() uint64 { return uint64(*s) }
func (s *sum64a) Sum64() uint64 { return uint64(*s) }
func (s *sum32) Write(data []byte) (int, error) {
hash := *s
for _, c := range data {
hash *= prime32
hash ^= sum32(c)
}
*s = hash
return len(data), nil
}
func (s *sum32a) Write(data []byte) (int, error) {
hash := *s
for _, c := range data {
hash ^= sum32a(c)
hash *= prime32
}
*s = hash
return len(data), nil
}
func (s *sum64) Write(data []byte) (int, error) {
hash := *s
for _, c := range data {
hash *= prime64
hash ^= sum64(c)
}
*s = hash
return len(data), nil
}
func (s *sum64a) Write(data []byte) (int, error) {
hash := *s
for _, c := range data {
hash ^= sum64a(c)
hash *= prime64
}
*s = hash
return len(data), nil
}
func (s *sum128) Write(data []byte) (int, error) {
for _, c := range data {
// Compute the multiplication
s0, s1 := bits.Mul64(prime128Lower, s[1])
s0 += s[1]<<prime128Shift + prime128Lower*s[0]
// Update the values
s[1] = s1
s[0] = s0
s[1] ^= uint64(c)
}
return len(data), nil
}
func (s *sum128a) Write(data []byte) (int, error) {
for _, c := range data {
s[1] ^= uint64(c)
// Compute the multiplication
s0, s1 := bits.Mul64(prime128Lower, s[1])
s0 += s[1]<<prime128Shift + prime128Lower*s[0]
// Update the values
s[1] = s1
s[0] = s0
}
return len(data), nil
}
func (s *sum32) Size() int { return 4 }
func (s *sum32a) Size() int { return 4 }
func (s *sum64) Size() int { return 8 }
func (s *sum64a) Size() int { return 8 }
func (s *sum128) Size() int { return 16 }
func (s *sum128a) Size() int { return 16 }
func (s *sum32) BlockSize() int { return 1 }
func (s *sum32a) BlockSize() int { return 1 }
func (s *sum64) BlockSize() int { return 1 }
func (s *sum64a) BlockSize() int { return 1 }
func (s *sum128) BlockSize() int { return 1 }
func (s *sum128a) BlockSize() int { return 1 }
func (s *sum32) Sum(in []byte) []byte {
v := uint32(*s)
return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
func (s *sum32a) Sum(in []byte) []byte {
v := uint32(*s)
return append(in, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
func (s *sum64) Sum(in []byte) []byte {
v := uint64(*s)
return append(in, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
func (s *sum64a) Sum(in []byte) []byte {
v := uint64(*s)
return append(in, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
func (s *sum128) Sum(in []byte) []byte {
return append(in,
byte(s[0]>>56), byte(s[0]>>48), byte(s[0]>>40), byte(s[0]>>32), byte(s[0]>>24), byte(s[0]>>16), byte(s[0]>>8), byte(s[0]),
byte(s[1]>>56), byte(s[1]>>48), byte(s[1]>>40), byte(s[1]>>32), byte(s[1]>>24), byte(s[1]>>16), byte(s[1]>>8), byte(s[1]),
)
}
func (s *sum128a) Sum(in []byte) []byte {
return append(in,
byte(s[0]>>56), byte(s[0]>>48), byte(s[0]>>40), byte(s[0]>>32), byte(s[0]>>24), byte(s[0]>>16), byte(s[0]>>8), byte(s[0]),
byte(s[1]>>56), byte(s[1]>>48), byte(s[1]>>40), byte(s[1]>>32), byte(s[1]>>24), byte(s[1]>>16), byte(s[1]>>8), byte(s[1]),
)
}
const (
magic32 = "fnv\x01"
magic32a = "fnv\x02"
magic64 = "fnv\x03"
magic64a = "fnv\x04"
magic128 = "fnv\x05"
magic128a = "fnv\x06"
marshaledSize32 = len(magic32) + 4
marshaledSize64 = len(magic64) + 8
marshaledSize128 = len(magic128) + 8*2
)
func (s *sum32) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize32)
b = append(b, magic32...)
b = appendUint32(b, uint32(*s))
return b, nil
}
func (s *sum32a) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize32)
b = append(b, magic32a...)
b = appendUint32(b, uint32(*s))
return b, nil
}
func (s *sum64) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize64)
b = append(b, magic64...)
b = appendUint64(b, uint64(*s))
return b, nil
}
func (s *sum64a) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize64)
b = append(b, magic64a...)
b = appendUint64(b, uint64(*s))
return b, nil
}
func (s *sum128) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize128)
b = append(b, magic128...)
b = appendUint64(b, s[0])
b = appendUint64(b, s[1])
return b, nil
}
func (s *sum128a) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize128)
b = append(b, magic128a...)
b = appendUint64(b, s[0])
b = appendUint64(b, s[1])
return b, nil
}
func (s *sum32) UnmarshalBinary(b []byte) error {
if len(b) < len(magic32) || string(b[:len(magic32)]) != magic32 {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize32 {
return errors.New("hash/fnv: invalid hash state size")
}
*s = sum32(readUint32(b[4:]))
return nil
}
func (s *sum32a) UnmarshalBinary(b []byte) error {
if len(b) < len(magic32a) || string(b[:len(magic32a)]) != magic32a {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize32 {
return errors.New("hash/fnv: invalid hash state size")
}
*s = sum32a(readUint32(b[4:]))
return nil
}
func (s *sum64) UnmarshalBinary(b []byte) error {
if len(b) < len(magic64) || string(b[:len(magic64)]) != magic64 {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize64 {
return errors.New("hash/fnv: invalid hash state size")
}
*s = sum64(readUint64(b[4:]))
return nil
}
func (s *sum64a) UnmarshalBinary(b []byte) error {
if len(b) < len(magic64a) || string(b[:len(magic64a)]) != magic64a {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize64 {
return errors.New("hash/fnv: invalid hash state size")
}
*s = sum64a(readUint64(b[4:]))
return nil
}
func (s *sum128) UnmarshalBinary(b []byte) error {
if len(b) < len(magic128) || string(b[:len(magic128)]) != magic128 {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize128 {
return errors.New("hash/fnv: invalid hash state size")
}
s[0] = readUint64(b[4:])
s[1] = readUint64(b[12:])
return nil
}
func (s *sum128a) UnmarshalBinary(b []byte) error {
if len(b) < len(magic128a) || string(b[:len(magic128a)]) != magic128a {
return errors.New("hash/fnv: invalid hash state identifier")
}
if len(b) != marshaledSize128 {
return errors.New("hash/fnv: invalid hash state size")
}
s[0] = readUint64(b[4:])
s[1] = readUint64(b[12:])
return nil
}
func readUint32(b []byte) uint32 {
_ = b[3]
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}
func appendUint32(b []byte, x uint32) []byte {
a := [4]byte{
byte(x >> 24),
byte(x >> 16),
byte(x >> 8),
byte(x),
}
return append(b, a[:]...)
}
func appendUint64(b []byte, x uint64) []byte {
a := [8]byte{
byte(x >> 56),
byte(x >> 48),
byte(x >> 40),
byte(x >> 32),
byte(x >> 24),
byte(x >> 16),
byte(x >> 8),
byte(x),
}
return append(b, a[:]...)
}
func readUint64(b []byte) uint64 {
_ = b[7]
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package maphash provides hash functions on byte sequences.
// These hash functions are intended to be used to implement hash tables or
// other data structures that need to map arbitrary strings or byte
// sequences to a uniform distribution on unsigned 64-bit integers.
// Each different instance of a hash table or data structure should use its own Seed.
//
// The hash functions are not cryptographically secure.
// (See crypto/sha256 and crypto/sha512 for cryptographic use.)
package maphash
// A Seed is a random value that selects the specific hash function
// computed by a Hash. If two Hashes use the same Seeds, they
// will compute the same hash values for any given input.
// If two Hashes use different Seeds, they are very likely to compute
// distinct hash values for any given input.
//
// A Seed must be initialized by calling MakeSeed.
// The zero seed is uninitialized and not valid for use with Hash's SetSeed method.
//
// Each Seed value is local to a single process and cannot be serialized
// or otherwise recreated in a different process.
type Seed struct {
s uint64
}
// Bytes returns the hash of b with the given seed.
//
// Bytes is equivalent to, but more convenient and efficient than:
//
// var h Hash
// h.SetSeed(seed)
// h.Write(b)
// return h.Sum64()
func Bytes(seed Seed, b []byte) uint64 {
state := seed.s
if state == 0 {
panic("maphash: use of uninitialized Seed")
}
if len(b) > bufSize {
b = b[:len(b):len(b)] // merge len and cap calculations when reslicing
for len(b) > bufSize {
state = rthash(b[:bufSize], state)
b = b[bufSize:]
}
}
return rthash(b, state)
}
// String returns the hash of s with the given seed.
//
// String is equivalent to, but more convenient and efficient than:
//
// var h Hash
// h.SetSeed(seed)
// h.WriteString(s)
// return h.Sum64()
func String(seed Seed, s string) uint64 {
state := seed.s
if state == 0 {
panic("maphash: use of uninitialized Seed")
}
for len(s) > bufSize {
state = rthashString(s[:bufSize], state)
s = s[bufSize:]
}
return rthashString(s, state)
}
// A Hash computes a seeded hash of a byte sequence.
//
// The zero Hash is a valid Hash ready to use.
// A zero Hash chooses a random seed for itself during
// the first call to a Reset, Write, Seed, or Sum64 method.
// For control over the seed, use SetSeed.
//
// The computed hash values depend only on the initial seed and
// the sequence of bytes provided to the Hash object, not on the way
// in which the bytes are provided. For example, the three sequences
//
// h.Write([]byte{'f','o','o'})
// h.WriteByte('f'); h.WriteByte('o'); h.WriteByte('o')
// h.WriteString("foo")
//
// all have the same effect.
//
// Hashes are intended to be collision-resistant, even for situations
// where an adversary controls the byte sequences being hashed.
//
// A Hash is not safe for concurrent use by multiple goroutines, but a Seed is.
// If multiple goroutines must compute the same seeded hash,
// each can declare its own Hash and call SetSeed with a common Seed.
type Hash struct {
_ [0]func() // not comparable
seed Seed // initial seed used for this hash
state Seed // current hash of all flushed bytes
buf [bufSize]byte // unflushed byte buffer
n int // number of unflushed bytes
}
// bufSize is the size of the Hash write buffer.
// The buffer ensures that writes depend only on the sequence of bytes,
// not the sequence of WriteByte/Write/WriteString calls,
// by always calling rthash with a full buffer (except for the tail).
const bufSize = 128
// initSeed seeds the hash if necessary.
// initSeed is called lazily before any operation that actually uses h.seed/h.state.
// Note that this does not include Write/WriteByte/WriteString in the case
// where they only add to h.buf. (If they write too much, they call h.flush,
// which does call h.initSeed.)
func (h *Hash) initSeed() {
if h.seed.s == 0 {
seed := MakeSeed()
h.seed = seed
h.state = seed
}
}
// WriteByte adds b to the sequence of bytes hashed by h.
// It never fails; the error result is for implementing io.ByteWriter.
func (h *Hash) WriteByte(b byte) error {
if h.n == len(h.buf) {
h.flush()
}
h.buf[h.n] = b
h.n++
return nil
}
// Write adds b to the sequence of bytes hashed by h.
// It always writes all of b and never fails; the count and error result are for implementing io.Writer.
func (h *Hash) Write(b []byte) (int, error) {
size := len(b)
// Deal with bytes left over in h.buf.
// h.n <= bufSize is always true.
// Checking it is ~free and it lets the compiler eliminate a bounds check.
if h.n > 0 && h.n <= bufSize {
k := copy(h.buf[h.n:], b)
h.n += k
if h.n < bufSize {
// Copied the entirety of b to h.buf.
return size, nil
}
b = b[k:]
h.flush()
// No need to set h.n = 0 here; it happens just before exit.
}
// Process as many full buffers as possible, without copying, and calling initSeed only once.
if len(b) > bufSize {
h.initSeed()
for len(b) > bufSize {
h.state.s = rthash(b[:bufSize], h.state.s)
b = b[bufSize:]
}
}
// Copy the tail.
copy(h.buf[:], b)
h.n = len(b)
return size, nil
}
// WriteString adds the bytes of s to the sequence of bytes hashed by h.
// It always writes all of s and never fails; the count and error result are for implementing io.StringWriter.
func (h *Hash) WriteString(s string) (int, error) {
// WriteString mirrors Write. See Write for comments.
size := len(s)
if h.n > 0 && h.n <= bufSize {
k := copy(h.buf[h.n:], s)
h.n += k
if h.n < bufSize {
return size, nil
}
s = s[k:]
h.flush()
}
if len(s) > bufSize {
h.initSeed()
for len(s) > bufSize {
h.state.s = rthashString(s[:bufSize], h.state.s)
s = s[bufSize:]
}
}
copy(h.buf[:], s)
h.n = len(s)
return size, nil
}
// Seed returns h's seed value.
func (h *Hash) Seed() Seed {
h.initSeed()
return h.seed
}
// SetSeed sets h to use seed, which must have been returned by MakeSeed
// or by another Hash's Seed method.
// Two Hash objects with the same seed behave identically.
// Two Hash objects with different seeds will very likely behave differently.
// Any bytes added to h before this call will be discarded.
func (h *Hash) SetSeed(seed Seed) {
if seed.s == 0 {
panic("maphash: use of uninitialized Seed")
}
h.seed = seed
h.state = seed
h.n = 0
}
// Reset discards all bytes added to h.
// (The seed remains the same.)
func (h *Hash) Reset() {
h.initSeed()
h.state = h.seed
h.n = 0
}
// precondition: buffer is full.
func (h *Hash) flush() {
if h.n != len(h.buf) {
panic("maphash: flush of partially full buffer")
}
h.initSeed()
h.state.s = rthash(h.buf[:h.n], h.state.s)
h.n = 0
}
// Sum64 returns h's current 64-bit value, which depends on
// h's seed and the sequence of bytes added to h since the
// last call to Reset or SetSeed.
//
// All bits of the Sum64 result are close to uniformly and
// independently distributed, so it can be safely reduced
// by using bit masking, shifting, or modular arithmetic.
func (h *Hash) Sum64() uint64 {
h.initSeed()
return rthash(h.buf[:h.n], h.state.s)
}
// MakeSeed returns a new random seed.
func MakeSeed() Seed {
var s uint64
for {
s = randUint64()
// We use seed 0 to indicate an uninitialized seed/hash,
// so keep trying until we get a non-zero seed.
if s != 0 {
break
}
}
return Seed{s: s}
}
// Sum appends the hash's current 64-bit value to b.
// It exists for implementing hash.Hash.
// For direct calls, it is more efficient to use Sum64.
func (h *Hash) Sum(b []byte) []byte {
x := h.Sum64()
return append(b,
byte(x>>0),
byte(x>>8),
byte(x>>16),
byte(x>>24),
byte(x>>32),
byte(x>>40),
byte(x>>48),
byte(x>>56))
}
// Size returns h's hash value size, 8 bytes.
func (h *Hash) Size() int { return 8 }
// BlockSize returns h's block size.
func (h *Hash) BlockSize() int { return len(h.buf) }
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
package maphash
import (
"unsafe"
)
//go:linkname runtime_fastrand64 runtime.fastrand64
func runtime_fastrand64() uint64
//go:linkname runtime_memhash runtime.memhash
//go:noescape
func runtime_memhash(p unsafe.Pointer, seed, s uintptr) uintptr
func rthash(buf []byte, seed uint64) uint64 {
if len(buf) == 0 {
return seed
}
len := len(buf)
// The runtime hasher only works on uintptr. For 64-bit
// architectures, we use the hasher directly. Otherwise,
// we use two parallel hashers on the lower and upper 32 bits.
if unsafe.Sizeof(uintptr(0)) == 8 {
return uint64(runtime_memhash(unsafe.Pointer(&buf[0]), uintptr(seed), uintptr(len)))
}
lo := runtime_memhash(unsafe.Pointer(&buf[0]), uintptr(seed), uintptr(len))
hi := runtime_memhash(unsafe.Pointer(&buf[0]), uintptr(seed>>32), uintptr(len))
return uint64(hi)<<32 | uint64(lo)
}
func rthashString(s string, state uint64) uint64 {
buf := unsafe.Slice(unsafe.StringData(s), len(s))
return rthash(buf, state)
}
func randUint64() uint64 {
return runtime_fastrand64()
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package html
import "sync"
// All entities that do not end with ';' are 6 or fewer bytes long.
const longestEntityWithoutSemicolon = 6
// entity is a map from HTML entity names to their values. The semicolon matters:
// https://html.spec.whatwg.org/multipage/named-characters.html
// lists both "amp" and "amp;" as two separate entries.
//
// Note that the HTML5 list is larger than the HTML4 list at
// http://www.w3.org/TR/html4/sgml/entities.html
var entity map[string]rune
// HTML entities that are two unicode codepoints.
var entity2 map[string][2]rune
// populateMapsOnce guards calling populateMaps.
var populateMapsOnce sync.Once
// populateMaps populates entity and entity2.
func populateMaps() {
entity = map[string]rune{
"AElig;": '\U000000C6',
"AMP;": '\U00000026',
"Aacute;": '\U000000C1',
"Abreve;": '\U00000102',
"Acirc;": '\U000000C2',
"Acy;": '\U00000410',
"Afr;": '\U0001D504',
"Agrave;": '\U000000C0',
"Alpha;": '\U00000391',
"Amacr;": '\U00000100',
"And;": '\U00002A53',
"Aogon;": '\U00000104',
"Aopf;": '\U0001D538',
"ApplyFunction;": '\U00002061',
"Aring;": '\U000000C5',
"Ascr;": '\U0001D49C',
"Assign;": '\U00002254',
"Atilde;": '\U000000C3',
"Auml;": '\U000000C4',
"Backslash;": '\U00002216',
"Barv;": '\U00002AE7',
"Barwed;": '\U00002306',
"Bcy;": '\U00000411',
"Because;": '\U00002235',
"Bernoullis;": '\U0000212C',
"Beta;": '\U00000392',
"Bfr;": '\U0001D505',
"Bopf;": '\U0001D539',
"Breve;": '\U000002D8',
"Bscr;": '\U0000212C',
"Bumpeq;": '\U0000224E',
"CHcy;": '\U00000427',
"COPY;": '\U000000A9',
"Cacute;": '\U00000106',
"Cap;": '\U000022D2',
"CapitalDifferentialD;": '\U00002145',
"Cayleys;": '\U0000212D',
"Ccaron;": '\U0000010C',
"Ccedil;": '\U000000C7',
"Ccirc;": '\U00000108',
"Cconint;": '\U00002230',
"Cdot;": '\U0000010A',
"Cedilla;": '\U000000B8',
"CenterDot;": '\U000000B7',
"Cfr;": '\U0000212D',
"Chi;": '\U000003A7',
"CircleDot;": '\U00002299',
"CircleMinus;": '\U00002296',
"CirclePlus;": '\U00002295',
"CircleTimes;": '\U00002297',
"ClockwiseContourIntegral;": '\U00002232',
"CloseCurlyDoubleQuote;": '\U0000201D',
"CloseCurlyQuote;": '\U00002019',
"Colon;": '\U00002237',
"Colone;": '\U00002A74',
"Congruent;": '\U00002261',
"Conint;": '\U0000222F',
"ContourIntegral;": '\U0000222E',
"Copf;": '\U00002102',
"Coproduct;": '\U00002210',
"CounterClockwiseContourIntegral;": '\U00002233',
"Cross;": '\U00002A2F',
"Cscr;": '\U0001D49E',
"Cup;": '\U000022D3',
"CupCap;": '\U0000224D',
"DD;": '\U00002145',
"DDotrahd;": '\U00002911',
"DJcy;": '\U00000402',
"DScy;": '\U00000405',
"DZcy;": '\U0000040F',
"Dagger;": '\U00002021',
"Darr;": '\U000021A1',
"Dashv;": '\U00002AE4',
"Dcaron;": '\U0000010E',
"Dcy;": '\U00000414',
"Del;": '\U00002207',
"Delta;": '\U00000394',
"Dfr;": '\U0001D507',
"DiacriticalAcute;": '\U000000B4',
"DiacriticalDot;": '\U000002D9',
"DiacriticalDoubleAcute;": '\U000002DD',
"DiacriticalGrave;": '\U00000060',
"DiacriticalTilde;": '\U000002DC',
"Diamond;": '\U000022C4',
"DifferentialD;": '\U00002146',
"Dopf;": '\U0001D53B',
"Dot;": '\U000000A8',
"DotDot;": '\U000020DC',
"DotEqual;": '\U00002250',
"DoubleContourIntegral;": '\U0000222F',
"DoubleDot;": '\U000000A8',
"DoubleDownArrow;": '\U000021D3',
"DoubleLeftArrow;": '\U000021D0',
"DoubleLeftRightArrow;": '\U000021D4',
"DoubleLeftTee;": '\U00002AE4',
"DoubleLongLeftArrow;": '\U000027F8',
"DoubleLongLeftRightArrow;": '\U000027FA',
"DoubleLongRightArrow;": '\U000027F9',
"DoubleRightArrow;": '\U000021D2',
"DoubleRightTee;": '\U000022A8',
"DoubleUpArrow;": '\U000021D1',
"DoubleUpDownArrow;": '\U000021D5',
"DoubleVerticalBar;": '\U00002225',
"DownArrow;": '\U00002193',
"DownArrowBar;": '\U00002913',
"DownArrowUpArrow;": '\U000021F5',
"DownBreve;": '\U00000311',
"DownLeftRightVector;": '\U00002950',
"DownLeftTeeVector;": '\U0000295E',
"DownLeftVector;": '\U000021BD',
"DownLeftVectorBar;": '\U00002956',
"DownRightTeeVector;": '\U0000295F',
"DownRightVector;": '\U000021C1',
"DownRightVectorBar;": '\U00002957',
"DownTee;": '\U000022A4',
"DownTeeArrow;": '\U000021A7',
"Downarrow;": '\U000021D3',
"Dscr;": '\U0001D49F',
"Dstrok;": '\U00000110',
"ENG;": '\U0000014A',
"ETH;": '\U000000D0',
"Eacute;": '\U000000C9',
"Ecaron;": '\U0000011A',
"Ecirc;": '\U000000CA',
"Ecy;": '\U0000042D',
"Edot;": '\U00000116',
"Efr;": '\U0001D508',
"Egrave;": '\U000000C8',
"Element;": '\U00002208',
"Emacr;": '\U00000112',
"EmptySmallSquare;": '\U000025FB',
"EmptyVerySmallSquare;": '\U000025AB',
"Eogon;": '\U00000118',
"Eopf;": '\U0001D53C',
"Epsilon;": '\U00000395',
"Equal;": '\U00002A75',
"EqualTilde;": '\U00002242',
"Equilibrium;": '\U000021CC',
"Escr;": '\U00002130',
"Esim;": '\U00002A73',
"Eta;": '\U00000397',
"Euml;": '\U000000CB',
"Exists;": '\U00002203',
"ExponentialE;": '\U00002147',
"Fcy;": '\U00000424',
"Ffr;": '\U0001D509',
"FilledSmallSquare;": '\U000025FC',
"FilledVerySmallSquare;": '\U000025AA',
"Fopf;": '\U0001D53D',
"ForAll;": '\U00002200',
"Fouriertrf;": '\U00002131',
"Fscr;": '\U00002131',
"GJcy;": '\U00000403',
"GT;": '\U0000003E',
"Gamma;": '\U00000393',
"Gammad;": '\U000003DC',
"Gbreve;": '\U0000011E',
"Gcedil;": '\U00000122',
"Gcirc;": '\U0000011C',
"Gcy;": '\U00000413',
"Gdot;": '\U00000120',
"Gfr;": '\U0001D50A',
"Gg;": '\U000022D9',
"Gopf;": '\U0001D53E',
"GreaterEqual;": '\U00002265',
"GreaterEqualLess;": '\U000022DB',
"GreaterFullEqual;": '\U00002267',
"GreaterGreater;": '\U00002AA2',
"GreaterLess;": '\U00002277',
"GreaterSlantEqual;": '\U00002A7E',
"GreaterTilde;": '\U00002273',
"Gscr;": '\U0001D4A2',
"Gt;": '\U0000226B',
"HARDcy;": '\U0000042A',
"Hacek;": '\U000002C7',
"Hat;": '\U0000005E',
"Hcirc;": '\U00000124',
"Hfr;": '\U0000210C',
"HilbertSpace;": '\U0000210B',
"Hopf;": '\U0000210D',
"HorizontalLine;": '\U00002500',
"Hscr;": '\U0000210B',
"Hstrok;": '\U00000126',
"HumpDownHump;": '\U0000224E',
"HumpEqual;": '\U0000224F',
"IEcy;": '\U00000415',
"IJlig;": '\U00000132',
"IOcy;": '\U00000401',
"Iacute;": '\U000000CD',
"Icirc;": '\U000000CE',
"Icy;": '\U00000418',
"Idot;": '\U00000130',
"Ifr;": '\U00002111',
"Igrave;": '\U000000CC',
"Im;": '\U00002111',
"Imacr;": '\U0000012A',
"ImaginaryI;": '\U00002148',
"Implies;": '\U000021D2',
"Int;": '\U0000222C',
"Integral;": '\U0000222B',
"Intersection;": '\U000022C2',
"InvisibleComma;": '\U00002063',
"InvisibleTimes;": '\U00002062',
"Iogon;": '\U0000012E',
"Iopf;": '\U0001D540',
"Iota;": '\U00000399',
"Iscr;": '\U00002110',
"Itilde;": '\U00000128',
"Iukcy;": '\U00000406',
"Iuml;": '\U000000CF',
"Jcirc;": '\U00000134',
"Jcy;": '\U00000419',
"Jfr;": '\U0001D50D',
"Jopf;": '\U0001D541',
"Jscr;": '\U0001D4A5',
"Jsercy;": '\U00000408',
"Jukcy;": '\U00000404',
"KHcy;": '\U00000425',
"KJcy;": '\U0000040C',
"Kappa;": '\U0000039A',
"Kcedil;": '\U00000136',
"Kcy;": '\U0000041A',
"Kfr;": '\U0001D50E',
"Kopf;": '\U0001D542',
"Kscr;": '\U0001D4A6',
"LJcy;": '\U00000409',
"LT;": '\U0000003C',
"Lacute;": '\U00000139',
"Lambda;": '\U0000039B',
"Lang;": '\U000027EA',
"Laplacetrf;": '\U00002112',
"Larr;": '\U0000219E',
"Lcaron;": '\U0000013D',
"Lcedil;": '\U0000013B',
"Lcy;": '\U0000041B',
"LeftAngleBracket;": '\U000027E8',
"LeftArrow;": '\U00002190',
"LeftArrowBar;": '\U000021E4',
"LeftArrowRightArrow;": '\U000021C6',
"LeftCeiling;": '\U00002308',
"LeftDoubleBracket;": '\U000027E6',
"LeftDownTeeVector;": '\U00002961',
"LeftDownVector;": '\U000021C3',
"LeftDownVectorBar;": '\U00002959',
"LeftFloor;": '\U0000230A',
"LeftRightArrow;": '\U00002194',
"LeftRightVector;": '\U0000294E',
"LeftTee;": '\U000022A3',
"LeftTeeArrow;": '\U000021A4',
"LeftTeeVector;": '\U0000295A',
"LeftTriangle;": '\U000022B2',
"LeftTriangleBar;": '\U000029CF',
"LeftTriangleEqual;": '\U000022B4',
"LeftUpDownVector;": '\U00002951',
"LeftUpTeeVector;": '\U00002960',
"LeftUpVector;": '\U000021BF',
"LeftUpVectorBar;": '\U00002958',
"LeftVector;": '\U000021BC',
"LeftVectorBar;": '\U00002952',
"Leftarrow;": '\U000021D0',
"Leftrightarrow;": '\U000021D4',
"LessEqualGreater;": '\U000022DA',
"LessFullEqual;": '\U00002266',
"LessGreater;": '\U00002276',
"LessLess;": '\U00002AA1',
"LessSlantEqual;": '\U00002A7D',
"LessTilde;": '\U00002272',
"Lfr;": '\U0001D50F',
"Ll;": '\U000022D8',
"Lleftarrow;": '\U000021DA',
"Lmidot;": '\U0000013F',
"LongLeftArrow;": '\U000027F5',
"LongLeftRightArrow;": '\U000027F7',
"LongRightArrow;": '\U000027F6',
"Longleftarrow;": '\U000027F8',
"Longleftrightarrow;": '\U000027FA',
"Longrightarrow;": '\U000027F9',
"Lopf;": '\U0001D543',
"LowerLeftArrow;": '\U00002199',
"LowerRightArrow;": '\U00002198',
"Lscr;": '\U00002112',
"Lsh;": '\U000021B0',
"Lstrok;": '\U00000141',
"Lt;": '\U0000226A',
"Map;": '\U00002905',
"Mcy;": '\U0000041C',
"MediumSpace;": '\U0000205F',
"Mellintrf;": '\U00002133',
"Mfr;": '\U0001D510',
"MinusPlus;": '\U00002213',
"Mopf;": '\U0001D544',
"Mscr;": '\U00002133',
"Mu;": '\U0000039C',
"NJcy;": '\U0000040A',
"Nacute;": '\U00000143',
"Ncaron;": '\U00000147',
"Ncedil;": '\U00000145',
"Ncy;": '\U0000041D',
"NegativeMediumSpace;": '\U0000200B',
"NegativeThickSpace;": '\U0000200B',
"NegativeThinSpace;": '\U0000200B',
"NegativeVeryThinSpace;": '\U0000200B',
"NestedGreaterGreater;": '\U0000226B',
"NestedLessLess;": '\U0000226A',
"NewLine;": '\U0000000A',
"Nfr;": '\U0001D511',
"NoBreak;": '\U00002060',
"NonBreakingSpace;": '\U000000A0',
"Nopf;": '\U00002115',
"Not;": '\U00002AEC',
"NotCongruent;": '\U00002262',
"NotCupCap;": '\U0000226D',
"NotDoubleVerticalBar;": '\U00002226',
"NotElement;": '\U00002209',
"NotEqual;": '\U00002260',
"NotExists;": '\U00002204',
"NotGreater;": '\U0000226F',
"NotGreaterEqual;": '\U00002271',
"NotGreaterLess;": '\U00002279',
"NotGreaterTilde;": '\U00002275',
"NotLeftTriangle;": '\U000022EA',
"NotLeftTriangleEqual;": '\U000022EC',
"NotLess;": '\U0000226E',
"NotLessEqual;": '\U00002270',
"NotLessGreater;": '\U00002278',
"NotLessTilde;": '\U00002274',
"NotPrecedes;": '\U00002280',
"NotPrecedesSlantEqual;": '\U000022E0',
"NotReverseElement;": '\U0000220C',
"NotRightTriangle;": '\U000022EB',
"NotRightTriangleEqual;": '\U000022ED',
"NotSquareSubsetEqual;": '\U000022E2',
"NotSquareSupersetEqual;": '\U000022E3',
"NotSubsetEqual;": '\U00002288',
"NotSucceeds;": '\U00002281',
"NotSucceedsSlantEqual;": '\U000022E1',
"NotSupersetEqual;": '\U00002289',
"NotTilde;": '\U00002241',
"NotTildeEqual;": '\U00002244',
"NotTildeFullEqual;": '\U00002247',
"NotTildeTilde;": '\U00002249',
"NotVerticalBar;": '\U00002224',
"Nscr;": '\U0001D4A9',
"Ntilde;": '\U000000D1',
"Nu;": '\U0000039D',
"OElig;": '\U00000152',
"Oacute;": '\U000000D3',
"Ocirc;": '\U000000D4',
"Ocy;": '\U0000041E',
"Odblac;": '\U00000150',
"Ofr;": '\U0001D512',
"Ograve;": '\U000000D2',
"Omacr;": '\U0000014C',
"Omega;": '\U000003A9',
"Omicron;": '\U0000039F',
"Oopf;": '\U0001D546',
"OpenCurlyDoubleQuote;": '\U0000201C',
"OpenCurlyQuote;": '\U00002018',
"Or;": '\U00002A54',
"Oscr;": '\U0001D4AA',
"Oslash;": '\U000000D8',
"Otilde;": '\U000000D5',
"Otimes;": '\U00002A37',
"Ouml;": '\U000000D6',
"OverBar;": '\U0000203E',
"OverBrace;": '\U000023DE',
"OverBracket;": '\U000023B4',
"OverParenthesis;": '\U000023DC',
"PartialD;": '\U00002202',
"Pcy;": '\U0000041F',
"Pfr;": '\U0001D513',
"Phi;": '\U000003A6',
"Pi;": '\U000003A0',
"PlusMinus;": '\U000000B1',
"Poincareplane;": '\U0000210C',
"Popf;": '\U00002119',
"Pr;": '\U00002ABB',
"Precedes;": '\U0000227A',
"PrecedesEqual;": '\U00002AAF',
"PrecedesSlantEqual;": '\U0000227C',
"PrecedesTilde;": '\U0000227E',
"Prime;": '\U00002033',
"Product;": '\U0000220F',
"Proportion;": '\U00002237',
"Proportional;": '\U0000221D',
"Pscr;": '\U0001D4AB',
"Psi;": '\U000003A8',
"QUOT;": '\U00000022',
"Qfr;": '\U0001D514',
"Qopf;": '\U0000211A',
"Qscr;": '\U0001D4AC',
"RBarr;": '\U00002910',
"REG;": '\U000000AE',
"Racute;": '\U00000154',
"Rang;": '\U000027EB',
"Rarr;": '\U000021A0',
"Rarrtl;": '\U00002916',
"Rcaron;": '\U00000158',
"Rcedil;": '\U00000156',
"Rcy;": '\U00000420',
"Re;": '\U0000211C',
"ReverseElement;": '\U0000220B',
"ReverseEquilibrium;": '\U000021CB',
"ReverseUpEquilibrium;": '\U0000296F',
"Rfr;": '\U0000211C',
"Rho;": '\U000003A1',
"RightAngleBracket;": '\U000027E9',
"RightArrow;": '\U00002192',
"RightArrowBar;": '\U000021E5',
"RightArrowLeftArrow;": '\U000021C4',
"RightCeiling;": '\U00002309',
"RightDoubleBracket;": '\U000027E7',
"RightDownTeeVector;": '\U0000295D',
"RightDownVector;": '\U000021C2',
"RightDownVectorBar;": '\U00002955',
"RightFloor;": '\U0000230B',
"RightTee;": '\U000022A2',
"RightTeeArrow;": '\U000021A6',
"RightTeeVector;": '\U0000295B',
"RightTriangle;": '\U000022B3',
"RightTriangleBar;": '\U000029D0',
"RightTriangleEqual;": '\U000022B5',
"RightUpDownVector;": '\U0000294F',
"RightUpTeeVector;": '\U0000295C',
"RightUpVector;": '\U000021BE',
"RightUpVectorBar;": '\U00002954',
"RightVector;": '\U000021C0',
"RightVectorBar;": '\U00002953',
"Rightarrow;": '\U000021D2',
"Ropf;": '\U0000211D',
"RoundImplies;": '\U00002970',
"Rrightarrow;": '\U000021DB',
"Rscr;": '\U0000211B',
"Rsh;": '\U000021B1',
"RuleDelayed;": '\U000029F4',
"SHCHcy;": '\U00000429',
"SHcy;": '\U00000428',
"SOFTcy;": '\U0000042C',
"Sacute;": '\U0000015A',
"Sc;": '\U00002ABC',
"Scaron;": '\U00000160',
"Scedil;": '\U0000015E',
"Scirc;": '\U0000015C',
"Scy;": '\U00000421',
"Sfr;": '\U0001D516',
"ShortDownArrow;": '\U00002193',
"ShortLeftArrow;": '\U00002190',
"ShortRightArrow;": '\U00002192',
"ShortUpArrow;": '\U00002191',
"Sigma;": '\U000003A3',
"SmallCircle;": '\U00002218',
"Sopf;": '\U0001D54A',
"Sqrt;": '\U0000221A',
"Square;": '\U000025A1',
"SquareIntersection;": '\U00002293',
"SquareSubset;": '\U0000228F',
"SquareSubsetEqual;": '\U00002291',
"SquareSuperset;": '\U00002290',
"SquareSupersetEqual;": '\U00002292',
"SquareUnion;": '\U00002294',
"Sscr;": '\U0001D4AE',
"Star;": '\U000022C6',
"Sub;": '\U000022D0',
"Subset;": '\U000022D0',
"SubsetEqual;": '\U00002286',
"Succeeds;": '\U0000227B',
"SucceedsEqual;": '\U00002AB0',
"SucceedsSlantEqual;": '\U0000227D',
"SucceedsTilde;": '\U0000227F',
"SuchThat;": '\U0000220B',
"Sum;": '\U00002211',
"Sup;": '\U000022D1',
"Superset;": '\U00002283',
"SupersetEqual;": '\U00002287',
"Supset;": '\U000022D1',
"THORN;": '\U000000DE',
"TRADE;": '\U00002122',
"TSHcy;": '\U0000040B',
"TScy;": '\U00000426',
"Tab;": '\U00000009',
"Tau;": '\U000003A4',
"Tcaron;": '\U00000164',
"Tcedil;": '\U00000162',
"Tcy;": '\U00000422',
"Tfr;": '\U0001D517',
"Therefore;": '\U00002234',
"Theta;": '\U00000398',
"ThinSpace;": '\U00002009',
"Tilde;": '\U0000223C',
"TildeEqual;": '\U00002243',
"TildeFullEqual;": '\U00002245',
"TildeTilde;": '\U00002248',
"Topf;": '\U0001D54B',
"TripleDot;": '\U000020DB',
"Tscr;": '\U0001D4AF',
"Tstrok;": '\U00000166',
"Uacute;": '\U000000DA',
"Uarr;": '\U0000219F',
"Uarrocir;": '\U00002949',
"Ubrcy;": '\U0000040E',
"Ubreve;": '\U0000016C',
"Ucirc;": '\U000000DB',
"Ucy;": '\U00000423',
"Udblac;": '\U00000170',
"Ufr;": '\U0001D518',
"Ugrave;": '\U000000D9',
"Umacr;": '\U0000016A',
"UnderBar;": '\U0000005F',
"UnderBrace;": '\U000023DF',
"UnderBracket;": '\U000023B5',
"UnderParenthesis;": '\U000023DD',
"Union;": '\U000022C3',
"UnionPlus;": '\U0000228E',
"Uogon;": '\U00000172',
"Uopf;": '\U0001D54C',
"UpArrow;": '\U00002191',
"UpArrowBar;": '\U00002912',
"UpArrowDownArrow;": '\U000021C5',
"UpDownArrow;": '\U00002195',
"UpEquilibrium;": '\U0000296E',
"UpTee;": '\U000022A5',
"UpTeeArrow;": '\U000021A5',
"Uparrow;": '\U000021D1',
"Updownarrow;": '\U000021D5',
"UpperLeftArrow;": '\U00002196',
"UpperRightArrow;": '\U00002197',
"Upsi;": '\U000003D2',
"Upsilon;": '\U000003A5',
"Uring;": '\U0000016E',
"Uscr;": '\U0001D4B0',
"Utilde;": '\U00000168',
"Uuml;": '\U000000DC',
"VDash;": '\U000022AB',
"Vbar;": '\U00002AEB',
"Vcy;": '\U00000412',
"Vdash;": '\U000022A9',
"Vdashl;": '\U00002AE6',
"Vee;": '\U000022C1',
"Verbar;": '\U00002016',
"Vert;": '\U00002016',
"VerticalBar;": '\U00002223',
"VerticalLine;": '\U0000007C',
"VerticalSeparator;": '\U00002758',
"VerticalTilde;": '\U00002240',
"VeryThinSpace;": '\U0000200A',
"Vfr;": '\U0001D519',
"Vopf;": '\U0001D54D',
"Vscr;": '\U0001D4B1',
"Vvdash;": '\U000022AA',
"Wcirc;": '\U00000174',
"Wedge;": '\U000022C0',
"Wfr;": '\U0001D51A',
"Wopf;": '\U0001D54E',
"Wscr;": '\U0001D4B2',
"Xfr;": '\U0001D51B',
"Xi;": '\U0000039E',
"Xopf;": '\U0001D54F',
"Xscr;": '\U0001D4B3',
"YAcy;": '\U0000042F',
"YIcy;": '\U00000407',
"YUcy;": '\U0000042E',
"Yacute;": '\U000000DD',
"Ycirc;": '\U00000176',
"Ycy;": '\U0000042B',
"Yfr;": '\U0001D51C',
"Yopf;": '\U0001D550',
"Yscr;": '\U0001D4B4',
"Yuml;": '\U00000178',
"ZHcy;": '\U00000416',
"Zacute;": '\U00000179',
"Zcaron;": '\U0000017D',
"Zcy;": '\U00000417',
"Zdot;": '\U0000017B',
"ZeroWidthSpace;": '\U0000200B',
"Zeta;": '\U00000396',
"Zfr;": '\U00002128',
"Zopf;": '\U00002124',
"Zscr;": '\U0001D4B5',
"aacute;": '\U000000E1',
"abreve;": '\U00000103',
"ac;": '\U0000223E',
"acd;": '\U0000223F',
"acirc;": '\U000000E2',
"acute;": '\U000000B4',
"acy;": '\U00000430',
"aelig;": '\U000000E6',
"af;": '\U00002061',
"afr;": '\U0001D51E',
"agrave;": '\U000000E0',
"alefsym;": '\U00002135',
"aleph;": '\U00002135',
"alpha;": '\U000003B1',
"amacr;": '\U00000101',
"amalg;": '\U00002A3F',
"amp;": '\U00000026',
"and;": '\U00002227',
"andand;": '\U00002A55',
"andd;": '\U00002A5C',
"andslope;": '\U00002A58',
"andv;": '\U00002A5A',
"ang;": '\U00002220',
"ange;": '\U000029A4',
"angle;": '\U00002220',
"angmsd;": '\U00002221',
"angmsdaa;": '\U000029A8',
"angmsdab;": '\U000029A9',
"angmsdac;": '\U000029AA',
"angmsdad;": '\U000029AB',
"angmsdae;": '\U000029AC',
"angmsdaf;": '\U000029AD',
"angmsdag;": '\U000029AE',
"angmsdah;": '\U000029AF',
"angrt;": '\U0000221F',
"angrtvb;": '\U000022BE',
"angrtvbd;": '\U0000299D',
"angsph;": '\U00002222',
"angst;": '\U000000C5',
"angzarr;": '\U0000237C',
"aogon;": '\U00000105',
"aopf;": '\U0001D552',
"ap;": '\U00002248',
"apE;": '\U00002A70',
"apacir;": '\U00002A6F',
"ape;": '\U0000224A',
"apid;": '\U0000224B',
"apos;": '\U00000027',
"approx;": '\U00002248',
"approxeq;": '\U0000224A',
"aring;": '\U000000E5',
"ascr;": '\U0001D4B6',
"ast;": '\U0000002A',
"asymp;": '\U00002248',
"asympeq;": '\U0000224D',
"atilde;": '\U000000E3',
"auml;": '\U000000E4',
"awconint;": '\U00002233',
"awint;": '\U00002A11',
"bNot;": '\U00002AED',
"backcong;": '\U0000224C',
"backepsilon;": '\U000003F6',
"backprime;": '\U00002035',
"backsim;": '\U0000223D',
"backsimeq;": '\U000022CD',
"barvee;": '\U000022BD',
"barwed;": '\U00002305',
"barwedge;": '\U00002305',
"bbrk;": '\U000023B5',
"bbrktbrk;": '\U000023B6',
"bcong;": '\U0000224C',
"bcy;": '\U00000431',
"bdquo;": '\U0000201E',
"becaus;": '\U00002235',
"because;": '\U00002235',
"bemptyv;": '\U000029B0',
"bepsi;": '\U000003F6',
"bernou;": '\U0000212C',
"beta;": '\U000003B2',
"beth;": '\U00002136',
"between;": '\U0000226C',
"bfr;": '\U0001D51F',
"bigcap;": '\U000022C2',
"bigcirc;": '\U000025EF',
"bigcup;": '\U000022C3',
"bigodot;": '\U00002A00',
"bigoplus;": '\U00002A01',
"bigotimes;": '\U00002A02',
"bigsqcup;": '\U00002A06',
"bigstar;": '\U00002605',
"bigtriangledown;": '\U000025BD',
"bigtriangleup;": '\U000025B3',
"biguplus;": '\U00002A04',
"bigvee;": '\U000022C1',
"bigwedge;": '\U000022C0',
"bkarow;": '\U0000290D',
"blacklozenge;": '\U000029EB',
"blacksquare;": '\U000025AA',
"blacktriangle;": '\U000025B4',
"blacktriangledown;": '\U000025BE',
"blacktriangleleft;": '\U000025C2',
"blacktriangleright;": '\U000025B8',
"blank;": '\U00002423',
"blk12;": '\U00002592',
"blk14;": '\U00002591',
"blk34;": '\U00002593',
"block;": '\U00002588',
"bnot;": '\U00002310',
"bopf;": '\U0001D553',
"bot;": '\U000022A5',
"bottom;": '\U000022A5',
"bowtie;": '\U000022C8',
"boxDL;": '\U00002557',
"boxDR;": '\U00002554',
"boxDl;": '\U00002556',
"boxDr;": '\U00002553',
"boxH;": '\U00002550',
"boxHD;": '\U00002566',
"boxHU;": '\U00002569',
"boxHd;": '\U00002564',
"boxHu;": '\U00002567',
"boxUL;": '\U0000255D',
"boxUR;": '\U0000255A',
"boxUl;": '\U0000255C',
"boxUr;": '\U00002559',
"boxV;": '\U00002551',
"boxVH;": '\U0000256C',
"boxVL;": '\U00002563',
"boxVR;": '\U00002560',
"boxVh;": '\U0000256B',
"boxVl;": '\U00002562',
"boxVr;": '\U0000255F',
"boxbox;": '\U000029C9',
"boxdL;": '\U00002555',
"boxdR;": '\U00002552',
"boxdl;": '\U00002510',
"boxdr;": '\U0000250C',
"boxh;": '\U00002500',
"boxhD;": '\U00002565',
"boxhU;": '\U00002568',
"boxhd;": '\U0000252C',
"boxhu;": '\U00002534',
"boxminus;": '\U0000229F',
"boxplus;": '\U0000229E',
"boxtimes;": '\U000022A0',
"boxuL;": '\U0000255B',
"boxuR;": '\U00002558',
"boxul;": '\U00002518',
"boxur;": '\U00002514',
"boxv;": '\U00002502',
"boxvH;": '\U0000256A',
"boxvL;": '\U00002561',
"boxvR;": '\U0000255E',
"boxvh;": '\U0000253C',
"boxvl;": '\U00002524',
"boxvr;": '\U0000251C',
"bprime;": '\U00002035',
"breve;": '\U000002D8',
"brvbar;": '\U000000A6',
"bscr;": '\U0001D4B7',
"bsemi;": '\U0000204F',
"bsim;": '\U0000223D',
"bsime;": '\U000022CD',
"bsol;": '\U0000005C',
"bsolb;": '\U000029C5',
"bsolhsub;": '\U000027C8',
"bull;": '\U00002022',
"bullet;": '\U00002022',
"bump;": '\U0000224E',
"bumpE;": '\U00002AAE',
"bumpe;": '\U0000224F',
"bumpeq;": '\U0000224F',
"cacute;": '\U00000107',
"cap;": '\U00002229',
"capand;": '\U00002A44',
"capbrcup;": '\U00002A49',
"capcap;": '\U00002A4B',
"capcup;": '\U00002A47',
"capdot;": '\U00002A40',
"caret;": '\U00002041',
"caron;": '\U000002C7',
"ccaps;": '\U00002A4D',
"ccaron;": '\U0000010D',
"ccedil;": '\U000000E7',
"ccirc;": '\U00000109',
"ccups;": '\U00002A4C',
"ccupssm;": '\U00002A50',
"cdot;": '\U0000010B',
"cedil;": '\U000000B8',
"cemptyv;": '\U000029B2',
"cent;": '\U000000A2',
"centerdot;": '\U000000B7',
"cfr;": '\U0001D520',
"chcy;": '\U00000447',
"check;": '\U00002713',
"checkmark;": '\U00002713',
"chi;": '\U000003C7',
"cir;": '\U000025CB',
"cirE;": '\U000029C3',
"circ;": '\U000002C6',
"circeq;": '\U00002257',
"circlearrowleft;": '\U000021BA',
"circlearrowright;": '\U000021BB',
"circledR;": '\U000000AE',
"circledS;": '\U000024C8',
"circledast;": '\U0000229B',
"circledcirc;": '\U0000229A',
"circleddash;": '\U0000229D',
"cire;": '\U00002257',
"cirfnint;": '\U00002A10',
"cirmid;": '\U00002AEF',
"cirscir;": '\U000029C2',
"clubs;": '\U00002663',
"clubsuit;": '\U00002663',
"colon;": '\U0000003A',
"colone;": '\U00002254',
"coloneq;": '\U00002254',
"comma;": '\U0000002C',
"commat;": '\U00000040',
"comp;": '\U00002201',
"compfn;": '\U00002218',
"complement;": '\U00002201',
"complexes;": '\U00002102',
"cong;": '\U00002245',
"congdot;": '\U00002A6D',
"conint;": '\U0000222E',
"copf;": '\U0001D554',
"coprod;": '\U00002210',
"copy;": '\U000000A9',
"copysr;": '\U00002117',
"crarr;": '\U000021B5',
"cross;": '\U00002717',
"cscr;": '\U0001D4B8',
"csub;": '\U00002ACF',
"csube;": '\U00002AD1',
"csup;": '\U00002AD0',
"csupe;": '\U00002AD2',
"ctdot;": '\U000022EF',
"cudarrl;": '\U00002938',
"cudarrr;": '\U00002935',
"cuepr;": '\U000022DE',
"cuesc;": '\U000022DF',
"cularr;": '\U000021B6',
"cularrp;": '\U0000293D',
"cup;": '\U0000222A',
"cupbrcap;": '\U00002A48',
"cupcap;": '\U00002A46',
"cupcup;": '\U00002A4A',
"cupdot;": '\U0000228D',
"cupor;": '\U00002A45',
"curarr;": '\U000021B7',
"curarrm;": '\U0000293C',
"curlyeqprec;": '\U000022DE',
"curlyeqsucc;": '\U000022DF',
"curlyvee;": '\U000022CE',
"curlywedge;": '\U000022CF',
"curren;": '\U000000A4',
"curvearrowleft;": '\U000021B6',
"curvearrowright;": '\U000021B7',
"cuvee;": '\U000022CE',
"cuwed;": '\U000022CF',
"cwconint;": '\U00002232',
"cwint;": '\U00002231',
"cylcty;": '\U0000232D',
"dArr;": '\U000021D3',
"dHar;": '\U00002965',
"dagger;": '\U00002020',
"daleth;": '\U00002138',
"darr;": '\U00002193',
"dash;": '\U00002010',
"dashv;": '\U000022A3',
"dbkarow;": '\U0000290F',
"dblac;": '\U000002DD',
"dcaron;": '\U0000010F',
"dcy;": '\U00000434',
"dd;": '\U00002146',
"ddagger;": '\U00002021',
"ddarr;": '\U000021CA',
"ddotseq;": '\U00002A77',
"deg;": '\U000000B0',
"delta;": '\U000003B4',
"demptyv;": '\U000029B1',
"dfisht;": '\U0000297F',
"dfr;": '\U0001D521',
"dharl;": '\U000021C3',
"dharr;": '\U000021C2',
"diam;": '\U000022C4',
"diamond;": '\U000022C4',
"diamondsuit;": '\U00002666',
"diams;": '\U00002666',
"die;": '\U000000A8',
"digamma;": '\U000003DD',
"disin;": '\U000022F2',
"div;": '\U000000F7',
"divide;": '\U000000F7',
"divideontimes;": '\U000022C7',
"divonx;": '\U000022C7',
"djcy;": '\U00000452',
"dlcorn;": '\U0000231E',
"dlcrop;": '\U0000230D',
"dollar;": '\U00000024',
"dopf;": '\U0001D555',
"dot;": '\U000002D9',
"doteq;": '\U00002250',
"doteqdot;": '\U00002251',
"dotminus;": '\U00002238',
"dotplus;": '\U00002214',
"dotsquare;": '\U000022A1',
"doublebarwedge;": '\U00002306',
"downarrow;": '\U00002193',
"downdownarrows;": '\U000021CA',
"downharpoonleft;": '\U000021C3',
"downharpoonright;": '\U000021C2',
"drbkarow;": '\U00002910',
"drcorn;": '\U0000231F',
"drcrop;": '\U0000230C',
"dscr;": '\U0001D4B9',
"dscy;": '\U00000455',
"dsol;": '\U000029F6',
"dstrok;": '\U00000111',
"dtdot;": '\U000022F1',
"dtri;": '\U000025BF',
"dtrif;": '\U000025BE',
"duarr;": '\U000021F5',
"duhar;": '\U0000296F',
"dwangle;": '\U000029A6',
"dzcy;": '\U0000045F',
"dzigrarr;": '\U000027FF',
"eDDot;": '\U00002A77',
"eDot;": '\U00002251',
"eacute;": '\U000000E9',
"easter;": '\U00002A6E',
"ecaron;": '\U0000011B',
"ecir;": '\U00002256',
"ecirc;": '\U000000EA',
"ecolon;": '\U00002255',
"ecy;": '\U0000044D',
"edot;": '\U00000117',
"ee;": '\U00002147',
"efDot;": '\U00002252',
"efr;": '\U0001D522',
"eg;": '\U00002A9A',
"egrave;": '\U000000E8',
"egs;": '\U00002A96',
"egsdot;": '\U00002A98',
"el;": '\U00002A99',
"elinters;": '\U000023E7',
"ell;": '\U00002113',
"els;": '\U00002A95',
"elsdot;": '\U00002A97',
"emacr;": '\U00000113',
"empty;": '\U00002205',
"emptyset;": '\U00002205',
"emptyv;": '\U00002205',
"emsp;": '\U00002003',
"emsp13;": '\U00002004',
"emsp14;": '\U00002005',
"eng;": '\U0000014B',
"ensp;": '\U00002002',
"eogon;": '\U00000119',
"eopf;": '\U0001D556',
"epar;": '\U000022D5',
"eparsl;": '\U000029E3',
"eplus;": '\U00002A71',
"epsi;": '\U000003B5',
"epsilon;": '\U000003B5',
"epsiv;": '\U000003F5',
"eqcirc;": '\U00002256',
"eqcolon;": '\U00002255',
"eqsim;": '\U00002242',
"eqslantgtr;": '\U00002A96',
"eqslantless;": '\U00002A95',
"equals;": '\U0000003D',
"equest;": '\U0000225F',
"equiv;": '\U00002261',
"equivDD;": '\U00002A78',
"eqvparsl;": '\U000029E5',
"erDot;": '\U00002253',
"erarr;": '\U00002971',
"escr;": '\U0000212F',
"esdot;": '\U00002250',
"esim;": '\U00002242',
"eta;": '\U000003B7',
"eth;": '\U000000F0',
"euml;": '\U000000EB',
"euro;": '\U000020AC',
"excl;": '\U00000021',
"exist;": '\U00002203',
"expectation;": '\U00002130',
"exponentiale;": '\U00002147',
"fallingdotseq;": '\U00002252',
"fcy;": '\U00000444',
"female;": '\U00002640',
"ffilig;": '\U0000FB03',
"fflig;": '\U0000FB00',
"ffllig;": '\U0000FB04',
"ffr;": '\U0001D523',
"filig;": '\U0000FB01',
"flat;": '\U0000266D',
"fllig;": '\U0000FB02',
"fltns;": '\U000025B1',
"fnof;": '\U00000192',
"fopf;": '\U0001D557',
"forall;": '\U00002200',
"fork;": '\U000022D4',
"forkv;": '\U00002AD9',
"fpartint;": '\U00002A0D',
"frac12;": '\U000000BD',
"frac13;": '\U00002153',
"frac14;": '\U000000BC',
"frac15;": '\U00002155',
"frac16;": '\U00002159',
"frac18;": '\U0000215B',
"frac23;": '\U00002154',
"frac25;": '\U00002156',
"frac34;": '\U000000BE',
"frac35;": '\U00002157',
"frac38;": '\U0000215C',
"frac45;": '\U00002158',
"frac56;": '\U0000215A',
"frac58;": '\U0000215D',
"frac78;": '\U0000215E',
"frasl;": '\U00002044',
"frown;": '\U00002322',
"fscr;": '\U0001D4BB',
"gE;": '\U00002267',
"gEl;": '\U00002A8C',
"gacute;": '\U000001F5',
"gamma;": '\U000003B3',
"gammad;": '\U000003DD',
"gap;": '\U00002A86',
"gbreve;": '\U0000011F',
"gcirc;": '\U0000011D',
"gcy;": '\U00000433',
"gdot;": '\U00000121',
"ge;": '\U00002265',
"gel;": '\U000022DB',
"geq;": '\U00002265',
"geqq;": '\U00002267',
"geqslant;": '\U00002A7E',
"ges;": '\U00002A7E',
"gescc;": '\U00002AA9',
"gesdot;": '\U00002A80',
"gesdoto;": '\U00002A82',
"gesdotol;": '\U00002A84',
"gesles;": '\U00002A94',
"gfr;": '\U0001D524',
"gg;": '\U0000226B',
"ggg;": '\U000022D9',
"gimel;": '\U00002137',
"gjcy;": '\U00000453',
"gl;": '\U00002277',
"glE;": '\U00002A92',
"gla;": '\U00002AA5',
"glj;": '\U00002AA4',
"gnE;": '\U00002269',
"gnap;": '\U00002A8A',
"gnapprox;": '\U00002A8A',
"gne;": '\U00002A88',
"gneq;": '\U00002A88',
"gneqq;": '\U00002269',
"gnsim;": '\U000022E7',
"gopf;": '\U0001D558',
"grave;": '\U00000060',
"gscr;": '\U0000210A',
"gsim;": '\U00002273',
"gsime;": '\U00002A8E',
"gsiml;": '\U00002A90',
"gt;": '\U0000003E',
"gtcc;": '\U00002AA7',
"gtcir;": '\U00002A7A',
"gtdot;": '\U000022D7',
"gtlPar;": '\U00002995',
"gtquest;": '\U00002A7C',
"gtrapprox;": '\U00002A86',
"gtrarr;": '\U00002978',
"gtrdot;": '\U000022D7',
"gtreqless;": '\U000022DB',
"gtreqqless;": '\U00002A8C',
"gtrless;": '\U00002277',
"gtrsim;": '\U00002273',
"hArr;": '\U000021D4',
"hairsp;": '\U0000200A',
"half;": '\U000000BD',
"hamilt;": '\U0000210B',
"hardcy;": '\U0000044A',
"harr;": '\U00002194',
"harrcir;": '\U00002948',
"harrw;": '\U000021AD',
"hbar;": '\U0000210F',
"hcirc;": '\U00000125',
"hearts;": '\U00002665',
"heartsuit;": '\U00002665',
"hellip;": '\U00002026',
"hercon;": '\U000022B9',
"hfr;": '\U0001D525',
"hksearow;": '\U00002925',
"hkswarow;": '\U00002926',
"hoarr;": '\U000021FF',
"homtht;": '\U0000223B',
"hookleftarrow;": '\U000021A9',
"hookrightarrow;": '\U000021AA',
"hopf;": '\U0001D559',
"horbar;": '\U00002015',
"hscr;": '\U0001D4BD',
"hslash;": '\U0000210F',
"hstrok;": '\U00000127',
"hybull;": '\U00002043',
"hyphen;": '\U00002010',
"iacute;": '\U000000ED',
"ic;": '\U00002063',
"icirc;": '\U000000EE',
"icy;": '\U00000438',
"iecy;": '\U00000435',
"iexcl;": '\U000000A1',
"iff;": '\U000021D4',
"ifr;": '\U0001D526',
"igrave;": '\U000000EC',
"ii;": '\U00002148',
"iiiint;": '\U00002A0C',
"iiint;": '\U0000222D',
"iinfin;": '\U000029DC',
"iiota;": '\U00002129',
"ijlig;": '\U00000133',
"imacr;": '\U0000012B',
"image;": '\U00002111',
"imagline;": '\U00002110',
"imagpart;": '\U00002111',
"imath;": '\U00000131',
"imof;": '\U000022B7',
"imped;": '\U000001B5',
"in;": '\U00002208',
"incare;": '\U00002105',
"infin;": '\U0000221E',
"infintie;": '\U000029DD',
"inodot;": '\U00000131',
"int;": '\U0000222B',
"intcal;": '\U000022BA',
"integers;": '\U00002124',
"intercal;": '\U000022BA',
"intlarhk;": '\U00002A17',
"intprod;": '\U00002A3C',
"iocy;": '\U00000451',
"iogon;": '\U0000012F',
"iopf;": '\U0001D55A',
"iota;": '\U000003B9',
"iprod;": '\U00002A3C',
"iquest;": '\U000000BF',
"iscr;": '\U0001D4BE',
"isin;": '\U00002208',
"isinE;": '\U000022F9',
"isindot;": '\U000022F5',
"isins;": '\U000022F4',
"isinsv;": '\U000022F3',
"isinv;": '\U00002208',
"it;": '\U00002062',
"itilde;": '\U00000129',
"iukcy;": '\U00000456',
"iuml;": '\U000000EF',
"jcirc;": '\U00000135',
"jcy;": '\U00000439',
"jfr;": '\U0001D527',
"jmath;": '\U00000237',
"jopf;": '\U0001D55B',
"jscr;": '\U0001D4BF',
"jsercy;": '\U00000458',
"jukcy;": '\U00000454',
"kappa;": '\U000003BA',
"kappav;": '\U000003F0',
"kcedil;": '\U00000137',
"kcy;": '\U0000043A',
"kfr;": '\U0001D528',
"kgreen;": '\U00000138',
"khcy;": '\U00000445',
"kjcy;": '\U0000045C',
"kopf;": '\U0001D55C',
"kscr;": '\U0001D4C0',
"lAarr;": '\U000021DA',
"lArr;": '\U000021D0',
"lAtail;": '\U0000291B',
"lBarr;": '\U0000290E',
"lE;": '\U00002266',
"lEg;": '\U00002A8B',
"lHar;": '\U00002962',
"lacute;": '\U0000013A',
"laemptyv;": '\U000029B4',
"lagran;": '\U00002112',
"lambda;": '\U000003BB',
"lang;": '\U000027E8',
"langd;": '\U00002991',
"langle;": '\U000027E8',
"lap;": '\U00002A85',
"laquo;": '\U000000AB',
"larr;": '\U00002190',
"larrb;": '\U000021E4',
"larrbfs;": '\U0000291F',
"larrfs;": '\U0000291D',
"larrhk;": '\U000021A9',
"larrlp;": '\U000021AB',
"larrpl;": '\U00002939',
"larrsim;": '\U00002973',
"larrtl;": '\U000021A2',
"lat;": '\U00002AAB',
"latail;": '\U00002919',
"late;": '\U00002AAD',
"lbarr;": '\U0000290C',
"lbbrk;": '\U00002772',
"lbrace;": '\U0000007B',
"lbrack;": '\U0000005B',
"lbrke;": '\U0000298B',
"lbrksld;": '\U0000298F',
"lbrkslu;": '\U0000298D',
"lcaron;": '\U0000013E',
"lcedil;": '\U0000013C',
"lceil;": '\U00002308',
"lcub;": '\U0000007B',
"lcy;": '\U0000043B',
"ldca;": '\U00002936',
"ldquo;": '\U0000201C',
"ldquor;": '\U0000201E',
"ldrdhar;": '\U00002967',
"ldrushar;": '\U0000294B',
"ldsh;": '\U000021B2',
"le;": '\U00002264',
"leftarrow;": '\U00002190',
"leftarrowtail;": '\U000021A2',
"leftharpoondown;": '\U000021BD',
"leftharpoonup;": '\U000021BC',
"leftleftarrows;": '\U000021C7',
"leftrightarrow;": '\U00002194',
"leftrightarrows;": '\U000021C6',
"leftrightharpoons;": '\U000021CB',
"leftrightsquigarrow;": '\U000021AD',
"leftthreetimes;": '\U000022CB',
"leg;": '\U000022DA',
"leq;": '\U00002264',
"leqq;": '\U00002266',
"leqslant;": '\U00002A7D',
"les;": '\U00002A7D',
"lescc;": '\U00002AA8',
"lesdot;": '\U00002A7F',
"lesdoto;": '\U00002A81',
"lesdotor;": '\U00002A83',
"lesges;": '\U00002A93',
"lessapprox;": '\U00002A85',
"lessdot;": '\U000022D6',
"lesseqgtr;": '\U000022DA',
"lesseqqgtr;": '\U00002A8B',
"lessgtr;": '\U00002276',
"lesssim;": '\U00002272',
"lfisht;": '\U0000297C',
"lfloor;": '\U0000230A',
"lfr;": '\U0001D529',
"lg;": '\U00002276',
"lgE;": '\U00002A91',
"lhard;": '\U000021BD',
"lharu;": '\U000021BC',
"lharul;": '\U0000296A',
"lhblk;": '\U00002584',
"ljcy;": '\U00000459',
"ll;": '\U0000226A',
"llarr;": '\U000021C7',
"llcorner;": '\U0000231E',
"llhard;": '\U0000296B',
"lltri;": '\U000025FA',
"lmidot;": '\U00000140',
"lmoust;": '\U000023B0',
"lmoustache;": '\U000023B0',
"lnE;": '\U00002268',
"lnap;": '\U00002A89',
"lnapprox;": '\U00002A89',
"lne;": '\U00002A87',
"lneq;": '\U00002A87',
"lneqq;": '\U00002268',
"lnsim;": '\U000022E6',
"loang;": '\U000027EC',
"loarr;": '\U000021FD',
"lobrk;": '\U000027E6',
"longleftarrow;": '\U000027F5',
"longleftrightarrow;": '\U000027F7',
"longmapsto;": '\U000027FC',
"longrightarrow;": '\U000027F6',
"looparrowleft;": '\U000021AB',
"looparrowright;": '\U000021AC',
"lopar;": '\U00002985',
"lopf;": '\U0001D55D',
"loplus;": '\U00002A2D',
"lotimes;": '\U00002A34',
"lowast;": '\U00002217',
"lowbar;": '\U0000005F',
"loz;": '\U000025CA',
"lozenge;": '\U000025CA',
"lozf;": '\U000029EB',
"lpar;": '\U00000028',
"lparlt;": '\U00002993',
"lrarr;": '\U000021C6',
"lrcorner;": '\U0000231F',
"lrhar;": '\U000021CB',
"lrhard;": '\U0000296D',
"lrm;": '\U0000200E',
"lrtri;": '\U000022BF',
"lsaquo;": '\U00002039',
"lscr;": '\U0001D4C1',
"lsh;": '\U000021B0',
"lsim;": '\U00002272',
"lsime;": '\U00002A8D',
"lsimg;": '\U00002A8F',
"lsqb;": '\U0000005B',
"lsquo;": '\U00002018',
"lsquor;": '\U0000201A',
"lstrok;": '\U00000142',
"lt;": '\U0000003C',
"ltcc;": '\U00002AA6',
"ltcir;": '\U00002A79',
"ltdot;": '\U000022D6',
"lthree;": '\U000022CB',
"ltimes;": '\U000022C9',
"ltlarr;": '\U00002976',
"ltquest;": '\U00002A7B',
"ltrPar;": '\U00002996',
"ltri;": '\U000025C3',
"ltrie;": '\U000022B4',
"ltrif;": '\U000025C2',
"lurdshar;": '\U0000294A',
"luruhar;": '\U00002966',
"mDDot;": '\U0000223A',
"macr;": '\U000000AF',
"male;": '\U00002642',
"malt;": '\U00002720',
"maltese;": '\U00002720',
"map;": '\U000021A6',
"mapsto;": '\U000021A6',
"mapstodown;": '\U000021A7',
"mapstoleft;": '\U000021A4',
"mapstoup;": '\U000021A5',
"marker;": '\U000025AE',
"mcomma;": '\U00002A29',
"mcy;": '\U0000043C',
"mdash;": '\U00002014',
"measuredangle;": '\U00002221',
"mfr;": '\U0001D52A',
"mho;": '\U00002127',
"micro;": '\U000000B5',
"mid;": '\U00002223',
"midast;": '\U0000002A',
"midcir;": '\U00002AF0',
"middot;": '\U000000B7',
"minus;": '\U00002212',
"minusb;": '\U0000229F',
"minusd;": '\U00002238',
"minusdu;": '\U00002A2A',
"mlcp;": '\U00002ADB',
"mldr;": '\U00002026',
"mnplus;": '\U00002213',
"models;": '\U000022A7',
"mopf;": '\U0001D55E',
"mp;": '\U00002213',
"mscr;": '\U0001D4C2',
"mstpos;": '\U0000223E',
"mu;": '\U000003BC',
"multimap;": '\U000022B8',
"mumap;": '\U000022B8',
"nLeftarrow;": '\U000021CD',
"nLeftrightarrow;": '\U000021CE',
"nRightarrow;": '\U000021CF',
"nVDash;": '\U000022AF',
"nVdash;": '\U000022AE',
"nabla;": '\U00002207',
"nacute;": '\U00000144',
"nap;": '\U00002249',
"napos;": '\U00000149',
"napprox;": '\U00002249',
"natur;": '\U0000266E',
"natural;": '\U0000266E',
"naturals;": '\U00002115',
"nbsp;": '\U000000A0',
"ncap;": '\U00002A43',
"ncaron;": '\U00000148',
"ncedil;": '\U00000146',
"ncong;": '\U00002247',
"ncup;": '\U00002A42',
"ncy;": '\U0000043D',
"ndash;": '\U00002013',
"ne;": '\U00002260',
"neArr;": '\U000021D7',
"nearhk;": '\U00002924',
"nearr;": '\U00002197',
"nearrow;": '\U00002197',
"nequiv;": '\U00002262',
"nesear;": '\U00002928',
"nexist;": '\U00002204',
"nexists;": '\U00002204',
"nfr;": '\U0001D52B',
"nge;": '\U00002271',
"ngeq;": '\U00002271',
"ngsim;": '\U00002275',
"ngt;": '\U0000226F',
"ngtr;": '\U0000226F',
"nhArr;": '\U000021CE',
"nharr;": '\U000021AE',
"nhpar;": '\U00002AF2',
"ni;": '\U0000220B',
"nis;": '\U000022FC',
"nisd;": '\U000022FA',
"niv;": '\U0000220B',
"njcy;": '\U0000045A',
"nlArr;": '\U000021CD',
"nlarr;": '\U0000219A',
"nldr;": '\U00002025',
"nle;": '\U00002270',
"nleftarrow;": '\U0000219A',
"nleftrightarrow;": '\U000021AE',
"nleq;": '\U00002270',
"nless;": '\U0000226E',
"nlsim;": '\U00002274',
"nlt;": '\U0000226E',
"nltri;": '\U000022EA',
"nltrie;": '\U000022EC',
"nmid;": '\U00002224',
"nopf;": '\U0001D55F',
"not;": '\U000000AC',
"notin;": '\U00002209',
"notinva;": '\U00002209',
"notinvb;": '\U000022F7',
"notinvc;": '\U000022F6',
"notni;": '\U0000220C',
"notniva;": '\U0000220C',
"notnivb;": '\U000022FE',
"notnivc;": '\U000022FD',
"npar;": '\U00002226',
"nparallel;": '\U00002226',
"npolint;": '\U00002A14',
"npr;": '\U00002280',
"nprcue;": '\U000022E0',
"nprec;": '\U00002280',
"nrArr;": '\U000021CF',
"nrarr;": '\U0000219B',
"nrightarrow;": '\U0000219B',
"nrtri;": '\U000022EB',
"nrtrie;": '\U000022ED',
"nsc;": '\U00002281',
"nsccue;": '\U000022E1',
"nscr;": '\U0001D4C3',
"nshortmid;": '\U00002224',
"nshortparallel;": '\U00002226',
"nsim;": '\U00002241',
"nsime;": '\U00002244',
"nsimeq;": '\U00002244',
"nsmid;": '\U00002224',
"nspar;": '\U00002226',
"nsqsube;": '\U000022E2',
"nsqsupe;": '\U000022E3',
"nsub;": '\U00002284',
"nsube;": '\U00002288',
"nsubseteq;": '\U00002288',
"nsucc;": '\U00002281',
"nsup;": '\U00002285',
"nsupe;": '\U00002289',
"nsupseteq;": '\U00002289',
"ntgl;": '\U00002279',
"ntilde;": '\U000000F1',
"ntlg;": '\U00002278',
"ntriangleleft;": '\U000022EA',
"ntrianglelefteq;": '\U000022EC',
"ntriangleright;": '\U000022EB',
"ntrianglerighteq;": '\U000022ED',
"nu;": '\U000003BD',
"num;": '\U00000023',
"numero;": '\U00002116',
"numsp;": '\U00002007',
"nvDash;": '\U000022AD',
"nvHarr;": '\U00002904',
"nvdash;": '\U000022AC',
"nvinfin;": '\U000029DE',
"nvlArr;": '\U00002902',
"nvrArr;": '\U00002903',
"nwArr;": '\U000021D6',
"nwarhk;": '\U00002923',
"nwarr;": '\U00002196',
"nwarrow;": '\U00002196',
"nwnear;": '\U00002927',
"oS;": '\U000024C8',
"oacute;": '\U000000F3',
"oast;": '\U0000229B',
"ocir;": '\U0000229A',
"ocirc;": '\U000000F4',
"ocy;": '\U0000043E',
"odash;": '\U0000229D',
"odblac;": '\U00000151',
"odiv;": '\U00002A38',
"odot;": '\U00002299',
"odsold;": '\U000029BC',
"oelig;": '\U00000153',
"ofcir;": '\U000029BF',
"ofr;": '\U0001D52C',
"ogon;": '\U000002DB',
"ograve;": '\U000000F2',
"ogt;": '\U000029C1',
"ohbar;": '\U000029B5',
"ohm;": '\U000003A9',
"oint;": '\U0000222E',
"olarr;": '\U000021BA',
"olcir;": '\U000029BE',
"olcross;": '\U000029BB',
"oline;": '\U0000203E',
"olt;": '\U000029C0',
"omacr;": '\U0000014D',
"omega;": '\U000003C9',
"omicron;": '\U000003BF',
"omid;": '\U000029B6',
"ominus;": '\U00002296',
"oopf;": '\U0001D560',
"opar;": '\U000029B7',
"operp;": '\U000029B9',
"oplus;": '\U00002295',
"or;": '\U00002228',
"orarr;": '\U000021BB',
"ord;": '\U00002A5D',
"order;": '\U00002134',
"orderof;": '\U00002134',
"ordf;": '\U000000AA',
"ordm;": '\U000000BA',
"origof;": '\U000022B6',
"oror;": '\U00002A56',
"orslope;": '\U00002A57',
"orv;": '\U00002A5B',
"oscr;": '\U00002134',
"oslash;": '\U000000F8',
"osol;": '\U00002298',
"otilde;": '\U000000F5',
"otimes;": '\U00002297',
"otimesas;": '\U00002A36',
"ouml;": '\U000000F6',
"ovbar;": '\U0000233D',
"par;": '\U00002225',
"para;": '\U000000B6',
"parallel;": '\U00002225',
"parsim;": '\U00002AF3',
"parsl;": '\U00002AFD',
"part;": '\U00002202',
"pcy;": '\U0000043F',
"percnt;": '\U00000025',
"period;": '\U0000002E',
"permil;": '\U00002030',
"perp;": '\U000022A5',
"pertenk;": '\U00002031',
"pfr;": '\U0001D52D',
"phi;": '\U000003C6',
"phiv;": '\U000003D5',
"phmmat;": '\U00002133',
"phone;": '\U0000260E',
"pi;": '\U000003C0',
"pitchfork;": '\U000022D4',
"piv;": '\U000003D6',
"planck;": '\U0000210F',
"planckh;": '\U0000210E',
"plankv;": '\U0000210F',
"plus;": '\U0000002B',
"plusacir;": '\U00002A23',
"plusb;": '\U0000229E',
"pluscir;": '\U00002A22',
"plusdo;": '\U00002214',
"plusdu;": '\U00002A25',
"pluse;": '\U00002A72',
"plusmn;": '\U000000B1',
"plussim;": '\U00002A26',
"plustwo;": '\U00002A27',
"pm;": '\U000000B1',
"pointint;": '\U00002A15',
"popf;": '\U0001D561',
"pound;": '\U000000A3',
"pr;": '\U0000227A',
"prE;": '\U00002AB3',
"prap;": '\U00002AB7',
"prcue;": '\U0000227C',
"pre;": '\U00002AAF',
"prec;": '\U0000227A',
"precapprox;": '\U00002AB7',
"preccurlyeq;": '\U0000227C',
"preceq;": '\U00002AAF',
"precnapprox;": '\U00002AB9',
"precneqq;": '\U00002AB5',
"precnsim;": '\U000022E8',
"precsim;": '\U0000227E',
"prime;": '\U00002032',
"primes;": '\U00002119',
"prnE;": '\U00002AB5',
"prnap;": '\U00002AB9',
"prnsim;": '\U000022E8',
"prod;": '\U0000220F',
"profalar;": '\U0000232E',
"profline;": '\U00002312',
"profsurf;": '\U00002313',
"prop;": '\U0000221D',
"propto;": '\U0000221D',
"prsim;": '\U0000227E',
"prurel;": '\U000022B0',
"pscr;": '\U0001D4C5',
"psi;": '\U000003C8',
"puncsp;": '\U00002008',
"qfr;": '\U0001D52E',
"qint;": '\U00002A0C',
"qopf;": '\U0001D562',
"qprime;": '\U00002057',
"qscr;": '\U0001D4C6',
"quaternions;": '\U0000210D',
"quatint;": '\U00002A16',
"quest;": '\U0000003F',
"questeq;": '\U0000225F',
"quot;": '\U00000022',
"rAarr;": '\U000021DB',
"rArr;": '\U000021D2',
"rAtail;": '\U0000291C',
"rBarr;": '\U0000290F',
"rHar;": '\U00002964',
"racute;": '\U00000155',
"radic;": '\U0000221A',
"raemptyv;": '\U000029B3',
"rang;": '\U000027E9',
"rangd;": '\U00002992',
"range;": '\U000029A5',
"rangle;": '\U000027E9',
"raquo;": '\U000000BB',
"rarr;": '\U00002192',
"rarrap;": '\U00002975',
"rarrb;": '\U000021E5',
"rarrbfs;": '\U00002920',
"rarrc;": '\U00002933',
"rarrfs;": '\U0000291E',
"rarrhk;": '\U000021AA',
"rarrlp;": '\U000021AC',
"rarrpl;": '\U00002945',
"rarrsim;": '\U00002974',
"rarrtl;": '\U000021A3',
"rarrw;": '\U0000219D',
"ratail;": '\U0000291A',
"ratio;": '\U00002236',
"rationals;": '\U0000211A',
"rbarr;": '\U0000290D',
"rbbrk;": '\U00002773',
"rbrace;": '\U0000007D',
"rbrack;": '\U0000005D',
"rbrke;": '\U0000298C',
"rbrksld;": '\U0000298E',
"rbrkslu;": '\U00002990',
"rcaron;": '\U00000159',
"rcedil;": '\U00000157',
"rceil;": '\U00002309',
"rcub;": '\U0000007D',
"rcy;": '\U00000440',
"rdca;": '\U00002937',
"rdldhar;": '\U00002969',
"rdquo;": '\U0000201D',
"rdquor;": '\U0000201D',
"rdsh;": '\U000021B3',
"real;": '\U0000211C',
"realine;": '\U0000211B',
"realpart;": '\U0000211C',
"reals;": '\U0000211D',
"rect;": '\U000025AD',
"reg;": '\U000000AE',
"rfisht;": '\U0000297D',
"rfloor;": '\U0000230B',
"rfr;": '\U0001D52F',
"rhard;": '\U000021C1',
"rharu;": '\U000021C0',
"rharul;": '\U0000296C',
"rho;": '\U000003C1',
"rhov;": '\U000003F1',
"rightarrow;": '\U00002192',
"rightarrowtail;": '\U000021A3',
"rightharpoondown;": '\U000021C1',
"rightharpoonup;": '\U000021C0',
"rightleftarrows;": '\U000021C4',
"rightleftharpoons;": '\U000021CC',
"rightrightarrows;": '\U000021C9',
"rightsquigarrow;": '\U0000219D',
"rightthreetimes;": '\U000022CC',
"ring;": '\U000002DA',
"risingdotseq;": '\U00002253',
"rlarr;": '\U000021C4',
"rlhar;": '\U000021CC',
"rlm;": '\U0000200F',
"rmoust;": '\U000023B1',
"rmoustache;": '\U000023B1',
"rnmid;": '\U00002AEE',
"roang;": '\U000027ED',
"roarr;": '\U000021FE',
"robrk;": '\U000027E7',
"ropar;": '\U00002986',
"ropf;": '\U0001D563',
"roplus;": '\U00002A2E',
"rotimes;": '\U00002A35',
"rpar;": '\U00000029',
"rpargt;": '\U00002994',
"rppolint;": '\U00002A12',
"rrarr;": '\U000021C9',
"rsaquo;": '\U0000203A',
"rscr;": '\U0001D4C7',
"rsh;": '\U000021B1',
"rsqb;": '\U0000005D',
"rsquo;": '\U00002019',
"rsquor;": '\U00002019',
"rthree;": '\U000022CC',
"rtimes;": '\U000022CA',
"rtri;": '\U000025B9',
"rtrie;": '\U000022B5',
"rtrif;": '\U000025B8',
"rtriltri;": '\U000029CE',
"ruluhar;": '\U00002968',
"rx;": '\U0000211E',
"sacute;": '\U0000015B',
"sbquo;": '\U0000201A',
"sc;": '\U0000227B',
"scE;": '\U00002AB4',
"scap;": '\U00002AB8',
"scaron;": '\U00000161',
"sccue;": '\U0000227D',
"sce;": '\U00002AB0',
"scedil;": '\U0000015F',
"scirc;": '\U0000015D',
"scnE;": '\U00002AB6',
"scnap;": '\U00002ABA',
"scnsim;": '\U000022E9',
"scpolint;": '\U00002A13',
"scsim;": '\U0000227F',
"scy;": '\U00000441',
"sdot;": '\U000022C5',
"sdotb;": '\U000022A1',
"sdote;": '\U00002A66',
"seArr;": '\U000021D8',
"searhk;": '\U00002925',
"searr;": '\U00002198',
"searrow;": '\U00002198',
"sect;": '\U000000A7',
"semi;": '\U0000003B',
"seswar;": '\U00002929',
"setminus;": '\U00002216',
"setmn;": '\U00002216',
"sext;": '\U00002736',
"sfr;": '\U0001D530',
"sfrown;": '\U00002322',
"sharp;": '\U0000266F',
"shchcy;": '\U00000449',
"shcy;": '\U00000448',
"shortmid;": '\U00002223',
"shortparallel;": '\U00002225',
"shy;": '\U000000AD',
"sigma;": '\U000003C3',
"sigmaf;": '\U000003C2',
"sigmav;": '\U000003C2',
"sim;": '\U0000223C',
"simdot;": '\U00002A6A',
"sime;": '\U00002243',
"simeq;": '\U00002243',
"simg;": '\U00002A9E',
"simgE;": '\U00002AA0',
"siml;": '\U00002A9D',
"simlE;": '\U00002A9F',
"simne;": '\U00002246',
"simplus;": '\U00002A24',
"simrarr;": '\U00002972',
"slarr;": '\U00002190',
"smallsetminus;": '\U00002216',
"smashp;": '\U00002A33',
"smeparsl;": '\U000029E4',
"smid;": '\U00002223',
"smile;": '\U00002323',
"smt;": '\U00002AAA',
"smte;": '\U00002AAC',
"softcy;": '\U0000044C',
"sol;": '\U0000002F',
"solb;": '\U000029C4',
"solbar;": '\U0000233F',
"sopf;": '\U0001D564',
"spades;": '\U00002660',
"spadesuit;": '\U00002660',
"spar;": '\U00002225',
"sqcap;": '\U00002293',
"sqcup;": '\U00002294',
"sqsub;": '\U0000228F',
"sqsube;": '\U00002291',
"sqsubset;": '\U0000228F',
"sqsubseteq;": '\U00002291',
"sqsup;": '\U00002290',
"sqsupe;": '\U00002292',
"sqsupset;": '\U00002290',
"sqsupseteq;": '\U00002292',
"squ;": '\U000025A1',
"square;": '\U000025A1',
"squarf;": '\U000025AA',
"squf;": '\U000025AA',
"srarr;": '\U00002192',
"sscr;": '\U0001D4C8',
"ssetmn;": '\U00002216',
"ssmile;": '\U00002323',
"sstarf;": '\U000022C6',
"star;": '\U00002606',
"starf;": '\U00002605',
"straightepsilon;": '\U000003F5',
"straightphi;": '\U000003D5',
"strns;": '\U000000AF',
"sub;": '\U00002282',
"subE;": '\U00002AC5',
"subdot;": '\U00002ABD',
"sube;": '\U00002286',
"subedot;": '\U00002AC3',
"submult;": '\U00002AC1',
"subnE;": '\U00002ACB',
"subne;": '\U0000228A',
"subplus;": '\U00002ABF',
"subrarr;": '\U00002979',
"subset;": '\U00002282',
"subseteq;": '\U00002286',
"subseteqq;": '\U00002AC5',
"subsetneq;": '\U0000228A',
"subsetneqq;": '\U00002ACB',
"subsim;": '\U00002AC7',
"subsub;": '\U00002AD5',
"subsup;": '\U00002AD3',
"succ;": '\U0000227B',
"succapprox;": '\U00002AB8',
"succcurlyeq;": '\U0000227D',
"succeq;": '\U00002AB0',
"succnapprox;": '\U00002ABA',
"succneqq;": '\U00002AB6',
"succnsim;": '\U000022E9',
"succsim;": '\U0000227F',
"sum;": '\U00002211',
"sung;": '\U0000266A',
"sup;": '\U00002283',
"sup1;": '\U000000B9',
"sup2;": '\U000000B2',
"sup3;": '\U000000B3',
"supE;": '\U00002AC6',
"supdot;": '\U00002ABE',
"supdsub;": '\U00002AD8',
"supe;": '\U00002287',
"supedot;": '\U00002AC4',
"suphsol;": '\U000027C9',
"suphsub;": '\U00002AD7',
"suplarr;": '\U0000297B',
"supmult;": '\U00002AC2',
"supnE;": '\U00002ACC',
"supne;": '\U0000228B',
"supplus;": '\U00002AC0',
"supset;": '\U00002283',
"supseteq;": '\U00002287',
"supseteqq;": '\U00002AC6',
"supsetneq;": '\U0000228B',
"supsetneqq;": '\U00002ACC',
"supsim;": '\U00002AC8',
"supsub;": '\U00002AD4',
"supsup;": '\U00002AD6',
"swArr;": '\U000021D9',
"swarhk;": '\U00002926',
"swarr;": '\U00002199',
"swarrow;": '\U00002199',
"swnwar;": '\U0000292A',
"szlig;": '\U000000DF',
"target;": '\U00002316',
"tau;": '\U000003C4',
"tbrk;": '\U000023B4',
"tcaron;": '\U00000165',
"tcedil;": '\U00000163',
"tcy;": '\U00000442',
"tdot;": '\U000020DB',
"telrec;": '\U00002315',
"tfr;": '\U0001D531',
"there4;": '\U00002234',
"therefore;": '\U00002234',
"theta;": '\U000003B8',
"thetasym;": '\U000003D1',
"thetav;": '\U000003D1',
"thickapprox;": '\U00002248',
"thicksim;": '\U0000223C',
"thinsp;": '\U00002009',
"thkap;": '\U00002248',
"thksim;": '\U0000223C',
"thorn;": '\U000000FE',
"tilde;": '\U000002DC',
"times;": '\U000000D7',
"timesb;": '\U000022A0',
"timesbar;": '\U00002A31',
"timesd;": '\U00002A30',
"tint;": '\U0000222D',
"toea;": '\U00002928',
"top;": '\U000022A4',
"topbot;": '\U00002336',
"topcir;": '\U00002AF1',
"topf;": '\U0001D565',
"topfork;": '\U00002ADA',
"tosa;": '\U00002929',
"tprime;": '\U00002034',
"trade;": '\U00002122',
"triangle;": '\U000025B5',
"triangledown;": '\U000025BF',
"triangleleft;": '\U000025C3',
"trianglelefteq;": '\U000022B4',
"triangleq;": '\U0000225C',
"triangleright;": '\U000025B9',
"trianglerighteq;": '\U000022B5',
"tridot;": '\U000025EC',
"trie;": '\U0000225C',
"triminus;": '\U00002A3A',
"triplus;": '\U00002A39',
"trisb;": '\U000029CD',
"tritime;": '\U00002A3B',
"trpezium;": '\U000023E2',
"tscr;": '\U0001D4C9',
"tscy;": '\U00000446',
"tshcy;": '\U0000045B',
"tstrok;": '\U00000167',
"twixt;": '\U0000226C',
"twoheadleftarrow;": '\U0000219E',
"twoheadrightarrow;": '\U000021A0',
"uArr;": '\U000021D1',
"uHar;": '\U00002963',
"uacute;": '\U000000FA',
"uarr;": '\U00002191',
"ubrcy;": '\U0000045E',
"ubreve;": '\U0000016D',
"ucirc;": '\U000000FB',
"ucy;": '\U00000443',
"udarr;": '\U000021C5',
"udblac;": '\U00000171',
"udhar;": '\U0000296E',
"ufisht;": '\U0000297E',
"ufr;": '\U0001D532',
"ugrave;": '\U000000F9',
"uharl;": '\U000021BF',
"uharr;": '\U000021BE',
"uhblk;": '\U00002580',
"ulcorn;": '\U0000231C',
"ulcorner;": '\U0000231C',
"ulcrop;": '\U0000230F',
"ultri;": '\U000025F8',
"umacr;": '\U0000016B',
"uml;": '\U000000A8',
"uogon;": '\U00000173',
"uopf;": '\U0001D566',
"uparrow;": '\U00002191',
"updownarrow;": '\U00002195',
"upharpoonleft;": '\U000021BF',
"upharpoonright;": '\U000021BE',
"uplus;": '\U0000228E',
"upsi;": '\U000003C5',
"upsih;": '\U000003D2',
"upsilon;": '\U000003C5',
"upuparrows;": '\U000021C8',
"urcorn;": '\U0000231D',
"urcorner;": '\U0000231D',
"urcrop;": '\U0000230E',
"uring;": '\U0000016F',
"urtri;": '\U000025F9',
"uscr;": '\U0001D4CA',
"utdot;": '\U000022F0',
"utilde;": '\U00000169',
"utri;": '\U000025B5',
"utrif;": '\U000025B4',
"uuarr;": '\U000021C8',
"uuml;": '\U000000FC',
"uwangle;": '\U000029A7',
"vArr;": '\U000021D5',
"vBar;": '\U00002AE8',
"vBarv;": '\U00002AE9',
"vDash;": '\U000022A8',
"vangrt;": '\U0000299C',
"varepsilon;": '\U000003F5',
"varkappa;": '\U000003F0',
"varnothing;": '\U00002205',
"varphi;": '\U000003D5',
"varpi;": '\U000003D6',
"varpropto;": '\U0000221D',
"varr;": '\U00002195',
"varrho;": '\U000003F1',
"varsigma;": '\U000003C2',
"vartheta;": '\U000003D1',
"vartriangleleft;": '\U000022B2',
"vartriangleright;": '\U000022B3',
"vcy;": '\U00000432',
"vdash;": '\U000022A2',
"vee;": '\U00002228',
"veebar;": '\U000022BB',
"veeeq;": '\U0000225A',
"vellip;": '\U000022EE',
"verbar;": '\U0000007C',
"vert;": '\U0000007C',
"vfr;": '\U0001D533',
"vltri;": '\U000022B2',
"vopf;": '\U0001D567',
"vprop;": '\U0000221D',
"vrtri;": '\U000022B3',
"vscr;": '\U0001D4CB',
"vzigzag;": '\U0000299A',
"wcirc;": '\U00000175',
"wedbar;": '\U00002A5F',
"wedge;": '\U00002227',
"wedgeq;": '\U00002259',
"weierp;": '\U00002118',
"wfr;": '\U0001D534',
"wopf;": '\U0001D568',
"wp;": '\U00002118',
"wr;": '\U00002240',
"wreath;": '\U00002240',
"wscr;": '\U0001D4CC',
"xcap;": '\U000022C2',
"xcirc;": '\U000025EF',
"xcup;": '\U000022C3',
"xdtri;": '\U000025BD',
"xfr;": '\U0001D535',
"xhArr;": '\U000027FA',
"xharr;": '\U000027F7',
"xi;": '\U000003BE',
"xlArr;": '\U000027F8',
"xlarr;": '\U000027F5',
"xmap;": '\U000027FC',
"xnis;": '\U000022FB',
"xodot;": '\U00002A00',
"xopf;": '\U0001D569',
"xoplus;": '\U00002A01',
"xotime;": '\U00002A02',
"xrArr;": '\U000027F9',
"xrarr;": '\U000027F6',
"xscr;": '\U0001D4CD',
"xsqcup;": '\U00002A06',
"xuplus;": '\U00002A04',
"xutri;": '\U000025B3',
"xvee;": '\U000022C1',
"xwedge;": '\U000022C0',
"yacute;": '\U000000FD',
"yacy;": '\U0000044F',
"ycirc;": '\U00000177',
"ycy;": '\U0000044B',
"yen;": '\U000000A5',
"yfr;": '\U0001D536',
"yicy;": '\U00000457',
"yopf;": '\U0001D56A',
"yscr;": '\U0001D4CE',
"yucy;": '\U0000044E',
"yuml;": '\U000000FF',
"zacute;": '\U0000017A',
"zcaron;": '\U0000017E',
"zcy;": '\U00000437',
"zdot;": '\U0000017C',
"zeetrf;": '\U00002128',
"zeta;": '\U000003B6',
"zfr;": '\U0001D537',
"zhcy;": '\U00000436',
"zigrarr;": '\U000021DD',
"zopf;": '\U0001D56B',
"zscr;": '\U0001D4CF',
"zwj;": '\U0000200D',
"zwnj;": '\U0000200C',
"AElig": '\U000000C6',
"AMP": '\U00000026',
"Aacute": '\U000000C1',
"Acirc": '\U000000C2',
"Agrave": '\U000000C0',
"Aring": '\U000000C5',
"Atilde": '\U000000C3',
"Auml": '\U000000C4',
"COPY": '\U000000A9',
"Ccedil": '\U000000C7',
"ETH": '\U000000D0',
"Eacute": '\U000000C9',
"Ecirc": '\U000000CA',
"Egrave": '\U000000C8',
"Euml": '\U000000CB',
"GT": '\U0000003E',
"Iacute": '\U000000CD',
"Icirc": '\U000000CE',
"Igrave": '\U000000CC',
"Iuml": '\U000000CF',
"LT": '\U0000003C',
"Ntilde": '\U000000D1',
"Oacute": '\U000000D3',
"Ocirc": '\U000000D4',
"Ograve": '\U000000D2',
"Oslash": '\U000000D8',
"Otilde": '\U000000D5',
"Ouml": '\U000000D6',
"QUOT": '\U00000022',
"REG": '\U000000AE',
"THORN": '\U000000DE',
"Uacute": '\U000000DA',
"Ucirc": '\U000000DB',
"Ugrave": '\U000000D9',
"Uuml": '\U000000DC',
"Yacute": '\U000000DD',
"aacute": '\U000000E1',
"acirc": '\U000000E2',
"acute": '\U000000B4',
"aelig": '\U000000E6',
"agrave": '\U000000E0',
"amp": '\U00000026',
"aring": '\U000000E5',
"atilde": '\U000000E3',
"auml": '\U000000E4',
"brvbar": '\U000000A6',
"ccedil": '\U000000E7',
"cedil": '\U000000B8',
"cent": '\U000000A2',
"copy": '\U000000A9',
"curren": '\U000000A4',
"deg": '\U000000B0',
"divide": '\U000000F7',
"eacute": '\U000000E9',
"ecirc": '\U000000EA',
"egrave": '\U000000E8',
"eth": '\U000000F0',
"euml": '\U000000EB',
"frac12": '\U000000BD',
"frac14": '\U000000BC',
"frac34": '\U000000BE',
"gt": '\U0000003E',
"iacute": '\U000000ED',
"icirc": '\U000000EE',
"iexcl": '\U000000A1',
"igrave": '\U000000EC',
"iquest": '\U000000BF',
"iuml": '\U000000EF',
"laquo": '\U000000AB',
"lt": '\U0000003C',
"macr": '\U000000AF',
"micro": '\U000000B5',
"middot": '\U000000B7',
"nbsp": '\U000000A0',
"not": '\U000000AC',
"ntilde": '\U000000F1',
"oacute": '\U000000F3',
"ocirc": '\U000000F4',
"ograve": '\U000000F2',
"ordf": '\U000000AA',
"ordm": '\U000000BA',
"oslash": '\U000000F8',
"otilde": '\U000000F5',
"ouml": '\U000000F6',
"para": '\U000000B6',
"plusmn": '\U000000B1',
"pound": '\U000000A3',
"quot": '\U00000022',
"raquo": '\U000000BB',
"reg": '\U000000AE',
"sect": '\U000000A7',
"shy": '\U000000AD',
"sup1": '\U000000B9',
"sup2": '\U000000B2',
"sup3": '\U000000B3',
"szlig": '\U000000DF',
"thorn": '\U000000FE',
"times": '\U000000D7',
"uacute": '\U000000FA',
"ucirc": '\U000000FB',
"ugrave": '\U000000F9',
"uml": '\U000000A8',
"uuml": '\U000000FC',
"yacute": '\U000000FD',
"yen": '\U000000A5',
"yuml": '\U000000FF',
}
entity2 = map[string][2]rune{
// TODO(nigeltao): Handle replacements that are wider than their names.
// "nLt;": {'\u226A', '\u20D2'},
// "nGt;": {'\u226B', '\u20D2'},
"NotEqualTilde;": {'\u2242', '\u0338'},
"NotGreaterFullEqual;": {'\u2267', '\u0338'},
"NotGreaterGreater;": {'\u226B', '\u0338'},
"NotGreaterSlantEqual;": {'\u2A7E', '\u0338'},
"NotHumpDownHump;": {'\u224E', '\u0338'},
"NotHumpEqual;": {'\u224F', '\u0338'},
"NotLeftTriangleBar;": {'\u29CF', '\u0338'},
"NotLessLess;": {'\u226A', '\u0338'},
"NotLessSlantEqual;": {'\u2A7D', '\u0338'},
"NotNestedGreaterGreater;": {'\u2AA2', '\u0338'},
"NotNestedLessLess;": {'\u2AA1', '\u0338'},
"NotPrecedesEqual;": {'\u2AAF', '\u0338'},
"NotRightTriangleBar;": {'\u29D0', '\u0338'},
"NotSquareSubset;": {'\u228F', '\u0338'},
"NotSquareSuperset;": {'\u2290', '\u0338'},
"NotSubset;": {'\u2282', '\u20D2'},
"NotSucceedsEqual;": {'\u2AB0', '\u0338'},
"NotSucceedsTilde;": {'\u227F', '\u0338'},
"NotSuperset;": {'\u2283', '\u20D2'},
"ThickSpace;": {'\u205F', '\u200A'},
"acE;": {'\u223E', '\u0333'},
"bne;": {'\u003D', '\u20E5'},
"bnequiv;": {'\u2261', '\u20E5'},
"caps;": {'\u2229', '\uFE00'},
"cups;": {'\u222A', '\uFE00'},
"fjlig;": {'\u0066', '\u006A'},
"gesl;": {'\u22DB', '\uFE00'},
"gvertneqq;": {'\u2269', '\uFE00'},
"gvnE;": {'\u2269', '\uFE00'},
"lates;": {'\u2AAD', '\uFE00'},
"lesg;": {'\u22DA', '\uFE00'},
"lvertneqq;": {'\u2268', '\uFE00'},
"lvnE;": {'\u2268', '\uFE00'},
"nGg;": {'\u22D9', '\u0338'},
"nGtv;": {'\u226B', '\u0338'},
"nLl;": {'\u22D8', '\u0338'},
"nLtv;": {'\u226A', '\u0338'},
"nang;": {'\u2220', '\u20D2'},
"napE;": {'\u2A70', '\u0338'},
"napid;": {'\u224B', '\u0338'},
"nbump;": {'\u224E', '\u0338'},
"nbumpe;": {'\u224F', '\u0338'},
"ncongdot;": {'\u2A6D', '\u0338'},
"nedot;": {'\u2250', '\u0338'},
"nesim;": {'\u2242', '\u0338'},
"ngE;": {'\u2267', '\u0338'},
"ngeqq;": {'\u2267', '\u0338'},
"ngeqslant;": {'\u2A7E', '\u0338'},
"nges;": {'\u2A7E', '\u0338'},
"nlE;": {'\u2266', '\u0338'},
"nleqq;": {'\u2266', '\u0338'},
"nleqslant;": {'\u2A7D', '\u0338'},
"nles;": {'\u2A7D', '\u0338'},
"notinE;": {'\u22F9', '\u0338'},
"notindot;": {'\u22F5', '\u0338'},
"nparsl;": {'\u2AFD', '\u20E5'},
"npart;": {'\u2202', '\u0338'},
"npre;": {'\u2AAF', '\u0338'},
"npreceq;": {'\u2AAF', '\u0338'},
"nrarrc;": {'\u2933', '\u0338'},
"nrarrw;": {'\u219D', '\u0338'},
"nsce;": {'\u2AB0', '\u0338'},
"nsubE;": {'\u2AC5', '\u0338'},
"nsubset;": {'\u2282', '\u20D2'},
"nsubseteqq;": {'\u2AC5', '\u0338'},
"nsucceq;": {'\u2AB0', '\u0338'},
"nsupE;": {'\u2AC6', '\u0338'},
"nsupset;": {'\u2283', '\u20D2'},
"nsupseteqq;": {'\u2AC6', '\u0338'},
"nvap;": {'\u224D', '\u20D2'},
"nvge;": {'\u2265', '\u20D2'},
"nvgt;": {'\u003E', '\u20D2'},
"nvle;": {'\u2264', '\u20D2'},
"nvlt;": {'\u003C', '\u20D2'},
"nvltrie;": {'\u22B4', '\u20D2'},
"nvrtrie;": {'\u22B5', '\u20D2'},
"nvsim;": {'\u223C', '\u20D2'},
"race;": {'\u223D', '\u0331'},
"smtes;": {'\u2AAC', '\uFE00'},
"sqcaps;": {'\u2293', '\uFE00'},
"sqcups;": {'\u2294', '\uFE00'},
"varsubsetneq;": {'\u228A', '\uFE00'},
"varsubsetneqq;": {'\u2ACB', '\uFE00'},
"varsupsetneq;": {'\u228B', '\uFE00'},
"varsupsetneqq;": {'\u2ACC', '\uFE00'},
"vnsub;": {'\u2282', '\u20D2'},
"vnsup;": {'\u2283', '\u20D2'},
"vsubnE;": {'\u2ACB', '\uFE00'},
"vsubne;": {'\u228A', '\uFE00'},
"vsupnE;": {'\u2ACC', '\uFE00'},
"vsupne;": {'\u228B', '\uFE00'},
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package html provides functions for escaping and unescaping HTML text.
package html
import (
"strings"
"unicode/utf8"
)
// These replacements permit compatibility with old numeric entities that
// assumed Windows-1252 encoding.
// https://html.spec.whatwg.org/multipage/parsing.html#numeric-character-reference-end-state
var replacementTable = [...]rune{
'\u20AC', // First entry is what 0x80 should be replaced with.
'\u0081',
'\u201A',
'\u0192',
'\u201E',
'\u2026',
'\u2020',
'\u2021',
'\u02C6',
'\u2030',
'\u0160',
'\u2039',
'\u0152',
'\u008D',
'\u017D',
'\u008F',
'\u0090',
'\u2018',
'\u2019',
'\u201C',
'\u201D',
'\u2022',
'\u2013',
'\u2014',
'\u02DC',
'\u2122',
'\u0161',
'\u203A',
'\u0153',
'\u009D',
'\u017E',
'\u0178', // Last entry is 0x9F.
// 0x00->'\uFFFD' is handled programmatically.
// 0x0D->'\u000D' is a no-op.
}
// unescapeEntity reads an entity like "<" from b[src:] and writes the
// corresponding "<" to b[dst:], returning the incremented dst and src cursors.
// Precondition: b[src] == '&' && dst <= src.
func unescapeEntity(b []byte, dst, src int) (dst1, src1 int) {
const attribute = false
// http://www.whatwg.org/specs/web-apps/current-work/multipage/tokenization.html#consume-a-character-reference
// i starts at 1 because we already know that s[0] == '&'.
i, s := 1, b[src:]
if len(s) <= 1 {
b[dst] = b[src]
return dst + 1, src + 1
}
if s[i] == '#' {
if len(s) <= 3 { // We need to have at least "&#.".
b[dst] = b[src]
return dst + 1, src + 1
}
i++
c := s[i]
hex := false
if c == 'x' || c == 'X' {
hex = true
i++
}
x := '\x00'
for i < len(s) {
c = s[i]
i++
if hex {
if '0' <= c && c <= '9' {
x = 16*x + rune(c) - '0'
continue
} else if 'a' <= c && c <= 'f' {
x = 16*x + rune(c) - 'a' + 10
continue
} else if 'A' <= c && c <= 'F' {
x = 16*x + rune(c) - 'A' + 10
continue
}
} else if '0' <= c && c <= '9' {
x = 10*x + rune(c) - '0'
continue
}
if c != ';' {
i--
}
break
}
if i <= 3 { // No characters matched.
b[dst] = b[src]
return dst + 1, src + 1
}
if 0x80 <= x && x <= 0x9F {
// Replace characters from Windows-1252 with UTF-8 equivalents.
x = replacementTable[x-0x80]
} else if x == 0 || (0xD800 <= x && x <= 0xDFFF) || x > 0x10FFFF {
// Replace invalid characters with the replacement character.
x = '\uFFFD'
}
return dst + utf8.EncodeRune(b[dst:], x), src + i
}
// Consume the maximum number of characters possible, with the
// consumed characters matching one of the named references.
for i < len(s) {
c := s[i]
i++
// Lower-cased characters are more common in entities, so we check for them first.
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
continue
}
if c != ';' {
i--
}
break
}
entityName := s[1:i]
if len(entityName) == 0 {
// No-op.
} else if attribute && entityName[len(entityName)-1] != ';' && len(s) > i && s[i] == '=' {
// No-op.
} else if x := entity[string(entityName)]; x != 0 {
return dst + utf8.EncodeRune(b[dst:], x), src + i
} else if x := entity2[string(entityName)]; x[0] != 0 {
dst1 := dst + utf8.EncodeRune(b[dst:], x[0])
return dst1 + utf8.EncodeRune(b[dst1:], x[1]), src + i
} else if !attribute {
maxLen := len(entityName) - 1
if maxLen > longestEntityWithoutSemicolon {
maxLen = longestEntityWithoutSemicolon
}
for j := maxLen; j > 1; j-- {
if x := entity[string(entityName[:j])]; x != 0 {
return dst + utf8.EncodeRune(b[dst:], x), src + j + 1
}
}
}
dst1, src1 = dst+i, src+i
copy(b[dst:dst1], b[src:src1])
return dst1, src1
}
var htmlEscaper = strings.NewReplacer(
`&`, "&",
`'`, "'", // "'" is shorter than "'" and apos was not in HTML until HTML5.
`<`, "<",
`>`, ">",
`"`, """, // """ is shorter than """.
)
// EscapeString escapes special characters like "<" to become "<". It
// escapes only five such characters: <, >, &, ' and ".
// UnescapeString(EscapeString(s)) == s always holds, but the converse isn't
// always true.
func EscapeString(s string) string {
return htmlEscaper.Replace(s)
}
// UnescapeString unescapes entities like "<" to become "<". It unescapes a
// larger range of entities than EscapeString escapes. For example, "á"
// unescapes to "á", as does "á" and "á".
// UnescapeString(EscapeString(s)) == s always holds, but the converse isn't
// always true.
func UnescapeString(s string) string {
populateMapsOnce.Do(populateMaps)
i := strings.IndexByte(s, '&')
if i < 0 {
return s
}
b := []byte(s)
dst, src := unescapeEntity(b, i, i)
for len(s[src:]) > 0 {
if s[src] == '&' {
i = 0
} else {
i = strings.IndexByte(s[src:], '&')
}
if i < 0 {
dst += copy(b[dst:], s[src:])
break
}
if i > 0 {
copy(b[dst:], s[src:src+i])
}
dst, src = unescapeEntity(b, dst+i, src+i)
}
return string(b[:dst])
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"strings"
)
// attrTypeMap[n] describes the value of the given attribute.
// If an attribute affects (or can mask) the encoding or interpretation of
// other content, or affects the contents, idempotency, or credentials of a
// network message, then the value in this map is contentTypeUnsafe.
// This map is derived from HTML5, specifically
// https://www.w3.org/TR/html5/Overview.html#attributes-1
// as well as "%URI"-typed attributes from
// https://www.w3.org/TR/html4/index/attributes.html
var attrTypeMap = map[string]contentType{
"accept": contentTypePlain,
"accept-charset": contentTypeUnsafe,
"action": contentTypeURL,
"alt": contentTypePlain,
"archive": contentTypeURL,
"async": contentTypeUnsafe,
"autocomplete": contentTypePlain,
"autofocus": contentTypePlain,
"autoplay": contentTypePlain,
"background": contentTypeURL,
"border": contentTypePlain,
"checked": contentTypePlain,
"cite": contentTypeURL,
"challenge": contentTypeUnsafe,
"charset": contentTypeUnsafe,
"class": contentTypePlain,
"classid": contentTypeURL,
"codebase": contentTypeURL,
"cols": contentTypePlain,
"colspan": contentTypePlain,
"content": contentTypeUnsafe,
"contenteditable": contentTypePlain,
"contextmenu": contentTypePlain,
"controls": contentTypePlain,
"coords": contentTypePlain,
"crossorigin": contentTypeUnsafe,
"data": contentTypeURL,
"datetime": contentTypePlain,
"default": contentTypePlain,
"defer": contentTypeUnsafe,
"dir": contentTypePlain,
"dirname": contentTypePlain,
"disabled": contentTypePlain,
"draggable": contentTypePlain,
"dropzone": contentTypePlain,
"enctype": contentTypeUnsafe,
"for": contentTypePlain,
"form": contentTypeUnsafe,
"formaction": contentTypeURL,
"formenctype": contentTypeUnsafe,
"formmethod": contentTypeUnsafe,
"formnovalidate": contentTypeUnsafe,
"formtarget": contentTypePlain,
"headers": contentTypePlain,
"height": contentTypePlain,
"hidden": contentTypePlain,
"high": contentTypePlain,
"href": contentTypeURL,
"hreflang": contentTypePlain,
"http-equiv": contentTypeUnsafe,
"icon": contentTypeURL,
"id": contentTypePlain,
"ismap": contentTypePlain,
"keytype": contentTypeUnsafe,
"kind": contentTypePlain,
"label": contentTypePlain,
"lang": contentTypePlain,
"language": contentTypeUnsafe,
"list": contentTypePlain,
"longdesc": contentTypeURL,
"loop": contentTypePlain,
"low": contentTypePlain,
"manifest": contentTypeURL,
"max": contentTypePlain,
"maxlength": contentTypePlain,
"media": contentTypePlain,
"mediagroup": contentTypePlain,
"method": contentTypeUnsafe,
"min": contentTypePlain,
"multiple": contentTypePlain,
"name": contentTypePlain,
"novalidate": contentTypeUnsafe,
// Skip handler names from
// https://www.w3.org/TR/html5/webappapis.html#event-handlers-on-elements,-document-objects,-and-window-objects
// since we have special handling in attrType.
"open": contentTypePlain,
"optimum": contentTypePlain,
"pattern": contentTypeUnsafe,
"placeholder": contentTypePlain,
"poster": contentTypeURL,
"profile": contentTypeURL,
"preload": contentTypePlain,
"pubdate": contentTypePlain,
"radiogroup": contentTypePlain,
"readonly": contentTypePlain,
"rel": contentTypeUnsafe,
"required": contentTypePlain,
"reversed": contentTypePlain,
"rows": contentTypePlain,
"rowspan": contentTypePlain,
"sandbox": contentTypeUnsafe,
"spellcheck": contentTypePlain,
"scope": contentTypePlain,
"scoped": contentTypePlain,
"seamless": contentTypePlain,
"selected": contentTypePlain,
"shape": contentTypePlain,
"size": contentTypePlain,
"sizes": contentTypePlain,
"span": contentTypePlain,
"src": contentTypeURL,
"srcdoc": contentTypeHTML,
"srclang": contentTypePlain,
"srcset": contentTypeSrcset,
"start": contentTypePlain,
"step": contentTypePlain,
"style": contentTypeCSS,
"tabindex": contentTypePlain,
"target": contentTypePlain,
"title": contentTypePlain,
"type": contentTypeUnsafe,
"usemap": contentTypeURL,
"value": contentTypeUnsafe,
"width": contentTypePlain,
"wrap": contentTypePlain,
"xmlns": contentTypeURL,
}
// attrType returns a conservative (upper-bound on authority) guess at the
// type of the lowercase named attribute.
func attrType(name string) contentType {
if strings.HasPrefix(name, "data-") {
// Strip data- so that custom attribute heuristics below are
// widely applied.
// Treat data-action as URL below.
name = name[5:]
} else if prefix, short, ok := strings.Cut(name, ":"); ok {
if prefix == "xmlns" {
return contentTypeURL
}
// Treat svg:href and xlink:href as href below.
name = short
}
if t, ok := attrTypeMap[name]; ok {
return t
}
// Treat partial event handler names as script.
if strings.HasPrefix(name, "on") {
return contentTypeJS
}
// Heuristics to prevent "javascript:..." injection in custom
// data attributes and custom attributes like g:tweetUrl.
// https://www.w3.org/TR/html5/dom.html#embedding-custom-non-visible-data-with-the-data-*-attributes
// "Custom data attributes are intended to store custom data
// private to the page or application, for which there are no
// more appropriate attributes or elements."
// Developers seem to store URL content in data URLs that start
// or end with "URI" or "URL".
if strings.Contains(name, "src") ||
strings.Contains(name, "uri") ||
strings.Contains(name, "url") {
return contentTypeURL
}
return contentTypePlain
}
// Code generated by "stringer -type attr"; DO NOT EDIT.
package template
import "strconv"
const _attr_name = "attrNoneattrScriptattrScriptTypeattrStyleattrURLattrSrcset"
var _attr_index = [...]uint8{0, 8, 18, 32, 41, 48, 58}
func (i attr) String() string {
if i >= attr(len(_attr_index)-1) {
return "attr(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _attr_name[_attr_index[i]:_attr_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"reflect"
)
// Strings of content from a trusted source.
type (
// CSS encapsulates known safe content that matches any of:
// 1. The CSS3 stylesheet production, such as `p { color: purple }`.
// 2. The CSS3 rule production, such as `a[href=~"https:"].foo#bar`.
// 3. CSS3 declaration productions, such as `color: red; margin: 2px`.
// 4. The CSS3 value production, such as `rgba(0, 0, 255, 127)`.
// See https://www.w3.org/TR/css3-syntax/#parsing and
// https://web.archive.org/web/20090211114933/http://w3.org/TR/css3-syntax#style
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
CSS string
// HTML encapsulates a known safe HTML document fragment.
// It should not be used for HTML from a third-party, or HTML with
// unclosed tags or comments. The outputs of a sound HTML sanitizer
// and a template escaped by this package are fine for use with HTML.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
HTML string
// HTMLAttr encapsulates an HTML attribute from a trusted source,
// for example, ` dir="ltr"`.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
HTMLAttr string
// JS encapsulates a known safe EcmaScript5 Expression, for example,
// `(x + y * z())`.
// Template authors are responsible for ensuring that typed expressions
// do not break the intended precedence and that there is no
// statement/expression ambiguity as when passing an expression like
// "{ foo: bar() }\n['foo']()", which is both a valid Expression and a
// valid Program with a very different meaning.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
//
// Using JS to include valid but untrusted JSON is not safe.
// A safe alternative is to parse the JSON with json.Unmarshal and then
// pass the resultant object into the template, where it will be
// converted to sanitized JSON when presented in a JavaScript context.
JS string
// JSStr encapsulates a sequence of characters meant to be embedded
// between quotes in a JavaScript expression.
// The string must match a series of StringCharacters:
// StringCharacter :: SourceCharacter but not `\` or LineTerminator
// | EscapeSequence
// Note that LineContinuations are not allowed.
// JSStr("foo\\nbar") is fine, but JSStr("foo\\\nbar") is not.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
JSStr string
// URL encapsulates a known safe URL or URL substring (see RFC 3986).
// A URL like `javascript:checkThatFormNotEditedBeforeLeavingPage()`
// from a trusted source should go in the page, but by default dynamic
// `javascript:` URLs are filtered out since they are a frequently
// exploited injection vector.
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
URL string
// Srcset encapsulates a known safe srcset attribute
// (see https://w3c.github.io/html/semantics-embedded-content.html#element-attrdef-img-srcset).
//
// Use of this type presents a security risk:
// the encapsulated content should come from a trusted source,
// as it will be included verbatim in the template output.
Srcset string
)
type contentType uint8
const (
contentTypePlain contentType = iota
contentTypeCSS
contentTypeHTML
contentTypeHTMLAttr
contentTypeJS
contentTypeJSStr
contentTypeURL
contentTypeSrcset
// contentTypeUnsafe is used in attr.go for values that affect how
// embedded content and network messages are formed, vetted,
// or interpreted; or which credentials network messages carry.
contentTypeUnsafe
)
// indirect returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil).
func indirect(a any) any {
if a == nil {
return nil
}
if t := reflect.TypeOf(a); t.Kind() != reflect.Pointer {
// Avoid creating a reflect.Value if it's not a pointer.
return a
}
v := reflect.ValueOf(a)
for v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
fmtStringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
)
// indirectToStringerOrError returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of fmt.Stringer
// or error.
func indirectToStringerOrError(a any) any {
if a == nil {
return nil
}
v := reflect.ValueOf(a)
for !v.Type().Implements(fmtStringerType) && !v.Type().Implements(errorType) && v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// stringify converts its arguments to a string and the type of the content.
// All pointers are dereferenced, as in the text/template package.
func stringify(args ...any) (string, contentType) {
if len(args) == 1 {
switch s := indirect(args[0]).(type) {
case string:
return s, contentTypePlain
case CSS:
return string(s), contentTypeCSS
case HTML:
return string(s), contentTypeHTML
case HTMLAttr:
return string(s), contentTypeHTMLAttr
case JS:
return string(s), contentTypeJS
case JSStr:
return string(s), contentTypeJSStr
case URL:
return string(s), contentTypeURL
case Srcset:
return string(s), contentTypeSrcset
}
}
i := 0
for _, arg := range args {
// We skip untyped nil arguments for backward compatibility.
// Without this they would be output as <nil>, escaped.
// See issue 25875.
if arg == nil {
continue
}
args[i] = indirectToStringerOrError(arg)
i++
}
return fmt.Sprint(args[:i]...), contentTypePlain
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"text/template/parse"
)
// context describes the state an HTML parser must be in when it reaches the
// portion of HTML produced by evaluating a particular template node.
//
// The zero value of type context is the start context for a template that
// produces an HTML fragment as defined at
// https://www.w3.org/TR/html5/syntax.html#the-end
// where the context element is null.
type context struct {
state state
delim delim
urlPart urlPart
jsCtx jsCtx
attr attr
element element
n parse.Node // for range break/continue
err *Error
}
func (c context) String() string {
var err error
if c.err != nil {
err = c.err
}
return fmt.Sprintf("{%v %v %v %v %v %v %v}", c.state, c.delim, c.urlPart, c.jsCtx, c.attr, c.element, err)
}
// eq reports whether two contexts are equal.
func (c context) eq(d context) bool {
return c.state == d.state &&
c.delim == d.delim &&
c.urlPart == d.urlPart &&
c.jsCtx == d.jsCtx &&
c.attr == d.attr &&
c.element == d.element &&
c.err == d.err
}
// mangle produces an identifier that includes a suffix that distinguishes it
// from template names mangled with different contexts.
func (c context) mangle(templateName string) string {
// The mangled name for the default context is the input templateName.
if c.state == stateText {
return templateName
}
s := templateName + "$htmltemplate_" + c.state.String()
if c.delim != delimNone {
s += "_" + c.delim.String()
}
if c.urlPart != urlPartNone {
s += "_" + c.urlPart.String()
}
if c.jsCtx != jsCtxRegexp {
s += "_" + c.jsCtx.String()
}
if c.attr != attrNone {
s += "_" + c.attr.String()
}
if c.element != elementNone {
s += "_" + c.element.String()
}
return s
}
// state describes a high-level HTML parser state.
//
// It bounds the top of the element stack, and by extension the HTML insertion
// mode, but also contains state that does not correspond to anything in the
// HTML5 parsing algorithm because a single token production in the HTML
// grammar may contain embedded actions in a template. For instance, the quoted
// HTML attribute produced by
//
// <div title="Hello {{.World}}">
//
// is a single token in HTML's grammar but in a template spans several nodes.
type state uint8
//go:generate stringer -type state
const (
// stateText is parsed character data. An HTML parser is in
// this state when its parse position is outside an HTML tag,
// directive, comment, and special element body.
stateText state = iota
// stateTag occurs before an HTML attribute or the end of a tag.
stateTag
// stateAttrName occurs inside an attribute name.
// It occurs between the ^'s in ` ^name^ = value`.
stateAttrName
// stateAfterName occurs after an attr name has ended but before any
// equals sign. It occurs between the ^'s in ` name^ ^= value`.
stateAfterName
// stateBeforeValue occurs after the equals sign but before the value.
// It occurs between the ^'s in ` name =^ ^value`.
stateBeforeValue
// stateHTMLCmt occurs inside an <!-- HTML comment -->.
stateHTMLCmt
// stateRCDATA occurs inside an RCDATA element (<textarea> or <title>)
// as described at https://www.w3.org/TR/html5/syntax.html#elements-0
stateRCDATA
// stateAttr occurs inside an HTML attribute whose content is text.
stateAttr
// stateURL occurs inside an HTML attribute whose content is a URL.
stateURL
// stateSrcset occurs inside an HTML srcset attribute.
stateSrcset
// stateJS occurs inside an event handler or script element.
stateJS
// stateJSDqStr occurs inside a JavaScript double quoted string.
stateJSDqStr
// stateJSSqStr occurs inside a JavaScript single quoted string.
stateJSSqStr
// stateJSRegexp occurs inside a JavaScript regexp literal.
stateJSRegexp
// stateJSBlockCmt occurs inside a JavaScript /* block comment */.
stateJSBlockCmt
// stateJSLineCmt occurs inside a JavaScript // line comment.
stateJSLineCmt
// stateCSS occurs inside a <style> element or style attribute.
stateCSS
// stateCSSDqStr occurs inside a CSS double quoted string.
stateCSSDqStr
// stateCSSSqStr occurs inside a CSS single quoted string.
stateCSSSqStr
// stateCSSDqURL occurs inside a CSS double quoted url("...").
stateCSSDqURL
// stateCSSSqURL occurs inside a CSS single quoted url('...').
stateCSSSqURL
// stateCSSURL occurs inside a CSS unquoted url(...).
stateCSSURL
// stateCSSBlockCmt occurs inside a CSS /* block comment */.
stateCSSBlockCmt
// stateCSSLineCmt occurs inside a CSS // line comment.
stateCSSLineCmt
// stateError is an infectious error state outside any valid
// HTML/CSS/JS construct.
stateError
// stateDead marks unreachable code after a {{break}} or {{continue}}.
stateDead
)
// isComment is true for any state that contains content meant for template
// authors & maintainers, not for end-users or machines.
func isComment(s state) bool {
switch s {
case stateHTMLCmt, stateJSBlockCmt, stateJSLineCmt, stateCSSBlockCmt, stateCSSLineCmt:
return true
}
return false
}
// isInTag return whether s occurs solely inside an HTML tag.
func isInTag(s state) bool {
switch s {
case stateTag, stateAttrName, stateAfterName, stateBeforeValue, stateAttr:
return true
}
return false
}
// delim is the delimiter that will end the current HTML attribute.
type delim uint8
//go:generate stringer -type delim
const (
// delimNone occurs outside any attribute.
delimNone delim = iota
// delimDoubleQuote occurs when a double quote (") closes the attribute.
delimDoubleQuote
// delimSingleQuote occurs when a single quote (') closes the attribute.
delimSingleQuote
// delimSpaceOrTagEnd occurs when a space or right angle bracket (>)
// closes the attribute.
delimSpaceOrTagEnd
)
// urlPart identifies a part in an RFC 3986 hierarchical URL to allow different
// encoding strategies.
type urlPart uint8
//go:generate stringer -type urlPart
const (
// urlPartNone occurs when not in a URL, or possibly at the start:
// ^ in "^http://auth/path?k=v#frag".
urlPartNone urlPart = iota
// urlPartPreQuery occurs in the scheme, authority, or path; between the
// ^s in "h^ttp://auth/path^?k=v#frag".
urlPartPreQuery
// urlPartQueryOrFrag occurs in the query portion between the ^s in
// "http://auth/path?^k=v#frag^".
urlPartQueryOrFrag
// urlPartUnknown occurs due to joining of contexts both before and
// after the query separator.
urlPartUnknown
)
// jsCtx determines whether a '/' starts a regular expression literal or a
// division operator.
type jsCtx uint8
//go:generate stringer -type jsCtx
const (
// jsCtxRegexp occurs where a '/' would start a regexp literal.
jsCtxRegexp jsCtx = iota
// jsCtxDivOp occurs where a '/' would start a division operator.
jsCtxDivOp
// jsCtxUnknown occurs where a '/' is ambiguous due to context joining.
jsCtxUnknown
)
// element identifies the HTML element when inside a start tag or special body.
// Certain HTML element (for example <script> and <style>) have bodies that are
// treated differently from stateText so the element type is necessary to
// transition into the correct context at the end of a tag and to identify the
// end delimiter for the body.
type element uint8
//go:generate stringer -type element
const (
// elementNone occurs outside a special tag or special element body.
elementNone element = iota
// elementScript corresponds to the raw text <script> element
// with JS MIME type or no type attribute.
elementScript
// elementStyle corresponds to the raw text <style> element.
elementStyle
// elementTextarea corresponds to the RCDATA <textarea> element.
elementTextarea
// elementTitle corresponds to the RCDATA <title> element.
elementTitle
)
//go:generate stringer -type attr
// attr identifies the current HTML attribute when inside the attribute,
// that is, starting from stateAttrName until stateTag/stateText (exclusive).
type attr uint8
const (
// attrNone corresponds to a normal attribute or no attribute.
attrNone attr = iota
// attrScript corresponds to an event handler attribute.
attrScript
// attrScriptType corresponds to the type attribute in script HTML element
attrScriptType
// attrStyle corresponds to the style attribute whose value is CSS.
attrStyle
// attrURL corresponds to an attribute whose value is a URL.
attrURL
// attrSrcset corresponds to a srcset attribute.
attrSrcset
)
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
// endsWithCSSKeyword reports whether b ends with an ident that
// case-insensitively matches the lower-case kw.
func endsWithCSSKeyword(b []byte, kw string) bool {
i := len(b) - len(kw)
if i < 0 {
// Too short.
return false
}
if i != 0 {
r, _ := utf8.DecodeLastRune(b[:i])
if isCSSNmchar(r) {
// Too long.
return false
}
}
// Many CSS keywords, such as "!important" can have characters encoded,
// but the URI production does not allow that according to
// https://www.w3.org/TR/css3-syntax/#TOK-URI
// This does not attempt to recognize encoded keywords. For example,
// given "\75\72\6c" and "url" this return false.
return string(bytes.ToLower(b[i:])) == kw
}
// isCSSNmchar reports whether rune is allowed anywhere in a CSS identifier.
func isCSSNmchar(r rune) bool {
// Based on the CSS3 nmchar production but ignores multi-rune escape
// sequences.
// https://www.w3.org/TR/css3-syntax/#SUBTOK-nmchar
return 'a' <= r && r <= 'z' ||
'A' <= r && r <= 'Z' ||
'0' <= r && r <= '9' ||
r == '-' ||
r == '_' ||
// Non-ASCII cases below.
0x80 <= r && r <= 0xd7ff ||
0xe000 <= r && r <= 0xfffd ||
0x10000 <= r && r <= 0x10ffff
}
// decodeCSS decodes CSS3 escapes given a sequence of stringchars.
// If there is no change, it returns the input, otherwise it returns a slice
// backed by a new array.
// https://www.w3.org/TR/css3-syntax/#SUBTOK-stringchar defines stringchar.
func decodeCSS(s []byte) []byte {
i := bytes.IndexByte(s, '\\')
if i == -1 {
return s
}
// The UTF-8 sequence for a codepoint is never longer than 1 + the
// number hex digits need to represent that codepoint, so len(s) is an
// upper bound on the output length.
b := make([]byte, 0, len(s))
for len(s) != 0 {
i := bytes.IndexByte(s, '\\')
if i == -1 {
i = len(s)
}
b, s = append(b, s[:i]...), s[i:]
if len(s) < 2 {
break
}
// https://www.w3.org/TR/css3-syntax/#SUBTOK-escape
// escape ::= unicode | '\' [#x20-#x7E#x80-#xD7FF#xE000-#xFFFD#x10000-#x10FFFF]
if isHex(s[1]) {
// https://www.w3.org/TR/css3-syntax/#SUBTOK-unicode
// unicode ::= '\' [0-9a-fA-F]{1,6} wc?
j := 2
for j < len(s) && j < 7 && isHex(s[j]) {
j++
}
r := hexDecode(s[1:j])
if r > unicode.MaxRune {
r, j = r/16, j-1
}
n := utf8.EncodeRune(b[len(b):cap(b)], r)
// The optional space at the end allows a hex
// sequence to be followed by a literal hex.
// string(decodeCSS([]byte(`\A B`))) == "\nB"
b, s = b[:len(b)+n], skipCSSSpace(s[j:])
} else {
// `\\` decodes to `\` and `\"` to `"`.
_, n := utf8.DecodeRune(s[1:])
b, s = append(b, s[1:1+n]...), s[1+n:]
}
}
return b
}
// isHex reports whether the given character is a hex digit.
func isHex(c byte) bool {
return '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F'
}
// hexDecode decodes a short hex digit sequence: "10" -> 16.
func hexDecode(s []byte) rune {
n := '\x00'
for _, c := range s {
n <<= 4
switch {
case '0' <= c && c <= '9':
n |= rune(c - '0')
case 'a' <= c && c <= 'f':
n |= rune(c-'a') + 10
case 'A' <= c && c <= 'F':
n |= rune(c-'A') + 10
default:
panic(fmt.Sprintf("Bad hex digit in %q", s))
}
}
return n
}
// skipCSSSpace returns a suffix of c, skipping over a single space.
func skipCSSSpace(c []byte) []byte {
if len(c) == 0 {
return c
}
// wc ::= #x9 | #xA | #xC | #xD | #x20
switch c[0] {
case '\t', '\n', '\f', ' ':
return c[1:]
case '\r':
// This differs from CSS3's wc production because it contains a
// probable spec error whereby wc contains all the single byte
// sequences in nl (newline) but not CRLF.
if len(c) >= 2 && c[1] == '\n' {
return c[2:]
}
return c[1:]
}
return c
}
// isCSSSpace reports whether b is a CSS space char as defined in wc.
func isCSSSpace(b byte) bool {
switch b {
case '\t', '\n', '\f', '\r', ' ':
return true
}
return false
}
// cssEscaper escapes HTML and CSS special characters using \<hex>+ escapes.
func cssEscaper(args ...any) string {
s, _ := stringify(args...)
var b strings.Builder
r, w, written := rune(0), 0, 0
for i := 0; i < len(s); i += w {
// See comment in htmlEscaper.
r, w = utf8.DecodeRuneInString(s[i:])
var repl string
switch {
case int(r) < len(cssReplacementTable) && cssReplacementTable[r] != "":
repl = cssReplacementTable[r]
default:
continue
}
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
if repl != `\\` && (written == len(s) || isHex(s[written]) || isCSSSpace(s[written])) {
b.WriteByte(' ')
}
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
var cssReplacementTable = []string{
0: `\0`,
'\t': `\9`,
'\n': `\a`,
'\f': `\c`,
'\r': `\d`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\22`,
'&': `\26`,
'\'': `\27`,
'(': `\28`,
')': `\29`,
'+': `\2b`,
'/': `\2f`,
':': `\3a`,
';': `\3b`,
'<': `\3c`,
'>': `\3e`,
'\\': `\\`,
'{': `\7b`,
'}': `\7d`,
}
var expressionBytes = []byte("expression")
var mozBindingBytes = []byte("mozbinding")
// cssValueFilter allows innocuous CSS values in the output including CSS
// quantities (10px or 25%), ID or class literals (#foo, .bar), keyword values
// (inherit, blue), and colors (#888).
// It filters out unsafe values, such as those that affect token boundaries,
// and anything that might execute scripts.
func cssValueFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeCSS {
return s
}
b, id := decodeCSS([]byte(s)), make([]byte, 0, 64)
// CSS3 error handling is specified as honoring string boundaries per
// https://www.w3.org/TR/css3-syntax/#error-handling :
// Malformed declarations. User agents must handle unexpected
// tokens encountered while parsing a declaration by reading until
// the end of the declaration, while observing the rules for
// matching pairs of (), [], {}, "", and '', and correctly handling
// escapes. For example, a malformed declaration may be missing a
// property, colon (:) or value.
// So we need to make sure that values do not have mismatched bracket
// or quote characters to prevent the browser from restarting parsing
// inside a string that might embed JavaScript source.
for i, c := range b {
switch c {
case 0, '"', '\'', '(', ')', '/', ';', '@', '[', '\\', ']', '`', '{', '}':
return filterFailsafe
case '-':
// Disallow <!-- or -->.
// -- should not appear in valid identifiers.
if i != 0 && b[i-1] == '-' {
return filterFailsafe
}
default:
if c < utf8.RuneSelf && isCSSNmchar(rune(c)) {
id = append(id, c)
}
}
}
id = bytes.ToLower(id)
if bytes.Contains(id, expressionBytes) || bytes.Contains(id, mozBindingBytes) {
return filterFailsafe
}
return string(b)
}
// Code generated by "stringer -type delim"; DO NOT EDIT.
package template
import "strconv"
const _delim_name = "delimNonedelimDoubleQuotedelimSingleQuotedelimSpaceOrTagEnd"
var _delim_index = [...]uint8{0, 9, 25, 41, 59}
func (i delim) String() string {
if i >= delim(len(_delim_index)-1) {
return "delim(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _delim_name[_delim_index[i]:_delim_index[i+1]]
}
// Code generated by "stringer -type element"; DO NOT EDIT.
package template
import "strconv"
const _element_name = "elementNoneelementScriptelementStyleelementTextareaelementTitle"
var _element_index = [...]uint8{0, 11, 24, 36, 51, 63}
func (i element) String() string {
if i >= element(len(_element_index)-1) {
return "element(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _element_name[_element_index[i]:_element_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"text/template/parse"
)
// Error describes a problem encountered during template Escaping.
type Error struct {
// ErrorCode describes the kind of error.
ErrorCode ErrorCode
// Node is the node that caused the problem, if known.
// If not nil, it overrides Name and Line.
Node parse.Node
// Name is the name of the template in which the error was encountered.
Name string
// Line is the line number of the error in the template source or 0.
Line int
// Description is a human-readable description of the problem.
Description string
}
// ErrorCode is a code for a kind of error.
type ErrorCode int
// We define codes for each error that manifests while escaping templates, but
// escaped templates may also fail at runtime.
//
// Output: "ZgotmplZ"
// Example:
//
// <img src="{{.X}}">
// where {{.X}} evaluates to `javascript:...`
//
// Discussion:
//
// "ZgotmplZ" is a special value that indicates that unsafe content reached a
// CSS or URL context at runtime. The output of the example will be
// <img src="#ZgotmplZ">
// If the data comes from a trusted source, use content types to exempt it
// from filtering: URL(`javascript:...`).
const (
// OK indicates the lack of an error.
OK ErrorCode = iota
// ErrAmbigContext: "... appears in an ambiguous context within a URL"
// Example:
// <a href="
// {{if .C}}
// /path/
// {{else}}
// /search?q=
// {{end}}
// {{.X}}
// ">
// Discussion:
// {{.X}} is in an ambiguous URL context since, depending on {{.C}},
// it may be either a URL suffix or a query parameter.
// Moving {{.X}} into the condition removes the ambiguity:
// <a href="{{if .C}}/path/{{.X}}{{else}}/search?q={{.X}}">
ErrAmbigContext
// ErrBadHTML: "expected space, attr name, or end of tag, but got ...",
// "... in unquoted attr", "... in attribute name"
// Example:
// <a href = /search?q=foo>
// <href=foo>
// <form na<e=...>
// <option selected<
// Discussion:
// This is often due to a typo in an HTML element, but some runes
// are banned in tag names, attribute names, and unquoted attribute
// values because they can tickle parser ambiguities.
// Quoting all attributes is the best policy.
ErrBadHTML
// ErrBranchEnd: "{{if}} branches end in different contexts"
// Example:
// {{if .C}}<a href="{{end}}{{.X}}
// Discussion:
// Package html/template statically examines each path through an
// {{if}}, {{range}}, or {{with}} to escape any following pipelines.
// The example is ambiguous since {{.X}} might be an HTML text node,
// or a URL prefix in an HTML attribute. The context of {{.X}} is
// used to figure out how to escape it, but that context depends on
// the run-time value of {{.C}} which is not statically known.
//
// The problem is usually something like missing quotes or angle
// brackets, or can be avoided by refactoring to put the two contexts
// into different branches of an if, range or with. If the problem
// is in a {{range}} over a collection that should never be empty,
// adding a dummy {{else}} can help.
ErrBranchEnd
// ErrEndContext: "... ends in a non-text context: ..."
// Examples:
// <div
// <div title="no close quote>
// <script>f()
// Discussion:
// Executed templates should produce a DocumentFragment of HTML.
// Templates that end without closing tags will trigger this error.
// Templates that should not be used in an HTML context or that
// produce incomplete Fragments should not be executed directly.
//
// {{define "main"}} <script>{{template "helper"}}</script> {{end}}
// {{define "helper"}} document.write(' <div title=" ') {{end}}
//
// "helper" does not produce a valid document fragment, so should
// not be Executed directly.
ErrEndContext
// ErrNoSuchTemplate: "no such template ..."
// Examples:
// {{define "main"}}<div {{template "attrs"}}>{{end}}
// {{define "attrs"}}href="{{.URL}}"{{end}}
// Discussion:
// Package html/template looks through template calls to compute the
// context.
// Here the {{.URL}} in "attrs" must be treated as a URL when called
// from "main", but you will get this error if "attrs" is not defined
// when "main" is parsed.
ErrNoSuchTemplate
// ErrOutputContext: "cannot compute output context for template ..."
// Examples:
// {{define "t"}}{{if .T}}{{template "t" .T}}{{end}}{{.H}}",{{end}}
// Discussion:
// A recursive template does not end in the same context in which it
// starts, and a reliable output context cannot be computed.
// Look for typos in the named template.
// If the template should not be called in the named start context,
// look for calls to that template in unexpected contexts.
// Maybe refactor recursive templates to not be recursive.
ErrOutputContext
// ErrPartialCharset: "unfinished JS regexp charset in ..."
// Example:
// <script>var pattern = /foo[{{.Chars}}]/</script>
// Discussion:
// Package html/template does not support interpolation into regular
// expression literal character sets.
ErrPartialCharset
// ErrPartialEscape: "unfinished escape sequence in ..."
// Example:
// <script>alert("\{{.X}}")</script>
// Discussion:
// Package html/template does not support actions following a
// backslash.
// This is usually an error and there are better solutions; for
// example
// <script>alert("{{.X}}")</script>
// should work, and if {{.X}} is a partial escape sequence such as
// "xA0", mark the whole sequence as safe content: JSStr(`\xA0`)
ErrPartialEscape
// ErrRangeLoopReentry: "on range loop re-entry: ..."
// Example:
// <script>var x = [{{range .}}'{{.}},{{end}}]</script>
// Discussion:
// If an iteration through a range would cause it to end in a
// different context than an earlier pass, there is no single context.
// In the example, there is missing a quote, so it is not clear
// whether {{.}} is meant to be inside a JS string or in a JS value
// context. The second iteration would produce something like
//
// <script>var x = ['firstValue,'secondValue]</script>
ErrRangeLoopReentry
// ErrSlashAmbig: '/' could start a division or regexp.
// Example:
// <script>
// {{if .C}}var x = 1{{end}}
// /-{{.N}}/i.test(x) ? doThis : doThat();
// </script>
// Discussion:
// The example above could produce `var x = 1/-2/i.test(s)...`
// in which the first '/' is a mathematical division operator or it
// could produce `/-2/i.test(s)` in which the first '/' starts a
// regexp literal.
// Look for missing semicolons inside branches, and maybe add
// parentheses to make it clear which interpretation you intend.
ErrSlashAmbig
// ErrPredefinedEscaper: "predefined escaper ... disallowed in template"
// Example:
// <div class={{. | html}}>Hello<div>
// Discussion:
// Package html/template already contextually escapes all pipelines to
// produce HTML output safe against code injection. Manually escaping
// pipeline output using the predefined escapers "html" or "urlquery" is
// unnecessary, and may affect the correctness or safety of the escaped
// pipeline output in Go 1.8 and earlier.
//
// In most cases, such as the given example, this error can be resolved by
// simply removing the predefined escaper from the pipeline and letting the
// contextual autoescaper handle the escaping of the pipeline. In other
// instances, where the predefined escaper occurs in the middle of a
// pipeline where subsequent commands expect escaped input, e.g.
// {{.X | html | makeALink}}
// where makeALink does
// return `<a href="`+input+`">link</a>`
// consider refactoring the surrounding template to make use of the
// contextual autoescaper, i.e.
// <a href="{{.X}}">link</a>
//
// To ease migration to Go 1.9 and beyond, "html" and "urlquery" will
// continue to be allowed as the last command in a pipeline. However, if the
// pipeline occurs in an unquoted attribute value context, "html" is
// disallowed. Avoid using "html" and "urlquery" entirely in new templates.
ErrPredefinedEscaper
)
func (e *Error) Error() string {
switch {
case e.Node != nil:
loc, _ := (*parse.Tree)(nil).ErrorContext(e.Node)
return fmt.Sprintf("html/template:%s: %s", loc, e.Description)
case e.Line != 0:
return fmt.Sprintf("html/template:%s:%d: %s", e.Name, e.Line, e.Description)
case e.Name != "":
return fmt.Sprintf("html/template:%s: %s", e.Name, e.Description)
}
return "html/template: " + e.Description
}
// errorf creates an error given a format string f and args.
// The template Name still needs to be supplied.
func errorf(k ErrorCode, node parse.Node, line int, f string, args ...any) *Error {
return &Error{k, node, "", line, fmt.Sprintf(f, args...)}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"html"
"io"
"text/template"
"text/template/parse"
)
// escapeTemplate rewrites the named template, which must be
// associated with t, to guarantee that the output of any of the named
// templates is properly escaped. If no error is returned, then the named templates have
// been modified. Otherwise the named templates have been rendered
// unusable.
func escapeTemplate(tmpl *Template, node parse.Node, name string) error {
c, _ := tmpl.esc.escapeTree(context{}, node, name, 0)
var err error
if c.err != nil {
err, c.err.Name = c.err, name
} else if c.state != stateText {
err = &Error{ErrEndContext, nil, name, 0, fmt.Sprintf("ends in a non-text context: %v", c)}
}
if err != nil {
// Prevent execution of unsafe templates.
if t := tmpl.set[name]; t != nil {
t.escapeErr = err
t.text.Tree = nil
t.Tree = nil
}
return err
}
tmpl.esc.commit()
if t := tmpl.set[name]; t != nil {
t.escapeErr = escapeOK
t.Tree = t.text.Tree
}
return nil
}
// evalArgs formats the list of arguments into a string. It is equivalent to
// fmt.Sprint(args...), except that it dereferences all pointers.
func evalArgs(args ...any) string {
// Optimization for simple common case of a single string argument.
if len(args) == 1 {
if s, ok := args[0].(string); ok {
return s
}
}
for i, arg := range args {
args[i] = indirectToStringerOrError(arg)
}
return fmt.Sprint(args...)
}
// funcMap maps command names to functions that render their inputs safe.
var funcMap = template.FuncMap{
"_html_template_attrescaper": attrEscaper,
"_html_template_commentescaper": commentEscaper,
"_html_template_cssescaper": cssEscaper,
"_html_template_cssvaluefilter": cssValueFilter,
"_html_template_htmlnamefilter": htmlNameFilter,
"_html_template_htmlescaper": htmlEscaper,
"_html_template_jsregexpescaper": jsRegexpEscaper,
"_html_template_jsstrescaper": jsStrEscaper,
"_html_template_jsvalescaper": jsValEscaper,
"_html_template_nospaceescaper": htmlNospaceEscaper,
"_html_template_rcdataescaper": rcdataEscaper,
"_html_template_srcsetescaper": srcsetFilterAndEscaper,
"_html_template_urlescaper": urlEscaper,
"_html_template_urlfilter": urlFilter,
"_html_template_urlnormalizer": urlNormalizer,
"_eval_args_": evalArgs,
}
// escaper collects type inferences about templates and changes needed to make
// templates injection safe.
type escaper struct {
// ns is the nameSpace that this escaper is associated with.
ns *nameSpace
// output[templateName] is the output context for a templateName that
// has been mangled to include its input context.
output map[string]context
// derived[c.mangle(name)] maps to a template derived from the template
// named name templateName for the start context c.
derived map[string]*template.Template
// called[templateName] is a set of called mangled template names.
called map[string]bool
// xxxNodeEdits are the accumulated edits to apply during commit.
// Such edits are not applied immediately in case a template set
// executes a given template in different escaping contexts.
actionNodeEdits map[*parse.ActionNode][]string
templateNodeEdits map[*parse.TemplateNode]string
textNodeEdits map[*parse.TextNode][]byte
// rangeContext holds context about the current range loop.
rangeContext *rangeContext
}
// rangeContext holds information about the current range loop.
type rangeContext struct {
outer *rangeContext // outer loop
breaks []context // context at each break action
continues []context // context at each continue action
}
// makeEscaper creates a blank escaper for the given set.
func makeEscaper(n *nameSpace) escaper {
return escaper{
n,
map[string]context{},
map[string]*template.Template{},
map[string]bool{},
map[*parse.ActionNode][]string{},
map[*parse.TemplateNode]string{},
map[*parse.TextNode][]byte{},
nil,
}
}
// filterFailsafe is an innocuous word that is emitted in place of unsafe values
// by sanitizer functions. It is not a keyword in any programming language,
// contains no special characters, is not empty, and when it appears in output
// it is distinct enough that a developer can find the source of the problem
// via a search engine.
const filterFailsafe = "ZgotmplZ"
// escape escapes a template node.
func (e *escaper) escape(c context, n parse.Node) context {
switch n := n.(type) {
case *parse.ActionNode:
return e.escapeAction(c, n)
case *parse.BreakNode:
c.n = n
e.rangeContext.breaks = append(e.rangeContext.breaks, c)
return context{state: stateDead}
case *parse.CommentNode:
return c
case *parse.ContinueNode:
c.n = n
e.rangeContext.continues = append(e.rangeContext.breaks, c)
return context{state: stateDead}
case *parse.IfNode:
return e.escapeBranch(c, &n.BranchNode, "if")
case *parse.ListNode:
return e.escapeList(c, n)
case *parse.RangeNode:
return e.escapeBranch(c, &n.BranchNode, "range")
case *parse.TemplateNode:
return e.escapeTemplate(c, n)
case *parse.TextNode:
return e.escapeText(c, n)
case *parse.WithNode:
return e.escapeBranch(c, &n.BranchNode, "with")
}
panic("escaping " + n.String() + " is unimplemented")
}
// escapeAction escapes an action template node.
func (e *escaper) escapeAction(c context, n *parse.ActionNode) context {
if len(n.Pipe.Decl) != 0 {
// A local variable assignment, not an interpolation.
return c
}
c = nudge(c)
// Check for disallowed use of predefined escapers in the pipeline.
for pos, idNode := range n.Pipe.Cmds {
node, ok := idNode.Args[0].(*parse.IdentifierNode)
if !ok {
// A predefined escaper "esc" will never be found as an identifier in a
// Chain or Field node, since:
// - "esc.x ..." is invalid, since predefined escapers return strings, and
// strings do not have methods, keys or fields.
// - "... .esc" is invalid, since predefined escapers are global functions,
// not methods or fields of any types.
// Therefore, it is safe to ignore these two node types.
continue
}
ident := node.Ident
if _, ok := predefinedEscapers[ident]; ok {
if pos < len(n.Pipe.Cmds)-1 ||
c.state == stateAttr && c.delim == delimSpaceOrTagEnd && ident == "html" {
return context{
state: stateError,
err: errorf(ErrPredefinedEscaper, n, n.Line, "predefined escaper %q disallowed in template", ident),
}
}
}
}
s := make([]string, 0, 3)
switch c.state {
case stateError:
return c
case stateURL, stateCSSDqStr, stateCSSSqStr, stateCSSDqURL, stateCSSSqURL, stateCSSURL:
switch c.urlPart {
case urlPartNone:
s = append(s, "_html_template_urlfilter")
fallthrough
case urlPartPreQuery:
switch c.state {
case stateCSSDqStr, stateCSSSqStr:
s = append(s, "_html_template_cssescaper")
default:
s = append(s, "_html_template_urlnormalizer")
}
case urlPartQueryOrFrag:
s = append(s, "_html_template_urlescaper")
case urlPartUnknown:
return context{
state: stateError,
err: errorf(ErrAmbigContext, n, n.Line, "%s appears in an ambiguous context within a URL", n),
}
default:
panic(c.urlPart.String())
}
case stateJS:
s = append(s, "_html_template_jsvalescaper")
// A slash after a value starts a div operator.
c.jsCtx = jsCtxDivOp
case stateJSDqStr, stateJSSqStr:
s = append(s, "_html_template_jsstrescaper")
case stateJSRegexp:
s = append(s, "_html_template_jsregexpescaper")
case stateCSS:
s = append(s, "_html_template_cssvaluefilter")
case stateText:
s = append(s, "_html_template_htmlescaper")
case stateRCDATA:
s = append(s, "_html_template_rcdataescaper")
case stateAttr:
// Handled below in delim check.
case stateAttrName, stateTag:
c.state = stateAttrName
s = append(s, "_html_template_htmlnamefilter")
case stateSrcset:
s = append(s, "_html_template_srcsetescaper")
default:
if isComment(c.state) {
s = append(s, "_html_template_commentescaper")
} else {
panic("unexpected state " + c.state.String())
}
}
switch c.delim {
case delimNone:
// No extra-escaping needed for raw text content.
case delimSpaceOrTagEnd:
s = append(s, "_html_template_nospaceescaper")
default:
s = append(s, "_html_template_attrescaper")
}
e.editActionNode(n, s)
return c
}
// ensurePipelineContains ensures that the pipeline ends with the commands with
// the identifiers in s in order. If the pipeline ends with a predefined escaper
// (i.e. "html" or "urlquery"), merge it with the identifiers in s.
func ensurePipelineContains(p *parse.PipeNode, s []string) {
if len(s) == 0 {
// Do not rewrite pipeline if we have no escapers to insert.
return
}
// Precondition: p.Cmds contains at most one predefined escaper and the
// escaper will be present at p.Cmds[len(p.Cmds)-1]. This precondition is
// always true because of the checks in escapeAction.
pipelineLen := len(p.Cmds)
if pipelineLen > 0 {
lastCmd := p.Cmds[pipelineLen-1]
if idNode, ok := lastCmd.Args[0].(*parse.IdentifierNode); ok {
if esc := idNode.Ident; predefinedEscapers[esc] {
// Pipeline ends with a predefined escaper.
if len(p.Cmds) == 1 && len(lastCmd.Args) > 1 {
// Special case: pipeline is of the form {{ esc arg1 arg2 ... argN }},
// where esc is the predefined escaper, and arg1...argN are its arguments.
// Convert this into the equivalent form
// {{ _eval_args_ arg1 arg2 ... argN | esc }}, so that esc can be easily
// merged with the escapers in s.
lastCmd.Args[0] = parse.NewIdentifier("_eval_args_").SetTree(nil).SetPos(lastCmd.Args[0].Position())
p.Cmds = appendCmd(p.Cmds, newIdentCmd(esc, p.Position()))
pipelineLen++
}
// If any of the commands in s that we are about to insert is equivalent
// to the predefined escaper, use the predefined escaper instead.
dup := false
for i, escaper := range s {
if escFnsEq(esc, escaper) {
s[i] = idNode.Ident
dup = true
}
}
if dup {
// The predefined escaper will already be inserted along with the
// escapers in s, so do not copy it to the rewritten pipeline.
pipelineLen--
}
}
}
}
// Rewrite the pipeline, creating the escapers in s at the end of the pipeline.
newCmds := make([]*parse.CommandNode, pipelineLen, pipelineLen+len(s))
insertedIdents := make(map[string]bool)
for i := 0; i < pipelineLen; i++ {
cmd := p.Cmds[i]
newCmds[i] = cmd
if idNode, ok := cmd.Args[0].(*parse.IdentifierNode); ok {
insertedIdents[normalizeEscFn(idNode.Ident)] = true
}
}
for _, name := range s {
if !insertedIdents[normalizeEscFn(name)] {
// When two templates share an underlying parse tree via the use of
// AddParseTree and one template is executed after the other, this check
// ensures that escapers that were already inserted into the pipeline on
// the first escaping pass do not get inserted again.
newCmds = appendCmd(newCmds, newIdentCmd(name, p.Position()))
}
}
p.Cmds = newCmds
}
// predefinedEscapers contains template predefined escapers that are equivalent
// to some contextual escapers. Keep in sync with equivEscapers.
var predefinedEscapers = map[string]bool{
"html": true,
"urlquery": true,
}
// equivEscapers matches contextual escapers to equivalent predefined
// template escapers.
var equivEscapers = map[string]string{
// The following pairs of HTML escapers provide equivalent security
// guarantees, since they all escape '\000', '\'', '"', '&', '<', and '>'.
"_html_template_attrescaper": "html",
"_html_template_htmlescaper": "html",
"_html_template_rcdataescaper": "html",
// These two URL escapers produce URLs safe for embedding in a URL query by
// percent-encoding all the reserved characters specified in RFC 3986 Section
// 2.2
"_html_template_urlescaper": "urlquery",
// These two functions are not actually equivalent; urlquery is stricter as it
// escapes reserved characters (e.g. '#'), while _html_template_urlnormalizer
// does not. It is therefore only safe to replace _html_template_urlnormalizer
// with urlquery (this happens in ensurePipelineContains), but not the otherI've
// way around. We keep this entry around to preserve the behavior of templates
// written before Go 1.9, which might depend on this substitution taking place.
"_html_template_urlnormalizer": "urlquery",
}
// escFnsEq reports whether the two escaping functions are equivalent.
func escFnsEq(a, b string) bool {
return normalizeEscFn(a) == normalizeEscFn(b)
}
// normalizeEscFn(a) is equal to normalizeEscFn(b) for any pair of names of
// escaper functions a and b that are equivalent.
func normalizeEscFn(e string) string {
if norm := equivEscapers[e]; norm != "" {
return norm
}
return e
}
// redundantFuncs[a][b] implies that funcMap[b](funcMap[a](x)) == funcMap[a](x)
// for all x.
var redundantFuncs = map[string]map[string]bool{
"_html_template_commentescaper": {
"_html_template_attrescaper": true,
"_html_template_nospaceescaper": true,
"_html_template_htmlescaper": true,
},
"_html_template_cssescaper": {
"_html_template_attrescaper": true,
},
"_html_template_jsregexpescaper": {
"_html_template_attrescaper": true,
},
"_html_template_jsstrescaper": {
"_html_template_attrescaper": true,
},
"_html_template_urlescaper": {
"_html_template_urlnormalizer": true,
},
}
// appendCmd appends the given command to the end of the command pipeline
// unless it is redundant with the last command.
func appendCmd(cmds []*parse.CommandNode, cmd *parse.CommandNode) []*parse.CommandNode {
if n := len(cmds); n != 0 {
last, okLast := cmds[n-1].Args[0].(*parse.IdentifierNode)
next, okNext := cmd.Args[0].(*parse.IdentifierNode)
if okLast && okNext && redundantFuncs[last.Ident][next.Ident] {
return cmds
}
}
return append(cmds, cmd)
}
// newIdentCmd produces a command containing a single identifier node.
func newIdentCmd(identifier string, pos parse.Pos) *parse.CommandNode {
return &parse.CommandNode{
NodeType: parse.NodeCommand,
Args: []parse.Node{parse.NewIdentifier(identifier).SetTree(nil).SetPos(pos)}, // TODO: SetTree.
}
}
// nudge returns the context that would result from following empty string
// transitions from the input context.
// For example, parsing:
//
// `<a href=`
//
// will end in context{stateBeforeValue, attrURL}, but parsing one extra rune:
//
// `<a href=x`
//
// will end in context{stateURL, delimSpaceOrTagEnd, ...}.
// There are two transitions that happen when the 'x' is seen:
// (1) Transition from a before-value state to a start-of-value state without
//
// consuming any character.
//
// (2) Consume 'x' and transition past the first value character.
// In this case, nudging produces the context after (1) happens.
func nudge(c context) context {
switch c.state {
case stateTag:
// In `<foo {{.}}`, the action should emit an attribute.
c.state = stateAttrName
case stateBeforeValue:
// In `<foo bar={{.}}`, the action is an undelimited value.
c.state, c.delim, c.attr = attrStartStates[c.attr], delimSpaceOrTagEnd, attrNone
case stateAfterName:
// In `<foo bar {{.}}`, the action is an attribute name.
c.state, c.attr = stateAttrName, attrNone
}
return c
}
// join joins the two contexts of a branch template node. The result is an
// error context if either of the input contexts are error contexts, or if the
// input contexts differ.
func join(a, b context, node parse.Node, nodeName string) context {
if a.state == stateError {
return a
}
if b.state == stateError {
return b
}
if a.state == stateDead {
return b
}
if b.state == stateDead {
return a
}
if a.eq(b) {
return a
}
c := a
c.urlPart = b.urlPart
if c.eq(b) {
// The contexts differ only by urlPart.
c.urlPart = urlPartUnknown
return c
}
c = a
c.jsCtx = b.jsCtx
if c.eq(b) {
// The contexts differ only by jsCtx.
c.jsCtx = jsCtxUnknown
return c
}
// Allow a nudged context to join with an unnudged one.
// This means that
// <p title={{if .C}}{{.}}{{end}}
// ends in an unquoted value state even though the else branch
// ends in stateBeforeValue.
if c, d := nudge(a), nudge(b); !(c.eq(a) && d.eq(b)) {
if e := join(c, d, node, nodeName); e.state != stateError {
return e
}
}
return context{
state: stateError,
err: errorf(ErrBranchEnd, node, 0, "{{%s}} branches end in different contexts: %v, %v", nodeName, a, b),
}
}
// escapeBranch escapes a branch template node: "if", "range" and "with".
func (e *escaper) escapeBranch(c context, n *parse.BranchNode, nodeName string) context {
if nodeName == "range" {
e.rangeContext = &rangeContext{outer: e.rangeContext}
}
c0 := e.escapeList(c, n.List)
if nodeName == "range" {
if c0.state != stateError {
c0 = joinRange(c0, e.rangeContext)
}
e.rangeContext = e.rangeContext.outer
if c0.state == stateError {
return c0
}
// The "true" branch of a "range" node can execute multiple times.
// We check that executing n.List once results in the same context
// as executing n.List twice.
e.rangeContext = &rangeContext{outer: e.rangeContext}
c1, _ := e.escapeListConditionally(c0, n.List, nil)
c0 = join(c0, c1, n, nodeName)
if c0.state == stateError {
e.rangeContext = e.rangeContext.outer
// Make clear that this is a problem on loop re-entry
// since developers tend to overlook that branch when
// debugging templates.
c0.err.Line = n.Line
c0.err.Description = "on range loop re-entry: " + c0.err.Description
return c0
}
c0 = joinRange(c0, e.rangeContext)
e.rangeContext = e.rangeContext.outer
if c0.state == stateError {
return c0
}
}
c1 := e.escapeList(c, n.ElseList)
return join(c0, c1, n, nodeName)
}
func joinRange(c0 context, rc *rangeContext) context {
// Merge contexts at break and continue statements into overall body context.
// In theory we could treat breaks differently from continues, but for now it is
// enough to treat them both as going back to the start of the loop (which may then stop).
for _, c := range rc.breaks {
c0 = join(c0, c, c.n, "range")
if c0.state == stateError {
c0.err.Line = c.n.(*parse.BreakNode).Line
c0.err.Description = "at range loop break: " + c0.err.Description
return c0
}
}
for _, c := range rc.continues {
c0 = join(c0, c, c.n, "range")
if c0.state == stateError {
c0.err.Line = c.n.(*parse.ContinueNode).Line
c0.err.Description = "at range loop continue: " + c0.err.Description
return c0
}
}
return c0
}
// escapeList escapes a list template node.
func (e *escaper) escapeList(c context, n *parse.ListNode) context {
if n == nil {
return c
}
for _, m := range n.Nodes {
c = e.escape(c, m)
if c.state == stateDead {
break
}
}
return c
}
// escapeListConditionally escapes a list node but only preserves edits and
// inferences in e if the inferences and output context satisfy filter.
// It returns the best guess at an output context, and the result of the filter
// which is the same as whether e was updated.
func (e *escaper) escapeListConditionally(c context, n *parse.ListNode, filter func(*escaper, context) bool) (context, bool) {
e1 := makeEscaper(e.ns)
e1.rangeContext = e.rangeContext
// Make type inferences available to f.
for k, v := range e.output {
e1.output[k] = v
}
c = e1.escapeList(c, n)
ok := filter != nil && filter(&e1, c)
if ok {
// Copy inferences and edits from e1 back into e.
for k, v := range e1.output {
e.output[k] = v
}
for k, v := range e1.derived {
e.derived[k] = v
}
for k, v := range e1.called {
e.called[k] = v
}
for k, v := range e1.actionNodeEdits {
e.editActionNode(k, v)
}
for k, v := range e1.templateNodeEdits {
e.editTemplateNode(k, v)
}
for k, v := range e1.textNodeEdits {
e.editTextNode(k, v)
}
}
return c, ok
}
// escapeTemplate escapes a {{template}} call node.
func (e *escaper) escapeTemplate(c context, n *parse.TemplateNode) context {
c, name := e.escapeTree(c, n, n.Name, n.Line)
if name != n.Name {
e.editTemplateNode(n, name)
}
return c
}
// escapeTree escapes the named template starting in the given context as
// necessary and returns its output context.
func (e *escaper) escapeTree(c context, node parse.Node, name string, line int) (context, string) {
// Mangle the template name with the input context to produce a reliable
// identifier.
dname := c.mangle(name)
e.called[dname] = true
if out, ok := e.output[dname]; ok {
// Already escaped.
return out, dname
}
t := e.template(name)
if t == nil {
// Two cases: The template exists but is empty, or has never been mentioned at
// all. Distinguish the cases in the error messages.
if e.ns.set[name] != nil {
return context{
state: stateError,
err: errorf(ErrNoSuchTemplate, node, line, "%q is an incomplete or empty template", name),
}, dname
}
return context{
state: stateError,
err: errorf(ErrNoSuchTemplate, node, line, "no such template %q", name),
}, dname
}
if dname != name {
// Use any template derived during an earlier call to escapeTemplate
// with different top level templates, or clone if necessary.
dt := e.template(dname)
if dt == nil {
dt = template.New(dname)
dt.Tree = &parse.Tree{Name: dname, Root: t.Root.CopyList()}
e.derived[dname] = dt
}
t = dt
}
return e.computeOutCtx(c, t), dname
}
// computeOutCtx takes a template and its start context and computes the output
// context while storing any inferences in e.
func (e *escaper) computeOutCtx(c context, t *template.Template) context {
// Propagate context over the body.
c1, ok := e.escapeTemplateBody(c, t)
if !ok {
// Look for a fixed point by assuming c1 as the output context.
if c2, ok2 := e.escapeTemplateBody(c1, t); ok2 {
c1, ok = c2, true
}
// Use c1 as the error context if neither assumption worked.
}
if !ok && c1.state != stateError {
return context{
state: stateError,
err: errorf(ErrOutputContext, t.Tree.Root, 0, "cannot compute output context for template %s", t.Name()),
}
}
return c1
}
// escapeTemplateBody escapes the given template assuming the given output
// context, and returns the best guess at the output context and whether the
// assumption was correct.
func (e *escaper) escapeTemplateBody(c context, t *template.Template) (context, bool) {
filter := func(e1 *escaper, c1 context) bool {
if c1.state == stateError {
// Do not update the input escaper, e.
return false
}
if !e1.called[t.Name()] {
// If t is not recursively called, then c1 is an
// accurate output context.
return true
}
// c1 is accurate if it matches our assumed output context.
return c.eq(c1)
}
// We need to assume an output context so that recursive template calls
// take the fast path out of escapeTree instead of infinitely recurring.
// Naively assuming that the input context is the same as the output
// works >90% of the time.
e.output[t.Name()] = c
return e.escapeListConditionally(c, t.Tree.Root, filter)
}
// delimEnds maps each delim to a string of characters that terminate it.
var delimEnds = [...]string{
delimDoubleQuote: `"`,
delimSingleQuote: "'",
// Determined empirically by running the below in various browsers.
// var div = document.createElement("DIV");
// for (var i = 0; i < 0x10000; ++i) {
// div.innerHTML = "<span title=x" + String.fromCharCode(i) + "-bar>";
// if (div.getElementsByTagName("SPAN")[0].title.indexOf("bar") < 0)
// document.write("<p>U+" + i.toString(16));
// }
delimSpaceOrTagEnd: " \t\n\f\r>",
}
var doctypeBytes = []byte("<!DOCTYPE")
// escapeText escapes a text template node.
func (e *escaper) escapeText(c context, n *parse.TextNode) context {
s, written, i, b := n.Text, 0, 0, new(bytes.Buffer)
for i != len(s) {
c1, nread := contextAfterText(c, s[i:])
i1 := i + nread
if c.state == stateText || c.state == stateRCDATA {
end := i1
if c1.state != c.state {
for j := end - 1; j >= i; j-- {
if s[j] == '<' {
end = j
break
}
}
}
for j := i; j < end; j++ {
if s[j] == '<' && !bytes.HasPrefix(bytes.ToUpper(s[j:]), doctypeBytes) {
b.Write(s[written:j])
b.WriteString("<")
written = j + 1
}
}
} else if isComment(c.state) && c.delim == delimNone {
switch c.state {
case stateJSBlockCmt:
// https://es5.github.com/#x7.4:
// "Comments behave like white space and are
// discarded except that, if a MultiLineComment
// contains a line terminator character, then
// the entire comment is considered to be a
// LineTerminator for purposes of parsing by
// the syntactic grammar."
if bytes.ContainsAny(s[written:i1], "\n\r\u2028\u2029") {
b.WriteByte('\n')
} else {
b.WriteByte(' ')
}
case stateCSSBlockCmt:
b.WriteByte(' ')
}
written = i1
}
if c.state != c1.state && isComment(c1.state) && c1.delim == delimNone {
// Preserve the portion between written and the comment start.
cs := i1 - 2
if c1.state == stateHTMLCmt {
// "<!--" instead of "/*" or "//"
cs -= 2
}
b.Write(s[written:cs])
written = i1
}
if i == i1 && c.state == c1.state {
panic(fmt.Sprintf("infinite loop from %v to %v on %q..%q", c, c1, s[:i], s[i:]))
}
c, i = c1, i1
}
if written != 0 && c.state != stateError {
if !isComment(c.state) || c.delim != delimNone {
b.Write(n.Text[written:])
}
e.editTextNode(n, b.Bytes())
}
return c
}
// contextAfterText starts in context c, consumes some tokens from the front of
// s, then returns the context after those tokens and the unprocessed suffix.
func contextAfterText(c context, s []byte) (context, int) {
if c.delim == delimNone {
c1, i := tSpecialTagEnd(c, s)
if i == 0 {
// A special end tag (`</script>`) has been seen and
// all content preceding it has been consumed.
return c1, 0
}
// Consider all content up to any end tag.
return transitionFunc[c.state](c, s[:i])
}
// We are at the beginning of an attribute value.
i := bytes.IndexAny(s, delimEnds[c.delim])
if i == -1 {
i = len(s)
}
if c.delim == delimSpaceOrTagEnd {
// https://www.w3.org/TR/html5/syntax.html#attribute-value-(unquoted)-state
// lists the runes below as error characters.
// Error out because HTML parsers may differ on whether
// "<a id= onclick=f(" ends inside id's or onclick's value,
// "<a class=`foo " ends inside a value,
// "<a style=font:'Arial'" needs open-quote fixup.
// IE treats '`' as a quotation character.
if j := bytes.IndexAny(s[:i], "\"'<=`"); j >= 0 {
return context{
state: stateError,
err: errorf(ErrBadHTML, nil, 0, "%q in unquoted attr: %q", s[j:j+1], s[:i]),
}, len(s)
}
}
if i == len(s) {
// Remain inside the attribute.
// Decode the value so non-HTML rules can easily handle
// <button onclick="alert("Hi!")">
// without having to entity decode token boundaries.
for u := []byte(html.UnescapeString(string(s))); len(u) != 0; {
c1, i1 := transitionFunc[c.state](c, u)
c, u = c1, u[i1:]
}
return c, len(s)
}
element := c.element
// If this is a non-JS "type" attribute inside "script" tag, do not treat the contents as JS.
if c.state == stateAttr && c.element == elementScript && c.attr == attrScriptType && !isJSType(string(s[:i])) {
element = elementNone
}
if c.delim != delimSpaceOrTagEnd {
// Consume any quote.
i++
}
// On exiting an attribute, we discard all state information
// except the state and element.
return context{state: stateTag, element: element}, i
}
// editActionNode records a change to an action pipeline for later commit.
func (e *escaper) editActionNode(n *parse.ActionNode, cmds []string) {
if _, ok := e.actionNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.actionNodeEdits[n] = cmds
}
// editTemplateNode records a change to a {{template}} callee for later commit.
func (e *escaper) editTemplateNode(n *parse.TemplateNode, callee string) {
if _, ok := e.templateNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.templateNodeEdits[n] = callee
}
// editTextNode records a change to a text node for later commit.
func (e *escaper) editTextNode(n *parse.TextNode, text []byte) {
if _, ok := e.textNodeEdits[n]; ok {
panic(fmt.Sprintf("node %s shared between templates", n))
}
e.textNodeEdits[n] = text
}
// commit applies changes to actions and template calls needed to contextually
// autoescape content and adds any derived templates to the set.
func (e *escaper) commit() {
for name := range e.output {
e.template(name).Funcs(funcMap)
}
// Any template from the name space associated with this escaper can be used
// to add derived templates to the underlying text/template name space.
tmpl := e.arbitraryTemplate()
for _, t := range e.derived {
if _, err := tmpl.text.AddParseTree(t.Name(), t.Tree); err != nil {
panic("error adding derived template")
}
}
for n, s := range e.actionNodeEdits {
ensurePipelineContains(n.Pipe, s)
}
for n, name := range e.templateNodeEdits {
n.Name = name
}
for n, s := range e.textNodeEdits {
n.Text = s
}
// Reset state that is specific to this commit so that the same changes are
// not re-applied to the template on subsequent calls to commit.
e.called = make(map[string]bool)
e.actionNodeEdits = make(map[*parse.ActionNode][]string)
e.templateNodeEdits = make(map[*parse.TemplateNode]string)
e.textNodeEdits = make(map[*parse.TextNode][]byte)
}
// template returns the named template given a mangled template name.
func (e *escaper) template(name string) *template.Template {
// Any template from the name space associated with this escaper can be used
// to look up templates in the underlying text/template name space.
t := e.arbitraryTemplate().text.Lookup(name)
if t == nil {
t = e.derived[name]
}
return t
}
// arbitraryTemplate returns an arbitrary template from the name space
// associated with e and panics if no templates are found.
func (e *escaper) arbitraryTemplate() *Template {
for _, t := range e.ns.set {
return t
}
panic("no templates in name space")
}
// Forwarding functions so that clients need only import this package
// to reach the general escaping functions of text/template.
// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
func HTMLEscape(w io.Writer, b []byte) {
template.HTMLEscape(w, b)
}
// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
func HTMLEscapeString(s string) string {
return template.HTMLEscapeString(s)
}
// HTMLEscaper returns the escaped HTML equivalent of the textual
// representation of its arguments.
func HTMLEscaper(args ...any) string {
return template.HTMLEscaper(args...)
}
// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
func JSEscape(w io.Writer, b []byte) {
template.JSEscape(w, b)
}
// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
func JSEscapeString(s string) string {
return template.JSEscapeString(s)
}
// JSEscaper returns the escaped JavaScript equivalent of the textual
// representation of its arguments.
func JSEscaper(args ...any) string {
return template.JSEscaper(args...)
}
// URLQueryEscaper returns the escaped value of the textual representation of
// its arguments in a form suitable for embedding in a URL query.
func URLQueryEscaper(args ...any) string {
return template.URLQueryEscaper(args...)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"fmt"
"strings"
"unicode/utf8"
)
// htmlNospaceEscaper escapes for inclusion in unquoted attribute values.
func htmlNospaceEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return htmlReplacer(stripTags(s), htmlNospaceNormReplacementTable, false)
}
return htmlReplacer(s, htmlNospaceReplacementTable, false)
}
// attrEscaper escapes for inclusion in quoted attribute values.
func attrEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return htmlReplacer(stripTags(s), htmlNormReplacementTable, true)
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// rcdataEscaper escapes for inclusion in an RCDATA element body.
func rcdataEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return htmlReplacer(s, htmlNormReplacementTable, true)
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// htmlEscaper escapes for inclusion in HTML text.
func htmlEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTML {
return s
}
return htmlReplacer(s, htmlReplacementTable, true)
}
// htmlReplacementTable contains the runes that need to be escaped
// inside a quoted attribute value or in a text node.
var htmlReplacementTable = []string{
// https://www.w3.org/TR/html5/syntax.html#attribute-value-(unquoted)-state
// U+0000 NULL Parse error. Append a U+FFFD REPLACEMENT
// CHARACTER character to the current attribute's value.
// "
// and similarly
// https://www.w3.org/TR/html5/syntax.html#before-attribute-value-state
0: "\uFFFD",
'"': """,
'&': "&",
'\'': "'",
'+': "+",
'<': "<",
'>': ">",
}
// htmlNormReplacementTable is like htmlReplacementTable but without '&' to
// avoid over-encoding existing entities.
var htmlNormReplacementTable = []string{
0: "\uFFFD",
'"': """,
'\'': "'",
'+': "+",
'<': "<",
'>': ">",
}
// htmlNospaceReplacementTable contains the runes that need to be escaped
// inside an unquoted attribute value.
// The set of runes escaped is the union of the HTML specials and
// those determined by running the JS below in browsers:
// <div id=d></div>
// <script>(function () {
// var a = [], d = document.getElementById("d"), i, c, s;
// for (i = 0; i < 0x10000; ++i) {
//
// c = String.fromCharCode(i);
// d.innerHTML = "<span title=" + c + "lt" + c + "></span>"
// s = d.getElementsByTagName("SPAN")[0];
// if (!s || s.title !== c + "lt" + c) { a.push(i.toString(16)); }
//
// }
// document.write(a.join(", "));
// })()</script>
var htmlNospaceReplacementTable = []string{
0: "�",
'\t': "	",
'\n': " ",
'\v': "",
'\f': "",
'\r': " ",
' ': " ",
'"': """,
'&': "&",
'\'': "'",
'+': "+",
'<': "<",
'=': "=",
'>': ">",
// A parse error in the attribute value (unquoted) and
// before attribute value states.
// Treated as a quoting character by IE.
'`': "`",
}
// htmlNospaceNormReplacementTable is like htmlNospaceReplacementTable but
// without '&' to avoid over-encoding existing entities.
var htmlNospaceNormReplacementTable = []string{
0: "�",
'\t': "	",
'\n': " ",
'\v': "",
'\f': "",
'\r': " ",
' ': " ",
'"': """,
'\'': "'",
'+': "+",
'<': "<",
'=': "=",
'>': ">",
// A parse error in the attribute value (unquoted) and
// before attribute value states.
// Treated as a quoting character by IE.
'`': "`",
}
// htmlReplacer returns s with runes replaced according to replacementTable
// and when badRunes is true, certain bad runes are allowed through unescaped.
func htmlReplacer(s string, replacementTable []string, badRunes bool) string {
written, b := 0, new(strings.Builder)
r, w := rune(0), 0
for i := 0; i < len(s); i += w {
// Cannot use 'for range s' because we need to preserve the width
// of the runes in the input. If we see a decoding error, the input
// width will not be utf8.Runelen(r) and we will overrun the buffer.
r, w = utf8.DecodeRuneInString(s[i:])
if int(r) < len(replacementTable) {
if repl := replacementTable[r]; len(repl) != 0 {
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
}
} else if badRunes {
// No-op.
// IE does not allow these ranges in unquoted attrs.
} else if 0xfdd0 <= r && r <= 0xfdef || 0xfff0 <= r && r <= 0xffff {
if written == 0 {
b.Grow(len(s))
}
fmt.Fprintf(b, "%s&#x%x;", s[written:i], r)
written = i + w
}
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
// stripTags takes a snippet of HTML and returns only the text content.
// For example, `<b>¡Hi!</b> <script>...</script>` -> `¡Hi! `.
func stripTags(html string) string {
var b strings.Builder
s, c, i, allText := []byte(html), context{}, 0, true
// Using the transition funcs helps us avoid mangling
// `<div title="1>2">` or `I <3 Ponies!`.
for i != len(s) {
if c.delim == delimNone {
st := c.state
// Use RCDATA instead of parsing into JS or CSS styles.
if c.element != elementNone && !isInTag(st) {
st = stateRCDATA
}
d, nread := transitionFunc[st](c, s[i:])
i1 := i + nread
if c.state == stateText || c.state == stateRCDATA {
// Emit text up to the start of the tag or comment.
j := i1
if d.state != c.state {
for j1 := j - 1; j1 >= i; j1-- {
if s[j1] == '<' {
j = j1
break
}
}
}
b.Write(s[i:j])
} else {
allText = false
}
c, i = d, i1
continue
}
i1 := i + bytes.IndexAny(s[i:], delimEnds[c.delim])
if i1 < i {
break
}
if c.delim != delimSpaceOrTagEnd {
// Consume any quote.
i1++
}
c, i = context{state: stateTag, element: c.element}, i1
}
if allText {
return html
} else if c.state == stateText || c.state == stateRCDATA {
b.Write(s[i:])
}
return b.String()
}
// htmlNameFilter accepts valid parts of an HTML attribute or tag name or
// a known-safe HTML attribute.
func htmlNameFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeHTMLAttr {
return s
}
if len(s) == 0 {
// Avoid violation of structure preservation.
// <input checked {{.K}}={{.V}}>.
// Without this, if .K is empty then .V is the value of
// checked, but otherwise .V is the value of the attribute
// named .K.
return filterFailsafe
}
s = strings.ToLower(s)
if t := attrType(s); t != contentTypePlain {
// TODO: Split attr and element name part filters so we can recognize known attributes.
return filterFailsafe
}
for _, r := range s {
switch {
case '0' <= r && r <= '9':
case 'a' <= r && r <= 'z':
default:
return filterFailsafe
}
}
return s
}
// commentEscaper returns the empty string regardless of input.
// Comment content does not correspond to any parsed structure or
// human-readable content, so the simplest and most secure policy is to drop
// content interpolated into comments.
// This approach is equally valid whether or not static comment content is
// removed from the template.
func commentEscaper(args ...any) string {
return ""
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strings"
"unicode/utf8"
)
// nextJSCtx returns the context that determines whether a slash after the
// given run of tokens starts a regular expression instead of a division
// operator: / or /=.
//
// This assumes that the token run does not include any string tokens, comment
// tokens, regular expression literal tokens, or division operators.
//
// This fails on some valid but nonsensical JavaScript programs like
// "x = ++/foo/i" which is quite different than "x++/foo/i", but is not known to
// fail on any known useful programs. It is based on the draft
// JavaScript 2.0 lexical grammar and requires one token of lookbehind:
// https://www.mozilla.org/js/language/js20-2000-07/rationale/syntax.html
func nextJSCtx(s []byte, preceding jsCtx) jsCtx {
s = bytes.TrimRight(s, "\t\n\f\r \u2028\u2029")
if len(s) == 0 {
return preceding
}
// All cases below are in the single-byte UTF-8 group.
switch c, n := s[len(s)-1], len(s); c {
case '+', '-':
// ++ and -- are not regexp preceders, but + and - are whether
// they are used as infix or prefix operators.
start := n - 1
// Count the number of adjacent dashes or pluses.
for start > 0 && s[start-1] == c {
start--
}
if (n-start)&1 == 1 {
// Reached for trailing minus signs since "---" is the
// same as "-- -".
return jsCtxRegexp
}
return jsCtxDivOp
case '.':
// Handle "42."
if n != 1 && '0' <= s[n-2] && s[n-2] <= '9' {
return jsCtxDivOp
}
return jsCtxRegexp
// Suffixes for all punctuators from section 7.7 of the language spec
// that only end binary operators not handled above.
case ',', '<', '>', '=', '*', '%', '&', '|', '^', '?':
return jsCtxRegexp
// Suffixes for all punctuators from section 7.7 of the language spec
// that are prefix operators not handled above.
case '!', '~':
return jsCtxRegexp
// Matches all the punctuators from section 7.7 of the language spec
// that are open brackets not handled above.
case '(', '[':
return jsCtxRegexp
// Matches all the punctuators from section 7.7 of the language spec
// that precede expression starts.
case ':', ';', '{':
return jsCtxRegexp
// CAVEAT: the close punctuators ('}', ']', ')') precede div ops and
// are handled in the default except for '}' which can precede a
// division op as in
// ({ valueOf: function () { return 42 } } / 2
// which is valid, but, in practice, developers don't divide object
// literals, so our heuristic works well for code like
// function () { ... } /foo/.test(x) && sideEffect();
// The ')' punctuator can precede a regular expression as in
// if (b) /foo/.test(x) && ...
// but this is much less likely than
// (a + b) / c
case '}':
return jsCtxRegexp
default:
// Look for an IdentifierName and see if it is a keyword that
// can precede a regular expression.
j := n
for j > 0 && isJSIdentPart(rune(s[j-1])) {
j--
}
if regexpPrecederKeywords[string(s[j:])] {
return jsCtxRegexp
}
}
// Otherwise is a punctuator not listed above, or
// a string which precedes a div op, or an identifier
// which precedes a div op.
return jsCtxDivOp
}
// regexpPrecederKeywords is a set of reserved JS keywords that can precede a
// regular expression in JS source.
var regexpPrecederKeywords = map[string]bool{
"break": true,
"case": true,
"continue": true,
"delete": true,
"do": true,
"else": true,
"finally": true,
"in": true,
"instanceof": true,
"return": true,
"throw": true,
"try": true,
"typeof": true,
"void": true,
}
var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
// indirectToJSONMarshaler returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
func indirectToJSONMarshaler(a any) any {
// text/template now supports passing untyped nil as a func call
// argument, so we must support it. Otherwise we'd panic below, as one
// cannot call the Type or Interface methods on an invalid
// reflect.Value. See golang.org/issue/18716.
if a == nil {
return nil
}
v := reflect.ValueOf(a)
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
// neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...any) string {
var a any
if len(args) == 1 {
a = indirectToJSONMarshaler(args[0])
switch t := a.(type) {
case JS:
return string(t)
case JSStr:
// TODO: normalize quotes.
return `"` + string(t) + `"`
case json.Marshaler:
// Do not treat as a Stringer.
case fmt.Stringer:
a = t.String()
}
} else {
for i, arg := range args {
args[i] = indirectToJSONMarshaler(arg)
}
a = fmt.Sprint(args...)
}
// TODO: detect cycles before calling Marshal which loops infinitely on
// cyclic data. This may be an unacceptable DoS risk.
b, err := json.Marshal(a)
if err != nil {
// Put a space before comment so that if it is flush against
// a division operator it is not turned into a line comment:
// x/{{y}}
// turning into
// x//* error marshaling y:
// second line of error message */null
return fmt.Sprintf(" /* %s */null ", strings.ReplaceAll(err.Error(), "*/", "* /"))
}
// TODO: maybe post-process output to prevent it from containing
// "<!--", "-->", "<![CDATA[", "]]>", or "</script"
// in case custom marshalers produce output containing those.
// Note: Do not use \x escaping to save bytes because it is not JSON compatible and this escaper
// supports ld+json content-type.
if len(b) == 0 {
// In, `x=y/{{.}}*z` a json.Marshaler that produces "" should
// not cause the output `x=y/*z`.
return " null "
}
first, _ := utf8.DecodeRune(b)
last, _ := utf8.DecodeLastRune(b)
var buf strings.Builder
// Prevent IdentifierNames and NumericLiterals from running into
// keywords: in, instanceof, typeof, void
pad := isJSIdentPart(first) || isJSIdentPart(last)
if pad {
buf.WriteByte(' ')
}
written := 0
// Make sure that json.Marshal escapes codepoints U+2028 & U+2029
// so it falls within the subset of JSON which is valid JS.
for i := 0; i < len(b); {
rune, n := utf8.DecodeRune(b[i:])
repl := ""
if rune == 0x2028 {
repl = `\u2028`
} else if rune == 0x2029 {
repl = `\u2029`
}
if repl != "" {
buf.Write(b[written:i])
buf.WriteString(repl)
written = i + n
}
i += n
}
if buf.Len() != 0 {
buf.Write(b[written:])
if pad {
buf.WriteByte(' ')
}
return buf.String()
}
return string(b)
}
// jsStrEscaper produces a string that can be included between quotes in
// JavaScript source, in JavaScript embedded in an HTML5 <script> element,
// or in an HTML5 event handler attribute such as onclick.
func jsStrEscaper(args ...any) string {
s, t := stringify(args...)
if t == contentTypeJSStr {
return replace(s, jsStrNormReplacementTable)
}
return replace(s, jsStrReplacementTable)
}
// jsRegexpEscaper behaves like jsStrEscaper but escapes regular expression
// specials so the result is treated literally when included in a regular
// expression literal. /foo{{.X}}bar/ matches the string "foo" followed by
// the literal text of {{.X}} followed by the string "bar".
func jsRegexpEscaper(args ...any) string {
s, _ := stringify(args...)
s = replace(s, jsRegexpReplacementTable)
if s == "" {
// /{{.X}}/ should not produce a line comment when .X == "".
return "(?:)"
}
return s
}
// replace replaces each rune r of s with replacementTable[r], provided that
// r < len(replacementTable). If replacementTable[r] is the empty string then
// no replacement is made.
// It also replaces runes U+2028 and U+2029 with the raw strings `\u2028` and
// `\u2029`.
func replace(s string, replacementTable []string) string {
var b strings.Builder
r, w, written := rune(0), 0, 0
for i := 0; i < len(s); i += w {
// See comment in htmlEscaper.
r, w = utf8.DecodeRuneInString(s[i:])
var repl string
switch {
case int(r) < len(lowUnicodeReplacementTable):
repl = lowUnicodeReplacementTable[r]
case int(r) < len(replacementTable) && replacementTable[r] != "":
repl = replacementTable[r]
case r == '\u2028':
repl = `\u2028`
case r == '\u2029':
repl = `\u2029`
default:
continue
}
if written == 0 {
b.Grow(len(s))
}
b.WriteString(s[written:i])
b.WriteString(repl)
written = i + w
}
if written == 0 {
return s
}
b.WriteString(s[written:])
return b.String()
}
var lowUnicodeReplacementTable = []string{
0: `\u0000`, 1: `\u0001`, 2: `\u0002`, 3: `\u0003`, 4: `\u0004`, 5: `\u0005`, 6: `\u0006`,
'\a': `\u0007`,
'\b': `\u0008`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
0xe: `\u000e`, 0xf: `\u000f`, 0x10: `\u0010`, 0x11: `\u0011`, 0x12: `\u0012`, 0x13: `\u0013`,
0x14: `\u0014`, 0x15: `\u0015`, 0x16: `\u0016`, 0x17: `\u0017`, 0x18: `\u0018`, 0x19: `\u0019`,
0x1a: `\u001a`, 0x1b: `\u001b`, 0x1c: `\u001c`, 0x1d: `\u001d`, 0x1e: `\u001e`, 0x1f: `\u001f`,
}
var jsStrReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'&': `\u0026`,
'\'': `\u0027`,
'+': `\u002b`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
'\\': `\\`,
}
// jsStrNormReplacementTable is like jsStrReplacementTable but does not
// overencode existing escapes since this table has no entry for `\`.
var jsStrNormReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'&': `\u0026`,
'\'': `\u0027`,
'+': `\u002b`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
}
var jsRegexpReplacementTable = []string{
0: `\u0000`,
'\t': `\t`,
'\n': `\n`,
'\v': `\u000b`, // "\v" == "v" on IE 6.
'\f': `\f`,
'\r': `\r`,
// Encode HTML specials as hex so the output can be embedded
// in HTML attributes without further encoding.
'"': `\u0022`,
'$': `\$`,
'&': `\u0026`,
'\'': `\u0027`,
'(': `\(`,
')': `\)`,
'*': `\*`,
'+': `\u002b`,
'-': `\-`,
'.': `\.`,
'/': `\/`,
'<': `\u003c`,
'>': `\u003e`,
'?': `\?`,
'[': `\[`,
'\\': `\\`,
']': `\]`,
'^': `\^`,
'{': `\{`,
'|': `\|`,
'}': `\}`,
}
// isJSIdentPart reports whether the given rune is a JS identifier part.
// It does not handle all the non-Latin letters, joiners, and combining marks,
// but it does handle every codepoint that can occur in a numeric literal or
// a keyword.
func isJSIdentPart(r rune) bool {
switch {
case r == '$':
return true
case '0' <= r && r <= '9':
return true
case 'A' <= r && r <= 'Z':
return true
case r == '_':
return true
case 'a' <= r && r <= 'z':
return true
}
return false
}
// isJSType reports whether the given MIME type should be considered JavaScript.
//
// It is used to determine whether a script tag with a type attribute is a javascript container.
func isJSType(mimeType string) bool {
// per
// https://www.w3.org/TR/html5/scripting-1.html#attr-script-type
// https://tools.ietf.org/html/rfc7231#section-3.1.1
// https://tools.ietf.org/html/rfc4329#section-3
// https://www.ietf.org/rfc/rfc4627.txt
// discard parameters
mimeType, _, _ = strings.Cut(mimeType, ";")
mimeType = strings.ToLower(mimeType)
mimeType = strings.TrimSpace(mimeType)
switch mimeType {
case
"application/ecmascript",
"application/javascript",
"application/json",
"application/ld+json",
"application/x-ecmascript",
"application/x-javascript",
"module",
"text/ecmascript",
"text/javascript",
"text/javascript1.0",
"text/javascript1.1",
"text/javascript1.2",
"text/javascript1.3",
"text/javascript1.4",
"text/javascript1.5",
"text/jscript",
"text/livescript",
"text/x-ecmascript",
"text/x-javascript":
return true
default:
return false
}
}
// Code generated by "stringer -type jsCtx"; DO NOT EDIT.
package template
import "strconv"
const _jsCtx_name = "jsCtxRegexpjsCtxDivOpjsCtxUnknown"
var _jsCtx_index = [...]uint8{0, 11, 21, 33}
func (i jsCtx) String() string {
if i >= jsCtx(len(_jsCtx_index)-1) {
return "jsCtx(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _jsCtx_name[_jsCtx_index[i]:_jsCtx_index[i+1]]
}
// Code generated by "stringer -type state"; DO NOT EDIT.
package template
import "strconv"
const _state_name = "stateTextstateTagstateAttrNamestateAfterNamestateBeforeValuestateHTMLCmtstateRCDATAstateAttrstateURLstateSrcsetstateJSstateJSDqStrstateJSSqStrstateJSRegexpstateJSBlockCmtstateJSLineCmtstateCSSstateCSSDqStrstateCSSSqStrstateCSSDqURLstateCSSSqURLstateCSSURLstateCSSBlockCmtstateCSSLineCmtstateError"
var _state_index = [...]uint16{0, 9, 17, 30, 44, 60, 72, 83, 92, 100, 111, 118, 130, 142, 155, 170, 184, 192, 205, 218, 231, 244, 255, 271, 286, 296}
func (i state) String() string {
if i >= state(len(_state_index)-1) {
return "state(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _state_name[_state_index[i]:_state_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"sync"
"text/template"
"text/template/parse"
)
// Template is a specialized Template from "text/template" that produces a safe
// HTML document fragment.
type Template struct {
// Sticky error if escaping fails, or escapeOK if succeeded.
escapeErr error
// We could embed the text/template field, but it's safer not to because
// we need to keep our version of the name space and the underlying
// template's in sync.
text *template.Template
// The underlying template's parse tree, updated to be HTML-safe.
Tree *parse.Tree
*nameSpace // common to all associated templates
}
// escapeOK is a sentinel value used to indicate valid escaping.
var escapeOK = fmt.Errorf("template escaped correctly")
// nameSpace is the data structure shared by all templates in an association.
type nameSpace struct {
mu sync.Mutex
set map[string]*Template
escaped bool
esc escaper
}
// Templates returns a slice of the templates associated with t, including t
// itself.
func (t *Template) Templates() []*Template {
ns := t.nameSpace
ns.mu.Lock()
defer ns.mu.Unlock()
// Return a slice so we don't expose the map.
m := make([]*Template, 0, len(ns.set))
for _, v := range ns.set {
m = append(m, v)
}
return m
}
// Option sets options for the template. Options are described by
// strings, either a simple string or "key=value". There can be at
// most one equals sign in an option string. If the option string
// is unrecognized or otherwise invalid, Option panics.
//
// Known options:
//
// missingkey: Control the behavior during execution if a map is
// indexed with a key that is not present in the map.
//
// "missingkey=default" or "missingkey=invalid"
// The default behavior: Do nothing and continue execution.
// If printed, the result of the index operation is the string
// "<no value>".
// "missingkey=zero"
// The operation returns the zero value for the map type's element.
// "missingkey=error"
// Execution stops immediately with an error.
func (t *Template) Option(opt ...string) *Template {
t.text.Option(opt...)
return t
}
// checkCanParse checks whether it is OK to parse templates.
// If not, it returns an error.
func (t *Template) checkCanParse() error {
if t == nil {
return nil
}
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
if t.nameSpace.escaped {
return fmt.Errorf("html/template: cannot Parse after Execute")
}
return nil
}
// escape escapes all associated templates.
func (t *Template) escape() error {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
t.nameSpace.escaped = true
if t.escapeErr == nil {
if t.Tree == nil {
return fmt.Errorf("template: %q is an incomplete or empty template", t.Name())
}
if err := escapeTemplate(t, t.text.Root, t.Name()); err != nil {
return err
}
} else if t.escapeErr != escapeOK {
return t.escapeErr
}
return nil
}
// Execute applies a parsed template to the specified data object,
// writing the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
func (t *Template) Execute(wr io.Writer, data any) error {
if err := t.escape(); err != nil {
return err
}
return t.text.Execute(wr, data)
}
// ExecuteTemplate applies the template associated with t that has the given
// name to the specified data object and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
func (t *Template) ExecuteTemplate(wr io.Writer, name string, data any) error {
tmpl, err := t.lookupAndEscapeTemplate(name)
if err != nil {
return err
}
return tmpl.text.Execute(wr, data)
}
// lookupAndEscapeTemplate guarantees that the template with the given name
// is escaped, or returns an error if it cannot be. It returns the named
// template.
func (t *Template) lookupAndEscapeTemplate(name string) (tmpl *Template, err error) {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
t.nameSpace.escaped = true
tmpl = t.set[name]
if tmpl == nil {
return nil, fmt.Errorf("html/template: %q is undefined", name)
}
if tmpl.escapeErr != nil && tmpl.escapeErr != escapeOK {
return nil, tmpl.escapeErr
}
if tmpl.text.Tree == nil || tmpl.text.Root == nil {
return nil, fmt.Errorf("html/template: %q is an incomplete template", name)
}
if t.text.Lookup(name) == nil {
panic("html/template internal error: template escaping out of sync")
}
if tmpl.escapeErr == nil {
err = escapeTemplate(tmpl, tmpl.text.Root, name)
}
return tmpl, err
}
// DefinedTemplates returns a string listing the defined templates,
// prefixed by the string "; defined templates are: ". If there are none,
// it returns the empty string. Used to generate an error message.
func (t *Template) DefinedTemplates() string {
return t.text.DefinedTemplates()
}
// Parse parses text as a template body for t.
// Named template definitions ({{define ...}} or {{block ...}} statements) in text
// define additional templates associated with t and are removed from the
// definition of t itself.
//
// Templates can be redefined in successive calls to Parse,
// before the first use of Execute on t or any associated template.
// A template definition with a body containing only white space and comments
// is considered empty and will not replace an existing template's body.
// This allows using Parse to add new named template definitions without
// overwriting the main template body.
func (t *Template) Parse(text string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
ret, err := t.text.Parse(text)
if err != nil {
return nil, err
}
// In general, all the named templates might have changed underfoot.
// Regardless, some new ones may have been defined.
// The template.Template set has been updated; update ours.
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
for _, v := range ret.Templates() {
name := v.Name()
tmpl := t.set[name]
if tmpl == nil {
tmpl = t.new(name)
}
tmpl.text = v
tmpl.Tree = v.Tree
}
return t, nil
}
// AddParseTree creates a new template with the name and parse tree
// and associates it with t.
//
// It returns an error if t or any associated template has already been executed.
func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
text, err := t.text.AddParseTree(name, tree)
if err != nil {
return nil, err
}
ret := &Template{
nil,
text,
text.Tree,
t.nameSpace,
}
t.set[name] = ret
return ret, nil
}
// Clone returns a duplicate of the template, including all associated
// templates. The actual representation is not copied, but the name space of
// associated templates is, so further calls to Parse in the copy will add
// templates to the copy but not to the original. Clone can be used to prepare
// common templates and use them with variant definitions for other templates
// by adding the variants after the clone is made.
//
// It returns an error if t has already been executed.
func (t *Template) Clone() (*Template, error) {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
if t.escapeErr != nil {
return nil, fmt.Errorf("html/template: cannot Clone %q after it has executed", t.Name())
}
textClone, err := t.text.Clone()
if err != nil {
return nil, err
}
ns := &nameSpace{set: make(map[string]*Template)}
ns.esc = makeEscaper(ns)
ret := &Template{
nil,
textClone,
textClone.Tree,
ns,
}
ret.set[ret.Name()] = ret
for _, x := range textClone.Templates() {
name := x.Name()
src := t.set[name]
if src == nil || src.escapeErr != nil {
return nil, fmt.Errorf("html/template: cannot Clone %q after it has executed", t.Name())
}
x.Tree = x.Tree.Copy()
ret.set[name] = &Template{
nil,
x,
x.Tree,
ret.nameSpace,
}
}
// Return the template associated with the name of this template.
return ret.set[ret.Name()], nil
}
// New allocates a new HTML template with the given name.
func New(name string) *Template {
ns := &nameSpace{set: make(map[string]*Template)}
ns.esc = makeEscaper(ns)
tmpl := &Template{
nil,
template.New(name),
nil,
ns,
}
tmpl.set[name] = tmpl
return tmpl
}
// New allocates a new HTML template associated with the given one
// and with the same delimiters. The association, which is transitive,
// allows one template to invoke another with a {{template}} action.
//
// If a template with the given name already exists, the new HTML template
// will replace it. The existing template will be reset and disassociated with
// t.
func (t *Template) New(name string) *Template {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
return t.new(name)
}
// new is the implementation of New, without the lock.
func (t *Template) new(name string) *Template {
tmpl := &Template{
nil,
t.text.New(name),
nil,
t.nameSpace,
}
if existing, ok := tmpl.set[name]; ok {
emptyTmpl := New(existing.Name())
*existing = *emptyTmpl
}
tmpl.set[name] = tmpl
return tmpl
}
// Name returns the name of the template.
func (t *Template) Name() string {
return t.text.Name()
}
type FuncMap = template.FuncMap
// Funcs adds the elements of the argument map to the template's function map.
// It must be called before the template is parsed.
// It panics if a value in the map is not a function with appropriate return
// type. However, it is legal to overwrite elements of the map. The return
// value is the template, so calls can be chained.
func (t *Template) Funcs(funcMap FuncMap) *Template {
t.text.Funcs(template.FuncMap(funcMap))
return t
}
// Delims sets the action delimiters to the specified strings, to be used in
// subsequent calls to Parse, ParseFiles, or ParseGlob. Nested template
// definitions will inherit the settings. An empty delimiter stands for the
// corresponding default: {{ or }}.
// The return value is the template, so calls can be chained.
func (t *Template) Delims(left, right string) *Template {
t.text.Delims(left, right)
return t
}
// Lookup returns the template with the given name that is associated with t,
// or nil if there is no such template.
func (t *Template) Lookup(name string) *Template {
t.nameSpace.mu.Lock()
defer t.nameSpace.mu.Unlock()
return t.set[name]
}
// Must is a helper that wraps a call to a function returning (*Template, error)
// and panics if the error is non-nil. It is intended for use in variable initializations
// such as
//
// var t = template.Must(template.New("name").Parse("html"))
func Must(t *Template, err error) *Template {
if err != nil {
panic(err)
}
return t
}
// ParseFiles creates a new Template and parses the template definitions from
// the named files. The returned template's name will have the (base) name and
// (parsed) contents of the first file. There must be at least one file.
// If an error occurs, parsing stops and the returned *Template is nil.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
// For instance, ParseFiles("a/foo", "b/foo") stores "b/foo" as the template
// named "foo", while "a/foo" is unavailable.
func ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(nil, readFileOS, filenames...)
}
// ParseFiles parses the named files and associates the resulting templates with
// t. If an error occurs, parsing stops and the returned template is nil;
// otherwise it is t. There must be at least one file.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
//
// ParseFiles returns an error if t or any associated template has already been executed.
func (t *Template) ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(t, readFileOS, filenames...)
}
// parseFiles is the helper for the method and function. If the argument
// template is nil, it is created from the first file.
func parseFiles(t *Template, readFile func(string) (string, []byte, error), filenames ...string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
if len(filenames) == 0 {
// Not really a problem, but be consistent.
return nil, fmt.Errorf("html/template: no files named in call to ParseFiles")
}
for _, filename := range filenames {
name, b, err := readFile(filename)
if err != nil {
return nil, err
}
s := string(b)
// First template becomes return value if not already defined,
// and we use that one for subsequent New calls to associate
// all the templates together. Also, if this file has the same name
// as t, this file becomes the contents of t, so
// t, err := New(name).Funcs(xxx).ParseFiles(name)
// works. Otherwise we create a new template associated with t.
var tmpl *Template
if t == nil {
t = New(name)
}
if name == t.Name() {
tmpl = t
} else {
tmpl = t.New(name)
}
_, err = tmpl.Parse(s)
if err != nil {
return nil, err
}
}
return t, nil
}
// ParseGlob creates a new Template and parses the template definitions from
// the files identified by the pattern. The files are matched according to the
// semantics of filepath.Match, and the pattern must match at least one file.
// The returned template will have the (base) name and (parsed) contents of the
// first file matched by the pattern. ParseGlob is equivalent to calling
// ParseFiles with the list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
func ParseGlob(pattern string) (*Template, error) {
return parseGlob(nil, pattern)
}
// ParseGlob parses the template definitions in the files identified by the
// pattern and associates the resulting templates with t. The files are matched
// according to the semantics of filepath.Match, and the pattern must match at
// least one file. ParseGlob is equivalent to calling t.ParseFiles with the
// list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
//
// ParseGlob returns an error if t or any associated template has already been executed.
func (t *Template) ParseGlob(pattern string) (*Template, error) {
return parseGlob(t, pattern)
}
// parseGlob is the implementation of the function and method ParseGlob.
func parseGlob(t *Template, pattern string) (*Template, error) {
if err := t.checkCanParse(); err != nil {
return nil, err
}
filenames, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
if len(filenames) == 0 {
return nil, fmt.Errorf("html/template: pattern matches no files: %#q", pattern)
}
return parseFiles(t, readFileOS, filenames...)
}
// IsTrue reports whether the value is 'true', in the sense of not the zero of its type,
// and whether the value has a meaningful truth value. This is the definition of
// truth used by if and other such actions.
func IsTrue(val any) (truth, ok bool) {
return template.IsTrue(val)
}
// ParseFS is like ParseFiles or ParseGlob but reads from the file system fs
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func ParseFS(fs fs.FS, patterns ...string) (*Template, error) {
return parseFS(nil, fs, patterns)
}
// ParseFS is like ParseFiles or ParseGlob but reads from the file system fs
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func (t *Template) ParseFS(fs fs.FS, patterns ...string) (*Template, error) {
return parseFS(t, fs, patterns)
}
func parseFS(t *Template, fsys fs.FS, patterns []string) (*Template, error) {
var filenames []string
for _, pattern := range patterns {
list, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, fmt.Errorf("template: pattern matches no files: %#q", pattern)
}
filenames = append(filenames, list...)
}
return parseFiles(t, readFileFS(fsys), filenames...)
}
func readFileOS(file string) (name string, b []byte, err error) {
name = filepath.Base(file)
b, err = os.ReadFile(file)
return
}
func readFileFS(fsys fs.FS) func(string) (string, []byte, error) {
return func(file string) (name string, b []byte, err error) {
name = path.Base(file)
b, err = fs.ReadFile(fsys, file)
return
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"bytes"
"strings"
)
// transitionFunc is the array of context transition functions for text nodes.
// A transition function takes a context and template text input, and returns
// the updated context and the number of bytes consumed from the front of the
// input.
var transitionFunc = [...]func(context, []byte) (context, int){
stateText: tText,
stateTag: tTag,
stateAttrName: tAttrName,
stateAfterName: tAfterName,
stateBeforeValue: tBeforeValue,
stateHTMLCmt: tHTMLCmt,
stateRCDATA: tSpecialTagEnd,
stateAttr: tAttr,
stateURL: tURL,
stateSrcset: tURL,
stateJS: tJS,
stateJSDqStr: tJSDelimited,
stateJSSqStr: tJSDelimited,
stateJSRegexp: tJSDelimited,
stateJSBlockCmt: tBlockCmt,
stateJSLineCmt: tLineCmt,
stateCSS: tCSS,
stateCSSDqStr: tCSSStr,
stateCSSSqStr: tCSSStr,
stateCSSDqURL: tCSSStr,
stateCSSSqURL: tCSSStr,
stateCSSURL: tCSSStr,
stateCSSBlockCmt: tBlockCmt,
stateCSSLineCmt: tLineCmt,
stateError: tError,
}
var commentStart = []byte("<!--")
var commentEnd = []byte("-->")
// tText is the context transition function for the text state.
func tText(c context, s []byte) (context, int) {
k := 0
for {
i := k + bytes.IndexByte(s[k:], '<')
if i < k || i+1 == len(s) {
return c, len(s)
} else if i+4 <= len(s) && bytes.Equal(commentStart, s[i:i+4]) {
return context{state: stateHTMLCmt}, i + 4
}
i++
end := false
if s[i] == '/' {
if i+1 == len(s) {
return c, len(s)
}
end, i = true, i+1
}
j, e := eatTagName(s, i)
if j != i {
if end {
e = elementNone
}
// We've found an HTML tag.
return context{state: stateTag, element: e}, j
}
k = j
}
}
var elementContentType = [...]state{
elementNone: stateText,
elementScript: stateJS,
elementStyle: stateCSS,
elementTextarea: stateRCDATA,
elementTitle: stateRCDATA,
}
// tTag is the context transition function for the tag state.
func tTag(c context, s []byte) (context, int) {
// Find the attribute name.
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
}
if s[i] == '>' {
return context{
state: elementContentType[c.element],
element: c.element,
}, i + 1
}
j, err := eatAttrName(s, i)
if err != nil {
return context{state: stateError, err: err}, len(s)
}
state, attr := stateTag, attrNone
if i == j {
return context{
state: stateError,
err: errorf(ErrBadHTML, nil, 0, "expected space, attr name, or end of tag, but got %q", s[i:]),
}, len(s)
}
attrName := strings.ToLower(string(s[i:j]))
if c.element == elementScript && attrName == "type" {
attr = attrScriptType
} else {
switch attrType(attrName) {
case contentTypeURL:
attr = attrURL
case contentTypeCSS:
attr = attrStyle
case contentTypeJS:
attr = attrScript
case contentTypeSrcset:
attr = attrSrcset
}
}
if j == len(s) {
state = stateAttrName
} else {
state = stateAfterName
}
return context{state: state, element: c.element, attr: attr}, j
}
// tAttrName is the context transition function for stateAttrName.
func tAttrName(c context, s []byte) (context, int) {
i, err := eatAttrName(s, 0)
if err != nil {
return context{state: stateError, err: err}, len(s)
} else if i != len(s) {
c.state = stateAfterName
}
return c, i
}
// tAfterName is the context transition function for stateAfterName.
func tAfterName(c context, s []byte) (context, int) {
// Look for the start of the value.
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
} else if s[i] != '=' {
// Occurs due to tag ending '>', and valueless attribute.
c.state = stateTag
return c, i
}
c.state = stateBeforeValue
// Consume the "=".
return c, i + 1
}
var attrStartStates = [...]state{
attrNone: stateAttr,
attrScript: stateJS,
attrScriptType: stateAttr,
attrStyle: stateCSS,
attrURL: stateURL,
attrSrcset: stateSrcset,
}
// tBeforeValue is the context transition function for stateBeforeValue.
func tBeforeValue(c context, s []byte) (context, int) {
i := eatWhiteSpace(s, 0)
if i == len(s) {
return c, len(s)
}
// Find the attribute delimiter.
delim := delimSpaceOrTagEnd
switch s[i] {
case '\'':
delim, i = delimSingleQuote, i+1
case '"':
delim, i = delimDoubleQuote, i+1
}
c.state, c.delim = attrStartStates[c.attr], delim
return c, i
}
// tHTMLCmt is the context transition function for stateHTMLCmt.
func tHTMLCmt(c context, s []byte) (context, int) {
if i := bytes.Index(s, commentEnd); i != -1 {
return context{}, i + 3
}
return c, len(s)
}
// specialTagEndMarkers maps element types to the character sequence that
// case-insensitively signals the end of the special tag body.
var specialTagEndMarkers = [...][]byte{
elementScript: []byte("script"),
elementStyle: []byte("style"),
elementTextarea: []byte("textarea"),
elementTitle: []byte("title"),
}
var (
specialTagEndPrefix = []byte("</")
tagEndSeparators = []byte("> \t\n\f/")
)
// tSpecialTagEnd is the context transition function for raw text and RCDATA
// element states.
func tSpecialTagEnd(c context, s []byte) (context, int) {
if c.element != elementNone {
if i := indexTagEnd(s, specialTagEndMarkers[c.element]); i != -1 {
return context{}, i
}
}
return c, len(s)
}
// indexTagEnd finds the index of a special tag end in a case insensitive way, or returns -1
func indexTagEnd(s []byte, tag []byte) int {
res := 0
plen := len(specialTagEndPrefix)
for len(s) > 0 {
// Try to find the tag end prefix first
i := bytes.Index(s, specialTagEndPrefix)
if i == -1 {
return i
}
s = s[i+plen:]
// Try to match the actual tag if there is still space for it
if len(tag) <= len(s) && bytes.EqualFold(tag, s[:len(tag)]) {
s = s[len(tag):]
// Check the tag is followed by a proper separator
if len(s) > 0 && bytes.IndexByte(tagEndSeparators, s[0]) != -1 {
return res + i
}
res += len(tag)
}
res += i + plen
}
return -1
}
// tAttr is the context transition function for the attribute state.
func tAttr(c context, s []byte) (context, int) {
return c, len(s)
}
// tURL is the context transition function for the URL state.
func tURL(c context, s []byte) (context, int) {
if bytes.ContainsAny(s, "#?") {
c.urlPart = urlPartQueryOrFrag
} else if len(s) != eatWhiteSpace(s, 0) && c.urlPart == urlPartNone {
// HTML5 uses "Valid URL potentially surrounded by spaces" for
// attrs: https://www.w3.org/TR/html5/index.html#attributes-1
c.urlPart = urlPartPreQuery
}
return c, len(s)
}
// tJS is the context transition function for the JS state.
func tJS(c context, s []byte) (context, int) {
i := bytes.IndexAny(s, `"'/`)
if i == -1 {
// Entire input is non string, comment, regexp tokens.
c.jsCtx = nextJSCtx(s, c.jsCtx)
return c, len(s)
}
c.jsCtx = nextJSCtx(s[:i], c.jsCtx)
switch s[i] {
case '"':
c.state, c.jsCtx = stateJSDqStr, jsCtxRegexp
case '\'':
c.state, c.jsCtx = stateJSSqStr, jsCtxRegexp
case '/':
switch {
case i+1 < len(s) && s[i+1] == '/':
c.state, i = stateJSLineCmt, i+1
case i+1 < len(s) && s[i+1] == '*':
c.state, i = stateJSBlockCmt, i+1
case c.jsCtx == jsCtxRegexp:
c.state = stateJSRegexp
case c.jsCtx == jsCtxDivOp:
c.jsCtx = jsCtxRegexp
default:
return context{
state: stateError,
err: errorf(ErrSlashAmbig, nil, 0, "'/' could start a division or regexp: %.32q", s[i:]),
}, len(s)
}
default:
panic("unreachable")
}
return c, i + 1
}
// tJSDelimited is the context transition function for the JS string and regexp
// states.
func tJSDelimited(c context, s []byte) (context, int) {
specials := `\"`
switch c.state {
case stateJSSqStr:
specials = `\'`
case stateJSRegexp:
specials = `\/[]`
}
k, inCharset := 0, false
for {
i := k + bytes.IndexAny(s[k:], specials)
if i < k {
break
}
switch s[i] {
case '\\':
i++
if i == len(s) {
return context{
state: stateError,
err: errorf(ErrPartialEscape, nil, 0, "unfinished escape sequence in JS string: %q", s),
}, len(s)
}
case '[':
inCharset = true
case ']':
inCharset = false
default:
// end delimiter
if !inCharset {
c.state, c.jsCtx = stateJS, jsCtxDivOp
return c, i + 1
}
}
k = i + 1
}
if inCharset {
// This can be fixed by making context richer if interpolation
// into charsets is desired.
return context{
state: stateError,
err: errorf(ErrPartialCharset, nil, 0, "unfinished JS regexp charset: %q", s),
}, len(s)
}
return c, len(s)
}
var blockCommentEnd = []byte("*/")
// tBlockCmt is the context transition function for /*comment*/ states.
func tBlockCmt(c context, s []byte) (context, int) {
i := bytes.Index(s, blockCommentEnd)
if i == -1 {
return c, len(s)
}
switch c.state {
case stateJSBlockCmt:
c.state = stateJS
case stateCSSBlockCmt:
c.state = stateCSS
default:
panic(c.state.String())
}
return c, i + 2
}
// tLineCmt is the context transition function for //comment states.
func tLineCmt(c context, s []byte) (context, int) {
var lineTerminators string
var endState state
switch c.state {
case stateJSLineCmt:
lineTerminators, endState = "\n\r\u2028\u2029", stateJS
case stateCSSLineCmt:
lineTerminators, endState = "\n\f\r", stateCSS
// Line comments are not part of any published CSS standard but
// are supported by the 4 major browsers.
// This defines line comments as
// LINECOMMENT ::= "//" [^\n\f\d]*
// since https://www.w3.org/TR/css3-syntax/#SUBTOK-nl defines
// newlines:
// nl ::= #xA | #xD #xA | #xD | #xC
default:
panic(c.state.String())
}
i := bytes.IndexAny(s, lineTerminators)
if i == -1 {
return c, len(s)
}
c.state = endState
// Per section 7.4 of EcmaScript 5 : https://es5.github.com/#x7.4
// "However, the LineTerminator at the end of the line is not
// considered to be part of the single-line comment; it is
// recognized separately by the lexical grammar and becomes part
// of the stream of input elements for the syntactic grammar."
return c, i
}
// tCSS is the context transition function for the CSS state.
func tCSS(c context, s []byte) (context, int) {
// CSS quoted strings are almost never used except for:
// (1) URLs as in background: "/foo.png"
// (2) Multiword font-names as in font-family: "Times New Roman"
// (3) List separators in content values as in inline-lists:
// <style>
// ul.inlineList { list-style: none; padding:0 }
// ul.inlineList > li { display: inline }
// ul.inlineList > li:before { content: ", " }
// ul.inlineList > li:first-child:before { content: "" }
// </style>
// <ul class=inlineList><li>One<li>Two<li>Three</ul>
// (4) Attribute value selectors as in a[href="http://example.com/"]
//
// We conservatively treat all strings as URLs, but make some
// allowances to avoid confusion.
//
// In (1), our conservative assumption is justified.
// In (2), valid font names do not contain ':', '?', or '#', so our
// conservative assumption is fine since we will never transition past
// urlPartPreQuery.
// In (3), our protocol heuristic should not be tripped, and there
// should not be non-space content after a '?' or '#', so as long as
// we only %-encode RFC 3986 reserved characters we are ok.
// In (4), we should URL escape for URL attributes, and for others we
// have the attribute name available if our conservative assumption
// proves problematic for real code.
k := 0
for {
i := k + bytes.IndexAny(s[k:], `("'/`)
if i < k {
return c, len(s)
}
switch s[i] {
case '(':
// Look for url to the left.
p := bytes.TrimRight(s[:i], "\t\n\f\r ")
if endsWithCSSKeyword(p, "url") {
j := len(s) - len(bytes.TrimLeft(s[i+1:], "\t\n\f\r "))
switch {
case j != len(s) && s[j] == '"':
c.state, j = stateCSSDqURL, j+1
case j != len(s) && s[j] == '\'':
c.state, j = stateCSSSqURL, j+1
default:
c.state = stateCSSURL
}
return c, j
}
case '/':
if i+1 < len(s) {
switch s[i+1] {
case '/':
c.state = stateCSSLineCmt
return c, i + 2
case '*':
c.state = stateCSSBlockCmt
return c, i + 2
}
}
case '"':
c.state = stateCSSDqStr
return c, i + 1
case '\'':
c.state = stateCSSSqStr
return c, i + 1
}
k = i + 1
}
}
// tCSSStr is the context transition function for the CSS string and URL states.
func tCSSStr(c context, s []byte) (context, int) {
var endAndEsc string
switch c.state {
case stateCSSDqStr, stateCSSDqURL:
endAndEsc = `\"`
case stateCSSSqStr, stateCSSSqURL:
endAndEsc = `\'`
case stateCSSURL:
// Unquoted URLs end with a newline or close parenthesis.
// The below includes the wc (whitespace character) and nl.
endAndEsc = "\\\t\n\f\r )"
default:
panic(c.state.String())
}
k := 0
for {
i := k + bytes.IndexAny(s[k:], endAndEsc)
if i < k {
c, nread := tURL(c, decodeCSS(s[k:]))
return c, k + nread
}
if s[i] == '\\' {
i++
if i == len(s) {
return context{
state: stateError,
err: errorf(ErrPartialEscape, nil, 0, "unfinished escape sequence in CSS string: %q", s),
}, len(s)
}
} else {
c.state = stateCSS
return c, i + 1
}
c, _ = tURL(c, decodeCSS(s[:i+1]))
k = i + 1
}
}
// tError is the context transition function for the error state.
func tError(c context, s []byte) (context, int) {
return c, len(s)
}
// eatAttrName returns the largest j such that s[i:j] is an attribute name.
// It returns an error if s[i:] does not look like it begins with an
// attribute name, such as encountering a quote mark without a preceding
// equals sign.
func eatAttrName(s []byte, i int) (int, *Error) {
for j := i; j < len(s); j++ {
switch s[j] {
case ' ', '\t', '\n', '\f', '\r', '=', '>':
return j, nil
case '\'', '"', '<':
// These result in a parse warning in HTML5 and are
// indicative of serious problems if seen in an attr
// name in a template.
return -1, errorf(ErrBadHTML, nil, 0, "%q in attribute name: %.32q", s[j:j+1], s)
default:
// No-op.
}
}
return len(s), nil
}
var elementNameMap = map[string]element{
"script": elementScript,
"style": elementStyle,
"textarea": elementTextarea,
"title": elementTitle,
}
// asciiAlpha reports whether c is an ASCII letter.
func asciiAlpha(c byte) bool {
return 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z'
}
// asciiAlphaNum reports whether c is an ASCII letter or digit.
func asciiAlphaNum(c byte) bool {
return asciiAlpha(c) || '0' <= c && c <= '9'
}
// eatTagName returns the largest j such that s[i:j] is a tag name and the tag type.
func eatTagName(s []byte, i int) (int, element) {
if i == len(s) || !asciiAlpha(s[i]) {
return i, elementNone
}
j := i + 1
for j < len(s) {
x := s[j]
if asciiAlphaNum(x) {
j++
continue
}
// Allow "x-y" or "x:y" but not "x-", "-y", or "x--y".
if (x == ':' || x == '-') && j+1 < len(s) && asciiAlphaNum(s[j+1]) {
j += 2
continue
}
break
}
return j, elementNameMap[strings.ToLower(string(s[i:j]))]
}
// eatWhiteSpace returns the largest j such that s[i:j] is white space.
func eatWhiteSpace(s []byte, i int) int {
for j := i; j < len(s); j++ {
switch s[j] {
case ' ', '\t', '\n', '\f', '\r':
// No-op.
default:
return j
}
}
return len(s)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"strings"
)
// urlFilter returns its input unless it contains an unsafe scheme in which
// case it defangs the entire URL.
//
// Schemes that cause unintended side effects that are irreversible without user
// interaction are considered unsafe. For example, clicking on a "javascript:"
// link can immediately trigger JavaScript code execution.
//
// This filter conservatively assumes that all schemes other than the following
// are unsafe:
// - http: Navigates to a new website, and may open a new window or tab.
// These side effects can be reversed by navigating back to the
// previous website, or closing the window or tab. No irreversible
// changes will take place without further user interaction with
// the new website.
// - https: Same as http.
// - mailto: Opens an email program and starts a new draft. This side effect
// is not irreversible until the user explicitly clicks send; it
// can be undone by closing the email program.
//
// To allow URLs containing other schemes to bypass this filter, developers must
// explicitly indicate that such a URL is expected and safe by encapsulating it
// in a template.URL value.
func urlFilter(args ...any) string {
s, t := stringify(args...)
if t == contentTypeURL {
return s
}
if !isSafeURL(s) {
return "#" + filterFailsafe
}
return s
}
// isSafeURL is true if s is a relative URL or if URL has a protocol in
// (http, https, mailto).
func isSafeURL(s string) bool {
if protocol, _, ok := strings.Cut(s, ":"); ok && !strings.Contains(protocol, "/") {
if !strings.EqualFold(protocol, "http") && !strings.EqualFold(protocol, "https") && !strings.EqualFold(protocol, "mailto") {
return false
}
}
return true
}
// urlEscaper produces an output that can be embedded in a URL query.
// The output can be embedded in an HTML attribute without further escaping.
func urlEscaper(args ...any) string {
return urlProcessor(false, args...)
}
// urlNormalizer normalizes URL content so it can be embedded in a quote-delimited
// string or parenthesis delimited url(...).
// The normalizer does not encode all HTML specials. Specifically, it does not
// encode '&' so correct embedding in an HTML attribute requires escaping of
// '&' to '&'.
func urlNormalizer(args ...any) string {
return urlProcessor(true, args...)
}
// urlProcessor normalizes (when norm is true) or escapes its input to produce
// a valid hierarchical or opaque URL part.
func urlProcessor(norm bool, args ...any) string {
s, t := stringify(args...)
if t == contentTypeURL {
norm = true
}
var b strings.Builder
if processURLOnto(s, norm, &b) {
return b.String()
}
return s
}
// processURLOnto appends a normalized URL corresponding to its input to b
// and reports whether the appended content differs from s.
func processURLOnto(s string, norm bool, b *strings.Builder) bool {
b.Grow(len(s) + 16)
written := 0
// The byte loop below assumes that all URLs use UTF-8 as the
// content-encoding. This is similar to the URI to IRI encoding scheme
// defined in section 3.1 of RFC 3987, and behaves the same as the
// EcmaScript builtin encodeURIComponent.
// It should not cause any misencoding of URLs in pages with
// Content-type: text/html;charset=UTF-8.
for i, n := 0, len(s); i < n; i++ {
c := s[i]
switch c {
// Single quote and parens are sub-delims in RFC 3986, but we
// escape them so the output can be embedded in single
// quoted attributes and unquoted CSS url(...) constructs.
// Single quotes are reserved in URLs, but are only used in
// the obsolete "mark" rule in an appendix in RFC 3986
// so can be safely encoded.
case '!', '#', '$', '&', '*', '+', ',', '/', ':', ';', '=', '?', '@', '[', ']':
if norm {
continue
}
// Unreserved according to RFC 3986 sec 2.3
// "For consistency, percent-encoded octets in the ranges of
// ALPHA (%41-%5A and %61-%7A), DIGIT (%30-%39), hyphen (%2D),
// period (%2E), underscore (%5F), or tilde (%7E) should not be
// created by URI producers
case '-', '.', '_', '~':
continue
case '%':
// When normalizing do not re-encode valid escapes.
if norm && i+2 < len(s) && isHex(s[i+1]) && isHex(s[i+2]) {
continue
}
default:
// Unreserved according to RFC 3986 sec 2.3
if 'a' <= c && c <= 'z' {
continue
}
if 'A' <= c && c <= 'Z' {
continue
}
if '0' <= c && c <= '9' {
continue
}
}
b.WriteString(s[written:i])
fmt.Fprintf(b, "%%%02x", c)
written = i + 1
}
b.WriteString(s[written:])
return written != 0
}
// Filters and normalizes srcset values which are comma separated
// URLs followed by metadata.
func srcsetFilterAndEscaper(args ...any) string {
s, t := stringify(args...)
switch t {
case contentTypeSrcset:
return s
case contentTypeURL:
// Normalizing gets rid of all HTML whitespace
// which separate the image URL from its metadata.
var b strings.Builder
if processURLOnto(s, true, &b) {
s = b.String()
}
// Additionally, commas separate one source from another.
return strings.ReplaceAll(s, ",", "%2c")
}
var b strings.Builder
written := 0
for i := 0; i < len(s); i++ {
if s[i] == ',' {
filterSrcsetElement(s, written, i, &b)
b.WriteString(",")
written = i + 1
}
}
filterSrcsetElement(s, written, len(s), &b)
return b.String()
}
// Derived from https://play.golang.org/p/Dhmj7FORT5
const htmlSpaceAndASCIIAlnumBytes = "\x00\x36\x00\x00\x01\x00\xff\x03\xfe\xff\xff\x07\xfe\xff\xff\x07"
// isHTMLSpace is true iff c is a whitespace character per
// https://infra.spec.whatwg.org/#ascii-whitespace
func isHTMLSpace(c byte) bool {
return (c <= 0x20) && 0 != (htmlSpaceAndASCIIAlnumBytes[c>>3]&(1<<uint(c&0x7)))
}
func isHTMLSpaceOrASCIIAlnum(c byte) bool {
return (c < 0x80) && 0 != (htmlSpaceAndASCIIAlnumBytes[c>>3]&(1<<uint(c&0x7)))
}
func filterSrcsetElement(s string, left int, right int, b *strings.Builder) {
start := left
for start < right && isHTMLSpace(s[start]) {
start++
}
end := right
for i := start; i < right; i++ {
if isHTMLSpace(s[i]) {
end = i
break
}
}
if url := s[start:end]; isSafeURL(url) {
// If image metadata is only spaces or alnums then
// we don't need to URL normalize it.
metadataOk := true
for i := end; i < right; i++ {
if !isHTMLSpaceOrASCIIAlnum(s[i]) {
metadataOk = false
break
}
}
if metadataOk {
b.WriteString(s[left:start])
processURLOnto(url, true, b)
b.WriteString(s[end:right])
return
}
}
b.WriteString("#")
b.WriteString(filterFailsafe)
}
// Code generated by "stringer -type urlPart"; DO NOT EDIT.
package template
import "strconv"
const _urlPart_name = "urlPartNoneurlPartPreQueryurlPartQueryOrFragurlPartUnknown"
var _urlPart_index = [...]uint8{0, 11, 26, 44, 58}
func (i urlPart) String() string {
if i >= urlPart(len(_urlPart_index)-1) {
return "urlPart(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _urlPart_name[_urlPart_index[i]:_urlPart_index[i+1]]
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package color implements a basic color library.
package color
// Color can convert itself to alpha-premultiplied 16-bits per channel RGBA.
// The conversion may be lossy.
type Color interface {
// RGBA returns the alpha-premultiplied red, green, blue and alpha values
// for the color. Each value ranges within [0, 0xffff], but is represented
// by a uint32 so that multiplying by a blend factor up to 0xffff will not
// overflow.
//
// An alpha-premultiplied color component c has been scaled by alpha (a),
// so has valid values 0 <= c <= a.
RGBA() (r, g, b, a uint32)
}
// RGBA represents a traditional 32-bit alpha-premultiplied color, having 8
// bits for each of red, green, blue and alpha.
//
// An alpha-premultiplied color component C has been scaled by alpha (A), so
// has valid values 0 <= C <= A.
type RGBA struct {
R, G, B, A uint8
}
func (c RGBA) RGBA() (r, g, b, a uint32) {
r = uint32(c.R)
r |= r << 8
g = uint32(c.G)
g |= g << 8
b = uint32(c.B)
b |= b << 8
a = uint32(c.A)
a |= a << 8
return
}
// RGBA64 represents a 64-bit alpha-premultiplied color, having 16 bits for
// each of red, green, blue and alpha.
//
// An alpha-premultiplied color component C has been scaled by alpha (A), so
// has valid values 0 <= C <= A.
type RGBA64 struct {
R, G, B, A uint16
}
func (c RGBA64) RGBA() (r, g, b, a uint32) {
return uint32(c.R), uint32(c.G), uint32(c.B), uint32(c.A)
}
// NRGBA represents a non-alpha-premultiplied 32-bit color.
type NRGBA struct {
R, G, B, A uint8
}
func (c NRGBA) RGBA() (r, g, b, a uint32) {
r = uint32(c.R)
r |= r << 8
r *= uint32(c.A)
r /= 0xff
g = uint32(c.G)
g |= g << 8
g *= uint32(c.A)
g /= 0xff
b = uint32(c.B)
b |= b << 8
b *= uint32(c.A)
b /= 0xff
a = uint32(c.A)
a |= a << 8
return
}
// NRGBA64 represents a non-alpha-premultiplied 64-bit color,
// having 16 bits for each of red, green, blue and alpha.
type NRGBA64 struct {
R, G, B, A uint16
}
func (c NRGBA64) RGBA() (r, g, b, a uint32) {
r = uint32(c.R)
r *= uint32(c.A)
r /= 0xffff
g = uint32(c.G)
g *= uint32(c.A)
g /= 0xffff
b = uint32(c.B)
b *= uint32(c.A)
b /= 0xffff
a = uint32(c.A)
return
}
// Alpha represents an 8-bit alpha color.
type Alpha struct {
A uint8
}
func (c Alpha) RGBA() (r, g, b, a uint32) {
a = uint32(c.A)
a |= a << 8
return a, a, a, a
}
// Alpha16 represents a 16-bit alpha color.
type Alpha16 struct {
A uint16
}
func (c Alpha16) RGBA() (r, g, b, a uint32) {
a = uint32(c.A)
return a, a, a, a
}
// Gray represents an 8-bit grayscale color.
type Gray struct {
Y uint8
}
func (c Gray) RGBA() (r, g, b, a uint32) {
y := uint32(c.Y)
y |= y << 8
return y, y, y, 0xffff
}
// Gray16 represents a 16-bit grayscale color.
type Gray16 struct {
Y uint16
}
func (c Gray16) RGBA() (r, g, b, a uint32) {
y := uint32(c.Y)
return y, y, y, 0xffff
}
// Model can convert any Color to one from its own color model. The conversion
// may be lossy.
type Model interface {
Convert(c Color) Color
}
// ModelFunc returns a Model that invokes f to implement the conversion.
func ModelFunc(f func(Color) Color) Model {
// Note: using *modelFunc as the implementation
// means that callers can still use comparisons
// like m == RGBAModel. This is not possible if
// we use the func value directly, because funcs
// are no longer comparable.
return &modelFunc{f}
}
type modelFunc struct {
f func(Color) Color
}
func (m *modelFunc) Convert(c Color) Color {
return m.f(c)
}
// Models for the standard color types.
var (
RGBAModel Model = ModelFunc(rgbaModel)
RGBA64Model Model = ModelFunc(rgba64Model)
NRGBAModel Model = ModelFunc(nrgbaModel)
NRGBA64Model Model = ModelFunc(nrgba64Model)
AlphaModel Model = ModelFunc(alphaModel)
Alpha16Model Model = ModelFunc(alpha16Model)
GrayModel Model = ModelFunc(grayModel)
Gray16Model Model = ModelFunc(gray16Model)
)
func rgbaModel(c Color) Color {
if _, ok := c.(RGBA); ok {
return c
}
r, g, b, a := c.RGBA()
return RGBA{uint8(r >> 8), uint8(g >> 8), uint8(b >> 8), uint8(a >> 8)}
}
func rgba64Model(c Color) Color {
if _, ok := c.(RGBA64); ok {
return c
}
r, g, b, a := c.RGBA()
return RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func nrgbaModel(c Color) Color {
if _, ok := c.(NRGBA); ok {
return c
}
r, g, b, a := c.RGBA()
if a == 0xffff {
return NRGBA{uint8(r >> 8), uint8(g >> 8), uint8(b >> 8), 0xff}
}
if a == 0 {
return NRGBA{0, 0, 0, 0}
}
// Since Color.RGBA returns an alpha-premultiplied color, we should have r <= a && g <= a && b <= a.
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
return NRGBA{uint8(r >> 8), uint8(g >> 8), uint8(b >> 8), uint8(a >> 8)}
}
func nrgba64Model(c Color) Color {
if _, ok := c.(NRGBA64); ok {
return c
}
r, g, b, a := c.RGBA()
if a == 0xffff {
return NRGBA64{uint16(r), uint16(g), uint16(b), 0xffff}
}
if a == 0 {
return NRGBA64{0, 0, 0, 0}
}
// Since Color.RGBA returns an alpha-premultiplied color, we should have r <= a && g <= a && b <= a.
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
return NRGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func alphaModel(c Color) Color {
if _, ok := c.(Alpha); ok {
return c
}
_, _, _, a := c.RGBA()
return Alpha{uint8(a >> 8)}
}
func alpha16Model(c Color) Color {
if _, ok := c.(Alpha16); ok {
return c
}
_, _, _, a := c.RGBA()
return Alpha16{uint16(a)}
}
func grayModel(c Color) Color {
if _, ok := c.(Gray); ok {
return c
}
r, g, b, _ := c.RGBA()
// These coefficients (the fractions 0.299, 0.587 and 0.114) are the same
// as those given by the JFIF specification and used by func RGBToYCbCr in
// ycbcr.go.
//
// Note that 19595 + 38470 + 7471 equals 65536.
//
// The 24 is 16 + 8. The 16 is the same as used in RGBToYCbCr. The 8 is
// because the return value is 8 bit color, not 16 bit color.
y := (19595*r + 38470*g + 7471*b + 1<<15) >> 24
return Gray{uint8(y)}
}
func gray16Model(c Color) Color {
if _, ok := c.(Gray16); ok {
return c
}
r, g, b, _ := c.RGBA()
// These coefficients (the fractions 0.299, 0.587 and 0.114) are the same
// as those given by the JFIF specification and used by func RGBToYCbCr in
// ycbcr.go.
//
// Note that 19595 + 38470 + 7471 equals 65536.
y := (19595*r + 38470*g + 7471*b + 1<<15) >> 16
return Gray16{uint16(y)}
}
// Palette is a palette of colors.
type Palette []Color
// Convert returns the palette color closest to c in Euclidean R,G,B space.
func (p Palette) Convert(c Color) Color {
if len(p) == 0 {
return nil
}
return p[p.Index(c)]
}
// Index returns the index of the palette color closest to c in Euclidean
// R,G,B,A space.
func (p Palette) Index(c Color) int {
// A batch version of this computation is in image/draw/draw.go.
cr, cg, cb, ca := c.RGBA()
ret, bestSum := 0, uint32(1<<32-1)
for i, v := range p {
vr, vg, vb, va := v.RGBA()
sum := sqDiff(cr, vr) + sqDiff(cg, vg) + sqDiff(cb, vb) + sqDiff(ca, va)
if sum < bestSum {
if sum == 0 {
return i
}
ret, bestSum = i, sum
}
}
return ret
}
// sqDiff returns the squared-difference of x and y, shifted by 2 so that
// adding four of those won't overflow a uint32.
//
// x and y are both assumed to be in the range [0, 0xffff].
func sqDiff(x, y uint32) uint32 {
// The canonical code of this function looks as follows:
//
// var d uint32
// if x > y {
// d = x - y
// } else {
// d = y - x
// }
// return (d * d) >> 2
//
// Language spec guarantees the following properties of unsigned integer
// values operations with respect to overflow/wrap around:
//
// > For unsigned integer values, the operations +, -, *, and << are
// > computed modulo 2n, where n is the bit width of the unsigned
// > integer's type. Loosely speaking, these unsigned integer operations
// > discard high bits upon overflow, and programs may rely on ``wrap
// > around''.
//
// Considering these properties and the fact that this function is
// called in the hot paths (x,y loops), it is reduced to the below code
// which is slightly faster. See TestSqDiff for correctness check.
d := x - y
return (d * d) >> 2
}
// Standard colors.
var (
Black = Gray16{0}
White = Gray16{0xffff}
Transparent = Alpha16{0}
Opaque = Alpha16{0xffff}
)
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package color
// RGBToYCbCr converts an RGB triple to a Y'CbCr triple.
func RGBToYCbCr(r, g, b uint8) (uint8, uint8, uint8) {
// The JFIF specification says:
// Y' = 0.2990*R + 0.5870*G + 0.1140*B
// Cb = -0.1687*R - 0.3313*G + 0.5000*B + 128
// Cr = 0.5000*R - 0.4187*G - 0.0813*B + 128
// https://www.w3.org/Graphics/JPEG/jfif3.pdf says Y but means Y'.
r1 := int32(r)
g1 := int32(g)
b1 := int32(b)
// yy is in range [0,0xff].
//
// Note that 19595 + 38470 + 7471 equals 65536.
yy := (19595*r1 + 38470*g1 + 7471*b1 + 1<<15) >> 16
// The bit twiddling below is equivalent to
//
// cb := (-11056*r1 - 21712*g1 + 32768*b1 + 257<<15) >> 16
// if cb < 0 {
// cb = 0
// } else if cb > 0xff {
// cb = ^int32(0)
// }
//
// but uses fewer branches and is faster.
// Note that the uint8 type conversion in the return
// statement will convert ^int32(0) to 0xff.
// The code below to compute cr uses a similar pattern.
//
// Note that -11056 - 21712 + 32768 equals 0.
cb := -11056*r1 - 21712*g1 + 32768*b1 + 257<<15
if uint32(cb)&0xff000000 == 0 {
cb >>= 16
} else {
cb = ^(cb >> 31)
}
// Note that 32768 - 27440 - 5328 equals 0.
cr := 32768*r1 - 27440*g1 - 5328*b1 + 257<<15
if uint32(cr)&0xff000000 == 0 {
cr >>= 16
} else {
cr = ^(cr >> 31)
}
return uint8(yy), uint8(cb), uint8(cr)
}
// YCbCrToRGB converts a Y'CbCr triple to an RGB triple.
func YCbCrToRGB(y, cb, cr uint8) (uint8, uint8, uint8) {
// The JFIF specification says:
// R = Y' + 1.40200*(Cr-128)
// G = Y' - 0.34414*(Cb-128) - 0.71414*(Cr-128)
// B = Y' + 1.77200*(Cb-128)
// https://www.w3.org/Graphics/JPEG/jfif3.pdf says Y but means Y'.
//
// Those formulae use non-integer multiplication factors. When computing,
// integer math is generally faster than floating point math. We multiply
// all of those factors by 1<<16 and round to the nearest integer:
// 91881 = roundToNearestInteger(1.40200 * 65536).
// 22554 = roundToNearestInteger(0.34414 * 65536).
// 46802 = roundToNearestInteger(0.71414 * 65536).
// 116130 = roundToNearestInteger(1.77200 * 65536).
//
// Adding a rounding adjustment in the range [0, 1<<16-1] and then shifting
// right by 16 gives us an integer math version of the original formulae.
// R = (65536*Y' + 91881 *(Cr-128) + adjustment) >> 16
// G = (65536*Y' - 22554 *(Cb-128) - 46802*(Cr-128) + adjustment) >> 16
// B = (65536*Y' + 116130 *(Cb-128) + adjustment) >> 16
// A constant rounding adjustment of 1<<15, one half of 1<<16, would mean
// round-to-nearest when dividing by 65536 (shifting right by 16).
// Similarly, a constant rounding adjustment of 0 would mean round-down.
//
// Defining YY1 = 65536*Y' + adjustment simplifies the formulae and
// requires fewer CPU operations:
// R = (YY1 + 91881 *(Cr-128) ) >> 16
// G = (YY1 - 22554 *(Cb-128) - 46802*(Cr-128)) >> 16
// B = (YY1 + 116130 *(Cb-128) ) >> 16
//
// The inputs (y, cb, cr) are 8 bit color, ranging in [0x00, 0xff]. In this
// function, the output is also 8 bit color, but in the related YCbCr.RGBA
// method, below, the output is 16 bit color, ranging in [0x0000, 0xffff].
// Outputting 16 bit color simply requires changing the 16 to 8 in the "R =
// etc >> 16" equation, and likewise for G and B.
//
// As mentioned above, a constant rounding adjustment of 1<<15 is a natural
// choice, but there is an additional constraint: if c0 := YCbCr{Y: y, Cb:
// 0x80, Cr: 0x80} and c1 := Gray{Y: y} then c0.RGBA() should equal
// c1.RGBA(). Specifically, if y == 0 then "R = etc >> 8" should yield
// 0x0000 and if y == 0xff then "R = etc >> 8" should yield 0xffff. If we
// used a constant rounding adjustment of 1<<15, then it would yield 0x0080
// and 0xff80 respectively.
//
// Note that when cb == 0x80 and cr == 0x80 then the formulae collapse to:
// R = YY1 >> n
// G = YY1 >> n
// B = YY1 >> n
// where n is 16 for this function (8 bit color output) and 8 for the
// YCbCr.RGBA method (16 bit color output).
//
// The solution is to make the rounding adjustment non-constant, and equal
// to 257*Y', which ranges over [0, 1<<16-1] as Y' ranges over [0, 255].
// YY1 is then defined as:
// YY1 = 65536*Y' + 257*Y'
// or equivalently:
// YY1 = Y' * 0x10101
yy1 := int32(y) * 0x10101
cb1 := int32(cb) - 128
cr1 := int32(cr) - 128
// The bit twiddling below is equivalent to
//
// r := (yy1 + 91881*cr1) >> 16
// if r < 0 {
// r = 0
// } else if r > 0xff {
// r = ^int32(0)
// }
//
// but uses fewer branches and is faster.
// Note that the uint8 type conversion in the return
// statement will convert ^int32(0) to 0xff.
// The code below to compute g and b uses a similar pattern.
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 {
r >>= 16
} else {
r = ^(r >> 31)
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 {
g >>= 16
} else {
g = ^(g >> 31)
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 {
b >>= 16
} else {
b = ^(b >> 31)
}
return uint8(r), uint8(g), uint8(b)
}
// YCbCr represents a fully opaque 24-bit Y'CbCr color, having 8 bits each for
// one luma and two chroma components.
//
// JPEG, VP8, the MPEG family and other codecs use this color model. Such
// codecs often use the terms YUV and Y'CbCr interchangeably, but strictly
// speaking, the term YUV applies only to analog video signals, and Y' (luma)
// is Y (luminance) after applying gamma correction.
//
// Conversion between RGB and Y'CbCr is lossy and there are multiple, slightly
// different formulae for converting between the two. This package follows
// the JFIF specification at https://www.w3.org/Graphics/JPEG/jfif3.pdf.
type YCbCr struct {
Y, Cb, Cr uint8
}
func (c YCbCr) RGBA() (uint32, uint32, uint32, uint32) {
// This code is a copy of the YCbCrToRGB function above, except that it
// returns values in the range [0, 0xffff] instead of [0, 0xff]. There is a
// subtle difference between doing this and having YCbCr satisfy the Color
// interface by first converting to an RGBA. The latter loses some
// information by going to and from 8 bits per channel.
//
// For example, this code:
// const y, cb, cr = 0x7f, 0x7f, 0x7f
// r, g, b := color.YCbCrToRGB(y, cb, cr)
// r0, g0, b0, _ := color.YCbCr{y, cb, cr}.RGBA()
// r1, g1, b1, _ := color.RGBA{r, g, b, 0xff}.RGBA()
// fmt.Printf("0x%04x 0x%04x 0x%04x\n", r0, g0, b0)
// fmt.Printf("0x%04x 0x%04x 0x%04x\n", r1, g1, b1)
// prints:
// 0x7e18 0x808d 0x7db9
// 0x7e7e 0x8080 0x7d7d
yy1 := int32(c.Y) * 0x10101
cb1 := int32(c.Cb) - 128
cr1 := int32(c.Cr) - 128
// The bit twiddling below is equivalent to
//
// r := (yy1 + 91881*cr1) >> 8
// if r < 0 {
// r = 0
// } else if r > 0xff {
// r = 0xffff
// }
//
// but uses fewer branches and is faster.
// The code below to compute g and b uses a similar pattern.
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 {
r >>= 8
} else {
r = ^(r >> 31) & 0xffff
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 {
g >>= 8
} else {
g = ^(g >> 31) & 0xffff
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 {
b >>= 8
} else {
b = ^(b >> 31) & 0xffff
}
return uint32(r), uint32(g), uint32(b), 0xffff
}
// YCbCrModel is the Model for Y'CbCr colors.
var YCbCrModel Model = ModelFunc(yCbCrModel)
func yCbCrModel(c Color) Color {
if _, ok := c.(YCbCr); ok {
return c
}
r, g, b, _ := c.RGBA()
y, u, v := RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8))
return YCbCr{y, u, v}
}
// NYCbCrA represents a non-alpha-premultiplied Y'CbCr-with-alpha color, having
// 8 bits each for one luma, two chroma and one alpha component.
type NYCbCrA struct {
YCbCr
A uint8
}
func (c NYCbCrA) RGBA() (uint32, uint32, uint32, uint32) {
// The first part of this method is the same as YCbCr.RGBA.
yy1 := int32(c.Y) * 0x10101
cb1 := int32(c.Cb) - 128
cr1 := int32(c.Cr) - 128
// The bit twiddling below is equivalent to
//
// r := (yy1 + 91881*cr1) >> 8
// if r < 0 {
// r = 0
// } else if r > 0xff {
// r = 0xffff
// }
//
// but uses fewer branches and is faster.
// The code below to compute g and b uses a similar pattern.
r := yy1 + 91881*cr1
if uint32(r)&0xff000000 == 0 {
r >>= 8
} else {
r = ^(r >> 31) & 0xffff
}
g := yy1 - 22554*cb1 - 46802*cr1
if uint32(g)&0xff000000 == 0 {
g >>= 8
} else {
g = ^(g >> 31) & 0xffff
}
b := yy1 + 116130*cb1
if uint32(b)&0xff000000 == 0 {
b >>= 8
} else {
b = ^(b >> 31) & 0xffff
}
// The second part of this method applies the alpha.
a := uint32(c.A) * 0x101
return uint32(r) * a / 0xffff, uint32(g) * a / 0xffff, uint32(b) * a / 0xffff, a
}
// NYCbCrAModel is the Model for non-alpha-premultiplied Y'CbCr-with-alpha
// colors.
var NYCbCrAModel Model = ModelFunc(nYCbCrAModel)
func nYCbCrAModel(c Color) Color {
switch c := c.(type) {
case NYCbCrA:
return c
case YCbCr:
return NYCbCrA{c, 0xff}
}
r, g, b, a := c.RGBA()
// Convert from alpha-premultiplied to non-alpha-premultiplied.
if a != 0 {
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
}
y, u, v := RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8))
return NYCbCrA{YCbCr{Y: y, Cb: u, Cr: v}, uint8(a >> 8)}
}
// RGBToCMYK converts an RGB triple to a CMYK quadruple.
func RGBToCMYK(r, g, b uint8) (uint8, uint8, uint8, uint8) {
rr := uint32(r)
gg := uint32(g)
bb := uint32(b)
w := rr
if w < gg {
w = gg
}
if w < bb {
w = bb
}
if w == 0 {
return 0, 0, 0, 0xff
}
c := (w - rr) * 0xff / w
m := (w - gg) * 0xff / w
y := (w - bb) * 0xff / w
return uint8(c), uint8(m), uint8(y), uint8(0xff - w)
}
// CMYKToRGB converts a CMYK quadruple to an RGB triple.
func CMYKToRGB(c, m, y, k uint8) (uint8, uint8, uint8) {
w := 0xffff - uint32(k)*0x101
r := (0xffff - uint32(c)*0x101) * w / 0xffff
g := (0xffff - uint32(m)*0x101) * w / 0xffff
b := (0xffff - uint32(y)*0x101) * w / 0xffff
return uint8(r >> 8), uint8(g >> 8), uint8(b >> 8)
}
// CMYK represents a fully opaque CMYK color, having 8 bits for each of cyan,
// magenta, yellow and black.
//
// It is not associated with any particular color profile.
type CMYK struct {
C, M, Y, K uint8
}
func (c CMYK) RGBA() (uint32, uint32, uint32, uint32) {
// This code is a copy of the CMYKToRGB function above, except that it
// returns values in the range [0, 0xffff] instead of [0, 0xff].
w := 0xffff - uint32(c.K)*0x101
r := (0xffff - uint32(c.C)*0x101) * w / 0xffff
g := (0xffff - uint32(c.M)*0x101) * w / 0xffff
b := (0xffff - uint32(c.Y)*0x101) * w / 0xffff
return r, g, b, 0xffff
}
// CMYKModel is the Model for CMYK colors.
var CMYKModel Model = ModelFunc(cmykModel)
func cmykModel(c Color) Color {
if _, ok := c.(CMYK); ok {
return c
}
r, g, b, _ := c.RGBA()
cc, mm, yy, kk := RGBToCMYK(uint8(r>>8), uint8(g>>8), uint8(b>>8))
return CMYK{cc, mm, yy, kk}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package draw provides image composition functions.
//
// See "The Go image/draw package" for an introduction to this package:
// https://golang.org/doc/articles/image_draw.html
package draw
import (
"image"
"image/color"
"image/internal/imageutil"
)
// m is the maximum color value returned by image.Color.RGBA.
const m = 1<<16 - 1
// Image is an image.Image with a Set method to change a single pixel.
type Image interface {
image.Image
Set(x, y int, c color.Color)
}
// RGBA64Image extends both the Image and image.RGBA64Image interfaces with a
// SetRGBA64 method to change a single pixel. SetRGBA64 is equivalent to
// calling Set, but it can avoid allocations from converting concrete color
// types to the color.Color interface type.
type RGBA64Image interface {
image.RGBA64Image
Set(x, y int, c color.Color)
SetRGBA64(x, y int, c color.RGBA64)
}
// Quantizer produces a palette for an image.
type Quantizer interface {
// Quantize appends up to cap(p) - len(p) colors to p and returns the
// updated palette suitable for converting m to a paletted image.
Quantize(p color.Palette, m image.Image) color.Palette
}
// Op is a Porter-Duff compositing operator.
type Op int
const (
// Over specifies ``(src in mask) over dst''.
Over Op = iota
// Src specifies ``src in mask''.
Src
)
// Draw implements the Drawer interface by calling the Draw function with this
// Op.
func (op Op) Draw(dst Image, r image.Rectangle, src image.Image, sp image.Point) {
DrawMask(dst, r, src, sp, nil, image.Point{}, op)
}
// Drawer contains the Draw method.
type Drawer interface {
// Draw aligns r.Min in dst with sp in src and then replaces the
// rectangle r in dst with the result of drawing src on dst.
Draw(dst Image, r image.Rectangle, src image.Image, sp image.Point)
}
// FloydSteinberg is a Drawer that is the Src Op with Floyd-Steinberg error
// diffusion.
var FloydSteinberg Drawer = floydSteinberg{}
type floydSteinberg struct{}
func (floydSteinberg) Draw(dst Image, r image.Rectangle, src image.Image, sp image.Point) {
clip(dst, &r, src, &sp, nil, nil)
if r.Empty() {
return
}
drawPaletted(dst, r, src, sp, true)
}
// clip clips r against each image's bounds (after translating into the
// destination image's coordinate space) and shifts the points sp and mp by
// the same amount as the change in r.Min.
func clip(dst Image, r *image.Rectangle, src image.Image, sp *image.Point, mask image.Image, mp *image.Point) {
orig := r.Min
*r = r.Intersect(dst.Bounds())
*r = r.Intersect(src.Bounds().Add(orig.Sub(*sp)))
if mask != nil {
*r = r.Intersect(mask.Bounds().Add(orig.Sub(*mp)))
}
dx := r.Min.X - orig.X
dy := r.Min.Y - orig.Y
if dx == 0 && dy == 0 {
return
}
sp.X += dx
sp.Y += dy
if mp != nil {
mp.X += dx
mp.Y += dy
}
}
func processBackward(dst image.Image, r image.Rectangle, src image.Image, sp image.Point) bool {
return dst == src &&
r.Overlaps(r.Add(sp.Sub(r.Min))) &&
(sp.Y < r.Min.Y || (sp.Y == r.Min.Y && sp.X < r.Min.X))
}
// Draw calls DrawMask with a nil mask.
func Draw(dst Image, r image.Rectangle, src image.Image, sp image.Point, op Op) {
DrawMask(dst, r, src, sp, nil, image.Point{}, op)
}
// DrawMask aligns r.Min in dst with sp in src and mp in mask and then replaces the rectangle r
// in dst with the result of a Porter-Duff composition. A nil mask is treated as opaque.
func DrawMask(dst Image, r image.Rectangle, src image.Image, sp image.Point, mask image.Image, mp image.Point, op Op) {
clip(dst, &r, src, &sp, mask, &mp)
if r.Empty() {
return
}
// Fast paths for special cases. If none of them apply, then we fall back
// to general but slower implementations.
//
// For NRGBA and NRGBA64 image types, the code paths aren't just faster.
// They also avoid the information loss that would otherwise occur from
// converting non-alpha-premultiplied color to and from alpha-premultiplied
// color. See TestDrawSrcNonpremultiplied.
switch dst0 := dst.(type) {
case *image.RGBA:
if op == Over {
if mask == nil {
switch src0 := src.(type) {
case *image.Uniform:
sr, sg, sb, sa := src0.RGBA()
if sa == 0xffff {
drawFillSrc(dst0, r, sr, sg, sb, sa)
} else {
drawFillOver(dst0, r, sr, sg, sb, sa)
}
return
case *image.RGBA:
drawCopyOver(dst0, r, src0, sp)
return
case *image.NRGBA:
drawNRGBAOver(dst0, r, src0, sp)
return
case *image.YCbCr:
// An image.YCbCr is always fully opaque, and so if the
// mask is nil (i.e. fully opaque) then the op is
// effectively always Src. Similarly for image.Gray and
// image.CMYK.
if imageutil.DrawYCbCr(dst0, r, src0, sp) {
return
}
case *image.Gray:
drawGray(dst0, r, src0, sp)
return
case *image.CMYK:
drawCMYK(dst0, r, src0, sp)
return
}
} else if mask0, ok := mask.(*image.Alpha); ok {
switch src0 := src.(type) {
case *image.Uniform:
drawGlyphOver(dst0, r, src0, mask0, mp)
return
case *image.RGBA:
drawRGBAMaskOver(dst0, r, src0, sp, mask0, mp)
return
case *image.Gray:
drawGrayMaskOver(dst0, r, src0, sp, mask0, mp)
return
// Case order matters. The next case (image.RGBA64Image) is an
// interface type that the concrete types above also implement.
case image.RGBA64Image:
drawRGBA64ImageMaskOver(dst0, r, src0, sp, mask0, mp)
return
}
}
} else {
if mask == nil {
switch src0 := src.(type) {
case *image.Uniform:
sr, sg, sb, sa := src0.RGBA()
drawFillSrc(dst0, r, sr, sg, sb, sa)
return
case *image.RGBA:
d0 := dst0.PixOffset(r.Min.X, r.Min.Y)
s0 := src0.PixOffset(sp.X, sp.Y)
drawCopySrc(
dst0.Pix[d0:], dst0.Stride, r, src0.Pix[s0:], src0.Stride, sp, 4*r.Dx())
return
case *image.NRGBA:
drawNRGBASrc(dst0, r, src0, sp)
return
case *image.YCbCr:
if imageutil.DrawYCbCr(dst0, r, src0, sp) {
return
}
case *image.Gray:
drawGray(dst0, r, src0, sp)
return
case *image.CMYK:
drawCMYK(dst0, r, src0, sp)
return
}
}
}
drawRGBA(dst0, r, src, sp, mask, mp, op)
return
case *image.Paletted:
if op == Src && mask == nil {
if src0, ok := src.(*image.Uniform); ok {
colorIndex := uint8(dst0.Palette.Index(src0.C))
i0 := dst0.PixOffset(r.Min.X, r.Min.Y)
i1 := i0 + r.Dx()
for i := i0; i < i1; i++ {
dst0.Pix[i] = colorIndex
}
firstRow := dst0.Pix[i0:i1]
for y := r.Min.Y + 1; y < r.Max.Y; y++ {
i0 += dst0.Stride
i1 += dst0.Stride
copy(dst0.Pix[i0:i1], firstRow)
}
return
} else if !processBackward(dst, r, src, sp) {
drawPaletted(dst0, r, src, sp, false)
return
}
}
case *image.NRGBA:
if op == Src && mask == nil {
if src0, ok := src.(*image.NRGBA); ok {
d0 := dst0.PixOffset(r.Min.X, r.Min.Y)
s0 := src0.PixOffset(sp.X, sp.Y)
drawCopySrc(
dst0.Pix[d0:], dst0.Stride, r, src0.Pix[s0:], src0.Stride, sp, 4*r.Dx())
return
}
}
case *image.NRGBA64:
if op == Src && mask == nil {
if src0, ok := src.(*image.NRGBA64); ok {
d0 := dst0.PixOffset(r.Min.X, r.Min.Y)
s0 := src0.PixOffset(sp.X, sp.Y)
drawCopySrc(
dst0.Pix[d0:], dst0.Stride, r, src0.Pix[s0:], src0.Stride, sp, 8*r.Dx())
return
}
}
}
x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1
if processBackward(dst, r, src, sp) {
x0, x1, dx = x1-1, x0-1, -1
y0, y1, dy = y1-1, y0-1, -1
}
// FALLBACK1.17
//
// Try the draw.RGBA64Image and image.RGBA64Image interfaces, part of the
// standard library since Go 1.17. These are like the draw.Image and
// image.Image interfaces but they can avoid allocations from converting
// concrete color types to the color.Color interface type.
if dst0, _ := dst.(RGBA64Image); dst0 != nil {
if src0, _ := src.(image.RGBA64Image); src0 != nil {
if mask == nil {
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
sx := sp.X + x0 - r.Min.X
mx := mp.X + x0 - r.Min.X
for x := x0; x != x1; x, sx, mx = x+dx, sx+dx, mx+dx {
if op == Src {
dst0.SetRGBA64(x, y, src0.RGBA64At(sx, sy))
} else {
srgba := src0.RGBA64At(sx, sy)
a := m - uint32(srgba.A)
drgba := dst0.RGBA64At(x, y)
dst0.SetRGBA64(x, y, color.RGBA64{
R: uint16((uint32(drgba.R)*a)/m) + srgba.R,
G: uint16((uint32(drgba.G)*a)/m) + srgba.G,
B: uint16((uint32(drgba.B)*a)/m) + srgba.B,
A: uint16((uint32(drgba.A)*a)/m) + srgba.A,
})
}
}
}
return
} else if mask0, _ := mask.(image.RGBA64Image); mask0 != nil {
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
sx := sp.X + x0 - r.Min.X
mx := mp.X + x0 - r.Min.X
for x := x0; x != x1; x, sx, mx = x+dx, sx+dx, mx+dx {
ma := uint32(mask0.RGBA64At(mx, my).A)
switch {
case ma == 0:
if op == Over {
// No-op.
} else {
dst0.SetRGBA64(x, y, color.RGBA64{})
}
case ma == m && op == Src:
dst0.SetRGBA64(x, y, src0.RGBA64At(sx, sy))
default:
srgba := src0.RGBA64At(sx, sy)
if op == Over {
drgba := dst0.RGBA64At(x, y)
a := m - (uint32(srgba.A) * ma / m)
dst0.SetRGBA64(x, y, color.RGBA64{
R: uint16((uint32(drgba.R)*a + uint32(srgba.R)*ma) / m),
G: uint16((uint32(drgba.G)*a + uint32(srgba.G)*ma) / m),
B: uint16((uint32(drgba.B)*a + uint32(srgba.B)*ma) / m),
A: uint16((uint32(drgba.A)*a + uint32(srgba.A)*ma) / m),
})
} else {
dst0.SetRGBA64(x, y, color.RGBA64{
R: uint16(uint32(srgba.R) * ma / m),
G: uint16(uint32(srgba.G) * ma / m),
B: uint16(uint32(srgba.B) * ma / m),
A: uint16(uint32(srgba.A) * ma / m),
})
}
}
}
}
return
}
}
}
// FALLBACK1.0
//
// If none of the faster code paths above apply, use the draw.Image and
// image.Image interfaces, part of the standard library since Go 1.0.
var out color.RGBA64
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
sx := sp.X + x0 - r.Min.X
mx := mp.X + x0 - r.Min.X
for x := x0; x != x1; x, sx, mx = x+dx, sx+dx, mx+dx {
ma := uint32(m)
if mask != nil {
_, _, _, ma = mask.At(mx, my).RGBA()
}
switch {
case ma == 0:
if op == Over {
// No-op.
} else {
dst.Set(x, y, color.Transparent)
}
case ma == m && op == Src:
dst.Set(x, y, src.At(sx, sy))
default:
sr, sg, sb, sa := src.At(sx, sy).RGBA()
if op == Over {
dr, dg, db, da := dst.At(x, y).RGBA()
a := m - (sa * ma / m)
out.R = uint16((dr*a + sr*ma) / m)
out.G = uint16((dg*a + sg*ma) / m)
out.B = uint16((db*a + sb*ma) / m)
out.A = uint16((da*a + sa*ma) / m)
} else {
out.R = uint16(sr * ma / m)
out.G = uint16(sg * ma / m)
out.B = uint16(sb * ma / m)
out.A = uint16(sa * ma / m)
}
// The third argument is &out instead of out (and out is
// declared outside of the inner loop) to avoid the implicit
// conversion to color.Color here allocating memory in the
// inner loop if sizeof(color.RGBA64) > sizeof(uintptr).
dst.Set(x, y, &out)
}
}
}
}
func drawFillOver(dst *image.RGBA, r image.Rectangle, sr, sg, sb, sa uint32) {
// The 0x101 is here for the same reason as in drawRGBA.
a := (m - sa) * 0x101
i0 := dst.PixOffset(r.Min.X, r.Min.Y)
i1 := i0 + r.Dx()*4
for y := r.Min.Y; y != r.Max.Y; y++ {
for i := i0; i < i1; i += 4 {
dr := &dst.Pix[i+0]
dg := &dst.Pix[i+1]
db := &dst.Pix[i+2]
da := &dst.Pix[i+3]
*dr = uint8((uint32(*dr)*a/m + sr) >> 8)
*dg = uint8((uint32(*dg)*a/m + sg) >> 8)
*db = uint8((uint32(*db)*a/m + sb) >> 8)
*da = uint8((uint32(*da)*a/m + sa) >> 8)
}
i0 += dst.Stride
i1 += dst.Stride
}
}
func drawFillSrc(dst *image.RGBA, r image.Rectangle, sr, sg, sb, sa uint32) {
sr8 := uint8(sr >> 8)
sg8 := uint8(sg >> 8)
sb8 := uint8(sb >> 8)
sa8 := uint8(sa >> 8)
// The built-in copy function is faster than a straightforward for loop to fill the destination with
// the color, but copy requires a slice source. We therefore use a for loop to fill the first row, and
// then use the first row as the slice source for the remaining rows.
i0 := dst.PixOffset(r.Min.X, r.Min.Y)
i1 := i0 + r.Dx()*4
for i := i0; i < i1; i += 4 {
dst.Pix[i+0] = sr8
dst.Pix[i+1] = sg8
dst.Pix[i+2] = sb8
dst.Pix[i+3] = sa8
}
firstRow := dst.Pix[i0:i1]
for y := r.Min.Y + 1; y < r.Max.Y; y++ {
i0 += dst.Stride
i1 += dst.Stride
copy(dst.Pix[i0:i1], firstRow)
}
}
func drawCopyOver(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.Point) {
dx, dy := r.Dx(), r.Dy()
d0 := dst.PixOffset(r.Min.X, r.Min.Y)
s0 := src.PixOffset(sp.X, sp.Y)
var (
ddelta, sdelta int
i0, i1, idelta int
)
if r.Min.Y < sp.Y || r.Min.Y == sp.Y && r.Min.X <= sp.X {
ddelta = dst.Stride
sdelta = src.Stride
i0, i1, idelta = 0, dx*4, +4
} else {
// If the source start point is higher than the destination start point, or equal height but to the left,
// then we compose the rows in right-to-left, bottom-up order instead of left-to-right, top-down.
d0 += (dy - 1) * dst.Stride
s0 += (dy - 1) * src.Stride
ddelta = -dst.Stride
sdelta = -src.Stride
i0, i1, idelta = (dx-1)*4, -4, -4
}
for ; dy > 0; dy-- {
dpix := dst.Pix[d0:]
spix := src.Pix[s0:]
for i := i0; i != i1; i += idelta {
s := spix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
sr := uint32(s[0]) * 0x101
sg := uint32(s[1]) * 0x101
sb := uint32(s[2]) * 0x101
sa := uint32(s[3]) * 0x101
// The 0x101 is here for the same reason as in drawRGBA.
a := (m - sa) * 0x101
d := dpix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
d[0] = uint8((uint32(d[0])*a/m + sr) >> 8)
d[1] = uint8((uint32(d[1])*a/m + sg) >> 8)
d[2] = uint8((uint32(d[2])*a/m + sb) >> 8)
d[3] = uint8((uint32(d[3])*a/m + sa) >> 8)
}
d0 += ddelta
s0 += sdelta
}
}
// drawCopySrc copies bytes to dstPix from srcPix. These arguments roughly
// correspond to the Pix fields of the image package's concrete image.Image
// implementations, but are offset (dstPix is dst.Pix[dpOffset:] not dst.Pix).
func drawCopySrc(
dstPix []byte, dstStride int, r image.Rectangle,
srcPix []byte, srcStride int, sp image.Point,
bytesPerRow int) {
d0, s0, ddelta, sdelta, dy := 0, 0, dstStride, srcStride, r.Dy()
if r.Min.Y > sp.Y {
// If the source start point is higher than the destination start
// point, then we compose the rows in bottom-up order instead of
// top-down. Unlike the drawCopyOver function, we don't have to check
// the x coordinates because the built-in copy function can handle
// overlapping slices.
d0 = (dy - 1) * dstStride
s0 = (dy - 1) * srcStride
ddelta = -dstStride
sdelta = -srcStride
}
for ; dy > 0; dy-- {
copy(dstPix[d0:d0+bytesPerRow], srcPix[s0:s0+bytesPerRow])
d0 += ddelta
s0 += sdelta
}
}
func drawNRGBAOver(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) {
i0 := (r.Min.X - dst.Rect.Min.X) * 4
i1 := (r.Max.X - dst.Rect.Min.X) * 4
si0 := (sp.X - src.Rect.Min.X) * 4
yMax := r.Max.Y - dst.Rect.Min.Y
y := r.Min.Y - dst.Rect.Min.Y
sy := sp.Y - src.Rect.Min.Y
for ; y != yMax; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride:]
spix := src.Pix[sy*src.Stride:]
for i, si := i0, si0; i < i1; i, si = i+4, si+4 {
// Convert from non-premultiplied color to pre-multiplied color.
s := spix[si : si+4 : si+4] // Small cap improves performance, see https://golang.org/issue/27857
sa := uint32(s[3]) * 0x101
sr := uint32(s[0]) * sa / 0xff
sg := uint32(s[1]) * sa / 0xff
sb := uint32(s[2]) * sa / 0xff
d := dpix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
// The 0x101 is here for the same reason as in drawRGBA.
a := (m - sa) * 0x101
d[0] = uint8((dr*a/m + sr) >> 8)
d[1] = uint8((dg*a/m + sg) >> 8)
d[2] = uint8((db*a/m + sb) >> 8)
d[3] = uint8((da*a/m + sa) >> 8)
}
}
}
func drawNRGBASrc(dst *image.RGBA, r image.Rectangle, src *image.NRGBA, sp image.Point) {
i0 := (r.Min.X - dst.Rect.Min.X) * 4
i1 := (r.Max.X - dst.Rect.Min.X) * 4
si0 := (sp.X - src.Rect.Min.X) * 4
yMax := r.Max.Y - dst.Rect.Min.Y
y := r.Min.Y - dst.Rect.Min.Y
sy := sp.Y - src.Rect.Min.Y
for ; y != yMax; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride:]
spix := src.Pix[sy*src.Stride:]
for i, si := i0, si0; i < i1; i, si = i+4, si+4 {
// Convert from non-premultiplied color to pre-multiplied color.
s := spix[si : si+4 : si+4] // Small cap improves performance, see https://golang.org/issue/27857
sa := uint32(s[3]) * 0x101
sr := uint32(s[0]) * sa / 0xff
sg := uint32(s[1]) * sa / 0xff
sb := uint32(s[2]) * sa / 0xff
d := dpix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
d[0] = uint8(sr >> 8)
d[1] = uint8(sg >> 8)
d[2] = uint8(sb >> 8)
d[3] = uint8(sa >> 8)
}
}
}
func drawGray(dst *image.RGBA, r image.Rectangle, src *image.Gray, sp image.Point) {
i0 := (r.Min.X - dst.Rect.Min.X) * 4
i1 := (r.Max.X - dst.Rect.Min.X) * 4
si0 := (sp.X - src.Rect.Min.X) * 1
yMax := r.Max.Y - dst.Rect.Min.Y
y := r.Min.Y - dst.Rect.Min.Y
sy := sp.Y - src.Rect.Min.Y
for ; y != yMax; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride:]
spix := src.Pix[sy*src.Stride:]
for i, si := i0, si0; i < i1; i, si = i+4, si+1 {
p := spix[si]
d := dpix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
d[0] = p
d[1] = p
d[2] = p
d[3] = 255
}
}
}
func drawCMYK(dst *image.RGBA, r image.Rectangle, src *image.CMYK, sp image.Point) {
i0 := (r.Min.X - dst.Rect.Min.X) * 4
i1 := (r.Max.X - dst.Rect.Min.X) * 4
si0 := (sp.X - src.Rect.Min.X) * 4
yMax := r.Max.Y - dst.Rect.Min.Y
y := r.Min.Y - dst.Rect.Min.Y
sy := sp.Y - src.Rect.Min.Y
for ; y != yMax; y, sy = y+1, sy+1 {
dpix := dst.Pix[y*dst.Stride:]
spix := src.Pix[sy*src.Stride:]
for i, si := i0, si0; i < i1; i, si = i+4, si+4 {
s := spix[si : si+4 : si+4] // Small cap improves performance, see https://golang.org/issue/27857
d := dpix[i : i+4 : i+4]
d[0], d[1], d[2] = color.CMYKToRGB(s[0], s[1], s[2], s[3])
d[3] = 255
}
}
}
func drawGlyphOver(dst *image.RGBA, r image.Rectangle, src *image.Uniform, mask *image.Alpha, mp image.Point) {
i0 := dst.PixOffset(r.Min.X, r.Min.Y)
i1 := i0 + r.Dx()*4
mi0 := mask.PixOffset(mp.X, mp.Y)
sr, sg, sb, sa := src.RGBA()
for y, my := r.Min.Y, mp.Y; y != r.Max.Y; y, my = y+1, my+1 {
for i, mi := i0, mi0; i < i1; i, mi = i+4, mi+1 {
ma := uint32(mask.Pix[mi])
if ma == 0 {
continue
}
ma |= ma << 8
// The 0x101 is here for the same reason as in drawRGBA.
a := (m - (sa * ma / m)) * 0x101
d := dst.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
d[0] = uint8((uint32(d[0])*a + sr*ma) / m >> 8)
d[1] = uint8((uint32(d[1])*a + sg*ma) / m >> 8)
d[2] = uint8((uint32(d[2])*a + sb*ma) / m >> 8)
d[3] = uint8((uint32(d[3])*a + sa*ma) / m >> 8)
}
i0 += dst.Stride
i1 += dst.Stride
mi0 += mask.Stride
}
}
func drawGrayMaskOver(dst *image.RGBA, r image.Rectangle, src *image.Gray, sp image.Point, mask *image.Alpha, mp image.Point) {
x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1
if r.Overlaps(r.Add(sp.Sub(r.Min))) {
if sp.Y < r.Min.Y || sp.Y == r.Min.Y && sp.X < r.Min.X {
x0, x1, dx = x1-1, x0-1, -1
y0, y1, dy = y1-1, y0-1, -1
}
}
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
sx0 := sp.X + x0 - r.Min.X
mx0 := mp.X + x0 - r.Min.X
sx1 := sx0 + (x1 - x0)
i0 := dst.PixOffset(x0, y0)
di := dx * 4
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
mi := mask.PixOffset(mx, my)
ma := uint32(mask.Pix[mi])
ma |= ma << 8
si := src.PixOffset(sx, sy)
sy := uint32(src.Pix[si])
sy |= sy << 8
sa := uint32(0xffff)
d := dst.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
// dr, dg, db and da are all 8-bit color at the moment, ranging in [0,255].
// We work in 16-bit color, and so would normally do:
// dr |= dr << 8
// and similarly for dg, db and da, but instead we multiply a
// (which is a 16-bit color, ranging in [0,65535]) by 0x101.
// This yields the same result, but is fewer arithmetic operations.
a := (m - (sa * ma / m)) * 0x101
d[0] = uint8((dr*a + sy*ma) / m >> 8)
d[1] = uint8((dg*a + sy*ma) / m >> 8)
d[2] = uint8((db*a + sy*ma) / m >> 8)
d[3] = uint8((da*a + sa*ma) / m >> 8)
}
i0 += dy * dst.Stride
}
}
func drawRGBAMaskOver(dst *image.RGBA, r image.Rectangle, src *image.RGBA, sp image.Point, mask *image.Alpha, mp image.Point) {
x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1
if dst == src && r.Overlaps(r.Add(sp.Sub(r.Min))) {
if sp.Y < r.Min.Y || sp.Y == r.Min.Y && sp.X < r.Min.X {
x0, x1, dx = x1-1, x0-1, -1
y0, y1, dy = y1-1, y0-1, -1
}
}
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
sx0 := sp.X + x0 - r.Min.X
mx0 := mp.X + x0 - r.Min.X
sx1 := sx0 + (x1 - x0)
i0 := dst.PixOffset(x0, y0)
di := dx * 4
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
mi := mask.PixOffset(mx, my)
ma := uint32(mask.Pix[mi])
ma |= ma << 8
si := src.PixOffset(sx, sy)
sr := uint32(src.Pix[si+0])
sg := uint32(src.Pix[si+1])
sb := uint32(src.Pix[si+2])
sa := uint32(src.Pix[si+3])
sr |= sr << 8
sg |= sg << 8
sb |= sb << 8
sa |= sa << 8
d := dst.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
// dr, dg, db and da are all 8-bit color at the moment, ranging in [0,255].
// We work in 16-bit color, and so would normally do:
// dr |= dr << 8
// and similarly for dg, db and da, but instead we multiply a
// (which is a 16-bit color, ranging in [0,65535]) by 0x101.
// This yields the same result, but is fewer arithmetic operations.
a := (m - (sa * ma / m)) * 0x101
d[0] = uint8((dr*a + sr*ma) / m >> 8)
d[1] = uint8((dg*a + sg*ma) / m >> 8)
d[2] = uint8((db*a + sb*ma) / m >> 8)
d[3] = uint8((da*a + sa*ma) / m >> 8)
}
i0 += dy * dst.Stride
}
}
func drawRGBA64ImageMaskOver(dst *image.RGBA, r image.Rectangle, src image.RGBA64Image, sp image.Point, mask *image.Alpha, mp image.Point) {
x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1
if image.Image(dst) == src && r.Overlaps(r.Add(sp.Sub(r.Min))) {
if sp.Y < r.Min.Y || sp.Y == r.Min.Y && sp.X < r.Min.X {
x0, x1, dx = x1-1, x0-1, -1
y0, y1, dy = y1-1, y0-1, -1
}
}
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
sx0 := sp.X + x0 - r.Min.X
mx0 := mp.X + x0 - r.Min.X
sx1 := sx0 + (x1 - x0)
i0 := dst.PixOffset(x0, y0)
di := dx * 4
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
mi := mask.PixOffset(mx, my)
ma := uint32(mask.Pix[mi])
ma |= ma << 8
srgba := src.RGBA64At(sx, sy)
d := dst.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
// dr, dg, db and da are all 8-bit color at the moment, ranging in [0,255].
// We work in 16-bit color, and so would normally do:
// dr |= dr << 8
// and similarly for dg, db and da, but instead we multiply a
// (which is a 16-bit color, ranging in [0,65535]) by 0x101.
// This yields the same result, but is fewer arithmetic operations.
a := (m - (uint32(srgba.A) * ma / m)) * 0x101
d[0] = uint8((dr*a + uint32(srgba.R)*ma) / m >> 8)
d[1] = uint8((dg*a + uint32(srgba.G)*ma) / m >> 8)
d[2] = uint8((db*a + uint32(srgba.B)*ma) / m >> 8)
d[3] = uint8((da*a + uint32(srgba.A)*ma) / m >> 8)
}
i0 += dy * dst.Stride
}
}
func drawRGBA(dst *image.RGBA, r image.Rectangle, src image.Image, sp image.Point, mask image.Image, mp image.Point, op Op) {
x0, x1, dx := r.Min.X, r.Max.X, 1
y0, y1, dy := r.Min.Y, r.Max.Y, 1
if image.Image(dst) == src && r.Overlaps(r.Add(sp.Sub(r.Min))) {
if sp.Y < r.Min.Y || sp.Y == r.Min.Y && sp.X < r.Min.X {
x0, x1, dx = x1-1, x0-1, -1
y0, y1, dy = y1-1, y0-1, -1
}
}
sy := sp.Y + y0 - r.Min.Y
my := mp.Y + y0 - r.Min.Y
sx0 := sp.X + x0 - r.Min.X
mx0 := mp.X + x0 - r.Min.X
sx1 := sx0 + (x1 - x0)
i0 := dst.PixOffset(x0, y0)
di := dx * 4
// Try the image.RGBA64Image interface, part of the standard library since
// Go 1.17.
//
// This optimization is similar to how FALLBACK1.17 optimizes FALLBACK1.0
// in DrawMask, except here the concrete type of dst is known to be
// *image.RGBA.
if src0, _ := src.(image.RGBA64Image); src0 != nil {
if mask == nil {
if op == Over {
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
srgba := src0.RGBA64At(sx, sy)
d := dst.Pix[i : i+4 : i+4]
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
a := (m - uint32(srgba.A)) * 0x101
d[0] = uint8((dr*a/m + uint32(srgba.R)) >> 8)
d[1] = uint8((dg*a/m + uint32(srgba.G)) >> 8)
d[2] = uint8((db*a/m + uint32(srgba.B)) >> 8)
d[3] = uint8((da*a/m + uint32(srgba.A)) >> 8)
}
i0 += dy * dst.Stride
}
} else {
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
srgba := src0.RGBA64At(sx, sy)
d := dst.Pix[i : i+4 : i+4]
d[0] = uint8(srgba.R >> 8)
d[1] = uint8(srgba.G >> 8)
d[2] = uint8(srgba.B >> 8)
d[3] = uint8(srgba.A >> 8)
}
i0 += dy * dst.Stride
}
}
return
} else if mask0, _ := mask.(image.RGBA64Image); mask0 != nil {
if op == Over {
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
ma := uint32(mask0.RGBA64At(mx, my).A)
srgba := src0.RGBA64At(sx, sy)
d := dst.Pix[i : i+4 : i+4]
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
a := (m - (uint32(srgba.A) * ma / m)) * 0x101
d[0] = uint8((dr*a + uint32(srgba.R)*ma) / m >> 8)
d[1] = uint8((dg*a + uint32(srgba.G)*ma) / m >> 8)
d[2] = uint8((db*a + uint32(srgba.B)*ma) / m >> 8)
d[3] = uint8((da*a + uint32(srgba.A)*ma) / m >> 8)
}
i0 += dy * dst.Stride
}
} else {
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
ma := uint32(mask0.RGBA64At(mx, my).A)
srgba := src0.RGBA64At(sx, sy)
d := dst.Pix[i : i+4 : i+4]
d[0] = uint8(uint32(srgba.R) * ma / m >> 8)
d[1] = uint8(uint32(srgba.G) * ma / m >> 8)
d[2] = uint8(uint32(srgba.B) * ma / m >> 8)
d[3] = uint8(uint32(srgba.A) * ma / m >> 8)
}
i0 += dy * dst.Stride
}
}
return
}
}
// Use the image.Image interface, part of the standard library since Go
// 1.0.
//
// This is similar to FALLBACK1.0 in DrawMask, except here the concrete
// type of dst is known to be *image.RGBA.
for y := y0; y != y1; y, sy, my = y+dy, sy+dy, my+dy {
for i, sx, mx := i0, sx0, mx0; sx != sx1; i, sx, mx = i+di, sx+dx, mx+dx {
ma := uint32(m)
if mask != nil {
_, _, _, ma = mask.At(mx, my).RGBA()
}
sr, sg, sb, sa := src.At(sx, sy).RGBA()
d := dst.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
if op == Over {
dr := uint32(d[0])
dg := uint32(d[1])
db := uint32(d[2])
da := uint32(d[3])
// dr, dg, db and da are all 8-bit color at the moment, ranging in [0,255].
// We work in 16-bit color, and so would normally do:
// dr |= dr << 8
// and similarly for dg, db and da, but instead we multiply a
// (which is a 16-bit color, ranging in [0,65535]) by 0x101.
// This yields the same result, but is fewer arithmetic operations.
a := (m - (sa * ma / m)) * 0x101
d[0] = uint8((dr*a + sr*ma) / m >> 8)
d[1] = uint8((dg*a + sg*ma) / m >> 8)
d[2] = uint8((db*a + sb*ma) / m >> 8)
d[3] = uint8((da*a + sa*ma) / m >> 8)
} else {
d[0] = uint8(sr * ma / m >> 8)
d[1] = uint8(sg * ma / m >> 8)
d[2] = uint8(sb * ma / m >> 8)
d[3] = uint8(sa * ma / m >> 8)
}
}
i0 += dy * dst.Stride
}
}
// clamp clamps i to the interval [0, 0xffff].
func clamp(i int32) int32 {
if i < 0 {
return 0
}
if i > 0xffff {
return 0xffff
}
return i
}
// sqDiff returns the squared-difference of x and y, shifted by 2 so that
// adding four of those won't overflow a uint32.
//
// x and y are both assumed to be in the range [0, 0xffff].
func sqDiff(x, y int32) uint32 {
// This is an optimized code relying on the overflow/wrap around
// properties of unsigned integers operations guaranteed by the language
// spec. See sqDiff from the image/color package for more details.
d := uint32(x - y)
return (d * d) >> 2
}
func drawPaletted(dst Image, r image.Rectangle, src image.Image, sp image.Point, floydSteinberg bool) {
// TODO(nigeltao): handle the case where the dst and src overlap.
// Does it even make sense to try and do Floyd-Steinberg whilst
// walking the image backward (right-to-left bottom-to-top)?
// If dst is an *image.Paletted, we have a fast path for dst.Set and
// dst.At. The dst.Set equivalent is a batch version of the algorithm
// used by color.Palette's Index method in image/color/color.go, plus
// optional Floyd-Steinberg error diffusion.
palette, pix, stride := [][4]int32(nil), []byte(nil), 0
if p, ok := dst.(*image.Paletted); ok {
palette = make([][4]int32, len(p.Palette))
for i, col := range p.Palette {
r, g, b, a := col.RGBA()
palette[i][0] = int32(r)
palette[i][1] = int32(g)
palette[i][2] = int32(b)
palette[i][3] = int32(a)
}
pix, stride = p.Pix[p.PixOffset(r.Min.X, r.Min.Y):], p.Stride
}
// quantErrorCurr and quantErrorNext are the Floyd-Steinberg quantization
// errors that have been propagated to the pixels in the current and next
// rows. The +2 simplifies calculation near the edges.
var quantErrorCurr, quantErrorNext [][4]int32
if floydSteinberg {
quantErrorCurr = make([][4]int32, r.Dx()+2)
quantErrorNext = make([][4]int32, r.Dx()+2)
}
pxRGBA := func(x, y int) (r, g, b, a uint32) { return src.At(x, y).RGBA() }
// Fast paths for special cases to avoid excessive use of the color.Color
// interface which escapes to the heap but need to be discovered for
// each pixel on r. See also https://golang.org/issues/15759.
switch src0 := src.(type) {
case *image.RGBA:
pxRGBA = func(x, y int) (r, g, b, a uint32) { return src0.RGBAAt(x, y).RGBA() }
case *image.NRGBA:
pxRGBA = func(x, y int) (r, g, b, a uint32) { return src0.NRGBAAt(x, y).RGBA() }
case *image.YCbCr:
pxRGBA = func(x, y int) (r, g, b, a uint32) { return src0.YCbCrAt(x, y).RGBA() }
}
// Loop over each source pixel.
out := color.RGBA64{A: 0xffff}
for y := 0; y != r.Dy(); y++ {
for x := 0; x != r.Dx(); x++ {
// er, eg and eb are the pixel's R,G,B values plus the
// optional Floyd-Steinberg error.
sr, sg, sb, sa := pxRGBA(sp.X+x, sp.Y+y)
er, eg, eb, ea := int32(sr), int32(sg), int32(sb), int32(sa)
if floydSteinberg {
er = clamp(er + quantErrorCurr[x+1][0]/16)
eg = clamp(eg + quantErrorCurr[x+1][1]/16)
eb = clamp(eb + quantErrorCurr[x+1][2]/16)
ea = clamp(ea + quantErrorCurr[x+1][3]/16)
}
if palette != nil {
// Find the closest palette color in Euclidean R,G,B,A space:
// the one that minimizes sum-squared-difference.
// TODO(nigeltao): consider smarter algorithms.
bestIndex, bestSum := 0, uint32(1<<32-1)
for index, p := range palette {
sum := sqDiff(er, p[0]) + sqDiff(eg, p[1]) + sqDiff(eb, p[2]) + sqDiff(ea, p[3])
if sum < bestSum {
bestIndex, bestSum = index, sum
if sum == 0 {
break
}
}
}
pix[y*stride+x] = byte(bestIndex)
if !floydSteinberg {
continue
}
er -= palette[bestIndex][0]
eg -= palette[bestIndex][1]
eb -= palette[bestIndex][2]
ea -= palette[bestIndex][3]
} else {
out.R = uint16(er)
out.G = uint16(eg)
out.B = uint16(eb)
out.A = uint16(ea)
// The third argument is &out instead of out (and out is
// declared outside of the inner loop) to avoid the implicit
// conversion to color.Color here allocating memory in the
// inner loop if sizeof(color.RGBA64) > sizeof(uintptr).
dst.Set(r.Min.X+x, r.Min.Y+y, &out)
if !floydSteinberg {
continue
}
sr, sg, sb, sa = dst.At(r.Min.X+x, r.Min.Y+y).RGBA()
er -= int32(sr)
eg -= int32(sg)
eb -= int32(sb)
ea -= int32(sa)
}
// Propagate the Floyd-Steinberg quantization error.
quantErrorNext[x+0][0] += er * 3
quantErrorNext[x+0][1] += eg * 3
quantErrorNext[x+0][2] += eb * 3
quantErrorNext[x+0][3] += ea * 3
quantErrorNext[x+1][0] += er * 5
quantErrorNext[x+1][1] += eg * 5
quantErrorNext[x+1][2] += eb * 5
quantErrorNext[x+1][3] += ea * 5
quantErrorNext[x+2][0] += er * 1
quantErrorNext[x+2][1] += eg * 1
quantErrorNext[x+2][2] += eb * 1
quantErrorNext[x+2][3] += ea * 1
quantErrorCurr[x+2][0] += er * 7
quantErrorCurr[x+2][1] += eg * 7
quantErrorCurr[x+2][2] += eb * 7
quantErrorCurr[x+2][3] += ea * 7
}
// Recycle the quantization error buffers.
if floydSteinberg {
quantErrorCurr, quantErrorNext = quantErrorNext, quantErrorCurr
for i := range quantErrorNext {
quantErrorNext[i] = [4]int32{}
}
}
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"bufio"
"errors"
"io"
"sync"
"sync/atomic"
)
// ErrFormat indicates that decoding encountered an unknown format.
var ErrFormat = errors.New("image: unknown format")
// A format holds an image format's name, magic header and how to decode it.
type format struct {
name, magic string
decode func(io.Reader) (Image, error)
decodeConfig func(io.Reader) (Config, error)
}
// Formats is the list of registered formats.
var (
formatsMu sync.Mutex
atomicFormats atomic.Value
)
// RegisterFormat registers an image format for use by Decode.
// Name is the name of the format, like "jpeg" or "png".
// Magic is the magic prefix that identifies the format's encoding. The magic
// string can contain "?" wildcards that each match any one byte.
// Decode is the function that decodes the encoded image.
// DecodeConfig is the function that decodes just its configuration.
func RegisterFormat(name, magic string, decode func(io.Reader) (Image, error), decodeConfig func(io.Reader) (Config, error)) {
formatsMu.Lock()
formats, _ := atomicFormats.Load().([]format)
atomicFormats.Store(append(formats, format{name, magic, decode, decodeConfig}))
formatsMu.Unlock()
}
// A reader is an io.Reader that can also peek ahead.
type reader interface {
io.Reader
Peek(int) ([]byte, error)
}
// asReader converts an io.Reader to a reader.
func asReader(r io.Reader) reader {
if rr, ok := r.(reader); ok {
return rr
}
return bufio.NewReader(r)
}
// match reports whether magic matches b. Magic may contain "?" wildcards.
func match(magic string, b []byte) bool {
if len(magic) != len(b) {
return false
}
for i, c := range b {
if magic[i] != c && magic[i] != '?' {
return false
}
}
return true
}
// sniff determines the format of r's data.
func sniff(r reader) format {
formats, _ := atomicFormats.Load().([]format)
for _, f := range formats {
b, err := r.Peek(len(f.magic))
if err == nil && match(f.magic, b) {
return f
}
}
return format{}
}
// Decode decodes an image that has been encoded in a registered format.
// The string returned is the format name used during format registration.
// Format registration is typically done by an init function in the codec-
// specific package.
func Decode(r io.Reader) (Image, string, error) {
rr := asReader(r)
f := sniff(rr)
if f.decode == nil {
return nil, "", ErrFormat
}
m, err := f.decode(rr)
return m, f.name, err
}
// DecodeConfig decodes the color model and dimensions of an image that has
// been encoded in a registered format. The string returned is the format name
// used during format registration. Format registration is typically done by
// an init function in the codec-specific package.
func DecodeConfig(r io.Reader) (Config, string, error) {
rr := asReader(r)
f := sniff(rr)
if f.decodeConfig == nil {
return Config{}, "", ErrFormat
}
c, err := f.decodeConfig(rr)
return c, f.name, err
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"image/color"
"math/bits"
"strconv"
)
// A Point is an X, Y coordinate pair. The axes increase right and down.
type Point struct {
X, Y int
}
// String returns a string representation of p like "(3,4)".
func (p Point) String() string {
return "(" + strconv.Itoa(p.X) + "," + strconv.Itoa(p.Y) + ")"
}
// Add returns the vector p+q.
func (p Point) Add(q Point) Point {
return Point{p.X + q.X, p.Y + q.Y}
}
// Sub returns the vector p-q.
func (p Point) Sub(q Point) Point {
return Point{p.X - q.X, p.Y - q.Y}
}
// Mul returns the vector p*k.
func (p Point) Mul(k int) Point {
return Point{p.X * k, p.Y * k}
}
// Div returns the vector p/k.
func (p Point) Div(k int) Point {
return Point{p.X / k, p.Y / k}
}
// In reports whether p is in r.
func (p Point) In(r Rectangle) bool {
return r.Min.X <= p.X && p.X < r.Max.X &&
r.Min.Y <= p.Y && p.Y < r.Max.Y
}
// Mod returns the point q in r such that p.X-q.X is a multiple of r's width
// and p.Y-q.Y is a multiple of r's height.
func (p Point) Mod(r Rectangle) Point {
w, h := r.Dx(), r.Dy()
p = p.Sub(r.Min)
p.X = p.X % w
if p.X < 0 {
p.X += w
}
p.Y = p.Y % h
if p.Y < 0 {
p.Y += h
}
return p.Add(r.Min)
}
// Eq reports whether p and q are equal.
func (p Point) Eq(q Point) bool {
return p == q
}
// ZP is the zero Point.
//
// Deprecated: Use a literal image.Point{} instead.
var ZP Point
// Pt is shorthand for Point{X, Y}.
func Pt(X, Y int) Point {
return Point{X, Y}
}
// A Rectangle contains the points with Min.X <= X < Max.X, Min.Y <= Y < Max.Y.
// It is well-formed if Min.X <= Max.X and likewise for Y. Points are always
// well-formed. A rectangle's methods always return well-formed outputs for
// well-formed inputs.
//
// A Rectangle is also an Image whose bounds are the rectangle itself. At
// returns color.Opaque for points in the rectangle and color.Transparent
// otherwise.
type Rectangle struct {
Min, Max Point
}
// String returns a string representation of r like "(3,4)-(6,5)".
func (r Rectangle) String() string {
return r.Min.String() + "-" + r.Max.String()
}
// Dx returns r's width.
func (r Rectangle) Dx() int {
return r.Max.X - r.Min.X
}
// Dy returns r's height.
func (r Rectangle) Dy() int {
return r.Max.Y - r.Min.Y
}
// Size returns r's width and height.
func (r Rectangle) Size() Point {
return Point{
r.Max.X - r.Min.X,
r.Max.Y - r.Min.Y,
}
}
// Add returns the rectangle r translated by p.
func (r Rectangle) Add(p Point) Rectangle {
return Rectangle{
Point{r.Min.X + p.X, r.Min.Y + p.Y},
Point{r.Max.X + p.X, r.Max.Y + p.Y},
}
}
// Sub returns the rectangle r translated by -p.
func (r Rectangle) Sub(p Point) Rectangle {
return Rectangle{
Point{r.Min.X - p.X, r.Min.Y - p.Y},
Point{r.Max.X - p.X, r.Max.Y - p.Y},
}
}
// Inset returns the rectangle r inset by n, which may be negative. If either
// of r's dimensions is less than 2*n then an empty rectangle near the center
// of r will be returned.
func (r Rectangle) Inset(n int) Rectangle {
if r.Dx() < 2*n {
r.Min.X = (r.Min.X + r.Max.X) / 2
r.Max.X = r.Min.X
} else {
r.Min.X += n
r.Max.X -= n
}
if r.Dy() < 2*n {
r.Min.Y = (r.Min.Y + r.Max.Y) / 2
r.Max.Y = r.Min.Y
} else {
r.Min.Y += n
r.Max.Y -= n
}
return r
}
// Intersect returns the largest rectangle contained by both r and s. If the
// two rectangles do not overlap then the zero rectangle will be returned.
func (r Rectangle) Intersect(s Rectangle) Rectangle {
if r.Min.X < s.Min.X {
r.Min.X = s.Min.X
}
if r.Min.Y < s.Min.Y {
r.Min.Y = s.Min.Y
}
if r.Max.X > s.Max.X {
r.Max.X = s.Max.X
}
if r.Max.Y > s.Max.Y {
r.Max.Y = s.Max.Y
}
// Letting r0 and s0 be the values of r and s at the time that the method
// is called, this next line is equivalent to:
//
// if max(r0.Min.X, s0.Min.X) >= min(r0.Max.X, s0.Max.X) || likewiseForY { etc }
if r.Empty() {
return ZR
}
return r
}
// Union returns the smallest rectangle that contains both r and s.
func (r Rectangle) Union(s Rectangle) Rectangle {
if r.Empty() {
return s
}
if s.Empty() {
return r
}
if r.Min.X > s.Min.X {
r.Min.X = s.Min.X
}
if r.Min.Y > s.Min.Y {
r.Min.Y = s.Min.Y
}
if r.Max.X < s.Max.X {
r.Max.X = s.Max.X
}
if r.Max.Y < s.Max.Y {
r.Max.Y = s.Max.Y
}
return r
}
// Empty reports whether the rectangle contains no points.
func (r Rectangle) Empty() bool {
return r.Min.X >= r.Max.X || r.Min.Y >= r.Max.Y
}
// Eq reports whether r and s contain the same set of points. All empty
// rectangles are considered equal.
func (r Rectangle) Eq(s Rectangle) bool {
return r == s || r.Empty() && s.Empty()
}
// Overlaps reports whether r and s have a non-empty intersection.
func (r Rectangle) Overlaps(s Rectangle) bool {
return !r.Empty() && !s.Empty() &&
r.Min.X < s.Max.X && s.Min.X < r.Max.X &&
r.Min.Y < s.Max.Y && s.Min.Y < r.Max.Y
}
// In reports whether every point in r is in s.
func (r Rectangle) In(s Rectangle) bool {
if r.Empty() {
return true
}
// Note that r.Max is an exclusive bound for r, so that r.In(s)
// does not require that r.Max.In(s).
return s.Min.X <= r.Min.X && r.Max.X <= s.Max.X &&
s.Min.Y <= r.Min.Y && r.Max.Y <= s.Max.Y
}
// Canon returns the canonical version of r. The returned rectangle has minimum
// and maximum coordinates swapped if necessary so that it is well-formed.
func (r Rectangle) Canon() Rectangle {
if r.Max.X < r.Min.X {
r.Min.X, r.Max.X = r.Max.X, r.Min.X
}
if r.Max.Y < r.Min.Y {
r.Min.Y, r.Max.Y = r.Max.Y, r.Min.Y
}
return r
}
// At implements the Image interface.
func (r Rectangle) At(x, y int) color.Color {
if (Point{x, y}).In(r) {
return color.Opaque
}
return color.Transparent
}
// RGBA64At implements the RGBA64Image interface.
func (r Rectangle) RGBA64At(x, y int) color.RGBA64 {
if (Point{x, y}).In(r) {
return color.RGBA64{0xffff, 0xffff, 0xffff, 0xffff}
}
return color.RGBA64{}
}
// Bounds implements the Image interface.
func (r Rectangle) Bounds() Rectangle {
return r
}
// ColorModel implements the Image interface.
func (r Rectangle) ColorModel() color.Model {
return color.Alpha16Model
}
// ZR is the zero Rectangle.
//
// Deprecated: Use a literal image.Rectangle{} instead.
var ZR Rectangle
// Rect is shorthand for Rectangle{Pt(x0, y0), Pt(x1, y1)}. The returned
// rectangle has minimum and maximum coordinates swapped if necessary so that
// it is well-formed.
func Rect(x0, y0, x1, y1 int) Rectangle {
if x0 > x1 {
x0, x1 = x1, x0
}
if y0 > y1 {
y0, y1 = y1, y0
}
return Rectangle{Point{x0, y0}, Point{x1, y1}}
}
// mul3NonNeg returns (x * y * z), unless at least one argument is negative or
// if the computation overflows the int type, in which case it returns -1.
func mul3NonNeg(x int, y int, z int) int {
if (x < 0) || (y < 0) || (z < 0) {
return -1
}
hi, lo := bits.Mul64(uint64(x), uint64(y))
if hi != 0 {
return -1
}
hi, lo = bits.Mul64(lo, uint64(z))
if hi != 0 {
return -1
}
a := int(lo)
if (a < 0) || (uint64(a) != lo) {
return -1
}
return a
}
// add2NonNeg returns (x + y), unless at least one argument is negative or if
// the computation overflows the int type, in which case it returns -1.
func add2NonNeg(x int, y int) int {
if (x < 0) || (y < 0) {
return -1
}
a := x + y
if a < 0 {
return -1
}
return a
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gif implements a GIF image decoder and encoder.
//
// The GIF specification is at https://www.w3.org/Graphics/GIF/spec-gif89a.txt.
package gif
import (
"bufio"
"compress/lzw"
"errors"
"fmt"
"image"
"image/color"
"io"
)
var (
errNotEnough = errors.New("gif: not enough image data")
errTooMuch = errors.New("gif: too much image data")
errBadPixel = errors.New("gif: invalid pixel value")
)
// If the io.Reader does not also have ReadByte, then decode will introduce its own buffering.
type reader interface {
io.Reader
io.ByteReader
}
// Masks etc.
const (
// Fields.
fColorTable = 1 << 7
fInterlace = 1 << 6
fColorTableBitsMask = 7
// Graphic control flags.
gcTransparentColorSet = 1 << 0
gcDisposalMethodMask = 7 << 2
)
// Disposal Methods.
const (
DisposalNone = 0x01
DisposalBackground = 0x02
DisposalPrevious = 0x03
)
// Section indicators.
const (
sExtension = 0x21
sImageDescriptor = 0x2C
sTrailer = 0x3B
)
// Extensions.
const (
eText = 0x01 // Plain Text
eGraphicControl = 0xF9 // Graphic Control
eComment = 0xFE // Comment
eApplication = 0xFF // Application
)
func readFull(r io.Reader, b []byte) error {
_, err := io.ReadFull(r, b)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
func readByte(r io.ByteReader) (byte, error) {
b, err := r.ReadByte()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return b, err
}
// decoder is the type used to decode a GIF file.
type decoder struct {
r reader
// From header.
vers string
width int
height int
loopCount int
delayTime int
backgroundIndex byte
disposalMethod byte
// From image descriptor.
imageFields byte
// From graphics control.
transparentIndex byte
hasTransparentIndex bool
// Computed.
globalColorTable color.Palette
// Used when decoding.
delay []int
disposal []byte
image []*image.Paletted
tmp [1024]byte // must be at least 768 so we can read color table
}
// blockReader parses the block structure of GIF image data, which comprises
// (n, (n bytes)) blocks, with 1 <= n <= 255. It is the reader given to the
// LZW decoder, which is thus immune to the blocking. After the LZW decoder
// completes, there will be a 0-byte block remaining (0, ()), which is
// consumed when checking that the blockReader is exhausted.
//
// To avoid the allocation of a bufio.Reader for the lzw Reader, blockReader
// implements io.ByteReader and buffers blocks into the decoder's "tmp" buffer.
type blockReader struct {
d *decoder
i, j uint8 // d.tmp[i:j] contains the buffered bytes
err error
}
func (b *blockReader) fill() {
if b.err != nil {
return
}
b.j, b.err = readByte(b.d.r)
if b.j == 0 && b.err == nil {
b.err = io.EOF
}
if b.err != nil {
return
}
b.i = 0
b.err = readFull(b.d.r, b.d.tmp[:b.j])
if b.err != nil {
b.j = 0
}
}
func (b *blockReader) ReadByte() (byte, error) {
if b.i == b.j {
b.fill()
if b.err != nil {
return 0, b.err
}
}
c := b.d.tmp[b.i]
b.i++
return c, nil
}
// blockReader must implement io.Reader, but its Read shouldn't ever actually
// be called in practice. The compress/lzw package will only call ReadByte.
func (b *blockReader) Read(p []byte) (int, error) {
if len(p) == 0 || b.err != nil {
return 0, b.err
}
if b.i == b.j {
b.fill()
if b.err != nil {
return 0, b.err
}
}
n := copy(p, b.d.tmp[b.i:b.j])
b.i += uint8(n)
return n, nil
}
// close primarily detects whether or not a block terminator was encountered
// after reading a sequence of data sub-blocks. It allows at most one trailing
// sub-block worth of data. I.e., if some number of bytes exist in one sub-block
// following the end of LZW data, the very next sub-block must be the block
// terminator. If the very end of LZW data happened to fill one sub-block, at
// most one more sub-block of length 1 may exist before the block-terminator.
// These accommodations allow us to support GIFs created by less strict encoders.
// See https://golang.org/issue/16146.
func (b *blockReader) close() error {
if b.err == io.EOF {
// A clean block-sequence terminator was encountered while reading.
return nil
} else if b.err != nil {
// Some other error was encountered while reading.
return b.err
}
if b.i == b.j {
// We reached the end of a sub block reading LZW data. We'll allow at
// most one more sub block of data with a length of 1 byte.
b.fill()
if b.err == io.EOF {
return nil
} else if b.err != nil {
return b.err
} else if b.j > 1 {
return errTooMuch
}
}
// Part of a sub-block remains buffered. We expect that the next attempt to
// buffer a sub-block will reach the block terminator.
b.fill()
if b.err == io.EOF {
return nil
} else if b.err != nil {
return b.err
}
return errTooMuch
}
// decode reads a GIF image from r and stores the result in d.
func (d *decoder) decode(r io.Reader, configOnly, keepAllFrames bool) error {
// Add buffering if r does not provide ReadByte.
if rr, ok := r.(reader); ok {
d.r = rr
} else {
d.r = bufio.NewReader(r)
}
d.loopCount = -1
err := d.readHeaderAndScreenDescriptor()
if err != nil {
return err
}
if configOnly {
return nil
}
for {
c, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading frames: %v", err)
}
switch c {
case sExtension:
if err = d.readExtension(); err != nil {
return err
}
case sImageDescriptor:
if err = d.readImageDescriptor(keepAllFrames); err != nil {
return err
}
if !keepAllFrames && len(d.image) == 1 {
return nil
}
case sTrailer:
if len(d.image) == 0 {
return fmt.Errorf("gif: missing image data")
}
return nil
default:
return fmt.Errorf("gif: unknown block type: 0x%.2x", c)
}
}
}
func (d *decoder) readHeaderAndScreenDescriptor() error {
err := readFull(d.r, d.tmp[:13])
if err != nil {
return fmt.Errorf("gif: reading header: %v", err)
}
d.vers = string(d.tmp[:6])
if d.vers != "GIF87a" && d.vers != "GIF89a" {
return fmt.Errorf("gif: can't recognize format %q", d.vers)
}
d.width = int(d.tmp[6]) + int(d.tmp[7])<<8
d.height = int(d.tmp[8]) + int(d.tmp[9])<<8
if fields := d.tmp[10]; fields&fColorTable != 0 {
d.backgroundIndex = d.tmp[11]
// readColorTable overwrites the contents of d.tmp, but that's OK.
if d.globalColorTable, err = d.readColorTable(fields); err != nil {
return err
}
}
// d.tmp[12] is the Pixel Aspect Ratio, which is ignored.
return nil
}
func (d *decoder) readColorTable(fields byte) (color.Palette, error) {
n := 1 << (1 + uint(fields&fColorTableBitsMask))
err := readFull(d.r, d.tmp[:3*n])
if err != nil {
return nil, fmt.Errorf("gif: reading color table: %s", err)
}
j, p := 0, make(color.Palette, n)
for i := range p {
p[i] = color.RGBA{d.tmp[j+0], d.tmp[j+1], d.tmp[j+2], 0xFF}
j += 3
}
return p, nil
}
func (d *decoder) readExtension() error {
extension, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
size := 0
switch extension {
case eText:
size = 13
case eGraphicControl:
return d.readGraphicControl()
case eComment:
// nothing to do but read the data.
case eApplication:
b, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
// The spec requires size be 11, but Adobe sometimes uses 10.
size = int(b)
default:
return fmt.Errorf("gif: unknown extension 0x%.2x", extension)
}
if size > 0 {
if err := readFull(d.r, d.tmp[:size]); err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
}
// Application Extension with "NETSCAPE2.0" as string and 1 in data means
// this extension defines a loop count.
if extension == eApplication && string(d.tmp[:size]) == "NETSCAPE2.0" {
n, err := d.readBlock()
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
if n == 0 {
return nil
}
if n == 3 && d.tmp[0] == 1 {
d.loopCount = int(d.tmp[1]) | int(d.tmp[2])<<8
}
}
for {
n, err := d.readBlock()
if err != nil {
return fmt.Errorf("gif: reading extension: %v", err)
}
if n == 0 {
return nil
}
}
}
func (d *decoder) readGraphicControl() error {
if err := readFull(d.r, d.tmp[:6]); err != nil {
return fmt.Errorf("gif: can't read graphic control: %s", err)
}
if d.tmp[0] != 4 {
return fmt.Errorf("gif: invalid graphic control extension block size: %d", d.tmp[0])
}
flags := d.tmp[1]
d.disposalMethod = (flags & gcDisposalMethodMask) >> 2
d.delayTime = int(d.tmp[2]) | int(d.tmp[3])<<8
if flags&gcTransparentColorSet != 0 {
d.transparentIndex = d.tmp[4]
d.hasTransparentIndex = true
}
if d.tmp[5] != 0 {
return fmt.Errorf("gif: invalid graphic control extension block terminator: %d", d.tmp[5])
}
return nil
}
func (d *decoder) readImageDescriptor(keepAllFrames bool) error {
m, err := d.newImageFromDescriptor()
if err != nil {
return err
}
useLocalColorTable := d.imageFields&fColorTable != 0
if useLocalColorTable {
m.Palette, err = d.readColorTable(d.imageFields)
if err != nil {
return err
}
} else {
if d.globalColorTable == nil {
return errors.New("gif: no color table")
}
m.Palette = d.globalColorTable
}
if d.hasTransparentIndex {
if !useLocalColorTable {
// Clone the global color table.
m.Palette = append(color.Palette(nil), d.globalColorTable...)
}
if ti := int(d.transparentIndex); ti < len(m.Palette) {
m.Palette[ti] = color.RGBA{}
} else {
// The transparentIndex is out of range, which is an error
// according to the spec, but Firefox and Google Chrome
// seem OK with this, so we enlarge the palette with
// transparent colors. See golang.org/issue/15059.
p := make(color.Palette, ti+1)
copy(p, m.Palette)
for i := len(m.Palette); i < len(p); i++ {
p[i] = color.RGBA{}
}
m.Palette = p
}
}
litWidth, err := readByte(d.r)
if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
if litWidth < 2 || litWidth > 8 {
return fmt.Errorf("gif: pixel size in decode out of range: %d", litWidth)
}
// A wonderfully Go-like piece of magic.
br := &blockReader{d: d}
lzwr := lzw.NewReader(br, lzw.LSB, int(litWidth))
defer lzwr.Close()
if err = readFull(lzwr, m.Pix); err != nil {
if err != io.ErrUnexpectedEOF {
return fmt.Errorf("gif: reading image data: %v", err)
}
return errNotEnough
}
// In theory, both lzwr and br should be exhausted. Reading from them
// should yield (0, io.EOF).
//
// The spec (Appendix F - Compression), says that "An End of
// Information code... must be the last code output by the encoder
// for an image". In practice, though, giflib (a widely used C
// library) does not enforce this, so we also accept lzwr returning
// io.ErrUnexpectedEOF (meaning that the encoded stream hit io.EOF
// before the LZW decoder saw an explicit end code), provided that
// the io.ReadFull call above successfully read len(m.Pix) bytes.
// See https://golang.org/issue/9856 for an example GIF.
if n, err := lzwr.Read(d.tmp[256:257]); n != 0 || (err != io.EOF && err != io.ErrUnexpectedEOF) {
if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
return errTooMuch
}
// In practice, some GIFs have an extra byte in the data sub-block
// stream, which we ignore. See https://golang.org/issue/16146.
if err := br.close(); err == errTooMuch {
return errTooMuch
} else if err != nil {
return fmt.Errorf("gif: reading image data: %v", err)
}
// Check that the color indexes are inside the palette.
if len(m.Palette) < 256 {
for _, pixel := range m.Pix {
if int(pixel) >= len(m.Palette) {
return errBadPixel
}
}
}
// Undo the interlacing if necessary.
if d.imageFields&fInterlace != 0 {
uninterlace(m)
}
if keepAllFrames || len(d.image) == 0 {
d.image = append(d.image, m)
d.delay = append(d.delay, d.delayTime)
d.disposal = append(d.disposal, d.disposalMethod)
}
// The GIF89a spec, Section 23 (Graphic Control Extension) says:
// "The scope of this extension is the first graphic rendering block
// to follow." We therefore reset the GCE fields to zero.
d.delayTime = 0
d.hasTransparentIndex = false
return nil
}
func (d *decoder) newImageFromDescriptor() (*image.Paletted, error) {
if err := readFull(d.r, d.tmp[:9]); err != nil {
return nil, fmt.Errorf("gif: can't read image descriptor: %s", err)
}
left := int(d.tmp[0]) + int(d.tmp[1])<<8
top := int(d.tmp[2]) + int(d.tmp[3])<<8
width := int(d.tmp[4]) + int(d.tmp[5])<<8
height := int(d.tmp[6]) + int(d.tmp[7])<<8
d.imageFields = d.tmp[8]
// The GIF89a spec, Section 20 (Image Descriptor) says: "Each image must
// fit within the boundaries of the Logical Screen, as defined in the
// Logical Screen Descriptor."
//
// This is conceptually similar to testing
// frameBounds := image.Rect(left, top, left+width, top+height)
// imageBounds := image.Rect(0, 0, d.width, d.height)
// if !frameBounds.In(imageBounds) { etc }
// but the semantics of the Go image.Rectangle type is that r.In(s) is true
// whenever r is an empty rectangle, even if r.Min.X > s.Max.X. Here, we
// want something stricter.
//
// Note that, by construction, left >= 0 && top >= 0, so we only have to
// explicitly compare frameBounds.Max (left+width, top+height) against
// imageBounds.Max (d.width, d.height) and not frameBounds.Min (left, top)
// against imageBounds.Min (0, 0).
if left+width > d.width || top+height > d.height {
return nil, errors.New("gif: frame bounds larger than image bounds")
}
return image.NewPaletted(image.Rectangle{
Min: image.Point{left, top},
Max: image.Point{left + width, top + height},
}, nil), nil
}
func (d *decoder) readBlock() (int, error) {
n, err := readByte(d.r)
if n == 0 || err != nil {
return 0, err
}
if err := readFull(d.r, d.tmp[:n]); err != nil {
return 0, err
}
return int(n), nil
}
// interlaceScan defines the ordering for a pass of the interlace algorithm.
type interlaceScan struct {
skip, start int
}
// interlacing represents the set of scans in an interlaced GIF image.
var interlacing = []interlaceScan{
{8, 0}, // Group 1 : Every 8th. row, starting with row 0.
{8, 4}, // Group 2 : Every 8th. row, starting with row 4.
{4, 2}, // Group 3 : Every 4th. row, starting with row 2.
{2, 1}, // Group 4 : Every 2nd. row, starting with row 1.
}
// uninterlace rearranges the pixels in m to account for interlaced input.
func uninterlace(m *image.Paletted) {
var nPix []uint8
dx := m.Bounds().Dx()
dy := m.Bounds().Dy()
nPix = make([]uint8, dx*dy)
offset := 0 // steps through the input by sequential scan lines.
for _, pass := range interlacing {
nOffset := pass.start * dx // steps through the output as defined by pass.
for y := pass.start; y < dy; y += pass.skip {
copy(nPix[nOffset:nOffset+dx], m.Pix[offset:offset+dx])
offset += dx
nOffset += dx * pass.skip
}
}
m.Pix = nPix
}
// Decode reads a GIF image from r and returns the first embedded
// image as an image.Image.
func Decode(r io.Reader) (image.Image, error) {
var d decoder
if err := d.decode(r, false, false); err != nil {
return nil, err
}
return d.image[0], nil
}
// GIF represents the possibly multiple images stored in a GIF file.
type GIF struct {
Image []*image.Paletted // The successive images.
Delay []int // The successive delay times, one per frame, in 100ths of a second.
// LoopCount controls the number of times an animation will be
// restarted during display.
// A LoopCount of 0 means to loop forever.
// A LoopCount of -1 means to show each frame only once.
// Otherwise, the animation is looped LoopCount+1 times.
LoopCount int
// Disposal is the successive disposal methods, one per frame. For
// backwards compatibility, a nil Disposal is valid to pass to EncodeAll,
// and implies that each frame's disposal method is 0 (no disposal
// specified).
Disposal []byte
// Config is the global color table (palette), width and height. A nil or
// empty-color.Palette Config.ColorModel means that each frame has its own
// color table and there is no global color table. Each frame's bounds must
// be within the rectangle defined by the two points (0, 0) and
// (Config.Width, Config.Height).
//
// For backwards compatibility, a zero-valued Config is valid to pass to
// EncodeAll, and implies that the overall GIF's width and height equals
// the first frame's bounds' Rectangle.Max point.
Config image.Config
// BackgroundIndex is the background index in the global color table, for
// use with the DisposalBackground disposal method.
BackgroundIndex byte
}
// DecodeAll reads a GIF image from r and returns the sequential frames
// and timing information.
func DecodeAll(r io.Reader) (*GIF, error) {
var d decoder
if err := d.decode(r, false, true); err != nil {
return nil, err
}
gif := &GIF{
Image: d.image,
LoopCount: d.loopCount,
Delay: d.delay,
Disposal: d.disposal,
Config: image.Config{
ColorModel: d.globalColorTable,
Width: d.width,
Height: d.height,
},
BackgroundIndex: d.backgroundIndex,
}
return gif, nil
}
// DecodeConfig returns the global color model and dimensions of a GIF image
// without decoding the entire image.
func DecodeConfig(r io.Reader) (image.Config, error) {
var d decoder
if err := d.decode(r, true, false); err != nil {
return image.Config{}, err
}
return image.Config{
ColorModel: d.globalColorTable,
Width: d.width,
Height: d.height,
}, nil
}
func init() {
image.RegisterFormat("gif", "GIF8?a", Decode, DecodeConfig)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gif
import (
"bufio"
"bytes"
"compress/lzw"
"errors"
"image"
"image/color"
"image/color/palette"
"image/draw"
"io"
)
// Graphic control extension fields.
const (
gcLabel = 0xF9
gcBlockSize = 0x04
)
var log2Lookup = [8]int{2, 4, 8, 16, 32, 64, 128, 256}
func log2(x int) int {
for i, v := range log2Lookup {
if x <= v {
return i
}
}
return -1
}
// Little-endian.
func writeUint16(b []uint8, u uint16) {
b[0] = uint8(u)
b[1] = uint8(u >> 8)
}
// writer is a buffered writer.
type writer interface {
Flush() error
io.Writer
io.ByteWriter
}
// encoder encodes an image to the GIF format.
type encoder struct {
// w is the writer to write to. err is the first error encountered during
// writing. All attempted writes after the first error become no-ops.
w writer
err error
// g is a reference to the data that is being encoded.
g GIF
// globalCT is the size in bytes of the global color table.
globalCT int
// buf is a scratch buffer. It must be at least 256 for the blockWriter.
buf [256]byte
globalColorTable [3 * 256]byte
localColorTable [3 * 256]byte
}
// blockWriter writes the block structure of GIF image data, which
// comprises (n, (n bytes)) blocks, with 1 <= n <= 255. It is the
// writer given to the LZW encoder, which is thus immune to the
// blocking.
type blockWriter struct {
e *encoder
}
func (b blockWriter) setup() {
b.e.buf[0] = 0
}
func (b blockWriter) Flush() error {
return b.e.err
}
func (b blockWriter) WriteByte(c byte) error {
if b.e.err != nil {
return b.e.err
}
// Append c to buffered sub-block.
b.e.buf[0]++
b.e.buf[b.e.buf[0]] = c
if b.e.buf[0] < 255 {
return nil
}
// Flush block
b.e.write(b.e.buf[:256])
b.e.buf[0] = 0
return b.e.err
}
// blockWriter must be an io.Writer for lzw.NewWriter, but this is never
// actually called.
func (b blockWriter) Write(data []byte) (int, error) {
for i, c := range data {
if err := b.WriteByte(c); err != nil {
return i, err
}
}
return len(data), nil
}
func (b blockWriter) close() {
// Write the block terminator (0x00), either by itself, or along with a
// pending sub-block.
if b.e.buf[0] == 0 {
b.e.writeByte(0)
} else {
n := uint(b.e.buf[0])
b.e.buf[n+1] = 0
b.e.write(b.e.buf[:n+2])
}
b.e.flush()
}
func (e *encoder) flush() {
if e.err != nil {
return
}
e.err = e.w.Flush()
}
func (e *encoder) write(p []byte) {
if e.err != nil {
return
}
_, e.err = e.w.Write(p)
}
func (e *encoder) writeByte(b byte) {
if e.err != nil {
return
}
e.err = e.w.WriteByte(b)
}
func (e *encoder) writeHeader() {
if e.err != nil {
return
}
_, e.err = io.WriteString(e.w, "GIF89a")
if e.err != nil {
return
}
// Logical screen width and height.
writeUint16(e.buf[0:2], uint16(e.g.Config.Width))
writeUint16(e.buf[2:4], uint16(e.g.Config.Height))
e.write(e.buf[:4])
if p, ok := e.g.Config.ColorModel.(color.Palette); ok && len(p) > 0 {
paddedSize := log2(len(p)) // Size of Global Color Table: 2^(1+n).
e.buf[0] = fColorTable | uint8(paddedSize)
e.buf[1] = e.g.BackgroundIndex
e.buf[2] = 0x00 // Pixel Aspect Ratio.
e.write(e.buf[:3])
var err error
e.globalCT, err = encodeColorTable(e.globalColorTable[:], p, paddedSize)
if err != nil && e.err == nil {
e.err = err
return
}
e.write(e.globalColorTable[:e.globalCT])
} else {
// All frames have a local color table, so a global color table
// is not needed.
e.buf[0] = 0x00
e.buf[1] = 0x00 // Background Color Index.
e.buf[2] = 0x00 // Pixel Aspect Ratio.
e.write(e.buf[:3])
}
// Add animation info if necessary.
if len(e.g.Image) > 1 && e.g.LoopCount >= 0 {
e.buf[0] = 0x21 // Extension Introducer.
e.buf[1] = 0xff // Application Label.
e.buf[2] = 0x0b // Block Size.
e.write(e.buf[:3])
_, err := io.WriteString(e.w, "NETSCAPE2.0") // Application Identifier.
if err != nil && e.err == nil {
e.err = err
return
}
e.buf[0] = 0x03 // Block Size.
e.buf[1] = 0x01 // Sub-block Index.
writeUint16(e.buf[2:4], uint16(e.g.LoopCount))
e.buf[4] = 0x00 // Block Terminator.
e.write(e.buf[:5])
}
}
func encodeColorTable(dst []byte, p color.Palette, size int) (int, error) {
if uint(size) >= uint(len(log2Lookup)) {
return 0, errors.New("gif: cannot encode color table with more than 256 entries")
}
for i, c := range p {
if c == nil {
return 0, errors.New("gif: cannot encode color table with nil entries")
}
var r, g, b uint8
// It is most likely that the palette is full of color.RGBAs, so they
// get a fast path.
if rgba, ok := c.(color.RGBA); ok {
r, g, b = rgba.R, rgba.G, rgba.B
} else {
rr, gg, bb, _ := c.RGBA()
r, g, b = uint8(rr>>8), uint8(gg>>8), uint8(bb>>8)
}
dst[3*i+0] = r
dst[3*i+1] = g
dst[3*i+2] = b
}
n := log2Lookup[size]
if n > len(p) {
// Pad with black.
fill := dst[3*len(p) : 3*n]
for i := range fill {
fill[i] = 0
}
}
return 3 * n, nil
}
func (e *encoder) colorTablesMatch(localLen, transparentIndex int) bool {
localSize := 3 * localLen
if transparentIndex >= 0 {
trOff := 3 * transparentIndex
return bytes.Equal(e.globalColorTable[:trOff], e.localColorTable[:trOff]) &&
bytes.Equal(e.globalColorTable[trOff+3:localSize], e.localColorTable[trOff+3:localSize])
}
return bytes.Equal(e.globalColorTable[:localSize], e.localColorTable[:localSize])
}
func (e *encoder) writeImageBlock(pm *image.Paletted, delay int, disposal byte) {
if e.err != nil {
return
}
if len(pm.Palette) == 0 {
e.err = errors.New("gif: cannot encode image block with empty palette")
return
}
b := pm.Bounds()
if b.Min.X < 0 || b.Max.X >= 1<<16 || b.Min.Y < 0 || b.Max.Y >= 1<<16 {
e.err = errors.New("gif: image block is too large to encode")
return
}
if !b.In(image.Rectangle{Max: image.Point{e.g.Config.Width, e.g.Config.Height}}) {
e.err = errors.New("gif: image block is out of bounds")
return
}
transparentIndex := -1
for i, c := range pm.Palette {
if c == nil {
e.err = errors.New("gif: cannot encode color table with nil entries")
return
}
if _, _, _, a := c.RGBA(); a == 0 {
transparentIndex = i
break
}
}
if delay > 0 || disposal != 0 || transparentIndex != -1 {
e.buf[0] = sExtension // Extension Introducer.
e.buf[1] = gcLabel // Graphic Control Label.
e.buf[2] = gcBlockSize // Block Size.
if transparentIndex != -1 {
e.buf[3] = 0x01 | disposal<<2
} else {
e.buf[3] = 0x00 | disposal<<2
}
writeUint16(e.buf[4:6], uint16(delay)) // Delay Time (1/100ths of a second)
// Transparent color index.
if transparentIndex != -1 {
e.buf[6] = uint8(transparentIndex)
} else {
e.buf[6] = 0x00
}
e.buf[7] = 0x00 // Block Terminator.
e.write(e.buf[:8])
}
e.buf[0] = sImageDescriptor
writeUint16(e.buf[1:3], uint16(b.Min.X))
writeUint16(e.buf[3:5], uint16(b.Min.Y))
writeUint16(e.buf[5:7], uint16(b.Dx()))
writeUint16(e.buf[7:9], uint16(b.Dy()))
e.write(e.buf[:9])
// To determine whether or not this frame's palette is the same as the
// global palette, we can check a couple things. First, do they actually
// point to the same []color.Color? If so, they are equal so long as the
// frame's palette is not longer than the global palette...
paddedSize := log2(len(pm.Palette)) // Size of Local Color Table: 2^(1+n).
if gp, ok := e.g.Config.ColorModel.(color.Palette); ok && len(pm.Palette) <= len(gp) && &gp[0] == &pm.Palette[0] {
e.writeByte(0) // Use the global color table.
} else {
ct, err := encodeColorTable(e.localColorTable[:], pm.Palette, paddedSize)
if err != nil {
if e.err == nil {
e.err = err
}
return
}
// This frame's palette is not the very same slice as the global
// palette, but it might be a copy, possibly with one value turned into
// transparency by DecodeAll.
if ct <= e.globalCT && e.colorTablesMatch(len(pm.Palette), transparentIndex) {
e.writeByte(0) // Use the global color table.
} else {
// Use a local color table.
e.writeByte(fColorTable | uint8(paddedSize))
e.write(e.localColorTable[:ct])
}
}
litWidth := paddedSize + 1
if litWidth < 2 {
litWidth = 2
}
e.writeByte(uint8(litWidth)) // LZW Minimum Code Size.
bw := blockWriter{e: e}
bw.setup()
lzww := lzw.NewWriter(bw, lzw.LSB, litWidth)
if dx := b.Dx(); dx == pm.Stride {
_, e.err = lzww.Write(pm.Pix[:dx*b.Dy()])
if e.err != nil {
lzww.Close()
return
}
} else {
for i, y := 0, b.Min.Y; y < b.Max.Y; i, y = i+pm.Stride, y+1 {
_, e.err = lzww.Write(pm.Pix[i : i+dx])
if e.err != nil {
lzww.Close()
return
}
}
}
lzww.Close() // flush to bw
bw.close() // flush to e.w
}
// Options are the encoding parameters.
type Options struct {
// NumColors is the maximum number of colors used in the image.
// It ranges from 1 to 256.
NumColors int
// Quantizer is used to produce a palette with size NumColors.
// palette.Plan9 is used in place of a nil Quantizer.
Quantizer draw.Quantizer
// Drawer is used to convert the source image to the desired palette.
// draw.FloydSteinberg is used in place of a nil Drawer.
Drawer draw.Drawer
}
// EncodeAll writes the images in g to w in GIF format with the
// given loop count and delay between frames.
func EncodeAll(w io.Writer, g *GIF) error {
if len(g.Image) == 0 {
return errors.New("gif: must provide at least one image")
}
if len(g.Image) != len(g.Delay) {
return errors.New("gif: mismatched image and delay lengths")
}
e := encoder{g: *g}
// The GIF.Disposal, GIF.Config and GIF.BackgroundIndex fields were added
// in Go 1.5. Valid Go 1.4 code, such as when the Disposal field is omitted
// in a GIF struct literal, should still produce valid GIFs.
if e.g.Disposal != nil && len(e.g.Image) != len(e.g.Disposal) {
return errors.New("gif: mismatched image and disposal lengths")
}
if e.g.Config == (image.Config{}) {
p := g.Image[0].Bounds().Max
e.g.Config.Width = p.X
e.g.Config.Height = p.Y
} else if e.g.Config.ColorModel != nil {
if _, ok := e.g.Config.ColorModel.(color.Palette); !ok {
return errors.New("gif: GIF color model must be a color.Palette")
}
}
if ww, ok := w.(writer); ok {
e.w = ww
} else {
e.w = bufio.NewWriter(w)
}
e.writeHeader()
for i, pm := range g.Image {
disposal := uint8(0)
if g.Disposal != nil {
disposal = g.Disposal[i]
}
e.writeImageBlock(pm, g.Delay[i], disposal)
}
e.writeByte(sTrailer)
e.flush()
return e.err
}
// Encode writes the Image m to w in GIF format.
func Encode(w io.Writer, m image.Image, o *Options) error {
// Check for bounds and size restrictions.
b := m.Bounds()
if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 {
return errors.New("gif: image is too large to encode")
}
opts := Options{}
if o != nil {
opts = *o
}
if opts.NumColors < 1 || 256 < opts.NumColors {
opts.NumColors = 256
}
if opts.Drawer == nil {
opts.Drawer = draw.FloydSteinberg
}
pm, _ := m.(*image.Paletted)
if pm == nil {
if cp, ok := m.ColorModel().(color.Palette); ok {
pm = image.NewPaletted(b, cp)
for y := b.Min.Y; y < b.Max.Y; y++ {
for x := b.Min.X; x < b.Max.X; x++ {
pm.Set(x, y, cp.Convert(m.At(x, y)))
}
}
}
}
if pm == nil || len(pm.Palette) > opts.NumColors {
// Set pm to be a palettedized copy of m, including its bounds, which
// might not start at (0, 0).
//
// TODO: Pick a better sub-sample of the Plan 9 palette.
pm = image.NewPaletted(b, palette.Plan9[:opts.NumColors])
if opts.Quantizer != nil {
pm.Palette = opts.Quantizer.Quantize(make(color.Palette, 0, opts.NumColors), m)
}
opts.Drawer.Draw(pm, b, m, b.Min)
}
// When calling Encode instead of EncodeAll, the single-frame image is
// translated such that its top-left corner is (0, 0), so that the single
// frame completely fills the overall GIF's bounds.
if pm.Rect.Min != (image.Point{}) {
dup := *pm
dup.Rect = dup.Rect.Sub(dup.Rect.Min)
pm = &dup
}
return EncodeAll(w, &GIF{
Image: []*image.Paletted{pm},
Delay: []int{0},
Config: image.Config{
ColorModel: pm.Palette,
Width: b.Dx(),
Height: b.Dy(),
},
})
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package image implements a basic 2-D image library.
//
// The fundamental interface is called Image. An Image contains colors, which
// are described in the image/color package.
//
// Values of the Image interface are created either by calling functions such
// as NewRGBA and NewPaletted, or by calling Decode on an io.Reader containing
// image data in a format such as GIF, JPEG or PNG. Decoding any particular
// image format requires the prior registration of a decoder function.
// Registration is typically automatic as a side effect of initializing that
// format's package so that, to decode a PNG image, it suffices to have
//
// import _ "image/png"
//
// in a program's main package. The _ means to import a package purely for its
// initialization side effects.
//
// See "The Go image package" for more details:
// https://golang.org/doc/articles/image_package.html
package image
import (
"image/color"
)
// Config holds an image's color model and dimensions.
type Config struct {
ColorModel color.Model
Width, Height int
}
// Image is a finite rectangular grid of color.Color values taken from a color
// model.
type Image interface {
// ColorModel returns the Image's color model.
ColorModel() color.Model
// Bounds returns the domain for which At can return non-zero color.
// The bounds do not necessarily contain the point (0, 0).
Bounds() Rectangle
// At returns the color of the pixel at (x, y).
// At(Bounds().Min.X, Bounds().Min.Y) returns the upper-left pixel of the grid.
// At(Bounds().Max.X-1, Bounds().Max.Y-1) returns the lower-right one.
At(x, y int) color.Color
}
// RGBA64Image is an Image whose pixels can be converted directly to a
// color.RGBA64.
type RGBA64Image interface {
// RGBA64At returns the RGBA64 color of the pixel at (x, y). It is
// equivalent to calling At(x, y).RGBA() and converting the resulting
// 32-bit return values to a color.RGBA64, but it can avoid allocations
// from converting concrete color types to the color.Color interface type.
RGBA64At(x, y int) color.RGBA64
Image
}
// PalettedImage is an image whose colors may come from a limited palette.
// If m is a PalettedImage and m.ColorModel() returns a color.Palette p,
// then m.At(x, y) should be equivalent to p[m.ColorIndexAt(x, y)]. If m's
// color model is not a color.Palette, then ColorIndexAt's behavior is
// undefined.
type PalettedImage interface {
// ColorIndexAt returns the palette index of the pixel at (x, y).
ColorIndexAt(x, y int) uint8
Image
}
// pixelBufferLength returns the length of the []uint8 typed Pix slice field
// for the NewXxx functions. Conceptually, this is just (bpp * width * height),
// but this function panics if at least one of those is negative or if the
// computation would overflow the int type.
//
// This panics instead of returning an error because of backwards
// compatibility. The NewXxx functions do not return an error.
func pixelBufferLength(bytesPerPixel int, r Rectangle, imageTypeName string) int {
totalLength := mul3NonNeg(bytesPerPixel, r.Dx(), r.Dy())
if totalLength < 0 {
panic("image: New" + imageTypeName + " Rectangle has huge or negative dimensions")
}
return totalLength
}
// RGBA is an in-memory image whose At method returns color.RGBA values.
type RGBA struct {
// Pix holds the image's pixels, in R, G, B, A order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *RGBA) ColorModel() color.Model { return color.RGBAModel }
func (p *RGBA) Bounds() Rectangle { return p.Rect }
func (p *RGBA) At(x, y int) color.Color {
return p.RGBAAt(x, y)
}
func (p *RGBA) RGBA64At(x, y int) color.RGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
r := uint16(s[0])
g := uint16(s[1])
b := uint16(s[2])
a := uint16(s[3])
return color.RGBA64{
(r << 8) | r,
(g << 8) | g,
(b << 8) | b,
(a << 8) | a,
}
}
func (p *RGBA) RGBAAt(x, y int) color.RGBA {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.RGBA{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *RGBA) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *RGBA) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.RGBAModel.Convert(c).(color.RGBA)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.R
s[1] = c1.G
s[2] = c1.B
s[3] = c1.A
}
func (p *RGBA) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.G >> 8)
s[2] = uint8(c.B >> 8)
s[3] = uint8(c.A >> 8)
}
func (p *RGBA) SetRGBA(x, y int, c color.RGBA) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.R
s[1] = c.G
s[2] = c.B
s[3] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *RGBA) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &RGBA{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &RGBA{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *RGBA) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 3, p.Rect.Dx()*4
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 4 {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewRGBA returns a new RGBA image with the given bounds.
func NewRGBA(r Rectangle) *RGBA {
return &RGBA{
Pix: make([]uint8, pixelBufferLength(4, r, "RGBA")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// RGBA64 is an in-memory image whose At method returns color.RGBA64 values.
type RGBA64 struct {
// Pix holds the image's pixels, in R, G, B, A order and big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*8].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *RGBA64) ColorModel() color.Model { return color.RGBA64Model }
func (p *RGBA64) Bounds() Rectangle { return p.Rect }
func (p *RGBA64) At(x, y int) color.Color {
return p.RGBA64At(x, y)
}
func (p *RGBA64) RGBA64At(x, y int) color.RGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.RGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
return color.RGBA64{
uint16(s[0])<<8 | uint16(s[1]),
uint16(s[2])<<8 | uint16(s[3]),
uint16(s[4])<<8 | uint16(s[5]),
uint16(s[6])<<8 | uint16(s[7]),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *RGBA64) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*8
}
func (p *RGBA64) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.RGBA64Model.Convert(c).(color.RGBA64)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c1.R >> 8)
s[1] = uint8(c1.R)
s[2] = uint8(c1.G >> 8)
s[3] = uint8(c1.G)
s[4] = uint8(c1.B >> 8)
s[5] = uint8(c1.B)
s[6] = uint8(c1.A >> 8)
s[7] = uint8(c1.A)
}
func (p *RGBA64) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.R)
s[2] = uint8(c.G >> 8)
s[3] = uint8(c.G)
s[4] = uint8(c.B >> 8)
s[5] = uint8(c.B)
s[6] = uint8(c.A >> 8)
s[7] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *RGBA64) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &RGBA64{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &RGBA64{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *RGBA64) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 6, p.Rect.Dx()*8
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 8 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewRGBA64 returns a new RGBA64 image with the given bounds.
func NewRGBA64(r Rectangle) *RGBA64 {
return &RGBA64{
Pix: make([]uint8, pixelBufferLength(8, r, "RGBA64")),
Stride: 8 * r.Dx(),
Rect: r,
}
}
// NRGBA is an in-memory image whose At method returns color.NRGBA values.
type NRGBA struct {
// Pix holds the image's pixels, in R, G, B, A order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *NRGBA) ColorModel() color.Model { return color.NRGBAModel }
func (p *NRGBA) Bounds() Rectangle { return p.Rect }
func (p *NRGBA) At(x, y int) color.Color {
return p.NRGBAAt(x, y)
}
func (p *NRGBA) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.NRGBAAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *NRGBA) NRGBAAt(x, y int) color.NRGBA {
if !(Point{x, y}.In(p.Rect)) {
return color.NRGBA{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.NRGBA{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *NRGBA) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *NRGBA) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.NRGBAModel.Convert(c).(color.NRGBA)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.R
s[1] = c1.G
s[2] = c1.B
s[3] = c1.A
}
func (p *NRGBA) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
r, g, b, a := uint32(c.R), uint32(c.G), uint32(c.B), uint32(c.A)
if (a != 0) && (a != 0xffff) {
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(r >> 8)
s[1] = uint8(g >> 8)
s[2] = uint8(b >> 8)
s[3] = uint8(a >> 8)
}
func (p *NRGBA) SetNRGBA(x, y int, c color.NRGBA) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.R
s[1] = c.G
s[2] = c.B
s[3] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *NRGBA) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &NRGBA{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &NRGBA{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *NRGBA) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 3, p.Rect.Dx()*4
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 4 {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewNRGBA returns a new NRGBA image with the given bounds.
func NewNRGBA(r Rectangle) *NRGBA {
return &NRGBA{
Pix: make([]uint8, pixelBufferLength(4, r, "NRGBA")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// NRGBA64 is an in-memory image whose At method returns color.NRGBA64 values.
type NRGBA64 struct {
// Pix holds the image's pixels, in R, G, B, A order and big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*8].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *NRGBA64) ColorModel() color.Model { return color.NRGBA64Model }
func (p *NRGBA64) Bounds() Rectangle { return p.Rect }
func (p *NRGBA64) At(x, y int) color.Color {
return p.NRGBA64At(x, y)
}
func (p *NRGBA64) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.NRGBA64At(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *NRGBA64) NRGBA64At(x, y int) color.NRGBA64 {
if !(Point{x, y}.In(p.Rect)) {
return color.NRGBA64{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
return color.NRGBA64{
uint16(s[0])<<8 | uint16(s[1]),
uint16(s[2])<<8 | uint16(s[3]),
uint16(s[4])<<8 | uint16(s[5]),
uint16(s[6])<<8 | uint16(s[7]),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *NRGBA64) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*8
}
func (p *NRGBA64) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.NRGBA64Model.Convert(c).(color.NRGBA64)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c1.R >> 8)
s[1] = uint8(c1.R)
s[2] = uint8(c1.G >> 8)
s[3] = uint8(c1.G)
s[4] = uint8(c1.B >> 8)
s[5] = uint8(c1.B)
s[6] = uint8(c1.A >> 8)
s[7] = uint8(c1.A)
}
func (p *NRGBA64) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
r, g, b, a := uint32(c.R), uint32(c.G), uint32(c.B), uint32(c.A)
if (a != 0) && (a != 0xffff) {
r = (r * 0xffff) / a
g = (g * 0xffff) / a
b = (b * 0xffff) / a
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(r >> 8)
s[1] = uint8(r)
s[2] = uint8(g >> 8)
s[3] = uint8(g)
s[4] = uint8(b >> 8)
s[5] = uint8(b)
s[6] = uint8(a >> 8)
s[7] = uint8(a)
}
func (p *NRGBA64) SetNRGBA64(x, y int, c color.NRGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+8 : i+8] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = uint8(c.R >> 8)
s[1] = uint8(c.R)
s[2] = uint8(c.G >> 8)
s[3] = uint8(c.G)
s[4] = uint8(c.B >> 8)
s[5] = uint8(c.B)
s[6] = uint8(c.A >> 8)
s[7] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *NRGBA64) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &NRGBA64{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &NRGBA64{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *NRGBA64) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 6, p.Rect.Dx()*8
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 8 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewNRGBA64 returns a new NRGBA64 image with the given bounds.
func NewNRGBA64(r Rectangle) *NRGBA64 {
return &NRGBA64{
Pix: make([]uint8, pixelBufferLength(8, r, "NRGBA64")),
Stride: 8 * r.Dx(),
Rect: r,
}
}
// Alpha is an in-memory image whose At method returns color.Alpha values.
type Alpha struct {
// Pix holds the image's pixels, as alpha values. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Alpha) ColorModel() color.Model { return color.AlphaModel }
func (p *Alpha) Bounds() Rectangle { return p.Rect }
func (p *Alpha) At(x, y int) color.Color {
return p.AlphaAt(x, y)
}
func (p *Alpha) RGBA64At(x, y int) color.RGBA64 {
a := uint16(p.AlphaAt(x, y).A)
a |= a << 8
return color.RGBA64{a, a, a, a}
}
func (p *Alpha) AlphaAt(x, y int) color.Alpha {
if !(Point{x, y}.In(p.Rect)) {
return color.Alpha{}
}
i := p.PixOffset(x, y)
return color.Alpha{p.Pix[i]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Alpha) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Alpha) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = color.AlphaModel.Convert(c).(color.Alpha).A
}
func (p *Alpha) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(c.A >> 8)
}
func (p *Alpha) SetAlpha(x, y int, c color.Alpha) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = c.A
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Alpha) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Alpha{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Alpha{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Alpha) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 0, p.Rect.Dx()
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i++ {
if p.Pix[i] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewAlpha returns a new Alpha image with the given bounds.
func NewAlpha(r Rectangle) *Alpha {
return &Alpha{
Pix: make([]uint8, pixelBufferLength(1, r, "Alpha")),
Stride: 1 * r.Dx(),
Rect: r,
}
}
// Alpha16 is an in-memory image whose At method returns color.Alpha16 values.
type Alpha16 struct {
// Pix holds the image's pixels, as alpha values in big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*2].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Alpha16) ColorModel() color.Model { return color.Alpha16Model }
func (p *Alpha16) Bounds() Rectangle { return p.Rect }
func (p *Alpha16) At(x, y int) color.Color {
return p.Alpha16At(x, y)
}
func (p *Alpha16) RGBA64At(x, y int) color.RGBA64 {
a := p.Alpha16At(x, y).A
return color.RGBA64{a, a, a, a}
}
func (p *Alpha16) Alpha16At(x, y int) color.Alpha16 {
if !(Point{x, y}.In(p.Rect)) {
return color.Alpha16{}
}
i := p.PixOffset(x, y)
return color.Alpha16{uint16(p.Pix[i+0])<<8 | uint16(p.Pix[i+1])}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Alpha16) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*2
}
func (p *Alpha16) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.Alpha16Model.Convert(c).(color.Alpha16)
p.Pix[i+0] = uint8(c1.A >> 8)
p.Pix[i+1] = uint8(c1.A)
}
func (p *Alpha16) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.A >> 8)
p.Pix[i+1] = uint8(c.A)
}
func (p *Alpha16) SetAlpha16(x, y int, c color.Alpha16) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.A >> 8)
p.Pix[i+1] = uint8(c.A)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Alpha16) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Alpha16{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Alpha16{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Alpha16) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 0, p.Rect.Dx()*2
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for i := i0; i < i1; i += 2 {
if p.Pix[i+0] != 0xff || p.Pix[i+1] != 0xff {
return false
}
}
i0 += p.Stride
i1 += p.Stride
}
return true
}
// NewAlpha16 returns a new Alpha16 image with the given bounds.
func NewAlpha16(r Rectangle) *Alpha16 {
return &Alpha16{
Pix: make([]uint8, pixelBufferLength(2, r, "Alpha16")),
Stride: 2 * r.Dx(),
Rect: r,
}
}
// Gray is an in-memory image whose At method returns color.Gray values.
type Gray struct {
// Pix holds the image's pixels, as gray values. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Gray) ColorModel() color.Model { return color.GrayModel }
func (p *Gray) Bounds() Rectangle { return p.Rect }
func (p *Gray) At(x, y int) color.Color {
return p.GrayAt(x, y)
}
func (p *Gray) RGBA64At(x, y int) color.RGBA64 {
gray := uint16(p.GrayAt(x, y).Y)
gray |= gray << 8
return color.RGBA64{gray, gray, gray, 0xffff}
}
func (p *Gray) GrayAt(x, y int) color.Gray {
if !(Point{x, y}.In(p.Rect)) {
return color.Gray{}
}
i := p.PixOffset(x, y)
return color.Gray{p.Pix[i]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Gray) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Gray) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = color.GrayModel.Convert(c).(color.Gray).Y
}
func (p *Gray) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
// This formula is the same as in color.grayModel.
gray := (19595*uint32(c.R) + 38470*uint32(c.G) + 7471*uint32(c.B) + 1<<15) >> 24
i := p.PixOffset(x, y)
p.Pix[i] = uint8(gray)
}
func (p *Gray) SetGray(x, y int, c color.Gray) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = c.Y
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Gray) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Gray{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Gray{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Gray) Opaque() bool {
return true
}
// NewGray returns a new Gray image with the given bounds.
func NewGray(r Rectangle) *Gray {
return &Gray{
Pix: make([]uint8, pixelBufferLength(1, r, "Gray")),
Stride: 1 * r.Dx(),
Rect: r,
}
}
// Gray16 is an in-memory image whose At method returns color.Gray16 values.
type Gray16 struct {
// Pix holds the image's pixels, as gray values in big-endian format. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*2].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *Gray16) ColorModel() color.Model { return color.Gray16Model }
func (p *Gray16) Bounds() Rectangle { return p.Rect }
func (p *Gray16) At(x, y int) color.Color {
return p.Gray16At(x, y)
}
func (p *Gray16) RGBA64At(x, y int) color.RGBA64 {
gray := p.Gray16At(x, y).Y
return color.RGBA64{gray, gray, gray, 0xffff}
}
func (p *Gray16) Gray16At(x, y int) color.Gray16 {
if !(Point{x, y}.In(p.Rect)) {
return color.Gray16{}
}
i := p.PixOffset(x, y)
return color.Gray16{uint16(p.Pix[i+0])<<8 | uint16(p.Pix[i+1])}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Gray16) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*2
}
func (p *Gray16) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.Gray16Model.Convert(c).(color.Gray16)
p.Pix[i+0] = uint8(c1.Y >> 8)
p.Pix[i+1] = uint8(c1.Y)
}
func (p *Gray16) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
// This formula is the same as in color.gray16Model.
gray := (19595*uint32(c.R) + 38470*uint32(c.G) + 7471*uint32(c.B) + 1<<15) >> 16
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(gray >> 8)
p.Pix[i+1] = uint8(gray)
}
func (p *Gray16) SetGray16(x, y int, c color.Gray16) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i+0] = uint8(c.Y >> 8)
p.Pix[i+1] = uint8(c.Y)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Gray16) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Gray16{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Gray16{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Gray16) Opaque() bool {
return true
}
// NewGray16 returns a new Gray16 image with the given bounds.
func NewGray16(r Rectangle) *Gray16 {
return &Gray16{
Pix: make([]uint8, pixelBufferLength(2, r, "Gray16")),
Stride: 2 * r.Dx(),
Rect: r,
}
}
// CMYK is an in-memory image whose At method returns color.CMYK values.
type CMYK struct {
// Pix holds the image's pixels, in C, M, Y, K order. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*4].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
}
func (p *CMYK) ColorModel() color.Model { return color.CMYKModel }
func (p *CMYK) Bounds() Rectangle { return p.Rect }
func (p *CMYK) At(x, y int) color.Color {
return p.CMYKAt(x, y)
}
func (p *CMYK) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.CMYKAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *CMYK) CMYKAt(x, y int) color.CMYK {
if !(Point{x, y}.In(p.Rect)) {
return color.CMYK{}
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
return color.CMYK{s[0], s[1], s[2], s[3]}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *CMYK) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*4
}
func (p *CMYK) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
c1 := color.CMYKModel.Convert(c).(color.CMYK)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c1.C
s[1] = c1.M
s[2] = c1.Y
s[3] = c1.K
}
func (p *CMYK) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
cc, mm, yy, kk := color.RGBToCMYK(uint8(c.R>>8), uint8(c.G>>8), uint8(c.B>>8))
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = cc
s[1] = mm
s[2] = yy
s[3] = kk
}
func (p *CMYK) SetCMYK(x, y int, c color.CMYK) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
s := p.Pix[i : i+4 : i+4] // Small cap improves performance, see https://golang.org/issue/27857
s[0] = c.C
s[1] = c.M
s[2] = c.Y
s[3] = c.K
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *CMYK) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &CMYK{}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &CMYK{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: r,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *CMYK) Opaque() bool {
return true
}
// NewCMYK returns a new CMYK image with the given bounds.
func NewCMYK(r Rectangle) *CMYK {
return &CMYK{
Pix: make([]uint8, pixelBufferLength(4, r, "CMYK")),
Stride: 4 * r.Dx(),
Rect: r,
}
}
// Paletted is an in-memory image of uint8 indices into a given palette.
type Paletted struct {
// Pix holds the image's pixels, as palette indices. The pixel at
// (x, y) starts at Pix[(y-Rect.Min.Y)*Stride + (x-Rect.Min.X)*1].
Pix []uint8
// Stride is the Pix stride (in bytes) between vertically adjacent pixels.
Stride int
// Rect is the image's bounds.
Rect Rectangle
// Palette is the image's palette.
Palette color.Palette
}
func (p *Paletted) ColorModel() color.Model { return p.Palette }
func (p *Paletted) Bounds() Rectangle { return p.Rect }
func (p *Paletted) At(x, y int) color.Color {
if len(p.Palette) == 0 {
return nil
}
if !(Point{x, y}.In(p.Rect)) {
return p.Palette[0]
}
i := p.PixOffset(x, y)
return p.Palette[p.Pix[i]]
}
func (p *Paletted) RGBA64At(x, y int) color.RGBA64 {
if len(p.Palette) == 0 {
return color.RGBA64{}
}
c := color.Color(nil)
if !(Point{x, y}.In(p.Rect)) {
c = p.Palette[0]
} else {
i := p.PixOffset(x, y)
c = p.Palette[p.Pix[i]]
}
r, g, b, a := c.RGBA()
return color.RGBA64{
uint16(r),
uint16(g),
uint16(b),
uint16(a),
}
}
// PixOffset returns the index of the first element of Pix that corresponds to
// the pixel at (x, y).
func (p *Paletted) PixOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.Stride + (x-p.Rect.Min.X)*1
}
func (p *Paletted) Set(x, y int, c color.Color) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(p.Palette.Index(c))
}
func (p *Paletted) SetRGBA64(x, y int, c color.RGBA64) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = uint8(p.Palette.Index(c))
}
func (p *Paletted) ColorIndexAt(x, y int) uint8 {
if !(Point{x, y}.In(p.Rect)) {
return 0
}
i := p.PixOffset(x, y)
return p.Pix[i]
}
func (p *Paletted) SetColorIndex(x, y int, index uint8) {
if !(Point{x, y}.In(p.Rect)) {
return
}
i := p.PixOffset(x, y)
p.Pix[i] = index
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *Paletted) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &Paletted{
Palette: p.Palette,
}
}
i := p.PixOffset(r.Min.X, r.Min.Y)
return &Paletted{
Pix: p.Pix[i:],
Stride: p.Stride,
Rect: p.Rect.Intersect(r),
Palette: p.Palette,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *Paletted) Opaque() bool {
var present [256]bool
i0, i1 := 0, p.Rect.Dx()
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for _, c := range p.Pix[i0:i1] {
present[c] = true
}
i0 += p.Stride
i1 += p.Stride
}
for i, c := range p.Palette {
if !present[i] {
continue
}
_, _, _, a := c.RGBA()
if a != 0xffff {
return false
}
}
return true
}
// NewPaletted returns a new Paletted image with the given width, height and
// palette.
func NewPaletted(r Rectangle, p color.Palette) *Paletted {
return &Paletted{
Pix: make([]uint8, pixelBufferLength(1, r, "Paletted")),
Stride: 1 * r.Dx(),
Rect: r,
Palette: p,
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
// This file implements a Forward Discrete Cosine Transformation.
/*
It is based on the code in jfdctint.c from the Independent JPEG Group,
found at http://www.ijg.org/files/jpegsrc.v8c.tar.gz.
The "LEGAL ISSUES" section of the README in that archive says:
In plain English:
1. We don't promise that this software works. (But if you find any bugs,
please let us know!)
2. You can use this software for whatever you want. You don't have to pay us.
3. You may not pretend that you wrote this software. If you use it in a
program, you must acknowledge somewhere in your documentation that
you've used the IJG code.
In legalese:
The authors make NO WARRANTY or representation, either express or implied,
with respect to this software, its quality, accuracy, merchantability, or
fitness for a particular purpose. This software is provided "AS IS", and you,
its user, assume the entire risk as to its quality and accuracy.
This software is copyright (C) 1991-2011, Thomas G. Lane, Guido Vollbeding.
All Rights Reserved except as specified below.
Permission is hereby granted to use, copy, modify, and distribute this
software (or portions thereof) for any purpose, without fee, subject to these
conditions:
(1) If any part of the source code for this software is distributed, then this
README file must be included, with this copyright and no-warranty notice
unaltered; and any additions, deletions, or changes to the original files
must be clearly indicated in accompanying documentation.
(2) If only executable code is distributed, then the accompanying
documentation must state that "this software is based in part on the work of
the Independent JPEG Group".
(3) Permission for use of this software is granted only if the user accepts
full responsibility for any undesirable consequences; the authors accept
NO LIABILITY for damages of any kind.
These conditions apply to any software derived from or based on the IJG code,
not just to the unmodified library. If you use our work, you ought to
acknowledge us.
Permission is NOT granted for the use of any IJG author's name or company name
in advertising or publicity relating to this software or products derived from
it. This software may be referred to only as "the Independent JPEG Group's
software".
We specifically permit and encourage the use of this software as the basis of
commercial products, provided that all warranty or liability claims are
assumed by the product vendor.
*/
// Trigonometric constants in 13-bit fixed point format.
const (
fix_0_298631336 = 2446
fix_0_390180644 = 3196
fix_0_541196100 = 4433
fix_0_765366865 = 6270
fix_0_899976223 = 7373
fix_1_175875602 = 9633
fix_1_501321110 = 12299
fix_1_847759065 = 15137
fix_1_961570560 = 16069
fix_2_053119869 = 16819
fix_2_562915447 = 20995
fix_3_072711026 = 25172
)
const (
constBits = 13
pass1Bits = 2
centerJSample = 128
)
// fdct performs a forward DCT on an 8x8 block of coefficients, including a
// level shift.
func fdct(b *block) {
// Pass 1: process rows.
for y := 0; y < 8; y++ {
y8 := y * 8
s := b[y8 : y8+8 : y8+8] // Small cap improves performance, see https://golang.org/issue/27857
x0 := s[0]
x1 := s[1]
x2 := s[2]
x3 := s[3]
x4 := s[4]
x5 := s[5]
x6 := s[6]
x7 := s[7]
tmp0 := x0 + x7
tmp1 := x1 + x6
tmp2 := x2 + x5
tmp3 := x3 + x4
tmp10 := tmp0 + tmp3
tmp12 := tmp0 - tmp3
tmp11 := tmp1 + tmp2
tmp13 := tmp1 - tmp2
tmp0 = x0 - x7
tmp1 = x1 - x6
tmp2 = x2 - x5
tmp3 = x3 - x4
s[0] = (tmp10 + tmp11 - 8*centerJSample) << pass1Bits
s[4] = (tmp10 - tmp11) << pass1Bits
z1 := (tmp12 + tmp13) * fix_0_541196100
z1 += 1 << (constBits - pass1Bits - 1)
s[2] = (z1 + tmp12*fix_0_765366865) >> (constBits - pass1Bits)
s[6] = (z1 - tmp13*fix_1_847759065) >> (constBits - pass1Bits)
tmp10 = tmp0 + tmp3
tmp11 = tmp1 + tmp2
tmp12 = tmp0 + tmp2
tmp13 = tmp1 + tmp3
z1 = (tmp12 + tmp13) * fix_1_175875602
z1 += 1 << (constBits - pass1Bits - 1)
tmp0 *= fix_1_501321110
tmp1 *= fix_3_072711026
tmp2 *= fix_2_053119869
tmp3 *= fix_0_298631336
tmp10 *= -fix_0_899976223
tmp11 *= -fix_2_562915447
tmp12 *= -fix_0_390180644
tmp13 *= -fix_1_961570560
tmp12 += z1
tmp13 += z1
s[1] = (tmp0 + tmp10 + tmp12) >> (constBits - pass1Bits)
s[3] = (tmp1 + tmp11 + tmp13) >> (constBits - pass1Bits)
s[5] = (tmp2 + tmp11 + tmp12) >> (constBits - pass1Bits)
s[7] = (tmp3 + tmp10 + tmp13) >> (constBits - pass1Bits)
}
// Pass 2: process columns.
// We remove pass1Bits scaling, but leave results scaled up by an overall factor of 8.
for x := 0; x < 8; x++ {
tmp0 := b[0*8+x] + b[7*8+x]
tmp1 := b[1*8+x] + b[6*8+x]
tmp2 := b[2*8+x] + b[5*8+x]
tmp3 := b[3*8+x] + b[4*8+x]
tmp10 := tmp0 + tmp3 + 1<<(pass1Bits-1)
tmp12 := tmp0 - tmp3
tmp11 := tmp1 + tmp2
tmp13 := tmp1 - tmp2
tmp0 = b[0*8+x] - b[7*8+x]
tmp1 = b[1*8+x] - b[6*8+x]
tmp2 = b[2*8+x] - b[5*8+x]
tmp3 = b[3*8+x] - b[4*8+x]
b[0*8+x] = (tmp10 + tmp11) >> pass1Bits
b[4*8+x] = (tmp10 - tmp11) >> pass1Bits
z1 := (tmp12 + tmp13) * fix_0_541196100
z1 += 1 << (constBits + pass1Bits - 1)
b[2*8+x] = (z1 + tmp12*fix_0_765366865) >> (constBits + pass1Bits)
b[6*8+x] = (z1 - tmp13*fix_1_847759065) >> (constBits + pass1Bits)
tmp10 = tmp0 + tmp3
tmp11 = tmp1 + tmp2
tmp12 = tmp0 + tmp2
tmp13 = tmp1 + tmp3
z1 = (tmp12 + tmp13) * fix_1_175875602
z1 += 1 << (constBits + pass1Bits - 1)
tmp0 *= fix_1_501321110
tmp1 *= fix_3_072711026
tmp2 *= fix_2_053119869
tmp3 *= fix_0_298631336
tmp10 *= -fix_0_899976223
tmp11 *= -fix_2_562915447
tmp12 *= -fix_0_390180644
tmp13 *= -fix_1_961570560
tmp12 += z1
tmp13 += z1
b[1*8+x] = (tmp0 + tmp10 + tmp12) >> (constBits + pass1Bits)
b[3*8+x] = (tmp1 + tmp11 + tmp13) >> (constBits + pass1Bits)
b[5*8+x] = (tmp2 + tmp11 + tmp12) >> (constBits + pass1Bits)
b[7*8+x] = (tmp3 + tmp10 + tmp13) >> (constBits + pass1Bits)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
import (
"io"
)
// maxCodeLength is the maximum (inclusive) number of bits in a Huffman code.
const maxCodeLength = 16
// maxNCodes is the maximum (inclusive) number of codes in a Huffman tree.
const maxNCodes = 256
// lutSize is the log-2 size of the Huffman decoder's look-up table.
const lutSize = 8
// huffman is a Huffman decoder, specified in section C.
type huffman struct {
// length is the number of codes in the tree.
nCodes int32
// lut is the look-up table for the next lutSize bits in the bit-stream.
// The high 8 bits of the uint16 are the encoded value. The low 8 bits
// are 1 plus the code length, or 0 if the value is too large to fit in
// lutSize bits.
lut [1 << lutSize]uint16
// vals are the decoded values, sorted by their encoding.
vals [maxNCodes]uint8
// minCodes[i] is the minimum code of length i, or -1 if there are no
// codes of that length.
minCodes [maxCodeLength]int32
// maxCodes[i] is the maximum code of length i, or -1 if there are no
// codes of that length.
maxCodes [maxCodeLength]int32
// valsIndices[i] is the index into vals of minCodes[i].
valsIndices [maxCodeLength]int32
}
// errShortHuffmanData means that an unexpected EOF occurred while decoding
// Huffman data.
var errShortHuffmanData = FormatError("short Huffman data")
// ensureNBits reads bytes from the byte buffer to ensure that d.bits.n is at
// least n. For best performance (avoiding function calls inside hot loops),
// the caller is the one responsible for first checking that d.bits.n < n.
func (d *decoder) ensureNBits(n int32) error {
for {
c, err := d.readByteStuffedByte()
if err != nil {
if err == io.ErrUnexpectedEOF {
return errShortHuffmanData
}
return err
}
d.bits.a = d.bits.a<<8 | uint32(c)
d.bits.n += 8
if d.bits.m == 0 {
d.bits.m = 1 << 7
} else {
d.bits.m <<= 8
}
if d.bits.n >= n {
break
}
}
return nil
}
// receiveExtend is the composition of RECEIVE and EXTEND, specified in section
// F.2.2.1.
func (d *decoder) receiveExtend(t uint8) (int32, error) {
if d.bits.n < int32(t) {
if err := d.ensureNBits(int32(t)); err != nil {
return 0, err
}
}
d.bits.n -= int32(t)
d.bits.m >>= t
s := int32(1) << t
x := int32(d.bits.a>>uint8(d.bits.n)) & (s - 1)
if x < s>>1 {
x += ((-1) << t) + 1
}
return x, nil
}
// processDHT processes a Define Huffman Table marker, and initializes a huffman
// struct from its contents. Specified in section B.2.4.2.
func (d *decoder) processDHT(n int) error {
for n > 0 {
if n < 17 {
return FormatError("DHT has wrong length")
}
if err := d.readFull(d.tmp[:17]); err != nil {
return err
}
tc := d.tmp[0] >> 4
if tc > maxTc {
return FormatError("bad Tc value")
}
th := d.tmp[0] & 0x0f
// The baseline th <= 1 restriction is specified in table B.5.
if th > maxTh || (d.baseline && th > 1) {
return FormatError("bad Th value")
}
h := &d.huff[tc][th]
// Read nCodes and h.vals (and derive h.nCodes).
// nCodes[i] is the number of codes with code length i.
// h.nCodes is the total number of codes.
h.nCodes = 0
var nCodes [maxCodeLength]int32
for i := range nCodes {
nCodes[i] = int32(d.tmp[i+1])
h.nCodes += nCodes[i]
}
if h.nCodes == 0 {
return FormatError("Huffman table has zero length")
}
if h.nCodes > maxNCodes {
return FormatError("Huffman table has excessive length")
}
n -= int(h.nCodes) + 17
if n < 0 {
return FormatError("DHT has wrong length")
}
if err := d.readFull(h.vals[:h.nCodes]); err != nil {
return err
}
// Derive the look-up table.
for i := range h.lut {
h.lut[i] = 0
}
var x, code uint32
for i := uint32(0); i < lutSize; i++ {
code <<= 1
for j := int32(0); j < nCodes[i]; j++ {
// The codeLength is 1+i, so shift code by 8-(1+i) to
// calculate the high bits for every 8-bit sequence
// whose codeLength's high bits matches code.
// The high 8 bits of lutValue are the encoded value.
// The low 8 bits are 1 plus the codeLength.
base := uint8(code << (7 - i))
lutValue := uint16(h.vals[x])<<8 | uint16(2+i)
for k := uint8(0); k < 1<<(7-i); k++ {
h.lut[base|k] = lutValue
}
code++
x++
}
}
// Derive minCodes, maxCodes, and valsIndices.
var c, index int32
for i, n := range nCodes {
if n == 0 {
h.minCodes[i] = -1
h.maxCodes[i] = -1
h.valsIndices[i] = -1
} else {
h.minCodes[i] = c
h.maxCodes[i] = c + n - 1
h.valsIndices[i] = index
c += n
index += n
}
c <<= 1
}
}
return nil
}
// decodeHuffman returns the next Huffman-coded value from the bit-stream,
// decoded according to h.
func (d *decoder) decodeHuffman(h *huffman) (uint8, error) {
if h.nCodes == 0 {
return 0, FormatError("uninitialized Huffman table")
}
if d.bits.n < 8 {
if err := d.ensureNBits(8); err != nil {
if err != errMissingFF00 && err != errShortHuffmanData {
return 0, err
}
// There are no more bytes of data in this segment, but we may still
// be able to read the next symbol out of the previously read bits.
// First, undo the readByte that the ensureNBits call made.
if d.bytes.nUnreadable != 0 {
d.unreadByteStuffedByte()
}
goto slowPath
}
}
if v := h.lut[(d.bits.a>>uint32(d.bits.n-lutSize))&0xff]; v != 0 {
n := (v & 0xff) - 1
d.bits.n -= int32(n)
d.bits.m >>= n
return uint8(v >> 8), nil
}
slowPath:
for i, code := 0, int32(0); i < maxCodeLength; i++ {
if d.bits.n == 0 {
if err := d.ensureNBits(1); err != nil {
return 0, err
}
}
if d.bits.a&d.bits.m != 0 {
code |= 1
}
d.bits.n--
d.bits.m >>= 1
if code <= h.maxCodes[i] {
return h.vals[h.valsIndices[i]+code-h.minCodes[i]], nil
}
code <<= 1
}
return 0, FormatError("bad Huffman code")
}
func (d *decoder) decodeBit() (bool, error) {
if d.bits.n == 0 {
if err := d.ensureNBits(1); err != nil {
return false, err
}
}
ret := d.bits.a&d.bits.m != 0
d.bits.n--
d.bits.m >>= 1
return ret, nil
}
func (d *decoder) decodeBits(n int32) (uint32, error) {
if d.bits.n < n {
if err := d.ensureNBits(n); err != nil {
return 0, err
}
}
ret := d.bits.a >> uint32(d.bits.n-n)
ret &= (1 << uint32(n)) - 1
d.bits.n -= n
d.bits.m >>= uint32(n)
return ret, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
// This is a Go translation of idct.c from
//
// http://standards.iso.org/ittf/PubliclyAvailableStandards/ISO_IEC_13818-4_2004_Conformance_Testing/Video/verifier/mpeg2decode_960109.tar.gz
//
// which carries the following notice:
/* Copyright (C) 1996, MPEG Software Simulation Group. All Rights Reserved. */
/*
* Disclaimer of Warranty
*
* These software programs are available to the user without any license fee or
* royalty on an "as is" basis. The MPEG Software Simulation Group disclaims
* any and all warranties, whether express, implied, or statuary, including any
* implied warranties or merchantability or of fitness for a particular
* purpose. In no event shall the copyright-holder be liable for any
* incidental, punitive, or consequential damages of any kind whatsoever
* arising from the use of these programs.
*
* This disclaimer of warranty extends to the user of these programs and user's
* customers, employees, agents, transferees, successors, and assigns.
*
* The MPEG Software Simulation Group does not represent or warrant that the
* programs furnished hereunder are free of infringement of any third-party
* patents.
*
* Commercial implementations of MPEG-1 and MPEG-2 video, including shareware,
* are subject to royalty fees to patent holders. Many of these patents are
* general enough such that they are unavoidable regardless of implementation
* design.
*
*/
const blockSize = 64 // A DCT block is 8x8.
type block [blockSize]int32
const (
w1 = 2841 // 2048*sqrt(2)*cos(1*pi/16)
w2 = 2676 // 2048*sqrt(2)*cos(2*pi/16)
w3 = 2408 // 2048*sqrt(2)*cos(3*pi/16)
w5 = 1609 // 2048*sqrt(2)*cos(5*pi/16)
w6 = 1108 // 2048*sqrt(2)*cos(6*pi/16)
w7 = 565 // 2048*sqrt(2)*cos(7*pi/16)
w1pw7 = w1 + w7
w1mw7 = w1 - w7
w2pw6 = w2 + w6
w2mw6 = w2 - w6
w3pw5 = w3 + w5
w3mw5 = w3 - w5
r2 = 181 // 256/sqrt(2)
)
// idct performs a 2-D Inverse Discrete Cosine Transformation.
//
// The input coefficients should already have been multiplied by the
// appropriate quantization table. We use fixed-point computation, with the
// number of bits for the fractional component varying over the intermediate
// stages.
//
// For more on the actual algorithm, see Z. Wang, "Fast algorithms for the
// discrete W transform and for the discrete Fourier transform", IEEE Trans. on
// ASSP, Vol. ASSP- 32, pp. 803-816, Aug. 1984.
func idct(src *block) {
// Horizontal 1-D IDCT.
for y := 0; y < 8; y++ {
y8 := y * 8
s := src[y8 : y8+8 : y8+8] // Small cap improves performance, see https://golang.org/issue/27857
// If all the AC components are zero, then the IDCT is trivial.
if s[1] == 0 && s[2] == 0 && s[3] == 0 &&
s[4] == 0 && s[5] == 0 && s[6] == 0 && s[7] == 0 {
dc := s[0] << 3
s[0] = dc
s[1] = dc
s[2] = dc
s[3] = dc
s[4] = dc
s[5] = dc
s[6] = dc
s[7] = dc
continue
}
// Prescale.
x0 := (s[0] << 11) + 128
x1 := s[4] << 11
x2 := s[6]
x3 := s[2]
x4 := s[1]
x5 := s[7]
x6 := s[5]
x7 := s[3]
// Stage 1.
x8 := w7 * (x4 + x5)
x4 = x8 + w1mw7*x4
x5 = x8 - w1pw7*x5
x8 = w3 * (x6 + x7)
x6 = x8 - w3mw5*x6
x7 = x8 - w3pw5*x7
// Stage 2.
x8 = x0 + x1
x0 -= x1
x1 = w6 * (x3 + x2)
x2 = x1 - w2pw6*x2
x3 = x1 + w2mw6*x3
x1 = x4 + x6
x4 -= x6
x6 = x5 + x7
x5 -= x7
// Stage 3.
x7 = x8 + x3
x8 -= x3
x3 = x0 + x2
x0 -= x2
x2 = (r2*(x4+x5) + 128) >> 8
x4 = (r2*(x4-x5) + 128) >> 8
// Stage 4.
s[0] = (x7 + x1) >> 8
s[1] = (x3 + x2) >> 8
s[2] = (x0 + x4) >> 8
s[3] = (x8 + x6) >> 8
s[4] = (x8 - x6) >> 8
s[5] = (x0 - x4) >> 8
s[6] = (x3 - x2) >> 8
s[7] = (x7 - x1) >> 8
}
// Vertical 1-D IDCT.
for x := 0; x < 8; x++ {
// Similar to the horizontal 1-D IDCT case, if all the AC components are zero, then the IDCT is trivial.
// However, after performing the horizontal 1-D IDCT, there are typically non-zero AC components, so
// we do not bother to check for the all-zero case.
s := src[x : x+57 : x+57] // Small cap improves performance, see https://golang.org/issue/27857
// Prescale.
y0 := (s[8*0] << 8) + 8192
y1 := s[8*4] << 8
y2 := s[8*6]
y3 := s[8*2]
y4 := s[8*1]
y5 := s[8*7]
y6 := s[8*5]
y7 := s[8*3]
// Stage 1.
y8 := w7*(y4+y5) + 4
y4 = (y8 + w1mw7*y4) >> 3
y5 = (y8 - w1pw7*y5) >> 3
y8 = w3*(y6+y7) + 4
y6 = (y8 - w3mw5*y6) >> 3
y7 = (y8 - w3pw5*y7) >> 3
// Stage 2.
y8 = y0 + y1
y0 -= y1
y1 = w6*(y3+y2) + 4
y2 = (y1 - w2pw6*y2) >> 3
y3 = (y1 + w2mw6*y3) >> 3
y1 = y4 + y6
y4 -= y6
y6 = y5 + y7
y5 -= y7
// Stage 3.
y7 = y8 + y3
y8 -= y3
y3 = y0 + y2
y0 -= y2
y2 = (r2*(y4+y5) + 128) >> 8
y4 = (r2*(y4-y5) + 128) >> 8
// Stage 4.
s[8*0] = (y7 + y1) >> 14
s[8*1] = (y3 + y2) >> 14
s[8*2] = (y0 + y4) >> 14
s[8*3] = (y8 + y6) >> 14
s[8*4] = (y8 - y6) >> 14
s[8*5] = (y0 - y4) >> 14
s[8*6] = (y3 - y2) >> 14
s[8*7] = (y7 - y1) >> 14
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package jpeg implements a JPEG image decoder and encoder.
//
// JPEG is defined in ITU-T T.81: https://www.w3.org/Graphics/JPEG/itu-t81.pdf.
package jpeg
import (
"image"
"image/color"
"image/internal/imageutil"
"io"
)
// A FormatError reports that the input is not a valid JPEG.
type FormatError string
func (e FormatError) Error() string { return "invalid JPEG format: " + string(e) }
// An UnsupportedError reports that the input uses a valid but unimplemented JPEG feature.
type UnsupportedError string
func (e UnsupportedError) Error() string { return "unsupported JPEG feature: " + string(e) }
var errUnsupportedSubsamplingRatio = UnsupportedError("luma/chroma subsampling ratio")
// Component specification, specified in section B.2.2.
type component struct {
h int // Horizontal sampling factor.
v int // Vertical sampling factor.
c uint8 // Component identifier.
tq uint8 // Quantization table destination selector.
}
const (
dcTable = 0
acTable = 1
maxTc = 1
maxTh = 3
maxTq = 3
maxComponents = 4
)
const (
sof0Marker = 0xc0 // Start Of Frame (Baseline Sequential).
sof1Marker = 0xc1 // Start Of Frame (Extended Sequential).
sof2Marker = 0xc2 // Start Of Frame (Progressive).
dhtMarker = 0xc4 // Define Huffman Table.
rst0Marker = 0xd0 // ReSTart (0).
rst7Marker = 0xd7 // ReSTart (7).
soiMarker = 0xd8 // Start Of Image.
eoiMarker = 0xd9 // End Of Image.
sosMarker = 0xda // Start Of Scan.
dqtMarker = 0xdb // Define Quantization Table.
driMarker = 0xdd // Define Restart Interval.
comMarker = 0xfe // COMment.
// "APPlication specific" markers aren't part of the JPEG spec per se,
// but in practice, their use is described at
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html
app0Marker = 0xe0
app14Marker = 0xee
app15Marker = 0xef
)
// See https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
const (
adobeTransformUnknown = 0
adobeTransformYCbCr = 1
adobeTransformYCbCrK = 2
)
// unzig maps from the zig-zag ordering to the natural ordering. For example,
// unzig[3] is the column and row of the fourth element in zig-zag order. The
// value is 16, which means first column (16%8 == 0) and third row (16/8 == 2).
var unzig = [blockSize]int{
0, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63,
}
// Deprecated: Reader is not used by the image/jpeg package and should
// not be used by others. It is kept for compatibility.
type Reader interface {
io.ByteReader
io.Reader
}
// bits holds the unprocessed bits that have been taken from the byte-stream.
// The n least significant bits of a form the unread bits, to be read in MSB to
// LSB order.
type bits struct {
a uint32 // accumulator.
m uint32 // mask. m==1<<(n-1) when n>0, with m==0 when n==0.
n int32 // the number of unread bits in a.
}
type decoder struct {
r io.Reader
bits bits
// bytes is a byte buffer, similar to a bufio.Reader, except that it
// has to be able to unread more than 1 byte, due to byte stuffing.
// Byte stuffing is specified in section F.1.2.3.
bytes struct {
// buf[i:j] are the buffered bytes read from the underlying
// io.Reader that haven't yet been passed further on.
buf [4096]byte
i, j int
// nUnreadable is the number of bytes to back up i after
// overshooting. It can be 0, 1 or 2.
nUnreadable int
}
width, height int
img1 *image.Gray
img3 *image.YCbCr
blackPix []byte
blackStride int
ri int // Restart Interval.
nComp int
// As per section 4.5, there are four modes of operation (selected by the
// SOF? markers): sequential DCT, progressive DCT, lossless and
// hierarchical, although this implementation does not support the latter
// two non-DCT modes. Sequential DCT is further split into baseline and
// extended, as per section 4.11.
baseline bool
progressive bool
jfif bool
adobeTransformValid bool
adobeTransform uint8
eobRun uint16 // End-of-Band run, specified in section G.1.2.2.
comp [maxComponents]component
progCoeffs [maxComponents][]block // Saved state between progressive-mode scans.
huff [maxTc + 1][maxTh + 1]huffman
quant [maxTq + 1]block // Quantization tables, in zig-zag order.
tmp [2 * blockSize]byte
}
// fill fills up the d.bytes.buf buffer from the underlying io.Reader. It
// should only be called when there are no unread bytes in d.bytes.
func (d *decoder) fill() error {
if d.bytes.i != d.bytes.j {
panic("jpeg: fill called when unread bytes exist")
}
// Move the last 2 bytes to the start of the buffer, in case we need
// to call unreadByteStuffedByte.
if d.bytes.j > 2 {
d.bytes.buf[0] = d.bytes.buf[d.bytes.j-2]
d.bytes.buf[1] = d.bytes.buf[d.bytes.j-1]
d.bytes.i, d.bytes.j = 2, 2
}
// Fill in the rest of the buffer.
n, err := d.r.Read(d.bytes.buf[d.bytes.j:])
d.bytes.j += n
if n > 0 {
return nil
}
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
// unreadByteStuffedByte undoes the most recent readByteStuffedByte call,
// giving a byte of data back from d.bits to d.bytes. The Huffman look-up table
// requires at least 8 bits for look-up, which means that Huffman decoding can
// sometimes overshoot and read one or two too many bytes. Two-byte overshoot
// can happen when expecting to read a 0xff 0x00 byte-stuffed byte.
func (d *decoder) unreadByteStuffedByte() {
d.bytes.i -= d.bytes.nUnreadable
d.bytes.nUnreadable = 0
if d.bits.n >= 8 {
d.bits.a >>= 8
d.bits.n -= 8
d.bits.m >>= 8
}
}
// readByte returns the next byte, whether buffered or not buffered. It does
// not care about byte stuffing.
func (d *decoder) readByte() (x byte, err error) {
for d.bytes.i == d.bytes.j {
if err = d.fill(); err != nil {
return 0, err
}
}
x = d.bytes.buf[d.bytes.i]
d.bytes.i++
d.bytes.nUnreadable = 0
return x, nil
}
// errMissingFF00 means that readByteStuffedByte encountered an 0xff byte (a
// marker byte) that wasn't the expected byte-stuffed sequence 0xff, 0x00.
var errMissingFF00 = FormatError("missing 0xff00 sequence")
// readByteStuffedByte is like readByte but is for byte-stuffed Huffman data.
func (d *decoder) readByteStuffedByte() (x byte, err error) {
// Take the fast path if d.bytes.buf contains at least two bytes.
if d.bytes.i+2 <= d.bytes.j {
x = d.bytes.buf[d.bytes.i]
d.bytes.i++
d.bytes.nUnreadable = 1
if x != 0xff {
return x, err
}
if d.bytes.buf[d.bytes.i] != 0x00 {
return 0, errMissingFF00
}
d.bytes.i++
d.bytes.nUnreadable = 2
return 0xff, nil
}
d.bytes.nUnreadable = 0
x, err = d.readByte()
if err != nil {
return 0, err
}
d.bytes.nUnreadable = 1
if x != 0xff {
return x, nil
}
x, err = d.readByte()
if err != nil {
return 0, err
}
d.bytes.nUnreadable = 2
if x != 0x00 {
return 0, errMissingFF00
}
return 0xff, nil
}
// readFull reads exactly len(p) bytes into p. It does not care about byte
// stuffing.
func (d *decoder) readFull(p []byte) error {
// Unread the overshot bytes, if any.
if d.bytes.nUnreadable != 0 {
if d.bits.n >= 8 {
d.unreadByteStuffedByte()
}
d.bytes.nUnreadable = 0
}
for {
n := copy(p, d.bytes.buf[d.bytes.i:d.bytes.j])
p = p[n:]
d.bytes.i += n
if len(p) == 0 {
break
}
if err := d.fill(); err != nil {
return err
}
}
return nil
}
// ignore ignores the next n bytes.
func (d *decoder) ignore(n int) error {
// Unread the overshot bytes, if any.
if d.bytes.nUnreadable != 0 {
if d.bits.n >= 8 {
d.unreadByteStuffedByte()
}
d.bytes.nUnreadable = 0
}
for {
m := d.bytes.j - d.bytes.i
if m > n {
m = n
}
d.bytes.i += m
n -= m
if n == 0 {
break
}
if err := d.fill(); err != nil {
return err
}
}
return nil
}
// Specified in section B.2.2.
func (d *decoder) processSOF(n int) error {
if d.nComp != 0 {
return FormatError("multiple SOF markers")
}
switch n {
case 6 + 3*1: // Grayscale image.
d.nComp = 1
case 6 + 3*3: // YCbCr or RGB image.
d.nComp = 3
case 6 + 3*4: // YCbCrK or CMYK image.
d.nComp = 4
default:
return UnsupportedError("number of components")
}
if err := d.readFull(d.tmp[:n]); err != nil {
return err
}
// We only support 8-bit precision.
if d.tmp[0] != 8 {
return UnsupportedError("precision")
}
d.height = int(d.tmp[1])<<8 + int(d.tmp[2])
d.width = int(d.tmp[3])<<8 + int(d.tmp[4])
if int(d.tmp[5]) != d.nComp {
return FormatError("SOF has wrong length")
}
for i := 0; i < d.nComp; i++ {
d.comp[i].c = d.tmp[6+3*i]
// Section B.2.2 states that "the value of C_i shall be different from
// the values of C_1 through C_(i-1)".
for j := 0; j < i; j++ {
if d.comp[i].c == d.comp[j].c {
return FormatError("repeated component identifier")
}
}
d.comp[i].tq = d.tmp[8+3*i]
if d.comp[i].tq > maxTq {
return FormatError("bad Tq value")
}
hv := d.tmp[7+3*i]
h, v := int(hv>>4), int(hv&0x0f)
if h < 1 || 4 < h || v < 1 || 4 < v {
return FormatError("luma/chroma subsampling ratio")
}
if h == 3 || v == 3 {
return errUnsupportedSubsamplingRatio
}
switch d.nComp {
case 1:
// If a JPEG image has only one component, section A.2 says "this data
// is non-interleaved by definition" and section A.2.2 says "[in this
// case...] the order of data units within a scan shall be left-to-right
// and top-to-bottom... regardless of the values of H_1 and V_1". Section
// 4.8.2 also says "[for non-interleaved data], the MCU is defined to be
// one data unit". Similarly, section A.1.1 explains that it is the ratio
// of H_i to max_j(H_j) that matters, and similarly for V. For grayscale
// images, H_1 is the maximum H_j for all components j, so that ratio is
// always 1. The component's (h, v) is effectively always (1, 1): even if
// the nominal (h, v) is (2, 1), a 20x5 image is encoded in three 8x8
// MCUs, not two 16x8 MCUs.
h, v = 1, 1
case 3:
// For YCbCr images, we only support 4:4:4, 4:4:0, 4:2:2, 4:2:0,
// 4:1:1 or 4:1:0 chroma subsampling ratios. This implies that the
// (h, v) values for the Y component are either (1, 1), (1, 2),
// (2, 1), (2, 2), (4, 1) or (4, 2), and the Y component's values
// must be a multiple of the Cb and Cr component's values. We also
// assume that the two chroma components have the same subsampling
// ratio.
switch i {
case 0: // Y.
// We have already verified, above, that h and v are both
// either 1, 2 or 4, so invalid (h, v) combinations are those
// with v == 4.
if v == 4 {
return errUnsupportedSubsamplingRatio
}
case 1: // Cb.
if d.comp[0].h%h != 0 || d.comp[0].v%v != 0 {
return errUnsupportedSubsamplingRatio
}
case 2: // Cr.
if d.comp[1].h != h || d.comp[1].v != v {
return errUnsupportedSubsamplingRatio
}
}
case 4:
// For 4-component images (either CMYK or YCbCrK), we only support two
// hv vectors: [0x11 0x11 0x11 0x11] and [0x22 0x11 0x11 0x22].
// Theoretically, 4-component JPEG images could mix and match hv values
// but in practice, those two combinations are the only ones in use,
// and it simplifies the applyBlack code below if we can assume that:
// - for CMYK, the C and K channels have full samples, and if the M
// and Y channels subsample, they subsample both horizontally and
// vertically.
// - for YCbCrK, the Y and K channels have full samples.
switch i {
case 0:
if hv != 0x11 && hv != 0x22 {
return errUnsupportedSubsamplingRatio
}
case 1, 2:
if hv != 0x11 {
return errUnsupportedSubsamplingRatio
}
case 3:
if d.comp[0].h != h || d.comp[0].v != v {
return errUnsupportedSubsamplingRatio
}
}
}
d.comp[i].h = h
d.comp[i].v = v
}
return nil
}
// Specified in section B.2.4.1.
func (d *decoder) processDQT(n int) error {
loop:
for n > 0 {
n--
x, err := d.readByte()
if err != nil {
return err
}
tq := x & 0x0f
if tq > maxTq {
return FormatError("bad Tq value")
}
switch x >> 4 {
default:
return FormatError("bad Pq value")
case 0:
if n < blockSize {
break loop
}
n -= blockSize
if err := d.readFull(d.tmp[:blockSize]); err != nil {
return err
}
for i := range d.quant[tq] {
d.quant[tq][i] = int32(d.tmp[i])
}
case 1:
if n < 2*blockSize {
break loop
}
n -= 2 * blockSize
if err := d.readFull(d.tmp[:2*blockSize]); err != nil {
return err
}
for i := range d.quant[tq] {
d.quant[tq][i] = int32(d.tmp[2*i])<<8 | int32(d.tmp[2*i+1])
}
}
}
if n != 0 {
return FormatError("DQT has wrong length")
}
return nil
}
// Specified in section B.2.4.4.
func (d *decoder) processDRI(n int) error {
if n != 2 {
return FormatError("DRI has wrong length")
}
if err := d.readFull(d.tmp[:2]); err != nil {
return err
}
d.ri = int(d.tmp[0])<<8 + int(d.tmp[1])
return nil
}
func (d *decoder) processApp0Marker(n int) error {
if n < 5 {
return d.ignore(n)
}
if err := d.readFull(d.tmp[:5]); err != nil {
return err
}
n -= 5
d.jfif = d.tmp[0] == 'J' && d.tmp[1] == 'F' && d.tmp[2] == 'I' && d.tmp[3] == 'F' && d.tmp[4] == '\x00'
if n > 0 {
return d.ignore(n)
}
return nil
}
func (d *decoder) processApp14Marker(n int) error {
if n < 12 {
return d.ignore(n)
}
if err := d.readFull(d.tmp[:12]); err != nil {
return err
}
n -= 12
if d.tmp[0] == 'A' && d.tmp[1] == 'd' && d.tmp[2] == 'o' && d.tmp[3] == 'b' && d.tmp[4] == 'e' {
d.adobeTransformValid = true
d.adobeTransform = d.tmp[11]
}
if n > 0 {
return d.ignore(n)
}
return nil
}
// decode reads a JPEG image from r and returns it as an image.Image.
func (d *decoder) decode(r io.Reader, configOnly bool) (image.Image, error) {
d.r = r
// Check for the Start Of Image marker.
if err := d.readFull(d.tmp[:2]); err != nil {
return nil, err
}
if d.tmp[0] != 0xff || d.tmp[1] != soiMarker {
return nil, FormatError("missing SOI marker")
}
// Process the remaining segments until the End Of Image marker.
for {
err := d.readFull(d.tmp[:2])
if err != nil {
return nil, err
}
for d.tmp[0] != 0xff {
// Strictly speaking, this is a format error. However, libjpeg is
// liberal in what it accepts. As of version 9, next_marker in
// jdmarker.c treats this as a warning (JWRN_EXTRANEOUS_DATA) and
// continues to decode the stream. Even before next_marker sees
// extraneous data, jpeg_fill_bit_buffer in jdhuff.c reads as many
// bytes as it can, possibly past the end of a scan's data. It
// effectively puts back any markers that it overscanned (e.g. an
// "\xff\xd9" EOI marker), but it does not put back non-marker data,
// and thus it can silently ignore a small number of extraneous
// non-marker bytes before next_marker has a chance to see them (and
// print a warning).
//
// We are therefore also liberal in what we accept. Extraneous data
// is silently ignored.
//
// This is similar to, but not exactly the same as, the restart
// mechanism within a scan (the RST[0-7] markers).
//
// Note that extraneous 0xff bytes in e.g. SOS data are escaped as
// "\xff\x00", and so are detected a little further down below.
d.tmp[0] = d.tmp[1]
d.tmp[1], err = d.readByte()
if err != nil {
return nil, err
}
}
marker := d.tmp[1]
if marker == 0 {
// Treat "\xff\x00" as extraneous data.
continue
}
for marker == 0xff {
// Section B.1.1.2 says, "Any marker may optionally be preceded by any
// number of fill bytes, which are bytes assigned code X'FF'".
marker, err = d.readByte()
if err != nil {
return nil, err
}
}
if marker == eoiMarker { // End Of Image.
break
}
if rst0Marker <= marker && marker <= rst7Marker {
// Figures B.2 and B.16 of the specification suggest that restart markers should
// only occur between Entropy Coded Segments and not after the final ECS.
// However, some encoders may generate incorrect JPEGs with a final restart
// marker. That restart marker will be seen here instead of inside the processSOS
// method, and is ignored as a harmless error. Restart markers have no extra data,
// so we check for this before we read the 16-bit length of the segment.
continue
}
// Read the 16-bit length of the segment. The value includes the 2 bytes for the
// length itself, so we subtract 2 to get the number of remaining bytes.
if err = d.readFull(d.tmp[:2]); err != nil {
return nil, err
}
n := int(d.tmp[0])<<8 + int(d.tmp[1]) - 2
if n < 0 {
return nil, FormatError("short segment length")
}
switch marker {
case sof0Marker, sof1Marker, sof2Marker:
d.baseline = marker == sof0Marker
d.progressive = marker == sof2Marker
err = d.processSOF(n)
if configOnly && d.jfif {
return nil, err
}
case dhtMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDHT(n)
}
case dqtMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDQT(n)
}
case sosMarker:
if configOnly {
return nil, nil
}
err = d.processSOS(n)
case driMarker:
if configOnly {
err = d.ignore(n)
} else {
err = d.processDRI(n)
}
case app0Marker:
err = d.processApp0Marker(n)
case app14Marker:
err = d.processApp14Marker(n)
default:
if app0Marker <= marker && marker <= app15Marker || marker == comMarker {
err = d.ignore(n)
} else if marker < 0xc0 { // See Table B.1 "Marker code assignments".
err = FormatError("unknown marker")
} else {
err = UnsupportedError("unknown marker")
}
}
if err != nil {
return nil, err
}
}
if d.progressive {
if err := d.reconstructProgressiveImage(); err != nil {
return nil, err
}
}
if d.img1 != nil {
return d.img1, nil
}
if d.img3 != nil {
if d.blackPix != nil {
return d.applyBlack()
} else if d.isRGB() {
return d.convertToRGB()
}
return d.img3, nil
}
return nil, FormatError("missing SOS marker")
}
// applyBlack combines d.img3 and d.blackPix into a CMYK image. The formula
// used depends on whether the JPEG image is stored as CMYK or YCbCrK,
// indicated by the APP14 (Adobe) metadata.
//
// Adobe CMYK JPEG images are inverted, where 255 means no ink instead of full
// ink, so we apply "v = 255 - v" at various points. Note that a double
// inversion is a no-op, so inversions might be implicit in the code below.
func (d *decoder) applyBlack() (image.Image, error) {
if !d.adobeTransformValid {
return nil, UnsupportedError("unknown color model: 4-component JPEG doesn't have Adobe APP14 metadata")
}
// If the 4-component JPEG image isn't explicitly marked as "Unknown (RGB
// or CMYK)" as per
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
// we assume that it is YCbCrK. This matches libjpeg's jdapimin.c.
if d.adobeTransform != adobeTransformUnknown {
// Convert the YCbCr part of the YCbCrK to RGB, invert the RGB to get
// CMY, and patch in the original K. The RGB to CMY inversion cancels
// out the 'Adobe inversion' described in the applyBlack doc comment
// above, so in practice, only the fourth channel (black) is inverted.
bounds := d.img3.Bounds()
img := image.NewRGBA(bounds)
imageutil.DrawYCbCr(img, bounds, d.img3, bounds.Min)
for iBase, y := 0, bounds.Min.Y; y < bounds.Max.Y; iBase, y = iBase+img.Stride, y+1 {
for i, x := iBase+3, bounds.Min.X; x < bounds.Max.X; i, x = i+4, x+1 {
img.Pix[i] = 255 - d.blackPix[(y-bounds.Min.Y)*d.blackStride+(x-bounds.Min.X)]
}
}
return &image.CMYK{
Pix: img.Pix,
Stride: img.Stride,
Rect: img.Rect,
}, nil
}
// The first three channels (cyan, magenta, yellow) of the CMYK
// were decoded into d.img3, but each channel was decoded into a separate
// []byte slice, and some channels may be subsampled. We interleave the
// separate channels into an image.CMYK's single []byte slice containing 4
// contiguous bytes per pixel.
bounds := d.img3.Bounds()
img := image.NewCMYK(bounds)
translations := [4]struct {
src []byte
stride int
}{
{d.img3.Y, d.img3.YStride},
{d.img3.Cb, d.img3.CStride},
{d.img3.Cr, d.img3.CStride},
{d.blackPix, d.blackStride},
}
for t, translation := range translations {
subsample := d.comp[t].h != d.comp[0].h || d.comp[t].v != d.comp[0].v
for iBase, y := 0, bounds.Min.Y; y < bounds.Max.Y; iBase, y = iBase+img.Stride, y+1 {
sy := y - bounds.Min.Y
if subsample {
sy /= 2
}
for i, x := iBase+t, bounds.Min.X; x < bounds.Max.X; i, x = i+4, x+1 {
sx := x - bounds.Min.X
if subsample {
sx /= 2
}
img.Pix[i] = 255 - translation.src[sy*translation.stride+sx]
}
}
}
return img, nil
}
func (d *decoder) isRGB() bool {
if d.jfif {
return false
}
if d.adobeTransformValid && d.adobeTransform == adobeTransformUnknown {
// https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/JPEG.html#Adobe
// says that 0 means Unknown (and in practice RGB) and 1 means YCbCr.
return true
}
return d.comp[0].c == 'R' && d.comp[1].c == 'G' && d.comp[2].c == 'B'
}
func (d *decoder) convertToRGB() (image.Image, error) {
cScale := d.comp[0].h / d.comp[1].h
bounds := d.img3.Bounds()
img := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
po := img.PixOffset(bounds.Min.X, y)
yo := d.img3.YOffset(bounds.Min.X, y)
co := d.img3.COffset(bounds.Min.X, y)
for i, iMax := 0, bounds.Max.X-bounds.Min.X; i < iMax; i++ {
img.Pix[po+4*i+0] = d.img3.Y[yo+i]
img.Pix[po+4*i+1] = d.img3.Cb[co+i/cScale]
img.Pix[po+4*i+2] = d.img3.Cr[co+i/cScale]
img.Pix[po+4*i+3] = 255
}
}
return img, nil
}
// Decode reads a JPEG image from r and returns it as an image.Image.
func Decode(r io.Reader) (image.Image, error) {
var d decoder
return d.decode(r, false)
}
// DecodeConfig returns the color model and dimensions of a JPEG image without
// decoding the entire image.
func DecodeConfig(r io.Reader) (image.Config, error) {
var d decoder
if _, err := d.decode(r, true); err != nil {
return image.Config{}, err
}
switch d.nComp {
case 1:
return image.Config{
ColorModel: color.GrayModel,
Width: d.width,
Height: d.height,
}, nil
case 3:
cm := color.YCbCrModel
if d.isRGB() {
cm = color.RGBAModel
}
return image.Config{
ColorModel: cm,
Width: d.width,
Height: d.height,
}, nil
case 4:
return image.Config{
ColorModel: color.CMYKModel,
Width: d.width,
Height: d.height,
}, nil
}
return image.Config{}, FormatError("missing SOF marker")
}
func init() {
image.RegisterFormat("jpeg", "\xff\xd8", Decode, DecodeConfig)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
import (
"image"
)
// makeImg allocates and initializes the destination image.
func (d *decoder) makeImg(mxx, myy int) {
if d.nComp == 1 {
m := image.NewGray(image.Rect(0, 0, 8*mxx, 8*myy))
d.img1 = m.SubImage(image.Rect(0, 0, d.width, d.height)).(*image.Gray)
return
}
h0 := d.comp[0].h
v0 := d.comp[0].v
hRatio := h0 / d.comp[1].h
vRatio := v0 / d.comp[1].v
var subsampleRatio image.YCbCrSubsampleRatio
switch hRatio<<4 | vRatio {
case 0x11:
subsampleRatio = image.YCbCrSubsampleRatio444
case 0x12:
subsampleRatio = image.YCbCrSubsampleRatio440
case 0x21:
subsampleRatio = image.YCbCrSubsampleRatio422
case 0x22:
subsampleRatio = image.YCbCrSubsampleRatio420
case 0x41:
subsampleRatio = image.YCbCrSubsampleRatio411
case 0x42:
subsampleRatio = image.YCbCrSubsampleRatio410
default:
panic("unreachable")
}
m := image.NewYCbCr(image.Rect(0, 0, 8*h0*mxx, 8*v0*myy), subsampleRatio)
d.img3 = m.SubImage(image.Rect(0, 0, d.width, d.height)).(*image.YCbCr)
if d.nComp == 4 {
h3, v3 := d.comp[3].h, d.comp[3].v
d.blackPix = make([]byte, 8*h3*mxx*8*v3*myy)
d.blackStride = 8 * h3 * mxx
}
}
// Specified in section B.2.3.
func (d *decoder) processSOS(n int) error {
if d.nComp == 0 {
return FormatError("missing SOF marker")
}
if n < 6 || 4+2*d.nComp < n || n%2 != 0 {
return FormatError("SOS has wrong length")
}
if err := d.readFull(d.tmp[:n]); err != nil {
return err
}
nComp := int(d.tmp[0])
if n != 4+2*nComp {
return FormatError("SOS length inconsistent with number of components")
}
var scan [maxComponents]struct {
compIndex uint8
td uint8 // DC table selector.
ta uint8 // AC table selector.
}
totalHV := 0
for i := 0; i < nComp; i++ {
cs := d.tmp[1+2*i] // Component selector.
compIndex := -1
for j, comp := range d.comp[:d.nComp] {
if cs == comp.c {
compIndex = j
}
}
if compIndex < 0 {
return FormatError("unknown component selector")
}
scan[i].compIndex = uint8(compIndex)
// Section B.2.3 states that "the value of Cs_j shall be different from
// the values of Cs_1 through Cs_(j-1)". Since we have previously
// verified that a frame's component identifiers (C_i values in section
// B.2.2) are unique, it suffices to check that the implicit indexes
// into d.comp are unique.
for j := 0; j < i; j++ {
if scan[i].compIndex == scan[j].compIndex {
return FormatError("repeated component selector")
}
}
totalHV += d.comp[compIndex].h * d.comp[compIndex].v
// The baseline t <= 1 restriction is specified in table B.3.
scan[i].td = d.tmp[2+2*i] >> 4
if t := scan[i].td; t > maxTh || (d.baseline && t > 1) {
return FormatError("bad Td value")
}
scan[i].ta = d.tmp[2+2*i] & 0x0f
if t := scan[i].ta; t > maxTh || (d.baseline && t > 1) {
return FormatError("bad Ta value")
}
}
// Section B.2.3 states that if there is more than one component then the
// total H*V values in a scan must be <= 10.
if d.nComp > 1 && totalHV > 10 {
return FormatError("total sampling factors too large")
}
// zigStart and zigEnd are the spectral selection bounds.
// ah and al are the successive approximation high and low values.
// The spec calls these values Ss, Se, Ah and Al.
//
// For progressive JPEGs, these are the two more-or-less independent
// aspects of progression. Spectral selection progression is when not
// all of a block's 64 DCT coefficients are transmitted in one pass.
// For example, three passes could transmit coefficient 0 (the DC
// component), coefficients 1-5, and coefficients 6-63, in zig-zag
// order. Successive approximation is when not all of the bits of a
// band of coefficients are transmitted in one pass. For example,
// three passes could transmit the 6 most significant bits, followed
// by the second-least significant bit, followed by the least
// significant bit.
//
// For sequential JPEGs, these parameters are hard-coded to 0/63/0/0, as
// per table B.3.
zigStart, zigEnd, ah, al := int32(0), int32(blockSize-1), uint32(0), uint32(0)
if d.progressive {
zigStart = int32(d.tmp[1+2*nComp])
zigEnd = int32(d.tmp[2+2*nComp])
ah = uint32(d.tmp[3+2*nComp] >> 4)
al = uint32(d.tmp[3+2*nComp] & 0x0f)
if (zigStart == 0 && zigEnd != 0) || zigStart > zigEnd || blockSize <= zigEnd {
return FormatError("bad spectral selection bounds")
}
if zigStart != 0 && nComp != 1 {
return FormatError("progressive AC coefficients for more than one component")
}
if ah != 0 && ah != al+1 {
return FormatError("bad successive approximation values")
}
}
// mxx and myy are the number of MCUs (Minimum Coded Units) in the image.
h0, v0 := d.comp[0].h, d.comp[0].v // The h and v values from the Y components.
mxx := (d.width + 8*h0 - 1) / (8 * h0)
myy := (d.height + 8*v0 - 1) / (8 * v0)
if d.img1 == nil && d.img3 == nil {
d.makeImg(mxx, myy)
}
if d.progressive {
for i := 0; i < nComp; i++ {
compIndex := scan[i].compIndex
if d.progCoeffs[compIndex] == nil {
d.progCoeffs[compIndex] = make([]block, mxx*myy*d.comp[compIndex].h*d.comp[compIndex].v)
}
}
}
d.bits = bits{}
mcu, expectedRST := 0, uint8(rst0Marker)
var (
// b is the decoded coefficients, in natural (not zig-zag) order.
b block
dc [maxComponents]int32
// bx and by are the location of the current block, in units of 8x8
// blocks: the third block in the first row has (bx, by) = (2, 0).
bx, by int
blockCount int
)
for my := 0; my < myy; my++ {
for mx := 0; mx < mxx; mx++ {
for i := 0; i < nComp; i++ {
compIndex := scan[i].compIndex
hi := d.comp[compIndex].h
vi := d.comp[compIndex].v
for j := 0; j < hi*vi; j++ {
// The blocks are traversed one MCU at a time. For 4:2:0 chroma
// subsampling, there are four Y 8x8 blocks in every 16x16 MCU.
//
// For a sequential 32x16 pixel image, the Y blocks visiting order is:
// 0 1 4 5
// 2 3 6 7
//
// For progressive images, the interleaved scans (those with nComp > 1)
// are traversed as above, but non-interleaved scans are traversed left
// to right, top to bottom:
// 0 1 2 3
// 4 5 6 7
// Only DC scans (zigStart == 0) can be interleaved. AC scans must have
// only one component.
//
// To further complicate matters, for non-interleaved scans, there is no
// data for any blocks that are inside the image at the MCU level but
// outside the image at the pixel level. For example, a 24x16 pixel 4:2:0
// progressive image consists of two 16x16 MCUs. The interleaved scans
// will process 8 Y blocks:
// 0 1 4 5
// 2 3 6 7
// The non-interleaved scans will process only 6 Y blocks:
// 0 1 2
// 3 4 5
if nComp != 1 {
bx = hi*mx + j%hi
by = vi*my + j/hi
} else {
q := mxx * hi
bx = blockCount % q
by = blockCount / q
blockCount++
if bx*8 >= d.width || by*8 >= d.height {
continue
}
}
// Load the previous partially decoded coefficients, if applicable.
if d.progressive {
b = d.progCoeffs[compIndex][by*mxx*hi+bx]
} else {
b = block{}
}
if ah != 0 {
if err := d.refine(&b, &d.huff[acTable][scan[i].ta], zigStart, zigEnd, 1<<al); err != nil {
return err
}
} else {
zig := zigStart
if zig == 0 {
zig++
// Decode the DC coefficient, as specified in section F.2.2.1.
value, err := d.decodeHuffman(&d.huff[dcTable][scan[i].td])
if err != nil {
return err
}
if value > 16 {
return UnsupportedError("excessive DC component")
}
dcDelta, err := d.receiveExtend(value)
if err != nil {
return err
}
dc[compIndex] += dcDelta
b[0] = dc[compIndex] << al
}
if zig <= zigEnd && d.eobRun > 0 {
d.eobRun--
} else {
// Decode the AC coefficients, as specified in section F.2.2.2.
huff := &d.huff[acTable][scan[i].ta]
for ; zig <= zigEnd; zig++ {
value, err := d.decodeHuffman(huff)
if err != nil {
return err
}
val0 := value >> 4
val1 := value & 0x0f
if val1 != 0 {
zig += int32(val0)
if zig > zigEnd {
break
}
ac, err := d.receiveExtend(val1)
if err != nil {
return err
}
b[unzig[zig]] = ac << al
} else {
if val0 != 0x0f {
d.eobRun = uint16(1 << val0)
if val0 != 0 {
bits, err := d.decodeBits(int32(val0))
if err != nil {
return err
}
d.eobRun |= uint16(bits)
}
d.eobRun--
break
}
zig += 0x0f
}
}
}
}
if d.progressive {
// Save the coefficients.
d.progCoeffs[compIndex][by*mxx*hi+bx] = b
// At this point, we could call reconstructBlock to dequantize and perform the
// inverse DCT, to save early stages of a progressive image to the *image.YCbCr
// buffers (the whole point of progressive encoding), but in Go, the jpeg.Decode
// function does not return until the entire image is decoded, so we "continue"
// here to avoid wasted computation. Instead, reconstructBlock is called on each
// accumulated block by the reconstructProgressiveImage method after all of the
// SOS markers are processed.
continue
}
if err := d.reconstructBlock(&b, bx, by, int(compIndex)); err != nil {
return err
}
} // for j
} // for i
mcu++
if d.ri > 0 && mcu%d.ri == 0 && mcu < mxx*myy {
// A more sophisticated decoder could use RST[0-7] markers to resynchronize from corrupt input,
// but this one assumes well-formed input, and hence the restart marker follows immediately.
if err := d.readFull(d.tmp[:2]); err != nil {
return err
}
// Section F.1.2.3 says that "Byte alignment of markers is
// achieved by padding incomplete bytes with 1-bits. If padding
// with 1-bits creates a X’FF’ value, a zero byte is stuffed
// before adding the marker."
//
// Seeing "\xff\x00" here is not spec compliant, as we are not
// expecting an *incomplete* byte (that needed padding). Still,
// some real world encoders (see golang.org/issue/28717) insert
// it, so we accept it and re-try the 2 byte read.
//
// libjpeg issues a warning (but not an error) for this:
// https://github.com/LuaDist/libjpeg/blob/6c0fcb8ddee365e7abc4d332662b06900612e923/jdmarker.c#L1041-L1046
if d.tmp[0] == 0xff && d.tmp[1] == 0x00 {
if err := d.readFull(d.tmp[:2]); err != nil {
return err
}
}
if d.tmp[0] != 0xff || d.tmp[1] != expectedRST {
return FormatError("bad RST marker")
}
expectedRST++
if expectedRST == rst7Marker+1 {
expectedRST = rst0Marker
}
// Reset the Huffman decoder.
d.bits = bits{}
// Reset the DC components, as per section F.2.1.3.1.
dc = [maxComponents]int32{}
// Reset the progressive decoder state, as per section G.1.2.2.
d.eobRun = 0
}
} // for mx
} // for my
return nil
}
// refine decodes a successive approximation refinement block, as specified in
// section G.1.2.
func (d *decoder) refine(b *block, h *huffman, zigStart, zigEnd, delta int32) error {
// Refining a DC component is trivial.
if zigStart == 0 {
if zigEnd != 0 {
panic("unreachable")
}
bit, err := d.decodeBit()
if err != nil {
return err
}
if bit {
b[0] |= delta
}
return nil
}
// Refining AC components is more complicated; see sections G.1.2.2 and G.1.2.3.
zig := zigStart
if d.eobRun == 0 {
loop:
for ; zig <= zigEnd; zig++ {
z := int32(0)
value, err := d.decodeHuffman(h)
if err != nil {
return err
}
val0 := value >> 4
val1 := value & 0x0f
switch val1 {
case 0:
if val0 != 0x0f {
d.eobRun = uint16(1 << val0)
if val0 != 0 {
bits, err := d.decodeBits(int32(val0))
if err != nil {
return err
}
d.eobRun |= uint16(bits)
}
break loop
}
case 1:
z = delta
bit, err := d.decodeBit()
if err != nil {
return err
}
if !bit {
z = -z
}
default:
return FormatError("unexpected Huffman code")
}
zig, err = d.refineNonZeroes(b, zig, zigEnd, int32(val0), delta)
if err != nil {
return err
}
if zig > zigEnd {
return FormatError("too many coefficients")
}
if z != 0 {
b[unzig[zig]] = z
}
}
}
if d.eobRun > 0 {
d.eobRun--
if _, err := d.refineNonZeroes(b, zig, zigEnd, -1, delta); err != nil {
return err
}
}
return nil
}
// refineNonZeroes refines non-zero entries of b in zig-zag order. If nz >= 0,
// the first nz zero entries are skipped over.
func (d *decoder) refineNonZeroes(b *block, zig, zigEnd, nz, delta int32) (int32, error) {
for ; zig <= zigEnd; zig++ {
u := unzig[zig]
if b[u] == 0 {
if nz == 0 {
break
}
nz--
continue
}
bit, err := d.decodeBit()
if err != nil {
return 0, err
}
if !bit {
continue
}
if b[u] >= 0 {
b[u] += delta
} else {
b[u] -= delta
}
}
return zig, nil
}
func (d *decoder) reconstructProgressiveImage() error {
// The h0, mxx, by and bx variables have the same meaning as in the
// processSOS method.
h0 := d.comp[0].h
mxx := (d.width + 8*h0 - 1) / (8 * h0)
for i := 0; i < d.nComp; i++ {
if d.progCoeffs[i] == nil {
continue
}
v := 8 * d.comp[0].v / d.comp[i].v
h := 8 * d.comp[0].h / d.comp[i].h
stride := mxx * d.comp[i].h
for by := 0; by*v < d.height; by++ {
for bx := 0; bx*h < d.width; bx++ {
if err := d.reconstructBlock(&d.progCoeffs[i][by*stride+bx], bx, by, i); err != nil {
return err
}
}
}
}
return nil
}
// reconstructBlock dequantizes, performs the inverse DCT and stores the block
// to the image.
func (d *decoder) reconstructBlock(b *block, bx, by, compIndex int) error {
qt := &d.quant[d.comp[compIndex].tq]
for zig := 0; zig < blockSize; zig++ {
b[unzig[zig]] *= qt[zig]
}
idct(b)
dst, stride := []byte(nil), 0
if d.nComp == 1 {
dst, stride = d.img1.Pix[8*(by*d.img1.Stride+bx):], d.img1.Stride
} else {
switch compIndex {
case 0:
dst, stride = d.img3.Y[8*(by*d.img3.YStride+bx):], d.img3.YStride
case 1:
dst, stride = d.img3.Cb[8*(by*d.img3.CStride+bx):], d.img3.CStride
case 2:
dst, stride = d.img3.Cr[8*(by*d.img3.CStride+bx):], d.img3.CStride
case 3:
dst, stride = d.blackPix[8*(by*d.blackStride+bx):], d.blackStride
default:
return UnsupportedError("too many components")
}
}
// Level shift by +128, clip to [0, 255], and write to dst.
for y := 0; y < 8; y++ {
y8 := y * 8
yStride := y * stride
for x := 0; x < 8; x++ {
c := b[y8+x]
if c < -128 {
c = 0
} else if c > 127 {
c = 255
} else {
c += 128
}
dst[yStride+x] = uint8(c)
}
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jpeg
import (
"bufio"
"errors"
"image"
"image/color"
"io"
)
// min returns the minimum of two integers.
func min(x, y int) int {
if x < y {
return x
}
return y
}
// div returns a/b rounded to the nearest integer, instead of rounded to zero.
func div(a, b int32) int32 {
if a >= 0 {
return (a + (b >> 1)) / b
}
return -((-a + (b >> 1)) / b)
}
// bitCount counts the number of bits needed to hold an integer.
var bitCount = [256]byte{
0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
}
type quantIndex int
const (
quantIndexLuminance quantIndex = iota
quantIndexChrominance
nQuantIndex
)
// unscaledQuant are the unscaled quantization tables in zig-zag order. Each
// encoder copies and scales the tables according to its quality parameter.
// The values are derived from section K.1 after converting from natural to
// zig-zag order.
var unscaledQuant = [nQuantIndex][blockSize]byte{
// Luminance.
{
16, 11, 12, 14, 12, 10, 16, 14,
13, 14, 18, 17, 16, 19, 24, 40,
26, 24, 22, 22, 24, 49, 35, 37,
29, 40, 58, 51, 61, 60, 57, 51,
56, 55, 64, 72, 92, 78, 64, 68,
87, 69, 55, 56, 80, 109, 81, 87,
95, 98, 103, 104, 103, 62, 77, 113,
121, 112, 100, 120, 92, 101, 103, 99,
},
// Chrominance.
{
17, 18, 18, 24, 21, 24, 47, 26,
26, 47, 99, 66, 56, 66, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
99, 99, 99, 99, 99, 99, 99, 99,
},
}
type huffIndex int
const (
huffIndexLuminanceDC huffIndex = iota
huffIndexLuminanceAC
huffIndexChrominanceDC
huffIndexChrominanceAC
nHuffIndex
)
// huffmanSpec specifies a Huffman encoding.
type huffmanSpec struct {
// count[i] is the number of codes of length i bits.
count [16]byte
// value[i] is the decoded value of the i'th codeword.
value []byte
}
// theHuffmanSpec is the Huffman encoding specifications.
// This encoder uses the same Huffman encoding for all images.
var theHuffmanSpec = [nHuffIndex]huffmanSpec{
// Luminance DC.
{
[16]byte{0, 1, 5, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0},
[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
// Luminance AC.
{
[16]byte{0, 2, 1, 3, 3, 2, 4, 3, 5, 5, 4, 4, 0, 0, 1, 125},
[]byte{
0x01, 0x02, 0x03, 0x00, 0x04, 0x11, 0x05, 0x12,
0x21, 0x31, 0x41, 0x06, 0x13, 0x51, 0x61, 0x07,
0x22, 0x71, 0x14, 0x32, 0x81, 0x91, 0xa1, 0x08,
0x23, 0x42, 0xb1, 0xc1, 0x15, 0x52, 0xd1, 0xf0,
0x24, 0x33, 0x62, 0x72, 0x82, 0x09, 0x0a, 0x16,
0x17, 0x18, 0x19, 0x1a, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39,
0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49,
0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59,
0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69,
0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
0x7a, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89,
0x8a, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98,
0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6,
0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3, 0xc4, 0xc5,
0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2, 0xd3, 0xd4,
0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xe1, 0xe2,
0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea,
0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8,
0xf9, 0xfa,
},
},
// Chrominance DC.
{
[16]byte{0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0},
[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
},
// Chrominance AC.
{
[16]byte{0, 2, 1, 2, 4, 4, 3, 4, 7, 5, 4, 4, 0, 1, 2, 119},
[]byte{
0x00, 0x01, 0x02, 0x03, 0x11, 0x04, 0x05, 0x21,
0x31, 0x06, 0x12, 0x41, 0x51, 0x07, 0x61, 0x71,
0x13, 0x22, 0x32, 0x81, 0x08, 0x14, 0x42, 0x91,
0xa1, 0xb1, 0xc1, 0x09, 0x23, 0x33, 0x52, 0xf0,
0x15, 0x62, 0x72, 0xd1, 0x0a, 0x16, 0x24, 0x34,
0xe1, 0x25, 0xf1, 0x17, 0x18, 0x19, 0x1a, 0x26,
0x27, 0x28, 0x29, 0x2a, 0x35, 0x36, 0x37, 0x38,
0x39, 0x3a, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48,
0x49, 0x4a, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58,
0x59, 0x5a, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
0x69, 0x6a, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78,
0x79, 0x7a, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x92, 0x93, 0x94, 0x95, 0x96,
0x97, 0x98, 0x99, 0x9a, 0xa2, 0xa3, 0xa4, 0xa5,
0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xb2, 0xb3, 0xb4,
0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xc2, 0xc3,
0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xd2,
0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda,
0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9,
0xea, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8,
0xf9, 0xfa,
},
},
}
// huffmanLUT is a compiled look-up table representation of a huffmanSpec.
// Each value maps to a uint32 of which the 8 most significant bits hold the
// codeword size in bits and the 24 least significant bits hold the codeword.
// The maximum codeword size is 16 bits.
type huffmanLUT []uint32
func (h *huffmanLUT) init(s huffmanSpec) {
maxValue := 0
for _, v := range s.value {
if int(v) > maxValue {
maxValue = int(v)
}
}
*h = make([]uint32, maxValue+1)
code, k := uint32(0), 0
for i := 0; i < len(s.count); i++ {
nBits := uint32(i+1) << 24
for j := uint8(0); j < s.count[i]; j++ {
(*h)[s.value[k]] = nBits | code
code++
k++
}
code <<= 1
}
}
// theHuffmanLUT are compiled representations of theHuffmanSpec.
var theHuffmanLUT [4]huffmanLUT
func init() {
for i, s := range theHuffmanSpec {
theHuffmanLUT[i].init(s)
}
}
// writer is a buffered writer.
type writer interface {
Flush() error
io.Writer
io.ByteWriter
}
// encoder encodes an image to the JPEG format.
type encoder struct {
// w is the writer to write to. err is the first error encountered during
// writing. All attempted writes after the first error become no-ops.
w writer
err error
// buf is a scratch buffer.
buf [16]byte
// bits and nBits are accumulated bits to write to w.
bits, nBits uint32
// quant is the scaled quantization tables, in zig-zag order.
quant [nQuantIndex][blockSize]byte
}
func (e *encoder) flush() {
if e.err != nil {
return
}
e.err = e.w.Flush()
}
func (e *encoder) write(p []byte) {
if e.err != nil {
return
}
_, e.err = e.w.Write(p)
}
func (e *encoder) writeByte(b byte) {
if e.err != nil {
return
}
e.err = e.w.WriteByte(b)
}
// emit emits the least significant nBits bits of bits to the bit-stream.
// The precondition is bits < 1<<nBits && nBits <= 16.
func (e *encoder) emit(bits, nBits uint32) {
nBits += e.nBits
bits <<= 32 - nBits
bits |= e.bits
for nBits >= 8 {
b := uint8(bits >> 24)
e.writeByte(b)
if b == 0xff {
e.writeByte(0x00)
}
bits <<= 8
nBits -= 8
}
e.bits, e.nBits = bits, nBits
}
// emitHuff emits the given value with the given Huffman encoder.
func (e *encoder) emitHuff(h huffIndex, value int32) {
x := theHuffmanLUT[h][value]
e.emit(x&(1<<24-1), x>>24)
}
// emitHuffRLE emits a run of runLength copies of value encoded with the given
// Huffman encoder.
func (e *encoder) emitHuffRLE(h huffIndex, runLength, value int32) {
a, b := value, value
if a < 0 {
a, b = -value, value-1
}
var nBits uint32
if a < 0x100 {
nBits = uint32(bitCount[a])
} else {
nBits = 8 + uint32(bitCount[a>>8])
}
e.emitHuff(h, runLength<<4|int32(nBits))
if nBits > 0 {
e.emit(uint32(b)&(1<<nBits-1), nBits)
}
}
// writeMarkerHeader writes the header for a marker with the given length.
func (e *encoder) writeMarkerHeader(marker uint8, markerlen int) {
e.buf[0] = 0xff
e.buf[1] = marker
e.buf[2] = uint8(markerlen >> 8)
e.buf[3] = uint8(markerlen & 0xff)
e.write(e.buf[:4])
}
// writeDQT writes the Define Quantization Table marker.
func (e *encoder) writeDQT() {
const markerlen = 2 + int(nQuantIndex)*(1+blockSize)
e.writeMarkerHeader(dqtMarker, markerlen)
for i := range e.quant {
e.writeByte(uint8(i))
e.write(e.quant[i][:])
}
}
// writeSOF0 writes the Start Of Frame (Baseline Sequential) marker.
func (e *encoder) writeSOF0(size image.Point, nComponent int) {
markerlen := 8 + 3*nComponent
e.writeMarkerHeader(sof0Marker, markerlen)
e.buf[0] = 8 // 8-bit color.
e.buf[1] = uint8(size.Y >> 8)
e.buf[2] = uint8(size.Y & 0xff)
e.buf[3] = uint8(size.X >> 8)
e.buf[4] = uint8(size.X & 0xff)
e.buf[5] = uint8(nComponent)
if nComponent == 1 {
e.buf[6] = 1
// No subsampling for grayscale image.
e.buf[7] = 0x11
e.buf[8] = 0x00
} else {
for i := 0; i < nComponent; i++ {
e.buf[3*i+6] = uint8(i + 1)
// We use 4:2:0 chroma subsampling.
e.buf[3*i+7] = "\x22\x11\x11"[i]
e.buf[3*i+8] = "\x00\x01\x01"[i]
}
}
e.write(e.buf[:3*(nComponent-1)+9])
}
// writeDHT writes the Define Huffman Table marker.
func (e *encoder) writeDHT(nComponent int) {
markerlen := 2
specs := theHuffmanSpec[:]
if nComponent == 1 {
// Drop the Chrominance tables.
specs = specs[:2]
}
for _, s := range specs {
markerlen += 1 + 16 + len(s.value)
}
e.writeMarkerHeader(dhtMarker, markerlen)
for i, s := range specs {
e.writeByte("\x00\x10\x01\x11"[i])
e.write(s.count[:])
e.write(s.value)
}
}
// writeBlock writes a block of pixel data using the given quantization table,
// returning the post-quantized DC value of the DCT-transformed block. b is in
// natural (not zig-zag) order.
func (e *encoder) writeBlock(b *block, q quantIndex, prevDC int32) int32 {
fdct(b)
// Emit the DC delta.
dc := div(b[0], 8*int32(e.quant[q][0]))
e.emitHuffRLE(huffIndex(2*q+0), 0, dc-prevDC)
// Emit the AC components.
h, runLength := huffIndex(2*q+1), int32(0)
for zig := 1; zig < blockSize; zig++ {
ac := div(b[unzig[zig]], 8*int32(e.quant[q][zig]))
if ac == 0 {
runLength++
} else {
for runLength > 15 {
e.emitHuff(h, 0xf0)
runLength -= 16
}
e.emitHuffRLE(h, runLength, ac)
runLength = 0
}
}
if runLength > 0 {
e.emitHuff(h, 0x00)
}
return dc
}
// toYCbCr converts the 8x8 region of m whose top-left corner is p to its
// YCbCr values.
func toYCbCr(m image.Image, p image.Point, yBlock, cbBlock, crBlock *block) {
b := m.Bounds()
xmax := b.Max.X - 1
ymax := b.Max.Y - 1
for j := 0; j < 8; j++ {
for i := 0; i < 8; i++ {
r, g, b, _ := m.At(min(p.X+i, xmax), min(p.Y+j, ymax)).RGBA()
yy, cb, cr := color.RGBToYCbCr(uint8(r>>8), uint8(g>>8), uint8(b>>8))
yBlock[8*j+i] = int32(yy)
cbBlock[8*j+i] = int32(cb)
crBlock[8*j+i] = int32(cr)
}
}
}
// grayToY stores the 8x8 region of m whose top-left corner is p in yBlock.
func grayToY(m *image.Gray, p image.Point, yBlock *block) {
b := m.Bounds()
xmax := b.Max.X - 1
ymax := b.Max.Y - 1
pix := m.Pix
for j := 0; j < 8; j++ {
for i := 0; i < 8; i++ {
idx := m.PixOffset(min(p.X+i, xmax), min(p.Y+j, ymax))
yBlock[8*j+i] = int32(pix[idx])
}
}
}
// rgbaToYCbCr is a specialized version of toYCbCr for image.RGBA images.
func rgbaToYCbCr(m *image.RGBA, p image.Point, yBlock, cbBlock, crBlock *block) {
b := m.Bounds()
xmax := b.Max.X - 1
ymax := b.Max.Y - 1
for j := 0; j < 8; j++ {
sj := p.Y + j
if sj > ymax {
sj = ymax
}
offset := (sj-b.Min.Y)*m.Stride - b.Min.X*4
for i := 0; i < 8; i++ {
sx := p.X + i
if sx > xmax {
sx = xmax
}
pix := m.Pix[offset+sx*4:]
yy, cb, cr := color.RGBToYCbCr(pix[0], pix[1], pix[2])
yBlock[8*j+i] = int32(yy)
cbBlock[8*j+i] = int32(cb)
crBlock[8*j+i] = int32(cr)
}
}
}
// yCbCrToYCbCr is a specialized version of toYCbCr for image.YCbCr images.
func yCbCrToYCbCr(m *image.YCbCr, p image.Point, yBlock, cbBlock, crBlock *block) {
b := m.Bounds()
xmax := b.Max.X - 1
ymax := b.Max.Y - 1
for j := 0; j < 8; j++ {
sy := p.Y + j
if sy > ymax {
sy = ymax
}
for i := 0; i < 8; i++ {
sx := p.X + i
if sx > xmax {
sx = xmax
}
yi := m.YOffset(sx, sy)
ci := m.COffset(sx, sy)
yBlock[8*j+i] = int32(m.Y[yi])
cbBlock[8*j+i] = int32(m.Cb[ci])
crBlock[8*j+i] = int32(m.Cr[ci])
}
}
}
// scale scales the 16x16 region represented by the 4 src blocks to the 8x8
// dst block.
func scale(dst *block, src *[4]block) {
for i := 0; i < 4; i++ {
dstOff := (i&2)<<4 | (i&1)<<2
for y := 0; y < 4; y++ {
for x := 0; x < 4; x++ {
j := 16*y + 2*x
sum := src[i][j] + src[i][j+1] + src[i][j+8] + src[i][j+9]
dst[8*y+x+dstOff] = (sum + 2) >> 2
}
}
}
}
// sosHeaderY is the SOS marker "\xff\xda" followed by 8 bytes:
// - the marker length "\x00\x08",
// - the number of components "\x01",
// - component 1 uses DC table 0 and AC table 0 "\x01\x00",
// - the bytes "\x00\x3f\x00". Section B.2.3 of the spec says that for
// sequential DCTs, those bytes (8-bit Ss, 8-bit Se, 4-bit Ah, 4-bit Al)
// should be 0x00, 0x3f, 0x00<<4 | 0x00.
var sosHeaderY = []byte{
0xff, 0xda, 0x00, 0x08, 0x01, 0x01, 0x00, 0x00, 0x3f, 0x00,
}
// sosHeaderYCbCr is the SOS marker "\xff\xda" followed by 12 bytes:
// - the marker length "\x00\x0c",
// - the number of components "\x03",
// - component 1 uses DC table 0 and AC table 0 "\x01\x00",
// - component 2 uses DC table 1 and AC table 1 "\x02\x11",
// - component 3 uses DC table 1 and AC table 1 "\x03\x11",
// - the bytes "\x00\x3f\x00". Section B.2.3 of the spec says that for
// sequential DCTs, those bytes (8-bit Ss, 8-bit Se, 4-bit Ah, 4-bit Al)
// should be 0x00, 0x3f, 0x00<<4 | 0x00.
var sosHeaderYCbCr = []byte{
0xff, 0xda, 0x00, 0x0c, 0x03, 0x01, 0x00, 0x02,
0x11, 0x03, 0x11, 0x00, 0x3f, 0x00,
}
// writeSOS writes the StartOfScan marker.
func (e *encoder) writeSOS(m image.Image) {
switch m.(type) {
case *image.Gray:
e.write(sosHeaderY)
default:
e.write(sosHeaderYCbCr)
}
var (
// Scratch buffers to hold the YCbCr values.
// The blocks are in natural (not zig-zag) order.
b block
cb, cr [4]block
// DC components are delta-encoded.
prevDCY, prevDCCb, prevDCCr int32
)
bounds := m.Bounds()
switch m := m.(type) {
// TODO(wathiede): switch on m.ColorModel() instead of type.
case *image.Gray:
for y := bounds.Min.Y; y < bounds.Max.Y; y += 8 {
for x := bounds.Min.X; x < bounds.Max.X; x += 8 {
p := image.Pt(x, y)
grayToY(m, p, &b)
prevDCY = e.writeBlock(&b, 0, prevDCY)
}
}
default:
rgba, _ := m.(*image.RGBA)
ycbcr, _ := m.(*image.YCbCr)
for y := bounds.Min.Y; y < bounds.Max.Y; y += 16 {
for x := bounds.Min.X; x < bounds.Max.X; x += 16 {
for i := 0; i < 4; i++ {
xOff := (i & 1) * 8
yOff := (i & 2) * 4
p := image.Pt(x+xOff, y+yOff)
if rgba != nil {
rgbaToYCbCr(rgba, p, &b, &cb[i], &cr[i])
} else if ycbcr != nil {
yCbCrToYCbCr(ycbcr, p, &b, &cb[i], &cr[i])
} else {
toYCbCr(m, p, &b, &cb[i], &cr[i])
}
prevDCY = e.writeBlock(&b, 0, prevDCY)
}
scale(&b, &cb)
prevDCCb = e.writeBlock(&b, 1, prevDCCb)
scale(&b, &cr)
prevDCCr = e.writeBlock(&b, 1, prevDCCr)
}
}
}
// Pad the last byte with 1's.
e.emit(0x7f, 7)
}
// DefaultQuality is the default quality encoding parameter.
const DefaultQuality = 75
// Options are the encoding parameters.
// Quality ranges from 1 to 100 inclusive, higher is better.
type Options struct {
Quality int
}
// Encode writes the Image m to w in JPEG 4:2:0 baseline format with the given
// options. Default parameters are used if a nil *Options is passed.
func Encode(w io.Writer, m image.Image, o *Options) error {
b := m.Bounds()
if b.Dx() >= 1<<16 || b.Dy() >= 1<<16 {
return errors.New("jpeg: image is too large to encode")
}
var e encoder
if ww, ok := w.(writer); ok {
e.w = ww
} else {
e.w = bufio.NewWriter(w)
}
// Clip quality to [1, 100].
quality := DefaultQuality
if o != nil {
quality = o.Quality
if quality < 1 {
quality = 1
} else if quality > 100 {
quality = 100
}
}
// Convert from a quality rating to a scaling factor.
var scale int
if quality < 50 {
scale = 5000 / quality
} else {
scale = 200 - quality*2
}
// Initialize the quantization tables.
for i := range e.quant {
for j := range e.quant[i] {
x := int(unscaledQuant[i][j])
x = (x*scale + 50) / 100
if x < 1 {
x = 1
} else if x > 255 {
x = 255
}
e.quant[i][j] = uint8(x)
}
}
// Compute number of components based on input image type.
nComponent := 3
switch m.(type) {
// TODO(wathiede): switch on m.ColorModel() instead of type.
case *image.Gray:
nComponent = 1
}
// Write the Start Of Image marker.
e.buf[0] = 0xff
e.buf[1] = 0xd8
e.write(e.buf[:2])
// Write the quantization tables.
e.writeDQT()
// Write the image dimensions.
e.writeSOF0(b.Size(), nComponent)
// Write the Huffman tables.
e.writeDHT(nComponent)
// Write the image data.
e.writeSOS(m)
// Write the End Of Image marker.
e.buf[0] = 0xff
e.buf[1] = 0xd9
e.write(e.buf[:2])
e.flush()
return e.err
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"image/color"
)
var (
// Black is an opaque black uniform image.
Black = NewUniform(color.Black)
// White is an opaque white uniform image.
White = NewUniform(color.White)
// Transparent is a fully transparent uniform image.
Transparent = NewUniform(color.Transparent)
// Opaque is a fully opaque uniform image.
Opaque = NewUniform(color.Opaque)
)
// Uniform is an infinite-sized Image of uniform color.
// It implements the color.Color, color.Model, and Image interfaces.
type Uniform struct {
C color.Color
}
func (c *Uniform) RGBA() (r, g, b, a uint32) {
return c.C.RGBA()
}
func (c *Uniform) ColorModel() color.Model {
return c
}
func (c *Uniform) Convert(color.Color) color.Color {
return c.C
}
func (c *Uniform) Bounds() Rectangle { return Rectangle{Point{-1e9, -1e9}, Point{1e9, 1e9}} }
func (c *Uniform) At(x, y int) color.Color { return c.C }
func (c *Uniform) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := c.C.RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (c *Uniform) Opaque() bool {
_, _, _, a := c.C.RGBA()
return a == 0xffff
}
// NewUniform returns a new Uniform image of the given color.
func NewUniform(c color.Color) *Uniform {
return &Uniform{c}
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package png
// intSize is either 32 or 64.
const intSize = 32 << (^uint(0) >> 63)
func abs(x int) int {
// m := -1 if x < 0. m := 0 otherwise.
m := x >> (intSize - 1)
// In two's complement representation, the negative number
// of any number (except the smallest one) can be computed
// by flipping all the bits and add 1. This is faster than
// code with a branch.
// See Hacker's Delight, section 2-4.
return (x ^ m) - m
}
// paeth implements the Paeth filter function, as per the PNG specification.
func paeth(a, b, c uint8) uint8 {
// This is an optimized version of the sample code in the PNG spec.
// For example, the sample code starts with:
// p := int(a) + int(b) - int(c)
// pa := abs(p - int(a))
// but the optimized form uses fewer arithmetic operations:
// pa := int(b) - int(c)
// pa = abs(pa)
pc := int(c)
pa := int(b) - pc
pb := int(a) - pc
pc = abs(pa + pb)
pa = abs(pa)
pb = abs(pb)
if pa <= pb && pa <= pc {
return a
} else if pb <= pc {
return b
}
return c
}
// filterPaeth applies the Paeth filter to the cdat slice.
// cdat is the current row's data, pdat is the previous row's data.
func filterPaeth(cdat, pdat []byte, bytesPerPixel int) {
var a, b, c, pa, pb, pc int
for i := 0; i < bytesPerPixel; i++ {
a, c = 0, 0
for j := i; j < len(cdat); j += bytesPerPixel {
b = int(pdat[j])
pa = b - c
pb = a - c
pc = abs(pa + pb)
pa = abs(pa)
pb = abs(pb)
if pa <= pb && pa <= pc {
// No-op.
} else if pb <= pc {
a = b
} else {
a = c
}
a += int(cdat[j])
a &= 0xff
cdat[j] = uint8(a)
c = b
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package png implements a PNG image decoder and encoder.
//
// The PNG specification is at https://www.w3.org/TR/PNG/.
package png
import (
"compress/zlib"
"encoding/binary"
"fmt"
"hash"
"hash/crc32"
"image"
"image/color"
"io"
)
// Color type, as per the PNG spec.
const (
ctGrayscale = 0
ctTrueColor = 2
ctPaletted = 3
ctGrayscaleAlpha = 4
ctTrueColorAlpha = 6
)
// A cb is a combination of color type and bit depth.
const (
cbInvalid = iota
cbG1
cbG2
cbG4
cbG8
cbGA8
cbTC8
cbP1
cbP2
cbP4
cbP8
cbTCA8
cbG16
cbGA16
cbTC16
cbTCA16
)
func cbPaletted(cb int) bool {
return cbP1 <= cb && cb <= cbP8
}
func cbTrueColor(cb int) bool {
return cb == cbTC8 || cb == cbTC16
}
// Filter type, as per the PNG spec.
const (
ftNone = 0
ftSub = 1
ftUp = 2
ftAverage = 3
ftPaeth = 4
nFilter = 5
)
// Interlace type.
const (
itNone = 0
itAdam7 = 1
)
// interlaceScan defines the placement and size of a pass for Adam7 interlacing.
type interlaceScan struct {
xFactor, yFactor, xOffset, yOffset int
}
// interlacing defines Adam7 interlacing, with 7 passes of reduced images.
// See https://www.w3.org/TR/PNG/#8Interlace
var interlacing = []interlaceScan{
{8, 8, 0, 0},
{8, 8, 4, 0},
{4, 8, 0, 4},
{4, 4, 2, 0},
{2, 4, 0, 2},
{2, 2, 1, 0},
{1, 2, 0, 1},
}
// Decoding stage.
// The PNG specification says that the IHDR, PLTE (if present), tRNS (if
// present), IDAT and IEND chunks must appear in that order. There may be
// multiple IDAT chunks, and IDAT chunks must be sequential (i.e. they may not
// have any other chunks between them).
// https://www.w3.org/TR/PNG/#5ChunkOrdering
const (
dsStart = iota
dsSeenIHDR
dsSeenPLTE
dsSeentRNS
dsSeenIDAT
dsSeenIEND
)
const pngHeader = "\x89PNG\r\n\x1a\n"
type decoder struct {
r io.Reader
img image.Image
crc hash.Hash32
width, height int
depth int
palette color.Palette
cb int
stage int
idatLength uint32
tmp [3 * 256]byte
interlace int
// useTransparent and transparent are used for grayscale and truecolor
// transparency, as opposed to palette transparency.
useTransparent bool
transparent [6]byte
}
// A FormatError reports that the input is not a valid PNG.
type FormatError string
func (e FormatError) Error() string { return "png: invalid format: " + string(e) }
var chunkOrderError = FormatError("chunk out of order")
// An UnsupportedError reports that the input uses a valid but unimplemented PNG feature.
type UnsupportedError string
func (e UnsupportedError) Error() string { return "png: unsupported feature: " + string(e) }
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (d *decoder) parseIHDR(length uint32) error {
if length != 13 {
return FormatError("bad IHDR length")
}
if _, err := io.ReadFull(d.r, d.tmp[:13]); err != nil {
return err
}
d.crc.Write(d.tmp[:13])
if d.tmp[10] != 0 {
return UnsupportedError("compression method")
}
if d.tmp[11] != 0 {
return UnsupportedError("filter method")
}
if d.tmp[12] != itNone && d.tmp[12] != itAdam7 {
return FormatError("invalid interlace method")
}
d.interlace = int(d.tmp[12])
w := int32(binary.BigEndian.Uint32(d.tmp[0:4]))
h := int32(binary.BigEndian.Uint32(d.tmp[4:8]))
if w <= 0 || h <= 0 {
return FormatError("non-positive dimension")
}
nPixels64 := int64(w) * int64(h)
nPixels := int(nPixels64)
if nPixels64 != int64(nPixels) {
return UnsupportedError("dimension overflow")
}
// There can be up to 8 bytes per pixel, for 16 bits per channel RGBA.
if nPixels != (nPixels*8)/8 {
return UnsupportedError("dimension overflow")
}
d.cb = cbInvalid
d.depth = int(d.tmp[8])
switch d.depth {
case 1:
switch d.tmp[9] {
case ctGrayscale:
d.cb = cbG1
case ctPaletted:
d.cb = cbP1
}
case 2:
switch d.tmp[9] {
case ctGrayscale:
d.cb = cbG2
case ctPaletted:
d.cb = cbP2
}
case 4:
switch d.tmp[9] {
case ctGrayscale:
d.cb = cbG4
case ctPaletted:
d.cb = cbP4
}
case 8:
switch d.tmp[9] {
case ctGrayscale:
d.cb = cbG8
case ctTrueColor:
d.cb = cbTC8
case ctPaletted:
d.cb = cbP8
case ctGrayscaleAlpha:
d.cb = cbGA8
case ctTrueColorAlpha:
d.cb = cbTCA8
}
case 16:
switch d.tmp[9] {
case ctGrayscale:
d.cb = cbG16
case ctTrueColor:
d.cb = cbTC16
case ctGrayscaleAlpha:
d.cb = cbGA16
case ctTrueColorAlpha:
d.cb = cbTCA16
}
}
if d.cb == cbInvalid {
return UnsupportedError(fmt.Sprintf("bit depth %d, color type %d", d.tmp[8], d.tmp[9]))
}
d.width, d.height = int(w), int(h)
return d.verifyChecksum()
}
func (d *decoder) parsePLTE(length uint32) error {
np := int(length / 3) // The number of palette entries.
if length%3 != 0 || np <= 0 || np > 256 || np > 1<<uint(d.depth) {
return FormatError("bad PLTE length")
}
n, err := io.ReadFull(d.r, d.tmp[:3*np])
if err != nil {
return err
}
d.crc.Write(d.tmp[:n])
switch d.cb {
case cbP1, cbP2, cbP4, cbP8:
d.palette = make(color.Palette, 256)
for i := 0; i < np; i++ {
d.palette[i] = color.RGBA{d.tmp[3*i+0], d.tmp[3*i+1], d.tmp[3*i+2], 0xff}
}
for i := np; i < 256; i++ {
// Initialize the rest of the palette to opaque black. The spec (section
// 11.2.3) says that "any out-of-range pixel value found in the image data
// is an error", but some real-world PNG files have out-of-range pixel
// values. We fall back to opaque black, the same as libpng 1.5.13;
// ImageMagick 6.5.7 returns an error.
d.palette[i] = color.RGBA{0x00, 0x00, 0x00, 0xff}
}
d.palette = d.palette[:np]
case cbTC8, cbTCA8, cbTC16, cbTCA16:
// As per the PNG spec, a PLTE chunk is optional (and for practical purposes,
// ignorable) for the ctTrueColor and ctTrueColorAlpha color types (section 4.1.2).
default:
return FormatError("PLTE, color type mismatch")
}
return d.verifyChecksum()
}
func (d *decoder) parsetRNS(length uint32) error {
switch d.cb {
case cbG1, cbG2, cbG4, cbG8, cbG16:
if length != 2 {
return FormatError("bad tRNS length")
}
n, err := io.ReadFull(d.r, d.tmp[:length])
if err != nil {
return err
}
d.crc.Write(d.tmp[:n])
copy(d.transparent[:], d.tmp[:length])
switch d.cb {
case cbG1:
d.transparent[1] *= 0xff
case cbG2:
d.transparent[1] *= 0x55
case cbG4:
d.transparent[1] *= 0x11
}
d.useTransparent = true
case cbTC8, cbTC16:
if length != 6 {
return FormatError("bad tRNS length")
}
n, err := io.ReadFull(d.r, d.tmp[:length])
if err != nil {
return err
}
d.crc.Write(d.tmp[:n])
copy(d.transparent[:], d.tmp[:length])
d.useTransparent = true
case cbP1, cbP2, cbP4, cbP8:
if length > 256 {
return FormatError("bad tRNS length")
}
n, err := io.ReadFull(d.r, d.tmp[:length])
if err != nil {
return err
}
d.crc.Write(d.tmp[:n])
if len(d.palette) < n {
d.palette = d.palette[:n]
}
for i := 0; i < n; i++ {
rgba := d.palette[i].(color.RGBA)
d.palette[i] = color.NRGBA{rgba.R, rgba.G, rgba.B, d.tmp[i]}
}
default:
return FormatError("tRNS, color type mismatch")
}
return d.verifyChecksum()
}
// Read presents one or more IDAT chunks as one continuous stream (minus the
// intermediate chunk headers and footers). If the PNG data looked like:
//
// ... len0 IDAT xxx crc0 len1 IDAT yy crc1 len2 IEND crc2
//
// then this reader presents xxxyy. For well-formed PNG data, the decoder state
// immediately before the first Read call is that d.r is positioned between the
// first IDAT and xxx, and the decoder state immediately after the last Read
// call is that d.r is positioned between yy and crc1.
func (d *decoder) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
for d.idatLength == 0 {
// We have exhausted an IDAT chunk. Verify the checksum of that chunk.
if err := d.verifyChecksum(); err != nil {
return 0, err
}
// Read the length and chunk type of the next chunk, and check that
// it is an IDAT chunk.
if _, err := io.ReadFull(d.r, d.tmp[:8]); err != nil {
return 0, err
}
d.idatLength = binary.BigEndian.Uint32(d.tmp[:4])
if string(d.tmp[4:8]) != "IDAT" {
return 0, FormatError("not enough pixel data")
}
d.crc.Reset()
d.crc.Write(d.tmp[4:8])
}
if int(d.idatLength) < 0 {
return 0, UnsupportedError("IDAT chunk length overflow")
}
n, err := d.r.Read(p[:min(len(p), int(d.idatLength))])
d.crc.Write(p[:n])
d.idatLength -= uint32(n)
return n, err
}
// decode decodes the IDAT data into an image.
func (d *decoder) decode() (image.Image, error) {
r, err := zlib.NewReader(d)
if err != nil {
return nil, err
}
defer r.Close()
var img image.Image
if d.interlace == itNone {
img, err = d.readImagePass(r, 0, false)
if err != nil {
return nil, err
}
} else if d.interlace == itAdam7 {
// Allocate a blank image of the full size.
img, err = d.readImagePass(nil, 0, true)
if err != nil {
return nil, err
}
for pass := 0; pass < 7; pass++ {
imagePass, err := d.readImagePass(r, pass, false)
if err != nil {
return nil, err
}
if imagePass != nil {
d.mergePassInto(img, imagePass, pass)
}
}
}
// Check for EOF, to verify the zlib checksum.
n := 0
for i := 0; n == 0 && err == nil; i++ {
if i == 100 {
return nil, io.ErrNoProgress
}
n, err = r.Read(d.tmp[:1])
}
if err != nil && err != io.EOF {
return nil, FormatError(err.Error())
}
if n != 0 || d.idatLength != 0 {
return nil, FormatError("too much pixel data")
}
return img, nil
}
// readImagePass reads a single image pass, sized according to the pass number.
func (d *decoder) readImagePass(r io.Reader, pass int, allocateOnly bool) (image.Image, error) {
bitsPerPixel := 0
pixOffset := 0
var (
gray *image.Gray
rgba *image.RGBA
paletted *image.Paletted
nrgba *image.NRGBA
gray16 *image.Gray16
rgba64 *image.RGBA64
nrgba64 *image.NRGBA64
img image.Image
)
width, height := d.width, d.height
if d.interlace == itAdam7 && !allocateOnly {
p := interlacing[pass]
// Add the multiplication factor and subtract one, effectively rounding up.
width = (width - p.xOffset + p.xFactor - 1) / p.xFactor
height = (height - p.yOffset + p.yFactor - 1) / p.yFactor
// A PNG image can't have zero width or height, but for an interlaced
// image, an individual pass might have zero width or height. If so, we
// shouldn't even read a per-row filter type byte, so return early.
if width == 0 || height == 0 {
return nil, nil
}
}
switch d.cb {
case cbG1, cbG2, cbG4, cbG8:
bitsPerPixel = d.depth
if d.useTransparent {
nrgba = image.NewNRGBA(image.Rect(0, 0, width, height))
img = nrgba
} else {
gray = image.NewGray(image.Rect(0, 0, width, height))
img = gray
}
case cbGA8:
bitsPerPixel = 16
nrgba = image.NewNRGBA(image.Rect(0, 0, width, height))
img = nrgba
case cbTC8:
bitsPerPixel = 24
if d.useTransparent {
nrgba = image.NewNRGBA(image.Rect(0, 0, width, height))
img = nrgba
} else {
rgba = image.NewRGBA(image.Rect(0, 0, width, height))
img = rgba
}
case cbP1, cbP2, cbP4, cbP8:
bitsPerPixel = d.depth
paletted = image.NewPaletted(image.Rect(0, 0, width, height), d.palette)
img = paletted
case cbTCA8:
bitsPerPixel = 32
nrgba = image.NewNRGBA(image.Rect(0, 0, width, height))
img = nrgba
case cbG16:
bitsPerPixel = 16
if d.useTransparent {
nrgba64 = image.NewNRGBA64(image.Rect(0, 0, width, height))
img = nrgba64
} else {
gray16 = image.NewGray16(image.Rect(0, 0, width, height))
img = gray16
}
case cbGA16:
bitsPerPixel = 32
nrgba64 = image.NewNRGBA64(image.Rect(0, 0, width, height))
img = nrgba64
case cbTC16:
bitsPerPixel = 48
if d.useTransparent {
nrgba64 = image.NewNRGBA64(image.Rect(0, 0, width, height))
img = nrgba64
} else {
rgba64 = image.NewRGBA64(image.Rect(0, 0, width, height))
img = rgba64
}
case cbTCA16:
bitsPerPixel = 64
nrgba64 = image.NewNRGBA64(image.Rect(0, 0, width, height))
img = nrgba64
}
if allocateOnly {
return img, nil
}
bytesPerPixel := (bitsPerPixel + 7) / 8
// The +1 is for the per-row filter type, which is at cr[0].
rowSize := 1 + (int64(bitsPerPixel)*int64(width)+7)/8
if rowSize != int64(int(rowSize)) {
return nil, UnsupportedError("dimension overflow")
}
// cr and pr are the bytes for the current and previous row.
cr := make([]uint8, rowSize)
pr := make([]uint8, rowSize)
for y := 0; y < height; y++ {
// Read the decompressed bytes.
_, err := io.ReadFull(r, cr)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return nil, FormatError("not enough pixel data")
}
return nil, err
}
// Apply the filter.
cdat := cr[1:]
pdat := pr[1:]
switch cr[0] {
case ftNone:
// No-op.
case ftSub:
for i := bytesPerPixel; i < len(cdat); i++ {
cdat[i] += cdat[i-bytesPerPixel]
}
case ftUp:
for i, p := range pdat {
cdat[i] += p
}
case ftAverage:
// The first column has no column to the left of it, so it is a
// special case. We know that the first column exists because we
// check above that width != 0, and so len(cdat) != 0.
for i := 0; i < bytesPerPixel; i++ {
cdat[i] += pdat[i] / 2
}
for i := bytesPerPixel; i < len(cdat); i++ {
cdat[i] += uint8((int(cdat[i-bytesPerPixel]) + int(pdat[i])) / 2)
}
case ftPaeth:
filterPaeth(cdat, pdat, bytesPerPixel)
default:
return nil, FormatError("bad filter type")
}
// Convert from bytes to colors.
switch d.cb {
case cbG1:
if d.useTransparent {
ty := d.transparent[1]
for x := 0; x < width; x += 8 {
b := cdat[x/8]
for x2 := 0; x2 < 8 && x+x2 < width; x2++ {
ycol := (b >> 7) * 0xff
acol := uint8(0xff)
if ycol == ty {
acol = 0x00
}
nrgba.SetNRGBA(x+x2, y, color.NRGBA{ycol, ycol, ycol, acol})
b <<= 1
}
}
} else {
for x := 0; x < width; x += 8 {
b := cdat[x/8]
for x2 := 0; x2 < 8 && x+x2 < width; x2++ {
gray.SetGray(x+x2, y, color.Gray{(b >> 7) * 0xff})
b <<= 1
}
}
}
case cbG2:
if d.useTransparent {
ty := d.transparent[1]
for x := 0; x < width; x += 4 {
b := cdat[x/4]
for x2 := 0; x2 < 4 && x+x2 < width; x2++ {
ycol := (b >> 6) * 0x55
acol := uint8(0xff)
if ycol == ty {
acol = 0x00
}
nrgba.SetNRGBA(x+x2, y, color.NRGBA{ycol, ycol, ycol, acol})
b <<= 2
}
}
} else {
for x := 0; x < width; x += 4 {
b := cdat[x/4]
for x2 := 0; x2 < 4 && x+x2 < width; x2++ {
gray.SetGray(x+x2, y, color.Gray{(b >> 6) * 0x55})
b <<= 2
}
}
}
case cbG4:
if d.useTransparent {
ty := d.transparent[1]
for x := 0; x < width; x += 2 {
b := cdat[x/2]
for x2 := 0; x2 < 2 && x+x2 < width; x2++ {
ycol := (b >> 4) * 0x11
acol := uint8(0xff)
if ycol == ty {
acol = 0x00
}
nrgba.SetNRGBA(x+x2, y, color.NRGBA{ycol, ycol, ycol, acol})
b <<= 4
}
}
} else {
for x := 0; x < width; x += 2 {
b := cdat[x/2]
for x2 := 0; x2 < 2 && x+x2 < width; x2++ {
gray.SetGray(x+x2, y, color.Gray{(b >> 4) * 0x11})
b <<= 4
}
}
}
case cbG8:
if d.useTransparent {
ty := d.transparent[1]
for x := 0; x < width; x++ {
ycol := cdat[x]
acol := uint8(0xff)
if ycol == ty {
acol = 0x00
}
nrgba.SetNRGBA(x, y, color.NRGBA{ycol, ycol, ycol, acol})
}
} else {
copy(gray.Pix[pixOffset:], cdat)
pixOffset += gray.Stride
}
case cbGA8:
for x := 0; x < width; x++ {
ycol := cdat[2*x+0]
nrgba.SetNRGBA(x, y, color.NRGBA{ycol, ycol, ycol, cdat[2*x+1]})
}
case cbTC8:
if d.useTransparent {
pix, i, j := nrgba.Pix, pixOffset, 0
tr, tg, tb := d.transparent[1], d.transparent[3], d.transparent[5]
for x := 0; x < width; x++ {
r := cdat[j+0]
g := cdat[j+1]
b := cdat[j+2]
a := uint8(0xff)
if r == tr && g == tg && b == tb {
a = 0x00
}
pix[i+0] = r
pix[i+1] = g
pix[i+2] = b
pix[i+3] = a
i += 4
j += 3
}
pixOffset += nrgba.Stride
} else {
pix, i, j := rgba.Pix, pixOffset, 0
for x := 0; x < width; x++ {
pix[i+0] = cdat[j+0]
pix[i+1] = cdat[j+1]
pix[i+2] = cdat[j+2]
pix[i+3] = 0xff
i += 4
j += 3
}
pixOffset += rgba.Stride
}
case cbP1:
for x := 0; x < width; x += 8 {
b := cdat[x/8]
for x2 := 0; x2 < 8 && x+x2 < width; x2++ {
idx := b >> 7
if len(paletted.Palette) <= int(idx) {
paletted.Palette = paletted.Palette[:int(idx)+1]
}
paletted.SetColorIndex(x+x2, y, idx)
b <<= 1
}
}
case cbP2:
for x := 0; x < width; x += 4 {
b := cdat[x/4]
for x2 := 0; x2 < 4 && x+x2 < width; x2++ {
idx := b >> 6
if len(paletted.Palette) <= int(idx) {
paletted.Palette = paletted.Palette[:int(idx)+1]
}
paletted.SetColorIndex(x+x2, y, idx)
b <<= 2
}
}
case cbP4:
for x := 0; x < width; x += 2 {
b := cdat[x/2]
for x2 := 0; x2 < 2 && x+x2 < width; x2++ {
idx := b >> 4
if len(paletted.Palette) <= int(idx) {
paletted.Palette = paletted.Palette[:int(idx)+1]
}
paletted.SetColorIndex(x+x2, y, idx)
b <<= 4
}
}
case cbP8:
if len(paletted.Palette) != 256 {
for x := 0; x < width; x++ {
if len(paletted.Palette) <= int(cdat[x]) {
paletted.Palette = paletted.Palette[:int(cdat[x])+1]
}
}
}
copy(paletted.Pix[pixOffset:], cdat)
pixOffset += paletted.Stride
case cbTCA8:
copy(nrgba.Pix[pixOffset:], cdat)
pixOffset += nrgba.Stride
case cbG16:
if d.useTransparent {
ty := uint16(d.transparent[0])<<8 | uint16(d.transparent[1])
for x := 0; x < width; x++ {
ycol := uint16(cdat[2*x+0])<<8 | uint16(cdat[2*x+1])
acol := uint16(0xffff)
if ycol == ty {
acol = 0x0000
}
nrgba64.SetNRGBA64(x, y, color.NRGBA64{ycol, ycol, ycol, acol})
}
} else {
for x := 0; x < width; x++ {
ycol := uint16(cdat[2*x+0])<<8 | uint16(cdat[2*x+1])
gray16.SetGray16(x, y, color.Gray16{ycol})
}
}
case cbGA16:
for x := 0; x < width; x++ {
ycol := uint16(cdat[4*x+0])<<8 | uint16(cdat[4*x+1])
acol := uint16(cdat[4*x+2])<<8 | uint16(cdat[4*x+3])
nrgba64.SetNRGBA64(x, y, color.NRGBA64{ycol, ycol, ycol, acol})
}
case cbTC16:
if d.useTransparent {
tr := uint16(d.transparent[0])<<8 | uint16(d.transparent[1])
tg := uint16(d.transparent[2])<<8 | uint16(d.transparent[3])
tb := uint16(d.transparent[4])<<8 | uint16(d.transparent[5])
for x := 0; x < width; x++ {
rcol := uint16(cdat[6*x+0])<<8 | uint16(cdat[6*x+1])
gcol := uint16(cdat[6*x+2])<<8 | uint16(cdat[6*x+3])
bcol := uint16(cdat[6*x+4])<<8 | uint16(cdat[6*x+5])
acol := uint16(0xffff)
if rcol == tr && gcol == tg && bcol == tb {
acol = 0x0000
}
nrgba64.SetNRGBA64(x, y, color.NRGBA64{rcol, gcol, bcol, acol})
}
} else {
for x := 0; x < width; x++ {
rcol := uint16(cdat[6*x+0])<<8 | uint16(cdat[6*x+1])
gcol := uint16(cdat[6*x+2])<<8 | uint16(cdat[6*x+3])
bcol := uint16(cdat[6*x+4])<<8 | uint16(cdat[6*x+5])
rgba64.SetRGBA64(x, y, color.RGBA64{rcol, gcol, bcol, 0xffff})
}
}
case cbTCA16:
for x := 0; x < width; x++ {
rcol := uint16(cdat[8*x+0])<<8 | uint16(cdat[8*x+1])
gcol := uint16(cdat[8*x+2])<<8 | uint16(cdat[8*x+3])
bcol := uint16(cdat[8*x+4])<<8 | uint16(cdat[8*x+5])
acol := uint16(cdat[8*x+6])<<8 | uint16(cdat[8*x+7])
nrgba64.SetNRGBA64(x, y, color.NRGBA64{rcol, gcol, bcol, acol})
}
}
// The current row for y is the previous row for y+1.
pr, cr = cr, pr
}
return img, nil
}
// mergePassInto merges a single pass into a full sized image.
func (d *decoder) mergePassInto(dst image.Image, src image.Image, pass int) {
p := interlacing[pass]
var (
srcPix []uint8
dstPix []uint8
stride int
rect image.Rectangle
bytesPerPixel int
)
switch target := dst.(type) {
case *image.Alpha:
srcPix = src.(*image.Alpha).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 1
case *image.Alpha16:
srcPix = src.(*image.Alpha16).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 2
case *image.Gray:
srcPix = src.(*image.Gray).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 1
case *image.Gray16:
srcPix = src.(*image.Gray16).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 2
case *image.NRGBA:
srcPix = src.(*image.NRGBA).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 4
case *image.NRGBA64:
srcPix = src.(*image.NRGBA64).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 8
case *image.Paletted:
source := src.(*image.Paletted)
srcPix = source.Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 1
if len(target.Palette) < len(source.Palette) {
// readImagePass can return a paletted image whose implicit palette
// length (one more than the maximum Pix value) is larger than the
// explicit palette length (what's in the PLTE chunk). Make the
// same adjustment here.
target.Palette = source.Palette
}
case *image.RGBA:
srcPix = src.(*image.RGBA).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 4
case *image.RGBA64:
srcPix = src.(*image.RGBA64).Pix
dstPix, stride, rect = target.Pix, target.Stride, target.Rect
bytesPerPixel = 8
}
s, bounds := 0, src.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
dBase := (y*p.yFactor+p.yOffset-rect.Min.Y)*stride + (p.xOffset-rect.Min.X)*bytesPerPixel
for x := bounds.Min.X; x < bounds.Max.X; x++ {
d := dBase + x*p.xFactor*bytesPerPixel
copy(dstPix[d:], srcPix[s:s+bytesPerPixel])
s += bytesPerPixel
}
}
}
func (d *decoder) parseIDAT(length uint32) (err error) {
d.idatLength = length
d.img, err = d.decode()
if err != nil {
return err
}
return d.verifyChecksum()
}
func (d *decoder) parseIEND(length uint32) error {
if length != 0 {
return FormatError("bad IEND length")
}
return d.verifyChecksum()
}
func (d *decoder) parseChunk(configOnly bool) error {
// Read the length and chunk type.
if _, err := io.ReadFull(d.r, d.tmp[:8]); err != nil {
return err
}
length := binary.BigEndian.Uint32(d.tmp[:4])
d.crc.Reset()
d.crc.Write(d.tmp[4:8])
// Read the chunk data.
switch string(d.tmp[4:8]) {
case "IHDR":
if d.stage != dsStart {
return chunkOrderError
}
d.stage = dsSeenIHDR
return d.parseIHDR(length)
case "PLTE":
if d.stage != dsSeenIHDR {
return chunkOrderError
}
d.stage = dsSeenPLTE
return d.parsePLTE(length)
case "tRNS":
if cbPaletted(d.cb) {
if d.stage != dsSeenPLTE {
return chunkOrderError
}
} else if cbTrueColor(d.cb) {
if d.stage != dsSeenIHDR && d.stage != dsSeenPLTE {
return chunkOrderError
}
} else if d.stage != dsSeenIHDR {
return chunkOrderError
}
d.stage = dsSeentRNS
return d.parsetRNS(length)
case "IDAT":
if d.stage < dsSeenIHDR || d.stage > dsSeenIDAT || (d.stage == dsSeenIHDR && cbPaletted(d.cb)) {
return chunkOrderError
} else if d.stage == dsSeenIDAT {
// Ignore trailing zero-length or garbage IDAT chunks.
//
// This does not affect valid PNG images that contain multiple IDAT
// chunks, since the first call to parseIDAT below will consume all
// consecutive IDAT chunks required for decoding the image.
break
}
d.stage = dsSeenIDAT
if configOnly {
return nil
}
return d.parseIDAT(length)
case "IEND":
if d.stage != dsSeenIDAT {
return chunkOrderError
}
d.stage = dsSeenIEND
return d.parseIEND(length)
}
if length > 0x7fffffff {
return FormatError(fmt.Sprintf("Bad chunk length: %d", length))
}
// Ignore this chunk (of a known length).
var ignored [4096]byte
for length > 0 {
n, err := io.ReadFull(d.r, ignored[:min(len(ignored), int(length))])
if err != nil {
return err
}
d.crc.Write(ignored[:n])
length -= uint32(n)
}
return d.verifyChecksum()
}
func (d *decoder) verifyChecksum() error {
if _, err := io.ReadFull(d.r, d.tmp[:4]); err != nil {
return err
}
if binary.BigEndian.Uint32(d.tmp[:4]) != d.crc.Sum32() {
return FormatError("invalid checksum")
}
return nil
}
func (d *decoder) checkHeader() error {
_, err := io.ReadFull(d.r, d.tmp[:len(pngHeader)])
if err != nil {
return err
}
if string(d.tmp[:len(pngHeader)]) != pngHeader {
return FormatError("not a PNG file")
}
return nil
}
// Decode reads a PNG image from r and returns it as an image.Image.
// The type of Image returned depends on the PNG contents.
func Decode(r io.Reader) (image.Image, error) {
d := &decoder{
r: r,
crc: crc32.NewIEEE(),
}
if err := d.checkHeader(); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
for d.stage != dsSeenIEND {
if err := d.parseChunk(false); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
}
return d.img, nil
}
// DecodeConfig returns the color model and dimensions of a PNG image without
// decoding the entire image.
func DecodeConfig(r io.Reader) (image.Config, error) {
d := &decoder{
r: r,
crc: crc32.NewIEEE(),
}
if err := d.checkHeader(); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return image.Config{}, err
}
for {
if err := d.parseChunk(true); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return image.Config{}, err
}
if cbPaletted(d.cb) {
if d.stage >= dsSeentRNS {
break
}
} else {
if d.stage >= dsSeenIHDR {
break
}
}
}
var cm color.Model
switch d.cb {
case cbG1, cbG2, cbG4, cbG8:
cm = color.GrayModel
case cbGA8:
cm = color.NRGBAModel
case cbTC8:
cm = color.RGBAModel
case cbP1, cbP2, cbP4, cbP8:
cm = d.palette
case cbTCA8:
cm = color.NRGBAModel
case cbG16:
cm = color.Gray16Model
case cbGA16:
cm = color.NRGBA64Model
case cbTC16:
cm = color.RGBA64Model
case cbTCA16:
cm = color.NRGBA64Model
}
return image.Config{
ColorModel: cm,
Width: d.width,
Height: d.height,
}, nil
}
func init() {
image.RegisterFormat("png", pngHeader, Decode, DecodeConfig)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package png
import (
"bufio"
"compress/zlib"
"encoding/binary"
"hash/crc32"
"image"
"image/color"
"io"
"strconv"
)
// Encoder configures encoding PNG images.
type Encoder struct {
CompressionLevel CompressionLevel
// BufferPool optionally specifies a buffer pool to get temporary
// EncoderBuffers when encoding an image.
BufferPool EncoderBufferPool
}
// EncoderBufferPool is an interface for getting and returning temporary
// instances of the EncoderBuffer struct. This can be used to reuse buffers
// when encoding multiple images.
type EncoderBufferPool interface {
Get() *EncoderBuffer
Put(*EncoderBuffer)
}
// EncoderBuffer holds the buffers used for encoding PNG images.
type EncoderBuffer encoder
type encoder struct {
enc *Encoder
w io.Writer
m image.Image
cb int
err error
header [8]byte
footer [4]byte
tmp [4 * 256]byte
cr [nFilter][]uint8
pr []uint8
zw *zlib.Writer
zwLevel int
bw *bufio.Writer
}
// CompressionLevel indicates the compression level.
type CompressionLevel int
const (
DefaultCompression CompressionLevel = 0
NoCompression CompressionLevel = -1
BestSpeed CompressionLevel = -2
BestCompression CompressionLevel = -3
// Positive CompressionLevel values are reserved to mean a numeric zlib
// compression level, although that is not implemented yet.
)
type opaquer interface {
Opaque() bool
}
// Returns whether or not the image is fully opaque.
func opaque(m image.Image) bool {
if o, ok := m.(opaquer); ok {
return o.Opaque()
}
b := m.Bounds()
for y := b.Min.Y; y < b.Max.Y; y++ {
for x := b.Min.X; x < b.Max.X; x++ {
_, _, _, a := m.At(x, y).RGBA()
if a != 0xffff {
return false
}
}
}
return true
}
// The absolute value of a byte interpreted as a signed int8.
func abs8(d uint8) int {
if d < 128 {
return int(d)
}
return 256 - int(d)
}
func (e *encoder) writeChunk(b []byte, name string) {
if e.err != nil {
return
}
n := uint32(len(b))
if int(n) != len(b) {
e.err = UnsupportedError(name + " chunk is too large: " + strconv.Itoa(len(b)))
return
}
binary.BigEndian.PutUint32(e.header[:4], n)
e.header[4] = name[0]
e.header[5] = name[1]
e.header[6] = name[2]
e.header[7] = name[3]
crc := crc32.NewIEEE()
crc.Write(e.header[4:8])
crc.Write(b)
binary.BigEndian.PutUint32(e.footer[:4], crc.Sum32())
_, e.err = e.w.Write(e.header[:8])
if e.err != nil {
return
}
_, e.err = e.w.Write(b)
if e.err != nil {
return
}
_, e.err = e.w.Write(e.footer[:4])
}
func (e *encoder) writeIHDR() {
b := e.m.Bounds()
binary.BigEndian.PutUint32(e.tmp[0:4], uint32(b.Dx()))
binary.BigEndian.PutUint32(e.tmp[4:8], uint32(b.Dy()))
// Set bit depth and color type.
switch e.cb {
case cbG8:
e.tmp[8] = 8
e.tmp[9] = ctGrayscale
case cbTC8:
e.tmp[8] = 8
e.tmp[9] = ctTrueColor
case cbP8:
e.tmp[8] = 8
e.tmp[9] = ctPaletted
case cbP4:
e.tmp[8] = 4
e.tmp[9] = ctPaletted
case cbP2:
e.tmp[8] = 2
e.tmp[9] = ctPaletted
case cbP1:
e.tmp[8] = 1
e.tmp[9] = ctPaletted
case cbTCA8:
e.tmp[8] = 8
e.tmp[9] = ctTrueColorAlpha
case cbG16:
e.tmp[8] = 16
e.tmp[9] = ctGrayscale
case cbTC16:
e.tmp[8] = 16
e.tmp[9] = ctTrueColor
case cbTCA16:
e.tmp[8] = 16
e.tmp[9] = ctTrueColorAlpha
}
e.tmp[10] = 0 // default compression method
e.tmp[11] = 0 // default filter method
e.tmp[12] = 0 // non-interlaced
e.writeChunk(e.tmp[:13], "IHDR")
}
func (e *encoder) writePLTEAndTRNS(p color.Palette) {
if len(p) < 1 || len(p) > 256 {
e.err = FormatError("bad palette length: " + strconv.Itoa(len(p)))
return
}
last := -1
for i, c := range p {
c1 := color.NRGBAModel.Convert(c).(color.NRGBA)
e.tmp[3*i+0] = c1.R
e.tmp[3*i+1] = c1.G
e.tmp[3*i+2] = c1.B
if c1.A != 0xff {
last = i
}
e.tmp[3*256+i] = c1.A
}
e.writeChunk(e.tmp[:3*len(p)], "PLTE")
if last != -1 {
e.writeChunk(e.tmp[3*256:3*256+1+last], "tRNS")
}
}
// An encoder is an io.Writer that satisfies writes by writing PNG IDAT chunks,
// including an 8-byte header and 4-byte CRC checksum per Write call. Such calls
// should be relatively infrequent, since writeIDATs uses a bufio.Writer.
//
// This method should only be called from writeIDATs (via writeImage).
// No other code should treat an encoder as an io.Writer.
func (e *encoder) Write(b []byte) (int, error) {
e.writeChunk(b, "IDAT")
if e.err != nil {
return 0, e.err
}
return len(b), nil
}
// Chooses the filter to use for encoding the current row, and applies it.
// The return value is the index of the filter and also of the row in cr that has had it applied.
func filter(cr *[nFilter][]byte, pr []byte, bpp int) int {
// We try all five filter types, and pick the one that minimizes the sum of absolute differences.
// This is the same heuristic that libpng uses, although the filters are attempted in order of
// estimated most likely to be minimal (ftUp, ftPaeth, ftNone, ftSub, ftAverage), rather than
// in their enumeration order (ftNone, ftSub, ftUp, ftAverage, ftPaeth).
cdat0 := cr[0][1:]
cdat1 := cr[1][1:]
cdat2 := cr[2][1:]
cdat3 := cr[3][1:]
cdat4 := cr[4][1:]
pdat := pr[1:]
n := len(cdat0)
// The up filter.
sum := 0
for i := 0; i < n; i++ {
cdat2[i] = cdat0[i] - pdat[i]
sum += abs8(cdat2[i])
}
best := sum
filter := ftUp
// The Paeth filter.
sum = 0
for i := 0; i < bpp; i++ {
cdat4[i] = cdat0[i] - pdat[i]
sum += abs8(cdat4[i])
}
for i := bpp; i < n; i++ {
cdat4[i] = cdat0[i] - paeth(cdat0[i-bpp], pdat[i], pdat[i-bpp])
sum += abs8(cdat4[i])
if sum >= best {
break
}
}
if sum < best {
best = sum
filter = ftPaeth
}
// The none filter.
sum = 0
for i := 0; i < n; i++ {
sum += abs8(cdat0[i])
if sum >= best {
break
}
}
if sum < best {
best = sum
filter = ftNone
}
// The sub filter.
sum = 0
for i := 0; i < bpp; i++ {
cdat1[i] = cdat0[i]
sum += abs8(cdat1[i])
}
for i := bpp; i < n; i++ {
cdat1[i] = cdat0[i] - cdat0[i-bpp]
sum += abs8(cdat1[i])
if sum >= best {
break
}
}
if sum < best {
best = sum
filter = ftSub
}
// The average filter.
sum = 0
for i := 0; i < bpp; i++ {
cdat3[i] = cdat0[i] - pdat[i]/2
sum += abs8(cdat3[i])
}
for i := bpp; i < n; i++ {
cdat3[i] = cdat0[i] - uint8((int(cdat0[i-bpp])+int(pdat[i]))/2)
sum += abs8(cdat3[i])
if sum >= best {
break
}
}
if sum < best {
filter = ftAverage
}
return filter
}
func zeroMemory(v []uint8) {
for i := range v {
v[i] = 0
}
}
func (e *encoder) writeImage(w io.Writer, m image.Image, cb int, level int) error {
if e.zw == nil || e.zwLevel != level {
zw, err := zlib.NewWriterLevel(w, level)
if err != nil {
return err
}
e.zw = zw
e.zwLevel = level
} else {
e.zw.Reset(w)
}
defer e.zw.Close()
bitsPerPixel := 0
switch cb {
case cbG8:
bitsPerPixel = 8
case cbTC8:
bitsPerPixel = 24
case cbP8:
bitsPerPixel = 8
case cbP4:
bitsPerPixel = 4
case cbP2:
bitsPerPixel = 2
case cbP1:
bitsPerPixel = 1
case cbTCA8:
bitsPerPixel = 32
case cbTC16:
bitsPerPixel = 48
case cbTCA16:
bitsPerPixel = 64
case cbG16:
bitsPerPixel = 16
}
// cr[*] and pr are the bytes for the current and previous row.
// cr[0] is unfiltered (or equivalently, filtered with the ftNone filter).
// cr[ft], for non-zero filter types ft, are buffers for transforming cr[0] under the
// other PNG filter types. These buffers are allocated once and re-used for each row.
// The +1 is for the per-row filter type, which is at cr[*][0].
b := m.Bounds()
sz := 1 + (bitsPerPixel*b.Dx()+7)/8
for i := range e.cr {
if cap(e.cr[i]) < sz {
e.cr[i] = make([]uint8, sz)
} else {
e.cr[i] = e.cr[i][:sz]
}
e.cr[i][0] = uint8(i)
}
cr := e.cr
if cap(e.pr) < sz {
e.pr = make([]uint8, sz)
} else {
e.pr = e.pr[:sz]
zeroMemory(e.pr)
}
pr := e.pr
gray, _ := m.(*image.Gray)
rgba, _ := m.(*image.RGBA)
paletted, _ := m.(*image.Paletted)
nrgba, _ := m.(*image.NRGBA)
for y := b.Min.Y; y < b.Max.Y; y++ {
// Convert from colors to bytes.
i := 1
switch cb {
case cbG8:
if gray != nil {
offset := (y - b.Min.Y) * gray.Stride
copy(cr[0][1:], gray.Pix[offset:offset+b.Dx()])
} else {
for x := b.Min.X; x < b.Max.X; x++ {
c := color.GrayModel.Convert(m.At(x, y)).(color.Gray)
cr[0][i] = c.Y
i++
}
}
case cbTC8:
// We have previously verified that the alpha value is fully opaque.
cr0 := cr[0]
stride, pix := 0, []byte(nil)
if rgba != nil {
stride, pix = rgba.Stride, rgba.Pix
} else if nrgba != nil {
stride, pix = nrgba.Stride, nrgba.Pix
}
if stride != 0 {
j0 := (y - b.Min.Y) * stride
j1 := j0 + b.Dx()*4
for j := j0; j < j1; j += 4 {
cr0[i+0] = pix[j+0]
cr0[i+1] = pix[j+1]
cr0[i+2] = pix[j+2]
i += 3
}
} else {
for x := b.Min.X; x < b.Max.X; x++ {
r, g, b, _ := m.At(x, y).RGBA()
cr0[i+0] = uint8(r >> 8)
cr0[i+1] = uint8(g >> 8)
cr0[i+2] = uint8(b >> 8)
i += 3
}
}
case cbP8:
if paletted != nil {
offset := (y - b.Min.Y) * paletted.Stride
copy(cr[0][1:], paletted.Pix[offset:offset+b.Dx()])
} else {
pi := m.(image.PalettedImage)
for x := b.Min.X; x < b.Max.X; x++ {
cr[0][i] = pi.ColorIndexAt(x, y)
i += 1
}
}
case cbP4, cbP2, cbP1:
pi := m.(image.PalettedImage)
var a uint8
var c int
pixelsPerByte := 8 / bitsPerPixel
for x := b.Min.X; x < b.Max.X; x++ {
a = a<<uint(bitsPerPixel) | pi.ColorIndexAt(x, y)
c++
if c == pixelsPerByte {
cr[0][i] = a
i += 1
a = 0
c = 0
}
}
if c != 0 {
for c != pixelsPerByte {
a = a << uint(bitsPerPixel)
c++
}
cr[0][i] = a
}
case cbTCA8:
if nrgba != nil {
offset := (y - b.Min.Y) * nrgba.Stride
copy(cr[0][1:], nrgba.Pix[offset:offset+b.Dx()*4])
} else if rgba != nil {
dst := cr[0][1:]
src := rgba.Pix[rgba.PixOffset(b.Min.X, y):rgba.PixOffset(b.Max.X, y)]
for ; len(src) >= 4; dst, src = dst[4:], src[4:] {
d := (*[4]byte)(dst)
s := (*[4]byte)(src)
if s[3] == 0x00 {
d[0] = 0
d[1] = 0
d[2] = 0
d[3] = 0
} else if s[3] == 0xff {
copy(d[:], s[:])
} else {
// This code does the same as color.NRGBAModel.Convert(
// rgba.At(x, y)).(color.NRGBA) but with no extra memory
// allocations or interface/function call overhead.
//
// The multiplier m combines 0x101 (which converts
// 8-bit color to 16-bit color) and 0xffff (which, when
// combined with the division-by-a, converts from
// alpha-premultiplied to non-alpha-premultiplied).
const m = 0x101 * 0xffff
a := uint32(s[3]) * 0x101
d[0] = uint8((uint32(s[0]) * m / a) >> 8)
d[1] = uint8((uint32(s[1]) * m / a) >> 8)
d[2] = uint8((uint32(s[2]) * m / a) >> 8)
d[3] = s[3]
}
}
} else {
// Convert from image.Image (which is alpha-premultiplied) to PNG's non-alpha-premultiplied.
for x := b.Min.X; x < b.Max.X; x++ {
c := color.NRGBAModel.Convert(m.At(x, y)).(color.NRGBA)
cr[0][i+0] = c.R
cr[0][i+1] = c.G
cr[0][i+2] = c.B
cr[0][i+3] = c.A
i += 4
}
}
case cbG16:
for x := b.Min.X; x < b.Max.X; x++ {
c := color.Gray16Model.Convert(m.At(x, y)).(color.Gray16)
cr[0][i+0] = uint8(c.Y >> 8)
cr[0][i+1] = uint8(c.Y)
i += 2
}
case cbTC16:
// We have previously verified that the alpha value is fully opaque.
for x := b.Min.X; x < b.Max.X; x++ {
r, g, b, _ := m.At(x, y).RGBA()
cr[0][i+0] = uint8(r >> 8)
cr[0][i+1] = uint8(r)
cr[0][i+2] = uint8(g >> 8)
cr[0][i+3] = uint8(g)
cr[0][i+4] = uint8(b >> 8)
cr[0][i+5] = uint8(b)
i += 6
}
case cbTCA16:
// Convert from image.Image (which is alpha-premultiplied) to PNG's non-alpha-premultiplied.
for x := b.Min.X; x < b.Max.X; x++ {
c := color.NRGBA64Model.Convert(m.At(x, y)).(color.NRGBA64)
cr[0][i+0] = uint8(c.R >> 8)
cr[0][i+1] = uint8(c.R)
cr[0][i+2] = uint8(c.G >> 8)
cr[0][i+3] = uint8(c.G)
cr[0][i+4] = uint8(c.B >> 8)
cr[0][i+5] = uint8(c.B)
cr[0][i+6] = uint8(c.A >> 8)
cr[0][i+7] = uint8(c.A)
i += 8
}
}
// Apply the filter.
// Skip filter for NoCompression and paletted images (cbP8) as
// "filters are rarely useful on palette images" and will result
// in larger files (see http://www.libpng.org/pub/png/book/chapter09.html).
f := ftNone
if level != zlib.NoCompression && cb != cbP8 && cb != cbP4 && cb != cbP2 && cb != cbP1 {
// Since we skip paletted images we don't have to worry about
// bitsPerPixel not being a multiple of 8
bpp := bitsPerPixel / 8
f = filter(&cr, pr, bpp)
}
// Write the compressed bytes.
if _, err := e.zw.Write(cr[f]); err != nil {
return err
}
// The current row for y is the previous row for y+1.
pr, cr[0] = cr[0], pr
}
return nil
}
// Write the actual image data to one or more IDAT chunks.
func (e *encoder) writeIDATs() {
if e.err != nil {
return
}
if e.bw == nil {
e.bw = bufio.NewWriterSize(e, 1<<15)
} else {
e.bw.Reset(e)
}
e.err = e.writeImage(e.bw, e.m, e.cb, levelToZlib(e.enc.CompressionLevel))
if e.err != nil {
return
}
e.err = e.bw.Flush()
}
// This function is required because we want the zero value of
// Encoder.CompressionLevel to map to zlib.DefaultCompression.
func levelToZlib(l CompressionLevel) int {
switch l {
case DefaultCompression:
return zlib.DefaultCompression
case NoCompression:
return zlib.NoCompression
case BestSpeed:
return zlib.BestSpeed
case BestCompression:
return zlib.BestCompression
default:
return zlib.DefaultCompression
}
}
func (e *encoder) writeIEND() { e.writeChunk(nil, "IEND") }
// Encode writes the Image m to w in PNG format. Any Image may be
// encoded, but images that are not image.NRGBA might be encoded lossily.
func Encode(w io.Writer, m image.Image) error {
var e Encoder
return e.Encode(w, m)
}
// Encode writes the Image m to w in PNG format.
func (enc *Encoder) Encode(w io.Writer, m image.Image) error {
// Obviously, negative widths and heights are invalid. Furthermore, the PNG
// spec section 11.2.2 says that zero is invalid. Excessively large images are
// also rejected.
mw, mh := int64(m.Bounds().Dx()), int64(m.Bounds().Dy())
if mw <= 0 || mh <= 0 || mw >= 1<<32 || mh >= 1<<32 {
return FormatError("invalid image size: " + strconv.FormatInt(mw, 10) + "x" + strconv.FormatInt(mh, 10))
}
var e *encoder
if enc.BufferPool != nil {
buffer := enc.BufferPool.Get()
e = (*encoder)(buffer)
}
if e == nil {
e = &encoder{}
}
if enc.BufferPool != nil {
defer enc.BufferPool.Put((*EncoderBuffer)(e))
}
e.enc = enc
e.w = w
e.m = m
var pal color.Palette
// cbP8 encoding needs PalettedImage's ColorIndexAt method.
if _, ok := m.(image.PalettedImage); ok {
pal, _ = m.ColorModel().(color.Palette)
}
if pal != nil {
if len(pal) <= 2 {
e.cb = cbP1
} else if len(pal) <= 4 {
e.cb = cbP2
} else if len(pal) <= 16 {
e.cb = cbP4
} else {
e.cb = cbP8
}
} else {
switch m.ColorModel() {
case color.GrayModel:
e.cb = cbG8
case color.Gray16Model:
e.cb = cbG16
case color.RGBAModel, color.NRGBAModel, color.AlphaModel:
if opaque(m) {
e.cb = cbTC8
} else {
e.cb = cbTCA8
}
default:
if opaque(m) {
e.cb = cbTC16
} else {
e.cb = cbTCA16
}
}
}
_, e.err = io.WriteString(w, pngHeader)
e.writeIHDR()
if pal != nil {
e.writePLTEAndTRNS(pal)
}
e.writeIDATs()
e.writeIEND()
return e.err
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package image
import (
"image/color"
)
// YCbCrSubsampleRatio is the chroma subsample ratio used in a YCbCr image.
type YCbCrSubsampleRatio int
const (
YCbCrSubsampleRatio444 YCbCrSubsampleRatio = iota
YCbCrSubsampleRatio422
YCbCrSubsampleRatio420
YCbCrSubsampleRatio440
YCbCrSubsampleRatio411
YCbCrSubsampleRatio410
)
func (s YCbCrSubsampleRatio) String() string {
switch s {
case YCbCrSubsampleRatio444:
return "YCbCrSubsampleRatio444"
case YCbCrSubsampleRatio422:
return "YCbCrSubsampleRatio422"
case YCbCrSubsampleRatio420:
return "YCbCrSubsampleRatio420"
case YCbCrSubsampleRatio440:
return "YCbCrSubsampleRatio440"
case YCbCrSubsampleRatio411:
return "YCbCrSubsampleRatio411"
case YCbCrSubsampleRatio410:
return "YCbCrSubsampleRatio410"
}
return "YCbCrSubsampleRatioUnknown"
}
// YCbCr is an in-memory image of Y'CbCr colors. There is one Y sample per
// pixel, but each Cb and Cr sample can span one or more pixels.
// YStride is the Y slice index delta between vertically adjacent pixels.
// CStride is the Cb and Cr slice index delta between vertically adjacent pixels
// that map to separate chroma samples.
// It is not an absolute requirement, but YStride and len(Y) are typically
// multiples of 8, and:
//
// For 4:4:4, CStride == YStride/1 && len(Cb) == len(Cr) == len(Y)/1.
// For 4:2:2, CStride == YStride/2 && len(Cb) == len(Cr) == len(Y)/2.
// For 4:2:0, CStride == YStride/2 && len(Cb) == len(Cr) == len(Y)/4.
// For 4:4:0, CStride == YStride/1 && len(Cb) == len(Cr) == len(Y)/2.
// For 4:1:1, CStride == YStride/4 && len(Cb) == len(Cr) == len(Y)/4.
// For 4:1:0, CStride == YStride/4 && len(Cb) == len(Cr) == len(Y)/8.
type YCbCr struct {
Y, Cb, Cr []uint8
YStride int
CStride int
SubsampleRatio YCbCrSubsampleRatio
Rect Rectangle
}
func (p *YCbCr) ColorModel() color.Model {
return color.YCbCrModel
}
func (p *YCbCr) Bounds() Rectangle {
return p.Rect
}
func (p *YCbCr) At(x, y int) color.Color {
return p.YCbCrAt(x, y)
}
func (p *YCbCr) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.YCbCrAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *YCbCr) YCbCrAt(x, y int) color.YCbCr {
if !(Point{x, y}.In(p.Rect)) {
return color.YCbCr{}
}
yi := p.YOffset(x, y)
ci := p.COffset(x, y)
return color.YCbCr{
p.Y[yi],
p.Cb[ci],
p.Cr[ci],
}
}
// YOffset returns the index of the first element of Y that corresponds to
// the pixel at (x, y).
func (p *YCbCr) YOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.YStride + (x - p.Rect.Min.X)
}
// COffset returns the index of the first element of Cb or Cr that corresponds
// to the pixel at (x, y).
func (p *YCbCr) COffset(x, y int) int {
switch p.SubsampleRatio {
case YCbCrSubsampleRatio422:
return (y-p.Rect.Min.Y)*p.CStride + (x/2 - p.Rect.Min.X/2)
case YCbCrSubsampleRatio420:
return (y/2-p.Rect.Min.Y/2)*p.CStride + (x/2 - p.Rect.Min.X/2)
case YCbCrSubsampleRatio440:
return (y/2-p.Rect.Min.Y/2)*p.CStride + (x - p.Rect.Min.X)
case YCbCrSubsampleRatio411:
return (y-p.Rect.Min.Y)*p.CStride + (x/4 - p.Rect.Min.X/4)
case YCbCrSubsampleRatio410:
return (y/2-p.Rect.Min.Y/2)*p.CStride + (x/4 - p.Rect.Min.X/4)
}
// Default to 4:4:4 subsampling.
return (y-p.Rect.Min.Y)*p.CStride + (x - p.Rect.Min.X)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *YCbCr) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &YCbCr{
SubsampleRatio: p.SubsampleRatio,
}
}
yi := p.YOffset(r.Min.X, r.Min.Y)
ci := p.COffset(r.Min.X, r.Min.Y)
return &YCbCr{
Y: p.Y[yi:],
Cb: p.Cb[ci:],
Cr: p.Cr[ci:],
SubsampleRatio: p.SubsampleRatio,
YStride: p.YStride,
CStride: p.CStride,
Rect: r,
}
}
func (p *YCbCr) Opaque() bool {
return true
}
func yCbCrSize(r Rectangle, subsampleRatio YCbCrSubsampleRatio) (w, h, cw, ch int) {
w, h = r.Dx(), r.Dy()
switch subsampleRatio {
case YCbCrSubsampleRatio422:
cw = (r.Max.X+1)/2 - r.Min.X/2
ch = h
case YCbCrSubsampleRatio420:
cw = (r.Max.X+1)/2 - r.Min.X/2
ch = (r.Max.Y+1)/2 - r.Min.Y/2
case YCbCrSubsampleRatio440:
cw = w
ch = (r.Max.Y+1)/2 - r.Min.Y/2
case YCbCrSubsampleRatio411:
cw = (r.Max.X+3)/4 - r.Min.X/4
ch = h
case YCbCrSubsampleRatio410:
cw = (r.Max.X+3)/4 - r.Min.X/4
ch = (r.Max.Y+1)/2 - r.Min.Y/2
default:
// Default to 4:4:4 subsampling.
cw = w
ch = h
}
return
}
// NewYCbCr returns a new YCbCr image with the given bounds and subsample
// ratio.
func NewYCbCr(r Rectangle, subsampleRatio YCbCrSubsampleRatio) *YCbCr {
w, h, cw, ch := yCbCrSize(r, subsampleRatio)
// totalLength should be the same as i2, below, for a valid Rectangle r.
totalLength := add2NonNeg(
mul3NonNeg(1, w, h),
mul3NonNeg(2, cw, ch),
)
if totalLength < 0 {
panic("image: NewYCbCr Rectangle has huge or negative dimensions")
}
i0 := w*h + 0*cw*ch
i1 := w*h + 1*cw*ch
i2 := w*h + 2*cw*ch
b := make([]byte, i2)
return &YCbCr{
Y: b[:i0:i0],
Cb: b[i0:i1:i1],
Cr: b[i1:i2:i2],
SubsampleRatio: subsampleRatio,
YStride: w,
CStride: cw,
Rect: r,
}
}
// NYCbCrA is an in-memory image of non-alpha-premultiplied Y'CbCr-with-alpha
// colors. A and AStride are analogous to the Y and YStride fields of the
// embedded YCbCr.
type NYCbCrA struct {
YCbCr
A []uint8
AStride int
}
func (p *NYCbCrA) ColorModel() color.Model {
return color.NYCbCrAModel
}
func (p *NYCbCrA) At(x, y int) color.Color {
return p.NYCbCrAAt(x, y)
}
func (p *NYCbCrA) RGBA64At(x, y int) color.RGBA64 {
r, g, b, a := p.NYCbCrAAt(x, y).RGBA()
return color.RGBA64{uint16(r), uint16(g), uint16(b), uint16(a)}
}
func (p *NYCbCrA) NYCbCrAAt(x, y int) color.NYCbCrA {
if !(Point{X: x, Y: y}.In(p.Rect)) {
return color.NYCbCrA{}
}
yi := p.YOffset(x, y)
ci := p.COffset(x, y)
ai := p.AOffset(x, y)
return color.NYCbCrA{
color.YCbCr{
Y: p.Y[yi],
Cb: p.Cb[ci],
Cr: p.Cr[ci],
},
p.A[ai],
}
}
// AOffset returns the index of the first element of A that corresponds to the
// pixel at (x, y).
func (p *NYCbCrA) AOffset(x, y int) int {
return (y-p.Rect.Min.Y)*p.AStride + (x - p.Rect.Min.X)
}
// SubImage returns an image representing the portion of the image p visible
// through r. The returned value shares pixels with the original image.
func (p *NYCbCrA) SubImage(r Rectangle) Image {
r = r.Intersect(p.Rect)
// If r1 and r2 are Rectangles, r1.Intersect(r2) is not guaranteed to be inside
// either r1 or r2 if the intersection is empty. Without explicitly checking for
// this, the Pix[i:] expression below can panic.
if r.Empty() {
return &NYCbCrA{
YCbCr: YCbCr{
SubsampleRatio: p.SubsampleRatio,
},
}
}
yi := p.YOffset(r.Min.X, r.Min.Y)
ci := p.COffset(r.Min.X, r.Min.Y)
ai := p.AOffset(r.Min.X, r.Min.Y)
return &NYCbCrA{
YCbCr: YCbCr{
Y: p.Y[yi:],
Cb: p.Cb[ci:],
Cr: p.Cr[ci:],
SubsampleRatio: p.SubsampleRatio,
YStride: p.YStride,
CStride: p.CStride,
Rect: r,
},
A: p.A[ai:],
AStride: p.AStride,
}
}
// Opaque scans the entire image and reports whether it is fully opaque.
func (p *NYCbCrA) Opaque() bool {
if p.Rect.Empty() {
return true
}
i0, i1 := 0, p.Rect.Dx()
for y := p.Rect.Min.Y; y < p.Rect.Max.Y; y++ {
for _, a := range p.A[i0:i1] {
if a != 0xff {
return false
}
}
i0 += p.AStride
i1 += p.AStride
}
return true
}
// NewNYCbCrA returns a new NYCbCrA image with the given bounds and subsample
// ratio.
func NewNYCbCrA(r Rectangle, subsampleRatio YCbCrSubsampleRatio) *NYCbCrA {
w, h, cw, ch := yCbCrSize(r, subsampleRatio)
// totalLength should be the same as i3, below, for a valid Rectangle r.
totalLength := add2NonNeg(
mul3NonNeg(2, w, h),
mul3NonNeg(2, cw, ch),
)
if totalLength < 0 {
panic("image: NewNYCbCrA Rectangle has huge or negative dimension")
}
i0 := 1*w*h + 0*cw*ch
i1 := 1*w*h + 1*cw*ch
i2 := 1*w*h + 2*cw*ch
i3 := 2*w*h + 2*cw*ch
b := make([]byte, i3)
return &NYCbCrA{
YCbCr: YCbCr{
Y: b[:i0:i0],
Cb: b[i0:i1:i1],
Cr: b[i1:i2:i2],
SubsampleRatio: subsampleRatio,
YStride: w,
CStride: cw,
Rect: r,
},
A: b[i2:],
AStride: w,
}
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Suffix array construction by induced sorting (SAIS).
// See Ge Nong, Sen Zhang, and Wai Hong Chen,
// "Two Efficient Algorithms for Linear Time Suffix Array Construction",
// especially section 3 (https://ieeexplore.ieee.org/document/5582081).
// See also http://zork.net/~st/jottings/sais.html.
//
// With optimizations inspired by Yuta Mori's sais-lite
// (https://sites.google.com/site/yuta256/sais).
//
// And with other new optimizations.
// Many of these functions are parameterized by the sizes of
// the types they operate on. The generator gen.go makes
// copies of these functions for use with other sizes.
// Specifically:
//
// - A function with a name ending in _8_32 takes []byte and []int32 arguments
// and is duplicated into _32_32, _8_64, and _64_64 forms.
// The _32_32 and _64_64_ suffixes are shortened to plain _32 and _64.
// Any lines in the function body that contain the text "byte-only" or "256"
// are stripped when creating _32_32 and _64_64 forms.
// (Those lines are typically 8-bit-specific optimizations.)
//
// - A function with a name ending only in _32 operates on []int32
// and is duplicated into a _64 form. (Note that it may still take a []byte,
// but there is no need for a version of the function in which the []byte
// is widened to a full integer array.)
// The overall runtime of this code is linear in the input size:
// it runs a sequence of linear passes to reduce the problem to
// a subproblem at most half as big, invokes itself recursively,
// and then runs a sequence of linear passes to turn the answer
// for the subproblem into the answer for the original problem.
// This gives T(N) = O(N) + T(N/2) = O(N) + O(N/2) + O(N/4) + ... = O(N).
//
// The outline of the code, with the forward and backward scans
// through O(N)-sized arrays called out, is:
//
// sais_I_N
// placeLMS_I_B
// bucketMax_I_B
// freq_I_B
// <scan +text> (1)
// <scan +freq> (2)
// <scan -text, random bucket> (3)
// induceSubL_I_B
// bucketMin_I_B
// freq_I_B
// <scan +text, often optimized away> (4)
// <scan +freq> (5)
// <scan +sa, random text, random bucket> (6)
// induceSubS_I_B
// bucketMax_I_B
// freq_I_B
// <scan +text, often optimized away> (7)
// <scan +freq> (8)
// <scan -sa, random text, random bucket> (9)
// assignID_I_B
// <scan +sa, random text substrings> (10)
// map_B
// <scan -sa> (11)
// recurse_B
// (recursive call to sais_B_B for a subproblem of size at most 1/2 input, often much smaller)
// unmap_I_B
// <scan -text> (12)
// <scan +sa> (13)
// expand_I_B
// bucketMax_I_B
// freq_I_B
// <scan +text, often optimized away> (14)
// <scan +freq> (15)
// <scan -sa, random text, random bucket> (16)
// induceL_I_B
// bucketMin_I_B
// freq_I_B
// <scan +text, often optimized away> (17)
// <scan +freq> (18)
// <scan +sa, random text, random bucket> (19)
// induceS_I_B
// bucketMax_I_B
// freq_I_B
// <scan +text, often optimized away> (20)
// <scan +freq> (21)
// <scan -sa, random text, random bucket> (22)
//
// Here, _B indicates the suffix array size (_32 or _64) and _I the input size (_8 or _B).
//
// The outline shows there are in general 22 scans through
// O(N)-sized arrays for a given level of the recursion.
// In the top level, operating on 8-bit input text,
// the six freq scans are fixed size (256) instead of potentially
// input-sized. Also, the frequency is counted once and cached
// whenever there is room to do so (there is nearly always room in general,
// and always room at the top level), which eliminates all but
// the first freq_I_B text scans (that is, 5 of the 6).
// So the top level of the recursion only does 22 - 6 - 5 = 11
// input-sized scans and a typical level does 16 scans.
//
// The linear scans do not cost anywhere near as much as
// the random accesses to the text made during a few of
// the scans (specifically #6, #9, #16, #19, #22 marked above).
// In real texts, there is not much but some locality to
// the accesses, due to the repetitive structure of the text
// (the same reason Burrows-Wheeler compression is so effective).
// For random inputs, there is no locality, which makes those
// accesses even more expensive, especially once the text
// no longer fits in cache.
// For example, running on 50 MB of Go source code, induceSubL_8_32
// (which runs only once, at the top level of the recursion)
// takes 0.44s, while on 50 MB of random input, it takes 2.55s.
// Nearly all the relative slowdown is explained by the text access:
//
// c0, c1 := text[k-1], text[k]
//
// That line runs for 0.23s on the Go text and 2.02s on random text.
//go:generate go run gen.go
package suffixarray
// text_32 returns the suffix array for the input text.
// It requires that len(text) fit in an int32
// and that the caller zero sa.
func text_32(text []byte, sa []int32) {
if int(int32(len(text))) != len(text) || len(text) != len(sa) {
panic("suffixarray: misuse of text_32")
}
sais_8_32(text, 256, sa, make([]int32, 2*256))
}
// sais_8_32 computes the suffix array of text.
// The text must contain only values in [0, textMax).
// The suffix array is stored in sa, which the caller
// must ensure is already zeroed.
// The caller must also provide temporary space tmp
// with len(tmp) ≥ textMax. If len(tmp) ≥ 2*textMax
// then the algorithm runs a little faster.
// If sais_8_32 modifies tmp, it sets tmp[0] = -1 on return.
func sais_8_32(text []byte, textMax int, sa, tmp []int32) {
if len(sa) != len(text) || len(tmp) < int(textMax) {
panic("suffixarray: misuse of sais_8_32")
}
// Trivial base cases. Sorting 0 or 1 things is easy.
if len(text) == 0 {
return
}
if len(text) == 1 {
sa[0] = 0
return
}
// Establish slices indexed by text character
// holding character frequency and bucket-sort offsets.
// If there's only enough tmp for one slice,
// we make it the bucket offsets and recompute
// the character frequency each time we need it.
var freq, bucket []int32
if len(tmp) >= 2*textMax {
freq, bucket = tmp[:textMax], tmp[textMax:2*textMax]
freq[0] = -1 // mark as uninitialized
} else {
freq, bucket = nil, tmp[:textMax]
}
// The SAIS algorithm.
// Each of these calls makes one scan through sa.
// See the individual functions for documentation
// about each's role in the algorithm.
numLMS := placeLMS_8_32(text, sa, freq, bucket)
if numLMS <= 1 {
// 0 or 1 items are already sorted. Do nothing.
} else {
induceSubL_8_32(text, sa, freq, bucket)
induceSubS_8_32(text, sa, freq, bucket)
length_8_32(text, sa, numLMS)
maxID := assignID_8_32(text, sa, numLMS)
if maxID < numLMS {
map_32(sa, numLMS)
recurse_32(sa, tmp, numLMS, maxID)
unmap_8_32(text, sa, numLMS)
} else {
// If maxID == numLMS, then each LMS-substring
// is unique, so the relative ordering of two LMS-suffixes
// is determined by just the leading LMS-substring.
// That is, the LMS-suffix sort order matches the
// (simpler) LMS-substring sort order.
// Copy the original LMS-substring order into the
// suffix array destination.
copy(sa, sa[len(sa)-numLMS:])
}
expand_8_32(text, freq, bucket, sa, numLMS)
}
induceL_8_32(text, sa, freq, bucket)
induceS_8_32(text, sa, freq, bucket)
// Mark for caller that we overwrote tmp.
tmp[0] = -1
}
// freq_8_32 returns the character frequencies
// for text, as a slice indexed by character value.
// If freq is nil, freq_8_32 uses and returns bucket.
// If freq is non-nil, freq_8_32 assumes that freq[0] >= 0
// means the frequencies are already computed.
// If the frequency data is overwritten or uninitialized,
// the caller must set freq[0] = -1 to force recomputation
// the next time it is needed.
func freq_8_32(text []byte, freq, bucket []int32) []int32 {
if freq != nil && freq[0] >= 0 {
return freq // already computed
}
if freq == nil {
freq = bucket
}
freq = freq[:256] // eliminate bounds check for freq[c] below
for i := range freq {
freq[i] = 0
}
for _, c := range text {
freq[c]++
}
return freq
}
// bucketMin_8_32 stores into bucket[c] the minimum index
// in the bucket for character c in a bucket-sort of text.
func bucketMin_8_32(text []byte, freq, bucket []int32) {
freq = freq_8_32(text, freq, bucket)
freq = freq[:256] // establish len(freq) = 256, so 0 ≤ i < 256 below
bucket = bucket[:256] // eliminate bounds check for bucket[i] below
total := int32(0)
for i, n := range freq {
bucket[i] = total
total += n
}
}
// bucketMax_8_32 stores into bucket[c] the maximum index
// in the bucket for character c in a bucket-sort of text.
// The bucket indexes for c are [min, max).
// That is, max is one past the final index in that bucket.
func bucketMax_8_32(text []byte, freq, bucket []int32) {
freq = freq_8_32(text, freq, bucket)
freq = freq[:256] // establish len(freq) = 256, so 0 ≤ i < 256 below
bucket = bucket[:256] // eliminate bounds check for bucket[i] below
total := int32(0)
for i, n := range freq {
total += n
bucket[i] = total
}
}
// The SAIS algorithm proceeds in a sequence of scans through sa.
// Each of the following functions implements one scan,
// and the functions appear here in the order they execute in the algorithm.
// placeLMS_8_32 places into sa the indexes of the
// final characters of the LMS substrings of text,
// sorted into the rightmost ends of their correct buckets
// in the suffix array.
//
// The imaginary sentinel character at the end of the text
// is the final character of the final LMS substring, but there
// is no bucket for the imaginary sentinel character,
// which has a smaller value than any real character.
// The caller must therefore pretend that sa[-1] == len(text).
//
// The text indexes of LMS-substring characters are always ≥ 1
// (the first LMS-substring must be preceded by one or more L-type
// characters that are not part of any LMS-substring),
// so using 0 as a “not present” suffix array entry is safe,
// both in this function and in most later functions
// (until induceL_8_32 below).
func placeLMS_8_32(text []byte, sa, freq, bucket []int32) int {
bucketMax_8_32(text, freq, bucket)
numLMS := 0
lastB := int32(-1)
bucket = bucket[:256] // eliminate bounds check for bucket[c1] below
// The next stanza of code (until the blank line) loop backward
// over text, stopping to execute a code body at each position i
// such that text[i] is an L-character and text[i+1] is an S-character.
// That is, i+1 is the position of the start of an LMS-substring.
// These could be hoisted out into a function with a callback,
// but at a significant speed cost. Instead, we just write these
// seven lines a few times in this source file. The copies below
// refer back to the pattern established by this original as the
// "LMS-substring iterator".
//
// In every scan through the text, c0, c1 are successive characters of text.
// In this backward scan, c0 == text[i] and c1 == text[i+1].
// By scanning backward, we can keep track of whether the current
// position is type-S or type-L according to the usual definition:
//
// - position len(text) is type S with text[len(text)] == -1 (the sentinel)
// - position i is type S if text[i] < text[i+1], or if text[i] == text[i+1] && i+1 is type S.
// - position i is type L if text[i] > text[i+1], or if text[i] == text[i+1] && i+1 is type L.
//
// The backward scan lets us maintain the current type,
// update it when we see c0 != c1, and otherwise leave it alone.
// We want to identify all S positions with a preceding L.
// Position len(text) is one such position by definition, but we have
// nowhere to write it down, so we eliminate it by untruthfully
// setting isTypeS = false at the start of the loop.
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Bucket the index i+1 for the start of an LMS-substring.
b := bucket[c1] - 1
bucket[c1] = b
sa[b] = int32(i + 1)
lastB = b
numLMS++
}
}
// We recorded the LMS-substring starts but really want the ends.
// Luckily, with two differences, the start indexes and the end indexes are the same.
// The first difference is that the rightmost LMS-substring's end index is len(text),
// so the caller must pretend that sa[-1] == len(text), as noted above.
// The second difference is that the first leftmost LMS-substring start index
// does not end an earlier LMS-substring, so as an optimization we can omit
// that leftmost LMS-substring start index (the last one we wrote).
//
// Exception: if numLMS <= 1, the caller is not going to bother with
// the recursion at all and will treat the result as containing LMS-substring starts.
// In that case, we don't remove the final entry.
if numLMS > 1 {
sa[lastB] = 0
}
return numLMS
}
// induceSubL_8_32 inserts the L-type text indexes of LMS-substrings
// into sa, assuming that the final characters of the LMS-substrings
// are already inserted into sa, sorted by final character, and at the
// right (not left) end of the corresponding character bucket.
// Each LMS-substring has the form (as a regexp) /S+L+S/:
// one or more S-type, one or more L-type, final S-type.
// induceSubL_8_32 leaves behind only the leftmost L-type text
// index for each LMS-substring. That is, it removes the final S-type
// indexes that are present on entry, and it inserts but then removes
// the interior L-type indexes too.
// (Only the leftmost L-type index is needed by induceSubS_8_32.)
func induceSubL_8_32(text []byte, sa, freq, bucket []int32) {
// Initialize positions for left side of character buckets.
bucketMin_8_32(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// As we scan the array left-to-right, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type L.
// Because j-1 is type L, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type L from type S.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type S.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ > i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type S, at which point it must stop.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i], so that the loop finishes with sa containing
// only the indexes of the leftmost L-type indexes for each LMS-substring.
//
// The suffix array sa therefore serves simultaneously as input, output,
// and a miraculously well-tailored work queue.
// placeLMS_8_32 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index:
// we're processing suffixes in sorted order
// and accessing buckets indexed by the
// byte before the sorted order, which still
// has very good locality.
// Invariant: b is cached, possibly dirty copy of bucket[cB].
cB := c1
b := bucket[cB]
sa[b] = int32(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
if j < 0 {
// Leave discovered type-S index for caller.
sa[i] = int32(-j)
continue
}
sa[i] = 0
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
k := j - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int32(k)
b++
}
}
// induceSubS_8_32 inserts the S-type text indexes of LMS-substrings
// into sa, assuming that the leftmost L-type text indexes are already
// inserted into sa, sorted by LMS-substring suffix, and at the
// left end of the corresponding character bucket.
// Each LMS-substring has the form (as a regexp) /S+L+S/:
// one or more S-type, one or more L-type, final S-type.
// induceSubS_8_32 leaves behind only the leftmost S-type text
// index for each LMS-substring, in sorted order, at the right end of sa.
// That is, it removes the L-type indexes that are present on entry,
// and it inserts but then removes the interior S-type indexes too,
// leaving the LMS-substring start indexes packed into sa[len(sa)-numLMS:].
// (Only the LMS-substring start indexes are processed by the recursion.)
func induceSubS_8_32(text []byte, sa, freq, bucket []int32) {
// Initialize positions for right side of character buckets.
bucketMax_8_32(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// Analogous to induceSubL_8_32 above,
// as we scan the array right-to-left, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type S.
// Because j-1 is type S, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type S from type L.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type L.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ < i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type L, at which point it must stop.
// That index (preceded by one of type L) is an LMS-substring start.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i] and compact into the top of sa,
// so that the loop finishes with the top of sa containing exactly
// the LMS-substring start indexes, sorted by LMS-substring.
// Cache recently used bucket index:
cB := byte(0)
b := bucket[cB]
top := len(sa)
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
sa[i] = 0
if j < 0 {
// Leave discovered LMS-substring start index for caller.
top--
sa[top] = int32(-j)
continue
}
// Index j was on work queue, meaning k := j-1 is S-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue -k to save for the caller.
k := j - 1
c1 := text[k]
c0 := text[k-1]
if c0 > c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int32(k)
}
}
// length_8_32 computes and records the length of each LMS-substring in text.
// The length of the LMS-substring at index j is stored at sa[j/2],
// avoiding the LMS-substring indexes already stored in the top half of sa.
// (If index j is an LMS-substring start, then index j-1 is type L and cannot be.)
// There are two exceptions, made for optimizations in name_8_32 below.
//
// First, the final LMS-substring is recorded as having length 0, which is otherwise
// impossible, instead of giving it a length that includes the implicit sentinel.
// This ensures the final LMS-substring has length unequal to all others
// and therefore can be detected as different without text comparison
// (it is unequal because it is the only one that ends in the implicit sentinel,
// and the text comparison would be problematic since the implicit sentinel
// is not actually present at text[len(text)]).
//
// Second, to avoid text comparison entirely, if an LMS-substring is very short,
// sa[j/2] records its actual text instead of its length, so that if two such
// substrings have matching “length,” the text need not be read at all.
// The definition of “very short” is that the text bytes must pack into an uint32,
// and the unsigned encoding e must be ≥ len(text), so that it can be
// distinguished from a valid length.
func length_8_32(text []byte, sa []int32, numLMS int) {
end := 0 // index of current LMS-substring end (0 indicates final LMS-substring)
// The encoding of N text bytes into a “length” word
// adds 1 to each byte, packs them into the bottom
// N*8 bits of a word, and then bitwise inverts the result.
// That is, the text sequence A B C (hex 41 42 43)
// encodes as ^uint32(0x42_43_44).
// LMS-substrings can never start or end with 0xFF.
// Adding 1 ensures the encoded byte sequence never
// starts or ends with 0x00, so that present bytes can be
// distinguished from zero-padding in the top bits,
// so the length need not be separately encoded.
// Inverting the bytes increases the chance that a
// 4-byte encoding will still be ≥ len(text).
// In particular, if the first byte is ASCII (<= 0x7E, so +1 <= 0x7F)
// then the high bit of the inversion will be set,
// making it clearly not a valid length (it would be a negative one).
//
// cx holds the pre-inverted encoding (the packed incremented bytes).
cx := uint32(0) // byte-only
// This stanza (until the blank line) is the "LMS-substring iterator",
// described in placeLMS_8_32 above, with one line added to maintain cx.
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
cx = cx<<8 | uint32(c1+1) // byte-only
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Index j = i+1 is the start of an LMS-substring.
// Compute length or encoded text to store in sa[j/2].
j := i + 1
var code int32
if end == 0 {
code = 0
} else {
code = int32(end - j)
if code <= 32/8 && ^cx >= uint32(len(text)) { // byte-only
code = int32(^cx) // byte-only
} // byte-only
}
sa[j>>1] = code
end = j + 1
cx = uint32(c1 + 1) // byte-only
}
}
}
// assignID_8_32 assigns a dense ID numbering to the
// set of LMS-substrings respecting string ordering and equality,
// returning the maximum assigned ID.
// For example given the input "ababab", the LMS-substrings
// are "aba", "aba", and "ab", renumbered as 2 2 1.
// sa[len(sa)-numLMS:] holds the LMS-substring indexes
// sorted in string order, so to assign numbers we can
// consider each in turn, removing adjacent duplicates.
// The new ID for the LMS-substring at index j is written to sa[j/2],
// overwriting the length previously stored there (by length_8_32 above).
func assignID_8_32(text []byte, sa []int32, numLMS int) int {
id := 0
lastLen := int32(-1) // impossible
lastPos := int32(0)
for _, j := range sa[len(sa)-numLMS:] {
// Is the LMS-substring at index j new, or is it the same as the last one we saw?
n := sa[j/2]
if n != lastLen {
goto New
}
if uint32(n) >= uint32(len(text)) {
// “Length” is really encoded full text, and they match.
goto Same
}
{
// Compare actual texts.
n := int(n)
this := text[j:][:n]
last := text[lastPos:][:n]
for i := 0; i < n; i++ {
if this[i] != last[i] {
goto New
}
}
goto Same
}
New:
id++
lastPos = j
lastLen = n
Same:
sa[j/2] = int32(id)
}
return id
}
// map_32 maps the LMS-substrings in text to their new IDs,
// producing the subproblem for the recursion.
// The mapping itself was mostly applied by assignID_8_32:
// sa[i] is either 0, the ID for the LMS-substring at index 2*i,
// or the ID for the LMS-substring at index 2*i+1.
// To produce the subproblem we need only remove the zeros
// and change ID into ID-1 (our IDs start at 1, but text chars start at 0).
//
// map_32 packs the result, which is the input to the recursion,
// into the top of sa, so that the recursion result can be stored
// in the bottom of sa, which sets up for expand_8_32 well.
func map_32(sa []int32, numLMS int) {
w := len(sa)
for i := len(sa) / 2; i >= 0; i-- {
j := sa[i]
if j > 0 {
w--
sa[w] = j - 1
}
}
}
// recurse_32 calls sais_32 recursively to solve the subproblem we've built.
// The subproblem is at the right end of sa, the suffix array result will be
// written at the left end of sa, and the middle of sa is available for use as
// temporary frequency and bucket storage.
func recurse_32(sa, oldTmp []int32, numLMS, maxID int) {
dst, saTmp, text := sa[:numLMS], sa[numLMS:len(sa)-numLMS], sa[len(sa)-numLMS:]
// Set up temporary space for recursive call.
// We must pass sais_32 a tmp buffer with at least maxID entries.
//
// The subproblem is guaranteed to have length at most len(sa)/2,
// so that sa can hold both the subproblem and its suffix array.
// Nearly all the time, however, the subproblem has length < len(sa)/3,
// in which case there is a subproblem-sized middle of sa that
// we can reuse for temporary space (saTmp).
// When recurse_32 is called from sais_8_32, oldTmp is length 512
// (from text_32), and saTmp will typically be much larger, so we'll use saTmp.
// When deeper recursions come back to recurse_32, now oldTmp is
// the saTmp from the top-most recursion, it is typically larger than
// the current saTmp (because the current sa gets smaller and smaller
// as the recursion gets deeper), and we keep reusing that top-most
// large saTmp instead of the offered smaller ones.
//
// Why is the subproblem length so often just under len(sa)/3?
// See Nong, Zhang, and Chen, section 3.6 for a plausible explanation.
// In brief, the len(sa)/2 case would correspond to an SLSLSLSLSLSL pattern
// in the input, perfect alternation of larger and smaller input bytes.
// Real text doesn't do that. If each L-type index is randomly followed
// by either an L-type or S-type index, then half the substrings will
// be of the form SLS, but the other half will be longer. Of that half,
// half (a quarter overall) will be SLLS; an eighth will be SLLLS, and so on.
// Not counting the final S in each (which overlaps the first S in the next),
// This works out to an average length 2×½ + 3×¼ + 4×⅛ + ... = 3.
// The space we need is further reduced by the fact that many of the
// short patterns like SLS will often be the same character sequences
// repeated throughout the text, reducing maxID relative to numLMS.
//
// For short inputs, the averages may not run in our favor, but then we
// can often fall back to using the length-512 tmp available in the
// top-most call. (Also a short allocation would not be a big deal.)
//
// For pathological inputs, we fall back to allocating a new tmp of length
// max(maxID, numLMS/2). This level of the recursion needs maxID,
// and all deeper levels of the recursion will need no more than numLMS/2,
// so this one allocation is guaranteed to suffice for the entire stack
// of recursive calls.
tmp := oldTmp
if len(tmp) < len(saTmp) {
tmp = saTmp
}
if len(tmp) < numLMS {
// TestSAIS/forcealloc reaches this code.
n := maxID
if n < numLMS/2 {
n = numLMS / 2
}
tmp = make([]int32, n)
}
// sais_32 requires that the caller arrange to clear dst,
// because in general the caller may know dst is
// freshly-allocated and already cleared. But this one is not.
for i := range dst {
dst[i] = 0
}
sais_32(text, maxID, dst, tmp)
}
// unmap_8_32 unmaps the subproblem back to the original.
// sa[:numLMS] is the LMS-substring numbers, which don't matter much anymore.
// sa[len(sa)-numLMS:] is the sorted list of those LMS-substring numbers.
// The key part is that if the list says K that means the K'th substring.
// We can replace sa[:numLMS] with the indexes of the LMS-substrings.
// Then if the list says K it really means sa[K].
// Having mapped the list back to LMS-substring indexes,
// we can place those into the right buckets.
func unmap_8_32(text []byte, sa []int32, numLMS int) {
unmap := sa[len(sa)-numLMS:]
j := len(unmap)
// "LMS-substring iterator" (see placeLMS_8_32 above).
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Populate inverse map.
j--
unmap[j] = int32(i + 1)
}
}
// Apply inverse map to subproblem suffix array.
sa = sa[:numLMS]
for i := 0; i < len(sa); i++ {
sa[i] = unmap[sa[i]]
}
}
// expand_8_32 distributes the compacted, sorted LMS-suffix indexes
// from sa[:numLMS] into the tops of the appropriate buckets in sa,
// preserving the sorted order and making room for the L-type indexes
// to be slotted into the sorted sequence by induceL_8_32.
func expand_8_32(text []byte, freq, bucket, sa []int32, numLMS int) {
bucketMax_8_32(text, freq, bucket)
bucket = bucket[:256] // eliminate bound check for bucket[c] below
// Loop backward through sa, always tracking
// the next index to populate from sa[:numLMS].
// When we get to one, populate it.
// Zero the rest of the slots; they have dead values in them.
x := numLMS - 1
saX := sa[x]
c := text[saX]
b := bucket[c] - 1
bucket[c] = b
for i := len(sa) - 1; i >= 0; i-- {
if i != int(b) {
sa[i] = 0
continue
}
sa[i] = saX
// Load next entry to put down (if any).
if x > 0 {
x--
saX = sa[x] // TODO bounds check
c = text[saX]
b = bucket[c] - 1
bucket[c] = b
}
}
}
// induceL_8_32 inserts L-type text indexes into sa,
// assuming that the leftmost S-type indexes are inserted
// into sa, in sorted order, in the right bucket halves.
// It leaves all the L-type indexes in sa, but the
// leftmost L-type indexes are negated, to mark them
// for processing by induceS_8_32.
func induceL_8_32(text []byte, sa, freq, bucket []int32) {
// Initialize positions for left side of character buckets.
bucketMin_8_32(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// This scan is similar to the one in induceSubL_8_32 above.
// That one arranges to clear all but the leftmost L-type indexes.
// This scan leaves all the L-type indexes and the original S-type
// indexes, but it negates the positive leftmost L-type indexes
// (the ones that induceS_8_32 needs to process).
// expand_8_32 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index.
cB := c1
b := bucket[cB]
sa[b] = int32(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j <= 0 {
// Skip empty or negated entry (including negated zero).
continue
}
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller. The caller can't tell the difference between
// an empty slot and a non-empty zero, but there's no need
// to distinguish them anyway: the final suffix array will end up
// with one zero somewhere, and that will be a real zero.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 < c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int32(k)
b++
}
}
func induceS_8_32(text []byte, sa, freq, bucket []int32) {
// Initialize positions for right side of character buckets.
bucketMax_8_32(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
cB := byte(0)
b := bucket[cB]
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j >= 0 {
// Skip non-flagged entry.
// (This loop can't see an empty entry; 0 means the real zero index.)
continue
}
// Negative j is a work queue entry; rewrite to positive j for final suffix array.
j = -j
sa[i] = int32(j)
// Index j was on work queue (encoded as -j but now decoded),
// meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue -k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 <= c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int32(k)
}
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by go generate; DO NOT EDIT.
package suffixarray
func text_64(text []byte, sa []int64) {
if int(int64(len(text))) != len(text) || len(text) != len(sa) {
panic("suffixarray: misuse of text_64")
}
sais_8_64(text, 256, sa, make([]int64, 2*256))
}
func sais_8_64(text []byte, textMax int, sa, tmp []int64) {
if len(sa) != len(text) || len(tmp) < int(textMax) {
panic("suffixarray: misuse of sais_8_64")
}
// Trivial base cases. Sorting 0 or 1 things is easy.
if len(text) == 0 {
return
}
if len(text) == 1 {
sa[0] = 0
return
}
// Establish slices indexed by text character
// holding character frequency and bucket-sort offsets.
// If there's only enough tmp for one slice,
// we make it the bucket offsets and recompute
// the character frequency each time we need it.
var freq, bucket []int64
if len(tmp) >= 2*textMax {
freq, bucket = tmp[:textMax], tmp[textMax:2*textMax]
freq[0] = -1 // mark as uninitialized
} else {
freq, bucket = nil, tmp[:textMax]
}
// The SAIS algorithm.
// Each of these calls makes one scan through sa.
// See the individual functions for documentation
// about each's role in the algorithm.
numLMS := placeLMS_8_64(text, sa, freq, bucket)
if numLMS <= 1 {
// 0 or 1 items are already sorted. Do nothing.
} else {
induceSubL_8_64(text, sa, freq, bucket)
induceSubS_8_64(text, sa, freq, bucket)
length_8_64(text, sa, numLMS)
maxID := assignID_8_64(text, sa, numLMS)
if maxID < numLMS {
map_64(sa, numLMS)
recurse_64(sa, tmp, numLMS, maxID)
unmap_8_64(text, sa, numLMS)
} else {
// If maxID == numLMS, then each LMS-substring
// is unique, so the relative ordering of two LMS-suffixes
// is determined by just the leading LMS-substring.
// That is, the LMS-suffix sort order matches the
// (simpler) LMS-substring sort order.
// Copy the original LMS-substring order into the
// suffix array destination.
copy(sa, sa[len(sa)-numLMS:])
}
expand_8_64(text, freq, bucket, sa, numLMS)
}
induceL_8_64(text, sa, freq, bucket)
induceS_8_64(text, sa, freq, bucket)
// Mark for caller that we overwrote tmp.
tmp[0] = -1
}
func sais_32(text []int32, textMax int, sa, tmp []int32) {
if len(sa) != len(text) || len(tmp) < int(textMax) {
panic("suffixarray: misuse of sais_32")
}
// Trivial base cases. Sorting 0 or 1 things is easy.
if len(text) == 0 {
return
}
if len(text) == 1 {
sa[0] = 0
return
}
// Establish slices indexed by text character
// holding character frequency and bucket-sort offsets.
// If there's only enough tmp for one slice,
// we make it the bucket offsets and recompute
// the character frequency each time we need it.
var freq, bucket []int32
if len(tmp) >= 2*textMax {
freq, bucket = tmp[:textMax], tmp[textMax:2*textMax]
freq[0] = -1 // mark as uninitialized
} else {
freq, bucket = nil, tmp[:textMax]
}
// The SAIS algorithm.
// Each of these calls makes one scan through sa.
// See the individual functions for documentation
// about each's role in the algorithm.
numLMS := placeLMS_32(text, sa, freq, bucket)
if numLMS <= 1 {
// 0 or 1 items are already sorted. Do nothing.
} else {
induceSubL_32(text, sa, freq, bucket)
induceSubS_32(text, sa, freq, bucket)
length_32(text, sa, numLMS)
maxID := assignID_32(text, sa, numLMS)
if maxID < numLMS {
map_32(sa, numLMS)
recurse_32(sa, tmp, numLMS, maxID)
unmap_32(text, sa, numLMS)
} else {
// If maxID == numLMS, then each LMS-substring
// is unique, so the relative ordering of two LMS-suffixes
// is determined by just the leading LMS-substring.
// That is, the LMS-suffix sort order matches the
// (simpler) LMS-substring sort order.
// Copy the original LMS-substring order into the
// suffix array destination.
copy(sa, sa[len(sa)-numLMS:])
}
expand_32(text, freq, bucket, sa, numLMS)
}
induceL_32(text, sa, freq, bucket)
induceS_32(text, sa, freq, bucket)
// Mark for caller that we overwrote tmp.
tmp[0] = -1
}
func sais_64(text []int64, textMax int, sa, tmp []int64) {
if len(sa) != len(text) || len(tmp) < int(textMax) {
panic("suffixarray: misuse of sais_64")
}
// Trivial base cases. Sorting 0 or 1 things is easy.
if len(text) == 0 {
return
}
if len(text) == 1 {
sa[0] = 0
return
}
// Establish slices indexed by text character
// holding character frequency and bucket-sort offsets.
// If there's only enough tmp for one slice,
// we make it the bucket offsets and recompute
// the character frequency each time we need it.
var freq, bucket []int64
if len(tmp) >= 2*textMax {
freq, bucket = tmp[:textMax], tmp[textMax:2*textMax]
freq[0] = -1 // mark as uninitialized
} else {
freq, bucket = nil, tmp[:textMax]
}
// The SAIS algorithm.
// Each of these calls makes one scan through sa.
// See the individual functions for documentation
// about each's role in the algorithm.
numLMS := placeLMS_64(text, sa, freq, bucket)
if numLMS <= 1 {
// 0 or 1 items are already sorted. Do nothing.
} else {
induceSubL_64(text, sa, freq, bucket)
induceSubS_64(text, sa, freq, bucket)
length_64(text, sa, numLMS)
maxID := assignID_64(text, sa, numLMS)
if maxID < numLMS {
map_64(sa, numLMS)
recurse_64(sa, tmp, numLMS, maxID)
unmap_64(text, sa, numLMS)
} else {
// If maxID == numLMS, then each LMS-substring
// is unique, so the relative ordering of two LMS-suffixes
// is determined by just the leading LMS-substring.
// That is, the LMS-suffix sort order matches the
// (simpler) LMS-substring sort order.
// Copy the original LMS-substring order into the
// suffix array destination.
copy(sa, sa[len(sa)-numLMS:])
}
expand_64(text, freq, bucket, sa, numLMS)
}
induceL_64(text, sa, freq, bucket)
induceS_64(text, sa, freq, bucket)
// Mark for caller that we overwrote tmp.
tmp[0] = -1
}
func freq_8_64(text []byte, freq, bucket []int64) []int64 {
if freq != nil && freq[0] >= 0 {
return freq // already computed
}
if freq == nil {
freq = bucket
}
freq = freq[:256] // eliminate bounds check for freq[c] below
for i := range freq {
freq[i] = 0
}
for _, c := range text {
freq[c]++
}
return freq
}
func freq_32(text []int32, freq, bucket []int32) []int32 {
if freq != nil && freq[0] >= 0 {
return freq // already computed
}
if freq == nil {
freq = bucket
}
for i := range freq {
freq[i] = 0
}
for _, c := range text {
freq[c]++
}
return freq
}
func freq_64(text []int64, freq, bucket []int64) []int64 {
if freq != nil && freq[0] >= 0 {
return freq // already computed
}
if freq == nil {
freq = bucket
}
for i := range freq {
freq[i] = 0
}
for _, c := range text {
freq[c]++
}
return freq
}
func bucketMin_8_64(text []byte, freq, bucket []int64) {
freq = freq_8_64(text, freq, bucket)
freq = freq[:256] // establish len(freq) = 256, so 0 ≤ i < 256 below
bucket = bucket[:256] // eliminate bounds check for bucket[i] below
total := int64(0)
for i, n := range freq {
bucket[i] = total
total += n
}
}
func bucketMin_32(text []int32, freq, bucket []int32) {
freq = freq_32(text, freq, bucket)
total := int32(0)
for i, n := range freq {
bucket[i] = total
total += n
}
}
func bucketMin_64(text []int64, freq, bucket []int64) {
freq = freq_64(text, freq, bucket)
total := int64(0)
for i, n := range freq {
bucket[i] = total
total += n
}
}
func bucketMax_8_64(text []byte, freq, bucket []int64) {
freq = freq_8_64(text, freq, bucket)
freq = freq[:256] // establish len(freq) = 256, so 0 ≤ i < 256 below
bucket = bucket[:256] // eliminate bounds check for bucket[i] below
total := int64(0)
for i, n := range freq {
total += n
bucket[i] = total
}
}
func bucketMax_32(text []int32, freq, bucket []int32) {
freq = freq_32(text, freq, bucket)
total := int32(0)
for i, n := range freq {
total += n
bucket[i] = total
}
}
func bucketMax_64(text []int64, freq, bucket []int64) {
freq = freq_64(text, freq, bucket)
total := int64(0)
for i, n := range freq {
total += n
bucket[i] = total
}
}
func placeLMS_8_64(text []byte, sa, freq, bucket []int64) int {
bucketMax_8_64(text, freq, bucket)
numLMS := 0
lastB := int64(-1)
bucket = bucket[:256] // eliminate bounds check for bucket[c1] below
// The next stanza of code (until the blank line) loop backward
// over text, stopping to execute a code body at each position i
// such that text[i] is an L-character and text[i+1] is an S-character.
// That is, i+1 is the position of the start of an LMS-substring.
// These could be hoisted out into a function with a callback,
// but at a significant speed cost. Instead, we just write these
// seven lines a few times in this source file. The copies below
// refer back to the pattern established by this original as the
// "LMS-substring iterator".
//
// In every scan through the text, c0, c1 are successive characters of text.
// In this backward scan, c0 == text[i] and c1 == text[i+1].
// By scanning backward, we can keep track of whether the current
// position is type-S or type-L according to the usual definition:
//
// - position len(text) is type S with text[len(text)] == -1 (the sentinel)
// - position i is type S if text[i] < text[i+1], or if text[i] == text[i+1] && i+1 is type S.
// - position i is type L if text[i] > text[i+1], or if text[i] == text[i+1] && i+1 is type L.
//
// The backward scan lets us maintain the current type,
// update it when we see c0 != c1, and otherwise leave it alone.
// We want to identify all S positions with a preceding L.
// Position len(text) is one such position by definition, but we have
// nowhere to write it down, so we eliminate it by untruthfully
// setting isTypeS = false at the start of the loop.
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Bucket the index i+1 for the start of an LMS-substring.
b := bucket[c1] - 1
bucket[c1] = b
sa[b] = int64(i + 1)
lastB = b
numLMS++
}
}
// We recorded the LMS-substring starts but really want the ends.
// Luckily, with two differences, the start indexes and the end indexes are the same.
// The first difference is that the rightmost LMS-substring's end index is len(text),
// so the caller must pretend that sa[-1] == len(text), as noted above.
// The second difference is that the first leftmost LMS-substring start index
// does not end an earlier LMS-substring, so as an optimization we can omit
// that leftmost LMS-substring start index (the last one we wrote).
//
// Exception: if numLMS <= 1, the caller is not going to bother with
// the recursion at all and will treat the result as containing LMS-substring starts.
// In that case, we don't remove the final entry.
if numLMS > 1 {
sa[lastB] = 0
}
return numLMS
}
func placeLMS_32(text []int32, sa, freq, bucket []int32) int {
bucketMax_32(text, freq, bucket)
numLMS := 0
lastB := int32(-1)
// The next stanza of code (until the blank line) loop backward
// over text, stopping to execute a code body at each position i
// such that text[i] is an L-character and text[i+1] is an S-character.
// That is, i+1 is the position of the start of an LMS-substring.
// These could be hoisted out into a function with a callback,
// but at a significant speed cost. Instead, we just write these
// seven lines a few times in this source file. The copies below
// refer back to the pattern established by this original as the
// "LMS-substring iterator".
//
// In every scan through the text, c0, c1 are successive characters of text.
// In this backward scan, c0 == text[i] and c1 == text[i+1].
// By scanning backward, we can keep track of whether the current
// position is type-S or type-L according to the usual definition:
//
// - position len(text) is type S with text[len(text)] == -1 (the sentinel)
// - position i is type S if text[i] < text[i+1], or if text[i] == text[i+1] && i+1 is type S.
// - position i is type L if text[i] > text[i+1], or if text[i] == text[i+1] && i+1 is type L.
//
// The backward scan lets us maintain the current type,
// update it when we see c0 != c1, and otherwise leave it alone.
// We want to identify all S positions with a preceding L.
// Position len(text) is one such position by definition, but we have
// nowhere to write it down, so we eliminate it by untruthfully
// setting isTypeS = false at the start of the loop.
c0, c1, isTypeS := int32(0), int32(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Bucket the index i+1 for the start of an LMS-substring.
b := bucket[c1] - 1
bucket[c1] = b
sa[b] = int32(i + 1)
lastB = b
numLMS++
}
}
// We recorded the LMS-substring starts but really want the ends.
// Luckily, with two differences, the start indexes and the end indexes are the same.
// The first difference is that the rightmost LMS-substring's end index is len(text),
// so the caller must pretend that sa[-1] == len(text), as noted above.
// The second difference is that the first leftmost LMS-substring start index
// does not end an earlier LMS-substring, so as an optimization we can omit
// that leftmost LMS-substring start index (the last one we wrote).
//
// Exception: if numLMS <= 1, the caller is not going to bother with
// the recursion at all and will treat the result as containing LMS-substring starts.
// In that case, we don't remove the final entry.
if numLMS > 1 {
sa[lastB] = 0
}
return numLMS
}
func placeLMS_64(text []int64, sa, freq, bucket []int64) int {
bucketMax_64(text, freq, bucket)
numLMS := 0
lastB := int64(-1)
// The next stanza of code (until the blank line) loop backward
// over text, stopping to execute a code body at each position i
// such that text[i] is an L-character and text[i+1] is an S-character.
// That is, i+1 is the position of the start of an LMS-substring.
// These could be hoisted out into a function with a callback,
// but at a significant speed cost. Instead, we just write these
// seven lines a few times in this source file. The copies below
// refer back to the pattern established by this original as the
// "LMS-substring iterator".
//
// In every scan through the text, c0, c1 are successive characters of text.
// In this backward scan, c0 == text[i] and c1 == text[i+1].
// By scanning backward, we can keep track of whether the current
// position is type-S or type-L according to the usual definition:
//
// - position len(text) is type S with text[len(text)] == -1 (the sentinel)
// - position i is type S if text[i] < text[i+1], or if text[i] == text[i+1] && i+1 is type S.
// - position i is type L if text[i] > text[i+1], or if text[i] == text[i+1] && i+1 is type L.
//
// The backward scan lets us maintain the current type,
// update it when we see c0 != c1, and otherwise leave it alone.
// We want to identify all S positions with a preceding L.
// Position len(text) is one such position by definition, but we have
// nowhere to write it down, so we eliminate it by untruthfully
// setting isTypeS = false at the start of the loop.
c0, c1, isTypeS := int64(0), int64(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Bucket the index i+1 for the start of an LMS-substring.
b := bucket[c1] - 1
bucket[c1] = b
sa[b] = int64(i + 1)
lastB = b
numLMS++
}
}
// We recorded the LMS-substring starts but really want the ends.
// Luckily, with two differences, the start indexes and the end indexes are the same.
// The first difference is that the rightmost LMS-substring's end index is len(text),
// so the caller must pretend that sa[-1] == len(text), as noted above.
// The second difference is that the first leftmost LMS-substring start index
// does not end an earlier LMS-substring, so as an optimization we can omit
// that leftmost LMS-substring start index (the last one we wrote).
//
// Exception: if numLMS <= 1, the caller is not going to bother with
// the recursion at all and will treat the result as containing LMS-substring starts.
// In that case, we don't remove the final entry.
if numLMS > 1 {
sa[lastB] = 0
}
return numLMS
}
func induceSubL_8_64(text []byte, sa, freq, bucket []int64) {
// Initialize positions for left side of character buckets.
bucketMin_8_64(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// As we scan the array left-to-right, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type L.
// Because j-1 is type L, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type L from type S.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type S.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ > i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type S, at which point it must stop.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i], so that the loop finishes with sa containing
// only the indexes of the leftmost L-type indexes for each LMS-substring.
//
// The suffix array sa therefore serves simultaneously as input, output,
// and a miraculously well-tailored work queue.
// placeLMS_8_64 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index:
// we're processing suffixes in sorted order
// and accessing buckets indexed by the
// byte before the sorted order, which still
// has very good locality.
// Invariant: b is cached, possibly dirty copy of bucket[cB].
cB := c1
b := bucket[cB]
sa[b] = int64(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
if j < 0 {
// Leave discovered type-S index for caller.
sa[i] = int64(-j)
continue
}
sa[i] = 0
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
k := j - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int64(k)
b++
}
}
func induceSubL_32(text []int32, sa, freq, bucket []int32) {
// Initialize positions for left side of character buckets.
bucketMin_32(text, freq, bucket)
// As we scan the array left-to-right, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type L.
// Because j-1 is type L, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type L from type S.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type S.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ > i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type S, at which point it must stop.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i], so that the loop finishes with sa containing
// only the indexes of the leftmost L-type indexes for each LMS-substring.
//
// The suffix array sa therefore serves simultaneously as input, output,
// and a miraculously well-tailored work queue.
// placeLMS_32 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index:
// we're processing suffixes in sorted order
// and accessing buckets indexed by the
// int32 before the sorted order, which still
// has very good locality.
// Invariant: b is cached, possibly dirty copy of bucket[cB].
cB := c1
b := bucket[cB]
sa[b] = int32(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
if j < 0 {
// Leave discovered type-S index for caller.
sa[i] = int32(-j)
continue
}
sa[i] = 0
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
k := j - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int32(k)
b++
}
}
func induceSubL_64(text []int64, sa, freq, bucket []int64) {
// Initialize positions for left side of character buckets.
bucketMin_64(text, freq, bucket)
// As we scan the array left-to-right, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type L.
// Because j-1 is type L, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type L from type S.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type S.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ > i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type S, at which point it must stop.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i], so that the loop finishes with sa containing
// only the indexes of the leftmost L-type indexes for each LMS-substring.
//
// The suffix array sa therefore serves simultaneously as input, output,
// and a miraculously well-tailored work queue.
// placeLMS_64 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index:
// we're processing suffixes in sorted order
// and accessing buckets indexed by the
// int64 before the sorted order, which still
// has very good locality.
// Invariant: b is cached, possibly dirty copy of bucket[cB].
cB := c1
b := bucket[cB]
sa[b] = int64(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
if j < 0 {
// Leave discovered type-S index for caller.
sa[i] = int64(-j)
continue
}
sa[i] = 0
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
k := j - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int64(k)
b++
}
}
func induceSubS_8_64(text []byte, sa, freq, bucket []int64) {
// Initialize positions for right side of character buckets.
bucketMax_8_64(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// Analogous to induceSubL_8_64 above,
// as we scan the array right-to-left, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type S.
// Because j-1 is type S, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type S from type L.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type L.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ < i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type L, at which point it must stop.
// That index (preceded by one of type L) is an LMS-substring start.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i] and compact into the top of sa,
// so that the loop finishes with the top of sa containing exactly
// the LMS-substring start indexes, sorted by LMS-substring.
// Cache recently used bucket index:
cB := byte(0)
b := bucket[cB]
top := len(sa)
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
sa[i] = 0
if j < 0 {
// Leave discovered LMS-substring start index for caller.
top--
sa[top] = int64(-j)
continue
}
// Index j was on work queue, meaning k := j-1 is S-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue -k to save for the caller.
k := j - 1
c1 := text[k]
c0 := text[k-1]
if c0 > c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int64(k)
}
}
func induceSubS_32(text []int32, sa, freq, bucket []int32) {
// Initialize positions for right side of character buckets.
bucketMax_32(text, freq, bucket)
// Analogous to induceSubL_32 above,
// as we scan the array right-to-left, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type S.
// Because j-1 is type S, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type S from type L.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type L.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ < i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type L, at which point it must stop.
// That index (preceded by one of type L) is an LMS-substring start.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i] and compact into the top of sa,
// so that the loop finishes with the top of sa containing exactly
// the LMS-substring start indexes, sorted by LMS-substring.
// Cache recently used bucket index:
cB := int32(0)
b := bucket[cB]
top := len(sa)
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
sa[i] = 0
if j < 0 {
// Leave discovered LMS-substring start index for caller.
top--
sa[top] = int32(-j)
continue
}
// Index j was on work queue, meaning k := j-1 is S-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue -k to save for the caller.
k := j - 1
c1 := text[k]
c0 := text[k-1]
if c0 > c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int32(k)
}
}
func induceSubS_64(text []int64, sa, freq, bucket []int64) {
// Initialize positions for right side of character buckets.
bucketMax_64(text, freq, bucket)
// Analogous to induceSubL_64 above,
// as we scan the array right-to-left, each sa[i] = j > 0 is a correctly
// sorted suffix array entry (for text[j:]) for which we know that j-1 is type S.
// Because j-1 is type S, inserting it into sa now will sort it correctly.
// But we want to distinguish a j-1 with j-2 of type S from type L.
// We can process the former but want to leave the latter for the caller.
// We record the difference by negating j-1 if it is preceded by type L.
// Either way, the insertion (into the text[j-1] bucket) is guaranteed to
// happen at sa[i´] for some i´ < i, that is, in the portion of sa we have
// yet to scan. A single pass therefore sees indexes j, j-1, j-2, j-3,
// and so on, in sorted but not necessarily adjacent order, until it finds
// one preceded by an index of type L, at which point it must stop.
// That index (preceded by one of type L) is an LMS-substring start.
//
// As we scan through the array, we clear the worked entries (sa[i] > 0) to zero,
// and we flip sa[i] < 0 to -sa[i] and compact into the top of sa,
// so that the loop finishes with the top of sa containing exactly
// the LMS-substring start indexes, sorted by LMS-substring.
// Cache recently used bucket index:
cB := int64(0)
b := bucket[cB]
top := len(sa)
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j == 0 {
// Skip empty entry.
continue
}
sa[i] = 0
if j < 0 {
// Leave discovered LMS-substring start index for caller.
top--
sa[top] = int64(-j)
continue
}
// Index j was on work queue, meaning k := j-1 is S-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue -k to save for the caller.
k := j - 1
c1 := text[k]
c0 := text[k-1]
if c0 > c1 {
k = -k
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int64(k)
}
}
func length_8_64(text []byte, sa []int64, numLMS int) {
end := 0 // index of current LMS-substring end (0 indicates final LMS-substring)
// The encoding of N text bytes into a “length” word
// adds 1 to each byte, packs them into the bottom
// N*8 bits of a word, and then bitwise inverts the result.
// That is, the text sequence A B C (hex 41 42 43)
// encodes as ^uint64(0x42_43_44).
// LMS-substrings can never start or end with 0xFF.
// Adding 1 ensures the encoded byte sequence never
// starts or ends with 0x00, so that present bytes can be
// distinguished from zero-padding in the top bits,
// so the length need not be separately encoded.
// Inverting the bytes increases the chance that a
// 4-byte encoding will still be ≥ len(text).
// In particular, if the first byte is ASCII (<= 0x7E, so +1 <= 0x7F)
// then the high bit of the inversion will be set,
// making it clearly not a valid length (it would be a negative one).
//
// cx holds the pre-inverted encoding (the packed incremented bytes).
cx := uint64(0) // byte-only
// This stanza (until the blank line) is the "LMS-substring iterator",
// described in placeLMS_8_64 above, with one line added to maintain cx.
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
cx = cx<<8 | uint64(c1+1) // byte-only
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Index j = i+1 is the start of an LMS-substring.
// Compute length or encoded text to store in sa[j/2].
j := i + 1
var code int64
if end == 0 {
code = 0
} else {
code = int64(end - j)
if code <= 64/8 && ^cx >= uint64(len(text)) { // byte-only
code = int64(^cx) // byte-only
} // byte-only
}
sa[j>>1] = code
end = j + 1
cx = uint64(c1 + 1) // byte-only
}
}
}
func length_32(text []int32, sa []int32, numLMS int) {
end := 0 // index of current LMS-substring end (0 indicates final LMS-substring)
// The encoding of N text int32s into a “length” word
// adds 1 to each int32, packs them into the bottom
// N*8 bits of a word, and then bitwise inverts the result.
// That is, the text sequence A B C (hex 41 42 43)
// encodes as ^uint32(0x42_43_44).
// LMS-substrings can never start or end with 0xFF.
// Adding 1 ensures the encoded int32 sequence never
// starts or ends with 0x00, so that present int32s can be
// distinguished from zero-padding in the top bits,
// so the length need not be separately encoded.
// Inverting the int32s increases the chance that a
// 4-int32 encoding will still be ≥ len(text).
// In particular, if the first int32 is ASCII (<= 0x7E, so +1 <= 0x7F)
// then the high bit of the inversion will be set,
// making it clearly not a valid length (it would be a negative one).
//
// cx holds the pre-inverted encoding (the packed incremented int32s).
// This stanza (until the blank line) is the "LMS-substring iterator",
// described in placeLMS_32 above, with one line added to maintain cx.
c0, c1, isTypeS := int32(0), int32(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Index j = i+1 is the start of an LMS-substring.
// Compute length or encoded text to store in sa[j/2].
j := i + 1
var code int32
if end == 0 {
code = 0
} else {
code = int32(end - j)
}
sa[j>>1] = code
end = j + 1
}
}
}
func length_64(text []int64, sa []int64, numLMS int) {
end := 0 // index of current LMS-substring end (0 indicates final LMS-substring)
// The encoding of N text int64s into a “length” word
// adds 1 to each int64, packs them into the bottom
// N*8 bits of a word, and then bitwise inverts the result.
// That is, the text sequence A B C (hex 41 42 43)
// encodes as ^uint64(0x42_43_44).
// LMS-substrings can never start or end with 0xFF.
// Adding 1 ensures the encoded int64 sequence never
// starts or ends with 0x00, so that present int64s can be
// distinguished from zero-padding in the top bits,
// so the length need not be separately encoded.
// Inverting the int64s increases the chance that a
// 4-int64 encoding will still be ≥ len(text).
// In particular, if the first int64 is ASCII (<= 0x7E, so +1 <= 0x7F)
// then the high bit of the inversion will be set,
// making it clearly not a valid length (it would be a negative one).
//
// cx holds the pre-inverted encoding (the packed incremented int64s).
// This stanza (until the blank line) is the "LMS-substring iterator",
// described in placeLMS_64 above, with one line added to maintain cx.
c0, c1, isTypeS := int64(0), int64(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Index j = i+1 is the start of an LMS-substring.
// Compute length or encoded text to store in sa[j/2].
j := i + 1
var code int64
if end == 0 {
code = 0
} else {
code = int64(end - j)
}
sa[j>>1] = code
end = j + 1
}
}
}
func assignID_8_64(text []byte, sa []int64, numLMS int) int {
id := 0
lastLen := int64(-1) // impossible
lastPos := int64(0)
for _, j := range sa[len(sa)-numLMS:] {
// Is the LMS-substring at index j new, or is it the same as the last one we saw?
n := sa[j/2]
if n != lastLen {
goto New
}
if uint64(n) >= uint64(len(text)) {
// “Length” is really encoded full text, and they match.
goto Same
}
{
// Compare actual texts.
n := int(n)
this := text[j:][:n]
last := text[lastPos:][:n]
for i := 0; i < n; i++ {
if this[i] != last[i] {
goto New
}
}
goto Same
}
New:
id++
lastPos = j
lastLen = n
Same:
sa[j/2] = int64(id)
}
return id
}
func assignID_32(text []int32, sa []int32, numLMS int) int {
id := 0
lastLen := int32(-1) // impossible
lastPos := int32(0)
for _, j := range sa[len(sa)-numLMS:] {
// Is the LMS-substring at index j new, or is it the same as the last one we saw?
n := sa[j/2]
if n != lastLen {
goto New
}
if uint32(n) >= uint32(len(text)) {
// “Length” is really encoded full text, and they match.
goto Same
}
{
// Compare actual texts.
n := int(n)
this := text[j:][:n]
last := text[lastPos:][:n]
for i := 0; i < n; i++ {
if this[i] != last[i] {
goto New
}
}
goto Same
}
New:
id++
lastPos = j
lastLen = n
Same:
sa[j/2] = int32(id)
}
return id
}
func assignID_64(text []int64, sa []int64, numLMS int) int {
id := 0
lastLen := int64(-1) // impossible
lastPos := int64(0)
for _, j := range sa[len(sa)-numLMS:] {
// Is the LMS-substring at index j new, or is it the same as the last one we saw?
n := sa[j/2]
if n != lastLen {
goto New
}
if uint64(n) >= uint64(len(text)) {
// “Length” is really encoded full text, and they match.
goto Same
}
{
// Compare actual texts.
n := int(n)
this := text[j:][:n]
last := text[lastPos:][:n]
for i := 0; i < n; i++ {
if this[i] != last[i] {
goto New
}
}
goto Same
}
New:
id++
lastPos = j
lastLen = n
Same:
sa[j/2] = int64(id)
}
return id
}
func map_64(sa []int64, numLMS int) {
w := len(sa)
for i := len(sa) / 2; i >= 0; i-- {
j := sa[i]
if j > 0 {
w--
sa[w] = j - 1
}
}
}
func recurse_64(sa, oldTmp []int64, numLMS, maxID int) {
dst, saTmp, text := sa[:numLMS], sa[numLMS:len(sa)-numLMS], sa[len(sa)-numLMS:]
// Set up temporary space for recursive call.
// We must pass sais_64 a tmp buffer with at least maxID entries.
//
// The subproblem is guaranteed to have length at most len(sa)/2,
// so that sa can hold both the subproblem and its suffix array.
// Nearly all the time, however, the subproblem has length < len(sa)/3,
// in which case there is a subproblem-sized middle of sa that
// we can reuse for temporary space (saTmp).
// When recurse_64 is called from sais_8_64, oldTmp is length 512
// (from text_64), and saTmp will typically be much larger, so we'll use saTmp.
// When deeper recursions come back to recurse_64, now oldTmp is
// the saTmp from the top-most recursion, it is typically larger than
// the current saTmp (because the current sa gets smaller and smaller
// as the recursion gets deeper), and we keep reusing that top-most
// large saTmp instead of the offered smaller ones.
//
// Why is the subproblem length so often just under len(sa)/3?
// See Nong, Zhang, and Chen, section 3.6 for a plausible explanation.
// In brief, the len(sa)/2 case would correspond to an SLSLSLSLSLSL pattern
// in the input, perfect alternation of larger and smaller input bytes.
// Real text doesn't do that. If each L-type index is randomly followed
// by either an L-type or S-type index, then half the substrings will
// be of the form SLS, but the other half will be longer. Of that half,
// half (a quarter overall) will be SLLS; an eighth will be SLLLS, and so on.
// Not counting the final S in each (which overlaps the first S in the next),
// This works out to an average length 2×½ + 3×¼ + 4×⅛ + ... = 3.
// The space we need is further reduced by the fact that many of the
// short patterns like SLS will often be the same character sequences
// repeated throughout the text, reducing maxID relative to numLMS.
//
// For short inputs, the averages may not run in our favor, but then we
// can often fall back to using the length-512 tmp available in the
// top-most call. (Also a short allocation would not be a big deal.)
//
// For pathological inputs, we fall back to allocating a new tmp of length
// max(maxID, numLMS/2). This level of the recursion needs maxID,
// and all deeper levels of the recursion will need no more than numLMS/2,
// so this one allocation is guaranteed to suffice for the entire stack
// of recursive calls.
tmp := oldTmp
if len(tmp) < len(saTmp) {
tmp = saTmp
}
if len(tmp) < numLMS {
// TestSAIS/forcealloc reaches this code.
n := maxID
if n < numLMS/2 {
n = numLMS / 2
}
tmp = make([]int64, n)
}
// sais_64 requires that the caller arrange to clear dst,
// because in general the caller may know dst is
// freshly-allocated and already cleared. But this one is not.
for i := range dst {
dst[i] = 0
}
sais_64(text, maxID, dst, tmp)
}
func unmap_8_64(text []byte, sa []int64, numLMS int) {
unmap := sa[len(sa)-numLMS:]
j := len(unmap)
// "LMS-substring iterator" (see placeLMS_8_64 above).
c0, c1, isTypeS := byte(0), byte(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Populate inverse map.
j--
unmap[j] = int64(i + 1)
}
}
// Apply inverse map to subproblem suffix array.
sa = sa[:numLMS]
for i := 0; i < len(sa); i++ {
sa[i] = unmap[sa[i]]
}
}
func unmap_32(text []int32, sa []int32, numLMS int) {
unmap := sa[len(sa)-numLMS:]
j := len(unmap)
// "LMS-substring iterator" (see placeLMS_32 above).
c0, c1, isTypeS := int32(0), int32(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Populate inverse map.
j--
unmap[j] = int32(i + 1)
}
}
// Apply inverse map to subproblem suffix array.
sa = sa[:numLMS]
for i := 0; i < len(sa); i++ {
sa[i] = unmap[sa[i]]
}
}
func unmap_64(text []int64, sa []int64, numLMS int) {
unmap := sa[len(sa)-numLMS:]
j := len(unmap)
// "LMS-substring iterator" (see placeLMS_64 above).
c0, c1, isTypeS := int64(0), int64(0), false
for i := len(text) - 1; i >= 0; i-- {
c0, c1 = text[i], c0
if c0 < c1 {
isTypeS = true
} else if c0 > c1 && isTypeS {
isTypeS = false
// Populate inverse map.
j--
unmap[j] = int64(i + 1)
}
}
// Apply inverse map to subproblem suffix array.
sa = sa[:numLMS]
for i := 0; i < len(sa); i++ {
sa[i] = unmap[sa[i]]
}
}
func expand_8_64(text []byte, freq, bucket, sa []int64, numLMS int) {
bucketMax_8_64(text, freq, bucket)
bucket = bucket[:256] // eliminate bound check for bucket[c] below
// Loop backward through sa, always tracking
// the next index to populate from sa[:numLMS].
// When we get to one, populate it.
// Zero the rest of the slots; they have dead values in them.
x := numLMS - 1
saX := sa[x]
c := text[saX]
b := bucket[c] - 1
bucket[c] = b
for i := len(sa) - 1; i >= 0; i-- {
if i != int(b) {
sa[i] = 0
continue
}
sa[i] = saX
// Load next entry to put down (if any).
if x > 0 {
x--
saX = sa[x] // TODO bounds check
c = text[saX]
b = bucket[c] - 1
bucket[c] = b
}
}
}
func expand_32(text []int32, freq, bucket, sa []int32, numLMS int) {
bucketMax_32(text, freq, bucket)
// Loop backward through sa, always tracking
// the next index to populate from sa[:numLMS].
// When we get to one, populate it.
// Zero the rest of the slots; they have dead values in them.
x := numLMS - 1
saX := sa[x]
c := text[saX]
b := bucket[c] - 1
bucket[c] = b
for i := len(sa) - 1; i >= 0; i-- {
if i != int(b) {
sa[i] = 0
continue
}
sa[i] = saX
// Load next entry to put down (if any).
if x > 0 {
x--
saX = sa[x] // TODO bounds check
c = text[saX]
b = bucket[c] - 1
bucket[c] = b
}
}
}
func expand_64(text []int64, freq, bucket, sa []int64, numLMS int) {
bucketMax_64(text, freq, bucket)
// Loop backward through sa, always tracking
// the next index to populate from sa[:numLMS].
// When we get to one, populate it.
// Zero the rest of the slots; they have dead values in them.
x := numLMS - 1
saX := sa[x]
c := text[saX]
b := bucket[c] - 1
bucket[c] = b
for i := len(sa) - 1; i >= 0; i-- {
if i != int(b) {
sa[i] = 0
continue
}
sa[i] = saX
// Load next entry to put down (if any).
if x > 0 {
x--
saX = sa[x] // TODO bounds check
c = text[saX]
b = bucket[c] - 1
bucket[c] = b
}
}
}
func induceL_8_64(text []byte, sa, freq, bucket []int64) {
// Initialize positions for left side of character buckets.
bucketMin_8_64(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
// This scan is similar to the one in induceSubL_8_64 above.
// That one arranges to clear all but the leftmost L-type indexes.
// This scan leaves all the L-type indexes and the original S-type
// indexes, but it negates the positive leftmost L-type indexes
// (the ones that induceS_8_64 needs to process).
// expand_8_64 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index.
cB := c1
b := bucket[cB]
sa[b] = int64(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j <= 0 {
// Skip empty or negated entry (including negated zero).
continue
}
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller. The caller can't tell the difference between
// an empty slot and a non-empty zero, but there's no need
// to distinguish them anyway: the final suffix array will end up
// with one zero somewhere, and that will be a real zero.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 < c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int64(k)
b++
}
}
func induceL_32(text []int32, sa, freq, bucket []int32) {
// Initialize positions for left side of character buckets.
bucketMin_32(text, freq, bucket)
// This scan is similar to the one in induceSubL_32 above.
// That one arranges to clear all but the leftmost L-type indexes.
// This scan leaves all the L-type indexes and the original S-type
// indexes, but it negates the positive leftmost L-type indexes
// (the ones that induceS_32 needs to process).
// expand_32 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index.
cB := c1
b := bucket[cB]
sa[b] = int32(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j <= 0 {
// Skip empty or negated entry (including negated zero).
continue
}
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller. The caller can't tell the difference between
// an empty slot and a non-empty zero, but there's no need
// to distinguish them anyway: the final suffix array will end up
// with one zero somewhere, and that will be a real zero.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 < c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int32(k)
b++
}
}
func induceL_64(text []int64, sa, freq, bucket []int64) {
// Initialize positions for left side of character buckets.
bucketMin_64(text, freq, bucket)
// This scan is similar to the one in induceSubL_64 above.
// That one arranges to clear all but the leftmost L-type indexes.
// This scan leaves all the L-type indexes and the original S-type
// indexes, but it negates the positive leftmost L-type indexes
// (the ones that induceS_64 needs to process).
// expand_64 left out the implicit entry sa[-1] == len(text),
// corresponding to the identified type-L index len(text)-1.
// Process it before the left-to-right scan of sa proper.
// See body in loop for commentary.
k := len(text) - 1
c0, c1 := text[k-1], text[k]
if c0 < c1 {
k = -k
}
// Cache recently used bucket index.
cB := c1
b := bucket[cB]
sa[b] = int64(k)
b++
for i := 0; i < len(sa); i++ {
j := int(sa[i])
if j <= 0 {
// Skip empty or negated entry (including negated zero).
continue
}
// Index j was on work queue, meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is L-type, queue k for processing later in this loop.
// If k-1 is S-type (text[k-1] < text[k]), queue -k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller. The caller can't tell the difference between
// an empty slot and a non-empty zero, but there's no need
// to distinguish them anyway: the final suffix array will end up
// with one zero somewhere, and that will be a real zero.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 < c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
sa[b] = int64(k)
b++
}
}
func induceS_8_64(text []byte, sa, freq, bucket []int64) {
// Initialize positions for right side of character buckets.
bucketMax_8_64(text, freq, bucket)
bucket = bucket[:256] // eliminate bounds check for bucket[cB] below
cB := byte(0)
b := bucket[cB]
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j >= 0 {
// Skip non-flagged entry.
// (This loop can't see an empty entry; 0 means the real zero index.)
continue
}
// Negative j is a work queue entry; rewrite to positive j for final suffix array.
j = -j
sa[i] = int64(j)
// Index j was on work queue (encoded as -j but now decoded),
// meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue -k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 <= c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int64(k)
}
}
func induceS_32(text []int32, sa, freq, bucket []int32) {
// Initialize positions for right side of character buckets.
bucketMax_32(text, freq, bucket)
cB := int32(0)
b := bucket[cB]
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j >= 0 {
// Skip non-flagged entry.
// (This loop can't see an empty entry; 0 means the real zero index.)
continue
}
// Negative j is a work queue entry; rewrite to positive j for final suffix array.
j = -j
sa[i] = int32(j)
// Index j was on work queue (encoded as -j but now decoded),
// meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue -k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 <= c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int32(k)
}
}
func induceS_64(text []int64, sa, freq, bucket []int64) {
// Initialize positions for right side of character buckets.
bucketMax_64(text, freq, bucket)
cB := int64(0)
b := bucket[cB]
for i := len(sa) - 1; i >= 0; i-- {
j := int(sa[i])
if j >= 0 {
// Skip non-flagged entry.
// (This loop can't see an empty entry; 0 means the real zero index.)
continue
}
// Negative j is a work queue entry; rewrite to positive j for final suffix array.
j = -j
sa[i] = int64(j)
// Index j was on work queue (encoded as -j but now decoded),
// meaning k := j-1 is L-type,
// so we can now place k correctly into sa.
// If k-1 is S-type, queue -k for processing later in this loop.
// If k-1 is L-type (text[k-1] > text[k]), queue k to save for the caller.
// If k is zero, k-1 doesn't exist, so we only need to leave it
// for the caller.
k := j - 1
c1 := text[k]
if k > 0 {
if c0 := text[k-1]; c0 <= c1 {
k = -k
}
}
if cB != c1 {
bucket[cB] = b
cB = c1
b = bucket[cB]
}
b--
sa[b] = int64(k)
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package suffixarray implements substring search in logarithmic time using
// an in-memory suffix array.
//
// Example use:
//
// // create index for some data
// index := suffixarray.New(data)
//
// // lookup byte slice s
// offsets1 := index.Lookup(s, -1) // the list of all indices where s occurs in data
// offsets2 := index.Lookup(s, 3) // the list of at most 3 indices where s occurs in data
package suffixarray
import (
"bytes"
"encoding/binary"
"errors"
"io"
"math"
"regexp"
"sort"
)
// Can change for testing
var maxData32 int = realMaxData32
const realMaxData32 = math.MaxInt32
// Index implements a suffix array for fast substring search.
type Index struct {
data []byte
sa ints // suffix array for data; sa.len() == len(data)
}
// An ints is either an []int32 or an []int64.
// That is, one of them is empty, and one is the real data.
// The int64 form is used when len(data) > maxData32
type ints struct {
int32 []int32
int64 []int64
}
func (a *ints) len() int {
return len(a.int32) + len(a.int64)
}
func (a *ints) get(i int) int64 {
if a.int32 != nil {
return int64(a.int32[i])
}
return a.int64[i]
}
func (a *ints) set(i int, v int64) {
if a.int32 != nil {
a.int32[i] = int32(v)
} else {
a.int64[i] = v
}
}
func (a *ints) slice(i, j int) ints {
if a.int32 != nil {
return ints{a.int32[i:j], nil}
}
return ints{nil, a.int64[i:j]}
}
// New creates a new Index for data.
// Index creation time is O(N) for N = len(data).
func New(data []byte) *Index {
ix := &Index{data: data}
if len(data) <= maxData32 {
ix.sa.int32 = make([]int32, len(data))
text_32(data, ix.sa.int32)
} else {
ix.sa.int64 = make([]int64, len(data))
text_64(data, ix.sa.int64)
}
return ix
}
// writeInt writes an int x to w using buf to buffer the write.
func writeInt(w io.Writer, buf []byte, x int) error {
binary.PutVarint(buf, int64(x))
_, err := w.Write(buf[0:binary.MaxVarintLen64])
return err
}
// readInt reads an int x from r using buf to buffer the read and returns x.
func readInt(r io.Reader, buf []byte) (int64, error) {
_, err := io.ReadFull(r, buf[0:binary.MaxVarintLen64]) // ok to continue with error
x, _ := binary.Varint(buf)
return x, err
}
// writeSlice writes data[:n] to w and returns n.
// It uses buf to buffer the write.
func writeSlice(w io.Writer, buf []byte, data ints) (n int, err error) {
// encode as many elements as fit into buf
p := binary.MaxVarintLen64
m := data.len()
for ; n < m && p+binary.MaxVarintLen64 <= len(buf); n++ {
p += binary.PutUvarint(buf[p:], uint64(data.get(n)))
}
// update buffer size
binary.PutVarint(buf, int64(p))
// write buffer
_, err = w.Write(buf[0:p])
return
}
var errTooBig = errors.New("suffixarray: data too large")
// readSlice reads data[:n] from r and returns n.
// It uses buf to buffer the read.
func readSlice(r io.Reader, buf []byte, data ints) (n int, err error) {
// read buffer size
var size64 int64
size64, err = readInt(r, buf)
if err != nil {
return
}
if int64(int(size64)) != size64 || int(size64) < 0 {
// We never write chunks this big anyway.
return 0, errTooBig
}
size := int(size64)
// read buffer w/o the size
if _, err = io.ReadFull(r, buf[binary.MaxVarintLen64:size]); err != nil {
return
}
// decode as many elements as present in buf
for p := binary.MaxVarintLen64; p < size; n++ {
x, w := binary.Uvarint(buf[p:])
data.set(n, int64(x))
p += w
}
return
}
const bufSize = 16 << 10 // reasonable for BenchmarkSaveRestore
// Read reads the index from r into x; x must not be nil.
func (x *Index) Read(r io.Reader) error {
// buffer for all reads
buf := make([]byte, bufSize)
// read length
n64, err := readInt(r, buf)
if err != nil {
return err
}
if int64(int(n64)) != n64 || int(n64) < 0 {
return errTooBig
}
n := int(n64)
// allocate space
if 2*n < cap(x.data) || cap(x.data) < n || x.sa.int32 != nil && n > maxData32 || x.sa.int64 != nil && n <= maxData32 {
// new data is significantly smaller or larger than
// existing buffers - allocate new ones
x.data = make([]byte, n)
x.sa.int32 = nil
x.sa.int64 = nil
if n <= maxData32 {
x.sa.int32 = make([]int32, n)
} else {
x.sa.int64 = make([]int64, n)
}
} else {
// re-use existing buffers
x.data = x.data[0:n]
x.sa = x.sa.slice(0, n)
}
// read data
if _, err := io.ReadFull(r, x.data); err != nil {
return err
}
// read index
sa := x.sa
for sa.len() > 0 {
n, err := readSlice(r, buf, sa)
if err != nil {
return err
}
sa = sa.slice(n, sa.len())
}
return nil
}
// Write writes the index x to w.
func (x *Index) Write(w io.Writer) error {
// buffer for all writes
buf := make([]byte, bufSize)
// write length
if err := writeInt(w, buf, len(x.data)); err != nil {
return err
}
// write data
if _, err := w.Write(x.data); err != nil {
return err
}
// write index
sa := x.sa
for sa.len() > 0 {
n, err := writeSlice(w, buf, sa)
if err != nil {
return err
}
sa = sa.slice(n, sa.len())
}
return nil
}
// Bytes returns the data over which the index was created.
// It must not be modified.
func (x *Index) Bytes() []byte {
return x.data
}
func (x *Index) at(i int) []byte {
return x.data[x.sa.get(i):]
}
// lookupAll returns a slice into the matching region of the index.
// The runtime is O(log(N)*len(s)).
func (x *Index) lookupAll(s []byte) ints {
// find matching suffix index range [i:j]
// find the first index where s would be the prefix
i := sort.Search(x.sa.len(), func(i int) bool { return bytes.Compare(x.at(i), s) >= 0 })
// starting at i, find the first index at which s is not a prefix
j := i + sort.Search(x.sa.len()-i, func(j int) bool { return !bytes.HasPrefix(x.at(j+i), s) })
return x.sa.slice(i, j)
}
// Lookup returns an unsorted list of at most n indices where the byte string s
// occurs in the indexed data. If n < 0, all occurrences are returned.
// The result is nil if s is empty, s is not found, or n == 0.
// Lookup time is O(log(N)*len(s) + len(result)) where N is the
// size of the indexed data.
func (x *Index) Lookup(s []byte, n int) (result []int) {
if len(s) > 0 && n != 0 {
matches := x.lookupAll(s)
count := matches.len()
if n < 0 || count < n {
n = count
}
// 0 <= n <= count
if n > 0 {
result = make([]int, n)
if matches.int32 != nil {
for i := range result {
result[i] = int(matches.int32[i])
}
} else {
for i := range result {
result[i] = int(matches.int64[i])
}
}
}
}
return
}
// FindAllIndex returns a sorted list of non-overlapping matches of the
// regular expression r, where a match is a pair of indices specifying
// the matched slice of x.Bytes(). If n < 0, all matches are returned
// in successive order. Otherwise, at most n matches are returned and
// they may not be successive. The result is nil if there are no matches,
// or if n == 0.
func (x *Index) FindAllIndex(r *regexp.Regexp, n int) (result [][]int) {
// a non-empty literal prefix is used to determine possible
// match start indices with Lookup
prefix, complete := r.LiteralPrefix()
lit := []byte(prefix)
// worst-case scenario: no literal prefix
if prefix == "" {
return r.FindAllIndex(x.data, n)
}
// if regexp is a literal just use Lookup and convert its
// result into match pairs
if complete {
// Lookup returns indices that may belong to overlapping matches.
// After eliminating them, we may end up with fewer than n matches.
// If we don't have enough at the end, redo the search with an
// increased value n1, but only if Lookup returned all the requested
// indices in the first place (if it returned fewer than that then
// there cannot be more).
for n1 := n; ; n1 += 2 * (n - len(result)) /* overflow ok */ {
indices := x.Lookup(lit, n1)
if len(indices) == 0 {
return
}
sort.Ints(indices)
pairs := make([]int, 2*len(indices))
result = make([][]int, len(indices))
count := 0
prev := 0
for _, i := range indices {
if count == n {
break
}
// ignore indices leading to overlapping matches
if prev <= i {
j := 2 * count
pairs[j+0] = i
pairs[j+1] = i + len(lit)
result[count] = pairs[j : j+2]
count++
prev = i + len(lit)
}
}
result = result[0:count]
if len(result) >= n || len(indices) != n1 {
// found all matches or there's no chance to find more
// (n and n1 can be negative)
break
}
}
if len(result) == 0 {
result = nil
}
return
}
// regexp has a non-empty literal prefix; Lookup(lit) computes
// the indices of possible complete matches; use these as starting
// points for anchored searches
// (regexp "^" matches beginning of input, not beginning of line)
r = regexp.MustCompile("^" + r.String()) // compiles because r compiled
// same comment about Lookup applies here as in the loop above
for n1 := n; ; n1 += 2 * (n - len(result)) /* overflow ok */ {
indices := x.Lookup(lit, n1)
if len(indices) == 0 {
return
}
sort.Ints(indices)
result = result[0:0]
prev := 0
for _, i := range indices {
if len(result) == n {
break
}
m := r.FindIndex(x.data[i:]) // anchored search - will not run off
// ignore indices leading to overlapping matches
if m != nil && prev <= i {
m[0] = i // correct m
m[1] += i
result = append(result, m)
prev = m[1]
}
}
if len(result) >= n || len(indices) != n1 {
// found all matches or there's no chance to find more
// (n and n1 can be negative)
break
}
}
if len(result) == 0 {
result = nil
}
return
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package abi
import (
"internal/goarch"
"unsafe"
)
// RegArgs is a struct that has space for each argument
// and return value register on the current architecture.
//
// Assembly code knows the layout of the first two fields
// of RegArgs.
//
// RegArgs also contains additional space to hold pointers
// when it may not be safe to keep them only in the integer
// register space otherwise.
type RegArgs struct {
// Values in these slots should be precisely the bit-by-bit
// representation of how they would appear in a register.
//
// This means that on big endian arches, integer values should
// be in the top bits of the slot. Floats are usually just
// directly represented, but some architectures treat narrow
// width floating point values specially (e.g. they're promoted
// first, or they need to be NaN-boxed).
Ints [IntArgRegs]uintptr // untyped integer registers
Floats [FloatArgRegs]uint64 // untyped float registers
// Fields above this point are known to assembly.
// Ptrs is a space that duplicates Ints but with pointer type,
// used to make pointers passed or returned in registers
// visible to the GC by making the type unsafe.Pointer.
Ptrs [IntArgRegs]unsafe.Pointer
// ReturnIsPtr is a bitmap that indicates which registers
// contain or will contain pointers on the return path from
// a reflectcall. The i'th bit indicates whether the i'th
// register contains or will contain a valid Go pointer.
ReturnIsPtr IntArgRegBitmap
}
func (r *RegArgs) Dump() {
print("Ints:")
for _, x := range r.Ints {
print(" ", x)
}
println()
print("Floats:")
for _, x := range r.Floats {
print(" ", x)
}
println()
print("Ptrs:")
for _, x := range r.Ptrs {
print(" ", x)
}
println()
}
// IntRegArgAddr returns a pointer inside of r.Ints[reg] that is appropriately
// offset for an argument of size argSize.
//
// argSize must be non-zero, fit in a register, and a power-of-two.
//
// This method is a helper for dealing with the endianness of different CPU
// architectures, since sub-word-sized arguments in big endian architectures
// need to be "aligned" to the upper edge of the register to be interpreted
// by the CPU correctly.
func (r *RegArgs) IntRegArgAddr(reg int, argSize uintptr) unsafe.Pointer {
if argSize > goarch.PtrSize || argSize == 0 || argSize&(argSize-1) != 0 {
panic("invalid argSize")
}
offset := uintptr(0)
if goarch.BigEndian {
offset = goarch.PtrSize - argSize
}
return unsafe.Pointer(uintptr(unsafe.Pointer(&r.Ints[reg])) + offset)
}
// IntArgRegBitmap is a bitmap large enough to hold one bit per
// integer argument/return register.
type IntArgRegBitmap [(IntArgRegs + 7) / 8]uint8
// Set sets the i'th bit of the bitmap to 1.
func (b *IntArgRegBitmap) Set(i int) {
b[i/8] |= uint8(1) << (i % 8)
}
// Get returns whether the i'th bit of the bitmap is set.
//
// nosplit because it's called in extremely sensitive contexts, like
// on the reflectcall return path.
//
//go:nosplit
func (b *IntArgRegBitmap) Get(i int) bool {
return b[i/8]&(uint8(1)<<(i%8)) != 0
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package buildcfg provides access to the build configuration
// described by the current environment. It is for use by build tools
// such as cmd/go or cmd/compile and for setting up go/build's Default context.
//
// Note that it does NOT provide access to the build configuration used to
// build the currently-running binary. For that, use runtime.GOOS etc
// as well as internal/goexperiment.
package buildcfg
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
)
var (
GOROOT = runtime.GOROOT() // cached for efficiency
GOARCH = envOr("GOARCH", defaultGOARCH)
GOOS = envOr("GOOS", defaultGOOS)
GO386 = envOr("GO386", defaultGO386)
GOAMD64 = goamd64()
GOARM = goarm()
GOMIPS = gomips()
GOMIPS64 = gomips64()
GOPPC64 = goppc64()
GOWASM = gowasm()
ToolTags = toolTags()
GO_LDSO = defaultGO_LDSO
Version = version
)
// Error is one of the errors found (if any) in the build configuration.
var Error error
// Check exits the program with a fatal error if Error is non-nil.
func Check() {
if Error != nil {
fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), Error)
os.Exit(2)
}
}
func envOr(key, value string) string {
if x := os.Getenv(key); x != "" {
return x
}
return value
}
func goamd64() int {
switch v := envOr("GOAMD64", defaultGOAMD64); v {
case "v1":
return 1
case "v2":
return 2
case "v3":
return 3
case "v4":
return 4
}
Error = fmt.Errorf("invalid GOAMD64: must be v1, v2, v3, v4")
return int(defaultGOAMD64[len("v")] - '0')
}
func goarm() int {
def := defaultGOARM
if GOOS == "android" && GOARCH == "arm" {
// Android arm devices always support GOARM=7.
def = "7"
}
switch v := envOr("GOARM", def); v {
case "5":
return 5
case "6":
return 6
case "7":
return 7
}
Error = fmt.Errorf("invalid GOARM: must be 5, 6, 7")
return int(def[0] - '0')
}
func gomips() string {
switch v := envOr("GOMIPS", defaultGOMIPS); v {
case "hardfloat", "softfloat":
return v
}
Error = fmt.Errorf("invalid GOMIPS: must be hardfloat, softfloat")
return defaultGOMIPS
}
func gomips64() string {
switch v := envOr("GOMIPS64", defaultGOMIPS64); v {
case "hardfloat", "softfloat":
return v
}
Error = fmt.Errorf("invalid GOMIPS64: must be hardfloat, softfloat")
return defaultGOMIPS64
}
func goppc64() int {
switch v := envOr("GOPPC64", defaultGOPPC64); v {
case "power8":
return 8
case "power9":
return 9
case "power10":
return 10
}
Error = fmt.Errorf("invalid GOPPC64: must be power8, power9, power10")
return int(defaultGOPPC64[len("power")] - '0')
}
type gowasmFeatures struct {
SatConv bool
SignExt bool
}
func (f gowasmFeatures) String() string {
var flags []string
if f.SatConv {
flags = append(flags, "satconv")
}
if f.SignExt {
flags = append(flags, "signext")
}
return strings.Join(flags, ",")
}
func gowasm() (f gowasmFeatures) {
for _, opt := range strings.Split(envOr("GOWASM", ""), ",") {
switch opt {
case "satconv":
f.SatConv = true
case "signext":
f.SignExt = true
case "":
// ignore
default:
Error = fmt.Errorf("invalid GOWASM: no such feature %q", opt)
}
}
return
}
func Getgoextlinkenabled() string {
return envOr("GO_EXTLINK_ENABLED", defaultGO_EXTLINK_ENABLED)
}
func toolTags() []string {
tags := experimentTags()
tags = append(tags, gogoarchTags()...)
return tags
}
func experimentTags() []string {
var list []string
// For each experiment that has been enabled in the toolchain, define a
// build tag with the same name but prefixed by "goexperiment." which can be
// used for compiling alternative files for the experiment. This allows
// changes for the experiment, like extra struct fields in the runtime,
// without affecting the base non-experiment code at all.
for _, exp := range Experiment.Enabled() {
list = append(list, "goexperiment."+exp)
}
return list
}
// GOGOARCH returns the name and value of the GO$GOARCH setting.
// For example, if GOARCH is "amd64" it might return "GOAMD64", "v2".
func GOGOARCH() (name, value string) {
switch GOARCH {
case "386":
return "GO386", GO386
case "amd64":
return "GOAMD64", fmt.Sprintf("v%d", GOAMD64)
case "arm":
return "GOARM", fmt.Sprintf("%d", GOARM)
case "mips", "mipsle":
return "GOMIPS", GOMIPS
case "mips64", "mips64le":
return "GOMIPS64", GOMIPS64
case "ppc64", "ppc64le":
return "GOPPC64", fmt.Sprintf("power%d", GOPPC64)
case "wasm":
return "GOWASM", GOWASM.String()
}
return "", ""
}
func gogoarchTags() []string {
switch GOARCH {
case "386":
return []string{GOARCH + "." + GO386}
case "amd64":
var list []string
for i := 1; i <= GOAMD64; i++ {
list = append(list, fmt.Sprintf("%s.v%d", GOARCH, i))
}
return list
case "arm":
var list []string
for i := 5; i <= GOARM; i++ {
list = append(list, fmt.Sprintf("%s.%d", GOARCH, i))
}
return list
case "mips", "mipsle":
return []string{GOARCH + "." + GOMIPS}
case "mips64", "mips64le":
return []string{GOARCH + "." + GOMIPS64}
case "ppc64", "ppc64le":
var list []string
for i := 8; i <= GOPPC64; i++ {
list = append(list, fmt.Sprintf("%s.power%d", GOARCH, i))
}
return list
case "wasm":
var list []string
if GOWASM.SatConv {
list = append(list, GOARCH+".satconv")
}
if GOWASM.SignExt {
list = append(list, GOARCH+".signext")
}
return list
}
return nil
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package buildcfg
import (
"fmt"
"reflect"
"strings"
"internal/goexperiment"
)
// ExperimentFlags represents a set of GOEXPERIMENT flags relative to a baseline
// (platform-default) experiment configuration.
type ExperimentFlags struct {
goexperiment.Flags
baseline goexperiment.Flags
}
// Experiment contains the toolchain experiments enabled for the
// current build.
//
// (This is not necessarily the set of experiments the compiler itself
// was built with.)
//
// experimentBaseline specifies the experiment flags that are enabled by
// default in the current toolchain. This is, in effect, the "control"
// configuration and any variation from this is an experiment.
var Experiment ExperimentFlags = func() ExperimentFlags {
flags, err := ParseGOEXPERIMENT(GOOS, GOARCH, envOr("GOEXPERIMENT", defaultGOEXPERIMENT))
if err != nil {
Error = err
return ExperimentFlags{}
}
return *flags
}()
// DefaultGOEXPERIMENT is the embedded default GOEXPERIMENT string.
// It is not guaranteed to be canonical.
const DefaultGOEXPERIMENT = defaultGOEXPERIMENT
// FramePointerEnabled enables the use of platform conventions for
// saving frame pointers.
//
// This used to be an experiment, but now it's always enabled on
// platforms that support it.
//
// Note: must agree with runtime.framepointer_enabled.
var FramePointerEnabled = GOARCH == "amd64" || GOARCH == "arm64"
// ParseGOEXPERIMENT parses a (GOOS, GOARCH, GOEXPERIMENT)
// configuration tuple and returns the enabled and baseline experiment
// flag sets.
//
// TODO(mdempsky): Move to internal/goexperiment.
func ParseGOEXPERIMENT(goos, goarch, goexp string) (*ExperimentFlags, error) {
// regabiSupported is set to true on platforms where register ABI is
// supported and enabled by default.
// regabiAlwaysOn is set to true on platforms where register ABI is
// always on.
var regabiSupported, regabiAlwaysOn bool
switch goarch {
case "amd64", "arm64", "ppc64le", "ppc64", "riscv64":
regabiAlwaysOn = true
regabiSupported = true
}
baseline := goexperiment.Flags{
RegabiWrappers: regabiSupported,
RegabiArgs: regabiSupported,
CoverageRedesign: true,
}
// Start with the statically enabled set of experiments.
flags := &ExperimentFlags{
Flags: baseline,
baseline: baseline,
}
// Pick up any changes to the baseline configuration from the
// GOEXPERIMENT environment. This can be set at make.bash time
// and overridden at build time.
if goexp != "" {
// Create a map of known experiment names.
names := make(map[string]func(bool))
rv := reflect.ValueOf(&flags.Flags).Elem()
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
field := rv.Field(i)
names[strings.ToLower(rt.Field(i).Name)] = field.SetBool
}
// "regabi" is an alias for all working regabi
// subexperiments, and not an experiment itself. Doing
// this as an alias make both "regabi" and "noregabi"
// do the right thing.
names["regabi"] = func(v bool) {
flags.RegabiWrappers = v
flags.RegabiArgs = v
}
// Parse names.
for _, f := range strings.Split(goexp, ",") {
if f == "" {
continue
}
if f == "none" {
// GOEXPERIMENT=none disables all experiment flags.
// This is used by cmd/dist, which doesn't know how
// to build with any experiment flags.
flags.Flags = goexperiment.Flags{}
continue
}
val := true
if strings.HasPrefix(f, "no") {
f, val = f[2:], false
}
set, ok := names[f]
if !ok {
return nil, fmt.Errorf("unknown GOEXPERIMENT %s", f)
}
set(val)
}
}
if regabiAlwaysOn {
flags.RegabiWrappers = true
flags.RegabiArgs = true
}
// regabi is only supported on amd64, arm64, riscv64, ppc64 and ppc64le.
if !regabiSupported {
flags.RegabiWrappers = false
flags.RegabiArgs = false
}
// Check regabi dependencies.
if flags.RegabiArgs && !flags.RegabiWrappers {
return nil, fmt.Errorf("GOEXPERIMENT regabiargs requires regabiwrappers")
}
return flags, nil
}
// String returns the canonical GOEXPERIMENT string to enable this experiment
// configuration. (Experiments in the same state as in the baseline are elided.)
func (exp *ExperimentFlags) String() string {
return strings.Join(expList(&exp.Flags, &exp.baseline, false), ",")
}
// expList returns the list of lower-cased experiment names for
// experiments that differ from base. base may be nil to indicate no
// experiments. If all is true, then include all experiment flags,
// regardless of base.
func expList(exp, base *goexperiment.Flags, all bool) []string {
var list []string
rv := reflect.ValueOf(exp).Elem()
var rBase reflect.Value
if base != nil {
rBase = reflect.ValueOf(base).Elem()
}
rt := rv.Type()
for i := 0; i < rt.NumField(); i++ {
name := strings.ToLower(rt.Field(i).Name)
val := rv.Field(i).Bool()
baseVal := false
if base != nil {
baseVal = rBase.Field(i).Bool()
}
if all || val != baseVal {
if val {
list = append(list, name)
} else {
list = append(list, "no"+name)
}
}
}
return list
}
// Enabled returns a list of enabled experiments, as
// lower-cased experiment names.
func (exp *ExperimentFlags) Enabled() []string {
return expList(&exp.Flags, nil, false)
}
// All returns a list of all experiment settings.
// Disabled experiments appear in the list prefixed by "no".
func (exp *ExperimentFlags) All() []string {
return expList(&exp.Flags, nil, true)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cformat
// This package provides apis for producing human-readable summaries
// of coverage data (e.g. a coverage percentage for a given package or
// set of packages) and for writing data in the legacy test format
// emitted by "go test -coverprofile=<outfile>".
//
// The model for using these apis is to create a Formatter object,
// then make a series of calls to SetPackage and AddUnit passing in
// data read from coverage meta-data and counter-data files. E.g.
//
// myformatter := cformat.NewFormatter()
// ...
// for each package P in meta-data file: {
// myformatter.SetPackage(P)
// for each function F in P: {
// for each coverable unit U in F: {
// myformatter.AddUnit(U)
// }
// }
// }
// myformatter.EmitPercent(os.Stdout, "")
// myformatter.EmitTextual(somefile)
//
// These apis are linked into tests that are built with "-cover", and
// called at the end of test execution to produce text output or
// emit coverage percentages.
import (
"fmt"
"internal/coverage"
"internal/coverage/cmerge"
"io"
"sort"
"text/tabwriter"
)
type Formatter struct {
// Maps import path to package state.
pm map[string]*pstate
// Records current package being visited.
pkg string
// Pointer to current package state.
p *pstate
// Counter mode.
cm coverage.CounterMode
}
// pstate records package-level coverage data state:
// - a table of functions (file/fname/literal)
// - a map recording the index/ID of each func encountered so far
// - a table storing execution count for the coverable units in each func
type pstate struct {
// slice of unique functions
funcs []fnfile
// maps function to index in slice above (index acts as function ID)
funcTable map[fnfile]uint32
// A table storing coverage counts for each coverable unit.
unitTable map[extcu]uint32
}
// extcu encapsulates a coverable unit within some function.
type extcu struct {
fnfid uint32 // index into p.funcs slice
coverage.CoverableUnit
}
// fnfile is a function-name/file-name tuple.
type fnfile struct {
file string
fname string
lit bool
}
func NewFormatter(cm coverage.CounterMode) *Formatter {
return &Formatter{
pm: make(map[string]*pstate),
cm: cm,
}
}
// SetPackage tells the formatter that we're about to visit the
// coverage data for the package with the specified import path.
// Note that it's OK to call SetPackage more than once with the
// same import path; counter data values will be accumulated.
func (fm *Formatter) SetPackage(importpath string) {
if importpath == fm.pkg {
return
}
fm.pkg = importpath
ps, ok := fm.pm[importpath]
if !ok {
ps = new(pstate)
fm.pm[importpath] = ps
ps.unitTable = make(map[extcu]uint32)
ps.funcTable = make(map[fnfile]uint32)
}
fm.p = ps
}
// AddUnit passes info on a single coverable unit (file, funcname,
// literal flag, range of lines, and counter value) to the formatter.
// Counter values will be accumulated where appropriate.
func (fm *Formatter) AddUnit(file string, fname string, isfnlit bool, unit coverage.CoverableUnit, count uint32) {
if fm.p == nil {
panic("AddUnit invoked before SetPackage")
}
fkey := fnfile{file: file, fname: fname, lit: isfnlit}
idx, ok := fm.p.funcTable[fkey]
if !ok {
idx = uint32(len(fm.p.funcs))
fm.p.funcs = append(fm.p.funcs, fkey)
fm.p.funcTable[fkey] = idx
}
ukey := extcu{fnfid: idx, CoverableUnit: unit}
pcount := fm.p.unitTable[ukey]
var result uint32
if fm.cm == coverage.CtrModeSet {
if count != 0 || pcount != 0 {
result = 1
}
} else {
// Use saturating arithmetic.
result, _ = cmerge.SaturatingAdd(pcount, count)
}
fm.p.unitTable[ukey] = result
}
// sortUnits sorts a slice of extcu objects in a package according to
// source position information (e.g. file and line). Note that we don't
// include function name as part of the sorting criteria, the thinking
// being that is better to provide things in the original source order.
func (p *pstate) sortUnits(units []extcu) {
sort.Slice(units, func(i, j int) bool {
ui := units[i]
uj := units[j]
ifile := p.funcs[ui.fnfid].file
jfile := p.funcs[uj.fnfid].file
if ifile != jfile {
return ifile < jfile
}
// NB: not taking function literal flag into account here (no
// need, since other fields are guaranteed to be distinct).
if units[i].StLine != units[j].StLine {
return units[i].StLine < units[j].StLine
}
if units[i].EnLine != units[j].EnLine {
return units[i].EnLine < units[j].EnLine
}
if units[i].StCol != units[j].StCol {
return units[i].StCol < units[j].StCol
}
if units[i].EnCol != units[j].EnCol {
return units[i].EnCol < units[j].EnCol
}
return units[i].NxStmts < units[j].NxStmts
})
}
// EmitTextual writes the accumulated coverage data in the legacy
// cmd/cover text format to the writer 'w'. We sort the data items by
// importpath, source file, and line number before emitting (this sorting
// is not explicitly mandated by the format, but seems like a good idea
// for repeatable/deterministic dumps).
func (fm *Formatter) EmitTextual(w io.Writer) error {
if fm.cm == coverage.CtrModeInvalid {
panic("internal error, counter mode unset")
}
if _, err := fmt.Fprintf(w, "mode: %s\n", fm.cm.String()); err != nil {
return err
}
pkgs := make([]string, 0, len(fm.pm))
for importpath := range fm.pm {
pkgs = append(pkgs, importpath)
}
sort.Strings(pkgs)
for _, importpath := range pkgs {
p := fm.pm[importpath]
units := make([]extcu, 0, len(p.unitTable))
for u := range p.unitTable {
units = append(units, u)
}
p.sortUnits(units)
for _, u := range units {
count := p.unitTable[u]
file := p.funcs[u.fnfid].file
if _, err := fmt.Fprintf(w, "%s:%d.%d,%d.%d %d %d\n",
file, u.StLine, u.StCol,
u.EnLine, u.EnCol, u.NxStmts, count); err != nil {
return err
}
}
}
return nil
}
// EmitPercent writes out a "percentage covered" string to the writer 'w'.
func (fm *Formatter) EmitPercent(w io.Writer, covpkgs string, noteEmpty bool) error {
pkgs := make([]string, 0, len(fm.pm))
for importpath := range fm.pm {
pkgs = append(pkgs, importpath)
}
sort.Strings(pkgs)
seenPkg := false
for _, importpath := range pkgs {
seenPkg = true
p := fm.pm[importpath]
var totalStmts, coveredStmts uint64
for unit, count := range p.unitTable {
nx := uint64(unit.NxStmts)
totalStmts += nx
if count != 0 {
coveredStmts += nx
}
}
if _, err := fmt.Fprintf(w, "\t%s\t", importpath); err != nil {
return err
}
if totalStmts == 0 {
if _, err := fmt.Fprintf(w, "coverage: [no statements]\n"); err != nil {
return err
}
} else {
if _, err := fmt.Fprintf(w, "coverage: %.1f%% of statements%s\n", 100*float64(coveredStmts)/float64(totalStmts), covpkgs); err != nil {
return err
}
}
}
if noteEmpty && !seenPkg {
if _, err := fmt.Fprintf(w, "coverage: [no statements]\n"); err != nil {
return err
}
}
return nil
}
// EmitFuncs writes out a function-level summary to the writer 'w'. A
// note on handling function literals: although we collect coverage
// data for unnamed literals, it probably does not make sense to
// include them in the function summary since there isn't any good way
// to name them (this is also consistent with the legacy cmd/cover
// implementation). We do want to include their counts in the overall
// summary however.
func (fm *Formatter) EmitFuncs(w io.Writer) error {
if fm.cm == coverage.CtrModeInvalid {
panic("internal error, counter mode unset")
}
perc := func(covered, total uint64) float64 {
if total == 0 {
total = 1
}
return 100.0 * float64(covered) / float64(total)
}
tabber := tabwriter.NewWriter(w, 1, 8, 1, '\t', 0)
defer tabber.Flush()
allStmts := uint64(0)
covStmts := uint64(0)
pkgs := make([]string, 0, len(fm.pm))
for importpath := range fm.pm {
pkgs = append(pkgs, importpath)
}
sort.Strings(pkgs)
// Emit functions for each package, sorted by import path.
for _, importpath := range pkgs {
p := fm.pm[importpath]
if len(p.unitTable) == 0 {
continue
}
units := make([]extcu, 0, len(p.unitTable))
for u := range p.unitTable {
units = append(units, u)
}
// Within a package, sort the units, then walk through the
// sorted array. Each time we hit a new function, emit the
// summary entry for the previous function, then make one last
// emit call at the end of the loop.
p.sortUnits(units)
fname := ""
ffile := ""
flit := false
var fline uint32
var cstmts, tstmts uint64
captureFuncStart := func(u extcu) {
fname = p.funcs[u.fnfid].fname
ffile = p.funcs[u.fnfid].file
flit = p.funcs[u.fnfid].lit
fline = u.StLine
}
emitFunc := func(u extcu) error {
// Don't emit entries for function literals (see discussion
// in function header comment above).
if !flit {
if _, err := fmt.Fprintf(tabber, "%s:%d:\t%s\t%.1f%%\n",
ffile, fline, fname, perc(cstmts, tstmts)); err != nil {
return err
}
}
captureFuncStart(u)
allStmts += tstmts
covStmts += cstmts
tstmts = 0
cstmts = 0
return nil
}
for k, u := range units {
if k == 0 {
captureFuncStart(u)
} else {
if fname != p.funcs[u.fnfid].fname {
// New function; emit entry for previous one.
if err := emitFunc(u); err != nil {
return err
}
}
}
tstmts += uint64(u.NxStmts)
count := p.unitTable[u]
if count != 0 {
cstmts += uint64(u.NxStmts)
}
}
if err := emitFunc(extcu{}); err != nil {
return err
}
}
if _, err := fmt.Fprintf(tabber, "%s\t%s\t%.1f%%\n",
"total", "(statements)", perc(covStmts, allStmts)); err != nil {
return err
}
return nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmerge
// package cmerge provides a few small utility APIs for helping
// with merging of counter data for a given function.
import (
"fmt"
"internal/coverage"
"math"
)
// Merger provides state and methods to help manage the process of
// merging together coverage counter data for a given function, for
// tools that need to implicitly merge counter as they read multiple
// coverage counter data files.
type Merger struct {
cmode coverage.CounterMode
cgran coverage.CounterGranularity
overflow bool
}
// MergeCounters takes the counter values in 'src' and merges them
// into 'dst' according to the correct counter mode.
func (m *Merger) MergeCounters(dst, src []uint32) (error, bool) {
if len(src) != len(dst) {
return fmt.Errorf("merging counters: len(dst)=%d len(src)=%d", len(dst), len(src)), false
}
if m.cmode == coverage.CtrModeSet {
for i := 0; i < len(src); i++ {
if src[i] != 0 {
dst[i] = 1
}
}
} else {
for i := 0; i < len(src); i++ {
dst[i] = m.SaturatingAdd(dst[i], src[i])
}
}
ovf := m.overflow
m.overflow = false
return nil, ovf
}
// Saturating add does a saturating addition of 'dst' and 'src',
// returning added value or math.MaxUint32 if there is an overflow.
// Overflows are recorded in case the client needs to track them.
func (m *Merger) SaturatingAdd(dst, src uint32) uint32 {
result, overflow := SaturatingAdd(dst, src)
if overflow {
m.overflow = true
}
return result
}
// Saturating add does a saturing addition of 'dst' and 'src',
// returning added value or math.MaxUint32 plus an overflow flag.
func SaturatingAdd(dst, src uint32) (uint32, bool) {
d, s := uint64(dst), uint64(src)
sum := d + s
overflow := false
if uint64(uint32(sum)) != sum {
overflow = true
sum = math.MaxUint32
}
return uint32(sum), overflow
}
// SetModeAndGranularity records the counter mode and granularity for
// the current merge. In the specific case of merging across coverage
// data files from different binaries, where we're combining data from
// more than one meta-data file, we need to check for mode/granularity
// clashes.
func (cm *Merger) SetModeAndGranularity(mdf string, cmode coverage.CounterMode, cgran coverage.CounterGranularity) error {
// Collect counter mode and granularity so as to detect clashes.
if cm.cmode != coverage.CtrModeInvalid {
if cm.cmode != cmode {
return fmt.Errorf("counter mode clash while reading meta-data file %s: previous file had %s, new file has %s", mdf, cm.cmode.String(), cmode.String())
}
if cm.cgran != cgran {
return fmt.Errorf("counter granularity clash while reading meta-data file %s: previous file had %s, new file has %s", mdf, cm.cgran.String(), cgran.String())
}
}
cm.cmode = cmode
cm.cgran = cgran
return nil
}
func (cm *Merger) ResetModeAndGranularity() {
cm.cmode = coverage.CtrModeInvalid
cm.cgran = coverage.CtrGranularityInvalid
cm.overflow = false
}
func (cm *Merger) Mode() coverage.CounterMode {
return cm.cmode
}
func (cm *Merger) Granularity() coverage.CounterGranularity {
return cm.cgran
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pods
import (
"fmt"
"internal/coverage"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
)
// Pod encapsulates a set of files emitted during the executions of a
// coverage-instrumented binary. Each pod contains a single meta-data
// file, and then 0 or more counter data files that refer to that
// meta-data file. Pods are intended to simplify processing of
// coverage output files in the case where we have several coverage
// output directories containing output files derived from more
// than one instrumented executable. In the case where the files that
// make up a pod are spread out across multiple directories, each
// element of the "Origins" field below will be populated with the
// index of the originating directory for the corresponding counter
// data file (within the slice of input dirs handed to CollectPods).
// The ProcessIDs field will be populated with the process ID of each
// data file in the CounterDataFiles slice.
type Pod struct {
MetaFile string
CounterDataFiles []string
Origins []int
ProcessIDs []int
}
// CollectPods visits the files contained within the directories in
// the list 'dirs', collects any coverage-related files, partitions
// them into pods, and returns a list of the pods to the caller, along
// with an error if something went wrong during directory/file
// reading.
//
// CollectPods skips over any file that is not related to coverage
// (e.g. avoids looking at things that are not meta-data files or
// counter-data files). CollectPods also skips over 'orphaned' counter
// data files (e.g. counter data files for which we can't find the
// corresponding meta-data file). If "warn" is true, CollectPods will
// issue warnings to stderr when it encounters non-fatal problems (for
// orphans or a directory with no meta-data files).
func CollectPods(dirs []string, warn bool) ([]Pod, error) {
files := []string{}
dirIndices := []int{}
for k, dir := range dirs {
dents, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, e := range dents {
if e.IsDir() {
continue
}
files = append(files, filepath.Join(dir, e.Name()))
dirIndices = append(dirIndices, k)
}
}
return collectPodsImpl(files, dirIndices, warn), nil
}
// CollectPodsFromFiles functions the same as "CollectPods" but
// operates on an explicit list of files instead of a directory.
func CollectPodsFromFiles(files []string, warn bool) []Pod {
return collectPodsImpl(files, nil, warn)
}
type fileWithAnnotations struct {
file string
origin int
pid int
}
type protoPod struct {
mf string
elements []fileWithAnnotations
}
// collectPodsImpl examines the specified list of files and picks out
// subsets that correspond to coverage pods. The first stage in this
// process is collecting a set { M1, M2, ... MN } where each M_k is a
// distinct coverage meta-data file. We then create a single pod for
// each meta-data file M_k, then find all of the counter data files
// that refer to that meta-data file (recall that the counter data
// file name incorporates the meta-data hash), and add the counter
// data file to the appropriate pod.
//
// This process is complicated by the fact that we need to keep track
// of directory indices for counter data files. Here is an example to
// motivate:
//
// directory 1:
//
// M1 covmeta.9bbf1777f47b3fcacb05c38b035512d6
// C1 covcounters.9bbf1777f47b3fcacb05c38b035512d6.1677673.1662138360208416486
// C2 covcounters.9bbf1777f47b3fcacb05c38b035512d6.1677637.1662138359974441782
//
// directory 2:
//
// M2 covmeta.9bbf1777f47b3fcacb05c38b035512d6
// C3 covcounters.9bbf1777f47b3fcacb05c38b035512d6.1677445.1662138360208416480
// C4 covcounters.9bbf1777f47b3fcacb05c38b035512d6.1677677.1662138359974441781
// M3 covmeta.a723844208cea2ae80c63482c78b2245
// C5 covcounters.a723844208cea2ae80c63482c78b2245.3677445.1662138360208416480
// C6 covcounters.a723844208cea2ae80c63482c78b2245.1877677.1662138359974441781
//
// In these two directories we have three meta-data files, but only
// two are distinct, meaning that we'll wind up with two pods. The
// first pod (with meta-file M1) will have four counter data files
// (C1, C2, C3, C4) and the second pod will have two counter data files
// (C5, C6).
func collectPodsImpl(files []string, dirIndices []int, warn bool) []Pod {
metaRE := regexp.MustCompile(fmt.Sprintf(`^%s\.(\S+)$`, coverage.MetaFilePref))
mm := make(map[string]protoPod)
for _, f := range files {
base := filepath.Base(f)
if m := metaRE.FindStringSubmatch(base); m != nil {
tag := m[1]
// We need to allow for the possibility of duplicate
// meta-data files. If we hit this case, use the
// first encountered as the canonical version.
if _, ok := mm[tag]; !ok {
mm[tag] = protoPod{mf: f}
}
// FIXME: should probably check file length and hash here for
// the duplicate.
}
}
counterRE := regexp.MustCompile(fmt.Sprintf(coverage.CounterFileRegexp, coverage.CounterFilePref))
for k, f := range files {
base := filepath.Base(f)
if m := counterRE.FindStringSubmatch(base); m != nil {
tag := m[1] // meta hash
pid, err := strconv.Atoi(m[2])
if err != nil {
continue
}
if v, ok := mm[tag]; ok {
idx := -1
if dirIndices != nil {
idx = dirIndices[k]
}
fo := fileWithAnnotations{file: f, origin: idx, pid: pid}
v.elements = append(v.elements, fo)
mm[tag] = v
} else {
if warn {
warning("skipping orphaned counter file: %s", f)
}
}
}
}
if len(mm) == 0 {
if warn {
warning("no coverage data files found")
}
return nil
}
pods := make([]Pod, 0, len(mm))
for _, p := range mm {
sort.Slice(p.elements, func(i, j int) bool {
return p.elements[i].file < p.elements[j].file
})
pod := Pod{
MetaFile: p.mf,
CounterDataFiles: make([]string, 0, len(p.elements)),
Origins: make([]int, 0, len(p.elements)),
ProcessIDs: make([]int, 0, len(p.elements)),
}
for _, e := range p.elements {
pod.CounterDataFiles = append(pod.CounterDataFiles, e.file)
pod.Origins = append(pod.Origins, e.origin)
pod.ProcessIDs = append(pod.ProcessIDs, e.pid)
}
pods = append(pods, pod)
}
sort.Slice(pods, func(i, j int) bool {
return pods[i].MetaFile < pods[j].MetaFile
})
return pods
}
func warning(s string, a ...interface{}) {
fmt.Fprintf(os.Stderr, "warning: ")
fmt.Fprintf(os.Stderr, s, a...)
fmt.Fprintf(os.Stderr, "\n")
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slicereader
import (
"encoding/binary"
"unsafe"
)
// This file contains the helper "SliceReader", a utility for
// reading values from a byte slice that may or may not be backed
// by a read-only mmap'd region.
type Reader struct {
b []byte
readonly bool
off int64
}
func NewReader(b []byte, readonly bool) *Reader {
r := Reader{
b: b,
readonly: readonly,
}
return &r
}
func (r *Reader) Read(b []byte) (int, error) {
amt := len(b)
toread := r.b[r.off:]
if len(toread) < amt {
amt = len(toread)
}
copy(b, toread)
r.off += int64(amt)
return amt, nil
}
func (r *Reader) SeekTo(off int64) {
r.off = off
}
func (r *Reader) Offset() int64 {
return r.off
}
func (r *Reader) ReadUint8() uint8 {
rv := uint8(r.b[int(r.off)])
r.off += 1
return rv
}
func (r *Reader) ReadUint32() uint32 {
end := int(r.off) + 4
rv := binary.LittleEndian.Uint32(r.b[int(r.off):end:end])
r.off += 4
return rv
}
func (r *Reader) ReadUint64() uint64 {
end := int(r.off) + 8
rv := binary.LittleEndian.Uint64(r.b[int(r.off):end:end])
r.off += 8
return rv
}
func (r *Reader) ReadULEB128() (value uint64) {
var shift uint
for {
b := r.b[r.off]
r.off++
value |= (uint64(b&0x7F) << shift)
if b&0x80 == 0 {
break
}
shift += 7
}
return
}
func (r *Reader) ReadString(len int64) string {
b := r.b[r.off : r.off+len]
r.off += len
if r.readonly {
return toString(b) // backed by RO memory, ok to make unsafe string
}
return string(b)
}
func toString(b []byte) string {
if len(b) == 0 {
return ""
}
return unsafe.String(&b[0], len(b))
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slicewriter
import (
"fmt"
"io"
)
// WriteSeeker is a helper object that implements the io.WriteSeeker
// interface. Clients can create a WriteSeeker, make a series of Write
// calls to add data to it (and possibly Seek calls to update
// previously written portions), then finally invoke BytesWritten() to
// get a pointer to the constructed byte slice.
type WriteSeeker struct {
payload []byte
off int64
}
func (sws *WriteSeeker) Write(p []byte) (n int, err error) {
amt := len(p)
towrite := sws.payload[sws.off:]
if len(towrite) < amt {
sws.payload = append(sws.payload, make([]byte, amt-len(towrite))...)
towrite = sws.payload[sws.off:]
}
copy(towrite, p)
sws.off += int64(amt)
return amt, nil
}
// Seek repositions the read/write position of the WriteSeeker within
// its internally maintained slice. Note that it is not possible to
// expand the size of the slice using SEEK_SET; trying to seek outside
// the slice will result in an error.
func (sws *WriteSeeker) Seek(offset int64, whence int) (int64, error) {
switch whence {
case io.SeekStart:
if sws.off != offset && (offset < 0 || offset >= int64(len(sws.payload))) {
return 0, fmt.Errorf("invalid seek: new offset %d (out of range [0 %d]", offset, len(sws.payload))
}
sws.off = offset
return offset, nil
case io.SeekCurrent:
newoff := sws.off + offset
if newoff != sws.off && (newoff < 0 || newoff >= int64(len(sws.payload))) {
return 0, fmt.Errorf("invalid seek: new offset %d (out of range [0 %d]", newoff, len(sws.payload))
}
sws.off += offset
return sws.off, nil
case io.SeekEnd:
newoff := int64(len(sws.payload)) + offset
if newoff != sws.off && (newoff < 0 || newoff >= int64(len(sws.payload))) {
return 0, fmt.Errorf("invalid seek: new offset %d (out of range [0 %d]", newoff, len(sws.payload))
}
sws.off = newoff
return sws.off, nil
}
// other modes not supported
return 0, fmt.Errorf("unsupported seek mode %d", whence)
}
// BytesWritten returns the underlying byte slice for the WriteSeeker,
// containing the data written to it via Write/Seek calls.
func (sws *WriteSeeker) BytesWritten() []byte {
return sws.payload
}
func (sws *WriteSeeker) Read(p []byte) (n int, err error) {
amt := len(p)
toread := sws.payload[sws.off:]
if len(toread) < amt {
amt = len(toread)
}
copy(p, toread)
sws.off += int64(amt)
return amt, nil
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cpu implements processor feature detection
// used by the Go standard library.
package cpu
// DebugOptions is set to true by the runtime if the OS supports reading
// GODEBUG early in runtime startup.
// This should not be changed after it is initialized.
var DebugOptions bool
// CacheLinePad is used to pad structs to avoid false sharing.
type CacheLinePad struct{ _ [CacheLinePadSize]byte }
// CacheLineSize is the CPU's assumed cache line size.
// There is currently no runtime detection of the real cache line size
// so we use the constant per GOARCH CacheLinePadSize as an approximation.
var CacheLineSize uintptr = CacheLinePadSize
// The booleans in X86 contain the correspondingly named cpuid feature bit.
// HasAVX and HasAVX2 are only set if the OS does support XMM and YMM registers
// in addition to the cpuid feature bit being set.
// The struct is padded to avoid false sharing.
var X86 struct {
_ CacheLinePad
HasAES bool
HasADX bool
HasAVX bool
HasAVX2 bool
HasBMI1 bool
HasBMI2 bool
HasERMS bool
HasFMA bool
HasOSXSAVE bool
HasPCLMULQDQ bool
HasPOPCNT bool
HasRDTSCP bool
HasSHA bool
HasSSE3 bool
HasSSSE3 bool
HasSSE41 bool
HasSSE42 bool
_ CacheLinePad
}
// The booleans in ARM contain the correspondingly named cpu feature bit.
// The struct is padded to avoid false sharing.
var ARM struct {
_ CacheLinePad
HasVFPv4 bool
HasIDIVA bool
_ CacheLinePad
}
// The booleans in ARM64 contain the correspondingly named cpu feature bit.
// The struct is padded to avoid false sharing.
var ARM64 struct {
_ CacheLinePad
HasAES bool
HasPMULL bool
HasSHA1 bool
HasSHA2 bool
HasSHA512 bool
HasCRC32 bool
HasATOMICS bool
HasCPUID bool
IsNeoverseN1 bool
IsNeoverseV1 bool
_ CacheLinePad
}
var MIPS64X struct {
_ CacheLinePad
HasMSA bool // MIPS SIMD architecture
_ CacheLinePad
}
// For ppc64(le), it is safe to check only for ISA level starting on ISA v3.00,
// since there are no optional categories. There are some exceptions that also
// require kernel support to work (darn, scv), so there are feature bits for
// those as well. The minimum processor requirement is POWER8 (ISA 2.07).
// The struct is padded to avoid false sharing.
var PPC64 struct {
_ CacheLinePad
HasDARN bool // Hardware random number generator (requires kernel enablement)
HasSCV bool // Syscall vectored (requires kernel enablement)
IsPOWER8 bool // ISA v2.07 (POWER8)
IsPOWER9 bool // ISA v3.00 (POWER9)
IsPOWER10 bool // ISA v3.1 (POWER10)
_ CacheLinePad
}
var S390X struct {
_ CacheLinePad
HasZARCH bool // z architecture mode is active [mandatory]
HasSTFLE bool // store facility list extended [mandatory]
HasLDISP bool // long (20-bit) displacements [mandatory]
HasEIMM bool // 32-bit immediates [mandatory]
HasDFP bool // decimal floating point
HasETF3EH bool // ETF-3 enhanced
HasMSA bool // message security assist (CPACF)
HasAES bool // KM-AES{128,192,256} functions
HasAESCBC bool // KMC-AES{128,192,256} functions
HasAESCTR bool // KMCTR-AES{128,192,256} functions
HasAESGCM bool // KMA-GCM-AES{128,192,256} functions
HasGHASH bool // KIMD-GHASH function
HasSHA1 bool // K{I,L}MD-SHA-1 functions
HasSHA256 bool // K{I,L}MD-SHA-256 functions
HasSHA512 bool // K{I,L}MD-SHA-512 functions
HasSHA3 bool // K{I,L}MD-SHA3-{224,256,384,512} and K{I,L}MD-SHAKE-{128,256} functions
HasVX bool // vector facility. Note: the runtime sets this when it processes auxv records.
HasVXE bool // vector-enhancements facility 1
HasKDSA bool // elliptic curve functions
HasECDSA bool // NIST curves
HasEDDSA bool // Edwards curves
_ CacheLinePad
}
// Initialize examines the processor and sets the relevant variables above.
// This is called by the runtime package early in program initialization,
// before normal init functions are run. env is set by runtime if the OS supports
// cpu feature options in GODEBUG.
func Initialize(env string) {
doinit()
processOptions(env)
}
// options contains the cpu debug options that can be used in GODEBUG.
// Options are arch dependent and are added by the arch specific doinit functions.
// Features that are mandatory for the specific GOARCH should not be added to options
// (e.g. SSE2 on amd64).
var options []option
// Option names should be lower case. e.g. avx instead of AVX.
type option struct {
Name string
Feature *bool
Specified bool // whether feature value was specified in GODEBUG
Enable bool // whether feature should be enabled
}
// processOptions enables or disables CPU feature values based on the parsed env string.
// The env string is expected to be of the form cpu.feature1=value1,cpu.feature2=value2...
// where feature names is one of the architecture specific list stored in the
// cpu packages options variable and values are either 'on' or 'off'.
// If env contains cpu.all=off then all cpu features referenced through the options
// variable are disabled. Other feature names and values result in warning messages.
func processOptions(env string) {
field:
for env != "" {
field := ""
i := indexByte(env, ',')
if i < 0 {
field, env = env, ""
} else {
field, env = env[:i], env[i+1:]
}
if len(field) < 4 || field[:4] != "cpu." {
continue
}
i = indexByte(field, '=')
if i < 0 {
print("GODEBUG: no value specified for \"", field, "\"\n")
continue
}
key, value := field[4:i], field[i+1:] // e.g. "SSE2", "on"
var enable bool
switch value {
case "on":
enable = true
case "off":
enable = false
default:
print("GODEBUG: value \"", value, "\" not supported for cpu option \"", key, "\"\n")
continue field
}
if key == "all" {
for i := range options {
options[i].Specified = true
options[i].Enable = enable
}
continue field
}
for i := range options {
if options[i].Name == key {
options[i].Specified = true
options[i].Enable = enable
continue field
}
}
print("GODEBUG: unknown cpu feature \"", key, "\"\n")
}
for _, o := range options {
if !o.Specified {
continue
}
if o.Enable && !*o.Feature {
print("GODEBUG: can not enable \"", o.Name, "\", missing CPU support\n")
continue
}
*o.Feature = o.Enable
}
}
// indexByte returns the index of the first instance of c in s,
// or -1 if c is not present in s.
func indexByte(s string, c byte) int {
for i := 0; i < len(s); i++ {
if s[i] == c {
return i
}
}
return -1
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build 386 || amd64
package cpu
const CacheLinePadSize = 64
// cpuid is implemented in cpu_x86.s.
func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
// xgetbv with ecx = 0 is implemented in cpu_x86.s.
func xgetbv() (eax, edx uint32)
// getGOAMD64level is implemented in cpu_x86.s. Returns number in [1,4].
func getGOAMD64level() int32
const (
// edx bits
cpuid_SSE2 = 1 << 26
// ecx bits
cpuid_SSE3 = 1 << 0
cpuid_PCLMULQDQ = 1 << 1
cpuid_SSSE3 = 1 << 9
cpuid_FMA = 1 << 12
cpuid_SSE41 = 1 << 19
cpuid_SSE42 = 1 << 20
cpuid_POPCNT = 1 << 23
cpuid_AES = 1 << 25
cpuid_OSXSAVE = 1 << 27
cpuid_AVX = 1 << 28
// ebx bits
cpuid_BMI1 = 1 << 3
cpuid_AVX2 = 1 << 5
cpuid_BMI2 = 1 << 8
cpuid_ERMS = 1 << 9
cpuid_ADX = 1 << 19
cpuid_SHA = 1 << 29
// edx bits for CPUID 0x80000001
cpuid_RDTSCP = 1 << 27
)
var maxExtendedFunctionInformation uint32
func doinit() {
options = []option{
{Name: "adx", Feature: &X86.HasADX},
{Name: "aes", Feature: &X86.HasAES},
{Name: "erms", Feature: &X86.HasERMS},
{Name: "pclmulqdq", Feature: &X86.HasPCLMULQDQ},
{Name: "rdtscp", Feature: &X86.HasRDTSCP},
{Name: "sha", Feature: &X86.HasSHA},
}
level := getGOAMD64level()
if level < 2 {
// These options are required at level 2. At lower levels
// they can be turned off.
options = append(options,
option{Name: "popcnt", Feature: &X86.HasPOPCNT},
option{Name: "sse3", Feature: &X86.HasSSE3},
option{Name: "sse41", Feature: &X86.HasSSE41},
option{Name: "sse42", Feature: &X86.HasSSE42},
option{Name: "ssse3", Feature: &X86.HasSSSE3})
}
if level < 3 {
// These options are required at level 3. At lower levels
// they can be turned off.
options = append(options,
option{Name: "avx", Feature: &X86.HasAVX},
option{Name: "avx2", Feature: &X86.HasAVX2},
option{Name: "bmi1", Feature: &X86.HasBMI1},
option{Name: "bmi2", Feature: &X86.HasBMI2},
option{Name: "fma", Feature: &X86.HasFMA})
}
maxID, _, _, _ := cpuid(0, 0)
if maxID < 1 {
return
}
maxExtendedFunctionInformation, _, _, _ = cpuid(0x80000000, 0)
_, _, ecx1, _ := cpuid(1, 0)
X86.HasSSE3 = isSet(ecx1, cpuid_SSE3)
X86.HasPCLMULQDQ = isSet(ecx1, cpuid_PCLMULQDQ)
X86.HasSSSE3 = isSet(ecx1, cpuid_SSSE3)
X86.HasSSE41 = isSet(ecx1, cpuid_SSE41)
X86.HasSSE42 = isSet(ecx1, cpuid_SSE42)
X86.HasPOPCNT = isSet(ecx1, cpuid_POPCNT)
X86.HasAES = isSet(ecx1, cpuid_AES)
// OSXSAVE can be false when using older Operating Systems
// or when explicitly disabled on newer Operating Systems by
// e.g. setting the xsavedisable boot option on Windows 10.
X86.HasOSXSAVE = isSet(ecx1, cpuid_OSXSAVE)
// The FMA instruction set extension only has VEX prefixed instructions.
// VEX prefixed instructions require OSXSAVE to be enabled.
// See Intel 64 and IA-32 Architecture Software Developer’s Manual Volume 2
// Section 2.4 "AVX and SSE Instruction Exception Specification"
X86.HasFMA = isSet(ecx1, cpuid_FMA) && X86.HasOSXSAVE
osSupportsAVX := false
// For XGETBV, OSXSAVE bit is required and sufficient.
if X86.HasOSXSAVE {
eax, _ := xgetbv()
// Check if XMM and YMM registers have OS support.
osSupportsAVX = isSet(eax, 1<<1) && isSet(eax, 1<<2)
}
X86.HasAVX = isSet(ecx1, cpuid_AVX) && osSupportsAVX
if maxID < 7 {
return
}
_, ebx7, _, _ := cpuid(7, 0)
X86.HasBMI1 = isSet(ebx7, cpuid_BMI1)
X86.HasAVX2 = isSet(ebx7, cpuid_AVX2) && osSupportsAVX
X86.HasBMI2 = isSet(ebx7, cpuid_BMI2)
X86.HasERMS = isSet(ebx7, cpuid_ERMS)
X86.HasADX = isSet(ebx7, cpuid_ADX)
X86.HasSHA = isSet(ebx7, cpuid_SHA)
var maxExtendedInformation uint32
maxExtendedInformation, _, _, _ = cpuid(0x80000000, 0)
if maxExtendedInformation < 0x80000001 {
return
}
_, _, _, edxExt1 := cpuid(0x80000001, 0)
X86.HasRDTSCP = isSet(edxExt1, cpuid_RDTSCP)
}
func isSet(hwc uint32, value uint32) bool {
return hwc&value != 0
}
// Name returns the CPU name given by the vendor.
// If the CPU name can not be determined an
// empty string is returned.
func Name() string {
if maxExtendedFunctionInformation < 0x80000004 {
return ""
}
data := make([]byte, 0, 3*4*4)
var eax, ebx, ecx, edx uint32
eax, ebx, ecx, edx = cpuid(0x80000002, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
eax, ebx, ecx, edx = cpuid(0x80000003, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
eax, ebx, ecx, edx = cpuid(0x80000004, 0)
data = appendBytes(data, eax, ebx, ecx, edx)
// Trim leading spaces.
for len(data) > 0 && data[0] == ' ' {
data = data[1:]
}
// Trim tail after and including the first null byte.
for i, c := range data {
if c == '\x00' {
data = data[:i]
break
}
}
return string(data)
}
func appendBytes(b []byte, args ...uint32) []byte {
for _, arg := range args {
b = append(b,
byte((arg >> 0)),
byte((arg >> 8)),
byte((arg >> 16)),
byte((arg >> 24)))
}
return b
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dag
// Transpose reverses all edges in g.
func (g *Graph) Transpose() {
old := g.edges
g.edges = make(map[string]map[string]bool)
for _, n := range g.Nodes {
g.edges[n] = make(map[string]bool)
}
for from, tos := range old {
for to := range tos {
g.edges[to][from] = true
}
}
}
// Topo returns a topological sort of g. This function is deterministic.
func (g *Graph) Topo() []string {
topo := make([]string, 0, len(g.Nodes))
marks := make(map[string]bool)
var visit func(n string)
visit = func(n string) {
if marks[n] {
return
}
for _, to := range g.Edges(n) {
visit(to)
}
marks[n] = true
topo = append(topo, n)
}
for _, root := range g.Nodes {
visit(root)
}
for i, j := 0, len(topo)-1; i < j; i, j = i+1, j-1 {
topo[i], topo[j] = topo[j], topo[i]
}
return topo
}
// TransitiveReduction removes edges from g that are transitively
// reachable. g must be transitively closed.
func (g *Graph) TransitiveReduction() {
// For i -> j -> k, if i -> k exists, delete it.
for _, i := range g.Nodes {
for _, j := range g.Nodes {
if g.HasEdge(i, j) {
for _, k := range g.Nodes {
if g.HasEdge(j, k) {
g.DelEdge(i, k)
}
}
}
}
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package dag implements a language for expressing directed acyclic
// graphs.
//
// The general syntax of a rule is:
//
// a, b < c, d;
//
// which means c and d come after a and b in the partial order
// (that is, there are edges from c and d to a and b),
// but doesn't provide a relative order between a vs b or c vs d.
//
// The rules can chain together, as in:
//
// e < f, g < h;
//
// which is equivalent to
//
// e < f, g;
// f, g < h;
//
// Except for the special bottom element "NONE", each name
// must appear exactly once on the right-hand side of any rule.
// That rule serves as the definition of the allowed successor
// for that name. The definition must appear before any uses
// of the name on the left-hand side of a rule. (That is, the
// rules themselves must be ordered according to the partial
// order, for easier reading by people.)
//
// Negative assertions double-check the partial order:
//
// i !< j
//
// means that it must NOT be the case that i < j.
// Negative assertions may appear anywhere in the rules,
// even before i and j have been defined.
//
// Comments begin with #.
package dag
import (
"fmt"
"sort"
"strings"
)
type Graph struct {
Nodes []string
byLabel map[string]int
edges map[string]map[string]bool
}
func newGraph() *Graph {
return &Graph{byLabel: map[string]int{}, edges: map[string]map[string]bool{}}
}
func (g *Graph) addNode(label string) bool {
if _, ok := g.byLabel[label]; ok {
return false
}
g.byLabel[label] = len(g.Nodes)
g.Nodes = append(g.Nodes, label)
g.edges[label] = map[string]bool{}
return true
}
func (g *Graph) AddEdge(from, to string) {
g.edges[from][to] = true
}
func (g *Graph) DelEdge(from, to string) {
delete(g.edges[from], to)
}
func (g *Graph) HasEdge(from, to string) bool {
return g.edges[from] != nil && g.edges[from][to]
}
func (g *Graph) Edges(from string) []string {
edges := make([]string, 0, 16)
for k := range g.edges[from] {
edges = append(edges, k)
}
sort.Slice(edges, func(i, j int) bool { return g.byLabel[edges[i]] < g.byLabel[edges[j]] })
return edges
}
// Parse parses the DAG language and returns the transitive closure of
// the described graph. In the returned graph, there is an edge from "b"
// to "a" if b < a (or a > b) in the partial order.
func Parse(dag string) (*Graph, error) {
g := newGraph()
disallowed := []rule{}
rules, err := parseRules(dag)
if err != nil {
return nil, err
}
// TODO: Add line numbers to errors.
var errors []string
errorf := func(format string, a ...any) {
errors = append(errors, fmt.Sprintf(format, a...))
}
for _, r := range rules {
if r.op == "!<" {
disallowed = append(disallowed, r)
continue
}
for _, def := range r.def {
if def == "NONE" {
errorf("NONE cannot be a predecessor")
continue
}
if !g.addNode(def) {
errorf("multiple definitions for %s", def)
}
for _, less := range r.less {
if less == "NONE" {
continue
}
if _, ok := g.byLabel[less]; !ok {
errorf("use of %s before its definition", less)
} else {
g.AddEdge(def, less)
}
}
}
}
// Check for missing definition.
for _, tos := range g.edges {
for to := range tos {
if g.edges[to] == nil {
errorf("missing definition for %s", to)
}
}
}
// Complete transitive closure.
for _, k := range g.Nodes {
for _, i := range g.Nodes {
for _, j := range g.Nodes {
if i != k && k != j && g.HasEdge(i, k) && g.HasEdge(k, j) {
if i == j {
// Can only happen along with a "use of X before deps" error above,
// but this error is more specific - it makes clear that reordering the
// rules will not be enough to fix the problem.
errorf("graph cycle: %s < %s < %s", j, k, i)
}
g.AddEdge(i, j)
}
}
}
}
// Check negative assertions against completed allowed graph.
for _, bad := range disallowed {
for _, less := range bad.less {
for _, def := range bad.def {
if g.HasEdge(def, less) {
errorf("graph edge assertion failed: %s !< %s", less, def)
}
}
}
}
if len(errors) > 0 {
return nil, fmt.Errorf("%s", strings.Join(errors, "\n"))
}
return g, nil
}
// A rule is a line in the DAG language where "less < def" or "less !< def".
type rule struct {
less []string
op string // Either "<" or "!<"
def []string
}
type syntaxError string
func (e syntaxError) Error() string {
return string(e)
}
// parseRules parses the rules of a DAG.
func parseRules(rules string) (out []rule, err error) {
defer func() {
e := recover()
switch e := e.(type) {
case nil:
return
case syntaxError:
err = e
default:
panic(e)
}
}()
p := &rulesParser{lineno: 1, text: rules}
var prev []string
var op string
for {
list, tok := p.nextList()
if tok == "" {
if prev == nil {
break
}
p.syntaxError("unexpected EOF")
}
if prev != nil {
out = append(out, rule{prev, op, list})
}
prev = list
if tok == ";" {
prev = nil
op = ""
continue
}
if tok != "<" && tok != "!<" {
p.syntaxError("missing <")
}
op = tok
}
return out, err
}
// A rulesParser parses the depsRules syntax described above.
type rulesParser struct {
lineno int
lastWord string
text string
}
// syntaxError reports a parsing error.
func (p *rulesParser) syntaxError(msg string) {
panic(syntaxError(fmt.Sprintf("parsing graph: line %d: syntax error: %s near %s", p.lineno, msg, p.lastWord)))
}
// nextList parses and returns a comma-separated list of names.
func (p *rulesParser) nextList() (list []string, token string) {
for {
tok := p.nextToken()
switch tok {
case "":
if len(list) == 0 {
return nil, ""
}
fallthrough
case ",", "<", "!<", ";":
p.syntaxError("bad list syntax")
}
list = append(list, tok)
tok = p.nextToken()
if tok != "," {
return list, tok
}
}
}
// nextToken returns the next token in the deps rules,
// one of ";" "," "<" "!<" or a name.
func (p *rulesParser) nextToken() string {
for {
if p.text == "" {
return ""
}
switch p.text[0] {
case ';', ',', '<':
t := p.text[:1]
p.text = p.text[1:]
return t
case '!':
if len(p.text) < 2 || p.text[1] != '<' {
p.syntaxError("unexpected token !")
}
p.text = p.text[2:]
return "!<"
case '#':
i := strings.Index(p.text, "\n")
if i < 0 {
i = len(p.text)
}
p.text = p.text[i:]
continue
case '\n':
p.lineno++
fallthrough
case ' ', '\t':
p.text = p.text[1:]
continue
default:
i := strings.IndexAny(p.text, "!;,<#\n \t")
if i < 0 {
i = len(p.text)
}
t := p.text[:i]
p.text = p.text[i:]
p.lastWord = t
return t
}
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"fmt"
"sort"
"strings"
)
// A pair is a pair of values tracked for both the x and y side of a diff.
// It is typically a pair of line indexes.
type pair struct{ x, y int }
// Diff returns an anchored diff of the two texts old and new
// in the “unified diff” format. If old and new are identical,
// Diff returns a nil slice (no output).
//
// Unix diff implementations typically look for a diff with
// the smallest number of lines inserted and removed,
// which can in the worst case take time quadratic in the
// number of lines in the texts. As a result, many implementations
// either can be made to run for a long time or cut off the search
// after a predetermined amount of work.
//
// In contrast, this implementation looks for a diff with the
// smallest number of “unique” lines inserted and removed,
// where unique means a line that appears just once in both old and new.
// We call this an “anchored diff” because the unique lines anchor
// the chosen matching regions. An anchored diff is usually clearer
// than a standard diff, because the algorithm does not try to
// reuse unrelated blank lines or closing braces.
// The algorithm also guarantees to run in O(n log n) time
// instead of the standard O(n²) time.
//
// Some systems call this approach a “patience diff,” named for
// the “patience sorting” algorithm, itself named for a solitaire card game.
// We avoid that name for two reasons. First, the name has been used
// for a few different variants of the algorithm, so it is imprecise.
// Second, the name is frequently interpreted as meaning that you have
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
// when in fact the algorithm is faster than the standard one.
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
if bytes.Equal(old, new) {
return nil
}
x := lines(old)
y := lines(new)
// Print diff header.
var out bytes.Buffer
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
fmt.Fprintf(&out, "--- %s\n", oldName)
fmt.Fprintf(&out, "+++ %s\n", newName)
// Loop over matches to consider,
// expanding each match to include surrounding lines,
// and then printing diff chunks.
// To avoid setup/teardown cases outside the loop,
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
// in the sequence of matches.
var (
done pair // printed up to x[:done.x] and y[:done.y]
chunk pair // start lines of current chunk
count pair // number of lines from each side in current chunk
ctext []string // lines for current chunk
)
for _, m := range tgs(x, y) {
if m.x < done.x {
// Already handled scanning forward from earlier match.
continue
}
// Expand matching lines as far possible,
// establishing that x[start.x:end.x] == y[start.y:end.y].
// Note that on the first (or last) iteration we may (or definitey do)
// have an empty match: start.x==end.x and start.y==end.y.
start := m
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
start.x--
start.y--
}
end := m
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
end.x++
end.y++
}
// Emit the mismatched lines before start into this chunk.
// (No effect on first sentinel iteration, when start = {0,0}.)
for _, s := range x[done.x:start.x] {
ctext = append(ctext, "-"+s)
count.x++
}
for _, s := range y[done.y:start.y] {
ctext = append(ctext, "+"+s)
count.y++
}
// If we're not at EOF and have too few common lines,
// the chunk includes all the common lines and continues.
const C = 3 // number of context lines
if (end.x < len(x) || end.y < len(y)) &&
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
for _, s := range x[start.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
continue
}
// End chunk with common lines for context.
if len(ctext) > 0 {
n := end.x - start.x
if n > C {
n = C
}
for _, s := range x[start.x : start.x+n] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = pair{start.x + n, start.y + n}
// Format and emit chunk.
// Convert line numbers to 1-indexed.
// Special case: empty file shows up as 0,0 not 1,0.
if count.x > 0 {
chunk.x++
}
if count.y > 0 {
chunk.y++
}
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
for _, s := range ctext {
out.WriteString(s)
}
count.x = 0
count.y = 0
ctext = ctext[:0]
}
// If we reached EOF, we're done.
if end.x >= len(x) && end.y >= len(y) {
break
}
// Otherwise start a new chunk.
chunk = pair{end.x - C, end.y - C}
for _, s := range x[chunk.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
}
return out.Bytes()
}
// lines returns the lines in the file x, including newlines.
// If the file does not end in a newline, one is supplied
// along with a warning about the missing newline.
func lines(x []byte) []string {
l := strings.SplitAfter(string(x), "\n")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
} else {
// Treat last line as having a message about the missing newline attached,
// using the same text as BSD/GNU diff (including the leading backslash).
l[len(l)-1] += "\n\\ No newline at end of file\n"
}
return l
}
// tgs returns the pairs of indexes of the longest common subsequence
// of unique lines in x and y, where a unique line is one that appears
// once in x and once in y.
//
// The longest common subsequence algorithm is as described in
// Thomas G. Szymanski, “A Special Case of the Maximal Common
// Subsequence Problem,” Princeton TR #170 (January 1975),
// available at https://research.swtch.com/tgs170.pdf.
func tgs(x, y []string) []pair {
// Count the number of times each string appears in a and b.
// We only care about 0, 1, many, counted as 0, -1, -2
// for the x side and 0, -4, -8 for the y side.
// Using negative numbers now lets us distinguish positive line numbers later.
m := make(map[string]int)
for _, s := range x {
if c := m[s]; c > -2 {
m[s] = c - 1
}
}
for _, s := range y {
if c := m[s]; c > -8 {
m[s] = c - 4
}
}
// Now unique strings can be identified by m[s] = -1+-4.
//
// Gather the indexes of those strings in x and y, building:
// xi[i] = increasing indexes of unique strings in x.
// yi[i] = increasing indexes of unique strings in y.
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
var xi, yi, inv []int
for i, s := range y {
if m[s] == -1+-4 {
m[s] = len(yi)
yi = append(yi, i)
}
}
for i, s := range x {
if j, ok := m[s]; ok && j >= 0 {
xi = append(xi, i)
inv = append(inv, j)
}
}
// Apply Algorithm A from Szymanski's paper.
// In those terms, A = J = inv and B = [0, n).
// We add sentinel pairs {0,0}, and {len(x),len(y)}
// to the returned sequence, to help the processing loop.
J := inv
n := len(xi)
T := make([]int, n)
L := make([]int, n)
for i := range T {
T[i] = n + 1
}
for i := 0; i < n; i++ {
k := sort.Search(n, func(k int) bool {
return T[k] >= J[i]
})
T[k] = J[i]
L[i] = k + 1
}
k := 0
for _, v := range L {
if k < v {
k = v
}
}
seq := make([]pair, 2+k)
seq[1+k] = pair{len(x), len(y)} // sentinel at end
lastj := n
for i := n - 1; i >= 0; i-- {
if L[i] == k && J[i] < lastj {
seq[k] = pair{xi[i], yi[J[i]]}
k--
}
}
seq[0] = pair{0, 0} // sentinel at start
return seq
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fmtsort provides a general stable ordering mechanism
// for maps, on behalf of the fmt and text/template packages.
// It is not guaranteed to be efficient and works only for types
// that are valid map keys.
package fmtsort
import (
"reflect"
"sort"
)
// Note: Throughout this package we avoid calling reflect.Value.Interface as
// it is not always legal to do so and it's easier to avoid the issue than to face it.
// SortedMap represents a map's keys and values. The keys and values are
// aligned in index order: Value[i] is the value in the map corresponding to Key[i].
type SortedMap struct {
Key []reflect.Value
Value []reflect.Value
}
func (o *SortedMap) Len() int { return len(o.Key) }
func (o *SortedMap) Less(i, j int) bool { return compare(o.Key[i], o.Key[j]) < 0 }
func (o *SortedMap) Swap(i, j int) {
o.Key[i], o.Key[j] = o.Key[j], o.Key[i]
o.Value[i], o.Value[j] = o.Value[j], o.Value[i]
}
// Sort accepts a map and returns a SortedMap that has the same keys and
// values but in a stable sorted order according to the keys, modulo issues
// raised by unorderable key values such as NaNs.
//
// The ordering rules are more general than with Go's < operator:
//
// - when applicable, nil compares low
// - ints, floats, and strings order by <
// - NaN compares less than non-NaN floats
// - bool compares false before true
// - complex compares real, then imag
// - pointers compare by machine address
// - channel values compare by machine address
// - structs compare each field in turn
// - arrays compare each element in turn.
// Otherwise identical arrays compare by length.
// - interface values compare first by reflect.Type describing the concrete type
// and then by concrete value as described in the previous rules.
func Sort(mapValue reflect.Value) *SortedMap {
if mapValue.Type().Kind() != reflect.Map {
return nil
}
// Note: this code is arranged to not panic even in the presence
// of a concurrent map update. The runtime is responsible for
// yelling loudly if that happens. See issue 33275.
n := mapValue.Len()
key := make([]reflect.Value, 0, n)
value := make([]reflect.Value, 0, n)
iter := mapValue.MapRange()
for iter.Next() {
key = append(key, iter.Key())
value = append(value, iter.Value())
}
sorted := &SortedMap{
Key: key,
Value: value,
}
sort.Stable(sorted)
return sorted
}
// compare compares two values of the same type. It returns -1, 0, 1
// according to whether a > b (1), a == b (0), or a < b (-1).
// If the types differ, it returns -1.
// See the comment on Sort for the comparison rules.
func compare(aVal, bVal reflect.Value) int {
aType, bType := aVal.Type(), bVal.Type()
if aType != bType {
return -1 // No good answer possible, but don't return 0: they're not equal.
}
switch aVal.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
a, b := aVal.Int(), bVal.Int()
switch {
case a < b:
return -1
case a > b:
return 1
default:
return 0
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
a, b := aVal.Uint(), bVal.Uint()
switch {
case a < b:
return -1
case a > b:
return 1
default:
return 0
}
case reflect.String:
a, b := aVal.String(), bVal.String()
switch {
case a < b:
return -1
case a > b:
return 1
default:
return 0
}
case reflect.Float32, reflect.Float64:
return floatCompare(aVal.Float(), bVal.Float())
case reflect.Complex64, reflect.Complex128:
a, b := aVal.Complex(), bVal.Complex()
if c := floatCompare(real(a), real(b)); c != 0 {
return c
}
return floatCompare(imag(a), imag(b))
case reflect.Bool:
a, b := aVal.Bool(), bVal.Bool()
switch {
case a == b:
return 0
case a:
return 1
default:
return -1
}
case reflect.Pointer, reflect.UnsafePointer:
a, b := aVal.Pointer(), bVal.Pointer()
switch {
case a < b:
return -1
case a > b:
return 1
default:
return 0
}
case reflect.Chan:
if c, ok := nilCompare(aVal, bVal); ok {
return c
}
ap, bp := aVal.Pointer(), bVal.Pointer()
switch {
case ap < bp:
return -1
case ap > bp:
return 1
default:
return 0
}
case reflect.Struct:
for i := 0; i < aVal.NumField(); i++ {
if c := compare(aVal.Field(i), bVal.Field(i)); c != 0 {
return c
}
}
return 0
case reflect.Array:
for i := 0; i < aVal.Len(); i++ {
if c := compare(aVal.Index(i), bVal.Index(i)); c != 0 {
return c
}
}
return 0
case reflect.Interface:
if c, ok := nilCompare(aVal, bVal); ok {
return c
}
c := compare(reflect.ValueOf(aVal.Elem().Type()), reflect.ValueOf(bVal.Elem().Type()))
if c != 0 {
return c
}
return compare(aVal.Elem(), bVal.Elem())
default:
// Certain types cannot appear as keys (maps, funcs, slices), but be explicit.
panic("bad type in compare: " + aType.String())
}
}
// nilCompare checks whether either value is nil. If not, the boolean is false.
// If either value is nil, the boolean is true and the integer is the comparison
// value. The comparison is defined to be 0 if both are nil, otherwise the one
// nil value compares low. Both arguments must represent a chan, func,
// interface, map, pointer, or slice.
func nilCompare(aVal, bVal reflect.Value) (int, bool) {
if aVal.IsNil() {
if bVal.IsNil() {
return 0, true
}
return -1, true
}
if bVal.IsNil() {
return 1, true
}
return 0, false
}
// floatCompare compares two floating-point values. NaNs compare low.
func floatCompare(a, b float64) int {
switch {
case isNaN(a):
return -1 // No good answer if b is a NaN so don't bother checking.
case isNaN(b):
return 1
case a < b:
return -1
case a > b:
return 1
}
return 0
}
func isNaN(a float64) bool {
return a != a
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || linux || windows || freebsd) && (amd64 || arm64)
package fuzz
import (
"unsafe"
)
// coverage returns a []byte containing unique 8-bit counters for each edge of
// the instrumented source code. This coverage data will only be generated if
// `-d=libfuzzer` is set at build time. This can be used to understand the code
// coverage of a test execution.
func coverage() []byte {
addr := unsafe.Pointer(&_counters)
size := uintptr(unsafe.Pointer(&_ecounters)) - uintptr(addr)
return unsafe.Slice((*byte)(addr), int(size))
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"fmt"
"math/bits"
)
// ResetCoverage sets all of the counters for each edge of the instrumented
// source code to 0.
func ResetCoverage() {
cov := coverage()
for i := range cov {
cov[i] = 0
}
}
// SnapshotCoverage copies the current counter values into coverageSnapshot,
// preserving them for later inspection. SnapshotCoverage also rounds each
// counter down to the nearest power of two. This lets the coordinator store
// multiple values for each counter by OR'ing them together.
func SnapshotCoverage() {
cov := coverage()
for i, b := range cov {
b |= b >> 1
b |= b >> 2
b |= b >> 4
b -= b >> 1
coverageSnapshot[i] = b
}
}
// diffCoverage returns a set of bits set in snapshot but not in base.
// If there are no new bits set, diffCoverage returns nil.
func diffCoverage(base, snapshot []byte) []byte {
if len(base) != len(snapshot) {
panic(fmt.Sprintf("the number of coverage bits changed: before=%d, after=%d", len(base), len(snapshot)))
}
found := false
for i := range snapshot {
if snapshot[i]&^base[i] != 0 {
found = true
break
}
}
if !found {
return nil
}
diff := make([]byte, len(snapshot))
for i := range diff {
diff[i] = snapshot[i] &^ base[i]
}
return diff
}
// countNewCoverageBits returns the number of bits set in snapshot that are not
// set in base.
func countNewCoverageBits(base, snapshot []byte) int {
n := 0
for i := range snapshot {
n += bits.OnesCount8(snapshot[i] &^ base[i])
}
return n
}
// isCoverageSubset returns true if all the base coverage bits are set in
// snapshot.
func isCoverageSubset(base, snapshot []byte) bool {
for i, v := range base {
if v&snapshot[i] != v {
return false
}
}
return true
}
// hasCoverageBit returns true if snapshot has at least one bit set that is
// also set in base.
func hasCoverageBit(base, snapshot []byte) bool {
for i := range snapshot {
if snapshot[i]&base[i] != 0 {
return true
}
}
return false
}
func countBits(cov []byte) int {
n := 0
for _, c := range cov {
n += bits.OnesCount8(c)
}
return n
}
var (
coverageEnabled = len(coverage()) > 0
coverageSnapshot = make([]byte, len(coverage()))
// _counters and _ecounters mark the start and end, respectively, of where
// the 8-bit coverage counters reside in memory. They're known to cmd/link,
// which specially assigns their addresses for this purpose.
_counters, _ecounters [0]byte
)
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/token"
"math"
"strconv"
"strings"
"unicode/utf8"
)
// encVersion1 will be the first line of a file with version 1 encoding.
var encVersion1 = "go test fuzz v1"
// marshalCorpusFile encodes an arbitrary number of arguments into the file format for the
// corpus.
func marshalCorpusFile(vals ...any) []byte {
if len(vals) == 0 {
panic("must have at least one value to marshal")
}
b := bytes.NewBuffer([]byte(encVersion1 + "\n"))
// TODO(katiehockman): keep uint8 and int32 encoding where applicable,
// instead of changing to byte and rune respectively.
for _, val := range vals {
switch t := val.(type) {
case int, int8, int16, int64, uint, uint16, uint32, uint64, bool:
fmt.Fprintf(b, "%T(%v)\n", t, t)
case float32:
if math.IsNaN(float64(t)) && math.Float32bits(t) != math.Float32bits(float32(math.NaN())) {
// We encode unusual NaNs as hex values, because that is how users are
// likely to encounter them in literature about floating-point encoding.
// This allows us to reproduce fuzz failures that depend on the specific
// NaN representation (for float32 there are about 2^24 possibilities!),
// not just the fact that the value is *a* NaN.
//
// Note that the specific value of float32(math.NaN()) can vary based on
// whether the architecture represents signaling NaNs using a low bit
// (as is common) or a high bit (as commonly implemented on MIPS
// hardware before around 2012). We believe that the increase in clarity
// from identifying "NaN" with math.NaN() is worth the slight ambiguity
// from a platform-dependent value.
fmt.Fprintf(b, "math.Float32frombits(0x%x)\n", math.Float32bits(t))
} else {
// We encode all other values — including the NaN value that is
// bitwise-identical to float32(math.Nan()) — using the default
// formatting, which is equivalent to strconv.FormatFloat with format
// 'g' and can be parsed by strconv.ParseFloat.
//
// For an ordinary floating-point number this format includes
// sufficiently many digits to reconstruct the exact value. For positive
// or negative infinity it is the string "+Inf" or "-Inf". For positive
// or negative zero it is "0" or "-0". For NaN, it is the string "NaN".
fmt.Fprintf(b, "%T(%v)\n", t, t)
}
case float64:
if math.IsNaN(t) && math.Float64bits(t) != math.Float64bits(math.NaN()) {
fmt.Fprintf(b, "math.Float64frombits(0x%x)\n", math.Float64bits(t))
} else {
fmt.Fprintf(b, "%T(%v)\n", t, t)
}
case string:
fmt.Fprintf(b, "string(%q)\n", t)
case rune: // int32
// Although rune and int32 are represented by the same type, only a subset
// of valid int32 values can be expressed as rune literals. Notably,
// negative numbers, surrogate halves, and values above unicode.MaxRune
// have no quoted representation.
//
// fmt with "%q" (and the corresponding functions in the strconv package)
// would quote out-of-range values to the Unicode replacement character
// instead of the original value (see https://go.dev/issue/51526), so
// they must be treated as int32 instead.
//
// We arbitrarily draw the line at UTF-8 validity, which biases toward the
// "rune" interpretation. (However, we accept either format as input.)
if utf8.ValidRune(t) {
fmt.Fprintf(b, "rune(%q)\n", t)
} else {
fmt.Fprintf(b, "int32(%v)\n", t)
}
case byte: // uint8
// For bytes, we arbitrarily prefer the character interpretation.
// (Every byte has a valid character encoding.)
fmt.Fprintf(b, "byte(%q)\n", t)
case []byte: // []uint8
fmt.Fprintf(b, "[]byte(%q)\n", t)
default:
panic(fmt.Sprintf("unsupported type: %T", t))
}
}
return b.Bytes()
}
// unmarshalCorpusFile decodes corpus bytes into their respective values.
func unmarshalCorpusFile(b []byte) ([]any, error) {
if len(b) == 0 {
return nil, fmt.Errorf("cannot unmarshal empty string")
}
lines := bytes.Split(b, []byte("\n"))
if len(lines) < 2 {
return nil, fmt.Errorf("must include version and at least one value")
}
version := strings.TrimSuffix(string(lines[0]), "\r")
if version != encVersion1 {
return nil, fmt.Errorf("unknown encoding version: %s", version)
}
var vals []any
for _, line := range lines[1:] {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
v, err := parseCorpusValue(line)
if err != nil {
return nil, fmt.Errorf("malformed line %q: %v", line, err)
}
vals = append(vals, v)
}
return vals, nil
}
func parseCorpusValue(line []byte) (any, error) {
fs := token.NewFileSet()
expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
if err != nil {
return nil, err
}
call, ok := expr.(*ast.CallExpr)
if !ok {
return nil, fmt.Errorf("expected call expression")
}
if len(call.Args) != 1 {
return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
}
arg := call.Args[0]
if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
if arrayType.Len != nil {
return nil, fmt.Errorf("expected []byte or primitive type")
}
elt, ok := arrayType.Elt.(*ast.Ident)
if !ok || elt.Name != "byte" {
return nil, fmt.Errorf("expected []byte")
}
lit, ok := arg.(*ast.BasicLit)
if !ok || lit.Kind != token.STRING {
return nil, fmt.Errorf("string literal required for type []byte")
}
s, err := strconv.Unquote(lit.Value)
if err != nil {
return nil, err
}
return []byte(s), nil
}
var idType *ast.Ident
if selector, ok := call.Fun.(*ast.SelectorExpr); ok {
xIdent, ok := selector.X.(*ast.Ident)
if !ok || xIdent.Name != "math" {
return nil, fmt.Errorf("invalid selector type")
}
switch selector.Sel.Name {
case "Float64frombits":
idType = &ast.Ident{Name: "float64-bits"}
case "Float32frombits":
idType = &ast.Ident{Name: "float32-bits"}
default:
return nil, fmt.Errorf("invalid selector type")
}
} else {
idType, ok = call.Fun.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("expected []byte or primitive type")
}
if idType.Name == "bool" {
id, ok := arg.(*ast.Ident)
if !ok {
return nil, fmt.Errorf("malformed bool")
}
if id.Name == "true" {
return true, nil
} else if id.Name == "false" {
return false, nil
} else {
return nil, fmt.Errorf("true or false required for type bool")
}
}
}
var (
val string
kind token.Token
)
if op, ok := arg.(*ast.UnaryExpr); ok {
switch lit := op.X.(type) {
case *ast.BasicLit:
if op.Op != token.SUB {
return nil, fmt.Errorf("unsupported operation on int/float: %v", op.Op)
}
// Special case for negative numbers.
val = op.Op.String() + lit.Value // e.g. "-" + "124"
kind = lit.Kind
case *ast.Ident:
if lit.Name != "Inf" {
return nil, fmt.Errorf("expected operation on int or float type")
}
if op.Op == token.SUB {
val = "-Inf"
} else {
val = "+Inf"
}
kind = token.FLOAT
default:
return nil, fmt.Errorf("expected operation on int or float type")
}
} else {
switch lit := arg.(type) {
case *ast.BasicLit:
val, kind = lit.Value, lit.Kind
case *ast.Ident:
if lit.Name != "NaN" {
return nil, fmt.Errorf("literal value required for primitive type")
}
val, kind = "NaN", token.FLOAT
default:
return nil, fmt.Errorf("literal value required for primitive type")
}
}
switch typ := idType.Name; typ {
case "string":
if kind != token.STRING {
return nil, fmt.Errorf("string literal value required for type string")
}
return strconv.Unquote(val)
case "byte", "rune":
if kind == token.INT {
switch typ {
case "rune":
return parseInt(val, typ)
case "byte":
return parseUint(val, typ)
}
}
if kind != token.CHAR {
return nil, fmt.Errorf("character literal required for byte/rune types")
}
n := len(val)
if n < 2 {
return nil, fmt.Errorf("malformed character literal, missing single quotes")
}
code, _, _, err := strconv.UnquoteChar(val[1:n-1], '\'')
if err != nil {
return nil, err
}
if typ == "rune" {
return code, nil
}
if code >= 256 {
return nil, fmt.Errorf("can only encode single byte to a byte type")
}
return byte(code), nil
case "int", "int8", "int16", "int32", "int64":
if kind != token.INT {
return nil, fmt.Errorf("integer literal required for int types")
}
return parseInt(val, typ)
case "uint", "uint8", "uint16", "uint32", "uint64":
if kind != token.INT {
return nil, fmt.Errorf("integer literal required for uint types")
}
return parseUint(val, typ)
case "float32":
if kind != token.FLOAT && kind != token.INT {
return nil, fmt.Errorf("float or integer literal required for float32 type")
}
v, err := strconv.ParseFloat(val, 32)
return float32(v), err
case "float64":
if kind != token.FLOAT && kind != token.INT {
return nil, fmt.Errorf("float or integer literal required for float64 type")
}
return strconv.ParseFloat(val, 64)
case "float32-bits":
if kind != token.INT {
return nil, fmt.Errorf("integer literal required for math.Float32frombits type")
}
bits, err := parseUint(val, "uint32")
if err != nil {
return nil, err
}
return math.Float32frombits(bits.(uint32)), nil
case "float64-bits":
if kind != token.FLOAT && kind != token.INT {
return nil, fmt.Errorf("integer literal required for math.Float64frombits type")
}
bits, err := parseUint(val, "uint64")
if err != nil {
return nil, err
}
return math.Float64frombits(bits.(uint64)), nil
default:
return nil, fmt.Errorf("expected []byte or primitive type")
}
}
// parseInt returns an integer of value val and type typ.
func parseInt(val, typ string) (any, error) {
switch typ {
case "int":
// The int type may be either 32 or 64 bits. If 32, the fuzz tests in the
// corpus may include 64-bit values produced by fuzzing runs on 64-bit
// architectures. When running those tests, we implicitly wrap the values to
// fit in a regular int. (The test case is still “interesting”, even if the
// specific values of its inputs are platform-dependent.)
i, err := strconv.ParseInt(val, 0, 64)
return int(i), err
case "int8":
i, err := strconv.ParseInt(val, 0, 8)
return int8(i), err
case "int16":
i, err := strconv.ParseInt(val, 0, 16)
return int16(i), err
case "int32", "rune":
i, err := strconv.ParseInt(val, 0, 32)
return int32(i), err
case "int64":
return strconv.ParseInt(val, 0, 64)
default:
panic("unreachable")
}
}
// parseUint returns an unsigned integer of value val and type typ.
func parseUint(val, typ string) (any, error) {
switch typ {
case "uint":
i, err := strconv.ParseUint(val, 0, 64)
return uint(i), err
case "uint8", "byte":
i, err := strconv.ParseUint(val, 0, 8)
return uint8(i), err
case "uint16":
i, err := strconv.ParseUint(val, 0, 16)
return uint16(i), err
case "uint32":
i, err := strconv.ParseUint(val, 0, 32)
return uint32(i), err
case "uint64":
return strconv.ParseUint(val, 0, 64)
default:
panic("unreachable")
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fuzz provides common fuzzing functionality for tests built with
// "go test" and for programs that use fuzzing functionality in the testing
// package.
package fuzz
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"fmt"
"internal/godebug"
"io"
"math/bits"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"time"
)
// CoordinateFuzzingOpts is a set of arguments for CoordinateFuzzing.
// The zero value is valid for each field unless specified otherwise.
type CoordinateFuzzingOpts struct {
// Log is a writer for logging progress messages and warnings.
// If nil, io.Discard will be used instead.
Log io.Writer
// Timeout is the amount of wall clock time to spend fuzzing after the corpus
// has loaded. If zero, there will be no time limit.
Timeout time.Duration
// Limit is the number of random values to generate and test. If zero,
// there will be no limit on the number of generated values.
Limit int64
// MinimizeTimeout is the amount of wall clock time to spend minimizing
// after discovering a crasher. If zero, there will be no time limit. If
// MinimizeTimeout and MinimizeLimit are both zero, then minimization will
// be disabled.
MinimizeTimeout time.Duration
// MinimizeLimit is the maximum number of calls to the fuzz function to be
// made while minimizing after finding a crash. If zero, there will be no
// limit. Calls to the fuzz function made when minimizing also count toward
// Limit. If MinimizeTimeout and MinimizeLimit are both zero, then
// minimization will be disabled.
MinimizeLimit int64
// parallel is the number of worker processes to run in parallel. If zero,
// CoordinateFuzzing will run GOMAXPROCS workers.
Parallel int
// Seed is a list of seed values added by the fuzz target with testing.F.Add
// and in testdata.
Seed []CorpusEntry
// Types is the list of types which make up a corpus entry.
// Types must be set and must match values in Seed.
Types []reflect.Type
// CorpusDir is a directory where files containing values that crash the
// code being tested may be written. CorpusDir must be set.
CorpusDir string
// CacheDir is a directory containing additional "interesting" values.
// The fuzzer may derive new values from these, and may write new values here.
CacheDir string
}
// CoordinateFuzzing creates several worker processes and communicates with
// them to test random inputs that could trigger crashes and expose bugs.
// The worker processes run the same binary in the same directory with the
// same environment variables as the coordinator process. Workers also run
// with the same arguments as the coordinator, except with the -test.fuzzworker
// flag prepended to the argument list.
//
// If a crash occurs, the function will return an error containing information
// about the crash, which can be reported to the user.
func CoordinateFuzzing(ctx context.Context, opts CoordinateFuzzingOpts) (err error) {
if err := ctx.Err(); err != nil {
return err
}
if opts.Log == nil {
opts.Log = io.Discard
}
if opts.Parallel == 0 {
opts.Parallel = runtime.GOMAXPROCS(0)
}
if opts.Limit > 0 && int64(opts.Parallel) > opts.Limit {
// Don't start more workers than we need.
opts.Parallel = int(opts.Limit)
}
c, err := newCoordinator(opts)
if err != nil {
return err
}
if opts.Timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
defer cancel()
}
// fuzzCtx is used to stop workers, for example, after finding a crasher.
fuzzCtx, cancelWorkers := context.WithCancel(ctx)
defer cancelWorkers()
doneC := ctx.Done()
// stop is called when a worker encounters a fatal error.
var fuzzErr error
stopping := false
stop := func(err error) {
if err == fuzzCtx.Err() || isInterruptError(err) {
// Suppress cancellation errors and terminations due to SIGINT.
// The messages are not helpful since either the user triggered the error
// (with ^C) or another more helpful message will be printed (a crasher).
err = nil
}
if err != nil && (fuzzErr == nil || fuzzErr == ctx.Err()) {
fuzzErr = err
}
if stopping {
return
}
stopping = true
cancelWorkers()
doneC = nil
}
// Ensure that any crash we find is written to the corpus, even if an error
// or interruption occurs while minimizing it.
crashWritten := false
defer func() {
if c.crashMinimizing == nil || crashWritten {
return
}
werr := writeToCorpus(&c.crashMinimizing.entry, opts.CorpusDir)
if werr != nil {
err = fmt.Errorf("%w\n%v", err, werr)
return
}
if err == nil {
err = &crashError{
path: c.crashMinimizing.entry.Path,
err: errors.New(c.crashMinimizing.crasherMsg),
}
}
}()
// Start workers.
// TODO(jayconrod): do we want to support fuzzing different binaries?
dir := "" // same as self
binPath := os.Args[0]
args := append([]string{"-test.fuzzworker"}, os.Args[1:]...)
env := os.Environ() // same as self
errC := make(chan error)
workers := make([]*worker, opts.Parallel)
for i := range workers {
var err error
workers[i], err = newWorker(c, dir, binPath, args, env)
if err != nil {
return err
}
}
for i := range workers {
w := workers[i]
go func() {
err := w.coordinate(fuzzCtx)
if fuzzCtx.Err() != nil || isInterruptError(err) {
err = nil
}
cleanErr := w.cleanup()
if err == nil {
err = cleanErr
}
errC <- err
}()
}
// Main event loop.
// Do not return until all workers have terminated. We avoid a deadlock by
// receiving messages from workers even after ctx is cancelled.
activeWorkers := len(workers)
statTicker := time.NewTicker(3 * time.Second)
defer statTicker.Stop()
defer c.logStats()
c.logStats()
for {
var inputC chan fuzzInput
input, ok := c.peekInput()
if ok && c.crashMinimizing == nil && !stopping {
inputC = c.inputC
}
var minimizeC chan fuzzMinimizeInput
minimizeInput, ok := c.peekMinimizeInput()
if ok && !stopping {
minimizeC = c.minimizeC
}
select {
case <-doneC:
// Interrupted, cancelled, or timed out.
// stop sets doneC to nil so we don't busy wait here.
stop(ctx.Err())
case err := <-errC:
// A worker terminated, possibly after encountering a fatal error.
stop(err)
activeWorkers--
if activeWorkers == 0 {
return fuzzErr
}
case result := <-c.resultC:
// Received response from worker.
if stopping {
break
}
c.updateStats(result)
if result.crasherMsg != "" {
if c.warmupRun() && result.entry.IsSeed {
target := filepath.Base(c.opts.CorpusDir)
fmt.Fprintf(c.opts.Log, "failure while testing seed corpus entry: %s/%s\n", target, testName(result.entry.Parent))
stop(errors.New(result.crasherMsg))
break
}
if c.canMinimize() && result.canMinimize {
if c.crashMinimizing != nil {
// This crash is not minimized, and another crash is being minimized.
// Ignore this one and wait for the other one to finish.
break
}
// Found a crasher but haven't yet attempted to minimize it.
// Send it back to a worker for minimization. Disable inputC so
// other workers don't continue fuzzing.
c.crashMinimizing = &result
fmt.Fprintf(c.opts.Log, "fuzz: minimizing %d-byte failing input file\n", len(result.entry.Data))
c.queueForMinimization(result, nil)
} else if !crashWritten {
// Found a crasher that's either minimized or not minimizable.
// Write to corpus and stop.
err := writeToCorpus(&result.entry, opts.CorpusDir)
if err == nil {
crashWritten = true
err = &crashError{
path: result.entry.Path,
err: errors.New(result.crasherMsg),
}
}
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG new crasher, elapsed: %s, id: %s, parent: %s, gen: %d, size: %d, exec time: %s\n",
c.elapsed(),
result.entry.Path,
result.entry.Parent,
result.entry.Generation,
len(result.entry.Data),
result.entryDuration,
)
}
stop(err)
}
} else if result.coverageData != nil {
if c.warmupRun() {
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG processed an initial input, elapsed: %s, id: %s, new bits: %d, size: %d, exec time: %s\n",
c.elapsed(),
result.entry.Parent,
countBits(diffCoverage(c.coverageMask, result.coverageData)),
len(result.entry.Data),
result.entryDuration,
)
}
c.updateCoverage(result.coverageData)
c.warmupInputLeft--
if c.warmupInputLeft == 0 {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, gathering baseline coverage: %d/%d completed, now fuzzing with %d workers\n", c.elapsed(), c.warmupInputCount, c.warmupInputCount, c.opts.Parallel)
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG finished processing input corpus, elapsed: %s, entries: %d, initial coverage bits: %d\n",
c.elapsed(),
len(c.corpus.entries),
countBits(c.coverageMask),
)
}
}
} else if keepCoverage := diffCoverage(c.coverageMask, result.coverageData); keepCoverage != nil {
// Found a value that expanded coverage.
// It's not a crasher, but we may want to add it to the on-disk
// corpus and prioritize it for future fuzzing.
// TODO(jayconrod, katiehockman): Prioritize fuzzing these
// values which expanded coverage, perhaps based on the
// number of new edges that this result expanded.
// TODO(jayconrod, katiehockman): Don't write a value that's already
// in the corpus.
if c.canMinimize() && result.canMinimize && c.crashMinimizing == nil {
// Send back to workers to find a smaller value that preserves
// at least one new coverage bit.
c.queueForMinimization(result, keepCoverage)
} else {
// Update the coordinator's coverage mask and save the value.
inputSize := len(result.entry.Data)
entryNew, err := c.addCorpusEntries(true, result.entry)
if err != nil {
stop(err)
break
}
if !entryNew {
continue
}
c.updateCoverage(keepCoverage)
c.inputQueue.enqueue(result.entry)
c.interestingCount++
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG new interesting input, elapsed: %s, id: %s, parent: %s, gen: %d, new bits: %d, total bits: %d, size: %d, exec time: %s\n",
c.elapsed(),
result.entry.Path,
result.entry.Parent,
result.entry.Generation,
countBits(keepCoverage),
countBits(c.coverageMask),
inputSize,
result.entryDuration,
)
}
}
} else {
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG worker reported interesting input that doesn't expand coverage, elapsed: %s, id: %s, parent: %s, canMinimize: %t\n",
c.elapsed(),
result.entry.Path,
result.entry.Parent,
result.canMinimize,
)
}
}
} else if c.warmupRun() {
// No error or coverage data was reported for this input during
// warmup, so continue processing results.
c.warmupInputLeft--
if c.warmupInputLeft == 0 {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, testing seed corpus: %d/%d completed, now fuzzing with %d workers\n", c.elapsed(), c.warmupInputCount, c.warmupInputCount, c.opts.Parallel)
if shouldPrintDebugInfo() {
fmt.Fprintf(
c.opts.Log,
"DEBUG finished testing-only phase, elapsed: %s, entries: %d\n",
time.Since(c.startTime),
len(c.corpus.entries),
)
}
}
}
// Once the result has been processed, stop the worker if we
// have reached the fuzzing limit.
if c.opts.Limit > 0 && c.count >= c.opts.Limit {
stop(nil)
}
case inputC <- input:
// Sent the next input to a worker.
c.sentInput(input)
case minimizeC <- minimizeInput:
// Sent the next input for minimization to a worker.
c.sentMinimizeInput(minimizeInput)
case <-statTicker.C:
c.logStats()
}
}
// TODO(jayconrod,katiehockman): if a crasher can't be written to the corpus,
// write to the cache instead.
}
// crashError wraps a crasher written to the seed corpus. It saves the name
// of the file where the input causing the crasher was saved. The testing
// framework uses this to report a command to re-run that specific input.
type crashError struct {
path string
err error
}
func (e *crashError) Error() string {
return e.err.Error()
}
func (e *crashError) Unwrap() error {
return e.err
}
func (e *crashError) CrashPath() string {
return e.path
}
type corpus struct {
entries []CorpusEntry
hashes map[[sha256.Size]byte]bool
}
// addCorpusEntries adds entries to the corpus, and optionally writes the entries
// to the cache directory. If an entry is already in the corpus it is skipped. If
// all of the entries are unique, addCorpusEntries returns true and a nil error,
// if at least one of the entries was a duplicate, it returns false and a nil error.
func (c *coordinator) addCorpusEntries(addToCache bool, entries ...CorpusEntry) (bool, error) {
noDupes := true
for _, e := range entries {
data, err := corpusEntryData(e)
if err != nil {
return false, err
}
h := sha256.Sum256(data)
if c.corpus.hashes[h] {
noDupes = false
continue
}
if addToCache {
if err := writeToCorpus(&e, c.opts.CacheDir); err != nil {
return false, err
}
// For entries written to disk, we don't hold onto the bytes,
// since the corpus would consume a significant amount of
// memory.
e.Data = nil
}
c.corpus.hashes[h] = true
c.corpus.entries = append(c.corpus.entries, e)
}
return noDupes, nil
}
// CorpusEntry represents an individual input for fuzzing.
//
// We must use an equivalent type in the testing and testing/internal/testdeps
// packages, but testing can't import this package directly, and we don't want
// to export this type from testing. Instead, we use the same struct type and
// use a type alias (not a defined type) for convenience.
type CorpusEntry = struct {
Parent string
// Path is the path of the corpus file, if the entry was loaded from disk.
// For other entries, including seed values provided by f.Add, Path is the
// name of the test, e.g. seed#0 or its hash.
Path string
// Data is the raw input data. Data should only be populated for seed
// values. For on-disk corpus files, Data will be nil, as it will be loaded
// from disk using Path.
Data []byte
// Values is the unmarshaled values from a corpus file.
Values []any
Generation int
// IsSeed indicates whether this entry is part of the seed corpus.
IsSeed bool
}
// corpusEntryData returns the raw input bytes, either from the data struct
// field, or from disk.
func corpusEntryData(ce CorpusEntry) ([]byte, error) {
if ce.Data != nil {
return ce.Data, nil
}
return os.ReadFile(ce.Path)
}
type fuzzInput struct {
// entry is the value to test initially. The worker will randomly mutate
// values from this starting point.
entry CorpusEntry
// timeout is the time to spend fuzzing variations of this input,
// not including starting or cleaning up.
timeout time.Duration
// limit is the maximum number of calls to the fuzz function the worker may
// make. The worker may make fewer calls, for example, if it finds an
// error early. If limit is zero, there is no limit on calls to the
// fuzz function.
limit int64
// warmup indicates whether this is a warmup input before fuzzing begins. If
// true, the input should not be fuzzed.
warmup bool
// coverageData reflects the coordinator's current coverageMask.
coverageData []byte
}
type fuzzResult struct {
// entry is an interesting value or a crasher.
entry CorpusEntry
// crasherMsg is an error message from a crash. It's "" if no crash was found.
crasherMsg string
// canMinimize is true if the worker should attempt to minimize this result.
// It may be false because an attempt has already been made.
canMinimize bool
// coverageData is set if the worker found new coverage.
coverageData []byte
// limit is the number of values the coordinator asked the worker
// to test. 0 if there was no limit.
limit int64
// count is the number of values the worker actually tested.
count int64
// totalDuration is the time the worker spent testing inputs.
totalDuration time.Duration
// entryDuration is the time the worker spent execution an interesting result
entryDuration time.Duration
}
type fuzzMinimizeInput struct {
// entry is an interesting value or crasher to minimize.
entry CorpusEntry
// crasherMsg is an error message from a crash. It's "" if no crash was found.
// If set, the worker will attempt to find a smaller input that also produces
// an error, though not necessarily the same error.
crasherMsg string
// limit is the maximum number of calls to the fuzz function the worker may
// make. The worker may make fewer calls, for example, if it can't reproduce
// an error. If limit is zero, there is no limit on calls to the fuzz function.
limit int64
// timeout is the time to spend minimizing this input.
// A zero timeout means no limit.
timeout time.Duration
// keepCoverage is a set of coverage bits that entry found that were not in
// the coordinator's combined set. When minimizing, the worker should find an
// input that preserves at least one of these bits. keepCoverage is nil for
// crashing inputs.
keepCoverage []byte
}
// coordinator holds channels that workers can use to communicate with
// the coordinator.
type coordinator struct {
opts CoordinateFuzzingOpts
// startTime is the time we started the workers after loading the corpus.
// Used for logging.
startTime time.Time
// inputC is sent values to fuzz by the coordinator. Any worker may receive
// values from this channel. Workers send results to resultC.
inputC chan fuzzInput
// minimizeC is sent values to minimize by the coordinator. Any worker may
// receive values from this channel. Workers send results to resultC.
minimizeC chan fuzzMinimizeInput
// resultC is sent results of fuzzing by workers. The coordinator
// receives these. Multiple types of messages are allowed.
resultC chan fuzzResult
// count is the number of values fuzzed so far.
count int64
// countLastLog is the number of values fuzzed when the output was last
// logged.
countLastLog int64
// timeLastLog is the time at which the output was last logged.
timeLastLog time.Time
// interestingCount is the number of unique interesting values which have
// been found this execution.
interestingCount int
// warmupInputCount is the count of all entries in the corpus which will
// need to be received from workers to run once during warmup, but not fuzz.
// This could be for coverage data, or only for the purposes of verifying
// that the seed corpus doesn't have any crashers. See warmupRun.
warmupInputCount int
// warmupInputLeft is the number of entries in the corpus which still need
// to be received from workers to run once during warmup, but not fuzz.
// See warmupInputLeft.
warmupInputLeft int
// duration is the time spent fuzzing inside workers, not counting time
// starting up or tearing down.
duration time.Duration
// countWaiting is the number of fuzzing executions the coordinator is
// waiting on workers to complete.
countWaiting int64
// corpus is a set of interesting values, including the seed corpus and
// generated values that workers reported as interesting.
corpus corpus
// minimizationAllowed is true if one or more of the types of fuzz
// function's parameters can be minimized.
minimizationAllowed bool
// inputQueue is a queue of inputs that workers should try fuzzing. This is
// initially populated from the seed corpus and cached inputs. More inputs
// may be added as new coverage is discovered.
inputQueue queue
// minimizeQueue is a queue of inputs that caused errors or exposed new
// coverage. Workers should attempt to find smaller inputs that do the
// same thing.
minimizeQueue queue
// crashMinimizing is the crash that is currently being minimized.
crashMinimizing *fuzzResult
// coverageMask aggregates coverage that was found for all inputs in the
// corpus. Each byte represents a single basic execution block. Each set bit
// within the byte indicates that an input has triggered that block at least
// 1 << n times, where n is the position of the bit in the byte. For example, a
// value of 12 indicates that separate inputs have triggered this block
// between 4-7 times and 8-15 times.
coverageMask []byte
}
func newCoordinator(opts CoordinateFuzzingOpts) (*coordinator, error) {
// Make sure all of the seed corpus has marshalled data.
for i := range opts.Seed {
if opts.Seed[i].Data == nil && opts.Seed[i].Values != nil {
opts.Seed[i].Data = marshalCorpusFile(opts.Seed[i].Values...)
}
}
c := &coordinator{
opts: opts,
startTime: time.Now(),
inputC: make(chan fuzzInput),
minimizeC: make(chan fuzzMinimizeInput),
resultC: make(chan fuzzResult),
timeLastLog: time.Now(),
corpus: corpus{hashes: make(map[[sha256.Size]byte]bool)},
}
if err := c.readCache(); err != nil {
return nil, err
}
if opts.MinimizeLimit > 0 || opts.MinimizeTimeout > 0 {
for _, t := range opts.Types {
if isMinimizable(t) {
c.minimizationAllowed = true
break
}
}
}
covSize := len(coverage())
if covSize == 0 {
fmt.Fprintf(c.opts.Log, "warning: the test binary was not built with coverage instrumentation, so fuzzing will run without coverage guidance and may be inefficient\n")
// Even though a coverage-only run won't occur, we should still run all
// of the seed corpus to make sure there are no existing failures before
// we start fuzzing.
c.warmupInputCount = len(c.opts.Seed)
for _, e := range c.opts.Seed {
c.inputQueue.enqueue(e)
}
} else {
c.warmupInputCount = len(c.corpus.entries)
for _, e := range c.corpus.entries {
c.inputQueue.enqueue(e)
}
// Set c.coverageMask to a clean []byte full of zeros.
c.coverageMask = make([]byte, covSize)
}
c.warmupInputLeft = c.warmupInputCount
if len(c.corpus.entries) == 0 {
fmt.Fprintf(c.opts.Log, "warning: starting with empty corpus\n")
var vals []any
for _, t := range opts.Types {
vals = append(vals, zeroValue(t))
}
data := marshalCorpusFile(vals...)
h := sha256.Sum256(data)
name := fmt.Sprintf("%x", h[:4])
c.addCorpusEntries(false, CorpusEntry{Path: name, Data: data})
}
return c, nil
}
func (c *coordinator) updateStats(result fuzzResult) {
c.count += result.count
c.countWaiting -= result.limit
c.duration += result.totalDuration
}
func (c *coordinator) logStats() {
now := time.Now()
if c.warmupRun() {
runSoFar := c.warmupInputCount - c.warmupInputLeft
if coverageEnabled {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, gathering baseline coverage: %d/%d completed\n", c.elapsed(), runSoFar, c.warmupInputCount)
} else {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, testing seed corpus: %d/%d completed\n", c.elapsed(), runSoFar, c.warmupInputCount)
}
} else if c.crashMinimizing != nil {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, minimizing\n", c.elapsed())
} else {
rate := float64(c.count-c.countLastLog) / now.Sub(c.timeLastLog).Seconds()
if coverageEnabled {
total := c.warmupInputCount + c.interestingCount
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, execs: %d (%.0f/sec), new interesting: %d (total: %d)\n", c.elapsed(), c.count, rate, c.interestingCount, total)
} else {
fmt.Fprintf(c.opts.Log, "fuzz: elapsed: %s, execs: %d (%.0f/sec)\n", c.elapsed(), c.count, rate)
}
}
c.countLastLog = c.count
c.timeLastLog = now
}
// peekInput returns the next value that should be sent to workers.
// If the number of executions is limited, the returned value includes
// a limit for one worker. If there are no executions left, peekInput returns
// a zero value and false.
//
// peekInput doesn't actually remove the input from the queue. The caller
// must call sentInput after sending the input.
//
// If the input queue is empty and the coverage/testing-only run has completed,
// queue refills it from the corpus.
func (c *coordinator) peekInput() (fuzzInput, bool) {
if c.opts.Limit > 0 && c.count+c.countWaiting >= c.opts.Limit {
// Already making the maximum number of calls to the fuzz function.
// Don't send more inputs right now.
return fuzzInput{}, false
}
if c.inputQueue.len == 0 {
if c.warmupRun() {
// Wait for coverage/testing-only run to finish before sending more
// inputs.
return fuzzInput{}, false
}
c.refillInputQueue()
}
entry, ok := c.inputQueue.peek()
if !ok {
panic("input queue empty after refill")
}
input := fuzzInput{
entry: entry.(CorpusEntry),
timeout: workerFuzzDuration,
warmup: c.warmupRun(),
}
if c.coverageMask != nil {
input.coverageData = bytes.Clone(c.coverageMask)
}
if input.warmup {
// No fuzzing will occur, but it should count toward the limit set by
// -fuzztime.
input.limit = 1
return input, true
}
if c.opts.Limit > 0 {
input.limit = c.opts.Limit / int64(c.opts.Parallel)
if c.opts.Limit%int64(c.opts.Parallel) > 0 {
input.limit++
}
remaining := c.opts.Limit - c.count - c.countWaiting
if input.limit > remaining {
input.limit = remaining
}
}
return input, true
}
// sentInput updates internal counters after an input is sent to c.inputC.
func (c *coordinator) sentInput(input fuzzInput) {
c.inputQueue.dequeue()
c.countWaiting += input.limit
}
// refillInputQueue refills the input queue from the corpus after it becomes
// empty.
func (c *coordinator) refillInputQueue() {
for _, e := range c.corpus.entries {
c.inputQueue.enqueue(e)
}
}
// queueForMinimization creates a fuzzMinimizeInput from result and adds it
// to the minimization queue to be sent to workers.
func (c *coordinator) queueForMinimization(result fuzzResult, keepCoverage []byte) {
if result.crasherMsg != "" {
c.minimizeQueue.clear()
}
input := fuzzMinimizeInput{
entry: result.entry,
crasherMsg: result.crasherMsg,
keepCoverage: keepCoverage,
}
c.minimizeQueue.enqueue(input)
}
// peekMinimizeInput returns the next input that should be sent to workers for
// minimization.
func (c *coordinator) peekMinimizeInput() (fuzzMinimizeInput, bool) {
if !c.canMinimize() {
// Already making the maximum number of calls to the fuzz function.
// Don't send more inputs right now.
return fuzzMinimizeInput{}, false
}
v, ok := c.minimizeQueue.peek()
if !ok {
return fuzzMinimizeInput{}, false
}
input := v.(fuzzMinimizeInput)
if c.opts.MinimizeTimeout > 0 {
input.timeout = c.opts.MinimizeTimeout
}
if c.opts.MinimizeLimit > 0 {
input.limit = c.opts.MinimizeLimit
} else if c.opts.Limit > 0 {
if input.crasherMsg != "" {
input.limit = c.opts.Limit
} else {
input.limit = c.opts.Limit / int64(c.opts.Parallel)
if c.opts.Limit%int64(c.opts.Parallel) > 0 {
input.limit++
}
}
}
if c.opts.Limit > 0 {
remaining := c.opts.Limit - c.count - c.countWaiting
if input.limit > remaining {
input.limit = remaining
}
}
return input, true
}
// sentMinimizeInput removes an input from the minimization queue after it's
// sent to minimizeC.
func (c *coordinator) sentMinimizeInput(input fuzzMinimizeInput) {
c.minimizeQueue.dequeue()
c.countWaiting += input.limit
}
// warmupRun returns true while the coordinator is running inputs without
// mutating them as a warmup before fuzzing. This could be to gather baseline
// coverage data for entries in the corpus, or to test all of the seed corpus
// for errors before fuzzing begins.
//
// The coordinator doesn't store coverage data in the cache with each input
// because that data would be invalid when counter offsets in the test binary
// change.
//
// When gathering coverage, the coordinator sends each entry to a worker to
// gather coverage for that entry only, without fuzzing or minimizing. This
// phase ends when all workers have finished, and the coordinator has a combined
// coverage map.
func (c *coordinator) warmupRun() bool {
return c.warmupInputLeft > 0
}
// updateCoverage sets bits in c.coverageMask that are set in newCoverage.
// updateCoverage returns the number of newly set bits. See the comment on
// coverageMask for the format.
func (c *coordinator) updateCoverage(newCoverage []byte) int {
if len(newCoverage) != len(c.coverageMask) {
panic(fmt.Sprintf("number of coverage counters changed at runtime: %d, expected %d", len(newCoverage), len(c.coverageMask)))
}
newBitCount := 0
for i := range newCoverage {
diff := newCoverage[i] &^ c.coverageMask[i]
newBitCount += bits.OnesCount8(diff)
c.coverageMask[i] |= newCoverage[i]
}
return newBitCount
}
// canMinimize returns whether the coordinator should attempt to find smaller
// inputs that reproduce a crash or new coverage.
func (c *coordinator) canMinimize() bool {
return c.minimizationAllowed &&
(c.opts.Limit == 0 || c.count+c.countWaiting < c.opts.Limit)
}
func (c *coordinator) elapsed() time.Duration {
return time.Since(c.startTime).Round(1 * time.Second)
}
// readCache creates a combined corpus from seed values and values in the cache
// (in GOCACHE/fuzz).
//
// TODO(fuzzing): need a mechanism that can remove values that
// aren't useful anymore, for example, because they have the wrong type.
func (c *coordinator) readCache() error {
if _, err := c.addCorpusEntries(false, c.opts.Seed...); err != nil {
return err
}
entries, err := ReadCorpus(c.opts.CacheDir, c.opts.Types)
if err != nil {
if _, ok := err.(*MalformedCorpusError); !ok {
// It's okay if some files in the cache directory are malformed and
// are not included in the corpus, but fail if it's an I/O error.
return err
}
// TODO(jayconrod,katiehockman): consider printing some kind of warning
// indicating the number of files which were skipped because they are
// malformed.
}
if _, err := c.addCorpusEntries(false, entries...); err != nil {
return err
}
return nil
}
// MalformedCorpusError is an error found while reading the corpus from the
// filesystem. All of the errors are stored in the errs list. The testing
// framework uses this to report malformed files in testdata.
type MalformedCorpusError struct {
errs []error
}
func (e *MalformedCorpusError) Error() string {
var msgs []string
for _, s := range e.errs {
msgs = append(msgs, s.Error())
}
return strings.Join(msgs, "\n")
}
// ReadCorpus reads the corpus from the provided dir. The returned corpus
// entries are guaranteed to match the given types. Any malformed files will
// be saved in a MalformedCorpusError and returned, along with the most recent
// error.
func ReadCorpus(dir string, types []reflect.Type) ([]CorpusEntry, error) {
files, err := os.ReadDir(dir)
if os.IsNotExist(err) {
return nil, nil // No corpus to read
} else if err != nil {
return nil, fmt.Errorf("reading seed corpus from testdata: %v", err)
}
var corpus []CorpusEntry
var errs []error
for _, file := range files {
// TODO(jayconrod,katiehockman): determine when a file is a fuzzing input
// based on its name. We should only read files created by writeToCorpus.
// If we read ALL files, we won't be able to change the file format by
// changing the extension. We also won't be able to add files like
// README.txt explaining why the directory exists.
if file.IsDir() {
continue
}
filename := filepath.Join(dir, file.Name())
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to read corpus file: %v", err)
}
var vals []any
vals, err = readCorpusData(data, types)
if err != nil {
errs = append(errs, fmt.Errorf("%q: %v", filename, err))
continue
}
corpus = append(corpus, CorpusEntry{Path: filename, Values: vals})
}
if len(errs) > 0 {
return corpus, &MalformedCorpusError{errs: errs}
}
return corpus, nil
}
func readCorpusData(data []byte, types []reflect.Type) ([]any, error) {
vals, err := unmarshalCorpusFile(data)
if err != nil {
return nil, fmt.Errorf("unmarshal: %v", err)
}
if err = CheckCorpus(vals, types); err != nil {
return nil, err
}
return vals, nil
}
// CheckCorpus verifies that the types in vals match the expected types
// provided.
func CheckCorpus(vals []any, types []reflect.Type) error {
if len(vals) != len(types) {
return fmt.Errorf("wrong number of values in corpus entry: %d, want %d", len(vals), len(types))
}
valsT := make([]reflect.Type, len(vals))
for valsI, v := range vals {
valsT[valsI] = reflect.TypeOf(v)
}
for i := range types {
if valsT[i] != types[i] {
return fmt.Errorf("mismatched types in corpus entry: %v, want %v", valsT, types)
}
}
return nil
}
// writeToCorpus atomically writes the given bytes to a new file in testdata. If
// the directory does not exist, it will create one. If the file already exists,
// writeToCorpus will not rewrite it. writeToCorpus sets entry.Path to the new
// file that was just written or an error if it failed.
func writeToCorpus(entry *CorpusEntry, dir string) (err error) {
sum := fmt.Sprintf("%x", sha256.Sum256(entry.Data))[:16]
entry.Path = filepath.Join(dir, sum)
if err := os.MkdirAll(dir, 0777); err != nil {
return err
}
if err := os.WriteFile(entry.Path, entry.Data, 0666); err != nil {
os.Remove(entry.Path) // remove partially written file
return err
}
return nil
}
func testName(path string) string {
return filepath.Base(path)
}
func zeroValue(t reflect.Type) any {
for _, v := range zeroVals {
if reflect.TypeOf(v) == t {
return v
}
}
panic(fmt.Sprintf("unsupported type: %v", t))
}
var zeroVals []any = []any{
[]byte(""),
string(""),
false,
byte(0),
rune(0),
float32(0),
float64(0),
int(0),
int8(0),
int16(0),
int32(0),
int64(0),
uint(0),
uint8(0),
uint16(0),
uint32(0),
uint64(0),
}
var debugInfo = godebug.New("fuzzdebug").Value() == "1"
func shouldPrintDebugInfo() bool {
return debugInfo
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"bytes"
"fmt"
"os"
"unsafe"
)
// sharedMem manages access to a region of virtual memory mapped from a file,
// shared between multiple processes. The region includes space for a header and
// a value of variable length.
//
// When fuzzing, the coordinator creates a sharedMem from a temporary file for
// each worker. This buffer is used to pass values to fuzz between processes.
// Care must be taken to manage access to shared memory across processes;
// sharedMem provides no synchronization on its own. See workerComm for an
// explanation.
type sharedMem struct {
// f is the file mapped into memory.
f *os.File
// region is the mapped region of virtual memory for f. The content of f may
// be read or written through this slice.
region []byte
// removeOnClose is true if the file should be deleted by Close.
removeOnClose bool
// sys contains OS-specific information.
sys sharedMemSys
}
// sharedMemHeader stores metadata in shared memory.
type sharedMemHeader struct {
// count is the number of times the worker has called the fuzz function.
// May be reset by coordinator.
count int64
// valueLen is the number of bytes in region which should be read.
valueLen int
// randState and randInc hold the state of a pseudo-random number generator.
randState, randInc uint64
// rawInMem is true if the region holds raw bytes, which occurs during
// minimization. If true after the worker fails during minimization, this
// indicates that an unrecoverable error occurred, and the region can be
// used to retrieve the raw bytes that caused the error.
rawInMem bool
}
// sharedMemSize returns the size needed for a shared memory buffer that can
// contain values of the given size.
func sharedMemSize(valueSize int) int {
// TODO(jayconrod): set a reasonable maximum size per platform.
return int(unsafe.Sizeof(sharedMemHeader{})) + valueSize
}
// sharedMemTempFile creates a new temporary file of the given size, then maps
// it into memory. The file will be removed when the Close method is called.
func sharedMemTempFile(size int) (m *sharedMem, err error) {
// Create a temporary file.
f, err := os.CreateTemp("", "fuzz-*")
if err != nil {
return nil, err
}
defer func() {
if err != nil {
f.Close()
os.Remove(f.Name())
}
}()
// Resize it to the correct size.
totalSize := sharedMemSize(size)
if err := f.Truncate(int64(totalSize)); err != nil {
return nil, err
}
// Map the file into memory.
removeOnClose := true
return sharedMemMapFile(f, totalSize, removeOnClose)
}
// header returns a pointer to metadata within the shared memory region.
func (m *sharedMem) header() *sharedMemHeader {
return (*sharedMemHeader)(unsafe.Pointer(&m.region[0]))
}
// valueRef returns the value currently stored in shared memory. The returned
// slice points to shared memory; it is not a copy.
func (m *sharedMem) valueRef() []byte {
length := m.header().valueLen
valueOffset := int(unsafe.Sizeof(sharedMemHeader{}))
return m.region[valueOffset : valueOffset+length]
}
// valueCopy returns a copy of the value stored in shared memory.
func (m *sharedMem) valueCopy() []byte {
ref := m.valueRef()
return bytes.Clone(ref)
}
// setValue copies the data in b into the shared memory buffer and sets
// the length. len(b) must be less than or equal to the capacity of the buffer
// (as returned by cap(m.value())).
func (m *sharedMem) setValue(b []byte) {
v := m.valueRef()
if len(b) > cap(v) {
panic(fmt.Sprintf("value length %d larger than shared memory capacity %d", len(b), cap(v)))
}
m.header().valueLen = len(b)
copy(v[:cap(v)], b)
}
// setValueLen sets the length of the shared memory buffer returned by valueRef
// to n, which may be at most the cap of that slice.
//
// Note that we can only store the length in the shared memory header. The full
// slice header contains a pointer, which is likely only valid for one process,
// since each process can map shared memory at a different virtual address.
func (m *sharedMem) setValueLen(n int) {
v := m.valueRef()
if n > cap(v) {
panic(fmt.Sprintf("length %d larger than shared memory capacity %d", n, cap(v)))
}
m.header().valueLen = n
}
// TODO(jayconrod): add method to resize the buffer. We'll need that when the
// mutator can increase input length. Only the coordinator will be able to
// do it, since we'll need to send a message to the worker telling it to
// remap the file.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"reflect"
)
func isMinimizable(t reflect.Type) bool {
return t == reflect.TypeOf("") || t == reflect.TypeOf([]byte(nil))
}
func minimizeBytes(v []byte, try func([]byte) bool, shouldStop func() bool) {
tmp := make([]byte, len(v))
// If minimization was successful at any point during minimizeBytes,
// then the vals slice in (*workerServer).minimizeInput will point to
// tmp. Since tmp is altered while making new candidates, we need to
// make sure that it is equal to the correct value, v, before exiting
// this function.
defer copy(tmp, v)
// First, try to cut the tail.
for n := 1024; n != 0; n /= 2 {
for len(v) > n {
if shouldStop() {
return
}
candidate := v[:len(v)-n]
if !try(candidate) {
break
}
// Set v to the new value to continue iterating.
v = candidate
}
}
// Then, try to remove each individual byte.
for i := 0; i < len(v)-1; i++ {
if shouldStop() {
return
}
candidate := tmp[:len(v)-1]
copy(candidate[:i], v[:i])
copy(candidate[i:], v[i+1:])
if !try(candidate) {
continue
}
// Update v to delete the value at index i.
copy(v[i:], v[i+1:])
v = v[:len(candidate)]
// v[i] is now different, so decrement i to redo this iteration
// of the loop with the new value.
i--
}
// Then, try to remove each possible subset of bytes.
for i := 0; i < len(v)-1; i++ {
copy(tmp, v[:i])
for j := len(v); j > i+1; j-- {
if shouldStop() {
return
}
candidate := tmp[:len(v)-j+i]
copy(candidate[i:], v[j:])
if !try(candidate) {
continue
}
// Update v and reset the loop with the new length.
copy(v[i:], v[j:])
v = v[:len(candidate)]
j = len(v)
}
}
// Then, try to make it more simplified and human-readable by trying to replace each
// byte with a printable character.
printableChars := []byte("012789ABCXYZabcxyz !\"#$%&'()*+,.")
for i, b := range v {
if shouldStop() {
return
}
for _, pc := range printableChars {
v[i] = pc
if try(v) {
// Successful. Move on to the next byte in v.
break
}
// Unsuccessful. Revert v[i] back to original value.
v[i] = b
}
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"encoding/binary"
"fmt"
"math"
"unsafe"
)
type mutator struct {
r mutatorRand
scratch []byte // scratch slice to avoid additional allocations
}
func newMutator() *mutator {
return &mutator{r: newPcgRand()}
}
func (m *mutator) rand(n int) int {
return m.r.intn(n)
}
func (m *mutator) randByteOrder() binary.ByteOrder {
if m.r.bool() {
return binary.LittleEndian
}
return binary.BigEndian
}
// chooseLen chooses length of range mutation in range [1,n]. It gives
// preference to shorter ranges.
func (m *mutator) chooseLen(n int) int {
switch x := m.rand(100); {
case x < 90:
return m.rand(min(8, n)) + 1
case x < 99:
return m.rand(min(32, n)) + 1
default:
return m.rand(n) + 1
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// mutate performs several mutations on the provided values.
func (m *mutator) mutate(vals []any, maxBytes int) {
// TODO(katiehockman): pull some of these functions into helper methods and
// test that each case is working as expected.
// TODO(katiehockman): perform more types of mutations for []byte.
// maxPerVal will represent the maximum number of bytes that each value be
// allowed after mutating, giving an equal amount of capacity to each line.
// Allow a little wiggle room for the encoding.
maxPerVal := maxBytes/len(vals) - 100
// Pick a random value to mutate.
// TODO: consider mutating more than one value at a time.
i := m.rand(len(vals))
switch v := vals[i].(type) {
case int:
vals[i] = int(m.mutateInt(int64(v), maxInt))
case int8:
vals[i] = int8(m.mutateInt(int64(v), math.MaxInt8))
case int16:
vals[i] = int16(m.mutateInt(int64(v), math.MaxInt16))
case int64:
vals[i] = m.mutateInt(v, maxInt)
case uint:
vals[i] = uint(m.mutateUInt(uint64(v), maxUint))
case uint16:
vals[i] = uint16(m.mutateUInt(uint64(v), math.MaxUint16))
case uint32:
vals[i] = uint32(m.mutateUInt(uint64(v), math.MaxUint32))
case uint64:
vals[i] = m.mutateUInt(uint64(v), maxUint)
case float32:
vals[i] = float32(m.mutateFloat(float64(v), math.MaxFloat32))
case float64:
vals[i] = m.mutateFloat(v, math.MaxFloat64)
case bool:
if m.rand(2) == 1 {
vals[i] = !v // 50% chance of flipping the bool
}
case rune: // int32
vals[i] = rune(m.mutateInt(int64(v), math.MaxInt32))
case byte: // uint8
vals[i] = byte(m.mutateUInt(uint64(v), math.MaxUint8))
case string:
if len(v) > maxPerVal {
panic(fmt.Sprintf("cannot mutate bytes of length %d", len(v)))
}
if cap(m.scratch) < maxPerVal {
m.scratch = append(make([]byte, 0, maxPerVal), v...)
} else {
m.scratch = m.scratch[:len(v)]
copy(m.scratch, v)
}
m.mutateBytes(&m.scratch)
vals[i] = string(m.scratch)
case []byte:
if len(v) > maxPerVal {
panic(fmt.Sprintf("cannot mutate bytes of length %d", len(v)))
}
if cap(m.scratch) < maxPerVal {
m.scratch = append(make([]byte, 0, maxPerVal), v...)
} else {
m.scratch = m.scratch[:len(v)]
copy(m.scratch, v)
}
m.mutateBytes(&m.scratch)
vals[i] = m.scratch
default:
panic(fmt.Sprintf("type not supported for mutating: %T", vals[i]))
}
}
func (m *mutator) mutateInt(v, maxValue int64) int64 {
var max int64
for {
max = 100
switch m.rand(2) {
case 0:
// Add a random number
if v >= maxValue {
continue
}
if v > 0 && maxValue-v < max {
// Don't let v exceed maxValue
max = maxValue - v
}
v += int64(1 + m.rand(int(max)))
return v
case 1:
// Subtract a random number
if v <= -maxValue {
continue
}
if v < 0 && maxValue+v < max {
// Don't let v drop below -maxValue
max = maxValue + v
}
v -= int64(1 + m.rand(int(max)))
return v
}
}
}
func (m *mutator) mutateUInt(v, maxValue uint64) uint64 {
var max uint64
for {
max = 100
switch m.rand(2) {
case 0:
// Add a random number
if v >= maxValue {
continue
}
if v > 0 && maxValue-v < max {
// Don't let v exceed maxValue
max = maxValue - v
}
v += uint64(1 + m.rand(int(max)))
return v
case 1:
// Subtract a random number
if v <= 0 {
continue
}
if v < max {
// Don't let v drop below 0
max = v
}
v -= uint64(1 + m.rand(int(max)))
return v
}
}
}
func (m *mutator) mutateFloat(v, maxValue float64) float64 {
var max float64
for {
switch m.rand(4) {
case 0:
// Add a random number
if v >= maxValue {
continue
}
max = 100
if v > 0 && maxValue-v < max {
// Don't let v exceed maxValue
max = maxValue - v
}
v += float64(1 + m.rand(int(max)))
return v
case 1:
// Subtract a random number
if v <= -maxValue {
continue
}
max = 100
if v < 0 && maxValue+v < max {
// Don't let v drop below -maxValue
max = maxValue + v
}
v -= float64(1 + m.rand(int(max)))
return v
case 2:
// Multiply by a random number
absV := math.Abs(v)
if v == 0 || absV >= maxValue {
continue
}
max = 10
if maxValue/absV < max {
// Don't let v go beyond the minimum or maximum value
max = maxValue / absV
}
v *= float64(1 + m.rand(int(max)))
return v
case 3:
// Divide by a random number
if v == 0 {
continue
}
v /= float64(1 + m.rand(10))
return v
}
}
}
type byteSliceMutator func(*mutator, []byte) []byte
var byteSliceMutators = []byteSliceMutator{
byteSliceRemoveBytes,
byteSliceInsertRandomBytes,
byteSliceDuplicateBytes,
byteSliceOverwriteBytes,
byteSliceBitFlip,
byteSliceXORByte,
byteSliceSwapByte,
byteSliceArithmeticUint8,
byteSliceArithmeticUint16,
byteSliceArithmeticUint32,
byteSliceArithmeticUint64,
byteSliceOverwriteInterestingUint8,
byteSliceOverwriteInterestingUint16,
byteSliceOverwriteInterestingUint32,
byteSliceInsertConstantBytes,
byteSliceOverwriteConstantBytes,
byteSliceShuffleBytes,
byteSliceSwapBytes,
}
func (m *mutator) mutateBytes(ptrB *[]byte) {
b := *ptrB
defer func() {
if unsafe.SliceData(*ptrB) != unsafe.SliceData(b) {
panic("data moved to new address")
}
*ptrB = b
}()
for {
mut := byteSliceMutators[m.rand(len(byteSliceMutators))]
if mutated := mut(m, b); mutated != nil {
b = mutated
return
}
}
}
var (
interesting8 = []int8{-128, -1, 0, 1, 16, 32, 64, 100, 127}
interesting16 = []int16{-32768, -129, 128, 255, 256, 512, 1000, 1024, 4096, 32767}
interesting32 = []int32{-2147483648, -100663046, -32769, 32768, 65535, 65536, 100663045, 2147483647}
)
const (
maxUint = uint64(^uint(0))
maxInt = int64(maxUint >> 1)
)
func init() {
for _, v := range interesting8 {
interesting16 = append(interesting16, int16(v))
}
for _, v := range interesting16 {
interesting32 = append(interesting32, int32(v))
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
// byteSliceRemoveBytes removes a random chunk of bytes from b.
func byteSliceRemoveBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
pos0 := m.rand(len(b))
pos1 := pos0 + m.chooseLen(len(b)-pos0)
copy(b[pos0:], b[pos1:])
b = b[:len(b)-(pos1-pos0)]
return b
}
// byteSliceInsertRandomBytes inserts a chunk of random bytes into b at a random
// position.
func byteSliceInsertRandomBytes(m *mutator, b []byte) []byte {
pos := m.rand(len(b) + 1)
n := m.chooseLen(1024)
if len(b)+n >= cap(b) {
return nil
}
b = b[:len(b)+n]
copy(b[pos+n:], b[pos:])
for i := 0; i < n; i++ {
b[pos+i] = byte(m.rand(256))
}
return b
}
// byteSliceDuplicateBytes duplicates a chunk of bytes in b and inserts it into
// a random position.
func byteSliceDuplicateBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
src := m.rand(len(b))
dst := m.rand(len(b))
for dst == src {
dst = m.rand(len(b))
}
n := m.chooseLen(len(b) - src)
// Use the end of the slice as scratch space to avoid doing an
// allocation. If the slice is too small abort and try something
// else.
if len(b)+(n*2) >= cap(b) {
return nil
}
end := len(b)
// Increase the size of b to fit the duplicated block as well as
// some extra working space
b = b[:end+(n*2)]
// Copy the block of bytes we want to duplicate to the end of the
// slice
copy(b[end+n:], b[src:src+n])
// Shift the bytes after the splice point n positions to the right
// to make room for the new block
copy(b[dst+n:end+n], b[dst:end])
// Insert the duplicate block into the splice point
copy(b[dst:], b[end+n:])
b = b[:end+n]
return b
}
// byteSliceOverwriteBytes overwrites a chunk of b with another chunk of b.
func byteSliceOverwriteBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
src := m.rand(len(b))
dst := m.rand(len(b))
for dst == src {
dst = m.rand(len(b))
}
n := m.chooseLen(len(b) - src - 1)
copy(b[dst:], b[src:src+n])
return b
}
// byteSliceBitFlip flips a random bit in a random byte in b.
func byteSliceBitFlip(m *mutator, b []byte) []byte {
if len(b) == 0 {
return nil
}
pos := m.rand(len(b))
b[pos] ^= 1 << uint(m.rand(8))
return b
}
// byteSliceXORByte XORs a random byte in b with a random value.
func byteSliceXORByte(m *mutator, b []byte) []byte {
if len(b) == 0 {
return nil
}
pos := m.rand(len(b))
// In order to avoid a no-op (where the random value matches
// the existing value), use XOR instead of just setting to
// the random value.
b[pos] ^= byte(1 + m.rand(255))
return b
}
// byteSliceSwapByte swaps two random bytes in b.
func byteSliceSwapByte(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
src := m.rand(len(b))
dst := m.rand(len(b))
for dst == src {
dst = m.rand(len(b))
}
b[src], b[dst] = b[dst], b[src]
return b
}
// byteSliceArithmeticUint8 adds/subtracts from a random byte in b.
func byteSliceArithmeticUint8(m *mutator, b []byte) []byte {
if len(b) == 0 {
return nil
}
pos := m.rand(len(b))
v := byte(m.rand(35) + 1)
if m.r.bool() {
b[pos] += v
} else {
b[pos] -= v
}
return b
}
// byteSliceArithmeticUint16 adds/subtracts from a random uint16 in b.
func byteSliceArithmeticUint16(m *mutator, b []byte) []byte {
if len(b) < 2 {
return nil
}
v := uint16(m.rand(35) + 1)
if m.r.bool() {
v = 0 - v
}
pos := m.rand(len(b) - 1)
enc := m.randByteOrder()
enc.PutUint16(b[pos:], enc.Uint16(b[pos:])+v)
return b
}
// byteSliceArithmeticUint32 adds/subtracts from a random uint32 in b.
func byteSliceArithmeticUint32(m *mutator, b []byte) []byte {
if len(b) < 4 {
return nil
}
v := uint32(m.rand(35) + 1)
if m.r.bool() {
v = 0 - v
}
pos := m.rand(len(b) - 3)
enc := m.randByteOrder()
enc.PutUint32(b[pos:], enc.Uint32(b[pos:])+v)
return b
}
// byteSliceArithmeticUint64 adds/subtracts from a random uint64 in b.
func byteSliceArithmeticUint64(m *mutator, b []byte) []byte {
if len(b) < 8 {
return nil
}
v := uint64(m.rand(35) + 1)
if m.r.bool() {
v = 0 - v
}
pos := m.rand(len(b) - 7)
enc := m.randByteOrder()
enc.PutUint64(b[pos:], enc.Uint64(b[pos:])+v)
return b
}
// byteSliceOverwriteInterestingUint8 overwrites a random byte in b with an interesting
// value.
func byteSliceOverwriteInterestingUint8(m *mutator, b []byte) []byte {
if len(b) == 0 {
return nil
}
pos := m.rand(len(b))
b[pos] = byte(interesting8[m.rand(len(interesting8))])
return b
}
// byteSliceOverwriteInterestingUint16 overwrites a random uint16 in b with an interesting
// value.
func byteSliceOverwriteInterestingUint16(m *mutator, b []byte) []byte {
if len(b) < 2 {
return nil
}
pos := m.rand(len(b) - 1)
v := uint16(interesting16[m.rand(len(interesting16))])
m.randByteOrder().PutUint16(b[pos:], v)
return b
}
// byteSliceOverwriteInterestingUint32 overwrites a random uint16 in b with an interesting
// value.
func byteSliceOverwriteInterestingUint32(m *mutator, b []byte) []byte {
if len(b) < 4 {
return nil
}
pos := m.rand(len(b) - 3)
v := uint32(interesting32[m.rand(len(interesting32))])
m.randByteOrder().PutUint32(b[pos:], v)
return b
}
// byteSliceInsertConstantBytes inserts a chunk of constant bytes into a random position in b.
func byteSliceInsertConstantBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
dst := m.rand(len(b))
// TODO(rolandshoemaker,katiehockman): 4096 was mainly picked
// randomly. We may want to either pick a much larger value
// (AFL uses 32768, paired with a similar impl to chooseLen
// which biases towards smaller lengths that grow over time),
// or set the max based on characteristics of the corpus
// (libFuzzer sets a min/max based on the min/max size of
// entries in the corpus and then picks uniformly from
// that range).
n := m.chooseLen(4096)
if len(b)+n >= cap(b) {
return nil
}
b = b[:len(b)+n]
copy(b[dst+n:], b[dst:])
rb := byte(m.rand(256))
for i := dst; i < dst+n; i++ {
b[i] = rb
}
return b
}
// byteSliceOverwriteConstantBytes overwrites a chunk of b with constant bytes.
func byteSliceOverwriteConstantBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
dst := m.rand(len(b))
n := m.chooseLen(len(b) - dst)
rb := byte(m.rand(256))
for i := dst; i < dst+n; i++ {
b[i] = rb
}
return b
}
// byteSliceShuffleBytes shuffles a chunk of bytes in b.
func byteSliceShuffleBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
dst := m.rand(len(b))
n := m.chooseLen(len(b) - dst)
if n <= 2 {
return nil
}
// Start at the end of the range, and iterate backwards
// to dst, swapping each element with another element in
// dst:dst+n (Fisher-Yates shuffle).
for i := n - 1; i > 0; i-- {
j := m.rand(i + 1)
b[dst+i], b[dst+j] = b[dst+j], b[dst+i]
}
return b
}
// byteSliceSwapBytes swaps two chunks of bytes in b.
func byteSliceSwapBytes(m *mutator, b []byte) []byte {
if len(b) <= 1 {
return nil
}
src := m.rand(len(b))
dst := m.rand(len(b))
for dst == src {
dst = m.rand(len(b))
}
// Choose the random length as len(b) - max(src, dst)
// so that we don't attempt to swap a chunk that extends
// beyond the end of the slice
max := dst
if src > max {
max = src
}
n := m.chooseLen(len(b) - max - 1)
// Check that neither chunk intersect, so that we don't end up
// duplicating parts of the input, rather than swapping them
if src > dst && dst+n >= src || dst > src && src+n >= dst {
return nil
}
// Use the end of the slice as scratch space to avoid doing an
// allocation. If the slice is too small abort and try something
// else.
if len(b)+n >= cap(b) {
return nil
}
end := len(b)
b = b[:end+n]
copy(b[end:], b[dst:dst+n])
copy(b[dst:], b[src:src+n])
copy(b[src:], b[end:])
b = b[:end]
return b
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"math/bits"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
)
type mutatorRand interface {
uint32() uint32
intn(int) int
uint32n(uint32) uint32
exp2() int
bool() bool
save(randState, randInc *uint64)
restore(randState, randInc uint64)
}
// The functions in pcg implement a 32 bit PRNG with a 64 bit period: pcg xsh rr
// 64 32. See https://www.pcg-random.org/ for more information. This
// implementation is geared specifically towards the needs of fuzzing: Simple
// creation and use, no reproducibility, no concurrency safety, just the
// necessary methods, optimized for speed.
var globalInc uint64 // PCG stream
const multiplier uint64 = 6364136223846793005
// pcgRand is a PRNG. It should not be copied or shared. No Rand methods are
// concurrency safe.
type pcgRand struct {
noCopy noCopy // help avoid mistakes: ask vet to ensure that we don't make a copy
state uint64
inc uint64
}
func godebugSeed() *int {
debug := strings.Split(os.Getenv("GODEBUG"), ",")
for _, f := range debug {
if strings.HasPrefix(f, "fuzzseed=") {
seed, err := strconv.Atoi(strings.TrimPrefix(f, "fuzzseed="))
if err != nil {
panic("malformed fuzzseed")
}
return &seed
}
}
return nil
}
// newPcgRand generates a new, seeded Rand, ready for use.
func newPcgRand() *pcgRand {
r := new(pcgRand)
now := uint64(time.Now().UnixNano())
if seed := godebugSeed(); seed != nil {
now = uint64(*seed)
}
inc := atomic.AddUint64(&globalInc, 1)
r.state = now
r.inc = (inc << 1) | 1
r.step()
r.state += now
r.step()
return r
}
func (r *pcgRand) step() {
r.state *= multiplier
r.state += r.inc
}
func (r *pcgRand) save(randState, randInc *uint64) {
*randState = r.state
*randInc = r.inc
}
func (r *pcgRand) restore(randState, randInc uint64) {
r.state = randState
r.inc = randInc
}
// uint32 returns a pseudo-random uint32.
func (r *pcgRand) uint32() uint32 {
x := r.state
r.step()
return bits.RotateLeft32(uint32(((x>>18)^x)>>27), -int(x>>59))
}
// intn returns a pseudo-random number in [0, n).
// n must fit in a uint32.
func (r *pcgRand) intn(n int) int {
if int(uint32(n)) != n {
panic("large Intn")
}
return int(r.uint32n(uint32(n)))
}
// uint32n returns a pseudo-random number in [0, n).
//
// For implementation details, see:
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
func (r *pcgRand) uint32n(n uint32) uint32 {
v := r.uint32()
prod := uint64(v) * uint64(n)
low := uint32(prod)
if low < n {
thresh := uint32(-int32(n)) % n
for low < thresh {
v = r.uint32()
prod = uint64(v) * uint64(n)
low = uint32(prod)
}
}
return uint32(prod >> 32)
}
// exp2 generates n with probability 1/2^(n+1).
func (r *pcgRand) exp2() int {
return bits.TrailingZeros32(r.uint32())
}
// bool generates a random bool.
func (r *pcgRand) bool() bool {
return r.uint32()&1 == 0
}
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}
// lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) lock() {}
func (*noCopy) unlock() {}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
// queue holds a growable sequence of inputs for fuzzing and minimization.
//
// For now, this is a simple ring buffer
// (https://en.wikipedia.org/wiki/Circular_buffer).
//
// TODO(golang.org/issue/46224): use a prioritization algorithm based on input
// size, previous duration, coverage, and any other metrics that seem useful.
type queue struct {
// elems holds a ring buffer.
// The queue is empty when begin = end.
// The queue is full (until grow is called) when end = begin + N - 1 (mod N)
// where N = cap(elems).
elems []any
head, len int
}
func (q *queue) cap() int {
return len(q.elems)
}
func (q *queue) grow() {
oldCap := q.cap()
newCap := oldCap * 2
if newCap == 0 {
newCap = 8
}
newElems := make([]any, newCap)
oldLen := q.len
for i := 0; i < oldLen; i++ {
newElems[i] = q.elems[(q.head+i)%oldCap]
}
q.elems = newElems
q.head = 0
}
func (q *queue) enqueue(e any) {
if q.len+1 > q.cap() {
q.grow()
}
i := (q.head + q.len) % q.cap()
q.elems[i] = e
q.len++
}
func (q *queue) dequeue() (any, bool) {
if q.len == 0 {
return nil, false
}
e := q.elems[q.head]
q.elems[q.head] = nil
q.head = (q.head + 1) % q.cap()
q.len--
return e, true
}
func (q *queue) peek() (any, bool) {
if q.len == 0 {
return nil, false
}
return q.elems[q.head], true
}
func (q *queue) clear() {
*q = queue{}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin || freebsd || linux
package fuzz
import (
"fmt"
"os"
"os/exec"
"syscall"
)
type sharedMemSys struct{}
func sharedMemMapFile(f *os.File, size int, removeOnClose bool) (*sharedMem, error) {
prot := syscall.PROT_READ | syscall.PROT_WRITE
flags := syscall.MAP_FILE | syscall.MAP_SHARED
region, err := syscall.Mmap(int(f.Fd()), 0, size, prot, flags)
if err != nil {
return nil, err
}
return &sharedMem{f: f, region: region, removeOnClose: removeOnClose}, nil
}
// Close unmaps the shared memory and closes the temporary file. If this
// sharedMem was created with sharedMemTempFile, Close also removes the file.
func (m *sharedMem) Close() error {
// Attempt all operations, even if we get an error for an earlier operation.
// os.File.Close may fail due to I/O errors, but we still want to delete
// the temporary file.
var errs []error
errs = append(errs,
syscall.Munmap(m.region),
m.f.Close())
if m.removeOnClose {
errs = append(errs, os.Remove(m.f.Name()))
}
for _, err := range errs {
if err != nil {
return err
}
}
return nil
}
// setWorkerComm configures communication channels on the cmd that will
// run a worker process.
func setWorkerComm(cmd *exec.Cmd, comm workerComm) {
mem := <-comm.memMu
memFile := mem.f
comm.memMu <- mem
cmd.ExtraFiles = []*os.File{comm.fuzzIn, comm.fuzzOut, memFile}
}
// getWorkerComm returns communication channels in the worker process.
func getWorkerComm() (comm workerComm, err error) {
fuzzIn := os.NewFile(3, "fuzz_in")
fuzzOut := os.NewFile(4, "fuzz_out")
memFile := os.NewFile(5, "fuzz_mem")
fi, err := memFile.Stat()
if err != nil {
return workerComm{}, err
}
size := int(fi.Size())
if int64(size) != fi.Size() {
return workerComm{}, fmt.Errorf("fuzz temp file exceeds maximum size")
}
removeOnClose := false
mem, err := sharedMemMapFile(memFile, size, removeOnClose)
if err != nil {
return workerComm{}, err
}
memMu := make(chan *sharedMem, 1)
memMu <- mem
return workerComm{fuzzIn: fuzzIn, fuzzOut: fuzzOut, memMu: memMu}, nil
}
// isInterruptError returns whether an error was returned by a process that
// was terminated by an interrupt signal (SIGINT).
func isInterruptError(err error) bool {
exitErr, ok := err.(*exec.ExitError)
if !ok || exitErr.ExitCode() >= 0 {
return false
}
status := exitErr.Sys().(syscall.WaitStatus)
return status.Signal() == syscall.SIGINT
}
// terminationSignal checks if err is an exec.ExitError with a signal status.
// If it is, terminationSignal returns the signal and true.
// If not, -1 and false.
func terminationSignal(err error) (os.Signal, bool) {
exitErr, ok := err.(*exec.ExitError)
if !ok || exitErr.ExitCode() >= 0 {
return syscall.Signal(-1), false
}
status := exitErr.Sys().(syscall.WaitStatus)
return status.Signal(), status.Signaled()
}
// isCrashSignal returns whether a signal was likely to have been caused by an
// error in the program that received it, triggered by a fuzz input. For
// example, SIGSEGV would be received after a nil pointer dereference.
// Other signals like SIGKILL or SIGHUP are more likely to have been sent by
// another process, and we shouldn't record a crasher if the worker process
// receives one of these.
//
// Note that Go installs its own signal handlers on startup, so some of these
// signals may only be received if signal handlers are changed. For example,
// SIGSEGV is normally transformed into a panic that causes the process to exit
// with status 2 if not recovered, which we handle as a crash.
func isCrashSignal(signal os.Signal) bool {
switch signal {
case
syscall.SIGILL, // illegal instruction
syscall.SIGTRAP, // breakpoint
syscall.SIGABRT, // abort() called
syscall.SIGBUS, // invalid memory access (e.g., misaligned address)
syscall.SIGFPE, // math error, e.g., integer divide by zero
syscall.SIGSEGV, // invalid memory access (e.g., write to read-only)
syscall.SIGPIPE: // sent data to closed pipe or socket
return true
default:
return false
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !libfuzzer
package fuzz
import _ "unsafe" // for go:linkname
//go:linkname libfuzzerTraceCmp1 runtime.libfuzzerTraceCmp1
//go:linkname libfuzzerTraceCmp2 runtime.libfuzzerTraceCmp2
//go:linkname libfuzzerTraceCmp4 runtime.libfuzzerTraceCmp4
//go:linkname libfuzzerTraceCmp8 runtime.libfuzzerTraceCmp8
//go:linkname libfuzzerTraceConstCmp1 runtime.libfuzzerTraceConstCmp1
//go:linkname libfuzzerTraceConstCmp2 runtime.libfuzzerTraceConstCmp2
//go:linkname libfuzzerTraceConstCmp4 runtime.libfuzzerTraceConstCmp4
//go:linkname libfuzzerTraceConstCmp8 runtime.libfuzzerTraceConstCmp8
//go:linkname libfuzzerHookStrCmp runtime.libfuzzerHookStrCmp
//go:linkname libfuzzerHookEqualFold runtime.libfuzzerHookEqualFold
func libfuzzerTraceCmp1(arg0, arg1 uint8, fakePC uint) {}
func libfuzzerTraceCmp2(arg0, arg1 uint16, fakePC uint) {}
func libfuzzerTraceCmp4(arg0, arg1 uint32, fakePC uint) {}
func libfuzzerTraceCmp8(arg0, arg1 uint64, fakePC uint) {}
func libfuzzerTraceConstCmp1(arg0, arg1 uint8, fakePC uint) {}
func libfuzzerTraceConstCmp2(arg0, arg1 uint16, fakePC uint) {}
func libfuzzerTraceConstCmp4(arg0, arg1 uint32, fakePC uint) {}
func libfuzzerTraceConstCmp8(arg0, arg1 uint64, fakePC uint) {}
func libfuzzerHookStrCmp(arg0, arg1 string, fakePC uint) {}
func libfuzzerHookEqualFold(arg0, arg1 string, fakePC uint) {}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fuzz
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"reflect"
"runtime"
"sync"
"time"
)
const (
// workerFuzzDuration is the amount of time a worker can spend testing random
// variations of an input given by the coordinator.
workerFuzzDuration = 100 * time.Millisecond
// workerTimeoutDuration is the amount of time a worker can go without
// responding to the coordinator before being stopped.
workerTimeoutDuration = 1 * time.Second
// workerExitCode is used as an exit code by fuzz worker processes after an internal error.
// This distinguishes internal errors from uncontrolled panics and other crashes.
// Keep in sync with internal/fuzz.workerExitCode.
workerExitCode = 70
// workerSharedMemSize is the maximum size of the shared memory file used to
// communicate with workers. This limits the size of fuzz inputs.
workerSharedMemSize = 100 << 20 // 100 MB
)
// worker manages a worker process running a test binary. The worker object
// exists only in the coordinator (the process started by 'go test -fuzz').
// workerClient is used by the coordinator to send RPCs to the worker process,
// which handles them with workerServer.
type worker struct {
dir string // working directory, same as package directory
binPath string // path to test executable
args []string // arguments for test executable
env []string // environment for test executable
coordinator *coordinator
memMu chan *sharedMem // mutex guarding shared memory with worker; persists across processes.
cmd *exec.Cmd // current worker process
client *workerClient // used to communicate with worker process
waitErr error // last error returned by wait, set before termC is closed.
interrupted bool // true after stop interrupts a running worker.
termC chan struct{} // closed by wait when worker process terminates
}
func newWorker(c *coordinator, dir, binPath string, args, env []string) (*worker, error) {
mem, err := sharedMemTempFile(workerSharedMemSize)
if err != nil {
return nil, err
}
memMu := make(chan *sharedMem, 1)
memMu <- mem
return &worker{
dir: dir,
binPath: binPath,
args: args,
env: env[:len(env):len(env)], // copy on append to ensure workers don't overwrite each other.
coordinator: c,
memMu: memMu,
}, nil
}
// cleanup releases persistent resources associated with the worker.
func (w *worker) cleanup() error {
mem := <-w.memMu
if mem == nil {
return nil
}
close(w.memMu)
return mem.Close()
}
// coordinate runs the test binary to perform fuzzing.
//
// coordinate loops until ctx is cancelled or a fatal error is encountered.
// If a test process terminates unexpectedly while fuzzing, coordinate will
// attempt to restart and continue unless the termination can be attributed
// to an interruption (from a timer or the user).
//
// While looping, coordinate receives inputs from the coordinator, passes
// those inputs to the worker process, then passes the results back to
// the coordinator.
func (w *worker) coordinate(ctx context.Context) error {
// Main event loop.
for {
// Start or restart the worker if it's not running.
if !w.isRunning() {
if err := w.startAndPing(ctx); err != nil {
return err
}
}
select {
case <-ctx.Done():
// Worker was told to stop.
err := w.stop()
if err != nil && !w.interrupted && !isInterruptError(err) {
return err
}
return ctx.Err()
case <-w.termC:
// Worker process terminated unexpectedly while waiting for input.
err := w.stop()
if w.interrupted {
panic("worker interrupted after unexpected termination")
}
if err == nil || isInterruptError(err) {
// Worker stopped, either by exiting with status 0 or after being
// interrupted with a signal that was not sent by the coordinator.
//
// When the user presses ^C, on POSIX platforms, SIGINT is delivered to
// all processes in the group concurrently, and the worker may see it
// before the coordinator. The worker should exit 0 gracefully (in
// theory).
//
// This condition is probably intended by the user, so suppress
// the error.
return nil
}
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == workerExitCode {
// Worker exited with a code indicating F.Fuzz was not called correctly,
// for example, F.Fail was called first.
return fmt.Errorf("fuzzing process exited unexpectedly due to an internal failure: %w", err)
}
// Worker exited non-zero or was terminated by a non-interrupt
// signal (for example, SIGSEGV) while fuzzing.
return fmt.Errorf("fuzzing process hung or terminated unexpectedly: %w", err)
// TODO(jayconrod,katiehockman): if -keepfuzzing, restart worker.
case input := <-w.coordinator.inputC:
// Received input from coordinator.
args := fuzzArgs{
Limit: input.limit,
Timeout: input.timeout,
Warmup: input.warmup,
CoverageData: input.coverageData,
}
entry, resp, isInternalError, err := w.client.fuzz(ctx, input.entry, args)
canMinimize := true
if err != nil {
// Error communicating with worker.
w.stop()
if ctx.Err() != nil {
// Timeout or interruption.
return ctx.Err()
}
if w.interrupted {
// Communication error before we stopped the worker.
// Report an error, but don't record a crasher.
return fmt.Errorf("communicating with fuzzing process: %v", err)
}
if sig, ok := terminationSignal(w.waitErr); ok && !isCrashSignal(sig) {
// Worker terminated by a signal that probably wasn't caused by a
// specific input to the fuzz function. For example, on Linux,
// the kernel (OOM killer) may send SIGKILL to a process using a lot
// of memory. Or the shell might send SIGHUP when the terminal
// is closed. Don't record a crasher.
return fmt.Errorf("fuzzing process terminated by unexpected signal; no crash will be recorded: %v", w.waitErr)
}
if isInternalError {
// An internal error occurred which shouldn't be considered
// a crash.
return err
}
// Unexpected termination. Set error message and fall through.
// We'll restart the worker on the next iteration.
// Don't attempt to minimize this since it crashed the worker.
resp.Err = fmt.Sprintf("fuzzing process hung or terminated unexpectedly: %v", w.waitErr)
canMinimize = false
}
result := fuzzResult{
limit: input.limit,
count: resp.Count,
totalDuration: resp.TotalDuration,
entryDuration: resp.InterestingDuration,
entry: entry,
crasherMsg: resp.Err,
coverageData: resp.CoverageData,
canMinimize: canMinimize,
}
w.coordinator.resultC <- result
case input := <-w.coordinator.minimizeC:
// Received input to minimize from coordinator.
result, err := w.minimize(ctx, input)
if err != nil {
// Error minimizing. Send back the original input. If it didn't cause
// an error before, report it as causing an error now.
// TODO: double-check this is handled correctly when
// implementing -keepfuzzing.
result = fuzzResult{
entry: input.entry,
crasherMsg: input.crasherMsg,
canMinimize: false,
limit: input.limit,
}
if result.crasherMsg == "" {
result.crasherMsg = err.Error()
}
}
w.coordinator.resultC <- result
}
}
}
// minimize tells a worker process to attempt to find a smaller value that
// either causes an error (if we started minimizing because we found an input
// that causes an error) or preserves new coverage (if we started minimizing
// because we found an input that expands coverage).
func (w *worker) minimize(ctx context.Context, input fuzzMinimizeInput) (min fuzzResult, err error) {
if w.coordinator.opts.MinimizeTimeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, w.coordinator.opts.MinimizeTimeout)
defer cancel()
}
args := minimizeArgs{
Limit: input.limit,
Timeout: input.timeout,
KeepCoverage: input.keepCoverage,
}
entry, resp, err := w.client.minimize(ctx, input.entry, args)
if err != nil {
// Error communicating with worker.
w.stop()
if ctx.Err() != nil || w.interrupted || isInterruptError(w.waitErr) {
// Worker was interrupted, possibly by the user pressing ^C.
// Normally, workers can handle interrupts and timeouts gracefully and
// will return without error. An error here indicates the worker
// may not have been in a good state, but the error won't be meaningful
// to the user. Just return the original crasher without logging anything.
return fuzzResult{
entry: input.entry,
crasherMsg: input.crasherMsg,
coverageData: input.keepCoverage,
canMinimize: false,
limit: input.limit,
}, nil
}
return fuzzResult{
entry: entry,
crasherMsg: fmt.Sprintf("fuzzing process hung or terminated unexpectedly while minimizing: %v", err),
canMinimize: false,
limit: input.limit,
count: resp.Count,
totalDuration: resp.Duration,
}, nil
}
if input.crasherMsg != "" && resp.Err == "" {
return fuzzResult{}, fmt.Errorf("attempted to minimize a crash but could not reproduce")
}
return fuzzResult{
entry: entry,
crasherMsg: resp.Err,
coverageData: resp.CoverageData,
canMinimize: false,
limit: input.limit,
count: resp.Count,
totalDuration: resp.Duration,
}, nil
}
func (w *worker) isRunning() bool {
return w.cmd != nil
}
// startAndPing starts the worker process and sends it a message to make sure it
// can communicate.
//
// startAndPing returns an error if any part of this didn't work, including if
// the context is expired or the worker process was interrupted before it
// responded. Errors that happen after start but before the ping response
// likely indicate that the worker did not call F.Fuzz or called F.Fail first.
// We don't record crashers for these errors.
func (w *worker) startAndPing(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
if err := w.start(); err != nil {
return err
}
if err := w.client.ping(ctx); err != nil {
w.stop()
if ctx.Err() != nil {
return ctx.Err()
}
if isInterruptError(err) {
// User may have pressed ^C before worker responded.
return err
}
// TODO: record and return stderr.
return fmt.Errorf("fuzzing process terminated without fuzzing: %w", err)
}
return nil
}
// start runs a new worker process.
//
// If the process couldn't be started, start returns an error. Start won't
// return later termination errors from the process if they occur.
//
// If the process starts successfully, start returns nil. stop must be called
// once later to clean up, even if the process terminates on its own.
//
// When the process terminates, w.waitErr is set to the error (if any), and
// w.termC is closed.
func (w *worker) start() (err error) {
if w.isRunning() {
panic("worker already started")
}
w.waitErr = nil
w.interrupted = false
w.termC = nil
cmd := exec.Command(w.binPath, w.args...)
cmd.Dir = w.dir
cmd.Env = w.env[:len(w.env):len(w.env)] // copy on append to ensure workers don't overwrite each other.
// Create the "fuzz_in" and "fuzz_out" pipes so we can communicate with
// the worker. We don't use stdin and stdout, since the test binary may
// do something else with those.
//
// Each pipe has a reader and a writer. The coordinator writes to fuzzInW
// and reads from fuzzOutR. The worker inherits fuzzInR and fuzzOutW.
// The coordinator closes fuzzInR and fuzzOutW after starting the worker,
// since we have no further need of them.
fuzzInR, fuzzInW, err := os.Pipe()
if err != nil {
return err
}
defer fuzzInR.Close()
fuzzOutR, fuzzOutW, err := os.Pipe()
if err != nil {
fuzzInW.Close()
return err
}
defer fuzzOutW.Close()
setWorkerComm(cmd, workerComm{fuzzIn: fuzzInR, fuzzOut: fuzzOutW, memMu: w.memMu})
// Start the worker process.
if err := cmd.Start(); err != nil {
fuzzInW.Close()
fuzzOutR.Close()
return err
}
// Worker started successfully.
// After this, w.client owns fuzzInW and fuzzOutR, so w.client.Close must be
// called later by stop.
w.cmd = cmd
w.termC = make(chan struct{})
comm := workerComm{fuzzIn: fuzzInW, fuzzOut: fuzzOutR, memMu: w.memMu}
m := newMutator()
w.client = newWorkerClient(comm, m)
go func() {
w.waitErr = w.cmd.Wait()
close(w.termC)
}()
return nil
}
// stop tells the worker process to exit by closing w.client, then blocks until
// it terminates. If the worker doesn't terminate after a short time, stop
// signals it with os.Interrupt (where supported), then os.Kill.
//
// stop returns the error the process terminated with, if any (same as
// w.waitErr).
//
// stop must be called at least once after start returns successfully, even if
// the worker process terminates unexpectedly.
func (w *worker) stop() error {
if w.termC == nil {
panic("worker was not started successfully")
}
select {
case <-w.termC:
// Worker already terminated.
if w.client == nil {
// stop already called.
return w.waitErr
}
// Possible unexpected termination.
w.client.Close()
w.cmd = nil
w.client = nil
return w.waitErr
default:
// Worker still running.
}
// Tell the worker to stop by closing fuzz_in. It won't actually stop until it
// finishes with earlier calls.
closeC := make(chan struct{})
go func() {
w.client.Close()
close(closeC)
}()
sig := os.Interrupt
if runtime.GOOS == "windows" {
// Per https://golang.org/pkg/os/#Signal, “Interrupt is not implemented on
// Windows; using it with os.Process.Signal will return an error.”
// Fall back to Kill instead.
sig = os.Kill
}
t := time.NewTimer(workerTimeoutDuration)
for {
select {
case <-w.termC:
// Worker terminated.
t.Stop()
<-closeC
w.cmd = nil
w.client = nil
return w.waitErr
case <-t.C:
// Timer fired before worker terminated.
w.interrupted = true
switch sig {
case os.Interrupt:
// Try to stop the worker with SIGINT and wait a little longer.
w.cmd.Process.Signal(sig)
sig = os.Kill
t.Reset(workerTimeoutDuration)
case os.Kill:
// Try to stop the worker with SIGKILL and keep waiting.
w.cmd.Process.Signal(sig)
sig = nil
t.Reset(workerTimeoutDuration)
case nil:
// Still waiting. Print a message to let the user know why.
fmt.Fprintf(w.coordinator.opts.Log, "waiting for fuzzing process to terminate...\n")
}
}
}
}
// RunFuzzWorker is called in a worker process to communicate with the
// coordinator process in order to fuzz random inputs. RunFuzzWorker loops
// until the coordinator tells it to stop.
//
// fn is a wrapper on the fuzz function. It may return an error to indicate
// a given input "crashed". The coordinator will also record a crasher if
// the function times out or terminates the process.
//
// RunFuzzWorker returns an error if it could not communicate with the
// coordinator process.
func RunFuzzWorker(ctx context.Context, fn func(CorpusEntry) error) error {
comm, err := getWorkerComm()
if err != nil {
return err
}
srv := &workerServer{
workerComm: comm,
fuzzFn: func(e CorpusEntry) (time.Duration, error) {
timer := time.AfterFunc(10*time.Second, func() {
panic("deadlocked!") // this error message won't be printed
})
defer timer.Stop()
start := time.Now()
err := fn(e)
return time.Since(start), err
},
m: newMutator(),
}
return srv.serve(ctx)
}
// call is serialized and sent from the coordinator on fuzz_in. It acts as
// a minimalist RPC mechanism. Exactly one of its fields must be set to indicate
// which method to call.
type call struct {
Ping *pingArgs
Fuzz *fuzzArgs
Minimize *minimizeArgs
}
// minimizeArgs contains arguments to workerServer.minimize. The value to
// minimize is already in shared memory.
type minimizeArgs struct {
// Timeout is the time to spend minimizing. This may include time to start up,
// especially if the input causes the worker process to terminated, requiring
// repeated restarts.
Timeout time.Duration
// Limit is the maximum number of values to test, without spending more time
// than Duration. 0 indicates no limit.
Limit int64
// KeepCoverage is a set of coverage counters the worker should attempt to
// keep in minimized values. When provided, the worker will reject inputs that
// don't cause at least one of these bits to be set.
KeepCoverage []byte
// Index is the index of the fuzz target parameter to be minimized.
Index int
}
// minimizeResponse contains results from workerServer.minimize.
type minimizeResponse struct {
// WroteToMem is true if the worker found a smaller input and wrote it to
// shared memory. If minimizeArgs.KeepCoverage was set, the minimized input
// preserved at least one coverage bit and did not cause an error.
// Otherwise, the minimized input caused some error, recorded in Err.
WroteToMem bool
// Err is the error string caused by the value in shared memory, if any.
Err string
// CoverageData is the set of coverage bits activated by the minimized value
// in shared memory. When set, it contains at least one bit from KeepCoverage.
// CoverageData will be nil if Err is set or if minimization failed.
CoverageData []byte
// Duration is the time spent minimizing, not including starting or cleaning up.
Duration time.Duration
// Count is the number of values tested.
Count int64
}
// fuzzArgs contains arguments to workerServer.fuzz. The value to fuzz is
// passed in shared memory.
type fuzzArgs struct {
// Timeout is the time to spend fuzzing, not including starting or
// cleaning up.
Timeout time.Duration
// Limit is the maximum number of values to test, without spending more time
// than Duration. 0 indicates no limit.
Limit int64
// Warmup indicates whether this is part of a warmup run, meaning that
// fuzzing should not occur. If coverageEnabled is true, then coverage data
// should be reported.
Warmup bool
// CoverageData is the coverage data. If set, the worker should update its
// local coverage data prior to fuzzing.
CoverageData []byte
}
// fuzzResponse contains results from workerServer.fuzz.
type fuzzResponse struct {
// Duration is the time spent fuzzing, not including starting or cleaning up.
TotalDuration time.Duration
InterestingDuration time.Duration
// Count is the number of values tested.
Count int64
// CoverageData is set if the value in shared memory expands coverage
// and therefore may be interesting to the coordinator.
CoverageData []byte
// Err is the error string caused by the value in shared memory, which is
// non-empty if the value in shared memory caused a crash.
Err string
// InternalErr is the error string caused by an internal error in the
// worker. This shouldn't be considered a crasher.
InternalErr string
}
// pingArgs contains arguments to workerServer.ping.
type pingArgs struct{}
// pingResponse contains results from workerServer.ping.
type pingResponse struct{}
// workerComm holds pipes and shared memory used for communication
// between the coordinator process (client) and a worker process (server).
// These values are unique to each worker; they are shared only with the
// coordinator, not with other workers.
//
// Access to shared memory is synchronized implicitly over the RPC protocol
// implemented in workerServer and workerClient. During a call, the client
// (worker) has exclusive access to shared memory; at other times, the server
// (coordinator) has exclusive access.
type workerComm struct {
fuzzIn, fuzzOut *os.File
memMu chan *sharedMem // mutex guarding shared memory
}
// workerServer is a minimalist RPC server, run by fuzz worker processes.
// It allows the coordinator process (using workerClient) to call methods in a
// worker process. This system allows the coordinator to run multiple worker
// processes in parallel and to collect inputs that caused crashes from shared
// memory after a worker process terminates unexpectedly.
type workerServer struct {
workerComm
m *mutator
// coverageMask is the local coverage data for the worker. It is
// periodically updated to reflect the data in the coordinator when new
// coverage is found.
coverageMask []byte
// fuzzFn runs the worker's fuzz target on the given input and returns an
// error if it finds a crasher (the process may also exit or crash), and the
// time it took to run the input. It sets a deadline of 10 seconds, at which
// point it will panic with the assumption that the process is hanging or
// deadlocked.
fuzzFn func(CorpusEntry) (time.Duration, error)
}
// serve reads serialized RPC messages on fuzzIn. When serve receives a message,
// it calls the corresponding method, then sends the serialized result back
// on fuzzOut.
//
// serve handles RPC calls synchronously; it will not attempt to read a message
// until the previous call has finished.
//
// serve returns errors that occurred when communicating over pipes. serve
// does not return errors from method calls; those are passed through serialized
// responses.
func (ws *workerServer) serve(ctx context.Context) error {
enc := json.NewEncoder(ws.fuzzOut)
dec := json.NewDecoder(&contextReader{ctx: ctx, r: ws.fuzzIn})
for {
var c call
if err := dec.Decode(&c); err != nil {
if err == io.EOF || err == ctx.Err() {
return nil
} else {
return err
}
}
var resp any
switch {
case c.Fuzz != nil:
resp = ws.fuzz(ctx, *c.Fuzz)
case c.Minimize != nil:
resp = ws.minimize(ctx, *c.Minimize)
case c.Ping != nil:
resp = ws.ping(ctx, *c.Ping)
default:
return errors.New("no arguments provided for any call")
}
if err := enc.Encode(resp); err != nil {
return err
}
}
}
// chainedMutations is how many mutations are applied before the worker
// resets the input to it's original state.
// NOTE: this number was picked without much thought. It is low enough that
// it seems to create a significant diversity in mutated inputs. We may want
// to consider looking into this more closely once we have a proper performance
// testing framework. Another option is to randomly pick the number of chained
// mutations on each invocation of the workerServer.fuzz method (this appears to
// be what libFuzzer does, although there seems to be no documentation which
// explains why this choice was made.)
const chainedMutations = 5
// fuzz runs the test function on random variations of the input value in shared
// memory for a limited duration or number of iterations.
//
// fuzz returns early if it finds an input that crashes the fuzz function (with
// fuzzResponse.Err set) or an input that expands coverage (with
// fuzzResponse.InterestingDuration set).
//
// fuzz does not modify the input in shared memory. Instead, it saves the
// initial PRNG state in shared memory and increments a counter in shared
// memory before each call to the test function. The caller may reconstruct
// the crashing input with this information, since the PRNG is deterministic.
func (ws *workerServer) fuzz(ctx context.Context, args fuzzArgs) (resp fuzzResponse) {
if args.CoverageData != nil {
if ws.coverageMask != nil && len(args.CoverageData) != len(ws.coverageMask) {
resp.InternalErr = fmt.Sprintf("unexpected size for CoverageData: got %d, expected %d", len(args.CoverageData), len(ws.coverageMask))
return resp
}
ws.coverageMask = args.CoverageData
}
start := time.Now()
defer func() { resp.TotalDuration = time.Since(start) }()
if args.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, args.Timeout)
defer cancel()
}
mem := <-ws.memMu
ws.m.r.save(&mem.header().randState, &mem.header().randInc)
defer func() {
resp.Count = mem.header().count
ws.memMu <- mem
}()
if args.Limit > 0 && mem.header().count >= args.Limit {
resp.InternalErr = fmt.Sprintf("mem.header().count %d already exceeds args.Limit %d", mem.header().count, args.Limit)
return resp
}
originalVals, err := unmarshalCorpusFile(mem.valueCopy())
if err != nil {
resp.InternalErr = err.Error()
return resp
}
vals := make([]any, len(originalVals))
copy(vals, originalVals)
shouldStop := func() bool {
return args.Limit > 0 && mem.header().count >= args.Limit
}
fuzzOnce := func(entry CorpusEntry) (dur time.Duration, cov []byte, errMsg string) {
mem.header().count++
var err error
dur, err = ws.fuzzFn(entry)
if err != nil {
errMsg = err.Error()
if errMsg == "" {
errMsg = "fuzz function failed with no input"
}
return dur, nil, errMsg
}
if ws.coverageMask != nil && countNewCoverageBits(ws.coverageMask, coverageSnapshot) > 0 {
return dur, coverageSnapshot, ""
}
return dur, nil, ""
}
if args.Warmup {
dur, _, errMsg := fuzzOnce(CorpusEntry{Values: vals})
if errMsg != "" {
resp.Err = errMsg
return resp
}
resp.InterestingDuration = dur
if coverageEnabled {
resp.CoverageData = coverageSnapshot
}
return resp
}
for {
select {
case <-ctx.Done():
return resp
default:
if mem.header().count%chainedMutations == 0 {
copy(vals, originalVals)
ws.m.r.save(&mem.header().randState, &mem.header().randInc)
}
ws.m.mutate(vals, cap(mem.valueRef()))
entry := CorpusEntry{Values: vals}
dur, cov, errMsg := fuzzOnce(entry)
if errMsg != "" {
resp.Err = errMsg
return resp
}
if cov != nil {
resp.CoverageData = cov
resp.InterestingDuration = dur
return resp
}
if shouldStop() {
return resp
}
}
}
}
func (ws *workerServer) minimize(ctx context.Context, args minimizeArgs) (resp minimizeResponse) {
start := time.Now()
defer func() { resp.Duration = time.Since(start) }()
mem := <-ws.memMu
defer func() { ws.memMu <- mem }()
vals, err := unmarshalCorpusFile(mem.valueCopy())
if err != nil {
panic(err)
}
inpHash := sha256.Sum256(mem.valueCopy())
if args.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, args.Timeout)
defer cancel()
}
// Minimize the values in vals, then write to shared memory. We only write
// to shared memory after completing minimization.
success, err := ws.minimizeInput(ctx, vals, mem, args)
if success {
writeToMem(vals, mem)
outHash := sha256.Sum256(mem.valueCopy())
mem.header().rawInMem = false
resp.WroteToMem = true
if err != nil {
resp.Err = err.Error()
} else {
// If the values didn't change during minimization then coverageSnapshot is likely
// a dirty snapshot which represents the very last step of minimization, not the
// coverage for the initial input. In that case just return the coverage we were
// given initially, since it more accurately represents the coverage map for the
// input we are returning.
if outHash != inpHash {
resp.CoverageData = coverageSnapshot
} else {
resp.CoverageData = args.KeepCoverage
}
}
}
return resp
}
// minimizeInput applies a series of minimizing transformations on the provided
// vals, ensuring that each minimization still causes an error, or keeps
// coverage, in fuzzFn. It uses the context to determine how long to run,
// stopping once closed. It returns a bool indicating whether minimization was
// successful and an error if one was found.
func (ws *workerServer) minimizeInput(ctx context.Context, vals []any, mem *sharedMem, args minimizeArgs) (success bool, retErr error) {
keepCoverage := args.KeepCoverage
memBytes := mem.valueRef()
bPtr := &memBytes
count := &mem.header().count
shouldStop := func() bool {
return ctx.Err() != nil ||
(args.Limit > 0 && *count >= args.Limit)
}
if shouldStop() {
return false, nil
}
// Check that the original value preserves coverage or causes an error.
// If not, then whatever caused us to think the value was interesting may
// have been a flake, and we can't minimize it.
*count++
_, retErr = ws.fuzzFn(CorpusEntry{Values: vals})
if keepCoverage != nil {
if !hasCoverageBit(keepCoverage, coverageSnapshot) || retErr != nil {
return false, nil
}
} else if retErr == nil {
return false, nil
}
mem.header().rawInMem = true
// tryMinimized runs the fuzz function with candidate replacing the value
// at index valI. tryMinimized returns whether the input with candidate is
// interesting for the same reason as the original input: it returns
// an error if one was expected, or it preserves coverage.
tryMinimized := func(candidate []byte) bool {
prev := vals[args.Index]
switch prev.(type) {
case []byte:
vals[args.Index] = candidate
case string:
vals[args.Index] = string(candidate)
default:
panic("impossible")
}
copy(*bPtr, candidate)
*bPtr = (*bPtr)[:len(candidate)]
mem.setValueLen(len(candidate))
*count++
_, err := ws.fuzzFn(CorpusEntry{Values: vals})
if err != nil {
retErr = err
if keepCoverage != nil {
// Now that we've found a crash, that's more important than any
// minimization of interesting inputs that was being done. Clear out
// keepCoverage to only minimize the crash going forward.
keepCoverage = nil
}
return true
}
// Minimization should preserve coverage bits.
if keepCoverage != nil && isCoverageSubset(keepCoverage, coverageSnapshot) {
return true
}
vals[args.Index] = prev
return false
}
switch v := vals[args.Index].(type) {
case string:
minimizeBytes([]byte(v), tryMinimized, shouldStop)
case []byte:
minimizeBytes(v, tryMinimized, shouldStop)
default:
panic("impossible")
}
return true, retErr
}
func writeToMem(vals []any, mem *sharedMem) {
b := marshalCorpusFile(vals...)
mem.setValue(b)
}
// ping does nothing. The coordinator calls this method to ensure the worker
// has called F.Fuzz and can communicate.
func (ws *workerServer) ping(ctx context.Context, args pingArgs) pingResponse {
return pingResponse{}
}
// workerClient is a minimalist RPC client. The coordinator process uses a
// workerClient to call methods in each worker process (handled by
// workerServer).
type workerClient struct {
workerComm
m *mutator
// mu is the mutex protecting the workerComm.fuzzIn pipe. This must be
// locked before making calls to the workerServer. It prevents
// workerClient.Close from closing fuzzIn while workerClient methods are
// writing to it concurrently, and prevents multiple callers from writing to
// fuzzIn concurrently.
mu sync.Mutex
}
func newWorkerClient(comm workerComm, m *mutator) *workerClient {
return &workerClient{workerComm: comm, m: m}
}
// Close shuts down the connection to the RPC server (the worker process) by
// closing fuzz_in. Close drains fuzz_out (avoiding a SIGPIPE in the worker),
// and closes it after the worker process closes the other end.
func (wc *workerClient) Close() error {
wc.mu.Lock()
defer wc.mu.Unlock()
// Close fuzzIn. This signals to the server that there are no more calls,
// and it should exit.
if err := wc.fuzzIn.Close(); err != nil {
wc.fuzzOut.Close()
return err
}
// Drain fuzzOut and close it. When the server exits, the kernel will close
// its end of fuzzOut, and we'll get EOF.
if _, err := io.Copy(io.Discard, wc.fuzzOut); err != nil {
wc.fuzzOut.Close()
return err
}
return wc.fuzzOut.Close()
}
// errSharedMemClosed is returned by workerClient methods that cannot access
// shared memory because it was closed and unmapped by another goroutine. That
// can happen when worker.cleanup is called in the worker goroutine while a
// workerClient.fuzz call runs concurrently.
//
// This error should not be reported. It indicates the operation was
// interrupted.
var errSharedMemClosed = errors.New("internal error: shared memory was closed and unmapped")
// minimize tells the worker to call the minimize method. See
// workerServer.minimize.
func (wc *workerClient) minimize(ctx context.Context, entryIn CorpusEntry, args minimizeArgs) (entryOut CorpusEntry, resp minimizeResponse, retErr error) {
wc.mu.Lock()
defer wc.mu.Unlock()
mem, ok := <-wc.memMu
if !ok {
return CorpusEntry{}, minimizeResponse{}, errSharedMemClosed
}
mem.header().count = 0
inp, err := corpusEntryData(entryIn)
if err != nil {
return CorpusEntry{}, minimizeResponse{}, err
}
mem.setValue(inp)
defer func() { wc.memMu <- mem }()
entryOut = entryIn
entryOut.Values, err = unmarshalCorpusFile(inp)
if err != nil {
return CorpusEntry{}, minimizeResponse{}, fmt.Errorf("workerClient.minimize unmarshaling provided value: %v", err)
}
for i, v := range entryOut.Values {
if !isMinimizable(reflect.TypeOf(v)) {
continue
}
wc.memMu <- mem
args.Index = i
c := call{Minimize: &args}
callErr := wc.callLocked(ctx, c, &resp)
mem, ok = <-wc.memMu
if !ok {
return CorpusEntry{}, minimizeResponse{}, errSharedMemClosed
}
if callErr != nil {
retErr = callErr
if !mem.header().rawInMem {
// An unrecoverable error occurred before minimization began.
return entryIn, minimizeResponse{}, retErr
}
// An unrecoverable error occurred during minimization. mem now
// holds the raw, unmarshalled bytes of entryIn.Values[i] that
// caused the error.
switch entryOut.Values[i].(type) {
case string:
entryOut.Values[i] = string(mem.valueCopy())
case []byte:
entryOut.Values[i] = mem.valueCopy()
default:
panic("impossible")
}
entryOut.Data = marshalCorpusFile(entryOut.Values...)
// Stop minimizing; another unrecoverable error is likely to occur.
break
}
if resp.WroteToMem {
// Minimization succeeded, and mem holds the marshaled data.
entryOut.Data = mem.valueCopy()
entryOut.Values, err = unmarshalCorpusFile(entryOut.Data)
if err != nil {
return CorpusEntry{}, minimizeResponse{}, fmt.Errorf("workerClient.minimize unmarshaling minimized value: %v", err)
}
}
// Prepare for next iteration of the loop.
if args.Timeout != 0 {
args.Timeout -= resp.Duration
if args.Timeout <= 0 {
break
}
}
if args.Limit != 0 {
args.Limit -= mem.header().count
if args.Limit <= 0 {
break
}
}
}
resp.Count = mem.header().count
h := sha256.Sum256(entryOut.Data)
entryOut.Path = fmt.Sprintf("%x", h[:4])
return entryOut, resp, retErr
}
// fuzz tells the worker to call the fuzz method. See workerServer.fuzz.
func (wc *workerClient) fuzz(ctx context.Context, entryIn CorpusEntry, args fuzzArgs) (entryOut CorpusEntry, resp fuzzResponse, isInternalError bool, err error) {
wc.mu.Lock()
defer wc.mu.Unlock()
mem, ok := <-wc.memMu
if !ok {
return CorpusEntry{}, fuzzResponse{}, true, errSharedMemClosed
}
mem.header().count = 0
inp, err := corpusEntryData(entryIn)
if err != nil {
return CorpusEntry{}, fuzzResponse{}, true, err
}
mem.setValue(inp)
wc.memMu <- mem
c := call{Fuzz: &args}
callErr := wc.callLocked(ctx, c, &resp)
if resp.InternalErr != "" {
return CorpusEntry{}, fuzzResponse{}, true, errors.New(resp.InternalErr)
}
mem, ok = <-wc.memMu
if !ok {
return CorpusEntry{}, fuzzResponse{}, true, errSharedMemClosed
}
defer func() { wc.memMu <- mem }()
resp.Count = mem.header().count
if !bytes.Equal(inp, mem.valueRef()) {
return CorpusEntry{}, fuzzResponse{}, true, errors.New("workerServer.fuzz modified input")
}
needEntryOut := callErr != nil || resp.Err != "" ||
(!args.Warmup && resp.CoverageData != nil)
if needEntryOut {
valuesOut, err := unmarshalCorpusFile(inp)
if err != nil {
return CorpusEntry{}, fuzzResponse{}, true, fmt.Errorf("unmarshaling fuzz input value after call: %v", err)
}
wc.m.r.restore(mem.header().randState, mem.header().randInc)
if !args.Warmup {
// Only mutate the valuesOut if fuzzing actually occurred.
numMutations := ((resp.Count - 1) % chainedMutations) + 1
for i := int64(0); i < numMutations; i++ {
wc.m.mutate(valuesOut, cap(mem.valueRef()))
}
}
dataOut := marshalCorpusFile(valuesOut...)
h := sha256.Sum256(dataOut)
name := fmt.Sprintf("%x", h[:4])
entryOut = CorpusEntry{
Parent: entryIn.Path,
Path: name,
Data: dataOut,
Generation: entryIn.Generation + 1,
}
if args.Warmup {
// The bytes weren't mutated, so if entryIn was a seed corpus value,
// then entryOut is too.
entryOut.IsSeed = entryIn.IsSeed
}
}
return entryOut, resp, false, callErr
}
// ping tells the worker to call the ping method. See workerServer.ping.
func (wc *workerClient) ping(ctx context.Context) error {
wc.mu.Lock()
defer wc.mu.Unlock()
c := call{Ping: &pingArgs{}}
var resp pingResponse
return wc.callLocked(ctx, c, &resp)
}
// callLocked sends an RPC from the coordinator to the worker process and waits
// for the response. The callLocked may be cancelled with ctx.
func (wc *workerClient) callLocked(ctx context.Context, c call, resp any) (err error) {
enc := json.NewEncoder(wc.fuzzIn)
dec := json.NewDecoder(&contextReader{ctx: ctx, r: wc.fuzzOut})
if err := enc.Encode(c); err != nil {
return err
}
return dec.Decode(resp)
}
// contextReader wraps a Reader with a Context. If the context is cancelled
// while the underlying reader is blocked, Read returns immediately.
//
// This is useful for reading from a pipe. Closing a pipe file descriptor does
// not unblock pending Reads on that file descriptor. All copies of the pipe's
// other file descriptor (the write end) must be closed in all processes that
// inherit it. This is difficult to do correctly in the situation we care about
// (process group termination).
type contextReader struct {
ctx context.Context
r io.Reader
}
func (cr *contextReader) Read(b []byte) (int, error) {
if ctxErr := cr.ctx.Err(); ctxErr != nil {
return 0, ctxErr
}
done := make(chan struct{})
// This goroutine may stay blocked after Read returns because the underlying
// read is blocked.
var n int
var err error
go func() {
n, err = cr.r.Read(b)
close(done)
}()
select {
case <-cr.ctx.Done():
return 0, cr.ctx.Err()
case <-done:
return n, err
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package godebug makes the settings in the $GODEBUG environment variable
// available to other packages. These settings are often used for compatibility
// tweaks, when we need to change a default behavior but want to let users
// opt back in to the original. For example GODEBUG=http2server=0 disables
// HTTP/2 support in the net/http server.
//
// In typical usage, code should declare a Setting as a global
// and then call Value each time the current setting value is needed:
//
// var http2server = godebug.New("http2server")
//
// func ServeConn(c net.Conn) {
// if http2server.Value() == "0" {
// disallow HTTP/2
// ...
// }
// ...
// }
//
// Each time a non-default setting causes a change in program behavior,
// code should call [Setting.IncNonDefault] to increment a counter that can
// be reported by [runtime/metrics.Read].
// Note that counters used with IncNonDefault must be added to
// various tables in other packages. See the [Setting.IncNonDefault]
// documentation for details.
package godebug
import (
"sync"
"sync/atomic"
_ "unsafe" // go:linkname
)
// A Setting is a single setting in the $GODEBUG environment variable.
type Setting struct {
name string
once sync.Once
*setting
}
type setting struct {
value atomic.Pointer[string]
nonDefaultOnce sync.Once
nonDefault atomic.Uint64
}
// New returns a new Setting for the $GODEBUG setting with the given name.
func New(name string) *Setting {
return &Setting{name: name}
}
// Name returns the name of the setting.
func (s *Setting) Name() string {
return s.name
}
// String returns a printable form for the setting: name=value.
func (s *Setting) String() string {
return s.name + "=" + s.Value()
}
// IncNonDefault increments the non-default behavior counter
// associated with the given setting.
// This counter is exposed in the runtime/metrics value
// /godebug/non-default-behavior/<name>:events.
//
// Note that Value must be called at least once before IncNonDefault.
//
// Any GODEBUG setting that can call IncNonDefault must be listed
// in three more places:
//
// - the table in ../runtime/metrics.go (search for non-default-behavior)
// - the table in ../../runtime/metrics/description.go (search for non-default-behavior; run 'go generate' afterward)
// - the table in ../../cmd/go/internal/load/godebug.go (search for defaultGodebugs)
func (s *Setting) IncNonDefault() {
s.nonDefaultOnce.Do(s.register)
s.nonDefault.Add(1)
}
func (s *Setting) register() {
registerMetric("/godebug/non-default-behavior/"+s.name+":events", s.nonDefault.Load)
}
// cache is a cache of all the GODEBUG settings,
// a locked map[string]*atomic.Pointer[string].
//
// All Settings with the same name share a single
// *atomic.Pointer[string], so that when GODEBUG
// changes only that single atomic string pointer
// needs to be updated.
//
// A name appears in the values map either if it is the
// name of a Setting for which Value has been called
// at least once, or if the name has ever appeared in
// a name=value pair in the $GODEBUG environment variable.
// Once entered into the map, the name is never removed.
var cache sync.Map // name string -> value *atomic.Pointer[string]
var empty string
// Value returns the current value for the GODEBUG setting s.
//
// Value maintains an internal cache that is synchronized
// with changes to the $GODEBUG environment variable,
// making Value efficient to call as frequently as needed.
// Clients should therefore typically not attempt their own
// caching of Value's result.
func (s *Setting) Value() string {
s.once.Do(func() {
s.setting = lookup(s.name)
})
return *s.value.Load()
}
// lookup returns the unique *setting value for the given name.
func lookup(name string) *setting {
if v, ok := cache.Load(name); ok {
return v.(*setting)
}
s := new(setting)
s.value.Store(&empty)
if v, loaded := cache.LoadOrStore(name, s); loaded {
// Lost race: someone else created it. Use theirs.
return v.(*setting)
}
return s
}
// setUpdate is provided by package runtime.
// It calls update(def, env), where def is the default GODEBUG setting
// and env is the current value of the $GODEBUG environment variable.
// After that first call, the runtime calls update(def, env)
// again each time the environment variable changes
// (due to use of os.Setenv, for example).
//
//go:linkname setUpdate
func setUpdate(update func(string, string))
// registerMetric is provided by package runtime.
// It forwards registrations to runtime/metrics.
//
//go:linkname registerMetric
func registerMetric(name string, read func() uint64)
// setNewNonDefaultInc is provided by package runtime.
// The runtime can do
// inc := newNonDefaultInc(name)
// instead of
// inc := godebug.New(name).IncNonDefault
// since it cannot import godebug.
//
//go:linkname setNewIncNonDefault
func setNewIncNonDefault(newIncNonDefault func(string) func())
func init() {
setUpdate(update)
setNewIncNonDefault(newIncNonDefault)
}
func newIncNonDefault(name string) func() {
s := New(name)
s.Value()
return s.IncNonDefault
}
var updateMu sync.Mutex
// update records an updated GODEBUG setting.
// def is the default GODEBUG setting for the running binary,
// and env is the current value of the $GODEBUG environment variable.
func update(def, env string) {
updateMu.Lock()
defer updateMu.Unlock()
// Update all the cached values, creating new ones as needed.
// We parse the environment variable first, so that any settings it has
// are already locked in place (did[name] = true) before we consider
// the defaults.
did := make(map[string]bool)
parse(did, env)
parse(did, def)
// Clear any cached values that are no longer present.
cache.Range(func(name, s any) bool {
if !did[name.(string)] {
s.(*setting).value.Store(&empty)
}
return true
})
}
// parse parses the GODEBUG setting string s,
// which has the form k=v,k2=v2,k3=v3.
// Later settings override earlier ones.
// Parse only updates settings k=v for which did[k] = false.
// It also sets did[k] = true for settings that it updates.
func parse(did map[string]bool, s string) {
// Scan the string backward so that later settings are used
// and earlier settings are ignored.
// Note that a forward scan would cause cached values
// to temporarily use the ignored value before being
// updated to the "correct" one.
end := len(s)
eq := -1
for i := end - 1; i >= -1; i-- {
if i == -1 || s[i] == ',' {
if eq >= 0 {
name, value := s[i+1:eq], s[eq+1:end]
if !did[name] {
did[name] = true
lookup(name).value.Store(&value)
}
}
eq = -1
end = i
} else if s[i] == '=' {
eq = i
}
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package intern lets you make smaller comparable values by boxing
// a larger comparable value (such as a 16 byte string header) down
// into a globally unique 8 byte pointer.
//
// The globally unique pointers are garbage collected with weak
// references and finalizers. This package hides that.
package intern
import (
"internal/godebug"
"runtime"
"sync"
"unsafe"
)
// A Value pointer is the handle to an underlying comparable value.
// See func Get for how Value pointers may be used.
type Value struct {
_ [0]func() // prevent people from accidentally using value type as comparable
cmpVal any
// resurrected is guarded by mu (for all instances of Value).
// It is set true whenever v is synthesized from a uintptr.
resurrected bool
}
// Get returns the comparable value passed to the Get func
// that returned v.
func (v *Value) Get() any { return v.cmpVal }
// key is a key in our global value map.
// It contains type-specialized fields to avoid allocations
// when converting common types to empty interfaces.
type key struct {
s string
cmpVal any
// isString reports whether key contains a string.
// Without it, the zero value of key is ambiguous.
isString bool
}
// keyFor returns a key to use with cmpVal.
func keyFor(cmpVal any) key {
if s, ok := cmpVal.(string); ok {
return key{s: s, isString: true}
}
return key{cmpVal: cmpVal}
}
// Value returns a *Value built from k.
func (k key) Value() *Value {
if k.isString {
return &Value{cmpVal: k.s}
}
return &Value{cmpVal: k.cmpVal}
}
var (
// mu guards valMap, a weakref map of *Value by underlying value.
// It also guards the resurrected field of all *Values.
mu sync.Mutex
valMap = map[key]uintptr{} // to uintptr(*Value)
valSafe = safeMap() // non-nil in safe+leaky mode
)
var intern = godebug.New("intern")
// safeMap returns a non-nil map if we're in safe-but-leaky mode,
// as controlled by GODEBUG=intern=leaky
func safeMap() map[key]*Value {
if intern.Value() == "leaky" {
return map[key]*Value{}
}
return nil
}
// Get returns a pointer representing the comparable value cmpVal.
//
// The returned pointer will be the same for Get(v) and Get(v2)
// if and only if v == v2, and can be used as a map key.
func Get(cmpVal any) *Value {
return get(keyFor(cmpVal))
}
// GetByString is identical to Get, except that it is specialized for strings.
// This avoids an allocation from putting a string into an interface{}
// to pass as an argument to Get.
func GetByString(s string) *Value {
return get(key{s: s, isString: true})
}
// We play unsafe games that violate Go's rules (and assume a non-moving
// collector). So we quiet Go here.
// See the comment below Get for more implementation details.
//
//go:nocheckptr
func get(k key) *Value {
mu.Lock()
defer mu.Unlock()
var v *Value
if valSafe != nil {
v = valSafe[k]
} else if addr, ok := valMap[k]; ok {
v = (*Value)(unsafe.Pointer(addr))
v.resurrected = true
}
if v != nil {
return v
}
v = k.Value()
if valSafe != nil {
valSafe[k] = v
} else {
// SetFinalizer before uintptr conversion (theoretical concern;
// see https://github.com/go4org/intern/issues/13)
runtime.SetFinalizer(v, finalize)
valMap[k] = uintptr(unsafe.Pointer(v))
}
return v
}
func finalize(v *Value) {
mu.Lock()
defer mu.Unlock()
if v.resurrected {
// We lost the race. Somebody resurrected it while we
// were about to finalize it. Try again next round.
v.resurrected = false
runtime.SetFinalizer(v, finalize)
return
}
delete(valMap, keyFor(v.cmpVal))
}
// Interning is simple if you don't require that unused values be
// garbage collectable. But we do require that; we don't want to be
// DOS vector. We do this by using a uintptr to hide the pointer from
// the garbage collector, and using a finalizer to eliminate the
// pointer when no other code is using it.
//
// The obvious implementation of this is to use a
// map[interface{}]uintptr-of-*interface{}, and set up a finalizer to
// delete from the map. Unfortunately, this is racy. Because pointers
// are being created in violation of Go's unsafety rules, it's
// possible to create a pointer to a value concurrently with the GC
// concluding that the value can be collected. There are other races
// that break the equality invariant as well, but the use-after-free
// will cause a runtime crash.
//
// To make this work, the finalizer needs to know that no references
// have been unsafely created since the finalizer was set up. To do
// this, values carry a "resurrected" sentinel, which gets set
// whenever a pointer is unsafely created. If the finalizer encounters
// the sentinel, it clears the sentinel and delays collection for one
// additional GC cycle, by re-installing itself as finalizer. This
// ensures that the unsafely created pointer is visible to the GC, and
// will correctly prevent collection.
//
// This technique does mean that interned values that get reused take
// at least 3 GC cycles to fully collect (1 to clear the sentinel, 1
// to clean up the unsafe map, 1 to be actually deleted).
//
// @ianlancetaylor commented in
// https://github.com/golang/go/issues/41303#issuecomment-717401656
// that it is possible to implement weak references in terms of
// finalizers without unsafe. Unfortunately, the approach he outlined
// does not work here, for two reasons. First, there is no way to
// construct a strong pointer out of a weak pointer; our map stores
// weak pointers, but we must return strong pointers to callers.
// Second, and more fundamentally, we must return not just _a_ strong
// pointer to callers, but _the same_ strong pointer to callers. In
// order to return _the same_ strong pointer to callers, we must track
// it, which is exactly what we cannot do with strong pointers.
//
// See https://github.com/inetaf/netaddr/issues/53 for more
// discussion, and https://github.com/go4org/intern/issues/2 for an
// illustration of the subtleties at play.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Simple conversions to avoid depending on strconv.
package itoa
// Itoa converts val to a decimal string.
func Itoa(val int) string {
if val < 0 {
return "-" + Uitoa(uint(-val))
}
return Uitoa(uint(val))
}
// Uitoa converts val to a decimal string.
func Uitoa(val uint) string {
if val == 0 { // avoid string allocation
return "0"
}
var buf [20]byte // big enough for 64bit value base 10
i := len(buf) - 1
for val >= 10 {
q := val / 10
buf[i] = byte('0' + val - q*10)
i--
val = q
}
// val < 10
buf[i] = byte('0' + val)
return string(buf[i:])
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package poll
import (
"internal/syscall/unix"
"sync"
"syscall"
)
var (
kernelVersion53Once sync.Once
kernelVersion53 bool
)
const maxCopyFileRangeRound = 1 << 30
// CopyFileRange copies at most remain bytes of data from src to dst, using
// the copy_file_range system call. dst and src must refer to regular files.
func CopyFileRange(dst, src *FD, remain int64) (written int64, handled bool, err error) {
kernelVersion53Once.Do(func() {
major, minor := unix.KernelVersion()
// copy_file_range(2) is broken in various ways on kernels older than 5.3,
// see issue #42400 and
// https://man7.org/linux/man-pages/man2/copy_file_range.2.html#VERSIONS
if major > 5 || (major == 5 && minor >= 3) {
kernelVersion53 = true
}
})
if !kernelVersion53 {
return 0, false, nil
}
for remain > 0 {
max := remain
if max > maxCopyFileRangeRound {
max = maxCopyFileRangeRound
}
n, err := copyFileRange(dst, src, int(max))
switch err {
case syscall.ENOSYS:
// copy_file_range(2) was introduced in Linux 4.5.
// Go supports Linux >= 2.6.33, so the system call
// may not be present.
//
// If we see ENOSYS, we have certainly not transferred
// any data, so we can tell the caller that we
// couldn't handle the transfer and let them fall
// back to more generic code.
return 0, false, nil
case syscall.EXDEV, syscall.EINVAL, syscall.EIO, syscall.EOPNOTSUPP, syscall.EPERM:
// Prior to Linux 5.3, it was not possible to
// copy_file_range across file systems. Similarly to
// the ENOSYS case above, if we see EXDEV, we have
// not transferred any data, and we can let the caller
// fall back to generic code.
//
// As for EINVAL, that is what we see if, for example,
// dst or src refer to a pipe rather than a regular
// file. This is another case where no data has been
// transferred, so we consider it unhandled.
//
// If src and dst are on CIFS, we can see EIO.
// See issue #42334.
//
// If the file is on NFS, we can see EOPNOTSUPP.
// See issue #40731.
//
// If the process is running inside a Docker container,
// we might see EPERM instead of ENOSYS. See issue
// #40893. Since EPERM might also be a legitimate error,
// don't mark copy_file_range(2) as unsupported.
return 0, false, nil
case nil:
if n == 0 {
// If we did not read any bytes at all,
// then this file may be in a file system
// where copy_file_range silently fails.
// https://lore.kernel.org/linux-fsdevel/20210126233840.GG4626@dread.disaster.area/T/#m05753578c7f7882f6e9ffe01f981bc223edef2b0
if written == 0 {
return 0, false, nil
}
// Otherwise src is at EOF, which means
// we are done.
return written, true, nil
}
remain -= n
written += n
default:
return written, true, err
}
}
return written, true, nil
}
// copyFileRange performs one round of copy_file_range(2).
func copyFileRange(dst, src *FD, max int) (written int64, err error) {
// The signature of copy_file_range(2) is:
//
// ssize_t copy_file_range(int fd_in, loff_t *off_in,
// int fd_out, loff_t *off_out,
// size_t len, unsigned int flags);
//
// Note that in the call to unix.CopyFileRange below, we use nil
// values for off_in and off_out. For the system call, this means
// "use and update the file offsets". That is why we must acquire
// locks for both file descriptors (and why this whole machinery is
// in the internal/poll package to begin with).
if err := dst.writeLock(); err != nil {
return 0, err
}
defer dst.writeUnlock()
if err := src.readLock(); err != nil {
return 0, err
}
defer src.readUnlock()
var n int
for {
n, err = unix.CopyFileRange(src.Sysfd, nil, dst.Sysfd, nil, max, 0)
if err != syscall.EINTR {
break
}
}
return int64(n), err
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package poll
import "syscall"
// Do the interface allocations only once for common
// Errno values.
var (
errEAGAIN error = syscall.EAGAIN
errEINVAL error = syscall.EINVAL
errENOENT error = syscall.ENOENT
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case syscall.EAGAIN:
return errEAGAIN
case syscall.EINVAL:
return errEINVAL
case syscall.ENOENT:
return errENOENT
}
return e
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || (openbsd && mips64)
package poll
import (
"internal/syscall/unix"
"syscall"
)
func fcntl(fd int, cmd int, arg int) (int, error) {
r, _, e := syscall.Syscall(unix.FcntlSyscall, uintptr(fd), uintptr(cmd), uintptr(arg))
if e != 0 {
return int(r), syscall.Errno(e)
}
return int(r), nil
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package poll supports non-blocking I/O on file descriptors with polling.
// This supports I/O operations that block only a goroutine, not a thread.
// This is used by the net and os packages.
// It uses a poller built into the runtime, with support from the
// runtime scheduler.
package poll
import (
"errors"
)
// errNetClosing is the type of the variable ErrNetClosing.
// This is used to implement the net.Error interface.
type errNetClosing struct{}
// Error returns the error message for ErrNetClosing.
// Keep this string consistent because of issue #4373:
// since historically programs have not been able to detect
// this error, they look for the string.
func (e errNetClosing) Error() string { return "use of closed network connection" }
func (e errNetClosing) Timeout() bool { return false }
func (e errNetClosing) Temporary() bool { return false }
// ErrNetClosing is returned when a network descriptor is used after
// it has been closed.
var ErrNetClosing = errNetClosing{}
// ErrFileClosing is returned when a file descriptor is used after it
// has been closed.
var ErrFileClosing = errors.New("use of closed file")
// ErrNoDeadline is returned when a request is made to set a deadline
// on a file type that does not use the poller.
var ErrNoDeadline = errors.New("file type does not support deadline")
// Return the appropriate closing error based on isFile.
func errClosing(isFile bool) error {
if isFile {
return ErrFileClosing
}
return ErrNetClosing
}
// ErrDeadlineExceeded is returned for an expired deadline.
// This is exported by the os package as os.ErrDeadlineExceeded.
var ErrDeadlineExceeded error = &DeadlineExceededError{}
// DeadlineExceededError is returned for an expired deadline.
type DeadlineExceededError struct{}
// Implement the net.Error interface.
// The string is "i/o timeout" because that is what was returned
// by earlier Go versions. Changing it may break programs that
// match on error strings.
func (e *DeadlineExceededError) Error() string { return "i/o timeout" }
func (e *DeadlineExceededError) Timeout() bool { return true }
func (e *DeadlineExceededError) Temporary() bool { return true }
// ErrNotPollable is returned when the file or socket is not suitable
// for event notification.
var ErrNotPollable = errors.New("not pollable")
// consume removes data from a slice of byte slices, for writev.
func consume(v *[][]byte, n int64) {
for len(*v) > 0 {
ln0 := int64(len((*v)[0]))
if ln0 > n {
(*v)[0] = (*v)[0][n:]
return
}
n -= ln0
(*v)[0] = nil
*v = (*v)[1:]
}
}
// TestHookDidWritev is a hook for testing writev.
var TestHookDidWritev = func(wrote int) {}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || dragonfly || freebsd || (js && wasm) || linux || netbsd || openbsd || solaris
package poll
import "syscall"
// Fsync wraps syscall.Fsync.
func (fd *FD) Fsync() error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return ignoringEINTR(func() error {
return syscall.Fsync(fd.Sysfd)
})
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package poll
import "sync/atomic"
// fdMutex is a specialized synchronization primitive that manages
// lifetime of an fd and serializes access to Read, Write and Close
// methods on FD.
type fdMutex struct {
state uint64
rsema uint32
wsema uint32
}
// fdMutex.state is organized as follows:
// 1 bit - whether FD is closed, if set all subsequent lock operations will fail.
// 1 bit - lock for read operations.
// 1 bit - lock for write operations.
// 20 bits - total number of references (read+write+misc).
// 20 bits - number of outstanding read waiters.
// 20 bits - number of outstanding write waiters.
const (
mutexClosed = 1 << 0
mutexRLock = 1 << 1
mutexWLock = 1 << 2
mutexRef = 1 << 3
mutexRefMask = (1<<20 - 1) << 3
mutexRWait = 1 << 23
mutexRMask = (1<<20 - 1) << 23
mutexWWait = 1 << 43
mutexWMask = (1<<20 - 1) << 43
)
const overflowMsg = "too many concurrent operations on a single file or socket (max 1048575)"
// Read operations must do rwlock(true)/rwunlock(true).
//
// Write operations must do rwlock(false)/rwunlock(false).
//
// Misc operations must do incref/decref.
// Misc operations include functions like setsockopt and setDeadline.
// They need to use incref/decref to ensure that they operate on the
// correct fd in presence of a concurrent close call (otherwise fd can
// be closed under their feet).
//
// Close operations must do increfAndClose/decref.
// incref adds a reference to mu.
// It reports whether mu is available for reading or writing.
func (mu *fdMutex) incref() bool {
for {
old := atomic.LoadUint64(&mu.state)
if old&mutexClosed != 0 {
return false
}
new := old + mutexRef
if new&mutexRefMask == 0 {
panic(overflowMsg)
}
if atomic.CompareAndSwapUint64(&mu.state, old, new) {
return true
}
}
}
// increfAndClose sets the state of mu to closed.
// It returns false if the file was already closed.
func (mu *fdMutex) increfAndClose() bool {
for {
old := atomic.LoadUint64(&mu.state)
if old&mutexClosed != 0 {
return false
}
// Mark as closed and acquire a reference.
new := (old | mutexClosed) + mutexRef
if new&mutexRefMask == 0 {
panic(overflowMsg)
}
// Remove all read and write waiters.
new &^= mutexRMask | mutexWMask
if atomic.CompareAndSwapUint64(&mu.state, old, new) {
// Wake all read and write waiters,
// they will observe closed flag after wakeup.
for old&mutexRMask != 0 {
old -= mutexRWait
runtime_Semrelease(&mu.rsema)
}
for old&mutexWMask != 0 {
old -= mutexWWait
runtime_Semrelease(&mu.wsema)
}
return true
}
}
}
// decref removes a reference from mu.
// It reports whether there is no remaining reference.
func (mu *fdMutex) decref() bool {
for {
old := atomic.LoadUint64(&mu.state)
if old&mutexRefMask == 0 {
panic("inconsistent poll.fdMutex")
}
new := old - mutexRef
if atomic.CompareAndSwapUint64(&mu.state, old, new) {
return new&(mutexClosed|mutexRefMask) == mutexClosed
}
}
}
// lock adds a reference to mu and locks mu.
// It reports whether mu is available for reading or writing.
func (mu *fdMutex) rwlock(read bool) bool {
var mutexBit, mutexWait, mutexMask uint64
var mutexSema *uint32
if read {
mutexBit = mutexRLock
mutexWait = mutexRWait
mutexMask = mutexRMask
mutexSema = &mu.rsema
} else {
mutexBit = mutexWLock
mutexWait = mutexWWait
mutexMask = mutexWMask
mutexSema = &mu.wsema
}
for {
old := atomic.LoadUint64(&mu.state)
if old&mutexClosed != 0 {
return false
}
var new uint64
if old&mutexBit == 0 {
// Lock is free, acquire it.
new = (old | mutexBit) + mutexRef
if new&mutexRefMask == 0 {
panic(overflowMsg)
}
} else {
// Wait for lock.
new = old + mutexWait
if new&mutexMask == 0 {
panic(overflowMsg)
}
}
if atomic.CompareAndSwapUint64(&mu.state, old, new) {
if old&mutexBit == 0 {
return true
}
runtime_Semacquire(mutexSema)
// The signaller has subtracted mutexWait.
}
}
}
// unlock removes a reference from mu and unlocks mu.
// It reports whether there is no remaining reference.
func (mu *fdMutex) rwunlock(read bool) bool {
var mutexBit, mutexWait, mutexMask uint64
var mutexSema *uint32
if read {
mutexBit = mutexRLock
mutexWait = mutexRWait
mutexMask = mutexRMask
mutexSema = &mu.rsema
} else {
mutexBit = mutexWLock
mutexWait = mutexWWait
mutexMask = mutexWMask
mutexSema = &mu.wsema
}
for {
old := atomic.LoadUint64(&mu.state)
if old&mutexBit == 0 || old&mutexRefMask == 0 {
panic("inconsistent poll.fdMutex")
}
// Drop lock, drop reference and wake read waiter if present.
new := (old &^ mutexBit) - mutexRef
if old&mutexMask != 0 {
new -= mutexWait
}
if atomic.CompareAndSwapUint64(&mu.state, old, new) {
if old&mutexMask != 0 {
runtime_Semrelease(mutexSema)
}
return new&(mutexClosed|mutexRefMask) == mutexClosed
}
}
}
// Implemented in runtime package.
func runtime_Semacquire(sema *uint32)
func runtime_Semrelease(sema *uint32)
// incref adds a reference to fd.
// It returns an error when fd cannot be used.
func (fd *FD) incref() error {
if !fd.fdmu.incref() {
return errClosing(fd.isFile)
}
return nil
}
// decref removes a reference from fd.
// It also closes fd when the state of fd is set to closed and there
// is no remaining reference.
func (fd *FD) decref() error {
if fd.fdmu.decref() {
return fd.destroy()
}
return nil
}
// readLock adds a reference to fd and locks fd for reading.
// It returns an error when fd cannot be used for reading.
func (fd *FD) readLock() error {
if !fd.fdmu.rwlock(true) {
return errClosing(fd.isFile)
}
return nil
}
// readUnlock removes a reference from fd and unlocks fd for reading.
// It also closes fd when the state of fd is set to closed and there
// is no remaining reference.
func (fd *FD) readUnlock() {
if fd.fdmu.rwunlock(true) {
fd.destroy()
}
}
// writeLock adds a reference to fd and locks fd for writing.
// It returns an error when fd cannot be used for writing.
func (fd *FD) writeLock() error {
if !fd.fdmu.rwlock(false) {
return errClosing(fd.isFile)
}
return nil
}
// writeUnlock removes a reference from fd and unlocks fd for writing.
// It also closes fd when the state of fd is set to closed and there
// is no remaining reference.
func (fd *FD) writeUnlock() {
if fd.fdmu.rwunlock(false) {
fd.destroy()
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package poll
import (
"errors"
"sync"
"syscall"
"time"
_ "unsafe" // for go:linkname
)
// runtimeNano returns the current value of the runtime clock in nanoseconds.
//
//go:linkname runtimeNano runtime.nanotime
func runtimeNano() int64
func runtime_pollServerInit()
func runtime_pollOpen(fd uintptr) (uintptr, int)
func runtime_pollClose(ctx uintptr)
func runtime_pollWait(ctx uintptr, mode int) int
func runtime_pollWaitCanceled(ctx uintptr, mode int)
func runtime_pollReset(ctx uintptr, mode int) int
func runtime_pollSetDeadline(ctx uintptr, d int64, mode int)
func runtime_pollUnblock(ctx uintptr)
func runtime_isPollServerDescriptor(fd uintptr) bool
type pollDesc struct {
runtimeCtx uintptr
}
var serverInit sync.Once
func (pd *pollDesc) init(fd *FD) error {
serverInit.Do(runtime_pollServerInit)
ctx, errno := runtime_pollOpen(uintptr(fd.Sysfd))
if errno != 0 {
return errnoErr(syscall.Errno(errno))
}
pd.runtimeCtx = ctx
return nil
}
func (pd *pollDesc) close() {
if pd.runtimeCtx == 0 {
return
}
runtime_pollClose(pd.runtimeCtx)
pd.runtimeCtx = 0
}
// Evict evicts fd from the pending list, unblocking any I/O running on fd.
func (pd *pollDesc) evict() {
if pd.runtimeCtx == 0 {
return
}
runtime_pollUnblock(pd.runtimeCtx)
}
func (pd *pollDesc) prepare(mode int, isFile bool) error {
if pd.runtimeCtx == 0 {
return nil
}
res := runtime_pollReset(pd.runtimeCtx, mode)
return convertErr(res, isFile)
}
func (pd *pollDesc) prepareRead(isFile bool) error {
return pd.prepare('r', isFile)
}
func (pd *pollDesc) prepareWrite(isFile bool) error {
return pd.prepare('w', isFile)
}
func (pd *pollDesc) wait(mode int, isFile bool) error {
if pd.runtimeCtx == 0 {
return errors.New("waiting for unsupported file type")
}
res := runtime_pollWait(pd.runtimeCtx, mode)
return convertErr(res, isFile)
}
func (pd *pollDesc) waitRead(isFile bool) error {
return pd.wait('r', isFile)
}
func (pd *pollDesc) waitWrite(isFile bool) error {
return pd.wait('w', isFile)
}
func (pd *pollDesc) waitCanceled(mode int) {
if pd.runtimeCtx == 0 {
return
}
runtime_pollWaitCanceled(pd.runtimeCtx, mode)
}
func (pd *pollDesc) pollable() bool {
return pd.runtimeCtx != 0
}
// Error values returned by runtime_pollReset and runtime_pollWait.
// These must match the values in runtime/netpoll.go.
const (
pollNoError = 0
pollErrClosing = 1
pollErrTimeout = 2
pollErrNotPollable = 3
)
func convertErr(res int, isFile bool) error {
switch res {
case pollNoError:
return nil
case pollErrClosing:
return errClosing(isFile)
case pollErrTimeout:
return ErrDeadlineExceeded
case pollErrNotPollable:
return ErrNotPollable
}
println("unreachable: ", res)
panic("unreachable")
}
// SetDeadline sets the read and write deadlines associated with fd.
func (fd *FD) SetDeadline(t time.Time) error {
return setDeadlineImpl(fd, t, 'r'+'w')
}
// SetReadDeadline sets the read deadline associated with fd.
func (fd *FD) SetReadDeadline(t time.Time) error {
return setDeadlineImpl(fd, t, 'r')
}
// SetWriteDeadline sets the write deadline associated with fd.
func (fd *FD) SetWriteDeadline(t time.Time) error {
return setDeadlineImpl(fd, t, 'w')
}
func setDeadlineImpl(fd *FD, t time.Time, mode int) error {
var d int64
if !t.IsZero() {
d = int64(time.Until(t))
if d == 0 {
d = -1 // don't confuse deadline right now with no deadline
}
}
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
if fd.pd.runtimeCtx == 0 {
return ErrNoDeadline
}
runtime_pollSetDeadline(fd.pd.runtimeCtx, d, mode)
return nil
}
// IsPollDescriptor reports whether fd is the descriptor being used by the poller.
// This is only used for testing.
func IsPollDescriptor(fd uintptr) bool {
return runtime_isPollServerDescriptor(fd)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package poll
import (
"io"
"syscall"
)
// eofError returns io.EOF when fd is available for reading end of
// file.
func (fd *FD) eofError(n int, err error) error {
if n == 0 && err == nil && fd.ZeroReadIsEOF {
return io.EOF
}
return err
}
// Shutdown wraps syscall.Shutdown.
func (fd *FD) Shutdown(how int) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.Shutdown(fd.Sysfd, how)
}
// Fchown wraps syscall.Fchown.
func (fd *FD) Fchown(uid, gid int) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return ignoringEINTR(func() error {
return syscall.Fchown(fd.Sysfd, uid, gid)
})
}
// Ftruncate wraps syscall.Ftruncate.
func (fd *FD) Ftruncate(size int64) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return ignoringEINTR(func() error {
return syscall.Ftruncate(fd.Sysfd, size)
})
}
// RawControl invokes the user-defined function f for a non-IO
// operation.
func (fd *FD) RawControl(f func(uintptr)) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
f(uintptr(fd.Sysfd))
return nil
}
// ignoringEINTR makes a function call and repeats it if it returns
// an EINTR error. This appears to be required even though we install all
// signal handlers with SA_RESTART: see #22838, #38033, #38836, #40846.
// Also #20400 and #36644 are issues in which a signal handler is
// installed without setting SA_RESTART. None of these are the common case,
// but there are enough of them that it seems that we can't avoid
// an EINTR loop.
func ignoringEINTR(fn func() error) error {
for {
err := fn()
if err != syscall.EINTR {
return err
}
}
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package poll
import (
"internal/syscall/unix"
"io"
"sync/atomic"
"syscall"
)
// FD is a file descriptor. The net and os packages use this type as a
// field of a larger type representing a network connection or OS file.
type FD struct {
// Lock sysfd and serialize access to Read and Write methods.
fdmu fdMutex
// System file descriptor. Immutable until Close.
Sysfd int
// I/O poller.
pd pollDesc
// Writev cache.
iovecs *[]syscall.Iovec
// Semaphore signaled when file is closed.
csema uint32
// Non-zero if this file has been set to blocking mode.
isBlocking uint32
// Whether this is a streaming descriptor, as opposed to a
// packet-based descriptor like a UDP socket. Immutable.
IsStream bool
// Whether a zero byte read indicates EOF. This is false for a
// message based socket connection.
ZeroReadIsEOF bool
// Whether this is a file rather than a network socket.
isFile bool
}
// Init initializes the FD. The Sysfd field should already be set.
// This can be called multiple times on a single FD.
// The net argument is a network name from the net package (e.g., "tcp"),
// or "file".
// Set pollable to true if fd should be managed by runtime netpoll.
func (fd *FD) Init(net string, pollable bool) error {
// We don't actually care about the various network types.
if net == "file" {
fd.isFile = true
}
if !pollable {
fd.isBlocking = 1
return nil
}
err := fd.pd.init(fd)
if err != nil {
// If we could not initialize the runtime poller,
// assume we are using blocking mode.
fd.isBlocking = 1
}
return err
}
// Destroy closes the file descriptor. This is called when there are
// no remaining references.
func (fd *FD) destroy() error {
// Poller may want to unregister fd in readiness notification mechanism,
// so this must be executed before CloseFunc.
fd.pd.close()
// We don't use ignoringEINTR here because POSIX does not define
// whether the descriptor is closed if close returns EINTR.
// If the descriptor is indeed closed, using a loop would race
// with some other goroutine opening a new descriptor.
// (The Linux kernel guarantees that it is closed on an EINTR error.)
err := CloseFunc(fd.Sysfd)
fd.Sysfd = -1
runtime_Semrelease(&fd.csema)
return err
}
// Close closes the FD. The underlying file descriptor is closed by the
// destroy method when there are no remaining references.
func (fd *FD) Close() error {
if !fd.fdmu.increfAndClose() {
return errClosing(fd.isFile)
}
// Unblock any I/O. Once it all unblocks and returns,
// so that it cannot be referring to fd.sysfd anymore,
// the final decref will close fd.sysfd. This should happen
// fairly quickly, since all the I/O is non-blocking, and any
// attempts to block in the pollDesc will return errClosing(fd.isFile).
fd.pd.evict()
// The call to decref will call destroy if there are no other
// references.
err := fd.decref()
// Wait until the descriptor is closed. If this was the only
// reference, it is already closed. Only wait if the file has
// not been set to blocking mode, as otherwise any current I/O
// may be blocking, and that would block the Close.
// No need for an atomic read of isBlocking, increfAndClose means
// we have exclusive access to fd.
if fd.isBlocking == 0 {
runtime_Semacquire(&fd.csema)
}
return err
}
// SetBlocking puts the file into blocking mode.
func (fd *FD) SetBlocking() error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
// Atomic store so that concurrent calls to SetBlocking
// do not cause a race condition. isBlocking only ever goes
// from 0 to 1 so there is no real race here.
atomic.StoreUint32(&fd.isBlocking, 1)
return syscall.SetNonblock(fd.Sysfd, false)
}
// Darwin and FreeBSD can't read or write 2GB+ files at a time,
// even on 64-bit systems.
// The same is true of socket implementations on many systems.
// See golang.org/issue/7812 and golang.org/issue/16266.
// Use 1GB instead of, say, 2GB-1, to keep subsequent reads aligned.
const maxRW = 1 << 30
// Read implements io.Reader.
func (fd *FD) Read(p []byte) (int, error) {
if err := fd.readLock(); err != nil {
return 0, err
}
defer fd.readUnlock()
if len(p) == 0 {
// If the caller wanted a zero byte read, return immediately
// without trying (but after acquiring the readLock).
// Otherwise syscall.Read returns 0, nil which looks like
// io.EOF.
// TODO(bradfitz): make it wait for readability? (Issue 15735)
return 0, nil
}
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, err
}
if fd.IsStream && len(p) > maxRW {
p = p[:maxRW]
}
for {
n, err := ignoringEINTRIO(syscall.Read, fd.Sysfd, p)
if err != nil {
n = 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, err
}
}
// Pread wraps the pread system call.
func (fd *FD) Pread(p []byte, off int64) (int, error) {
// Call incref, not readLock, because since pread specifies the
// offset it is independent from other reads.
// Similarly, using the poller doesn't make sense for pread.
if err := fd.incref(); err != nil {
return 0, err
}
if fd.IsStream && len(p) > maxRW {
p = p[:maxRW]
}
var (
n int
err error
)
for {
n, err = syscall.Pread(fd.Sysfd, p, off)
if err != syscall.EINTR {
break
}
}
if err != nil {
n = 0
}
fd.decref()
err = fd.eofError(n, err)
return n, err
}
// ReadFrom wraps the recvfrom network call.
func (fd *FD) ReadFrom(p []byte) (int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, nil, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, nil, err
}
for {
n, sa, err := syscall.Recvfrom(fd.Sysfd, p, 0)
if err != nil {
if err == syscall.EINTR {
continue
}
n = 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, sa, err
}
}
// ReadFromInet4 wraps the recvfrom network call for IPv4.
func (fd *FD) ReadFromInet4(p []byte, from *syscall.SockaddrInet4) (int, error) {
if err := fd.readLock(); err != nil {
return 0, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, err
}
for {
n, err := unix.RecvfromInet4(fd.Sysfd, p, 0, from)
if err != nil {
if err == syscall.EINTR {
continue
}
n = 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, err
}
}
// ReadFromInet6 wraps the recvfrom network call for IPv6.
func (fd *FD) ReadFromInet6(p []byte, from *syscall.SockaddrInet6) (int, error) {
if err := fd.readLock(); err != nil {
return 0, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, err
}
for {
n, err := unix.RecvfromInet6(fd.Sysfd, p, 0, from)
if err != nil {
if err == syscall.EINTR {
continue
}
n = 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, err
}
}
// ReadMsg wraps the recvmsg network call.
func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, nil, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, 0, 0, nil, err
}
for {
n, oobn, sysflags, sa, err := syscall.Recvmsg(fd.Sysfd, p, oob, flags)
if err != nil {
if err == syscall.EINTR {
continue
}
// TODO(dfc) should n and oobn be set to 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, oobn, sysflags, sa, err
}
}
// ReadMsgInet4 is ReadMsg, but specialized for syscall.SockaddrInet4.
func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.SockaddrInet4) (int, int, int, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, 0, 0, err
}
for {
n, oobn, sysflags, err := unix.RecvmsgInet4(fd.Sysfd, p, oob, flags, sa4)
if err != nil {
if err == syscall.EINTR {
continue
}
// TODO(dfc) should n and oobn be set to 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, oobn, sysflags, err
}
}
// ReadMsgInet6 is ReadMsg, but specialized for syscall.SockaddrInet6.
func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.SockaddrInet6) (int, int, int, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, 0, 0, err
}
for {
n, oobn, sysflags, err := unix.RecvmsgInet6(fd.Sysfd, p, oob, flags, sa6)
if err != nil {
if err == syscall.EINTR {
continue
}
// TODO(dfc) should n and oobn be set to 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, oobn, sysflags, err
}
}
// Write implements io.Writer.
func (fd *FD) Write(p []byte) (int, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, err
}
var nn int
for {
max := len(p)
if fd.IsStream && max-nn > maxRW {
max = nn + maxRW
}
n, err := ignoringEINTRIO(syscall.Write, fd.Sysfd, p[nn:max])
if n > 0 {
nn += n
}
if nn == len(p) {
return nn, err
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
// Pwrite wraps the pwrite system call.
func (fd *FD) Pwrite(p []byte, off int64) (int, error) {
// Call incref, not writeLock, because since pwrite specifies the
// offset it is independent from other writes.
// Similarly, using the poller doesn't make sense for pwrite.
if err := fd.incref(); err != nil {
return 0, err
}
defer fd.decref()
var nn int
for {
max := len(p)
if fd.IsStream && max-nn > maxRW {
max = nn + maxRW
}
n, err := syscall.Pwrite(fd.Sysfd, p[nn:max], off+int64(nn))
if err == syscall.EINTR {
continue
}
if n > 0 {
nn += n
}
if nn == len(p) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, io.ErrUnexpectedEOF
}
}
}
// WriteToInet4 wraps the sendto network call for IPv4 addresses.
func (fd *FD) WriteToInet4(p []byte, sa *syscall.SockaddrInet4) (int, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, err
}
for {
err := unix.SendtoInet4(fd.Sysfd, p, 0, sa)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return 0, err
}
return len(p), nil
}
}
// WriteToInet6 wraps the sendto network call for IPv6 addresses.
func (fd *FD) WriteToInet6(p []byte, sa *syscall.SockaddrInet6) (int, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, err
}
for {
err := unix.SendtoInet6(fd.Sysfd, p, 0, sa)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return 0, err
}
return len(p), nil
}
}
// WriteTo wraps the sendto network call.
func (fd *FD) WriteTo(p []byte, sa syscall.Sockaddr) (int, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, err
}
for {
err := syscall.Sendto(fd.Sysfd, p, 0, sa)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return 0, err
}
return len(p), nil
}
}
// WriteMsg wraps the sendmsg network call.
func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, error) {
if err := fd.writeLock(); err != nil {
return 0, 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, 0, err
}
for {
n, err := syscall.SendmsgN(fd.Sysfd, p, oob, sa, 0)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return n, 0, err
}
return n, len(oob), err
}
}
// WriteMsgInet4 is WriteMsg specialized for syscall.SockaddrInet4.
func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (int, int, error) {
if err := fd.writeLock(); err != nil {
return 0, 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, 0, err
}
for {
n, err := unix.SendmsgNInet4(fd.Sysfd, p, oob, sa, 0)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return n, 0, err
}
return n, len(oob), err
}
}
// WriteMsgInet6 is WriteMsg specialized for syscall.SockaddrInet6.
func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (int, int, error) {
if err := fd.writeLock(); err != nil {
return 0, 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, 0, err
}
for {
n, err := unix.SendmsgNInet6(fd.Sysfd, p, oob, sa, 0)
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return n, 0, err
}
return n, len(oob), err
}
}
// Accept wraps the accept network call.
func (fd *FD) Accept() (int, syscall.Sockaddr, string, error) {
if err := fd.readLock(); err != nil {
return -1, nil, "", err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return -1, nil, "", err
}
for {
s, rsa, errcall, err := accept(fd.Sysfd)
if err == nil {
return s, rsa, "", err
}
switch err {
case syscall.EINTR:
continue
case syscall.EAGAIN:
if fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
case syscall.ECONNABORTED:
// This means that a socket on the listen
// queue was closed before we Accept()ed it;
// it's a silly error, so try again.
continue
}
return -1, nil, errcall, err
}
}
// Seek wraps syscall.Seek.
func (fd *FD) Seek(offset int64, whence int) (int64, error) {
if err := fd.incref(); err != nil {
return 0, err
}
defer fd.decref()
return syscall.Seek(fd.Sysfd, offset, whence)
}
// ReadDirent wraps syscall.ReadDirent.
// We treat this like an ordinary system call rather than a call
// that tries to fill the buffer.
func (fd *FD) ReadDirent(buf []byte) (int, error) {
if err := fd.incref(); err != nil {
return 0, err
}
defer fd.decref()
for {
n, err := ignoringEINTRIO(syscall.ReadDirent, fd.Sysfd, buf)
if err != nil {
n = 0
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
// Do not call eofError; caller does not expect to see io.EOF.
return n, err
}
}
// Fchmod wraps syscall.Fchmod.
func (fd *FD) Fchmod(mode uint32) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return ignoringEINTR(func() error {
return syscall.Fchmod(fd.Sysfd, mode)
})
}
// Fchdir wraps syscall.Fchdir.
func (fd *FD) Fchdir() error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.Fchdir(fd.Sysfd)
}
// Fstat wraps syscall.Fstat
func (fd *FD) Fstat(s *syscall.Stat_t) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return ignoringEINTR(func() error {
return syscall.Fstat(fd.Sysfd, s)
})
}
// dupCloexecUnsupported indicates whether F_DUPFD_CLOEXEC is supported by the kernel.
var dupCloexecUnsupported atomic.Bool
// DupCloseOnExec dups fd and marks it close-on-exec.
func DupCloseOnExec(fd int) (int, string, error) {
if syscall.F_DUPFD_CLOEXEC != 0 && !dupCloexecUnsupported.Load() {
r0, e1 := fcntl(fd, syscall.F_DUPFD_CLOEXEC, 0)
if e1 == nil {
return r0, "", nil
}
switch e1.(syscall.Errno) {
case syscall.EINVAL, syscall.ENOSYS:
// Old kernel, or js/wasm (which returns
// ENOSYS). Fall back to the portable way from
// now on.
dupCloexecUnsupported.Store(true)
default:
return -1, "fcntl", e1
}
}
return dupCloseOnExecOld(fd)
}
// dupCloseOnExecOld is the traditional way to dup an fd and
// set its O_CLOEXEC bit, using two system calls.
func dupCloseOnExecOld(fd int) (int, string, error) {
syscall.ForkLock.RLock()
defer syscall.ForkLock.RUnlock()
newfd, err := syscall.Dup(fd)
if err != nil {
return -1, "dup", err
}
syscall.CloseOnExec(newfd)
return newfd, "", nil
}
// Dup duplicates the file descriptor.
func (fd *FD) Dup() (int, string, error) {
if err := fd.incref(); err != nil {
return -1, "", err
}
defer fd.decref()
return DupCloseOnExec(fd.Sysfd)
}
// On Unix variants only, expose the IO event for the net code.
// WaitWrite waits until data can be read from fd.
func (fd *FD) WaitWrite() error {
return fd.pd.waitWrite(fd.isFile)
}
// WriteOnce is for testing only. It makes a single write call.
func (fd *FD) WriteOnce(p []byte) (int, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
return ignoringEINTRIO(syscall.Write, fd.Sysfd, p)
}
// RawRead invokes the user-defined function f for a read operation.
func (fd *FD) RawRead(f func(uintptr) bool) error {
if err := fd.readLock(); err != nil {
return err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return err
}
for {
if f(uintptr(fd.Sysfd)) {
return nil
}
if err := fd.pd.waitRead(fd.isFile); err != nil {
return err
}
}
}
// RawWrite invokes the user-defined function f for a write operation.
func (fd *FD) RawWrite(f func(uintptr) bool) error {
if err := fd.writeLock(); err != nil {
return err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return err
}
for {
if f(uintptr(fd.Sysfd)) {
return nil
}
if err := fd.pd.waitWrite(fd.isFile); err != nil {
return err
}
}
}
// ignoringEINTRIO is like ignoringEINTR, but just for IO calls.
func ignoringEINTRIO(fn func(fd int, p []byte) (int, error), fd int, p []byte) (int, error) {
for {
n, err := fn(fd, p)
if err != syscall.EINTR {
return n, err
}
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || (openbsd && mips64)
package poll
import (
"syscall"
"unsafe"
)
func writev(fd int, iovecs []syscall.Iovec) (uintptr, error) {
var (
r uintptr
e syscall.Errno
)
for {
r, _, e = syscall.Syscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e != syscall.EINTR {
break
}
}
if e != 0 {
return r, e
}
return r, nil
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd
package poll
import "syscall"
func newIovecWithBase(base *byte) syscall.Iovec {
return syscall.Iovec{Base: base}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package poll
import "syscall"
// maxSendfileSize is the largest chunk size we ask the kernel to copy
// at a time.
const maxSendfileSize int = 4 << 20
// SendFile wraps the sendfile system call.
func SendFile(dstFD *FD, src int, remain int64) (int64, error, bool) {
if err := dstFD.writeLock(); err != nil {
return 0, err, false
}
defer dstFD.writeUnlock()
if err := dstFD.pd.prepareWrite(dstFD.isFile); err != nil {
return 0, err, false
}
dst := dstFD.Sysfd
var (
written int64
err error
handled = true
)
for remain > 0 {
n := maxSendfileSize
if int64(n) > remain {
n = int(remain)
}
n, err1 := syscall.Sendfile(dst, src, nil, n)
if n > 0 {
written += int64(n)
remain -= int64(n)
} else if n == 0 && err1 == nil {
break
}
if err1 == syscall.EINTR {
continue
}
if err1 == syscall.EAGAIN {
if err1 = dstFD.pd.waitWrite(dstFD.isFile); err1 == nil {
continue
}
}
if err1 != nil {
// This includes syscall.ENOSYS (no kernel
// support) and syscall.EINVAL (fd types which
// don't implement sendfile)
err = err1
handled = false
break
}
}
return written, err, handled
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements accept for platforms that provide a fast path for
// setting SetNonblock and CloseOnExec.
//go:build dragonfly || freebsd || (linux && !arm) || netbsd || openbsd || solaris
package poll
import "syscall"
// Wrapper around the accept system call that marks the returned file
// descriptor as nonblocking and close-on-exec.
func accept(s int) (int, syscall.Sockaddr, string, error) {
ns, sa, err := Accept4Func(s, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
if err != nil {
return -1, sa, "accept4", err
}
return ns, sa, "", nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package poll
import "syscall"
// SetsockoptInt wraps the setsockopt network call with an int argument.
func (fd *FD) SetsockoptInt(level, name, arg int) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptInt(fd.Sysfd, level, name, arg)
}
// SetsockoptInet4Addr wraps the setsockopt network call with an IPv4 address.
func (fd *FD) SetsockoptInet4Addr(level, name int, arg [4]byte) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptInet4Addr(fd.Sysfd, level, name, arg)
}
// SetsockoptLinger wraps the setsockopt network call with a Linger argument.
func (fd *FD) SetsockoptLinger(level, name int, l *syscall.Linger) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptLinger(fd.Sysfd, level, name, l)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package poll
import "syscall"
// SetsockoptIPMreqn wraps the setsockopt network call with an IPMreqn argument.
func (fd *FD) SetsockoptIPMreqn(level, name int, mreq *syscall.IPMreqn) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptIPMreqn(fd.Sysfd, level, name, mreq)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package poll
import "syscall"
// SetsockoptByte wraps the setsockopt network call with a byte argument.
func (fd *FD) SetsockoptByte(level, name int, arg byte) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptByte(fd.Sysfd, level, name, arg)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package poll
import "syscall"
// SetsockoptIPMreq wraps the setsockopt network call with an IPMreq argument.
func (fd *FD) SetsockoptIPMreq(level, name int, mreq *syscall.IPMreq) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptIPMreq(fd.Sysfd, level, name, mreq)
}
// SetsockoptIPv6Mreq wraps the setsockopt network call with an IPv6Mreq argument.
func (fd *FD) SetsockoptIPv6Mreq(level, name int, mreq *syscall.IPv6Mreq) error {
if err := fd.incref(); err != nil {
return err
}
defer fd.decref()
return syscall.SetsockoptIPv6Mreq(fd.Sysfd, level, name, mreq)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package poll
import (
"runtime"
"sync"
"syscall"
"unsafe"
)
const (
// spliceNonblock makes calls to splice(2) non-blocking.
spliceNonblock = 0x2
// maxSpliceSize is the maximum amount of data Splice asks
// the kernel to move in a single call to splice(2).
// We use 1MB as Splice writes data through a pipe, and 1MB is the default maximum pipe buffer size,
// which is determined by /proc/sys/fs/pipe-max-size.
maxSpliceSize = 1 << 20
)
// Splice transfers at most remain bytes of data from src to dst, using the
// splice system call to minimize copies of data from and to userspace.
//
// Splice gets a pipe buffer from the pool or creates a new one if needed, to serve as a buffer for the data transfer.
// src and dst must both be stream-oriented sockets.
//
// If err != nil, sc is the system call which caused the error.
func Splice(dst, src *FD, remain int64) (written int64, handled bool, sc string, err error) {
p, sc, err := getPipe()
if err != nil {
return 0, false, sc, err
}
defer putPipe(p)
var inPipe, n int
for err == nil && remain > 0 {
max := maxSpliceSize
if int64(max) > remain {
max = int(remain)
}
inPipe, err = spliceDrain(p.wfd, src, max)
// The operation is considered handled if splice returns no
// error, or an error other than EINVAL. An EINVAL means the
// kernel does not support splice for the socket type of src.
// The failed syscall does not consume any data so it is safe
// to fall back to a generic copy.
//
// spliceDrain should never return EAGAIN, so if err != nil,
// Splice cannot continue.
//
// If inPipe == 0 && err == nil, src is at EOF, and the
// transfer is complete.
handled = handled || (err != syscall.EINVAL)
if err != nil || inPipe == 0 {
break
}
p.data += inPipe
n, err = splicePump(dst, p.rfd, inPipe)
if n > 0 {
written += int64(n)
remain -= int64(n)
p.data -= n
}
}
if err != nil {
return written, handled, "splice", err
}
return written, true, "", nil
}
// spliceDrain moves data from a socket to a pipe.
//
// Invariant: when entering spliceDrain, the pipe is empty. It is either in its
// initial state, or splicePump has emptied it previously.
//
// Given this, spliceDrain can reasonably assume that the pipe is ready for
// writing, so if splice returns EAGAIN, it must be because the socket is not
// ready for reading.
//
// If spliceDrain returns (0, nil), src is at EOF.
func spliceDrain(pipefd int, sock *FD, max int) (int, error) {
if err := sock.readLock(); err != nil {
return 0, err
}
defer sock.readUnlock()
if err := sock.pd.prepareRead(sock.isFile); err != nil {
return 0, err
}
for {
n, err := splice(pipefd, sock.Sysfd, max, spliceNonblock)
if err == syscall.EINTR {
continue
}
if err != syscall.EAGAIN {
return n, err
}
if err := sock.pd.waitRead(sock.isFile); err != nil {
return n, err
}
}
}
// splicePump moves all the buffered data from a pipe to a socket.
//
// Invariant: when entering splicePump, there are exactly inPipe
// bytes of data in the pipe, from a previous call to spliceDrain.
//
// By analogy to the condition from spliceDrain, splicePump
// only needs to poll the socket for readiness, if splice returns
// EAGAIN.
//
// If splicePump cannot move all the data in a single call to
// splice(2), it loops over the buffered data until it has written
// all of it to the socket. This behavior is similar to the Write
// step of an io.Copy in userspace.
func splicePump(sock *FD, pipefd int, inPipe int) (int, error) {
if err := sock.writeLock(); err != nil {
return 0, err
}
defer sock.writeUnlock()
if err := sock.pd.prepareWrite(sock.isFile); err != nil {
return 0, err
}
written := 0
for inPipe > 0 {
n, err := splice(sock.Sysfd, pipefd, inPipe, spliceNonblock)
// Here, the condition n == 0 && err == nil should never be
// observed, since Splice controls the write side of the pipe.
if n > 0 {
inPipe -= n
written += n
continue
}
if err != syscall.EAGAIN {
return written, err
}
if err := sock.pd.waitWrite(sock.isFile); err != nil {
return written, err
}
}
return written, nil
}
// splice wraps the splice system call. Since the current implementation
// only uses splice on sockets and pipes, the offset arguments are unused.
// splice returns int instead of int64, because callers never ask it to
// move more data in a single call than can fit in an int32.
func splice(out int, in int, max int, flags int) (int, error) {
n, err := syscall.Splice(in, nil, out, nil, max, flags)
return int(n), err
}
type splicePipeFields struct {
rfd int
wfd int
data int
}
type splicePipe struct {
splicePipeFields
// We want to use a finalizer, so ensure that the size is
// large enough to not use the tiny allocator.
_ [24 - unsafe.Sizeof(splicePipeFields{})%24]byte
}
// splicePipePool caches pipes to avoid high-frequency construction and destruction of pipe buffers.
// The garbage collector will free all pipes in the sync.Pool periodically, thus we need to set up
// a finalizer for each pipe to close its file descriptors before the actual GC.
var splicePipePool = sync.Pool{New: newPoolPipe}
func newPoolPipe() any {
// Discard the error which occurred during the creation of pipe buffer,
// redirecting the data transmission to the conventional way utilizing read() + write() as a fallback.
p := newPipe()
if p == nil {
return nil
}
runtime.SetFinalizer(p, destroyPipe)
return p
}
// getPipe tries to acquire a pipe buffer from the pool or create a new one with newPipe() if it gets nil from the cache.
//
// Note that it may fail to create a new pipe buffer by newPipe(), in which case getPipe() will return a generic error
// and system call name splice in a string as the indication.
func getPipe() (*splicePipe, string, error) {
v := splicePipePool.Get()
if v == nil {
return nil, "splice", syscall.EINVAL
}
return v.(*splicePipe), "", nil
}
func putPipe(p *splicePipe) {
// If there is still data left in the pipe,
// then close and discard it instead of putting it back into the pool.
if p.data != 0 {
runtime.SetFinalizer(p, nil)
destroyPipe(p)
return
}
splicePipePool.Put(p)
}
// newPipe sets up a pipe for a splice operation.
func newPipe() *splicePipe {
var fds [2]int
if err := syscall.Pipe2(fds[:], syscall.O_CLOEXEC|syscall.O_NONBLOCK); err != nil {
return nil
}
// Splice will loop writing maxSpliceSize bytes from the source to the pipe,
// and then write those bytes from the pipe to the destination.
// Set the pipe buffer size to maxSpliceSize to optimize that.
// Ignore errors here, as a smaller buffer size will work,
// although it will require more system calls.
fcntl(fds[0], syscall.F_SETPIPE_SZ, maxSpliceSize)
return &splicePipe{splicePipeFields: splicePipeFields{rfd: fds[0], wfd: fds[1]}}
}
// destroyPipe destroys a pipe.
func destroyPipe(p *splicePipe) {
CloseFunc(p.rfd)
CloseFunc(p.wfd)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package poll
import (
"io"
"runtime"
"syscall"
)
// Writev wraps the writev system call.
func (fd *FD) Writev(v *[][]byte) (int64, error) {
if err := fd.writeLock(); err != nil {
return 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, err
}
var iovecs []syscall.Iovec
if fd.iovecs != nil {
iovecs = *fd.iovecs
}
// TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is
// 1024 and this seems conservative enough for now. Darwin's
// UIO_MAXIOV also seems to be 1024.
maxVec := 1024
if runtime.GOOS == "aix" || runtime.GOOS == "solaris" {
// IOV_MAX is set to XOPEN_IOV_MAX on AIX and Solaris.
maxVec = 16
}
var n int64
var err error
for len(*v) > 0 {
iovecs = iovecs[:0]
for _, chunk := range *v {
if len(chunk) == 0 {
continue
}
iovecs = append(iovecs, newIovecWithBase(&chunk[0]))
if fd.IsStream && len(chunk) > 1<<30 {
iovecs[len(iovecs)-1].SetLen(1 << 30)
break // continue chunk on next writev
}
iovecs[len(iovecs)-1].SetLen(len(chunk))
if len(iovecs) == maxVec {
break
}
}
if len(iovecs) == 0 {
break
}
if fd.iovecs == nil {
fd.iovecs = new([]syscall.Iovec)
}
*fd.iovecs = iovecs // cache
var wrote uintptr
wrote, err = writev(fd.Sysfd, iovecs)
if wrote == ^uintptr(0) {
wrote = 0
}
TestHookDidWritev(int(wrote))
n += int64(wrote)
consume(v, int64(wrote))
for i := range iovecs {
iovecs[i] = syscall.Iovec{}
}
if err != nil {
if err == syscall.EINTR {
continue
}
if err == syscall.EAGAIN {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
break
}
if n == 0 {
err = io.ErrUnexpectedEOF
break
}
}
return n, err
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package profile
import (
"errors"
"fmt"
"sort"
)
func (p *Profile) decoder() []decoder {
return profileDecoder
}
// preEncode populates the unexported fields to be used by encode
// (with suffix X) from the corresponding exported fields. The
// exported fields are cleared up to facilitate testing.
func (p *Profile) preEncode() {
strings := make(map[string]int)
addString(strings, "")
for _, st := range p.SampleType {
st.typeX = addString(strings, st.Type)
st.unitX = addString(strings, st.Unit)
}
for _, s := range p.Sample {
s.labelX = nil
var keys []string
for k := range s.Label {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := s.Label[k]
for _, v := range vs {
s.labelX = append(s.labelX,
Label{
keyX: addString(strings, k),
strX: addString(strings, v),
},
)
}
}
var numKeys []string
for k := range s.NumLabel {
numKeys = append(numKeys, k)
}
sort.Strings(numKeys)
for _, k := range numKeys {
vs := s.NumLabel[k]
for _, v := range vs {
s.labelX = append(s.labelX,
Label{
keyX: addString(strings, k),
numX: v,
},
)
}
}
s.locationIDX = nil
for _, l := range s.Location {
s.locationIDX = append(s.locationIDX, l.ID)
}
}
for _, m := range p.Mapping {
m.fileX = addString(strings, m.File)
m.buildIDX = addString(strings, m.BuildID)
}
for _, l := range p.Location {
for i, ln := range l.Line {
if ln.Function != nil {
l.Line[i].functionIDX = ln.Function.ID
} else {
l.Line[i].functionIDX = 0
}
}
if l.Mapping != nil {
l.mappingIDX = l.Mapping.ID
} else {
l.mappingIDX = 0
}
}
for _, f := range p.Function {
f.nameX = addString(strings, f.Name)
f.systemNameX = addString(strings, f.SystemName)
f.filenameX = addString(strings, f.Filename)
}
p.dropFramesX = addString(strings, p.DropFrames)
p.keepFramesX = addString(strings, p.KeepFrames)
if pt := p.PeriodType; pt != nil {
pt.typeX = addString(strings, pt.Type)
pt.unitX = addString(strings, pt.Unit)
}
p.stringTable = make([]string, len(strings))
for s, i := range strings {
p.stringTable[i] = s
}
}
func (p *Profile) encode(b *buffer) {
for _, x := range p.SampleType {
encodeMessage(b, 1, x)
}
for _, x := range p.Sample {
encodeMessage(b, 2, x)
}
for _, x := range p.Mapping {
encodeMessage(b, 3, x)
}
for _, x := range p.Location {
encodeMessage(b, 4, x)
}
for _, x := range p.Function {
encodeMessage(b, 5, x)
}
encodeStrings(b, 6, p.stringTable)
encodeInt64Opt(b, 7, p.dropFramesX)
encodeInt64Opt(b, 8, p.keepFramesX)
encodeInt64Opt(b, 9, p.TimeNanos)
encodeInt64Opt(b, 10, p.DurationNanos)
if pt := p.PeriodType; pt != nil && (pt.typeX != 0 || pt.unitX != 0) {
encodeMessage(b, 11, p.PeriodType)
}
encodeInt64Opt(b, 12, p.Period)
}
var profileDecoder = []decoder{
nil, // 0
// repeated ValueType sample_type = 1
func(b *buffer, m message) error {
x := new(ValueType)
pp := m.(*Profile)
pp.SampleType = append(pp.SampleType, x)
return decodeMessage(b, x)
},
// repeated Sample sample = 2
func(b *buffer, m message) error {
x := new(Sample)
pp := m.(*Profile)
pp.Sample = append(pp.Sample, x)
return decodeMessage(b, x)
},
// repeated Mapping mapping = 3
func(b *buffer, m message) error {
x := new(Mapping)
pp := m.(*Profile)
pp.Mapping = append(pp.Mapping, x)
return decodeMessage(b, x)
},
// repeated Location location = 4
func(b *buffer, m message) error {
x := new(Location)
pp := m.(*Profile)
pp.Location = append(pp.Location, x)
return decodeMessage(b, x)
},
// repeated Function function = 5
func(b *buffer, m message) error {
x := new(Function)
pp := m.(*Profile)
pp.Function = append(pp.Function, x)
return decodeMessage(b, x)
},
// repeated string string_table = 6
func(b *buffer, m message) error {
err := decodeStrings(b, &m.(*Profile).stringTable)
if err != nil {
return err
}
if m.(*Profile).stringTable[0] != "" {
return errors.New("string_table[0] must be ''")
}
return nil
},
// repeated int64 drop_frames = 7
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).dropFramesX) },
// repeated int64 keep_frames = 8
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).keepFramesX) },
// repeated int64 time_nanos = 9
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).TimeNanos) },
// repeated int64 duration_nanos = 10
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).DurationNanos) },
// optional string period_type = 11
func(b *buffer, m message) error {
x := new(ValueType)
pp := m.(*Profile)
pp.PeriodType = x
return decodeMessage(b, x)
},
// repeated int64 period = 12
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).Period) },
// repeated int64 comment = 13
func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Profile).commentX) },
// int64 defaultSampleType = 14
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Profile).defaultSampleTypeX) },
}
// postDecode takes the unexported fields populated by decode (with
// suffix X) and populates the corresponding exported fields.
// The unexported fields are cleared up to facilitate testing.
func (p *Profile) postDecode() error {
var err error
mappings := make(map[uint64]*Mapping)
for _, m := range p.Mapping {
m.File, err = getString(p.stringTable, &m.fileX, err)
m.BuildID, err = getString(p.stringTable, &m.buildIDX, err)
mappings[m.ID] = m
}
functions := make(map[uint64]*Function)
for _, f := range p.Function {
f.Name, err = getString(p.stringTable, &f.nameX, err)
f.SystemName, err = getString(p.stringTable, &f.systemNameX, err)
f.Filename, err = getString(p.stringTable, &f.filenameX, err)
functions[f.ID] = f
}
locations := make(map[uint64]*Location)
for _, l := range p.Location {
l.Mapping = mappings[l.mappingIDX]
l.mappingIDX = 0
for i, ln := range l.Line {
if id := ln.functionIDX; id != 0 {
l.Line[i].Function = functions[id]
if l.Line[i].Function == nil {
return fmt.Errorf("Function ID %d not found", id)
}
l.Line[i].functionIDX = 0
}
}
locations[l.ID] = l
}
for _, st := range p.SampleType {
st.Type, err = getString(p.stringTable, &st.typeX, err)
st.Unit, err = getString(p.stringTable, &st.unitX, err)
}
for _, s := range p.Sample {
labels := make(map[string][]string)
numLabels := make(map[string][]int64)
for _, l := range s.labelX {
var key, value string
key, err = getString(p.stringTable, &l.keyX, err)
if l.strX != 0 {
value, err = getString(p.stringTable, &l.strX, err)
labels[key] = append(labels[key], value)
} else {
numLabels[key] = append(numLabels[key], l.numX)
}
}
if len(labels) > 0 {
s.Label = labels
}
if len(numLabels) > 0 {
s.NumLabel = numLabels
}
s.Location = nil
for _, lid := range s.locationIDX {
s.Location = append(s.Location, locations[lid])
}
s.locationIDX = nil
}
p.DropFrames, err = getString(p.stringTable, &p.dropFramesX, err)
p.KeepFrames, err = getString(p.stringTable, &p.keepFramesX, err)
if pt := p.PeriodType; pt == nil {
p.PeriodType = &ValueType{}
}
if pt := p.PeriodType; pt != nil {
pt.Type, err = getString(p.stringTable, &pt.typeX, err)
pt.Unit, err = getString(p.stringTable, &pt.unitX, err)
}
for _, i := range p.commentX {
var c string
c, err = getString(p.stringTable, &i, err)
p.Comments = append(p.Comments, c)
}
p.commentX = nil
p.DefaultSampleType, err = getString(p.stringTable, &p.defaultSampleTypeX, err)
p.stringTable = nil
return nil
}
func (p *ValueType) decoder() []decoder {
return valueTypeDecoder
}
func (p *ValueType) encode(b *buffer) {
encodeInt64Opt(b, 1, p.typeX)
encodeInt64Opt(b, 2, p.unitX)
}
var valueTypeDecoder = []decoder{
nil, // 0
// optional int64 type = 1
func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).typeX) },
// optional int64 unit = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*ValueType).unitX) },
}
func (p *Sample) decoder() []decoder {
return sampleDecoder
}
func (p *Sample) encode(b *buffer) {
encodeUint64s(b, 1, p.locationIDX)
for _, x := range p.Value {
encodeInt64(b, 2, x)
}
for _, x := range p.labelX {
encodeMessage(b, 3, x)
}
}
var sampleDecoder = []decoder{
nil, // 0
// repeated uint64 location = 1
func(b *buffer, m message) error { return decodeUint64s(b, &m.(*Sample).locationIDX) },
// repeated int64 value = 2
func(b *buffer, m message) error { return decodeInt64s(b, &m.(*Sample).Value) },
// repeated Label label = 3
func(b *buffer, m message) error {
s := m.(*Sample)
n := len(s.labelX)
s.labelX = append(s.labelX, Label{})
return decodeMessage(b, &s.labelX[n])
},
}
func (p Label) decoder() []decoder {
return labelDecoder
}
func (p Label) encode(b *buffer) {
encodeInt64Opt(b, 1, p.keyX)
encodeInt64Opt(b, 2, p.strX)
encodeInt64Opt(b, 3, p.numX)
}
var labelDecoder = []decoder{
nil, // 0
// optional int64 key = 1
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Label).keyX) },
// optional int64 str = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Label).strX) },
// optional int64 num = 3
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Label).numX) },
}
func (p *Mapping) decoder() []decoder {
return mappingDecoder
}
func (p *Mapping) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeUint64Opt(b, 2, p.Start)
encodeUint64Opt(b, 3, p.Limit)
encodeUint64Opt(b, 4, p.Offset)
encodeInt64Opt(b, 5, p.fileX)
encodeInt64Opt(b, 6, p.buildIDX)
encodeBoolOpt(b, 7, p.HasFunctions)
encodeBoolOpt(b, 8, p.HasFilenames)
encodeBoolOpt(b, 9, p.HasLineNumbers)
encodeBoolOpt(b, 10, p.HasInlineFrames)
}
var mappingDecoder = []decoder{
nil, // 0
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).ID) }, // optional uint64 id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Start) }, // optional uint64 memory_offset = 2
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Limit) }, // optional uint64 memory_limit = 3
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Mapping).Offset) }, // optional uint64 file_offset = 4
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).fileX) }, // optional int64 filename = 5
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Mapping).buildIDX) }, // optional int64 build_id = 6
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFunctions) }, // optional bool has_functions = 7
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasFilenames) }, // optional bool has_filenames = 8
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasLineNumbers) }, // optional bool has_line_numbers = 9
func(b *buffer, m message) error { return decodeBool(b, &m.(*Mapping).HasInlineFrames) }, // optional bool has_inline_frames = 10
}
func (p *Location) decoder() []decoder {
return locationDecoder
}
func (p *Location) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeUint64Opt(b, 2, p.mappingIDX)
encodeUint64Opt(b, 3, p.Address)
for i := range p.Line {
encodeMessage(b, 4, &p.Line[i])
}
}
var locationDecoder = []decoder{
nil, // 0
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).ID) }, // optional uint64 id = 1;
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).mappingIDX) }, // optional uint64 mapping_id = 2;
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Location).Address) }, // optional uint64 address = 3;
func(b *buffer, m message) error { // repeated Line line = 4
pp := m.(*Location)
n := len(pp.Line)
pp.Line = append(pp.Line, Line{})
return decodeMessage(b, &pp.Line[n])
},
}
func (p *Line) decoder() []decoder {
return lineDecoder
}
func (p *Line) encode(b *buffer) {
encodeUint64Opt(b, 1, p.functionIDX)
encodeInt64Opt(b, 2, p.Line)
}
var lineDecoder = []decoder{
nil, // 0
// optional uint64 function_id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Line).functionIDX) },
// optional int64 line = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Line).Line) },
}
func (p *Function) decoder() []decoder {
return functionDecoder
}
func (p *Function) encode(b *buffer) {
encodeUint64Opt(b, 1, p.ID)
encodeInt64Opt(b, 2, p.nameX)
encodeInt64Opt(b, 3, p.systemNameX)
encodeInt64Opt(b, 4, p.filenameX)
encodeInt64Opt(b, 5, p.StartLine)
}
var functionDecoder = []decoder{
nil, // 0
// optional uint64 id = 1
func(b *buffer, m message) error { return decodeUint64(b, &m.(*Function).ID) },
// optional int64 function_name = 2
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).nameX) },
// optional int64 function_system_name = 3
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).systemNameX) },
// repeated int64 filename = 4
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).filenameX) },
// optional int64 start_line = 5
func(b *buffer, m message) error { return decodeInt64(b, &m.(*Function).StartLine) },
}
func addString(strings map[string]int, s string) int64 {
i, ok := strings[s]
if !ok {
i = len(strings)
strings[s] = i
}
return int64(i)
}
func getString(strings []string, strng *int64, err error) (string, error) {
if err != nil {
return "", err
}
s := int(*strng)
if s < 0 || s >= len(strings) {
return "", errMalformed
}
*strng = 0
return strings[s], nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implements methods to filter samples from profiles.
package profile
import "regexp"
// FilterSamplesByName filters the samples in a profile and only keeps
// samples where at least one frame matches focus but none match ignore.
// Returns true is the corresponding regexp matched at least one sample.
func (p *Profile) FilterSamplesByName(focus, ignore, hide *regexp.Regexp) (fm, im, hm bool) {
focusOrIgnore := make(map[uint64]bool)
hidden := make(map[uint64]bool)
for _, l := range p.Location {
if ignore != nil && l.matchesName(ignore) {
im = true
focusOrIgnore[l.ID] = false
} else if focus == nil || l.matchesName(focus) {
fm = true
focusOrIgnore[l.ID] = true
}
if hide != nil && l.matchesName(hide) {
hm = true
l.Line = l.unmatchedLines(hide)
if len(l.Line) == 0 {
hidden[l.ID] = true
}
}
}
s := make([]*Sample, 0, len(p.Sample))
for _, sample := range p.Sample {
if focusedAndNotIgnored(sample.Location, focusOrIgnore) {
if len(hidden) > 0 {
var locs []*Location
for _, loc := range sample.Location {
if !hidden[loc.ID] {
locs = append(locs, loc)
}
}
if len(locs) == 0 {
// Remove sample with no locations (by not adding it to s).
continue
}
sample.Location = locs
}
s = append(s, sample)
}
}
p.Sample = s
return
}
// matchesName reports whether the function name or file in the
// location matches the regular expression.
func (loc *Location) matchesName(re *regexp.Regexp) bool {
for _, ln := range loc.Line {
if fn := ln.Function; fn != nil {
if re.MatchString(fn.Name) {
return true
}
if re.MatchString(fn.Filename) {
return true
}
}
}
return false
}
// unmatchedLines returns the lines in the location that do not match
// the regular expression.
func (loc *Location) unmatchedLines(re *regexp.Regexp) []Line {
var lines []Line
for _, ln := range loc.Line {
if fn := ln.Function; fn != nil {
if re.MatchString(fn.Name) {
continue
}
if re.MatchString(fn.Filename) {
continue
}
}
lines = append(lines, ln)
}
return lines
}
// focusedAndNotIgnored looks up a slice of ids against a map of
// focused/ignored locations. The map only contains locations that are
// explicitly focused or ignored. Returns whether there is at least
// one focused location but no ignored locations.
func focusedAndNotIgnored(locs []*Location, m map[uint64]bool) bool {
var f bool
for _, loc := range locs {
if focus, focusOrIgnore := m[loc.ID]; focusOrIgnore {
if focus {
// Found focused location. Must keep searching in case there
// is an ignored one as well.
f = true
} else {
// Found ignored location. Can return false right away.
return false
}
}
}
return f
}
// TagMatch selects tags for filtering
type TagMatch func(key, val string, nval int64) bool
// FilterSamplesByTag removes all samples from the profile, except
// those that match focus and do not match the ignore regular
// expression.
func (p *Profile) FilterSamplesByTag(focus, ignore TagMatch) (fm, im bool) {
samples := make([]*Sample, 0, len(p.Sample))
for _, s := range p.Sample {
focused, ignored := focusedSample(s, focus, ignore)
fm = fm || focused
im = im || ignored
if focused && !ignored {
samples = append(samples, s)
}
}
p.Sample = samples
return
}
// focusedSample checks a sample against focus and ignore regexps.
// Returns whether the focus/ignore regexps match any tags.
func focusedSample(s *Sample, focus, ignore TagMatch) (fm, im bool) {
fm = focus == nil
for key, vals := range s.Label {
for _, val := range vals {
if ignore != nil && ignore(key, val, 0) {
im = true
}
if !fm && focus(key, val, 0) {
fm = true
}
}
}
for key, vals := range s.NumLabel {
for _, val := range vals {
if ignore != nil && ignore(key, "", val) {
im = true
}
if !fm && focus(key, "", val) {
fm = true
}
}
}
return fm, im
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements parsers to convert legacy profiles into the
// profile.proto format.
package profile
import (
"bufio"
"bytes"
"fmt"
"internal/lazyregexp"
"io"
"math"
"strconv"
"strings"
)
var (
countStartRE = lazyregexp.New(`\A(\w+) profile: total \d+\n\z`)
countRE = lazyregexp.New(`\A(\d+) @(( 0x[0-9a-f]+)+)\n\z`)
heapHeaderRE = lazyregexp.New(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] *@ *(heap[_a-z0-9]*)/?(\d*)`)
heapSampleRE = lazyregexp.New(`(-?\d+): *(-?\d+) *\[ *(\d+): *(\d+) *] @([ x0-9a-f]*)`)
contentionSampleRE = lazyregexp.New(`(\d+) *(\d+) @([ x0-9a-f]*)`)
hexNumberRE = lazyregexp.New(`0x[0-9a-f]+`)
growthHeaderRE = lazyregexp.New(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] @ growthz`)
fragmentationHeaderRE = lazyregexp.New(`heap profile: *(\d+): *(\d+) *\[ *(\d+): *(\d+) *\] @ fragmentationz`)
threadzStartRE = lazyregexp.New(`--- threadz \d+ ---`)
threadStartRE = lazyregexp.New(`--- Thread ([[:xdigit:]]+) \(name: (.*)/(\d+)\) stack: ---`)
procMapsRE = lazyregexp.New(`([[:xdigit:]]+)-([[:xdigit:]]+)\s+([-rwxp]+)\s+([[:xdigit:]]+)\s+([[:xdigit:]]+):([[:xdigit:]]+)\s+([[:digit:]]+)\s*(\S+)?`)
briefMapsRE = lazyregexp.New(`\s*([[:xdigit:]]+)-([[:xdigit:]]+):\s*(\S+)(\s.*@)?([[:xdigit:]]+)?`)
// LegacyHeapAllocated instructs the heapz parsers to use the
// allocated memory stats instead of the default in-use memory. Note
// that tcmalloc doesn't provide all allocated memory, only in-use
// stats.
LegacyHeapAllocated bool
)
func isSpaceOrComment(line string) bool {
trimmed := strings.TrimSpace(line)
return len(trimmed) == 0 || trimmed[0] == '#'
}
// parseGoCount parses a Go count profile (e.g., threadcreate or
// goroutine) and returns a new Profile.
func parseGoCount(b []byte) (*Profile, error) {
r := bytes.NewBuffer(b)
var line string
var err error
for {
// Skip past comments and empty lines seeking a real header.
line, err = r.ReadString('\n')
if err != nil {
return nil, err
}
if !isSpaceOrComment(line) {
break
}
}
m := countStartRE.FindStringSubmatch(line)
if m == nil {
return nil, errUnrecognized
}
profileType := m[1]
p := &Profile{
PeriodType: &ValueType{Type: profileType, Unit: "count"},
Period: 1,
SampleType: []*ValueType{{Type: profileType, Unit: "count"}},
}
locations := make(map[uint64]*Location)
for {
line, err = r.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if isSpaceOrComment(line) {
continue
}
if strings.HasPrefix(line, "---") {
break
}
m := countRE.FindStringSubmatch(line)
if m == nil {
return nil, errMalformed
}
n, err := strconv.ParseInt(m[1], 0, 64)
if err != nil {
return nil, errMalformed
}
fields := strings.Fields(m[2])
locs := make([]*Location, 0, len(fields))
for _, stk := range fields {
addr, err := strconv.ParseUint(stk, 0, 64)
if err != nil {
return nil, errMalformed
}
// Adjust all frames by -1 to land on the call instruction.
addr--
loc := locations[addr]
if loc == nil {
loc = &Location{
Address: addr,
}
locations[addr] = loc
p.Location = append(p.Location, loc)
}
locs = append(locs, loc)
}
p.Sample = append(p.Sample, &Sample{
Location: locs,
Value: []int64{n},
})
}
if err = parseAdditionalSections(strings.TrimSpace(line), r, p); err != nil {
return nil, err
}
return p, nil
}
// remapLocationIDs ensures there is a location for each address
// referenced by a sample, and remaps the samples to point to the new
// location ids.
func (p *Profile) remapLocationIDs() {
seen := make(map[*Location]bool, len(p.Location))
var locs []*Location
for _, s := range p.Sample {
for _, l := range s.Location {
if seen[l] {
continue
}
l.ID = uint64(len(locs) + 1)
locs = append(locs, l)
seen[l] = true
}
}
p.Location = locs
}
func (p *Profile) remapFunctionIDs() {
seen := make(map[*Function]bool, len(p.Function))
var fns []*Function
for _, l := range p.Location {
for _, ln := range l.Line {
fn := ln.Function
if fn == nil || seen[fn] {
continue
}
fn.ID = uint64(len(fns) + 1)
fns = append(fns, fn)
seen[fn] = true
}
}
p.Function = fns
}
// remapMappingIDs matches location addresses with existing mappings
// and updates them appropriately. This is O(N*M), if this ever shows
// up as a bottleneck, evaluate sorting the mappings and doing a
// binary search, which would make it O(N*log(M)).
func (p *Profile) remapMappingIDs() {
if len(p.Mapping) == 0 {
return
}
// Some profile handlers will incorrectly set regions for the main
// executable if its section is remapped. Fix them through heuristics.
// Remove the initial mapping if named '/anon_hugepage' and has a
// consecutive adjacent mapping.
if m := p.Mapping[0]; strings.HasPrefix(m.File, "/anon_hugepage") {
if len(p.Mapping) > 1 && m.Limit == p.Mapping[1].Start {
p.Mapping = p.Mapping[1:]
}
}
for _, l := range p.Location {
if a := l.Address; a != 0 {
for _, m := range p.Mapping {
if m.Start <= a && a < m.Limit {
l.Mapping = m
break
}
}
}
}
// Reset all mapping IDs.
for i, m := range p.Mapping {
m.ID = uint64(i + 1)
}
}
var cpuInts = []func([]byte) (uint64, []byte){
get32l,
get32b,
get64l,
get64b,
}
func get32l(b []byte) (uint64, []byte) {
if len(b) < 4 {
return 0, nil
}
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24, b[4:]
}
func get32b(b []byte) (uint64, []byte) {
if len(b) < 4 {
return 0, nil
}
return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(b[0])<<24, b[4:]
}
func get64l(b []byte) (uint64, []byte) {
if len(b) < 8 {
return 0, nil
}
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56, b[8:]
}
func get64b(b []byte) (uint64, []byte) {
if len(b) < 8 {
return 0, nil
}
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56, b[8:]
}
// ParseTracebacks parses a set of tracebacks and returns a newly
// populated profile. It will accept any text file and generate a
// Profile out of it with any hex addresses it can identify, including
// a process map if it can recognize one. Each sample will include a
// tag "source" with the addresses recognized in string format.
func ParseTracebacks(b []byte) (*Profile, error) {
r := bytes.NewBuffer(b)
p := &Profile{
PeriodType: &ValueType{Type: "trace", Unit: "count"},
Period: 1,
SampleType: []*ValueType{
{Type: "trace", Unit: "count"},
},
}
var sources []string
var sloc []*Location
locs := make(map[uint64]*Location)
for {
l, err := r.ReadString('\n')
if err != nil {
if err != io.EOF {
return nil, err
}
if l == "" {
break
}
}
if sectionTrigger(l) == memoryMapSection {
break
}
if s, addrs := extractHexAddresses(l); len(s) > 0 {
for _, addr := range addrs {
// Addresses from stack traces point to the next instruction after
// each call. Adjust by -1 to land somewhere on the actual call.
addr--
loc := locs[addr]
if locs[addr] == nil {
loc = &Location{
Address: addr,
}
p.Location = append(p.Location, loc)
locs[addr] = loc
}
sloc = append(sloc, loc)
}
sources = append(sources, s...)
} else {
if len(sources) > 0 || len(sloc) > 0 {
addTracebackSample(sloc, sources, p)
sloc, sources = nil, nil
}
}
}
// Add final sample to save any leftover data.
if len(sources) > 0 || len(sloc) > 0 {
addTracebackSample(sloc, sources, p)
}
if err := p.ParseMemoryMap(r); err != nil {
return nil, err
}
return p, nil
}
func addTracebackSample(l []*Location, s []string, p *Profile) {
p.Sample = append(p.Sample,
&Sample{
Value: []int64{1},
Location: l,
Label: map[string][]string{"source": s},
})
}
// parseCPU parses a profilez legacy profile and returns a newly
// populated Profile.
//
// The general format for profilez samples is a sequence of words in
// binary format. The first words are a header with the following data:
//
// 1st word -- 0
// 2nd word -- 3
// 3rd word -- 0 if a c++ application, 1 if a java application.
// 4th word -- Sampling period (in microseconds).
// 5th word -- Padding.
func parseCPU(b []byte) (*Profile, error) {
var parse func([]byte) (uint64, []byte)
var n1, n2, n3, n4, n5 uint64
for _, parse = range cpuInts {
var tmp []byte
n1, tmp = parse(b)
n2, tmp = parse(tmp)
n3, tmp = parse(tmp)
n4, tmp = parse(tmp)
n5, tmp = parse(tmp)
if tmp != nil && n1 == 0 && n2 == 3 && n3 == 0 && n4 > 0 && n5 == 0 {
b = tmp
return cpuProfile(b, int64(n4), parse)
}
}
return nil, errUnrecognized
}
// cpuProfile returns a new Profile from C++ profilez data.
// b is the profile bytes after the header, period is the profiling
// period, and parse is a function to parse 8-byte chunks from the
// profile in its native endianness.
func cpuProfile(b []byte, period int64, parse func(b []byte) (uint64, []byte)) (*Profile, error) {
p := &Profile{
Period: period * 1000,
PeriodType: &ValueType{Type: "cpu", Unit: "nanoseconds"},
SampleType: []*ValueType{
{Type: "samples", Unit: "count"},
{Type: "cpu", Unit: "nanoseconds"},
},
}
var err error
if b, _, err = parseCPUSamples(b, parse, true, p); err != nil {
return nil, err
}
// If all samples have the same second-to-the-bottom frame, it
// strongly suggests that it is an uninteresting artifact of
// measurement -- a stack frame pushed by the signal handler. The
// bottom frame is always correct as it is picked up from the signal
// structure, not the stack. Check if this is the case and if so,
// remove.
if len(p.Sample) > 1 && len(p.Sample[0].Location) > 1 {
allSame := true
id1 := p.Sample[0].Location[1].Address
for _, s := range p.Sample {
if len(s.Location) < 2 || id1 != s.Location[1].Address {
allSame = false
break
}
}
if allSame {
for _, s := range p.Sample {
s.Location = append(s.Location[:1], s.Location[2:]...)
}
}
}
if err := p.ParseMemoryMap(bytes.NewBuffer(b)); err != nil {
return nil, err
}
return p, nil
}
// parseCPUSamples parses a collection of profilez samples from a
// profile.
//
// profilez samples are a repeated sequence of stack frames of the
// form:
//
// 1st word -- The number of times this stack was encountered.
// 2nd word -- The size of the stack (StackSize).
// 3rd word -- The first address on the stack.
// ...
// StackSize + 2 -- The last address on the stack
//
// The last stack trace is of the form:
//
// 1st word -- 0
// 2nd word -- 1
// 3rd word -- 0
//
// Addresses from stack traces may point to the next instruction after
// each call. Optionally adjust by -1 to land somewhere on the actual
// call (except for the leaf, which is not a call).
func parseCPUSamples(b []byte, parse func(b []byte) (uint64, []byte), adjust bool, p *Profile) ([]byte, map[uint64]*Location, error) {
locs := make(map[uint64]*Location)
for len(b) > 0 {
var count, nstk uint64
count, b = parse(b)
nstk, b = parse(b)
if b == nil || nstk > uint64(len(b)/4) {
return nil, nil, errUnrecognized
}
var sloc []*Location
addrs := make([]uint64, nstk)
for i := 0; i < int(nstk); i++ {
addrs[i], b = parse(b)
}
if count == 0 && nstk == 1 && addrs[0] == 0 {
// End of data marker
break
}
for i, addr := range addrs {
if adjust && i > 0 {
addr--
}
loc := locs[addr]
if loc == nil {
loc = &Location{
Address: addr,
}
locs[addr] = loc
p.Location = append(p.Location, loc)
}
sloc = append(sloc, loc)
}
p.Sample = append(p.Sample,
&Sample{
Value: []int64{int64(count), int64(count) * p.Period},
Location: sloc,
})
}
// Reached the end without finding the EOD marker.
return b, locs, nil
}
// parseHeap parses a heapz legacy or a growthz profile and
// returns a newly populated Profile.
func parseHeap(b []byte) (p *Profile, err error) {
r := bytes.NewBuffer(b)
l, err := r.ReadString('\n')
if err != nil {
return nil, errUnrecognized
}
sampling := ""
if header := heapHeaderRE.FindStringSubmatch(l); header != nil {
p = &Profile{
SampleType: []*ValueType{
{Type: "objects", Unit: "count"},
{Type: "space", Unit: "bytes"},
},
PeriodType: &ValueType{Type: "objects", Unit: "bytes"},
}
var period int64
if len(header[6]) > 0 {
if period, err = strconv.ParseInt(header[6], 10, 64); err != nil {
return nil, errUnrecognized
}
}
switch header[5] {
case "heapz_v2", "heap_v2":
sampling, p.Period = "v2", period
case "heapprofile":
sampling, p.Period = "", 1
case "heap":
sampling, p.Period = "v2", period/2
default:
return nil, errUnrecognized
}
} else if header = growthHeaderRE.FindStringSubmatch(l); header != nil {
p = &Profile{
SampleType: []*ValueType{
{Type: "objects", Unit: "count"},
{Type: "space", Unit: "bytes"},
},
PeriodType: &ValueType{Type: "heapgrowth", Unit: "count"},
Period: 1,
}
} else if header = fragmentationHeaderRE.FindStringSubmatch(l); header != nil {
p = &Profile{
SampleType: []*ValueType{
{Type: "objects", Unit: "count"},
{Type: "space", Unit: "bytes"},
},
PeriodType: &ValueType{Type: "allocations", Unit: "count"},
Period: 1,
}
} else {
return nil, errUnrecognized
}
if LegacyHeapAllocated {
for _, st := range p.SampleType {
st.Type = "alloc_" + st.Type
}
} else {
for _, st := range p.SampleType {
st.Type = "inuse_" + st.Type
}
}
locs := make(map[uint64]*Location)
for {
l, err = r.ReadString('\n')
if err != nil {
if err != io.EOF {
return nil, err
}
if l == "" {
break
}
}
if isSpaceOrComment(l) {
continue
}
l = strings.TrimSpace(l)
if sectionTrigger(l) != unrecognizedSection {
break
}
value, blocksize, addrs, err := parseHeapSample(l, p.Period, sampling)
if err != nil {
return nil, err
}
var sloc []*Location
for _, addr := range addrs {
// Addresses from stack traces point to the next instruction after
// each call. Adjust by -1 to land somewhere on the actual call.
addr--
loc := locs[addr]
if locs[addr] == nil {
loc = &Location{
Address: addr,
}
p.Location = append(p.Location, loc)
locs[addr] = loc
}
sloc = append(sloc, loc)
}
p.Sample = append(p.Sample, &Sample{
Value: value,
Location: sloc,
NumLabel: map[string][]int64{"bytes": {blocksize}},
})
}
if err = parseAdditionalSections(l, r, p); err != nil {
return nil, err
}
return p, nil
}
// parseHeapSample parses a single row from a heap profile into a new Sample.
func parseHeapSample(line string, rate int64, sampling string) (value []int64, blocksize int64, addrs []uint64, err error) {
sampleData := heapSampleRE.FindStringSubmatch(line)
if len(sampleData) != 6 {
return value, blocksize, addrs, fmt.Errorf("unexpected number of sample values: got %d, want 6", len(sampleData))
}
// Use first two values by default; tcmalloc sampling generates the
// same value for both, only the older heap-profile collect separate
// stats for in-use and allocated objects.
valueIndex := 1
if LegacyHeapAllocated {
valueIndex = 3
}
var v1, v2 int64
if v1, err = strconv.ParseInt(sampleData[valueIndex], 10, 64); err != nil {
return value, blocksize, addrs, fmt.Errorf("malformed sample: %s: %v", line, err)
}
if v2, err = strconv.ParseInt(sampleData[valueIndex+1], 10, 64); err != nil {
return value, blocksize, addrs, fmt.Errorf("malformed sample: %s: %v", line, err)
}
if v1 == 0 {
if v2 != 0 {
return value, blocksize, addrs, fmt.Errorf("allocation count was 0 but allocation bytes was %d", v2)
}
} else {
blocksize = v2 / v1
if sampling == "v2" {
v1, v2 = scaleHeapSample(v1, v2, rate)
}
}
value = []int64{v1, v2}
addrs = parseHexAddresses(sampleData[5])
return value, blocksize, addrs, nil
}
// extractHexAddresses extracts hex numbers from a string and returns
// them, together with their numeric value, in a slice.
func extractHexAddresses(s string) ([]string, []uint64) {
hexStrings := hexNumberRE.FindAllString(s, -1)
var ids []uint64
for _, s := range hexStrings {
if id, err := strconv.ParseUint(s, 0, 64); err == nil {
ids = append(ids, id)
} else {
// Do not expect any parsing failures due to the regexp matching.
panic("failed to parse hex value:" + s)
}
}
return hexStrings, ids
}
// parseHexAddresses parses hex numbers from a string and returns them
// in a slice.
func parseHexAddresses(s string) []uint64 {
_, ids := extractHexAddresses(s)
return ids
}
// scaleHeapSample adjusts the data from a heapz Sample to
// account for its probability of appearing in the collected
// data. heapz profiles are a sampling of the memory allocations
// requests in a program. We estimate the unsampled value by dividing
// each collected sample by its probability of appearing in the
// profile. heapz v2 profiles rely on a poisson process to determine
// which samples to collect, based on the desired average collection
// rate R. The probability of a sample of size S to appear in that
// profile is 1-exp(-S/R).
func scaleHeapSample(count, size, rate int64) (int64, int64) {
if count == 0 || size == 0 {
return 0, 0
}
if rate <= 1 {
// if rate==1 all samples were collected so no adjustment is needed.
// if rate<1 treat as unknown and skip scaling.
return count, size
}
avgSize := float64(size) / float64(count)
scale := 1 / (1 - math.Exp(-avgSize/float64(rate)))
return int64(float64(count) * scale), int64(float64(size) * scale)
}
// parseContention parses a mutex or contention profile. There are 2 cases:
// "--- contentionz " for legacy C++ profiles (and backwards compatibility)
// "--- mutex:" or "--- contention:" for profiles generated by the Go runtime.
// This code converts the text output from runtime into a *Profile. (In the future
// the runtime might write a serialized Profile directly making this unnecessary.)
func parseContention(b []byte) (*Profile, error) {
r := bytes.NewBuffer(b)
var l string
var err error
for {
// Skip past comments and empty lines seeking a real header.
l, err = r.ReadString('\n')
if err != nil {
return nil, err
}
if !isSpaceOrComment(l) {
break
}
}
if strings.HasPrefix(l, "--- contentionz ") {
return parseCppContention(r)
} else if strings.HasPrefix(l, "--- mutex:") {
return parseCppContention(r)
} else if strings.HasPrefix(l, "--- contention:") {
return parseCppContention(r)
}
return nil, errUnrecognized
}
// parseCppContention parses the output from synchronization_profiling.cc
// for backward compatibility, and the compatible (non-debug) block profile
// output from the Go runtime.
func parseCppContention(r *bytes.Buffer) (*Profile, error) {
p := &Profile{
PeriodType: &ValueType{Type: "contentions", Unit: "count"},
Period: 1,
SampleType: []*ValueType{
{Type: "contentions", Unit: "count"},
{Type: "delay", Unit: "nanoseconds"},
},
}
var cpuHz int64
var l string
var err error
// Parse text of the form "attribute = value" before the samples.
const delimiter = '='
for {
l, err = r.ReadString('\n')
if err != nil {
if err != io.EOF {
return nil, err
}
if l == "" {
break
}
}
if isSpaceOrComment(l) {
continue
}
if l = strings.TrimSpace(l); l == "" {
continue
}
if strings.HasPrefix(l, "---") {
break
}
index := strings.IndexByte(l, delimiter)
if index < 0 {
break
}
key := l[:index]
val := l[index+1:]
key, val = strings.TrimSpace(key), strings.TrimSpace(val)
var err error
switch key {
case "cycles/second":
if cpuHz, err = strconv.ParseInt(val, 0, 64); err != nil {
return nil, errUnrecognized
}
case "sampling period":
if p.Period, err = strconv.ParseInt(val, 0, 64); err != nil {
return nil, errUnrecognized
}
case "ms since reset":
ms, err := strconv.ParseInt(val, 0, 64)
if err != nil {
return nil, errUnrecognized
}
p.DurationNanos = ms * 1000 * 1000
case "format":
// CPP contentionz profiles don't have format.
return nil, errUnrecognized
case "resolution":
// CPP contentionz profiles don't have resolution.
return nil, errUnrecognized
case "discarded samples":
default:
return nil, errUnrecognized
}
}
locs := make(map[uint64]*Location)
for {
if !isSpaceOrComment(l) {
if l = strings.TrimSpace(l); strings.HasPrefix(l, "---") {
break
}
value, addrs, err := parseContentionSample(l, p.Period, cpuHz)
if err != nil {
return nil, err
}
var sloc []*Location
for _, addr := range addrs {
// Addresses from stack traces point to the next instruction after
// each call. Adjust by -1 to land somewhere on the actual call.
addr--
loc := locs[addr]
if locs[addr] == nil {
loc = &Location{
Address: addr,
}
p.Location = append(p.Location, loc)
locs[addr] = loc
}
sloc = append(sloc, loc)
}
p.Sample = append(p.Sample, &Sample{
Value: value,
Location: sloc,
})
}
if l, err = r.ReadString('\n'); err != nil {
if err != io.EOF {
return nil, err
}
if l == "" {
break
}
}
}
if err = parseAdditionalSections(l, r, p); err != nil {
return nil, err
}
return p, nil
}
// parseContentionSample parses a single row from a contention profile
// into a new Sample.
func parseContentionSample(line string, period, cpuHz int64) (value []int64, addrs []uint64, err error) {
sampleData := contentionSampleRE.FindStringSubmatch(line)
if sampleData == nil {
return value, addrs, errUnrecognized
}
v1, err := strconv.ParseInt(sampleData[1], 10, 64)
if err != nil {
return value, addrs, fmt.Errorf("malformed sample: %s: %v", line, err)
}
v2, err := strconv.ParseInt(sampleData[2], 10, 64)
if err != nil {
return value, addrs, fmt.Errorf("malformed sample: %s: %v", line, err)
}
// Unsample values if period and cpuHz are available.
// - Delays are scaled to cycles and then to nanoseconds.
// - Contentions are scaled to cycles.
if period > 0 {
if cpuHz > 0 {
cpuGHz := float64(cpuHz) / 1e9
v1 = int64(float64(v1) * float64(period) / cpuGHz)
}
v2 = v2 * period
}
value = []int64{v2, v1}
addrs = parseHexAddresses(sampleData[3])
return value, addrs, nil
}
// parseThread parses a Threadz profile and returns a new Profile.
func parseThread(b []byte) (*Profile, error) {
r := bytes.NewBuffer(b)
var line string
var err error
for {
// Skip past comments and empty lines seeking a real header.
line, err = r.ReadString('\n')
if err != nil {
return nil, err
}
if !isSpaceOrComment(line) {
break
}
}
if m := threadzStartRE.FindStringSubmatch(line); m != nil {
// Advance over initial comments until first stack trace.
for {
line, err = r.ReadString('\n')
if err != nil {
if err != io.EOF {
return nil, err
}
if line == "" {
break
}
}
if sectionTrigger(line) != unrecognizedSection || line[0] == '-' {
break
}
}
} else if t := threadStartRE.FindStringSubmatch(line); len(t) != 4 {
return nil, errUnrecognized
}
p := &Profile{
SampleType: []*ValueType{{Type: "thread", Unit: "count"}},
PeriodType: &ValueType{Type: "thread", Unit: "count"},
Period: 1,
}
locs := make(map[uint64]*Location)
// Recognize each thread and populate profile samples.
for sectionTrigger(line) == unrecognizedSection {
if strings.HasPrefix(line, "---- no stack trace for") {
line = ""
break
}
if t := threadStartRE.FindStringSubmatch(line); len(t) != 4 {
return nil, errUnrecognized
}
var addrs []uint64
line, addrs, err = parseThreadSample(r)
if err != nil {
return nil, errUnrecognized
}
if len(addrs) == 0 {
// We got a --same as previous threads--. Bump counters.
if len(p.Sample) > 0 {
s := p.Sample[len(p.Sample)-1]
s.Value[0]++
}
continue
}
var sloc []*Location
for _, addr := range addrs {
// Addresses from stack traces point to the next instruction after
// each call. Adjust by -1 to land somewhere on the actual call.
addr--
loc := locs[addr]
if locs[addr] == nil {
loc = &Location{
Address: addr,
}
p.Location = append(p.Location, loc)
locs[addr] = loc
}
sloc = append(sloc, loc)
}
p.Sample = append(p.Sample, &Sample{
Value: []int64{1},
Location: sloc,
})
}
if err = parseAdditionalSections(line, r, p); err != nil {
return nil, err
}
return p, nil
}
// parseThreadSample parses a symbolized or unsymbolized stack trace.
// Returns the first line after the traceback, the sample (or nil if
// it hits a 'same-as-previous' marker) and an error.
func parseThreadSample(b *bytes.Buffer) (nextl string, addrs []uint64, err error) {
var l string
sameAsPrevious := false
for {
if l, err = b.ReadString('\n'); err != nil {
if err != io.EOF {
return "", nil, err
}
if l == "" {
break
}
}
if l = strings.TrimSpace(l); l == "" {
continue
}
if strings.HasPrefix(l, "---") {
break
}
if strings.Contains(l, "same as previous thread") {
sameAsPrevious = true
continue
}
addrs = append(addrs, parseHexAddresses(l)...)
}
if sameAsPrevious {
return l, nil, nil
}
return l, addrs, nil
}
// parseAdditionalSections parses any additional sections in the
// profile, ignoring any unrecognized sections.
func parseAdditionalSections(l string, b *bytes.Buffer, p *Profile) (err error) {
for {
if sectionTrigger(l) == memoryMapSection {
break
}
// Ignore any unrecognized sections.
if l, err := b.ReadString('\n'); err != nil {
if err != io.EOF {
return err
}
if l == "" {
break
}
}
}
return p.ParseMemoryMap(b)
}
// ParseMemoryMap parses a memory map in the format of
// /proc/self/maps, and overrides the mappings in the current profile.
// It renumbers the samples and locations in the profile correspondingly.
func (p *Profile) ParseMemoryMap(rd io.Reader) error {
b := bufio.NewReader(rd)
var attrs []string
var r *strings.Replacer
const delimiter = '='
for {
l, err := b.ReadString('\n')
if err != nil {
if err != io.EOF {
return err
}
if l == "" {
break
}
}
if l = strings.TrimSpace(l); l == "" {
continue
}
if r != nil {
l = r.Replace(l)
}
m, err := parseMappingEntry(l)
if err != nil {
if err == errUnrecognized {
// Recognize assignments of the form: attr=value, and replace
// $attr with value on subsequent mappings.
idx := strings.IndexByte(l, delimiter)
if idx >= 0 {
attr := l[:idx]
value := l[idx+1:]
attrs = append(attrs, "$"+strings.TrimSpace(attr), strings.TrimSpace(value))
r = strings.NewReplacer(attrs...)
}
// Ignore any unrecognized entries
continue
}
return err
}
if m == nil || (m.File == "" && len(p.Mapping) != 0) {
// In some cases the first entry may include the address range
// but not the name of the file. It should be followed by
// another entry with the name.
continue
}
if len(p.Mapping) == 1 && p.Mapping[0].File == "" {
// Update the name if this is the entry following that empty one.
p.Mapping[0].File = m.File
continue
}
p.Mapping = append(p.Mapping, m)
}
p.remapLocationIDs()
p.remapFunctionIDs()
p.remapMappingIDs()
return nil
}
func parseMappingEntry(l string) (*Mapping, error) {
mapping := &Mapping{}
var err error
if me := procMapsRE.FindStringSubmatch(l); len(me) == 9 {
if !strings.Contains(me[3], "x") {
// Skip non-executable entries.
return nil, nil
}
if mapping.Start, err = strconv.ParseUint(me[1], 16, 64); err != nil {
return nil, errUnrecognized
}
if mapping.Limit, err = strconv.ParseUint(me[2], 16, 64); err != nil {
return nil, errUnrecognized
}
if me[4] != "" {
if mapping.Offset, err = strconv.ParseUint(me[4], 16, 64); err != nil {
return nil, errUnrecognized
}
}
mapping.File = me[8]
return mapping, nil
}
if me := briefMapsRE.FindStringSubmatch(l); len(me) == 6 {
if mapping.Start, err = strconv.ParseUint(me[1], 16, 64); err != nil {
return nil, errUnrecognized
}
if mapping.Limit, err = strconv.ParseUint(me[2], 16, 64); err != nil {
return nil, errUnrecognized
}
mapping.File = me[3]
if me[5] != "" {
if mapping.Offset, err = strconv.ParseUint(me[5], 16, 64); err != nil {
return nil, errUnrecognized
}
}
return mapping, nil
}
return nil, errUnrecognized
}
type sectionType int
const (
unrecognizedSection sectionType = iota
memoryMapSection
)
var memoryMapTriggers = []string{
"--- Memory map: ---",
"MAPPED_LIBRARIES:",
}
func sectionTrigger(line string) sectionType {
for _, trigger := range memoryMapTriggers {
if strings.Contains(line, trigger) {
return memoryMapSection
}
}
return unrecognizedSection
}
func (p *Profile) addLegacyFrameInfo() {
switch {
case isProfileType(p, heapzSampleTypes) ||
isProfileType(p, heapzInUseSampleTypes) ||
isProfileType(p, heapzAllocSampleTypes):
p.DropFrames, p.KeepFrames = allocRxStr, allocSkipRxStr
case isProfileType(p, contentionzSampleTypes):
p.DropFrames, p.KeepFrames = lockRxStr, ""
default:
p.DropFrames, p.KeepFrames = cpuProfilerRxStr, ""
}
}
var heapzSampleTypes = []string{"allocations", "size"} // early Go pprof profiles
var heapzInUseSampleTypes = []string{"inuse_objects", "inuse_space"}
var heapzAllocSampleTypes = []string{"alloc_objects", "alloc_space"}
var contentionzSampleTypes = []string{"contentions", "delay"}
func isProfileType(p *Profile, t []string) bool {
st := p.SampleType
if len(st) != len(t) {
return false
}
for i := range st {
if st[i].Type != t[i] {
return false
}
}
return true
}
var allocRxStr = strings.Join([]string{
// POSIX entry points.
`calloc`,
`cfree`,
`malloc`,
`free`,
`memalign`,
`do_memalign`,
`(__)?posix_memalign`,
`pvalloc`,
`valloc`,
`realloc`,
// TC malloc.
`tcmalloc::.*`,
`tc_calloc`,
`tc_cfree`,
`tc_malloc`,
`tc_free`,
`tc_memalign`,
`tc_posix_memalign`,
`tc_pvalloc`,
`tc_valloc`,
`tc_realloc`,
`tc_new`,
`tc_delete`,
`tc_newarray`,
`tc_deletearray`,
`tc_new_nothrow`,
`tc_newarray_nothrow`,
// Memory-allocation routines on OS X.
`malloc_zone_malloc`,
`malloc_zone_calloc`,
`malloc_zone_valloc`,
`malloc_zone_realloc`,
`malloc_zone_memalign`,
`malloc_zone_free`,
// Go runtime
`runtime\..*`,
// Other misc. memory allocation routines
`BaseArena::.*`,
`(::)?do_malloc_no_errno`,
`(::)?do_malloc_pages`,
`(::)?do_malloc`,
`DoSampledAllocation`,
`MallocedMemBlock::MallocedMemBlock`,
`_M_allocate`,
`__builtin_(vec_)?delete`,
`__builtin_(vec_)?new`,
`__gnu_cxx::new_allocator::allocate`,
`__libc_malloc`,
`__malloc_alloc_template::allocate`,
`allocate`,
`cpp_alloc`,
`operator new(\[\])?`,
`simple_alloc::allocate`,
}, `|`)
var allocSkipRxStr = strings.Join([]string{
// Preserve Go runtime frames that appear in the middle/bottom of
// the stack.
`runtime\.panic`,
`runtime\.reflectcall`,
`runtime\.call[0-9]*`,
}, `|`)
var cpuProfilerRxStr = strings.Join([]string{
`ProfileData::Add`,
`ProfileData::prof_handler`,
`CpuProfiler::prof_handler`,
`__pthread_sighandler`,
`__restore`,
}, `|`)
var lockRxStr = strings.Join([]string{
`RecordLockProfileData`,
`(base::)?RecordLockProfileData.*`,
`(base::)?SubmitMutexProfileData.*`,
`(base::)?SubmitSpinLockProfileData.*`,
`(Mutex::)?AwaitCommon.*`,
`(Mutex::)?Unlock.*`,
`(Mutex::)?UnlockSlow.*`,
`(Mutex::)?ReaderUnlock.*`,
`(MutexLock::)?~MutexLock.*`,
`(SpinLock::)?Unlock.*`,
`(SpinLock::)?SlowUnlock.*`,
`(SpinLockHolder::)?~SpinLockHolder.*`,
}, `|`)
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package profile
import (
"fmt"
"sort"
"strconv"
"strings"
)
// Merge merges all the profiles in profs into a single Profile.
// Returns a new profile independent of the input profiles. The merged
// profile is compacted to eliminate unused samples, locations,
// functions and mappings. Profiles must have identical profile sample
// and period types or the merge will fail. profile.Period of the
// resulting profile will be the maximum of all profiles, and
// profile.TimeNanos will be the earliest nonzero one.
func Merge(srcs []*Profile) (*Profile, error) {
if len(srcs) == 0 {
return nil, fmt.Errorf("no profiles to merge")
}
p, err := combineHeaders(srcs)
if err != nil {
return nil, err
}
pm := &profileMerger{
p: p,
samples: make(map[sampleKey]*Sample, len(srcs[0].Sample)),
locations: make(map[locationKey]*Location, len(srcs[0].Location)),
functions: make(map[functionKey]*Function, len(srcs[0].Function)),
mappings: make(map[mappingKey]*Mapping, len(srcs[0].Mapping)),
}
for _, src := range srcs {
// Clear the profile-specific hash tables
pm.locationsByID = make(map[uint64]*Location, len(src.Location))
pm.functionsByID = make(map[uint64]*Function, len(src.Function))
pm.mappingsByID = make(map[uint64]mapInfo, len(src.Mapping))
if len(pm.mappings) == 0 && len(src.Mapping) > 0 {
// The Mapping list has the property that the first mapping
// represents the main binary. Take the first Mapping we see,
// otherwise the operations below will add mappings in an
// arbitrary order.
pm.mapMapping(src.Mapping[0])
}
for _, s := range src.Sample {
if !isZeroSample(s) {
pm.mapSample(s)
}
}
}
for _, s := range p.Sample {
if isZeroSample(s) {
// If there are any zero samples, re-merge the profile to GC
// them.
return Merge([]*Profile{p})
}
}
return p, nil
}
// Normalize normalizes the source profile by multiplying each value in profile by the
// ratio of the sum of the base profile's values of that sample type to the sum of the
// source profile's value of that sample type.
func (p *Profile) Normalize(pb *Profile) error {
if err := p.compatible(pb); err != nil {
return err
}
baseVals := make([]int64, len(p.SampleType))
for _, s := range pb.Sample {
for i, v := range s.Value {
baseVals[i] += v
}
}
srcVals := make([]int64, len(p.SampleType))
for _, s := range p.Sample {
for i, v := range s.Value {
srcVals[i] += v
}
}
normScale := make([]float64, len(baseVals))
for i := range baseVals {
if srcVals[i] == 0 {
normScale[i] = 0.0
} else {
normScale[i] = float64(baseVals[i]) / float64(srcVals[i])
}
}
p.ScaleN(normScale)
return nil
}
func isZeroSample(s *Sample) bool {
for _, v := range s.Value {
if v != 0 {
return false
}
}
return true
}
type profileMerger struct {
p *Profile
// Memoization tables within a profile.
locationsByID map[uint64]*Location
functionsByID map[uint64]*Function
mappingsByID map[uint64]mapInfo
// Memoization tables for profile entities.
samples map[sampleKey]*Sample
locations map[locationKey]*Location
functions map[functionKey]*Function
mappings map[mappingKey]*Mapping
}
type mapInfo struct {
m *Mapping
offset int64
}
func (pm *profileMerger) mapSample(src *Sample) *Sample {
s := &Sample{
Location: make([]*Location, len(src.Location)),
Value: make([]int64, len(src.Value)),
Label: make(map[string][]string, len(src.Label)),
NumLabel: make(map[string][]int64, len(src.NumLabel)),
NumUnit: make(map[string][]string, len(src.NumLabel)),
}
for i, l := range src.Location {
s.Location[i] = pm.mapLocation(l)
}
for k, v := range src.Label {
vv := make([]string, len(v))
copy(vv, v)
s.Label[k] = vv
}
for k, v := range src.NumLabel {
u := src.NumUnit[k]
vv := make([]int64, len(v))
uu := make([]string, len(u))
copy(vv, v)
copy(uu, u)
s.NumLabel[k] = vv
s.NumUnit[k] = uu
}
// Check memoization table. Must be done on the remapped location to
// account for the remapped mapping. Add current values to the
// existing sample.
k := s.key()
if ss, ok := pm.samples[k]; ok {
for i, v := range src.Value {
ss.Value[i] += v
}
return ss
}
copy(s.Value, src.Value)
pm.samples[k] = s
pm.p.Sample = append(pm.p.Sample, s)
return s
}
// key generates sampleKey to be used as a key for maps.
func (sample *Sample) key() sampleKey {
ids := make([]string, len(sample.Location))
for i, l := range sample.Location {
ids[i] = strconv.FormatUint(l.ID, 16)
}
labels := make([]string, 0, len(sample.Label))
for k, v := range sample.Label {
labels = append(labels, fmt.Sprintf("%q%q", k, v))
}
sort.Strings(labels)
numlabels := make([]string, 0, len(sample.NumLabel))
for k, v := range sample.NumLabel {
numlabels = append(numlabels, fmt.Sprintf("%q%x%x", k, v, sample.NumUnit[k]))
}
sort.Strings(numlabels)
return sampleKey{
strings.Join(ids, "|"),
strings.Join(labels, ""),
strings.Join(numlabels, ""),
}
}
type sampleKey struct {
locations string
labels string
numlabels string
}
func (pm *profileMerger) mapLocation(src *Location) *Location {
if src == nil {
return nil
}
if l, ok := pm.locationsByID[src.ID]; ok {
pm.locationsByID[src.ID] = l
return l
}
mi := pm.mapMapping(src.Mapping)
l := &Location{
ID: uint64(len(pm.p.Location) + 1),
Mapping: mi.m,
Address: uint64(int64(src.Address) + mi.offset),
Line: make([]Line, len(src.Line)),
IsFolded: src.IsFolded,
}
for i, ln := range src.Line {
l.Line[i] = pm.mapLine(ln)
}
// Check memoization table. Must be done on the remapped location to
// account for the remapped mapping ID.
k := l.key()
if ll, ok := pm.locations[k]; ok {
pm.locationsByID[src.ID] = ll
return ll
}
pm.locationsByID[src.ID] = l
pm.locations[k] = l
pm.p.Location = append(pm.p.Location, l)
return l
}
// key generates locationKey to be used as a key for maps.
func (l *Location) key() locationKey {
key := locationKey{
addr: l.Address,
isFolded: l.IsFolded,
}
if l.Mapping != nil {
// Normalizes address to handle address space randomization.
key.addr -= l.Mapping.Start
key.mappingID = l.Mapping.ID
}
lines := make([]string, len(l.Line)*2)
for i, line := range l.Line {
if line.Function != nil {
lines[i*2] = strconv.FormatUint(line.Function.ID, 16)
}
lines[i*2+1] = strconv.FormatInt(line.Line, 16)
}
key.lines = strings.Join(lines, "|")
return key
}
type locationKey struct {
addr, mappingID uint64
lines string
isFolded bool
}
func (pm *profileMerger) mapMapping(src *Mapping) mapInfo {
if src == nil {
return mapInfo{}
}
if mi, ok := pm.mappingsByID[src.ID]; ok {
return mi
}
// Check memoization tables.
mk := src.key()
if m, ok := pm.mappings[mk]; ok {
mi := mapInfo{m, int64(m.Start) - int64(src.Start)}
pm.mappingsByID[src.ID] = mi
return mi
}
m := &Mapping{
ID: uint64(len(pm.p.Mapping) + 1),
Start: src.Start,
Limit: src.Limit,
Offset: src.Offset,
File: src.File,
BuildID: src.BuildID,
HasFunctions: src.HasFunctions,
HasFilenames: src.HasFilenames,
HasLineNumbers: src.HasLineNumbers,
HasInlineFrames: src.HasInlineFrames,
}
pm.p.Mapping = append(pm.p.Mapping, m)
// Update memoization tables.
pm.mappings[mk] = m
mi := mapInfo{m, 0}
pm.mappingsByID[src.ID] = mi
return mi
}
// key generates encoded strings of Mapping to be used as a key for
// maps.
func (m *Mapping) key() mappingKey {
// Normalize addresses to handle address space randomization.
// Round up to next 4K boundary to avoid minor discrepancies.
const mapsizeRounding = 0x1000
size := m.Limit - m.Start
size = size + mapsizeRounding - 1
size = size - (size % mapsizeRounding)
key := mappingKey{
size: size,
offset: m.Offset,
}
switch {
case m.BuildID != "":
key.buildIDOrFile = m.BuildID
case m.File != "":
key.buildIDOrFile = m.File
default:
// A mapping containing neither build ID nor file name is a fake mapping. A
// key with empty buildIDOrFile is used for fake mappings so that they are
// treated as the same mapping during merging.
}
return key
}
type mappingKey struct {
size, offset uint64
buildIDOrFile string
}
func (pm *profileMerger) mapLine(src Line) Line {
ln := Line{
Function: pm.mapFunction(src.Function),
Line: src.Line,
}
return ln
}
func (pm *profileMerger) mapFunction(src *Function) *Function {
if src == nil {
return nil
}
if f, ok := pm.functionsByID[src.ID]; ok {
return f
}
k := src.key()
if f, ok := pm.functions[k]; ok {
pm.functionsByID[src.ID] = f
return f
}
f := &Function{
ID: uint64(len(pm.p.Function) + 1),
Name: src.Name,
SystemName: src.SystemName,
Filename: src.Filename,
StartLine: src.StartLine,
}
pm.functions[k] = f
pm.functionsByID[src.ID] = f
pm.p.Function = append(pm.p.Function, f)
return f
}
// key generates a struct to be used as a key for maps.
func (f *Function) key() functionKey {
return functionKey{
f.StartLine,
f.Name,
f.SystemName,
f.Filename,
}
}
type functionKey struct {
startLine int64
name, systemName, fileName string
}
// combineHeaders checks that all profiles can be merged and returns
// their combined profile.
func combineHeaders(srcs []*Profile) (*Profile, error) {
for _, s := range srcs[1:] {
if err := srcs[0].compatible(s); err != nil {
return nil, err
}
}
var timeNanos, durationNanos, period int64
var comments []string
seenComments := map[string]bool{}
var defaultSampleType string
for _, s := range srcs {
if timeNanos == 0 || s.TimeNanos < timeNanos {
timeNanos = s.TimeNanos
}
durationNanos += s.DurationNanos
if period == 0 || period < s.Period {
period = s.Period
}
for _, c := range s.Comments {
if seen := seenComments[c]; !seen {
comments = append(comments, c)
seenComments[c] = true
}
}
if defaultSampleType == "" {
defaultSampleType = s.DefaultSampleType
}
}
p := &Profile{
SampleType: make([]*ValueType, len(srcs[0].SampleType)),
DropFrames: srcs[0].DropFrames,
KeepFrames: srcs[0].KeepFrames,
TimeNanos: timeNanos,
DurationNanos: durationNanos,
PeriodType: srcs[0].PeriodType,
Period: period,
Comments: comments,
DefaultSampleType: defaultSampleType,
}
copy(p.SampleType, srcs[0].SampleType)
return p, nil
}
// compatible determines if two profiles can be compared/merged.
// returns nil if the profiles are compatible; otherwise an error with
// details on the incompatibility.
func (p *Profile) compatible(pb *Profile) error {
if !equalValueType(p.PeriodType, pb.PeriodType) {
return fmt.Errorf("incompatible period types %v and %v", p.PeriodType, pb.PeriodType)
}
if len(p.SampleType) != len(pb.SampleType) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
for i := range p.SampleType {
if !equalValueType(p.SampleType[i], pb.SampleType[i]) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
}
return nil
}
// equalValueType returns true if the two value types are semantically
// equal. It ignores the internal fields used during encode/decode.
func equalValueType(st1, st2 *ValueType) bool {
return st1.Type == st2.Type && st1.Unit == st2.Unit
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package profile provides a representation of
// github.com/google/pprof/proto/profile.proto and
// methods to encode/decode/merge profiles in this format.
package profile
import (
"bytes"
"compress/gzip"
"fmt"
"internal/lazyregexp"
"io"
"strings"
"time"
)
// Profile is an in-memory representation of profile.proto.
type Profile struct {
SampleType []*ValueType
DefaultSampleType string
Sample []*Sample
Mapping []*Mapping
Location []*Location
Function []*Function
Comments []string
DropFrames string
KeepFrames string
TimeNanos int64
DurationNanos int64
PeriodType *ValueType
Period int64
commentX []int64
dropFramesX int64
keepFramesX int64
stringTable []string
defaultSampleTypeX int64
}
// ValueType corresponds to Profile.ValueType
type ValueType struct {
Type string // cpu, wall, inuse_space, etc
Unit string // seconds, nanoseconds, bytes, etc
typeX int64
unitX int64
}
// Sample corresponds to Profile.Sample
type Sample struct {
Location []*Location
Value []int64
Label map[string][]string
NumLabel map[string][]int64
NumUnit map[string][]string
locationIDX []uint64
labelX []Label
}
// Label corresponds to Profile.Label
type Label struct {
keyX int64
// Exactly one of the two following values must be set
strX int64
numX int64 // Integer value for this label
}
// Mapping corresponds to Profile.Mapping
type Mapping struct {
ID uint64
Start uint64
Limit uint64
Offset uint64
File string
BuildID string
HasFunctions bool
HasFilenames bool
HasLineNumbers bool
HasInlineFrames bool
fileX int64
buildIDX int64
}
// Location corresponds to Profile.Location
type Location struct {
ID uint64
Mapping *Mapping
Address uint64
Line []Line
IsFolded bool
mappingIDX uint64
}
// Line corresponds to Profile.Line
type Line struct {
Function *Function
Line int64
functionIDX uint64
}
// Function corresponds to Profile.Function
type Function struct {
ID uint64
Name string
SystemName string
Filename string
StartLine int64
nameX int64
systemNameX int64
filenameX int64
}
// Parse parses a profile and checks for its validity. The input
// may be a gzip-compressed encoded protobuf or one of many legacy
// profile formats which may be unsupported in the future.
func Parse(r io.Reader) (*Profile, error) {
orig, err := io.ReadAll(r)
if err != nil {
return nil, err
}
var p *Profile
if len(orig) >= 2 && orig[0] == 0x1f && orig[1] == 0x8b {
gz, err := gzip.NewReader(bytes.NewBuffer(orig))
if err != nil {
return nil, fmt.Errorf("decompressing profile: %v", err)
}
data, err := io.ReadAll(gz)
if err != nil {
return nil, fmt.Errorf("decompressing profile: %v", err)
}
orig = data
}
if p, err = parseUncompressed(orig); err != nil {
if p, err = parseLegacy(orig); err != nil {
return nil, fmt.Errorf("parsing profile: %v", err)
}
}
if err := p.CheckValid(); err != nil {
return nil, fmt.Errorf("malformed profile: %v", err)
}
return p, nil
}
var errUnrecognized = fmt.Errorf("unrecognized profile format")
var errMalformed = fmt.Errorf("malformed profile format")
func parseLegacy(data []byte) (*Profile, error) {
parsers := []func([]byte) (*Profile, error){
parseCPU,
parseHeap,
parseGoCount, // goroutine, threadcreate
parseThread,
parseContention,
}
for _, parser := range parsers {
p, err := parser(data)
if err == nil {
p.setMain()
p.addLegacyFrameInfo()
return p, nil
}
if err != errUnrecognized {
return nil, err
}
}
return nil, errUnrecognized
}
func parseUncompressed(data []byte) (*Profile, error) {
p := &Profile{}
if err := unmarshal(data, p); err != nil {
return nil, err
}
if err := p.postDecode(); err != nil {
return nil, err
}
return p, nil
}
var libRx = lazyregexp.New(`([.]so$|[.]so[._][0-9]+)`)
// setMain scans Mapping entries and guesses which entry is main
// because legacy profiles don't obey the convention of putting main
// first.
func (p *Profile) setMain() {
for i := 0; i < len(p.Mapping); i++ {
file := strings.TrimSpace(strings.ReplaceAll(p.Mapping[i].File, "(deleted)", ""))
if len(file) == 0 {
continue
}
if len(libRx.FindStringSubmatch(file)) > 0 {
continue
}
if strings.HasPrefix(file, "[") {
continue
}
// Swap what we guess is main to position 0.
p.Mapping[i], p.Mapping[0] = p.Mapping[0], p.Mapping[i]
break
}
}
// Write writes the profile as a gzip-compressed marshaled protobuf.
func (p *Profile) Write(w io.Writer) error {
p.preEncode()
b := marshal(p)
zw := gzip.NewWriter(w)
defer zw.Close()
_, err := zw.Write(b)
return err
}
// CheckValid tests whether the profile is valid. Checks include, but are
// not limited to:
// - len(Profile.Sample[n].value) == len(Profile.value_unit)
// - Sample.id has a corresponding Profile.Location
func (p *Profile) CheckValid() error {
// Check that sample values are consistent
sampleLen := len(p.SampleType)
if sampleLen == 0 && len(p.Sample) != 0 {
return fmt.Errorf("missing sample type information")
}
for _, s := range p.Sample {
if len(s.Value) != sampleLen {
return fmt.Errorf("mismatch: sample has: %d values vs. %d types", len(s.Value), len(p.SampleType))
}
}
// Check that all mappings/locations/functions are in the tables
// Check that there are no duplicate ids
mappings := make(map[uint64]*Mapping, len(p.Mapping))
for _, m := range p.Mapping {
if m.ID == 0 {
return fmt.Errorf("found mapping with reserved ID=0")
}
if mappings[m.ID] != nil {
return fmt.Errorf("multiple mappings with same id: %d", m.ID)
}
mappings[m.ID] = m
}
functions := make(map[uint64]*Function, len(p.Function))
for _, f := range p.Function {
if f.ID == 0 {
return fmt.Errorf("found function with reserved ID=0")
}
if functions[f.ID] != nil {
return fmt.Errorf("multiple functions with same id: %d", f.ID)
}
functions[f.ID] = f
}
locations := make(map[uint64]*Location, len(p.Location))
for _, l := range p.Location {
if l.ID == 0 {
return fmt.Errorf("found location with reserved id=0")
}
if locations[l.ID] != nil {
return fmt.Errorf("multiple locations with same id: %d", l.ID)
}
locations[l.ID] = l
if m := l.Mapping; m != nil {
if m.ID == 0 || mappings[m.ID] != m {
return fmt.Errorf("inconsistent mapping %p: %d", m, m.ID)
}
}
for _, ln := range l.Line {
if f := ln.Function; f != nil {
if f.ID == 0 || functions[f.ID] != f {
return fmt.Errorf("inconsistent function %p: %d", f, f.ID)
}
}
}
}
return nil
}
// Aggregate merges the locations in the profile into equivalence
// classes preserving the request attributes. It also updates the
// samples to point to the merged locations.
func (p *Profile) Aggregate(inlineFrame, function, filename, linenumber, address bool) error {
for _, m := range p.Mapping {
m.HasInlineFrames = m.HasInlineFrames && inlineFrame
m.HasFunctions = m.HasFunctions && function
m.HasFilenames = m.HasFilenames && filename
m.HasLineNumbers = m.HasLineNumbers && linenumber
}
// Aggregate functions
if !function || !filename {
for _, f := range p.Function {
if !function {
f.Name = ""
f.SystemName = ""
}
if !filename {
f.Filename = ""
}
}
}
// Aggregate locations
if !inlineFrame || !address || !linenumber {
for _, l := range p.Location {
if !inlineFrame && len(l.Line) > 1 {
l.Line = l.Line[len(l.Line)-1:]
}
if !linenumber {
for i := range l.Line {
l.Line[i].Line = 0
}
}
if !address {
l.Address = 0
}
}
}
return p.CheckValid()
}
// Print dumps a text representation of a profile. Intended mainly
// for debugging purposes.
func (p *Profile) String() string {
ss := make([]string, 0, len(p.Sample)+len(p.Mapping)+len(p.Location))
if pt := p.PeriodType; pt != nil {
ss = append(ss, fmt.Sprintf("PeriodType: %s %s", pt.Type, pt.Unit))
}
ss = append(ss, fmt.Sprintf("Period: %d", p.Period))
if p.TimeNanos != 0 {
ss = append(ss, fmt.Sprintf("Time: %v", time.Unix(0, p.TimeNanos)))
}
if p.DurationNanos != 0 {
ss = append(ss, fmt.Sprintf("Duration: %v", time.Duration(p.DurationNanos)))
}
ss = append(ss, "Samples:")
var sh1 string
for _, s := range p.SampleType {
sh1 = sh1 + fmt.Sprintf("%s/%s ", s.Type, s.Unit)
}
ss = append(ss, strings.TrimSpace(sh1))
for _, s := range p.Sample {
var sv string
for _, v := range s.Value {
sv = fmt.Sprintf("%s %10d", sv, v)
}
sv = sv + ": "
for _, l := range s.Location {
sv = sv + fmt.Sprintf("%d ", l.ID)
}
ss = append(ss, sv)
const labelHeader = " "
if len(s.Label) > 0 {
ls := labelHeader
for k, v := range s.Label {
ls = ls + fmt.Sprintf("%s:%v ", k, v)
}
ss = append(ss, ls)
}
if len(s.NumLabel) > 0 {
ls := labelHeader
for k, v := range s.NumLabel {
ls = ls + fmt.Sprintf("%s:%v ", k, v)
}
ss = append(ss, ls)
}
}
ss = append(ss, "Locations")
for _, l := range p.Location {
locStr := fmt.Sprintf("%6d: %#x ", l.ID, l.Address)
if m := l.Mapping; m != nil {
locStr = locStr + fmt.Sprintf("M=%d ", m.ID)
}
if len(l.Line) == 0 {
ss = append(ss, locStr)
}
for li := range l.Line {
lnStr := "??"
if fn := l.Line[li].Function; fn != nil {
lnStr = fmt.Sprintf("%s %s:%d s=%d",
fn.Name,
fn.Filename,
l.Line[li].Line,
fn.StartLine)
if fn.Name != fn.SystemName {
lnStr = lnStr + "(" + fn.SystemName + ")"
}
}
ss = append(ss, locStr+lnStr)
// Do not print location details past the first line
locStr = " "
}
}
ss = append(ss, "Mappings")
for _, m := range p.Mapping {
bits := ""
if m.HasFunctions {
bits += "[FN]"
}
if m.HasFilenames {
bits += "[FL]"
}
if m.HasLineNumbers {
bits += "[LN]"
}
if m.HasInlineFrames {
bits += "[IN]"
}
ss = append(ss, fmt.Sprintf("%d: %#x/%#x/%#x %s %s %s",
m.ID,
m.Start, m.Limit, m.Offset,
m.File,
m.BuildID,
bits))
}
return strings.Join(ss, "\n") + "\n"
}
// Merge adds profile p adjusted by ratio r into profile p. Profiles
// must be compatible (same Type and SampleType).
// TODO(rsilvera): consider normalizing the profiles based on the
// total samples collected.
func (p *Profile) Merge(pb *Profile, r float64) error {
if err := p.Compatible(pb); err != nil {
return err
}
pb = pb.Copy()
// Keep the largest of the two periods.
if pb.Period > p.Period {
p.Period = pb.Period
}
p.DurationNanos += pb.DurationNanos
p.Mapping = append(p.Mapping, pb.Mapping...)
for i, m := range p.Mapping {
m.ID = uint64(i + 1)
}
p.Location = append(p.Location, pb.Location...)
for i, l := range p.Location {
l.ID = uint64(i + 1)
}
p.Function = append(p.Function, pb.Function...)
for i, f := range p.Function {
f.ID = uint64(i + 1)
}
if r != 1.0 {
for _, s := range pb.Sample {
for i, v := range s.Value {
s.Value[i] = int64((float64(v) * r))
}
}
}
p.Sample = append(p.Sample, pb.Sample...)
return p.CheckValid()
}
// Compatible determines if two profiles can be compared/merged.
// returns nil if the profiles are compatible; otherwise an error with
// details on the incompatibility.
func (p *Profile) Compatible(pb *Profile) error {
if !compatibleValueTypes(p.PeriodType, pb.PeriodType) {
return fmt.Errorf("incompatible period types %v and %v", p.PeriodType, pb.PeriodType)
}
if len(p.SampleType) != len(pb.SampleType) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
for i := range p.SampleType {
if !compatibleValueTypes(p.SampleType[i], pb.SampleType[i]) {
return fmt.Errorf("incompatible sample types %v and %v", p.SampleType, pb.SampleType)
}
}
return nil
}
// HasFunctions determines if all locations in this profile have
// symbolized function information.
func (p *Profile) HasFunctions() bool {
for _, l := range p.Location {
if l.Mapping == nil || !l.Mapping.HasFunctions {
return false
}
}
return true
}
// HasFileLines determines if all locations in this profile have
// symbolized file and line number information.
func (p *Profile) HasFileLines() bool {
for _, l := range p.Location {
if l.Mapping == nil || (!l.Mapping.HasFilenames || !l.Mapping.HasLineNumbers) {
return false
}
}
return true
}
func compatibleValueTypes(v1, v2 *ValueType) bool {
if v1 == nil || v2 == nil {
return true // No grounds to disqualify.
}
return v1.Type == v2.Type && v1.Unit == v2.Unit
}
// Copy makes a fully independent copy of a profile.
func (p *Profile) Copy() *Profile {
p.preEncode()
b := marshal(p)
pp := &Profile{}
if err := unmarshal(b, pp); err != nil {
panic(err)
}
if err := pp.postDecode(); err != nil {
panic(err)
}
return pp
}
// Demangler maps symbol names to a human-readable form. This may
// include C++ demangling and additional simplification. Names that
// are not demangled may be missing from the resulting map.
type Demangler func(name []string) (map[string]string, error)
// Demangle attempts to demangle and optionally simplify any function
// names referenced in the profile. It works on a best-effort basis:
// it will silently preserve the original names in case of any errors.
func (p *Profile) Demangle(d Demangler) error {
// Collect names to demangle.
var names []string
for _, fn := range p.Function {
names = append(names, fn.SystemName)
}
// Update profile with demangled names.
demangled, err := d(names)
if err != nil {
return err
}
for _, fn := range p.Function {
if dd, ok := demangled[fn.SystemName]; ok {
fn.Name = dd
}
}
return nil
}
// Empty reports whether the profile contains no samples.
func (p *Profile) Empty() bool {
return len(p.Sample) == 0
}
// Scale multiplies all sample values in a profile by a constant.
func (p *Profile) Scale(ratio float64) {
if ratio == 1 {
return
}
ratios := make([]float64, len(p.SampleType))
for i := range p.SampleType {
ratios[i] = ratio
}
p.ScaleN(ratios)
}
// ScaleN multiplies each sample values in a sample by a different amount.
func (p *Profile) ScaleN(ratios []float64) error {
if len(p.SampleType) != len(ratios) {
return fmt.Errorf("mismatched scale ratios, got %d, want %d", len(ratios), len(p.SampleType))
}
allOnes := true
for _, r := range ratios {
if r != 1 {
allOnes = false
break
}
}
if allOnes {
return nil
}
for _, s := range p.Sample {
for i, v := range s.Value {
if ratios[i] != 1 {
s.Value[i] = int64(float64(v) * ratios[i])
}
}
}
return nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is a simple protocol buffer encoder and decoder.
//
// A protocol message must implement the message interface:
// decoder() []decoder
// encode(*buffer)
//
// The decode method returns a slice indexed by field number that gives the
// function to decode that field.
// The encode method encodes its receiver into the given buffer.
//
// The two methods are simple enough to be implemented by hand rather than
// by using a protocol compiler.
//
// See profile.go for examples of messages implementing this interface.
//
// There is no support for groups, message sets, or "has" bits.
package profile
import (
"errors"
"fmt"
)
type buffer struct {
field int
typ int
u64 uint64
data []byte
tmp [16]byte
}
type decoder func(*buffer, message) error
type message interface {
decoder() []decoder
encode(*buffer)
}
func marshal(m message) []byte {
var b buffer
m.encode(&b)
return b.data
}
func encodeVarint(b *buffer, x uint64) {
for x >= 128 {
b.data = append(b.data, byte(x)|0x80)
x >>= 7
}
b.data = append(b.data, byte(x))
}
func encodeLength(b *buffer, tag int, len int) {
encodeVarint(b, uint64(tag)<<3|2)
encodeVarint(b, uint64(len))
}
func encodeUint64(b *buffer, tag int, x uint64) {
// append varint to b.data
encodeVarint(b, uint64(tag)<<3|0)
encodeVarint(b, x)
}
func encodeUint64s(b *buffer, tag int, x []uint64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
encodeVarint(b, u)
}
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
encodeUint64(b, tag, u)
}
}
func encodeUint64Opt(b *buffer, tag int, x uint64) {
if x == 0 {
return
}
encodeUint64(b, tag, x)
}
func encodeInt64(b *buffer, tag int, x int64) {
u := uint64(x)
encodeUint64(b, tag, u)
}
func encodeInt64Opt(b *buffer, tag int, x int64) {
if x == 0 {
return
}
encodeInt64(b, tag, x)
}
func encodeInt64s(b *buffer, tag int, x []int64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
encodeVarint(b, uint64(u))
}
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
encodeInt64(b, tag, u)
}
}
func encodeString(b *buffer, tag int, x string) {
encodeLength(b, tag, len(x))
b.data = append(b.data, x...)
}
func encodeStrings(b *buffer, tag int, x []string) {
for _, s := range x {
encodeString(b, tag, s)
}
}
func encodeBool(b *buffer, tag int, x bool) {
if x {
encodeUint64(b, tag, 1)
} else {
encodeUint64(b, tag, 0)
}
}
func encodeBoolOpt(b *buffer, tag int, x bool) {
if !x {
return
}
encodeBool(b, tag, x)
}
func encodeMessage(b *buffer, tag int, m message) {
n1 := len(b.data)
m.encode(b)
n2 := len(b.data)
encodeLength(b, tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
}
func unmarshal(data []byte, m message) (err error) {
b := buffer{data: data, typ: 2}
return decodeMessage(&b, m)
}
func le64(p []byte) uint64 {
return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
}
func le32(p []byte) uint32 {
return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
}
func decodeVarint(data []byte) (uint64, []byte, error) {
var i int
var u uint64
for i = 0; ; i++ {
if i >= 10 || i >= len(data) {
return 0, nil, errors.New("bad varint")
}
u |= uint64(data[i]&0x7F) << uint(7*i)
if data[i]&0x80 == 0 {
return u, data[i+1:], nil
}
}
}
func decodeField(b *buffer, data []byte) ([]byte, error) {
x, data, err := decodeVarint(data)
if err != nil {
return nil, err
}
b.field = int(x >> 3)
b.typ = int(x & 7)
b.data = nil
b.u64 = 0
switch b.typ {
case 0:
b.u64, data, err = decodeVarint(data)
if err != nil {
return nil, err
}
case 1:
if len(data) < 8 {
return nil, errors.New("not enough data")
}
b.u64 = le64(data[:8])
data = data[8:]
case 2:
var n uint64
n, data, err = decodeVarint(data)
if err != nil {
return nil, err
}
if n > uint64(len(data)) {
return nil, errors.New("too much data")
}
b.data = data[:n]
data = data[n:]
case 5:
if len(data) < 4 {
return nil, errors.New("not enough data")
}
b.u64 = uint64(le32(data[:4]))
data = data[4:]
default:
return nil, fmt.Errorf("unknown wire type: %d", b.typ)
}
return data, nil
}
func checkType(b *buffer, typ int) error {
if b.typ != typ {
return errors.New("type mismatch")
}
return nil
}
func decodeMessage(b *buffer, m message) error {
if err := checkType(b, 2); err != nil {
return err
}
dec := m.decoder()
data := b.data
for len(data) > 0 {
// pull varint field# + type
var err error
data, err = decodeField(b, data)
if err != nil {
return err
}
if b.field >= len(dec) || dec[b.field] == nil {
continue
}
if err := dec[b.field](b, m); err != nil {
return err
}
}
return nil
}
func decodeInt64(b *buffer, x *int64) error {
if err := checkType(b, 0); err != nil {
return err
}
*x = int64(b.u64)
return nil
}
func decodeInt64s(b *buffer, x *[]int64) error {
if b.typ == 2 {
// Packed encoding
data := b.data
for len(data) > 0 {
var u uint64
var err error
if u, data, err = decodeVarint(data); err != nil {
return err
}
*x = append(*x, int64(u))
}
return nil
}
var i int64
if err := decodeInt64(b, &i); err != nil {
return err
}
*x = append(*x, i)
return nil
}
func decodeUint64(b *buffer, x *uint64) error {
if err := checkType(b, 0); err != nil {
return err
}
*x = b.u64
return nil
}
func decodeUint64s(b *buffer, x *[]uint64) error {
if b.typ == 2 {
data := b.data
// Packed encoding
for len(data) > 0 {
var u uint64
var err error
if u, data, err = decodeVarint(data); err != nil {
return err
}
*x = append(*x, u)
}
return nil
}
var u uint64
if err := decodeUint64(b, &u); err != nil {
return err
}
*x = append(*x, u)
return nil
}
func decodeString(b *buffer, x *string) error {
if err := checkType(b, 2); err != nil {
return err
}
*x = string(b.data)
return nil
}
func decodeStrings(b *buffer, x *[]string) error {
var s string
if err := decodeString(b, &s); err != nil {
return err
}
*x = append(*x, s)
return nil
}
func decodeBool(b *buffer, x *bool) error {
if err := checkType(b, 0); err != nil {
return err
}
if int64(b.u64) == 0 {
*x = false
} else {
*x = true
}
return nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implements methods to remove frames from profiles.
package profile
import (
"fmt"
"regexp"
)
// Prune removes all nodes beneath a node matching dropRx, and not
// matching keepRx. If the root node of a Sample matches, the sample
// will have an empty stack.
func (p *Profile) Prune(dropRx, keepRx *regexp.Regexp) {
prune := make(map[uint64]bool)
pruneBeneath := make(map[uint64]bool)
for _, loc := range p.Location {
var i int
for i = len(loc.Line) - 1; i >= 0; i-- {
if fn := loc.Line[i].Function; fn != nil && fn.Name != "" {
funcName := fn.Name
// Account for leading '.' on the PPC ELF v1 ABI.
if funcName[0] == '.' {
funcName = funcName[1:]
}
if dropRx.MatchString(funcName) {
if keepRx == nil || !keepRx.MatchString(funcName) {
break
}
}
}
}
if i >= 0 {
// Found matching entry to prune.
pruneBeneath[loc.ID] = true
// Remove the matching location.
if i == len(loc.Line)-1 {
// Matched the top entry: prune the whole location.
prune[loc.ID] = true
} else {
loc.Line = loc.Line[i+1:]
}
}
}
// Prune locs from each Sample
for _, sample := range p.Sample {
// Scan from the root to the leaves to find the prune location.
// Do not prune frames before the first user frame, to avoid
// pruning everything.
foundUser := false
for i := len(sample.Location) - 1; i >= 0; i-- {
id := sample.Location[i].ID
if !prune[id] && !pruneBeneath[id] {
foundUser = true
continue
}
if !foundUser {
continue
}
if prune[id] {
sample.Location = sample.Location[i+1:]
break
}
if pruneBeneath[id] {
sample.Location = sample.Location[i:]
break
}
}
}
}
// RemoveUninteresting prunes and elides profiles using built-in
// tables of uninteresting function names.
func (p *Profile) RemoveUninteresting() error {
var keep, drop *regexp.Regexp
var err error
if p.DropFrames != "" {
if drop, err = regexp.Compile("^(" + p.DropFrames + ")$"); err != nil {
return fmt.Errorf("failed to compile regexp %s: %v", p.DropFrames, err)
}
if p.KeepFrames != "" {
if keep, err = regexp.Compile("^(" + p.KeepFrames + ")$"); err != nil {
return fmt.Errorf("failed to compile regexp %s: %v", p.KeepFrames, err)
}
}
p.Prune(drop, keep)
}
return nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflectlite
import (
"internal/goarch"
"internal/unsafeheader"
"unsafe"
)
// Swapper returns a function that swaps the elements in the provided
// slice.
//
// Swapper panics if the provided interface is not a slice.
func Swapper(slice any) func(i, j int) {
v := ValueOf(slice)
if v.Kind() != Slice {
panic(&ValueError{Method: "Swapper", Kind: v.Kind()})
}
// Fast path for slices of size 0 and 1. Nothing to swap.
switch v.Len() {
case 0:
return func(i, j int) { panic("reflect: slice index out of range") }
case 1:
return func(i, j int) {
if i != 0 || j != 0 {
panic("reflect: slice index out of range")
}
}
}
typ := v.Type().Elem().(*rtype)
size := typ.Size()
hasPtr := typ.ptrdata != 0
// Some common & small cases, without using memmove:
if hasPtr {
if size == goarch.PtrSize {
ps := *(*[]unsafe.Pointer)(v.ptr)
return func(i, j int) { ps[i], ps[j] = ps[j], ps[i] }
}
if typ.Kind() == String {
ss := *(*[]string)(v.ptr)
return func(i, j int) { ss[i], ss[j] = ss[j], ss[i] }
}
} else {
switch size {
case 8:
is := *(*[]int64)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 4:
is := *(*[]int32)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 2:
is := *(*[]int16)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 1:
is := *(*[]int8)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
}
}
s := (*unsafeheader.Slice)(v.ptr)
tmp := unsafe_New(typ) // swap scratch space
return func(i, j int) {
if uint(i) >= uint(s.Len) || uint(j) >= uint(s.Len) {
panic("reflect: slice index out of range")
}
val1 := arrayAt(s.Data, i, size, "i < s.Len")
val2 := arrayAt(s.Data, j, size, "j < s.Len")
typedmemmove(typ, tmp, val1)
typedmemmove(typ, val1, val2)
typedmemmove(typ, val2, tmp)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package reflectlite implements lightweight version of reflect, not using
// any package except for "runtime" and "unsafe".
package reflectlite
import "unsafe"
// Type is the representation of a Go type.
//
// Not all methods apply to all kinds of types. Restrictions,
// if any, are noted in the documentation for each method.
// Use the Kind method to find out the kind of type before
// calling kind-specific methods. Calling a method
// inappropriate to the kind of type causes a run-time panic.
//
// Type values are comparable, such as with the == operator,
// so they can be used as map keys.
// Two Type values are equal if they represent identical types.
type Type interface {
// Methods applicable to all types.
// Name returns the type's name within its package for a defined type.
// For other (non-defined) types it returns the empty string.
Name() string
// PkgPath returns a defined type's package path, that is, the import path
// that uniquely identifies the package, such as "encoding/base64".
// If the type was predeclared (string, error) or not defined (*T, struct{},
// []int, or A where A is an alias for a non-defined type), the package path
// will be the empty string.
PkgPath() string
// Size returns the number of bytes needed to store
// a value of the given type; it is analogous to unsafe.Sizeof.
Size() uintptr
// Kind returns the specific kind of this type.
Kind() Kind
// Implements reports whether the type implements the interface type u.
Implements(u Type) bool
// AssignableTo reports whether a value of the type is assignable to type u.
AssignableTo(u Type) bool
// Comparable reports whether values of this type are comparable.
Comparable() bool
// String returns a string representation of the type.
// The string representation may use shortened package names
// (e.g., base64 instead of "encoding/base64") and is not
// guaranteed to be unique among types. To test for type identity,
// compare the Types directly.
String() string
// Elem returns a type's element type.
// It panics if the type's Kind is not Ptr.
Elem() Type
common() *rtype
uncommon() *uncommonType
}
/*
* These data structures are known to the compiler (../../cmd/internal/reflectdata/reflect.go).
* A few are known to ../runtime/type.go to convey to debuggers.
* They are also known to ../runtime/type.go.
*/
// A Kind represents the specific kind of type that a Type represents.
// The zero Kind is not a valid kind.
type Kind uint
const (
Invalid Kind = iota
Bool
Int
Int8
Int16
Int32
Int64
Uint
Uint8
Uint16
Uint32
Uint64
Uintptr
Float32
Float64
Complex64
Complex128
Array
Chan
Func
Interface
Map
Pointer
Slice
String
Struct
UnsafePointer
)
const Ptr = Pointer
// tflag is used by an rtype to signal what extra type information is
// available in the memory directly following the rtype value.
//
// tflag values must be kept in sync with copies in:
//
// cmd/compile/internal/reflectdata/reflect.go
// cmd/link/internal/ld/decodesym.go
// runtime/type.go
type tflag uint8
const (
// tflagUncommon means that there is a pointer, *uncommonType,
// just beyond the outer type structure.
//
// For example, if t.Kind() == Struct and t.tflag&tflagUncommon != 0,
// then t has uncommonType data and it can be accessed as:
//
// type tUncommon struct {
// structType
// u uncommonType
// }
// u := &(*tUncommon)(unsafe.Pointer(t)).u
tflagUncommon tflag = 1 << 0
// tflagExtraStar means the name in the str field has an
// extraneous '*' prefix. This is because for most types T in
// a program, the type *T also exists and reusing the str data
// saves binary size.
tflagExtraStar tflag = 1 << 1
// tflagNamed means the type has a name.
tflagNamed tflag = 1 << 2
// tflagRegularMemory means that equal and hash functions can treat
// this type as a single region of t.size bytes.
tflagRegularMemory tflag = 1 << 3
)
// rtype is the common implementation of most values.
// It is embedded in other struct types.
//
// rtype must be kept in sync with ../runtime/type.go:/^type._type.
type rtype struct {
size uintptr
ptrdata uintptr // number of bytes in the type that can contain pointers
hash uint32 // hash of type; avoids computation in hash tables
tflag tflag // extra type information flags
align uint8 // alignment of variable with this type
fieldAlign uint8 // alignment of struct field with this type
kind uint8 // enumeration for C
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
equal func(unsafe.Pointer, unsafe.Pointer) bool
gcdata *byte // garbage collection data
str nameOff // string form
ptrToThis typeOff // type for pointer to this type, may be zero
}
// Method on non-interface type
type method struct {
name nameOff // name of method
mtyp typeOff // method type (without receiver)
ifn textOff // fn used in interface call (one-word receiver)
tfn textOff // fn used for normal method call
}
// uncommonType is present only for defined types or types with methods
// (if T is a defined type, the uncommonTypes for T and *T have methods).
// Using a pointer to this struct reduces the overall size required
// to describe a non-defined type with no methods.
type uncommonType struct {
pkgPath nameOff // import path; empty for built-in types like int, string
mcount uint16 // number of methods
xcount uint16 // number of exported methods
moff uint32 // offset from this uncommontype to [mcount]method
_ uint32 // unused
}
// chanDir represents a channel type's direction.
type chanDir int
const (
recvDir chanDir = 1 << iota // <-chan
sendDir // chan<-
bothDir = recvDir | sendDir // chan
)
// arrayType represents a fixed array type.
type arrayType struct {
rtype
elem *rtype // array element type
slice *rtype // slice type
len uintptr
}
// chanType represents a channel type.
type chanType struct {
rtype
elem *rtype // channel element type
dir uintptr // channel direction (chanDir)
}
// funcType represents a function type.
//
// A *rtype for each in and out parameter is stored in an array that
// directly follows the funcType (and possibly its uncommonType). So
// a function type with one method, one input, and one output is:
//
// struct {
// funcType
// uncommonType
// [2]*rtype // [0] is in, [1] is out
// }
type funcType struct {
rtype
inCount uint16
outCount uint16 // top bit is set if last input parameter is ...
}
// imethod represents a method on an interface type
type imethod struct {
name nameOff // name of method
typ typeOff // .(*FuncType) underneath
}
// interfaceType represents an interface type.
type interfaceType struct {
rtype
pkgPath name // import path
methods []imethod // sorted by hash
}
// mapType represents a map type.
type mapType struct {
rtype
key *rtype // map key type
elem *rtype // map element (value) type
bucket *rtype // internal bucket structure
// function for hashing keys (ptr to key, seed) -> hash
hasher func(unsafe.Pointer, uintptr) uintptr
keysize uint8 // size of key slot
valuesize uint8 // size of value slot
bucketsize uint16 // size of bucket
flags uint32
}
// ptrType represents a pointer type.
type ptrType struct {
rtype
elem *rtype // pointer element (pointed at) type
}
// sliceType represents a slice type.
type sliceType struct {
rtype
elem *rtype // slice element type
}
// Struct field
type structField struct {
name name // name is always non-empty
typ *rtype // type of field
offset uintptr // byte offset of field
}
func (f *structField) embedded() bool {
return f.name.embedded()
}
// structType represents a struct type.
type structType struct {
rtype
pkgPath name
fields []structField // sorted by offset
}
// name is an encoded type name with optional extra data.
//
// The first byte is a bit field containing:
//
// 1<<0 the name is exported
// 1<<1 tag data follows the name
// 1<<2 pkgPath nameOff follows the name and tag
//
// The next two bytes are the data length:
//
// l := uint16(data[1])<<8 | uint16(data[2])
//
// Bytes [3:3+l] are the string data.
//
// If tag data follows then bytes 3+l and 3+l+1 are the tag length,
// with the data following.
//
// If the import path follows, then 4 bytes at the end of
// the data form a nameOff. The import path is only set for concrete
// methods that are defined in a different package than their type.
//
// If a name starts with "*", then the exported bit represents
// whether the pointed to type is exported.
type name struct {
bytes *byte
}
func (n name) data(off int, whySafe string) *byte {
return (*byte)(add(unsafe.Pointer(n.bytes), uintptr(off), whySafe))
}
func (n name) isExported() bool {
return (*n.bytes)&(1<<0) != 0
}
func (n name) hasTag() bool {
return (*n.bytes)&(1<<1) != 0
}
func (n name) embedded() bool {
return (*n.bytes)&(1<<3) != 0
}
// readVarint parses a varint as encoded by encoding/binary.
// It returns the number of encoded bytes and the encoded value.
func (n name) readVarint(off int) (int, int) {
v := 0
for i := 0; ; i++ {
x := *n.data(off+i, "read varint")
v += int(x&0x7f) << (7 * i)
if x&0x80 == 0 {
return i + 1, v
}
}
}
func (n name) name() string {
if n.bytes == nil {
return ""
}
i, l := n.readVarint(1)
return unsafe.String(n.data(1+i, "non-empty string"), l)
}
func (n name) tag() string {
if !n.hasTag() {
return ""
}
i, l := n.readVarint(1)
i2, l2 := n.readVarint(1 + i + l)
return unsafe.String(n.data(1+i+l+i2, "non-empty string"), l2)
}
func (n name) pkgPath() string {
if n.bytes == nil || *n.data(0, "name flag field")&(1<<2) == 0 {
return ""
}
i, l := n.readVarint(1)
off := 1 + i + l
if n.hasTag() {
i2, l2 := n.readVarint(off)
off += i2 + l2
}
var nameOff int32
// Note that this field may not be aligned in memory,
// so we cannot use a direct int32 assignment here.
copy((*[4]byte)(unsafe.Pointer(&nameOff))[:], (*[4]byte)(unsafe.Pointer(n.data(off, "name offset field")))[:])
pkgPathName := name{(*byte)(resolveTypeOff(unsafe.Pointer(n.bytes), nameOff))}
return pkgPathName.name()
}
/*
* The compiler knows the exact layout of all the data structures above.
* The compiler does not know about the data structures and methods below.
*/
const (
kindDirectIface = 1 << 5
kindGCProg = 1 << 6 // Type.gc points to GC program
kindMask = (1 << 5) - 1
)
// String returns the name of k.
func (k Kind) String() string {
if int(k) < len(kindNames) {
return kindNames[k]
}
return kindNames[0]
}
var kindNames = []string{
Invalid: "invalid",
Bool: "bool",
Int: "int",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Uint: "uint",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Uintptr: "uintptr",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
Array: "array",
Chan: "chan",
Func: "func",
Interface: "interface",
Map: "map",
Ptr: "ptr",
Slice: "slice",
String: "string",
Struct: "struct",
UnsafePointer: "unsafe.Pointer",
}
func (t *uncommonType) methods() []method {
if t.mcount == 0 {
return nil
}
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.mcount > 0"))[:t.mcount:t.mcount]
}
func (t *uncommonType) exportedMethods() []method {
if t.xcount == 0 {
return nil
}
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.xcount > 0"))[:t.xcount:t.xcount]
}
// resolveNameOff resolves a name offset from a base pointer.
// The (*rtype).nameOff method is a convenience wrapper for this function.
// Implemented in the runtime package.
func resolveNameOff(ptrInModule unsafe.Pointer, off int32) unsafe.Pointer
// resolveTypeOff resolves an *rtype offset from a base type.
// The (*rtype).typeOff method is a convenience wrapper for this function.
// Implemented in the runtime package.
func resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer
type nameOff int32 // offset to a name
type typeOff int32 // offset to an *rtype
type textOff int32 // offset from top of text section
func (t *rtype) nameOff(off nameOff) name {
return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))}
}
func (t *rtype) typeOff(off typeOff) *rtype {
return (*rtype)(resolveTypeOff(unsafe.Pointer(t), int32(off)))
}
func (t *rtype) uncommon() *uncommonType {
if t.tflag&tflagUncommon == 0 {
return nil
}
switch t.Kind() {
case Struct:
return &(*structTypeUncommon)(unsafe.Pointer(t)).u
case Ptr:
type u struct {
ptrType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Func:
type u struct {
funcType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Slice:
type u struct {
sliceType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Array:
type u struct {
arrayType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Chan:
type u struct {
chanType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Map:
type u struct {
mapType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Interface:
type u struct {
interfaceType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
default:
type u struct {
rtype
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
}
}
func (t *rtype) String() string {
s := t.nameOff(t.str).name()
if t.tflag&tflagExtraStar != 0 {
return s[1:]
}
return s
}
func (t *rtype) Size() uintptr { return t.size }
func (t *rtype) Kind() Kind { return Kind(t.kind & kindMask) }
func (t *rtype) pointers() bool { return t.ptrdata != 0 }
func (t *rtype) common() *rtype { return t }
func (t *rtype) exportedMethods() []method {
ut := t.uncommon()
if ut == nil {
return nil
}
return ut.exportedMethods()
}
func (t *rtype) NumMethod() int {
if t.Kind() == Interface {
tt := (*interfaceType)(unsafe.Pointer(t))
return tt.NumMethod()
}
return len(t.exportedMethods())
}
func (t *rtype) PkgPath() string {
if t.tflag&tflagNamed == 0 {
return ""
}
ut := t.uncommon()
if ut == nil {
return ""
}
return t.nameOff(ut.pkgPath).name()
}
func (t *rtype) hasName() bool {
return t.tflag&tflagNamed != 0
}
func (t *rtype) Name() string {
if !t.hasName() {
return ""
}
s := t.String()
i := len(s) - 1
sqBrackets := 0
for i >= 0 && (s[i] != '.' || sqBrackets != 0) {
switch s[i] {
case ']':
sqBrackets++
case '[':
sqBrackets--
}
i--
}
return s[i+1:]
}
func (t *rtype) chanDir() chanDir {
if t.Kind() != Chan {
panic("reflect: chanDir of non-chan type")
}
tt := (*chanType)(unsafe.Pointer(t))
return chanDir(tt.dir)
}
func (t *rtype) Elem() Type {
switch t.Kind() {
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
return toType(tt.elem)
case Chan:
tt := (*chanType)(unsafe.Pointer(t))
return toType(tt.elem)
case Map:
tt := (*mapType)(unsafe.Pointer(t))
return toType(tt.elem)
case Ptr:
tt := (*ptrType)(unsafe.Pointer(t))
return toType(tt.elem)
case Slice:
tt := (*sliceType)(unsafe.Pointer(t))
return toType(tt.elem)
}
panic("reflect: Elem of invalid type")
}
func (t *rtype) In(i int) Type {
if t.Kind() != Func {
panic("reflect: In of non-func type")
}
tt := (*funcType)(unsafe.Pointer(t))
return toType(tt.in()[i])
}
func (t *rtype) Key() Type {
if t.Kind() != Map {
panic("reflect: Key of non-map type")
}
tt := (*mapType)(unsafe.Pointer(t))
return toType(tt.key)
}
func (t *rtype) Len() int {
if t.Kind() != Array {
panic("reflect: Len of non-array type")
}
tt := (*arrayType)(unsafe.Pointer(t))
return int(tt.len)
}
func (t *rtype) NumField() int {
if t.Kind() != Struct {
panic("reflect: NumField of non-struct type")
}
tt := (*structType)(unsafe.Pointer(t))
return len(tt.fields)
}
func (t *rtype) NumIn() int {
if t.Kind() != Func {
panic("reflect: NumIn of non-func type")
}
tt := (*funcType)(unsafe.Pointer(t))
return int(tt.inCount)
}
func (t *rtype) NumOut() int {
if t.Kind() != Func {
panic("reflect: NumOut of non-func type")
}
tt := (*funcType)(unsafe.Pointer(t))
return len(tt.out())
}
func (t *rtype) Out(i int) Type {
if t.Kind() != Func {
panic("reflect: Out of non-func type")
}
tt := (*funcType)(unsafe.Pointer(t))
return toType(tt.out()[i])
}
func (t *funcType) in() []*rtype {
uadd := unsafe.Sizeof(*t)
if t.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommonType{})
}
if t.inCount == 0 {
return nil
}
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "t.inCount > 0"))[:t.inCount:t.inCount]
}
func (t *funcType) out() []*rtype {
uadd := unsafe.Sizeof(*t)
if t.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommonType{})
}
outCount := t.outCount & (1<<15 - 1)
if outCount == 0 {
return nil
}
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "outCount > 0"))[t.inCount : t.inCount+outCount : t.inCount+outCount]
}
// add returns p+x.
//
// The whySafe string is ignored, so that the function still inlines
// as efficiently as p+x, but all call sites should use the string to
// record why the addition is safe, which is to say why the addition
// does not cause x to advance to the very end of p's allocation
// and therefore point incorrectly at the next block in memory.
func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer {
return unsafe.Pointer(uintptr(p) + x)
}
// NumMethod returns the number of interface methods in the type's method set.
func (t *interfaceType) NumMethod() int { return len(t.methods) }
// TypeOf returns the reflection Type that represents the dynamic type of i.
// If i is a nil interface value, TypeOf returns nil.
func TypeOf(i any) Type {
eface := *(*emptyInterface)(unsafe.Pointer(&i))
return toType(eface.typ)
}
func (t *rtype) Implements(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.Implements")
}
if u.Kind() != Interface {
panic("reflect: non-interface type passed to Type.Implements")
}
return implements(u.(*rtype), t)
}
func (t *rtype) AssignableTo(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.AssignableTo")
}
uu := u.(*rtype)
return directlyAssignable(uu, t) || implements(uu, t)
}
func (t *rtype) Comparable() bool {
return t.equal != nil
}
// implements reports whether the type V implements the interface type T.
func implements(T, V *rtype) bool {
if T.Kind() != Interface {
return false
}
t := (*interfaceType)(unsafe.Pointer(T))
if len(t.methods) == 0 {
return true
}
// The same algorithm applies in both cases, but the
// method tables for an interface type and a concrete type
// are different, so the code is duplicated.
// In both cases the algorithm is a linear scan over the two
// lists - T's methods and V's methods - simultaneously.
// Since method tables are stored in a unique sorted order
// (alphabetical, with no duplicate method names), the scan
// through V's methods must hit a match for each of T's
// methods along the way, or else V does not implement T.
// This lets us run the scan in overall linear time instead of
// the quadratic time a naive search would require.
// See also ../runtime/iface.go.
if V.Kind() == Interface {
v := (*interfaceType)(unsafe.Pointer(V))
i := 0
for j := 0; j < len(v.methods); j++ {
tm := &t.methods[i]
tmName := t.nameOff(tm.name)
vm := &v.methods[j]
vmName := V.nameOff(vm.name)
if vmName.name() == tmName.name() && V.typeOff(vm.typ) == t.typeOff(tm.typ) {
if !tmName.isExported() {
tmPkgPath := tmName.pkgPath()
if tmPkgPath == "" {
tmPkgPath = t.pkgPath.name()
}
vmPkgPath := vmName.pkgPath()
if vmPkgPath == "" {
vmPkgPath = v.pkgPath.name()
}
if tmPkgPath != vmPkgPath {
continue
}
}
if i++; i >= len(t.methods) {
return true
}
}
}
return false
}
v := V.uncommon()
if v == nil {
return false
}
i := 0
vmethods := v.methods()
for j := 0; j < int(v.mcount); j++ {
tm := &t.methods[i]
tmName := t.nameOff(tm.name)
vm := vmethods[j]
vmName := V.nameOff(vm.name)
if vmName.name() == tmName.name() && V.typeOff(vm.mtyp) == t.typeOff(tm.typ) {
if !tmName.isExported() {
tmPkgPath := tmName.pkgPath()
if tmPkgPath == "" {
tmPkgPath = t.pkgPath.name()
}
vmPkgPath := vmName.pkgPath()
if vmPkgPath == "" {
vmPkgPath = V.nameOff(v.pkgPath).name()
}
if tmPkgPath != vmPkgPath {
continue
}
}
if i++; i >= len(t.methods) {
return true
}
}
}
return false
}
// directlyAssignable reports whether a value x of type V can be directly
// assigned (using memmove) to a value of type T.
// https://golang.org/doc/go_spec.html#Assignability
// Ignoring the interface rules (implemented elsewhere)
// and the ideal constant rules (no ideal constants at run time).
func directlyAssignable(T, V *rtype) bool {
// x's type V is identical to T?
if T == V {
return true
}
// Otherwise at least one of T and V must not be defined
// and they must have the same kind.
if T.hasName() && V.hasName() || T.Kind() != V.Kind() {
return false
}
// x's type T and V must have identical underlying types.
return haveIdenticalUnderlyingType(T, V, true)
}
func haveIdenticalType(T, V Type, cmpTags bool) bool {
if cmpTags {
return T == V
}
if T.Name() != V.Name() || T.Kind() != V.Kind() {
return false
}
return haveIdenticalUnderlyingType(T.common(), V.common(), false)
}
func haveIdenticalUnderlyingType(T, V *rtype, cmpTags bool) bool {
if T == V {
return true
}
kind := T.Kind()
if kind != V.Kind() {
return false
}
// Non-composite types of equal kind have same underlying type
// (the predefined instance of the type).
if Bool <= kind && kind <= Complex128 || kind == String || kind == UnsafePointer {
return true
}
// Composite types.
switch kind {
case Array:
return T.Len() == V.Len() && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Chan:
// Special case:
// x is a bidirectional channel value, T is a channel type,
// and x's type V and T have identical element types.
if V.chanDir() == bothDir && haveIdenticalType(T.Elem(), V.Elem(), cmpTags) {
return true
}
// Otherwise continue test for identical underlying type.
return V.chanDir() == T.chanDir() && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Func:
t := (*funcType)(unsafe.Pointer(T))
v := (*funcType)(unsafe.Pointer(V))
if t.outCount != v.outCount || t.inCount != v.inCount {
return false
}
for i := 0; i < t.NumIn(); i++ {
if !haveIdenticalType(t.In(i), v.In(i), cmpTags) {
return false
}
}
for i := 0; i < t.NumOut(); i++ {
if !haveIdenticalType(t.Out(i), v.Out(i), cmpTags) {
return false
}
}
return true
case Interface:
t := (*interfaceType)(unsafe.Pointer(T))
v := (*interfaceType)(unsafe.Pointer(V))
if len(t.methods) == 0 && len(v.methods) == 0 {
return true
}
// Might have the same methods but still
// need a run time conversion.
return false
case Map:
return haveIdenticalType(T.Key(), V.Key(), cmpTags) && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Ptr, Slice:
return haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Struct:
t := (*structType)(unsafe.Pointer(T))
v := (*structType)(unsafe.Pointer(V))
if len(t.fields) != len(v.fields) {
return false
}
if t.pkgPath.name() != v.pkgPath.name() {
return false
}
for i := range t.fields {
tf := &t.fields[i]
vf := &v.fields[i]
if tf.name.name() != vf.name.name() {
return false
}
if !haveIdenticalType(tf.typ, vf.typ, cmpTags) {
return false
}
if cmpTags && tf.name.tag() != vf.name.tag() {
return false
}
if tf.offset != vf.offset {
return false
}
if tf.embedded() != vf.embedded() {
return false
}
}
return true
}
return false
}
type structTypeUncommon struct {
structType
u uncommonType
}
// toType converts from a *rtype to a Type that can be returned
// to the client of package reflect. In gc, the only concern is that
// a nil *rtype must be replaced by a nil Type, but in gccgo this
// function takes care of ensuring that multiple *rtype for the same
// type are coalesced into a single Type.
func toType(t *rtype) Type {
if t == nil {
return nil
}
return t
}
// ifaceIndir reports whether t is stored indirectly in an interface value.
func ifaceIndir(t *rtype) bool {
return t.kind&kindDirectIface == 0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflectlite
import (
"internal/goarch"
"internal/unsafeheader"
"runtime"
"unsafe"
)
// Value is the reflection interface to a Go value.
//
// Not all methods apply to all kinds of values. Restrictions,
// if any, are noted in the documentation for each method.
// Use the Kind method to find out the kind of value before
// calling kind-specific methods. Calling a method
// inappropriate to the kind of type causes a run time panic.
//
// The zero Value represents no value.
// Its IsValid method returns false, its Kind method returns Invalid,
// its String method returns "<invalid Value>", and all other methods panic.
// Most functions and methods never return an invalid value.
// If one does, its documentation states the conditions explicitly.
//
// A Value can be used concurrently by multiple goroutines provided that
// the underlying Go value can be used concurrently for the equivalent
// direct operations.
//
// To compare two Values, compare the results of the Interface method.
// Using == on two Values does not compare the underlying values
// they represent.
type Value struct {
// typ holds the type of the value represented by a Value.
typ *rtype
// Pointer-valued data or, if flagIndir is set, pointer to data.
// Valid when either flagIndir is set or typ.pointers() is true.
ptr unsafe.Pointer
// flag holds metadata about the value.
// The lowest bits are flag bits:
// - flagStickyRO: obtained via unexported not embedded field, so read-only
// - flagEmbedRO: obtained via unexported embedded field, so read-only
// - flagIndir: val holds a pointer to the data
// - flagAddr: v.CanAddr is true (implies flagIndir)
// Value cannot represent method values.
// The next five bits give the Kind of the value.
// This repeats typ.Kind() except for method values.
// The remaining 23+ bits give a method number for method values.
// If flag.kind() != Func, code can assume that flagMethod is unset.
// If ifaceIndir(typ), code can assume that flagIndir is set.
flag
// A method value represents a curried method invocation
// like r.Read for some receiver r. The typ+val+flag bits describe
// the receiver r, but the flag's Kind bits say Func (methods are
// functions), and the top bits of the flag give the method number
// in r's type's method table.
}
type flag uintptr
const (
flagKindWidth = 5 // there are 27 kinds
flagKindMask flag = 1<<flagKindWidth - 1
flagStickyRO flag = 1 << 5
flagEmbedRO flag = 1 << 6
flagIndir flag = 1 << 7
flagAddr flag = 1 << 8
flagMethod flag = 1 << 9
flagMethodShift = 10
flagRO flag = flagStickyRO | flagEmbedRO
)
func (f flag) kind() Kind {
return Kind(f & flagKindMask)
}
func (f flag) ro() flag {
if f&flagRO != 0 {
return flagStickyRO
}
return 0
}
// pointer returns the underlying pointer represented by v.
// v.Kind() must be Pointer, Map, Chan, Func, or UnsafePointer
func (v Value) pointer() unsafe.Pointer {
if v.typ.size != goarch.PtrSize || !v.typ.pointers() {
panic("can't call pointer on a non-pointer Value")
}
if v.flag&flagIndir != 0 {
return *(*unsafe.Pointer)(v.ptr)
}
return v.ptr
}
// packEface converts v to the empty interface.
func packEface(v Value) any {
t := v.typ
var i any
e := (*emptyInterface)(unsafe.Pointer(&i))
// First, fill in the data portion of the interface.
switch {
case ifaceIndir(t):
if v.flag&flagIndir == 0 {
panic("bad indir")
}
// Value is indirect, and so is the interface we're making.
ptr := v.ptr
if v.flag&flagAddr != 0 {
// TODO: pass safe boolean from valueInterface so
// we don't need to copy if safe==true?
c := unsafe_New(t)
typedmemmove(t, c, ptr)
ptr = c
}
e.word = ptr
case v.flag&flagIndir != 0:
// Value is indirect, but interface is direct. We need
// to load the data at v.ptr into the interface data word.
e.word = *(*unsafe.Pointer)(v.ptr)
default:
// Value is direct, and so is the interface.
e.word = v.ptr
}
// Now, fill in the type portion. We're very careful here not
// to have any operation between the e.word and e.typ assignments
// that would let the garbage collector observe the partially-built
// interface value.
e.typ = t
return i
}
// unpackEface converts the empty interface i to a Value.
func unpackEface(i any) Value {
e := (*emptyInterface)(unsafe.Pointer(&i))
// NOTE: don't read e.word until we know whether it is really a pointer or not.
t := e.typ
if t == nil {
return Value{}
}
f := flag(t.Kind())
if ifaceIndir(t) {
f |= flagIndir
}
return Value{t, e.word, f}
}
// A ValueError occurs when a Value method is invoked on
// a Value that does not support it. Such cases are documented
// in the description of each method.
type ValueError struct {
Method string
Kind Kind
}
func (e *ValueError) Error() string {
if e.Kind == 0 {
return "reflect: call of " + e.Method + " on zero Value"
}
return "reflect: call of " + e.Method + " on " + e.Kind.String() + " Value"
}
// methodName returns the name of the calling method,
// assumed to be two stack frames above.
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
// emptyInterface is the header for an interface{} value.
type emptyInterface struct {
typ *rtype
word unsafe.Pointer
}
// mustBeExported panics if f records that the value was obtained using
// an unexported field.
func (f flag) mustBeExported() {
if f == 0 {
panic(&ValueError{methodName(), 0})
}
if f&flagRO != 0 {
panic("reflect: " + methodName() + " using value obtained using unexported field")
}
}
// mustBeAssignable panics if f records that the value is not assignable,
// which is to say that either it was obtained using an unexported field
// or it is not addressable.
func (f flag) mustBeAssignable() {
if f == 0 {
panic(&ValueError{methodName(), Invalid})
}
// Assignable if addressable and not read-only.
if f&flagRO != 0 {
panic("reflect: " + methodName() + " using value obtained using unexported field")
}
if f&flagAddr == 0 {
panic("reflect: " + methodName() + " using unaddressable value")
}
}
// CanSet reports whether the value of v can be changed.
// A Value can be changed only if it is addressable and was not
// obtained by the use of unexported struct fields.
// If CanSet returns false, calling Set or any type-specific
// setter (e.g., SetBool, SetInt) will panic.
func (v Value) CanSet() bool {
return v.flag&(flagAddr|flagRO) == flagAddr
}
// Elem returns the value that the interface v contains
// or that the pointer v points to.
// It panics if v's Kind is not Interface or Pointer.
// It returns the zero Value if v is nil.
func (v Value) Elem() Value {
k := v.kind()
switch k {
case Interface:
var eface any
if v.typ.NumMethod() == 0 {
eface = *(*any)(v.ptr)
} else {
eface = (any)(*(*interface {
M()
})(v.ptr))
}
x := unpackEface(eface)
if x.flag != 0 {
x.flag |= v.flag.ro()
}
return x
case Pointer:
ptr := v.ptr
if v.flag&flagIndir != 0 {
ptr = *(*unsafe.Pointer)(ptr)
}
// The returned value's address is v's value.
if ptr == nil {
return Value{}
}
tt := (*ptrType)(unsafe.Pointer(v.typ))
typ := tt.elem
fl := v.flag&flagRO | flagIndir | flagAddr
fl |= flag(typ.Kind())
return Value{typ, ptr, fl}
}
panic(&ValueError{"reflectlite.Value.Elem", v.kind()})
}
func valueInterface(v Value) any {
if v.flag == 0 {
panic(&ValueError{"reflectlite.Value.Interface", 0})
}
if v.kind() == Interface {
// Special case: return the element inside the interface.
// Empty interface has one layout, all interfaces with
// methods have a second layout.
if v.numMethod() == 0 {
return *(*any)(v.ptr)
}
return *(*interface {
M()
})(v.ptr)
}
// TODO: pass safe to packEface so we don't need to copy if safe==true?
return packEface(v)
}
// IsNil reports whether its argument v is nil. The argument must be
// a chan, func, interface, map, pointer, or slice value; if it is
// not, IsNil panics. Note that IsNil is not always equivalent to a
// regular comparison with nil in Go. For example, if v was created
// by calling ValueOf with an uninitialized interface variable i,
// i==nil will be true but v.IsNil will panic as v will be the zero
// Value.
func (v Value) IsNil() bool {
k := v.kind()
switch k {
case Chan, Func, Map, Pointer, UnsafePointer:
// if v.flag&flagMethod != 0 {
// return false
// }
ptr := v.ptr
if v.flag&flagIndir != 0 {
ptr = *(*unsafe.Pointer)(ptr)
}
return ptr == nil
case Interface, Slice:
// Both interface and slice are nil if first word is 0.
// Both are always bigger than a word; assume flagIndir.
return *(*unsafe.Pointer)(v.ptr) == nil
}
panic(&ValueError{"reflectlite.Value.IsNil", v.kind()})
}
// IsValid reports whether v represents a value.
// It returns false if v is the zero Value.
// If IsValid returns false, all other methods except String panic.
// Most functions and methods never return an invalid Value.
// If one does, its documentation states the conditions explicitly.
func (v Value) IsValid() bool {
return v.flag != 0
}
// Kind returns v's Kind.
// If v is the zero Value (IsValid returns false), Kind returns Invalid.
func (v Value) Kind() Kind {
return v.kind()
}
// implemented in runtime:
func chanlen(unsafe.Pointer) int
func maplen(unsafe.Pointer) int
// Len returns v's length.
// It panics if v's Kind is not Array, Chan, Map, Slice, or String.
func (v Value) Len() int {
k := v.kind()
switch k {
case Array:
tt := (*arrayType)(unsafe.Pointer(v.typ))
return int(tt.len)
case Chan:
return chanlen(v.pointer())
case Map:
return maplen(v.pointer())
case Slice:
// Slice is bigger than a word; assume flagIndir.
return (*unsafeheader.Slice)(v.ptr).Len
case String:
// String is bigger than a word; assume flagIndir.
return (*unsafeheader.String)(v.ptr).Len
}
panic(&ValueError{"reflect.Value.Len", v.kind()})
}
// NumMethod returns the number of exported methods in the value's method set.
func (v Value) numMethod() int {
if v.typ == nil {
panic(&ValueError{"reflectlite.Value.NumMethod", Invalid})
}
return v.typ.NumMethod()
}
// Set assigns x to the value v.
// It panics if CanSet returns false.
// As in Go, x's value must be assignable to v's type.
func (v Value) Set(x Value) {
v.mustBeAssignable()
x.mustBeExported() // do not let unexported x leak
var target unsafe.Pointer
if v.kind() == Interface {
target = v.ptr
}
x = x.assignTo("reflectlite.Set", v.typ, target)
if x.flag&flagIndir != 0 {
typedmemmove(v.typ, v.ptr, x.ptr)
} else {
*(*unsafe.Pointer)(v.ptr) = x.ptr
}
}
// Type returns v's type.
func (v Value) Type() Type {
f := v.flag
if f == 0 {
panic(&ValueError{"reflectlite.Value.Type", Invalid})
}
// Method values not supported.
return v.typ
}
/*
* constructors
*/
// implemented in package runtime
func unsafe_New(*rtype) unsafe.Pointer
// ValueOf returns a new Value initialized to the concrete value
// stored in the interface i. ValueOf(nil) returns the zero Value.
func ValueOf(i any) Value {
if i == nil {
return Value{}
}
// TODO: Maybe allow contents of a Value to live on the stack.
// For now we make the contents always escape to the heap. It
// makes life easier in a few places (see chanrecv/mapassign
// comment below).
escapes(i)
return unpackEface(i)
}
// assignTo returns a value v that can be assigned directly to typ.
// It panics if v is not assignable to typ.
// For a conversion to an interface type, target is a suggested scratch space to use.
func (v Value) assignTo(context string, dst *rtype, target unsafe.Pointer) Value {
// if v.flag&flagMethod != 0 {
// v = makeMethodValue(context, v)
// }
switch {
case directlyAssignable(dst, v.typ):
// Overwrite type so that they match.
// Same memory layout, so no harm done.
fl := v.flag&(flagAddr|flagIndir) | v.flag.ro()
fl |= flag(dst.Kind())
return Value{dst, v.ptr, fl}
case implements(dst, v.typ):
if target == nil {
target = unsafe_New(dst)
}
if v.Kind() == Interface && v.IsNil() {
// A nil ReadWriter passed to nil Reader is OK,
// but using ifaceE2I below will panic.
// Avoid the panic by returning a nil dst (e.g., Reader) explicitly.
return Value{dst, nil, flag(Interface)}
}
x := valueInterface(v)
if dst.NumMethod() == 0 {
*(*any)(target) = x
} else {
ifaceE2I(dst, x, target)
}
return Value{dst, target, flagIndir | flag(Interface)}
}
// Failed.
panic(context + ": value of type " + v.typ.String() + " is not assignable to type " + dst.String())
}
// arrayAt returns the i-th element of p,
// an array whose elements are eltSize bytes wide.
// The array pointed at by p must have at least i+1 elements:
// it is invalid (but impossible to check here) to pass i >= len,
// because then the result will point outside the array.
// whySafe must explain why i < len. (Passing "i < len" is fine;
// the benefit is to surface this assumption at the call site.)
func arrayAt(p unsafe.Pointer, i int, eltSize uintptr, whySafe string) unsafe.Pointer {
return add(p, uintptr(i)*eltSize, "i < len")
}
func ifaceE2I(t *rtype, src any, dst unsafe.Pointer)
// typedmemmove copies a value of type t to dst from src.
//
//go:noescape
func typedmemmove(t *rtype, dst, src unsafe.Pointer)
// Dummy annotation marking that the value x escapes,
// for use in cases where the reflect code is so clever that
// the compiler cannot follow.
func escapes(x any) {
if dummy.b {
dummy.x = x
}
}
var dummy struct {
b bool
x any
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package safefilepath manipulates operating-system file paths.
package safefilepath
import (
"errors"
)
var errInvalidPath = errors.New("invalid path")
// FromFS converts a slash-separated path into an operating-system path.
//
// FromFS returns an error if the path cannot be represented by the operating
// system. For example, paths containing '\' and ':' characters are rejected
// on Windows.
func FromFS(path string) (string, error) {
return fromFS(path)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package safefilepath
import "runtime"
func fromFS(path string) (string, error) {
if runtime.GOOS == "plan9" {
if len(path) > 0 && path[0] == '#' {
return "", errInvalidPath
}
}
for i := range path {
if path[i] == 0 {
return "", errInvalidPath
}
}
return path, nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package saferio provides I/O functions that avoid allocating large
// amounts of memory unnecessarily. This is intended for packages that
// read data from an [io.Reader] where the size is part of the input
// data but the input may be corrupt, or may be provided by an
// untrustworthy attacker.
package saferio
import (
"io"
"reflect"
)
// chunk is an arbitrary limit on how much memory we are willing
// to allocate without concern.
const chunk = 10 << 20 // 10M
// ReadData reads n bytes from the input stream, but avoids allocating
// all n bytes if n is large. This avoids crashing the program by
// allocating all n bytes in cases where n is incorrect.
//
// The error is io.EOF only if no bytes were read.
// If an io.EOF happens after reading some but not all the bytes,
// ReadData returns io.ErrUnexpectedEOF.
func ReadData(r io.Reader, n uint64) ([]byte, error) {
if int64(n) < 0 || n != uint64(int(n)) {
// n is too large to fit in int, so we can't allocate
// a buffer large enough. Treat this as a read failure.
return nil, io.ErrUnexpectedEOF
}
if n < chunk {
buf := make([]byte, n)
_, err := io.ReadFull(r, buf)
if err != nil {
return nil, err
}
return buf, nil
}
var buf []byte
buf1 := make([]byte, chunk)
for n > 0 {
next := n
if next > chunk {
next = chunk
}
_, err := io.ReadFull(r, buf1[:next])
if err != nil {
if len(buf) > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
buf = append(buf, buf1[:next]...)
n -= next
}
return buf, nil
}
// ReadDataAt reads n bytes from the input stream at off, but avoids
// allocating all n bytes if n is large. This avoids crashing the program
// by allocating all n bytes in cases where n is incorrect.
func ReadDataAt(r io.ReaderAt, n uint64, off int64) ([]byte, error) {
if int64(n) < 0 || n != uint64(int(n)) {
// n is too large to fit in int, so we can't allocate
// a buffer large enough. Treat this as a read failure.
return nil, io.ErrUnexpectedEOF
}
if n < chunk {
buf := make([]byte, n)
_, err := r.ReadAt(buf, off)
if err != nil {
// io.SectionReader can return EOF for n == 0,
// but for our purposes that is a success.
if err != io.EOF || n > 0 {
return nil, err
}
}
return buf, nil
}
var buf []byte
buf1 := make([]byte, chunk)
for n > 0 {
next := n
if next > chunk {
next = chunk
}
_, err := r.ReadAt(buf1[:next], off)
if err != nil {
return nil, err
}
buf = append(buf, buf1[:next]...)
n -= next
off += int64(next)
}
return buf, nil
}
// SliceCap returns the capacity to use when allocating a slice.
// After the slice is allocated with the capacity, it should be
// built using append. This will avoid allocating too much memory
// if the capacity is large and incorrect.
//
// A negative result means that the value is always too big.
//
// The element type is described by passing a pointer to a value of that type.
// This would ideally use generics, but this code is built with
// the bootstrap compiler which need not support generics.
// We use a pointer so that we can handle slices of interface type.
func SliceCap(v any, c uint64) int {
if int64(c) < 0 || c != uint64(int(c)) {
return -1
}
typ := reflect.TypeOf(v)
if typ.Kind() != reflect.Ptr {
panic("SliceCap called with non-pointer type")
}
size := uint64(typ.Elem().Size())
if size > 0 && c > (1<<64-1)/size {
return -1
}
if c*size > chunk {
c = uint64(chunk / size)
if c == 0 {
c = 1
}
}
return int(c)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val any
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val any
Err error
Shared bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (any, error)) (v any, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (any, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (any, error)) {
c.val, c.err = fn()
g.mu.Lock()
c.wg.Done()
if g.m[key] == c {
delete(g.m, key)
}
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
g.mu.Unlock()
}
// ForgetUnshared tells the singleflight to forget about a key if it is not
// shared with any other goroutines. Future calls to Do for a forgotten key
// will call the function rather than waiting for an earlier call to complete.
// Returns whether the key was forgotten or unknown--that is, whether no
// other goroutines are waiting for the result.
func (g *Group) ForgetUnshared(key string) bool {
g.mu.Lock()
defer g.mu.Unlock()
c, ok := g.m[key]
if !ok {
return true
}
if c.dups == 0 {
delete(g.m, key)
return true
}
return false
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testenv
import (
"context"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
)
// HasExec reports whether the current system can start new processes
// using os.StartProcess or (more commonly) exec.Command.
func HasExec() bool {
switch runtime.GOOS {
case "js", "ios":
return false
}
return true
}
// MustHaveExec checks that the current system can start new processes
// using os.StartProcess or (more commonly) exec.Command.
// If not, MustHaveExec calls t.Skip with an explanation.
func MustHaveExec(t testing.TB) {
if !HasExec() {
t.Skipf("skipping test: cannot exec subprocess on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
var execPaths sync.Map // path -> error
// MustHaveExecPath checks that the current system can start the named executable
// using os.StartProcess or (more commonly) exec.Command.
// If not, MustHaveExecPath calls t.Skip with an explanation.
func MustHaveExecPath(t testing.TB, path string) {
MustHaveExec(t)
err, found := execPaths.Load(path)
if !found {
_, err = exec.LookPath(path)
err, _ = execPaths.LoadOrStore(path, err)
}
if err != nil {
t.Skipf("skipping test: %s: %s", path, err)
}
}
// CleanCmdEnv will fill cmd.Env with the environment, excluding certain
// variables that could modify the behavior of the Go tools such as
// GODEBUG and GOTRACEBACK.
func CleanCmdEnv(cmd *exec.Cmd) *exec.Cmd {
if cmd.Env != nil {
panic("environment already set")
}
for _, env := range os.Environ() {
// Exclude GODEBUG from the environment to prevent its output
// from breaking tests that are trying to parse other command output.
if strings.HasPrefix(env, "GODEBUG=") {
continue
}
// Exclude GOTRACEBACK for the same reason.
if strings.HasPrefix(env, "GOTRACEBACK=") {
continue
}
cmd.Env = append(cmd.Env, env)
}
return cmd
}
// CommandContext is like exec.CommandContext, but:
// - skips t if the platform does not support os/exec,
// - sends SIGQUIT (if supported by the platform) instead of SIGKILL
// in its Cancel function
// - if the test has a deadline, adds a Context timeout and WaitDelay
// for an arbitrary grace period before the test's deadline expires,
// - fails the test if the command does not complete before the test's deadline, and
// - sets a Cleanup function that verifies that the test did not leak a subprocess.
func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd {
t.Helper()
MustHaveExec(t)
var (
cancelCtx context.CancelFunc
gracePeriod time.Duration // unlimited unless the test has a deadline (to allow for interactive debugging)
)
if t, ok := t.(interface {
testing.TB
Deadline() (time.Time, bool)
}); ok {
if td, ok := t.Deadline(); ok {
// Start with a minimum grace period, just long enough to consume the
// output of a reasonable program after it terminates.
gracePeriod = 100 * time.Millisecond
if s := os.Getenv("GO_TEST_TIMEOUT_SCALE"); s != "" {
scale, err := strconv.Atoi(s)
if err != nil {
t.Fatalf("invalid GO_TEST_TIMEOUT_SCALE: %v", err)
}
gracePeriod *= time.Duration(scale)
}
// If time allows, increase the termination grace period to 5% of the
// test's remaining time.
testTimeout := time.Until(td)
if gp := testTimeout / 20; gp > gracePeriod {
gracePeriod = gp
}
// When we run commands that execute subprocesses, we want to reserve two
// grace periods to clean up: one for the delay between the first
// termination signal being sent (via the Cancel callback when the Context
// expires) and the process being forcibly terminated (via the WaitDelay
// field), and a second one for the delay becween the process being
// terminated and and the test logging its output for debugging.
//
// (We want to ensure that the test process itself has enough time to
// log the output before it is also terminated.)
cmdTimeout := testTimeout - 2*gracePeriod
if cd, ok := ctx.Deadline(); !ok || time.Until(cd) > cmdTimeout {
// Either ctx doesn't have a deadline, or its deadline would expire
// after (or too close before) the test has already timed out.
// Add a shorter timeout so that the test will produce useful output.
ctx, cancelCtx = context.WithTimeout(ctx, cmdTimeout)
}
}
}
cmd := exec.CommandContext(ctx, name, args...)
cmd.Cancel = func() error {
if cancelCtx != nil && ctx.Err() == context.DeadlineExceeded {
// The command timed out due to running too close to the test's deadline.
// There is no way the test did that intentionally — it's too close to the
// wire! — so mark it as a test failure. That way, if the test expects the
// command to fail for some other reason, it doesn't have to distinguish
// between that reason and a timeout.
t.Errorf("test timed out while running command: %v", cmd)
} else {
// The command is being terminated due to ctx being canceled, but
// apparently not due to an explicit test deadline that we added.
// Log that information in case it is useful for diagnosing a failure,
// but don't actually fail the test because of it.
t.Logf("%v: terminating command: %v", ctx.Err(), cmd)
}
return cmd.Process.Signal(Sigquit)
}
cmd.WaitDelay = gracePeriod
t.Cleanup(func() {
if cancelCtx != nil {
cancelCtx()
}
if cmd.Process != nil && cmd.ProcessState == nil {
t.Errorf("command was started, but test did not wait for it to complete: %v", cmd)
}
})
return cmd
}
// Command is like exec.Command, but applies the same changes as
// testenv.CommandContext (with a default Context).
func Command(t testing.TB, name string, args ...string) *exec.Cmd {
t.Helper()
return CommandContext(t, context.Background(), name, args...)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !noopt
package testenv
// OptimizationOff reports whether optimization is disabled.
func OptimizationOff() bool {
return false
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package testenv provides information about what functionality
// is available in different testing environments run by the Go team.
//
// It is an internal package because these details are specific
// to the Go team's test setup (on build.golang.org) and not
// fundamental to tests in general.
package testenv
import (
"bytes"
"errors"
"flag"
"fmt"
"internal/cfg"
"internal/platform"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"testing"
)
// Builder reports the name of the builder running this test
// (for example, "linux-amd64" or "windows-386-gce").
// If the test is not running on the build infrastructure,
// Builder returns the empty string.
func Builder() string {
return os.Getenv("GO_BUILDER_NAME")
}
// HasGoBuild reports whether the current system can build programs with “go build”
// and then run them with os.StartProcess or exec.Command.
func HasGoBuild() bool {
if os.Getenv("GO_GCFLAGS") != "" {
// It's too much work to require every caller of the go command
// to pass along "-gcflags="+os.Getenv("GO_GCFLAGS").
// For now, if $GO_GCFLAGS is set, report that we simply can't
// run go build.
return false
}
switch runtime.GOOS {
case "android", "js", "ios":
return false
}
return true
}
// MustHaveGoBuild checks that the current system can build programs with “go build”
// and then run them with os.StartProcess or exec.Command.
// If not, MustHaveGoBuild calls t.Skip with an explanation.
func MustHaveGoBuild(t testing.TB) {
if os.Getenv("GO_GCFLAGS") != "" {
t.Skipf("skipping test: 'go build' not compatible with setting $GO_GCFLAGS")
}
if !HasGoBuild() {
t.Skipf("skipping test: 'go build' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
// HasGoRun reports whether the current system can run programs with “go run.”
func HasGoRun() bool {
// For now, having go run and having go build are the same.
return HasGoBuild()
}
// MustHaveGoRun checks that the current system can run programs with “go run.”
// If not, MustHaveGoRun calls t.Skip with an explanation.
func MustHaveGoRun(t testing.TB) {
if !HasGoRun() {
t.Skipf("skipping test: 'go run' not available on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
// GoToolPath reports the path to the Go tool.
// It is a convenience wrapper around GoTool.
// If the tool is unavailable GoToolPath calls t.Skip.
// If the tool should be available and isn't, GoToolPath calls t.Fatal.
func GoToolPath(t testing.TB) string {
MustHaveGoBuild(t)
path, err := GoTool()
if err != nil {
t.Fatal(err)
}
// Add all environment variables that affect the Go command to test metadata.
// Cached test results will be invalidate when these variables change.
// See golang.org/issue/32285.
for _, envVar := range strings.Fields(cfg.KnownEnv) {
os.Getenv(envVar)
}
return path
}
var (
gorootOnce sync.Once
gorootPath string
gorootErr error
)
func findGOROOT() (string, error) {
gorootOnce.Do(func() {
gorootPath = runtime.GOROOT()
if gorootPath != "" {
// If runtime.GOROOT() is non-empty, assume that it is valid.
//
// (It might not be: for example, the user may have explicitly set GOROOT
// to the wrong directory, or explicitly set GOROOT_FINAL but not GOROOT
// and hasn't moved the tree to GOROOT_FINAL yet. But those cases are
// rare, and if that happens the user can fix what they broke.)
return
}
// runtime.GOROOT doesn't know where GOROOT is (perhaps because the test
// binary was built with -trimpath, or perhaps because GOROOT_FINAL was set
// without GOROOT and the tree hasn't been moved there yet).
//
// Since this is internal/testenv, we can cheat and assume that the caller
// is a test of some package in a subdirectory of GOROOT/src. ('go test'
// runs the test in the directory containing the packaged under test.) That
// means that if we start walking up the tree, we should eventually find
// GOROOT/src/go.mod, and we can report the parent directory of that.
cwd, err := os.Getwd()
if err != nil {
gorootErr = fmt.Errorf("finding GOROOT: %w", err)
return
}
dir := cwd
for {
parent := filepath.Dir(dir)
if parent == dir {
// dir is either "." or only a volume name.
gorootErr = fmt.Errorf("failed to locate GOROOT/src in any parent directory")
return
}
if base := filepath.Base(dir); base != "src" {
dir = parent
continue // dir cannot be GOROOT/src if it doesn't end in "src".
}
b, err := os.ReadFile(filepath.Join(dir, "go.mod"))
if err != nil {
if os.IsNotExist(err) {
dir = parent
continue
}
gorootErr = fmt.Errorf("finding GOROOT: %w", err)
return
}
goMod := string(b)
for goMod != "" {
var line string
line, goMod, _ = strings.Cut(goMod, "\n")
fields := strings.Fields(line)
if len(fields) >= 2 && fields[0] == "module" && fields[1] == "std" {
// Found "module std", which is the module declaration in GOROOT/src!
gorootPath = parent
return
}
}
}
})
return gorootPath, gorootErr
}
// GOROOT reports the path to the directory containing the root of the Go
// project source tree. This is normally equivalent to runtime.GOROOT, but
// works even if the test binary was built with -trimpath.
//
// If GOROOT cannot be found, GOROOT skips t if t is non-nil,
// or panics otherwise.
func GOROOT(t testing.TB) string {
path, err := findGOROOT()
if err != nil {
if t == nil {
panic(err)
}
t.Helper()
t.Skip(err)
}
return path
}
// GoTool reports the path to the Go tool.
func GoTool() (string, error) {
goToolOnce.Do(func() {
goToolPath, goToolErr = func() (string, error) {
if !HasGoBuild() {
return "", errors.New("platform cannot run go tool")
}
var exeSuffix string
if runtime.GOOS == "windows" {
exeSuffix = ".exe"
}
goroot, err := findGOROOT()
if err != nil {
return "", fmt.Errorf("cannot find go tool: %w", err)
}
path := filepath.Join(goroot, "bin", "go"+exeSuffix)
if _, err := os.Stat(path); err == nil {
return path, nil
}
goBin, err := exec.LookPath("go" + exeSuffix)
if err != nil {
return "", errors.New("cannot find go tool: " + err.Error())
}
return goBin, nil
}()
})
return goToolPath, goToolErr
}
var (
goToolOnce sync.Once
goToolPath string
goToolErr error
)
// HasSrc reports whether the entire source tree is available under GOROOT.
func HasSrc() bool {
switch runtime.GOOS {
case "ios":
return false
}
return true
}
// HasExternalNetwork reports whether the current system can use
// external (non-localhost) networks.
func HasExternalNetwork() bool {
return !testing.Short() && runtime.GOOS != "js"
}
// MustHaveExternalNetwork checks that the current system can use
// external (non-localhost) networks.
// If not, MustHaveExternalNetwork calls t.Skip with an explanation.
func MustHaveExternalNetwork(t testing.TB) {
if runtime.GOOS == "js" {
t.Skipf("skipping test: no external network on %s", runtime.GOOS)
}
if testing.Short() {
t.Skipf("skipping test: no external network in -short mode")
}
}
// HasCGO reports whether the current system can use cgo.
func HasCGO() bool {
hasCgoOnce.Do(func() {
goTool, err := GoTool()
if err != nil {
return
}
cmd := exec.Command(goTool, "env", "CGO_ENABLED")
out, err := cmd.Output()
if err != nil {
panic(fmt.Sprintf("%v: %v", cmd, out))
}
hasCgo, err = strconv.ParseBool(string(bytes.TrimSpace(out)))
if err != nil {
panic(fmt.Sprintf("%v: non-boolean output %q", cmd, out))
}
})
return hasCgo
}
var (
hasCgoOnce sync.Once
hasCgo bool
)
// MustHaveCGO calls t.Skip if cgo is not available.
func MustHaveCGO(t testing.TB) {
if !HasCGO() {
t.Skipf("skipping test: no cgo")
}
}
// CanInternalLink reports whether the current system can link programs with
// internal linking.
func CanInternalLink(withCgo bool) bool {
return !platform.MustLinkExternal(runtime.GOOS, runtime.GOARCH, withCgo)
}
// MustInternalLink checks that the current system can link programs with internal
// linking.
// If not, MustInternalLink calls t.Skip with an explanation.
func MustInternalLink(t testing.TB, withCgo bool) {
if !CanInternalLink(withCgo) {
if withCgo && CanInternalLink(false) {
t.Skipf("skipping test: internal linking on %s/%s is not supported with cgo", runtime.GOOS, runtime.GOARCH)
}
t.Skipf("skipping test: internal linking on %s/%s is not supported", runtime.GOOS, runtime.GOARCH)
}
}
// HasSymlink reports whether the current system can use os.Symlink.
func HasSymlink() bool {
ok, _ := hasSymlink()
return ok
}
// MustHaveSymlink reports whether the current system can use os.Symlink.
// If not, MustHaveSymlink calls t.Skip with an explanation.
func MustHaveSymlink(t testing.TB) {
ok, reason := hasSymlink()
if !ok {
t.Skipf("skipping test: cannot make symlinks on %s/%s%s", runtime.GOOS, runtime.GOARCH, reason)
}
}
// HasLink reports whether the current system can use os.Link.
func HasLink() bool {
// From Android release M (Marshmallow), hard linking files is blocked
// and an attempt to call link() on a file will return EACCES.
// - https://code.google.com/p/android-developer-preview/issues/detail?id=3150
return runtime.GOOS != "plan9" && runtime.GOOS != "android"
}
// MustHaveLink reports whether the current system can use os.Link.
// If not, MustHaveLink calls t.Skip with an explanation.
func MustHaveLink(t testing.TB) {
if !HasLink() {
t.Skipf("skipping test: hardlinks are not supported on %s/%s", runtime.GOOS, runtime.GOARCH)
}
}
var flaky = flag.Bool("flaky", false, "run known-flaky tests too")
func SkipFlaky(t testing.TB, issue int) {
t.Helper()
if !*flaky {
t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue)
}
}
func SkipFlakyNet(t testing.TB) {
t.Helper()
if v, _ := strconv.ParseBool(os.Getenv("GO_BUILDER_FLAKY_NET")); v {
t.Skip("skipping test on builder known to have frequent network failures")
}
}
// CPUIsSlow reports whether the CPU running the test is suspected to be slow.
func CPUIsSlow() bool {
switch runtime.GOARCH {
case "arm", "mips", "mipsle", "mips64", "mips64le":
return true
}
return false
}
// SkipIfShortAndSlow skips t if -short is set and the CPU running the test is
// suspected to be slow.
//
// (This is useful for CPU-intensive tests that otherwise complete quickly.)
func SkipIfShortAndSlow(t testing.TB) {
if testing.Short() && CPUIsSlow() {
t.Helper()
t.Skipf("skipping test in -short mode on %s", runtime.GOARCH)
}
}
// SkipIfOptimizationOff skips t if optimization is disabled.
func SkipIfOptimizationOff(t testing.TB) {
if OptimizationOff() {
t.Helper()
t.Skip("skipping test with optimization disabled")
}
}
// WriteImportcfg writes an importcfg file used by the compiler or linker to
// dstPath containing entries for the file mappings in packageFiles, as well
// as for the packages transitively imported by the package(s) in pkgs.
//
// pkgs may include any package pattern that is valid to pass to 'go list',
// so it may also be a list of Go source files all in the same directory.
func WriteImportcfg(t testing.TB, dstPath string, packageFiles map[string]string, pkgs ...string) {
t.Helper()
icfg := new(bytes.Buffer)
icfg.WriteString("# import config\n")
for k, v := range packageFiles {
fmt.Fprintf(icfg, "packagefile %s=%s\n", k, v)
}
if len(pkgs) > 0 {
// Use 'go list' to resolve any missing packages and rewrite the import map.
cmd := Command(t, GoToolPath(t), "list", "-export", "-deps", "-f", `{{if ne .ImportPath "command-line-arguments"}}{{if .Export}}{{.ImportPath}}={{.Export}}{{end}}{{end}}`)
cmd.Args = append(cmd.Args, pkgs...)
cmd.Stderr = new(strings.Builder)
out, err := cmd.Output()
if err != nil {
t.Fatalf("%v: %v\n%s", cmd, err, cmd.Stderr)
}
for _, line := range strings.Split(string(out), "\n") {
if line == "" {
continue
}
importPath, export, ok := strings.Cut(line, "=")
if !ok {
t.Fatalf("invalid line in output from %v:\n%s", cmd, line)
}
if packageFiles[importPath] == "" {
fmt.Fprintf(icfg, "packagefile %s=%s\n", importPath, export)
}
}
}
if err := os.WriteFile(dstPath, icfg.Bytes(), 0666); err != nil {
t.Fatal(err)
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package testenv
import (
"runtime"
)
func hasSymlink() (ok bool, reason string) {
switch runtime.GOOS {
case "android", "plan9":
return false, ""
}
return true, ""
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"container/heap"
"math"
"sort"
"strings"
"time"
)
// MutatorUtil is a change in mutator utilization at a particular
// time. Mutator utilization functions are represented as a
// time-ordered []MutatorUtil.
type MutatorUtil struct {
Time int64
// Util is the mean mutator utilization starting at Time. This
// is in the range [0, 1].
Util float64
}
// UtilFlags controls the behavior of MutatorUtilization.
type UtilFlags int
const (
// UtilSTW means utilization should account for STW events.
UtilSTW UtilFlags = 1 << iota
// UtilBackground means utilization should account for
// background mark workers.
UtilBackground
// UtilAssist means utilization should account for mark
// assists.
UtilAssist
// UtilSweep means utilization should account for sweeping.
UtilSweep
// UtilPerProc means each P should be given a separate
// utilization function. Otherwise, there is a single function
// and each P is given a fraction of the utilization.
UtilPerProc
)
// MutatorUtilization returns a set of mutator utilization functions
// for the given trace. Each function will always end with 0
// utilization. The bounds of each function are implicit in the first
// and last event; outside of these bounds each function is undefined.
//
// If the UtilPerProc flag is not given, this always returns a single
// utilization function. Otherwise, it returns one function per P.
func MutatorUtilization(events []*Event, flags UtilFlags) [][]MutatorUtil {
if len(events) == 0 {
return nil
}
type perP struct {
// gc > 0 indicates that GC is active on this P.
gc int
// series the logical series number for this P. This
// is necessary because Ps may be removed and then
// re-added, and then the new P needs a new series.
series int
}
ps := []perP{}
stw := 0
out := [][]MutatorUtil{}
assists := map[uint64]bool{}
block := map[uint64]*Event{}
bgMark := map[uint64]bool{}
for _, ev := range events {
switch ev.Type {
case EvGomaxprocs:
gomaxprocs := int(ev.Args[0])
if len(ps) > gomaxprocs {
if flags&UtilPerProc != 0 {
// End each P's series.
for _, p := range ps[gomaxprocs:] {
out[p.series] = addUtil(out[p.series], MutatorUtil{ev.Ts, 0})
}
}
ps = ps[:gomaxprocs]
}
for len(ps) < gomaxprocs {
// Start new P's series.
series := 0
if flags&UtilPerProc != 0 || len(out) == 0 {
series = len(out)
out = append(out, []MutatorUtil{{ev.Ts, 1}})
}
ps = append(ps, perP{series: series})
}
case EvGCSTWStart:
if flags&UtilSTW != 0 {
stw++
}
case EvGCSTWDone:
if flags&UtilSTW != 0 {
stw--
}
case EvGCMarkAssistStart:
if flags&UtilAssist != 0 {
ps[ev.P].gc++
assists[ev.G] = true
}
case EvGCMarkAssistDone:
if flags&UtilAssist != 0 {
ps[ev.P].gc--
delete(assists, ev.G)
}
case EvGCSweepStart:
if flags&UtilSweep != 0 {
ps[ev.P].gc++
}
case EvGCSweepDone:
if flags&UtilSweep != 0 {
ps[ev.P].gc--
}
case EvGoStartLabel:
if flags&UtilBackground != 0 && strings.HasPrefix(ev.SArgs[0], "GC ") && ev.SArgs[0] != "GC (idle)" {
// Background mark worker.
//
// If we're in per-proc mode, we don't
// count dedicated workers because
// they kick all of the goroutines off
// that P, so don't directly
// contribute to goroutine latency.
if !(flags&UtilPerProc != 0 && ev.SArgs[0] == "GC (dedicated)") {
bgMark[ev.G] = true
ps[ev.P].gc++
}
}
fallthrough
case EvGoStart:
if assists[ev.G] {
// Unblocked during assist.
ps[ev.P].gc++
}
block[ev.G] = ev.Link
default:
if ev != block[ev.G] {
continue
}
if assists[ev.G] {
// Blocked during assist.
ps[ev.P].gc--
}
if bgMark[ev.G] {
// Background mark worker done.
ps[ev.P].gc--
delete(bgMark, ev.G)
}
delete(block, ev.G)
}
if flags&UtilPerProc == 0 {
// Compute the current average utilization.
if len(ps) == 0 {
continue
}
gcPs := 0
if stw > 0 {
gcPs = len(ps)
} else {
for i := range ps {
if ps[i].gc > 0 {
gcPs++
}
}
}
mu := MutatorUtil{ev.Ts, 1 - float64(gcPs)/float64(len(ps))}
// Record the utilization change. (Since
// len(ps) == len(out), we know len(out) > 0.)
out[0] = addUtil(out[0], mu)
} else {
// Check for per-P utilization changes.
for i := range ps {
p := &ps[i]
util := 1.0
if stw > 0 || p.gc > 0 {
util = 0.0
}
out[p.series] = addUtil(out[p.series], MutatorUtil{ev.Ts, util})
}
}
}
// Add final 0 utilization event to any remaining series. This
// is important to mark the end of the trace. The exact value
// shouldn't matter since no window should extend beyond this,
// but using 0 is symmetric with the start of the trace.
mu := MutatorUtil{events[len(events)-1].Ts, 0}
for i := range ps {
out[ps[i].series] = addUtil(out[ps[i].series], mu)
}
return out
}
func addUtil(util []MutatorUtil, mu MutatorUtil) []MutatorUtil {
if len(util) > 0 {
if mu.Util == util[len(util)-1].Util {
// No change.
return util
}
if mu.Time == util[len(util)-1].Time {
// Take the lowest utilization at a time stamp.
if mu.Util < util[len(util)-1].Util {
util[len(util)-1] = mu
}
return util
}
}
return append(util, mu)
}
// totalUtil is total utilization, measured in nanoseconds. This is a
// separate type primarily to distinguish it from mean utilization,
// which is also a float64.
type totalUtil float64
func totalUtilOf(meanUtil float64, dur int64) totalUtil {
return totalUtil(meanUtil * float64(dur))
}
// mean returns the mean utilization over dur.
func (u totalUtil) mean(dur time.Duration) float64 {
return float64(u) / float64(dur)
}
// An MMUCurve is the minimum mutator utilization curve across
// multiple window sizes.
type MMUCurve struct {
series []mmuSeries
}
type mmuSeries struct {
util []MutatorUtil
// sums[j] is the cumulative sum of util[:j].
sums []totalUtil
// bands summarizes util in non-overlapping bands of duration
// bandDur.
bands []mmuBand
// bandDur is the duration of each band.
bandDur int64
}
type mmuBand struct {
// minUtil is the minimum instantaneous mutator utilization in
// this band.
minUtil float64
// cumUtil is the cumulative total mutator utilization between
// time 0 and the left edge of this band.
cumUtil totalUtil
// integrator is the integrator for the left edge of this
// band.
integrator integrator
}
// NewMMUCurve returns an MMU curve for the given mutator utilization
// function.
func NewMMUCurve(utils [][]MutatorUtil) *MMUCurve {
series := make([]mmuSeries, len(utils))
for i, util := range utils {
series[i] = newMMUSeries(util)
}
return &MMUCurve{series}
}
// bandsPerSeries is the number of bands to divide each series into.
// This is only changed by tests.
var bandsPerSeries = 1000
func newMMUSeries(util []MutatorUtil) mmuSeries {
// Compute cumulative sum.
sums := make([]totalUtil, len(util))
var prev MutatorUtil
var sum totalUtil
for j, u := range util {
sum += totalUtilOf(prev.Util, u.Time-prev.Time)
sums[j] = sum
prev = u
}
// Divide the utilization curve up into equal size
// non-overlapping "bands" and compute a summary for each of
// these bands.
//
// Compute the duration of each band.
numBands := bandsPerSeries
if numBands > len(util) {
// There's no point in having lots of bands if there
// aren't many events.
numBands = len(util)
}
dur := util[len(util)-1].Time - util[0].Time
bandDur := (dur + int64(numBands) - 1) / int64(numBands)
if bandDur < 1 {
bandDur = 1
}
// Compute the bands. There are numBands+1 bands in order to
// record the final cumulative sum.
bands := make([]mmuBand, numBands+1)
s := mmuSeries{util, sums, bands, bandDur}
leftSum := integrator{&s, 0}
for i := range bands {
startTime, endTime := s.bandTime(i)
cumUtil := leftSum.advance(startTime)
predIdx := leftSum.pos
minUtil := 1.0
for i := predIdx; i < len(util) && util[i].Time < endTime; i++ {
minUtil = math.Min(minUtil, util[i].Util)
}
bands[i] = mmuBand{minUtil, cumUtil, leftSum}
}
return s
}
func (s *mmuSeries) bandTime(i int) (start, end int64) {
start = int64(i)*s.bandDur + s.util[0].Time
end = start + s.bandDur
return
}
type bandUtil struct {
// Utilization series index
series int
// Band index
i int
// Lower bound of mutator utilization for all windows
// with a left edge in this band.
utilBound float64
}
type bandUtilHeap []bandUtil
func (h bandUtilHeap) Len() int {
return len(h)
}
func (h bandUtilHeap) Less(i, j int) bool {
return h[i].utilBound < h[j].utilBound
}
func (h bandUtilHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *bandUtilHeap) Push(x any) {
*h = append(*h, x.(bandUtil))
}
func (h *bandUtilHeap) Pop() any {
x := (*h)[len(*h)-1]
*h = (*h)[:len(*h)-1]
return x
}
// UtilWindow is a specific window at Time.
type UtilWindow struct {
Time int64
// MutatorUtil is the mean mutator utilization in this window.
MutatorUtil float64
}
type utilHeap []UtilWindow
func (h utilHeap) Len() int {
return len(h)
}
func (h utilHeap) Less(i, j int) bool {
if h[i].MutatorUtil != h[j].MutatorUtil {
return h[i].MutatorUtil > h[j].MutatorUtil
}
return h[i].Time > h[j].Time
}
func (h utilHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h *utilHeap) Push(x any) {
*h = append(*h, x.(UtilWindow))
}
func (h *utilHeap) Pop() any {
x := (*h)[len(*h)-1]
*h = (*h)[:len(*h)-1]
return x
}
// An accumulator takes a windowed mutator utilization function and
// tracks various statistics for that function.
type accumulator struct {
mmu float64
// bound is the mutator utilization bound where adding any
// mutator utilization above this bound cannot affect the
// accumulated statistics.
bound float64
// Worst N window tracking
nWorst int
wHeap utilHeap
// Mutator utilization distribution tracking
mud *mud
// preciseMass is the distribution mass that must be precise
// before accumulation is stopped.
preciseMass float64
// lastTime and lastMU are the previous point added to the
// windowed mutator utilization function.
lastTime int64
lastMU float64
}
// resetTime declares a discontinuity in the windowed mutator
// utilization function by resetting the current time.
func (acc *accumulator) resetTime() {
// This only matters for distribution collection, since that's
// the only thing that depends on the progression of the
// windowed mutator utilization function.
acc.lastTime = math.MaxInt64
}
// addMU adds a point to the windowed mutator utilization function at
// (time, mu). This must be called for monotonically increasing values
// of time.
//
// It returns true if further calls to addMU would be pointless.
func (acc *accumulator) addMU(time int64, mu float64, window time.Duration) bool {
if mu < acc.mmu {
acc.mmu = mu
}
acc.bound = acc.mmu
if acc.nWorst == 0 {
// If the minimum has reached zero, it can't go any
// lower, so we can stop early.
return mu == 0
}
// Consider adding this window to the n worst.
if len(acc.wHeap) < acc.nWorst || mu < acc.wHeap[0].MutatorUtil {
// This window is lower than the K'th worst window.
//
// Check if there's any overlapping window
// already in the heap and keep whichever is
// worse.
for i, ui := range acc.wHeap {
if time+int64(window) > ui.Time && ui.Time+int64(window) > time {
if ui.MutatorUtil <= mu {
// Keep the first window.
goto keep
} else {
// Replace it with this window.
heap.Remove(&acc.wHeap, i)
break
}
}
}
heap.Push(&acc.wHeap, UtilWindow{time, mu})
if len(acc.wHeap) > acc.nWorst {
heap.Pop(&acc.wHeap)
}
keep:
}
if len(acc.wHeap) < acc.nWorst {
// We don't have N windows yet, so keep accumulating.
acc.bound = 1.0
} else {
// Anything above the least worst window has no effect.
acc.bound = math.Max(acc.bound, acc.wHeap[0].MutatorUtil)
}
if acc.mud != nil {
if acc.lastTime != math.MaxInt64 {
// Update distribution.
acc.mud.add(acc.lastMU, mu, float64(time-acc.lastTime))
}
acc.lastTime, acc.lastMU = time, mu
if _, mudBound, ok := acc.mud.approxInvCumulativeSum(); ok {
acc.bound = math.Max(acc.bound, mudBound)
} else {
// We haven't accumulated enough total precise
// mass yet to even reach our goal, so keep
// accumulating.
acc.bound = 1
}
// It's not worth checking percentiles every time, so
// just keep accumulating this band.
return false
}
// If we've found enough 0 utilizations, we can stop immediately.
return len(acc.wHeap) == acc.nWorst && acc.wHeap[0].MutatorUtil == 0
}
// MMU returns the minimum mutator utilization for the given time
// window. This is the minimum utilization for all windows of this
// duration across the execution. The returned value is in the range
// [0, 1].
func (c *MMUCurve) MMU(window time.Duration) (mmu float64) {
acc := accumulator{mmu: 1.0, bound: 1.0}
c.mmu(window, &acc)
return acc.mmu
}
// Examples returns n specific examples of the lowest mutator
// utilization for the given window size. The returned windows will be
// disjoint (otherwise there would be a huge number of
// mostly-overlapping windows at the single lowest point). There are
// no guarantees on which set of disjoint windows this returns.
func (c *MMUCurve) Examples(window time.Duration, n int) (worst []UtilWindow) {
acc := accumulator{mmu: 1.0, bound: 1.0, nWorst: n}
c.mmu(window, &acc)
sort.Sort(sort.Reverse(acc.wHeap))
return ([]UtilWindow)(acc.wHeap)
}
// MUD returns mutator utilization distribution quantiles for the
// given window size.
//
// The mutator utilization distribution is the distribution of mean
// mutator utilization across all windows of the given window size in
// the trace.
//
// The minimum mutator utilization is the minimum (0th percentile) of
// this distribution. (However, if only the minimum is desired, it's
// more efficient to use the MMU method.)
func (c *MMUCurve) MUD(window time.Duration, quantiles []float64) []float64 {
if len(quantiles) == 0 {
return []float64{}
}
// Each unrefined band contributes a known total mass to the
// distribution (bandDur except at the end), but in an unknown
// way. However, we know that all the mass it contributes must
// be at or above its worst-case mean mutator utilization.
//
// Hence, we refine bands until the highest desired
// distribution quantile is less than the next worst-case mean
// mutator utilization. At this point, all further
// contributions to the distribution must be beyond the
// desired quantile and hence cannot affect it.
//
// First, find the highest desired distribution quantile.
maxQ := quantiles[0]
for _, q := range quantiles {
if q > maxQ {
maxQ = q
}
}
// The distribution's mass is in units of time (it's not
// normalized because this would make it more annoying to
// account for future contributions of unrefined bands). The
// total final mass will be the duration of the trace itself
// minus the window size. Using this, we can compute the mass
// corresponding to quantile maxQ.
var duration int64
for _, s := range c.series {
duration1 := s.util[len(s.util)-1].Time - s.util[0].Time
if duration1 >= int64(window) {
duration += duration1 - int64(window)
}
}
qMass := float64(duration) * maxQ
// Accumulate the MUD until we have precise information for
// everything to the left of qMass.
acc := accumulator{mmu: 1.0, bound: 1.0, preciseMass: qMass, mud: new(mud)}
acc.mud.setTrackMass(qMass)
c.mmu(window, &acc)
// Evaluate the quantiles on the accumulated MUD.
out := make([]float64, len(quantiles))
for i := range out {
mu, _ := acc.mud.invCumulativeSum(float64(duration) * quantiles[i])
if math.IsNaN(mu) {
// There are a few legitimate ways this can
// happen:
//
// 1. If the window is the full trace
// duration, then the windowed MU function is
// only defined at a single point, so the MU
// distribution is not well-defined.
//
// 2. If there are no events, then the MU
// distribution has no mass.
//
// Either way, all of the quantiles will have
// converged toward the MMU at this point.
mu = acc.mmu
}
out[i] = mu
}
return out
}
func (c *MMUCurve) mmu(window time.Duration, acc *accumulator) {
if window <= 0 {
acc.mmu = 0
return
}
var bandU bandUtilHeap
windows := make([]time.Duration, len(c.series))
for i, s := range c.series {
windows[i] = window
if max := time.Duration(s.util[len(s.util)-1].Time - s.util[0].Time); window > max {
windows[i] = max
}
bandU1 := bandUtilHeap(s.mkBandUtil(i, windows[i]))
if bandU == nil {
bandU = bandU1
} else {
bandU = append(bandU, bandU1...)
}
}
// Process bands from lowest utilization bound to highest.
heap.Init(&bandU)
// Refine each band into a precise window and MMU until
// refining the next lowest band can no longer affect the MMU
// or windows.
for len(bandU) > 0 && bandU[0].utilBound < acc.bound {
i := bandU[0].series
c.series[i].bandMMU(bandU[0].i, windows[i], acc)
heap.Pop(&bandU)
}
}
func (c *mmuSeries) mkBandUtil(series int, window time.Duration) []bandUtil {
// For each band, compute the worst-possible total mutator
// utilization for all windows that start in that band.
// minBands is the minimum number of bands a window can span
// and maxBands is the maximum number of bands a window can
// span in any alignment.
minBands := int((int64(window) + c.bandDur - 1) / c.bandDur)
maxBands := int((int64(window) + 2*(c.bandDur-1)) / c.bandDur)
if window > 1 && maxBands < 2 {
panic("maxBands < 2")
}
tailDur := int64(window) % c.bandDur
nUtil := len(c.bands) - maxBands + 1
if nUtil < 0 {
nUtil = 0
}
bandU := make([]bandUtil, nUtil)
for i := range bandU {
// To compute the worst-case MU, we assume the minimum
// for any bands that are only partially overlapped by
// some window and the mean for any bands that are
// completely covered by all windows.
var util totalUtil
// Find the lowest and second lowest of the partial
// bands.
l := c.bands[i].minUtil
r1 := c.bands[i+minBands-1].minUtil
r2 := c.bands[i+maxBands-1].minUtil
minBand := math.Min(l, math.Min(r1, r2))
// Assume the worst window maximally overlaps the
// worst minimum and then the rest overlaps the second
// worst minimum.
if minBands == 1 {
util += totalUtilOf(minBand, int64(window))
} else {
util += totalUtilOf(minBand, c.bandDur)
midBand := 0.0
switch {
case minBand == l:
midBand = math.Min(r1, r2)
case minBand == r1:
midBand = math.Min(l, r2)
case minBand == r2:
midBand = math.Min(l, r1)
}
util += totalUtilOf(midBand, tailDur)
}
// Add the total mean MU of bands that are completely
// overlapped by all windows.
if minBands > 2 {
util += c.bands[i+minBands-1].cumUtil - c.bands[i+1].cumUtil
}
bandU[i] = bandUtil{series, i, util.mean(window)}
}
return bandU
}
// bandMMU computes the precise minimum mutator utilization for
// windows with a left edge in band bandIdx.
func (c *mmuSeries) bandMMU(bandIdx int, window time.Duration, acc *accumulator) {
util := c.util
// We think of the mutator utilization over time as the
// box-filtered utilization function, which we call the
// "windowed mutator utilization function". The resulting
// function is continuous and piecewise linear (unless
// window==0, which we handle elsewhere), where the boundaries
// between segments occur when either edge of the window
// encounters a change in the instantaneous mutator
// utilization function. Hence, the minimum of this function
// will always occur when one of the edges of the window
// aligns with a utilization change, so these are the only
// points we need to consider.
//
// We compute the mutator utilization function incrementally
// by tracking the integral from t=0 to the left edge of the
// window and to the right edge of the window.
left := c.bands[bandIdx].integrator
right := left
time, endTime := c.bandTime(bandIdx)
if utilEnd := util[len(util)-1].Time - int64(window); utilEnd < endTime {
endTime = utilEnd
}
acc.resetTime()
for {
// Advance edges to time and time+window.
mu := (right.advance(time+int64(window)) - left.advance(time)).mean(window)
if acc.addMU(time, mu, window) {
break
}
if time == endTime {
break
}
// The maximum slope of the windowed mutator
// utilization function is 1/window, so we can always
// advance the time by at least (mu - mmu) * window
// without dropping below mmu.
minTime := time + int64((mu-acc.bound)*float64(window))
// Advance the window to the next time where either
// the left or right edge of the window encounters a
// change in the utilization curve.
if t1, t2 := left.next(time), right.next(time+int64(window))-int64(window); t1 < t2 {
time = t1
} else {
time = t2
}
if time < minTime {
time = minTime
}
if time >= endTime {
// For MMUs we could stop here, but for MUDs
// it's important that we span the entire
// band.
time = endTime
}
}
}
// An integrator tracks a position in a utilization function and
// integrates it.
type integrator struct {
u *mmuSeries
// pos is the index in u.util of the current time's non-strict
// predecessor.
pos int
}
// advance returns the integral of the utilization function from 0 to
// time. advance must be called on monotonically increasing values of
// times.
func (in *integrator) advance(time int64) totalUtil {
util, pos := in.u.util, in.pos
// Advance pos until pos+1 is time's strict successor (making
// pos time's non-strict predecessor).
//
// Very often, this will be nearby, so we optimize that case,
// but it may be arbitrarily far away, so we handled that
// efficiently, too.
const maxSeq = 8
if pos+maxSeq < len(util) && util[pos+maxSeq].Time > time {
// Nearby. Use a linear scan.
for pos+1 < len(util) && util[pos+1].Time <= time {
pos++
}
} else {
// Far. Binary search for time's strict successor.
l, r := pos, len(util)
for l < r {
h := int(uint(l+r) >> 1)
if util[h].Time <= time {
l = h + 1
} else {
r = h
}
}
pos = l - 1 // Non-strict predecessor.
}
in.pos = pos
var partial totalUtil
if time != util[pos].Time {
partial = totalUtilOf(util[pos].Util, time-util[pos].Time)
}
return in.u.sums[pos] + partial
}
// next returns the smallest time t' > time of a change in the
// utilization function.
func (in *integrator) next(time int64) int64 {
for _, u := range in.u.util[in.pos:] {
if u.Time > time {
return u.Time
}
}
return 1<<63 - 1
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"sort"
"strings"
)
// GDesc contains statistics and execution details of a single goroutine.
type GDesc struct {
ID uint64
Name string
PC uint64
CreationTime int64
StartTime int64
EndTime int64
// List of regions in the goroutine, sorted based on the start time.
Regions []*UserRegionDesc
// Statistics of execution time during the goroutine execution.
GExecutionStat
*gdesc // private part.
}
// UserRegionDesc represents a region and goroutine execution stats
// while the region was active.
type UserRegionDesc struct {
TaskID uint64
Name string
// Region start event. Normally EvUserRegion start event or nil,
// but can be EvGoCreate event if the region is a synthetic
// region representing task inheritance from the parent goroutine.
Start *Event
// Region end event. Normally EvUserRegion end event or nil,
// but can be EvGoStop or EvGoEnd event if the goroutine
// terminated without explicitly ending the region.
End *Event
GExecutionStat
}
// GExecutionStat contains statistics about a goroutine's execution
// during a period of time.
type GExecutionStat struct {
ExecTime int64
SchedWaitTime int64
IOTime int64
BlockTime int64
SyscallTime int64
GCTime int64
SweepTime int64
TotalTime int64
}
// sub returns the stats v-s.
func (s GExecutionStat) sub(v GExecutionStat) (r GExecutionStat) {
r = s
r.ExecTime -= v.ExecTime
r.SchedWaitTime -= v.SchedWaitTime
r.IOTime -= v.IOTime
r.BlockTime -= v.BlockTime
r.SyscallTime -= v.SyscallTime
r.GCTime -= v.GCTime
r.SweepTime -= v.SweepTime
r.TotalTime -= v.TotalTime
return r
}
// snapshotStat returns the snapshot of the goroutine execution statistics.
// This is called as we process the ordered trace event stream. lastTs and
// activeGCStartTime are used to process pending statistics if this is called
// before any goroutine end event.
func (g *GDesc) snapshotStat(lastTs, activeGCStartTime int64) (ret GExecutionStat) {
ret = g.GExecutionStat
if g.gdesc == nil {
return ret // finalized GDesc. No pending state.
}
if activeGCStartTime != 0 { // terminating while GC is active
if g.CreationTime < activeGCStartTime {
ret.GCTime += lastTs - activeGCStartTime
} else {
// The goroutine's lifetime completely overlaps
// with a GC.
ret.GCTime += lastTs - g.CreationTime
}
}
if g.TotalTime == 0 {
ret.TotalTime = lastTs - g.CreationTime
}
if g.lastStartTime != 0 {
ret.ExecTime += lastTs - g.lastStartTime
}
if g.blockNetTime != 0 {
ret.IOTime += lastTs - g.blockNetTime
}
if g.blockSyncTime != 0 {
ret.BlockTime += lastTs - g.blockSyncTime
}
if g.blockSyscallTime != 0 {
ret.SyscallTime += lastTs - g.blockSyscallTime
}
if g.blockSchedTime != 0 {
ret.SchedWaitTime += lastTs - g.blockSchedTime
}
if g.blockSweepTime != 0 {
ret.SweepTime += lastTs - g.blockSweepTime
}
return ret
}
// finalize is called when processing a goroutine end event or at
// the end of trace processing. This finalizes the execution stat
// and any active regions in the goroutine, in which case trigger is nil.
func (g *GDesc) finalize(lastTs, activeGCStartTime int64, trigger *Event) {
if trigger != nil {
g.EndTime = trigger.Ts
}
finalStat := g.snapshotStat(lastTs, activeGCStartTime)
g.GExecutionStat = finalStat
// System goroutines are never part of regions, even though they
// "inherit" a task due to creation (EvGoCreate) from within a region.
// This may happen e.g. if the first GC is triggered within a region,
// starting the GC worker goroutines.
if !IsSystemGoroutine(g.Name) {
for _, s := range g.activeRegions {
s.End = trigger
s.GExecutionStat = finalStat.sub(s.GExecutionStat)
g.Regions = append(g.Regions, s)
}
}
*(g.gdesc) = gdesc{}
}
// gdesc is a private part of GDesc that is required only during analysis.
type gdesc struct {
lastStartTime int64
blockNetTime int64
blockSyncTime int64
blockSyscallTime int64
blockSweepTime int64
blockGCTime int64
blockSchedTime int64
activeRegions []*UserRegionDesc // stack of active regions
}
// GoroutineStats generates statistics for all goroutines in the trace.
func GoroutineStats(events []*Event) map[uint64]*GDesc {
gs := make(map[uint64]*GDesc)
var lastTs int64
var gcStartTime int64 // gcStartTime == 0 indicates gc is inactive.
for _, ev := range events {
lastTs = ev.Ts
switch ev.Type {
case EvGoCreate:
g := &GDesc{ID: ev.Args[0], CreationTime: ev.Ts, gdesc: new(gdesc)}
g.blockSchedTime = ev.Ts
// When a goroutine is newly created, inherit the task
// of the active region. For ease handling of this
// case, we create a fake region description with the
// task id. This isn't strictly necessary as this
// goroutine may not be associated with the task, but
// it can be convenient to see all children created
// during a region.
if creatorG := gs[ev.G]; creatorG != nil && len(creatorG.gdesc.activeRegions) > 0 {
regions := creatorG.gdesc.activeRegions
s := regions[len(regions)-1]
if s.TaskID != 0 {
g.gdesc.activeRegions = []*UserRegionDesc{
{TaskID: s.TaskID, Start: ev},
}
}
}
gs[g.ID] = g
case EvGoStart, EvGoStartLabel:
g := gs[ev.G]
if g.PC == 0 && len(ev.Stk) > 0 {
g.PC = ev.Stk[0].PC
g.Name = ev.Stk[0].Fn
}
g.lastStartTime = ev.Ts
if g.StartTime == 0 {
g.StartTime = ev.Ts
}
if g.blockSchedTime != 0 {
g.SchedWaitTime += ev.Ts - g.blockSchedTime
g.blockSchedTime = 0
}
case EvGoEnd, EvGoStop:
g := gs[ev.G]
g.finalize(ev.Ts, gcStartTime, ev)
case EvGoBlockSend, EvGoBlockRecv, EvGoBlockSelect,
EvGoBlockSync, EvGoBlockCond:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
g.blockSyncTime = ev.Ts
case EvGoSched, EvGoPreempt:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
g.blockSchedTime = ev.Ts
case EvGoSleep, EvGoBlock:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
case EvGoBlockNet:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
g.blockNetTime = ev.Ts
case EvGoBlockGC:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
g.blockGCTime = ev.Ts
case EvGoUnblock:
g := gs[ev.Args[0]]
if g.blockNetTime != 0 {
g.IOTime += ev.Ts - g.blockNetTime
g.blockNetTime = 0
}
if g.blockSyncTime != 0 {
g.BlockTime += ev.Ts - g.blockSyncTime
g.blockSyncTime = 0
}
g.blockSchedTime = ev.Ts
case EvGoSysBlock:
g := gs[ev.G]
g.ExecTime += ev.Ts - g.lastStartTime
g.lastStartTime = 0
g.blockSyscallTime = ev.Ts
case EvGoSysExit:
g := gs[ev.G]
if g.blockSyscallTime != 0 {
g.SyscallTime += ev.Ts - g.blockSyscallTime
g.blockSyscallTime = 0
}
g.blockSchedTime = ev.Ts
case EvGCSweepStart:
g := gs[ev.G]
if g != nil {
// Sweep can happen during GC on system goroutine.
g.blockSweepTime = ev.Ts
}
case EvGCSweepDone:
g := gs[ev.G]
if g != nil && g.blockSweepTime != 0 {
g.SweepTime += ev.Ts - g.blockSweepTime
g.blockSweepTime = 0
}
case EvGCStart:
gcStartTime = ev.Ts
case EvGCDone:
for _, g := range gs {
if g.EndTime != 0 {
continue
}
if gcStartTime < g.CreationTime {
g.GCTime += ev.Ts - g.CreationTime
} else {
g.GCTime += ev.Ts - gcStartTime
}
}
gcStartTime = 0 // indicates gc is inactive.
case EvUserRegion:
g := gs[ev.G]
switch mode := ev.Args[1]; mode {
case 0: // region start
g.activeRegions = append(g.activeRegions, &UserRegionDesc{
Name: ev.SArgs[0],
TaskID: ev.Args[0],
Start: ev,
GExecutionStat: g.snapshotStat(lastTs, gcStartTime),
})
case 1: // region end
var sd *UserRegionDesc
if regionStk := g.activeRegions; len(regionStk) > 0 {
n := len(regionStk)
sd = regionStk[n-1]
regionStk = regionStk[:n-1] // pop
g.activeRegions = regionStk
} else {
sd = &UserRegionDesc{
Name: ev.SArgs[0],
TaskID: ev.Args[0],
}
}
sd.GExecutionStat = g.snapshotStat(lastTs, gcStartTime).sub(sd.GExecutionStat)
sd.End = ev
g.Regions = append(g.Regions, sd)
}
}
}
for _, g := range gs {
g.finalize(lastTs, gcStartTime, nil)
// sort based on region start time
sort.Slice(g.Regions, func(i, j int) bool {
x := g.Regions[i].Start
y := g.Regions[j].Start
if x == nil {
return true
}
if y == nil {
return false
}
return x.Ts < y.Ts
})
g.gdesc = nil
}
return gs
}
// RelatedGoroutines finds a set of goroutines related to goroutine goid.
func RelatedGoroutines(events []*Event, goid uint64) map[uint64]bool {
// BFS of depth 2 over "unblock" edges
// (what goroutines unblock goroutine goid?).
gmap := make(map[uint64]bool)
gmap[goid] = true
for i := 0; i < 2; i++ {
gmap1 := make(map[uint64]bool)
for g := range gmap {
gmap1[g] = true
}
for _, ev := range events {
if ev.Type == EvGoUnblock && gmap[ev.Args[0]] {
gmap1[ev.G] = true
}
}
gmap = gmap1
}
gmap[0] = true // for GC events
return gmap
}
func IsSystemGoroutine(entryFn string) bool {
// This mimics runtime.isSystemGoroutine as closely as
// possible.
// Also, locked g in extra M (with empty entryFn) is system goroutine.
return entryFn == "" || entryFn != "runtime.main" && strings.HasPrefix(entryFn, "runtime.")
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"math"
"sort"
)
// mud is an updatable mutator utilization distribution.
//
// This is a continuous distribution of duration over mutator
// utilization. For example, the integral from mutator utilization a
// to b is the total duration during which the mutator utilization was
// in the range [a, b].
//
// This distribution is *not* normalized (it is not a probability
// distribution). This makes it easier to work with as it's being
// updated.
//
// It is represented as the sum of scaled uniform distribution
// functions and Dirac delta functions (which are treated as
// degenerate uniform distributions).
type mud struct {
sorted, unsorted []edge
// trackMass is the inverse cumulative sum to track as the
// distribution is updated.
trackMass float64
// trackBucket is the bucket in which trackMass falls. If the
// total mass of the distribution is < trackMass, this is
// len(hist).
trackBucket int
// trackSum is the cumulative sum of hist[:trackBucket]. Once
// trackSum >= trackMass, trackBucket must be recomputed.
trackSum float64
// hist is a hierarchical histogram of distribution mass.
hist [mudDegree]float64
}
const (
// mudDegree is the number of buckets in the MUD summary
// histogram.
mudDegree = 1024
)
type edge struct {
// At x, the function increases by y.
x, delta float64
// Additionally at x is a Dirac delta function with area dirac.
dirac float64
}
// add adds a uniform function over [l, r] scaled so the total weight
// of the uniform is area. If l==r, this adds a Dirac delta function.
func (d *mud) add(l, r, area float64) {
if area == 0 {
return
}
if r < l {
l, r = r, l
}
// Add the edges.
if l == r {
d.unsorted = append(d.unsorted, edge{l, 0, area})
} else {
delta := area / (r - l)
d.unsorted = append(d.unsorted, edge{l, delta, 0}, edge{r, -delta, 0})
}
// Update the histogram.
h := &d.hist
lbFloat, lf := math.Modf(l * mudDegree)
lb := int(lbFloat)
if lb >= mudDegree {
lb, lf = mudDegree-1, 1
}
if l == r {
h[lb] += area
} else {
rbFloat, rf := math.Modf(r * mudDegree)
rb := int(rbFloat)
if rb >= mudDegree {
rb, rf = mudDegree-1, 1
}
if lb == rb {
h[lb] += area
} else {
perBucket := area / (r - l) / mudDegree
h[lb] += perBucket * (1 - lf)
h[rb] += perBucket * rf
for i := lb + 1; i < rb; i++ {
h[i] += perBucket
}
}
}
// Update mass tracking.
if thresh := float64(d.trackBucket) / mudDegree; l < thresh {
if r < thresh {
d.trackSum += area
} else {
d.trackSum += area * (thresh - l) / (r - l)
}
if d.trackSum >= d.trackMass {
// The tracked mass now falls in a different
// bucket. Recompute the inverse cumulative sum.
d.setTrackMass(d.trackMass)
}
}
}
// setTrackMass sets the mass to track the inverse cumulative sum for.
//
// Specifically, mass is a cumulative duration, and the mutator
// utilization bounds for this duration can be queried using
// approxInvCumulativeSum.
func (d *mud) setTrackMass(mass float64) {
d.trackMass = mass
// Find the bucket currently containing trackMass by computing
// the cumulative sum.
sum := 0.0
for i, val := range d.hist[:] {
newSum := sum + val
if newSum > mass {
// mass falls in bucket i.
d.trackBucket = i
d.trackSum = sum
return
}
sum = newSum
}
d.trackBucket = len(d.hist)
d.trackSum = sum
}
// approxInvCumulativeSum is like invCumulativeSum, but specifically
// operates on the tracked mass and returns an upper and lower bound
// approximation of the inverse cumulative sum.
//
// The true inverse cumulative sum will be in the range [lower, upper).
func (d *mud) approxInvCumulativeSum() (float64, float64, bool) {
if d.trackBucket == len(d.hist) {
return math.NaN(), math.NaN(), false
}
return float64(d.trackBucket) / mudDegree, float64(d.trackBucket+1) / mudDegree, true
}
// invCumulativeSum returns x such that the integral of d from -∞ to x
// is y. If the total weight of d is less than y, it returns the
// maximum of the distribution and false.
//
// Specifically, y is a cumulative duration, and invCumulativeSum
// returns the mutator utilization x such that at least y time has
// been spent with mutator utilization <= x.
func (d *mud) invCumulativeSum(y float64) (float64, bool) {
if len(d.sorted) == 0 && len(d.unsorted) == 0 {
return math.NaN(), false
}
// Sort edges.
edges := d.unsorted
sort.Slice(edges, func(i, j int) bool {
return edges[i].x < edges[j].x
})
// Merge with sorted edges.
d.unsorted = nil
if d.sorted == nil {
d.sorted = edges
} else {
oldSorted := d.sorted
newSorted := make([]edge, len(oldSorted)+len(edges))
i, j := 0, 0
for o := range newSorted {
if i >= len(oldSorted) {
copy(newSorted[o:], edges[j:])
break
} else if j >= len(edges) {
copy(newSorted[o:], oldSorted[i:])
break
} else if oldSorted[i].x < edges[j].x {
newSorted[o] = oldSorted[i]
i++
} else {
newSorted[o] = edges[j]
j++
}
}
d.sorted = newSorted
}
// Traverse edges in order computing a cumulative sum.
csum, rate, prevX := 0.0, 0.0, 0.0
for _, e := range d.sorted {
newCsum := csum + (e.x-prevX)*rate
if newCsum >= y {
// y was exceeded between the previous edge
// and this one.
if rate == 0 {
// Anywhere between prevX and
// e.x will do. We return e.x
// because that takes care of
// the y==0 case naturally.
return e.x, true
}
return (y-csum)/rate + prevX, true
}
newCsum += e.dirac
if newCsum >= y {
// y was exceeded by the Dirac delta at e.x.
return e.x, true
}
csum, prevX = newCsum, e.x
rate += e.delta
}
return prevX, false
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"fmt"
"sort"
)
type eventBatch struct {
events []*Event
selected bool
}
type orderEvent struct {
ev *Event
batch int
g uint64
init gState
next gState
}
type gStatus int
type gState struct {
seq uint64
status gStatus
}
const (
gDead gStatus = iota
gRunnable
gRunning
gWaiting
unordered = ^uint64(0)
garbage = ^uint64(0) - 1
noseq = ^uint64(0)
seqinc = ^uint64(0) - 1
)
// order1007 merges a set of per-P event batches into a single, consistent stream.
// The high level idea is as follows. Events within an individual batch are in
// correct order, because they are emitted by a single P. So we need to produce
// a correct interleaving of the batches. To do this we take first unmerged event
// from each batch (frontier). Then choose subset that is "ready" to be merged,
// that is, events for which all dependencies are already merged. Then we choose
// event with the lowest timestamp from the subset, merge it and repeat.
// This approach ensures that we form a consistent stream even if timestamps are
// incorrect (condition observed on some machines).
func order1007(m map[int][]*Event) (events []*Event, err error) {
pending := 0
// The ordering of CPU profile sample events in the data stream is based on
// when each run of the signal handler was able to acquire the spinlock,
// with original timestamps corresponding to when ReadTrace pulled the data
// off of the profBuf queue. Re-sort them by the timestamp we captured
// inside the signal handler.
sort.Stable(eventList(m[ProfileP]))
var batches []*eventBatch
for _, v := range m {
pending += len(v)
batches = append(batches, &eventBatch{v, false})
}
gs := make(map[uint64]gState)
var frontier []orderEvent
for ; pending != 0; pending-- {
for i, b := range batches {
if b.selected || len(b.events) == 0 {
continue
}
ev := b.events[0]
g, init, next := stateTransition(ev)
if !transitionReady(g, gs[g], init) {
continue
}
frontier = append(frontier, orderEvent{ev, i, g, init, next})
b.events = b.events[1:]
b.selected = true
// Get rid of "Local" events, they are intended merely for ordering.
switch ev.Type {
case EvGoStartLocal:
ev.Type = EvGoStart
case EvGoUnblockLocal:
ev.Type = EvGoUnblock
case EvGoSysExitLocal:
ev.Type = EvGoSysExit
}
}
if len(frontier) == 0 {
return nil, fmt.Errorf("no consistent ordering of events possible")
}
sort.Sort(orderEventList(frontier))
f := frontier[0]
frontier[0] = frontier[len(frontier)-1]
frontier = frontier[:len(frontier)-1]
events = append(events, f.ev)
transition(gs, f.g, f.init, f.next)
if !batches[f.batch].selected {
panic("frontier batch is not selected")
}
batches[f.batch].selected = false
}
// At this point we have a consistent stream of events.
// Make sure time stamps respect the ordering.
// The tests will skip (not fail) the test case if they see this error.
if !sort.IsSorted(eventList(events)) {
return nil, ErrTimeOrder
}
// The last part is giving correct timestamps to EvGoSysExit events.
// The problem with EvGoSysExit is that actual syscall exit timestamp (ev.Args[2])
// is potentially acquired long before event emission. So far we've used
// timestamp of event emission (ev.Ts).
// We could not set ev.Ts = ev.Args[2] earlier, because it would produce
// seemingly broken timestamps (misplaced event).
// We also can't simply update the timestamp and resort events, because
// if timestamps are broken we will misplace the event and later report
// logically broken trace (instead of reporting broken timestamps).
lastSysBlock := make(map[uint64]int64)
for _, ev := range events {
switch ev.Type {
case EvGoSysBlock, EvGoInSyscall:
lastSysBlock[ev.G] = ev.Ts
case EvGoSysExit:
ts := int64(ev.Args[2])
if ts == 0 {
continue
}
block := lastSysBlock[ev.G]
if block == 0 {
return nil, fmt.Errorf("stray syscall exit")
}
if ts < block {
return nil, ErrTimeOrder
}
ev.Ts = ts
}
}
sort.Stable(eventList(events))
return
}
// stateTransition returns goroutine state (sequence and status) when the event
// becomes ready for merging (init) and the goroutine state after the event (next).
func stateTransition(ev *Event) (g uint64, init, next gState) {
switch ev.Type {
case EvGoCreate:
g = ev.Args[0]
init = gState{0, gDead}
next = gState{1, gRunnable}
case EvGoWaiting, EvGoInSyscall:
g = ev.G
init = gState{1, gRunnable}
next = gState{2, gWaiting}
case EvGoStart, EvGoStartLabel:
g = ev.G
init = gState{ev.Args[1], gRunnable}
next = gState{ev.Args[1] + 1, gRunning}
case EvGoStartLocal:
// noseq means that this event is ready for merging as soon as
// frontier reaches it (EvGoStartLocal is emitted on the same P
// as the corresponding EvGoCreate/EvGoUnblock, and thus the latter
// is already merged).
// seqinc is a stub for cases when event increments g sequence,
// but since we don't know current seq we also don't know next seq.
g = ev.G
init = gState{noseq, gRunnable}
next = gState{seqinc, gRunning}
case EvGoBlock, EvGoBlockSend, EvGoBlockRecv, EvGoBlockSelect,
EvGoBlockSync, EvGoBlockCond, EvGoBlockNet, EvGoSleep,
EvGoSysBlock, EvGoBlockGC:
g = ev.G
init = gState{noseq, gRunning}
next = gState{noseq, gWaiting}
case EvGoSched, EvGoPreempt:
g = ev.G
init = gState{noseq, gRunning}
next = gState{noseq, gRunnable}
case EvGoUnblock, EvGoSysExit:
g = ev.Args[0]
init = gState{ev.Args[1], gWaiting}
next = gState{ev.Args[1] + 1, gRunnable}
case EvGoUnblockLocal, EvGoSysExitLocal:
g = ev.Args[0]
init = gState{noseq, gWaiting}
next = gState{seqinc, gRunnable}
case EvGCStart:
g = garbage
init = gState{ev.Args[0], gDead}
next = gState{ev.Args[0] + 1, gDead}
default:
// no ordering requirements
g = unordered
}
return
}
func transitionReady(g uint64, curr, init gState) bool {
return g == unordered || (init.seq == noseq || init.seq == curr.seq) && init.status == curr.status
}
func transition(gs map[uint64]gState, g uint64, init, next gState) {
if g == unordered {
return
}
curr := gs[g]
if !transitionReady(g, curr, init) {
panic("event sequences are broken")
}
switch next.seq {
case noseq:
next.seq = curr.seq
case seqinc:
next.seq = curr.seq + 1
}
gs[g] = next
}
// order1005 merges a set of per-P event batches into a single, consistent stream.
func order1005(m map[int][]*Event) (events []*Event, err error) {
for _, batch := range m {
events = append(events, batch...)
}
for _, ev := range events {
if ev.Type == EvGoSysExit {
// EvGoSysExit emission is delayed until the thread has a P.
// Give it the real sequence number and time stamp.
ev.seq = int64(ev.Args[1])
if ev.Args[2] != 0 {
ev.Ts = int64(ev.Args[2])
}
}
}
sort.Sort(eventSeqList(events))
if !sort.IsSorted(eventList(events)) {
return nil, ErrTimeOrder
}
return
}
type orderEventList []orderEvent
func (l orderEventList) Len() int {
return len(l)
}
func (l orderEventList) Less(i, j int) bool {
return l[i].ev.Ts < l[j].ev.Ts
}
func (l orderEventList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
type eventList []*Event
func (l eventList) Len() int {
return len(l)
}
func (l eventList) Less(i, j int) bool {
return l[i].Ts < l[j].Ts
}
func (l eventList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
type eventSeqList []*Event
func (l eventSeqList) Len() int {
return len(l)
}
func (l eventSeqList) Less(i, j int) bool {
return l[i].seq < l[j].seq
}
func (l eventSeqList) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"bufio"
"bytes"
"fmt"
"io"
"math/rand"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
_ "unsafe"
)
func goCmd() string {
var exeSuffix string
if runtime.GOOS == "windows" {
exeSuffix = ".exe"
}
path := filepath.Join(runtime.GOROOT(), "bin", "go"+exeSuffix)
if _, err := os.Stat(path); err == nil {
return path
}
return "go"
}
// Event describes one event in the trace.
type Event struct {
Off int // offset in input file (for debugging and error reporting)
Type byte // one of Ev*
seq int64 // sequence number
Ts int64 // timestamp in nanoseconds
P int // P on which the event happened (can be one of TimerP, NetpollP, SyscallP)
G uint64 // G on which the event happened
StkID uint64 // unique stack ID
Stk []*Frame // stack trace (can be empty)
Args [3]uint64 // event-type-specific arguments
SArgs []string // event-type-specific string args
// linked event (can be nil), depends on event type:
// for GCStart: the GCStop
// for GCSTWStart: the GCSTWDone
// for GCSweepStart: the GCSweepDone
// for GoCreate: first GoStart of the created goroutine
// for GoStart/GoStartLabel: the associated GoEnd, GoBlock or other blocking event
// for GoSched/GoPreempt: the next GoStart
// for GoBlock and other blocking events: the unblock event
// for GoUnblock: the associated GoStart
// for blocking GoSysCall: the associated GoSysExit
// for GoSysExit: the next GoStart
// for GCMarkAssistStart: the associated GCMarkAssistDone
// for UserTaskCreate: the UserTaskEnd
// for UserRegion: if the start region, the corresponding UserRegion end event
Link *Event
}
// Frame is a frame in stack traces.
type Frame struct {
PC uint64
Fn string
File string
Line int
}
const (
// Special P identifiers:
FakeP = 1000000 + iota
TimerP // depicts timer unblocks
NetpollP // depicts network unblocks
SyscallP // depicts returns from syscalls
GCP // depicts GC state
ProfileP // depicts recording of CPU profile samples
)
// ParseResult is the result of Parse.
type ParseResult struct {
// Events is the sorted list of Events in the trace.
Events []*Event
// Stacks is the stack traces keyed by stack IDs from the trace.
Stacks map[uint64][]*Frame
}
// Parse parses, post-processes and verifies the trace.
func Parse(r io.Reader, bin string) (ParseResult, error) {
ver, res, err := parse(r, bin)
if err != nil {
return ParseResult{}, err
}
if ver < 1007 && bin == "" {
return ParseResult{}, fmt.Errorf("for traces produced by go 1.6 or below, the binary argument must be provided")
}
return res, nil
}
// parse parses, post-processes and verifies the trace. It returns the
// trace version and the list of events.
func parse(r io.Reader, bin string) (int, ParseResult, error) {
ver, rawEvents, strings, err := readTrace(r)
if err != nil {
return 0, ParseResult{}, err
}
events, stacks, err := parseEvents(ver, rawEvents, strings)
if err != nil {
return 0, ParseResult{}, err
}
events = removeFutile(events)
err = postProcessTrace(ver, events)
if err != nil {
return 0, ParseResult{}, err
}
// Attach stack traces.
for _, ev := range events {
if ev.StkID != 0 {
ev.Stk = stacks[ev.StkID]
}
}
if ver < 1007 && bin != "" {
if err := symbolize(events, bin); err != nil {
return 0, ParseResult{}, err
}
}
return ver, ParseResult{Events: events, Stacks: stacks}, nil
}
// rawEvent is a helper type used during parsing.
type rawEvent struct {
off int
typ byte
args []uint64
sargs []string
}
// readTrace does wire-format parsing and verification.
// It does not care about specific event types and argument meaning.
func readTrace(r io.Reader) (ver int, events []rawEvent, strings map[uint64]string, err error) {
// Read and validate trace header.
var buf [16]byte
off, err := io.ReadFull(r, buf[:])
if err != nil {
err = fmt.Errorf("failed to read header: read %v, err %v", off, err)
return
}
ver, err = parseHeader(buf[:])
if err != nil {
return
}
switch ver {
case 1005, 1007, 1008, 1009, 1010, 1011, 1019:
// Note: When adding a new version, confirm that canned traces from the
// old version are part of the test suite. Add them using mkcanned.bash.
break
default:
err = fmt.Errorf("unsupported trace file version %v.%v (update Go toolchain) %v", ver/1000, ver%1000, ver)
return
}
// Read events.
strings = make(map[uint64]string)
for {
// Read event type and number of arguments (1 byte).
off0 := off
var n int
n, err = r.Read(buf[:1])
if err == io.EOF {
err = nil
break
}
if err != nil || n != 1 {
err = fmt.Errorf("failed to read trace at offset 0x%x: n=%v err=%v", off0, n, err)
return
}
off += n
typ := buf[0] << 2 >> 2
narg := buf[0]>>6 + 1
inlineArgs := byte(4)
if ver < 1007 {
narg++
inlineArgs++
}
if typ == EvNone || typ >= EvCount || EventDescriptions[typ].minVersion > ver {
err = fmt.Errorf("unknown event type %v at offset 0x%x", typ, off0)
return
}
if typ == EvString {
// String dictionary entry [ID, length, string].
var id uint64
id, off, err = readVal(r, off)
if err != nil {
return
}
if id == 0 {
err = fmt.Errorf("string at offset %d has invalid id 0", off)
return
}
if strings[id] != "" {
err = fmt.Errorf("string at offset %d has duplicate id %v", off, id)
return
}
var ln uint64
ln, off, err = readVal(r, off)
if err != nil {
return
}
if ln == 0 {
err = fmt.Errorf("string at offset %d has invalid length 0", off)
return
}
if ln > 1e6 {
err = fmt.Errorf("string at offset %d has too large length %v", off, ln)
return
}
buf := make([]byte, ln)
var n int
n, err = io.ReadFull(r, buf)
if err != nil {
err = fmt.Errorf("failed to read trace at offset %d: read %v, want %v, error %v", off, n, ln, err)
return
}
off += n
strings[id] = string(buf)
continue
}
ev := rawEvent{typ: typ, off: off0}
if narg < inlineArgs {
for i := 0; i < int(narg); i++ {
var v uint64
v, off, err = readVal(r, off)
if err != nil {
err = fmt.Errorf("failed to read event %v argument at offset %v (%v)", typ, off, err)
return
}
ev.args = append(ev.args, v)
}
} else {
// More than inlineArgs args, the first value is length of the event in bytes.
var v uint64
v, off, err = readVal(r, off)
if err != nil {
err = fmt.Errorf("failed to read event %v argument at offset %v (%v)", typ, off, err)
return
}
evLen := v
off1 := off
for evLen > uint64(off-off1) {
v, off, err = readVal(r, off)
if err != nil {
err = fmt.Errorf("failed to read event %v argument at offset %v (%v)", typ, off, err)
return
}
ev.args = append(ev.args, v)
}
if evLen != uint64(off-off1) {
err = fmt.Errorf("event has wrong length at offset 0x%x: want %v, got %v", off0, evLen, off-off1)
return
}
}
switch ev.typ {
case EvUserLog: // EvUserLog records are followed by a value string of length ev.args[len(ev.args)-1]
var s string
s, off, err = readStr(r, off)
ev.sargs = append(ev.sargs, s)
}
events = append(events, ev)
}
return
}
func readStr(r io.Reader, off0 int) (s string, off int, err error) {
var sz uint64
sz, off, err = readVal(r, off0)
if err != nil || sz == 0 {
return "", off, err
}
if sz > 1e6 {
return "", off, fmt.Errorf("string at offset %d is too large (len=%d)", off, sz)
}
buf := make([]byte, sz)
n, err := io.ReadFull(r, buf)
if err != nil || sz != uint64(n) {
return "", off + n, fmt.Errorf("failed to read trace at offset %d: read %v, want %v, error %v", off, n, sz, err)
}
return string(buf), off + n, nil
}
// parseHeader parses trace header of the form "go 1.7 trace\x00\x00\x00\x00"
// and returns parsed version as 1007.
func parseHeader(buf []byte) (int, error) {
if len(buf) != 16 {
return 0, fmt.Errorf("bad header length")
}
if buf[0] != 'g' || buf[1] != 'o' || buf[2] != ' ' ||
buf[3] < '1' || buf[3] > '9' ||
buf[4] != '.' ||
buf[5] < '1' || buf[5] > '9' {
return 0, fmt.Errorf("not a trace file")
}
ver := int(buf[5] - '0')
i := 0
for ; buf[6+i] >= '0' && buf[6+i] <= '9' && i < 2; i++ {
ver = ver*10 + int(buf[6+i]-'0')
}
ver += int(buf[3]-'0') * 1000
if !bytes.Equal(buf[6+i:], []byte(" trace\x00\x00\x00\x00")[:10-i]) {
return 0, fmt.Errorf("not a trace file")
}
return ver, nil
}
// Parse events transforms raw events into events.
// It does analyze and verify per-event-type arguments.
func parseEvents(ver int, rawEvents []rawEvent, strings map[uint64]string) (events []*Event, stacks map[uint64][]*Frame, err error) {
var ticksPerSec, lastSeq, lastTs int64
var lastG uint64
var lastP int
timerGoids := make(map[uint64]bool)
lastGs := make(map[int]uint64) // last goroutine running on P
stacks = make(map[uint64][]*Frame)
batches := make(map[int][]*Event) // events by P
for _, raw := range rawEvents {
desc := EventDescriptions[raw.typ]
if desc.Name == "" {
err = fmt.Errorf("missing description for event type %v", raw.typ)
return
}
narg := argNum(raw, ver)
if len(raw.args) != narg {
err = fmt.Errorf("%v has wrong number of arguments at offset 0x%x: want %v, got %v",
desc.Name, raw.off, narg, len(raw.args))
return
}
switch raw.typ {
case EvBatch:
lastGs[lastP] = lastG
lastP = int(raw.args[0])
lastG = lastGs[lastP]
if ver < 1007 {
lastSeq = int64(raw.args[1])
lastTs = int64(raw.args[2])
} else {
lastTs = int64(raw.args[1])
}
case EvFrequency:
ticksPerSec = int64(raw.args[0])
if ticksPerSec <= 0 {
// The most likely cause for this is tick skew on different CPUs.
// For example, solaris/amd64 seems to have wildly different
// ticks on different CPUs.
err = ErrTimeOrder
return
}
case EvTimerGoroutine:
timerGoids[raw.args[0]] = true
case EvStack:
if len(raw.args) < 2 {
err = fmt.Errorf("EvStack has wrong number of arguments at offset 0x%x: want at least 2, got %v",
raw.off, len(raw.args))
return
}
size := raw.args[1]
if size > 1000 {
err = fmt.Errorf("EvStack has bad number of frames at offset 0x%x: %v",
raw.off, size)
return
}
want := 2 + 4*size
if ver < 1007 {
want = 2 + size
}
if uint64(len(raw.args)) != want {
err = fmt.Errorf("EvStack has wrong number of arguments at offset 0x%x: want %v, got %v",
raw.off, want, len(raw.args))
return
}
id := raw.args[0]
if id != 0 && size > 0 {
stk := make([]*Frame, size)
for i := 0; i < int(size); i++ {
if ver < 1007 {
stk[i] = &Frame{PC: raw.args[2+i]}
} else {
pc := raw.args[2+i*4+0]
fn := raw.args[2+i*4+1]
file := raw.args[2+i*4+2]
line := raw.args[2+i*4+3]
stk[i] = &Frame{PC: pc, Fn: strings[fn], File: strings[file], Line: int(line)}
}
}
stacks[id] = stk
}
default:
e := &Event{Off: raw.off, Type: raw.typ, P: lastP, G: lastG}
var argOffset int
if ver < 1007 {
e.seq = lastSeq + int64(raw.args[0])
e.Ts = lastTs + int64(raw.args[1])
lastSeq = e.seq
argOffset = 2
} else {
e.Ts = lastTs + int64(raw.args[0])
argOffset = 1
}
lastTs = e.Ts
for i := argOffset; i < narg; i++ {
if i == narg-1 && desc.Stack {
e.StkID = raw.args[i]
} else {
e.Args[i-argOffset] = raw.args[i]
}
}
switch raw.typ {
case EvGoStart, EvGoStartLocal, EvGoStartLabel:
lastG = e.Args[0]
e.G = lastG
if raw.typ == EvGoStartLabel {
e.SArgs = []string{strings[e.Args[2]]}
}
case EvGCSTWStart:
e.G = 0
switch e.Args[0] {
case 0:
e.SArgs = []string{"mark termination"}
case 1:
e.SArgs = []string{"sweep termination"}
default:
err = fmt.Errorf("unknown STW kind %d", e.Args[0])
return
}
case EvGCStart, EvGCDone, EvGCSTWDone:
e.G = 0
case EvGoEnd, EvGoStop, EvGoSched, EvGoPreempt,
EvGoSleep, EvGoBlock, EvGoBlockSend, EvGoBlockRecv,
EvGoBlockSelect, EvGoBlockSync, EvGoBlockCond, EvGoBlockNet,
EvGoSysBlock, EvGoBlockGC:
lastG = 0
case EvGoSysExit, EvGoWaiting, EvGoInSyscall:
e.G = e.Args[0]
case EvUserTaskCreate:
// e.Args 0: taskID, 1:parentID, 2:nameID
e.SArgs = []string{strings[e.Args[2]]}
case EvUserRegion:
// e.Args 0: taskID, 1: mode, 2:nameID
e.SArgs = []string{strings[e.Args[2]]}
case EvUserLog:
// e.Args 0: taskID, 1:keyID, 2: stackID
e.SArgs = []string{strings[e.Args[1]], raw.sargs[0]}
case EvCPUSample:
e.Ts = int64(e.Args[0])
e.P = int(e.Args[1])
e.G = e.Args[2]
e.Args[0] = 0
}
switch raw.typ {
default:
batches[lastP] = append(batches[lastP], e)
case EvCPUSample:
// Most events are written out by the active P at the exact
// moment they describe. CPU profile samples are different
// because they're written to the tracing log after some delay,
// by a separate worker goroutine, into a separate buffer.
//
// We keep these in their own batch until all of the batches are
// merged in timestamp order. We also (right before the merge)
// re-sort these events by the timestamp captured in the
// profiling signal handler.
batches[ProfileP] = append(batches[ProfileP], e)
}
}
}
if len(batches) == 0 {
err = fmt.Errorf("trace is empty")
return
}
if ticksPerSec == 0 {
err = fmt.Errorf("no EvFrequency event")
return
}
if BreakTimestampsForTesting {
var batchArr [][]*Event
for _, batch := range batches {
batchArr = append(batchArr, batch)
}
for i := 0; i < 5; i++ {
batch := batchArr[rand.Intn(len(batchArr))]
batch[rand.Intn(len(batch))].Ts += int64(rand.Intn(2000) - 1000)
}
}
if ver < 1007 {
events, err = order1005(batches)
} else {
events, err = order1007(batches)
}
if err != nil {
return
}
// Translate cpu ticks to real time.
minTs := events[0].Ts
// Use floating point to avoid integer overflows.
freq := 1e9 / float64(ticksPerSec)
for _, ev := range events {
ev.Ts = int64(float64(ev.Ts-minTs) * freq)
// Move timers and syscalls to separate fake Ps.
if timerGoids[ev.G] && ev.Type == EvGoUnblock {
ev.P = TimerP
}
if ev.Type == EvGoSysExit {
ev.P = SyscallP
}
}
return
}
// removeFutile removes all constituents of futile wakeups (block, unblock, start).
// For example, a goroutine was unblocked on a mutex, but another goroutine got
// ahead and acquired the mutex before the first goroutine is scheduled,
// so the first goroutine has to block again. Such wakeups happen on buffered
// channels and sync.Mutex, but are generally not interesting for end user.
func removeFutile(events []*Event) []*Event {
// Two non-trivial aspects:
// 1. A goroutine can be preempted during a futile wakeup and migrate to another P.
// We want to remove all of that.
// 2. Tracing can start in the middle of a futile wakeup.
// That is, we can see a futile wakeup event w/o the actual wakeup before it.
// postProcessTrace runs after us and ensures that we leave the trace in a consistent state.
// Phase 1: determine futile wakeup sequences.
type G struct {
futile bool
wakeup []*Event // wakeup sequence (subject for removal)
}
gs := make(map[uint64]G)
futile := make(map[*Event]bool)
for _, ev := range events {
switch ev.Type {
case EvGoUnblock:
g := gs[ev.Args[0]]
g.wakeup = []*Event{ev}
gs[ev.Args[0]] = g
case EvGoStart, EvGoPreempt, EvFutileWakeup:
g := gs[ev.G]
g.wakeup = append(g.wakeup, ev)
if ev.Type == EvFutileWakeup {
g.futile = true
}
gs[ev.G] = g
case EvGoBlock, EvGoBlockSend, EvGoBlockRecv, EvGoBlockSelect, EvGoBlockSync, EvGoBlockCond:
g := gs[ev.G]
if g.futile {
futile[ev] = true
for _, ev1 := range g.wakeup {
futile[ev1] = true
}
}
delete(gs, ev.G)
}
}
// Phase 2: remove futile wakeup sequences.
newEvents := events[:0] // overwrite the original slice
for _, ev := range events {
if !futile[ev] {
newEvents = append(newEvents, ev)
}
}
return newEvents
}
// ErrTimeOrder is returned by Parse when the trace contains
// time stamps that do not respect actual event ordering.
var ErrTimeOrder = fmt.Errorf("time stamps out of order")
// postProcessTrace does inter-event verification and information restoration.
// The resulting trace is guaranteed to be consistent
// (for example, a P does not run two Gs at the same time, or a G is indeed
// blocked before an unblock event).
func postProcessTrace(ver int, events []*Event) error {
const (
gDead = iota
gRunnable
gRunning
gWaiting
)
type gdesc struct {
state int
ev *Event
evStart *Event
evCreate *Event
evMarkAssist *Event
}
type pdesc struct {
running bool
g uint64
evSTW *Event
evSweep *Event
}
gs := make(map[uint64]gdesc)
ps := make(map[int]pdesc)
tasks := make(map[uint64]*Event) // task id to task creation events
activeRegions := make(map[uint64][]*Event) // goroutine id to stack of regions
gs[0] = gdesc{state: gRunning}
var evGC, evSTW *Event
checkRunning := func(p pdesc, g gdesc, ev *Event, allowG0 bool) error {
name := EventDescriptions[ev.Type].Name
if g.state != gRunning {
return fmt.Errorf("g %v is not running while %v (offset %v, time %v)", ev.G, name, ev.Off, ev.Ts)
}
if p.g != ev.G {
return fmt.Errorf("p %v is not running g %v while %v (offset %v, time %v)", ev.P, ev.G, name, ev.Off, ev.Ts)
}
if !allowG0 && ev.G == 0 {
return fmt.Errorf("g 0 did %v (offset %v, time %v)", EventDescriptions[ev.Type].Name, ev.Off, ev.Ts)
}
return nil
}
for _, ev := range events {
g := gs[ev.G]
p := ps[ev.P]
switch ev.Type {
case EvProcStart:
if p.running {
return fmt.Errorf("p %v is running before start (offset %v, time %v)", ev.P, ev.Off, ev.Ts)
}
p.running = true
case EvProcStop:
if !p.running {
return fmt.Errorf("p %v is not running before stop (offset %v, time %v)", ev.P, ev.Off, ev.Ts)
}
if p.g != 0 {
return fmt.Errorf("p %v is running a goroutine %v during stop (offset %v, time %v)", ev.P, p.g, ev.Off, ev.Ts)
}
p.running = false
case EvGCStart:
if evGC != nil {
return fmt.Errorf("previous GC is not ended before a new one (offset %v, time %v)", ev.Off, ev.Ts)
}
evGC = ev
// Attribute this to the global GC state.
ev.P = GCP
case EvGCDone:
if evGC == nil {
return fmt.Errorf("bogus GC end (offset %v, time %v)", ev.Off, ev.Ts)
}
evGC.Link = ev
evGC = nil
case EvGCSTWStart:
evp := &evSTW
if ver < 1010 {
// Before 1.10, EvGCSTWStart was per-P.
evp = &p.evSTW
}
if *evp != nil {
return fmt.Errorf("previous STW is not ended before a new one (offset %v, time %v)", ev.Off, ev.Ts)
}
*evp = ev
case EvGCSTWDone:
evp := &evSTW
if ver < 1010 {
// Before 1.10, EvGCSTWDone was per-P.
evp = &p.evSTW
}
if *evp == nil {
return fmt.Errorf("bogus STW end (offset %v, time %v)", ev.Off, ev.Ts)
}
(*evp).Link = ev
*evp = nil
case EvGCSweepStart:
if p.evSweep != nil {
return fmt.Errorf("previous sweeping is not ended before a new one (offset %v, time %v)", ev.Off, ev.Ts)
}
p.evSweep = ev
case EvGCMarkAssistStart:
if g.evMarkAssist != nil {
return fmt.Errorf("previous mark assist is not ended before a new one (offset %v, time %v)", ev.Off, ev.Ts)
}
g.evMarkAssist = ev
case EvGCMarkAssistDone:
// Unlike most events, mark assists can be in progress when a
// goroutine starts tracing, so we can't report an error here.
if g.evMarkAssist != nil {
g.evMarkAssist.Link = ev
g.evMarkAssist = nil
}
case EvGCSweepDone:
if p.evSweep == nil {
return fmt.Errorf("bogus sweeping end (offset %v, time %v)", ev.Off, ev.Ts)
}
p.evSweep.Link = ev
p.evSweep = nil
case EvGoWaiting:
if g.state != gRunnable {
return fmt.Errorf("g %v is not runnable before EvGoWaiting (offset %v, time %v)", ev.G, ev.Off, ev.Ts)
}
g.state = gWaiting
g.ev = ev
case EvGoInSyscall:
if g.state != gRunnable {
return fmt.Errorf("g %v is not runnable before EvGoInSyscall (offset %v, time %v)", ev.G, ev.Off, ev.Ts)
}
g.state = gWaiting
g.ev = ev
case EvGoCreate:
if err := checkRunning(p, g, ev, true); err != nil {
return err
}
if _, ok := gs[ev.Args[0]]; ok {
return fmt.Errorf("g %v already exists (offset %v, time %v)", ev.Args[0], ev.Off, ev.Ts)
}
gs[ev.Args[0]] = gdesc{state: gRunnable, ev: ev, evCreate: ev}
case EvGoStart, EvGoStartLabel:
if g.state != gRunnable {
return fmt.Errorf("g %v is not runnable before start (offset %v, time %v)", ev.G, ev.Off, ev.Ts)
}
if p.g != 0 {
return fmt.Errorf("p %v is already running g %v while start g %v (offset %v, time %v)", ev.P, p.g, ev.G, ev.Off, ev.Ts)
}
g.state = gRunning
g.evStart = ev
p.g = ev.G
if g.evCreate != nil {
if ver < 1007 {
// +1 because symbolizer expects return pc.
ev.Stk = []*Frame{{PC: g.evCreate.Args[1] + 1}}
} else {
ev.StkID = g.evCreate.Args[1]
}
g.evCreate = nil
}
if g.ev != nil {
g.ev.Link = ev
g.ev = nil
}
case EvGoEnd, EvGoStop:
if err := checkRunning(p, g, ev, false); err != nil {
return err
}
g.evStart.Link = ev
g.evStart = nil
g.state = gDead
p.g = 0
if ev.Type == EvGoEnd { // flush all active regions
regions := activeRegions[ev.G]
for _, s := range regions {
s.Link = ev
}
delete(activeRegions, ev.G)
}
case EvGoSched, EvGoPreempt:
if err := checkRunning(p, g, ev, false); err != nil {
return err
}
g.state = gRunnable
g.evStart.Link = ev
g.evStart = nil
p.g = 0
g.ev = ev
case EvGoUnblock:
if g.state != gRunning {
return fmt.Errorf("g %v is not running while unpark (offset %v, time %v)", ev.G, ev.Off, ev.Ts)
}
if ev.P != TimerP && p.g != ev.G {
return fmt.Errorf("p %v is not running g %v while unpark (offset %v, time %v)", ev.P, ev.G, ev.Off, ev.Ts)
}
g1 := gs[ev.Args[0]]
if g1.state != gWaiting {
return fmt.Errorf("g %v is not waiting before unpark (offset %v, time %v)", ev.Args[0], ev.Off, ev.Ts)
}
if g1.ev != nil && g1.ev.Type == EvGoBlockNet && ev.P != TimerP {
ev.P = NetpollP
}
if g1.ev != nil {
g1.ev.Link = ev
}
g1.state = gRunnable
g1.ev = ev
gs[ev.Args[0]] = g1
case EvGoSysCall:
if err := checkRunning(p, g, ev, false); err != nil {
return err
}
g.ev = ev
case EvGoSysBlock:
if err := checkRunning(p, g, ev, false); err != nil {
return err
}
g.state = gWaiting
g.evStart.Link = ev
g.evStart = nil
p.g = 0
case EvGoSysExit:
if g.state != gWaiting {
return fmt.Errorf("g %v is not waiting during syscall exit (offset %v, time %v)", ev.G, ev.Off, ev.Ts)
}
if g.ev != nil && g.ev.Type == EvGoSysCall {
g.ev.Link = ev
}
g.state = gRunnable
g.ev = ev
case EvGoSleep, EvGoBlock, EvGoBlockSend, EvGoBlockRecv,
EvGoBlockSelect, EvGoBlockSync, EvGoBlockCond, EvGoBlockNet, EvGoBlockGC:
if err := checkRunning(p, g, ev, false); err != nil {
return err
}
g.state = gWaiting
g.ev = ev
g.evStart.Link = ev
g.evStart = nil
p.g = 0
case EvUserTaskCreate:
taskid := ev.Args[0]
if prevEv, ok := tasks[taskid]; ok {
return fmt.Errorf("task id conflicts (id:%d), %q vs %q", taskid, ev, prevEv)
}
tasks[ev.Args[0]] = ev
case EvUserTaskEnd:
taskid := ev.Args[0]
if taskCreateEv, ok := tasks[taskid]; ok {
taskCreateEv.Link = ev
delete(tasks, taskid)
}
case EvUserRegion:
mode := ev.Args[1]
regions := activeRegions[ev.G]
if mode == 0 { // region start
activeRegions[ev.G] = append(regions, ev) // push
} else if mode == 1 { // region end
n := len(regions)
if n > 0 { // matching region start event is in the trace.
s := regions[n-1]
if s.Args[0] != ev.Args[0] || s.SArgs[0] != ev.SArgs[0] { // task id, region name mismatch
return fmt.Errorf("misuse of region in goroutine %d: span end %q when the inner-most active span start event is %q", ev.G, ev, s)
}
// Link region start event with span end event
s.Link = ev
if n > 1 {
activeRegions[ev.G] = regions[:n-1]
} else {
delete(activeRegions, ev.G)
}
}
} else {
return fmt.Errorf("invalid user region mode: %q", ev)
}
}
gs[ev.G] = g
ps[ev.P] = p
}
// TODO(dvyukov): restore stacks for EvGoStart events.
// TODO(dvyukov): test that all EvGoStart events has non-nil Link.
return nil
}
// symbolize attaches func/file/line info to stack traces.
func symbolize(events []*Event, bin string) error {
// First, collect and dedup all pcs.
pcs := make(map[uint64]*Frame)
for _, ev := range events {
for _, f := range ev.Stk {
pcs[f.PC] = nil
}
}
// Start addr2line.
cmd := exec.Command(goCmd(), "tool", "addr2line", bin)
in, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to pipe addr2line stdin: %v", err)
}
cmd.Stderr = os.Stderr
out, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to pipe addr2line stdout: %v", err)
}
err = cmd.Start()
if err != nil {
return fmt.Errorf("failed to start addr2line: %v", err)
}
outb := bufio.NewReader(out)
// Write all pcs to addr2line.
// Need to copy pcs to an array, because map iteration order is non-deterministic.
var pcArray []uint64
for pc := range pcs {
pcArray = append(pcArray, pc)
_, err := fmt.Fprintf(in, "0x%x\n", pc-1)
if err != nil {
return fmt.Errorf("failed to write to addr2line: %v", err)
}
}
in.Close()
// Read in answers.
for _, pc := range pcArray {
fn, err := outb.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read from addr2line: %v", err)
}
file, err := outb.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read from addr2line: %v", err)
}
f := &Frame{PC: pc}
f.Fn = fn[:len(fn)-1]
f.File = file[:len(file)-1]
if colon := strings.LastIndex(f.File, ":"); colon != -1 {
ln, err := strconv.Atoi(f.File[colon+1:])
if err == nil {
f.File = f.File[:colon]
f.Line = ln
}
}
pcs[pc] = f
}
cmd.Wait()
// Replace frames in events array.
for _, ev := range events {
for i, f := range ev.Stk {
ev.Stk[i] = pcs[f.PC]
}
}
return nil
}
// readVal reads unsigned base-128 value from r.
func readVal(r io.Reader, off0 int) (v uint64, off int, err error) {
off = off0
for i := 0; i < 10; i++ {
var buf [1]byte
var n int
n, err = r.Read(buf[:])
if err != nil || n != 1 {
return 0, 0, fmt.Errorf("failed to read trace at offset %d: read %v, error %v", off0, n, err)
}
off++
v |= uint64(buf[0]&0x7f) << (uint(i) * 7)
if buf[0]&0x80 == 0 {
return
}
}
return 0, 0, fmt.Errorf("bad value at offset 0x%x", off0)
}
// Print dumps events to stdout. For debugging.
func Print(events []*Event) {
for _, ev := range events {
PrintEvent(ev)
}
}
// PrintEvent dumps the event to stdout. For debugging.
func PrintEvent(ev *Event) {
fmt.Printf("%s\n", ev)
}
func (ev *Event) String() string {
desc := EventDescriptions[ev.Type]
w := new(strings.Builder)
fmt.Fprintf(w, "%v %v p=%v g=%v off=%v", ev.Ts, desc.Name, ev.P, ev.G, ev.Off)
for i, a := range desc.Args {
fmt.Fprintf(w, " %v=%v", a, ev.Args[i])
}
for i, a := range desc.SArgs {
fmt.Fprintf(w, " %v=%v", a, ev.SArgs[i])
}
return w.String()
}
// argNum returns total number of args for the event accounting for timestamps,
// sequence numbers and differences between trace format versions.
func argNum(raw rawEvent, ver int) int {
desc := EventDescriptions[raw.typ]
if raw.typ == EvStack {
return len(raw.args)
}
narg := len(desc.Args)
if desc.Stack {
narg++
}
switch raw.typ {
case EvBatch, EvFrequency, EvTimerGoroutine:
if ver < 1007 {
narg++ // there was an unused arg before 1.7
}
return narg
}
narg++ // timestamp
if ver < 1007 {
narg++ // sequence
}
switch raw.typ {
case EvGCSweepDone:
if ver < 1009 {
narg -= 2 // 1.9 added two arguments
}
case EvGCStart, EvGoStart, EvGoUnblock:
if ver < 1007 {
narg-- // 1.7 added an additional seq arg
}
case EvGCSTWStart:
if ver < 1010 {
narg-- // 1.10 added an argument
}
}
return narg
}
// BreakTimestampsForTesting causes the parser to randomly alter timestamps (for testing of broken cputicks).
var BreakTimestampsForTesting bool
// Event types in the trace.
// Verbatim copy from src/runtime/trace.go with the "trace" prefix removed.
const (
EvNone = 0 // unused
EvBatch = 1 // start of per-P batch of events [pid, timestamp]
EvFrequency = 2 // contains tracer timer frequency [frequency (ticks per second)]
EvStack = 3 // stack [stack id, number of PCs, array of {PC, func string ID, file string ID, line}]
EvGomaxprocs = 4 // current value of GOMAXPROCS [timestamp, GOMAXPROCS, stack id]
EvProcStart = 5 // start of P [timestamp, thread id]
EvProcStop = 6 // stop of P [timestamp]
EvGCStart = 7 // GC start [timestamp, seq, stack id]
EvGCDone = 8 // GC done [timestamp]
EvGCSTWStart = 9 // GC mark termination start [timestamp, kind]
EvGCSTWDone = 10 // GC mark termination done [timestamp]
EvGCSweepStart = 11 // GC sweep start [timestamp, stack id]
EvGCSweepDone = 12 // GC sweep done [timestamp, swept, reclaimed]
EvGoCreate = 13 // goroutine creation [timestamp, new goroutine id, new stack id, stack id]
EvGoStart = 14 // goroutine starts running [timestamp, goroutine id, seq]
EvGoEnd = 15 // goroutine ends [timestamp]
EvGoStop = 16 // goroutine stops (like in select{}) [timestamp, stack]
EvGoSched = 17 // goroutine calls Gosched [timestamp, stack]
EvGoPreempt = 18 // goroutine is preempted [timestamp, stack]
EvGoSleep = 19 // goroutine calls Sleep [timestamp, stack]
EvGoBlock = 20 // goroutine blocks [timestamp, stack]
EvGoUnblock = 21 // goroutine is unblocked [timestamp, goroutine id, seq, stack]
EvGoBlockSend = 22 // goroutine blocks on chan send [timestamp, stack]
EvGoBlockRecv = 23 // goroutine blocks on chan recv [timestamp, stack]
EvGoBlockSelect = 24 // goroutine blocks on select [timestamp, stack]
EvGoBlockSync = 25 // goroutine blocks on Mutex/RWMutex [timestamp, stack]
EvGoBlockCond = 26 // goroutine blocks on Cond [timestamp, stack]
EvGoBlockNet = 27 // goroutine blocks on network [timestamp, stack]
EvGoSysCall = 28 // syscall enter [timestamp, stack]
EvGoSysExit = 29 // syscall exit [timestamp, goroutine id, seq, real timestamp]
EvGoSysBlock = 30 // syscall blocks [timestamp]
EvGoWaiting = 31 // denotes that goroutine is blocked when tracing starts [timestamp, goroutine id]
EvGoInSyscall = 32 // denotes that goroutine is in syscall when tracing starts [timestamp, goroutine id]
EvHeapAlloc = 33 // gcController.heapLive change [timestamp, heap live bytes]
EvHeapGoal = 34 // gcController.heapGoal change [timestamp, heap goal bytes]
EvTimerGoroutine = 35 // denotes timer goroutine [timer goroutine id]
EvFutileWakeup = 36 // denotes that the previous wakeup of this goroutine was futile [timestamp]
EvString = 37 // string dictionary entry [ID, length, string]
EvGoStartLocal = 38 // goroutine starts running on the same P as the last event [timestamp, goroutine id]
EvGoUnblockLocal = 39 // goroutine is unblocked on the same P as the last event [timestamp, goroutine id, stack]
EvGoSysExitLocal = 40 // syscall exit on the same P as the last event [timestamp, goroutine id, real timestamp]
EvGoStartLabel = 41 // goroutine starts running with label [timestamp, goroutine id, seq, label string id]
EvGoBlockGC = 42 // goroutine blocks on GC assist [timestamp, stack]
EvGCMarkAssistStart = 43 // GC mark assist start [timestamp, stack]
EvGCMarkAssistDone = 44 // GC mark assist done [timestamp]
EvUserTaskCreate = 45 // trace.NewContext [timestamp, internal task id, internal parent id, stack, name string]
EvUserTaskEnd = 46 // end of task [timestamp, internal task id, stack]
EvUserRegion = 47 // trace.WithRegion [timestamp, internal task id, mode(0:start, 1:end), stack, name string]
EvUserLog = 48 // trace.Log [timestamp, internal id, key string id, stack, value string]
EvCPUSample = 49 // CPU profiling sample [timestamp, real timestamp, real P id (-1 when absent), goroutine id, stack]
EvCount = 50
)
var EventDescriptions = [EvCount]struct {
Name string
minVersion int
Stack bool
Args []string
SArgs []string // string arguments
}{
EvNone: {"None", 1005, false, []string{}, nil},
EvBatch: {"Batch", 1005, false, []string{"p", "ticks"}, nil}, // in 1.5 format it was {"p", "seq", "ticks"}
EvFrequency: {"Frequency", 1005, false, []string{"freq"}, nil}, // in 1.5 format it was {"freq", "unused"}
EvStack: {"Stack", 1005, false, []string{"id", "siz"}, nil},
EvGomaxprocs: {"Gomaxprocs", 1005, true, []string{"procs"}, nil},
EvProcStart: {"ProcStart", 1005, false, []string{"thread"}, nil},
EvProcStop: {"ProcStop", 1005, false, []string{}, nil},
EvGCStart: {"GCStart", 1005, true, []string{"seq"}, nil}, // in 1.5 format it was {}
EvGCDone: {"GCDone", 1005, false, []string{}, nil},
EvGCSTWStart: {"GCSTWStart", 1005, false, []string{"kindid"}, []string{"kind"}}, // <= 1.9, args was {} (implicitly {0})
EvGCSTWDone: {"GCSTWDone", 1005, false, []string{}, nil},
EvGCSweepStart: {"GCSweepStart", 1005, true, []string{}, nil},
EvGCSweepDone: {"GCSweepDone", 1005, false, []string{"swept", "reclaimed"}, nil}, // before 1.9, format was {}
EvGoCreate: {"GoCreate", 1005, true, []string{"g", "stack"}, nil},
EvGoStart: {"GoStart", 1005, false, []string{"g", "seq"}, nil}, // in 1.5 format it was {"g"}
EvGoEnd: {"GoEnd", 1005, false, []string{}, nil},
EvGoStop: {"GoStop", 1005, true, []string{}, nil},
EvGoSched: {"GoSched", 1005, true, []string{}, nil},
EvGoPreempt: {"GoPreempt", 1005, true, []string{}, nil},
EvGoSleep: {"GoSleep", 1005, true, []string{}, nil},
EvGoBlock: {"GoBlock", 1005, true, []string{}, nil},
EvGoUnblock: {"GoUnblock", 1005, true, []string{"g", "seq"}, nil}, // in 1.5 format it was {"g"}
EvGoBlockSend: {"GoBlockSend", 1005, true, []string{}, nil},
EvGoBlockRecv: {"GoBlockRecv", 1005, true, []string{}, nil},
EvGoBlockSelect: {"GoBlockSelect", 1005, true, []string{}, nil},
EvGoBlockSync: {"GoBlockSync", 1005, true, []string{}, nil},
EvGoBlockCond: {"GoBlockCond", 1005, true, []string{}, nil},
EvGoBlockNet: {"GoBlockNet", 1005, true, []string{}, nil},
EvGoSysCall: {"GoSysCall", 1005, true, []string{}, nil},
EvGoSysExit: {"GoSysExit", 1005, false, []string{"g", "seq", "ts"}, nil},
EvGoSysBlock: {"GoSysBlock", 1005, false, []string{}, nil},
EvGoWaiting: {"GoWaiting", 1005, false, []string{"g"}, nil},
EvGoInSyscall: {"GoInSyscall", 1005, false, []string{"g"}, nil},
EvHeapAlloc: {"HeapAlloc", 1005, false, []string{"mem"}, nil},
EvHeapGoal: {"HeapGoal", 1005, false, []string{"mem"}, nil},
EvTimerGoroutine: {"TimerGoroutine", 1005, false, []string{"g"}, nil}, // in 1.5 format it was {"g", "unused"}
EvFutileWakeup: {"FutileWakeup", 1005, false, []string{}, nil},
EvString: {"String", 1007, false, []string{}, nil},
EvGoStartLocal: {"GoStartLocal", 1007, false, []string{"g"}, nil},
EvGoUnblockLocal: {"GoUnblockLocal", 1007, true, []string{"g"}, nil},
EvGoSysExitLocal: {"GoSysExitLocal", 1007, false, []string{"g", "ts"}, nil},
EvGoStartLabel: {"GoStartLabel", 1008, false, []string{"g", "seq", "labelid"}, []string{"label"}},
EvGoBlockGC: {"GoBlockGC", 1008, true, []string{}, nil},
EvGCMarkAssistStart: {"GCMarkAssistStart", 1009, true, []string{}, nil},
EvGCMarkAssistDone: {"GCMarkAssistDone", 1009, false, []string{}, nil},
EvUserTaskCreate: {"UserTaskCreate", 1011, true, []string{"taskid", "pid", "typeid"}, []string{"name"}},
EvUserTaskEnd: {"UserTaskEnd", 1011, true, []string{"taskid"}, nil},
EvUserRegion: {"UserRegion", 1011, true, []string{"taskid", "mode", "typeid"}, []string{"name"}},
EvUserLog: {"UserLog", 1011, true, []string{"id", "keyid"}, []string{"category", "message"}},
EvCPUSample: {"CPUSample", 1019, true, []string{"ts", "p", "g"}, nil},
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import "bytes"
// Writer is a test trace writer.
type Writer struct {
bytes.Buffer
}
func NewWriter() *Writer {
w := new(Writer)
w.Write([]byte("go 1.9 trace\x00\x00\x00\x00"))
return w
}
// Emit writes an event record to the trace.
// See Event types for valid types and required arguments.
func (w *Writer) Emit(typ byte, args ...uint64) {
nargs := byte(len(args)) - 1
if nargs > 3 {
nargs = 3
}
buf := []byte{typ | nargs<<6}
if nargs == 3 {
buf = append(buf, 0)
}
for _, a := range args {
buf = appendVarint(buf, a)
}
if nargs == 3 {
buf[1] = byte(len(buf) - 2)
}
n, err := w.Write(buf)
if n != len(buf) || err != nil {
panic("failed to write")
}
}
func appendVarint(buf []byte, v uint64) []byte {
for ; v >= 0x80; v >>= 7 {
buf = append(buf, 0x80|byte(v))
}
buf = append(buf, byte(v))
return buf
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xcoff
import (
"encoding/binary"
"fmt"
"io"
"os"
"strconv"
"strings"
)
const (
SAIAMAG = 0x8
AIAFMAG = "`\n"
AIAMAG = "<aiaff>\n"
AIAMAGBIG = "<bigaf>\n"
// Sizeof
FL_HSZ_BIG = 0x80
AR_HSZ_BIG = 0x70
)
type bigarFileHeader struct {
Flmagic [SAIAMAG]byte // Archive magic string
Flmemoff [20]byte // Member table offset
Flgstoff [20]byte // 32-bits global symtab offset
Flgst64off [20]byte // 64-bits global symtab offset
Flfstmoff [20]byte // First member offset
Fllstmoff [20]byte // Last member offset
Flfreeoff [20]byte // First member on free list offset
}
type bigarMemberHeader struct {
Arsize [20]byte // File member size
Arnxtmem [20]byte // Next member pointer
Arprvmem [20]byte // Previous member pointer
Ardate [12]byte // File member date
Aruid [12]byte // File member uid
Argid [12]byte // File member gid
Armode [12]byte // File member mode (octal)
Arnamlen [4]byte // File member name length
// _ar_nam is removed because it's easier to get name without it.
}
// Archive represents an open AIX big archive.
type Archive struct {
ArchiveHeader
Members []*Member
closer io.Closer
}
// MemberHeader holds information about a big archive file header
type ArchiveHeader struct {
magic string
}
// Member represents a member of an AIX big archive.
type Member struct {
MemberHeader
sr *io.SectionReader
}
// MemberHeader holds information about a big archive member
type MemberHeader struct {
Name string
Size uint64
}
// OpenArchive opens the named archive using os.Open and prepares it for use
// as an AIX big archive.
func OpenArchive(name string) (*Archive, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
arch, err := NewArchive(f)
if err != nil {
f.Close()
return nil, err
}
arch.closer = f
return arch, nil
}
// Close closes the Archive.
// If the Archive was created using NewArchive directly instead of OpenArchive,
// Close has no effect.
func (a *Archive) Close() error {
var err error
if a.closer != nil {
err = a.closer.Close()
a.closer = nil
}
return err
}
// NewArchive creates a new Archive for accessing an AIX big archive in an underlying reader.
func NewArchive(r io.ReaderAt) (*Archive, error) {
parseDecimalBytes := func(b []byte) (int64, error) {
return strconv.ParseInt(strings.TrimSpace(string(b)), 10, 64)
}
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read File Header
var magic [SAIAMAG]byte
if _, err := sr.ReadAt(magic[:], 0); err != nil {
return nil, err
}
arch := new(Archive)
switch string(magic[:]) {
case AIAMAGBIG:
arch.magic = string(magic[:])
case AIAMAG:
return nil, fmt.Errorf("small AIX archive not supported")
default:
return nil, fmt.Errorf("unrecognised archive magic: 0x%x", magic)
}
var fhdr bigarFileHeader
if _, err := sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
if err := binary.Read(sr, binary.BigEndian, &fhdr); err != nil {
return nil, err
}
off, err := parseDecimalBytes(fhdr.Flfstmoff[:])
if err != nil {
return nil, fmt.Errorf("error parsing offset of first member in archive header(%q); %v", fhdr, err)
}
if off == 0 {
// Occurs if the archive is empty.
return arch, nil
}
lastoff, err := parseDecimalBytes(fhdr.Fllstmoff[:])
if err != nil {
return nil, fmt.Errorf("error parsing offset of first member in archive header(%q); %v", fhdr, err)
}
// Read members
for {
// Read Member Header
// The member header is normally 2 bytes larger. But it's easier
// to read the name if the header is read without _ar_nam.
// However, AIAFMAG must be read afterward.
if _, err := sr.Seek(off, io.SeekStart); err != nil {
return nil, err
}
var mhdr bigarMemberHeader
if err := binary.Read(sr, binary.BigEndian, &mhdr); err != nil {
return nil, err
}
member := new(Member)
arch.Members = append(arch.Members, member)
size, err := parseDecimalBytes(mhdr.Arsize[:])
if err != nil {
return nil, fmt.Errorf("error parsing size in member header(%q); %v", mhdr, err)
}
member.Size = uint64(size)
// Read name
namlen, err := parseDecimalBytes(mhdr.Arnamlen[:])
if err != nil {
return nil, fmt.Errorf("error parsing name length in member header(%q); %v", mhdr, err)
}
name := make([]byte, namlen)
if err := binary.Read(sr, binary.BigEndian, name); err != nil {
return nil, err
}
member.Name = string(name)
fileoff := off + AR_HSZ_BIG + namlen
if fileoff&1 != 0 {
fileoff++
if _, err := sr.Seek(1, io.SeekCurrent); err != nil {
return nil, err
}
}
// Read AIAFMAG string
var fmag [2]byte
if err := binary.Read(sr, binary.BigEndian, &fmag); err != nil {
return nil, err
}
if string(fmag[:]) != AIAFMAG {
return nil, fmt.Errorf("AIAFMAG not found after member header")
}
fileoff += 2 // Add the two bytes of AIAFMAG
member.sr = io.NewSectionReader(sr, fileoff, size)
if off == lastoff {
break
}
off, err = parseDecimalBytes(mhdr.Arnxtmem[:])
if err != nil {
return nil, fmt.Errorf("error parsing offset of first member in archive header(%q); %v", fhdr, err)
}
}
return arch, nil
}
// GetFile returns the XCOFF file defined by member name.
// FIXME: This doesn't work if an archive has two members with the same
// name which can occur if a archive has both 32-bits and 64-bits files.
func (arch *Archive) GetFile(name string) (*File, error) {
for _, mem := range arch.Members {
if mem.Name == name {
return NewFile(mem.sr)
}
}
return nil, fmt.Errorf("unknown member %s in archive", name)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package xcoff implements access to XCOFF (Extended Common Object File Format) files.
package xcoff
import (
"debug/dwarf"
"encoding/binary"
"fmt"
"internal/saferio"
"io"
"os"
"strings"
)
// SectionHeader holds information about an XCOFF section header.
type SectionHeader struct {
Name string
VirtualAddress uint64
Size uint64
Type uint32
Relptr uint64
Nreloc uint32
}
type Section struct {
SectionHeader
Relocs []Reloc
io.ReaderAt
sr *io.SectionReader
}
// AuxiliaryCSect holds information about an XCOFF symbol in an AUX_CSECT entry.
type AuxiliaryCSect struct {
Length int64
StorageMappingClass int
SymbolType int
}
// AuxiliaryFcn holds information about an XCOFF symbol in an AUX_FCN entry.
type AuxiliaryFcn struct {
Size int64
}
type Symbol struct {
Name string
Value uint64
SectionNumber int
StorageClass int
AuxFcn AuxiliaryFcn
AuxCSect AuxiliaryCSect
}
type Reloc struct {
VirtualAddress uint64
Symbol *Symbol
Signed bool
InstructionFixed bool
Length uint8
Type uint8
}
// ImportedSymbol holds information about an imported XCOFF symbol.
type ImportedSymbol struct {
Name string
Library string
}
// FileHeader holds information about an XCOFF file header.
type FileHeader struct {
TargetMachine uint16
}
// A File represents an open XCOFF file.
type File struct {
FileHeader
Sections []*Section
Symbols []*Symbol
StringTable []byte
LibraryPaths []string
closer io.Closer
}
// Open opens the named file using os.Open and prepares it for use as an XCOFF binary.
func Open(name string) (*File, error) {
f, err := os.Open(name)
if err != nil {
return nil, err
}
ff, err := NewFile(f)
if err != nil {
f.Close()
return nil, err
}
ff.closer = f
return ff, nil
}
// Close closes the File.
// If the File was created using NewFile directly instead of Open,
// Close has no effect.
func (f *File) Close() error {
var err error
if f.closer != nil {
err = f.closer.Close()
f.closer = nil
}
return err
}
// Section returns the first section with the given name, or nil if no such
// section exists.
// Xcoff have section's name limited to 8 bytes. Some sections like .gosymtab
// can be trunked but this method will still find them.
func (f *File) Section(name string) *Section {
for _, s := range f.Sections {
if s.Name == name || (len(name) > 8 && s.Name == name[:8]) {
return s
}
}
return nil
}
// SectionByType returns the first section in f with the
// given type, or nil if there is no such section.
func (f *File) SectionByType(typ uint32) *Section {
for _, s := range f.Sections {
if s.Type == typ {
return s
}
}
return nil
}
// cstring converts ASCII byte sequence b to string.
// It stops once it finds 0 or reaches end of b.
func cstring(b []byte) string {
var i int
for i = 0; i < len(b) && b[i] != 0; i++ {
}
return string(b[:i])
}
// getString extracts a string from an XCOFF string table.
func getString(st []byte, offset uint32) (string, bool) {
if offset < 4 || int(offset) >= len(st) {
return "", false
}
return cstring(st[offset:]), true
}
// NewFile creates a new File for accessing an XCOFF binary in an underlying reader.
func NewFile(r io.ReaderAt) (*File, error) {
sr := io.NewSectionReader(r, 0, 1<<63-1)
// Read XCOFF target machine
var magic uint16
if err := binary.Read(sr, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != U802TOCMAGIC && magic != U64_TOCMAGIC {
return nil, fmt.Errorf("unrecognised XCOFF magic: 0x%x", magic)
}
f := new(File)
f.TargetMachine = magic
// Read XCOFF file header
if _, err := sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
var nscns uint16
var symptr uint64
var nsyms uint32
var opthdr uint16
var hdrsz int
switch f.TargetMachine {
case U802TOCMAGIC:
fhdr := new(FileHeader32)
if err := binary.Read(sr, binary.BigEndian, fhdr); err != nil {
return nil, err
}
nscns = fhdr.Fnscns
symptr = uint64(fhdr.Fsymptr)
nsyms = fhdr.Fnsyms
opthdr = fhdr.Fopthdr
hdrsz = FILHSZ_32
case U64_TOCMAGIC:
fhdr := new(FileHeader64)
if err := binary.Read(sr, binary.BigEndian, fhdr); err != nil {
return nil, err
}
nscns = fhdr.Fnscns
symptr = fhdr.Fsymptr
nsyms = fhdr.Fnsyms
opthdr = fhdr.Fopthdr
hdrsz = FILHSZ_64
}
if symptr == 0 || nsyms <= 0 {
return nil, fmt.Errorf("no symbol table")
}
// Read string table (located right after symbol table).
offset := symptr + uint64(nsyms)*SYMESZ
if _, err := sr.Seek(int64(offset), io.SeekStart); err != nil {
return nil, err
}
// The first 4 bytes contain the length (in bytes).
var l uint32
if err := binary.Read(sr, binary.BigEndian, &l); err != nil {
return nil, err
}
if l > 4 {
st, err := saferio.ReadDataAt(sr, uint64(l), int64(offset))
if err != nil {
return nil, err
}
f.StringTable = st
}
// Read section headers
if _, err := sr.Seek(int64(hdrsz)+int64(opthdr), io.SeekStart); err != nil {
return nil, err
}
c := saferio.SliceCap((**Section)(nil), uint64(nscns))
if c < 0 {
return nil, fmt.Errorf("too many XCOFF sections (%d)", nscns)
}
f.Sections = make([]*Section, 0, c)
for i := 0; i < int(nscns); i++ {
var scnptr uint64
s := new(Section)
switch f.TargetMachine {
case U802TOCMAGIC:
shdr := new(SectionHeader32)
if err := binary.Read(sr, binary.BigEndian, shdr); err != nil {
return nil, err
}
s.Name = cstring(shdr.Sname[:])
s.VirtualAddress = uint64(shdr.Svaddr)
s.Size = uint64(shdr.Ssize)
scnptr = uint64(shdr.Sscnptr)
s.Type = shdr.Sflags
s.Relptr = uint64(shdr.Srelptr)
s.Nreloc = uint32(shdr.Snreloc)
case U64_TOCMAGIC:
shdr := new(SectionHeader64)
if err := binary.Read(sr, binary.BigEndian, shdr); err != nil {
return nil, err
}
s.Name = cstring(shdr.Sname[:])
s.VirtualAddress = shdr.Svaddr
s.Size = shdr.Ssize
scnptr = shdr.Sscnptr
s.Type = shdr.Sflags
s.Relptr = shdr.Srelptr
s.Nreloc = shdr.Snreloc
}
r2 := r
if scnptr == 0 { // .bss must have all 0s
r2 = zeroReaderAt{}
}
s.sr = io.NewSectionReader(r2, int64(scnptr), int64(s.Size))
s.ReaderAt = s.sr
f.Sections = append(f.Sections, s)
}
// Symbol map needed by relocation
var idxToSym = make(map[int]*Symbol)
// Read symbol table
if _, err := sr.Seek(int64(symptr), io.SeekStart); err != nil {
return nil, err
}
f.Symbols = make([]*Symbol, 0)
for i := 0; i < int(nsyms); i++ {
var numaux int
var ok, needAuxFcn bool
sym := new(Symbol)
switch f.TargetMachine {
case U802TOCMAGIC:
se := new(SymEnt32)
if err := binary.Read(sr, binary.BigEndian, se); err != nil {
return nil, err
}
numaux = int(se.Nnumaux)
sym.SectionNumber = int(se.Nscnum)
sym.StorageClass = int(se.Nsclass)
sym.Value = uint64(se.Nvalue)
needAuxFcn = se.Ntype&SYM_TYPE_FUNC != 0 && numaux > 1
zeroes := binary.BigEndian.Uint32(se.Nname[:4])
if zeroes != 0 {
sym.Name = cstring(se.Nname[:])
} else {
offset := binary.BigEndian.Uint32(se.Nname[4:])
sym.Name, ok = getString(f.StringTable, offset)
if !ok {
goto skip
}
}
case U64_TOCMAGIC:
se := new(SymEnt64)
if err := binary.Read(sr, binary.BigEndian, se); err != nil {
return nil, err
}
numaux = int(se.Nnumaux)
sym.SectionNumber = int(se.Nscnum)
sym.StorageClass = int(se.Nsclass)
sym.Value = se.Nvalue
needAuxFcn = se.Ntype&SYM_TYPE_FUNC != 0 && numaux > 1
sym.Name, ok = getString(f.StringTable, se.Noffset)
if !ok {
goto skip
}
}
if sym.StorageClass != C_EXT && sym.StorageClass != C_WEAKEXT && sym.StorageClass != C_HIDEXT {
goto skip
}
// Must have at least one csect auxiliary entry.
if numaux < 1 || i+numaux >= int(nsyms) {
goto skip
}
if sym.SectionNumber > int(nscns) {
goto skip
}
if sym.SectionNumber == 0 {
sym.Value = 0
} else {
sym.Value -= f.Sections[sym.SectionNumber-1].VirtualAddress
}
idxToSym[i] = sym
// If this symbol is a function, it must retrieve its size from
// its AUX_FCN entry.
// It can happen that a function symbol doesn't have any AUX_FCN.
// In this case, needAuxFcn is false and their size will be set to 0.
if needAuxFcn {
switch f.TargetMachine {
case U802TOCMAGIC:
aux := new(AuxFcn32)
if err := binary.Read(sr, binary.BigEndian, aux); err != nil {
return nil, err
}
sym.AuxFcn.Size = int64(aux.Xfsize)
case U64_TOCMAGIC:
aux := new(AuxFcn64)
if err := binary.Read(sr, binary.BigEndian, aux); err != nil {
return nil, err
}
sym.AuxFcn.Size = int64(aux.Xfsize)
}
}
// Read csect auxiliary entry (by convention, it is the last).
if !needAuxFcn {
if _, err := sr.Seek(int64(numaux-1)*SYMESZ, io.SeekCurrent); err != nil {
return nil, err
}
}
i += numaux
numaux = 0
switch f.TargetMachine {
case U802TOCMAGIC:
aux := new(AuxCSect32)
if err := binary.Read(sr, binary.BigEndian, aux); err != nil {
return nil, err
}
sym.AuxCSect.SymbolType = int(aux.Xsmtyp & 0x7)
sym.AuxCSect.StorageMappingClass = int(aux.Xsmclas)
sym.AuxCSect.Length = int64(aux.Xscnlen)
case U64_TOCMAGIC:
aux := new(AuxCSect64)
if err := binary.Read(sr, binary.BigEndian, aux); err != nil {
return nil, err
}
sym.AuxCSect.SymbolType = int(aux.Xsmtyp & 0x7)
sym.AuxCSect.StorageMappingClass = int(aux.Xsmclas)
sym.AuxCSect.Length = int64(aux.Xscnlenhi)<<32 | int64(aux.Xscnlenlo)
}
f.Symbols = append(f.Symbols, sym)
skip:
i += numaux // Skip auxiliary entries
if _, err := sr.Seek(int64(numaux)*SYMESZ, io.SeekCurrent); err != nil {
return nil, err
}
}
// Read relocations
// Only for .data or .text section
for sectNum, sect := range f.Sections {
if sect.Type != STYP_TEXT && sect.Type != STYP_DATA {
continue
}
if sect.Relptr == 0 {
continue
}
c := saferio.SliceCap((*Reloc)(nil), uint64(sect.Nreloc))
if c < 0 {
return nil, fmt.Errorf("too many relocs (%d) for section %d", sect.Nreloc, sectNum)
}
sect.Relocs = make([]Reloc, 0, c)
if _, err := sr.Seek(int64(sect.Relptr), io.SeekStart); err != nil {
return nil, err
}
for i := uint32(0); i < sect.Nreloc; i++ {
var reloc Reloc
switch f.TargetMachine {
case U802TOCMAGIC:
rel := new(Reloc32)
if err := binary.Read(sr, binary.BigEndian, rel); err != nil {
return nil, err
}
reloc.VirtualAddress = uint64(rel.Rvaddr)
reloc.Symbol = idxToSym[int(rel.Rsymndx)]
reloc.Type = rel.Rtype
reloc.Length = rel.Rsize&0x3F + 1
if rel.Rsize&0x80 != 0 {
reloc.Signed = true
}
if rel.Rsize&0x40 != 0 {
reloc.InstructionFixed = true
}
case U64_TOCMAGIC:
rel := new(Reloc64)
if err := binary.Read(sr, binary.BigEndian, rel); err != nil {
return nil, err
}
reloc.VirtualAddress = rel.Rvaddr
reloc.Symbol = idxToSym[int(rel.Rsymndx)]
reloc.Type = rel.Rtype
reloc.Length = rel.Rsize&0x3F + 1
if rel.Rsize&0x80 != 0 {
reloc.Signed = true
}
if rel.Rsize&0x40 != 0 {
reloc.InstructionFixed = true
}
}
sect.Relocs = append(sect.Relocs, reloc)
}
}
return f, nil
}
// zeroReaderAt is ReaderAt that reads 0s.
type zeroReaderAt struct{}
// ReadAt writes len(p) 0s into p.
func (w zeroReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
for i := range p {
p[i] = 0
}
return len(p), nil
}
// Data reads and returns the contents of the XCOFF section s.
func (s *Section) Data() ([]byte, error) {
dat := make([]byte, s.sr.Size())
n, err := s.sr.ReadAt(dat, 0)
if n == len(dat) {
err = nil
}
return dat[:n], err
}
// CSect reads and returns the contents of a csect.
func (f *File) CSect(name string) []byte {
for _, sym := range f.Symbols {
if sym.Name == name && sym.AuxCSect.SymbolType == XTY_SD {
if i := sym.SectionNumber - 1; 0 <= i && i < len(f.Sections) {
s := f.Sections[i]
if sym.Value+uint64(sym.AuxCSect.Length) <= s.Size {
dat := make([]byte, sym.AuxCSect.Length)
_, err := s.sr.ReadAt(dat, int64(sym.Value))
if err != nil {
return nil
}
return dat
}
}
break
}
}
return nil
}
func (f *File) DWARF() (*dwarf.Data, error) {
// There are many other DWARF sections, but these
// are the ones the debug/dwarf package uses.
// Don't bother loading others.
var subtypes = [...]uint32{SSUBTYP_DWABREV, SSUBTYP_DWINFO, SSUBTYP_DWLINE, SSUBTYP_DWRNGES, SSUBTYP_DWSTR}
var dat [len(subtypes)][]byte
for i, subtype := range subtypes {
s := f.SectionByType(STYP_DWARF | subtype)
if s != nil {
b, err := s.Data()
if err != nil && uint64(len(b)) < s.Size {
return nil, err
}
dat[i] = b
}
}
abbrev, info, line, ranges, str := dat[0], dat[1], dat[2], dat[3], dat[4]
return dwarf.New(abbrev, nil, nil, info, line, nil, ranges, str)
}
// readImportID returns the import file IDs stored inside the .loader section.
// Library name pattern is either path/base/member or base/member
func (f *File) readImportIDs(s *Section) ([]string, error) {
// Read loader header
if _, err := s.sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
var istlen uint32
var nimpid uint32
var impoff uint64
switch f.TargetMachine {
case U802TOCMAGIC:
lhdr := new(LoaderHeader32)
if err := binary.Read(s.sr, binary.BigEndian, lhdr); err != nil {
return nil, err
}
istlen = lhdr.Listlen
nimpid = lhdr.Lnimpid
impoff = uint64(lhdr.Limpoff)
case U64_TOCMAGIC:
lhdr := new(LoaderHeader64)
if err := binary.Read(s.sr, binary.BigEndian, lhdr); err != nil {
return nil, err
}
istlen = lhdr.Listlen
nimpid = lhdr.Lnimpid
impoff = lhdr.Limpoff
}
// Read loader import file ID table
if _, err := s.sr.Seek(int64(impoff), io.SeekStart); err != nil {
return nil, err
}
table := make([]byte, istlen)
if _, err := io.ReadFull(s.sr, table); err != nil {
return nil, err
}
offset := 0
// First import file ID is the default LIBPATH value
libpath := cstring(table[offset:])
f.LibraryPaths = strings.Split(libpath, ":")
offset += len(libpath) + 3 // 3 null bytes
all := make([]string, 0)
for i := 1; i < int(nimpid); i++ {
impidpath := cstring(table[offset:])
offset += len(impidpath) + 1
impidbase := cstring(table[offset:])
offset += len(impidbase) + 1
impidmem := cstring(table[offset:])
offset += len(impidmem) + 1
var path string
if len(impidpath) > 0 {
path = impidpath + "/" + impidbase + "/" + impidmem
} else {
path = impidbase + "/" + impidmem
}
all = append(all, path)
}
return all, nil
}
// ImportedSymbols returns the names of all symbols
// referred to by the binary f that are expected to be
// satisfied by other libraries at dynamic load time.
// It does not return weak symbols.
func (f *File) ImportedSymbols() ([]ImportedSymbol, error) {
s := f.SectionByType(STYP_LOADER)
if s == nil {
return nil, nil
}
// Read loader header
if _, err := s.sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
var stlen uint32
var stoff uint64
var nsyms uint32
var symoff uint64
switch f.TargetMachine {
case U802TOCMAGIC:
lhdr := new(LoaderHeader32)
if err := binary.Read(s.sr, binary.BigEndian, lhdr); err != nil {
return nil, err
}
stlen = lhdr.Lstlen
stoff = uint64(lhdr.Lstoff)
nsyms = lhdr.Lnsyms
symoff = LDHDRSZ_32
case U64_TOCMAGIC:
lhdr := new(LoaderHeader64)
if err := binary.Read(s.sr, binary.BigEndian, lhdr); err != nil {
return nil, err
}
stlen = lhdr.Lstlen
stoff = lhdr.Lstoff
nsyms = lhdr.Lnsyms
symoff = lhdr.Lsymoff
}
// Read loader section string table
if _, err := s.sr.Seek(int64(stoff), io.SeekStart); err != nil {
return nil, err
}
st := make([]byte, stlen)
if _, err := io.ReadFull(s.sr, st); err != nil {
return nil, err
}
// Read imported libraries
libs, err := f.readImportIDs(s)
if err != nil {
return nil, err
}
// Read loader symbol table
if _, err := s.sr.Seek(int64(symoff), io.SeekStart); err != nil {
return nil, err
}
all := make([]ImportedSymbol, 0)
for i := 0; i < int(nsyms); i++ {
var name string
var ifile uint32
var ok bool
switch f.TargetMachine {
case U802TOCMAGIC:
ldsym := new(LoaderSymbol32)
if err := binary.Read(s.sr, binary.BigEndian, ldsym); err != nil {
return nil, err
}
if ldsym.Lsmtype&0x40 == 0 {
continue // Imported symbols only
}
zeroes := binary.BigEndian.Uint32(ldsym.Lname[:4])
if zeroes != 0 {
name = cstring(ldsym.Lname[:])
} else {
offset := binary.BigEndian.Uint32(ldsym.Lname[4:])
name, ok = getString(st, offset)
if !ok {
continue
}
}
ifile = ldsym.Lifile
case U64_TOCMAGIC:
ldsym := new(LoaderSymbol64)
if err := binary.Read(s.sr, binary.BigEndian, ldsym); err != nil {
return nil, err
}
if ldsym.Lsmtype&0x40 == 0 {
continue // Imported symbols only
}
name, ok = getString(st, ldsym.Loffset)
if !ok {
continue
}
ifile = ldsym.Lifile
}
var sym ImportedSymbol
sym.Name = name
if ifile >= 1 && int(ifile) <= len(libs) {
sym.Library = libs[ifile-1]
}
all = append(all, sym)
}
return all, nil
}
// ImportedLibraries returns the names of all libraries
// referred to by the binary f that are expected to be
// linked with the binary at dynamic link time.
func (f *File) ImportedLibraries() ([]string, error) {
s := f.SectionByType(STYP_LOADER)
if s == nil {
return nil, nil
}
all, err := f.readImportIDs(s)
return all, err
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fs defines basic interfaces to a file system.
// A file system can be provided by the host operating system
// but also by other packages.
package fs
import (
"internal/oserror"
"time"
"unicode/utf8"
)
// An FS provides access to a hierarchical file system.
//
// The FS interface is the minimum implementation required of the file system.
// A file system may implement additional interfaces,
// such as ReadFileFS, to provide additional or optimized functionality.
type FS interface {
// Open opens the named file.
//
// When Open returns an error, it should be of type *PathError
// with the Op field set to "open", the Path field set to name,
// and the Err field describing the problem.
//
// Open should reject attempts to open names that do not satisfy
// ValidPath(name), returning a *PathError with Err set to
// ErrInvalid or ErrNotExist.
Open(name string) (File, error)
}
// ValidPath reports whether the given path name
// is valid for use in a call to Open.
//
// Path names passed to open are UTF-8-encoded,
// unrooted, slash-separated sequences of path elements, like “x/y/z”.
// Path names must not contain an element that is “.” or “..” or the empty string,
// except for the special case that the root directory is named “.”.
// Paths must not start or end with a slash: “/x” and “x/” are invalid.
//
// Note that paths are slash-separated on all systems, even Windows.
// Paths containing other characters such as backslash and colon
// are accepted as valid, but those characters must never be
// interpreted by an FS implementation as path element separators.
func ValidPath(name string) bool {
if !utf8.ValidString(name) {
return false
}
if name == "." {
// special case
return true
}
// Iterate over elements in name, checking each.
for {
i := 0
for i < len(name) && name[i] != '/' {
i++
}
elem := name[:i]
if elem == "" || elem == "." || elem == ".." {
return false
}
if i == len(name) {
return true // reached clean ending
}
name = name[i+1:]
}
}
// A File provides access to a single file.
// The File interface is the minimum implementation required of the file.
// Directory files should also implement ReadDirFile.
// A file may implement io.ReaderAt or io.Seeker as optimizations.
type File interface {
Stat() (FileInfo, error)
Read([]byte) (int, error)
Close() error
}
// A DirEntry is an entry read from a directory
// (using the ReadDir function or a ReadDirFile's ReadDir method).
type DirEntry interface {
// Name returns the name of the file (or subdirectory) described by the entry.
// This name is only the final element of the path (the base name), not the entire path.
// For example, Name would return "hello.go" not "home/gopher/hello.go".
Name() string
// IsDir reports whether the entry describes a directory.
IsDir() bool
// Type returns the type bits for the entry.
// The type bits are a subset of the usual FileMode bits, those returned by the FileMode.Type method.
Type() FileMode
// Info returns the FileInfo for the file or subdirectory described by the entry.
// The returned FileInfo may be from the time of the original directory read
// or from the time of the call to Info. If the file has been removed or renamed
// since the directory read, Info may return an error satisfying errors.Is(err, ErrNotExist).
// If the entry denotes a symbolic link, Info reports the information about the link itself,
// not the link's target.
Info() (FileInfo, error)
}
// A ReadDirFile is a directory file whose entries can be read with the ReadDir method.
// Every directory file should implement this interface.
// (It is permissible for any file to implement this interface,
// but if so ReadDir should return an error for non-directories.)
type ReadDirFile interface {
File
// ReadDir reads the contents of the directory and returns
// a slice of up to n DirEntry values in directory order.
// Subsequent calls on the same file will yield further DirEntry values.
//
// If n > 0, ReadDir returns at most n DirEntry structures.
// In this case, if ReadDir returns an empty slice, it will return
// a non-nil error explaining why.
// At the end of a directory, the error is io.EOF.
// (ReadDir must return io.EOF itself, not an error wrapping io.EOF.)
//
// If n <= 0, ReadDir returns all the DirEntry values from the directory
// in a single slice. In this case, if ReadDir succeeds (reads all the way
// to the end of the directory), it returns the slice and a nil error.
// If it encounters an error before the end of the directory,
// ReadDir returns the DirEntry list read until that point and a non-nil error.
ReadDir(n int) ([]DirEntry, error)
}
// Generic file system errors.
// Errors returned by file systems can be tested against these errors
// using errors.Is.
var (
ErrInvalid = errInvalid() // "invalid argument"
ErrPermission = errPermission() // "permission denied"
ErrExist = errExist() // "file already exists"
ErrNotExist = errNotExist() // "file does not exist"
ErrClosed = errClosed() // "file already closed"
)
func errInvalid() error { return oserror.ErrInvalid }
func errPermission() error { return oserror.ErrPermission }
func errExist() error { return oserror.ErrExist }
func errNotExist() error { return oserror.ErrNotExist }
func errClosed() error { return oserror.ErrClosed }
// A FileInfo describes a file and is returned by Stat.
type FileInfo interface {
Name() string // base name of the file
Size() int64 // length in bytes for regular files; system-dependent for others
Mode() FileMode // file mode bits
ModTime() time.Time // modification time
IsDir() bool // abbreviation for Mode().IsDir()
Sys() any // underlying data source (can return nil)
}
// A FileMode represents a file's mode and permission bits.
// The bits have the same definition on all systems, so that
// information about files can be moved from one system
// to another portably. Not all bits apply to all systems.
// The only required bit is ModeDir for directories.
type FileMode uint32
// The defined file mode bits are the most significant bits of the FileMode.
// The nine least-significant bits are the standard Unix rwxrwxrwx permissions.
// The values of these bits should be considered part of the public API and
// may be used in wire protocols or disk representations: they must not be
// changed, although new bits might be added.
const (
// The single letters are the abbreviations
// used by the String method's formatting.
ModeDir FileMode = 1 << (32 - 1 - iota) // d: is a directory
ModeAppend // a: append-only
ModeExclusive // l: exclusive use
ModeTemporary // T: temporary file; Plan 9 only
ModeSymlink // L: symbolic link
ModeDevice // D: device file
ModeNamedPipe // p: named pipe (FIFO)
ModeSocket // S: Unix domain socket
ModeSetuid // u: setuid
ModeSetgid // g: setgid
ModeCharDevice // c: Unix character device, when ModeDevice is set
ModeSticky // t: sticky
ModeIrregular // ?: non-regular file; nothing else is known about this file
// Mask for the type bits. For regular files, none will be set.
ModeType = ModeDir | ModeSymlink | ModeNamedPipe | ModeSocket | ModeDevice | ModeCharDevice | ModeIrregular
ModePerm FileMode = 0777 // Unix permission bits
)
func (m FileMode) String() string {
const str = "dalTLDpSugct?"
var buf [32]byte // Mode is uint32.
w := 0
for i, c := range str {
if m&(1<<uint(32-1-i)) != 0 {
buf[w] = byte(c)
w++
}
}
if w == 0 {
buf[w] = '-'
w++
}
const rwx = "rwxrwxrwx"
for i, c := range rwx {
if m&(1<<uint(9-1-i)) != 0 {
buf[w] = byte(c)
} else {
buf[w] = '-'
}
w++
}
return string(buf[:w])
}
// IsDir reports whether m describes a directory.
// That is, it tests for the ModeDir bit being set in m.
func (m FileMode) IsDir() bool {
return m&ModeDir != 0
}
// IsRegular reports whether m describes a regular file.
// That is, it tests that no mode type bits are set.
func (m FileMode) IsRegular() bool {
return m&ModeType == 0
}
// Perm returns the Unix permission bits in m (m & ModePerm).
func (m FileMode) Perm() FileMode {
return m & ModePerm
}
// Type returns type bits in m (m & ModeType).
func (m FileMode) Type() FileMode {
return m & ModeType
}
// PathError records an error and the operation and file path that caused it.
type PathError struct {
Op string
Path string
Err error
}
func (e *PathError) Error() string { return e.Op + " " + e.Path + ": " + e.Err.Error() }
func (e *PathError) Unwrap() error { return e.Err }
// Timeout reports whether this error represents a timeout.
func (e *PathError) Timeout() bool {
t, ok := e.Err.(interface{ Timeout() bool })
return ok && t.Timeout()
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
import (
"path"
)
// A GlobFS is a file system with a Glob method.
type GlobFS interface {
FS
// Glob returns the names of all files matching pattern,
// providing an implementation of the top-level
// Glob function.
Glob(pattern string) ([]string, error)
}
// Glob returns the names of all files matching pattern or nil
// if there is no matching file. The syntax of patterns is the same
// as in path.Match. The pattern may describe hierarchical names such as
// usr/*/bin/ed.
//
// Glob ignores file system errors such as I/O errors reading directories.
// The only possible returned error is path.ErrBadPattern, reporting that
// the pattern is malformed.
//
// If fs implements GlobFS, Glob calls fs.Glob.
// Otherwise, Glob uses ReadDir to traverse the directory tree
// and look for matches for the pattern.
func Glob(fsys FS, pattern string) (matches []string, err error) {
return globWithLimit(fsys, pattern, 0)
}
func globWithLimit(fsys FS, pattern string, depth int) (matches []string, err error) {
// This limit is added to prevent stack exhaustion issues. See
// CVE-2022-30630.
const pathSeparatorsLimit = 10000
if depth > pathSeparatorsLimit {
return nil, path.ErrBadPattern
}
if fsys, ok := fsys.(GlobFS); ok {
return fsys.Glob(pattern)
}
// Check pattern is well-formed.
if _, err := path.Match(pattern, ""); err != nil {
return nil, err
}
if !hasMeta(pattern) {
if _, err = Stat(fsys, pattern); err != nil {
return nil, nil
}
return []string{pattern}, nil
}
dir, file := path.Split(pattern)
dir = cleanGlobPath(dir)
if !hasMeta(dir) {
return glob(fsys, dir, file, nil)
}
// Prevent infinite recursion. See issue 15879.
if dir == pattern {
return nil, path.ErrBadPattern
}
var m []string
m, err = globWithLimit(fsys, dir, depth+1)
if err != nil {
return nil, err
}
for _, d := range m {
matches, err = glob(fsys, d, file, matches)
if err != nil {
return
}
}
return
}
// cleanGlobPath prepares path for glob matching.
func cleanGlobPath(path string) string {
switch path {
case "":
return "."
default:
return path[0 : len(path)-1] // chop off trailing separator
}
}
// glob searches for files matching pattern in the directory dir
// and appends them to matches, returning the updated slice.
// If the directory cannot be opened, glob returns the existing matches.
// New matches are added in lexicographical order.
func glob(fs FS, dir, pattern string, matches []string) (m []string, e error) {
m = matches
infos, err := ReadDir(fs, dir)
if err != nil {
return // ignore I/O error
}
for _, info := range infos {
n := info.Name()
matched, err := path.Match(pattern, n)
if err != nil {
return m, err
}
if matched {
m = append(m, path.Join(dir, n))
}
}
return
}
// hasMeta reports whether path contains any of the magic characters
// recognized by path.Match.
func hasMeta(path string) bool {
for i := 0; i < len(path); i++ {
switch path[i] {
case '*', '?', '[', '\\':
return true
}
}
return false
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
import (
"errors"
"sort"
)
// ReadDirFS is the interface implemented by a file system
// that provides an optimized implementation of ReadDir.
type ReadDirFS interface {
FS
// ReadDir reads the named directory
// and returns a list of directory entries sorted by filename.
ReadDir(name string) ([]DirEntry, error)
}
// ReadDir reads the named directory
// and returns a list of directory entries sorted by filename.
//
// If fs implements ReadDirFS, ReadDir calls fs.ReadDir.
// Otherwise ReadDir calls fs.Open and uses ReadDir and Close
// on the returned file.
func ReadDir(fsys FS, name string) ([]DirEntry, error) {
if fsys, ok := fsys.(ReadDirFS); ok {
return fsys.ReadDir(name)
}
file, err := fsys.Open(name)
if err != nil {
return nil, err
}
defer file.Close()
dir, ok := file.(ReadDirFile)
if !ok {
return nil, &PathError{Op: "readdir", Path: name, Err: errors.New("not implemented")}
}
list, err := dir.ReadDir(-1)
sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() })
return list, err
}
// dirInfo is a DirEntry based on a FileInfo.
type dirInfo struct {
fileInfo FileInfo
}
func (di dirInfo) IsDir() bool {
return di.fileInfo.IsDir()
}
func (di dirInfo) Type() FileMode {
return di.fileInfo.Mode().Type()
}
func (di dirInfo) Info() (FileInfo, error) {
return di.fileInfo, nil
}
func (di dirInfo) Name() string {
return di.fileInfo.Name()
}
// FileInfoToDirEntry returns a DirEntry that returns information from info.
// If info is nil, FileInfoToDirEntry returns nil.
func FileInfoToDirEntry(info FileInfo) DirEntry {
if info == nil {
return nil
}
return dirInfo{fileInfo: info}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
import "io"
// ReadFileFS is the interface implemented by a file system
// that provides an optimized implementation of ReadFile.
type ReadFileFS interface {
FS
// ReadFile reads the named file and returns its contents.
// A successful call returns a nil error, not io.EOF.
// (Because ReadFile reads the whole file, the expected EOF
// from the final Read is not treated as an error to be reported.)
//
// The caller is permitted to modify the returned byte slice.
// This method should return a copy of the underlying data.
ReadFile(name string) ([]byte, error)
}
// ReadFile reads the named file from the file system fs and returns its contents.
// A successful call returns a nil error, not io.EOF.
// (Because ReadFile reads the whole file, the expected EOF
// from the final Read is not treated as an error to be reported.)
//
// If fs implements ReadFileFS, ReadFile calls fs.ReadFile.
// Otherwise ReadFile calls fs.Open and uses Read and Close
// on the returned file.
func ReadFile(fsys FS, name string) ([]byte, error) {
if fsys, ok := fsys.(ReadFileFS); ok {
return fsys.ReadFile(name)
}
file, err := fsys.Open(name)
if err != nil {
return nil, err
}
defer file.Close()
var size int
if info, err := file.Stat(); err == nil {
size64 := info.Size()
if int64(int(size64)) == size64 {
size = int(size64)
}
}
data := make([]byte, 0, size+1)
for {
if len(data) >= cap(data) {
d := append(data[:cap(data)], 0)
data = d[:len(data)]
}
n, err := file.Read(data[len(data):cap(data)])
data = data[:len(data)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return data, err
}
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
// A StatFS is a file system with a Stat method.
type StatFS interface {
FS
// Stat returns a FileInfo describing the file.
// If there is an error, it should be of type *PathError.
Stat(name string) (FileInfo, error)
}
// Stat returns a FileInfo describing the named file from the file system.
//
// If fs implements StatFS, Stat calls fs.Stat.
// Otherwise, Stat opens the file to stat it.
func Stat(fsys FS, name string) (FileInfo, error) {
if fsys, ok := fsys.(StatFS); ok {
return fsys.Stat(name)
}
file, err := fsys.Open(name)
if err != nil {
return nil, err
}
defer file.Close()
return file.Stat()
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
import (
"errors"
"path"
)
// A SubFS is a file system with a Sub method.
type SubFS interface {
FS
// Sub returns an FS corresponding to the subtree rooted at dir.
Sub(dir string) (FS, error)
}
// Sub returns an FS corresponding to the subtree rooted at fsys's dir.
//
// If dir is ".", Sub returns fsys unchanged.
// Otherwise, if fs implements SubFS, Sub returns fsys.Sub(dir).
// Otherwise, Sub returns a new FS implementation sub that,
// in effect, implements sub.Open(name) as fsys.Open(path.Join(dir, name)).
// The implementation also translates calls to ReadDir, ReadFile, and Glob appropriately.
//
// Note that Sub(os.DirFS("/"), "prefix") is equivalent to os.DirFS("/prefix")
// and that neither of them guarantees to avoid operating system
// accesses outside "/prefix", because the implementation of os.DirFS
// does not check for symbolic links inside "/prefix" that point to
// other directories. That is, os.DirFS is not a general substitute for a
// chroot-style security mechanism, and Sub does not change that fact.
func Sub(fsys FS, dir string) (FS, error) {
if !ValidPath(dir) {
return nil, &PathError{Op: "sub", Path: dir, Err: errors.New("invalid name")}
}
if dir == "." {
return fsys, nil
}
if fsys, ok := fsys.(SubFS); ok {
return fsys.Sub(dir)
}
return &subFS{fsys, dir}, nil
}
type subFS struct {
fsys FS
dir string
}
// fullName maps name to the fully-qualified name dir/name.
func (f *subFS) fullName(op string, name string) (string, error) {
if !ValidPath(name) {
return "", &PathError{Op: op, Path: name, Err: errors.New("invalid name")}
}
return path.Join(f.dir, name), nil
}
// shorten maps name, which should start with f.dir, back to the suffix after f.dir.
func (f *subFS) shorten(name string) (rel string, ok bool) {
if name == f.dir {
return ".", true
}
if len(name) >= len(f.dir)+2 && name[len(f.dir)] == '/' && name[:len(f.dir)] == f.dir {
return name[len(f.dir)+1:], true
}
return "", false
}
// fixErr shortens any reported names in PathErrors by stripping f.dir.
func (f *subFS) fixErr(err error) error {
if e, ok := err.(*PathError); ok {
if short, ok := f.shorten(e.Path); ok {
e.Path = short
}
}
return err
}
func (f *subFS) Open(name string) (File, error) {
full, err := f.fullName("open", name)
if err != nil {
return nil, err
}
file, err := f.fsys.Open(full)
return file, f.fixErr(err)
}
func (f *subFS) ReadDir(name string) ([]DirEntry, error) {
full, err := f.fullName("read", name)
if err != nil {
return nil, err
}
dir, err := ReadDir(f.fsys, full)
return dir, f.fixErr(err)
}
func (f *subFS) ReadFile(name string) ([]byte, error) {
full, err := f.fullName("read", name)
if err != nil {
return nil, err
}
data, err := ReadFile(f.fsys, full)
return data, f.fixErr(err)
}
func (f *subFS) Glob(pattern string) ([]string, error) {
// Check pattern is well-formed.
if _, err := path.Match(pattern, ""); err != nil {
return nil, err
}
if pattern == "." {
return []string{"."}, nil
}
full := f.dir + "/" + pattern
list, err := Glob(f.fsys, full)
for i, name := range list {
name, ok := f.shorten(name)
if !ok {
return nil, errors.New("invalid result from inner fsys Glob: " + name + " not in " + f.dir) // can't use fmt in this package
}
list[i] = name
}
return list, f.fixErr(err)
}
func (f *subFS) Sub(dir string) (FS, error) {
if dir == "." {
return f, nil
}
full, err := f.fullName("sub", dir)
if err != nil {
return nil, err
}
return &subFS{f.fsys, full}, nil
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fs
import (
"errors"
"path"
)
// SkipDir is used as a return value from WalkDirFuncs to indicate that
// the directory named in the call is to be skipped. It is not returned
// as an error by any function.
var SkipDir = errors.New("skip this directory")
// SkipAll is used as a return value from WalkDirFuncs to indicate that
// all remaining files and directories are to be skipped. It is not returned
// as an error by any function.
var SkipAll = errors.New("skip everything and stop the walk")
// WalkDirFunc is the type of the function called by WalkDir to visit
// each file or directory.
//
// The path argument contains the argument to WalkDir as a prefix.
// That is, if WalkDir is called with root argument "dir" and finds a file
// named "a" in that directory, the walk function will be called with
// argument "dir/a".
//
// The d argument is the fs.DirEntry for the named path.
//
// The error result returned by the function controls how WalkDir
// continues. If the function returns the special value SkipDir, WalkDir
// skips the current directory (path if d.IsDir() is true, otherwise
// path's parent directory). If the function returns the special value
// SkipAll, WalkDir skips all remaining files and directories. Otherwise,
// if the function returns a non-nil error, WalkDir stops entirely and
// returns that error.
//
// The err argument reports an error related to path, signaling that
// WalkDir will not walk into that directory. The function can decide how
// to handle that error; as described earlier, returning the error will
// cause WalkDir to stop walking the entire tree.
//
// WalkDir calls the function with a non-nil err argument in two cases.
//
// First, if the initial fs.Stat on the root directory fails, WalkDir
// calls the function with path set to root, d set to nil, and err set to
// the error from fs.Stat.
//
// Second, if a directory's ReadDir method fails, WalkDir calls the
// function with path set to the directory's path, d set to an
// fs.DirEntry describing the directory, and err set to the error from
// ReadDir. In this second case, the function is called twice with the
// path of the directory: the first call is before the directory read is
// attempted and has err set to nil, giving the function a chance to
// return SkipDir or SkipAll and avoid the ReadDir entirely. The second call
// is after a failed ReadDir and reports the error from ReadDir.
// (If ReadDir succeeds, there is no second call.)
//
// The differences between WalkDirFunc compared to filepath.WalkFunc are:
//
// - The second argument has type fs.DirEntry instead of fs.FileInfo.
// - The function is called before reading a directory, to allow SkipDir
// or SkipAll to bypass the directory read entirely or skip all remaining
// files and directories respectively.
// - If a directory read fails, the function is called a second time
// for that directory to report the error.
type WalkDirFunc func(path string, d DirEntry, err error) error
// walkDir recursively descends path, calling walkDirFn.
func walkDir(fsys FS, name string, d DirEntry, walkDirFn WalkDirFunc) error {
if err := walkDirFn(name, d, nil); err != nil || !d.IsDir() {
if err == SkipDir && d.IsDir() {
// Successfully skipped directory.
err = nil
}
return err
}
dirs, err := ReadDir(fsys, name)
if err != nil {
// Second call, to report ReadDir error.
err = walkDirFn(name, d, err)
if err != nil {
if err == SkipDir && d.IsDir() {
err = nil
}
return err
}
}
for _, d1 := range dirs {
name1 := path.Join(name, d1.Name())
if err := walkDir(fsys, name1, d1, walkDirFn); err != nil {
if err == SkipDir {
break
}
return err
}
}
return nil
}
// WalkDir walks the file tree rooted at root, calling fn for each file or
// directory in the tree, including root.
//
// All errors that arise visiting files and directories are filtered by fn:
// see the fs.WalkDirFunc documentation for details.
//
// The files are walked in lexical order, which makes the output deterministic
// but requires WalkDir to read an entire directory into memory before proceeding
// to walk that directory.
//
// WalkDir does not follow symbolic links found in directories,
// but if root itself is a symbolic link, its target will be walked.
func WalkDir(fsys FS, root string, fn WalkDirFunc) error {
info, err := Stat(fsys, root)
if err != nil {
err = fn(root, nil, err)
} else {
err = walkDir(fsys, root, &statDirEntry{info}, fn)
}
if err == SkipDir || err == SkipAll {
return nil
}
return err
}
type statDirEntry struct {
info FileInfo
}
func (d *statDirEntry) Name() string { return d.info.Name() }
func (d *statDirEntry) IsDir() bool { return d.info.IsDir() }
func (d *statDirEntry) Type() FileMode { return d.info.Mode().Type() }
func (d *statDirEntry) Info() (FileInfo, error) { return d.info, nil }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package io provides basic interfaces to I/O primitives.
// Its primary job is to wrap existing implementations of such primitives,
// such as those in package os, into shared public interfaces that
// abstract the functionality, plus some other related primitives.
//
// Because these interfaces and primitives wrap lower-level operations with
// various implementations, unless otherwise informed clients should not
// assume they are safe for parallel execution.
package io
import (
"errors"
"sync"
)
// Seek whence values.
const (
SeekStart = 0 // seek relative to the origin of the file
SeekCurrent = 1 // seek relative to the current offset
SeekEnd = 2 // seek relative to the end
)
// ErrShortWrite means that a write accepted fewer bytes than requested
// but failed to return an explicit error.
var ErrShortWrite = errors.New("short write")
// errInvalidWrite means that a write returned an impossible count.
var errInvalidWrite = errors.New("invalid write result")
// ErrShortBuffer means that a read required a longer buffer than was provided.
var ErrShortBuffer = errors.New("short buffer")
// EOF is the error returned by Read when no more input is available.
// (Read must return EOF itself, not an error wrapping EOF,
// because callers will test for EOF using ==.)
// Functions should return EOF only to signal a graceful end of input.
// If the EOF occurs unexpectedly in a structured data stream,
// the appropriate error is either ErrUnexpectedEOF or some other error
// giving more detail.
var EOF = errors.New("EOF")
// ErrUnexpectedEOF means that EOF was encountered in the
// middle of reading a fixed-size block or data structure.
var ErrUnexpectedEOF = errors.New("unexpected EOF")
// ErrNoProgress is returned by some clients of a Reader when
// many calls to Read have failed to return any data or error,
// usually the sign of a broken Reader implementation.
var ErrNoProgress = errors.New("multiple Read calls return no data or error")
// Reader is the interface that wraps the basic Read method.
//
// Read reads up to len(p) bytes into p. It returns the number of bytes
// read (0 <= n <= len(p)) and any error encountered. Even if Read
// returns n < len(p), it may use all of p as scratch space during the call.
// If some data is available but not len(p) bytes, Read conventionally
// returns what is available instead of waiting for more.
//
// When Read encounters an error or end-of-file condition after
// successfully reading n > 0 bytes, it returns the number of
// bytes read. It may return the (non-nil) error from the same call
// or return the error (and n == 0) from a subsequent call.
// An instance of this general case is that a Reader returning
// a non-zero number of bytes at the end of the input stream may
// return either err == EOF or err == nil. The next Read should
// return 0, EOF.
//
// Callers should always process the n > 0 bytes returned before
// considering the error err. Doing so correctly handles I/O errors
// that happen after reading some bytes and also both of the
// allowed EOF behaviors.
//
// Implementations of Read are discouraged from returning a
// zero byte count with a nil error, except when len(p) == 0.
// Callers should treat a return of 0 and nil as indicating that
// nothing happened; in particular it does not indicate EOF.
//
// Implementations must not retain p.
type Reader interface {
Read(p []byte) (n int, err error)
}
// Writer is the interface that wraps the basic Write method.
//
// Write writes len(p) bytes from p to the underlying data stream.
// It returns the number of bytes written from p (0 <= n <= len(p))
// and any error encountered that caused the write to stop early.
// Write must return a non-nil error if it returns n < len(p).
// Write must not modify the slice data, even temporarily.
//
// Implementations must not retain p.
type Writer interface {
Write(p []byte) (n int, err error)
}
// Closer is the interface that wraps the basic Close method.
//
// The behavior of Close after the first call is undefined.
// Specific implementations may document their own behavior.
type Closer interface {
Close() error
}
// Seeker is the interface that wraps the basic Seek method.
//
// Seek sets the offset for the next Read or Write to offset,
// interpreted according to whence:
// SeekStart means relative to the start of the file,
// SeekCurrent means relative to the current offset, and
// SeekEnd means relative to the end
// (for example, offset = -2 specifies the penultimate byte of the file).
// Seek returns the new offset relative to the start of the
// file or an error, if any.
//
// Seeking to an offset before the start of the file is an error.
// Seeking to any positive offset may be allowed, but if the new offset exceeds
// the size of the underlying object the behavior of subsequent I/O operations
// is implementation-dependent.
type Seeker interface {
Seek(offset int64, whence int) (int64, error)
}
// ReadWriter is the interface that groups the basic Read and Write methods.
type ReadWriter interface {
Reader
Writer
}
// ReadCloser is the interface that groups the basic Read and Close methods.
type ReadCloser interface {
Reader
Closer
}
// WriteCloser is the interface that groups the basic Write and Close methods.
type WriteCloser interface {
Writer
Closer
}
// ReadWriteCloser is the interface that groups the basic Read, Write and Close methods.
type ReadWriteCloser interface {
Reader
Writer
Closer
}
// ReadSeeker is the interface that groups the basic Read and Seek methods.
type ReadSeeker interface {
Reader
Seeker
}
// ReadSeekCloser is the interface that groups the basic Read, Seek and Close
// methods.
type ReadSeekCloser interface {
Reader
Seeker
Closer
}
// WriteSeeker is the interface that groups the basic Write and Seek methods.
type WriteSeeker interface {
Writer
Seeker
}
// ReadWriteSeeker is the interface that groups the basic Read, Write and Seek methods.
type ReadWriteSeeker interface {
Reader
Writer
Seeker
}
// ReaderFrom is the interface that wraps the ReadFrom method.
//
// ReadFrom reads data from r until EOF or error.
// The return value n is the number of bytes read.
// Any error except EOF encountered during the read is also returned.
//
// The Copy function uses ReaderFrom if available.
type ReaderFrom interface {
ReadFrom(r Reader) (n int64, err error)
}
// WriterTo is the interface that wraps the WriteTo method.
//
// WriteTo writes data to w until there's no more data to write or
// when an error occurs. The return value n is the number of bytes
// written. Any error encountered during the write is also returned.
//
// The Copy function uses WriterTo if available.
type WriterTo interface {
WriteTo(w Writer) (n int64, err error)
}
// ReaderAt is the interface that wraps the basic ReadAt method.
//
// ReadAt reads len(p) bytes into p starting at offset off in the
// underlying input source. It returns the number of bytes
// read (0 <= n <= len(p)) and any error encountered.
//
// When ReadAt returns n < len(p), it returns a non-nil error
// explaining why more bytes were not returned. In this respect,
// ReadAt is stricter than Read.
//
// Even if ReadAt returns n < len(p), it may use all of p as scratch
// space during the call. If some data is available but not len(p) bytes,
// ReadAt blocks until either all the data is available or an error occurs.
// In this respect ReadAt is different from Read.
//
// If the n = len(p) bytes returned by ReadAt are at the end of the
// input source, ReadAt may return either err == EOF or err == nil.
//
// If ReadAt is reading from an input source with a seek offset,
// ReadAt should not affect nor be affected by the underlying
// seek offset.
//
// Clients of ReadAt can execute parallel ReadAt calls on the
// same input source.
//
// Implementations must not retain p.
type ReaderAt interface {
ReadAt(p []byte, off int64) (n int, err error)
}
// WriterAt is the interface that wraps the basic WriteAt method.
//
// WriteAt writes len(p) bytes from p to the underlying data stream
// at offset off. It returns the number of bytes written from p (0 <= n <= len(p))
// and any error encountered that caused the write to stop early.
// WriteAt must return a non-nil error if it returns n < len(p).
//
// If WriteAt is writing to a destination with a seek offset,
// WriteAt should not affect nor be affected by the underlying
// seek offset.
//
// Clients of WriteAt can execute parallel WriteAt calls on the same
// destination if the ranges do not overlap.
//
// Implementations must not retain p.
type WriterAt interface {
WriteAt(p []byte, off int64) (n int, err error)
}
// ByteReader is the interface that wraps the ReadByte method.
//
// ReadByte reads and returns the next byte from the input or
// any error encountered. If ReadByte returns an error, no input
// byte was consumed, and the returned byte value is undefined.
//
// ReadByte provides an efficient interface for byte-at-time
// processing. A Reader that does not implement ByteReader
// can be wrapped using bufio.NewReader to add this method.
type ByteReader interface {
ReadByte() (byte, error)
}
// ByteScanner is the interface that adds the UnreadByte method to the
// basic ReadByte method.
//
// UnreadByte causes the next call to ReadByte to return the last byte read.
// If the last operation was not a successful call to ReadByte, UnreadByte may
// return an error, unread the last byte read (or the byte prior to the
// last-unread byte), or (in implementations that support the Seeker interface)
// seek to one byte before the current offset.
type ByteScanner interface {
ByteReader
UnreadByte() error
}
// ByteWriter is the interface that wraps the WriteByte method.
type ByteWriter interface {
WriteByte(c byte) error
}
// RuneReader is the interface that wraps the ReadRune method.
//
// ReadRune reads a single encoded Unicode character
// and returns the rune and its size in bytes. If no character is
// available, err will be set.
type RuneReader interface {
ReadRune() (r rune, size int, err error)
}
// RuneScanner is the interface that adds the UnreadRune method to the
// basic ReadRune method.
//
// UnreadRune causes the next call to ReadRune to return the last rune read.
// If the last operation was not a successful call to ReadRune, UnreadRune may
// return an error, unread the last rune read (or the rune prior to the
// last-unread rune), or (in implementations that support the Seeker interface)
// seek to the start of the rune before the current offset.
type RuneScanner interface {
RuneReader
UnreadRune() error
}
// StringWriter is the interface that wraps the WriteString method.
type StringWriter interface {
WriteString(s string) (n int, err error)
}
// WriteString writes the contents of the string s to w, which accepts a slice of bytes.
// If w implements StringWriter, its WriteString method is invoked directly.
// Otherwise, w.Write is called exactly once.
func WriteString(w Writer, s string) (n int, err error) {
if sw, ok := w.(StringWriter); ok {
return sw.WriteString(s)
}
return w.Write([]byte(s))
}
// ReadAtLeast reads from r into buf until it has read at least min bytes.
// It returns the number of bytes copied and an error if fewer bytes were read.
// The error is EOF only if no bytes were read.
// If an EOF happens after reading fewer than min bytes,
// ReadAtLeast returns ErrUnexpectedEOF.
// If min is greater than the length of buf, ReadAtLeast returns ErrShortBuffer.
// On return, n >= min if and only if err == nil.
// If r returns an error having read at least min bytes, the error is dropped.
func ReadAtLeast(r Reader, buf []byte, min int) (n int, err error) {
if len(buf) < min {
return 0, ErrShortBuffer
}
for n < min && err == nil {
var nn int
nn, err = r.Read(buf[n:])
n += nn
}
if n >= min {
err = nil
} else if n > 0 && err == EOF {
err = ErrUnexpectedEOF
}
return
}
// ReadFull reads exactly len(buf) bytes from r into buf.
// It returns the number of bytes copied and an error if fewer bytes were read.
// The error is EOF only if no bytes were read.
// If an EOF happens after reading some but not all the bytes,
// ReadFull returns ErrUnexpectedEOF.
// On return, n == len(buf) if and only if err == nil.
// If r returns an error having read at least len(buf) bytes, the error is dropped.
func ReadFull(r Reader, buf []byte) (n int, err error) {
return ReadAtLeast(r, buf, len(buf))
}
// CopyN copies n bytes (or until an error) from src to dst.
// It returns the number of bytes copied and the earliest
// error encountered while copying.
// On return, written == n if and only if err == nil.
//
// If dst implements the ReaderFrom interface,
// the copy is implemented using it.
func CopyN(dst Writer, src Reader, n int64) (written int64, err error) {
written, err = Copy(dst, LimitReader(src, n))
if written == n {
return n, nil
}
if written < n && err == nil {
// src stopped early; must have been EOF.
err = EOF
}
return
}
// Copy copies from src to dst until either EOF is reached
// on src or an error occurs. It returns the number of bytes
// copied and the first error encountered while copying, if any.
//
// A successful Copy returns err == nil, not err == EOF.
// Because Copy is defined to read from src until EOF, it does
// not treat an EOF from Read as an error to be reported.
//
// If src implements the WriterTo interface,
// the copy is implemented by calling src.WriteTo(dst).
// Otherwise, if dst implements the ReaderFrom interface,
// the copy is implemented by calling dst.ReadFrom(src).
func Copy(dst Writer, src Reader) (written int64, err error) {
return copyBuffer(dst, src, nil)
}
// CopyBuffer is identical to Copy except that it stages through the
// provided buffer (if one is required) rather than allocating a
// temporary one. If buf is nil, one is allocated; otherwise if it has
// zero length, CopyBuffer panics.
//
// If either src implements WriterTo or dst implements ReaderFrom,
// buf will not be used to perform the copy.
func CopyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
if buf != nil && len(buf) == 0 {
panic("empty buffer in CopyBuffer")
}
return copyBuffer(dst, src, buf)
}
// copyBuffer is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) {
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := src.(WriterTo); ok {
return wt.WriteTo(dst)
}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
if rt, ok := dst.(ReaderFrom); ok {
return rt.ReadFrom(src)
}
if buf == nil {
size := 32 * 1024
if l, ok := src.(*LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf = make([]byte, size)
}
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errInvalidWrite
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = ErrShortWrite
break
}
}
if er != nil {
if er != EOF {
err = er
}
break
}
}
return written, err
}
// LimitReader returns a Reader that reads from r
// but stops with EOF after n bytes.
// The underlying implementation is a *LimitedReader.
func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} }
// A LimitedReader reads from R but limits the amount of
// data returned to just N bytes. Each call to Read
// updates N to reflect the new amount remaining.
// Read returns EOF when N <= 0 or when the underlying R returns EOF.
type LimitedReader struct {
R Reader // underlying reader
N int64 // max bytes remaining
}
func (l *LimitedReader) Read(p []byte) (n int, err error) {
if l.N <= 0 {
return 0, EOF
}
if int64(len(p)) > l.N {
p = p[0:l.N]
}
n, err = l.R.Read(p)
l.N -= int64(n)
return
}
// NewSectionReader returns a SectionReader that reads from r
// starting at offset off and stops with EOF after n bytes.
func NewSectionReader(r ReaderAt, off int64, n int64) *SectionReader {
var remaining int64
const maxint64 = 1<<63 - 1
if off <= maxint64-n {
remaining = n + off
} else {
// Overflow, with no way to return error.
// Assume we can read up to an offset of 1<<63 - 1.
remaining = maxint64
}
return &SectionReader{r, off, off, remaining}
}
// SectionReader implements Read, Seek, and ReadAt on a section
// of an underlying ReaderAt.
type SectionReader struct {
r ReaderAt
base int64
off int64
limit int64
}
func (s *SectionReader) Read(p []byte) (n int, err error) {
if s.off >= s.limit {
return 0, EOF
}
if max := s.limit - s.off; int64(len(p)) > max {
p = p[0:max]
}
n, err = s.r.ReadAt(p, s.off)
s.off += int64(n)
return
}
var errWhence = errors.New("Seek: invalid whence")
var errOffset = errors.New("Seek: invalid offset")
func (s *SectionReader) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, errWhence
case SeekStart:
offset += s.base
case SeekCurrent:
offset += s.off
case SeekEnd:
offset += s.limit
}
if offset < s.base {
return 0, errOffset
}
s.off = offset
return offset - s.base, nil
}
func (s *SectionReader) ReadAt(p []byte, off int64) (n int, err error) {
if off < 0 || off >= s.limit-s.base {
return 0, EOF
}
off += s.base
if max := s.limit - off; int64(len(p)) > max {
p = p[0:max]
n, err = s.r.ReadAt(p, off)
if err == nil {
err = EOF
}
return n, err
}
return s.r.ReadAt(p, off)
}
// Size returns the size of the section in bytes.
func (s *SectionReader) Size() int64 { return s.limit - s.base }
// An OffsetWriter maps writes at offset base to offset base+off in the underlying writer.
type OffsetWriter struct {
w WriterAt
base int64 // the original offset
off int64 // the current offset
}
// NewOffsetWriter returns an OffsetWriter that writes to w
// starting at offset off.
func NewOffsetWriter(w WriterAt, off int64) *OffsetWriter {
return &OffsetWriter{w, off, off}
}
func (o *OffsetWriter) Write(p []byte) (n int, err error) {
n, err = o.w.WriteAt(p, o.off)
o.off += int64(n)
return
}
func (o *OffsetWriter) WriteAt(p []byte, off int64) (n int, err error) {
off += o.base
return o.w.WriteAt(p, off)
}
func (o *OffsetWriter) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, errWhence
case SeekStart:
offset += o.base
case SeekCurrent:
offset += o.off
}
if offset < o.base {
return 0, errOffset
}
o.off = offset
return offset - o.base, nil
}
// TeeReader returns a Reader that writes to w what it reads from r.
// All reads from r performed through it are matched with
// corresponding writes to w. There is no internal buffering -
// the write must complete before the read completes.
// Any error encountered while writing is reported as a read error.
func TeeReader(r Reader, w Writer) Reader {
return &teeReader{r, w}
}
type teeReader struct {
r Reader
w Writer
}
func (t *teeReader) Read(p []byte) (n int, err error) {
n, err = t.r.Read(p)
if n > 0 {
if n, err := t.w.Write(p[:n]); err != nil {
return n, err
}
}
return
}
// Discard is a Writer on which all Write calls succeed
// without doing anything.
var Discard Writer = discard{}
type discard struct{}
// discard implements ReaderFrom as an optimization so Copy to
// io.Discard can avoid doing unnecessary work.
var _ ReaderFrom = discard{}
func (discard) Write(p []byte) (int, error) {
return len(p), nil
}
func (discard) WriteString(s string) (int, error) {
return len(s), nil
}
var blackHolePool = sync.Pool{
New: func() any {
b := make([]byte, 8192)
return &b
},
}
func (discard) ReadFrom(r Reader) (n int64, err error) {
bufp := blackHolePool.Get().(*[]byte)
readSize := 0
for {
readSize, err = r.Read(*bufp)
n += int64(readSize)
if err != nil {
blackHolePool.Put(bufp)
if err == EOF {
return n, nil
}
return
}
}
}
// NopCloser returns a ReadCloser with a no-op Close method wrapping
// the provided Reader r.
// If r implements WriterTo, the returned ReadCloser will implement WriterTo
// by forwarding calls to r.
func NopCloser(r Reader) ReadCloser {
if _, ok := r.(WriterTo); ok {
return nopCloserWriterTo{r}
}
return nopCloser{r}
}
type nopCloser struct {
Reader
}
func (nopCloser) Close() error { return nil }
type nopCloserWriterTo struct {
Reader
}
func (nopCloserWriterTo) Close() error { return nil }
func (c nopCloserWriterTo) WriteTo(w Writer) (n int64, err error) {
return c.Reader.(WriterTo).WriteTo(w)
}
// ReadAll reads from r until an error or EOF and returns the data it read.
// A successful call returns err == nil, not err == EOF. Because ReadAll is
// defined to read from src until EOF, it does not treat an EOF from Read
// as an error to be reported.
func ReadAll(r Reader) ([]byte, error) {
b := make([]byte, 0, 512)
for {
if len(b) == cap(b) {
// Add more capacity (let append pick how much).
b = append(b, 0)[:len(b)]
}
n, err := r.Read(b[len(b):cap(b)])
b = b[:len(b)+n]
if err != nil {
if err == EOF {
err = nil
}
return b, err
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ioutil implements some I/O utility functions.
//
// Deprecated: As of Go 1.16, the same functionality is now provided
// by package [io] or package [os], and those implementations
// should be preferred in new code.
// See the specific function documentation for details.
package ioutil
import (
"io"
"io/fs"
"os"
"sort"
)
// ReadAll reads from r until an error or EOF and returns the data it read.
// A successful call returns err == nil, not err == EOF. Because ReadAll is
// defined to read from src until EOF, it does not treat an EOF from Read
// as an error to be reported.
//
// Deprecated: As of Go 1.16, this function simply calls [io.ReadAll].
func ReadAll(r io.Reader) ([]byte, error) {
return io.ReadAll(r)
}
// ReadFile reads the file named by filename and returns the contents.
// A successful call returns err == nil, not err == EOF. Because ReadFile
// reads the whole file, it does not treat an EOF from Read as an error
// to be reported.
//
// Deprecated: As of Go 1.16, this function simply calls [os.ReadFile].
func ReadFile(filename string) ([]byte, error) {
return os.ReadFile(filename)
}
// WriteFile writes data to a file named by filename.
// If the file does not exist, WriteFile creates it with permissions perm
// (before umask); otherwise WriteFile truncates it before writing, without changing permissions.
//
// Deprecated: As of Go 1.16, this function simply calls [os.WriteFile].
func WriteFile(filename string, data []byte, perm fs.FileMode) error {
return os.WriteFile(filename, data, perm)
}
// ReadDir reads the directory named by dirname and returns
// a list of fs.FileInfo for the directory's contents,
// sorted by filename. If an error occurs reading the directory,
// ReadDir returns no directory entries along with the error.
//
// Deprecated: As of Go 1.16, [os.ReadDir] is a more efficient and correct choice:
// it returns a list of [fs.DirEntry] instead of [fs.FileInfo],
// and it returns partial results in the case of an error
// midway through reading a directory.
//
// If you must continue obtaining a list of [fs.FileInfo], you still can:
//
// entries, err := os.ReadDir(dirname)
// if err != nil { ... }
// infos := make([]fs.FileInfo, 0, len(entries))
// for _, entry := range entries {
// info, err := entry.Info()
// if err != nil { ... }
// infos = append(infos, info)
// }
func ReadDir(dirname string) ([]fs.FileInfo, error) {
f, err := os.Open(dirname)
if err != nil {
return nil, err
}
list, err := f.Readdir(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() })
return list, nil
}
// NopCloser returns a ReadCloser with a no-op Close method wrapping
// the provided Reader r.
//
// Deprecated: As of Go 1.16, this function simply calls [io.NopCloser].
func NopCloser(r io.Reader) io.ReadCloser {
return io.NopCloser(r)
}
// Discard is an io.Writer on which all Write calls succeed
// without doing anything.
//
// Deprecated: As of Go 1.16, this value is simply [io.Discard].
var Discard io.Writer = io.Discard
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ioutil
import (
"os"
)
// TempFile creates a new temporary file in the directory dir,
// opens the file for reading and writing, and returns the resulting *os.File.
// The filename is generated by taking pattern and adding a random
// string to the end. If pattern includes a "*", the random string
// replaces the last "*".
// If dir is the empty string, TempFile uses the default directory
// for temporary files (see os.TempDir).
// Multiple programs calling TempFile simultaneously
// will not choose the same file. The caller can use f.Name()
// to find the pathname of the file. It is the caller's responsibility
// to remove the file when no longer needed.
//
// Deprecated: As of Go 1.17, this function simply calls [os.CreateTemp].
func TempFile(dir, pattern string) (f *os.File, err error) {
return os.CreateTemp(dir, pattern)
}
// TempDir creates a new temporary directory in the directory dir.
// The directory name is generated by taking pattern and applying a
// random string to the end. If pattern includes a "*", the random string
// replaces the last "*". TempDir returns the name of the new directory.
// If dir is the empty string, TempDir uses the
// default directory for temporary files (see os.TempDir).
// Multiple programs calling TempDir simultaneously
// will not choose the same directory. It is the caller's responsibility
// to remove the directory when no longer needed.
//
// Deprecated: As of Go 1.17, this function simply calls [os.MkdirTemp].
func TempDir(dir, pattern string) (name string, err error) {
return os.MkdirTemp(dir, pattern)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package io
type eofReader struct{}
func (eofReader) Read([]byte) (int, error) {
return 0, EOF
}
type multiReader struct {
readers []Reader
}
func (mr *multiReader) Read(p []byte) (n int, err error) {
for len(mr.readers) > 0 {
// Optimization to flatten nested multiReaders (Issue 13558).
if len(mr.readers) == 1 {
if r, ok := mr.readers[0].(*multiReader); ok {
mr.readers = r.readers
continue
}
}
n, err = mr.readers[0].Read(p)
if err == EOF {
// Use eofReader instead of nil to avoid nil panic
// after performing flatten (Issue 18232).
mr.readers[0] = eofReader{} // permit earlier GC
mr.readers = mr.readers[1:]
}
if n > 0 || err != EOF {
if err == EOF && len(mr.readers) > 0 {
// Don't return EOF yet. More readers remain.
err = nil
}
return
}
}
return 0, EOF
}
func (mr *multiReader) WriteTo(w Writer) (sum int64, err error) {
return mr.writeToWithBuffer(w, make([]byte, 1024*32))
}
func (mr *multiReader) writeToWithBuffer(w Writer, buf []byte) (sum int64, err error) {
for i, r := range mr.readers {
var n int64
if subMr, ok := r.(*multiReader); ok { // reuse buffer with nested multiReaders
n, err = subMr.writeToWithBuffer(w, buf)
} else {
n, err = copyBuffer(w, r, buf)
}
sum += n
if err != nil {
mr.readers = mr.readers[i:] // permit resume / retry after error
return sum, err
}
mr.readers[i] = nil // permit early GC
}
mr.readers = nil
return sum, nil
}
var _ WriterTo = (*multiReader)(nil)
// MultiReader returns a Reader that's the logical concatenation of
// the provided input readers. They're read sequentially. Once all
// inputs have returned EOF, Read will return EOF. If any of the readers
// return a non-nil, non-EOF error, Read will return that error.
func MultiReader(readers ...Reader) Reader {
r := make([]Reader, len(readers))
copy(r, readers)
return &multiReader{r}
}
type multiWriter struct {
writers []Writer
}
func (t *multiWriter) Write(p []byte) (n int, err error) {
for _, w := range t.writers {
n, err = w.Write(p)
if err != nil {
return
}
if n != len(p) {
err = ErrShortWrite
return
}
}
return len(p), nil
}
var _ StringWriter = (*multiWriter)(nil)
func (t *multiWriter) WriteString(s string) (n int, err error) {
var p []byte // lazily initialized if/when needed
for _, w := range t.writers {
if sw, ok := w.(StringWriter); ok {
n, err = sw.WriteString(s)
} else {
if p == nil {
p = []byte(s)
}
n, err = w.Write(p)
}
if err != nil {
return
}
if n != len(s) {
err = ErrShortWrite
return
}
}
return len(s), nil
}
// MultiWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command.
//
// Each write is written to each listed writer, one at a time.
// If a listed writer returns an error, that overall write operation
// stops and returns the error; it does not continue down the list.
func MultiWriter(writers ...Writer) Writer {
allWriters := make([]Writer, 0, len(writers))
for _, w := range writers {
if mw, ok := w.(*multiWriter); ok {
allWriters = append(allWriters, mw.writers...)
} else {
allWriters = append(allWriters, w)
}
}
return &multiWriter{allWriters}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Pipe adapter to connect code expecting an io.Reader
// with code expecting an io.Writer.
package io
import (
"errors"
"sync"
)
// onceError is an object that will only store an error once.
type onceError struct {
sync.Mutex // guards following
err error
}
func (a *onceError) Store(err error) {
a.Lock()
defer a.Unlock()
if a.err != nil {
return
}
a.err = err
}
func (a *onceError) Load() error {
a.Lock()
defer a.Unlock()
return a.err
}
// ErrClosedPipe is the error used for read or write operations on a closed pipe.
var ErrClosedPipe = errors.New("io: read/write on closed pipe")
// A pipe is the shared pipe structure underlying PipeReader and PipeWriter.
type pipe struct {
wrMu sync.Mutex // Serializes Write operations
wrCh chan []byte
rdCh chan int
once sync.Once // Protects closing done
done chan struct{}
rerr onceError
werr onceError
}
func (p *pipe) read(b []byte) (n int, err error) {
select {
case <-p.done:
return 0, p.readCloseError()
default:
}
select {
case bw := <-p.wrCh:
nr := copy(b, bw)
p.rdCh <- nr
return nr, nil
case <-p.done:
return 0, p.readCloseError()
}
}
func (p *pipe) closeRead(err error) error {
if err == nil {
err = ErrClosedPipe
}
p.rerr.Store(err)
p.once.Do(func() { close(p.done) })
return nil
}
func (p *pipe) write(b []byte) (n int, err error) {
select {
case <-p.done:
return 0, p.writeCloseError()
default:
p.wrMu.Lock()
defer p.wrMu.Unlock()
}
for once := true; once || len(b) > 0; once = false {
select {
case p.wrCh <- b:
nw := <-p.rdCh
b = b[nw:]
n += nw
case <-p.done:
return n, p.writeCloseError()
}
}
return n, nil
}
func (p *pipe) closeWrite(err error) error {
if err == nil {
err = EOF
}
p.werr.Store(err)
p.once.Do(func() { close(p.done) })
return nil
}
// readCloseError is considered internal to the pipe type.
func (p *pipe) readCloseError() error {
rerr := p.rerr.Load()
if werr := p.werr.Load(); rerr == nil && werr != nil {
return werr
}
return ErrClosedPipe
}
// writeCloseError is considered internal to the pipe type.
func (p *pipe) writeCloseError() error {
werr := p.werr.Load()
if rerr := p.rerr.Load(); werr == nil && rerr != nil {
return rerr
}
return ErrClosedPipe
}
// A PipeReader is the read half of a pipe.
type PipeReader struct {
p *pipe
}
// Read implements the standard Read interface:
// it reads data from the pipe, blocking until a writer
// arrives or the write end is closed.
// If the write end is closed with an error, that error is
// returned as err; otherwise err is EOF.
func (r *PipeReader) Read(data []byte) (n int, err error) {
return r.p.read(data)
}
// Close closes the reader; subsequent writes to the
// write half of the pipe will return the error ErrClosedPipe.
func (r *PipeReader) Close() error {
return r.CloseWithError(nil)
}
// CloseWithError closes the reader; subsequent writes
// to the write half of the pipe will return the error err.
//
// CloseWithError never overwrites the previous error if it exists
// and always returns nil.
func (r *PipeReader) CloseWithError(err error) error {
return r.p.closeRead(err)
}
// A PipeWriter is the write half of a pipe.
type PipeWriter struct {
p *pipe
}
// Write implements the standard Write interface:
// it writes data to the pipe, blocking until one or more readers
// have consumed all the data or the read end is closed.
// If the read end is closed with an error, that err is
// returned as err; otherwise err is ErrClosedPipe.
func (w *PipeWriter) Write(data []byte) (n int, err error) {
return w.p.write(data)
}
// Close closes the writer; subsequent reads from the
// read half of the pipe will return no bytes and EOF.
func (w *PipeWriter) Close() error {
return w.CloseWithError(nil)
}
// CloseWithError closes the writer; subsequent reads from the
// read half of the pipe will return no bytes and the error err,
// or EOF if err is nil.
//
// CloseWithError never overwrites the previous error if it exists
// and always returns nil.
func (w *PipeWriter) CloseWithError(err error) error {
return w.p.closeWrite(err)
}
// Pipe creates a synchronous in-memory pipe.
// It can be used to connect code expecting an io.Reader
// with code expecting an io.Writer.
//
// Reads and Writes on the pipe are matched one to one
// except when multiple Reads are needed to consume a single Write.
// That is, each Write to the PipeWriter blocks until it has satisfied
// one or more Reads from the PipeReader that fully consume
// the written data.
// The data is copied directly from the Write to the corresponding
// Read (or Reads); there is no internal buffering.
//
// It is safe to call Read and Write in parallel with each other or with Close.
// Parallel calls to Read and parallel calls to Write are also safe:
// the individual calls will be gated sequentially.
func Pipe() (*PipeReader, *PipeWriter) {
p := &pipe{
wrCh: make(chan []byte),
rdCh: make(chan int),
done: make(chan struct{}),
}
return &PipeReader{p}, &PipeWriter{p}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package log implements a simple logging package. It defines a type, Logger,
// with methods for formatting output. It also has a predefined 'standard'
// Logger accessible through helper functions Print[f|ln], Fatal[f|ln], and
// Panic[f|ln], which are easier to use than creating a Logger manually.
// That logger writes to standard error and prints the date and time
// of each logged message.
// Every log message is output on a separate line: if the message being
// printed does not end in a newline, the logger will add one.
// The Fatal functions call os.Exit(1) after writing the log message.
// The Panic functions call panic after writing the log message.
package log
import (
"fmt"
"io"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
)
// These flags define which text to prefix to each log entry generated by the Logger.
// Bits are or'ed together to control what's printed.
// With the exception of the Lmsgprefix flag, there is no
// control over the order they appear (the order listed here)
// or the format they present (as described in the comments).
// The prefix is followed by a colon only when Llongfile or Lshortfile
// is specified.
// For example, flags Ldate | Ltime (or LstdFlags) produce,
//
// 2009/01/23 01:23:23 message
//
// while flags Ldate | Ltime | Lmicroseconds | Llongfile produce,
//
// 2009/01/23 01:23:23.123123 /a/b/c/d.go:23: message
const (
Ldate = 1 << iota // the date in the local time zone: 2009/01/23
Ltime // the time in the local time zone: 01:23:23
Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime.
Llongfile // full file name and line number: /a/b/c/d.go:23
Lshortfile // final file name element and line number: d.go:23. overrides Llongfile
LUTC // if Ldate or Ltime is set, use UTC rather than the local time zone
Lmsgprefix // move the "prefix" from the beginning of the line to before the message
LstdFlags = Ldate | Ltime // initial values for the standard logger
)
// A Logger represents an active logging object that generates lines of
// output to an io.Writer. Each logging operation makes a single call to
// the Writer's Write method. A Logger can be used simultaneously from
// multiple goroutines; it guarantees to serialize access to the Writer.
type Logger struct {
outMu sync.Mutex
out io.Writer // destination for output
prefix atomic.Pointer[string] // prefix on each line to identify the logger (but see Lmsgprefix)
flag atomic.Int32 // properties
isDiscard atomic.Bool
}
// New creates a new Logger. The out variable sets the
// destination to which log data will be written.
// The prefix appears at the beginning of each generated log line, or
// after the log header if the Lmsgprefix flag is provided.
// The flag argument defines the logging properties.
func New(out io.Writer, prefix string, flag int) *Logger {
l := new(Logger)
l.SetOutput(out)
l.SetPrefix(prefix)
l.SetFlags(flag)
return l
}
// SetOutput sets the output destination for the logger.
func (l *Logger) SetOutput(w io.Writer) {
l.outMu.Lock()
defer l.outMu.Unlock()
l.out = w
l.isDiscard.Store(w == io.Discard)
}
var std = New(os.Stderr, "", LstdFlags)
// Default returns the standard logger used by the package-level output functions.
func Default() *Logger { return std }
// Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding.
func itoa(buf *[]byte, i int, wid int) {
// Assemble decimal in reverse order.
var b [20]byte
bp := len(b) - 1
for i >= 10 || wid > 1 {
wid--
q := i / 10
b[bp] = byte('0' + i - q*10)
bp--
i = q
}
// i < 10
b[bp] = byte('0' + i)
*buf = append(*buf, b[bp:]...)
}
// formatHeader writes log header to buf in following order:
// - l.prefix (if it's not blank and Lmsgprefix is unset),
// - date and/or time (if corresponding flags are provided),
// - file and line number (if corresponding flags are provided),
// - l.prefix (if it's not blank and Lmsgprefix is set).
func formatHeader(buf *[]byte, t time.Time, prefix string, flag int, file string, line int) {
if flag&Lmsgprefix == 0 {
*buf = append(*buf, prefix...)
}
if flag&(Ldate|Ltime|Lmicroseconds) != 0 {
if flag&LUTC != 0 {
t = t.UTC()
}
if flag&Ldate != 0 {
year, month, day := t.Date()
itoa(buf, year, 4)
*buf = append(*buf, '/')
itoa(buf, int(month), 2)
*buf = append(*buf, '/')
itoa(buf, day, 2)
*buf = append(*buf, ' ')
}
if flag&(Ltime|Lmicroseconds) != 0 {
hour, min, sec := t.Clock()
itoa(buf, hour, 2)
*buf = append(*buf, ':')
itoa(buf, min, 2)
*buf = append(*buf, ':')
itoa(buf, sec, 2)
if flag&Lmicroseconds != 0 {
*buf = append(*buf, '.')
itoa(buf, t.Nanosecond()/1e3, 6)
}
*buf = append(*buf, ' ')
}
}
if flag&(Lshortfile|Llongfile) != 0 {
if flag&Lshortfile != 0 {
short := file
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
short = file[i+1:]
break
}
}
file = short
}
*buf = append(*buf, file...)
*buf = append(*buf, ':')
itoa(buf, line, -1)
*buf = append(*buf, ": "...)
}
if flag&Lmsgprefix != 0 {
*buf = append(*buf, prefix...)
}
}
var bufferPool = sync.Pool{New: func() any { return new([]byte) }}
func getBuffer() *[]byte {
p := bufferPool.Get().(*[]byte)
*p = (*p)[:0]
return p
}
func putBuffer(p *[]byte) {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum buffer
// to place back in the pool.
//
// See https://go.dev/issue/23199
if cap(*p) > 64<<10 {
*p = nil
}
bufferPool.Put(p)
}
// Output writes the output for a logging event. The string s contains
// the text to print after the prefix specified by the flags of the
// Logger. A newline is appended if the last character of s is not
// already a newline. Calldepth is used to recover the PC and is
// provided for generality, although at the moment on all pre-defined
// paths it will be 2.
func (l *Logger) Output(calldepth int, s string) error {
calldepth++ // +1 for this frame.
return l.output(calldepth, func(b []byte) []byte {
return append(b, s...)
})
}
func (l *Logger) output(calldepth int, appendOutput func([]byte) []byte) error {
if l.isDiscard.Load() {
return nil
}
now := time.Now() // get this early.
// Load prefix and flag once so that their value is consistent within
// this call regardless of any concurrent changes to their value.
prefix := l.Prefix()
flag := l.Flags()
var file string
var line int
if flag&(Lshortfile|Llongfile) != 0 {
var ok bool
_, file, line, ok = runtime.Caller(calldepth)
if !ok {
file = "???"
line = 0
}
}
buf := getBuffer()
defer putBuffer(buf)
formatHeader(buf, now, prefix, flag, file, line)
*buf = appendOutput(*buf)
if len(*buf) == 0 || (*buf)[len(*buf)-1] != '\n' {
*buf = append(*buf, '\n')
}
l.outMu.Lock()
defer l.outMu.Unlock()
_, err := l.out.Write(*buf)
return err
}
// Print calls l.Output to print to the logger.
// Arguments are handled in the manner of fmt.Print.
func (l *Logger) Print(v ...any) {
l.output(2, func(b []byte) []byte {
return fmt.Append(b, v...)
})
}
// Printf calls l.Output to print to the logger.
// Arguments are handled in the manner of fmt.Printf.
func (l *Logger) Printf(format string, v ...any) {
l.output(2, func(b []byte) []byte {
return fmt.Appendf(b, format, v...)
})
}
// Println calls l.Output to print to the logger.
// Arguments are handled in the manner of fmt.Println.
func (l *Logger) Println(v ...any) {
l.output(2, func(b []byte) []byte {
return fmt.Appendln(b, v...)
})
}
// Fatal is equivalent to l.Print() followed by a call to os.Exit(1).
func (l *Logger) Fatal(v ...any) {
l.Output(2, fmt.Sprint(v...))
os.Exit(1)
}
// Fatalf is equivalent to l.Printf() followed by a call to os.Exit(1).
func (l *Logger) Fatalf(format string, v ...any) {
l.Output(2, fmt.Sprintf(format, v...))
os.Exit(1)
}
// Fatalln is equivalent to l.Println() followed by a call to os.Exit(1).
func (l *Logger) Fatalln(v ...any) {
l.Output(2, fmt.Sprintln(v...))
os.Exit(1)
}
// Panic is equivalent to l.Print() followed by a call to panic().
func (l *Logger) Panic(v ...any) {
s := fmt.Sprint(v...)
l.Output(2, s)
panic(s)
}
// Panicf is equivalent to l.Printf() followed by a call to panic().
func (l *Logger) Panicf(format string, v ...any) {
s := fmt.Sprintf(format, v...)
l.Output(2, s)
panic(s)
}
// Panicln is equivalent to l.Println() followed by a call to panic().
func (l *Logger) Panicln(v ...any) {
s := fmt.Sprintln(v...)
l.Output(2, s)
panic(s)
}
// Flags returns the output flags for the logger.
// The flag bits are Ldate, Ltime, and so on.
func (l *Logger) Flags() int {
return int(l.flag.Load())
}
// SetFlags sets the output flags for the logger.
// The flag bits are Ldate, Ltime, and so on.
func (l *Logger) SetFlags(flag int) {
l.flag.Store(int32(flag))
}
// Prefix returns the output prefix for the logger.
func (l *Logger) Prefix() string {
if p := l.prefix.Load(); p != nil {
return *p
}
return ""
}
// SetPrefix sets the output prefix for the logger.
func (l *Logger) SetPrefix(prefix string) {
l.prefix.Store(&prefix)
}
// Writer returns the output destination for the logger.
func (l *Logger) Writer() io.Writer {
l.outMu.Lock()
defer l.outMu.Unlock()
return l.out
}
// SetOutput sets the output destination for the standard logger.
func SetOutput(w io.Writer) {
std.SetOutput(w)
}
// Flags returns the output flags for the standard logger.
// The flag bits are Ldate, Ltime, and so on.
func Flags() int {
return std.Flags()
}
// SetFlags sets the output flags for the standard logger.
// The flag bits are Ldate, Ltime, and so on.
func SetFlags(flag int) {
std.SetFlags(flag)
}
// Prefix returns the output prefix for the standard logger.
func Prefix() string {
return std.Prefix()
}
// SetPrefix sets the output prefix for the standard logger.
func SetPrefix(prefix string) {
std.SetPrefix(prefix)
}
// Writer returns the output destination for the standard logger.
func Writer() io.Writer {
return std.Writer()
}
// These functions write to the standard logger.
// Print calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Print.
func Print(v ...any) {
std.output(2, func(b []byte) []byte {
return fmt.Append(b, v...)
})
}
// Printf calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Printf.
func Printf(format string, v ...any) {
std.output(2, func(b []byte) []byte {
return fmt.Appendf(b, format, v...)
})
}
// Println calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Println.
func Println(v ...any) {
std.output(2, func(b []byte) []byte {
return fmt.Appendln(b, v...)
})
}
// Fatal is equivalent to Print() followed by a call to os.Exit(1).
func Fatal(v ...any) {
std.Output(2, fmt.Sprint(v...))
os.Exit(1)
}
// Fatalf is equivalent to Printf() followed by a call to os.Exit(1).
func Fatalf(format string, v ...any) {
std.Output(2, fmt.Sprintf(format, v...))
os.Exit(1)
}
// Fatalln is equivalent to Println() followed by a call to os.Exit(1).
func Fatalln(v ...any) {
std.Output(2, fmt.Sprintln(v...))
os.Exit(1)
}
// Panic is equivalent to Print() followed by a call to panic().
func Panic(v ...any) {
s := fmt.Sprint(v...)
std.Output(2, s)
panic(s)
}
// Panicf is equivalent to Printf() followed by a call to panic().
func Panicf(format string, v ...any) {
s := fmt.Sprintf(format, v...)
std.Output(2, s)
panic(s)
}
// Panicln is equivalent to Println() followed by a call to panic().
func Panicln(v ...any) {
s := fmt.Sprintln(v...)
std.Output(2, s)
panic(s)
}
// Output writes the output for a logging event. The string s contains
// the text to print after the prefix specified by the flags of the
// Logger. A newline is appended if the last character of s is not
// already a newline. Calldepth is the count of the number of
// frames to skip when computing the file name and line number
// if Llongfile or Lshortfile is set; a value of 1 will print the details
// for the caller of Output.
func Output(calldepth int, s string) error {
return std.Output(calldepth+1, s) // +1 for this frame.
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows && !plan9
package syslog
import (
"errors"
"fmt"
"log"
"net"
"os"
"strings"
"sync"
"time"
)
// The Priority is a combination of the syslog facility and
// severity. For example, LOG_ALERT | LOG_FTP sends an alert severity
// message from the FTP facility. The default severity is LOG_EMERG;
// the default facility is LOG_KERN.
type Priority int
const severityMask = 0x07
const facilityMask = 0xf8
const (
// Severity.
// From /usr/include/sys/syslog.h.
// These are the same on Linux, BSD, and OS X.
LOG_EMERG Priority = iota
LOG_ALERT
LOG_CRIT
LOG_ERR
LOG_WARNING
LOG_NOTICE
LOG_INFO
LOG_DEBUG
)
const (
// Facility.
// From /usr/include/sys/syslog.h.
// These are the same up to LOG_FTP on Linux, BSD, and OS X.
LOG_KERN Priority = iota << 3
LOG_USER
LOG_MAIL
LOG_DAEMON
LOG_AUTH
LOG_SYSLOG
LOG_LPR
LOG_NEWS
LOG_UUCP
LOG_CRON
LOG_AUTHPRIV
LOG_FTP
_ // unused
_ // unused
_ // unused
_ // unused
LOG_LOCAL0
LOG_LOCAL1
LOG_LOCAL2
LOG_LOCAL3
LOG_LOCAL4
LOG_LOCAL5
LOG_LOCAL6
LOG_LOCAL7
)
// A Writer is a connection to a syslog server.
type Writer struct {
priority Priority
tag string
hostname string
network string
raddr string
mu sync.Mutex // guards conn
conn serverConn
}
// This interface and the separate syslog_unix.go file exist for
// Solaris support as implemented by gccgo. On Solaris you cannot
// simply open a TCP connection to the syslog daemon. The gccgo
// sources have a syslog_solaris.go file that implements unixSyslog to
// return a type that satisfies this interface and simply calls the C
// library syslog function.
type serverConn interface {
writeString(p Priority, hostname, tag, s, nl string) error
close() error
}
type netConn struct {
local bool
conn net.Conn
}
// New establishes a new connection to the system log daemon. Each
// write to the returned writer sends a log message with the given
// priority (a combination of the syslog facility and severity) and
// prefix tag. If tag is empty, the os.Args[0] is used.
func New(priority Priority, tag string) (*Writer, error) {
return Dial("", "", priority, tag)
}
// Dial establishes a connection to a log daemon by connecting to
// address raddr on the specified network. Each write to the returned
// writer sends a log message with the facility and severity
// (from priority) and tag. If tag is empty, the os.Args[0] is used.
// If network is empty, Dial will connect to the local syslog server.
// Otherwise, see the documentation for net.Dial for valid values
// of network and raddr.
func Dial(network, raddr string, priority Priority, tag string) (*Writer, error) {
if priority < 0 || priority > LOG_LOCAL7|LOG_DEBUG {
return nil, errors.New("log/syslog: invalid priority")
}
if tag == "" {
tag = os.Args[0]
}
hostname, _ := os.Hostname()
w := &Writer{
priority: priority,
tag: tag,
hostname: hostname,
network: network,
raddr: raddr,
}
w.mu.Lock()
defer w.mu.Unlock()
err := w.connect()
if err != nil {
return nil, err
}
return w, err
}
// connect makes a connection to the syslog server.
// It must be called with w.mu held.
func (w *Writer) connect() (err error) {
if w.conn != nil {
// ignore err from close, it makes sense to continue anyway
w.conn.close()
w.conn = nil
}
if w.network == "" {
w.conn, err = unixSyslog()
if w.hostname == "" {
w.hostname = "localhost"
}
} else {
var c net.Conn
c, err = net.Dial(w.network, w.raddr)
if err == nil {
w.conn = &netConn{
conn: c,
local: w.network == "unixgram" || w.network == "unix",
}
if w.hostname == "" {
w.hostname = c.LocalAddr().String()
}
}
}
return
}
// Write sends a log message to the syslog daemon.
func (w *Writer) Write(b []byte) (int, error) {
return w.writeAndRetry(w.priority, string(b))
}
// Close closes a connection to the syslog daemon.
func (w *Writer) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.conn != nil {
err := w.conn.close()
w.conn = nil
return err
}
return nil
}
// Emerg logs a message with severity LOG_EMERG, ignoring the severity
// passed to New.
func (w *Writer) Emerg(m string) error {
_, err := w.writeAndRetry(LOG_EMERG, m)
return err
}
// Alert logs a message with severity LOG_ALERT, ignoring the severity
// passed to New.
func (w *Writer) Alert(m string) error {
_, err := w.writeAndRetry(LOG_ALERT, m)
return err
}
// Crit logs a message with severity LOG_CRIT, ignoring the severity
// passed to New.
func (w *Writer) Crit(m string) error {
_, err := w.writeAndRetry(LOG_CRIT, m)
return err
}
// Err logs a message with severity LOG_ERR, ignoring the severity
// passed to New.
func (w *Writer) Err(m string) error {
_, err := w.writeAndRetry(LOG_ERR, m)
return err
}
// Warning logs a message with severity LOG_WARNING, ignoring the
// severity passed to New.
func (w *Writer) Warning(m string) error {
_, err := w.writeAndRetry(LOG_WARNING, m)
return err
}
// Notice logs a message with severity LOG_NOTICE, ignoring the
// severity passed to New.
func (w *Writer) Notice(m string) error {
_, err := w.writeAndRetry(LOG_NOTICE, m)
return err
}
// Info logs a message with severity LOG_INFO, ignoring the severity
// passed to New.
func (w *Writer) Info(m string) error {
_, err := w.writeAndRetry(LOG_INFO, m)
return err
}
// Debug logs a message with severity LOG_DEBUG, ignoring the severity
// passed to New.
func (w *Writer) Debug(m string) error {
_, err := w.writeAndRetry(LOG_DEBUG, m)
return err
}
func (w *Writer) writeAndRetry(p Priority, s string) (int, error) {
pr := (w.priority & facilityMask) | (p & severityMask)
w.mu.Lock()
defer w.mu.Unlock()
if w.conn != nil {
if n, err := w.write(pr, s); err == nil {
return n, nil
}
}
if err := w.connect(); err != nil {
return 0, err
}
return w.write(pr, s)
}
// write generates and writes a syslog formatted string. The
// format is as follows: <PRI>TIMESTAMP HOSTNAME TAG[PID]: MSG
func (w *Writer) write(p Priority, msg string) (int, error) {
// ensure it ends in a \n
nl := ""
if !strings.HasSuffix(msg, "\n") {
nl = "\n"
}
err := w.conn.writeString(p, w.hostname, w.tag, msg, nl)
if err != nil {
return 0, err
}
// Note: return the length of the input, not the number of
// bytes printed by Fprintf, because this must behave like
// an io.Writer.
return len(msg), nil
}
func (n *netConn) writeString(p Priority, hostname, tag, msg, nl string) error {
if n.local {
// Compared to the network form below, the changes are:
// 1. Use time.Stamp instead of time.RFC3339.
// 2. Drop the hostname field from the Fprintf.
timestamp := time.Now().Format(time.Stamp)
_, err := fmt.Fprintf(n.conn, "<%d>%s %s[%d]: %s%s",
p, timestamp,
tag, os.Getpid(), msg, nl)
return err
}
timestamp := time.Now().Format(time.RFC3339)
_, err := fmt.Fprintf(n.conn, "<%d>%s %s %s[%d]: %s%s",
p, timestamp, hostname,
tag, os.Getpid(), msg, nl)
return err
}
func (n *netConn) close() error {
return n.conn.Close()
}
// NewLogger creates a log.Logger whose output is written to the
// system log service with the specified priority, a combination of
// the syslog facility and severity. The logFlag argument is the flag
// set passed through to log.New to create the Logger.
func NewLogger(p Priority, logFlag int) (*log.Logger, error) {
s, err := New(p, "")
if err != nil {
return nil, err
}
return log.New(s, "", logFlag), nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows && !plan9
package syslog
import (
"errors"
"net"
)
// unixSyslog opens a connection to the syslog daemon running on the
// local machine using a Unix domain socket.
func unixSyslog() (conn serverConn, err error) {
logTypes := []string{"unixgram", "unix"}
logPaths := []string{"/dev/log", "/var/run/syslog", "/var/run/log"}
for _, network := range logTypes {
for _, path := range logPaths {
conn, err := net.Dial(network, path)
if err == nil {
return &netConn{conn: conn, local: true}, nil
}
}
}
return nil, errors.New("Unix syslog delivery error")
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package maps defines various functions useful with maps of any type.
package maps
// Keys returns the keys of the map m.
// The keys will be in an indeterminate order.
func Keys[M ~map[K]V, K comparable, V any](m M) []K {
r := make([]K, 0, len(m))
for k := range m {
r = append(r, k)
}
return r
}
// Values returns the values of the map m.
// The values will be in an indeterminate order.
func Values[M ~map[K]V, K comparable, V any](m M) []V {
r := make([]V, 0, len(m))
for _, v := range m {
r = append(r, v)
}
return r
}
// Equal reports whether two maps contain the same key/value pairs.
// Values are compared using ==.
func Equal[M1, M2 ~map[K]V, K, V comparable](m1 M1, m2 M2) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
if v2, ok := m2[k]; !ok || v1 != v2 {
return false
}
}
return true
}
// EqualFunc is like Equal, but compares values using eq.
// Keys are still compared with ==.
func EqualFunc[M1 ~map[K]V1, M2 ~map[K]V2, K comparable, V1, V2 any](m1 M1, m2 M2, eq func(V1, V2) bool) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
if v2, ok := m2[k]; !ok || !eq(v1, v2) {
return false
}
}
return true
}
// Clone returns a copy of m. This is a shallow clone:
// the new keys and values are set using ordinary assignment.
func Clone[M ~map[K]V, K comparable, V any](m M) M {
// Preserve nil in case it matters.
if m == nil {
return nil
}
r := make(M, len(m))
for k, v := range m {
r[k] = v
}
return r
}
// Copy copies all key/value pairs in src adding them to dst.
// When a key in src is already present in dst,
// the value in dst will be overwritten by the value associated
// with the key in src.
func Copy[M1 ~map[K]V, M2 ~map[K]V, K comparable, V any](dst M1, src M2) {
for k, v := range src {
dst[k] = v
}
}
// DeleteFunc deletes any key/value pairs from m for which del returns true.
func DeleteFunc[M ~map[K]V, K comparable, V any](m M, del func(K, V) bool) {
for k, v := range m {
if del(k, v) {
delete(m, k)
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Abs returns the absolute value of x.
//
// Special cases are:
//
// Abs(±Inf) = +Inf
// Abs(NaN) = NaN
func Abs(x float64) float64 {
return Float64frombits(Float64bits(x) &^ (1 << 63))
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/e_acosh.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// __ieee754_acosh(x)
// Method :
// Based on
// acosh(x) = log [ x + sqrt(x*x-1) ]
// we have
// acosh(x) := log(x)+ln2, if x is large; else
// acosh(x) := log(2x-1/(sqrt(x*x-1)+x)) if x>2; else
// acosh(x) := log1p(t+sqrt(2.0*t+t*t)); where t=x-1.
//
// Special cases:
// acosh(x) is NaN with signal if x<1.
// acosh(NaN) is NaN without signal.
//
// Acosh returns the inverse hyperbolic cosine of x.
//
// Special cases are:
//
// Acosh(+Inf) = +Inf
// Acosh(x) = NaN if x < 1
// Acosh(NaN) = NaN
func Acosh(x float64) float64 {
if haveArchAcosh {
return archAcosh(x)
}
return acosh(x)
}
func acosh(x float64) float64 {
const Large = 1 << 28 // 2**28
// first case is special case
switch {
case x < 1 || IsNaN(x):
return NaN()
case x == 1:
return 0
case x >= Large:
return Log(x) + Ln2 // x > 2**28
case x > 2:
return Log(2*x - 1/(x+Sqrt(x*x-1))) // 2**28 > x > 2
}
t := x - 1
return Log1p(t + Sqrt(2*t+t*t)) // 2 >= x > 1
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point arcsine and arccosine.
They are implemented by computing the arctangent
after appropriate range reduction.
*/
// Asin returns the arcsine, in radians, of x.
//
// Special cases are:
//
// Asin(±0) = ±0
// Asin(x) = NaN if x < -1 or x > 1
func Asin(x float64) float64 {
if haveArchAsin {
return archAsin(x)
}
return asin(x)
}
func asin(x float64) float64 {
if x == 0 {
return x // special case
}
sign := false
if x < 0 {
x = -x
sign = true
}
if x > 1 {
return NaN() // special case
}
temp := Sqrt(1 - x*x)
if x > 0.7 {
temp = Pi/2 - satan(temp/x)
} else {
temp = satan(x / temp)
}
if sign {
temp = -temp
}
return temp
}
// Acos returns the arccosine, in radians, of x.
//
// Special case is:
//
// Acos(x) = NaN if x < -1 or x > 1
func Acos(x float64) float64 {
if haveArchAcos {
return archAcos(x)
}
return acos(x)
}
func acos(x float64) float64 {
return Pi/2 - Asin(x)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/s_asinh.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// asinh(x)
// Method :
// Based on
// asinh(x) = sign(x) * log [ |x| + sqrt(x*x+1) ]
// we have
// asinh(x) := x if 1+x*x=1,
// := sign(x)*(log(x)+ln2)) for large |x|, else
// := sign(x)*log(2|x|+1/(|x|+sqrt(x*x+1))) if|x|>2, else
// := sign(x)*log1p(|x| + x**2/(1 + sqrt(1+x**2)))
//
// Asinh returns the inverse hyperbolic sine of x.
//
// Special cases are:
//
// Asinh(±0) = ±0
// Asinh(±Inf) = ±Inf
// Asinh(NaN) = NaN
func Asinh(x float64) float64 {
if haveArchAsinh {
return archAsinh(x)
}
return asinh(x)
}
func asinh(x float64) float64 {
const (
Ln2 = 6.93147180559945286227e-01 // 0x3FE62E42FEFA39EF
NearZero = 1.0 / (1 << 28) // 2**-28
Large = 1 << 28 // 2**28
)
// special cases
if IsNaN(x) || IsInf(x, 0) {
return x
}
sign := false
if x < 0 {
x = -x
sign = true
}
var temp float64
switch {
case x > Large:
temp = Log(x) + Ln2 // |x| > 2**28
case x > 2:
temp = Log(2*x + 1/(Sqrt(x*x+1)+x)) // 2**28 > |x| > 2.0
case x < NearZero:
temp = x // |x| < 2**-28
default:
temp = Log1p(x + x*x/(1+Sqrt(1+x*x))) // 2.0 > |x| > 2**-28
}
if sign {
temp = -temp
}
return temp
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point arctangent.
*/
// The original C code, the long comment, and the constants below were
// from http://netlib.sandia.gov/cephes/cmath/atan.c, available from
// http://www.netlib.org/cephes/cmath.tgz.
// The go code is a version of the original C.
//
// atan.c
// Inverse circular tangent (arctangent)
//
// SYNOPSIS:
// double x, y, atan();
// y = atan( x );
//
// DESCRIPTION:
// Returns radian angle between -pi/2 and +pi/2 whose tangent is x.
//
// Range reduction is from three intervals into the interval from zero to 0.66.
// The approximant uses a rational function of degree 4/5 of the form
// x + x**3 P(x)/Q(x).
//
// ACCURACY:
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10, 10 50000 2.4e-17 8.3e-18
// IEEE -10, 10 10^6 1.8e-16 5.0e-17
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// xatan evaluates a series valid in the range [0, 0.66].
func xatan(x float64) float64 {
const (
P0 = -8.750608600031904122785e-01
P1 = -1.615753718733365076637e+01
P2 = -7.500855792314704667340e+01
P3 = -1.228866684490136173410e+02
P4 = -6.485021904942025371773e+01
Q0 = +2.485846490142306297962e+01
Q1 = +1.650270098316988542046e+02
Q2 = +4.328810604912902668951e+02
Q3 = +4.853903996359136964868e+02
Q4 = +1.945506571482613964425e+02
)
z := x * x
z = z * ((((P0*z+P1)*z+P2)*z+P3)*z + P4) / (((((z+Q0)*z+Q1)*z+Q2)*z+Q3)*z + Q4)
z = x*z + x
return z
}
// satan reduces its argument (known to be positive)
// to the range [0, 0.66] and calls xatan.
func satan(x float64) float64 {
const (
Morebits = 6.123233995736765886130e-17 // pi/2 = PIO2 + Morebits
Tan3pio8 = 2.41421356237309504880 // tan(3*pi/8)
)
if x <= 0.66 {
return xatan(x)
}
if x > Tan3pio8 {
return Pi/2 - xatan(1/x) + Morebits
}
return Pi/4 + xatan((x-1)/(x+1)) + 0.5*Morebits
}
// Atan returns the arctangent, in radians, of x.
//
// Special cases are:
//
// Atan(±0) = ±0
// Atan(±Inf) = ±Pi/2
func Atan(x float64) float64 {
if haveArchAtan {
return archAtan(x)
}
return atan(x)
}
func atan(x float64) float64 {
if x == 0 {
return x
}
if x > 0 {
return satan(x)
}
return -satan(-x)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Atan2 returns the arc tangent of y/x, using
// the signs of the two to determine the quadrant
// of the return value.
//
// Special cases are (in order):
//
// Atan2(y, NaN) = NaN
// Atan2(NaN, x) = NaN
// Atan2(+0, x>=0) = +0
// Atan2(-0, x>=0) = -0
// Atan2(+0, x<=-0) = +Pi
// Atan2(-0, x<=-0) = -Pi
// Atan2(y>0, 0) = +Pi/2
// Atan2(y<0, 0) = -Pi/2
// Atan2(+Inf, +Inf) = +Pi/4
// Atan2(-Inf, +Inf) = -Pi/4
// Atan2(+Inf, -Inf) = 3Pi/4
// Atan2(-Inf, -Inf) = -3Pi/4
// Atan2(y, +Inf) = 0
// Atan2(y>0, -Inf) = +Pi
// Atan2(y<0, -Inf) = -Pi
// Atan2(+Inf, x) = +Pi/2
// Atan2(-Inf, x) = -Pi/2
func Atan2(y, x float64) float64 {
if haveArchAtan2 {
return archAtan2(y, x)
}
return atan2(y, x)
}
func atan2(y, x float64) float64 {
// special cases
switch {
case IsNaN(y) || IsNaN(x):
return NaN()
case y == 0:
if x >= 0 && !Signbit(x) {
return Copysign(0, y)
}
return Copysign(Pi, y)
case x == 0:
return Copysign(Pi/2, y)
case IsInf(x, 0):
if IsInf(x, 1) {
switch {
case IsInf(y, 0):
return Copysign(Pi/4, y)
default:
return Copysign(0, y)
}
}
switch {
case IsInf(y, 0):
return Copysign(3*Pi/4, y)
default:
return Copysign(Pi, y)
}
case IsInf(y, 0):
return Copysign(Pi/2, y)
}
// Call atan and determine the quadrant.
q := Atan(y / x)
if x < 0 {
if q <= 0 {
return q + Pi
}
return q - Pi
}
return q
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/e_atanh.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// __ieee754_atanh(x)
// Method :
// 1. Reduce x to positive by atanh(-x) = -atanh(x)
// 2. For x>=0.5
// 1 2x x
// atanh(x) = --- * log(1 + -------) = 0.5 * log1p(2 * --------)
// 2 1 - x 1 - x
//
// For x<0.5
// atanh(x) = 0.5*log1p(2x+2x*x/(1-x))
//
// Special cases:
// atanh(x) is NaN if |x| > 1 with signal;
// atanh(NaN) is that NaN with no signal;
// atanh(+-1) is +-INF with signal.
//
// Atanh returns the inverse hyperbolic tangent of x.
//
// Special cases are:
//
// Atanh(1) = +Inf
// Atanh(±0) = ±0
// Atanh(-1) = -Inf
// Atanh(x) = NaN if x < -1 or x > 1
// Atanh(NaN) = NaN
func Atanh(x float64) float64 {
if haveArchAtanh {
return archAtanh(x)
}
return atanh(x)
}
func atanh(x float64) float64 {
const NearZero = 1.0 / (1 << 28) // 2**-28
// special cases
switch {
case x < -1 || x > 1 || IsNaN(x):
return NaN()
case x == 1:
return Inf(1)
case x == -1:
return Inf(-1)
}
sign := false
if x < 0 {
x = -x
sign = true
}
var temp float64
switch {
case x < NearZero:
temp = x
case x < 0.5:
temp = x + x
temp = 0.5 * Log1p(temp+temp*x/(1-x))
default:
temp = 0.5 * Log1p((x+x)/(1-x))
}
if sign {
temp = -temp
}
return temp
}
// Code generated by "stringer -type=Accuracy"; DO NOT EDIT.
package big
import "strconv"
const _Accuracy_name = "BelowExactAbove"
var _Accuracy_index = [...]uint8{0, 5, 10, 15}
func (i Accuracy) String() string {
i -= -1
if i < 0 || i >= Accuracy(len(_Accuracy_index)-1) {
return "Accuracy(" + strconv.FormatInt(int64(i+-1), 10) + ")"
}
return _Accuracy_name[_Accuracy_index[i]:_Accuracy_index[i+1]]
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file provides Go implementations of elementary multi-precision
// arithmetic operations on word vectors. These have the suffix _g.
// These are needed for platforms without assembly implementations of these routines.
// This file also contains elementary operations that can be implemented
// sufficiently efficiently in Go.
package big
import "math/bits"
// A Word represents a single digit of a multi-precision unsigned integer.
type Word uint
const (
_S = _W / 8 // word size in bytes
_W = bits.UintSize // word size in bits
_B = 1 << _W // digit base
_M = _B - 1 // digit mask
)
// Many of the loops in this file are of the form
// for i := 0; i < len(z) && i < len(x) && i < len(y); i++
// i < len(z) is the real condition.
// However, checking i < len(x) && i < len(y) as well is faster than
// having the compiler do a bounds check in the body of the loop;
// remarkably it is even faster than hoisting the bounds check
// out of the loop, by doing something like
// _, _ = x[len(z)-1], y[len(z)-1]
// There are other ways to hoist the bounds check out of the loop,
// but the compiler's BCE isn't powerful enough for them (yet?).
// See the discussion in CL 164966.
// ----------------------------------------------------------------------------
// Elementary operations on words
//
// These operations are used by the vector operations below.
// z1<<_W + z0 = x*y
func mulWW(x, y Word) (z1, z0 Word) {
hi, lo := bits.Mul(uint(x), uint(y))
return Word(hi), Word(lo)
}
// z1<<_W + z0 = x*y + c
func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
hi, lo := bits.Mul(uint(x), uint(y))
var cc uint
lo, cc = bits.Add(lo, uint(c), 0)
return Word(hi + cc), Word(lo)
}
// nlz returns the number of leading zeros in x.
// Wraps bits.LeadingZeros call for convenience.
func nlz(x Word) uint {
return uint(bits.LeadingZeros(uint(x)))
}
// The resulting carry c is either 0 or 1.
func addVV_g(z, x, y []Word) (c Word) {
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x) && i < len(y); i++ {
zi, cc := bits.Add(uint(x[i]), uint(y[i]), uint(c))
z[i] = Word(zi)
c = Word(cc)
}
return
}
// The resulting carry c is either 0 or 1.
func subVV_g(z, x, y []Word) (c Word) {
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x) && i < len(y); i++ {
zi, cc := bits.Sub(uint(x[i]), uint(y[i]), uint(c))
z[i] = Word(zi)
c = Word(cc)
}
return
}
// The resulting carry c is either 0 or 1.
func addVW_g(z, x []Word, y Word) (c Word) {
c = y
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
zi, cc := bits.Add(uint(x[i]), uint(c), 0)
z[i] = Word(zi)
c = Word(cc)
}
return
}
// addVWlarge is addVW, but intended for large z.
// The only difference is that we check on every iteration
// whether we are done with carries,
// and if so, switch to a much faster copy instead.
// This is only a good idea for large z,
// because the overhead of the check and the function call
// outweigh the benefits when z is small.
func addVWlarge(z, x []Word, y Word) (c Word) {
c = y
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
if c == 0 {
copy(z[i:], x[i:])
return
}
zi, cc := bits.Add(uint(x[i]), uint(c), 0)
z[i] = Word(zi)
c = Word(cc)
}
return
}
func subVW_g(z, x []Word, y Word) (c Word) {
c = y
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
zi, cc := bits.Sub(uint(x[i]), uint(c), 0)
z[i] = Word(zi)
c = Word(cc)
}
return
}
// subVWlarge is to subVW as addVWlarge is to addVW.
func subVWlarge(z, x []Word, y Word) (c Word) {
c = y
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
if c == 0 {
copy(z[i:], x[i:])
return
}
zi, cc := bits.Sub(uint(x[i]), uint(c), 0)
z[i] = Word(zi)
c = Word(cc)
}
return
}
func shlVU_g(z, x []Word, s uint) (c Word) {
if s == 0 {
copy(z, x)
return
}
if len(z) == 0 {
return
}
s &= _W - 1 // hint to the compiler that shifts by s don't need guard code
ŝ := _W - s
ŝ &= _W - 1 // ditto
c = x[len(z)-1] >> ŝ
for i := len(z) - 1; i > 0; i-- {
z[i] = x[i]<<s | x[i-1]>>ŝ
}
z[0] = x[0] << s
return
}
func shrVU_g(z, x []Word, s uint) (c Word) {
if s == 0 {
copy(z, x)
return
}
if len(z) == 0 {
return
}
if len(x) != len(z) {
// This is an invariant guaranteed by the caller.
panic("len(x) != len(z)")
}
s &= _W - 1 // hint to the compiler that shifts by s don't need guard code
ŝ := _W - s
ŝ &= _W - 1 // ditto
c = x[0] << ŝ
for i := 1; i < len(z); i++ {
z[i-1] = x[i-1]>>s | x[i]<<ŝ
}
z[len(z)-1] = x[len(z)-1] >> s
return
}
func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
c = r
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
c, z[i] = mulAddWWW_g(x[i], y, c)
}
return
}
func addMulVVW_g(z, x []Word, y Word) (c Word) {
// The comment near the top of this file discusses this for loop condition.
for i := 0; i < len(z) && i < len(x); i++ {
z1, z0 := mulAddWWW_g(x[i], y, z[i])
lo, cc := bits.Add(uint(z0), uint(c), 0)
c, z[i] = Word(cc), Word(lo)
c += z1
}
return
}
// q = ( x1 << _W + x0 - r)/y. m = floor(( _B^2 - 1 ) / d - _B). Requiring x1<y.
// An approximate reciprocal with a reference to "Improved Division by Invariant Integers
// (IEEE Transactions on Computers, 11 Jun. 2010)"
func divWW(x1, x0, y, m Word) (q, r Word) {
s := nlz(y)
if s != 0 {
x1 = x1<<s | x0>>(_W-s)
x0 <<= s
y <<= s
}
d := uint(y)
// We know that
// m = ⎣(B^2-1)/d⎦-B
// ⎣(B^2-1)/d⎦ = m+B
// (B^2-1)/d = m+B+delta1 0 <= delta1 <= (d-1)/d
// B^2/d = m+B+delta2 0 <= delta2 <= 1
// The quotient we're trying to compute is
// quotient = ⎣(x1*B+x0)/d⎦
// = ⎣(x1*B*(B^2/d)+x0*(B^2/d))/B^2⎦
// = ⎣(x1*B*(m+B+delta2)+x0*(m+B+delta2))/B^2⎦
// = ⎣(x1*m+x1*B+x0)/B + x0*m/B^2 + delta2*(x1*B+x0)/B^2⎦
// The latter two terms of this three-term sum are between 0 and 1.
// So we can compute just the first term, and we will be low by at most 2.
t1, t0 := bits.Mul(uint(m), uint(x1))
_, c := bits.Add(t0, uint(x0), 0)
t1, _ = bits.Add(t1, uint(x1), c)
// The quotient is either t1, t1+1, or t1+2.
// We'll try t1 and adjust if needed.
qq := t1
// compute remainder r=x-d*q.
dq1, dq0 := bits.Mul(d, qq)
r0, b := bits.Sub(uint(x0), dq0, 0)
r1, _ := bits.Sub(uint(x1), dq1, b)
// The remainder we just computed is bounded above by B+d:
// r = x1*B + x0 - d*q.
// = x1*B + x0 - d*⎣(x1*m+x1*B+x0)/B⎦
// = x1*B + x0 - d*((x1*m+x1*B+x0)/B-alpha) 0 <= alpha < 1
// = x1*B + x0 - x1*d/B*m - x1*d - x0*d/B + d*alpha
// = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha
// = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha
// = x1*B + x0 - x1*d/B*((B^2-1)/d-B-beta) - x1*d - x0*d/B + d*alpha 0 <= beta < 1
// = x1*B + x0 - x1*B + x1/B + x1*d + x1*d/B*beta - x1*d - x0*d/B + d*alpha
// = x0 + x1/B + x1*d/B*beta - x0*d/B + d*alpha
// = x0*(1-d/B) + x1*(1+d*beta)/B + d*alpha
// < B*(1-d/B) + d*B/B + d because x0<B (and 1-d/B>0), x1<d, 1+d*beta<=B, alpha<1
// = B - d + d + d
// = B+d
// So r1 can only be 0 or 1. If r1 is 1, then we know q was too small.
// Add 1 to q and subtract d from r. That guarantees that r is <B, so
// we no longer need to keep track of r1.
if r1 != 0 {
qq++
r0 -= d
}
// If the remainder is still too large, increment q one more time.
if r0 >= d {
qq++
r0 -= d
}
return Word(qq), Word(r0 >> s)
}
// reciprocalWord return the reciprocal of the divisor. rec = floor(( _B^2 - 1 ) / u - _B). u = d1 << nlz(d1).
func reciprocalWord(d1 Word) Word {
u := uint(d1 << nlz(d1))
x1 := ^u
x0 := uint(_M)
rec, _ := bits.Div(x1, x0, u) // (_B^2-1)/U-_B = (_B*(_M-C)+_M)/U
return Word(rec)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements multi-precision decimal numbers.
// The implementation is for float to decimal conversion only;
// not general purpose use.
// The only operations are precise conversion from binary to
// decimal and rounding.
//
// The key observation and some code (shr) is borrowed from
// strconv/decimal.go: conversion of binary fractional values can be done
// precisely in multi-precision decimal because 2 divides 10 (required for
// >> of mantissa); but conversion of decimal floating-point values cannot
// be done precisely in binary representation.
//
// In contrast to strconv/decimal.go, only right shift is implemented in
// decimal format - left shift can be done precisely in binary format.
package big
// A decimal represents an unsigned floating-point number in decimal representation.
// The value of a non-zero decimal d is d.mant * 10**d.exp with 0.1 <= d.mant < 1,
// with the most-significant mantissa digit at index 0. For the zero decimal, the
// mantissa length and exponent are 0.
// The zero value for decimal represents a ready-to-use 0.0.
type decimal struct {
mant []byte // mantissa ASCII digits, big-endian
exp int // exponent
}
// at returns the i'th mantissa digit, starting with the most significant digit at 0.
func (d *decimal) at(i int) byte {
if 0 <= i && i < len(d.mant) {
return d.mant[i]
}
return '0'
}
// Maximum shift amount that can be done in one pass without overflow.
// A Word has _W bits and (1<<maxShift - 1)*10 + 9 must fit into Word.
const maxShift = _W - 4
// TODO(gri) Since we know the desired decimal precision when converting
// a floating-point number, we may be able to limit the number of decimal
// digits that need to be computed by init by providing an additional
// precision argument and keeping track of when a number was truncated early
// (equivalent of "sticky bit" in binary rounding).
// TODO(gri) Along the same lines, enforce some limit to shift magnitudes
// to avoid "infinitely" long running conversions (until we run out of space).
// Init initializes x to the decimal representation of m << shift (for
// shift >= 0), or m >> -shift (for shift < 0).
func (x *decimal) init(m nat, shift int) {
// special case 0
if len(m) == 0 {
x.mant = x.mant[:0]
x.exp = 0
return
}
// Optimization: If we need to shift right, first remove any trailing
// zero bits from m to reduce shift amount that needs to be done in
// decimal format (since that is likely slower).
if shift < 0 {
ntz := m.trailingZeroBits()
s := uint(-shift)
if s >= ntz {
s = ntz // shift at most ntz bits
}
m = nat(nil).shr(m, s)
shift += int(s)
}
// Do any shift left in binary representation.
if shift > 0 {
m = nat(nil).shl(m, uint(shift))
shift = 0
}
// Convert mantissa into decimal representation.
s := m.utoa(10)
n := len(s)
x.exp = n
// Trim trailing zeros; instead the exponent is tracking
// the decimal point independent of the number of digits.
for n > 0 && s[n-1] == '0' {
n--
}
x.mant = append(x.mant[:0], s[:n]...)
// Do any (remaining) shift right in decimal representation.
if shift < 0 {
for shift < -maxShift {
shr(x, maxShift)
shift += maxShift
}
shr(x, uint(-shift))
}
}
// shr implements x >> s, for s <= maxShift.
func shr(x *decimal, s uint) {
// Division by 1<<s using shift-and-subtract algorithm.
// pick up enough leading digits to cover first shift
r := 0 // read index
var n Word
for n>>s == 0 && r < len(x.mant) {
ch := Word(x.mant[r])
r++
n = n*10 + ch - '0'
}
if n == 0 {
// x == 0; shouldn't get here, but handle anyway
x.mant = x.mant[:0]
return
}
for n>>s == 0 {
r++
n *= 10
}
x.exp += 1 - r
// read a digit, write a digit
w := 0 // write index
mask := Word(1)<<s - 1
for r < len(x.mant) {
ch := Word(x.mant[r])
r++
d := n >> s
n &= mask // n -= d << s
x.mant[w] = byte(d + '0')
w++
n = n*10 + ch - '0'
}
// write extra digits that still fit
for n > 0 && w < len(x.mant) {
d := n >> s
n &= mask
x.mant[w] = byte(d + '0')
w++
n = n * 10
}
x.mant = x.mant[:w] // the number may be shorter (e.g. 1024 >> 10)
// append additional digits that didn't fit
for n > 0 {
d := n >> s
n &= mask
x.mant = append(x.mant, byte(d+'0'))
n = n * 10
}
trim(x)
}
func (x *decimal) String() string {
if len(x.mant) == 0 {
return "0"
}
var buf []byte
switch {
case x.exp <= 0:
// 0.00ddd
buf = make([]byte, 0, 2+(-x.exp)+len(x.mant))
buf = append(buf, "0."...)
buf = appendZeros(buf, -x.exp)
buf = append(buf, x.mant...)
case /* 0 < */ x.exp < len(x.mant):
// dd.ddd
buf = make([]byte, 0, 1+len(x.mant))
buf = append(buf, x.mant[:x.exp]...)
buf = append(buf, '.')
buf = append(buf, x.mant[x.exp:]...)
default: // len(x.mant) <= x.exp
// ddd00
buf = make([]byte, 0, x.exp)
buf = append(buf, x.mant...)
buf = appendZeros(buf, x.exp-len(x.mant))
}
return string(buf)
}
// appendZeros appends n 0 digits to buf and returns buf.
func appendZeros(buf []byte, n int) []byte {
for ; n > 0; n-- {
buf = append(buf, '0')
}
return buf
}
// shouldRoundUp reports if x should be rounded up
// if shortened to n digits. n must be a valid index
// for x.mant.
func shouldRoundUp(x *decimal, n int) bool {
if x.mant[n] == '5' && n+1 == len(x.mant) {
// exactly halfway - round to even
return n > 0 && (x.mant[n-1]-'0')&1 != 0
}
// not halfway - digit tells all (x.mant has no trailing zeros)
return x.mant[n] >= '5'
}
// round sets x to (at most) n mantissa digits by rounding it
// to the nearest even value with n (or fever) mantissa digits.
// If n < 0, x remains unchanged.
func (x *decimal) round(n int) {
if n < 0 || n >= len(x.mant) {
return // nothing to do
}
if shouldRoundUp(x, n) {
x.roundUp(n)
} else {
x.roundDown(n)
}
}
func (x *decimal) roundUp(n int) {
if n < 0 || n >= len(x.mant) {
return // nothing to do
}
// 0 <= n < len(x.mant)
// find first digit < '9'
for n > 0 && x.mant[n-1] >= '9' {
n--
}
if n == 0 {
// all digits are '9's => round up to '1' and update exponent
x.mant[0] = '1' // ok since len(x.mant) > n
x.mant = x.mant[:1]
x.exp++
return
}
// n > 0 && x.mant[n-1] < '9'
x.mant[n-1]++
x.mant = x.mant[:n]
// x already trimmed
}
func (x *decimal) roundDown(n int) {
if n < 0 || n >= len(x.mant) {
return // nothing to do
}
x.mant = x.mant[:n]
trim(x)
}
// trim cuts off any trailing zeros from x's mantissa;
// they are meaningless for the value of x.
func trim(x *decimal) {
i := len(x.mant)
for i > 0 && x.mant[i-1] == '0' {
i--
}
x.mant = x.mant[:i]
if i == 0 {
x.exp = 0
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements multi-precision floating-point numbers.
// Like in the GNU MPFR library (https://www.mpfr.org/), operands
// can be of mixed precision. Unlike MPFR, the rounding mode is
// not specified with each operation, but with each operand. The
// rounding mode of the result operand determines the rounding
// mode of an operation. This is a from-scratch implementation.
package big
import (
"fmt"
"math"
"math/bits"
)
const debugFloat = false // enable for debugging
// A nonzero finite Float represents a multi-precision floating point number
//
// sign × mantissa × 2**exponent
//
// with 0.5 <= mantissa < 1.0, and MinExp <= exponent <= MaxExp.
// A Float may also be zero (+0, -0) or infinite (+Inf, -Inf).
// All Floats are ordered, and the ordering of two Floats x and y
// is defined by x.Cmp(y).
//
// Each Float value also has a precision, rounding mode, and accuracy.
// The precision is the maximum number of mantissa bits available to
// represent the value. The rounding mode specifies how a result should
// be rounded to fit into the mantissa bits, and accuracy describes the
// rounding error with respect to the exact result.
//
// Unless specified otherwise, all operations (including setters) that
// specify a *Float variable for the result (usually via the receiver
// with the exception of MantExp), round the numeric result according
// to the precision and rounding mode of the result variable.
//
// If the provided result precision is 0 (see below), it is set to the
// precision of the argument with the largest precision value before any
// rounding takes place, and the rounding mode remains unchanged. Thus,
// uninitialized Floats provided as result arguments will have their
// precision set to a reasonable value determined by the operands, and
// their mode is the zero value for RoundingMode (ToNearestEven).
//
// By setting the desired precision to 24 or 53 and using matching rounding
// mode (typically ToNearestEven), Float operations produce the same results
// as the corresponding float32 or float64 IEEE-754 arithmetic for operands
// that correspond to normal (i.e., not denormal) float32 or float64 numbers.
// Exponent underflow and overflow lead to a 0 or an Infinity for different
// values than IEEE-754 because Float exponents have a much larger range.
//
// The zero (uninitialized) value for a Float is ready to use and represents
// the number +0.0 exactly, with precision 0 and rounding mode ToNearestEven.
//
// Operations always take pointer arguments (*Float) rather
// than Float values, and each unique Float value requires
// its own unique *Float pointer. To "copy" a Float value,
// an existing (or newly allocated) Float must be set to
// a new value using the Float.Set method; shallow copies
// of Floats are not supported and may lead to errors.
type Float struct {
prec uint32
mode RoundingMode
acc Accuracy
form form
neg bool
mant nat
exp int32
}
// An ErrNaN panic is raised by a Float operation that would lead to
// a NaN under IEEE-754 rules. An ErrNaN implements the error interface.
type ErrNaN struct {
msg string
}
func (err ErrNaN) Error() string {
return err.msg
}
// NewFloat allocates and returns a new Float set to x,
// with precision 53 and rounding mode ToNearestEven.
// NewFloat panics with ErrNaN if x is a NaN.
func NewFloat(x float64) *Float {
if math.IsNaN(x) {
panic(ErrNaN{"NewFloat(NaN)"})
}
return new(Float).SetFloat64(x)
}
// Exponent and precision limits.
const (
MaxExp = math.MaxInt32 // largest supported exponent
MinExp = math.MinInt32 // smallest supported exponent
MaxPrec = math.MaxUint32 // largest (theoretically) supported precision; likely memory-limited
)
// Internal representation: The mantissa bits x.mant of a nonzero finite
// Float x are stored in a nat slice long enough to hold up to x.prec bits;
// the slice may (but doesn't have to) be shorter if the mantissa contains
// trailing 0 bits. x.mant is normalized if the msb of x.mant == 1 (i.e.,
// the msb is shifted all the way "to the left"). Thus, if the mantissa has
// trailing 0 bits or x.prec is not a multiple of the Word size _W,
// x.mant[0] has trailing zero bits. The msb of the mantissa corresponds
// to the value 0.5; the exponent x.exp shifts the binary point as needed.
//
// A zero or non-finite Float x ignores x.mant and x.exp.
//
// x form neg mant exp
// ----------------------------------------------------------
// ±0 zero sign - -
// 0 < |x| < +Inf finite sign mantissa exponent
// ±Inf inf sign - -
// A form value describes the internal representation.
type form byte
// The form value order is relevant - do not change!
const (
zero form = iota
finite
inf
)
// RoundingMode determines how a Float value is rounded to the
// desired precision. Rounding may change the Float value; the
// rounding error is described by the Float's Accuracy.
type RoundingMode byte
// These constants define supported rounding modes.
const (
ToNearestEven RoundingMode = iota // == IEEE 754-2008 roundTiesToEven
ToNearestAway // == IEEE 754-2008 roundTiesToAway
ToZero // == IEEE 754-2008 roundTowardZero
AwayFromZero // no IEEE 754-2008 equivalent
ToNegativeInf // == IEEE 754-2008 roundTowardNegative
ToPositiveInf // == IEEE 754-2008 roundTowardPositive
)
//go:generate stringer -type=RoundingMode
// Accuracy describes the rounding error produced by the most recent
// operation that generated a Float value, relative to the exact value.
type Accuracy int8
// Constants describing the Accuracy of a Float.
const (
Below Accuracy = -1
Exact Accuracy = 0
Above Accuracy = +1
)
//go:generate stringer -type=Accuracy
// SetPrec sets z's precision to prec and returns the (possibly) rounded
// value of z. Rounding occurs according to z's rounding mode if the mantissa
// cannot be represented in prec bits without loss of precision.
// SetPrec(0) maps all finite values to ±0; infinite values remain unchanged.
// If prec > MaxPrec, it is set to MaxPrec.
func (z *Float) SetPrec(prec uint) *Float {
z.acc = Exact // optimistically assume no rounding is needed
// special case
if prec == 0 {
z.prec = 0
if z.form == finite {
// truncate z to 0
z.acc = makeAcc(z.neg)
z.form = zero
}
return z
}
// general case
if prec > MaxPrec {
prec = MaxPrec
}
old := z.prec
z.prec = uint32(prec)
if z.prec < old {
z.round(0)
}
return z
}
func makeAcc(above bool) Accuracy {
if above {
return Above
}
return Below
}
// SetMode sets z's rounding mode to mode and returns an exact z.
// z remains unchanged otherwise.
// z.SetMode(z.Mode()) is a cheap way to set z's accuracy to Exact.
func (z *Float) SetMode(mode RoundingMode) *Float {
z.mode = mode
z.acc = Exact
return z
}
// Prec returns the mantissa precision of x in bits.
// The result may be 0 for |x| == 0 and |x| == Inf.
func (x *Float) Prec() uint {
return uint(x.prec)
}
// MinPrec returns the minimum precision required to represent x exactly
// (i.e., the smallest prec before x.SetPrec(prec) would start rounding x).
// The result is 0 for |x| == 0 and |x| == Inf.
func (x *Float) MinPrec() uint {
if x.form != finite {
return 0
}
return uint(len(x.mant))*_W - x.mant.trailingZeroBits()
}
// Mode returns the rounding mode of x.
func (x *Float) Mode() RoundingMode {
return x.mode
}
// Acc returns the accuracy of x produced by the most recent
// operation, unless explicitly documented otherwise by that
// operation.
func (x *Float) Acc() Accuracy {
return x.acc
}
// Sign returns:
//
// -1 if x < 0
// 0 if x is ±0
// +1 if x > 0
func (x *Float) Sign() int {
if debugFloat {
x.validate()
}
if x.form == zero {
return 0
}
if x.neg {
return -1
}
return 1
}
// MantExp breaks x into its mantissa and exponent components
// and returns the exponent. If a non-nil mant argument is
// provided its value is set to the mantissa of x, with the
// same precision and rounding mode as x. The components
// satisfy x == mant × 2**exp, with 0.5 <= |mant| < 1.0.
// Calling MantExp with a nil argument is an efficient way to
// get the exponent of the receiver.
//
// Special cases are:
//
// ( ±0).MantExp(mant) = 0, with mant set to ±0
// (±Inf).MantExp(mant) = 0, with mant set to ±Inf
//
// x and mant may be the same in which case x is set to its
// mantissa value.
func (x *Float) MantExp(mant *Float) (exp int) {
if debugFloat {
x.validate()
}
if x.form == finite {
exp = int(x.exp)
}
if mant != nil {
mant.Copy(x)
if mant.form == finite {
mant.exp = 0
}
}
return
}
func (z *Float) setExpAndRound(exp int64, sbit uint) {
if exp < MinExp {
// underflow
z.acc = makeAcc(z.neg)
z.form = zero
return
}
if exp > MaxExp {
// overflow
z.acc = makeAcc(!z.neg)
z.form = inf
return
}
z.form = finite
z.exp = int32(exp)
z.round(sbit)
}
// SetMantExp sets z to mant × 2**exp and returns z.
// The result z has the same precision and rounding mode
// as mant. SetMantExp is an inverse of MantExp but does
// not require 0.5 <= |mant| < 1.0. Specifically, for a
// given x of type *Float, SetMantExp relates to MantExp
// as follows:
//
// mant := new(Float)
// new(Float).SetMantExp(mant, x.MantExp(mant)).Cmp(x) == 0
//
// Special cases are:
//
// z.SetMantExp( ±0, exp) = ±0
// z.SetMantExp(±Inf, exp) = ±Inf
//
// z and mant may be the same in which case z's exponent
// is set to exp.
func (z *Float) SetMantExp(mant *Float, exp int) *Float {
if debugFloat {
z.validate()
mant.validate()
}
z.Copy(mant)
if z.form == finite {
// 0 < |mant| < +Inf
z.setExpAndRound(int64(z.exp)+int64(exp), 0)
}
return z
}
// Signbit reports whether x is negative or negative zero.
func (x *Float) Signbit() bool {
return x.neg
}
// IsInf reports whether x is +Inf or -Inf.
func (x *Float) IsInf() bool {
return x.form == inf
}
// IsInt reports whether x is an integer.
// ±Inf values are not integers.
func (x *Float) IsInt() bool {
if debugFloat {
x.validate()
}
// special cases
if x.form != finite {
return x.form == zero
}
// x.form == finite
if x.exp <= 0 {
return false
}
// x.exp > 0
return x.prec <= uint32(x.exp) || x.MinPrec() <= uint(x.exp) // not enough bits for fractional mantissa
}
// debugging support
func (x *Float) validate() {
if !debugFloat {
// avoid performance bugs
panic("validate called but debugFloat is not set")
}
if msg := x.validate0(); msg != "" {
panic(msg)
}
}
func (x *Float) validate0() string {
if x.form != finite {
return ""
}
m := len(x.mant)
if m == 0 {
return "nonzero finite number with empty mantissa"
}
const msb = 1 << (_W - 1)
if x.mant[m-1]&msb == 0 {
return fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Text('p', 0))
}
if x.prec == 0 {
return "zero precision finite number"
}
return ""
}
// round rounds z according to z.mode to z.prec bits and sets z.acc accordingly.
// sbit must be 0 or 1 and summarizes any "sticky bit" information one might
// have before calling round. z's mantissa must be normalized (with the msb set)
// or empty.
//
// CAUTION: The rounding modes ToNegativeInf, ToPositiveInf are affected by the
// sign of z. For correct rounding, the sign of z must be set correctly before
// calling round.
func (z *Float) round(sbit uint) {
if debugFloat {
z.validate()
}
z.acc = Exact
if z.form != finite {
// ±0 or ±Inf => nothing left to do
return
}
// z.form == finite && len(z.mant) > 0
// m > 0 implies z.prec > 0 (checked by validate)
m := uint32(len(z.mant)) // present mantissa length in words
bits := m * _W // present mantissa bits; bits > 0
if bits <= z.prec {
// mantissa fits => nothing to do
return
}
// bits > z.prec
// Rounding is based on two bits: the rounding bit (rbit) and the
// sticky bit (sbit). The rbit is the bit immediately before the
// z.prec leading mantissa bits (the "0.5"). The sbit is set if any
// of the bits before the rbit are set (the "0.25", "0.125", etc.):
//
// rbit sbit => "fractional part"
//
// 0 0 == 0
// 0 1 > 0 , < 0.5
// 1 0 == 0.5
// 1 1 > 0.5, < 1.0
// bits > z.prec: mantissa too large => round
r := uint(bits - z.prec - 1) // rounding bit position; r >= 0
rbit := z.mant.bit(r) & 1 // rounding bit; be safe and ensure it's a single bit
// The sticky bit is only needed for rounding ToNearestEven
// or when the rounding bit is zero. Avoid computation otherwise.
if sbit == 0 && (rbit == 0 || z.mode == ToNearestEven) {
sbit = z.mant.sticky(r)
}
sbit &= 1 // be safe and ensure it's a single bit
// cut off extra words
n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision
if m > n {
copy(z.mant, z.mant[m-n:]) // move n last words to front
z.mant = z.mant[:n]
}
// determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word
ntz := n*_W - z.prec // 0 <= ntz < _W
lsb := Word(1) << ntz
// round if result is inexact
if rbit|sbit != 0 {
// Make rounding decision: The result mantissa is truncated ("rounded down")
// by default. Decide if we need to increment, or "round up", the (unsigned)
// mantissa.
inc := false
switch z.mode {
case ToNegativeInf:
inc = z.neg
case ToZero:
// nothing to do
case ToNearestEven:
inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0)
case ToNearestAway:
inc = rbit != 0
case AwayFromZero:
inc = true
case ToPositiveInf:
inc = !z.neg
default:
panic("unreachable")
}
// A positive result (!z.neg) is Above the exact result if we increment,
// and it's Below if we truncate (Exact results require no rounding).
// For a negative result (z.neg) it is exactly the opposite.
z.acc = makeAcc(inc != z.neg)
if inc {
// add 1 to mantissa
if addVW(z.mant, z.mant, lsb) != 0 {
// mantissa overflow => adjust exponent
if z.exp >= MaxExp {
// exponent overflow
z.form = inf
return
}
z.exp++
// adjust mantissa: divide by 2 to compensate for exponent adjustment
shrVU(z.mant, z.mant, 1)
// set msb == carry == 1 from the mantissa overflow above
const msb = 1 << (_W - 1)
z.mant[n-1] |= msb
}
}
}
// zero out trailing bits in least-significant word
z.mant[0] &^= lsb - 1
if debugFloat {
z.validate()
}
}
func (z *Float) setBits64(neg bool, x uint64) *Float {
if z.prec == 0 {
z.prec = 64
}
z.acc = Exact
z.neg = neg
if x == 0 {
z.form = zero
return z
}
// x != 0
z.form = finite
s := bits.LeadingZeros64(x)
z.mant = z.mant.setUint64(x << uint(s))
z.exp = int32(64 - s) // always fits
if z.prec < 64 {
z.round(0)
}
return z
}
// SetUint64 sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to 64 (and rounding will have
// no effect).
func (z *Float) SetUint64(x uint64) *Float {
return z.setBits64(false, x)
}
// SetInt64 sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to 64 (and rounding will have
// no effect).
func (z *Float) SetInt64(x int64) *Float {
u := x
if u < 0 {
u = -u
}
// We cannot simply call z.SetUint64(uint64(u)) and change
// the sign afterwards because the sign affects rounding.
return z.setBits64(x < 0, uint64(u))
}
// SetFloat64 sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to 53 (and rounding will have
// no effect). SetFloat64 panics with ErrNaN if x is a NaN.
func (z *Float) SetFloat64(x float64) *Float {
if z.prec == 0 {
z.prec = 53
}
if math.IsNaN(x) {
panic(ErrNaN{"Float.SetFloat64(NaN)"})
}
z.acc = Exact
z.neg = math.Signbit(x) // handle -0, -Inf correctly
if x == 0 {
z.form = zero
return z
}
if math.IsInf(x, 0) {
z.form = inf
return z
}
// normalized x != 0
z.form = finite
fmant, exp := math.Frexp(x) // get normalized mantissa
z.mant = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11)
z.exp = int32(exp) // always fits
if z.prec < 53 {
z.round(0)
}
return z
}
// fnorm normalizes mantissa m by shifting it to the left
// such that the msb of the most-significant word (msw) is 1.
// It returns the shift amount. It assumes that len(m) != 0.
func fnorm(m nat) int64 {
if debugFloat && (len(m) == 0 || m[len(m)-1] == 0) {
panic("msw of mantissa is 0")
}
s := nlz(m[len(m)-1])
if s > 0 {
c := shlVU(m, m, s)
if debugFloat && c != 0 {
panic("nlz or shlVU incorrect")
}
}
return int64(s)
}
// SetInt sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to the larger of x.BitLen()
// or 64 (and rounding will have no effect).
func (z *Float) SetInt(x *Int) *Float {
// TODO(gri) can be more efficient if z.prec > 0
// but small compared to the size of x, or if there
// are many trailing 0's.
bits := uint32(x.BitLen())
if z.prec == 0 {
z.prec = umax32(bits, 64)
}
z.acc = Exact
z.neg = x.neg
if len(x.abs) == 0 {
z.form = zero
return z
}
// x != 0
z.mant = z.mant.set(x.abs)
fnorm(z.mant)
z.setExpAndRound(int64(bits), 0)
return z
}
// SetRat sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to the largest of a.BitLen(),
// b.BitLen(), or 64; with x = a/b.
func (z *Float) SetRat(x *Rat) *Float {
if x.IsInt() {
return z.SetInt(x.Num())
}
var a, b Float
a.SetInt(x.Num())
b.SetInt(x.Denom())
if z.prec == 0 {
z.prec = umax32(a.prec, b.prec)
}
return z.Quo(&a, &b)
}
// SetInf sets z to the infinite Float -Inf if signbit is
// set, or +Inf if signbit is not set, and returns z. The
// precision of z is unchanged and the result is always
// Exact.
func (z *Float) SetInf(signbit bool) *Float {
z.acc = Exact
z.form = inf
z.neg = signbit
return z
}
// Set sets z to the (possibly rounded) value of x and returns z.
// If z's precision is 0, it is changed to the precision of x
// before setting z (and rounding will have no effect).
// Rounding is performed according to z's precision and rounding
// mode; and z's accuracy reports the result error relative to the
// exact (not rounded) result.
func (z *Float) Set(x *Float) *Float {
if debugFloat {
x.validate()
}
z.acc = Exact
if z != x {
z.form = x.form
z.neg = x.neg
if x.form == finite {
z.exp = x.exp
z.mant = z.mant.set(x.mant)
}
if z.prec == 0 {
z.prec = x.prec
} else if z.prec < x.prec {
z.round(0)
}
}
return z
}
// Copy sets z to x, with the same precision, rounding mode, and
// accuracy as x, and returns z. x is not changed even if z and
// x are the same.
func (z *Float) Copy(x *Float) *Float {
if debugFloat {
x.validate()
}
if z != x {
z.prec = x.prec
z.mode = x.mode
z.acc = x.acc
z.form = x.form
z.neg = x.neg
if z.form == finite {
z.mant = z.mant.set(x.mant)
z.exp = x.exp
}
}
return z
}
// msb32 returns the 32 most significant bits of x.
func msb32(x nat) uint32 {
i := len(x) - 1
if i < 0 {
return 0
}
if debugFloat && x[i]&(1<<(_W-1)) == 0 {
panic("x not normalized")
}
switch _W {
case 32:
return uint32(x[i])
case 64:
return uint32(x[i] >> 32)
}
panic("unreachable")
}
// msb64 returns the 64 most significant bits of x.
func msb64(x nat) uint64 {
i := len(x) - 1
if i < 0 {
return 0
}
if debugFloat && x[i]&(1<<(_W-1)) == 0 {
panic("x not normalized")
}
switch _W {
case 32:
v := uint64(x[i]) << 32
if i > 0 {
v |= uint64(x[i-1])
}
return v
case 64:
return uint64(x[i])
}
panic("unreachable")
}
// Uint64 returns the unsigned integer resulting from truncating x
// towards zero. If 0 <= x <= math.MaxUint64, the result is Exact
// if x is an integer and Below otherwise.
// The result is (0, Above) for x < 0, and (math.MaxUint64, Below)
// for x > math.MaxUint64.
func (x *Float) Uint64() (uint64, Accuracy) {
if debugFloat {
x.validate()
}
switch x.form {
case finite:
if x.neg {
return 0, Above
}
// 0 < x < +Inf
if x.exp <= 0 {
// 0 < x < 1
return 0, Below
}
// 1 <= x < Inf
if x.exp <= 64 {
// u = trunc(x) fits into a uint64
u := msb64(x.mant) >> (64 - uint32(x.exp))
if x.MinPrec() <= 64 {
return u, Exact
}
return u, Below // x truncated
}
// x too large
return math.MaxUint64, Below
case zero:
return 0, Exact
case inf:
if x.neg {
return 0, Above
}
return math.MaxUint64, Below
}
panic("unreachable")
}
// Int64 returns the integer resulting from truncating x towards zero.
// If math.MinInt64 <= x <= math.MaxInt64, the result is Exact if x is
// an integer, and Above (x < 0) or Below (x > 0) otherwise.
// The result is (math.MinInt64, Above) for x < math.MinInt64,
// and (math.MaxInt64, Below) for x > math.MaxInt64.
func (x *Float) Int64() (int64, Accuracy) {
if debugFloat {
x.validate()
}
switch x.form {
case finite:
// 0 < |x| < +Inf
acc := makeAcc(x.neg)
if x.exp <= 0 {
// 0 < |x| < 1
return 0, acc
}
// x.exp > 0
// 1 <= |x| < +Inf
if x.exp <= 63 {
// i = trunc(x) fits into an int64 (excluding math.MinInt64)
i := int64(msb64(x.mant) >> (64 - uint32(x.exp)))
if x.neg {
i = -i
}
if x.MinPrec() <= uint(x.exp) {
return i, Exact
}
return i, acc // x truncated
}
if x.neg {
// check for special case x == math.MinInt64 (i.e., x == -(0.5 << 64))
if x.exp == 64 && x.MinPrec() == 1 {
acc = Exact
}
return math.MinInt64, acc
}
// x too large
return math.MaxInt64, Below
case zero:
return 0, Exact
case inf:
if x.neg {
return math.MinInt64, Above
}
return math.MaxInt64, Below
}
panic("unreachable")
}
// Float32 returns the float32 value nearest to x. If x is too small to be
// represented by a float32 (|x| < math.SmallestNonzeroFloat32), the result
// is (0, Below) or (-0, Above), respectively, depending on the sign of x.
// If x is too large to be represented by a float32 (|x| > math.MaxFloat32),
// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x.
func (x *Float) Float32() (float32, Accuracy) {
if debugFloat {
x.validate()
}
switch x.form {
case finite:
// 0 < |x| < +Inf
const (
fbits = 32 // float size
mbits = 23 // mantissa size (excluding implicit msb)
ebits = fbits - mbits - 1 // 8 exponent size
bias = 1<<(ebits-1) - 1 // 127 exponent bias
dmin = 1 - bias - mbits // -149 smallest unbiased exponent (denormal)
emin = 1 - bias // -126 smallest unbiased exponent (normal)
emax = bias // 127 largest unbiased exponent (normal)
)
// Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float32 mantissa.
e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0
// Compute precision p for float32 mantissa.
// If the exponent is too small, we have a denormal number before
// rounding and fewer than p mantissa bits of precision available
// (the exponent remains fixed but the mantissa gets shifted right).
p := mbits + 1 // precision of normal float
if e < emin {
// recompute precision
p = mbits + 1 - emin + int(e)
// If p == 0, the mantissa of x is shifted so much to the right
// that its msb falls immediately to the right of the float32
// mantissa space. In other words, if the smallest denormal is
// considered "1.0", for p == 0, the mantissa value m is >= 0.5.
// If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal.
// If m == 0.5, it is rounded down to even, i.e., 0.0.
// If p < 0, the mantissa value m is <= "0.25" which is never rounded up.
if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ {
// underflow to ±0
if x.neg {
var z float32
return -z, Above
}
return 0.0, Below
}
// otherwise, round up
// We handle p == 0 explicitly because it's easy and because
// Float.round doesn't support rounding to 0 bits of precision.
if p == 0 {
if x.neg {
return -math.SmallestNonzeroFloat32, Below
}
return math.SmallestNonzeroFloat32, Above
}
}
// p > 0
// round
var r Float
r.prec = uint32(p)
r.Set(x)
e = r.exp - 1
// Rounding may have caused r to overflow to ±Inf
// (rounding never causes underflows to 0).
// If the exponent is too large, also overflow to ±Inf.
if r.form == inf || e > emax {
// overflow
if x.neg {
return float32(math.Inf(-1)), Below
}
return float32(math.Inf(+1)), Above
}
// e <= emax
// Determine sign, biased exponent, and mantissa.
var sign, bexp, mant uint32
if x.neg {
sign = 1 << (fbits - 1)
}
// Rounding may have caused a denormal number to
// become normal. Check again.
if e < emin {
// denormal number: recompute precision
// Since rounding may have at best increased precision
// and we have eliminated p <= 0 early, we know p > 0.
// bexp == 0 for denormals
p = mbits + 1 - emin + int(e)
mant = msb32(r.mant) >> uint(fbits-p)
} else {
// normal number: emin <= e <= emax
bexp = uint32(e+bias) << mbits
mant = msb32(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit)
}
return math.Float32frombits(sign | bexp | mant), r.acc
case zero:
if x.neg {
var z float32
return -z, Exact
}
return 0.0, Exact
case inf:
if x.neg {
return float32(math.Inf(-1)), Exact
}
return float32(math.Inf(+1)), Exact
}
panic("unreachable")
}
// Float64 returns the float64 value nearest to x. If x is too small to be
// represented by a float64 (|x| < math.SmallestNonzeroFloat64), the result
// is (0, Below) or (-0, Above), respectively, depending on the sign of x.
// If x is too large to be represented by a float64 (|x| > math.MaxFloat64),
// the result is (+Inf, Above) or (-Inf, Below), depending on the sign of x.
func (x *Float) Float64() (float64, Accuracy) {
if debugFloat {
x.validate()
}
switch x.form {
case finite:
// 0 < |x| < +Inf
const (
fbits = 64 // float size
mbits = 52 // mantissa size (excluding implicit msb)
ebits = fbits - mbits - 1 // 11 exponent size
bias = 1<<(ebits-1) - 1 // 1023 exponent bias
dmin = 1 - bias - mbits // -1074 smallest unbiased exponent (denormal)
emin = 1 - bias // -1022 smallest unbiased exponent (normal)
emax = bias // 1023 largest unbiased exponent (normal)
)
// Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float64 mantissa.
e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0
// Compute precision p for float64 mantissa.
// If the exponent is too small, we have a denormal number before
// rounding and fewer than p mantissa bits of precision available
// (the exponent remains fixed but the mantissa gets shifted right).
p := mbits + 1 // precision of normal float
if e < emin {
// recompute precision
p = mbits + 1 - emin + int(e)
// If p == 0, the mantissa of x is shifted so much to the right
// that its msb falls immediately to the right of the float64
// mantissa space. In other words, if the smallest denormal is
// considered "1.0", for p == 0, the mantissa value m is >= 0.5.
// If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal.
// If m == 0.5, it is rounded down to even, i.e., 0.0.
// If p < 0, the mantissa value m is <= "0.25" which is never rounded up.
if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ {
// underflow to ±0
if x.neg {
var z float64
return -z, Above
}
return 0.0, Below
}
// otherwise, round up
// We handle p == 0 explicitly because it's easy and because
// Float.round doesn't support rounding to 0 bits of precision.
if p == 0 {
if x.neg {
return -math.SmallestNonzeroFloat64, Below
}
return math.SmallestNonzeroFloat64, Above
}
}
// p > 0
// round
var r Float
r.prec = uint32(p)
r.Set(x)
e = r.exp - 1
// Rounding may have caused r to overflow to ±Inf
// (rounding never causes underflows to 0).
// If the exponent is too large, also overflow to ±Inf.
if r.form == inf || e > emax {
// overflow
if x.neg {
return math.Inf(-1), Below
}
return math.Inf(+1), Above
}
// e <= emax
// Determine sign, biased exponent, and mantissa.
var sign, bexp, mant uint64
if x.neg {
sign = 1 << (fbits - 1)
}
// Rounding may have caused a denormal number to
// become normal. Check again.
if e < emin {
// denormal number: recompute precision
// Since rounding may have at best increased precision
// and we have eliminated p <= 0 early, we know p > 0.
// bexp == 0 for denormals
p = mbits + 1 - emin + int(e)
mant = msb64(r.mant) >> uint(fbits-p)
} else {
// normal number: emin <= e <= emax
bexp = uint64(e+bias) << mbits
mant = msb64(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit)
}
return math.Float64frombits(sign | bexp | mant), r.acc
case zero:
if x.neg {
var z float64
return -z, Exact
}
return 0.0, Exact
case inf:
if x.neg {
return math.Inf(-1), Exact
}
return math.Inf(+1), Exact
}
panic("unreachable")
}
// Int returns the result of truncating x towards zero;
// or nil if x is an infinity.
// The result is Exact if x.IsInt(); otherwise it is Below
// for x > 0, and Above for x < 0.
// If a non-nil *Int argument z is provided, Int stores
// the result in z instead of allocating a new Int.
func (x *Float) Int(z *Int) (*Int, Accuracy) {
if debugFloat {
x.validate()
}
if z == nil && x.form <= finite {
z = new(Int)
}
switch x.form {
case finite:
// 0 < |x| < +Inf
acc := makeAcc(x.neg)
if x.exp <= 0 {
// 0 < |x| < 1
return z.SetInt64(0), acc
}
// x.exp > 0
// 1 <= |x| < +Inf
// determine minimum required precision for x
allBits := uint(len(x.mant)) * _W
exp := uint(x.exp)
if x.MinPrec() <= exp {
acc = Exact
}
// shift mantissa as needed
if z == nil {
z = new(Int)
}
z.neg = x.neg
switch {
case exp > allBits:
z.abs = z.abs.shl(x.mant, exp-allBits)
default:
z.abs = z.abs.set(x.mant)
case exp < allBits:
z.abs = z.abs.shr(x.mant, allBits-exp)
}
return z, acc
case zero:
return z.SetInt64(0), Exact
case inf:
return nil, makeAcc(x.neg)
}
panic("unreachable")
}
// Rat returns the rational number corresponding to x;
// or nil if x is an infinity.
// The result is Exact if x is not an Inf.
// If a non-nil *Rat argument z is provided, Rat stores
// the result in z instead of allocating a new Rat.
func (x *Float) Rat(z *Rat) (*Rat, Accuracy) {
if debugFloat {
x.validate()
}
if z == nil && x.form <= finite {
z = new(Rat)
}
switch x.form {
case finite:
// 0 < |x| < +Inf
allBits := int32(len(x.mant)) * _W
// build up numerator and denominator
z.a.neg = x.neg
switch {
case x.exp > allBits:
z.a.abs = z.a.abs.shl(x.mant, uint(x.exp-allBits))
z.b.abs = z.b.abs[:0] // == 1 (see Rat)
// z already in normal form
default:
z.a.abs = z.a.abs.set(x.mant)
z.b.abs = z.b.abs[:0] // == 1 (see Rat)
// z already in normal form
case x.exp < allBits:
z.a.abs = z.a.abs.set(x.mant)
t := z.b.abs.setUint64(1)
z.b.abs = t.shl(t, uint(allBits-x.exp))
z.norm()
}
return z, Exact
case zero:
return z.SetInt64(0), Exact
case inf:
return nil, makeAcc(x.neg)
}
panic("unreachable")
}
// Abs sets z to the (possibly rounded) value |x| (the absolute value of x)
// and returns z.
func (z *Float) Abs(x *Float) *Float {
z.Set(x)
z.neg = false
return z
}
// Neg sets z to the (possibly rounded) value of x with its sign negated,
// and returns z.
func (z *Float) Neg(x *Float) *Float {
z.Set(x)
z.neg = !z.neg
return z
}
func validateBinaryOperands(x, y *Float) {
if !debugFloat {
// avoid performance bugs
panic("validateBinaryOperands called but debugFloat is not set")
}
if len(x.mant) == 0 {
panic("empty mantissa for x")
}
if len(y.mant) == 0 {
panic("empty mantissa for y")
}
}
// z = x + y, ignoring signs of x and y for the addition
// but using the sign of z for rounding the result.
// x and y must have a non-empty mantissa and valid exponent.
func (z *Float) uadd(x, y *Float) {
// Note: This implementation requires 2 shifts most of the
// time. It is also inefficient if exponents or precisions
// differ by wide margins. The following article describes
// an efficient (but much more complicated) implementation
// compatible with the internal representation used here:
//
// Vincent Lefèvre: "The Generic Multiple-Precision Floating-
// Point Addition With Exact Rounding (as in the MPFR Library)"
// http://www.vinc17.net/research/papers/rnc6.pdf
if debugFloat {
validateBinaryOperands(x, y)
}
// compute exponents ex, ey for mantissa with "binary point"
// on the right (mantissa.0) - use int64 to avoid overflow
ex := int64(x.exp) - int64(len(x.mant))*_W
ey := int64(y.exp) - int64(len(y.mant))*_W
al := alias(z.mant, x.mant) || alias(z.mant, y.mant)
// TODO(gri) having a combined add-and-shift primitive
// could make this code significantly faster
switch {
case ex < ey:
if al {
t := nat(nil).shl(y.mant, uint(ey-ex))
z.mant = z.mant.add(x.mant, t)
} else {
z.mant = z.mant.shl(y.mant, uint(ey-ex))
z.mant = z.mant.add(x.mant, z.mant)
}
default:
// ex == ey, no shift needed
z.mant = z.mant.add(x.mant, y.mant)
case ex > ey:
if al {
t := nat(nil).shl(x.mant, uint(ex-ey))
z.mant = z.mant.add(t, y.mant)
} else {
z.mant = z.mant.shl(x.mant, uint(ex-ey))
z.mant = z.mant.add(z.mant, y.mant)
}
ex = ey
}
// len(z.mant) > 0
z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0)
}
// z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction
// but using the sign of z for rounding the result.
// x and y must have a non-empty mantissa and valid exponent.
func (z *Float) usub(x, y *Float) {
// This code is symmetric to uadd.
// We have not factored the common code out because
// eventually uadd (and usub) should be optimized
// by special-casing, and the code will diverge.
if debugFloat {
validateBinaryOperands(x, y)
}
ex := int64(x.exp) - int64(len(x.mant))*_W
ey := int64(y.exp) - int64(len(y.mant))*_W
al := alias(z.mant, x.mant) || alias(z.mant, y.mant)
switch {
case ex < ey:
if al {
t := nat(nil).shl(y.mant, uint(ey-ex))
z.mant = t.sub(x.mant, t)
} else {
z.mant = z.mant.shl(y.mant, uint(ey-ex))
z.mant = z.mant.sub(x.mant, z.mant)
}
default:
// ex == ey, no shift needed
z.mant = z.mant.sub(x.mant, y.mant)
case ex > ey:
if al {
t := nat(nil).shl(x.mant, uint(ex-ey))
z.mant = t.sub(t, y.mant)
} else {
z.mant = z.mant.shl(x.mant, uint(ex-ey))
z.mant = z.mant.sub(z.mant, y.mant)
}
ex = ey
}
// operands may have canceled each other out
if len(z.mant) == 0 {
z.acc = Exact
z.form = zero
z.neg = false
return
}
// len(z.mant) > 0
z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0)
}
// z = x * y, ignoring signs of x and y for the multiplication
// but using the sign of z for rounding the result.
// x and y must have a non-empty mantissa and valid exponent.
func (z *Float) umul(x, y *Float) {
if debugFloat {
validateBinaryOperands(x, y)
}
// Note: This is doing too much work if the precision
// of z is less than the sum of the precisions of x
// and y which is often the case (e.g., if all floats
// have the same precision).
// TODO(gri) Optimize this for the common case.
e := int64(x.exp) + int64(y.exp)
if x == y {
z.mant = z.mant.sqr(x.mant)
} else {
z.mant = z.mant.mul(x.mant, y.mant)
}
z.setExpAndRound(e-fnorm(z.mant), 0)
}
// z = x / y, ignoring signs of x and y for the division
// but using the sign of z for rounding the result.
// x and y must have a non-empty mantissa and valid exponent.
func (z *Float) uquo(x, y *Float) {
if debugFloat {
validateBinaryOperands(x, y)
}
// mantissa length in words for desired result precision + 1
// (at least one extra bit so we get the rounding bit after
// the division)
n := int(z.prec/_W) + 1
// compute adjusted x.mant such that we get enough result precision
xadj := x.mant
if d := n - len(x.mant) + len(y.mant); d > 0 {
// d extra words needed => add d "0 digits" to x
xadj = make(nat, len(x.mant)+d)
copy(xadj[d:], x.mant)
}
// TODO(gri): If we have too many digits (d < 0), we should be able
// to shorten x for faster division. But we must be extra careful
// with rounding in that case.
// Compute d before division since there may be aliasing of x.mant
// (via xadj) or y.mant with z.mant.
d := len(xadj) - len(y.mant)
// divide
var r nat
z.mant, r = z.mant.div(nil, xadj, y.mant)
e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W
// The result is long enough to include (at least) the rounding bit.
// If there's a non-zero remainder, the corresponding fractional part
// (if it were computed), would have a non-zero sticky bit (if it were
// zero, it couldn't have a non-zero remainder).
var sbit uint
if len(r) > 0 {
sbit = 1
}
z.setExpAndRound(e-fnorm(z.mant), sbit)
}
// ucmp returns -1, 0, or +1, depending on whether
// |x| < |y|, |x| == |y|, or |x| > |y|.
// x and y must have a non-empty mantissa and valid exponent.
func (x *Float) ucmp(y *Float) int {
if debugFloat {
validateBinaryOperands(x, y)
}
switch {
case x.exp < y.exp:
return -1
case x.exp > y.exp:
return +1
}
// x.exp == y.exp
// compare mantissas
i := len(x.mant)
j := len(y.mant)
for i > 0 || j > 0 {
var xm, ym Word
if i > 0 {
i--
xm = x.mant[i]
}
if j > 0 {
j--
ym = y.mant[j]
}
switch {
case xm < ym:
return -1
case xm > ym:
return +1
}
}
return 0
}
// Handling of sign bit as defined by IEEE 754-2008, section 6.3:
//
// When neither the inputs nor result are NaN, the sign of a product or
// quotient is the exclusive OR of the operands’ signs; the sign of a sum,
// or of a difference x−y regarded as a sum x+(−y), differs from at most
// one of the addends’ signs; and the sign of the result of conversions,
// the quantize operation, the roundToIntegral operations, and the
// roundToIntegralExact (see 5.3.1) is the sign of the first or only operand.
// These rules shall apply even when operands or results are zero or infinite.
//
// When the sum of two operands with opposite signs (or the difference of
// two operands with like signs) is exactly zero, the sign of that sum (or
// difference) shall be +0 in all rounding-direction attributes except
// roundTowardNegative; under that attribute, the sign of an exact zero
// sum (or difference) shall be −0. However, x+x = x−(−x) retains the same
// sign as x even when x is zero.
//
// See also: https://play.golang.org/p/RtH3UCt5IH
// Add sets z to the rounded sum x+y and returns z. If z's precision is 0,
// it is changed to the larger of x's or y's precision before the operation.
// Rounding is performed according to z's precision and rounding mode; and
// z's accuracy reports the result error relative to the exact (not rounded)
// result. Add panics with ErrNaN if x and y are infinities with opposite
// signs. The value of z is undefined in that case.
func (z *Float) Add(x, y *Float) *Float {
if debugFloat {
x.validate()
y.validate()
}
if z.prec == 0 {
z.prec = umax32(x.prec, y.prec)
}
if x.form == finite && y.form == finite {
// x + y (common case)
// Below we set z.neg = x.neg, and when z aliases y this will
// change the y operand's sign. This is fine, because if an
// operand aliases the receiver it'll be overwritten, but we still
// want the original x.neg and y.neg values when we evaluate
// x.neg != y.neg, so we need to save y.neg before setting z.neg.
yneg := y.neg
z.neg = x.neg
if x.neg == yneg {
// x + y == x + y
// (-x) + (-y) == -(x + y)
z.uadd(x, y)
} else {
// x + (-y) == x - y == -(y - x)
// (-x) + y == y - x == -(x - y)
if x.ucmp(y) > 0 {
z.usub(x, y)
} else {
z.neg = !z.neg
z.usub(y, x)
}
}
if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact {
z.neg = true
}
return z
}
if x.form == inf && y.form == inf && x.neg != y.neg {
// +Inf + -Inf
// -Inf + +Inf
// value of z is undefined but make sure it's valid
z.acc = Exact
z.form = zero
z.neg = false
panic(ErrNaN{"addition of infinities with opposite signs"})
}
if x.form == zero && y.form == zero {
// ±0 + ±0
z.acc = Exact
z.form = zero
z.neg = x.neg && y.neg // -0 + -0 == -0
return z
}
if x.form == inf || y.form == zero {
// ±Inf + y
// x + ±0
return z.Set(x)
}
// ±0 + y
// x + ±Inf
return z.Set(y)
}
// Sub sets z to the rounded difference x-y and returns z.
// Precision, rounding, and accuracy reporting are as for Add.
// Sub panics with ErrNaN if x and y are infinities with equal
// signs. The value of z is undefined in that case.
func (z *Float) Sub(x, y *Float) *Float {
if debugFloat {
x.validate()
y.validate()
}
if z.prec == 0 {
z.prec = umax32(x.prec, y.prec)
}
if x.form == finite && y.form == finite {
// x - y (common case)
yneg := y.neg
z.neg = x.neg
if x.neg != yneg {
// x - (-y) == x + y
// (-x) - y == -(x + y)
z.uadd(x, y)
} else {
// x - y == x - y == -(y - x)
// (-x) - (-y) == y - x == -(x - y)
if x.ucmp(y) > 0 {
z.usub(x, y)
} else {
z.neg = !z.neg
z.usub(y, x)
}
}
if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact {
z.neg = true
}
return z
}
if x.form == inf && y.form == inf && x.neg == y.neg {
// +Inf - +Inf
// -Inf - -Inf
// value of z is undefined but make sure it's valid
z.acc = Exact
z.form = zero
z.neg = false
panic(ErrNaN{"subtraction of infinities with equal signs"})
}
if x.form == zero && y.form == zero {
// ±0 - ±0
z.acc = Exact
z.form = zero
z.neg = x.neg && !y.neg // -0 - +0 == -0
return z
}
if x.form == inf || y.form == zero {
// ±Inf - y
// x - ±0
return z.Set(x)
}
// ±0 - y
// x - ±Inf
return z.Neg(y)
}
// Mul sets z to the rounded product x*y and returns z.
// Precision, rounding, and accuracy reporting are as for Add.
// Mul panics with ErrNaN if one operand is zero and the other
// operand an infinity. The value of z is undefined in that case.
func (z *Float) Mul(x, y *Float) *Float {
if debugFloat {
x.validate()
y.validate()
}
if z.prec == 0 {
z.prec = umax32(x.prec, y.prec)
}
z.neg = x.neg != y.neg
if x.form == finite && y.form == finite {
// x * y (common case)
z.umul(x, y)
return z
}
z.acc = Exact
if x.form == zero && y.form == inf || x.form == inf && y.form == zero {
// ±0 * ±Inf
// ±Inf * ±0
// value of z is undefined but make sure it's valid
z.form = zero
z.neg = false
panic(ErrNaN{"multiplication of zero with infinity"})
}
if x.form == inf || y.form == inf {
// ±Inf * y
// x * ±Inf
z.form = inf
return z
}
// ±0 * y
// x * ±0
z.form = zero
return z
}
// Quo sets z to the rounded quotient x/y and returns z.
// Precision, rounding, and accuracy reporting are as for Add.
// Quo panics with ErrNaN if both operands are zero or infinities.
// The value of z is undefined in that case.
func (z *Float) Quo(x, y *Float) *Float {
if debugFloat {
x.validate()
y.validate()
}
if z.prec == 0 {
z.prec = umax32(x.prec, y.prec)
}
z.neg = x.neg != y.neg
if x.form == finite && y.form == finite {
// x / y (common case)
z.uquo(x, y)
return z
}
z.acc = Exact
if x.form == zero && y.form == zero || x.form == inf && y.form == inf {
// ±0 / ±0
// ±Inf / ±Inf
// value of z is undefined but make sure it's valid
z.form = zero
z.neg = false
panic(ErrNaN{"division of zero by zero or infinity by infinity"})
}
if x.form == zero || y.form == inf {
// ±0 / y
// x / ±Inf
z.form = zero
return z
}
// x / ±0
// ±Inf / y
z.form = inf
return z
}
// Cmp compares x and y and returns:
//
// -1 if x < y
// 0 if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf)
// +1 if x > y
func (x *Float) Cmp(y *Float) int {
if debugFloat {
x.validate()
y.validate()
}
mx := x.ord()
my := y.ord()
switch {
case mx < my:
return -1
case mx > my:
return +1
}
// mx == my
// only if |mx| == 1 we have to compare the mantissae
switch mx {
case -1:
return y.ucmp(x)
case +1:
return x.ucmp(y)
}
return 0
}
// ord classifies x and returns:
//
// -2 if -Inf == x
// -1 if -Inf < x < 0
// 0 if x == 0 (signed or unsigned)
// +1 if 0 < x < +Inf
// +2 if x == +Inf
func (x *Float) ord() int {
var m int
switch x.form {
case finite:
m = 1
case zero:
return 0
case inf:
m = 2
}
if x.neg {
m = -m
}
return m
}
func umax32(x, y uint32) uint32 {
if x > y {
return x
}
return y
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements string-to-Float conversion functions.
package big
import (
"fmt"
"io"
"strings"
)
var floatZero Float
// SetString sets z to the value of s and returns z and a boolean indicating
// success. s must be a floating-point number of the same format as accepted
// by Parse, with base argument 0. The entire string (not just a prefix) must
// be valid for success. If the operation failed, the value of z is undefined
// but the returned value is nil.
func (z *Float) SetString(s string) (*Float, bool) {
if f, _, err := z.Parse(s, 0); err == nil {
return f, true
}
return nil, false
}
// scan is like Parse but reads the longest possible prefix representing a valid
// floating point number from an io.ByteScanner rather than a string. It serves
// as the implementation of Parse. It does not recognize ±Inf and does not expect
// EOF at the end.
func (z *Float) scan(r io.ByteScanner, base int) (f *Float, b int, err error) {
prec := z.prec
if prec == 0 {
prec = 64
}
// A reasonable value in case of an error.
z.form = zero
// sign
z.neg, err = scanSign(r)
if err != nil {
return
}
// mantissa
var fcount int // fractional digit count; valid if <= 0
z.mant, b, fcount, err = z.mant.scan(r, base, true)
if err != nil {
return
}
// exponent
var exp int64
var ebase int
exp, ebase, err = scanExponent(r, true, base == 0)
if err != nil {
return
}
// special-case 0
if len(z.mant) == 0 {
z.prec = prec
z.acc = Exact
z.form = zero
f = z
return
}
// len(z.mant) > 0
// The mantissa may have a radix point (fcount <= 0) and there
// may be a nonzero exponent exp. The radix point amounts to a
// division by b**(-fcount). An exponent means multiplication by
// ebase**exp. Finally, mantissa normalization (shift left) requires
// a correcting multiplication by 2**(-shiftcount). Multiplications
// are commutative, so we can apply them in any order as long as there
// is no loss of precision. We only have powers of 2 and 10, and
// we split powers of 10 into the product of the same powers of
// 2 and 5. This reduces the size of the multiplication factor
// needed for base-10 exponents.
// normalize mantissa and determine initial exponent contributions
exp2 := int64(len(z.mant))*_W - fnorm(z.mant)
exp5 := int64(0)
// determine binary or decimal exponent contribution of radix point
if fcount < 0 {
// The mantissa has a radix point ddd.dddd; and
// -fcount is the number of digits to the right
// of '.'. Adjust relevant exponent accordingly.
d := int64(fcount)
switch b {
case 10:
exp5 = d
fallthrough // 10**e == 5**e * 2**e
case 2:
exp2 += d
case 8:
exp2 += d * 3 // octal digits are 3 bits each
case 16:
exp2 += d * 4 // hexadecimal digits are 4 bits each
default:
panic("unexpected mantissa base")
}
// fcount consumed - not needed anymore
}
// take actual exponent into account
switch ebase {
case 10:
exp5 += exp
fallthrough // see fallthrough above
case 2:
exp2 += exp
default:
panic("unexpected exponent base")
}
// exp consumed - not needed anymore
// apply 2**exp2
if MinExp <= exp2 && exp2 <= MaxExp {
z.prec = prec
z.form = finite
z.exp = int32(exp2)
f = z
} else {
err = fmt.Errorf("exponent overflow")
return
}
if exp5 == 0 {
// no decimal exponent contribution
z.round(0)
return
}
// exp5 != 0
// apply 5**exp5
p := new(Float).SetPrec(z.Prec() + 64) // use more bits for p -- TODO(gri) what is the right number?
if exp5 < 0 {
z.Quo(z, p.pow5(uint64(-exp5)))
} else {
z.Mul(z, p.pow5(uint64(exp5)))
}
return
}
// These powers of 5 fit into a uint64.
//
// for p, q := uint64(0), uint64(1); p < q; p, q = q, q*5 {
// fmt.Println(q)
// }
var pow5tab = [...]uint64{
1,
5,
25,
125,
625,
3125,
15625,
78125,
390625,
1953125,
9765625,
48828125,
244140625,
1220703125,
6103515625,
30517578125,
152587890625,
762939453125,
3814697265625,
19073486328125,
95367431640625,
476837158203125,
2384185791015625,
11920928955078125,
59604644775390625,
298023223876953125,
1490116119384765625,
7450580596923828125,
}
// pow5 sets z to 5**n and returns z.
// n must not be negative.
func (z *Float) pow5(n uint64) *Float {
const m = uint64(len(pow5tab) - 1)
if n <= m {
return z.SetUint64(pow5tab[n])
}
// n > m
z.SetUint64(pow5tab[m])
n -= m
// use more bits for f than for z
// TODO(gri) what is the right number?
f := new(Float).SetPrec(z.Prec() + 64).SetUint64(5)
for n > 0 {
if n&1 != 0 {
z.Mul(z, f)
}
f.Mul(f, f)
n >>= 1
}
return z
}
// Parse parses s which must contain a text representation of a floating-
// point number with a mantissa in the given conversion base (the exponent
// is always a decimal number), or a string representing an infinite value.
//
// For base 0, an underscore character “_” may appear between a base
// prefix and an adjacent digit, and between successive digits; such
// underscores do not change the value of the number, or the returned
// digit count. Incorrect placement of underscores is reported as an
// error if there are no other errors. If base != 0, underscores are
// not recognized and thus terminate scanning like any other character
// that is not a valid radix point or digit.
//
// It sets z to the (possibly rounded) value of the corresponding floating-
// point value, and returns z, the actual base b, and an error err, if any.
// The entire string (not just a prefix) must be consumed for success.
// If z's precision is 0, it is changed to 64 before rounding takes effect.
// The number must be of the form:
//
// number = [ sign ] ( float | "inf" | "Inf" ) .
// sign = "+" | "-" .
// float = ( mantissa | prefix pmantissa ) [ exponent ] .
// prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] .
// mantissa = digits "." [ digits ] | digits | "." digits .
// pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits .
// exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits .
// digits = digit { [ "_" ] digit } .
// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
//
// The base argument must be 0, 2, 8, 10, or 16. Providing an invalid base
// argument will lead to a run-time panic.
//
// For base 0, the number prefix determines the actual base: A prefix of
// “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and
// “0x” or “0X” selects base 16. Otherwise, the actual base is 10 and
// no prefix is accepted. The octal prefix "0" is not supported (a leading
// "0" is simply considered a "0").
//
// A "p" or "P" exponent indicates a base 2 (rather then base 10) exponent;
// for instance, "0x1.fffffffffffffp1023" (using base 0) represents the
// maximum float64 value. For hexadecimal mantissae, the exponent character
// must be one of 'p' or 'P', if present (an "e" or "E" exponent indicator
// cannot be distinguished from a mantissa digit).
//
// The returned *Float f is nil and the value of z is valid but not
// defined if an error is reported.
func (z *Float) Parse(s string, base int) (f *Float, b int, err error) {
// scan doesn't handle ±Inf
if len(s) == 3 && (s == "Inf" || s == "inf") {
f = z.SetInf(false)
return
}
if len(s) == 4 && (s[0] == '+' || s[0] == '-') && (s[1:] == "Inf" || s[1:] == "inf") {
f = z.SetInf(s[0] == '-')
return
}
r := strings.NewReader(s)
if f, b, err = z.scan(r, base); err != nil {
return
}
// entire string must have been consumed
if ch, err2 := r.ReadByte(); err2 == nil {
err = fmt.Errorf("expected end of string, found %q", ch)
} else if err2 != io.EOF {
err = err2
}
return
}
// ParseFloat is like f.Parse(s, base) with f set to the given precision
// and rounding mode.
func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) {
return new(Float).SetPrec(prec).SetMode(mode).Parse(s, base)
}
var _ fmt.Scanner = (*Float)(nil) // *Float must implement fmt.Scanner
// Scan is a support routine for fmt.Scanner; it sets z to the value of
// the scanned number. It accepts formats whose verbs are supported by
// fmt.Scan for floating point values, which are:
// 'b' (binary), 'e', 'E', 'f', 'F', 'g' and 'G'.
// Scan doesn't handle ±Inf.
func (z *Float) Scan(s fmt.ScanState, ch rune) error {
s.SkipSpace()
_, _, err := z.scan(byteReader{s}, 0)
return err
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements encoding/decoding of Floats.
package big
import (
"encoding/binary"
"errors"
"fmt"
)
// Gob codec version. Permits backward-compatible changes to the encoding.
const floatGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
// The Float value and all its attributes (precision,
// rounding mode, accuracy) are marshaled.
func (x *Float) GobEncode() ([]byte, error) {
if x == nil {
return nil, nil
}
// determine max. space (bytes) required for encoding
sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec
n := 0 // number of mantissa words
if x.form == finite {
// add space for mantissa and exponent
n = int((x.prec + (_W - 1)) / _W) // required mantissa length in words for given precision
// actual mantissa slice could be shorter (trailing 0's) or longer (unused bits):
// - if shorter, only encode the words present
// - if longer, cut off unused words when encoding in bytes
// (in practice, this should never happen since rounding
// takes care of it, but be safe and do it always)
if len(x.mant) < n {
n = len(x.mant)
}
// len(x.mant) >= n
sz += 4 + n*_S // exp + mant
}
buf := make([]byte, sz)
buf[0] = floatGobVersion
b := byte(x.mode&7)<<5 | byte((x.acc+1)&3)<<3 | byte(x.form&3)<<1
if x.neg {
b |= 1
}
buf[1] = b
binary.BigEndian.PutUint32(buf[2:], x.prec)
if x.form == finite {
binary.BigEndian.PutUint32(buf[6:], uint32(x.exp))
x.mant[len(x.mant)-n:].bytes(buf[10:]) // cut off unused trailing words
}
return buf, nil
}
// GobDecode implements the gob.GobDecoder interface.
// The result is rounded per the precision and rounding mode of
// z unless z's precision is 0, in which case z is set exactly
// to the decoded value.
func (z *Float) GobDecode(buf []byte) error {
if len(buf) == 0 {
// Other side sent a nil or default value.
*z = Float{}
return nil
}
if len(buf) < 6 {
return errors.New("Float.GobDecode: buffer too small")
}
if buf[0] != floatGobVersion {
return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0])
}
oldPrec := z.prec
oldMode := z.mode
b := buf[1]
z.mode = RoundingMode((b >> 5) & 7)
z.acc = Accuracy((b>>3)&3) - 1
z.form = form((b >> 1) & 3)
z.neg = b&1 != 0
z.prec = binary.BigEndian.Uint32(buf[2:])
if z.form == finite {
if len(buf) < 10 {
return errors.New("Float.GobDecode: buffer too small for finite form float")
}
z.exp = int32(binary.BigEndian.Uint32(buf[6:]))
z.mant = z.mant.setBytes(buf[10:])
}
if oldPrec != 0 {
z.mode = oldMode
z.SetPrec(uint(oldPrec))
}
if msg := z.validate0(); msg != "" {
return errors.New("Float.GobDecode: " + msg)
}
return nil
}
// MarshalText implements the encoding.TextMarshaler interface.
// Only the Float value is marshaled (in full precision), other
// attributes such as precision or accuracy are ignored.
func (x *Float) MarshalText() (text []byte, err error) {
if x == nil {
return []byte("<nil>"), nil
}
var buf []byte
return x.Append(buf, 'g', -1), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The result is rounded per the precision and rounding mode of z.
// If z's precision is 0, it is changed to 64 before rounding takes
// effect.
func (z *Float) UnmarshalText(text []byte) error {
// TODO(gri): get rid of the []byte/string conversion
_, _, err := z.Parse(string(text), 0)
if err != nil {
err = fmt.Errorf("math/big: cannot unmarshal %q into a *big.Float (%v)", text, err)
}
return err
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements Float-to-string conversion functions.
// It is closely following the corresponding implementation
// in strconv/ftoa.go, but modified and simplified for Float.
package big
import (
"bytes"
"fmt"
"strconv"
)
// Text converts the floating-point number x to a string according
// to the given format and precision prec. The format is one of:
//
// 'e' -d.dddde±dd, decimal exponent, at least two (possibly 0) exponent digits
// 'E' -d.ddddE±dd, decimal exponent, at least two (possibly 0) exponent digits
// 'f' -ddddd.dddd, no exponent
// 'g' like 'e' for large exponents, like 'f' otherwise
// 'G' like 'E' for large exponents, like 'f' otherwise
// 'x' -0xd.dddddp±dd, hexadecimal mantissa, decimal power of two exponent
// 'p' -0x.dddp±dd, hexadecimal mantissa, decimal power of two exponent (non-standard)
// 'b' -ddddddp±dd, decimal mantissa, decimal power of two exponent (non-standard)
//
// For the power-of-two exponent formats, the mantissa is printed in normalized form:
//
// 'x' hexadecimal mantissa in [1, 2), or 0
// 'p' hexadecimal mantissa in [½, 1), or 0
// 'b' decimal integer mantissa using x.Prec() bits, or 0
//
// Note that the 'x' form is the one used by most other languages and libraries.
//
// If format is a different character, Text returns a "%" followed by the
// unrecognized format character.
//
// The precision prec controls the number of digits (excluding the exponent)
// printed by the 'e', 'E', 'f', 'g', 'G', and 'x' formats.
// For 'e', 'E', 'f', and 'x', it is the number of digits after the decimal point.
// For 'g' and 'G' it is the total number of digits. A negative precision selects
// the smallest number of decimal digits necessary to identify the value x uniquely
// using x.Prec() mantissa bits.
// The prec value is ignored for the 'b' and 'p' formats.
func (x *Float) Text(format byte, prec int) string {
cap := 10 // TODO(gri) determine a good/better value here
if prec > 0 {
cap += prec
}
return string(x.Append(make([]byte, 0, cap), format, prec))
}
// String formats x like x.Text('g', 10).
// (String must be called explicitly, Float.Format does not support %s verb.)
func (x *Float) String() string {
return x.Text('g', 10)
}
// Append appends to buf the string form of the floating-point number x,
// as generated by x.Text, and returns the extended buffer.
func (x *Float) Append(buf []byte, fmt byte, prec int) []byte {
// sign
if x.neg {
buf = append(buf, '-')
}
// Inf
if x.form == inf {
if !x.neg {
buf = append(buf, '+')
}
return append(buf, "Inf"...)
}
// pick off easy formats
switch fmt {
case 'b':
return x.fmtB(buf)
case 'p':
return x.fmtP(buf)
case 'x':
return x.fmtX(buf, prec)
}
// Algorithm:
// 1) convert Float to multiprecision decimal
// 2) round to desired precision
// 3) read digits out and format
// 1) convert Float to multiprecision decimal
var d decimal // == 0.0
if x.form == finite {
// x != 0
d.init(x.mant, int(x.exp)-x.mant.bitLen())
}
// 2) round to desired precision
shortest := false
if prec < 0 {
shortest = true
roundShortest(&d, x)
// Precision for shortest representation mode.
switch fmt {
case 'e', 'E':
prec = len(d.mant) - 1
case 'f':
prec = max(len(d.mant)-d.exp, 0)
case 'g', 'G':
prec = len(d.mant)
}
} else {
// round appropriately
switch fmt {
case 'e', 'E':
// one digit before and number of digits after decimal point
d.round(1 + prec)
case 'f':
// number of digits before and after decimal point
d.round(d.exp + prec)
case 'g', 'G':
if prec == 0 {
prec = 1
}
d.round(prec)
}
}
// 3) read digits out and format
switch fmt {
case 'e', 'E':
return fmtE(buf, fmt, prec, d)
case 'f':
return fmtF(buf, prec, d)
case 'g', 'G':
// trim trailing fractional zeros in %e format
eprec := prec
if eprec > len(d.mant) && len(d.mant) >= d.exp {
eprec = len(d.mant)
}
// %e is used if the exponent from the conversion
// is less than -4 or greater than or equal to the precision.
// If precision was the shortest possible, use eprec = 6 for
// this decision.
if shortest {
eprec = 6
}
exp := d.exp - 1
if exp < -4 || exp >= eprec {
if prec > len(d.mant) {
prec = len(d.mant)
}
return fmtE(buf, fmt+'e'-'g', prec-1, d)
}
if prec > d.exp {
prec = len(d.mant)
}
return fmtF(buf, max(prec-d.exp, 0), d)
}
// unknown format
if x.neg {
buf = buf[:len(buf)-1] // sign was added prematurely - remove it again
}
return append(buf, '%', fmt)
}
func roundShortest(d *decimal, x *Float) {
// if the mantissa is zero, the number is zero - stop now
if len(d.mant) == 0 {
return
}
// Approach: All numbers in the interval [x - 1/2ulp, x + 1/2ulp]
// (possibly exclusive) round to x for the given precision of x.
// Compute the lower and upper bound in decimal form and find the
// shortest decimal number d such that lower <= d <= upper.
// TODO(gri) strconv/ftoa.do describes a shortcut in some cases.
// See if we can use it (in adjusted form) here as well.
// 1) Compute normalized mantissa mant and exponent exp for x such
// that the lsb of mant corresponds to 1/2 ulp for the precision of
// x (i.e., for mant we want x.prec + 1 bits).
mant := nat(nil).set(x.mant)
exp := int(x.exp) - mant.bitLen()
s := mant.bitLen() - int(x.prec+1)
switch {
case s < 0:
mant = mant.shl(mant, uint(-s))
case s > 0:
mant = mant.shr(mant, uint(+s))
}
exp += s
// x = mant * 2**exp with lsb(mant) == 1/2 ulp of x.prec
// 2) Compute lower bound by subtracting 1/2 ulp.
var lower decimal
var tmp nat
lower.init(tmp.sub(mant, natOne), exp)
// 3) Compute upper bound by adding 1/2 ulp.
var upper decimal
upper.init(tmp.add(mant, natOne), exp)
// The upper and lower bounds are possible outputs only if
// the original mantissa is even, so that ToNearestEven rounding
// would round to the original mantissa and not the neighbors.
inclusive := mant[0]&2 == 0 // test bit 1 since original mantissa was shifted by 1
// Now we can figure out the minimum number of digits required.
// Walk along until d has distinguished itself from upper and lower.
for i, m := range d.mant {
l := lower.at(i)
u := upper.at(i)
// Okay to round down (truncate) if lower has a different digit
// or if lower is inclusive and is exactly the result of rounding
// down (i.e., and we have reached the final digit of lower).
okdown := l != m || inclusive && i+1 == len(lower.mant)
// Okay to round up if upper has a different digit and either upper
// is inclusive or upper is bigger than the result of rounding up.
okup := m != u && (inclusive || m+1 < u || i+1 < len(upper.mant))
// If it's okay to do either, then round to the nearest one.
// If it's okay to do only one, do it.
switch {
case okdown && okup:
d.round(i + 1)
return
case okdown:
d.roundDown(i + 1)
return
case okup:
d.roundUp(i + 1)
return
}
}
}
// %e: d.ddddde±dd
func fmtE(buf []byte, fmt byte, prec int, d decimal) []byte {
// first digit
ch := byte('0')
if len(d.mant) > 0 {
ch = d.mant[0]
}
buf = append(buf, ch)
// .moredigits
if prec > 0 {
buf = append(buf, '.')
i := 1
m := min(len(d.mant), prec+1)
if i < m {
buf = append(buf, d.mant[i:m]...)
i = m
}
for ; i <= prec; i++ {
buf = append(buf, '0')
}
}
// e±
buf = append(buf, fmt)
var exp int64
if len(d.mant) > 0 {
exp = int64(d.exp) - 1 // -1 because first digit was printed before '.'
}
if exp < 0 {
ch = '-'
exp = -exp
} else {
ch = '+'
}
buf = append(buf, ch)
// dd...d
if exp < 10 {
buf = append(buf, '0') // at least 2 exponent digits
}
return strconv.AppendInt(buf, exp, 10)
}
// %f: ddddddd.ddddd
func fmtF(buf []byte, prec int, d decimal) []byte {
// integer, padded with zeros as needed
if d.exp > 0 {
m := min(len(d.mant), d.exp)
buf = append(buf, d.mant[:m]...)
for ; m < d.exp; m++ {
buf = append(buf, '0')
}
} else {
buf = append(buf, '0')
}
// fraction
if prec > 0 {
buf = append(buf, '.')
for i := 0; i < prec; i++ {
buf = append(buf, d.at(d.exp+i))
}
}
return buf
}
// fmtB appends the string of x in the format mantissa "p" exponent
// with a decimal mantissa and a binary exponent, or 0" if x is zero,
// and returns the extended buffer.
// The mantissa is normalized such that is uses x.Prec() bits in binary
// representation.
// The sign of x is ignored, and x must not be an Inf.
// (The caller handles Inf before invoking fmtB.)
func (x *Float) fmtB(buf []byte) []byte {
if x.form == zero {
return append(buf, '0')
}
if debugFloat && x.form != finite {
panic("non-finite float")
}
// x != 0
// adjust mantissa to use exactly x.prec bits
m := x.mant
switch w := uint32(len(x.mant)) * _W; {
case w < x.prec:
m = nat(nil).shl(m, uint(x.prec-w))
case w > x.prec:
m = nat(nil).shr(m, uint(w-x.prec))
}
buf = append(buf, m.utoa(10)...)
buf = append(buf, 'p')
e := int64(x.exp) - int64(x.prec)
if e >= 0 {
buf = append(buf, '+')
}
return strconv.AppendInt(buf, e, 10)
}
// fmtX appends the string of x in the format "0x1." mantissa "p" exponent
// with a hexadecimal mantissa and a binary exponent, or "0x0p0" if x is zero,
// and returns the extended buffer.
// A non-zero mantissa is normalized such that 1.0 <= mantissa < 2.0.
// The sign of x is ignored, and x must not be an Inf.
// (The caller handles Inf before invoking fmtX.)
func (x *Float) fmtX(buf []byte, prec int) []byte {
if x.form == zero {
buf = append(buf, "0x0"...)
if prec > 0 {
buf = append(buf, '.')
for i := 0; i < prec; i++ {
buf = append(buf, '0')
}
}
buf = append(buf, "p+00"...)
return buf
}
if debugFloat && x.form != finite {
panic("non-finite float")
}
// round mantissa to n bits
var n uint
if prec < 0 {
n = 1 + (x.MinPrec()-1+3)/4*4 // round MinPrec up to 1 mod 4
} else {
n = 1 + 4*uint(prec)
}
// n%4 == 1
x = new(Float).SetPrec(n).SetMode(x.mode).Set(x)
// adjust mantissa to use exactly n bits
m := x.mant
switch w := uint(len(x.mant)) * _W; {
case w < n:
m = nat(nil).shl(m, n-w)
case w > n:
m = nat(nil).shr(m, w-n)
}
exp64 := int64(x.exp) - 1 // avoid wrap-around
hm := m.utoa(16)
if debugFloat && hm[0] != '1' {
panic("incorrect mantissa: " + string(hm))
}
buf = append(buf, "0x1"...)
if len(hm) > 1 {
buf = append(buf, '.')
buf = append(buf, hm[1:]...)
}
buf = append(buf, 'p')
if exp64 >= 0 {
buf = append(buf, '+')
} else {
exp64 = -exp64
buf = append(buf, '-')
}
// Force at least two exponent digits, to match fmt.
if exp64 < 10 {
buf = append(buf, '0')
}
return strconv.AppendInt(buf, exp64, 10)
}
// fmtP appends the string of x in the format "0x." mantissa "p" exponent
// with a hexadecimal mantissa and a binary exponent, or "0" if x is zero,
// and returns the extended buffer.
// The mantissa is normalized such that 0.5 <= 0.mantissa < 1.0.
// The sign of x is ignored, and x must not be an Inf.
// (The caller handles Inf before invoking fmtP.)
func (x *Float) fmtP(buf []byte) []byte {
if x.form == zero {
return append(buf, '0')
}
if debugFloat && x.form != finite {
panic("non-finite float")
}
// x != 0
// remove trailing 0 words early
// (no need to convert to hex 0's and trim later)
m := x.mant
i := 0
for i < len(m) && m[i] == 0 {
i++
}
m = m[i:]
buf = append(buf, "0x."...)
buf = append(buf, bytes.TrimRight(m.utoa(16), "0")...)
buf = append(buf, 'p')
if x.exp >= 0 {
buf = append(buf, '+')
}
return strconv.AppendInt(buf, int64(x.exp), 10)
}
func min(x, y int) int {
if x < y {
return x
}
return y
}
var _ fmt.Formatter = &floatZero // *Float must implement fmt.Formatter
// Format implements fmt.Formatter. It accepts all the regular
// formats for floating-point numbers ('b', 'e', 'E', 'f', 'F',
// 'g', 'G', 'x') as well as 'p' and 'v'. See (*Float).Text for the
// interpretation of 'p'. The 'v' format is handled like 'g'.
// Format also supports specification of the minimum precision
// in digits, the output field width, as well as the format flags
// '+' and ' ' for sign control, '0' for space or zero padding,
// and '-' for left or right justification. See the fmt package
// for details.
func (x *Float) Format(s fmt.State, format rune) {
prec, hasPrec := s.Precision()
if !hasPrec {
prec = 6 // default precision for 'e', 'f'
}
switch format {
case 'e', 'E', 'f', 'b', 'p', 'x':
// nothing to do
case 'F':
// (*Float).Text doesn't support 'F'; handle like 'f'
format = 'f'
case 'v':
// handle like 'g'
format = 'g'
fallthrough
case 'g', 'G':
if !hasPrec {
prec = -1 // default precision for 'g', 'G'
}
default:
fmt.Fprintf(s, "%%!%c(*big.Float=%s)", format, x.String())
return
}
var buf []byte
buf = x.Append(buf, byte(format), prec)
if len(buf) == 0 {
buf = []byte("?") // should never happen, but don't crash
}
// len(buf) > 0
var sign string
switch {
case buf[0] == '-':
sign = "-"
buf = buf[1:]
case buf[0] == '+':
// +Inf
sign = "+"
if s.Flag(' ') {
sign = " "
}
buf = buf[1:]
case s.Flag('+'):
sign = "+"
case s.Flag(' '):
sign = " "
}
var padding int
if width, hasWidth := s.Width(); hasWidth && width > len(sign)+len(buf) {
padding = width - len(sign) - len(buf)
}
switch {
case s.Flag('0') && !x.IsInf():
// 0-padding on left
writeMultiple(s, sign, 1)
writeMultiple(s, "0", padding)
s.Write(buf)
case s.Flag('-'):
// padding on right
writeMultiple(s, sign, 1)
s.Write(buf)
writeMultiple(s, " ", padding)
default:
// padding on left
writeMultiple(s, " ", padding)
writeMultiple(s, sign, 1)
s.Write(buf)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements signed multi-precision integers.
package big
import (
"fmt"
"io"
"math/rand"
"strings"
)
// An Int represents a signed multi-precision integer.
// The zero value for an Int represents the value 0.
//
// Operations always take pointer arguments (*Int) rather
// than Int values, and each unique Int value requires
// its own unique *Int pointer. To "copy" an Int value,
// an existing (or newly allocated) Int must be set to
// a new value using the Int.Set method; shallow copies
// of Ints are not supported and may lead to errors.
//
// Note that methods may leak the Int's value through timing side-channels.
// Because of this and because of the scope and complexity of the
// implementation, Int is not well-suited to implement cryptographic operations.
// The standard library avoids exposing non-trivial Int methods to
// attacker-controlled inputs and the determination of whether a bug in math/big
// is considered a security vulnerability might depend on the impact on the
// standard library.
type Int struct {
neg bool // sign
abs nat // absolute value of the integer
}
var intOne = &Int{false, natOne}
// Sign returns:
//
// -1 if x < 0
// 0 if x == 0
// +1 if x > 0
func (x *Int) Sign() int {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
if len(x.abs) == 0 {
return 0
}
if x.neg {
return -1
}
return 1
}
// SetInt64 sets z to x and returns z.
func (z *Int) SetInt64(x int64) *Int {
neg := false
if x < 0 {
neg = true
x = -x
}
z.abs = z.abs.setUint64(uint64(x))
z.neg = neg
return z
}
// SetUint64 sets z to x and returns z.
func (z *Int) SetUint64(x uint64) *Int {
z.abs = z.abs.setUint64(x)
z.neg = false
return z
}
// NewInt allocates and returns a new Int set to x.
func NewInt(x int64) *Int {
// This code is arranged to be inlineable and produce
// zero allocations when inlined. See issue 29951.
u := uint64(x)
if x < 0 {
u = -u
}
var abs []Word
if x == 0 {
} else if _W == 32 && u>>32 != 0 {
abs = []Word{Word(u), Word(u >> 32)}
} else {
abs = []Word{Word(u)}
}
return &Int{neg: x < 0, abs: abs}
}
// Set sets z to x and returns z.
func (z *Int) Set(x *Int) *Int {
if z != x {
z.abs = z.abs.set(x.abs)
z.neg = x.neg
}
return z
}
// Bits provides raw (unchecked but fast) access to x by returning its
// absolute value as a little-endian Word slice. The result and x share
// the same underlying array.
// Bits is intended to support implementation of missing low-level Int
// functionality outside this package; it should be avoided otherwise.
func (x *Int) Bits() []Word {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
return x.abs
}
// SetBits provides raw (unchecked but fast) access to z by setting its
// value to abs, interpreted as a little-endian Word slice, and returning
// z. The result and abs share the same underlying array.
// SetBits is intended to support implementation of missing low-level Int
// functionality outside this package; it should be avoided otherwise.
func (z *Int) SetBits(abs []Word) *Int {
z.abs = nat(abs).norm()
z.neg = false
return z
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Int) Abs(x *Int) *Int {
z.Set(x)
z.neg = false
return z
}
// Neg sets z to -x and returns z.
func (z *Int) Neg(x *Int) *Int {
z.Set(x)
z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign
return z
}
// Add sets z to the sum x+y and returns z.
func (z *Int) Add(x, y *Int) *Int {
neg := x.neg
if x.neg == y.neg {
// x + y == x + y
// (-x) + (-y) == -(x + y)
z.abs = z.abs.add(x.abs, y.abs)
} else {
// x + (-y) == x - y == -(y - x)
// (-x) + y == y - x == -(x - y)
if x.abs.cmp(y.abs) >= 0 {
z.abs = z.abs.sub(x.abs, y.abs)
} else {
neg = !neg
z.abs = z.abs.sub(y.abs, x.abs)
}
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z
}
// Sub sets z to the difference x-y and returns z.
func (z *Int) Sub(x, y *Int) *Int {
neg := x.neg
if x.neg != y.neg {
// x - (-y) == x + y
// (-x) - y == -(x + y)
z.abs = z.abs.add(x.abs, y.abs)
} else {
// x - y == x - y == -(y - x)
// (-x) - (-y) == y - x == -(x - y)
if x.abs.cmp(y.abs) >= 0 {
z.abs = z.abs.sub(x.abs, y.abs)
} else {
neg = !neg
z.abs = z.abs.sub(y.abs, x.abs)
}
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z
}
// Mul sets z to the product x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int {
// x * y == x * y
// x * (-y) == -(x * y)
// (-x) * y == -(x * y)
// (-x) * (-y) == x * y
if x == y {
z.abs = z.abs.sqr(x.abs)
z.neg = false
return z
}
z.abs = z.abs.mul(x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
// MulRange sets z to the product of all integers
// in the range [a, b] inclusively and returns z.
// If a > b (empty range), the result is 1.
func (z *Int) MulRange(a, b int64) *Int {
switch {
case a > b:
return z.SetInt64(1) // empty range
case a <= 0 && b >= 0:
return z.SetInt64(0) // range includes 0
}
// a <= b && (b < 0 || a > 0)
neg := false
if a < 0 {
neg = (b-a)&1 == 0
a, b = -b, -a
}
z.abs = z.abs.mulRange(uint64(a), uint64(b))
z.neg = neg
return z
}
// Binomial sets z to the binomial coefficient C(n, k) and returns z.
func (z *Int) Binomial(n, k int64) *Int {
if k > n {
return z.SetInt64(0)
}
// reduce the number of multiplications by reducing k
if k > n-k {
k = n - k // C(n, k) == C(n, n-k)
}
// C(n, k) == n * (n-1) * ... * (n-k+1) / k * (k-1) * ... * 1
// == n * (n-1) * ... * (n-k+1) / 1 * (1+1) * ... * k
//
// Using the multiplicative formula produces smaller values
// at each step, requiring fewer allocations and computations:
//
// z = 1
// for i := 0; i < k; i = i+1 {
// z *= n-i
// z /= i+1
// }
//
// finally to avoid computing i+1 twice per loop:
//
// z = 1
// i := 0
// for i < k {
// z *= n-i
// i++
// z /= i
// }
var N, K, i, t Int
N.SetInt64(n)
K.SetInt64(k)
z.Set(intOne)
for i.Cmp(&K) < 0 {
z.Mul(z, t.Sub(&N, &i))
i.Add(&i, intOne)
z.Quo(z, &i)
}
return z
}
// Quo sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// Quo implements truncated division (like Go); see QuoRem for more details.
func (z *Int) Quo(x, y *Int) *Int {
z.abs, _ = z.abs.div(nil, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign
return z
}
// Rem sets z to the remainder x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see QuoRem for more details.
func (z *Int) Rem(x, y *Int) *Int {
_, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z
}
// QuoRem sets z to the quotient x/y and r to the remainder x%y
// and returns the pair (z, r) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs.
//
// QuoRem implements T-division and modulus (like Go):
//
// q = x/y with the result truncated to zero
// r = x - y*q
//
// (See Daan Leijen, “Division and Modulus for Computer Scientists”.)
// See DivMod for Euclidean division and modulus (unlike Go).
func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) {
z.abs, r.abs = z.abs.div(r.abs, x.abs, y.abs)
z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign
return z, r
}
// Div sets z to the quotient x/y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// Div implements Euclidean division (unlike Go); see DivMod for more details.
func (z *Int) Div(x, y *Int) *Int {
y_neg := y.neg // z may be an alias for y
var r Int
z.QuoRem(x, y, &r)
if r.neg {
if y_neg {
z.Add(z, intOne)
} else {
z.Sub(z, intOne)
}
}
return z
}
// Mod sets z to the modulus x%y for y != 0 and returns z.
// If y == 0, a division-by-zero run-time panic occurs.
// Mod implements Euclidean modulus (unlike Go); see DivMod for more details.
func (z *Int) Mod(x, y *Int) *Int {
y0 := y // save y
if z == y || alias(z.abs, y.abs) {
y0 = new(Int).Set(y)
}
var q Int
q.QuoRem(x, y, z)
if z.neg {
if y0.neg {
z.Sub(z, y0)
} else {
z.Add(z, y0)
}
}
return z
}
// DivMod sets z to the quotient x div y and m to the modulus x mod y
// and returns the pair (z, m) for y != 0.
// If y == 0, a division-by-zero run-time panic occurs.
//
// DivMod implements Euclidean division and modulus (unlike Go):
//
// q = x div y such that
// m = x - y*q with 0 <= m < |y|
//
// (See Raymond T. Boute, “The Euclidean definition of the functions
// div and mod”. ACM Transactions on Programming Languages and
// Systems (TOPLAS), 14(2):127-144, New York, NY, USA, 4/1992.
// ACM press.)
// See QuoRem for T-division and modulus (like Go).
func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
y0 := y // save y
if z == y || alias(z.abs, y.abs) {
y0 = new(Int).Set(y)
}
z.QuoRem(x, y, m)
if m.neg {
if y0.neg {
z.Add(z, intOne)
m.Sub(m, y0)
} else {
z.Sub(z, intOne)
m.Add(m, y0)
}
}
return z, m
}
// Cmp compares x and y and returns:
//
// -1 if x < y
// 0 if x == y
// +1 if x > y
func (x *Int) Cmp(y *Int) (r int) {
// x cmp y == x cmp y
// x cmp (-y) == x
// (-x) cmp y == y
// (-x) cmp (-y) == -(x cmp y)
switch {
case x == y:
// nothing to do
case x.neg == y.neg:
r = x.abs.cmp(y.abs)
if x.neg {
r = -r
}
case x.neg:
r = -1
default:
r = 1
}
return
}
// CmpAbs compares the absolute values of x and y and returns:
//
// -1 if |x| < |y|
// 0 if |x| == |y|
// +1 if |x| > |y|
func (x *Int) CmpAbs(y *Int) int {
return x.abs.cmp(y.abs)
}
// low32 returns the least significant 32 bits of x.
func low32(x nat) uint32 {
if len(x) == 0 {
return 0
}
return uint32(x[0])
}
// low64 returns the least significant 64 bits of x.
func low64(x nat) uint64 {
if len(x) == 0 {
return 0
}
v := uint64(x[0])
if _W == 32 && len(x) > 1 {
return uint64(x[1])<<32 | v
}
return v
}
// Int64 returns the int64 representation of x.
// If x cannot be represented in an int64, the result is undefined.
func (x *Int) Int64() int64 {
v := int64(low64(x.abs))
if x.neg {
v = -v
}
return v
}
// Uint64 returns the uint64 representation of x.
// If x cannot be represented in a uint64, the result is undefined.
func (x *Int) Uint64() uint64 {
return low64(x.abs)
}
// IsInt64 reports whether x can be represented as an int64.
func (x *Int) IsInt64() bool {
if len(x.abs) <= 64/_W {
w := int64(low64(x.abs))
return w >= 0 || x.neg && w == -w
}
return false
}
// IsUint64 reports whether x can be represented as a uint64.
func (x *Int) IsUint64() bool {
return !x.neg && len(x.abs) <= 64/_W
}
// ToFloat64 returns the float64 value nearest x,
// and an indication of any rounding that occurred.
func (x *Int) ToFloat64() (float64, Accuracy) {
n := x.abs.bitLen() // NB: still uses slow crypto impl!
if n == 0 {
return 0.0, Exact
}
// Fast path: no more than 53 significant bits.
if n <= 53 || n < 64 && n-int(x.abs.trailingZeroBits()) <= 53 {
f := float64(low64(x.abs))
if x.neg {
f = -f
}
return f, Exact
}
return new(Float).SetInt(x).Float64()
}
// SetString sets z to the value of s, interpreted in the given base,
// and returns z and a boolean indicating success. The entire string
// (not just a prefix) must be valid for success. If SetString fails,
// the value of z is undefined but the returned value is nil.
//
// The base argument must be 0 or a value between 2 and MaxBase.
// For base 0, the number prefix determines the actual base: A prefix of
// “0b” or “0B” selects base 2, “0”, “0o” or “0O” selects base 8,
// and “0x” or “0X” selects base 16. Otherwise, the selected base is 10
// and no prefix is accepted.
//
// For bases <= 36, lower and upper case letters are considered the same:
// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35.
// For bases > 36, the upper case letters 'A' to 'Z' represent the digit
// values 36 to 61.
//
// For base 0, an underscore character “_” may appear between a base
// prefix and an adjacent digit, and between successive digits; such
// underscores do not change the value of the number.
// Incorrect placement of underscores is reported as an error if there
// are no other errors. If base != 0, underscores are not recognized
// and act like any other character that is not a valid digit.
func (z *Int) SetString(s string, base int) (*Int, bool) {
return z.setFromScanner(strings.NewReader(s), base)
}
// setFromScanner implements SetString given an io.ByteScanner.
// For documentation see comments of SetString.
func (z *Int) setFromScanner(r io.ByteScanner, base int) (*Int, bool) {
if _, _, err := z.scan(r, base); err != nil {
return nil, false
}
// entire content must have been consumed
if _, err := r.ReadByte(); err != io.EOF {
return nil, false
}
return z, true // err == io.EOF => scan consumed all content of r
}
// SetBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z *Int) SetBytes(buf []byte) *Int {
z.abs = z.abs.setBytes(buf)
z.neg = false
return z
}
// Bytes returns the absolute value of x as a big-endian byte slice.
//
// To use a fixed length slice, or a preallocated one, use FillBytes.
func (x *Int) Bytes() []byte {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
buf := make([]byte, len(x.abs)*_S)
return buf[x.abs.bytes(buf):]
}
// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
// big-endian byte slice, and returns buf.
//
// If the absolute value of x doesn't fit in buf, FillBytes will panic.
func (x *Int) FillBytes(buf []byte) []byte {
// Clear whole buffer. (This gets optimized into a memclr.)
for i := range buf {
buf[i] = 0
}
x.abs.bytes(buf)
return buf
}
// BitLen returns the length of the absolute value of x in bits.
// The bit length of 0 is 0.
func (x *Int) BitLen() int {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
return x.abs.bitLen()
}
// TrailingZeroBits returns the number of consecutive least significant zero
// bits of |x|.
func (x *Int) TrailingZeroBits() uint {
return x.abs.trailingZeroBits()
}
// Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z.
// If m == nil or m == 0, z = x**y unless y <= 0 then z = 1. If m != 0, y < 0,
// and x and m are not relatively prime, z is unchanged and nil is returned.
//
// Modular exponentiation of inputs of a particular size is not a
// cryptographically constant-time operation.
func (z *Int) Exp(x, y, m *Int) *Int {
return z.exp(x, y, m, false)
}
func (z *Int) expSlow(x, y, m *Int) *Int {
return z.exp(x, y, m, true)
}
func (z *Int) exp(x, y, m *Int, slow bool) *Int {
// See Knuth, volume 2, section 4.6.3.
xWords := x.abs
if y.neg {
if m == nil || len(m.abs) == 0 {
return z.SetInt64(1)
}
// for y < 0: x**y mod m == (x**(-1))**|y| mod m
inverse := new(Int).ModInverse(x, m)
if inverse == nil {
return nil
}
xWords = inverse.abs
}
yWords := y.abs
var mWords nat
if m != nil {
if z == m || alias(z.abs, m.abs) {
m = new(Int).Set(m)
}
mWords = m.abs // m.abs may be nil for m == 0
}
z.abs = z.abs.expNN(xWords, yWords, mWords, slow)
z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign
if z.neg && len(mWords) > 0 {
// make modulus result positive
z.abs = z.abs.sub(mWords, z.abs) // z == x**y mod |m| && 0 <= z < |m|
z.neg = false
}
return z
}
// GCD sets z to the greatest common divisor of a and b and returns z.
// If x or y are not nil, GCD sets their value such that z = a*x + b*y.
//
// a and b may be positive, zero or negative. (Before Go 1.14 both had
// to be > 0.) Regardless of the signs of a and b, z is always >= 0.
//
// If a == b == 0, GCD sets z = x = y = 0.
//
// If a == 0 and b != 0, GCD sets z = |b|, x = 0, y = sign(b) * 1.
//
// If a != 0 and b == 0, GCD sets z = |a|, x = sign(a) * 1, y = 0.
func (z *Int) GCD(x, y, a, b *Int) *Int {
if len(a.abs) == 0 || len(b.abs) == 0 {
lenA, lenB, negA, negB := len(a.abs), len(b.abs), a.neg, b.neg
if lenA == 0 {
z.Set(b)
} else {
z.Set(a)
}
z.neg = false
if x != nil {
if lenA == 0 {
x.SetUint64(0)
} else {
x.SetUint64(1)
x.neg = negA
}
}
if y != nil {
if lenB == 0 {
y.SetUint64(0)
} else {
y.SetUint64(1)
y.neg = negB
}
}
return z
}
return z.lehmerGCD(x, y, a, b)
}
// lehmerSimulate attempts to simulate several Euclidean update steps
// using the leading digits of A and B. It returns u0, u1, v0, v1
// such that A and B can be updated as:
//
// A = u0*A + v0*B
// B = u1*A + v1*B
//
// Requirements: A >= B and len(B.abs) >= 2
// Since we are calculating with full words to avoid overflow,
// we use 'even' to track the sign of the cosequences.
// For even iterations: u0, v1 >= 0 && u1, v0 <= 0
// For odd iterations: u0, v1 <= 0 && u1, v0 >= 0
func lehmerSimulate(A, B *Int) (u0, u1, v0, v1 Word, even bool) {
// initialize the digits
var a1, a2, u2, v2 Word
m := len(B.abs) // m >= 2
n := len(A.abs) // n >= m >= 2
// extract the top Word of bits from A and B
h := nlz(A.abs[n-1])
a1 = A.abs[n-1]<<h | A.abs[n-2]>>(_W-h)
// B may have implicit zero words in the high bits if the lengths differ
switch {
case n == m:
a2 = B.abs[n-1]<<h | B.abs[n-2]>>(_W-h)
case n == m+1:
a2 = B.abs[n-2] >> (_W - h)
default:
a2 = 0
}
// Since we are calculating with full words to avoid overflow,
// we use 'even' to track the sign of the cosequences.
// For even iterations: u0, v1 >= 0 && u1, v0 <= 0
// For odd iterations: u0, v1 <= 0 && u1, v0 >= 0
// The first iteration starts with k=1 (odd).
even = false
// variables to track the cosequences
u0, u1, u2 = 0, 1, 0
v0, v1, v2 = 0, 0, 1
// Calculate the quotient and cosequences using Collins' stopping condition.
// Note that overflow of a Word is not possible when computing the remainder
// sequence and cosequences since the cosequence size is bounded by the input size.
// See section 4.2 of Jebelean for details.
for a2 >= v2 && a1-a2 >= v1+v2 {
q, r := a1/a2, a1%a2
a1, a2 = a2, r
u0, u1, u2 = u1, u2, u1+q*u2
v0, v1, v2 = v1, v2, v1+q*v2
even = !even
}
return
}
// lehmerUpdate updates the inputs A and B such that:
//
// A = u0*A + v0*B
// B = u1*A + v1*B
//
// where the signs of u0, u1, v0, v1 are given by even
// For even == true: u0, v1 >= 0 && u1, v0 <= 0
// For even == false: u0, v1 <= 0 && u1, v0 >= 0
// q, r, s, t are temporary variables to avoid allocations in the multiplication.
func lehmerUpdate(A, B, q, r, s, t *Int, u0, u1, v0, v1 Word, even bool) {
t.abs = t.abs.setWord(u0)
s.abs = s.abs.setWord(v0)
t.neg = !even
s.neg = even
t.Mul(A, t)
s.Mul(B, s)
r.abs = r.abs.setWord(u1)
q.abs = q.abs.setWord(v1)
r.neg = even
q.neg = !even
r.Mul(A, r)
q.Mul(B, q)
A.Add(t, s)
B.Add(r, q)
}
// euclidUpdate performs a single step of the Euclidean GCD algorithm
// if extended is true, it also updates the cosequence Ua, Ub.
func euclidUpdate(A, B, Ua, Ub, q, r, s, t *Int, extended bool) {
q, r = q.QuoRem(A, B, r)
*A, *B, *r = *B, *r, *A
if extended {
// Ua, Ub = Ub, Ua - q*Ub
t.Set(Ub)
s.Mul(Ub, q)
Ub.Sub(Ua, s)
Ua.Set(t)
}
}
// lehmerGCD sets z to the greatest common divisor of a and b,
// which both must be != 0, and returns z.
// If x or y are not nil, their values are set such that z = a*x + b*y.
// See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm L.
// This implementation uses the improved condition by Collins requiring only one
// quotient and avoiding the possibility of single Word overflow.
// See Jebelean, "Improving the multiprecision Euclidean algorithm",
// Design and Implementation of Symbolic Computation Systems, pp 45-58.
// The cosequences are updated according to Algorithm 10.45 from
// Cohen et al. "Handbook of Elliptic and Hyperelliptic Curve Cryptography" pp 192.
func (z *Int) lehmerGCD(x, y, a, b *Int) *Int {
var A, B, Ua, Ub *Int
A = new(Int).Abs(a)
B = new(Int).Abs(b)
extended := x != nil || y != nil
if extended {
// Ua (Ub) tracks how many times input a has been accumulated into A (B).
Ua = new(Int).SetInt64(1)
Ub = new(Int)
}
// temp variables for multiprecision update
q := new(Int)
r := new(Int)
s := new(Int)
t := new(Int)
// ensure A >= B
if A.abs.cmp(B.abs) < 0 {
A, B = B, A
Ub, Ua = Ua, Ub
}
// loop invariant A >= B
for len(B.abs) > 1 {
// Attempt to calculate in single-precision using leading words of A and B.
u0, u1, v0, v1, even := lehmerSimulate(A, B)
// multiprecision Step
if v0 != 0 {
// Simulate the effect of the single-precision steps using the cosequences.
// A = u0*A + v0*B
// B = u1*A + v1*B
lehmerUpdate(A, B, q, r, s, t, u0, u1, v0, v1, even)
if extended {
// Ua = u0*Ua + v0*Ub
// Ub = u1*Ua + v1*Ub
lehmerUpdate(Ua, Ub, q, r, s, t, u0, u1, v0, v1, even)
}
} else {
// Single-digit calculations failed to simulate any quotients.
// Do a standard Euclidean step.
euclidUpdate(A, B, Ua, Ub, q, r, s, t, extended)
}
}
if len(B.abs) > 0 {
// extended Euclidean algorithm base case if B is a single Word
if len(A.abs) > 1 {
// A is longer than a single Word, so one update is needed.
euclidUpdate(A, B, Ua, Ub, q, r, s, t, extended)
}
if len(B.abs) > 0 {
// A and B are both a single Word.
aWord, bWord := A.abs[0], B.abs[0]
if extended {
var ua, ub, va, vb Word
ua, ub = 1, 0
va, vb = 0, 1
even := true
for bWord != 0 {
q, r := aWord/bWord, aWord%bWord
aWord, bWord = bWord, r
ua, ub = ub, ua+q*ub
va, vb = vb, va+q*vb
even = !even
}
t.abs = t.abs.setWord(ua)
s.abs = s.abs.setWord(va)
t.neg = !even
s.neg = even
t.Mul(Ua, t)
s.Mul(Ub, s)
Ua.Add(t, s)
} else {
for bWord != 0 {
aWord, bWord = bWord, aWord%bWord
}
}
A.abs[0] = aWord
}
}
negA := a.neg
if y != nil {
// avoid aliasing b needed in the division below
if y == b {
B.Set(b)
} else {
B = b
}
// y = (z - a*x)/b
y.Mul(a, Ua) // y can safely alias a
if negA {
y.neg = !y.neg
}
y.Sub(A, y)
y.Div(y, B)
}
if x != nil {
*x = *Ua
if negA {
x.neg = !x.neg
}
}
*z = *A
return z
}
// Rand sets z to a pseudo-random number in [0, n) and returns z.
//
// As this uses the math/rand package, it must not be used for
// security-sensitive work. Use crypto/rand.Int instead.
func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int {
// z.neg is not modified before the if check, because z and n might alias.
if n.neg || len(n.abs) == 0 {
z.neg = false
z.abs = nil
return z
}
z.neg = false
z.abs = z.abs.random(rnd, n.abs, n.abs.bitLen())
return z
}
// ModInverse sets z to the multiplicative inverse of g in the ring ℤ/nℤ
// and returns z. If g and n are not relatively prime, g has no multiplicative
// inverse in the ring ℤ/nℤ. In this case, z is unchanged and the return value
// is nil. If n == 0, a division-by-zero run-time panic occurs.
func (z *Int) ModInverse(g, n *Int) *Int {
// GCD expects parameters a and b to be > 0.
if n.neg {
var n2 Int
n = n2.Neg(n)
}
if g.neg {
var g2 Int
g = g2.Mod(g, n)
}
var d, x Int
d.GCD(&x, nil, g, n)
// if and only if d==1, g and n are relatively prime
if d.Cmp(intOne) != 0 {
return nil
}
// x and y are such that g*x + n*y = 1, therefore x is the inverse element,
// but it may be negative, so convert to the range 0 <= z < |n|
if x.neg {
z.Add(&x, n)
} else {
z.Set(&x)
}
return z
}
func (z nat) modInverse(g, n nat) nat {
// TODO(rsc): ModInverse should be implemented in terms of this function.
return (&Int{abs: z}).ModInverse(&Int{abs: g}, &Int{abs: n}).abs
}
// Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0.
// The y argument must be an odd integer.
func Jacobi(x, y *Int) int {
if len(y.abs) == 0 || y.abs[0]&1 == 0 {
panic(fmt.Sprintf("big: invalid 2nd argument to Int.Jacobi: need odd integer but got %s", y.String()))
}
// We use the formulation described in chapter 2, section 2.4,
// "The Yacas Book of Algorithms":
// http://yacas.sourceforge.net/Algo.book.pdf
var a, b, c Int
a.Set(x)
b.Set(y)
j := 1
if b.neg {
if a.neg {
j = -1
}
b.neg = false
}
for {
if b.Cmp(intOne) == 0 {
return j
}
if len(a.abs) == 0 {
return 0
}
a.Mod(&a, &b)
if len(a.abs) == 0 {
return 0
}
// a > 0
// handle factors of 2 in 'a'
s := a.abs.trailingZeroBits()
if s&1 != 0 {
bmod8 := b.abs[0] & 7
if bmod8 == 3 || bmod8 == 5 {
j = -j
}
}
c.Rsh(&a, s) // a = 2^s*c
// swap numerator and denominator
if b.abs[0]&3 == 3 && c.abs[0]&3 == 3 {
j = -j
}
a.Set(&b)
b.Set(&c)
}
}
// modSqrt3Mod4 uses the identity
//
// (a^((p+1)/4))^2 mod p
// == u^(p+1) mod p
// == u^2 mod p
//
// to calculate the square root of any quadratic residue mod p quickly for 3
// mod 4 primes.
func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int {
e := new(Int).Add(p, intOne) // e = p + 1
e.Rsh(e, 2) // e = (p + 1) / 4
z.Exp(x, e, p) // z = x^e mod p
return z
}
// modSqrt5Mod8Prime uses Atkin's observation that 2 is not a square mod p
//
// alpha == (2*a)^((p-5)/8) mod p
// beta == 2*a*alpha^2 mod p is a square root of -1
// b == a*alpha*(beta-1) mod p is a square root of a
//
// to calculate the square root of any quadratic residue mod p quickly for 5
// mod 8 primes.
func (z *Int) modSqrt5Mod8Prime(x, p *Int) *Int {
// p == 5 mod 8 implies p = e*8 + 5
// e is the quotient and 5 the remainder on division by 8
e := new(Int).Rsh(p, 3) // e = (p - 5) / 8
tx := new(Int).Lsh(x, 1) // tx = 2*x
alpha := new(Int).Exp(tx, e, p)
beta := new(Int).Mul(alpha, alpha)
beta.Mod(beta, p)
beta.Mul(beta, tx)
beta.Mod(beta, p)
beta.Sub(beta, intOne)
beta.Mul(beta, x)
beta.Mod(beta, p)
beta.Mul(beta, alpha)
z.Mod(beta, p)
return z
}
// modSqrtTonelliShanks uses the Tonelli-Shanks algorithm to find the square
// root of a quadratic residue modulo any prime.
func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int {
// Break p-1 into s*2^e such that s is odd.
var s Int
s.Sub(p, intOne)
e := s.abs.trailingZeroBits()
s.Rsh(&s, e)
// find some non-square n
var n Int
n.SetInt64(2)
for Jacobi(&n, p) != -1 {
n.Add(&n, intOne)
}
// Core of the Tonelli-Shanks algorithm. Follows the description in
// section 6 of "Square roots from 1; 24, 51, 10 to Dan Shanks" by Ezra
// Brown:
// https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf
var y, b, g, t Int
y.Add(&s, intOne)
y.Rsh(&y, 1)
y.Exp(x, &y, p) // y = x^((s+1)/2)
b.Exp(x, &s, p) // b = x^s
g.Exp(&n, &s, p) // g = n^s
r := e
for {
// find the least m such that ord_p(b) = 2^m
var m uint
t.Set(&b)
for t.Cmp(intOne) != 0 {
t.Mul(&t, &t).Mod(&t, p)
m++
}
if m == 0 {
return z.Set(&y)
}
t.SetInt64(0).SetBit(&t, int(r-m-1), 1).Exp(&g, &t, p)
// t = g^(2^(r-m-1)) mod p
g.Mul(&t, &t).Mod(&g, p) // g = g^(2^(r-m)) mod p
y.Mul(&y, &t).Mod(&y, p)
b.Mul(&b, &g).Mod(&b, p)
r = m
}
}
// ModSqrt sets z to a square root of x mod p if such a square root exists, and
// returns z. The modulus p must be an odd prime. If x is not a square mod p,
// ModSqrt leaves z unchanged and returns nil. This function panics if p is
// not an odd integer, its behavior is undefined if p is odd but not prime.
func (z *Int) ModSqrt(x, p *Int) *Int {
switch Jacobi(x, p) {
case -1:
return nil // x is not a square mod p
case 0:
return z.SetInt64(0) // sqrt(0) mod p = 0
case 1:
break
}
if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p
x = new(Int).Mod(x, p)
}
switch {
case p.abs[0]%4 == 3:
// Check whether p is 3 mod 4, and if so, use the faster algorithm.
return z.modSqrt3Mod4Prime(x, p)
case p.abs[0]%8 == 5:
// Check whether p is 5 mod 8, use Atkin's algorithm.
return z.modSqrt5Mod8Prime(x, p)
default:
// Otherwise, use Tonelli-Shanks.
return z.modSqrtTonelliShanks(x, p)
}
}
// Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
z.abs = z.abs.shl(x.abs, n)
z.neg = x.neg
return z
}
// Rsh sets z = x >> n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int {
if x.neg {
// (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1)
t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0
t = t.shr(t, n)
z.abs = t.add(t, natOne)
z.neg = true // z cannot be zero if x is negative
return z
}
z.abs = z.abs.shr(x.abs, n)
z.neg = false
return z
}
// Bit returns the value of the i'th bit of x. That is, it
// returns (x>>i)&1. The bit index i must be >= 0.
func (x *Int) Bit(i int) uint {
if i == 0 {
// optimization for common case: odd/even test of x
if len(x.abs) > 0 {
return uint(x.abs[0] & 1) // bit 0 is same for -x
}
return 0
}
if i < 0 {
panic("negative bit index")
}
if x.neg {
t := nat(nil).sub(x.abs, natOne)
return t.bit(uint(i)) ^ 1
}
return x.abs.bit(uint(i))
}
// SetBit sets z to x, with x's i'th bit set to b (0 or 1).
// That is, if b is 1 SetBit sets z = x | (1 << i);
// if b is 0 SetBit sets z = x &^ (1 << i). If b is not 0 or 1,
// SetBit will panic.
func (z *Int) SetBit(x *Int, i int, b uint) *Int {
if i < 0 {
panic("negative bit index")
}
if x.neg {
t := z.abs.sub(x.abs, natOne)
t = t.setBit(t, uint(i), b^1)
z.abs = t.add(t, natOne)
z.neg = len(z.abs) > 0
return z
}
z.abs = z.abs.setBit(x.abs, uint(i), b)
z.neg = false
return z
}
// And sets z = x & y and returns z.
func (z *Int) And(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative
return z
}
// x & y == x & y
z.abs = z.abs.and(x.abs, y.abs)
z.neg = false
return z
}
// x.neg != y.neg
if x.neg {
x, y = y, x // & is symmetric
}
// x & (-y) == x & ^(y-1) == x &^ (y-1)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(x.abs, y1)
z.neg = false
return z
}
// AndNot sets z = x &^ y and returns z.
func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(y1, x1)
z.neg = false
return z
}
// x &^ y == x &^ y
z.abs = z.abs.andNot(x.abs, y.abs)
z.neg = false
return z
}
if x.neg {
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
x1 := nat(nil).sub(x.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
z.neg = true // z cannot be zero if x is negative and y is positive
return z
}
// x &^ (-y) == x &^ ^(y-1) == x & (y-1)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.and(x.abs, y1)
z.neg = false
return z
}
// Or sets z = x | y and returns z.
func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative
return z
}
// x | y == x | y
z.abs = z.abs.or(x.abs, y.abs)
z.neg = false
return z
}
// x.neg != y.neg
if x.neg {
x, y = y, x // | is symmetric
}
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
z.neg = true // z cannot be zero if one of x or y is negative
return z
}
// Xor sets z = x ^ y and returns z.
func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg {
if x.neg {
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
x1 := nat(nil).sub(x.abs, natOne)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.xor(x1, y1)
z.neg = false
return z
}
// x ^ y == x ^ y
z.abs = z.abs.xor(x.abs, y.abs)
z.neg = false
return z
}
// x.neg != y.neg
if x.neg {
x, y = y, x // ^ is symmetric
}
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
z.neg = true // z cannot be zero if only one of x or y is negative
return z
}
// Not sets z = ^x and returns z.
func (z *Int) Not(x *Int) *Int {
if x.neg {
// ^(-x) == ^(^(x-1)) == x-1
z.abs = z.abs.sub(x.abs, natOne)
z.neg = false
return z
}
// ^x == -x-1 == -(x+1)
z.abs = z.abs.add(x.abs, natOne)
z.neg = true // z cannot be zero if x is positive
return z
}
// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z.
// It panics if x is negative.
func (z *Int) Sqrt(x *Int) *Int {
if x.neg {
panic("square root of negative number")
}
z.neg = false
z.abs = z.abs.sqrt(x.abs)
return z
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements int-to-string conversion functions.
package big
import (
"errors"
"fmt"
"io"
)
// Text returns the string representation of x in the given base.
// Base must be between 2 and 62, inclusive. The result uses the
// lower-case letters 'a' to 'z' for digit values 10 to 35, and
// the upper-case letters 'A' to 'Z' for digit values 36 to 61.
// No prefix (such as "0x") is added to the string. If x is a nil
// pointer it returns "<nil>".
func (x *Int) Text(base int) string {
if x == nil {
return "<nil>"
}
return string(x.abs.itoa(x.neg, base))
}
// Append appends the string representation of x, as generated by
// x.Text(base), to buf and returns the extended buffer.
func (x *Int) Append(buf []byte, base int) []byte {
if x == nil {
return append(buf, "<nil>"...)
}
return append(buf, x.abs.itoa(x.neg, base)...)
}
// String returns the decimal representation of x as generated by
// x.Text(10).
func (x *Int) String() string {
return x.Text(10)
}
// write count copies of text to s.
func writeMultiple(s fmt.State, text string, count int) {
if len(text) > 0 {
b := []byte(text)
for ; count > 0; count-- {
s.Write(b)
}
}
}
var _ fmt.Formatter = intOne // *Int must implement fmt.Formatter
// Format implements fmt.Formatter. It accepts the formats
// 'b' (binary), 'o' (octal with 0 prefix), 'O' (octal with 0o prefix),
// 'd' (decimal), 'x' (lowercase hexadecimal), and
// 'X' (uppercase hexadecimal).
// Also supported are the full suite of package fmt's format
// flags for integral types, including '+' and ' ' for sign
// control, '#' for leading zero in octal and for hexadecimal,
// a leading "0x" or "0X" for "%#x" and "%#X" respectively,
// specification of minimum digits precision, output field
// width, space or zero padding, and '-' for left or right
// justification.
func (x *Int) Format(s fmt.State, ch rune) {
// determine base
var base int
switch ch {
case 'b':
base = 2
case 'o', 'O':
base = 8
case 'd', 's', 'v':
base = 10
case 'x', 'X':
base = 16
default:
// unknown format
fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String())
return
}
if x == nil {
fmt.Fprint(s, "<nil>")
return
}
// determine sign character
sign := ""
switch {
case x.neg:
sign = "-"
case s.Flag('+'): // supersedes ' ' when both specified
sign = "+"
case s.Flag(' '):
sign = " "
}
// determine prefix characters for indicating output base
prefix := ""
if s.Flag('#') {
switch ch {
case 'b': // binary
prefix = "0b"
case 'o': // octal
prefix = "0"
case 'x': // hexadecimal
prefix = "0x"
case 'X':
prefix = "0X"
}
}
if ch == 'O' {
prefix = "0o"
}
digits := x.abs.utoa(base)
if ch == 'X' {
// faster than bytes.ToUpper
for i, d := range digits {
if 'a' <= d && d <= 'z' {
digits[i] = 'A' + (d - 'a')
}
}
}
// number of characters for the three classes of number padding
var left int // space characters to left of digits for right justification ("%8d")
var zeros int // zero characters (actually cs[0]) as left-most digits ("%.8d")
var right int // space characters to right of digits for left justification ("%-8d")
// determine number padding from precision: the least number of digits to output
precision, precisionSet := s.Precision()
if precisionSet {
switch {
case len(digits) < precision:
zeros = precision - len(digits) // count of zero padding
case len(digits) == 1 && digits[0] == '0' && precision == 0:
return // print nothing if zero value (x == 0) and zero precision ("." or ".0")
}
}
// determine field pad from width: the least number of characters to output
length := len(sign) + len(prefix) + zeros + len(digits)
if width, widthSet := s.Width(); widthSet && length < width { // pad as specified
switch d := width - length; {
case s.Flag('-'):
// pad on the right with spaces; supersedes '0' when both specified
right = d
case s.Flag('0') && !precisionSet:
// pad with zeros unless precision also specified
zeros = d
default:
// pad on the left with spaces
left = d
}
}
// print number as [left pad][sign][prefix][zero pad][digits][right pad]
writeMultiple(s, " ", left)
writeMultiple(s, sign, 1)
writeMultiple(s, prefix, 1)
writeMultiple(s, "0", zeros)
s.Write(digits)
writeMultiple(s, " ", right)
}
// scan sets z to the integer value corresponding to the longest possible prefix
// read from r representing a signed integer number in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined but the returned value is nil. The
// syntax follows the syntax of integer literals in Go.
//
// The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of
// “0b” or “0B” selects base 2; a “0”, “0o”, or “0O” prefix selects
// base 8, and a “0x” or “0X” prefix selects base 16. Otherwise the selected
// base is 10.
func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) {
// determine sign
neg, err := scanSign(r)
if err != nil {
return nil, 0, err
}
// determine mantissa
z.abs, base, _, err = z.abs.scan(r, base, false)
if err != nil {
return nil, base, err
}
z.neg = len(z.abs) > 0 && neg // 0 has no sign
return z, base, nil
}
func scanSign(r io.ByteScanner) (neg bool, err error) {
var ch byte
if ch, err = r.ReadByte(); err != nil {
return false, err
}
switch ch {
case '-':
neg = true
case '+':
// nothing to do
default:
r.UnreadByte()
}
return
}
// byteReader is a local wrapper around fmt.ScanState;
// it implements the ByteReader interface.
type byteReader struct {
fmt.ScanState
}
func (r byteReader) ReadByte() (byte, error) {
ch, size, err := r.ReadRune()
if size != 1 && err == nil {
err = fmt.Errorf("invalid rune %#U", ch)
}
return byte(ch), err
}
func (r byteReader) UnreadByte() error {
return r.UnreadRune()
}
var _ fmt.Scanner = intOne // *Int must implement fmt.Scanner
// Scan is a support routine for fmt.Scanner; it sets z to the value of
// the scanned number. It accepts the formats 'b' (binary), 'o' (octal),
// 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal).
func (z *Int) Scan(s fmt.ScanState, ch rune) error {
s.SkipSpace() // skip leading space characters
base := 0
switch ch {
case 'b':
base = 2
case 'o':
base = 8
case 'd':
base = 10
case 'x', 'X':
base = 16
case 's', 'v':
// let scan determine the base
default:
return errors.New("Int.Scan: invalid verb")
}
_, _, err := z.scan(byteReader{s}, base)
return err
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements encoding/decoding of Ints.
package big
import (
"bytes"
"fmt"
)
// Gob codec version. Permits backward-compatible changes to the encoding.
const intGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (x *Int) GobEncode() ([]byte, error) {
if x == nil {
return nil, nil
}
buf := make([]byte, 1+len(x.abs)*_S) // extra byte for version and sign bit
i := x.abs.bytes(buf) - 1 // i >= 0
b := intGobVersion << 1 // make space for sign bit
if x.neg {
b |= 1
}
buf[i] = b
return buf[i:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Int) GobDecode(buf []byte) error {
if len(buf) == 0 {
// Other side sent a nil or default value.
*z = Int{}
return nil
}
b := buf[0]
if b>>1 != intGobVersion {
return fmt.Errorf("Int.GobDecode: encoding version %d not supported", b>>1)
}
z.neg = b&1 != 0
z.abs = z.abs.setBytes(buf[1:])
return nil
}
// MarshalText implements the encoding.TextMarshaler interface.
func (x *Int) MarshalText() (text []byte, err error) {
if x == nil {
return []byte("<nil>"), nil
}
return x.abs.itoa(x.neg, 10), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (z *Int) UnmarshalText(text []byte) error {
if _, ok := z.setFromScanner(bytes.NewReader(text), 0); !ok {
return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Int", text)
}
return nil
}
// The JSON marshalers are only here for API backward compatibility
// (programs that explicitly look for these two methods). JSON works
// fine with the TextMarshaler only.
// MarshalJSON implements the json.Marshaler interface.
func (x *Int) MarshalJSON() ([]byte, error) {
if x == nil {
return []byte("null"), nil
}
return x.abs.itoa(x.neg, 10), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (z *Int) UnmarshalJSON(text []byte) error {
// Ignore null, like in the main JSON package.
if string(text) == "null" {
return nil
}
return z.UnmarshalText(text)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements unsigned multi-precision integers (natural
// numbers). They are the building blocks for the implementation
// of signed integers, rationals, and floating-point numbers.
//
// Caution: This implementation relies on the function "alias"
// which assumes that (nat) slice capacities are never
// changed (no 3-operand slice expressions). If that
// changes, alias needs to be updated for correctness.
package big
import (
"encoding/binary"
"math/bits"
"math/rand"
"sync"
)
// An unsigned integer x of the form
//
// x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
//
// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
// with the digits x[i] as the slice elements.
//
// A number is normalized if the slice contains no leading 0 digits.
// During arithmetic operations, denormalized values may occur but are
// always normalized before returning the final result. The normalized
// representation of 0 is the empty or nil slice (length = 0).
type nat []Word
var (
natOne = nat{1}
natTwo = nat{2}
natFive = nat{5}
natTen = nat{10}
)
func (z nat) String() string {
return "0x" + string(z.itoa(false, 16))
}
func (z nat) clear() {
for i := range z {
z[i] = 0
}
}
func (z nat) norm() nat {
i := len(z)
for i > 0 && z[i-1] == 0 {
i--
}
return z[0:i]
}
func (z nat) make(n int) nat {
if n <= cap(z) {
return z[:n] // reuse z
}
if n == 1 {
// Most nats start small and stay that way; don't over-allocate.
return make(nat, 1)
}
// Choosing a good value for e has significant performance impact
// because it increases the chance that a value can be reused.
const e = 4 // extra capacity
return make(nat, n, n+e)
}
func (z nat) setWord(x Word) nat {
if x == 0 {
return z[:0]
}
z = z.make(1)
z[0] = x
return z
}
func (z nat) setUint64(x uint64) nat {
// single-word value
if w := Word(x); uint64(w) == x {
return z.setWord(w)
}
// 2-word value
z = z.make(2)
z[1] = Word(x >> 32)
z[0] = Word(x)
return z
}
func (z nat) set(x nat) nat {
z = z.make(len(x))
copy(z, x)
return z
}
func (z nat) add(x, y nat) nat {
m := len(x)
n := len(y)
switch {
case m < n:
return z.add(y, x)
case m == 0:
// n == 0 because m >= n; result is 0
return z[:0]
case n == 0:
// result is x
return z.set(x)
}
// m > 0
z = z.make(m + 1)
c := addVV(z[0:n], x, y)
if m > n {
c = addVW(z[n:m], x[n:], c)
}
z[m] = c
return z.norm()
}
func (z nat) sub(x, y nat) nat {
m := len(x)
n := len(y)
switch {
case m < n:
panic("underflow")
case m == 0:
// n == 0 because m >= n; result is 0
return z[:0]
case n == 0:
// result is x
return z.set(x)
}
// m > 0
z = z.make(m)
c := subVV(z[0:n], x, y)
if m > n {
c = subVW(z[n:], x[n:], c)
}
if c != 0 {
panic("underflow")
}
return z.norm()
}
func (x nat) cmp(y nat) (r int) {
m := len(x)
n := len(y)
if m != n || m == 0 {
switch {
case m < n:
r = -1
case m > n:
r = 1
}
return
}
i := m - 1
for i > 0 && x[i] == y[i] {
i--
}
switch {
case x[i] < y[i]:
r = -1
case x[i] > y[i]:
r = 1
}
return
}
func (z nat) mulAddWW(x nat, y, r Word) nat {
m := len(x)
if m == 0 || y == 0 {
return z.setWord(r) // result is r
}
// m > 0
z = z.make(m + 1)
z[m] = mulAddVWW(z[0:m], x, y, r)
return z.norm()
}
// basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func basicMul(z, x, y nat) {
z[0 : len(x)+len(y)].clear() // initialize z
for i, d := range y {
if d != 0 {
z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
}
}
}
// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
// assuming k = -1/m mod 2**_W.
// z is used for storing the result which is returned;
// z must not alias x, y or m.
// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
// https://eprint.iacr.org/2011/239.pdf
// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
// This code assumes x, y, m are all the same length, n.
// (required by addMulVVW and the for loop).
// It also assumes that x, y are already reduced mod m,
// or else the result will not be properly reduced.
if len(x) != n || len(y) != n || len(m) != n {
panic("math/big: mismatched montgomery number lengths")
}
z = z.make(n * 2)
z.clear()
var c Word
for i := 0; i < n; i++ {
d := y[i]
c2 := addMulVVW(z[i:n+i], x, d)
t := z[i] * k
c3 := addMulVVW(z[i:n+i], m, t)
cx := c + c2
cy := cx + c3
z[n+i] = cy
if cx < c2 || cy < c3 {
c = 1
} else {
c = 0
}
}
if c != 0 {
subVV(z[:n], z[n:], m)
} else {
copy(z[:n], z[n:])
}
return z[:n]
}
// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
// Factored out for readability - do not use outside karatsuba.
func karatsubaAdd(z, x nat, n int) {
if c := addVV(z[0:n], z, x); c != 0 {
addVW(z[n:n+n>>1], z[n:], c)
}
}
// Like karatsubaAdd, but does subtract.
func karatsubaSub(z, x nat, n int) {
if c := subVV(z[0:n], z, x); c != 0 {
subVW(z[n:n+n>>1], z[n:], c)
}
}
// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
var karatsubaThreshold = 40 // computed by calibrate_test.go
// karatsuba multiplies x and y and leaves the result in z.
// Both x and y must have the same length n and n must be a
// power of 2. The result vector z must have len(z) >= 6*n.
// The (non-normalized) result is placed in z[0 : 2*n].
func karatsuba(z, x, y nat) {
n := len(y)
// Switch to basic multiplication if numbers are odd or small.
// (n is always even if karatsubaThreshold is even, but be
// conservative)
if n&1 != 0 || n < karatsubaThreshold || n < 2 {
basicMul(z, x, y)
return
}
// n&1 == 0 && n >= karatsubaThreshold && n >= 2
// Karatsuba multiplication is based on the observation that
// for two numbers x and y with:
//
// x = x1*b + x0
// y = y1*b + y0
//
// the product x*y can be obtained with 3 products z2, z1, z0
// instead of 4:
//
// x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
// = z2*b*b + z1*b + z0
//
// with:
//
// xd = x1 - x0
// yd = y0 - y1
//
// z1 = xd*yd + z2 + z0
// = (x1-x0)*(y0 - y1) + z2 + z0
// = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
// = x1*y0 - z2 - z0 + x0*y1 + z2 + z0
// = x1*y0 + x0*y1
// split x, y into "digits"
n2 := n >> 1 // n2 >= 1
x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
// z is used for the result and temporary storage:
//
// 6*n 5*n 4*n 3*n 2*n 1*n 0*n
// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
//
// For each recursive call of karatsuba, an unused slice of
// z is passed in that has (at least) half the length of the
// caller's z.
// compute z0 and z2 with the result "in place" in z
karatsuba(z, x0, y0) // z0 = x0*y0
karatsuba(z[n:], x1, y1) // z2 = x1*y1
// compute xd (or the negative value if underflow occurs)
s := 1 // sign of product xd*yd
xd := z[2*n : 2*n+n2]
if subVV(xd, x1, x0) != 0 { // x1-x0
s = -s
subVV(xd, x0, x1) // x0-x1
}
// compute yd (or the negative value if underflow occurs)
yd := z[2*n+n2 : 3*n]
if subVV(yd, y0, y1) != 0 { // y0-y1
s = -s
subVV(yd, y1, y0) // y1-y0
}
// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
p := z[n*3:]
karatsuba(p, xd, yd)
// save original z2:z0
// (ok to use upper half of z since we're done recurring)
r := z[n*4:]
copy(r, z[:n*2])
// add up all partial products
//
// 2*n n 0
// z = [ z2 | z0 ]
// + [ z0 ]
// + [ z2 ]
// + [ p ]
//
karatsubaAdd(z[n2:], r, n)
karatsubaAdd(z[n2:], r[n:], n)
if s > 0 {
karatsubaAdd(z[n2:], p, n)
} else {
karatsubaSub(z[n2:], p, n)
}
}
// alias reports whether x and y share the same base array.
//
// Note: alias assumes that the capacity of underlying arrays
// is never changed for nat values; i.e. that there are
// no 3-operand slice expressions in this code (or worse,
// reflect-based operations to the same effect).
func alias(x, y nat) bool {
return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
}
// addAt implements z += x<<(_W*i); z must be long enough.
// (we don't use nat.add because we need z to stay the same
// slice, and we don't need to normalize z after each addition)
func addAt(z, x nat, i int) {
if n := len(x); n > 0 {
if c := addVV(z[i:i+n], z[i:], x); c != 0 {
j := i + n
if j < len(z) {
addVW(z[j:], z[j:], c)
}
}
}
}
func max(x, y int) int {
if x > y {
return x
}
return y
}
// karatsubaLen computes an approximation to the maximum k <= n such that
// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
// result is the largest number that can be divided repeatedly by 2 before
// becoming about the value of threshold.
func karatsubaLen(n, threshold int) int {
i := uint(0)
for n > threshold {
n >>= 1
i++
}
return n << i
}
func (z nat) mul(x, y nat) nat {
m := len(x)
n := len(y)
switch {
case m < n:
return z.mul(y, x)
case m == 0 || n == 0:
return z[:0]
case n == 1:
return z.mulAddWW(x, y[0], 0)
}
// m >= n > 1
// determine if z can be reused
if alias(z, x) || alias(z, y) {
z = nil // z is an alias for x or y - cannot reuse
}
// use basic multiplication if the numbers are small
if n < karatsubaThreshold {
z = z.make(m + n)
basicMul(z, x, y)
return z.norm()
}
// m >= n && n >= karatsubaThreshold && n >= 2
// determine Karatsuba length k such that
//
// x = xh*b + x0 (0 <= x0 < b)
// y = yh*b + y0 (0 <= y0 < b)
// b = 1<<(_W*k) ("base" of digits xi, yi)
//
k := karatsubaLen(n, karatsubaThreshold)
// k <= n
// multiply x0 and y0 via Karatsuba
x0 := x[0:k] // x0 is not normalized
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
z = z[0 : m+n] // z has final length but may be incomplete
z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
// If xh != 0 or yh != 0, add the missing terms to z. For
//
// xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
// yh = y1*b (0 <= y1 < b)
//
// the missing terms are
//
// x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
//
// since all the yi for i > 1 are 0 by choice of k: If any of them
// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
// be a larger valid threshold contradicting the assumption about k.
//
if k < n || m != n {
tp := getNat(3 * k)
t := *tp
// add x0*y1*b
x0 := x0.norm()
y1 := y[k:] // y1 is normalized because y is
t = t.mul(x0, y1) // update t so we don't lose t's underlying array
addAt(z, t, k)
// add xi*y0<<i, xi*y1*b<<(i+k)
y0 := y0.norm()
for i := k; i < len(x); i += k {
xi := x[i:]
if len(xi) > k {
xi = xi[:k]
}
xi = xi.norm()
t = t.mul(xi, y0)
addAt(z, t, i)
t = t.mul(xi, y1)
addAt(z, t, i+k)
}
putNat(tp)
}
return z.norm()
}
// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) == 2*len(x)
// The (non-normalized) result is placed in z.
func basicSqr(z, x nat) {
n := len(x)
tp := getNat(2 * n)
t := *tp // temporary variable to hold the products
t.clear()
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
d := x[i]
// z collects the squares x[i] * x[i]
z[2*i+1], z[2*i] = mulWW(d, d)
// t collects the products x[i] * x[j] where j < i
t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
}
t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
addVV(z, z, t) // combine the result
putNat(tp)
}
// karatsubaSqr squares x and leaves the result in z.
// len(x) must be a power of 2 and len(z) >= 6*len(x).
// The (non-normalized) result is placed in z[0 : 2*len(x)].
//
// The algorithm and the layout of z are the same as for karatsuba.
func karatsubaSqr(z, x nat) {
n := len(x)
if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
basicSqr(z[:2*n], x)
return
}
n2 := n >> 1
x1, x0 := x[n2:], x[0:n2]
karatsubaSqr(z, x0)
karatsubaSqr(z[n:], x1)
// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
xd := z[2*n : 2*n+n2]
if subVV(xd, x1, x0) != 0 {
subVV(xd, x0, x1)
}
p := z[n*3:]
karatsubaSqr(p, xd)
r := z[n*4:]
copy(r, z[:n*2])
karatsubaAdd(z[n2:], r, n)
karatsubaAdd(z[n2:], r[n:], n)
karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
}
// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// we use the Karatsuba algorithm optimized for x == y.
var basicSqrThreshold = 20 // computed by calibrate_test.go
var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
// z = x*x
func (z nat) sqr(x nat) nat {
n := len(x)
switch {
case n == 0:
return z[:0]
case n == 1:
d := x[0]
z = z.make(2)
z[1], z[0] = mulWW(d, d)
return z.norm()
}
if alias(z, x) {
z = nil // z is an alias for x - cannot reuse
}
if n < basicSqrThreshold {
z = z.make(2 * n)
basicMul(z, x, x)
return z.norm()
}
if n < karatsubaSqrThreshold {
z = z.make(2 * n)
basicSqr(z, x)
return z.norm()
}
// Use Karatsuba multiplication optimized for x == y.
// The algorithm and layout of z are the same as for mul.
// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
k := karatsubaLen(n, karatsubaSqrThreshold)
x0 := x[0:k]
z = z.make(max(6*k, 2*n))
karatsubaSqr(z, x0) // z = x0^2
z = z[0 : 2*n]
z[2*k:].clear()
if k < n {
tp := getNat(2 * k)
t := *tp
x0 := x0.norm()
x1 := x[k:]
t = t.mul(x0, x1)
addAt(z, t, k)
addAt(z, t, k) // z = 2*x1*x0*b + x0^2
t = t.sqr(x1)
addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
putNat(tp)
}
return z.norm()
}
// mulRange computes the product of all the unsigned integers in the
// range [a, b] inclusively. If a > b (empty range), the result is 1.
func (z nat) mulRange(a, b uint64) nat {
switch {
case a == 0:
// cut long ranges short (optimization)
return z.setUint64(0)
case a > b:
return z.setUint64(1)
case a == b:
return z.setUint64(a)
case a+1 == b:
return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
}
m := (a + b) / 2
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
}
// getNat returns a *nat of len n. The contents may not be zero.
// The pool holds *nat to avoid allocation when converting to interface{}.
func getNat(n int) *nat {
var z *nat
if v := natPool.Get(); v != nil {
z = v.(*nat)
}
if z == nil {
z = new(nat)
}
*z = z.make(n)
if n > 0 {
(*z)[0] = 0xfedcb // break code expecting zero
}
return z
}
func putNat(x *nat) {
natPool.Put(x)
}
var natPool sync.Pool
// bitLen returns the length of x in bits.
// Unlike most methods, it works even if x is not normalized.
func (x nat) bitLen() int {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
if i := len(x) - 1; i >= 0 {
// bits.Len uses a lookup table for the low-order bits on some
// architectures. Neutralize any input-dependent behavior by setting all
// bits after the first one bit.
top := uint(x[i])
top |= top >> 1
top |= top >> 2
top |= top >> 4
top |= top >> 8
top |= top >> 16
top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures
return i*_W + bits.Len(top)
}
return 0
}
// trailingZeroBits returns the number of consecutive least significant zero
// bits of x.
func (x nat) trailingZeroBits() uint {
if len(x) == 0 {
return 0
}
var i uint
for x[i] == 0 {
i++
}
// x[i] != 0
return i*_W + uint(bits.TrailingZeros(uint(x[i])))
}
// isPow2 returns i, true when x == 2**i and 0, false otherwise.
func (x nat) isPow2() (uint, bool) {
var i uint
for x[i] == 0 {
i++
}
if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 {
return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true
}
return 0, false
}
func same(x, y nat) bool {
return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
}
// z = x << s
func (z nat) shl(x nat, s uint) nat {
if s == 0 {
if same(z, x) {
return z
}
if !alias(z, x) {
return z.set(x)
}
}
m := len(x)
if m == 0 {
return z[:0]
}
// m > 0
n := m + int(s/_W)
z = z.make(n + 1)
z[n] = shlVU(z[n-m:n], x, s%_W)
z[0 : n-m].clear()
return z.norm()
}
// z = x >> s
func (z nat) shr(x nat, s uint) nat {
if s == 0 {
if same(z, x) {
return z
}
if !alias(z, x) {
return z.set(x)
}
}
m := len(x)
n := m - int(s/_W)
if n <= 0 {
return z[:0]
}
// n > 0
z = z.make(n)
shrVU(z, x[m-n:], s%_W)
return z.norm()
}
func (z nat) setBit(x nat, i uint, b uint) nat {
j := int(i / _W)
m := Word(1) << (i % _W)
n := len(x)
switch b {
case 0:
z = z.make(n)
copy(z, x)
if j >= n {
// no need to grow
return z
}
z[j] &^= m
return z.norm()
case 1:
if j >= n {
z = z.make(j + 1)
z[n:].clear()
} else {
z = z.make(n)
}
copy(z, x)
z[j] |= m
// no need to normalize
return z
}
panic("set bit is not 0 or 1")
}
// bit returns the value of the i'th bit, with lsb == bit 0.
func (x nat) bit(i uint) uint {
j := i / _W
if j >= uint(len(x)) {
return 0
}
// 0 <= j < len(x)
return uint(x[j] >> (i % _W) & 1)
}
// sticky returns 1 if there's a 1 bit within the
// i least significant bits, otherwise it returns 0.
func (x nat) sticky(i uint) uint {
j := i / _W
if j >= uint(len(x)) {
if len(x) == 0 {
return 0
}
return 1
}
// 0 <= j < len(x)
for _, x := range x[:j] {
if x != 0 {
return 1
}
}
if x[j]<<(_W-i%_W) != 0 {
return 1
}
return 0
}
func (z nat) and(x, y nat) nat {
m := len(x)
n := len(y)
if m > n {
m = n
}
// m <= n
z = z.make(m)
for i := 0; i < m; i++ {
z[i] = x[i] & y[i]
}
return z.norm()
}
// trunc returns z = x mod 2ⁿ.
func (z nat) trunc(x nat, n uint) nat {
w := (n + _W - 1) / _W
if uint(len(x)) < w {
return z.set(x)
}
z = z.make(int(w))
copy(z, x)
if n%_W != 0 {
z[len(z)-1] &= 1<<(n%_W) - 1
}
return z.norm()
}
func (z nat) andNot(x, y nat) nat {
m := len(x)
n := len(y)
if n > m {
n = m
}
// m >= n
z = z.make(m)
for i := 0; i < n; i++ {
z[i] = x[i] &^ y[i]
}
copy(z[n:m], x[n:m])
return z.norm()
}
func (z nat) or(x, y nat) nat {
m := len(x)
n := len(y)
s := x
if m < n {
n, m = m, n
s = y
}
// m >= n
z = z.make(m)
for i := 0; i < n; i++ {
z[i] = x[i] | y[i]
}
copy(z[n:m], s[n:m])
return z.norm()
}
func (z nat) xor(x, y nat) nat {
m := len(x)
n := len(y)
s := x
if m < n {
n, m = m, n
s = y
}
// m >= n
z = z.make(m)
for i := 0; i < n; i++ {
z[i] = x[i] ^ y[i]
}
copy(z[n:m], s[n:m])
return z.norm()
}
// random creates a random integer in [0..limit), using the space in z if
// possible. n is the bit length of limit.
func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
if alias(z, limit) {
z = nil // z is an alias for limit - cannot reuse
}
z = z.make(len(limit))
bitLengthOfMSW := uint(n % _W)
if bitLengthOfMSW == 0 {
bitLengthOfMSW = _W
}
mask := Word((1 << bitLengthOfMSW) - 1)
for {
switch _W {
case 32:
for i := range z {
z[i] = Word(rand.Uint32())
}
case 64:
for i := range z {
z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
}
default:
panic("unknown word size")
}
z[len(limit)-1] &= mask
if z.cmp(limit) < 0 {
break
}
}
return z.norm()
}
// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
// otherwise it sets z to x**y. The result is the value of z.
func (z nat) expNN(x, y, m nat, slow bool) nat {
if alias(z, x) || alias(z, y) {
// We cannot allow in-place modification of x or y.
z = nil
}
// x**y mod 1 == 0
if len(m) == 1 && m[0] == 1 {
return z.setWord(0)
}
// m == 0 || m > 1
// x**0 == 1
if len(y) == 0 {
return z.setWord(1)
}
// y > 0
// 0**y = 0
if len(x) == 0 {
return z.setWord(0)
}
// x > 0
// 1**y = 1
if len(x) == 1 && x[0] == 1 {
return z.setWord(1)
}
// x > 1
// x**1 == x
if len(y) == 1 && y[0] == 1 {
if len(m) != 0 {
return z.rem(x, m)
}
return z.set(x)
}
// y > 1
if len(m) != 0 {
// We likely end up being as long as the modulus.
z = z.make(len(m))
// If the exponent is large, we use the Montgomery method for odd values,
// and a 4-bit, windowed exponentiation for powers of two,
// and a CRT-decomposed Montgomery method for the remaining values
// (even values times non-trivial odd values, which decompose into one
// instance of each of the first two cases).
if len(y) > 1 && !slow {
if m[0]&1 == 1 {
return z.expNNMontgomery(x, y, m)
}
if logM, ok := m.isPow2(); ok {
return z.expNNWindowed(x, y, logM)
}
return z.expNNMontgomeryEven(x, y, m)
}
}
z = z.set(x)
v := y[len(y)-1] // v > 0 because y is normalized and y > 0
shift := nlz(v) + 1
v <<= shift
var q nat
const mask = 1 << (_W - 1)
// We walk through the bits of the exponent one by one. Each time we
// see a bit, we square, thus doubling the power. If the bit is a one,
// we also multiply by x, thus adding one to the power.
w := _W - int(shift)
// zz and r are used to avoid allocating in mul and div as
// otherwise the arguments would alias.
var zz, r nat
for j := 0; j < w; j++ {
zz = zz.sqr(z)
zz, z = z, zz
if v&mask != 0 {
zz = zz.mul(z, x)
zz, z = z, zz
}
if len(m) != 0 {
zz, r = zz.div(r, z, m)
zz, r, q, z = q, z, zz, r
}
v <<= 1
}
for i := len(y) - 2; i >= 0; i-- {
v = y[i]
for j := 0; j < _W; j++ {
zz = zz.sqr(z)
zz, z = z, zz
if v&mask != 0 {
zz = zz.mul(z, x)
zz, z = z, zz
}
if len(m) != 0 {
zz, r = zz.div(r, z, m)
zz, r, q, z = q, z, zz, r
}
v <<= 1
}
}
return z.norm()
}
// expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd.
// It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2
// and then uses the Chinese Remainder Theorem to combine the results.
// The recursive call using m1 will use expNNWindowed,
// while the recursive call using m2 will use expNNMontgomery.
// For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
// IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
// http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
func (z nat) expNNMontgomeryEven(x, y, m nat) nat {
// Split m = m₁ × m₂ where m₁ = 2ⁿ
n := m.trailingZeroBits()
m1 := nat(nil).shl(natOne, n)
m2 := nat(nil).shr(m, n)
// We want z = x**y mod m.
// z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
// z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
// (We are using the math/big convention for names here,
// where the computation is z = x**y mod m, so its parts are z1 and z2.
// The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
z1 := nat(nil).expNN(x, y, m1, false)
z2 := nat(nil).expNN(x, y, m2, false)
// Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
// which uses only a single modInverse (and an easy one at that).
// p = (z₁ - z₂) × m₂⁻¹ (mod m₁)
// z = z₂ + p × m₂
// The final addition is in range because:
// z = z₂ + p × m₂
// ≤ z₂ + (m₁-1) × m₂
// < m₂ + (m₁-1) × m₂
// = m₁ × m₂
// = m.
z = z.set(z2)
// Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1.
z1 = z1.subMod2N(z1, z2, n)
// Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
m2inv := nat(nil).modInverse(m2, m1)
z2 = z2.mul(z1, m2inv)
z2 = z2.trunc(z2, n)
// Reuse z1 for p * m2.
z = z.add(z, z1.mul(z2, m2))
return z
}
// expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
// where m = 2**logM.
func (z nat) expNNWindowed(x, y nat, logM uint) nat {
if len(y) <= 1 {
panic("big: misuse of expNNWindowed")
}
if x[0]&1 == 0 {
// len(y) > 1, so y > logM.
// x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM.
return z.setWord(0)
}
if logM == 1 {
return z.setWord(1)
}
// zz is used to avoid allocating in mul as otherwise
// the arguments would alias.
w := int((logM + _W - 1) / _W)
zzp := getNat(w)
zz := *zzp
const n = 4
// powers[i] contains x^i.
var powers [1 << n]*nat
for i := range powers {
powers[i] = getNat(w)
}
*powers[0] = powers[0].set(natOne)
*powers[1] = powers[1].trunc(x, logM)
for i := 2; i < 1<<n; i += 2 {
p2, p, p1 := powers[i/2], powers[i], powers[i+1]
*p = p.sqr(*p2)
*p = p.trunc(*p, logM)
*p1 = p1.mul(*p, x)
*p1 = p1.trunc(*p1, logM)
}
// Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1,
// so we can compute x**(y mod 2**(logM-1)) instead of x**y.
// That is, we can throw away all but the bottom logM-1 bits of y.
// Instead of allocating a new y, we start reading y at the right word
// and truncate it appropriately at the start of the loop.
i := len(y) - 1
mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word.
mmask := ^Word(0)
if mbits := (logM - 1) & (_W - 1); mbits != 0 {
mmask = (1 << mbits) - 1
}
if i > mtop {
i = mtop
}
advance := false
z = z.setWord(1)
for ; i >= 0; i-- {
yi := y[i]
if i == mtop {
yi &= mmask
}
for j := 0; j < _W; j += n {
if advance {
// Account for use of 4 bits in previous iteration.
// Unrolled loop for significant performance
// gain. Use go test -bench=".*" in crypto/rsa
// to check performance before making changes.
zz = zz.sqr(z)
zz, z = z, zz
z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
z = z.trunc(z, logM)
zz = zz.sqr(z)
zz, z = z, zz
z = z.trunc(z, logM)
}
zz = zz.mul(z, *powers[yi>>(_W-n)])
zz, z = z, zz
z = z.trunc(z, logM)
yi <<= n
advance = true
}
}
*zzp = zz
putNat(zzp)
for i := range powers {
putNat(powers[i])
}
return z.norm()
}
// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
// Uses Montgomery representation.
func (z nat) expNNMontgomery(x, y, m nat) nat {
numWords := len(m)
// We want the lengths of x and m to be equal.
// It is OK if x >= m as long as len(x) == len(m).
if len(x) > numWords {
_, x = nat(nil).div(nil, x, m)
// Note: now len(x) <= numWords, not guaranteed ==.
}
if len(x) < numWords {
rr := make(nat, numWords)
copy(rr, x)
x = rr
}
// Ideally the precomputations would be performed outside, and reused
// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
// Iteration for Multiplicative Inverses Modulo Prime Powers".
k0 := 2 - m[0]
t := m[0] - 1
for i := 1; i < _W; i <<= 1 {
t *= t
k0 *= (t + 1)
}
k0 = -k0
// RR = 2**(2*_W*len(m)) mod m
RR := nat(nil).setWord(1)
zz := nat(nil).shl(RR, uint(2*numWords*_W))
_, RR = nat(nil).div(RR, zz, m)
if len(RR) < numWords {
zz = zz.make(numWords)
copy(zz, RR)
RR = zz
}
// one = 1, with equal length to that of m
one := make(nat, numWords)
one[0] = 1
const n = 4
// powers[i] contains x^i
var powers [1 << n]nat
powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
for i := 2; i < 1<<n; i++ {
powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
}
// initialize z = 1 (Montgomery 1)
z = z.make(numWords)
copy(z, powers[0])
zz = zz.make(numWords)
// same windowed exponent, but with Montgomery multiplications
for i := len(y) - 1; i >= 0; i-- {
yi := y[i]
for j := 0; j < _W; j += n {
if i != len(y)-1 || j != 0 {
zz = zz.montgomery(z, z, m, k0, numWords)
z = z.montgomery(zz, zz, m, k0, numWords)
zz = zz.montgomery(z, z, m, k0, numWords)
z = z.montgomery(zz, zz, m, k0, numWords)
}
zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
z, zz = zz, z
yi <<= n
}
}
// convert to regular number
zz = zz.montgomery(z, one, m, k0, numWords)
// One last reduction, just in case.
// See golang.org/issue/13907.
if zz.cmp(m) >= 0 {
// Common case is m has high bit set; in that case,
// since zz is the same length as m, there can be just
// one multiple of m to remove. Just subtract.
// We think that the subtract should be sufficient in general,
// so do that unconditionally, but double-check,
// in case our beliefs are wrong.
// The div is not expected to be reached.
zz = zz.sub(zz, m)
if zz.cmp(m) >= 0 {
_, zz = nat(nil).div(nil, zz, m)
}
}
return zz.norm()
}
// bytes writes the value of z into buf using big-endian encoding.
// The value of z is encoded in the slice buf[i:]. If the value of z
// cannot be represented in buf, bytes panics. The number i of unused
// bytes at the beginning of buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
// This function is used in cryptographic operations. It must not leak
// anything but the Int's sign and bit size through side-channels. Any
// changes must be reviewed by a security expert.
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
if i >= 0 {
buf[i] = byte(d)
} else if byte(d) != 0 {
panic("math/big: buffer too small to fit value")
}
d >>= 8
}
}
if i < 0 {
i = 0
}
for i < len(buf) && buf[i] == 0 {
i++
}
return
}
// bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value.
func bigEndianWord(buf []byte) Word {
if _W == 64 {
return Word(binary.BigEndian.Uint64(buf))
}
return Word(binary.BigEndian.Uint32(buf))
}
// setBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z nat) setBytes(buf []byte) nat {
z = z.make((len(buf) + _S - 1) / _S)
i := len(buf)
for k := 0; i >= _S; k++ {
z[k] = bigEndianWord(buf[i-_S : i])
i -= _S
}
if i > 0 {
var d Word
for s := uint(0); i > 0; s += 8 {
d |= Word(buf[i-1]) << s
i--
}
z[len(z)-1] = d
}
return z.norm()
}
// sqrt sets z = ⌊√x⌋
func (z nat) sqrt(x nat) nat {
if x.cmp(natOne) <= 0 {
return z.set(x)
}
if alias(z, x) {
z = nil
}
// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
// https://members.loria.fr/PZimmermann/mca/pub226.html
// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
// otherwise it converges to the correct z and stays there.
var z1, z2 nat
z1 = z
z1 = z1.setUint64(1)
z1 = z1.shl(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
for n := 0; ; n++ {
z2, _ = z2.div(nil, x, z1)
z2 = z2.add(z2, z1)
z2 = z2.shr(z2, 1)
if z2.cmp(z1) >= 0 {
// z1 is answer.
// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
if n&1 == 0 {
return z1
}
return z.set(z1)
}
z1, z2 = z2, z1
}
}
// subMod2N returns z = (x - y) mod 2ⁿ.
func (z nat) subMod2N(x, y nat, n uint) nat {
if uint(x.bitLen()) > n {
if alias(z, x) {
// ok to overwrite x in place
x = x.trunc(x, n)
} else {
x = nat(nil).trunc(x, n)
}
}
if uint(y.bitLen()) > n {
if alias(z, y) {
// ok to overwrite y in place
y = y.trunc(y, n)
} else {
y = nat(nil).trunc(y, n)
}
}
if x.cmp(y) >= 0 {
return z.sub(x, y)
}
// x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x).
z = z.sub(y, x)
for uint(len(z))*_W < n {
z = append(z, 0)
}
for i := range z {
z[i] = ^z[i]
}
z = z.trunc(z, n)
return z.add(z, natOne)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements nat-to-string conversion functions.
package big
import (
"errors"
"fmt"
"io"
"math"
"math/bits"
"sync"
)
const digits = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
// Note: MaxBase = len(digits), but it must remain an untyped rune constant
// for API compatibility.
// MaxBase is the largest number base accepted for string conversions.
const MaxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1)
const maxBaseSmall = 10 + ('z' - 'a' + 1)
// maxPow returns (b**n, n) such that b**n is the largest power b**n <= _M.
// For instance maxPow(10) == (1e19, 19) for 19 decimal digits in a 64bit Word.
// In other words, at most n digits in base b fit into a Word.
// TODO(gri) replace this with a table, generated at build time.
func maxPow(b Word) (p Word, n int) {
p, n = b, 1 // assuming b <= _M
for max := _M / b; p <= max; {
// p == b**n && p <= max
p *= b
n++
}
// p == b**n && p <= _M
return
}
// pow returns x**n for n > 0, and 1 otherwise.
func pow(x Word, n int) (p Word) {
// n == sum of bi * 2**i, for 0 <= i < imax, and bi is 0 or 1
// thus x**n == product of x**(2**i) for all i where bi == 1
// (Russian Peasant Method for exponentiation)
p = 1
for n > 0 {
if n&1 != 0 {
p *= x
}
x *= x
n >>= 1
}
return
}
// scan errors
var (
errNoDigits = errors.New("number has no digits")
errInvalSep = errors.New("'_' must separate successive digits")
)
// scan scans the number corresponding to the longest possible prefix
// from r representing an unsigned number in a given conversion base.
// scan returns the corresponding natural number res, the actual base b,
// a digit count, and a read or syntax error err, if any.
//
// For base 0, an underscore character “_” may appear between a base
// prefix and an adjacent digit, and between successive digits; such
// underscores do not change the value of the number, or the returned
// digit count. Incorrect placement of underscores is reported as an
// error if there are no other errors. If base != 0, underscores are
// not recognized and thus terminate scanning like any other character
// that is not a valid radix point or digit.
//
// number = mantissa | prefix pmantissa .
// prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] .
// mantissa = digits "." [ digits ] | digits | "." digits .
// pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits .
// digits = digit { [ "_" ] digit } .
// digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" .
//
// Unless fracOk is set, the base argument must be 0 or a value between
// 2 and MaxBase. If fracOk is set, the base argument must be one of
// 0, 2, 8, 10, or 16. Providing an invalid base argument leads to a run-
// time panic.
//
// For base 0, the number prefix determines the actual base: A prefix of
// “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and
// “0x” or “0X” selects base 16. If fracOk is false, a “0” prefix
// (immediately followed by digits) selects base 8 as well. Otherwise,
// the selected base is 10 and no prefix is accepted.
//
// If fracOk is set, a period followed by a fractional part is permitted.
// The result value is computed as if there were no period present; and
// the count value is used to determine the fractional part.
//
// For bases <= 36, lower and upper case letters are considered the same:
// The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35.
// For bases > 36, the upper case letters 'A' to 'Z' represent the digit
// values 36 to 61.
//
// A result digit count > 0 corresponds to the number of (non-prefix) digits
// parsed. A digit count <= 0 indicates the presence of a period (if fracOk
// is set, only), and -count is the number of fractional digits found.
// In this case, the actual value of the scanned number is res * b**count.
func (z nat) scan(r io.ByteScanner, base int, fracOk bool) (res nat, b, count int, err error) {
// reject invalid bases
baseOk := base == 0 ||
!fracOk && 2 <= base && base <= MaxBase ||
fracOk && (base == 2 || base == 8 || base == 10 || base == 16)
if !baseOk {
panic(fmt.Sprintf("invalid number base %d", base))
}
// prev encodes the previously seen char: it is one
// of '_', '0' (a digit), or '.' (anything else). A
// valid separator '_' may only occur after a digit
// and if base == 0.
prev := '.'
invalSep := false
// one char look-ahead
ch, err := r.ReadByte()
// determine actual base
b, prefix := base, 0
if base == 0 {
// actual base is 10 unless there's a base prefix
b = 10
if err == nil && ch == '0' {
prev = '0'
count = 1
ch, err = r.ReadByte()
if err == nil {
// possibly one of 0b, 0B, 0o, 0O, 0x, 0X
switch ch {
case 'b', 'B':
b, prefix = 2, 'b'
case 'o', 'O':
b, prefix = 8, 'o'
case 'x', 'X':
b, prefix = 16, 'x'
default:
if !fracOk {
b, prefix = 8, '0'
}
}
if prefix != 0 {
count = 0 // prefix is not counted
if prefix != '0' {
ch, err = r.ReadByte()
}
}
}
}
}
// convert string
// Algorithm: Collect digits in groups of at most n digits in di
// and then use mulAddWW for every such group to add them to the
// result.
z = z[:0]
b1 := Word(b)
bn, n := maxPow(b1) // at most n digits in base b1 fit into Word
di := Word(0) // 0 <= di < b1**i < bn
i := 0 // 0 <= i < n
dp := -1 // position of decimal point
for err == nil {
if ch == '.' && fracOk {
fracOk = false
if prev == '_' {
invalSep = true
}
prev = '.'
dp = count
} else if ch == '_' && base == 0 {
if prev != '0' {
invalSep = true
}
prev = '_'
} else {
// convert rune into digit value d1
var d1 Word
switch {
case '0' <= ch && ch <= '9':
d1 = Word(ch - '0')
case 'a' <= ch && ch <= 'z':
d1 = Word(ch - 'a' + 10)
case 'A' <= ch && ch <= 'Z':
if b <= maxBaseSmall {
d1 = Word(ch - 'A' + 10)
} else {
d1 = Word(ch - 'A' + maxBaseSmall)
}
default:
d1 = MaxBase + 1
}
if d1 >= b1 {
r.UnreadByte() // ch does not belong to number anymore
break
}
prev = '0'
count++
// collect d1 in di
di = di*b1 + d1
i++
// if di is "full", add it to the result
if i == n {
z = z.mulAddWW(z, bn, di)
di = 0
i = 0
}
}
ch, err = r.ReadByte()
}
if err == io.EOF {
err = nil
}
// other errors take precedence over invalid separators
if err == nil && (invalSep || prev == '_') {
err = errInvalSep
}
if count == 0 {
// no digits found
if prefix == '0' {
// there was only the octal prefix 0 (possibly followed by separators and digits > 7);
// interpret as decimal 0
return z[:0], 10, 1, err
}
err = errNoDigits // fall through; result will be 0
}
// add remaining digits to result
if i > 0 {
z = z.mulAddWW(z, pow(b1, i), di)
}
res = z.norm()
// adjust count for fraction, if any
if dp >= 0 {
// 0 <= dp <= count
count = dp - count
}
return
}
// utoa converts x to an ASCII representation in the given base;
// base must be between 2 and MaxBase, inclusive.
func (x nat) utoa(base int) []byte {
return x.itoa(false, base)
}
// itoa is like utoa but it prepends a '-' if neg && x != 0.
func (x nat) itoa(neg bool, base int) []byte {
if base < 2 || base > MaxBase {
panic("invalid base")
}
// x == 0
if len(x) == 0 {
return []byte("0")
}
// len(x) > 0
// allocate buffer for conversion
i := int(float64(x.bitLen())/math.Log2(float64(base))) + 1 // off by 1 at most
if neg {
i++
}
s := make([]byte, i)
// convert power of two and non power of two bases separately
if b := Word(base); b == b&-b {
// shift is base b digit size in bits
shift := uint(bits.TrailingZeros(uint(b))) // shift > 0 because b >= 2
mask := Word(1<<shift - 1)
w := x[0] // current word
nbits := uint(_W) // number of unprocessed bits in w
// convert less-significant words (include leading zeros)
for k := 1; k < len(x); k++ {
// convert full digits
for nbits >= shift {
i--
s[i] = digits[w&mask]
w >>= shift
nbits -= shift
}
// convert any partial leading digit and advance to next word
if nbits == 0 {
// no partial digit remaining, just advance
w = x[k]
nbits = _W
} else {
// partial digit in current word w (== x[k-1]) and next word x[k]
w |= x[k] << nbits
i--
s[i] = digits[w&mask]
// advance
w = x[k] >> (shift - nbits)
nbits = _W - (shift - nbits)
}
}
// convert digits of most-significant word w (omit leading zeros)
for w != 0 {
i--
s[i] = digits[w&mask]
w >>= shift
}
} else {
bb, ndigits := maxPow(b)
// construct table of successive squares of bb*leafSize to use in subdivisions
// result (table != nil) <=> (len(x) > leafSize > 0)
table := divisors(len(x), b, ndigits, bb)
// preserve x, create local copy for use by convertWords
q := nat(nil).set(x)
// convert q to string s in base b
q.convertWords(s, b, ndigits, bb, table)
// strip leading zeros
// (x != 0; thus s must contain at least one non-zero digit
// and the loop will terminate)
i = 0
for s[i] == '0' {
i++
}
}
if neg {
i--
s[i] = '-'
}
return s[i:]
}
// Convert words of q to base b digits in s. If q is large, it is recursively "split in half"
// by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using
// repeated nat/Word division.
//
// The iterative method processes n Words by n divW() calls, each of which visits every Word in the
// incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s.
// Recursive conversion divides q by its approximate square root, yielding two parts, each half
// the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s
// plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and
// is made better by splitting the subblocks recursively. Best is to split blocks until one more
// split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the
// iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the
// range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and
// ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for
// specific hardware.
func (q nat) convertWords(s []byte, b Word, ndigits int, bb Word, table []divisor) {
// split larger blocks recursively
if table != nil {
// len(q) > leafSize > 0
var r nat
index := len(table) - 1
for len(q) > leafSize {
// find divisor close to sqrt(q) if possible, but in any case < q
maxLength := q.bitLen() // ~= log2 q, or at of least largest possible q of this bit length
minLength := maxLength >> 1 // ~= log2 sqrt(q)
for index > 0 && table[index-1].nbits > minLength {
index-- // desired
}
if table[index].nbits >= maxLength && table[index].bbb.cmp(q) >= 0 {
index--
if index < 0 {
panic("internal inconsistency")
}
}
// split q into the two digit number (q'*bbb + r) to form independent subblocks
q, r = q.div(r, q, table[index].bbb)
// convert subblocks and collect results in s[:h] and s[h:]
h := len(s) - table[index].ndigits
r.convertWords(s[h:], b, ndigits, bb, table[0:index])
s = s[:h] // == q.convertWords(s, b, ndigits, bb, table[0:index+1])
}
}
// having split any large blocks now process the remaining (small) block iteratively
i := len(s)
var r Word
if b == 10 {
// hard-coding for 10 here speeds this up by 1.25x (allows for / and % by constants)
for len(q) > 0 {
// extract least significant, base bb "digit"
q, r = q.divW(q, bb)
for j := 0; j < ndigits && i > 0; j++ {
i--
// avoid % computation since r%10 == r - int(r/10)*10;
// this appears to be faster for BenchmarkString10000Base10
// and smaller strings (but a bit slower for larger ones)
t := r / 10
s[i] = '0' + byte(r-t*10)
r = t
}
}
} else {
for len(q) > 0 {
// extract least significant, base bb "digit"
q, r = q.divW(q, bb)
for j := 0; j < ndigits && i > 0; j++ {
i--
s[i] = digits[r%b]
r /= b
}
}
}
// prepend high-order zeros
for i > 0 { // while need more leading zeros
i--
s[i] = '0'
}
}
// Split blocks greater than leafSize Words (or set to 0 to disable recursive conversion)
// Benchmark and configure leafSize using: go test -bench="Leaf"
//
// 8 and 16 effective on 3.0 GHz Xeon "Clovertown" CPU (128 byte cache lines)
// 8 and 16 effective on 2.66 GHz Core 2 Duo "Penryn" CPU
var leafSize int = 8 // number of Word-size binary values treat as a monolithic block
type divisor struct {
bbb nat // divisor
nbits int // bit length of divisor (discounting leading zeros) ~= log2(bbb)
ndigits int // digit length of divisor in terms of output base digits
}
var cacheBase10 struct {
sync.Mutex
table [64]divisor // cached divisors for base 10
}
// expWW computes x**y
func (z nat) expWW(x, y Word) nat {
return z.expNN(nat(nil).setWord(x), nat(nil).setWord(y), nil, false)
}
// construct table of powers of bb*leafSize to use in subdivisions.
func divisors(m int, b Word, ndigits int, bb Word) []divisor {
// only compute table when recursive conversion is enabled and x is large
if leafSize == 0 || m <= leafSize {
return nil
}
// determine k where (bb**leafSize)**(2**k) >= sqrt(x)
k := 1
for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 {
k++
}
// reuse and extend existing table of divisors or create new table as appropriate
var table []divisor // for b == 10, table overlaps with cacheBase10.table
if b == 10 {
cacheBase10.Lock()
table = cacheBase10.table[0:k] // reuse old table for this conversion
} else {
table = make([]divisor, k) // create new table for this conversion
}
// extend table
if table[k-1].ndigits == 0 {
// add new entries as needed
var larger nat
for i := 0; i < k; i++ {
if table[i].ndigits == 0 {
if i == 0 {
table[0].bbb = nat(nil).expWW(bb, Word(leafSize))
table[0].ndigits = ndigits * leafSize
} else {
table[i].bbb = nat(nil).sqr(table[i-1].bbb)
table[i].ndigits = 2 * table[i-1].ndigits
}
// optimization: exploit aggregated extra bits in macro blocks
larger = nat(nil).set(table[i].bbb)
for mulAddVWW(larger, larger, b, 0) == 0 {
table[i].bbb = table[i].bbb.set(larger)
table[i].ndigits++
}
table[i].nbits = table[i].bbb.bitLen()
}
}
}
if b == 10 {
cacheBase10.Unlock()
}
return table
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Multi-precision division. Here be dragons.
Given u and v, where u is n+m digits, and v is n digits (with no leading zeros),
the goal is to return quo, rem such that u = quo*v + rem, where 0 ≤ rem < v.
That is, quo = ⌊u/v⌋ where ⌊x⌋ denotes the floor (truncation to integer) of x,
and rem = u - quo·v.
Long Division
Division in a computer proceeds the same as long division in elementary school,
but computers are not as good as schoolchildren at following vague directions,
so we have to be much more precise about the actual steps and what can happen.
We work from most to least significant digit of the quotient, doing:
• Guess a digit q, the number of v to subtract from the current
section of u to zero out the topmost digit.
• Check the guess by multiplying q·v and comparing it against
the current section of u, adjusting the guess as needed.
• Subtract q·v from the current section of u.
• Add q to the corresponding section of the result quo.
When all digits have been processed, the final remainder is left in u
and returned as rem.
For example, here is a sketch of dividing 5 digits by 3 digits (n=3, m=2).
q₂ q₁ q₀
_________________
v₂ v₁ v₀ ) u₄ u₃ u₂ u₁ u₀
↓ ↓ ↓ | |
[u₄ u₃ u₂]| |
- [ q₂·v ]| |
----------- ↓ |
[ rem | u₁]|
- [ q₁·v ]|
----------- ↓
[ rem | u₀]
- [ q₀·v ]
------------
[ rem ]
Instead of creating new storage for the remainders and copying digits from u
as indicated by the arrows, we use u's storage directly as both the source
and destination of the subtractions, so that the remainders overwrite
successive overlapping sections of u as the division proceeds, using a slice
of u to identify the current section. This avoids all the copying as well as
shifting of remainders.
Division of u with n+m digits by v with n digits (in base B) can in general
produce at most m+1 digits, because:
• u < B^(n+m) [B^(n+m) has n+m+1 digits]
• v ≥ B^(n-1) [B^(n-1) is the smallest n-digit number]
• u/v < B^(n+m) / B^(n-1) [divide bounds for u, v]
• u/v < B^(m+1) [simplify]
The first step is special: it takes the top n digits of u and divides them by
the n digits of v, producing the first quotient digit and an n-digit remainder.
In the example, q₂ = ⌊u₄u₃u₂ / v⌋.
The first step divides n digits by n digits to ensure that it produces only a
single digit.
Each subsequent step appends the next digit from u to the remainder and divides
those n+1 digits by the n digits of v, producing another quotient digit and a
new n-digit remainder.
Subsequent steps divide n+1 digits by n digits, an operation that in general
might produce two digits. However, as used in the algorithm, that division is
guaranteed to produce only a single digit. The dividend is of the form
rem·B + d, where rem is a remainder from the previous step and d is a single
digit, so:
• rem ≤ v - 1 [rem is a remainder from dividing by v]
• rem·B ≤ v·B - B [multiply by B]
• d ≤ B - 1 [d is a single digit]
• rem·B + d ≤ v·B - 1 [add]
• rem·B + d < v·B [change ≤ to <]
• (rem·B + d)/v < B [divide by v]
Guess and Check
At each step we need to divide n+1 digits by n digits, but this is for the
implementation of division by n digits, so we can't just invoke a division
routine: we _are_ the division routine. Instead, we guess at the answer and
then check it using multiplication. If the guess is wrong, we correct it.
How can this guessing possibly be efficient? It turns out that the following
statement (let's call it the Good Guess Guarantee) is true.
If
• q = ⌊u/v⌋ where u is n+1 digits and v is n digits,
• q < B, and
• the topmost digit of v = vₙ₋₁ ≥ B/2,
then q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ satisfies q ≤ q̂ ≤ q+2. (Proof below.)
That is, if we know the answer has only a single digit and we guess an answer
by ignoring the bottom n-1 digits of u and v, using a 2-by-1-digit division,
then that guess is at least as large as the correct answer. It is also not
too much larger: it is off by at most two from the correct answer.
Note that in the first step of the overall division, which is an n-by-n-digit
division, the 2-by-1 guess uses an implicit uₙ = 0.
Note that using a 2-by-1-digit division here does not mean calling ourselves
recursively. Instead, we use an efficient direct hardware implementation of
that operation.
Note that because q is u/v rounded down, q·v must not exceed u: u ≥ q·v.
If a guess q̂ is too big, it will not satisfy this test. Viewed a different way,
the remainder r̂ for a given q̂ is u - q̂·v, which must be positive. If it is
negative, then the guess q̂ is too big.
This gives us a way to compute q. First compute q̂ with 2-by-1-digit division.
Then, while u < q̂·v, decrement q̂; this loop executes at most twice, because
q̂ ≤ q+2.
Scaling Inputs
The Good Guess Guarantee requires that the top digit of v (vₙ₋₁) be at least B/2.
For example in base 10, ⌊172/19⌋ = 9, but ⌊18/1⌋ = 18: the guess is wildly off
because the first digit 1 is smaller than B/2 = 5.
We can ensure that v has a large top digit by multiplying both u and v by the
right amount. Continuing the example, if we multiply both 172 and 19 by 3, we
now have ⌊516/57⌋, the leading digit of v is now ≥ 5, and sure enough
⌊51/5⌋ = 10 is much closer to the correct answer 9. It would be easier here
to multiply by 4, because that can be done with a shift. Specifically, we can
always count the number of leading zeros i in the first digit of v and then
shift both u and v left by i bits.
Having scaled u and v, the value ⌊u/v⌋ is unchanged, but the remainder will
be scaled: 172 mod 19 is 1, but 516 mod 57 is 3. We have to divide the remainder
by the scaling factor (shifting right i bits) when we finish.
Note that these shifts happen before and after the entire division algorithm,
not at each step in the per-digit iteration.
Note the effect of scaling inputs on the size of the possible quotient.
In the scaled u/v, u can gain a digit from scaling; v never does, because we
pick the scaling factor to make v's top digit larger but without overflowing.
If u and v have n+m and n digits after scaling, then:
• u < B^(n+m) [B^(n+m) has n+m+1 digits]
• v ≥ B^n / 2 [vₙ₋₁ ≥ B/2, so vₙ₋₁·B^(n-1) ≥ B^n/2]
• u/v < B^(n+m) / (B^n / 2) [divide bounds for u, v]
• u/v < 2 B^m [simplify]
The quotient can still have m+1 significant digits, but if so the top digit
must be a 1. This provides a different way to handle the first digit of the
result: compare the top n digits of u against v and fill in either a 0 or a 1.
Refining Guesses
Before we check whether u < q̂·v, we can adjust our guess to change it from
q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ into the refined guess ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋.
Although not mentioned above, the Good Guess Guarantee also promises that this
3-by-2-digit division guess is more precise and at most one away from the real
answer q. The improvement from the 2-by-1 to the 3-by-2 guess can also be done
without n-digit math.
If we have a guess q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ and we want to see if it also equal to
⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, we can use the same check we would for the full division:
if uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂, then the guess is too large and should be reduced.
Checking uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ < 0,
and
uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ = (uₙuₙ₋₁·B + uₙ₋₂) - q̂·(vₙ₋₁·B + vₙ₋₂)
[splitting off the bottom digit]
= (uₙuₙ₋₁ - q̂·vₙ₋₁)·B + uₙ₋₂ - q̂·vₙ₋₂
[regrouping]
The expression (uₙuₙ₋₁ - q̂·vₙ₋₁) is the remainder of uₙuₙ₋₁ / vₙ₋₁.
If the initial guess returns both q̂ and its remainder r̂, then checking
whether uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as checking r̂·B + uₙ₋₂ < q̂·vₙ₋₂.
If we find that r̂·B + uₙ₋₂ < q̂·vₙ₋₂, then we can adjust the guess by
decrementing q̂ and adding vₙ₋₁ to r̂. We repeat until r̂·B + uₙ₋₂ ≥ q̂·vₙ₋₂.
(As before, this fixup is only needed at most twice.)
Now that q̂ = ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, as mentioned above it is at most one
away from the correct q, and we've avoided doing any n-digit math.
(If we need the new remainder, it can be computed as r̂·B + uₙ₋₂ - q̂·vₙ₋₂.)
The final check u < q̂·v and the possible fixup must be done at full precision.
For random inputs, a fixup at this step is exceedingly rare: the 3-by-2 guess
is not often wrong at all. But still we must do the check. Note that since the
3-by-2 guess is off by at most 1, it can be convenient to perform the final
u < q̂·v as part of the computation of the remainder r = u - q̂·v. If the
subtraction underflows, decremeting q̂ and adding one v back to r is enough to
arrive at the final q, r.
That's the entirety of long division: scale the inputs, and then loop over
each output position, guessing, checking, and correcting the next output digit.
For a 2n-digit number divided by an n-digit number (the worst size-n case for
division complexity), this algorithm uses n+1 iterations, each of which must do
at least the 1-by-n-digit multiplication q̂·v. That's O(n) iterations of
O(n) time each, so O(n²) time overall.
Recursive Division
For very large inputs, it is possible to improve on the O(n²) algorithm.
Let's call a group of n/2 real digits a (very) “wide digit”. We can run the
standard long division algorithm explained above over the wide digits instead of
the actual digits. This will result in many fewer steps, but the math involved in
each step is more work.
Where basic long division uses a 2-by-1-digit division to guess the initial q̂,
the new algorithm must use a 2-by-1-wide-digit division, which is of course
really an n-by-n/2-digit division. That's OK: if we implement n-digit division
in terms of n/2-digit division, the recursion will terminate when the divisor
becomes small enough to handle with standard long division or even with the
2-by-1 hardware instruction.
For example, here is a sketch of dividing 10 digits by 4, proceeding with
wide digits corresponding to two regular digits. The first step, still special,
must leave off a (regular) digit, dividing 5 by 4 and producing a 4-digit
remainder less than v. The middle steps divide 6 digits by 4, guaranteed to
produce two output digits each (one wide digit) with 4-digit remainders.
The final step must use what it has: the 4-digit remainder plus one more,
5 digits to divide by 4.
q₆ q₅ q₄ q₃ q₂ q₁ q₀
_______________________________
v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀
↓ ↓ ↓ ↓ ↓ | | | | |
[u₉ u₈ u₇ u₆ u₅]| | | | |
- [ q₆q₅·v ]| | | | |
----------------- ↓ ↓ | | |
[ rem |u₄ u₃]| | |
- [ q₄q₃·v ]| | |
-------------------- ↓ ↓ |
[ rem |u₂ u₁]|
- [ q₂q₁·v ]|
-------------------- ↓
[ rem |u₀]
- [ q₀·v ]
------------------
[ rem ]
An alternative would be to look ahead to how well n/2 divides into n+m and
adjust the first step to use fewer digits as needed, making the first step
more special to make the last step not special at all. For example, using the
same input, we could choose to use only 4 digits in the first step, leaving
a full wide digit for the last step:
q₆ q₅ q₄ q₃ q₂ q₁ q₀
_______________________________
v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀
↓ ↓ ↓ ↓ | | | | | |
[u₉ u₈ u₇ u₆]| | | | | |
- [ q₆·v ]| | | | | |
-------------- ↓ ↓ | | | |
[ rem |u₅ u₄]| | | |
- [ q₅q₄·v ]| | | |
-------------------- ↓ ↓ | |
[ rem |u₃ u₂]| |
- [ q₃q₂·v ]| |
-------------------- ↓ ↓
[ rem |u₁ u₀]
- [ q₁q₀·v ]
---------------------
[ rem ]
Today, the code in divRecursiveStep works like the first example. Perhaps in
the future we will make it work like the alternative, to avoid a special case
in the final iteration.
Either way, each step is a 3-by-2-wide-digit division approximated first by
a 2-by-1-wide-digit division, just as we did for regular digits in long division.
Because the actual answer we want is a 3-by-2-wide-digit division, instead of
multiplying q̂·v directly during the fixup, we can use the quick refinement
from long division (an n/2-by-n/2 multiply) to correct q to its actual value
and also compute the remainder (as mentioned above), and then stop after that,
never doing a full n-by-n multiply.
Instead of using an n-by-n/2-digit division to produce n/2 digits, we can add
(not discard) one more real digit, doing an (n+1)-by-(n/2+1)-digit division that
produces n/2+1 digits. That single extra digit tightens the Good Guess Guarantee
to q ≤ q̂ ≤ q+1 and lets us drop long division's special treatment of the first
digit. These benefits are discussed more after the Good Guess Guarantee proof
below.
How Fast is Recursive Division?
For a 2n-by-n-digit division, this algorithm runs a 4-by-2 long division over
wide digits, producing two wide digits plus a possible leading regular digit 1,
which can be handled without a recursive call. That is, the algorithm uses two
full iterations, each using an n-by-n/2-digit division and an n/2-by-n/2-digit
multiplication, along with a few n-digit additions and subtractions. The standard
n-by-n-digit multiplication algorithm requires O(n²) time, making the overall
algorithm require time T(n) where
T(n) = 2T(n/2) + O(n) + O(n²)
which, by the Bentley-Haken-Saxe theorem, ends up reducing to T(n) = O(n²).
This is not an improvement over regular long division.
When the number of digits n becomes large enough, Karatsuba's algorithm for
multiplication can be used instead, which takes O(n^log₂3) = O(n^1.6) time.
(Karatsuba multiplication is implemented in func karatsuba in nat.go.)
That makes the overall recursive division algorithm take O(n^1.6) time as well,
which is an improvement, but again only for large enough numbers.
It is not critical to make sure that every recursion does only two recursive
calls. While in general the number of recursive calls can change the time
analysis, in this case doing three calls does not change the analysis:
T(n) = 3T(n/2) + O(n) + O(n^log₂3)
ends up being T(n) = O(n^log₂3). Because the Karatsuba multiplication taking
time O(n^log₂3) is itself doing 3 half-sized recursions, doing three for the
division does not hurt the asymptotic performance. Of course, it is likely
still faster in practice to do two.
Proof of the Good Guess Guarantee
Given numbers x, y, let us break them into the quotients and remainders when
divided by some scaling factor S, with the added constraints that the quotient
x/y and the high part of y are both less than some limit T, and that the high
part of y is at least half as big as T.
x₁ = ⌊x/S⌋ y₁ = ⌊y/S⌋
x₀ = x mod S y₀ = y mod S
x = x₁·S + x₀ 0 ≤ x₀ < S x/y < T
y = y₁·S + y₀ 0 ≤ y₀ < S T/2 ≤ y₁ < T
And consider the two truncated quotients:
q = ⌊x/y⌋
q̂ = ⌊x₁/y₁⌋
We will prove that q ≤ q̂ ≤ q+2.
The guarantee makes no real demands on the scaling factor S: it is simply the
magnitude of the digits cut from both x and y to produce x₁ and y₁.
The guarantee makes only limited demands on T: it must be large enough to hold
the quotient x/y, and y₁ must have roughly the same size.
To apply to the earlier discussion of 2-by-1 guesses in long division,
we would choose:
S = Bⁿ⁻¹
T = B
x = u
x₁ = uₙuₙ₋₁
x₀ = uₙ₋₂...u₀
y = v
y₁ = vₙ₋₁
y₀ = vₙ₋₂...u₀
These simpler variables avoid repeating those longer expressions in the proof.
Note also that, by definition, truncating division ⌊x/y⌋ satisfies
x/y - 1 < ⌊x/y⌋ ≤ x/y.
This fact will be used a few times in the proofs.
Proof that q ≤ q̂:
q̂·y₁ = ⌊x₁/y₁⌋·y₁ [by definition, q̂ = ⌊x₁/y₁⌋]
> (x₁/y₁ - 1)·y₁ [x₁/y₁ - 1 < ⌊x₁/y₁⌋]
= x₁ - y₁ [distribute y₁]
So q̂·y₁ > x₁ - y₁.
Since q̂·y₁ is an integer, q̂·y₁ ≥ x₁ - y₁ + 1.
q̂ - q = q̂ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋]
≥ q̂ - x/y [⌊x/y⌋ < x/y]
= (1/y)·(q̂·y - x) [factor out 1/y]
≥ (1/y)·(q̂·y₁·S - x) [y = y₁·S + y₀ ≥ y₁·S]
≥ (1/y)·((x₁ - y₁ + 1)·S - x) [above: q̂·y₁ ≥ x₁ - y₁ + 1]
= (1/y)·(x₁·S - y₁·S + S - x) [distribute S]
= (1/y)·(S - x₀ - y₁·S) [-x = -x₁·S - x₀]
> -y₁·S / y [x₀ < S, so S - x₀ < 0; drop it]
≥ -1 [y₁·S ≤ y]
So q̂ - q > -1.
Since q̂ - q is an integer, q̂ - q ≥ 0, or equivalently q ≤ q̂.
Proof that q̂ ≤ q+2:
x₁/y₁ - x/y = x₁·S/y₁·S - x/y [multiply left term by S/S]
≤ x/y₁·S - x/y [x₁S ≤ x]
= (x/y)·(y/y₁·S - 1) [factor out x/y]
= (x/y)·((y - y₁·S)/y₁·S) [move -1 into y/y₁·S fraction]
= (x/y)·(y₀/y₁·S) [y - y₁·S = y₀]
= (x/y)·(1/y₁)·(y₀/S) [factor out 1/y₁]
< (x/y)·(1/y₁) [y₀ < S, so y₀/S < 1]
≤ (x/y)·(2/T) [y₁ ≥ T/2, so 1/y₁ ≤ 2/T]
< T·(2/T) [x/y < T]
= 2 [T·(2/T) = 2]
So x₁/y₁ - x/y < 2.
q̂ - q = ⌊x₁/y₁⌋ - q [by definition, q̂ = ⌊x₁/y₁⌋]
= ⌊x₁/y₁⌋ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋]
≤ x₁/y₁ - ⌊x/y⌋ [⌊x₁/y₁⌋ ≤ x₁/y₁]
< x₁/y₁ - (x/y - 1) [⌊x/y⌋ > x/y - 1]
= (x₁/y₁ - x/y) + 1 [regrouping]
< 2 + 1 [above: x₁/y₁ - x/y < 2]
= 3
So q̂ - q < 3.
Since q̂ - q is an integer, q̂ - q ≤ 2.
Note that when x/y < T/2, the bounds tighten to x₁/y₁ - x/y < 1 and therefore
q̂ - q ≤ 1.
Note also that in the general case 2n-by-n division where we don't know that
x/y < T, we do know that x/y < 2T, yielding the bound q̂ - q ≤ 4. So we could
remove the special case first step of long division as long as we allow the
first fixup loop to run up to four times. (Using a simple comparison to decide
whether the first digit is 0 or 1 is still more efficient, though.)
Finally, note that when dividing three leading base-B digits by two (scaled),
we have T = B² and x/y < B = T/B, a much tighter bound than x/y < T.
This in turn yields the much tighter bound x₁/y₁ - x/y < 2/B. This means that
⌊x₁/y₁⌋ and ⌊x/y⌋ can only differ when x/y is less than 2/B greater than an
integer. For random x and y, the chance of this is 2/B, or, for large B,
approximately zero. This means that after we produce the 3-by-2 guess in the
long division algorithm, the fixup loop essentially never runs.
In the recursive algorithm, the extra digit in (2·⌊n/2⌋+1)-by-(⌊n/2⌋+1)-digit
division has exactly the same effect: the probability of needing a fixup is the
same 2/B. Even better, we can allow the general case x/y < 2T and the fixup
probability only grows to 4/B, still essentially zero.
References
There are no great references for implementing long division; thus this comment.
Here are some notes about what to expect from the obvious references.
Knuth Volume 2 (Seminumerical Algorithms) section 4.3.1 is the usual canonical
reference for long division, but that entire series is highly compressed, never
repeating a necessary fact and leaving important insights to the exercises.
For example, no rationale whatsoever is given for the calculation that extends
q̂ from a 2-by-1 to a 3-by-2 guess, nor why it reduces the error bound.
The proof that the calculation even has the desired effect is left to exercises.
The solutions to those exercises provided at the back of the book are entirely
calculations, still with no explanation as to what is going on or how you would
arrive at the idea of doing those exact calculations. Nowhere is it mentioned
that this test extends the 2-by-1 guess into a 3-by-2 guess. The proof of the
Good Guess Guarantee is only for the 2-by-1 guess and argues by contradiction,
making it difficult to understand how modifications like adding another digit
or adjusting the quotient range affects the overall bound.
All that said, Knuth remains the canonical reference. It is dense but packed
full of information and references, and the proofs are simpler than many other
presentations. The proofs above are reworkings of Knuth's to remove the
arguments by contradiction and add explanations or steps that Knuth omitted.
But beware of errors in older printings. Take the published errata with you.
Brinch Hansen's “Multiple-length Division Revisited: a Tour of the Minefield”
starts with a blunt critique of Knuth's presentation (among others) and then
presents a more detailed and easier to follow treatment of long division,
including an implementation in Pascal. But the algorithm and implementation
work entirely in terms of 3-by-2 division, which is much less useful on modern
hardware than an algorithm using 2-by-1 division. The proofs are a bit too
focused on digit counting and seem needlessly complex, especially compared to
the ones given above.
Burnikel and Ziegler's “Fast Recursive Division” introduced the key insight of
implementing division by an n-digit divisor using recursive calls to division
by an n/2-digit divisor, relying on Karatsuba multiplication to yield a
sub-quadratic run time. However, the presentation decisions are made almost
entirely for the purpose of simplifying the run-time analysis, rather than
simplifying the presentation. Instead of a single algorithm that loops over
quotient digits, the paper presents two mutually-recursive algorithms, for
2n-by-n and 3n-by-2n. The paper also does not present any general (n+m)-by-n
algorithm.
The proofs in the paper are remarkably complex, especially considering that
the algorithm is at its core just long division on wide digits, so that the
usual long division proofs apply essentially unaltered.
*/
package big
import "math/bits"
// rem returns r such that r = u%v.
// It uses z as the storage for r.
func (z nat) rem(u, v nat) (r nat) {
if alias(z, u) {
z = nil
}
qp := getNat(0)
q, r := qp.div(z, u, v)
*qp = q
putNat(qp)
return r
}
// div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v.
// It uses z and z2 as the storage for q and r.
func (z nat) div(z2, u, v nat) (q, r nat) {
if len(v) == 0 {
panic("division by zero")
}
if u.cmp(v) < 0 {
q = z[:0]
r = z2.set(u)
return
}
if len(v) == 1 {
// Short division: long optimized for a single-word divisor.
// In that case, the 2-by-1 guess is all we need at each step.
var r2 Word
q, r2 = z.divW(u, v[0])
r = z2.setWord(r2)
return
}
q, r = z.divLarge(z2, u, v)
return
}
// divW returns q, r such that q = ⌊x/y⌋ and r = x%y = x - q·y.
// It uses z as the storage for q.
// Note that y is a single digit (Word), not a big number.
func (z nat) divW(x nat, y Word) (q nat, r Word) {
m := len(x)
switch {
case y == 0:
panic("division by zero")
case y == 1:
q = z.set(x) // result is x
return
case m == 0:
q = z[:0] // result is 0
return
}
// m > 0
z = z.make(m)
r = divWVW(z, 0, x, y)
q = z.norm()
return
}
// modW returns x % d.
func (x nat) modW(d Word) (r Word) {
// TODO(agl): we don't actually need to store the q value.
var q nat
q = q.make(len(x))
return divWVW(q, 0, x, d)
}
// divWVW overwrites z with ⌊x/y⌋, returning the remainder r.
// The caller must ensure that len(z) = len(x).
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) {
r = xn
if len(x) == 1 {
qq, rr := bits.Div(uint(r), uint(x[0]), uint(y))
z[0] = Word(qq)
return Word(rr)
}
rec := reciprocalWord(y)
for i := len(z) - 1; i >= 0; i-- {
z[i], r = divWW(r, x[i], y, rec)
}
return r
}
// div returns q, r such that q = ⌊uIn/vIn⌋ and r = uIn%vIn = uIn - q·vIn.
// It uses z and u as the storage for q and r.
// The caller must ensure that len(vIn) ≥ 2 (use divW otherwise)
// and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise).
func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) {
n := len(vIn)
m := len(uIn) - n
// Scale the inputs so vIn's top bit is 1 (see “Scaling Inputs” above).
// vIn is treated as a read-only input (it may be in use by another
// goroutine), so we must make a copy.
// uIn is copied to u.
shift := nlz(vIn[n-1])
vp := getNat(n)
v := *vp
shlVU(v, vIn, shift)
u = u.make(len(uIn) + 1)
u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
// The caller should not pass aliased z and u, since those are
// the two different outputs, but correct just in case.
if alias(z, u) {
z = nil
}
q = z.make(m + 1)
// Use basic or recursive long division depending on size.
if n < divRecursiveThreshold {
q.divBasic(u, v)
} else {
q.divRecursive(u, v)
}
putNat(vp)
q = q.norm()
// Undo scaling of remainder.
shrVU(u, u, shift)
r = u.norm()
return q, r
}
// divBasic implements long division as described above.
// It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r.
// q must be large enough to hold ⌊u/v⌋.
func (q nat) divBasic(u, v nat) {
n := len(v)
m := len(u) - n
qhatvp := getNat(n + 1)
qhatv := *qhatvp
// Set up for divWW below, precomputing reciprocal argument.
vn1 := v[n-1]
rec := reciprocalWord(vn1)
// Compute each digit of quotient.
for j := m; j >= 0; j-- {
// Compute the 2-by-1 guess q̂.
// The first iteration must invent a leading 0 for u.
qhat := Word(_M)
var ujn Word
if j+n < len(u) {
ujn = u[j+n]
}
// ujn ≤ vn1, or else q̂ would be more than one digit.
// For ujn == vn1, we set q̂ to the max digit M above.
// Otherwise, we compute the 2-by-1 guess.
if ujn != vn1 {
var rhat Word
qhat, rhat = divWW(ujn, u[j+n-1], vn1, rec)
// Refine q̂ to a 3-by-2 guess. See “Refining Guesses” above.
vn2 := v[n-2]
x1, x2 := mulWW(qhat, vn2)
ujn2 := u[j+n-2]
for greaterThan(x1, x2, rhat, ujn2) { // x1x2 > r̂ u[j+n-2]
qhat--
prevRhat := rhat
rhat += vn1
// If r̂ overflows, then
// r̂ u[j+n-2]v[n-1] is now definitely > x1 x2.
if rhat < prevRhat {
break
}
// TODO(rsc): No need for a full mulWW.
// x2 += vn2; if x2 overflows, x1++
x1, x2 = mulWW(qhat, vn2)
}
}
// Compute q̂·v.
qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
qhl := len(qhatv)
if j+qhl > len(u) && qhatv[n] == 0 {
qhl--
}
// Subtract q̂·v from the current section of u.
// If it underflows, q̂·v > u, which we fix up
// by decrementing q̂ and adding v back.
c := subVV(u[j:j+qhl], u[j:], qhatv)
if c != 0 {
c := addVV(u[j:j+n], u[j:], v)
// If n == qhl, the carry from subVV and the carry from addVV
// cancel out and don't affect u[j+n].
if n < qhl {
u[j+n] += c
}
qhat--
}
// Save quotient digit.
// Caller may know the top digit is zero and not leave room for it.
if j == m && m == len(q) && qhat == 0 {
continue
}
q[j] = qhat
}
putNat(qhatvp)
}
// greaterThan reports whether the two digit numbers x1 x2 > y1 y2.
// TODO(rsc): In contradiction to most of this file, x1 is the high
// digit and x2 is the low digit. This should be fixed.
func greaterThan(x1, x2, y1, y2 Word) bool {
return x1 > y1 || x1 == y1 && x2 > y2
}
// divRecursiveThreshold is the number of divisor digits
// at which point divRecursive is faster than divBasic.
const divRecursiveThreshold = 100
// divRecursive implements recursive division as described above.
// It overwrites z with ⌊u/v⌋ and overwrites u with the remainder r.
// z must be large enough to hold ⌊u/v⌋.
// This function is just for allocating and freeing temporaries
// around divRecursiveStep, the real implementation.
func (z nat) divRecursive(u, v nat) {
// Recursion depth is (much) less than 2 log₂(len(v)).
// Allocate a slice of temporaries to be reused across recursion,
// plus one extra temporary not live across the recursion.
recDepth := 2 * bits.Len(uint(len(v)))
tmp := getNat(3 * len(v))
temps := make([]*nat, recDepth)
z.clear()
z.divRecursiveStep(u, v, 0, tmp, temps)
// Free temporaries.
for _, n := range temps {
if n != nil {
putNat(n)
}
}
putNat(tmp)
}
// divRecursiveStep is the actual implementation of recursive division.
// It adds ⌊u/v⌋ to z and overwrites u with the remainder r.
// z must be large enough to hold ⌊u/v⌋.
// It uses temps[depth] (allocating if needed) as a temporary live across
// the recursive call. It also uses tmp, but not live across the recursion.
func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// u is a subsection of the original and may have leading zeros.
// TODO(rsc): The v = v.norm() is useless and should be removed.
// We know (and require) that v's top digit is ≥ B/2.
u = u.norm()
v = v.norm()
if len(u) == 0 {
z.clear()
return
}
// Fall back to basic division if the problem is now small enough.
n := len(v)
if n < divRecursiveThreshold {
z.divBasic(u, v)
return
}
// Nothing to do if u is shorter than v (implies u < v).
m := len(u) - n
if m < 0 {
return
}
// We consider B digits in a row as a single wide digit.
// (See “Recursive Division” above.)
//
// TODO(rsc): rename B to Wide, to avoid confusion with _B,
// which is something entirely different.
// TODO(rsc): Look into whether using ⌈n/2⌉ is better than ⌊n/2⌋.
B := n / 2
// Allocate a nat for qhat below.
if temps[depth] == nil {
temps[depth] = getNat(n) // TODO(rsc): Can be just B+1.
} else {
*temps[depth] = temps[depth].make(B + 1)
}
// Compute each wide digit of the quotient.
//
// TODO(rsc): Change the loop to be
// for j := (m+B-1)/B*B; j > 0; j -= B {
// which will make the final step a regular step, letting us
// delete what amounts to an extra copy of the loop body below.
j := m
for j > B {
// Divide u[j-B:j+n] (3 wide digits) by v (2 wide digits).
// First make the 2-by-1-wide-digit guess using a recursive call.
// Then extend the guess to the full 3-by-2 (see “Refining Guesses”).
//
// For the 2-by-1-wide-digit guess, instead of doing 2B-by-B-digit,
// we use a (2B+1)-by-(B+1) digit, which handles the possibility that
// the result has an extra leading 1 digit as well as guaranteeing
// that the computed q̂ will be off by at most 1 instead of 2.
// s is the number of digits to drop from the 3B- and 2B-digit chunks.
// We drop B-1 to be left with 2B+1 and B+1.
s := (B - 1)
// uu is the up-to-3B-digit section of u we are working on.
uu := u[j-B:]
// Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n].
qhat := *temps[depth]
qhat.clear()
qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
// Extend to a 3-by-2 quotient and remainder.
// Because divRecursiveStep overwrote the top part of uu with
// the remainder r̂, the full uu already contains the equivalent
// of r̂·B + uₙ₋₂ from the “Refining Guesses” discussion.
// Subtracting q̂·vₙ₋₂ from it will compute the full-length remainder.
// If that subtraction underflows, q̂·v > u, which we fix up
// by decrementing q̂ and adding v back, same as in long division.
// TODO(rsc): Instead of subtract and fix-up, this code is computing
// q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u.
// But we can do the subtraction directly, as in the comment above
// and in long division, because we know that q̂ is wrong by at most one.
qhatv := tmp.make(3 * n)
qhatv.clear()
qhatv = qhatv.mul(qhat, v[:s])
for i := 0; i < 2; i++ {
e := qhatv.cmp(uu.norm())
if e <= 0 {
break
}
subVW(qhat, qhat, 1)
c := subVV(qhatv[:s], qhatv[:s], v[:s])
if len(qhatv) > s {
subVW(qhatv[s:], qhatv[s:], c)
}
addAt(uu[s:], v[s:], 0)
}
if qhatv.cmp(uu.norm()) > 0 {
panic("impossible")
}
c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv)
if c > 0 {
subVW(uu[len(qhatv):], uu[len(qhatv):], c)
}
addAt(z, qhat, j-B)
j -= B
}
// TODO(rsc): Rewrite loop as described above and delete all this code.
// Now u < (v<<B), compute lower bits in the same way.
// Choose shift = B-1 again.
s := B - 1
qhat := *temps[depth]
qhat.clear()
qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
qhatv := tmp.make(3 * n)
qhatv.clear()
qhatv = qhatv.mul(qhat, v[:s])
// Set the correct remainder as before.
for i := 0; i < 2; i++ {
if e := qhatv.cmp(u.norm()); e > 0 {
subVW(qhat, qhat, 1)
c := subVV(qhatv[:s], qhatv[:s], v[:s])
if len(qhatv) > s {
subVW(qhatv[s:], qhatv[s:], c)
}
addAt(u[s:], v[s:], 0)
}
}
if qhatv.cmp(u.norm()) > 0 {
panic("impossible")
}
c := subVV(u[0:len(qhatv)], u[0:len(qhatv)], qhatv)
if c > 0 {
c = subVW(u[len(qhatv):], u[len(qhatv):], c)
}
if c > 0 {
panic("impossible")
}
// Done!
addAt(z, qhat.norm(), 0)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package big
import "math/rand"
// ProbablyPrime reports whether x is probably prime,
// applying the Miller-Rabin test with n pseudorandomly chosen bases
// as well as a Baillie-PSW test.
//
// If x is prime, ProbablyPrime returns true.
// If x is chosen randomly and not prime, ProbablyPrime probably returns false.
// The probability of returning true for a randomly chosen non-prime is at most ¼ⁿ.
//
// ProbablyPrime is 100% accurate for inputs less than 2⁶⁴.
// See Menezes et al., Handbook of Applied Cryptography, 1997, pp. 145-149,
// and FIPS 186-4 Appendix F for further discussion of the error probabilities.
//
// ProbablyPrime is not suitable for judging primes that an adversary may
// have crafted to fool the test.
//
// As of Go 1.8, ProbablyPrime(0) is allowed and applies only a Baillie-PSW test.
// Before Go 1.8, ProbablyPrime applied only the Miller-Rabin tests, and ProbablyPrime(0) panicked.
func (x *Int) ProbablyPrime(n int) bool {
// Note regarding the doc comment above:
// It would be more precise to say that the Baillie-PSW test uses the
// extra strong Lucas test as its Lucas test, but since no one knows
// how to tell any of the Lucas tests apart inside a Baillie-PSW test
// (they all work equally well empirically), that detail need not be
// documented or implicitly guaranteed.
// The comment does avoid saying "the" Baillie-PSW test
// because of this general ambiguity.
if n < 0 {
panic("negative n for ProbablyPrime")
}
if x.neg || len(x.abs) == 0 {
return false
}
// primeBitMask records the primes < 64.
const primeBitMask uint64 = 1<<2 | 1<<3 | 1<<5 | 1<<7 |
1<<11 | 1<<13 | 1<<17 | 1<<19 | 1<<23 | 1<<29 | 1<<31 |
1<<37 | 1<<41 | 1<<43 | 1<<47 | 1<<53 | 1<<59 | 1<<61
w := x.abs[0]
if len(x.abs) == 1 && w < 64 {
return primeBitMask&(1<<w) != 0
}
if w&1 == 0 {
return false // x is even
}
const primesA = 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 37
const primesB = 29 * 31 * 41 * 43 * 47 * 53
var rA, rB uint32
switch _W {
case 32:
rA = uint32(x.abs.modW(primesA))
rB = uint32(x.abs.modW(primesB))
case 64:
r := x.abs.modW((primesA * primesB) & _M)
rA = uint32(r % primesA)
rB = uint32(r % primesB)
default:
panic("math/big: invalid word size")
}
if rA%3 == 0 || rA%5 == 0 || rA%7 == 0 || rA%11 == 0 || rA%13 == 0 || rA%17 == 0 || rA%19 == 0 || rA%23 == 0 || rA%37 == 0 ||
rB%29 == 0 || rB%31 == 0 || rB%41 == 0 || rB%43 == 0 || rB%47 == 0 || rB%53 == 0 {
return false
}
return x.abs.probablyPrimeMillerRabin(n+1, true) && x.abs.probablyPrimeLucas()
}
// probablyPrimeMillerRabin reports whether n passes reps rounds of the
// Miller-Rabin primality test, using pseudo-randomly chosen bases.
// If force2 is true, one of the rounds is forced to use base 2.
// See Handbook of Applied Cryptography, p. 139, Algorithm 4.24.
// The number n is known to be non-zero.
func (n nat) probablyPrimeMillerRabin(reps int, force2 bool) bool {
nm1 := nat(nil).sub(n, natOne)
// determine q, k such that nm1 = q << k
k := nm1.trailingZeroBits()
q := nat(nil).shr(nm1, k)
nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0])))
var x, y, quotient nat
nm3Len := nm3.bitLen()
NextRandom:
for i := 0; i < reps; i++ {
if i == reps-1 && force2 {
x = x.set(natTwo)
} else {
x = x.random(rand, nm3, nm3Len)
x = x.add(x, natTwo)
}
y = y.expNN(x, q, n, false)
if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
continue
}
for j := uint(1); j < k; j++ {
y = y.sqr(y)
quotient, y = quotient.div(y, y, n)
if y.cmp(nm1) == 0 {
continue NextRandom
}
if y.cmp(natOne) == 0 {
return false
}
}
return false
}
return true
}
// probablyPrimeLucas reports whether n passes the "almost extra strong" Lucas probable prime test,
// using Baillie-OEIS parameter selection. This corresponds to "AESLPSP" on Jacobsen's tables (link below).
// The combination of this test and a Miller-Rabin/Fermat test with base 2 gives a Baillie-PSW test.
//
// References:
//
// Baillie and Wagstaff, "Lucas Pseudoprimes", Mathematics of Computation 35(152),
// October 1980, pp. 1391-1417, especially page 1401.
// https://www.ams.org/journals/mcom/1980-35-152/S0025-5718-1980-0583518-6/S0025-5718-1980-0583518-6.pdf
//
// Grantham, "Frobenius Pseudoprimes", Mathematics of Computation 70(234),
// March 2000, pp. 873-891.
// https://www.ams.org/journals/mcom/2001-70-234/S0025-5718-00-01197-2/S0025-5718-00-01197-2.pdf
//
// Baillie, "Extra strong Lucas pseudoprimes", OEIS A217719, https://oeis.org/A217719.
//
// Jacobsen, "Pseudoprime Statistics, Tables, and Data", http://ntheory.org/pseudoprimes.html.
//
// Nicely, "The Baillie-PSW Primality Test", https://web.archive.org/web/20191121062007/http://www.trnicely.net/misc/bpsw.html.
// (Note that Nicely's definition of the "extra strong" test gives the wrong Jacobi condition,
// as pointed out by Jacobsen.)
//
// Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed.
// Springer, 2005.
func (n nat) probablyPrimeLucas() bool {
// Discard 0, 1.
if len(n) == 0 || n.cmp(natOne) == 0 {
return false
}
// Two is the only even prime.
// Already checked by caller, but here to allow testing in isolation.
if n[0]&1 == 0 {
return n.cmp(natTwo) == 0
}
// Baillie-OEIS "method C" for choosing D, P, Q,
// as in https://oeis.org/A217719/a217719.txt:
// try increasing P ≥ 3 such that D = P² - 4 (so Q = 1)
// until Jacobi(D, n) = -1.
// The search is expected to succeed for non-square n after just a few trials.
// After more than expected failures, check whether n is square
// (which would cause Jacobi(D, n) = 1 for all D not dividing n).
p := Word(3)
d := nat{1}
t1 := nat(nil) // temp
intD := &Int{abs: d}
intN := &Int{abs: n}
for ; ; p++ {
if p > 10000 {
// This is widely believed to be impossible.
// If we get a report, we'll want the exact number n.
panic("math/big: internal error: cannot find (D/n) = -1 for " + intN.String())
}
d[0] = p*p - 4
j := Jacobi(intD, intN)
if j == -1 {
break
}
if j == 0 {
// d = p²-4 = (p-2)(p+2).
// If (d/n) == 0 then d shares a prime factor with n.
// Since the loop proceeds in increasing p and starts with p-2==1,
// the shared prime factor must be p+2.
// If p+2 == n, then n is prime; otherwise p+2 is a proper factor of n.
return len(n) == 1 && n[0] == p+2
}
if p == 40 {
// We'll never find (d/n) = -1 if n is a square.
// If n is a non-square we expect to find a d in just a few attempts on average.
// After 40 attempts, take a moment to check if n is indeed a square.
t1 = t1.sqrt(n)
t1 = t1.sqr(t1)
if t1.cmp(n) == 0 {
return false
}
}
}
// Grantham definition of "extra strong Lucas pseudoprime", after Thm 2.3 on p. 876
// (D, P, Q above have become Δ, b, 1):
//
// Let U_n = U_n(b, 1), V_n = V_n(b, 1), and Δ = b²-4.
// An extra strong Lucas pseudoprime to base b is a composite n = 2^r s + Jacobi(Δ, n),
// where s is odd and gcd(n, 2*Δ) = 1, such that either (i) U_s ≡ 0 mod n and V_s ≡ ±2 mod n,
// or (ii) V_{2^t s} ≡ 0 mod n for some 0 ≤ t < r-1.
//
// We know gcd(n, Δ) = 1 or else we'd have found Jacobi(d, n) == 0 above.
// We know gcd(n, 2) = 1 because n is odd.
//
// Arrange s = (n - Jacobi(Δ, n)) / 2^r = (n+1) / 2^r.
s := nat(nil).add(n, natOne)
r := int(s.trailingZeroBits())
s = s.shr(s, uint(r))
nm2 := nat(nil).sub(n, natTwo) // n-2
// We apply the "almost extra strong" test, which checks the above conditions
// except for U_s ≡ 0 mod n, which allows us to avoid computing any U_k values.
// Jacobsen points out that maybe we should just do the full extra strong test:
// "It is also possible to recover U_n using Crandall and Pomerance equation 3.13:
// U_n = D^-1 (2V_{n+1} - PV_n) allowing us to run the full extra-strong test
// at the cost of a single modular inversion. This computation is easy and fast in GMP,
// so we can get the full extra-strong test at essentially the same performance as the
// almost extra strong test."
// Compute Lucas sequence V_s(b, 1), where:
//
// V(0) = 2
// V(1) = P
// V(k) = P V(k-1) - Q V(k-2).
//
// (Remember that due to method C above, P = b, Q = 1.)
//
// In general V(k) = α^k + β^k, where α and β are roots of x² - Px + Q.
// Crandall and Pomerance (p.147) observe that for 0 ≤ j ≤ k,
//
// V(j+k) = V(j)V(k) - V(k-j).
//
// So in particular, to quickly double the subscript:
//
// V(2k) = V(k)² - 2
// V(2k+1) = V(k) V(k+1) - P
//
// We can therefore start with k=0 and build up to k=s in log₂(s) steps.
natP := nat(nil).setWord(p)
vk := nat(nil).setWord(2)
vk1 := nat(nil).setWord(p)
t2 := nat(nil) // temp
for i := int(s.bitLen()); i >= 0; i-- {
if s.bit(uint(i)) != 0 {
// k' = 2k+1
// V(k') = V(2k+1) = V(k) V(k+1) - P.
t1 = t1.mul(vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
t2, vk = t2.div(vk, t1, n)
// V(k'+1) = V(2k+2) = V(k+1)² - 2.
t1 = t1.sqr(vk1)
t1 = t1.add(t1, nm2)
t2, vk1 = t2.div(vk1, t1, n)
} else {
// k' = 2k
// V(k'+1) = V(2k+1) = V(k) V(k+1) - P.
t1 = t1.mul(vk, vk1)
t1 = t1.add(t1, n)
t1 = t1.sub(t1, natP)
t2, vk1 = t2.div(vk1, t1, n)
// V(k') = V(2k) = V(k)² - 2
t1 = t1.sqr(vk)
t1 = t1.add(t1, nm2)
t2, vk = t2.div(vk, t1, n)
}
}
// Now k=s, so vk = V(s). Check V(s) ≡ ±2 (mod n).
if vk.cmp(natTwo) == 0 || vk.cmp(nm2) == 0 {
// Check U(s) ≡ 0.
// As suggested by Jacobsen, apply Crandall and Pomerance equation 3.13:
//
// U(k) = D⁻¹ (2 V(k+1) - P V(k))
//
// Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n,
// or P V(k) - 2 V(k+1) == 0 mod n.
t1 := t1.mul(vk, natP)
t2 := t2.shl(vk1, 1)
if t1.cmp(t2) < 0 {
t1, t2 = t2, t1
}
t1 = t1.sub(t1, t2)
t3 := vk1 // steal vk1, no longer needed below
vk1 = nil
_ = vk1
t2, t3 = t2.div(t3, t1, n)
if len(t3) == 0 {
return true
}
}
// Check V(2^t s) ≡ 0 mod n for some 0 ≤ t < r-1.
for t := 0; t < r-1; t++ {
if len(vk) == 0 { // vk == 0
return true
}
// Optimization: V(k) = 2 is a fixed point for V(k') = V(k)² - 2,
// so if V(k) = 2, we can stop: we will never find a future V(k) == 0.
if len(vk) == 1 && vk[0] == 2 { // vk == 2
return false
}
// k' = 2k
// V(k') = V(2k) = V(k)² - 2
t1 = t1.sqr(vk)
t1 = t1.sub(t1, natTwo)
t2, vk = t2.div(vk, t1, n)
}
return false
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements multi-precision rational numbers.
package big
import (
"fmt"
"math"
)
// A Rat represents a quotient a/b of arbitrary precision.
// The zero value for a Rat represents the value 0.
//
// Operations always take pointer arguments (*Rat) rather
// than Rat values, and each unique Rat value requires
// its own unique *Rat pointer. To "copy" a Rat value,
// an existing (or newly allocated) Rat must be set to
// a new value using the Rat.Set method; shallow copies
// of Rats are not supported and may lead to errors.
type Rat struct {
// To make zero values for Rat work w/o initialization,
// a zero value of b (len(b) == 0) acts like b == 1. At
// the earliest opportunity (when an assignment to the Rat
// is made), such uninitialized denominators are set to 1.
// a.neg determines the sign of the Rat, b.neg is ignored.
a, b Int
}
// NewRat creates a new Rat with numerator a and denominator b.
func NewRat(a, b int64) *Rat {
return new(Rat).SetFrac64(a, b)
}
// SetFloat64 sets z to exactly f and returns z.
// If f is not finite, SetFloat returns nil.
func (z *Rat) SetFloat64(f float64) *Rat {
const expMask = 1<<11 - 1
bits := math.Float64bits(f)
mantissa := bits & (1<<52 - 1)
exp := int((bits >> 52) & expMask)
switch exp {
case expMask: // non-finite
return nil
case 0: // denormal
exp -= 1022
default: // normal
mantissa |= 1 << 52
exp -= 1023
}
shift := 52 - exp
// Optimization (?): partially pre-normalise.
for mantissa&1 == 0 && shift > 0 {
mantissa >>= 1
shift--
}
z.a.SetUint64(mantissa)
z.a.neg = f < 0
z.b.Set(intOne)
if shift > 0 {
z.b.Lsh(&z.b, uint(shift))
} else {
z.a.Lsh(&z.a, uint(-shift))
}
return z.norm()
}
// quotToFloat32 returns the non-negative float32 value
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
func quotToFloat32(a, b nat) (f float32, exact bool) {
const (
// float size in bits
Fsize = 32
// mantissa
Msize = 23
Msize1 = Msize + 1 // incl. implicit 1
Msize2 = Msize1 + 1
// exponent
Esize = Fsize - Msize1
Ebias = 1<<(Esize-1) - 1
Emin = 1 - Ebias
Emax = Ebias
)
// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
alen := a.bitLen()
if alen == 0 {
return 0, true
}
blen := b.bitLen()
if blen == 0 {
panic("division by zero")
}
// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
// This is 2 or 3 more than the float32 mantissa field width of Msize:
// - the optional extra bit is shifted away in step 3 below.
// - the high-order 1 is omitted in "normal" representation;
// - the low-order 1 will be used during rounding then discarded.
exp := alen - blen
var a2, b2 nat
a2 = a2.set(a)
b2 = b2.set(b)
if shift := Msize2 - exp; shift > 0 {
a2 = a2.shl(a2, uint(shift))
} else if shift < 0 {
b2 = b2.shl(b2, uint(-shift))
}
// 2. Compute quotient and remainder (q, r). NB: due to the
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
q, r := q.div(a2, a2, b2) // (recycle a2)
mantissa := low32(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
// (in effect---we accomplish this incrementally).
if mantissa>>Msize2 == 1 {
if mantissa&1 == 1 {
haveRem = true
}
mantissa >>= 1
exp++
}
if mantissa>>Msize1 != 1 {
panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
}
// 4. Rounding.
if Emin-Msize <= exp && exp <= Emin {
// Denormal case; lose 'shift' bits of precision.
shift := uint(Emin - (exp - 1)) // [1..Esize1)
lostbits := mantissa & (1<<shift - 1)
haveRem = haveRem || lostbits != 0
mantissa >>= shift
exp = 2 - Ebias // == exp + shift
}
// Round q using round-half-to-even.
exact = !haveRem
if mantissa&1 != 0 {
exact = false
if haveRem || mantissa&2 != 0 {
if mantissa++; mantissa >= 1<<Msize2 {
// Complete rollover 11...1 => 100...0, so shift is safe
mantissa >>= 1
exp++
}
}
}
mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1.
f = float32(math.Ldexp(float64(mantissa), exp-Msize1))
if math.IsInf(float64(f), 0) {
exact = false
}
return
}
// quotToFloat64 returns the non-negative float64 value
// nearest to the quotient a/b, using round-to-even in
// halfway cases. It does not mutate its arguments.
// Preconditions: b is non-zero; a and b have no common factors.
func quotToFloat64(a, b nat) (f float64, exact bool) {
const (
// float size in bits
Fsize = 64
// mantissa
Msize = 52
Msize1 = Msize + 1 // incl. implicit 1
Msize2 = Msize1 + 1
// exponent
Esize = Fsize - Msize1
Ebias = 1<<(Esize-1) - 1
Emin = 1 - Ebias
Emax = Ebias
)
// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
alen := a.bitLen()
if alen == 0 {
return 0, true
}
blen := b.bitLen()
if blen == 0 {
panic("division by zero")
}
// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
// This is 2 or 3 more than the float64 mantissa field width of Msize:
// - the optional extra bit is shifted away in step 3 below.
// - the high-order 1 is omitted in "normal" representation;
// - the low-order 1 will be used during rounding then discarded.
exp := alen - blen
var a2, b2 nat
a2 = a2.set(a)
b2 = b2.set(b)
if shift := Msize2 - exp; shift > 0 {
a2 = a2.shl(a2, uint(shift))
} else if shift < 0 {
b2 = b2.shl(b2, uint(-shift))
}
// 2. Compute quotient and remainder (q, r). NB: due to the
// extra shift, the low-order bit of q is logically the
// high-order bit of r.
var q nat
q, r := q.div(a2, a2, b2) // (recycle a2)
mantissa := low64(q)
haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
// (in effect---we accomplish this incrementally).
if mantissa>>Msize2 == 1 {
if mantissa&1 == 1 {
haveRem = true
}
mantissa >>= 1
exp++
}
if mantissa>>Msize1 != 1 {
panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
}
// 4. Rounding.
if Emin-Msize <= exp && exp <= Emin {
// Denormal case; lose 'shift' bits of precision.
shift := uint(Emin - (exp - 1)) // [1..Esize1)
lostbits := mantissa & (1<<shift - 1)
haveRem = haveRem || lostbits != 0
mantissa >>= shift
exp = 2 - Ebias // == exp + shift
}
// Round q using round-half-to-even.
exact = !haveRem
if mantissa&1 != 0 {
exact = false
if haveRem || mantissa&2 != 0 {
if mantissa++; mantissa >= 1<<Msize2 {
// Complete rollover 11...1 => 100...0, so shift is safe
mantissa >>= 1
exp++
}
}
}
mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1.
f = math.Ldexp(float64(mantissa), exp-Msize1)
if math.IsInf(f, 0) {
exact = false
}
return
}
// Float32 returns the nearest float32 value for x and a bool indicating
// whether f represents x exactly. If the magnitude of x is too large to
// be represented by a float32, f is an infinity and exact is false.
// The sign of f always matches the sign of x, even if f == 0.
func (x *Rat) Float32() (f float32, exact bool) {
b := x.b.abs
if len(b) == 0 {
b = natOne
}
f, exact = quotToFloat32(x.a.abs, b)
if x.a.neg {
f = -f
}
return
}
// Float64 returns the nearest float64 value for x and a bool indicating
// whether f represents x exactly. If the magnitude of x is too large to
// be represented by a float64, f is an infinity and exact is false.
// The sign of f always matches the sign of x, even if f == 0.
func (x *Rat) Float64() (f float64, exact bool) {
b := x.b.abs
if len(b) == 0 {
b = natOne
}
f, exact = quotToFloat64(x.a.abs, b)
if x.a.neg {
f = -f
}
return
}
// SetFrac sets z to a/b and returns z.
// If b == 0, SetFrac panics.
func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.neg = a.neg != b.neg
babs := b.abs
if len(babs) == 0 {
panic("division by zero")
}
if &z.a == b || alias(z.a.abs, babs) {
babs = nat(nil).set(babs) // make a copy
}
z.a.abs = z.a.abs.set(a.abs)
z.b.abs = z.b.abs.set(babs)
return z.norm()
}
// SetFrac64 sets z to a/b and returns z.
// If b == 0, SetFrac64 panics.
func (z *Rat) SetFrac64(a, b int64) *Rat {
if b == 0 {
panic("division by zero")
}
z.a.SetInt64(a)
if b < 0 {
b = -b
z.a.neg = !z.a.neg
}
z.b.abs = z.b.abs.setUint64(uint64(b))
return z.norm()
}
// SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x)
z.b.abs = z.b.abs.setWord(1)
return z
}
// SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x)
z.b.abs = z.b.abs.setWord(1)
return z
}
// SetUint64 sets z to x and returns z.
func (z *Rat) SetUint64(x uint64) *Rat {
z.a.SetUint64(x)
z.b.abs = z.b.abs.setWord(1)
return z
}
// Set sets z to x (by making a copy of x) and returns z.
func (z *Rat) Set(x *Rat) *Rat {
if z != x {
z.a.Set(&x.a)
z.b.Set(&x.b)
}
if len(z.b.abs) == 0 {
z.b.abs = z.b.abs.setWord(1)
}
return z
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat {
z.Set(x)
z.a.neg = false
return z
}
// Neg sets z to -x and returns z.
func (z *Rat) Neg(x *Rat) *Rat {
z.Set(x)
z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
return z
}
// Inv sets z to 1/x and returns z.
// If x == 0, Inv panics.
func (z *Rat) Inv(x *Rat) *Rat {
if len(x.a.abs) == 0 {
panic("division by zero")
}
z.Set(x)
z.a.abs, z.b.abs = z.b.abs, z.a.abs
return z
}
// Sign returns:
//
// -1 if x < 0
// 0 if x == 0
// +1 if x > 0
func (x *Rat) Sign() int {
return x.a.Sign()
}
// IsInt reports whether the denominator of x is 1.
func (x *Rat) IsInt() bool {
return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
}
// Num returns the numerator of x; it may be <= 0.
// The result is a reference to x's numerator; it
// may change if a new value is assigned to x, and vice versa.
// The sign of the numerator corresponds to the sign of x.
func (x *Rat) Num() *Int {
return &x.a
}
// Denom returns the denominator of x; it is always > 0.
// The result is a reference to x's denominator, unless
// x is an uninitialized (zero value) Rat, in which case
// the result is a new Int of value 1. (To initialize x,
// any operation that sets x will do, including x.Set(x).)
// If the result is a reference to x's denominator it
// may change if a new value is assigned to x, and vice versa.
func (x *Rat) Denom() *Int {
// Note that x.b.neg is guaranteed false.
if len(x.b.abs) == 0 {
// Note: If this proves problematic, we could
// panic instead and require the Rat to
// be explicitly initialized.
return &Int{abs: nat{1}}
}
return &x.b
}
func (z *Rat) norm() *Rat {
switch {
case len(z.a.abs) == 0:
// z == 0; normalize sign and denominator
z.a.neg = false
fallthrough
case len(z.b.abs) == 0:
// z is integer; normalize denominator
z.b.abs = z.b.abs.setWord(1)
default:
// z is fraction; normalize numerator and denominator
neg := z.a.neg
z.a.neg = false
z.b.neg = false
if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
}
z.a.neg = neg
}
return z
}
// mulDenom sets z to the denominator product x*y (by taking into
// account that 0 values for x or y must be interpreted as 1) and
// returns z.
func mulDenom(z, x, y nat) nat {
switch {
case len(x) == 0 && len(y) == 0:
return z.setWord(1)
case len(x) == 0:
return z.set(y)
case len(y) == 0:
return z.set(x)
}
return z.mul(x, y)
}
// scaleDenom sets z to the product x*f.
// If f == 0 (zero value of denominator), z is set to (a copy of) x.
func (z *Int) scaleDenom(x *Int, f nat) {
if len(f) == 0 {
z.Set(x)
return
}
z.abs = z.abs.mul(x.abs, f)
z.neg = x.neg
}
// Cmp compares x and y and returns:
//
// -1 if x < y
// 0 if x == y
// +1 if x > y
func (x *Rat) Cmp(y *Rat) int {
var a, b Int
a.scaleDenom(&x.a, y.b.abs)
b.scaleDenom(&y.a, x.b.abs)
return a.Cmp(&b)
}
// Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat {
var a1, a2 Int
a1.scaleDenom(&x.a, y.b.abs)
a2.scaleDenom(&y.a, x.b.abs)
z.a.Add(&a1, &a2)
z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat {
var a1, a2 Int
a1.scaleDenom(&x.a, y.b.abs)
a2.scaleDenom(&y.a, x.b.abs)
z.a.Sub(&a1, &a2)
z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat {
if x == y {
// a squared Rat is positive and can't be reduced (no need to call norm())
z.a.neg = false
z.a.abs = z.a.abs.sqr(x.a.abs)
if len(x.b.abs) == 0 {
z.b.abs = z.b.abs.setWord(1)
} else {
z.b.abs = z.b.abs.sqr(x.b.abs)
}
return z
}
z.a.Mul(&x.a, &y.a)
z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
return z.norm()
}
// Quo sets z to the quotient x/y and returns z.
// If y == 0, Quo panics.
func (z *Rat) Quo(x, y *Rat) *Rat {
if len(y.a.abs) == 0 {
panic("division by zero")
}
var a, b Int
a.scaleDenom(&x.a, y.b.abs)
b.scaleDenom(&y.a, x.b.abs)
z.a.abs = a.abs
z.b.abs = b.abs
z.a.neg = a.neg != b.neg
return z.norm()
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements rat-to-string conversion functions.
package big
import (
"errors"
"fmt"
"io"
"strconv"
"strings"
)
func ratTok(ch rune) bool {
return strings.ContainsRune("+-/0123456789.eE", ch)
}
var ratZero Rat
var _ fmt.Scanner = &ratZero // *Rat must implement fmt.Scanner
// Scan is a support routine for fmt.Scanner. It accepts the formats
// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent.
func (z *Rat) Scan(s fmt.ScanState, ch rune) error {
tok, err := s.Token(true, ratTok)
if err != nil {
return err
}
if !strings.ContainsRune("efgEFGv", ch) {
return errors.New("Rat.Scan: invalid verb")
}
if _, ok := z.SetString(string(tok)); !ok {
return errors.New("Rat.Scan: invalid syntax")
}
return nil
}
// SetString sets z to the value of s and returns z and a boolean indicating
// success. s can be given as a (possibly signed) fraction "a/b", or as a
// floating-point number optionally followed by an exponent.
// If a fraction is provided, both the dividend and the divisor may be a
// decimal integer or independently use a prefix of “0b”, “0” or “0o”,
// or “0x” (or their upper-case variants) to denote a binary, octal, or
// hexadecimal integer, respectively. The divisor may not be signed.
// If a floating-point number is provided, it may be in decimal form or
// use any of the same prefixes as above but for “0” to denote a non-decimal
// mantissa. A leading “0” is considered a decimal leading 0; it does not
// indicate octal representation in this case.
// An optional base-10 “e” or base-2 “p” (or their upper-case variants)
// exponent may be provided as well, except for hexadecimal floats which
// only accept an (optional) “p” exponent (because an “e” or “E” cannot
// be distinguished from a mantissa digit). If the exponent's absolute value
// is too large, the operation may fail.
// The entire string, not just a prefix, must be valid for success. If the
// operation failed, the value of z is undefined but the returned value is nil.
func (z *Rat) SetString(s string) (*Rat, bool) {
if len(s) == 0 {
return nil, false
}
// len(s) > 0
// parse fraction a/b, if any
if sep := strings.Index(s, "/"); sep >= 0 {
if _, ok := z.a.SetString(s[:sep], 0); !ok {
return nil, false
}
r := strings.NewReader(s[sep+1:])
var err error
if z.b.abs, _, _, err = z.b.abs.scan(r, 0, false); err != nil {
return nil, false
}
// entire string must have been consumed
if _, err = r.ReadByte(); err != io.EOF {
return nil, false
}
if len(z.b.abs) == 0 {
return nil, false
}
return z.norm(), true
}
// parse floating-point number
r := strings.NewReader(s)
// sign
neg, err := scanSign(r)
if err != nil {
return nil, false
}
// mantissa
var base int
var fcount int // fractional digit count; valid if <= 0
z.a.abs, base, fcount, err = z.a.abs.scan(r, 0, true)
if err != nil {
return nil, false
}
// exponent
var exp int64
var ebase int
exp, ebase, err = scanExponent(r, true, true)
if err != nil {
return nil, false
}
// there should be no unread characters left
if _, err = r.ReadByte(); err != io.EOF {
return nil, false
}
// special-case 0 (see also issue #16176)
if len(z.a.abs) == 0 {
return z.norm(), true
}
// len(z.a.abs) > 0
// The mantissa may have a radix point (fcount <= 0) and there
// may be a nonzero exponent exp. The radix point amounts to a
// division by base**(-fcount), which equals a multiplication by
// base**fcount. An exponent means multiplication by ebase**exp.
// Multiplications are commutative, so we can apply them in any
// order. We only have powers of 2 and 10, and we split powers
// of 10 into the product of the same powers of 2 and 5. This
// may reduce the size of shift/multiplication factors or
// divisors required to create the final fraction, depending
// on the actual floating-point value.
// determine binary or decimal exponent contribution of radix point
var exp2, exp5 int64
if fcount < 0 {
// The mantissa has a radix point ddd.dddd; and
// -fcount is the number of digits to the right
// of '.'. Adjust relevant exponent accordingly.
d := int64(fcount)
switch base {
case 10:
exp5 = d
fallthrough // 10**e == 5**e * 2**e
case 2:
exp2 = d
case 8:
exp2 = d * 3 // octal digits are 3 bits each
case 16:
exp2 = d * 4 // hexadecimal digits are 4 bits each
default:
panic("unexpected mantissa base")
}
// fcount consumed - not needed anymore
}
// take actual exponent into account
switch ebase {
case 10:
exp5 += exp
fallthrough // see fallthrough above
case 2:
exp2 += exp
default:
panic("unexpected exponent base")
}
// exp consumed - not needed anymore
// apply exp5 contributions
// (start with exp5 so the numbers to multiply are smaller)
if exp5 != 0 {
n := exp5
if n < 0 {
n = -n
if n < 0 {
// This can occur if -n overflows. -(-1 << 63) would become
// -1 << 63, which is still negative.
return nil, false
}
}
if n > 1e6 {
return nil, false // avoid excessively large exponents
}
pow5 := z.b.abs.expNN(natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs
if exp5 > 0 {
z.a.abs = z.a.abs.mul(z.a.abs, pow5)
z.b.abs = z.b.abs.setWord(1)
} else {
z.b.abs = pow5
}
} else {
z.b.abs = z.b.abs.setWord(1)
}
// apply exp2 contributions
if exp2 < -1e7 || exp2 > 1e7 {
return nil, false // avoid excessively large exponents
}
if exp2 > 0 {
z.a.abs = z.a.abs.shl(z.a.abs, uint(exp2))
} else if exp2 < 0 {
z.b.abs = z.b.abs.shl(z.b.abs, uint(-exp2))
}
z.a.neg = neg && len(z.a.abs) > 0 // 0 has no sign
return z.norm(), true
}
// scanExponent scans the longest possible prefix of r representing a base 10
// (“e”, “E”) or a base 2 (“p”, “P”) exponent, if any. It returns the
// exponent, the exponent base (10 or 2), or a read or syntax error, if any.
//
// If sepOk is set, an underscore character “_” may appear between successive
// exponent digits; such underscores do not change the value of the exponent.
// Incorrect placement of underscores is reported as an error if there are no
// other errors. If sepOk is not set, underscores are not recognized and thus
// terminate scanning like any other character that is not a valid digit.
//
// exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits .
// sign = "+" | "-" .
// digits = digit { [ '_' ] digit } .
// digit = "0" ... "9" .
//
// A base 2 exponent is only permitted if base2ok is set.
func scanExponent(r io.ByteScanner, base2ok, sepOk bool) (exp int64, base int, err error) {
// one char look-ahead
ch, err := r.ReadByte()
if err != nil {
if err == io.EOF {
err = nil
}
return 0, 10, err
}
// exponent char
switch ch {
case 'e', 'E':
base = 10
case 'p', 'P':
if base2ok {
base = 2
break // ok
}
fallthrough // binary exponent not permitted
default:
r.UnreadByte() // ch does not belong to exponent anymore
return 0, 10, nil
}
// sign
var digits []byte
ch, err = r.ReadByte()
if err == nil && (ch == '+' || ch == '-') {
if ch == '-' {
digits = append(digits, '-')
}
ch, err = r.ReadByte()
}
// prev encodes the previously seen char: it is one
// of '_', '0' (a digit), or '.' (anything else). A
// valid separator '_' may only occur after a digit.
prev := '.'
invalSep := false
// exponent value
hasDigits := false
for err == nil {
if '0' <= ch && ch <= '9' {
digits = append(digits, ch)
prev = '0'
hasDigits = true
} else if ch == '_' && sepOk {
if prev != '0' {
invalSep = true
}
prev = '_'
} else {
r.UnreadByte() // ch does not belong to number anymore
break
}
ch, err = r.ReadByte()
}
if err == io.EOF {
err = nil
}
if err == nil && !hasDigits {
err = errNoDigits
}
if err == nil {
exp, err = strconv.ParseInt(string(digits), 10, 64)
}
// other errors take precedence over invalid separators
if err == nil && (invalSep || prev == '_') {
err = errInvalSep
}
return
}
// String returns a string representation of x in the form "a/b" (even if b == 1).
func (x *Rat) String() string {
return string(x.marshal())
}
// marshal implements String returning a slice of bytes
func (x *Rat) marshal() []byte {
var buf []byte
buf = x.a.Append(buf, 10)
buf = append(buf, '/')
if len(x.b.abs) != 0 {
buf = x.b.Append(buf, 10)
} else {
buf = append(buf, '1')
}
return buf
}
// RatString returns a string representation of x in the form "a/b" if b != 1,
// and in the form "a" if b == 1.
func (x *Rat) RatString() string {
if x.IsInt() {
return x.a.String()
}
return x.String()
}
// FloatString returns a string representation of x in decimal form with prec
// digits of precision after the radix point. The last digit is rounded to
// nearest, with halves rounded away from zero.
func (x *Rat) FloatString(prec int) string {
var buf []byte
if x.IsInt() {
buf = x.a.Append(buf, 10)
if prec > 0 {
buf = append(buf, '.')
for i := prec; i > 0; i-- {
buf = append(buf, '0')
}
}
return string(buf)
}
// x.b.abs != 0
q, r := nat(nil).div(nat(nil), x.a.abs, x.b.abs)
p := natOne
if prec > 0 {
p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil, false)
}
r = r.mul(r, p)
r, r2 := r.div(nat(nil), r, x.b.abs)
// see if we need to round up
r2 = r2.add(r2, r2)
if x.b.abs.cmp(r2) <= 0 {
r = r.add(r, natOne)
if r.cmp(p) >= 0 {
q = nat(nil).add(q, natOne)
r = nat(nil).sub(r, p)
}
}
if x.a.neg {
buf = append(buf, '-')
}
buf = append(buf, q.utoa(10)...) // itoa ignores sign if q == 0
if prec > 0 {
buf = append(buf, '.')
rs := r.utoa(10)
for i := prec - len(rs); i > 0; i-- {
buf = append(buf, '0')
}
buf = append(buf, rs...)
}
return string(buf)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements encoding/decoding of Rats.
package big
import (
"encoding/binary"
"errors"
"fmt"
"math"
)
// Gob codec version. Permits backward-compatible changes to the encoding.
const ratGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (x *Rat) GobEncode() ([]byte, error) {
if x == nil {
return nil, nil
}
buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
i := x.b.abs.bytes(buf)
j := x.a.abs.bytes(buf[:i])
n := i - j
if int(uint32(n)) != n {
// this should never happen
return nil, errors.New("Rat.GobEncode: numerator too large")
}
binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
j -= 1 + 4
b := ratGobVersion << 1 // make space for sign bit
if x.a.neg {
b |= 1
}
buf[j] = b
return buf[j:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Rat) GobDecode(buf []byte) error {
if len(buf) == 0 {
// Other side sent a nil or default value.
*z = Rat{}
return nil
}
if len(buf) < 5 {
return errors.New("Rat.GobDecode: buffer too small")
}
b := buf[0]
if b>>1 != ratGobVersion {
return fmt.Errorf("Rat.GobDecode: encoding version %d not supported", b>>1)
}
const j = 1 + 4
ln := binary.BigEndian.Uint32(buf[j-4 : j])
if uint64(ln) > math.MaxInt-j {
return errors.New("Rat.GobDecode: invalid length")
}
i := j + int(ln)
if len(buf) < i {
return errors.New("Rat.GobDecode: buffer too small")
}
z.a.neg = b&1 != 0
z.a.abs = z.a.abs.setBytes(buf[j:i])
z.b.abs = z.b.abs.setBytes(buf[i:])
return nil
}
// MarshalText implements the encoding.TextMarshaler interface.
func (x *Rat) MarshalText() (text []byte, err error) {
if x.IsInt() {
return x.a.MarshalText()
}
return x.marshal(), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (z *Rat) UnmarshalText(text []byte) error {
// TODO(gri): get rid of the []byte/string conversion
if _, ok := z.SetString(string(text)); !ok {
return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Rat", text)
}
return nil
}
// Code generated by "stringer -type=RoundingMode"; DO NOT EDIT.
package big
import "strconv"
const _RoundingMode_name = "ToNearestEvenToNearestAwayToZeroAwayFromZeroToNegativeInfToPositiveInf"
var _RoundingMode_index = [...]uint8{0, 13, 26, 32, 44, 57, 70}
func (i RoundingMode) String() string {
if i >= RoundingMode(len(_RoundingMode_index)-1) {
return "RoundingMode(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _RoundingMode_name[_RoundingMode_index[i]:_RoundingMode_index[i+1]]
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package big
import (
"math"
"sync"
)
var threeOnce struct {
sync.Once
v *Float
}
func three() *Float {
threeOnce.Do(func() {
threeOnce.v = NewFloat(3.0)
})
return threeOnce.v
}
// Sqrt sets z to the rounded square root of x, and returns it.
//
// If z's precision is 0, it is changed to x's precision before the
// operation. Rounding is performed according to z's precision and
// rounding mode, but z's accuracy is not computed. Specifically, the
// result of z.Acc() is undefined.
//
// The function panics if z < 0. The value of z is undefined in that
// case.
func (z *Float) Sqrt(x *Float) *Float {
if debugFloat {
x.validate()
}
if z.prec == 0 {
z.prec = x.prec
}
if x.Sign() == -1 {
// following IEEE754-2008 (section 7.2)
panic(ErrNaN{"square root of negative operand"})
}
// handle ±0 and +∞
if x.form != finite {
z.acc = Exact
z.form = x.form
z.neg = x.neg // IEEE754-2008 requires √±0 = ±0
return z
}
// MantExp sets the argument's precision to the receiver's, and
// when z.prec > x.prec this will lower z.prec. Restore it after
// the MantExp call.
prec := z.prec
b := x.MantExp(z)
z.prec = prec
// Compute √(z·2**b) as
// √( z)·2**(½b) if b is even
// √(2z)·2**(⌊½b⌋) if b > 0 is odd
// √(½z)·2**(⌈½b⌉) if b < 0 is odd
switch b % 2 {
case 0:
// nothing to do
case 1:
z.exp++
case -1:
z.exp--
}
// 0.25 <= z < 2.0
// Solving 1/x² - z = 0 avoids Quo calls and is faster, especially
// for high precisions.
z.sqrtInverse(z)
// re-attach halved exponent
return z.SetMantExp(z, b/2)
}
// Compute √x (to z.prec precision) by solving
//
// 1/t² - x = 0
//
// for t (using Newton's method), and then inverting.
func (z *Float) sqrtInverse(x *Float) {
// let
// f(t) = 1/t² - x
// then
// g(t) = f(t)/f'(t) = -½t(1 - xt²)
// and the next guess is given by
// t2 = t - g(t) = ½t(3 - xt²)
u := newFloat(z.prec)
v := newFloat(z.prec)
three := three()
ng := func(t *Float) *Float {
u.prec = t.prec
v.prec = t.prec
u.Mul(t, t) // u = t²
u.Mul(x, u) // = xt²
v.Sub(three, u) // v = 3 - xt²
u.Mul(t, v) // u = t(3 - xt²)
u.exp-- // = ½t(3 - xt²)
return t.Set(u)
}
xf, _ := x.Float64()
sqi := newFloat(z.prec)
sqi.SetFloat64(1 / math.Sqrt(xf))
for prec := z.prec + 32; sqi.prec < prec; {
sqi.prec *= 2
sqi = ng(sqi)
}
// sqi = 1/√x
// x/√x = √x
z.Mul(x, sqi)
}
// newFloat returns a new *Float with space for twice the given
// precision.
func newFloat(prec2 uint32) *Float {
z := new(Float)
// nat.make ensures the slice length is > 0
z.mant = z.mant.make(int(prec2/_W) * 2)
return z
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
const (
uvnan = 0x7FF8000000000001
uvinf = 0x7FF0000000000000
uvneginf = 0xFFF0000000000000
uvone = 0x3FF0000000000000
mask = 0x7FF
shift = 64 - 11 - 1
bias = 1023
signMask = 1 << 63
fracMask = 1<<shift - 1
)
// Inf returns positive infinity if sign >= 0, negative infinity if sign < 0.
func Inf(sign int) float64 {
var v uint64
if sign >= 0 {
v = uvinf
} else {
v = uvneginf
}
return Float64frombits(v)
}
// NaN returns an IEEE 754 “not-a-number” value.
func NaN() float64 { return Float64frombits(uvnan) }
// IsNaN reports whether f is an IEEE 754 “not-a-number” value.
func IsNaN(f float64) (is bool) {
// IEEE 754 says that only NaNs satisfy f != f.
// To avoid the floating-point hardware, could use:
// x := Float64bits(f);
// return uint32(x>>shift)&mask == mask && x != uvinf && x != uvneginf
return f != f
}
// IsInf reports whether f is an infinity, according to sign.
// If sign > 0, IsInf reports whether f is positive infinity.
// If sign < 0, IsInf reports whether f is negative infinity.
// If sign == 0, IsInf reports whether f is either infinity.
func IsInf(f float64, sign int) bool {
// Test for infinity by comparing against maximum float.
// To avoid the floating-point hardware, could use:
// x := Float64bits(f);
// return sign >= 0 && x == uvinf || sign <= 0 && x == uvneginf;
return sign >= 0 && f > MaxFloat64 || sign <= 0 && f < -MaxFloat64
}
// normalize returns a normal number y and exponent exp
// satisfying x == y × 2**exp. It assumes x is finite and non-zero.
func normalize(x float64) (y float64, exp int) {
const SmallestNormal = 2.2250738585072014e-308 // 2**-1022
if Abs(x) < SmallestNormal {
return x * (1 << 52), -52
}
return x, 0
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run make_tables.go
// Package bits implements bit counting and manipulation
// functions for the predeclared unsigned integer types.
//
// Functions in this package may be implemented directly by
// the compiler, for better performance. For those functions
// the code in this package will not be used. Which
// functions are implemented by the compiler depends on the
// architecture and the Go release.
package bits
const uintSize = 32 << (^uint(0) >> 63) // 32 or 64
// UintSize is the size of a uint in bits.
const UintSize = uintSize
// --- LeadingZeros ---
// LeadingZeros returns the number of leading zero bits in x; the result is UintSize for x == 0.
func LeadingZeros(x uint) int { return UintSize - Len(x) }
// LeadingZeros8 returns the number of leading zero bits in x; the result is 8 for x == 0.
func LeadingZeros8(x uint8) int { return 8 - Len8(x) }
// LeadingZeros16 returns the number of leading zero bits in x; the result is 16 for x == 0.
func LeadingZeros16(x uint16) int { return 16 - Len16(x) }
// LeadingZeros32 returns the number of leading zero bits in x; the result is 32 for x == 0.
func LeadingZeros32(x uint32) int { return 32 - Len32(x) }
// LeadingZeros64 returns the number of leading zero bits in x; the result is 64 for x == 0.
func LeadingZeros64(x uint64) int { return 64 - Len64(x) }
// --- TrailingZeros ---
// See http://supertech.csail.mit.edu/papers/debruijn.pdf
const deBruijn32 = 0x077CB531
var deBruijn32tab = [32]byte{
0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
}
const deBruijn64 = 0x03f79d71b4ca8b09
var deBruijn64tab = [64]byte{
0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
}
// TrailingZeros returns the number of trailing zero bits in x; the result is UintSize for x == 0.
func TrailingZeros(x uint) int {
if UintSize == 32 {
return TrailingZeros32(uint32(x))
}
return TrailingZeros64(uint64(x))
}
// TrailingZeros8 returns the number of trailing zero bits in x; the result is 8 for x == 0.
func TrailingZeros8(x uint8) int {
return int(ntz8tab[x])
}
// TrailingZeros16 returns the number of trailing zero bits in x; the result is 16 for x == 0.
func TrailingZeros16(x uint16) int {
if x == 0 {
return 16
}
// see comment in TrailingZeros64
return int(deBruijn32tab[uint32(x&-x)*deBruijn32>>(32-5)])
}
// TrailingZeros32 returns the number of trailing zero bits in x; the result is 32 for x == 0.
func TrailingZeros32(x uint32) int {
if x == 0 {
return 32
}
// see comment in TrailingZeros64
return int(deBruijn32tab[(x&-x)*deBruijn32>>(32-5)])
}
// TrailingZeros64 returns the number of trailing zero bits in x; the result is 64 for x == 0.
func TrailingZeros64(x uint64) int {
if x == 0 {
return 64
}
// If popcount is fast, replace code below with return popcount(^x & (x - 1)).
//
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multiplying by a power of two is equivalent to
// left shifting, in this case by k bits. The de Bruijn (64 bit) constant
// is such that all six bit, consecutive substrings are distinct.
// Therefore, if we have a left shifted version of this constant we can
// find by how many bits it was shifted by looking at which six bit
// substring ended up at the top of the word.
// (Knuth, volume 4, section 7.3.1)
return int(deBruijn64tab[(x&-x)*deBruijn64>>(64-6)])
}
// --- OnesCount ---
const m0 = 0x5555555555555555 // 01010101 ...
const m1 = 0x3333333333333333 // 00110011 ...
const m2 = 0x0f0f0f0f0f0f0f0f // 00001111 ...
const m3 = 0x00ff00ff00ff00ff // etc.
const m4 = 0x0000ffff0000ffff
// OnesCount returns the number of one bits ("population count") in x.
func OnesCount(x uint) int {
if UintSize == 32 {
return OnesCount32(uint32(x))
}
return OnesCount64(uint64(x))
}
// OnesCount8 returns the number of one bits ("population count") in x.
func OnesCount8(x uint8) int {
return int(pop8tab[x])
}
// OnesCount16 returns the number of one bits ("population count") in x.
func OnesCount16(x uint16) int {
return int(pop8tab[x>>8] + pop8tab[x&0xff])
}
// OnesCount32 returns the number of one bits ("population count") in x.
func OnesCount32(x uint32) int {
return int(pop8tab[x>>24] + pop8tab[x>>16&0xff] + pop8tab[x>>8&0xff] + pop8tab[x&0xff])
}
// OnesCount64 returns the number of one bits ("population count") in x.
func OnesCount64(x uint64) int {
// Implementation: Parallel summing of adjacent bits.
// See "Hacker's Delight", Chap. 5: Counting Bits.
// The following pattern shows the general approach:
//
// x = x>>1&(m0&m) + x&(m0&m)
// x = x>>2&(m1&m) + x&(m1&m)
// x = x>>4&(m2&m) + x&(m2&m)
// x = x>>8&(m3&m) + x&(m3&m)
// x = x>>16&(m4&m) + x&(m4&m)
// x = x>>32&(m5&m) + x&(m5&m)
// return int(x)
//
// Masking (& operations) can be left away when there's no
// danger that a field's sum will carry over into the next
// field: Since the result cannot be > 64, 8 bits is enough
// and we can ignore the masks for the shifts by 8 and up.
// Per "Hacker's Delight", the first line can be simplified
// more, but it saves at best one instruction, so we leave
// it alone for clarity.
const m = 1<<64 - 1
x = x>>1&(m0&m) + x&(m0&m)
x = x>>2&(m1&m) + x&(m1&m)
x = (x>>4 + x) & (m2 & m)
x += x >> 8
x += x >> 16
x += x >> 32
return int(x) & (1<<7 - 1)
}
// --- RotateLeft ---
// RotateLeft returns the value of x rotated left by (k mod UintSize) bits.
// To rotate x right by k bits, call RotateLeft(x, -k).
//
// This function's execution time does not depend on the inputs.
func RotateLeft(x uint, k int) uint {
if UintSize == 32 {
return uint(RotateLeft32(uint32(x), k))
}
return uint(RotateLeft64(uint64(x), k))
}
// RotateLeft8 returns the value of x rotated left by (k mod 8) bits.
// To rotate x right by k bits, call RotateLeft8(x, -k).
//
// This function's execution time does not depend on the inputs.
func RotateLeft8(x uint8, k int) uint8 {
const n = 8
s := uint(k) & (n - 1)
return x<<s | x>>(n-s)
}
// RotateLeft16 returns the value of x rotated left by (k mod 16) bits.
// To rotate x right by k bits, call RotateLeft16(x, -k).
//
// This function's execution time does not depend on the inputs.
func RotateLeft16(x uint16, k int) uint16 {
const n = 16
s := uint(k) & (n - 1)
return x<<s | x>>(n-s)
}
// RotateLeft32 returns the value of x rotated left by (k mod 32) bits.
// To rotate x right by k bits, call RotateLeft32(x, -k).
//
// This function's execution time does not depend on the inputs.
func RotateLeft32(x uint32, k int) uint32 {
const n = 32
s := uint(k) & (n - 1)
return x<<s | x>>(n-s)
}
// RotateLeft64 returns the value of x rotated left by (k mod 64) bits.
// To rotate x right by k bits, call RotateLeft64(x, -k).
//
// This function's execution time does not depend on the inputs.
func RotateLeft64(x uint64, k int) uint64 {
const n = 64
s := uint(k) & (n - 1)
return x<<s | x>>(n-s)
}
// --- Reverse ---
// Reverse returns the value of x with its bits in reversed order.
func Reverse(x uint) uint {
if UintSize == 32 {
return uint(Reverse32(uint32(x)))
}
return uint(Reverse64(uint64(x)))
}
// Reverse8 returns the value of x with its bits in reversed order.
func Reverse8(x uint8) uint8 {
return rev8tab[x]
}
// Reverse16 returns the value of x with its bits in reversed order.
func Reverse16(x uint16) uint16 {
return uint16(rev8tab[x>>8]) | uint16(rev8tab[x&0xff])<<8
}
// Reverse32 returns the value of x with its bits in reversed order.
func Reverse32(x uint32) uint32 {
const m = 1<<32 - 1
x = x>>1&(m0&m) | x&(m0&m)<<1
x = x>>2&(m1&m) | x&(m1&m)<<2
x = x>>4&(m2&m) | x&(m2&m)<<4
return ReverseBytes32(x)
}
// Reverse64 returns the value of x with its bits in reversed order.
func Reverse64(x uint64) uint64 {
const m = 1<<64 - 1
x = x>>1&(m0&m) | x&(m0&m)<<1
x = x>>2&(m1&m) | x&(m1&m)<<2
x = x>>4&(m2&m) | x&(m2&m)<<4
return ReverseBytes64(x)
}
// --- ReverseBytes ---
// ReverseBytes returns the value of x with its bytes in reversed order.
//
// This function's execution time does not depend on the inputs.
func ReverseBytes(x uint) uint {
if UintSize == 32 {
return uint(ReverseBytes32(uint32(x)))
}
return uint(ReverseBytes64(uint64(x)))
}
// ReverseBytes16 returns the value of x with its bytes in reversed order.
//
// This function's execution time does not depend on the inputs.
func ReverseBytes16(x uint16) uint16 {
return x>>8 | x<<8
}
// ReverseBytes32 returns the value of x with its bytes in reversed order.
//
// This function's execution time does not depend on the inputs.
func ReverseBytes32(x uint32) uint32 {
const m = 1<<32 - 1
x = x>>8&(m3&m) | x&(m3&m)<<8
return x>>16 | x<<16
}
// ReverseBytes64 returns the value of x with its bytes in reversed order.
//
// This function's execution time does not depend on the inputs.
func ReverseBytes64(x uint64) uint64 {
const m = 1<<64 - 1
x = x>>8&(m3&m) | x&(m3&m)<<8
x = x>>16&(m4&m) | x&(m4&m)<<16
return x>>32 | x<<32
}
// --- Len ---
// Len returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len(x uint) int {
if UintSize == 32 {
return Len32(uint32(x))
}
return Len64(uint64(x))
}
// Len8 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len8(x uint8) int {
return int(len8tab[x])
}
// Len16 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len16(x uint16) (n int) {
if x >= 1<<8 {
x >>= 8
n = 8
}
return n + int(len8tab[x])
}
// Len32 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len32(x uint32) (n int) {
if x >= 1<<16 {
x >>= 16
n = 16
}
if x >= 1<<8 {
x >>= 8
n += 8
}
return n + int(len8tab[x])
}
// Len64 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len64(x uint64) (n int) {
if x >= 1<<32 {
x >>= 32
n = 32
}
if x >= 1<<16 {
x >>= 16
n += 16
}
if x >= 1<<8 {
x >>= 8
n += 8
}
return n + int(len8tab[x])
}
// --- Add with carry ---
// Add returns the sum with carry of x, y and carry: sum = x + y + carry.
// The carry input must be 0 or 1; otherwise the behavior is undefined.
// The carryOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Add(x, y, carry uint) (sum, carryOut uint) {
if UintSize == 32 {
s32, c32 := Add32(uint32(x), uint32(y), uint32(carry))
return uint(s32), uint(c32)
}
s64, c64 := Add64(uint64(x), uint64(y), uint64(carry))
return uint(s64), uint(c64)
}
// Add32 returns the sum with carry of x, y and carry: sum = x + y + carry.
// The carry input must be 0 or 1; otherwise the behavior is undefined.
// The carryOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Add32(x, y, carry uint32) (sum, carryOut uint32) {
sum64 := uint64(x) + uint64(y) + uint64(carry)
sum = uint32(sum64)
carryOut = uint32(sum64 >> 32)
return
}
// Add64 returns the sum with carry of x, y and carry: sum = x + y + carry.
// The carry input must be 0 or 1; otherwise the behavior is undefined.
// The carryOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Add64(x, y, carry uint64) (sum, carryOut uint64) {
sum = x + y + carry
// The sum will overflow if both top bits are set (x & y) or if one of them
// is (x | y), and a carry from the lower place happened. If such a carry
// happens, the top bit will be 1 + 0 + 1 = 0 (&^ sum).
carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
return
}
// --- Subtract with borrow ---
// Sub returns the difference of x, y and borrow: diff = x - y - borrow.
// The borrow input must be 0 or 1; otherwise the behavior is undefined.
// The borrowOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Sub(x, y, borrow uint) (diff, borrowOut uint) {
if UintSize == 32 {
d32, b32 := Sub32(uint32(x), uint32(y), uint32(borrow))
return uint(d32), uint(b32)
}
d64, b64 := Sub64(uint64(x), uint64(y), uint64(borrow))
return uint(d64), uint(b64)
}
// Sub32 returns the difference of x, y and borrow, diff = x - y - borrow.
// The borrow input must be 0 or 1; otherwise the behavior is undefined.
// The borrowOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Sub32(x, y, borrow uint32) (diff, borrowOut uint32) {
diff = x - y - borrow
// The difference will underflow if the top bit of x is not set and the top
// bit of y is set (^x & y) or if they are the same (^(x ^ y)) and a borrow
// from the lower place happens. If that borrow happens, the result will be
// 1 - 1 - 1 = 0 - 0 - 1 = 1 (& diff).
borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 31
return
}
// Sub64 returns the difference of x, y and borrow: diff = x - y - borrow.
// The borrow input must be 0 or 1; otherwise the behavior is undefined.
// The borrowOut output is guaranteed to be 0 or 1.
//
// This function's execution time does not depend on the inputs.
func Sub64(x, y, borrow uint64) (diff, borrowOut uint64) {
diff = x - y - borrow
// See Sub32 for the bit logic.
borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
return
}
// --- Full-width multiply ---
// Mul returns the full-width product of x and y: (hi, lo) = x * y
// with the product bits' upper half returned in hi and the lower
// half returned in lo.
//
// This function's execution time does not depend on the inputs.
func Mul(x, y uint) (hi, lo uint) {
if UintSize == 32 {
h, l := Mul32(uint32(x), uint32(y))
return uint(h), uint(l)
}
h, l := Mul64(uint64(x), uint64(y))
return uint(h), uint(l)
}
// Mul32 returns the 64-bit product of x and y: (hi, lo) = x * y
// with the product bits' upper half returned in hi and the lower
// half returned in lo.
//
// This function's execution time does not depend on the inputs.
func Mul32(x, y uint32) (hi, lo uint32) {
tmp := uint64(x) * uint64(y)
hi, lo = uint32(tmp>>32), uint32(tmp)
return
}
// Mul64 returns the 128-bit product of x and y: (hi, lo) = x * y
// with the product bits' upper half returned in hi and the lower
// half returned in lo.
//
// This function's execution time does not depend on the inputs.
func Mul64(x, y uint64) (hi, lo uint64) {
const mask32 = 1<<32 - 1
x0 := x & mask32
x1 := x >> 32
y0 := y & mask32
y1 := y >> 32
w0 := x0 * y0
t := x1*y0 + w0>>32
w1 := t & mask32
w2 := t >> 32
w1 += x0 * y1
hi = x1*y1 + w2 + w1>>32
lo = x * y
return
}
// --- Full-width divide ---
// Div returns the quotient and remainder of (hi, lo) divided by y:
// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
// half in parameter hi and the lower half in parameter lo.
// Div panics for y == 0 (division by zero) or y <= hi (quotient overflow).
func Div(hi, lo, y uint) (quo, rem uint) {
if UintSize == 32 {
q, r := Div32(uint32(hi), uint32(lo), uint32(y))
return uint(q), uint(r)
}
q, r := Div64(uint64(hi), uint64(lo), uint64(y))
return uint(q), uint(r)
}
// Div32 returns the quotient and remainder of (hi, lo) divided by y:
// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
// half in parameter hi and the lower half in parameter lo.
// Div32 panics for y == 0 (division by zero) or y <= hi (quotient overflow).
func Div32(hi, lo, y uint32) (quo, rem uint32) {
if y != 0 && y <= hi {
panic(overflowError)
}
z := uint64(hi)<<32 | uint64(lo)
quo, rem = uint32(z/uint64(y)), uint32(z%uint64(y))
return
}
// Div64 returns the quotient and remainder of (hi, lo) divided by y:
// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
// half in parameter hi and the lower half in parameter lo.
// Div64 panics for y == 0 (division by zero) or y <= hi (quotient overflow).
func Div64(hi, lo, y uint64) (quo, rem uint64) {
if y == 0 {
panic(divideError)
}
if y <= hi {
panic(overflowError)
}
// If high part is zero, we can directly return the results.
if hi == 0 {
return lo / y, lo % y
}
s := uint(LeadingZeros64(y))
y <<= s
const (
two32 = 1 << 32
mask32 = two32 - 1
)
yn1 := y >> 32
yn0 := y & mask32
un32 := hi<<s | lo>>(64-s)
un10 := lo << s
un1 := un10 >> 32
un0 := un10 & mask32
q1 := un32 / yn1
rhat := un32 - q1*yn1
for q1 >= two32 || q1*yn0 > two32*rhat+un1 {
q1--
rhat += yn1
if rhat >= two32 {
break
}
}
un21 := un32*two32 + un1 - q1*y
q0 := un21 / yn1
rhat = un21 - q0*yn1
for q0 >= two32 || q0*yn0 > two32*rhat+un0 {
q0--
rhat += yn1
if rhat >= two32 {
break
}
}
return q1*two32 + q0, (un21*two32 + un0 - q0*y) >> s
}
// Rem returns the remainder of (hi, lo) divided by y. Rem panics for
// y == 0 (division by zero) but, unlike Div, it doesn't panic on a
// quotient overflow.
func Rem(hi, lo, y uint) uint {
if UintSize == 32 {
return uint(Rem32(uint32(hi), uint32(lo), uint32(y)))
}
return uint(Rem64(uint64(hi), uint64(lo), uint64(y)))
}
// Rem32 returns the remainder of (hi, lo) divided by y. Rem32 panics
// for y == 0 (division by zero) but, unlike Div32, it doesn't panic
// on a quotient overflow.
func Rem32(hi, lo, y uint32) uint32 {
return uint32((uint64(hi)<<32 | uint64(lo)) % uint64(y))
}
// Rem64 returns the remainder of (hi, lo) divided by y. Rem64 panics
// for y == 0 (division by zero) but, unlike Div64, it doesn't panic
// on a quotient overflow.
func Rem64(hi, lo, y uint64) uint64 {
// We scale down hi so that hi < y, then use Div64 to compute the
// rem with the guarantee that it won't panic on quotient overflow.
// Given that
// hi ≡ hi%y (mod y)
// we have
// hi<<64 + lo ≡ (hi%y)<<64 + lo (mod y)
_, rem := Div64(hi%y, lo, y)
return rem
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The go code is a modified version of the original C code from
// http://www.netlib.org/fdlibm/s_cbrt.c and came with this notice.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunSoft, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
// Cbrt returns the cube root of x.
//
// Special cases are:
//
// Cbrt(±0) = ±0
// Cbrt(±Inf) = ±Inf
// Cbrt(NaN) = NaN
func Cbrt(x float64) float64 {
if haveArchCbrt {
return archCbrt(x)
}
return cbrt(x)
}
func cbrt(x float64) float64 {
const (
B1 = 715094163 // (682-0.03306235651)*2**20
B2 = 696219795 // (664-0.03306235651)*2**20
C = 5.42857142857142815906e-01 // 19/35 = 0x3FE15F15F15F15F1
D = -7.05306122448979611050e-01 // -864/1225 = 0xBFE691DE2532C834
E = 1.41428571428571436819e+00 // 99/70 = 0x3FF6A0EA0EA0EA0F
F = 1.60714285714285720630e+00 // 45/28 = 0x3FF9B6DB6DB6DB6E
G = 3.57142857142857150787e-01 // 5/14 = 0x3FD6DB6DB6DB6DB7
SmallestNormal = 2.22507385850720138309e-308 // 2**-1022 = 0x0010000000000000
)
// special cases
switch {
case x == 0 || IsNaN(x) || IsInf(x, 0):
return x
}
sign := false
if x < 0 {
x = -x
sign = true
}
// rough cbrt to 5 bits
t := Float64frombits(Float64bits(x)/3 + B1<<32)
if x < SmallestNormal {
// subnormal number
t = float64(1 << 54) // set t= 2**54
t *= x
t = Float64frombits(Float64bits(t)/3 + B2<<32)
}
// new cbrt to 23 bits
r := t * t / x
s := C + r*t
t *= G + F/(s+E+D/s)
// chop to 22 bits, make larger than cbrt(x)
t = Float64frombits(Float64bits(t)&(0xFFFFFFFFC<<28) + 1<<30)
// one step newton iteration to 53 bits with error less than 0.667ulps
s = t * t // t*t is exact
r = x / s
w := t + t
r = (r - t) / (w + r) // r-s is exact
t = t + t*r
// restore the sign bit
if sign {
t = -t
}
return t
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cmplx provides basic constants and mathematical functions for
// complex numbers. Special case handling conforms to the C99 standard
// Annex G IEC 60559-compatible complex arithmetic.
package cmplx
import "math"
// Abs returns the absolute value (also called the modulus) of x.
func Abs(x complex128) float64 { return math.Hypot(real(x), imag(x)) }
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex circular arc sine
//
// DESCRIPTION:
//
// Inverse complex sine:
// 2
// w = -i clog( iz + csqrt( 1 - z ) ).
//
// casin(z) = -i casinh(iz)
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 10100 2.1e-15 3.4e-16
// IEEE -10,+10 30000 2.2e-14 2.7e-15
// Larger relative error can be observed for z near zero.
// Also tested by csin(casin(z)) = z.
// Asin returns the inverse sine of x.
func Asin(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case im == 0 && math.Abs(re) <= 1:
return complex(math.Asin(re), im)
case re == 0 && math.Abs(im) <= 1:
return complex(re, math.Asinh(im))
case math.IsNaN(im):
switch {
case re == 0:
return complex(re, math.NaN())
case math.IsInf(re, 0):
return complex(math.NaN(), re)
default:
return NaN()
}
case math.IsInf(im, 0):
switch {
case math.IsNaN(re):
return x
case math.IsInf(re, 0):
return complex(math.Copysign(math.Pi/4, re), im)
default:
return complex(math.Copysign(0, re), im)
}
case math.IsInf(re, 0):
return complex(math.Copysign(math.Pi/2, re), math.Copysign(re, im))
}
ct := complex(-imag(x), real(x)) // i * x
xx := x * x
x1 := complex(1-real(xx), -imag(xx)) // 1 - x*x
x2 := Sqrt(x1) // x2 = sqrt(1 - x*x)
w := Log(ct + x2)
return complex(imag(w), -real(w)) // -i * w
}
// Asinh returns the inverse hyperbolic sine of x.
func Asinh(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case im == 0 && math.Abs(re) <= 1:
return complex(math.Asinh(re), im)
case re == 0 && math.Abs(im) <= 1:
return complex(re, math.Asin(im))
case math.IsInf(re, 0):
switch {
case math.IsInf(im, 0):
return complex(re, math.Copysign(math.Pi/4, im))
case math.IsNaN(im):
return x
default:
return complex(re, math.Copysign(0.0, im))
}
case math.IsNaN(re):
switch {
case im == 0:
return x
case math.IsInf(im, 0):
return complex(im, re)
default:
return NaN()
}
case math.IsInf(im, 0):
return complex(math.Copysign(im, re), math.Copysign(math.Pi/2, im))
}
xx := x * x
x1 := complex(1+real(xx), imag(xx)) // 1 + x*x
return Log(x + Sqrt(x1)) // log(x + sqrt(1 + x*x))
}
// Complex circular arc cosine
//
// DESCRIPTION:
//
// w = arccos z = PI/2 - arcsin z.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 5200 1.6e-15 2.8e-16
// IEEE -10,+10 30000 1.8e-14 2.2e-15
// Acos returns the inverse cosine of x.
func Acos(x complex128) complex128 {
w := Asin(x)
return complex(math.Pi/2-real(w), -imag(w))
}
// Acosh returns the inverse hyperbolic cosine of x.
func Acosh(x complex128) complex128 {
if x == 0 {
return complex(0, math.Copysign(math.Pi/2, imag(x)))
}
w := Acos(x)
if imag(w) <= 0 {
return complex(-imag(w), real(w)) // i * w
}
return complex(imag(w), -real(w)) // -i * w
}
// Complex circular arc tangent
//
// DESCRIPTION:
//
// If
// z = x + iy,
//
// then
// 1 ( 2x )
// Re w = - arctan(-----------) + k PI
// 2 ( 2 2)
// (1 - x - y )
//
// ( 2 2)
// 1 (x + (y+1) )
// Im w = - log(------------)
// 4 ( 2 2)
// (x + (y-1) )
//
// Where k is an arbitrary integer.
//
// catan(z) = -i catanh(iz).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 5900 1.3e-16 7.8e-18
// IEEE -10,+10 30000 2.3e-15 8.5e-17
// The check catan( ctan(z) ) = z, with |x| and |y| < PI/2,
// had peak relative error 1.5e-16, rms relative error
// 2.9e-17. See also clog().
// Atan returns the inverse tangent of x.
func Atan(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case im == 0:
return complex(math.Atan(re), im)
case re == 0 && math.Abs(im) <= 1:
return complex(re, math.Atanh(im))
case math.IsInf(im, 0) || math.IsInf(re, 0):
if math.IsNaN(re) {
return complex(math.NaN(), math.Copysign(0, im))
}
return complex(math.Copysign(math.Pi/2, re), math.Copysign(0, im))
case math.IsNaN(re) || math.IsNaN(im):
return NaN()
}
x2 := real(x) * real(x)
a := 1 - x2 - imag(x)*imag(x)
if a == 0 {
return NaN()
}
t := 0.5 * math.Atan2(2*real(x), a)
w := reducePi(t)
t = imag(x) - 1
b := x2 + t*t
if b == 0 {
return NaN()
}
t = imag(x) + 1
c := (x2 + t*t) / b
return complex(w, 0.25*math.Log(c))
}
// Atanh returns the inverse hyperbolic tangent of x.
func Atanh(x complex128) complex128 {
z := complex(-imag(x), real(x)) // z = i * x
z = Atan(z)
return complex(imag(z), -real(z)) // z = -i * z
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
// Conj returns the complex conjugate of x.
func Conj(x complex128) complex128 { return complex(real(x), -imag(x)) }
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex exponential function
//
// DESCRIPTION:
//
// Returns the complex exponential of the complex argument z.
//
// If
// z = x + iy,
// r = exp(x),
// then
// w = r cos y + i r sin y.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 8700 3.7e-17 1.1e-17
// IEEE -10,+10 30000 3.0e-16 8.7e-17
// Exp returns e**x, the base-e exponential of x.
func Exp(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case math.IsInf(re, 0):
switch {
case re > 0 && im == 0:
return x
case math.IsInf(im, 0) || math.IsNaN(im):
if re < 0 {
return complex(0, math.Copysign(0, im))
} else {
return complex(math.Inf(1.0), math.NaN())
}
}
case math.IsNaN(re):
if im == 0 {
return complex(math.NaN(), im)
}
}
r := math.Exp(real(x))
s, c := math.Sincos(imag(x))
return complex(r*c, r*s)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// IsInf reports whether either real(x) or imag(x) is an infinity.
func IsInf(x complex128) bool {
if math.IsInf(real(x), 0) || math.IsInf(imag(x), 0) {
return true
}
return false
}
// Inf returns a complex infinity, complex(+Inf, +Inf).
func Inf() complex128 {
inf := math.Inf(1)
return complex(inf, inf)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// IsNaN reports whether either real(x) or imag(x) is NaN
// and neither is an infinity.
func IsNaN(x complex128) bool {
switch {
case math.IsInf(real(x), 0) || math.IsInf(imag(x), 0):
return false
case math.IsNaN(real(x)) || math.IsNaN(imag(x)):
return true
}
return false
}
// NaN returns a complex “not-a-number” value.
func NaN() complex128 {
nan := math.NaN()
return complex(nan, nan)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex natural logarithm
//
// DESCRIPTION:
//
// Returns complex logarithm to the base e (2.718...) of
// the complex argument z.
//
// If
// z = x + iy, r = sqrt( x**2 + y**2 ),
// then
// w = log(r) + i arctan(y/x).
//
// The arctangent ranges from -PI to +PI.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 7000 8.5e-17 1.9e-17
// IEEE -10,+10 30000 5.0e-15 1.1e-16
//
// Larger relative error can be observed for z near 1 +i0.
// In IEEE arithmetic the peak absolute error is 5.2e-16, rms
// absolute error 1.0e-16.
// Log returns the natural logarithm of x.
func Log(x complex128) complex128 {
return complex(math.Log(Abs(x)), Phase(x))
}
// Log10 returns the decimal logarithm of x.
func Log10(x complex128) complex128 {
z := Log(x)
return complex(math.Log10E*real(z), math.Log10E*imag(z))
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// Phase returns the phase (also called the argument) of x.
// The returned value is in the range [-Pi, Pi].
func Phase(x complex128) float64 { return math.Atan2(imag(x), real(x)) }
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
// Polar returns the absolute value r and phase θ of x,
// such that x = r * e**θi.
// The phase is in the range [-Pi, Pi].
func Polar(x complex128) (r, θ float64) {
return Abs(x), Phase(x)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex power function
//
// DESCRIPTION:
//
// Raises complex A to the complex Zth power.
// Definition is per AMS55 # 4.2.8,
// analytically equivalent to cpow(a,z) = cexp(z clog(a)).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -10,+10 30000 9.4e-15 1.5e-15
// Pow returns x**y, the base-x exponential of y.
// For generalized compatibility with math.Pow:
//
// Pow(0, ±0) returns 1+0i
// Pow(0, c) for real(c)<0 returns Inf+0i if imag(c) is zero, otherwise Inf+Inf i.
func Pow(x, y complex128) complex128 {
if x == 0 { // Guaranteed also true for x == -0.
if IsNaN(y) {
return NaN()
}
r, i := real(y), imag(y)
switch {
case r == 0:
return 1
case r < 0:
if i == 0 {
return complex(math.Inf(1), 0)
}
return Inf()
case r > 0:
return 0
}
panic("not reached")
}
modulus := Abs(x)
if modulus == 0 {
return complex(0, 0)
}
r := math.Pow(modulus, real(y))
arg := Phase(x)
theta := real(y) * arg
if imag(y) != 0 {
r *= math.Exp(-imag(y) * arg)
theta += imag(y) * math.Log(modulus)
}
s, c := math.Sincos(theta)
return complex(r*c, r*s)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// Rect returns the complex number x with polar coordinates r, θ.
func Rect(r, θ float64) complex128 {
s, c := math.Sincos(θ)
return complex(r*c, r*s)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex circular sine
//
// DESCRIPTION:
//
// If
// z = x + iy,
//
// then
//
// w = sin x cosh y + i cos x sinh y.
//
// csin(z) = -i csinh(iz).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 8400 5.3e-17 1.3e-17
// IEEE -10,+10 30000 3.8e-16 1.0e-16
// Also tested by csin(casin(z)) = z.
// Sin returns the sine of x.
func Sin(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case im == 0 && (math.IsInf(re, 0) || math.IsNaN(re)):
return complex(math.NaN(), im)
case math.IsInf(im, 0):
switch {
case re == 0:
return x
case math.IsInf(re, 0) || math.IsNaN(re):
return complex(math.NaN(), im)
}
case re == 0 && math.IsNaN(im):
return x
}
s, c := math.Sincos(real(x))
sh, ch := sinhcosh(imag(x))
return complex(s*ch, c*sh)
}
// Complex hyperbolic sine
//
// DESCRIPTION:
//
// csinh z = (cexp(z) - cexp(-z))/2
// = sinh x * cos y + i cosh x * sin y .
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -10,+10 30000 3.1e-16 8.2e-17
// Sinh returns the hyperbolic sine of x.
func Sinh(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case re == 0 && (math.IsInf(im, 0) || math.IsNaN(im)):
return complex(re, math.NaN())
case math.IsInf(re, 0):
switch {
case im == 0:
return complex(re, im)
case math.IsInf(im, 0) || math.IsNaN(im):
return complex(re, math.NaN())
}
case im == 0 && math.IsNaN(re):
return complex(math.NaN(), im)
}
s, c := math.Sincos(imag(x))
sh, ch := sinhcosh(real(x))
return complex(c*sh, s*ch)
}
// Complex circular cosine
//
// DESCRIPTION:
//
// If
// z = x + iy,
//
// then
//
// w = cos x cosh y - i sin x sinh y.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 8400 4.5e-17 1.3e-17
// IEEE -10,+10 30000 3.8e-16 1.0e-16
// Cos returns the cosine of x.
func Cos(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case im == 0 && (math.IsInf(re, 0) || math.IsNaN(re)):
return complex(math.NaN(), -im*math.Copysign(0, re))
case math.IsInf(im, 0):
switch {
case re == 0:
return complex(math.Inf(1), -re*math.Copysign(0, im))
case math.IsInf(re, 0) || math.IsNaN(re):
return complex(math.Inf(1), math.NaN())
}
case re == 0 && math.IsNaN(im):
return complex(math.NaN(), 0)
}
s, c := math.Sincos(real(x))
sh, ch := sinhcosh(imag(x))
return complex(c*ch, -s*sh)
}
// Complex hyperbolic cosine
//
// DESCRIPTION:
//
// ccosh(z) = cosh x cos y + i sinh x sin y .
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -10,+10 30000 2.9e-16 8.1e-17
// Cosh returns the hyperbolic cosine of x.
func Cosh(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case re == 0 && (math.IsInf(im, 0) || math.IsNaN(im)):
return complex(math.NaN(), re*math.Copysign(0, im))
case math.IsInf(re, 0):
switch {
case im == 0:
return complex(math.Inf(1), im*math.Copysign(0, re))
case math.IsInf(im, 0) || math.IsNaN(im):
return complex(math.Inf(1), math.NaN())
}
case im == 0 && math.IsNaN(re):
return complex(math.NaN(), im)
}
s, c := math.Sincos(imag(x))
sh, ch := sinhcosh(real(x))
return complex(c*ch, s*sh)
}
// calculate sinh and cosh.
func sinhcosh(x float64) (sh, ch float64) {
if math.Abs(x) <= 0.5 {
return math.Sinh(x), math.Cosh(x)
}
e := math.Exp(x)
ei := 0.5 / e
e *= 0.5
return e - ei, e + ei
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import "math"
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex square root
//
// DESCRIPTION:
//
// If z = x + iy, r = |z|, then
//
// 1/2
// Re w = [ (r + x)/2 ] ,
//
// 1/2
// Im w = [ (r - x)/2 ] .
//
// Cancellation error in r-x or r+x is avoided by using the
// identity 2 Re w Im w = y.
//
// Note that -w is also a square root of z. The root chosen
// is always in the right half plane and Im w has the same sign as y.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 25000 3.2e-17 9.6e-18
// IEEE -10,+10 1,000,000 2.9e-16 6.1e-17
// Sqrt returns the square root of x.
// The result r is chosen so that real(r) ≥ 0 and imag(r) has the same sign as imag(x).
func Sqrt(x complex128) complex128 {
if imag(x) == 0 {
// Ensure that imag(r) has the same sign as imag(x) for imag(x) == signed zero.
if real(x) == 0 {
return complex(0, imag(x))
}
if real(x) < 0 {
return complex(0, math.Copysign(math.Sqrt(-real(x)), imag(x)))
}
return complex(math.Sqrt(real(x)), imag(x))
} else if math.IsInf(imag(x), 0) {
return complex(math.Inf(1.0), imag(x))
}
if real(x) == 0 {
if imag(x) < 0 {
r := math.Sqrt(-0.5 * imag(x))
return complex(r, -r)
}
r := math.Sqrt(0.5 * imag(x))
return complex(r, r)
}
a := real(x)
b := imag(x)
var scale float64
// Rescale to avoid internal overflow or underflow.
if math.Abs(a) > 4 || math.Abs(b) > 4 {
a *= 0.25
b *= 0.25
scale = 2
} else {
a *= 1.8014398509481984e16 // 2**54
b *= 1.8014398509481984e16
scale = 7.450580596923828125e-9 // 2**-27
}
r := math.Hypot(a, b)
var t float64
if a > 0 {
t = math.Sqrt(0.5*r + 0.5*a)
r = scale * math.Abs((0.5*b)/t)
t *= scale
} else {
r = math.Sqrt(0.5*r - 0.5*a)
t = scale * math.Abs((0.5*b)/r)
r *= scale
}
if b < 0 {
return complex(t, -r)
}
return complex(t, r)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cmplx
import (
"math"
"math/bits"
)
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/c9x-complex/clog.c.
// The go code is a simplified version of the original C.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// Complex circular tangent
//
// DESCRIPTION:
//
// If
// z = x + iy,
//
// then
//
// sin 2x + i sinh 2y
// w = --------------------.
// cos 2x + cosh 2y
//
// On the real axis the denominator is zero at odd multiples
// of PI/2. The denominator is evaluated by its Taylor
// series near these points.
//
// ctan(z) = -i ctanh(iz).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 5200 7.1e-17 1.6e-17
// IEEE -10,+10 30000 7.2e-16 1.2e-16
// Also tested by ctan * ccot = 1 and catan(ctan(z)) = z.
// Tan returns the tangent of x.
func Tan(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case math.IsInf(im, 0):
switch {
case math.IsInf(re, 0) || math.IsNaN(re):
return complex(math.Copysign(0, re), math.Copysign(1, im))
}
return complex(math.Copysign(0, math.Sin(2*re)), math.Copysign(1, im))
case re == 0 && math.IsNaN(im):
return x
}
d := math.Cos(2*real(x)) + math.Cosh(2*imag(x))
if math.Abs(d) < 0.25 {
d = tanSeries(x)
}
if d == 0 {
return Inf()
}
return complex(math.Sin(2*real(x))/d, math.Sinh(2*imag(x))/d)
}
// Complex hyperbolic tangent
//
// DESCRIPTION:
//
// tanh z = (sinh 2x + i sin 2y) / (cosh 2x + cos 2y) .
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -10,+10 30000 1.7e-14 2.4e-16
// Tanh returns the hyperbolic tangent of x.
func Tanh(x complex128) complex128 {
switch re, im := real(x), imag(x); {
case math.IsInf(re, 0):
switch {
case math.IsInf(im, 0) || math.IsNaN(im):
return complex(math.Copysign(1, re), math.Copysign(0, im))
}
return complex(math.Copysign(1, re), math.Copysign(0, math.Sin(2*im)))
case im == 0 && math.IsNaN(re):
return x
}
d := math.Cosh(2*real(x)) + math.Cos(2*imag(x))
if d == 0 {
return Inf()
}
return complex(math.Sinh(2*real(x))/d, math.Sin(2*imag(x))/d)
}
// reducePi reduces the input argument x to the range (-Pi/2, Pi/2].
// x must be greater than or equal to 0. For small arguments it
// uses Cody-Waite reduction in 3 float64 parts based on:
// "Elementary Function Evaluation: Algorithms and Implementation"
// Jean-Michel Muller, 1997.
// For very large arguments it uses Payne-Hanek range reduction based on:
// "ARGUMENT REDUCTION FOR HUGE ARGUMENTS: Good to the Last Bit"
// K. C. Ng et al, March 24, 1992.
func reducePi(x float64) float64 {
// reduceThreshold is the maximum value of x where the reduction using
// Cody-Waite reduction still gives accurate results. This threshold
// is set by t*PIn being representable as a float64 without error
// where t is given by t = floor(x * (1 / Pi)) and PIn are the leading partial
// terms of Pi. Since the leading terms, PI1 and PI2 below, have 30 and 32
// trailing zero bits respectively, t should have less than 30 significant bits.
// t < 1<<30 -> floor(x*(1/Pi)+0.5) < 1<<30 -> x < (1<<30-1) * Pi - 0.5
// So, conservatively we can take x < 1<<30.
const reduceThreshold float64 = 1 << 30
if math.Abs(x) < reduceThreshold {
// Use Cody-Waite reduction in three parts.
const (
// PI1, PI2 and PI3 comprise an extended precision value of PI
// such that PI ~= PI1 + PI2 + PI3. The parts are chosen so
// that PI1 and PI2 have an approximately equal number of trailing
// zero bits. This ensures that t*PI1 and t*PI2 are exact for
// large integer values of t. The full precision PI3 ensures the
// approximation of PI is accurate to 102 bits to handle cancellation
// during subtraction.
PI1 = 3.141592502593994 // 0x400921fb40000000
PI2 = 1.5099578831723193e-07 // 0x3e84442d00000000
PI3 = 1.0780605716316238e-14 // 0x3d08469898cc5170
)
t := x / math.Pi
t += 0.5
t = float64(int64(t)) // int64(t) = the multiple
return ((x - t*PI1) - t*PI2) - t*PI3
}
// Must apply Payne-Hanek range reduction
const (
mask = 0x7FF
shift = 64 - 11 - 1
bias = 1023
fracMask = 1<<shift - 1
)
// Extract out the integer and exponent such that,
// x = ix * 2 ** exp.
ix := math.Float64bits(x)
exp := int(ix>>shift&mask) - bias - shift
ix &= fracMask
ix |= 1 << shift
// mPi is the binary digits of 1/Pi as a uint64 array,
// that is, 1/Pi = Sum mPi[i]*2^(-64*i).
// 19 64-bit digits give 1216 bits of precision
// to handle the largest possible float64 exponent.
var mPi = [...]uint64{
0x0000000000000000,
0x517cc1b727220a94,
0xfe13abe8fa9a6ee0,
0x6db14acc9e21c820,
0xff28b1d5ef5de2b0,
0xdb92371d2126e970,
0x0324977504e8c90e,
0x7f0ef58e5894d39f,
0x74411afa975da242,
0x74ce38135a2fbf20,
0x9cc8eb1cc1a99cfa,
0x4e422fc5defc941d,
0x8ffc4bffef02cc07,
0xf79788c5ad05368f,
0xb69b3f6793e584db,
0xa7a31fb34f2ff516,
0xba93dd63f5f2f8bd,
0x9e839cfbc5294975,
0x35fdafd88fc6ae84,
0x2b0198237e3db5d5,
}
// Use the exponent to extract the 3 appropriate uint64 digits from mPi,
// B ~ (z0, z1, z2), such that the product leading digit has the exponent -64.
// Note, exp >= 50 since x >= reduceThreshold and exp < 971 for maximum float64.
digit, bitshift := uint(exp+64)/64, uint(exp+64)%64
z0 := (mPi[digit] << bitshift) | (mPi[digit+1] >> (64 - bitshift))
z1 := (mPi[digit+1] << bitshift) | (mPi[digit+2] >> (64 - bitshift))
z2 := (mPi[digit+2] << bitshift) | (mPi[digit+3] >> (64 - bitshift))
// Multiply mantissa by the digits and extract the upper two digits (hi, lo).
z2hi, _ := bits.Mul64(z2, ix)
z1hi, z1lo := bits.Mul64(z1, ix)
z0lo := z0 * ix
lo, c := bits.Add64(z1lo, z2hi, 0)
hi, _ := bits.Add64(z0lo, z1hi, c)
// Find the magnitude of the fraction.
lz := uint(bits.LeadingZeros64(hi))
e := uint64(bias - (lz + 1))
// Clear implicit mantissa bit and shift into place.
hi = (hi << (lz + 1)) | (lo >> (64 - (lz + 1)))
hi >>= 64 - shift
// Include the exponent and convert to a float.
hi |= e << shift
x = math.Float64frombits(hi)
// map to (-Pi/2, Pi/2]
if x > 0.5 {
x--
}
return math.Pi * x
}
// Taylor series expansion for cosh(2y) - cos(2x)
func tanSeries(z complex128) float64 {
const MACHEP = 1.0 / (1 << 53)
x := math.Abs(2 * real(z))
y := math.Abs(2 * imag(z))
x = reducePi(x)
x = x * x
y = y * y
x2 := 1.0
y2 := 1.0
f := 1.0
rn := 0.0
d := 0.0
for {
rn++
f *= rn
rn++
f *= rn
x2 *= x
y2 *= y
t := y2 + x2
t /= f
d += t
rn++
f *= rn
rn++
f *= rn
x2 *= x
y2 *= y
t = y2 - x2
t /= f
d += t
if !(math.Abs(t/d) > MACHEP) {
// Caution: Use ! and > instead of <= for correct behavior if t/d is NaN.
// See issue 17577.
break
}
}
return d
}
// Complex circular cotangent
//
// DESCRIPTION:
//
// If
// z = x + iy,
//
// then
//
// sin 2x - i sinh 2y
// w = --------------------.
// cosh 2y - cos 2x
//
// On the real axis, the denominator has zeros at even
// multiples of PI/2. Near these points it is evaluated
// by a Taylor series.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -10,+10 3000 6.5e-17 1.6e-17
// IEEE -10,+10 30000 9.2e-16 1.2e-16
// Also tested by ctan * ccot = 1 + i0.
// Cot returns the cotangent of x.
func Cot(x complex128) complex128 {
d := math.Cosh(2*imag(x)) - math.Cos(2*real(x))
if math.Abs(d) < 0.25 {
d = tanSeries(x)
}
if d == 0 {
return Inf()
}
return complex(math.Sin(2*real(x))/d, -math.Sinh(2*imag(x))/d)
}
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
func sign[I int32 | int64](a, b I) int {
if a < b {
return -1
}
if a > b {
return 1
}
return 0
}
// Compare compares a and b such that
// -NaN is ordered before any other value,
// +NaN is ordered after any other value,
// and -0 is ordered before +0.
// In other words, it defines a total order over floats
// (according to the total-ordering predicate in IEEE-754, section 5.10).
// It returns 0 if a == b, -1 if a < b, and +1 if a > b.
func Compare(a, b float64) int {
// Perform a bitwise comparison (a < b) by casting the float64s into an int64s.
x := int64(Float64bits(a))
y := int64(Float64bits(b))
// If a and b are both negative, flip the comparison so that we check a > b.
if x < 0 && y < 0 {
return sign(y, x)
}
return sign(x, y)
}
// Compare32 compares a and b such that
// -NaN is ordered before any other value,
// +NaN is ordered after any other value,
// and -0 is ordered before +0.
// In other words, it defines a total order over floats
// (according to the total-ordering predicate in IEEE-754, section 5.10).
// It returns 0 if a == b, -1 if a < b, and +1 if a > b.
func Compare32(a, b float32) int {
// Perform a bitwise comparison (a < b) by casting the float32s into an int32s.
x := int32(Float32bits(a))
y := int32(Float32bits(b))
// If a and b are both negative, flip the comparison so that we check a > b.
if x < 0 && y < 0 {
return sign(y, x)
}
return sign(x, y)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Copysign returns a value with the magnitude of f
// and the sign of sign.
func Copysign(f, sign float64) float64 {
const signBit = 1 << 63
return Float64frombits(Float64bits(f)&^signBit | Float64bits(sign)&signBit)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Dim returns the maximum of x-y or 0.
//
// Special cases are:
//
// Dim(+Inf, +Inf) = NaN
// Dim(-Inf, -Inf) = NaN
// Dim(x, NaN) = Dim(NaN, x) = NaN
func Dim(x, y float64) float64 {
// The special cases result in NaN after the subtraction:
// +Inf - +Inf = NaN
// -Inf - -Inf = NaN
// NaN - y = NaN
// x - NaN = NaN
v := x - y
if v <= 0 {
// v is negative or 0
return 0
}
// v is positive or NaN
return v
}
// Max returns the larger of x or y.
//
// Special cases are:
//
// Max(x, +Inf) = Max(+Inf, x) = +Inf
// Max(x, NaN) = Max(NaN, x) = NaN
// Max(+0, ±0) = Max(±0, +0) = +0
// Max(-0, -0) = -0
func Max(x, y float64) float64 {
if haveArchMax {
return archMax(x, y)
}
return max(x, y)
}
func max(x, y float64) float64 {
// special cases
switch {
case IsInf(x, 1) || IsInf(y, 1):
return Inf(1)
case IsNaN(x) || IsNaN(y):
return NaN()
case x == 0 && x == y:
if Signbit(x) {
return y
}
return x
}
if x > y {
return x
}
return y
}
// Min returns the smaller of x or y.
//
// Special cases are:
//
// Min(x, -Inf) = Min(-Inf, x) = -Inf
// Min(x, NaN) = Min(NaN, x) = NaN
// Min(-0, ±0) = Min(±0, -0) = -0
func Min(x, y float64) float64 {
if haveArchMin {
return archMin(x, y)
}
return min(x, y)
}
func min(x, y float64) float64 {
// special cases
switch {
case IsInf(x, -1) || IsInf(y, -1):
return Inf(-1)
case IsNaN(x) || IsNaN(y):
return NaN()
case x == 0 && x == y:
if Signbit(x) {
return x
}
return y
}
if x < y {
return x
}
return y
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point error function and complementary error function.
*/
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/s_erf.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// double erf(double x)
// double erfc(double x)
// x
// 2 |\
// erf(x) = --------- | exp(-t*t)dt
// sqrt(pi) \|
// 0
//
// erfc(x) = 1-erf(x)
// Note that
// erf(-x) = -erf(x)
// erfc(-x) = 2 - erfc(x)
//
// Method:
// 1. For |x| in [0, 0.84375]
// erf(x) = x + x*R(x**2)
// erfc(x) = 1 - erf(x) if x in [-.84375,0.25]
// = 0.5 + ((0.5-x)-x*R) if x in [0.25,0.84375]
// where R = P/Q where P is an odd poly of degree 8 and
// Q is an odd poly of degree 10.
// -57.90
// | R - (erf(x)-x)/x | <= 2
//
//
// Remark. The formula is derived by noting
// erf(x) = (2/sqrt(pi))*(x - x**3/3 + x**5/10 - x**7/42 + ....)
// and that
// 2/sqrt(pi) = 1.128379167095512573896158903121545171688
// is close to one. The interval is chosen because the fix
// point of erf(x) is near 0.6174 (i.e., erf(x)=x when x is
// near 0.6174), and by some experiment, 0.84375 is chosen to
// guarantee the error is less than one ulp for erf.
//
// 2. For |x| in [0.84375,1.25], let s = |x| - 1, and
// c = 0.84506291151 rounded to single (24 bits)
// erf(x) = sign(x) * (c + P1(s)/Q1(s))
// erfc(x) = (1-c) - P1(s)/Q1(s) if x > 0
// 1+(c+P1(s)/Q1(s)) if x < 0
// |P1/Q1 - (erf(|x|)-c)| <= 2**-59.06
// Remark: here we use the taylor series expansion at x=1.
// erf(1+s) = erf(1) + s*Poly(s)
// = 0.845.. + P1(s)/Q1(s)
// That is, we use rational approximation to approximate
// erf(1+s) - (c = (single)0.84506291151)
// Note that |P1/Q1|< 0.078 for x in [0.84375,1.25]
// where
// P1(s) = degree 6 poly in s
// Q1(s) = degree 6 poly in s
//
// 3. For x in [1.25,1/0.35(~2.857143)],
// erfc(x) = (1/x)*exp(-x*x-0.5625+R1/S1)
// erf(x) = 1 - erfc(x)
// where
// R1(z) = degree 7 poly in z, (z=1/x**2)
// S1(z) = degree 8 poly in z
//
// 4. For x in [1/0.35,28]
// erfc(x) = (1/x)*exp(-x*x-0.5625+R2/S2) if x > 0
// = 2.0 - (1/x)*exp(-x*x-0.5625+R2/S2) if -6<x<0
// = 2.0 - tiny (if x <= -6)
// erf(x) = sign(x)*(1.0 - erfc(x)) if x < 6, else
// erf(x) = sign(x)*(1.0 - tiny)
// where
// R2(z) = degree 6 poly in z, (z=1/x**2)
// S2(z) = degree 7 poly in z
//
// Note1:
// To compute exp(-x*x-0.5625+R/S), let s be a single
// precision number and s := x; then
// -x*x = -s*s + (s-x)*(s+x)
// exp(-x*x-0.5626+R/S) =
// exp(-s*s-0.5625)*exp((s-x)*(s+x)+R/S);
// Note2:
// Here 4 and 5 make use of the asymptotic series
// exp(-x*x)
// erfc(x) ~ ---------- * ( 1 + Poly(1/x**2) )
// x*sqrt(pi)
// We use rational approximation to approximate
// g(s)=f(1/x**2) = log(erfc(x)*x) - x*x + 0.5625
// Here is the error bound for R1/S1 and R2/S2
// |R1/S1 - f(x)| < 2**(-62.57)
// |R2/S2 - f(x)| < 2**(-61.52)
//
// 5. For inf > x >= 28
// erf(x) = sign(x) *(1 - tiny) (raise inexact)
// erfc(x) = tiny*tiny (raise underflow) if x > 0
// = 2 - tiny if x<0
//
// 7. Special case:
// erf(0) = 0, erf(inf) = 1, erf(-inf) = -1,
// erfc(0) = 1, erfc(inf) = 0, erfc(-inf) = 2,
// erfc/erf(NaN) is NaN
const (
erx = 8.45062911510467529297e-01 // 0x3FEB0AC160000000
// Coefficients for approximation to erf in [0, 0.84375]
efx = 1.28379167095512586316e-01 // 0x3FC06EBA8214DB69
efx8 = 1.02703333676410069053e+00 // 0x3FF06EBA8214DB69
pp0 = 1.28379167095512558561e-01 // 0x3FC06EBA8214DB68
pp1 = -3.25042107247001499370e-01 // 0xBFD4CD7D691CB913
pp2 = -2.84817495755985104766e-02 // 0xBF9D2A51DBD7194F
pp3 = -5.77027029648944159157e-03 // 0xBF77A291236668E4
pp4 = -2.37630166566501626084e-05 // 0xBEF8EAD6120016AC
qq1 = 3.97917223959155352819e-01 // 0x3FD97779CDDADC09
qq2 = 6.50222499887672944485e-02 // 0x3FB0A54C5536CEBA
qq3 = 5.08130628187576562776e-03 // 0x3F74D022C4D36B0F
qq4 = 1.32494738004321644526e-04 // 0x3F215DC9221C1A10
qq5 = -3.96022827877536812320e-06 // 0xBED09C4342A26120
// Coefficients for approximation to erf in [0.84375, 1.25]
pa0 = -2.36211856075265944077e-03 // 0xBF6359B8BEF77538
pa1 = 4.14856118683748331666e-01 // 0x3FDA8D00AD92B34D
pa2 = -3.72207876035701323847e-01 // 0xBFD7D240FBB8C3F1
pa3 = 3.18346619901161753674e-01 // 0x3FD45FCA805120E4
pa4 = -1.10894694282396677476e-01 // 0xBFBC63983D3E28EC
pa5 = 3.54783043256182359371e-02 // 0x3FA22A36599795EB
pa6 = -2.16637559486879084300e-03 // 0xBF61BF380A96073F
qa1 = 1.06420880400844228286e-01 // 0x3FBB3E6618EEE323
qa2 = 5.40397917702171048937e-01 // 0x3FE14AF092EB6F33
qa3 = 7.18286544141962662868e-02 // 0x3FB2635CD99FE9A7
qa4 = 1.26171219808761642112e-01 // 0x3FC02660E763351F
qa5 = 1.36370839120290507362e-02 // 0x3F8BEDC26B51DD1C
qa6 = 1.19844998467991074170e-02 // 0x3F888B545735151D
// Coefficients for approximation to erfc in [1.25, 1/0.35]
ra0 = -9.86494403484714822705e-03 // 0xBF843412600D6435
ra1 = -6.93858572707181764372e-01 // 0xBFE63416E4BA7360
ra2 = -1.05586262253232909814e+01 // 0xC0251E0441B0E726
ra3 = -6.23753324503260060396e+01 // 0xC04F300AE4CBA38D
ra4 = -1.62396669462573470355e+02 // 0xC0644CB184282266
ra5 = -1.84605092906711035994e+02 // 0xC067135CEBCCABB2
ra6 = -8.12874355063065934246e+01 // 0xC054526557E4D2F2
ra7 = -9.81432934416914548592e+00 // 0xC023A0EFC69AC25C
sa1 = 1.96512716674392571292e+01 // 0x4033A6B9BD707687
sa2 = 1.37657754143519042600e+02 // 0x4061350C526AE721
sa3 = 4.34565877475229228821e+02 // 0x407B290DD58A1A71
sa4 = 6.45387271733267880336e+02 // 0x40842B1921EC2868
sa5 = 4.29008140027567833386e+02 // 0x407AD02157700314
sa6 = 1.08635005541779435134e+02 // 0x405B28A3EE48AE2C
sa7 = 6.57024977031928170135e+00 // 0x401A47EF8E484A93
sa8 = -6.04244152148580987438e-02 // 0xBFAEEFF2EE749A62
// Coefficients for approximation to erfc in [1/.35, 28]
rb0 = -9.86494292470009928597e-03 // 0xBF84341239E86F4A
rb1 = -7.99283237680523006574e-01 // 0xBFE993BA70C285DE
rb2 = -1.77579549177547519889e+01 // 0xC031C209555F995A
rb3 = -1.60636384855821916062e+02 // 0xC064145D43C5ED98
rb4 = -6.37566443368389627722e+02 // 0xC083EC881375F228
rb5 = -1.02509513161107724954e+03 // 0xC09004616A2E5992
rb6 = -4.83519191608651397019e+02 // 0xC07E384E9BDC383F
sb1 = 3.03380607434824582924e+01 // 0x403E568B261D5190
sb2 = 3.25792512996573918826e+02 // 0x40745CAE221B9F0A
sb3 = 1.53672958608443695994e+03 // 0x409802EB189D5118
sb4 = 3.19985821950859553908e+03 // 0x40A8FFB7688C246A
sb5 = 2.55305040643316442583e+03 // 0x40A3F219CEDF3BE6
sb6 = 4.74528541206955367215e+02 // 0x407DA874E79FE763
sb7 = -2.24409524465858183362e+01 // 0xC03670E242712D62
)
// Erf returns the error function of x.
//
// Special cases are:
//
// Erf(+Inf) = 1
// Erf(-Inf) = -1
// Erf(NaN) = NaN
func Erf(x float64) float64 {
if haveArchErf {
return archErf(x)
}
return erf(x)
}
func erf(x float64) float64 {
const (
VeryTiny = 2.848094538889218e-306 // 0x0080000000000000
Small = 1.0 / (1 << 28) // 2**-28
)
// special cases
switch {
case IsNaN(x):
return NaN()
case IsInf(x, 1):
return 1
case IsInf(x, -1):
return -1
}
sign := false
if x < 0 {
x = -x
sign = true
}
if x < 0.84375 { // |x| < 0.84375
var temp float64
if x < Small { // |x| < 2**-28
if x < VeryTiny {
temp = 0.125 * (8.0*x + efx8*x) // avoid underflow
} else {
temp = x + efx*x
}
} else {
z := x * x
r := pp0 + z*(pp1+z*(pp2+z*(pp3+z*pp4)))
s := 1 + z*(qq1+z*(qq2+z*(qq3+z*(qq4+z*qq5))))
y := r / s
temp = x + x*y
}
if sign {
return -temp
}
return temp
}
if x < 1.25 { // 0.84375 <= |x| < 1.25
s := x - 1
P := pa0 + s*(pa1+s*(pa2+s*(pa3+s*(pa4+s*(pa5+s*pa6)))))
Q := 1 + s*(qa1+s*(qa2+s*(qa3+s*(qa4+s*(qa5+s*qa6)))))
if sign {
return -erx - P/Q
}
return erx + P/Q
}
if x >= 6 { // inf > |x| >= 6
if sign {
return -1
}
return 1
}
s := 1 / (x * x)
var R, S float64
if x < 1/0.35 { // |x| < 1 / 0.35 ~ 2.857143
R = ra0 + s*(ra1+s*(ra2+s*(ra3+s*(ra4+s*(ra5+s*(ra6+s*ra7))))))
S = 1 + s*(sa1+s*(sa2+s*(sa3+s*(sa4+s*(sa5+s*(sa6+s*(sa7+s*sa8)))))))
} else { // |x| >= 1 / 0.35 ~ 2.857143
R = rb0 + s*(rb1+s*(rb2+s*(rb3+s*(rb4+s*(rb5+s*rb6)))))
S = 1 + s*(sb1+s*(sb2+s*(sb3+s*(sb4+s*(sb5+s*(sb6+s*sb7))))))
}
z := Float64frombits(Float64bits(x) & 0xffffffff00000000) // pseudo-single (20-bit) precision x
r := Exp(-z*z-0.5625) * Exp((z-x)*(z+x)+R/S)
if sign {
return r/x - 1
}
return 1 - r/x
}
// Erfc returns the complementary error function of x.
//
// Special cases are:
//
// Erfc(+Inf) = 0
// Erfc(-Inf) = 2
// Erfc(NaN) = NaN
func Erfc(x float64) float64 {
if haveArchErfc {
return archErfc(x)
}
return erfc(x)
}
func erfc(x float64) float64 {
const Tiny = 1.0 / (1 << 56) // 2**-56
// special cases
switch {
case IsNaN(x):
return NaN()
case IsInf(x, 1):
return 0
case IsInf(x, -1):
return 2
}
sign := false
if x < 0 {
x = -x
sign = true
}
if x < 0.84375 { // |x| < 0.84375
var temp float64
if x < Tiny { // |x| < 2**-56
temp = x
} else {
z := x * x
r := pp0 + z*(pp1+z*(pp2+z*(pp3+z*pp4)))
s := 1 + z*(qq1+z*(qq2+z*(qq3+z*(qq4+z*qq5))))
y := r / s
if x < 0.25 { // |x| < 1/4
temp = x + x*y
} else {
temp = 0.5 + (x*y + (x - 0.5))
}
}
if sign {
return 1 + temp
}
return 1 - temp
}
if x < 1.25 { // 0.84375 <= |x| < 1.25
s := x - 1
P := pa0 + s*(pa1+s*(pa2+s*(pa3+s*(pa4+s*(pa5+s*pa6)))))
Q := 1 + s*(qa1+s*(qa2+s*(qa3+s*(qa4+s*(qa5+s*qa6)))))
if sign {
return 1 + erx + P/Q
}
return 1 - erx - P/Q
}
if x < 28 { // |x| < 28
s := 1 / (x * x)
var R, S float64
if x < 1/0.35 { // |x| < 1 / 0.35 ~ 2.857143
R = ra0 + s*(ra1+s*(ra2+s*(ra3+s*(ra4+s*(ra5+s*(ra6+s*ra7))))))
S = 1 + s*(sa1+s*(sa2+s*(sa3+s*(sa4+s*(sa5+s*(sa6+s*(sa7+s*sa8)))))))
} else { // |x| >= 1 / 0.35 ~ 2.857143
if sign && x > 6 {
return 2 // x < -6
}
R = rb0 + s*(rb1+s*(rb2+s*(rb3+s*(rb4+s*(rb5+s*rb6)))))
S = 1 + s*(sb1+s*(sb2+s*(sb3+s*(sb4+s*(sb5+s*(sb6+s*sb7))))))
}
z := Float64frombits(Float64bits(x) & 0xffffffff00000000) // pseudo-single (20-bit) precision x
r := Exp(-z*z-0.5625) * Exp((z-x)*(z+x)+R/S)
if sign {
return 2 - r/x
}
return r / x
}
if sign {
return 2
}
return 0
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Inverse of the floating-point error function.
*/
// This implementation is based on the rational approximation
// of percentage points of normal distribution available from
// https://www.jstor.org/stable/2347330.
const (
// Coefficients for approximation to erf in |x| <= 0.85
a0 = 1.1975323115670912564578e0
a1 = 4.7072688112383978012285e1
a2 = 6.9706266534389598238465e2
a3 = 4.8548868893843886794648e3
a4 = 1.6235862515167575384252e4
a5 = 2.3782041382114385731252e4
a6 = 1.1819493347062294404278e4
a7 = 8.8709406962545514830200e2
b0 = 1.0000000000000000000e0
b1 = 4.2313330701600911252e1
b2 = 6.8718700749205790830e2
b3 = 5.3941960214247511077e3
b4 = 2.1213794301586595867e4
b5 = 3.9307895800092710610e4
b6 = 2.8729085735721942674e4
b7 = 5.2264952788528545610e3
// Coefficients for approximation to erf in 0.85 < |x| <= 1-2*exp(-25)
c0 = 1.42343711074968357734e0
c1 = 4.63033784615654529590e0
c2 = 5.76949722146069140550e0
c3 = 3.64784832476320460504e0
c4 = 1.27045825245236838258e0
c5 = 2.41780725177450611770e-1
c6 = 2.27238449892691845833e-2
c7 = 7.74545014278341407640e-4
d0 = 1.4142135623730950488016887e0
d1 = 2.9036514445419946173133295e0
d2 = 2.3707661626024532365971225e0
d3 = 9.7547832001787427186894837e-1
d4 = 2.0945065210512749128288442e-1
d5 = 2.1494160384252876777097297e-2
d6 = 7.7441459065157709165577218e-4
d7 = 1.4859850019840355905497876e-9
// Coefficients for approximation to erf in 1-2*exp(-25) < |x| < 1
e0 = 6.65790464350110377720e0
e1 = 5.46378491116411436990e0
e2 = 1.78482653991729133580e0
e3 = 2.96560571828504891230e-1
e4 = 2.65321895265761230930e-2
e5 = 1.24266094738807843860e-3
e6 = 2.71155556874348757815e-5
e7 = 2.01033439929228813265e-7
f0 = 1.414213562373095048801689e0
f1 = 8.482908416595164588112026e-1
f2 = 1.936480946950659106176712e-1
f3 = 2.103693768272068968719679e-2
f4 = 1.112800997078859844711555e-3
f5 = 2.611088405080593625138020e-5
f6 = 2.010321207683943062279931e-7
f7 = 2.891024605872965461538222e-15
)
// Erfinv returns the inverse error function of x.
//
// Special cases are:
//
// Erfinv(1) = +Inf
// Erfinv(-1) = -Inf
// Erfinv(x) = NaN if x < -1 or x > 1
// Erfinv(NaN) = NaN
func Erfinv(x float64) float64 {
// special cases
if IsNaN(x) || x <= -1 || x >= 1 {
if x == -1 || x == 1 {
return Inf(int(x))
}
return NaN()
}
sign := false
if x < 0 {
x = -x
sign = true
}
var ans float64
if x <= 0.85 { // |x| <= 0.85
r := 0.180625 - 0.25*x*x
z1 := ((((((a7*r+a6)*r+a5)*r+a4)*r+a3)*r+a2)*r+a1)*r + a0
z2 := ((((((b7*r+b6)*r+b5)*r+b4)*r+b3)*r+b2)*r+b1)*r + b0
ans = (x * z1) / z2
} else {
var z1, z2 float64
r := Sqrt(Ln2 - Log(1.0-x))
if r <= 5.0 {
r -= 1.6
z1 = ((((((c7*r+c6)*r+c5)*r+c4)*r+c3)*r+c2)*r+c1)*r + c0
z2 = ((((((d7*r+d6)*r+d5)*r+d4)*r+d3)*r+d2)*r+d1)*r + d0
} else {
r -= 5.0
z1 = ((((((e7*r+e6)*r+e5)*r+e4)*r+e3)*r+e2)*r+e1)*r + e0
z2 = ((((((f7*r+f6)*r+f5)*r+f4)*r+f3)*r+f2)*r+f1)*r + f0
}
ans = z1 / z2
}
if sign {
return -ans
}
return ans
}
// Erfcinv returns the inverse of Erfc(x).
//
// Special cases are:
//
// Erfcinv(0) = +Inf
// Erfcinv(2) = -Inf
// Erfcinv(x) = NaN if x < 0 or x > 2
// Erfcinv(NaN) = NaN
func Erfcinv(x float64) float64 {
return Erfinv(1 - x)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Exp returns e**x, the base-e exponential of x.
//
// Special cases are:
//
// Exp(+Inf) = +Inf
// Exp(NaN) = NaN
//
// Very large values overflow to 0 or +Inf.
// Very small values underflow to 1.
func Exp(x float64) float64 {
if haveArchExp {
return archExp(x)
}
return exp(x)
}
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/e_exp.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 2004 by Sun Microsystems, Inc. All rights reserved.
//
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// exp(x)
// Returns the exponential of x.
//
// Method
// 1. Argument reduction:
// Reduce x to an r so that |r| <= 0.5*ln2 ~ 0.34658.
// Given x, find r and integer k such that
//
// x = k*ln2 + r, |r| <= 0.5*ln2.
//
// Here r will be represented as r = hi-lo for better
// accuracy.
//
// 2. Approximation of exp(r) by a special rational function on
// the interval [0,0.34658]:
// Write
// R(r**2) = r*(exp(r)+1)/(exp(r)-1) = 2 + r*r/6 - r**4/360 + ...
// We use a special Remez algorithm on [0,0.34658] to generate
// a polynomial of degree 5 to approximate R. The maximum error
// of this polynomial approximation is bounded by 2**-59. In
// other words,
// R(z) ~ 2.0 + P1*z + P2*z**2 + P3*z**3 + P4*z**4 + P5*z**5
// (where z=r*r, and the values of P1 to P5 are listed below)
// and
// | 5 | -59
// | 2.0+P1*z+...+P5*z - R(z) | <= 2
// | |
// The computation of exp(r) thus becomes
// 2*r
// exp(r) = 1 + -------
// R - r
// r*R1(r)
// = 1 + r + ----------- (for better accuracy)
// 2 - R1(r)
// where
// 2 4 10
// R1(r) = r - (P1*r + P2*r + ... + P5*r ).
//
// 3. Scale back to obtain exp(x):
// From step 1, we have
// exp(x) = 2**k * exp(r)
//
// Special cases:
// exp(INF) is INF, exp(NaN) is NaN;
// exp(-INF) is 0, and
// for finite argument, only exp(0)=1 is exact.
//
// Accuracy:
// according to an error analysis, the error is always less than
// 1 ulp (unit in the last place).
//
// Misc. info.
// For IEEE double
// if x > 7.09782712893383973096e+02 then exp(x) overflow
// if x < -7.45133219101941108420e+02 then exp(x) underflow
//
// Constants:
// The hexadecimal values are the intended ones for the following
// constants. The decimal values may be used, provided that the
// compiler will convert from decimal to binary accurately enough
// to produce the hexadecimal values shown.
func exp(x float64) float64 {
const (
Ln2Hi = 6.93147180369123816490e-01
Ln2Lo = 1.90821492927058770002e-10
Log2e = 1.44269504088896338700e+00
Overflow = 7.09782712893383973096e+02
Underflow = -7.45133219101941108420e+02
NearZero = 1.0 / (1 << 28) // 2**-28
)
// special cases
switch {
case IsNaN(x) || IsInf(x, 1):
return x
case IsInf(x, -1):
return 0
case x > Overflow:
return Inf(1)
case x < Underflow:
return 0
case -NearZero < x && x < NearZero:
return 1 + x
}
// reduce; computed as r = hi - lo for extra precision.
var k int
switch {
case x < 0:
k = int(Log2e*x - 0.5)
case x > 0:
k = int(Log2e*x + 0.5)
}
hi := x - float64(k)*Ln2Hi
lo := float64(k) * Ln2Lo
// compute
return expmulti(hi, lo, k)
}
// Exp2 returns 2**x, the base-2 exponential of x.
//
// Special cases are the same as Exp.
func Exp2(x float64) float64 {
if haveArchExp2 {
return archExp2(x)
}
return exp2(x)
}
func exp2(x float64) float64 {
const (
Ln2Hi = 6.93147180369123816490e-01
Ln2Lo = 1.90821492927058770002e-10
Overflow = 1.0239999999999999e+03
Underflow = -1.0740e+03
)
// special cases
switch {
case IsNaN(x) || IsInf(x, 1):
return x
case IsInf(x, -1):
return 0
case x > Overflow:
return Inf(1)
case x < Underflow:
return 0
}
// argument reduction; x = r×lg(e) + k with |r| ≤ ln(2)/2.
// computed as r = hi - lo for extra precision.
var k int
switch {
case x > 0:
k = int(x + 0.5)
case x < 0:
k = int(x - 0.5)
}
t := x - float64(k)
hi := t * Ln2Hi
lo := -t * Ln2Lo
// compute
return expmulti(hi, lo, k)
}
// exp1 returns e**r × 2**k where r = hi - lo and |r| ≤ ln(2)/2.
func expmulti(hi, lo float64, k int) float64 {
const (
P1 = 1.66666666666666657415e-01 /* 0x3FC55555; 0x55555555 */
P2 = -2.77777777770155933842e-03 /* 0xBF66C16C; 0x16BEBD93 */
P3 = 6.61375632143793436117e-05 /* 0x3F11566A; 0xAF25DE2C */
P4 = -1.65339022054652515390e-06 /* 0xBEBBBD41; 0xC5D26BF1 */
P5 = 4.13813679705723846039e-08 /* 0x3E663769; 0x72BEA4D0 */
)
r := hi - lo
t := r * r
c := r - t*(P1+t*(P2+t*(P3+t*(P4+t*P5))))
y := 1 - ((lo - (r*c)/(2-c)) - hi)
// TODO(rsc): make sure Ldexp can handle boundary k
return Ldexp(y, k)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64
package math
const haveArchExp2 = false
func archExp2(x float64) float64 {
panic("not implemented")
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/s_expm1.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// expm1(x)
// Returns exp(x)-1, the exponential of x minus 1.
//
// Method
// 1. Argument reduction:
// Given x, find r and integer k such that
//
// x = k*ln2 + r, |r| <= 0.5*ln2 ~ 0.34658
//
// Here a correction term c will be computed to compensate
// the error in r when rounded to a floating-point number.
//
// 2. Approximating expm1(r) by a special rational function on
// the interval [0,0.34658]:
// Since
// r*(exp(r)+1)/(exp(r)-1) = 2+ r**2/6 - r**4/360 + ...
// we define R1(r*r) by
// r*(exp(r)+1)/(exp(r)-1) = 2+ r**2/6 * R1(r*r)
// That is,
// R1(r**2) = 6/r *((exp(r)+1)/(exp(r)-1) - 2/r)
// = 6/r * ( 1 + 2.0*(1/(exp(r)-1) - 1/r))
// = 1 - r**2/60 + r**4/2520 - r**6/100800 + ...
// We use a special Reme algorithm on [0,0.347] to generate
// a polynomial of degree 5 in r*r to approximate R1. The
// maximum error of this polynomial approximation is bounded
// by 2**-61. In other words,
// R1(z) ~ 1.0 + Q1*z + Q2*z**2 + Q3*z**3 + Q4*z**4 + Q5*z**5
// where Q1 = -1.6666666666666567384E-2,
// Q2 = 3.9682539681370365873E-4,
// Q3 = -9.9206344733435987357E-6,
// Q4 = 2.5051361420808517002E-7,
// Q5 = -6.2843505682382617102E-9;
// (where z=r*r, and the values of Q1 to Q5 are listed below)
// with error bounded by
// | 5 | -61
// | 1.0+Q1*z+...+Q5*z - R1(z) | <= 2
// | |
//
// expm1(r) = exp(r)-1 is then computed by the following
// specific way which minimize the accumulation rounding error:
// 2 3
// r r [ 3 - (R1 + R1*r/2) ]
// expm1(r) = r + --- + --- * [--------------------]
// 2 2 [ 6 - r*(3 - R1*r/2) ]
//
// To compensate the error in the argument reduction, we use
// expm1(r+c) = expm1(r) + c + expm1(r)*c
// ~ expm1(r) + c + r*c
// Thus c+r*c will be added in as the correction terms for
// expm1(r+c). Now rearrange the term to avoid optimization
// screw up:
// ( 2 2 )
// ({ ( r [ R1 - (3 - R1*r/2) ] ) } r )
// expm1(r+c)~r - ({r*(--- * [--------------------]-c)-c} - --- )
// ({ ( 2 [ 6 - r*(3 - R1*r/2) ] ) } 2 )
// ( )
//
// = r - E
// 3. Scale back to obtain expm1(x):
// From step 1, we have
// expm1(x) = either 2**k*[expm1(r)+1] - 1
// = or 2**k*[expm1(r) + (1-2**-k)]
// 4. Implementation notes:
// (A). To save one multiplication, we scale the coefficient Qi
// to Qi*2**i, and replace z by (x**2)/2.
// (B). To achieve maximum accuracy, we compute expm1(x) by
// (i) if x < -56*ln2, return -1.0, (raise inexact if x!=inf)
// (ii) if k=0, return r-E
// (iii) if k=-1, return 0.5*(r-E)-0.5
// (iv) if k=1 if r < -0.25, return 2*((r+0.5)- E)
// else return 1.0+2.0*(r-E);
// (v) if (k<-2||k>56) return 2**k(1-(E-r)) - 1 (or exp(x)-1)
// (vi) if k <= 20, return 2**k((1-2**-k)-(E-r)), else
// (vii) return 2**k(1-((E+2**-k)-r))
//
// Special cases:
// expm1(INF) is INF, expm1(NaN) is NaN;
// expm1(-INF) is -1, and
// for finite argument, only expm1(0)=0 is exact.
//
// Accuracy:
// according to an error analysis, the error is always less than
// 1 ulp (unit in the last place).
//
// Misc. info.
// For IEEE double
// if x > 7.09782712893383973096e+02 then expm1(x) overflow
//
// Constants:
// The hexadecimal values are the intended ones for the following
// constants. The decimal values may be used, provided that the
// compiler will convert from decimal to binary accurately enough
// to produce the hexadecimal values shown.
//
// Expm1 returns e**x - 1, the base-e exponential of x minus 1.
// It is more accurate than Exp(x) - 1 when x is near zero.
//
// Special cases are:
//
// Expm1(+Inf) = +Inf
// Expm1(-Inf) = -1
// Expm1(NaN) = NaN
//
// Very large values overflow to -1 or +Inf.
func Expm1(x float64) float64 {
if haveArchExpm1 {
return archExpm1(x)
}
return expm1(x)
}
func expm1(x float64) float64 {
const (
Othreshold = 7.09782712893383973096e+02 // 0x40862E42FEFA39EF
Ln2X56 = 3.88162421113569373274e+01 // 0x4043687a9f1af2b1
Ln2HalfX3 = 1.03972077083991796413e+00 // 0x3ff0a2b23f3bab73
Ln2Half = 3.46573590279972654709e-01 // 0x3fd62e42fefa39ef
Ln2Hi = 6.93147180369123816490e-01 // 0x3fe62e42fee00000
Ln2Lo = 1.90821492927058770002e-10 // 0x3dea39ef35793c76
InvLn2 = 1.44269504088896338700e+00 // 0x3ff71547652b82fe
Tiny = 1.0 / (1 << 54) // 2**-54 = 0x3c90000000000000
// scaled coefficients related to expm1
Q1 = -3.33333333333331316428e-02 // 0xBFA11111111110F4
Q2 = 1.58730158725481460165e-03 // 0x3F5A01A019FE5585
Q3 = -7.93650757867487942473e-05 // 0xBF14CE199EAADBB7
Q4 = 4.00821782732936239552e-06 // 0x3ED0CFCA86E65239
Q5 = -2.01099218183624371326e-07 // 0xBE8AFDB76E09C32D
)
// special cases
switch {
case IsInf(x, 1) || IsNaN(x):
return x
case IsInf(x, -1):
return -1
}
absx := x
sign := false
if x < 0 {
absx = -absx
sign = true
}
// filter out huge argument
if absx >= Ln2X56 { // if |x| >= 56 * ln2
if sign {
return -1 // x < -56*ln2, return -1
}
if absx >= Othreshold { // if |x| >= 709.78...
return Inf(1)
}
}
// argument reduction
var c float64
var k int
if absx > Ln2Half { // if |x| > 0.5 * ln2
var hi, lo float64
if absx < Ln2HalfX3 { // and |x| < 1.5 * ln2
if !sign {
hi = x - Ln2Hi
lo = Ln2Lo
k = 1
} else {
hi = x + Ln2Hi
lo = -Ln2Lo
k = -1
}
} else {
if !sign {
k = int(InvLn2*x + 0.5)
} else {
k = int(InvLn2*x - 0.5)
}
t := float64(k)
hi = x - t*Ln2Hi // t * Ln2Hi is exact here
lo = t * Ln2Lo
}
x = hi - lo
c = (hi - x) - lo
} else if absx < Tiny { // when |x| < 2**-54, return x
return x
} else {
k = 0
}
// x is now in primary range
hfx := 0.5 * x
hxs := x * hfx
r1 := 1 + hxs*(Q1+hxs*(Q2+hxs*(Q3+hxs*(Q4+hxs*Q5))))
t := 3 - r1*hfx
e := hxs * ((r1 - t) / (6.0 - x*t))
if k == 0 {
return x - (x*e - hxs) // c is 0
}
e = (x*(e-c) - c)
e -= hxs
switch {
case k == -1:
return 0.5*(x-e) - 0.5
case k == 1:
if x < -0.25 {
return -2 * (e - (x + 0.5))
}
return 1 + 2*(x-e)
case k <= -2 || k > 56: // suffice to return exp(x)-1
y := 1 - (e - x)
y = Float64frombits(Float64bits(y) + uint64(k)<<52) // add k to y's exponent
return y - 1
}
if k < 20 {
t := Float64frombits(0x3ff0000000000000 - (0x20000000000000 >> uint(k))) // t=1-2**-k
y := t - (e - x)
y = Float64frombits(Float64bits(y) + uint64(k)<<52) // add k to y's exponent
return y
}
t = Float64frombits(uint64(0x3ff-k) << 52) // 2**-k
y := x - (e + t)
y++
y = Float64frombits(Float64bits(y) + uint64(k)<<52) // add k to y's exponent
return y
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Floor returns the greatest integer value less than or equal to x.
//
// Special cases are:
//
// Floor(±0) = ±0
// Floor(±Inf) = ±Inf
// Floor(NaN) = NaN
func Floor(x float64) float64 {
if haveArchFloor {
return archFloor(x)
}
return floor(x)
}
func floor(x float64) float64 {
if x == 0 || IsNaN(x) || IsInf(x, 0) {
return x
}
if x < 0 {
d, fract := Modf(-x)
if fract != 0.0 {
d = d + 1
}
return -d
}
d, _ := Modf(x)
return d
}
// Ceil returns the least integer value greater than or equal to x.
//
// Special cases are:
//
// Ceil(±0) = ±0
// Ceil(±Inf) = ±Inf
// Ceil(NaN) = NaN
func Ceil(x float64) float64 {
if haveArchCeil {
return archCeil(x)
}
return ceil(x)
}
func ceil(x float64) float64 {
return -Floor(-x)
}
// Trunc returns the integer value of x.
//
// Special cases are:
//
// Trunc(±0) = ±0
// Trunc(±Inf) = ±Inf
// Trunc(NaN) = NaN
func Trunc(x float64) float64 {
if haveArchTrunc {
return archTrunc(x)
}
return trunc(x)
}
func trunc(x float64) float64 {
if x == 0 || IsNaN(x) || IsInf(x, 0) {
return x
}
d, _ := Modf(x)
return d
}
// Round returns the nearest integer, rounding half away from zero.
//
// Special cases are:
//
// Round(±0) = ±0
// Round(±Inf) = ±Inf
// Round(NaN) = NaN
func Round(x float64) float64 {
// Round is a faster implementation of:
//
// func Round(x float64) float64 {
// t := Trunc(x)
// if Abs(x-t) >= 0.5 {
// return t + Copysign(1, x)
// }
// return t
// }
bits := Float64bits(x)
e := uint(bits>>shift) & mask
if e < bias {
// Round abs(x) < 1 including denormals.
bits &= signMask // +-0
if e == bias-1 {
bits |= uvone // +-1
}
} else if e < bias+shift {
// Round any abs(x) >= 1 containing a fractional component [0,1).
//
// Numbers with larger exponents are returned unchanged since they
// must be either an integer, infinity, or NaN.
const half = 1 << (shift - 1)
e -= bias
bits += half >> e
bits &^= fracMask >> e
}
return Float64frombits(bits)
}
// RoundToEven returns the nearest integer, rounding ties to even.
//
// Special cases are:
//
// RoundToEven(±0) = ±0
// RoundToEven(±Inf) = ±Inf
// RoundToEven(NaN) = NaN
func RoundToEven(x float64) float64 {
// RoundToEven is a faster implementation of:
//
// func RoundToEven(x float64) float64 {
// t := math.Trunc(x)
// odd := math.Remainder(t, 2) != 0
// if d := math.Abs(x - t); d > 0.5 || (d == 0.5 && odd) {
// return t + math.Copysign(1, x)
// }
// return t
// }
bits := Float64bits(x)
e := uint(bits>>shift) & mask
if e >= bias {
// Round abs(x) >= 1.
// - Large numbers without fractional components, infinity, and NaN are unchanged.
// - Add 0.499.. or 0.5 before truncating depending on whether the truncated
// number is even or odd (respectively).
const halfMinusULP = (1 << (shift - 1)) - 1
e -= bias
bits += (halfMinusULP + (bits>>(shift-e))&1) >> e
bits &^= fracMask >> e
} else if e == bias-1 && bits&fracMask != 0 {
// Round 0.5 < abs(x) < 1.
bits = bits&signMask | uvone // +-1
} else {
// Round abs(x) <= 0.5 including denormals.
bits &= signMask // +-0
}
return Float64frombits(bits)
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
import "math/bits"
func zero(x uint64) uint64 {
if x == 0 {
return 1
}
return 0
// branchless:
// return ((x>>1 | x&1) - 1) >> 63
}
func nonzero(x uint64) uint64 {
if x != 0 {
return 1
}
return 0
// branchless:
// return 1 - ((x>>1|x&1)-1)>>63
}
func shl(u1, u2 uint64, n uint) (r1, r2 uint64) {
r1 = u1<<n | u2>>(64-n) | u2<<(n-64)
r2 = u2 << n
return
}
func shr(u1, u2 uint64, n uint) (r1, r2 uint64) {
r2 = u2>>n | u1<<(64-n) | u1>>(n-64)
r1 = u1 >> n
return
}
// shrcompress compresses the bottom n+1 bits of the two-word
// value into a single bit. the result is equal to the value
// shifted to the right by n, except the result's 0th bit is
// set to the bitwise OR of the bottom n+1 bits.
func shrcompress(u1, u2 uint64, n uint) (r1, r2 uint64) {
// TODO: Performance here is really sensitive to the
// order/placement of these branches. n == 0 is common
// enough to be in the fast path. Perhaps more measurement
// needs to be done to find the optimal order/placement?
switch {
case n == 0:
return u1, u2
case n == 64:
return 0, u1 | nonzero(u2)
case n >= 128:
return 0, nonzero(u1 | u2)
case n < 64:
r1, r2 = shr(u1, u2, n)
r2 |= nonzero(u2 & (1<<n - 1))
case n < 128:
r1, r2 = shr(u1, u2, n)
r2 |= nonzero(u1&(1<<(n-64)-1) | u2)
}
return
}
func lz(u1, u2 uint64) (l int32) {
l = int32(bits.LeadingZeros64(u1))
if l == 64 {
l += int32(bits.LeadingZeros64(u2))
}
return l
}
// split splits b into sign, biased exponent, and mantissa.
// It adds the implicit 1 bit to the mantissa for normal values,
// and normalizes subnormal values.
func split(b uint64) (sign uint32, exp int32, mantissa uint64) {
sign = uint32(b >> 63)
exp = int32(b>>52) & mask
mantissa = b & fracMask
if exp == 0 {
// Normalize value if subnormal.
shift := uint(bits.LeadingZeros64(mantissa) - 11)
mantissa <<= shift
exp = 1 - int32(shift)
} else {
// Add implicit 1 bit
mantissa |= 1 << 52
}
return
}
// FMA returns x * y + z, computed with only one rounding.
// (That is, FMA returns the fused multiply-add of x, y, and z.)
func FMA(x, y, z float64) float64 {
bx, by, bz := Float64bits(x), Float64bits(y), Float64bits(z)
// Inf or NaN or zero involved. At most one rounding will occur.
if x == 0.0 || y == 0.0 || z == 0.0 || bx&uvinf == uvinf || by&uvinf == uvinf {
return x*y + z
}
// Handle non-finite z separately. Evaluating x*y+z where
// x and y are finite, but z is infinite, should always result in z.
if bz&uvinf == uvinf {
return z
}
// Inputs are (sub)normal.
// Split x, y, z into sign, exponent, mantissa.
xs, xe, xm := split(bx)
ys, ye, ym := split(by)
zs, ze, zm := split(bz)
// Compute product p = x*y as sign, exponent, two-word mantissa.
// Start with exponent. "is normal" bit isn't subtracted yet.
pe := xe + ye - bias + 1
// pm1:pm2 is the double-word mantissa for the product p.
// Shift left to leave top bit in product. Effectively
// shifts the 106-bit product to the left by 21.
pm1, pm2 := bits.Mul64(xm<<10, ym<<11)
zm1, zm2 := zm<<10, uint64(0)
ps := xs ^ ys // product sign
// normalize to 62nd bit
is62zero := uint((^pm1 >> 62) & 1)
pm1, pm2 = shl(pm1, pm2, is62zero)
pe -= int32(is62zero)
// Swap addition operands so |p| >= |z|
if pe < ze || pe == ze && pm1 < zm1 {
ps, pe, pm1, pm2, zs, ze, zm1, zm2 = zs, ze, zm1, zm2, ps, pe, pm1, pm2
}
// Align significands
zm1, zm2 = shrcompress(zm1, zm2, uint(pe-ze))
// Compute resulting significands, normalizing if necessary.
var m, c uint64
if ps == zs {
// Adding (pm1:pm2) + (zm1:zm2)
pm2, c = bits.Add64(pm2, zm2, 0)
pm1, _ = bits.Add64(pm1, zm1, c)
pe -= int32(^pm1 >> 63)
pm1, m = shrcompress(pm1, pm2, uint(64+pm1>>63))
} else {
// Subtracting (pm1:pm2) - (zm1:zm2)
// TODO: should we special-case cancellation?
pm2, c = bits.Sub64(pm2, zm2, 0)
pm1, _ = bits.Sub64(pm1, zm1, c)
nz := lz(pm1, pm2)
pe -= nz
m, pm2 = shl(pm1, pm2, uint(nz-1))
m |= nonzero(pm2)
}
// Round and break ties to even
if pe > 1022+bias || pe == 1022+bias && (m+1<<9)>>63 == 1 {
// rounded value overflows exponent range
return Float64frombits(uint64(ps)<<63 | uvinf)
}
if pe < 0 {
n := uint(-pe)
m = m>>n | nonzero(m&(1<<n-1))
pe = 0
}
m = ((m + 1<<9) >> 10) & ^zero((m&(1<<10-1))^1<<9)
pe &= -int32(nonzero(m))
return Float64frombits(uint64(ps)<<63 + uint64(pe)<<52 + m)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Frexp breaks f into a normalized fraction
// and an integral power of two.
// It returns frac and exp satisfying f == frac × 2**exp,
// with the absolute value of frac in the interval [½, 1).
//
// Special cases are:
//
// Frexp(±0) = ±0, 0
// Frexp(±Inf) = ±Inf, 0
// Frexp(NaN) = NaN, 0
func Frexp(f float64) (frac float64, exp int) {
if haveArchFrexp {
return archFrexp(f)
}
return frexp(f)
}
func frexp(f float64) (frac float64, exp int) {
// special cases
switch {
case f == 0:
return f, 0 // correctly return -0
case IsInf(f, 0) || IsNaN(f):
return f, 0
}
f, exp = normalize(f)
x := Float64bits(f)
exp += int((x>>shift)&mask) - bias + 1
x &^= mask << shift
x |= (-1 + bias) << shift
frac = Float64frombits(x)
return
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from http://netlib.sandia.gov/cephes/cprob/gamma.c.
// The go code is a simplified version of the original C.
//
// tgamma.c
//
// Gamma function
//
// SYNOPSIS:
//
// double x, y, tgamma();
// extern int signgam;
//
// y = tgamma( x );
//
// DESCRIPTION:
//
// Returns gamma function of the argument. The result is
// correctly signed, and the sign (+1 or -1) is also
// returned in a global (extern) variable named signgam.
// This variable is also filled in by the logarithmic gamma
// function lgamma().
//
// Arguments |x| <= 34 are reduced by recurrence and the function
// approximated by a rational function of degree 6/7 in the
// interval (2,3). Large arguments are handled by Stirling's
// formula. Large negative arguments are made positive using
// a reflection formula.
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC -34, 34 10000 1.3e-16 2.5e-17
// IEEE -170,-33 20000 2.3e-15 3.3e-16
// IEEE -33, 33 20000 9.4e-16 2.2e-16
// IEEE 33, 171.6 20000 2.3e-15 3.2e-16
//
// Error for arguments outside the test range will be larger
// owing to error amplification by the exponential function.
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
var _gamP = [...]float64{
1.60119522476751861407e-04,
1.19135147006586384913e-03,
1.04213797561761569935e-02,
4.76367800457137231464e-02,
2.07448227648435975150e-01,
4.94214826801497100753e-01,
9.99999999999999996796e-01,
}
var _gamQ = [...]float64{
-2.31581873324120129819e-05,
5.39605580493303397842e-04,
-4.45641913851797240494e-03,
1.18139785222060435552e-02,
3.58236398605498653373e-02,
-2.34591795718243348568e-01,
7.14304917030273074085e-02,
1.00000000000000000320e+00,
}
var _gamS = [...]float64{
7.87311395793093628397e-04,
-2.29549961613378126380e-04,
-2.68132617805781232825e-03,
3.47222221605458667310e-03,
8.33333333333482257126e-02,
}
// Gamma function computed by Stirling's formula.
// The pair of results must be multiplied together to get the actual answer.
// The multiplication is left to the caller so that, if careful, the caller can avoid
// infinity for 172 <= x <= 180.
// The polynomial is valid for 33 <= x <= 172; larger values are only used
// in reciprocal and produce denormalized floats. The lower precision there
// masks any imprecision in the polynomial.
func stirling(x float64) (float64, float64) {
if x > 200 {
return Inf(1), 1
}
const (
SqrtTwoPi = 2.506628274631000502417
MaxStirling = 143.01608
)
w := 1 / x
w = 1 + w*((((_gamS[0]*w+_gamS[1])*w+_gamS[2])*w+_gamS[3])*w+_gamS[4])
y1 := Exp(x)
y2 := 1.0
if x > MaxStirling { // avoid Pow() overflow
v := Pow(x, 0.5*x-0.25)
y1, y2 = v, v/y1
} else {
y1 = Pow(x, x-0.5) / y1
}
return y1, SqrtTwoPi * w * y2
}
// Gamma returns the Gamma function of x.
//
// Special cases are:
//
// Gamma(+Inf) = +Inf
// Gamma(+0) = +Inf
// Gamma(-0) = -Inf
// Gamma(x) = NaN for integer x < 0
// Gamma(-Inf) = NaN
// Gamma(NaN) = NaN
func Gamma(x float64) float64 {
const Euler = 0.57721566490153286060651209008240243104215933593992 // A001620
// special cases
switch {
case isNegInt(x) || IsInf(x, -1) || IsNaN(x):
return NaN()
case IsInf(x, 1):
return Inf(1)
case x == 0:
if Signbit(x) {
return Inf(-1)
}
return Inf(1)
}
q := Abs(x)
p := Floor(q)
if q > 33 {
if x >= 0 {
y1, y2 := stirling(x)
return y1 * y2
}
// Note: x is negative but (checked above) not a negative integer,
// so x must be small enough to be in range for conversion to int64.
// If |x| were >= 2⁶³ it would have to be an integer.
signgam := 1
if ip := int64(p); ip&1 == 0 {
signgam = -1
}
z := q - p
if z > 0.5 {
p = p + 1
z = q - p
}
z = q * Sin(Pi*z)
if z == 0 {
return Inf(signgam)
}
sq1, sq2 := stirling(q)
absz := Abs(z)
d := absz * sq1 * sq2
if IsInf(d, 0) {
z = Pi / absz / sq1 / sq2
} else {
z = Pi / d
}
return float64(signgam) * z
}
// Reduce argument
z := 1.0
for x >= 3 {
x = x - 1
z = z * x
}
for x < 0 {
if x > -1e-09 {
goto small
}
z = z / x
x = x + 1
}
for x < 2 {
if x < 1e-09 {
goto small
}
z = z / x
x = x + 1
}
if x == 2 {
return z
}
x = x - 2
p = (((((x*_gamP[0]+_gamP[1])*x+_gamP[2])*x+_gamP[3])*x+_gamP[4])*x+_gamP[5])*x + _gamP[6]
q = ((((((x*_gamQ[0]+_gamQ[1])*x+_gamQ[2])*x+_gamQ[3])*x+_gamQ[4])*x+_gamQ[5])*x+_gamQ[6])*x + _gamQ[7]
return z * p / q
small:
if x == 0 {
return Inf(1)
}
return z / ((1 + Euler*x) * x)
}
func isNegInt(x float64) bool {
if x < 0 {
_, xf := Modf(x)
return xf == 0
}
return false
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Hypot -- sqrt(p*p + q*q), but overflows only if the result does.
*/
// Hypot returns Sqrt(p*p + q*q), taking care to avoid
// unnecessary overflow and underflow.
//
// Special cases are:
//
// Hypot(±Inf, q) = +Inf
// Hypot(p, ±Inf) = +Inf
// Hypot(NaN, q) = NaN
// Hypot(p, NaN) = NaN
func Hypot(p, q float64) float64 {
if haveArchHypot {
return archHypot(p, q)
}
return hypot(p, q)
}
func hypot(p, q float64) float64 {
p, q = Abs(p), Abs(q)
// special cases
switch {
case IsInf(p, 1) || IsInf(q, 1):
return Inf(1)
case IsNaN(p) || IsNaN(q):
return NaN()
}
if p < q {
p, q = q, p
}
if p == 0 {
return 0
}
q = q / p
return p * Sqrt(1+q*q)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Bessel function of the first and second kinds of order zero.
*/
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_j0.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_j0(x), __ieee754_y0(x)
// Bessel function of the first and second kinds of order zero.
// Method -- j0(x):
// 1. For tiny x, we use j0(x) = 1 - x**2/4 + x**4/64 - ...
// 2. Reduce x to |x| since j0(x)=j0(-x), and
// for x in (0,2)
// j0(x) = 1-z/4+ z**2*R0/S0, where z = x*x;
// (precision: |j0-1+z/4-z**2R0/S0 |<2**-63.67 )
// for x in (2,inf)
// j0(x) = sqrt(2/(pi*x))*(p0(x)*cos(x0)-q0(x)*sin(x0))
// where x0 = x-pi/4. It is better to compute sin(x0),cos(x0)
// as follow:
// cos(x0) = cos(x)cos(pi/4)+sin(x)sin(pi/4)
// = 1/sqrt(2) * (cos(x) + sin(x))
// sin(x0) = sin(x)cos(pi/4)-cos(x)sin(pi/4)
// = 1/sqrt(2) * (sin(x) - cos(x))
// (To avoid cancellation, use
// sin(x) +- cos(x) = -cos(2x)/(sin(x) -+ cos(x))
// to compute the worse one.)
//
// 3 Special cases
// j0(nan)= nan
// j0(0) = 1
// j0(inf) = 0
//
// Method -- y0(x):
// 1. For x<2.
// Since
// y0(x) = 2/pi*(j0(x)*(ln(x/2)+Euler) + x**2/4 - ...)
// therefore y0(x)-2/pi*j0(x)*ln(x) is an even function.
// We use the following function to approximate y0,
// y0(x) = U(z)/V(z) + (2/pi)*(j0(x)*ln(x)), z= x**2
// where
// U(z) = u00 + u01*z + ... + u06*z**6
// V(z) = 1 + v01*z + ... + v04*z**4
// with absolute approximation error bounded by 2**-72.
// Note: For tiny x, U/V = u0 and j0(x)~1, hence
// y0(tiny) = u0 + (2/pi)*ln(tiny), (choose tiny<2**-27)
// 2. For x>=2.
// y0(x) = sqrt(2/(pi*x))*(p0(x)*cos(x0)+q0(x)*sin(x0))
// where x0 = x-pi/4. It is better to compute sin(x0),cos(x0)
// by the method mentioned above.
// 3. Special cases: y0(0)=-inf, y0(x<0)=NaN, y0(inf)=0.
//
// J0 returns the order-zero Bessel function of the first kind.
//
// Special cases are:
//
// J0(±Inf) = 0
// J0(0) = 1
// J0(NaN) = NaN
func J0(x float64) float64 {
const (
Huge = 1e300
TwoM27 = 1.0 / (1 << 27) // 2**-27 0x3e40000000000000
TwoM13 = 1.0 / (1 << 13) // 2**-13 0x3f20000000000000
Two129 = 1 << 129 // 2**129 0x4800000000000000
// R0/S0 on [0, 2]
R02 = 1.56249999999999947958e-02 // 0x3F8FFFFFFFFFFFFD
R03 = -1.89979294238854721751e-04 // 0xBF28E6A5B61AC6E9
R04 = 1.82954049532700665670e-06 // 0x3EBEB1D10C503919
R05 = -4.61832688532103189199e-09 // 0xBE33D5E773D63FCE
S01 = 1.56191029464890010492e-02 // 0x3F8FFCE882C8C2A4
S02 = 1.16926784663337450260e-04 // 0x3F1EA6D2DD57DBF4
S03 = 5.13546550207318111446e-07 // 0x3EA13B54CE84D5A9
S04 = 1.16614003333790000205e-09 // 0x3E1408BCF4745D8F
)
// special cases
switch {
case IsNaN(x):
return x
case IsInf(x, 0):
return 0
case x == 0:
return 1
}
x = Abs(x)
if x >= 2 {
s, c := Sincos(x)
ss := s - c
cc := s + c
// make sure x+x does not overflow
if x < MaxFloat64/2 {
z := -Cos(x + x)
if s*c < 0 {
cc = z / ss
} else {
ss = z / cc
}
}
// j0(x) = 1/sqrt(pi) * (P(0,x)*cc - Q(0,x)*ss) / sqrt(x)
// y0(x) = 1/sqrt(pi) * (P(0,x)*ss + Q(0,x)*cc) / sqrt(x)
var z float64
if x > Two129 { // |x| > ~6.8056e+38
z = (1 / SqrtPi) * cc / Sqrt(x)
} else {
u := pzero(x)
v := qzero(x)
z = (1 / SqrtPi) * (u*cc - v*ss) / Sqrt(x)
}
return z // |x| >= 2.0
}
if x < TwoM13 { // |x| < ~1.2207e-4
if x < TwoM27 {
return 1 // |x| < ~7.4506e-9
}
return 1 - 0.25*x*x // ~7.4506e-9 < |x| < ~1.2207e-4
}
z := x * x
r := z * (R02 + z*(R03+z*(R04+z*R05)))
s := 1 + z*(S01+z*(S02+z*(S03+z*S04)))
if x < 1 {
return 1 + z*(-0.25+(r/s)) // |x| < 1.00
}
u := 0.5 * x
return (1+u)*(1-u) + z*(r/s) // 1.0 < |x| < 2.0
}
// Y0 returns the order-zero Bessel function of the second kind.
//
// Special cases are:
//
// Y0(+Inf) = 0
// Y0(0) = -Inf
// Y0(x < 0) = NaN
// Y0(NaN) = NaN
func Y0(x float64) float64 {
const (
TwoM27 = 1.0 / (1 << 27) // 2**-27 0x3e40000000000000
Two129 = 1 << 129 // 2**129 0x4800000000000000
U00 = -7.38042951086872317523e-02 // 0xBFB2E4D699CBD01F
U01 = 1.76666452509181115538e-01 // 0x3FC69D019DE9E3FC
U02 = -1.38185671945596898896e-02 // 0xBF8C4CE8B16CFA97
U03 = 3.47453432093683650238e-04 // 0x3F36C54D20B29B6B
U04 = -3.81407053724364161125e-06 // 0xBECFFEA773D25CAD
U05 = 1.95590137035022920206e-08 // 0x3E5500573B4EABD4
U06 = -3.98205194132103398453e-11 // 0xBDC5E43D693FB3C8
V01 = 1.27304834834123699328e-02 // 0x3F8A127091C9C71A
V02 = 7.60068627350353253702e-05 // 0x3F13ECBBF578C6C1
V03 = 2.59150851840457805467e-07 // 0x3E91642D7FF202FD
V04 = 4.41110311332675467403e-10 // 0x3DFE50183BD6D9EF
)
// special cases
switch {
case x < 0 || IsNaN(x):
return NaN()
case IsInf(x, 1):
return 0
case x == 0:
return Inf(-1)
}
if x >= 2 { // |x| >= 2.0
// y0(x) = sqrt(2/(pi*x))*(p0(x)*sin(x0)+q0(x)*cos(x0))
// where x0 = x-pi/4
// Better formula:
// cos(x0) = cos(x)cos(pi/4)+sin(x)sin(pi/4)
// = 1/sqrt(2) * (sin(x) + cos(x))
// sin(x0) = sin(x)cos(3pi/4)-cos(x)sin(3pi/4)
// = 1/sqrt(2) * (sin(x) - cos(x))
// To avoid cancellation, use
// sin(x) +- cos(x) = -cos(2x)/(sin(x) -+ cos(x))
// to compute the worse one.
s, c := Sincos(x)
ss := s - c
cc := s + c
// j0(x) = 1/sqrt(pi) * (P(0,x)*cc - Q(0,x)*ss) / sqrt(x)
// y0(x) = 1/sqrt(pi) * (P(0,x)*ss + Q(0,x)*cc) / sqrt(x)
// make sure x+x does not overflow
if x < MaxFloat64/2 {
z := -Cos(x + x)
if s*c < 0 {
cc = z / ss
} else {
ss = z / cc
}
}
var z float64
if x > Two129 { // |x| > ~6.8056e+38
z = (1 / SqrtPi) * ss / Sqrt(x)
} else {
u := pzero(x)
v := qzero(x)
z = (1 / SqrtPi) * (u*ss + v*cc) / Sqrt(x)
}
return z // |x| >= 2.0
}
if x <= TwoM27 {
return U00 + (2/Pi)*Log(x) // |x| < ~7.4506e-9
}
z := x * x
u := U00 + z*(U01+z*(U02+z*(U03+z*(U04+z*(U05+z*U06)))))
v := 1 + z*(V01+z*(V02+z*(V03+z*V04)))
return u/v + (2/Pi)*J0(x)*Log(x) // ~7.4506e-9 < |x| < 2.0
}
// The asymptotic expansions of pzero is
// 1 - 9/128 s**2 + 11025/98304 s**4 - ..., where s = 1/x.
// For x >= 2, We approximate pzero by
// pzero(x) = 1 + (R/S)
// where R = pR0 + pR1*s**2 + pR2*s**4 + ... + pR5*s**10
// S = 1 + pS0*s**2 + ... + pS4*s**10
// and
// | pzero(x)-1-R/S | <= 2 ** ( -60.26)
// for x in [inf, 8]=1/[0,0.125]
var p0R8 = [6]float64{
0.00000000000000000000e+00, // 0x0000000000000000
-7.03124999999900357484e-02, // 0xBFB1FFFFFFFFFD32
-8.08167041275349795626e+00, // 0xC02029D0B44FA779
-2.57063105679704847262e+02, // 0xC07011027B19E863
-2.48521641009428822144e+03, // 0xC0A36A6ECD4DCAFC
-5.25304380490729545272e+03, // 0xC0B4850B36CC643D
}
var p0S8 = [5]float64{
1.16534364619668181717e+02, // 0x405D223307A96751
3.83374475364121826715e+03, // 0x40ADF37D50596938
4.05978572648472545552e+04, // 0x40E3D2BB6EB6B05F
1.16752972564375915681e+05, // 0x40FC810F8F9FA9BD
4.76277284146730962675e+04, // 0x40E741774F2C49DC
}
// for x in [8,4.5454]=1/[0.125,0.22001]
var p0R5 = [6]float64{
-1.14125464691894502584e-11, // 0xBDA918B147E495CC
-7.03124940873599280078e-02, // 0xBFB1FFFFE69AFBC6
-4.15961064470587782438e+00, // 0xC010A370F90C6BBF
-6.76747652265167261021e+01, // 0xC050EB2F5A7D1783
-3.31231299649172967747e+02, // 0xC074B3B36742CC63
-3.46433388365604912451e+02, // 0xC075A6EF28A38BD7
}
var p0S5 = [5]float64{
6.07539382692300335975e+01, // 0x404E60810C98C5DE
1.05125230595704579173e+03, // 0x40906D025C7E2864
5.97897094333855784498e+03, // 0x40B75AF88FBE1D60
9.62544514357774460223e+03, // 0x40C2CCB8FA76FA38
2.40605815922939109441e+03, // 0x40A2CC1DC70BE864
}
// for x in [4.547,2.8571]=1/[0.2199,0.35001]
var p0R3 = [6]float64{
-2.54704601771951915620e-09, // 0xBE25E1036FE1AA86
-7.03119616381481654654e-02, // 0xBFB1FFF6F7C0E24B
-2.40903221549529611423e+00, // 0xC00345B2AEA48074
-2.19659774734883086467e+01, // 0xC035F74A4CB94E14
-5.80791704701737572236e+01, // 0xC04D0A22420A1A45
-3.14479470594888503854e+01, // 0xC03F72ACA892D80F
}
var p0S3 = [5]float64{
3.58560338055209726349e+01, // 0x4041ED9284077DD3
3.61513983050303863820e+02, // 0x40769839464A7C0E
1.19360783792111533330e+03, // 0x4092A66E6D1061D6
1.12799679856907414432e+03, // 0x40919FFCB8C39B7E
1.73580930813335754692e+02, // 0x4065B296FC379081
}
// for x in [2.8570,2]=1/[0.3499,0.5]
var p0R2 = [6]float64{
-8.87534333032526411254e-08, // 0xBE77D316E927026D
-7.03030995483624743247e-02, // 0xBFB1FF62495E1E42
-1.45073846780952986357e+00, // 0xBFF736398A24A843
-7.63569613823527770791e+00, // 0xC01E8AF3EDAFA7F3
-1.11931668860356747786e+01, // 0xC02662E6C5246303
-3.23364579351335335033e+00, // 0xC009DE81AF8FE70F
}
var p0S2 = [5]float64{
2.22202997532088808441e+01, // 0x40363865908B5959
1.36206794218215208048e+02, // 0x4061069E0EE8878F
2.70470278658083486789e+02, // 0x4070E78642EA079B
1.53875394208320329881e+02, // 0x40633C033AB6FAFF
1.46576176948256193810e+01, // 0x402D50B344391809
}
func pzero(x float64) float64 {
var p *[6]float64
var q *[5]float64
if x >= 8 {
p = &p0R8
q = &p0S8
} else if x >= 4.5454 {
p = &p0R5
q = &p0S5
} else if x >= 2.8571 {
p = &p0R3
q = &p0S3
} else if x >= 2 {
p = &p0R2
q = &p0S2
}
z := 1 / (x * x)
r := p[0] + z*(p[1]+z*(p[2]+z*(p[3]+z*(p[4]+z*p[5]))))
s := 1 + z*(q[0]+z*(q[1]+z*(q[2]+z*(q[3]+z*q[4]))))
return 1 + r/s
}
// For x >= 8, the asymptotic expansions of qzero is
// -1/8 s + 75/1024 s**3 - ..., where s = 1/x.
// We approximate pzero by
// qzero(x) = s*(-1.25 + (R/S))
// where R = qR0 + qR1*s**2 + qR2*s**4 + ... + qR5*s**10
// S = 1 + qS0*s**2 + ... + qS5*s**12
// and
// | qzero(x)/s +1.25-R/S | <= 2**(-61.22)
// for x in [inf, 8]=1/[0,0.125]
var q0R8 = [6]float64{
0.00000000000000000000e+00, // 0x0000000000000000
7.32421874999935051953e-02, // 0x3FB2BFFFFFFFFE2C
1.17682064682252693899e+01, // 0x402789525BB334D6
5.57673380256401856059e+02, // 0x40816D6315301825
8.85919720756468632317e+03, // 0x40C14D993E18F46D
3.70146267776887834771e+04, // 0x40E212D40E901566
}
var q0S8 = [6]float64{
1.63776026895689824414e+02, // 0x406478D5365B39BC
8.09834494656449805916e+03, // 0x40BFA2584E6B0563
1.42538291419120476348e+05, // 0x4101665254D38C3F
8.03309257119514397345e+05, // 0x412883DA83A52B43
8.40501579819060512818e+05, // 0x4129A66B28DE0B3D
-3.43899293537866615225e+05, // 0xC114FD6D2C9530C5
}
// for x in [8,4.5454]=1/[0.125,0.22001]
var q0R5 = [6]float64{
1.84085963594515531381e-11, // 0x3DB43D8F29CC8CD9
7.32421766612684765896e-02, // 0x3FB2BFFFD172B04C
5.83563508962056953777e+00, // 0x401757B0B9953DD3
1.35111577286449829671e+02, // 0x4060E3920A8788E9
1.02724376596164097464e+03, // 0x40900CF99DC8C481
1.98997785864605384631e+03, // 0x409F17E953C6E3A6
}
var q0S5 = [6]float64{
8.27766102236537761883e+01, // 0x4054B1B3FB5E1543
2.07781416421392987104e+03, // 0x40A03BA0DA21C0CE
1.88472887785718085070e+04, // 0x40D267D27B591E6D
5.67511122894947329769e+04, // 0x40EBB5E397E02372
3.59767538425114471465e+04, // 0x40E191181F7A54A0
-5.35434275601944773371e+03, // 0xC0B4EA57BEDBC609
}
// for x in [4.547,2.8571]=1/[0.2199,0.35001]
var q0R3 = [6]float64{
4.37741014089738620906e-09, // 0x3E32CD036ADECB82
7.32411180042911447163e-02, // 0x3FB2BFEE0E8D0842
3.34423137516170720929e+00, // 0x400AC0FC61149CF5
4.26218440745412650017e+01, // 0x40454F98962DAEDD
1.70808091340565596283e+02, // 0x406559DBE25EFD1F
1.66733948696651168575e+02, // 0x4064D77C81FA21E0
}
var q0S3 = [6]float64{
4.87588729724587182091e+01, // 0x40486122BFE343A6
7.09689221056606015736e+02, // 0x40862D8386544EB3
3.70414822620111362994e+03, // 0x40ACF04BE44DFC63
6.46042516752568917582e+03, // 0x40B93C6CD7C76A28
2.51633368920368957333e+03, // 0x40A3A8AAD94FB1C0
-1.49247451836156386662e+02, // 0xC062A7EB201CF40F
}
// for x in [2.8570,2]=1/[0.3499,0.5]
var q0R2 = [6]float64{
1.50444444886983272379e-07, // 0x3E84313B54F76BDB
7.32234265963079278272e-02, // 0x3FB2BEC53E883E34
1.99819174093815998816e+00, // 0x3FFFF897E727779C
1.44956029347885735348e+01, // 0x402CFDBFAAF96FE5
3.16662317504781540833e+01, // 0x403FAA8E29FBDC4A
1.62527075710929267416e+01, // 0x403040B171814BB4
}
var q0S2 = [6]float64{
3.03655848355219184498e+01, // 0x403E5D96F7C07AED
2.69348118608049844624e+02, // 0x4070D591E4D14B40
8.44783757595320139444e+02, // 0x408A664522B3BF22
8.82935845112488550512e+02, // 0x408B977C9C5CC214
2.12666388511798828631e+02, // 0x406A95530E001365
-5.31095493882666946917e+00, // 0xC0153E6AF8B32931
}
func qzero(x float64) float64 {
var p, q *[6]float64
if x >= 8 {
p = &q0R8
q = &q0S8
} else if x >= 4.5454 {
p = &q0R5
q = &q0S5
} else if x >= 2.8571 {
p = &q0R3
q = &q0S3
} else if x >= 2 {
p = &q0R2
q = &q0S2
}
z := 1 / (x * x)
r := p[0] + z*(p[1]+z*(p[2]+z*(p[3]+z*(p[4]+z*p[5]))))
s := 1 + z*(q[0]+z*(q[1]+z*(q[2]+z*(q[3]+z*(q[4]+z*q[5])))))
return (-0.125 + r/s) / x
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Bessel function of the first and second kinds of order one.
*/
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_j1.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_j1(x), __ieee754_y1(x)
// Bessel function of the first and second kinds of order one.
// Method -- j1(x):
// 1. For tiny x, we use j1(x) = x/2 - x**3/16 + x**5/384 - ...
// 2. Reduce x to |x| since j1(x)=-j1(-x), and
// for x in (0,2)
// j1(x) = x/2 + x*z*R0/S0, where z = x*x;
// (precision: |j1/x - 1/2 - R0/S0 |<2**-61.51 )
// for x in (2,inf)
// j1(x) = sqrt(2/(pi*x))*(p1(x)*cos(x1)-q1(x)*sin(x1))
// y1(x) = sqrt(2/(pi*x))*(p1(x)*sin(x1)+q1(x)*cos(x1))
// where x1 = x-3*pi/4. It is better to compute sin(x1),cos(x1)
// as follow:
// cos(x1) = cos(x)cos(3pi/4)+sin(x)sin(3pi/4)
// = 1/sqrt(2) * (sin(x) - cos(x))
// sin(x1) = sin(x)cos(3pi/4)-cos(x)sin(3pi/4)
// = -1/sqrt(2) * (sin(x) + cos(x))
// (To avoid cancellation, use
// sin(x) +- cos(x) = -cos(2x)/(sin(x) -+ cos(x))
// to compute the worse one.)
//
// 3 Special cases
// j1(nan)= nan
// j1(0) = 0
// j1(inf) = 0
//
// Method -- y1(x):
// 1. screen out x<=0 cases: y1(0)=-inf, y1(x<0)=NaN
// 2. For x<2.
// Since
// y1(x) = 2/pi*(j1(x)*(ln(x/2)+Euler)-1/x-x/2+5/64*x**3-...)
// therefore y1(x)-2/pi*j1(x)*ln(x)-1/x is an odd function.
// We use the following function to approximate y1,
// y1(x) = x*U(z)/V(z) + (2/pi)*(j1(x)*ln(x)-1/x), z= x**2
// where for x in [0,2] (abs err less than 2**-65.89)
// U(z) = U0[0] + U0[1]*z + ... + U0[4]*z**4
// V(z) = 1 + v0[0]*z + ... + v0[4]*z**5
// Note: For tiny x, 1/x dominate y1 and hence
// y1(tiny) = -2/pi/tiny, (choose tiny<2**-54)
// 3. For x>=2.
// y1(x) = sqrt(2/(pi*x))*(p1(x)*sin(x1)+q1(x)*cos(x1))
// where x1 = x-3*pi/4. It is better to compute sin(x1),cos(x1)
// by method mentioned above.
// J1 returns the order-one Bessel function of the first kind.
//
// Special cases are:
//
// J1(±Inf) = 0
// J1(NaN) = NaN
func J1(x float64) float64 {
const (
TwoM27 = 1.0 / (1 << 27) // 2**-27 0x3e40000000000000
Two129 = 1 << 129 // 2**129 0x4800000000000000
// R0/S0 on [0, 2]
R00 = -6.25000000000000000000e-02 // 0xBFB0000000000000
R01 = 1.40705666955189706048e-03 // 0x3F570D9F98472C61
R02 = -1.59955631084035597520e-05 // 0xBEF0C5C6BA169668
R03 = 4.96727999609584448412e-08 // 0x3E6AAAFA46CA0BD9
S01 = 1.91537599538363460805e-02 // 0x3F939D0B12637E53
S02 = 1.85946785588630915560e-04 // 0x3F285F56B9CDF664
S03 = 1.17718464042623683263e-06 // 0x3EB3BFF8333F8498
S04 = 5.04636257076217042715e-09 // 0x3E35AC88C97DFF2C
S05 = 1.23542274426137913908e-11 // 0x3DAB2ACFCFB97ED8
)
// special cases
switch {
case IsNaN(x):
return x
case IsInf(x, 0) || x == 0:
return 0
}
sign := false
if x < 0 {
x = -x
sign = true
}
if x >= 2 {
s, c := Sincos(x)
ss := -s - c
cc := s - c
// make sure x+x does not overflow
if x < MaxFloat64/2 {
z := Cos(x + x)
if s*c > 0 {
cc = z / ss
} else {
ss = z / cc
}
}
// j1(x) = 1/sqrt(pi) * (P(1,x)*cc - Q(1,x)*ss) / sqrt(x)
// y1(x) = 1/sqrt(pi) * (P(1,x)*ss + Q(1,x)*cc) / sqrt(x)
var z float64
if x > Two129 {
z = (1 / SqrtPi) * cc / Sqrt(x)
} else {
u := pone(x)
v := qone(x)
z = (1 / SqrtPi) * (u*cc - v*ss) / Sqrt(x)
}
if sign {
return -z
}
return z
}
if x < TwoM27 { // |x|<2**-27
return 0.5 * x // inexact if x!=0 necessary
}
z := x * x
r := z * (R00 + z*(R01+z*(R02+z*R03)))
s := 1.0 + z*(S01+z*(S02+z*(S03+z*(S04+z*S05))))
r *= x
z = 0.5*x + r/s
if sign {
return -z
}
return z
}
// Y1 returns the order-one Bessel function of the second kind.
//
// Special cases are:
//
// Y1(+Inf) = 0
// Y1(0) = -Inf
// Y1(x < 0) = NaN
// Y1(NaN) = NaN
func Y1(x float64) float64 {
const (
TwoM54 = 1.0 / (1 << 54) // 2**-54 0x3c90000000000000
Two129 = 1 << 129 // 2**129 0x4800000000000000
U00 = -1.96057090646238940668e-01 // 0xBFC91866143CBC8A
U01 = 5.04438716639811282616e-02 // 0x3FA9D3C776292CD1
U02 = -1.91256895875763547298e-03 // 0xBF5F55E54844F50F
U03 = 2.35252600561610495928e-05 // 0x3EF8AB038FA6B88E
U04 = -9.19099158039878874504e-08 // 0xBE78AC00569105B8
V00 = 1.99167318236649903973e-02 // 0x3F94650D3F4DA9F0
V01 = 2.02552581025135171496e-04 // 0x3F2A8C896C257764
V02 = 1.35608801097516229404e-06 // 0x3EB6C05A894E8CA6
V03 = 6.22741452364621501295e-09 // 0x3E3ABF1D5BA69A86
V04 = 1.66559246207992079114e-11 // 0x3DB25039DACA772A
)
// special cases
switch {
case x < 0 || IsNaN(x):
return NaN()
case IsInf(x, 1):
return 0
case x == 0:
return Inf(-1)
}
if x >= 2 {
s, c := Sincos(x)
ss := -s - c
cc := s - c
// make sure x+x does not overflow
if x < MaxFloat64/2 {
z := Cos(x + x)
if s*c > 0 {
cc = z / ss
} else {
ss = z / cc
}
}
// y1(x) = sqrt(2/(pi*x))*(p1(x)*sin(x0)+q1(x)*cos(x0))
// where x0 = x-3pi/4
// Better formula:
// cos(x0) = cos(x)cos(3pi/4)+sin(x)sin(3pi/4)
// = 1/sqrt(2) * (sin(x) - cos(x))
// sin(x0) = sin(x)cos(3pi/4)-cos(x)sin(3pi/4)
// = -1/sqrt(2) * (cos(x) + sin(x))
// To avoid cancellation, use
// sin(x) +- cos(x) = -cos(2x)/(sin(x) -+ cos(x))
// to compute the worse one.
var z float64
if x > Two129 {
z = (1 / SqrtPi) * ss / Sqrt(x)
} else {
u := pone(x)
v := qone(x)
z = (1 / SqrtPi) * (u*ss + v*cc) / Sqrt(x)
}
return z
}
if x <= TwoM54 { // x < 2**-54
return -(2 / Pi) / x
}
z := x * x
u := U00 + z*(U01+z*(U02+z*(U03+z*U04)))
v := 1 + z*(V00+z*(V01+z*(V02+z*(V03+z*V04))))
return x*(u/v) + (2/Pi)*(J1(x)*Log(x)-1/x)
}
// For x >= 8, the asymptotic expansions of pone is
// 1 + 15/128 s**2 - 4725/2**15 s**4 - ..., where s = 1/x.
// We approximate pone by
// pone(x) = 1 + (R/S)
// where R = pr0 + pr1*s**2 + pr2*s**4 + ... + pr5*s**10
// S = 1 + ps0*s**2 + ... + ps4*s**10
// and
// | pone(x)-1-R/S | <= 2**(-60.06)
// for x in [inf, 8]=1/[0,0.125]
var p1R8 = [6]float64{
0.00000000000000000000e+00, // 0x0000000000000000
1.17187499999988647970e-01, // 0x3FBDFFFFFFFFFCCE
1.32394806593073575129e+01, // 0x402A7A9D357F7FCE
4.12051854307378562225e+02, // 0x4079C0D4652EA590
3.87474538913960532227e+03, // 0x40AE457DA3A532CC
7.91447954031891731574e+03, // 0x40BEEA7AC32782DD
}
var p1S8 = [5]float64{
1.14207370375678408436e+02, // 0x405C8D458E656CAC
3.65093083420853463394e+03, // 0x40AC85DC964D274F
3.69562060269033463555e+04, // 0x40E20B8697C5BB7F
9.76027935934950801311e+04, // 0x40F7D42CB28F17BB
3.08042720627888811578e+04, // 0x40DE1511697A0B2D
}
// for x in [8,4.5454] = 1/[0.125,0.22001]
var p1R5 = [6]float64{
1.31990519556243522749e-11, // 0x3DAD0667DAE1CA7D
1.17187493190614097638e-01, // 0x3FBDFFFFE2C10043
6.80275127868432871736e+00, // 0x401B36046E6315E3
1.08308182990189109773e+02, // 0x405B13B9452602ED
5.17636139533199752805e+02, // 0x40802D16D052D649
5.28715201363337541807e+02, // 0x408085B8BB7E0CB7
}
var p1S5 = [5]float64{
5.92805987221131331921e+01, // 0x404DA3EAA8AF633D
9.91401418733614377743e+02, // 0x408EFB361B066701
5.35326695291487976647e+03, // 0x40B4E9445706B6FB
7.84469031749551231769e+03, // 0x40BEA4B0B8A5BB15
1.50404688810361062679e+03, // 0x40978030036F5E51
}
// for x in[4.5453,2.8571] = 1/[0.2199,0.35001]
var p1R3 = [6]float64{
3.02503916137373618024e-09, // 0x3E29FC21A7AD9EDD
1.17186865567253592491e-01, // 0x3FBDFFF55B21D17B
3.93297750033315640650e+00, // 0x400F76BCE85EAD8A
3.51194035591636932736e+01, // 0x40418F489DA6D129
9.10550110750781271918e+01, // 0x4056C3854D2C1837
4.85590685197364919645e+01, // 0x4048478F8EA83EE5
}
var p1S3 = [5]float64{
3.47913095001251519989e+01, // 0x40416549A134069C
3.36762458747825746741e+02, // 0x40750C3307F1A75F
1.04687139975775130551e+03, // 0x40905B7C5037D523
8.90811346398256432622e+02, // 0x408BD67DA32E31E9
1.03787932439639277504e+02, // 0x4059F26D7C2EED53
}
// for x in [2.8570,2] = 1/[0.3499,0.5]
var p1R2 = [6]float64{
1.07710830106873743082e-07, // 0x3E7CE9D4F65544F4
1.17176219462683348094e-01, // 0x3FBDFF42BE760D83
2.36851496667608785174e+00, // 0x4002F2B7F98FAEC0
1.22426109148261232917e+01, // 0x40287C377F71A964
1.76939711271687727390e+01, // 0x4031B1A8177F8EE2
5.07352312588818499250e+00, // 0x40144B49A574C1FE
}
var p1S2 = [5]float64{
2.14364859363821409488e+01, // 0x40356FBD8AD5ECDC
1.25290227168402751090e+02, // 0x405F529314F92CD5
2.32276469057162813669e+02, // 0x406D08D8D5A2DBD9
1.17679373287147100768e+02, // 0x405D6B7ADA1884A9
8.36463893371618283368e+00, // 0x4020BAB1F44E5192
}
func pone(x float64) float64 {
var p *[6]float64
var q *[5]float64
if x >= 8 {
p = &p1R8
q = &p1S8
} else if x >= 4.5454 {
p = &p1R5
q = &p1S5
} else if x >= 2.8571 {
p = &p1R3
q = &p1S3
} else if x >= 2 {
p = &p1R2
q = &p1S2
}
z := 1 / (x * x)
r := p[0] + z*(p[1]+z*(p[2]+z*(p[3]+z*(p[4]+z*p[5]))))
s := 1.0 + z*(q[0]+z*(q[1]+z*(q[2]+z*(q[3]+z*q[4]))))
return 1 + r/s
}
// For x >= 8, the asymptotic expansions of qone is
// 3/8 s - 105/1024 s**3 - ..., where s = 1/x.
// We approximate qone by
// qone(x) = s*(0.375 + (R/S))
// where R = qr1*s**2 + qr2*s**4 + ... + qr5*s**10
// S = 1 + qs1*s**2 + ... + qs6*s**12
// and
// | qone(x)/s -0.375-R/S | <= 2**(-61.13)
// for x in [inf, 8] = 1/[0,0.125]
var q1R8 = [6]float64{
0.00000000000000000000e+00, // 0x0000000000000000
-1.02539062499992714161e-01, // 0xBFBA3FFFFFFFFDF3
-1.62717534544589987888e+01, // 0xC0304591A26779F7
-7.59601722513950107896e+02, // 0xC087BCD053E4B576
-1.18498066702429587167e+04, // 0xC0C724E740F87415
-4.84385124285750353010e+04, // 0xC0E7A6D065D09C6A
}
var q1S8 = [6]float64{
1.61395369700722909556e+02, // 0x40642CA6DE5BCDE5
7.82538599923348465381e+03, // 0x40BE9162D0D88419
1.33875336287249578163e+05, // 0x4100579AB0B75E98
7.19657723683240939863e+05, // 0x4125F65372869C19
6.66601232617776375264e+05, // 0x412457D27719AD5C
-2.94490264303834643215e+05, // 0xC111F9690EA5AA18
}
// for x in [8,4.5454] = 1/[0.125,0.22001]
var q1R5 = [6]float64{
-2.08979931141764104297e-11, // 0xBDB6FA431AA1A098
-1.02539050241375426231e-01, // 0xBFBA3FFFCB597FEF
-8.05644828123936029840e+00, // 0xC0201CE6CA03AD4B
-1.83669607474888380239e+02, // 0xC066F56D6CA7B9B0
-1.37319376065508163265e+03, // 0xC09574C66931734F
-2.61244440453215656817e+03, // 0xC0A468E388FDA79D
}
var q1S5 = [6]float64{
8.12765501384335777857e+01, // 0x405451B2FF5A11B2
1.99179873460485964642e+03, // 0x409F1F31E77BF839
1.74684851924908907677e+04, // 0x40D10F1F0D64CE29
4.98514270910352279316e+04, // 0x40E8576DAABAD197
2.79480751638918118260e+04, // 0x40DB4B04CF7C364B
-4.71918354795128470869e+03, // 0xC0B26F2EFCFFA004
}
// for x in [4.5454,2.8571] = 1/[0.2199,0.35001] ???
var q1R3 = [6]float64{
-5.07831226461766561369e-09, // 0xBE35CFA9D38FC84F
-1.02537829820837089745e-01, // 0xBFBA3FEB51AEED54
-4.61011581139473403113e+00, // 0xC01270C23302D9FF
-5.78472216562783643212e+01, // 0xC04CEC71C25D16DA
-2.28244540737631695038e+02, // 0xC06C87D34718D55F
-2.19210128478909325622e+02, // 0xC06B66B95F5C1BF6
}
var q1S3 = [6]float64{
4.76651550323729509273e+01, // 0x4047D523CCD367E4
6.73865112676699709482e+02, // 0x40850EEBC031EE3E
3.38015286679526343505e+03, // 0x40AA684E448E7C9A
5.54772909720722782367e+03, // 0x40B5ABBAA61D54A6
1.90311919338810798763e+03, // 0x409DBC7A0DD4DF4B
-1.35201191444307340817e+02, // 0xC060E670290A311F
}
// for x in [2.8570,2] = 1/[0.3499,0.5]
var q1R2 = [6]float64{
-1.78381727510958865572e-07, // 0xBE87F12644C626D2
-1.02517042607985553460e-01, // 0xBFBA3E8E9148B010
-2.75220568278187460720e+00, // 0xC006048469BB4EDA
-1.96636162643703720221e+01, // 0xC033A9E2C168907F
-4.23253133372830490089e+01, // 0xC04529A3DE104AAA
-2.13719211703704061733e+01, // 0xC0355F3639CF6E52
}
var q1S2 = [6]float64{
2.95333629060523854548e+01, // 0x403D888A78AE64FF
2.52981549982190529136e+02, // 0x406F9F68DB821CBA
7.57502834868645436472e+02, // 0x4087AC05CE49A0F7
7.39393205320467245656e+02, // 0x40871B2548D4C029
1.55949003336666123687e+02, // 0x40637E5E3C3ED8D4
-4.95949898822628210127e+00, // 0xC013D686E71BE86B
}
func qone(x float64) float64 {
var p, q *[6]float64
if x >= 8 {
p = &q1R8
q = &q1S8
} else if x >= 4.5454 {
p = &q1R5
q = &q1S5
} else if x >= 2.8571 {
p = &q1R3
q = &q1S3
} else if x >= 2 {
p = &q1R2
q = &q1S2
}
z := 1 / (x * x)
r := p[0] + z*(p[1]+z*(p[2]+z*(p[3]+z*(p[4]+z*p[5]))))
s := 1 + z*(q[0]+z*(q[1]+z*(q[2]+z*(q[3]+z*(q[4]+z*q[5])))))
return (0.375 + r/s) / x
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Bessel function of the first and second kinds of order n.
*/
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_jn.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_jn(n, x), __ieee754_yn(n, x)
// floating point Bessel's function of the 1st and 2nd kind
// of order n
//
// Special cases:
// y0(0)=y1(0)=yn(n,0) = -inf with division by zero signal;
// y0(-ve)=y1(-ve)=yn(n,-ve) are NaN with invalid signal.
// Note 2. About jn(n,x), yn(n,x)
// For n=0, j0(x) is called,
// for n=1, j1(x) is called,
// for n<x, forward recursion is used starting
// from values of j0(x) and j1(x).
// for n>x, a continued fraction approximation to
// j(n,x)/j(n-1,x) is evaluated and then backward
// recursion is used starting from a supposed value
// for j(n,x). The resulting value of j(0,x) is
// compared with the actual value to correct the
// supposed value of j(n,x).
//
// yn(n,x) is similar in all respects, except
// that forward recursion is used for all
// values of n>1.
// Jn returns the order-n Bessel function of the first kind.
//
// Special cases are:
//
// Jn(n, ±Inf) = 0
// Jn(n, NaN) = NaN
func Jn(n int, x float64) float64 {
const (
TwoM29 = 1.0 / (1 << 29) // 2**-29 0x3e10000000000000
Two302 = 1 << 302 // 2**302 0x52D0000000000000
)
// special cases
switch {
case IsNaN(x):
return x
case IsInf(x, 0):
return 0
}
// J(-n, x) = (-1)**n * J(n, x), J(n, -x) = (-1)**n * J(n, x)
// Thus, J(-n, x) = J(n, -x)
if n == 0 {
return J0(x)
}
if x == 0 {
return 0
}
if n < 0 {
n, x = -n, -x
}
if n == 1 {
return J1(x)
}
sign := false
if x < 0 {
x = -x
if n&1 == 1 {
sign = true // odd n and negative x
}
}
var b float64
if float64(n) <= x {
// Safe to use J(n+1,x)=2n/x *J(n,x)-J(n-1,x)
if x >= Two302 { // x > 2**302
// (x >> n**2)
// Jn(x) = cos(x-(2n+1)*pi/4)*sqrt(2/x*pi)
// Yn(x) = sin(x-(2n+1)*pi/4)*sqrt(2/x*pi)
// Let s=sin(x), c=cos(x),
// xn=x-(2n+1)*pi/4, sqt2 = sqrt(2),then
//
// n sin(xn)*sqt2 cos(xn)*sqt2
// ----------------------------------
// 0 s-c c+s
// 1 -s-c -c+s
// 2 -s+c -c-s
// 3 s+c c-s
var temp float64
switch s, c := Sincos(x); n & 3 {
case 0:
temp = c + s
case 1:
temp = -c + s
case 2:
temp = -c - s
case 3:
temp = c - s
}
b = (1 / SqrtPi) * temp / Sqrt(x)
} else {
b = J1(x)
for i, a := 1, J0(x); i < n; i++ {
a, b = b, b*(float64(i+i)/x)-a // avoid underflow
}
}
} else {
if x < TwoM29 { // x < 2**-29
// x is tiny, return the first Taylor expansion of J(n,x)
// J(n,x) = 1/n!*(x/2)**n - ...
if n > 33 { // underflow
b = 0
} else {
temp := x * 0.5
b = temp
a := 1.0
for i := 2; i <= n; i++ {
a *= float64(i) // a = n!
b *= temp // b = (x/2)**n
}
b /= a
}
} else {
// use backward recurrence
// x x**2 x**2
// J(n,x)/J(n-1,x) = ---- ------ ------ .....
// 2n - 2(n+1) - 2(n+2)
//
// 1 1 1
// (for large x) = ---- ------ ------ .....
// 2n 2(n+1) 2(n+2)
// -- - ------ - ------ -
// x x x
//
// Let w = 2n/x and h=2/x, then the above quotient
// is equal to the continued fraction:
// 1
// = -----------------------
// 1
// w - -----------------
// 1
// w+h - ---------
// w+2h - ...
//
// To determine how many terms needed, let
// Q(0) = w, Q(1) = w(w+h) - 1,
// Q(k) = (w+k*h)*Q(k-1) - Q(k-2),
// When Q(k) > 1e4 good for single
// When Q(k) > 1e9 good for double
// When Q(k) > 1e17 good for quadruple
// determine k
w := float64(n+n) / x
h := 2 / x
q0 := w
z := w + h
q1 := w*z - 1
k := 1
for q1 < 1e9 {
k++
z += h
q0, q1 = q1, z*q1-q0
}
m := n + n
t := 0.0
for i := 2 * (n + k); i >= m; i -= 2 {
t = 1 / (float64(i)/x - t)
}
a := t
b = 1
// estimate log((2/x)**n*n!) = n*log(2/x)+n*ln(n)
// Hence, if n*(log(2n/x)) > ...
// single 8.8722839355e+01
// double 7.09782712893383973096e+02
// long double 1.1356523406294143949491931077970765006170e+04
// then recurrent value may overflow and the result is
// likely underflow to zero
tmp := float64(n)
v := 2 / x
tmp = tmp * Log(Abs(v*tmp))
if tmp < 7.09782712893383973096e+02 {
for i := n - 1; i > 0; i-- {
di := float64(i + i)
a, b = b, b*di/x-a
}
} else {
for i := n - 1; i > 0; i-- {
di := float64(i + i)
a, b = b, b*di/x-a
// scale b to avoid spurious overflow
if b > 1e100 {
a /= b
t /= b
b = 1
}
}
}
b = t * J0(x) / b
}
}
if sign {
return -b
}
return b
}
// Yn returns the order-n Bessel function of the second kind.
//
// Special cases are:
//
// Yn(n, +Inf) = 0
// Yn(n ≥ 0, 0) = -Inf
// Yn(n < 0, 0) = +Inf if n is odd, -Inf if n is even
// Yn(n, x < 0) = NaN
// Yn(n, NaN) = NaN
func Yn(n int, x float64) float64 {
const Two302 = 1 << 302 // 2**302 0x52D0000000000000
// special cases
switch {
case x < 0 || IsNaN(x):
return NaN()
case IsInf(x, 1):
return 0
}
if n == 0 {
return Y0(x)
}
if x == 0 {
if n < 0 && n&1 == 1 {
return Inf(1)
}
return Inf(-1)
}
sign := false
if n < 0 {
n = -n
if n&1 == 1 {
sign = true // sign true if n < 0 && |n| odd
}
}
if n == 1 {
if sign {
return -Y1(x)
}
return Y1(x)
}
var b float64
if x >= Two302 { // x > 2**302
// (x >> n**2)
// Jn(x) = cos(x-(2n+1)*pi/4)*sqrt(2/x*pi)
// Yn(x) = sin(x-(2n+1)*pi/4)*sqrt(2/x*pi)
// Let s=sin(x), c=cos(x),
// xn=x-(2n+1)*pi/4, sqt2 = sqrt(2),then
//
// n sin(xn)*sqt2 cos(xn)*sqt2
// ----------------------------------
// 0 s-c c+s
// 1 -s-c -c+s
// 2 -s+c -c-s
// 3 s+c c-s
var temp float64
switch s, c := Sincos(x); n & 3 {
case 0:
temp = s - c
case 1:
temp = -s - c
case 2:
temp = -s + c
case 3:
temp = s + c
}
b = (1 / SqrtPi) * temp / Sqrt(x)
} else {
a := Y0(x)
b = Y1(x)
// quit if b is -inf
for i := 1; i < n && !IsInf(b, -1); i++ {
a, b = b, (float64(i+i)/x)*b-a
}
}
if sign {
return -b
}
return b
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Ldexp is the inverse of Frexp.
// It returns frac × 2**exp.
//
// Special cases are:
//
// Ldexp(±0, exp) = ±0
// Ldexp(±Inf, exp) = ±Inf
// Ldexp(NaN, exp) = NaN
func Ldexp(frac float64, exp int) float64 {
if haveArchLdexp {
return archLdexp(frac, exp)
}
return ldexp(frac, exp)
}
func ldexp(frac float64, exp int) float64 {
// special cases
switch {
case frac == 0:
return frac // correctly return -0
case IsInf(frac, 0) || IsNaN(frac):
return frac
}
frac, e := normalize(frac)
exp += e
x := Float64bits(frac)
exp += int(x>>shift)&mask - bias
if exp < -1075 {
return Copysign(0, frac) // underflow
}
if exp > 1023 { // overflow
if frac < 0 {
return Inf(-1)
}
return Inf(1)
}
var m float64 = 1
if exp < -1022 { // denormal
exp += 53
m = 1.0 / (1 << 53) // 2**-53
}
x &^= mask << shift
x |= uint64(exp+bias) << shift
return m * Float64frombits(x)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point logarithm of the Gamma function.
*/
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_lgamma_r.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_lgamma_r(x, signgamp)
// Reentrant version of the logarithm of the Gamma function
// with user provided pointer for the sign of Gamma(x).
//
// Method:
// 1. Argument Reduction for 0 < x <= 8
// Since gamma(1+s)=s*gamma(s), for x in [0,8], we may
// reduce x to a number in [1.5,2.5] by
// lgamma(1+s) = log(s) + lgamma(s)
// for example,
// lgamma(7.3) = log(6.3) + lgamma(6.3)
// = log(6.3*5.3) + lgamma(5.3)
// = log(6.3*5.3*4.3*3.3*2.3) + lgamma(2.3)
// 2. Polynomial approximation of lgamma around its
// minimum (ymin=1.461632144968362245) to maintain monotonicity.
// On [ymin-0.23, ymin+0.27] (i.e., [1.23164,1.73163]), use
// Let z = x-ymin;
// lgamma(x) = -1.214862905358496078218 + z**2*poly(z)
// poly(z) is a 14 degree polynomial.
// 2. Rational approximation in the primary interval [2,3]
// We use the following approximation:
// s = x-2.0;
// lgamma(x) = 0.5*s + s*P(s)/Q(s)
// with accuracy
// |P/Q - (lgamma(x)-0.5s)| < 2**-61.71
// Our algorithms are based on the following observation
//
// zeta(2)-1 2 zeta(3)-1 3
// lgamma(2+s) = s*(1-Euler) + --------- * s - --------- * s + ...
// 2 3
//
// where Euler = 0.5772156649... is the Euler constant, which
// is very close to 0.5.
//
// 3. For x>=8, we have
// lgamma(x)~(x-0.5)log(x)-x+0.5*log(2pi)+1/(12x)-1/(360x**3)+....
// (better formula:
// lgamma(x)~(x-0.5)*(log(x)-1)-.5*(log(2pi)-1) + ...)
// Let z = 1/x, then we approximation
// f(z) = lgamma(x) - (x-0.5)(log(x)-1)
// by
// 3 5 11
// w = w0 + w1*z + w2*z + w3*z + ... + w6*z
// where
// |w - f(z)| < 2**-58.74
//
// 4. For negative x, since (G is gamma function)
// -x*G(-x)*G(x) = pi/sin(pi*x),
// we have
// G(x) = pi/(sin(pi*x)*(-x)*G(-x))
// since G(-x) is positive, sign(G(x)) = sign(sin(pi*x)) for x<0
// Hence, for x<0, signgam = sign(sin(pi*x)) and
// lgamma(x) = log(|Gamma(x)|)
// = log(pi/(|x*sin(pi*x)|)) - lgamma(-x);
// Note: one should avoid computing pi*(-x) directly in the
// computation of sin(pi*(-x)).
//
// 5. Special Cases
// lgamma(2+s) ~ s*(1-Euler) for tiny s
// lgamma(1)=lgamma(2)=0
// lgamma(x) ~ -log(x) for tiny x
// lgamma(0) = lgamma(inf) = inf
// lgamma(-integer) = +-inf
//
//
var _lgamA = [...]float64{
7.72156649015328655494e-02, // 0x3FB3C467E37DB0C8
3.22467033424113591611e-01, // 0x3FD4A34CC4A60FAD
6.73523010531292681824e-02, // 0x3FB13E001A5562A7
2.05808084325167332806e-02, // 0x3F951322AC92547B
7.38555086081402883957e-03, // 0x3F7E404FB68FEFE8
2.89051383673415629091e-03, // 0x3F67ADD8CCB7926B
1.19270763183362067845e-03, // 0x3F538A94116F3F5D
5.10069792153511336608e-04, // 0x3F40B6C689B99C00
2.20862790713908385557e-04, // 0x3F2CF2ECED10E54D
1.08011567247583939954e-04, // 0x3F1C5088987DFB07
2.52144565451257326939e-05, // 0x3EFA7074428CFA52
4.48640949618915160150e-05, // 0x3F07858E90A45837
}
var _lgamR = [...]float64{
1.0, // placeholder
1.39200533467621045958e+00, // 0x3FF645A762C4AB74
7.21935547567138069525e-01, // 0x3FE71A1893D3DCDC
1.71933865632803078993e-01, // 0x3FC601EDCCFBDF27
1.86459191715652901344e-02, // 0x3F9317EA742ED475
7.77942496381893596434e-04, // 0x3F497DDACA41A95B
7.32668430744625636189e-06, // 0x3EDEBAF7A5B38140
}
var _lgamS = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
2.14982415960608852501e-01, // 0x3FCB848B36E20878
3.25778796408930981787e-01, // 0x3FD4D98F4F139F59
1.46350472652464452805e-01, // 0x3FC2BB9CBEE5F2F7
2.66422703033638609560e-02, // 0x3F9B481C7E939961
1.84028451407337715652e-03, // 0x3F5E26B67368F239
3.19475326584100867617e-05, // 0x3F00BFECDD17E945
}
var _lgamT = [...]float64{
4.83836122723810047042e-01, // 0x3FDEF72BC8EE38A2
-1.47587722994593911752e-01, // 0xBFC2E4278DC6C509
6.46249402391333854778e-02, // 0x3FB08B4294D5419B
-3.27885410759859649565e-02, // 0xBFA0C9A8DF35B713
1.79706750811820387126e-02, // 0x3F9266E7970AF9EC
-1.03142241298341437450e-02, // 0xBF851F9FBA91EC6A
6.10053870246291332635e-03, // 0x3F78FCE0E370E344
-3.68452016781138256760e-03, // 0xBF6E2EFFB3E914D7
2.25964780900612472250e-03, // 0x3F6282D32E15C915
-1.40346469989232843813e-03, // 0xBF56FE8EBF2D1AF1
8.81081882437654011382e-04, // 0x3F4CDF0CEF61A8E9
-5.38595305356740546715e-04, // 0xBF41A6109C73E0EC
3.15632070903625950361e-04, // 0x3F34AF6D6C0EBBF7
-3.12754168375120860518e-04, // 0xBF347F24ECC38C38
3.35529192635519073543e-04, // 0x3F35FD3EE8C2D3F4
}
var _lgamU = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
6.32827064025093366517e-01, // 0x3FE4401E8B005DFF
1.45492250137234768737e+00, // 0x3FF7475CD119BD6F
9.77717527963372745603e-01, // 0x3FEF497644EA8450
2.28963728064692451092e-01, // 0x3FCD4EAEF6010924
1.33810918536787660377e-02, // 0x3F8B678BBF2BAB09
}
var _lgamV = [...]float64{
1.0,
2.45597793713041134822e+00, // 0x4003A5D7C2BD619C
2.12848976379893395361e+00, // 0x40010725A42B18F5
7.69285150456672783825e-01, // 0x3FE89DFBE45050AF
1.04222645593369134254e-01, // 0x3FBAAE55D6537C88
3.21709242282423911810e-03, // 0x3F6A5ABB57D0CF61
}
var _lgamW = [...]float64{
4.18938533204672725052e-01, // 0x3FDACFE390C97D69
8.33333333333329678849e-02, // 0x3FB555555555553B
-2.77777777728775536470e-03, // 0xBF66C16C16B02E5C
7.93650558643019558500e-04, // 0x3F4A019F98CF38B6
-5.95187557450339963135e-04, // 0xBF4380CB8C0FE741
8.36339918996282139126e-04, // 0x3F4B67BA4CDAD5D1
-1.63092934096575273989e-03, // 0xBF5AB89D0B9E43E4
}
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
//
// Special cases are:
//
// Lgamma(+Inf) = +Inf
// Lgamma(0) = +Inf
// Lgamma(-integer) = +Inf
// Lgamma(-Inf) = -Inf
// Lgamma(NaN) = NaN
func Lgamma(x float64) (lgamma float64, sign int) {
const (
Ymin = 1.461632144968362245
Two52 = 1 << 52 // 0x4330000000000000 ~4.5036e+15
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
// Tt = -(tail of Tf)
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
)
// special cases
sign = 1
switch {
case IsNaN(x):
lgamma = x
return
case IsInf(x, 0):
lgamma = x
return
case x == 0:
lgamma = Inf(1)
return
}
neg := false
if x < 0 {
x = -x
neg = true
}
if x < Tiny { // if |x| < 2**-70, return -log(|x|)
if neg {
sign = -1
}
lgamma = -Log(x)
return
}
var nadj float64
if neg {
if x >= Two52 { // |x| >= 2**52, must be -integer
lgamma = Inf(1)
return
}
t := sinPi(x)
if t == 0 {
lgamma = Inf(1) // -integer
return
}
nadj = Log(Pi / Abs(t*x))
if t < 0 {
sign = -1
}
}
switch {
case x == 1 || x == 2: // purge off 1 and 2
lgamma = 0
return
case x < 2: // use lgamma(x) = lgamma(x+1) - log(x)
var y float64
var i int
if x <= 0.9 {
lgamma = -Log(x)
switch {
case x >= (Ymin - 1 + 0.27): // 0.7316 <= x <= 0.9
y = 1 - x
i = 0
case x >= (Ymin - 1 - 0.27): // 0.2316 <= x < 0.7316
y = x - (Tc - 1)
i = 1
default: // 0 < x < 0.2316
y = x
i = 2
}
} else {
lgamma = 0
switch {
case x >= (Ymin + 0.27): // 1.7316 <= x < 2
y = 2 - x
i = 0
case x >= (Ymin - 0.27): // 1.2316 <= x < 1.7316
y = x - Tc
i = 1
default: // 0.9 < x < 1.2316
y = x - 1
i = 2
}
}
switch i {
case 0:
z := y * y
p1 := _lgamA[0] + z*(_lgamA[2]+z*(_lgamA[4]+z*(_lgamA[6]+z*(_lgamA[8]+z*_lgamA[10]))))
p2 := z * (_lgamA[1] + z*(+_lgamA[3]+z*(_lgamA[5]+z*(_lgamA[7]+z*(_lgamA[9]+z*_lgamA[11])))))
p := y*p1 + p2
lgamma += (p - 0.5*y)
case 1:
z := y * y
w := z * y
p1 := _lgamT[0] + w*(_lgamT[3]+w*(_lgamT[6]+w*(_lgamT[9]+w*_lgamT[12]))) // parallel comp
p2 := _lgamT[1] + w*(_lgamT[4]+w*(_lgamT[7]+w*(_lgamT[10]+w*_lgamT[13])))
p3 := _lgamT[2] + w*(_lgamT[5]+w*(_lgamT[8]+w*(_lgamT[11]+w*_lgamT[14])))
p := z*p1 - (Tt - w*(p2+y*p3))
lgamma += (Tf + p)
case 2:
p1 := y * (_lgamU[0] + y*(_lgamU[1]+y*(_lgamU[2]+y*(_lgamU[3]+y*(_lgamU[4]+y*_lgamU[5])))))
p2 := 1 + y*(_lgamV[1]+y*(_lgamV[2]+y*(_lgamV[3]+y*(_lgamV[4]+y*_lgamV[5]))))
lgamma += (-0.5*y + p1/p2)
}
case x < 8: // 2 <= x < 8
i := int(x)
y := x - float64(i)
p := y * (_lgamS[0] + y*(_lgamS[1]+y*(_lgamS[2]+y*(_lgamS[3]+y*(_lgamS[4]+y*(_lgamS[5]+y*_lgamS[6]))))))
q := 1 + y*(_lgamR[1]+y*(_lgamR[2]+y*(_lgamR[3]+y*(_lgamR[4]+y*(_lgamR[5]+y*_lgamR[6])))))
lgamma = 0.5*y + p/q
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
switch i {
case 7:
z *= (y + 6)
fallthrough
case 6:
z *= (y + 5)
fallthrough
case 5:
z *= (y + 4)
fallthrough
case 4:
z *= (y + 3)
fallthrough
case 3:
z *= (y + 2)
lgamma += Log(z)
}
case x < Two58: // 8 <= x < 2**58
t := Log(x)
z := 1 / x
y := z * z
w := _lgamW[0] + z*(_lgamW[1]+y*(_lgamW[2]+y*(_lgamW[3]+y*(_lgamW[4]+y*(_lgamW[5]+y*_lgamW[6])))))
lgamma = (x-0.5)*(t-1) + w
default: // 2**58 <= x <= Inf
lgamma = x * (Log(x) - 1)
}
if neg {
lgamma = nadj - lgamma
}
return
}
// sinPi(x) is a helper function for negative x
func sinPi(x float64) float64 {
const (
Two52 = 1 << 52 // 0x4330000000000000 ~4.5036e+15
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
)
if x < 0.25 {
return -Sin(Pi * x)
}
// argument reduction
z := Floor(x)
var n int
if z != x { // inexact
x = Mod(x, 2)
n = int(x * 4)
} else {
if x >= Two53 { // x must be even
x = 0
n = 0
} else {
if x < Two52 {
z = x + Two52 // exact
}
n = int(1 & Float64bits(z))
x = float64(n)
n <<= 2
}
}
switch n {
case 0:
x = Sin(Pi * x)
case 1, 2:
x = Cos(Pi * (0.5 - x))
case 3, 4:
x = Sin(Pi * (1 - x))
case 5, 6:
x = -Cos(Pi * (x - 1.5))
default:
x = Sin(Pi * (x - 2))
}
return -x
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point logarithm.
*/
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/e_log.c
// and came with this notice. The go code is a simpler
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_log(x)
// Return the logarithm of x
//
// Method :
// 1. Argument Reduction: find k and f such that
// x = 2**k * (1+f),
// where sqrt(2)/2 < 1+f < sqrt(2) .
//
// 2. Approximation of log(1+f).
// Let s = f/(2+f) ; based on log(1+f) = log(1+s) - log(1-s)
// = 2s + 2/3 s**3 + 2/5 s**5 + .....,
// = 2s + s*R
// We use a special Reme algorithm on [0,0.1716] to generate
// a polynomial of degree 14 to approximate R. The maximum error
// of this polynomial approximation is bounded by 2**-58.45. In
// other words,
// 2 4 6 8 10 12 14
// R(z) ~ L1*s +L2*s +L3*s +L4*s +L5*s +L6*s +L7*s
// (the values of L1 to L7 are listed in the program) and
// | 2 14 | -58.45
// | L1*s +...+L7*s - R(z) | <= 2
// | |
// Note that 2s = f - s*f = f - hfsq + s*hfsq, where hfsq = f*f/2.
// In order to guarantee error in log below 1ulp, we compute log by
// log(1+f) = f - s*(f - R) (if f is not too large)
// log(1+f) = f - (hfsq - s*(hfsq+R)). (better accuracy)
//
// 3. Finally, log(x) = k*Ln2 + log(1+f).
// = k*Ln2_hi+(f-(hfsq-(s*(hfsq+R)+k*Ln2_lo)))
// Here Ln2 is split into two floating point number:
// Ln2_hi + Ln2_lo,
// where n*Ln2_hi is always exact for |n| < 2000.
//
// Special cases:
// log(x) is NaN with signal if x < 0 (including -INF) ;
// log(+INF) is +INF; log(0) is -INF with signal;
// log(NaN) is that NaN with no signal.
//
// Accuracy:
// according to an error analysis, the error is always less than
// 1 ulp (unit in the last place).
//
// Constants:
// The hexadecimal values are the intended ones for the following
// constants. The decimal values may be used, provided that the
// compiler will convert from decimal to binary accurately enough
// to produce the hexadecimal values shown.
// Log returns the natural logarithm of x.
//
// Special cases are:
//
// Log(+Inf) = +Inf
// Log(0) = -Inf
// Log(x < 0) = NaN
// Log(NaN) = NaN
func Log(x float64) float64 {
if haveArchLog {
return archLog(x)
}
return log(x)
}
func log(x float64) float64 {
const (
Ln2Hi = 6.93147180369123816490e-01 /* 3fe62e42 fee00000 */
Ln2Lo = 1.90821492927058770002e-10 /* 3dea39ef 35793c76 */
L1 = 6.666666666666735130e-01 /* 3FE55555 55555593 */
L2 = 3.999999999940941908e-01 /* 3FD99999 9997FA04 */
L3 = 2.857142874366239149e-01 /* 3FD24924 94229359 */
L4 = 2.222219843214978396e-01 /* 3FCC71C5 1D8E78AF */
L5 = 1.818357216161805012e-01 /* 3FC74664 96CB03DE */
L6 = 1.531383769920937332e-01 /* 3FC39A09 D078C69F */
L7 = 1.479819860511658591e-01 /* 3FC2F112 DF3E5244 */
)
// special cases
switch {
case IsNaN(x) || IsInf(x, 1):
return x
case x < 0:
return NaN()
case x == 0:
return Inf(-1)
}
// reduce
f1, ki := Frexp(x)
if f1 < Sqrt2/2 {
f1 *= 2
ki--
}
f := f1 - 1
k := float64(ki)
// compute
s := f / (2 + f)
s2 := s * s
s4 := s2 * s2
t1 := s2 * (L1 + s4*(L3+s4*(L5+s4*L7)))
t2 := s4 * (L2 + s4*(L4+s4*L6))
R := t1 + t2
hfsq := 0.5 * f * f
return k*Ln2Hi - ((hfsq - (s*(hfsq+R) + k*Ln2Lo)) - f)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Log10 returns the decimal logarithm of x.
// The special cases are the same as for Log.
func Log10(x float64) float64 {
if haveArchLog10 {
return archLog10(x)
}
return log10(x)
}
func log10(x float64) float64 {
return Log(x) * (1 / Ln10)
}
// Log2 returns the binary logarithm of x.
// The special cases are the same as for Log.
func Log2(x float64) float64 {
if haveArchLog2 {
return archLog2(x)
}
return log2(x)
}
func log2(x float64) float64 {
frac, exp := Frexp(x)
// Make sure exact powers of two give an exact answer.
// Don't depend on Log(0.5)*(1/Ln2)+exp being exactly exp-1.
if frac == 0.5 {
return float64(exp - 1)
}
return Log(frac)*(1/Ln2) + float64(exp)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below are from FreeBSD's /usr/src/lib/msun/src/s_log1p.c
// and came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
//
// double log1p(double x)
//
// Method :
// 1. Argument Reduction: find k and f such that
// 1+x = 2**k * (1+f),
// where sqrt(2)/2 < 1+f < sqrt(2) .
//
// Note. If k=0, then f=x is exact. However, if k!=0, then f
// may not be representable exactly. In that case, a correction
// term is need. Let u=1+x rounded. Let c = (1+x)-u, then
// log(1+x) - log(u) ~ c/u. Thus, we proceed to compute log(u),
// and add back the correction term c/u.
// (Note: when x > 2**53, one can simply return log(x))
//
// 2. Approximation of log1p(f).
// Let s = f/(2+f) ; based on log(1+f) = log(1+s) - log(1-s)
// = 2s + 2/3 s**3 + 2/5 s**5 + .....,
// = 2s + s*R
// We use a special Reme algorithm on [0,0.1716] to generate
// a polynomial of degree 14 to approximate R The maximum error
// of this polynomial approximation is bounded by 2**-58.45. In
// other words,
// 2 4 6 8 10 12 14
// R(z) ~ Lp1*s +Lp2*s +Lp3*s +Lp4*s +Lp5*s +Lp6*s +Lp7*s
// (the values of Lp1 to Lp7 are listed in the program)
// and
// | 2 14 | -58.45
// | Lp1*s +...+Lp7*s - R(z) | <= 2
// | |
// Note that 2s = f - s*f = f - hfsq + s*hfsq, where hfsq = f*f/2.
// In order to guarantee error in log below 1ulp, we compute log
// by
// log1p(f) = f - (hfsq - s*(hfsq+R)).
//
// 3. Finally, log1p(x) = k*ln2 + log1p(f).
// = k*ln2_hi+(f-(hfsq-(s*(hfsq+R)+k*ln2_lo)))
// Here ln2 is split into two floating point number:
// ln2_hi + ln2_lo,
// where n*ln2_hi is always exact for |n| < 2000.
//
// Special cases:
// log1p(x) is NaN with signal if x < -1 (including -INF) ;
// log1p(+INF) is +INF; log1p(-1) is -INF with signal;
// log1p(NaN) is that NaN with no signal.
//
// Accuracy:
// according to an error analysis, the error is always less than
// 1 ulp (unit in the last place).
//
// Constants:
// The hexadecimal values are the intended ones for the following
// constants. The decimal values may be used, provided that the
// compiler will convert from decimal to binary accurately enough
// to produce the hexadecimal values shown.
//
// Note: Assuming log() return accurate answer, the following
// algorithm can be used to compute log1p(x) to within a few ULP:
//
// u = 1+x;
// if(u==1.0) return x ; else
// return log(u)*(x/(u-1.0));
//
// See HP-15C Advanced Functions Handbook, p.193.
// Log1p returns the natural logarithm of 1 plus its argument x.
// It is more accurate than Log(1 + x) when x is near zero.
//
// Special cases are:
//
// Log1p(+Inf) = +Inf
// Log1p(±0) = ±0
// Log1p(-1) = -Inf
// Log1p(x < -1) = NaN
// Log1p(NaN) = NaN
func Log1p(x float64) float64 {
if haveArchLog1p {
return archLog1p(x)
}
return log1p(x)
}
func log1p(x float64) float64 {
const (
Sqrt2M1 = 4.142135623730950488017e-01 // Sqrt(2)-1 = 0x3fda827999fcef34
Sqrt2HalfM1 = -2.928932188134524755992e-01 // Sqrt(2)/2-1 = 0xbfd2bec333018866
Small = 1.0 / (1 << 29) // 2**-29 = 0x3e20000000000000
Tiny = 1.0 / (1 << 54) // 2**-54
Two53 = 1 << 53 // 2**53
Ln2Hi = 6.93147180369123816490e-01 // 3fe62e42fee00000
Ln2Lo = 1.90821492927058770002e-10 // 3dea39ef35793c76
Lp1 = 6.666666666666735130e-01 // 3FE5555555555593
Lp2 = 3.999999999940941908e-01 // 3FD999999997FA04
Lp3 = 2.857142874366239149e-01 // 3FD2492494229359
Lp4 = 2.222219843214978396e-01 // 3FCC71C51D8E78AF
Lp5 = 1.818357216161805012e-01 // 3FC7466496CB03DE
Lp6 = 1.531383769920937332e-01 // 3FC39A09D078C69F
Lp7 = 1.479819860511658591e-01 // 3FC2F112DF3E5244
)
// special cases
switch {
case x < -1 || IsNaN(x): // includes -Inf
return NaN()
case x == -1:
return Inf(-1)
case IsInf(x, 1):
return Inf(1)
}
absx := Abs(x)
var f float64
var iu uint64
k := 1
if absx < Sqrt2M1 { // |x| < Sqrt(2)-1
if absx < Small { // |x| < 2**-29
if absx < Tiny { // |x| < 2**-54
return x
}
return x - x*x*0.5
}
if x > Sqrt2HalfM1 { // Sqrt(2)/2-1 < x
// (Sqrt(2)/2-1) < x < (Sqrt(2)-1)
k = 0
f = x
iu = 1
}
}
var c float64
if k != 0 {
var u float64
if absx < Two53 { // 1<<53
u = 1.0 + x
iu = Float64bits(u)
k = int((iu >> 52) - 1023)
// correction term
if k > 0 {
c = 1.0 - (u - x)
} else {
c = x - (u - 1.0)
}
c /= u
} else {
u = x
iu = Float64bits(u)
k = int((iu >> 52) - 1023)
c = 0
}
iu &= 0x000fffffffffffff
if iu < 0x0006a09e667f3bcd { // mantissa of Sqrt(2)
u = Float64frombits(iu | 0x3ff0000000000000) // normalize u
} else {
k++
u = Float64frombits(iu | 0x3fe0000000000000) // normalize u/2
iu = (0x0010000000000000 - iu) >> 2
}
f = u - 1.0 // Sqrt(2)/2 < u < Sqrt(2)
}
hfsq := 0.5 * f * f
var s, R, z float64
if iu == 0 { // |f| < 2**-20
if f == 0 {
if k == 0 {
return 0
}
c += float64(k) * Ln2Lo
return float64(k)*Ln2Hi + c
}
R = hfsq * (1.0 - 0.66666666666666666*f) // avoid division
if k == 0 {
return f - R
}
return float64(k)*Ln2Hi - ((R - (float64(k)*Ln2Lo + c)) - f)
}
s = f / (2.0 + f)
z = s * s
R = z * (Lp1 + z*(Lp2+z*(Lp3+z*(Lp4+z*(Lp5+z*(Lp6+z*Lp7))))))
if k == 0 {
return f - (hfsq - s*(hfsq+R))
}
return float64(k)*Ln2Hi - ((hfsq - (s*(hfsq+R) + (float64(k)*Ln2Lo + c))) - f)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Logb returns the binary exponent of x.
//
// Special cases are:
//
// Logb(±Inf) = +Inf
// Logb(0) = -Inf
// Logb(NaN) = NaN
func Logb(x float64) float64 {
// special cases
switch {
case x == 0:
return Inf(-1)
case IsInf(x, 0):
return Inf(1)
case IsNaN(x):
return x
}
return float64(ilogb(x))
}
// Ilogb returns the binary exponent of x as an integer.
//
// Special cases are:
//
// Ilogb(±Inf) = MaxInt32
// Ilogb(0) = MinInt32
// Ilogb(NaN) = MaxInt32
func Ilogb(x float64) int {
// special cases
switch {
case x == 0:
return MinInt32
case IsNaN(x):
return MaxInt32
case IsInf(x, 0):
return MaxInt32
}
return ilogb(x)
}
// ilogb returns the binary exponent of x. It assumes x is finite and
// non-zero.
func ilogb(x float64) int {
x, exp := normalize(x)
return int((Float64bits(x)>>shift)&mask) - bias + exp
}
// Copyright 2009-2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point mod function.
*/
// Mod returns the floating-point remainder of x/y.
// The magnitude of the result is less than y and its
// sign agrees with that of x.
//
// Special cases are:
//
// Mod(±Inf, y) = NaN
// Mod(NaN, y) = NaN
// Mod(x, 0) = NaN
// Mod(x, ±Inf) = x
// Mod(x, NaN) = NaN
func Mod(x, y float64) float64 {
if haveArchMod {
return archMod(x, y)
}
return mod(x, y)
}
func mod(x, y float64) float64 {
if y == 0 || IsInf(x, 0) || IsNaN(x) || IsNaN(y) {
return NaN()
}
y = Abs(y)
yfr, yexp := Frexp(y)
r := x
if x < 0 {
r = -x
}
for r >= y {
rfr, rexp := Frexp(r)
if rfr < yfr {
rexp = rexp - 1
}
r = r - Ldexp(y, rexp-yexp)
}
if x < 0 {
r = -r
}
return r
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Modf returns integer and fractional floating-point numbers
// that sum to f. Both values have the same sign as f.
//
// Special cases are:
//
// Modf(±Inf) = ±Inf, NaN
// Modf(NaN) = NaN, NaN
func Modf(f float64) (int float64, frac float64) {
if haveArchModf {
return archModf(f)
}
return modf(f)
}
func modf(f float64) (int float64, frac float64) {
if f < 1 {
switch {
case f < 0:
int, frac = Modf(-f)
return -int, -frac
case f == 0:
return f, f // Return -0, -0 when f == -0
}
return 0, f
}
x := Float64bits(f)
e := uint(x>>shift)&mask - bias
// Keep the top 12+e bits, the integer part; clear the rest.
if e < 64-12 {
x &^= 1<<(64-12-e) - 1
}
int = Float64frombits(x)
frac = f - int
return
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !arm64 && !ppc64 && !ppc64le
package math
const haveArchModf = false
func archModf(f float64) (int float64, frac float64) {
panic("not implemented")
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Nextafter32 returns the next representable float32 value after x towards y.
//
// Special cases are:
//
// Nextafter32(x, x) = x
// Nextafter32(NaN, y) = NaN
// Nextafter32(x, NaN) = NaN
func Nextafter32(x, y float32) (r float32) {
switch {
case IsNaN(float64(x)) || IsNaN(float64(y)): // special case
r = float32(NaN())
case x == y:
r = x
case x == 0:
r = float32(Copysign(float64(Float32frombits(1)), float64(y)))
case (y > x) == (x > 0):
r = Float32frombits(Float32bits(x) + 1)
default:
r = Float32frombits(Float32bits(x) - 1)
}
return
}
// Nextafter returns the next representable float64 value after x towards y.
//
// Special cases are:
//
// Nextafter(x, x) = x
// Nextafter(NaN, y) = NaN
// Nextafter(x, NaN) = NaN
func Nextafter(x, y float64) (r float64) {
switch {
case IsNaN(x) || IsNaN(y): // special case
r = NaN()
case x == y:
r = x
case x == 0:
r = Copysign(Float64frombits(1), y)
case (y > x) == (x > 0):
r = Float64frombits(Float64bits(x) + 1)
default:
r = Float64frombits(Float64bits(x) - 1)
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
func isOddInt(x float64) bool {
if Abs(x) >= (1 << 53) {
// 1 << 53 is the largest exact integer in the float64 format.
// Any number outside this range will be truncated before the decimal point and therefore will always be
// an even integer.
// Without this check and if x overflows int64 the int64(xi) conversion below may produce incorrect results
// on some architectures (and does so on arm64). See issue #57465.
return false
}
xi, xf := Modf(x)
return xf == 0 && int64(xi)&1 == 1
}
// Special cases taken from FreeBSD's /usr/src/lib/msun/src/e_pow.c
// updated by IEEE Std. 754-2008 "Section 9.2.1 Special values".
// Pow returns x**y, the base-x exponential of y.
//
// Special cases are (in order):
//
// Pow(x, ±0) = 1 for any x
// Pow(1, y) = 1 for any y
// Pow(x, 1) = x for any x
// Pow(NaN, y) = NaN
// Pow(x, NaN) = NaN
// Pow(±0, y) = ±Inf for y an odd integer < 0
// Pow(±0, -Inf) = +Inf
// Pow(±0, +Inf) = +0
// Pow(±0, y) = +Inf for finite y < 0 and not an odd integer
// Pow(±0, y) = ±0 for y an odd integer > 0
// Pow(±0, y) = +0 for finite y > 0 and not an odd integer
// Pow(-1, ±Inf) = 1
// Pow(x, +Inf) = +Inf for |x| > 1
// Pow(x, -Inf) = +0 for |x| > 1
// Pow(x, +Inf) = +0 for |x| < 1
// Pow(x, -Inf) = +Inf for |x| < 1
// Pow(+Inf, y) = +Inf for y > 0
// Pow(+Inf, y) = +0 for y < 0
// Pow(-Inf, y) = Pow(-0, -y)
// Pow(x, y) = NaN for finite x < 0 and finite non-integer y
func Pow(x, y float64) float64 {
if haveArchPow {
return archPow(x, y)
}
return pow(x, y)
}
func pow(x, y float64) float64 {
switch {
case y == 0 || x == 1:
return 1
case y == 1:
return x
case IsNaN(x) || IsNaN(y):
return NaN()
case x == 0:
switch {
case y < 0:
if Signbit(x) && isOddInt(y) {
return Inf(-1)
}
return Inf(1)
case y > 0:
if Signbit(x) && isOddInt(y) {
return x
}
return 0
}
case IsInf(y, 0):
switch {
case x == -1:
return 1
case (Abs(x) < 1) == IsInf(y, 1):
return 0
default:
return Inf(1)
}
case IsInf(x, 0):
if IsInf(x, -1) {
return Pow(1/x, -y) // Pow(-0, -y)
}
switch {
case y < 0:
return 0
case y > 0:
return Inf(1)
}
case y == 0.5:
return Sqrt(x)
case y == -0.5:
return 1 / Sqrt(x)
}
yi, yf := Modf(Abs(y))
if yf != 0 && x < 0 {
return NaN()
}
if yi >= 1<<63 {
// yi is a large even int that will lead to overflow (or underflow to 0)
// for all x except -1 (x == 1 was handled earlier)
switch {
case x == -1:
return 1
case (Abs(x) < 1) == (y > 0):
return 0
default:
return Inf(1)
}
}
// ans = a1 * 2**ae (= 1 for now).
a1 := 1.0
ae := 0
// ans *= x**yf
if yf != 0 {
if yf > 0.5 {
yf--
yi++
}
a1 = Exp(yf * Log(x))
}
// ans *= x**yi
// by multiplying in successive squarings
// of x according to bits of yi.
// accumulate powers of two into exp.
x1, xe := Frexp(x)
for i := int64(yi); i != 0; i >>= 1 {
if xe < -1<<12 || 1<<12 < xe {
// catch xe before it overflows the left shift below
// Since i !=0 it has at least one bit still set, so ae will accumulate xe
// on at least one more iteration, ae += xe is a lower bound on ae
// the lower bound on ae exceeds the size of a float64 exp
// so the final call to Ldexp will produce under/overflow (0/Inf)
ae += xe
break
}
if i&1 == 1 {
a1 *= x1
ae += xe
}
x1 *= x1
xe <<= 1
if x1 < .5 {
x1 += x1
xe--
}
}
// ans = a1*2**ae
// if y < 0 { ans = 1 / ans }
// but in the opposite order
if y < 0 {
a1 = 1 / a1
ae = -ae
}
return Ldexp(a1, ae)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// pow10tab stores the pre-computed values 10**i for i < 32.
var pow10tab = [...]float64{
1e00, 1e01, 1e02, 1e03, 1e04, 1e05, 1e06, 1e07, 1e08, 1e09,
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29,
1e30, 1e31,
}
// pow10postab32 stores the pre-computed value for 10**(i*32) at index i.
var pow10postab32 = [...]float64{
1e00, 1e32, 1e64, 1e96, 1e128, 1e160, 1e192, 1e224, 1e256, 1e288,
}
// pow10negtab32 stores the pre-computed value for 10**(-i*32) at index i.
var pow10negtab32 = [...]float64{
1e-00, 1e-32, 1e-64, 1e-96, 1e-128, 1e-160, 1e-192, 1e-224, 1e-256, 1e-288, 1e-320,
}
// Pow10 returns 10**n, the base-10 exponential of n.
//
// Special cases are:
//
// Pow10(n) = 0 for n < -323
// Pow10(n) = +Inf for n > 308
func Pow10(n int) float64 {
if 0 <= n && n <= 308 {
return pow10postab32[uint(n)/32] * pow10tab[uint(n)%32]
}
if -323 <= n && n <= 0 {
return pow10negtab32[uint(-n)/32] / pow10tab[uint(-n)%32]
}
// n < -323 || 308 < n
if n > 0 {
return Inf(1)
}
// n < -323
return 0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"math"
)
/*
* Exponential distribution
*
* See "The Ziggurat Method for Generating Random Variables"
* (Marsaglia & Tsang, 2000)
* https://www.jstatsoft.org/v05/i08/paper [pdf]
*/
const (
re = 7.69711747013104972
)
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1).
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func (r *Rand) ExpFloat64() float64 {
for {
j := r.Uint32()
i := j & 0xFF
x := float64(j) * float64(we[i])
if j < ke[i] {
return x
}
if i == 0 {
return re - math.Log(r.Float64())
}
if fe[i]+float32(r.Float64())*(fe[i-1]-fe[i]) < float32(math.Exp(-x)) {
return x
}
}
}
var ke = [256]uint32{
0xe290a139, 0x0, 0x9beadebc, 0xc377ac71, 0xd4ddb990,
0xde893fb8, 0xe4a8e87c, 0xe8dff16a, 0xebf2deab, 0xee49a6e8,
0xf0204efd, 0xf19bdb8e, 0xf2d458bb, 0xf3da104b, 0xf4b86d78,
0xf577ad8a, 0xf61de83d, 0xf6afb784, 0xf730a573, 0xf7a37651,
0xf80a5bb6, 0xf867189d, 0xf8bb1b4f, 0xf9079062, 0xf94d70ca,
0xf98d8c7d, 0xf9c8928a, 0xf9ff175b, 0xfa319996, 0xfa6085f8,
0xfa8c3a62, 0xfab5084e, 0xfadb36c8, 0xfaff0410, 0xfb20a6ea,
0xfb404fb4, 0xfb5e2951, 0xfb7a59e9, 0xfb95038c, 0xfbae44ba,
0xfbc638d8, 0xfbdcf892, 0xfbf29a30, 0xfc0731df, 0xfc1ad1ed,
0xfc2d8b02, 0xfc3f6c4d, 0xfc5083ac, 0xfc60ddd1, 0xfc708662,
0xfc7f8810, 0xfc8decb4, 0xfc9bbd62, 0xfca9027c, 0xfcb5c3c3,
0xfcc20864, 0xfccdd70a, 0xfcd935e3, 0xfce42ab0, 0xfceebace,
0xfcf8eb3b, 0xfd02c0a0, 0xfd0c3f59, 0xfd156b7b, 0xfd1e48d6,
0xfd26daff, 0xfd2f2552, 0xfd372af7, 0xfd3eeee5, 0xfd4673e7,
0xfd4dbc9e, 0xfd54cb85, 0xfd5ba2f2, 0xfd62451b, 0xfd68b415,
0xfd6ef1da, 0xfd750047, 0xfd7ae120, 0xfd809612, 0xfd8620b4,
0xfd8b8285, 0xfd90bcf5, 0xfd95d15e, 0xfd9ac10b, 0xfd9f8d36,
0xfda43708, 0xfda8bf9e, 0xfdad2806, 0xfdb17141, 0xfdb59c46,
0xfdb9a9fd, 0xfdbd9b46, 0xfdc170f6, 0xfdc52bd8, 0xfdc8ccac,
0xfdcc542d, 0xfdcfc30b, 0xfdd319ef, 0xfdd6597a, 0xfdd98245,
0xfddc94e5, 0xfddf91e6, 0xfde279ce, 0xfde54d1f, 0xfde80c52,
0xfdeab7de, 0xfded5034, 0xfdefd5be, 0xfdf248e3, 0xfdf4aa06,
0xfdf6f984, 0xfdf937b6, 0xfdfb64f4, 0xfdfd818d, 0xfdff8dd0,
0xfe018a08, 0xfe03767a, 0xfe05536c, 0xfe07211c, 0xfe08dfc9,
0xfe0a8fab, 0xfe0c30fb, 0xfe0dc3ec, 0xfe0f48b1, 0xfe10bf76,
0xfe122869, 0xfe1383b4, 0xfe14d17c, 0xfe1611e7, 0xfe174516,
0xfe186b2a, 0xfe19843e, 0xfe1a9070, 0xfe1b8fd6, 0xfe1c8289,
0xfe1d689b, 0xfe1e4220, 0xfe1f0f26, 0xfe1fcfbc, 0xfe2083ed,
0xfe212bc3, 0xfe21c745, 0xfe225678, 0xfe22d95f, 0xfe234ffb,
0xfe23ba4a, 0xfe241849, 0xfe2469f2, 0xfe24af3c, 0xfe24e81e,
0xfe25148b, 0xfe253474, 0xfe2547c7, 0xfe254e70, 0xfe25485a,
0xfe25356a, 0xfe251586, 0xfe24e88f, 0xfe24ae64, 0xfe2466e1,
0xfe2411df, 0xfe23af34, 0xfe233eb4, 0xfe22c02c, 0xfe22336b,
0xfe219838, 0xfe20ee58, 0xfe20358c, 0xfe1f6d92, 0xfe1e9621,
0xfe1daef0, 0xfe1cb7ac, 0xfe1bb002, 0xfe1a9798, 0xfe196e0d,
0xfe1832fd, 0xfe16e5fe, 0xfe15869d, 0xfe141464, 0xfe128ed3,
0xfe10f565, 0xfe0f478c, 0xfe0d84b1, 0xfe0bac36, 0xfe09bd73,
0xfe07b7b5, 0xfe059a40, 0xfe03644c, 0xfe011504, 0xfdfeab88,
0xfdfc26e9, 0xfdf98629, 0xfdf6c83b, 0xfdf3ec01, 0xfdf0f04a,
0xfdedd3d1, 0xfdea953d, 0xfde7331e, 0xfde3abe9, 0xfddffdfb,
0xfddc2791, 0xfdd826cd, 0xfdd3f9a8, 0xfdcf9dfc, 0xfdcb1176,
0xfdc65198, 0xfdc15bb3, 0xfdbc2ce2, 0xfdb6c206, 0xfdb117be,
0xfdab2a63, 0xfda4f5fd, 0xfd9e7640, 0xfd97a67a, 0xfd908192,
0xfd8901f2, 0xfd812182, 0xfd78d98e, 0xfd7022bb, 0xfd66f4ed,
0xfd5d4732, 0xfd530f9c, 0xfd48432b, 0xfd3cd59a, 0xfd30b936,
0xfd23dea4, 0xfd16349e, 0xfd07a7a3, 0xfcf8219b, 0xfce7895b,
0xfcd5c220, 0xfcc2aadb, 0xfcae1d5e, 0xfc97ed4e, 0xfc7fe6d4,
0xfc65ccf3, 0xfc495762, 0xfc2a2fc8, 0xfc07ee19, 0xfbe213c1,
0xfbb8051a, 0xfb890078, 0xfb5411a5, 0xfb180005, 0xfad33482,
0xfa839276, 0xfa263b32, 0xf9b72d1c, 0xf930a1a2, 0xf889f023,
0xf7b577d2, 0xf69c650c, 0xf51530f0, 0xf2cb0e3c, 0xeeefb15d,
0xe6da6ecf,
}
var we = [256]float32{
2.0249555e-09, 1.486674e-11, 2.4409617e-11, 3.1968806e-11,
3.844677e-11, 4.4228204e-11, 4.9516443e-11, 5.443359e-11,
5.905944e-11, 6.344942e-11, 6.7643814e-11, 7.1672945e-11,
7.556032e-11, 7.932458e-11, 8.298079e-11, 8.654132e-11,
9.0016515e-11, 9.3415074e-11, 9.674443e-11, 1.0001099e-10,
1.03220314e-10, 1.06377254e-10, 1.09486115e-10, 1.1255068e-10,
1.1557435e-10, 1.1856015e-10, 1.2151083e-10, 1.2442886e-10,
1.2731648e-10, 1.3017575e-10, 1.3300853e-10, 1.3581657e-10,
1.3860142e-10, 1.4136457e-10, 1.4410738e-10, 1.4683108e-10,
1.4953687e-10, 1.5222583e-10, 1.54899e-10, 1.5755733e-10,
1.6020171e-10, 1.6283301e-10, 1.6545203e-10, 1.6805951e-10,
1.7065617e-10, 1.732427e-10, 1.7581973e-10, 1.7838787e-10,
1.8094774e-10, 1.8349985e-10, 1.8604476e-10, 1.8858298e-10,
1.9111498e-10, 1.9364126e-10, 1.9616223e-10, 1.9867835e-10,
2.0119004e-10, 2.0369768e-10, 2.0620168e-10, 2.087024e-10,
2.1120022e-10, 2.136955e-10, 2.1618855e-10, 2.1867974e-10,
2.2116936e-10, 2.2365775e-10, 2.261452e-10, 2.2863202e-10,
2.311185e-10, 2.3360494e-10, 2.360916e-10, 2.3857874e-10,
2.4106667e-10, 2.4355562e-10, 2.4604588e-10, 2.485377e-10,
2.5103128e-10, 2.5352695e-10, 2.560249e-10, 2.585254e-10,
2.6102867e-10, 2.6353494e-10, 2.6604446e-10, 2.6855745e-10,
2.7107416e-10, 2.7359479e-10, 2.761196e-10, 2.7864877e-10,
2.8118255e-10, 2.8372119e-10, 2.8626485e-10, 2.888138e-10,
2.9136826e-10, 2.939284e-10, 2.9649452e-10, 2.9906677e-10,
3.016454e-10, 3.0423064e-10, 3.0682268e-10, 3.0942177e-10,
3.1202813e-10, 3.1464195e-10, 3.1726352e-10, 3.19893e-10,
3.2253064e-10, 3.251767e-10, 3.2783135e-10, 3.3049485e-10,
3.3316744e-10, 3.3584938e-10, 3.3854083e-10, 3.4124212e-10,
3.4395342e-10, 3.46675e-10, 3.4940711e-10, 3.5215003e-10,
3.5490397e-10, 3.5766917e-10, 3.6044595e-10, 3.6323455e-10,
3.660352e-10, 3.6884823e-10, 3.7167386e-10, 3.745124e-10,
3.773641e-10, 3.802293e-10, 3.8310827e-10, 3.860013e-10,
3.8890866e-10, 3.918307e-10, 3.9476775e-10, 3.9772008e-10,
4.0068804e-10, 4.0367196e-10, 4.0667217e-10, 4.09689e-10,
4.1272286e-10, 4.1577405e-10, 4.1884296e-10, 4.2192994e-10,
4.250354e-10, 4.281597e-10, 4.313033e-10, 4.3446652e-10,
4.3764986e-10, 4.408537e-10, 4.4407847e-10, 4.4732465e-10,
4.5059267e-10, 4.5388301e-10, 4.571962e-10, 4.6053267e-10,
4.6389292e-10, 4.6727755e-10, 4.70687e-10, 4.741219e-10,
4.7758275e-10, 4.810702e-10, 4.845848e-10, 4.8812715e-10,
4.9169796e-10, 4.9529775e-10, 4.989273e-10, 5.0258725e-10,
5.0627835e-10, 5.100013e-10, 5.1375687e-10, 5.1754584e-10,
5.21369e-10, 5.2522725e-10, 5.2912136e-10, 5.330522e-10,
5.370208e-10, 5.4102806e-10, 5.45075e-10, 5.491625e-10,
5.532918e-10, 5.5746385e-10, 5.616799e-10, 5.6594107e-10,
5.7024857e-10, 5.746037e-10, 5.7900773e-10, 5.834621e-10,
5.8796823e-10, 5.925276e-10, 5.971417e-10, 6.018122e-10,
6.065408e-10, 6.113292e-10, 6.1617933e-10, 6.2109295e-10,
6.260722e-10, 6.3111916e-10, 6.3623595e-10, 6.4142497e-10,
6.4668854e-10, 6.5202926e-10, 6.5744976e-10, 6.6295286e-10,
6.6854156e-10, 6.742188e-10, 6.79988e-10, 6.858526e-10,
6.9181616e-10, 6.978826e-10, 7.04056e-10, 7.103407e-10,
7.167412e-10, 7.2326256e-10, 7.2990985e-10, 7.366886e-10,
7.4360473e-10, 7.5066453e-10, 7.5787476e-10, 7.6524265e-10,
7.7277595e-10, 7.80483e-10, 7.883728e-10, 7.9645507e-10,
8.047402e-10, 8.1323964e-10, 8.219657e-10, 8.309319e-10,
8.401528e-10, 8.496445e-10, 8.594247e-10, 8.6951274e-10,
8.799301e-10, 8.9070046e-10, 9.018503e-10, 9.134092e-10,
9.254101e-10, 9.378904e-10, 9.508923e-10, 9.644638e-10,
9.786603e-10, 9.935448e-10, 1.0091913e-09, 1.025686e-09,
1.0431306e-09, 1.0616465e-09, 1.08138e-09, 1.1025096e-09,
1.1252564e-09, 1.1498986e-09, 1.1767932e-09, 1.206409e-09,
1.2393786e-09, 1.276585e-09, 1.3193139e-09, 1.3695435e-09,
1.4305498e-09, 1.508365e-09, 1.6160854e-09, 1.7921248e-09,
}
var fe = [256]float32{
1, 0.9381437, 0.90046996, 0.87170434, 0.8477855, 0.8269933,
0.8084217, 0.7915276, 0.77595687, 0.7614634, 0.7478686,
0.7350381, 0.72286767, 0.71127474, 0.70019263, 0.6895665,
0.67935055, 0.6695063, 0.66000086, 0.65080583, 0.6418967,
0.63325197, 0.6248527, 0.6166822, 0.60872537, 0.60096896,
0.5934009, 0.58601034, 0.5787874, 0.57172304, 0.5648092,
0.5580383, 0.5514034, 0.5448982, 0.5385169, 0.53225386,
0.5261042, 0.52006316, 0.5141264, 0.50828975, 0.5025495,
0.496902, 0.49134386, 0.485872, 0.48048335, 0.4751752,
0.46994483, 0.46478975, 0.45970762, 0.45469615, 0.44975325,
0.44487688, 0.44006512, 0.43531612, 0.43062815, 0.42599955,
0.42142874, 0.4169142, 0.41245446, 0.40804818, 0.403694,
0.3993907, 0.39513698, 0.39093173, 0.38677382, 0.38266218,
0.37859577, 0.37457356, 0.37059465, 0.3666581, 0.362763,
0.35890847, 0.35509375, 0.351318, 0.3475805, 0.34388044,
0.34021714, 0.3365899, 0.33299807, 0.32944095, 0.32591796,
0.3224285, 0.3189719, 0.31554767, 0.31215525, 0.30879408,
0.3054636, 0.3021634, 0.29889292, 0.2956517, 0.29243928,
0.28925523, 0.28609908, 0.28297043, 0.27986884, 0.27679393,
0.2737453, 0.2707226, 0.2677254, 0.26475343, 0.26180625,
0.25888354, 0.25598502, 0.2531103, 0.25025907, 0.24743107,
0.24462597, 0.24184346, 0.23908329, 0.23634516, 0.23362878,
0.23093392, 0.2282603, 0.22560766, 0.22297576, 0.22036438,
0.21777324, 0.21520215, 0.21265087, 0.21011916, 0.20760682,
0.20511365, 0.20263945, 0.20018397, 0.19774707, 0.19532852,
0.19292815, 0.19054577, 0.1881812, 0.18583426, 0.18350479,
0.1811926, 0.17889754, 0.17661946, 0.17435817, 0.17211354,
0.1698854, 0.16767362, 0.16547804, 0.16329853, 0.16113494,
0.15898713, 0.15685499, 0.15473837, 0.15263714, 0.15055119,
0.14848037, 0.14642459, 0.14438373, 0.14235765, 0.14034624,
0.13834943, 0.13636707, 0.13439907, 0.13244532, 0.13050574,
0.1285802, 0.12666863, 0.12477092, 0.12288698, 0.12101672,
0.119160056, 0.1173169, 0.115487166, 0.11367077, 0.11186763,
0.11007768, 0.10830083, 0.10653701, 0.10478614, 0.10304816,
0.101323, 0.09961058, 0.09791085, 0.09622374, 0.09454919,
0.09288713, 0.091237515, 0.08960028, 0.087975375, 0.08636274,
0.08476233, 0.083174095, 0.081597984, 0.08003395, 0.07848195,
0.076941945, 0.07541389, 0.07389775, 0.072393484, 0.07090106,
0.069420435, 0.06795159, 0.066494495, 0.06504912, 0.063615434,
0.062193416, 0.060783047, 0.059384305, 0.057997175,
0.05662164, 0.05525769, 0.053905312, 0.052564494, 0.051235236,
0.049917534, 0.048611384, 0.047316793, 0.046033762, 0.0447623,
0.043502413, 0.042254124, 0.041017443, 0.039792392,
0.038578995, 0.037377283, 0.036187284, 0.035009038,
0.033842582, 0.032687962, 0.031545233, 0.030414443, 0.02929566,
0.02818895, 0.027094385, 0.026012046, 0.024942026, 0.023884421,
0.022839336, 0.021806888, 0.020787204, 0.019780423, 0.0187867,
0.0178062, 0.016839107, 0.015885621, 0.014945968, 0.014020392,
0.013109165, 0.012212592, 0.011331013, 0.01046481, 0.009614414,
0.008780315, 0.007963077, 0.0071633533, 0.006381906,
0.0056196423, 0.0048776558, 0.004157295, 0.0034602648,
0.0027887989, 0.0021459677, 0.0015362998, 0.0009672693,
0.00045413437,
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"math"
)
/*
* Normal distribution
*
* See "The Ziggurat Method for Generating Random Variables"
* (Marsaglia & Tsang, 2000)
* http://www.jstatsoft.org/v05/i08/paper [pdf]
*/
const (
rn = 3.442619855899
)
func absInt32(i int32) uint32 {
if i < 0 {
return uint32(-i)
}
return uint32(i)
}
// NormFloat64 returns a normally distributed float64 in
// the range -math.MaxFloat64 through +math.MaxFloat64 inclusive,
// with standard normal distribution (mean = 0, stddev = 1).
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func (r *Rand) NormFloat64() float64 {
for {
j := int32(r.Uint32()) // Possibly negative
i := j & 0x7F
x := float64(j) * float64(wn[i])
if absInt32(j) < kn[i] {
// This case should be hit better than 99% of the time.
return x
}
if i == 0 {
// This extra work is only required for the base strip.
for {
x = -math.Log(r.Float64()) * (1.0 / rn)
y := -math.Log(r.Float64())
if y+y >= x*x {
break
}
}
if j > 0 {
return rn + x
}
return -rn - x
}
if fn[i]+float32(r.Float64())*(fn[i-1]-fn[i]) < float32(math.Exp(-.5*x*x)) {
return x
}
}
}
var kn = [128]uint32{
0x76ad2212, 0x0, 0x600f1b53, 0x6ce447a6, 0x725b46a2,
0x7560051d, 0x774921eb, 0x789a25bd, 0x799045c3, 0x7a4bce5d,
0x7adf629f, 0x7b5682a6, 0x7bb8a8c6, 0x7c0ae722, 0x7c50cce7,
0x7c8cec5b, 0x7cc12cd6, 0x7ceefed2, 0x7d177e0b, 0x7d3b8883,
0x7d5bce6c, 0x7d78dd64, 0x7d932886, 0x7dab0e57, 0x7dc0dd30,
0x7dd4d688, 0x7de73185, 0x7df81cea, 0x7e07c0a3, 0x7e163efa,
0x7e23b587, 0x7e303dfd, 0x7e3beec2, 0x7e46db77, 0x7e51155d,
0x7e5aabb3, 0x7e63abf7, 0x7e6c222c, 0x7e741906, 0x7e7b9a18,
0x7e82adfa, 0x7e895c63, 0x7e8fac4b, 0x7e95a3fb, 0x7e9b4924,
0x7ea0a0ef, 0x7ea5b00d, 0x7eaa7ac3, 0x7eaf04f3, 0x7eb3522a,
0x7eb765a5, 0x7ebb4259, 0x7ebeeafd, 0x7ec2620a, 0x7ec5a9c4,
0x7ec8c441, 0x7ecbb365, 0x7ece78ed, 0x7ed11671, 0x7ed38d62,
0x7ed5df12, 0x7ed80cb4, 0x7eda175c, 0x7edc0005, 0x7eddc78e,
0x7edf6ebf, 0x7ee0f647, 0x7ee25ebe, 0x7ee3a8a9, 0x7ee4d473,
0x7ee5e276, 0x7ee6d2f5, 0x7ee7a620, 0x7ee85c10, 0x7ee8f4cd,
0x7ee97047, 0x7ee9ce59, 0x7eea0eca, 0x7eea3147, 0x7eea3568,
0x7eea1aab, 0x7ee9e071, 0x7ee98602, 0x7ee90a88, 0x7ee86d08,
0x7ee7ac6a, 0x7ee6c769, 0x7ee5bc9c, 0x7ee48a67, 0x7ee32efc,
0x7ee1a857, 0x7edff42f, 0x7ede0ffa, 0x7edbf8d9, 0x7ed9ab94,
0x7ed7248d, 0x7ed45fae, 0x7ed1585c, 0x7ece095f, 0x7eca6ccb,
0x7ec67be2, 0x7ec22eee, 0x7ebd7d1a, 0x7eb85c35, 0x7eb2c075,
0x7eac9c20, 0x7ea5df27, 0x7e9e769f, 0x7e964c16, 0x7e8d44ba,
0x7e834033, 0x7e781728, 0x7e6b9933, 0x7e5d8a1a, 0x7e4d9ded,
0x7e3b737a, 0x7e268c2f, 0x7e0e3ff5, 0x7df1aa5d, 0x7dcf8c72,
0x7da61a1e, 0x7d72a0fb, 0x7d30e097, 0x7cd9b4ab, 0x7c600f1a,
0x7ba90bdc, 0x7a722176, 0x77d664e5,
}
var wn = [128]float32{
1.7290405e-09, 1.2680929e-10, 1.6897518e-10, 1.9862688e-10,
2.2232431e-10, 2.4244937e-10, 2.601613e-10, 2.7611988e-10,
2.9073963e-10, 3.042997e-10, 3.1699796e-10, 3.289802e-10,
3.4035738e-10, 3.5121603e-10, 3.616251e-10, 3.7164058e-10,
3.8130857e-10, 3.9066758e-10, 3.9975012e-10, 4.08584e-10,
4.1719309e-10, 4.2559822e-10, 4.338176e-10, 4.418672e-10,
4.497613e-10, 4.5751258e-10, 4.651324e-10, 4.7263105e-10,
4.8001775e-10, 4.87301e-10, 4.944885e-10, 5.015873e-10,
5.0860405e-10, 5.155446e-10, 5.2241467e-10, 5.2921934e-10,
5.359635e-10, 5.426517e-10, 5.4928817e-10, 5.5587696e-10,
5.624219e-10, 5.6892646e-10, 5.753941e-10, 5.818282e-10,
5.882317e-10, 5.946077e-10, 6.00959e-10, 6.072884e-10,
6.135985e-10, 6.19892e-10, 6.2617134e-10, 6.3243905e-10,
6.386974e-10, 6.449488e-10, 6.511956e-10, 6.5744005e-10,
6.6368433e-10, 6.699307e-10, 6.7618144e-10, 6.824387e-10,
6.8870465e-10, 6.949815e-10, 7.012715e-10, 7.075768e-10,
7.1389966e-10, 7.202424e-10, 7.266073e-10, 7.329966e-10,
7.394128e-10, 7.4585826e-10, 7.5233547e-10, 7.58847e-10,
7.653954e-10, 7.719835e-10, 7.7861395e-10, 7.852897e-10,
7.920138e-10, 7.987892e-10, 8.0561924e-10, 8.125073e-10,
8.194569e-10, 8.2647167e-10, 8.3355556e-10, 8.407127e-10,
8.479473e-10, 8.55264e-10, 8.6266755e-10, 8.7016316e-10,
8.777562e-10, 8.8545243e-10, 8.932582e-10, 9.0117996e-10,
9.09225e-10, 9.174008e-10, 9.2571584e-10, 9.341788e-10,
9.427997e-10, 9.515889e-10, 9.605579e-10, 9.697193e-10,
9.790869e-10, 9.88676e-10, 9.985036e-10, 1.0085882e-09,
1.0189509e-09, 1.0296151e-09, 1.0406069e-09, 1.0519566e-09,
1.063698e-09, 1.0758702e-09, 1.0885183e-09, 1.1016947e-09,
1.1154611e-09, 1.1298902e-09, 1.1450696e-09, 1.1611052e-09,
1.1781276e-09, 1.1962995e-09, 1.2158287e-09, 1.2369856e-09,
1.2601323e-09, 1.2857697e-09, 1.3146202e-09, 1.347784e-09,
1.3870636e-09, 1.4357403e-09, 1.5008659e-09, 1.6030948e-09,
}
var fn = [128]float32{
1, 0.9635997, 0.9362827, 0.9130436, 0.89228165, 0.87324303,
0.8555006, 0.8387836, 0.8229072, 0.8077383, 0.793177,
0.7791461, 0.7655842, 0.7524416, 0.73967725, 0.7272569,
0.7151515, 0.7033361, 0.69178915, 0.68049186, 0.6694277,
0.658582, 0.6479418, 0.63749546, 0.6272325, 0.6171434,
0.6072195, 0.5974532, 0.58783704, 0.5783647, 0.56903,
0.5598274, 0.5507518, 0.54179835, 0.5329627, 0.52424055,
0.5156282, 0.50712204, 0.49871865, 0.49041483, 0.48220766,
0.4740943, 0.46607214, 0.4581387, 0.45029163, 0.44252872,
0.43484783, 0.427247, 0.41972435, 0.41227803, 0.40490642,
0.39760786, 0.3903808, 0.3832238, 0.37613547, 0.36911446,
0.3621595, 0.35526937, 0.34844297, 0.34167916, 0.33497685,
0.3283351, 0.3217529, 0.3152294, 0.30876362, 0.30235484,
0.29600215, 0.28970486, 0.2834622, 0.2772735, 0.27113807,
0.2650553, 0.25902456, 0.2530453, 0.24711695, 0.241239,
0.23541094, 0.22963232, 0.2239027, 0.21822165, 0.21258877,
0.20700371, 0.20146611, 0.19597565, 0.19053204, 0.18513499,
0.17978427, 0.17447963, 0.1692209, 0.16400786, 0.15884037,
0.15371831, 0.14864157, 0.14361008, 0.13862377, 0.13368265,
0.12878671, 0.12393598, 0.119130544, 0.11437051, 0.10965602,
0.104987256, 0.10036444, 0.095787846, 0.0912578, 0.08677467,
0.0823389, 0.077950984, 0.073611505, 0.06932112, 0.06508058,
0.06089077, 0.056752663, 0.0526674, 0.048636295, 0.044660863,
0.040742867, 0.03688439, 0.033087887, 0.029356318,
0.025693292, 0.022103304, 0.018592102, 0.015167298,
0.011839478, 0.008624485, 0.005548995, 0.0026696292,
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rand implements pseudo-random number generators suitable for tasks
// such as simulation, but it should not be used for security-sensitive work.
//
// Random numbers are generated by a [Source], usually wrapped in a [Rand].
// Both types should be used by a single goroutine at a time: sharing among
// multiple goroutines requires some kind of synchronization.
//
// Top-level functions, such as [Float64] and [Int],
// are safe for concurrent use by multiple goroutines.
//
// This package's outputs might be easily predictable regardless of how it's
// seeded. For random numbers suitable for security-sensitive work, see the
// crypto/rand package.
package rand
import (
"internal/godebug"
"sync"
"sync/atomic"
_ "unsafe" // for go:linkname
)
// A Source represents a source of uniformly-distributed
// pseudo-random int64 values in the range [0, 1<<63).
//
// A Source is not safe for concurrent use by multiple goroutines.
type Source interface {
Int63() int64
Seed(seed int64)
}
// A Source64 is a Source that can also generate
// uniformly-distributed pseudo-random uint64 values in
// the range [0, 1<<64) directly.
// If a Rand r's underlying Source s implements Source64,
// then r.Uint64 returns the result of one call to s.Uint64
// instead of making two calls to s.Int63.
type Source64 interface {
Source
Uint64() uint64
}
// NewSource returns a new pseudo-random Source seeded with the given value.
// Unlike the default Source used by top-level functions, this source is not
// safe for concurrent use by multiple goroutines.
// The returned Source implements Source64.
func NewSource(seed int64) Source {
return newSource(seed)
}
func newSource(seed int64) *rngSource {
var rng rngSource
rng.Seed(seed)
return &rng
}
// A Rand is a source of random numbers.
type Rand struct {
src Source
s64 Source64 // non-nil if src is source64
// readVal contains remainder of 63-bit integer used for bytes
// generation during most recent Read call.
// It is saved so next Read call can start where the previous
// one finished.
readVal int64
// readPos indicates the number of low-order bytes of readVal
// that are still valid.
readPos int8
}
// New returns a new Rand that uses random values from src
// to generate other random values.
func New(src Source) *Rand {
s64, _ := src.(Source64)
return &Rand{src: src, s64: s64}
}
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
func (r *Rand) Seed(seed int64) {
if lk, ok := r.src.(*lockedSource); ok {
lk.seedPos(seed, &r.readPos)
return
}
r.src.Seed(seed)
r.readPos = 0
}
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
func (r *Rand) Int63() int64 { return r.src.Int63() }
// Uint32 returns a pseudo-random 32-bit value as a uint32.
func (r *Rand) Uint32() uint32 { return uint32(r.Int63() >> 31) }
// Uint64 returns a pseudo-random 64-bit value as a uint64.
func (r *Rand) Uint64() uint64 {
if r.s64 != nil {
return r.s64.Uint64()
}
return uint64(r.Int63())>>31 | uint64(r.Int63())<<32
}
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
func (r *Rand) Int31() int32 { return int32(r.Int63() >> 32) }
// Int returns a non-negative pseudo-random int.
func (r *Rand) Int() int {
u := uint(r.Int63())
return int(u << 1 >> 1) // clear sign bit if int == int32
}
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *Rand) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
if n&(n-1) == 0 { // n is power of two, can mask
return r.Int63() & (n - 1)
}
max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
v := r.Int63()
for v > max {
v = r.Int63()
}
return v % n
}
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *Rand) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
if n&(n-1) == 0 { // n is power of two, can mask
return r.Int31() & (n - 1)
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := r.Int31()
for v > max {
v = r.Int31()
}
return v % n
}
// int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// n must be > 0, but int31n does not check this; the caller must ensure it.
// int31n exists because Int31n is inefficient, but Go 1 compatibility
// requires that the stream of values produced by math/rand remain unchanged.
// int31n can thus only be used internally, by newly introduced APIs.
//
// For implementation details, see:
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
func (r *Rand) int31n(n int32) int32 {
v := r.Uint32()
prod := uint64(v) * uint64(n)
low := uint32(prod)
if low < uint32(n) {
thresh := uint32(-n) % uint32(n)
for low < thresh {
v = r.Uint32()
prod = uint64(v) * uint64(n)
low = uint32(prod)
}
}
return int32(prod >> 32)
}
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *Rand) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
if n <= 1<<31-1 {
return int(r.Int31n(int32(n)))
}
return int(r.Int63n(int64(n)))
}
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *Rand) Float64() float64 {
// A clearer, simpler implementation would be:
// return float64(r.Int63n(1<<53)) / (1<<53)
// However, Go 1 shipped with
// return float64(r.Int63()) / (1 << 63)
// and we want to preserve that value stream.
//
// There is one bug in the value stream: r.Int63() may be so close
// to 1<<63 that the division rounds up to 1.0, and we've guaranteed
// that the result is always less than 1.0.
//
// We tried to fix this by mapping 1.0 back to 0.0, but since float64
// values near 0 are much denser than near 1, mapping 1 to 0 caused
// a theoretically significant overshoot in the probability of returning 0.
// Instead of that, if we round up to 1, just try again.
// Getting 1 only happens 1/2⁵³ of the time, so most clients
// will not observe it anyway.
again:
f := float64(r.Int63()) / (1 << 63)
if f == 1 {
goto again // resample; this branch is taken O(never)
}
return f
}
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *Rand) Float32() float32 {
// Same rationale as in Float64: we want to preserve the Go 1 value
// stream except we want to fix it not to return 1.0
// This only happens 1/2²⁴ of the time (plus the 1/2⁵³ of the time in Float64).
again:
f := float32(r.Float64())
if f == 1 {
goto again // resample; this branch is taken O(very rarely)
}
return f
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n).
func (r *Rand) Perm(n int) []int {
m := make([]int, n)
// In the following loop, the iteration when i=0 always swaps m[0] with m[0].
// A change to remove this useless iteration is to assign 1 to i in the init
// statement. But Perm also effects r. Making this change will affect
// the final state of r. So this change can't be made for compatibility
// reasons for Go 1.
for i := 0; i < n; i++ {
j := r.Intn(i + 1)
m[i] = m[j]
m[j] = i
}
return m
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func (r *Rand) Shuffle(n int, swap func(i, j int)) {
if n < 0 {
panic("invalid argument to Shuffle")
}
// Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// Shuffle really ought not be called with n that doesn't fit in 32 bits.
// Not only will it take a very long time, but with 2³¹! possible permutations,
// there's no way that any PRNG can have a big enough internal state to
// generate even a minuscule percentage of the possible permutations.
// Nevertheless, the right API signature accepts an int n, so handle it as best we can.
i := n - 1
for ; i > 1<<31-1-1; i-- {
j := int(r.Int63n(int64(i + 1)))
swap(i, j)
}
for ; i > 0; i-- {
j := int(r.int31n(int32(i + 1)))
swap(i, j)
}
}
// Read generates len(p) random bytes and writes them into p. It
// always returns len(p) and a nil error.
// Read should not be called concurrently with any other Rand method.
func (r *Rand) Read(p []byte) (n int, err error) {
switch src := r.src.(type) {
case *lockedSource:
return src.read(p, &r.readVal, &r.readPos)
case *fastSource:
return src.read(p, &r.readVal, &r.readPos)
}
return read(p, r.src, &r.readVal, &r.readPos)
}
func read(p []byte, src Source, readVal *int64, readPos *int8) (n int, err error) {
pos := *readPos
val := *readVal
rng, _ := src.(*rngSource)
for n = 0; n < len(p); n++ {
if pos == 0 {
if rng != nil {
val = rng.Int63()
} else {
val = src.Int63()
}
pos = 7
}
p[n] = byte(val)
val >>= 8
pos--
}
*readPos = pos
*readVal = val
return
}
/*
* Top-level convenience functions
*/
// globalRandGenerator is the source of random numbers for the top-level
// convenience functions. When possible it uses the runtime fastrand64
// function to avoid locking. This is not possible if the user called Seed,
// either explicitly or implicitly via GODEBUG=randautoseed=0.
var globalRandGenerator atomic.Pointer[Rand]
var randautoseed = godebug.New("randautoseed")
// globalRand returns the generator to use for the top-level convenience
// functions.
func globalRand() *Rand {
if r := globalRandGenerator.Load(); r != nil {
return r
}
// This is the first call. Initialize based on GODEBUG.
var r *Rand
if randautoseed.Value() == "0" {
randautoseed.IncNonDefault()
r = New(new(lockedSource))
r.Seed(1)
} else {
r = &Rand{
src: &fastSource{},
s64: &fastSource{},
}
}
if !globalRandGenerator.CompareAndSwap(nil, r) {
// Two different goroutines called some top-level
// function at the same time. While the results in
// that case are unpredictable, if we just use r here,
// and we are using a seed, we will most likely return
// the same value for both calls. That doesn't seem ideal.
// Just use the first one to get in.
return globalRandGenerator.Load()
}
return r
}
//go:linkname fastrand64
func fastrand64() uint64
// fastSource is an implementation of Source64 that uses the runtime
// fastrand functions.
type fastSource struct {
// The mutex is used to avoid race conditions in Read.
mu sync.Mutex
}
func (*fastSource) Int63() int64 {
return int64(fastrand64() & rngMask)
}
func (*fastSource) Seed(int64) {
panic("internal error: call to fastSource.Seed")
}
func (*fastSource) Uint64() uint64 {
return fastrand64()
}
func (fs *fastSource) read(p []byte, readVal *int64, readPos *int8) (n int, err error) {
fs.mu.Lock()
n, err = read(p, fs, readVal, readPos)
fs.mu.Unlock()
return
}
// Seed uses the provided seed value to initialize the default Source to a
// deterministic state. Seed values that have the same remainder when
// divided by 2³¹-1 generate the same pseudo-random sequence.
// Seed, unlike the Rand.Seed method, is safe for concurrent use.
//
// If Seed is not called, the generator is seeded randomly at program startup.
//
// Prior to Go 1.20, the generator was seeded like Seed(1) at program startup.
// To force the old behavior, call Seed(1) at program startup.
// Alternately, set GODEBUG=randautoseed=0 in the environment
// before making any calls to functions in this package.
//
// Deprecated: Programs that call Seed and then expect a specific sequence
// of results from the global random source (using functions such as Int)
// can be broken when a dependency changes how much it consumes
// from the global random source. To avoid such breakages, programs
// that need a specific result sequence should use New(NewSource(seed))
// to obtain a random generator that other packages cannot access.
func Seed(seed int64) {
orig := globalRandGenerator.Load()
// If we are already using a lockedSource, we can just re-seed it.
if orig != nil {
if _, ok := orig.src.(*lockedSource); ok {
orig.Seed(seed)
return
}
}
// Otherwise either
// 1) orig == nil, which is the normal case when Seed is the first
// top-level function to be called, or
// 2) orig is already a fastSource, in which case we need to change
// to a lockedSource.
// Either way we do the same thing.
r := New(new(lockedSource))
r.Seed(seed)
if !globalRandGenerator.CompareAndSwap(orig, r) {
// Something changed underfoot. Retry to be safe.
Seed(seed)
}
}
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64
// from the default Source.
func Int63() int64 { return globalRand().Int63() }
// Uint32 returns a pseudo-random 32-bit value as a uint32
// from the default Source.
func Uint32() uint32 { return globalRand().Uint32() }
// Uint64 returns a pseudo-random 64-bit value as a uint64
// from the default Source.
func Uint64() uint64 { return globalRand().Uint64() }
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32
// from the default Source.
func Int31() int32 { return globalRand().Int31() }
// Int returns a non-negative pseudo-random int from the default Source.
func Int() int { return globalRand().Int() }
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n)
// from the default Source.
// It panics if n <= 0.
func Int63n(n int64) int64 { return globalRand().Int63n(n) }
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n)
// from the default Source.
// It panics if n <= 0.
func Int31n(n int32) int32 { return globalRand().Int31n(n) }
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n)
// from the default Source.
// It panics if n <= 0.
func Intn(n int) int { return globalRand().Intn(n) }
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0)
// from the default Source.
func Float64() float64 { return globalRand().Float64() }
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0)
// from the default Source.
func Float32() float32 { return globalRand().Float32() }
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n) from the default Source.
func Perm(n int) []int { return globalRand().Perm(n) }
// Shuffle pseudo-randomizes the order of elements using the default Source.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func Shuffle(n int, swap func(i, j int)) { globalRand().Shuffle(n, swap) }
// Read generates len(p) random bytes from the default Source and
// writes them into p. It always returns len(p) and a nil error.
// Read, unlike the Rand.Read method, is safe for concurrent use.
//
// Deprecated: For almost all use cases, crypto/rand.Read is more appropriate.
func Read(p []byte) (n int, err error) { return globalRand().Read(p) }
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func NormFloat64() float64 { return globalRand().NormFloat64() }
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func ExpFloat64() float64 { return globalRand().ExpFloat64() }
type lockedSource struct {
lk sync.Mutex
s *rngSource
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.s.Int63()
r.lk.Unlock()
return
}
func (r *lockedSource) Uint64() (n uint64) {
r.lk.Lock()
n = r.s.Uint64()
r.lk.Unlock()
return
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.seed(seed)
r.lk.Unlock()
}
// seedPos implements Seed for a lockedSource without a race condition.
func (r *lockedSource) seedPos(seed int64, readPos *int8) {
r.lk.Lock()
r.seed(seed)
*readPos = 0
r.lk.Unlock()
}
// seed seeds the underlying source.
// The caller must have locked r.lk.
func (r *lockedSource) seed(seed int64) {
if r.s == nil {
r.s = newSource(seed)
} else {
r.s.Seed(seed)
}
}
// read implements Read for a lockedSource without a race condition.
func (r *lockedSource) read(p []byte, readVal *int64, readPos *int8) (n int, err error) {
r.lk.Lock()
n, err = read(p, r.s, readVal, readPos)
r.lk.Unlock()
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
/*
* Uniform distribution
*
* algorithm by
* DP Mitchell and JA Reeds
*/
const (
rngLen = 607
rngTap = 273
rngMax = 1 << 63
rngMask = rngMax - 1
int32max = (1 << 31) - 1
)
var (
// rngCooked used for seeding. See gen_cooked.go for details.
rngCooked [rngLen]int64 = [...]int64{
-4181792142133755926, -4576982950128230565, 1395769623340756751, 5333664234075297259,
-6347679516498800754, 9033628115061424579, 7143218595135194537, 4812947590706362721,
7937252194349799378, 5307299880338848416, 8209348851763925077, -7107630437535961764,
4593015457530856296, 8140875735541888011, -5903942795589686782, -603556388664454774,
-7496297993371156308, 113108499721038619, 4569519971459345583, -4160538177779461077,
-6835753265595711384, -6507240692498089696, 6559392774825876886, 7650093201692370310,
7684323884043752161, -8965504200858744418, -2629915517445760644, 271327514973697897,
-6433985589514657524, 1065192797246149621, 3344507881999356393, -4763574095074709175,
7465081662728599889, 1014950805555097187, -4773931307508785033, -5742262670416273165,
2418672789110888383, 5796562887576294778, 4484266064449540171, 3738982361971787048,
-4699774852342421385, 10530508058128498, -589538253572429690, -6598062107225984180,
8660405965245884302, 10162832508971942, -2682657355892958417, 7031802312784620857,
6240911277345944669, 831864355460801054, -1218937899312622917, 2116287251661052151,
2202309800992166967, 9161020366945053561, 4069299552407763864, 4936383537992622449,
457351505131524928, -8881176990926596454, -6375600354038175299, -7155351920868399290,
4368649989588021065, 887231587095185257, -3659780529968199312, -2407146836602825512,
5616972787034086048, -751562733459939242, 1686575021641186857, -5177887698780513806,
-4979215821652996885, -1375154703071198421, 5632136521049761902, -8390088894796940536,
-193645528485698615, -5979788902190688516, -4907000935050298721, -285522056888777828,
-2776431630044341707, 1679342092332374735, 6050638460742422078, -2229851317345194226,
-1582494184340482199, 5881353426285907985, 812786550756860885, 4541845584483343330,
-6497901820577766722, 4980675660146853729, -4012602956251539747, -329088717864244987,
-2896929232104691526, 1495812843684243920, -2153620458055647789, 7370257291860230865,
-2466442761497833547, 4706794511633873654, -1398851569026877145, 8549875090542453214,
-9189721207376179652, -7894453601103453165, 7297902601803624459, 1011190183918857495,
-6985347000036920864, 5147159997473910359, -8326859945294252826, 2659470849286379941,
6097729358393448602, -7491646050550022124, -5117116194870963097, -896216826133240300,
-745860416168701406, 5803876044675762232, -787954255994554146, -3234519180203704564,
-4507534739750823898, -1657200065590290694, 505808562678895611, -4153273856159712438,
-8381261370078904295, 572156825025677802, 1791881013492340891, 3393267094866038768,
-5444650186382539299, 2352769483186201278, -7930912453007408350, -325464993179687389,
-3441562999710612272, -6489413242825283295, 5092019688680754699, -227247482082248967,
4234737173186232084, 5027558287275472836, 4635198586344772304, -536033143587636457,
5907508150730407386, -8438615781380831356, 972392927514829904, -3801314342046600696,
-4064951393885491917, -174840358296132583, 2407211146698877100, -1640089820333676239,
3940796514530962282, -5882197405809569433, 3095313889586102949, -1818050141166537098,
5832080132947175283, 7890064875145919662, 8184139210799583195, -8073512175445549678,
-7758774793014564506, -4581724029666783935, 3516491885471466898, -8267083515063118116,
6657089965014657519, 5220884358887979358, 1796677326474620641, 5340761970648932916,
1147977171614181568, 5066037465548252321, 2574765911837859848, 1085848279845204775,
-5873264506986385449, 6116438694366558490, 2107701075971293812, -7420077970933506541,
2469478054175558874, -1855128755834809824, -5431463669011098282, -9038325065738319171,
-6966276280341336160, 7217693971077460129, -8314322083775271549, 7196649268545224266,
-3585711691453906209, -5267827091426810625, 8057528650917418961, -5084103596553648165,
-2601445448341207749, -7850010900052094367, 6527366231383600011, 3507654575162700890,
9202058512774729859, 1954818376891585542, -2582991129724600103, 8299563319178235687,
-5321504681635821435, 7046310742295574065, -2376176645520785576, -7650733936335907755,
8850422670118399721, 3631909142291992901, 5158881091950831288, -6340413719511654215,
4763258931815816403, 6280052734341785344, -4979582628649810958, 2043464728020827976,
-2678071570832690343, 4562580375758598164, 5495451168795427352, -7485059175264624713,
553004618757816492, 6895160632757959823, -989748114590090637, 7139506338801360852,
-672480814466784139, 5535668688139305547, 2430933853350256242, -3821430778991574732,
-1063731997747047009, -3065878205254005442, 7632066283658143750, 6308328381617103346,
3681878764086140361, 3289686137190109749, 6587997200611086848, 244714774258135476,
-5143583659437639708, 8090302575944624335, 2945117363431356361, -8359047641006034763,
3009039260312620700, -793344576772241777, 401084700045993341, -1968749590416080887,
4707864159563588614, -3583123505891281857, -3240864324164777915, -5908273794572565703,
-3719524458082857382, -5281400669679581926, 8118566580304798074, 3839261274019871296,
7062410411742090847, -8481991033874568140, 6027994129690250817, -6725542042704711878,
-2971981702428546974, -7854441788951256975, 8809096399316380241, 6492004350391900708,
2462145737463489636, -8818543617934476634, -5070345602623085213, -8961586321599299868,
-3758656652254704451, -8630661632476012791, 6764129236657751224, -709716318315418359,
-3403028373052861600, -8838073512170985897, -3999237033416576341, -2920240395515973663,
-2073249475545404416, 368107899140673753, -6108185202296464250, -6307735683270494757,
4782583894627718279, 6718292300699989587, 8387085186914375220, 3387513132024756289,
4654329375432538231, -292704475491394206, -3848998599978456535, 7623042350483453954,
7725442901813263321, 9186225467561587250, -5132344747257272453, -6865740430362196008,
2530936820058611833, 1636551876240043639, -3658707362519810009, 1452244145334316253,
-7161729655835084979, -7943791770359481772, 9108481583171221009, -3200093350120725999,
5007630032676973346, 2153168792952589781, 6720334534964750538, -3181825545719981703,
3433922409283786309, 2285479922797300912, 3110614940896576130, -2856812446131932915,
-3804580617188639299, 7163298419643543757, 4891138053923696990, 580618510277907015,
1684034065251686769, 4429514767357295841, -8893025458299325803, -8103734041042601133,
7177515271653460134, 4589042248470800257, -1530083407795771245, 143607045258444228,
246994305896273627, -8356954712051676521, 6473547110565816071, 3092379936208876896,
2058427839513754051, -4089587328327907870, 8785882556301281247, -3074039370013608197,
-637529855400303673, 6137678347805511274, -7152924852417805802, 5708223427705576541,
-3223714144396531304, 4358391411789012426, 325123008708389849, 6837621693887290924,
4843721905315627004, -3212720814705499393, -3825019837890901156, 4602025990114250980,
1044646352569048800, 9106614159853161675, -8394115921626182539, -4304087667751778808,
2681532557646850893, 3681559472488511871, -3915372517896561773, -2889241648411946534,
-6564663803938238204, -8060058171802589521, 581945337509520675, 3648778920718647903,
-4799698790548231394, -7602572252857820065, 220828013409515943, -1072987336855386047,
4287360518296753003, -4633371852008891965, 5513660857261085186, -2258542936462001533,
-8744380348503999773, 8746140185685648781, 228500091334420247, 1356187007457302238,
3019253992034194581, 3152601605678500003, -8793219284148773595, 5559581553696971176,
4916432985369275664, -8559797105120221417, -5802598197927043732, 2868348622579915573,
-7224052902810357288, -5894682518218493085, 2587672709781371173, -7706116723325376475,
3092343956317362483, -5561119517847711700, 972445599196498113, -1558506600978816441,
1708913533482282562, -2305554874185907314, -6005743014309462908, -6653329009633068701,
-483583197311151195, 2488075924621352812, -4529369641467339140, -4663743555056261452,
2997203966153298104, 1282559373026354493, 240113143146674385, 8665713329246516443,
628141331766346752, -4651421219668005332, -7750560848702540400, 7596648026010355826,
-3132152619100351065, 7834161864828164065, 7103445518877254909, 4390861237357459201,
-4780718172614204074, -319889632007444440, 622261699494173647, -3186110786557562560,
-8718967088789066690, -1948156510637662747, -8212195255998774408, -7028621931231314745,
2623071828615234808, -4066058308780939700, -5484966924888173764, -6683604512778046238,
-6756087640505506466, 5256026990536851868, 7841086888628396109, 6640857538655893162,
-8021284697816458310, -7109857044414059830, -1689021141511844405, -4298087301956291063,
-4077748265377282003, -998231156719803476, 2719520354384050532, 9132346697815513771,
4332154495710163773, -2085582442760428892, 6994721091344268833, -2556143461985726874,
-8567931991128098309, 59934747298466858, -3098398008776739403, -265597256199410390,
2332206071942466437, -7522315324568406181, 3154897383618636503, -7585605855467168281,
-6762850759087199275, 197309393502684135, -8579694182469508493, 2543179307861934850,
4350769010207485119, -4468719947444108136, -7207776534213261296, -1224312577878317200,
4287946071480840813, 8362686366770308971, 6486469209321732151, -5605644191012979782,
-1669018511020473564, 4450022655153542367, -7618176296641240059, -3896357471549267421,
-4596796223304447488, -6531150016257070659, -8982326463137525940, -4125325062227681798,
-1306489741394045544, -8338554946557245229, 5329160409530630596, 7790979528857726136,
4955070238059373407, -4304834761432101506, -6215295852904371179, 3007769226071157901,
-6753025801236972788, 8928702772696731736, 7856187920214445904, -4748497451462800923,
7900176660600710914, -7082800908938549136, -6797926979589575837, -6737316883512927978,
4186670094382025798, 1883939007446035042, -414705992779907823, 3734134241178479257,
4065968871360089196, 6953124200385847784, -7917685222115876751, -7585632937840318161,
-5567246375906782599, -5256612402221608788, 3106378204088556331, -2894472214076325998,
4565385105440252958, 1979884289539493806, -6891578849933910383, 3783206694208922581,
8464961209802336085, 2843963751609577687, 3030678195484896323, -4429654462759003204,
4459239494808162889, 402587895800087237, 8057891408711167515, 4541888170938985079,
1042662272908816815, -3666068979732206850, 2647678726283249984, 2144477441549833761,
-3417019821499388721, -2105601033380872185, 5916597177708541638, -8760774321402454447,
8833658097025758785, 5970273481425315300, 563813119381731307, -6455022486202078793,
1598828206250873866, -4016978389451217698, -2988328551145513985, -6071154634840136312,
8469693267274066490, 125672920241807416, -3912292412830714870, -2559617104544284221,
-486523741806024092, -4735332261862713930, 5923302823487327109, -9082480245771672572,
-1808429243461201518, 7990420780896957397, 4317817392807076702, 3625184369705367340,
-6482649271566653105, -3480272027152017464, -3225473396345736649, -368878695502291645,
-3981164001421868007, -8522033136963788610, 7609280429197514109, 3020985755112334161,
-2572049329799262942, 2635195723621160615, 5144520864246028816, -8188285521126945980,
1567242097116389047, 8172389260191636581, -2885551685425483535, -7060359469858316883,
-6480181133964513127, -7317004403633452381, 6011544915663598137, 5932255307352610768,
2241128460406315459, -8327867140638080220, 3094483003111372717, 4583857460292963101,
9079887171656594975, -384082854924064405, -3460631649611717935, 4225072055348026230,
-7385151438465742745, 3801620336801580414, -399845416774701952, -7446754431269675473,
7899055018877642622, 5421679761463003041, 5521102963086275121, -4975092593295409910,
8735487530905098534, -7462844945281082830, -2080886987197029914, -1000715163927557685,
-4253840471931071485, -5828896094657903328, 6424174453260338141, 359248545074932887,
-5949720754023045210, -2426265837057637212, 3030918217665093212, -9077771202237461772,
-3186796180789149575, 740416251634527158, -2142944401404840226, 6951781370868335478,
399922722363687927, -8928469722407522623, -1378421100515597285, -8343051178220066766,
-3030716356046100229, -8811767350470065420, 9026808440365124461, 6440783557497587732,
4615674634722404292, 539897290441580544, 2096238225866883852, 8751955639408182687,
-7316147128802486205, 7381039757301768559, 6157238513393239656, -1473377804940618233,
8629571604380892756, 5280433031239081479, 7101611890139813254, 2479018537985767835,
7169176924412769570, -1281305539061572506, -7865612307799218120, 2278447439451174845,
3625338785743880657, 6477479539006708521, 8976185375579272206, -3712000482142939688,
1326024180520890843, 7537449876596048829, 5464680203499696154, 3189671183162196045,
6346751753565857109, -8982212049534145501, -6127578587196093755, -245039190118465649,
-6320577374581628592, 7208698530190629697, 7276901792339343736, -7490986807540332668,
4133292154170828382, 2918308698224194548, -7703910638917631350, -3929437324238184044,
-4300543082831323144, -6344160503358350167, 5896236396443472108, -758328221503023383,
-1894351639983151068, -307900319840287220, -6278469401177312761, -2171292963361310674,
8382142935188824023, 9103922860780351547, 4152330101494654406,
}
)
type rngSource struct {
tap int // index into vec
feed int // index into vec
vec [rngLen]int64 // current feedback register
}
// seed rng x[n+1] = 48271 * x[n] mod (2**31 - 1)
func seedrand(x int32) int32 {
const (
A = 48271
Q = 44488
R = 3399
)
hi := x / Q
lo := x % Q
x = A*lo - R*hi
if x < 0 {
x += int32max
}
return x
}
// Seed uses the provided seed value to initialize the generator to a deterministic state.
func (rng *rngSource) Seed(seed int64) {
rng.tap = 0
rng.feed = rngLen - rngTap
seed = seed % int32max
if seed < 0 {
seed += int32max
}
if seed == 0 {
seed = 89482311
}
x := int32(seed)
for i := -20; i < rngLen; i++ {
x = seedrand(x)
if i >= 0 {
var u int64
u = int64(x) << 40
x = seedrand(x)
u ^= int64(x) << 20
x = seedrand(x)
u ^= int64(x)
u ^= rngCooked[i]
rng.vec[i] = u
}
}
}
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
func (rng *rngSource) Int63() int64 {
return int64(rng.Uint64() & rngMask)
}
// Uint64 returns a non-negative pseudo-random 64-bit integer as an uint64.
func (rng *rngSource) Uint64() uint64 {
rng.tap--
if rng.tap < 0 {
rng.tap += rngLen
}
rng.feed--
if rng.feed < 0 {
rng.feed += rngLen
}
x := rng.vec[rng.feed] + rng.vec[rng.tap]
rng.vec[rng.feed] = x
return uint64(x)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// W.Hormann, G.Derflinger:
// "Rejection-Inversion to Generate Variates
// from Monotone Discrete Distributions"
// http://eeyore.wu-wien.ac.at/papers/96-04-04.wh-der.ps.gz
package rand
import "math"
// A Zipf generates Zipf distributed variates.
type Zipf struct {
r *Rand
imax float64
v float64
q float64
s float64
oneminusQ float64
oneminusQinv float64
hxm float64
hx0minusHxm float64
}
func (z *Zipf) h(x float64) float64 {
return math.Exp(z.oneminusQ*math.Log(z.v+x)) * z.oneminusQinv
}
func (z *Zipf) hinv(x float64) float64 {
return math.Exp(z.oneminusQinv*math.Log(z.oneminusQ*x)) - z.v
}
// NewZipf returns a Zipf variate generator.
// The generator generates values k ∈ [0, imax]
// such that P(k) is proportional to (v + k) ** (-s).
// Requirements: s > 1 and v >= 1.
func NewZipf(r *Rand, s float64, v float64, imax uint64) *Zipf {
z := new(Zipf)
if s <= 1.0 || v < 1 {
return nil
}
z.r = r
z.imax = float64(imax)
z.v = v
z.q = s
z.oneminusQ = 1.0 - z.q
z.oneminusQinv = 1.0 / z.oneminusQ
z.hxm = z.h(z.imax + 0.5)
z.hx0minusHxm = z.h(0.5) - math.Exp(math.Log(z.v)*(-z.q)) - z.hxm
z.s = 1 - z.hinv(z.h(1.5)-math.Exp(-z.q*math.Log(z.v+1.0)))
return z
}
// Uint64 returns a value drawn from the Zipf distribution described
// by the Zipf object.
func (z *Zipf) Uint64() uint64 {
if z == nil {
panic("rand: nil Zipf")
}
k := 0.0
for {
r := z.r.Float64() // r on [0,1]
ur := z.hxm + r*z.hx0minusHxm
x := z.hinv(ur)
k = math.Floor(x + 0.5)
if k-x <= z.s {
break
}
if ur >= z.h(k+0.5)-math.Exp(-math.Log(k+z.v)*z.q) {
break
}
}
return uint64(k)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code and the comment below are from
// FreeBSD's /usr/src/lib/msun/src/e_remainder.c and came
// with this notice. The go code is a simplified version of
// the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_remainder(x,y)
// Return :
// returns x REM y = x - [x/y]*y as if in infinite
// precision arithmetic, where [x/y] is the (infinite bit)
// integer nearest x/y (in half way cases, choose the even one).
// Method :
// Based on Mod() returning x - [x/y]chopped * y exactly.
// Remainder returns the IEEE 754 floating-point remainder of x/y.
//
// Special cases are:
//
// Remainder(±Inf, y) = NaN
// Remainder(NaN, y) = NaN
// Remainder(x, 0) = NaN
// Remainder(x, ±Inf) = x
// Remainder(x, NaN) = NaN
func Remainder(x, y float64) float64 {
if haveArchRemainder {
return archRemainder(x, y)
}
return remainder(x, y)
}
func remainder(x, y float64) float64 {
const (
Tiny = 4.45014771701440276618e-308 // 0x0020000000000000
HalfMax = MaxFloat64 / 2
)
// special cases
switch {
case IsNaN(x) || IsNaN(y) || IsInf(x, 0) || y == 0:
return NaN()
case IsInf(y, 0):
return x
}
sign := false
if x < 0 {
x = -x
sign = true
}
if y < 0 {
y = -y
}
if x == y {
if sign {
zero := 0.0
return -zero
}
return 0
}
if y <= HalfMax {
x = Mod(x, y+y) // now x < 2y
}
if y < Tiny {
if x+x > y {
x -= y
if x+x >= y {
x -= y
}
}
} else {
yHalf := 0.5 * y
if x > yHalf {
x -= y
if x >= yHalf {
x -= y
}
}
}
if sign {
x = -x
}
return x
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Signbit reports whether x is negative or negative zero.
func Signbit(x float64) bool {
return Float64bits(x)&(1<<63) != 0
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point sine and cosine.
*/
// The original C code, the long comment, and the constants
// below were from http://netlib.sandia.gov/cephes/cmath/sin.c,
// available from http://www.netlib.org/cephes/cmath.tgz.
// The go code is a simplified version of the original C.
//
// sin.c
//
// Circular sine
//
// SYNOPSIS:
//
// double x, y, sin();
// y = sin( x );
//
// DESCRIPTION:
//
// Range reduction is into intervals of pi/4. The reduction error is nearly
// eliminated by contriving an extended precision modular arithmetic.
//
// Two polynomial approximating functions are employed.
// Between 0 and pi/4 the sine is approximated by
// x + x**3 P(x**2).
// Between pi/4 and pi/2 the cosine is represented as
// 1 - x**2 Q(x**2).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// DEC 0, 10 150000 3.0e-17 7.8e-18
// IEEE -1.07e9,+1.07e9 130000 2.1e-16 5.4e-17
//
// Partial loss of accuracy begins to occur at x = 2**30 = 1.074e9. The loss
// is not gradual, but jumps suddenly to about 1 part in 10e7. Results may
// be meaningless for x > 2**49 = 5.6e14.
//
// cos.c
//
// Circular cosine
//
// SYNOPSIS:
//
// double x, y, cos();
// y = cos( x );
//
// DESCRIPTION:
//
// Range reduction is into intervals of pi/4. The reduction error is nearly
// eliminated by contriving an extended precision modular arithmetic.
//
// Two polynomial approximating functions are employed.
// Between 0 and pi/4 the cosine is approximated by
// 1 - x**2 Q(x**2).
// Between pi/4 and pi/2 the sine is represented as
// x + x**3 P(x**2).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -1.07e9,+1.07e9 130000 2.1e-16 5.4e-17
// DEC 0,+1.07e9 17000 3.0e-17 7.2e-18
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// sin coefficients
var _sin = [...]float64{
1.58962301576546568060e-10, // 0x3de5d8fd1fd19ccd
-2.50507477628578072866e-8, // 0xbe5ae5e5a9291f5d
2.75573136213857245213e-6, // 0x3ec71de3567d48a1
-1.98412698295895385996e-4, // 0xbf2a01a019bfdf03
8.33333333332211858878e-3, // 0x3f8111111110f7d0
-1.66666666666666307295e-1, // 0xbfc5555555555548
}
// cos coefficients
var _cos = [...]float64{
-1.13585365213876817300e-11, // 0xbda8fa49a0861a9b
2.08757008419747316778e-9, // 0x3e21ee9d7b4e3f05
-2.75573141792967388112e-7, // 0xbe927e4f7eac4bc6
2.48015872888517045348e-5, // 0x3efa01a019c844f5
-1.38888888888730564116e-3, // 0xbf56c16c16c14f91
4.16666666666665929218e-2, // 0x3fa555555555554b
}
// Cos returns the cosine of the radian argument x.
//
// Special cases are:
//
// Cos(±Inf) = NaN
// Cos(NaN) = NaN
func Cos(x float64) float64 {
if haveArchCos {
return archCos(x)
}
return cos(x)
}
func cos(x float64) float64 {
const (
PI4A = 7.85398125648498535156e-1 // 0x3fe921fb40000000, Pi/4 split into three parts
PI4B = 3.77489470793079817668e-8 // 0x3e64442d00000000,
PI4C = 2.69515142907905952645e-15 // 0x3ce8469898cc5170,
)
// special cases
switch {
case IsNaN(x) || IsInf(x, 0):
return NaN()
}
// make argument positive
sign := false
x = Abs(x)
var j uint64
var y, z float64
if x >= reduceThreshold {
j, z = trigReduce(x)
} else {
j = uint64(x * (4 / Pi)) // integer part of x/(Pi/4), as integer for tests on the phase angle
y = float64(j) // integer part of x/(Pi/4), as float
// map zeros to origin
if j&1 == 1 {
j++
y++
}
j &= 7 // octant modulo 2Pi radians (360 degrees)
z = ((x - y*PI4A) - y*PI4B) - y*PI4C // Extended precision modular arithmetic
}
if j > 3 {
j -= 4
sign = !sign
}
if j > 1 {
sign = !sign
}
zz := z * z
if j == 1 || j == 2 {
y = z + z*zz*((((((_sin[0]*zz)+_sin[1])*zz+_sin[2])*zz+_sin[3])*zz+_sin[4])*zz+_sin[5])
} else {
y = 1.0 - 0.5*zz + zz*zz*((((((_cos[0]*zz)+_cos[1])*zz+_cos[2])*zz+_cos[3])*zz+_cos[4])*zz+_cos[5])
}
if sign {
y = -y
}
return y
}
// Sin returns the sine of the radian argument x.
//
// Special cases are:
//
// Sin(±0) = ±0
// Sin(±Inf) = NaN
// Sin(NaN) = NaN
func Sin(x float64) float64 {
if haveArchSin {
return archSin(x)
}
return sin(x)
}
func sin(x float64) float64 {
const (
PI4A = 7.85398125648498535156e-1 // 0x3fe921fb40000000, Pi/4 split into three parts
PI4B = 3.77489470793079817668e-8 // 0x3e64442d00000000,
PI4C = 2.69515142907905952645e-15 // 0x3ce8469898cc5170,
)
// special cases
switch {
case x == 0 || IsNaN(x):
return x // return ±0 || NaN()
case IsInf(x, 0):
return NaN()
}
// make argument positive but save the sign
sign := false
if x < 0 {
x = -x
sign = true
}
var j uint64
var y, z float64
if x >= reduceThreshold {
j, z = trigReduce(x)
} else {
j = uint64(x * (4 / Pi)) // integer part of x/(Pi/4), as integer for tests on the phase angle
y = float64(j) // integer part of x/(Pi/4), as float
// map zeros to origin
if j&1 == 1 {
j++
y++
}
j &= 7 // octant modulo 2Pi radians (360 degrees)
z = ((x - y*PI4A) - y*PI4B) - y*PI4C // Extended precision modular arithmetic
}
// reflect in x axis
if j > 3 {
sign = !sign
j -= 4
}
zz := z * z
if j == 1 || j == 2 {
y = 1.0 - 0.5*zz + zz*zz*((((((_cos[0]*zz)+_cos[1])*zz+_cos[2])*zz+_cos[3])*zz+_cos[4])*zz+_cos[5])
} else {
y = z + z*zz*((((((_sin[0]*zz)+_sin[1])*zz+_sin[2])*zz+_sin[3])*zz+_sin[4])*zz+_sin[5])
}
if sign {
y = -y
}
return y
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// Coefficients _sin[] and _cos[] are found in pkg/math/sin.go.
// Sincos returns Sin(x), Cos(x).
//
// Special cases are:
//
// Sincos(±0) = ±0, 1
// Sincos(±Inf) = NaN, NaN
// Sincos(NaN) = NaN, NaN
func Sincos(x float64) (sin, cos float64) {
const (
PI4A = 7.85398125648498535156e-1 // 0x3fe921fb40000000, Pi/4 split into three parts
PI4B = 3.77489470793079817668e-8 // 0x3e64442d00000000,
PI4C = 2.69515142907905952645e-15 // 0x3ce8469898cc5170,
)
// special cases
switch {
case x == 0:
return x, 1 // return ±0.0, 1.0
case IsNaN(x) || IsInf(x, 0):
return NaN(), NaN()
}
// make argument positive
sinSign, cosSign := false, false
if x < 0 {
x = -x
sinSign = true
}
var j uint64
var y, z float64
if x >= reduceThreshold {
j, z = trigReduce(x)
} else {
j = uint64(x * (4 / Pi)) // integer part of x/(Pi/4), as integer for tests on the phase angle
y = float64(j) // integer part of x/(Pi/4), as float
if j&1 == 1 { // map zeros to origin
j++
y++
}
j &= 7 // octant modulo 2Pi radians (360 degrees)
z = ((x - y*PI4A) - y*PI4B) - y*PI4C // Extended precision modular arithmetic
}
if j > 3 { // reflect in x axis
j -= 4
sinSign, cosSign = !sinSign, !cosSign
}
if j > 1 {
cosSign = !cosSign
}
zz := z * z
cos = 1.0 - 0.5*zz + zz*zz*((((((_cos[0]*zz)+_cos[1])*zz+_cos[2])*zz+_cos[3])*zz+_cos[4])*zz+_cos[5])
sin = z + z*zz*((((((_sin[0]*zz)+_sin[1])*zz+_sin[2])*zz+_sin[3])*zz+_sin[4])*zz+_sin[5])
if j == 1 || j == 2 {
sin, cos = cos, sin
}
if cosSign {
cos = -cos
}
if sinSign {
sin = -sin
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point hyperbolic sine and cosine.
The exponential func is called for arguments
greater in magnitude than 0.5.
A series is used for arguments smaller in magnitude than 0.5.
Cosh(x) is computed from the exponential func for
all arguments.
*/
// Sinh returns the hyperbolic sine of x.
//
// Special cases are:
//
// Sinh(±0) = ±0
// Sinh(±Inf) = ±Inf
// Sinh(NaN) = NaN
func Sinh(x float64) float64 {
if haveArchSinh {
return archSinh(x)
}
return sinh(x)
}
func sinh(x float64) float64 {
// The coefficients are #2029 from Hart & Cheney. (20.36D)
const (
P0 = -0.6307673640497716991184787251e+6
P1 = -0.8991272022039509355398013511e+5
P2 = -0.2894211355989563807284660366e+4
P3 = -0.2630563213397497062819489e+2
Q0 = -0.6307673640497716991212077277e+6
Q1 = 0.1521517378790019070696485176e+5
Q2 = -0.173678953558233699533450911e+3
)
sign := false
if x < 0 {
x = -x
sign = true
}
var temp float64
switch {
case x > 21:
temp = Exp(x) * 0.5
case x > 0.5:
ex := Exp(x)
temp = (ex - 1/ex) * 0.5
default:
sq := x * x
temp = (((P3*sq+P2)*sq+P1)*sq + P0) * x
temp = temp / (((sq+Q2)*sq+Q1)*sq + Q0)
}
if sign {
temp = -temp
}
return temp
}
// Cosh returns the hyperbolic cosine of x.
//
// Special cases are:
//
// Cosh(±0) = 1
// Cosh(±Inf) = +Inf
// Cosh(NaN) = NaN
func Cosh(x float64) float64 {
if haveArchCosh {
return archCosh(x)
}
return cosh(x)
}
func cosh(x float64) float64 {
x = Abs(x)
if x > 21 {
return Exp(x) * 0.5
}
ex := Exp(x)
return (ex + 1/ex) * 0.5
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_sqrt.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_sqrt(x)
// Return correctly rounded sqrt.
// -----------------------------------------
// | Use the hardware sqrt if you have one |
// -----------------------------------------
// Method:
// Bit by bit method using integer arithmetic. (Slow, but portable)
// 1. Normalization
// Scale x to y in [1,4) with even powers of 2:
// find an integer k such that 1 <= (y=x*2**(2k)) < 4, then
// sqrt(x) = 2**k * sqrt(y)
// 2. Bit by bit computation
// Let q = sqrt(y) truncated to i bit after binary point (q = 1),
// i 0
// i+1 2
// s = 2*q , and y = 2 * ( y - q ). (1)
// i i i i
//
// To compute q from q , one checks whether
// i+1 i
//
// -(i+1) 2
// (q + 2 ) <= y. (2)
// i
// -(i+1)
// If (2) is false, then q = q ; otherwise q = q + 2 .
// i+1 i i+1 i
//
// With some algebraic manipulation, it is not difficult to see
// that (2) is equivalent to
// -(i+1)
// s + 2 <= y (3)
// i i
//
// The advantage of (3) is that s and y can be computed by
// i i
// the following recurrence formula:
// if (3) is false
//
// s = s , y = y ; (4)
// i+1 i i+1 i
//
// otherwise,
// -i -(i+1)
// s = s + 2 , y = y - s - 2 (5)
// i+1 i i+1 i i
//
// One may easily use induction to prove (4) and (5).
// Note. Since the left hand side of (3) contain only i+2 bits,
// it is not necessary to do a full (53-bit) comparison
// in (3).
// 3. Final rounding
// After generating the 53 bits result, we compute one more bit.
// Together with the remainder, we can decide whether the
// result is exact, bigger than 1/2ulp, or less than 1/2ulp
// (it will never equal to 1/2ulp).
// The rounding mode can be detected by checking whether
// huge + tiny is equal to huge, and whether huge - tiny is
// equal to huge for some floating point number "huge" and "tiny".
//
//
// Notes: Rounding mode detection omitted. The constants "mask", "shift",
// and "bias" are found in src/math/bits.go
// Sqrt returns the square root of x.
//
// Special cases are:
//
// Sqrt(+Inf) = +Inf
// Sqrt(±0) = ±0
// Sqrt(x < 0) = NaN
// Sqrt(NaN) = NaN
func Sqrt(x float64) float64 {
return sqrt(x)
}
// Note: On systems where Sqrt is a single instruction, the compiler
// may turn a direct call into a direct use of that instruction instead.
func sqrt(x float64) float64 {
// special cases
switch {
case x == 0 || IsNaN(x) || IsInf(x, 1):
return x
case x < 0:
return NaN()
}
ix := Float64bits(x)
// normalize x
exp := int((ix >> shift) & mask)
if exp == 0 { // subnormal x
for ix&(1<<shift) == 0 {
ix <<= 1
exp--
}
exp++
}
exp -= bias // unbias exponent
ix &^= mask << shift
ix |= 1 << shift
if exp&1 == 1 { // odd exp, double x to make it even
ix <<= 1
}
exp >>= 1 // exp = exp/2, exponent of square root
// generate sqrt(x) bit by bit
ix <<= 1
var q, s uint64 // q = sqrt(x)
r := uint64(1 << (shift + 1)) // r = moving bit from MSB to LSB
for r != 0 {
t := s + r
if t <= ix {
s = t + r
ix -= t
q += r
}
ix <<= 1
r >>= 1
}
// final rounding
if ix != 0 { // remainder, result not exact
q += q & 1 // round according to extra bit
}
ix = q>>1 + uint64(exp-1+bias)<<shift // significand + biased exponent
return Float64frombits(ix)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !s390x
// This is a large group of functions that most architectures don't
// implement in assembly.
package math
const haveArchAcos = false
func archAcos(x float64) float64 {
panic("not implemented")
}
const haveArchAcosh = false
func archAcosh(x float64) float64 {
panic("not implemented")
}
const haveArchAsin = false
func archAsin(x float64) float64 {
panic("not implemented")
}
const haveArchAsinh = false
func archAsinh(x float64) float64 {
panic("not implemented")
}
const haveArchAtan = false
func archAtan(x float64) float64 {
panic("not implemented")
}
const haveArchAtan2 = false
func archAtan2(y, x float64) float64 {
panic("not implemented")
}
const haveArchAtanh = false
func archAtanh(x float64) float64 {
panic("not implemented")
}
const haveArchCbrt = false
func archCbrt(x float64) float64 {
panic("not implemented")
}
const haveArchCos = false
func archCos(x float64) float64 {
panic("not implemented")
}
const haveArchCosh = false
func archCosh(x float64) float64 {
panic("not implemented")
}
const haveArchErf = false
func archErf(x float64) float64 {
panic("not implemented")
}
const haveArchErfc = false
func archErfc(x float64) float64 {
panic("not implemented")
}
const haveArchExpm1 = false
func archExpm1(x float64) float64 {
panic("not implemented")
}
const haveArchFrexp = false
func archFrexp(x float64) (float64, int) {
panic("not implemented")
}
const haveArchLdexp = false
func archLdexp(frac float64, exp int) float64 {
panic("not implemented")
}
const haveArchLog10 = false
func archLog10(x float64) float64 {
panic("not implemented")
}
const haveArchLog2 = false
func archLog2(x float64) float64 {
panic("not implemented")
}
const haveArchLog1p = false
func archLog1p(x float64) float64 {
panic("not implemented")
}
const haveArchMod = false
func archMod(x, y float64) float64 {
panic("not implemented")
}
const haveArchPow = false
func archPow(x, y float64) float64 {
panic("not implemented")
}
const haveArchRemainder = false
func archRemainder(x, y float64) float64 {
panic("not implemented")
}
const haveArchSin = false
func archSin(x float64) float64 {
panic("not implemented")
}
const haveArchSinh = false
func archSinh(x float64) float64 {
panic("not implemented")
}
const haveArchTan = false
func archTan(x float64) float64 {
panic("not implemented")
}
const haveArchTanh = false
func archTanh(x float64) float64 {
panic("not implemented")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
/*
Floating-point tangent.
*/
// The original C code, the long comment, and the constants
// below were from http://netlib.sandia.gov/cephes/cmath/sin.c,
// available from http://www.netlib.org/cephes/cmath.tgz.
// The go code is a simplified version of the original C.
//
// tan.c
//
// Circular tangent
//
// SYNOPSIS:
//
// double x, y, tan();
// y = tan( x );
//
// DESCRIPTION:
//
// Returns the circular tangent of the radian argument x.
//
// Range reduction is modulo pi/4. A rational function
// x + x**3 P(x**2)/Q(x**2)
// is employed in the basic interval [0, pi/4].
//
// ACCURACY:
// Relative error:
// arithmetic domain # trials peak rms
// DEC +-1.07e9 44000 4.1e-17 1.0e-17
// IEEE +-1.07e9 30000 2.9e-16 8.1e-17
//
// Partial loss of accuracy begins to occur at x = 2**30 = 1.074e9. The loss
// is not gradual, but jumps suddenly to about 1 part in 10e7. Results may
// be meaningless for x > 2**49 = 5.6e14.
// [Accuracy loss statement from sin.go comments.]
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
// tan coefficients
var _tanP = [...]float64{
-1.30936939181383777646e4, // 0xc0c992d8d24f3f38
1.15351664838587416140e6, // 0x413199eca5fc9ddd
-1.79565251976484877988e7, // 0xc1711fead3299176
}
var _tanQ = [...]float64{
1.00000000000000000000e0,
1.36812963470692954678e4, // 0x40cab8a5eeb36572
-1.32089234440210967447e6, // 0xc13427bc582abc96
2.50083801823357915839e7, // 0x4177d98fc2ead8ef
-5.38695755929454629881e7, // 0xc189afe03cbe5a31
}
// Tan returns the tangent of the radian argument x.
//
// Special cases are:
//
// Tan(±0) = ±0
// Tan(±Inf) = NaN
// Tan(NaN) = NaN
func Tan(x float64) float64 {
if haveArchTan {
return archTan(x)
}
return tan(x)
}
func tan(x float64) float64 {
const (
PI4A = 7.85398125648498535156e-1 // 0x3fe921fb40000000, Pi/4 split into three parts
PI4B = 3.77489470793079817668e-8 // 0x3e64442d00000000,
PI4C = 2.69515142907905952645e-15 // 0x3ce8469898cc5170,
)
// special cases
switch {
case x == 0 || IsNaN(x):
return x // return ±0 || NaN()
case IsInf(x, 0):
return NaN()
}
// make argument positive but save the sign
sign := false
if x < 0 {
x = -x
sign = true
}
var j uint64
var y, z float64
if x >= reduceThreshold {
j, z = trigReduce(x)
} else {
j = uint64(x * (4 / Pi)) // integer part of x/(Pi/4), as integer for tests on the phase angle
y = float64(j) // integer part of x/(Pi/4), as float
/* map zeros and singularities to origin */
if j&1 == 1 {
j++
y++
}
z = ((x - y*PI4A) - y*PI4B) - y*PI4C
}
zz := z * z
if zz > 1e-14 {
y = z + z*(zz*(((_tanP[0]*zz)+_tanP[1])*zz+_tanP[2])/((((zz+_tanQ[1])*zz+_tanQ[2])*zz+_tanQ[3])*zz+_tanQ[4]))
} else {
y = z
}
if j&2 == 2 {
y = -1 / y
}
if sign {
y = -y
}
return y
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
// The original C code, the long comment, and the constants
// below were from http://netlib.sandia.gov/cephes/cmath/sin.c,
// available from http://www.netlib.org/cephes/cmath.tgz.
// The go code is a simplified version of the original C.
// tanh.c
//
// Hyperbolic tangent
//
// SYNOPSIS:
//
// double x, y, tanh();
//
// y = tanh( x );
//
// DESCRIPTION:
//
// Returns hyperbolic tangent of argument in the range MINLOG to MAXLOG.
// MAXLOG = 8.8029691931113054295988e+01 = log(2**127)
// MINLOG = -8.872283911167299960540e+01 = log(2**-128)
//
// A rational function is used for |x| < 0.625. The form
// x + x**3 P(x)/Q(x) of Cody & Waite is employed.
// Otherwise,
// tanh(x) = sinh(x)/cosh(x) = 1 - 2/(exp(2x) + 1).
//
// ACCURACY:
//
// Relative error:
// arithmetic domain # trials peak rms
// IEEE -2,2 30000 2.5e-16 5.8e-17
//
// Cephes Math Library Release 2.8: June, 2000
// Copyright 1984, 1987, 1989, 1992, 2000 by Stephen L. Moshier
//
// The readme file at http://netlib.sandia.gov/cephes/ says:
// Some software in this archive may be from the book _Methods and
// Programs for Mathematical Functions_ (Prentice-Hall or Simon & Schuster
// International, 1989) or from the Cephes Mathematical Library, a
// commercial product. In either event, it is copyrighted by the author.
// What you see here may be used freely but it comes with no support or
// guarantee.
//
// The two known misprints in the book are repaired here in the
// source listings for the gamma function and the incomplete beta
// integral.
//
// Stephen L. Moshier
// moshier@na-net.ornl.gov
//
var tanhP = [...]float64{
-9.64399179425052238628e-1,
-9.92877231001918586564e1,
-1.61468768441708447952e3,
}
var tanhQ = [...]float64{
1.12811678491632931402e2,
2.23548839060100448583e3,
4.84406305325125486048e3,
}
// Tanh returns the hyperbolic tangent of x.
//
// Special cases are:
//
// Tanh(±0) = ±0
// Tanh(±Inf) = ±1
// Tanh(NaN) = NaN
func Tanh(x float64) float64 {
if haveArchTanh {
return archTanh(x)
}
return tanh(x)
}
func tanh(x float64) float64 {
const MAXLOG = 8.8029691931113054295988e+01 // log(2**127)
z := Abs(x)
switch {
case z > 0.5*MAXLOG:
if x < 0 {
return -1
}
return 1
case z >= 0.625:
s := Exp(2 * z)
z = 1 - 2/(s+1)
if x < 0 {
z = -z
}
default:
if x == 0 {
return x
}
s := x * x
z = x + x*s*((tanhP[0]*s+tanhP[1])*s+tanhP[2])/(((s+tanhQ[0])*s+tanhQ[1])*s+tanhQ[2])
}
return z
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
import (
"math/bits"
)
// reduceThreshold is the maximum value of x where the reduction using Pi/4
// in 3 float64 parts still gives accurate results. This threshold
// is set by y*C being representable as a float64 without error
// where y is given by y = floor(x * (4 / Pi)) and C is the leading partial
// terms of 4/Pi. Since the leading terms (PI4A and PI4B in sin.go) have 30
// and 32 trailing zero bits, y should have less than 30 significant bits.
//
// y < 1<<30 -> floor(x*4/Pi) < 1<<30 -> x < (1<<30 - 1) * Pi/4
//
// So, conservatively we can take x < 1<<29.
// Above this threshold Payne-Hanek range reduction must be used.
const reduceThreshold = 1 << 29
// trigReduce implements Payne-Hanek range reduction by Pi/4
// for x > 0. It returns the integer part mod 8 (j) and
// the fractional part (z) of x / (Pi/4).
// The implementation is based on:
// "ARGUMENT REDUCTION FOR HUGE ARGUMENTS: Good to the Last Bit"
// K. C. Ng et al, March 24, 1992
// The simulated multi-precision calculation of x*B uses 64-bit integer arithmetic.
func trigReduce(x float64) (j uint64, z float64) {
const PI4 = Pi / 4
if x < PI4 {
return 0, x
}
// Extract out the integer and exponent such that,
// x = ix * 2 ** exp.
ix := Float64bits(x)
exp := int(ix>>shift&mask) - bias - shift
ix &^= mask << shift
ix |= 1 << shift
// Use the exponent to extract the 3 appropriate uint64 digits from mPi4,
// B ~ (z0, z1, z2), such that the product leading digit has the exponent -61.
// Note, exp >= -53 since x >= PI4 and exp < 971 for maximum float64.
digit, bitshift := uint(exp+61)/64, uint(exp+61)%64
z0 := (mPi4[digit] << bitshift) | (mPi4[digit+1] >> (64 - bitshift))
z1 := (mPi4[digit+1] << bitshift) | (mPi4[digit+2] >> (64 - bitshift))
z2 := (mPi4[digit+2] << bitshift) | (mPi4[digit+3] >> (64 - bitshift))
// Multiply mantissa by the digits and extract the upper two digits (hi, lo).
z2hi, _ := bits.Mul64(z2, ix)
z1hi, z1lo := bits.Mul64(z1, ix)
z0lo := z0 * ix
lo, c := bits.Add64(z1lo, z2hi, 0)
hi, _ := bits.Add64(z0lo, z1hi, c)
// The top 3 bits are j.
j = hi >> 61
// Extract the fraction and find its magnitude.
hi = hi<<3 | lo>>61
lz := uint(bits.LeadingZeros64(hi))
e := uint64(bias - (lz + 1))
// Clear implicit mantissa bit and shift into place.
hi = (hi << (lz + 1)) | (lo >> (64 - (lz + 1)))
hi >>= 64 - shift
// Include the exponent and convert to a float.
hi |= e << shift
z = Float64frombits(hi)
// Map zeros to origin.
if j&1 == 1 {
j++
j &= 7
z--
}
// Multiply the fractional part by pi/4.
return j, z * PI4
}
// mPi4 is the binary digits of 4/pi as a uint64 array,
// that is, 4/pi = Sum mPi4[i]*2^(-64*i)
// 19 64-bit digits and the leading one bit give 1217 bits
// of precision to handle the largest possible float64 exponent.
var mPi4 = [...]uint64{
0x0000000000000001,
0x45f306dc9c882a53,
0xf84eafa3ea69bb81,
0xb6c52b3278872083,
0xfca2c757bd778ac3,
0x6e48dc74849ba5c0,
0x0c925dd413a32439,
0xfc3bd63962534e7d,
0xd1046bea5d768909,
0xd338e04d68befc82,
0x7323ac7306a673e9,
0x3908bf177bf25076,
0x3ff12fffbc0b301f,
0xde5e2316b414da3e,
0xda6cfd9e4f96136e,
0x9e8c7ecd3cbfd45a,
0xea4f758fd7cbe2f6,
0x7a0e73ef14a525d4,
0xd7f6bf623f1aba10,
0xac06608df8f6d757,
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
import "unsafe"
// Float32bits returns the IEEE 754 binary representation of f,
// with the sign bit of f and the result in the same bit position.
// Float32bits(Float32frombits(x)) == x.
func Float32bits(f float32) uint32 { return *(*uint32)(unsafe.Pointer(&f)) }
// Float32frombits returns the floating-point number corresponding
// to the IEEE 754 binary representation b, with the sign bit of b
// and the result in the same bit position.
// Float32frombits(Float32bits(x)) == x.
func Float32frombits(b uint32) float32 { return *(*float32)(unsafe.Pointer(&b)) }
// Float64bits returns the IEEE 754 binary representation of f,
// with the sign bit of f and the result in the same bit position,
// and Float64bits(Float64frombits(x)) == x.
func Float64bits(f float64) uint64 { return *(*uint64)(unsafe.Pointer(&f)) }
// Float64frombits returns the floating-point number corresponding
// to the IEEE 754 binary representation b, with the sign bit of b
// and the result in the same bit position.
// Float64frombits(Float64bits(x)) == x.
func Float64frombits(b uint64) float64 { return *(*float64)(unsafe.Pointer(&b)) }
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"strings"
"unicode"
"unicode/utf8"
)
// A WordEncoder is an RFC 2047 encoded-word encoder.
type WordEncoder byte
const (
// BEncoding represents Base64 encoding scheme as defined by RFC 2045.
BEncoding = WordEncoder('b')
// QEncoding represents the Q-encoding scheme as defined by RFC 2047.
QEncoding = WordEncoder('q')
)
var (
errInvalidWord = errors.New("mime: invalid RFC 2047 encoded-word")
)
// Encode returns the encoded-word form of s. If s is ASCII without special
// characters, it is returned unchanged. The provided charset is the IANA
// charset name of s. It is case insensitive.
func (e WordEncoder) Encode(charset, s string) string {
if !needsEncoding(s) {
return s
}
return e.encodeWord(charset, s)
}
func needsEncoding(s string) bool {
for _, b := range s {
if (b < ' ' || b > '~') && b != '\t' {
return true
}
}
return false
}
// encodeWord encodes a string into an encoded-word.
func (e WordEncoder) encodeWord(charset, s string) string {
var buf strings.Builder
// Could use a hint like len(s)*3, but that's not enough for cases
// with word splits and too much for simpler inputs.
// 48 is close to maxEncodedWordLen/2, but adjusted to allocator size class.
buf.Grow(48)
e.openWord(&buf, charset)
if e == BEncoding {
e.bEncode(&buf, charset, s)
} else {
e.qEncode(&buf, charset, s)
}
closeWord(&buf)
return buf.String()
}
const (
// The maximum length of an encoded-word is 75 characters.
// See RFC 2047, section 2.
maxEncodedWordLen = 75
// maxContentLen is how much content can be encoded, ignoring the header and
// 2-byte footer.
maxContentLen = maxEncodedWordLen - len("=?UTF-8?q?") - len("?=")
)
var maxBase64Len = base64.StdEncoding.DecodedLen(maxContentLen)
// bEncode encodes s using base64 encoding and writes it to buf.
func (e WordEncoder) bEncode(buf *strings.Builder, charset, s string) {
w := base64.NewEncoder(base64.StdEncoding, buf)
// If the charset is not UTF-8 or if the content is short, do not bother
// splitting the encoded-word.
if !isUTF8(charset) || base64.StdEncoding.EncodedLen(len(s)) <= maxContentLen {
io.WriteString(w, s)
w.Close()
return
}
var currentLen, last, runeLen int
for i := 0; i < len(s); i += runeLen {
// Multi-byte characters must not be split across encoded-words.
// See RFC 2047, section 5.3.
_, runeLen = utf8.DecodeRuneInString(s[i:])
if currentLen+runeLen <= maxBase64Len {
currentLen += runeLen
} else {
io.WriteString(w, s[last:i])
w.Close()
e.splitWord(buf, charset)
last = i
currentLen = runeLen
}
}
io.WriteString(w, s[last:])
w.Close()
}
// qEncode encodes s using Q encoding and writes it to buf. It splits the
// encoded-words when necessary.
func (e WordEncoder) qEncode(buf *strings.Builder, charset, s string) {
// We only split encoded-words when the charset is UTF-8.
if !isUTF8(charset) {
writeQString(buf, s)
return
}
var currentLen, runeLen int
for i := 0; i < len(s); i += runeLen {
b := s[i]
// Multi-byte characters must not be split across encoded-words.
// See RFC 2047, section 5.3.
var encLen int
if b >= ' ' && b <= '~' && b != '=' && b != '?' && b != '_' {
runeLen, encLen = 1, 1
} else {
_, runeLen = utf8.DecodeRuneInString(s[i:])
encLen = 3 * runeLen
}
if currentLen+encLen > maxContentLen {
e.splitWord(buf, charset)
currentLen = 0
}
writeQString(buf, s[i:i+runeLen])
currentLen += encLen
}
}
// writeQString encodes s using Q encoding and writes it to buf.
func writeQString(buf *strings.Builder, s string) {
for i := 0; i < len(s); i++ {
switch b := s[i]; {
case b == ' ':
buf.WriteByte('_')
case b >= '!' && b <= '~' && b != '=' && b != '?' && b != '_':
buf.WriteByte(b)
default:
buf.WriteByte('=')
buf.WriteByte(upperhex[b>>4])
buf.WriteByte(upperhex[b&0x0f])
}
}
}
// openWord writes the beginning of an encoded-word into buf.
func (e WordEncoder) openWord(buf *strings.Builder, charset string) {
buf.WriteString("=?")
buf.WriteString(charset)
buf.WriteByte('?')
buf.WriteByte(byte(e))
buf.WriteByte('?')
}
// closeWord writes the end of an encoded-word into buf.
func closeWord(buf *strings.Builder) {
buf.WriteString("?=")
}
// splitWord closes the current encoded-word and opens a new one.
func (e WordEncoder) splitWord(buf *strings.Builder, charset string) {
closeWord(buf)
buf.WriteByte(' ')
e.openWord(buf, charset)
}
func isUTF8(charset string) bool {
return strings.EqualFold(charset, "UTF-8")
}
const upperhex = "0123456789ABCDEF"
// A WordDecoder decodes MIME headers containing RFC 2047 encoded-words.
type WordDecoder struct {
// CharsetReader, if non-nil, defines a function to generate
// charset-conversion readers, converting from the provided
// charset into UTF-8.
// Charsets are always lower-case. utf-8, iso-8859-1 and us-ascii charsets
// are handled by default.
// One of the CharsetReader's result values must be non-nil.
CharsetReader func(charset string, input io.Reader) (io.Reader, error)
}
// Decode decodes an RFC 2047 encoded-word.
func (d *WordDecoder) Decode(word string) (string, error) {
// See https://tools.ietf.org/html/rfc2047#section-2 for details.
// Our decoder is permissive, we accept empty encoded-text.
if len(word) < 8 || !strings.HasPrefix(word, "=?") || !strings.HasSuffix(word, "?=") || strings.Count(word, "?") != 4 {
return "", errInvalidWord
}
word = word[2 : len(word)-2]
// split word "UTF-8?q?text" into "UTF-8", 'q', and "text"
charset, text, _ := strings.Cut(word, "?")
if charset == "" {
return "", errInvalidWord
}
encoding, text, _ := strings.Cut(text, "?")
if len(encoding) != 1 {
return "", errInvalidWord
}
content, err := decode(encoding[0], text)
if err != nil {
return "", err
}
var buf strings.Builder
if err := d.convert(&buf, charset, content); err != nil {
return "", err
}
return buf.String(), nil
}
// DecodeHeader decodes all encoded-words of the given string. It returns an
// error if and only if CharsetReader of d returns an error.
func (d *WordDecoder) DecodeHeader(header string) (string, error) {
// If there is no encoded-word, returns before creating a buffer.
i := strings.Index(header, "=?")
if i == -1 {
return header, nil
}
var buf strings.Builder
buf.WriteString(header[:i])
header = header[i:]
betweenWords := false
for {
start := strings.Index(header, "=?")
if start == -1 {
break
}
cur := start + len("=?")
i := strings.Index(header[cur:], "?")
if i == -1 {
break
}
charset := header[cur : cur+i]
cur += i + len("?")
if len(header) < cur+len("Q??=") {
break
}
encoding := header[cur]
cur++
if header[cur] != '?' {
break
}
cur++
j := strings.Index(header[cur:], "?=")
if j == -1 {
break
}
text := header[cur : cur+j]
end := cur + j + len("?=")
content, err := decode(encoding, text)
if err != nil {
betweenWords = false
buf.WriteString(header[:start+2])
header = header[start+2:]
continue
}
// Write characters before the encoded-word. White-space and newline
// characters separating two encoded-words must be deleted.
if start > 0 && (!betweenWords || hasNonWhitespace(header[:start])) {
buf.WriteString(header[:start])
}
if err := d.convert(&buf, charset, content); err != nil {
return "", err
}
header = header[end:]
betweenWords = true
}
if len(header) > 0 {
buf.WriteString(header)
}
return buf.String(), nil
}
func decode(encoding byte, text string) ([]byte, error) {
switch encoding {
case 'B', 'b':
return base64.StdEncoding.DecodeString(text)
case 'Q', 'q':
return qDecode(text)
default:
return nil, errInvalidWord
}
}
func (d *WordDecoder) convert(buf *strings.Builder, charset string, content []byte) error {
switch {
case strings.EqualFold("utf-8", charset):
buf.Write(content)
case strings.EqualFold("iso-8859-1", charset):
for _, c := range content {
buf.WriteRune(rune(c))
}
case strings.EqualFold("us-ascii", charset):
for _, c := range content {
if c >= utf8.RuneSelf {
buf.WriteRune(unicode.ReplacementChar)
} else {
buf.WriteByte(c)
}
}
default:
if d.CharsetReader == nil {
return fmt.Errorf("mime: unhandled charset %q", charset)
}
r, err := d.CharsetReader(strings.ToLower(charset), bytes.NewReader(content))
if err != nil {
return err
}
if _, err = io.Copy(buf, r); err != nil {
return err
}
}
return nil
}
// hasNonWhitespace reports whether s (assumed to be ASCII) contains at least
// one byte of non-whitespace.
func hasNonWhitespace(s string) bool {
for _, b := range s {
switch b {
// Encoded-words can only be separated by linear white spaces which does
// not include vertical tabs (\v).
case ' ', '\t', '\n', '\r':
default:
return true
}
}
return false
}
// qDecode decodes a Q encoded string.
func qDecode(s string) ([]byte, error) {
dec := make([]byte, len(s))
n := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == '_':
dec[n] = ' '
case c == '=':
if i+2 >= len(s) {
return nil, errInvalidWord
}
b, err := readHexByte(s[i+1], s[i+2])
if err != nil {
return nil, err
}
dec[n] = b
i += 2
case (c <= '~' && c >= ' ') || c == '\n' || c == '\r' || c == '\t':
dec[n] = c
default:
return nil, errInvalidWord
}
n++
}
return dec[:n], nil
}
// readHexByte returns the byte from its quoted-printable representation.
func readHexByte(a, b byte) (byte, error) {
var hb, lb byte
var err error
if hb, err = fromHex(a); err != nil {
return 0, err
}
if lb, err = fromHex(b); err != nil {
return 0, err
}
return hb<<4 | lb, nil
}
func fromHex(b byte) (byte, error) {
switch {
case b >= '0' && b <= '9':
return b - '0', nil
case b >= 'A' && b <= 'F':
return b - 'A' + 10, nil
// Accept badly encoded bytes.
case b >= 'a' && b <= 'f':
return b - 'a' + 10, nil
}
return 0, fmt.Errorf("mime: invalid hex byte %#02x", b)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"strings"
)
// isTSpecial reports whether rune is in 'tspecials' as defined by RFC
// 1521 and RFC 2045.
func isTSpecial(r rune) bool {
return strings.ContainsRune(`()<>@,;:\"/[]?=`, r)
}
// isTokenChar reports whether rune is in 'token' as defined by RFC
// 1521 and RFC 2045.
func isTokenChar(r rune) bool {
// token := 1*<any (US-ASCII) CHAR except SPACE, CTLs,
// or tspecials>
return r > 0x20 && r < 0x7f && !isTSpecial(r)
}
// isToken reports whether s is a 'token' as defined by RFC 1521
// and RFC 2045.
func isToken(s string) bool {
if s == "" {
return false
}
return strings.IndexFunc(s, isNotTokenChar) < 0
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mime
import (
"errors"
"fmt"
"sort"
"strings"
"unicode"
)
// FormatMediaType serializes mediatype t and the parameters
// param as a media type conforming to RFC 2045 and RFC 2616.
// The type and parameter names are written in lower-case.
// When any of the arguments result in a standard violation then
// FormatMediaType returns the empty string.
func FormatMediaType(t string, param map[string]string) string {
var b strings.Builder
if major, sub, ok := strings.Cut(t, "/"); !ok {
if !isToken(t) {
return ""
}
b.WriteString(strings.ToLower(t))
} else {
if !isToken(major) || !isToken(sub) {
return ""
}
b.WriteString(strings.ToLower(major))
b.WriteByte('/')
b.WriteString(strings.ToLower(sub))
}
attrs := make([]string, 0, len(param))
for a := range param {
attrs = append(attrs, a)
}
sort.Strings(attrs)
for _, attribute := range attrs {
value := param[attribute]
b.WriteByte(';')
b.WriteByte(' ')
if !isToken(attribute) {
return ""
}
b.WriteString(strings.ToLower(attribute))
needEnc := needsEncoding(value)
if needEnc {
// RFC 2231 section 4
b.WriteByte('*')
}
b.WriteByte('=')
if needEnc {
b.WriteString("utf-8''")
offset := 0
for index := 0; index < len(value); index++ {
ch := value[index]
// {RFC 2231 section 7}
// attribute-char := <any (US-ASCII) CHAR except SPACE, CTLs, "*", "'", "%", or tspecials>
if ch <= ' ' || ch >= 0x7F ||
ch == '*' || ch == '\'' || ch == '%' ||
isTSpecial(rune(ch)) {
b.WriteString(value[offset:index])
offset = index + 1
b.WriteByte('%')
b.WriteByte(upperhex[ch>>4])
b.WriteByte(upperhex[ch&0x0F])
}
}
b.WriteString(value[offset:])
continue
}
if isToken(value) {
b.WriteString(value)
continue
}
b.WriteByte('"')
offset := 0
for index := 0; index < len(value); index++ {
character := value[index]
if character == '"' || character == '\\' {
b.WriteString(value[offset:index])
offset = index
b.WriteByte('\\')
}
}
b.WriteString(value[offset:])
b.WriteByte('"')
}
return b.String()
}
func checkMediaTypeDisposition(s string) error {
typ, rest := consumeToken(s)
if typ == "" {
return errors.New("mime: no media type")
}
if rest == "" {
return nil
}
if !strings.HasPrefix(rest, "/") {
return errors.New("mime: expected slash after first token")
}
subtype, rest := consumeToken(rest[1:])
if subtype == "" {
return errors.New("mime: expected token after slash")
}
if rest != "" {
return errors.New("mime: unexpected content after media subtype")
}
return nil
}
// ErrInvalidMediaParameter is returned by ParseMediaType if
// the media type value was found but there was an error parsing
// the optional parameters
var ErrInvalidMediaParameter = errors.New("mime: invalid media parameter")
// ParseMediaType parses a media type value and any optional
// parameters, per RFC 1521. Media types are the values in
// Content-Type and Content-Disposition headers (RFC 2183).
// On success, ParseMediaType returns the media type converted
// to lowercase and trimmed of white space and a non-nil map.
// If there is an error parsing the optional parameter,
// the media type will be returned along with the error
// ErrInvalidMediaParameter.
// The returned map, params, maps from the lowercase
// attribute to the attribute value with its case preserved.
func ParseMediaType(v string) (mediatype string, params map[string]string, err error) {
base, _, _ := strings.Cut(v, ";")
mediatype = strings.TrimSpace(strings.ToLower(base))
err = checkMediaTypeDisposition(mediatype)
if err != nil {
return "", nil, err
}
params = make(map[string]string)
// Map of base parameter name -> parameter name -> value
// for parameters containing a '*' character.
// Lazily initialized.
var continuation map[string]map[string]string
v = v[len(base):]
for len(v) > 0 {
v = strings.TrimLeftFunc(v, unicode.IsSpace)
if len(v) == 0 {
break
}
key, value, rest := consumeMediaParam(v)
if key == "" {
if strings.TrimSpace(rest) == ";" {
// Ignore trailing semicolons.
// Not an error.
break
}
// Parse error.
return mediatype, nil, ErrInvalidMediaParameter
}
pmap := params
if baseName, _, ok := strings.Cut(key, "*"); ok {
if continuation == nil {
continuation = make(map[string]map[string]string)
}
var ok bool
if pmap, ok = continuation[baseName]; !ok {
continuation[baseName] = make(map[string]string)
pmap = continuation[baseName]
}
}
if v, exists := pmap[key]; exists && v != value {
// Duplicate parameter names are incorrect, but we allow them if they are equal.
return "", nil, errors.New("mime: duplicate parameter name")
}
pmap[key] = value
v = rest
}
// Stitch together any continuations or things with stars
// (i.e. RFC 2231 things with stars: "foo*0" or "foo*")
var buf strings.Builder
for key, pieceMap := range continuation {
singlePartKey := key + "*"
if v, ok := pieceMap[singlePartKey]; ok {
if decv, ok := decode2231Enc(v); ok {
params[key] = decv
}
continue
}
buf.Reset()
valid := false
for n := 0; ; n++ {
simplePart := fmt.Sprintf("%s*%d", key, n)
if v, ok := pieceMap[simplePart]; ok {
valid = true
buf.WriteString(v)
continue
}
encodedPart := simplePart + "*"
v, ok := pieceMap[encodedPart]
if !ok {
break
}
valid = true
if n == 0 {
if decv, ok := decode2231Enc(v); ok {
buf.WriteString(decv)
}
} else {
decv, _ := percentHexUnescape(v)
buf.WriteString(decv)
}
}
if valid {
params[key] = buf.String()
}
}
return
}
func decode2231Enc(v string) (string, bool) {
sv := strings.SplitN(v, "'", 3)
if len(sv) != 3 {
return "", false
}
// TODO: ignoring lang in sv[1] for now. If anybody needs it we'll
// need to decide how to expose it in the API. But I'm not sure
// anybody uses it in practice.
charset := strings.ToLower(sv[0])
if len(charset) == 0 {
return "", false
}
if charset != "us-ascii" && charset != "utf-8" {
// TODO: unsupported encoding
return "", false
}
encv, err := percentHexUnescape(sv[2])
if err != nil {
return "", false
}
return encv, true
}
func isNotTokenChar(r rune) bool {
return !isTokenChar(r)
}
// consumeToken consumes a token from the beginning of provided
// string, per RFC 2045 section 5.1 (referenced from 2183), and return
// the token consumed and the rest of the string. Returns ("", v) on
// failure to consume at least one character.
func consumeToken(v string) (token, rest string) {
notPos := strings.IndexFunc(v, isNotTokenChar)
if notPos == -1 {
return v, ""
}
if notPos == 0 {
return "", v
}
return v[0:notPos], v[notPos:]
}
// consumeValue consumes a "value" per RFC 2045, where a value is
// either a 'token' or a 'quoted-string'. On success, consumeValue
// returns the value consumed (and de-quoted/escaped, if a
// quoted-string) and the rest of the string. On failure, returns
// ("", v).
func consumeValue(v string) (value, rest string) {
if v == "" {
return
}
if v[0] != '"' {
return consumeToken(v)
}
// parse a quoted-string
buffer := new(strings.Builder)
for i := 1; i < len(v); i++ {
r := v[i]
if r == '"' {
return buffer.String(), v[i+1:]
}
// When MSIE sends a full file path (in "intranet mode"), it does not
// escape backslashes: "C:\dev\go\foo.txt", not "C:\\dev\\go\\foo.txt".
//
// No known MIME generators emit unnecessary backslash escapes
// for simple token characters like numbers and letters.
//
// If we see an unnecessary backslash escape, assume it is from MSIE
// and intended as a literal backslash. This makes Go servers deal better
// with MSIE without affecting the way they handle conforming MIME
// generators.
if r == '\\' && i+1 < len(v) && isTSpecial(rune(v[i+1])) {
buffer.WriteByte(v[i+1])
i++
continue
}
if r == '\r' || r == '\n' {
return "", v
}
buffer.WriteByte(v[i])
}
// Did not find end quote.
return "", v
}
func consumeMediaParam(v string) (param, value, rest string) {
rest = strings.TrimLeftFunc(v, unicode.IsSpace)
if !strings.HasPrefix(rest, ";") {
return "", "", v
}
rest = rest[1:] // consume semicolon
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
param, rest = consumeToken(rest)
param = strings.ToLower(param)
if param == "" {
return "", "", v
}
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
if !strings.HasPrefix(rest, "=") {
return "", "", v
}
rest = rest[1:] // consume equals sign
rest = strings.TrimLeftFunc(rest, unicode.IsSpace)
value, rest2 := consumeValue(rest)
if value == "" && rest2 == rest {
return "", "", v
}
rest = rest2
return param, value, rest
}
func percentHexUnescape(s string) (string, error) {
// Count %, check that they're well-formed.
percents := 0
for i := 0; i < len(s); {
if s[i] != '%' {
i++
continue
}
percents++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[0:3]
}
return "", fmt.Errorf("mime: bogus characters after %%: %q", s)
}
i += 3
}
if percents == 0 {
return s, nil
}
t := make([]byte, len(s)-2*percents)
j := 0
for i := 0; i < len(s); {
switch s[i] {
case '%':
t[j] = unhex(s[i+1])<<4 | unhex(s[i+2])
j++
i += 3
default:
t[j] = s[i]
j++
i++
}
}
return string(t), nil
}
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multipart
import (
"bytes"
"errors"
"internal/godebug"
"io"
"math"
"net/textproto"
"os"
)
// ErrMessageTooLarge is returned by ReadForm if the message form
// data is too large to be processed.
var ErrMessageTooLarge = errors.New("multipart: message too large")
// TODO(adg,bradfitz): find a way to unify the DoS-prevention strategy here
// with that of the http package's ParseForm.
// ReadForm parses an entire multipart message whose parts have
// a Content-Disposition of "form-data".
// It stores up to maxMemory bytes + 10MB (reserved for non-file parts)
// in memory. File parts which can't be stored in memory will be stored on
// disk in temporary files.
// It returns ErrMessageTooLarge if all non-file parts can't be stored in
// memory.
func (r *Reader) ReadForm(maxMemory int64) (*Form, error) {
return r.readForm(maxMemory)
}
var multipartFiles = godebug.New("multipartfiles")
func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) {
form := &Form{make(map[string][]string), make(map[string][]*FileHeader)}
var (
file *os.File
fileOff int64
)
numDiskFiles := 0
combineFiles := multipartFiles.Value() != "distinct"
defer func() {
if file != nil {
if cerr := file.Close(); err == nil {
err = cerr
}
}
if combineFiles && numDiskFiles > 1 {
for _, fhs := range form.File {
for _, fh := range fhs {
fh.tmpshared = true
}
}
}
if err != nil {
form.RemoveAll()
if file != nil {
os.Remove(file.Name())
}
}
}()
// maxFileMemoryBytes is the maximum bytes of file data we will store in memory.
// Data past this limit is written to disk.
// This limit strictly applies to content, not metadata (filenames, MIME headers, etc.),
// since metadata is always stored in memory, not disk.
//
// maxMemoryBytes is the maximum bytes we will store in memory, including file content,
// non-file part values, metdata, and map entry overhead.
//
// We reserve an additional 10 MB in maxMemoryBytes for non-file data.
//
// The relationship between these parameters, as well as the overly-large and
// unconfigurable 10 MB added on to maxMemory, is unfortunate but difficult to change
// within the constraints of the API as documented.
maxFileMemoryBytes := maxMemory
if maxFileMemoryBytes == math.MaxInt64 {
maxFileMemoryBytes--
}
maxMemoryBytes := maxMemory + int64(10<<20)
if maxMemoryBytes <= 0 {
if maxMemory < 0 {
maxMemoryBytes = 0
} else {
maxMemoryBytes = math.MaxInt64
}
}
for {
p, err := r.nextPart(false, maxMemoryBytes)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
name := p.FormName()
if name == "" {
continue
}
filename := p.FileName()
// Multiple values for the same key (one map entry, longer slice) are cheaper
// than the same number of values for different keys (many map entries), but
// using a consistent per-value cost for overhead is simpler.
maxMemoryBytes -= int64(len(name))
maxMemoryBytes -= 100 // map overhead
if maxMemoryBytes < 0 {
// We can't actually take this path, since nextPart would already have
// rejected the MIME headers for being too large. Check anyway.
return nil, ErrMessageTooLarge
}
var b bytes.Buffer
if filename == "" {
// value, store as string in memory
n, err := io.CopyN(&b, p, maxMemoryBytes+1)
if err != nil && err != io.EOF {
return nil, err
}
maxMemoryBytes -= n
if maxMemoryBytes < 0 {
return nil, ErrMessageTooLarge
}
form.Value[name] = append(form.Value[name], b.String())
continue
}
// file, store in memory or on disk
maxMemoryBytes -= mimeHeaderSize(p.Header)
if maxMemoryBytes < 0 {
return nil, ErrMessageTooLarge
}
fh := &FileHeader{
Filename: filename,
Header: p.Header,
}
n, err := io.CopyN(&b, p, maxFileMemoryBytes+1)
if err != nil && err != io.EOF {
return nil, err
}
if n > maxFileMemoryBytes {
if file == nil {
file, err = os.CreateTemp(r.tempDir, "multipart-")
if err != nil {
return nil, err
}
}
numDiskFiles++
size, err := io.Copy(file, io.MultiReader(&b, p))
if err != nil {
return nil, err
}
fh.tmpfile = file.Name()
fh.Size = size
fh.tmpoff = fileOff
fileOff += size
if !combineFiles {
if err := file.Close(); err != nil {
return nil, err
}
file = nil
}
} else {
fh.content = b.Bytes()
fh.Size = int64(len(fh.content))
maxFileMemoryBytes -= n
maxMemoryBytes -= n
}
form.File[name] = append(form.File[name], fh)
}
return form, nil
}
func mimeHeaderSize(h textproto.MIMEHeader) (size int64) {
for k, vs := range h {
size += int64(len(k))
size += 100 // map entry overhead
for _, v := range vs {
size += int64(len(v))
}
}
return size
}
// Form is a parsed multipart form.
// Its File parts are stored either in memory or on disk,
// and are accessible via the *FileHeader's Open method.
// Its Value parts are stored as strings.
// Both are keyed by field name.
type Form struct {
Value map[string][]string
File map[string][]*FileHeader
}
// RemoveAll removes any temporary files associated with a Form.
func (f *Form) RemoveAll() error {
var err error
for _, fhs := range f.File {
for _, fh := range fhs {
if fh.tmpfile != "" {
e := os.Remove(fh.tmpfile)
if e != nil && !errors.Is(e, os.ErrNotExist) && err == nil {
err = e
}
}
}
}
return err
}
// A FileHeader describes a file part of a multipart request.
type FileHeader struct {
Filename string
Header textproto.MIMEHeader
Size int64
content []byte
tmpfile string
tmpoff int64
tmpshared bool
}
// Open opens and returns the FileHeader's associated File.
func (fh *FileHeader) Open() (File, error) {
if b := fh.content; b != nil {
r := io.NewSectionReader(bytes.NewReader(b), 0, int64(len(b)))
return sectionReadCloser{r, nil}, nil
}
if fh.tmpshared {
f, err := os.Open(fh.tmpfile)
if err != nil {
return nil, err
}
r := io.NewSectionReader(f, fh.tmpoff, fh.Size)
return sectionReadCloser{r, f}, nil
}
return os.Open(fh.tmpfile)
}
// File is an interface to access the file part of a multipart message.
// Its contents may be either stored in memory or on disk.
// If stored on disk, the File's underlying concrete type will be an *os.File.
type File interface {
io.Reader
io.ReaderAt
io.Seeker
io.Closer
}
// helper types to turn a []byte into a File
type sectionReadCloser struct {
*io.SectionReader
io.Closer
}
func (rc sectionReadCloser) Close() error {
if rc.Closer != nil {
return rc.Closer.Close()
}
return nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
/*
Package multipart implements MIME multipart parsing, as defined in RFC
2046.
The implementation is sufficient for HTTP (RFC 2388) and the multipart
bodies generated by popular browsers.
*/
package multipart
import (
"bufio"
"bytes"
"fmt"
"io"
"mime"
"mime/quotedprintable"
"net/textproto"
"path/filepath"
"strings"
)
var emptyParams = make(map[string]string)
// This constant needs to be at least 76 for this package to work correctly.
// This is because \r\n--separator_of_len_70- would fill the buffer and it
// wouldn't be safe to consume a single byte from it.
const peekBufferSize = 4096
// A Part represents a single part in a multipart body.
type Part struct {
// The headers of the body, if any, with the keys canonicalized
// in the same fashion that the Go http.Request headers are.
// For example, "foo-bar" changes case to "Foo-Bar"
Header textproto.MIMEHeader
mr *Reader
disposition string
dispositionParams map[string]string
// r is either a reader directly reading from mr, or it's a
// wrapper around such a reader, decoding the
// Content-Transfer-Encoding
r io.Reader
n int // known data bytes waiting in mr.bufReader
total int64 // total data bytes read already
err error // error to return when n == 0
readErr error // read error observed from mr.bufReader
}
// FormName returns the name parameter if p has a Content-Disposition
// of type "form-data". Otherwise it returns the empty string.
func (p *Part) FormName() string {
// See https://tools.ietf.org/html/rfc2183 section 2 for EBNF
// of Content-Disposition value format.
if p.dispositionParams == nil {
p.parseContentDisposition()
}
if p.disposition != "form-data" {
return ""
}
return p.dispositionParams["name"]
}
// FileName returns the filename parameter of the Part's Content-Disposition
// header. If not empty, the filename is passed through filepath.Base (which is
// platform dependent) before being returned.
func (p *Part) FileName() string {
if p.dispositionParams == nil {
p.parseContentDisposition()
}
filename := p.dispositionParams["filename"]
if filename == "" {
return ""
}
// RFC 7578, Section 4.2 requires that if a filename is provided, the
// directory path information must not be used.
return filepath.Base(filename)
}
func (p *Part) parseContentDisposition() {
v := p.Header.Get("Content-Disposition")
var err error
p.disposition, p.dispositionParams, err = mime.ParseMediaType(v)
if err != nil {
p.dispositionParams = emptyParams
}
}
// NewReader creates a new multipart Reader reading from r using the
// given MIME boundary.
//
// The boundary is usually obtained from the "boundary" parameter of
// the message's "Content-Type" header. Use mime.ParseMediaType to
// parse such headers.
func NewReader(r io.Reader, boundary string) *Reader {
b := []byte("\r\n--" + boundary + "--")
return &Reader{
bufReader: bufio.NewReaderSize(&stickyErrorReader{r: r}, peekBufferSize),
nl: b[:2],
nlDashBoundary: b[:len(b)-2],
dashBoundaryDash: b[2:],
dashBoundary: b[2 : len(b)-2],
}
}
// stickyErrorReader is an io.Reader which never calls Read on its
// underlying Reader once an error has been seen. (the io.Reader
// interface's contract promises nothing about the return values of
// Read calls after an error, yet this package does do multiple Reads
// after error)
type stickyErrorReader struct {
r io.Reader
err error
}
func (r *stickyErrorReader) Read(p []byte) (n int, _ error) {
if r.err != nil {
return 0, r.err
}
n, r.err = r.r.Read(p)
return n, r.err
}
func newPart(mr *Reader, rawPart bool, maxMIMEHeaderSize int64) (*Part, error) {
bp := &Part{
Header: make(map[string][]string),
mr: mr,
}
if err := bp.populateHeaders(maxMIMEHeaderSize); err != nil {
return nil, err
}
bp.r = partReader{bp}
// rawPart is used to switch between Part.NextPart and Part.NextRawPart.
if !rawPart {
const cte = "Content-Transfer-Encoding"
if strings.EqualFold(bp.Header.Get(cte), "quoted-printable") {
bp.Header.Del(cte)
bp.r = quotedprintable.NewReader(bp.r)
}
}
return bp, nil
}
func (p *Part) populateHeaders(maxMIMEHeaderSize int64) error {
r := textproto.NewReader(p.mr.bufReader)
header, err := readMIMEHeader(r, maxMIMEHeaderSize)
if err == nil {
p.Header = header
}
// TODO: Add a distinguishable error to net/textproto.
if err != nil && err.Error() == "message too large" {
err = ErrMessageTooLarge
}
return err
}
// Read reads the body of a part, after its headers and before the
// next part (if any) begins.
func (p *Part) Read(d []byte) (n int, err error) {
return p.r.Read(d)
}
// partReader implements io.Reader by reading raw bytes directly from the
// wrapped *Part, without doing any Transfer-Encoding decoding.
type partReader struct {
p *Part
}
func (pr partReader) Read(d []byte) (int, error) {
p := pr.p
br := p.mr.bufReader
// Read into buffer until we identify some data to return,
// or we find a reason to stop (boundary or read error).
for p.n == 0 && p.err == nil {
peek, _ := br.Peek(br.Buffered())
p.n, p.err = scanUntilBoundary(peek, p.mr.dashBoundary, p.mr.nlDashBoundary, p.total, p.readErr)
if p.n == 0 && p.err == nil {
// Force buffered I/O to read more into buffer.
_, p.readErr = br.Peek(len(peek) + 1)
if p.readErr == io.EOF {
p.readErr = io.ErrUnexpectedEOF
}
}
}
// Read out from "data to return" part of buffer.
if p.n == 0 {
return 0, p.err
}
n := len(d)
if n > p.n {
n = p.n
}
n, _ = br.Read(d[:n])
p.total += int64(n)
p.n -= n
if p.n == 0 {
return n, p.err
}
return n, nil
}
// scanUntilBoundary scans buf to identify how much of it can be safely
// returned as part of the Part body.
// dashBoundary is "--boundary".
// nlDashBoundary is "\r\n--boundary" or "\n--boundary", depending on what mode we are in.
// The comments below (and the name) assume "\n--boundary", but either is accepted.
// total is the number of bytes read out so far. If total == 0, then a leading "--boundary" is recognized.
// readErr is the read error, if any, that followed reading the bytes in buf.
// scanUntilBoundary returns the number of data bytes from buf that can be
// returned as part of the Part body and also the error to return (if any)
// once those data bytes are done.
func scanUntilBoundary(buf, dashBoundary, nlDashBoundary []byte, total int64, readErr error) (int, error) {
if total == 0 {
// At beginning of body, allow dashBoundary.
if bytes.HasPrefix(buf, dashBoundary) {
switch matchAfterPrefix(buf, dashBoundary, readErr) {
case -1:
return len(dashBoundary), nil
case 0:
return 0, nil
case +1:
return 0, io.EOF
}
}
if bytes.HasPrefix(dashBoundary, buf) {
return 0, readErr
}
}
// Search for "\n--boundary".
if i := bytes.Index(buf, nlDashBoundary); i >= 0 {
switch matchAfterPrefix(buf[i:], nlDashBoundary, readErr) {
case -1:
return i + len(nlDashBoundary), nil
case 0:
return i, nil
case +1:
return i, io.EOF
}
}
if bytes.HasPrefix(nlDashBoundary, buf) {
return 0, readErr
}
// Otherwise, anything up to the final \n is not part of the boundary
// and so must be part of the body.
// Also if the section from the final \n onward is not a prefix of the boundary,
// it too must be part of the body.
i := bytes.LastIndexByte(buf, nlDashBoundary[0])
if i >= 0 && bytes.HasPrefix(nlDashBoundary, buf[i:]) {
return i, nil
}
return len(buf), readErr
}
// matchAfterPrefix checks whether buf should be considered to match the boundary.
// The prefix is "--boundary" or "\r\n--boundary" or "\n--boundary",
// and the caller has verified already that bytes.HasPrefix(buf, prefix) is true.
//
// matchAfterPrefix returns +1 if the buffer does match the boundary,
// meaning the prefix is followed by a double dash, space, tab, cr, nl,
// or end of input.
// It returns -1 if the buffer definitely does NOT match the boundary,
// meaning the prefix is followed by some other character.
// For example, "--foobar" does not match "--foo".
// It returns 0 more input needs to be read to make the decision,
// meaning that len(buf) == len(prefix) and readErr == nil.
func matchAfterPrefix(buf, prefix []byte, readErr error) int {
if len(buf) == len(prefix) {
if readErr != nil {
return +1
}
return 0
}
c := buf[len(prefix)]
if c == ' ' || c == '\t' || c == '\r' || c == '\n' {
return +1
}
// Try to detect boundaryDash
if c == '-' {
if len(buf) == len(prefix)+1 {
if readErr != nil {
// Prefix + "-" does not match
return -1
}
return 0
}
if buf[len(prefix)+1] == '-' {
return +1
}
}
return -1
}
func (p *Part) Close() error {
io.Copy(io.Discard, p)
return nil
}
// Reader is an iterator over parts in a MIME multipart body.
// Reader's underlying parser consumes its input as needed. Seeking
// isn't supported.
type Reader struct {
bufReader *bufio.Reader
tempDir string // used in tests
currentPart *Part
partsRead int
nl []byte // "\r\n" or "\n" (set after seeing first boundary line)
nlDashBoundary []byte // nl + "--boundary"
dashBoundaryDash []byte // "--boundary--"
dashBoundary []byte // "--boundary"
}
// maxMIMEHeaderSize is the maximum size of a MIME header we will parse,
// including header keys, values, and map overhead.
const maxMIMEHeaderSize = 10 << 20
// NextPart returns the next part in the multipart or an error.
// When there are no more parts, the error io.EOF is returned.
//
// As a special case, if the "Content-Transfer-Encoding" header
// has a value of "quoted-printable", that header is instead
// hidden and the body is transparently decoded during Read calls.
func (r *Reader) NextPart() (*Part, error) {
return r.nextPart(false, maxMIMEHeaderSize)
}
// NextRawPart returns the next part in the multipart or an error.
// When there are no more parts, the error io.EOF is returned.
//
// Unlike NextPart, it does not have special handling for
// "Content-Transfer-Encoding: quoted-printable".
func (r *Reader) NextRawPart() (*Part, error) {
return r.nextPart(true, maxMIMEHeaderSize)
}
func (r *Reader) nextPart(rawPart bool, maxMIMEHeaderSize int64) (*Part, error) {
if r.currentPart != nil {
r.currentPart.Close()
}
if string(r.dashBoundary) == "--" {
return nil, fmt.Errorf("multipart: boundary is empty")
}
expectNewPart := false
for {
line, err := r.bufReader.ReadSlice('\n')
if err == io.EOF && r.isFinalBoundary(line) {
// If the buffer ends in "--boundary--" without the
// trailing "\r\n", ReadSlice will return an error
// (since it's missing the '\n'), but this is a valid
// multipart EOF so we need to return io.EOF instead of
// a fmt-wrapped one.
return nil, io.EOF
}
if err != nil {
return nil, fmt.Errorf("multipart: NextPart: %w", err)
}
if r.isBoundaryDelimiterLine(line) {
r.partsRead++
bp, err := newPart(r, rawPart, maxMIMEHeaderSize)
if err != nil {
return nil, err
}
r.currentPart = bp
return bp, nil
}
if r.isFinalBoundary(line) {
// Expected EOF
return nil, io.EOF
}
if expectNewPart {
return nil, fmt.Errorf("multipart: expecting a new Part; got line %q", string(line))
}
if r.partsRead == 0 {
// skip line
continue
}
// Consume the "\n" or "\r\n" separator between the
// body of the previous part and the boundary line we
// now expect will follow. (either a new part or the
// end boundary)
if bytes.Equal(line, r.nl) {
expectNewPart = true
continue
}
return nil, fmt.Errorf("multipart: unexpected line in Next(): %q", line)
}
}
// isFinalBoundary reports whether line is the final boundary line
// indicating that all parts are over.
// It matches `^--boundary--[ \t]*(\r\n)?$`
func (r *Reader) isFinalBoundary(line []byte) bool {
if !bytes.HasPrefix(line, r.dashBoundaryDash) {
return false
}
rest := line[len(r.dashBoundaryDash):]
rest = skipLWSPChar(rest)
return len(rest) == 0 || bytes.Equal(rest, r.nl)
}
func (r *Reader) isBoundaryDelimiterLine(line []byte) (ret bool) {
// https://tools.ietf.org/html/rfc2046#section-5.1
// The boundary delimiter line is then defined as a line
// consisting entirely of two hyphen characters ("-",
// decimal value 45) followed by the boundary parameter
// value from the Content-Type header field, optional linear
// whitespace, and a terminating CRLF.
if !bytes.HasPrefix(line, r.dashBoundary) {
return false
}
rest := line[len(r.dashBoundary):]
rest = skipLWSPChar(rest)
// On the first part, see our lines are ending in \n instead of \r\n
// and switch into that mode if so. This is a violation of the spec,
// but occurs in practice.
if r.partsRead == 0 && len(rest) == 1 && rest[0] == '\n' {
r.nl = r.nl[1:]
r.nlDashBoundary = r.nlDashBoundary[1:]
}
return bytes.Equal(rest, r.nl)
}
// skipLWSPChar returns b with leading spaces and tabs removed.
// RFC 822 defines:
//
// LWSP-char = SPACE / HTAB
func skipLWSPChar(b []byte) []byte {
for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') {
b = b[1:]
}
return b
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multipart
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net/textproto"
"sort"
"strings"
)
// A Writer generates multipart messages.
type Writer struct {
w io.Writer
boundary string
lastpart *part
}
// NewWriter returns a new multipart Writer with a random boundary,
// writing to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
boundary: randomBoundary(),
}
}
// Boundary returns the Writer's boundary.
func (w *Writer) Boundary() string {
return w.boundary
}
// SetBoundary overrides the Writer's default randomly-generated
// boundary separator with an explicit value.
//
// SetBoundary must be called before any parts are created, may only
// contain certain ASCII characters, and must be non-empty and
// at most 70 bytes long.
func (w *Writer) SetBoundary(boundary string) error {
if w.lastpart != nil {
return errors.New("mime: SetBoundary called after write")
}
// rfc2046#section-5.1.1
if len(boundary) < 1 || len(boundary) > 70 {
return errors.New("mime: invalid boundary length")
}
end := len(boundary) - 1
for i, b := range boundary {
if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
continue
}
switch b {
case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
continue
case ' ':
if i != end {
continue
}
}
return errors.New("mime: invalid boundary character")
}
w.boundary = boundary
return nil
}
// FormDataContentType returns the Content-Type for an HTTP
// multipart/form-data with this Writer's Boundary.
func (w *Writer) FormDataContentType() string {
b := w.boundary
// We must quote the boundary if it contains any of the
// tspecials characters defined by RFC 2045, or space.
if strings.ContainsAny(b, `()<>@,;:\"/[]?= `) {
b = `"` + b + `"`
}
return "multipart/form-data; boundary=" + b
}
func randomBoundary() string {
var buf [30]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
// CreatePart creates a new multipart section with the provided
// header. The body of the part should be written to the returned
// Writer. After calling CreatePart, any previous part may no longer
// be written to.
func (w *Writer) CreatePart(header textproto.MIMEHeader) (io.Writer, error) {
if w.lastpart != nil {
if err := w.lastpart.close(); err != nil {
return nil, err
}
}
var b bytes.Buffer
if w.lastpart != nil {
fmt.Fprintf(&b, "\r\n--%s\r\n", w.boundary)
} else {
fmt.Fprintf(&b, "--%s\r\n", w.boundary)
}
keys := make([]string, 0, len(header))
for k := range header {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
for _, v := range header[k] {
fmt.Fprintf(&b, "%s: %s\r\n", k, v)
}
}
fmt.Fprintf(&b, "\r\n")
_, err := io.Copy(w.w, &b)
if err != nil {
return nil, err
}
p := &part{
mw: w,
}
w.lastpart = p
return p, nil
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func escapeQuotes(s string) string {
return quoteEscaper.Replace(s)
}
// CreateFormFile is a convenience wrapper around CreatePart. It creates
// a new form-data header with the provided field name and file name.
func (w *Writer) CreateFormFile(fieldname, filename string) (io.Writer, error) {
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
escapeQuotes(fieldname), escapeQuotes(filename)))
h.Set("Content-Type", "application/octet-stream")
return w.CreatePart(h)
}
// CreateFormField calls CreatePart with a header using the
// given field name.
func (w *Writer) CreateFormField(fieldname string) (io.Writer, error) {
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldname)))
return w.CreatePart(h)
}
// WriteField calls CreateFormField and then writes the given value.
func (w *Writer) WriteField(fieldname, value string) error {
p, err := w.CreateFormField(fieldname)
if err != nil {
return err
}
_, err = p.Write([]byte(value))
return err
}
// Close finishes the multipart message and writes the trailing
// boundary end line to the output.
func (w *Writer) Close() error {
if w.lastpart != nil {
if err := w.lastpart.close(); err != nil {
return err
}
w.lastpart = nil
}
_, err := fmt.Fprintf(w.w, "\r\n--%s--\r\n", w.boundary)
return err
}
type part struct {
mw *Writer
closed bool
we error // last error that occurred writing
}
func (p *part) close() error {
p.closed = true
return p.we
}
func (p *part) Write(d []byte) (n int, err error) {
if p.closed {
return 0, errors.New("multipart: can't write to finished part")
}
n, err = p.mw.w.Write(d)
if err != nil {
p.we = err
}
return
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package quotedprintable implements quoted-printable encoding as specified by
// RFC 2045.
package quotedprintable
import (
"bufio"
"bytes"
"fmt"
"io"
)
// Reader is a quoted-printable decoder.
type Reader struct {
br *bufio.Reader
rerr error // last read error
line []byte // to be consumed before more of br
}
// NewReader returns a quoted-printable reader, decoding from r.
func NewReader(r io.Reader) *Reader {
return &Reader{
br: bufio.NewReader(r),
}
}
func fromHex(b byte) (byte, error) {
switch {
case b >= '0' && b <= '9':
return b - '0', nil
case b >= 'A' && b <= 'F':
return b - 'A' + 10, nil
// Accept badly encoded bytes.
case b >= 'a' && b <= 'f':
return b - 'a' + 10, nil
}
return 0, fmt.Errorf("quotedprintable: invalid hex byte 0x%02x", b)
}
func readHexByte(v []byte) (b byte, err error) {
if len(v) < 2 {
return 0, io.ErrUnexpectedEOF
}
var hb, lb byte
if hb, err = fromHex(v[0]); err != nil {
return 0, err
}
if lb, err = fromHex(v[1]); err != nil {
return 0, err
}
return hb<<4 | lb, nil
}
func isQPDiscardWhitespace(r rune) bool {
switch r {
case '\n', '\r', ' ', '\t':
return true
}
return false
}
var (
crlf = []byte("\r\n")
lf = []byte("\n")
softSuffix = []byte("=")
)
// Read reads and decodes quoted-printable data from the underlying reader.
func (r *Reader) Read(p []byte) (n int, err error) {
// Deviations from RFC 2045:
// 1. in addition to "=\r\n", "=\n" is also treated as soft line break.
// 2. it will pass through a '\r' or '\n' not preceded by '=', consistent
// with other broken QP encoders & decoders.
// 3. it accepts soft line-break (=) at end of message (issue 15486); i.e.
// the final byte read from the underlying reader is allowed to be '=',
// and it will be silently ignored.
// 4. it takes = as literal = if not followed by two hex digits
// but not at end of line (issue 13219).
for len(p) > 0 {
if len(r.line) == 0 {
if r.rerr != nil {
return n, r.rerr
}
r.line, r.rerr = r.br.ReadSlice('\n')
// Does the line end in CRLF instead of just LF?
hasLF := bytes.HasSuffix(r.line, lf)
hasCR := bytes.HasSuffix(r.line, crlf)
wholeLine := r.line
r.line = bytes.TrimRightFunc(wholeLine, isQPDiscardWhitespace)
if bytes.HasSuffix(r.line, softSuffix) {
rightStripped := wholeLine[len(r.line):]
r.line = r.line[:len(r.line)-1]
if !bytes.HasPrefix(rightStripped, lf) && !bytes.HasPrefix(rightStripped, crlf) &&
!(len(rightStripped) == 0 && len(r.line) > 0 && r.rerr == io.EOF) {
r.rerr = fmt.Errorf("quotedprintable: invalid bytes after =: %q", rightStripped)
}
} else if hasLF {
if hasCR {
r.line = append(r.line, '\r', '\n')
} else {
r.line = append(r.line, '\n')
}
}
continue
}
b := r.line[0]
switch {
case b == '=':
b, err = readHexByte(r.line[1:])
if err != nil {
if len(r.line) >= 2 && r.line[1] != '\r' && r.line[1] != '\n' {
// Take the = as a literal =.
b = '='
break
}
return n, err
}
r.line = r.line[2:] // 2 of the 3; other 1 is done below
case b == '\t' || b == '\r' || b == '\n':
break
case b >= 0x80:
// As an extension to RFC 2045, we accept
// values >= 0x80 without complaint. Issue 22597.
break
case b < ' ' || b > '~':
return n, fmt.Errorf("quotedprintable: invalid unescaped byte 0x%02x in body", b)
}
p[0] = b
p = p[1:]
r.line = r.line[1:]
n++
}
return n, nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package quotedprintable
import "io"
const lineMaxLen = 76
// A Writer is a quoted-printable writer that implements io.WriteCloser.
type Writer struct {
// Binary mode treats the writer's input as pure binary and processes end of
// line bytes as binary data.
Binary bool
w io.Writer
i int
line [78]byte
cr bool
}
// NewWriter returns a new Writer that writes to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{w: w}
}
// Write encodes p using quoted-printable encoding and writes it to the
// underlying io.Writer. It limits line length to 76 characters. The encoded
// bytes are not necessarily flushed until the Writer is closed.
func (w *Writer) Write(p []byte) (n int, err error) {
for i, b := range p {
switch {
// Simple writes are done in batch.
case b >= '!' && b <= '~' && b != '=':
continue
case isWhitespace(b) || !w.Binary && (b == '\n' || b == '\r'):
continue
}
if i > n {
if err := w.write(p[n:i]); err != nil {
return n, err
}
n = i
}
if err := w.encode(b); err != nil {
return n, err
}
n++
}
if n == len(p) {
return n, nil
}
if err := w.write(p[n:]); err != nil {
return n, err
}
return len(p), nil
}
// Close closes the Writer, flushing any unwritten data to the underlying
// io.Writer, but does not close the underlying io.Writer.
func (w *Writer) Close() error {
if err := w.checkLastByte(); err != nil {
return err
}
return w.flush()
}
// write limits text encoded in quoted-printable to 76 characters per line.
func (w *Writer) write(p []byte) error {
for _, b := range p {
if b == '\n' || b == '\r' {
// If the previous byte was \r, the CRLF has already been inserted.
if w.cr && b == '\n' {
w.cr = false
continue
}
if b == '\r' {
w.cr = true
}
if err := w.checkLastByte(); err != nil {
return err
}
if err := w.insertCRLF(); err != nil {
return err
}
continue
}
if w.i == lineMaxLen-1 {
if err := w.insertSoftLineBreak(); err != nil {
return err
}
}
w.line[w.i] = b
w.i++
w.cr = false
}
return nil
}
func (w *Writer) encode(b byte) error {
if lineMaxLen-1-w.i < 3 {
if err := w.insertSoftLineBreak(); err != nil {
return err
}
}
w.line[w.i] = '='
w.line[w.i+1] = upperhex[b>>4]
w.line[w.i+2] = upperhex[b&0x0f]
w.i += 3
return nil
}
const upperhex = "0123456789ABCDEF"
// checkLastByte encodes the last buffered byte if it is a space or a tab.
func (w *Writer) checkLastByte() error {
if w.i == 0 {
return nil
}
b := w.line[w.i-1]
if isWhitespace(b) {
w.i--
if err := w.encode(b); err != nil {
return err
}
}
return nil
}
func (w *Writer) insertSoftLineBreak() error {
w.line[w.i] = '='
w.i++
return w.insertCRLF()
}
func (w *Writer) insertCRLF() error {
w.line[w.i] = '\r'
w.line[w.i+1] = '\n'
w.i += 2
return w.flush()
}
func (w *Writer) flush() error {
if _, err := w.w.Write(w.line[:w.i]); err != nil {
return err
}
w.i = 0
return nil
}
func isWhitespace(b byte) bool {
return b == ' ' || b == '\t'
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package mime implements parts of the MIME spec.
package mime
import (
"fmt"
"sort"
"strings"
"sync"
)
var (
mimeTypes sync.Map // map[string]string; ".Z" => "application/x-compress"
mimeTypesLower sync.Map // map[string]string; ".z" => "application/x-compress"
// extensions maps from MIME type to list of lowercase file
// extensions: "image/jpeg" => [".jpg", ".jpeg"]
extensionsMu sync.Mutex // Guards stores (but not loads) on extensions.
extensions sync.Map // map[string][]string; slice values are append-only.
)
func clearSyncMap(m *sync.Map) {
m.Range(func(k, _ any) bool {
m.Delete(k)
return true
})
}
// setMimeTypes is used by initMime's non-test path, and by tests.
func setMimeTypes(lowerExt, mixExt map[string]string) {
clearSyncMap(&mimeTypes)
clearSyncMap(&mimeTypesLower)
clearSyncMap(&extensions)
for k, v := range lowerExt {
mimeTypesLower.Store(k, v)
}
for k, v := range mixExt {
mimeTypes.Store(k, v)
}
extensionsMu.Lock()
defer extensionsMu.Unlock()
for k, v := range lowerExt {
justType, _, err := ParseMediaType(v)
if err != nil {
panic(err)
}
var exts []string
if ei, ok := extensions.Load(justType); ok {
exts = ei.([]string)
}
extensions.Store(justType, append(exts, k))
}
}
var builtinTypesLower = map[string]string{
".avif": "image/avif",
".css": "text/css; charset=utf-8",
".gif": "image/gif",
".htm": "text/html; charset=utf-8",
".html": "text/html; charset=utf-8",
".jpeg": "image/jpeg",
".jpg": "image/jpeg",
".js": "text/javascript; charset=utf-8",
".json": "application/json",
".mjs": "text/javascript; charset=utf-8",
".pdf": "application/pdf",
".png": "image/png",
".svg": "image/svg+xml",
".wasm": "application/wasm",
".webp": "image/webp",
".xml": "text/xml; charset=utf-8",
}
var once sync.Once // guards initMime
var testInitMime, osInitMime func()
func initMime() {
if fn := testInitMime; fn != nil {
fn()
} else {
setMimeTypes(builtinTypesLower, builtinTypesLower)
osInitMime()
}
}
// TypeByExtension returns the MIME type associated with the file extension ext.
// The extension ext should begin with a leading dot, as in ".html".
// When ext has no associated type, TypeByExtension returns "".
//
// Extensions are looked up first case-sensitively, then case-insensitively.
//
// The built-in table is small but on unix it is augmented by the local
// system's MIME-info database or mime.types file(s) if available under one or
// more of these names:
//
// /usr/local/share/mime/globs2
// /usr/share/mime/globs2
// /etc/mime.types
// /etc/apache2/mime.types
// /etc/apache/mime.types
//
// On Windows, MIME types are extracted from the registry.
//
// Text types have the charset parameter set to "utf-8" by default.
func TypeByExtension(ext string) string {
once.Do(initMime)
// Case-sensitive lookup.
if v, ok := mimeTypes.Load(ext); ok {
return v.(string)
}
// Case-insensitive lookup.
// Optimistically assume a short ASCII extension and be
// allocation-free in that case.
var buf [10]byte
lower := buf[:0]
const utf8RuneSelf = 0x80 // from utf8 package, but not importing it.
for i := 0; i < len(ext); i++ {
c := ext[i]
if c >= utf8RuneSelf {
// Slow path.
si, _ := mimeTypesLower.Load(strings.ToLower(ext))
s, _ := si.(string)
return s
}
if 'A' <= c && c <= 'Z' {
lower = append(lower, c+('a'-'A'))
} else {
lower = append(lower, c)
}
}
si, _ := mimeTypesLower.Load(string(lower))
s, _ := si.(string)
return s
}
// ExtensionsByType returns the extensions known to be associated with the MIME
// type typ. The returned extensions will each begin with a leading dot, as in
// ".html". When typ has no associated extensions, ExtensionsByType returns an
// nil slice.
func ExtensionsByType(typ string) ([]string, error) {
justType, _, err := ParseMediaType(typ)
if err != nil {
return nil, err
}
once.Do(initMime)
s, ok := extensions.Load(justType)
if !ok {
return nil, nil
}
ret := append([]string(nil), s.([]string)...)
sort.Strings(ret)
return ret, nil
}
// AddExtensionType sets the MIME type associated with
// the extension ext to typ. The extension should begin with
// a leading dot, as in ".html".
func AddExtensionType(ext, typ string) error {
if !strings.HasPrefix(ext, ".") {
return fmt.Errorf("mime: extension %q missing leading dot", ext)
}
once.Do(initMime)
return setExtensionType(ext, typ)
}
func setExtensionType(extension, mimeType string) error {
justType, param, err := ParseMediaType(mimeType)
if err != nil {
return err
}
if strings.HasPrefix(mimeType, "text/") && param["charset"] == "" {
param["charset"] = "utf-8"
mimeType = FormatMediaType(mimeType, param)
}
extLower := strings.ToLower(extension)
mimeTypes.Store(extension, mimeType)
mimeTypesLower.Store(extLower, mimeType)
extensionsMu.Lock()
defer extensionsMu.Unlock()
var exts []string
if ei, ok := extensions.Load(justType); ok {
exts = ei.([]string)
}
for _, v := range exts {
if v == extLower {
return nil
}
}
extensions.Store(justType, append(exts, extLower))
return nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package mime
import (
"bufio"
"os"
"strings"
)
func init() {
osInitMime = initMimeUnix
}
// See https://specifications.freedesktop.org/shared-mime-info-spec/shared-mime-info-spec-0.21.html
// for the FreeDesktop Shared MIME-info Database specification.
var mimeGlobs = []string{
"/usr/local/share/mime/globs2",
"/usr/share/mime/globs2",
}
// Common locations for mime.types files on unix.
var typeFiles = []string{
"/etc/mime.types",
"/etc/apache2/mime.types",
"/etc/apache/mime.types",
"/etc/httpd/conf/mime.types",
}
func loadMimeGlobsFile(filename string) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
// Each line should be of format: weight:mimetype:glob[:morefields...]
fields := strings.Split(scanner.Text(), ":")
if len(fields) < 3 || len(fields[0]) < 1 || len(fields[2]) < 3 {
continue
} else if fields[0][0] == '#' || fields[2][0] != '*' || fields[2][1] != '.' {
continue
}
extension := fields[2][1:]
if strings.ContainsAny(extension, "?*[") {
// Not a bare extension, but a glob. Ignore for now:
// - we do not have an implementation for this glob
// syntax (translation to path/filepath.Match could
// be possible)
// - support for globs with weight ordering would have
// performance impact to all lookups to support the
// rarely seen glob entries
// - trying to match glob metacharacters literally is
// not useful
continue
}
if _, ok := mimeTypes.Load(extension); ok {
// We've already seen this extension.
// The file is in weight order, so we keep
// the first entry that we see.
continue
}
setExtensionType(extension, fields[1])
}
if err := scanner.Err(); err != nil {
panic(err)
}
return nil
}
func loadMimeFile(filename string) {
f, err := os.Open(filename)
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) <= 1 || fields[0][0] == '#' {
continue
}
mimeType := fields[0]
for _, ext := range fields[1:] {
if ext[0] == '#' {
break
}
setExtensionType("."+ext, mimeType)
}
}
if err := scanner.Err(); err != nil {
panic(err)
}
}
func initMimeUnix() {
for _, filename := range mimeGlobs {
if err := loadMimeGlobsFile(filename); err == nil {
return // Stop checking more files if mimetype database is found.
}
}
// Fallback if no system-generated mimetype database exists.
for _, filename := range typeFiles {
loadMimeFile(filename)
}
}
func initMimeForTests() map[string]string {
mimeGlobs = []string{""}
typeFiles = []string{"testdata/test.types"}
return map[string]string{
".T1": "application/test",
".t2": "text/test; charset=utf-8",
".png": "image/png",
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Minimal RFC 6724 address selection.
package net
import (
"net/netip"
"sort"
)
func sortByRFC6724(addrs []IPAddr) {
if len(addrs) < 2 {
return
}
sortByRFC6724withSrcs(addrs, srcAddrs(addrs))
}
func sortByRFC6724withSrcs(addrs []IPAddr, srcs []netip.Addr) {
if len(addrs) != len(srcs) {
panic("internal error")
}
addrAttr := make([]ipAttr, len(addrs))
srcAttr := make([]ipAttr, len(srcs))
for i, v := range addrs {
addrAttrIP, _ := netip.AddrFromSlice(v.IP)
addrAttr[i] = ipAttrOf(addrAttrIP)
srcAttr[i] = ipAttrOf(srcs[i])
}
sort.Stable(&byRFC6724{
addrs: addrs,
addrAttr: addrAttr,
srcs: srcs,
srcAttr: srcAttr,
})
}
// srcAddrs tries to UDP-connect to each address to see if it has a
// route. (This doesn't send any packets). The destination port
// number is irrelevant.
func srcAddrs(addrs []IPAddr) []netip.Addr {
srcs := make([]netip.Addr, len(addrs))
dst := UDPAddr{Port: 9}
for i := range addrs {
dst.IP = addrs[i].IP
dst.Zone = addrs[i].Zone
c, err := DialUDP("udp", nil, &dst)
if err == nil {
if src, ok := c.LocalAddr().(*UDPAddr); ok {
srcs[i], _ = netip.AddrFromSlice(src.IP)
}
c.Close()
}
}
return srcs
}
type ipAttr struct {
Scope scope
Precedence uint8
Label uint8
}
func ipAttrOf(ip netip.Addr) ipAttr {
if !ip.IsValid() {
return ipAttr{}
}
match := rfc6724policyTable.Classify(ip)
return ipAttr{
Scope: classifyScope(ip),
Precedence: match.Precedence,
Label: match.Label,
}
}
type byRFC6724 struct {
addrs []IPAddr // addrs to sort
addrAttr []ipAttr
srcs []netip.Addr // or not valid addr if unreachable
srcAttr []ipAttr
}
func (s *byRFC6724) Len() int { return len(s.addrs) }
func (s *byRFC6724) Swap(i, j int) {
s.addrs[i], s.addrs[j] = s.addrs[j], s.addrs[i]
s.srcs[i], s.srcs[j] = s.srcs[j], s.srcs[i]
s.addrAttr[i], s.addrAttr[j] = s.addrAttr[j], s.addrAttr[i]
s.srcAttr[i], s.srcAttr[j] = s.srcAttr[j], s.srcAttr[i]
}
// Less reports whether i is a better destination address for this
// host than j.
//
// The algorithm and variable names comes from RFC 6724 section 6.
func (s *byRFC6724) Less(i, j int) bool {
DA := s.addrs[i].IP
DB := s.addrs[j].IP
SourceDA := s.srcs[i]
SourceDB := s.srcs[j]
attrDA := &s.addrAttr[i]
attrDB := &s.addrAttr[j]
attrSourceDA := &s.srcAttr[i]
attrSourceDB := &s.srcAttr[j]
const preferDA = true
const preferDB = false
// Rule 1: Avoid unusable destinations.
// If DB is known to be unreachable or if Source(DB) is undefined, then
// prefer DA. Similarly, if DA is known to be unreachable or if
// Source(DA) is undefined, then prefer DB.
if !SourceDA.IsValid() && !SourceDB.IsValid() {
return false // "equal"
}
if !SourceDB.IsValid() {
return preferDA
}
if !SourceDA.IsValid() {
return preferDB
}
// Rule 2: Prefer matching scope.
// If Scope(DA) = Scope(Source(DA)) and Scope(DB) <> Scope(Source(DB)),
// then prefer DA. Similarly, if Scope(DA) <> Scope(Source(DA)) and
// Scope(DB) = Scope(Source(DB)), then prefer DB.
if attrDA.Scope == attrSourceDA.Scope && attrDB.Scope != attrSourceDB.Scope {
return preferDA
}
if attrDA.Scope != attrSourceDA.Scope && attrDB.Scope == attrSourceDB.Scope {
return preferDB
}
// Rule 3: Avoid deprecated addresses.
// If Source(DA) is deprecated and Source(DB) is not, then prefer DB.
// Similarly, if Source(DA) is not deprecated and Source(DB) is
// deprecated, then prefer DA.
// TODO(bradfitz): implement? low priority for now.
// Rule 4: Prefer home addresses.
// If Source(DA) is simultaneously a home address and care-of address
// and Source(DB) is not, then prefer DA. Similarly, if Source(DB) is
// simultaneously a home address and care-of address and Source(DA) is
// not, then prefer DB.
// TODO(bradfitz): implement? low priority for now.
// Rule 5: Prefer matching label.
// If Label(Source(DA)) = Label(DA) and Label(Source(DB)) <> Label(DB),
// then prefer DA. Similarly, if Label(Source(DA)) <> Label(DA) and
// Label(Source(DB)) = Label(DB), then prefer DB.
if attrSourceDA.Label == attrDA.Label &&
attrSourceDB.Label != attrDB.Label {
return preferDA
}
if attrSourceDA.Label != attrDA.Label &&
attrSourceDB.Label == attrDB.Label {
return preferDB
}
// Rule 6: Prefer higher precedence.
// If Precedence(DA) > Precedence(DB), then prefer DA. Similarly, if
// Precedence(DA) < Precedence(DB), then prefer DB.
if attrDA.Precedence > attrDB.Precedence {
return preferDA
}
if attrDA.Precedence < attrDB.Precedence {
return preferDB
}
// Rule 7: Prefer native transport.
// If DA is reached via an encapsulating transition mechanism (e.g.,
// IPv6 in IPv4) and DB is not, then prefer DB. Similarly, if DB is
// reached via encapsulation and DA is not, then prefer DA.
// TODO(bradfitz): implement? low priority for now.
// Rule 8: Prefer smaller scope.
// If Scope(DA) < Scope(DB), then prefer DA. Similarly, if Scope(DA) >
// Scope(DB), then prefer DB.
if attrDA.Scope < attrDB.Scope {
return preferDA
}
if attrDA.Scope > attrDB.Scope {
return preferDB
}
// Rule 9: Use the longest matching prefix.
// When DA and DB belong to the same address family (both are IPv6 or
// both are IPv4 [but see below]): If CommonPrefixLen(Source(DA), DA) >
// CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if
// CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB),
// then prefer DB.
//
// However, applying this rule to IPv4 addresses causes
// problems (see issues 13283 and 18518), so limit to IPv6.
if DA.To4() == nil && DB.To4() == nil {
commonA := commonPrefixLen(SourceDA, DA)
commonB := commonPrefixLen(SourceDB, DB)
if commonA > commonB {
return preferDA
}
if commonA < commonB {
return preferDB
}
}
// Rule 10: Otherwise, leave the order unchanged.
// If DA preceded DB in the original list, prefer DA.
// Otherwise, prefer DB.
return false // "equal"
}
type policyTableEntry struct {
Prefix netip.Prefix
Precedence uint8
Label uint8
}
type policyTable []policyTableEntry
// RFC 6724 section 2.1.
// Items are sorted by the size of their Prefix.Mask.Size,
var rfc6724policyTable = policyTable{
{
// "::1/128"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}), 128),
Precedence: 50,
Label: 0,
},
{
// "::ffff:0:0/96"
// IPv4-compatible, etc.
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}), 96),
Precedence: 35,
Label: 4,
},
{
// "::/96"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 96),
Precedence: 1,
Label: 3,
},
{
// "2001::/32"
// Teredo
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01}), 32),
Precedence: 5,
Label: 5,
},
{
// "2002::/16"
// 6to4
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x02}), 16),
Precedence: 30,
Label: 2,
},
{
// "3ffe::/16"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x3f, 0xfe}), 16),
Precedence: 1,
Label: 12,
},
{
// "fec0::/10"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfe, 0xc0}), 10),
Precedence: 1,
Label: 11,
},
{
// "fc00::/7"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfc}), 7),
Precedence: 3,
Label: 13,
},
{
// "::/0"
Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
Precedence: 40,
Label: 1,
},
}
// Classify returns the policyTableEntry of the entry with the longest
// matching prefix that contains ip.
// The table t must be sorted from largest mask size to smallest.
func (t policyTable) Classify(ip netip.Addr) policyTableEntry {
// Prefix.Contains() will not match an IPv6 prefix for an IPv4 address.
if ip.Is4() {
ip = netip.AddrFrom16(ip.As16())
}
for _, ent := range t {
if ent.Prefix.Contains(ip) {
return ent
}
}
return policyTableEntry{}
}
// RFC 6724 section 3.1.
type scope uint8
const (
scopeInterfaceLocal scope = 0x1
scopeLinkLocal scope = 0x2
scopeAdminLocal scope = 0x4
scopeSiteLocal scope = 0x5
scopeOrgLocal scope = 0x8
scopeGlobal scope = 0xe
)
func classifyScope(ip netip.Addr) scope {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
return scopeLinkLocal
}
ipv6 := ip.Is6() && !ip.Is4In6()
ipv6AsBytes := ip.As16()
if ipv6 && ip.IsMulticast() {
return scope(ipv6AsBytes[1] & 0xf)
}
// Site-local addresses are defined in RFC 3513 section 2.5.6
// (and deprecated in RFC 3879).
if ipv6 && ipv6AsBytes[0] == 0xfe && ipv6AsBytes[1]&0xc0 == 0xc0 {
return scopeSiteLocal
}
return scopeGlobal
}
// commonPrefixLen reports the length of the longest prefix (looking
// at the most significant, or leftmost, bits) that the
// two addresses have in common, up to the length of a's prefix (i.e.,
// the portion of the address not including the interface ID).
//
// If a or b is an IPv4 address as an IPv6 address, the IPv4 addresses
// are compared (with max common prefix length of 32).
// If a and b are different IP versions, 0 is returned.
//
// See https://tools.ietf.org/html/rfc6724#section-2.2
func commonPrefixLen(a netip.Addr, b IP) (cpl int) {
if b4 := b.To4(); b4 != nil {
b = b4
}
aAsSlice := a.AsSlice()
if len(aAsSlice) != len(b) {
return 0
}
// If IPv6, only up to the prefix (first 64 bits)
if len(aAsSlice) > 8 {
aAsSlice = aAsSlice[:8]
b = b[:8]
}
for len(aAsSlice) > 0 {
if aAsSlice[0] == b[0] {
cpl += 8
aAsSlice = aAsSlice[1:]
b = b[1:]
continue
}
bits := 8
ab, bb := aAsSlice[0], b[0]
for {
ab >>= 1
bb >>= 1
bits--
if ab == bb {
cpl += bits
return
}
}
}
return
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && ((linux && !android) || netbsd || solaris)
package net
/*
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
*/
import "C"
import "unsafe"
func cgoNameinfoPTR(b []byte, sa *C.struct_sockaddr, salen C.socklen_t) (int, error) {
gerrno, err := C.getnameinfo(sa, salen, (*C.char)(unsafe.Pointer(&b[0])), C.socklen_t(len(b)), nil, 0, C.NI_NAMEREQD)
return int(gerrno), err
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && (android || linux || solaris)
package net
/*
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
import (
"syscall"
"unsafe"
)
func cgoSockaddrInet4(ip IP) *C.struct_sockaddr {
sa := syscall.RawSockaddrInet4{Family: syscall.AF_INET}
copy(sa.Addr[:], ip)
return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
}
func cgoSockaddrInet6(ip IP, zone int) *C.struct_sockaddr {
sa := syscall.RawSockaddrInet6{Family: syscall.AF_INET6, Scope_id: uint32(zone)}
copy(sa.Addr[:], ip)
return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is called cgo_unix.go, but to allow syscalls-to-libc-based
// implementations to share the code, it does not use cgo directly.
// Instead of C.foo it uses _C_foo, which is defined in either
// cgo_unix_cgo.go or cgo_unix_syscall.go
//go:build !netgo && ((cgo && unix) || darwin)
package net
import (
"context"
"errors"
"net/netip"
"syscall"
"unsafe"
"golang.org/x/net/dns/dnsmessage"
)
// An addrinfoErrno represents a getaddrinfo, getnameinfo-specific
// error number. It's a signed number and a zero value is a non-error
// by convention.
type addrinfoErrno int
func (eai addrinfoErrno) Error() string { return _C_gai_strerror(_C_int(eai)) }
func (eai addrinfoErrno) Temporary() bool { return eai == _C_EAI_AGAIN }
func (eai addrinfoErrno) Timeout() bool { return false }
// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided
// context is cancellable. It is intended for use with calls that don't support context
// cancellation (cgo, syscalls). blocking func may still be running after this function finishes.
func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) {
if ctx.Done() == nil {
return blocking()
}
type result struct {
res T
err error
}
res := make(chan result, 1)
go func() {
var r result
r.res, r.err = blocking()
res <- r
}()
select {
case r := <-res:
return r.res, r.err
case <-ctx.Done():
var zero T
return zero, mapErr(ctx.Err())
}
}
func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) {
addrs, err, completed := cgoLookupIP(ctx, "ip", name)
for _, addr := range addrs {
hosts = append(hosts, addr.String())
}
return
}
func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) {
var hints _C_struct_addrinfo
switch network {
case "": // no hints
case "tcp", "tcp4", "tcp6":
*_C_ai_socktype(&hints) = _C_SOCK_STREAM
*_C_ai_protocol(&hints) = _C_IPPROTO_TCP
case "udp", "udp4", "udp6":
*_C_ai_socktype(&hints) = _C_SOCK_DGRAM
*_C_ai_protocol(&hints) = _C_IPPROTO_UDP
default:
return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}, true
}
switch ipVersion(network) {
case '4':
*_C_ai_family(&hints) = _C_AF_INET
case '6':
*_C_ai_family(&hints) = _C_AF_INET6
}
port, err = doBlockingWithCtx(ctx, func() (int, error) {
return cgoLookupServicePort(&hints, network, service)
})
return port, err, true
}
func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (port int, err error) {
cservice := make([]byte, len(service)+1)
copy(cservice, service)
// Lowercase the C service name.
for i, b := range cservice[:len(service)] {
cservice[i] = lowerASCII(b)
}
var res *_C_struct_addrinfo
gerrno, err := _C_getaddrinfo(nil, (*_C_char)(unsafe.Pointer(&cservice[0])), hints, &res)
if gerrno != 0 {
isTemporary := false
switch gerrno {
case _C_EAI_SYSTEM:
if err == nil { // see golang.org/issue/6232
err = syscall.EMFILE
}
default:
err = addrinfoErrno(gerrno)
isTemporary = addrinfoErrno(gerrno).Temporary()
}
return 0, &DNSError{Err: err.Error(), Name: network + "/" + service, IsTemporary: isTemporary}
}
defer _C_freeaddrinfo(res)
for r := res; r != nil; r = *_C_ai_next(r) {
switch *_C_ai_family(r) {
case _C_AF_INET:
sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(*_C_ai_addr(r)))
p := (*[2]byte)(unsafe.Pointer(&sa.Port))
return int(p[0])<<8 | int(p[1]), nil
case _C_AF_INET6:
sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(*_C_ai_addr(r)))
p := (*[2]byte)(unsafe.Pointer(&sa.Port))
return int(p[0])<<8 | int(p[1]), nil
}
}
return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}
}
func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
acquireThread()
defer releaseThread()
var hints _C_struct_addrinfo
*_C_ai_flags(&hints) = cgoAddrInfoFlags
*_C_ai_socktype(&hints) = _C_SOCK_STREAM
*_C_ai_family(&hints) = _C_AF_UNSPEC
switch ipVersion(network) {
case '4':
*_C_ai_family(&hints) = _C_AF_INET
case '6':
*_C_ai_family(&hints) = _C_AF_INET6
}
h := make([]byte, len(name)+1)
copy(h, name)
var res *_C_struct_addrinfo
gerrno, err := _C_getaddrinfo((*_C_char)(unsafe.Pointer(&h[0])), nil, &hints, &res)
if gerrno != 0 {
isErrorNoSuchHost := false
isTemporary := false
switch gerrno {
case _C_EAI_SYSTEM:
if err == nil {
// err should not be nil, but sometimes getaddrinfo returns
// gerrno == _C_EAI_SYSTEM with err == nil on Linux.
// The report claims that it happens when we have too many
// open files, so use syscall.EMFILE (too many open files in system).
// Most system calls would return ENFILE (too many open files),
// so at the least EMFILE should be easy to recognize if this
// comes up again. golang.org/issue/6232.
err = syscall.EMFILE
}
case _C_EAI_NONAME:
err = errNoSuchHost
isErrorNoSuchHost = true
default:
err = addrinfoErrno(gerrno)
isTemporary = addrinfoErrno(gerrno).Temporary()
}
return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: isErrorNoSuchHost, IsTemporary: isTemporary}
}
defer _C_freeaddrinfo(res)
for r := res; r != nil; r = *_C_ai_next(r) {
// We only asked for SOCK_STREAM, but check anyhow.
if *_C_ai_socktype(r) != _C_SOCK_STREAM {
continue
}
switch *_C_ai_family(r) {
case _C_AF_INET:
sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(*_C_ai_addr(r)))
addr := IPAddr{IP: copyIP(sa.Addr[:])}
addrs = append(addrs, addr)
case _C_AF_INET6:
sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(*_C_ai_addr(r)))
addr := IPAddr{IP: copyIP(sa.Addr[:]), Zone: zoneCache.name(int(sa.Scope_id))}
addrs = append(addrs, addr)
}
}
return addrs, nil
}
func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) {
addrs, err = doBlockingWithCtx(ctx, func() ([]IPAddr, error) {
return cgoLookupHostIP(network, name)
})
return addrs, err, true
}
// These are roughly enough for the following:
//
// Source Encoding Maximum length of single name entry
// Unicast DNS ASCII or <=253 + a NUL terminator
// Unicode in RFC 5892 252 * total number of labels + delimiters + a NUL terminator
// Multicast DNS UTF-8 in RFC 5198 or <=253 + a NUL terminator
// the same as unicast DNS ASCII <=253 + a NUL terminator
// Local database various depends on implementation
const (
nameinfoLen = 64
maxNameinfoLen = 4096
)
func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, completed bool) {
ip, err := netip.ParseAddr(addr)
if err != nil {
return nil, &DNSError{Err: "invalid address", Name: addr}, true
}
sa, salen := cgoSockaddr(IP(ip.AsSlice()), ip.Zone())
if sa == nil {
return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true
}
names, err = doBlockingWithCtx(ctx, func() ([]string, error) {
return cgoLookupAddrPTR(addr, sa, salen)
})
return names, err, true
}
func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) {
acquireThread()
defer releaseThread()
var gerrno int
var b []byte
for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 {
b = make([]byte, l)
gerrno, err = cgoNameinfoPTR(b, sa, salen)
if gerrno == 0 || gerrno != _C_EAI_OVERFLOW {
break
}
}
if gerrno != 0 {
isErrorNoSuchHost := false
isTemporary := false
switch gerrno {
case _C_EAI_SYSTEM:
if err == nil { // see golang.org/issue/6232
err = syscall.EMFILE
}
case _C_EAI_NONAME:
err = errNoSuchHost
isErrorNoSuchHost = true
default:
err = addrinfoErrno(gerrno)
isTemporary = addrinfoErrno(gerrno).Temporary()
}
return nil, &DNSError{Err: err.Error(), Name: addr, IsTemporary: isTemporary, IsNotFound: isErrorNoSuchHost}
}
for i := 0; i < len(b); i++ {
if b[i] == 0 {
b = b[:i]
break
}
}
return []string{absDomainName(string(b))}, nil
}
func cgoSockaddr(ip IP, zone string) (*_C_struct_sockaddr, _C_socklen_t) {
if ip4 := ip.To4(); ip4 != nil {
return cgoSockaddrInet4(ip4), _C_socklen_t(syscall.SizeofSockaddrInet4)
}
if ip6 := ip.To16(); ip6 != nil {
return cgoSockaddrInet6(ip6, zoneCache.index(zone)), _C_socklen_t(syscall.SizeofSockaddrInet6)
}
return nil, 0
}
func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
resources, err := resSearch(ctx, name, int(dnsmessage.TypeCNAME), int(dnsmessage.ClassINET))
if err != nil {
return
}
cname, err = parseCNAMEFromResources(resources)
if err != nil {
return "", err, false
}
return cname, nil, true
}
// resSearch will make a call to the 'res_nsearch' routine in the C library
// and parse the output as a slice of DNS resources.
func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) {
return cgoResSearch(hostname, rtype, class)
})
}
func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
acquireThread()
defer releaseThread()
state := (*_C_struct___res_state)(_C_malloc(unsafe.Sizeof(_C_struct___res_state{})))
defer _C_free(unsafe.Pointer(state))
if err := _C_res_ninit(state); err != nil {
return nil, errors.New("res_ninit failure: " + err.Error())
}
defer _C_res_nclose(state)
// Some res_nsearch implementations (like macOS) do not set errno.
// They set h_errno, which is not per-thread and useless to us.
// res_nsearch returns the size of the DNS response packet.
// But if the DNS response packet contains failure-like response codes,
// res_search returns -1 even though it has copied the packet into buf,
// giving us no way to find out how big the packet is.
// For now, we are willing to take res_search's word that there's nothing
// useful in the response, even though there *is* a response.
bufSize := maxDNSPacketSize
buf := (*_C_uchar)(_C_malloc(uintptr(bufSize)))
defer _C_free(unsafe.Pointer(buf))
s := _C_CString(hostname)
defer _C_FreeCString(s)
var size int
for {
size, _ = _C_res_nsearch(state, s, class, rtype, buf, bufSize)
if size <= 0 || size > 0xffff {
return nil, errors.New("res_nsearch failure")
}
if size <= bufSize {
break
}
// Allocate a bigger buffer to fit the entire msg.
_C_free(unsafe.Pointer(buf))
bufSize = size
buf = (*_C_uchar)(_C_malloc(uintptr(bufSize)))
}
var p dnsmessage.Parser
if _, err := p.Start(unsafe.Slice((*byte)(unsafe.Pointer(buf)), size)); err != nil {
return nil, err
}
p.SkipAllQuestions()
resources, err := p.AllAnswers()
if err != nil {
return nil, err
}
return resources, nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build cgo && !netgo && unix && !darwin
package net
/*
#cgo CFLAGS: -fno-stack-protector
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <stdlib.h>
// If nothing else defined EAI_OVERFLOW, make sure it has a value.
#ifndef EAI_OVERFLOW
#define EAI_OVERFLOW -12
#endif
*/
import "C"
import "unsafe"
const (
_C_AF_INET = C.AF_INET
_C_AF_INET6 = C.AF_INET6
_C_AF_UNSPEC = C.AF_UNSPEC
_C_EAI_AGAIN = C.EAI_AGAIN
_C_EAI_NONAME = C.EAI_NONAME
_C_EAI_OVERFLOW = C.EAI_OVERFLOW
_C_EAI_SYSTEM = C.EAI_SYSTEM
_C_IPPROTO_TCP = C.IPPROTO_TCP
_C_IPPROTO_UDP = C.IPPROTO_UDP
_C_SOCK_DGRAM = C.SOCK_DGRAM
_C_SOCK_STREAM = C.SOCK_STREAM
)
type (
_C_char = C.char
_C_uchar = C.uchar
_C_int = C.int
_C_uint = C.uint
_C_socklen_t = C.socklen_t
_C_struct_addrinfo = C.struct_addrinfo
_C_struct_sockaddr = C.struct_sockaddr
)
func _C_GoString(p *_C_char) string { return C.GoString(p) }
func _C_CString(s string) *_C_char { return C.CString(s) }
func _C_FreeCString(p *_C_char) { C.free(unsafe.Pointer(p)) }
func _C_malloc(n uintptr) unsafe.Pointer { return C.malloc(C.size_t(n)) }
func _C_free(p unsafe.Pointer) { C.free(p) }
func _C_ai_addr(ai *_C_struct_addrinfo) **_C_struct_sockaddr { return &ai.ai_addr }
func _C_ai_family(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_family }
func _C_ai_flags(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_flags }
func _C_ai_next(ai *_C_struct_addrinfo) **_C_struct_addrinfo { return &ai.ai_next }
func _C_ai_protocol(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_protocol }
func _C_ai_socktype(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_socktype }
func _C_freeaddrinfo(ai *_C_struct_addrinfo) {
C.freeaddrinfo(ai)
}
func _C_gai_strerror(eai _C_int) string {
return C.GoString(C.gai_strerror(eai))
}
func _C_getaddrinfo(hostname, servname *_C_char, hints *_C_struct_addrinfo, res **_C_struct_addrinfo) (int, error) {
x, err := C.getaddrinfo(hostname, servname, hints, res)
return int(x), err
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// res_search, for cgo systems where that is thread-safe.
//go:build cgo && !netgo && (linux || openbsd)
package net
/*
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <arpa/nameser.h>
#include <resolv.h>
#cgo !android,!openbsd LDFLAGS: -lresolv
*/
import "C"
type _C_struct___res_state = struct{}
func _C_res_ninit(state *_C_struct___res_state) error {
return nil
}
func _C_res_nclose(state *_C_struct___res_state) {
return
}
func _C_res_nsearch(state *_C_struct___res_state, dname *_C_char, class, typ int, ans *_C_uchar, anslen int) (int, error) {
x, err := C.res_search(dname, C.int(class), C.int(typ), ans, C.int(anslen))
return int(x), err
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js
package net
import (
"internal/bytealg"
"internal/godebug"
"os"
"runtime"
"sync"
"syscall"
)
// conf represents a system's network configuration.
type conf struct {
// forceCgoLookupHost forces CGO to always be used, if available.
forceCgoLookupHost bool
netGo bool // go DNS resolution forced
netCgo bool // non-go DNS resolution forced (cgo, or win32)
// machine has an /etc/mdns.allow file
hasMDNSAllow bool
goos string // the runtime.GOOS, to ease testing
dnsDebugLevel int
}
var (
confOnce sync.Once // guards init of confVal via initConfVal
confVal = &conf{goos: runtime.GOOS}
)
// systemConf returns the machine's network configuration.
func systemConf() *conf {
confOnce.Do(initConfVal)
return confVal
}
func initConfVal() {
dnsMode, debugLevel := goDebugNetDNS()
confVal.dnsDebugLevel = debugLevel
confVal.netGo = netGo || dnsMode == "go"
confVal.netCgo = netCgo || dnsMode == "cgo"
if !confVal.netGo && !confVal.netCgo && (runtime.GOOS == "windows" || runtime.GOOS == "plan9") {
// Neither of these platforms actually use cgo.
//
// The meaning of "cgo" mode in the net package is
// really "the native OS way", which for libc meant
// cgo on the original platforms that motivated
// PreferGo support before Windows and Plan9 got support,
// at which time the GODEBUG=netdns=go and GODEBUG=netdns=cgo
// names were already kinda locked in.
confVal.netCgo = true
}
if confVal.dnsDebugLevel > 0 {
defer func() {
if confVal.dnsDebugLevel > 1 {
println("go package net: confVal.netCgo =", confVal.netCgo, " netGo =", confVal.netGo)
}
switch {
case confVal.netGo:
if netGo {
println("go package net: built with netgo build tag; using Go's DNS resolver")
} else {
println("go package net: GODEBUG setting forcing use of Go's resolver")
}
case confVal.forceCgoLookupHost:
println("go package net: using cgo DNS resolver")
default:
println("go package net: dynamic selection of DNS resolver")
}
}()
}
// Darwin pops up annoying dialog boxes if programs try to do
// their own DNS requests. So always use cgo instead, which
// avoids that.
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
confVal.forceCgoLookupHost = true
return
}
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
return
}
// If any environment-specified resolver options are specified,
// force cgo. Note that LOCALDOMAIN can change behavior merely
// by being specified with the empty string.
_, localDomainDefined := syscall.Getenv("LOCALDOMAIN")
if os.Getenv("RES_OPTIONS") != "" ||
os.Getenv("HOSTALIASES") != "" ||
confVal.netCgo ||
localDomainDefined {
confVal.forceCgoLookupHost = true
return
}
// OpenBSD apparently lets you override the location of resolv.conf
// with ASR_CONFIG. If we notice that, defer to libc.
if runtime.GOOS == "openbsd" && os.Getenv("ASR_CONFIG") != "" {
confVal.forceCgoLookupHost = true
return
}
if _, err := os.Stat("/etc/mdns.allow"); err == nil {
confVal.hasMDNSAllow = true
}
}
// canUseCgo reports whether calling cgo functions is allowed
// for non-hostname lookups.
func (c *conf) canUseCgo() bool {
ret, _ := c.hostLookupOrder(nil, "")
return ret == hostLookupCgo
}
// hostLookupOrder determines which strategy to use to resolve hostname.
// The provided Resolver is optional. nil means to not consider its options.
// It also returns dnsConfig when it was used to determine the lookup order.
func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrder, dnsConfig *dnsConfig) {
if c.dnsDebugLevel > 1 {
defer func() {
print("go package net: hostLookupOrder(", hostname, ") = ", ret.String(), "\n")
}()
}
fallbackOrder := hostLookupCgo
if c.netGo || r.preferGo() {
switch c.goos {
case "windows":
// TODO(bradfitz): implement files-based
// lookup on Windows too? I guess /etc/hosts
// kinda exists on Windows. But for now, only
// do DNS.
fallbackOrder = hostLookupDNS
default:
fallbackOrder = hostLookupFilesDNS
}
}
if c.forceCgoLookupHost || c.goos == "android" || c.goos == "windows" || c.goos == "plan9" {
return fallbackOrder, nil
}
if bytealg.IndexByteString(hostname, '\\') != -1 || bytealg.IndexByteString(hostname, '%') != -1 {
// Don't deal with special form hostnames with backslashes
// or '%'.
return fallbackOrder, nil
}
conf := getSystemDNSConfig()
if conf.err != nil && !os.IsNotExist(conf.err) && !os.IsPermission(conf.err) {
// If we can't read the resolv.conf file, assume it
// had something important in it and defer to cgo.
// libc's resolver might then fail too, but at least
// it wasn't our fault.
return fallbackOrder, conf
}
if conf.unknownOpt {
return fallbackOrder, conf
}
// OpenBSD is unique and doesn't use nsswitch.conf.
// It also doesn't support mDNS.
if c.goos == "openbsd" {
// OpenBSD's resolv.conf manpage says that a non-existent
// resolv.conf means "lookup" defaults to only "files",
// without DNS lookups.
if os.IsNotExist(conf.err) {
return hostLookupFiles, conf
}
lookup := conf.lookup
if len(lookup) == 0 {
// https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
// "If the lookup keyword is not used in the
// system's resolv.conf file then the assumed
// order is 'bind file'"
return hostLookupDNSFiles, conf
}
if len(lookup) < 1 || len(lookup) > 2 {
return fallbackOrder, conf
}
switch lookup[0] {
case "bind":
if len(lookup) == 2 {
if lookup[1] == "file" {
return hostLookupDNSFiles, conf
}
return fallbackOrder, conf
}
return hostLookupDNS, conf
case "file":
if len(lookup) == 2 {
if lookup[1] == "bind" {
return hostLookupFilesDNS, conf
}
return fallbackOrder, conf
}
return hostLookupFiles, conf
default:
return fallbackOrder, conf
}
}
// Canonicalize the hostname by removing any trailing dot.
if stringsHasSuffix(hostname, ".") {
hostname = hostname[:len(hostname)-1]
}
if stringsHasSuffixFold(hostname, ".local") {
// Per RFC 6762, the ".local" TLD is special. And
// because Go's native resolver doesn't do mDNS or
// similar local resolution mechanisms, assume that
// libc might (via Avahi, etc) and use cgo.
return fallbackOrder, conf
}
nss := getSystemNSS()
srcs := nss.sources["hosts"]
// If /etc/nsswitch.conf doesn't exist or doesn't specify any
// sources for "hosts", assume Go's DNS will work fine.
if os.IsNotExist(nss.err) || (nss.err == nil && len(srcs) == 0) {
if c.goos == "solaris" {
// illumos defaults to "nis [NOTFOUND=return] files"
return fallbackOrder, conf
}
return hostLookupFilesDNS, conf
}
if nss.err != nil {
// We failed to parse or open nsswitch.conf, so
// conservatively assume we should use cgo if it's
// available.
return fallbackOrder, conf
}
var mdnsSource, filesSource, dnsSource bool
var first string
for _, src := range srcs {
if src.source == "myhostname" {
if isLocalhost(hostname) || isGateway(hostname) || isOutbound(hostname) {
return fallbackOrder, conf
}
hn, err := getHostname()
if err != nil || stringsEqualFold(hostname, hn) {
return fallbackOrder, conf
}
continue
}
if src.source == "files" || src.source == "dns" {
if !src.standardCriteria() {
return fallbackOrder, conf // non-standard; let libc deal with it.
}
if src.source == "files" {
filesSource = true
} else if src.source == "dns" {
dnsSource = true
}
if first == "" {
first = src.source
}
continue
}
if stringsHasPrefix(src.source, "mdns") {
// e.g. "mdns4", "mdns4_minimal"
// We already returned true before if it was *.local.
// libc wouldn't have found a hit on this anyway.
mdnsSource = true
continue
}
// Some source we don't know how to deal with.
return fallbackOrder, conf
}
// We don't parse mdns.allow files. They're rare. If one
// exists, it might list other TLDs (besides .local) or even
// '*', so just let libc deal with it.
if mdnsSource && c.hasMDNSAllow {
return fallbackOrder, conf
}
// Cases where Go can handle it without cgo and C thread
// overhead.
switch {
case filesSource && dnsSource:
if first == "files" {
return hostLookupFilesDNS, conf
} else {
return hostLookupDNSFiles, conf
}
case filesSource:
return hostLookupFiles, conf
case dnsSource:
return hostLookupDNS, conf
}
// Something weird. Let libc deal with it.
return fallbackOrder, conf
}
var netdns = godebug.New("netdns")
// goDebugNetDNS parses the value of the GODEBUG "netdns" value.
// The netdns value can be of the form:
//
// 1 // debug level 1
// 2 // debug level 2
// cgo // use cgo for DNS lookups
// go // use go for DNS lookups
// cgo+1 // use cgo for DNS lookups + debug level 1
// 1+cgo // same
// cgo+2 // same, but debug level 2
//
// etc.
func goDebugNetDNS() (dnsMode string, debugLevel int) {
goDebug := netdns.Value()
parsePart := func(s string) {
if s == "" {
return
}
if '0' <= s[0] && s[0] <= '9' {
debugLevel, _, _ = dtoi(s)
} else {
dnsMode = s
}
}
if i := bytealg.IndexByteString(goDebug, '+'); i != -1 {
parsePart(goDebug[:i])
parsePart(goDebug[i+1:])
return
}
parsePart(goDebug)
return
}
// isLocalhost reports whether h should be considered a "localhost"
// name for the myhostname NSS module.
func isLocalhost(h string) bool {
return stringsEqualFold(h, "localhost") || stringsEqualFold(h, "localhost.localdomain") || stringsHasSuffixFold(h, ".localhost") || stringsHasSuffixFold(h, ".localhost.localdomain")
}
// isGateway reports whether h should be considered a "gateway"
// name for the myhostname NSS module.
func isGateway(h string) bool {
return stringsEqualFold(h, "_gateway")
}
// isOutbound reports whether h should be considered a "outbound"
// name for the myhostname NSS module.
func isOutbound(h string) bool {
return stringsEqualFold(h, "_outbound")
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"internal/nettrace"
"syscall"
"time"
)
// defaultTCPKeepAlive is a default constant value for TCPKeepAlive times
// See golang.org/issue/31510
const (
defaultTCPKeepAlive = 15 * time.Second
)
// A Dialer contains options for connecting to an address.
//
// The zero value for each field is equivalent to dialing
// without that option. Dialing with the zero value of Dialer
// is therefore equivalent to just calling the Dial function.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// Timeout is the maximum amount of time a dial will wait for
// a connect to complete. If Deadline is also set, it may fail
// earlier.
//
// The default is no timeout.
//
// When using TCP and dialing a host name with multiple IP
// addresses, the timeout may be divided between them.
//
// With or without a timeout, the operating system may impose
// its own earlier timeout. For instance, TCP timeouts are
// often around 3 minutes.
Timeout time.Duration
// Deadline is the absolute point in time after which dials
// will fail. If Timeout is set, it may fail earlier.
// Zero means no deadline, or dependent on the operating system
// as with the Timeout option.
Deadline time.Time
// LocalAddr is the local address to use when dialing an
// address. The address must be of a compatible type for the
// network being dialed.
// If nil, a local address is automatically chosen.
LocalAddr Addr
// DualStack previously enabled RFC 6555 Fast Fallback
// support, also known as "Happy Eyeballs", in which IPv4 is
// tried soon if IPv6 appears to be misconfigured and
// hanging.
//
// Deprecated: Fast Fallback is enabled by default. To
// disable, set FallbackDelay to a negative value.
DualStack bool
// FallbackDelay specifies the length of time to wait before
// spawning a RFC 6555 Fast Fallback connection. That is, this
// is the amount of time to wait for IPv6 to succeed before
// assuming that IPv6 is misconfigured and falling back to
// IPv4.
//
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// KeepAlive specifies the interval between keep-alive
// probes for an active network connection.
// If zero, keep-alive probes are sent with a default value
// (currently 15 seconds), if supported by the protocol and operating
// system. Network protocols or operating systems that do
// not support keep-alives ignore this field.
// If negative, keep-alive probes are disabled.
KeepAlive time.Duration
// Resolver optionally specifies an alternate resolver to use.
Resolver *Resolver
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancellation.
//
// Deprecated: Use DialContext instead.
Cancel <-chan struct{}
// If Control is not nil, it is called after creating the network
// connection but before actually dialing.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
//
// Control is ignored if ControlContext is not nil.
Control func(network, address string, c syscall.RawConn) error
// If ControlContext is not nil, it is called after creating the network
// connection but before actually dialing.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Dial. For example, passing "tcp" to Dial
// will cause the Control function to be called with "tcp4" or "tcp6".
//
// If ControlContext is not nil, Control is ignored.
ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
}
func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
func minNonzeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() || a.Before(b) {
return a
}
return b
}
// deadline returns the earliest of:
// - now+Timeout
// - d.Deadline
// - the context's deadline
//
// Or zero, if none of Timeout, Deadline, or context's deadline is set.
func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
if d.Timeout != 0 { // including negative, for historical reasons
earliest = now.Add(d.Timeout)
}
if d, ok := ctx.Deadline(); ok {
earliest = minNonzeroTime(earliest, d)
}
return minNonzeroTime(earliest, d.Deadline)
}
func (d *Dialer) resolver() *Resolver {
if d.Resolver != nil {
return d.Resolver
}
return DefaultResolver
}
// partialDeadline returns the deadline to use for a single address,
// when multiple addresses are pending.
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
if deadline.IsZero() {
return deadline, nil
}
timeRemaining := deadline.Sub(now)
if timeRemaining <= 0 {
return time.Time{}, errTimeout
}
// Tentatively allocate equal time to each remaining address.
timeout := timeRemaining / time.Duration(addrsRemaining)
// If the time per address is too short, steal from the end of the list.
const saneMinimum = 2 * time.Second
if timeout < saneMinimum {
if timeRemaining < saneMinimum {
timeout = timeRemaining
} else {
timeout = saneMinimum
}
}
return now.Add(timeout), nil
}
func (d *Dialer) fallbackDelay() time.Duration {
if d.FallbackDelay > 0 {
return d.FallbackDelay
} else {
return 300 * time.Millisecond
}
}
func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
i := last(network, ':')
if i < 0 { // no colon
switch network {
case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
case "ip", "ip4", "ip6":
if needsProto {
return "", 0, UnknownNetworkError(network)
}
case "unix", "unixgram", "unixpacket":
default:
return "", 0, UnknownNetworkError(network)
}
return network, 0, nil
}
afnet = network[:i]
switch afnet {
case "ip", "ip4", "ip6":
protostr := network[i+1:]
proto, i, ok := dtoi(protostr)
if !ok || i != len(protostr) {
proto, err = lookupProtocol(ctx, protostr)
if err != nil {
return "", 0, err
}
}
return afnet, proto, nil
}
return "", 0, UnknownNetworkError(network)
}
// resolveAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
afnet, _, err := parseNetwork(ctx, network, true)
if err != nil {
return nil, err
}
if op == "dial" && addr == "" {
return nil, errMissingAddress
}
switch afnet {
case "unix", "unixgram", "unixpacket":
addr, err := ResolveUnixAddr(afnet, addr)
if err != nil {
return nil, err
}
if op == "dial" && hint != nil && addr.Network() != hint.Network() {
return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
}
return addrList{addr}, nil
}
addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
var (
tcp *TCPAddr
udp *UDPAddr
ip *IPAddr
wildcard bool
)
switch hint := hint.(type) {
case *TCPAddr:
tcp = hint
wildcard = tcp.isWildcard()
case *UDPAddr:
udp = hint
wildcard = udp.isWildcard()
case *IPAddr:
ip = hint
wildcard = ip.isWildcard()
}
naddrs := addrs[:0]
for _, addr := range addrs {
if addr.Network() != hint.Network() {
return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
}
switch addr := addr.(type) {
case *TCPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
continue
}
naddrs = append(naddrs, addr)
case *UDPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
continue
}
naddrs = append(naddrs, addr)
case *IPAddr:
if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
continue
}
naddrs = append(naddrs, addr)
}
}
if len(naddrs) == 0 {
return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
}
return naddrs, nil
}
// Dial connects to the address on the named network.
//
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and
// "unixpacket".
//
// For TCP and UDP networks, the address has the form "host:port".
// The host must be a literal IP address, or a host name that can be
// resolved to IP addresses.
// The port must be a literal port number or a service name.
// If the host is a literal IPv6 address it must be enclosed in square
// brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80".
// The zone specifies the scope of the literal IPv6 address as defined
// in RFC 4007.
// The functions JoinHostPort and SplitHostPort manipulate a pair of
// host and port in this form.
// When using TCP, and the host resolves to multiple IP addresses,
// Dial will try each IP address in order until one succeeds.
//
// Examples:
//
// Dial("tcp", "golang.org:http")
// Dial("tcp", "192.0.2.1:http")
// Dial("tcp", "198.51.100.1:80")
// Dial("udp", "[2001:db8::1]:domain")
// Dial("udp", "[fe80::1%lo0]:53")
// Dial("tcp", ":80")
//
// For IP networks, the network must be "ip", "ip4" or "ip6" followed
// by a colon and a literal protocol number or a protocol name, and
// the address has the form "host". The host must be a literal IP
// address or a literal IPv6 address with zone.
// It depends on each operating system how the operating system
// behaves with a non-well known protocol number such as "0" or "255".
//
// Examples:
//
// Dial("ip4:1", "192.0.2.1")
// Dial("ip6:ipv6-icmp", "2001:db8::1")
// Dial("ip6:58", "fe80::1%lo0")
//
// For TCP, UDP and IP networks, if the host is empty or a literal
// unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for
// TCP and UDP, "", "0.0.0.0" or "::" for IP, the local system is
// assumed.
//
// For Unix networks, the address must be a file system path.
func Dial(network, address string) (Conn, error) {
var d Dialer
return d.Dial(network, address)
}
// DialTimeout acts like Dial but takes a timeout.
//
// The timeout includes name resolution, if required.
// When using TCP, and the host in the address parameter resolves to
// multiple IP addresses, the timeout is spread over each consecutive
// dial, such that each is given an appropriate fraction of the time
// to connect.
//
// See func Dial for a description of the network and address
// parameters.
func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
d := Dialer{Timeout: timeout}
return d.Dial(network, address)
}
// sysDialer contains a Dial's parameters and configuration.
type sysDialer struct {
Dialer
network, address string
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
}
// Dial connects to the address on the named network.
//
// See func Dial for a description of the network and address
// parameters.
//
// Dial uses context.Background internally; to specify the context, use
// DialContext.
func (d *Dialer) Dial(network, address string) (Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// DialContext connects to the address on the named network using
// the provided context.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// When using TCP, and the host in the address parameter resolves to multiple
// network addresses, any dial timeout (from d.Timeout or ctx) is spread
// over each consecutive dial, such that each is given an appropriate
// fraction of the time to connect.
// For example, if a host has 4 IP addresses and the timeout is 1 minute,
// the connect to each single address will be given 15 seconds to complete
// before trying the next one.
//
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
if ctx == nil {
panic("nil context")
}
deadline := d.deadline(ctx, time.Now())
if !deadline.IsZero() {
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
subCtx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
ctx = subCtx
}
}
if oldCancel := d.Cancel; oldCancel != nil {
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-oldCancel:
cancel()
case <-subCtx.Done():
}
}()
ctx = subCtx
}
// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
resolveCtx := ctx
if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
shadow := *trace
shadow.ConnectStart = nil
shadow.ConnectDone = nil
resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
}
addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
sd := &sysDialer{
Dialer: *d,
network: network,
address: address,
}
var primaries, fallbacks addrList
if d.dualStack() && network == "tcp" {
primaries, fallbacks = addrs.partition(isIPv4)
} else {
primaries = addrs
}
return sd.dialParallel(ctx, primaries, fallbacks)
}
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
if len(fallbacks) == 0 {
return sd.dialSerial(ctx, primaries)
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered
startRacer := func(ctx context.Context, primary bool) {
ras := primaries
if !primary {
ras = fallbacks
}
c, err := sd.dialSerial(ctx, ras)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
}
}
var primary, fallback dialResult
// Start the main racer.
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go startRacer(primaryCtx, true)
// Start the timer for the fallback racer.
fallbackTimer := time.NewTimer(sd.fallbackDelay())
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go startRacer(fallbackCtx, false)
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
}
// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.
for i, ra := range ras {
select {
case <-ctx.Done():
return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}
dialCtx := ctx
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// Ran out of time.
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
}
break
}
if partialDeadline.Before(deadline) {
var cancel context.CancelFunc
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
defer cancel()
}
}
c, err := sd.dialSingle(dialCtx, ra)
if err == nil {
return c, nil
}
if firstErr == nil {
firstErr = err
}
}
if firstErr == nil {
firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}
// dialSingle attempts to establish and returns a single connection to
// the destination address.
func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil {
raStr := ra.String()
if trace.ConnectStart != nil {
trace.ConnectStart(sd.network, raStr)
}
if trace.ConnectDone != nil {
defer func() { trace.ConnectDone(sd.network, raStr, err) }()
}
}
la := sd.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
c, err = sd.dialTCP(ctx, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
c, err = sd.dialUDP(ctx, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
c, err = sd.dialIP(ctx, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
c, err = sd.dialUnix(ctx, la, ra)
default:
return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
}
if err != nil {
return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
// ListenConfig contains options for listening to an address.
type ListenConfig struct {
// If Control is not nil, it is called after creating the network
// connection but before binding it to the operating system.
//
// Network and address parameters passed to Control method are not
// necessarily the ones passed to Listen. For example, passing "tcp" to
// Listen will cause the Control function to be called with "tcp4" or "tcp6".
Control func(network, address string, c syscall.RawConn) error
// KeepAlive specifies the keep-alive period for network
// connections accepted by this listener.
// If zero, keep-alives are enabled if supported by the protocol
// and operating system. Network protocols or operating systems
// that do not support keep-alives ignore this field.
// If negative, keep-alives are disabled.
KeepAlive time.Duration
}
// Listen announces on the local network address.
//
// See func Listen for a description of the network and address
// parameters.
func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
sl := &sysListener{
ListenConfig: *lc,
network: network,
address: address,
}
var l Listener
la := addrs.first(isIPv4)
switch la := la.(type) {
case *TCPAddr:
l, err = sl.listenTCP(ctx, la)
case *UnixAddr:
l, err = sl.listenUnix(ctx, la)
default:
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
}
return l, nil
}
// ListenPacket announces on the local network address.
//
// See func ListenPacket for a description of the network and address
// parameters.
func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
}
sl := &sysListener{
ListenConfig: *lc,
network: network,
address: address,
}
var c PacketConn
la := addrs.first(isIPv4)
switch la := la.(type) {
case *UDPAddr:
c, err = sl.listenUDP(ctx, la)
case *IPAddr:
c, err = sl.listenIP(ctx, la)
case *UnixAddr:
c, err = sl.listenUnixgram(ctx, la)
default:
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
// sysListener contains a Listen's parameters and configuration.
type sysListener struct {
ListenConfig
network, address string
}
// Listen announces on the local network address.
//
// The network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
//
// For TCP networks, if the host in the address parameter is empty or
// a literal unspecified IP address, Listen listens on all available
// unicast and anycast IP addresses of the local system.
// To only use IPv4, use network "tcp4".
// The address can use a host name, but this is not recommended,
// because it will create a listener for at most one of the host's IP
// addresses.
// If the port in the address parameter is empty or "0", as in
// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
// The Addr method of Listener can be used to discover the chosen
// port.
//
// See func Dial for a description of the network and address
// parameters.
//
// Listen uses context.Background internally; to specify the context, use
// ListenConfig.Listen.
func Listen(network, address string) (Listener, error) {
var lc ListenConfig
return lc.Listen(context.Background(), network, address)
}
// ListenPacket announces on the local network address.
//
// The network must be "udp", "udp4", "udp6", "unixgram", or an IP
// transport. The IP transports are "ip", "ip4", or "ip6" followed by
// a colon and a literal protocol number or a protocol name, as in
// "ip:1" or "ip:icmp".
//
// For UDP and IP networks, if the host in the address parameter is
// empty or a literal unspecified IP address, ListenPacket listens on
// all available IP addresses of the local system except multicast IP
// addresses.
// To only use IPv4, use network "udp4" or "ip4:proto".
// The address can use a host name, but this is not recommended,
// because it will create a listener for at most one of the host's IP
// addresses.
// If the port in the address parameter is empty or "0", as in
// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
// The LocalAddr method of PacketConn can be used to discover the
// chosen port.
//
// See func Dial for a description of the network and address
// parameters.
//
// ListenPacket uses context.Background internally; to specify the context, use
// ListenConfig.ListenPacket.
func ListenPacket(network, address string) (PacketConn, error) {
var lc ListenConfig
return lc.ListenPacket(context.Background(), network, address)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/bytealg"
"internal/itoa"
"sort"
"golang.org/x/net/dns/dnsmessage"
)
// provided by runtime
func fastrandu() uint
func randInt() int {
return int(fastrandu() >> 1) // clear sign bit
}
func randIntn(n int) int {
return randInt() % n
}
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
// address addr suitable for rDNS (PTR) record lookup or an error if it fails
// to parse the IP address.
func reverseaddr(addr string) (arpa string, err error) {
ip := ParseIP(addr)
if ip == nil {
return "", &DNSError{Err: "unrecognized address", Name: addr}
}
if ip.To4() != nil {
return itoa.Uitoa(uint(ip[15])) + "." + itoa.Uitoa(uint(ip[14])) + "." + itoa.Uitoa(uint(ip[13])) + "." + itoa.Uitoa(uint(ip[12])) + ".in-addr.arpa.", nil
}
// Must be IPv6
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- {
v := ip[i]
buf = append(buf, hexDigit[v&0xF],
'.',
hexDigit[v>>4],
'.')
}
// Append "ip6.arpa." and return (buf already has the final .)
buf = append(buf, "ip6.arpa."...)
return string(buf), nil
}
func equalASCIIName(x, y dnsmessage.Name) bool {
if x.Length != y.Length {
return false
}
for i := 0; i < int(x.Length); i++ {
a := x.Data[i]
b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
if 'A' <= b && b <= 'Z' {
b += 0x20
}
if a != b {
return false
}
}
return true
}
// isDomainName checks if a string is a presentation-format domain name
// (currently restricted to hostname-compatible "preferred name" LDH labels and
// SRV-like "underscore labels"; see golang.org/issue/12421).
func isDomainName(s string) bool {
// The root domain name is valid. See golang.org/issue/45715.
if s == "." {
return true
}
// See RFC 1035, RFC 3696.
// Presentation format has dots before every label except the first, and the
// terminal empty label is optional here because we assume fully-qualified
// (absolute) input. We must therefore reserve space for the first and last
// labels' length octets in wire format, where they are necessary and the
// maximum total length is 255.
// So our _effective_ maximum is 253, but 254 is not rejected if the last
// character is a dot.
l := len(s)
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
last := byte('.')
nonNumeric := false // true once we've seen a letter or hyphen
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
nonNumeric = true
partlen++
case '0' <= c && c <= '9':
// fine
partlen++
case c == '-':
// Byte before dash cannot be dot.
if last == '.' {
return false
}
partlen++
nonNumeric = true
case c == '.':
// Byte before dot cannot be dot, dash.
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return nonNumeric
}
// absDomainName returns an absolute domain name which ends with a
// trailing dot to match pure Go reverse resolver and all other lookup
// routines.
// See golang.org/issue/12189.
// But we don't want to add dots for local names from /etc/hosts.
// It's hard to tell so we settle on the heuristic that names without dots
// (like "localhost" or "myhost") do not get trailing dots, but any other
// names do.
func absDomainName(s string) string {
if bytealg.IndexByteString(s, '.') != -1 && s[len(s)-1] != '.' {
s += "."
}
return s
}
// An SRV represents a single DNS SRV record.
type SRV struct {
Target string
Port uint16
Priority uint16
Weight uint16
}
// byPriorityWeight sorts SRV records by ascending priority and weight.
type byPriorityWeight []*SRV
func (s byPriorityWeight) Len() int { return len(s) }
func (s byPriorityWeight) Less(i, j int) bool {
return s[i].Priority < s[j].Priority || (s[i].Priority == s[j].Priority && s[i].Weight < s[j].Weight)
}
func (s byPriorityWeight) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// shuffleByWeight shuffles SRV records by weight using the algorithm
// described in RFC 2782.
func (addrs byPriorityWeight) shuffleByWeight() {
sum := 0
for _, addr := range addrs {
sum += int(addr.Weight)
}
for sum > 0 && len(addrs) > 1 {
s := 0
n := randIntn(sum)
for i := range addrs {
s += int(addrs[i].Weight)
if s > n {
if i > 0 {
addrs[0], addrs[i] = addrs[i], addrs[0]
}
break
}
}
sum -= int(addrs[0].Weight)
addrs = addrs[1:]
}
}
// sort reorders SRV records as specified in RFC 2782.
func (addrs byPriorityWeight) sort() {
sort.Sort(addrs)
i := 0
for j := 1; j < len(addrs); j++ {
if addrs[i].Priority != addrs[j].Priority {
addrs[i:j].shuffleByWeight()
i = j
}
}
addrs[i:].shuffleByWeight()
}
// An MX represents a single DNS MX record.
type MX struct {
Host string
Pref uint16
}
// byPref implements sort.Interface to sort MX records by preference
type byPref []*MX
func (s byPref) Len() int { return len(s) }
func (s byPref) Less(i, j int) bool { return s[i].Pref < s[j].Pref }
func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// sort reorders MX records as specified in RFC 5321.
func (s byPref) sort() {
for i := range s {
j := randIntn(i + 1)
s[i], s[j] = s[j], s[i]
}
sort.Sort(s)
}
// An NS represents a single DNS NS record.
type NS struct {
Host string
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js
// DNS client: see RFC 1035.
// Has to be linked into package net for Dial.
// TODO(rsc):
// Could potentially handle many outstanding lookups faster.
// Random UDP source port (net.Dial should do that for us).
// Random request IDs.
package net
import (
"context"
"errors"
"internal/itoa"
"io"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/dns/dnsmessage"
)
const (
// to be used as a useTCP parameter to exchange
useTCPOnly = true
useUDPOrTCP = false
// Maximum DNS packet size.
// Value taken from https://dnsflagday.net/2020/.
maxDNSPacketSize = 1232
)
var (
errLameReferral = errors.New("lame referral")
errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
errServerMisbehaving = errors.New("server misbehaving")
errInvalidDNSResponse = errors.New("invalid DNS response")
errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
// errServerTemporarilyMisbehaving is like errServerMisbehaving, except
// that when it gets translated to a DNSError, the IsTemporary field
// gets set to true.
errServerTemporarilyMisbehaving = errors.New("server misbehaving")
)
func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
id = uint16(randInt())
b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true, AuthenticData: ad})
if err := b.StartQuestions(); err != nil {
return 0, nil, nil, err
}
if err := b.Question(q); err != nil {
return 0, nil, nil, err
}
// Accept packets up to maxDNSPacketSize. RFC 6891.
if err := b.StartAdditionals(); err != nil {
return 0, nil, nil, err
}
var rh dnsmessage.ResourceHeader
if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
return 0, nil, nil, err
}
if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
return 0, nil, nil, err
}
tcpReq, err = b.Finish()
if err != nil {
return 0, nil, nil, err
}
udpReq = tcpReq[2:]
l := len(tcpReq) - 2
tcpReq[0] = byte(l >> 8)
tcpReq[1] = byte(l)
return id, udpReq, tcpReq, nil
}
func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
if !respHdr.Response {
return false
}
if reqID != respHdr.ID {
return false
}
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
return false
}
return true
}
func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, maxDNSPacketSize)
for {
n, err := c.Read(b)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
var p dnsmessage.Parser
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
h, err := p.Start(b[:n])
if err != nil {
continue
}
q, err := p.Question()
if err != nil || !checkResponse(id, query, h, q) {
continue
}
return p, h, nil
}
}
func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
var p dnsmessage.Parser
h, err := p.Start(b[:n])
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
}
q, err := p.Question()
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
}
if !checkResponse(id, query, h, q) {
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
}
return p, h, nil
}
// exchange sends a query on the connection and hopes for a response.
func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP, ad bool) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q, ad)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
}
var networks []string
if useTCP {
networks = []string{"tcp"}
} else {
networks = []string{"udp", "tcp"}
}
for _, network := range networks {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
c, err := r.dial(ctx, network, server)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d)
}
var p dnsmessage.Parser
var h dnsmessage.Header
if _, ok := c.(PacketConn); ok {
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
} else {
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
}
c.Close()
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
}
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
}
if h.Truncated { // see RFC 5966
continue
}
return p, h, nil
}
return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
}
// checkHeader performs basic sanity checks on the header.
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
if h.RCode == dnsmessage.RCodeNameError {
return errNoSuchHost
}
_, err := p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
return errCannotUnmarshalDNSMessage
}
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
return errLameReferral
}
if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
if h.RCode == dnsmessage.RCodeServerFailure {
return errServerTemporarilyMisbehaving
}
return errServerMisbehaving
}
return nil
}
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
return errNoSuchHost
}
if err != nil {
return errCannotUnmarshalDNSMessage
}
if h.Type == qtype {
return nil
}
if err := p.SkipAnswer(); err != nil {
return errCannotUnmarshalDNSMessage
}
}
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error
serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers))
n, err := dnsmessage.NewName(name)
if err != nil {
return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
}
q := dnsmessage.Question{
Name: n,
Type: qtype,
Class: dnsmessage.ClassINET,
}
for i := 0; i < cfg.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
if err != nil {
dnsErr := &DNSError{
Err: err.Error(),
Name: name,
Server: server,
}
if nerr, ok := err.(Error); ok && nerr.Timeout() {
dnsErr.IsTimeout = true
}
// Set IsTemporary for socket-level errors. Note that this flag
// may also be used to indicate a SERVFAIL response.
if _, ok := err.(*OpError); ok {
dnsErr.IsTemporary = true
}
lastErr = dnsErr
continue
}
if err := checkHeader(&p, h); err != nil {
dnsErr := &DNSError{
Err: err.Error(),
Name: name,
Server: server,
}
if err == errServerTemporarilyMisbehaving {
dnsErr.IsTemporary = true
}
if err == errNoSuchHost {
// The name does not exist, so trying
// another server won't help.
dnsErr.IsNotFound = true
return p, server, dnsErr
}
lastErr = dnsErr
continue
}
err = skipToAnswer(&p, qtype)
if err == nil {
return p, server, nil
}
lastErr = &DNSError{
Err: err.Error(),
Name: name,
Server: server,
}
if err == errNoSuchHost {
// The name does not exist, so trying another
// server won't help.
lastErr.(*DNSError).IsNotFound = true
return p, server, lastErr
}
}
}
return dnsmessage.Parser{}, "", lastErr
}
// A resolverConfig represents a DNS stub resolver configuration.
type resolverConfig struct {
initOnce sync.Once // guards init of resolverConfig
// ch is used as a semaphore that only allows one lookup at a
// time to recheck resolv.conf.
ch chan struct{} // guards lastChecked and modTime
lastChecked time.Time // last time resolv.conf was checked
dnsConfig atomic.Pointer[dnsConfig] // parsed resolv.conf structure used in lookups
}
var resolvConf resolverConfig
func getSystemDNSConfig() *dnsConfig {
resolvConf.tryUpdate("/etc/resolv.conf")
return resolvConf.dnsConfig.Load()
}
// init initializes conf and is only called via conf.initOnce.
func (conf *resolverConfig) init() {
// Set dnsConfig and lastChecked so we don't parse
// resolv.conf twice the first time.
conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
conf.lastChecked = time.Now()
// Prepare ch so that only one update of resolverConfig may
// run at once.
conf.ch = make(chan struct{}, 1)
}
// tryUpdate tries to update conf with the named resolv.conf file.
// The name variable only exists for testing. It is otherwise always
// "/etc/resolv.conf".
func (conf *resolverConfig) tryUpdate(name string) {
conf.initOnce.Do(conf.init)
if conf.dnsConfig.Load().noReload {
return
}
// Ensure only one update at a time checks resolv.conf.
if !conf.tryAcquireSema() {
return
}
defer conf.releaseSema()
now := time.Now()
if conf.lastChecked.After(now.Add(-5 * time.Second)) {
return
}
conf.lastChecked = now
switch runtime.GOOS {
case "windows":
// There's no file on disk, so don't bother checking
// and failing.
//
// The Windows implementation of dnsReadConfig (called
// below) ignores the name.
default:
var mtime time.Time
if fi, err := os.Stat(name); err == nil {
mtime = fi.ModTime()
}
if mtime.Equal(conf.dnsConfig.Load().mtime) {
return
}
}
dnsConf := dnsReadConfig(name)
conf.dnsConfig.Store(dnsConf)
}
func (conf *resolverConfig) tryAcquireSema() bool {
select {
case conf.ch <- struct{}{}:
return true
default:
return false
}
}
func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
if !isDomainName(name) {
// We used to use "invalid domain name" as the error,
// but that is a detail of the specific lookup mechanism.
// Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host.
return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
if conf == nil {
conf = getSystemDNSConfig()
}
var (
p dnsmessage.Parser
server string
err error
)
for _, fqdn := range conf.nameList(name) {
p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil {
break
}
if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
// If we hit a temporary error with StrictErrors enabled,
// stop immediately instead of trying more names.
break
}
}
if err == nil {
return p, server, nil
}
if err, ok := err.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324.
err.Name = name
}
return dnsmessage.Parser{}, "", err
}
// avoidDNS reports whether this is a hostname for which we should not
// use DNS. Currently this includes only .onion, per RFC 7686. See
// golang.org/issue/13705. Does not cover .local names (RFC 6762),
// see golang.org/issue/16739.
func avoidDNS(name string) bool {
if name == "" {
return true
}
if name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
return stringsHasSuffixFold(name, ".onion")
}
// nameList returns a list of names for sequential DNS queries.
func (conf *dnsConfig) nameList(name string) []string {
if avoidDNS(name) {
return nil
}
// Check name length (see isDomainName).
l := len(name)
rooted := l > 0 && name[l-1] == '.'
if l > 254 || l == 254 && !rooted {
return nil
}
// If name is rooted (trailing dot), try only that name.
if rooted {
return []string{name}
}
hasNdots := count(name, '.') >= conf.ndots
name += "."
l++
// Build list of search choices.
names := make([]string, 0, 1+len(conf.search))
// If name has enough dots, try unsuffixed first.
if hasNdots {
names = append(names, name)
}
// Try suffixes that are not too long (see isDomainName).
for _, suffix := range conf.search {
if l+len(suffix) <= 254 {
names = append(names, name+suffix)
}
}
// Try unsuffixed, if not tried first above.
if !hasNdots {
names = append(names, name)
}
return names
}
// hostLookupOrder specifies the order of LookupHost lookup strategies.
// It is basically a simplified representation of nsswitch.conf.
// "files" means /etc/hosts.
type hostLookupOrder int
const (
// hostLookupCgo means defer to cgo.
hostLookupCgo hostLookupOrder = iota
hostLookupFilesDNS // files first
hostLookupDNSFiles // dns first
hostLookupFiles // only files
hostLookupDNS // only DNS
)
var lookupOrderName = map[hostLookupOrder]string{
hostLookupCgo: "cgo",
hostLookupFilesDNS: "files,dns",
hostLookupDNSFiles: "dns,files",
hostLookupFiles: "files",
hostLookupDNS: "dns",
}
func (o hostLookupOrder) String() string {
if s, ok := lookupOrderName[o]; ok {
return s
}
return "hostLookupOrder=" + itoa.Itoa(int(o)) + "??"
}
func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder, conf *dnsConfig) (addrs []string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
// Use entries from /etc/hosts if they match.
addrs, _ = lookupStaticHost(name)
if len(addrs) > 0 {
return
}
if order == hostLookupFiles {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
}
ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf)
if err != nil {
return
}
addrs = make([]string, 0, len(ips))
for _, ip := range ips {
addrs = append(addrs, ip.String())
}
return
}
// lookup entries from /etc/hosts
func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
addr, canonical := lookupStaticHost(name)
for _, haddr := range addr {
haddr, zone := splitHostZone(haddr)
if ip := ParseIP(haddr); ip != nil {
addr := IPAddr{IP: ip, Zone: zone}
addrs = append(addrs, addr)
}
}
sortByRFC6724(addrs)
return addrs, canonical
}
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
order, conf := systemConf().hostLookupOrder(r, host)
addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
return
}
func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, cname dnsmessage.Name, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
var canonical string
addrs, canonical = goLookupIPFiles(name)
if len(addrs) > 0 {
var err error
cname, err = dnsmessage.NewName(canonical)
if err != nil {
return nil, dnsmessage.Name{}, err
}
return addrs, cname, nil
}
if order == hostLookupFiles {
return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
}
if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost.
return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
}
type result struct {
p dnsmessage.Parser
server string
error
}
if conf == nil {
conf = getSystemDNSConfig()
}
lane := make(chan result, 1)
qtypes := []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
if network == "CNAME" {
qtypes = append(qtypes, dnsmessage.TypeCNAME)
}
switch ipVersion(network) {
case '4':
qtypes = []dnsmessage.Type{dnsmessage.TypeA}
case '6':
qtypes = []dnsmessage.Type{dnsmessage.TypeAAAA}
}
var queryFn func(fqdn string, qtype dnsmessage.Type)
var responseFn func(fqdn string, qtype dnsmessage.Type) result
if conf.singleRequest {
queryFn = func(fqdn string, qtype dnsmessage.Type) {}
responseFn = func(fqdn string, qtype dnsmessage.Type) result {
dnsWaitGroup.Add(1)
defer dnsWaitGroup.Done()
p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
return result{p, server, err}
}
} else {
queryFn = func(fqdn string, qtype dnsmessage.Type) {
dnsWaitGroup.Add(1)
go func(qtype dnsmessage.Type) {
p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
lane <- result{p, server, err}
dnsWaitGroup.Done()
}(qtype)
}
responseFn = func(fqdn string, qtype dnsmessage.Type) result {
return <-lane
}
}
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
queryFn(fqdn, qtype)
}
hitStrictError := false
for _, qtype := range qtypes {
result := responseFn(fqdn, qtype)
if result.error != nil {
if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
// This error will abort the nameList loop.
hitStrictError = true
lastErr = result.error
} else if lastErr == nil || fqdn == name+"." {
// Prefer error for original name.
lastErr = result.error
}
continue
}
// Presotto says it's okay to assume that servers listed in
// /etc/resolv.conf are recursive resolvers.
//
// We asked for recursion, so it should have included all the
// answers we need in this one packet.
//
// Further, RFC 1034 section 4.3.1 says that "the recursive
// response to a query will be... The answer to the query,
// possibly preface by one or more CNAME RRs that specify
// aliases encountered on the way to an answer."
//
// Therefore, we should be able to assume that we can ignore
// CNAMEs and that the A and AAAA records we requested are
// for the canonical name.
loop:
for {
h, err := result.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: result.server,
}
}
if err != nil {
break
}
switch h.Type {
case dnsmessage.TypeA:
a, err := result.p.AResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: result.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: result.server,
}
break loop
}
addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
case dnsmessage.TypeCNAME:
c, err := result.p.CNAMEResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: result.server,
}
break loop
}
if cname.Length == 0 && c.CNAME.Length > 0 {
cname = c.CNAME
}
default:
if err := result.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Name: name,
Server: result.server,
}
break loop
}
continue
}
}
}
if hitStrictError {
// If either family hit an error with StrictErrors enabled,
// discard all addresses. This ensures that network flakiness
// cannot turn a dualstack hostname IPv4/IPv6-only.
addrs = nil
break
}
if len(addrs) > 0 || network == "CNAME" && cname.Length > 0 {
break
}
}
if lastErr, ok := lastErr.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324.
lastErr.Name = name
}
sortByRFC6724(addrs)
if len(addrs) == 0 && !(network == "CNAME" && cname.Length > 0) {
if order == hostLookupDNSFiles {
var canonical string
addrs, canonical = goLookupIPFiles(name)
if len(addrs) > 0 {
var err error
cname, err = dnsmessage.NewName(canonical)
if err != nil {
return nil, dnsmessage.Name{}, err
}
return addrs, cname, nil
}
}
if lastErr != nil {
return nil, dnsmessage.Name{}, lastErr
}
}
return addrs, cname, nil
}
// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
func (r *Resolver) goLookupCNAME(ctx context.Context, host string, order hostLookupOrder, conf *dnsConfig) (string, error) {
_, cname, err := r.goLookupIPCNAMEOrder(ctx, "CNAME", host, order, conf)
return cname.String(), err
}
// goLookupPTR is the native Go implementation of LookupAddr.
// Used only if cgoLookupPTR refuses to handle the request (that is,
// only if cgoLookupPTR is the stub in cgo_stub.go).
// Normally we let cgo use the C library resolver instead of depending
// on our lookup code, so that Go and C get the same answers.
func (r *Resolver) goLookupPTR(ctx context.Context, addr string, conf *dnsConfig) ([]string, error) {
names := lookupStaticAddr(addr)
if len(names) > 0 {
return names, nil
}
arpa, err := reverseaddr(addr)
if err != nil {
return nil, err
}
p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR, conf)
if err != nil {
return nil, err
}
var ptrs []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
if h.Type != dnsmessage.TypePTR {
err := p.SkipAnswer()
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
continue
}
ptr, err := p.PTRResource()
if err != nil {
return nil, &DNSError{
Err: "cannot marshal DNS message",
Name: addr,
Server: server,
}
}
ptrs = append(ptrs, ptr.PTR.String())
}
return ptrs, nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"os"
"sync/atomic"
"time"
)
var (
defaultNS = []string{"127.0.0.1:53", "[::1]:53"}
getHostname = os.Hostname // variable for testing
)
type dnsConfig struct {
servers []string // server addresses (in host:port form) to use
search []string // rooted suffixes to append to local name
ndots int // number of dots in name to trigger absolute lookup
timeout time.Duration // wait before giving up on a query, including retries
attempts int // lost packets before giving up on server
rotate bool // round robin among servers
unknownOpt bool // anything unknown was encountered
lookup []string // OpenBSD top-level database "lookup" order
err error // any error that occurs during open of resolv.conf
mtime time.Time // time of resolv.conf modification
soffset uint32 // used by serverOffset
singleRequest bool // use sequential A and AAAA queries instead of parallel queries
useTCP bool // force usage of TCP for DNS resolutions
trustAD bool // add AD flag to queries
noReload bool // do not check for config file updates
}
// serverOffset returns an offset that can be used to determine
// indices of servers in c.servers when making queries.
// When the rotate option is enabled, this offset increases.
// Otherwise it is always 0.
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js && !windows
// Read system DNS config from /etc/resolv.conf
package net
import (
"internal/bytealg"
"net/netip"
"time"
)
// See resolv.conf(5) on a Linux machine.
func dnsReadConfig(filename string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
file, err := open(filename)
if err != nil {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
defer file.close()
if fi, err := file.file.Stat(); err == nil {
conf.mtime = fi.ModTime()
} else {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
if len(line) > 0 && (line[0] == ';' || line[0] == '#') {
// comment.
continue
}
f := getFields(line)
if len(f) < 1 {
continue
}
switch f[0] {
case "nameserver": // add one name server
if len(f) > 1 && len(conf.servers) < 3 { // small, but the standard limit
// One more check: make sure server name is
// just an IP address. Otherwise we need DNS
// to look it up.
if _, err := netip.ParseAddr(f[1]); err == nil {
conf.servers = append(conf.servers, JoinHostPort(f[1], "53"))
}
}
case "domain": // set search path to just this domain
if len(f) > 1 {
conf.search = []string{ensureRooted(f[1])}
}
case "search": // set search path to given servers
conf.search = make([]string, 0, len(f)-1)
for i := 1; i < len(f); i++ {
name := ensureRooted(f[i])
if name == "." {
continue
}
conf.search = append(conf.search, name)
}
case "options": // magic options
for _, s := range f[1:] {
switch {
case hasPrefix(s, "ndots:"):
n, _, _ := dtoi(s[6:])
if n < 0 {
n = 0
} else if n > 15 {
n = 15
}
conf.ndots = n
case hasPrefix(s, "timeout:"):
n, _, _ := dtoi(s[8:])
if n < 1 {
n = 1
}
conf.timeout = time.Duration(n) * time.Second
case hasPrefix(s, "attempts:"):
n, _, _ := dtoi(s[9:])
if n < 1 {
n = 1
}
conf.attempts = n
case s == "rotate":
conf.rotate = true
case s == "single-request" || s == "single-request-reopen":
// Linux option:
// http://man7.org/linux/man-pages/man5/resolv.conf.5.html
// "By default, glibc performs IPv4 and IPv6 lookups in parallel [...]
// This option disables the behavior and makes glibc
// perform the IPv6 and IPv4 requests sequentially."
conf.singleRequest = true
case s == "use-vc" || s == "usevc" || s == "tcp":
// Linux (use-vc), FreeBSD (usevc) and OpenBSD (tcp) option:
// http://man7.org/linux/man-pages/man5/resolv.conf.5.html
// "Sets RES_USEVC in _res.options.
// This option forces the use of TCP for DNS resolutions."
// https://www.freebsd.org/cgi/man.cgi?query=resolv.conf&sektion=5&manpath=freebsd-release-ports
// https://man.openbsd.org/resolv.conf.5
conf.useTCP = true
case s == "trust-ad":
conf.trustAD = true
case s == "edns0":
// We use EDNS by default.
// Ignore this option.
case s == "no-reload":
conf.noReload = true
default:
conf.unknownOpt = true
}
}
case "lookup":
// OpenBSD option:
// https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
// "the legal space-separated values are: bind, file, yp"
conf.lookup = f[1:]
default:
conf.unknownOpt = true
}
}
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
if len(conf.search) == 0 {
conf.search = dnsDefaultSearch()
}
return conf
}
func dnsDefaultSearch() []string {
hn, err := getHostname()
if err != nil {
// best effort
return nil
}
if i := bytealg.IndexByteString(hn, '.'); i >= 0 && i < len(hn)-1 {
return []string{ensureRooted(hn[i+1:])}
}
return nil
}
func hasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
func ensureRooted(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s
}
return s + "."
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"os"
"syscall"
)
// wrapSyscallError takes an error and a syscall name. If the error is
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
func wrapSyscallError(name string, err error) error {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError(name, err)
}
return err
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || js
package net
import "syscall"
func isConnError(err error) bool {
if se, ok := err.(syscall.Errno); ok {
return se == syscall.ECONNRESET || se == syscall.ECONNABORTED
}
return false
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package net
import (
"internal/poll"
"runtime"
"syscall"
"time"
)
// Network file descriptor.
type netFD struct {
pfd poll.FD
// immutable until Close
family int
sotype int
isConnected bool // handshake completed or use of association with peer
net string
laddr Addr
raddr Addr
}
func (fd *netFD) setAddr(laddr, raddr Addr) {
fd.laddr = laddr
fd.raddr = raddr
runtime.SetFinalizer(fd, (*netFD).Close)
}
func (fd *netFD) Close() error {
runtime.SetFinalizer(fd, nil)
return fd.pfd.Close()
}
func (fd *netFD) shutdown(how int) error {
err := fd.pfd.Shutdown(how)
runtime.KeepAlive(fd)
return wrapSyscallError("shutdown", err)
}
func (fd *netFD) closeRead() error {
return fd.shutdown(syscall.SHUT_RD)
}
func (fd *netFD) closeWrite() error {
return fd.shutdown(syscall.SHUT_WR)
}
func (fd *netFD) Read(p []byte) (n int, err error) {
n, err = fd.pfd.Read(p)
runtime.KeepAlive(fd)
return n, wrapSyscallError(readSyscallName, err)
}
func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
n, sa, err = fd.pfd.ReadFrom(p)
runtime.KeepAlive(fd)
return n, sa, wrapSyscallError(readFromSyscallName, err)
}
func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) {
n, err = fd.pfd.ReadFromInet4(p, from)
runtime.KeepAlive(fd)
return n, wrapSyscallError(readFromSyscallName, err)
}
func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) {
n, err = fd.pfd.ReadFromInet6(p, from)
runtime.KeepAlive(fd)
return n, wrapSyscallError(readFromSyscallName, err)
}
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
runtime.KeepAlive(fd)
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
}
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
}
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
}
func (fd *netFD) Write(p []byte) (nn int, err error) {
nn, err = fd.pfd.Write(p)
runtime.KeepAlive(fd)
return nn, wrapSyscallError(writeSyscallName, err)
}
func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
n, err = fd.pfd.WriteTo(p, sa)
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}
func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
n, err = fd.pfd.WriteToInet4(p, sa)
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}
func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
n, err = fd.pfd.WriteToInet6(p, sa)
runtime.KeepAlive(fd)
return n, wrapSyscallError(writeToSyscallName, err)
}
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
}
func (fd *netFD) SetDeadline(t time.Time) error {
return fd.pfd.SetDeadline(t)
}
func (fd *netFD) SetReadDeadline(t time.Time) error {
return fd.pfd.SetReadDeadline(t)
}
func (fd *netFD) SetWriteDeadline(t time.Time) error {
return fd.pfd.SetWriteDeadline(t)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package net
import (
"context"
"internal/poll"
"os"
"runtime"
"syscall"
)
const (
readSyscallName = "read"
readFromSyscallName = "recvfrom"
readMsgSyscallName = "recvmsg"
writeSyscallName = "write"
writeToSyscallName = "sendto"
writeMsgSyscallName = "sendmsg"
)
func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
ret := &netFD{
pfd: poll.FD{
Sysfd: sysfd,
IsStream: sotype == syscall.SOCK_STREAM,
ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW,
},
family: family,
sotype: sotype,
net: net,
}
return ret, nil
}
func (fd *netFD) init() error {
return fd.pfd.Init(fd.net, true)
}
func (fd *netFD) name() string {
var ls, rs string
if fd.laddr != nil {
ls = fd.laddr.String()
}
if fd.raddr != nil {
rs = fd.raddr.String()
}
return fd.net + ":" + ls + "->" + rs
}
func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa syscall.Sockaddr, ret error) {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
switch err := connectFunc(fd.pfd.Sysfd, ra); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case nil, syscall.EISCONN:
select {
case <-ctx.Done():
return nil, mapErr(ctx.Err())
default:
}
if err := fd.pfd.Init(fd.net, true); err != nil {
return nil, err
}
runtime.KeepAlive(fd)
return nil, nil
case syscall.EINVAL:
// On Solaris and illumos we can see EINVAL if the socket has
// already been accepted and closed by the server. Treat this
// as a successful connection--writes to the socket will see
// EOF. For details and a test case in C see
// https://golang.org/issue/6828.
if runtime.GOOS == "solaris" || runtime.GOOS == "illumos" {
return nil, nil
}
fallthrough
default:
return nil, os.NewSyscallError("connect", err)
}
if err := fd.pfd.Init(fd.net, true); err != nil {
return nil, err
}
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
fd.pfd.SetWriteDeadline(deadline)
defer fd.pfd.SetWriteDeadline(noDeadline)
}
// Start the "interrupter" goroutine, if this context might be canceled.
//
// The interrupter goroutine waits for the context to be done and
// interrupts the dial (by altering the fd's write deadline, which
// wakes up waitWrite).
ctxDone := ctx.Done()
if ctxDone != nil {
// Wait for the interrupter goroutine to exit before returning
// from connect.
done := make(chan struct{})
interruptRes := make(chan error)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
// The interrupter goroutine called SetWriteDeadline,
// but the connect code below had returned from
// waitWrite already and did a successful connect (ret
// == nil). Because we've now poisoned the connection
// by making it unwritable, don't return a successful
// dial. This was issue 16523.
ret = mapErr(ctxErr)
fd.Close() // prevent a leak
}
}()
go func() {
select {
case <-ctxDone:
// Force the runtime's poller to immediately give up
// waiting for writability, unblocking waitWrite
// below.
fd.pfd.SetWriteDeadline(aLongTimeAgo)
testHookCanceledDial()
interruptRes <- ctx.Err()
case <-done:
interruptRes <- nil
}
}()
}
for {
// Performing multiple connect system calls on a
// non-blocking socket under Unix variants does not
// necessarily result in earlier errors being
// returned. Instead, once runtime-integrated network
// poller tells us that the socket is ready, get the
// SO_ERROR socket option to see if the connection
// succeeded or failed. See issue 7474 for further
// details.
if err := fd.pfd.WaitWrite(); err != nil {
select {
case <-ctxDone:
return nil, mapErr(ctx.Err())
default:
}
return nil, err
}
nerr, err := getsockoptIntFunc(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
if err != nil {
return nil, os.NewSyscallError("getsockopt", err)
}
switch err := syscall.Errno(nerr); err {
case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
case syscall.EISCONN:
return nil, nil
case syscall.Errno(0):
// The runtime poller can wake us up spuriously;
// see issues 14548 and 19289. Check that we are
// really connected; if not, wait again.
if rsa, err := syscall.Getpeername(fd.pfd.Sysfd); err == nil {
return rsa, nil
}
default:
return nil, os.NewSyscallError("connect", err)
}
runtime.KeepAlive(fd)
}
}
func (fd *netFD) accept() (netfd *netFD, err error) {
d, rsa, errcall, err := fd.pfd.Accept()
if err != nil {
if errcall != "" {
err = wrapSyscallError(errcall, err)
}
return nil, err
}
if netfd, err = newFD(d, fd.family, fd.sotype, fd.net); err != nil {
poll.CloseFunc(d)
return nil, err
}
if err = netfd.init(); err != nil {
netfd.Close()
return nil, err
}
lsa, _ := syscall.Getsockname(netfd.pfd.Sysfd)
netfd.setAddr(netfd.addrFunc()(lsa), netfd.addrFunc()(rsa))
return netfd, nil
}
func (fd *netFD) dup() (f *os.File, err error) {
ns, call, err := fd.pfd.Dup()
if err != nil {
if call != "" {
err = os.NewSyscallError(call, err)
}
return nil, err
}
return os.NewFile(uintptr(ns), fd.name()), nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import "os"
// BUG(mikio): On JS and Windows, the FileConn, FileListener and
// FilePacketConn functions are not implemented.
type fileAddr string
func (fileAddr) Network() string { return "file+net" }
func (f fileAddr) String() string { return string(f) }
// FileConn returns a copy of the network connection corresponding to
// the open file f.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
func FileConn(f *os.File) (c Conn, err error) {
c, err = fileConn(f)
if err != nil {
err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
}
return
}
// FileListener returns a copy of the network listener corresponding
// to the open file f.
// It is the caller's responsibility to close ln when finished.
// Closing ln does not affect f, and closing f does not affect ln.
func FileListener(f *os.File) (ln Listener, err error) {
ln, err = fileListener(f)
if err != nil {
err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
}
return
}
// FilePacketConn returns a copy of the packet network connection
// corresponding to the open file f.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
func FilePacketConn(f *os.File) (c PacketConn, err error) {
c, err = filePacketConn(f)
if err != nil {
err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package net
import (
"internal/poll"
"os"
"syscall"
)
func dupSocket(f *os.File) (int, error) {
s, call, err := poll.DupCloseOnExec(int(f.Fd()))
if err != nil {
if call != "" {
err = os.NewSyscallError(call, err)
}
return -1, err
}
if err := syscall.SetNonblock(s, true); err != nil {
poll.CloseFunc(s)
return -1, os.NewSyscallError("setnonblock", err)
}
return s, nil
}
func newFileFD(f *os.File) (*netFD, error) {
s, err := dupSocket(f)
if err != nil {
return nil, err
}
family := syscall.AF_UNSPEC
sotype, err := syscall.GetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_TYPE)
if err != nil {
poll.CloseFunc(s)
return nil, os.NewSyscallError("getsockopt", err)
}
lsa, _ := syscall.Getsockname(s)
rsa, _ := syscall.Getpeername(s)
switch lsa.(type) {
case *syscall.SockaddrInet4:
family = syscall.AF_INET
case *syscall.SockaddrInet6:
family = syscall.AF_INET6
case *syscall.SockaddrUnix:
family = syscall.AF_UNIX
default:
poll.CloseFunc(s)
return nil, syscall.EPROTONOSUPPORT
}
fd, err := newFD(s, family, sotype, "")
if err != nil {
poll.CloseFunc(s)
return nil, err
}
laddr := fd.addrFunc()(lsa)
raddr := fd.addrFunc()(rsa)
fd.net = laddr.Network()
if err := fd.init(); err != nil {
fd.Close()
return nil, err
}
fd.setAddr(laddr, raddr)
return fd, nil
}
func fileConn(f *os.File) (Conn, error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
}
switch fd.laddr.(type) {
case *TCPAddr:
return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
case *UDPAddr:
return newUDPConn(fd), nil
case *IPAddr:
return newIPConn(fd), nil
case *UnixAddr:
return newUnixConn(fd), nil
}
fd.Close()
return nil, syscall.EINVAL
}
func fileListener(f *os.File) (Listener, error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
}
switch laddr := fd.laddr.(type) {
case *TCPAddr:
return &TCPListener{fd: fd}, nil
case *UnixAddr:
return &UnixListener{fd: fd, path: laddr.Name, unlink: false}, nil
}
fd.Close()
return nil, syscall.EINVAL
}
func filePacketConn(f *os.File) (PacketConn, error) {
fd, err := newFileFD(f)
if err != nil {
return nil, err
}
switch fd.laddr.(type) {
case *UDPAddr:
return newUDPConn(fd), nil
case *IPAddr:
return newIPConn(fd), nil
case *UnixAddr:
return newUnixConn(fd), nil
}
fd.Close()
return nil, syscall.EINVAL
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"time"
)
var (
// if non-nil, overrides dialTCP.
testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
testHookHostsPath = "/etc/hosts"
testHookLookupIP = func(
ctx context.Context,
fn func(context.Context, string, string) ([]IPAddr, error),
network string,
host string,
) ([]IPAddr, error) {
return fn(ctx, network, host)
}
testHookSetKeepAlive = func(time.Duration) {}
)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package net
import "syscall"
var (
testHookDialChannel = func() {} // for golang.org/issue/5349
testHookCanceledDial = func() {} // for golang.org/issue/16523
// Placeholders for socket system calls.
socketFunc func(int, int, int) (int, error) = syscall.Socket
connectFunc func(int, syscall.Sockaddr) error = syscall.Connect
listenFunc func(int, int) error = syscall.Listen
getsockoptIntFunc func(int, int, int) (int, error) = syscall.GetsockoptInt
)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/bytealg"
"net/netip"
"sync"
"time"
)
const cacheMaxAge = 5 * time.Second
func parseLiteralIP(addr string) string {
ip, err := netip.ParseAddr(addr)
if err != nil {
return ""
}
return ip.String()
}
type byName struct {
addrs []string
canonicalName string
}
// hosts contains known host entries.
var hosts struct {
sync.Mutex
// Key for the list of literal IP addresses must be a host
// name. It would be part of DNS labels, a FQDN or an absolute
// FQDN.
// For now the key is converted to lower case for convenience.
byName map[string]byName
// Key for the list of host names must be a literal IP address
// including IPv6 address with zone identifier.
// We don't support old-classful IP address notation.
byAddr map[string][]string
expire time.Time
path string
mtime time.Time
size int64
}
func readHosts() {
now := time.Now()
hp := testHookHostsPath
if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 {
return
}
mtime, size, err := stat(hp)
if err == nil && hosts.path == hp && hosts.mtime.Equal(mtime) && hosts.size == size {
hosts.expire = now.Add(cacheMaxAge)
return
}
hs := make(map[string]byName)
is := make(map[string][]string)
var file *file
if file, _ = open(hp); file == nil {
return
}
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
if i := bytealg.IndexByteString(line, '#'); i >= 0 {
// Discard comments.
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
addr := parseLiteralIP(f[0])
if addr == "" {
continue
}
var canonical string
for i := 1; i < len(f); i++ {
name := absDomainName(f[i])
h := []byte(f[i])
lowerASCIIBytes(h)
key := absDomainName(string(h))
if i == 1 {
canonical = key
}
is[addr] = append(is[addr], name)
if v, ok := hs[key]; ok {
hs[key] = byName{
addrs: append(v.addrs, addr),
canonicalName: v.canonicalName,
}
continue
}
hs[key] = byName{
addrs: []string{addr},
canonicalName: canonical,
}
}
}
// Update the data cache.
hosts.expire = now.Add(cacheMaxAge)
hosts.path = hp
hosts.byName = hs
hosts.byAddr = is
hosts.mtime = mtime
hosts.size = size
file.close()
}
// lookupStaticHost looks up the addresses and the canonical name for the given host from /etc/hosts.
func lookupStaticHost(host string) ([]string, string) {
hosts.Lock()
defer hosts.Unlock()
readHosts()
if len(hosts.byName) != 0 {
if hasUpperCase(host) {
lowerHost := []byte(host)
lowerASCIIBytes(lowerHost)
host = string(lowerHost)
}
if byName, ok := hosts.byName[absDomainName(host)]; ok {
ipsCp := make([]string, len(byName.addrs))
copy(ipsCp, byName.addrs)
return ipsCp, byName.canonicalName
}
}
return nil, ""
}
// lookupStaticAddr looks up the hosts for the given address from /etc/hosts.
func lookupStaticAddr(addr string) []string {
hosts.Lock()
defer hosts.Unlock()
readHosts()
addr = parseLiteralIP(addr)
if addr == "" {
return nil
}
if len(hosts.byAddr) != 0 {
if hosts, ok := hosts.byAddr[addr]; ok {
hostsCp := make([]string, len(hosts))
copy(hostsCp, hosts)
return hostsCp
}
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements CGI from the perspective of a child
// process.
package cgi
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
)
// Request returns the HTTP request as represented in the current
// environment. This assumes the current program is being run
// by a web server in a CGI environment.
// The returned Request's Body is populated, if applicable.
func Request() (*http.Request, error) {
r, err := RequestFromMap(envMap(os.Environ()))
if err != nil {
return nil, err
}
if r.ContentLength > 0 {
r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength))
}
return r, nil
}
func envMap(env []string) map[string]string {
m := make(map[string]string)
for _, kv := range env {
if k, v, ok := strings.Cut(kv, "="); ok {
m[k] = v
}
}
return m
}
// RequestFromMap creates an http.Request from CGI variables.
// The returned Request's Body field is not populated.
func RequestFromMap(params map[string]string) (*http.Request, error) {
r := new(http.Request)
r.Method = params["REQUEST_METHOD"]
if r.Method == "" {
return nil, errors.New("cgi: no REQUEST_METHOD in environment")
}
r.Proto = params["SERVER_PROTOCOL"]
var ok bool
r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto)
if !ok {
return nil, errors.New("cgi: invalid SERVER_PROTOCOL version")
}
r.Close = true
r.Trailer = http.Header{}
r.Header = http.Header{}
r.Host = params["HTTP_HOST"]
if lenstr := params["CONTENT_LENGTH"]; lenstr != "" {
clen, err := strconv.ParseInt(lenstr, 10, 64)
if err != nil {
return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
}
r.ContentLength = clen
}
if ct := params["CONTENT_TYPE"]; ct != "" {
r.Header.Set("Content-Type", ct)
}
// Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
for k, v := range params {
if k == "HTTP_HOST" {
continue
}
if after, found := strings.CutPrefix(k, "HTTP_"); found {
r.Header.Add(strings.ReplaceAll(after, "_", "-"), v)
}
}
uriStr := params["REQUEST_URI"]
if uriStr == "" {
// Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
s := params["QUERY_STRING"]
if s != "" {
uriStr += "?" + s
}
}
// There's apparently a de-facto standard for this.
// https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
r.TLS = &tls.ConnectionState{HandshakeComplete: true}
}
if r.Host != "" {
// Hostname is provided, so we can reasonably construct a URL.
rawurl := r.Host + uriStr
if r.TLS == nil {
rawurl = "http://" + rawurl
} else {
rawurl = "https://" + rawurl
}
url, err := url.Parse(rawurl)
if err != nil {
return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
}
r.URL = url
}
// Fallback logic if we don't have a Host header or the URL
// failed to parse
if r.URL == nil {
url, err := url.Parse(uriStr)
if err != nil {
return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
}
r.URL = url
}
// Request.RemoteAddr has its port set by Go's standard http
// server, so we do here too.
remotePort, _ := strconv.Atoi(params["REMOTE_PORT"]) // zero if unset or invalid
r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], strconv.Itoa(remotePort))
return r, nil
}
// Serve executes the provided Handler on the currently active CGI
// request, if any. If there's no current CGI environment
// an error is returned. The provided handler may be nil to use
// http.DefaultServeMux.
func Serve(handler http.Handler) error {
req, err := Request()
if err != nil {
return err
}
if req.Body == nil {
req.Body = http.NoBody
}
if handler == nil {
handler = http.DefaultServeMux
}
rw := &response{
req: req,
header: make(http.Header),
bufw: bufio.NewWriter(os.Stdout),
}
handler.ServeHTTP(rw, req)
rw.Write(nil) // make sure a response is sent
if err = rw.bufw.Flush(); err != nil {
return err
}
return nil
}
type response struct {
req *http.Request
header http.Header
code int
wroteHeader bool
wroteCGIHeader bool
bufw *bufio.Writer
}
func (r *response) Flush() {
r.bufw.Flush()
}
func (r *response) Header() http.Header {
return r.header
}
func (r *response) Write(p []byte) (n int, err error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if !r.wroteCGIHeader {
r.writeCGIHeader(p)
}
return r.bufw.Write(p)
}
func (r *response) WriteHeader(code int) {
if r.wroteHeader {
// Note: explicitly using Stderr, as Stdout is our HTTP output.
fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
return
}
r.wroteHeader = true
r.code = code
}
// writeCGIHeader finalizes the header sent to the client and writes it to the output.
// p is not written by writeHeader, but is the first chunk of the body
// that will be written. It is sniffed for a Content-Type if none is
// set explicitly.
func (r *response) writeCGIHeader(p []byte) {
if r.wroteCGIHeader {
return
}
r.wroteCGIHeader = true
fmt.Fprintf(r.bufw, "Status: %d %s\r\n", r.code, http.StatusText(r.code))
if _, hasType := r.header["Content-Type"]; !hasType {
r.header.Set("Content-Type", http.DetectContentType(p))
}
r.header.Write(r.bufw)
r.bufw.WriteString("\r\n")
r.bufw.Flush()
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements the host side of CGI (being the webserver
// parent process).
// Package cgi implements CGI (Common Gateway Interface) as specified
// in RFC 3875.
//
// Note that using CGI means starting a new process to handle each
// request, which is typically less efficient than using a
// long-running server. This package is intended primarily for
// compatibility with existing systems.
package cgi
import (
"bufio"
"fmt"
"io"
"log"
"net"
"net/http"
"net/textproto"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
)
var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
var osDefaultInheritEnv = func() []string {
switch runtime.GOOS {
case "darwin", "ios":
return []string{"DYLD_LIBRARY_PATH"}
case "linux", "freebsd", "netbsd", "openbsd":
return []string{"LD_LIBRARY_PATH"}
case "hpux":
return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}
case "irix":
return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}
case "illumos", "solaris":
return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}
case "windows":
return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}
}
return nil
}()
// Handler runs an executable in a subprocess with a CGI environment.
type Handler struct {
Path string // path to the CGI executable
Root string // root URI prefix of handler or empty for "/"
// Dir specifies the CGI executable's working directory.
// If Dir is empty, the base directory of Path is used.
// If Path has no base directory, the current working
// directory is used.
Dir string
Env []string // extra environment variables to set, if any, as "key=value"
InheritEnv []string // environment variables to inherit from host, as "key"
Logger *log.Logger // optional log for errors or nil to use log.Print
Args []string // optional arguments to pass to child process
Stderr io.Writer // optional stderr for the child process; nil means os.Stderr
// PathLocationHandler specifies the root http Handler that
// should handle internal redirects when the CGI process
// returns a Location header value starting with a "/", as
// specified in RFC 3875 § 6.3.2. This will likely be
// http.DefaultServeMux.
//
// If nil, a CGI response with a local URI path is instead sent
// back to the client and not redirected internally.
PathLocationHandler http.Handler
}
func (h *Handler) stderr() io.Writer {
if h.Stderr != nil {
return h.Stderr
}
return os.Stderr
}
// removeLeadingDuplicates remove leading duplicate in environments.
// It's possible to override environment like following.
//
// cgi.Handler{
// ...
// Env: []string{"SCRIPT_FILENAME=foo.php"},
// }
func removeLeadingDuplicates(env []string) (ret []string) {
for i, e := range env {
found := false
if eq := strings.IndexByte(e, '='); eq != -1 {
keq := e[:eq+1] // "key="
for _, e2 := range env[i+1:] {
if strings.HasPrefix(e2, keq) {
found = true
break
}
}
}
if !found {
ret = append(ret, e)
}
}
return
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
root := h.Root
if root == "" {
root = "/"
}
if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte("Chunked request bodies are not supported by CGI."))
return
}
pathInfo := req.URL.Path
if root != "/" && strings.HasPrefix(pathInfo, root) {
pathInfo = pathInfo[len(root):]
}
port := "80"
if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
port = matches[1]
}
env := []string{
"SERVER_SOFTWARE=go",
"SERVER_PROTOCOL=HTTP/1.1",
"HTTP_HOST=" + req.Host,
"GATEWAY_INTERFACE=CGI/1.1",
"REQUEST_METHOD=" + req.Method,
"QUERY_STRING=" + req.URL.RawQuery,
"REQUEST_URI=" + req.URL.RequestURI(),
"PATH_INFO=" + pathInfo,
"SCRIPT_NAME=" + root,
"SCRIPT_FILENAME=" + h.Path,
"SERVER_PORT=" + port,
}
if remoteIP, remotePort, err := net.SplitHostPort(req.RemoteAddr); err == nil {
env = append(env, "REMOTE_ADDR="+remoteIP, "REMOTE_HOST="+remoteIP, "REMOTE_PORT="+remotePort)
} else {
// could not parse ip:port, let's use whole RemoteAddr and leave REMOTE_PORT undefined
env = append(env, "REMOTE_ADDR="+req.RemoteAddr, "REMOTE_HOST="+req.RemoteAddr)
}
if hostDomain, _, err := net.SplitHostPort(req.Host); err == nil {
env = append(env, "SERVER_NAME="+hostDomain)
} else {
env = append(env, "SERVER_NAME="+req.Host)
}
if req.TLS != nil {
env = append(env, "HTTPS=on")
}
for k, v := range req.Header {
k = strings.Map(upperCaseAndUnderscore, k)
if k == "PROXY" {
// See Issue 16405
continue
}
joinStr := ", "
if k == "COOKIE" {
joinStr = "; "
}
env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr))
}
if req.ContentLength > 0 {
env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
}
if ctype := req.Header.Get("Content-Type"); ctype != "" {
env = append(env, "CONTENT_TYPE="+ctype)
}
envPath := os.Getenv("PATH")
if envPath == "" {
envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
}
env = append(env, "PATH="+envPath)
for _, e := range h.InheritEnv {
if v := os.Getenv(e); v != "" {
env = append(env, e+"="+v)
}
}
for _, e := range osDefaultInheritEnv {
if v := os.Getenv(e); v != "" {
env = append(env, e+"="+v)
}
}
if h.Env != nil {
env = append(env, h.Env...)
}
env = removeLeadingDuplicates(env)
var cwd, path string
if h.Dir != "" {
path = h.Path
cwd = h.Dir
} else {
cwd, path = filepath.Split(h.Path)
}
if cwd == "" {
cwd = "."
}
internalError := func(err error) {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("CGI error: %v", err)
}
cmd := &exec.Cmd{
Path: path,
Args: append([]string{h.Path}, h.Args...),
Dir: cwd,
Env: env,
Stderr: h.stderr(),
}
if req.ContentLength != 0 {
cmd.Stdin = req.Body
}
stdoutRead, err := cmd.StdoutPipe()
if err != nil {
internalError(err)
return
}
err = cmd.Start()
if err != nil {
internalError(err)
return
}
if hook := testHookStartProcess; hook != nil {
hook(cmd.Process)
}
defer cmd.Wait()
defer stdoutRead.Close()
linebody := bufio.NewReaderSize(stdoutRead, 1024)
headers := make(http.Header)
statusCode := 0
headerLines := 0
sawBlankLine := false
for {
line, isPrefix, err := linebody.ReadLine()
if isPrefix {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: long header line from subprocess.")
return
}
if err == io.EOF {
break
}
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: error reading headers: %v", err)
return
}
if len(line) == 0 {
sawBlankLine = true
break
}
headerLines++
header, val, ok := strings.Cut(string(line), ":")
if !ok {
h.printf("cgi: bogus header line: %s", string(line))
continue
}
if !httpguts.ValidHeaderFieldName(header) {
h.printf("cgi: invalid header name: %q", header)
continue
}
val = textproto.TrimString(val)
switch {
case header == "Status":
if len(val) < 3 {
h.printf("cgi: bogus status (short): %q", val)
return
}
code, err := strconv.Atoi(val[0:3])
if err != nil {
h.printf("cgi: bogus status: %q", val)
h.printf("cgi: line was %q", line)
return
}
statusCode = code
default:
headers.Add(header, val)
}
}
if headerLines == 0 || !sawBlankLine {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: no headers")
return
}
if loc := headers.Get("Location"); loc != "" {
if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
h.handleInternalRedirect(rw, req, loc)
return
}
if statusCode == 0 {
statusCode = http.StatusFound
}
}
if statusCode == 0 && headers.Get("Content-Type") == "" {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: missing required Content-Type in headers")
return
}
if statusCode == 0 {
statusCode = http.StatusOK
}
// Copy headers to rw's headers, after we've decided not to
// go into handleInternalRedirect, which won't want its rw
// headers to have been touched.
for k, vv := range headers {
for _, v := range vv {
rw.Header().Add(k, v)
}
}
rw.WriteHeader(statusCode)
_, err = io.Copy(rw, linebody)
if err != nil {
h.printf("cgi: copy error: %v", err)
// And kill the child CGI process so we don't hang on
// the deferred cmd.Wait above if the error was just
// the client (rw) going away. If it was a read error
// (because the child died itself), then the extra
// kill of an already-dead process is harmless (the PID
// won't be reused until the Wait above).
cmd.Process.Kill()
}
}
func (h *Handler) printf(format string, v ...any) {
if h.Logger != nil {
h.Logger.Printf(format, v...)
} else {
log.Printf(format, v...)
}
}
func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
url, err := req.URL.Parse(path)
if err != nil {
rw.WriteHeader(http.StatusInternalServerError)
h.printf("cgi: error resolving local URI path %q: %v", path, err)
return
}
// TODO: RFC 3875 isn't clear if only GET is supported, but it
// suggests so: "Note that any message-body attached to the
// request (such as for a POST request) may not be available
// to the resource that is the target of the redirect." We
// should do some tests against Apache to see how it handles
// POST, HEAD, etc. Does the internal redirect get the same
// method or just GET? What about incoming headers?
// (e.g. Cookies) Which headers, if any, are copied into the
// second request?
newReq := &http.Request{
Method: "GET",
URL: url,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: url.Host,
RemoteAddr: req.RemoteAddr,
TLS: req.TLS,
}
h.PathLocationHandler.ServeHTTP(rw, newReq)
}
func upperCaseAndUnderscore(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r - ('a' - 'A')
case r == '-':
return '_'
case r == '=':
// Maybe not part of the CGI 'spec' but would mess up
// the environment in any case, as Go represents the
// environment as a slice of "key=value" strings.
return '_'
}
// TODO: other transformations in spec or practice?
return r
}
var testHookStartProcess func(*os.Process) // nil except for some tests
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP client. See RFC 7230 through 7235.
//
// This is the high-level Client interface.
// The low-level implementation is in transport.go.
package http
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http/internal/ascii"
"net/url"
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
// A Client is an HTTP client. Its zero value (DefaultClient) is a
// usable client that uses DefaultTransport.
//
// The Client's Transport typically has internal state (cached TCP
// connections), so Clients should be reused instead of created as
// needed. Clients are safe for concurrent use by multiple goroutines.
//
// A Client is higher-level than a RoundTripper (such as Transport)
// and additionally handles HTTP details such as cookies and
// redirects.
//
// When following redirects, the Client will forward all headers set on the
// initial Request except:
//
// • when forwarding sensitive headers like "Authorization",
// "WWW-Authenticate", and "Cookie" to untrusted targets.
// These headers will be ignored when following a redirect to a domain
// that is not a subdomain match or exact match of the initial domain.
// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com"
// will forward the sensitive headers, but a redirect to "bar.com" will not.
//
// • when forwarding the "Cookie" header with a non-nil cookie Jar.
// Since each redirect may mutate the state of the cookie jar,
// a redirect may possibly alter a cookie set in the initial request.
// When forwarding the "Cookie" header, any mutated cookies will be omitted,
// with the expectation that the Jar will insert those mutated cookies
// with the updated values (assuming the origin matches).
// If Jar is nil, the initial cookies are forwarded without change.
type Client struct {
// Transport specifies the mechanism by which individual
// HTTP requests are made.
// If nil, DefaultTransport is used.
Transport RoundTripper
// CheckRedirect specifies the policy for handling redirects.
// If CheckRedirect is not nil, the client calls it before
// following an HTTP redirect. The arguments req and via are
// the upcoming request and the requests made already, oldest
// first. If CheckRedirect returns an error, the Client's Get
// method returns both the previous Response (with its Body
// closed) and CheckRedirect's error (wrapped in a url.Error)
// instead of issuing the Request req.
// As a special case, if CheckRedirect returns ErrUseLastResponse,
// then the most recent response is returned with its body
// unclosed, along with a nil error.
//
// If CheckRedirect is nil, the Client uses its default policy,
// which is to stop after 10 consecutive requests.
CheckRedirect func(req *Request, via []*Request) error
// Jar specifies the cookie jar.
//
// The Jar is used to insert relevant cookies into every
// outbound Request and is updated with the cookie values
// of every inbound Response. The Jar is consulted for every
// redirect that the Client follows.
//
// If Jar is nil, cookies are only sent if they are explicitly
// set on the Request.
Jar CookieJar
// Timeout specifies a time limit for requests made by this
// Client. The timeout includes connection time, any
// redirects, and reading the response body. The timer remains
// running after Get, Head, Post, or Do return and will
// interrupt reading of the Response.Body.
//
// A Timeout of zero means no timeout.
//
// The Client cancels requests to the underlying Transport
// as if the Request's Context ended.
//
// For compatibility, the Client will also use the deprecated
// CancelRequest method on Transport if found. New
// RoundTripper implementations should use the Request's Context
// for cancellation instead of implementing CancelRequest.
Timeout time.Duration
}
// DefaultClient is the default Client and is used by Get, Head, and Post.
var DefaultClient = &Client{}
// RoundTripper is an interface representing the ability to execute a
// single HTTP transaction, obtaining the Response for a given Request.
//
// A RoundTripper must be safe for concurrent use by multiple
// goroutines.
type RoundTripper interface {
// RoundTrip executes a single HTTP transaction, returning
// a Response for the provided Request.
//
// RoundTrip should not attempt to interpret the response. In
// particular, RoundTrip must return err == nil if it obtained
// a response, regardless of the response's HTTP status code.
// A non-nil err should be reserved for failure to obtain a
// response. Similarly, RoundTrip should not attempt to
// handle higher-level protocol details such as redirects,
// authentication, or cookies.
//
// RoundTrip should not modify the request, except for
// consuming and closing the Request's Body. RoundTrip may
// read fields of the request in a separate goroutine. Callers
// should not mutate or reuse the request until the Response's
// Body has been closed.
//
// RoundTrip must always close the body, including on errors,
// but depending on the implementation may do so in a separate
// goroutine even after RoundTrip returns. This means that
// callers wanting to reuse the body for subsequent requests
// must arrange to wait for the Close call before doing so.
//
// The Request's URL and Header fields must be initialized.
RoundTrip(*Request) (*Response, error)
}
// refererForURL returns a referer without any authentication info or
// an empty string if lastReq scheme is https and newReq scheme is http.
func refererForURL(lastReq, newReq *url.URL) string {
// https://tools.ietf.org/html/rfc7231#section-5.5.2
// "Clients SHOULD NOT include a Referer header field in a
// (non-secure) HTTP request if the referring page was
// transferred with a secure protocol."
if lastReq.Scheme == "https" && newReq.Scheme == "http" {
return ""
}
referer := lastReq.String()
if lastReq.User != nil {
// This is not very efficient, but is the best we can
// do without:
// - introducing a new method on URL
// - creating a race condition
// - copying the URL struct manually, which would cause
// maintenance problems down the line
auth := lastReq.User.String() + "@"
referer = strings.Replace(referer, auth, "", 1)
}
return referer
}
// didTimeout is non-nil only if err != nil.
func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
if c.Jar != nil {
for _, cookie := range c.Jar.Cookies(req.URL) {
req.AddCookie(cookie)
}
}
resp, didTimeout, err = send(req, c.transport(), deadline)
if err != nil {
return nil, didTimeout, err
}
if c.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
c.Jar.SetCookies(req.URL, rc)
}
}
return resp, nil, nil
}
func (c *Client) deadline() time.Time {
if c.Timeout > 0 {
return time.Now().Add(c.Timeout)
}
return time.Time{}
}
func (c *Client) transport() RoundTripper {
if c.Transport != nil {
return c.Transport
}
return DefaultTransport
}
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
req := ireq // req is either the original request, or a modified fork
if rt == nil {
req.closeBody()
return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport")
}
if req.URL == nil {
req.closeBody()
return nil, alwaysFalse, errors.New("http: nil Request.URL")
}
if req.RequestURI != "" {
req.closeBody()
return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests")
}
// forkReq forks req into a shallow clone of ireq the first
// time it's called.
forkReq := func() {
if ireq == req {
req = new(Request)
*req = *ireq // shallow clone
}
}
// Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though.
if req.Header == nil {
forkReq()
req.Header = make(Header)
}
if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" {
username := u.Username()
password, _ := u.Password()
forkReq()
req.Header = cloneOrMakeHeader(ireq.Header)
req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
if !deadline.IsZero() {
forkReq()
}
stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
resp, err = rt.RoundTrip(req)
if err != nil {
stopTimer()
if resp != nil {
log.Printf("RoundTripper returned a response & error; ignoring response")
}
if tlsErr, ok := err.(tls.RecordHeaderError); ok {
// If we get a bad TLS record header, check to see if the
// response looks like HTTP and give a more helpful error.
// See golang.org/issue/11111.
if string(tlsErr.RecordHeader[:]) == "HTTP/" {
err = errors.New("http: server gave HTTP response to HTTPS client")
}
}
return nil, didTimeout, err
}
if resp == nil {
return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt)
}
if resp.Body == nil {
// The documentation on the Body field says “The http Client and Transport
// guarantee that Body is always non-nil, even on responses without a body
// or responses with a zero-length body.” Unfortunately, we didn't document
// that same constraint for arbitrary RoundTripper implementations, and
// RoundTripper implementations in the wild (mostly in tests) assume that
// they can use a nil Body to mean an empty one (similar to Request.Body).
// (See https://golang.org/issue/38095.)
//
// If the ContentLength allows the Body to be empty, fill in an empty one
// here to ensure that it is non-nil.
if resp.ContentLength > 0 && req.Method != "HEAD" {
return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength)
}
resp.Body = io.NopCloser(strings.NewReader(""))
}
if !deadline.IsZero() {
resp.Body = &cancelTimerBody{
stop: stopTimer,
rc: resp.Body,
reqDidTimeout: didTimeout,
}
}
return resp, nil, nil
}
// timeBeforeContextDeadline reports whether the non-zero Time t is
// before ctx's deadline, if any. If ctx does not have a deadline, it
// always reports true (the deadline is considered infinite).
func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool {
d, ok := ctx.Deadline()
if !ok {
return true
}
return t.Before(d)
}
// knownRoundTripperImpl reports whether rt is a RoundTripper that's
// maintained by the Go team and known to implement the latest
// optional semantics (notably contexts). The Request is used
// to check whether this particular request is using an alternate protocol,
// in which case we need to check the RoundTripper for that protocol.
func knownRoundTripperImpl(rt RoundTripper, req *Request) bool {
switch t := rt.(type) {
case *Transport:
if altRT := t.alternateRoundTripper(req); altRT != nil {
return knownRoundTripperImpl(altRT, req)
}
return true
case *http2Transport, http2noDialH2RoundTripper:
return true
}
// There's a very minor chance of a false positive with this.
// Instead of detecting our golang.org/x/net/http2.Transport,
// it might detect a Transport type in a different http2
// package. But I know of none, and the only problem would be
// some temporarily leaked goroutines if the transport didn't
// support contexts. So this is a good enough heuristic:
if reflect.TypeOf(rt).String() == "*http2.Transport" {
return true
}
return false
}
// setRequestCancel sets req.Cancel and adds a deadline context to req
// if deadline is non-zero. The RoundTripper's type is used to
// determine whether the legacy CancelRequest behavior should be used.
//
// As background, there are three ways to cancel a request:
// First was Transport.CancelRequest. (deprecated)
// Second was Request.Cancel.
// Third was Request.Context.
// This function populates the second and third, and uses the first if it really needs to.
func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) {
if deadline.IsZero() {
return nop, alwaysFalse
}
knownTransport := knownRoundTripperImpl(rt, req)
oldCtx := req.Context()
if req.Cancel == nil && knownTransport {
// If they already had a Request.Context that's
// expiring sooner, do nothing:
if !timeBeforeContextDeadline(deadline, oldCtx) {
return nop, alwaysFalse
}
var cancelCtx func()
req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline)
return cancelCtx, func() bool { return time.Now().After(deadline) }
}
initialReqCancel := req.Cancel // the user's original Request.Cancel, if any
var cancelCtx func()
if timeBeforeContextDeadline(deadline, oldCtx) {
req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline)
}
cancel := make(chan struct{})
req.Cancel = cancel
doCancel := func() {
// The second way in the func comment above:
close(cancel)
// The first way, used only for RoundTripper
// implementations written before Go 1.5 or Go 1.6.
type canceler interface{ CancelRequest(*Request) }
if v, ok := rt.(canceler); ok {
v.CancelRequest(req)
}
}
stopTimerCh := make(chan struct{})
var once sync.Once
stopTimer = func() {
once.Do(func() {
close(stopTimerCh)
if cancelCtx != nil {
cancelCtx()
}
})
}
timer := time.NewTimer(time.Until(deadline))
var timedOut atomic.Bool
go func() {
select {
case <-initialReqCancel:
doCancel()
timer.Stop()
case <-timer.C:
timedOut.Store(true)
doCancel()
case <-stopTimerCh:
timer.Stop()
}
}()
return stopTimer, timedOut.Load
}
// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
// "To receive authorization, the client sends the userid and password,
// separated by a single colon (":") character, within a base64
// encoded string in the credentials."
// It is not meant to be urlencoded.
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
// Get issues a GET to the specified URL. If the response is one of
// the following redirect codes, Get follows the redirect, up to a
// maximum of 10 redirects:
//
// 301 (Moved Permanently)
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
// An error is returned if there were too many redirects or if there
// was an HTTP protocol error. A non-2xx response doesn't cause an
// error. Any returned error will be of type *url.Error. The url.Error
// value's Timeout method will report true if the request timed out.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
// Get is a wrapper around DefaultClient.Get.
//
// To make a request with custom headers, use NewRequest and
// DefaultClient.Do.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and DefaultClient.Do.
func Get(url string) (resp *Response, err error) {
return DefaultClient.Get(url)
}
// Get issues a GET to the specified URL. If the response is one of the
// following redirect codes, Get follows the redirect after calling the
// Client's CheckRedirect function:
//
// 301 (Moved Permanently)
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
// An error is returned if the Client's CheckRedirect function fails
// or if there was an HTTP protocol error. A non-2xx response doesn't
// cause an error. Any returned error will be of type *url.Error. The
// url.Error value's Timeout method will report true if the request
// timed out.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
// To make a request with custom headers, use NewRequest and Client.Do.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and Client.Do.
func (c *Client) Get(url string) (resp *Response, err error) {
req, err := NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
}
func alwaysFalse() bool { return false }
// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to
// control how redirects are processed. If returned, the next request
// is not sent and the most recent response is returned with its body
// unclosed.
var ErrUseLastResponse = errors.New("net/http: use last response")
// checkRedirect calls either the user's configured CheckRedirect
// function, or the default.
func (c *Client) checkRedirect(req *Request, via []*Request) error {
fn := c.CheckRedirect
if fn == nil {
fn = defaultCheckRedirect
}
return fn(req, via)
}
// redirectBehavior describes what should happen when the
// client encounters a 3xx status code from the server.
func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) {
switch resp.StatusCode {
case 301, 302, 303:
redirectMethod = reqMethod
shouldRedirect = true
includeBody = false
// RFC 2616 allowed automatic redirection only with GET and
// HEAD requests. RFC 7231 lifts this restriction, but we still
// restrict other methods to GET to maintain compatibility.
// See Issue 18570.
if reqMethod != "GET" && reqMethod != "HEAD" {
redirectMethod = "GET"
}
case 307, 308:
redirectMethod = reqMethod
shouldRedirect = true
includeBody = true
if ireq.GetBody == nil && ireq.outgoingLength() != 0 {
// We had a request body, and 307/308 require
// re-sending it, but GetBody is not defined. So just
// return this response to the user instead of an
// error, like we did in Go 1.7 and earlier.
shouldRedirect = false
}
}
return redirectMethod, shouldRedirect, includeBody
}
// urlErrorOp returns the (*url.Error).Op value to use for the
// provided (*Request).Method value.
func urlErrorOp(method string) string {
if method == "" {
return "Get"
}
if lowerMethod, ok := ascii.ToLower(method); ok {
return method[:1] + lowerMethod[1:]
}
return method
}
// Do sends an HTTP request and returns an HTTP response, following
// policy (such as redirects, cookies, auth) as configured on the
// client.
//
// An error is returned if caused by client policy (such as
// CheckRedirect), or failure to speak HTTP (such as a network
// connectivity problem). A non-2xx status code doesn't cause an
// error.
//
// If the returned error is nil, the Response will contain a non-nil
// Body which the user is expected to close. If the Body is not both
// read to EOF and closed, the Client's underlying RoundTripper
// (typically Transport) may not be able to re-use a persistent TCP
// connection to the server for a subsequent "keep-alive" request.
//
// The request Body, if non-nil, will be closed by the underlying
// Transport, even on errors.
//
// On error, any Response can be ignored. A non-nil Response with a
// non-nil error only occurs when CheckRedirect fails, and even then
// the returned Response.Body is already closed.
//
// Generally Get, Post, or PostForm will be used instead of Do.
//
// If the server replies with a redirect, the Client first uses the
// CheckRedirect function to determine whether the redirect should be
// followed. If permitted, a 301, 302, or 303 redirect causes
// subsequent requests to use HTTP method GET
// (or HEAD if the original request was HEAD), with no body.
// A 307 or 308 redirect preserves the original HTTP method and body,
// provided that the Request.GetBody function is defined.
// The NewRequest function automatically sets GetBody for common
// standard library body types.
//
// Any returned error will be of type *url.Error. The url.Error
// value's Timeout method will report true if the request timed out.
func (c *Client) Do(req *Request) (*Response, error) {
return c.do(req)
}
var testHookClientDoResult func(retres *Response, reterr error)
func (c *Client) do(req *Request) (retres *Response, reterr error) {
if testHookClientDoResult != nil {
defer func() { testHookClientDoResult(retres, reterr) }()
}
if req.URL == nil {
req.closeBody()
return nil, &url.Error{
Op: urlErrorOp(req.Method),
Err: errors.New("http: nil Request.URL"),
}
}
var (
deadline = c.deadline()
reqs []*Request
resp *Response
copyHeaders = c.makeHeadersCopier(req)
reqBodyClosed = false // have we closed the current req.Body?
// Redirect behavior:
redirectMethod string
includeBody bool
)
uerr := func(err error) error {
// the body may have been closed already by c.send()
if !reqBodyClosed {
req.closeBody()
}
var urlStr string
if resp != nil && resp.Request != nil {
urlStr = stripPassword(resp.Request.URL)
} else {
urlStr = stripPassword(req.URL)
}
return &url.Error{
Op: urlErrorOp(reqs[0].Method),
URL: urlStr,
Err: err,
}
}
for {
// For all but the first request, create the next
// request hop and replace req.
if len(reqs) > 0 {
loc := resp.Header.Get("Location")
if loc == "" {
// While most 3xx responses include a Location, it is not
// required and 3xx responses without a Location have been
// observed in the wild. See issues #17773 and #49281.
return resp, nil
}
u, err := req.URL.Parse(loc)
if err != nil {
resp.closeBody()
return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err))
}
host := ""
if req.Host != "" && req.Host != req.URL.Host {
// If the caller specified a custom Host header and the
// redirect location is relative, preserve the Host header
// through the redirect. See issue #22233.
if u, _ := url.Parse(loc); u != nil && !u.IsAbs() {
host = req.Host
}
}
ireq := reqs[0]
req = &Request{
Method: redirectMethod,
Response: resp,
URL: u,
Header: make(Header),
Host: host,
Cancel: ireq.Cancel,
ctx: ireq.ctx,
}
if includeBody && ireq.GetBody != nil {
req.Body, err = ireq.GetBody()
if err != nil {
resp.closeBody()
return nil, uerr(err)
}
req.ContentLength = ireq.ContentLength
}
// Copy original headers before setting the Referer,
// in case the user set Referer on their first request.
// If they really want to override, they can do it in
// their CheckRedirect func.
copyHeaders(req)
// Add the Referer header from the most recent
// request URL to the new one, if it's not https->http:
if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" {
req.Header.Set("Referer", ref)
}
err = c.checkRedirect(req, reqs)
// Sentinel error to let users select the
// previous response, without closing its
// body. See Issue 10069.
if err == ErrUseLastResponse {
return resp, nil
}
// Close the previous response's body. But
// read at least some of the body so if it's
// small the underlying TCP connection will be
// re-used. No need to check for errors: if it
// fails, the Transport won't reuse it anyway.
const maxBodySlurpSize = 2 << 10
if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
io.CopyN(io.Discard, resp.Body, maxBodySlurpSize)
}
resp.Body.Close()
if err != nil {
// Special case for Go 1 compatibility: return both the response
// and an error if the CheckRedirect function failed.
// See https://golang.org/issue/3795
// The resp.Body has already been closed.
ue := uerr(err)
ue.(*url.Error).URL = loc
return resp, ue
}
}
reqs = append(reqs, req)
var err error
var didTimeout func() bool
if resp, didTimeout, err = c.send(req, deadline); err != nil {
// c.send() always closes req.Body
reqBodyClosed = true
if !deadline.IsZero() && didTimeout() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
}
}
return nil, uerr(err)
}
var shouldRedirect bool
redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0])
if !shouldRedirect {
return resp, nil
}
req.closeBody()
}
}
// makeHeadersCopier makes a function that copies headers from the
// initial Request, ireq. For every redirect, this function must be called
// so that it can copy headers into the upcoming Request.
func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) {
// The headers to copy are from the very initial request.
// We use a closured callback to keep a reference to these original headers.
var (
ireqhdr = cloneOrMakeHeader(ireq.Header)
icookies map[string][]*Cookie
)
if c.Jar != nil && ireq.Header.Get("Cookie") != "" {
icookies = make(map[string][]*Cookie)
for _, c := range ireq.Cookies() {
icookies[c.Name] = append(icookies[c.Name], c)
}
}
preq := ireq // The previous request
return func(req *Request) {
// If Jar is present and there was some initial cookies provided
// via the request header, then we may need to alter the initial
// cookies as we follow redirects since each redirect may end up
// modifying a pre-existing cookie.
//
// Since cookies already set in the request header do not contain
// information about the original domain and path, the logic below
// assumes any new set cookies override the original cookie
// regardless of domain or path.
//
// See https://golang.org/issue/17494
if c.Jar != nil && icookies != nil {
var changed bool
resp := req.Response // The response that caused the upcoming redirect
for _, c := range resp.Cookies() {
if _, ok := icookies[c.Name]; ok {
delete(icookies, c.Name)
changed = true
}
}
if changed {
ireqhdr.Del("Cookie")
var ss []string
for _, cs := range icookies {
for _, c := range cs {
ss = append(ss, c.Name+"="+c.Value)
}
}
sort.Strings(ss) // Ensure deterministic headers
ireqhdr.Set("Cookie", strings.Join(ss, "; "))
}
}
// Copy the initial request's Header values
// (at least the safe ones).
for k, vv := range ireqhdr {
if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) {
req.Header[k] = vv
}
}
preq = req // Update previous Request with the current request
}
}
func defaultCheckRedirect(req *Request, via []*Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return nil
}
// Post issues a POST to the specified URL.
//
// Caller should close resp.Body when done reading from it.
//
// If the provided body is an io.Closer, it is closed after the
// request.
//
// Post is a wrapper around DefaultClient.Post.
//
// To set custom headers, use NewRequest and DefaultClient.Do.
//
// See the Client.Do method documentation for details on how redirects
// are handled.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and DefaultClient.Do.
func Post(url, contentType string, body io.Reader) (resp *Response, err error) {
return DefaultClient.Post(url, contentType, body)
}
// Post issues a POST to the specified URL.
//
// Caller should close resp.Body when done reading from it.
//
// If the provided body is an io.Closer, it is closed after the
// request.
//
// To set custom headers, use NewRequest and Client.Do.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and Client.Do.
//
// See the Client.Do method documentation for details on how redirects
// are handled.
func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.Do(req)
}
// PostForm issues a POST to the specified URL, with data's keys and
// values URL-encoded as the request body.
//
// The Content-Type header is set to application/x-www-form-urlencoded.
// To set other headers, use NewRequest and DefaultClient.Do.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
// PostForm is a wrapper around DefaultClient.PostForm.
//
// See the Client.Do method documentation for details on how redirects
// are handled.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and DefaultClient.Do.
func PostForm(url string, data url.Values) (resp *Response, err error) {
return DefaultClient.PostForm(url, data)
}
// PostForm issues a POST to the specified URL,
// with data's keys and values URL-encoded as the request body.
//
// The Content-Type header is set to application/x-www-form-urlencoded.
// To set other headers, use NewRequest and Client.Do.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
//
// See the Client.Do method documentation for details on how redirects
// are handled.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and Client.Do.
func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
// Head issues a HEAD to the specified URL. If the response is one of
// the following redirect codes, Head follows the redirect, up to a
// maximum of 10 redirects:
//
// 301 (Moved Permanently)
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
// Head is a wrapper around DefaultClient.Head.
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and DefaultClient.Do.
func Head(url string) (resp *Response, err error) {
return DefaultClient.Head(url)
}
// Head issues a HEAD to the specified URL. If the response is one of the
// following redirect codes, Head follows the redirect after calling the
// Client's CheckRedirect function:
//
// 301 (Moved Permanently)
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
// 308 (Permanent Redirect)
//
// To make a request with a specified context.Context, use NewRequestWithContext
// and Client.Do.
func (c *Client) Head(url string) (resp *Response, err error) {
req, err := NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return c.Do(req)
}
// CloseIdleConnections closes any connections on its Transport which
// were previously connected from previous requests but are now
// sitting idle in a "keep-alive" state. It does not interrupt any
// connections currently in use.
//
// If the Client's Transport does not have a CloseIdleConnections method
// then this method does nothing.
func (c *Client) CloseIdleConnections() {
type closeIdler interface {
CloseIdleConnections()
}
if tr, ok := c.transport().(closeIdler); ok {
tr.CloseIdleConnections()
}
}
// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
// 1. On Read error or close, the stop func is called.
// 2. On Read failure, if reqDidTimeout is true, the error is wrapped and
// marked as net.Error that hit its timeout.
type cancelTimerBody struct {
stop func() // stops the time.Timer waiting to cancel the request
rc io.ReadCloser
reqDidTimeout func() bool
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
if err == nil {
return n, nil
}
if err == io.EOF {
return n, err
}
if b.reqDidTimeout() {
err = &httpError{
err: err.Error() + " (Client.Timeout or context cancellation while reading body)",
timeout: true,
}
}
return n, err
}
func (b *cancelTimerBody) Close() error {
err := b.rc.Close()
b.stop()
return err
}
func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool {
switch CanonicalHeaderKey(headerKey) {
case "Authorization", "Www-Authenticate", "Cookie", "Cookie2":
// Permit sending auth/cookie headers from "foo.com"
// to "sub.foo.com".
// Note that we don't send all cookies to subdomains
// automatically. This function is only used for
// Cookies set explicitly on the initial outgoing
// client request. Cookies automatically added via the
// CookieJar mechanism continue to follow each
// cookie's scope as set by Set-Cookie. But for
// outgoing requests with the Cookie header set
// directly, we don't know their scope, so we assume
// it's for *.domain.com.
ihost := idnaASCIIFromURL(initial)
dhost := idnaASCIIFromURL(dest)
return isDomainOrSubdomain(dhost, ihost)
}
// All other headers are copied:
return true
}
// isDomainOrSubdomain reports whether sub is a subdomain (or exact
// match) of the parent domain.
//
// Both domains must already be in canonical form.
func isDomainOrSubdomain(sub, parent string) bool {
if sub == parent {
return true
}
// If sub is "foo.example.com" and parent is "example.com",
// that means sub must end in "."+parent.
// Do it without allocating.
if !strings.HasSuffix(sub, parent) {
return false
}
return sub[len(sub)-len(parent)-1] == '.'
}
func stripPassword(u *url.URL) string {
_, passSet := u.User.Password()
if passSet {
return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1)
}
return u.String()
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"mime/multipart"
"net/textproto"
"net/url"
)
func cloneURLValues(v url.Values) url.Values {
if v == nil {
return nil
}
// http.Header and url.Values have the same representation, so temporarily
// treat it like http.Header, which does have a clone:
return url.Values(Header(v).Clone())
}
func cloneURL(u *url.URL) *url.URL {
if u == nil {
return nil
}
u2 := new(url.URL)
*u2 = *u
if u.User != nil {
u2.User = new(url.Userinfo)
*u2.User = *u.User
}
return u2
}
func cloneMultipartForm(f *multipart.Form) *multipart.Form {
if f == nil {
return nil
}
f2 := &multipart.Form{
Value: (map[string][]string)(Header(f.Value).Clone()),
}
if f.File != nil {
m := make(map[string][]*multipart.FileHeader)
for k, vv := range f.File {
vv2 := make([]*multipart.FileHeader, len(vv))
for i, v := range vv {
vv2[i] = cloneMultipartFileHeader(v)
}
m[k] = vv2
}
f2.File = m
}
return f2
}
func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
if fh == nil {
return nil
}
fh2 := new(multipart.FileHeader)
*fh2 = *fh
fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone())
return fh2
}
// cloneOrMakeHeader invokes Header.Clone but if the
// result is nil, it'll instead make and return a non-nil Header.
func cloneOrMakeHeader(hdr Header) Header {
clone := hdr.Clone()
if clone == nil {
clone = make(Header)
}
return clone
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"errors"
"fmt"
"log"
"net"
"net/http/internal/ascii"
"net/textproto"
"strconv"
"strings"
"time"
)
// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
// HTTP response or the Cookie header of an HTTP request.
//
// See https://tools.ietf.org/html/rfc6265 for details.
type Cookie struct {
Name string
Value string
Path string // optional
Domain string // optional
Expires time.Time // optional
RawExpires string // for reading cookies only
// MaxAge=0 means no 'Max-Age' attribute specified.
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
// MaxAge>0 means Max-Age attribute present and given in seconds
MaxAge int
Secure bool
HttpOnly bool
SameSite SameSite
Raw string
Unparsed []string // Raw text of unparsed attribute-value pairs
}
// SameSite allows a server to define a cookie attribute making it impossible for
// the browser to send this cookie along with cross-site requests. The main
// goal is to mitigate the risk of cross-origin information leakage, and provide
// some protection against cross-site request forgery attacks.
//
// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
type SameSite int
const (
SameSiteDefaultMode SameSite = iota + 1
SameSiteLaxMode
SameSiteStrictMode
SameSiteNoneMode
)
// readSetCookies parses all "Set-Cookie" values from
// the header h and returns the successfully parsed Cookies.
func readSetCookies(h Header) []*Cookie {
cookieCount := len(h["Set-Cookie"])
if cookieCount == 0 {
return []*Cookie{}
}
cookies := make([]*Cookie, 0, cookieCount)
for _, line := range h["Set-Cookie"] {
parts := strings.Split(textproto.TrimString(line), ";")
if len(parts) == 1 && parts[0] == "" {
continue
}
parts[0] = textproto.TrimString(parts[0])
name, value, ok := strings.Cut(parts[0], "=")
if !ok {
continue
}
name = textproto.TrimString(name)
if !isCookieNameValid(name) {
continue
}
value, ok = parseCookieValue(value, true)
if !ok {
continue
}
c := &Cookie{
Name: name,
Value: value,
Raw: line,
}
for i := 1; i < len(parts); i++ {
parts[i] = textproto.TrimString(parts[i])
if len(parts[i]) == 0 {
continue
}
attr, val, _ := strings.Cut(parts[i], "=")
lowerAttr, isASCII := ascii.ToLower(attr)
if !isASCII {
continue
}
val, ok = parseCookieValue(val, false)
if !ok {
c.Unparsed = append(c.Unparsed, parts[i])
continue
}
switch lowerAttr {
case "samesite":
lowerVal, ascii := ascii.ToLower(val)
if !ascii {
c.SameSite = SameSiteDefaultMode
continue
}
switch lowerVal {
case "lax":
c.SameSite = SameSiteLaxMode
case "strict":
c.SameSite = SameSiteStrictMode
case "none":
c.SameSite = SameSiteNoneMode
default:
c.SameSite = SameSiteDefaultMode
}
continue
case "secure":
c.Secure = true
continue
case "httponly":
c.HttpOnly = true
continue
case "domain":
c.Domain = val
continue
case "max-age":
secs, err := strconv.Atoi(val)
if err != nil || secs != 0 && val[0] == '0' {
break
}
if secs <= 0 {
secs = -1
}
c.MaxAge = secs
continue
case "expires":
c.RawExpires = val
exptime, err := time.Parse(time.RFC1123, val)
if err != nil {
exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val)
if err != nil {
c.Expires = time.Time{}
break
}
}
c.Expires = exptime.UTC()
continue
case "path":
c.Path = val
continue
}
c.Unparsed = append(c.Unparsed, parts[i])
}
cookies = append(cookies, c)
}
return cookies
}
// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
// The provided cookie must have a valid Name. Invalid cookies may be
// silently dropped.
func SetCookie(w ResponseWriter, cookie *Cookie) {
if v := cookie.String(); v != "" {
w.Header().Add("Set-Cookie", v)
}
}
// String returns the serialization of the cookie for use in a Cookie
// header (if only Name and Value are set) or a Set-Cookie response
// header (if other fields are set).
// If c is nil or c.Name is invalid, the empty string is returned.
func (c *Cookie) String() string {
if c == nil || !isCookieNameValid(c.Name) {
return ""
}
// extraCookieLength derived from typical length of cookie attributes
// see RFC 6265 Sec 4.1.
const extraCookieLength = 110
var b strings.Builder
b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength)
b.WriteString(c.Name)
b.WriteRune('=')
b.WriteString(sanitizeCookieValue(c.Value))
if len(c.Path) > 0 {
b.WriteString("; Path=")
b.WriteString(sanitizeCookiePath(c.Path))
}
if len(c.Domain) > 0 {
if validCookieDomain(c.Domain) {
// A c.Domain containing illegal characters is not
// sanitized but simply dropped which turns the cookie
// into a host-only cookie. A leading dot is okay
// but won't be sent.
d := c.Domain
if d[0] == '.' {
d = d[1:]
}
b.WriteString("; Domain=")
b.WriteString(d)
} else {
log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain)
}
}
var buf [len(TimeFormat)]byte
if validCookieExpires(c.Expires) {
b.WriteString("; Expires=")
b.Write(c.Expires.UTC().AppendFormat(buf[:0], TimeFormat))
}
if c.MaxAge > 0 {
b.WriteString("; Max-Age=")
b.Write(strconv.AppendInt(buf[:0], int64(c.MaxAge), 10))
} else if c.MaxAge < 0 {
b.WriteString("; Max-Age=0")
}
if c.HttpOnly {
b.WriteString("; HttpOnly")
}
if c.Secure {
b.WriteString("; Secure")
}
switch c.SameSite {
case SameSiteDefaultMode:
// Skip, default mode is obtained by not emitting the attribute.
case SameSiteNoneMode:
b.WriteString("; SameSite=None")
case SameSiteLaxMode:
b.WriteString("; SameSite=Lax")
case SameSiteStrictMode:
b.WriteString("; SameSite=Strict")
}
return b.String()
}
// Valid reports whether the cookie is valid.
func (c *Cookie) Valid() error {
if c == nil {
return errors.New("http: nil Cookie")
}
if !isCookieNameValid(c.Name) {
return errors.New("http: invalid Cookie.Name")
}
if !c.Expires.IsZero() && !validCookieExpires(c.Expires) {
return errors.New("http: invalid Cookie.Expires")
}
for i := 0; i < len(c.Value); i++ {
if !validCookieValueByte(c.Value[i]) {
return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i])
}
}
if len(c.Path) > 0 {
for i := 0; i < len(c.Path); i++ {
if !validCookiePathByte(c.Path[i]) {
return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i])
}
}
}
if len(c.Domain) > 0 {
if !validCookieDomain(c.Domain) {
return errors.New("http: invalid Cookie.Domain")
}
}
return nil
}
// readCookies parses all "Cookie" values from the header h and
// returns the successfully parsed Cookies.
//
// if filter isn't empty, only cookies of that name are returned.
func readCookies(h Header, filter string) []*Cookie {
lines := h["Cookie"]
if len(lines) == 0 {
return []*Cookie{}
}
cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";"))
for _, line := range lines {
line = textproto.TrimString(line)
var part string
for len(line) > 0 { // continue since we have rest
part, line, _ = strings.Cut(line, ";")
part = textproto.TrimString(part)
if part == "" {
continue
}
name, val, _ := strings.Cut(part, "=")
name = textproto.TrimString(name)
if !isCookieNameValid(name) {
continue
}
if filter != "" && filter != name {
continue
}
val, ok := parseCookieValue(val, true)
if !ok {
continue
}
cookies = append(cookies, &Cookie{Name: name, Value: val})
}
}
return cookies
}
// validCookieDomain reports whether v is a valid cookie domain-value.
func validCookieDomain(v string) bool {
if isCookieDomainName(v) {
return true
}
if net.ParseIP(v) != nil && !strings.Contains(v, ":") {
return true
}
return false
}
// validCookieExpires reports whether v is a valid cookie expires-value.
func validCookieExpires(t time.Time) bool {
// IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601
return t.Year() >= 1601
}
// isCookieDomainName reports whether s is a valid domain name or a valid
// domain name with a leading dot '.'. It is almost a direct copy of
// package net's isDomainName.
func isCookieDomainName(s string) bool {
if len(s) == 0 {
return false
}
if len(s) > 255 {
return false
}
if s[0] == '.' {
// A cookie a domain attribute may start with a leading dot.
s = s[1:]
}
last := byte('.')
ok := false // Ok once we've seen a letter.
partlen := 0
for i := 0; i < len(s); i++ {
c := s[i]
switch {
default:
return false
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
// No '_' allowed here (in contrast to package net).
ok = true
partlen++
case '0' <= c && c <= '9':
// fine
partlen++
case c == '-':
// Byte before dash cannot be dot.
if last == '.' {
return false
}
partlen++
case c == '.':
// Byte before dot cannot be dot, dash.
if last == '.' || last == '-' {
return false
}
if partlen > 63 || partlen == 0 {
return false
}
partlen = 0
}
last = c
}
if last == '-' || partlen > 63 {
return false
}
return ok
}
var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
func sanitizeCookieName(n string) string {
return cookieNameSanitizer.Replace(n)
}
// sanitizeCookieValue produces a suitable cookie-value from v.
// https://tools.ietf.org/html/rfc6265#section-4.1.1
//
// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE )
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
// ; US-ASCII characters excluding CTLs,
// ; whitespace DQUOTE, comma, semicolon,
// ; and backslash
//
// We loosen this as spaces and commas are common in cookie values
// but we produce a quoted cookie-value if and only if v contains
// commas or spaces.
// See https://golang.org/issue/7243 for the discussion.
func sanitizeCookieValue(v string) string {
v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
if len(v) == 0 {
return v
}
if strings.ContainsAny(v, " ,") {
return `"` + v + `"`
}
return v
}
func validCookieValueByte(b byte) bool {
return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'
}
// path-av = "Path=" path-value
// path-value = <any CHAR except CTLs or ";">
func sanitizeCookiePath(v string) string {
return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v)
}
func validCookiePathByte(b byte) bool {
return 0x20 <= b && b < 0x7f && b != ';'
}
func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string {
ok := true
for i := 0; i < len(v); i++ {
if valid(v[i]) {
continue
}
log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName)
ok = false
break
}
if ok {
return v
}
buf := make([]byte, 0, len(v))
for i := 0; i < len(v); i++ {
if b := v[i]; valid(b) {
buf = append(buf, b)
}
}
return string(buf)
}
func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) {
// Strip the quotes, if present.
if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' {
raw = raw[1 : len(raw)-1]
}
for i := 0; i < len(raw); i++ {
if !validCookieValueByte(raw[i]) {
return "", false
}
}
return raw, true
}
func isCookieNameValid(raw string) bool {
if raw == "" {
return false
}
return strings.IndexFunc(raw, isNotToken) < 0
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar.
package cookiejar
import (
"errors"
"fmt"
"net"
"net/http"
"net/http/internal/ascii"
"net/url"
"sort"
"strings"
"sync"
"time"
)
// PublicSuffixList provides the public suffix of a domain. For example:
// - the public suffix of "example.com" is "com",
// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and
// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us".
//
// Implementations of PublicSuffixList must be safe for concurrent use by
// multiple goroutines.
//
// An implementation that always returns "" is valid and may be useful for
// testing but it is not secure: it means that the HTTP server for foo.com can
// set a cookie for bar.com.
//
// A public suffix list implementation is in the package
// golang.org/x/net/publicsuffix.
type PublicSuffixList interface {
// PublicSuffix returns the public suffix of domain.
//
// TODO: specify which of the caller and callee is responsible for IP
// addresses, for leading and trailing dots, for case sensitivity, and
// for IDN/Punycode.
PublicSuffix(domain string) string
// String returns a description of the source of this public suffix
// list. The description will typically contain something like a time
// stamp or version number.
String() string
}
// Options are the options for creating a new Jar.
type Options struct {
// PublicSuffixList is the public suffix list that determines whether
// an HTTP server can set a cookie for a domain.
//
// A nil value is valid and may be useful for testing but it is not
// secure: it means that the HTTP server for foo.co.uk can set a cookie
// for bar.co.uk.
PublicSuffixList PublicSuffixList
}
// Jar implements the http.CookieJar interface from the net/http package.
type Jar struct {
psList PublicSuffixList
// mu locks the remaining fields.
mu sync.Mutex
// entries is a set of entries, keyed by their eTLD+1 and subkeyed by
// their name/domain/path.
entries map[string]map[string]entry
// nextSeqNum is the next sequence number assigned to a new cookie
// created SetCookies.
nextSeqNum uint64
}
// New returns a new cookie jar. A nil *Options is equivalent to a zero
// Options.
func New(o *Options) (*Jar, error) {
jar := &Jar{
entries: make(map[string]map[string]entry),
}
if o != nil {
jar.psList = o.PublicSuffixList
}
return jar, nil
}
// entry is the internal representation of a cookie.
//
// This struct type is not used outside of this package per se, but the exported
// fields are those of RFC 6265.
type entry struct {
Name string
Value string
Domain string
Path string
SameSite string
Secure bool
HttpOnly bool
Persistent bool
HostOnly bool
Expires time.Time
Creation time.Time
LastAccess time.Time
// seqNum is a sequence number so that Cookies returns cookies in a
// deterministic order, even for cookies that have equal Path length and
// equal Creation time. This simplifies testing.
seqNum uint64
}
// id returns the domain;path;name triple of e as an id.
func (e *entry) id() string {
return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
}
// shouldSend determines whether e's cookie qualifies to be included in a
// request to host/path. It is the caller's responsibility to check if the
// cookie is expired.
func (e *entry) shouldSend(https bool, host, path string) bool {
return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
}
// domainMatch checks whether e's Domain allows sending e back to host.
// It differs from "domain-match" of RFC 6265 section 5.1.3 because we treat
// a cookie with an IP address in the Domain always as a host cookie.
func (e *entry) domainMatch(host string) bool {
if e.Domain == host {
return true
}
return !e.HostOnly && hasDotSuffix(host, e.Domain)
}
// pathMatch implements "path-match" according to RFC 6265 section 5.1.4.
func (e *entry) pathMatch(requestPath string) bool {
if requestPath == e.Path {
return true
}
if strings.HasPrefix(requestPath, e.Path) {
if e.Path[len(e.Path)-1] == '/' {
return true // The "/any/" matches "/any/path" case.
} else if requestPath[len(e.Path)] == '/' {
return true // The "/any" matches "/any/path" case.
}
}
return false
}
// hasDotSuffix reports whether s ends in "."+suffix.
func hasDotSuffix(s, suffix string) bool {
return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
}
// Cookies implements the Cookies method of the http.CookieJar interface.
//
// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
return j.cookies(u, time.Now())
}
// cookies is like Cookies but takes the current time as a parameter.
func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
if u.Scheme != "http" && u.Scheme != "https" {
return cookies
}
host, err := canonicalHost(u.Host)
if err != nil {
return cookies
}
key := jarKey(host, j.psList)
j.mu.Lock()
defer j.mu.Unlock()
submap := j.entries[key]
if submap == nil {
return cookies
}
https := u.Scheme == "https"
path := u.Path
if path == "" {
path = "/"
}
modified := false
var selected []entry
for id, e := range submap {
if e.Persistent && !e.Expires.After(now) {
delete(submap, id)
modified = true
continue
}
if !e.shouldSend(https, host, path) {
continue
}
e.LastAccess = now
submap[id] = e
selected = append(selected, e)
modified = true
}
if modified {
if len(submap) == 0 {
delete(j.entries, key)
} else {
j.entries[key] = submap
}
}
// sort according to RFC 6265 section 5.4 point 2: by longest
// path and then by earliest creation time.
sort.Slice(selected, func(i, j int) bool {
s := selected
if len(s[i].Path) != len(s[j].Path) {
return len(s[i].Path) > len(s[j].Path)
}
if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
return ret < 0
}
return s[i].seqNum < s[j].seqNum
})
for _, e := range selected {
cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
}
return cookies
}
// SetCookies implements the SetCookies method of the http.CookieJar interface.
//
// It does nothing if the URL's scheme is not HTTP or HTTPS.
func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.setCookies(u, cookies, time.Now())
}
// setCookies is like SetCookies but takes the current time as parameter.
func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
if len(cookies) == 0 {
return
}
if u.Scheme != "http" && u.Scheme != "https" {
return
}
host, err := canonicalHost(u.Host)
if err != nil {
return
}
key := jarKey(host, j.psList)
defPath := defaultPath(u.Path)
j.mu.Lock()
defer j.mu.Unlock()
submap := j.entries[key]
modified := false
for _, cookie := range cookies {
e, remove, err := j.newEntry(cookie, now, defPath, host)
if err != nil {
continue
}
id := e.id()
if remove {
if submap != nil {
if _, ok := submap[id]; ok {
delete(submap, id)
modified = true
}
}
continue
}
if submap == nil {
submap = make(map[string]entry)
}
if old, ok := submap[id]; ok {
e.Creation = old.Creation
e.seqNum = old.seqNum
} else {
e.Creation = now
e.seqNum = j.nextSeqNum
j.nextSeqNum++
}
e.LastAccess = now
submap[id] = e
modified = true
}
if modified {
if len(submap) == 0 {
delete(j.entries, key)
} else {
j.entries[key] = submap
}
}
}
// canonicalHost strips port from host if present and returns the canonicalized
// host name.
func canonicalHost(host string) (string, error) {
var err error
if hasPort(host) {
host, _, err = net.SplitHostPort(host)
if err != nil {
return "", err
}
}
// Strip trailing dot from fully qualified domain names.
host = strings.TrimSuffix(host, ".")
encoded, err := toASCII(host)
if err != nil {
return "", err
}
// We know this is ascii, no need to check.
lower, _ := ascii.ToLower(encoded)
return lower, nil
}
// hasPort reports whether host contains a port number. host may be a host
// name, an IPv4 or an IPv6 address.
func hasPort(host string) bool {
colons := strings.Count(host, ":")
if colons == 0 {
return false
}
if colons == 1 {
return true
}
return host[0] == '[' && strings.Contains(host, "]:")
}
// jarKey returns the key to use for a jar.
func jarKey(host string, psl PublicSuffixList) string {
if isIP(host) {
return host
}
var i int
if psl == nil {
i = strings.LastIndex(host, ".")
if i <= 0 {
return host
}
} else {
suffix := psl.PublicSuffix(host)
if suffix == host {
return host
}
i = len(host) - len(suffix)
if i <= 0 || host[i-1] != '.' {
// The provided public suffix list psl is broken.
// Storing cookies under host is a safe stopgap.
return host
}
// Only len(suffix) is used to determine the jar key from
// here on, so it is okay if psl.PublicSuffix("www.buggy.psl")
// returns "com" as the jar key is generated from host.
}
prevDot := strings.LastIndex(host[:i-1], ".")
return host[prevDot+1:]
}
// isIP reports whether host is an IP address.
func isIP(host string) bool {
return net.ParseIP(host) != nil
}
// defaultPath returns the directory part of an URL's path according to
// RFC 6265 section 5.1.4.
func defaultPath(path string) string {
if len(path) == 0 || path[0] != '/' {
return "/" // Path is empty or malformed.
}
i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1.
if i == 0 {
return "/" // Path has the form "/abc".
}
return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
}
// newEntry creates an entry from a http.Cookie c. now is the current time and
// is compared to c.Expires to determine deletion of c. defPath and host are the
// default-path and the canonical host name of the URL c was received from.
//
// remove records whether the jar should delete this cookie, as it has already
// expired with respect to now. In this case, e may be incomplete, but it will
// be valid to call e.id (which depends on e's Name, Domain and Path).
//
// A malformed c.Domain will result in an error.
func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
e.Name = c.Name
if c.Path == "" || c.Path[0] != '/' {
e.Path = defPath
} else {
e.Path = c.Path
}
e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
if err != nil {
return e, false, err
}
// MaxAge takes precedence over Expires.
if c.MaxAge < 0 {
return e, true, nil
} else if c.MaxAge > 0 {
e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
e.Persistent = true
} else {
if c.Expires.IsZero() {
e.Expires = endOfTime
e.Persistent = false
} else {
if !c.Expires.After(now) {
return e, true, nil
}
e.Expires = c.Expires
e.Persistent = true
}
}
e.Value = c.Value
e.Secure = c.Secure
e.HttpOnly = c.HttpOnly
switch c.SameSite {
case http.SameSiteDefaultMode:
e.SameSite = "SameSite"
case http.SameSiteStrictMode:
e.SameSite = "SameSite=Strict"
case http.SameSiteLaxMode:
e.SameSite = "SameSite=Lax"
}
return e, false, nil
}
var (
errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
errNoHostname = errors.New("cookiejar: no host name available (IP only)")
)
// endOfTime is the time when session (non-persistent) cookies expire.
// This instant is representable in most date/time formats (not just
// Go's time.Time) and should be far enough in the future.
var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
// domainAndType determines the cookie's domain and hostOnly attribute.
func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
if domain == "" {
// No domain attribute in the SetCookie header indicates a
// host cookie.
return host, true, nil
}
if isIP(host) {
// RFC 6265 is not super clear here, a sensible interpretation
// is that cookies with an IP address in the domain-attribute
// are allowed.
// RFC 6265 section 5.2.3 mandates to strip an optional leading
// dot in the domain-attribute before processing the cookie.
//
// Most browsers don't do that for IP addresses, only curl
// version 7.54) and IE (version 11) do not reject a
// Set-Cookie: a=1; domain=.127.0.0.1
// This leading dot is optional and serves only as hint for
// humans to indicate that a cookie with "domain=.bbc.co.uk"
// would be sent to every subdomain of bbc.co.uk.
// It just doesn't make sense on IP addresses.
// The other processing and validation steps in RFC 6265 just
// collaps to:
if host != domain {
return "", false, errIllegalDomain
}
// According to RFC 6265 such cookies should be treated as
// domain cookies.
// As there are no subdomains of an IP address the treatment
// according to RFC 6265 would be exactly the same as that of
// a host-only cookie. Contemporary browsers (and curl) do
// allows such cookies but treat them as host-only cookies.
// So do we as it just doesn't make sense to label them as
// domain cookies when there is no domain; the whole notion of
// domain cookies requires a domain name to be well defined.
return host, true, nil
}
// From here on: If the cookie is valid, it is a domain cookie (with
// the one exception of a public suffix below).
// See RFC 6265 section 5.2.3.
if domain[0] == '.' {
domain = domain[1:]
}
if len(domain) == 0 || domain[0] == '.' {
// Received either "Domain=." or "Domain=..some.thing",
// both are illegal.
return "", false, errMalformedDomain
}
domain, isASCII := ascii.ToLower(domain)
if !isASCII {
// Received non-ASCII domain, e.g. "perché.com" instead of "xn--perch-fsa.com"
return "", false, errMalformedDomain
}
if domain[len(domain)-1] == '.' {
// We received stuff like "Domain=www.example.com.".
// Browsers do handle such stuff (actually differently) but
// RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in
// requiring a reject. 4.1.2.3 is not normative, but
// "Domain Matching" (5.1.3) and "Canonicalized Host Names"
// (5.1.2) are.
return "", false, errMalformedDomain
}
// See RFC 6265 section 5.3 #5.
if j.psList != nil {
if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
if host == domain {
// This is the one exception in which a cookie
// with a domain attribute is a host cookie.
return host, true, nil
}
return "", false, errIllegalDomain
}
}
// The domain must domain-match host: www.mycompany.com cannot
// set cookies for .ourcompetitors.com.
if host != domain && !hasDotSuffix(host, domain) {
return "", false, errIllegalDomain
}
return domain, false, nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cookiejar
// This file implements the Punycode algorithm from RFC 3492.
import (
"fmt"
"net/http/internal/ascii"
"strings"
"unicode/utf8"
)
// These parameter values are specified in section 5.
//
// All computation is done with int32s, so that overflow behavior is identical
// regardless of whether int is 32-bit or 64-bit.
const (
base int32 = 36
damp int32 = 700
initialBias int32 = 72
initialN int32 = 128
skew int32 = 38
tmax int32 = 26
tmin int32 = 1
)
// encode encodes a string as specified in section 6.3 and prepends prefix to
// the result.
//
// The "while h < length(input)" line in the specification becomes "for
// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
func encode(prefix, s string) (string, error) {
output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
copy(output, prefix)
delta, n, bias := int32(0), initialN, initialBias
b, remaining := int32(0), int32(0)
for _, r := range s {
if r < utf8.RuneSelf {
b++
output = append(output, byte(r))
} else {
remaining++
}
}
h := b
if b > 0 {
output = append(output, '-')
}
for remaining != 0 {
m := int32(0x7fffffff)
for _, r := range s {
if m > r && r >= n {
m = r
}
}
delta += (m - n) * (h + 1)
if delta < 0 {
return "", fmt.Errorf("cookiejar: invalid label %q", s)
}
n = m
for _, r := range s {
if r < n {
delta++
if delta < 0 {
return "", fmt.Errorf("cookiejar: invalid label %q", s)
}
continue
}
if r > n {
continue
}
q := delta
for k := base; ; k += base {
t := k - bias
if t < tmin {
t = tmin
} else if t > tmax {
t = tmax
}
if q < t {
break
}
output = append(output, encodeDigit(t+(q-t)%(base-t)))
q = (q - t) / (base - t)
}
output = append(output, encodeDigit(q))
bias = adapt(delta, h+1, h == b)
delta = 0
h++
remaining--
}
delta++
n++
}
return string(output), nil
}
func encodeDigit(digit int32) byte {
switch {
case 0 <= digit && digit < 26:
return byte(digit + 'a')
case 26 <= digit && digit < 36:
return byte(digit + ('0' - 26))
}
panic("cookiejar: internal error in punycode encoding")
}
// adapt is the bias adaptation function specified in section 6.1.
func adapt(delta, numPoints int32, firstTime bool) int32 {
if firstTime {
delta /= damp
} else {
delta /= 2
}
delta += delta / numPoints
k := int32(0)
for delta > ((base-tmin)*tmax)/2 {
delta /= base - tmin
k += base
}
return k + (base-tmin+1)*delta/(delta+skew)
}
// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and
// friends) and not Punycode (RFC 3492) per se.
// acePrefix is the ASCII Compatible Encoding prefix.
const acePrefix = "xn--"
// toASCII converts a domain or domain label to its ASCII form. For example,
// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
// toASCII("golang") is "golang".
func toASCII(s string) (string, error) {
if ascii.Is(s) {
return s, nil
}
labels := strings.Split(s, ".")
for i, label := range labels {
if !ascii.Is(label) {
a, err := encode(acePrefix, label)
if err != nil {
return "", err
}
labels[i] = a
}
}
return strings.Join(labels, "."), nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fcgi
// This file implements FastCGI from the perspective of a child process.
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/cgi"
"os"
"strings"
"time"
)
// request holds the state for an in-progress request. As soon as it's complete,
// it's converted to an http.Request.
type request struct {
pw *io.PipeWriter
reqId uint16
params map[string]string
buf [1024]byte
rawParams []byte
keepConn bool
}
// envVarsContextKey uniquely identifies a mapping of CGI
// environment variables to their values in a request context
type envVarsContextKey struct{}
func newRequest(reqId uint16, flags uint8) *request {
r := &request{
reqId: reqId,
params: map[string]string{},
keepConn: flags&flagKeepConn != 0,
}
r.rawParams = r.buf[:0]
return r
}
// parseParams reads an encoded []byte into Params.
func (r *request) parseParams() {
text := r.rawParams
r.rawParams = nil
for len(text) > 0 {
keyLen, n := readSize(text)
if n == 0 {
return
}
text = text[n:]
valLen, n := readSize(text)
if n == 0 {
return
}
text = text[n:]
if int(keyLen)+int(valLen) > len(text) {
return
}
key := readString(text, keyLen)
text = text[keyLen:]
val := readString(text, valLen)
text = text[valLen:]
r.params[key] = val
}
}
// response implements http.ResponseWriter.
type response struct {
req *request
header http.Header
code int
wroteHeader bool
wroteCGIHeader bool
w *bufWriter
}
func newResponse(c *child, req *request) *response {
return &response{
req: req,
header: http.Header{},
w: newWriter(c.conn, typeStdout, req.reqId),
}
}
func (r *response) Header() http.Header {
return r.header
}
func (r *response) Write(p []byte) (n int, err error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if !r.wroteCGIHeader {
r.writeCGIHeader(p)
}
return r.w.Write(p)
}
func (r *response) WriteHeader(code int) {
if r.wroteHeader {
return
}
r.wroteHeader = true
r.code = code
if code == http.StatusNotModified {
// Must not have body.
r.header.Del("Content-Type")
r.header.Del("Content-Length")
r.header.Del("Transfer-Encoding")
}
if r.header.Get("Date") == "" {
r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
}
}
// writeCGIHeader finalizes the header sent to the client and writes it to the output.
// p is not written by writeHeader, but is the first chunk of the body
// that will be written. It is sniffed for a Content-Type if none is
// set explicitly.
func (r *response) writeCGIHeader(p []byte) {
if r.wroteCGIHeader {
return
}
r.wroteCGIHeader = true
fmt.Fprintf(r.w, "Status: %d %s\r\n", r.code, http.StatusText(r.code))
if _, hasType := r.header["Content-Type"]; r.code != http.StatusNotModified && !hasType {
r.header.Set("Content-Type", http.DetectContentType(p))
}
r.header.Write(r.w)
r.w.WriteString("\r\n")
r.w.Flush()
}
func (r *response) Flush() {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
r.w.Flush()
}
func (r *response) Close() error {
r.Flush()
return r.w.Close()
}
type child struct {
conn *conn
handler http.Handler
requests map[uint16]*request // keyed by request ID
}
func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
return &child{
conn: newConn(rwc),
handler: handler,
requests: make(map[uint16]*request),
}
}
func (c *child) serve() {
defer c.conn.Close()
defer c.cleanUp()
var rec record
for {
if err := rec.read(c.conn.rwc); err != nil {
return
}
if err := c.handleRecord(&rec); err != nil {
return
}
}
}
var errCloseConn = errors.New("fcgi: connection should be closed")
var emptyBody = io.NopCloser(strings.NewReader(""))
// ErrRequestAborted is returned by Read when a handler attempts to read the
// body of a request that has been aborted by the web server.
var ErrRequestAborted = errors.New("fcgi: request aborted by web server")
// ErrConnClosed is returned by Read when a handler attempts to read the body of
// a request after the connection to the web server has been closed.
var ErrConnClosed = errors.New("fcgi: connection to web server closed")
func (c *child) handleRecord(rec *record) error {
req, ok := c.requests[rec.h.Id]
if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
// The spec says to ignore unknown request IDs.
return nil
}
switch rec.h.Type {
case typeBeginRequest:
if req != nil {
// The server is trying to begin a request with the same ID
// as an in-progress request. This is an error.
return errors.New("fcgi: received ID that is already in-flight")
}
var br beginRequest
if err := br.read(rec.content()); err != nil {
return err
}
if br.role != roleResponder {
c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
return nil
}
req = newRequest(rec.h.Id, br.flags)
c.requests[rec.h.Id] = req
return nil
case typeParams:
// NOTE(eds): Technically a key-value pair can straddle the boundary
// between two packets. We buffer until we've received all parameters.
if len(rec.content()) > 0 {
req.rawParams = append(req.rawParams, rec.content()...)
return nil
}
req.parseParams()
return nil
case typeStdin:
content := rec.content()
if req.pw == nil {
var body io.ReadCloser
if len(content) > 0 {
// body could be an io.LimitReader, but it shouldn't matter
// as long as both sides are behaving.
body, req.pw = io.Pipe()
} else {
body = emptyBody
}
go c.serveRequest(req, body)
}
if len(content) > 0 {
// TODO(eds): This blocks until the handler reads from the pipe.
// If the handler takes a long time, it might be a problem.
req.pw.Write(content)
} else {
delete(c.requests, req.reqId)
if req.pw != nil {
req.pw.Close()
}
}
return nil
case typeGetValues:
values := map[string]string{"FCGI_MPXS_CONNS": "1"}
c.conn.writePairs(typeGetValuesResult, 0, values)
return nil
case typeData:
// If the filter role is implemented, read the data stream here.
return nil
case typeAbortRequest:
delete(c.requests, rec.h.Id)
c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
if req.pw != nil {
req.pw.CloseWithError(ErrRequestAborted)
}
if !req.keepConn {
// connection will close upon return
return errCloseConn
}
return nil
default:
b := make([]byte, 8)
b[0] = byte(rec.h.Type)
c.conn.writeRecord(typeUnknownType, 0, b)
return nil
}
}
// filterOutUsedEnvVars returns a new map of env vars without the
// variables in the given envVars map that are read for creating each http.Request
func filterOutUsedEnvVars(envVars map[string]string) map[string]string {
withoutUsedEnvVars := make(map[string]string)
for k, v := range envVars {
if addFastCGIEnvToContext(k) {
withoutUsedEnvVars[k] = v
}
}
return withoutUsedEnvVars
}
func (c *child) serveRequest(req *request, body io.ReadCloser) {
r := newResponse(c, req)
httpReq, err := cgi.RequestFromMap(req.params)
if err != nil {
// there was an error reading the request
r.WriteHeader(http.StatusInternalServerError)
c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error()))
} else {
httpReq.Body = body
withoutUsedEnvVars := filterOutUsedEnvVars(req.params)
envVarCtx := context.WithValue(httpReq.Context(), envVarsContextKey{}, withoutUsedEnvVars)
httpReq = httpReq.WithContext(envVarCtx)
c.handler.ServeHTTP(r, httpReq)
}
// Make sure we serve something even if nothing was written to r
r.Write(nil)
r.Close()
c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
// Consume the entire body, so the host isn't still writing to
// us when we close the socket below in the !keepConn case,
// otherwise we'd send a RST. (golang.org/issue/4183)
// TODO(bradfitz): also bound this copy in time. Or send
// some sort of abort request to the host, so the host
// can properly cut off the client sending all the data.
// For now just bound it a little and
io.CopyN(io.Discard, body, 100<<20)
body.Close()
if !req.keepConn {
c.conn.Close()
}
}
func (c *child) cleanUp() {
for _, req := range c.requests {
if req.pw != nil {
// race with call to Close in c.serveRequest doesn't matter because
// Pipe(Reader|Writer).Close are idempotent
req.pw.CloseWithError(ErrConnClosed)
}
}
}
// Serve accepts incoming FastCGI connections on the listener l, creating a new
// goroutine for each. The goroutine reads requests and then calls handler
// to reply to them.
// If l is nil, Serve accepts connections from os.Stdin.
// If handler is nil, http.DefaultServeMux is used.
func Serve(l net.Listener, handler http.Handler) error {
if l == nil {
var err error
l, err = net.FileListener(os.Stdin)
if err != nil {
return err
}
defer l.Close()
}
if handler == nil {
handler = http.DefaultServeMux
}
for {
rw, err := l.Accept()
if err != nil {
return err
}
c := newChild(rw, handler)
go c.serve()
}
}
// ProcessEnv returns FastCGI environment variables associated with the request r
// for which no effort was made to be included in the request itself - the data
// is hidden in the request's context. As an example, if REMOTE_USER is set for a
// request, it will not be found anywhere in r, but it will be included in
// ProcessEnv's response (via r's context).
func ProcessEnv(r *http.Request) map[string]string {
env, _ := r.Context().Value(envVarsContextKey{}).(map[string]string)
return env
}
// addFastCGIEnvToContext reports whether to include the FastCGI environment variable s
// in the http.Request.Context, accessible via ProcessEnv.
func addFastCGIEnvToContext(s string) bool {
// Exclude things supported by net/http natively:
switch s {
case "CONTENT_LENGTH", "CONTENT_TYPE", "HTTPS",
"PATH_INFO", "QUERY_STRING", "REMOTE_ADDR",
"REMOTE_HOST", "REMOTE_PORT", "REQUEST_METHOD",
"REQUEST_URI", "SCRIPT_NAME", "SERVER_PROTOCOL":
return false
}
if strings.HasPrefix(s, "HTTP_") {
return false
}
// Explicitly include FastCGI-specific things.
// This list is redundant with the default "return true" below.
// Consider this documentation of the sorts of things we expect
// to maybe see.
switch s {
case "REMOTE_USER":
return true
}
// Unknown, so include it to be safe.
return true
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fcgi implements the FastCGI protocol.
//
// See https://fast-cgi.github.io/ for an unofficial mirror of the
// original documentation.
//
// Currently only the responder role is supported.
package fcgi
// This file defines the raw protocol and some utilities used by the child and
// the host.
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"sync"
)
// recType is a record type, as defined by
// https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22#S8
type recType uint8
const (
typeBeginRequest recType = 1
typeAbortRequest recType = 2
typeEndRequest recType = 3
typeParams recType = 4
typeStdin recType = 5
typeStdout recType = 6
typeStderr recType = 7
typeData recType = 8
typeGetValues recType = 9
typeGetValuesResult recType = 10
typeUnknownType recType = 11
)
// keep the connection between web-server and responder open after request
const flagKeepConn = 1
const (
maxWrite = 65535 // maximum record body
maxPad = 255
)
const (
roleResponder = iota + 1 // only Responders are implemented.
roleAuthorizer
roleFilter
)
const (
statusRequestComplete = iota
statusCantMultiplex
statusOverloaded
statusUnknownRole
)
type header struct {
Version uint8
Type recType
Id uint16
ContentLength uint16
PaddingLength uint8
Reserved uint8
}
type beginRequest struct {
role uint16
flags uint8
reserved [5]uint8
}
func (br *beginRequest) read(content []byte) error {
if len(content) != 8 {
return errors.New("fcgi: invalid begin request record")
}
br.role = binary.BigEndian.Uint16(content)
br.flags = content[2]
return nil
}
// for padding so we don't have to allocate all the time
// not synchronized because we don't care what the contents are
var pad [maxPad]byte
func (h *header) init(recType recType, reqId uint16, contentLength int) {
h.Version = 1
h.Type = recType
h.Id = reqId
h.ContentLength = uint16(contentLength)
h.PaddingLength = uint8(-contentLength & 7)
}
// conn sends records over rwc
type conn struct {
mutex sync.Mutex
rwc io.ReadWriteCloser
// to avoid allocations
buf bytes.Buffer
h header
}
func newConn(rwc io.ReadWriteCloser) *conn {
return &conn{rwc: rwc}
}
func (c *conn) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.rwc.Close()
}
type record struct {
h header
buf [maxWrite + maxPad]byte
}
func (rec *record) read(r io.Reader) (err error) {
if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
return err
}
if rec.h.Version != 1 {
return errors.New("fcgi: invalid header version")
}
n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
if _, err = io.ReadFull(r, rec.buf[:n]); err != nil {
return err
}
return nil
}
func (r *record) content() []byte {
return r.buf[:r.h.ContentLength]
}
// writeRecord writes and sends a single record.
func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.buf.Reset()
c.h.init(recType, reqId, len(b))
if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
return err
}
if _, err := c.buf.Write(b); err != nil {
return err
}
if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
return err
}
_, err := c.rwc.Write(c.buf.Bytes())
return err
}
func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error {
b := make([]byte, 8)
binary.BigEndian.PutUint32(b, uint32(appStatus))
b[4] = protocolStatus
return c.writeRecord(typeEndRequest, reqId, b)
}
func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
w := newWriter(c, recType, reqId)
b := make([]byte, 8)
for k, v := range pairs {
n := encodeSize(b, uint32(len(k)))
n += encodeSize(b[n:], uint32(len(v)))
if _, err := w.Write(b[:n]); err != nil {
return err
}
if _, err := w.WriteString(k); err != nil {
return err
}
if _, err := w.WriteString(v); err != nil {
return err
}
}
w.Close()
return nil
}
func readSize(s []byte) (uint32, int) {
if len(s) == 0 {
return 0, 0
}
size, n := uint32(s[0]), 1
if size&(1<<7) != 0 {
if len(s) < 4 {
return 0, 0
}
n = 4
size = binary.BigEndian.Uint32(s)
size &^= 1 << 31
}
return size, n
}
func readString(s []byte, size uint32) string {
if size > uint32(len(s)) {
return ""
}
return string(s[:size])
}
func encodeSize(b []byte, size uint32) int {
if size > 127 {
size |= 1 << 31
binary.BigEndian.PutUint32(b, size)
return 4
}
b[0] = byte(size)
return 1
}
// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
// Closed.
type bufWriter struct {
closer io.Closer
*bufio.Writer
}
func (w *bufWriter) Close() error {
if err := w.Writer.Flush(); err != nil {
w.closer.Close()
return err
}
return w.closer.Close()
}
func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
s := &streamWriter{c: c, recType: recType, reqId: reqId}
w := bufio.NewWriterSize(s, maxWrite)
return &bufWriter{s, w}
}
// streamWriter abstracts out the separation of a stream into discrete records.
// It only writes maxWrite bytes at a time.
type streamWriter struct {
c *conn
recType recType
reqId uint16
}
func (w *streamWriter) Write(p []byte) (int, error) {
nn := 0
for len(p) > 0 {
n := len(p)
if n > maxWrite {
n = maxWrite
}
if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil {
return nn, err
}
nn += n
p = p[n:]
}
return nn, nil
}
func (w *streamWriter) Close() error {
// send empty record to close the stream
return w.c.writeRecord(w.recType, w.reqId, nil)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"fmt"
"io"
)
// fileTransport implements RoundTripper for the 'file' protocol.
type fileTransport struct {
fh fileHandler
}
// NewFileTransport returns a new RoundTripper, serving the provided
// FileSystem. The returned RoundTripper ignores the URL host in its
// incoming requests, as well as most other properties of the
// request.
//
// The typical use case for NewFileTransport is to register the "file"
// protocol with a Transport, as in:
//
// t := &http.Transport{}
// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
// c := &http.Client{Transport: t}
// res, err := c.Get("file:///etc/passwd")
// ...
func NewFileTransport(fs FileSystem) RoundTripper {
return fileTransport{fileHandler{fs}}
}
func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
// We start ServeHTTP in a goroutine, which may take a long
// time if the file is large. The newPopulateResponseWriter
// call returns a channel which either ServeHTTP or finish()
// sends our *Response on, once the *Response itself has been
// populated (even if the body itself is still being
// written to the res.Body, a pipe)
rw, resc := newPopulateResponseWriter()
go func() {
t.fh.ServeHTTP(rw, req)
rw.finish()
}()
return <-resc, nil
}
func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
pr, pw := io.Pipe()
rw := &populateResponse{
ch: make(chan *Response),
pw: pw,
res: &Response{
Proto: "HTTP/1.0",
ProtoMajor: 1,
Header: make(Header),
Close: true,
Body: pr,
},
}
return rw, rw.ch
}
// populateResponse is a ResponseWriter that populates the *Response
// in res, and writes its body to a pipe connected to the response
// body. Once writes begin or finish() is called, the response is sent
// on ch.
type populateResponse struct {
res *Response
ch chan *Response
wroteHeader bool
hasContent bool
sentResponse bool
pw *io.PipeWriter
}
func (pr *populateResponse) finish() {
if !pr.wroteHeader {
pr.WriteHeader(500)
}
if !pr.sentResponse {
pr.sendResponse()
}
pr.pw.Close()
}
func (pr *populateResponse) sendResponse() {
if pr.sentResponse {
return
}
pr.sentResponse = true
if pr.hasContent {
pr.res.ContentLength = -1
}
pr.ch <- pr.res
}
func (pr *populateResponse) Header() Header {
return pr.res.Header
}
func (pr *populateResponse) WriteHeader(code int) {
if pr.wroteHeader {
return
}
pr.wroteHeader = true
pr.res.StatusCode = code
pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
}
func (pr *populateResponse) Write(p []byte) (n int, err error) {
if !pr.wroteHeader {
pr.WriteHeader(StatusOK)
}
pr.hasContent = true
if !pr.sentResponse {
pr.sendResponse()
}
return pr.pw.Write(p)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP file system request handler
package http
import (
"errors"
"fmt"
"internal/safefilepath"
"io"
"io/fs"
"mime"
"mime/multipart"
"net/textproto"
"net/url"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
)
// A Dir implements FileSystem using the native file system restricted to a
// specific directory tree.
//
// While the FileSystem.Open method takes '/'-separated paths, a Dir's string
// value is a filename on the native file system, not a URL, so it is separated
// by filepath.Separator, which isn't necessarily '/'.
//
// Note that Dir could expose sensitive files and directories. Dir will follow
// symlinks pointing out of the directory tree, which can be especially dangerous
// if serving from a directory in which users are able to create arbitrary symlinks.
// Dir will also allow access to files and directories starting with a period,
// which could expose sensitive directories like .git or sensitive files like
// .htpasswd. To exclude files with a leading period, remove the files/directories
// from the server or create a custom FileSystem implementation.
//
// An empty Dir is treated as ".".
type Dir string
// mapOpenError maps the provided non-nil error from opening name
// to a possibly better non-nil error. In particular, it turns OS-specific errors
// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552.
func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error {
if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) {
return originalErr
}
parts := strings.Split(name, string(sep))
for i := range parts {
if parts[i] == "" {
continue
}
fi, err := stat(strings.Join(parts[:i+1], string(sep)))
if err != nil {
return originalErr
}
if !fi.IsDir() {
return fs.ErrNotExist
}
}
return originalErr
}
// Open implements FileSystem using os.Open, opening files for reading rooted
// and relative to the directory d.
func (d Dir) Open(name string) (File, error) {
path, err := safefilepath.FromFS(path.Clean("/" + name))
if err != nil {
return nil, errors.New("http: invalid or unsafe file path")
}
dir := string(d)
if dir == "" {
dir = "."
}
fullName := filepath.Join(dir, path)
f, err := os.Open(fullName)
if err != nil {
return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat)
}
return f, nil
}
// A FileSystem implements access to a collection of named files.
// The elements in a file path are separated by slash ('/', U+002F)
// characters, regardless of host operating system convention.
// See the FileServer function to convert a FileSystem to a Handler.
//
// This interface predates the fs.FS interface, which can be used instead:
// the FS adapter function converts an fs.FS to a FileSystem.
type FileSystem interface {
Open(name string) (File, error)
}
// A File is returned by a FileSystem's Open method and can be
// served by the FileServer implementation.
//
// The methods should behave the same as those on an *os.File.
type File interface {
io.Closer
io.Reader
io.Seeker
Readdir(count int) ([]fs.FileInfo, error)
Stat() (fs.FileInfo, error)
}
type anyDirs interface {
len() int
name(i int) string
isDir(i int) bool
}
type fileInfoDirs []fs.FileInfo
func (d fileInfoDirs) len() int { return len(d) }
func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() }
func (d fileInfoDirs) name(i int) string { return d[i].Name() }
type dirEntryDirs []fs.DirEntry
func (d dirEntryDirs) len() int { return len(d) }
func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() }
func (d dirEntryDirs) name(i int) string { return d[i].Name() }
func dirList(w ResponseWriter, r *Request, f File) {
// Prefer to use ReadDir instead of Readdir,
// because the former doesn't require calling
// Stat on every entry of a directory on Unix.
var dirs anyDirs
var err error
if d, ok := f.(fs.ReadDirFile); ok {
var list dirEntryDirs
list, err = d.ReadDir(-1)
dirs = list
} else {
var list fileInfoDirs
list, err = f.Readdir(-1)
dirs = list
}
if err != nil {
logf(r, "http: error reading directory: %v", err)
Error(w, "Error reading directory", StatusInternalServerError)
return
}
sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) })
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, "<pre>\n")
for i, n := 0, dirs.len(); i < n; i++ {
name := dirs.name(i)
if dirs.isDir(i) {
name += "/"
}
// name may contain '?' or '#', which must be escaped to remain
// part of the URL path, and not indicate the start of a query
// string or fragment.
url := url.URL{Path: name}
fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
}
fmt.Fprintf(w, "</pre>\n")
}
// ServeContent replies to the request using the content in the
// provided ReadSeeker. The main benefit of ServeContent over io.Copy
// is that it handles Range requests properly, sets the MIME type, and
// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since,
// and If-Range requests.
//
// If the response's Content-Type header is not set, ServeContent
// first tries to deduce the type from name's file extension and,
// if that fails, falls back to reading the first block of the content
// and passing it to DetectContentType.
// The name is otherwise unused; in particular it can be empty and is
// never sent in the response.
//
// If modtime is not the zero time or Unix epoch, ServeContent
// includes it in a Last-Modified header in the response. If the
// request includes an If-Modified-Since header, ServeContent uses
// modtime to decide whether the content needs to be sent at all.
//
// The content's Seek method must work: ServeContent uses
// a seek to the end of the content to determine its size.
//
// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range.
//
// Note that *os.File implements the io.ReadSeeker interface.
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
sizeFunc := func() (int64, error) {
size, err := content.Seek(0, io.SeekEnd)
if err != nil {
return 0, errSeeker
}
_, err = content.Seek(0, io.SeekStart)
if err != nil {
return 0, errSeeker
}
return size, nil
}
serveContent(w, req, name, modtime, sizeFunc, content)
}
// errSeeker is returned by ServeContent's sizeFunc when the content
// doesn't seek properly. The underlying Seeker's error text isn't
// included in the sizeFunc reply so it's not sent over HTTP to end
// users.
var errSeeker = errors.New("seeker can't seek")
// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of
// all of the byte-range-spec values is greater than the content size.
var errNoOverlap = errors.New("invalid range: failed to overlap")
// if name is empty, filename is unknown. (used for mime type, before sniffing)
// if modtime.IsZero(), modtime is unknown.
// content must be seeked to the beginning of the file.
// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response.
func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) {
setLastModified(w, modtime)
done, rangeReq := checkPreconditions(w, r, modtime)
if done {
return
}
code := StatusOK
// If Content-Type isn't set, use the file's extension to find it, but
// if the Content-Type is unset explicitly, do not sniff the type.
ctypes, haveType := w.Header()["Content-Type"]
var ctype string
if !haveType {
ctype = mime.TypeByExtension(filepath.Ext(name))
if ctype == "" {
// read a chunk to decide between utf-8 text and binary
var buf [sniffLen]byte
n, _ := io.ReadFull(content, buf[:])
ctype = DetectContentType(buf[:n])
_, err := content.Seek(0, io.SeekStart) // rewind to output whole file
if err != nil {
Error(w, "seeker can't seek", StatusInternalServerError)
return
}
}
w.Header().Set("Content-Type", ctype)
} else if len(ctypes) > 0 {
ctype = ctypes[0]
}
size, err := sizeFunc()
if err != nil {
Error(w, err.Error(), StatusInternalServerError)
return
}
if size < 0 {
// Should never happen but just to be sure
Error(w, "negative content size computed", StatusInternalServerError)
return
}
// handle Content-Range header.
sendSize := size
var sendContent io.Reader = content
ranges, err := parseRange(rangeReq, size)
switch err {
case nil:
case errNoOverlap:
if size == 0 {
// Some clients add a Range header to all requests to
// limit the size of the response. If the file is empty,
// ignore the range header and respond with a 200 rather
// than a 416.
ranges = nil
break
}
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
fallthrough
default:
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
if sumRangesSize(ranges) > size {
// The total number of bytes in all the ranges
// is larger than the size of the file by
// itself, so this is probably an attack, or a
// dumb client. Ignore the range request.
ranges = nil
}
switch {
case len(ranges) == 1:
// RFC 7233, Section 4.1:
// "If a single part is being transferred, the server
// generating the 206 response MUST generate a
// Content-Range header field, describing what range
// of the selected representation is enclosed, and a
// payload consisting of the range.
// ...
// A server MUST NOT generate a multipart response to
// a request for a single range, since a client that
// does not request multiple parts might not support
// multipart responses."
ra := ranges[0]
if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
sendSize = ra.length
code = StatusPartialContent
w.Header().Set("Content-Range", ra.contentRange(size))
case len(ranges) > 1:
sendSize = rangesMIMESize(ranges, ctype, size)
code = StatusPartialContent
pr, pw := io.Pipe()
mw := multipart.NewWriter(pw)
w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
sendContent = pr
defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
go func() {
for _, ra := range ranges {
part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
if err != nil {
pw.CloseWithError(err)
return
}
if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
pw.CloseWithError(err)
return
}
if _, err := io.CopyN(part, content, ra.length); err != nil {
pw.CloseWithError(err)
return
}
}
mw.Close()
pw.Close()
}()
}
w.Header().Set("Accept-Ranges", "bytes")
if w.Header().Get("Content-Encoding") == "" {
w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10))
}
w.WriteHeader(code)
if r.Method != "HEAD" {
io.CopyN(w, sendContent, sendSize)
}
}
// scanETag determines if a syntactically valid ETag is present at s. If so,
// the ETag and remaining text after consuming ETag is returned. Otherwise,
// it returns "", "".
func scanETag(s string) (etag string, remain string) {
s = textproto.TrimString(s)
start := 0
if strings.HasPrefix(s, "W/") {
start = 2
}
if len(s[start:]) < 2 || s[start] != '"' {
return "", ""
}
// ETag is either W/"text" or "text".
// See RFC 7232 2.3.
for i := start + 1; i < len(s); i++ {
c := s[i]
switch {
// Character values allowed in ETags.
case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80:
case c == '"':
return s[:i+1], s[i+1:]
default:
return "", ""
}
}
return "", ""
}
// etagStrongMatch reports whether a and b match using strong ETag comparison.
// Assumes a and b are valid ETags.
func etagStrongMatch(a, b string) bool {
return a == b && a != "" && a[0] == '"'
}
// etagWeakMatch reports whether a and b match using weak ETag comparison.
// Assumes a and b are valid ETags.
func etagWeakMatch(a, b string) bool {
return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/")
}
// condResult is the result of an HTTP request precondition check.
// See https://tools.ietf.org/html/rfc7232 section 3.
type condResult int
const (
condNone condResult = iota
condTrue
condFalse
)
func checkIfMatch(w ResponseWriter, r *Request) condResult {
im := r.Header.Get("If-Match")
if im == "" {
return condNone
}
for {
im = textproto.TrimString(im)
if len(im) == 0 {
break
}
if im[0] == ',' {
im = im[1:]
continue
}
if im[0] == '*' {
return condTrue
}
etag, remain := scanETag(im)
if etag == "" {
break
}
if etagStrongMatch(etag, w.Header().get("Etag")) {
return condTrue
}
im = remain
}
return condFalse
}
func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult {
ius := r.Header.Get("If-Unmodified-Since")
if ius == "" || isZeroTime(modtime) {
return condNone
}
t, err := ParseTime(ius)
if err != nil {
return condNone
}
// The Last-Modified header truncates sub-second precision so
// the modtime needs to be truncated too.
modtime = modtime.Truncate(time.Second)
if ret := modtime.Compare(t); ret <= 0 {
return condTrue
}
return condFalse
}
func checkIfNoneMatch(w ResponseWriter, r *Request) condResult {
inm := r.Header.get("If-None-Match")
if inm == "" {
return condNone
}
buf := inm
for {
buf = textproto.TrimString(buf)
if len(buf) == 0 {
break
}
if buf[0] == ',' {
buf = buf[1:]
continue
}
if buf[0] == '*' {
return condFalse
}
etag, remain := scanETag(buf)
if etag == "" {
break
}
if etagWeakMatch(etag, w.Header().get("Etag")) {
return condFalse
}
buf = remain
}
return condTrue
}
func checkIfModifiedSince(r *Request, modtime time.Time) condResult {
if r.Method != "GET" && r.Method != "HEAD" {
return condNone
}
ims := r.Header.Get("If-Modified-Since")
if ims == "" || isZeroTime(modtime) {
return condNone
}
t, err := ParseTime(ims)
if err != nil {
return condNone
}
// The Last-Modified header truncates sub-second precision so
// the modtime needs to be truncated too.
modtime = modtime.Truncate(time.Second)
if ret := modtime.Compare(t); ret <= 0 {
return condFalse
}
return condTrue
}
func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult {
if r.Method != "GET" && r.Method != "HEAD" {
return condNone
}
ir := r.Header.get("If-Range")
if ir == "" {
return condNone
}
etag, _ := scanETag(ir)
if etag != "" {
if etagStrongMatch(etag, w.Header().Get("Etag")) {
return condTrue
} else {
return condFalse
}
}
// The If-Range value is typically the ETag value, but it may also be
// the modtime date. See golang.org/issue/8367.
if modtime.IsZero() {
return condFalse
}
t, err := ParseTime(ir)
if err != nil {
return condFalse
}
if t.Unix() == modtime.Unix() {
return condTrue
}
return condFalse
}
var unixEpochTime = time.Unix(0, 0)
// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0).
func isZeroTime(t time.Time) bool {
return t.IsZero() || t.Equal(unixEpochTime)
}
func setLastModified(w ResponseWriter, modtime time.Time) {
if !isZeroTime(modtime) {
w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
}
}
func writeNotModified(w ResponseWriter) {
// RFC 7232 section 4.1:
// a sender SHOULD NOT generate representation metadata other than the
// above listed fields unless said metadata exists for the purpose of
// guiding cache updates (e.g., Last-Modified might be useful if the
// response does not have an ETag field).
h := w.Header()
delete(h, "Content-Type")
delete(h, "Content-Length")
delete(h, "Content-Encoding")
if h.Get("Etag") != "" {
delete(h, "Last-Modified")
}
w.WriteHeader(StatusNotModified)
}
// checkPreconditions evaluates request preconditions and reports whether a precondition
// resulted in sending StatusNotModified or StatusPreconditionFailed.
func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) {
// This function carefully follows RFC 7232 section 6.
ch := checkIfMatch(w, r)
if ch == condNone {
ch = checkIfUnmodifiedSince(r, modtime)
}
if ch == condFalse {
w.WriteHeader(StatusPreconditionFailed)
return true, ""
}
switch checkIfNoneMatch(w, r) {
case condFalse:
if r.Method == "GET" || r.Method == "HEAD" {
writeNotModified(w)
return true, ""
} else {
w.WriteHeader(StatusPreconditionFailed)
return true, ""
}
case condNone:
if checkIfModifiedSince(r, modtime) == condFalse {
writeNotModified(w)
return true, ""
}
}
rangeHeader = r.Header.get("Range")
if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse {
rangeHeader = ""
}
return false, rangeHeader
}
// name is '/'-separated, not filepath.Separator.
func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
const indexPage = "/index.html"
// redirect .../index.html to .../
// can't use Redirect() because that would make the path absolute,
// which would be a problem running under StripPrefix
if strings.HasSuffix(r.URL.Path, indexPage) {
localRedirect(w, r, "./")
return
}
f, err := fs.Open(name)
if err != nil {
msg, code := toHTTPError(err)
Error(w, msg, code)
return
}
defer f.Close()
d, err := f.Stat()
if err != nil {
msg, code := toHTTPError(err)
Error(w, msg, code)
return
}
if redirect {
// redirect to canonical path: / at end of directory url
// r.URL.Path always begins with /
url := r.URL.Path
if d.IsDir() {
if url[len(url)-1] != '/' {
localRedirect(w, r, path.Base(url)+"/")
return
}
} else {
if url[len(url)-1] == '/' {
localRedirect(w, r, "../"+path.Base(url))
return
}
}
}
if d.IsDir() {
url := r.URL.Path
// redirect if the directory name doesn't end in a slash
if url == "" || url[len(url)-1] != '/' {
localRedirect(w, r, path.Base(url)+"/")
return
}
// use contents of index.html for directory, if present
index := strings.TrimSuffix(name, "/") + indexPage
ff, err := fs.Open(index)
if err == nil {
defer ff.Close()
dd, err := ff.Stat()
if err == nil {
d = dd
f = ff
}
}
}
// Still a directory? (we didn't find an index.html file)
if d.IsDir() {
if checkIfModifiedSince(r, d.ModTime()) == condFalse {
writeNotModified(w)
return
}
setLastModified(w, d.ModTime())
dirList(w, r, f)
return
}
// serveContent will check modification time
sizeFunc := func() (int64, error) { return d.Size(), nil }
serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f)
}
// toHTTPError returns a non-specific HTTP error message and status code
// for a given non-nil error value. It's important that toHTTPError does not
// actually return err.Error(), since msg and httpStatus are returned to users,
// and historically Go's ServeContent always returned just "404 Not Found" for
// all errors. We don't want to start leaking information in error messages.
func toHTTPError(err error) (msg string, httpStatus int) {
if errors.Is(err, fs.ErrNotExist) {
return "404 page not found", StatusNotFound
}
if errors.Is(err, fs.ErrPermission) {
return "403 Forbidden", StatusForbidden
}
// Default:
return "500 Internal Server Error", StatusInternalServerError
}
// localRedirect gives a Moved Permanently response.
// It does not convert relative paths to absolute paths like Redirect does.
func localRedirect(w ResponseWriter, r *Request, newPath string) {
if q := r.URL.RawQuery; q != "" {
newPath += "?" + q
}
w.Header().Set("Location", newPath)
w.WriteHeader(StatusMovedPermanently)
}
// ServeFile replies to the request with the contents of the named
// file or directory.
//
// If the provided file or directory name is a relative path, it is
// interpreted relative to the current directory and may ascend to
// parent directories. If the provided name is constructed from user
// input, it should be sanitized before calling ServeFile.
//
// As a precaution, ServeFile will reject requests where r.URL.Path
// contains a ".." path element; this protects against callers who
// might unsafely use filepath.Join on r.URL.Path without sanitizing
// it and then use that filepath.Join result as the name argument.
//
// As another special case, ServeFile redirects any request where r.URL.Path
// ends in "/index.html" to the same path, without the final
// "index.html". To avoid such redirects either modify the path or
// use ServeContent.
//
// Outside of those two special cases, ServeFile does not use
// r.URL.Path for selecting the file or directory to serve; only the
// file or directory provided in the name argument is used.
func ServeFile(w ResponseWriter, r *Request, name string) {
if containsDotDot(r.URL.Path) {
// Too many programs use r.URL.Path to construct the argument to
// serveFile. Reject the request under the assumption that happened
// here and ".." may not be wanted.
// Note that name might not contain "..", for example if code (still
// incorrectly) used filepath.Join(myDir, r.URL.Path).
Error(w, "invalid URL path", StatusBadRequest)
return
}
dir, file := filepath.Split(name)
serveFile(w, r, Dir(dir), file, false)
}
func containsDotDot(v string) bool {
if !strings.Contains(v, "..") {
return false
}
for _, ent := range strings.FieldsFunc(v, isSlashRune) {
if ent == ".." {
return true
}
}
return false
}
func isSlashRune(r rune) bool { return r == '/' || r == '\\' }
type fileHandler struct {
root FileSystem
}
type ioFS struct {
fsys fs.FS
}
type ioFile struct {
file fs.File
}
func (f ioFS) Open(name string) (File, error) {
if name == "/" {
name = "."
} else {
name = strings.TrimPrefix(name, "/")
}
file, err := f.fsys.Open(name)
if err != nil {
return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) {
return fs.Stat(f.fsys, path)
})
}
return ioFile{file}, nil
}
func (f ioFile) Close() error { return f.file.Close() }
func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) }
func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() }
var errMissingSeek = errors.New("io.File missing Seek method")
var errMissingReadDir = errors.New("io.File directory missing ReadDir method")
func (f ioFile) Seek(offset int64, whence int) (int64, error) {
s, ok := f.file.(io.Seeker)
if !ok {
return 0, errMissingSeek
}
return s.Seek(offset, whence)
}
func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) {
d, ok := f.file.(fs.ReadDirFile)
if !ok {
return nil, errMissingReadDir
}
return d.ReadDir(count)
}
func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) {
d, ok := f.file.(fs.ReadDirFile)
if !ok {
return nil, errMissingReadDir
}
var list []fs.FileInfo
for {
dirs, err := d.ReadDir(count - len(list))
for _, dir := range dirs {
info, err := dir.Info()
if err != nil {
// Pretend it doesn't exist, like (*os.File).Readdir does.
continue
}
list = append(list, info)
}
if err != nil {
return list, err
}
if count < 0 || len(list) >= count {
break
}
}
return list, nil
}
// FS converts fsys to a FileSystem implementation,
// for use with FileServer and NewFileTransport.
// The files provided by fsys must implement io.Seeker.
func FS(fsys fs.FS) FileSystem {
return ioFS{fsys}
}
// FileServer returns a handler that serves HTTP requests
// with the contents of the file system rooted at root.
//
// As a special case, the returned file server redirects any request
// ending in "/index.html" to the same path, without the final
// "index.html".
//
// To use the operating system's file system implementation,
// use http.Dir:
//
// http.Handle("/", http.FileServer(http.Dir("/tmp")))
//
// To use an fs.FS implementation, use http.FS to convert it:
//
// http.Handle("/", http.FileServer(http.FS(fsys)))
func FileServer(root FileSystem) Handler {
return &fileHandler{root}
}
func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
const options = MethodOptions + ", " + MethodGet + ", " + MethodHead
switch r.Method {
case MethodGet, MethodHead:
if !strings.HasPrefix(r.URL.Path, "/") {
r.URL.Path = "/" + r.URL.Path
}
serveFile(w, r, f.root, path.Clean(r.URL.Path), true)
case MethodOptions:
w.Header().Set("Allow", options)
default:
w.Header().Set("Allow", options)
Error(w, "read-only", StatusMethodNotAllowed)
}
}
// httpRange specifies the byte range to be sent to the client.
type httpRange struct {
start, length int64
}
func (r httpRange) contentRange(size int64) string {
return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
}
func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
return textproto.MIMEHeader{
"Content-Range": {r.contentRange(size)},
"Content-Type": {contentType},
}
}
// parseRange parses a Range header string as per RFC 7233.
// errNoOverlap is returned if none of the ranges overlap.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
return nil, nil // header not present
}
const b = "bytes="
if !strings.HasPrefix(s, b) {
return nil, errors.New("invalid range")
}
var ranges []httpRange
noOverlap := false
for _, ra := range strings.Split(s[len(b):], ",") {
ra = textproto.TrimString(ra)
if ra == "" {
continue
}
start, end, ok := strings.Cut(ra, "-")
if !ok {
return nil, errors.New("invalid range")
}
start, end = textproto.TrimString(start), textproto.TrimString(end)
var r httpRange
if start == "" {
// If no start is specified, end specifies the
// range start relative to the end of the file,
// and we are dealing with <suffix-length>
// which has to be a non-negative integer as per
// RFC 7233 Section 2.1 "Byte-Ranges".
if end == "" || end[0] == '-' {
return nil, errors.New("invalid range")
}
i, err := strconv.ParseInt(end, 10, 64)
if i < 0 || err != nil {
return nil, errors.New("invalid range")
}
if i > size {
i = size
}
r.start = size - i
r.length = size - r.start
} else {
i, err := strconv.ParseInt(start, 10, 64)
if err != nil || i < 0 {
return nil, errors.New("invalid range")
}
if i >= size {
// If the range begins after the size of the content,
// then it does not overlap.
noOverlap = true
continue
}
r.start = i
if end == "" {
// If no end is specified, range extends to end of the file.
r.length = size - r.start
} else {
i, err := strconv.ParseInt(end, 10, 64)
if err != nil || r.start > i {
return nil, errors.New("invalid range")
}
if i >= size {
i = size - 1
}
r.length = i - r.start + 1
}
}
ranges = append(ranges, r)
}
if noOverlap && len(ranges) == 0 {
// The specified ranges did not overlap with the content.
return nil, errNoOverlap
}
return ranges, nil
}
// countingWriter counts how many bytes have been written to it.
type countingWriter int64
func (w *countingWriter) Write(p []byte) (n int, err error) {
*w += countingWriter(len(p))
return len(p), nil
}
// rangesMIMESize returns the number of bytes it takes to encode the
// provided ranges as a multipart response.
func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
var w countingWriter
mw := multipart.NewWriter(&w)
for _, ra := range ranges {
mw.CreatePart(ra.mimeHeader(contentType, contentSize))
encSize += ra.length
}
mw.Close()
encSize += int64(w)
return
}
func sumRangesSize(ranges []httpRange) (size int64) {
for _, ra := range ranges {
size += ra.length
}
return
}
//go:build !nethttpomithttp2
// +build !nethttpomithttp2
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2
// Package http2 implements the HTTP/2 protocol.
//
// This package is low-level and intended to be used directly by very
// few people. Most users will use it indirectly through the automatic
// use by the net/http package (from Go 1.6 and later).
// For use in earlier Go versions see ConfigureServer. (Transport support
// requires Go 1.6 or later)
//
// See https://http2.github.io/ for more information on HTTP/2.
//
// See https://http2.golang.org/ for a test server running this code.
//
package http
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math"
mathrand "math/rand"
"net"
"net/http/httptrace"
"net/textproto"
"net/url"
"os"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
)
// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
// contains helper functions which may use Unicode-aware functions which would
// otherwise be unsafe and could introduce vulnerabilities if used improperly.
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func http2asciiEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if http2lower(s[i]) != http2lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func http2lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// isASCIIPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func http2isASCIIPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// asciiToLower returns the lowercase version of s if s is ASCII and printable,
// and whether or not it was.
func http2asciiToLower(s string) (lower string, ok bool) {
if !http2isASCIIPrint(s) {
return "", false
}
return strings.ToLower(s), true
}
// A list of the possible cipher suite ids. Taken from
// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt
const (
http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000
http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001
http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002
http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003
http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004
http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006
http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007
http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008
http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009
http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A
http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B
http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C
http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D
http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E
http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F
http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010
http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011
http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012
http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013
http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014
http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015
http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016
http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017
http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018
http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019
http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A
http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B
// Reserved uint16 = 0x001C-1D
http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E
http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F
http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020
http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021
http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022
http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023
http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024
http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025
http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026
http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027
http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028
http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029
http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A
http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B
http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E
http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F
http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030
http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031
http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032
http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033
http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034
http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036
http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037
http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038
http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039
http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A
http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B
http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C
http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D
http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E
http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F
http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040
http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046
// Reserved uint16 = 0x0047-4F
// Reserved uint16 = 0x0050-58
// Reserved uint16 = 0x0059-5C
// Unassigned uint16 = 0x005D-5F
// Reserved uint16 = 0x0060-66
http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067
http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068
http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069
http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A
http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B
http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C
http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D
// Unassigned uint16 = 0x006E-83
http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089
http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A
http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B
http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C
http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D
http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E
http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F
http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090
http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091
http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092
http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093
http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094
http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095
http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096
http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097
http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098
http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099
http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A
http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B
http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C
http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D
http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E
http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F
http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0
http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1
http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2
http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3
http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4
http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5
http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6
http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7
http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8
http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9
http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA
http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB
http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC
http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD
http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE
http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF
http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0
http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1
http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2
http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5
http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6
http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9
http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF
http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5
// Unassigned uint16 = 0x00C6-FE
http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF
// Unassigned uint16 = 0x01-55,*
http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600
// Unassigned uint16 = 0x5601 - 0xC000
http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001
http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002
http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005
http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006
http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007
http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A
http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B
http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C
http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D
http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E
http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F
http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010
http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011
http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012
http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013
http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014
http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015
http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016
http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017
http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018
http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019
http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A
http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B
http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C
http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D
http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E
http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F
http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020
http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021
http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026
http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027
http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028
http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029
http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E
http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F
http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030
http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031
http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032
http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033
http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034
http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035
http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036
http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037
http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B
http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C
http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D
http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E
http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F
http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040
http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041
http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042
http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043
http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044
http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045
http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046
http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D
http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E
http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F
http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050
http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051
http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052
http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053
http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054
http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055
http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056
http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057
http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058
http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059
http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A
http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061
http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062
http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063
http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064
http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065
http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066
http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067
http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068
http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069
http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A
http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B
http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C
http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D
http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E
http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F
http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070
http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079
http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A
http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D
http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E
http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093
http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094
http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099
http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A
http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B
http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C
http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D
http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E
http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F
http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0
http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1
http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2
http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3
http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4
http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5
http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6
http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7
http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8
http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9
http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA
http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF
// Unassigned uint16 = 0xC0B0-FF
// Unassigned uint16 = 0xC1-CB,*
// Unassigned uint16 = 0xCC00-A7
http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8
http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9
http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA
http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB
http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC
http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD
http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE
)
// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec.
// References:
// https://tools.ietf.org/html/rfc7540#appendix-A
// Reject cipher suites from Appendix A.
// "This list includes those cipher suites that do not
// offer an ephemeral key exchange and those that are
// based on the TLS null, stream or block cipher type"
func http2isBadCipher(cipher uint16) bool {
switch cipher {
case http2cipher_TLS_NULL_WITH_NULL_NULL,
http2cipher_TLS_RSA_WITH_NULL_MD5,
http2cipher_TLS_RSA_WITH_NULL_SHA,
http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5,
http2cipher_TLS_RSA_WITH_RC4_128_MD5,
http2cipher_TLS_RSA_WITH_RC4_128_SHA,
http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5,
http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA,
http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_RSA_WITH_DES_CBC_SHA,
http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5,
http2cipher_TLS_DH_anon_WITH_RC4_128_MD5,
http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_KRB5_WITH_DES_CBC_SHA,
http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_KRB5_WITH_RC4_128_SHA,
http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA,
http2cipher_TLS_KRB5_WITH_DES_CBC_MD5,
http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5,
http2cipher_TLS_KRB5_WITH_RC4_128_MD5,
http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5,
http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA,
http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA,
http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA,
http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5,
http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5,
http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5,
http2cipher_TLS_PSK_WITH_NULL_SHA,
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA,
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA,
http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA,
http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA,
http2cipher_TLS_RSA_WITH_NULL_SHA256,
http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256,
http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA,
http2cipher_TLS_PSK_WITH_RC4_128_SHA,
http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA,
http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA,
http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA,
http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA,
http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA,
http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA,
http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA,
http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA,
http2cipher_TLS_RSA_WITH_SEED_CBC_SHA,
http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA,
http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA,
http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA,
http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA,
http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA,
http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_PSK_WITH_NULL_SHA256,
http2cipher_TLS_PSK_WITH_NULL_SHA384,
http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256,
http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384,
http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256,
http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384,
http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256,
http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV,
http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA,
http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA,
http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA,
http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA,
http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA,
http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDH_anon_WITH_NULL_SHA,
http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA,
http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA,
http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA,
http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA,
http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA,
http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384,
http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA,
http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA,
http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA,
http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA,
http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384,
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA,
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256,
http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384,
http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256,
http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384,
http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384,
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256,
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384,
http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256,
http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384,
http2cipher_TLS_RSA_WITH_AES_128_CCM,
http2cipher_TLS_RSA_WITH_AES_256_CCM,
http2cipher_TLS_RSA_WITH_AES_128_CCM_8,
http2cipher_TLS_RSA_WITH_AES_256_CCM_8,
http2cipher_TLS_PSK_WITH_AES_128_CCM,
http2cipher_TLS_PSK_WITH_AES_256_CCM,
http2cipher_TLS_PSK_WITH_AES_128_CCM_8,
http2cipher_TLS_PSK_WITH_AES_256_CCM_8:
return true
default:
return false
}
}
// ClientConnPool manages a pool of HTTP/2 client connections.
type http2ClientConnPool interface {
// GetClientConn returns a specific HTTP/2 connection (usually
// a TLS-TCP connection) to an HTTP/2 server. On success, the
// returned ClientConn accounts for the upcoming RoundTrip
// call, so the caller should not omit it. If the caller needs
// to, ClientConn.RoundTrip can be called with a bogus
// new(http.Request) to release the stream reservation.
GetClientConn(req *Request, addr string) (*http2ClientConn, error)
MarkDead(*http2ClientConn)
}
// clientConnPoolIdleCloser is the interface implemented by ClientConnPool
// implementations which can close their idle connections.
type http2clientConnPoolIdleCloser interface {
http2ClientConnPool
closeIdleConnections()
}
var (
_ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil)
_ http2clientConnPoolIdleCloser = http2noDialClientConnPool{}
)
// TODO: use singleflight for dialing and addConnCalls?
type http2clientConnPool struct {
t *http2Transport
mu sync.Mutex // TODO: maybe switch to RWMutex
// TODO: add support for sharing conns based on cert names
// (e.g. share conn for googleapis.com and appspot.com)
conns map[string][]*http2ClientConn // key is host:port
dialing map[string]*http2dialCall // currently in-flight dials
keys map[*http2ClientConn][]string
addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls
}
func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
return p.getClientConn(req, addr, http2dialOnMiss)
}
const (
http2dialOnMiss = true
http2noDialOnMiss = false
)
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
if http2isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection.
http2traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil {
return nil, err
}
return cc, nil
}
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if cc.ReserveNewRequest() {
// When a connection is presented to us by the net/http package,
// the GetConn hook has already been called.
// Don't call it a second time here.
if !cc.getConnCalled {
http2traceGetConn(req, addr)
}
cc.getConnCalled = false
p.mu.Unlock()
return cc, nil
}
}
if !dialOnMiss {
p.mu.Unlock()
return nil, http2ErrNoCachedConn
}
http2traceGetConn(req, addr)
call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
<-call.done
if http2shouldRetryDial(call, req) {
continue
}
cc, err := call.res, call.err
if err != nil {
return nil, err
}
if cc.ReserveNewRequest() {
return cc, nil
}
}
}
// dialCall is an in-flight Transport dial call to a host.
type http2dialCall struct {
_ http2incomparable
p *http2clientConnPool
// the context associated with the request
// that created this dialCall
ctx context.Context
done chan struct{} // closed when done
res *http2ClientConn // valid after done is closed
err error // valid after done is closed
}
// requires p.mu is held.
func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil {
p.dialing = make(map[string]*http2dialCall)
}
p.dialing[addr] = call
go call.dial(call.ctx, addr)
return call
}
// run in its own goroutine.
func (c *http2dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
c.p.mu.Lock()
delete(c.p.dialing, addr)
if c.err == nil {
c.p.addConnLocked(addr, c.res)
}
c.p.mu.Unlock()
close(c.done)
}
// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
// already exist. It coalesces concurrent calls with the same key.
// This is used by the http1 Transport code when it creates a new connection. Because
// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
// the protocol), it can get into a situation where it has multiple TLS connections.
// This code decides which ones live or die.
// The return value used is whether c was used.
// c is never closed.
func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) {
p.mu.Lock()
for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() {
p.mu.Unlock()
return false, nil
}
}
call, dup := p.addConnCalls[key]
if !dup {
if p.addConnCalls == nil {
p.addConnCalls = make(map[string]*http2addConnCall)
}
call = &http2addConnCall{
p: p,
done: make(chan struct{}),
}
p.addConnCalls[key] = call
go call.run(t, key, c)
}
p.mu.Unlock()
<-call.done
if call.err != nil {
return false, call.err
}
return !dup, nil
}
type http2addConnCall struct {
_ http2incomparable
p *http2clientConnPool
done chan struct{} // closed when done
err error
}
func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) {
cc, err := t.NewClientConn(tc)
p := c.p
p.mu.Lock()
if err != nil {
c.err = err
} else {
cc.getConnCalled = true // already called by the net/http package
p.addConnLocked(key, cc)
}
delete(p.addConnCalls, key)
p.mu.Unlock()
close(c.done)
}
// p.mu must be held
func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) {
for _, v := range p.conns[key] {
if v == cc {
return
}
}
if p.conns == nil {
p.conns = make(map[string][]*http2ClientConn)
}
if p.keys == nil {
p.keys = make(map[*http2ClientConn][]string)
}
p.conns[key] = append(p.conns[key], cc)
p.keys[cc] = append(p.keys[cc], key)
}
func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
p.mu.Lock()
defer p.mu.Unlock()
for _, key := range p.keys[cc] {
vv, ok := p.conns[key]
if !ok {
continue
}
newList := http2filterOutClientConn(vv, cc)
if len(newList) > 0 {
p.conns[key] = newList
} else {
delete(p.conns, key)
}
}
delete(p.keys, cc)
}
func (p *http2clientConnPool) closeIdleConnections() {
p.mu.Lock()
defer p.mu.Unlock()
// TODO: don't close a cc if it was just added to the pool
// milliseconds ago and has never been used. There's currently
// a small race window with the HTTP/1 Transport's integration
// where it can add an idle conn just before using it, and
// somebody else can concurrently call CloseIdleConns and
// break some caller's RoundTrip.
for _, vv := range p.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
}
}
func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
// If we filtered it out, zero out the last item to prevent
// the GC from seeing it.
if len(in) != len(out) {
in[len(in)-1] = nil
}
return out
}
// noDialClientConnPool is an implementation of http2.ClientConnPool
// which never dials. We let the HTTP/1.1 client dial and use its TLS
// connection instead.
type http2noDialClientConnPool struct{ *http2clientConnPool }
func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
return p.getClientConn(req, addr, http2noDialOnMiss)
}
// shouldRetryDial reports whether the current request should
// retry dialing after the call finished unsuccessfully, for example
// if the dial was canceled because of a context cancellation or
// deadline expiry.
func http2shouldRetryDial(call *http2dialCall, req *Request) bool {
if call.err == nil {
// No error, no need to retry
return false
}
if call.ctx == req.Context() {
// If the call has the same context as the request, the dial
// should not be retried, since any cancellation will have come
// from this request.
return false
}
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
// If the call error is not because of a context cancellation or a deadline expiry,
// the dial should not be retried.
return false
}
// Only retry if the error is a context cancellation error or deadline expiry
// and the context associated with the call was canceled or expired.
return call.ctx.Err() != nil
}
// Buffer chunks are allocated from a pool to reduce pressure on GC.
// The maximum wasted space per dataBuffer is 2x the largest size class,
// which happens when the dataBuffer has multiple chunks and there is
// one unread byte in both the first and last chunks. We use a few size
// classes to minimize overheads for servers that typically receive very
// small request bodies.
//
// TODO: Benchmark to determine if the pools are necessary. The GC may have
// improved enough that we can instead allocate chunks like this:
// make([]byte, max(16<<10, expectedBytesRemaining))
var (
http2dataChunkSizeClasses = []int{
1 << 10,
2 << 10,
4 << 10,
8 << 10,
16 << 10,
}
http2dataChunkPools = [...]sync.Pool{
{New: func() interface{} { return make([]byte, 1<<10) }},
{New: func() interface{} { return make([]byte, 2<<10) }},
{New: func() interface{} { return make([]byte, 4<<10) }},
{New: func() interface{} { return make([]byte, 8<<10) }},
{New: func() interface{} { return make([]byte, 16<<10) }},
}
)
func http2getDataBufferChunk(size int64) []byte {
i := 0
for ; i < len(http2dataChunkSizeClasses)-1; i++ {
if size <= int64(http2dataChunkSizeClasses[i]) {
break
}
}
return http2dataChunkPools[i].Get().([]byte)
}
func http2putDataBufferChunk(p []byte) {
for i, n := range http2dataChunkSizeClasses {
if len(p) == n {
http2dataChunkPools[i].Put(p)
return
}
}
panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
}
// dataBuffer is an io.ReadWriter backed by a list of data chunks.
// Each dataBuffer is used to read DATA frames on a single stream.
// The buffer is divided into chunks so the server can limit the
// total memory used by a single connection without limiting the
// request body size on any single stream.
type http2dataBuffer struct {
chunks [][]byte
r int // next byte to read is chunks[0][r]
w int // next byte to write is chunks[len(chunks)-1][w]
size int // total buffered bytes
expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0)
}
var http2errReadEmpty = errors.New("read from empty dataBuffer")
// Read copies bytes from the buffer into p.
// It is an error to read when no data is available.
func (b *http2dataBuffer) Read(p []byte) (int, error) {
if b.size == 0 {
return 0, http2errReadEmpty
}
var ntotal int
for len(p) > 0 && b.size > 0 {
readFrom := b.bytesFromFirstChunk()
n := copy(p, readFrom)
p = p[n:]
ntotal += n
b.r += n
b.size -= n
// If the first chunk has been consumed, advance to the next chunk.
if b.r == len(b.chunks[0]) {
http2putDataBufferChunk(b.chunks[0])
end := len(b.chunks) - 1
copy(b.chunks[:end], b.chunks[1:])
b.chunks[end] = nil
b.chunks = b.chunks[:end]
b.r = 0
}
}
return ntotal, nil
}
func (b *http2dataBuffer) bytesFromFirstChunk() []byte {
if len(b.chunks) == 1 {
return b.chunks[0][b.r:b.w]
}
return b.chunks[0][b.r:]
}
// Len returns the number of bytes of the unread portion of the buffer.
func (b *http2dataBuffer) Len() int {
return b.size
}
// Write appends p to the buffer.
func (b *http2dataBuffer) Write(p []byte) (int, error) {
ntotal := len(p)
for len(p) > 0 {
// If the last chunk is empty, allocate a new chunk. Try to allocate
// enough to fully copy p plus any additional bytes we expect to
// receive. However, this may allocate less than len(p).
want := int64(len(p))
if b.expected > want {
want = b.expected
}
chunk := b.lastChunkOrAlloc(want)
n := copy(chunk[b.w:], p)
p = p[n:]
b.w += n
b.size += n
b.expected -= int64(n)
}
return ntotal, nil
}
func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte {
if len(b.chunks) != 0 {
last := b.chunks[len(b.chunks)-1]
if b.w < len(last) {
return last
}
}
chunk := http2getDataBufferChunk(want)
b.chunks = append(b.chunks, chunk)
b.w = 0
return chunk
}
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type http2ErrCode uint32
const (
http2ErrCodeNo http2ErrCode = 0x0
http2ErrCodeProtocol http2ErrCode = 0x1
http2ErrCodeInternal http2ErrCode = 0x2
http2ErrCodeFlowControl http2ErrCode = 0x3
http2ErrCodeSettingsTimeout http2ErrCode = 0x4
http2ErrCodeStreamClosed http2ErrCode = 0x5
http2ErrCodeFrameSize http2ErrCode = 0x6
http2ErrCodeRefusedStream http2ErrCode = 0x7
http2ErrCodeCancel http2ErrCode = 0x8
http2ErrCodeCompression http2ErrCode = 0x9
http2ErrCodeConnect http2ErrCode = 0xa
http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb
http2ErrCodeInadequateSecurity http2ErrCode = 0xc
http2ErrCodeHTTP11Required http2ErrCode = 0xd
)
var http2errCodeName = map[http2ErrCode]string{
http2ErrCodeNo: "NO_ERROR",
http2ErrCodeProtocol: "PROTOCOL_ERROR",
http2ErrCodeInternal: "INTERNAL_ERROR",
http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
http2ErrCodeStreamClosed: "STREAM_CLOSED",
http2ErrCodeFrameSize: "FRAME_SIZE_ERROR",
http2ErrCodeRefusedStream: "REFUSED_STREAM",
http2ErrCodeCancel: "CANCEL",
http2ErrCodeCompression: "COMPRESSION_ERROR",
http2ErrCodeConnect: "CONNECT_ERROR",
http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
}
func (e http2ErrCode) String() string {
if s, ok := http2errCodeName[e]; ok {
return s
}
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
}
func (e http2ErrCode) stringToken() string {
if s, ok := http2errCodeName[e]; ok {
return s
}
return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e))
}
// ConnectionError is an error that results in the termination of the
// entire connection.
type http2ConnectionError http2ErrCode
func (e http2ConnectionError) Error() string {
return fmt.Sprintf("connection error: %s", http2ErrCode(e))
}
// StreamError is an error that only affects one stream within an
// HTTP/2 connection.
type http2StreamError struct {
StreamID uint32
Code http2ErrCode
Cause error // optional additional detail
}
// errFromPeer is a sentinel error value for StreamError.Cause to
// indicate that the StreamError was sent from the peer over the wire
// and wasn't locally generated in the Transport.
var http2errFromPeer = errors.New("received from peer")
func http2streamError(id uint32, code http2ErrCode) http2StreamError {
return http2StreamError{StreamID: id, Code: code}
}
func (e http2StreamError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
}
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
}
// 6.9.1 The Flow Control Window
// "If a sender receives a WINDOW_UPDATE that causes a flow control
// window to exceed this maximum it MUST terminate either the stream
// or the connection, as appropriate. For streams, [...]; for the
// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
type http2goAwayFlowError struct{}
func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
// connError represents an HTTP/2 ConnectionError error code, along
// with a string (for debugging) explaining why.
//
// Errors of this type are only returned by the frame parser functions
// and converted into ConnectionError(Code), after stashing away
// the Reason into the Framer's errDetail field, accessible via
// the (*Framer).ErrorDetail method.
type http2connError struct {
Code http2ErrCode // the ConnectionError error code
Reason string // additional reason
}
func (e http2connError) Error() string {
return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
}
type http2pseudoHeaderError string
func (e http2pseudoHeaderError) Error() string {
return fmt.Sprintf("invalid pseudo-header %q", string(e))
}
type http2duplicatePseudoHeaderError string
func (e http2duplicatePseudoHeaderError) Error() string {
return fmt.Sprintf("duplicate pseudo-header %q", string(e))
}
type http2headerFieldNameError string
func (e http2headerFieldNameError) Error() string {
return fmt.Sprintf("invalid header field name %q", string(e))
}
type http2headerFieldValueError string
func (e http2headerFieldValueError) Error() string {
return fmt.Sprintf("invalid header field value for %q", string(e))
}
var (
http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
http2errPseudoAfterRegular = errors.New("pseudo header field after regular")
)
// inflowMinRefresh is the minimum number of bytes we'll send for a
// flow control window update.
const http2inflowMinRefresh = 4 << 10
// inflow accounts for an inbound flow control window.
// It tracks both the latest window sent to the peer (used for enforcement)
// and the accumulated unsent window.
type http2inflow struct {
avail int32
unsent int32
}
// init sets the initial window.
func (f *http2inflow) init(n int32) {
f.avail = n
}
// add adds n bytes to the window, with a maximum window size of max,
// indicating that the peer can now send us more data.
// For example, the user read from a {Request,Response} body and consumed
// some of the buffered data, so the peer can now send more.
// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer.
// Window updates are accumulated and sent when the unsent capacity
// is at least inflowMinRefresh or will at least double the peer's available window.
func (f *http2inflow) add(n int) (connAdd int32) {
if n < 0 {
panic("negative update")
}
unsent := int64(f.unsent) + int64(n)
// "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets."
// RFC 7540 Section 6.9.1.
const maxWindow = 1<<31 - 1
if unsent+int64(f.avail) > maxWindow {
panic("flow control update exceeds maximum window size")
}
f.unsent = int32(unsent)
if f.unsent < http2inflowMinRefresh && f.unsent < f.avail {
// If there aren't at least inflowMinRefresh bytes of window to send,
// and this update won't at least double the window, buffer the update for later.
return 0
}
f.avail += f.unsent
f.unsent = 0
return int32(unsent)
}
// take attempts to take n bytes from the peer's flow control window.
// It reports whether the window has available capacity.
func (f *http2inflow) take(n uint32) bool {
if n > uint32(f.avail) {
return false
}
f.avail -= int32(n)
return true
}
// takeInflows attempts to take n bytes from two inflows,
// typically connection-level and stream-level flows.
// It reports whether both windows have available capacity.
func http2takeInflows(f1, f2 *http2inflow, n uint32) bool {
if n > uint32(f1.avail) || n > uint32(f2.avail) {
return false
}
f1.avail -= int32(n)
f2.avail -= int32(n)
return true
}
// outflow is the outbound flow control window's size.
type http2outflow struct {
_ http2incomparable
// n is the number of DATA bytes we're allowed to send.
// An outflow is kept both on a conn and a per-stream.
n int32
// conn points to the shared connection-level outflow that is
// shared by all streams on that conn. It is nil for the outflow
// that's on the conn directly.
conn *http2outflow
}
func (f *http2outflow) setConnFlow(cf *http2outflow) { f.conn = cf }
func (f *http2outflow) available() int32 {
n := f.n
if f.conn != nil && f.conn.n < n {
n = f.conn.n
}
return n
}
func (f *http2outflow) take(n int32) {
if n > f.available() {
panic("internal error: took too much")
}
f.n -= n
if f.conn != nil {
f.conn.n -= n
}
}
// add adds n bytes (positive or negative) to the flow control window.
// It returns false if the sum would exceed 2^31-1.
func (f *http2outflow) add(n int32) bool {
sum := f.n + n
if (sum > n) == (f.n > 0) {
f.n = sum
return true
}
return false
}
const http2frameHeaderLen = 9
var http2padZeros = make([]byte, 255) // zeros for padding
// A FrameType is a registered frame type as defined in
// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2
type http2FrameType uint8
const (
http2FrameData http2FrameType = 0x0
http2FrameHeaders http2FrameType = 0x1
http2FramePriority http2FrameType = 0x2
http2FrameRSTStream http2FrameType = 0x3
http2FrameSettings http2FrameType = 0x4
http2FramePushPromise http2FrameType = 0x5
http2FramePing http2FrameType = 0x6
http2FrameGoAway http2FrameType = 0x7
http2FrameWindowUpdate http2FrameType = 0x8
http2FrameContinuation http2FrameType = 0x9
)
var http2frameName = map[http2FrameType]string{
http2FrameData: "DATA",
http2FrameHeaders: "HEADERS",
http2FramePriority: "PRIORITY",
http2FrameRSTStream: "RST_STREAM",
http2FrameSettings: "SETTINGS",
http2FramePushPromise: "PUSH_PROMISE",
http2FramePing: "PING",
http2FrameGoAway: "GOAWAY",
http2FrameWindowUpdate: "WINDOW_UPDATE",
http2FrameContinuation: "CONTINUATION",
}
func (t http2FrameType) String() string {
if s, ok := http2frameName[t]; ok {
return s
}
return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
}
// Flags is a bitmask of HTTP/2 flags.
// The meaning of flags varies depending on the frame type.
type http2Flags uint8
// Has reports whether f contains all (0 or more) flags in v.
func (f http2Flags) Has(v http2Flags) bool {
return (f & v) == v
}
// Frame-specific FrameHeader flag bits.
const (
// Data Frame
http2FlagDataEndStream http2Flags = 0x1
http2FlagDataPadded http2Flags = 0x8
// Headers Frame
http2FlagHeadersEndStream http2Flags = 0x1
http2FlagHeadersEndHeaders http2Flags = 0x4
http2FlagHeadersPadded http2Flags = 0x8
http2FlagHeadersPriority http2Flags = 0x20
// Settings Frame
http2FlagSettingsAck http2Flags = 0x1
// Ping Frame
http2FlagPingAck http2Flags = 0x1
// Continuation Frame
http2FlagContinuationEndHeaders http2Flags = 0x4
http2FlagPushPromiseEndHeaders http2Flags = 0x4
http2FlagPushPromisePadded http2Flags = 0x8
)
var http2flagName = map[http2FrameType]map[http2Flags]string{
http2FrameData: {
http2FlagDataEndStream: "END_STREAM",
http2FlagDataPadded: "PADDED",
},
http2FrameHeaders: {
http2FlagHeadersEndStream: "END_STREAM",
http2FlagHeadersEndHeaders: "END_HEADERS",
http2FlagHeadersPadded: "PADDED",
http2FlagHeadersPriority: "PRIORITY",
},
http2FrameSettings: {
http2FlagSettingsAck: "ACK",
},
http2FramePing: {
http2FlagPingAck: "ACK",
},
http2FrameContinuation: {
http2FlagContinuationEndHeaders: "END_HEADERS",
},
http2FramePushPromise: {
http2FlagPushPromiseEndHeaders: "END_HEADERS",
http2FlagPushPromisePadded: "PADDED",
},
}
// a frameParser parses a frame given its FrameHeader and payload
// bytes. The length of payload will always equal fh.Length (which
// might be 0).
type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error)
var http2frameParsers = map[http2FrameType]http2frameParser{
http2FrameData: http2parseDataFrame,
http2FrameHeaders: http2parseHeadersFrame,
http2FramePriority: http2parsePriorityFrame,
http2FrameRSTStream: http2parseRSTStreamFrame,
http2FrameSettings: http2parseSettingsFrame,
http2FramePushPromise: http2parsePushPromise,
http2FramePing: http2parsePingFrame,
http2FrameGoAway: http2parseGoAwayFrame,
http2FrameWindowUpdate: http2parseWindowUpdateFrame,
http2FrameContinuation: http2parseContinuationFrame,
}
func http2typeFrameParser(t http2FrameType) http2frameParser {
if f := http2frameParsers[t]; f != nil {
return f
}
return http2parseUnknownFrame
}
// A FrameHeader is the 9 byte header of all HTTP/2 frames.
//
// See https://httpwg.org/specs/rfc7540.html#FrameHeader
type http2FrameHeader struct {
valid bool // caller can access []byte fields in the Frame
// Type is the 1 byte frame type. There are ten standard frame
// types, but extension frame types may be written by WriteRawFrame
// and will be returned by ReadFrame (as UnknownFrame).
Type http2FrameType
// Flags are the 1 byte of 8 potential bit flags per frame.
// They are specific to the frame type.
Flags http2Flags
// Length is the length of the frame, not including the 9 byte header.
// The maximum size is one byte less than 16MB (uint24), but only
// frames up to 16KB are allowed without peer agreement.
Length uint32
// StreamID is which stream this frame is for. Certain frames
// are not stream-specific, in which case this field is 0.
StreamID uint32
}
// Header returns h. It exists so FrameHeaders can be embedded in other
// specific frame types and implement the Frame interface.
func (h http2FrameHeader) Header() http2FrameHeader { return h }
func (h http2FrameHeader) String() string {
var buf bytes.Buffer
buf.WriteString("[FrameHeader ")
h.writeDebug(&buf)
buf.WriteByte(']')
return buf.String()
}
func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) {
buf.WriteString(h.Type.String())
if h.Flags != 0 {
buf.WriteString(" flags=")
set := 0
for i := uint8(0); i < 8; i++ {
if h.Flags&(1<<i) == 0 {
continue
}
set++
if set > 1 {
buf.WriteByte('|')
}
name := http2flagName[h.Type][http2Flags(1<<i)]
if name != "" {
buf.WriteString(name)
} else {
fmt.Fprintf(buf, "0x%x", 1<<i)
}
}
}
if h.StreamID != 0 {
fmt.Fprintf(buf, " stream=%d", h.StreamID)
}
fmt.Fprintf(buf, " len=%d", h.Length)
}
func (h *http2FrameHeader) checkValid() {
if !h.valid {
panic("Frame accessor called on non-owned Frame")
}
}
func (h *http2FrameHeader) invalidate() { h.valid = false }
// frame header bytes.
// Used only by ReadFrameHeader.
var http2fhBytes = sync.Pool{
New: func() interface{} {
buf := make([]byte, http2frameHeaderLen)
return &buf
},
}
// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader.
// Most users should use Framer.ReadFrame instead.
func http2ReadFrameHeader(r io.Reader) (http2FrameHeader, error) {
bufp := http2fhBytes.Get().(*[]byte)
defer http2fhBytes.Put(bufp)
return http2readFrameHeader(*bufp, r)
}
func http2readFrameHeader(buf []byte, r io.Reader) (http2FrameHeader, error) {
_, err := io.ReadFull(r, buf[:http2frameHeaderLen])
if err != nil {
return http2FrameHeader{}, err
}
return http2FrameHeader{
Length: (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])),
Type: http2FrameType(buf[3]),
Flags: http2Flags(buf[4]),
StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
valid: true,
}, nil
}
// A Frame is the base interface implemented by all frame types.
// Callers will generally type-assert the specific frame type:
// *HeadersFrame, *SettingsFrame, *WindowUpdateFrame, etc.
//
// Frames are only valid until the next call to Framer.ReadFrame.
type http2Frame interface {
Header() http2FrameHeader
// invalidate is called by Framer.ReadFrame to make this
// frame's buffers as being invalid, since the subsequent
// frame will reuse them.
invalidate()
}
// A Framer reads and writes Frames.
type http2Framer struct {
r io.Reader
lastFrame http2Frame
errDetail error
// countError is a non-nil func that's called on a frame parse
// error with some unique error path token. It's initialized
// from Transport.CountError or Server.CountError.
countError func(errToken string)
// lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32
maxReadSize uint32
headerBuf [http2frameHeaderLen]byte
// TODO: let getReadBuf be configurable, and use a less memory-pinning
// allocator in server.go to minimize memory pinned for many idle conns.
// Will probably also need to make frame invalidation have a hook too.
getReadBuf func(size uint32) []byte
readBuf []byte // cache for default getReadBuf
maxWriteSize uint32 // zero means unlimited; TODO: implement
w io.Writer
wbuf []byte
// AllowIllegalWrites permits the Framer's Write methods to
// write frames that do not conform to the HTTP/2 spec. This
// permits using the Framer to test other HTTP/2
// implementations' conformance to the spec.
// If false, the Write methods will prefer to return an error
// rather than comply.
AllowIllegalWrites bool
// AllowIllegalReads permits the Framer's ReadFrame method
// to return non-compliant frames or frame orders.
// This is for testing and permits using the Framer to test
// other HTTP/2 implementations' conformance to the spec.
// It is not compatible with ReadMetaHeaders.
AllowIllegalReads bool
// ReadMetaHeaders if non-nil causes ReadFrame to merge
// HEADERS and CONTINUATION frames together and return
// MetaHeadersFrame instead.
ReadMetaHeaders *hpack.Decoder
// MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
// It's used only if ReadMetaHeaders is set; 0 means a sane default
// (currently 16MB)
// If the limit is hit, MetaHeadersFrame.Truncated is set true.
MaxHeaderListSize uint32
// TODO: track which type of frame & with which flags was sent
// last. Then return an error (unless AllowIllegalWrites) if
// we're in the middle of a header block and a
// non-Continuation or Continuation on a different stream is
// attempted to be written.
logReads, logWrites bool
debugFramer *http2Framer // only use for logging written writes
debugFramerBuf *bytes.Buffer
debugReadLoggerf func(string, ...interface{})
debugWriteLoggerf func(string, ...interface{})
frameCache *http2frameCache // nil if frames aren't reused (default)
}
func (fr *http2Framer) maxHeaderListSize() uint32 {
if fr.MaxHeaderListSize == 0 {
return 16 << 20 // sane default, per docs
}
return fr.MaxHeaderListSize
}
func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) {
// Write the FrameHeader.
f.wbuf = append(f.wbuf[:0],
0, // 3 bytes of length, filled in in endWrite
0,
0,
byte(ftype),
byte(flags),
byte(streamID>>24),
byte(streamID>>16),
byte(streamID>>8),
byte(streamID))
}
func (f *http2Framer) endWrite() error {
// Now that we know the final size, fill in the FrameHeader in
// the space previously reserved for it. Abuse append.
length := len(f.wbuf) - http2frameHeaderLen
if length >= (1 << 24) {
return http2ErrFrameTooLarge
}
_ = append(f.wbuf[:0],
byte(length>>16),
byte(length>>8),
byte(length))
if f.logWrites {
f.logWrite()
}
n, err := f.w.Write(f.wbuf)
if err == nil && n != len(f.wbuf) {
err = io.ErrShortWrite
}
return err
}
func (f *http2Framer) logWrite() {
if f.debugFramer == nil {
f.debugFramerBuf = new(bytes.Buffer)
f.debugFramer = http2NewFramer(nil, f.debugFramerBuf)
f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below
// Let us read anything, even if we accidentally wrote it
// in the wrong order:
f.debugFramer.AllowIllegalReads = true
}
f.debugFramerBuf.Write(f.wbuf)
fr, err := f.debugFramer.ReadFrame()
if err != nil {
f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f)
return
}
f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr))
}
func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) }
func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) }
func (f *http2Framer) writeUint32(v uint32) {
f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
const (
http2minMaxFrameSize = 1 << 14
http2maxFrameSize = 1<<24 - 1
)
// SetReuseFrames allows the Framer to reuse Frames.
// If called on a Framer, Frames returned by calls to ReadFrame are only
// valid until the next call to ReadFrame.
func (fr *http2Framer) SetReuseFrames() {
if fr.frameCache != nil {
return
}
fr.frameCache = &http2frameCache{}
}
type http2frameCache struct {
dataFrame http2DataFrame
}
func (fc *http2frameCache) getDataFrame() *http2DataFrame {
if fc == nil {
return &http2DataFrame{}
}
return &fc.dataFrame
}
// NewFramer returns a Framer that writes frames to w and reads them from r.
func http2NewFramer(w io.Writer, r io.Reader) *http2Framer {
fr := &http2Framer{
w: w,
r: r,
countError: func(string) {},
logReads: http2logFrameReads,
logWrites: http2logFrameWrites,
debugReadLoggerf: log.Printf,
debugWriteLoggerf: log.Printf,
}
fr.getReadBuf = func(size uint32) []byte {
if cap(fr.readBuf) >= int(size) {
return fr.readBuf[:size]
}
fr.readBuf = make([]byte, size)
return fr.readBuf
}
fr.SetMaxReadFrameSize(http2maxFrameSize)
return fr
}
// SetMaxReadFrameSize sets the maximum size of a frame
// that will be read by a subsequent call to ReadFrame.
// It is the caller's responsibility to advertise this
// limit with a SETTINGS frame.
func (fr *http2Framer) SetMaxReadFrameSize(v uint32) {
if v > http2maxFrameSize {
v = http2maxFrameSize
}
fr.maxReadSize = v
}
// ErrorDetail returns a more detailed error of the last error
// returned by Framer.ReadFrame. For instance, if ReadFrame
// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail
// will say exactly what was invalid. ErrorDetail is not guaranteed
// to return a non-nil value and like the rest of the http2 package,
// its return value is not protected by an API compatibility promise.
// ErrorDetail is reset after the next call to ReadFrame.
func (fr *http2Framer) ErrorDetail() error {
return fr.errDetail
}
// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer
// sends a frame that is larger than declared with SetMaxReadFrameSize.
var http2ErrFrameTooLarge = errors.New("http2: frame too large")
// terminalReadFrameError reports whether err is an unrecoverable
// error from ReadFrame and no other frames should be read.
func http2terminalReadFrameError(err error) bool {
if _, ok := err.(http2StreamError); ok {
return false
}
return err != nil
}
// ReadFrame reads a single frame. The returned Frame is only valid
// until the next call to ReadFrame.
//
// If the frame is larger than previously set with SetMaxReadFrameSize, the
// returned error is ErrFrameTooLarge. Other errors may be of type
// ConnectionError, StreamError, or anything else from the underlying
// reader.
func (fr *http2Framer) ReadFrame() (http2Frame, error) {
fr.errDetail = nil
if fr.lastFrame != nil {
fr.lastFrame.invalidate()
}
fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r)
if err != nil {
return nil, err
}
if fh.Length > fr.maxReadSize {
return nil, http2ErrFrameTooLarge
}
payload := fr.getReadBuf(fh.Length)
if _, err := io.ReadFull(fr.r, payload); err != nil {
return nil, err
}
f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
if err != nil {
if ce, ok := err.(http2connError); ok {
return nil, fr.connError(ce.Code, ce.Reason)
}
return nil, err
}
if err := fr.checkFrameOrder(f); err != nil {
return nil, err
}
if fr.logReads {
fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
}
if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil {
return fr.readMetaFrame(f.(*http2HeadersFrame))
}
return f, nil
}
// connError returns ConnectionError(code) but first
// stashes away a public reason to the caller can optionally relay it
// to the peer before hanging up on them. This might help others debug
// their implementations.
func (fr *http2Framer) connError(code http2ErrCode, reason string) error {
fr.errDetail = errors.New(reason)
return http2ConnectionError(code)
}
// checkFrameOrder reports an error if f is an invalid frame to return
// next from ReadFrame. Mostly it checks whether HEADERS and
// CONTINUATION frames are contiguous.
func (fr *http2Framer) checkFrameOrder(f http2Frame) error {
last := fr.lastFrame
fr.lastFrame = f
if fr.AllowIllegalReads {
return nil
}
fh := f.Header()
if fr.lastHeaderStream != 0 {
if fh.Type != http2FrameContinuation {
return fr.connError(http2ErrCodeProtocol,
fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
fh.Type, fh.StreamID,
last.Header().Type, fr.lastHeaderStream))
}
if fh.StreamID != fr.lastHeaderStream {
return fr.connError(http2ErrCodeProtocol,
fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
fh.StreamID, fr.lastHeaderStream))
}
} else if fh.Type == http2FrameContinuation {
return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
}
switch fh.Type {
case http2FrameHeaders, http2FrameContinuation:
if fh.Flags.Has(http2FlagHeadersEndHeaders) {
fr.lastHeaderStream = 0
} else {
fr.lastHeaderStream = fh.StreamID
}
}
return nil
}
// A DataFrame conveys arbitrary, variable-length sequences of octets
// associated with a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.1
type http2DataFrame struct {
http2FrameHeader
data []byte
}
func (f *http2DataFrame) StreamEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream)
}
// Data returns the frame's data octets, not including any padding
// size byte or padding suffix bytes.
// The caller must not retain the returned memory past the next
// call to ReadFrame.
func (f *http2DataFrame) Data() []byte {
f.checkValid()
return f.data
}
func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if fh.StreamID == 0 {
// DATA frames MUST be associated with a stream. If a
// DATA frame is received whose stream identifier
// field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
countError("frame_data_stream_0")
return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"}
}
f := fc.getDataFrame()
f.http2FrameHeader = fh
var padSize byte
if fh.Flags.Has(http2FlagDataPadded) {
var err error
payload, padSize, err = http2readByte(payload)
if err != nil {
countError("frame_data_pad_byte_short")
return nil, err
}
}
if int(padSize) > len(payload) {
// If the length of the padding is greater than the
// length of the frame payload, the recipient MUST
// treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610
countError("frame_data_pad_too_big")
return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"}
}
f.data = payload[:len(payload)-int(padSize)]
return f, nil
}
var (
http2errStreamID = errors.New("invalid stream ID")
http2errDepStreamID = errors.New("invalid dependent stream ID")
http2errPadLength = errors.New("pad length too large")
http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled")
)
func http2validStreamIDOrZero(streamID uint32) bool {
return streamID&(1<<31) == 0
}
func http2validStreamID(streamID uint32) bool {
return streamID != 0 && streamID&(1<<31) == 0
}
// WriteData writes a DATA frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility not to violate the maximum frame size
// and to not call other Write methods concurrently.
func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error {
return f.WriteDataPadded(streamID, endStream, data, nil)
}
// WriteDataPadded writes a DATA frame with optional padding.
//
// If pad is nil, the padding bit is not sent.
// The length of pad must not exceed 255 bytes.
// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility not to violate the maximum frame size
// and to not call other Write methods concurrently.
func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error {
if err := f.startWriteDataPadded(streamID, endStream, data, pad); err != nil {
return err
}
return f.endWrite()
}
// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer.
// The caller should call endWrite to flush the frame to the underlying writer.
func (f *http2Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error {
if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
if len(pad) > 0 {
if len(pad) > 255 {
return http2errPadLength
}
if !f.AllowIllegalWrites {
for _, b := range pad {
if b != 0 {
// "Padding octets MUST be set to zero when sending."
return http2errPadBytes
}
}
}
}
var flags http2Flags
if endStream {
flags |= http2FlagDataEndStream
}
if pad != nil {
flags |= http2FlagDataPadded
}
f.startWrite(http2FrameData, flags, streamID)
if pad != nil {
f.wbuf = append(f.wbuf, byte(len(pad)))
}
f.wbuf = append(f.wbuf, data...)
f.wbuf = append(f.wbuf, pad...)
return nil
}
// A SettingsFrame conveys configuration parameters that affect how
// endpoints communicate, such as preferences and constraints on peer
// behavior.
//
// See https://httpwg.org/specs/rfc7540.html#SETTINGS
type http2SettingsFrame struct {
http2FrameHeader
p []byte
}
func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 {
// When this (ACK 0x1) bit is set, the payload of the
// SETTINGS frame MUST be empty. Receipt of a
// SETTINGS frame with the ACK flag set and a length
// field value other than 0 MUST be treated as a
// connection error (Section 5.4.1) of type
// FRAME_SIZE_ERROR.
countError("frame_settings_ack_with_length")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID != 0 {
// SETTINGS frames always apply to a connection,
// never a single stream. The stream identifier for a
// SETTINGS frame MUST be zero (0x0). If an endpoint
// receives a SETTINGS frame whose stream identifier
// field is anything other than 0x0, the endpoint MUST
// respond with a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR.
countError("frame_settings_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
if len(p)%6 != 0 {
countError("frame_settings_mod_6")
// Expecting even number of 6 byte settings.
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
f := &http2SettingsFrame{http2FrameHeader: fh, p: p}
if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 {
countError("frame_settings_window_size_too_big")
// Values above the maximum flow control window size of 2^31 - 1 MUST
// be treated as a connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR.
return nil, http2ConnectionError(http2ErrCodeFlowControl)
}
return f, nil
}
func (f *http2SettingsFrame) IsAck() bool {
return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck)
}
func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) {
f.checkValid()
for i := 0; i < f.NumSettings(); i++ {
if s := f.Setting(i); s.ID == id {
return s.Val, true
}
}
return 0, false
}
// Setting returns the setting from the frame at the given 0-based index.
// The index must be >= 0 and less than f.NumSettings().
func (f *http2SettingsFrame) Setting(i int) http2Setting {
buf := f.p
return http2Setting{
ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])),
Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]),
}
}
func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 }
// HasDuplicates reports whether f contains any duplicate setting IDs.
func (f *http2SettingsFrame) HasDuplicates() bool {
num := f.NumSettings()
if num == 0 {
return false
}
// If it's small enough (the common case), just do the n^2
// thing and avoid a map allocation.
if num < 10 {
for i := 0; i < num; i++ {
idi := f.Setting(i).ID
for j := i + 1; j < num; j++ {
idj := f.Setting(j).ID
if idi == idj {
return true
}
}
}
return false
}
seen := map[http2SettingID]bool{}
for i := 0; i < num; i++ {
id := f.Setting(i).ID
if seen[id] {
return true
}
seen[id] = true
}
return false
}
// ForeachSetting runs fn for each setting.
// It stops and returns the first error.
func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error {
f.checkValid()
for i := 0; i < f.NumSettings(); i++ {
if err := fn(f.Setting(i)); err != nil {
return err
}
}
return nil
}
// WriteSettings writes a SETTINGS frame with zero or more settings
// specified and the ACK bit not set.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WriteSettings(settings ...http2Setting) error {
f.startWrite(http2FrameSettings, 0, 0)
for _, s := range settings {
f.writeUint16(uint16(s.ID))
f.writeUint32(s.Val)
}
return f.endWrite()
}
// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WriteSettingsAck() error {
f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0)
return f.endWrite()
}
// A PingFrame is a mechanism for measuring a minimal round trip time
// from the sender, as well as determining whether an idle connection
// is still functional.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.7
type http2PingFrame struct {
http2FrameHeader
Data [8]byte
}
func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) }
func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if len(payload) != 8 {
countError("frame_ping_length")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID != 0 {
countError("frame_ping_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
f := &http2PingFrame{http2FrameHeader: fh}
copy(f.Data[:], payload)
return f, nil
}
func (f *http2Framer) WritePing(ack bool, data [8]byte) error {
var flags http2Flags
if ack {
flags = http2FlagPingAck
}
f.startWrite(http2FramePing, flags, 0)
f.writeBytes(data[:])
return f.endWrite()
}
// A GoAwayFrame informs the remote peer to stop creating streams on this connection.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8
type http2GoAwayFrame struct {
http2FrameHeader
LastStreamID uint32
ErrCode http2ErrCode
debugData []byte
}
// DebugData returns any debug data in the GOAWAY frame. Its contents
// are not defined.
// The caller must not retain the returned memory past the next
// call to ReadFrame.
func (f *http2GoAwayFrame) DebugData() []byte {
f.checkValid()
return f.debugData
}
func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.StreamID != 0 {
countError("frame_goaway_has_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
if len(p) < 8 {
countError("frame_goaway_short")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
return &http2GoAwayFrame{
http2FrameHeader: fh,
LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1),
ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])),
debugData: p[8:],
}, nil
}
func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error {
f.startWrite(http2FrameGoAway, 0, 0)
f.writeUint32(maxStreamID & (1<<31 - 1))
f.writeUint32(uint32(code))
f.writeBytes(debugData)
return f.endWrite()
}
// An UnknownFrame is the frame type returned when the frame type is unknown
// or no specific frame type parser exists.
type http2UnknownFrame struct {
http2FrameHeader
p []byte
}
// Payload returns the frame's payload (after the header). It is not
// valid to call this method after a subsequent call to
// Framer.ReadFrame, nor is it valid to retain the returned slice.
// The memory is owned by the Framer and is invalidated when the next
// frame is read.
func (f *http2UnknownFrame) Payload() []byte {
f.checkValid()
return f.p
}
func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
return &http2UnknownFrame{fh, p}, nil
}
// A WindowUpdateFrame is used to implement flow control.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.9
type http2WindowUpdateFrame struct {
http2FrameHeader
Increment uint32 // never read with high bit set
}
func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if len(p) != 4 {
countError("frame_windowupdate_bad_len")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit
if inc == 0 {
// A receiver MUST treat the receipt of a
// WINDOW_UPDATE frame with an flow control window
// increment of 0 as a stream error (Section 5.4.2) of
// type PROTOCOL_ERROR; errors on the connection flow
// control window MUST be treated as a connection
// error (Section 5.4.1).
if fh.StreamID == 0 {
countError("frame_windowupdate_zero_inc_conn")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
countError("frame_windowupdate_zero_inc_stream")
return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
}
return &http2WindowUpdateFrame{
http2FrameHeader: fh,
Increment: inc,
}, nil
}
// WriteWindowUpdate writes a WINDOW_UPDATE frame.
// The increment value must be between 1 and 2,147,483,647, inclusive.
// If the Stream ID is zero, the window update applies to the
// connection as a whole.
func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error {
// "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets."
if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites {
return errors.New("illegal window increment value")
}
f.startWrite(http2FrameWindowUpdate, 0, streamID)
f.writeUint32(incr)
return f.endWrite()
}
// A HeadersFrame is used to open a stream and additionally carries a
// header block fragment.
type http2HeadersFrame struct {
http2FrameHeader
// Priority is set if FlagHeadersPriority is set in the FrameHeader.
Priority http2PriorityParam
headerFragBuf []byte // not owned
}
func (f *http2HeadersFrame) HeaderBlockFragment() []byte {
f.checkValid()
return f.headerFragBuf
}
func (f *http2HeadersFrame) HeadersEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders)
}
func (f *http2HeadersFrame) StreamEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream)
}
func (f *http2HeadersFrame) HasPriority() bool {
return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority)
}
func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
hf := &http2HeadersFrame{
http2FrameHeader: fh,
}
if fh.StreamID == 0 {
// HEADERS frames MUST be associated with a stream. If a HEADERS frame
// is received whose stream identifier field is 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR.
countError("frame_headers_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"}
}
var padLength uint8
if fh.Flags.Has(http2FlagHeadersPadded) {
if p, padLength, err = http2readByte(p); err != nil {
countError("frame_headers_pad_short")
return
}
}
if fh.Flags.Has(http2FlagHeadersPriority) {
var v uint32
p, v, err = http2readUint32(p)
if err != nil {
countError("frame_headers_prio_short")
return nil, err
}
hf.Priority.StreamDep = v & 0x7fffffff
hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set
p, hf.Priority.Weight, err = http2readByte(p)
if err != nil {
countError("frame_headers_prio_weight_short")
return nil, err
}
}
if len(p)-int(padLength) < 0 {
countError("frame_headers_pad_too_big")
return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
}
hf.headerFragBuf = p[:len(p)-int(padLength)]
return hf, nil
}
// HeadersFrameParam are the parameters for writing a HEADERS frame.
type http2HeadersFrameParam struct {
// StreamID is the required Stream ID to initiate.
StreamID uint32
// BlockFragment is part (or all) of a Header Block.
BlockFragment []byte
// EndStream indicates that the header block is the last that
// the endpoint will send for the identified stream. Setting
// this flag causes the stream to enter one of "half closed"
// states.
EndStream bool
// EndHeaders indicates that this frame contains an entire
// header block and is not followed by any
// CONTINUATION frames.
EndHeaders bool
// PadLength is the optional number of bytes of zeros to add
// to this frame.
PadLength uint8
// Priority, if non-zero, includes stream priority information
// in the HEADER frame.
Priority http2PriorityParam
}
// WriteHeaders writes a single HEADERS frame.
//
// This is a low-level header writing method. Encoding headers and
// splitting them into any necessary CONTINUATION frames is handled
// elsewhere.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error {
if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
var flags http2Flags
if p.PadLength != 0 {
flags |= http2FlagHeadersPadded
}
if p.EndStream {
flags |= http2FlagHeadersEndStream
}
if p.EndHeaders {
flags |= http2FlagHeadersEndHeaders
}
if !p.Priority.IsZero() {
flags |= http2FlagHeadersPriority
}
f.startWrite(http2FrameHeaders, flags, p.StreamID)
if p.PadLength != 0 {
f.writeByte(p.PadLength)
}
if !p.Priority.IsZero() {
v := p.Priority.StreamDep
if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites {
return http2errDepStreamID
}
if p.Priority.Exclusive {
v |= 1 << 31
}
f.writeUint32(v)
f.writeByte(p.Priority.Weight)
}
f.wbuf = append(f.wbuf, p.BlockFragment...)
f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
return f.endWrite()
}
// A PriorityFrame specifies the sender-advised priority of a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3
type http2PriorityFrame struct {
http2FrameHeader
http2PriorityParam
}
// PriorityParam are the stream prioritzation parameters.
type http2PriorityParam struct {
// StreamDep is a 31-bit stream identifier for the
// stream that this stream depends on. Zero means no
// dependency.
StreamDep uint32
// Exclusive is whether the dependency is exclusive.
Exclusive bool
// Weight is the stream's zero-indexed weight. It should be
// set together with StreamDep, or neither should be set. Per
// the spec, "Add one to the value to obtain a weight between
// 1 and 256."
Weight uint8
}
func (p http2PriorityParam) IsZero() bool {
return p == http2PriorityParam{}
}
func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
if fh.StreamID == 0 {
countError("frame_priority_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
}
if len(payload) != 5 {
countError("frame_priority_bad_length")
return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
}
v := binary.BigEndian.Uint32(payload[:4])
streamID := v & 0x7fffffff // mask off high bit
return &http2PriorityFrame{
http2FrameHeader: fh,
http2PriorityParam: http2PriorityParam{
Weight: payload[4],
StreamDep: streamID,
Exclusive: streamID != v, // was high bit set?
},
}, nil
}
// WritePriority writes a PRIORITY frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error {
if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
if !http2validStreamIDOrZero(p.StreamDep) {
return http2errDepStreamID
}
f.startWrite(http2FramePriority, 0, streamID)
v := p.StreamDep
if p.Exclusive {
v |= 1 << 31
}
f.writeUint32(v)
f.writeByte(p.Weight)
return f.endWrite()
}
// A RSTStreamFrame allows for abnormal termination of a stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
type http2RSTStreamFrame struct {
http2FrameHeader
ErrCode http2ErrCode
}
func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if len(p) != 4 {
countError("frame_rststream_bad_len")
return nil, http2ConnectionError(http2ErrCodeFrameSize)
}
if fh.StreamID == 0 {
countError("frame_rststream_zero_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
}
// WriteRSTStream writes a RST_STREAM frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error {
if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
f.startWrite(http2FrameRSTStream, 0, streamID)
f.writeUint32(uint32(code))
return f.endWrite()
}
// A ContinuationFrame is used to continue a sequence of header block fragments.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.10
type http2ContinuationFrame struct {
http2FrameHeader
headerFragBuf []byte
}
func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
if fh.StreamID == 0 {
countError("frame_continuation_zero_stream")
return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
}
return &http2ContinuationFrame{fh, p}, nil
}
func (f *http2ContinuationFrame) HeaderBlockFragment() []byte {
f.checkValid()
return f.headerFragBuf
}
func (f *http2ContinuationFrame) HeadersEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders)
}
// WriteContinuation writes a CONTINUATION frame.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
var flags http2Flags
if endHeaders {
flags |= http2FlagContinuationEndHeaders
}
f.startWrite(http2FrameContinuation, flags, streamID)
f.wbuf = append(f.wbuf, headerBlockFragment...)
return f.endWrite()
}
// A PushPromiseFrame is used to initiate a server stream.
// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.6
type http2PushPromiseFrame struct {
http2FrameHeader
PromiseID uint32
headerFragBuf []byte // not owned
}
func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte {
f.checkValid()
return f.headerFragBuf
}
func (f *http2PushPromiseFrame) HeadersEnded() bool {
return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders)
}
func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
pp := &http2PushPromiseFrame{
http2FrameHeader: fh,
}
if pp.StreamID == 0 {
// PUSH_PROMISE frames MUST be associated with an existing,
// peer-initiated stream. The stream identifier of a
// PUSH_PROMISE frame indicates the stream it is associated
// with. If the stream identifier field specifies the value
// 0x0, a recipient MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
countError("frame_pushpromise_zero_stream")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
// The PUSH_PROMISE frame includes optional padding.
// Padding fields and flags are identical to those defined for DATA frames
var padLength uint8
if fh.Flags.Has(http2FlagPushPromisePadded) {
if p, padLength, err = http2readByte(p); err != nil {
countError("frame_pushpromise_pad_short")
return
}
}
p, pp.PromiseID, err = http2readUint32(p)
if err != nil {
countError("frame_pushpromise_promiseid_short")
return
}
pp.PromiseID = pp.PromiseID & (1<<31 - 1)
if int(padLength) > len(p) {
// like the DATA frame, error out if padding is longer than the body.
countError("frame_pushpromise_pad_too_big")
return nil, http2ConnectionError(http2ErrCodeProtocol)
}
pp.headerFragBuf = p[:len(p)-int(padLength)]
return pp, nil
}
// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame.
type http2PushPromiseParam struct {
// StreamID is the required Stream ID to initiate.
StreamID uint32
// PromiseID is the required Stream ID which this
// Push Promises
PromiseID uint32
// BlockFragment is part (or all) of a Header Block.
BlockFragment []byte
// EndHeaders indicates that this frame contains an entire
// header block and is not followed by any
// CONTINUATION frames.
EndHeaders bool
// PadLength is the optional number of bytes of zeros to add
// to this frame.
PadLength uint8
}
// WritePushPromise writes a single PushPromise Frame.
//
// As with Header Frames, This is the low level call for writing
// individual frames. Continuation frames are handled elsewhere.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error {
if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
return http2errStreamID
}
var flags http2Flags
if p.PadLength != 0 {
flags |= http2FlagPushPromisePadded
}
if p.EndHeaders {
flags |= http2FlagPushPromiseEndHeaders
}
f.startWrite(http2FramePushPromise, flags, p.StreamID)
if p.PadLength != 0 {
f.writeByte(p.PadLength)
}
if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites {
return http2errStreamID
}
f.writeUint32(p.PromiseID)
f.wbuf = append(f.wbuf, p.BlockFragment...)
f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
return f.endWrite()
}
// WriteRawFrame writes a raw frame. This can be used to write
// extension frames unknown to this package.
func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error {
f.startWrite(t, flags, streamID)
f.writeBytes(payload)
return f.endWrite()
}
func http2readByte(p []byte) (remain []byte, b byte, err error) {
if len(p) == 0 {
return nil, 0, io.ErrUnexpectedEOF
}
return p[1:], p[0], nil
}
func http2readUint32(p []byte) (remain []byte, v uint32, err error) {
if len(p) < 4 {
return nil, 0, io.ErrUnexpectedEOF
}
return p[4:], binary.BigEndian.Uint32(p[:4]), nil
}
type http2streamEnder interface {
StreamEnded() bool
}
type http2headersEnder interface {
HeadersEnded() bool
}
type http2headersOrContinuation interface {
http2headersEnder
HeaderBlockFragment() []byte
}
// A MetaHeadersFrame is the representation of one HEADERS frame and
// zero or more contiguous CONTINUATION frames and the decoding of
// their HPACK-encoded contents.
//
// This type of frame does not appear on the wire and is only returned
// by the Framer when Framer.ReadMetaHeaders is set.
type http2MetaHeadersFrame struct {
*http2HeadersFrame
// Fields are the fields contained in the HEADERS and
// CONTINUATION frames. The underlying slice is owned by the
// Framer and must not be retained after the next call to
// ReadFrame.
//
// Fields are guaranteed to be in the correct http2 order and
// not have unknown pseudo header fields or invalid header
// field names or values. Required pseudo header fields may be
// missing, however. Use the MetaHeadersFrame.Pseudo accessor
// method access pseudo headers.
Fields []hpack.HeaderField
// Truncated is whether the max header list size limit was hit
// and Fields is incomplete. The hpack decoder state is still
// valid, however.
Truncated bool
}
// PseudoValue returns the given pseudo header field's value.
// The provided pseudo field should not contain the leading colon.
func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string {
for _, hf := range mh.Fields {
if !hf.IsPseudo() {
return ""
}
if hf.Name[1:] == pseudo {
return hf.Value
}
}
return ""
}
// RegularFields returns the regular (non-pseudo) header fields of mh.
// The caller does not own the returned slice.
func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[i:]
}
}
return nil
}
// PseudoFields returns the pseudo header fields of mh.
// The caller does not own the returned slice.
func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
for i, hf := range mh.Fields {
if !hf.IsPseudo() {
return mh.Fields[:i]
}
}
return mh.Fields
}
func (mh *http2MetaHeadersFrame) checkPseudos() error {
var isRequest, isResponse bool
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
case ":method", ":path", ":scheme", ":authority":
isRequest = true
case ":status":
isResponse = true
default:
return http2pseudoHeaderError(hf.Name)
}
// Check for duplicates.
// This would be a bad algorithm, but N is 4.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {
return http2duplicatePseudoHeaderError(hf.Name)
}
}
}
if isRequest && isResponse {
return http2errMixPseudoHeaderTypes
}
return nil
}
func (fr *http2Framer) maxHeaderStringLen() int {
v := fr.maxHeaderListSize()
if uint32(int(v)) == v {
return int(v)
}
// They had a crazy big number for MaxHeaderBytes anyway,
// so give them unlimited header lengths:
return 0
}
// readMetaFrame returns 0 or more CONTINUATION frames from fr and
// merge them into the provided hf and returns a MetaHeadersFrame
// with the decoded hpack values.
func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) {
if fr.AllowIllegalReads {
return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
}
mh := &http2MetaHeadersFrame{
http2HeadersFrame: hf,
}
var remainSize = fr.maxHeaderListSize()
var sawRegular bool
var invalid error // pseudo header field errors
hdec := fr.ReadMetaHeaders
hdec.SetEmitEnabled(true)
hdec.SetMaxStringLength(fr.maxHeaderStringLen())
hdec.SetEmitFunc(func(hf hpack.HeaderField) {
if http2VerboseLogs && fr.logReads {
fr.debugReadLoggerf("http2: decoded hpack field %+v", hf)
}
if !httpguts.ValidHeaderFieldValue(hf.Value) {
// Don't include the value in the error, because it may be sensitive.
invalid = http2headerFieldValueError(hf.Name)
}
isPseudo := strings.HasPrefix(hf.Name, ":")
if isPseudo {
if sawRegular {
invalid = http2errPseudoAfterRegular
}
} else {
sawRegular = true
if !http2validWireHeaderFieldName(hf.Name) {
invalid = http2headerFieldNameError(hf.Name)
}
}
if invalid != nil {
hdec.SetEmitEnabled(false)
return
}
size := hf.Size()
if size > remainSize {
hdec.SetEmitEnabled(false)
mh.Truncated = true
return
}
remainSize -= size
mh.Fields = append(mh.Fields, hf)
})
// Lose reference to MetaHeadersFrame:
defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
var hc http2headersOrContinuation = hf
for {
frag := hc.HeaderBlockFragment()
if _, err := hdec.Write(frag); err != nil {
return nil, http2ConnectionError(http2ErrCodeCompression)
}
if hc.HeadersEnded() {
break
}
if f, err := fr.ReadFrame(); err != nil {
return nil, err
} else {
hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder
}
}
mh.http2HeadersFrame.headerFragBuf = nil
mh.http2HeadersFrame.invalidate()
if err := hdec.Close(); err != nil {
return nil, http2ConnectionError(http2ErrCodeCompression)
}
if invalid != nil {
fr.errDetail = invalid
if http2VerboseLogs {
log.Printf("http2: invalid header: %v", invalid)
}
return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid}
}
if err := mh.checkPseudos(); err != nil {
fr.errDetail = err
if http2VerboseLogs {
log.Printf("http2: invalid pseudo headers: %v", err)
}
return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err}
}
return mh, nil
}
func http2summarizeFrame(f http2Frame) string {
var buf bytes.Buffer
f.Header().writeDebug(&buf)
switch f := f.(type) {
case *http2SettingsFrame:
n := 0
f.ForeachSetting(func(s http2Setting) error {
n++
if n == 1 {
buf.WriteString(", settings:")
}
fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val)
return nil
})
if n > 0 {
buf.Truncate(buf.Len() - 1) // remove trailing comma
}
case *http2DataFrame:
data := f.Data()
const max = 256
if len(data) > max {
data = data[:max]
}
fmt.Fprintf(&buf, " data=%q", data)
if len(f.Data()) > max {
fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max)
}
case *http2WindowUpdateFrame:
if f.StreamID == 0 {
buf.WriteString(" (conn)")
}
fmt.Fprintf(&buf, " incr=%v", f.Increment)
case *http2PingFrame:
fmt.Fprintf(&buf, " ping=%q", f.Data[:])
case *http2GoAwayFrame:
fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q",
f.LastStreamID, f.ErrCode, f.debugData)
case *http2RSTStreamFrame:
fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode)
}
return buf.String()
}
func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
return trace != nil && trace.WroteHeaderField != nil
}
func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(k, []string{v})
}
}
func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
if trace != nil {
return trace.Got1xxResponse
}
return nil
}
// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
return tlsCn, nil
}
func http2tlsUnderlyingConn(tc *tls.Conn) net.Conn {
return tc.NetConn()
}
var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
type http2goroutineLock uint64
func http2newGoroutineLock() http2goroutineLock {
if !http2DebugGoroutines {
return 0
}
return http2goroutineLock(http2curGoroutineID())
}
func (g http2goroutineLock) check() {
if !http2DebugGoroutines {
return
}
if http2curGoroutineID() != uint64(g) {
panic("running on the wrong goroutine")
}
}
func (g http2goroutineLock) checkNotOn() {
if !http2DebugGoroutines {
return
}
if http2curGoroutineID() == uint64(g) {
panic("running on the wrong goroutine")
}
}
var http2goroutineSpace = []byte("goroutine ")
func http2curGoroutineID() uint64 {
bp := http2littleBuf.Get().(*[]byte)
defer http2littleBuf.Put(bp)
b := *bp
b = b[:runtime.Stack(b, false)]
// Parse the 4707 out of "goroutine 4707 ["
b = bytes.TrimPrefix(b, http2goroutineSpace)
i := bytes.IndexByte(b, ' ')
if i < 0 {
panic(fmt.Sprintf("No space found in %q", b))
}
b = b[:i]
n, err := http2parseUintBytes(b, 10, 64)
if err != nil {
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
}
return n
}
var http2littleBuf = sync.Pool{
New: func() interface{} {
buf := make([]byte, 64)
return &buf
},
}
// parseUintBytes is like strconv.ParseUint, but using a []byte.
func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
var cutoff, maxVal uint64
if bitSize == 0 {
bitSize = int(strconv.IntSize)
}
s0 := s
switch {
case len(s) < 1:
err = strconv.ErrSyntax
goto Error
case 2 <= base && base <= 36:
// valid base; nothing to do
case base == 0:
// Look for octal, hex prefix.
switch {
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
base = 16
s = s[2:]
if len(s) < 1 {
err = strconv.ErrSyntax
goto Error
}
case s[0] == '0':
base = 8
default:
base = 10
}
default:
err = errors.New("invalid base " + strconv.Itoa(base))
goto Error
}
n = 0
cutoff = http2cutoff64(base)
maxVal = 1<<uint(bitSize) - 1
for i := 0; i < len(s); i++ {
var v byte
d := s[i]
switch {
case '0' <= d && d <= '9':
v = d - '0'
case 'a' <= d && d <= 'z':
v = d - 'a' + 10
case 'A' <= d && d <= 'Z':
v = d - 'A' + 10
default:
n = 0
err = strconv.ErrSyntax
goto Error
}
if int(v) >= base {
n = 0
err = strconv.ErrSyntax
goto Error
}
if n >= cutoff {
// n*base overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n *= uint64(base)
n1 := n + uint64(v)
if n1 < n || n1 > maxVal {
// n+v overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n = n1
}
return n, nil
Error:
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
}
// Return the first number n such that n*base >= 1<<64.
func http2cutoff64(base int) uint64 {
if base < 2 {
return 0
}
return (1<<64-1)/uint64(base) + 1
}
var (
http2commonBuildOnce sync.Once
http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case
http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case
)
func http2buildCommonHeaderMapsOnce() {
http2commonBuildOnce.Do(http2buildCommonHeaderMaps)
}
func http2buildCommonHeaderMaps() {
common := []string{
"accept",
"accept-charset",
"accept-encoding",
"accept-language",
"accept-ranges",
"age",
"access-control-allow-credentials",
"access-control-allow-headers",
"access-control-allow-methods",
"access-control-allow-origin",
"access-control-expose-headers",
"access-control-max-age",
"access-control-request-headers",
"access-control-request-method",
"allow",
"authorization",
"cache-control",
"content-disposition",
"content-encoding",
"content-language",
"content-length",
"content-location",
"content-range",
"content-type",
"cookie",
"date",
"etag",
"expect",
"expires",
"from",
"host",
"if-match",
"if-modified-since",
"if-none-match",
"if-unmodified-since",
"last-modified",
"link",
"location",
"max-forwards",
"origin",
"proxy-authenticate",
"proxy-authorization",
"range",
"referer",
"refresh",
"retry-after",
"server",
"set-cookie",
"strict-transport-security",
"trailer",
"transfer-encoding",
"user-agent",
"vary",
"via",
"www-authenticate",
"x-forwarded-for",
"x-forwarded-proto",
}
http2commonLowerHeader = make(map[string]string, len(common))
http2commonCanonHeader = make(map[string]string, len(common))
for _, v := range common {
chk := CanonicalHeaderKey(v)
http2commonLowerHeader[chk] = v
http2commonCanonHeader[v] = chk
}
}
func http2lowerHeader(v string) (lower string, ascii bool) {
http2buildCommonHeaderMapsOnce()
if s, ok := http2commonLowerHeader[v]; ok {
return s, true
}
return http2asciiToLower(v)
}
func http2canonicalHeader(v string) string {
http2buildCommonHeaderMapsOnce()
if s, ok := http2commonCanonHeader[v]; ok {
return s
}
return CanonicalHeaderKey(v)
}
var (
http2VerboseLogs bool
http2logFrameWrites bool
http2logFrameReads bool
http2inTests bool
)
func init() {
e := os.Getenv("GODEBUG")
if strings.Contains(e, "http2debug=1") {
http2VerboseLogs = true
}
if strings.Contains(e, "http2debug=2") {
http2VerboseLogs = true
http2logFrameWrites = true
http2logFrameReads = true
}
}
const (
// ClientPreface is the string that must be sent by new
// connections from clients.
http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
// SETTINGS_MAX_FRAME_SIZE default
// https://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2
http2initialMaxFrameSize = 16384
// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/2's TLS setup.
http2NextProtoTLS = "h2"
// https://httpwg.org/specs/rfc7540.html#SettingValues
http2initialHeaderTableSize = 4096
http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
http2defaultMaxReadFrameSize = 1 << 20
)
var (
http2clientPreface = []byte(http2ClientPreface)
)
type http2streamState int
// HTTP/2 stream states.
//
// See http://tools.ietf.org/html/rfc7540#section-5.1.
//
// For simplicity, the server code merges "reserved (local)" into
// "half-closed (remote)". This is one less state transition to track.
// The only downside is that we send PUSH_PROMISEs slightly less
// liberally than allowable. More discussion here:
// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html
//
// "reserved (remote)" is omitted since the client code does not
// support server push.
const (
http2stateIdle http2streamState = iota
http2stateOpen
http2stateHalfClosedLocal
http2stateHalfClosedRemote
http2stateClosed
)
var http2stateName = [...]string{
http2stateIdle: "Idle",
http2stateOpen: "Open",
http2stateHalfClosedLocal: "HalfClosedLocal",
http2stateHalfClosedRemote: "HalfClosedRemote",
http2stateClosed: "Closed",
}
func (st http2streamState) String() string {
return http2stateName[st]
}
// Setting is a setting parameter: which setting it is, and its value.
type http2Setting struct {
// ID is which setting is being set.
// See https://httpwg.org/specs/rfc7540.html#SettingFormat
ID http2SettingID
// Val is the value.
Val uint32
}
func (s http2Setting) String() string {
return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
}
// Valid reports whether the setting is valid.
func (s http2Setting) Valid() error {
// Limits and error codes from 6.5.2 Defined SETTINGS Parameters
switch s.ID {
case http2SettingEnablePush:
if s.Val != 1 && s.Val != 0 {
return http2ConnectionError(http2ErrCodeProtocol)
}
case http2SettingInitialWindowSize:
if s.Val > 1<<31-1 {
return http2ConnectionError(http2ErrCodeFlowControl)
}
case http2SettingMaxFrameSize:
if s.Val < 16384 || s.Val > 1<<24-1 {
return http2ConnectionError(http2ErrCodeProtocol)
}
}
return nil
}
// A SettingID is an HTTP/2 setting as defined in
// https://httpwg.org/specs/rfc7540.html#iana-settings
type http2SettingID uint16
const (
http2SettingHeaderTableSize http2SettingID = 0x1
http2SettingEnablePush http2SettingID = 0x2
http2SettingMaxConcurrentStreams http2SettingID = 0x3
http2SettingInitialWindowSize http2SettingID = 0x4
http2SettingMaxFrameSize http2SettingID = 0x5
http2SettingMaxHeaderListSize http2SettingID = 0x6
)
var http2settingName = map[http2SettingID]string{
http2SettingHeaderTableSize: "HEADER_TABLE_SIZE",
http2SettingEnablePush: "ENABLE_PUSH",
http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
http2SettingMaxFrameSize: "MAX_FRAME_SIZE",
http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
}
func (s http2SettingID) String() string {
if v, ok := http2settingName[s]; ok {
return v
}
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
}
// validWireHeaderFieldName reports whether v is a valid header field
// name (key). See httpguts.ValidHeaderName for the base rules.
//
// Further, http2 says:
//
// "Just as in HTTP/1.x, header field names are strings of ASCII
// characters that are compared in a case-insensitive
// fashion. However, header field names MUST be converted to
// lowercase prior to their encoding in HTTP/2. "
func http2validWireHeaderFieldName(v string) bool {
if len(v) == 0 {
return false
}
for _, r := range v {
if !httpguts.IsTokenRune(r) {
return false
}
if 'A' <= r && r <= 'Z' {
return false
}
}
return true
}
func http2httpCodeString(code int) string {
switch code {
case 200:
return "200"
case 404:
return "404"
}
return strconv.Itoa(code)
}
// from pkg io
type http2stringWriter interface {
WriteString(s string) (n int, err error)
}
// A gate lets two goroutines coordinate their activities.
type http2gate chan struct{}
func (g http2gate) Done() { g <- struct{}{} }
func (g http2gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type http2closeWaiter chan struct{}
// Init makes a closeWaiter usable.
// It exists because so a closeWaiter value can be placed inside a
// larger struct and have the Mutex and Cond's memory in the same
// allocation.
func (cw *http2closeWaiter) Init() {
*cw = make(chan struct{})
}
// Close marks the closeWaiter as closed and unblocks any waiters.
func (cw http2closeWaiter) Close() {
close(cw)
}
// Wait waits for the closeWaiter to become closed.
func (cw http2closeWaiter) Wait() {
<-cw
}
// bufferedWriter is a buffered writer that writes to w.
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type http2bufferedWriter struct {
_ http2incomparable
w io.Writer // immutable
bw *bufio.Writer // non-nil when data is buffered
}
func http2newBufferedWriter(w io.Writer) *http2bufferedWriter {
return &http2bufferedWriter{w: w}
}
// bufWriterPoolBufferSize is the size of bufio.Writer's
// buffers created using bufWriterPool.
//
// TODO: pick a less arbitrary value? this is a bit under
// (3 x typical 1500 byte MTU) at least. Other than that,
// not much thought went into it.
const http2bufWriterPoolBufferSize = 4 << 10
var http2bufWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize)
},
}
func (w *http2bufferedWriter) Available() int {
if w.bw == nil {
return http2bufWriterPoolBufferSize
}
return w.bw.Available()
}
func (w *http2bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := http2bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w)
w.bw = bw
}
return w.bw.Write(p)
}
func (w *http2bufferedWriter) Flush() error {
bw := w.bw
if bw == nil {
return nil
}
err := bw.Flush()
bw.Reset(nil)
http2bufWriterPool.Put(bw)
w.bw = nil
return err
}
func http2mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
}
return uint32(v)
}
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 7230, section 3.3.
func http2bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
type http2httpError struct {
_ http2incomparable
msg string
timeout bool
}
func (e *http2httpError) Error() string { return e.msg }
func (e *http2httpError) Timeout() bool { return e.timeout }
func (e *http2httpError) Temporary() bool { return true }
var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
type http2connectionStater interface {
ConnectionState() tls.ConnectionState
}
var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }}
type http2sorter struct {
v []string // owned by sorter
}
func (s *http2sorter) Len() int { return len(s.v) }
func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] }
func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] }
// Keys returns the sorted keys of h.
//
// The returned slice is only valid until s used again or returned to
// its pool.
func (s *http2sorter) Keys(h Header) []string {
keys := s.v[:0]
for k := range h {
keys = append(keys, k)
}
s.v = keys
sort.Sort(s)
return keys
}
func (s *http2sorter) SortStrings(ss []string) {
// Our sorter works on s.v, which sorter owns, so
// stash it away while we sort the user's buffer.
save := s.v
s.v = ss
sort.Sort(s)
s.v = save
}
// validPseudoPath reports whether v is a valid :path pseudo-header
// value. It must be either:
//
// - a non-empty string starting with '/'
// - the string '*', for OPTIONS requests.
//
// For now this is only used a quick check for deciding when to clean
// up Opaque URLs before sending requests from the Transport.
// See golang.org/issue/16847
//
// We used to enforce that the path also didn't start with "//", but
// Google's GFE accepts such paths and Chrome sends them, so ignore
// that part of the spec. See golang.org/issue/19103.
func http2validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/') || v == "*"
}
// incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type http2incomparable [0]func()
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered)
type http2pipe struct {
mu sync.Mutex
c sync.Cond // c.L lazily initialized to &p.mu
b http2pipeBuffer // nil when done reading
unread int // bytes unread when done
err error // read error once empty. non-nil means closed.
breakErr error // immediate read error (caller doesn't see rest of b)
donec chan struct{} // closed on error
readFn func() // optional code to run in Read before error
}
type http2pipeBuffer interface {
Len() int
io.Writer
io.Reader
}
// setBuffer initializes the pipe buffer.
// It has no effect if the pipe is already closed.
func (p *http2pipe) setBuffer(b http2pipeBuffer) {
p.mu.Lock()
defer p.mu.Unlock()
if p.err != nil || p.breakErr != nil {
return
}
p.b = b
}
func (p *http2pipe) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
if p.b == nil {
return p.unread
}
return p.b.Len()
}
// Read waits until data is available and copies bytes
// from the buffer into p.
func (p *http2pipe) Read(d []byte) (n int, err error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.c.L == nil {
p.c.L = &p.mu
}
for {
if p.breakErr != nil {
return 0, p.breakErr
}
if p.b != nil && p.b.Len() > 0 {
return p.b.Read(d)
}
if p.err != nil {
if p.readFn != nil {
p.readFn() // e.g. copy trailers
p.readFn = nil // not sticky like p.err
}
p.b = nil
return 0, p.err
}
p.c.Wait()
}
}
var http2errClosedPipeWrite = errors.New("write on closed buffer")
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
func (p *http2pipe) Write(d []byte) (n int, err error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.c.L == nil {
p.c.L = &p.mu
}
defer p.c.Signal()
if p.err != nil {
return 0, http2errClosedPipeWrite
}
if p.breakErr != nil {
p.unread += len(d)
return len(d), nil // discard when there is no reader
}
return p.b.Write(d)
}
// CloseWithError causes the next Read (waking up a current blocked
// Read if needed) to return the provided err after all data has been
// read.
//
// The error must be non-nil.
func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) }
// BreakWithError causes the next Read (waking up a current blocked
// Read if needed) to return the provided err immediately, without
// waiting for unread data.
func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) }
// closeWithErrorAndCode is like CloseWithError but also sets some code to run
// in the caller's goroutine before returning the error.
func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) }
func (p *http2pipe) closeWithError(dst *error, err error, fn func()) {
if err == nil {
panic("err must be non-nil")
}
p.mu.Lock()
defer p.mu.Unlock()
if p.c.L == nil {
p.c.L = &p.mu
}
defer p.c.Signal()
if *dst != nil {
// Already been done.
return
}
p.readFn = fn
if dst == &p.breakErr {
if p.b != nil {
p.unread += p.b.Len()
}
p.b = nil
}
*dst = err
p.closeDoneLocked()
}
// requires p.mu be held.
func (p *http2pipe) closeDoneLocked() {
if p.donec == nil {
return
}
// Close if unclosed. This isn't racy since we always
// hold p.mu while closing.
select {
case <-p.donec:
default:
close(p.donec)
}
}
// Err returns the error (if any) first set by BreakWithError or CloseWithError.
func (p *http2pipe) Err() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.breakErr != nil {
return p.breakErr
}
return p.err
}
// Done returns a channel which is closed if and when this pipe is closed
// with CloseWithError.
func (p *http2pipe) Done() <-chan struct{} {
p.mu.Lock()
defer p.mu.Unlock()
if p.donec == nil {
p.donec = make(chan struct{})
if p.err != nil || p.breakErr != nil {
// Already hit an error.
p.closeDoneLocked()
}
}
return p.donec
}
const (
http2prefaceTimeout = 10 * time.Second
http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
http2handlerChunkWriteSize = 4 << 10
http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
http2maxQueuedControlFrames = 10000
)
var (
http2errClientDisconnected = errors.New("client disconnected")
http2errClosedBody = errors.New("body closed by handler")
http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
http2errStreamClosed = errors.New("http2: stream closed")
)
var http2responseWriterStatePool = sync.Pool{
New: func() interface{} {
rws := &http2responseWriterState{}
rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize)
return rws
},
}
// Test hooks.
var (
http2testHookOnConn func()
http2testHookGetServerConn func(*http2serverConn)
http2testHookOnPanicMu *sync.Mutex // nil except in tests
http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool)
)
// Server is an HTTP/2 server.
type http2Server struct {
// MaxHandlers limits the number of http.Handler ServeHTTP goroutines
// which may run at a time over all connections.
// Negative or zero no limit.
// TODO: implement
MaxHandlers int
// MaxConcurrentStreams optionally specifies the number of
// concurrent streams that each client may have open at a
// time. This is unrelated to the number of http.Handler goroutines
// which may be active globally, which is MaxHandlers.
// If zero, MaxConcurrentStreams defaults to at least 100, per
// the HTTP/2 spec's recommendations.
MaxConcurrentStreams uint32
// MaxDecoderHeaderTableSize optionally specifies the http2
// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
// informs the remote endpoint of the maximum size of the header compression
// table used to decode header blocks, in octets. If zero, the default value
// of 4096 is used.
MaxDecoderHeaderTableSize uint32
// MaxEncoderHeaderTableSize optionally specifies an upper limit for the
// header compression table used for encoding request headers. Received
// SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
// the default value of 4096 is used.
MaxEncoderHeaderTableSize uint32
// MaxReadFrameSize optionally specifies the largest frame
// this server is willing to read. A valid value is between
// 16k and 16M, inclusive. If zero or otherwise invalid, a
// default value is used.
MaxReadFrameSize uint32
// PermitProhibitedCipherSuites, if true, permits the use of
// cipher suites prohibited by the HTTP/2 spec.
PermitProhibitedCipherSuites bool
// IdleTimeout specifies how long until idle clients should be
// closed with a GOAWAY frame. PING frames are not considered
// activity for the purposes of IdleTimeout.
IdleTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
// If the value is outside this range, a default value will be
// used instead.
MaxUploadBufferPerConnection int32
// MaxUploadBufferPerStream is the size of the initial flow control
// window for each stream. The HTTP/2 spec does not allow this to
// be larger than 2^32-1. If the value is zero or larger than the
// maximum, a default value will be used instead.
MaxUploadBufferPerStream int32
// NewWriteScheduler constructs a write scheduler for a connection.
// If nil, a default scheduler is chosen.
NewWriteScheduler func() http2WriteScheduler
// CountError, if non-nil, is called on HTTP/2 server errors.
// It's intended to increment a metric for monitoring, such
// as an expvar or Prometheus metric.
// The errType consists of only ASCII word characters.
CountError func(errType string)
// Internal state. This is a pointer (rather than embedded directly)
// so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers.
state *http2serverInternalState
}
func (s *http2Server) initialConnRecvWindowSize() int32 {
if s.MaxUploadBufferPerConnection >= http2initialWindowSize {
return s.MaxUploadBufferPerConnection
}
return 1 << 20
}
func (s *http2Server) initialStreamRecvWindowSize() int32 {
if s.MaxUploadBufferPerStream > 0 {
return s.MaxUploadBufferPerStream
}
return 1 << 20
}
func (s *http2Server) maxReadFrameSize() uint32 {
if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize {
return v
}
return http2defaultMaxReadFrameSize
}
func (s *http2Server) maxConcurrentStreams() uint32 {
if v := s.MaxConcurrentStreams; v > 0 {
return v
}
return http2defaultMaxStreams
}
func (s *http2Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return http2initialHeaderTableSize
}
func (s *http2Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return http2initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
func (s *http2Server) maxQueuedControlFrames() int {
// TODO: if anybody asks, add a Server field, and remember to define the
// behavior of negative values.
return http2maxQueuedControlFrames
}
type http2serverInternalState struct {
mu sync.Mutex
activeConns map[*http2serverConn]struct{}
}
func (s *http2serverInternalState) registerConn(sc *http2serverConn) {
if s == nil {
return // if the Server was used without calling ConfigureServer
}
s.mu.Lock()
s.activeConns[sc] = struct{}{}
s.mu.Unlock()
}
func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) {
if s == nil {
return // if the Server was used without calling ConfigureServer
}
s.mu.Lock()
delete(s.activeConns, sc)
s.mu.Unlock()
}
func (s *http2serverInternalState) startGracefulShutdown() {
if s == nil {
return // if the Server was used without calling ConfigureServer
}
s.mu.Lock()
for sc := range s.activeConns {
sc.startGracefulShutdown()
}
s.mu.Unlock()
}
// ConfigureServer adds HTTP/2 support to a net/http Server.
//
// The configuration conf may be nil.
//
// ConfigureServer must be called before s begins serving.
func http2ConfigureServer(s *Server, conf *http2Server) error {
if s == nil {
panic("nil *http.Server")
}
if conf == nil {
conf = new(http2Server)
}
conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})}
if h1, h2 := s, conf; h2.IdleTimeout == 0 {
if h1.IdleTimeout != 0 {
h2.IdleTimeout = h1.IdleTimeout
} else {
h2.IdleTimeout = h1.ReadTimeout
}
}
s.RegisterOnShutdown(conf.state.startGracefulShutdown)
if s.TLSConfig == nil {
s.TLSConfig = new(tls.Config)
} else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 {
// If they already provided a TLS 1.0–1.2 CipherSuite list, return an
// error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or
// ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.
haveRequired := false
for _, cs := range s.TLSConfig.CipherSuites {
switch cs {
case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
// Alternative MTI cipher to not discourage ECDSA-only servers.
// See http://golang.org/cl/30721 for further information.
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
haveRequired = true
}
}
if !haveRequired {
return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
}
}
// Note: not setting MinVersion to tls.VersionTLS12,
// as we don't want to interfere with HTTP/1.1 traffic
// on the user's server. We enforce TLS 1.2 later once
// we accept a connection. Ideally this should be done
// during next-proto selection, but using TLS <1.2 with
// HTTP/2 is still the client's bug.
s.TLSConfig.PreferServerCipherSuites = true
if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS)
}
if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1")
}
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
}
protoHandler := func(hs *Server, c *tls.Conn, h Handler) {
if http2testHookOnConn != nil {
http2testHookOnConn()
}
// The TLSNextProto interface predates contexts, so
// the net/http package passes down its per-connection
// base context via an exported but unadvertised
// method on the Handler. This is for internal
// net/http<=>http2 use only.
var ctx context.Context
type baseContexter interface {
BaseContext() context.Context
}
if bc, ok := h.(baseContexter); ok {
ctx = bc.BaseContext()
}
conf.ServeConn(c, &http2ServeConnOpts{
Context: ctx,
Handler: h,
BaseConfig: hs,
})
}
s.TLSNextProto[http2NextProtoTLS] = protoHandler
return nil
}
// ServeConnOpts are options for the Server.ServeConn method.
type http2ServeConnOpts struct {
// Context is the base context to use.
// If nil, context.Background is used.
Context context.Context
// BaseConfig optionally sets the base configuration
// for values. If nil, defaults are used.
BaseConfig *Server
// Handler specifies which handler to use for processing
// requests. If nil, BaseConfig.Handler is used. If BaseConfig
// or BaseConfig.Handler is nil, http.DefaultServeMux is used.
Handler Handler
// UpgradeRequest is an initial request received on a connection
// undergoing an h2c upgrade. The request body must have been
// completely read from the connection before calling ServeConn,
// and the 101 Switching Protocols response written.
UpgradeRequest *Request
// Settings is the decoded contents of the HTTP2-Settings header
// in an h2c upgrade request.
Settings []byte
// SawClientPreface is set if the HTTP/2 connection preface
// has already been read from the connection.
SawClientPreface bool
}
func (o *http2ServeConnOpts) context() context.Context {
if o != nil && o.Context != nil {
return o.Context
}
return context.Background()
}
func (o *http2ServeConnOpts) baseConfig() *Server {
if o != nil && o.BaseConfig != nil {
return o.BaseConfig
}
return new(Server)
}
func (o *http2ServeConnOpts) handler() Handler {
if o != nil {
if o.Handler != nil {
return o.Handler
}
if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
return o.BaseConfig.Handler
}
}
return DefaultServeMux
}
// ServeConn serves HTTP/2 requests on the provided connection and
// blocks until the connection is no longer readable.
//
// ServeConn starts speaking HTTP/2 assuming that c has not had any
// reads or writes. It writes its initial settings frame and expects
// to be able to read the preface and settings frame from the
// client. If c has a ConnectionState method like a *tls.Conn, the
// ConnectionState is used to verify the TLS ciphersuite and to set
// the Request.TLS field in Handlers.
//
// ServeConn does not support h2c by itself. Any h2c support must be
// implemented in terms of providing a suitably-behaving net.Conn.
//
// The opts parameter is optional. If nil, default values are used.
func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
baseCtx, cancel := http2serverConnBaseContext(c, opts)
defer cancel()
sc := &http2serverConn{
srv: s,
hs: opts.baseConfig(),
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
bw: http2newBufferedWriter(c),
handler: opts.handler(),
streams: make(map[uint32]*http2stream),
readFrameCh: make(chan http2readFrameResult),
wantWriteFrameCh: make(chan http2FrameWriteRequest, 8),
serveMsgCh: make(chan interface{}, 8),
wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync
bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(),
initialStreamSendWindowSize: http2initialWindowSize,
maxFrameSize: http2initialMaxFrameSize,
serveG: http2newGoroutineLock(),
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
}
s.state.registerConn(sc)
defer s.state.unregisterConn(sc)
// The net/http package sets the write deadline from the
// http.Server.WriteTimeout during the TLS handshake, but then
// passes the connection off to us with the deadline already set.
// Write deadlines are set per stream in serverConn.newStream.
// Disarm the net.Conn write deadline here.
if sc.hs.WriteTimeout != 0 {
sc.conn.SetWriteDeadline(time.Time{})
}
if s.NewWriteScheduler != nil {
sc.writeSched = s.NewWriteScheduler()
} else {
sc.writeSched = http2NewPriorityWriteScheduler(nil)
}
// These start at the RFC-specified defaults. If there is a higher
// configured value for inflow, that will be updated when we send a
// WINDOW_UPDATE shortly after sending SETTINGS.
sc.flow.add(http2initialWindowSize)
sc.inflow.init(http2initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
fr := http2NewFramer(sc.bw, c)
if s.CountError != nil {
fr.countError = s.CountError
}
fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize())
sc.framer = fr
if tc, ok := c.(http2connectionStater); ok {
sc.tlsState = new(tls.ConnectionState)
*sc.tlsState = tc.ConnectionState()
// 9.2 Use of TLS Features
// An implementation of HTTP/2 over TLS MUST use TLS
// 1.2 or higher with the restrictions on feature set
// and cipher suite described in this section. Due to
// implementation limitations, it might not be
// possible to fail TLS negotiation. An endpoint MUST
// immediately terminate an HTTP/2 connection that
// does not meet the TLS requirements described in
// this section with a connection error (Section
// 5.4.1) of type INADEQUATE_SECURITY.
if sc.tlsState.Version < tls.VersionTLS12 {
sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low")
return
}
if sc.tlsState.ServerName == "" {
// Client must use SNI, but we don't enforce that anymore,
// since it was causing problems when connecting to bare IP
// addresses during development.
//
// TODO: optionally enforce? Or enforce at the time we receive
// a new request, and verify the ServerName matches the :authority?
// But that precludes proxy situations, perhaps.
//
// So for now, do nothing here again.
}
if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated."
//
// We choose that. In my opinion, the spec is weak
// here. It also says both parties must support at least
// TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no
// excuses here. If we really must, we could allow an
// "AllowInsecureWeakCiphers" option on the server later.
// Let's see how it plays out first.
sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
return
}
}
if opts.Settings != nil {
fr := &http2SettingsFrame{
http2FrameHeader: http2FrameHeader{valid: true},
p: opts.Settings,
}
if err := fr.ForeachSetting(sc.processSetting); err != nil {
sc.rejectConn(http2ErrCodeProtocol, "invalid settings")
return
}
opts.Settings = nil
}
if hook := http2testHookGetServerConn; hook != nil {
hook(sc)
}
if opts.UpgradeRequest != nil {
sc.upgradeRequest(opts.UpgradeRequest)
opts.UpgradeRequest = nil
}
sc.serve()
}
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) {
ctx, cancel = context.WithCancel(opts.context())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
if hs := opts.baseConfig(); hs != nil {
ctx = context.WithValue(ctx, ServerContextKey, hs)
}
return
}
func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) {
sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
// ignoring errors. hanging up anyway.
sc.framer.WriteGoAway(0, err, []byte(debug))
sc.bw.Flush()
sc.conn.Close()
}
type http2serverConn struct {
// Immutable:
srv *http2Server
hs *Server
conn net.Conn
bw *http2bufferedWriter // writing to conn
handler Handler
baseCtx context.Context
framer *http2Framer
doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan http2readFrameResult // written by serverConn.readFrames
wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve
wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
bodyReadCh chan http2bodyReadMsg // from handlers -> serve
serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop
flow http2outflow // conn-wide (not stream-specific) outbound flow control
inflow http2inflow // conn-wide inbound flow control
tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string
writeSched http2WriteScheduler
// Everything following is owned by the serve loop; use serveG.check():
serveG http2goroutineLock // used to verify funcs are on serve()
pushEnabled bool
sawClientPreface bool // preface has already been read, used in h2c upgrade
sawFirstSettings bool // got the initial SETTINGS frame after the preface
needToSendSettingsAck bool
unackedSettings int // how many SETTINGS have we sent without ACKs?
queuedControlFrames int // control frames in the writeSched queue
clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
curClientStreams uint32 // number of open streams initiated by the client
curPushedStreams uint32 // number of open streams initiated by server push
maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
streams map[uint32]*http2stream
initialStreamSendWindowSize int32
maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
canonHeaderKeysSize int // canonHeader keys size in bytes
writingFrame bool // started writing a frame (on serve goroutine or separate)
writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh
needsFrameFlush bool // last frame write wasn't a flush
inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
goAwayCode http2ErrCode
shutdownTimer *time.Timer // nil until used
idleTimer *time.Timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
hpackEncoder *hpack.Encoder
// Used by startGracefulShutdown.
shutdownOnce sync.Once
}
func (sc *http2serverConn) maxHeaderListSize() uint32 {
n := sc.hs.MaxHeaderBytes
if n <= 0 {
n = DefaultMaxHeaderBytes
}
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return uint32(n + typicalHeaders*perFieldOverhead)
}
func (sc *http2serverConn) curOpenStreams() uint32 {
sc.serveG.check()
return sc.curClientStreams + sc.curPushedStreams
}
// stream represents a stream. This is the minimal metadata needed by
// the serve goroutine. Most of the actual stream state is owned by
// the http.Handler's goroutine in the responseWriter. Because the
// responseWriter's responseWriterState is recycled at the end of a
// handler, this struct intentionally has no pointer to the
// *responseWriter{,State} itself, as the Handler ending nils out the
// responseWriter's state field.
type http2stream struct {
// immutable:
sc *http2serverConn
id uint32
body *http2pipe // non-nil if expecting DATA frames
cw http2closeWaiter // closed wait stream transitions to closed state
ctx context.Context
cancelCtx func()
// owned by serverConn's serve loop:
bodyBytes int64 // body bytes seen so far
declBodyBytes int64 // or -1 if undeclared
flow http2outflow // limits writing from Handler to client
inflow http2inflow // what the client is allowed to POST/etc to us
state http2streamState
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
trailer Header // accumulated trailers
reqTrailer Header // handler's Request.Trailer
}
func (sc *http2serverConn) Framer() *http2Framer { return sc.framer }
func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() }
func (sc *http2serverConn) Flush() error { return sc.bw.Flush() }
func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
return sc.hpackEncoder, &sc.headerWriteBuf
}
func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) {
sc.serveG.check()
// http://tools.ietf.org/html/rfc7540#section-5.1
if st, ok := sc.streams[streamID]; ok {
return st.state, st
}
// "The first use of a new stream identifier implicitly closes all
// streams in the "idle" state that might have been initiated by
// that peer with a lower-valued stream identifier. For example, if
// a client sends a HEADERS frame on stream 7 without ever sending a
// frame on stream 5, then stream 5 transitions to the "closed"
// state when the first frame for stream 7 is sent or received."
if streamID%2 == 1 {
if streamID <= sc.maxClientStreamID {
return http2stateClosed, nil
}
} else {
if streamID <= sc.maxPushPromiseID {
return http2stateClosed, nil
}
}
return http2stateIdle, nil
}
// setConnState calls the net/http ConnState hook for this connection, if configured.
// Note that the net/http package does StateNew and StateClosed for us.
// There is currently no plan for StateHijacked or hijacking HTTP/2 connections.
func (sc *http2serverConn) setConnState(state ConnState) {
if sc.hs.ConnState != nil {
sc.hs.ConnState(sc.conn, state)
}
}
func (sc *http2serverConn) vlogf(format string, args ...interface{}) {
if http2VerboseLogs {
sc.logf(format, args...)
}
}
func (sc *http2serverConn) logf(format string, args ...interface{}) {
if lg := sc.hs.ErrorLog; lg != nil {
lg.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// errno returns v's underlying uintptr, else 0.
//
// TODO: remove this helper function once http2 can use build
// tags. See comment in isClosedConnError.
func http2errno(v error) uintptr {
if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
return uintptr(rv.Uint())
}
return 0
}
// isClosedConnError reports whether err is an error from use of a closed
// network connection.
func http2isClosedConnError(err error) bool {
if err == nil {
return false
}
// TODO: remove this string search and be more like the Windows
// case below. That might involve modifying the standard library
// to return better error types.
str := err.Error()
if strings.Contains(str, "use of closed network connection") {
return true
}
// TODO(bradfitz): x/tools/cmd/bundle doesn't really support
// build tags, so I can't make an http2_windows.go file with
// Windows-specific stuff. Fix that and move this, once we
// have a way to bundle this into std's net/http somehow.
if runtime.GOOS == "windows" {
if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
const WSAECONNABORTED = 10053
const WSAECONNRESET = 10054
if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
return true
}
}
}
}
return false
}
func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) {
if err == nil {
return
}
if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout {
// Boring, expected errors.
sc.vlogf(format, args...)
} else {
sc.logf(format, args...)
}
}
// maxCachedCanonicalHeadersKeysSize is an arbitrarily-chosen limit on the size
// of the entries in the canonHeader cache.
// This should be larger than the size of unique, uncommon header keys likely to
// be sent by the peer, while not so high as to permit unreasonable memory usage
// if the peer sends an unbounded number of unique header keys.
const http2maxCachedCanonicalHeadersKeysSize = 2048
func (sc *http2serverConn) canonicalHeader(v string) string {
sc.serveG.check()
http2buildCommonHeaderMapsOnce()
cv, ok := http2commonCanonHeader[v]
if ok {
return cv
}
cv, ok = sc.canonHeader[v]
if ok {
return cv
}
if sc.canonHeader == nil {
sc.canonHeader = make(map[string]string)
}
cv = CanonicalHeaderKey(v)
size := 100 + len(v)*2 // 100 bytes of map overhead + key + value
if sc.canonHeaderKeysSize+size <= http2maxCachedCanonicalHeadersKeysSize {
sc.canonHeader[v] = cv
sc.canonHeaderKeysSize += size
}
return cv
}
type http2readFrameResult struct {
f http2Frame // valid until readMore is called
err error
// readMore should be called once the consumer no longer needs or
// retains f. After readMore, f is invalid and more frames can be
// read.
readMore func()
}
// readFrames is the loop that reads incoming frames.
// It takes care to only read one frame at a time, blocking until the
// consumer is done with the frame.
// It's run on its own goroutine.
func (sc *http2serverConn) readFrames() {
gate := make(http2gate)
gateDone := gate.Done
for {
f, err := sc.framer.ReadFrame()
select {
case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}:
case <-sc.doneServing:
return
}
select {
case <-gate:
case <-sc.doneServing:
return
}
if http2terminalReadFrameError(err) {
return
}
}
}
// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
type http2frameWriteResult struct {
_ http2incomparable
wr http2FrameWriteRequest // what was written (or attempted)
err error // result of the writeFrame call
}
// writeFrameAsync runs in its own goroutine and writes a single frame
// and then reports when it's done.
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest, wd *http2writeData) {
var err error
if wd == nil {
err = wr.write.writeFrame(sc)
} else {
err = sc.framer.endWrite()
}
sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err}
}
func (sc *http2serverConn) closeAllStreamsOnConnClose() {
sc.serveG.check()
for _, st := range sc.streams {
sc.closeStream(st, http2errClientDisconnected)
}
}
func (sc *http2serverConn) stopShutdownTimer() {
sc.serveG.check()
if t := sc.shutdownTimer; t != nil {
t.Stop()
}
}
func (sc *http2serverConn) notePanic() {
// Note: this is for serverConn.serve panicking, not http.Handler code.
if http2testHookOnPanicMu != nil {
http2testHookOnPanicMu.Lock()
defer http2testHookOnPanicMu.Unlock()
}
if http2testHookOnPanic != nil {
if e := recover(); e != nil {
if http2testHookOnPanic(sc, e) {
panic(e)
}
}
}
}
func (sc *http2serverConn) serve() {
sc.serveG.check()
defer sc.notePanic()
defer sc.conn.Close()
defer sc.closeAllStreamsOnConnClose()
defer sc.stopShutdownTimer()
defer close(sc.doneServing) // unblocks handlers trying to send
if http2VerboseLogs {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
sc.writeFrame(http2FrameWriteRequest{
write: http2writeSettings{
{http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{http2SettingMaxConcurrentStreams, sc.advMaxStreams},
{http2SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{http2SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
{http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
},
})
sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
}
if err := sc.readPreface(); err != nil {
sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
return
}
// Now that we've got the preface, get us out of the
// "StateNew" state. We can't go directly to idle, though.
// Active means we read some data and anticipate a request. We'll
// do another Active when we get a HEADERS frame.
sc.setConnState(StateActive)
sc.setConnState(StateIdle)
if sc.srv.IdleTimeout != 0 {
sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
defer sc.idleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
loopNum := 0
for {
loopNum++
select {
case wr := <-sc.wantWriteFrameCh:
if se, ok := wr.write.(http2StreamError); ok {
sc.resetStream(se)
break
}
sc.writeFrame(wr)
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
select {
case wroteRes := <-sc.wroteFrameCh:
sc.wroteFrame(wroteRes)
default:
}
}
if !sc.processFrameFromReader(res) {
return
}
res.readMore()
if settingsTimer != nil {
settingsTimer.Stop()
settingsTimer = nil
}
case m := <-sc.bodyReadCh:
sc.noteBodyRead(m.st, m.n)
case msg := <-sc.serveMsgCh:
switch v := msg.(type) {
case func(int):
v(loopNum) // for testing
case *http2serverMessage:
switch v {
case http2settingsTimerMsg:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return
case http2idleTimerMsg:
sc.vlogf("connection is idle")
sc.goAway(http2ErrCodeNo)
case http2shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
case http2gracefulShutdownMsg:
sc.startGracefulShutdownInternal()
default:
panic("unknown timer")
}
case *http2startPushRequest:
sc.startPush(v)
case func(*http2serverConn):
v(sc)
default:
panic(fmt.Sprintf("unexpected type %T", v))
}
}
// If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us
// run out of memory.
if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() {
sc.vlogf("http2: too many control frames in send queue, closing connection")
return
}
// Start the shutdown timer after sending a GOAWAY. When sending GOAWAY
// with no error code (graceful shutdown), don't start the timer until
// all open streams have been completed.
sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0
if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) {
sc.shutDownIn(http2goAwayTimeout)
}
}
}
func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) {
select {
case <-sc.doneServing:
case <-sharedCh:
close(privateCh)
}
}
type http2serverMessage int
// Message values sent to serveMsgCh.
var (
http2settingsTimerMsg = new(http2serverMessage)
http2idleTimerMsg = new(http2serverMessage)
http2shutdownTimerMsg = new(http2serverMessage)
http2gracefulShutdownMsg = new(http2serverMessage)
)
func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) }
func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) }
func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) }
func (sc *http2serverConn) sendServeMsg(msg interface{}) {
sc.serveG.checkNotOn() // NOT
select {
case sc.serveMsgCh <- msg:
case <-sc.doneServing:
}
}
var http2errPrefaceTimeout = errors.New("timeout waiting for client preface")
// readPreface reads the ClientPreface greeting from the peer or
// returns errPrefaceTimeout on timeout, or an error if the greeting
// is invalid.
func (sc *http2serverConn) readPreface() error {
if sc.sawClientPreface {
return nil
}
errc := make(chan error, 1)
go func() {
// Read the client preface
buf := make([]byte, len(http2ClientPreface))
if _, err := io.ReadFull(sc.conn, buf); err != nil {
errc <- err
} else if !bytes.Equal(buf, http2clientPreface) {
errc <- fmt.Errorf("bogus greeting %q", buf)
} else {
errc <- nil
}
}()
timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server?
defer timer.Stop()
select {
case <-timer.C:
return http2errPrefaceTimeout
case err := <-errc:
if err == nil {
if http2VerboseLogs {
sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
}
}
return err
}
}
var http2errChanPool = sync.Pool{
New: func() interface{} { return make(chan error, 1) },
}
var http2writeDataPool = sync.Pool{
New: func() interface{} { return new(http2writeData) },
}
// writeDataFromHandler writes DATA response frames from a handler on
// the given stream.
func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error {
ch := http2errChanPool.Get().(chan error)
writeArg := http2writeDataPool.Get().(*http2writeData)
*writeArg = http2writeData{stream.id, data, endStream}
err := sc.writeFrameFromHandler(http2FrameWriteRequest{
write: writeArg,
stream: stream,
done: ch,
})
if err != nil {
return err
}
var frameWriteDone bool // the frame write is done (successfully or not)
select {
case err = <-ch:
frameWriteDone = true
case <-sc.doneServing:
return http2errClientDisconnected
case <-stream.cw:
// If both ch and stream.cw were ready (as might
// happen on the final Write after an http.Handler
// ends), prefer the write result. Otherwise this
// might just be us successfully closing the stream.
// The writeFrameAsync and serve goroutines guarantee
// that the ch send will happen before the stream.cw
// close.
select {
case err = <-ch:
frameWriteDone = true
default:
return http2errStreamClosed
}
}
http2errChanPool.Put(ch)
if frameWriteDone {
http2writeDataPool.Put(writeArg)
}
return err
}
// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts
// if the connection has gone away.
//
// This must not be run from the serve goroutine itself, else it might
// deadlock writing to sc.wantWriteFrameCh (which is only mildly
// buffered and is read by serve itself). If you're on the serve
// goroutine, call writeFrame instead.
func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error {
sc.serveG.checkNotOn() // NOT
select {
case sc.wantWriteFrameCh <- wr:
return nil
case <-sc.doneServing:
// Serve loop is gone.
// Client has closed their connection to the server.
return http2errClientDisconnected
}
}
// writeFrame schedules a frame to write and sends it if there's nothing
// already being written.
//
// There is no pushback here (the serve goroutine never blocks). It's
// the http.Handlers that block, waiting for their previous frames to
// make it onto the wire
//
// If you're not on the serve goroutine, use writeFrameFromHandler instead.
func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) {
sc.serveG.check()
// If true, wr will not be written and wr.done will not be signaled.
var ignoreWrite bool
// We are not allowed to write frames on closed streams. RFC 7540 Section
// 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on
// a closed stream." Our server never sends PRIORITY, so that exception
// does not apply.
//
// The serverConn might close an open stream while the stream's handler
// is still running. For example, the server might close a stream when it
// receives bad data from the client. If this happens, the handler might
// attempt to write a frame after the stream has been closed (since the
// handler hasn't yet been notified of the close). In this case, we simply
// ignore the frame. The handler will notice that the stream is closed when
// it waits for the frame to be written.
//
// As an exception to this rule, we allow sending RST_STREAM after close.
// This allows us to immediately reject new streams without tracking any
// state for those streams (except for the queued RST_STREAM frame). This
// may result in duplicate RST_STREAMs in some cases, but the client should
// ignore those.
if wr.StreamID() != 0 {
_, isReset := wr.write.(http2StreamError)
if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset {
ignoreWrite = true
}
}
// Don't send a 100-continue response if we've already sent headers.
// See golang.org/issue/14030.
switch wr.write.(type) {
case *http2writeResHeaders:
wr.stream.wroteHeaders = true
case http2write100ContinueHeadersFrame:
if wr.stream.wroteHeaders {
// We do not need to notify wr.done because this frame is
// never written with wr.done != nil.
if wr.done != nil {
panic("wr.done != nil for write100ContinueHeadersFrame")
}
ignoreWrite = true
}
}
if !ignoreWrite {
if wr.isControl() {
sc.queuedControlFrames++
// For extra safety, detect wraparounds, which should not happen,
// and pull the plug.
if sc.queuedControlFrames < 0 {
sc.conn.Close()
}
}
sc.writeSched.Push(wr)
}
sc.scheduleFrameWrite()
}
// startFrameWrite starts a goroutine to write wr (in a separate
// goroutine since that might block on the network), and updates the
// serve goroutine's state about the world, updated from info in wr.
func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) {
sc.serveG.check()
if sc.writingFrame {
panic("internal error: can only be writing one frame at a time")
}
st := wr.stream
if st != nil {
switch st.state {
case http2stateHalfClosedLocal:
switch wr.write.(type) {
case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate:
// RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE
// in this state. (We never send PRIORITY from the server, so that is not checked.)
default:
panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
}
case http2stateClosed:
panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
}
}
if wpp, ok := wr.write.(*http2writePushPromise); ok {
var err error
wpp.promisedID, err = wpp.allocatePromisedID()
if err != nil {
sc.writingFrameAsync = false
wr.replyToWriter(err)
return
}
}
sc.writingFrame = true
sc.needsFrameFlush = true
if wr.write.staysWithinBuffer(sc.bw.Available()) {
sc.writingFrameAsync = false
err := wr.write.writeFrame(sc)
sc.wroteFrame(http2frameWriteResult{wr: wr, err: err})
} else if wd, ok := wr.write.(*http2writeData); ok {
// Encode the frame in the serve goroutine, to ensure we don't have
// any lingering asynchronous references to data passed to Write.
// See https://go.dev/issue/58446.
sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil)
sc.writingFrameAsync = true
go sc.writeFrameAsync(wr, wd)
} else {
sc.writingFrameAsync = true
go sc.writeFrameAsync(wr, nil)
}
}
// errHandlerPanicked is the error given to any callers blocked in a read from
// Request.Body when the main goroutine panics. Since most handlers read in the
// main ServeHTTP goroutine, this will show up rarely.
var http2errHandlerPanicked = errors.New("http2: handler panicked")
// wroteFrame is called on the serve goroutine with the result of
// whatever happened on writeFrameAsync.
func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
sc.serveG.check()
if !sc.writingFrame {
panic("internal error: expected to be already writing a frame")
}
sc.writingFrame = false
sc.writingFrameAsync = false
wr := res.wr
if http2writeEndsStream(wr.write) {
st := wr.stream
if st == nil {
panic("internal error: expecting non-nil stream")
}
switch st.state {
case http2stateOpen:
// Here we would go to stateHalfClosedLocal in
// theory, but since our handler is done and
// the net/http package provides no mechanism
// for closing a ResponseWriter while still
// reading data (see possible TODO at top of
// this file), we go into closed state here
// anyway, after telling the peer we're
// hanging up on them. We'll transition to
// stateClosed after the RST_STREAM frame is
// written.
st.state = http2stateHalfClosedLocal
// Section 8.1: a server MAY request that the client abort
// transmission of a request without error by sending a
// RST_STREAM with an error code of NO_ERROR after sending
// a complete response.
sc.resetStream(http2streamError(st.id, http2ErrCodeNo))
case http2stateHalfClosedRemote:
sc.closeStream(st, http2errHandlerComplete)
}
} else {
switch v := wr.write.(type) {
case http2StreamError:
// st may be unknown if the RST_STREAM was generated to reject bad input.
if st, ok := sc.streams[v.StreamID]; ok {
sc.closeStream(st, v)
}
case http2handlerPanicRST:
sc.closeStream(wr.stream, http2errHandlerPanicked)
}
}
// Reply (if requested) to unblock the ServeHTTP goroutine.
wr.replyToWriter(res.err)
sc.scheduleFrameWrite()
}
// scheduleFrameWrite tickles the frame writing scheduler.
//
// If a frame is already being written, nothing happens. This will be called again
// when the frame is done being written.
//
// If a frame isn't being written and we need to send one, the best frame
// to send is selected by writeSched.
//
// If a frame isn't being written and there's nothing else to send, we
// flush the write buffer.
func (sc *http2serverConn) scheduleFrameWrite() {
sc.serveG.check()
if sc.writingFrame || sc.inFrameScheduleLoop {
return
}
sc.inFrameScheduleLoop = true
for !sc.writingFrameAsync {
if sc.needToSendGoAway {
sc.needToSendGoAway = false
sc.startFrameWrite(http2FrameWriteRequest{
write: &http2writeGoAway{
maxStreamID: sc.maxClientStreamID,
code: sc.goAwayCode,
},
})
continue
}
if sc.needToSendSettingsAck {
sc.needToSendSettingsAck = false
sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
continue
}
if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
if wr, ok := sc.writeSched.Pop(); ok {
if wr.isControl() {
sc.queuedControlFrames--
}
sc.startFrameWrite(wr)
continue
}
}
if sc.needsFrameFlush {
sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}})
sc.needsFrameFlush = false // after startFrameWrite, since it sets this true
continue
}
break
}
sc.inFrameScheduleLoop = false
}
// startGracefulShutdown gracefully shuts down a connection. This
// sends GOAWAY with ErrCodeNo to tell the client we're gracefully
// shutting down. The connection isn't closed until all current
// streams are done.
//
// startGracefulShutdown returns immediately; it does not wait until
// the connection has shut down.
func (sc *http2serverConn) startGracefulShutdown() {
sc.serveG.checkNotOn() // NOT
sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) })
}
// After sending GOAWAY with an error code (non-graceful shutdown), the
// connection will close after goAwayTimeout.
//
// If we close the connection immediately after sending GOAWAY, there may
// be unsent data in our kernel receive buffer, which will cause the kernel
// to send a TCP RST on close() instead of a FIN. This RST will abort the
// connection immediately, whether or not the client had received the GOAWAY.
//
// Ideally we should delay for at least 1 RTT + epsilon so the client has
// a chance to read the GOAWAY and stop sending messages. Measuring RTT
// is hard, so we approximate with 1 second. See golang.org/issue/18701.
//
// This is a var so it can be shorter in tests, where all requests uses the
// loopback interface making the expected RTT very small.
//
// TODO: configurable?
var http2goAwayTimeout = 1 * time.Second
func (sc *http2serverConn) startGracefulShutdownInternal() {
sc.goAway(http2ErrCodeNo)
}
func (sc *http2serverConn) goAway(code http2ErrCode) {
sc.serveG.check()
if sc.inGoAway {
if sc.goAwayCode == http2ErrCodeNo {
sc.goAwayCode = code
}
return
}
sc.inGoAway = true
sc.needToSendGoAway = true
sc.goAwayCode = code
sc.scheduleFrameWrite()
}
func (sc *http2serverConn) shutDownIn(d time.Duration) {
sc.serveG.check()
sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
}
func (sc *http2serverConn) resetStream(se http2StreamError) {
sc.serveG.check()
sc.writeFrame(http2FrameWriteRequest{write: se})
if st, ok := sc.streams[se.StreamID]; ok {
st.resetQueued = true
}
}
// processFrameFromReader processes the serve loop's read from readFrameCh from the
// frame-reading goroutine.
// processFrameFromReader returns whether the connection should be kept open.
func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool {
sc.serveG.check()
err := res.err
if err != nil {
if err == http2ErrFrameTooLarge {
sc.goAway(http2ErrCodeFrameSize)
return true // goAway will close the loop
}
clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err)
if clientGone {
// TODO: could we also get into this state if
// the peer does a half close
// (e.g. CloseWrite) because they're done
// sending frames but they're still wanting
// our open replies? Investigate.
// TODO: add CloseWrite to crypto/tls.Conn first
// so we have a way to test this? I suppose
// just for testing we could have a non-TLS mode.
return false
}
} else {
f := res.f
if http2VerboseLogs {
sc.vlogf("http2: server read frame %v", http2summarizeFrame(f))
}
err = sc.processFrame(f)
if err == nil {
return true
}
}
switch ev := err.(type) {
case http2StreamError:
sc.resetStream(ev)
return true
case http2goAwayFlowError:
sc.goAway(http2ErrCodeFlowControl)
return true
case http2ConnectionError:
sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
sc.goAway(http2ErrCode(ev))
return true // goAway will handle shutdown
default:
if res.err != nil {
sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
} else {
sc.logf("http2: server closing client connection: %v", err)
}
return false
}
}
func (sc *http2serverConn) processFrame(f http2Frame) error {
sc.serveG.check()
// First frame received must be SETTINGS.
if !sc.sawFirstSettings {
if _, ok := f.(*http2SettingsFrame); !ok {
return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol))
}
sc.sawFirstSettings = true
}
// Discard frames for streams initiated after the identified last
// stream sent in a GOAWAY, or all frames after sending an error.
// We still need to return connection-level flow control for DATA frames.
// RFC 9113 Section 6.8.
if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
if f, ok := f.(*http2DataFrame); ok {
if !sc.inflow.take(f.Length) {
return sc.countError("data_flow", http2streamError(f.Header().StreamID, http2ErrCodeFlowControl))
}
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
}
return nil
}
switch f := f.(type) {
case *http2SettingsFrame:
return sc.processSettings(f)
case *http2MetaHeadersFrame:
return sc.processHeaders(f)
case *http2WindowUpdateFrame:
return sc.processWindowUpdate(f)
case *http2PingFrame:
return sc.processPing(f)
case *http2DataFrame:
return sc.processData(f)
case *http2RSTStreamFrame:
return sc.processResetStream(f)
case *http2PriorityFrame:
return sc.processPriority(f)
case *http2GoAwayFrame:
return sc.processGoAway(f)
case *http2PushPromiseFrame:
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol))
default:
sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil
}
}
func (sc *http2serverConn) processPing(f *http2PingFrame) error {
sc.serveG.check()
if f.IsAck() {
// 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag."
return nil
}
if f.StreamID != 0 {
// "PING frames are not associated with any individual
// stream. If a PING frame is received with a stream
// identifier field value other than 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR."
return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol))
}
sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}})
return nil
}
func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error {
sc.serveG.check()
switch {
case f.StreamID != 0: // stream-level flow control
state, st := sc.state(f.StreamID)
if state == http2stateIdle {
// Section 5.1: "Receiving any frame other than HEADERS
// or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR."
return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol))
}
if st == nil {
// "WINDOW_UPDATE can be sent by a peer that has sent a
// frame bearing the END_STREAM flag. This means that a
// receiver could receive a WINDOW_UPDATE frame on a "half
// closed (remote)" or "closed" stream. A receiver MUST
// NOT treat this as an error, see Section 5.1."
return nil
}
if !st.flow.add(int32(f.Increment)) {
return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl))
}
default: // connection-level flow control
if !sc.flow.add(int32(f.Increment)) {
return http2goAwayFlowError{}
}
}
sc.scheduleFrameWrite()
return nil
}
func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
sc.serveG.check()
state, st := sc.state(f.StreamID)
if state == http2stateIdle {
// 6.4 "RST_STREAM frames MUST NOT be sent for a
// stream in the "idle" state. If a RST_STREAM frame
// identifying an idle stream is received, the
// recipient MUST treat this as a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol))
}
if st != nil {
st.cancelCtx()
sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode))
}
return nil
}
func (sc *http2serverConn) closeStream(st *http2stream, err error) {
sc.serveG.check()
if st.state == http2stateIdle || st.state == http2stateClosed {
panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
}
st.state = http2stateClosed
if st.readDeadline != nil {
st.readDeadline.Stop()
}
if st.writeDeadline != nil {
st.writeDeadline.Stop()
}
if st.isPushed() {
sc.curPushedStreams--
} else {
sc.curClientStreams--
}
delete(sc.streams, st.id)
if len(sc.streams) == 0 {
sc.setConnState(StateIdle)
if sc.srv.IdleTimeout != 0 {
sc.idleTimer.Reset(sc.srv.IdleTimeout)
}
if http2h1ServerKeepAlivesDisabled(sc.hs) {
sc.startGracefulShutdownInternal()
}
}
if p := st.body; p != nil {
// Return any buffered unread bytes worth of conn-level flow control.
// See golang.org/issue/16481
sc.sendWindowUpdate(nil, p.Len())
p.CloseWithError(err)
}
if e, ok := err.(http2StreamError); ok {
if e.Cause != nil {
err = e.Cause
} else {
err = http2errStreamClosed
}
}
st.closeErr = err
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id)
}
func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
sc.serveG.check()
if f.IsAck() {
sc.unackedSettings--
if sc.unackedSettings < 0 {
// Why is the peer ACKing settings we never sent?
// The spec doesn't mention this case, but
// hang up on them anyway.
return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol))
}
return nil
}
if f.NumSettings() > 100 || f.HasDuplicates() {
// This isn't actually in the spec, but hang up on
// suspiciously large settings frames or those with
// duplicate entries.
return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol))
}
if err := f.ForeachSetting(sc.processSetting); err != nil {
return err
}
// TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be
// acknowledged individually, even if multiple are received before the ACK.
sc.needToSendSettingsAck = true
sc.scheduleFrameWrite()
return nil
}
func (sc *http2serverConn) processSetting(s http2Setting) error {
sc.serveG.check()
if err := s.Valid(); err != nil {
return err
}
if http2VerboseLogs {
sc.vlogf("http2: server processing setting %v", s)
}
switch s.ID {
case http2SettingHeaderTableSize:
sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
case http2SettingEnablePush:
sc.pushEnabled = s.Val != 0
case http2SettingMaxConcurrentStreams:
sc.clientMaxStreams = s.Val
case http2SettingInitialWindowSize:
return sc.processSettingInitialWindowSize(s.Val)
case http2SettingMaxFrameSize:
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case http2SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val
default:
// Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST
// ignore that setting."
if http2VerboseLogs {
sc.vlogf("http2: server ignoring unknown setting %v", s)
}
}
return nil
}
func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
sc.serveG.check()
// Note: val already validated to be within range by
// processSetting's Valid call.
// "A SETTINGS frame can alter the initial flow control window
// size for all current streams. When the value of
// SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST
// adjust the size of all stream flow control windows that it
// maintains by the difference between the new value and the
// old value."
old := sc.initialStreamSendWindowSize
sc.initialStreamSendWindowSize = int32(val)
growth := int32(val) - old // may be negative
for _, st := range sc.streams {
if !st.flow.add(growth) {
// 6.9.2 Initial Flow Control Window Size
// "An endpoint MUST treat a change to
// SETTINGS_INITIAL_WINDOW_SIZE that causes any flow
// control window to exceed the maximum size as a
// connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR."
return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl))
}
}
return nil
}
func (sc *http2serverConn) processData(f *http2DataFrame) error {
sc.serveG.check()
id := f.Header().StreamID
data := f.Data()
state, st := sc.state(id)
if id == 0 || state == http2stateIdle {
// Section 6.1: "DATA frames MUST be associated with a
// stream. If a DATA frame is received whose stream
// identifier field is 0x0, the recipient MUST respond
// with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR."
//
// Section 5.1: "Receiving any frame other than HEADERS
// or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR."
return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol))
}
// "If a DATA frame is received whose stream is not in "open"
// or "half closed (local)" state, the recipient MUST respond
// with a stream error (Section 5.4.2) of type STREAM_CLOSED."
if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued {
// This includes sending a RST_STREAM if the stream is
// in stateHalfClosedLocal (which currently means that
// the http.Handler returned, so it's done reading &
// done writing). Try to stop the client from sending
// more DATA.
// But still enforce their connection-level flow control,
// and return any flow control bytes since we're not going
// to consume them.
if !sc.inflow.take(f.Length) {
return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl))
}
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
if st != nil && st.resetQueued {
// Already have a stream error in flight. Don't send another.
return nil
}
return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed))
}
if st.body == nil {
panic("internal error: should have a body in this state")
}
// Sender sending more than they'd declared?
if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
if !sc.inflow.take(f.Length) {
return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl))
}
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
// RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
// value of a content-length header field does not equal the sum of the
// DATA frame payload lengths that form the body.
return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol))
}
if f.Length > 0 {
// Check whether the client has flow control quota.
if !http2takeInflows(&sc.inflow, &st.inflow, f.Length) {
return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl))
}
if len(data) > 0 {
wrote, err := st.body.Write(data)
if err != nil {
sc.sendWindowUpdate(nil, int(f.Length)-wrote)
return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed))
}
if wrote != len(data) {
panic("internal error: bad Writer")
}
st.bodyBytes += int64(len(data))
}
// Return any padded flow control now, since we won't
// refund it later on body reads.
// Call sendWindowUpdate even if there is no padding,
// to return buffered flow control credit if the sent
// window has shrunk.
pad := int32(f.Length) - int32(len(data))
sc.sendWindowUpdate32(nil, pad)
sc.sendWindowUpdate32(st, pad)
}
if f.StreamEnded() {
st.endStream()
}
return nil
}
func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error {
sc.serveG.check()
if f.ErrCode != http2ErrCodeNo {
sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
} else {
sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
}
sc.startGracefulShutdownInternal()
// http://tools.ietf.org/html/rfc7540#section-6.8
// We should not create any new streams, which means we should disable push.
sc.pushEnabled = false
return nil
}
// isPushed reports whether the stream is server-initiated.
func (st *http2stream) isPushed() bool {
return st.id%2 == 0
}
// endStream closes a Request.Body's pipe. It is called when a DATA
// frame says a request body is over (or after trailers).
func (st *http2stream) endStream() {
sc := st.sc
sc.serveG.check()
if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
st.declBodyBytes, st.bodyBytes))
} else {
st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
st.body.CloseWithError(io.EOF)
}
st.state = http2stateHalfClosedRemote
}
// copyTrailersToHandlerRequest is run in the Handler's goroutine in
// its Request.Body.Read just before it gets io.EOF.
func (st *http2stream) copyTrailersToHandlerRequest() {
for k, vv := range st.trailer {
if _, ok := st.reqTrailer[k]; ok {
// Only copy it over it was pre-declared.
st.reqTrailer[k] = vv
}
}
}
// onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired.
func (st *http2stream) onReadTimeout() {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us
// returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
// onWriteTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's WriteTimeout has fired.
func (st *http2stream) onWriteTimeout() {
st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2StreamError{
StreamID: st.id,
Code: http2ErrCodeInternal,
Cause: os.ErrDeadlineExceeded,
}})
}
func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
sc.serveG.check()
id := f.StreamID
// http://tools.ietf.org/html/rfc7540#section-5.1.1
// Streams initiated by a client MUST use odd-numbered stream
// identifiers. [...] An endpoint that receives an unexpected
// stream identifier MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR.
if id%2 != 1 {
return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol))
}
// A HEADERS frame can be used to create a new stream or
// send a trailer for an open one. If we already have a stream
// open, let it process its own HEADERS frame (trailers at this
// point, if it's valid).
if st := sc.streams[f.StreamID]; st != nil {
if st.resetQueued {
// We're sending RST_STREAM to close the stream, so don't bother
// processing this frame.
return nil
}
// RFC 7540, sec 5.1: If an endpoint receives additional frames, other than
// WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in
// this state, it MUST respond with a stream error (Section 5.4.2) of
// type STREAM_CLOSED.
if st.state == http2stateHalfClosedRemote {
return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed))
}
return st.processTrailerHeaders(f)
}
// [...] The identifier of a newly established stream MUST be
// numerically greater than all streams that the initiating
// endpoint has opened or reserved. [...] An endpoint that
// receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxClientStreamID {
return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol))
}
sc.maxClientStreamID = id
if sc.idleTimer != nil {
sc.idleTimer.Stop()
}
// http://tools.ietf.org/html/rfc7540#section-5.1.2
// [...] Endpoints MUST NOT exceed the limit set by their peer. An
// endpoint that receives a HEADERS frame that causes their
// advertised concurrent stream limit to be exceeded MUST treat
// this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR
// or REFUSED_STREAM.
if sc.curClientStreams+1 > sc.advMaxStreams {
if sc.unackedSettings == 0 {
// They should know better.
return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol))
}
// Assume it's a network race, where they just haven't
// received our last SETTINGS update. But actually
// this can't happen yet, because we don't yet provide
// a way for users to adjust server parameters at
// runtime.
return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream))
}
initialState := http2stateOpen
if f.StreamEnded() {
initialState = http2stateHalfClosedRemote
}
st := sc.newStream(id, 0, initialState)
if f.HasPriority() {
if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err
}
sc.writeSched.AdjustStream(st.id, f.Priority)
}
rw, req, err := sc.newWriterAndRequest(st, f)
if err != nil {
return err
}
st.reqTrailer = req.Trailer
if st.reqTrailer != nil {
st.trailer = make(Header)
}
st.body = req.Body.(*http2requestBody).pipe // may be nil
st.declBodyBytes = req.ContentLength
handler := sc.handler.ServeHTTP
if f.Truncated {
// Their header list was too long. Send a 431 error.
handler = http2handleHeaderListTooLong
} else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil {
handler = http2new400Handler(err)
}
// The net/http package sets the read deadline from the
// http.Server.ReadTimeout during the TLS handshake, but then
// passes the connection off to us with the deadline already
// set. Disarm it here after the request headers are read,
// similar to how the http1 server works. Here it's
// technically more like the http1 Server's ReadHeaderTimeout
// (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{})
if st.body != nil {
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
}
go sc.runHandler(rw, req, handler)
return nil
}
func (sc *http2serverConn) upgradeRequest(req *Request) {
sc.serveG.check()
id := uint32(1)
sc.maxClientStreamID = id
st := sc.newStream(id, 0, http2stateHalfClosedRemote)
st.reqTrailer = req.Trailer
if st.reqTrailer != nil {
st.trailer = make(Header)
}
rw := sc.newResponseWriter(st, req)
// Disable any read deadline set by the net/http package
// prior to the upgrade.
if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{})
}
go sc.runHandler(rw, req, sc.handler.ServeHTTP)
}
func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
sc := st.sc
sc.serveG.check()
if st.gotTrailerHeader {
return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol))
}
st.gotTrailerHeader = true
if !f.StreamEnded() {
return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol))
}
if len(f.PseudoFields()) > 0 {
return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol))
}
if st.trailer != nil {
for _, hf := range f.RegularFields() {
key := sc.canonicalHeader(hf.Name)
if !httpguts.ValidTrailerHeader(key) {
// TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with
// HTTP folk.
return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol))
}
st.trailer[key] = append(st.trailer[key], hf.Value)
}
}
st.endStream()
return nil
}
func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error {
if streamID == p.StreamDep {
// Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat
// this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR."
// Section 5.3.3 says that a stream can depend on one of its dependencies,
// so it's only self-dependencies that are forbidden.
return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol))
}
return nil
}
func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil {
return err
}
sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam)
return nil
}
func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream {
sc.serveG.check()
if id == 0 {
panic("internal error: cannot create stream with id 0")
}
ctx, cancelCtx := context.WithCancel(sc.baseCtx)
st := &http2stream{
sc: sc,
id: id,
state: state,
ctx: ctx,
cancelCtx: cancelCtx,
}
st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize())
if sc.hs.WriteTimeout != 0 {
st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
sc.streams[id] = st
sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID})
if st.isPushed() {
sc.curPushedStreams++
} else {
sc.curClientStreams++
}
if sc.curOpenStreams() == 1 {
sc.setConnState(StateActive)
}
return st
}
func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) {
sc.serveG.check()
rp := http2requestParam{
method: f.PseudoValue("method"),
scheme: f.PseudoValue("scheme"),
authority: f.PseudoValue("authority"),
path: f.PseudoValue("path"),
}
isConnect := rp.method == "CONNECT"
if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol))
}
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
// MUST be treated as a stream error (Section 5.4.2)
// of type PROTOCOL_ERROR."
//
// 8.1.2.3 Request Pseudo-Header Fields
// "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path
// pseudo-header fields"
return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol))
}
rp.header = make(Header)
for _, hf := range f.RegularFields() {
rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
}
if rp.authority == "" {
rp.authority = rp.header.Get("Host")
}
rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
if err != nil {
return nil, nil, err
}
bodyOpen := !f.StreamEnded()
if bodyOpen {
if vv, ok := rp.header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
req.ContentLength = int64(cl)
} else {
req.ContentLength = 0
}
} else {
req.ContentLength = -1
}
req.Body.(*http2requestBody).pipe = &http2pipe{
b: &http2dataBuffer{expected: req.ContentLength},
}
}
return rw, req, nil
}
type http2requestParam struct {
method string
scheme, authority, path string
header Header
}
func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) {
sc.serveG.check()
var tlsState *tls.ConnectionState // nil if not scheme https
if rp.scheme == "https" {
tlsState = sc.tlsState
}
needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
if needsContinue {
rp.header.Del("Expect")
}
// Merge Cookie headers into one "; "-delimited value.
if cookies := rp.header["Cookie"]; len(cookies) > 1 {
rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer Header
for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = CanonicalHeaderKey(textproto.TrimString(key))
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
// Bogus. (copy of http1 rules)
// Ignore.
default:
if trailer == nil {
trailer = make(Header)
}
trailer[key] = nil
}
}
}
delete(rp.header, "Trailer")
var url_ *url.URL
var requestURI string
if rp.method == "CONNECT" {
url_ = &url.URL{Host: rp.authority}
requestURI = rp.authority // mimic HTTP/1 server behavior
} else {
var err error
url_, err = url.ParseRequestURI(rp.path)
if err != nil {
return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol))
}
requestURI = rp.path
}
body := &http2requestBody{
conn: sc,
stream: st,
needsContinue: needsContinue,
}
req := &Request{
Method: rp.method,
URL: url_,
RemoteAddr: sc.remoteAddrStr,
Header: rp.header,
RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
Host: rp.authority,
Body: body,
Trailer: trailer,
}
req = req.WithContext(st.ctx)
rw := sc.newResponseWriter(st, req)
return rw, req, nil
}
func (sc *http2serverConn) newResponseWriter(st *http2stream, req *Request) *http2responseWriter {
rws := http2responseWriterStatePool.Get().(*http2responseWriterState)
bwSave := rws.bw
*rws = http2responseWriterState{} // zero all the fields
rws.conn = sc
rws.bw = bwSave
rws.bw.Reset(http2chunkWriter{rws})
rws.stream = st
rws.req = req
return &http2responseWriter{rws: rws}
}
// Run on its own goroutine.
func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
didPanic := true
defer func() {
rw.rws.stream.cancelCtx()
if req.MultipartForm != nil {
req.MultipartForm.RemoveAll()
}
if didPanic {
e := recover()
sc.writeFrameFromHandler(http2FrameWriteRequest{
write: http2handlerPanicRST{rw.rws.stream.id},
stream: rw.rws.stream,
})
// Same as net/http:
if e != nil && e != ErrAbortHandler {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
}
return
}
rw.handlerDone()
}()
handler(rw, req)
didPanic = false
}
func http2handleHeaderListTooLong(w ResponseWriter, r *Request) {
// 10.5.1 Limits on Header Block Size:
// .. "A server that receives a larger header block than it is
// willing to handle can send an HTTP 431 (Request Header Fields Too
// Large) status code"
const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+
w.WriteHeader(statusRequestHeaderFieldsTooLarge)
io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
}
// called from handler goroutines.
// h may be nil.
func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error {
sc.serveG.checkNotOn() // NOT on
var errc chan error
if headerData.h != nil {
// If there's a header map (which we don't own), so we have to block on
// waiting for this frame to be written, so an http.Flush mid-handler
// writes out the correct value of keys, before a handler later potentially
// mutates it.
errc = http2errChanPool.Get().(chan error)
}
if err := sc.writeFrameFromHandler(http2FrameWriteRequest{
write: headerData,
stream: st,
done: errc,
}); err != nil {
return err
}
if errc != nil {
select {
case err := <-errc:
http2errChanPool.Put(errc)
return err
case <-sc.doneServing:
return http2errClientDisconnected
case <-st.cw:
return http2errStreamClosed
}
}
return nil
}
// called from handler goroutines.
func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) {
sc.writeFrameFromHandler(http2FrameWriteRequest{
write: http2write100ContinueHeadersFrame{st.id},
stream: st,
})
}
// A bodyReadMsg tells the server loop that the http.Handler read n
// bytes of the DATA from the client on the given stream.
type http2bodyReadMsg struct {
st *http2stream
n int
}
// called from handler goroutines.
// Notes that the handler for the given stream ID read n bytes of its body
// and schedules flow control tokens to be sent.
func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) {
sc.serveG.checkNotOn() // NOT on
if n > 0 {
select {
case sc.bodyReadCh <- http2bodyReadMsg{st, n}:
case <-sc.doneServing:
}
}
}
func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) {
sc.serveG.check()
sc.sendWindowUpdate(nil, n) // conn-level
if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed {
// Don't send this WINDOW_UPDATE if the stream is closed
// remotely.
sc.sendWindowUpdate(st, n)
}
}
// st may be nil for conn-level
func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) {
sc.sendWindowUpdate(st, int(n))
}
// st may be nil for conn-level
func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) {
sc.serveG.check()
var streamID uint32
var send int32
if st == nil {
send = sc.inflow.add(n)
} else {
streamID = st.id
send = st.inflow.add(n)
}
if send == 0 {
return
}
sc.writeFrame(http2FrameWriteRequest{
write: http2writeWindowUpdate{streamID: streamID, n: uint32(send)},
stream: st,
})
}
// requestBody is the Handler's Request.Body type.
// Read and Close may be called concurrently.
type http2requestBody struct {
_ http2incomparable
stream *http2stream
conn *http2serverConn
closeOnce sync.Once // for use by Close only
sawEOF bool // for use by Read only
pipe *http2pipe // non-nil if we have a HTTP entity message body
needsContinue bool // need to send a 100-continue
}
func (b *http2requestBody) Close() error {
b.closeOnce.Do(func() {
if b.pipe != nil {
b.pipe.BreakWithError(http2errClosedBody)
}
})
return nil
}
func (b *http2requestBody) Read(p []byte) (n int, err error) {
if b.needsContinue {
b.needsContinue = false
b.conn.write100ContinueHeaders(b.stream)
}
if b.pipe == nil || b.sawEOF {
return 0, io.EOF
}
n, err = b.pipe.Read(p)
if err == io.EOF {
b.sawEOF = true
}
if b.conn == nil && http2inTests {
return
}
b.conn.noteBodyReadFromHandler(b.stream, n, err)
return
}
// responseWriter is the http.ResponseWriter implementation. It's
// intentionally small (1 pointer wide) to minimize garbage. The
// responseWriterState pointer inside is zeroed at the end of a
// request (in handlerDone) and calls on the responseWriter thereafter
// simply crash (caller's mistake), but the much larger responseWriterState
// and buffers are reused between multiple requests.
type http2responseWriter struct {
rws *http2responseWriterState
}
// Optional http.ResponseWriter interfaces implemented.
var (
_ CloseNotifier = (*http2responseWriter)(nil)
_ Flusher = (*http2responseWriter)(nil)
_ http2stringWriter = (*http2responseWriter)(nil)
)
type http2responseWriterState struct {
// immutable within a request:
stream *http2stream
req *Request
conn *http2serverConn
// TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc
bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState}
// mutated by http.Handler goroutine:
handlerHeader Header // nil until called
snapHeader Header // snapshot of handlerHeader at WriteHeader time
trailers []string // set in writeChunk
status int // status code passed to WriteHeader
wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
sentHeader bool // have we sent the header frame?
handlerDone bool // handler has finished
dirty bool // a Write failed; don't reuse this responseWriterState
sentContentLen int64 // non-zero if handler set a Content-Length header
wroteBytes int64
closeNotifierMu sync.Mutex // guards closeNotifierCh
closeNotifierCh chan bool // nil until first used
}
type http2chunkWriter struct{ rws *http2responseWriterState }
func (cw http2chunkWriter) Write(p []byte) (n int, err error) {
n, err = cw.rws.writeChunk(p)
if err == http2errStreamClosed {
// If writing failed because the stream has been closed,
// return the reason it was closed.
err = cw.rws.stream.closeErr
}
return n, err
}
func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
func (rws *http2responseWriterState) hasNonemptyTrailers() bool {
for _, trailer := range rws.trailers {
if _, ok := rws.handlerHeader[trailer]; ok {
return true
}
}
return false
}
// declareTrailer is called for each Trailer header when the
// response header is written. It notes that a header will need to be
// written in the trailers at the end of the response.
func (rws *http2responseWriterState) declareTrailer(k string) {
k = CanonicalHeaderKey(k)
if !httpguts.ValidTrailerHeader(k) {
// Forbidden by RFC 7230, section 4.1.2.
rws.conn.logf("ignoring invalid trailer %q", k)
return
}
if !http2strSliceContains(rws.trailers, k) {
rws.trailers = append(rws.trailers, k)
}
}
// writeChunk writes chunks from the bufio.Writer. But because
// bufio.Writer may bypass its chunking, sometimes p may be
// arbitrarily large.
//
// writeChunk is also responsible (on the first chunk) for sending the
// HEADER response.
func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
if !rws.wroteHeader {
rws.writeHeader(200)
}
if rws.handlerDone {
rws.promoteUndeclaredTrailers()
}
isHeadResp := rws.req.Method == "HEAD"
if !rws.sentHeader {
rws.sentHeader = true
var ctype, clen string
if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
rws.snapHeader.Del("Content-Length")
if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
rws.sentContentLen = int64(cl)
} else {
clen = ""
}
}
if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
clen = strconv.Itoa(len(p))
}
_, hasContentType := rws.snapHeader["Content-Type"]
// If the Content-Encoding is non-blank, we shouldn't
// sniff the body. See Issue golang.org/issue/31753.
ce := rws.snapHeader.Get("Content-Encoding")
hasCE := len(ce) > 0
if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 {
ctype = DetectContentType(p)
}
var date string
if _, ok := rws.snapHeader["Date"]; !ok {
// TODO(bradfitz): be faster here, like net/http? measure.
date = time.Now().UTC().Format(TimeFormat)
}
for _, v := range rws.snapHeader["Trailer"] {
http2foreachHeaderElement(v, rws.declareTrailer)
}
// "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2),
// but respect "Connection" == "close" to mean sending a GOAWAY and tearing
// down the TCP connection when idle, like we do for HTTP/1.
// TODO: remove more Connection-specific header fields here, in addition
// to "Connection".
if _, ok := rws.snapHeader["Connection"]; ok {
v := rws.snapHeader.Get("Connection")
delete(rws.snapHeader, "Connection")
if v == "close" {
rws.conn.startGracefulShutdown()
}
}
endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
httpResCode: rws.status,
h: rws.snapHeader,
endStream: endStream,
contentType: ctype,
contentLength: clen,
date: date,
})
if err != nil {
rws.dirty = true
return 0, err
}
if endStream {
return 0, nil
}
}
if isHeadResp {
return len(p), nil
}
if len(p) == 0 && !rws.handlerDone {
return 0, nil
}
// only send trailers if they have actually been defined by the
// server handler.
hasNonemptyTrailers := rws.hasNonemptyTrailers()
endStream := rws.handlerDone && !hasNonemptyTrailers
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
rws.dirty = true
return 0, err
}
}
if rws.handlerDone && hasNonemptyTrailers {
err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
h: rws.handlerHeader,
trailers: rws.trailers,
endStream: true,
})
if err != nil {
rws.dirty = true
}
return len(p), err
}
return len(p), nil
}
// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
// that, if present, signals that the map entry is actually for
// the response trailers, and not the response headers. The prefix
// is stripped after the ServeHTTP call finishes and the values are
// sent in the trailers.
//
// This mechanism is intended only for trailers that are not known
// prior to the headers being written. If the set of trailers is fixed
// or known before the header is written, the normal Go trailers mechanism
// is preferred:
//
// https://golang.org/pkg/net/http/#ResponseWriter
// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
const http2TrailerPrefix = "Trailer:"
// promoteUndeclaredTrailers permits http.Handlers to set trailers
// after the header has already been flushed. Because the Go
// ResponseWriter interface has no way to set Trailers (only the
// Header), and because we didn't want to expand the ResponseWriter
// interface, and because nobody used trailers, and because RFC 7230
// says you SHOULD (but not must) predeclare any trailers in the
// header, the official ResponseWriter rules said trailers in Go must
// be predeclared, and then we reuse the same ResponseWriter.Header()
// map to mean both Headers and Trailers. When it's time to write the
// Trailers, we pick out the fields of Headers that were declared as
// trailers. That worked for a while, until we found the first major
// user of Trailers in the wild: gRPC (using them only over http2),
// and gRPC libraries permit setting trailers mid-stream without
// predeclaring them. So: change of plans. We still permit the old
// way, but we also permit this hack: if a Header() key begins with
// "Trailer:", the suffix of that key is a Trailer. Because ':' is an
// invalid token byte anyway, there is no ambiguity. (And it's already
// filtered out) It's mildly hacky, but not terrible.
//
// This method runs after the Handler is done and promotes any Header
// fields to be trailers.
func (rws *http2responseWriterState) promoteUndeclaredTrailers() {
for k, vv := range rws.handlerHeader {
if !strings.HasPrefix(k, http2TrailerPrefix) {
continue
}
trailerKey := strings.TrimPrefix(k, http2TrailerPrefix)
rws.declareTrailer(trailerKey)
rws.handlerHeader[CanonicalHeaderKey(trailerKey)] = vv
}
if len(rws.trailers) > 1 {
sorter := http2sorterPool.Get().(*http2sorter)
sorter.SortStrings(rws.trailers)
http2sorterPool.Put(sorter)
}
}
func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *http2serverConn) {
if st.readDeadline != nil {
if !st.readDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *http2serverConn) {
if st.writeDeadline != nil {
if !st.writeDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *http2responseWriter) Flush() {
w.FlushError()
}
func (w *http2responseWriter) FlushError() error {
rws := w.rws
if rws == nil {
panic("Header called after Handler finished")
}
var err error
if rws.bw.Buffered() > 0 {
err = rws.bw.Flush()
} else {
// The bufio.Writer won't call chunkWriter.Write
// (writeChunk with zero bytes, so we have to do it
// ourselves to force the HTTP response header and/or
// final DATA frame (with END_STREAM) to be sent.
_, err = http2chunkWriter{rws}.Write(nil)
if err == nil {
select {
case <-rws.stream.cw:
err = rws.stream.closeErr
default:
}
}
}
return err
}
func (w *http2responseWriter) CloseNotify() <-chan bool {
rws := w.rws
if rws == nil {
panic("CloseNotify called after Handler finished")
}
rws.closeNotifierMu.Lock()
ch := rws.closeNotifierCh
if ch == nil {
ch = make(chan bool, 1)
rws.closeNotifierCh = ch
cw := rws.stream.cw
go func() {
cw.Wait() // wait for close
ch <- true
}()
}
rws.closeNotifierMu.Unlock()
return ch
}
func (w *http2responseWriter) Header() Header {
rws := w.rws
if rws == nil {
panic("Header called after Handler finished")
}
if rws.handlerHeader == nil {
rws.handlerHeader = make(Header)
}
return rws.handlerHeader
}
// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode.
func http2checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
// at http://httpwg.org/specs/rfc7231.html#status.codes).
// But for now any three digits.
//
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
// no equivalent bogus thing we can realistically send in HTTP/2,
// so we'll consistently panic instead and help people find their bugs
// early. (We can't return an error from WriteHeader even if we wanted to.)
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
func (w *http2responseWriter) WriteHeader(code int) {
rws := w.rws
if rws == nil {
panic("WriteHeader called after Handler finished")
}
rws.writeHeader(code)
}
func (rws *http2responseWriterState) writeHeader(code int) {
if rws.wroteHeader {
return
}
http2checkWriteHeaderCode(code)
// Handle informational headers
if code >= 100 && code <= 199 {
// Per RFC 8297 we must not clear the current header map
h := rws.handlerHeader
_, cl := h["Content-Length"]
_, te := h["Transfer-Encoding"]
if cl || te {
h = h.Clone()
h.Del("Content-Length")
h.Del("Transfer-Encoding")
}
if rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
httpResCode: code,
h: h,
endStream: rws.handlerDone && !rws.hasTrailers(),
}) != nil {
rws.dirty = true
}
return
}
rws.wroteHeader = true
rws.status = code
if len(rws.handlerHeader) > 0 {
rws.snapHeader = http2cloneHeader(rws.handlerHeader)
}
}
func http2cloneHeader(h Header) Header {
h2 := make(Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
// The Life Of A Write is like this:
//
// * Handler calls w.Write or w.WriteString ->
// * -> rws.bw (*bufio.Writer) ->
// * (Handler might call Flush)
// * -> chunkWriter{rws}
// * -> responseWriterState.writeChunk(p []byte)
// * -> responseWriterState.writeChunk (most of the magic; see comment there)
func (w *http2responseWriter) Write(p []byte) (n int, err error) {
return w.write(len(p), p, "")
}
func (w *http2responseWriter) WriteString(s string) (n int, err error) {
return w.write(len(s), nil, s)
}
// either dataB or dataS is non-zero.
func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
rws := w.rws
if rws == nil {
panic("Write called after Handler finished")
}
if !rws.wroteHeader {
w.WriteHeader(200)
}
if !http2bodyAllowedForStatus(rws.status) {
return 0, ErrBodyNotAllowed
}
rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set
if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
// TODO: send a RST_STREAM
return 0, errors.New("http2: handler wrote more than declared Content-Length")
}
if dataB != nil {
return rws.bw.Write(dataB)
} else {
return rws.bw.WriteString(dataS)
}
}
func (w *http2responseWriter) handlerDone() {
rws := w.rws
dirty := rws.dirty
rws.handlerDone = true
w.Flush()
w.rws = nil
if !dirty {
// Only recycle the pool if all prior Write calls to
// the serverConn goroutine completed successfully. If
// they returned earlier due to resets from the peer
// there might still be write goroutines outstanding
// from the serverConn referencing the rws memory. See
// issue 20704.
http2responseWriterStatePool.Put(rws)
}
}
// Push errors.
var (
http2ErrRecursivePush = errors.New("http2: recursive push not allowed")
http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
)
var _ Pusher = (*http2responseWriter)(nil)
func (w *http2responseWriter) Push(target string, opts *PushOptions) error {
st := w.rws.stream
sc := st.sc
sc.serveG.checkNotOn()
// No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream."
// http://tools.ietf.org/html/rfc7540#section-6.6
if st.isPushed() {
return http2ErrRecursivePush
}
if opts == nil {
opts = new(PushOptions)
}
// Default options.
if opts.Method == "" {
opts.Method = "GET"
}
if opts.Header == nil {
opts.Header = Header{}
}
wantScheme := "http"
if w.rws.req.TLS != nil {
wantScheme = "https"
}
// Validate the request.
u, err := url.Parse(target)
if err != nil {
return err
}
if u.Scheme == "" {
if !strings.HasPrefix(target, "/") {
return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
}
u.Scheme = wantScheme
u.Host = w.rws.req.Host
} else {
if u.Scheme != wantScheme {
return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
}
if u.Host == "" {
return errors.New("URL must have a host")
}
}
for k := range opts.Header {
if strings.HasPrefix(k, ":") {
return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
}
// These headers are meaningful only if the request has a body,
// but PUSH_PROMISE requests cannot have a body.
// http://tools.ietf.org/html/rfc7540#section-8.2
// Also disallow Host, since the promised URL must be absolute.
if http2asciiEqualFold(k, "content-length") ||
http2asciiEqualFold(k, "content-encoding") ||
http2asciiEqualFold(k, "trailer") ||
http2asciiEqualFold(k, "te") ||
http2asciiEqualFold(k, "expect") ||
http2asciiEqualFold(k, "host") {
return fmt.Errorf("promised request headers cannot include %q", k)
}
}
if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil {
return err
}
// The RFC effectively limits promised requests to GET and HEAD:
// "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]"
// http://tools.ietf.org/html/rfc7540#section-8.2
if opts.Method != "GET" && opts.Method != "HEAD" {
return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
}
msg := &http2startPushRequest{
parent: st,
method: opts.Method,
url: u,
header: http2cloneHeader(opts.Header),
done: http2errChanPool.Get().(chan error),
}
select {
case <-sc.doneServing:
return http2errClientDisconnected
case <-st.cw:
return http2errStreamClosed
case sc.serveMsgCh <- msg:
}
select {
case <-sc.doneServing:
return http2errClientDisconnected
case <-st.cw:
return http2errStreamClosed
case err := <-msg.done:
http2errChanPool.Put(msg.done)
return err
}
}
type http2startPushRequest struct {
parent *http2stream
method string
url *url.URL
header Header
done chan error
}
func (sc *http2serverConn) startPush(msg *http2startPushRequest) {
sc.serveG.check()
// http://tools.ietf.org/html/rfc7540#section-6.6.
// PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that
// is in either the "open" or "half-closed (remote)" state.
if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote {
// responseWriter.Push checks that the stream is peer-initiated.
msg.done <- http2errStreamClosed
return
}
// http://tools.ietf.org/html/rfc7540#section-6.6.
if !sc.pushEnabled {
msg.done <- ErrNotSupported
return
}
// PUSH_PROMISE frames must be sent in increasing order by stream ID, so
// we allocate an ID for the promised stream lazily, when the PUSH_PROMISE
// is written. Once the ID is allocated, we start the request handler.
allocatePromisedID := func() (uint32, error) {
sc.serveG.check()
// Check this again, just in case. Technically, we might have received
// an updated SETTINGS by the time we got around to writing this frame.
if !sc.pushEnabled {
return 0, ErrNotSupported
}
// http://tools.ietf.org/html/rfc7540#section-6.5.2.
if sc.curPushedStreams+1 > sc.clientMaxStreams {
return 0, http2ErrPushLimitReached
}
// http://tools.ietf.org/html/rfc7540#section-5.1.1.
// Streams initiated by the server MUST use even-numbered identifiers.
// A server that is unable to establish a new stream identifier can send a GOAWAY
// frame so that the client is forced to open a new connection for new streams.
if sc.maxPushPromiseID+2 >= 1<<31 {
sc.startGracefulShutdownInternal()
return 0, http2ErrPushLimitReached
}
sc.maxPushPromiseID += 2
promisedID := sc.maxPushPromiseID
// http://tools.ietf.org/html/rfc7540#section-8.2.
// Strictly speaking, the new stream should start in "reserved (local)", then
// transition to "half closed (remote)" after sending the initial HEADERS, but
// we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote)
rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{
method: msg.method,
scheme: msg.url.Scheme,
authority: msg.url.Host,
path: msg.url.RequestURI(),
header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
})
if err != nil {
// Should not happen, since we've already validated msg.url.
panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
}
go sc.runHandler(rw, req, sc.handler.ServeHTTP)
return promisedID, nil
}
sc.writeFrame(http2FrameWriteRequest{
write: &http2writePushPromise{
streamID: msg.parent.id,
method: msg.method,
url: msg.url,
h: msg.header,
allocatePromisedID: allocatePromisedID,
},
stream: msg.parent,
done: msg.done,
})
}
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 7230 section 7 and calls fn for each non-empty element.
func http2foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}
// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2
var http2connHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Connection",
"Transfer-Encoding",
"Upgrade",
}
// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request,
// per RFC 7540 Section 8.1.2.2.
// The returned error is reported to users.
func http2checkValidHTTP2RequestHeaders(h Header) error {
for _, k := range http2connHeaders {
if _, ok := h[k]; ok {
return fmt.Errorf("request header %q is not valid in HTTP/2", k)
}
}
te := h["Te"]
if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
}
return nil
}
func http2new400Handler(err error) HandlerFunc {
return func(w ResponseWriter, r *Request) {
Error(w, err.Error(), StatusBadRequest)
}
}
// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
// disabled. See comments on h1ServerShutdownChan above for why
// the code is written this way.
func http2h1ServerKeepAlivesDisabled(hs *Server) bool {
var x interface{} = hs
type I interface {
doKeepAlives() bool
}
if hs, ok := x.(I); ok {
return !hs.doKeepAlives()
}
return false
}
func (sc *http2serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil {
return err
}
f := sc.srv.CountError
if f == nil {
return err
}
var typ string
var code http2ErrCode
switch e := err.(type) {
case http2ConnectionError:
typ = "conn"
code = http2ErrCode(e)
case http2StreamError:
typ = "stream"
code = http2ErrCode(e.Code)
default:
return err
}
codeStr := http2errCodeName[code]
if codeStr == "" {
codeStr = strconv.Itoa(int(code))
}
f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
return err
}
const (
// transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k.
http2transportDefaultConnFlow = 1 << 30
// transportDefaultStreamFlow is how many stream-level flow
// control tokens we announce to the peer, and how many bytes
// we buffer per stream.
http2transportDefaultStreamFlow = 4 << 20
http2defaultUserAgent = "Go-http-client/2.0"
// initialMaxConcurrentStreams is a connections maxConcurrentStreams until
// it's received servers initial SETTINGS frame, which corresponds with the
// spec's minimum recommended value.
http2initialMaxConcurrentStreams = 100
// defaultMaxConcurrentStreams is a connections default maxConcurrentStreams
// if the server doesn't include one in its initial SETTINGS frame.
http2defaultMaxConcurrentStreams = 1000
)
// Transport is an HTTP/2 Transport.
//
// A Transport internally caches connections to servers. It is safe
// for concurrent use by multiple goroutines.
type http2Transport struct {
// DialTLSContext specifies an optional dial function with context for
// creating TLS connections for requests.
//
// If DialTLSContext and DialTLS is nil, tls.Dial is used.
//
// If the returned net.Conn has a ConnectionState method like tls.Conn,
// it will be used to set http.Response.TLS.
DialTLSContext func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for requests.
//
// If DialTLSContext and DialTLS is nil, tls.Dial is used.
//
// Deprecated: Use DialTLSContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialTLSContext takes priority.
DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// ConnPool optionally specifies an alternate connection pool to use.
// If nil, the default is used.
ConnPool http2ClientConnPool
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// AllowHTTP, if true, permits HTTP/2 requests using the insecure,
// plain-text "http" scheme. Note that this does not enable h2c support.
AllowHTTP bool
// MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to
// send in the initial settings frame. It is how many bytes
// of response headers are allowed. Unlike the http2 spec, zero here
// means to use a default limit (currently 10MB). If you actually
// want to advertise an unlimited value to the peer, Transport
// interprets the highest possible value here (0xffffffff or 1<<32-1)
// to mean no limit.
MaxHeaderListSize uint32
// MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the
// initial settings frame. It is the size in bytes of the largest frame
// payload that the sender is willing to receive. If 0, no setting is
// sent, and the value is provided by the peer, which should be 16384
// according to the spec:
// https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2.
// Values are bounded in the range 16k to 16M.
MaxReadFrameSize uint32
// MaxDecoderHeaderTableSize optionally specifies the http2
// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
// informs the remote endpoint of the maximum size of the header compression
// table used to decode header blocks, in octets. If zero, the default value
// of 4096 is used.
MaxDecoderHeaderTableSize uint32
// MaxEncoderHeaderTableSize optionally specifies an upper limit for the
// header compression table used for encoding request headers. Received
// SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
// the default value of 4096 is used.
MaxEncoderHeaderTableSize uint32
// StrictMaxConcurrentStreams controls whether the server's
// SETTINGS_MAX_CONCURRENT_STREAMS should be respected
// globally. If false, new TCP connections are created to the
// server as needed to keep each under the per-connection
// SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the
// server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as
// a global limit and callers of RoundTrip block when needed,
// waiting for their turn.
StrictMaxConcurrentStreams bool
// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
// there is no other traffic on the connection, the health check will
// be performed every ReadIdleTimeout interval.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to Ping is not received.
// Defaults to 15s.
PingTimeout time.Duration
// WriteByteTimeout is the timeout after which the connection will be
// closed no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
WriteByteTimeout time.Duration
// CountError, if non-nil, is called on HTTP/2 transport errors.
// It's intended to increment a metric for monitoring, such
// as an expvar or Prometheus metric.
// The errType consists of only ASCII word characters.
CountError func(errType string)
// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
t1 *Transport
connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
}
func (t *http2Transport) maxHeaderListSize() uint32 {
if t.MaxHeaderListSize == 0 {
return 10 << 20
}
if t.MaxHeaderListSize == 0xffffffff {
return 0
}
return t.MaxHeaderListSize
}
func (t *http2Transport) maxFrameReadSize() uint32 {
if t.MaxReadFrameSize == 0 {
return 0 // use the default provided by the peer
}
if t.MaxReadFrameSize < http2minMaxFrameSize {
return http2minMaxFrameSize
}
if t.MaxReadFrameSize > http2maxFrameSize {
return http2maxFrameSize
}
return t.MaxReadFrameSize
}
func (t *http2Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
func (t *http2Transport) pingTimeout() time.Duration {
if t.PingTimeout == 0 {
return 15 * time.Second
}
return t.PingTimeout
}
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
//
// Use ConfigureTransports instead to configure the HTTP/2 Transport.
func http2ConfigureTransport(t1 *Transport) error {
_, err := http2ConfigureTransports(t1)
return err
}
// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns a new HTTP/2 Transport for further configuration.
// It returns an error if t1 has already been HTTP/2-enabled.
func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) {
return http2configureTransports(t1)
}
func http2configureTransports(t1 *Transport) (*http2Transport, error) {
connPool := new(http2clientConnPool)
t2 := &http2Transport{
ConnPool: http2noDialClientConnPool{connPool},
t1: t1,
}
connPool.t = t2
if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
return nil, err
}
if t1.TLSClientConfig == nil {
t1.TLSClientConfig = new(tls.Config)
}
if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
}
if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
}
upgradeFn := func(authority string, c *tls.Conn) RoundTripper {
addr := http2authorityAddr("https", authority)
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
go c.Close()
return http2erringRoundTripper{err}
} else if !used {
// Turns out we don't need this c.
// For example, two goroutines made requests to the same host
// at the same time, both kicking off TCP dials. (since protocol
// was unknown)
go c.Close()
}
return t2
}
if m := t1.TLSNextProto; len(m) == 0 {
t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
"h2": upgradeFn,
}
} else {
m["h2"] = upgradeFn
}
return t2, nil
}
func (t *http2Transport) connPool() http2ClientConnPool {
t.connPoolOnce.Do(t.initConnPool)
return t.connPoolOrDef
}
func (t *http2Transport) initConnPool() {
if t.ConnPool != nil {
t.connPoolOrDef = t.ConnPool
} else {
t.connPoolOrDef = &http2clientConnPool{t: t}
}
}
// ClientConn is the state of a single HTTP/2 client connection to an
// HTTP/2 server.
type http2ClientConn struct {
t *http2Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tconnClosed bool
tlsState *tls.ConnectionState // nil only for specialized impls
reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
getConnCalled bool // used by clientConnPool
// readLoop goroutine fields:
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
idleTimeout time.Duration // or 0 for never
idleTimer *time.Timer
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2outflow // our conn-level flow control quota (cs.outflow is per stream)
inflow http2inflow // peer's conn-level flow control
doNotReuse bool // whether conn is marked to not be reused for any future requests
closing bool
closed bool
seenSettings bool // true if we've seen a settings frame, false otherwise
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*http2clientStream // client-initiated
streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
nextStreamID uint32
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
br *bufio.Reader
lastActive time.Time
lastIdle time.Time // time last idle
// Settings from peer: (also guarded by wmu)
maxFrameSize uint32
maxConcurrentStreams uint32
peerMaxHeaderListSize uint64
peerMaxHeaderTableSize uint32
initialWindowSize uint32
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock.
// Lock reqmu BEFORE mu or wmu.
reqHeaderMu chan struct{}
// wmu is held while writing.
// Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
// Only acquire both at the same time when changing peer settings.
wmu sync.Mutex
bw *bufio.Writer
fr *http2Framer
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
}
// clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call.
type http2clientStream struct {
cc *http2ClientConn
// Fields of Request that we may access even after the response body is closed.
ctx context.Context
reqCancel <-chan struct{}
trace *httptrace.ClientTrace // or nil
ID uint32
bufPipe http2pipe // buffered pipe with the flow-controlled response payload
requestedGzip bool
isHead bool
abortOnce sync.Once
abort chan struct{} // closed to signal stream should end immediately
abortErr error // set if abort is closed
peerClosed chan struct{} // closed when the peer sends an END_STREAM flag
donec chan struct{} // closed after the stream is in the closed state
on100 chan struct{} // buffered; written to if a 100 is received
respHeaderRecv chan struct{} // closed when headers are received
res *Response // set if respHeaderRecv is closed
flow http2outflow // guarded by cc.mu
inflow http2inflow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
reqBody io.ReadCloser
reqBodyContentLength int64 // -1 means unknown
reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
// owned by writeRequest:
sentEndStream bool // sent an END_STREAM flag to the peer
sentHeaders bool
// owned by clientConnReadLoop:
firstByte bool // got the first response byte
pastHeaders bool // got first MetaHeadersFrame (actual headers)
pastTrailers bool // got optional second MetaHeadersFrame (trailers)
num1xx uint8 // number of 1xx responses seen
readClosed bool // peer sent an END_STREAM flag
readAborted bool // read loop reset the stream
trailer Header // accumulated trailers
resTrailer *Header // client's Response.Trailer
}
var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error
// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func,
// if any. It returns nil if not set or if the Go version is too old.
func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error {
if fn := http2got1xxFuncForTests; fn != nil {
return fn
}
return http2traceGot1xxResponseFunc(cs.trace)
}
func (cs *http2clientStream) abortStream(err error) {
cs.cc.mu.Lock()
defer cs.cc.mu.Unlock()
cs.abortStreamLocked(err)
}
func (cs *http2clientStream) abortStreamLocked(err error) {
cs.abortOnce.Do(func() {
cs.abortErr = err
close(cs.abort)
})
if cs.reqBody != nil {
cs.closeReqBodyLocked()
}
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast()
}
}
func (cs *http2clientStream) abortRequestBodyWrite() {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
cc.cond.Broadcast()
}
}
func (cs *http2clientStream) closeReqBodyLocked() {
if cs.reqBodyClosed != nil {
return
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
cs.reqBody.Close()
close(reqBodyClosed)
}()
}
type http2stickyErrWriter struct {
conn net.Conn
timeout time.Duration
err *error
}
func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
for {
if sew.timeout != 0 {
sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
}
nn, err := sew.conn.Write(p[n:])
n += nn
if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
// Keep extending the deadline so long as we're making progress.
continue
}
if sew.timeout != 0 {
sew.conn.SetWriteDeadline(time.Time{})
}
*sew.err = err
return n, err
}
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
// needs to be detected by net/http regardless of whether it's its
// bundled version (in h2_bundle.go with a rewritten type name) or
// from a user's x/net/http2. As such, as it has a unique method name
// (IsHTTP2NoCachedConnError) that net/http sniffs for via func
// isNoCachedConnError.
type http2noCachedConnError struct{}
func (http2noCachedConnError) IsHTTP2NoCachedConnError() {}
func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" }
// isNoCachedConnError reports whether err is of type noCachedConnError
// or its equivalent renamed type in net/http2's h2_bundle.go. Both types
// may coexist in the same running program.
func http2isNoCachedConnError(err error) bool {
_, ok := err.(interface{ IsHTTP2NoCachedConnError() })
return ok
}
var http2ErrNoCachedConn error = http2noCachedConnError{}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type http2RoundTripOpt struct {
// OnlyCachedConn controls whether RoundTripOpt may
// create a new TCP connection. If set true and
// no cached connection is available, RoundTripOpt
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
return t.RoundTripOpt(req, http2RoundTripOpt{})
}
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func http2authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
var http2retryBackoffHook func(time.Duration) *time.Timer
func http2backoffNewTimer(d time.Duration) *time.Timer {
if http2retryBackoffHook != nil {
return http2retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options.
func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
return nil, errors.New("http2: unsupported scheme")
}
addr := http2authorityAddr(req.URL.Scheme, req.URL.Host)
for retry := 0; ; retry++ {
cc, err := t.connPool().GetClientConn(req, addr)
if err != nil {
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err
}
reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
http2traceGotConn(req, cc, reused)
res, err := cc.RoundTrip(req)
if err != nil && retry <= 6 {
if req, err = http2shouldRetryRequest(req, err); err == nil {
// After the first retry, do exponential backoff with 10% jitter.
if retry == 0 {
t.vlogf("RoundTrip retrying after failure: %v", err)
continue
}
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
timer := http2backoffNewTimer(d)
select {
case <-timer.C:
t.vlogf("RoundTrip retrying after failure: %v", err)
continue
case <-req.Context().Done():
timer.Stop()
err = req.Context().Err()
}
}
}
if err != nil {
t.vlogf("RoundTrip failure: %v", err)
return nil, err
}
return res, nil
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use.
func (t *http2Transport) CloseIdleConnections() {
if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok {
cp.closeIdleConnections()
}
}
var (
http2errClientConnClosed = errors.New("http2: client conn is closed")
http2errClientConnUnusable = errors.New("http2: client conn not usable")
http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
)
// shouldRetryRequest is called by RoundTrip when a request fails to get
// response headers. It is always called with a non-nil error.
// It returns either a request to retry (either the same request, or a
// modified clone), or an error if the request can't be replayed.
func http2shouldRetryRequest(req *Request, err error) (*Request, error) {
if !http2canRetryError(err) {
return nil, err
}
// If the Body is nil (or http.NoBody), it's safe to reuse
// this request and its Body.
if req.Body == nil || req.Body == NoBody {
return req, nil
}
// If the request body can be reset back to its original
// state via the optional req.GetBody, do that.
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return nil, err
}
newReq := *req
newReq.Body = body
return &newReq, nil
}
// The Request.Body can't reset back to the beginning, but we
// don't seem to have started to read from it yet, so reuse
// the request directly.
if err == http2errClientConnUnusable {
return req, nil
}
return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
}
func http2canRetryError(err error) bool {
if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway {
return true
}
if se, ok := err.(http2StreamError); ok {
if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer {
// See golang/go#47635, golang/go#42777
return true
}
return se.Code == http2ErrCodeRefusedStream
}
return false
}
func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
tconn, err := t.dialTLS(ctx, "tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
return t.newClientConn(tconn, singleUse)
}
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
cfg := new(tls.Config)
if t.TLSClientConfig != nil {
*cfg = *t.TLSClientConfig.Clone()
}
if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) {
cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...)
}
if cfg.ServerName == "" {
cfg.ServerName = host
}
return cfg
}
func (t *http2Transport) dialTLS(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) {
if t.DialTLSContext != nil {
return t.DialTLSContext(ctx, network, addr, tlsCfg)
} else if t.DialTLS != nil {
return t.DialTLS(network, addr, tlsCfg)
}
tlsCn, err := t.dialTLSWithContext(ctx, network, addr, tlsCfg)
if err != nil {
return nil, err
}
state := tlsCn.ConnectionState()
if p := state.NegotiatedProtocol; p != http2NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
return tlsCn, nil
}
// disableKeepAlives reports whether connections should be closed as
// soon as possible after handling the first request.
func (t *http2Transport) disableKeepAlives() bool {
return t.t1 != nil && t.t1.DisableKeepAlives
}
func (t *http2Transport) expectContinueTimeout() time.Duration {
if t.t1 == nil {
return 0
}
return t.t1.ExpectContinueTimeout
}
func (t *http2Transport) maxDecoderHeaderTableSize() uint32 {
if v := t.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return http2initialHeaderTableSize
}
func (t *http2Transport) maxEncoderHeaderTableSize() uint32 {
if v := t.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return http2initialHeaderTableSize
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives())
}
func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
cc := &http2ClientConn{
t: t,
tconn: c,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default
maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
streams: make(map[uint32]*http2clientStream),
singleUse: singleUse,
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
}
if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
}
cc.cond = sync.NewCond(&cc.mu)
cc.flow.add(int32(http2initialWindowSize))
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(http2stickyErrWriter{
conn: c,
timeout: t.WriteByteTimeout,
err: &cc.werr,
})
cc.br = bufio.NewReader(c)
cc.fr = http2NewFramer(cc.bw, cc.br)
if t.maxFrameReadSize() != 0 {
cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
}
if t.CountError != nil {
cc.fr.countError = t.CountError
}
maxHeaderTableSize := t.maxDecoderHeaderTableSize()
cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize())
cc.peerMaxHeaderTableSize = http2initialHeaderTableSize
if t.AllowHTTP {
cc.nextStreamID = 3
}
if cs, ok := c.(http2connectionStater); ok {
state := cs.ConnectionState()
cc.tlsState = &state
}
initialSettings := []http2Setting{
{ID: http2SettingEnablePush, Val: 0},
{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow},
}
if max := t.maxFrameReadSize(); max != 0 {
initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max})
}
if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max})
}
if maxHeaderTableSize != http2initialHeaderTableSize {
initialSettings = append(initialSettings, http2Setting{ID: http2SettingHeaderTableSize, Val: maxHeaderTableSize})
}
cc.bw.Write(http2clientPreface)
cc.fr.WriteSettings(initialSettings...)
cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow)
cc.inflow.init(http2transportDefaultConnFlow + http2initialWindowSize)
cc.bw.Flush()
if cc.werr != nil {
cc.Close()
return nil, cc.werr
}
go cc.readLoop()
return cc, nil
}
func (cc *http2ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
if err != nil {
cc.vlogf("http2: Transport health check failure: %v", err)
cc.closeForLostPing()
} else {
cc.vlogf("http2: Transport health check success")
}
}
// SetDoNotReuse marks cc as not reusable for future HTTP requests.
func (cc *http2ClientConn) SetDoNotReuse() {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.doNotReuse = true
}
func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
old := cc.goAway
cc.goAway = f
// Merge the previous and current GoAway error frames.
if cc.goAwayDebug == "" {
cc.goAwayDebug = string(f.DebugData())
}
if old != nil && old.ErrCode != http2ErrCodeNo {
cc.goAway.ErrCode = old.ErrCode
}
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID > last {
cs.abortStreamLocked(http2errClientConnGotGoAway)
}
}
}
// CanTakeNewRequest reports whether the connection can take a new request,
// meaning it has not been closed or received or sent a GOAWAY.
//
// If the caller is going to immediately make a new request on this
// connection, use ReserveNewRequest instead.
func (cc *http2ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.canTakeNewRequestLocked()
}
// ReserveNewRequest is like CanTakeNewRequest but also reserves a
// concurrent stream in cc. The reservation is decremented on the
// next call to RoundTrip.
func (cc *http2ClientConn) ReserveNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
if st := cc.idleStateLocked(); !st.canTakeNewRequest {
return false
}
cc.streamsReserved++
return true
}
// ClientConnState describes the state of a ClientConn.
type http2ClientConnState struct {
// Closed is whether the connection is closed.
Closed bool
// Closing is whether the connection is in the process of
// closing. It may be closing due to shutdown, being a
// single-use connection, being marked as DoNotReuse, or
// having received a GOAWAY frame.
Closing bool
// StreamsActive is how many streams are active.
StreamsActive int
// StreamsReserved is how many streams have been reserved via
// ClientConn.ReserveNewRequest.
StreamsReserved int
// StreamsPending is how many requests have been sent in excess
// of the peer's advertised MaxConcurrentStreams setting and
// are waiting for other streams to complete.
StreamsPending int
// MaxConcurrentStreams is how many concurrent streams the
// peer advertised as acceptable. Zero means no SETTINGS
// frame has been received yet.
MaxConcurrentStreams uint32
// LastIdle, if non-zero, is when the connection last
// transitioned to idle state.
LastIdle time.Time
}
// State returns a snapshot of cc's state.
func (cc *http2ClientConn) State() http2ClientConnState {
cc.wmu.Lock()
maxConcurrent := cc.maxConcurrentStreams
if !cc.seenSettings {
maxConcurrent = 0
}
cc.wmu.Unlock()
cc.mu.Lock()
defer cc.mu.Unlock()
return http2ClientConnState{
Closed: cc.closed,
Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil,
StreamsActive: len(cc.streams),
StreamsReserved: cc.streamsReserved,
StreamsPending: cc.pendingRequests,
LastIdle: cc.lastIdle,
MaxConcurrentStreams: maxConcurrent,
}
}
// clientConnIdleState describes the suitability of a client
// connection to initiate a new RoundTrip request.
type http2clientConnIdleState struct {
canTakeNewRequest bool
}
func (cc *http2ClientConn) idleState() http2clientConnIdleState {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.idleStateLocked()
}
func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
if cc.singleUse && cc.nextStreamID > 1 {
return
}
var maxConcurrentOkay bool
if cc.t.StrictMaxConcurrentStreams {
// We'll tell the caller we can take a new request to
// prevent the caller from dialing a new TCP
// connection, but then we'll block later before
// writing it.
maxConcurrentOkay = true
} else {
maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
!cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
return
}
func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
st := cc.idleStateLocked()
return st.canTakeNewRequest
}
// tooIdleLocked reports whether this connection has been been sitting idle
// for too much wall time.
func (cc *http2ClientConn) tooIdleLocked() bool {
// The Round(0) strips the monontonic clock reading so the
// times are compared based on their wall time. We don't want
// to reuse a connection that's been sitting idle during
// VM/laptop suspend if monotonic time was also frozen.
return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
// only be called when we're idle, but because we're coming from a new
// goroutine, there could be a new request coming in at the same time,
// so this simply calls the synchronized closeIfIdle to shut down this
// connection. The timer could just call closeIfIdle, but this is more
// clear.
func (cc *http2ClientConn) onIdleTimeout() {
cc.closeIfIdle()
}
func (cc *http2ClientConn) closeConn() {
t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn)
defer t.Stop()
cc.tconn.Close()
}
// A tls.Conn.Close can hang for a long time if the peer is unresponsive.
// Try to shut it down more aggressively.
func (cc *http2ClientConn) forceCloseConn() {
tc, ok := cc.tconn.(*tls.Conn)
if !ok {
return
}
if nc := http2tlsUnderlyingConn(tc); nc != nil {
nc.Close()
}
}
func (cc *http2ClientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 || cc.streamsReserved > 0 {
cc.mu.Unlock()
return
}
cc.closed = true
nextID := cc.nextStreamID
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
if http2VerboseLogs {
cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2)
}
cc.closeConn()
}
func (cc *http2ClientConn) isDoNotReuseAndIdle() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.doNotReuse && len(cc.streams) == 0
}
var http2shutdownEnterWaitStateHook = func() {}
// Shutdown gracefully closes the client connection, waiting for running streams to complete.
func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
if err := cc.sendGoAway(); err != nil {
return err
}
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
go func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if len(cc.streams) == 0 || cc.closed {
cc.closed = true
close(done)
break
}
if cancelled {
break
}
cc.cond.Wait()
}
}()
http2shutdownEnterWaitStateHook()
select {
case <-done:
cc.closeConn()
return nil
case <-ctx.Done():
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
}
func (cc *http2ClientConn) sendGoAway() error {
cc.mu.Lock()
closing := cc.closing
cc.closing = true
maxStreamID := cc.nextStreamID
cc.mu.Unlock()
if closing {
// GOAWAY sent already
return nil
}
cc.wmu.Lock()
defer cc.wmu.Unlock()
// Send a graceful shutdown frame to server
if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
return err
}
if err := cc.bw.Flush(); err != nil {
return err
}
// Prevent new requests
return nil
}
// closes the client connection immediately. In-flight requests are interrupted.
// err is sent to streams.
func (cc *http2ClientConn) closeForError(err error) {
cc.mu.Lock()
cc.closed = true
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
cc.cond.Broadcast()
cc.mu.Unlock()
cc.closeConn()
}
// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *http2ClientConn) Close() error {
err := errors.New("http2: client connection force closed via ClientConn.Close")
cc.closeForError(err)
return nil
}
// closes the client connection immediately. In-flight requests are interrupted.
func (cc *http2ClientConn) closeForLostPing() {
err := errors.New("http2: client connection lost")
if f := cc.t.CountError; f != nil {
f("conn_close_lost_ping")
}
cc.closeForError(err)
}
// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
var http2errRequestCanceled = errors.New("net/http: request canceled")
func http2commaSeparatedTrailers(req *Request) (string, error) {
keys := make([]string, 0, len(req.Trailer))
for k := range req.Trailer {
k = http2canonicalHeader(k)
switch k {
case "Transfer-Encoding", "Trailer", "Content-Length":
return "", fmt.Errorf("invalid Trailer key %q", k)
}
keys = append(keys, k)
}
if len(keys) > 0 {
sort.Strings(keys)
return strings.Join(keys, ","), nil
}
return "", nil
}
func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
}
// No way to do this (yet?) with just an http2.Transport. Probably
// no need. Request.Cancel this is the new way. We only need to support
// this for compatibility with the old http.Transport fields when
// we're doing transparent http2.
return 0
}
// checkConnHeaders checks whether req has any invalid connection-level headers.
// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
// Certain headers are special-cased as okay but not transmitted later.
func http2checkConnHeaders(req *Request) error {
if v := req.Header.Get("Upgrade"); v != "" {
return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"])
}
if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
}
if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("http2: invalid Connection request header: %q", vv)
}
return nil
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func http2actualContentLength(req *Request) int64 {
if req.Body == nil || req.Body == NoBody {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
func (cc *http2ClientConn) decrStreamReservations() {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.decrStreamReservationsLocked()
}
func (cc *http2ClientConn) decrStreamReservationsLocked() {
if cc.streamsReserved > 0 {
cc.streamsReserved--
}
}
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
ctx := req.Context()
cs := &http2clientStream{
cc: cc,
ctx: ctx,
reqCancel: req.Cancel,
isHead: req.Method == "HEAD",
reqBody: req.Body,
reqBodyContentLength: http2actualContentLength(req),
trace: httptrace.ContextClientTrace(ctx),
peerClosed: make(chan struct{}),
abort: make(chan struct{}),
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
go cs.doRequest(req)
waitDone := func() error {
select {
case <-cs.donec:
return nil
case <-ctx.Done():
return ctx.Err()
case <-cs.reqCancel:
return http2errRequestCanceled
}
}
handleResponseHeaders := func() (*Response, error) {
res := cs.res
if res.StatusCode > 299 {
// On error or status code 3xx, 4xx, 5xx, etc abort any
// ongoing write, assuming that the server doesn't care
// about our request body. If the server replied with 1xx or
// 2xx, however, then assume the server DOES potentially
// want our body (e.g. full-duplex streaming:
// golang.org/issue/13444). If it turns out the server
// doesn't, they'll RST_STREAM us soon enough. This is a
// heuristic to avoid adding knobs to Transport. Hopefully
// we can keep it.
cs.abortRequestBodyWrite()
}
res.Request = req
res.TLS = cc.tlsState
if res.Body == http2noBody && http2actualContentLength(req) == 0 {
// If there isn't a request or response body still being
// written, then wait for the stream to be closed before
// RoundTrip returns.
if err := waitDone(); err != nil {
return nil, err
}
}
return res, nil
}
for {
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
case <-cs.abort:
select {
case <-cs.respHeaderRecv:
// If both cs.respHeaderRecv and cs.abort are signaling,
// pick respHeaderRecv. The server probably wrote the
// response and immediately reset the stream.
// golang.org/issue/49645
return handleResponseHeaders()
default:
waitDone()
return nil, cs.abortErr
}
case <-ctx.Done():
err := ctx.Err()
cs.abortStream(err)
return nil, err
case <-cs.reqCancel:
cs.abortStream(http2errRequestCanceled)
return nil, http2errRequestCanceled
}
}
}
// doRequest runs for the duration of the request lifetime.
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
func (cs *http2clientStream) doRequest(req *Request) {
err := cs.writeRequest(req)
cs.cleanupWriteRequest(err)
}
// writeRequest sends a request.
//
// It returns nil after the request is written, the response read,
// and the request stream is half-closed by the peer.
//
// It returns non-nil if the request ends otherwise.
// If the returned error is StreamError, the error Code may be used in resetting the stream.
func (cs *http2clientStream) writeRequest(req *Request) (err error) {
cc := cs.cc
ctx := cs.ctx
if err := http2checkConnHeaders(req); err != nil {
return err
}
// Acquire the new-request lock by writing to reqHeaderMu.
// This lock guards the critical section covering allocating a new stream ID
// (requires mu) and creating the stream (requires wmu).
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
return http2errRequestCanceled
case <-ctx.Done():
return ctx.Err()
}
cc.mu.Lock()
if cc.idleTimer != nil {
cc.idleTimer.Stop()
}
cc.decrStreamReservationsLocked()
if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil {
cc.mu.Unlock()
<-cc.reqHeaderMu
return err
}
cc.addStreamLocked(cs) // assigns stream ID
if http2isConnectionCloseRequest(req) {
cc.doNotReuse = true
}
cc.mu.Unlock()
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
!cs.isHead {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// http://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
cs.requestedGzip = true
}
continueTimeout := cc.t.expectContinueTimeout()
if continueTimeout != 0 {
if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") {
continueTimeout = 0
} else {
cs.on100 = make(chan struct{}, 1)
}
}
// Past this point (where we send request headers), it is possible for
// RoundTrip to return successfully. Since the RoundTrip contract permits
// the caller to "mutate or reuse" the Request after closing the Response's Body,
// we must take care when referencing the Request from here on.
err = cs.encodeAndWriteHeaders(req)
<-cc.reqHeaderMu
if err != nil {
return err
}
hasBody := cs.reqBodyContentLength != 0
if !hasBody {
cs.sentEndStream = true
} else {
if continueTimeout != 0 {
http2traceWait100Continue(cs.trace)
timer := time.NewTimer(continueTimeout)
select {
case <-timer.C:
err = nil
case <-cs.on100:
err = nil
case <-cs.abort:
err = cs.abortErr
case <-ctx.Done():
err = ctx.Err()
case <-cs.reqCancel:
err = http2errRequestCanceled
}
timer.Stop()
if err != nil {
http2traceWroteRequest(cs.trace, err)
return err
}
}
if err = cs.writeRequestBody(req); err != nil {
if err != http2errStopReqBodyWrite {
http2traceWroteRequest(cs.trace, err)
return err
}
} else {
cs.sentEndStream = true
}
}
http2traceWroteRequest(cs.trace, err)
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
timer := time.NewTimer(d)
defer timer.Stop()
respHeaderTimer = timer.C
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
select {
case <-cs.peerClosed:
return nil
case <-respHeaderTimer:
return http2errTimeout
case <-respHeaderRecv:
respHeaderRecv = nil
respHeaderTimer = nil // keep waiting for END_STREAM
case <-cs.abort:
return cs.abortErr
case <-ctx.Done():
return ctx.Err()
case <-cs.reqCancel:
return http2errRequestCanceled
}
}
}
func (cs *http2clientStream) encodeAndWriteHeaders(req *Request) error {
cc := cs.cc
ctx := cs.ctx
cc.wmu.Lock()
defer cc.wmu.Unlock()
// If the request was canceled while waiting for cc.mu, just quit.
select {
case <-cs.abort:
return cs.abortErr
case <-ctx.Done():
return ctx.Err()
case <-cs.reqCancel:
return http2errRequestCanceled
default:
}
// Encode headers.
//
// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
// sent by writeRequestBody below, along with any Trailers,
// again in form HEADERS{1}, CONTINUATION{0,})
trailers, err := http2commaSeparatedTrailers(req)
if err != nil {
return err
}
hasTrailers := trailers != ""
contentLen := http2actualContentLength(req)
hasBody := contentLen != 0
hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
if err != nil {
return err
}
// Write the request.
endStream := !hasBody && !hasTrailers
cs.sentHeaders = true
err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
http2traceWroteHeaders(cs.trace)
return err
}
// cleanupWriteRequest performs post-request tasks.
//
// If err (the result of writeRequest) is non-nil and the stream is not closed,
// cleanupWriteRequest will send a reset to the peer.
func (cs *http2clientStream) cleanupWriteRequest(err error) {
cc := cs.cc
if cs.ID == 0 {
// We were canceled before creating the stream, so return our reservation.
cc.decrStreamReservations()
}
// TODO: write h12Compare test showing whether
// Request.Body is closed by the Transport,
// and in multiple cases: server replies <=299 and >299
// while still writing request body
cc.mu.Lock()
mustCloseBody := false
if cs.reqBody != nil && cs.reqBodyClosed == nil {
mustCloseBody = true
cs.reqBodyClosed = make(chan struct{})
}
bodyClosed := cs.reqBodyClosed
cc.mu.Unlock()
if mustCloseBody {
cs.reqBody.Close()
close(bodyClosed)
}
if bodyClosed != nil {
<-bodyClosed
}
if err != nil && cs.sentEndStream {
// If the connection is closed immediately after the response is read,
// we may be aborted before finishing up here. If the stream was closed
// cleanly on both sides, there is no error.
select {
case <-cs.peerClosed:
err = nil
default:
}
}
if err != nil {
cs.abortStream(err) // possibly redundant, but harmless
if cs.sentHeaders {
if se, ok := err.(http2StreamError); ok {
if se.Cause != http2errFromPeer {
cc.writeStreamReset(cs.ID, se.Code, err)
}
} else {
cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err)
}
}
cs.bufPipe.CloseWithError(err) // no-op if already closed
} else {
if cs.sentHeaders && !cs.sentEndStream {
cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil)
}
cs.bufPipe.CloseWithError(http2errRequestCanceled)
}
if cs.ID != 0 {
cc.forgetStreamID(cs.ID)
}
cc.wmu.Lock()
werr := cc.werr
cc.wmu.Unlock()
if werr != nil {
cc.Close()
}
close(cs.donec)
}
// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams.
// Must hold cc.mu.
func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error {
for {
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
return http2errClientConnUnusable
}
cc.lastIdle = time.Time{}
if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) {
return nil
}
cc.pendingRequests++
cc.cond.Wait()
cc.pendingRequests--
select {
case <-cs.abort:
return cs.abortErr
default:
}
}
}
// requires cc.wmu be held
func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error {
first := true // first frame written (HEADERS is first, then CONTINUATION)
for len(hdrs) > 0 && cc.werr == nil {
chunk := hdrs
if len(chunk) > maxFrameSize {
chunk = chunk[:maxFrameSize]
}
hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0
if first {
cc.fr.WriteHeaders(http2HeadersFrameParam{
StreamID: streamID,
BlockFragment: chunk,
EndStream: endStream,
EndHeaders: endHeaders,
})
first = false
} else {
cc.fr.WriteContinuation(streamID, endHeaders, chunk)
}
}
cc.bw.Flush()
return cc.werr
}
// internal error values; they don't escape to callers
var (
// abort request body write; don't send cancel
http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
// abort request body write, but send stream reset of cancel.
http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length")
)
// frameScratchBufferLen returns the length of a buffer to use for
// outgoing request bodies to read/write to/from.
//
// It returns max(1, min(peer's advertised max frame size,
// Request.ContentLength+1, 512KB)).
func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int {
const max = 512 << 10
n := int64(maxFrameSize)
if n > max {
n = max
}
if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n {
// Add an extra byte past the declared content-length to
// give the caller's Request.Body io.Reader a chance to
// give us more bytes than they declared, so we can catch it
// early.
n = cl + 1
}
if n < 1 {
return 1
}
return int(n) // doesn't truncate; max is 512K
}
var http2bufPool sync.Pool // of *[]byte
func (cs *http2clientStream) writeRequestBody(req *Request) (err error) {
cc := cs.cc
body := cs.reqBody
sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
hasTrailers := req.Trailer != nil
remainLen := cs.reqBodyContentLength
hasContentLen := remainLen != -1
cc.mu.Lock()
maxFrameSize := int(cc.maxFrameSize)
cc.mu.Unlock()
// Scratch buffer for reading into & writing from.
scratchLen := cs.frameScratchBufferLen(maxFrameSize)
var buf []byte
if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
defer http2bufPool.Put(bp)
buf = *bp
} else {
buf = make([]byte, scratchLen)
defer http2bufPool.Put(&buf)
}
var sawEOF bool
for !sawEOF {
n, err := body.Read(buf)
if hasContentLen {
remainLen -= int64(n)
if remainLen == 0 && err == nil {
// The request body's Content-Length was predeclared and
// we just finished reading it all, but the underlying io.Reader
// returned the final chunk with a nil error (which is one of
// the two valid things a Reader can do at EOF). Because we'd prefer
// to send the END_STREAM bit early, double-check that we're actually
// at EOF. Subsequent reads should return (0, EOF) at this point.
// If either value is different, we return an error in one of two ways below.
var scratch [1]byte
var n1 int
n1, err = body.Read(scratch[:])
remainLen -= int64(n1)
}
if remainLen < 0 {
err = http2errReqBodyTooLong
return err
}
}
if err != nil {
cc.mu.Lock()
bodyClosed := cs.reqBodyClosed != nil
cc.mu.Unlock()
switch {
case bodyClosed:
return http2errStopReqBodyWrite
case err == io.EOF:
sawEOF = true
err = nil
default:
return err
}
}
remain := buf[:n]
for len(remain) > 0 && err == nil {
var allowed int32
allowed, err = cs.awaitFlowControl(len(remain))
if err != nil {
return err
}
cc.wmu.Lock()
data := remain[:allowed]
remain = remain[allowed:]
sentEnd = sawEOF && len(remain) == 0 && !hasTrailers
err = cc.fr.WriteData(cs.ID, sentEnd, data)
if err == nil {
// TODO(bradfitz): this flush is for latency, not bandwidth.
// Most requests won't need this. Make this opt-in or
// opt-out? Use some heuristic on the body type? Nagel-like
// timers? Based on 'n'? Only last chunk of this for loop,
// unless flow control tokens are low? For now, always.
// If we change this, see comment below.
err = cc.bw.Flush()
}
cc.wmu.Unlock()
}
if err != nil {
return err
}
}
if sentEnd {
// Already sent END_STREAM (which implies we have no
// trailers) and flushed, because currently all
// WriteData frames above get a flush. So we're done.
return nil
}
// Since the RoundTrip contract permits the caller to "mutate or reuse"
// a request after the Response's Body is closed, verify that this hasn't
// happened before accessing the trailers.
cc.mu.Lock()
trailer := req.Trailer
err = cs.abortErr
cc.mu.Unlock()
if err != nil {
return err
}
cc.wmu.Lock()
defer cc.wmu.Unlock()
var trls []byte
if len(trailer) > 0 {
trls, err = cc.encodeTrailers(trailer)
if err != nil {
return err
}
}
// Two ways to send END_STREAM: either with trailers, or
// with an empty DATA frame.
if len(trls) > 0 {
err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls)
} else {
err = cc.fr.WriteData(cs.ID, true, nil)
}
if ferr := cc.bw.Flush(); ferr != nil && err == nil {
err = ferr
}
return err
}
// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow
// control tokens from the server.
// It returns either the non-zero number of tokens taken or an error
// if the stream is dead.
func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
cc := cs.cc
ctx := cs.ctx
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if cc.closed {
return 0, http2errClientConnClosed
}
if cs.reqBodyClosed != nil {
return 0, http2errStopReqBodyWrite
}
select {
case <-cs.abort:
return 0, cs.abortErr
case <-ctx.Done():
return 0, ctx.Err()
case <-cs.reqCancel:
return 0, http2errRequestCanceled
default:
}
if a := cs.flow.available(); a > 0 {
take := a
if int(take) > maxBytes {
take = int32(maxBytes) // can't truncate int; take is int32
}
if take > int32(cc.maxFrameSize) {
take = int32(cc.maxFrameSize)
}
cs.flow.take(take)
return take, nil
}
cc.cond.Wait()
}
}
var http2errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held.
func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
cc.hbuf.Reset()
if req.URL == nil {
return nil, http2errNilRequestURL
}
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !http2validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !http2validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
}
}
}
enumerateHeaders := func(f func(name, value string)) {
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
m := req.Method
if m == "" {
m = MethodGet
}
f(":method", m)
if req.Method != "CONNECT" {
f(":path", path)
f(":scheme", req.URL.Scheme)
}
if trailers != "" {
f("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") {
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
} else if http2asciiEqualFold(k, "connection") ||
http2asciiEqualFold(k, "proxy-connection") ||
http2asciiEqualFold(k, "transfer-encoding") ||
http2asciiEqualFold(k, "upgrade") ||
http2asciiEqualFold(k, "keep-alive") {
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
} else if http2asciiEqualFold(k, "user-agent") {
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
} else if http2asciiEqualFold(k, "cookie") {
// Per 8.1.2.5 To allow for better compression efficiency, the
// Cookie header field MAY be split into separate header fields,
// each with one or more cookie-pairs.
for _, v := range vv {
for {
p := strings.IndexByte(v, ';')
if p < 0 {
break
}
f("cookie", v[:p])
p++
// strip space after semicolon if any.
for p+1 <= len(v) && v[p] == ' ' {
p++
}
v = v[p:]
}
if len(v) > 0 {
f("cookie", v)
}
}
continue
}
for _, v := range vv {
f(k, v)
}
}
if http2shouldSendReqContentLength(req.Method, contentLength) {
f("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
f("accept-encoding", "gzip")
}
if !didUA {
f("user-agent", http2defaultUserAgent)
}
}
// Do a first pass over the headers counting bytes to ensure
// we don't exceed cc.peerMaxHeaderListSize. This is done as a
// separate pass before encoding the headers to prevent
// modifying the hpack state.
hlSize := uint64(0)
enumerateHeaders(func(name, value string) {
hf := hpack.HeaderField{Name: name, Value: value}
hlSize += uint64(hf.Size())
})
if hlSize > cc.peerMaxHeaderListSize {
return nil, http2errRequestHeaderListSize
}
trace := httptrace.ContextClientTrace(req.Context())
traceHeaders := http2traceHasWroteHeaderField(trace)
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
name, ascii := http2lowerHeader(name)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
return
}
cc.writeHeader(name, value)
if traceHeaders {
http2traceWroteHeaderField(trace, name, value)
}
})
return cc.hbuf.Bytes(), nil
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func http2shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
// requires cc.wmu be held.
func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) {
cc.hbuf.Reset()
hlSize := uint64(0)
for k, vv := range trailer {
for _, v := range vv {
hf := hpack.HeaderField{Name: k, Value: v}
hlSize += uint64(hf.Size())
}
}
if hlSize > cc.peerMaxHeaderListSize {
return nil, http2errRequestHeaderListSize
}
for k, vv := range trailer {
lowKey, ascii := http2lowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
continue
}
// Transfer-Encoding, etc.. have already been filtered at the
// start of RoundTrip
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
return cc.hbuf.Bytes(), nil
}
func (cc *http2ClientConn) writeHeader(name, value string) {
if http2VerboseLogs {
log.Printf("http2: Transport encoding header %q = %q", name, value)
}
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
type http2resAndError struct {
_ http2incomparable
res *Response
err error
}
// requires cc.mu be held.
func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) {
cs.flow.add(int32(cc.initialWindowSize))
cs.flow.setConnFlow(&cc.flow)
cs.inflow.init(http2transportDefaultStreamFlow)
cs.ID = cc.nextStreamID
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
if cs.ID == 0 {
panic("assigned stream ID 0")
}
}
func (cc *http2ClientConn) forgetStreamID(id uint32) {
cc.mu.Lock()
slen := len(cc.streams)
delete(cc.streams, id)
if len(cc.streams) != slen-1 {
panic("forgetting unknown stream id")
}
cc.lastActive = time.Now()
if len(cc.streams) == 0 && cc.idleTimer != nil {
cc.idleTimer.Reset(cc.idleTimeout)
cc.lastIdle = time.Now()
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
cc.cond.Broadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
if http2VerboseLogs {
cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2)
}
cc.closed = true
defer cc.closeConn()
}
cc.mu.Unlock()
}
// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
type http2clientConnReadLoop struct {
_ http2incomparable
cc *http2ClientConn
}
// readLoop runs in its own goroutine and reads and dispatches frames.
func (cc *http2ClientConn) readLoop() {
rl := &http2clientConnReadLoop{cc: cc}
defer rl.cleanup()
cc.readerErr = rl.run()
if ce, ok := cc.readerErr.(http2ConnectionError); ok {
cc.wmu.Lock()
cc.fr.WriteGoAway(0, http2ErrCode(ce), nil)
cc.wmu.Unlock()
}
}
// GoAwayError is returned by the Transport when the server closes the
// TCP connection after sending a GOAWAY frame.
type http2GoAwayError struct {
LastStreamID uint32
ErrCode http2ErrCode
DebugData string
}
func (e http2GoAwayError) Error() string {
return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q",
e.LastStreamID, e.ErrCode, e.DebugData)
}
func http2isEOFOrNetReadError(err error) bool {
if err == io.EOF {
return true
}
ne, ok := err.(*net.OpError)
return ok && ne.Op == "read"
}
func (rl *http2clientConnReadLoop) cleanup() {
cc := rl.cc
cc.t.connPool().MarkDead(cc)
defer cc.closeConn()
defer close(cc.readerDone)
if cc.idleTimer != nil {
cc.idleTimer.Stop()
}
// Close any response bodies if the server closes prematurely.
// TODO: also do this if we've written the headers but not
// gotten a response yet.
err := cc.readerErr
cc.mu.Lock()
if cc.goAway != nil && http2isEOFOrNetReadError(err) {
err = http2GoAwayError{
LastStreamID: cc.goAway.LastStreamID,
ErrCode: cc.goAway.ErrCode,
DebugData: cc.goAwayDebug,
}
} else if err == io.EOF {
err = io.ErrUnexpectedEOF
}
cc.closed = true
for _, cs := range cc.streams {
select {
case <-cs.peerClosed:
// The server closed the stream before closing the conn,
// so no need to interrupt it.
default:
cs.abortStreamLocked(err)
}
}
cc.cond.Broadcast()
cc.mu.Unlock()
}
// countReadFrameError calls Transport.CountError with a string
// representing err.
func (cc *http2ClientConn) countReadFrameError(err error) {
f := cc.t.CountError
if f == nil || err == nil {
return
}
if ce, ok := err.(http2ConnectionError); ok {
errCode := http2ErrCode(ce)
f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken()))
return
}
if errors.Is(err, io.EOF) {
f("read_frame_eof")
return
}
if errors.Is(err, io.ErrUnexpectedEOF) {
f("read_frame_unexpected_eof")
return
}
if errors.Is(err, http2ErrFrameTooLarge) {
f("read_frame_too_large")
return
}
f("read_frame_other")
}
func (rl *http2clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
}
for {
f, err := cc.fr.ReadFrame()
if t != nil {
t.Reset(readIdleTimeout)
}
if err != nil {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
if se, ok := err.(http2StreamError); ok {
if cs := rl.streamByID(se.StreamID); cs != nil {
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
rl.endStreamError(cs, se)
}
continue
} else if err != nil {
cc.countReadFrameError(err)
return err
}
if http2VerboseLogs {
cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
}
if !gotSettings {
if _, ok := f.(*http2SettingsFrame); !ok {
cc.logf("protocol error: received %T before a SETTINGS frame", f)
return http2ConnectionError(http2ErrCodeProtocol)
}
gotSettings = true
}
switch f := f.(type) {
case *http2MetaHeadersFrame:
err = rl.processHeaders(f)
case *http2DataFrame:
err = rl.processData(f)
case *http2GoAwayFrame:
err = rl.processGoAway(f)
case *http2RSTStreamFrame:
err = rl.processResetStream(f)
case *http2SettingsFrame:
err = rl.processSettings(f)
case *http2PushPromiseFrame:
err = rl.processPushPromise(f)
case *http2WindowUpdateFrame:
err = rl.processWindowUpdate(f)
case *http2PingFrame:
err = rl.processPing(f)
default:
cc.logf("Transport: unhandled response frame type %T", f)
}
if err != nil {
if http2VerboseLogs {
cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err)
}
return err
}
}
}
func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error {
cs := rl.streamByID(f.StreamID)
if cs == nil {
// We'd get here if we canceled a request while the
// server had its response still in flight. So if this
// was just something we canceled, ignore it.
return nil
}
if cs.readClosed {
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
Code: http2ErrCodeProtocol,
Cause: errors.New("protocol error: headers after END_STREAM"),
})
return nil
}
if !cs.firstByte {
if cs.trace != nil {
// TODO(bradfitz): move first response byte earlier,
// when we first read the 9 byte header, not waiting
// until all the HEADERS+CONTINUATION frames have been
// merged. This works for now.
http2traceFirstResponseByte(cs.trace)
}
cs.firstByte = true
}
if !cs.pastHeaders {
cs.pastHeaders = true
} else {
return rl.processTrailers(cs, f)
}
res, err := rl.handleResponse(cs, f)
if err != nil {
if _, ok := err.(http2ConnectionError); ok {
return err
}
// Any other error type is a stream error.
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
Code: http2ErrCodeProtocol,
Cause: err,
})
return nil // return nil from process* funcs to keep conn alive
}
if res == nil {
// (nil, nil) special case. See handleResponse docs.
return nil
}
cs.resTrailer = &res.Trailer
cs.res = res
close(cs.respHeaderRecv)
if f.StreamEnded() {
rl.endStream(cs)
}
return nil
}
// may return error types nil, or ConnectionError. Any other error value
// is a StreamError of type ErrCodeProtocol. The returned error in that case
// is the detail.
//
// As a special case, handleResponse may return (nil, nil) to skip the
// frame (currently only used for 1xx responses).
func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) {
if f.Truncated {
return nil, http2errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("malformed response from server: missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header")
}
regularFields := f.RegularFields()
strs := make([]string, len(regularFields))
header := make(Header, len(regularFields))
res := &Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + StatusText(statusCode),
}
for _, hf := range regularFields {
key := http2canonicalHeader(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(Header)
res.Trailer = t
}
http2foreachHeaderElement(hf.Value, func(v string) {
t[http2canonicalHeader(v)] = nil
})
} else {
vv := header[key]
if vv == nil && len(strs) > 0 {
// More than likely this will be a single-element key.
// Most headers aren't multi-valued.
// Set the capacity on strs[0] to 1, so any future append
// won't extend the slice into the other strings.
vv, strs = strs[:1:1], strs[1:]
vv[0] = hf.Value
header[key] = vv
} else {
header[key] = append(vv, hf.Value)
}
}
}
if statusCode >= 100 && statusCode <= 199 {
if f.StreamEnded() {
return nil, errors.New("1xx informational response with END_STREAM flag")
}
cs.num1xx++
const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
if cs.num1xx > max1xxResponses {
return nil, errors.New("http2: too many 1xx informational responses")
}
if fn := cs.get1xxTraceFunc(); fn != nil {
if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
return nil, err
}
}
if statusCode == 100 {
http2traceGot100Continue(cs.trace)
select {
case cs.on100 <- struct{}{}:
default:
}
}
cs.pastHeaders = false // do it all again
return nil, nil
}
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil {
res.ContentLength = int64(cl)
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
} else if f.StreamEnded() && !cs.isHead {
res.ContentLength = 0
}
if cs.isHead {
res.Body = http2noBody
return res, nil
}
if f.StreamEnded() {
if res.ContentLength > 0 {
res.Body = http2missingBody{}
} else {
res.Body = http2noBody
}
return res, nil
}
cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength})
cs.bytesRemain = res.ContentLength
res.Body = http2transportResponseBody{cs}
if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &http2gzipReader{body: res.Body}
res.Uncompressed = true
}
return res, nil
}
func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error {
if cs.pastTrailers {
// Too many HEADERS frames for this stream.
return http2ConnectionError(http2ErrCodeProtocol)
}
cs.pastTrailers = true
if !f.StreamEnded() {
// We expect that any headers for trailers also
// has END_STREAM.
return http2ConnectionError(http2ErrCodeProtocol)
}
if len(f.PseudoFields()) > 0 {
// No pseudo header fields are defined for trailers.
// TODO: ConnectionError might be overly harsh? Check.
return http2ConnectionError(http2ErrCodeProtocol)
}
trailer := make(Header)
for _, hf := range f.RegularFields() {
key := http2canonicalHeader(hf.Name)
trailer[key] = append(trailer[key], hf.Value)
}
cs.trailer = trailer
rl.endStream(cs)
return nil
}
// transportResponseBody is the concrete type of Transport.RoundTrip's
// Response.Body. It is an io.ReadCloser.
type http2transportResponseBody struct {
cs *http2clientStream
}
func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
cs := b.cs
cc := cs.cc
if cs.readErr != nil {
return 0, cs.readErr
}
n, err = b.cs.bufPipe.Read(p)
if cs.bytesRemain != -1 {
if int64(n) > cs.bytesRemain {
n = int(cs.bytesRemain)
if err == nil {
err = errors.New("net/http: server replied with more than declared Content-Length; truncated")
cs.abortStream(err)
}
cs.readErr = err
return int(cs.bytesRemain), err
}
cs.bytesRemain -= int64(n)
if err == io.EOF && cs.bytesRemain > 0 {
err = io.ErrUnexpectedEOF
cs.readErr = err
return n, err
}
}
if n == 0 {
// No flow control tokens to send back.
return
}
cc.mu.Lock()
connAdd := cc.inflow.add(n)
var streamAdd int32
if err == nil { // No need to refresh if the stream is over or failed.
streamAdd = cs.inflow.add(n)
}
cc.mu.Unlock()
if connAdd != 0 || streamAdd != 0 {
cc.wmu.Lock()
defer cc.wmu.Unlock()
if connAdd != 0 {
cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd))
}
if streamAdd != 0 {
cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd))
}
cc.bw.Flush()
}
return
}
var http2errClosedResponseBody = errors.New("http2: response body closed")
func (b http2transportResponseBody) Close() error {
cs := b.cs
cc := cs.cc
unread := cs.bufPipe.Len()
if unread > 0 {
cc.mu.Lock()
// Return connection-level flow control.
connAdd := cc.inflow.add(unread)
cc.mu.Unlock()
// TODO(dneil): Acquiring this mutex can block indefinitely.
// Move flow control return to a goroutine?
cc.wmu.Lock()
// Return connection-level flow control.
if connAdd > 0 {
cc.fr.WriteWindowUpdate(0, uint32(connAdd))
}
cc.bw.Flush()
cc.wmu.Unlock()
}
cs.bufPipe.BreakWithError(http2errClosedResponseBody)
cs.abortStream(http2errClosedResponseBody)
select {
case <-cs.donec:
case <-cs.ctx.Done():
// See golang/go#49366: The net/http package can cancel the
// request context after the response body is fully read.
// Don't treat this as an error.
return nil
case <-cs.reqCancel:
return http2errRequestCanceled
}
return nil
}
func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
cc := rl.cc
cs := rl.streamByID(f.StreamID)
data := f.Data()
if cs == nil {
cc.mu.Lock()
neverSent := cc.nextStreamID
cc.mu.Unlock()
if f.StreamID >= neverSent {
// We never asked for this.
cc.logf("http2: Transport received unsolicited DATA frame; closing connection")
return http2ConnectionError(http2ErrCodeProtocol)
}
// We probably did ask for this, but canceled. Just ignore it.
// TODO: be stricter here? only silently ignore things which
// we canceled, but not things which were closed normally
// by the peer? Tough without accumulating too much state.
// But at least return their flow control:
if f.Length > 0 {
cc.mu.Lock()
ok := cc.inflow.take(f.Length)
connAdd := cc.inflow.add(int(f.Length))
cc.mu.Unlock()
if !ok {
return http2ConnectionError(http2ErrCodeFlowControl)
}
if connAdd > 0 {
cc.wmu.Lock()
cc.fr.WriteWindowUpdate(0, uint32(connAdd))
cc.bw.Flush()
cc.wmu.Unlock()
}
}
return nil
}
if cs.readClosed {
cc.logf("protocol error: received DATA after END_STREAM")
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
Code: http2ErrCodeProtocol,
})
return nil
}
if !cs.firstByte {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
Code: http2ErrCodeProtocol,
})
return nil
}
if f.Length > 0 {
if cs.isHead && len(data) > 0 {
cc.logf("protocol error: received DATA on a HEAD request")
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
Code: http2ErrCodeProtocol,
})
return nil
}
// Check connection-level flow control.
cc.mu.Lock()
if !http2takeInflows(&cc.inflow, &cs.inflow, f.Length) {
cc.mu.Unlock()
return http2ConnectionError(http2ErrCodeFlowControl)
}
// Return any padded flow control now, since we won't
// refund it later on body reads.
var refund int
if pad := int(f.Length) - len(data); pad > 0 {
refund += pad
}
didReset := false
var err error
if len(data) > 0 {
if _, err = cs.bufPipe.Write(data); err != nil {
// Return len(data) now if the stream is already closed,
// since data will never be read.
didReset = true
refund += len(data)
}
}
sendConn := cc.inflow.add(refund)
var sendStream int32
if !didReset {
sendStream = cs.inflow.add(refund)
}
cc.mu.Unlock()
if sendConn > 0 || sendStream > 0 {
cc.wmu.Lock()
if sendConn > 0 {
cc.fr.WriteWindowUpdate(0, uint32(sendConn))
}
if sendStream > 0 {
cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream))
}
cc.bw.Flush()
cc.wmu.Unlock()
}
if err != nil {
rl.endStreamError(cs, err)
return nil
}
}
if f.StreamEnded() {
rl.endStream(cs)
}
return nil
}
func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) {
// TODO: check that any declared content-length matches, like
// server.go's (*stream).endStream method.
if !cs.readClosed {
cs.readClosed = true
// Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a
// race condition: The caller can read io.EOF from Response.Body
// and close the body before we close cs.peerClosed, causing
// cleanupWriteRequest to send a RST_STREAM.
rl.cc.mu.Lock()
defer rl.cc.mu.Unlock()
cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers)
close(cs.peerClosed)
}
}
func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) {
cs.readAborted = true
cs.abortStream(err)
}
func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream {
rl.cc.mu.Lock()
defer rl.cc.mu.Unlock()
cs := rl.cc.streams[id]
if cs != nil && !cs.readAborted {
return cs
}
return nil
}
func (cs *http2clientStream) copyTrailers() {
for k, vv := range cs.trailer {
t := cs.resTrailer
if *t == nil {
*t = make(Header)
}
(*t)[k] = vv
}
}
func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
cc := rl.cc
cc.t.connPool().MarkDead(cc)
if f.ErrCode != 0 {
// TODO: deal with GOAWAY more. particularly the error code
cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
if fn := cc.t.CountError; fn != nil {
fn("recv_goaway_" + f.ErrCode.stringToken())
}
}
cc.setGoAway(f)
return nil
}
func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error {
cc := rl.cc
// Locking both mu and wmu here allows frame encoding to read settings with only wmu held.
// Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless.
cc.wmu.Lock()
defer cc.wmu.Unlock()
if err := rl.processSettingsNoWrite(f); err != nil {
return err
}
if !f.IsAck() {
cc.fr.WriteSettingsAck()
cc.bw.Flush()
}
return nil
}
func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error {
cc := rl.cc
cc.mu.Lock()
defer cc.mu.Unlock()
if f.IsAck() {
if cc.wantSettingsAck {
cc.wantSettingsAck = false
return nil
}
return http2ConnectionError(http2ErrCodeProtocol)
}
var seenMaxConcurrentStreams bool
err := f.ForeachSetting(func(s http2Setting) error {
switch s.ID {
case http2SettingMaxFrameSize:
cc.maxFrameSize = s.Val
case http2SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val
seenMaxConcurrentStreams = true
case http2SettingMaxHeaderListSize:
cc.peerMaxHeaderListSize = uint64(s.Val)
case http2SettingInitialWindowSize:
// Values above the maximum flow-control
// window size of 2^31-1 MUST be treated as a
// connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR.
if s.Val > math.MaxInt32 {
return http2ConnectionError(http2ErrCodeFlowControl)
}
// Adjust flow control of currently-open
// frames by the difference of the old initial
// window size and this one.
delta := int32(s.Val) - int32(cc.initialWindowSize)
for _, cs := range cc.streams {
cs.flow.add(delta)
}
cc.cond.Broadcast()
cc.initialWindowSize = s.Val
case http2SettingHeaderTableSize:
cc.henc.SetMaxDynamicTableSize(s.Val)
cc.peerMaxHeaderTableSize = s.Val
default:
cc.vlogf("Unhandled Setting: %v", s)
}
return nil
})
if err != nil {
return err
}
if !cc.seenSettings {
if !seenMaxConcurrentStreams {
// This was the servers initial SETTINGS frame and it
// didn't contain a MAX_CONCURRENT_STREAMS field so
// increase the number of concurrent streams this
// connection can establish to our default.
cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams
}
cc.seenSettings = true
}
return nil
}
func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error {
cc := rl.cc
cs := rl.streamByID(f.StreamID)
if f.StreamID != 0 && cs == nil {
return nil
}
cc.mu.Lock()
defer cc.mu.Unlock()
fl := &cc.flow
if cs != nil {
fl = &cs.flow
}
if !fl.add(int32(f.Increment)) {
return http2ConnectionError(http2ErrCodeFlowControl)
}
cc.cond.Broadcast()
return nil
}
func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error {
cs := rl.streamByID(f.StreamID)
if cs == nil {
// TODO: return error if server tries to RST_STREAM an idle stream
return nil
}
serr := http2streamError(cs.ID, f.ErrCode)
serr.Cause = http2errFromPeer
if f.ErrCode == http2ErrCodeProtocol {
rl.cc.SetDoNotReuse()
}
if fn := cs.cc.t.CountError; fn != nil {
fn("recv_rststream_" + f.ErrCode.stringToken())
}
cs.abortStream(serr)
cs.bufPipe.CloseWithError(serr)
return nil
}
// Ping sends a PING frame to the server and waits for the ack.
func (cc *http2ClientConn) Ping(ctx context.Context) error {
c := make(chan struct{})
// Generate a random payload
var p [8]byte
for {
if _, err := rand.Read(p[:]); err != nil {
return err
}
cc.mu.Lock()
// check for dup before insert
if _, found := cc.pings[p]; !found {
cc.pings[p] = c
cc.mu.Unlock()
break
}
cc.mu.Unlock()
}
errc := make(chan error, 1)
go func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
if err := cc.fr.WritePing(false, p); err != nil {
errc <- err
return
}
if err := cc.bw.Flush(); err != nil {
errc <- err
return
}
}()
select {
case <-c:
return nil
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone:
// connection closed
return cc.readerErr
}
}
func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error {
if f.IsAck() {
cc := rl.cc
cc.mu.Lock()
defer cc.mu.Unlock()
// If ack, notify listener if any
if c, ok := cc.pings[f.Data]; ok {
close(c)
delete(cc.pings, f.Data)
}
return nil
}
cc := rl.cc
cc.wmu.Lock()
defer cc.wmu.Unlock()
if err := cc.fr.WritePing(true, f.Data); err != nil {
return err
}
return cc.bw.Flush()
}
func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error {
// We told the peer we don't want them.
// Spec says:
// "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH
// setting of the peer endpoint is set to 0. An endpoint that
// has set this setting and has received acknowledgement MUST
// treat the receipt of a PUSH_PROMISE frame as a connection
// error (Section 5.4.1) of type PROTOCOL_ERROR."
return http2ConnectionError(http2ErrCodeProtocol)
}
func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
// TODO: map err to more interesting error codes, once the
// HTTP community comes up with some. But currently for
// RST_STREAM there's no equivalent to GOAWAY frame's debug
// data, and the error codes are all pretty vague ("cancel").
cc.wmu.Lock()
cc.fr.WriteRSTStream(streamID, code)
cc.bw.Flush()
cc.wmu.Unlock()
}
var (
http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit")
)
func (cc *http2ClientConn) logf(format string, args ...interface{}) {
cc.t.logf(format, args...)
}
func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
cc.t.vlogf(format, args...)
}
func (t *http2Transport) vlogf(format string, args ...interface{}) {
if http2VerboseLogs {
t.logf(format, args...)
}
}
func (t *http2Transport) logf(format string, args ...interface{}) {
log.Printf(format, args...)
}
var http2noBody io.ReadCloser = http2noBodyReader{}
type http2noBodyReader struct{}
func (http2noBodyReader) Close() error { return nil }
func (http2noBodyReader) Read([]byte) (int, error) { return 0, io.EOF }
type http2missingBody struct{}
func (http2missingBody) Close() error { return nil }
func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF }
func http2strSliceContains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
type http2erringRoundTripper struct{ err error }
func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err }
func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type http2gzipReader struct {
_ http2incomparable
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *http2gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *http2gzipReader) Close() error {
if err := gz.body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
}
type http2errorReader struct{ err error }
func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
// isConnectionCloseRequest reports whether req should use its own
// connection for a single request and then close the connection.
func http2isConnectionCloseRequest(req *Request) bool {
return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close")
}
// registerHTTPSProtocol calls Transport.RegisterProtocol but
// converting panics into errors.
func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("%v", e)
}
}()
t.RegisterProtocol("https", rt)
return nil
}
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
// if there's already has a cached connection to the host.
// (The field is exported so it can be accessed via reflect from net/http; tested
// by TestNoDialH2RoundTripperType)
type http2noDialH2RoundTripper struct{ *http2Transport }
func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
res, err := rt.http2Transport.RoundTrip(req)
if http2isNoCachedConnError(err) {
return nil, ErrSkipAltProtocol
}
return res, err
}
func (t *http2Transport) idleConnTimeout() time.Duration {
if t.t1 != nil {
return t.t1.IdleConnTimeout
}
return 0
}
func http2traceGetConn(req *Request, hostPort string) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GetConn == nil {
return
}
trace.GetConn(hostPort)
}
func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil {
return
}
ci := httptrace.GotConnInfo{Conn: cc.tconn}
ci.Reused = reused
cc.mu.Lock()
ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = time.Since(cc.lastActive)
}
cc.mu.Unlock()
trace.GotConn(ci)
}
func http2traceWroteHeaders(trace *httptrace.ClientTrace) {
if trace != nil && trace.WroteHeaders != nil {
trace.WroteHeaders()
}
}
func http2traceGot100Continue(trace *httptrace.ClientTrace) {
if trace != nil && trace.Got100Continue != nil {
trace.Got100Continue()
}
}
func http2traceWait100Continue(trace *httptrace.ClientTrace) {
if trace != nil && trace.Wait100Continue != nil {
trace.Wait100Continue()
}
}
func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) {
if trace != nil && trace.WroteRequest != nil {
trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
}
}
func http2traceFirstResponseByte(trace *httptrace.ClientTrace) {
if trace != nil && trace.GotFirstResponseByte != nil {
trace.GotFirstResponseByte()
}
}
// writeFramer is implemented by any type that is used to write frames.
type http2writeFramer interface {
writeFrame(http2writeContext) error
// staysWithinBuffer reports whether this writer promises that
// it will only write less than or equal to size bytes, and it
// won't Flush the write context.
staysWithinBuffer(size int) bool
}
// writeContext is the interface needed by the various frame writer
// types below. All the writeFrame methods below are scheduled via the
// frame writing scheduler (see writeScheduler in writesched.go).
//
// This interface is implemented by *serverConn.
//
// TODO: decide whether to a) use this in the client code (which didn't
// end up using this yet, because it has a simpler design, not
// currently implementing priorities), or b) delete this and
// make the server code a bit more concrete.
type http2writeContext interface {
Framer() *http2Framer
Flush() error
CloseConn() error
// HeaderEncoder returns an HPACK encoder that writes to the
// returned buffer.
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
}
// writeEndsStream reports whether w writes a frame that will transition
// the stream to a half-closed local state. This returns false for RST_STREAM,
// which closes the entire stream (not just the local half).
func http2writeEndsStream(w http2writeFramer) bool {
switch v := w.(type) {
case *http2writeData:
return v.endStream
case *http2writeResHeaders:
return v.endStream
case nil:
// This can only happen if the caller reuses w after it's
// been intentionally nil'ed out to prevent use. Keep this
// here to catch future refactoring breaking it.
panic("writeEndsStream called on nil writeFramer")
}
return false
}
type http2flushFrameWriter struct{}
func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error {
return ctx.Flush()
}
func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false }
type http2writeSettings []http2Setting
func (s http2writeSettings) staysWithinBuffer(max int) bool {
const settingSize = 6 // uint16 + uint32
return http2frameHeaderLen+settingSize*len(s) <= max
}
func (s http2writeSettings) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteSettings([]http2Setting(s)...)
}
type http2writeGoAway struct {
maxStreamID uint32
code http2ErrCode
}
func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error {
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
ctx.Flush() // ignore error: we're hanging up on them anyway
return err
}
func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes
type http2writeData struct {
streamID uint32
p []byte
endStream bool
}
func (w *http2writeData) String() string {
return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
}
func (w *http2writeData) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
}
func (w *http2writeData) staysWithinBuffer(max int) bool {
return http2frameHeaderLen+len(w.p) <= max
}
// handlerPanicRST is the message sent from handler goroutines when
// the handler panics.
type http2handlerPanicRST struct {
StreamID uint32
}
func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal)
}
func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
func (se http2StreamError) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
}
func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
type http2writePingAck struct{ pf *http2PingFrame }
func (w http2writePingAck) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WritePing(true, w.pf.Data)
}
func (w http2writePingAck) staysWithinBuffer(max int) bool {
return http2frameHeaderLen+len(w.pf.Data) <= max
}
type http2writeSettingsAck struct{}
func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteSettingsAck()
}
func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max }
// splitHeaderBlock splits headerBlock into fragments so that each fragment fits
// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true
// for the first/last fragment, respectively.
func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error {
// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
// that all peers must support (16KB). Later we could care
// more and send larger frames if the peer advertised it, but
// there's little point. Most headers are small anyway (so we
// generally won't have CONTINUATION frames), and extra frames
// only waste 9 bytes anyway.
const maxFrameSize = 16384
first := true
for len(headerBlock) > 0 {
frag := headerBlock
if len(frag) > maxFrameSize {
frag = frag[:maxFrameSize]
}
headerBlock = headerBlock[len(frag):]
if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil {
return err
}
first = false
}
return nil
}
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers or trailers from a server handler.
type http2writeResHeaders struct {
streamID uint32
httpResCode int // 0 means no ":status" line
h Header // may be nil
trailers []string // if non-nil, which keys of h to write. nil means all.
endStream bool
date string
contentType string
contentLength string
}
func http2encKV(enc *hpack.Encoder, k, v string) {
if http2VerboseLogs {
log.Printf("http2: server encoding header %q = %q", k, v)
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
func (w *http2writeResHeaders) staysWithinBuffer(max int) bool {
// TODO: this is a common one. It'd be nice to return true
// here and get into the fast path if we could be clever and
// calculate the size fast enough, or at least a conservative
// upper bound that usually fires. (Maybe if w.h and
// w.trailers are nil, so we don't need to enumerate it.)
// Otherwise I'm afraid that just calculating the length to
// answer this question would be slower than the ~2µs benefit.
return false
}
func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
if w.httpResCode != 0 {
http2encKV(enc, ":status", http2httpCodeString(w.httpResCode))
}
http2encodeHeaders(enc, w.h, w.trailers)
if w.contentType != "" {
http2encKV(enc, "content-type", w.contentType)
}
if w.contentLength != "" {
http2encKV(enc, "content-length", w.contentLength)
}
if w.date != "" {
http2encKV(enc, "date", w.date)
}
headerBlock := buf.Bytes()
if len(headerBlock) == 0 && w.trailers == nil {
panic("unexpected empty hpack")
}
return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
}
func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
if firstFrag {
return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: frag,
EndStream: w.endStream,
EndHeaders: lastFrag,
})
} else {
return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
}
}
// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames.
type http2writePushPromise struct {
streamID uint32 // pusher stream
method string // for :method
url *url.URL // for :scheme, :authority, :path
h Header
// Creates an ID for a pushed stream. This runs on serveG just before
// the frame is written. The returned ID is copied to promisedID.
allocatePromisedID func() (uint32, error)
promisedID uint32
}
func (w *http2writePushPromise) staysWithinBuffer(max int) bool {
// TODO: see writeResHeaders.staysWithinBuffer
return false
}
func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
http2encKV(enc, ":method", w.method)
http2encKV(enc, ":scheme", w.url.Scheme)
http2encKV(enc, ":authority", w.url.Host)
http2encKV(enc, ":path", w.url.RequestURI())
http2encodeHeaders(enc, w.h, nil)
headerBlock := buf.Bytes()
if len(headerBlock) == 0 {
panic("unexpected empty hpack")
}
return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
}
func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
if firstFrag {
return ctx.Framer().WritePushPromise(http2PushPromiseParam{
StreamID: w.streamID,
PromiseID: w.promisedID,
BlockFragment: frag,
EndHeaders: lastFrag,
})
} else {
return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
}
}
type http2write100ContinueHeadersFrame struct {
streamID uint32
}
func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
http2encKV(enc, ":status", "100")
return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: buf.Bytes(),
EndStream: false,
EndHeaders: true,
})
}
func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool {
// Sloppy but conservative:
return 9+2*(len(":status")+len("100")) <= max
}
type http2writeWindowUpdate struct {
streamID uint32 // or 0 for conn-level
n uint32
}
func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}
// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k])
// is encoded only if k is in keys.
func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
if keys == nil {
sorter := http2sorterPool.Get().(*http2sorter)
// Using defer here, since the returned keys from the
// sorter.Keys method is only valid until the sorter
// is returned:
defer http2sorterPool.Put(sorter)
keys = sorter.Keys(h)
}
for _, k := range keys {
vv := h[k]
k, ascii := http2lowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
continue
}
if !http2validWireHeaderFieldName(k) {
// Skip it as backup paranoia. Per
// golang.org/issue/14048, these should
// already be rejected at a higher level.
continue
}
isTE := k == "transfer-encoding"
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
}
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if isTE && v != "trailers" {
continue
}
http2encKV(enc, k, v)
}
}
}
// WriteScheduler is the interface implemented by HTTP/2 write schedulers.
// Methods are never called concurrently.
type http2WriteScheduler interface {
// OpenStream opens a new stream in the write scheduler.
// It is illegal to call this with streamID=0 or with a streamID that is
// already open -- the call may panic.
OpenStream(streamID uint32, options http2OpenStreamOptions)
// CloseStream closes a stream in the write scheduler. Any frames queued on
// this stream should be discarded. It is illegal to call this on a stream
// that is not open -- the call may panic.
CloseStream(streamID uint32)
// AdjustStream adjusts the priority of the given stream. This may be called
// on a stream that has not yet been opened or has been closed. Note that
// RFC 7540 allows PRIORITY frames to be sent on streams in any state. See:
// https://tools.ietf.org/html/rfc7540#section-5.1
AdjustStream(streamID uint32, priority http2PriorityParam)
// Push queues a frame in the scheduler. In most cases, this will not be
// called with wr.StreamID()!=0 unless that stream is currently open. The one
// exception is RST_STREAM frames, which may be sent on idle or closed streams.
Push(wr http2FrameWriteRequest)
// Pop dequeues the next frame to write. Returns false if no frames can
// be written. Frames with a given wr.StreamID() are Pop'd in the same
// order they are Push'd, except RST_STREAM frames. No frames should be
// discarded except by CloseStream.
Pop() (wr http2FrameWriteRequest, ok bool)
}
// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream.
type http2OpenStreamOptions struct {
// PusherID is zero if the stream was initiated by the client. Otherwise,
// PusherID names the stream that pushed the newly opened stream.
PusherID uint32
}
// FrameWriteRequest is a request to write a frame.
type http2FrameWriteRequest struct {
// write is the interface value that does the writing, once the
// WriteScheduler has selected this frame to write. The write
// functions are all defined in write.go.
write http2writeFramer
// stream is the stream on which this frame will be written.
// nil for non-stream frames like PING and SETTINGS.
// nil for RST_STREAM streams, which use the StreamError.StreamID field instead.
stream *http2stream
// done, if non-nil, must be a buffered channel with space for
// 1 message and is sent the return value from write (or an
// earlier error) when the frame has been written.
done chan error
}
// StreamID returns the id of the stream this frame will be written to.
// 0 is used for non-stream frames such as PING and SETTINGS.
func (wr http2FrameWriteRequest) StreamID() uint32 {
if wr.stream == nil {
if se, ok := wr.write.(http2StreamError); ok {
// (*serverConn).resetStream doesn't set
// stream because it doesn't necessarily have
// one. So special case this type of write
// message.
return se.StreamID
}
return 0
}
return wr.stream.id
}
// isControl reports whether wr is a control frame for MaxQueuedControlFrames
// purposes. That includes non-stream frames and RST_STREAM frames.
func (wr http2FrameWriteRequest) isControl() bool {
return wr.stream == nil
}
// DataSize returns the number of flow control bytes that must be consumed
// to write this entire frame. This is 0 for non-DATA frames.
func (wr http2FrameWriteRequest) DataSize() int {
if wd, ok := wr.write.(*http2writeData); ok {
return len(wd.p)
}
return 0
}
// Consume consumes min(n, available) bytes from this frame, where available
// is the number of flow control bytes available on the stream. Consume returns
// 0, 1, or 2 frames, where the integer return value gives the number of frames
// returned.
//
// If flow control prevents consuming any bytes, this returns (_, _, 0). If
// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this
// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and
// 'rest' contains the remaining bytes. The consumed bytes are deducted from the
// underlying stream's flow control budget.
func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) {
var empty http2FrameWriteRequest
// Non-DATA frames are always consumed whole.
wd, ok := wr.write.(*http2writeData)
if !ok || len(wd.p) == 0 {
return wr, empty, 1
}
// Might need to split after applying limits.
allowed := wr.stream.flow.available()
if n < allowed {
allowed = n
}
if wr.stream.sc.maxFrameSize < allowed {
allowed = wr.stream.sc.maxFrameSize
}
if allowed <= 0 {
return empty, empty, 0
}
if len(wd.p) > int(allowed) {
wr.stream.flow.take(allowed)
consumed := http2FrameWriteRequest{
stream: wr.stream,
write: &http2writeData{
streamID: wd.streamID,
p: wd.p[:allowed],
// Even if the original had endStream set, there
// are bytes remaining because len(wd.p) > allowed,
// so we know endStream is false.
endStream: false,
},
// Our caller is blocking on the final DATA frame, not
// this intermediate frame, so no need to wait.
done: nil,
}
rest := http2FrameWriteRequest{
stream: wr.stream,
write: &http2writeData{
streamID: wd.streamID,
p: wd.p[allowed:],
endStream: wd.endStream,
},
done: wr.done,
}
return consumed, rest, 2
}
// The frame is consumed whole.
// NB: This cast cannot overflow because allowed is <= math.MaxInt32.
wr.stream.flow.take(int32(len(wd.p)))
return wr, empty, 1
}
// String is for debugging only.
func (wr http2FrameWriteRequest) String() string {
var des string
if s, ok := wr.write.(fmt.Stringer); ok {
des = s.String()
} else {
des = fmt.Sprintf("%T", wr.write)
}
return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des)
}
// replyToWriter sends err to wr.done and panics if the send must block
// This does nothing if wr.done is nil.
func (wr *http2FrameWriteRequest) replyToWriter(err error) {
if wr.done == nil {
return
}
select {
case wr.done <- err:
default:
panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write))
}
wr.write = nil // prevent use (assume it's tainted after wr.done send)
}
// writeQueue is used by implementations of WriteScheduler.
type http2writeQueue struct {
s []http2FrameWriteRequest
}
func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
func (q *http2writeQueue) push(wr http2FrameWriteRequest) {
q.s = append(q.s, wr)
}
func (q *http2writeQueue) shift() http2FrameWriteRequest {
if len(q.s) == 0 {
panic("invalid use of queue")
}
wr := q.s[0]
// TODO: less copy-happy queue.
copy(q.s, q.s[1:])
q.s[len(q.s)-1] = http2FrameWriteRequest{}
q.s = q.s[:len(q.s)-1]
return wr
}
// consume consumes up to n bytes from q.s[0]. If the frame is
// entirely consumed, it is removed from the queue. If the frame
// is partially consumed, the frame is kept with the consumed
// bytes removed. Returns true iff any bytes were consumed.
func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) {
if len(q.s) == 0 {
return http2FrameWriteRequest{}, false
}
consumed, rest, numresult := q.s[0].Consume(n)
switch numresult {
case 0:
return http2FrameWriteRequest{}, false
case 1:
q.shift()
case 2:
q.s[0] = rest
}
return consumed, true
}
type http2writeQueuePool []*http2writeQueue
// put inserts an unused writeQueue into the pool.
// put inserts an unused writeQueue into the pool.
func (p *http2writeQueuePool) put(q *http2writeQueue) {
for i := range q.s {
q.s[i] = http2FrameWriteRequest{}
}
q.s = q.s[:0]
*p = append(*p, q)
}
// get returns an empty writeQueue.
func (p *http2writeQueuePool) get() *http2writeQueue {
ln := len(*p)
if ln == 0 {
return new(http2writeQueue)
}
x := ln - 1
q := (*p)[x]
(*p)[x] = nil
*p = (*p)[:x]
return q
}
// RFC 7540, Section 5.3.5: the default weight is 16.
const http2priorityDefaultWeight = 15 // 16 = 15 + 1
// PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
type http2PriorityWriteSchedulerConfig struct {
// MaxClosedNodesInTree controls the maximum number of closed streams to
// retain in the priority tree. Setting this to zero saves a small amount
// of memory at the cost of performance.
//
// See RFC 7540, Section 5.3.4:
// "It is possible for a stream to become closed while prioritization
// information ... is in transit. ... This potentially creates suboptimal
// prioritization, since the stream could be given a priority that is
// different from what is intended. To avoid these problems, an endpoint
// SHOULD retain stream prioritization state for a period after streams
// become closed. The longer state is retained, the lower the chance that
// streams are assigned incorrect or default priority values."
MaxClosedNodesInTree int
// MaxIdleNodesInTree controls the maximum number of idle streams to
// retain in the priority tree. Setting this to zero saves a small amount
// of memory at the cost of performance.
//
// See RFC 7540, Section 5.3.4:
// Similarly, streams that are in the "idle" state can be assigned
// priority or become a parent of other streams. This allows for the
// creation of a grouping node in the dependency tree, which enables
// more flexible expressions of priority. Idle streams begin with a
// default priority (Section 5.3.5).
MaxIdleNodesInTree int
// ThrottleOutOfOrderWrites enables write throttling to help ensure that
// data is delivered in priority order. This works around a race where
// stream B depends on stream A and both streams are about to call Write
// to queue DATA frames. If B wins the race, a naive scheduler would eagerly
// write as much data from B as possible, but this is suboptimal because A
// is a higher-priority stream. With throttling enabled, we write a small
// amount of data from B to minimize the amount of bandwidth that B can
// steal from A.
ThrottleOutOfOrderWrites bool
}
// NewPriorityWriteScheduler constructs a WriteScheduler that schedules
// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3.
// If cfg is nil, default options are used.
func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler {
if cfg == nil {
// For justification of these defaults, see:
// https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY
cfg = &http2PriorityWriteSchedulerConfig{
MaxClosedNodesInTree: 10,
MaxIdleNodesInTree: 10,
ThrottleOutOfOrderWrites: false,
}
}
ws := &http2priorityWriteScheduler{
nodes: make(map[uint32]*http2priorityNode),
maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
}
ws.nodes[0] = &ws.root
if cfg.ThrottleOutOfOrderWrites {
ws.writeThrottleLimit = 1024
} else {
ws.writeThrottleLimit = math.MaxInt32
}
return ws
}
type http2priorityNodeState int
const (
http2priorityNodeOpen http2priorityNodeState = iota
http2priorityNodeClosed
http2priorityNodeIdle
)
// priorityNode is a node in an HTTP/2 priority tree.
// Each node is associated with a single stream ID.
// See RFC 7540, Section 5.3.
type http2priorityNode struct {
q http2writeQueue // queue of pending frames to write
id uint32 // id of the stream, or 0 for the root of the tree
weight uint8 // the actual weight is weight+1, so the value is in [1,256]
state http2priorityNodeState // open | closed | idle
bytes int64 // number of bytes written by this node, or 0 if closed
subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
// These links form the priority tree.
parent *http2priorityNode
kids *http2priorityNode // start of the kids list
prev, next *http2priorityNode // doubly-linked list of siblings
}
func (n *http2priorityNode) setParent(parent *http2priorityNode) {
if n == parent {
panic("setParent to self")
}
if n.parent == parent {
return
}
// Unlink from current parent.
if parent := n.parent; parent != nil {
if n.prev == nil {
parent.kids = n.next
} else {
n.prev.next = n.next
}
if n.next != nil {
n.next.prev = n.prev
}
}
// Link to new parent.
// If parent=nil, remove n from the tree.
// Always insert at the head of parent.kids (this is assumed by walkReadyInOrder).
n.parent = parent
if parent == nil {
n.next = nil
n.prev = nil
} else {
n.next = parent.kids
n.prev = nil
if n.next != nil {
n.next.prev = n
}
parent.kids = n
}
}
func (n *http2priorityNode) addBytes(b int64) {
n.bytes += b
for ; n != nil; n = n.parent {
n.subtreeBytes += b
}
}
// walkReadyInOrder iterates over the tree in priority order, calling f for each node
// with a non-empty write queue. When f returns true, this function returns true and the
// walk halts. tmp is used as scratch space for sorting.
//
// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
// if any ancestor p of n is still open (ignoring the root node).
func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool {
if !n.q.empty() && f(n, openParent) {
return true
}
if n.kids == nil {
return false
}
// Don't consider the root "open" when updating openParent since
// we can't send data frames on the root stream (only control frames).
if n.id != 0 {
openParent = openParent || (n.state == http2priorityNodeOpen)
}
// Common case: only one kid or all kids have the same weight.
// Some clients don't use weights; other clients (like web browsers)
// use mostly-linear priority trees.
w := n.kids.weight
needSort := false
for k := n.kids.next; k != nil; k = k.next {
if k.weight != w {
needSort = true
break
}
}
if !needSort {
for k := n.kids; k != nil; k = k.next {
if k.walkReadyInOrder(openParent, tmp, f) {
return true
}
}
return false
}
// Uncommon case: sort the child nodes. We remove the kids from the parent,
// then re-insert after sorting so we can reuse tmp for future sort calls.
*tmp = (*tmp)[:0]
for n.kids != nil {
*tmp = append(*tmp, n.kids)
n.kids.setParent(nil)
}
sort.Sort(http2sortPriorityNodeSiblings(*tmp))
for i := len(*tmp) - 1; i >= 0; i-- {
(*tmp)[i].setParent(n) // setParent inserts at the head of n.kids
}
for k := n.kids; k != nil; k = k.next {
if k.walkReadyInOrder(openParent, tmp, f) {
return true
}
}
return false
}
type http2sortPriorityNodeSiblings []*http2priorityNode
func (z http2sortPriorityNodeSiblings) Len() int { return len(z) }
func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
func (z http2sortPriorityNodeSiblings) Less(i, k int) bool {
// Prefer the subtree that has sent fewer bytes relative to its weight.
// See sections 5.3.2 and 5.3.4.
wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
if bi == 0 && bk == 0 {
return wi >= wk
}
if bk == 0 {
return false
}
return bi/bk <= wi/wk
}
type http2priorityWriteScheduler struct {
// root is the root of the priority tree, where root.id = 0.
// The root queues control frames that are not associated with any stream.
root http2priorityNode
// nodes maps stream ids to priority tree nodes.
nodes map[uint32]*http2priorityNode
// maxID is the maximum stream id in nodes.
maxID uint32
// lists of nodes that have been closed or are idle, but are kept in
// the tree for improved prioritization. When the lengths exceed either
// maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
closedNodes, idleNodes []*http2priorityNode
// From the config.
maxClosedNodesInTree int
maxIdleNodesInTree int
writeThrottleLimit int32
enableWriteThrottle bool
// tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
tmp []*http2priorityNode
// pool of empty queues for reuse.
queuePool http2writeQueuePool
}
func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
// The stream may be currently idle but cannot be opened or closed.
if curr := ws.nodes[streamID]; curr != nil {
if curr.state != http2priorityNodeIdle {
panic(fmt.Sprintf("stream %d already opened", streamID))
}
curr.state = http2priorityNodeOpen
return
}
// RFC 7540, Section 5.3.5:
// "All streams are initially assigned a non-exclusive dependency on stream 0x0.
// Pushed streams initially depend on their associated stream. In both cases,
// streams are assigned a default weight of 16."
parent := ws.nodes[options.PusherID]
if parent == nil {
parent = &ws.root
}
n := &http2priorityNode{
q: *ws.queuePool.get(),
id: streamID,
weight: http2priorityDefaultWeight,
state: http2priorityNodeOpen,
}
n.setParent(parent)
ws.nodes[streamID] = n
if streamID > ws.maxID {
ws.maxID = streamID
}
}
func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) {
if streamID == 0 {
panic("violation of WriteScheduler interface: cannot close stream 0")
}
if ws.nodes[streamID] == nil {
panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
}
if ws.nodes[streamID].state != http2priorityNodeOpen {
panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
}
n := ws.nodes[streamID]
n.state = http2priorityNodeClosed
n.addBytes(-n.bytes)
q := n.q
ws.queuePool.put(&q)
n.q.s = nil
if ws.maxClosedNodesInTree > 0 {
ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
} else {
ws.removeNode(n)
}
}
func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
if streamID == 0 {
panic("adjustPriority on root")
}
// If streamID does not exist, there are two cases:
// - A closed stream that has been removed (this will have ID <= maxID)
// - An idle stream that is being used for "grouping" (this will have ID > maxID)
n := ws.nodes[streamID]
if n == nil {
if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 {
return
}
ws.maxID = streamID
n = &http2priorityNode{
q: *ws.queuePool.get(),
id: streamID,
weight: http2priorityDefaultWeight,
state: http2priorityNodeIdle,
}
n.setParent(&ws.root)
ws.nodes[streamID] = n
ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n)
}
// Section 5.3.1: A dependency on a stream that is not currently in the tree
// results in that stream being given a default priority (Section 5.3.5).
parent := ws.nodes[priority.StreamDep]
if parent == nil {
n.setParent(&ws.root)
n.weight = http2priorityDefaultWeight
return
}
// Ignore if the client tries to make a node its own parent.
if n == parent {
return
}
// Section 5.3.3:
// "If a stream is made dependent on one of its own dependencies, the
// formerly dependent stream is first moved to be dependent on the
// reprioritized stream's previous parent. The moved dependency retains
// its weight."
//
// That is: if parent depends on n, move parent to depend on n.parent.
for x := parent.parent; x != nil; x = x.parent {
if x == n {
parent.setParent(n.parent)
break
}
}
// Section 5.3.3: The exclusive flag causes the stream to become the sole
// dependency of its parent stream, causing other dependencies to become
// dependent on the exclusive stream.
if priority.Exclusive {
k := parent.kids
for k != nil {
next := k.next
if k != n {
k.setParent(n)
}
k = next
}
}
n.setParent(parent)
n.weight = priority.Weight
}
func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) {
var n *http2priorityNode
if wr.isControl() {
n = &ws.root
} else {
id := wr.StreamID()
n = ws.nodes[id]
if n == nil {
// id is an idle or closed stream. wr should not be a HEADERS or
// DATA frame. In other case, we push wr onto the root, rather
// than creating a new priorityNode.
if wr.DataSize() > 0 {
panic("add DATA on non-open stream")
}
n = &ws.root
}
}
n.q.push(wr)
}
func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) {
ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool {
limit := int32(math.MaxInt32)
if openParent {
limit = ws.writeThrottleLimit
}
wr, ok = n.q.consume(limit)
if !ok {
return false
}
n.addBytes(int64(wr.DataSize()))
// If B depends on A and B continuously has data available but A
// does not, gradually increase the throttling limit to allow B to
// steal more and more bandwidth from A.
if openParent {
ws.writeThrottleLimit += 1024
if ws.writeThrottleLimit < 0 {
ws.writeThrottleLimit = math.MaxInt32
}
} else if ws.enableWriteThrottle {
ws.writeThrottleLimit = 1024
}
return true
})
return wr, ok
}
func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) {
if maxSize == 0 {
return
}
if len(*list) == maxSize {
// Remove the oldest node, then shift left.
ws.removeNode((*list)[0])
x := (*list)[1:]
copy(*list, x)
*list = (*list)[:len(x)]
}
*list = append(*list, n)
}
func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) {
for k := n.kids; k != nil; k = k.next {
k.setParent(n.parent)
}
n.setParent(nil)
delete(ws.nodes, n.id)
}
// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2
// priorities. Control frames like SETTINGS and PING are written before DATA
// frames, but if no control frames are queued and multiple streams have queued
// HEADERS or DATA frames, Pop selects a ready stream arbitrarily.
func http2NewRandomWriteScheduler() http2WriteScheduler {
return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)}
}
type http2randomWriteScheduler struct {
// zero are frames not associated with a specific stream.
zero http2writeQueue
// sq contains the stream-specific queues, keyed by stream ID.
// When a stream is idle, closed, or emptied, it's deleted
// from the map.
sq map[uint32]*http2writeQueue
// pool of empty queues for reuse.
queuePool http2writeQueuePool
}
func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
// no-op: idle streams are not tracked
}
func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) {
q, ok := ws.sq[streamID]
if !ok {
return
}
delete(ws.sq, streamID)
ws.queuePool.put(q)
}
func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
// no-op: priorities are ignored
}
func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) {
if wr.isControl() {
ws.zero.push(wr)
return
}
id := wr.StreamID()
q, ok := ws.sq[id]
if !ok {
q = ws.queuePool.get()
ws.sq[id] = q
}
q.push(wr)
}
func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) {
// Control and RST_STREAM frames first.
if !ws.zero.empty() {
return ws.zero.shift(), true
}
// Iterate over all non-idle streams until finding one that can be consumed.
for streamID, q := range ws.sq {
if wr, ok := q.consume(math.MaxInt32); ok {
if q.empty() {
delete(ws.sq, streamID)
ws.queuePool.put(q)
}
return wr, true
}
}
return http2FrameWriteRequest{}, false
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !nethttpomithttp2
// +build !nethttpomithttp2
package http
import (
"reflect"
)
func (e http2StreamError) As(target any) bool {
dst := reflect.ValueOf(target).Elem()
dstType := dst.Type()
if dstType.Kind() != reflect.Struct {
return false
}
src := reflect.ValueOf(e)
srcType := src.Type()
numField := srcType.NumField()
if dstType.NumField() != numField {
return false
}
for i := 0; i < numField; i++ {
sf := srcType.Field(i)
df := dstType.Field(i)
if sf.Name != df.Name || !sf.Type.ConvertibleTo(df.Type) {
return false
}
}
for i := 0; i < numField; i++ {
df := dst.Field(i)
df.Set(src.Field(i).Convert(df.Type()))
}
return true
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"io"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"sort"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
// A Header represents the key-value pairs in an HTTP header.
//
// The keys should be in canonical form, as returned by
// CanonicalHeaderKey.
type Header map[string][]string
// Add adds the key, value pair to the header.
// It appends to any existing values associated with key.
// The key is case insensitive; it is canonicalized by
// CanonicalHeaderKey.
func (h Header) Add(key, value string) {
textproto.MIMEHeader(h).Add(key, value)
}
// Set sets the header entries associated with key to the
// single element value. It replaces any existing values
// associated with key. The key is case insensitive; it is
// canonicalized by textproto.CanonicalMIMEHeaderKey.
// To use non-canonical keys, assign to the map directly.
func (h Header) Set(key, value string) {
textproto.MIMEHeader(h).Set(key, value)
}
// Get gets the first value associated with the given key. If
// there are no values associated with the key, Get returns "".
// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
// used to canonicalize the provided key. Get assumes that all
// keys are stored in canonical form. To use non-canonical keys,
// access the map directly.
func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
// Values returns all values associated with the given key.
// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
// used to canonicalize the provided key. To use non-canonical
// keys, access the map directly.
// The returned slice is not a copy.
func (h Header) Values(key string) []string {
return textproto.MIMEHeader(h).Values(key)
}
// get is like Get, but key must already be in CanonicalHeaderKey form.
func (h Header) get(key string) string {
if v := h[key]; len(v) > 0 {
return v[0]
}
return ""
}
// has reports whether h has the provided key defined, even if it's
// set to 0-length slice.
func (h Header) has(key string) bool {
_, ok := h[key]
return ok
}
// Del deletes the values associated with key.
// The key is case insensitive; it is canonicalized by
// CanonicalHeaderKey.
func (h Header) Del(key string) {
textproto.MIMEHeader(h).Del(key)
}
// Write writes a header in wire format.
func (h Header) Write(w io.Writer) error {
return h.write(w, nil)
}
func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
return h.writeSubset(w, nil, trace)
}
// Clone returns a copy of h or nil if h is nil.
func (h Header) Clone() Header {
if h == nil {
return nil
}
// Find total number of values.
nv := 0
for _, vv := range h {
nv += len(vv)
}
sv := make([]string, nv) // shared backing array for headers' values
h2 := make(Header, len(h))
for k, vv := range h {
if vv == nil {
// Preserve nil values. ReverseProxy distinguishes
// between nil and zero-length header values.
h2[k] = nil
continue
}
n := copy(sv, vv)
h2[k] = sv[:n:n]
sv = sv[n:]
}
return h2
}
var timeFormats = []string{
TimeFormat,
time.RFC850,
time.ANSIC,
}
// ParseTime parses a time header (such as the Date: header),
// trying each of the three formats allowed by HTTP/1.1:
// TimeFormat, time.RFC850, and time.ANSIC.
func ParseTime(text string) (t time.Time, err error) {
for _, layout := range timeFormats {
t, err = time.Parse(layout, text)
if err == nil {
return
}
}
return
}
var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
// stringWriter implements WriteString on a Writer.
type stringWriter struct {
w io.Writer
}
func (w stringWriter) WriteString(s string) (n int, err error) {
return w.w.Write([]byte(s))
}
type keyValues struct {
key string
values []string
}
// A headerSorter implements sort.Interface by sorting a []keyValues
// by key. It's used as a pointer, so it can fit in a sort.Interface
// interface value without allocation.
type headerSorter struct {
kvs []keyValues
}
func (s *headerSorter) Len() int { return len(s.kvs) }
func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
var headerSorterPool = sync.Pool{
New: func() any { return new(headerSorter) },
}
// sortedKeyValues returns h's keys sorted in the returned kvs
// slice. The headerSorter used to sort is also returned, for possible
// return to headerSorterCache.
func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
hs = headerSorterPool.Get().(*headerSorter)
if cap(hs.kvs) < len(h) {
hs.kvs = make([]keyValues, 0, len(h))
}
kvs = hs.kvs[:0]
for k, vv := range h {
if !exclude[k] {
kvs = append(kvs, keyValues{k, vv})
}
}
hs.kvs = kvs
sort.Sort(hs)
return kvs, hs
}
// WriteSubset writes a header in wire format.
// If exclude is not nil, keys where exclude[key] == true are not written.
// Keys are not canonicalized before checking the exclude map.
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
return h.writeSubset(w, exclude, nil)
}
func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
ws, ok := w.(io.StringWriter)
if !ok {
ws = stringWriter{w}
}
kvs, sorter := h.sortedKeyValues(exclude)
var formattedVals []string
for _, kv := range kvs {
if !httpguts.ValidHeaderFieldName(kv.key) {
// This could be an error. In the common case of
// writing response headers, however, we have no good
// way to provide the error back to the server
// handler, so just drop invalid headers instead.
continue
}
for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
v = textproto.TrimString(v)
for _, s := range []string{kv.key, ": ", v, "\r\n"} {
if _, err := ws.WriteString(s); err != nil {
headerSorterPool.Put(sorter)
return err
}
}
if trace != nil && trace.WroteHeaderField != nil {
formattedVals = append(formattedVals, v)
}
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField(kv.key, formattedVals)
formattedVals = nil
}
}
headerSorterPool.Put(sorter)
return nil
}
// CanonicalHeaderKey returns the canonical format of the
// header key s. The canonicalization converts the first
// letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
// If s contains a space or invalid header field bytes, it is
// returned without modifications.
func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
// hasToken reports whether token appears with v, ASCII
// case-insensitive, with space or comma boundaries.
// token must be all lowercase.
// v may contain mixed cased.
func hasToken(v, token string) bool {
if len(token) > len(v) || token == "" {
return false
}
if v == token {
return true
}
for sp := 0; sp <= len(v)-len(token); sp++ {
// Check that first character is good.
// The token is ASCII, so checking only a single byte
// is sufficient. We skip this potential starting
// position if both the first byte and its potential
// ASCII uppercase equivalent (b|0x20) don't match.
// False positives ('^' => '~') are caught by EqualFold.
if b := v[sp]; b != token[0] && b|0x20 != token[0] {
continue
}
// Check that start pos is on a valid token boundary.
if sp > 0 && !isTokenBoundary(v[sp-1]) {
continue
}
// Check that end pos is on a valid token boundary.
if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
continue
}
if ascii.EqualFold(v[sp:sp+len(token)], token) {
return true
}
}
return false
}
func isTokenBoundary(b byte) bool {
return b == ' ' || b == ',' || b == '\t'
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2
package http
import (
"io"
"strconv"
"strings"
"time"
"unicode/utf8"
"golang.org/x/net/http/httpguts"
)
// incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
type incomparable [0]func()
// maxInt64 is the effective "infinite" value for the Server and
// Transport's byte-limiting readers.
const maxInt64 = 1<<63 - 1
// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancellation of network operations.
var aLongTimeAgo = time.Unix(1, 0)
// omitBundledHTTP2 is set by omithttp2.go when the nethttpomithttp2
// build tag is set. That means h2_bundle.go isn't compiled in and we
// shouldn't try to use it.
var omitBundledHTTP2 bool
// TODO(bradfitz): move common stuff here. The other files have accumulated
// generic http stuff in random places.
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}
func (k *contextKey) String() string { return "net/http context value " + k.name }
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
// return true if the string includes a port.
func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
// removeEmptyPort strips the empty port in ":port" to ""
// as mandated by RFC 3986 Section 6.2.3.
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}
// stringContainsCTLByte reports whether s contains any ASCII control character.
func stringContainsCTLByte(s string) bool {
for i := 0; i < len(s); i++ {
b := s[i]
if b < ' ' || b == 0x7f {
return true
}
}
return false
}
func hexEscapeNonASCII(s string) string {
newLen := 0
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
newLen += 3
} else {
newLen++
}
}
if newLen == len(s) {
return s
}
b := make([]byte, 0, newLen)
var pos int
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
if pos < i {
b = append(b, s[pos:i]...)
}
b = append(b, '%')
b = strconv.AppendInt(b, int64(s[i]), 16)
pos = i + 1
}
}
if pos < len(s) {
b = append(b, s[pos:]...)
}
return string(b)
}
// NoBody is an io.ReadCloser with no bytes. Read always returns EOF
// and Close always returns nil. It can be used in an outgoing client
// request to explicitly signal that a request has zero bytes.
// An alternative, however, is to simply set Request.Body to nil.
var NoBody = noBody{}
type noBody struct{}
func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
func (noBody) Close() error { return nil }
func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
var (
// verify that an io.Copy from NoBody won't require a buffer:
_ io.WriterTo = NoBody
_ io.ReadCloser = NoBody
)
// PushOptions describes options for Pusher.Push.
type PushOptions struct {
// Method specifies the HTTP method for the promised request.
// If set, it must be "GET" or "HEAD". Empty means "GET".
Method string
// Header specifies additional promised request headers. This cannot
// include HTTP/2 pseudo header fields like ":path" and ":scheme",
// which will be added automatically.
Header Header
}
// Pusher is the interface implemented by ResponseWriters that support
// HTTP/2 server push. For more background, see
// https://tools.ietf.org/html/rfc7540#section-8.2.
type Pusher interface {
// Push initiates an HTTP/2 server push. This constructs a synthetic
// request using the given target and options, serializes that request
// into a PUSH_PROMISE frame, then dispatches that request using the
// server's request handler. If opts is nil, default options are used.
//
// The target must either be an absolute path (like "/path") or an absolute
// URL that contains a valid host and the same scheme as the parent request.
// If the target is a path, it will inherit the scheme and host of the
// parent request.
//
// The HTTP/2 spec disallows recursive pushes and cross-authority pushes.
// Push may or may not detect these invalid pushes; however, invalid
// pushes will be detected and canceled by conforming clients.
//
// Handlers that wish to push URL X should call Push before sending any
// data that may trigger a request for URL X. This avoids a race where the
// client issues requests for X before receiving the PUSH_PROMISE for X.
//
// Push will run in a separate goroutine making the order of arrival
// non-deterministic. Any required synchronization needs to be implemented
// by the caller.
//
// Push returns ErrNotSupported if the client has disabled push or if push
// is not supported on the underlying connection.
Push(target string, opts *PushOptions) error
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httptest provides utilities for HTTP testing.
package httptest
import (
"bufio"
"bytes"
"crypto/tls"
"io"
"net/http"
"strings"
)
// NewRequest returns a new incoming server Request, suitable
// for passing to an http.Handler for testing.
//
// The target is the RFC 7230 "request-target": it may be either a
// path or an absolute URL. If target is an absolute URL, the host name
// from the URL is used. Otherwise, "example.com" is used.
//
// The TLS field is set to a non-nil dummy value if target has scheme
// "https".
//
// The Request.Proto is always HTTP/1.1.
//
// An empty method means "GET".
//
// The provided body may be nil. If the body is of type *bytes.Reader,
// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is
// set.
//
// NewRequest panics on error for ease of use in testing, where a
// panic is acceptable.
//
// To generate a client HTTP request instead of a server request, see
// the NewRequest function in the net/http package.
func NewRequest(method, target string, body io.Reader) *http.Request {
if method == "" {
method = "GET"
}
req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n")))
if err != nil {
panic("invalid NewRequest arguments; " + err.Error())
}
// HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here.
req.Proto = "HTTP/1.1"
req.ProtoMinor = 1
req.Close = false
if body != nil {
switch v := body.(type) {
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
case *bytes.Reader:
req.ContentLength = int64(v.Len())
case *strings.Reader:
req.ContentLength = int64(v.Len())
default:
req.ContentLength = -1
}
if rc, ok := body.(io.ReadCloser); ok {
req.Body = rc
} else {
req.Body = io.NopCloser(body)
}
}
// 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in
// documentation and example source code and should not be
// used publicly.
req.RemoteAddr = "192.0.2.1:1234"
if req.Host == "" {
req.Host = "example.com"
}
if strings.HasPrefix(target, "https://") {
req.TLS = &tls.ConnectionState{
Version: tls.VersionTLS12,
HandshakeComplete: true,
ServerName: req.Host,
}
}
return req
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httptest
import (
"bytes"
"fmt"
"io"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
// Code is the HTTP response code set by WriteHeader.
//
// Note that if a Handler never calls WriteHeader or Write,
// this might end up being 0, rather than the implicit
// http.StatusOK. To get the implicit value, use the Result
// method.
Code int
// HeaderMap contains the headers explicitly set by the Handler.
// It is an internal detail.
//
// Deprecated: HeaderMap exists for historical compatibility
// and should not be used. To access the headers returned by a handler,
// use the Response.Header map as returned by the Result method.
HeaderMap http.Header
// Body is the buffer to which the Handler's Write calls are sent.
// If nil, the Writes are silently discarded.
Body *bytes.Buffer
// Flushed is whether the Handler called Flush.
Flushed bool
result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
}
// NewRecorder returns an initialized ResponseRecorder.
func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: 200,
}
}
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
const DefaultRemoteAddr = "1.2.3.4"
// Header implements http.ResponseWriter. It returns the response
// headers to mutate within a handler. To test the headers that were
// written after a handler completes, use the Result method and see
// the returned Response value's Header.
func (rw *ResponseRecorder) Header() http.Header {
m := rw.HeaderMap
if m == nil {
m = make(http.Header)
rw.HeaderMap = m
}
return m
}
// writeHeader writes a header if it was not written yet and
// detects Content-Type if needed.
//
// bytes or str are the beginning of the response body.
// We pass both to avoid unnecessarily generate garbage
// in rw.WriteString which was created for performance reasons.
// Non-nil bytes win.
func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
if rw.wroteHeader {
return
}
if len(str) > 512 {
str = str[:512]
}
m := rw.Header()
_, hasType := m["Content-Type"]
hasTE := m.Get("Transfer-Encoding") != ""
if !hasType && !hasTE {
if b == nil {
b = []byte(str)
}
m.Set("Content-Type", http.DetectContentType(b))
}
rw.WriteHeader(200)
}
// Write implements http.ResponseWriter. The data in buf is written to
// rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
rw.writeHeader(buf, "")
if rw.Body != nil {
rw.Body.Write(buf)
}
return len(buf), nil
}
// WriteString implements io.StringWriter. The data in str is written
// to rw.Body, if not nil.
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
rw.writeHeader(nil, str)
if rw.Body != nil {
rw.Body.WriteString(str)
}
return len(str), nil
}
func checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
// at https://httpwg.org/specs/rfc7231.html#status.codes)
// and we might block under 200 (once we have more mature 1xx support).
// But for now any three digits.
//
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
// no equivalent bogus thing we can realistically send in HTTP/2,
// so we'll consistently panic instead and help people find their bugs
// early. (We can't return an error from WriteHeader even if we wanted to.)
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
// WriteHeader implements http.ResponseWriter.
func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.wroteHeader {
return
}
checkWriteHeaderCode(code)
rw.Code = code
rw.wroteHeader = true
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
rw.snapHeader = rw.HeaderMap.Clone()
}
// Flush implements http.Flusher. To test whether Flush was
// called, see rw.Flushed.
func (rw *ResponseRecorder) Flush() {
if !rw.wroteHeader {
rw.WriteHeader(200)
}
rw.Flushed = true
}
// Result returns the response generated by the handler.
//
// The returned Response will have at least its StatusCode,
// Header, Body, and optionally Trailer populated.
// More fields may be populated in the future, so callers should
// not DeepEqual the result in tests.
//
// The Response.Header is a snapshot of the headers at the time of the
// first write call, or at the time of this call, if the handler never
// did a write.
//
// The Response.Body is guaranteed to be non-nil and Body.Read call is
// guaranteed to not return any error other than io.EOF.
//
// Result must only be called after the handler has finished running.
func (rw *ResponseRecorder) Result() *http.Response {
if rw.result != nil {
return rw.result
}
if rw.snapHeader == nil {
rw.snapHeader = rw.HeaderMap.Clone()
}
res := &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: rw.Code,
Header: rw.snapHeader,
}
rw.result = res
if res.StatusCode == 0 {
res.StatusCode = 200
}
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
if rw.Body != nil {
res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
} else {
res.Body = http.NoBody
}
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
for _, k := range trailers {
for _, k := range strings.Split(k, ",") {
k = http.CanonicalHeaderKey(textproto.TrimString(k))
if !httpguts.ValidTrailerHeader(k) {
// Ignore since forbidden by RFC 7230, section 4.1.2.
continue
}
vv, ok := rw.HeaderMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
res.Trailer[k] = vv2
}
}
}
for k, vv := range rw.HeaderMap {
if !strings.HasPrefix(k, http.TrailerPrefix) {
continue
}
if res.Trailer == nil {
res.Trailer = make(http.Header)
}
for _, v := range vv {
res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
}
}
return res
}
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
//
// This a modified version of same function found in net/http/transfer.go. This
// one just ignores an invalid header.
func parseContentLength(cl string) int64 {
cl = textproto.TrimString(cl)
if cl == "" {
return -1
}
n, err := strconv.ParseUint(cl, 10, 63)
if err != nil {
return -1
}
return int64(n)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implementation of Server
package httptest
import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/http/internal/testcert"
"os"
"strings"
"sync"
"time"
)
// A Server is an HTTP server listening on a system-chosen port on the
// local loopback interface, for use in end-to-end HTTP tests.
type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener
// EnableHTTP2 controls whether HTTP/2 is enabled
// on the server. It must be set between calling
// NewUnstartedServer and calling Server.StartTLS.
EnableHTTP2 bool
// TLS is the optional TLS configuration, populated with a new config
// after TLS is started. If set on an unstarted server before StartTLS
// is called, existing fields are copied into the new config.
TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
Config *http.Server
// certificate is a parsed version of the TLS config certificate, if present.
certificate *x509.Certificate
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
mu sync.Mutex // guards closed and conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
// client is configured for use with the server.
// Its transport is automatically closed when Close is called.
client *http.Client
}
func newLocalListener() net.Listener {
if serveFlag != "" {
l, err := net.Listen("tcp", serveFlag)
if err != nil {
panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
}
return l
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
}
return l
}
// When debugging a particular http server-based test,
// this flag lets you run
//
// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
//
// to start the broken server so you can interact with it manually.
// We only register this flag if it looks like the caller knows about it
// and is trying to use it as we don't want to pollute flags and this
// isn't really part of our API. Don't depend on this.
var serveFlag string
func init() {
if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
}
}
func strSliceContainsPrefix(v []string, pre string) bool {
for _, s := range v {
if strings.HasPrefix(s, pre) {
return true
}
}
return false
}
// NewServer starts and returns a new Server.
// The caller should call Close when finished, to shut it down.
func NewServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.Start()
return ts
}
// NewUnstartedServer returns a new Server but doesn't start it.
//
// After changing its configuration, the caller should call Start or
// StartTLS.
//
// The caller should call Close when finished, to shut it down.
func NewUnstartedServer(handler http.Handler) *Server {
return &Server{
Listener: newLocalListener(),
Config: &http.Server{Handler: handler},
}
}
// Start starts a server from NewUnstartedServer.
func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
if s.client == nil {
s.client = &http.Client{Transport: &http.Transport{}}
}
s.URL = "http://" + s.Listener.Addr().String()
s.wrap()
s.goServe()
if serveFlag != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
}
}
// StartTLS starts TLS on a server from NewUnstartedServer.
func (s *Server) StartTLS() {
if s.URL != "" {
panic("Server already started")
}
if s.client == nil {
s.client = &http.Client{Transport: &http.Transport{}}
}
cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
existingConfig := s.TLS
if existingConfig != nil {
s.TLS = existingConfig.Clone()
} else {
s.TLS = new(tls.Config)
}
if s.TLS.NextProtos == nil {
nextProtos := []string{"http/1.1"}
if s.EnableHTTP2 {
nextProtos = []string{"h2"}
}
s.TLS.NextProtos = nextProtos
}
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
s.client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
ForceAttemptHTTP2: s.EnableHTTP2,
}
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrap()
s.goServe()
}
// NewTLSServer starts and returns a new Server using TLS.
// The caller should call Close when finished, to shut it down.
func NewTLSServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.StartTLS()
return ts
}
type closeIdleTransport interface {
CloseIdleConnections()
}
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
s.mu.Lock()
if !s.closed {
s.closed = true
s.Listener.Close()
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
// Force-close any idle connections (those between
// requests) and new connections (those which connected
// but never sent a request). StateNew connections are
// super rare and have only been seen (in
// previously-flaky tests) in the case of
// socket-late-binding races from the http Client
// dialing this server and then getting an idle
// connection before the dial completed. There is thus
// a connected connection in StateNew with no
// associated Request. We only close StateIdle and
// StateNew because they're not doing anything. It's
// possible StateNew is about to do something in a few
// milliseconds, but a previous CL to check again in a
// few milliseconds wasn't liked (early versions of
// https://golang.org/cl/15151) so now we just
// forcefully close StateNew. The docs for Server.Close say
// we wait for "outstanding requests", so we don't close things
// in StateActive.
if st == http.StateIdle || st == http.StateNew {
s.closeConn(c)
}
}
// If this server doesn't shut down in 5 seconds, tell the user why.
t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
defer t.Stop()
}
s.mu.Unlock()
// Not part of httptest.Server's correctness, but assume most
// users of httptest.Server will be using the standard
// transport, so help them out and close any idle connections for them.
if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
// Also close the client idle connections.
if s.client != nil {
if t, ok := s.client.Transport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
}
s.wg.Wait()
}
func (s *Server) logCloseHangDebugInfo() {
s.mu.Lock()
defer s.mu.Unlock()
var buf strings.Builder
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
for c, st := range s.conns {
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
}
log.Print(buf.String())
}
// CloseClientConnections closes any open HTTP connections to the test Server.
func (s *Server) CloseClientConnections() {
s.mu.Lock()
nconn := len(s.conns)
ch := make(chan struct{}, nconn)
for c := range s.conns {
go s.closeConnChan(c, ch)
}
s.mu.Unlock()
// Wait for outstanding closes to finish.
//
// Out of paranoia for making a late change in Go 1.6, we
// bound how long this can wait, since golang.org/issue/14291
// isn't fully understood yet. At least this should only be used
// in tests.
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
for i := 0; i < nconn; i++ {
select {
case <-ch:
case <-timer.C:
// Too slow. Give up.
return
}
}
}
// Certificate returns the certificate used by the server, or nil if
// the server doesn't use TLS.
func (s *Server) Certificate() *x509.Certificate {
return s.certificate
}
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on Server.Close.
func (s *Server) Client() *http.Client {
return s.client
}
func (s *Server) goServe() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
}()
}
// wrap installs the connection state-tracking hook to know which
// connections are idle.
func (s *Server) wrap() {
oldHook := s.Config.ConnState
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
switch cs {
case http.StateNew:
if _, exists := s.conns[c]; exists {
panic("invalid state transition")
}
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
// Add c to the set of tracked conns and increment it to the
// waitgroup.
s.wg.Add(1)
s.conns[c] = cs
if s.closed {
// Probably just a socket-late-binding dial from
// the default transport that lost the race (and
// thus this connection is now idle and will
// never be used).
s.closeConn(c)
}
case http.StateActive:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateNew && oldState != http.StateIdle {
panic("invalid state transition")
}
s.conns[c] = cs
}
case http.StateIdle:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateActive {
panic("invalid state transition")
}
s.conns[c] = cs
}
if s.closed {
s.closeConn(c)
}
case http.StateHijacked, http.StateClosed:
// Remove c from the set of tracked conns and decrement it from the
// waitgroup, unless it was previously removed.
if _, ok := s.conns[c]; ok {
delete(s.conns, c)
// Keep Close from returning until the user's ConnState hook
// (if any) finishes.
defer s.wg.Done()
}
}
if oldHook != nil {
oldHook(c, cs)
}
}
}
// closeConn closes c.
// s.mu must be held.
func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
// closeConnChan is like closeConn, but takes an optional channel to receive a value
// when the goroutine closing c is done.
func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
c.Close()
if done != nil {
done <- struct{}{}
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httptrace provides mechanisms to trace the events within
// HTTP client requests.
package httptrace
import (
"context"
"crypto/tls"
"internal/nettrace"
"net"
"net/textproto"
"reflect"
"time"
)
// unique type to prevent assignment.
type clientEventContextKey struct{}
// ContextClientTrace returns the ClientTrace associated with the
// provided context. If none, it returns nil.
func ContextClientTrace(ctx context.Context) *ClientTrace {
trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace)
return trace
}
// WithClientTrace returns a new context based on the provided parent
// ctx. HTTP client requests made with the returned context will use
// the provided trace hooks, in addition to any previous hooks
// registered with ctx. Any hooks defined in the provided trace will
// be called first.
func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
if trace == nil {
panic("nil trace")
}
old := ContextClientTrace(ctx)
trace.compose(old)
ctx = context.WithValue(ctx, clientEventContextKey{}, trace)
if trace.hasNetHooks() {
nt := &nettrace.Trace{
ConnectStart: trace.ConnectStart,
ConnectDone: trace.ConnectDone,
}
if trace.DNSStart != nil {
nt.DNSStart = func(name string) {
trace.DNSStart(DNSStartInfo{Host: name})
}
}
if trace.DNSDone != nil {
nt.DNSDone = func(netIPs []any, coalesced bool, err error) {
addrs := make([]net.IPAddr, len(netIPs))
for i, ip := range netIPs {
addrs[i] = ip.(net.IPAddr)
}
trace.DNSDone(DNSDoneInfo{
Addrs: addrs,
Coalesced: coalesced,
Err: err,
})
}
}
ctx = context.WithValue(ctx, nettrace.TraceKey{}, nt)
}
return ctx
}
// ClientTrace is a set of hooks to run at various stages of an outgoing
// HTTP request. Any particular hook may be nil. Functions may be
// called concurrently from different goroutines and some may be called
// after the request has completed or failed.
//
// ClientTrace currently traces a single HTTP request & response
// during a single round trip and has no hooks that span a series
// of redirected requests.
//
// See https://blog.golang.org/http-tracing for more.
type ClientTrace struct {
// GetConn is called before a connection is created or
// retrieved from an idle pool. The hostPort is the
// "host:port" of the target or proxy. GetConn is called even
// if there's already an idle cached connection available.
GetConn func(hostPort string)
// GotConn is called after a successful connection is
// obtained. There is no hook for failure to obtain a
// connection; instead, use the error from
// Transport.RoundTrip.
GotConn func(GotConnInfo)
// PutIdleConn is called when the connection is returned to
// the idle pool. If err is nil, the connection was
// successfully returned to the idle pool. If err is non-nil,
// it describes why not. PutIdleConn is not called if
// connection reuse is disabled via Transport.DisableKeepAlives.
// PutIdleConn is called before the caller's Response.Body.Close
// call returns.
// For HTTP/2, this hook is not currently used.
PutIdleConn func(err error)
// GotFirstResponseByte is called when the first byte of the response
// headers is available.
GotFirstResponseByte func()
// Got100Continue is called if the server replies with a "100
// Continue" response.
Got100Continue func()
// Got1xxResponse is called for each 1xx informational response header
// returned before the final non-1xx response. Got1xxResponse is called
// for "100 Continue" responses, even if Got100Continue is also defined.
// If it returns an error, the client request is aborted with that error value.
Got1xxResponse func(code int, header textproto.MIMEHeader) error
// DNSStart is called when a DNS lookup begins.
DNSStart func(DNSStartInfo)
// DNSDone is called when a DNS lookup ends.
DNSDone func(DNSDoneInfo)
// ConnectStart is called when a new connection's Dial begins.
// If net.Dialer.DualStack (IPv6 "Happy Eyeballs") support is
// enabled, this may be called multiple times.
ConnectStart func(network, addr string)
// ConnectDone is called when a new connection's Dial
// completes. The provided err indicates whether the
// connection completed successfully.
// If net.Dialer.DualStack ("Happy Eyeballs") support is
// enabled, this may be called multiple times.
ConnectDone func(network, addr string, err error)
// TLSHandshakeStart is called when the TLS handshake is started. When
// connecting to an HTTPS site via an HTTP proxy, the handshake happens
// after the CONNECT request is processed by the proxy.
TLSHandshakeStart func()
// TLSHandshakeDone is called after the TLS handshake with either the
// successful handshake's connection state, or a non-nil error on handshake
// failure.
TLSHandshakeDone func(tls.ConnectionState, error)
// WroteHeaderField is called after the Transport has written
// each request header. At the time of this call the values
// might be buffered and not yet written to the network.
WroteHeaderField func(key string, value []string)
// WroteHeaders is called after the Transport has written
// all request headers.
WroteHeaders func()
// Wait100Continue is called if the Request specified
// "Expect: 100-continue" and the Transport has written the
// request headers but is waiting for "100 Continue" from the
// server before writing the request body.
Wait100Continue func()
// WroteRequest is called with the result of writing the
// request and any body. It may be called multiple times
// in the case of retried requests.
WroteRequest func(WroteRequestInfo)
}
// WroteRequestInfo contains information provided to the WroteRequest
// hook.
type WroteRequestInfo struct {
// Err is any error encountered while writing the Request.
Err error
}
// compose modifies t such that it respects the previously-registered hooks in old,
// subject to the composition policy requested in t.Compose.
func (t *ClientTrace) compose(old *ClientTrace) {
if old == nil {
return
}
tv := reflect.ValueOf(t).Elem()
ov := reflect.ValueOf(old).Elem()
structType := tv.Type()
for i := 0; i < structType.NumField(); i++ {
tf := tv.Field(i)
hookType := tf.Type()
if hookType.Kind() != reflect.Func {
continue
}
of := ov.Field(i)
if of.IsNil() {
continue
}
if tf.IsNil() {
tf.Set(of)
continue
}
// Make a copy of tf for tf to call. (Otherwise it
// creates a recursive call cycle and stack overflows)
tfCopy := reflect.ValueOf(tf.Interface())
// We need to call both tf and of in some order.
newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value {
tfCopy.Call(args)
return of.Call(args)
})
tv.Field(i).Set(newFunc)
}
}
// DNSStartInfo contains information about a DNS request.
type DNSStartInfo struct {
Host string
}
// DNSDoneInfo contains information about the results of a DNS lookup.
type DNSDoneInfo struct {
// Addrs are the IPv4 and/or IPv6 addresses found in the DNS
// lookup. The contents of the slice should not be mutated.
Addrs []net.IPAddr
// Err is any error that occurred during the DNS lookup.
Err error
// Coalesced is whether the Addrs were shared with another
// caller who was doing the same DNS lookup concurrently.
Coalesced bool
}
func (t *ClientTrace) hasNetHooks() bool {
if t == nil {
return false
}
return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil
}
// GotConnInfo is the argument to the ClientTrace.GotConn function and
// contains information about the obtained connection.
type GotConnInfo struct {
// Conn is the connection that was obtained. It is owned by
// the http.Transport and should not be read, written or
// closed by users of ClientTrace.
Conn net.Conn
// Reused is whether this connection has been previously
// used for another HTTP request.
Reused bool
// WasIdle is whether this connection was obtained from an
// idle pool.
WasIdle bool
// IdleTime reports how long the connection was previously
// idle, if WasIdle is true.
IdleTime time.Duration
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httputil
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// drainBody reads all of b to memory and then returns two equivalent
// ReadClosers yielding the same bytes.
//
// It returns an error if the initial slurp of all bytes fails. It does not attempt
// to make the returned ReadClosers have identical error-matching behavior.
func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
if b == nil || b == http.NoBody {
// No copying needed. Preserve the magic sentinel meaning of NoBody.
return http.NoBody, http.NoBody, nil
}
var buf bytes.Buffer
if _, err = buf.ReadFrom(b); err != nil {
return nil, b, err
}
if err = b.Close(); err != nil {
return nil, b, err
}
return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
// dumpConn is a net.Conn which writes to Writer and reads from Reader
type dumpConn struct {
io.Writer
io.Reader
}
func (c *dumpConn) Close() error { return nil }
func (c *dumpConn) LocalAddr() net.Addr { return nil }
func (c *dumpConn) RemoteAddr() net.Addr { return nil }
func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
type neverEnding byte
func (b neverEnding) Read(p []byte) (n int, err error) {
for i := range p {
p[i] = byte(b)
}
return len(p), nil
}
// outgoingLength is a copy of the unexported
// (*http.Request).outgoingLength method.
func outgoingLength(req *http.Request) int64 {
if req.Body == nil || req.Body == http.NoBody {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
// DumpRequestOut is like DumpRequest but for outgoing client requests. It
// includes any headers that the standard http.Transport adds, such as
// User-Agent.
func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
save := req.Body
dummyBody := false
if !body {
contentLength := outgoingLength(req)
if contentLength != 0 {
req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength))
dummyBody = true
}
} else {
var err error
save, req.Body, err = drainBody(req.Body)
if err != nil {
return nil, err
}
}
// Since we're using the actual Transport code to write the request,
// switch to http so the Transport doesn't try to do an SSL
// negotiation with our dumpConn and its bytes.Buffer & pipe.
// The wire format for https and http are the same, anyway.
reqSend := req
if req.URL.Scheme == "https" {
reqSend = new(http.Request)
*reqSend = *req
reqSend.URL = new(url.URL)
*reqSend.URL = *req.URL
reqSend.URL.Scheme = "http"
}
// Use the actual Transport code to record what we would send
// on the wire, but not using TCP. Use a Transport with a
// custom dialer that returns a fake net.Conn that waits
// for the full input (and recording it), and then responds
// with a dummy response.
var buf bytes.Buffer // records the output
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
dr := &delegateReader{c: make(chan io.Reader)}
t := &http.Transport{
Dial: func(net, addr string) (net.Conn, error) {
return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
},
}
defer t.CloseIdleConnections()
// We need this channel to ensure that the reader
// goroutine exits if t.RoundTrip returns an error.
// See golang.org/issue/32571.
quitReadCh := make(chan struct{})
// Wait for the request before replying with a dummy response:
go func() {
req, err := http.ReadRequest(bufio.NewReader(pr))
if err == nil {
// Ensure all the body is read; otherwise
// we'll get a partial dump.
io.Copy(io.Discard, req.Body)
req.Body.Close()
}
select {
case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
case <-quitReadCh:
// Ensure delegateReader.Read doesn't block forever if we get an error.
close(dr.c)
}
}()
_, err := t.RoundTrip(reqSend)
req.Body = save
if err != nil {
pw.Close()
dr.err = err
close(quitReadCh)
return nil, err
}
dump := buf.Bytes()
// If we used a dummy body above, remove it now.
// TODO: if the req.ContentLength is large, we allocate memory
// unnecessarily just to slice it off here. But this is just
// a debug function, so this is acceptable for now. We could
// discard the body earlier if this matters.
if dummyBody {
if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 {
dump = dump[:i+4]
}
}
return dump, nil
}
// delegateReader is a reader that delegates to another reader,
// once it arrives on a channel.
type delegateReader struct {
c chan io.Reader
err error // only used if r is nil and c is closed.
r io.Reader // nil until received from c
}
func (r *delegateReader) Read(p []byte) (int, error) {
if r.r == nil {
var ok bool
if r.r, ok = <-r.c; !ok {
return 0, r.err
}
}
return r.r.Read(p)
}
// Return value if nonempty, def otherwise.
func valueOrDefault(value, def string) string {
if value != "" {
return value
}
return def
}
var reqWriteExcludeHeaderDump = map[string]bool{
"Host": true, // not in Header map anyway
"Transfer-Encoding": true,
"Trailer": true,
}
// DumpRequest returns the given request in its HTTP/1.x wire
// representation. It should only be used by servers to debug client
// requests. The returned representation is an approximation only;
// some details of the initial request are lost while parsing it into
// an http.Request. In particular, the order and case of header field
// names are lost. The order of values in multi-valued headers is kept
// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their
// original binary representations.
//
// If body is true, DumpRequest also returns the body. To do so, it
// consumes req.Body and then replaces it with a new io.ReadCloser
// that yields the same bytes. If DumpRequest returns an error,
// the state of req is undefined.
//
// The documentation for http.Request.Write details which fields
// of req are included in the dump.
func DumpRequest(req *http.Request, body bool) ([]byte, error) {
var err error
save := req.Body
if !body || req.Body == nil {
req.Body = nil
} else {
save, req.Body, err = drainBody(req.Body)
if err != nil {
return nil, err
}
}
var b bytes.Buffer
// By default, print out the unmodified req.RequestURI, which
// is always set for incoming server requests. But because we
// previously used req.URL.RequestURI and the docs weren't
// always so clear about when to use DumpRequest vs
// DumpRequestOut, fall back to the old way if the caller
// provides a non-server Request.
reqURI := req.RequestURI
if reqURI == "" {
reqURI = req.URL.RequestURI()
}
fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
reqURI, req.ProtoMajor, req.ProtoMinor)
absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://")
if !absRequestURI {
host := req.Host
if host == "" && req.URL != nil {
host = req.URL.Host
}
if host != "" {
fmt.Fprintf(&b, "Host: %s\r\n", host)
}
}
chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
if len(req.TransferEncoding) > 0 {
fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
}
err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
if err != nil {
return nil, err
}
io.WriteString(&b, "\r\n")
if req.Body != nil {
var dest io.Writer = &b
if chunked {
dest = NewChunkedWriter(dest)
}
_, err = io.Copy(dest, req.Body)
if chunked {
dest.(io.Closer).Close()
io.WriteString(&b, "\r\n")
}
}
req.Body = save
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
// errNoBody is a sentinel error value used by failureToReadBody so we
// can detect that the lack of body was intentional.
var errNoBody = errors.New("sentinel error value")
// failureToReadBody is an io.ReadCloser that just returns errNoBody on
// Read. It's swapped in when we don't actually want to consume
// the body, but need a non-nil one, and want to distinguish the
// error from reading the dummy body.
type failureToReadBody struct{}
func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
func (failureToReadBody) Close() error { return nil }
// emptyBody is an instance of empty reader.
var emptyBody = io.NopCloser(strings.NewReader(""))
// DumpResponse is like DumpRequest but dumps a response.
func DumpResponse(resp *http.Response, body bool) ([]byte, error) {
var b bytes.Buffer
var err error
save := resp.Body
savecl := resp.ContentLength
if !body {
// For content length of zero. Make sure the body is an empty
// reader, instead of returning error through failureToReadBody{}.
if resp.ContentLength == 0 {
resp.Body = emptyBody
} else {
resp.Body = failureToReadBody{}
}
} else if resp.Body == nil {
resp.Body = emptyBody
} else {
save, resp.Body, err = drainBody(resp.Body)
if err != nil {
return nil, err
}
}
err = resp.Write(&b)
if err == errNoBody {
err = nil
}
resp.Body = save
resp.ContentLength = savecl
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httputil provides HTTP utility functions, complementing the
// more common ones in the net/http package.
package httputil
import (
"io"
"net/http/internal"
)
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
return internal.NewChunkedReader(r)
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream but does
// not send the final CRLF that appears after trailers; trailers and the last
// CRLF must be written separately.
//
// NewChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
// Content-Length header. Using NewChunkedWriter inside a handler
// would result in double chunking or chunking with a Content-Length
// length, both of which are wrong.
func NewChunkedWriter(w io.Writer) io.WriteCloser {
return internal.NewChunkedWriter(w)
}
// ErrLineTooLong is returned when reading malformed chunked data
// with lines that are too long.
var ErrLineTooLong = internal.ErrLineTooLong
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httputil
import (
"bufio"
"errors"
"io"
"net"
"net/http"
"net/textproto"
"sync"
)
var (
// Deprecated: No longer used.
ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"}
// Deprecated: No longer used.
ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
// Deprecated: No longer used.
ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
)
// This is an API usage error - the local side is closed.
// ErrPersistEOF (above) reports that the remote side is closed.
var errClosed = errors.New("i/o operation on closed connection")
// ServerConn is an artifact of Go's early HTTP implementation.
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
// Deprecated: Use the Server in package net/http instead.
type ServerConn struct {
mu sync.Mutex // read-write protects the following fields
c net.Conn
r *bufio.Reader
re, we error // read/write errors
lastbody io.ReadCloser
nread, nwritten int
pipereq map[*http.Request]uint
pipe textproto.Pipeline
}
// NewServerConn is an artifact of Go's early HTTP implementation.
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
// Deprecated: Use the Server in package net/http instead.
func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
if r == nil {
r = bufio.NewReader(c)
}
return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
}
// Hijack detaches the ServerConn and returns the underlying connection as well
// as the read-side bufio which may have some left over data. Hijack may be
// called before Read has signaled the end of the keep-alive logic. The user
// should not call Hijack while Read or Write is in progress.
func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) {
sc.mu.Lock()
defer sc.mu.Unlock()
c := sc.c
r := sc.r
sc.c = nil
sc.r = nil
return c, r
}
// Close calls Hijack and then also closes the underlying connection.
func (sc *ServerConn) Close() error {
c, _ := sc.Hijack()
if c != nil {
return c.Close()
}
return nil
}
// Read returns the next request on the wire. An ErrPersistEOF is returned if
// it is gracefully determined that there are no more requests (e.g. after the
// first request on an HTTP/1.0 connection, or after a Connection:close on a
// HTTP/1.1 connection).
func (sc *ServerConn) Read() (*http.Request, error) {
var req *http.Request
var err error
// Ensure ordered execution of Reads and Writes
id := sc.pipe.Next()
sc.pipe.StartRequest(id)
defer func() {
sc.pipe.EndRequest(id)
if req == nil {
sc.pipe.StartResponse(id)
sc.pipe.EndResponse(id)
} else {
// Remember the pipeline id of this request
sc.mu.Lock()
sc.pipereq[req] = id
sc.mu.Unlock()
}
}()
sc.mu.Lock()
if sc.we != nil { // no point receiving if write-side broken or closed
defer sc.mu.Unlock()
return nil, sc.we
}
if sc.re != nil {
defer sc.mu.Unlock()
return nil, sc.re
}
if sc.r == nil { // connection closed by user in the meantime
defer sc.mu.Unlock()
return nil, errClosed
}
r := sc.r
lastbody := sc.lastbody
sc.lastbody = nil
sc.mu.Unlock()
// Make sure body is fully consumed, even if user does not call body.Close
if lastbody != nil {
// body.Close is assumed to be idempotent and multiple calls to
// it should return the error that its first invocation
// returned.
err = lastbody.Close()
if err != nil {
sc.mu.Lock()
defer sc.mu.Unlock()
sc.re = err
return nil, err
}
}
req, err = http.ReadRequest(r)
sc.mu.Lock()
defer sc.mu.Unlock()
if err != nil {
if err == io.ErrUnexpectedEOF {
// A close from the opposing client is treated as a
// graceful close, even if there was some unparse-able
// data before the close.
sc.re = ErrPersistEOF
return nil, sc.re
} else {
sc.re = err
return req, err
}
}
sc.lastbody = req.Body
sc.nread++
if req.Close {
sc.re = ErrPersistEOF
return req, sc.re
}
return req, err
}
// Pending returns the number of unanswered requests
// that have been received on the connection.
func (sc *ServerConn) Pending() int {
sc.mu.Lock()
defer sc.mu.Unlock()
return sc.nread - sc.nwritten
}
// Write writes resp in response to req. To close the connection gracefully, set the
// Response.Close field to true. Write should be considered operational until
// it returns an error, regardless of any errors returned on the Read side.
func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
// Retrieve the pipeline ID of this request/response pair
sc.mu.Lock()
id, ok := sc.pipereq[req]
delete(sc.pipereq, req)
if !ok {
sc.mu.Unlock()
return ErrPipeline
}
sc.mu.Unlock()
// Ensure pipeline order
sc.pipe.StartResponse(id)
defer sc.pipe.EndResponse(id)
sc.mu.Lock()
if sc.we != nil {
defer sc.mu.Unlock()
return sc.we
}
if sc.c == nil { // connection closed by user in the meantime
defer sc.mu.Unlock()
return ErrClosed
}
c := sc.c
if sc.nread <= sc.nwritten {
defer sc.mu.Unlock()
return errors.New("persist server pipe count")
}
if resp.Close {
// After signaling a keep-alive close, any pipelined unread
// requests will be lost. It is up to the user to drain them
// before signaling.
sc.re = ErrPersistEOF
}
sc.mu.Unlock()
err := resp.Write(c)
sc.mu.Lock()
defer sc.mu.Unlock()
if err != nil {
sc.we = err
return err
}
sc.nwritten++
return nil
}
// ClientConn is an artifact of Go's early HTTP implementation.
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
// Deprecated: Use Client or Transport in package net/http instead.
type ClientConn struct {
mu sync.Mutex // read-write protects the following fields
c net.Conn
r *bufio.Reader
re, we error // read/write errors
lastbody io.ReadCloser
nread, nwritten int
pipereq map[*http.Request]uint
pipe textproto.Pipeline
writeReq func(*http.Request, io.Writer) error
}
// NewClientConn is an artifact of Go's early HTTP implementation.
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
// Deprecated: Use the Client or Transport in package net/http instead.
func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
if r == nil {
r = bufio.NewReader(c)
}
return &ClientConn{
c: c,
r: r,
pipereq: make(map[*http.Request]uint),
writeReq: (*http.Request).Write,
}
}
// NewProxyClientConn is an artifact of Go's early HTTP implementation.
// It is low-level, old, and unused by Go's current HTTP stack.
// We should have deleted it before Go 1.
//
// Deprecated: Use the Client or Transport in package net/http instead.
func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
cc := NewClientConn(c, r)
cc.writeReq = (*http.Request).WriteProxy
return cc
}
// Hijack detaches the ClientConn and returns the underlying connection as well
// as the read-side bufio which may have some left over data. Hijack may be
// called before the user or Read have signaled the end of the keep-alive
// logic. The user should not call Hijack while Read or Write is in progress.
func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
cc.mu.Lock()
defer cc.mu.Unlock()
c = cc.c
r = cc.r
cc.c = nil
cc.r = nil
return
}
// Close calls Hijack and then also closes the underlying connection.
func (cc *ClientConn) Close() error {
c, _ := cc.Hijack()
if c != nil {
return c.Close()
}
return nil
}
// Write writes a request. An ErrPersistEOF error is returned if the connection
// has been closed in an HTTP keep-alive sense. If req.Close equals true, the
// keep-alive connection is logically closed after this request and the opposing
// server is informed. An ErrUnexpectedEOF indicates the remote closed the
// underlying TCP connection, which is usually considered as graceful close.
func (cc *ClientConn) Write(req *http.Request) error {
var err error
// Ensure ordered execution of Writes
id := cc.pipe.Next()
cc.pipe.StartRequest(id)
defer func() {
cc.pipe.EndRequest(id)
if err != nil {
cc.pipe.StartResponse(id)
cc.pipe.EndResponse(id)
} else {
// Remember the pipeline id of this request
cc.mu.Lock()
cc.pipereq[req] = id
cc.mu.Unlock()
}
}()
cc.mu.Lock()
if cc.re != nil { // no point sending if read-side closed or broken
defer cc.mu.Unlock()
return cc.re
}
if cc.we != nil {
defer cc.mu.Unlock()
return cc.we
}
if cc.c == nil { // connection closed by user in the meantime
defer cc.mu.Unlock()
return errClosed
}
c := cc.c
if req.Close {
// We write the EOF to the write-side error, because there
// still might be some pipelined reads
cc.we = ErrPersistEOF
}
cc.mu.Unlock()
err = cc.writeReq(req, c)
cc.mu.Lock()
defer cc.mu.Unlock()
if err != nil {
cc.we = err
return err
}
cc.nwritten++
return nil
}
// Pending returns the number of unanswered requests
// that have been sent on the connection.
func (cc *ClientConn) Pending() int {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.nwritten - cc.nread
}
// Read reads the next response from the wire. A valid response might be
// returned together with an ErrPersistEOF, which means that the remote
// requested that this be the last request serviced. Read can be called
// concurrently with Write, but not with another Read.
func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
// Retrieve the pipeline ID of this request/response pair
cc.mu.Lock()
id, ok := cc.pipereq[req]
delete(cc.pipereq, req)
if !ok {
cc.mu.Unlock()
return nil, ErrPipeline
}
cc.mu.Unlock()
// Ensure pipeline order
cc.pipe.StartResponse(id)
defer cc.pipe.EndResponse(id)
cc.mu.Lock()
if cc.re != nil {
defer cc.mu.Unlock()
return nil, cc.re
}
if cc.r == nil { // connection closed by user in the meantime
defer cc.mu.Unlock()
return nil, errClosed
}
r := cc.r
lastbody := cc.lastbody
cc.lastbody = nil
cc.mu.Unlock()
// Make sure body is fully consumed, even if user does not call body.Close
if lastbody != nil {
// body.Close is assumed to be idempotent and multiple calls to
// it should return the error that its first invocation
// returned.
err = lastbody.Close()
if err != nil {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.re = err
return nil, err
}
}
resp, err = http.ReadResponse(r, req)
cc.mu.Lock()
defer cc.mu.Unlock()
if err != nil {
cc.re = err
return resp, err
}
cc.lastbody = resp.Body
cc.nread++
if resp.Close {
cc.re = ErrPersistEOF // don't send any more requests
return resp, cc.re
}
return resp, err
}
// Do is convenience method that writes a request and reads a response.
func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) {
err := cc.Write(req)
if err != nil {
return nil, err
}
return cc.Read(req)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP reverse proxy handler
package httputil
import (
"context"
"errors"
"fmt"
"io"
"log"
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
// A ProxyRequest contains a request to be rewritten by a ReverseProxy.
type ProxyRequest struct {
// In is the request received by the proxy.
// The Rewrite function must not modify In.
In *http.Request
// Out is the request which will be sent by the proxy.
// The Rewrite function may modify or replace this request.
// Hop-by-hop headers are removed from this request
// before Rewrite is called.
Out *http.Request
}
// SetURL routes the outbound request to the scheme, host, and base path
// provided in target. If the target's path is "/base" and the incoming
// request was for "/dir", the target request will be for "/base/dir".
//
// SetURL rewrites the outbound Host header to match the target's host.
// To preserve the inbound request's Host header (the default behavior
// of NewSingleHostReverseProxy):
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.SetURL(url)
// r.Out.Host = r.In.Host
// }
func (r *ProxyRequest) SetURL(target *url.URL) {
rewriteRequestURL(r.Out, target)
r.Out.Host = ""
}
// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
// X-Forwarded-Proto headers of the outbound request.
//
// - The X-Forwarded-For header is set to the client IP address.
// - The X-Forwarded-Host header is set to the host name requested
// by the client.
// - The X-Forwarded-Proto header is set to "http" or "https", depending
// on whether the inbound request was made on a TLS-enabled connection.
//
// If the outbound request contains an existing X-Forwarded-For header,
// SetXForwarded appends the client IP address to it. To append to the
// inbound request's X-Forwarded-For header (the default behavior of
// ReverseProxy when using a Director function), copy the header
// from the inbound request before calling SetXForwarded:
//
// rewriteFunc := func(r *httputil.ProxyRequest) {
// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
// r.SetXForwarded()
// }
func (r *ProxyRequest) SetXForwarded() {
clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
if err == nil {
prior := r.Out.Header["X-Forwarded-For"]
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
r.Out.Header.Set("X-Forwarded-For", clientIP)
} else {
r.Out.Header.Del("X-Forwarded-For")
}
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
if r.In.TLS == nil {
r.Out.Header.Set("X-Forwarded-Proto", "http")
} else {
r.Out.Header.Set("X-Forwarded-Proto", "https")
}
}
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Rewrite must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Rewrite must not access the provided ProxyRequest
// or its contents after returning.
//
// The Forwarded, X-Forwarded, X-Forwarded-Host,
// and X-Forwarded-Proto headers are removed from the
// outbound request before Rewrite is called. See also
// the ProxyRequest.SetXForwarded method.
//
// Unparsable query parameters are removed from the
// outbound request before Rewrite is called.
// The Rewrite function may copy the inbound URL's
// RawQuery to the outbound URL to preserve the original
// parameter string. Note that this can lead to security
// issues if the proxy's interpretation of query parameters
// does not match that of the downstream server.
//
// At most one of Rewrite or Director may be set.
Rewrite func(*ProxyRequest)
// Director is a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
//
// By default, the X-Forwarded-For header is set to the
// value of the client IP address. If an X-Forwarded-For
// header already exists, the client IP is appended to the
// existing values. As a special case, if the header
// exists in the Request.Header map but has a nil value
// (such as when set by the Director func), the X-Forwarded-For
// header is not modified.
//
// To prevent IP spoofing, be sure to delete any pre-existing
// X-Forwarded-For header coming from the client or
// an untrusted proxy.
//
// Hop-by-hop headers are removed from the request after
// Director returns, which can remove headers added by
// Director. Use a Rewrite function instead to ensure
// modifications to the request are preserved.
//
// Unparsable query parameters are removed from the outbound
// request if Request.Form is set after Director returns.
//
// At most one of Rewrite or Director may be set.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
// A negative value means to flush immediately
// after each write to the client.
// The FlushInterval is ignored when ReverseProxy
// recognizes a response as a streaming response, or
// if its ContentLength is -1; for such responses, writes
// are flushed to the client immediately.
FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
// BufferPool optionally specifies a buffer pool to
// get byte slices for use by io.CopyBuffer when
// copying HTTP response bodies.
BufferPool BufferPool
// ModifyResponse is an optional function that modifies the
// Response from the backend. It is called if the backend
// returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called without any call to ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error
// ErrorHandler is an optional function that handles errors
// reaching the backend or errors from ModifyResponse.
//
// If nil, the default is to log the provided error and return
// a 502 Status Bad Gateway response.
ErrorHandler func(http.ResponseWriter, *http.Request, error)
}
// A BufferPool is an interface for getting and returning temporary
// byte slices for use by io.CopyBuffer.
type BufferPool interface {
Get() []byte
Put([]byte)
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
// NewSingleHostReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
//
// NewSingleHostReverseProxy does not rewrite the Host header.
//
// To customize the ReverseProxy behavior beyond what
// NewSingleHostReverseProxy provides, use ReverseProxy directly
// with a Rewrite function. The ProxyRequest SetURL method
// may be used to route the outbound request. (Note that SetURL,
// unlike NewSingleHostReverseProxy, rewrites the Host header
// of the outbound request by default.)
//
// proxy := &ReverseProxy{
// Rewrite: func(r *ProxyRequest) {
// r.SetURL(target)
// r.Out.Host = r.In.Host // if desired
// },
// }
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
director := func(req *http.Request) {
rewriteRequestURL(req, target)
}
return &ReverseProxy{Director: director}
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
}
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
if p.ErrorHandler != nil {
return p.ErrorHandler
}
return p.defaultErrorHandler
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(res); err != nil {
res.Body.Close()
p.getErrorHandler()(rw, req, err)
return false
}
return true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
ctx := req.Context()
if ctx.Done() != nil {
// CloseNotifier predates context.Context, and has been
// entirely superseded by it. If the request contains
// a Context that carries a cancellation signal, don't
// bother spinning up a goroutine to watch the CloseNotify
// channel (if any).
//
// If the request Context has a nil Done channel (which
// means it is either context.Background, or a custom
// Context implementation with no cancellation signal),
// then consult the CloseNotifier if available.
} else if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Body != nil {
// Reading from the request body after returning from a handler is not
// allowed, and the RoundTrip goroutine that reads the Body can outlive
// this handler. This can lead to a crash if the handler panics (see
// Issue 46866). Although calling Close doesn't guarantee there isn't
// any Read in flight after the handle returns, in practice it's safe to
// read after closing it.
defer outreq.Body.Close()
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
if (p.Director != nil) == (p.Rewrite != nil) {
p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
return
}
if p.Director != nil {
p.Director(outreq)
if outreq.Form != nil {
outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
}
}
outreq.Close = false
reqUpType := upgradeType(outreq.Header)
if !ascii.IsPrint(reqUpType) {
p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
return
}
removeHopByHopHeaders(outreq.Header)
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeHopByHopHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
if p.Rewrite != nil {
// Strip client-provided forwarding headers.
// The Rewrite func may use SetXForwarded to set new values
// for these or copy the previous values from the inbound request.
outreq.Header.Del("Forwarded")
outreq.Header.Del("X-Forwarded-For")
outreq.Header.Del("X-Forwarded-Host")
outreq.Header.Del("X-Forwarded-Proto")
// Remove unparsable query parameters from the outbound request.
outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
pr := &ProxyRequest{
In: req,
Out: outreq,
}
p.Rewrite(pr)
outreq = pr.Out
} else {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
outreq.Header.Set("X-Forwarded-For", clientIP)
}
}
}
if _, ok := outreq.Header["User-Agent"]; !ok {
// If the outbound request doesn't have a User-Agent header set,
// don't send the default Go HTTP client User-Agent.
outreq.Header.Set("User-Agent", "")
}
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
res, err := transport.RoundTrip(outreq)
if err != nil {
p.getErrorHandler()(rw, outreq, err)
return
}
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
removeHopByHopHeaders(res.Header)
if !p.modifyResponse(rw, res, outreq) {
return
}
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
rw.WriteHeader(res.StatusCode)
err = p.copyResponse(rw, res.Body, p.flushInterval(res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
// is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
// on read error while copying body.
if !shouldPanicOnCopyError(req) {
p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
return
}
panic(http.ErrAbortHandler)
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
var inOurTests bool // whether we're in our own tests
// shouldPanicOnCopyError reports whether the reverse proxy should
// panic with http.ErrAbortHandler. This is the right thing to do by
// default, but Go 1.10 and earlier did not, so existing unit tests
// weren't expecting panics. Only panic in our own tests, or when
// running under the HTTP server.
func shouldPanicOnCopyError(req *http.Request) bool {
if inOurTests {
// Our tests know to handle this panic.
return true
}
if req.Context().Value(http.ServerContextKey) != nil {
// We seem to be running under an HTTP server, so
// it'll recover the panic.
return true
}
// Otherwise act like Go 1.10 and earlier to not break
// existing tests.
return false
}
// removeHopByHopHeaders removes hop-by-hop headers.
func removeHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
for _, f := range hopHeaders {
h.Del(f)
}
}
// flushInterval returns the p.FlushInterval value, conditionally
// overriding its value for a specific request/response.
func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
resCT := res.Header.Get("Content-Type")
// For Server-Sent Events responses, flush immediately.
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
return -1 // negative means immediately
}
// We might have the case of streaming for which Content-Length might be unset.
if res.ContentLength == -1 {
return -1
}
return p.FlushInterval
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: flushInterval,
}
defer mlw.stop()
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
dst = mlw
}
}
var buf []byte
if p.BufferPool != nil {
buf = p.BufferPool.Get()
defer p.BufferPool.Put(buf)
}
_, err := p.copyBuffer(dst, src, buf)
return err
}
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
var written int64
for {
nr, rerr := src.Read(buf)
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return written, werr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if rerr != nil {
if rerr == io.EOF {
rerr = nil
}
return written, rerr
}
}
}
func (p *ReverseProxy) logf(format string, args ...any) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
t *time.Timer
flushPending bool
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
return
}
if m.flushPending {
return
}
if m.t == nil {
m.t = time.AfterFunc(m.latency, m.delayedFlush)
} else {
m.t.Reset(m.latency)
}
m.flushPending = true
return
}
func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
m.dst.Flush()
m.flushPending = false
}
func (m *maxLatencyWriter) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.flushPending = false
if m.t != nil {
m.t.Stop()
}
}
func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
return ""
}
return h.Get("Upgrade")
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
}
if !ascii.EqualFold(reqUpType, resUpType) {
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return
}
hj, ok := rw.(http.Hijacker)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()
defer close(backConnCloseCh)
conn, brw, err := hj.Hijack()
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
return
}
defer conn.Close()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
return
}
if err := brw.Flush(); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
return
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc
}
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.Copy(c.user, c.backend)
errc <- err
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.Copy(c.backend, c.user)
errc <- err
}
func cleanQueryParams(s string) string {
reencode := func(s string) string {
v, _ := url.ParseQuery(s)
return v.Encode()
}
for i := 0; i < len(s); {
switch s[i] {
case ';':
return reencode(s)
case '%':
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
return reencode(s)
}
i += 3
default:
i++
}
}
return s
}
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ascii
import (
"strings"
"unicode"
)
// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func EqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lower(s[i]) != lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// IsPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func IsPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// Is returns whether s is ASCII.
func Is(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > unicode.MaxASCII {
return false
}
}
return true
}
// ToLower returns the lowercase version of s if s is ASCII and printable.
func ToLower(s string) (lower string, ok bool) {
if !IsPrint(s) {
return "", false
}
return strings.ToLower(s), true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The wire protocol for HTTP's "chunked" Transfer-Encoding.
// Package internal contains HTTP internals shared by net/http and
// net/http/httputil.
package internal
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
)
const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
var ErrLineTooLong = errors.New("header line too long")
// NewChunkedReader returns a new chunkedReader that translates the data read from r
// out of HTTP "chunked" format before returning it.
// The chunkedReader returns io.EOF when the final 0-length chunk is read.
//
// NewChunkedReader is not needed by normal applications. The http package
// automatically decodes chunking when reading response bodies.
func NewChunkedReader(r io.Reader) io.Reader {
br, ok := r.(*bufio.Reader)
if !ok {
br = bufio.NewReader(r)
}
return &chunkedReader{r: br}
}
type chunkedReader struct {
r *bufio.Reader
n uint64 // unread bytes in chunk
err error
buf [2]byte
checkEnd bool // whether need to check for \r\n chunk footer
}
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line []byte
line, cr.err = readChunkLine(cr.r)
if cr.err != nil {
return
}
cr.n, cr.err = parseHexUint(line)
if cr.err != nil {
return
}
if cr.n == 0 {
cr.err = io.EOF
}
}
func (cr *chunkedReader) chunkHeaderAvailable() bool {
n := cr.r.Buffered()
if n > 0 {
peek, _ := cr.r.Peek(n)
return bytes.IndexByte(peek, '\n') >= 0
}
return false
}
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
for cr.err == nil {
if cr.checkEnd {
if n > 0 && cr.r.Buffered() < 2 {
// We have some data. Return early (per the io.Reader
// contract) instead of potentially blocking while
// reading more.
break
}
if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
if string(cr.buf[:]) != "\r\n" {
cr.err = errors.New("malformed chunked encoding")
break
}
} else {
if cr.err == io.EOF {
cr.err = io.ErrUnexpectedEOF
}
break
}
cr.checkEnd = false
}
if cr.n == 0 {
if n > 0 && !cr.chunkHeaderAvailable() {
// We've read enough. Don't potentially block
// reading a new chunk header.
break
}
cr.beginChunk()
continue
}
if len(b) == 0 {
break
}
rbuf := b
if uint64(len(rbuf)) > cr.n {
rbuf = rbuf[:cr.n]
}
var n0 int
n0, cr.err = cr.r.Read(rbuf)
n += n0
b = b[n0:]
cr.n -= uint64(n0)
// If we're at the end of a chunk, read the next two
// bytes to verify they are "\r\n".
if cr.n == 0 && cr.err == nil {
cr.checkEnd = true
} else if cr.err == io.EOF {
cr.err = io.ErrUnexpectedEOF
}
}
return n, cr.err
}
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
// The returned bytes are owned by the bufio.Reader
// so they are only valid until the next bufio read.
func readChunkLine(b *bufio.Reader) ([]byte, error) {
p, err := b.ReadSlice('\n')
if err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == bufio.ErrBufferFull {
err = ErrLineTooLong
}
return nil, err
}
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
p = trimTrailingWhitespace(p)
p, err = removeChunkExtension(p)
if err != nil {
return nil, err
}
return p, nil
}
func trimTrailingWhitespace(b []byte) []byte {
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
b = b[:len(b)-1]
}
return b
}
func isASCIISpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
var semi = []byte(";")
// removeChunkExtension removes any chunk-extension from p.
// For example,
//
// "0" => "0"
// "0;token" => "0"
// "0;token=val" => "0"
// `0;token="quoted string"` => "0"
func removeChunkExtension(p []byte) ([]byte, error) {
p, _, _ = bytes.Cut(p, semi)
// TODO: care about exact syntax of chunk extensions? We're
// ignoring and stripping them anyway. For now just never
// return an error.
return p, nil
}
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream but does
// not send the final CRLF that appears after trailers; trailers and the last
// CRLF must be written separately.
//
// NewChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
// Content-Length header. Using newChunkedWriter inside a handler
// would result in double chunking or chunking with a Content-Length
// length, both of which are wrong.
func NewChunkedWriter(w io.Writer) io.WriteCloser {
return &chunkedWriter{w}
}
// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
// Encoding wire format to the underlying Wire chunkedWriter.
type chunkedWriter struct {
Wire io.Writer
}
// Write the contents of data as one chunk to Wire.
// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
// a bug since it does not check for success of io.WriteString
func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
// Don't send 0-length data. It looks like EOF for chunked encoding.
if len(data) == 0 {
return 0, nil
}
if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
return 0, err
}
if n, err = cw.Wire.Write(data); err != nil {
return
}
if n != len(data) {
err = io.ErrShortWrite
return
}
if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil {
return
}
if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok {
err = bw.Flush()
}
return
}
func (cw *chunkedWriter) Close() error {
_, err := io.WriteString(cw.Wire, "0\r\n")
return err
}
// FlushAfterChunkWriter signals from the caller of NewChunkedWriter
// that each chunk should be followed by a flush. It is used by the
// http.Transport code to keep the buffering behavior for headers and
// trailers, but flush out chunks aggressively in the middle for
// request bodies which may be generated slowly. See Issue 6574.
type FlushAfterChunkWriter struct {
*bufio.Writer
}
func parseHexUint(v []byte) (n uint64, err error) {
for i, b := range v {
switch {
case '0' <= b && b <= '9':
b = b - '0'
case 'a' <= b && b <= 'f':
b = b - 'a' + 10
case 'A' <= b && b <= 'F':
b = b - 'A' + 10
default:
return 0, errors.New("invalid byte in chunk length")
}
if i == 16 {
return 0, errors.New("http chunk length too large")
}
n <<= 4
n |= uint64(b)
}
return
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package pprof serves via its HTTP server runtime profiling data
// in the format expected by the pprof visualization tool.
//
// The package is typically only imported for the side effect of
// registering its HTTP handlers.
// The handled paths all begin with /debug/pprof/.
//
// To use pprof, link this package into your program:
//
// import _ "net/http/pprof"
//
// If your application is not already running an http server, you
// need to start one. Add "net/http" and "log" to your imports and
// the following code to your main function:
//
// go func() {
// log.Println(http.ListenAndServe("localhost:6060", nil))
// }()
//
// By default, all the profiles listed in [runtime/pprof.Profile] are
// available (via [Handler]), in addition to the [Cmdline], [Profile], [Symbol],
// and [Trace] profiles defined in this package.
// If you are not using DefaultServeMux, you will have to register handlers
// with the mux you are using.
//
// # Usage examples
//
// Use the pprof tool to look at the heap profile:
//
// go tool pprof http://localhost:6060/debug/pprof/heap
//
// Or to look at a 30-second CPU profile:
//
// go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
//
// Or to look at the goroutine blocking profile, after calling
// runtime.SetBlockProfileRate in your program:
//
// go tool pprof http://localhost:6060/debug/pprof/block
//
// Or to look at the holders of contended mutexes, after calling
// runtime.SetMutexProfileFraction in your program:
//
// go tool pprof http://localhost:6060/debug/pprof/mutex
//
// The package also exports a handler that serves execution trace data
// for the "go tool trace" command. To collect a 5-second execution trace:
//
// curl -o trace.out http://localhost:6060/debug/pprof/trace?seconds=5
// go tool trace trace.out
//
// To view all available profiles, open http://localhost:6060/debug/pprof/
// in your browser.
//
// For a study of the facility in action, visit
//
// https://blog.golang.org/2011/06/profiling-go-programs.html
package pprof
import (
"bufio"
"bytes"
"context"
"fmt"
"html"
"internal/profile"
"io"
"log"
"net/http"
"net/url"
"os"
"runtime"
"runtime/pprof"
"runtime/trace"
"sort"
"strconv"
"strings"
"time"
)
func init() {
http.HandleFunc("/debug/pprof/", Index)
http.HandleFunc("/debug/pprof/cmdline", Cmdline)
http.HandleFunc("/debug/pprof/profile", Profile)
http.HandleFunc("/debug/pprof/symbol", Symbol)
http.HandleFunc("/debug/pprof/trace", Trace)
}
// Cmdline responds with the running program's
// command line, with arguments separated by NUL bytes.
// The package initialization registers it as /debug/pprof/cmdline.
func Cmdline(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, strings.Join(os.Args, "\x00"))
}
func sleep(r *http.Request, d time.Duration) {
select {
case <-time.After(d):
case <-r.Context().Done():
}
}
func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool {
srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server)
return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds()
}
func serveError(w http.ResponseWriter, status int, txt string) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Go-Pprof", "1")
w.Header().Del("Content-Disposition")
w.WriteHeader(status)
fmt.Fprintln(w, txt)
}
// Profile responds with the pprof-formatted cpu profile.
// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified.
// The package initialization registers it as /debug/pprof/profile.
func Profile(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
if sec <= 0 || err != nil {
sec = 30
}
if durationExceedsWriteTimeout(r, float64(sec)) {
serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
return
}
// Set Content Type assuming StartCPUProfile will work,
// because if it does it starts writing.
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", `attachment; filename="profile"`)
if err := pprof.StartCPUProfile(w); err != nil {
// StartCPUProfile failed, so no writes yet.
serveError(w, http.StatusInternalServerError,
fmt.Sprintf("Could not enable CPU profiling: %s", err))
return
}
sleep(r, time.Duration(sec)*time.Second)
pprof.StopCPUProfile()
}
// Trace responds with the execution trace in binary form.
// Tracing lasts for duration specified in seconds GET parameter, or for 1 second if not specified.
// The package initialization registers it as /debug/pprof/trace.
func Trace(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
sec, err := strconv.ParseFloat(r.FormValue("seconds"), 64)
if sec <= 0 || err != nil {
sec = 1
}
if durationExceedsWriteTimeout(r, sec) {
serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
return
}
// Set Content Type assuming trace.Start will work,
// because if it does it starts writing.
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", `attachment; filename="trace"`)
if err := trace.Start(w); err != nil {
// trace.Start failed, so no writes yet.
serveError(w, http.StatusInternalServerError,
fmt.Sprintf("Could not enable tracing: %s", err))
return
}
sleep(r, time.Duration(sec*float64(time.Second)))
trace.Stop()
}
// Symbol looks up the program counters listed in the request,
// responding with a table mapping program counters to function names.
// The package initialization registers it as /debug/pprof/symbol.
func Symbol(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
// We have to read the whole POST body before
// writing any output. Buffer the output here.
var buf bytes.Buffer
// We don't know how many symbols we have, but we
// do have symbol information. Pprof only cares whether
// this number is 0 (no symbols available) or > 0.
fmt.Fprintf(&buf, "num_symbols: 1\n")
var b *bufio.Reader
if r.Method == "POST" {
b = bufio.NewReader(r.Body)
} else {
b = bufio.NewReader(strings.NewReader(r.URL.RawQuery))
}
for {
word, err := b.ReadSlice('+')
if err == nil {
word = word[0 : len(word)-1] // trim +
}
pc, _ := strconv.ParseUint(string(word), 0, 64)
if pc != 0 {
f := runtime.FuncForPC(uintptr(pc))
if f != nil {
fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name())
}
}
// Wait until here to check for err; the last
// symbol will have an err because it doesn't end in +.
if err != nil {
if err != io.EOF {
fmt.Fprintf(&buf, "reading request: %v\n", err)
}
break
}
}
w.Write(buf.Bytes())
}
// Handler returns an HTTP handler that serves the named profile.
// Available profiles can be found in [runtime/pprof.Profile].
func Handler(name string) http.Handler {
return handler(name)
}
type handler string
func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
p := pprof.Lookup(string(name))
if p == nil {
serveError(w, http.StatusNotFound, "Unknown profile")
return
}
if sec := r.FormValue("seconds"); sec != "" {
name.serveDeltaProfile(w, r, p, sec)
return
}
gc, _ := strconv.Atoi(r.FormValue("gc"))
if name == "heap" && gc > 0 {
runtime.GC()
}
debug, _ := strconv.Atoi(r.FormValue("debug"))
if debug != 0 {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
} else {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, name))
}
p.WriteTo(w, debug)
}
func (name handler) serveDeltaProfile(w http.ResponseWriter, r *http.Request, p *pprof.Profile, secStr string) {
sec, err := strconv.ParseInt(secStr, 10, 64)
if err != nil || sec <= 0 {
serveError(w, http.StatusBadRequest, `invalid value for "seconds" - must be a positive integer`)
return
}
if !profileSupportsDelta[name] {
serveError(w, http.StatusBadRequest, `"seconds" parameter is not supported for this profile type`)
return
}
// 'name' should be a key in profileSupportsDelta.
if durationExceedsWriteTimeout(r, float64(sec)) {
serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
return
}
debug, _ := strconv.Atoi(r.FormValue("debug"))
if debug != 0 {
serveError(w, http.StatusBadRequest, "seconds and debug params are incompatible")
return
}
p0, err := collectProfile(p)
if err != nil {
serveError(w, http.StatusInternalServerError, "failed to collect profile")
return
}
t := time.NewTimer(time.Duration(sec) * time.Second)
defer t.Stop()
select {
case <-r.Context().Done():
err := r.Context().Err()
if err == context.DeadlineExceeded {
serveError(w, http.StatusRequestTimeout, err.Error())
} else { // TODO: what's a good status code for canceled requests? 400?
serveError(w, http.StatusInternalServerError, err.Error())
}
return
case <-t.C:
}
p1, err := collectProfile(p)
if err != nil {
serveError(w, http.StatusInternalServerError, "failed to collect profile")
return
}
ts := p1.TimeNanos
dur := p1.TimeNanos - p0.TimeNanos
p0.Scale(-1)
p1, err = profile.Merge([]*profile.Profile{p0, p1})
if err != nil {
serveError(w, http.StatusInternalServerError, "failed to compute delta")
return
}
p1.TimeNanos = ts // set since we don't know what profile.Merge set for TimeNanos.
p1.DurationNanos = dur
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s-delta"`, name))
p1.Write(w)
}
func collectProfile(p *pprof.Profile) (*profile.Profile, error) {
var buf bytes.Buffer
if err := p.WriteTo(&buf, 0); err != nil {
return nil, err
}
ts := time.Now().UnixNano()
p0, err := profile.Parse(&buf)
if err != nil {
return nil, err
}
p0.TimeNanos = ts
return p0, nil
}
var profileSupportsDelta = map[handler]bool{
"allocs": true,
"block": true,
"goroutine": true,
"heap": true,
"mutex": true,
"threadcreate": true,
}
var profileDescriptions = map[string]string{
"allocs": "A sampling of all past memory allocations",
"block": "Stack traces that led to blocking on synchronization primitives",
"cmdline": "The command line invocation of the current program",
"goroutine": "Stack traces of all current goroutines. Use debug=2 as a query parameter to export in the same format as an unrecovered panic.",
"heap": "A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.",
"mutex": "Stack traces of holders of contended mutexes",
"profile": "CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.",
"threadcreate": "Stack traces that led to the creation of new OS threads",
"trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.",
}
type profileEntry struct {
Name string
Href string
Desc string
Count int
}
// Index responds with the pprof-formatted profile named by the request.
// For example, "/debug/pprof/heap" serves the "heap" profile.
// Index responds to a request for "/debug/pprof/" with an HTML page
// listing the available profiles.
func Index(w http.ResponseWriter, r *http.Request) {
if name, found := strings.CutPrefix(r.URL.Path, "/debug/pprof/"); found {
if name != "" {
handler(name).ServeHTTP(w, r)
return
}
}
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
var profiles []profileEntry
for _, p := range pprof.Profiles() {
profiles = append(profiles, profileEntry{
Name: p.Name(),
Href: p.Name(),
Desc: profileDescriptions[p.Name()],
Count: p.Count(),
})
}
// Adding other profiles exposed from within this package
for _, p := range []string{"cmdline", "profile", "trace"} {
profiles = append(profiles, profileEntry{
Name: p,
Href: p,
Desc: profileDescriptions[p],
})
}
sort.Slice(profiles, func(i, j int) bool {
return profiles[i].Name < profiles[j].Name
})
if err := indexTmplExecute(w, profiles); err != nil {
log.Print(err)
}
}
func indexTmplExecute(w io.Writer, profiles []profileEntry) error {
var b bytes.Buffer
b.WriteString(`<html>
<head>
<title>/debug/pprof/</title>
<style>
.profile-name{
display:inline-block;
width:6rem;
}
</style>
</head>
<body>
/debug/pprof/
<br>
<p>Set debug=1 as a query parameter to export in legacy text format</p>
<br>
Types of profiles available:
<table>
<thead><td>Count</td><td>Profile</td></thead>
`)
for _, profile := range profiles {
link := &url.URL{Path: profile.Href, RawQuery: "debug=1"}
fmt.Fprintf(&b, "<tr><td>%d</td><td><a href='%s'>%s</a></td></tr>\n", profile.Count, link, html.EscapeString(profile.Name))
}
b.WriteString(`</table>
<a href="goroutine?debug=2">full goroutine stack dump</a>
<br>
<p>
Profile Descriptions:
<ul>
`)
for _, profile := range profiles {
fmt.Fprintf(&b, "<li><div class=profile-name>%s: </div> %s</li>\n", html.EscapeString(profile.Name), html.EscapeString(profile.Desc))
}
b.WriteString(`</ul>
</p>
</body>
</html>`)
_, err := w.Write(b.Bytes())
return err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP Request reading and parsing.
package http
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
urlpkg "net/url"
"strconv"
"strings"
"sync"
"golang.org/x/net/idna"
)
const (
defaultMaxMemory = 32 << 20 // 32 MB
)
// ErrMissingFile is returned by FormFile when the provided file field name
// is either not present in the request or not a file field.
var ErrMissingFile = errors.New("http: no such file")
// ProtocolError represents an HTTP protocol error.
//
// Deprecated: Not all errors in the http package related to protocol errors
// are of type ProtocolError.
type ProtocolError struct {
ErrorString string
}
func (pe *ProtocolError) Error() string { return pe.ErrorString }
var (
// ErrNotSupported indicates that a feature is not supported.
//
// It is returned by ResponseController methods to indicate that
// the handler does not support the method, and by the Push method
// of Pusher implementations to indicate that HTTP/2 Push support
// is not available.
ErrNotSupported = &ProtocolError{"feature not supported"}
// Deprecated: ErrUnexpectedTrailer is no longer returned by
// anything in the net/http package. Callers should not
// compare errors against this variable.
ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
// ErrMissingBoundary is returned by Request.MultipartReader when the
// request's Content-Type does not include a "boundary" parameter.
ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
// ErrNotMultipart is returned by Request.MultipartReader when the
// request's Content-Type is not multipart/form-data.
ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
// Deprecated: ErrHeaderTooLong is no longer returned by
// anything in the net/http package. Callers should not
// compare errors against this variable.
ErrHeaderTooLong = &ProtocolError{"header too long"}
// Deprecated: ErrShortBody is no longer returned by
// anything in the net/http package. Callers should not
// compare errors against this variable.
ErrShortBody = &ProtocolError{"entity body too short"}
// Deprecated: ErrMissingContentLength is no longer returned by
// anything in the net/http package. Callers should not
// compare errors against this variable.
ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
)
func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) }
// Headers that Request.Write handles itself and should be skipped.
var reqWriteExcludeHeader = map[string]bool{
"Host": true, // not in Header map anyway
"User-Agent": true,
"Content-Length": true,
"Transfer-Encoding": true,
"Trailer": true,
}
// A Request represents an HTTP request received by a server
// or to be sent by a client.
//
// The field semantics differ slightly between client and server
// usage. In addition to the notes on the fields below, see the
// documentation for Request.Write and RoundTripper.
type Request struct {
// Method specifies the HTTP method (GET, POST, PUT, etc.).
// For client requests, an empty string means GET.
//
// Go's HTTP client does not support sending a request with
// the CONNECT method. See the documentation on Transport for
// details.
Method string
// URL specifies either the URI being requested (for server
// requests) or the URL to access (for client requests).
//
// For server requests, the URL is parsed from the URI
// supplied on the Request-Line as stored in RequestURI. For
// most requests, fields other than Path and RawQuery will be
// empty. (See RFC 7230, Section 5.3)
//
// For client requests, the URL's Host specifies the server to
// connect to, while the Request's Host field optionally
// specifies the Host header value to send in the HTTP
// request.
URL *url.URL
// The protocol version for incoming server requests.
//
// For client requests, these fields are ignored. The HTTP
// client code always uses either HTTP/1.1 or HTTP/2.
// See the docs on Transport for details.
Proto string // "HTTP/1.0"
ProtoMajor int // 1
ProtoMinor int // 0
// Header contains the request header fields either received
// by the server or to be sent by the client.
//
// If a server received a request with header lines,
//
// Host: example.com
// accept-encoding: gzip, deflate
// Accept-Language: en-us
// fOO: Bar
// foo: two
//
// then
//
// Header = map[string][]string{
// "Accept-Encoding": {"gzip, deflate"},
// "Accept-Language": {"en-us"},
// "Foo": {"Bar", "two"},
// }
//
// For incoming requests, the Host header is promoted to the
// Request.Host field and removed from the Header map.
//
// HTTP defines that header names are case-insensitive. The
// request parser implements this by using CanonicalHeaderKey,
// making the first character and any characters following a
// hyphen uppercase and the rest lowercase.
//
// For client requests, certain headers such as Content-Length
// and Connection are automatically written when needed and
// values in Header may be ignored. See the documentation
// for the Request.Write method.
Header Header
// Body is the request's body.
//
// For client requests, a nil body means the request has no
// body, such as a GET request. The HTTP Client's Transport
// is responsible for calling the Close method.
//
// For server requests, the Request Body is always non-nil
// but will return EOF immediately when no body is present.
// The Server will close the request body. The ServeHTTP
// Handler does not need to.
//
// Body must allow Read to be called concurrently with Close.
// In particular, calling Close should unblock a Read waiting
// for input.
Body io.ReadCloser
// GetBody defines an optional func to return a new copy of
// Body. It is used for client requests when a redirect requires
// reading the body more than once. Use of GetBody still
// requires setting Body.
//
// For server requests, it is unused.
GetBody func() (io.ReadCloser, error)
// ContentLength records the length of the associated content.
// The value -1 indicates that the length is unknown.
// Values >= 0 indicate that the given number of bytes may
// be read from Body.
//
// For client requests, a value of 0 with a non-nil Body is
// also treated as unknown.
ContentLength int64
// TransferEncoding lists the transfer encodings from outermost to
// innermost. An empty list denotes the "identity" encoding.
// TransferEncoding can usually be ignored; chunked encoding is
// automatically added and removed as necessary when sending and
// receiving requests.
TransferEncoding []string
// Close indicates whether to close the connection after
// replying to this request (for servers) or after sending this
// request and reading its response (for clients).
//
// For server requests, the HTTP server handles this automatically
// and this field is not needed by Handlers.
//
// For client requests, setting this field prevents re-use of
// TCP connections between requests to the same hosts, as if
// Transport.DisableKeepAlives were set.
Close bool
// For server requests, Host specifies the host on which the
// URL is sought. For HTTP/1 (per RFC 7230, section 5.4), this
// is either the value of the "Host" header or the host name
// given in the URL itself. For HTTP/2, it is the value of the
// ":authority" pseudo-header field.
// It may be of the form "host:port". For international domain
// names, Host may be in Punycode or Unicode form. Use
// golang.org/x/net/idna to convert it to either format if
// needed.
// To prevent DNS rebinding attacks, server Handlers should
// validate that the Host header has a value for which the
// Handler considers itself authoritative. The included
// ServeMux supports patterns registered to particular host
// names and thus protects its registered Handlers.
//
// For client requests, Host optionally overrides the Host
// header to send. If empty, the Request.Write method uses
// the value of URL.Host. Host may contain an international
// domain name.
Host string
// Form contains the parsed form data, including both the URL
// field's query parameters and the PATCH, POST, or PUT form data.
// This field is only available after ParseForm is called.
// The HTTP client ignores Form and uses Body instead.
Form url.Values
// PostForm contains the parsed form data from PATCH, POST
// or PUT body parameters.
//
// This field is only available after ParseForm is called.
// The HTTP client ignores PostForm and uses Body instead.
PostForm url.Values
// MultipartForm is the parsed multipart form, including file uploads.
// This field is only available after ParseMultipartForm is called.
// The HTTP client ignores MultipartForm and uses Body instead.
MultipartForm *multipart.Form
// Trailer specifies additional headers that are sent after the request
// body.
//
// For server requests, the Trailer map initially contains only the
// trailer keys, with nil values. (The client declares which trailers it
// will later send.) While the handler is reading from Body, it must
// not reference Trailer. After reading from Body returns EOF, Trailer
// can be read again and will contain non-nil values, if they were sent
// by the client.
//
// For client requests, Trailer must be initialized to a map containing
// the trailer keys to later send. The values may be nil or their final
// values. The ContentLength must be 0 or -1, to send a chunked request.
// After the HTTP request is sent the map values can be updated while
// the request body is read. Once the body returns EOF, the caller must
// not mutate Trailer.
//
// Few HTTP clients, servers, or proxies support HTTP trailers.
Trailer Header
// RemoteAddr allows HTTP servers and other software to record
// the network address that sent the request, usually for
// logging. This field is not filled in by ReadRequest and
// has no defined format. The HTTP server in this package
// sets RemoteAddr to an "IP:port" address before invoking a
// handler.
// This field is ignored by the HTTP client.
RemoteAddr string
// RequestURI is the unmodified request-target of the
// Request-Line (RFC 7230, Section 3.1.1) as sent by the client
// to a server. Usually the URL field should be used instead.
// It is an error to set this field in an HTTP client request.
RequestURI string
// TLS allows HTTP servers and other software to record
// information about the TLS connection on which the request
// was received. This field is not filled in by ReadRequest.
// The HTTP server in this package sets the field for
// TLS-enabled connections before invoking a handler;
// otherwise it leaves the field nil.
// This field is ignored by the HTTP client.
TLS *tls.ConnectionState
// Cancel is an optional channel whose closure indicates that the client
// request should be regarded as canceled. Not all implementations of
// RoundTripper may support Cancel.
//
// For server requests, this field is not applicable.
//
// Deprecated: Set the Request's context with NewRequestWithContext
// instead. If a Request's Cancel field and context are both
// set, it is undefined whether Cancel is respected.
Cancel <-chan struct{}
// Response is the redirect response which caused this request
// to be created. This field is only populated during client
// redirects.
Response *Response
// ctx is either the client or server context. It should only
// be modified via copying the whole Request using Clone or WithContext.
// It is unexported to prevent people from using Context wrong
// and mutating the contexts held by callers of the same request.
ctx context.Context
}
// Context returns the request's context. To change the context, use
// Clone or WithContext.
//
// The returned context is always non-nil; it defaults to the
// background context.
//
// For outgoing client requests, the context controls cancellation.
//
// For incoming server requests, the context is canceled when the
// client's connection closes, the request is canceled (with HTTP/2),
// or when the ServeHTTP method returns.
func (r *Request) Context() context.Context {
if r.ctx != nil {
return r.ctx
}
return context.Background()
}
// WithContext returns a shallow copy of r with its context changed
// to ctx. The provided ctx must be non-nil.
//
// For outgoing client request, the context controls the entire
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
//
// To create a new request with a context, use NewRequestWithContext.
// To make a deep copy of a request with a new context, use Request.Clone.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := new(Request)
*r2 = *r
r2.ctx = ctx
return r2
}
// Clone returns a deep copy of r with its context changed to ctx.
// The provided ctx must be non-nil.
//
// For an outgoing client request, the context controls the entire
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
func (r *Request) Clone(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
}
r2 := new(Request)
*r2 = *r
r2.ctx = ctx
r2.URL = cloneURL(r.URL)
if r.Header != nil {
r2.Header = r.Header.Clone()
}
if r.Trailer != nil {
r2.Trailer = r.Trailer.Clone()
}
if s := r.TransferEncoding; s != nil {
s2 := make([]string, len(s))
copy(s2, s)
r2.TransferEncoding = s2
}
r2.Form = cloneURLValues(r.Form)
r2.PostForm = cloneURLValues(r.PostForm)
r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
return r2
}
// ProtoAtLeast reports whether the HTTP protocol used
// in the request is at least major.minor.
func (r *Request) ProtoAtLeast(major, minor int) bool {
return r.ProtoMajor > major ||
r.ProtoMajor == major && r.ProtoMinor >= minor
}
// UserAgent returns the client's User-Agent, if sent in the request.
func (r *Request) UserAgent() string {
return r.Header.Get("User-Agent")
}
// Cookies parses and returns the HTTP cookies sent with the request.
func (r *Request) Cookies() []*Cookie {
return readCookies(r.Header, "")
}
// ErrNoCookie is returned by Request's Cookie method when a cookie is not found.
var ErrNoCookie = errors.New("http: named cookie not present")
// Cookie returns the named cookie provided in the request or
// ErrNoCookie if not found.
// If multiple cookies match the given name, only one cookie will
// be returned.
func (r *Request) Cookie(name string) (*Cookie, error) {
if name == "" {
return nil, ErrNoCookie
}
for _, c := range readCookies(r.Header, name) {
return c, nil
}
return nil, ErrNoCookie
}
// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
// AddCookie does not attach more than one Cookie header field. That
// means all cookies, if any, are written into the same line,
// separated by semicolon.
// AddCookie only sanitizes c's name and value, and does not sanitize
// a Cookie header already present in the request.
func (r *Request) AddCookie(c *Cookie) {
s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
if c := r.Header.Get("Cookie"); c != "" {
r.Header.Set("Cookie", c+"; "+s)
} else {
r.Header.Set("Cookie", s)
}
}
// Referer returns the referring URL, if sent in the request.
//
// Referer is misspelled as in the request itself, a mistake from the
// earliest days of HTTP. This value can also be fetched from the
// Header map as Header["Referer"]; the benefit of making it available
// as a method is that the compiler can diagnose programs that use the
// alternate (correct English) spelling req.Referrer() but cannot
// diagnose programs that use Header["Referrer"].
func (r *Request) Referer() string {
return r.Header.Get("Referer")
}
// multipartByReader is a sentinel value.
// Its presence in Request.MultipartForm indicates that parsing of the request
// body has been handed off to a MultipartReader instead of ParseMultipartForm.
var multipartByReader = &multipart.Form{
Value: make(map[string][]string),
File: make(map[string][]*multipart.FileHeader),
}
// MultipartReader returns a MIME multipart reader if this is a
// multipart/form-data or a multipart/mixed POST request, else returns nil and an error.
// Use this function instead of ParseMultipartForm to
// process the request body as a stream.
func (r *Request) MultipartReader() (*multipart.Reader, error) {
if r.MultipartForm == multipartByReader {
return nil, errors.New("http: MultipartReader called twice")
}
if r.MultipartForm != nil {
return nil, errors.New("http: multipart handled by ParseMultipartForm")
}
r.MultipartForm = multipartByReader
return r.multipartReader(true)
}
func (r *Request) multipartReader(allowMixed bool) (*multipart.Reader, error) {
v := r.Header.Get("Content-Type")
if v == "" {
return nil, ErrNotMultipart
}
if r.Body == nil {
return nil, errors.New("missing form body")
}
d, params, err := mime.ParseMediaType(v)
if err != nil || !(d == "multipart/form-data" || allowMixed && d == "multipart/mixed") {
return nil, ErrNotMultipart
}
boundary, ok := params["boundary"]
if !ok {
return nil, ErrMissingBoundary
}
return multipart.NewReader(r.Body, boundary), nil
}
// isH2Upgrade reports whether r represents the http2 "client preface"
// magic string.
func (r *Request) isH2Upgrade() bool {
return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0"
}
// Return value if nonempty, def otherwise.
func valueOrDefault(value, def string) string {
if value != "" {
return value
}
return def
}
// NOTE: This is not intended to reflect the actual Go version being used.
// It was changed at the time of Go 1.1 release because the former User-Agent
// had ended up blocked by some intrusion detection systems.
// See https://codereview.appspot.com/7532043.
const defaultUserAgent = "Go-http-client/1.1"
// Write writes an HTTP/1.1 request, which is the header and body, in wire format.
// This method consults the following fields of the request:
//
// Host
// URL
// Method (defaults to "GET")
// Header
// ContentLength
// TransferEncoding
// Body
//
// If Body is present, Content-Length is <= 0 and TransferEncoding
// hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent.
func (r *Request) Write(w io.Writer) error {
return r.write(w, false, nil, nil)
}
// WriteProxy is like Write but writes the request in the form
// expected by an HTTP proxy. In particular, WriteProxy writes the
// initial Request-URI line of the request with an absolute URI, per
// section 5.3 of RFC 7230, including the scheme and host.
// In either case, WriteProxy also writes a Host header, using
// either r.Host or r.URL.Host.
func (r *Request) WriteProxy(w io.Writer) error {
return r.write(w, true, nil, nil)
}
// errMissingHost is returned by Write when there is no Host or URL present in
// the Request.
var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set")
// extraHeaders may be nil
// waitForContinue may be nil
// always closes body
func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) {
trace := httptrace.ContextClientTrace(r.Context())
if trace != nil && trace.WroteRequest != nil {
defer func() {
trace.WroteRequest(httptrace.WroteRequestInfo{
Err: err,
})
}()
}
closed := false
defer func() {
if closed {
return
}
if closeErr := r.closeBody(); closeErr != nil && err == nil {
err = closeErr
}
}()
// Find the target host. Prefer the Host: header, but if that
// is not given, use the host from the request URL.
//
// Clean the host, in case it arrives with unexpected stuff in it.
host := cleanHost(r.Host)
if host == "" {
if r.URL == nil {
return errMissingHost
}
host = cleanHost(r.URL.Host)
}
// According to RFC 6874, an HTTP client, proxy, or other
// intermediary must remove any IPv6 zone identifier attached
// to an outgoing URI.
host = removeZone(host)
ruri := r.URL.RequestURI()
if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" {
ruri = r.URL.Scheme + "://" + host + ruri
} else if r.Method == "CONNECT" && r.URL.Path == "" {
// CONNECT requests normally give just the host and port, not a full URL.
ruri = host
if r.URL.Opaque != "" {
ruri = r.URL.Opaque
}
}
if stringContainsCTLByte(ruri) {
return errors.New("net/http: can't write control character in Request.URL")
}
// TODO: validate r.Method too? At least it's less likely to
// come from an attacker (more likely to be a constant in
// code).
// Wrap the writer in a bufio Writer if it's not already buffered.
// Don't always call NewWriter, as that forces a bytes.Buffer
// and other small bufio Writers to have a minimum 4k buffer
// size.
var bw *bufio.Writer
if _, ok := w.(io.ByteWriter); !ok {
bw = bufio.NewWriter(w)
w = bw
}
_, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri)
if err != nil {
return err
}
// Header lines
_, err = fmt.Fprintf(w, "Host: %s\r\n", host)
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Host", []string{host})
}
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
userAgent := defaultUserAgent
if r.Header.has("User-Agent") {
userAgent = r.Header.Get("User-Agent")
}
if userAgent != "" {
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
if err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("User-Agent", []string{userAgent})
}
}
// Process Body,ContentLength,Close,Trailer
tw, err := newTransferWriter(r)
if err != nil {
return err
}
err = tw.writeHeader(w, trace)
if err != nil {
return err
}
err = r.Header.writeSubset(w, reqWriteExcludeHeader, trace)
if err != nil {
return err
}
if extraHeaders != nil {
err = extraHeaders.write(w, trace)
if err != nil {
return err
}
}
_, err = io.WriteString(w, "\r\n")
if err != nil {
return err
}
if trace != nil && trace.WroteHeaders != nil {
trace.WroteHeaders()
}
// Flush and wait for 100-continue if expected.
if waitForContinue != nil {
if bw, ok := w.(*bufio.Writer); ok {
err = bw.Flush()
if err != nil {
return err
}
}
if trace != nil && trace.Wait100Continue != nil {
trace.Wait100Continue()
}
if !waitForContinue() {
closed = true
r.closeBody()
return nil
}
}
if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders {
if err := bw.Flush(); err != nil {
return err
}
}
// Write body and trailer
closed = true
err = tw.writeBody(w)
if err != nil {
if tw.bodyReadError == err {
err = requestBodyReadError{err}
}
return err
}
if bw != nil {
return bw.Flush()
}
return nil
}
// requestBodyReadError wraps an error from (*Request).write to indicate
// that the error came from a Read call on the Request.Body.
// This error type should not escape the net/http package to users.
type requestBodyReadError struct{ error }
func idnaASCII(v string) (string, error) {
// TODO: Consider removing this check after verifying performance is okay.
// Right now punycode verification, length checks, context checks, and the
// permissible character tests are all omitted. It also prevents the ToASCII
// call from salvaging an invalid IDN, when possible. As a result it may be
// possible to have two IDNs that appear identical to the user where the
// ASCII-only version causes an error downstream whereas the non-ASCII
// version does not.
// Note that for correct ASCII IDNs ToASCII will only do considerably more
// work, but it will not cause an allocation.
if ascii.Is(v) {
return v, nil
}
return idna.Lookup.ToASCII(v)
}
// cleanHost cleans up the host sent in request's Host header.
//
// It both strips anything after '/' or ' ', and puts the value
// into Punycode form, if necessary.
//
// Ideally we'd clean the Host header according to the spec:
//
// https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]")
// https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host)
// https://tools.ietf.org/html/rfc3986#section-3.2.2 (definition of host)
//
// But practically, what we are trying to avoid is the situation in
// issue 11206, where a malformed Host header used in the proxy context
// would create a bad request. So it is enough to just truncate at the
// first offending character.
func cleanHost(in string) string {
if i := strings.IndexAny(in, " /"); i != -1 {
in = in[:i]
}
host, port, err := net.SplitHostPort(in)
if err != nil { // input was just a host
a, err := idnaASCII(in)
if err != nil {
return in // garbage in, garbage out
}
return a
}
a, err := idnaASCII(host)
if err != nil {
return in // garbage in, garbage out
}
return net.JoinHostPort(a, port)
}
// removeZone removes IPv6 zone identifier from host.
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
func removeZone(host string) string {
if !strings.HasPrefix(host, "[") {
return host
}
i := strings.LastIndex(host, "]")
if i < 0 {
return host
}
j := strings.LastIndex(host[:i], "%")
if j < 0 {
return host
}
return host[:j] + host[i:]
}
// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6.
// "HTTP/1.0" returns (1, 0, true). Note that strings without
// a minor version, such as "HTTP/2", are not valid.
func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
switch vers {
case "HTTP/1.1":
return 1, 1, true
case "HTTP/1.0":
return 1, 0, true
}
if !strings.HasPrefix(vers, "HTTP/") {
return 0, 0, false
}
if len(vers) != len("HTTP/X.Y") {
return 0, 0, false
}
if vers[6] != '.' {
return 0, 0, false
}
maj, err := strconv.ParseUint(vers[5:6], 10, 0)
if err != nil {
return 0, 0, false
}
min, err := strconv.ParseUint(vers[7:8], 10, 0)
if err != nil {
return 0, 0, false
}
return int(maj), int(min), true
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// NewRequest wraps NewRequestWithContext using context.Background.
func NewRequest(method, url string, body io.Reader) (*Request, error) {
return NewRequestWithContext(context.Background(), method, url, body)
}
// NewRequestWithContext returns a new Request given a method, URL, and
// optional body.
//
// If the provided body is also an io.Closer, the returned
// Request.Body is set to body and will be closed by the Client
// methods Do, Post, and PostForm, and Transport.RoundTrip.
//
// NewRequestWithContext returns a Request suitable for use with
// Client.Do or Transport.RoundTrip. To create a request for use with
// testing a Server Handler, either use the NewRequest function in the
// net/http/httptest package, use ReadRequest, or manually update the
// Request fields. For an outgoing client request, the context
// controls the entire lifetime of a request and its response:
// obtaining a connection, sending the request, and reading the
// response headers and body. See the Request type's documentation for
// the difference between inbound and outbound request fields.
//
// If body is of type *bytes.Buffer, *bytes.Reader, or
// *strings.Reader, the returned request's ContentLength is set to its
// exact value (instead of -1), GetBody is populated (so 307 and 308
// redirects can replay the body), and Body is set to NoBody if the
// ContentLength is 0.
func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) {
if method == "" {
// We document that "" means "GET" for Request.Method, and people have
// relied on that from NewRequest, so keep that working.
// We still enforce validMethod for non-empty methods.
method = "GET"
}
if !validMethod(method) {
return nil, fmt.Errorf("net/http: invalid method %q", method)
}
if ctx == nil {
return nil, errors.New("net/http: nil Context")
}
u, err := urlpkg.Parse(url)
if err != nil {
return nil, err
}
rc, ok := body.(io.ReadCloser)
if !ok && body != nil {
rc = io.NopCloser(body)
}
// The host's colon:port should be normalized. See Issue 14836.
u.Host = removeEmptyPort(u.Host)
req := &Request{
ctx: ctx,
Method: method,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(Header),
Body: rc,
Host: u.Host,
}
if body != nil {
switch v := body.(type) {
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
buf := v.Bytes()
req.GetBody = func() (io.ReadCloser, error) {
r := bytes.NewReader(buf)
return io.NopCloser(r), nil
}
case *bytes.Reader:
req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return io.NopCloser(&r), nil
}
case *strings.Reader:
req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return io.NopCloser(&r), nil
}
default:
// This is where we'd set it to -1 (at least
// if body != NoBody) to mean unknown, but
// that broke people during the Go 1.8 testing
// period. People depend on it being 0 I
// guess. Maybe retry later. See Issue 18117.
}
// For client requests, Request.ContentLength of 0
// means either actually 0, or unknown. The only way
// to explicitly say that the ContentLength is zero is
// to set the Body to nil. But turns out too much code
// depends on NewRequest returning a non-nil Body,
// so we use a well-known ReadCloser variable instead
// and have the http package also treat that sentinel
// variable to mean explicitly zero.
if req.GetBody != nil && req.ContentLength == 0 {
req.Body = NoBody
req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil }
}
}
return req, nil
}
// BasicAuth returns the username and password provided in the request's
// Authorization header, if the request uses HTTP Basic Authentication.
// See RFC 2617, Section 2.
func (r *Request) BasicAuth() (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
return "", "", false
}
return parseBasicAuth(auth)
}
// parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) {
return "", "", false
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return "", "", false
}
cs := string(c)
username, password, ok = strings.Cut(cs, ":")
if !ok {
return "", "", false
}
return username, password, true
}
// SetBasicAuth sets the request's Authorization header to use HTTP
// Basic Authentication with the provided username and password.
//
// With HTTP Basic Authentication the provided username and password
// are not encrypted. It should generally only be used in an HTTPS
// request.
//
// The username may not contain a colon. Some protocols may impose
// additional requirements on pre-escaping the username and
// password. For instance, when used with OAuth2, both arguments must
// be URL encoded first with url.QueryEscape.
func (r *Request) SetBasicAuth(username, password string) {
r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
method, rest, ok1 := strings.Cut(line, " ")
requestURI, proto, ok2 := strings.Cut(rest, " ")
if !ok1 || !ok2 {
return "", "", "", false
}
return method, requestURI, proto, true
}
var textprotoReaderPool sync.Pool
func newTextprotoReader(br *bufio.Reader) *textproto.Reader {
if v := textprotoReaderPool.Get(); v != nil {
tr := v.(*textproto.Reader)
tr.R = br
return tr
}
return textproto.NewReader(br)
}
func putTextprotoReader(r *textproto.Reader) {
r.R = nil
textprotoReaderPool.Put(r)
}
// ReadRequest reads and parses an incoming request from b.
//
// ReadRequest is a low-level function and should only be used for
// specialized applications; most code should use the Server to read
// requests and handle them via the Handler interface. ReadRequest
// only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2.
func ReadRequest(b *bufio.Reader) (*Request, error) {
req, err := readRequest(b)
if err != nil {
return nil, err
}
delete(req.Header, "Host")
return req, err
}
func readRequest(b *bufio.Reader) (req *Request, err error) {
tp := newTextprotoReader(b)
defer putTextprotoReader(tp)
req = new(Request)
// First line: GET /index.html HTTP/1.0
var s string
if s, err = tp.ReadLine(); err != nil {
return nil, err
}
defer func() {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
}()
var ok bool
req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s)
if !ok {
return nil, badStringError("malformed HTTP request", s)
}
if !validMethod(req.Method) {
return nil, badStringError("invalid method", req.Method)
}
rawurl := req.RequestURI
if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok {
return nil, badStringError("malformed HTTP version", req.Proto)
}
// CONNECT requests are used two different ways, and neither uses a full URL:
// The standard use is to tunnel HTTPS through an HTTP proxy.
// It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is
// just the authority section of a URL. This information should go in req.URL.Host.
//
// The net/rpc package also uses CONNECT, but there the parameter is a path
// that starts with a slash. It can be parsed with the regular URL parser,
// and the path will end up in req.URL.Path, where it needs to be in order for
// RPC to work.
justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
if justAuthority {
rawurl = "http://" + rawurl
}
if req.URL, err = url.ParseRequestURI(rawurl); err != nil {
return nil, err
}
if justAuthority {
// Strip the bogus "http://" back off.
req.URL.Scheme = ""
}
// Subsequent lines: Key: value.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
return nil, err
}
req.Header = Header(mimeHeader)
if len(req.Header["Host"]) > 1 {
return nil, fmt.Errorf("too many Host headers")
}
// RFC 7230, section 5.3: Must treat
// GET /index.html HTTP/1.1
// Host: www.google.com
// and
// GET http://www.google.com/index.html HTTP/1.1
// Host: doesntmatter
// the same. In the second case, any Host line is ignored.
req.Host = req.URL.Host
if req.Host == "" {
req.Host = req.Header.get("Host")
}
fixPragmaCacheControl(req.Header)
req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false)
err = readTransfer(req, b)
if err != nil {
return nil, err
}
if req.isH2Upgrade() {
// Because it's neither chunked, nor declared:
req.ContentLength = -1
// We want to give handlers a chance to hijack the
// connection, but we need to prevent the Server from
// dealing with the connection further if it's not
// hijacked. Set Close to ensure that:
req.Close = true
}
return req, nil
}
// MaxBytesReader is similar to io.LimitReader but is intended for
// limiting the size of incoming request bodies. In contrast to
// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
// non-nil error of type *MaxBytesError for a Read beyond the limit,
// and closes the underlying reader when its Close method is called.
//
// MaxBytesReader prevents clients from accidentally or maliciously
// sending a large request and wasting server resources. If possible,
// it tells the ResponseWriter to close the connection after the limit
// has been reached.
func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
if n < 0 { // Treat negative limits as equivalent to 0.
n = 0
}
return &maxBytesReader{w: w, r: r, i: n, n: n}
}
// MaxBytesError is returned by MaxBytesReader when its read limit is exceeded.
type MaxBytesError struct {
Limit int64
}
func (e *MaxBytesError) Error() string {
// Due to Hyrum's law, this text cannot be changed.
return "http: request body too large"
}
type maxBytesReader struct {
w ResponseWriter
r io.ReadCloser // underlying reader
i int64 // max bytes initially, for MaxBytesError
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
// 0 < len(p) < 2^63
if int64(len(p))-1 > l.n {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = &MaxBytesError{l.i}
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}
func copyValues(dst, src url.Values) {
for k, vs := range src {
dst[k] = append(dst[k], vs...)
}
}
func parsePostForm(r *Request) (vs url.Values, err error) {
if r.Body == nil {
err = errors.New("missing form body")
return
}
ct := r.Header.Get("Content-Type")
// RFC 7231, section 3.1.1.5 - empty type
// MAY be treated as application/octet-stream
if ct == "" {
ct = "application/octet-stream"
}
ct, _, err = mime.ParseMediaType(ct)
switch {
case ct == "application/x-www-form-urlencoded":
var reader io.Reader = r.Body
maxFormSize := int64(1<<63 - 1)
if _, ok := r.Body.(*maxBytesReader); !ok {
maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
reader = io.LimitReader(r.Body, maxFormSize+1)
}
b, e := io.ReadAll(reader)
if e != nil {
if err == nil {
err = e
}
break
}
if int64(len(b)) > maxFormSize {
err = errors.New("http: POST too large")
return
}
vs, e = url.ParseQuery(string(b))
if err == nil {
err = e
}
case ct == "multipart/form-data":
// handled by ParseMultipartForm (which is calling us, or should be)
// TODO(bradfitz): there are too many possible
// orders to call too many functions here.
// Clean this up and write more tests.
// request_test.go contains the start of this,
// in TestParseMultipartFormOrder and others.
}
return
}
// ParseForm populates r.Form and r.PostForm.
//
// For all requests, ParseForm parses the raw query from the URL and updates
// r.Form.
//
// For POST, PUT, and PATCH requests, it also reads the request body, parses it
// as a form and puts the results into both r.PostForm and r.Form. Request body
// parameters take precedence over URL query string values in r.Form.
//
// If the request Body's size has not already been limited by MaxBytesReader,
// the size is capped at 10MB.
//
// For other HTTP methods, or when the Content-Type is not
// application/x-www-form-urlencoded, the request Body is not read, and
// r.PostForm is initialized to a non-nil, empty value.
//
// ParseMultipartForm calls ParseForm automatically.
// ParseForm is idempotent.
func (r *Request) ParseForm() error {
var err error
if r.PostForm == nil {
if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
r.PostForm, err = parsePostForm(r)
}
if r.PostForm == nil {
r.PostForm = make(url.Values)
}
}
if r.Form == nil {
if len(r.PostForm) > 0 {
r.Form = make(url.Values)
copyValues(r.Form, r.PostForm)
}
var newValues url.Values
if r.URL != nil {
var e error
newValues, e = url.ParseQuery(r.URL.RawQuery)
if err == nil {
err = e
}
}
if newValues == nil {
newValues = make(url.Values)
}
if r.Form == nil {
r.Form = newValues
} else {
copyValues(r.Form, newValues)
}
}
return err
}
// ParseMultipartForm parses a request body as multipart/form-data.
// The whole request body is parsed and up to a total of maxMemory bytes of
// its file parts are stored in memory, with the remainder stored on
// disk in temporary files.
// ParseMultipartForm calls ParseForm if necessary.
// If ParseForm returns an error, ParseMultipartForm returns it but also
// continues parsing the request body.
// After one call to ParseMultipartForm, subsequent calls have no effect.
func (r *Request) ParseMultipartForm(maxMemory int64) error {
if r.MultipartForm == multipartByReader {
return errors.New("http: multipart handled by MultipartReader")
}
var parseFormErr error
if r.Form == nil {
// Let errors in ParseForm fall through, and just
// return it at the end.
parseFormErr = r.ParseForm()
}
if r.MultipartForm != nil {
return nil
}
mr, err := r.multipartReader(false)
if err != nil {
return err
}
f, err := mr.ReadForm(maxMemory)
if err != nil {
return err
}
if r.PostForm == nil {
r.PostForm = make(url.Values)
}
for k, v := range f.Value {
r.Form[k] = append(r.Form[k], v...)
// r.PostForm should also be populated. See Issue 9305.
r.PostForm[k] = append(r.PostForm[k], v...)
}
r.MultipartForm = f
return parseFormErr
}
// FormValue returns the first value for the named component of the query.
// POST and PUT body parameters take precedence over URL query string values.
// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores
// any errors returned by these functions.
// If key is not present, FormValue returns the empty string.
// To access multiple values of the same key, call ParseForm and
// then inspect Request.Form directly.
func (r *Request) FormValue(key string) string {
if r.Form == nil {
r.ParseMultipartForm(defaultMaxMemory)
}
if vs := r.Form[key]; len(vs) > 0 {
return vs[0]
}
return ""
}
// PostFormValue returns the first value for the named component of the POST,
// PATCH, or PUT request body. URL query parameters are ignored.
// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores
// any errors returned by these functions.
// If key is not present, PostFormValue returns the empty string.
func (r *Request) PostFormValue(key string) string {
if r.PostForm == nil {
r.ParseMultipartForm(defaultMaxMemory)
}
if vs := r.PostForm[key]; len(vs) > 0 {
return vs[0]
}
return ""
}
// FormFile returns the first file for the provided form key.
// FormFile calls ParseMultipartForm and ParseForm if necessary.
func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
if r.MultipartForm == multipartByReader {
return nil, nil, errors.New("http: multipart handled by MultipartReader")
}
if r.MultipartForm == nil {
err := r.ParseMultipartForm(defaultMaxMemory)
if err != nil {
return nil, nil, err
}
}
if r.MultipartForm != nil && r.MultipartForm.File != nil {
if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
f, err := fhs[0].Open()
return f, fhs[0], err
}
}
return nil, nil, ErrMissingFile
}
func (r *Request) expectsContinue() bool {
return hasToken(r.Header.get("Expect"), "100-continue")
}
func (r *Request) wantsHttp10KeepAlive() bool {
if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
return false
}
return hasToken(r.Header.get("Connection"), "keep-alive")
}
func (r *Request) wantsClose() bool {
if r.Close {
return true
}
return hasToken(r.Header.get("Connection"), "close")
}
func (r *Request) closeBody() error {
if r.Body == nil {
return nil
}
return r.Body.Close()
}
func (r *Request) isReplayable() bool {
if r.Body == nil || r.Body == NoBody || r.GetBody != nil {
switch valueOrDefault(r.Method, "GET") {
case "GET", "HEAD", "OPTIONS", "TRACE":
return true
}
// The Idempotency-Key, while non-standard, is widely used to
// mean a POST or other request is idempotent. See
// https://golang.org/issue/19943#issuecomment-421092421
if r.Header.has("Idempotency-Key") || r.Header.has("X-Idempotency-Key") {
return true
}
}
return false
}
// outgoingLength reports the Content-Length of this outgoing (Client) request.
// It maps 0 into -1 (unknown) when the Body is non-nil.
func (r *Request) outgoingLength() int64 {
if r.Body == nil || r.Body == NoBody {
return 0
}
if r.ContentLength != 0 {
return r.ContentLength
}
return -1
}
// requestMethodUsuallyLacksBody reports whether the given request
// method is one that typically does not involve a request body.
// This is used by the Transport (via
// transferWriter.shouldSendChunkedRequestBody) to determine whether
// we try to test-read a byte from a non-nil Request.Body when
// Request.outgoingLength() returns -1. See the comments in
// shouldSendChunkedRequestBody.
func requestMethodUsuallyLacksBody(method string) bool {
switch method {
case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH":
return true
}
return false
}
// requiresHTTP1 reports whether this request requires being sent on
// an HTTP/1 connection.
func (r *Request) requiresHTTP1() bool {
return hasToken(r.Header.Get("Connection"), "upgrade") &&
ascii.EqualFold(r.Header.Get("Upgrade"), "websocket")
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP Response reading and parsing.
package http
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net/textproto"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http/httpguts"
)
var respExcludeHeader = map[string]bool{
"Content-Length": true,
"Transfer-Encoding": true,
"Trailer": true,
}
// Response represents the response from an HTTP request.
//
// The Client and Transport return Responses from servers once
// the response headers have been received. The response body
// is streamed on demand as the Body field is read.
type Response struct {
Status string // e.g. "200 OK"
StatusCode int // e.g. 200
Proto string // e.g. "HTTP/1.0"
ProtoMajor int // e.g. 1
ProtoMinor int // e.g. 0
// Header maps header keys to values. If the response had multiple
// headers with the same key, they may be concatenated, with comma
// delimiters. (RFC 7230, section 3.2.2 requires that multiple headers
// be semantically equivalent to a comma-delimited sequence.) When
// Header values are duplicated by other fields in this struct (e.g.,
// ContentLength, TransferEncoding, Trailer), the field values are
// authoritative.
//
// Keys in the map are canonicalized (see CanonicalHeaderKey).
Header Header
// Body represents the response body.
//
// The response body is streamed on demand as the Body field
// is read. If the network connection fails or the server
// terminates the response, Body.Read calls return an error.
//
// The http Client and Transport guarantee that Body is always
// non-nil, even on responses without a body or responses with
// a zero-length body. It is the caller's responsibility to
// close Body. The default HTTP client's Transport may not
// reuse HTTP/1.x "keep-alive" TCP connections if the Body is
// not read to completion and closed.
//
// The Body is automatically dechunked if the server replied
// with a "chunked" Transfer-Encoding.
//
// As of Go 1.12, the Body will also implement io.Writer
// on a successful "101 Switching Protocols" response,
// as used by WebSockets and HTTP/2's "h2c" mode.
Body io.ReadCloser
// ContentLength records the length of the associated content. The
// value -1 indicates that the length is unknown. Unless Request.Method
// is "HEAD", values >= 0 indicate that the given number of bytes may
// be read from Body.
ContentLength int64
// Contains transfer encodings from outer-most to inner-most. Value is
// nil, means that "identity" encoding is used.
TransferEncoding []string
// Close records whether the header directed that the connection be
// closed after reading Body. The value is advice for clients: neither
// ReadResponse nor Response.Write ever closes a connection.
Close bool
// Uncompressed reports whether the response was sent compressed but
// was decompressed by the http package. When true, reading from
// Body yields the uncompressed content instead of the compressed
// content actually set from the server, ContentLength is set to -1,
// and the "Content-Length" and "Content-Encoding" fields are deleted
// from the responseHeader. To get the original response from
// the server, set Transport.DisableCompression to true.
Uncompressed bool
// Trailer maps trailer keys to values in the same
// format as Header.
//
// The Trailer initially contains only nil values, one for
// each key specified in the server's "Trailer" header
// value. Those values are not added to Header.
//
// Trailer must not be accessed concurrently with Read calls
// on the Body.
//
// After Body.Read has returned io.EOF, Trailer will contain
// any trailer values sent by the server.
Trailer Header
// Request is the request that was sent to obtain this Response.
// Request's Body is nil (having already been consumed).
// This is only populated for Client requests.
Request *Request
// TLS contains information about the TLS connection on which the
// response was received. It is nil for unencrypted responses.
// The pointer is shared between responses and should not be
// modified.
TLS *tls.ConnectionState
}
// Cookies parses and returns the cookies set in the Set-Cookie headers.
func (r *Response) Cookies() []*Cookie {
return readSetCookies(r.Header)
}
// ErrNoLocation is returned by Response's Location method
// when no Location header is present.
var ErrNoLocation = errors.New("http: no Location header in response")
// Location returns the URL of the response's "Location" header,
// if present. Relative redirects are resolved relative to
// the Response's Request. ErrNoLocation is returned if no
// Location header is present.
func (r *Response) Location() (*url.URL, error) {
lv := r.Header.Get("Location")
if lv == "" {
return nil, ErrNoLocation
}
if r.Request != nil && r.Request.URL != nil {
return r.Request.URL.Parse(lv)
}
return url.Parse(lv)
}
// ReadResponse reads and returns an HTTP response from r.
// The req parameter optionally specifies the Request that corresponds
// to this Response. If nil, a GET request is assumed.
// Clients must call resp.Body.Close when finished reading resp.Body.
// After that call, clients can inspect resp.Trailer to find key/value
// pairs included in the response trailer.
func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
tp := textproto.NewReader(r)
resp := &Response{
Request: req,
}
// Parse the first line of the response.
line, err := tp.ReadLine()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
proto, status, ok := strings.Cut(line, " ")
if !ok {
return nil, badStringError("malformed HTTP response", line)
}
resp.Proto = proto
resp.Status = strings.TrimLeft(status, " ")
statusCode, _, _ := strings.Cut(resp.Status, " ")
if len(statusCode) != 3 {
return nil, badStringError("malformed HTTP status code", statusCode)
}
resp.StatusCode, err = strconv.Atoi(statusCode)
if err != nil || resp.StatusCode < 0 {
return nil, badStringError("malformed HTTP status code", statusCode)
}
if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
return nil, badStringError("malformed HTTP version", resp.Proto)
}
// Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err
}
resp.Header = Header(mimeHeader)
fixPragmaCacheControl(resp.Header)
err = readTransfer(resp, r)
if err != nil {
return nil, err
}
return resp, nil
}
// RFC 7234, section 5.4: Should treat
//
// Pragma: no-cache
//
// like
//
// Cache-Control: no-cache
func fixPragmaCacheControl(header Header) {
if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
if _, presentcc := header["Cache-Control"]; !presentcc {
header["Cache-Control"] = []string{"no-cache"}
}
}
}
// ProtoAtLeast reports whether the HTTP protocol used
// in the response is at least major.minor.
func (r *Response) ProtoAtLeast(major, minor int) bool {
return r.ProtoMajor > major ||
r.ProtoMajor == major && r.ProtoMinor >= minor
}
// Write writes r to w in the HTTP/1.x server response format,
// including the status line, headers, body, and optional trailer.
//
// This method consults the following fields of the response r:
//
// StatusCode
// ProtoMajor
// ProtoMinor
// Request.Method
// TransferEncoding
// Trailer
// Body
// ContentLength
// Header, values for non-canonical keys will have unpredictable behavior
//
// The Response Body is closed after it is sent.
func (r *Response) Write(w io.Writer) error {
// Status line
text := r.Status
if text == "" {
text = StatusText(r.StatusCode)
if text == "" {
text = "status code " + strconv.Itoa(r.StatusCode)
}
} else {
// Just to reduce stutter, if user set r.Status to "200 OK" and StatusCode to 200.
// Not important.
text = strings.TrimPrefix(text, strconv.Itoa(r.StatusCode)+" ")
}
if _, err := fmt.Fprintf(w, "HTTP/%d.%d %03d %s\r\n", r.ProtoMajor, r.ProtoMinor, r.StatusCode, text); err != nil {
return err
}
// Clone it, so we can modify r1 as needed.
r1 := new(Response)
*r1 = *r
if r1.ContentLength == 0 && r1.Body != nil {
// Is it actually 0 length? Or just unknown?
var buf [1]byte
n, err := r1.Body.Read(buf[:])
if err != nil && err != io.EOF {
return err
}
if n == 0 {
// Reset it to a known zero reader, in case underlying one
// is unhappy being read repeatedly.
r1.Body = NoBody
} else {
r1.ContentLength = -1
r1.Body = struct {
io.Reader
io.Closer
}{
io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
r.Body,
}
}
}
// If we're sending a non-chunked HTTP/1.1 response without a
// content-length, the only way to do that is the old HTTP/1.0
// way, by noting the EOF with a connection close, so we need
// to set Close.
if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) && !r1.Uncompressed {
r1.Close = true
}
// Process Body,ContentLength,Close,Trailer
tw, err := newTransferWriter(r1)
if err != nil {
return err
}
err = tw.writeHeader(w, nil)
if err != nil {
return err
}
// Rest of header
err = r.Header.WriteSubset(w, respExcludeHeader)
if err != nil {
return err
}
// contentLengthAlreadySent may have been already sent for
// POST/PUT requests, even if zero length. See Issue 8180.
contentLengthAlreadySent := tw.shouldSendContentLength()
if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) {
if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
return err
}
}
// End-of-header
if _, err := io.WriteString(w, "\r\n"); err != nil {
return err
}
// Write body and trailer
err = tw.writeBody(w)
if err != nil {
return err
}
// Success
return nil
}
func (r *Response) closeBody() {
if r.Body != nil {
r.Body.Close()
}
}
// bodyIsWritable reports whether the Body supports writing. The
// Transport returns Writable bodies for 101 Switching Protocols
// responses.
// The Transport uses this method to determine whether a persistent
// connection is done being managed from its perspective. Once we
// return a writable response body to a user, the net/http package is
// done managing that connection.
func (r *Response) bodyIsWritable() bool {
_, ok := r.Body.(io.Writer)
return ok
}
// isProtocolSwitch reports whether the response code and header
// indicate a successful protocol upgrade response.
func (r *Response) isProtocolSwitch() bool {
return isProtocolSwitchResponse(r.StatusCode, r.Header)
}
// isProtocolSwitchResponse reports whether the response code and
// response header indicate a successful protocol upgrade response.
func isProtocolSwitchResponse(code int, h Header) bool {
return code == StatusSwitchingProtocols && isProtocolSwitchHeader(h)
}
// isProtocolSwitchHeader reports whether the request or response header
// is for a protocol switch.
func isProtocolSwitchHeader(h Header) bool {
return h.Get("Upgrade") != "" &&
httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade")
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"bufio"
"fmt"
"net"
"time"
)
// A ResponseController is used by an HTTP handler to control the response.
//
// A ResponseController may not be used after the Handler.ServeHTTP method has returned.
type ResponseController struct {
rw ResponseWriter
}
// NewResponseController creates a ResponseController for a request.
//
// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method,
// or have an Unwrap method returning the original ResponseWriter.
//
// If the ResponseWriter implements any of the following methods, the ResponseController
// will call them as appropriate:
//
// Flush()
// FlushError() error // alternative Flush returning an error
// Hijack() (net.Conn, *bufio.ReadWriter, error)
// SetReadDeadline(deadline time.Time) error
// SetWriteDeadline(deadline time.Time) error
//
// If the ResponseWriter does not support a method, ResponseController returns
// an error matching ErrNotSupported.
func NewResponseController(rw ResponseWriter) *ResponseController {
return &ResponseController{rw}
}
type rwUnwrapper interface {
Unwrap() ResponseWriter
}
// Flush flushes buffered data to the client.
func (c *ResponseController) Flush() error {
rw := c.rw
for {
switch t := rw.(type) {
case interface{ FlushError() error }:
return t.FlushError()
case Flusher:
t.Flush()
return nil
case rwUnwrapper:
rw = t.Unwrap()
default:
return errNotSupported()
}
}
}
// Hijack lets the caller take over the connection.
// See the Hijacker interface for details.
func (c *ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := c.rw
for {
switch t := rw.(type) {
case Hijacker:
return t.Hijack()
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, nil, errNotSupported()
}
}
}
// SetReadDeadline sets the deadline for reading the entire request, including the body.
// Reads from the request body after the deadline has been exceeded will return an error.
// A zero value means no deadline.
//
// Setting the read deadline after it has been exceeded will not extend it.
func (c *ResponseController) SetReadDeadline(deadline time.Time) error {
rw := c.rw
for {
switch t := rw.(type) {
case interface{ SetReadDeadline(time.Time) error }:
return t.SetReadDeadline(deadline)
case rwUnwrapper:
rw = t.Unwrap()
default:
return errNotSupported()
}
}
}
// SetWriteDeadline sets the deadline for writing the response.
// Writes to the response body after the deadline has been exceeded will not block,
// but may succeed if the data has been buffered.
// A zero value means no deadline.
//
// Setting the write deadline after it has been exceeded will not extend it.
func (c *ResponseController) SetWriteDeadline(deadline time.Time) error {
rw := c.rw
for {
switch t := rw.(type) {
case interface{ SetWriteDeadline(time.Time) error }:
return t.SetWriteDeadline(deadline)
case rwUnwrapper:
rw = t.Unwrap()
default:
return errNotSupported()
}
}
}
// errNotSupported returns an error that Is ErrNotSupported,
// but is not == to it.
func errNotSupported() error {
return fmt.Errorf("%w", ErrNotSupported)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js || !wasm
package http
// RoundTrip implements the RoundTripper interface.
//
// For higher-level HTTP client support (such as handling of cookies
// and redirects), see Get, Post, and the Client type.
//
// Like the RoundTripper interface, the error types returned
// by RoundTrip are unspecified.
func (t *Transport) RoundTrip(req *Request) (*Response, error) {
return t.roundTrip(req)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP server. See RFC 7230 through 7235.
package http
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"internal/godebug"
"io"
"log"
"math/rand"
"net"
"net/textproto"
"net/url"
urlpkg "net/url"
"path"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http/httpguts"
)
// Errors used by the HTTP server.
var (
// ErrBodyNotAllowed is returned by ResponseWriter.Write calls
// when the HTTP method or response code does not permit a
// body.
ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body")
// ErrHijacked is returned by ResponseWriter.Write calls when
// the underlying connection has been hijacked using the
// Hijacker interface. A zero-byte write on a hijacked
// connection will return ErrHijacked without any other side
// effects.
ErrHijacked = errors.New("http: connection has been hijacked")
// ErrContentLength is returned by ResponseWriter.Write calls
// when a Handler set a Content-Length response header with a
// declared size and then attempted to write more bytes than
// declared.
ErrContentLength = errors.New("http: wrote more than the declared Content-Length")
// Deprecated: ErrWriteAfterFlush is no longer returned by
// anything in the net/http package. Callers should not
// compare errors against this variable.
ErrWriteAfterFlush = errors.New("unused")
)
// A Handler responds to an HTTP request.
//
// ServeHTTP should write reply headers and data to the ResponseWriter
// and then return. Returning signals that the request is finished; it
// is not valid to use the ResponseWriter or read from the
// Request.Body after or concurrently with the completion of the
// ServeHTTP call.
//
// Depending on the HTTP client software, HTTP protocol version, and
// any intermediaries between the client and the Go server, it may not
// be possible to read from the Request.Body after writing to the
// ResponseWriter. Cautious handlers should read the Request.Body
// first, and then reply.
//
// Except for reading the body, handlers should not modify the
// provided Request.
//
// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes
// that the effect of the panic was isolated to the active request.
// It recovers the panic, logs a stack trace to the server error log,
// and either closes the network connection or sends an HTTP/2
// RST_STREAM, depending on the HTTP protocol. To abort a handler so
// the client sees an interrupted response but the server doesn't log
// an error, panic with the value ErrAbortHandler.
type Handler interface {
ServeHTTP(ResponseWriter, *Request)
}
// A ResponseWriter interface is used by an HTTP handler to
// construct an HTTP response.
//
// A ResponseWriter may not be used after the Handler.ServeHTTP method
// has returned.
type ResponseWriter interface {
// Header returns the header map that will be sent by
// WriteHeader. The Header map also is the mechanism with which
// Handlers can set HTTP trailers.
//
// Changing the header map after a call to WriteHeader (or
// Write) has no effect unless the HTTP status code was of the
// 1xx class or the modified headers are trailers.
//
// There are two ways to set Trailers. The preferred way is to
// predeclare in the headers which trailers you will later
// send by setting the "Trailer" header to the names of the
// trailer keys which will come later. In this case, those
// keys of the Header map are treated as if they were
// trailers. See the example. The second way, for trailer
// keys not known to the Handler until after the first Write,
// is to prefix the Header map keys with the TrailerPrefix
// constant value. See TrailerPrefix.
//
// To suppress automatic response headers (such as "Date"), set
// their value to nil.
Header() Header
// Write writes the data to the connection as part of an HTTP reply.
//
// If WriteHeader has not yet been called, Write calls
// WriteHeader(http.StatusOK) before writing the data. If the Header
// does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to
// DetectContentType. Additionally, if the total size of all written
// data is under a few KB and there are no Flush calls, the
// Content-Length header is added automatically.
//
// Depending on the HTTP protocol version and the client, calling
// Write or WriteHeader may prevent future reads on the
// Request.Body. For HTTP/1.x requests, handlers should read any
// needed request body data before writing the response. Once the
// headers have been flushed (due to either an explicit Flusher.Flush
// call or writing enough data to trigger a flush), the request body
// may be unavailable. For HTTP/2 requests, the Go HTTP server permits
// handlers to continue to read the request body while concurrently
// writing the response. However, such behavior may not be supported
// by all HTTP/2 clients. Handlers should read before writing if
// possible to maximize compatibility.
Write([]byte) (int, error)
// WriteHeader sends an HTTP response header with the provided
// status code.
//
// If WriteHeader is not called explicitly, the first call to Write
// will trigger an implicit WriteHeader(http.StatusOK).
// Thus explicit calls to WriteHeader are mainly used to
// send error codes or 1xx informational responses.
//
// The provided code must be a valid HTTP 1xx-5xx status code.
// Any number of 1xx headers may be written, followed by at most
// one 2xx-5xx header. 1xx headers are sent immediately, but 2xx-5xx
// headers may be buffered. Use the Flusher interface to send
// buffered data. The header map is cleared when 2xx-5xx headers are
// sent, but not with 1xx headers.
//
// The server will automatically send a 100 (Continue) header
// on the first read from the request body if the request has
// an "Expect: 100-continue" header.
WriteHeader(statusCode int)
}
// The Flusher interface is implemented by ResponseWriters that allow
// an HTTP handler to flush buffered data to the client.
//
// The default HTTP/1.x and HTTP/2 ResponseWriter implementations
// support Flusher, but ResponseWriter wrappers may not. Handlers
// should always test for this ability at runtime.
//
// Note that even for ResponseWriters that support Flush,
// if the client is connected through an HTTP proxy,
// the buffered data may not reach the client until the response
// completes.
type Flusher interface {
// Flush sends any buffered data to the client.
Flush()
}
// The Hijacker interface is implemented by ResponseWriters that allow
// an HTTP handler to take over the connection.
//
// The default ResponseWriter for HTTP/1.x connections supports
// Hijacker, but HTTP/2 connections intentionally do not.
// ResponseWriter wrappers may also not support Hijacker. Handlers
// should always test for this ability at runtime.
type Hijacker interface {
// Hijack lets the caller take over the connection.
// After a call to Hijack the HTTP server library
// will not do anything else with the connection.
//
// It becomes the caller's responsibility to manage
// and close the connection.
//
// The returned net.Conn may have read or write deadlines
// already set, depending on the configuration of the
// Server. It is the caller's responsibility to set
// or clear those deadlines as needed.
//
// The returned bufio.Reader may contain unprocessed buffered
// data from the client.
//
// After a call to Hijack, the original Request.Body must not
// be used. The original Request's Context remains valid and
// is not canceled until the Request's ServeHTTP method
// returns.
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
// The CloseNotifier interface is implemented by ResponseWriters which
// allow detecting when the underlying connection has gone away.
//
// This mechanism can be used to cancel long operations on the server
// if the client has disconnected before the response is ready.
//
// Deprecated: the CloseNotifier interface predates Go's context package.
// New code should use Request.Context instead.
type CloseNotifier interface {
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
// away.
//
// CloseNotify may wait to notify until Request.Body has been
// fully read.
//
// After the Handler has returned, there is no guarantee
// that the channel receives a value.
//
// If the protocol is HTTP/1.1 and CloseNotify is called while
// processing an idempotent request (such a GET) while
// HTTP/1.1 pipelining is in use, the arrival of a subsequent
// pipelined request may cause a value to be sent on the
// returned channel. In practice HTTP/1.1 pipelining is not
// enabled in browsers and not seen often in the wild. If this
// is a problem, use HTTP/2 or only use CloseNotify on methods
// such as POST.
CloseNotify() <-chan bool
}
var (
// ServerContextKey is a context key. It can be used in HTTP
// handlers with Context.Value to access the server that
// started the handler. The associated value will be of
// type *Server.
ServerContextKey = &contextKey{"http-server"}
// LocalAddrContextKey is a context key. It can be used in
// HTTP handlers with Context.Value to access the local
// address the connection arrived on.
// The associated value will be of type net.Addr.
LocalAddrContextKey = &contextKey{"local-addr"}
)
// A conn represents the server side of an HTTP connection.
type conn struct {
// server is the server on which the connection arrived.
// Immutable; never nil.
server *Server
// cancelCtx cancels the connection-level context.
cancelCtx context.CancelFunc
// rwc is the underlying network connection.
// This is never wrapped by other types and is the value given out
// to CloseNotifier callers. It is usually of type *net.TCPConn or
// *tls.Conn.
rwc net.Conn
// remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously
// inside the Listener's Accept goroutine, as some implementations block.
// It is populated immediately inside the (*conn).serve goroutine.
// This is the value of a Handler's (*Request).RemoteAddr.
remoteAddr string
// tlsState is the TLS connection state when using TLS.
// nil means not TLS.
tlsState *tls.ConnectionState
// werr is set to the first write error to rwc.
// It is set via checkConnErrorWriter{w}, where bufw writes.
werr error
// r is bufr's read source. It's a wrapper around rwc that provides
// io.LimitedReader-style limiting (while reading request headers)
// and functionality to support CloseNotifier. See *connReader docs.
r *connReader
// bufr reads from r.
bufr *bufio.Reader
// bufw writes to checkConnErrorWriter{c}, which populates werr on error.
bufw *bufio.Writer
// lastMethod is the method of the most recent request
// on this connection, if any.
lastMethod string
curReq atomic.Pointer[response] // (which has a Request in it)
curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
// mu guards hijackedv
mu sync.Mutex
// hijackedv is whether this connection has been hijacked
// by a Handler with the Hijacker interface.
// It is guarded by mu.
hijackedv bool
}
func (c *conn) hijacked() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.hijackedv
}
// c.mu must be held.
func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if c.hijackedv {
return nil, nil, ErrHijacked
}
c.r.abortPendingRead()
c.hijackedv = true
rwc = c.rwc
rwc.SetDeadline(time.Time{})
buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
if c.r.hasByte {
if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
}
}
c.setState(rwc, StateHijacked, runHooks)
return
}
// This should be >= 512 bytes for DetectContentType,
// but otherwise it's somewhat arbitrary.
const bufferBeforeChunkingSize = 2048
// chunkWriter writes to a response's conn buffer, and is the writer
// wrapped by the response.w buffered writer.
//
// chunkWriter also is responsible for finalizing the Header, including
// conditionally setting the Content-Type and setting a Content-Length
// in cases where the handler's final output is smaller than the buffer
// size. It also conditionally adds chunk headers, when in chunking mode.
//
// See the comment above (*response).Write for the entire write flow.
type chunkWriter struct {
res *response
// header is either nil or a deep clone of res.handlerHeader
// at the time of res.writeHeader, if res.writeHeader is
// called and extra buffering is being done to calculate
// Content-Type and/or Content-Length.
header Header
// wroteHeader tells whether the header's been written to "the
// wire" (or rather: w.conn.buf). this is unlike
// (*response).wroteHeader, which tells only whether it was
// logically written.
wroteHeader bool
// set by the writeHeader method:
chunking bool // using chunked transfer encoding for reply body
}
var (
crlf = []byte("\r\n")
colonSpace = []byte(": ")
)
func (cw *chunkWriter) Write(p []byte) (n int, err error) {
if !cw.wroteHeader {
cw.writeHeader(p)
}
if cw.res.req.Method == "HEAD" {
// Eat writes.
return len(p), nil
}
if cw.chunking {
_, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p))
if err != nil {
cw.res.conn.rwc.Close()
return
}
}
n, err = cw.res.conn.bufw.Write(p)
if cw.chunking && err == nil {
_, err = cw.res.conn.bufw.Write(crlf)
}
if err != nil {
cw.res.conn.rwc.Close()
}
return
}
func (cw *chunkWriter) flush() error {
if !cw.wroteHeader {
cw.writeHeader(nil)
}
return cw.res.conn.bufw.Flush()
}
func (cw *chunkWriter) close() {
if !cw.wroteHeader {
cw.writeHeader(nil)
}
if cw.chunking {
bw := cw.res.conn.bufw // conn's bufio writer
// zero chunk to mark EOF
bw.WriteString("0\r\n")
if trailers := cw.res.finalTrailers(); trailers != nil {
trailers.Write(bw) // the writer handles noting errors
}
// final blank line after the trailers (whether
// present or not)
bw.WriteString("\r\n")
}
}
// A response represents the server side of an HTTP response.
type response struct {
conn *conn
req *Request // request for this response
reqBody io.ReadCloser
cancelCtx context.CancelFunc // when ServeHTTP exits
wroteHeader bool // a non-1xx header has been (logically) written
wroteContinue bool // 100 Continue response was written
wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive"
wantsClose bool // HTTP request has Connection "close"
// canWriteContinue is an atomic boolean that says whether or
// not a 100 Continue header can be written to the
// connection.
// writeContinueMu must be held while writing the header.
// These two fields together synchronize the body reader (the
// expectContinueReader, which wants to write 100 Continue)
// against the main writer.
canWriteContinue atomic.Bool
writeContinueMu sync.Mutex
w *bufio.Writer // buffers output in chunks to chunkWriter
cw chunkWriter
// handlerHeader is the Header that Handlers get access to,
// which may be retained and mutated even after WriteHeader.
// handlerHeader is copied into cw.header at WriteHeader
// time, and privately mutated thereafter.
handlerHeader Header
calledHeader bool // handler accessed handlerHeader via Header
written int64 // number of bytes written in body
contentLength int64 // explicitly-declared Content-Length; or -1
status int // status code passed to WriteHeader
// close connection after this reply. set on request and
// updated after response from handler if there's a
// "Connection: keep-alive" response header and a
// Content-Length.
closeAfterReply bool
// requestBodyLimitHit is set by requestTooLarge when
// maxBytesReader hits its max size. It is checked in
// WriteHeader, to make sure we don't consume the
// remaining request body to try to advance to the next HTTP
// request. Instead, when this is set, we stop reading
// subsequent requests on this connection and stop reading
// input from it.
requestBodyLimitHit bool
// trailers are the headers to be sent after the handler
// finishes writing the body. This field is initialized from
// the Trailer response header when the response header is
// written.
trailers []string
handlerDone atomic.Bool // set true when the handler exits
// Buffers for Date, Content-Length, and status code
dateBuf [len(TimeFormat)]byte
clenBuf [10]byte
statusBuf [3]byte
// closeNotifyCh is the channel returned by CloseNotify.
// TODO(bradfitz): this is currently (for Go 1.8) always
// non-nil. Make this lazily-created again as it used to be?
closeNotifyCh chan bool
didCloseNotify atomic.Bool // atomic (only false->true winner should send)
}
func (c *response) SetReadDeadline(deadline time.Time) error {
return c.conn.rwc.SetReadDeadline(deadline)
}
func (c *response) SetWriteDeadline(deadline time.Time) error {
return c.conn.rwc.SetWriteDeadline(deadline)
}
// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
// that, if present, signals that the map entry is actually for
// the response trailers, and not the response headers. The prefix
// is stripped after the ServeHTTP call finishes and the values are
// sent in the trailers.
//
// This mechanism is intended only for trailers that are not known
// prior to the headers being written. If the set of trailers is fixed
// or known before the header is written, the normal Go trailers mechanism
// is preferred:
//
// https://pkg.go.dev/net/http#ResponseWriter
// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers
const TrailerPrefix = "Trailer:"
// finalTrailers is called after the Handler exits and returns a non-nil
// value if the Handler set any trailers.
func (w *response) finalTrailers() Header {
var t Header
for k, vv := range w.handlerHeader {
if kk, found := strings.CutPrefix(k, TrailerPrefix); found {
if t == nil {
t = make(Header)
}
t[kk] = vv
}
}
for _, k := range w.trailers {
if t == nil {
t = make(Header)
}
for _, v := range w.handlerHeader[k] {
t.Add(k, v)
}
}
return t
}
// declareTrailer is called for each Trailer header when the
// response header is written. It notes that a header will need to be
// written in the trailers at the end of the response.
func (w *response) declareTrailer(k string) {
k = CanonicalHeaderKey(k)
if !httpguts.ValidTrailerHeader(k) {
// Forbidden by RFC 7230, section 4.1.2
return
}
w.trailers = append(w.trailers, k)
}
// requestTooLarge is called by maxBytesReader when too much input has
// been read from the client.
func (w *response) requestTooLarge() {
w.closeAfterReply = true
w.requestBodyLimitHit = true
if !w.wroteHeader {
w.Header().Set("Connection", "close")
}
}
// writerOnly hides an io.Writer value's optional ReadFrom method
// from io.Copy.
type writerOnly struct {
io.Writer
}
// ReadFrom is here to optimize copying from an *os.File regular file
// to a *net.TCPConn with sendfile, or from a supported src type such
// as a *net.TCPConn on Linux with splice.
func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
bufp := copyBufPool.Get().(*[]byte)
buf := *bufp
defer copyBufPool.Put(bufp)
// Our underlying w.conn.rwc is usually a *TCPConn (with its
// own ReadFrom method). If not, just fall back to the normal
// copy method.
rf, ok := w.conn.rwc.(io.ReaderFrom)
if !ok {
return io.CopyBuffer(writerOnly{w}, src, buf)
}
// Copy the first sniffLen bytes before switching to ReadFrom.
// This ensures we don't start writing the response before the
// source is available (see golang.org/issue/5660) and provides
// enough bytes to perform Content-Type sniffing when required.
if !w.cw.wroteHeader {
n0, err := io.CopyBuffer(writerOnly{w}, io.LimitReader(src, sniffLen), buf)
n += n0
if err != nil || n0 < sniffLen {
return n, err
}
}
w.w.Flush() // get rid of any previous writes
w.cw.flush() // make sure Header is written; flush data to rwc
// Now that cw has been flushed, its chunking field is guaranteed initialized.
if !w.cw.chunking && w.bodyAllowed() {
n0, err := rf.ReadFrom(src)
n += n0
w.written += n0
return n, err
}
n0, err := io.CopyBuffer(writerOnly{w}, src, buf)
n += n0
return n, err
}
// debugServerConnections controls whether all server connections are wrapped
// with a verbose logging wrapper.
const debugServerConnections = false
// Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) *conn {
c := &conn{
server: srv,
rwc: rwc,
}
if debugServerConnections {
c.rwc = newLoggingConn("server", c.rwc)
}
return c
}
type readResult struct {
_ incomparable
n int
err error
b byte // byte read, if n == 1
}
// connReader is the io.Reader wrapper used by *conn. It combines a
// selectively-activated io.LimitedReader (to bound request header
// read sizes) with support for selectively keeping an io.Reader.Read
// call blocked in a background goroutine to wait for activity and
// trigger a CloseNotifier channel.
type connReader struct {
conn *conn
mu sync.Mutex // guards following
hasByte bool
byteBuf [1]byte
cond *sync.Cond
inRead bool
aborted bool // set true before conn.rwc deadline is set to past
remain int64 // bytes remaining
}
func (cr *connReader) lock() {
cr.mu.Lock()
if cr.cond == nil {
cr.cond = sync.NewCond(&cr.mu)
}
}
func (cr *connReader) unlock() { cr.mu.Unlock() }
func (cr *connReader) startBackgroundRead() {
cr.lock()
defer cr.unlock()
if cr.inRead {
panic("invalid concurrent Body.Read call")
}
if cr.hasByte {
return
}
cr.inRead = true
cr.conn.rwc.SetReadDeadline(time.Time{})
go cr.backgroundRead()
}
func (cr *connReader) backgroundRead() {
n, err := cr.conn.rwc.Read(cr.byteBuf[:])
cr.lock()
if n == 1 {
cr.hasByte = true
// We were past the end of the previous request's body already
// (since we wouldn't be in a background read otherwise), so
// this is a pipelined HTTP request. Prior to Go 1.11 we used to
// send on the CloseNotify channel and cancel the context here,
// but the behavior was documented as only "may", and we only
// did that because that's how CloseNotify accidentally behaved
// in very early Go releases prior to context support. Once we
// added context support, people used a Handler's
// Request.Context() and passed it along. Having that context
// cancel on pipelined HTTP requests caused problems.
// Fortunately, almost nothing uses HTTP/1.x pipelining.
// Unfortunately, apt-get does, or sometimes does.
// New Go 1.11 behavior: don't fire CloseNotify or cancel
// contexts on pipelined requests. Shouldn't affect people, but
// fixes cases like Issue 23921. This does mean that a client
// closing their TCP connection after sending a pipelined
// request won't cancel the context, but we'll catch that on any
// write failure (in checkConnErrorWriter.Write).
// If the server never writes, yes, there are still contrived
// server & client behaviors where this fails to ever cancel the
// context, but that's kinda why HTTP/1.x pipelining died
// anyway.
}
if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() {
// Ignore this error. It's the expected error from
// another goroutine calling abortPendingRead.
} else if err != nil {
cr.handleReadError(err)
}
cr.aborted = false
cr.inRead = false
cr.unlock()
cr.cond.Broadcast()
}
func (cr *connReader) abortPendingRead() {
cr.lock()
defer cr.unlock()
if !cr.inRead {
return
}
cr.aborted = true
cr.conn.rwc.SetReadDeadline(aLongTimeAgo)
for cr.inRead {
cr.cond.Wait()
}
cr.conn.rwc.SetReadDeadline(time.Time{})
}
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
// handleReadError is called whenever a Read from the client returns a
// non-nil error.
//
// The provided non-nil err is almost always io.EOF or a "use of
// closed network connection". In any case, the error is not
// particularly interesting, except perhaps for debugging during
// development. Any error means the connection is dead and we should
// down its context.
//
// It may be called from multiple goroutines.
func (cr *connReader) handleReadError(_ error) {
cr.conn.cancelCtx()
cr.closeNotify()
}
// may be called from multiple goroutines.
func (cr *connReader) closeNotify() {
res := cr.conn.curReq.Load()
if res != nil && !res.didCloseNotify.Swap(true) {
res.closeNotifyCh <- true
}
}
func (cr *connReader) Read(p []byte) (n int, err error) {
cr.lock()
if cr.inRead {
cr.unlock()
if cr.conn.hijacked() {
panic("invalid Body.Read call. After hijacked, the original Request must not be used")
}
panic("invalid concurrent Body.Read call")
}
if cr.hitReadLimit() {
cr.unlock()
return 0, io.EOF
}
if len(p) == 0 {
cr.unlock()
return 0, nil
}
if int64(len(p)) > cr.remain {
p = p[:cr.remain]
}
if cr.hasByte {
p[0] = cr.byteBuf[0]
cr.hasByte = false
cr.unlock()
return 1, nil
}
cr.inRead = true
cr.unlock()
n, err = cr.conn.rwc.Read(p)
cr.lock()
cr.inRead = false
if err != nil {
cr.handleReadError(err)
}
cr.remain -= int64(n)
cr.unlock()
cr.cond.Broadcast()
return n, err
}
var (
bufioReaderPool sync.Pool
bufioWriter2kPool sync.Pool
bufioWriter4kPool sync.Pool
)
var copyBufPool = sync.Pool{
New: func() any {
b := make([]byte, 32*1024)
return &b
},
}
func bufioWriterPool(size int) *sync.Pool {
switch size {
case 2 << 10:
return &bufioWriter2kPool
case 4 << 10:
return &bufioWriter4kPool
}
return nil
}
func newBufioReader(r io.Reader) *bufio.Reader {
if v := bufioReaderPool.Get(); v != nil {
br := v.(*bufio.Reader)
br.Reset(r)
return br
}
// Note: if this reader size is ever changed, update
// TestHandlerBodyClose's assumptions.
return bufio.NewReader(r)
}
func putBufioReader(br *bufio.Reader) {
br.Reset(nil)
bufioReaderPool.Put(br)
}
func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
pool := bufioWriterPool(size)
if pool != nil {
if v := pool.Get(); v != nil {
bw := v.(*bufio.Writer)
bw.Reset(w)
return bw
}
}
return bufio.NewWriterSize(w, size)
}
func putBufioWriter(bw *bufio.Writer) {
bw.Reset(nil)
if pool := bufioWriterPool(bw.Available()); pool != nil {
pool.Put(bw)
}
}
// DefaultMaxHeaderBytes is the maximum permitted size of the headers
// in an HTTP request.
// This can be overridden by setting Server.MaxHeaderBytes.
const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
func (srv *Server) maxHeaderBytes() int {
if srv.MaxHeaderBytes > 0 {
return srv.MaxHeaderBytes
}
return DefaultMaxHeaderBytes
}
func (srv *Server) initialReadLimitSize() int64 {
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
// tlsHandshakeTimeout returns the time limit permitted for the TLS
// handshake, or zero for unlimited.
//
// It returns the minimum of any positive ReadHeaderTimeout,
// ReadTimeout, or WriteTimeout.
func (srv *Server) tlsHandshakeTimeout() time.Duration {
var ret time.Duration
for _, v := range [...]time.Duration{
srv.ReadHeaderTimeout,
srv.ReadTimeout,
srv.WriteTimeout,
} {
if v <= 0 {
continue
}
if ret == 0 || v < ret {
ret = v
}
}
return ret
}
// wrapper around io.ReadCloser which on first read, sends an
// HTTP/1.1 100 Continue header
type expectContinueReader struct {
resp *response
readCloser io.ReadCloser
closed atomic.Bool
sawEOF atomic.Bool
}
func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed.Load() {
return 0, ErrBodyReadAfterClose
}
w := ecr.resp
if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() {
w.wroteContinue = true
w.writeContinueMu.Lock()
if w.canWriteContinue.Load() {
w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
w.conn.bufw.Flush()
w.canWriteContinue.Store(false)
}
w.writeContinueMu.Unlock()
}
n, err = ecr.readCloser.Read(p)
if err == io.EOF {
ecr.sawEOF.Store(true)
}
return
}
func (ecr *expectContinueReader) Close() error {
ecr.closed.Store(true)
return ecr.readCloser.Close()
}
// TimeFormat is the time format to use when generating times in HTTP
// headers. It is like time.RFC1123 but hard-codes GMT as the time
// zone. The time being formatted must be in UTC for Format to
// generate the correct format.
//
// For parsing this time format, see ParseTime.
const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat))
func appendTime(b []byte, t time.Time) []byte {
const days = "SunMonTueWedThuFriSat"
const months = "JanFebMarAprMayJunJulAugSepOctNovDec"
t = t.UTC()
yy, mm, dd := t.Date()
hh, mn, ss := t.Clock()
day := days[3*t.Weekday():]
mon := months[3*(mm-1):]
return append(b,
day[0], day[1], day[2], ',', ' ',
byte('0'+dd/10), byte('0'+dd%10), ' ',
mon[0], mon[1], mon[2], ' ',
byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ',
byte('0'+hh/10), byte('0'+hh%10), ':',
byte('0'+mn/10), byte('0'+mn%10), ':',
byte('0'+ss/10), byte('0'+ss%10), ' ',
'G', 'M', 'T')
}
var errTooLarge = errors.New("http: request too large")
// Read next request from connection.
func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
if c.hijacked() {
return nil, ErrHijacked
}
var (
wholeReqDeadline time.Time // or zero if none
hdrDeadline time.Time // or zero if none
)
t0 := time.Now()
if d := c.server.readHeaderTimeout(); d > 0 {
hdrDeadline = t0.Add(d)
}
if d := c.server.ReadTimeout; d > 0 {
wholeReqDeadline = t0.Add(d)
}
c.rwc.SetReadDeadline(hdrDeadline)
if d := c.server.WriteTimeout; d > 0 {
defer func() {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}()
}
c.r.setReadLimit(c.server.initialReadLimitSize())
if c.lastMethod == "POST" {
// RFC 7230 section 3 tolerance for old buggy clients.
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek))
}
req, err := readRequest(c.bufr)
if err != nil {
if c.r.hitReadLimit() {
return nil, errTooLarge
}
return nil, err
}
if !http1ServerSupportsRequest(req) {
return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"}
}
c.lastMethod = req.Method
c.r.setInfiniteReadLimit()
hosts, haveHost := req.Header["Host"]
isH2Upgrade := req.isH2Upgrade()
if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade && req.Method != "CONNECT" {
return nil, badRequestError("missing required Host header")
}
if len(hosts) == 1 && !httpguts.ValidHostHeader(hosts[0]) {
return nil, badRequestError("malformed Host header")
}
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, badRequestError("invalid header name")
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, badRequestError("invalid header value")
}
}
}
delete(req.Header, "Host")
ctx, cancelCtx := context.WithCancel(ctx)
req.ctx = ctx
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
if body, ok := req.Body.(*body); ok {
body.doEarlyClose = true
}
// Adjust the read deadline if necessary.
if !hdrDeadline.Equal(wholeReqDeadline) {
c.rwc.SetReadDeadline(wholeReqDeadline)
}
w = &response{
conn: c,
cancelCtx: cancelCtx,
req: req,
reqBody: req.Body,
handlerHeader: make(Header),
contentLength: -1,
closeNotifyCh: make(chan bool, 1),
// We populate these ahead of time so we're not
// reading from req.Header after their Handler starts
// and maybe mutates it (Issue 14940)
wants10KeepAlive: req.wantsHttp10KeepAlive(),
wantsClose: req.wantsClose(),
}
if isH2Upgrade {
w.closeAfterReply = true
}
w.cw.res = w
w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
return w, nil
}
// http1ServerSupportsRequest reports whether Go's HTTP/1.x server
// supports the given request.
func http1ServerSupportsRequest(req *Request) bool {
if req.ProtoMajor == 1 {
return true
}
// Accept "PRI * HTTP/2.0" upgrade requests, so Handlers can
// wire up their own HTTP/2 upgrades.
if req.ProtoMajor == 2 && req.ProtoMinor == 0 &&
req.Method == "PRI" && req.RequestURI == "*" {
return true
}
// Reject HTTP/0.x, and all other HTTP/2+ requests (which
// aren't encoded in ASCII anyway).
return false
}
func (w *response) Header() Header {
if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader {
// Accessing the header between logically writing it
// and physically writing it means we need to allocate
// a clone to snapshot the logically written state.
w.cw.header = w.handlerHeader.Clone()
}
w.calledHeader = true
return w.handlerHeader
}
// maxPostHandlerReadBytes is the max number of Request.Body bytes not
// consumed by a handler that the server will read from the client
// in order to keep a connection alive. If there are more bytes than
// this then the server to be paranoid instead sends a "Connection:
// close" response.
//
// This number is approximately what a typical machine's TCP buffer
// size is anyway. (if we have the bytes on the machine, we might as
// well read them)
const maxPostHandlerReadBytes = 256 << 10
func checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
// at https://httpwg.org/specs/rfc7231.html#status.codes).
// But for now any three digits.
//
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
// no equivalent bogus thing we can realistically send in HTTP/2,
// so we'll consistently panic instead and help people find their bugs
// early. (We can't return an error from WriteHeader even if we wanted to.)
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid WriteHeader code %v", code))
}
}
// relevantCaller searches the call stack for the first function outside of net/http.
// The purpose of this function is to provide more helpful error messages.
func relevantCaller() runtime.Frame {
pc := make([]uintptr, 16)
n := runtime.Callers(1, pc)
frames := runtime.CallersFrames(pc[:n])
var frame runtime.Frame
for {
frame, more := frames.Next()
if !strings.HasPrefix(frame.Function, "net/http.") {
return frame
}
if !more {
break
}
}
return frame
}
func (w *response) WriteHeader(code int) {
if w.conn.hijacked() {
caller := relevantCaller()
w.conn.server.logf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
return
}
if w.wroteHeader {
caller := relevantCaller()
w.conn.server.logf("http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
return
}
checkWriteHeaderCode(code)
// Handle informational headers
if code >= 100 && code <= 199 {
// Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
if code == 100 && w.canWriteContinue.Load() {
w.writeContinueMu.Lock()
w.canWriteContinue.Store(false)
w.writeContinueMu.Unlock()
}
writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
// Per RFC 8297 we must not clear the current header map
w.handlerHeader.WriteSubset(w.conn.bufw, excludedHeadersNoBody)
w.conn.bufw.Write(crlf)
w.conn.bufw.Flush()
return
}
w.wroteHeader = true
w.status = code
if w.calledHeader && w.cw.header == nil {
w.cw.header = w.handlerHeader.Clone()
}
if cl := w.handlerHeader.get("Content-Length"); cl != "" {
v, err := strconv.ParseInt(cl, 10, 64)
if err == nil && v >= 0 {
w.contentLength = v
} else {
w.conn.server.logf("http: invalid Content-Length of %q", cl)
w.handlerHeader.Del("Content-Length")
}
}
}
// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader.
// This type is used to avoid extra allocations from cloning and/or populating
// the response Header map and all its 1-element slices.
type extraHeader struct {
contentType string
connection string
transferEncoding string
date []byte // written if not nil
contentLength []byte // written if not nil
}
// Sorted the same as extraHeader.Write's loop.
var extraHeaderKeys = [][]byte{
[]byte("Content-Type"),
[]byte("Connection"),
[]byte("Transfer-Encoding"),
}
var (
headerContentLength = []byte("Content-Length: ")
headerDate = []byte("Date: ")
)
// Write writes the headers described in h to w.
//
// This method has a value receiver, despite the somewhat large size
// of h, because it prevents an allocation. The escape analysis isn't
// smart enough to realize this function doesn't mutate h.
func (h extraHeader) Write(w *bufio.Writer) {
if h.date != nil {
w.Write(headerDate)
w.Write(h.date)
w.Write(crlf)
}
if h.contentLength != nil {
w.Write(headerContentLength)
w.Write(h.contentLength)
w.Write(crlf)
}
for i, v := range []string{h.contentType, h.connection, h.transferEncoding} {
if v != "" {
w.Write(extraHeaderKeys[i])
w.Write(colonSpace)
w.WriteString(v)
w.Write(crlf)
}
}
}
// writeHeader finalizes the header sent to the client and writes it
// to cw.res.conn.bufw.
//
// p is not written by writeHeader, but is the first chunk of the body
// that will be written. It is sniffed for a Content-Type if none is
// set explicitly. It's also used to set the Content-Length, if the
// total body size was small and the handler has already finished
// running.
func (cw *chunkWriter) writeHeader(p []byte) {
if cw.wroteHeader {
return
}
cw.wroteHeader = true
w := cw.res
keepAlivesEnabled := w.conn.server.doKeepAlives()
isHEAD := w.req.Method == "HEAD"
// header is written out to w.conn.buf below. Depending on the
// state of the handler, we either own the map or not. If we
// don't own it, the exclude map is created lazily for
// WriteSubset to remove headers. The setHeader struct holds
// headers we need to add.
header := cw.header
owned := header != nil
if !owned {
header = w.handlerHeader
}
var excludeHeader map[string]bool
delHeader := func(key string) {
if owned {
header.Del(key)
return
}
if _, ok := header[key]; !ok {
return
}
if excludeHeader == nil {
excludeHeader = make(map[string]bool)
}
excludeHeader[key] = true
}
var setHeader extraHeader
// Don't write out the fake "Trailer:foo" keys. See TrailerPrefix.
trailers := false
for k := range cw.header {
if strings.HasPrefix(k, TrailerPrefix) {
if excludeHeader == nil {
excludeHeader = make(map[string]bool)
}
excludeHeader[k] = true
trailers = true
}
}
for _, v := range cw.header["Trailer"] {
trailers = true
foreachHeaderElement(v, cw.res.declareTrailer)
}
te := header.get("Transfer-Encoding")
hasTE := te != ""
// If the handler is done but never sent a Content-Length
// response header and this is our first (and last) write, set
// it, even to zero. This helps HTTP/1.0 clients keep their
// "keep-alive" connections alive.
// Exceptions: 304/204/1xx responses never get Content-Length, and if
// it was a HEAD request, we don't know the difference between
// 0 actual bytes and 0 bytes because the handler noticed it
// was a HEAD request and chose not to write anything. So for
// HEAD, the handler should either write the Content-Length or
// write non-zero bytes. If it's actually 0 bytes and the
// handler never looked at the Request.Method, we just don't
// send a Content-Length header.
// Further, we don't send an automatic Content-Length if they
// set a Transfer-Encoding, because they're generally incompatible.
if w.handlerDone.Load() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
w.contentLength = int64(len(p))
setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
}
// If this was an HTTP/1.0 request with keep-alive and we sent a
// Content-Length back, we can make this a keep-alive response ...
if w.wants10KeepAlive && keepAlivesEnabled {
sentLength := header.get("Content-Length") != ""
if sentLength && header.get("Connection") == "keep-alive" {
w.closeAfterReply = false
}
}
// Check for an explicit (and valid) Content-Length header.
hasCL := w.contentLength != -1
if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) {
_, connectionHeaderSet := header["Connection"]
if !connectionHeaderSet {
setHeader.connection = "keep-alive"
}
} else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose {
w.closeAfterReply = true
}
if header.get("Connection") == "close" || !keepAlivesEnabled {
w.closeAfterReply = true
}
// If the client wanted a 100-continue but we never sent it to
// them (or, more strictly: we never finished reading their
// request body), don't reuse this connection because it's now
// in an unknown state: we might be sending this response at
// the same time the client is now sending its request body
// after a timeout. (Some HTTP clients send Expect:
// 100-continue but knowing that some servers don't support
// it, the clients set a timer and send the body later anyway)
// If we haven't seen EOF, we can't skip over the unread body
// because we don't know if the next bytes on the wire will be
// the body-following-the-timer or the subsequent request.
// See Issue 11549.
if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.Load() {
w.closeAfterReply = true
}
// Per RFC 2616, we should consume the request body before
// replying, if the handler hasn't already done so. But we
// don't want to do an unbounded amount of reading here for
// DoS reasons, so we only try up to a threshold.
// TODO(bradfitz): where does RFC 2616 say that? See Issue 15527
// about HTTP/1.x Handlers concurrently reading and writing, like
// HTTP/2 handlers can do. Maybe this code should be relaxed?
if w.req.ContentLength != 0 && !w.closeAfterReply {
var discard, tooBig bool
switch bdy := w.req.Body.(type) {
case *expectContinueReader:
if bdy.resp.wroteContinue {
discard = true
}
case *body:
bdy.mu.Lock()
switch {
case bdy.closed:
if !bdy.sawEOF {
// Body was closed in handler with non-EOF error.
w.closeAfterReply = true
}
case bdy.unreadDataSizeLocked() >= maxPostHandlerReadBytes:
tooBig = true
default:
discard = true
}
bdy.mu.Unlock()
default:
discard = true
}
if discard {
_, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1)
switch err {
case nil:
// There must be even more data left over.
tooBig = true
case ErrBodyReadAfterClose:
// Body was already consumed and closed.
case io.EOF:
// The remaining body was just consumed, close it.
err = w.reqBody.Close()
if err != nil {
w.closeAfterReply = true
}
default:
// Some other kind of error occurred, like a read timeout, or
// corrupt chunked encoding. In any case, whatever remains
// on the wire must not be parsed as another HTTP request.
w.closeAfterReply = true
}
}
if tooBig {
w.requestTooLarge()
delHeader("Connection")
setHeader.connection = "close"
}
}
code := w.status
if bodyAllowedForStatus(code) {
// If no content type, apply sniffing algorithm to body.
_, haveType := header["Content-Type"]
// If the Content-Encoding was set and is non-blank,
// we shouldn't sniff the body. See Issue 31753.
ce := header.Get("Content-Encoding")
hasCE := len(ce) > 0
if !hasCE && !haveType && !hasTE && len(p) > 0 {
setHeader.contentType = DetectContentType(p)
}
} else {
for _, k := range suppressedHeaders(code) {
delHeader(k)
}
}
if !header.has("Date") {
setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now())
}
if hasCL && hasTE && te != "identity" {
// TODO: return an error if WriteHeader gets a return parameter
// For now just ignore the Content-Length.
w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
te, w.contentLength)
delHeader("Content-Length")
hasCL = false
}
if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent {
// Response has no body.
delHeader("Transfer-Encoding")
} else if hasCL {
// Content-Length has been provided, so no chunking is to be done.
delHeader("Transfer-Encoding")
} else if w.req.ProtoAtLeast(1, 1) {
// HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no
// content-length has been provided. The connection must be closed after the
// reply is written, and no chunking is to be done. This is the setup
// recommended in the Server-Sent Events candidate recommendation 11,
// section 8.
if hasTE && te == "identity" {
cw.chunking = false
w.closeAfterReply = true
delHeader("Transfer-Encoding")
} else {
// HTTP/1.1 or greater: use chunked transfer encoding
// to avoid closing the connection at EOF.
cw.chunking = true
setHeader.transferEncoding = "chunked"
if hasTE && te == "chunked" {
// We will send the chunked Transfer-Encoding header later.
delHeader("Transfer-Encoding")
}
}
} else {
// HTTP version < 1.1: cannot do chunked transfer
// encoding and we don't know the Content-Length so
// signal EOF by closing connection.
w.closeAfterReply = true
delHeader("Transfer-Encoding") // in case already set
}
// Cannot use Content-Length with non-identity Transfer-Encoding.
if cw.chunking {
delHeader("Content-Length")
}
if !w.req.ProtoAtLeast(1, 0) {
return
}
// Only override the Connection header if it is not a successful
// protocol switch response and if KeepAlives are not enabled.
// See https://golang.org/issue/36381.
delConnectionHeader := w.closeAfterReply &&
(!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) &&
!isProtocolSwitchResponse(w.status, header)
if delConnectionHeader {
delHeader("Connection")
if w.req.ProtoAtLeast(1, 1) {
setHeader.connection = "close"
}
}
writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
cw.header.WriteSubset(w.conn.bufw, excludeHeader)
setHeader.Write(w.conn.bufw)
w.conn.bufw.Write(crlf)
}
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 7230 section 7 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}
// writeStatusLine writes an HTTP/1.x Status-Line (RFC 7230 Section 3.1.2)
// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0.
// code is the response status code.
// scratch is an optional scratch buffer. If it has at least capacity 3, it's used.
func writeStatusLine(bw *bufio.Writer, is11 bool, code int, scratch []byte) {
if is11 {
bw.WriteString("HTTP/1.1 ")
} else {
bw.WriteString("HTTP/1.0 ")
}
if text := StatusText(code); text != "" {
bw.Write(strconv.AppendInt(scratch[:0], int64(code), 10))
bw.WriteByte(' ')
bw.WriteString(text)
bw.WriteString("\r\n")
} else {
// don't worry about performance
fmt.Fprintf(bw, "%03d status code %d\r\n", code, code)
}
}
// bodyAllowed reports whether a Write is allowed for this response type.
// It's illegal to call this before the header has been flushed.
func (w *response) bodyAllowed() bool {
if !w.wroteHeader {
panic("")
}
return bodyAllowedForStatus(w.status)
}
// The Life Of A Write is like this:
//
// Handler starts. No header has been sent. The handler can either
// write a header, or just start writing. Writing before sending a header
// sends an implicitly empty 200 OK header.
//
// If the handler didn't declare a Content-Length up front, we either
// go into chunking mode or, if the handler finishes running before
// the chunking buffer size, we compute a Content-Length and send that
// in the header instead.
//
// Likewise, if the handler didn't set a Content-Type, we sniff that
// from the initial chunk of output.
//
// The Writers are wired together like:
//
// 1. *response (the ResponseWriter) ->
// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes ->
// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
// and which writes the chunk headers, if needed ->
// 4. conn.bufw, a *bufio.Writer of default (4kB) bytes, writing to ->
// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
// and populates c.werr with it if so, but otherwise writes to ->
// 6. the rwc, the net.Conn.
//
// TODO(bradfitz): short-circuit some of the buffering when the
// initial header contains both a Content-Type and Content-Length.
// Also short-circuit in (1) when the header's been sent and not in
// chunking mode, writing directly to (4) instead, if (2) has no
// buffered data. More generally, we could short-circuit from (1) to
// (3) even in chunking mode if the write size from (1) is over some
// threshold and nothing is in (2). The answer might be mostly making
// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal
// with this instead.
func (w *response) Write(data []byte) (n int, err error) {
return w.write(len(data), data, "")
}
func (w *response) WriteString(data string) (n int, err error) {
return w.write(len(data), nil, data)
}
// either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.conn.hijacked() {
if lenData > 0 {
caller := relevantCaller()
w.conn.server.logf("http: response.Write on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
}
return 0, ErrHijacked
}
if w.canWriteContinue.Load() {
// Body reader wants to write 100 Continue but hasn't yet.
// Tell it not to. The store must be done while holding the lock
// because the lock makes sure that there is not an active write
// this very moment.
w.writeContinueMu.Lock()
w.canWriteContinue.Store(false)
w.writeContinueMu.Unlock()
}
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
if lenData == 0 {
return 0, nil
}
if !w.bodyAllowed() {
return 0, ErrBodyNotAllowed
}
w.written += int64(lenData) // ignoring errors, for errorKludge
if w.contentLength != -1 && w.written > w.contentLength {
return 0, ErrContentLength
}
if dataB != nil {
return w.w.Write(dataB)
} else {
return w.w.WriteString(dataS)
}
}
func (w *response) finishRequest() {
w.handlerDone.Store(true)
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
w.w.Flush()
putBufioWriter(w.w)
w.cw.close()
w.conn.bufw.Flush()
w.conn.r.abortPendingRead()
// Close the body (regardless of w.closeAfterReply) so we can
// re-use its bufio.Reader later safely.
w.reqBody.Close()
if w.req.MultipartForm != nil {
w.req.MultipartForm.RemoveAll()
}
}
// shouldReuseConnection reports whether the underlying TCP connection can be reused.
// It must only be called after the handler is done executing.
func (w *response) shouldReuseConnection() bool {
if w.closeAfterReply {
// The request or something set while executing the
// handler indicated we shouldn't reuse this
// connection.
return false
}
if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
// Did not write enough. Avoid getting out of sync.
return false
}
// There was some error writing to the underlying connection
// during the request, so don't re-use this conn.
if w.conn.werr != nil {
return false
}
if w.closedRequestBodyEarly() {
return false
}
return true
}
func (w *response) closedRequestBodyEarly() bool {
body, ok := w.req.Body.(*body)
return ok && body.didEarlyClose()
}
func (w *response) Flush() {
w.FlushError()
}
func (w *response) FlushError() error {
if !w.wroteHeader {
w.WriteHeader(StatusOK)
}
err := w.w.Flush()
e2 := w.cw.flush()
if err == nil {
err = e2
}
return err
}
func (c *conn) finalFlush() {
if c.bufr != nil {
// Steal the bufio.Reader (~4KB worth of memory) and its associated
// reader for a future connection.
putBufioReader(c.bufr)
c.bufr = nil
}
if c.bufw != nil {
c.bufw.Flush()
// Steal the bufio.Writer (~4KB worth of memory) and its associated
// writer for a future connection.
putBufioWriter(c.bufw)
c.bufw = nil
}
}
// Close the connection.
func (c *conn) close() {
c.finalFlush()
c.rwc.Close()
}
// rstAvoidanceDelay is the amount of time we sleep after closing the
// write side of a TCP connection before closing the entire socket.
// By sleeping, we increase the chances that the client sees our FIN
// and processes its final data before they process the subsequent RST
// from closing a connection with known unread data.
// This RST seems to occur mostly on BSD systems. (And Windows?)
// This timeout is somewhat arbitrary (~latency around the planet).
const rstAvoidanceDelay = 500 * time.Millisecond
type closeWriter interface {
CloseWrite() error
}
var _ closeWriter = (*net.TCPConn)(nil)
// closeWriteAndWait flushes any outstanding data and sends a FIN packet (if
// client is connected via TCP), signaling that we're done. We then
// pause for a bit, hoping the client processes it before any
// subsequent RST.
//
// See https://golang.org/issue/3595
func (c *conn) closeWriteAndWait() {
c.finalFlush()
if tcp, ok := c.rwc.(closeWriter); ok {
tcp.CloseWrite()
}
time.Sleep(rstAvoidanceDelay)
}
// validNextProto reports whether the proto is a valid ALPN protocol name.
// Everything is valid except the empty string and built-in protocol types,
// so that those can't be overridden with alternate implementations.
func validNextProto(proto string) bool {
switch proto {
case "", "http/1.1", "http/1.0":
return false
}
return true
}
const (
runHooks = true
skipHooks = false
)
func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) {
srv := c.server
switch state {
case StateNew:
srv.trackConn(c, true)
case StateHijacked, StateClosed:
srv.trackConn(c, false)
}
if state > 0xff || state < 0 {
panic("internal error")
}
packedState := uint64(time.Now().Unix()<<8) | uint64(state)
c.curState.Store(packedState)
if !runHook {
return
}
if hook := srv.ConnState; hook != nil {
hook(nc, state)
}
}
func (c *conn) getState() (state ConnState, unixSec int64) {
packedState := c.curState.Load()
return ConnState(packedState & 0xff), int64(packedState >> 8)
}
// badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embedded errors.
func badRequestError(e string) error { return statusError{StatusBadRequest, e} }
// statusError is an error used to respond to a request with an HTTP status.
// The text should be plain text without user info or other embedded errors.
type statusError struct {
code int
text string
}
func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text }
// ErrAbortHandler is a sentinel panic value to abort a handler.
// While any panic from ServeHTTP aborts the response to the client,
// panicking with ErrAbortHandler also suppresses logging of a stack
// trace to the server's error log.
var ErrAbortHandler = errors.New("net/http: abort Handler")
// isCommonNetReadError reports whether err is a common error
// encountered during reading a request off the network when the
// client has gone away or had its read fail somehow. This is used to
// determine which logs are interesting enough to log about.
func isCommonNetReadError(err error) bool {
if err == io.EOF {
return true
}
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return true
}
if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
return true
}
return false
}
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
var inFlightResponse *response
defer func() {
if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
}
if inFlightResponse != nil {
inFlightResponse.cancelCtx()
}
if !c.hijacked() {
if inFlightResponse != nil {
inFlightResponse.conn.r.abortPendingRead()
inFlightResponse.reqBody.Close()
}
c.close()
c.setState(c.rwc, StateClosed, runHooks)
}
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
tlsTO := c.server.tlsHandshakeTimeout()
if tlsTO > 0 {
dl := time.Now().Add(tlsTO)
c.rwc.SetReadDeadline(dl)
c.rwc.SetWriteDeadline(dl)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
// TLS, assume they're speaking plaintext HTTP and write a
// 400 response on the TLS conn's underlying net.Conn.
if re, ok := err.(tls.RecordHeaderError); ok && re.Conn != nil && tlsRecordHeaderLooksLikeHTTP(re.RecordHeader) {
io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n")
re.Conn.Close()
return
}
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
// Restore Conn-level deadlines.
if tlsTO > 0 {
c.rwc.SetReadDeadline(time.Time{})
c.rwc.SetWriteDeadline(time.Time{})
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
if fn := c.server.TLSNextProto[proto]; fn != nil {
h := initALPNRequest{ctx, tlsConn, serverHandler{c.server}}
// Mark freshly created HTTP/2 as active and prevent any server state hooks
// from being run on these connections. This prevents closeIdleConns from
// closing such connections. See issue https://golang.org/issue/39776.
c.setState(c.rwc, StateActive, skipHooks)
fn(c.server, tlsConn, h)
}
return
}
}
// HTTP/1.x from here on.
ctx, cancelCtx := context.WithCancel(ctx)
c.cancelCtx = cancelCtx
defer cancelCtx()
c.r = &connReader{conn: c}
c.bufr = newBufioReader(c.r)
c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
for {
w, err := c.readRequest(ctx)
if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active.
c.setState(c.rwc, StateActive, runHooks)
}
if err != nil {
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
switch {
case err == errTooLarge:
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
const publicErr = "431 Request Header Fields Too Large"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
c.closeWriteAndWait()
return
case isUnsupportedTEError(err):
// Respond as per RFC 7230 Section 3.3.1 which says,
// A server that receives a request message with a
// transfer coding it does not understand SHOULD
// respond with 501 (Unimplemented).
code := StatusNotImplemented
// We purposefully aren't echoing back the transfer-encoding's value,
// so as to mitigate the risk of cross side scripting by an attacker.
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, StatusText(code), errorHeaders)
return
case isCommonNetReadError(err):
return // don't reply
default:
if v, ok := err.(statusError); ok {
fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text)
return
}
publicErr := "400 Bad Request"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
return
}
}
// Expect 100 Continue support
req := w.req
if req.expectsContinue() {
if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
// Wrap the Body reader with one that replies on the connection
req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
w.canWriteContinue.Store(true)
}
} else if req.Header.get("Expect") != "" {
w.sendExpectationFailed()
return
}
c.curReq.Store(w)
if requestBodyRemains(req.Body) {
registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
} else {
w.conn.r.startBackgroundRead()
}
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
// But we're not going to implement HTTP pipelining because it
// was never deployed in the wild and the answer is HTTP/2.
inFlightResponse = w
serverHandler{c.server}.ServeHTTP(w, w.req)
inFlightResponse = nil
w.cancelCtx()
if c.hijacked() {
return
}
w.finishRequest()
c.rwc.SetWriteDeadline(time.Time{})
if !w.shouldReuseConnection() {
if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
c.closeWriteAndWait()
}
return
}
c.setState(c.rwc, StateIdle, runHooks)
c.curReq.Store(nil)
if !w.conn.server.doKeepAlives() {
// We're in shutdown mode. We might've replied
// to the user without "Connection: close" and
// they might think they can send another
// request, but such is life with HTTP/1.1.
return
}
if d := c.server.idleTimeout(); d != 0 {
c.rwc.SetReadDeadline(time.Now().Add(d))
} else {
c.rwc.SetReadDeadline(time.Time{})
}
// Wait for the connection to become readable again before trying to
// read the next request. This prevents a ReadHeaderTimeout or
// ReadTimeout from starting until the first bytes of the next request
// have been received.
if _, err := c.bufr.Peek(4); err != nil {
return
}
c.rwc.SetReadDeadline(time.Time{})
}
}
func (w *response) sendExpectationFailed() {
// TODO(bradfitz): let ServeHTTP handlers handle
// requests with non-standard expectation[s]? Seems
// theoretical at best, and doesn't fit into the
// current ServeHTTP model anyway. We'd need to
// make the ResponseWriter an optional
// "ExpectReplier" interface or something.
//
// For now we'll just obey RFC 7231 5.1.1 which says
// "A server that receives an Expect field-value other
// than 100-continue MAY respond with a 417 (Expectation
// Failed) status code to indicate that the unexpected
// expectation cannot be met."
w.Header().Set("Connection", "close")
w.WriteHeader(StatusExpectationFailed)
w.finishRequest()
}
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if w.handlerDone.Load() {
panic("net/http: Hijack called after ServeHTTP finished")
}
if w.wroteHeader {
w.cw.flush()
}
c := w.conn
c.mu.Lock()
defer c.mu.Unlock()
// Release the bufioWriter that writes to the chunk writer, it is not
// used after a connection has been hijacked.
rwc, buf, err = c.hijackLocked()
if err == nil {
putBufioWriter(w.w)
w.w = nil
}
return rwc, buf, err
}
func (w *response) CloseNotify() <-chan bool {
if w.handlerDone.Load() {
panic("net/http: CloseNotify called after ServeHTTP finished")
}
return w.closeNotifyCh
}
func registerOnHitEOF(rc io.ReadCloser, fn func()) {
switch v := rc.(type) {
case *expectContinueReader:
registerOnHitEOF(v.readCloser, fn)
case *body:
v.registerOnHitEOF(fn)
default:
panic("unexpected type " + fmt.Sprintf("%T", rc))
}
}
// requestBodyRemains reports whether future calls to Read
// on rc might yield more data.
func requestBodyRemains(rc io.ReadCloser) bool {
if rc == NoBody {
return false
}
switch v := rc.(type) {
case *expectContinueReader:
return requestBodyRemains(v.readCloser)
case *body:
return v.bodyRemains()
default:
panic("unexpected type " + fmt.Sprintf("%T", rc))
}
}
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as HTTP handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler that calls f.
type HandlerFunc func(ResponseWriter, *Request)
// ServeHTTP calls f(w, r).
func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
f(w, r)
}
// Helper handlers
// Error replies to the request with the specified error message and HTTP code.
// It does not otherwise end the request; the caller should ensure no further
// writes are done to w.
// The error message should be plain text.
func Error(w ResponseWriter, error string, code int) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(code)
fmt.Fprintln(w, error)
}
// NotFound replies to the request with an HTTP 404 not found error.
func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) }
// NotFoundHandler returns a simple request handler
// that replies to each request with a “404 page not found” reply.
func NotFoundHandler() Handler { return HandlerFunc(NotFound) }
// StripPrefix returns a handler that serves HTTP requests by removing the
// given prefix from the request URL's Path (and RawPath if set) and invoking
// the handler h. StripPrefix handles a request for a path that doesn't begin
// with prefix by replying with an HTTP 404 not found error. The prefix must
// match exactly: if the prefix in the request contains escaped characters
// the reply is also an HTTP 404 not found error.
func StripPrefix(prefix string, h Handler) Handler {
if prefix == "" {
return h
}
return HandlerFunc(func(w ResponseWriter, r *Request) {
p := strings.TrimPrefix(r.URL.Path, prefix)
rp := strings.TrimPrefix(r.URL.RawPath, prefix)
if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) {
r2 := new(Request)
*r2 = *r
r2.URL = new(url.URL)
*r2.URL = *r.URL
r2.URL.Path = p
r2.URL.RawPath = rp
h.ServeHTTP(w, r2)
} else {
NotFound(w, r)
}
})
}
// Redirect replies to the request with a redirect to url,
// which may be a path relative to the request path.
//
// The provided code should be in the 3xx range and is usually
// StatusMovedPermanently, StatusFound or StatusSeeOther.
//
// If the Content-Type header has not been set, Redirect sets it
// to "text/html; charset=utf-8" and writes a small HTML body.
// Setting the Content-Type header to any value, including nil,
// disables that behavior.
func Redirect(w ResponseWriter, r *Request, url string, code int) {
if u, err := urlpkg.Parse(url); err == nil {
// If url was relative, make its path absolute by
// combining with request path.
// The client would probably do this for us,
// but doing it ourselves is more reliable.
// See RFC 7231, section 7.1.2
if u.Scheme == "" && u.Host == "" {
oldpath := r.URL.Path
if oldpath == "" { // should not happen, but avoid a crash if it does
oldpath = "/"
}
// no leading http://server
if url == "" || url[0] != '/' {
// make relative path absolute
olddir, _ := path.Split(oldpath)
url = olddir + url
}
var query string
if i := strings.Index(url, "?"); i != -1 {
url, query = url[:i], url[i:]
}
// clean up but preserve trailing slash
trailing := strings.HasSuffix(url, "/")
url = path.Clean(url)
if trailing && !strings.HasSuffix(url, "/") {
url += "/"
}
url += query
}
}
h := w.Header()
// RFC 7231 notes that a short HTML body is usually included in
// the response because older user agents may not understand 301/307.
// Do it only if the request didn't already have a Content-Type header.
_, hadCT := h["Content-Type"]
h.Set("Location", hexEscapeNonASCII(url))
if !hadCT && (r.Method == "GET" || r.Method == "HEAD") {
h.Set("Content-Type", "text/html; charset=utf-8")
}
w.WriteHeader(code)
// Shouldn't send the body for POST or HEAD; that leaves GET.
if !hadCT && r.Method == "GET" {
body := "<a href=\"" + htmlEscape(url) + "\">" + StatusText(code) + "</a>.\n"
fmt.Fprintln(w, body)
}
}
var htmlReplacer = strings.NewReplacer(
"&", "&",
"<", "<",
">", ">",
// """ is shorter than """.
`"`, """,
// "'" is shorter than "'" and apos was not in HTML until HTML5.
"'", "'",
)
func htmlEscape(s string) string {
return htmlReplacer.Replace(s)
}
// Redirect to a fixed URL
type redirectHandler struct {
url string
code int
}
func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
Redirect(w, r, rh.url, rh.code)
}
// RedirectHandler returns a request handler that redirects
// each request it receives to the given url using the given
// status code.
//
// The provided code should be in the 3xx range and is usually
// StatusMovedPermanently, StatusFound or StatusSeeOther.
func RedirectHandler(url string, code int) Handler {
return &redirectHandler{url, code}
}
// ServeMux is an HTTP request multiplexer.
// It matches the URL of each incoming request against a list of registered
// patterns and calls the handler for the pattern that
// most closely matches the URL.
//
// Patterns name fixed, rooted paths, like "/favicon.ico",
// or rooted subtrees, like "/images/" (note the trailing slash).
// Longer patterns take precedence over shorter ones, so that
// if there are handlers registered for both "/images/"
// and "/images/thumbnails/", the latter handler will be
// called for paths beginning with "/images/thumbnails/" and the
// former will receive requests for any other paths in the
// "/images/" subtree.
//
// Note that since a pattern ending in a slash names a rooted subtree,
// the pattern "/" matches all paths not matched by other registered
// patterns, not just the URL with Path == "/".
//
// If a subtree has been registered and a request is received naming the
// subtree root without its trailing slash, ServeMux redirects that
// request to the subtree root (adding the trailing slash). This behavior can
// be overridden with a separate registration for the path without
// the trailing slash. For example, registering "/images/" causes ServeMux
// to redirect a request for "/images" to "/images/", unless "/images" has
// been registered separately.
//
// Patterns may optionally begin with a host name, restricting matches to
// URLs on that host only. Host-specific patterns take precedence over
// general patterns, so that a handler might register for the two patterns
// "/codesearch" and "codesearch.google.com/" without also taking over
// requests for "http://www.google.com/".
//
// ServeMux also takes care of sanitizing the URL request path and the Host
// header, stripping the port number and redirecting any request containing . or
// .. elements or repeated slashes to an equivalent, cleaner URL.
type ServeMux struct {
mu sync.RWMutex
m map[string]muxEntry
es []muxEntry // slice of entries sorted from longest to shortest.
hosts bool // whether any patterns contain hostnames
}
type muxEntry struct {
h Handler
pattern string
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux { return new(ServeMux) }
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = &defaultServeMux
var defaultServeMux ServeMux
// cleanPath returns the canonical path for p, eliminating . and .. elements.
func cleanPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
// Fast path for common case of p being the string we want:
if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
np = p
} else {
np += "/"
}
}
return np
}
// stripHostPort returns h without any trailing ":<port>".
func stripHostPort(h string) string {
// If no port on host, return unchanged
if !strings.Contains(h, ":") {
return h
}
host, _, err := net.SplitHostPort(h)
if err != nil {
return h // on error, return unchanged
}
return host
}
// Find a handler on a handler map given a path string.
// Most-specific (longest) pattern wins.
func (mux *ServeMux) match(path string) (h Handler, pattern string) {
// Check for exact match first.
v, ok := mux.m[path]
if ok {
return v.h, v.pattern
}
// Check for longest valid match. mux.es contains all patterns
// that end in / sorted from longest to shortest.
for _, e := range mux.es {
if strings.HasPrefix(path, e.pattern) {
return e.h, e.pattern
}
}
return nil, ""
}
// redirectToPathSlash determines if the given path needs appending "/" to it.
// This occurs when a handler for path + "/" was already registered, but
// not for path itself. If the path needs appending to, it creates a new
// URL, setting the path to u.Path + "/" and returning true to indicate so.
func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) {
mux.mu.RLock()
shouldRedirect := mux.shouldRedirectRLocked(host, path)
mux.mu.RUnlock()
if !shouldRedirect {
return u, false
}
path = path + "/"
u = &url.URL{Path: path, RawQuery: u.RawQuery}
return u, true
}
// shouldRedirectRLocked reports whether the given path and host should be redirected to
// path+"/". This should happen if a handler is registered for path+"/" but
// not path -- see comments at ServeMux.
func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
p := []string{path, host + path}
for _, c := range p {
if _, exist := mux.m[c]; exist {
return false
}
}
n := len(path)
if n == 0 {
return false
}
for _, c := range p {
if _, exist := mux.m[c+"/"]; exist {
return path[n-1] != '/'
}
}
return false
}
// Handler returns the handler to use for the given request,
// consulting r.Method, r.Host, and r.URL.Path. It always returns
// a non-nil handler. If the path is not in its canonical form, the
// handler will be an internally-generated handler that redirects
// to the canonical path. If the host contains a port, it is ignored
// when matching handlers.
//
// The path and host are used unchanged for CONNECT requests.
//
// Handler also returns the registered pattern that matches the
// request or, in the case of internally-generated redirects,
// the pattern that will match after following the redirect.
//
// If there is no registered handler that applies to the request,
// Handler returns a “page not found” handler and an empty pattern.
func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
// CONNECT requests are not canonicalized.
if r.Method == "CONNECT" {
// If r.URL.Path is /tree and its handler is not registered,
// the /tree -> /tree/ redirect applies to CONNECT requests
// but the path canonicalization does not.
if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok {
return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
}
return mux.handler(r.Host, r.URL.Path)
}
// All other requests have any port stripped and path cleaned
// before passing to mux.handler.
host := stripHostPort(r.Host)
path := cleanPath(r.URL.Path)
// If the given path is /tree and its handler is not registered,
// redirect for /tree/.
if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok {
return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
}
if path != r.URL.Path {
_, pattern = mux.handler(host, path)
u := &url.URL{Path: path, RawQuery: r.URL.RawQuery}
return RedirectHandler(u.String(), StatusMovedPermanently), pattern
}
return mux.handler(host, r.URL.Path)
}
// handler is the main implementation of Handler.
// The path is known to be in canonical form, except for CONNECT methods.
func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
mux.mu.RLock()
defer mux.mu.RUnlock()
// Host-specific pattern takes precedence over generic ones
if mux.hosts {
h, pattern = mux.match(host + path)
}
if h == nil {
h, pattern = mux.match(path)
}
if h == nil {
h, pattern = NotFoundHandler(), ""
}
return
}
// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
if r.RequestURI == "*" {
if r.ProtoAtLeast(1, 1) {
w.Header().Set("Connection", "close")
}
w.WriteHeader(StatusBadRequest)
return
}
h, _ := mux.Handler(r)
h.ServeHTTP(w, r)
}
// Handle registers the handler for the given pattern.
// If a handler already exists for pattern, Handle panics.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
mux.mu.Lock()
defer mux.mu.Unlock()
if pattern == "" {
panic("http: invalid pattern")
}
if handler == nil {
panic("http: nil handler")
}
if _, exist := mux.m[pattern]; exist {
panic("http: multiple registrations for " + pattern)
}
if mux.m == nil {
mux.m = make(map[string]muxEntry)
}
e := muxEntry{h: handler, pattern: pattern}
mux.m[pattern] = e
if pattern[len(pattern)-1] == '/' {
mux.es = appendSorted(mux.es, e)
}
if pattern[0] != '/' {
mux.hosts = true
}
}
func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
n := len(es)
i := sort.Search(n, func(i int) bool {
return len(es[i].pattern) < len(e.pattern)
})
if i == n {
return append(es, e)
}
// we now know that i points at where we want to insert
es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
copy(es[i+1:], es[i:]) // Move shorter entries down
es[i] = e
return es
}
// HandleFunc registers the handler function for the given pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
if handler == nil {
panic("http: nil handler")
}
mux.Handle(pattern, HandlerFunc(handler))
}
// Handle registers the handler for the given pattern
// in the DefaultServeMux.
// The documentation for ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleFunc registers the handler function for the given pattern
// in the DefaultServeMux.
// The documentation for ServeMux explains how patterns are matched.
func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
// Serve accepts incoming HTTP connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
//
// The handler is typically nil, in which case the DefaultServeMux is used.
//
// HTTP/2 support is only enabled if the Listener returns *tls.Conn
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
//
// Serve always returns a non-nil error.
func Serve(l net.Listener, handler Handler) error {
srv := &Server{Handler: handler}
return srv.Serve(l)
}
// ServeTLS accepts incoming HTTPS connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
//
// The handler is typically nil, in which case the DefaultServeMux is used.
//
// Additionally, files containing a certificate and matching private key
// for the server must be provided. If the certificate is signed by a
// certificate authority, the certFile should be the concatenation
// of the server's certificate, any intermediates, and the CA's certificate.
//
// ServeTLS always returns a non-nil error.
func ServeTLS(l net.Listener, handler Handler, certFile, keyFile string) error {
srv := &Server{Handler: handler}
return srv.ServeTLS(l, certFile, keyFile)
}
// A Server defines parameters for running an HTTP server.
// The zero value for Server is a valid configuration.
type Server struct {
// Addr optionally specifies the TCP address for the server to listen on,
// in the form "host:port". If empty, ":http" (port 80) is used.
// The service names are defined in RFC 6335 and assigned by IANA.
// See net.Dial for details of the address format.
Addr string
Handler Handler // handler to invoke, http.DefaultServeMux if nil
// DisableGeneralOptionsHandler, if true, passes "OPTIONS *" requests to the Handler,
// otherwise responds with 200 OK and Content-Length: 0.
DisableGeneralOptionsHandler bool
// TLSConfig optionally provides a TLS configuration for use
// by ServeTLS and ListenAndServeTLS. Note that this value is
// cloned by ServeTLS and ListenAndServeTLS, so it's not
// possible to modify the configuration with methods like
// tls.Config.SetSessionTicketKeys. To use
// SetSessionTicketKeys, use Server.Serve with a TLS Listener
// instead.
TLSConfig *tls.Config
// ReadTimeout is the maximum duration for reading the entire
// request, including the body. A zero or negative value means
// there will be no timeout.
//
// Because ReadTimeout does not let Handlers make per-request
// decisions on each request body's acceptable deadline or
// upload rate, most users will prefer to use
// ReadHeaderTimeout. It is valid to use them both.
ReadTimeout time.Duration
// ReadHeaderTimeout is the amount of time allowed to read
// request headers. The connection's read deadline is reset
// after reading the headers and the Handler can decide what
// is considered too slow for the body. If ReadHeaderTimeout
// is zero, the value of ReadTimeout is used. If both are
// zero, there is no timeout.
ReadHeaderTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
// writes of the response. It is reset whenever a new
// request's header is read. Like ReadTimeout, it does not
// let Handlers make decisions on a per-request basis.
// A zero or negative value means there will be no timeout.
WriteTimeout time.Duration
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alives are enabled. If IdleTimeout
// is zero, the value of ReadTimeout is used. If both are
// zero, there is no timeout.
IdleTimeout time.Duration
// MaxHeaderBytes controls the maximum number of bytes the
// server will read parsing the request header's keys and
// values, including the request line. It does not limit the
// size of the request body.
// If zero, DefaultMaxHeaderBytes is used.
MaxHeaderBytes int
// TLSNextProto optionally specifies a function to take over
// ownership of the provided TLS connection when an ALPN
// protocol upgrade has occurred. The map key is the protocol
// name negotiated. The Handler argument should be used to
// handle HTTP requests and will initialize the Request's TLS
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
// If TLSNextProto is not nil, HTTP/2 support is not enabled
// automatically.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
// ConnState specifies an optional callback function that is
// called when a client connection changes state. See the
// ConnState type and associated constants for details.
ConnState func(net.Conn, ConnState)
// ErrorLog specifies an optional logger for errors accepting
// connections, unexpected behavior from handlers, and
// underlying FileSystem errors.
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
// BaseContext optionally specifies a function that returns
// the base context for incoming requests on this server.
// The provided Listener is the specific Listener that's
// about to start accepting requests.
// If BaseContext is nil, the default is context.Background().
// If non-nil, it must return a non-nil context.
BaseContext func(net.Listener) context.Context
// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// is derived from the base context and has a ServerContextKey
// value.
ConnContext func(ctx context.Context, c net.Conn) context.Context
inShutdown atomic.Bool // true when server is in shutdown
disableKeepAlives atomic.Bool
nextProtoOnce sync.Once // guards setupHTTP2_* init
nextProtoErr error // result of http2.ConfigureServer if used
mu sync.Mutex
listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
onShutdown []func()
listenerGroup sync.WaitGroup
}
// Close immediately closes all active net.Listeners and any
// connections in state StateNew, StateActive, or StateIdle. For a
// graceful shutdown, use Shutdown.
//
// Close does not attempt to close (and does not even know about)
// any hijacked connections, such as WebSockets.
//
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
srv.inShutdown.Store(true)
srv.mu.Lock()
defer srv.mu.Unlock()
err := srv.closeListenersLocked()
// Unlock srv.mu while waiting for listenerGroup.
// The group Add and Done calls are made with srv.mu held,
// to avoid adding a new listener in the window between
// us setting inShutdown above and waiting here.
srv.mu.Unlock()
srv.listenerGroup.Wait()
srv.mu.Lock()
for c := range srv.activeConn {
c.rwc.Close()
delete(srv.activeConn, c)
}
return err
}
// shutdownPollIntervalMax is the max polling interval when checking
// quiescence during Server.Shutdown. Polling starts with a small
// interval and backs off to the max.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
const shutdownPollIntervalMax = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Shutdown does not attempt to close nor wait for hijacked
// connections such as WebSockets. The caller of Shutdown should
// separately notify such long-lived connections of shutdown and wait
// for them to close, if desired. See RegisterOnShutdown for a way to
// register shutdown notification functions.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
srv.inShutdown.Store(true)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
for _, f := range srv.onShutdown {
go f()
}
srv.mu.Unlock()
srv.listenerGroup.Wait()
pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}
timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if srv.closeIdleConns() {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
timer.Reset(nextPollInterval())
}
}
}
// RegisterOnShutdown registers a function to call on Shutdown.
// This can be used to gracefully shutdown connections that have
// undergone ALPN protocol upgrade or that have been hijacked.
// This function should start protocol-specific graceful shutdown,
// but should not wait for shutdown to complete.
func (srv *Server) RegisterOnShutdown(f func()) {
srv.mu.Lock()
srv.onShutdown = append(srv.onShutdown, f)
srv.mu.Unlock()
}
// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
st, unixSec := c.getState()
// Issue 22682: treat StateNew connections as if
// they're idle if we haven't read the first request's
// header in over 5 seconds.
if st == StateNew && unixSec < time.Now().Unix()-5 {
st = StateIdle
}
if st != StateIdle || unixSec == 0 {
// Assume unixSec == 0 means it's a very new
// connection, without state set yet.
quiescent = false
continue
}
c.rwc.Close()
delete(s.activeConn, c)
}
return quiescent
}
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
}
return err
}
// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
const (
// StateNew represents a new connection that is expected to
// send a request immediately. Connections begin at this
// state and then transition to either StateActive or
// StateClosed.
StateNew ConnState = iota
// StateActive represents a connection that has read 1 or more
// bytes of a request. The Server.ConnState hook for
// StateActive fires before the request has entered a handler
// and doesn't fire again until the request has been
// handled. After the request is handled, the state
// transitions to StateClosed, StateHijacked, or StateIdle.
// For HTTP/2, StateActive fires on the transition from zero
// to one active request, and only transitions away once all
// active requests are complete. That means that ConnState
// cannot be used to do per-request work; ConnState only notes
// the overall state of the connection.
StateActive
// StateIdle represents a connection that has finished
// handling a request and is in the keep-alive state, waiting
// for a new request. Connections transition from StateIdle
// to either StateActive or StateClosed.
StateIdle
// StateHijacked represents a hijacked connection.
// This is a terminal state. It does not transition to StateClosed.
StateHijacked
// StateClosed represents a closed connection.
// This is a terminal state. Hijacked connections do not
// transition to StateClosed.
StateClosed
)
var stateName = map[ConnState]string{
StateNew: "new",
StateActive: "active",
StateIdle: "idle",
StateHijacked: "hijacked",
StateClosed: "closed",
}
func (c ConnState) String() string {
return stateName[c]
}
// serverHandler delegates to either the server's Handler or
// DefaultServeMux and also handles "OPTIONS *" requests.
type serverHandler struct {
srv *Server
}
func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
handler := sh.srv.Handler
if handler == nil {
handler = DefaultServeMux
}
if !sh.srv.DisableGeneralOptionsHandler && req.RequestURI == "*" && req.Method == "OPTIONS" {
handler = globalOptionsHandler{}
}
handler.ServeHTTP(rw, req)
}
// AllowQuerySemicolons returns a handler that serves requests by converting any
// unescaped semicolons in the URL query to ampersands, and invoking the handler h.
//
// This restores the pre-Go 1.17 behavior of splitting query parameters on both
// semicolons and ampersands. (See golang.org/issue/25192). Note that this
// behavior doesn't match that of many proxies, and the mismatch can lead to
// security issues.
//
// AllowQuerySemicolons should be invoked before Request.ParseForm is called.
func AllowQuerySemicolons(h Handler) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
if strings.Contains(r.URL.RawQuery, ";") {
r2 := new(Request)
*r2 = *r
r2.URL = new(url.URL)
*r2.URL = *r.URL
r2.URL.RawQuery = strings.ReplaceAll(r.URL.RawQuery, ";", "&")
h.ServeHTTP(w, r2)
} else {
h.ServeHTTP(w, r)
}
})
}
// ListenAndServe listens on the TCP network address srv.Addr and then
// calls Serve to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// If srv.Addr is blank, ":http" is used.
//
// ListenAndServe always returns a non-nil error. After Shutdown or Close,
// the returned error is ErrServerClosed.
func (srv *Server) ListenAndServe() error {
if srv.shuttingDown() {
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
addr = ":http"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return srv.Serve(ln)
}
var testHookServerServe func(*Server, net.Listener) // used if non-nil
// shouldConfigureHTTP2ForServe reports whether Server.Serve should configure
// automatic HTTP/2. (which sets up the srv.TLSNextProto map)
func (srv *Server) shouldConfigureHTTP2ForServe() bool {
if srv.TLSConfig == nil {
// Compatibility with Go 1.6:
// If there's no TLSConfig, it's possible that the user just
// didn't set it on the http.Server, but did pass it to
// tls.NewListener and passed that listener to Serve.
// So we should configure HTTP/2 (to set up srv.TLSNextProto)
// in case the listener returns an "h2" *tls.Conn.
return true
}
// The user specified a TLSConfig on their http.Server.
// In this, case, only configure HTTP/2 if their tls.Config
// explicitly mentions "h2". Otherwise http2.ConfigureServer
// would modify the tls.Config to add it, but they probably already
// passed this tls.Config to tls.NewListener. And if they did,
// it's too late anyway to fix it. It would only be potentially racy.
// See Issue 15908.
return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
}
// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
// and ListenAndServeTLS methods after a call to Shutdown or Close.
var ErrServerClosed = errors.New("http: Server closed")
// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
//
// HTTP/2 support is only enabled if the Listener returns *tls.Conn
// connections and they were configured with "h2" in the TLS
// Config.NextProtos.
//
// Serve always returns a non-nil error and closes l.
// After Shutdown or Close, the returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
if fn := testHookServerServe; fn != nil {
fn(srv, l) // call hook with unwrapped listener
}
origListener := l
l = &onceCloseListener{Listener: l}
defer l.Close()
if err := srv.setupHTTP2_Serve(); err != nil {
return err
}
if !srv.trackListener(&l, true) {
return ErrServerClosed
}
defer srv.trackListener(&l, false)
baseCtx := context.Background()
if srv.BaseContext != nil {
baseCtx = srv.BaseContext(origListener)
if baseCtx == nil {
panic("BaseContext returned a nil context")
}
}
var tempDelay time.Duration // how long to sleep on accept failure
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
rw, err := l.Accept()
if err != nil {
if srv.shuttingDown() {
return ErrServerClosed
}
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
srv.logf("http: Accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
return err
}
connCtx := ctx
if cc := srv.ConnContext; cc != nil {
connCtx = cc(connCtx, rw)
if connCtx == nil {
panic("ConnContext returned nil")
}
}
tempDelay = 0
c := srv.newConn(rw)
c.setState(c.rwc, StateNew, runHooks) // before Serve can return
go c.serve(connCtx)
}
}
// ServeTLS accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines perform TLS
// setup and then read requests, calling srv.Handler to reply to them.
//
// Files containing a certificate and matching private key for the
// server must be provided if neither the Server's
// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
// If the certificate is signed by a certificate authority, the
// certFile should be the concatenation of the server's certificate,
// any intermediates, and the CA's certificate.
//
// ServeTLS always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
// Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
// before we clone it and create the TLS Listener.
if err := srv.setupHTTP2_ServeTLS(); err != nil {
return err
}
config := cloneTLSConfig(srv.TLSConfig)
if !strSliceContains(config.NextProtos, "http/1.1") {
config.NextProtos = append(config.NextProtos, "http/1.1")
}
configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil
if !configHasCert || certFile != "" || keyFile != "" {
var err error
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
}
tlsListener := tls.NewListener(l, config)
return srv.Serve(tlsListener)
}
// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
//
// It reports whether the server is still up (not Shutdown or Closed).
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[*net.Listener]struct{})
}
if add {
if s.shuttingDown() {
return false
}
s.listeners[ln] = struct{}{}
s.listenerGroup.Add(1)
} else {
delete(s.listeners, ln)
s.listenerGroup.Done()
}
return true
}
func (s *Server) trackConn(c *conn, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
}
func (s *Server) idleTimeout() time.Duration {
if s.IdleTimeout != 0 {
return s.IdleTimeout
}
return s.ReadTimeout
}
func (s *Server) readHeaderTimeout() time.Duration {
if s.ReadHeaderTimeout != 0 {
return s.ReadHeaderTimeout
}
return s.ReadTimeout
}
func (s *Server) doKeepAlives() bool {
return !s.disableKeepAlives.Load() && !s.shuttingDown()
}
func (s *Server) shuttingDown() bool {
return s.inShutdown.Load()
}
// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
// By default, keep-alives are always enabled. Only very
// resource-constrained environments or servers in the process of
// shutting down should disable them.
func (srv *Server) SetKeepAlivesEnabled(v bool) {
if v {
srv.disableKeepAlives.Store(false)
return
}
srv.disableKeepAlives.Store(true)
// Close idle HTTP/1 conns:
srv.closeIdleConns()
// TODO: Issue 26303: close HTTP/2 conns as soon as they become idle.
}
func (s *Server) logf(format string, args ...any) {
if s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// logf prints to the ErrorLog of the *Server associated with request r
// via ServerContextKey. If there's no associated server, or if ErrorLog
// is nil, logging is done via the log package's standard logger.
func logf(r *Request, format string, args ...any) {
s, _ := r.Context().Value(ServerContextKey).(*Server)
if s != nil && s.ErrorLog != nil {
s.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
// ListenAndServe listens on the TCP network address addr and then calls
// Serve with handler to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// The handler is typically nil, in which case the DefaultServeMux is used.
//
// ListenAndServe always returns a non-nil error.
func ListenAndServe(addr string, handler Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServe()
}
// ListenAndServeTLS acts identically to ListenAndServe, except that it
// expects HTTPS connections. Additionally, files containing a certificate and
// matching private key for the server must be provided. If the certificate
// is signed by a certificate authority, the certFile should be the concatenation
// of the server's certificate, any intermediates, and the CA's certificate.
func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and
// then calls ServeTLS to handle requests on incoming TLS connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// Filenames containing a certificate and matching private key for the
// server must be provided if neither the Server's TLSConfig.Certificates
// nor TLSConfig.GetCertificate are populated. If the certificate is
// signed by a certificate authority, the certFile should be the
// concatenation of the server's certificate, any intermediates, and
// the CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
//
// ListenAndServeTLS always returns a non-nil error. After Shutdown or
// Close, the returned error is ErrServerClosed.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
if srv.shuttingDown() {
return ErrServerClosed
}
addr := srv.Addr
if addr == "" {
addr = ":https"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer ln.Close()
return srv.ServeTLS(ln, certFile, keyFile)
}
// setupHTTP2_ServeTLS conditionally configures HTTP/2 on
// srv and reports whether there was an error setting it up. If it is
// not configured for policy reasons, nil is returned.
func (srv *Server) setupHTTP2_ServeTLS() error {
srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults)
return srv.nextProtoErr
}
// setupHTTP2_Serve is called from (*Server).Serve and conditionally
// configures HTTP/2 on srv using a more conservative policy than
// setupHTTP2_ServeTLS because Serve is called after tls.Listen,
// and may be called concurrently. See shouldConfigureHTTP2ForServe.
//
// The tests named TestTransportAutomaticHTTP2* and
// TestConcurrentServerServe in server_test.go demonstrate some
// of the supported use cases and motivations.
func (srv *Server) setupHTTP2_Serve() error {
srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults_Serve)
return srv.nextProtoErr
}
func (srv *Server) onceSetNextProtoDefaults_Serve() {
if srv.shouldConfigureHTTP2ForServe() {
srv.onceSetNextProtoDefaults()
}
}
var http2server = godebug.New("http2server")
// onceSetNextProtoDefaults configures HTTP/2, if the user hasn't
// configured otherwise. (by setting srv.TLSNextProto non-nil)
// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*).
func (srv *Server) onceSetNextProtoDefaults() {
if omitBundledHTTP2 {
return
}
if http2server.Value() == "0" {
http2server.IncNonDefault()
return
}
// Enable HTTP/2 by default if the user hasn't otherwise
// configured their TLSNextProto map.
if srv.TLSNextProto == nil {
conf := &http2Server{
NewWriteScheduler: func() http2WriteScheduler { return http2NewPriorityWriteScheduler(nil) },
}
srv.nextProtoErr = http2ConfigureServer(srv, conf)
}
}
// TimeoutHandler returns a Handler that runs h with the given time limit.
//
// The new Handler calls h.ServeHTTP to handle each request, but if a
// call runs for longer than its time limit, the handler responds with
// a 503 Service Unavailable error and the given message in its body.
// (If msg is empty, a suitable default message will be sent.)
// After such a timeout, writes by h to its ResponseWriter will return
// ErrHandlerTimeout.
//
// TimeoutHandler supports the Pusher interface but does not support
// the Hijacker or Flusher interfaces.
func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
return &timeoutHandler{
handler: h,
body: msg,
dt: dt,
}
}
// ErrHandlerTimeout is returned on ResponseWriter Write calls
// in handlers which have timed out.
var ErrHandlerTimeout = errors.New("http: Handler timeout")
type timeoutHandler struct {
handler Handler
body string
dt time.Duration
// When set, no context will be created and this context will
// be used instead.
testContext context.Context
}
func (h *timeoutHandler) errorBody() string {
if h.body != "" {
return h.body
}
return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
}
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
ctx := h.testContext
if ctx == nil {
var cancelCtx context.CancelFunc
ctx, cancelCtx = context.WithTimeout(r.Context(), h.dt)
defer cancelCtx()
}
r = r.WithContext(ctx)
done := make(chan struct{})
tw := &timeoutWriter{
w: w,
h: make(Header),
req: r,
}
panicChan := make(chan any, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
h.handler.ServeHTTP(tw, r)
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
tw.mu.Lock()
defer tw.mu.Unlock()
dst := w.Header()
for k, vv := range tw.h {
dst[k] = vv
}
if !tw.wroteHeader {
tw.code = StatusOK
}
w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes())
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
switch err := ctx.Err(); err {
case context.DeadlineExceeded:
w.WriteHeader(StatusServiceUnavailable)
io.WriteString(w, h.errorBody())
tw.err = ErrHandlerTimeout
default:
w.WriteHeader(StatusServiceUnavailable)
tw.err = err
}
}
}
type timeoutWriter struct {
w ResponseWriter
h Header
wbuf bytes.Buffer
req *Request
mu sync.Mutex
err error
wroteHeader bool
code int
}
var _ Pusher = (*timeoutWriter)(nil)
// Push implements the Pusher interface.
func (tw *timeoutWriter) Push(target string, opts *PushOptions) error {
if pusher, ok := tw.w.(Pusher); ok {
return pusher.Push(target, opts)
}
return ErrNotSupported
}
func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.err != nil {
return 0, tw.err
}
if !tw.wroteHeader {
tw.writeHeaderLocked(StatusOK)
}
return tw.wbuf.Write(p)
}
func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
switch {
case tw.err != nil:
return
case tw.wroteHeader:
if tw.req != nil {
caller := relevantCaller()
logf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
}
default:
tw.wroteHeader = true
tw.code = code
}
}
func (tw *timeoutWriter) WriteHeader(code int) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.writeHeaderLocked(code)
}
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
// globalOptionsHandler responds to "OPTIONS *" requests.
type globalOptionsHandler struct{}
func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "0")
if r.ContentLength != 0 {
// Read up to 4KB of OPTIONS body (as mentioned in the
// spec as being reserved for future use), but anything
// over that is considered a waste of server resources
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := MaxBytesReader(w, r.Body, 4<<10)
io.Copy(io.Discard, mb)
}
}
// initALPNRequest is an HTTP handler that initializes certain
// uninitialized fields in its *Request. Such partially-initialized
// Requests come from ALPN protocol handlers.
type initALPNRequest struct {
ctx context.Context
c *tls.Conn
h serverHandler
}
// BaseContext is an exported but unadvertised http.Handler method
// recognized by x/net/http2 to pass down a context; the TLSNextProto
// API predates context support so we shoehorn through the only
// interface we have available.
func (h initALPNRequest) BaseContext() context.Context { return h.ctx }
func (h initALPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
if req.TLS == nil {
req.TLS = &tls.ConnectionState{}
*req.TLS = h.c.ConnectionState()
}
if req.Body == nil {
req.Body = NoBody
}
if req.RemoteAddr == "" {
req.RemoteAddr = h.c.RemoteAddr().String()
}
h.h.ServeHTTP(rw, req)
}
// loggingConn is used for debugging.
type loggingConn struct {
name string
net.Conn
}
var (
uniqNameMu sync.Mutex
uniqNameNext = make(map[string]int)
)
func newLoggingConn(baseName string, c net.Conn) net.Conn {
uniqNameMu.Lock()
defer uniqNameMu.Unlock()
uniqNameNext[baseName]++
return &loggingConn{
name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
Conn: c,
}
}
func (c *loggingConn) Write(p []byte) (n int, err error) {
log.Printf("%s.Write(%d) = ....", c.name, len(p))
n, err = c.Conn.Write(p)
log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Read(p []byte) (n int, err error) {
log.Printf("%s.Read(%d) = ....", c.name, len(p))
n, err = c.Conn.Read(p)
log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
return
}
func (c *loggingConn) Close() (err error) {
log.Printf("%s.Close() = ...", c.name)
err = c.Conn.Close()
log.Printf("%s.Close() = %v", c.name, err)
return
}
// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
// It only contains one field (and a pointer field at that), so it
// fits in an interface value without an extra allocation.
type checkConnErrorWriter struct {
c *conn
}
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
n, err = w.c.rwc.Write(p)
if err != nil && w.c.werr == nil {
w.c.werr = err
w.c.cancelCtx()
}
return
}
func numLeadingCRorLF(v []byte) (n int) {
for _, b := range v {
if b == '\r' || b == '\n' {
n++
continue
}
break
}
return
}
func strSliceContains(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
// tlsRecordHeaderLooksLikeHTTP reports whether a TLS record header
// looks like it might've been a misdirected plaintext HTTP request.
func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
switch string(hdr[:]) {
case "GET /", "HEAD ", "POST ", "PUT /", "OPTIO":
return true
}
return false
}
// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
func MaxBytesHandler(h Handler, n int64) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
r2 := *r
r2.Body = MaxBytesReader(w, r.Body, n)
h.ServeHTTP(w, &r2)
})
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"bytes"
"encoding/binary"
)
// The algorithm uses at most sniffLen bytes to make its decision.
const sniffLen = 512
// DetectContentType implements the algorithm described
// at https://mimesniff.spec.whatwg.org/ to determine the
// Content-Type of the given data. It considers at most the
// first 512 bytes of data. DetectContentType always returns
// a valid MIME type: if it cannot determine a more specific one, it
// returns "application/octet-stream".
func DetectContentType(data []byte) string {
if len(data) > sniffLen {
data = data[:sniffLen]
}
// Index of the first non-whitespace byte in data.
firstNonWS := 0
for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ {
}
for _, sig := range sniffSignatures {
if ct := sig.match(data, firstNonWS); ct != "" {
return ct
}
}
return "application/octet-stream" // fallback
}
// isWS reports whether the provided byte is a whitespace byte (0xWS)
// as defined in https://mimesniff.spec.whatwg.org/#terminology.
func isWS(b byte) bool {
switch b {
case '\t', '\n', '\x0c', '\r', ' ':
return true
}
return false
}
// isTT reports whether the provided byte is a tag-terminating byte (0xTT)
// as defined in https://mimesniff.spec.whatwg.org/#terminology.
func isTT(b byte) bool {
switch b {
case ' ', '>':
return true
}
return false
}
type sniffSig interface {
// match returns the MIME type of the data, or "" if unknown.
match(data []byte, firstNonWS int) string
}
// Data matching the table in section 6.
var sniffSignatures = []sniffSig{
htmlSig("<!DOCTYPE HTML"),
htmlSig("<HTML"),
htmlSig("<HEAD"),
htmlSig("<SCRIPT"),
htmlSig("<IFRAME"),
htmlSig("<H1"),
htmlSig("<DIV"),
htmlSig("<FONT"),
htmlSig("<TABLE"),
htmlSig("<A"),
htmlSig("<STYLE"),
htmlSig("<TITLE"),
htmlSig("<B"),
htmlSig("<BODY"),
htmlSig("<BR"),
htmlSig("<P"),
htmlSig("<!--"),
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
pat: []byte("<?xml"),
skipWS: true,
ct: "text/xml; charset=utf-8"},
&exactSig{[]byte("%PDF-"), "application/pdf"},
&exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
// UTF BOMs.
&maskedSig{
mask: []byte("\xFF\xFF\x00\x00"),
pat: []byte("\xFE\xFF\x00\x00"),
ct: "text/plain; charset=utf-16be",
},
&maskedSig{
mask: []byte("\xFF\xFF\x00\x00"),
pat: []byte("\xFF\xFE\x00\x00"),
ct: "text/plain; charset=utf-16le",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\x00"),
pat: []byte("\xEF\xBB\xBF\x00"),
ct: "text/plain; charset=utf-8",
},
// Image types
// For posterity, we originally returned "image/vnd.microsoft.icon" from
// https://tools.ietf.org/html/draft-ietf-websec-mime-sniff-03#section-7
// https://codereview.appspot.com/4746042
// but that has since been replaced with "image/x-icon" in Section 6.2
// of https://mimesniff.spec.whatwg.org/#matching-an-image-type-pattern
&exactSig{[]byte("\x00\x00\x01\x00"), "image/x-icon"},
&exactSig{[]byte("\x00\x00\x02\x00"), "image/x-icon"},
&exactSig{[]byte("BM"), "image/bmp"},
&exactSig{[]byte("GIF87a"), "image/gif"},
&exactSig{[]byte("GIF89a"), "image/gif"},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
ct: "image/webp",
},
&exactSig{[]byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"},
&exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
// Audio and Video types
// Enforce the pattern match ordering as prescribed in
// https://mimesniff.spec.whatwg.org/#matching-an-audio-or-video-type-pattern
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
pat: []byte("FORM\x00\x00\x00\x00AIFF"),
ct: "audio/aiff",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF"),
pat: []byte("ID3"),
ct: "audio/mpeg",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
pat: []byte("OggS\x00"),
ct: "application/ogg",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"),
pat: []byte("MThd\x00\x00\x00\x06"),
ct: "audio/midi",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
pat: []byte("RIFF\x00\x00\x00\x00AVI "),
ct: "video/avi",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
ct: "audio/wave",
},
// 6.2.0.2. video/mp4
mp4Sig{},
// 6.2.0.3. video/webm
&exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
// Font types
&maskedSig{
// 34 NULL bytes followed by the string "LP"
pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00LP"),
// 34 NULL bytes followed by \xF\xF
mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"),
ct: "application/vnd.ms-fontobject",
},
&exactSig{[]byte("\x00\x01\x00\x00"), "font/ttf"},
&exactSig{[]byte("OTTO"), "font/otf"},
&exactSig{[]byte("ttcf"), "font/collection"},
&exactSig{[]byte("wOFF"), "font/woff"},
&exactSig{[]byte("wOF2"), "font/woff2"},
// Archive types
&exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
&exactSig{[]byte("PK\x03\x04"), "application/zip"},
// RAR's signatures are incorrectly defined by the MIME spec as per
// https://github.com/whatwg/mimesniff/issues/63
// However, RAR Labs correctly defines it at:
// https://www.rarlab.com/technote.htm#rarsign
// so we use the definition from RAR Labs.
// TODO: do whatever the spec ends up doing.
&exactSig{[]byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"}, // RAR v1.5-v4.0
&exactSig{[]byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"}, // RAR v5+
&exactSig{[]byte("\x00\x61\x73\x6D"), "application/wasm"},
textSig{}, // should be last
}
type exactSig struct {
sig []byte
ct string
}
func (e *exactSig) match(data []byte, firstNonWS int) string {
if bytes.HasPrefix(data, e.sig) {
return e.ct
}
return ""
}
type maskedSig struct {
mask, pat []byte
skipWS bool
ct string
}
func (m *maskedSig) match(data []byte, firstNonWS int) string {
// pattern matching algorithm section 6
// https://mimesniff.spec.whatwg.org/#pattern-matching-algorithm
if m.skipWS {
data = data[firstNonWS:]
}
if len(m.pat) != len(m.mask) {
return ""
}
if len(data) < len(m.pat) {
return ""
}
for i, pb := range m.pat {
maskedData := data[i] & m.mask[i]
if maskedData != pb {
return ""
}
}
return m.ct
}
type htmlSig []byte
func (h htmlSig) match(data []byte, firstNonWS int) string {
data = data[firstNonWS:]
if len(data) < len(h)+1 {
return ""
}
for i, b := range h {
db := data[i]
if 'A' <= b && b <= 'Z' {
db &= 0xDF
}
if b != db {
return ""
}
}
// Next byte must be a tag-terminating byte(0xTT).
if !isTT(data[len(h)]) {
return ""
}
return "text/html; charset=utf-8"
}
var mp4ftype = []byte("ftyp")
var mp4 = []byte("mp4")
type mp4Sig struct{}
func (mp4Sig) match(data []byte, firstNonWS int) string {
// https://mimesniff.spec.whatwg.org/#signature-for-mp4
// c.f. section 6.2.1
if len(data) < 12 {
return ""
}
boxSize := int(binary.BigEndian.Uint32(data[:4]))
if len(data) < boxSize || boxSize%4 != 0 {
return ""
}
if !bytes.Equal(data[4:8], mp4ftype) {
return ""
}
for st := 8; st < boxSize; st += 4 {
if st == 12 {
// Ignores the four bytes that correspond to the version number of the "major brand".
continue
}
if bytes.Equal(data[st:st+3], mp4) {
return "video/mp4"
}
}
return ""
}
type textSig struct{}
func (textSig) match(data []byte, firstNonWS int) string {
// c.f. section 5, step 4.
for _, b := range data[firstNonWS:] {
switch {
case b <= 0x08,
b == 0x0B,
0x0E <= b && b <= 0x1A,
0x1C <= b && b <= 0x1F:
return ""
}
}
return "text/plain; charset=utf-8"
}
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o socks_bundle.go -prefix socks golang.org/x/net/internal/socks
// Package socks provides a SOCKS version 5 client implementation.
//
// SOCKS protocol version 5 is defined in RFC 1928.
// Username/Password authentication for SOCKS version 5 is defined in
// RFC 1929.
//
package http
import (
"context"
"errors"
"io"
"net"
"strconv"
"time"
)
var (
socksnoDeadline = time.Time{}
socksaLongTimeAgo = time.Unix(1, 0)
)
func (d *socksDialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := sockssplitHostPort(address)
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(socksnoDeadline)
}
if ctx != context.Background() {
errCh := make(chan error, 1)
done := make(chan struct{})
defer func() {
close(done)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
c.SetDeadline(socksaLongTimeAgo)
errCh <- ctx.Err()
case <-done:
errCh <- nil
}
}()
}
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, socksVersion5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(socksAuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
b = append(b, byte(am))
}
}
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
return
}
if b[0] != socksVersion5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := socksAuthMethod(b[1])
if am == socksAuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
return
}
}
b = b[:0]
b = append(b, socksVersion5, byte(d.cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, socksAddrTypeIPv4)
b = append(b, ip4...)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, socksAddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
}
b = append(b, socksAddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
return
}
if b[0] != socksVersion5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := socksReply(b[1]); cmdErr != socksStatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
}
l := 2
var a socksAddr
switch b[3] {
case socksAddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
case socksAddrTypeIPv6:
l += net.IPv6len
a.IP = make(net.IP, net.IPv6len)
case socksAddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}
if cap(b) < l {
b = make([]byte, l)
} else {
b = b[:l]
}
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
return
}
if a.IP != nil {
copy(a.IP, b)
} else {
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}
func sockssplitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
}
return host, portnum, nil
}
// A Command represents a SOCKS command.
type socksCommand int
func (cmd socksCommand) String() string {
switch cmd {
case socksCmdConnect:
return "socks connect"
case sockscmdBind:
return "socks bind"
default:
return "socks " + strconv.Itoa(int(cmd))
}
}
// An AuthMethod represents a SOCKS authentication method.
type socksAuthMethod int
// A Reply represents a SOCKS command reply code.
type socksReply int
func (code socksReply) String() string {
switch code {
case socksStatusSucceeded:
return "succeeded"
case 0x01:
return "general SOCKS server failure"
case 0x02:
return "connection not allowed by ruleset"
case 0x03:
return "network unreachable"
case 0x04:
return "host unreachable"
case 0x05:
return "connection refused"
case 0x06:
return "TTL expired"
case 0x07:
return "command not supported"
case 0x08:
return "address type not supported"
default:
return "unknown code: " + strconv.Itoa(int(code))
}
}
// Wire protocol constants.
const (
socksVersion5 = 0x05
socksAddrTypeIPv4 = 0x01
socksAddrTypeFQDN = 0x03
socksAddrTypeIPv6 = 0x04
socksCmdConnect socksCommand = 0x01 // establishes an active-open forward proxy connection
sockscmdBind socksCommand = 0x02 // establishes a passive-open forward proxy connection
socksAuthMethodNotRequired socksAuthMethod = 0x00 // no authentication required
socksAuthMethodUsernamePassword socksAuthMethod = 0x02 // use username/password
socksAuthMethodNoAcceptableMethods socksAuthMethod = 0xff // no acceptable authentication methods
socksStatusSucceeded socksReply = 0x00
)
// An Addr represents a SOCKS-specific address.
// Either Name or IP is used exclusively.
type socksAddr struct {
Name string // fully-qualified domain name
IP net.IP
Port int
}
func (a *socksAddr) Network() string { return "socks" }
func (a *socksAddr) String() string {
if a == nil {
return "<nil>"
}
port := strconv.Itoa(a.Port)
if a.IP == nil {
return net.JoinHostPort(a.Name, port)
}
return net.JoinHostPort(a.IP.String(), port)
}
// A Conn represents a forward proxy connection.
type socksConn struct {
net.Conn
boundAddr net.Addr
}
// BoundAddr returns the address assigned by the proxy server for
// connecting to the command target address from the proxy server.
func (c *socksConn) BoundAddr() net.Addr {
if c == nil {
return nil
}
return c.boundAddr
}
// A Dialer holds SOCKS-specific options.
type socksDialer struct {
cmd socksCommand // either CmdConnect or cmdBind
proxyNetwork string // network between a proxy server and a client
proxyAddress string // proxy server address
// ProxyDial specifies the optional dial function for
// establishing the transport connection.
ProxyDial func(context.Context, string, string) (net.Conn, error)
// AuthMethods specifies the list of request authentication
// methods.
// If empty, SOCKS client requests only AuthMethodNotRequired.
AuthMethods []socksAuthMethod
// Authenticate specifies the optional authentication
// function. It must be non-nil when AuthMethods is not empty.
// It must return an error when the authentication is failed.
Authenticate func(context.Context, io.ReadWriter, socksAuthMethod) error
}
// DialContext connects to the provided address on the provided
// network.
//
// The returned error value may be a net.OpError. When the Op field of
// net.OpError contains "socks", the Source field contains a proxy
// server address and the Addr field contains a command target
// address.
//
// See func Dial of the net package of standard library for a
// description of the network and address parameters.
func (d *socksDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
} else {
var dd net.Dialer
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
a, err := d.connect(ctx, c, address)
if err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return &socksConn{Conn: c, boundAddr: a}, nil
}
// DialWithConn initiates a connection from SOCKS server to the target
// network and address using the connection c that is already
// connected to the SOCKS server.
//
// It returns the connection's local address assigned by the SOCKS
// server.
func (d *socksDialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
a, err := d.connect(ctx, c, address)
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return a, nil
}
// Dial connects to the provided address on the provided network.
//
// Unlike DialContext, it returns a raw transport connection instead
// of a forward proxy connection.
//
// Deprecated: Use DialContext or DialWithConn instead.
func (d *socksDialer) Dial(network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
} else {
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func (d *socksDialer) validateTarget(network, address string) error {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return errors.New("network not implemented")
}
switch d.cmd {
case socksCmdConnect, sockscmdBind:
default:
return errors.New("command not implemented")
}
return nil
}
func (d *socksDialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
for i, s := range []string{d.proxyAddress, address} {
host, port, err := sockssplitHostPort(s)
if err != nil {
return nil, nil, err
}
a := &socksAddr{Port: port}
a.IP = net.ParseIP(host)
if a.IP == nil {
a.Name = host
}
if i == 0 {
proxy = a
} else {
dst = a
}
}
return
}
// NewDialer returns a new Dialer that dials through the provided
// proxy server's network and address.
func socksNewDialer(network, address string) *socksDialer {
return &socksDialer{proxyNetwork: network, proxyAddress: address, cmd: socksCmdConnect}
}
const (
socksauthUsernamePasswordVersion = 0x01
socksauthStatusSucceeded = 0x00
)
// UsernamePassword are the credentials for the username/password
// authentication method.
type socksUsernamePassword struct {
Username string
Password string
}
// Authenticate authenticates a pair of username and password with the
// proxy server.
func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth socksAuthMethod) error {
switch auth {
case socksAuthMethodNotRequired:
return nil
case socksAuthMethodUsernamePassword:
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
return errors.New("invalid username/password")
}
b := []byte{socksauthUsernamePasswordVersion}
b = append(b, byte(len(up.Username)))
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// necessary
if _, err := rw.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(rw, b[:2]); err != nil {
return err
}
if b[0] != socksauthUsernamePasswordVersion {
return errors.New("invalid username/password version")
}
if b[1] != socksauthStatusSucceeded {
return errors.New("username/password authentication failed")
}
return nil
}
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
// HTTP status codes as registered with IANA.
// See: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
const (
StatusContinue = 100 // RFC 9110, 15.2.1
StatusSwitchingProtocols = 101 // RFC 9110, 15.2.2
StatusProcessing = 102 // RFC 2518, 10.1
StatusEarlyHints = 103 // RFC 8297
StatusOK = 200 // RFC 9110, 15.3.1
StatusCreated = 201 // RFC 9110, 15.3.2
StatusAccepted = 202 // RFC 9110, 15.3.3
StatusNonAuthoritativeInfo = 203 // RFC 9110, 15.3.4
StatusNoContent = 204 // RFC 9110, 15.3.5
StatusResetContent = 205 // RFC 9110, 15.3.6
StatusPartialContent = 206 // RFC 9110, 15.3.7
StatusMultiStatus = 207 // RFC 4918, 11.1
StatusAlreadyReported = 208 // RFC 5842, 7.1
StatusIMUsed = 226 // RFC 3229, 10.4.1
StatusMultipleChoices = 300 // RFC 9110, 15.4.1
StatusMovedPermanently = 301 // RFC 9110, 15.4.2
StatusFound = 302 // RFC 9110, 15.4.3
StatusSeeOther = 303 // RFC 9110, 15.4.4
StatusNotModified = 304 // RFC 9110, 15.4.5
StatusUseProxy = 305 // RFC 9110, 15.4.6
_ = 306 // RFC 9110, 15.4.7 (Unused)
StatusTemporaryRedirect = 307 // RFC 9110, 15.4.8
StatusPermanentRedirect = 308 // RFC 9110, 15.4.9
StatusBadRequest = 400 // RFC 9110, 15.5.1
StatusUnauthorized = 401 // RFC 9110, 15.5.2
StatusPaymentRequired = 402 // RFC 9110, 15.5.3
StatusForbidden = 403 // RFC 9110, 15.5.4
StatusNotFound = 404 // RFC 9110, 15.5.5
StatusMethodNotAllowed = 405 // RFC 9110, 15.5.6
StatusNotAcceptable = 406 // RFC 9110, 15.5.7
StatusProxyAuthRequired = 407 // RFC 9110, 15.5.8
StatusRequestTimeout = 408 // RFC 9110, 15.5.9
StatusConflict = 409 // RFC 9110, 15.5.10
StatusGone = 410 // RFC 9110, 15.5.11
StatusLengthRequired = 411 // RFC 9110, 15.5.12
StatusPreconditionFailed = 412 // RFC 9110, 15.5.13
StatusRequestEntityTooLarge = 413 // RFC 9110, 15.5.14
StatusRequestURITooLong = 414 // RFC 9110, 15.5.15
StatusUnsupportedMediaType = 415 // RFC 9110, 15.5.16
StatusRequestedRangeNotSatisfiable = 416 // RFC 9110, 15.5.17
StatusExpectationFailed = 417 // RFC 9110, 15.5.18
StatusTeapot = 418 // RFC 9110, 15.5.19 (Unused)
StatusMisdirectedRequest = 421 // RFC 9110, 15.5.20
StatusUnprocessableEntity = 422 // RFC 9110, 15.5.21
StatusLocked = 423 // RFC 4918, 11.3
StatusFailedDependency = 424 // RFC 4918, 11.4
StatusTooEarly = 425 // RFC 8470, 5.2.
StatusUpgradeRequired = 426 // RFC 9110, 15.5.22
StatusPreconditionRequired = 428 // RFC 6585, 3
StatusTooManyRequests = 429 // RFC 6585, 4
StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5
StatusUnavailableForLegalReasons = 451 // RFC 7725, 3
StatusInternalServerError = 500 // RFC 9110, 15.6.1
StatusNotImplemented = 501 // RFC 9110, 15.6.2
StatusBadGateway = 502 // RFC 9110, 15.6.3
StatusServiceUnavailable = 503 // RFC 9110, 15.6.4
StatusGatewayTimeout = 504 // RFC 9110, 15.6.5
StatusHTTPVersionNotSupported = 505 // RFC 9110, 15.6.6
StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1
StatusInsufficientStorage = 507 // RFC 4918, 11.5
StatusLoopDetected = 508 // RFC 5842, 7.2
StatusNotExtended = 510 // RFC 2774, 7
StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6
)
// StatusText returns a text for the HTTP status code. It returns the empty
// string if the code is unknown.
func StatusText(code int) string {
switch code {
case StatusContinue:
return "Continue"
case StatusSwitchingProtocols:
return "Switching Protocols"
case StatusProcessing:
return "Processing"
case StatusEarlyHints:
return "Early Hints"
case StatusOK:
return "OK"
case StatusCreated:
return "Created"
case StatusAccepted:
return "Accepted"
case StatusNonAuthoritativeInfo:
return "Non-Authoritative Information"
case StatusNoContent:
return "No Content"
case StatusResetContent:
return "Reset Content"
case StatusPartialContent:
return "Partial Content"
case StatusMultiStatus:
return "Multi-Status"
case StatusAlreadyReported:
return "Already Reported"
case StatusIMUsed:
return "IM Used"
case StatusMultipleChoices:
return "Multiple Choices"
case StatusMovedPermanently:
return "Moved Permanently"
case StatusFound:
return "Found"
case StatusSeeOther:
return "See Other"
case StatusNotModified:
return "Not Modified"
case StatusUseProxy:
return "Use Proxy"
case StatusTemporaryRedirect:
return "Temporary Redirect"
case StatusPermanentRedirect:
return "Permanent Redirect"
case StatusBadRequest:
return "Bad Request"
case StatusUnauthorized:
return "Unauthorized"
case StatusPaymentRequired:
return "Payment Required"
case StatusForbidden:
return "Forbidden"
case StatusNotFound:
return "Not Found"
case StatusMethodNotAllowed:
return "Method Not Allowed"
case StatusNotAcceptable:
return "Not Acceptable"
case StatusProxyAuthRequired:
return "Proxy Authentication Required"
case StatusRequestTimeout:
return "Request Timeout"
case StatusConflict:
return "Conflict"
case StatusGone:
return "Gone"
case StatusLengthRequired:
return "Length Required"
case StatusPreconditionFailed:
return "Precondition Failed"
case StatusRequestEntityTooLarge:
return "Request Entity Too Large"
case StatusRequestURITooLong:
return "Request URI Too Long"
case StatusUnsupportedMediaType:
return "Unsupported Media Type"
case StatusRequestedRangeNotSatisfiable:
return "Requested Range Not Satisfiable"
case StatusExpectationFailed:
return "Expectation Failed"
case StatusTeapot:
return "I'm a teapot"
case StatusMisdirectedRequest:
return "Misdirected Request"
case StatusUnprocessableEntity:
return "Unprocessable Entity"
case StatusLocked:
return "Locked"
case StatusFailedDependency:
return "Failed Dependency"
case StatusTooEarly:
return "Too Early"
case StatusUpgradeRequired:
return "Upgrade Required"
case StatusPreconditionRequired:
return "Precondition Required"
case StatusTooManyRequests:
return "Too Many Requests"
case StatusRequestHeaderFieldsTooLarge:
return "Request Header Fields Too Large"
case StatusUnavailableForLegalReasons:
return "Unavailable For Legal Reasons"
case StatusInternalServerError:
return "Internal Server Error"
case StatusNotImplemented:
return "Not Implemented"
case StatusBadGateway:
return "Bad Gateway"
case StatusServiceUnavailable:
return "Service Unavailable"
case StatusGatewayTimeout:
return "Gateway Timeout"
case StatusHTTPVersionNotSupported:
return "HTTP Version Not Supported"
case StatusVariantAlsoNegotiates:
return "Variant Also Negotiates"
case StatusInsufficientStorage:
return "Insufficient Storage"
case StatusLoopDetected:
return "Loop Detected"
case StatusNotExtended:
return "Not Extended"
case StatusNetworkAuthenticationRequired:
return "Network Authentication Required"
default:
return ""
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net/http/httptrace"
"net/http/internal"
"net/http/internal/ascii"
"net/textproto"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/net/http/httpguts"
)
// ErrLineTooLong is returned when reading request or response bodies
// with malformed chunked encoding.
var ErrLineTooLong = internal.ErrLineTooLong
type errorReader struct {
err error
}
func (r errorReader) Read(p []byte) (n int, err error) {
return 0, r.err
}
type byteReader struct {
b byte
done bool
}
func (br *byteReader) Read(p []byte) (n int, err error) {
if br.done {
return 0, io.EOF
}
if len(p) == 0 {
return 0, nil
}
br.done = true
p[0] = br.b
return 1, io.EOF
}
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
type transferWriter struct {
Method string
Body io.Reader
BodyCloser io.Closer
ResponseToHEAD bool
ContentLength int64 // -1 means unknown, 0 means exactly none
Close bool
TransferEncoding []string
Header Header
Trailer Header
IsResponse bool
bodyReadError error // any non-EOF error from reading Body
FlushHeaders bool // flush headers to network before body
ByteReadCh chan readResult // non-nil if probeRequestBody called
}
func newTransferWriter(r any) (t *transferWriter, err error) {
t = &transferWriter{}
// Extract relevant fields
atLeastHTTP11 := false
switch rr := r.(type) {
case *Request:
if rr.ContentLength != 0 && rr.Body == nil {
return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
}
t.Method = valueOrDefault(rr.Method, "GET")
t.Close = rr.Close
t.TransferEncoding = rr.TransferEncoding
t.Header = rr.Header
t.Trailer = rr.Trailer
t.Body = rr.Body
t.BodyCloser = rr.Body
t.ContentLength = rr.outgoingLength()
if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() {
t.TransferEncoding = []string{"chunked"}
}
// If there's a body, conservatively flush the headers
// to any bufio.Writer we're writing to, just in case
// the server needs the headers early, before we copy
// the body and possibly block. We make an exception
// for the common standard library in-memory types,
// though, to avoid unnecessary TCP packets on the
// wire. (Issue 22088.)
if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) {
t.FlushHeaders = true
}
atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0
case *Response:
t.IsResponse = true
if rr.Request != nil {
t.Method = rr.Request.Method
}
t.Body = rr.Body
t.BodyCloser = rr.Body
t.ContentLength = rr.ContentLength
t.Close = rr.Close
t.TransferEncoding = rr.TransferEncoding
t.Header = rr.Header
t.Trailer = rr.Trailer
atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
t.ResponseToHEAD = noResponseBodyExpected(t.Method)
}
// Sanitize Body,ContentLength,TransferEncoding
if t.ResponseToHEAD {
t.Body = nil
if chunked(t.TransferEncoding) {
t.ContentLength = -1
}
} else {
if !atLeastHTTP11 || t.Body == nil {
t.TransferEncoding = nil
}
if chunked(t.TransferEncoding) {
t.ContentLength = -1
} else if t.Body == nil { // no chunking, no body
t.ContentLength = 0
}
}
// Sanitize Trailer
if !chunked(t.TransferEncoding) {
t.Trailer = nil
}
return t, nil
}
// shouldSendChunkedRequestBody reports whether we should try to send a
// chunked request body to the server. In particular, the case we really
// want to prevent is sending a GET or other typically-bodyless request to a
// server with a chunked body when the body has zero bytes, since GETs with
// bodies (while acceptable according to specs), even zero-byte chunked
// bodies, are approximately never seen in the wild and confuse most
// servers. See Issue 18257, as one example.
//
// The only reason we'd send such a request is if the user set the Body to a
// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't
// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume
// there's bytes to send.
//
// This code tries to read a byte from the Request.Body in such cases to see
// whether the body actually has content (super rare) or is actually just
// a non-nil content-less ReadCloser (the more common case). In that more
// common case, we act as if their Body were nil instead, and don't send
// a body.
func (t *transferWriter) shouldSendChunkedRequestBody() bool {
// Note that t.ContentLength is the corrected content length
// from rr.outgoingLength, so 0 actually means zero, not unknown.
if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them
return false
}
if t.Method == "CONNECT" {
return false
}
if requestMethodUsuallyLacksBody(t.Method) {
// Only probe the Request.Body for GET/HEAD/DELETE/etc
// requests, because it's only those types of requests
// that confuse servers.
t.probeRequestBody() // adjusts t.Body, t.ContentLength
return t.Body != nil
}
// For all other request types (PUT, POST, PATCH, or anything
// made-up we've never heard of), assume it's normal and the server
// can deal with a chunked request body. Maybe we'll adjust this
// later.
return true
}
// probeRequestBody reads a byte from t.Body to see whether it's empty
// (returns io.EOF right away).
//
// But because we've had problems with this blocking users in the past
// (issue 17480) when the body is a pipe (perhaps waiting on the response
// headers before the pipe is fed data), we need to be careful and bound how
// long we wait for it. This delay will only affect users if all the following
// are true:
// - the request body blocks
// - the content length is not set (or set to -1)
// - the method doesn't usually have a body (GET, HEAD, DELETE, ...)
// - there is no transfer-encoding=chunked already set.
//
// In other words, this delay will not normally affect anybody, and there
// are workarounds if it does.
func (t *transferWriter) probeRequestBody() {
t.ByteReadCh = make(chan readResult, 1)
go func(body io.Reader) {
var buf [1]byte
var rres readResult
rres.n, rres.err = body.Read(buf[:])
if rres.n == 1 {
rres.b = buf[0]
}
t.ByteReadCh <- rres
close(t.ByteReadCh)
}(t.Body)
timer := time.NewTimer(200 * time.Millisecond)
select {
case rres := <-t.ByteReadCh:
timer.Stop()
if rres.n == 0 && rres.err == io.EOF {
// It was empty.
t.Body = nil
t.ContentLength = 0
} else if rres.n == 1 {
if rres.err != nil {
t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err})
} else {
t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body)
}
} else if rres.err != nil {
t.Body = errorReader{rres.err}
}
case <-timer.C:
// Too slow. Don't wait. Read it later, and keep
// assuming that this is ContentLength == -1
// (unknown), which means we'll send a
// "Transfer-Encoding: chunked" header.
t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body)
// Request that Request.Write flush the headers to the
// network before writing the body, since our body may not
// become readable until it's seen the response headers.
t.FlushHeaders = true
}
}
func noResponseBodyExpected(requestMethod string) bool {
return requestMethod == "HEAD"
}
func (t *transferWriter) shouldSendContentLength() bool {
if chunked(t.TransferEncoding) {
return false
}
if t.ContentLength > 0 {
return true
}
if t.ContentLength < 0 {
return false
}
// Many servers expect a Content-Length for these methods
if t.Method == "POST" || t.Method == "PUT" || t.Method == "PATCH" {
return true
}
if t.ContentLength == 0 && isIdentity(t.TransferEncoding) {
if t.Method == "GET" || t.Method == "HEAD" {
return false
}
return true
}
return false
}
func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error {
if t.Close && !hasToken(t.Header.get("Connection"), "close") {
if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Connection", []string{"close"})
}
}
// Write Content-Length and/or Transfer-Encoding whose values are a
// function of the sanitized field triple (Body, ContentLength,
// TransferEncoding)
if t.shouldSendContentLength() {
if _, err := io.WriteString(w, "Content-Length: "); err != nil {
return err
}
if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)})
}
} else if chunked(t.TransferEncoding) {
if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"})
}
}
// Write Trailer header
if t.Trailer != nil {
keys := make([]string, 0, len(t.Trailer))
for k := range t.Trailer {
k = CanonicalHeaderKey(k)
switch k {
case "Transfer-Encoding", "Trailer", "Content-Length":
return badStringError("invalid Trailer key", k)
}
keys = append(keys, k)
}
if len(keys) > 0 {
sort.Strings(keys)
// TODO: could do better allocation-wise here, but trailers are rare,
// so being lazy for now.
if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
return err
}
if trace != nil && trace.WroteHeaderField != nil {
trace.WroteHeaderField("Trailer", keys)
}
}
}
return nil
}
// always closes t.BodyCloser
func (t *transferWriter) writeBody(w io.Writer) (err error) {
var ncopy int64
closed := false
defer func() {
if closed || t.BodyCloser == nil {
return
}
if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
// Write body. We "unwrap" the body first if it was wrapped in a
// nopCloser or readTrackingBody. This is to ensure that we can take advantage of
// OS-level optimizations in the event that the body is an
// *os.File.
if t.Body != nil {
var body = t.unwrapBody()
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
_, err = t.doBodyCopy(cw, body)
if err == nil {
err = cw.Close()
}
} else if t.ContentLength == -1 {
dst := w
if t.Method == "CONNECT" {
dst = bufioFlushWriter{dst}
}
ncopy, err = t.doBodyCopy(dst, body)
} else {
ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength))
if err != nil {
return err
}
var nextra int64
nextra, err = t.doBodyCopy(io.Discard, body)
ncopy += nextra
}
if err != nil {
return err
}
}
if t.BodyCloser != nil {
closed = true
if err := t.BodyCloser.Close(); err != nil {
return err
}
}
if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
return fmt.Errorf("http: ContentLength=%d with Body length %d",
t.ContentLength, ncopy)
}
if chunked(t.TransferEncoding) {
// Write Trailer header
if t.Trailer != nil {
if err := t.Trailer.Write(w); err != nil {
return err
}
}
// Last chunk, empty trailer
_, err = io.WriteString(w, "\r\n")
}
return err
}
// doBodyCopy wraps a copy operation, with any resulting error also
// being saved in bodyReadError.
//
// This function is only intended for use in writeBody.
func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) {
n, err = io.Copy(dst, src)
if err != nil && err != io.EOF {
t.bodyReadError = err
}
return
}
// unwrapBody unwraps the body's inner reader if it's a
// nopCloser. This is to ensure that body writes sourced from local
// files (*os.File types) are properly optimized.
//
// This function is only intended for use in writeBody.
func (t *transferWriter) unwrapBody() io.Reader {
if r, ok := unwrapNopCloser(t.Body); ok {
return r
}
if r, ok := t.Body.(*readTrackingBody); ok {
r.didRead = true
return r.ReadCloser
}
return t.Body
}
type transferReader struct {
// Input
Header Header
StatusCode int
RequestMethod string
ProtoMajor int
ProtoMinor int
// Output
Body io.ReadCloser
ContentLength int64
Chunked bool
Close bool
Trailer Header
}
func (t *transferReader) protoAtLeast(m, n int) bool {
return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n)
}
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 7230, section 3.3.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
var (
suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"}
suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"}
excludedHeadersNoBody = map[string]bool{"Content-Length": true, "Transfer-Encoding": true}
)
func suppressedHeaders(status int) []string {
switch {
case status == 304:
// RFC 7232 section 4.1
return suppressedHeaders304
case !bodyAllowedForStatus(status):
return suppressedHeadersNoBody
}
return nil
}
// msg is *Request or *Response.
func readTransfer(msg any, r *bufio.Reader) (err error) {
t := &transferReader{RequestMethod: "GET"}
// Unify input
isResponse := false
switch rr := msg.(type) {
case *Response:
t.Header = rr.Header
t.StatusCode = rr.StatusCode
t.ProtoMajor = rr.ProtoMajor
t.ProtoMinor = rr.ProtoMinor
t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true)
isResponse = true
if rr.Request != nil {
t.RequestMethod = rr.Request.Method
}
case *Request:
t.Header = rr.Header
t.RequestMethod = rr.Method
t.ProtoMajor = rr.ProtoMajor
t.ProtoMinor = rr.ProtoMinor
// Transfer semantics for Requests are exactly like those for
// Responses with status code 200, responding to a GET method
t.StatusCode = 200
t.Close = rr.Close
default:
panic("unexpected type")
}
// Default to HTTP/1.1
if t.ProtoMajor == 0 && t.ProtoMinor == 0 {
t.ProtoMajor, t.ProtoMinor = 1, 1
}
// Transfer-Encoding: chunked, and overriding Content-Length.
if err := t.parseTransferEncoding(); err != nil {
return err
}
realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked)
if err != nil {
return err
}
if isResponse && t.RequestMethod == "HEAD" {
if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
return err
} else {
t.ContentLength = n
}
} else {
t.ContentLength = realLength
}
// Trailer
t.Trailer, err = fixTrailer(t.Header, t.Chunked)
if err != nil {
return err
}
// If there is no Content-Length or chunked Transfer-Encoding on a *Response
// and the status is not 1xx, 204 or 304, then the body is unbounded.
// See RFC 7230, section 3.3.
switch msg.(type) {
case *Response:
if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) {
// Unbounded body.
t.Close = true
}
}
// Prepare body reader. ContentLength < 0 means chunked encoding
// or close connection when finished, since multipart is not supported yet
switch {
case t.Chunked:
if isResponse && (noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode)) {
t.Body = NoBody
} else {
t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close}
}
case realLength == 0:
t.Body = NoBody
case realLength > 0:
t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
default:
// realLength < 0, i.e. "Content-Length" not mentioned in header
if t.Close {
// Close semantics (i.e. HTTP/1.0)
t.Body = &body{src: r, closing: t.Close}
} else {
// Persistent connection (i.e. HTTP/1.1)
t.Body = NoBody
}
}
// Unify output
switch rr := msg.(type) {
case *Request:
rr.Body = t.Body
rr.ContentLength = t.ContentLength
if t.Chunked {
rr.TransferEncoding = []string{"chunked"}
}
rr.Close = t.Close
rr.Trailer = t.Trailer
case *Response:
rr.Body = t.Body
rr.ContentLength = t.ContentLength
if t.Chunked {
rr.TransferEncoding = []string{"chunked"}
}
rr.Close = t.Close
rr.Trailer = t.Trailer
}
return nil
}
// Checks whether chunked is part of the encodings stack.
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
// Checks whether the encoding is explicitly "identity".
func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
// unsupportedTEError reports unsupported transfer-encodings.
type unsupportedTEError struct {
err string
}
func (uste *unsupportedTEError) Error() string {
return uste.err
}
// isUnsupportedTEError checks if the error is of type
// unsupportedTEError. It is usually invoked with a non-nil err.
func isUnsupportedTEError(err error) bool {
_, ok := err.(*unsupportedTEError)
return ok
}
// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header.
func (t *transferReader) parseTransferEncoding() error {
raw, present := t.Header["Transfer-Encoding"]
if !present {
return nil
}
delete(t.Header, "Transfer-Encoding")
// Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests.
if !t.protoAtLeast(1, 1) {
return nil
}
// Like nginx, we only support a single Transfer-Encoding header field, and
// only if set to "chunked". This is one of the most security sensitive
// surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it
// strict and simple.
if len(raw) != 1 {
return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)}
}
if !ascii.EqualFold(raw[0], "chunked") {
return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])}
}
// RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field
// in any message that contains a Transfer-Encoding header field."
//
// but also: "If a message is received with both a Transfer-Encoding and a
// Content-Length header field, the Transfer-Encoding overrides the
// Content-Length. Such a message might indicate an attempt to perform
// request smuggling (Section 9.5) or response splitting (Section 9.4) and
// ought to be handled as an error. A sender MUST remove the received
// Content-Length field prior to forwarding such a message downstream."
//
// Reportedly, these appear in the wild.
delete(t.Header, "Content-Length")
t.Chunked = true
return nil
}
// Determine the expected body length, using RFC 7230 Section 3.3. This
// function is not a method, because ultimately it should be shared by
// ReadResponse and ReadRequest.
func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) {
isRequest := !isResponse
contentLens := header["Content-Length"]
// Hardening against HTTP request smuggling
if len(contentLens) > 1 {
// Per RFC 7230 Section 3.3.2, prevent multiple
// Content-Length headers if they differ in value.
// If there are dups of the value, remove the dups.
// See Issue 16490.
first := textproto.TrimString(contentLens[0])
for _, ct := range contentLens[1:] {
if first != textproto.TrimString(ct) {
return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens)
}
}
// deduplicate Content-Length
header.Del("Content-Length")
header.Add("Content-Length", first)
contentLens = header["Content-Length"]
}
// Logic based on response type or status
if isResponse && noResponseBodyExpected(requestMethod) {
return 0, nil
}
if status/100 == 1 {
return 0, nil
}
switch status {
case 204, 304:
return 0, nil
}
// Logic based on Transfer-Encoding
if chunked {
return -1, nil
}
// Logic based on Content-Length
var cl string
if len(contentLens) == 1 {
cl = textproto.TrimString(contentLens[0])
}
if cl != "" {
n, err := parseContentLength(cl)
if err != nil {
return -1, err
}
return n, nil
}
header.Del("Content-Length")
if isRequest {
// RFC 7230 neither explicitly permits nor forbids an
// entity-body on a GET request so we permit one if
// declared, but we default to 0 here (not -1 below)
// if there's no mention of a body.
// Likewise, all other request methods are assumed to have
// no body if neither Transfer-Encoding chunked nor a
// Content-Length are set.
return 0, nil
}
// Body-EOF logic based on other methods (like closing, or chunked coding)
return -1, nil
}
// Determine whether to hang up after sending a request and body, or
// receiving a response and body
// 'header' is the request headers.
func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
if major < 1 {
return true
}
conv := header["Connection"]
hasClose := httpguts.HeaderValuesContainsToken(conv, "close")
if major == 1 && minor == 0 {
return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive")
}
if hasClose && removeCloseHeader {
header.Del("Connection")
}
return hasClose
}
// Parse the trailer header.
func fixTrailer(header Header, chunked bool) (Header, error) {
vv, ok := header["Trailer"]
if !ok {
return nil, nil
}
if !chunked {
// Trailer and no chunking:
// this is an invalid use case for trailer header.
// Nevertheless, no error will be returned and we
// let users decide if this is a valid HTTP message.
// The Trailer header will be kept in Response.Header
// but not populate Response.Trailer.
// See issue #27197.
return nil, nil
}
header.Del("Trailer")
trailer := make(Header)
var err error
for _, v := range vv {
foreachHeaderElement(v, func(key string) {
key = CanonicalHeaderKey(key)
switch key {
case "Transfer-Encoding", "Trailer", "Content-Length":
if err == nil {
err = badStringError("bad trailer key", key)
return
}
}
trailer[key] = nil
})
}
if err != nil {
return nil, err
}
if len(trailer) == 0 {
return nil, nil
}
return trailer, nil
}
// body turns a Reader into a ReadCloser.
// Close ensures that the body has been fully read
// and then reads the trailer if necessary.
type body struct {
src io.Reader
hdr any // non-nil (Response or Request) value means read trailer
r *bufio.Reader // underlying wire-format reader for the trailer
closing bool // is the connection to be closed after reading body?
doEarlyClose bool // whether Close should stop early
mu sync.Mutex // guards following, and calls to Read and Close
sawEOF bool
closed bool
earlyClose bool // Close called and we didn't read to the end of src
onHitEOF func() // if non-nil, func to call when EOF is Read
}
// ErrBodyReadAfterClose is returned when reading a Request or Response
// Body after the body has been closed. This typically happens when the body is
// read after an HTTP Handler calls WriteHeader or Write on its
// ResponseWriter.
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
func (b *body) Read(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return 0, ErrBodyReadAfterClose
}
return b.readLocked(p)
}
// Must hold b.mu.
func (b *body) readLocked(p []byte) (n int, err error) {
if b.sawEOF {
return 0, io.EOF
}
n, err = b.src.Read(p)
if err == io.EOF {
b.sawEOF = true
// Chunked case. Read the trailer.
if b.hdr != nil {
if e := b.readTrailer(); e != nil {
err = e
// Something went wrong in the trailer, we must not allow any
// further reads of any kind to succeed from body, nor any
// subsequent requests on the server connection. See
// golang.org/issue/12027
b.sawEOF = false
b.closed = true
}
b.hdr = nil
} else {
// If the server declared the Content-Length, our body is a LimitedReader
// and we need to check whether this EOF arrived early.
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
err = io.ErrUnexpectedEOF
}
}
}
// If we can return an EOF here along with the read data, do
// so. This is optional per the io.Reader contract, but doing
// so helps the HTTP transport code recycle its connection
// earlier (since it will see this EOF itself), even if the
// client doesn't do future reads or Close.
if err == nil && n > 0 {
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 {
err = io.EOF
b.sawEOF = true
}
}
if b.sawEOF && b.onHitEOF != nil {
b.onHitEOF()
}
return n, err
}
var (
singleCRLF = []byte("\r\n")
doubleCRLF = []byte("\r\n\r\n")
)
func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
for peekSize := 4; ; peekSize++ {
// This loop stops when Peek returns an error,
// which it does when r's buffer has been filled.
buf, err := r.Peek(peekSize)
if bytes.HasSuffix(buf, doubleCRLF) {
return true
}
if err != nil {
break
}
}
return false
}
var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
func (b *body) readTrailer() error {
// The common case, since nobody uses trailers.
buf, err := b.r.Peek(2)
if bytes.Equal(buf, singleCRLF) {
b.r.Discard(2)
return nil
}
if len(buf) < 2 {
return errTrailerEOF
}
if err != nil {
return err
}
// Make sure there's a header terminator coming up, to prevent
// a DoS with an unbounded size Trailer. It's not easy to
// slip in a LimitReader here, as textproto.NewReader requires
// a concrete *bufio.Reader. Also, we can't get all the way
// back up to our conn's LimitedReader that *might* be backing
// this bufio.Reader. Instead, a hack: we iteratively Peek up
// to the bufio.Reader's max size, looking for a double CRLF.
// This limits the trailer to the underlying buffer size, typically 4kB.
if !seeUpcomingDoubleCRLF(b.r) {
return errors.New("http: suspiciously long trailer after chunked body")
}
hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
if err != nil {
if err == io.EOF {
return errTrailerEOF
}
return err
}
switch rr := b.hdr.(type) {
case *Request:
mergeSetHeader(&rr.Trailer, Header(hdr))
case *Response:
mergeSetHeader(&rr.Trailer, Header(hdr))
}
return nil
}
func mergeSetHeader(dst *Header, src Header) {
if *dst == nil {
*dst = src
return
}
for k, vv := range src {
(*dst)[k] = vv
}
}
// unreadDataSizeLocked returns the number of bytes of unread input.
// It returns -1 if unknown.
// b.mu must be held.
func (b *body) unreadDataSizeLocked() int64 {
if lr, ok := b.src.(*io.LimitedReader); ok {
return lr.N
}
return -1
}
func (b *body) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return nil
}
var err error
switch {
case b.sawEOF:
// Already saw EOF, so no need going to look for it.
case b.hdr == nil && b.closing:
// no trailer and closing the connection next.
// no point in reading to EOF.
case b.doEarlyClose:
// Read up to maxPostHandlerReadBytes bytes of the body, looking
// for EOF (and trailers), so we can re-use this connection.
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes {
// There was a declared Content-Length, and we have more bytes remaining
// than our maxPostHandlerReadBytes tolerance. So, give up.
b.earlyClose = true
} else {
var n int64
// Consume the body, or, which will also lead to us reading
// the trailer headers after the body, if present.
n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes)
if err == io.EOF {
err = nil
}
if n == maxPostHandlerReadBytes {
b.earlyClose = true
}
}
default:
// Fully consume the body, which will also lead to us reading
// the trailer headers after the body, if present.
_, err = io.Copy(io.Discard, bodyLocked{b})
}
b.closed = true
return err
}
func (b *body) didEarlyClose() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.earlyClose
}
// bodyRemains reports whether future Read calls might
// yield data.
func (b *body) bodyRemains() bool {
b.mu.Lock()
defer b.mu.Unlock()
return !b.sawEOF
}
func (b *body) registerOnHitEOF(fn func()) {
b.mu.Lock()
defer b.mu.Unlock()
b.onHitEOF = fn
}
// bodyLocked is an io.Reader reading from a *body when its mutex is
// already held.
type bodyLocked struct {
b *body
}
func (bl bodyLocked) Read(p []byte) (n int, err error) {
if bl.b.closed {
return 0, ErrBodyReadAfterClose
}
return bl.b.readLocked(p)
}
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
func parseContentLength(cl string) (int64, error) {
cl = textproto.TrimString(cl)
if cl == "" {
return -1, nil
}
n, err := strconv.ParseUint(cl, 10, 63)
if err != nil {
return 0, badStringError("bad Content-Length", cl)
}
return int64(n), nil
}
// finishAsyncByteRead finishes reading the 1-byte sniff
// from the ContentLength==0, Body!=nil case.
type finishAsyncByteRead struct {
tw *transferWriter
}
func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return
}
rres := <-fr.tw.ByteReadCh
n, err = rres.n, rres.err
if n == 1 {
p[0] = rres.b
}
if err == nil {
err = io.EOF
}
return
}
var nopCloserType = reflect.TypeOf(io.NopCloser(nil))
var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct {
io.Reader
io.WriterTo
}{}))
// unwrapNopCloser return the underlying reader and true if r is a NopCloser
// else it return false.
func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) {
switch reflect.TypeOf(r) {
case nopCloserType, nopCloserWriterToType:
return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true
default:
return nil, false
}
}
// isKnownInMemoryReader reports whether r is a type known to not
// block on Read. Its caller uses this as an optional optimization to
// send fewer TCP packets.
func isKnownInMemoryReader(r io.Reader) bool {
switch r.(type) {
case *bytes.Reader, *bytes.Buffer, *strings.Reader:
return true
}
if r, ok := unwrapNopCloser(r); ok {
return isKnownInMemoryReader(r)
}
if r, ok := r.(*readTrackingBody); ok {
return isKnownInMemoryReader(r.ReadCloser)
}
return false
}
// bufioFlushWriter is an io.Writer wrapper that flushes all writes
// on its wrapped writer if it's a *bufio.Writer.
type bufioFlushWriter struct{ w io.Writer }
func (fw bufioFlushWriter) Write(p []byte) (n int, err error) {
n, err = fw.w.Write(p)
if bw, ok := fw.w.(*bufio.Writer); n > 0 && ok {
ferr := bw.Flush()
if ferr != nil && err == nil {
err = ferr
}
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP client implementation. See RFC 7230 through 7235.
//
// This is the low-level Transport implementation of RoundTripper.
// The high-level interface is in client.go.
package http
import (
"bufio"
"compress/gzip"
"container/list"
"context"
"crypto/tls"
"errors"
"fmt"
"internal/godebug"
"io"
"log"
"net"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http/httpproxy"
)
// DefaultTransport is the default implementation of Transport and is
// used by DefaultClient. It establishes network connections as needed
// and caches them for reuse by subsequent calls. It uses HTTP proxies
// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY
// and NO_PROXY (or the lowercase versions thereof).
var DefaultTransport RoundTripper = &Transport{
Proxy: ProxyFromEnvironment,
DialContext: defaultTransportDialContext(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
// MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2
// Transport is an implementation of RoundTripper that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
//
// By default, Transport caches connections for future re-use.
// This may leave many open connections when accessing many hosts.
// This behavior can be managed using Transport's CloseIdleConnections method
// and the MaxIdleConnsPerHost and DisableKeepAlives fields.
//
// Transports should be reused instead of created as needed.
// Transports are safe for concurrent use by multiple goroutines.
//
// A Transport is a low-level primitive for making HTTP and HTTPS requests.
// For high-level functionality, such as cookies and redirects, see Client.
//
// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
// for HTTPS URLs, depending on whether the server supports HTTP/2,
// and how the Transport is configured. The DefaultTransport supports HTTP/2.
// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2
// and call ConfigureTransport. See the package docs for more about HTTP/2.
//
// Responses with status codes in the 1xx range are either handled
// automatically (100 expect-continue) or ignored. The one
// exception is HTTP status code 101 (Switching Protocols), which is
// considered a terminal status and returned by RoundTrip. To see the
// ignored 1xx responses, use the httptrace trace package's
// ClientTrace.Got1xxResponse.
//
// Transport only retries a request upon encountering a network error
// if the request is idempotent and either has no body or has its
// Request.GetBody defined. HTTP requests are considered idempotent if
// they have HTTP methods GET, HEAD, OPTIONS, or TRACE; or if their
// Header map contains an "Idempotency-Key" or "X-Idempotency-Key"
// entry. If the idempotency key value is a zero-length slice, the
// request is treated as idempotent but the header is not sent on the
// wire.
type Transport struct {
idleMu sync.Mutex
closeIdle bool // user has requested to close all idle conns
idleConn map[connectMethodKey][]*persistConn // most recently used at end
idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns
idleLRU connLRU
reqMu sync.Mutex
reqCanceler map[cancelKey]func(error)
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
connsPerHostMu sync.Mutex
connsPerHost map[connectMethodKey]int
connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
//
// The proxy type is determined by the URL scheme. "http",
// "https", and "socks5" are supported. If the scheme is empty,
// "http" is assumed.
//
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*url.URL, error)
// OnProxyConnectResponse is called when the Transport gets an HTTP response from
// a proxy for a CONNECT request. It's called before the check for a 200 OK response.
// If it returns an error, the request fails with that error.
OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error
// DialContext specifies the dial function for creating unencrypted TCP connections.
// If DialContext is nil (and the deprecated Dial below is also nil),
// then the transport dials using package net.
//
// DialContext runs concurrently with calls to RoundTrip.
// A RoundTrip call that initiates a dial may end up using
// a connection dialed previously when the earlier connection
// becomes idle before the later DialContext completes.
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Dial specifies the dial function for creating unencrypted TCP connections.
//
// Dial runs concurrently with calls to RoundTrip.
// A RoundTrip call that initiates a dial may end up using
// a connection dialed previously when the earlier connection
// becomes idle before the later Dial completes.
//
// Deprecated: Use DialContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialContext takes priority.
Dial func(network, addr string) (net.Conn, error)
// DialTLSContext specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests.
//
// If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
// DialContext and TLSClientConfig are used.
//
// If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
// requests and the TLSClientConfig and TLSHandshakeTimeout
// are ignored. The returned net.Conn is assumed to already be
// past the TLS handshake.
DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests.
//
// Deprecated: Use DialTLSContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialTLSContext takes priority.
DialTLS func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client.
// If nil, the default configuration is used.
// If non-nil, HTTP/2 support may not be enabled by default.
TLSClientConfig *tls.Config
// TLSHandshakeTimeout specifies the maximum amount of time waiting to
// wait for a TLS handshake. Zero means no timeout.
TLSHandshakeTimeout time.Duration
// DisableKeepAlives, if true, disables HTTP keep-alives and
// will only use the connection to the server for a single
// HTTP request.
//
// This is unrelated to the similarly named TCP keep-alives.
DisableKeepAlives bool
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// MaxIdleConns controls the maximum number of idle (keep-alive)
// connections across all hosts. Zero means no limit.
MaxIdleConns int
// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
// (keep-alive) connections to keep per-host. If zero,
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
// MaxConnsPerHost optionally limits the total number of
// connections per host, including connections in the dialing,
// active, and idle states. On limit violation, dials will block.
//
// Zero means no limit.
MaxConnsPerHost int
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
// Zero means no limit.
IdleConnTimeout time.Duration
// ResponseHeaderTimeout, if non-zero, specifies the amount of
// time to wait for a server's response headers after fully
// writing the request (including its body, if any). This
// time does not include the time to read the response body.
ResponseHeaderTimeout time.Duration
// ExpectContinueTimeout, if non-zero, specifies the amount of
// time to wait for a server's first response headers after fully
// writing the request headers if the request has an
// "Expect: 100-continue" header. Zero means no timeout and
// causes the body to be sent immediately, without
// waiting for the server to approve.
// This time does not include the time to send the request header.
ExpectContinueTimeout time.Duration
// TLSNextProto specifies how the Transport switches to an
// alternate protocol (such as HTTP/2) after a TLS ALPN
// protocol negotiation. If Transport dials an TLS connection
// with a non-empty protocol name and TLSNextProto contains a
// map entry for that key (such as "h2"), then the func is
// called with the request's authority (such as "example.com"
// or "example.com:1234") and the TLS connection. The function
// must return a RoundTripper that then handles the request.
// If TLSNextProto is not nil, HTTP/2 support is not enabled
// automatically.
TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
// ProxyConnectHeader optionally specifies headers to send to
// proxies during CONNECT requests.
// To set the header dynamically, see GetProxyConnectHeader.
ProxyConnectHeader Header
// GetProxyConnectHeader optionally specifies a func to return
// headers to send to proxyURL during a CONNECT request to the
// ip:port target.
// If it returns an error, the Transport's RoundTrip fails with
// that error. It can return (nil, nil) to not add headers.
// If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
// ignored.
GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
//
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
// WriteBufferSize specifies the size of the write buffer used
// when writing to the transport.
// If zero, a default (currently 4KB) is used.
WriteBufferSize int
// ReadBufferSize specifies the size of the read buffer used
// when reading from the transport.
// If zero, a default (currently 4KB) is used.
ReadBufferSize int
// nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once
h2transport h2Transport // non-nil if http2 wired up
tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired
// ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
// Dial, DialTLS, or DialContext func or TLSClientConfig is provided.
// By default, use of any those fields conservatively disables HTTP/2.
// To use a custom dialer or TLS config and still attempt HTTP/2
// upgrades, set this to true.
ForceAttemptHTTP2 bool
}
// A cancelKey is the key of the reqCanceler map.
// We wrap the *Request in this type since we want to use the original request,
// not any transient one created by roundTrip.
type cancelKey struct {
req *Request
}
func (t *Transport) writeBufferSize() int {
if t.WriteBufferSize > 0 {
return t.WriteBufferSize
}
return 4 << 10
}
func (t *Transport) readBufferSize() int {
if t.ReadBufferSize > 0 {
return t.ReadBufferSize
}
return 4 << 10
}
// Clone returns a deep copy of t's exported fields.
func (t *Transport) Clone() *Transport {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
t2 := &Transport{
Proxy: t.Proxy,
OnProxyConnectResponse: t.OnProxyConnectResponse,
DialContext: t.DialContext,
Dial: t.Dial,
DialTLS: t.DialTLS,
DialTLSContext: t.DialTLSContext,
TLSHandshakeTimeout: t.TLSHandshakeTimeout,
DisableKeepAlives: t.DisableKeepAlives,
DisableCompression: t.DisableCompression,
MaxIdleConns: t.MaxIdleConns,
MaxIdleConnsPerHost: t.MaxIdleConnsPerHost,
MaxConnsPerHost: t.MaxConnsPerHost,
IdleConnTimeout: t.IdleConnTimeout,
ResponseHeaderTimeout: t.ResponseHeaderTimeout,
ExpectContinueTimeout: t.ExpectContinueTimeout,
ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
GetProxyConnectHeader: t.GetProxyConnectHeader,
MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
ForceAttemptHTTP2: t.ForceAttemptHTTP2,
WriteBufferSize: t.WriteBufferSize,
ReadBufferSize: t.ReadBufferSize,
}
if t.TLSClientConfig != nil {
t2.TLSClientConfig = t.TLSClientConfig.Clone()
}
if !t.tlsNextProtoWasNil {
npm := map[string]func(authority string, c *tls.Conn) RoundTripper{}
for k, v := range t.TLSNextProto {
npm[k] = v
}
t2.TLSNextProto = npm
}
return t2
}
// h2Transport is the interface we expect to be able to call from
// net/http against an *http2.Transport that's either bundled into
// h2_bundle.go or supplied by the user via x/net/http2.
//
// We name it with the "h2" prefix to stay out of the "http2" prefix
// namespace used by x/tools/cmd/bundle for h2_bundle.go.
type h2Transport interface {
CloseIdleConnections()
}
func (t *Transport) hasCustomTLSDialer() bool {
return t.DialTLS != nil || t.DialTLSContext != nil
}
var http2client = godebug.New("http2client")
// onceSetNextProtoDefaults initializes TLSNextProto.
// It must be called via t.nextProtoOnce.Do.
func (t *Transport) onceSetNextProtoDefaults() {
t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
if http2client.Value() == "0" {
http2client.IncNonDefault()
return
}
// If they've already configured http2 with
// golang.org/x/net/http2 instead of the bundled copy, try to
// get at its http2.Transport value (via the "https"
// altproto map) so we can call CloseIdleConnections on it if
// requested. (Issue 22891)
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
if rv := reflect.ValueOf(altProto["https"]); rv.IsValid() && rv.Type().Kind() == reflect.Struct && rv.Type().NumField() == 1 {
if v := rv.Field(0); v.CanInterface() {
if h2i, ok := v.Interface().(h2Transport); ok {
t.h2transport = h2i
return
}
}
}
if t.TLSNextProto != nil {
// This is the documented way to disable http2 on a
// Transport.
return
}
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
// Be conservative and don't automatically enable
// http2 if they've specified a custom TLS config or
// custom dialers. Let them opt-in themselves via
// http2.ConfigureTransport so we don't surprise them
// by modifying their tls.Config. Issue 14275.
// However, if ForceAttemptHTTP2 is true, it overrides the above checks.
return
}
if omitBundledHTTP2 {
return
}
t2, err := http2configureTransports(t)
if err != nil {
log.Printf("Error enabling Transport HTTP/2 support: %v", err)
return
}
t.h2transport = t2
// Auto-configure the http2.Transport's MaxHeaderListSize from
// the http.Transport's MaxResponseHeaderBytes. They don't
// exactly mean the same thing, but they're close.
//
// TODO: also add this to x/net/http2.Configure Transport, behind
// a +build go1.7 build tag:
if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
const h2max = 1<<32 - 1
if limit1 >= h2max {
t2.MaxHeaderListSize = h2max
} else {
t2.MaxHeaderListSize = uint32(limit1)
}
}
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
// given request, as indicated by the environment variables
// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions
// thereof). Requests use the proxy from the environment variable
// matching their scheme, unless excluded by NO_PROXY.
//
// The environment values may be either a complete URL or a
// "host[:port]", in which case the "http" scheme is assumed.
// The schemes "http", "https", and "socks5" are supported.
// An error is returned if the value is a different form.
//
// A nil URL and nil error are returned if no proxy is defined in the
// environment, or a proxy should not be used for the given request,
// as defined by NO_PROXY.
//
// As a special case, if req.URL.Host is "localhost" (with or without
// a port number), then a nil URL and nil error will be returned.
func ProxyFromEnvironment(req *Request) (*url.URL, error) {
return envProxyFunc()(req.URL)
}
// ProxyURL returns a proxy function (for use in a Transport)
// that always returns the same URL.
func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
return func(*Request) (*url.URL, error) {
return fixedURL, nil
}
}
// transportRequest is a wrapper around a *Request that adds
// optional extra headers to write and stores any error to return
// from roundTrip.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
trace *httptrace.ClientTrace // optional
cancelKey cancelKey
mu sync.Mutex // guards err
err error // first setError value for mapRoundTripError to consider
}
func (tr *transportRequest) extraHeaders() Header {
if tr.extra == nil {
tr.extra = make(Header)
}
return tr.extra
}
func (tr *transportRequest) setError(err error) {
tr.mu.Lock()
if tr.err == nil {
tr.err = err
}
tr.mu.Unlock()
}
// useRegisteredProtocol reports whether an alternate protocol (as registered
// with Transport.RegisterProtocol) should be respected for this request.
func (t *Transport) useRegisteredProtocol(req *Request) bool {
if req.URL.Scheme == "https" && req.requiresHTTP1() {
// If this request requires HTTP/1, don't use the
// "https" alternate protocol, which is used by the
// HTTP/2 code to take over requests if there's an
// existing cached HTTP/2 connection.
return false
}
return true
}
// alternateRoundTripper returns the alternate RoundTripper to use
// for this request if the Request's URL scheme requires one,
// or nil for the normal case of using the Transport.
func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
if !t.useRegisteredProtocol(req) {
return nil
}
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
return altProto[req.URL.Scheme]
}
// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
if req.URL == nil {
req.closeBody()
return nil, errors.New("http: nil Request.URL")
}
if req.Header == nil {
req.closeBody()
return nil, errors.New("http: nil Request.Header")
}
scheme := req.URL.Scheme
isHTTP := scheme == "http" || scheme == "https"
if isHTTP {
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
req.closeBody()
return nil, fmt.Errorf("net/http: invalid header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
req.closeBody()
// Don't include the value in the error, because it may be sensitive.
return nil, fmt.Errorf("net/http: invalid header field value for %q", k)
}
}
}
}
origReq := req
cancelKey := cancelKey{origReq}
req = setupRewindBody(req)
if altRT := t.alternateRoundTripper(req); altRT != nil {
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
var err error
req, err = rewindBody(req)
if err != nil {
return nil, err
}
}
if !isHTTP {
req.closeBody()
return nil, badStringError("unsupported protocol scheme", scheme)
}
if req.Method != "" && !validMethod(req.Method) {
req.closeBody()
return nil, fmt.Errorf("net/http: invalid method %q", req.Method)
}
if req.URL.Host == "" {
req.closeBody()
return nil, errors.New("http: no Host in request URL")
}
for {
select {
case <-ctx.Done():
req.closeBody()
return nil, ctx.Err()
default:
}
// treq gets modified by roundTrip, so we need to recreate for each retry.
treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
req.closeBody()
return nil, err
}
// Get the cached or newly-created connection to either the
// host (for http or https), the http proxy, or the http proxy
// pre-CONNECTed to https server. In any case, we'll be ready
// to send it requests.
pconn, err := t.getConn(treq, cm)
if err != nil {
t.setReqCanceler(cancelKey, nil)
req.closeBody()
return nil, err
}
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
}
if err == nil {
resp.Request = origReq
return resp, nil
}
// Failed. Clean up and determine whether to retry.
if http2isNoCachedConnError(err) {
if t.removeIdleConn(pconn) {
t.decConnsPerHost(pconn.cacheKey)
}
} else if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek,
// as we've historically done.
if e, ok := err.(nothingWrittenError); ok {
err = e.error
}
if e, ok := err.(transportReadFromServerError); ok {
err = e.err
}
if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose {
// Issue 49621: Close the request body if pconn.roundTrip
// didn't do so already. This can happen if the pconn
// write loop exits without reading the write request.
req.closeBody()
}
return nil, err
}
testHookRoundTripRetried()
// Rewind the body if we're able to.
req, err = rewindBody(req)
if err != nil {
return nil, err
}
}
}
var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
type readTrackingBody struct {
io.ReadCloser
didRead bool
didClose bool
}
func (r *readTrackingBody) Read(data []byte) (int, error) {
r.didRead = true
return r.ReadCloser.Read(data)
}
func (r *readTrackingBody) Close() error {
r.didClose = true
return r.ReadCloser.Close()
}
// setupRewindBody returns a new request with a custom body wrapper
// that can report whether the body needs rewinding.
// This lets rewindBody avoid an error result when the request
// does not have GetBody but the body hasn't been read at all yet.
func setupRewindBody(req *Request) *Request {
if req.Body == nil || req.Body == NoBody {
return req
}
newReq := *req
newReq.Body = &readTrackingBody{ReadCloser: req.Body}
return &newReq
}
// rewindBody returns a new request with the body rewound.
// It returns req unmodified if the body does not need rewinding.
// rewindBody takes care of closing req.Body when appropriate
// (in all cases except when rewindBody returns req unmodified).
func rewindBody(req *Request) (rewound *Request, err error) {
if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) {
return req, nil // nothing to rewind
}
if !req.Body.(*readTrackingBody).didClose {
req.closeBody()
}
if req.GetBody == nil {
return nil, errCannotRewind
}
body, err := req.GetBody()
if err != nil {
return nil, err
}
newReq := *req
newReq.Body = &readTrackingBody{ReadCloser: body}
return &newReq, nil
}
// shouldRetryRequest reports whether we should retry sending a failed
// HTTP request on a new connection. The non-nil input error is the
// error from roundTrip.
func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
if http2isNoCachedConnError(err) {
// Issue 16582: if the user started a bunch of
// requests at once, they can all pick the same conn
// and violate the server's max concurrent streams.
// Instead, match the HTTP/1 behavior for now and dial
// again to get a new TCP connection, rather than failing
// this request.
return true
}
if err == errMissingHost {
// User error.
return false
}
if !pc.isReused() {
// This was a fresh connection. There's no reason the server
// should've hung up on us.
//
// Also, if we retried now, we could loop forever
// creating new connections and retrying if the server
// is just hanging up on us because it doesn't like
// our request (as opposed to sending an error).
return false
}
if _, ok := err.(nothingWrittenError); ok {
// We never wrote anything, so it's safe to retry, if there's no body or we
// can "rewind" the body with GetBody.
return req.outgoingLength() == 0 || req.GetBody != nil
}
if !req.isReplayable() {
// Don't retry non-idempotent requests.
return false
}
if _, ok := err.(transportReadFromServerError); ok {
// We got some non-EOF net.Conn.Read failure reading
// the 1st response byte from the server.
return true
}
if err == errServerClosedIdle {
// The server replied with io.EOF while we were trying to
// read the response. Probably an unfortunately keep-alive
// timeout, just as the client was writing a request.
return true
}
return false // conservatively
}
// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol.
var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol")
// RegisterProtocol registers a new protocol with scheme.
// The Transport will pass requests using the given scheme to rt.
// It is rt's responsibility to simulate HTTP request semantics.
//
// RegisterProtocol can be used by other packages to provide
// implementations of protocol schemes like "ftp" or "file".
//
// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will
// handle the RoundTrip itself for that one request, as if the
// protocol were not registered.
func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
t.altMu.Lock()
defer t.altMu.Unlock()
oldMap, _ := t.altProto.Load().(map[string]RoundTripper)
if _, exists := oldMap[scheme]; exists {
panic("protocol " + scheme + " already registered")
}
newMap := make(map[string]RoundTripper)
for k, v := range oldMap {
newMap[k] = v
}
newMap[scheme] = rt
t.altProto.Store(newMap)
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in
// a "keep-alive" state. It does not interrupt any connections currently
// in use.
func (t *Transport) CloseIdleConnections() {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
t.idleMu.Lock()
m := t.idleConn
t.idleConn = nil
t.closeIdle = true // close newly idle connections
t.idleLRU = connLRU{}
t.idleMu.Unlock()
for _, conns := range m {
for _, pconn := range conns {
pconn.close(errCloseIdleConns)
}
}
if t2 := t.h2transport; t2 != nil {
t2.CloseIdleConnections()
}
}
// CancelRequest cancels an in-flight request by closing its connection.
// CancelRequest should only be called after RoundTrip has returned.
//
// Deprecated: Use Request.WithContext to create a request with a
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
t.cancelRequest(cancelKey{req}, errRequestCanceled)
}
// Cancel an in-flight request, recording the error value.
// Returns whether the request was canceled.
func (t *Transport) cancelRequest(key cancelKey, err error) bool {
// This function must not return until the cancel func has completed.
// See: https://golang.org/issue/34658
t.reqMu.Lock()
defer t.reqMu.Unlock()
cancel := t.reqCanceler[key]
delete(t.reqCanceler, key)
if cancel != nil {
cancel(err)
}
return cancel != nil
}
//
// Private implementation past this point.
//
var (
envProxyOnce sync.Once
envProxyFuncValue func(*url.URL) (*url.URL, error)
)
// envProxyFunc returns a function that reads the
// environment variable to determine the proxy address.
func envProxyFunc() func(*url.URL) (*url.URL, error) {
envProxyOnce.Do(func() {
envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc()
})
return envProxyFuncValue
}
// resetProxyConfig is used by tests.
func resetProxyConfig() {
envProxyOnce = sync.Once{}
envProxyFuncValue = nil
}
func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
cm.targetScheme = treq.URL.Scheme
cm.targetAddr = canonicalAddr(treq.URL)
if t.Proxy != nil {
cm.proxyURL, err = t.Proxy(treq.Request)
}
cm.onlyH1 = treq.requiresHTTP1()
return cm, err
}
// proxyAuth returns the Proxy-Authorization header to set
// on requests, if applicable.
func (cm *connectMethod) proxyAuth() string {
if cm.proxyURL == nil {
return ""
}
if u := cm.proxyURL.User; u != nil {
username := u.Username()
password, _ := u.Password()
return "Basic " + basicAuth(username, password)
}
return ""
}
// error values for debugging and testing, not seen by users.
var (
errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled")
errConnBroken = errors.New("http: putIdleConn: connection is in bad state")
errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called")
errTooManyIdle = errors.New("http: putIdleConn: too many idle connections")
errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host")
errCloseIdleConns = errors.New("http: CloseIdleConnections called")
errReadLoopExiting = errors.New("http: persistConn.readLoop exiting")
errIdleConnTimeout = errors.New("http: idle connection timeout")
// errServerClosedIdle is not seen by users for idempotent requests, but may be
// seen by a user if the server shuts down an idle connection and sends its FIN
// in flight with already-written POST body bytes from the client.
// See https://github.com/golang/go/issues/19943#issuecomment-355607646
errServerClosedIdle = errors.New("http: server closed idle connection")
)
// transportReadFromServerError is used by Transport.readLoop when the
// 1 byte peek read fails and we're actually anticipating a response.
// Usually this is just due to the inherent keep-alive shut down race,
// where the server closed the connection at the same time the client
// wrote. The underlying err field is usually io.EOF or some
// ECONNRESET sort of thing which varies by platform. But it might be
// the user's custom net.Conn.Read error too, so we carry it along for
// them to return from Transport.RoundTrip.
type transportReadFromServerError struct {
err error
}
func (e transportReadFromServerError) Unwrap() error { return e.err }
func (e transportReadFromServerError) Error() string {
return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err)
}
func (t *Transport) putOrCloseIdleConn(pconn *persistConn) {
if err := t.tryPutIdleConn(pconn); err != nil {
pconn.close(err)
}
}
func (t *Transport) maxIdleConnsPerHost() int {
if v := t.MaxIdleConnsPerHost; v != 0 {
return v
}
return DefaultMaxIdleConnsPerHost
}
// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting
// a new request.
// If pconn is no longer needed or not in a good state, tryPutIdleConn returns
// an error explaining why it wasn't registered.
// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that.
func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
return errKeepAlivesDisabled
}
if pconn.isBroken() {
return errConnBroken
}
pconn.markReused()
t.idleMu.Lock()
defer t.idleMu.Unlock()
// HTTP/2 (pconn.alt != nil) connections do not come out of the idle list,
// because multiple goroutines can use them simultaneously.
// If this is an HTTP/2 connection being “returned,” we're done.
if pconn.alt != nil && t.idleLRU.m[pconn] != nil {
return nil
}
// Deliver pconn to goroutine waiting for idle connection, if any.
// (They may be actively dialing, but this conn is ready first.
// Chrome calls this socket late binding.
// See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.)
key := pconn.cacheKey
if q, ok := t.idleConnWait[key]; ok {
done := false
if pconn.alt == nil {
// HTTP/1.
// Loop over the waiting list until we find a w that isn't done already, and hand it pconn.
for q.len() > 0 {
w := q.popFront()
if w.tryDeliver(pconn, nil) {
done = true
break
}
}
} else {
// HTTP/2.
// Can hand the same pconn to everyone in the waiting list,
// and we still won't be done: we want to put it in the idle
// list unconditionally, for any future clients too.
for q.len() > 0 {
w := q.popFront()
w.tryDeliver(pconn, nil)
}
}
if q.len() == 0 {
delete(t.idleConnWait, key)
} else {
t.idleConnWait[key] = q
}
if done {
return nil
}
}
if t.closeIdle {
return errCloseIdle
}
if t.idleConn == nil {
t.idleConn = make(map[connectMethodKey][]*persistConn)
}
idles := t.idleConn[key]
if len(idles) >= t.maxIdleConnsPerHost() {
return errTooManyIdleHost
}
for _, exist := range idles {
if exist == pconn {
log.Fatalf("dup idle pconn %p in freelist", pconn)
}
}
t.idleConn[key] = append(idles, pconn)
t.idleLRU.add(pconn)
if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns {
oldest := t.idleLRU.removeOldest()
oldest.close(errTooManyIdle)
t.removeIdleConnLocked(oldest)
}
// Set idle timer, but only for HTTP/1 (pconn.alt == nil).
// The HTTP/2 implementation manages the idle timer itself
// (see idleConnTimeout in h2_bundle.go).
if t.IdleConnTimeout > 0 && pconn.alt == nil {
if pconn.idleTimer != nil {
pconn.idleTimer.Reset(t.IdleConnTimeout)
} else {
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
}
}
pconn.idleAt = time.Now()
return nil
}
// queueForIdleConn queues w to receive the next idle connection for w.cm.
// As an optimization hint to the caller, queueForIdleConn reports whether
// it successfully delivered an already-idle connection.
func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
if t.DisableKeepAlives {
return false
}
t.idleMu.Lock()
defer t.idleMu.Unlock()
// Stop closing connections that become idle - we might want one.
// (That is, undo the effect of t.CloseIdleConnections.)
t.closeIdle = false
if w == nil {
// Happens in test hook.
return false
}
// If IdleConnTimeout is set, calculate the oldest
// persistConn.idleAt time we're willing to use a cached idle
// conn.
var oldTime time.Time
if t.IdleConnTimeout > 0 {
oldTime = time.Now().Add(-t.IdleConnTimeout)
}
// Look for most recently-used idle connection.
if list, ok := t.idleConn[w.key]; ok {
stop := false
delivered := false
for len(list) > 0 && !stop {
pconn := list[len(list)-1]
// See whether this connection has been idle too long, considering
// only the wall time (the Round(0)), in case this is a laptop or VM
// coming out of suspend with previously cached idle connections.
tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime)
if tooOld {
// Async cleanup. Launch in its own goroutine (as if a
// time.AfterFunc called it); it acquires idleMu, which we're
// holding, and does a synchronous net.Conn.Close.
go pconn.closeConnIfStillIdle()
}
if pconn.isBroken() || tooOld {
// If either persistConn.readLoop has marked the connection
// broken, but Transport.removeIdleConn has not yet removed it
// from the idle list, or if this persistConn is too old (it was
// idle too long), then ignore it and look for another. In both
// cases it's already in the process of being closed.
list = list[:len(list)-1]
continue
}
delivered = w.tryDeliver(pconn, nil)
if delivered {
if pconn.alt != nil {
// HTTP/2: multiple clients can share pconn.
// Leave it in the list.
} else {
// HTTP/1: only one client can use pconn.
// Remove it from the list.
t.idleLRU.remove(pconn)
list = list[:len(list)-1]
}
}
stop = true
}
if len(list) > 0 {
t.idleConn[w.key] = list
} else {
delete(t.idleConn, w.key)
}
if stop {
return delivered
}
}
// Register to receive next connection that becomes idle.
if t.idleConnWait == nil {
t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.idleConnWait[w.key]
q.cleanFront()
q.pushBack(w)
t.idleConnWait[w.key] = q
return false
}
// removeIdleConn marks pconn as dead.
func (t *Transport) removeIdleConn(pconn *persistConn) bool {
t.idleMu.Lock()
defer t.idleMu.Unlock()
return t.removeIdleConnLocked(pconn)
}
// t.idleMu must be held.
func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
if pconn.idleTimer != nil {
pconn.idleTimer.Stop()
}
t.idleLRU.remove(pconn)
key := pconn.cacheKey
pconns := t.idleConn[key]
var removed bool
switch len(pconns) {
case 0:
// Nothing
case 1:
if pconns[0] == pconn {
delete(t.idleConn, key)
removed = true
}
default:
for i, v := range pconns {
if v != pconn {
continue
}
// Slide down, keeping most recently-used
// conns at the end.
copy(pconns[i:], pconns[i+1:])
t.idleConn[key] = pconns[:len(pconns)-1]
removed = true
break
}
}
return removed
}
func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
t.reqCanceler = make(map[cancelKey]func(error))
}
if fn != nil {
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, key)
}
}
// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[key]
if !ok {
return false
}
if fn != nil {
t.reqCanceler[key] = fn
} else {
delete(t.reqCanceler, key)
}
return true
}
var zeroDialer net.Dialer
func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
if t.DialContext != nil {
return t.DialContext(ctx, network, addr)
}
if t.Dial != nil {
c, err := t.Dial(network, addr)
if c == nil && err == nil {
err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
}
return c, err
}
return zeroDialer.DialContext(ctx, network, addr)
}
// A wantConn records state about a wanted connection
// (that is, an active call to getConn).
// The conn may be gotten by dialing or by finding an idle connection,
// or a cancellation may make the conn no longer wanted.
// These three options are racing against each other and use
// wantConn to coordinate and agree about the winning outcome.
type wantConn struct {
cm connectMethod
key connectMethodKey // cm.key()
ctx context.Context // context for dial
ready chan struct{} // closed when pc, err pair is delivered
// hooks for testing to know when dials are done
// beforeDial is called in the getConn goroutine when the dial is queued.
// afterDial is called when the dial is completed or canceled.
beforeDial func()
afterDial func()
mu sync.Mutex // protects pc, err, close(ready)
pc *persistConn
err error
}
// waiting reports whether w is still waiting for an answer (connection or error).
func (w *wantConn) waiting() bool {
select {
case <-w.ready:
return false
default:
return true
}
}
// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded.
func (w *wantConn) tryDeliver(pc *persistConn, err error) bool {
w.mu.Lock()
defer w.mu.Unlock()
if w.pc != nil || w.err != nil {
return false
}
w.pc = pc
w.err = err
if w.pc == nil && w.err == nil {
panic("net/http: internal error: misuse of tryDeliver")
}
close(w.ready)
return true
}
// cancel marks w as no longer wanting a result (for example, due to cancellation).
// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn.
func (w *wantConn) cancel(t *Transport, err error) {
w.mu.Lock()
if w.pc == nil && w.err == nil {
close(w.ready) // catch misbehavior in future delivery
}
pc := w.pc
w.pc = nil
w.err = err
w.mu.Unlock()
if pc != nil {
t.putOrCloseIdleConn(pc)
}
}
// A wantConnQueue is a queue of wantConns.
type wantConnQueue struct {
// This is a queue, not a deque.
// It is split into two stages - head[headPos:] and tail.
// popFront is trivial (headPos++) on the first stage, and
// pushBack is trivial (append) on the second stage.
// If the first stage is empty, popFront can swap the
// first and second stages to remedy the situation.
//
// This two-stage split is analogous to the use of two lists
// in Okasaki's purely functional queue but without the
// overhead of reversing the list when swapping stages.
head []*wantConn
headPos int
tail []*wantConn
}
// len returns the number of items in the queue.
func (q *wantConnQueue) len() int {
return len(q.head) - q.headPos + len(q.tail)
}
// pushBack adds w to the back of the queue.
func (q *wantConnQueue) pushBack(w *wantConn) {
q.tail = append(q.tail, w)
}
// popFront removes and returns the wantConn at the front of the queue.
func (q *wantConnQueue) popFront() *wantConn {
if q.headPos >= len(q.head) {
if len(q.tail) == 0 {
return nil
}
// Pick up tail as new head, clear tail.
q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
}
w := q.head[q.headPos]
q.head[q.headPos] = nil
q.headPos++
return w
}
// peekFront returns the wantConn at the front of the queue without removing it.
func (q *wantConnQueue) peekFront() *wantConn {
if q.headPos < len(q.head) {
return q.head[q.headPos]
}
if len(q.tail) > 0 {
return q.tail[0]
}
return nil
}
// cleanFront pops any wantConns that are no longer waiting from the head of the
// queue, reporting whether any were popped.
func (q *wantConnQueue) cleanFront() (cleaned bool) {
for {
w := q.peekFront()
if w == nil || w.waiting() {
return cleaned
}
q.popFront()
cleaned = true
}
}
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
if t.DialTLSContext != nil {
conn, err = t.DialTLSContext(ctx, network, addr)
} else {
conn, err = t.DialTLS(network, addr)
}
if conn == nil && err == nil {
err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
}
return
}
// getConn dials and creates a new persistConn to the target as
// specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn
// is ready to write requests to.
func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) {
req := treq.Request
trace := treq.trace
ctx := req.Context()
if trace != nil && trace.GetConn != nil {
trace.GetConn(cm.addr())
}
w := &wantConn{
cm: cm,
key: cm.key(),
ctx: ctx,
ready: make(chan struct{}, 1),
beforeDial: testHookPrePendingDial,
afterDial: testHookPostPendingDial,
}
defer func() {
if err != nil {
w.cancel(t, err)
}
}()
// Queue for idle connection.
if delivered := t.queueForIdleConn(w); delivered {
pc := w.pc
// Trace only for HTTP/1.
// HTTP/2 calls trace.GotConn itself.
if pc.alt == nil && trace != nil && trace.GotConn != nil {
trace.GotConn(pc.gotIdleConnTrace(pc.idleAt))
}
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
t.setReqCanceler(treq.cancelKey, func(error) {})
return pc, nil
}
cancelc := make(chan error, 1)
t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
// Queue for permission to dial.
t.queueForDial(w)
// Wait for completion or cancellation.
select {
case <-w.ready:
// Trace success but only for HTTP/1.
// HTTP/2 calls trace.GotConn itself.
if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()})
}
if w.err != nil {
// If the request has been canceled, that's probably
// what caused w.err; if so, prefer to return the
// cancellation error (see golang.org/issue/16049).
select {
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
default:
// return below
}
}
return w.pc, w.err
case <-req.Cancel:
return nil, errRequestCanceledConn
case <-req.Context().Done():
return nil, req.Context().Err()
case err := <-cancelc:
if err == errRequestCanceled {
err = errRequestCanceledConn
}
return nil, err
}
}
// queueForDial queues w to wait for permission to begin dialing.
// Once w receives permission to dial, it will do so in a separate goroutine.
func (t *Transport) queueForDial(w *wantConn) {
w.beforeDial()
if t.MaxConnsPerHost <= 0 {
go t.dialConnFor(w)
return
}
t.connsPerHostMu.Lock()
defer t.connsPerHostMu.Unlock()
if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
if t.connsPerHost == nil {
t.connsPerHost = make(map[connectMethodKey]int)
}
t.connsPerHost[w.key] = n + 1
go t.dialConnFor(w)
return
}
if t.connsPerHostWait == nil {
t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
}
q := t.connsPerHostWait[w.key]
q.cleanFront()
q.pushBack(w)
t.connsPerHostWait[w.key] = q
}
// dialConnFor dials on behalf of w and delivers the result to w.
// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
func (t *Transport) dialConnFor(w *wantConn) {
defer w.afterDial()
pc, err := t.dialConn(w.ctx, w.cm)
delivered := w.tryDeliver(pc, err)
if err == nil && (!delivered || pc.alt != nil) {
// pconn was not passed to w,
// or it is HTTP/2 and can be shared.
// Add to the idle connection pool.
t.putOrCloseIdleConn(pc)
}
if err != nil {
t.decConnsPerHost(w.key)
}
}
// decConnsPerHost decrements the per-host connection count for key,
// which may in turn give a different waiting goroutine permission to dial.
func (t *Transport) decConnsPerHost(key connectMethodKey) {
if t.MaxConnsPerHost <= 0 {
return
}
t.connsPerHostMu.Lock()
defer t.connsPerHostMu.Unlock()
n := t.connsPerHost[key]
if n == 0 {
// Shouldn't happen, but if it does, the counting is buggy and could
// easily lead to a silent deadlock, so report the problem loudly.
panic("net/http: internal error: connCount underflow")
}
// Can we hand this count to a goroutine still waiting to dial?
// (Some goroutines on the wait list may have timed out or
// gotten a connection another way. If they're all gone,
// we don't want to kick off any spurious dial operations.)
if q := t.connsPerHostWait[key]; q.len() > 0 {
done := false
for q.len() > 0 {
w := q.popFront()
if w.waiting() {
go t.dialConnFor(w)
done = true
break
}
}
if q.len() == 0 {
delete(t.connsPerHostWait, key)
} else {
// q is a value (like a slice), so we have to store
// the updated q back into the map.
t.connsPerHostWait[key] = q
}
if done {
return
}
}
// Otherwise, decrement the recorded count.
if n--; n == 0 {
delete(t.connsPerHost, key)
} else {
t.connsPerHost[key] = n
}
}
// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
// tunnel, this function establishes a nested TLS session inside the encrypted channel.
// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error {
// Initiate TLS and check remote host name against certificate.
cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = name
}
if pconn.cacheKey.onlyH1 {
cfg.NextProtos = nil
}
plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := pconn.t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := tlsConn.HandshakeContext(ctx)
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return err
}
cs := tlsConn.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs
pconn.conn = tlsConn
return nil
}
type erringRoundTripper interface {
RoundTripErr() error
}
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
pconn = &persistConn{
t: t,
cacheKey: cm.key(),
reqch: make(chan requestAndChan, 1),
writech: make(chan writeRequest, 1),
closech: make(chan struct{}),
writeErrCh: make(chan error, 1),
writeLoopDone: make(chan struct{}),
}
trace := httptrace.ContextClientTrace(ctx)
wrapErr := func(err error) error {
if cm.proxyURL != nil {
// Return a typed error, per Issue 16997
return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
}
return err
}
if cm.scheme() == "https" && t.hasCustomTLSDialer() {
var err error
pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
if err != nil {
return nil, wrapErr(err)
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state.
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
if err := tc.HandshakeContext(ctx); err != nil {
go pconn.conn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return nil, err
}
cs := tc.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs
}
} else {
conn, err := t.dial(ctx, "tcp", cm.addr())
if err != nil {
return nil, wrapErr(err)
}
pconn.conn = conn
if cm.scheme() == "https" {
var firstTLSHost string
if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
return nil, wrapErr(err)
}
if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
return nil, wrapErr(err)
}
}
}
// Proxy setup.
switch {
case cm.proxyURL == nil:
// Do nothing. Not using a proxy.
case cm.proxyURL.Scheme == "socks5":
conn := pconn.conn
d := socksNewDialer("tcp", conn.RemoteAddr().String())
if u := cm.proxyURL.User; u != nil {
auth := &socksUsernamePassword{
Username: u.Username(),
}
auth.Password, _ = u.Password()
d.AuthMethods = []socksAuthMethod{
socksAuthMethodNotRequired,
socksAuthMethodUsernamePassword,
}
d.Authenticate = auth.Authenticate
}
if _, err := d.DialWithConn(ctx, conn, "tcp", cm.targetAddr); err != nil {
conn.Close()
return nil, err
}
case cm.targetScheme == "http":
pconn.isProxy = true
if pa := cm.proxyAuth(); pa != "" {
pconn.mutateHeaderFunc = func(h Header) {
h.Set("Proxy-Authorization", pa)
}
}
case cm.targetScheme == "https":
conn := pconn.conn
var hdr Header
if t.GetProxyConnectHeader != nil {
var err error
hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
if err != nil {
conn.Close()
return nil, err
}
} else {
hdr = t.ProxyConnectHeader
}
if hdr == nil {
hdr = make(Header)
}
if pa := cm.proxyAuth(); pa != "" {
hdr = hdr.Clone()
hdr.Set("Proxy-Authorization", pa)
}
connectReq := &Request{
Method: "CONNECT",
URL: &url.URL{Opaque: cm.targetAddr},
Host: cm.targetAddr,
Header: hdr,
}
// If there's no done channel (no deadline or cancellation
// from the caller possible), at least set some (long)
// timeout here. This will make sure we don't block forever
// and leak a goroutine if the connection stops replying
// after the TCP connect.
connectCtx := ctx
if ctx.Done() == nil {
newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
connectCtx = newCtx
}
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
var (
resp *Response
err error // write or read error
)
// Write the CONNECT request & read the response.
go func() {
defer close(didReadResponse)
err = connectReq.Write(conn)
if err != nil {
return
}
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(conn)
resp, err = ReadResponse(br, connectReq)
}()
select {
case <-connectCtx.Done():
conn.Close()
<-didReadResponse
return nil, connectCtx.Err()
case <-didReadResponse:
// resp or err now set
}
if err != nil {
conn.Close()
return nil, err
}
if t.OnProxyConnectResponse != nil {
err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
if err != nil {
return nil, err
}
}
if resp.StatusCode != 200 {
_, text, ok := strings.Cut(resp.Status, " ")
conn.Close()
if !ok {
return nil, errors.New("unknown status code")
}
return nil, errors.New(text)
}
}
if cm.proxyURL != nil && cm.targetScheme == "https" {
if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil {
return nil, err
}
}
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
alt := next(cm.targetAddr, pconn.conn.(*tls.Conn))
if e, ok := alt.(erringRoundTripper); ok {
// pconn.conn was closed by next (http2configureTransports.upgradeFn).
return nil, e.RoundTripErr()
}
return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil
}
}
pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize())
pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize())
go pconn.readLoop()
go pconn.writeLoop()
return pconn, nil
}
// persistConnWriter is the io.Writer written to by pc.bw.
// It accumulates the number of bytes written to the underlying conn,
// so the retry logic can determine whether any bytes made it across
// the wire.
// This is exactly 1 pointer field wide so it can go into an interface
// without allocation.
type persistConnWriter struct {
pc *persistConn
}
func (w persistConnWriter) Write(p []byte) (n int, err error) {
n, err = w.pc.conn.Write(p)
w.pc.nwrite += int64(n)
return
}
// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if
// the Conn implements io.ReaderFrom, it can take advantage of optimizations
// such as sendfile.
func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) {
n, err = io.Copy(w.pc.conn, r)
w.pc.nwrite += n
return
}
var _ io.ReaderFrom = (*persistConnWriter)(nil)
// connectMethod is the map key (in its String form) for keeping persistent
// TCP connections alive for subsequent HTTP requests.
//
// A connect method may be of the following types:
//
// connectMethod.key().String() Description
// ------------------------------ -------------------------
// |http|foo.com http directly to server, no proxy
// |https|foo.com https directly to server, no proxy
// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy
// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
// http://proxy.com|http http to proxy, http to anywhere after that
// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com
// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com
// https://proxy.com|http https to proxy, http to anywhere after that
type connectMethod struct {
_ incomparable
proxyURL *url.URL // nil for no proxy, else full proxy URL
targetScheme string // "http" or "https"
// If proxyURL specifies an http or https proxy, and targetScheme is http (not https),
// then targetAddr is not included in the connect method key, because the socket can
// be reused for different targetAddr values.
targetAddr string
onlyH1 bool // whether to disable HTTP/2 and force HTTP/1
}
func (cm *connectMethod) key() connectMethodKey {
proxyStr := ""
targetAddr := cm.targetAddr
if cm.proxyURL != nil {
proxyStr = cm.proxyURL.String()
if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" {
targetAddr = ""
}
}
return connectMethodKey{
proxy: proxyStr,
scheme: cm.targetScheme,
addr: targetAddr,
onlyH1: cm.onlyH1,
}
}
// scheme returns the first hop scheme: http, https, or socks5
func (cm *connectMethod) scheme() string {
if cm.proxyURL != nil {
return cm.proxyURL.Scheme
}
return cm.targetScheme
}
// addr returns the first hop "host:port" to which we need to TCP connect.
func (cm *connectMethod) addr() string {
if cm.proxyURL != nil {
return canonicalAddr(cm.proxyURL)
}
return cm.targetAddr
}
// tlsHost returns the host name to match against the peer's
// TLS certificate.
func (cm *connectMethod) tlsHost() string {
h := cm.targetAddr
if hasPort(h) {
h = h[:strings.LastIndex(h, ":")]
}
return h
}
// connectMethodKey is the map key version of connectMethod, with a
// stringified proxy URL (or the empty string) instead of a pointer to
// a URL.
type connectMethodKey struct {
proxy, scheme, addr string
onlyH1 bool
}
func (k connectMethodKey) String() string {
// Only used by tests.
var h1 string
if k.onlyH1 {
h1 = ",h1"
}
return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr)
}
// persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well)
type persistConn struct {
// alt optionally specifies the TLS NextProto RoundTripper.
// This is used for HTTP/2 today and future protocols later.
// If it's non-nil, the rest of the fields are unused.
alt RoundTripper
t *Transport
cacheKey connectMethodKey
conn net.Conn
tlsState *tls.ConnectionState
br *bufio.Reader // from conn
bw *bufio.Writer // to conn
nwrite int64 // bytes written
reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop
closech chan struct{} // closed when conn closed
isProxy bool
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
readLimit int64 // bytes allowed to be read; owned by readLoop
// writeErrCh passes the request write error (usually nil)
// from the writeLoop goroutine to the readLoop which passes
// it off to the res.Body reader, which then uses it to decide
// whether or not a connection can be reused. Issue 7569.
writeErrCh chan error
writeLoopDone chan struct{} // closed when write loop ends
// Both guarded by Transport.idleMu:
idleAt time.Time // time it last become idle
idleTimer *time.Timer // holding an AfterFunc to close it
mu sync.Mutex // guards following fields
numExpectedResponses int
closed error // set non-nil when conn is closed, before closech is closed
canceledErr error // set non-nil if conn is canceled
broken bool // an error has happened on this connection; marked broken so it's not reused.
reused bool // whether conn has had successful request/response and is being reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
mutateHeaderFunc func(Header)
}
func (pc *persistConn) maxHeaderResponseSize() int64 {
if v := pc.t.MaxResponseHeaderBytes; v != 0 {
return v
}
return 10 << 20 // conservative default; same as http2
}
func (pc *persistConn) Read(p []byte) (n int, err error) {
if pc.readLimit <= 0 {
return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
}
if int64(len(p)) > pc.readLimit {
p = p[:pc.readLimit]
}
n, err = pc.conn.Read(p)
if err == io.EOF {
pc.sawEOF = true
}
pc.readLimit -= int64(n)
return
}
// isBroken reports whether this connection is in a known broken state.
func (pc *persistConn) isBroken() bool {
pc.mu.Lock()
b := pc.closed != nil
pc.mu.Unlock()
return b
}
// canceled returns non-nil if the connection was closed due to
// CancelRequest or due to context cancellation.
func (pc *persistConn) canceled() error {
pc.mu.Lock()
defer pc.mu.Unlock()
return pc.canceledErr
}
// isReused reports whether this connection has been used before.
func (pc *persistConn) isReused() bool {
pc.mu.Lock()
r := pc.reused
pc.mu.Unlock()
return r
}
func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) {
pc.mu.Lock()
defer pc.mu.Unlock()
t.Reused = pc.reused
t.Conn = pc.conn
t.WasIdle = true
if !idleAt.IsZero() {
t.IdleTime = time.Since(idleAt)
}
return
}
func (pc *persistConn) cancelRequest(err error) {
pc.mu.Lock()
defer pc.mu.Unlock()
pc.canceledErr = err
pc.closeLocked(errRequestCanceled)
}
// closeConnIfStillIdle closes the connection if it's still sitting idle.
// This is what's called by the persistConn's idleTimer, and is run in its
// own goroutine.
func (pc *persistConn) closeConnIfStillIdle() {
t := pc.t
t.idleMu.Lock()
defer t.idleMu.Unlock()
if _, ok := t.idleLRU.m[pc]; !ok {
// Not idle.
return
}
t.removeIdleConnLocked(pc)
pc.close(errIdleConnTimeout)
}
// mapRoundTripError returns the appropriate error value for
// persistConn.roundTrip.
//
// The provided err is the first error that (*persistConn).roundTrip
// happened to receive from its select statement.
//
// The startBytesWritten value should be the value of pc.nwrite before the roundTrip
// started writing the request.
func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error {
if err == nil {
return nil
}
// Wait for the writeLoop goroutine to terminate to avoid data
// races on callers who mutate the request on failure.
//
// When resc in pc.roundTrip and hence rc.ch receives a responseAndError
// with a non-nil error it implies that the persistConn is either closed
// or closing. Waiting on pc.writeLoopDone is hence safe as all callers
// close closech which in turn ensures writeLoop returns.
<-pc.writeLoopDone
// If the request was canceled, that's better than network
// failures that were likely the result of tearing down the
// connection.
if cerr := pc.canceled(); cerr != nil {
return cerr
}
// See if an error was set explicitly.
req.mu.Lock()
reqErr := req.err
req.mu.Unlock()
if reqErr != nil {
return reqErr
}
if err == errServerClosedIdle {
// Don't decorate
return err
}
if _, ok := err.(transportReadFromServerError); ok {
if pc.nwrite == startBytesWritten {
return nothingWrittenError{err}
}
// Don't decorate
return err
}
if pc.isBroken() {
if pc.nwrite == startBytesWritten {
return nothingWrittenError{err}
}
return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err)
}
return err
}
// errCallerOwnsConn is an internal sentinel error used when we hand
// off a writable response.Body to the caller. We use this to prevent
// closing a net.Conn that is now owned by the caller.
var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")
func (pc *persistConn) readLoop() {
closeErr := errReadLoopExiting // default value, if not changed below
defer func() {
pc.close(closeErr)
pc.t.removeIdleConn(pc)
}()
tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
if err := pc.t.tryPutIdleConn(pc); err != nil {
closeErr = err
if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
trace.PutIdleConn(err)
}
return false
}
if trace != nil && trace.PutIdleConn != nil {
trace.PutIdleConn(nil)
}
return true
}
// eofc is used to block caller goroutines reading from Response.Body
// at EOF until this goroutines has (potentially) added the connection
// back to the idle pool.
eofc := make(chan struct{})
defer close(eofc) // unblock reader on errors
// Read this once, before loop starts. (to avoid races in tests)
testHookMu.Lock()
testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead
testHookMu.Unlock()
alive := true
for alive {
pc.readLimit = pc.maxHeaderResponseSize()
_, err := pc.br.Peek(1)
pc.mu.Lock()
if pc.numExpectedResponses == 0 {
pc.readLoopPeekFailLocked(err)
pc.mu.Unlock()
return
}
pc.mu.Unlock()
rc := <-pc.reqch
trace := httptrace.ContextClientTrace(rc.req.Context())
var resp *Response
if err == nil {
resp, err = pc.readResponse(rc, trace)
} else {
err = transportReadFromServerError{err}
closeErr = err
}
if err != nil {
if pc.readLimit <= 0 {
err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
}
select {
case rc.ch <- responseAndError{err: err}:
case <-rc.callerGone:
return
}
return
}
pc.readLimit = maxInt64 // effectively no limit for response bodies
pc.mu.Lock()
pc.numExpectedResponses--
pc.mu.Unlock()
bodyWritable := resp.bodyIsWritable()
hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
alive = false
}
if !hasBody || bodyWritable {
replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
// Put the idle conn back into the pool before we send the response
// so if they process it quickly and make another request, they'll
// get this same conn. But we use the unbuffered channel 'rc'
// to guarantee that persistConn.roundTrip got out of its select
// potentially waiting for this persistConn to close.
alive = alive &&
!pc.sawEOF &&
pc.wroteRequest() &&
replaced && tryPutIdleConn(trace)
if bodyWritable {
closeErr = errCallerOwnsConn
}
select {
case rc.ch <- responseAndError{res: resp}:
case <-rc.callerGone:
return
}
// Now that they've read from the unbuffered channel, they're safely
// out of the select that also waits on this goroutine to die, so
// we're allowed to exit now if needed (if alive is false)
testHookReadLoopBeforeNextRead()
continue
}
waitForBodyRead := make(chan bool, 2)
body := &bodyEOFSignal{
body: resp.Body,
earlyCloseFn: func() error {
waitForBodyRead <- false
<-eofc // will be closed by deferred call at the end of the function
return nil
},
fn: func(err error) error {
isEOF := err == io.EOF
waitForBodyRead <- isEOF
if isEOF {
<-eofc // see comment above eofc declaration
} else if err != nil {
if cerr := pc.canceled(); cerr != nil {
return cerr
}
}
return err
},
}
resp.Body = body
if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
resp.Body = &gzipReader{body: body}
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
resp.ContentLength = -1
resp.Uncompressed = true
}
select {
case rc.ch <- responseAndError{res: resp}:
case <-rc.callerGone:
return
}
// Before looping back to the top of this function and peeking on
// the bufio.Reader, wait for the caller goroutine to finish
// reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
alive = alive &&
bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
replaced && tryPutIdleConn(trace)
if bodyEOF {
eofc <- struct{}{}
}
case <-rc.req.Cancel:
alive = false
pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
case <-pc.closech:
alive = false
}
testHookReadLoopBeforeNextRead()
}
}
func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
if pc.closed != nil {
return
}
if n := pc.br.Buffered(); n > 0 {
buf, _ := pc.br.Peek(n)
if is408Message(buf) {
pc.closeLocked(errServerClosedIdle)
return
} else {
log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr)
}
}
if peekErr == io.EOF {
// common case.
pc.closeLocked(errServerClosedIdle)
} else {
pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", peekErr))
}
}
// is408Message reports whether buf has the prefix of an
// HTTP 408 Request Timeout response.
// See golang.org/issue/32310.
func is408Message(buf []byte) bool {
if len(buf) < len("HTTP/1.x 408") {
return false
}
if string(buf[:7]) != "HTTP/1." {
return false
}
return string(buf[8:12]) == " 408"
}
// readResponse reads an HTTP response (or two, in the case of "Expect:
// 100-continue") from the server. It returns the final non-100 one.
// trace is optional.
func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTrace) (resp *Response, err error) {
if trace != nil && trace.GotFirstResponseByte != nil {
if peek, err := pc.br.Peek(1); err == nil && len(peek) == 1 {
trace.GotFirstResponseByte()
}
}
num1xx := 0 // number of informational 1xx headers received
const max1xxResponses = 5 // arbitrary bound on number of informational responses
continueCh := rc.continueCh
for {
resp, err = ReadResponse(pc.br, rc.req)
if err != nil {
return
}
resCode := resp.StatusCode
if continueCh != nil {
if resCode == 100 {
if trace != nil && trace.Got100Continue != nil {
trace.Got100Continue()
}
continueCh <- struct{}{}
continueCh = nil
} else if resCode >= 200 {
close(continueCh)
continueCh = nil
}
}
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != StatusSwitchingProtocols
if is1xxNonTerminal {
num1xx++
if num1xx > max1xxResponses {
return nil, errors.New("net/http: too many 1xx informational responses")
}
pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
if trace != nil && trace.Got1xxResponse != nil {
if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil {
return nil, err
}
}
continue
}
break
}
if resp.isProtocolSwitch() {
resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
}
resp.TLS = pc.tlsState
return
}
// waitForContinue returns the function to block until
// any response, timeout or connection close. After any of them,
// the function returns a bool which indicates if the body should be sent.
func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
if continueCh == nil {
return nil
}
return func() bool {
timer := time.NewTimer(pc.t.ExpectContinueTimeout)
defer timer.Stop()
select {
case _, ok := <-continueCh:
return ok
case <-timer.C:
return true
case <-pc.closech:
return false
}
}
}
func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
body := &readWriteCloserBody{ReadWriteCloser: rwc}
if br.Buffered() != 0 {
body.br = br
}
return body
}
// readWriteCloserBody is the Response.Body type used when we want to
// give users write access to the Body through the underlying
// connection (TCP, unless using custom dialers). This is then
// the concrete type for a Response.Body on the 101 Switching
// Protocols response, as used by WebSockets, h2c, etc.
type readWriteCloserBody struct {
_ incomparable
br *bufio.Reader // used until empty
io.ReadWriteCloser
}
func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
if b.br != nil {
if n := b.br.Buffered(); len(p) > n {
p = p[:n]
}
n, err = b.br.Read(p)
if b.br.Buffered() == 0 {
b.br = nil
}
return n, err
}
return b.ReadWriteCloser.Read(p)
}
// nothingWrittenError wraps a write errors which ended up writing zero bytes.
type nothingWrittenError struct {
error
}
func (nwe nothingWrittenError) Unwrap() error {
return nwe.error
}
func (pc *persistConn) writeLoop() {
defer close(pc.writeLoopDone)
for {
select {
case wr := <-pc.writech:
startBytesWritten := pc.nwrite
err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
if bre, ok := err.(requestBodyReadError); ok {
err = bre.error
// Errors reading from the user's
// Request.Body are high priority.
// Set it here before sending on the
// channels below or calling
// pc.close() which tears down
// connections and causes other
// errors.
wr.req.setError(err)
}
if err == nil {
err = pc.bw.Flush()
}
if err != nil {
if pc.nwrite == startBytesWritten {
err = nothingWrittenError{err}
}
}
pc.writeErrCh <- err // to the body reader, which might recycle us
wr.ch <- err // to the roundTrip function
if err != nil {
pc.close(err)
return
}
case <-pc.closech:
return
}
}
}
// maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip
// will wait to see the Request's Body.Write result after getting a
// response from the server. See comments in (*persistConn).wroteRequest.
const maxWriteWaitBeforeConnReuse = 50 * time.Millisecond
// wroteRequest is a check before recycling a connection that the previous write
// (from writeLoop above) happened and was successful.
func (pc *persistConn) wroteRequest() bool {
select {
case err := <-pc.writeErrCh:
// Common case: the write happened well before the response, so
// avoid creating a timer.
return err == nil
default:
// Rare case: the request was written in writeLoop above but
// before it could send to pc.writeErrCh, the reader read it
// all, processed it, and called us here. In this case, give the
// write goroutine a bit of time to finish its send.
//
// Less rare case: We also get here in the legitimate case of
// Issue 7569, where the writer is still writing (or stalled),
// but the server has already replied. In this case, we don't
// want to wait too long, and we want to return false so this
// connection isn't re-used.
t := time.NewTimer(maxWriteWaitBeforeConnReuse)
defer t.Stop()
select {
case err := <-pc.writeErrCh:
return err == nil
case <-t.C:
return false
}
}
}
// responseAndError is how the goroutine reading from an HTTP/1 server
// communicates with the goroutine doing the RoundTrip.
type responseAndError struct {
_ incomparable
res *Response // else use this response (see res method)
err error
}
type requestAndChan struct {
_ incomparable
req *Request
cancelKey cancelKey
ch chan responseAndError // unbuffered; always send in select on callerGone
// whether the Transport (as opposed to the user client code)
// added the Accept-Encoding gzip header. If the Transport
// set it, only then do we transparently decode the gzip.
addedGzip bool
// Optional blocking chan for Expect: 100-continue (for send).
// If the request has an "Expect: 100-continue" header and
// the server responds 100 Continue, readLoop send a value
// to writeLoop via this chan.
continueCh chan<- struct{}
callerGone <-chan struct{} // closed when roundTrip caller has returned
}
// A writeRequest is sent by the caller's goroutine to the
// writeLoop's goroutine to write a request while the read loop
// concurrently waits on both the write response and the server's
// reply.
type writeRequest struct {
req *transportRequest
ch chan<- error
// Optional blocking chan for Expect: 100-continue (for receive).
// If not nil, writeLoop blocks sending request body until
// it receives from this chan.
continueCh <-chan struct{}
}
type httpError struct {
err string
timeout bool
}
func (e *httpError) Error() string { return e.err }
func (e *httpError) Timeout() bool { return e.timeout }
func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
// errRequestCanceled is set to be identical to the one from h2 to facilitate
// testing.
var errRequestCanceled = http2errRequestCanceled
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
func nop() {}
// testHooks. Always non-nil.
var (
testHookEnterRoundTrip = nop
testHookWaitResLoop = nop
testHookRoundTripRetried = nop
testHookPrePendingDial = nop
testHookPostPendingDial = nop
testHookMu sync.Locker = fakeLocker{} // guards following
testHookReadLoopBeforeNextRead = nop
)
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
testHookEnterRoundTrip()
if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled
}
pc.mu.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
pc.mu.Unlock()
if headerFn != nil {
headerFn(req.extraHeaders())
}
// Ask for a compressed version if the caller didn't set their
// own value for Accept-Encoding. We only attempt to
// uncompress the gzip stream if we were the layer that
// requested it.
requestedGzip := false
if !pc.t.DisableCompression &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
req.Method != "HEAD" {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
// https://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
// auto-decoding a portion of a gzipped document will just fail
// anyway. See https://golang.org/issue/8923
requestedGzip = true
req.extraHeaders().Set("Accept-Encoding", "gzip")
}
var continueCh chan struct{}
if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() {
continueCh = make(chan struct{}, 1)
}
if pc.t.DisableKeepAlives &&
!req.wantsClose() &&
!isProtocolSwitchHeader(req.Header) {
req.extraHeaders().Set("Connection", "close")
}
gone := make(chan struct{})
defer close(gone)
defer func() {
if err != nil {
pc.t.setReqCanceler(req.cancelKey, nil)
}
}()
const debugRoundTrip = false
// Write the request concurrently with waiting for a response,
// in case the server decides to reply before reading our full
// request body.
startBytesWritten := pc.nwrite
writeErrCh := make(chan error, 1)
pc.writech <- writeRequest{req, writeErrCh, continueCh}
resc := make(chan responseAndError)
pc.reqch <- requestAndChan{
req: req.Request,
cancelKey: req.cancelKey,
ch: resc,
addedGzip: requestedGzip,
continueCh: continueCh,
callerGone: gone,
}
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
ctxDoneChan := req.Context().Done()
pcClosed := pc.closech
canceled := false
for {
testHookWaitResLoop()
select {
case err := <-writeErrCh:
if debugRoundTrip {
req.logf("writeErrCh resv: %T/%#v", err, err)
}
if err != nil {
pc.close(fmt.Errorf("write error: %w", err))
return nil, pc.mapRoundTripError(req, startBytesWritten, err)
}
if d := pc.t.ResponseHeaderTimeout; d > 0 {
if debugRoundTrip {
req.logf("starting timer for %v", d)
}
timer := time.NewTimer(d)
defer timer.Stop() // prevent leaks
respHeaderTimer = timer.C
}
case <-pcClosed:
pcClosed = nil
if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
if debugRoundTrip {
req.logf("closech recv: %T %#v", pc.closed, pc.closed)
}
return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
}
case <-respHeaderTimer:
if debugRoundTrip {
req.logf("timeout waiting for response headers.")
}
pc.close(errTimeout)
return nil, errTimeout
case re := <-resc:
if (re.res == nil) == (re.err == nil) {
panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
}
if debugRoundTrip {
req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
}
if re.err != nil {
return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
}
return re.res, nil
case <-cancelChan:
canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
cancelChan = nil
case <-ctxDoneChan:
canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
}
}
// tLogKey is a context WithValue key for test debugging contexts containing
// a t.Logf func. See export_test.go's Request.WithT method.
type tLogKey struct{}
func (tr *transportRequest) logf(format string, args ...any) {
if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok {
logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
}
}
// markReused marks this connection as having been successfully used for a
// request and response.
func (pc *persistConn) markReused() {
pc.mu.Lock()
pc.reused = true
pc.mu.Unlock()
}
// close closes the underlying TCP connection and closes
// the pc.closech channel.
//
// The provided err is only for testing and debugging; in normal
// circumstances it should never be seen by users.
func (pc *persistConn) close(err error) {
pc.mu.Lock()
defer pc.mu.Unlock()
pc.closeLocked(err)
}
func (pc *persistConn) closeLocked(err error) {
if err == nil {
panic("nil error")
}
pc.broken = true
if pc.closed == nil {
pc.closed = err
pc.t.decConnsPerHost(pc.cacheKey)
// Close HTTP/1 (pc.alt == nil) connection.
// HTTP/2 closes its connection itself.
if pc.alt == nil {
if err != errCallerOwnsConn {
pc.conn.Close()
}
close(pc.closech)
}
}
pc.mutateHeaderFunc = nil
}
var portMap = map[string]string{
"http": "80",
"https": "443",
"socks5": "1080",
}
func idnaASCIIFromURL(url *url.URL) string {
addr := url.Hostname()
if v, err := idnaASCII(addr); err == nil {
addr = v
}
return addr
}
// canonicalAddr returns url.Host but always with a ":port" suffix.
func canonicalAddr(url *url.URL) string {
port := url.Port()
if port == "" {
port = portMap[url.Scheme]
}
return net.JoinHostPort(idnaASCIIFromURL(url), port)
}
// bodyEOFSignal is used by the HTTP/1 transport when reading response
// bodies to make sure we see the end of a response body before
// proceeding and reading on the connection again.
//
// It wraps a ReadCloser but runs fn (if non-nil) at most
// once, right before its final (error-producing) Read or Close call
// returns. fn should return the new error to return from Read or Close.
//
// If earlyCloseFn is non-nil and Close is called before io.EOF is
// seen, earlyCloseFn is called instead of fn, and its return value is
// the return value from Close.
type bodyEOFSignal struct {
body io.ReadCloser
mu sync.Mutex // guards following 4 fields
closed bool // whether Close has been called
rerr error // sticky Read error
fn func(error) error // err will be nil on Read io.EOF
earlyCloseFn func() error // optional alt Close func used if io.EOF not seen
}
var errReadOnClosedResBody = errors.New("http: read on closed response body")
func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
es.mu.Lock()
closed, rerr := es.closed, es.rerr
es.mu.Unlock()
if closed {
return 0, errReadOnClosedResBody
}
if rerr != nil {
return 0, rerr
}
n, err = es.body.Read(p)
if err != nil {
es.mu.Lock()
defer es.mu.Unlock()
if es.rerr == nil {
es.rerr = err
}
err = es.condfn(err)
}
return
}
func (es *bodyEOFSignal) Close() error {
es.mu.Lock()
defer es.mu.Unlock()
if es.closed {
return nil
}
es.closed = true
if es.earlyCloseFn != nil && es.rerr != io.EOF {
return es.earlyCloseFn()
}
err := es.body.Close()
return es.condfn(err)
}
// caller must hold es.mu.
func (es *bodyEOFSignal) condfn(err error) error {
if es.fn == nil {
return err
}
err = es.fn(err)
es.fn = nil
return err
}
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
type gzipReader struct {
_ incomparable
body *bodyEOFSignal // underlying HTTP/1 response body framing
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // any error from gzip.NewReader; sticky
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zr == nil {
if gz.zerr == nil {
gz.zr, gz.zerr = gzip.NewReader(gz.body)
}
if gz.zerr != nil {
return 0, gz.zerr
}
}
gz.body.mu.Lock()
if gz.body.closed {
err = errReadOnClosedResBody
}
gz.body.mu.Unlock()
if err != nil {
return 0, err
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// fakeLocker is a sync.Locker which does nothing. It's used to guard
// test-only fields when not under test, to avoid runtime atomic
// overhead.
type fakeLocker struct{}
func (fakeLocker) Lock() {}
func (fakeLocker) Unlock() {}
// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
// cfg is nil. This is safe to call even if cfg is in active use by a TLS
// client or server.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}
type connLRU struct {
ll *list.List // list.Element.Value type is of *persistConn
m map[*persistConn]*list.Element
}
// add adds pc to the head of the linked list.
func (cl *connLRU) add(pc *persistConn) {
if cl.ll == nil {
cl.ll = list.New()
cl.m = make(map[*persistConn]*list.Element)
}
ele := cl.ll.PushFront(pc)
if _, ok := cl.m[pc]; ok {
panic("persistConn was already in LRU")
}
cl.m[pc] = ele
}
func (cl *connLRU) removeOldest() *persistConn {
ele := cl.ll.Back()
pc := ele.Value.(*persistConn)
cl.ll.Remove(ele)
delete(cl.m, pc)
return pc
}
// remove removes pc from cl.
func (cl *connLRU) remove(pc *persistConn) {
if ele, ok := cl.m[pc]; ok {
cl.ll.Remove(ele)
delete(cl.m, pc)
}
}
// len returns the number of items in the cache.
func (cl *connLRU) len() int {
return len(cl.m)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !(js && wasm)
// +build !js !wasm
package http
import (
"context"
"net"
)
func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return dialer.DialContext
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"errors"
"internal/itoa"
"sync"
"time"
)
// BUG(mikio): On JS, methods and functions related to
// Interface are not implemented.
// BUG(mikio): On AIX, DragonFly BSD, NetBSD, OpenBSD, Plan 9 and
// Solaris, the MulticastAddrs method of Interface is not implemented.
var (
errInvalidInterface = errors.New("invalid network interface")
errInvalidInterfaceIndex = errors.New("invalid network interface index")
errInvalidInterfaceName = errors.New("invalid network interface name")
errNoSuchInterface = errors.New("no such network interface")
errNoSuchMulticastInterface = errors.New("no such multicast network interface")
)
// Interface represents a mapping between network interface name
// and index. It also represents network interface facility
// information.
type Interface struct {
Index int // positive integer that starts at one, zero is never used
MTU int // maximum transmission unit
Name string // e.g., "en0", "lo0", "eth0.100"
HardwareAddr HardwareAddr // IEEE MAC-48, EUI-48 and EUI-64 form
Flags Flags // e.g., FlagUp, FlagLoopback, FlagMulticast
}
type Flags uint
const (
FlagUp Flags = 1 << iota // interface is administratively up
FlagBroadcast // interface supports broadcast access capability
FlagLoopback // interface is a loopback interface
FlagPointToPoint // interface belongs to a point-to-point link
FlagMulticast // interface supports multicast access capability
FlagRunning // interface is in running state
)
var flagNames = []string{
"up",
"broadcast",
"loopback",
"pointtopoint",
"multicast",
"running",
}
func (f Flags) String() string {
s := ""
for i, name := range flagNames {
if f&(1<<uint(i)) != 0 {
if s != "" {
s += "|"
}
s += name
}
}
if s == "" {
s = "0"
}
return s
}
// Addrs returns a list of unicast interface addresses for a specific
// interface.
func (ifi *Interface) Addrs() ([]Addr, error) {
if ifi == nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
}
ifat, err := interfaceAddrTable(ifi)
if err != nil {
err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
return ifat, err
}
// MulticastAddrs returns a list of multicast, joined group addresses
// for a specific interface.
func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
if ifi == nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
}
ifat, err := interfaceMulticastAddrTable(ifi)
if err != nil {
err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
return ifat, err
}
// Interfaces returns a list of the system's network interfaces.
func Interfaces() ([]Interface, error) {
ift, err := interfaceTable(0)
if err != nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
if len(ift) != 0 {
zoneCache.update(ift, false)
}
return ift, nil
}
// InterfaceAddrs returns a list of the system's unicast interface
// addresses.
//
// The returned list does not identify the associated interface; use
// Interfaces and Interface.Addrs for more detail.
func InterfaceAddrs() ([]Addr, error) {
ifat, err := interfaceAddrTable(nil)
if err != nil {
err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
return ifat, err
}
// InterfaceByIndex returns the interface specified by index.
//
// On Solaris, it returns one of the logical network interfaces
// sharing the logical data link; for more precision use
// InterfaceByName.
func InterfaceByIndex(index int) (*Interface, error) {
if index <= 0 {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex}
}
ift, err := interfaceTable(index)
if err != nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
ifi, err := interfaceByIndex(ift, index)
if err != nil {
err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
return ifi, err
}
func interfaceByIndex(ift []Interface, index int) (*Interface, error) {
for _, ifi := range ift {
if index == ifi.Index {
return &ifi, nil
}
}
return nil, errNoSuchInterface
}
// InterfaceByName returns the interface specified by name.
func InterfaceByName(name string) (*Interface, error) {
if name == "" {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}
}
ift, err := interfaceTable(0)
if err != nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
}
if len(ift) != 0 {
zoneCache.update(ift, false)
}
for _, ifi := range ift {
if name == ifi.Name {
return &ifi, nil
}
}
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface}
}
// An ipv6ZoneCache represents a cache holding partial network
// interface information. It is used for reducing the cost of IPv6
// addressing scope zone resolution.
//
// Multiple names sharing the index are managed by first-come
// first-served basis for consistency.
type ipv6ZoneCache struct {
sync.RWMutex // guard the following
lastFetched time.Time // last time routing information was fetched
toIndex map[string]int // interface name to its index
toName map[int]string // interface index to its name
}
var zoneCache = ipv6ZoneCache{
toIndex: make(map[string]int),
toName: make(map[int]string),
}
// update refreshes the network interface information if the cache was last
// updated more than 1 minute ago, or if force is set. It reports whether the
// cache was updated.
func (zc *ipv6ZoneCache) update(ift []Interface, force bool) (updated bool) {
zc.Lock()
defer zc.Unlock()
now := time.Now()
if !force && zc.lastFetched.After(now.Add(-60*time.Second)) {
return false
}
zc.lastFetched = now
if len(ift) == 0 {
var err error
if ift, err = interfaceTable(0); err != nil {
return false
}
}
zc.toIndex = make(map[string]int, len(ift))
zc.toName = make(map[int]string, len(ift))
for _, ifi := range ift {
zc.toIndex[ifi.Name] = ifi.Index
if _, ok := zc.toName[ifi.Index]; !ok {
zc.toName[ifi.Index] = ifi.Name
}
}
return true
}
func (zc *ipv6ZoneCache) name(index int) string {
if index == 0 {
return ""
}
updated := zoneCache.update(nil, false)
zoneCache.RLock()
name, ok := zoneCache.toName[index]
zoneCache.RUnlock()
if !ok && !updated {
zoneCache.update(nil, true)
zoneCache.RLock()
name, ok = zoneCache.toName[index]
zoneCache.RUnlock()
}
if !ok { // last resort
name = itoa.Uitoa(uint(index))
}
return name
}
func (zc *ipv6ZoneCache) index(name string) int {
if name == "" {
return 0
}
updated := zoneCache.update(nil, false)
zoneCache.RLock()
index, ok := zoneCache.toIndex[name]
zoneCache.RUnlock()
if !ok && !updated {
zoneCache.update(nil, true)
zoneCache.RLock()
index, ok = zoneCache.toIndex[name]
zoneCache.RUnlock()
}
if !ok { // last resort
index, _, _ = dtoi(name)
}
return index
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"os"
"syscall"
"unsafe"
)
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otherwise it returns a mapping of a specific
// interface.
func interfaceTable(ifindex int) ([]Interface, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC)
if err != nil {
return nil, os.NewSyscallError("netlinkrib", err)
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, os.NewSyscallError("parsenetlinkmessage", err)
}
var ift []Interface
loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
break loop
case syscall.RTM_NEWLINK:
ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0]))
if ifindex == 0 || ifindex == int(ifim.Index) {
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil {
return nil, os.NewSyscallError("parsenetlinkrouteattr", err)
}
ift = append(ift, *newLink(ifim, attrs))
if ifindex == int(ifim.Index) {
break loop
}
}
}
}
return ift, nil
}
const (
// See linux/if_arp.h.
// Note that Linux doesn't support IPv4 over IPv6 tunneling.
sysARPHardwareIPv4IPv4 = 768 // IPv4 over IPv4 tunneling
sysARPHardwareIPv6IPv6 = 769 // IPv6 over IPv6 tunneling
sysARPHardwareIPv6IPv4 = 776 // IPv6 over IPv4 tunneling
sysARPHardwareGREIPv4 = 778 // any over GRE over IPv4 tunneling
sysARPHardwareGREIPv6 = 823 // any over GRE over IPv6 tunneling
)
func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface {
ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)}
for _, a := range attrs {
switch a.Attr.Type {
case syscall.IFLA_ADDRESS:
// We never return any /32 or /128 IP address
// prefix on any IP tunnel interface as the
// hardware address.
switch len(a.Value) {
case IPv4len:
switch ifim.Type {
case sysARPHardwareIPv4IPv4, sysARPHardwareGREIPv4, sysARPHardwareIPv6IPv4:
continue
}
case IPv6len:
switch ifim.Type {
case sysARPHardwareIPv6IPv6, sysARPHardwareGREIPv6:
continue
}
}
var nonzero bool
for _, b := range a.Value {
if b != 0 {
nonzero = true
break
}
}
if nonzero {
ifi.HardwareAddr = a.Value[:]
}
case syscall.IFLA_IFNAME:
ifi.Name = string(a.Value[:len(a.Value)-1])
case syscall.IFLA_MTU:
ifi.MTU = int(*(*uint32)(unsafe.Pointer(&a.Value[:4][0])))
}
}
return ifi
}
func linkFlags(rawFlags uint32) Flags {
var f Flags
if rawFlags&syscall.IFF_UP != 0 {
f |= FlagUp
}
if rawFlags&syscall.IFF_RUNNING != 0 {
f |= FlagRunning
}
if rawFlags&syscall.IFF_BROADCAST != 0 {
f |= FlagBroadcast
}
if rawFlags&syscall.IFF_LOOPBACK != 0 {
f |= FlagLoopback
}
if rawFlags&syscall.IFF_POINTOPOINT != 0 {
f |= FlagPointToPoint
}
if rawFlags&syscall.IFF_MULTICAST != 0 {
f |= FlagMulticast
}
return f
}
// If the ifi is nil, interfaceAddrTable returns addresses for all
// network interfaces. Otherwise it returns addresses for a specific
// interface.
func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC)
if err != nil {
return nil, os.NewSyscallError("netlinkrib", err)
}
msgs, err := syscall.ParseNetlinkMessage(tab)
if err != nil {
return nil, os.NewSyscallError("parsenetlinkmessage", err)
}
var ift []Interface
if ifi == nil {
var err error
ift, err = interfaceTable(0)
if err != nil {
return nil, err
}
}
ifat, err := addrTable(ift, ifi, msgs)
if err != nil {
return nil, err
}
return ifat, nil
}
func addrTable(ift []Interface, ifi *Interface, msgs []syscall.NetlinkMessage) ([]Addr, error) {
var ifat []Addr
loop:
for _, m := range msgs {
switch m.Header.Type {
case syscall.NLMSG_DONE:
break loop
case syscall.RTM_NEWADDR:
ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0]))
if len(ift) != 0 || ifi.Index == int(ifam.Index) {
if len(ift) != 0 {
var err error
ifi, err = interfaceByIndex(ift, int(ifam.Index))
if err != nil {
return nil, err
}
}
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
if err != nil {
return nil, os.NewSyscallError("parsenetlinkrouteattr", err)
}
ifa := newAddr(ifam, attrs)
if ifa != nil {
ifat = append(ifat, ifa)
}
}
}
}
return ifat, nil
}
func newAddr(ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr {
var ipPointToPoint bool
// Seems like we need to make sure whether the IP interface
// stack consists of IP point-to-point numbered or unnumbered
// addressing.
for _, a := range attrs {
if a.Attr.Type == syscall.IFA_LOCAL {
ipPointToPoint = true
break
}
}
for _, a := range attrs {
if ipPointToPoint && a.Attr.Type == syscall.IFA_ADDRESS {
continue
}
switch ifam.Family {
case syscall.AF_INET:
return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)}
case syscall.AF_INET6:
ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)}
copy(ifa.IP, a.Value[:])
return ifa
}
}
return nil
}
// interfaceMulticastAddrTable returns addresses for a specific
// interface.
func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
ifmat4 := parseProcNetIGMP("/proc/net/igmp", ifi)
ifmat6 := parseProcNetIGMP6("/proc/net/igmp6", ifi)
return append(ifmat4, ifmat6...), nil
}
func parseProcNetIGMP(path string, ifi *Interface) []Addr {
fd, err := open(path)
if err != nil {
return nil
}
defer fd.close()
var (
ifmat []Addr
name string
)
fd.readLine() // skip first line
b := make([]byte, IPv4len)
for l, ok := fd.readLine(); ok; l, ok = fd.readLine() {
f := splitAtBytes(l, " :\r\t\n")
if len(f) < 4 {
continue
}
switch {
case l[0] != ' ' && l[0] != '\t': // new interface line
name = f[1]
case len(f[0]) == 8:
if ifi == nil || name == ifi.Name {
// The Linux kernel puts the IP
// address in /proc/net/igmp in native
// endianness.
for i := 0; i+1 < len(f[0]); i += 2 {
b[i/2], _ = xtoi2(f[0][i:i+2], 0)
}
i := *(*uint32)(unsafe.Pointer(&b[:4][0]))
ifma := &IPAddr{IP: IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))}
ifmat = append(ifmat, ifma)
}
}
}
return ifmat
}
func parseProcNetIGMP6(path string, ifi *Interface) []Addr {
fd, err := open(path)
if err != nil {
return nil
}
defer fd.close()
var ifmat []Addr
b := make([]byte, IPv6len)
for l, ok := fd.readLine(); ok; l, ok = fd.readLine() {
f := splitAtBytes(l, " \r\t\n")
if len(f) < 6 {
continue
}
if ifi == nil || f[1] == ifi.Name {
for i := 0; i+1 < len(f[2]); i += 2 {
b[i/2], _ = xtoi2(f[2][i:i+2], 0)
}
ifma := &IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}}
ifmat = append(ifmat, ifma)
}
}
return ifmat
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package socktest provides utilities for socket testing.
package socktest
import (
"fmt"
"sync"
)
// A Switch represents a callpath point switch for socket system
// calls.
type Switch struct {
once sync.Once
fmu sync.RWMutex
fltab map[FilterType]Filter
smu sync.RWMutex
sotab Sockets
stats stats
}
func (sw *Switch) init() {
sw.fltab = make(map[FilterType]Filter)
sw.sotab = make(Sockets)
sw.stats = make(stats)
}
// Stats returns a list of per-cookie socket statistics.
func (sw *Switch) Stats() []Stat {
var st []Stat
sw.smu.RLock()
for _, s := range sw.stats {
ns := *s
st = append(st, ns)
}
sw.smu.RUnlock()
return st
}
// Sockets returns mappings of socket descriptor to socket status.
func (sw *Switch) Sockets() Sockets {
sw.smu.RLock()
tab := make(Sockets, len(sw.sotab))
for i, s := range sw.sotab {
tab[i] = s
}
sw.smu.RUnlock()
return tab
}
// A Cookie represents a 3-tuple of a socket; address family, socket
// type and protocol number.
type Cookie uint64
// Family returns an address family.
func (c Cookie) Family() int { return int(c >> 48) }
// Type returns a socket type.
func (c Cookie) Type() int { return int(c << 16 >> 32) }
// Protocol returns a protocol number.
func (c Cookie) Protocol() int { return int(c & 0xff) }
func cookie(family, sotype, proto int) Cookie {
return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff
}
// A Status represents the status of a socket.
type Status struct {
Cookie Cookie
Err error // error status of socket system call
SocketErr error // error status of socket by SO_ERROR
}
func (so Status) String() string {
return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
}
// A Stat represents a per-cookie socket statistics.
type Stat struct {
Family int // address family
Type int // socket type
Protocol int // protocol number
Opened uint64 // number of sockets opened
Connected uint64 // number of sockets connected
Listened uint64 // number of sockets listened
Accepted uint64 // number of sockets accepted
Closed uint64 // number of sockets closed
OpenFailed uint64 // number of sockets open failed
ConnectFailed uint64 // number of sockets connect failed
ListenFailed uint64 // number of sockets listen failed
AcceptFailed uint64 // number of sockets accept failed
CloseFailed uint64 // number of sockets close failed
}
func (st Stat) String() string {
return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
}
type stats map[Cookie]*Stat
func (st stats) getLocked(c Cookie) *Stat {
s, ok := st[c]
if !ok {
s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()}
st[c] = s
}
return s
}
// A FilterType represents a filter type.
type FilterType int
const (
FilterSocket FilterType = iota // for Socket
FilterConnect // for Connect or ConnectEx
FilterListen // for Listen
FilterAccept // for Accept, Accept4 or AcceptEx
FilterGetsockoptInt // for GetsockoptInt
FilterClose // for Close or Closesocket
)
// A Filter represents a socket system call filter.
//
// It will only be executed before a system call for a socket that has
// an entry in internal table.
// If the filter returns a non-nil error, the execution of system call
// will be canceled and the system call function returns the non-nil
// error.
// It can return a non-nil AfterFilter for filtering after the
// execution of the system call.
type Filter func(*Status) (AfterFilter, error)
func (f Filter) apply(st *Status) (AfterFilter, error) {
if f == nil {
return nil, nil
}
return f(st)
}
// An AfterFilter represents a socket system call filter after an
// execution of a system call.
//
// It will only be executed after a system call for a socket that has
// an entry in internal table.
// If the filter returns a non-nil error, the system call function
// returns the non-nil error.
type AfterFilter func(*Status) error
func (f AfterFilter) apply(st *Status) error {
if f == nil {
return nil
}
return f(st)
}
// Set deploys the socket system call filter f for the filter type t.
func (sw *Switch) Set(t FilterType, f Filter) {
sw.once.Do(sw.init)
sw.fmu.Lock()
sw.fltab[t] = f
sw.fmu.Unlock()
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !plan9
package socktest
import (
"fmt"
"syscall"
)
func familyString(family int) string {
switch family {
case syscall.AF_INET:
return "inet4"
case syscall.AF_INET6:
return "inet6"
case syscall.AF_UNIX:
return "local"
default:
return fmt.Sprintf("%d", family)
}
}
func typeString(sotype int) string {
var s string
switch sotype & 0xff {
case syscall.SOCK_STREAM:
s = "stream"
case syscall.SOCK_DGRAM:
s = "datagram"
case syscall.SOCK_RAW:
s = "raw"
case syscall.SOCK_SEQPACKET:
s = "seqpacket"
default:
s = fmt.Sprintf("%d", sotype&0xff)
}
if flags := uint(sotype) & ^uint(0xff); flags != 0 {
s += fmt.Sprintf("|%#x", flags)
}
return s
}
func protocolString(proto int) string {
switch proto {
case 0:
return "default"
case syscall.IPPROTO_TCP:
return "tcp"
case syscall.IPPROTO_UDP:
return "udp"
default:
return fmt.Sprintf("%d", proto)
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package socktest
// Sockets maps a socket descriptor to the status of socket.
type Sockets map[int]Status
func (sw *Switch) sockso(s int) *Status {
sw.smu.RLock()
defer sw.smu.RUnlock()
so, ok := sw.sotab[s]
if !ok {
return nil
}
return &so
}
// addLocked returns a new Status without locking.
// sw.smu must be held before call.
func (sw *Switch) addLocked(s, family, sotype, proto int) *Status {
sw.once.Do(sw.init)
so := Status{Cookie: cookie(family, sotype, proto)}
sw.sotab[s] = so
return &so
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
package socktest
import "syscall"
// Accept4 wraps syscall.Accept4.
func (sw *Switch) Accept4(s, flags int) (ns int, sa syscall.Sockaddr, err error) {
so := sw.sockso(s)
if so == nil {
return syscall.Accept4(s, flags)
}
sw.fmu.RLock()
f := sw.fltab[FilterAccept]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return -1, nil, err
}
ns, sa, so.Err = syscall.Accept4(s, flags)
if err = af.apply(so); err != nil {
if so.Err == nil {
syscall.Close(ns)
}
return -1, nil, err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).AcceptFailed++
return -1, nil, so.Err
}
nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
sw.stats.getLocked(nso.Cookie).Accepted++
return ns, sa, nil
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package socktest
import "syscall"
// Socket wraps syscall.Socket.
func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) {
sw.once.Do(sw.init)
so := &Status{Cookie: cookie(family, sotype, proto)}
sw.fmu.RLock()
f := sw.fltab[FilterSocket]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return -1, err
}
s, so.Err = syscall.Socket(family, sotype, proto)
if err = af.apply(so); err != nil {
if so.Err == nil {
syscall.Close(s)
}
return -1, err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).OpenFailed++
return -1, so.Err
}
nso := sw.addLocked(s, family, sotype, proto)
sw.stats.getLocked(nso.Cookie).Opened++
return s, nil
}
// Close wraps syscall.Close.
func (sw *Switch) Close(s int) (err error) {
so := sw.sockso(s)
if so == nil {
return syscall.Close(s)
}
sw.fmu.RLock()
f := sw.fltab[FilterClose]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return err
}
so.Err = syscall.Close(s)
if err = af.apply(so); err != nil {
return err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).CloseFailed++
return so.Err
}
delete(sw.sotab, s)
sw.stats.getLocked(so.Cookie).Closed++
return nil
}
// Connect wraps syscall.Connect.
func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) {
so := sw.sockso(s)
if so == nil {
return syscall.Connect(s, sa)
}
sw.fmu.RLock()
f := sw.fltab[FilterConnect]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return err
}
so.Err = syscall.Connect(s, sa)
if err = af.apply(so); err != nil {
return err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).ConnectFailed++
return so.Err
}
sw.stats.getLocked(so.Cookie).Connected++
return nil
}
// Listen wraps syscall.Listen.
func (sw *Switch) Listen(s, backlog int) (err error) {
so := sw.sockso(s)
if so == nil {
return syscall.Listen(s, backlog)
}
sw.fmu.RLock()
f := sw.fltab[FilterListen]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return err
}
so.Err = syscall.Listen(s, backlog)
if err = af.apply(so); err != nil {
return err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).ListenFailed++
return so.Err
}
sw.stats.getLocked(so.Cookie).Listened++
return nil
}
// Accept wraps syscall.Accept.
func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) {
so := sw.sockso(s)
if so == nil {
return syscall.Accept(s)
}
sw.fmu.RLock()
f := sw.fltab[FilterAccept]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return -1, nil, err
}
ns, sa, so.Err = syscall.Accept(s)
if err = af.apply(so); err != nil {
if so.Err == nil {
syscall.Close(ns)
}
return -1, nil, err
}
sw.smu.Lock()
defer sw.smu.Unlock()
if so.Err != nil {
sw.stats.getLocked(so.Cookie).AcceptFailed++
return -1, nil, so.Err
}
nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
sw.stats.getLocked(nso.Cookie).Accepted++
return ns, sa, nil
}
// GetsockoptInt wraps syscall.GetsockoptInt.
func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) {
so := sw.sockso(s)
if so == nil {
return syscall.GetsockoptInt(s, level, opt)
}
sw.fmu.RLock()
f := sw.fltab[FilterGetsockoptInt]
sw.fmu.RUnlock()
af, err := f.apply(so)
if err != nil {
return -1, err
}
soerr, so.Err = syscall.GetsockoptInt(s, level, opt)
so.SocketErr = syscall.Errno(soerr)
if err = af.apply(so); err != nil {
return -1, err
}
if so.Err != nil {
return -1, so.Err
}
if opt == syscall.SO_ERROR && (so.SocketErr == syscall.Errno(0) || so.SocketErr == syscall.EISCONN) {
sw.smu.Lock()
sw.stats.getLocked(so.Cookie).Connected++
sw.smu.Unlock()
}
return soerr, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IP address manipulations
//
// IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes.
// An IPv4 address can be converted to an IPv6 address by
// adding a canonical prefix (10 zeros, 2 0xFFs).
// This library accepts either size of byte slice but always
// returns 16-byte addresses.
package net
import (
"internal/bytealg"
"internal/itoa"
"net/netip"
)
// IP address lengths (bytes).
const (
IPv4len = 4
IPv6len = 16
)
// An IP is a single IP address, a slice of bytes.
// Functions in this package accept either 4-byte (IPv4)
// or 16-byte (IPv6) slices as input.
//
// Note that in this documentation, referring to an
// IP address as an IPv4 address or an IPv6 address
// is a semantic property of the address, not just the
// length of the byte slice: a 16-byte slice can still
// be an IPv4 address.
type IP []byte
// An IPMask is a bitmask that can be used to manipulate
// IP addresses for IP addressing and routing.
//
// See type IPNet and func ParseCIDR for details.
type IPMask []byte
// An IPNet represents an IP network.
type IPNet struct {
IP IP // network number
Mask IPMask // network mask
}
// IPv4 returns the IP address (in 16-byte form) of the
// IPv4 address a.b.c.d.
func IPv4(a, b, c, d byte) IP {
p := make(IP, IPv6len)
copy(p, v4InV6Prefix)
p[12] = a
p[13] = b
p[14] = c
p[15] = d
return p
}
var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
// IPv4Mask returns the IP mask (in 4-byte form) of the
// IPv4 mask a.b.c.d.
func IPv4Mask(a, b, c, d byte) IPMask {
p := make(IPMask, IPv4len)
p[0] = a
p[1] = b
p[2] = c
p[3] = d
return p
}
// CIDRMask returns an IPMask consisting of 'ones' 1 bits
// followed by 0s up to a total length of 'bits' bits.
// For a mask of this form, CIDRMask is the inverse of IPMask.Size.
func CIDRMask(ones, bits int) IPMask {
if bits != 8*IPv4len && bits != 8*IPv6len {
return nil
}
if ones < 0 || ones > bits {
return nil
}
l := bits / 8
m := make(IPMask, l)
n := uint(ones)
for i := 0; i < l; i++ {
if n >= 8 {
m[i] = 0xff
n -= 8
continue
}
m[i] = ^byte(0xff >> n)
n = 0
}
return m
}
// Well-known IPv4 addresses
var (
IPv4bcast = IPv4(255, 255, 255, 255) // limited broadcast
IPv4allsys = IPv4(224, 0, 0, 1) // all systems
IPv4allrouter = IPv4(224, 0, 0, 2) // all routers
IPv4zero = IPv4(0, 0, 0, 0) // all zeros
)
// Well-known IPv6 addresses
var (
IPv6zero = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
IPv6unspecified = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
IPv6interfacelocalallnodes = IP{0xff, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
IPv6linklocalallnodes = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
IPv6linklocalallrouters = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x02}
)
// IsUnspecified reports whether ip is an unspecified address, either
// the IPv4 address "0.0.0.0" or the IPv6 address "::".
func (ip IP) IsUnspecified() bool {
return ip.Equal(IPv4zero) || ip.Equal(IPv6unspecified)
}
// IsLoopback reports whether ip is a loopback address.
func (ip IP) IsLoopback() bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[0] == 127
}
return ip.Equal(IPv6loopback)
}
// IsPrivate reports whether ip is a private address, according to
// RFC 1918 (IPv4 addresses) and RFC 4193 (IPv6 addresses).
func (ip IP) IsPrivate() bool {
if ip4 := ip.To4(); ip4 != nil {
// Following RFC 1918, Section 3. Private Address Space which says:
// The Internet Assigned Numbers Authority (IANA) has reserved the
// following three blocks of the IP address space for private internets:
// 10.0.0.0 - 10.255.255.255 (10/8 prefix)
// 172.16.0.0 - 172.31.255.255 (172.16/12 prefix)
// 192.168.0.0 - 192.168.255.255 (192.168/16 prefix)
return ip4[0] == 10 ||
(ip4[0] == 172 && ip4[1]&0xf0 == 16) ||
(ip4[0] == 192 && ip4[1] == 168)
}
// Following RFC 4193, Section 8. IANA Considerations which says:
// The IANA has assigned the FC00::/7 prefix to "Unique Local Unicast".
return len(ip) == IPv6len && ip[0]&0xfe == 0xfc
}
// IsMulticast reports whether ip is a multicast address.
func (ip IP) IsMulticast() bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[0]&0xf0 == 0xe0
}
return len(ip) == IPv6len && ip[0] == 0xff
}
// IsInterfaceLocalMulticast reports whether ip is
// an interface-local multicast address.
func (ip IP) IsInterfaceLocalMulticast() bool {
return len(ip) == IPv6len && ip[0] == 0xff && ip[1]&0x0f == 0x01
}
// IsLinkLocalMulticast reports whether ip is a link-local
// multicast address.
func (ip IP) IsLinkLocalMulticast() bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[0] == 224 && ip4[1] == 0 && ip4[2] == 0
}
return len(ip) == IPv6len && ip[0] == 0xff && ip[1]&0x0f == 0x02
}
// IsLinkLocalUnicast reports whether ip is a link-local
// unicast address.
func (ip IP) IsLinkLocalUnicast() bool {
if ip4 := ip.To4(); ip4 != nil {
return ip4[0] == 169 && ip4[1] == 254
}
return len(ip) == IPv6len && ip[0] == 0xfe && ip[1]&0xc0 == 0x80
}
// IsGlobalUnicast reports whether ip is a global unicast
// address.
//
// The identification of global unicast addresses uses address type
// identification as defined in RFC 1122, RFC 4632 and RFC 4291 with
// the exception of IPv4 directed broadcast addresses.
// It returns true even if ip is in IPv4 private address space or
// local IPv6 unicast address space.
func (ip IP) IsGlobalUnicast() bool {
return (len(ip) == IPv4len || len(ip) == IPv6len) &&
!ip.Equal(IPv4bcast) &&
!ip.IsUnspecified() &&
!ip.IsLoopback() &&
!ip.IsMulticast() &&
!ip.IsLinkLocalUnicast()
}
// Is p all zeros?
func isZeros(p IP) bool {
for i := 0; i < len(p); i++ {
if p[i] != 0 {
return false
}
}
return true
}
// To4 converts the IPv4 address ip to a 4-byte representation.
// If ip is not an IPv4 address, To4 returns nil.
func (ip IP) To4() IP {
if len(ip) == IPv4len {
return ip
}
if len(ip) == IPv6len &&
isZeros(ip[0:10]) &&
ip[10] == 0xff &&
ip[11] == 0xff {
return ip[12:16]
}
return nil
}
// To16 converts the IP address ip to a 16-byte representation.
// If ip is not an IP address (it is the wrong length), To16 returns nil.
func (ip IP) To16() IP {
if len(ip) == IPv4len {
return IPv4(ip[0], ip[1], ip[2], ip[3])
}
if len(ip) == IPv6len {
return ip
}
return nil
}
// Default route masks for IPv4.
var (
classAMask = IPv4Mask(0xff, 0, 0, 0)
classBMask = IPv4Mask(0xff, 0xff, 0, 0)
classCMask = IPv4Mask(0xff, 0xff, 0xff, 0)
)
// DefaultMask returns the default IP mask for the IP address ip.
// Only IPv4 addresses have default masks; DefaultMask returns
// nil if ip is not a valid IPv4 address.
func (ip IP) DefaultMask() IPMask {
if ip = ip.To4(); ip == nil {
return nil
}
switch {
case ip[0] < 0x80:
return classAMask
case ip[0] < 0xC0:
return classBMask
default:
return classCMask
}
}
func allFF(b []byte) bool {
for _, c := range b {
if c != 0xff {
return false
}
}
return true
}
// Mask returns the result of masking the IP address ip with mask.
func (ip IP) Mask(mask IPMask) IP {
if len(mask) == IPv6len && len(ip) == IPv4len && allFF(mask[:12]) {
mask = mask[12:]
}
if len(mask) == IPv4len && len(ip) == IPv6len && bytealg.Equal(ip[:12], v4InV6Prefix) {
ip = ip[12:]
}
n := len(ip)
if n != len(mask) {
return nil
}
out := make(IP, n)
for i := 0; i < n; i++ {
out[i] = ip[i] & mask[i]
}
return out
}
// String returns the string form of the IP address ip.
// It returns one of 4 forms:
// - "<nil>", if ip has length 0
// - dotted decimal ("192.0.2.1"), if ip is an IPv4 or IP4-mapped IPv6 address
// - IPv6 conforming to RFC 5952 ("2001:db8::1"), if ip is a valid IPv6 address
// - the hexadecimal form of ip, without punctuation, if no other cases apply
func (ip IP) String() string {
if len(ip) == 0 {
return "<nil>"
}
if len(ip) != IPv4len && len(ip) != IPv6len {
return "?" + hexString(ip)
}
// If IPv4, use dotted notation.
if p4 := ip.To4(); len(p4) == IPv4len {
return netip.AddrFrom4([4]byte(p4)).String()
}
return netip.AddrFrom16([16]byte(ip)).String()
}
func hexString(b []byte) string {
s := make([]byte, len(b)*2)
for i, tn := range b {
s[i*2], s[i*2+1] = hexDigit[tn>>4], hexDigit[tn&0xf]
}
return string(s)
}
// ipEmptyString is like ip.String except that it returns
// an empty string when ip is unset.
func ipEmptyString(ip IP) string {
if len(ip) == 0 {
return ""
}
return ip.String()
}
// MarshalText implements the encoding.TextMarshaler interface.
// The encoding is the same as returned by String, with one exception:
// When len(ip) is zero, it returns an empty slice.
func (ip IP) MarshalText() ([]byte, error) {
if len(ip) == 0 {
return []byte(""), nil
}
if len(ip) != IPv4len && len(ip) != IPv6len {
return nil, &AddrError{Err: "invalid IP address", Addr: hexString(ip)}
}
return []byte(ip.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The IP address is expected in a form accepted by ParseIP.
func (ip *IP) UnmarshalText(text []byte) error {
if len(text) == 0 {
*ip = nil
return nil
}
s := string(text)
x := ParseIP(s)
if x == nil {
return &ParseError{Type: "IP address", Text: s}
}
*ip = x
return nil
}
// Equal reports whether ip and x are the same IP address.
// An IPv4 address and that same address in IPv6 form are
// considered to be equal.
func (ip IP) Equal(x IP) bool {
if len(ip) == len(x) {
return bytealg.Equal(ip, x)
}
if len(ip) == IPv4len && len(x) == IPv6len {
return bytealg.Equal(x[0:12], v4InV6Prefix) && bytealg.Equal(ip, x[12:])
}
if len(ip) == IPv6len && len(x) == IPv4len {
return bytealg.Equal(ip[0:12], v4InV6Prefix) && bytealg.Equal(ip[12:], x)
}
return false
}
func (ip IP) matchAddrFamily(x IP) bool {
return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil
}
// If mask is a sequence of 1 bits followed by 0 bits,
// return the number of 1 bits.
func simpleMaskLength(mask IPMask) int {
var n int
for i, v := range mask {
if v == 0xff {
n += 8
continue
}
// found non-ff byte
// count 1 bits
for v&0x80 != 0 {
n++
v <<= 1
}
// rest must be 0 bits
if v != 0 {
return -1
}
for i++; i < len(mask); i++ {
if mask[i] != 0 {
return -1
}
}
break
}
return n
}
// Size returns the number of leading ones and total bits in the mask.
// If the mask is not in the canonical form--ones followed by zeros--then
// Size returns 0, 0.
func (m IPMask) Size() (ones, bits int) {
ones, bits = simpleMaskLength(m), len(m)*8
if ones == -1 {
return 0, 0
}
return
}
// String returns the hexadecimal form of m, with no punctuation.
func (m IPMask) String() string {
if len(m) == 0 {
return "<nil>"
}
return hexString(m)
}
func networkNumberAndMask(n *IPNet) (ip IP, m IPMask) {
if ip = n.IP.To4(); ip == nil {
ip = n.IP
if len(ip) != IPv6len {
return nil, nil
}
}
m = n.Mask
switch len(m) {
case IPv4len:
if len(ip) != IPv4len {
return nil, nil
}
case IPv6len:
if len(ip) == IPv4len {
m = m[12:]
}
default:
return nil, nil
}
return
}
// Contains reports whether the network includes ip.
func (n *IPNet) Contains(ip IP) bool {
nn, m := networkNumberAndMask(n)
if x := ip.To4(); x != nil {
ip = x
}
l := len(ip)
if l != len(nn) {
return false
}
for i := 0; i < l; i++ {
if nn[i]&m[i] != ip[i]&m[i] {
return false
}
}
return true
}
// Network returns the address's network name, "ip+net".
func (n *IPNet) Network() string { return "ip+net" }
// String returns the CIDR notation of n like "192.0.2.0/24"
// or "2001:db8::/48" as defined in RFC 4632 and RFC 4291.
// If the mask is not in the canonical form, it returns the
// string which consists of an IP address, followed by a slash
// character and a mask expressed as hexadecimal form with no
// punctuation like "198.51.100.0/c000ff00".
func (n *IPNet) String() string {
if n == nil {
return "<nil>"
}
nn, m := networkNumberAndMask(n)
if nn == nil || m == nil {
return "<nil>"
}
l := simpleMaskLength(m)
if l == -1 {
return nn.String() + "/" + m.String()
}
return nn.String() + "/" + itoa.Uitoa(uint(l))
}
// ParseIP parses s as an IP address, returning the result.
// The string s can be in IPv4 dotted decimal ("192.0.2.1"), IPv6
// ("2001:db8::68"), or IPv4-mapped IPv6 ("::ffff:192.0.2.1") form.
// If s is not a valid textual representation of an IP address,
// ParseIP returns nil.
func ParseIP(s string) IP {
if addr, valid := parseIP(s); valid {
return IP(addr[:])
}
return nil
}
func parseIP(s string) ([16]byte, bool) {
ip, err := netip.ParseAddr(s)
if err != nil || ip.Zone() != "" {
return [16]byte{}, false
}
return ip.As16(), true
}
// ParseCIDR parses s as a CIDR notation IP address and prefix length,
// like "192.0.2.0/24" or "2001:db8::/32", as defined in
// RFC 4632 and RFC 4291.
//
// It returns the IP address and the network implied by the IP and
// prefix length.
// For example, ParseCIDR("192.0.2.1/24") returns the IP address
// 192.0.2.1 and the network 192.0.2.0/24.
func ParseCIDR(s string) (IP, *IPNet, error) {
i := bytealg.IndexByteString(s, '/')
if i < 0 {
return nil, nil, &ParseError{Type: "CIDR address", Text: s}
}
addr, mask := s[:i], s[i+1:]
ipAddr, err := netip.ParseAddr(addr)
if err != nil || ipAddr.Zone() != "" {
return nil, nil, &ParseError{Type: "CIDR address", Text: s}
}
n, i, ok := dtoi(mask)
if !ok || i != len(mask) || n < 0 || n > ipAddr.BitLen() {
return nil, nil, &ParseError{Type: "CIDR address", Text: s}
}
m := CIDRMask(n, ipAddr.BitLen())
addr16 := ipAddr.As16()
return IP(addr16[:]), &IPNet{IP: IP(addr16[:]).Mask(m), Mask: m}, nil
}
func copyIP(x IP) IP {
y := make(IP, len(x))
copy(y, x)
return y
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"syscall"
)
// BUG(mikio): On every POSIX platform, reads from the "ip4" network
// using the ReadFrom or ReadFromIP method might not return a complete
// IPv4 packet, including its header, even if there is space
// available. This can occur even in cases where Read or ReadMsgIP
// could return a complete packet. For this reason, it is recommended
// that you do not use these methods if it is important to receive a
// full packet.
//
// The Go 1 compatibility guidelines make it impossible for us to
// change the behavior of these methods; use Read or ReadMsgIP
// instead.
// BUG(mikio): On JS and Plan 9, methods and functions related
// to IPConn are not implemented.
// BUG(mikio): On Windows, the File method of IPConn is not
// implemented.
// IPAddr represents the address of an IP end point.
type IPAddr struct {
IP IP
Zone string // IPv6 scoped addressing zone
}
// Network returns the address's network name, "ip".
func (a *IPAddr) Network() string { return "ip" }
func (a *IPAddr) String() string {
if a == nil {
return "<nil>"
}
ip := ipEmptyString(a.IP)
if a.Zone != "" {
return ip + "%" + a.Zone
}
return ip
}
func (a *IPAddr) isWildcard() bool {
if a == nil || a.IP == nil {
return true
}
return a.IP.IsUnspecified()
}
func (a *IPAddr) opAddr() Addr {
if a == nil {
return nil
}
return a
}
// ResolveIPAddr returns an address of IP end point.
//
// The network must be an IP network name.
//
// If the host in the address parameter is not a literal IP address,
// ResolveIPAddr resolves the address to an address of IP end point.
// Otherwise, it parses the address as a literal IP address.
// The address parameter can use a host name, but this is not
// recommended, because it will return at most one of the host name's
// IP addresses.
//
// See func Dial for a description of the network and address
// parameters.
func ResolveIPAddr(network, address string) (*IPAddr, error) {
if network == "" { // a hint wildcard for Go 1.0 undocumented behavior
network = "ip"
}
afnet, _, err := parseNetwork(context.Background(), network, false)
if err != nil {
return nil, err
}
switch afnet {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(network)
}
addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, address)
if err != nil {
return nil, err
}
return addrs.forResolve(network, address).(*IPAddr), nil
}
// IPConn is the implementation of the Conn and PacketConn interfaces
// for IP network connections.
type IPConn struct {
conn
}
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
func (c *IPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
return newRawConn(c.fd)
}
// ReadFromIP acts like ReadFrom but returns an IPAddr.
func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
n, addr, err := c.readFrom(b)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, addr, err
}
// ReadFrom implements the PacketConn ReadFrom method.
func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
n, addr, err := c.readFrom(b)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
if addr == nil {
return n, nil, err
}
return n, addr, err
}
// ReadMsgIP reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the message and the source address of the message.
//
// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
// used to manipulate IP-level socket options in oob.
func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
if !c.ok() {
return 0, 0, 0, nil, syscall.EINVAL
}
n, oobn, flags, addr, err = c.readMsg(b, oob)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return
}
// WriteToIP acts like WriteTo but takes an IPAddr.
func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.writeTo(b, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return n, err
}
// WriteTo implements the PacketConn WriteTo method.
func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
a, ok := addr.(*IPAddr)
if !ok {
return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
}
n, err := c.writeTo(b, a)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
}
return n, err
}
// WriteMsgIP writes a message to addr via c, copying the payload from
// b and the associated out-of-band data from oob. It returns the
// number of payload and out-of-band bytes written.
//
// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
// used to manipulate IP-level socket options in oob.
func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
}
n, oobn, err = c.writeMsg(b, oob, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return
}
func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// DialIP acts like Dial for IP networks.
//
// The network must be an IP network name; see func Dial for details.
//
// If laddr is nil, a local address is automatically chosen.
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialIP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
return c, nil
}
// ListenIP acts like ListenPacket for IP networks.
//
// The network must be an IP network name; see func Dial for details.
//
// If the IP field of laddr is nil or an unspecified IP address,
// ListenIP listens on all available IP addresses of the local system
// except multicast IP addresses.
func ListenIP(network string, laddr *IPAddr) (*IPConn, error) {
if laddr == nil {
laddr = &IPAddr{}
}
sl := &sysListener{network: network, address: laddr.String()}
c, err := sl.listenIP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
return c, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"context"
"syscall"
)
func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
return &IPAddr{IP: sa.Addr[0:]}
case *syscall.SockaddrInet6:
return &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
}
return nil
}
func (a *IPAddr) family() int {
if a == nil || len(a.IP) <= IPv4len {
return syscall.AF_INET
}
if a.IP.To4() != nil {
return syscall.AF_INET
}
return syscall.AF_INET6
}
func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
if a == nil {
return nil, nil
}
return ipToSockaddr(family, a.IP, 0, a.Zone)
}
func (a *IPAddr) toLocal(net string) sockaddr {
return &IPAddr{loopbackIP(net), a.Zone}
}
func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
// TODO(cw,rsc): consider using readv if we know the family
// type to avoid the header trim/copy
var addr *IPAddr
n, sa, err := c.fd.readFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
addr = &IPAddr{IP: sa.Addr[0:]}
n = stripIPv4Header(n, b)
case *syscall.SockaddrInet6:
addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
}
return n, addr, err
}
func stripIPv4Header(n int, b []byte) int {
if len(b) < 20 {
return n
}
l := int(b[0]&0x0f) << 2
if 20 > l || l > len(b) {
return n
}
if b[0]>>4 != 4 {
return n
}
copy(b, b[l:])
return n - l
}
func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
var sa syscall.Sockaddr
n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
addr = &IPAddr{IP: sa.Addr[0:]}
case *syscall.SockaddrInet6:
addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
}
return
}
func (c *IPConn) writeTo(b []byte, addr *IPAddr) (int, error) {
if c.fd.isConnected {
return 0, ErrWriteToConnected
}
if addr == nil {
return 0, errMissingAddress
}
sa, err := addr.sockaddr(c.fd.family)
if err != nil {
return 0, err
}
return c.fd.writeTo(b, sa)
}
func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
if c.fd.isConnected {
return 0, 0, ErrWriteToConnected
}
if addr == nil {
return 0, 0, errMissingAddress
}
sa, err := addr.sockaddr(c.fd.family)
if err != nil {
return 0, 0, err
}
return c.fd.writeMsg(b, oob, sa)
}
func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, sd.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(sd.network)
}
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
return newIPConn(fd), nil
}
func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
network, proto, err := parseNetwork(ctx, sl.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(sl.network)
}
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return newIPConn(fd), nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"internal/bytealg"
"runtime"
"sync"
)
// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
// "tcp" and "udp" networks does not listen for both IPv4 and IPv6
// connections. This is due to the fact that IPv4 traffic will not be
// routed to an IPv6 socket - two separate sockets are required if
// both address families are to be supported.
// See inet6(4) for details.
type ipStackCapabilities struct {
sync.Once // guards following
ipv4Enabled bool
ipv6Enabled bool
ipv4MappedIPv6Enabled bool
}
var ipStackCaps ipStackCapabilities
// supportsIPv4 reports whether the platform supports IPv4 networking
// functionality.
func supportsIPv4() bool {
ipStackCaps.Once.Do(ipStackCaps.probe)
return ipStackCaps.ipv4Enabled
}
// supportsIPv6 reports whether the platform supports IPv6 networking
// functionality.
func supportsIPv6() bool {
ipStackCaps.Once.Do(ipStackCaps.probe)
return ipStackCaps.ipv6Enabled
}
// supportsIPv4map reports whether the platform supports mapping an
// IPv4 address inside an IPv6 address at transport layer
// protocols. See RFC 4291, RFC 4038 and RFC 3493.
func supportsIPv4map() bool {
// Some operating systems provide no support for mapping IPv4
// addresses to IPv6, and a runtime check is unnecessary.
switch runtime.GOOS {
case "dragonfly", "openbsd":
return false
}
ipStackCaps.Once.Do(ipStackCaps.probe)
return ipStackCaps.ipv4MappedIPv6Enabled
}
// An addrList represents a list of network endpoint addresses.
type addrList []Addr
// isIPv4 reports whether addr contains an IPv4 address.
func isIPv4(addr Addr) bool {
switch addr := addr.(type) {
case *TCPAddr:
return addr.IP.To4() != nil
case *UDPAddr:
return addr.IP.To4() != nil
case *IPAddr:
return addr.IP.To4() != nil
}
return false
}
// isNotIPv4 reports whether addr does not contain an IPv4 address.
func isNotIPv4(addr Addr) bool { return !isIPv4(addr) }
// forResolve returns the most appropriate address in address for
// a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr.
// IPv4 is preferred, unless addr contains an IPv6 literal.
func (addrs addrList) forResolve(network, addr string) Addr {
var want6 bool
switch network {
case "ip":
// IPv6 literal (addr does NOT contain a port)
want6 = count(addr, ':') > 0
case "tcp", "udp":
// IPv6 literal. (addr contains a port, so look for '[')
want6 = count(addr, '[') > 0
}
if want6 {
return addrs.first(isNotIPv4)
}
return addrs.first(isIPv4)
}
// first returns the first address which satisfies strategy, or if
// none do, then the first address of any kind.
func (addrs addrList) first(strategy func(Addr) bool) Addr {
for _, addr := range addrs {
if strategy(addr) {
return addr
}
}
return addrs[0]
}
// partition divides an address list into two categories, using a
// strategy function to assign a boolean label to each address.
// The first address, and any with a matching label, are returned as
// primaries, while addresses with the opposite label are returned
// as fallbacks. For non-empty inputs, primaries is guaranteed to be
// non-empty.
func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks addrList) {
var primaryLabel bool
for i, addr := range addrs {
label := strategy(addr)
if i == 0 || label == primaryLabel {
primaryLabel = label
primaries = append(primaries, addr)
} else {
fallbacks = append(fallbacks, addr)
}
}
return
}
// filterAddrList applies a filter to a list of IP addresses,
// yielding a list of Addr objects. Known filters are nil, ipv4only,
// and ipv6only. It returns every address when the filter is nil.
// The result contains at least one address when error is nil.
func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr, originalAddr string) (addrList, error) {
var addrs addrList
for _, ip := range ips {
if filter == nil || filter(ip) {
addrs = append(addrs, inetaddr(ip))
}
}
if len(addrs) == 0 {
return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: originalAddr}
}
return addrs, nil
}
// ipv4only reports whether addr is an IPv4 address.
func ipv4only(addr IPAddr) bool {
return addr.IP.To4() != nil
}
// ipv6only reports whether addr is an IPv6 address except IPv4-mapped IPv6 address.
func ipv6only(addr IPAddr) bool {
return len(addr.IP) == IPv6len && addr.IP.To4() == nil
}
// SplitHostPort splits a network address of the form "host:port",
// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or
// host%zone and port.
//
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
//
// See func Dial for a description of the hostport parameter, and host
// and port results.
func SplitHostPort(hostport string) (host, port string, err error) {
const (
missingPort = "missing port in address"
tooManyColons = "too many colons in address"
)
addrErr := func(addr, why string) (host, port string, err error) {
return "", "", &AddrError{Err: why, Addr: addr}
}
j, k := 0, 0
// The port starts after the last colon.
i := last(hostport, ':')
if i < 0 {
return addrErr(hostport, missingPort)
}
if hostport[0] == '[' {
// Expect the first ']' just before the last ':'.
end := bytealg.IndexByteString(hostport, ']')
if end < 0 {
return addrErr(hostport, "missing ']' in address")
}
switch end + 1 {
case len(hostport):
// There can't be a ':' behind the ']' now.
return addrErr(hostport, missingPort)
case i:
// The expected result.
default:
// Either ']' isn't followed by a colon, or it is
// followed by a colon that is not the last one.
if hostport[end+1] == ':' {
return addrErr(hostport, tooManyColons)
}
return addrErr(hostport, missingPort)
}
host = hostport[1:end]
j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions
} else {
host = hostport[:i]
if bytealg.IndexByteString(host, ':') >= 0 {
return addrErr(hostport, tooManyColons)
}
}
if bytealg.IndexByteString(hostport[j:], '[') >= 0 {
return addrErr(hostport, "unexpected '[' in address")
}
if bytealg.IndexByteString(hostport[k:], ']') >= 0 {
return addrErr(hostport, "unexpected ']' in address")
}
port = hostport[i+1:]
return host, port, nil
}
func splitHostZone(s string) (host, zone string) {
// The IPv6 scoped addressing zone identifier starts after the
// last percent sign.
if i := last(s, '%'); i > 0 {
host, zone = s[:i], s[i+1:]
} else {
host = s
}
return
}
// JoinHostPort combines host and port into a network address of the
// form "host:port". If host contains a colon, as found in literal
// IPv6 addresses, then JoinHostPort returns "[host]:port".
//
// See func Dial for a description of the host and port parameters.
func JoinHostPort(host, port string) string {
// We assume that host is a literal IPv6 address if host has
// colons.
if bytealg.IndexByteString(host, ':') >= 0 {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// internetAddrList resolves addr, which may be a literal IP
// address or a DNS name, and returns a list of internet protocol
// family addresses. The result contains at least one address when
// error is nil.
func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
var (
err error
host, port string
portnum int
)
switch net {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
if addr != "" {
if host, port, err = SplitHostPort(addr); err != nil {
return nil, err
}
if portnum, err = r.LookupPort(ctx, net, port); err != nil {
return nil, err
}
}
case "ip", "ip4", "ip6":
if addr != "" {
host = addr
}
default:
return nil, UnknownNetworkError(net)
}
inetaddr := func(ip IPAddr) Addr {
switch net {
case "tcp", "tcp4", "tcp6":
return &TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}
case "udp", "udp4", "udp6":
return &UDPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}
case "ip", "ip4", "ip6":
return &IPAddr{IP: ip.IP, Zone: ip.Zone}
default:
panic("unexpected network: " + net)
}
}
if host == "" {
return addrList{inetaddr(IPAddr{})}, nil
}
// Try as a literal IP address, then as a DNS name.
ips, err := r.lookupIPAddr(ctx, net, host)
if err != nil {
return nil, err
}
// Issue 18806: if the machine has halfway configured
// IPv6 such that it can bind on "::" (IPv6unspecified)
// but not connect back to that same address, fall
// back to dialing 0.0.0.0.
if len(ips) == 1 && ips[0].IP.Equal(IPv6unspecified) {
ips = append(ips, IPAddr{IP: IPv4zero})
}
var filter func(IPAddr) bool
if net != "" && net[len(net)-1] == '4' {
filter = ipv4only
}
if net != "" && net[len(net)-1] == '6' {
filter = ipv6only
}
return filterAddrList(filter, ips, inetaddr, host)
}
func loopbackIP(net string) IP {
if net != "" && net[len(net)-1] == '6' {
return IPv6loopback
}
return IP{127, 0, 0, 1}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"context"
"internal/poll"
"net/netip"
"runtime"
"syscall"
)
// probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication
// capabilities which are controlled by the IPV6_V6ONLY socket option
// and kernel configuration.
//
// Should we try to use the IPv4 socket interface if we're only
// dealing with IPv4 sockets? As long as the host system understands
// IPv4-mapped IPv6, it's okay to pass IPv4-mapped IPv6 addresses to
// the IPv6 interface. That simplifies our code and is most
// general. Unfortunately, we need to run on kernels built without
// IPv6 support too. So probe the kernel to figure it out.
func (p *ipStackCapabilities) probe() {
s, err := sysSocket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
switch err {
case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT:
case nil:
poll.CloseFunc(s)
p.ipv4Enabled = true
}
var probes = []struct {
laddr TCPAddr
value int
}{
// IPv6 communication capability
{laddr: TCPAddr{IP: ParseIP("::1")}, value: 1},
// IPv4-mapped IPv6 address communication capability
{laddr: TCPAddr{IP: IPv4(127, 0, 0, 1)}, value: 0},
}
switch runtime.GOOS {
case "dragonfly", "openbsd":
// The latest DragonFly BSD and OpenBSD kernels don't
// support IPV6_V6ONLY=0. They always return an error
// and we don't need to probe the capability.
probes = probes[:1]
}
for i := range probes {
s, err := sysSocket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
if err != nil {
continue
}
defer poll.CloseFunc(s)
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value)
sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6)
if err != nil {
continue
}
if err := syscall.Bind(s, sa); err != nil {
continue
}
if i == 0 {
p.ipv6Enabled = true
} else {
p.ipv4MappedIPv6Enabled = true
}
}
}
// favoriteAddrFamily returns the appropriate address family for the
// given network, laddr, raddr and mode.
//
// If mode indicates "listen" and laddr is a wildcard, we assume that
// the user wants to make a passive-open connection with a wildcard
// address family, both AF_INET and AF_INET6, and a wildcard address
// like the following:
//
// - A listen for a wildcard communication domain, "tcp" or
// "udp", with a wildcard address: If the platform supports
// both IPv6 and IPv4-mapped IPv6 communication capabilities,
// or does not support IPv4, we use a dual stack, AF_INET6 and
// IPV6_V6ONLY=0, wildcard address listen. The dual stack
// wildcard address listen may fall back to an IPv6-only,
// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen.
// Otherwise we prefer an IPv4-only, AF_INET, wildcard address
// listen.
//
// - A listen for a wildcard communication domain, "tcp" or
// "udp", with an IPv4 wildcard address: same as above.
//
// - A listen for a wildcard communication domain, "tcp" or
// "udp", with an IPv6 wildcard address: same as above.
//
// - A listen for an IPv4 communication domain, "tcp4" or "udp4",
// with an IPv4 wildcard address: We use an IPv4-only, AF_INET,
// wildcard address listen.
//
// - A listen for an IPv6 communication domain, "tcp6" or "udp6",
// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6
// and IPV6_V6ONLY=1, wildcard address listen.
//
// Otherwise guess: If the addresses are IPv4 then returns AF_INET,
// or else returns AF_INET6. It also returns a boolean value what
// designates IPV6_V6ONLY option.
//
// Note that the latest DragonFly BSD and OpenBSD kernels allow
// neither "net.inet6.ip6.v6only=1" change nor IPPROTO_IPV6 level
// IPV6_V6ONLY socket option setting.
func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) {
switch network[len(network)-1] {
case '4':
return syscall.AF_INET, false
case '6':
return syscall.AF_INET6, true
}
if mode == "listen" && (laddr == nil || laddr.isWildcard()) {
if supportsIPv4map() || !supportsIPv4() {
return syscall.AF_INET6, false
}
if laddr == nil {
return syscall.AF_INET, false
}
return laddr.family(), false
}
if (laddr == nil || laddr.family() == syscall.AF_INET) &&
(raddr == nil || raddr.family() == syscall.AF_INET) {
return syscall.AF_INET, false
}
return syscall.AF_INET6, false
}
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
}
func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
if len(ip) == 0 {
ip = IPv4zero
}
ip4 := ip.To4()
if ip4 == nil {
return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: ip.String()}
}
sa := syscall.SockaddrInet4{Port: port}
copy(sa.Addr[:], ip4)
return sa, nil
}
func ipToSockaddrInet6(ip IP, port int, zone string) (syscall.SockaddrInet6, error) {
// In general, an IP wildcard address, which is either
// "0.0.0.0" or "::", means the entire IP addressing
// space. For some historical reason, it is used to
// specify "any available address" on some operations
// of IP node.
//
// When the IP node supports IPv4-mapped IPv6 address,
// we allow a listener to listen to the wildcard
// address of both IP addressing spaces by specifying
// IPv6 wildcard address.
if len(ip) == 0 || ip.Equal(IPv4zero) {
ip = IPv6zero
}
// We accept any IPv6 address including IPv4-mapped
// IPv6 address.
ip6 := ip.To16()
if ip6 == nil {
return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: ip.String()}
}
sa := syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))}
copy(sa.Addr[:], ip6)
return sa, nil
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
switch family {
case syscall.AF_INET:
sa, err := ipToSockaddrInet4(ip, port)
if err != nil {
return nil, err
}
return &sa, nil
case syscall.AF_INET6:
sa, err := ipToSockaddrInet6(ip, port, zone)
if err != nil {
return nil, err
}
return &sa, nil
}
return nil, &AddrError{Err: "invalid address family", Addr: ip.String()}
}
func addrPortToSockaddrInet4(ap netip.AddrPort) (syscall.SockaddrInet4, error) {
// ipToSockaddrInet4 has special handling here for zero length slices.
// We do not, because netip has no concept of a generic zero IP address.
addr := ap.Addr()
if !addr.Is4() {
return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: addr.String()}
}
sa := syscall.SockaddrInet4{
Addr: addr.As4(),
Port: int(ap.Port()),
}
return sa, nil
}
func addrPortToSockaddrInet6(ap netip.AddrPort) (syscall.SockaddrInet6, error) {
// ipToSockaddrInet6 has special handling here for zero length slices.
// We do not, because netip has no concept of a generic zero IP address.
//
// addr is allowed to be an IPv4 address, because As16 will convert it
// to an IPv4-mapped IPv6 address.
// The error message is kept consistent with ipToSockaddrInet6.
addr := ap.Addr()
if !addr.IsValid() {
return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: addr.String()}
}
sa := syscall.SockaddrInet6{
Addr: addr.As16(),
Port: int(ap.Port()),
ZoneId: uint32(zoneCache.index(addr.Zone())),
}
return sa, nil
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"errors"
"internal/nettrace"
"internal/singleflight"
"net/netip"
"sync"
"golang.org/x/net/dns/dnsmessage"
)
// protocols contains minimal mappings between internet protocol
// names and numbers for platforms that don't have a complete list of
// protocol numbers.
//
// See https://www.iana.org/assignments/protocol-numbers
//
// On Unix, this map is augmented by readProtocols via lookupProtocol.
var protocols = map[string]int{
"icmp": 1,
"igmp": 2,
"tcp": 6,
"udp": 17,
"ipv6-icmp": 58,
}
// services contains minimal mappings between services names and port
// numbers for platforms that don't have a complete list of port numbers.
//
// See https://www.iana.org/assignments/service-names-port-numbers
//
// On Unix, this map is augmented by readServices via goLookupPort.
var services = map[string]map[string]int{
"udp": {
"domain": 53,
},
"tcp": {
"ftp": 21,
"ftps": 990,
"gopher": 70, // ʕ◔ϖ◔ʔ
"http": 80,
"https": 443,
"imap2": 143,
"imap3": 220,
"imaps": 993,
"pop3": 110,
"pop3s": 995,
"smtp": 25,
"ssh": 22,
"telnet": 23,
},
}
// dnsWaitGroup can be used by tests to wait for all DNS goroutines to
// complete. This avoids races on the test hooks.
var dnsWaitGroup sync.WaitGroup
const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
func lookupProtocolMap(name string) (int, error) {
var lowerProtocol [maxProtoLength]byte
n := copy(lowerProtocol[:], name)
lowerASCIIBytes(lowerProtocol[:n])
proto, found := protocols[string(lowerProtocol[:n])]
if !found || n != len(name) {
return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name}
}
return proto, nil
}
// maxPortBufSize is the longest reasonable name of a service
// (non-numeric port).
// Currently the longest known IANA-unregistered name is
// "mobility-header", so we use that length, plus some slop in case
// something longer is added in the future.
const maxPortBufSize = len("mobility-header") + 10
func lookupPortMap(network, service string) (port int, error error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
case "udp4", "udp6":
network = "udp"
}
if m, ok := services[network]; ok {
var lowerService [maxPortBufSize]byte
n := copy(lowerService[:], service)
lowerASCIIBytes(lowerService[:n])
if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
return port, nil
}
}
return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
}
// ipVersion returns the provided network's IP version: '4', '6' or 0
// if network does not end in a '4' or '6' byte.
func ipVersion(network string) byte {
if network == "" {
return 0
}
n := network[len(network)-1]
if n != '4' && n != '6' {
n = 0
}
return n
}
// DefaultResolver is the resolver used by the package-level Lookup
// functions and by Dialers without a specified Resolver.
var DefaultResolver = &Resolver{}
// A Resolver looks up names and numbers.
//
// A nil *Resolver is equivalent to a zero Resolver.
type Resolver struct {
// PreferGo controls whether Go's built-in DNS resolver is preferred
// on platforms where it's available. It is equivalent to setting
// GODEBUG=netdns=go, but scoped to just this resolver.
PreferGo bool
// StrictErrors controls the behavior of temporary errors
// (including timeout, socket errors, and SERVFAIL) when using
// Go's built-in resolver. For a query composed of multiple
// sub-queries (such as an A+AAAA address lookup, or walking the
// DNS search list), this option causes such errors to abort the
// whole query instead of returning a partial result. This is
// not enabled by default because it may affect compatibility
// with resolvers that process AAAA queries incorrectly.
StrictErrors bool
// Dial optionally specifies an alternate dialer for use by
// Go's built-in DNS resolver to make TCP and UDP connections
// to DNS services. The host in the address parameter will
// always be a literal IP address and not a host name, and the
// port in the address parameter will be a literal port number
// and not a service name.
// If the Conn returned is also a PacketConn, sent and received DNS
// messages must adhere to RFC 1035 section 4.2.1, "UDP usage".
// Otherwise, DNS messages transmitted over Conn must adhere
// to RFC 7766 section 5, "Transport Protocol Selection".
// If nil, the default dialer is used.
Dial func(ctx context.Context, network, address string) (Conn, error)
// lookupGroup merges LookupIPAddr calls together for lookups for the same
// host. The lookupGroup key is the LookupIPAddr.host argument.
// The return values are ([]IPAddr, error).
lookupGroup singleflight.Group
// TODO(bradfitz): optional interface impl override hook
// TODO(bradfitz): Timeout time.Duration?
}
func (r *Resolver) preferGo() bool { return r != nil && r.PreferGo }
func (r *Resolver) strictErrors() bool { return r != nil && r.StrictErrors }
func (r *Resolver) getLookupGroup() *singleflight.Group {
if r == nil {
return &DefaultResolver.lookupGroup
}
return &r.lookupGroup
}
// LookupHost looks up the given host using the local resolver.
// It returns a slice of that host's addresses.
//
// LookupHost uses context.Background internally; to specify the context, use
// Resolver.LookupHost.
func LookupHost(host string) (addrs []string, err error) {
return DefaultResolver.LookupHost(context.Background(), host)
}
// LookupHost looks up the given host using the local resolver.
// It returns a slice of that host's addresses.
func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Make sure that no matter what we do later, host=="" is rejected.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
if _, err := netip.ParseAddr(host); err == nil {
return []string{host}, nil
}
return r.lookupHost(ctx, host)
}
// LookupIP looks up host using the local resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func LookupIP(host string) ([]IP, error) {
addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
if err != nil {
return nil, err
}
ips := make([]IP, len(addrs))
for i, ia := range addrs {
ips[i] = ia.IP
}
return ips, nil
}
// LookupIPAddr looks up host using the local resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
return r.lookupIPAddr(ctx, "ip", host)
}
// LookupIP looks up host for the given network using the local resolver.
// It returns a slice of that host's IP addresses of the type specified by
// network.
// network must be one of "ip", "ip4" or "ip6".
func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, error) {
afnet, _, err := parseNetwork(ctx, network, false)
if err != nil {
return nil, err
}
switch afnet {
case "ip", "ip4", "ip6":
default:
return nil, UnknownNetworkError(network)
}
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
addrs, err := r.internetAddrList(ctx, afnet, host)
if err != nil {
return nil, err
}
ips := make([]IP, 0, len(addrs))
for _, addr := range addrs {
ips = append(ips, addr.(*IPAddr).IP)
}
return ips, nil
}
// LookupNetIP looks up host using the local resolver.
// It returns a slice of that host's IP addresses of the type specified by
// network.
// The network must be one of "ip", "ip4" or "ip6".
func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
// TODO(bradfitz): make this efficient, making the internal net package
// type throughout be netip.Addr and only converting to the net.IP slice
// version at the edge. But for now (2021-10-20), this is a wrapper around
// the old way.
ips, err := r.LookupIP(ctx, network, host)
if err != nil {
return nil, err
}
ret := make([]netip.Addr, 0, len(ips))
for _, ip := range ips {
if a, ok := netip.AddrFromSlice(ip); ok {
ret = append(ret, a)
}
}
return ret, nil
}
// onlyValuesCtx is a context that uses an underlying context
// for value lookup if the underlying context hasn't yet expired.
type onlyValuesCtx struct {
context.Context
lookupValues context.Context
}
var _ context.Context = (*onlyValuesCtx)(nil)
// Value performs a lookup if the original context hasn't expired.
func (ovc *onlyValuesCtx) Value(key any) any {
select {
case <-ovc.lookupValues.Done():
return nil
default:
return ovc.lookupValues.Value(key)
}
}
// withUnexpiredValuesPreserved returns a context.Context that only uses lookupCtx
// for its values, otherwise it is never canceled and has no deadline.
// If the lookup context expires, any looked up values will return nil.
// See Issue 28600.
func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx}
}
// lookupIPAddr looks up host using the local resolver and particular network.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
// Make sure that no matter what we do later, host=="" is rejected.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
}
if ip, err := netip.ParseAddr(host); err == nil {
return []IPAddr{{IP: IP(ip.AsSlice()).To16(), Zone: ip.Zone()}}, nil
}
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil && trace.DNSStart != nil {
trace.DNSStart(host)
}
// The underlying resolver func is lookupIP by default but it
// can be overridden by tests. This is needed by net/http, so it
// uses a context key instead of unexported variables.
resolverFunc := r.lookupIP
if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string, string) ([]IPAddr, error)); alt != nil {
resolverFunc = alt
}
// We don't want a cancellation of ctx to affect the
// lookupGroup operation. Otherwise if our context gets
// canceled it might cause an error to be returned to a lookup
// using a completely different context. However we need to preserve
// only the values in context. See Issue 28600.
lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx))
lookupKey := network + "\000" + host
dnsWaitGroup.Add(1)
ch := r.getLookupGroup().DoChan(lookupKey, func() (any, error) {
return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
})
dnsWaitGroupDone := func(ch <-chan singleflight.Result, cancelFn context.CancelFunc) {
<-ch
dnsWaitGroup.Done()
cancelFn()
}
select {
case <-ctx.Done():
// Our context was canceled. If we are the only
// goroutine looking up this key, then drop the key
// from the lookupGroup and cancel the lookup.
// If there are other goroutines looking up this key,
// let the lookup continue uncanceled, and let later
// lookups with the same key share the result.
// See issues 8602, 20703, 22724.
if r.getLookupGroup().ForgetUnshared(lookupKey) {
lookupGroupCancel()
go dnsWaitGroupDone(ch, func() {})
} else {
go dnsWaitGroupDone(ch, lookupGroupCancel)
}
ctxErr := ctx.Err()
err := &DNSError{
Err: mapErr(ctxErr).Error(),
Name: host,
IsTimeout: ctxErr == context.DeadlineExceeded,
}
if trace != nil && trace.DNSDone != nil {
trace.DNSDone(nil, false, err)
}
return nil, err
case r := <-ch:
dnsWaitGroup.Done()
lookupGroupCancel()
err := r.Err
if err != nil {
if _, ok := err.(*DNSError); !ok {
isTimeout := false
if err == context.DeadlineExceeded {
isTimeout = true
} else if terr, ok := err.(timeout); ok {
isTimeout = terr.Timeout()
}
err = &DNSError{
Err: err.Error(),
Name: host,
IsTimeout: isTimeout,
}
}
}
if trace != nil && trace.DNSDone != nil {
addrs, _ := r.Val.([]IPAddr)
trace.DNSDone(ipAddrsEface(addrs), r.Shared, err)
}
return lookupIPReturn(r.Val, err, r.Shared)
}
}
// lookupIPReturn turns the return values from singleflight.Do into
// the return values from LookupIP.
func lookupIPReturn(addrsi any, err error, shared bool) ([]IPAddr, error) {
if err != nil {
return nil, err
}
addrs := addrsi.([]IPAddr)
if shared {
clone := make([]IPAddr, len(addrs))
copy(clone, addrs)
addrs = clone
}
return addrs, nil
}
// ipAddrsEface returns an empty interface slice of addrs.
func ipAddrsEface(addrs []IPAddr) []any {
s := make([]any, len(addrs))
for i, v := range addrs {
s[i] = v
}
return s
}
// LookupPort looks up the port for the given network and service.
//
// LookupPort uses context.Background internally; to specify the context, use
// Resolver.LookupPort.
func LookupPort(network, service string) (port int, err error) {
return DefaultResolver.LookupPort(context.Background(), network, service)
}
// LookupPort looks up the port for the given network and service.
func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
port, needsLookup := parsePort(service)
if needsLookup {
switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
case "": // a hint wildcard for Go 1.0 undocumented behavior
network = "ip"
default:
return 0, &AddrError{Err: "unknown network", Addr: network}
}
port, err = r.lookupPort(ctx, network, service)
if err != nil {
return 0, err
}
}
if 0 > port || port > 65535 {
return 0, &AddrError{Err: "invalid port", Addr: service}
}
return port, nil
}
// LookupCNAME returns the canonical name for the given host.
// Callers that do not care about the canonical name can call
// LookupHost or LookupIP directly; both take care of resolving
// the canonical name as part of the lookup.
//
// A canonical name is the final name after following zero
// or more CNAME records.
// LookupCNAME does not return an error if host does not
// contain DNS "CNAME" records, as long as host resolves to
// address records.
//
// The returned canonical name is validated to be a properly
// formatted presentation-format domain name.
//
// LookupCNAME uses context.Background internally; to specify the context, use
// Resolver.LookupCNAME.
func LookupCNAME(host string) (cname string, err error) {
return DefaultResolver.LookupCNAME(context.Background(), host)
}
// LookupCNAME returns the canonical name for the given host.
// Callers that do not care about the canonical name can call
// LookupHost or LookupIP directly; both take care of resolving
// the canonical name as part of the lookup.
//
// A canonical name is the final name after following zero
// or more CNAME records.
// LookupCNAME does not return an error if host does not
// contain DNS "CNAME" records, as long as host resolves to
// address records.
//
// The returned canonical name is validated to be a properly
// formatted presentation-format domain name.
func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) {
cname, err := r.lookupCNAME(ctx, host)
if err != nil {
return "", err
}
if !isDomainName(cname) {
return "", &DNSError{Err: errMalformedDNSRecordsDetail, Name: host}
}
return cname, nil
}
// LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
// by weight within a priority.
//
// LookupSRV constructs the DNS name to look up following RFC 2782.
// That is, it looks up _service._proto.name. To accommodate services
// publishing SRV records under non-standard names, if both service
// and proto are empty strings, LookupSRV looks up name directly.
//
// The returned service names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
}
// LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name. The proto is "tcp" or "udp".
// The returned records are sorted by priority and randomized
// by weight within a priority.
//
// LookupSRV constructs the DNS name to look up following RFC 2782.
// That is, it looks up _service._proto.name. To accommodate services
// publishing SRV records under non-standard names, if both service
// and proto are empty strings, LookupSRV looks up name directly.
//
// The returned service names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
if err != nil {
return "", nil, err
}
if cname != "" && !isDomainName(cname) {
return "", nil, &DNSError{Err: "SRV header name is invalid", Name: name}
}
filteredAddrs := make([]*SRV, 0, len(addrs))
for _, addr := range addrs {
if addr == nil {
continue
}
if !isDomainName(addr.Target) {
continue
}
filteredAddrs = append(filteredAddrs, addr)
}
if len(addrs) != len(filteredAddrs) {
return cname, filteredAddrs, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
}
return cname, filteredAddrs, nil
}
// LookupMX returns the DNS MX records for the given domain name sorted by preference.
//
// The returned mail server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
//
// LookupMX uses context.Background internally; to specify the context, use
// Resolver.LookupMX.
func LookupMX(name string) ([]*MX, error) {
return DefaultResolver.LookupMX(context.Background(), name)
}
// LookupMX returns the DNS MX records for the given domain name sorted by preference.
//
// The returned mail server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
records, err := r.lookupMX(ctx, name)
if err != nil {
return nil, err
}
filteredMX := make([]*MX, 0, len(records))
for _, mx := range records {
if mx == nil {
continue
}
if !isDomainName(mx.Host) {
continue
}
filteredMX = append(filteredMX, mx)
}
if len(records) != len(filteredMX) {
return filteredMX, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
}
return filteredMX, nil
}
// LookupNS returns the DNS NS records for the given domain name.
//
// The returned name server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
//
// LookupNS uses context.Background internally; to specify the context, use
// Resolver.LookupNS.
func LookupNS(name string) ([]*NS, error) {
return DefaultResolver.LookupNS(context.Background(), name)
}
// LookupNS returns the DNS NS records for the given domain name.
//
// The returned name server names are validated to be properly
// formatted presentation-format domain names. If the response contains
// invalid names, those records are filtered out and an error
// will be returned alongside the remaining results, if any.
func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
records, err := r.lookupNS(ctx, name)
if err != nil {
return nil, err
}
filteredNS := make([]*NS, 0, len(records))
for _, ns := range records {
if ns == nil {
continue
}
if !isDomainName(ns.Host) {
continue
}
filteredNS = append(filteredNS, ns)
}
if len(records) != len(filteredNS) {
return filteredNS, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
}
return filteredNS, nil
}
// LookupTXT returns the DNS TXT records for the given domain name.
//
// LookupTXT uses context.Background internally; to specify the context, use
// Resolver.LookupTXT.
func LookupTXT(name string) ([]string, error) {
return DefaultResolver.lookupTXT(context.Background(), name)
}
// LookupTXT returns the DNS TXT records for the given domain name.
func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
return r.lookupTXT(ctx, name)
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
//
// The returned names are validated to be properly formatted presentation-format
// domain names. If the response contains invalid names, those records are filtered
// out and an error will be returned alongside the remaining results, if any.
//
// When using the host C library resolver, at most one result will be
// returned. To bypass the host resolver, use a custom Resolver.
//
// LookupAddr uses context.Background internally; to specify the context, use
// Resolver.LookupAddr.
func LookupAddr(addr string) (names []string, err error) {
return DefaultResolver.LookupAddr(context.Background(), addr)
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
//
// The returned names are validated to be properly formatted presentation-format
// domain names. If the response contains invalid names, those records are filtered
// out and an error will be returned alongside the remaining results, if any.
func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
names, err := r.lookupAddr(ctx, addr)
if err != nil {
return nil, err
}
filteredNames := make([]string, 0, len(names))
for _, name := range names {
if isDomainName(name) {
filteredNames = append(filteredNames, name)
}
}
if len(names) != len(filteredNames) {
return filteredNames, &DNSError{Err: errMalformedDNSRecordsDetail, Name: addr}
}
return filteredNames, nil
}
// errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup...
// method receives DNS records which contain invalid DNS names. This may be returned alongside
// results which have had the malformed records filtered out.
var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
// dial makes a new connection to the provided server (which must be
// an IP address) with the provided network type, using either r.Dial
// (if both r and r.Dial are non-nil) or else Dialer.DialContext.
func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
var c Conn
var err error
if r != nil && r.Dial != nil {
c, err = r.Dial(ctx, network, server)
} else {
var d Dialer
c, err = d.DialContext(ctx, network, server)
}
if err != nil {
return nil, mapErr(err)
}
return c, nil
}
// goLookupSRV returns the SRV records for a target name, built either
// from its component service ("sip"), protocol ("tcp"), and name
// ("example.com."), or from name directly (if service and proto are
// both empty).
//
// In either case, the returned target name ("_sip._tcp.example.com.")
// is also returned on success.
//
// The records are sorted by weight.
func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) {
if service == "" && proto == "" {
target = name
} else {
target = "_" + service + "._" + proto + "." + name
}
p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV, nil)
if err != nil {
return "", nil, err
}
var cname dnsmessage.Name
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeSRV {
if err := p.SkipAnswer(); err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
if cname.Length == 0 && h.Name.Length != 0 {
cname = h.Name
}
srv, err := p.SRVResource()
if err != nil {
return "", nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
return cname.String(), srvs, nil
}
// goLookupMX returns the MX records for name.
func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX, nil)
if err != nil {
return nil, err
}
var mxs []*MX
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeMX {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
mx, err := p.MXResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
}
byPref(mxs).sort()
return mxs, nil
}
// goLookupNS returns the NS records for name.
func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS, nil)
if err != nil {
return nil, err
}
var nss []*NS
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeNS {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
ns, err := p.NSResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
}
// goLookupTXT returns the TXT records from name.
func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) {
p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT, nil)
if err != nil {
return nil, err
}
var txts []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
if h.Type != dnsmessage.TypeTXT {
if err := p.SkipAnswer(); err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
continue
}
txt, err := p.TXTResource()
if err != nil {
return nil, &DNSError{
Err: "cannot unmarshal DNS message",
Name: name,
Server: server,
}
}
// Multiple strings in one TXT record need to be
// concatenated without separator to be consistent
// with previous Go resolver.
n := 0
for _, s := range txt.TXT {
n += len(s)
}
txtJoin := make([]byte, 0, n)
for _, s := range txt.TXT {
txtJoin = append(txtJoin, s...)
}
if len(txts) == 0 {
txts = make([]string, 0, 1)
}
txts = append(txts, string(txtJoin))
}
return txts, nil
}
func parseCNAMEFromResources(resources []dnsmessage.Resource) (string, error) {
if len(resources) == 0 {
return "", errors.New("no CNAME record received")
}
c, ok := resources[0].Body.(*dnsmessage.CNAMEResource)
if !ok {
return "", errors.New("could not parse CNAME record")
}
return c.CNAME.String(), nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package net
import (
"context"
"internal/bytealg"
"sync"
"syscall"
)
var onceReadProtocols sync.Once
// readProtocols loads contents of /etc/protocols into protocols map
// for quick access.
func readProtocols() {
file, err := open("/etc/protocols")
if err != nil {
return
}
defer file.close()
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// tcp 6 TCP # transmission control protocol
if i := bytealg.IndexByteString(line, '#'); i >= 0 {
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
if proto, _, ok := dtoi(f[1]); ok {
if _, ok := protocols[f[0]]; !ok {
protocols[f[0]] = proto
}
for _, alias := range f[2:] {
if _, ok := protocols[alias]; !ok {
protocols[alias] = proto
}
}
}
}
}
// lookupProtocol looks up IP protocol name in /etc/protocols and
// returns correspondent protocol number.
func lookupProtocol(_ context.Context, name string) (int, error) {
onceReadProtocols.Do(readProtocols)
return lookupProtocolMap(name)
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order, conf := systemConf().hostLookupOrder(r, host)
if !r.preferGo() && order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(ctx, host); ok {
return addrs, err
}
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
return r.goLookupHostOrder(ctx, host, order, conf)
}
func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
if r.preferGo() {
return r.goLookupIP(ctx, network, host)
}
order, conf := systemConf().hostLookupOrder(r, host)
if order == hostLookupCgo {
if addrs, err, ok := cgoLookupIP(ctx, network, host); ok {
return addrs, err
}
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
ips, _, err := r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
return ips, err
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
if !r.preferGo() && systemConf().canUseCgo() {
if port, err, ok := cgoLookupPort(ctx, network, service); ok {
if err != nil {
// Issue 18213: if cgo fails, first check to see whether we
// have the answer baked-in to the net package.
if port, err := goLookupPort(network, service); err == nil {
return port, nil
}
}
return port, err
}
}
return goLookupPort(network, service)
}
func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
order, conf := systemConf().hostLookupOrder(r, name)
if !r.preferGo() && order == hostLookupCgo {
if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
return cname, err
}
}
return r.goLookupCNAME(ctx, name, order, conf)
}
func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
return r.goLookupSRV(ctx, service, proto, name)
}
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
return r.goLookupMX(ctx, name)
}
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
return r.goLookupNS(ctx, name)
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
return r.goLookupTXT(ctx, name)
}
func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
order, conf := systemConf().hostLookupOrder(r, "")
if !r.preferGo() && order == hostLookupCgo {
if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok {
return ptrs, err
}
}
return r.goLookupPTR(ctx, addr, conf)
}
// concurrentThreadsLimit returns the number of threads we permit to
// run concurrently doing DNS lookups via cgo. A DNS lookup may use a
// file descriptor so we limit this to less than the number of
// permitted open files. On some systems, notably Darwin, if
// getaddrinfo is unable to open a file descriptor it simply returns
// EAI_NONAME rather than a useful error. Limiting the number of
// concurrent getaddrinfo calls to less than the permitted number of
// file descriptors makes that error less likely. We don't bother to
// apply the same limit to DNS lookups run directly from Go, because
// there we will return a meaningful "too many open files" error.
func concurrentThreadsLimit() int {
var rlim syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil {
return 500
}
r := int(rlim.Cur)
if r > 500 {
r = 500
} else if r > 30 {
r -= 30
}
return r
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
const hexDigit = "0123456789abcdef"
// A HardwareAddr represents a physical hardware address.
type HardwareAddr []byte
func (a HardwareAddr) String() string {
if len(a) == 0 {
return ""
}
buf := make([]byte, 0, len(a)*3-1)
for i, b := range a {
if i > 0 {
buf = append(buf, ':')
}
buf = append(buf, hexDigit[b>>4])
buf = append(buf, hexDigit[b&0xF])
}
return string(buf)
}
// ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, EUI-64, or a 20-octet
// IP over InfiniBand link-layer address using one of the following formats:
//
// 00:00:5e:00:53:01
// 02:00:5e:10:00:00:00:01
// 00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01
// 00-00-5e-00-53-01
// 02-00-5e-10-00-00-00-01
// 00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01
// 0000.5e00.5301
// 0200.5e10.0000.0001
// 0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001
func ParseMAC(s string) (hw HardwareAddr, err error) {
if len(s) < 14 {
goto error
}
if s[2] == ':' || s[2] == '-' {
if (len(s)+1)%3 != 0 {
goto error
}
n := (len(s) + 1) / 3
if n != 6 && n != 8 && n != 20 {
goto error
}
hw = make(HardwareAddr, n)
for x, i := 0, 0; i < n; i++ {
var ok bool
if hw[i], ok = xtoi2(s[x:], s[2]); !ok {
goto error
}
x += 3
}
} else if s[4] == '.' {
if (len(s)+1)%5 != 0 {
goto error
}
n := 2 * (len(s) + 1) / 5
if n != 6 && n != 8 && n != 20 {
goto error
}
hw = make(HardwareAddr, n)
for x, i := 0, 0; i < n; i += 2 {
var ok bool
if hw[i], ok = xtoi2(s[x:x+2], 0); !ok {
goto error
}
if hw[i+1], ok = xtoi2(s[x+2:], s[4]); !ok {
goto error
}
x += 5
}
} else {
goto error
}
return hw, nil
error:
return nil, &AddrError{Err: "invalid MAC address", Addr: s}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package mail implements parsing of mail messages.
For the most part, this package follows the syntax as specified by RFC 5322 and
extended by RFC 6532.
Notable divergences:
- Obsolete address formats are not parsed, including addresses with
embedded route information.
- The full range of spacing (the CFWS syntax element) is not supported,
such as breaking addresses across lines.
- No unicode normalization is performed.
- The special characters ()[]:;@\, are allowed to appear unquoted in names.
*/
package mail
import (
"bufio"
"errors"
"fmt"
"io"
"log"
"mime"
"net/textproto"
"strings"
"sync"
"time"
"unicode/utf8"
)
var debug = debugT(false)
type debugT bool
func (d debugT) Printf(format string, args ...any) {
if d {
log.Printf(format, args...)
}
}
// A Message represents a parsed mail message.
type Message struct {
Header Header
Body io.Reader
}
// ReadMessage reads a message from r.
// The headers are parsed, and the body of the message will be available
// for reading from msg.Body.
func ReadMessage(r io.Reader) (msg *Message, err error) {
tp := textproto.NewReader(bufio.NewReader(r))
hdr, err := tp.ReadMIMEHeader()
if err != nil {
return nil, err
}
return &Message{
Header: Header(hdr),
Body: tp.R,
}, nil
}
// Layouts suitable for passing to time.Parse.
// These are tried in order.
var (
dateLayoutsBuildOnce sync.Once
dateLayouts []string
)
func buildDateLayouts() {
// Generate layouts based on RFC 5322, section 3.3.
dows := [...]string{"", "Mon, "} // day-of-week
days := [...]string{"2", "02"} // day = 1*2DIGIT
years := [...]string{"2006", "06"} // year = 4*DIGIT / 2*DIGIT
seconds := [...]string{":05", ""} // second
// "-0700 (MST)" is not in RFC 5322, but is common.
zones := [...]string{"-0700", "MST", "UT"} // zone = (("+" / "-") 4DIGIT) / "UT" / "GMT" / ...
for _, dow := range dows {
for _, day := range days {
for _, year := range years {
for _, second := range seconds {
for _, zone := range zones {
s := dow + day + " Jan " + year + " 15:04" + second + " " + zone
dateLayouts = append(dateLayouts, s)
}
}
}
}
}
}
// ParseDate parses an RFC 5322 date string.
func ParseDate(date string) (time.Time, error) {
dateLayoutsBuildOnce.Do(buildDateLayouts)
// CR and LF must match and are tolerated anywhere in the date field.
date = strings.ReplaceAll(date, "\r\n", "")
if strings.Contains(date, "\r") {
return time.Time{}, errors.New("mail: header has a CR without LF")
}
// Re-using some addrParser methods which support obsolete text, i.e. non-printable ASCII
p := addrParser{date, nil}
p.skipSpace()
// RFC 5322: zone = (FWS ( "+" / "-" ) 4DIGIT) / obs-zone
// zone length is always 5 chars unless obsolete (obs-zone)
if ind := strings.IndexAny(p.s, "+-"); ind != -1 && len(p.s) >= ind+5 {
date = p.s[:ind+5]
p.s = p.s[ind+5:]
} else {
ind := strings.Index(p.s, "T")
if ind == 0 {
// In this case we have the following date formats:
// * Thu, 20 Nov 1997 09:55:06 MDT
// * Thu, 20 Nov 1997 09:55:06 MDT (MDT)
// * Thu, 20 Nov 1997 09:55:06 MDT (This comment)
ind = strings.Index(p.s[1:], "T")
if ind != -1 {
ind++
}
}
if ind != -1 && len(p.s) >= ind+5 {
// The last letter T of the obsolete time zone is checked when no standard time zone is found.
// If T is misplaced, the date to parse is garbage.
date = p.s[:ind+1]
p.s = p.s[ind+1:]
}
}
if !p.skipCFWS() {
return time.Time{}, errors.New("mail: misformatted parenthetical comment")
}
for _, layout := range dateLayouts {
t, err := time.Parse(layout, date)
if err == nil {
return t, nil
}
}
return time.Time{}, errors.New("mail: header could not be parsed")
}
// A Header represents the key-value pairs in a mail message header.
type Header map[string][]string
// Get gets the first value associated with the given key.
// It is case insensitive; CanonicalMIMEHeaderKey is used
// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
// To access multiple values of a key, or to use non-canonical keys,
// access the map directly.
func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
var ErrHeaderNotPresent = errors.New("mail: header not in message")
// Date parses the Date header field.
func (h Header) Date() (time.Time, error) {
hdr := h.Get("Date")
if hdr == "" {
return time.Time{}, ErrHeaderNotPresent
}
return ParseDate(hdr)
}
// AddressList parses the named header field as a list of addresses.
func (h Header) AddressList(key string) ([]*Address, error) {
hdr := h.Get(key)
if hdr == "" {
return nil, ErrHeaderNotPresent
}
return ParseAddressList(hdr)
}
// Address represents a single mail address.
// An address such as "Barry Gibbs <bg@example.com>" is represented
// as Address{Name: "Barry Gibbs", Address: "bg@example.com"}.
type Address struct {
Name string // Proper name; may be empty.
Address string // user@domain
}
// ParseAddress parses a single RFC 5322 address, e.g. "Barry Gibbs <bg@example.com>"
func ParseAddress(address string) (*Address, error) {
return (&addrParser{s: address}).parseSingleAddress()
}
// ParseAddressList parses the given string as a list of addresses.
func ParseAddressList(list string) ([]*Address, error) {
return (&addrParser{s: list}).parseAddressList()
}
// An AddressParser is an RFC 5322 address parser.
type AddressParser struct {
// WordDecoder optionally specifies a decoder for RFC 2047 encoded-words.
WordDecoder *mime.WordDecoder
}
// Parse parses a single RFC 5322 address of the
// form "Gogh Fir <gf@example.com>" or "foo@example.com".
func (p *AddressParser) Parse(address string) (*Address, error) {
return (&addrParser{s: address, dec: p.WordDecoder}).parseSingleAddress()
}
// ParseList parses the given string as a list of comma-separated addresses
// of the form "Gogh Fir <gf@example.com>" or "foo@example.com".
func (p *AddressParser) ParseList(list string) ([]*Address, error) {
return (&addrParser{s: list, dec: p.WordDecoder}).parseAddressList()
}
// String formats the address as a valid RFC 5322 address.
// If the address's name contains non-ASCII characters
// the name will be rendered according to RFC 2047.
func (a *Address) String() string {
// Format address local@domain
at := strings.LastIndex(a.Address, "@")
var local, domain string
if at < 0 {
// This is a malformed address ("@" is required in addr-spec);
// treat the whole address as local-part.
local = a.Address
} else {
local, domain = a.Address[:at], a.Address[at+1:]
}
// Add quotes if needed
quoteLocal := false
for i, r := range local {
if isAtext(r, false, false) {
continue
}
if r == '.' {
// Dots are okay if they are surrounded by atext.
// We only need to check that the previous byte is
// not a dot, and this isn't the end of the string.
if i > 0 && local[i-1] != '.' && i < len(local)-1 {
continue
}
}
quoteLocal = true
break
}
if quoteLocal {
local = quoteString(local)
}
s := "<" + local + "@" + domain + ">"
if a.Name == "" {
return s
}
// If every character is printable ASCII, quoting is simple.
allPrintable := true
for _, r := range a.Name {
// isWSP here should actually be isFWS,
// but we don't support folding yet.
if !isVchar(r) && !isWSP(r) || isMultibyte(r) {
allPrintable = false
break
}
}
if allPrintable {
return quoteString(a.Name) + " " + s
}
// Text in an encoded-word in a display-name must not contain certain
// characters like quotes or parentheses (see RFC 2047 section 5.3).
// When this is the case encode the name using base64 encoding.
if strings.ContainsAny(a.Name, "\"#$%&'(),.:;<>@[]^`{|}~") {
return mime.BEncoding.Encode("utf-8", a.Name) + " " + s
}
return mime.QEncoding.Encode("utf-8", a.Name) + " " + s
}
type addrParser struct {
s string
dec *mime.WordDecoder // may be nil
}
func (p *addrParser) parseAddressList() ([]*Address, error) {
var list []*Address
for {
p.skipSpace()
// allow skipping empty entries (RFC5322 obs-addr-list)
if p.consume(',') {
continue
}
addrs, err := p.parseAddress(true)
if err != nil {
return nil, err
}
list = append(list, addrs...)
if !p.skipCFWS() {
return nil, errors.New("mail: misformatted parenthetical comment")
}
if p.empty() {
break
}
if p.peek() != ',' {
return nil, errors.New("mail: expected comma")
}
// Skip empty entries for obs-addr-list.
for p.consume(',') {
p.skipSpace()
}
if p.empty() {
break
}
}
return list, nil
}
func (p *addrParser) parseSingleAddress() (*Address, error) {
addrs, err := p.parseAddress(true)
if err != nil {
return nil, err
}
if !p.skipCFWS() {
return nil, errors.New("mail: misformatted parenthetical comment")
}
if !p.empty() {
return nil, fmt.Errorf("mail: expected single address, got %q", p.s)
}
if len(addrs) == 0 {
return nil, errors.New("mail: empty group")
}
if len(addrs) > 1 {
return nil, errors.New("mail: group with multiple addresses")
}
return addrs[0], nil
}
// parseAddress parses a single RFC 5322 address at the start of p.
func (p *addrParser) parseAddress(handleGroup bool) ([]*Address, error) {
debug.Printf("parseAddress: %q", p.s)
p.skipSpace()
if p.empty() {
return nil, errors.New("mail: no address")
}
// address = mailbox / group
// mailbox = name-addr / addr-spec
// group = display-name ":" [group-list] ";" [CFWS]
// addr-spec has a more restricted grammar than name-addr,
// so try parsing it first, and fallback to name-addr.
// TODO(dsymonds): Is this really correct?
spec, err := p.consumeAddrSpec()
if err == nil {
var displayName string
p.skipSpace()
if !p.empty() && p.peek() == '(' {
displayName, err = p.consumeDisplayNameComment()
if err != nil {
return nil, err
}
}
return []*Address{{
Name: displayName,
Address: spec,
}}, err
}
debug.Printf("parseAddress: not an addr-spec: %v", err)
debug.Printf("parseAddress: state is now %q", p.s)
// display-name
var displayName string
if p.peek() != '<' {
displayName, err = p.consumePhrase()
if err != nil {
return nil, err
}
}
debug.Printf("parseAddress: displayName=%q", displayName)
p.skipSpace()
if handleGroup {
if p.consume(':') {
return p.consumeGroupList()
}
}
// angle-addr = "<" addr-spec ">"
if !p.consume('<') {
atext := true
for _, r := range displayName {
if !isAtext(r, true, false) {
atext = false
break
}
}
if atext {
// The input is like "foo.bar"; it's possible the input
// meant to be "foo.bar@domain", or "foo.bar <...>".
return nil, errors.New("mail: missing '@' or angle-addr")
}
// The input is like "Full Name", which couldn't possibly be a
// valid email address if followed by "@domain"; the input
// likely meant to be "Full Name <...>".
return nil, errors.New("mail: no angle-addr")
}
spec, err = p.consumeAddrSpec()
if err != nil {
return nil, err
}
if !p.consume('>') {
return nil, errors.New("mail: unclosed angle-addr")
}
debug.Printf("parseAddress: spec=%q", spec)
return []*Address{{
Name: displayName,
Address: spec,
}}, nil
}
func (p *addrParser) consumeGroupList() ([]*Address, error) {
var group []*Address
// handle empty group.
p.skipSpace()
if p.consume(';') {
p.skipCFWS()
return group, nil
}
for {
p.skipSpace()
// embedded groups not allowed.
addrs, err := p.parseAddress(false)
if err != nil {
return nil, err
}
group = append(group, addrs...)
if !p.skipCFWS() {
return nil, errors.New("mail: misformatted parenthetical comment")
}
if p.consume(';') {
p.skipCFWS()
break
}
if !p.consume(',') {
return nil, errors.New("mail: expected comma")
}
}
return group, nil
}
// consumeAddrSpec parses a single RFC 5322 addr-spec at the start of p.
func (p *addrParser) consumeAddrSpec() (spec string, err error) {
debug.Printf("consumeAddrSpec: %q", p.s)
orig := *p
defer func() {
if err != nil {
*p = orig
}
}()
// local-part = dot-atom / quoted-string
var localPart string
p.skipSpace()
if p.empty() {
return "", errors.New("mail: no addr-spec")
}
if p.peek() == '"' {
// quoted-string
debug.Printf("consumeAddrSpec: parsing quoted-string")
localPart, err = p.consumeQuotedString()
if localPart == "" {
err = errors.New("mail: empty quoted string in addr-spec")
}
} else {
// dot-atom
debug.Printf("consumeAddrSpec: parsing dot-atom")
localPart, err = p.consumeAtom(true, false)
}
if err != nil {
debug.Printf("consumeAddrSpec: failed: %v", err)
return "", err
}
if !p.consume('@') {
return "", errors.New("mail: missing @ in addr-spec")
}
// domain = dot-atom / domain-literal
var domain string
p.skipSpace()
if p.empty() {
return "", errors.New("mail: no domain in addr-spec")
}
// TODO(dsymonds): Handle domain-literal
domain, err = p.consumeAtom(true, false)
if err != nil {
return "", err
}
return localPart + "@" + domain, nil
}
// consumePhrase parses the RFC 5322 phrase at the start of p.
func (p *addrParser) consumePhrase() (phrase string, err error) {
debug.Printf("consumePhrase: [%s]", p.s)
// phrase = 1*word
var words []string
var isPrevEncoded bool
for {
// word = atom / quoted-string
var word string
p.skipSpace()
if p.empty() {
break
}
isEncoded := false
if p.peek() == '"' {
// quoted-string
word, err = p.consumeQuotedString()
} else {
// atom
// We actually parse dot-atom here to be more permissive
// than what RFC 5322 specifies.
word, err = p.consumeAtom(true, true)
if err == nil {
word, isEncoded, err = p.decodeRFC2047Word(word)
}
}
if err != nil {
break
}
debug.Printf("consumePhrase: consumed %q", word)
if isPrevEncoded && isEncoded {
words[len(words)-1] += word
} else {
words = append(words, word)
}
isPrevEncoded = isEncoded
}
// Ignore any error if we got at least one word.
if err != nil && len(words) == 0 {
debug.Printf("consumePhrase: hit err: %v", err)
return "", fmt.Errorf("mail: missing word in phrase: %v", err)
}
phrase = strings.Join(words, " ")
return phrase, nil
}
// consumeQuotedString parses the quoted string at the start of p.
func (p *addrParser) consumeQuotedString() (qs string, err error) {
// Assume first byte is '"'.
i := 1
qsb := make([]rune, 0, 10)
escaped := false
Loop:
for {
r, size := utf8.DecodeRuneInString(p.s[i:])
switch {
case size == 0:
return "", errors.New("mail: unclosed quoted-string")
case size == 1 && r == utf8.RuneError:
return "", fmt.Errorf("mail: invalid utf-8 in quoted-string: %q", p.s)
case escaped:
// quoted-pair = ("\" (VCHAR / WSP))
if !isVchar(r) && !isWSP(r) {
return "", fmt.Errorf("mail: bad character in quoted-string: %q", r)
}
qsb = append(qsb, r)
escaped = false
case isQtext(r) || isWSP(r):
// qtext (printable US-ASCII excluding " and \), or
// FWS (almost; we're ignoring CRLF)
qsb = append(qsb, r)
case r == '"':
break Loop
case r == '\\':
escaped = true
default:
return "", fmt.Errorf("mail: bad character in quoted-string: %q", r)
}
i += size
}
p.s = p.s[i+1:]
return string(qsb), nil
}
// consumeAtom parses an RFC 5322 atom at the start of p.
// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead.
// If permissive is true, consumeAtom will not fail on:
// - leading/trailing/double dots in the atom (see golang.org/issue/4938)
// - special characters (RFC 5322 3.2.3) except '<', '>', ':' and '"' (see golang.org/issue/21018)
func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) {
i := 0
Loop:
for {
r, size := utf8.DecodeRuneInString(p.s[i:])
switch {
case size == 1 && r == utf8.RuneError:
return "", fmt.Errorf("mail: invalid utf-8 in address: %q", p.s)
case size == 0 || !isAtext(r, dot, permissive):
break Loop
default:
i += size
}
}
if i == 0 {
return "", errors.New("mail: invalid string")
}
atom, p.s = p.s[:i], p.s[i:]
if !permissive {
if strings.HasPrefix(atom, ".") {
return "", errors.New("mail: leading dot in atom")
}
if strings.Contains(atom, "..") {
return "", errors.New("mail: double dot in atom")
}
if strings.HasSuffix(atom, ".") {
return "", errors.New("mail: trailing dot in atom")
}
}
return atom, nil
}
func (p *addrParser) consumeDisplayNameComment() (string, error) {
if !p.consume('(') {
return "", errors.New("mail: comment does not start with (")
}
comment, ok := p.consumeComment()
if !ok {
return "", errors.New("mail: misformatted parenthetical comment")
}
// TODO(stapelberg): parse quoted-string within comment
words := strings.FieldsFunc(comment, func(r rune) bool { return r == ' ' || r == '\t' })
for idx, word := range words {
decoded, isEncoded, err := p.decodeRFC2047Word(word)
if err != nil {
return "", err
}
if isEncoded {
words[idx] = decoded
}
}
return strings.Join(words, " "), nil
}
func (p *addrParser) consume(c byte) bool {
if p.empty() || p.peek() != c {
return false
}
p.s = p.s[1:]
return true
}
// skipSpace skips the leading space and tab characters.
func (p *addrParser) skipSpace() {
p.s = strings.TrimLeft(p.s, " \t")
}
func (p *addrParser) peek() byte {
return p.s[0]
}
func (p *addrParser) empty() bool {
return p.len() == 0
}
func (p *addrParser) len() int {
return len(p.s)
}
// skipCFWS skips CFWS as defined in RFC5322.
func (p *addrParser) skipCFWS() bool {
p.skipSpace()
for {
if !p.consume('(') {
break
}
if _, ok := p.consumeComment(); !ok {
return false
}
p.skipSpace()
}
return true
}
func (p *addrParser) consumeComment() (string, bool) {
// '(' already consumed.
depth := 1
var comment string
for {
if p.empty() || depth == 0 {
break
}
if p.peek() == '\\' && p.len() > 1 {
p.s = p.s[1:]
} else if p.peek() == '(' {
depth++
} else if p.peek() == ')' {
depth--
}
if depth > 0 {
comment += p.s[:1]
}
p.s = p.s[1:]
}
return comment, depth == 0
}
func (p *addrParser) decodeRFC2047Word(s string) (word string, isEncoded bool, err error) {
dec := p.dec
if dec == nil {
dec = &rfc2047Decoder
}
// Substitute our own CharsetReader function so that we can tell
// whether an error from the Decode method was due to the
// CharsetReader (meaning the charset is invalid).
// We used to look for the charsetError type in the error result,
// but that behaves badly with CharsetReaders other than the
// one in rfc2047Decoder.
adec := *dec
charsetReaderError := false
adec.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
if dec.CharsetReader == nil {
charsetReaderError = true
return nil, charsetError(charset)
}
r, err := dec.CharsetReader(charset, input)
if err != nil {
charsetReaderError = true
}
return r, err
}
word, err = adec.Decode(s)
if err == nil {
return word, true, nil
}
// If the error came from the character set reader
// (meaning the character set itself is invalid
// but the decoding worked fine until then),
// return the original text and the error,
// with isEncoded=true.
if charsetReaderError {
return s, true, err
}
// Ignore invalid RFC 2047 encoded-word errors.
return s, false, nil
}
var rfc2047Decoder = mime.WordDecoder{
CharsetReader: func(charset string, input io.Reader) (io.Reader, error) {
return nil, charsetError(charset)
},
}
type charsetError string
func (e charsetError) Error() string {
return fmt.Sprintf("charset not supported: %q", string(e))
}
// isAtext reports whether r is an RFC 5322 atext character.
// If dot is true, period is included.
// If permissive is true, RFC 5322 3.2.3 specials is included,
// except '<', '>', ':' and '"'.
func isAtext(r rune, dot, permissive bool) bool {
switch r {
case '.':
return dot
// RFC 5322 3.2.3. specials
case '(', ')', '[', ']', ';', '@', '\\', ',':
return permissive
case '<', '>', '"', ':':
return false
}
return isVchar(r)
}
// isQtext reports whether r is an RFC 5322 qtext character.
func isQtext(r rune) bool {
// Printable US-ASCII, excluding backslash or quote.
if r == '\\' || r == '"' {
return false
}
return isVchar(r)
}
// quoteString renders a string as an RFC 5322 quoted-string.
func quoteString(s string) string {
var b strings.Builder
b.WriteByte('"')
for _, r := range s {
if isQtext(r) || isWSP(r) {
b.WriteRune(r)
} else if isVchar(r) {
b.WriteByte('\\')
b.WriteRune(r)
}
}
b.WriteByte('"')
return b.String()
}
// isVchar reports whether r is an RFC 5322 VCHAR character.
func isVchar(r rune) bool {
// Visible (printing) characters.
return '!' <= r && r <= '~' || isMultibyte(r)
}
// isMultibyte reports whether r is a multi-byte UTF-8 character
// as supported by RFC 6532.
func isMultibyte(r rune) bool {
return r >= utf8.RuneSelf
}
// isWSP reports whether r is a WSP (white space).
// WSP is a space or horizontal tab (RFC 5234 Appendix B).
func isWSP(r rune) bool {
return r == ' ' || r == '\t'
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package net provides a portable interface for network I/O, including
TCP/IP, UDP, domain name resolution, and Unix domain sockets.
Although the package provides access to low-level networking
primitives, most clients will need only the basic interface provided
by the Dial, Listen, and Accept functions and the associated
Conn and Listener interfaces. The crypto/tls package uses
the same interfaces and similar Dial and Listen functions.
The Dial function connects to a server:
conn, err := net.Dial("tcp", "golang.org:80")
if err != nil {
// handle error
}
fmt.Fprintf(conn, "GET / HTTP/1.0\r\n\r\n")
status, err := bufio.NewReader(conn).ReadString('\n')
// ...
The Listen function creates servers:
ln, err := net.Listen("tcp", ":8080")
if err != nil {
// handle error
}
for {
conn, err := ln.Accept()
if err != nil {
// handle error
}
go handleConnection(conn)
}
# Name Resolution
The method for resolving domain names, whether indirectly with functions like Dial
or directly with functions like LookupHost and LookupAddr, varies by operating system.
On Unix systems, the resolver has two options for resolving names.
It can use a pure Go resolver that sends DNS requests directly to the servers
listed in /etc/resolv.conf, or it can use a cgo-based resolver that calls C
library routines such as getaddrinfo and getnameinfo.
By default the pure Go resolver is used, because a blocked DNS request consumes
only a goroutine, while a blocked C call consumes an operating system thread.
When cgo is available, the cgo-based resolver is used instead under a variety of
conditions: on systems that do not let programs make direct DNS requests (OS X),
when the LOCALDOMAIN environment variable is present (even if empty),
when the RES_OPTIONS or HOSTALIASES environment variable is non-empty,
when the ASR_CONFIG environment variable is non-empty (OpenBSD only),
when /etc/resolv.conf or /etc/nsswitch.conf specify the use of features that the
Go resolver does not implement, and when the name being looked up ends in .local
or is an mDNS name.
The resolver decision can be overridden by setting the netdns value of the
GODEBUG environment variable (see package runtime) to go or cgo, as in:
export GODEBUG=netdns=go # force pure Go resolver
export GODEBUG=netdns=cgo # force native resolver (cgo, win32)
The decision can also be forced while building the Go source tree
by setting the netgo or netcgo build tag.
A numeric netdns setting, as in GODEBUG=netdns=1, causes the resolver
to print debugging information about its decisions.
To force a particular resolver while also printing debugging information,
join the two settings by a plus sign, as in GODEBUG=netdns=go+1.
On macOS, if Go code that uses the net package is built with
-buildmode=c-archive, linking the resulting archive into a C program
requires passing -lresolv when linking the C code.
On Plan 9, the resolver always accesses /net/cs and /net/dns.
On Windows, in Go 1.18.x and earlier, the resolver always used C
library functions, such as GetAddrInfo and DnsQuery.
*/
package net
import (
"context"
"errors"
"internal/poll"
"io"
"os"
"sync"
"syscall"
"time"
)
// netGo and netCgo contain the state of the build tags used
// to build this binary, and whether cgo is available.
// conf.go mirrors these into conf for easier testing.
var (
netGo bool // set true in cgo_stub.go for build tag "netgo" (or no cgo)
netCgo bool // set true in conf_netcgo.go for build tag "netcgo"
)
// Addr represents a network end point address.
//
// The two methods Network and String conventionally return strings
// that can be passed as the arguments to Dial, but the exact form
// and meaning of the strings is up to the implementation.
type Addr interface {
Network() string // name of the network (for example, "tcp", "udp")
String() string // string form of address (for example, "192.0.2.1:25", "[2001:db8::1]:80")
}
// Conn is a generic stream-oriented network connection.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn interface {
// Read reads data from the connection.
// Read can be made to time out and return an error after a fixed
// time limit; see SetDeadline and SetReadDeadline.
Read(b []byte) (n int, err error)
// Write writes data to the connection.
// Write can be made to time out and return an error after a fixed
// time limit; see SetDeadline and SetWriteDeadline.
Write(b []byte) (n int, err error)
// Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
Close() error
// LocalAddr returns the local network address, if known.
LocalAddr() Addr
// RemoteAddr returns the remote network address, if known.
RemoteAddr() Addr
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail instead of blocking. The deadline applies to all future
// and pending I/O, not just the immediately following call to
// Read or Write. After a deadline has been exceeded, the
// connection can be refreshed by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other
// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// The error's Timeout method will return true, but note that there
// are other possible errors for which the Timeout method will
// return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
// SetReadDeadline sets the deadline for future Read calls
// and any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
}
type conn struct {
fd *netFD
}
func (c *conn) ok() bool { return c != nil && c.fd != nil }
// Implementation of the Conn interface.
// Read implements the Conn Read method.
func (c *conn) Read(b []byte) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.fd.Read(b)
if err != nil && err != io.EOF {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
// Write implements the Conn Write method.
func (c *conn) Write(b []byte) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.fd.Write(b)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
// Close closes the connection.
func (c *conn) Close() error {
if !c.ok() {
return syscall.EINVAL
}
err := c.fd.Close()
if err != nil {
err = &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return err
}
// LocalAddr returns the local network address.
// The Addr returned is shared by all invocations of LocalAddr, so
// do not modify it.
func (c *conn) LocalAddr() Addr {
if !c.ok() {
return nil
}
return c.fd.laddr
}
// RemoteAddr returns the remote network address.
// The Addr returned is shared by all invocations of RemoteAddr, so
// do not modify it.
func (c *conn) RemoteAddr() Addr {
if !c.ok() {
return nil
}
return c.fd.raddr
}
// SetDeadline implements the Conn SetDeadline method.
func (c *conn) SetDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.SetDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
}
// SetReadDeadline implements the Conn SetReadDeadline method.
func (c *conn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.SetReadDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
}
// SetWriteDeadline implements the Conn SetWriteDeadline method.
func (c *conn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.SetWriteDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
}
// SetReadBuffer sets the size of the operating system's
// receive buffer associated with the connection.
func (c *conn) SetReadBuffer(bytes int) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setReadBuffer(c.fd, bytes); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
}
// SetWriteBuffer sets the size of the operating system's
// transmit buffer associated with the connection.
func (c *conn) SetWriteBuffer(bytes int) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setWriteBuffer(c.fd, bytes); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
}
// File returns a copy of the underlying os.File.
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
//
// The returned os.File's file descriptor is different from the connection's.
// Attempting to change properties of the original using this duplicate
// may or may not have the desired effect.
func (c *conn) File() (f *os.File, err error) {
f, err = c.fd.dup()
if err != nil {
err = &OpError{Op: "file", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return
}
// PacketConn is a generic packet-oriented network connection.
//
// Multiple goroutines may invoke methods on a PacketConn simultaneously.
type PacketConn interface {
// ReadFrom reads a packet from the connection,
// copying the payload into p. It returns the number of
// bytes copied into p and the return address that
// was on the packet.
// It returns the number of bytes read (0 <= n <= len(p))
// and any error encountered. Callers should always process
// the n > 0 bytes returned before considering the error err.
// ReadFrom can be made to time out and return an error after a
// fixed time limit; see SetDeadline and SetReadDeadline.
ReadFrom(p []byte) (n int, addr Addr, err error)
// WriteTo writes a packet with payload p to addr.
// WriteTo can be made to time out and return an Error after a
// fixed time limit; see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
WriteTo(p []byte, addr Addr) (n int, err error)
// Close closes the connection.
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
Close() error
// LocalAddr returns the local network address, if known.
LocalAddr() Addr
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail instead of blocking. The deadline applies to all future
// and pending I/O, not just the immediately following call to
// Read or Write. After a deadline has been exceeded, the
// connection can be refreshed by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other
// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// The error's Timeout method will return true, but note that there
// are other possible errors for which the Timeout method will
// return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful ReadFrom or WriteTo calls.
//
// A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
// SetReadDeadline sets the deadline for future ReadFrom calls
// and any currently-blocked ReadFrom call.
// A zero value for t means ReadFrom will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future WriteTo calls
// and any currently-blocked WriteTo call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means WriteTo will not time out.
SetWriteDeadline(t time.Time) error
}
var listenerBacklogCache struct {
sync.Once
val int
}
// listenerBacklog is a caching wrapper around maxListenerBacklog.
func listenerBacklog() int {
listenerBacklogCache.Do(func() { listenerBacklogCache.val = maxListenerBacklog() })
return listenerBacklogCache.val
}
// A Listener is a generic network listener for stream-oriented protocols.
//
// Multiple goroutines may invoke methods on a Listener simultaneously.
type Listener interface {
// Accept waits for and returns the next connection to the listener.
Accept() (Conn, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Addr returns the listener's network address.
Addr() Addr
}
// An Error represents a network error.
type Error interface {
error
Timeout() bool // Is the error a timeout?
// Deprecated: Temporary errors are not well-defined.
// Most "temporary" errors are timeouts, and the few exceptions are surprising.
// Do not use this method.
Temporary() bool
}
// Various errors contained in OpError.
var (
// For connection setup operations.
errNoSuitableAddress = errors.New("no suitable address found")
// For connection setup and write operations.
errMissingAddress = errors.New("missing address")
// For both read and write operations.
errCanceled = canceledError{}
ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection")
)
// canceledError lets us return the same error string we have always
// returned, while still being Is context.Canceled.
type canceledError struct{}
func (canceledError) Error() string { return "operation was canceled" }
func (canceledError) Is(err error) bool { return err == context.Canceled }
// mapErr maps from the context errors to the historical internal net
// error values.
func mapErr(err error) error {
switch err {
case context.Canceled:
return errCanceled
case context.DeadlineExceeded:
return errTimeout
default:
return err
}
}
// OpError is the error type usually returned by functions in the net
// package. It describes the operation, network type, and address of
// an error.
type OpError struct {
// Op is the operation which caused the error, such as
// "read" or "write".
Op string
// Net is the network type on which this error occurred,
// such as "tcp" or "udp6".
Net string
// For operations involving a remote network connection, like
// Dial, Read, or Write, Source is the corresponding local
// network address.
Source Addr
// Addr is the network address for which this error occurred.
// For local operations, like Listen or SetDeadline, Addr is
// the address of the local endpoint being manipulated.
// For operations involving a remote network connection, like
// Dial, Read, or Write, Addr is the remote address of that
// connection.
Addr Addr
// Err is the error that occurred during the operation.
// The Error method panics if the error is nil.
Err error
}
func (e *OpError) Unwrap() error { return e.Err }
func (e *OpError) Error() string {
if e == nil {
return "<nil>"
}
s := e.Op
if e.Net != "" {
s += " " + e.Net
}
if e.Source != nil {
s += " " + e.Source.String()
}
if e.Addr != nil {
if e.Source != nil {
s += "->"
} else {
s += " "
}
s += e.Addr.String()
}
s += ": " + e.Err.Error()
return s
}
var (
// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancellation of dials.
aLongTimeAgo = time.Unix(1, 0)
// noDeadline and noCancel are just zero values for
// readability with functions taking too many parameters.
noDeadline = time.Time{}
noCancel = (chan struct{})(nil)
)
type timeout interface {
Timeout() bool
}
func (e *OpError) Timeout() bool {
if ne, ok := e.Err.(*os.SyscallError); ok {
t, ok := ne.Err.(timeout)
return ok && t.Timeout()
}
t, ok := e.Err.(timeout)
return ok && t.Timeout()
}
type temporary interface {
Temporary() bool
}
func (e *OpError) Temporary() bool {
// Treat ECONNRESET and ECONNABORTED as temporary errors when
// they come from calling accept. See issue 6163.
if e.Op == "accept" && isConnError(e.Err) {
return true
}
if ne, ok := e.Err.(*os.SyscallError); ok {
t, ok := ne.Err.(temporary)
return ok && t.Temporary()
}
t, ok := e.Err.(temporary)
return ok && t.Temporary()
}
// A ParseError is the error type of literal network address parsers.
type ParseError struct {
// Type is the type of string that was expected, such as
// "IP address", "CIDR address".
Type string
// Text is the malformed text string.
Text string
}
func (e *ParseError) Error() string { return "invalid " + e.Type + ": " + e.Text }
func (e *ParseError) Timeout() bool { return false }
func (e *ParseError) Temporary() bool { return false }
type AddrError struct {
Err string
Addr string
}
func (e *AddrError) Error() string {
if e == nil {
return "<nil>"
}
s := e.Err
if e.Addr != "" {
s = "address " + e.Addr + ": " + s
}
return s
}
func (e *AddrError) Timeout() bool { return false }
func (e *AddrError) Temporary() bool { return false }
type UnknownNetworkError string
func (e UnknownNetworkError) Error() string { return "unknown network " + string(e) }
func (e UnknownNetworkError) Timeout() bool { return false }
func (e UnknownNetworkError) Temporary() bool { return false }
type InvalidAddrError string
func (e InvalidAddrError) Error() string { return string(e) }
func (e InvalidAddrError) Timeout() bool { return false }
func (e InvalidAddrError) Temporary() bool { return false }
// errTimeout exists to return the historical "i/o timeout" string
// for context.DeadlineExceeded. See mapErr.
// It is also used when Dialer.Deadline is exceeded.
// error.Is(errTimeout, context.DeadlineExceeded) returns true.
//
// TODO(iant): We could consider changing this to os.ErrDeadlineExceeded
// in the future, if we make
//
// errors.Is(os.ErrDeadlineExceeded, context.DeadlineExceeded)
//
// return true.
var errTimeout error = &timeoutError{}
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
func (e *timeoutError) Is(err error) bool {
return err == context.DeadlineExceeded
}
// DNSConfigError represents an error reading the machine's DNS configuration.
// (No longer used; kept for compatibility.)
type DNSConfigError struct {
Err error
}
func (e *DNSConfigError) Unwrap() error { return e.Err }
func (e *DNSConfigError) Error() string { return "error reading DNS config: " + e.Err.Error() }
func (e *DNSConfigError) Timeout() bool { return false }
func (e *DNSConfigError) Temporary() bool { return false }
// Various errors contained in DNSError.
var (
errNoSuchHost = errors.New("no such host")
)
// DNSError represents a DNS lookup error.
type DNSError struct {
Err string // description of the error
Name string // name looked for
Server string // server used
IsTimeout bool // if true, timed out; not all timeouts set this
IsTemporary bool // if true, error is temporary; not all errors set this
IsNotFound bool // if true, host could not be found
}
func (e *DNSError) Error() string {
if e == nil {
return "<nil>"
}
s := "lookup " + e.Name
if e.Server != "" {
s += " on " + e.Server
}
s += ": " + e.Err
return s
}
// Timeout reports whether the DNS lookup is known to have timed out.
// This is not always known; a DNS lookup may fail due to a timeout
// and return a DNSError for which Timeout returns false.
func (e *DNSError) Timeout() bool { return e.IsTimeout }
// Temporary reports whether the DNS error is known to be temporary.
// This is not always known; a DNS lookup may fail due to a temporary
// error and return a DNSError for which Temporary returns false.
func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary }
// errClosed exists just so that the docs for ErrClosed don't mention
// the internal package poll.
var errClosed = poll.ErrNetClosing
// ErrClosed is the error returned by an I/O call on a network
// connection that has already been closed, or that is closed by
// another goroutine before the I/O is completed. This may be wrapped
// in another error, and should normally be tested using
// errors.Is(err, net.ErrClosed).
var ErrClosed error = errClosed
type writerOnly struct {
io.Writer
}
// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
// applicable.
func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
return io.Copy(writerOnly{w}, r)
}
// Limit the number of concurrent cgo-using goroutines, because
// each will block an entire operating system thread. The usual culprit
// is resolving many DNS names in separate goroutines but the DNS
// server is not responding. Then the many lookups each use a different
// thread, and the system or the program runs out of threads.
var threadLimit chan struct{}
var threadOnce sync.Once
func acquireThread() {
threadOnce.Do(func() {
threadLimit = make(chan struct{}, concurrentThreadsLimit())
})
threadLimit <- struct{}{}
}
func releaseThread() {
<-threadLimit
}
// buffersWriter is the interface implemented by Conns that support a
// "writev"-like batch write optimization.
// writeBuffers should fully consume and write all chunks from the
// provided Buffers, else it should report a non-nil error.
type buffersWriter interface {
writeBuffers(*Buffers) (int64, error)
}
// Buffers contains zero or more runs of bytes to write.
//
// On certain machines, for certain types of connections, this is
// optimized into an OS-specific batch write operation (such as
// "writev").
type Buffers [][]byte
var (
_ io.WriterTo = (*Buffers)(nil)
_ io.Reader = (*Buffers)(nil)
)
// WriteTo writes contents of the buffers to w.
//
// WriteTo implements io.WriterTo for Buffers.
//
// WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v),
// but does not modify v[i][j] for any i, j.
func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
if wv, ok := w.(buffersWriter); ok {
return wv.writeBuffers(v)
}
for _, b := range *v {
nb, err := w.Write(b)
n += int64(nb)
if err != nil {
v.consume(n)
return n, err
}
}
v.consume(n)
return n, nil
}
// Read from the buffers.
//
// Read implements io.Reader for Buffers.
//
// Read modifies the slice v as well as v[i] for 0 <= i < len(v),
// but does not modify v[i][j] for any i, j.
func (v *Buffers) Read(p []byte) (n int, err error) {
for len(p) > 0 && len(*v) > 0 {
n0 := copy(p, (*v)[0])
v.consume(int64(n0))
p = p[n0:]
n += n0
}
if len(*v) == 0 {
err = io.EOF
}
return
}
func (v *Buffers) consume(n int64) {
for len(*v) > 0 {
ln0 := int64(len((*v)[0]))
if ln0 > n {
(*v)[0] = (*v)[0][n:]
return
}
n -= ln0
(*v)[0] = nil
*v = (*v)[1:]
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Stuff that exists in std, but we can't use due to being a dependency
// of net, for go/build deps_test policy reasons.
package netip
func stringsLastIndexByte(s string, b byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == b {
return i
}
}
return -1
}
func beUint64(b []byte) uint64 {
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
}
func bePutUint64(b []byte, v uint64) {
_ = b[7] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 56)
b[1] = byte(v >> 48)
b[2] = byte(v >> 40)
b[3] = byte(v >> 32)
b[4] = byte(v >> 24)
b[5] = byte(v >> 16)
b[6] = byte(v >> 8)
b[7] = byte(v)
}
func bePutUint32(b []byte, v uint32) {
_ = b[3] // early bounds check to guarantee safety of writes below
b[0] = byte(v >> 24)
b[1] = byte(v >> 16)
b[2] = byte(v >> 8)
b[3] = byte(v)
}
func leUint16(b []byte) uint16 {
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint16(b[0]) | uint16(b[1])<<8
}
func lePutUint16(b []byte, v uint16) {
_ = b[1] // early bounds check to guarantee safety of writes below
b[0] = byte(v)
b[1] = byte(v >> 8)
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package netip defines an IP address type that's a small value type.
// Building on that Addr type, the package also defines AddrPort (an
// IP address and a port), and Prefix (an IP address and a bit length
// prefix).
//
// Compared to the net.IP type, this package's Addr type takes less
// memory, is immutable, and is comparable (supports == and being a
// map key).
package netip
import (
"errors"
"math"
"strconv"
"internal/bytealg"
"internal/intern"
"internal/itoa"
)
// Sizes: (64-bit)
// net.IP: 24 byte slice header + {4, 16} = 28 to 40 bytes
// net.IPAddr: 40 byte slice header + {4, 16} = 44 to 56 bytes + zone length
// netip.Addr: 24 bytes (zone is per-name singleton, shared across all users)
// Addr represents an IPv4 or IPv6 address (with or without a scoped
// addressing zone), similar to net.IP or net.IPAddr.
//
// Unlike net.IP or net.IPAddr, Addr is a comparable value
// type (it supports == and can be a map key) and is immutable.
//
// The zero Addr is not a valid IP address.
// Addr{} is distinct from both 0.0.0.0 and ::.
type Addr struct {
// addr is the hi and lo bits of an IPv6 address. If z==z4,
// hi and lo contain the IPv4-mapped IPv6 address.
//
// hi and lo are constructed by interpreting a 16-byte IPv6
// address as a big-endian 128-bit number. The most significant
// bits of that number go into hi, the rest into lo.
//
// For example, 0011:2233:4455:6677:8899:aabb:ccdd:eeff is stored as:
// addr.hi = 0x0011223344556677
// addr.lo = 0x8899aabbccddeeff
//
// We store IPs like this, rather than as [16]byte, because it
// turns most operations on IPs into arithmetic and bit-twiddling
// operations on 64-bit registers, which is much faster than
// bytewise processing.
addr uint128
// z is a combination of the address family and the IPv6 zone.
//
// nil means invalid IP address (for a zero Addr).
// z4 means an IPv4 address.
// z6noz means an IPv6 address without a zone.
//
// Otherwise it's the interned zone name string.
z *intern.Value
}
// z0, z4, and z6noz are sentinel Addr.z values.
// See the Addr type's field docs.
var (
z0 = (*intern.Value)(nil)
z4 = new(intern.Value)
z6noz = new(intern.Value)
)
// IPv6LinkLocalAllNodes returns the IPv6 link-local all nodes multicast
// address ff02::1.
func IPv6LinkLocalAllNodes() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x01}) }
// IPv6LinkLocalAllRouters returns the IPv6 link-local all routers multicast
// address ff02::2.
func IPv6LinkLocalAllRouters() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x02}) }
// IPv6Loopback returns the IPv6 loopback address ::1.
func IPv6Loopback() Addr { return AddrFrom16([16]byte{15: 0x01}) }
// IPv6Unspecified returns the IPv6 unspecified address "::".
func IPv6Unspecified() Addr { return Addr{z: z6noz} }
// IPv4Unspecified returns the IPv4 unspecified address "0.0.0.0".
func IPv4Unspecified() Addr { return AddrFrom4([4]byte{}) }
// AddrFrom4 returns the address of the IPv4 address given by the bytes in addr.
func AddrFrom4(addr [4]byte) Addr {
return Addr{
addr: uint128{0, 0xffff00000000 | uint64(addr[0])<<24 | uint64(addr[1])<<16 | uint64(addr[2])<<8 | uint64(addr[3])},
z: z4,
}
}
// AddrFrom16 returns the IPv6 address given by the bytes in addr.
// An IPv4-mapped IPv6 address is left as an IPv6 address.
// (Use Unmap to convert them if needed.)
func AddrFrom16(addr [16]byte) Addr {
return Addr{
addr: uint128{
beUint64(addr[:8]),
beUint64(addr[8:]),
},
z: z6noz,
}
}
// ParseAddr parses s as an IP address, returning the result. The string
// s can be in dotted decimal ("192.0.2.1"), IPv6 ("2001:db8::68"),
// or IPv6 with a scoped addressing zone ("fe80::1cc0:3e8c:119f:c2e1%ens18").
func ParseAddr(s string) (Addr, error) {
for i := 0; i < len(s); i++ {
switch s[i] {
case '.':
return parseIPv4(s)
case ':':
return parseIPv6(s)
case '%':
// Assume that this was trying to be an IPv6 address with
// a zone specifier, but the address is missing.
return Addr{}, parseAddrError{in: s, msg: "missing IPv6 address"}
}
}
return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"}
}
// MustParseAddr calls ParseAddr(s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParseAddr(s string) Addr {
ip, err := ParseAddr(s)
if err != nil {
panic(err)
}
return ip
}
type parseAddrError struct {
in string // the string given to ParseAddr
msg string // an explanation of the parse failure
at string // optionally, the unparsed portion of in at which the error occurred.
}
func (err parseAddrError) Error() string {
q := strconv.Quote
if err.at != "" {
return "ParseAddr(" + q(err.in) + "): " + err.msg + " (at " + q(err.at) + ")"
}
return "ParseAddr(" + q(err.in) + "): " + err.msg
}
// parseIPv4 parses s as an IPv4 address (in form "192.168.0.1").
func parseIPv4(s string) (ip Addr, err error) {
var fields [4]uint8
var val, pos int
var digLen int // number of digits in current octet
for i := 0; i < len(s); i++ {
if s[i] >= '0' && s[i] <= '9' {
if digLen == 1 && val == 0 {
return Addr{}, parseAddrError{in: s, msg: "IPv4 field has octet with leading zero"}
}
val = val*10 + int(s[i]) - '0'
digLen++
if val > 255 {
return Addr{}, parseAddrError{in: s, msg: "IPv4 field has value >255"}
}
} else if s[i] == '.' {
// .1.2.3
// 1.2.3.
// 1..2.3
if i == 0 || i == len(s)-1 || s[i-1] == '.' {
return Addr{}, parseAddrError{in: s, msg: "IPv4 field must have at least one digit", at: s[i:]}
}
// 1.2.3.4.5
if pos == 3 {
return Addr{}, parseAddrError{in: s, msg: "IPv4 address too long"}
}
fields[pos] = uint8(val)
pos++
val = 0
digLen = 0
} else {
return Addr{}, parseAddrError{in: s, msg: "unexpected character", at: s[i:]}
}
}
if pos < 3 {
return Addr{}, parseAddrError{in: s, msg: "IPv4 address too short"}
}
fields[3] = uint8(val)
return AddrFrom4(fields), nil
}
// parseIPv6 parses s as an IPv6 address (in form "2001:db8::68").
func parseIPv6(in string) (Addr, error) {
s := in
// Split off the zone right from the start. Yes it's a second scan
// of the string, but trying to handle it inline makes a bunch of
// other inner loop conditionals more expensive, and it ends up
// being slower.
zone := ""
i := bytealg.IndexByteString(s, '%')
if i != -1 {
s, zone = s[:i], s[i+1:]
if zone == "" {
// Not allowed to have an empty zone if explicitly specified.
return Addr{}, parseAddrError{in: in, msg: "zone must be a non-empty string"}
}
}
var ip [16]byte
ellipsis := -1 // position of ellipsis in ip
// Might have leading ellipsis
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
ellipsis = 0
s = s[2:]
// Might be only ellipsis
if len(s) == 0 {
return IPv6Unspecified().WithZone(zone), nil
}
}
// Loop, parsing hex numbers followed by colon.
i = 0
for i < 16 {
// Hex number. Similar to parseIPv4, inlining the hex number
// parsing yields a significant performance increase.
off := 0
acc := uint32(0)
for ; off < len(s); off++ {
c := s[off]
if c >= '0' && c <= '9' {
acc = (acc << 4) + uint32(c-'0')
} else if c >= 'a' && c <= 'f' {
acc = (acc << 4) + uint32(c-'a'+10)
} else if c >= 'A' && c <= 'F' {
acc = (acc << 4) + uint32(c-'A'+10)
} else {
break
}
if acc > math.MaxUint16 {
// Overflow, fail.
return Addr{}, parseAddrError{in: in, msg: "IPv6 field has value >=2^16", at: s}
}
}
if off == 0 {
// No digits found, fail.
return Addr{}, parseAddrError{in: in, msg: "each colon-separated field must have at least one digit", at: s}
}
// If followed by dot, might be in trailing IPv4.
if off < len(s) && s[off] == '.' {
if ellipsis < 0 && i != 12 {
// Not the right place.
return Addr{}, parseAddrError{in: in, msg: "embedded IPv4 address must replace the final 2 fields of the address", at: s}
}
if i+4 > 16 {
// Not enough room.
return Addr{}, parseAddrError{in: in, msg: "too many hex fields to fit an embedded IPv4 at the end of the address", at: s}
}
// TODO: could make this a bit faster by having a helper
// that parses to a [4]byte, and have both parseIPv4 and
// parseIPv6 use it.
ip4, err := parseIPv4(s)
if err != nil {
return Addr{}, parseAddrError{in: in, msg: err.Error(), at: s}
}
ip[i] = ip4.v4(0)
ip[i+1] = ip4.v4(1)
ip[i+2] = ip4.v4(2)
ip[i+3] = ip4.v4(3)
s = ""
i += 4
break
}
// Save this 16-bit chunk.
ip[i] = byte(acc >> 8)
ip[i+1] = byte(acc)
i += 2
// Stop at end of string.
s = s[off:]
if len(s) == 0 {
break
}
// Otherwise must be followed by colon and more.
if s[0] != ':' {
return Addr{}, parseAddrError{in: in, msg: "unexpected character, want colon", at: s}
} else if len(s) == 1 {
return Addr{}, parseAddrError{in: in, msg: "colon must be followed by more characters", at: s}
}
s = s[1:]
// Look for ellipsis.
if s[0] == ':' {
if ellipsis >= 0 { // already have one
return Addr{}, parseAddrError{in: in, msg: "multiple :: in address", at: s}
}
ellipsis = i
s = s[1:]
if len(s) == 0 { // can be at end
break
}
}
}
// Must have used entire string.
if len(s) != 0 {
return Addr{}, parseAddrError{in: in, msg: "trailing garbage after address", at: s}
}
// If didn't parse enough, expand ellipsis.
if i < 16 {
if ellipsis < 0 {
return Addr{}, parseAddrError{in: in, msg: "address string too short"}
}
n := 16 - i
for j := i - 1; j >= ellipsis; j-- {
ip[j+n] = ip[j]
}
for j := ellipsis + n - 1; j >= ellipsis; j-- {
ip[j] = 0
}
} else if ellipsis >= 0 {
// Ellipsis must represent at least one 0 group.
return Addr{}, parseAddrError{in: in, msg: "the :: must expand to at least one field of zeros"}
}
return AddrFrom16(ip).WithZone(zone), nil
}
// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address.
// Note that a net.IP can be passed directly as the []byte argument.
// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false.
func AddrFromSlice(slice []byte) (ip Addr, ok bool) {
switch len(slice) {
case 4:
return AddrFrom4([4]byte(slice)), true
case 16:
return AddrFrom16([16]byte(slice)), true
}
return Addr{}, false
}
// v4 returns the i'th byte of ip. If ip is not an IPv4, v4 returns
// unspecified garbage.
func (ip Addr) v4(i uint8) uint8 {
return uint8(ip.addr.lo >> ((3 - i) * 8))
}
// v6 returns the i'th byte of ip. If ip is an IPv4 address, this
// accesses the IPv4-mapped IPv6 address form of the IP.
func (ip Addr) v6(i uint8) uint8 {
return uint8(*(ip.addr.halves()[(i/8)%2]) >> ((7 - i%8) * 8))
}
// v6u16 returns the i'th 16-bit word of ip. If ip is an IPv4 address,
// this accesses the IPv4-mapped IPv6 address form of the IP.
func (ip Addr) v6u16(i uint8) uint16 {
return uint16(*(ip.addr.halves()[(i/4)%2]) >> ((3 - i%4) * 16))
}
// isZero reports whether ip is the zero value of the IP type.
// The zero value is not a valid IP address of any type.
//
// Note that "0.0.0.0" and "::" are not the zero value. Use IsUnspecified to
// check for these values instead.
func (ip Addr) isZero() bool {
// Faster than comparing ip == Addr{}, but effectively equivalent,
// as there's no way to make an IP with a nil z from this package.
return ip.z == z0
}
// IsValid reports whether the Addr is an initialized address (not the zero Addr).
//
// Note that "0.0.0.0" and "::" are both valid values.
func (ip Addr) IsValid() bool { return ip.z != z0 }
// BitLen returns the number of bits in the IP address:
// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr.
//
// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses
// and therefore have bit length 128.
func (ip Addr) BitLen() int {
switch ip.z {
case z0:
return 0
case z4:
return 32
}
return 128
}
// Zone returns ip's IPv6 scoped addressing zone, if any.
func (ip Addr) Zone() string {
if ip.z == nil {
return ""
}
zone, _ := ip.z.Get().(string)
return zone
}
// Compare returns an integer comparing two IPs.
// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2.
// The definition of "less than" is the same as the Less method.
func (ip Addr) Compare(ip2 Addr) int {
f1, f2 := ip.BitLen(), ip2.BitLen()
if f1 < f2 {
return -1
}
if f1 > f2 {
return 1
}
hi1, hi2 := ip.addr.hi, ip2.addr.hi
if hi1 < hi2 {
return -1
}
if hi1 > hi2 {
return 1
}
lo1, lo2 := ip.addr.lo, ip2.addr.lo
if lo1 < lo2 {
return -1
}
if lo1 > lo2 {
return 1
}
if ip.Is6() {
za, zb := ip.Zone(), ip2.Zone()
if za < zb {
return -1
}
if za > zb {
return 1
}
}
return 0
}
// Less reports whether ip sorts before ip2.
// IP addresses sort first by length, then their address.
// IPv6 addresses with zones sort just after the same address without a zone.
func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 }
// Is4 reports whether ip is an IPv4 address.
//
// It returns false for IPv4-mapped IPv6 addresses. See Addr.Unmap.
func (ip Addr) Is4() bool {
return ip.z == z4
}
// Is4In6 reports whether ip is an IPv4-mapped IPv6 address.
func (ip Addr) Is4In6() bool {
return ip.Is6() && ip.addr.hi == 0 && ip.addr.lo>>32 == 0xffff
}
// Is6 reports whether ip is an IPv6 address, including IPv4-mapped
// IPv6 addresses.
func (ip Addr) Is6() bool {
return ip.z != z0 && ip.z != z4
}
// Unmap returns ip with any IPv4-mapped IPv6 address prefix removed.
//
// That is, if ip is an IPv6 address wrapping an IPv4 address, it
// returns the wrapped IPv4 address. Otherwise it returns ip unmodified.
func (ip Addr) Unmap() Addr {
if ip.Is4In6() {
ip.z = z4
}
return ip
}
// WithZone returns an IP that's the same as ip but with the provided
// zone. If zone is empty, the zone is removed. If ip is an IPv4
// address, WithZone is a no-op and returns ip unchanged.
func (ip Addr) WithZone(zone string) Addr {
if !ip.Is6() {
return ip
}
if zone == "" {
ip.z = z6noz
return ip
}
ip.z = intern.GetByString(zone)
return ip
}
// withoutZone unconditionally strips the zone from ip.
// It's similar to WithZone, but small enough to be inlinable.
func (ip Addr) withoutZone() Addr {
if !ip.Is6() {
return ip
}
ip.z = z6noz
return ip
}
// hasZone reports whether ip has an IPv6 zone.
func (ip Addr) hasZone() bool {
return ip.z != z0 && ip.z != z4 && ip.z != z6noz
}
// IsLinkLocalUnicast reports whether ip is a link-local unicast address.
func (ip Addr) IsLinkLocalUnicast() bool {
// Dynamic Configuration of IPv4 Link-Local Addresses
// https://datatracker.ietf.org/doc/html/rfc3927#section-2.1
if ip.Is4() {
return ip.v4(0) == 169 && ip.v4(1) == 254
}
// IP Version 6 Addressing Architecture (2.4 Address Type Identification)
// https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
if ip.Is6() {
return ip.v6u16(0)&0xffc0 == 0xfe80
}
return false // zero value
}
// IsLoopback reports whether ip is a loopback address.
func (ip Addr) IsLoopback() bool {
// Requirements for Internet Hosts -- Communication Layers (3.2.1.3 Addressing)
// https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3
if ip.Is4() {
return ip.v4(0) == 127
}
// IP Version 6 Addressing Architecture (2.4 Address Type Identification)
// https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
if ip.Is6() {
return ip.addr.hi == 0 && ip.addr.lo == 1
}
return false // zero value
}
// IsMulticast reports whether ip is a multicast address.
func (ip Addr) IsMulticast() bool {
// Host Extensions for IP Multicasting (4. HOST GROUP ADDRESSES)
// https://datatracker.ietf.org/doc/html/rfc1112#section-4
if ip.Is4() {
return ip.v4(0)&0xf0 == 0xe0
}
// IP Version 6 Addressing Architecture (2.4 Address Type Identification)
// https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
if ip.Is6() {
return ip.addr.hi>>(64-8) == 0xff // ip.v6(0) == 0xff
}
return false // zero value
}
// IsInterfaceLocalMulticast reports whether ip is an IPv6 interface-local
// multicast address.
func (ip Addr) IsInterfaceLocalMulticast() bool {
// IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
// https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
if ip.Is6() {
return ip.v6u16(0)&0xff0f == 0xff01
}
return false // zero value
}
// IsLinkLocalMulticast reports whether ip is a link-local multicast address.
func (ip Addr) IsLinkLocalMulticast() bool {
// IPv4 Multicast Guidelines (4. Local Network Control Block (224.0.0/24))
// https://datatracker.ietf.org/doc/html/rfc5771#section-4
if ip.Is4() {
return ip.v4(0) == 224 && ip.v4(1) == 0 && ip.v4(2) == 0
}
// IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
// https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
if ip.Is6() {
return ip.v6u16(0)&0xff0f == 0xff02
}
return false // zero value
}
// IsGlobalUnicast reports whether ip is a global unicast address.
//
// It returns true for IPv6 addresses which fall outside of the current
// IANA-allocated 2000::/3 global unicast space, with the exception of the
// link-local address space. It also returns true even if ip is in the IPv4
// private address space or IPv6 unique local address space.
// It returns false for the zero Addr.
//
// For reference, see RFC 1122, RFC 4291, and RFC 4632.
func (ip Addr) IsGlobalUnicast() bool {
if ip.z == z0 {
// Invalid or zero-value.
return false
}
// Match package net's IsGlobalUnicast logic. Notably private IPv4 addresses
// and ULA IPv6 addresses are still considered "global unicast".
if ip.Is4() && (ip == IPv4Unspecified() || ip == AddrFrom4([4]byte{255, 255, 255, 255})) {
return false
}
return ip != IPv6Unspecified() &&
!ip.IsLoopback() &&
!ip.IsMulticast() &&
!ip.IsLinkLocalUnicast()
}
// IsPrivate reports whether ip is a private address, according to RFC 1918
// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether
// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the
// same as net.IP.IsPrivate.
func (ip Addr) IsPrivate() bool {
// Match the stdlib's IsPrivate logic.
if ip.Is4() {
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
// private IPv4 address subnets.
return ip.v4(0) == 10 ||
(ip.v4(0) == 172 && ip.v4(1)&0xf0 == 16) ||
(ip.v4(0) == 192 && ip.v4(1) == 168)
}
if ip.Is6() {
// RFC 4193 allocates fc00::/7 as the unique local unicast IPv6 address
// subnet.
return ip.v6(0)&0xfe == 0xfc
}
return false // zero value
}
// IsUnspecified reports whether ip is an unspecified address, either the IPv4
// address "0.0.0.0" or the IPv6 address "::".
//
// Note that the zero Addr is not an unspecified address.
func (ip Addr) IsUnspecified() bool {
return ip == IPv4Unspecified() || ip == IPv6Unspecified()
}
// Prefix keeps only the top b bits of IP, producing a Prefix
// of the specified length.
// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error.
// Otherwise, if bits is less than zero or greater than ip.BitLen(),
// Prefix returns an error.
func (ip Addr) Prefix(b int) (Prefix, error) {
if b < 0 {
return Prefix{}, errors.New("negative Prefix bits")
}
effectiveBits := b
switch ip.z {
case z0:
return Prefix{}, nil
case z4:
if b > 32 {
return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv4")
}
effectiveBits += 96
default:
if b > 128 {
return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv6")
}
}
ip.addr = ip.addr.and(mask6(effectiveBits))
return PrefixFrom(ip, b), nil
}
const (
netIPv4len = 4
netIPv6len = 16
)
// As16 returns the IP address in its 16-byte representation.
// IPv4 addresses are returned as IPv4-mapped IPv6 addresses.
// IPv6 addresses with zones are returned without their zone (use the
// Zone method to get it).
// The ip zero value returns all zeroes.
func (ip Addr) As16() (a16 [16]byte) {
bePutUint64(a16[:8], ip.addr.hi)
bePutUint64(a16[8:], ip.addr.lo)
return a16
}
// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation.
// If ip is the zero Addr or an IPv6 address, As4 panics.
// Note that 0.0.0.0 is not the zero Addr.
func (ip Addr) As4() (a4 [4]byte) {
if ip.z == z4 || ip.Is4In6() {
bePutUint32(a4[:], uint32(ip.addr.lo))
return a4
}
if ip.z == z0 {
panic("As4 called on IP zero value")
}
panic("As4 called on IPv6 address")
}
// AsSlice returns an IPv4 or IPv6 address in its respective 4-byte or 16-byte representation.
func (ip Addr) AsSlice() []byte {
switch ip.z {
case z0:
return nil
case z4:
var ret [4]byte
bePutUint32(ret[:], uint32(ip.addr.lo))
return ret[:]
default:
var ret [16]byte
bePutUint64(ret[:8], ip.addr.hi)
bePutUint64(ret[8:], ip.addr.lo)
return ret[:]
}
}
// Next returns the address following ip.
// If there is none, it returns the zero Addr.
func (ip Addr) Next() Addr {
ip.addr = ip.addr.addOne()
if ip.Is4() {
if uint32(ip.addr.lo) == 0 {
// Overflowed.
return Addr{}
}
} else {
if ip.addr.isZero() {
// Overflowed
return Addr{}
}
}
return ip
}
// Prev returns the IP before ip.
// If there is none, it returns the IP zero value.
func (ip Addr) Prev() Addr {
if ip.Is4() {
if uint32(ip.addr.lo) == 0 {
return Addr{}
}
} else if ip.addr.isZero() {
return Addr{}
}
ip.addr = ip.addr.subOne()
return ip
}
// String returns the string form of the IP address ip.
// It returns one of 5 forms:
//
// - "invalid IP", if ip is the zero Addr
// - IPv4 dotted decimal ("192.0.2.1")
// - IPv6 ("2001:db8::1")
// - "::ffff:1.2.3.4" (if Is4In6)
// - IPv6 with zone ("fe80:db8::1%eth0")
//
// Note that unlike package net's IP.String method,
// IPv4-mapped IPv6 addresses format with a "::ffff:"
// prefix before the dotted quad.
func (ip Addr) String() string {
switch ip.z {
case z0:
return "invalid IP"
case z4:
return ip.string4()
default:
if ip.Is4In6() {
if z := ip.Zone(); z != "" {
return "::ffff:" + ip.Unmap().string4() + "%" + z
} else {
return "::ffff:" + ip.Unmap().string4()
}
}
return ip.string6()
}
}
// AppendTo appends a text encoding of ip,
// as generated by MarshalText,
// to b and returns the extended buffer.
func (ip Addr) AppendTo(b []byte) []byte {
switch ip.z {
case z0:
return b
case z4:
return ip.appendTo4(b)
default:
if ip.Is4In6() {
b = append(b, "::ffff:"...)
b = ip.Unmap().appendTo4(b)
if z := ip.Zone(); z != "" {
b = append(b, '%')
b = append(b, z...)
}
return b
}
return ip.appendTo6(b)
}
}
// digits is a string of the hex digits from 0 to f. It's used in
// appendDecimal and appendHex to format IP addresses.
const digits = "0123456789abcdef"
// appendDecimal appends the decimal string representation of x to b.
func appendDecimal(b []byte, x uint8) []byte {
// Using this function rather than strconv.AppendUint makes IPv4
// string building 2x faster.
if x >= 100 {
b = append(b, digits[x/100])
}
if x >= 10 {
b = append(b, digits[x/10%10])
}
return append(b, digits[x%10])
}
// appendHex appends the hex string representation of x to b.
func appendHex(b []byte, x uint16) []byte {
// Using this function rather than strconv.AppendUint makes IPv6
// string building 2x faster.
if x >= 0x1000 {
b = append(b, digits[x>>12])
}
if x >= 0x100 {
b = append(b, digits[x>>8&0xf])
}
if x >= 0x10 {
b = append(b, digits[x>>4&0xf])
}
return append(b, digits[x&0xf])
}
// appendHexPad appends the fully padded hex string representation of x to b.
func appendHexPad(b []byte, x uint16) []byte {
return append(b, digits[x>>12], digits[x>>8&0xf], digits[x>>4&0xf], digits[x&0xf])
}
func (ip Addr) string4() string {
const max = len("255.255.255.255")
ret := make([]byte, 0, max)
ret = ip.appendTo4(ret)
return string(ret)
}
func (ip Addr) appendTo4(ret []byte) []byte {
ret = appendDecimal(ret, ip.v4(0))
ret = append(ret, '.')
ret = appendDecimal(ret, ip.v4(1))
ret = append(ret, '.')
ret = appendDecimal(ret, ip.v4(2))
ret = append(ret, '.')
ret = appendDecimal(ret, ip.v4(3))
return ret
}
// string6 formats ip in IPv6 textual representation. It follows the
// guidelines in section 4 of RFC 5952
// (https://tools.ietf.org/html/rfc5952#section-4): no unnecessary
// zeros, use :: to elide the longest run of zeros, and don't use ::
// to compact a single zero field.
func (ip Addr) string6() string {
// Use a zone with a "plausibly long" name, so that most zone-ful
// IP addresses won't require additional allocation.
//
// The compiler does a cool optimization here, where ret ends up
// stack-allocated and so the only allocation this function does
// is to construct the returned string. As such, it's okay to be a
// bit greedy here, size-wise.
const max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
ret := make([]byte, 0, max)
ret = ip.appendTo6(ret)
return string(ret)
}
func (ip Addr) appendTo6(ret []byte) []byte {
zeroStart, zeroEnd := uint8(255), uint8(255)
for i := uint8(0); i < 8; i++ {
j := i
for j < 8 && ip.v6u16(j) == 0 {
j++
}
if l := j - i; l >= 2 && l > zeroEnd-zeroStart {
zeroStart, zeroEnd = i, j
}
}
for i := uint8(0); i < 8; i++ {
if i == zeroStart {
ret = append(ret, ':', ':')
i = zeroEnd
if i >= 8 {
break
}
} else if i > 0 {
ret = append(ret, ':')
}
ret = appendHex(ret, ip.v6u16(i))
}
if ip.z != z6noz {
ret = append(ret, '%')
ret = append(ret, ip.Zone()...)
}
return ret
}
// StringExpanded is like String but IPv6 addresses are expanded with leading
// zeroes and no "::" compression. For example, "2001:db8::1" becomes
// "2001:0db8:0000:0000:0000:0000:0000:0001".
func (ip Addr) StringExpanded() string {
switch ip.z {
case z0, z4:
return ip.String()
}
const size = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
ret := make([]byte, 0, size)
for i := uint8(0); i < 8; i++ {
if i > 0 {
ret = append(ret, ':')
}
ret = appendHexPad(ret, ip.v6u16(i))
}
if ip.z != z6noz {
// The addition of a zone will cause a second allocation, but when there
// is no zone the ret slice will be stack allocated.
ret = append(ret, '%')
ret = append(ret, ip.Zone()...)
}
return string(ret)
}
// MarshalText implements the encoding.TextMarshaler interface,
// The encoding is the same as returned by String, with one exception:
// If ip is the zero Addr, the encoding is the empty string.
func (ip Addr) MarshalText() ([]byte, error) {
switch ip.z {
case z0:
return []byte(""), nil
case z4:
max := len("255.255.255.255")
b := make([]byte, 0, max)
return ip.appendTo4(b), nil
default:
max := len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
b := make([]byte, 0, max)
if ip.Is4In6() {
b = append(b, "::ffff:"...)
b = ip.Unmap().appendTo4(b)
if z := ip.Zone(); z != "" {
b = append(b, '%')
b = append(b, z...)
}
return b, nil
}
return ip.appendTo6(b), nil
}
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The IP address is expected in a form accepted by ParseAddr.
//
// If text is empty, UnmarshalText sets *ip to the zero Addr and
// returns no error.
func (ip *Addr) UnmarshalText(text []byte) error {
if len(text) == 0 {
*ip = Addr{}
return nil
}
var err error
*ip, err = ParseAddr(string(text))
return err
}
func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte {
var b []byte
switch ip.z {
case z0:
b = make([]byte, trailingBytes)
case z4:
b = make([]byte, 4+trailingBytes)
bePutUint32(b, uint32(ip.addr.lo))
default:
z := ip.Zone()
b = make([]byte, 16+len(z)+trailingBytes)
bePutUint64(b[:8], ip.addr.hi)
bePutUint64(b[8:], ip.addr.lo)
copy(b[16:], z)
}
return b
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
// It returns a zero-length slice for the zero Addr,
// the 4-byte form for an IPv4 address,
// and the 16-byte form with zone appended for an IPv6 address.
func (ip Addr) MarshalBinary() ([]byte, error) {
return ip.marshalBinaryWithTrailingBytes(0), nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It expects data in the form generated by MarshalBinary.
func (ip *Addr) UnmarshalBinary(b []byte) error {
n := len(b)
switch {
case n == 0:
*ip = Addr{}
return nil
case n == 4:
*ip = AddrFrom4([4]byte(b))
return nil
case n == 16:
*ip = AddrFrom16([16]byte(b))
return nil
case n > 16:
*ip = AddrFrom16([16]byte(b[:16])).WithZone(string(b[16:]))
return nil
}
return errors.New("unexpected slice size")
}
// AddrPort is an IP and a port number.
type AddrPort struct {
ip Addr
port uint16
}
// AddrPortFrom returns an AddrPort with the provided IP and port.
// It does not allocate.
func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} }
// Addr returns p's IP address.
func (p AddrPort) Addr() Addr { return p.ip }
// Port returns p's port.
func (p AddrPort) Port() uint16 { return p.port }
// splitAddrPort splits s into an IP address string and a port
// string. It splits strings shaped like "foo:bar" or "[foo]:bar",
// without further validating the substrings. v6 indicates whether the
// ip string should parse as an IPv6 address or an IPv4 address, in
// order for s to be a valid ip:port string.
func splitAddrPort(s string) (ip, port string, v6 bool, err error) {
i := stringsLastIndexByte(s, ':')
if i == -1 {
return "", "", false, errors.New("not an ip:port")
}
ip, port = s[:i], s[i+1:]
if len(ip) == 0 {
return "", "", false, errors.New("no IP")
}
if len(port) == 0 {
return "", "", false, errors.New("no port")
}
if ip[0] == '[' {
if len(ip) < 2 || ip[len(ip)-1] != ']' {
return "", "", false, errors.New("missing ]")
}
ip = ip[1 : len(ip)-1]
v6 = true
}
return ip, port, v6, nil
}
// ParseAddrPort parses s as an AddrPort.
//
// It doesn't do any name resolution: both the address and the port
// must be numeric.
func ParseAddrPort(s string) (AddrPort, error) {
var ipp AddrPort
ip, port, v6, err := splitAddrPort(s)
if err != nil {
return ipp, err
}
port16, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return ipp, errors.New("invalid port " + strconv.Quote(port) + " parsing " + strconv.Quote(s))
}
ipp.port = uint16(port16)
ipp.ip, err = ParseAddr(ip)
if err != nil {
return AddrPort{}, err
}
if v6 && ipp.ip.Is4() {
return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", square brackets can only be used with IPv6 addresses")
} else if !v6 && ipp.ip.Is6() {
return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", IPv6 addresses must be surrounded by square brackets")
}
return ipp, nil
}
// MustParseAddrPort calls ParseAddrPort(s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParseAddrPort(s string) AddrPort {
ip, err := ParseAddrPort(s)
if err != nil {
panic(err)
}
return ip
}
// IsValid reports whether p.Addr() is valid.
// All ports are valid, including zero.
func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
func (p AddrPort) String() string {
switch p.ip.z {
case z0:
return "invalid AddrPort"
case z4:
a := p.ip.As4()
buf := make([]byte, 0, 21)
for i := range a {
buf = strconv.AppendUint(buf, uint64(a[i]), 10)
buf = append(buf, "...:"[i])
}
buf = strconv.AppendUint(buf, uint64(p.port), 10)
return string(buf)
default:
// TODO: this could be more efficient allocation-wise:
return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port)))
}
}
func joinHostPort(host, port string) string {
// We assume that host is a literal IPv6 address if host has
// colons.
if bytealg.IndexByteString(host, ':') >= 0 {
return "[" + host + "]:" + port
}
return host + ":" + port
}
// AppendTo appends a text encoding of p,
// as generated by MarshalText,
// to b and returns the extended buffer.
func (p AddrPort) AppendTo(b []byte) []byte {
switch p.ip.z {
case z0:
return b
case z4:
b = p.ip.appendTo4(b)
default:
if p.ip.Is4In6() {
b = append(b, "[::ffff:"...)
b = p.ip.Unmap().appendTo4(b)
if z := p.ip.Zone(); z != "" {
b = append(b, '%')
b = append(b, z...)
}
} else {
b = append(b, '[')
b = p.ip.appendTo6(b)
}
b = append(b, ']')
}
b = append(b, ':')
b = strconv.AppendUint(b, uint64(p.port), 10)
return b
}
// MarshalText implements the encoding.TextMarshaler interface. The
// encoding is the same as returned by String, with one exception: if
// p.Addr() is the zero Addr, the encoding is the empty string.
func (p AddrPort) MarshalText() ([]byte, error) {
var max int
switch p.ip.z {
case z0:
case z4:
max = len("255.255.255.255:65535")
default:
max = len("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535")
}
b := make([]byte, 0, max)
b = p.AppendTo(b)
return b, nil
}
// UnmarshalText implements the encoding.TextUnmarshaler
// interface. The AddrPort is expected in a form
// generated by MarshalText or accepted by ParseAddrPort.
func (p *AddrPort) UnmarshalText(text []byte) error {
if len(text) == 0 {
*p = AddrPort{}
return nil
}
var err error
*p, err = ParseAddrPort(string(text))
return err
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
// It returns Addr.MarshalBinary with an additional two bytes appended
// containing the port in little-endian.
func (p AddrPort) MarshalBinary() ([]byte, error) {
b := p.Addr().marshalBinaryWithTrailingBytes(2)
lePutUint16(b[len(b)-2:], p.Port())
return b, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It expects data in the form generated by MarshalBinary.
func (p *AddrPort) UnmarshalBinary(b []byte) error {
if len(b) < 2 {
return errors.New("unexpected slice size")
}
var addr Addr
err := addr.UnmarshalBinary(b[:len(b)-2])
if err != nil {
return err
}
*p = AddrPortFrom(addr, leUint16(b[len(b)-2:]))
return nil
}
// Prefix is an IP address prefix (CIDR) representing an IP network.
//
// The first Bits() of Addr() are specified. The remaining bits match any address.
// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6.
type Prefix struct {
ip Addr
// bitsPlusOne stores the prefix bit length plus one.
// A Prefix is valid if and only if bitsPlusOne is non-zero.
bitsPlusOne uint8
}
// PrefixFrom returns a Prefix with the provided IP address and bit
// prefix length.
//
// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask
// off the host bits of ip.
//
// If bits is less than zero or greater than ip.BitLen, Prefix.Bits
// will return an invalid value -1.
func PrefixFrom(ip Addr, bits int) Prefix {
var bitsPlusOne uint8
if !ip.isZero() && bits >= 0 && bits <= ip.BitLen() {
bitsPlusOne = uint8(bits) + 1
}
return Prefix{
ip: ip.withoutZone(),
bitsPlusOne: bitsPlusOne,
}
}
// Addr returns p's IP address.
func (p Prefix) Addr() Addr { return p.ip }
// Bits returns p's prefix length.
//
// It reports -1 if invalid.
func (p Prefix) Bits() int { return int(p.bitsPlusOne) - 1 }
// IsValid reports whether p.Bits() has a valid range for p.Addr().
// If p.Addr() is the zero Addr, IsValid returns false.
// Note that if p is the zero Prefix, then p.IsValid() == false.
func (p Prefix) IsValid() bool { return p.bitsPlusOne > 0 }
func (p Prefix) isZero() bool { return p == Prefix{} }
// IsSingleIP reports whether p contains exactly one IP.
func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }
// ParsePrefix parses s as an IP address prefix.
// The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
// the CIDR notation defined in RFC 4632 and RFC 4291.
// IPv6 zones are not permitted in prefixes, and an error will be returned if a
// zone is present.
//
// Note that masked address bits are not zeroed. Use Masked for that.
func ParsePrefix(s string) (Prefix, error) {
i := stringsLastIndexByte(s, '/')
if i < 0 {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'")
}
ip, err := ParseAddr(s[:i])
if err != nil {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): " + err.Error())
}
// IPv6 zones are not allowed: https://go.dev/issue/51899
if ip.Is6() && ip.z != z6noz {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): IPv6 zones cannot be present in a prefix")
}
bitsStr := s[i+1:]
bits, err := strconv.Atoi(bitsStr)
if err != nil {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr))
}
maxBits := 32
if ip.Is6() {
maxBits = 128
}
if bits < 0 || bits > maxBits {
return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): prefix length out of range")
}
return PrefixFrom(ip, bits), nil
}
// MustParsePrefix calls ParsePrefix(s) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustParsePrefix(s string) Prefix {
ip, err := ParsePrefix(s)
if err != nil {
panic(err)
}
return ip
}
// Masked returns p in its canonical form, with all but the high
// p.Bits() bits of p.Addr() masked off.
//
// If p is zero or otherwise invalid, Masked returns the zero Prefix.
func (p Prefix) Masked() Prefix {
m, _ := p.ip.Prefix(p.Bits())
return m
}
// Contains reports whether the network p includes ip.
//
// An IPv4 address will not match an IPv6 prefix.
// An IPv4-mapped IPv6 address will not match an IPv4 prefix.
// A zero-value IP will not match any prefix.
// If ip has an IPv6 zone, Contains returns false,
// because Prefixes strip zones.
func (p Prefix) Contains(ip Addr) bool {
if !p.IsValid() || ip.hasZone() {
return false
}
if f1, f2 := p.ip.BitLen(), ip.BitLen(); f1 == 0 || f2 == 0 || f1 != f2 {
return false
}
if ip.Is4() {
// xor the IP addresses together; mismatched bits are now ones.
// Shift away the number of bits we don't care about.
// Shifts in Go are more efficient if the compiler can prove
// that the shift amount is smaller than the width of the shifted type (64 here).
// We know that p.bits is in the range 0..32 because p is Valid;
// the compiler doesn't know that, so mask with 63 to help it.
// Now truncate to 32 bits, because this is IPv4.
// If all the bits we care about are equal, the result will be zero.
return uint32((ip.addr.lo^p.ip.addr.lo)>>((32-p.Bits())&63)) == 0
} else {
// xor the IP addresses together.
// Mask away the bits we don't care about.
// If all the bits we care about are equal, the result will be zero.
return ip.addr.xor(p.ip.addr).and(mask6(p.Bits())).isZero()
}
}
// Overlaps reports whether p and o contain any IP addresses in common.
//
// If p and o are of different address families or either have a zero
// IP, it reports false. Like the Contains method, a prefix with an
// IPv4-mapped IPv6 address is still treated as an IPv6 mask.
func (p Prefix) Overlaps(o Prefix) bool {
if !p.IsValid() || !o.IsValid() {
return false
}
if p == o {
return true
}
if p.ip.Is4() != o.ip.Is4() {
return false
}
var minBits int
if pb, ob := p.Bits(), o.Bits(); pb < ob {
minBits = pb
} else {
minBits = ob
}
if minBits == 0 {
return true
}
// One of these Prefix calls might look redundant, but we don't require
// that p and o values are normalized (via Prefix.Masked) first,
// so the Prefix call on the one that's already minBits serves to zero
// out any remaining bits in IP.
var err error
if p, err = p.ip.Prefix(minBits); err != nil {
return false
}
if o, err = o.ip.Prefix(minBits); err != nil {
return false
}
return p.ip == o.ip
}
// AppendTo appends a text encoding of p,
// as generated by MarshalText,
// to b and returns the extended buffer.
func (p Prefix) AppendTo(b []byte) []byte {
if p.isZero() {
return b
}
if !p.IsValid() {
return append(b, "invalid Prefix"...)
}
// p.ip is non-nil, because p is valid.
if p.ip.z == z4 {
b = p.ip.appendTo4(b)
} else {
if p.ip.Is4In6() {
b = append(b, "::ffff:"...)
b = p.ip.Unmap().appendTo4(b)
} else {
b = p.ip.appendTo6(b)
}
}
b = append(b, '/')
b = appendDecimal(b, uint8(p.Bits()))
return b
}
// MarshalText implements the encoding.TextMarshaler interface,
// The encoding is the same as returned by String, with one exception:
// If p is the zero value, the encoding is the empty string.
func (p Prefix) MarshalText() ([]byte, error) {
var max int
switch p.ip.z {
case z0:
case z4:
max = len("255.255.255.255/32")
default:
max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0/128")
}
b := make([]byte, 0, max)
b = p.AppendTo(b)
return b, nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The IP address is expected in a form accepted by ParsePrefix
// or generated by MarshalText.
func (p *Prefix) UnmarshalText(text []byte) error {
if len(text) == 0 {
*p = Prefix{}
return nil
}
var err error
*p, err = ParsePrefix(string(text))
return err
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
// It returns Addr.MarshalBinary with an additional byte appended
// containing the prefix bits.
func (p Prefix) MarshalBinary() ([]byte, error) {
b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1)
b[len(b)-1] = uint8(p.Bits())
return b, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
// It expects data in the form generated by MarshalBinary.
func (p *Prefix) UnmarshalBinary(b []byte) error {
if len(b) < 1 {
return errors.New("unexpected slice size")
}
var addr Addr
err := addr.UnmarshalBinary(b[:len(b)-1])
if err != nil {
return err
}
*p = PrefixFrom(addr, int(b[len(b)-1]))
return nil
}
// String returns the CIDR notation of p: "<ip>/<bits>".
func (p Prefix) String() string {
if !p.IsValid() {
return "invalid Prefix"
}
return p.ip.String() + "/" + itoa.Itoa(p.Bits())
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netip
import "math/bits"
// uint128 represents a uint128 using two uint64s.
//
// When the methods below mention a bit number, bit 0 is the most
// significant bit (in hi) and bit 127 is the lowest (lo&1).
type uint128 struct {
hi uint64
lo uint64
}
// mask6 returns a uint128 bitmask with the topmost n bits of a
// 128-bit number.
func mask6(n int) uint128 {
return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)}
}
// isZero reports whether u == 0.
//
// It's faster than u == (uint128{}) because the compiler (as of Go
// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in
// its eq alg's generated code.
func (u uint128) isZero() bool { return u.hi|u.lo == 0 }
// and returns the bitwise AND of u and m (u&m).
func (u uint128) and(m uint128) uint128 {
return uint128{u.hi & m.hi, u.lo & m.lo}
}
// xor returns the bitwise XOR of u and m (u^m).
func (u uint128) xor(m uint128) uint128 {
return uint128{u.hi ^ m.hi, u.lo ^ m.lo}
}
// or returns the bitwise OR of u and m (u|m).
func (u uint128) or(m uint128) uint128 {
return uint128{u.hi | m.hi, u.lo | m.lo}
}
// not returns the bitwise NOT of u.
func (u uint128) not() uint128 {
return uint128{^u.hi, ^u.lo}
}
// subOne returns u - 1.
func (u uint128) subOne() uint128 {
lo, borrow := bits.Sub64(u.lo, 1, 0)
return uint128{u.hi - borrow, lo}
}
// addOne returns u + 1.
func (u uint128) addOne() uint128 {
lo, carry := bits.Add64(u.lo, 1, 0)
return uint128{u.hi + carry, lo}
}
// halves returns the two uint64 halves of the uint128.
//
// Logically, think of it as returning two uint64s.
// It only returns pointers for inlining reasons on 32-bit platforms.
func (u *uint128) halves() [2]*uint64 {
return [2]*uint64{&u.hi, &u.lo}
}
// bitsSetFrom returns a copy of u with the given bit
// and all subsequent ones set.
func (u uint128) bitsSetFrom(bit uint8) uint128 {
return u.or(mask6(int(bit)).not())
}
// bitsClearedFrom returns a copy of u with the given bit
// and all subsequent ones cleared.
func (u uint128) bitsClearedFrom(bit uint8) uint128 {
return u.and(mask6(int(bit)))
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"errors"
"internal/bytealg"
"os"
"sync"
"time"
)
const (
nssConfigPath = "/etc/nsswitch.conf"
)
var nssConfig nsswitchConfig
type nsswitchConfig struct {
initOnce sync.Once // guards init of nsswitchConfig
// ch is used as a semaphore that only allows one lookup at a
// time to recheck nsswitch.conf
ch chan struct{} // guards lastChecked and modTime
lastChecked time.Time // last time nsswitch.conf was checked
mu sync.Mutex // protects nssConf
nssConf *nssConf
}
func getSystemNSS() *nssConf {
nssConfig.tryUpdate()
nssConfig.mu.Lock()
conf := nssConfig.nssConf
nssConfig.mu.Unlock()
return conf
}
// init initializes conf and is only called via conf.initOnce.
func (conf *nsswitchConfig) init() {
conf.nssConf = parseNSSConfFile("/etc/nsswitch.conf")
conf.lastChecked = time.Now()
conf.ch = make(chan struct{}, 1)
}
// tryUpdate tries to update conf.
func (conf *nsswitchConfig) tryUpdate() {
conf.initOnce.Do(conf.init)
// Ensure only one update at a time checks nsswitch.conf
if !conf.tryAcquireSema() {
return
}
defer conf.releaseSema()
now := time.Now()
if conf.lastChecked.After(now.Add(-5 * time.Second)) {
return
}
conf.lastChecked = now
var mtime time.Time
if fi, err := os.Stat(nssConfigPath); err == nil {
mtime = fi.ModTime()
}
if mtime.Equal(conf.nssConf.mtime) {
return
}
nssConf := parseNSSConfFile(nssConfigPath)
conf.mu.Lock()
conf.nssConf = nssConf
conf.mu.Unlock()
}
func (conf *nsswitchConfig) acquireSema() {
conf.ch <- struct{}{}
}
func (conf *nsswitchConfig) tryAcquireSema() bool {
select {
case conf.ch <- struct{}{}:
return true
default:
return false
}
}
func (conf *nsswitchConfig) releaseSema() {
<-conf.ch
}
// nssConf represents the state of the machine's /etc/nsswitch.conf file.
type nssConf struct {
mtime time.Time // time of nsswitch.conf modification
err error // any error encountered opening or parsing the file
sources map[string][]nssSource // keyed by database (e.g. "hosts")
}
type nssSource struct {
source string // e.g. "compat", "files", "mdns4_minimal"
criteria []nssCriterion
}
// standardCriteria reports all specified criteria have the default
// status actions.
func (s nssSource) standardCriteria() bool {
for i, crit := range s.criteria {
if !crit.standardStatusAction(i == len(s.criteria)-1) {
return false
}
}
return true
}
// nssCriterion is the parsed structure of one of the criteria in brackets
// after an NSS source name.
type nssCriterion struct {
negate bool // if "!" was present
status string // e.g. "success", "unavail" (lowercase)
action string // e.g. "return", "continue" (lowercase)
}
// standardStatusAction reports whether c is equivalent to not
// specifying the criterion at all. last is whether this criteria is the
// last in the list.
func (c nssCriterion) standardStatusAction(last bool) bool {
if c.negate {
return false
}
var def string
switch c.status {
case "success":
def = "return"
case "notfound", "unavail", "tryagain":
def = "continue"
default:
// Unknown status
return false
}
if last && c.action == "return" {
return true
}
return c.action == def
}
func parseNSSConfFile(file string) *nssConf {
f, err := open(file)
if err != nil {
return &nssConf{err: err}
}
defer f.close()
mtime, _, err := f.stat()
if err != nil {
return &nssConf{err: err}
}
conf := parseNSSConf(f)
conf.mtime = mtime
return conf
}
func parseNSSConf(f *file) *nssConf {
conf := new(nssConf)
for line, ok := f.readLine(); ok; line, ok = f.readLine() {
line = trimSpace(removeComment(line))
if len(line) == 0 {
continue
}
colon := bytealg.IndexByteString(line, ':')
if colon == -1 {
conf.err = errors.New("no colon on line")
return conf
}
db := trimSpace(line[:colon])
srcs := line[colon+1:]
for {
srcs = trimSpace(srcs)
if len(srcs) == 0 {
break
}
sp := bytealg.IndexByteString(srcs, ' ')
var src string
if sp == -1 {
src = srcs
srcs = "" // done
} else {
src = srcs[:sp]
srcs = trimSpace(srcs[sp+1:])
}
var criteria []nssCriterion
// See if there's a criteria block in brackets.
if len(srcs) > 0 && srcs[0] == '[' {
bclose := bytealg.IndexByteString(srcs, ']')
if bclose == -1 {
conf.err = errors.New("unclosed criterion bracket")
return conf
}
var err error
criteria, err = parseCriteria(srcs[1:bclose])
if err != nil {
conf.err = errors.New("invalid criteria: " + srcs[1:bclose])
return conf
}
srcs = srcs[bclose+1:]
}
if conf.sources == nil {
conf.sources = make(map[string][]nssSource)
}
conf.sources[db] = append(conf.sources[db], nssSource{
source: src,
criteria: criteria,
})
}
}
return conf
}
// parses "foo=bar !foo=bar"
func parseCriteria(x string) (c []nssCriterion, err error) {
err = foreachField(x, func(f string) error {
not := false
if len(f) > 0 && f[0] == '!' {
not = true
f = f[1:]
}
if len(f) < 3 {
return errors.New("criterion too short")
}
eq := bytealg.IndexByteString(f, '=')
if eq == -1 {
return errors.New("criterion lacks equal sign")
}
if hasUpperCase(f) {
lower := []byte(f)
lowerASCIIBytes(lower)
f = string(lower)
}
c = append(c, nssCriterion{
negate: not,
status: f[:eq],
action: f[eq+1:],
})
return nil
})
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Simple file i/o and string manipulation, to avoid
// depending on strconv and bufio and strings.
package net
import (
"internal/bytealg"
"io"
"os"
"time"
)
type file struct {
file *os.File
data []byte
atEOF bool
}
func (f *file) close() { f.file.Close() }
func (f *file) getLineFromData() (s string, ok bool) {
data := f.data
i := 0
for i = 0; i < len(data); i++ {
if data[i] == '\n' {
s = string(data[0:i])
ok = true
// move data
i++
n := len(data) - i
copy(data[0:], data[i:])
f.data = data[0:n]
return
}
}
if f.atEOF && len(f.data) > 0 {
// EOF, return all we have
s = string(data)
f.data = f.data[0:0]
ok = true
}
return
}
func (f *file) readLine() (s string, ok bool) {
if s, ok = f.getLineFromData(); ok {
return
}
if len(f.data) < cap(f.data) {
ln := len(f.data)
n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)])
if n >= 0 {
f.data = f.data[0 : ln+n]
}
if err == io.EOF || err == io.ErrUnexpectedEOF {
f.atEOF = true
}
}
s, ok = f.getLineFromData()
return
}
func (f *file) stat() (mtime time.Time, size int64, err error) {
st, err := f.file.Stat()
if err != nil {
return time.Time{}, 0, err
}
return st.ModTime(), st.Size(), nil
}
func open(name string) (*file, error) {
fd, err := os.Open(name)
if err != nil {
return nil, err
}
return &file{fd, make([]byte, 0, 64*1024), false}, nil
}
func stat(name string) (mtime time.Time, size int64, err error) {
st, err := os.Stat(name)
if err != nil {
return time.Time{}, 0, err
}
return st.ModTime(), st.Size(), nil
}
// Count occurrences in s of any bytes in t.
func countAnyByte(s string, t string) int {
n := 0
for i := 0; i < len(s); i++ {
if bytealg.IndexByteString(t, s[i]) >= 0 {
n++
}
}
return n
}
// Split s at any bytes in t.
func splitAtBytes(s string, t string) []string {
a := make([]string, 1+countAnyByte(s, t))
n := 0
last := 0
for i := 0; i < len(s); i++ {
if bytealg.IndexByteString(t, s[i]) >= 0 {
if last < i {
a[n] = s[last:i]
n++
}
last = i + 1
}
}
if last < len(s) {
a[n] = s[last:]
n++
}
return a[0:n]
}
func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") }
// Bigger than we need, not too big to worry about overflow
const big = 0xFFFFFF
// Decimal to integer.
// Returns number, characters consumed, success.
func dtoi(s string) (n int, i int, ok bool) {
n = 0
for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
n = n*10 + int(s[i]-'0')
if n >= big {
return big, i, false
}
}
if i == 0 {
return 0, 0, false
}
return n, i, true
}
// Hexadecimal to integer.
// Returns number, characters consumed, success.
func xtoi(s string) (n int, i int, ok bool) {
n = 0
for i = 0; i < len(s); i++ {
if '0' <= s[i] && s[i] <= '9' {
n *= 16
n += int(s[i] - '0')
} else if 'a' <= s[i] && s[i] <= 'f' {
n *= 16
n += int(s[i]-'a') + 10
} else if 'A' <= s[i] && s[i] <= 'F' {
n *= 16
n += int(s[i]-'A') + 10
} else {
break
}
if n >= big {
return 0, i, false
}
}
if i == 0 {
return 0, i, false
}
return n, i, true
}
// xtoi2 converts the next two hex digits of s into a byte.
// If s is longer than 2 bytes then the third byte must be e.
// If the first two bytes of s are not hex digits or the third byte
// does not match e, false is returned.
func xtoi2(s string, e byte) (byte, bool) {
if len(s) > 2 && s[2] != e {
return 0, false
}
n, ei, ok := xtoi(s[:2])
return byte(n), ok && ei == 2
}
// Convert i to a hexadecimal string. Leading zeros are not printed.
func appendHex(dst []byte, i uint32) []byte {
if i == 0 {
return append(dst, '0')
}
for j := 7; j >= 0; j-- {
v := i >> uint(j*4)
if v > 0 {
dst = append(dst, hexDigit[v&0xf])
}
}
return dst
}
// Number of occurrences of b in s.
func count(s string, b byte) int {
n := 0
for i := 0; i < len(s); i++ {
if s[i] == b {
n++
}
}
return n
}
// Index of rightmost occurrence of b in s.
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}
// hasUpperCase tells whether the given string contains at least one upper-case.
func hasUpperCase(s string) bool {
for i := range s {
if 'A' <= s[i] && s[i] <= 'Z' {
return true
}
}
return false
}
// lowerASCIIBytes makes x ASCII lowercase in-place.
func lowerASCIIBytes(x []byte) {
for i, b := range x {
if 'A' <= b && b <= 'Z' {
x[i] += 'a' - 'A'
}
}
}
// lowerASCII returns the ASCII lowercase version of b.
func lowerASCII(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// trimSpace returns x without any leading or trailing ASCII whitespace.
func trimSpace(x string) string {
for len(x) > 0 && isSpace(x[0]) {
x = x[1:]
}
for len(x) > 0 && isSpace(x[len(x)-1]) {
x = x[:len(x)-1]
}
return x
}
// isSpace reports whether b is an ASCII space character.
func isSpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
// removeComment returns line, removing any '#' byte and any following
// bytes.
func removeComment(line string) string {
if i := bytealg.IndexByteString(line, '#'); i != -1 {
return line[:i]
}
return line
}
// foreachField runs fn on each non-empty run of non-space bytes in x.
// It returns the first non-nil error returned by fn.
func foreachField(x string, fn func(field string) error) error {
x = trimSpace(x)
for len(x) > 0 {
sp := bytealg.IndexByteString(x, ' ')
if sp == -1 {
return fn(x)
}
if field := trimSpace(x[:sp]); len(field) > 0 {
if err := fn(field); err != nil {
return err
}
}
x = trimSpace(x[sp+1:])
}
return nil
}
// stringsHasSuffix is strings.HasSuffix. It reports whether s ends in
// suffix.
func stringsHasSuffix(s, suffix string) bool {
return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
}
// stringsHasSuffixFold reports whether s ends in suffix,
// ASCII-case-insensitively.
func stringsHasSuffixFold(s, suffix string) bool {
return len(s) >= len(suffix) && stringsEqualFold(s[len(s)-len(suffix):], suffix)
}
// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix.
func stringsHasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
// stringsEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func stringsEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lowerASCII(s[i]) != lowerASCII(t[i]) {
return false
}
}
return true
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"io"
"os"
"sync"
"time"
)
// pipeDeadline is an abstraction for handling timeouts.
type pipeDeadline struct {
mu sync.Mutex // Guards timer and cancel
timer *time.Timer
cancel chan struct{} // Must be non-nil
}
func makePipeDeadline() pipeDeadline {
return pipeDeadline{cancel: make(chan struct{})}
}
// set sets the point in time when the deadline will time out.
// A timeout event is signaled by closing the channel returned by waiter.
// Once a timeout has occurred, the deadline can be refreshed by specifying a
// t value in the future.
//
// A zero value for t prevents timeout.
func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel // Wait for the timer callback to finish and close cancel
}
d.timer = nil
// Time is zero, then there is no deadline.
closed := isClosedChan(d.cancel)
if t.IsZero() {
if closed {
d.cancel = make(chan struct{})
}
return
}
// Time in the future, setup a timer to cancel in the future.
if dur := time.Until(t); dur > 0 {
if closed {
d.cancel = make(chan struct{})
}
d.timer = time.AfterFunc(dur, func() {
close(d.cancel)
})
return
}
// Time in the past, so close immediately.
if !closed {
close(d.cancel)
}
}
// wait returns a channel that is closed when the deadline is exceeded.
func (d *pipeDeadline) wait() chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
return d.cancel
}
func isClosedChan(c <-chan struct{}) bool {
select {
case <-c:
return true
default:
return false
}
}
type pipeAddr struct{}
func (pipeAddr) Network() string { return "pipe" }
func (pipeAddr) String() string { return "pipe" }
type pipe struct {
wrMu sync.Mutex // Serialize Write operations
// Used by local Read to interact with remote Write.
// Successful receive on rdRx is always followed by send on rdTx.
rdRx <-chan []byte
rdTx chan<- int
// Used by local Write to interact with remote Read.
// Successful send on wrTx is always followed by receive on wrRx.
wrTx chan<- []byte
wrRx <-chan int
once sync.Once // Protects closing localDone
localDone chan struct{}
remoteDone <-chan struct{}
readDeadline pipeDeadline
writeDeadline pipeDeadline
}
// Pipe creates a synchronous, in-memory, full duplex
// network connection; both ends implement the Conn interface.
// Reads on one end are matched with writes on the other,
// copying data directly between the two; there is no internal
// buffering.
func Pipe() (Conn, Conn) {
cb1 := make(chan []byte)
cb2 := make(chan []byte)
cn1 := make(chan int)
cn2 := make(chan int)
done1 := make(chan struct{})
done2 := make(chan struct{})
p1 := &pipe{
rdRx: cb1, rdTx: cn1,
wrTx: cb2, wrRx: cn2,
localDone: done1, remoteDone: done2,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
p2 := &pipe{
rdRx: cb2, rdTx: cn2,
wrTx: cb1, wrRx: cn1,
localDone: done2, remoteDone: done1,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
return p1, p2
}
func (*pipe) LocalAddr() Addr { return pipeAddr{} }
func (*pipe) RemoteAddr() Addr { return pipeAddr{} }
func (p *pipe) Read(b []byte) (int, error) {
n, err := p.read(b)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
err = &OpError{Op: "read", Net: "pipe", Err: err}
}
return n, err
}
func (p *pipe) read(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.EOF
case isClosedChan(p.readDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
select {
case bw := <-p.rdRx:
nr := copy(b, bw)
p.rdTx <- nr
return nr, nil
case <-p.localDone:
return 0, io.ErrClosedPipe
case <-p.remoteDone:
return 0, io.EOF
case <-p.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (p *pipe) Write(b []byte) (int, error) {
n, err := p.write(b)
if err != nil && err != io.ErrClosedPipe {
err = &OpError{Op: "write", Net: "pipe", Err: err}
}
return n, err
}
func (p *pipe) write(b []byte) (n int, err error) {
switch {
case isClosedChan(p.localDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.remoteDone):
return 0, io.ErrClosedPipe
case isClosedChan(p.writeDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
p.wrMu.Lock() // Ensure entirety of b is written together
defer p.wrMu.Unlock()
for once := true; once || len(b) > 0; once = false {
select {
case p.wrTx <- b:
nw := <-p.wrRx
b = b[nw:]
n += nw
case <-p.localDone:
return n, io.ErrClosedPipe
case <-p.remoteDone:
return n, io.ErrClosedPipe
case <-p.writeDeadline.wait():
return n, os.ErrDeadlineExceeded
}
}
return n, nil
}
func (p *pipe) SetDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
p.writeDeadline.set(t)
return nil
}
func (p *pipe) SetReadDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.readDeadline.set(t)
return nil
}
func (p *pipe) SetWriteDeadline(t time.Time) error {
if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
return io.ErrClosedPipe
}
p.writeDeadline.set(t)
return nil
}
func (p *pipe) Close() error {
p.once.Do(func() { close(p.localDone) })
return nil
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
// parsePort parses service as a decimal integer and returns the
// corresponding value as port. It is the caller's responsibility to
// parse service as a non-decimal integer when needsLookup is true.
//
// Some system resolvers will return a valid port number when given a number
// over 65536 (see https://golang.org/issues/11715). Alas, the parser
// can't bail early on numbers > 65536. Therefore reasonably large/small
// numbers are parsed in full and rejected if invalid.
func parsePort(service string) (port int, needsLookup bool) {
if service == "" {
// Lock in the legacy behavior that an empty string
// means port 0. See golang.org/issue/13610.
return 0, false
}
const (
max = uint32(1<<32 - 1)
cutoff = uint32(1 << 30)
)
neg := false
if service[0] == '+' {
service = service[1:]
} else if service[0] == '-' {
neg = true
service = service[1:]
}
var n uint32
for _, d := range service {
if '0' <= d && d <= '9' {
d -= '0'
} else {
return 0, true
}
if n >= cutoff {
n = max
break
}
n *= 10
nn := n + uint32(d)
if nn < n || nn > max {
n = max
break
}
n = nn
}
if !neg && n >= cutoff {
port = int(cutoff - 1)
} else if neg && n > cutoff {
port = int(cutoff)
} else {
port = int(n)
}
if neg {
port = -port
}
return port, false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
// Read system port mappings from /etc/services
package net
import (
"internal/bytealg"
"sync"
)
var onceReadServices sync.Once
func readServices() {
file, err := open("/etc/services")
if err != nil {
return
}
defer file.close()
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// "http 80/tcp www www-http # World Wide Web HTTP"
if i := bytealg.IndexByteString(line, '#'); i >= 0 {
line = line[:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
portnet := f[1] // "80/tcp"
port, j, ok := dtoi(portnet)
if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' {
continue
}
netw := portnet[j+1:] // "tcp"
m, ok1 := services[netw]
if !ok1 {
m = make(map[string]int)
services[netw] = m
}
for i := 0; i < len(f); i++ {
if i != 1 { // f[1] was port/net
m[f[i]] = port
}
}
}
}
// goLookupPort is the native Go implementation of LookupPort.
func goLookupPort(network, service string) (port int, err error) {
onceReadServices.Do(readServices)
return lookupPortMap(network, service)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/poll"
"runtime"
"syscall"
)
// BUG(tmm1): On Windows, the Write method of syscall.RawConn
// does not integrate with the runtime's network poller. It cannot
// wait for the connection to become writeable, and does not respect
// deadlines. If the user-provided callback returns false, the Write
// method will fail immediately.
// BUG(mikio): On JS and Plan 9, the Control, Read and Write
// methods of syscall.RawConn are not implemented.
type rawConn struct {
fd *netFD
}
func (c *rawConn) ok() bool { return c != nil && c.fd != nil }
func (c *rawConn) Control(f func(uintptr)) error {
if !c.ok() {
return syscall.EINVAL
}
err := c.fd.pfd.RawControl(f)
runtime.KeepAlive(c.fd)
if err != nil {
err = &OpError{Op: "raw-control", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return err
}
func (c *rawConn) Read(f func(uintptr) bool) error {
if !c.ok() {
return syscall.EINVAL
}
err := c.fd.pfd.RawRead(f)
runtime.KeepAlive(c.fd)
if err != nil {
err = &OpError{Op: "raw-read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return err
}
func (c *rawConn) Write(f func(uintptr) bool) error {
if !c.ok() {
return syscall.EINVAL
}
err := c.fd.pfd.RawWrite(f)
runtime.KeepAlive(c.fd)
if err != nil {
err = &OpError{Op: "raw-write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return err
}
// PollFD returns the poll.FD of the underlying connection.
//
// Other packages in std that also import internal/poll (such as os)
// can use a type assertion to access this extension method so that
// they can pass the *poll.FD to functions like poll.Splice.
//
// PollFD is not intended for use outside the standard library.
func (c *rawConn) PollFD() *poll.FD {
if !c.ok() {
return nil
}
return &c.fd.pfd
}
func newRawConn(fd *netFD) (*rawConn, error) {
return &rawConn{fd: fd}, nil
}
type rawListener struct {
rawConn
}
func (l *rawListener) Read(func(uintptr) bool) error {
return syscall.EINVAL
}
func (l *rawListener) Write(func(uintptr) bool) error {
return syscall.EINVAL
}
func newRawListener(fd *netFD) (*rawListener, error) {
return &rawListener{rawConn{fd: fd}}, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rpc
import (
"bufio"
"encoding/gob"
"errors"
"io"
"log"
"net"
"net/http"
"sync"
)
// ServerError represents an error that has been returned from
// the remote side of the RPC connection.
type ServerError string
func (e ServerError) Error() string {
return string(e)
}
var ErrShutdown = errors.New("connection is shut down")
// Call represents an active RPC.
type Call struct {
ServiceMethod string // The name of the service and method to call.
Args any // The argument to the function (*struct).
Reply any // The reply from the function (*struct).
Error error // After completion, the error status.
Done chan *Call // Receives *Call when Go is complete.
}
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
codec ClientCodec
reqMutex sync.Mutex // protects following
request Request
mutex sync.Mutex // protects following
seq uint64
pending map[uint64]*Call
closing bool // user has called Close
shutdown bool // server has told us to stop
}
// A ClientCodec implements writing of RPC requests and
// reading of RPC responses for the client side of an RPC session.
// The client calls WriteRequest to write a request to the connection
// and calls ReadResponseHeader and ReadResponseBody in pairs
// to read responses. The client calls Close when finished with the
// connection. ReadResponseBody may be called with a nil
// argument to force the body of the response to be read and then
// discarded.
// See NewClient's comment for information about concurrent access.
type ClientCodec interface {
WriteRequest(*Request, any) error
ReadResponseHeader(*Response) error
ReadResponseBody(any) error
Close() error
}
func (client *Client) send(call *Call) {
client.reqMutex.Lock()
defer client.reqMutex.Unlock()
// Register this call.
client.mutex.Lock()
if client.shutdown || client.closing {
client.mutex.Unlock()
call.Error = ErrShutdown
call.done()
return
}
seq := client.seq
client.seq++
client.pending[seq] = call
client.mutex.Unlock()
// Encode and send the request.
client.request.Seq = seq
client.request.ServiceMethod = call.ServiceMethod
err := client.codec.WriteRequest(&client.request, call.Args)
if err != nil {
client.mutex.Lock()
call = client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
if call != nil {
call.Error = err
call.done()
}
}
}
func (client *Client) input() {
var err error
var response Response
for err == nil {
response = Response{}
err = client.codec.ReadResponseHeader(&response)
if err != nil {
break
}
seq := response.Seq
client.mutex.Lock()
call := client.pending[seq]
delete(client.pending, seq)
client.mutex.Unlock()
switch {
case call == nil:
// We've got no pending call. That usually means that
// WriteRequest partially failed, and call was already
// removed; response is a server telling us about an
// error reading request body. We should still attempt
// to read error body, but there's no one to give it to.
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
case response.Error != "":
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
call.Error = ServerError(response.Error)
err = client.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
call.done()
default:
err = client.codec.ReadResponseBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// Terminate pending calls.
client.reqMutex.Lock()
client.mutex.Lock()
client.shutdown = true
closing := client.closing
if err == io.EOF {
if closing {
err = ErrShutdown
} else {
err = io.ErrUnexpectedEOF
}
}
for _, call := range client.pending {
call.Error = err
call.done()
}
client.mutex.Unlock()
client.reqMutex.Unlock()
if debugLog && err != io.EOF && !closing {
log.Println("rpc: client protocol error:", err)
}
}
func (call *Call) done() {
select {
case call.Done <- call:
// ok
default:
// We don't want to block here. It is the caller's responsibility to make
// sure the channel has enough buffer space. See comment in Go().
if debugLog {
log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
}
}
}
// NewClient returns a new Client to handle requests to the
// set of services at the other end of the connection.
// It adds a buffer to the write side of the connection so
// the header and payload are sent as a unit.
//
// The read and write halves of the connection are serialized independently,
// so no interlocking is required. However each half may be accessed
// concurrently so the implementation of conn should protect against
// concurrent reads or concurrent writes.
func NewClient(conn io.ReadWriteCloser) *Client {
encBuf := bufio.NewWriter(conn)
client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
return NewClientWithCodec(client)
}
// NewClientWithCodec is like NewClient but uses the specified
// codec to encode requests and decode responses.
func NewClientWithCodec(codec ClientCodec) *Client {
client := &Client{
codec: codec,
pending: make(map[uint64]*Call),
}
go client.input()
return client
}
type gobClientCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
encBuf *bufio.Writer
}
func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
if err = c.enc.Encode(r); err != nil {
return
}
if err = c.enc.Encode(body); err != nil {
return
}
return c.encBuf.Flush()
}
func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
return c.dec.Decode(r)
}
func (c *gobClientCodec) ReadResponseBody(body any) error {
return c.dec.Decode(body)
}
func (c *gobClientCodec) Close() error {
return c.rwc.Close()
}
// DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string) (*Client, error) {
return DialHTTPPath(network, address, DefaultRPCPath)
}
// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
// Require successful HTTP response
// before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn), nil
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
conn.Close()
return nil, &net.OpError{
Op: "dial-http",
Net: network + " " + address,
Addr: nil,
Err: err,
}
}
// Dial connects to an RPC server at the specified network address.
func Dial(network, address string) (*Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn), nil
}
// Close calls the underlying codec's Close method. If the connection is already
// shutting down, ErrShutdown is returned.
func (client *Client) Close() error {
client.mutex.Lock()
if client.closing {
client.mutex.Unlock()
return ErrShutdown
}
client.closing = true
client.mutex.Unlock()
return client.codec.Close()
}
// Go invokes the function asynchronously. It returns the Call structure representing
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
call := new(Call)
call.ServiceMethod = serviceMethod
call.Args = args
call.Reply = reply
if done == nil {
done = make(chan *Call, 10) // buffered.
} else {
// If caller passes done != nil, it must arrange that
// done has enough buffer for the number of simultaneous
// RPCs that will be using that channel. If the channel
// is totally unbuffered, it's best not to run at all.
if cap(done) == 0 {
log.Panic("rpc: done channel is unbuffered")
}
}
call.Done = done
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args any, reply any) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rpc
/*
Some HTML presented at http://machine:port/debug/rpc
Lists services, their methods, and some statistics, still rudimentary.
*/
import (
"fmt"
"html/template"
"net/http"
"sort"
)
const debugText = `<html>
<body>
<title>Services</title>
{{range .}}
<hr>
Service {{.Name}}
<hr>
<table>
<th align=center>Method</th><th align=center>Calls</th>
{{range .Method}}
<tr>
<td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td>
<td align=center>{{.Type.NumCalls}}</td>
</tr>
{{end}}
</table>
{{end}}
</body>
</html>`
var debug = template.Must(template.New("RPC debug").Parse(debugText))
// If set, print log statements for internal and I/O errors.
var debugLog = false
type debugMethod struct {
Type *methodType
Name string
}
type methodArray []debugMethod
type debugService struct {
Service *service
Name string
Method methodArray
}
type serviceArray []debugService
func (s serviceArray) Len() int { return len(s) }
func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (m methodArray) Len() int { return len(m) }
func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
type debugHTTP struct {
*Server
}
// Runs at /debug/rpc
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data.
var services serviceArray
server.serviceMap.Range(func(snamei, svci any) bool {
svc := svci.(*service)
ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))}
for mname, method := range svc.method {
ds.Method = append(ds.Method, debugMethod{method, mname})
}
sort.Sort(ds.Method)
services = append(services, ds)
return true
})
sort.Sort(services)
err := debug.Execute(w, services)
if err != nil {
fmt.Fprintln(w, "rpc: error executing template:", err.Error())
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package jsonrpc implements a JSON-RPC 1.0 ClientCodec and ServerCodec
// for the rpc package.
// For JSON-RPC 2.0 support, see https://godoc.org/?q=json-rpc+2.0
package jsonrpc
import (
"encoding/json"
"fmt"
"io"
"net"
"net/rpc"
"sync"
)
type clientCodec struct {
dec *json.Decoder // for reading JSON values
enc *json.Encoder // for writing JSON values
c io.Closer
// temporary work space
req clientRequest
resp clientResponse
// JSON-RPC responses include the request id but not the request method.
// Package rpc expects both.
// We save the request method in pending when sending a request
// and then look it up by request ID when filling out the rpc Response.
mutex sync.Mutex // protects pending
pending map[uint64]string // map request id to method name
}
// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
return &clientCodec{
dec: json.NewDecoder(conn),
enc: json.NewEncoder(conn),
c: conn,
pending: make(map[uint64]string),
}
}
type clientRequest struct {
Method string `json:"method"`
Params [1]any `json:"params"`
Id uint64 `json:"id"`
}
func (c *clientCodec) WriteRequest(r *rpc.Request, param any) error {
c.mutex.Lock()
c.pending[r.Seq] = r.ServiceMethod
c.mutex.Unlock()
c.req.Method = r.ServiceMethod
c.req.Params[0] = param
c.req.Id = r.Seq
return c.enc.Encode(&c.req)
}
type clientResponse struct {
Id uint64 `json:"id"`
Result *json.RawMessage `json:"result"`
Error any `json:"error"`
}
func (r *clientResponse) reset() {
r.Id = 0
r.Result = nil
r.Error = nil
}
func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
c.resp.reset()
if err := c.dec.Decode(&c.resp); err != nil {
return err
}
c.mutex.Lock()
r.ServiceMethod = c.pending[c.resp.Id]
delete(c.pending, c.resp.Id)
c.mutex.Unlock()
r.Error = ""
r.Seq = c.resp.Id
if c.resp.Error != nil || c.resp.Result == nil {
x, ok := c.resp.Error.(string)
if !ok {
return fmt.Errorf("invalid error %v", c.resp.Error)
}
if x == "" {
x = "unspecified error"
}
r.Error = x
}
return nil
}
func (c *clientCodec) ReadResponseBody(x any) error {
if x == nil {
return nil
}
return json.Unmarshal(*c.resp.Result, x)
}
func (c *clientCodec) Close() error {
return c.c.Close()
}
// NewClient returns a new rpc.Client to handle requests to the
// set of services at the other end of the connection.
func NewClient(conn io.ReadWriteCloser) *rpc.Client {
return rpc.NewClientWithCodec(NewClientCodec(conn))
}
// Dial connects to a JSON-RPC server at the specified network address.
func Dial(network, address string) (*rpc.Client, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClient(conn), err
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jsonrpc
import (
"encoding/json"
"errors"
"io"
"net/rpc"
"sync"
)
var errMissingParams = errors.New("jsonrpc: request body missing params")
type serverCodec struct {
dec *json.Decoder // for reading JSON values
enc *json.Encoder // for writing JSON values
c io.Closer
// temporary work space
req serverRequest
// JSON-RPC clients can use arbitrary json values as request IDs.
// Package rpc expects uint64 request IDs.
// We assign uint64 sequence numbers to incoming requests
// but save the original request ID in the pending map.
// When rpc responds, we use the sequence number in
// the response to find the original request ID.
mutex sync.Mutex // protects seq, pending
seq uint64
pending map[uint64]*json.RawMessage
}
// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
return &serverCodec{
dec: json.NewDecoder(conn),
enc: json.NewEncoder(conn),
c: conn,
pending: make(map[uint64]*json.RawMessage),
}
}
type serverRequest struct {
Method string `json:"method"`
Params *json.RawMessage `json:"params"`
Id *json.RawMessage `json:"id"`
}
func (r *serverRequest) reset() {
r.Method = ""
r.Params = nil
r.Id = nil
}
type serverResponse struct {
Id *json.RawMessage `json:"id"`
Result any `json:"result"`
Error any `json:"error"`
}
func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
c.req.reset()
if err := c.dec.Decode(&c.req); err != nil {
return err
}
r.ServiceMethod = c.req.Method
// JSON request id can be any JSON value;
// RPC package expects uint64. Translate to
// internal uint64 and save JSON on the side.
c.mutex.Lock()
c.seq++
c.pending[c.seq] = c.req.Id
c.req.Id = nil
r.Seq = c.seq
c.mutex.Unlock()
return nil
}
func (c *serverCodec) ReadRequestBody(x any) error {
if x == nil {
return nil
}
if c.req.Params == nil {
return errMissingParams
}
// JSON params is array value.
// RPC params is struct.
// Unmarshal into array containing struct for now.
// Should think about making RPC more general.
var params [1]any
params[0] = x
return json.Unmarshal(*c.req.Params, ¶ms)
}
var null = json.RawMessage([]byte("null"))
func (c *serverCodec) WriteResponse(r *rpc.Response, x any) error {
c.mutex.Lock()
b, ok := c.pending[r.Seq]
if !ok {
c.mutex.Unlock()
return errors.New("invalid sequence number in response")
}
delete(c.pending, r.Seq)
c.mutex.Unlock()
if b == nil {
// Invalid request so no id. Use JSON null.
b = &null
}
resp := serverResponse{Id: b}
if r.Error == "" {
resp.Result = x
} else {
resp.Error = r.Error
}
return c.enc.Encode(resp)
}
func (c *serverCodec) Close() error {
return c.c.Close()
}
// ServeConn runs the JSON-RPC server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
func ServeConn(conn io.ReadWriteCloser) {
rpc.ServeCodec(NewServerCodec(conn))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package rpc provides access to the exported methods of an object across a
network or other I/O connection. A server registers an object, making it visible
as a service with the name of the type of the object. After registration, exported
methods of the object will be accessible remotely. A server may register multiple
objects (services) of different types but it is an error to register multiple
objects of the same type.
Only methods that satisfy these criteria will be made available for remote access;
other methods will be ignored:
- the method's type is exported.
- the method is exported.
- the method has two arguments, both exported (or builtin) types.
- the method's second argument is a pointer.
- the method has return type error.
In effect, the method must look schematically like
func (t *T) MethodName(argType T1, replyType *T2) error
where T1 and T2 can be marshaled by encoding/gob.
These requirements apply even if a different codec is used.
(In the future, these requirements may soften for custom codecs.)
The method's first argument represents the arguments provided by the caller; the
second argument represents the result parameters to be returned to the caller.
The method's return value, if non-nil, is passed back as a string that the client
sees as if created by errors.New. If an error is returned, the reply parameter
will not be sent back to the client.
The server may handle requests on a single connection by calling ServeConn. More
typically it will create a network listener and call Accept or, for an HTTP
listener, HandleHTTP and http.Serve.
A client wishing to use the service establishes a connection and then invokes
NewClient on the connection. The convenience function Dial (DialHTTP) performs
both steps for a raw network connection (an HTTP connection). The resulting
Client object has two methods, Call and Go, that specify the service and method to
call, a pointer containing the arguments, and a pointer to receive the result
parameters.
The Call method waits for the remote call to complete while the Go method
launches the call asynchronously and signals completion using the Call
structure's Done channel.
Unless an explicit codec is set up, package encoding/gob is used to
transport the data.
Here is a simple example. A server wishes to export an object of type Arith:
package server
import "errors"
type Args struct {
A, B int
}
type Quotient struct {
Quo, Rem int
}
type Arith int
func (t *Arith) Multiply(args *Args, reply *int) error {
*reply = args.A * args.B
return nil
}
func (t *Arith) Divide(args *Args, quo *Quotient) error {
if args.B == 0 {
return errors.New("divide by zero")
}
quo.Quo = args.A / args.B
quo.Rem = args.A % args.B
return nil
}
The server calls (for HTTP service):
arith := new(Arith)
rpc.Register(arith)
rpc.HandleHTTP()
l, e := net.Listen("tcp", ":1234")
if e != nil {
log.Fatal("listen error:", e)
}
go http.Serve(l, nil)
At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
"Arith.Divide". To invoke one, a client first dials the server:
client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
if err != nil {
log.Fatal("dialing:", err)
}
Then it can make a remote call:
// Synchronous call
args := &server.Args{7,8}
var reply int
err = client.Call("Arith.Multiply", args, &reply)
if err != nil {
log.Fatal("arith error:", err)
}
fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply)
or
// Asynchronous call
quotient := new(Quotient)
divCall := client.Go("Arith.Divide", args, quotient, nil)
replyCall := <-divCall.Done // will be equal to divCall
// check errors, print, etc.
A server implementation will often provide a simple, type-safe wrapper for the
client.
The net/rpc package is frozen and is not accepting new features.
*/
package rpc
import (
"bufio"
"encoding/gob"
"errors"
"go/token"
"io"
"log"
"net"
"net/http"
"reflect"
"strings"
"sync"
)
const (
// Defaults used by HandleHTTP
DefaultRPCPath = "/_goRPC_"
DefaultDebugPath = "/debug/rpc"
)
// Precompute the reflect type for error. Can't use error directly
// because Typeof takes an empty interface value. This is annoying.
var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
type methodType struct {
sync.Mutex // protects counters
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
numCalls uint
}
type service struct {
name string // name of service
rcvr reflect.Value // receiver of methods for the service
typ reflect.Type // type of the receiver
method map[string]*methodType // registered methods
}
// Request is a header written before every RPC call. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Request struct {
ServiceMethod string // format: "Service.Method"
Seq uint64 // sequence number chosen by client
next *Request // for free list in Server
}
// Response is a header written before every RPC return. It is used internally
// but documented here as an aid to debugging, such as when analyzing
// network traffic.
type Response struct {
ServiceMethod string // echoes that of the Request
Seq uint64 // echoes that of the request
Error string // error, if any.
next *Response // for free list in Server
}
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map // map[string]*service
reqLock sync.Mutex // protects freeReq
freeReq *Request
respLock sync.Mutex // protects freeResp
freeResp *Response
}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// Is this type exported or a builtin?
func isExportedOrBuiltinType(t reflect.Type) bool {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
// PkgPath will be non-empty even for an exported type,
// so we need to check the type name as well.
return token.IsExported(t.Name()) || t.PkgPath() == ""
}
// Register publishes in the server the set of methods of the
// receiver value that satisfy the following conditions:
// - exported method of exported type
// - two arguments, both of exported type
// - the second argument is a pointer
// - one return value, of type error
//
// It returns an error if the receiver is not an exported type or has
// no suitable methods. It also logs the error using package log.
// The client accesses each method using a string of the form "Type.Method",
// where Type is the receiver's concrete type.
func (server *Server) Register(rcvr any) error {
return server.register(rcvr, "", false)
}
// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func (server *Server) RegisterName(name string, rcvr any) error {
return server.register(rcvr, name, true)
}
// logRegisterError specifies whether to log problems during method registration.
// To debug registration, recompile the package with this set to true.
const logRegisterError = false
func (server *Server) register(rcvr any, name string, useName bool) error {
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
sname := name
if !useName {
sname = reflect.Indirect(s.rcvr).Type().Name()
}
if sname == "" {
s := "rpc.Register: no service name for type " + s.typ.String()
log.Print(s)
return errors.New(s)
}
if !useName && !token.IsExported(sname) {
s := "rpc.Register: type " + sname + " is not exported"
log.Print(s)
return errors.New(s)
}
s.name = sname
// Install the methods
s.method = suitableMethods(s.typ, logRegisterError)
if len(s.method) == 0 {
str := ""
// To help the user, see if a pointer receiver would work.
method := suitableMethods(reflect.PointerTo(s.typ), false)
if len(method) != 0 {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
} else {
str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
}
log.Print(str)
return errors.New(str)
}
if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
return errors.New("rpc: service already defined: " + sname)
}
return nil
}
// suitableMethods returns suitable Rpc methods of typ. It will log
// errors if logErr is true.
func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
methods := make(map[string]*methodType)
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
mtype := method.Type
mname := method.Name
// Method must be exported.
if !method.IsExported() {
continue
}
// Method needs three ins: receiver, *args, *reply.
if mtype.NumIn() != 3 {
if logErr {
log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
}
continue
}
// First arg need not be a pointer.
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
if logErr {
log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
}
continue
}
// Second arg must be a pointer.
replyType := mtype.In(2)
if replyType.Kind() != reflect.Pointer {
if logErr {
log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
}
continue
}
// Reply type must be exported.
if !isExportedOrBuiltinType(replyType) {
if logErr {
log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
}
continue
}
// Method needs one out.
if mtype.NumOut() != 1 {
if logErr {
log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
}
continue
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
if logErr {
log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
}
continue
}
methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
}
return methods
}
// A value sent as a placeholder for the server's response value when the server
// receives an invalid request. It is never decoded by the client since the Response
// contains an error when it is used.
var invalidRequest = struct{}{}
func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
resp := server.getResponse()
// Encode the response header
resp.ServiceMethod = req.ServiceMethod
if errmsg != "" {
resp.Error = errmsg
reply = invalidRequest
}
resp.Seq = req.Seq
sending.Lock()
err := codec.WriteResponse(resp, reply)
if debugLog && err != nil {
log.Println("rpc: writing response:", err)
}
sending.Unlock()
server.freeResponse(resp)
}
func (m *methodType) NumCalls() (n uint) {
m.Lock()
n = m.numCalls
m.Unlock()
return n
}
func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
if wg != nil {
defer wg.Done()
}
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
function := mtype.method.Func
// Invoke the method, providing a new value for the reply.
returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
errmsg = errInter.(error).Error()
}
server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
server.freeRequest(req)
}
type gobServerCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
encBuf *bufio.Writer
closed bool
}
func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
return c.dec.Decode(r)
}
func (c *gobServerCodec) ReadRequestBody(body any) error {
return c.dec.Decode(body)
}
func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
if err = c.enc.Encode(r); err != nil {
if c.encBuf.Flush() == nil {
// Gob couldn't encode the header. Should not happen, so if it does,
// shut down the connection to signal that the connection is broken.
log.Println("rpc: gob error encoding response:", err)
c.Close()
}
return
}
if err = c.enc.Encode(body); err != nil {
if c.encBuf.Flush() == nil {
// Was a gob problem encoding the body but the header has been written.
// Shut down the connection to signal that the connection is broken.
log.Println("rpc: gob error encoding body:", err)
c.Close()
}
return
}
return c.encBuf.Flush()
}
func (c *gobServerCodec) Close() error {
if c.closed {
// Only call c.rwc.Close once; otherwise the semantics are undefined.
return nil
}
c.closed = true
return c.rwc.Close()
}
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
// See NewClient's comment for information about concurrent access.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{
rwc: conn,
dec: gob.NewDecoder(conn),
enc: gob.NewEncoder(buf),
encBuf: buf,
}
server.ServeCodec(srv)
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if debugLog && err != io.EOF {
log.Println("rpc:", err)
}
if !keepReading {
break
}
// send a response if we actually managed to read a header.
if req != nil {
server.sendResponse(sending, req, invalidRequest, codec, err.Error())
server.freeRequest(req)
}
continue
}
wg.Add(1)
go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
}
// We've seen that there are no more requests.
// Wait for responses to be sent before closing codec.
wg.Wait()
codec.Close()
}
// ServeRequest is like ServeCodec but synchronously serves a single request.
// It does not close the codec upon completion.
func (server *Server) ServeRequest(codec ServerCodec) error {
sending := new(sync.Mutex)
service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if !keepReading {
return err
}
// send a response if we actually managed to read a header.
if req != nil {
server.sendResponse(sending, req, invalidRequest, codec, err.Error())
server.freeRequest(req)
}
return err
}
service.call(server, sending, nil, mtype, req, argv, replyv, codec)
return nil
}
func (server *Server) getRequest() *Request {
server.reqLock.Lock()
req := server.freeReq
if req == nil {
req = new(Request)
} else {
server.freeReq = req.next
*req = Request{}
}
server.reqLock.Unlock()
return req
}
func (server *Server) freeRequest(req *Request) {
server.reqLock.Lock()
req.next = server.freeReq
server.freeReq = req
server.reqLock.Unlock()
}
func (server *Server) getResponse() *Response {
server.respLock.Lock()
resp := server.freeResp
if resp == nil {
resp = new(Response)
} else {
server.freeResp = resp.next
*resp = Response{}
}
server.respLock.Unlock()
return resp
}
func (server *Server) freeResponse(resp *Response) {
server.respLock.Lock()
resp.next = server.freeResp
server.freeResp = resp
server.respLock.Unlock()
}
func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
service, mtype, req, keepReading, err = server.readRequestHeader(codec)
if err != nil {
if !keepReading {
return
}
// discard body
codec.ReadRequestBody(nil)
return
}
// Decode the argument value.
argIsValue := false // if true, need to indirect before calling.
if mtype.ArgType.Kind() == reflect.Pointer {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
// argv guaranteed to be a pointer now.
if err = codec.ReadRequestBody(argv.Interface()); err != nil {
return
}
if argIsValue {
argv = argv.Elem()
}
replyv = reflect.New(mtype.ReplyType.Elem())
switch mtype.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
}
return
}
func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
// Grab the request header.
req = server.getRequest()
err = codec.ReadRequestHeader(req)
if err != nil {
req = nil
if err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
err = errors.New("rpc: server cannot decode request: " + err.Error())
return
}
// We read the header successfully. If we see an error now,
// we can still recover and move on to the next request.
keepReading = true
dot := strings.LastIndex(req.ServiceMethod, ".")
if dot < 0 {
err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
return
}
serviceName := req.ServiceMethod[:dot]
methodName := req.ServiceMethod[dot+1:]
// Look up the request.
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc: can't find service " + req.ServiceMethod)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc: can't find method " + req.ServiceMethod)
}
return
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks until the listener
// returns a non-nil error. The caller typically invokes Accept in a
// go statement.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Print("rpc.Serve: accept:", err.Error())
return
}
go server.ServeConn(conn)
}
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
// RegisterName is like Register but uses the provided name for the type
// instead of the receiver's concrete type.
func RegisterName(name string, rcvr any) error {
return DefaultServer.RegisterName(name, rcvr)
}
// A ServerCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session.
// The server calls ReadRequestHeader and ReadRequestBody in pairs
// to read requests from the connection, and it calls WriteResponse to
// write a response back. The server calls Close when finished with the
// connection. ReadRequestBody may be called with a nil
// argument to force the body of the request to be read and discarded.
// See NewClient's comment for information about concurrent access.
type ServerCodec interface {
ReadRequestHeader(*Request) error
ReadRequestBody(any) error
WriteResponse(*Response, any) error
// Close can be called multiple times and must be idempotent.
Close() error
}
// ServeConn runs the DefaultServer on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
// See NewClient's comment for information about concurrent access.
func ServeConn(conn io.ReadWriteCloser) {
DefaultServer.ServeConn(conn)
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
DefaultServer.ServeCodec(codec)
}
// ServeRequest is like ServeCodec but synchronously serves a single request.
// It does not close the codec upon completion.
func ServeRequest(codec ServerCodec) error {
return DefaultServer.ServeRequest(codec)
}
// Accept accepts connections on the listener and serves requests
// to DefaultServer for each incoming connection.
// Accept blocks; the caller typically invokes it in a go statement.
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Can connect to RPC service using HTTP CONNECT to rpcPath.
var connected = "200 Connected to Go RPC"
// ServeHTTP implements an http.Handler that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
// and a debugging handler on debugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
http.Handle(rpcPath, server)
http.Handle(debugPath, debugHTTP{server})
}
// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func HandleHTTP() {
DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/poll"
"io"
"os"
)
// sendFile copies the contents of r to c using the sendfile
// system call to minimize copies.
//
// if handled == true, sendFile returns the number (potentially zero) of bytes
// copied and any non-EOF error.
//
// if handled == false, sendFile performed no work.
func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if ok {
remain, r = lr.N, lr.R
if remain <= 0 {
return 0, nil, true
}
}
f, ok := r.(*os.File)
if !ok {
return 0, nil, false
}
sc, err := f.SyscallConn()
if err != nil {
return 0, nil, false
}
var werr error
err = sc.Read(func(fd uintptr) bool {
written, werr, handled = poll.SendFile(&c.pfd, int(fd), remain)
return true
})
if err == nil {
err = werr
}
if lr != nil {
lr.N = remain - written
}
return written, wrapSyscallError("sendfile", err), handled
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package smtp
import (
"crypto/hmac"
"crypto/md5"
"errors"
"fmt"
)
// Auth is implemented by an SMTP authentication mechanism.
type Auth interface {
// Start begins an authentication with a server.
// It returns the name of the authentication protocol
// and optionally data to include in the initial AUTH message
// sent to the server.
// If it returns a non-nil error, the SMTP client aborts
// the authentication attempt and closes the connection.
Start(server *ServerInfo) (proto string, toServer []byte, err error)
// Next continues the authentication. The server has just sent
// the fromServer data. If more is true, the server expects a
// response, which Next should return as toServer; otherwise
// Next should return toServer == nil.
// If Next returns a non-nil error, the SMTP client aborts
// the authentication attempt and closes the connection.
Next(fromServer []byte, more bool) (toServer []byte, err error)
}
// ServerInfo records information about an SMTP server.
type ServerInfo struct {
Name string // SMTP server name
TLS bool // using TLS, with valid certificate for Name
Auth []string // advertised authentication mechanisms
}
type plainAuth struct {
identity, username, password string
host string
}
// PlainAuth returns an Auth that implements the PLAIN authentication
// mechanism as defined in RFC 4616. The returned Auth uses the given
// username and password to authenticate to host and act as identity.
// Usually identity should be the empty string, to act as username.
//
// PlainAuth will only send the credentials if the connection is using TLS
// or is connected to localhost. Otherwise authentication will fail with an
// error, without sending the credentials.
func PlainAuth(identity, username, password, host string) Auth {
return &plainAuth{identity, username, password, host}
}
func isLocalhost(name string) bool {
return name == "localhost" || name == "127.0.0.1" || name == "::1"
}
func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) {
// Must have TLS, or else localhost server.
// Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo.
// In particular, it doesn't matter if the server advertises PLAIN auth.
// That might just be the attacker saying
// "it's ok, you can trust me with your password."
if !server.TLS && !isLocalhost(server.Name) {
return "", nil, errors.New("unencrypted connection")
}
if server.Name != a.host {
return "", nil, errors.New("wrong host name")
}
resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password)
return "PLAIN", resp, nil
}
func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
// We've already sent everything.
return nil, errors.New("unexpected server challenge")
}
return nil, nil
}
type cramMD5Auth struct {
username, secret string
}
// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication
// mechanism as defined in RFC 2195.
// The returned Auth uses the given username and secret to authenticate
// to the server using the challenge-response mechanism.
func CRAMMD5Auth(username, secret string) Auth {
return &cramMD5Auth{username, secret}
}
func (a *cramMD5Auth) Start(server *ServerInfo) (string, []byte, error) {
return "CRAM-MD5", nil, nil
}
func (a *cramMD5Auth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
d := hmac.New(md5.New, []byte(a.secret))
d.Write(fromServer)
s := make([]byte, 0, d.Size())
return fmt.Appendf(nil, "%s %x", a.username, d.Sum(s)), nil
}
return nil, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC 5321.
// It also implements the following extensions:
//
// 8BITMIME RFC 1652
// AUTH RFC 2554
// STARTTLS RFC 3207
//
// Additional extensions may be handled by clients.
//
// The smtp package is frozen and is not accepting new features.
// Some external packages provide more functionality. See:
//
// https://godoc.org/?q=smtp
package smtp
import (
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/textproto"
"strings"
)
// A Client represents a client connection to an SMTP server.
type Client struct {
// Text is the textproto.Conn used by the Client. It is exported to allow for
// clients to add extensions.
Text *textproto.Conn
// keep a reference to the connection so it can be used to create a TLS
// connection later
conn net.Conn
// whether the Client is using TLS
tls bool
serverName string
// map of supported extensions
ext map[string]string
// supported auth mechanisms
auth []string
localName string // the name to use in HELO/EHLO
didHello bool // whether we've said HELO/EHLO
helloError error // the error from the hello
}
// Dial returns a new Client connected to an SMTP server at addr.
// The addr must include a port, as in "mail.example.com:smtp".
func Dial(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
host, _, _ := net.SplitHostPort(addr)
return NewClient(conn, host)
}
// NewClient returns a new Client using an existing connection and host as a
// server name to be used when authenticating.
func NewClient(conn net.Conn, host string) (*Client, error) {
text := textproto.NewConn(conn)
_, _, err := text.ReadResponse(220)
if err != nil {
text.Close()
return nil, err
}
c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"}
_, c.tls = conn.(*tls.Conn)
return c, nil
}
// Close closes the connection.
func (c *Client) Close() error {
return c.Text.Close()
}
// hello runs a hello exchange if needed.
func (c *Client) hello() error {
if !c.didHello {
c.didHello = true
err := c.ehlo()
if err != nil {
c.helloError = c.helo()
}
}
return c.helloError
}
// Hello sends a HELO or EHLO to the server as the given host name.
// Calling this method is only necessary if the client needs control
// over the host name used. The client will introduce itself as "localhost"
// automatically otherwise. If Hello is called, it must be called before
// any of the other methods.
func (c *Client) Hello(localName string) error {
if err := validateLine(localName); err != nil {
return err
}
if c.didHello {
return errors.New("smtp: Hello called after other methods")
}
c.localName = localName
return c.hello()
}
// cmd is a convenience function that sends a command and returns the response
func (c *Client) cmd(expectCode int, format string, args ...any) (int, string, error) {
id, err := c.Text.Cmd(format, args...)
if err != nil {
return 0, "", err
}
c.Text.StartResponse(id)
defer c.Text.EndResponse(id)
code, msg, err := c.Text.ReadResponse(expectCode)
return code, msg, err
}
// helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo.
func (c *Client) helo() error {
c.ext = nil
_, _, err := c.cmd(250, "HELO %s", c.localName)
return err
}
// ehlo sends the EHLO (extended hello) greeting to the server. It
// should be the preferred greeting for servers that support it.
func (c *Client) ehlo() error {
_, msg, err := c.cmd(250, "EHLO %s", c.localName)
if err != nil {
return err
}
ext := make(map[string]string)
extList := strings.Split(msg, "\n")
if len(extList) > 1 {
extList = extList[1:]
for _, line := range extList {
k, v, _ := strings.Cut(line, " ")
ext[k] = v
}
}
if mechs, ok := ext["AUTH"]; ok {
c.auth = strings.Split(mechs, " ")
}
c.ext = ext
return err
}
// StartTLS sends the STARTTLS command and encrypts all further communication.
// Only servers that advertise the STARTTLS extension support this function.
func (c *Client) StartTLS(config *tls.Config) error {
if err := c.hello(); err != nil {
return err
}
_, _, err := c.cmd(220, "STARTTLS")
if err != nil {
return err
}
c.conn = tls.Client(c.conn, config)
c.Text = textproto.NewConn(c.conn)
c.tls = true
return c.ehlo()
}
// TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if StartTLS did
// not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
tc, ok := c.conn.(*tls.Conn)
if !ok {
return
}
return tc.ConnectionState(), true
}
// Verify checks the validity of an email address on the server.
// If Verify returns nil, the address is valid. A non-nil return
// does not necessarily indicate an invalid address. Many servers
// will not verify addresses for security reasons.
func (c *Client) Verify(addr string) error {
if err := validateLine(addr); err != nil {
return err
}
if err := c.hello(); err != nil {
return err
}
_, _, err := c.cmd(250, "VRFY %s", addr)
return err
}
// Auth authenticates a client using the provided authentication mechanism.
// A failed authentication closes the connection.
// Only servers that advertise the AUTH extension support this function.
func (c *Client) Auth(a Auth) error {
if err := c.hello(); err != nil {
return err
}
encoding := base64.StdEncoding
mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth})
if err != nil {
c.Quit()
return err
}
resp64 := make([]byte, encoding.EncodedLen(len(resp)))
encoding.Encode(resp64, resp)
code, msg64, err := c.cmd(0, strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64)))
for err == nil {
var msg []byte
switch code {
case 334:
msg, err = encoding.DecodeString(msg64)
case 235:
// the last message isn't base64 because it isn't a challenge
msg = []byte(msg64)
default:
err = &textproto.Error{Code: code, Msg: msg64}
}
if err == nil {
resp, err = a.Next(msg, code == 334)
}
if err != nil {
// abort the AUTH
c.cmd(501, "*")
c.Quit()
break
}
if resp == nil {
break
}
resp64 = make([]byte, encoding.EncodedLen(len(resp)))
encoding.Encode(resp64, resp)
code, msg64, err = c.cmd(0, string(resp64))
}
return err
}
// Mail issues a MAIL command to the server using the provided email address.
// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME
// parameter. If the server supports the SMTPUTF8 extension, Mail adds the
// SMTPUTF8 parameter.
// This initiates a mail transaction and is followed by one or more Rcpt calls.
func (c *Client) Mail(from string) error {
if err := validateLine(from); err != nil {
return err
}
if err := c.hello(); err != nil {
return err
}
cmdStr := "MAIL FROM:<%s>"
if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok {
cmdStr += " BODY=8BITMIME"
}
if _, ok := c.ext["SMTPUTF8"]; ok {
cmdStr += " SMTPUTF8"
}
}
_, _, err := c.cmd(250, cmdStr, from)
return err
}
// Rcpt issues a RCPT command to the server using the provided email address.
// A call to Rcpt must be preceded by a call to Mail and may be followed by
// a Data call or another Rcpt call.
func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil {
return err
}
_, _, err := c.cmd(25, "RCPT TO:<%s>", to)
return err
}
type dataCloser struct {
c *Client
io.WriteCloser
}
func (d *dataCloser) Close() error {
d.WriteCloser.Close()
_, _, err := d.c.Text.ReadResponse(250)
return err
}
// Data issues a DATA command to the server and returns a writer that
// can be used to write the mail headers and body. The caller should
// close the writer before calling any more methods on c. A call to
// Data must be preceded by one or more calls to Rcpt.
func (c *Client) Data() (io.WriteCloser, error) {
_, _, err := c.cmd(354, "DATA")
if err != nil {
return nil, err
}
return &dataCloser{c, c.Text.DotWriter()}, nil
}
var testHookStartTLS func(*tls.Config) // nil, except for tests
// SendMail connects to the server at addr, switches to TLS if
// possible, authenticates with the optional mechanism a if possible,
// and then sends an email from address from, to addresses to, with
// message msg.
// The addr must include a port, as in "mail.example.com:smtp".
//
// The addresses in the to parameter are the SMTP RCPT addresses.
//
// The msg parameter should be an RFC 822-style email with headers
// first, a blank line, and then the message body. The lines of msg
// should be CRLF terminated. The msg headers should usually include
// fields such as "From", "To", "Subject", and "Cc". Sending "Bcc"
// messages is accomplished by including an email address in the to
// parameter but not including it in the msg headers.
//
// The SendMail function and the net/smtp package are low-level
// mechanisms and provide no support for DKIM signing, MIME
// attachments (see the mime/multipart package), or other mail
// functionality. Higher-level packages exist outside of the standard
// library.
func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
if err := validateLine(from); err != nil {
return err
}
for _, recp := range to {
if err := validateLine(recp); err != nil {
return err
}
}
c, err := Dial(addr)
if err != nil {
return err
}
defer c.Close()
if err = c.hello(); err != nil {
return err
}
if ok, _ := c.Extension("STARTTLS"); ok {
config := &tls.Config{ServerName: c.serverName}
if testHookStartTLS != nil {
testHookStartTLS(config)
}
if err = c.StartTLS(config); err != nil {
return err
}
}
if a != nil && c.ext != nil {
if _, ok := c.ext["AUTH"]; !ok {
return errors.New("smtp: server doesn't support AUTH")
}
if err = c.Auth(a); err != nil {
return err
}
}
if err = c.Mail(from); err != nil {
return err
}
for _, addr := range to {
if err = c.Rcpt(addr); err != nil {
return err
}
}
w, err := c.Data()
if err != nil {
return err
}
_, err = w.Write(msg)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return c.Quit()
}
// Extension reports whether an extension is support by the server.
// The extension name is case-insensitive. If the extension is supported,
// Extension also returns a string that contains any parameters the
// server specifies for the extension.
func (c *Client) Extension(ext string) (bool, string) {
if err := c.hello(); err != nil {
return false, ""
}
if c.ext == nil {
return false, ""
}
ext = strings.ToUpper(ext)
param, ok := c.ext[ext]
return ok, param
}
// Reset sends the RSET command to the server, aborting the current mail
// transaction.
func (c *Client) Reset() error {
if err := c.hello(); err != nil {
return err
}
_, _, err := c.cmd(250, "RSET")
return err
}
// Noop sends the NOOP command to the server. It does nothing but check
// that the connection to the server is okay.
func (c *Client) Noop() error {
if err := c.hello(); err != nil {
return err
}
_, _, err := c.cmd(250, "NOOP")
return err
}
// Quit sends the QUIT command and closes the connection to the server.
func (c *Client) Quit() error {
if err := c.hello(); err != nil {
return err
}
_, _, err := c.cmd(221, "QUIT")
if err != nil {
return err
}
return c.Text.Close()
}
// validateLine checks to see if a line has CR or LF as per RFC 5321.
func validateLine(line string) error {
if strings.ContainsAny(line, "\n\r") {
return errors.New("smtp: A line must not contain CR or LF")
}
return nil
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements sysSocket for platforms that provide a fast path for
// setting SetNonblock and CloseOnExec.
//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
package net
import (
"os"
"syscall"
)
// Wrapper around the socket system call that marks the returned file
// descriptor as nonblocking and close-on-exec.
func sysSocket(family, sotype, proto int) (int, error) {
s, err := socketFunc(family, sotype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, proto)
if err != nil {
return -1, os.NewSyscallError("socket", err)
}
return s, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/syscall/unix"
"syscall"
)
// Linux stores the backlog as:
//
// - uint16 in kernel version < 4.1,
// - uint32 in kernel version >= 4.1
//
// Truncate number to avoid wrapping.
//
// See issue 5030 and 41470.
func maxAckBacklog(n int) int {
major, minor := unix.KernelVersion()
size := 16
if major > 4 || (major == 4 && minor >= 1) {
size = 32
}
var max uint = 1<<size - 1
if uint(n) > max {
n = int(max)
}
return n
}
func maxListenerBacklog() int {
fd, err := open("/proc/sys/net/core/somaxconn")
if err != nil {
return syscall.SOMAXCONN
}
defer fd.close()
l, ok := fd.readLine()
if !ok {
return syscall.SOMAXCONN
}
f := getFields(l)
n, _, ok := dtoi(f[0])
if n == 0 || !ok {
return syscall.SOMAXCONN
}
if n > 1<<16-1 {
return maxAckBacklog(n)
}
return n
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package net
import (
"context"
"internal/poll"
"os"
"syscall"
)
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
}
if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil {
poll.CloseFunc(s)
return nil, err
}
if fd, err = newFD(s, family, sotype, net); err != nil {
poll.CloseFunc(s)
return nil, err
}
// This function makes a network file descriptor for the
// following applications:
//
// - An endpoint holder that opens a passive stream
// connection, known as a stream listener
//
// - An endpoint holder that opens a destination-unspecific
// datagram connection, known as a datagram listener
//
// - An endpoint holder that opens an active stream or a
// destination-specific datagram connection, known as a
// dialer
//
// - An endpoint holder that opens the other connection, such
// as talking to the protocol stack inside the kernel
//
// For stream and datagram listeners, they will only require
// named sockets, so we can assume that it's just a request
// from stream or datagram listeners when laddr is not nil but
// raddr is nil. Otherwise we assume it's just for dialers or
// the other connection holders.
if laddr != nil && raddr == nil {
switch sotype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
case syscall.SOCK_DGRAM:
if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
}
if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
func (fd *netFD) ctrlNetwork() string {
switch fd.net {
case "unix", "unixgram", "unixpacket":
return fd.net
}
switch fd.net[len(fd.net)-1] {
case '4', '6':
return fd.net
}
if fd.family == syscall.AF_INET {
return fd.net + "4"
}
return fd.net + "6"
}
func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
switch fd.family {
case syscall.AF_INET, syscall.AF_INET6:
switch fd.sotype {
case syscall.SOCK_STREAM:
return sockaddrToTCP
case syscall.SOCK_DGRAM:
return sockaddrToUDP
case syscall.SOCK_RAW:
return sockaddrToIP
}
case syscall.AF_UNIX:
switch fd.sotype {
case syscall.SOCK_STREAM:
return sockaddrToUnix
case syscall.SOCK_DGRAM:
return sockaddrToUnixgram
case syscall.SOCK_SEQPACKET:
return sockaddrToUnixpacket
}
}
return func(syscall.Sockaddr) Addr { return nil }
}
func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var c *rawConn
var err error
if ctrlCtxFn != nil {
c, err = newRawConn(fd)
if err != nil {
return err
}
var ctrlAddr string
if raddr != nil {
ctrlAddr = raddr.String()
} else if laddr != nil {
ctrlAddr = laddr.String()
}
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
return err
}
}
var lsa syscall.Sockaddr
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
} else if lsa != nil {
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
}
}
var rsa syscall.Sockaddr // remote address from the user
var crsa syscall.Sockaddr // remote address we actually connected to
if raddr != nil {
if rsa, err = raddr.sockaddr(fd.family); err != nil {
return err
}
if crsa, err = fd.connect(ctx, lsa, rsa); err != nil {
return err
}
fd.isConnected = true
} else {
if err := fd.init(); err != nil {
return err
}
}
// Record the local and remote addresses from the actual socket.
// Get the local address by calling Getsockname.
// For the remote address, use
// 1) the one returned by the connect method, if any; or
// 2) the one from Getpeername, if it succeeds; or
// 3) the one passed to us as the raddr parameter.
lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
if crsa != nil {
fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(crsa))
} else if rsa, _ = syscall.Getpeername(fd.pfd.Sysfd); rsa != nil {
fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(rsa))
} else {
fd.setAddr(fd.addrFunc()(lsa), raddr)
}
return nil
}
func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
var err error
if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
return err
}
var lsa syscall.Sockaddr
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil {
return os.NewSyscallError("listen", err)
}
if err = fd.init(); err != nil {
return err
}
lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
fd.setAddr(fd.addrFunc()(lsa), nil)
return nil
}
func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
switch addr := laddr.(type) {
case *UDPAddr:
// We provide a socket that listens to a wildcard
// address with reusable UDP port when the given laddr
// is an appropriate UDP multicast address prefix.
// This makes it possible for a single UDP listener to
// join multiple different group addresses, for
// multiple UDP listeners that listen on the same UDP
// port to join the same group address.
if addr.IP != nil && addr.IP.IsMulticast() {
if err := setDefaultMulticastSockopts(fd.pfd.Sysfd); err != nil {
return err
}
addr := *addr
switch fd.family {
case syscall.AF_INET:
addr.IP = IPv4zero
case syscall.AF_INET6:
addr.IP = IPv6unspecified
}
laddr = &addr
}
}
var err error
var lsa syscall.Sockaddr
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
}
if ctrlCtxFn != nil {
c, err := newRawConn(fd)
if err != nil {
return err
}
if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
return err
}
}
if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
if err = fd.init(); err != nil {
return err
}
lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
fd.setAddr(fd.addrFunc()(lsa), nil)
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"os"
"syscall"
)
func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW {
// Allow both IP versions even if the OS default
// is otherwise. Note that some operating systems
// never admit this option.
syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
}
if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX {
// Allow broadcast.
return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
}
return nil
}
func setDefaultListenerSockopts(s int) error {
// Allow reuse of recently-used addresses.
return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
}
func setDefaultMulticastSockopts(s int) error {
// Allow multicast UDP and raw IP datagram sockets to listen
// concurrently across multiple listeners.
return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package net
import (
"internal/bytealg"
"runtime"
"syscall"
)
// Boolean to int.
func boolint(b bool) int {
if b {
return 1
}
return 0
}
func ipv4AddrToInterface(ip IP) (*Interface, error) {
ift, err := Interfaces()
if err != nil {
return nil, err
}
for _, ifi := range ift {
ifat, err := ifi.Addrs()
if err != nil {
return nil, err
}
for _, ifa := range ifat {
switch v := ifa.(type) {
case *IPAddr:
if ip.Equal(v.IP) {
return &ifi, nil
}
case *IPNet:
if ip.Equal(v.IP) {
return &ifi, nil
}
}
}
}
if ip.Equal(IPv4zero) {
return nil, nil
}
return nil, errNoSuchInterface
}
func interfaceToIPv4Addr(ifi *Interface) (IP, error) {
if ifi == nil {
return IPv4zero, nil
}
ifat, err := ifi.Addrs()
if err != nil {
return nil, err
}
for _, ifa := range ifat {
switch v := ifa.(type) {
case *IPAddr:
if v.IP.To4() != nil {
return v.IP, nil
}
case *IPNet:
if v.IP.To4() != nil {
return v.IP, nil
}
}
}
return nil, errNoSuchInterface
}
func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error {
if ifi == nil {
return nil
}
ifat, err := ifi.Addrs()
if err != nil {
return err
}
for _, ifa := range ifat {
switch v := ifa.(type) {
case *IPAddr:
if a := v.IP.To4(); a != nil {
copy(mreq.Interface[:], a)
goto done
}
case *IPNet:
if a := v.IP.To4(); a != nil {
copy(mreq.Interface[:], a)
goto done
}
}
}
done:
if bytealg.Equal(mreq.Multiaddr[:], IPv4zero.To4()) {
return errNoSuchMulticastInterface
}
return nil
}
func setReadBuffer(fd *netFD, bytes int) error {
err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setWriteBuffer(fd *netFD, bytes int) error {
err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setKeepAlive(fd *netFD, keepalive bool) error {
err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setLinger(fd *netFD, sec int) error {
var l syscall.Linger
if sec >= 0 {
l.Onoff = 1
l.Linger = int32(sec)
} else {
l.Onoff = 0
l.Linger = 0
}
err := fd.pfd.SetsockoptLinger(syscall.SOL_SOCKET, syscall.SO_LINGER, &l)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"runtime"
"syscall"
)
func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
var v int32
if ifi != nil {
v = int32(ifi.Index)
}
mreq := &syscall.IPMreqn{Ifindex: v}
err := fd.pfd.SetsockoptIPMreqn(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setIPv4MulticastLoopback(fd *netFD, v bool) error {
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package net
import (
"runtime"
"syscall"
)
func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
return err
}
err := fd.pfd.SetsockoptIPMreq(syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
var v int
if ifi != nil {
v = ifi.Index
}
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func setIPv6MulticastLoopback(fd *netFD, v bool) error {
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
mreq := &syscall.IPv6Mreq{}
copy(mreq.Multiaddr[:], ip)
if ifi != nil {
mreq.Interface = uint32(ifi.Index)
}
err := fd.pfd.SetsockoptIPv6Mreq(syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"internal/poll"
"io"
)
// splice transfers data from r to c using the splice system call to minimize
// copies from and to userspace. c must be a TCP connection. Currently, splice
// is only enabled if r is a TCP or a stream-oriented Unix connection.
//
// If splice returns handled == false, it has performed no work.
func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if ok {
remain, r = lr.N, lr.R
if remain <= 0 {
return 0, nil, true
}
}
var s *netFD
if tc, ok := r.(*TCPConn); ok {
s = tc.fd
} else if uc, ok := r.(*UnixConn); ok {
if uc.fd.net != "unix" {
return 0, nil, false
}
s = uc.fd
} else {
return 0, nil, false
}
written, handled, sc, err := poll.Splice(&c.pfd, &s.pfd, remain)
if lr != nil {
lr.N -= written
}
return written, wrapSyscallError(sc, err), handled
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"internal/itoa"
"io"
"net/netip"
"os"
"syscall"
"time"
)
// BUG(mikio): On JS and Windows, the File method of TCPConn and
// TCPListener is not implemented.
// TCPAddr represents the address of a TCP end point.
type TCPAddr struct {
IP IP
Port int
Zone string // IPv6 scoped addressing zone
}
// AddrPort returns the TCPAddr a as a netip.AddrPort.
//
// If a.Port does not fit in a uint16, it's silently truncated.
//
// If a is nil, a zero value is returned.
func (a *TCPAddr) AddrPort() netip.AddrPort {
if a == nil {
return netip.AddrPort{}
}
na, _ := netip.AddrFromSlice(a.IP)
na = na.WithZone(a.Zone)
return netip.AddrPortFrom(na, uint16(a.Port))
}
// Network returns the address's network name, "tcp".
func (a *TCPAddr) Network() string { return "tcp" }
func (a *TCPAddr) String() string {
if a == nil {
return "<nil>"
}
ip := ipEmptyString(a.IP)
if a.Zone != "" {
return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port))
}
return JoinHostPort(ip, itoa.Itoa(a.Port))
}
func (a *TCPAddr) isWildcard() bool {
if a == nil || a.IP == nil {
return true
}
return a.IP.IsUnspecified()
}
func (a *TCPAddr) opAddr() Addr {
if a == nil {
return nil
}
return a
}
// ResolveTCPAddr returns an address of TCP end point.
//
// The network must be a TCP network name.
//
// If the host in the address parameter is not a literal IP address or
// the port is not a literal port number, ResolveTCPAddr resolves the
// address to an address of TCP end point.
// Otherwise, it parses the address as a pair of literal IP address
// and port number.
// The address parameter can use a host name, but this is not
// recommended, because it will return at most one of the host name's
// IP addresses.
//
// See func Dial for a description of the network and address
// parameters.
func ResolveTCPAddr(network, address string) (*TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "": // a hint wildcard for Go 1.0 undocumented behavior
network = "tcp"
default:
return nil, UnknownNetworkError(network)
}
addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address)
if err != nil {
return nil, err
}
return addrs.forResolve(network, address).(*TCPAddr), nil
}
// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false,
// then the returned TCPAddr will contain a nil IP field, indicating an
// address family-agnostic unspecified address.
func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr {
return &TCPAddr{
IP: addr.Addr().AsSlice(),
Zone: addr.Addr().Zone(),
Port: int(addr.Port()),
}
}
// TCPConn is an implementation of the Conn interface for TCP network
// connections.
type TCPConn struct {
conn
}
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
func (c *TCPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
return newRawConn(c.fd)
}
// ReadFrom implements the io.ReaderFrom ReadFrom method.
func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.readFrom(r)
if err != nil && err != io.EOF {
err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
// CloseRead shuts down the reading side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseRead() error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.closeRead(); err != nil {
return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// CloseWrite shuts down the writing side of the TCP connection.
// Most callers should just use Close.
func (c *TCPConn) CloseWrite() error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.closeWrite(); err != nil {
return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// SetLinger sets the behavior of Close on a connection which still
// has data waiting to be sent or to be acknowledged.
//
// If sec < 0 (the default), the operating system finishes sending the
// data in the background.
//
// If sec == 0, the operating system discards any unsent or
// unacknowledged data.
//
// If sec > 0, the data is sent in the background as with sec < 0. On
// some operating systems after sec seconds have elapsed any remaining
// unsent data may be discarded.
func (c *TCPConn) SetLinger(sec int) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setLinger(c.fd, sec); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// SetKeepAlive sets whether the operating system should send
// keep-alive messages on the connection.
func (c *TCPConn) SetKeepAlive(keepalive bool) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setKeepAlive(c.fd, keepalive); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// SetKeepAlivePeriod sets period between keep-alives.
func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setKeepAlivePeriod(c.fd, d); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// SetNoDelay controls whether the operating system should delay
// packet transmission in hopes of sending fewer packets (Nagle's
// algorithm). The default is true (no delay), meaning that data is
// sent as soon as possible after a Write.
func (c *TCPConn) SetNoDelay(noDelay bool) error {
if !c.ok() {
return syscall.EINVAL
}
if err := setNoDelay(c.fd, noDelay); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
setNoDelay(fd, true)
if keepAlive == 0 {
keepAlive = defaultTCPKeepAlive
}
if keepAlive > 0 {
setKeepAlive(fd, true)
setKeepAlivePeriod(fd, keepAlive)
if keepAliveHook != nil {
keepAliveHook(keepAlive)
}
}
return &TCPConn{conn{fd}}
}
// DialTCP acts like Dial for TCP networks.
//
// The network must be a TCP network name; see func Dial for details.
//
// If laddr is nil, a local address is automatically chosen.
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialTCP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
return c, nil
}
// TCPListener is a TCP network listener. Clients should typically
// use variables of type Listener instead of assuming TCP.
type TCPListener struct {
fd *netFD
lc ListenConfig
}
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
//
// The returned RawConn only supports calling Control. Read and
// Write return an error.
func (l *TCPListener) SyscallConn() (syscall.RawConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
return newRawListener(l.fd)
}
// AcceptTCP accepts the next incoming call and returns the new
// connection.
func (l *TCPListener) AcceptTCP() (*TCPConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
c, err := l.accept()
if err != nil {
return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return c, nil
}
// Accept implements the Accept method in the Listener interface; it
// waits for the next call and returns a generic Conn.
func (l *TCPListener) Accept() (Conn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
c, err := l.accept()
if err != nil {
return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return c, nil
}
// Close stops listening on the TCP address.
// Already Accepted connections are not closed.
func (l *TCPListener) Close() error {
if !l.ok() {
return syscall.EINVAL
}
if err := l.close(); err != nil {
return &OpError{Op: "close", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return nil
}
// Addr returns the listener's network address, a *TCPAddr.
// The Addr returned is shared by all invocations of Addr, so
// do not modify it.
func (l *TCPListener) Addr() Addr { return l.fd.laddr }
// SetDeadline sets the deadline associated with the listener.
// A zero time value disables the deadline.
func (l *TCPListener) SetDeadline(t time.Time) error {
if !l.ok() {
return syscall.EINVAL
}
if err := l.fd.pfd.SetDeadline(t); err != nil {
return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return nil
}
// File returns a copy of the underlying os.File.
// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
// The returned os.File's file descriptor is different from the
// connection's. Attempting to change properties of the original
// using this duplicate may or may not have the desired effect.
func (l *TCPListener) File() (f *os.File, err error) {
if !l.ok() {
return nil, syscall.EINVAL
}
f, err = l.file()
if err != nil {
return nil, &OpError{Op: "file", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return
}
// ListenTCP acts like Listen for TCP networks.
//
// The network must be a TCP network name; see func Dial for details.
//
// If the IP field of laddr is nil or an unspecified IP address,
// ListenTCP listens on all available unicast and anycast IP addresses
// of the local system.
// If the Port field of laddr is 0, a port number is automatically
// chosen.
func ListenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
}
if laddr == nil {
laddr = &TCPAddr{}
}
sl := &sysListener{network: network, address: laddr.String()}
ln, err := sl.listenTCP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
return ln, nil
}
// roundDurationUp rounds d to the next multiple of to.
func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
return (d + to - 1) / to
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"context"
"io"
"os"
"syscall"
)
func sockaddrToTCP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port}
case *syscall.SockaddrInet6:
return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
}
return nil
}
func (a *TCPAddr) family() int {
if a == nil || len(a.IP) <= IPv4len {
return syscall.AF_INET
}
if a.IP.To4() != nil {
return syscall.AF_INET
}
return syscall.AF_INET6
}
func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
if a == nil {
return nil, nil
}
return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
func (a *TCPAddr) toLocal(net string) sockaddr {
return &TCPAddr{loopbackIP(net), a.Port, a.Zone}
}
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
if n, err, handled := splice(c.fd, r); handled {
return n, err
}
if n, err, handled := sendFile(c.fd, r); handled {
return n, err
}
return genericReadFrom(c, r)
}
func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if h := sd.testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
}
if h := testHookDialTCP; h != nil {
return h(ctx, sd.network, laddr, raddr)
}
return sd.doDialTCP(ctx, laddr, raddr)
}
func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
// connect to a simultaneous Dial("tcp", addr2, addr1) run on the machine
// at addr2, without either machine executing Listen. If laddr == nil,
// it means we want the kernel to pick an appropriate originating local
// address. Some Linux kernels cycle blindly through a fixed range of
// local ports, regardless of destination port. If a kernel happens to
// pick local port 50001 as the source for a Dial("tcp", "", "localhost:50001"),
// then the Dial will succeed, having simultaneously connected to itself.
// This can only happen when we are letting the kernel pick a port (laddr == nil)
// and when there is no listener for the destination address.
// It's hard to argue this is anything other than a kernel bug. If we
// see this happen, rather than expose the buggy effect to users, we
// close the fd and try again. If it happens twice more, we relent and
// use the result. See also:
// https://golang.org/issue/2690
// https://stackoverflow.com/questions/4949858/
//
// The opposite can also happen: if we ask the kernel to pick an appropriate
// originating local address, sometimes it picks one that is already in use.
// So if the error is EADDRNOTAVAIL, we have to try again too, just for
// a different reason.
//
// The kernel socket code is no doubt enjoying watching us squirm.
for i := 0; i < 2 && (laddr == nil || laddr.Port == 0) && (selfConnect(fd, err) || spuriousENOTAVAIL(err)); i++ {
if err == nil {
fd.Close()
}
fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", ctrlCtxFn)
}
if err != nil {
return nil, err
}
return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
}
func selfConnect(fd *netFD, err error) bool {
// If the connect failed, we clearly didn't connect to ourselves.
if err != nil {
return false
}
// The socket constructor can return an fd with raddr nil under certain
// unknown conditions. The errors in the calls there to Getpeername
// are discarded, but we can't catch the problem there because those
// calls are sometimes legally erroneous with a "socket not connected".
// Since this code (selfConnect) is already trying to work around
// a problem, we make sure if this happens we recognize trouble and
// ask the DialTCP routine to try again.
// TODO: try to understand what's really going on.
if fd.laddr == nil || fd.raddr == nil {
return true
}
l := fd.laddr.(*TCPAddr)
r := fd.raddr.(*TCPAddr)
return l.Port == r.Port && l.IP.Equal(r.IP)
}
func spuriousENOTAVAIL(err error) bool {
if op, ok := err.(*OpError); ok {
err = op.Err
}
if sys, ok := err.(*os.SyscallError); ok {
err = sys.Err
}
return err == syscall.EADDRNOTAVAIL
}
func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil }
func (ln *TCPListener) accept() (*TCPConn, error) {
fd, err := ln.fd.accept()
if err != nil {
return nil, err
}
return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
}
func (ln *TCPListener) close() error {
return ln.fd.Close()
}
func (ln *TCPListener) file() (*os.File, error) {
f, err := ln.fd.dup()
if err != nil {
return nil, err
}
return f, nil
}
func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return &TCPListener{fd: fd, lc: sl.ListenConfig}, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || windows
package net
import (
"runtime"
"syscall"
)
func setNoDelay(fd *netFD, noDelay bool) error {
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || freebsd || linux || netbsd
package net
import (
"runtime"
"syscall"
"time"
)
func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
// The kernel expects seconds so round to next highest second.
secs := int(roundDurationUp(d, time.Second))
if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil {
return wrapSyscallError("setsockopt", err)
}
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)
runtime.KeepAlive(fd)
return wrapSyscallError("setsockopt", err)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package textproto
// A MIMEHeader represents a MIME-style header mapping
// keys to sets of values.
type MIMEHeader map[string][]string
// Add adds the key, value pair to the header.
// It appends to any existing values associated with key.
func (h MIMEHeader) Add(key, value string) {
key = CanonicalMIMEHeaderKey(key)
h[key] = append(h[key], value)
}
// Set sets the header entries associated with key to
// the single element value. It replaces any existing
// values associated with key.
func (h MIMEHeader) Set(key, value string) {
h[CanonicalMIMEHeaderKey(key)] = []string{value}
}
// Get gets the first value associated with the given key.
// It is case insensitive; CanonicalMIMEHeaderKey is used
// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
// To use non-canonical keys, access the map directly.
func (h MIMEHeader) Get(key string) string {
if h == nil {
return ""
}
v := h[CanonicalMIMEHeaderKey(key)]
if len(v) == 0 {
return ""
}
return v[0]
}
// Values returns all values associated with the given key.
// It is case insensitive; CanonicalMIMEHeaderKey is
// used to canonicalize the provided key. To use non-canonical
// keys, access the map directly.
// The returned slice is not a copy.
func (h MIMEHeader) Values(key string) []string {
if h == nil {
return nil
}
return h[CanonicalMIMEHeaderKey(key)]
}
// Del deletes the values associated with key.
func (h MIMEHeader) Del(key string) {
delete(h, CanonicalMIMEHeaderKey(key))
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package textproto
import (
"sync"
)
// A Pipeline manages a pipelined in-order request/response sequence.
//
// To use a Pipeline p to manage multiple clients on a connection,
// each client should run:
//
// id := p.Next() // take a number
//
// p.StartRequest(id) // wait for turn to send request
// «send request»
// p.EndRequest(id) // notify Pipeline that request is sent
//
// p.StartResponse(id) // wait for turn to read response
// «read response»
// p.EndResponse(id) // notify Pipeline that response is read
//
// A pipelined server can use the same calls to ensure that
// responses computed in parallel are written in the correct order.
type Pipeline struct {
mu sync.Mutex
id uint
request sequencer
response sequencer
}
// Next returns the next id for a request/response pair.
func (p *Pipeline) Next() uint {
p.mu.Lock()
id := p.id
p.id++
p.mu.Unlock()
return id
}
// StartRequest blocks until it is time to send (or, if this is a server, receive)
// the request with the given id.
func (p *Pipeline) StartRequest(id uint) {
p.request.Start(id)
}
// EndRequest notifies p that the request with the given id has been sent
// (or, if this is a server, received).
func (p *Pipeline) EndRequest(id uint) {
p.request.End(id)
}
// StartResponse blocks until it is time to receive (or, if this is a server, send)
// the request with the given id.
func (p *Pipeline) StartResponse(id uint) {
p.response.Start(id)
}
// EndResponse notifies p that the response with the given id has been received
// (or, if this is a server, sent).
func (p *Pipeline) EndResponse(id uint) {
p.response.End(id)
}
// A sequencer schedules a sequence of numbered events that must
// happen in order, one after the other. The event numbering must start
// at 0 and increment without skipping. The event number wraps around
// safely as long as there are not 2^32 simultaneous events pending.
type sequencer struct {
mu sync.Mutex
id uint
wait map[uint]chan struct{}
}
// Start waits until it is time for the event numbered id to begin.
// That is, except for the first event, it waits until End(id-1) has
// been called.
func (s *sequencer) Start(id uint) {
s.mu.Lock()
if s.id == id {
s.mu.Unlock()
return
}
c := make(chan struct{})
if s.wait == nil {
s.wait = make(map[uint]chan struct{})
}
s.wait[id] = c
s.mu.Unlock()
<-c
}
// End notifies the sequencer that the event numbered id has completed,
// allowing it to schedule the event numbered id+1. It is a run-time error
// to call End with an id that is not the number of the active event.
func (s *sequencer) End(id uint) {
s.mu.Lock()
if s.id != id {
s.mu.Unlock()
panic("out of sync")
}
id++
s.id = id
if s.wait == nil {
s.wait = make(map[uint]chan struct{})
}
c, ok := s.wait[id]
if ok {
delete(s.wait, id)
}
s.mu.Unlock()
if ok {
close(c)
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package textproto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"sync"
)
// A Reader implements convenience methods for reading requests
// or responses from a text protocol network connection.
type Reader struct {
R *bufio.Reader
dot *dotReader
buf []byte // a re-usable buffer for readContinuedLineSlice
}
// NewReader returns a new Reader reading from r.
//
// To avoid denial of service attacks, the provided bufio.Reader
// should be reading from an io.LimitReader or similar Reader to bound
// the size of responses.
func NewReader(r *bufio.Reader) *Reader {
return &Reader{R: r}
}
// ReadLine reads a single line from r,
// eliding the final \n or \r\n from the returned string.
func (r *Reader) ReadLine() (string, error) {
line, err := r.readLineSlice()
return string(line), err
}
// ReadLineBytes is like ReadLine but returns a []byte instead of a string.
func (r *Reader) ReadLineBytes() ([]byte, error) {
line, err := r.readLineSlice()
if line != nil {
line = bytes.Clone(line)
}
return line, err
}
func (r *Reader) readLineSlice() ([]byte, error) {
r.closeDot()
var line []byte
for {
l, more, err := r.R.ReadLine()
if err != nil {
return nil, err
}
// Avoid the copy if the first call produced a full line.
if line == nil && !more {
return l, nil
}
line = append(line, l...)
if !more {
break
}
}
return line, nil
}
// ReadContinuedLine reads a possibly continued line from r,
// eliding the final trailing ASCII white space.
// Lines after the first are considered continuations if they
// begin with a space or tab character. In the returned data,
// continuation lines are separated from the previous line
// only by a single space: the newline and leading white space
// are removed.
//
// For example, consider this input:
//
// Line 1
// continued...
// Line 2
//
// The first call to ReadContinuedLine will return "Line 1 continued..."
// and the second will return "Line 2".
//
// Empty lines are never continued.
func (r *Reader) ReadContinuedLine() (string, error) {
line, err := r.readContinuedLineSlice(noValidation)
return string(line), err
}
// trim returns s with leading and trailing spaces and tabs removed.
// It does not assume Unicode or UTF-8.
func trim(s []byte) []byte {
i := 0
for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
i++
}
n := len(s)
for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
n--
}
return s[i:n]
}
// ReadContinuedLineBytes is like ReadContinuedLine but
// returns a []byte instead of a string.
func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
line, err := r.readContinuedLineSlice(noValidation)
if line != nil {
line = bytes.Clone(line)
}
return line, err
}
// readContinuedLineSlice reads continued lines from the reader buffer,
// returning a byte slice with all lines. The validateFirstLine function
// is run on the first read line, and if it returns an error then this
// error is returned from readContinuedLineSlice.
func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) {
if validateFirstLine == nil {
return nil, fmt.Errorf("missing validateFirstLine func")
}
// Read the first line.
line, err := r.readLineSlice()
if err != nil {
return nil, err
}
if len(line) == 0 { // blank line - no continuation
return line, nil
}
if err := validateFirstLine(line); err != nil {
return nil, err
}
// Optimistically assume that we have started to buffer the next line
// and it starts with an ASCII letter (the next header key), or a blank
// line, so we can avoid copying that buffered data around in memory
// and skipping over non-existent whitespace.
if r.R.Buffered() > 1 {
peek, _ := r.R.Peek(2)
if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
return trim(line), nil
}
}
// ReadByte or the next readLineSlice will flush the read buffer;
// copy the slice into buf.
r.buf = append(r.buf[:0], trim(line)...)
// Read continuation lines.
for r.skipSpace() > 0 {
line, err := r.readLineSlice()
if err != nil {
break
}
r.buf = append(r.buf, ' ')
r.buf = append(r.buf, trim(line)...)
}
return r.buf, nil
}
// skipSpace skips R over all spaces and returns the number of bytes skipped.
func (r *Reader) skipSpace() int {
n := 0
for {
c, err := r.R.ReadByte()
if err != nil {
// Bufio will keep err until next read.
break
}
if c != ' ' && c != '\t' {
r.R.UnreadByte()
break
}
n++
}
return n
}
func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
line, err := r.ReadLine()
if err != nil {
return
}
return parseCodeLine(line, expectCode)
}
func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
err = ProtocolError("short response: " + line)
return
}
continued = line[3] == '-'
code, err = strconv.Atoi(line[0:3])
if err != nil || code < 100 {
err = ProtocolError("invalid response code: " + line)
return
}
message = line[4:]
if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
100 <= expectCode && expectCode < 1000 && code != expectCode {
err = &Error{code, message}
}
return
}
// ReadCodeLine reads a response code line of the form
//
// code message
//
// where code is a three-digit status code and the message
// extends to the rest of the line. An example of such a line is:
//
// 220 plan9.bell-labs.com ESMTP
//
// If the prefix of the status does not match the digits in expectCode,
// ReadCodeLine returns with err set to &Error{code, message}.
// For example, if expectCode is 31, an error will be returned if
// the status is not in the range [310,319].
//
// If the response is multi-line, ReadCodeLine returns an error.
//
// An expectCode <= 0 disables the check of the status code.
func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
code, continued, message, err := r.readCodeLine(expectCode)
if err == nil && continued {
err = ProtocolError("unexpected multi-line response: " + message)
}
return
}
// ReadResponse reads a multi-line response of the form:
//
// code-message line 1
// code-message line 2
// ...
// code message line n
//
// where code is a three-digit status code. The first line starts with the
// code and a hyphen. The response is terminated by a line that starts
// with the same code followed by a space. Each line in message is
// separated by a newline (\n).
//
// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for
// details of another form of response accepted:
//
// code-message line 1
// message line 2
// ...
// code message line n
//
// If the prefix of the status does not match the digits in expectCode,
// ReadResponse returns with err set to &Error{code, message}.
// For example, if expectCode is 31, an error will be returned if
// the status is not in the range [310,319].
//
// An expectCode <= 0 disables the check of the status code.
func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
code, continued, message, err := r.readCodeLine(expectCode)
multi := continued
for continued {
line, err := r.ReadLine()
if err != nil {
return 0, "", err
}
var code2 int
var moreMessage string
code2, continued, moreMessage, err = parseCodeLine(line, 0)
if err != nil || code2 != code {
message += "\n" + strings.TrimRight(line, "\r\n")
continued = true
continue
}
message += "\n" + moreMessage
}
if err != nil && multi && message != "" {
// replace one line error message with all lines (full message)
err = &Error{code, message}
}
return
}
// DotReader returns a new Reader that satisfies Reads using the
// decoded text of a dot-encoded block read from r.
// The returned Reader is only valid until the next call
// to a method on r.
//
// Dot encoding is a common framing used for data blocks
// in text protocols such as SMTP. The data consists of a sequence
// of lines, each of which ends in "\r\n". The sequence itself
// ends at a line containing just a dot: ".\r\n". Lines beginning
// with a dot are escaped with an additional dot to avoid
// looking like the end of the sequence.
//
// The decoded form returned by the Reader's Read method
// rewrites the "\r\n" line endings into the simpler "\n",
// removes leading dot escapes if present, and stops with error io.EOF
// after consuming (and discarding) the end-of-sequence line.
func (r *Reader) DotReader() io.Reader {
r.closeDot()
r.dot = &dotReader{r: r}
return r.dot
}
type dotReader struct {
r *Reader
state int
}
// Read satisfies reads by decoding dot-encoded data read from d.r.
func (d *dotReader) Read(b []byte) (n int, err error) {
// Run data through a simple state machine to
// elide leading dots, rewrite trailing \r\n into \n,
// and detect ending .\r\n line.
const (
stateBeginLine = iota // beginning of line; initial state; must be zero
stateDot // read . at beginning of line
stateDotCR // read .\r at beginning of line
stateCR // read \r (possibly at end of line)
stateData // reading data in middle of line
stateEOF // reached .\r\n end marker line
)
br := d.r.R
for n < len(b) && d.state != stateEOF {
var c byte
c, err = br.ReadByte()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
break
}
switch d.state {
case stateBeginLine:
if c == '.' {
d.state = stateDot
continue
}
if c == '\r' {
d.state = stateCR
continue
}
d.state = stateData
case stateDot:
if c == '\r' {
d.state = stateDotCR
continue
}
if c == '\n' {
d.state = stateEOF
continue
}
d.state = stateData
case stateDotCR:
if c == '\n' {
d.state = stateEOF
continue
}
// Not part of .\r\n.
// Consume leading dot and emit saved \r.
br.UnreadByte()
c = '\r'
d.state = stateData
case stateCR:
if c == '\n' {
d.state = stateBeginLine
break
}
// Not part of \r\n. Emit saved \r
br.UnreadByte()
c = '\r'
d.state = stateData
case stateData:
if c == '\r' {
d.state = stateCR
continue
}
if c == '\n' {
d.state = stateBeginLine
}
}
b[n] = c
n++
}
if err == nil && d.state == stateEOF {
err = io.EOF
}
if err != nil && d.r.dot == d {
d.r.dot = nil
}
return
}
// closeDot drains the current DotReader if any,
// making sure that it reads until the ending dot line.
func (r *Reader) closeDot() {
if r.dot == nil {
return
}
buf := make([]byte, 128)
for r.dot != nil {
// When Read reaches EOF or an error,
// it will set r.dot == nil.
r.dot.Read(buf)
}
}
// ReadDotBytes reads a dot-encoding and returns the decoded data.
//
// See the documentation for the DotReader method for details about dot-encoding.
func (r *Reader) ReadDotBytes() ([]byte, error) {
return io.ReadAll(r.DotReader())
}
// ReadDotLines reads a dot-encoding and returns a slice
// containing the decoded lines, with the final \r\n or \n elided from each.
//
// See the documentation for the DotReader method for details about dot-encoding.
func (r *Reader) ReadDotLines() ([]string, error) {
// We could use ReadDotBytes and then Split it,
// but reading a line at a time avoids needing a
// large contiguous block of memory and is simpler.
var v []string
var err error
for {
var line string
line, err = r.ReadLine()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
break
}
// Dot by itself marks end; otherwise cut one dot.
if len(line) > 0 && line[0] == '.' {
if len(line) == 1 {
break
}
line = line[1:]
}
v = append(v, line)
}
return v, err
}
var colon = []byte(":")
// ReadMIMEHeader reads a MIME-style header from r.
// The header is a sequence of possibly continued Key: Value lines
// ending in a blank line.
// The returned map m maps CanonicalMIMEHeaderKey(key) to a
// sequence of values in the same order encountered in the input.
//
// For example, consider this input:
//
// My-Key: Value 1
// Long-Key: Even
// Longer Value
// My-Key: Value 2
//
// Given that input, ReadMIMEHeader returns the map:
//
// map[string][]string{
// "My-Key": {"Value 1", "Value 2"},
// "Long-Key": {"Even Longer Value"},
// }
func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
return readMIMEHeader(r, math.MaxInt64)
}
// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size.
// It is called by the mime/multipart package.
func readMIMEHeader(r *Reader, lim int64) (MIMEHeader, error) {
// Avoid lots of small slice allocations later by allocating one
// large one ahead of time which we'll cut up into smaller
// slices. If this isn't big enough later, we allocate small ones.
var strs []string
hint := r.upcomingHeaderNewlines()
if hint > 0 {
strs = make([]string, hint)
}
m := make(MIMEHeader, hint)
// The first line cannot start with a leading space.
if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
line, err := r.readLineSlice()
if err != nil {
return m, err
}
return m, ProtocolError("malformed MIME header initial line: " + string(line))
}
for {
kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon)
if len(kv) == 0 {
return m, err
}
// Key ends at first colon.
k, v, ok := bytes.Cut(kv, colon)
if !ok {
return m, ProtocolError("malformed MIME header line: " + string(kv))
}
key, ok := canonicalMIMEHeaderKey(k)
if !ok {
return m, ProtocolError("malformed MIME header line: " + string(kv))
}
for _, c := range v {
if !validHeaderValueByte(c) {
return m, ProtocolError("malformed MIME header line: " + string(kv))
}
}
// As per RFC 7230 field-name is a token, tokens consist of one or more chars.
// We could return a ProtocolError here, but better to be liberal in what we
// accept, so if we get an empty key, skip it.
if key == "" {
continue
}
// Skip initial spaces in value.
value := string(bytes.TrimLeft(v, " \t"))
vv := m[key]
if vv == nil {
lim -= int64(len(key))
lim -= 100 // map entry overhead
}
lim -= int64(len(value))
if lim < 0 {
// TODO: This should be a distinguishable error (ErrMessageTooLarge)
// to allow mime/multipart to detect it.
return m, errors.New("message too large")
}
if vv == nil && len(strs) > 0 {
// More than likely this will be a single-element key.
// Most headers aren't multi-valued.
// Set the capacity on strs[0] to 1, so any future append
// won't extend the slice into the other strings.
vv, strs = strs[:1:1], strs[1:]
vv[0] = value
m[key] = vv
} else {
m[key] = append(vv, value)
}
if err != nil {
return m, err
}
}
}
// noValidation is a no-op validation func for readContinuedLineSlice
// that permits any lines.
func noValidation(_ []byte) error { return nil }
// mustHaveFieldNameColon ensures that, per RFC 7230, the
// field-name is on a single line, so the first line must
// contain a colon.
func mustHaveFieldNameColon(line []byte) error {
if bytes.IndexByte(line, ':') < 0 {
return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
}
return nil
}
var nl = []byte("\n")
// upcomingHeaderNewlines returns an approximation of the number of newlines
// that will be in this header. If it gets confused, it returns 0.
func (r *Reader) upcomingHeaderNewlines() (n int) {
// Try to determine the 'hint' size.
r.R.Peek(1) // force a buffer load if empty
s := r.R.Buffered()
if s == 0 {
return
}
peek, _ := r.R.Peek(s)
return bytes.Count(peek, nl)
}
// CanonicalMIMEHeaderKey returns the canonical format of the
// MIME header key s. The canonicalization converts the first
// letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
// MIME header keys are assumed to be ASCII only.
// If s contains a space or invalid header field bytes, it is
// returned without modifications.
func CanonicalMIMEHeaderKey(s string) string {
// Quick check for canonical encoding.
upper := true
for i := 0; i < len(s); i++ {
c := s[i]
if !validHeaderFieldByte(c) {
return s
}
if upper && 'a' <= c && c <= 'z' {
s, _ = canonicalMIMEHeaderKey([]byte(s))
return s
}
if !upper && 'A' <= c && c <= 'Z' {
s, _ = canonicalMIMEHeaderKey([]byte(s))
return s
}
upper = c == '-'
}
return s
}
const toLower = 'a' - 'A'
// validHeaderFieldByte reports whether c is a valid byte in a header
// field name. RFC 7230 says:
//
// header-field = field-name ":" OWS field-value OWS
// field-name = token
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
// token = 1*tchar
func validHeaderFieldByte(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c >= 128, then 1<<c and 1<<(c-64) will both be zero,
// and this function will return false.
const mask = 0 |
(1<<(10)-1)<<'0' |
(1<<(26)-1)<<'a' |
(1<<(26)-1)<<'A' |
1<<'!' |
1<<'#' |
1<<'$' |
1<<'%' |
1<<'&' |
1<<'\'' |
1<<'*' |
1<<'+' |
1<<'-' |
1<<'.' |
1<<'^' |
1<<'_' |
1<<'`' |
1<<'|' |
1<<'~'
return ((uint64(1)<<c)&(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&(mask>>64)) != 0
}
// validHeaderValueByte reports whether c is a valid byte in a header
// field value. RFC 7230 says:
//
// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
// field-vchar = VCHAR / obs-text
// obs-text = %x80-FF
//
// RFC 5234 says:
//
// HTAB = %x09
// SP = %x20
// VCHAR = %x21-7E
func validHeaderValueByte(c byte) bool {
// mask is a 128-bit bitmap with 1s for allowed bytes,
// so that the byte c can be tested with a shift and an and.
// If c >= 128, then 1<<c and 1<<(c-64) will both be zero.
// Since this is the obs-text range, we invert the mask to
// create a bitmap with 1s for disallowed bytes.
const mask = 0 |
(1<<(0x7f-0x21)-1)<<0x21 | // VCHAR: %x21-7E
1<<0x20 | // SP: %x20
1<<0x09 // HTAB: %x09
return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
(uint64(1)<<(c-64))&^(mask>>64)) == 0
}
// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is
// allowed to mutate the provided byte slice before returning the
// string.
//
// For invalid inputs (if a contains spaces or non-token bytes), a
// is unchanged and a string copy is returned.
//
// ok is true if the header key contains only valid characters and spaces.
// ReadMIMEHeader accepts header keys containing spaces, but does not
// canonicalize them.
func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
// See if a looks like a header key. If not, return it unchanged.
noCanon := false
for _, c := range a {
if validHeaderFieldByte(c) {
continue
}
// Don't canonicalize.
if c == ' ' {
// We accept invalid headers with a space before the
// colon, but must not canonicalize them.
// See https://go.dev/issue/34540.
noCanon = true
continue
}
return string(a), false
}
if noCanon {
return string(a), true
}
upper := true
for i, c := range a {
// Canonicalize: first letter upper case
// and upper case after each dash.
// (Host, User-Agent, If-Modified-Since).
// MIME headers are ASCII only, so no Unicode issues.
if upper && 'a' <= c && c <= 'z' {
c -= toLower
} else if !upper && 'A' <= c && c <= 'Z' {
c += toLower
}
a[i] = c
upper = c == '-' // for next time
}
commonHeaderOnce.Do(initCommonHeader)
// The compiler recognizes m[string(byteSlice)] as a special
// case, so a copy of a's bytes into a new string does not
// happen in this map lookup:
if v := commonHeader[string(a)]; v != "" {
return v, true
}
return string(a), true
}
// commonHeader interns common header strings.
var commonHeader map[string]string
var commonHeaderOnce sync.Once
func initCommonHeader() {
commonHeader = make(map[string]string)
for _, v := range []string{
"Accept",
"Accept-Charset",
"Accept-Encoding",
"Accept-Language",
"Accept-Ranges",
"Cache-Control",
"Cc",
"Connection",
"Content-Id",
"Content-Language",
"Content-Length",
"Content-Transfer-Encoding",
"Content-Type",
"Cookie",
"Date",
"Dkim-Signature",
"Etag",
"Expires",
"From",
"Host",
"If-Modified-Since",
"If-None-Match",
"In-Reply-To",
"Last-Modified",
"Location",
"Message-Id",
"Mime-Version",
"Pragma",
"Received",
"Return-Path",
"Server",
"Set-Cookie",
"Subject",
"To",
"User-Agent",
"Via",
"X-Forwarded-For",
"X-Imforwards",
"X-Powered-By",
} {
commonHeader[v] = v
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package textproto implements generic support for text-based request/response
// protocols in the style of HTTP, NNTP, and SMTP.
//
// The package provides:
//
// Error, which represents a numeric error response from
// a server.
//
// Pipeline, to manage pipelined requests and responses
// in a client.
//
// Reader, to read numeric response code lines,
// key: value headers, lines wrapped with leading spaces
// on continuation lines, and whole text blocks ending
// with a dot on a line by itself.
//
// Writer, to write dot-encoded text blocks.
//
// Conn, a convenient packaging of Reader, Writer, and Pipeline for use
// with a single network connection.
package textproto
import (
"bufio"
"fmt"
"io"
"net"
)
// An Error represents a numeric error response from a server.
type Error struct {
Code int
Msg string
}
func (e *Error) Error() string {
return fmt.Sprintf("%03d %s", e.Code, e.Msg)
}
// A ProtocolError describes a protocol violation such
// as an invalid response or a hung-up connection.
type ProtocolError string
func (p ProtocolError) Error() string {
return string(p)
}
// A Conn represents a textual network protocol connection.
// It consists of a Reader and Writer to manage I/O
// and a Pipeline to sequence concurrent requests on the connection.
// These embedded types carry methods with them;
// see the documentation of those types for details.
type Conn struct {
Reader
Writer
Pipeline
conn io.ReadWriteCloser
}
// NewConn returns a new Conn using conn for I/O.
func NewConn(conn io.ReadWriteCloser) *Conn {
return &Conn{
Reader: Reader{R: bufio.NewReader(conn)},
Writer: Writer{W: bufio.NewWriter(conn)},
conn: conn,
}
}
// Close closes the connection.
func (c *Conn) Close() error {
return c.conn.Close()
}
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, error) {
c, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return NewConn(c), nil
}
// Cmd is a convenience method that sends a command after
// waiting its turn in the pipeline. The command text is the
// result of formatting format with args and appending \r\n.
// Cmd returns the id of the command, for use with StartResponse and EndResponse.
//
// For example, a client might run a HELP command that returns a dot-body
// by using:
//
// id, err := c.Cmd("HELP")
// if err != nil {
// return nil, err
// }
//
// c.StartResponse(id)
// defer c.EndResponse(id)
//
// if _, _, err = c.ReadCodeLine(110); err != nil {
// return nil, err
// }
// text, err := c.ReadDotBytes()
// if err != nil {
// return nil, err
// }
// return c.ReadCodeLine(250)
func (c *Conn) Cmd(format string, args ...any) (id uint, err error) {
id = c.Next()
c.StartRequest(id)
err = c.PrintfLine(format, args...)
c.EndRequest(id)
if err != nil {
return 0, err
}
return id, nil
}
// TrimString returns s without leading and trailing ASCII space.
func TrimString(s string) string {
for len(s) > 0 && isASCIISpace(s[0]) {
s = s[1:]
}
for len(s) > 0 && isASCIISpace(s[len(s)-1]) {
s = s[:len(s)-1]
}
return s
}
// TrimBytes returns b without leading and trailing ASCII space.
func TrimBytes(b []byte) []byte {
for len(b) > 0 && isASCIISpace(b[0]) {
b = b[1:]
}
for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
b = b[:len(b)-1]
}
return b
}
func isASCIISpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
func isASCIILetter(b byte) bool {
b |= 0x20 // make lower case
return 'a' <= b && b <= 'z'
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package textproto
import (
"bufio"
"fmt"
"io"
)
// A Writer implements convenience methods for writing
// requests or responses to a text protocol network connection.
type Writer struct {
W *bufio.Writer
dot *dotWriter
}
// NewWriter returns a new Writer writing to w.
func NewWriter(w *bufio.Writer) *Writer {
return &Writer{W: w}
}
var crnl = []byte{'\r', '\n'}
var dotcrnl = []byte{'.', '\r', '\n'}
// PrintfLine writes the formatted output followed by \r\n.
func (w *Writer) PrintfLine(format string, args ...any) error {
w.closeDot()
fmt.Fprintf(w.W, format, args...)
w.W.Write(crnl)
return w.W.Flush()
}
// DotWriter returns a writer that can be used to write a dot-encoding to w.
// It takes care of inserting leading dots when necessary,
// translating line-ending \n into \r\n, and adding the final .\r\n line
// when the DotWriter is closed. The caller should close the
// DotWriter before the next call to a method on w.
//
// See the documentation for Reader's DotReader method for details about dot-encoding.
func (w *Writer) DotWriter() io.WriteCloser {
w.closeDot()
w.dot = &dotWriter{w: w}
return w.dot
}
func (w *Writer) closeDot() {
if w.dot != nil {
w.dot.Close() // sets w.dot = nil
}
}
type dotWriter struct {
w *Writer
state int
}
const (
wstateBegin = iota // initial state; must be zero
wstateBeginLine // beginning of line
wstateCR // wrote \r (possibly at end of line)
wstateData // writing data in middle of line
)
func (d *dotWriter) Write(b []byte) (n int, err error) {
bw := d.w.W
for n < len(b) {
c := b[n]
switch d.state {
case wstateBegin, wstateBeginLine:
d.state = wstateData
if c == '.' {
// escape leading dot
bw.WriteByte('.')
}
fallthrough
case wstateData:
if c == '\r' {
d.state = wstateCR
}
if c == '\n' {
bw.WriteByte('\r')
d.state = wstateBeginLine
}
case wstateCR:
d.state = wstateData
if c == '\n' {
d.state = wstateBeginLine
}
}
if err = bw.WriteByte(c); err != nil {
break
}
n++
}
return
}
func (d *dotWriter) Close() error {
if d.w.dot == d {
d.w.dot = nil
}
bw := d.w.W
switch d.state {
default:
bw.WriteByte('\r')
fallthrough
case wstateCR:
bw.WriteByte('\n')
fallthrough
case wstateBeginLine:
bw.Write(dotcrnl)
}
return bw.Flush()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"internal/itoa"
"net/netip"
"syscall"
)
// BUG(mikio): On Plan 9, the ReadMsgUDP and
// WriteMsgUDP methods of UDPConn are not implemented.
// BUG(mikio): On Windows, the File method of UDPConn is not
// implemented.
// BUG(mikio): On JS, methods and functions related to UDPConn are not
// implemented.
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
IP IP
Port int
Zone string // IPv6 scoped addressing zone
}
// AddrPort returns the UDPAddr a as a netip.AddrPort.
//
// If a.Port does not fit in a uint16, it's silently truncated.
//
// If a is nil, a zero value is returned.
func (a *UDPAddr) AddrPort() netip.AddrPort {
if a == nil {
return netip.AddrPort{}
}
na, _ := netip.AddrFromSlice(a.IP)
na = na.WithZone(a.Zone)
return netip.AddrPortFrom(na, uint16(a.Port))
}
// Network returns the address's network name, "udp".
func (a *UDPAddr) Network() string { return "udp" }
func (a *UDPAddr) String() string {
if a == nil {
return "<nil>"
}
ip := ipEmptyString(a.IP)
if a.Zone != "" {
return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port))
}
return JoinHostPort(ip, itoa.Itoa(a.Port))
}
func (a *UDPAddr) isWildcard() bool {
if a == nil || a.IP == nil {
return true
}
return a.IP.IsUnspecified()
}
func (a *UDPAddr) opAddr() Addr {
if a == nil {
return nil
}
return a
}
// ResolveUDPAddr returns an address of UDP end point.
//
// The network must be a UDP network name.
//
// If the host in the address parameter is not a literal IP address or
// the port is not a literal port number, ResolveUDPAddr resolves the
// address to an address of UDP end point.
// Otherwise, it parses the address as a pair of literal IP address
// and port number.
// The address parameter can use a host name, but this is not
// recommended, because it will return at most one of the host name's
// IP addresses.
//
// See func Dial for a description of the network and address
// parameters.
func ResolveUDPAddr(network, address string) (*UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "": // a hint wildcard for Go 1.0 undocumented behavior
network = "udp"
default:
return nil, UnknownNetworkError(network)
}
addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address)
if err != nil {
return nil, err
}
return addrs.forResolve(network, address).(*UDPAddr), nil
}
// UDPAddrFromAddrPort returns addr as a UDPAddr. If addr.IsValid() is false,
// then the returned UDPAddr will contain a nil IP field, indicating an
// address family-agnostic unspecified address.
func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr {
return &UDPAddr{
IP: addr.Addr().AsSlice(),
Zone: addr.Addr().Zone(),
Port: int(addr.Port()),
}
}
// An addrPortUDPAddr is a netip.AddrPort-based UDP address that satisfies the Addr interface.
type addrPortUDPAddr struct {
netip.AddrPort
}
func (addrPortUDPAddr) Network() string { return "udp" }
// UDPConn is the implementation of the Conn and PacketConn interfaces
// for UDP network connections.
type UDPConn struct {
conn
}
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
func (c *UDPConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
return newRawConn(c.fd)
}
// ReadFromUDP acts like ReadFrom but returns a UDPAddr.
func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
// This function is designed to allow the caller to control the lifetime
// of the returned *UDPAddr and thereby prevent an allocation.
// See https://blog.filippo.io/efficient-go-apis-with-the-inliner/.
// The real work is done by readFromUDP, below.
return c.readFromUDP(b, &UDPAddr{})
}
// readFromUDP implements ReadFromUDP.
func (c *UDPConn) readFromUDP(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
n, addr, err := c.readFrom(b, addr)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, addr, err
}
// ReadFrom implements the PacketConn ReadFrom method.
func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
n, addr, err := c.readFromUDP(b, &UDPAddr{})
if addr == nil {
// Return Addr(nil), not Addr(*UDPConn(nil)).
return n, nil, err
}
return n, addr, err
}
// ReadFromUDPAddrPort acts like ReadFrom but returns a netip.AddrPort.
//
// If c is bound to an unspecified address, the returned
// netip.AddrPort's address might be an IPv4-mapped IPv6 address.
// Use netip.Addr.Unmap to get the address without the IPv6 prefix.
func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
if !c.ok() {
return 0, netip.AddrPort{}, syscall.EINVAL
}
n, addr, err = c.readFromAddrPort(b)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, addr, err
}
// ReadMsgUDP reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the message and the source address of the message.
//
// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
// used to manipulate IP-level socket options in oob.
func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
var ap netip.AddrPort
n, oobn, flags, ap, err = c.ReadMsgUDPAddrPort(b, oob)
if ap.IsValid() {
addr = UDPAddrFromAddrPort(ap)
}
return
}
// ReadMsgUDPAddrPort is like ReadMsgUDP but returns an netip.AddrPort instead of a UDPAddr.
func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
if !c.ok() {
return 0, 0, 0, netip.AddrPort{}, syscall.EINVAL
}
n, oobn, flags, addr, err = c.readMsg(b, oob)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return
}
// WriteToUDP acts like WriteTo but takes a UDPAddr.
func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.writeTo(b, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return n, err
}
// WriteToUDPAddrPort acts like WriteTo but takes a netip.AddrPort.
func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.writeToAddrPort(b, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
}
return n, err
}
// WriteTo implements the PacketConn WriteTo method.
func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
a, ok := addr.(*UDPAddr)
if !ok {
return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
}
n, err := c.writeTo(b, a)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
}
return n, err
}
// WriteMsgUDP writes a message to addr via c if c isn't connected, or
// to c's remote address if c is connected (in which case addr must be
// nil). The payload is copied from b and the associated out-of-band
// data is copied from oob. It returns the number of payload and
// out-of-band bytes written.
//
// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
// used to manipulate IP-level socket options in oob.
func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
}
n, oobn, err = c.writeMsg(b, oob, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return
}
// WriteMsgUDPAddrPort is like WriteMsgUDP but takes a netip.AddrPort instead of a UDPAddr.
func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
}
n, oobn, err = c.writeMsgAddrPort(b, oob, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
}
return
}
func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
// DialUDP acts like Dial for UDP networks.
//
// The network must be a UDP network name; see func Dial for details.
//
// If laddr is nil, a local address is automatically chosen.
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
switch network {
case "udp", "udp4", "udp6":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialUDP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
return c, nil
}
// ListenUDP acts like ListenPacket for UDP networks.
//
// The network must be a UDP network name; see func Dial for details.
//
// If the IP field of laddr is nil or an unspecified IP address,
// ListenUDP listens on all available IP addresses of the local system
// except multicast IP addresses.
// If the Port field of laddr is 0, a port number is automatically
// chosen.
func ListenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
switch network {
case "udp", "udp4", "udp6":
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
}
if laddr == nil {
laddr = &UDPAddr{}
}
sl := &sysListener{network: network, address: laddr.String()}
c, err := sl.listenUDP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
return c, nil
}
// ListenMulticastUDP acts like ListenPacket for UDP networks but
// takes a group address on a specific network interface.
//
// The network must be a UDP network name; see func Dial for details.
//
// ListenMulticastUDP listens on all available IP addresses of the
// local system including the group, multicast IP address.
// If ifi is nil, ListenMulticastUDP uses the system-assigned
// multicast interface, although this is not recommended because the
// assignment depends on platforms and sometimes it might require
// routing configuration.
// If the Port field of gaddr is 0, a port number is automatically
// chosen.
//
// ListenMulticastUDP is just for convenience of simple, small
// applications. There are golang.org/x/net/ipv4 and
// golang.org/x/net/ipv6 packages for general purpose uses.
//
// Note that ListenMulticastUDP will set the IP_MULTICAST_LOOP socket option
// to 0 under IPPROTO_IP, to disable loopback of multicast packets.
func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
switch network {
case "udp", "udp4", "udp6":
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: UnknownNetworkError(network)}
}
if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
}
sl := &sysListener{network: network, address: gaddr.String()}
c, err := sl.listenMulticastUDP(context.Background(), ifi, gaddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
}
return c, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"context"
"net/netip"
"syscall"
)
func sockaddrToUDP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
case *syscall.SockaddrInet6:
return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
}
return nil
}
func (a *UDPAddr) family() int {
if a == nil || len(a.IP) <= IPv4len {
return syscall.AF_INET
}
if a.IP.To4() != nil {
return syscall.AF_INET
}
return syscall.AF_INET6
}
func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
if a == nil {
return nil, nil
}
return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
func (a *UDPAddr) toLocal(net string) sockaddr {
return &UDPAddr{loopbackIP(net), a.Port, a.Zone}
}
func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
var n int
var err error
switch c.fd.family {
case syscall.AF_INET:
var from syscall.SockaddrInet4
n, err = c.fd.readFromInet4(b, &from)
if err == nil {
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes
*addr = UDPAddr{IP: ip[:], Port: from.Port}
}
case syscall.AF_INET6:
var from syscall.SockaddrInet6
n, err = c.fd.readFromInet6(b, &from)
if err == nil {
ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
*addr = UDPAddr{IP: ip[:], Port: from.Port, Zone: zoneCache.name(int(from.ZoneId))}
}
}
if err != nil {
// No sockaddr, so don't return UDPAddr.
addr = nil
}
return n, addr, err
}
func (c *UDPConn) readFromAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
var ip netip.Addr
var port int
switch c.fd.family {
case syscall.AF_INET:
var from syscall.SockaddrInet4
n, err = c.fd.readFromInet4(b, &from)
if err == nil {
ip = netip.AddrFrom4(from.Addr)
port = from.Port
}
case syscall.AF_INET6:
var from syscall.SockaddrInet6
n, err = c.fd.readFromInet6(b, &from)
if err == nil {
ip = netip.AddrFrom16(from.Addr).WithZone(zoneCache.name(int(from.ZoneId)))
port = from.Port
}
}
if err == nil {
addr = netip.AddrPortFrom(ip, uint16(port))
}
return n, addr, err
}
func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
switch c.fd.family {
case syscall.AF_INET:
var sa syscall.SockaddrInet4
n, oobn, flags, err = c.fd.readMsgInet4(b, oob, 0, &sa)
ip := netip.AddrFrom4(sa.Addr)
addr = netip.AddrPortFrom(ip, uint16(sa.Port))
case syscall.AF_INET6:
var sa syscall.SockaddrInet6
n, oobn, flags, err = c.fd.readMsgInet6(b, oob, 0, &sa)
ip := netip.AddrFrom16(sa.Addr).WithZone(zoneCache.name(int(sa.ZoneId)))
addr = netip.AddrPortFrom(ip, uint16(sa.Port))
}
return
}
func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
if c.fd.isConnected {
return 0, ErrWriteToConnected
}
if addr == nil {
return 0, errMissingAddress
}
switch c.fd.family {
case syscall.AF_INET:
sa, err := ipToSockaddrInet4(addr.IP, addr.Port)
if err != nil {
return 0, err
}
return c.fd.writeToInet4(b, &sa)
case syscall.AF_INET6:
sa, err := ipToSockaddrInet6(addr.IP, addr.Port, addr.Zone)
if err != nil {
return 0, err
}
return c.fd.writeToInet6(b, &sa)
default:
return 0, &AddrError{Err: "invalid address family", Addr: addr.IP.String()}
}
}
func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
if c.fd.isConnected {
return 0, ErrWriteToConnected
}
if !addr.IsValid() {
return 0, errMissingAddress
}
switch c.fd.family {
case syscall.AF_INET:
sa, err := addrPortToSockaddrInet4(addr)
if err != nil {
return 0, err
}
return c.fd.writeToInet4(b, &sa)
case syscall.AF_INET6:
sa, err := addrPortToSockaddrInet6(addr)
if err != nil {
return 0, err
}
return c.fd.writeToInet6(b, &sa)
default:
return 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
}
}
func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
if c.fd.isConnected && addr != nil {
return 0, 0, ErrWriteToConnected
}
if !c.fd.isConnected && addr == nil {
return 0, 0, errMissingAddress
}
sa, err := addr.sockaddr(c.fd.family)
if err != nil {
return 0, 0, err
}
return c.fd.writeMsg(b, oob, sa)
}
func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
if c.fd.isConnected && addr.IsValid() {
return 0, 0, ErrWriteToConnected
}
if !c.fd.isConnected && !addr.IsValid() {
return 0, 0, errMissingAddress
}
switch c.fd.family {
case syscall.AF_INET:
sa, err := addrPortToSockaddrInet4(addr)
if err != nil {
return 0, 0, err
}
return c.fd.writeMsgInet4(b, oob, &sa)
case syscall.AF_INET6:
sa, err := addrPortToSockaddrInet6(addr)
if err != nil {
return 0, 0, err
}
return c.fd.writeMsgInet6(b, oob, &sa)
default:
return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
}
}
func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
c := newUDPConn(fd)
if ip4 := gaddr.IP.To4(); ip4 != nil {
if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil {
c.Close()
return nil, err
}
} else {
if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil {
c.Close()
return nil, err
}
}
return c, nil
}
func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
if ifi != nil {
if err := setIPv4MulticastInterface(c.fd, ifi); err != nil {
return err
}
}
if err := setIPv4MulticastLoopback(c.fd, false); err != nil {
return err
}
if err := joinIPv4Group(c.fd, ifi, ip); err != nil {
return err
}
return nil
}
func listenIPv6MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
if ifi != nil {
if err := setIPv6MulticastInterface(c.fd, ifi); err != nil {
return err
}
}
if err := setIPv6MulticastLoopback(c.fd, false); err != nil {
return err
}
if err := joinIPv6Group(c.fd, ifi, ip); err != nil {
return err
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package net
import (
"context"
"os"
"sync"
"syscall"
"time"
)
// BUG(mikio): On JS and Plan 9, methods and functions related
// to UnixConn and UnixListener are not implemented.
// BUG(mikio): On Windows, methods and functions related to UnixConn
// and UnixListener don't work for "unixgram" and "unixpacket".
// UnixAddr represents the address of a Unix domain socket end point.
type UnixAddr struct {
Name string
Net string
}
// Network returns the address's network name, "unix", "unixgram" or
// "unixpacket".
func (a *UnixAddr) Network() string {
return a.Net
}
func (a *UnixAddr) String() string {
if a == nil {
return "<nil>"
}
return a.Name
}
func (a *UnixAddr) isWildcard() bool {
return a == nil || a.Name == ""
}
func (a *UnixAddr) opAddr() Addr {
if a == nil {
return nil
}
return a
}
// ResolveUnixAddr returns an address of Unix domain socket end point.
//
// The network must be a Unix network name.
//
// See func Dial for a description of the network and address
// parameters.
func ResolveUnixAddr(network, address string) (*UnixAddr, error) {
switch network {
case "unix", "unixgram", "unixpacket":
return &UnixAddr{Name: address, Net: network}, nil
default:
return nil, UnknownNetworkError(network)
}
}
// UnixConn is an implementation of the Conn interface for connections
// to Unix domain sockets.
type UnixConn struct {
conn
}
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
func (c *UnixConn) SyscallConn() (syscall.RawConn, error) {
if !c.ok() {
return nil, syscall.EINVAL
}
return newRawConn(c.fd)
}
// CloseRead shuts down the reading side of the Unix domain connection.
// Most callers should just use Close.
func (c *UnixConn) CloseRead() error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.closeRead(); err != nil {
return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// CloseWrite shuts down the writing side of the Unix domain connection.
// Most callers should just use Close.
func (c *UnixConn) CloseWrite() error {
if !c.ok() {
return syscall.EINVAL
}
if err := c.fd.closeWrite(); err != nil {
return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return nil
}
// ReadFromUnix acts like ReadFrom but returns a UnixAddr.
func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
n, addr, err := c.readFrom(b)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, addr, err
}
// ReadFrom implements the PacketConn ReadFrom method.
func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
if !c.ok() {
return 0, nil, syscall.EINVAL
}
n, addr, err := c.readFrom(b)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
if addr == nil {
return n, nil, err
}
return n, addr, err
}
// ReadMsgUnix reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the message and the source address of the message.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// read (and discard) 1 byte from the connection.
func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
if !c.ok() {
return 0, 0, 0, nil, syscall.EINVAL
}
n, oobn, flags, addr, err = c.readMsg(b, oob)
if err != nil {
err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return
}
// WriteToUnix acts like WriteTo but takes a UnixAddr.
func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.writeTo(b, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return n, err
}
// WriteTo implements the PacketConn WriteTo method.
func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
a, ok := addr.(*UnixAddr)
if !ok {
return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
}
n, err := c.writeTo(b, a)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
}
return n, err
}
// WriteMsgUnix writes a message to addr via c, copying the payload
// from b and the associated out-of-band data from oob. It returns the
// number of payload and out-of-band bytes written.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// write 1 byte to the connection.
func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
}
n, oobn, err = c.writeMsg(b, oob, addr)
if err != nil {
err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
}
return
}
func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
// DialUnix acts like Dial for Unix networks.
//
// The network must be a Unix network name; see func Dial for details.
//
// If laddr is non-nil, it is used as the local address for the
// connection.
func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
switch network {
case "unix", "unixgram", "unixpacket":
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
sd := &sysDialer{network: network, address: raddr.String()}
c, err := sd.dialUnix(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
return c, nil
}
// UnixListener is a Unix domain socket listener. Clients should
// typically use variables of type Listener instead of assuming Unix
// domain sockets.
type UnixListener struct {
fd *netFD
path string
unlink bool
unlinkOnce sync.Once
}
func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil }
// SyscallConn returns a raw network connection.
// This implements the syscall.Conn interface.
//
// The returned RawConn only supports calling Control. Read and
// Write return an error.
func (l *UnixListener) SyscallConn() (syscall.RawConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
return newRawListener(l.fd)
}
// AcceptUnix accepts the next incoming call and returns the new
// connection.
func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
c, err := l.accept()
if err != nil {
return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return c, nil
}
// Accept implements the Accept method in the Listener interface.
// Returned connections will be of type *UnixConn.
func (l *UnixListener) Accept() (Conn, error) {
if !l.ok() {
return nil, syscall.EINVAL
}
c, err := l.accept()
if err != nil {
return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return c, nil
}
// Close stops listening on the Unix address. Already accepted
// connections are not closed.
func (l *UnixListener) Close() error {
if !l.ok() {
return syscall.EINVAL
}
if err := l.close(); err != nil {
return &OpError{Op: "close", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return nil
}
// Addr returns the listener's network address.
// The Addr returned is shared by all invocations of Addr, so
// do not modify it.
func (l *UnixListener) Addr() Addr { return l.fd.laddr }
// SetDeadline sets the deadline associated with the listener.
// A zero time value disables the deadline.
func (l *UnixListener) SetDeadline(t time.Time) error {
if !l.ok() {
return syscall.EINVAL
}
if err := l.fd.pfd.SetDeadline(t); err != nil {
return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return nil
}
// File returns a copy of the underlying os.File.
// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
// The returned os.File's file descriptor is different from the
// connection's. Attempting to change properties of the original
// using this duplicate may or may not have the desired effect.
func (l *UnixListener) File() (f *os.File, err error) {
if !l.ok() {
return nil, syscall.EINVAL
}
f, err = l.file()
if err != nil {
err = &OpError{Op: "file", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
}
return
}
// ListenUnix acts like Listen for Unix networks.
//
// The network must be "unix" or "unixpacket".
func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
switch network {
case "unix", "unixpacket":
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
}
if laddr == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
}
sl := &sysListener{network: network, address: laddr.String()}
ln, err := sl.listenUnix(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
return ln, nil
}
// ListenUnixgram acts like ListenPacket for Unix networks.
//
// The network must be "unixgram".
func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
switch network {
case "unixgram":
default:
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
}
if laddr == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: errMissingAddress}
}
sl := &sysListener{network: network, address: laddr.String()}
c, err := sl.listenUnixgram(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
return c, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package net
import (
"context"
"errors"
"os"
"syscall"
)
func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
var sotype int
switch net {
case "unix":
sotype = syscall.SOCK_STREAM
case "unixgram":
sotype = syscall.SOCK_DGRAM
case "unixpacket":
sotype = syscall.SOCK_SEQPACKET
default:
return nil, UnknownNetworkError(net)
}
switch mode {
case "dial":
if laddr != nil && laddr.isWildcard() {
laddr = nil
}
if raddr != nil && raddr.isWildcard() {
raddr = nil
}
if raddr == nil && (sotype != syscall.SOCK_DGRAM || laddr == nil) {
return nil, errMissingAddress
}
case "listen":
default:
return nil, errors.New("unknown mode: " + mode)
}
fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
if err != nil {
return nil, err
}
return fd, nil
}
func sockaddrToUnix(sa syscall.Sockaddr) Addr {
if s, ok := sa.(*syscall.SockaddrUnix); ok {
return &UnixAddr{Name: s.Name, Net: "unix"}
}
return nil
}
func sockaddrToUnixgram(sa syscall.Sockaddr) Addr {
if s, ok := sa.(*syscall.SockaddrUnix); ok {
return &UnixAddr{Name: s.Name, Net: "unixgram"}
}
return nil
}
func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr {
if s, ok := sa.(*syscall.SockaddrUnix); ok {
return &UnixAddr{Name: s.Name, Net: "unixpacket"}
}
return nil
}
func sotypeToNet(sotype int) string {
switch sotype {
case syscall.SOCK_STREAM:
return "unix"
case syscall.SOCK_DGRAM:
return "unixgram"
case syscall.SOCK_SEQPACKET:
return "unixpacket"
default:
panic("sotypeToNet unknown socket type")
}
}
func (a *UnixAddr) family() int {
return syscall.AF_UNIX
}
func (a *UnixAddr) sockaddr(family int) (syscall.Sockaddr, error) {
if a == nil {
return nil, nil
}
return &syscall.SockaddrUnix{Name: a.Name}, nil
}
func (a *UnixAddr) toLocal(net string) sockaddr {
return a
}
func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
var addr *UnixAddr
n, sa, err := c.fd.readFrom(b)
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
if sa.Name != "" {
addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)}
}
}
return n, addr, err
}
func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
var sa syscall.Sockaddr
n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags)
if readMsgFlags == 0 && err == nil && oobn > 0 {
setReadMsgCloseOnExec(oob[:oobn])
}
switch sa := sa.(type) {
case *syscall.SockaddrUnix:
if sa.Name != "" {
addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)}
}
}
return
}
func (c *UnixConn) writeTo(b []byte, addr *UnixAddr) (int, error) {
if c.fd.isConnected {
return 0, ErrWriteToConnected
}
if addr == nil {
return 0, errMissingAddress
}
if addr.Net != sotypeToNet(c.fd.sotype) {
return 0, syscall.EAFNOSUPPORT
}
sa := &syscall.SockaddrUnix{Name: addr.Name}
return c.fd.writeTo(b, sa)
}
func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
if c.fd.sotype == syscall.SOCK_DGRAM && c.fd.isConnected {
return 0, 0, ErrWriteToConnected
}
var sa syscall.Sockaddr
if addr != nil {
if addr.Net != sotypeToNet(c.fd.sotype) {
return 0, 0, syscall.EAFNOSUPPORT
}
sa = &syscall.SockaddrUnix{Name: addr.Name}
}
return c.fd.writeMsg(b, oob, sa)
}
func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
ctrlCtxFn := sd.Dialer.ControlContext
if ctrlCtxFn == nil && sd.Dialer.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sd.Dialer.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
if err != nil {
return nil, err
}
return newUnixConn(fd), nil
}
func (ln *UnixListener) accept() (*UnixConn, error) {
fd, err := ln.fd.accept()
if err != nil {
return nil, err
}
return newUnixConn(fd), nil
}
func (ln *UnixListener) close() error {
// The operating system doesn't clean up
// the file that announcing created, so
// we have to clean it up ourselves.
// There's a race here--we can't know for
// sure whether someone else has come along
// and replaced our socket name already--
// but this sequence (remove then close)
// is at least compatible with the auto-remove
// sequence in ListenUnix. It's only non-Go
// programs that can mess us up.
// Even if there are racy calls to Close, we want to unlink only for the first one.
ln.unlinkOnce.Do(func() {
if ln.path[0] != '@' && ln.unlink {
syscall.Unlink(ln.path)
}
})
return ln.fd.Close()
}
func (ln *UnixListener) file() (*os.File, error) {
f, err := ln.fd.dup()
if err != nil {
return nil, err
}
return f, nil
}
// SetUnlinkOnClose sets whether the underlying socket file should be removed
// from the file system when the listener is closed.
//
// The default behavior is to unlink the socket file only when package net created it.
// That is, when the listener and the underlying socket file were created by a call to
// Listen or ListenUnix, then by default closing the listener will remove the socket file.
// but if the listener was created by a call to FileListener to use an already existing
// socket file, then by default closing the listener will not remove the socket file.
func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
l.unlink = unlink
}
func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
}
func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
if sl.ListenConfig.Control != nil {
ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
return sl.ListenConfig.Control(network, address, c)
}
}
fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
if err != nil {
return nil, err
}
return newUnixConn(fd), nil
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || linux || netbsd || openbsd
package net
import "syscall"
const readMsgFlags = syscall.MSG_CMSG_CLOEXEC
func setReadMsgCloseOnExec(oob []byte) {}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package url parses URLs and implements query escaping.
package url
// See RFC 3986. This package generally follows RFC 3986, except where
// it deviates for compatibility reasons. When sending changes, first
// search old issues for history on decisions. Unit tests should also
// contain references to issue numbers with details.
import (
"errors"
"fmt"
"path"
"sort"
"strconv"
"strings"
)
// Error reports an error and the operation and URL that caused it.
type Error struct {
Op string
URL string
Err error
}
func (e *Error) Unwrap() error { return e.Err }
func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) }
func (e *Error) Timeout() bool {
t, ok := e.Err.(interface {
Timeout() bool
})
return ok && t.Timeout()
}
func (e *Error) Temporary() bool {
t, ok := e.Err.(interface {
Temporary() bool
})
return ok && t.Temporary()
}
const upperhex = "0123456789ABCDEF"
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
return true
case 'a' <= c && c <= 'f':
return true
case 'A' <= c && c <= 'F':
return true
}
return false
}
func unhex(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
}
return 0
}
type encoding int
const (
encodePath encoding = 1 + iota
encodePathSegment
encodeHost
encodeZone
encodeUserPassword
encodeQueryComponent
encodeFragment
)
type EscapeError string
func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
type InvalidHostError string
func (e InvalidHostError) Error() string {
return "invalid character " + strconv.Quote(string(e)) + " in host name"
}
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
// Please be informed that for now shouldEscape does not check all
// reserved characters correctly. See golang.org/issue/5684.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
if mode == encodeHost || mode == encodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
// We add [ ] because we include [ipv6]:port as part of host.
// We add < > because they're the only characters left that
// we could possibly allow, and Parse will reject them if we
// escape them (because hosts can't use %-encoding for
// ASCII bytes).
switch c {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
switch c {
case '-', '_', '.', '~': // §2.3 Unreserved characters (mark)
return false
case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case encodePath: // §3.3
// The RFC allows : @ & = + $ but saves / ; , for assigning
// meaning to individual path segments. This package
// only manipulates the path as a whole, so we allow those
// last three as well. That leaves only ? to escape.
return c == '?'
case encodePathSegment: // §3.3
// The RFC allows : @ & = + $ but saves / ; , for assigning
// meaning to individual path segments.
return c == '/' || c == ';' || c == ',' || c == '?'
case encodeUserPassword: // §3.2.1
// The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
// userinfo, so we must escape only '@', '/', and '?'.
// The parsing of userinfo treats ':' as special so we must escape
// that too.
return c == '@' || c == '/' || c == '?' || c == ':'
case encodeQueryComponent: // §3.4
// The RFC reserves (so we must escape) everything.
return true
case encodeFragment: // §4.1
// The RFC text is silent but the grammar allows
// everything, so escape nothing.
return false
}
}
if mode == encodeFragment {
// RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are
// included in reserved from RFC 2396 §2.2. The remaining sub-delims do not
// need to be escaped. To minimize potential breakage, we apply two restrictions:
// (1) we always escape sub-delims outside of the fragment, and (2) we always
// escape single quote to avoid breaking callers that had previously assumed that
// single quotes would be escaped. See issue #19917.
switch c {
case '!', '(', ')', '*':
return false
}
}
// Everything else must be escaped.
return true
}
// QueryUnescape does the inverse transformation of QueryEscape,
// converting each 3-byte encoded substring of the form "%AB" into the
// hex-decoded byte 0xAB.
// It returns an error if any % is not followed by two hexadecimal
// digits.
func QueryUnescape(s string) (string, error) {
return unescape(s, encodeQueryComponent)
}
// PathUnescape does the inverse transformation of PathEscape,
// converting each 3-byte encoded substring of the form "%AB" into the
// hex-decoded byte 0xAB. It returns an error if any % is not followed
// by two hexadecimal digits.
//
// PathUnescape is identical to QueryUnescape except that it does not
// unescape '+' to ' ' (space).
func PathUnescape(s string) (string, error) {
return unescape(s, encodePathSegment)
}
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
func unescape(s string, mode encoding) (string, error) {
// Count %, check that they're well-formed.
n := 0
hasPlus := false
for i := 0; i < len(s); {
switch s[i] {
case '%':
n++
if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
s = s[i:]
if len(s) > 3 {
s = s[:3]
}
return "", EscapeError(s)
}
// Per https://tools.ietf.org/html/rfc3986#page-21
// in the host component %-encoding can only be used
// for non-ASCII bytes.
// But https://tools.ietf.org/html/rfc6874#section-2
// introduces %25 being allowed to escape a percent sign
// in IPv6 scoped-address literals. Yay.
if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
return "", EscapeError(s[i : i+3])
}
if mode == encodeZone {
// RFC 6874 says basically "anything goes" for zone identifiers
// and that even non-ASCII can be redundantly escaped,
// but it seems prudent to restrict %-escaped bytes here to those
// that are valid host name bytes in their unescaped form.
// That is, you can use escaping in the zone identifier but not
// to introduce bytes you couldn't just write directly.
// But Windows puts spaces here! Yay.
v := unhex(s[i+1])<<4 | unhex(s[i+2])
if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) {
return "", EscapeError(s[i : i+3])
}
}
i += 3
case '+':
hasPlus = mode == encodeQueryComponent
i++
default:
if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
return "", InvalidHostError(s[i : i+1])
}
i++
}
}
if n == 0 && !hasPlus {
return s, nil
}
var t strings.Builder
t.Grow(len(s) - 2*n)
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
i += 2
case '+':
if mode == encodeQueryComponent {
t.WriteByte(' ')
} else {
t.WriteByte('+')
}
default:
t.WriteByte(s[i])
}
}
return t.String(), nil
}
// QueryEscape escapes the string so it can be safely placed
// inside a URL query.
func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
// PathEscape escapes the string so it can be safely placed inside a URL path segment,
// replacing special characters (including /) with %XX sequences as needed.
func PathEscape(s string) string {
return escape(s, encodePathSegment)
}
func escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == encodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
var buf [64]byte
var t []byte
required := len(s) + 2*hexCount
if required <= len(buf) {
t = buf[:required]
} else {
t = make([]byte, required)
}
if hexCount == 0 {
copy(t, s)
for i := 0; i < len(s); i++ {
if s[i] == ' ' {
t[i] = '+'
}
}
return string(t)
}
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == encodeQueryComponent:
t[j] = '+'
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = upperhex[c>>4]
t[j+2] = upperhex[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}
// A URL represents a parsed URL (technically, a URI reference).
//
// The general form represented is:
//
// [scheme:][//[userinfo@]host][/]path[?query][#fragment]
//
// URLs that do not start with a slash after the scheme are interpreted as:
//
// scheme:opaque[?query][#fragment]
//
// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
// A consequence is that it is impossible to tell which slashes in the Path were
// slashes in the raw URL and which were %2f. This distinction is rarely important,
// but when it is, the code should use the EscapedPath method, which preserves
// the original encoding of Path.
//
// The RawPath field is an optional field which is only set when the default
// encoding of Path is different from the escaped path. See the EscapedPath method
// for more details.
//
// URL's String method uses the EscapedPath method to obtain the path.
type URL struct {
Scheme string
Opaque string // encoded opaque data
User *Userinfo // username and password information
Host string // host or host:port
Path string // path (relative paths may omit leading slash)
RawPath string // encoded path hint (see EscapedPath method)
OmitHost bool // do not emit empty host (authority)
ForceQuery bool // append a query ('?') even if RawQuery is empty
RawQuery string // encoded query values, without '?'
Fragment string // fragment for references, without '#'
RawFragment string // encoded fragment hint (see EscapedFragment method)
}
// User returns a Userinfo containing the provided username
// and no password set.
func User(username string) *Userinfo {
return &Userinfo{username, "", false}
}
// UserPassword returns a Userinfo containing the provided username
// and password.
//
// This functionality should only be used with legacy web sites.
// RFC 2396 warns that interpreting Userinfo this way
// “is NOT RECOMMENDED, because the passing of authentication
// information in clear text (such as URI) has proven to be a
// security risk in almost every case where it has been used.”
func UserPassword(username, password string) *Userinfo {
return &Userinfo{username, password, true}
}
// The Userinfo type is an immutable encapsulation of username and
// password details for a URL. An existing Userinfo value is guaranteed
// to have a username set (potentially empty, as allowed by RFC 2396),
// and optionally a password.
type Userinfo struct {
username string
password string
passwordSet bool
}
// Username returns the username.
func (u *Userinfo) Username() string {
if u == nil {
return ""
}
return u.username
}
// Password returns the password in case it is set, and whether it is set.
func (u *Userinfo) Password() (string, bool) {
if u == nil {
return "", false
}
return u.password, u.passwordSet
}
// String returns the encoded userinfo information in the standard form
// of "username[:password]".
func (u *Userinfo) String() string {
if u == nil {
return ""
}
s := escape(u.username, encodeUserPassword)
if u.passwordSet {
s += ":" + escape(u.password, encodeUserPassword)
}
return s
}
// Maybe rawURL is of the form scheme:path.
// (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*)
// If so, return scheme, path; else return "", rawURL.
func getScheme(rawURL string) (scheme, path string, err error) {
for i := 0; i < len(rawURL); i++ {
c := rawURL[i]
switch {
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
// do nothing
case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.':
if i == 0 {
return "", rawURL, nil
}
case c == ':':
if i == 0 {
return "", "", errors.New("missing protocol scheme")
}
return rawURL[:i], rawURL[i+1:], nil
default:
// we have encountered an invalid character,
// so there is no valid scheme
return "", rawURL, nil
}
}
return "", rawURL, nil
}
// Parse parses a raw url into a URL structure.
//
// The url may be relative (a path, without a host) or absolute
// (starting with a scheme). Trying to parse a hostname and path
// without a scheme is invalid but may not necessarily return an
// error, due to parsing ambiguities.
func Parse(rawURL string) (*URL, error) {
// Cut off #frag
u, frag, _ := strings.Cut(rawURL, "#")
url, err := parse(u, false)
if err != nil {
return nil, &Error{"parse", u, err}
}
if frag == "" {
return url, nil
}
if err = url.setFragment(frag); err != nil {
return nil, &Error{"parse", rawURL, err}
}
return url, nil
}
// ParseRequestURI parses a raw url into a URL structure. It assumes that
// url was received in an HTTP request, so the url is interpreted
// only as an absolute URI or an absolute path.
// The string url is assumed not to have a #fragment suffix.
// (Web browsers strip #fragment before sending the URL to a web server.)
func ParseRequestURI(rawURL string) (*URL, error) {
url, err := parse(rawURL, true)
if err != nil {
return nil, &Error{"parse", rawURL, err}
}
return url, nil
}
// parse parses a URL from a string in one of two contexts. If
// viaRequest is true, the URL is assumed to have arrived via an HTTP request,
// in which case only absolute URLs or path-absolute relative URLs are allowed.
// If viaRequest is false, all forms of relative URLs are allowed.
func parse(rawURL string, viaRequest bool) (*URL, error) {
var rest string
var err error
if stringContainsCTLByte(rawURL) {
return nil, errors.New("net/url: invalid control character in URL")
}
if rawURL == "" && viaRequest {
return nil, errors.New("empty url")
}
url := new(URL)
if rawURL == "*" {
url.Path = "*"
return url, nil
}
// Split off possible leading "http:", "mailto:", etc.
// Cannot contain escaped characters.
if url.Scheme, rest, err = getScheme(rawURL); err != nil {
return nil, err
}
url.Scheme = strings.ToLower(url.Scheme)
if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 {
url.ForceQuery = true
rest = rest[:len(rest)-1]
} else {
rest, url.RawQuery, _ = strings.Cut(rest, "?")
}
if !strings.HasPrefix(rest, "/") {
if url.Scheme != "" {
// We consider rootless paths per RFC 3986 as opaque.
url.Opaque = rest
return url, nil
}
if viaRequest {
return nil, errors.New("invalid URI for request")
}
// Avoid confusion with malformed schemes, like cache_object:foo/bar.
// See golang.org/issue/16822.
//
// RFC 3986, §3.3:
// In addition, a URI reference (Section 4.1) may be a relative-path reference,
// in which case the first path segment cannot contain a colon (":") character.
if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") {
// First path segment has colon. Not allowed in relative URL.
return nil, errors.New("first path segment in URL cannot contain colon")
}
}
if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") {
var authority string
authority, rest = rest[2:], ""
if i := strings.Index(authority, "/"); i >= 0 {
authority, rest = authority[:i], authority[i:]
}
url.User, url.Host, err = parseAuthority(authority)
if err != nil {
return nil, err
}
} else if url.Scheme != "" && strings.HasPrefix(rest, "/") {
// OmitHost is set to true when rawURL has an empty host (authority).
// See golang.org/issue/46059.
url.OmitHost = true
}
// Set Path and, optionally, RawPath.
// RawPath is a hint of the encoding of Path. We don't want to set it if
// the default escaping of Path is equivalent, to help make sure that people
// don't rely on it in general.
if err := url.setPath(rest); err != nil {
return nil, err
}
return url, nil
}
func parseAuthority(authority string) (user *Userinfo, host string, err error) {
i := strings.LastIndex(authority, "@")
if i < 0 {
host, err = parseHost(authority)
} else {
host, err = parseHost(authority[i+1:])
}
if err != nil {
return nil, "", err
}
if i < 0 {
return nil, host, nil
}
userinfo := authority[:i]
if !validUserinfo(userinfo) {
return nil, "", errors.New("net/url: invalid userinfo")
}
if !strings.Contains(userinfo, ":") {
if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil {
return nil, "", err
}
user = User(userinfo)
} else {
username, password, _ := strings.Cut(userinfo, ":")
if username, err = unescape(username, encodeUserPassword); err != nil {
return nil, "", err
}
if password, err = unescape(password, encodeUserPassword); err != nil {
return nil, "", err
}
user = UserPassword(username, password)
}
return user, host, nil
}
// parseHost parses host as an authority without user
// information. That is, as host[:port].
func parseHost(host string) (string, error) {
if strings.HasPrefix(host, "[") {
// Parse an IP-Literal in RFC 3986 and RFC 6874.
// E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80".
i := strings.LastIndex(host, "]")
if i < 0 {
return "", errors.New("missing ']' in host")
}
colonPort := host[i+1:]
if !validOptionalPort(colonPort) {
return "", fmt.Errorf("invalid port %q after host", colonPort)
}
// RFC 6874 defines that %25 (%-encoded percent) introduces
// the zone identifier, and the zone identifier can use basically
// any %-encoding it likes. That's different from the host, which
// can only %-encode non-ASCII bytes.
// We do impose some restrictions on the zone, to avoid stupidity
// like newlines.
zone := strings.Index(host[:i], "%25")
if zone >= 0 {
host1, err := unescape(host[:zone], encodeHost)
if err != nil {
return "", err
}
host2, err := unescape(host[zone:i], encodeZone)
if err != nil {
return "", err
}
host3, err := unescape(host[i:], encodeHost)
if err != nil {
return "", err
}
return host1 + host2 + host3, nil
}
} else if i := strings.LastIndex(host, ":"); i != -1 {
colonPort := host[i:]
if !validOptionalPort(colonPort) {
return "", fmt.Errorf("invalid port %q after host", colonPort)
}
}
var err error
if host, err = unescape(host, encodeHost); err != nil {
return "", err
}
return host, nil
}
// setPath sets the Path and RawPath fields of the URL based on the provided
// escaped path p. It maintains the invariant that RawPath is only specified
// when it differs from the default encoding of the path.
// For example:
// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath=""
// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar"
// setPath will return an error only if the provided path contains an invalid
// escaping.
func (u *URL) setPath(p string) error {
path, err := unescape(p, encodePath)
if err != nil {
return err
}
u.Path = path
if escp := escape(path, encodePath); p == escp {
// Default encoding is fine.
u.RawPath = ""
} else {
u.RawPath = p
}
return nil
}
// EscapedPath returns the escaped form of u.Path.
// In general there are multiple possible escaped forms of any path.
// EscapedPath returns u.RawPath when it is a valid escaping of u.Path.
// Otherwise EscapedPath ignores u.RawPath and computes an escaped
// form on its own.
// The String and RequestURI methods use EscapedPath to construct
// their results.
// In general, code should call EscapedPath instead of
// reading u.RawPath directly.
func (u *URL) EscapedPath() string {
if u.RawPath != "" && validEncoded(u.RawPath, encodePath) {
p, err := unescape(u.RawPath, encodePath)
if err == nil && p == u.Path {
return u.RawPath
}
}
if u.Path == "*" {
return "*" // don't escape (Issue 11202)
}
return escape(u.Path, encodePath)
}
// validEncoded reports whether s is a valid encoded path or fragment,
// according to mode.
// It must not contain any bytes that require escaping during encoding.
func validEncoded(s string, mode encoding) bool {
for i := 0; i < len(s); i++ {
// RFC 3986, Appendix A.
// pchar = unreserved / pct-encoded / sub-delims / ":" / "@".
// shouldEscape is not quite compliant with the RFC,
// so we check the sub-delims ourselves and let
// shouldEscape handle the others.
switch s[i] {
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@':
// ok
case '[', ']':
// ok - not specified in RFC 3986 but left alone by modern browsers
case '%':
// ok - percent encoded, will decode
default:
if shouldEscape(s[i], mode) {
return false
}
}
}
return true
}
// setFragment is like setPath but for Fragment/RawFragment.
func (u *URL) setFragment(f string) error {
frag, err := unescape(f, encodeFragment)
if err != nil {
return err
}
u.Fragment = frag
if escf := escape(frag, encodeFragment); f == escf {
// Default encoding is fine.
u.RawFragment = ""
} else {
u.RawFragment = f
}
return nil
}
// EscapedFragment returns the escaped form of u.Fragment.
// In general there are multiple possible escaped forms of any fragment.
// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment.
// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped
// form on its own.
// The String method uses EscapedFragment to construct its result.
// In general, code should call EscapedFragment instead of
// reading u.RawFragment directly.
func (u *URL) EscapedFragment() string {
if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) {
f, err := unescape(u.RawFragment, encodeFragment)
if err == nil && f == u.Fragment {
return u.RawFragment
}
}
return escape(u.Fragment, encodeFragment)
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/
func validOptionalPort(port string) bool {
if port == "" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
// String reassembles the URL into a valid URL string.
// The general form of the result is one of:
//
// scheme:opaque?query#fragment
// scheme://userinfo@host/path?query#fragment
//
// If u.Opaque is non-empty, String uses the first form;
// otherwise it uses the second form.
// Any non-ASCII characters in host are escaped.
// To obtain the path, String uses u.EscapedPath().
//
// In the second form, the following rules apply:
// - if u.Scheme is empty, scheme: is omitted.
// - if u.User is nil, userinfo@ is omitted.
// - if u.Host is empty, host/ is omitted.
// - if u.Scheme and u.Host are empty and u.User is nil,
// the entire scheme://userinfo@host/ is omitted.
// - if u.Host is non-empty and u.Path begins with a /,
// the form host/path does not add its own /.
// - if u.RawQuery is empty, ?query is omitted.
// - if u.Fragment is empty, #fragment is omitted.
func (u *URL) String() string {
var buf strings.Builder
if u.Scheme != "" {
buf.WriteString(u.Scheme)
buf.WriteByte(':')
}
if u.Opaque != "" {
buf.WriteString(u.Opaque)
} else {
if u.Scheme != "" || u.Host != "" || u.User != nil {
if u.OmitHost && u.Host == "" && u.User == nil {
// omit empty host
} else {
if u.Host != "" || u.Path != "" || u.User != nil {
buf.WriteString("//")
}
if ui := u.User; ui != nil {
buf.WriteString(ui.String())
buf.WriteByte('@')
}
if h := u.Host; h != "" {
buf.WriteString(escape(h, encodeHost))
}
}
}
path := u.EscapedPath()
if path != "" && path[0] != '/' && u.Host != "" {
buf.WriteByte('/')
}
if buf.Len() == 0 {
// RFC 3986 §4.2
// A path segment that contains a colon character (e.g., "this:that")
// cannot be used as the first segment of a relative-path reference, as
// it would be mistaken for a scheme name. Such a segment must be
// preceded by a dot-segment (e.g., "./this:that") to make a relative-
// path reference.
if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") {
buf.WriteString("./")
}
}
buf.WriteString(path)
}
if u.ForceQuery || u.RawQuery != "" {
buf.WriteByte('?')
buf.WriteString(u.RawQuery)
}
if u.Fragment != "" {
buf.WriteByte('#')
buf.WriteString(u.EscapedFragment())
}
return buf.String()
}
// Redacted is like String but replaces any password with "xxxxx".
// Only the password in u.URL is redacted.
func (u *URL) Redacted() string {
if u == nil {
return ""
}
ru := *u
if _, has := ru.User.Password(); has {
ru.User = UserPassword(ru.User.Username(), "xxxxx")
}
return ru.String()
}
// Values maps a string key to a list of values.
// It is typically used for query parameters and form values.
// Unlike in the http.Header map, the keys in a Values map
// are case-sensitive.
type Values map[string][]string
// Get gets the first value associated with the given key.
// If there are no values associated with the key, Get returns
// the empty string. To access multiple values, use the map
// directly.
func (v Values) Get(key string) string {
vs := v[key]
if len(vs) == 0 {
return ""
}
return vs[0]
}
// Set sets the key to value. It replaces any existing
// values.
func (v Values) Set(key, value string) {
v[key] = []string{value}
}
// Add adds the value to key. It appends to any existing
// values associated with key.
func (v Values) Add(key, value string) {
v[key] = append(v[key], value)
}
// Del deletes the values associated with key.
func (v Values) Del(key string) {
delete(v, key)
}
// Has checks whether a given key is set.
func (v Values) Has(key string) bool {
_, ok := v[key]
return ok
}
// ParseQuery parses the URL-encoded query string and returns
// a map listing the values specified for each key.
// ParseQuery always returns a non-nil map containing all the
// valid query parameters found; err describes the first decoding error
// encountered, if any.
//
// Query is expected to be a list of key=value settings separated by ampersands.
// A setting without an equals sign is interpreted as a key set to an empty
// value.
// Settings containing a non-URL-encoded semicolon are considered invalid.
func ParseQuery(query string) (Values, error) {
m := make(Values)
err := parseQuery(m, query)
return m, err
}
func parseQuery(m Values, query string) (err error) {
for query != "" {
var key string
key, query, _ = strings.Cut(query, "&")
if strings.Contains(key, ";") {
err = fmt.Errorf("invalid semicolon separator in query")
continue
}
if key == "" {
continue
}
key, value, _ := strings.Cut(key, "=")
key, err1 := QueryUnescape(key)
if err1 != nil {
if err == nil {
err = err1
}
continue
}
value, err1 = QueryUnescape(value)
if err1 != nil {
if err == nil {
err = err1
}
continue
}
m[key] = append(m[key], value)
}
return err
}
// Encode encodes the values into “URL encoded” form
// ("bar=baz&foo=quux") sorted by key.
func (v Values) Encode() string {
if v == nil {
return ""
}
var buf strings.Builder
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
keyEscaped := QueryEscape(k)
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(keyEscaped)
buf.WriteByte('=')
buf.WriteString(QueryEscape(v))
}
}
return buf.String()
}
// resolvePath applies special path segments from refs and applies
// them to base, per RFC 3986.
func resolvePath(base, ref string) string {
var full string
if ref == "" {
full = base
} else if ref[0] != '/' {
i := strings.LastIndex(base, "/")
full = base[:i+1] + ref
} else {
full = ref
}
if full == "" {
return ""
}
var (
elem string
dst strings.Builder
)
first := true
remaining := full
// We want to return a leading '/', so write it now.
dst.WriteByte('/')
found := true
for found {
elem, remaining, found = strings.Cut(remaining, "/")
if elem == "." {
first = false
// drop
continue
}
if elem == ".." {
// Ignore the leading '/' we already wrote.
str := dst.String()[1:]
index := strings.LastIndexByte(str, '/')
dst.Reset()
dst.WriteByte('/')
if index == -1 {
first = true
} else {
dst.WriteString(str[:index])
}
} else {
if !first {
dst.WriteByte('/')
}
dst.WriteString(elem)
first = false
}
}
if elem == "." || elem == ".." {
dst.WriteByte('/')
}
// We wrote an initial '/', but we don't want two.
r := dst.String()
if len(r) > 1 && r[1] == '/' {
r = r[1:]
}
return r
}
// IsAbs reports whether the URL is absolute.
// Absolute means that it has a non-empty scheme.
func (u *URL) IsAbs() bool {
return u.Scheme != ""
}
// Parse parses a URL in the context of the receiver. The provided URL
// may be relative or absolute. Parse returns nil, err on parse
// failure, otherwise its return value is the same as ResolveReference.
func (u *URL) Parse(ref string) (*URL, error) {
refURL, err := Parse(ref)
if err != nil {
return nil, err
}
return u.ResolveReference(refURL), nil
}
// ResolveReference resolves a URI reference to an absolute URI from
// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference
// may be relative or absolute. ResolveReference always returns a new
// URL instance, even if the returned URL is identical to either the
// base or reference. If ref is an absolute URL, then ResolveReference
// ignores base and returns a copy of ref.
func (u *URL) ResolveReference(ref *URL) *URL {
url := *ref
if ref.Scheme == "" {
url.Scheme = u.Scheme
}
if ref.Scheme != "" || ref.Host != "" || ref.User != nil {
// The "absoluteURI" or "net_path" cases.
// We can ignore the error from setPath since we know we provided a
// validly-escaped path.
url.setPath(resolvePath(ref.EscapedPath(), ""))
return &url
}
if ref.Opaque != "" {
url.User = nil
url.Host = ""
url.Path = ""
return &url
}
if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" {
url.RawQuery = u.RawQuery
if ref.Fragment == "" {
url.Fragment = u.Fragment
url.RawFragment = u.RawFragment
}
}
// The "abs_path" or "rel_path" cases.
url.Host = u.Host
url.User = u.User
url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath()))
return &url
}
// Query parses RawQuery and returns the corresponding values.
// It silently discards malformed value pairs.
// To check errors use ParseQuery.
func (u *URL) Query() Values {
v, _ := ParseQuery(u.RawQuery)
return v
}
// RequestURI returns the encoded path?query or opaque?query
// string that would be used in an HTTP request for u.
func (u *URL) RequestURI() string {
result := u.Opaque
if result == "" {
result = u.EscapedPath()
if result == "" {
result = "/"
}
} else {
if strings.HasPrefix(result, "//") {
result = u.Scheme + ":" + result
}
}
if u.ForceQuery || u.RawQuery != "" {
result += "?" + u.RawQuery
}
return result
}
// Hostname returns u.Host, stripping any valid port number if present.
//
// If the result is enclosed in square brackets, as literal IPv6 addresses are,
// the square brackets are removed from the result.
func (u *URL) Hostname() string {
host, _ := splitHostPort(u.Host)
return host
}
// Port returns the port part of u.Host, without the leading colon.
//
// If u.Host doesn't contain a valid numeric port, Port returns an empty string.
func (u *URL) Port() string {
_, port := splitHostPort(u.Host)
return port
}
// splitHostPort separates host and port. If the port is not valid, it returns
// the entire input as host, and it doesn't check the validity of the host.
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
func splitHostPort(hostPort string) (host, port string) {
host = hostPort
colon := strings.LastIndexByte(host, ':')
if colon != -1 && validOptionalPort(host[colon:]) {
host, port = host[:colon], host[colon+1:]
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
return
}
// Marshaling interface implementations.
// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs.
func (u *URL) MarshalBinary() (text []byte, err error) {
return []byte(u.String()), nil
}
func (u *URL) UnmarshalBinary(text []byte) error {
u1, err := Parse(string(text))
if err != nil {
return err
}
*u = *u1
return nil
}
// JoinPath returns a new URL with the provided path elements joined to
// any existing path and the resulting path cleaned of any ./ or ../ elements.
// Any sequences of multiple / characters will be reduced to a single /.
func (u *URL) JoinPath(elem ...string) *URL {
elem = append([]string{u.EscapedPath()}, elem...)
var p string
if !strings.HasPrefix(elem[0], "/") {
// Return a relative path if u is relative,
// but ensure that it contains no ../ elements.
elem[0] = "/" + elem[0]
p = path.Join(elem...)[1:]
} else {
p = path.Join(elem...)
}
// path.Join will remove any trailing slashes.
// Preserve at least one.
if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") {
p += "/"
}
url := *u
url.setPath(p)
return &url
}
// validUserinfo reports whether s is a valid userinfo string per RFC 3986
// Section 3.2.1:
//
// userinfo = *( unreserved / pct-encoded / sub-delims / ":" )
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
// / "*" / "+" / "," / ";" / "="
//
// It doesn't validate pct-encoded. The caller does that via func unescape.
func validUserinfo(s string) bool {
for _, r := range s {
if 'A' <= r && r <= 'Z' {
continue
}
if 'a' <= r && r <= 'z' {
continue
}
if '0' <= r && r <= '9' {
continue
}
switch r {
case '-', '.', '_', ':', '~', '!', '$', '&', '\'',
'(', ')', '*', '+', ',', ';', '=', '%', '@':
continue
default:
return false
}
}
return true
}
// stringContainsCTLByte reports whether s contains any ASCII control character.
func stringContainsCTLByte(s string) bool {
for i := 0; i < len(s); i++ {
b := s[i]
if b < ' ' || b == 0x7f {
return true
}
}
return false
}
// JoinPath returns a URL string with the provided path elements joined to
// the existing path of base and the resulting path cleaned of any ./ or ../ elements.
func JoinPath(base string, elem ...string) (result string, err error) {
url, err := Parse(base)
if err != nil {
return
}
result = url.JoinPath(elem...).String()
return
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package net
import (
"runtime"
"syscall"
)
func (c *conn) writeBuffers(v *Buffers) (int64, error) {
if !c.ok() {
return 0, syscall.EINVAL
}
n, err := c.fd.writeBuffers(v)
if err != nil {
return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, nil
}
func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
n, err = fd.pfd.Writev((*[][]byte)(v))
runtime.KeepAlive(fd)
return n, wrapSyscallError("writev", err)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"io/fs"
"sort"
)
type readdirMode int
const (
readdirName readdirMode = iota
readdirDirEntry
readdirFileInfo
)
// Readdir reads the contents of the directory associated with file and
// returns a slice of up to n FileInfo values, as would be returned
// by Lstat, in directory order. Subsequent calls on the same file will yield
// further FileInfos.
//
// If n > 0, Readdir returns at most n FileInfo structures. In this case, if
// Readdir returns an empty slice, it will return a non-nil error
// explaining why. At the end of a directory, the error is io.EOF.
//
// If n <= 0, Readdir returns all the FileInfo from the directory in
// a single slice. In this case, if Readdir succeeds (reads all
// the way to the end of the directory), it returns the slice and a
// nil error. If it encounters an error before the end of the
// directory, Readdir returns the FileInfo read until that point
// and a non-nil error.
//
// Most clients are better served by the more efficient ReadDir method.
func (f *File) Readdir(n int) ([]FileInfo, error) {
if f == nil {
return nil, ErrInvalid
}
_, _, infos, err := f.readdir(n, readdirFileInfo)
if infos == nil {
// Readdir has historically always returned a non-nil empty slice, never nil,
// even on error (except misuse with nil receiver above).
// Keep it that way to avoid breaking overly sensitive callers.
infos = []FileInfo{}
}
return infos, err
}
// Readdirnames reads the contents of the directory associated with file
// and returns a slice of up to n names of files in the directory,
// in directory order. Subsequent calls on the same file will yield
// further names.
//
// If n > 0, Readdirnames returns at most n names. In this case, if
// Readdirnames returns an empty slice, it will return a non-nil error
// explaining why. At the end of a directory, the error is io.EOF.
//
// If n <= 0, Readdirnames returns all the names from the directory in
// a single slice. In this case, if Readdirnames succeeds (reads all
// the way to the end of the directory), it returns the slice and a
// nil error. If it encounters an error before the end of the
// directory, Readdirnames returns the names read until that point and
// a non-nil error.
func (f *File) Readdirnames(n int) (names []string, err error) {
if f == nil {
return nil, ErrInvalid
}
names, _, _, err = f.readdir(n, readdirName)
if names == nil {
// Readdirnames has historically always returned a non-nil empty slice, never nil,
// even on error (except misuse with nil receiver above).
// Keep it that way to avoid breaking overly sensitive callers.
names = []string{}
}
return names, err
}
// A DirEntry is an entry read from a directory
// (using the ReadDir function or a File's ReadDir method).
type DirEntry = fs.DirEntry
// ReadDir reads the contents of the directory associated with the file f
// and returns a slice of DirEntry values in directory order.
// Subsequent calls on the same file will yield later DirEntry records in the directory.
//
// If n > 0, ReadDir returns at most n DirEntry records.
// In this case, if ReadDir returns an empty slice, it will return an error explaining why.
// At the end of a directory, the error is io.EOF.
//
// If n <= 0, ReadDir returns all the DirEntry records remaining in the directory.
// When it succeeds, it returns a nil error (not io.EOF).
func (f *File) ReadDir(n int) ([]DirEntry, error) {
if f == nil {
return nil, ErrInvalid
}
_, dirents, _, err := f.readdir(n, readdirDirEntry)
if dirents == nil {
// Match Readdir and Readdirnames: don't return nil slices.
dirents = []DirEntry{}
}
return dirents, err
}
// testingForceReadDirLstat forces ReadDir to call Lstat, for testing that code path.
// This can be difficult to provoke on some Unix systems otherwise.
var testingForceReadDirLstat bool
// ReadDir reads the named directory,
// returning all its directory entries sorted by filename.
// If an error occurs reading the directory,
// ReadDir returns the entries it was able to read before the error,
// along with the error.
func ReadDir(name string) ([]DirEntry, error) {
f, err := Open(name)
if err != nil {
return nil, err
}
defer f.Close()
dirs, err := f.ReadDir(-1)
sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
return dirs, err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || dragonfly || freebsd || (js && wasm) || linux || netbsd || openbsd || solaris
package os
import (
"io"
"runtime"
"sync"
"syscall"
"unsafe"
)
// Auxiliary information if the File describes a directory
type dirInfo struct {
buf *[]byte // buffer for directory I/O
nbuf int // length of buf; return value from Getdirentries
bufp int // location of next record in buf.
}
const (
// More than 5760 to work around https://golang.org/issue/24015.
blockSize = 8192
)
var dirBufPool = sync.Pool{
New: func() any {
// The buffer must be at least a block long.
buf := make([]byte, blockSize)
return &buf
},
}
func (d *dirInfo) close() {
if d.buf != nil {
dirBufPool.Put(d.buf)
d.buf = nil
}
}
func (f *File) readdir(n int, mode readdirMode) (names []string, dirents []DirEntry, infos []FileInfo, err error) {
// If this file has no dirinfo, create one.
if f.dirinfo == nil {
f.dirinfo = new(dirInfo)
f.dirinfo.buf = dirBufPool.Get().(*[]byte)
}
d := f.dirinfo
// Change the meaning of n for the implementation below.
//
// The n above was for the public interface of "if n <= 0,
// Readdir returns all the FileInfo from the directory in a
// single slice".
//
// But below, we use only negative to mean looping until the
// end and positive to mean bounded, with positive
// terminating at 0.
if n == 0 {
n = -1
}
for n != 0 {
// Refill the buffer if necessary
if d.bufp >= d.nbuf {
d.bufp = 0
var errno error
d.nbuf, errno = f.pfd.ReadDirent(*d.buf)
runtime.KeepAlive(f)
if errno != nil {
return names, dirents, infos, &PathError{Op: "readdirent", Path: f.name, Err: errno}
}
if d.nbuf <= 0 {
break // EOF
}
}
// Drain the buffer
buf := (*d.buf)[d.bufp:d.nbuf]
reclen, ok := direntReclen(buf)
if !ok || reclen > uint64(len(buf)) {
break
}
rec := buf[:reclen]
d.bufp += int(reclen)
ino, ok := direntIno(rec)
if !ok {
break
}
if ino == 0 {
continue
}
const namoff = uint64(unsafe.Offsetof(syscall.Dirent{}.Name))
namlen, ok := direntNamlen(rec)
if !ok || namoff+namlen > uint64(len(rec)) {
break
}
name := rec[namoff : namoff+namlen]
for i, c := range name {
if c == 0 {
name = name[:i]
break
}
}
// Check for useless names before allocating a string.
if string(name) == "." || string(name) == ".." {
continue
}
if n > 0 { // see 'n == 0' comment above
n--
}
if mode == readdirName {
names = append(names, string(name))
} else if mode == readdirDirEntry {
de, err := newUnixDirent(f.name, string(name), direntType(rec))
if IsNotExist(err) {
// File disappeared between readdir and stat.
// Treat as if it didn't exist.
continue
}
if err != nil {
return nil, dirents, nil, err
}
dirents = append(dirents, de)
} else {
info, err := lstat(f.name + "/" + string(name))
if IsNotExist(err) {
// File disappeared between readdir + stat.
// Treat as if it didn't exist.
continue
}
if err != nil {
return nil, nil, infos, err
}
infos = append(infos, info)
}
}
if n > 0 && len(names)+len(dirents)+len(infos) == 0 {
return nil, nil, nil, io.EOF
}
return names, dirents, infos, nil
}
// readInt returns the size-bytes unsigned integer in native byte order at offset off.
func readInt(b []byte, off, size uintptr) (u uint64, ok bool) {
if len(b) < int(off+size) {
return 0, false
}
if isBigEndian {
return readIntBE(b[off:], size), true
}
return readIntLE(b[off:], size), true
}
func readIntBE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[1]) | uint64(b[0])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(b[0])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
func readIntLE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"syscall"
"unsafe"
)
func direntIno(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(syscall.Dirent{}.Ino), unsafe.Sizeof(syscall.Dirent{}.Ino))
}
func direntReclen(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(syscall.Dirent{}.Reclen), unsafe.Sizeof(syscall.Dirent{}.Reclen))
}
func direntNamlen(buf []byte) (uint64, bool) {
reclen, ok := direntReclen(buf)
if !ok {
return 0, false
}
return reclen - uint64(unsafe.Offsetof(syscall.Dirent{}.Name)), true
}
func direntType(buf []byte) FileMode {
off := unsafe.Offsetof(syscall.Dirent{}.Type)
if off >= uintptr(len(buf)) {
return ^FileMode(0) // unknown
}
typ := buf[off]
switch typ {
case syscall.DT_BLK:
return ModeDevice
case syscall.DT_CHR:
return ModeDevice | ModeCharDevice
case syscall.DT_DIR:
return ModeDir
case syscall.DT_FIFO:
return ModeNamedPipe
case syscall.DT_LNK:
return ModeSymlink
case syscall.DT_REG:
return 0
case syscall.DT_SOCK:
return ModeSocket
}
return ^FileMode(0) // unknown
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// General environment variables.
package os
import (
"internal/testlog"
"syscall"
)
// Expand replaces ${var} or $var in the string based on the mapping function.
// For example, os.ExpandEnv(s) is equivalent to os.Expand(s, os.Getenv).
func Expand(s string, mapping func(string) string) string {
var buf []byte
// ${} is all ASCII, so bytes are fine for this operation.
i := 0
for j := 0; j < len(s); j++ {
if s[j] == '$' && j+1 < len(s) {
if buf == nil {
buf = make([]byte, 0, 2*len(s))
}
buf = append(buf, s[i:j]...)
name, w := getShellName(s[j+1:])
if name == "" && w > 0 {
// Encountered invalid syntax; eat the
// characters.
} else if name == "" {
// Valid syntax, but $ was not followed by a
// name. Leave the dollar character untouched.
buf = append(buf, s[j])
} else {
buf = append(buf, mapping(name)...)
}
j += w
i = j + 1
}
}
if buf == nil {
return s
}
return string(buf) + s[i:]
}
// ExpandEnv replaces ${var} or $var in the string according to the values
// of the current environment variables. References to undefined
// variables are replaced by the empty string.
func ExpandEnv(s string) string {
return Expand(s, Getenv)
}
// isShellSpecialVar reports whether the character identifies a special
// shell variable such as $*.
func isShellSpecialVar(c uint8) bool {
switch c {
case '*', '#', '$', '@', '!', '?', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return true
}
return false
}
// isAlphaNum reports whether the byte is an ASCII letter, number, or underscore.
func isAlphaNum(c uint8) bool {
return c == '_' || '0' <= c && c <= '9' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z'
}
// getShellName returns the name that begins the string and the number of bytes
// consumed to extract it. If the name is enclosed in {}, it's part of a ${}
// expansion and two more bytes are needed than the length of the name.
func getShellName(s string) (string, int) {
switch {
case s[0] == '{':
if len(s) > 2 && isShellSpecialVar(s[1]) && s[2] == '}' {
return s[1:2], 3
}
// Scan to closing brace
for i := 1; i < len(s); i++ {
if s[i] == '}' {
if i == 1 {
return "", 2 // Bad syntax; eat "${}"
}
return s[1:i], i + 1
}
}
return "", 1 // Bad syntax; eat "${"
case isShellSpecialVar(s[0]):
return s[0:1], 1
}
// Scan alphanumerics.
var i int
for i = 0; i < len(s) && isAlphaNum(s[i]); i++ {
}
return s[:i], i
}
// Getenv retrieves the value of the environment variable named by the key.
// It returns the value, which will be empty if the variable is not present.
// To distinguish between an empty value and an unset value, use LookupEnv.
func Getenv(key string) string {
testlog.Getenv(key)
v, _ := syscall.Getenv(key)
return v
}
// LookupEnv retrieves the value of the environment variable named
// by the key. If the variable is present in the environment the
// value (which may be empty) is returned and the boolean is true.
// Otherwise the returned value will be empty and the boolean will
// be false.
func LookupEnv(key string) (string, bool) {
testlog.Getenv(key)
return syscall.Getenv(key)
}
// Setenv sets the value of the environment variable named by the key.
// It returns an error, if any.
func Setenv(key, value string) error {
err := syscall.Setenv(key, value)
if err != nil {
return NewSyscallError("setenv", err)
}
return nil
}
// Unsetenv unsets a single environment variable.
func Unsetenv(key string) error {
return syscall.Unsetenv(key)
}
// Clearenv deletes all environment variables.
func Clearenv() {
syscall.Clearenv()
}
// Environ returns a copy of strings representing the environment,
// in the form "key=value".
func Environ() []string {
return syscall.Environ()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"internal/poll"
"io/fs"
)
// Portable analogs of some common system call errors.
//
// Errors returned from this package may be tested against these errors
// with errors.Is.
var (
// ErrInvalid indicates an invalid argument.
// Methods on File will return this error when the receiver is nil.
ErrInvalid = fs.ErrInvalid // "invalid argument"
ErrPermission = fs.ErrPermission // "permission denied"
ErrExist = fs.ErrExist // "file already exists"
ErrNotExist = fs.ErrNotExist // "file does not exist"
ErrClosed = fs.ErrClosed // "file already closed"
ErrNoDeadline = errNoDeadline() // "file type does not support deadline"
ErrDeadlineExceeded = errDeadlineExceeded() // "i/o timeout"
)
func errNoDeadline() error { return poll.ErrNoDeadline }
// errDeadlineExceeded returns the value for os.ErrDeadlineExceeded.
// This error comes from the internal/poll package, which is also
// used by package net. Doing this this way ensures that the net
// package will return os.ErrDeadlineExceeded for an exceeded deadline,
// as documented by net.Conn.SetDeadline, without requiring any extra
// work in the net package and without requiring the internal/poll
// package to import os (which it can't, because that would be circular).
func errDeadlineExceeded() error { return poll.ErrDeadlineExceeded }
type timeout interface {
Timeout() bool
}
// PathError records an error and the operation and file path that caused it.
type PathError = fs.PathError
// SyscallError records an error from a specific system call.
type SyscallError struct {
Syscall string
Err error
}
func (e *SyscallError) Error() string { return e.Syscall + ": " + e.Err.Error() }
func (e *SyscallError) Unwrap() error { return e.Err }
// Timeout reports whether this error represents a timeout.
func (e *SyscallError) Timeout() bool {
t, ok := e.Err.(timeout)
return ok && t.Timeout()
}
// NewSyscallError returns, as an error, a new SyscallError
// with the given system call name and error details.
// As a convenience, if err is nil, NewSyscallError returns nil.
func NewSyscallError(syscall string, err error) error {
if err == nil {
return nil
}
return &SyscallError{syscall, err}
}
// IsExist returns a boolean indicating whether the error is known to report
// that a file or directory already exists. It is satisfied by ErrExist as
// well as some syscall errors.
//
// This function predates errors.Is. It only supports errors returned by
// the os package. New code should use errors.Is(err, fs.ErrExist).
func IsExist(err error) bool {
return underlyingErrorIs(err, ErrExist)
}
// IsNotExist returns a boolean indicating whether the error is known to
// report that a file or directory does not exist. It is satisfied by
// ErrNotExist as well as some syscall errors.
//
// This function predates errors.Is. It only supports errors returned by
// the os package. New code should use errors.Is(err, fs.ErrNotExist).
func IsNotExist(err error) bool {
return underlyingErrorIs(err, ErrNotExist)
}
// IsPermission returns a boolean indicating whether the error is known to
// report that permission is denied. It is satisfied by ErrPermission as well
// as some syscall errors.
//
// This function predates errors.Is. It only supports errors returned by
// the os package. New code should use errors.Is(err, fs.ErrPermission).
func IsPermission(err error) bool {
return underlyingErrorIs(err, ErrPermission)
}
// IsTimeout returns a boolean indicating whether the error is known
// to report that a timeout occurred.
//
// This function predates errors.Is, and the notion of whether an
// error indicates a timeout can be ambiguous. For example, the Unix
// error EWOULDBLOCK sometimes indicates a timeout and sometimes does not.
// New code should use errors.Is with a value appropriate to the call
// returning the error, such as os.ErrDeadlineExceeded.
func IsTimeout(err error) bool {
terr, ok := underlyingError(err).(timeout)
return ok && terr.Timeout()
}
func underlyingErrorIs(err, target error) bool {
// Note that this function is not errors.Is:
// underlyingError only unwraps the specific error-wrapping types
// that it historically did, not all errors implementing Unwrap().
err = underlyingError(err)
if err == target {
return true
}
// To preserve prior behavior, only examine syscall errors.
e, ok := err.(syscallErrorType)
return ok && e.Is(target)
}
// underlyingError returns the underlying error for known os error types.
func underlyingError(err error) error {
switch err := err.(type) {
case *PathError:
return err.Err
case *LinkError:
return err.Err
case *SyscallError:
return err.Err
}
return err
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package os
import "syscall"
// wrapSyscallError takes an error and a syscall name. If the error is
// a syscall.Errno, it wraps it in an os.SyscallError using the syscall name.
func wrapSyscallError(name string, err error) error {
if _, ok := err.(syscall.Errno); ok {
err = NewSyscallError(name, err)
}
return err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"errors"
"internal/testlog"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
)
// ErrProcessDone indicates a Process has finished.
var ErrProcessDone = errors.New("os: process already finished")
// Process stores the information about a process created by StartProcess.
type Process struct {
Pid int
handle uintptr // handle is accessed atomically on Windows
isdone atomic.Bool // process has been successfully waited on
sigMu sync.RWMutex // avoid race between wait and signal
}
func newProcess(pid int, handle uintptr) *Process {
p := &Process{Pid: pid, handle: handle}
runtime.SetFinalizer(p, (*Process).Release)
return p
}
func (p *Process) setDone() {
p.isdone.Store(true)
}
func (p *Process) done() bool {
return p.isdone.Load()
}
// ProcAttr holds the attributes that will be applied to a new process
// started by StartProcess.
type ProcAttr struct {
// If Dir is non-empty, the child changes into the directory before
// creating the process.
Dir string
// If Env is non-nil, it gives the environment variables for the
// new process in the form returned by Environ.
// If it is nil, the result of Environ will be used.
Env []string
// Files specifies the open files inherited by the new process. The
// first three entries correspond to standard input, standard output, and
// standard error. An implementation may support additional entries,
// depending on the underlying operating system. A nil entry corresponds
// to that file being closed when the process starts.
// On Unix systems, StartProcess will change these File values
// to blocking mode, which means that SetDeadline will stop working
// and calling Close will not interrupt a Read or Write.
Files []*File
// Operating system-specific process creation attributes.
// Note that setting this field means that your program
// may not execute properly or even compile on some
// operating systems.
Sys *syscall.SysProcAttr
}
// A Signal represents an operating system signal.
// The usual underlying implementation is operating system-dependent:
// on Unix it is syscall.Signal.
type Signal interface {
String() string
Signal() // to distinguish from other Stringers
}
// Getpid returns the process id of the caller.
func Getpid() int { return syscall.Getpid() }
// Getppid returns the process id of the caller's parent.
func Getppid() int { return syscall.Getppid() }
// FindProcess looks for a running process by its pid.
//
// The Process it returns can be used to obtain information
// about the underlying operating system process.
//
// On Unix systems, FindProcess always succeeds and returns a Process
// for the given pid, regardless of whether the process exists.
func FindProcess(pid int) (*Process, error) {
return findProcess(pid)
}
// StartProcess starts a new process with the program, arguments and attributes
// specified by name, argv and attr. The argv slice will become os.Args in the
// new process, so it normally starts with the program name.
//
// If the calling goroutine has locked the operating system thread
// with runtime.LockOSThread and modified any inheritable OS-level
// thread state (for example, Linux or Plan 9 name spaces), the new
// process will inherit the caller's thread state.
//
// StartProcess is a low-level interface. The os/exec package provides
// higher-level interfaces.
//
// If there is an error, it will be of type *PathError.
func StartProcess(name string, argv []string, attr *ProcAttr) (*Process, error) {
testlog.Open(name)
return startProcess(name, argv, attr)
}
// Release releases any resources associated with the Process p,
// rendering it unusable in the future.
// Release only needs to be called if Wait is not.
func (p *Process) Release() error {
return p.release()
}
// Kill causes the Process to exit immediately. Kill does not wait until
// the Process has actually exited. This only kills the Process itself,
// not any other processes it may have started.
func (p *Process) Kill() error {
return p.kill()
}
// Wait waits for the Process to exit, and then returns a
// ProcessState describing its status and an error, if any.
// Wait releases any resources associated with the Process.
// On most operating systems, the Process must be a child
// of the current process or an error will be returned.
func (p *Process) Wait() (*ProcessState, error) {
return p.wait()
}
// Signal sends a signal to the Process.
// Sending Interrupt on Windows is not implemented.
func (p *Process) Signal(sig Signal) error {
return p.signal(sig)
}
// UserTime returns the user CPU time of the exited process and its children.
func (p *ProcessState) UserTime() time.Duration {
return p.userTime()
}
// SystemTime returns the system CPU time of the exited process and its children.
func (p *ProcessState) SystemTime() time.Duration {
return p.systemTime()
}
// Exited reports whether the program has exited.
// On Unix systems this reports true if the program exited due to calling exit,
// but false if the program terminated due to a signal.
func (p *ProcessState) Exited() bool {
return p.exited()
}
// Success reports whether the program exited successfully,
// such as with exit status 0 on Unix.
func (p *ProcessState) Success() bool {
return p.success()
}
// Sys returns system-dependent exit information about
// the process. Convert it to the appropriate underlying
// type, such as syscall.WaitStatus on Unix, to access its contents.
func (p *ProcessState) Sys() any {
return p.sys()
}
// SysUsage returns system-dependent resource usage information about
// the exited process. Convert it to the appropriate underlying
// type, such as *syscall.Rusage on Unix, to access its contents.
// (On Unix, *syscall.Rusage matches struct rusage as defined in the
// getrusage(2) manual page.)
func (p *ProcessState) SysUsage() any {
return p.sysUsage()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package exec runs external commands. It wraps os.StartProcess to make it
// easier to remap stdin and stdout, connect I/O with pipes, and do other
// adjustments.
//
// Unlike the "system" library call from C and other languages, the
// os/exec package intentionally does not invoke the system shell and
// does not expand any glob patterns or handle other expansions,
// pipelines, or redirections typically done by shells. The package
// behaves more like C's "exec" family of functions. To expand glob
// patterns, either call the shell directly, taking care to escape any
// dangerous input, or use the path/filepath package's Glob function.
// To expand environment variables, use package os's ExpandEnv.
//
// Note that the examples in this package assume a Unix system.
// They may not run on Windows, and they do not run in the Go Playground
// used by golang.org and godoc.org.
//
// # Executables in the current directory
//
// The functions Command and LookPath look for a program
// in the directories listed in the current path, following the
// conventions of the host operating system.
// Operating systems have for decades included the current
// directory in this search, sometimes implicitly and sometimes
// configured explicitly that way by default.
// Modern practice is that including the current directory
// is usually unexpected and often leads to security problems.
//
// To avoid those security problems, as of Go 1.19, this package will not resolve a program
// using an implicit or explicit path entry relative to the current directory.
// That is, if you run exec.LookPath("go"), it will not successfully return
// ./go on Unix nor .\go.exe on Windows, no matter how the path is configured.
// Instead, if the usual path algorithms would result in that answer,
// these functions return an error err satisfying errors.Is(err, ErrDot).
//
// For example, consider these two program snippets:
//
// path, err := exec.LookPath("prog")
// if err != nil {
// log.Fatal(err)
// }
// use(path)
//
// and
//
// cmd := exec.Command("prog")
// if err := cmd.Run(); err != nil {
// log.Fatal(err)
// }
//
// These will not find and run ./prog or .\prog.exe,
// no matter how the current path is configured.
//
// Code that always wants to run a program from the current directory
// can be rewritten to say "./prog" instead of "prog".
//
// Code that insists on including results from relative path entries
// can instead override the error using an errors.Is check:
//
// path, err := exec.LookPath("prog")
// if errors.Is(err, exec.ErrDot) {
// err = nil
// }
// if err != nil {
// log.Fatal(err)
// }
// use(path)
//
// and
//
// cmd := exec.Command("prog")
// if errors.Is(cmd.Err, exec.ErrDot) {
// cmd.Err = nil
// }
// if err := cmd.Run(); err != nil {
// log.Fatal(err)
// }
//
// Setting the environment variable GODEBUG=execerrdot=0
// disables generation of ErrDot entirely, temporarily restoring the pre-Go 1.19
// behavior for programs that are unable to apply more targeted fixes.
// A future version of Go may remove support for this variable.
//
// Before adding such overrides, make sure you understand the
// security implications of doing so.
// See https://go.dev/blog/path-security for more information.
package exec
import (
"bytes"
"context"
"errors"
"internal/godebug"
"internal/syscall/execenv"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
)
// Error is returned by LookPath when it fails to classify a file as an
// executable.
type Error struct {
// Name is the file name for which the error occurred.
Name string
// Err is the underlying error.
Err error
}
func (e *Error) Error() string {
return "exec: " + strconv.Quote(e.Name) + ": " + e.Err.Error()
}
func (e *Error) Unwrap() error { return e.Err }
// ErrWaitDelay is returned by (*Cmd).Wait if the process exits with a
// successful status code but its output pipes are not closed before the
// command's WaitDelay expires.
var ErrWaitDelay = errors.New("exec: WaitDelay expired before I/O complete")
// wrappedError wraps an error without relying on fmt.Errorf.
type wrappedError struct {
prefix string
err error
}
func (w wrappedError) Error() string {
return w.prefix + ": " + w.err.Error()
}
func (w wrappedError) Unwrap() error {
return w.err
}
// Cmd represents an external command being prepared or run.
//
// A Cmd cannot be reused after calling its Run, Output or CombinedOutput
// methods.
type Cmd struct {
// Path is the path of the command to run.
//
// This is the only field that must be set to a non-zero
// value. If Path is relative, it is evaluated relative
// to Dir.
Path string
// Args holds command line arguments, including the command as Args[0].
// If the Args field is empty or nil, Run uses {Path}.
//
// In typical use, both Path and Args are set by calling Command.
Args []string
// Env specifies the environment of the process.
// Each entry is of the form "key=value".
// If Env is nil, the new process uses the current process's
// environment.
// If Env contains duplicate environment keys, only the last
// value in the slice for each duplicate key is used.
// As a special case on Windows, SYSTEMROOT is always added if
// missing and not explicitly set to the empty string.
Env []string
// Dir specifies the working directory of the command.
// If Dir is the empty string, Run runs the command in the
// calling process's current directory.
Dir string
// Stdin specifies the process's standard input.
//
// If Stdin is nil, the process reads from the null device (os.DevNull).
//
// If Stdin is an *os.File, the process's standard input is connected
// directly to that file.
//
// Otherwise, during the execution of the command a separate
// goroutine reads from Stdin and delivers that data to the command
// over a pipe. In this case, Wait does not complete until the goroutine
// stops copying, either because it has reached the end of Stdin
// (EOF or a read error), or because writing to the pipe returned an error,
// or because a nonzero WaitDelay was set and expired.
Stdin io.Reader
// Stdout and Stderr specify the process's standard output and error.
//
// If either is nil, Run connects the corresponding file descriptor
// to the null device (os.DevNull).
//
// If either is an *os.File, the corresponding output from the process
// is connected directly to that file.
//
// Otherwise, during the execution of the command a separate goroutine
// reads from the process over a pipe and delivers that data to the
// corresponding Writer. In this case, Wait does not complete until the
// goroutine reaches EOF or encounters an error or a nonzero WaitDelay
// expires.
//
// If Stdout and Stderr are the same writer, and have a type that can
// be compared with ==, at most one goroutine at a time will call Write.
Stdout io.Writer
Stderr io.Writer
// ExtraFiles specifies additional open files to be inherited by the
// new process. It does not include standard input, standard output, or
// standard error. If non-nil, entry i becomes file descriptor 3+i.
//
// ExtraFiles is not supported on Windows.
ExtraFiles []*os.File
// SysProcAttr holds optional, operating system-specific attributes.
// Run passes it to os.StartProcess as the os.ProcAttr's Sys field.
SysProcAttr *syscall.SysProcAttr
// Process is the underlying process, once started.
Process *os.Process
// ProcessState contains information about an exited process.
// If the process was started successfully, Wait or Run will
// populate its ProcessState when the command completes.
ProcessState *os.ProcessState
// ctx is the context passed to CommandContext, if any.
ctx context.Context
Err error // LookPath error, if any.
// If Cancel is non-nil, the command must have been created with
// CommandContext and Cancel will be called when the command's
// Context is done. By default, CommandContext sets Cancel to
// call the Kill method on the command's Process.
//
// Typically a custom Cancel will send a signal to the command's
// Process, but it may instead take other actions to initiate cancellation,
// such as closing a stdin or stdout pipe or sending a shutdown request on a
// network socket.
//
// If the command exits with a success status after Cancel is
// called, and Cancel does not return an error equivalent to
// os.ErrProcessDone, then Wait and similar methods will return a non-nil
// error: either an error wrapping the one returned by Cancel,
// or the error from the Context.
// (If the command exits with a non-success status, or Cancel
// returns an error that wraps os.ErrProcessDone, Wait and similar methods
// continue to return the command's usual exit status.)
//
// If Cancel is set to nil, nothing will happen immediately when the command's
// Context is done, but a nonzero WaitDelay will still take effect. That may
// be useful, for example, to work around deadlocks in commands that do not
// support shutdown signals but are expected to always finish quickly.
//
// Cancel will not be called if Start returns a non-nil error.
Cancel func() error
// If WaitDelay is non-zero, it bounds the time spent waiting on two sources
// of unexpected delay in Wait: a child process that fails to exit after the
// associated Context is canceled, and a child process that exits but leaves
// its I/O pipes unclosed.
//
// The WaitDelay timer starts when either the associated Context is done or a
// call to Wait observes that the child process has exited, whichever occurs
// first. When the delay has elapsed, the command shuts down the child process
// and/or its I/O pipes.
//
// If the child process has failed to exit — perhaps because it ignored or
// failed to receive a shutdown signal from a Cancel function, or because no
// Cancel function was set — then it will be terminated using os.Process.Kill.
//
// Then, if the I/O pipes communicating with the child process are still open,
// those pipes are closed in order to unblock any goroutines currently blocked
// on Read or Write calls.
//
// If pipes are closed due to WaitDelay, no Cancel call has occurred,
// and the command has otherwise exited with a successful status, Wait and
// similar methods will return ErrWaitDelay instead of nil.
//
// If WaitDelay is zero (the default), I/O pipes will be read until EOF,
// which might not occur until orphaned subprocesses of the command have
// also closed their descriptors for the pipes.
WaitDelay time.Duration
// childIOFiles holds closers for any of the child process's
// stdin, stdout, and/or stderr files that were opened by the Cmd itself
// (not supplied by the caller). These should be closed as soon as they
// are inherited by the child process.
childIOFiles []io.Closer
// parentIOPipes holds closers for the parent's end of any pipes
// connected to the child's stdin, stdout, and/or stderr streams
// that were opened by the Cmd itself (not supplied by the caller).
// These should be closed after Wait sees the command and copying
// goroutines exit, or after WaitDelay has expired.
parentIOPipes []io.Closer
// goroutine holds a set of closures to execute to copy data
// to and/or from the command's I/O pipes.
goroutine []func() error
// If goroutineErr is non-nil, it receives the first error from a copying
// goroutine once all such goroutines have completed.
// goroutineErr is set to nil once its error has been received.
goroutineErr <-chan error
// If ctxResult is non-nil, it receives the result of watchCtx exactly once.
ctxResult <-chan ctxResult
// The stack saved when the Command was created, if GODEBUG contains
// execwait=2. Used for debugging leaks.
createdByStack []byte
// For a security release long ago, we created x/sys/execabs,
// which manipulated the unexported lookPathErr error field
// in this struct. For Go 1.19 we exported the field as Err error,
// above, but we have to keep lookPathErr around for use by
// old programs building against new toolchains.
// The String and Start methods look for an error in lookPathErr
// in preference to Err, to preserve the errors that execabs sets.
//
// In general we don't guarantee misuse of reflect like this,
// but the misuse of reflect was by us, the best of various bad
// options to fix the security problem, and people depend on
// those old copies of execabs continuing to work.
// The result is that we have to leave this variable around for the
// rest of time, a compatibility scar.
//
// See https://go.dev/blog/path-security
// and https://go.dev/issue/43724 for more context.
lookPathErr error
}
// A ctxResult reports the result of watching the Context associated with a
// running command (and sending corresponding signals if needed).
type ctxResult struct {
err error
// If timer is non-nil, it expires after WaitDelay has elapsed after
// the Context is done.
//
// (If timer is nil, that means that the Context was not done before the
// command completed, or no WaitDelay was set, or the WaitDelay already
// expired and its effect was already applied.)
timer *time.Timer
}
var execwait = godebug.New("execwait")
var execerrdot = godebug.New("execerrdot")
// Command returns the Cmd struct to execute the named program with
// the given arguments.
//
// It sets only the Path and Args in the returned structure.
//
// If name contains no path separators, Command uses LookPath to
// resolve name to a complete path if possible. Otherwise it uses name
// directly as Path.
//
// The returned Cmd's Args field is constructed from the command name
// followed by the elements of arg, so arg should not include the
// command name itself. For example, Command("echo", "hello").
// Args[0] is always name, not the possibly resolved Path.
//
// On Windows, processes receive the whole command line as a single string
// and do their own parsing. Command combines and quotes Args into a command
// line string with an algorithm compatible with applications using
// CommandLineToArgvW (which is the most common way). Notable exceptions are
// msiexec.exe and cmd.exe (and thus, all batch files), which have a different
// unquoting algorithm. In these or other similar cases, you can do the
// quoting yourself and provide the full command line in SysProcAttr.CmdLine,
// leaving Args empty.
func Command(name string, arg ...string) *Cmd {
cmd := &Cmd{
Path: name,
Args: append([]string{name}, arg...),
}
if v := execwait.Value(); v != "" {
if v == "2" {
// Obtain the caller stack. (This is equivalent to runtime/debug.Stack,
// copied to avoid importing the whole package.)
stack := make([]byte, 1024)
for {
n := runtime.Stack(stack, false)
if n < len(stack) {
stack = stack[:n]
break
}
stack = make([]byte, 2*len(stack))
}
if i := bytes.Index(stack, []byte("\nos/exec.Command(")); i >= 0 {
stack = stack[i+1:]
}
cmd.createdByStack = stack
}
runtime.SetFinalizer(cmd, func(c *Cmd) {
if c.Process != nil && c.ProcessState == nil {
debugHint := ""
if c.createdByStack == nil {
debugHint = " (set GODEBUG=execwait=2 to capture stacks for debugging)"
} else {
os.Stderr.WriteString("GODEBUG=execwait=2 detected a leaked exec.Cmd created by:\n")
os.Stderr.Write(c.createdByStack)
os.Stderr.WriteString("\n")
debugHint = ""
}
panic("exec: Cmd started a Process but leaked without a call to Wait" + debugHint)
}
})
}
if filepath.Base(name) == name {
lp, err := LookPath(name)
if lp != "" {
// Update cmd.Path even if err is non-nil.
// If err is ErrDot (especially on Windows), lp may include a resolved
// extension (like .exe or .bat) that should be preserved.
cmd.Path = lp
}
if err != nil {
cmd.Err = err
}
}
return cmd
}
// CommandContext is like Command but includes a context.
//
// The provided context is used to interrupt the process
// (by calling cmd.Cancel or os.Process.Kill)
// if the context becomes done before the command completes on its own.
//
// CommandContext sets the command's Cancel function to invoke the Kill method
// on its Process, and leaves its WaitDelay unset. The caller may change the
// cancellation behavior by modifying those fields before starting the command.
func CommandContext(ctx context.Context, name string, arg ...string) *Cmd {
if ctx == nil {
panic("nil Context")
}
cmd := Command(name, arg...)
cmd.ctx = ctx
cmd.Cancel = func() error {
return cmd.Process.Kill()
}
return cmd
}
// String returns a human-readable description of c.
// It is intended only for debugging.
// In particular, it is not suitable for use as input to a shell.
// The output of String may vary across Go releases.
func (c *Cmd) String() string {
if c.Err != nil || c.lookPathErr != nil {
// failed to resolve path; report the original requested path (plus args)
return strings.Join(c.Args, " ")
}
// report the exact executable path (plus args)
b := new(strings.Builder)
b.WriteString(c.Path)
for _, a := range c.Args[1:] {
b.WriteByte(' ')
b.WriteString(a)
}
return b.String()
}
// interfaceEqual protects against panics from doing equality tests on
// two interfaces with non-comparable underlying types.
func interfaceEqual(a, b any) bool {
defer func() {
recover()
}()
return a == b
}
func (c *Cmd) argv() []string {
if len(c.Args) > 0 {
return c.Args
}
return []string{c.Path}
}
func (c *Cmd) childStdin() (*os.File, error) {
if c.Stdin == nil {
f, err := os.Open(os.DevNull)
if err != nil {
return nil, err
}
c.childIOFiles = append(c.childIOFiles, f)
return f, nil
}
if f, ok := c.Stdin.(*os.File); ok {
return f, nil
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
c.childIOFiles = append(c.childIOFiles, pr)
c.parentIOPipes = append(c.parentIOPipes, pw)
c.goroutine = append(c.goroutine, func() error {
_, err := io.Copy(pw, c.Stdin)
if skipStdinCopyError(err) {
err = nil
}
if err1 := pw.Close(); err == nil {
err = err1
}
return err
})
return pr, nil
}
func (c *Cmd) childStdout() (*os.File, error) {
return c.writerDescriptor(c.Stdout)
}
func (c *Cmd) childStderr(childStdout *os.File) (*os.File, error) {
if c.Stderr != nil && interfaceEqual(c.Stderr, c.Stdout) {
return childStdout, nil
}
return c.writerDescriptor(c.Stderr)
}
// writerDescriptor returns an os.File to which the child process
// can write to send data to w.
//
// If w is nil, writerDescriptor returns a File that writes to os.DevNull.
func (c *Cmd) writerDescriptor(w io.Writer) (*os.File, error) {
if w == nil {
f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
if err != nil {
return nil, err
}
c.childIOFiles = append(c.childIOFiles, f)
return f, nil
}
if f, ok := w.(*os.File); ok {
return f, nil
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
c.childIOFiles = append(c.childIOFiles, pw)
c.parentIOPipes = append(c.parentIOPipes, pr)
c.goroutine = append(c.goroutine, func() error {
_, err := io.Copy(w, pr)
pr.Close() // in case io.Copy stopped due to write error
return err
})
return pw, nil
}
func closeDescriptors(closers []io.Closer) {
for _, fd := range closers {
fd.Close()
}
}
// Run starts the specified command and waits for it to complete.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command starts but does not complete successfully, the error is of
// type *ExitError. Other error types may be returned for other situations.
//
// If the calling goroutine has locked the operating system thread
// with runtime.LockOSThread and modified any inheritable OS-level
// thread state (for example, Linux or Plan 9 name spaces), the new
// process will inherit the caller's thread state.
func (c *Cmd) Run() error {
if err := c.Start(); err != nil {
return err
}
return c.Wait()
}
// lookExtensions finds windows executable by its dir and path.
// It uses LookPath to try appropriate extensions.
// lookExtensions does not search PATH, instead it converts `prog` into `.\prog`.
func lookExtensions(path, dir string) (string, error) {
if filepath.Base(path) == path {
path = "." + string(filepath.Separator) + path
}
if dir == "" {
return LookPath(path)
}
if filepath.VolumeName(path) != "" {
return LookPath(path)
}
if len(path) > 1 && os.IsPathSeparator(path[0]) {
return LookPath(path)
}
dirandpath := filepath.Join(dir, path)
// We assume that LookPath will only add file extension.
lp, err := LookPath(dirandpath)
if err != nil {
return "", err
}
ext := strings.TrimPrefix(lp, dirandpath)
return path + ext, nil
}
// Start starts the specified command but does not wait for it to complete.
//
// If Start returns successfully, the c.Process field will be set.
//
// After a successful call to Start the Wait method must be called in
// order to release associated system resources.
func (c *Cmd) Start() error {
// Check for doubled Start calls before we defer failure cleanup. If the prior
// call to Start succeeded, we don't want to spuriously close its pipes.
if c.Process != nil {
return errors.New("exec: already started")
}
started := false
defer func() {
closeDescriptors(c.childIOFiles)
c.childIOFiles = nil
if !started {
closeDescriptors(c.parentIOPipes)
c.parentIOPipes = nil
}
}()
if c.Path == "" && c.Err == nil && c.lookPathErr == nil {
c.Err = errors.New("exec: no command")
}
if c.Err != nil || c.lookPathErr != nil {
if c.lookPathErr != nil {
return c.lookPathErr
}
return c.Err
}
if runtime.GOOS == "windows" {
lp, err := lookExtensions(c.Path, c.Dir)
if err != nil {
return err
}
c.Path = lp
}
if c.Cancel != nil && c.ctx == nil {
return errors.New("exec: command with a non-nil Cancel was not created with CommandContext")
}
if c.ctx != nil {
select {
case <-c.ctx.Done():
return c.ctx.Err()
default:
}
}
childFiles := make([]*os.File, 0, 3+len(c.ExtraFiles))
stdin, err := c.childStdin()
if err != nil {
return err
}
childFiles = append(childFiles, stdin)
stdout, err := c.childStdout()
if err != nil {
return err
}
childFiles = append(childFiles, stdout)
stderr, err := c.childStderr(stdout)
if err != nil {
return err
}
childFiles = append(childFiles, stderr)
childFiles = append(childFiles, c.ExtraFiles...)
env, err := c.environ()
if err != nil {
return err
}
c.Process, err = os.StartProcess(c.Path, c.argv(), &os.ProcAttr{
Dir: c.Dir,
Files: childFiles,
Env: env,
Sys: c.SysProcAttr,
})
if err != nil {
return err
}
started = true
// Don't allocate the goroutineErr channel unless there are goroutines to start.
if len(c.goroutine) > 0 {
goroutineErr := make(chan error, 1)
c.goroutineErr = goroutineErr
type goroutineStatus struct {
running int
firstErr error
}
statusc := make(chan goroutineStatus, 1)
statusc <- goroutineStatus{running: len(c.goroutine)}
for _, fn := range c.goroutine {
go func(fn func() error) {
err := fn()
status := <-statusc
if status.firstErr == nil {
status.firstErr = err
}
status.running--
if status.running == 0 {
goroutineErr <- status.firstErr
} else {
statusc <- status
}
}(fn)
}
c.goroutine = nil // Allow the goroutines' closures to be GC'd when they complete.
}
// If we have anything to do when the command's Context expires,
// start a goroutine to watch for cancellation.
//
// (Even if the command was created by CommandContext, a helper library may
// have explicitly set its Cancel field back to nil, indicating that it should
// be allowed to continue running after cancellation after all.)
if (c.Cancel != nil || c.WaitDelay != 0) && c.ctx != nil && c.ctx.Done() != nil {
resultc := make(chan ctxResult)
c.ctxResult = resultc
go c.watchCtx(resultc)
}
return nil
}
// watchCtx watches c.ctx until it is able to send a result to resultc.
//
// If c.ctx is done before a result can be sent, watchCtx calls c.Cancel,
// and/or kills cmd.Process it after c.WaitDelay has elapsed.
//
// watchCtx manipulates c.goroutineErr, so its result must be received before
// c.awaitGoroutines is called.
func (c *Cmd) watchCtx(resultc chan<- ctxResult) {
select {
case resultc <- ctxResult{}:
return
case <-c.ctx.Done():
}
var err error
if c.Cancel != nil {
if interruptErr := c.Cancel(); interruptErr == nil {
// We appear to have successfully interrupted the command, so any
// program behavior from this point may be due to ctx even if the
// command exits with code 0.
err = c.ctx.Err()
} else if errors.Is(interruptErr, os.ErrProcessDone) {
// The process already finished: we just didn't notice it yet.
// (Perhaps c.Wait hadn't been called, or perhaps it happened to race with
// c.ctx being cancelled.) Don't inject a needless error.
} else {
err = wrappedError{
prefix: "exec: canceling Cmd",
err: interruptErr,
}
}
}
if c.WaitDelay == 0 {
resultc <- ctxResult{err: err}
return
}
timer := time.NewTimer(c.WaitDelay)
select {
case resultc <- ctxResult{err: err, timer: timer}:
// c.Process.Wait returned and we've handed the timer off to c.Wait.
// It will take care of goroutine shutdown from here.
return
case <-timer.C:
}
killed := false
if killErr := c.Process.Kill(); killErr == nil {
// We appear to have killed the process. c.Process.Wait should return a
// non-nil error to c.Wait unless the Kill signal races with a successful
// exit, and if that does happen we shouldn't report a spurious error,
// so don't set err to anything here.
killed = true
} else if !errors.Is(killErr, os.ErrProcessDone) {
err = wrappedError{
prefix: "exec: killing Cmd",
err: killErr,
}
}
if c.goroutineErr != nil {
select {
case goroutineErr := <-c.goroutineErr:
// Forward goroutineErr only if we don't have reason to believe it was
// caused by a call to Cancel or Kill above.
if err == nil && !killed {
err = goroutineErr
}
default:
// Close the child process's I/O pipes, in case it abandoned some
// subprocess that inherited them and is still holding them open
// (see https://go.dev/issue/23019).
//
// We close the goroutine pipes only after we have sent any signals we're
// going to send to the process (via Signal or Kill above): if we send
// SIGKILL to the process, we would prefer for it to die of SIGKILL, not
// SIGPIPE. (However, this may still cause any orphaned subprocesses to
// terminate with SIGPIPE.)
closeDescriptors(c.parentIOPipes)
// Wait for the copying goroutines to finish, but report ErrWaitDelay for
// the error: any other error here could result from closing the pipes.
_ = <-c.goroutineErr
if err == nil {
err = ErrWaitDelay
}
}
// Since we have already received the only result from c.goroutineErr,
// set it to nil to prevent awaitGoroutines from blocking on it.
c.goroutineErr = nil
}
resultc <- ctxResult{err: err}
}
// An ExitError reports an unsuccessful exit by a command.
type ExitError struct {
*os.ProcessState
// Stderr holds a subset of the standard error output from the
// Cmd.Output method if standard error was not otherwise being
// collected.
//
// If the error output is long, Stderr may contain only a prefix
// and suffix of the output, with the middle replaced with
// text about the number of omitted bytes.
//
// Stderr is provided for debugging, for inclusion in error messages.
// Users with other needs should redirect Cmd.Stderr as needed.
Stderr []byte
}
func (e *ExitError) Error() string {
return e.ProcessState.String()
}
// Wait waits for the command to exit and waits for any copying to
// stdin or copying from stdout or stderr to complete.
//
// The command must have been started by Start.
//
// The returned error is nil if the command runs, has no problems
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command fails to run or doesn't complete successfully, the
// error is of type *ExitError. Other error types may be
// returned for I/O problems.
//
// If any of c.Stdin, c.Stdout or c.Stderr are not an *os.File, Wait also waits
// for the respective I/O loop copying to or from the process to complete.
//
// Wait releases any resources associated with the Cmd.
func (c *Cmd) Wait() error {
if c.Process == nil {
return errors.New("exec: not started")
}
if c.ProcessState != nil {
return errors.New("exec: Wait was already called")
}
state, err := c.Process.Wait()
if err == nil && !state.Success() {
err = &ExitError{ProcessState: state}
}
c.ProcessState = state
var timer *time.Timer
if c.ctxResult != nil {
watch := <-c.ctxResult
timer = watch.timer
// If c.Process.Wait returned an error, prefer that.
// Otherwise, report any error from the watchCtx goroutine,
// such as a Context cancellation or a WaitDelay overrun.
if err == nil && watch.err != nil {
err = watch.err
}
}
if goroutineErr := c.awaitGoroutines(timer); err == nil {
// Report an error from the copying goroutines only if the program otherwise
// exited normally on its own. Otherwise, the copying error may be due to the
// abnormal termination.
err = goroutineErr
}
closeDescriptors(c.parentIOPipes)
c.parentIOPipes = nil
return err
}
// awaitGoroutines waits for the results of the goroutines copying data to or
// from the command's I/O pipes.
//
// If c.WaitDelay elapses before the goroutines complete, awaitGoroutines
// forcibly closes their pipes and returns ErrWaitDelay.
//
// If timer is non-nil, it must send to timer.C at the end of c.WaitDelay.
func (c *Cmd) awaitGoroutines(timer *time.Timer) error {
defer func() {
if timer != nil {
timer.Stop()
}
c.goroutineErr = nil
}()
if c.goroutineErr == nil {
return nil // No running goroutines to await.
}
if timer == nil {
if c.WaitDelay == 0 {
return <-c.goroutineErr
}
select {
case err := <-c.goroutineErr:
// Avoid the overhead of starting a timer.
return err
default:
}
// No existing timer was started: either there is no Context associated with
// the command, or c.Process.Wait completed before the Context was done.
timer = time.NewTimer(c.WaitDelay)
}
select {
case <-timer.C:
closeDescriptors(c.parentIOPipes)
// Wait for the copying goroutines to finish, but ignore any error
// (since it was probably caused by closing the pipes).
_ = <-c.goroutineErr
return ErrWaitDelay
case err := <-c.goroutineErr:
return err
}
}
// Output runs the command and returns its standard output.
// Any returned error will usually be of type *ExitError.
// If c.Stderr was nil, Output populates ExitError.Stderr.
func (c *Cmd) Output() ([]byte, error) {
if c.Stdout != nil {
return nil, errors.New("exec: Stdout already set")
}
var stdout bytes.Buffer
c.Stdout = &stdout
captureErr := c.Stderr == nil
if captureErr {
c.Stderr = &prefixSuffixSaver{N: 32 << 10}
}
err := c.Run()
if err != nil && captureErr {
if ee, ok := err.(*ExitError); ok {
ee.Stderr = c.Stderr.(*prefixSuffixSaver).Bytes()
}
}
return stdout.Bytes(), err
}
// CombinedOutput runs the command and returns its combined standard
// output and standard error.
func (c *Cmd) CombinedOutput() ([]byte, error) {
if c.Stdout != nil {
return nil, errors.New("exec: Stdout already set")
}
if c.Stderr != nil {
return nil, errors.New("exec: Stderr already set")
}
var b bytes.Buffer
c.Stdout = &b
c.Stderr = &b
err := c.Run()
return b.Bytes(), err
}
// StdinPipe returns a pipe that will be connected to the command's
// standard input when the command starts.
// The pipe will be closed automatically after Wait sees the command exit.
// A caller need only call Close to force the pipe to close sooner.
// For example, if the command being run will not exit until standard input
// is closed, the caller must close the pipe.
func (c *Cmd) StdinPipe() (io.WriteCloser, error) {
if c.Stdin != nil {
return nil, errors.New("exec: Stdin already set")
}
if c.Process != nil {
return nil, errors.New("exec: StdinPipe after process started")
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
c.Stdin = pr
c.childIOFiles = append(c.childIOFiles, pr)
c.parentIOPipes = append(c.parentIOPipes, pw)
return pw, nil
}
// StdoutPipe returns a pipe that will be connected to the command's
// standard output when the command starts.
//
// Wait will close the pipe after seeing the command exit, so most callers
// need not close the pipe themselves. It is thus incorrect to call Wait
// before all reads from the pipe have completed.
// For the same reason, it is incorrect to call Run when using StdoutPipe.
// See the example for idiomatic usage.
func (c *Cmd) StdoutPipe() (io.ReadCloser, error) {
if c.Stdout != nil {
return nil, errors.New("exec: Stdout already set")
}
if c.Process != nil {
return nil, errors.New("exec: StdoutPipe after process started")
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
c.Stdout = pw
c.childIOFiles = append(c.childIOFiles, pw)
c.parentIOPipes = append(c.parentIOPipes, pr)
return pr, nil
}
// StderrPipe returns a pipe that will be connected to the command's
// standard error when the command starts.
//
// Wait will close the pipe after seeing the command exit, so most callers
// need not close the pipe themselves. It is thus incorrect to call Wait
// before all reads from the pipe have completed.
// For the same reason, it is incorrect to use Run when using StderrPipe.
// See the StdoutPipe example for idiomatic usage.
func (c *Cmd) StderrPipe() (io.ReadCloser, error) {
if c.Stderr != nil {
return nil, errors.New("exec: Stderr already set")
}
if c.Process != nil {
return nil, errors.New("exec: StderrPipe after process started")
}
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
c.Stderr = pw
c.childIOFiles = append(c.childIOFiles, pw)
c.parentIOPipes = append(c.parentIOPipes, pr)
return pr, nil
}
// prefixSuffixSaver is an io.Writer which retains the first N bytes
// and the last N bytes written to it. The Bytes() methods reconstructs
// it with a pretty error message.
type prefixSuffixSaver struct {
N int // max size of prefix or suffix
prefix []byte
suffix []byte // ring buffer once len(suffix) == N
suffixOff int // offset to write into suffix
skipped int64
// TODO(bradfitz): we could keep one large []byte and use part of it for
// the prefix, reserve space for the '... Omitting N bytes ...' message,
// then the ring buffer suffix, and just rearrange the ring buffer
// suffix when Bytes() is called, but it doesn't seem worth it for
// now just for error messages. It's only ~64KB anyway.
}
func (w *prefixSuffixSaver) Write(p []byte) (n int, err error) {
lenp := len(p)
p = w.fill(&w.prefix, p)
// Only keep the last w.N bytes of suffix data.
if overage := len(p) - w.N; overage > 0 {
p = p[overage:]
w.skipped += int64(overage)
}
p = w.fill(&w.suffix, p)
// w.suffix is full now if p is non-empty. Overwrite it in a circle.
for len(p) > 0 { // 0, 1, or 2 iterations.
n := copy(w.suffix[w.suffixOff:], p)
p = p[n:]
w.skipped += int64(n)
w.suffixOff += n
if w.suffixOff == w.N {
w.suffixOff = 0
}
}
return lenp, nil
}
// fill appends up to len(p) bytes of p to *dst, such that *dst does not
// grow larger than w.N. It returns the un-appended suffix of p.
func (w *prefixSuffixSaver) fill(dst *[]byte, p []byte) (pRemain []byte) {
if remain := w.N - len(*dst); remain > 0 {
add := minInt(len(p), remain)
*dst = append(*dst, p[:add]...)
p = p[add:]
}
return p
}
func (w *prefixSuffixSaver) Bytes() []byte {
if w.suffix == nil {
return w.prefix
}
if w.skipped == 0 {
return append(w.prefix, w.suffix...)
}
var buf bytes.Buffer
buf.Grow(len(w.prefix) + len(w.suffix) + 50)
buf.Write(w.prefix)
buf.WriteString("\n... omitting ")
buf.WriteString(strconv.FormatInt(w.skipped, 10))
buf.WriteString(" bytes ...\n")
buf.Write(w.suffix[w.suffixOff:])
buf.Write(w.suffix[:w.suffixOff])
return buf.Bytes()
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// environ returns a best-effort copy of the environment in which the command
// would be run as it is currently configured. If an error occurs in computing
// the environment, it is returned alongside the best-effort copy.
func (c *Cmd) environ() ([]string, error) {
var err error
env := c.Env
if env == nil {
env, err = execenv.Default(c.SysProcAttr)
if err != nil {
env = os.Environ()
// Note that the non-nil err is preserved despite env being overridden.
}
if c.Dir != "" {
switch runtime.GOOS {
case "windows", "plan9":
// Windows and Plan 9 do not use the PWD variable, so we don't need to
// keep it accurate.
default:
// On POSIX platforms, PWD represents “an absolute pathname of the
// current working directory.” Since we are changing the working
// directory for the command, we should also update PWD to reflect that.
//
// Unfortunately, we didn't always do that, so (as proposed in
// https://go.dev/issue/50599) to avoid unintended collateral damage we
// only implicitly update PWD when Env is nil. That way, we're much
// less likely to override an intentional change to the variable.
if pwd, absErr := filepath.Abs(c.Dir); absErr == nil {
env = append(env, "PWD="+pwd)
} else if err == nil {
err = absErr
}
}
}
}
env, dedupErr := dedupEnv(env)
if err == nil {
err = dedupErr
}
return addCriticalEnv(env), err
}
// Environ returns a copy of the environment in which the command would be run
// as it is currently configured.
func (c *Cmd) Environ() []string {
// Intentionally ignore errors: environ returns a best-effort environment no matter what.
env, _ := c.environ()
return env
}
// dedupEnv returns a copy of env with any duplicates removed, in favor of
// later values.
// Items not of the normal environment "key=value" form are preserved unchanged.
// Except on Plan 9, items containing NUL characters are removed, and
// an error is returned along with the remaining values.
func dedupEnv(env []string) ([]string, error) {
return dedupEnvCase(runtime.GOOS == "windows", runtime.GOOS == "plan9", env)
}
// dedupEnvCase is dedupEnv with a case option for testing.
// If caseInsensitive is true, the case of keys is ignored.
// If nulOK is false, items containing NUL characters are allowed.
func dedupEnvCase(caseInsensitive, nulOK bool, env []string) ([]string, error) {
// Construct the output in reverse order, to preserve the
// last occurrence of each key.
var err error
out := make([]string, 0, len(env))
saw := make(map[string]bool, len(env))
for n := len(env); n > 0; n-- {
kv := env[n-1]
// Reject NUL in environment variables to prevent security issues (#56284);
// except on Plan 9, which uses NUL as os.PathListSeparator (#56544).
if !nulOK && strings.IndexByte(kv, 0) != -1 {
err = errors.New("exec: environment variable contains NUL")
continue
}
i := strings.Index(kv, "=")
if i == 0 {
// We observe in practice keys with a single leading "=" on Windows.
// TODO(#49886): Should we consume only the first leading "=" as part
// of the key, or parse through arbitrarily many of them until a non-"="?
i = strings.Index(kv[1:], "=") + 1
}
if i < 0 {
if kv != "" {
// The entry is not of the form "key=value" (as it is required to be).
// Leave it as-is for now.
// TODO(#52436): should we strip or reject these bogus entries?
out = append(out, kv)
}
continue
}
k := kv[:i]
if caseInsensitive {
k = strings.ToLower(k)
}
if saw[k] {
continue
}
saw[k] = true
out = append(out, kv)
}
// Now reverse the slice to restore the original order.
for i := 0; i < len(out)/2; i++ {
j := len(out) - i - 1
out[i], out[j] = out[j], out[i]
}
return out, err
}
// addCriticalEnv adds any critical environment variables that are required
// (or at least almost always required) on the operating system.
// Currently this is only used for Windows.
func addCriticalEnv(env []string) []string {
if runtime.GOOS != "windows" {
return env
}
for _, kv := range env {
k, _, ok := strings.Cut(kv, "=")
if !ok {
continue
}
if strings.EqualFold(k, "SYSTEMROOT") {
// We already have it.
return env
}
}
return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
}
// ErrDot indicates that a path lookup resolved to an executable
// in the current directory due to ‘.’ being in the path, either
// implicitly or explicitly. See the package documentation for details.
//
// Note that functions in this package do not return ErrDot directly.
// Code should use errors.Is(err, ErrDot), not err == ErrDot,
// to test whether a returned error err is due to this condition.
var ErrDot = errors.New("cannot run executable found relative to current directory")
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !plan9 && !windows
package exec
import (
"io/fs"
"syscall"
)
// skipStdinCopyError optionally specifies a function which reports
// whether the provided stdin copy error should be ignored.
func skipStdinCopyError(err error) bool {
// Ignore EPIPE errors copying to stdin if the program
// completed successfully otherwise.
// See Issue 9173.
pe, ok := err.(*fs.PathError)
return ok &&
pe.Op == "write" && pe.Path == "|1" &&
pe.Err == syscall.EPIPE
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
// Package fdtest provides test helpers for working with file descriptors across exec.
package fdtest
import (
"syscall"
)
// Exists returns true if fd is a valid file descriptor.
func Exists(fd uintptr) bool {
var s syscall.Stat_t
err := syscall.Fstat(int(fd), &s)
return err != syscall.EBADF
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package exec
import (
"errors"
"internal/syscall/unix"
"io/fs"
"os"
"path/filepath"
"strings"
"syscall"
)
// ErrNotFound is the error resulting if a path search failed to find an executable file.
var ErrNotFound = errors.New("executable file not found in $PATH")
func findExecutable(file string) error {
d, err := os.Stat(file)
if err != nil {
return err
}
m := d.Mode()
if m.IsDir() {
return syscall.EISDIR
}
err = unix.Eaccess(file, unix.X_OK)
// ENOSYS means Eaccess is not available or not implemented.
// EPERM can be returned by Linux containers employing seccomp.
// In both cases, fall back to checking the permission bits.
if err == nil || (err != syscall.ENOSYS && err != syscall.EPERM) {
return err
}
if m&0111 != 0 {
return nil
}
return fs.ErrPermission
}
// LookPath searches for an executable named file in the
// directories named by the PATH environment variable.
// If file contains a slash, it is tried directly and the PATH is not consulted.
// Otherwise, on success, the result is an absolute path.
//
// In older versions of Go, LookPath could return a path relative to the current directory.
// As of Go 1.19, LookPath will instead return that path along with an error satisfying
// errors.Is(err, ErrDot). See the package documentation for more details.
func LookPath(file string) (string, error) {
// NOTE(rsc): I wish we could use the Plan 9 behavior here
// (only bypass the path if file begins with / or ./ or ../)
// but that would not match all the Unix shells.
if strings.Contains(file, "/") {
err := findExecutable(file)
if err == nil {
return file, nil
}
return "", &Error{file, err}
}
path := os.Getenv("PATH")
for _, dir := range filepath.SplitList(path) {
if dir == "" {
// Unix shell semantics: path element "" means "."
dir = "."
}
path := filepath.Join(dir, file)
if err := findExecutable(path); err == nil {
if !filepath.IsAbs(path) {
if execerrdot.Value() != "0" {
return path, &Error{file, ErrDot}
}
execerrdot.IncNonDefault()
}
return path, nil
}
}
return "", &Error{file, ErrNotFound}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package os
import (
"internal/itoa"
"internal/syscall/execenv"
"runtime"
"syscall"
)
// The only signal values guaranteed to be present in the os package on all
// systems are os.Interrupt (send the process an interrupt) and os.Kill (force
// the process to exit). On Windows, sending os.Interrupt to a process with
// os.Process.Signal is not implemented; it will return an error instead of
// sending a signal.
var (
Interrupt Signal = syscall.SIGINT
Kill Signal = syscall.SIGKILL
)
func startProcess(name string, argv []string, attr *ProcAttr) (p *Process, err error) {
// If there is no SysProcAttr (ie. no Chroot or changed
// UID/GID), double-check existence of the directory we want
// to chdir into. We can make the error clearer this way.
if attr != nil && attr.Sys == nil && attr.Dir != "" {
if _, err := Stat(attr.Dir); err != nil {
pe := err.(*PathError)
pe.Op = "chdir"
return nil, pe
}
}
sysattr := &syscall.ProcAttr{
Dir: attr.Dir,
Env: attr.Env,
Sys: attr.Sys,
}
if sysattr.Env == nil {
sysattr.Env, err = execenv.Default(sysattr.Sys)
if err != nil {
return nil, err
}
}
sysattr.Files = make([]uintptr, 0, len(attr.Files))
for _, f := range attr.Files {
sysattr.Files = append(sysattr.Files, f.Fd())
}
pid, h, e := syscall.StartProcess(name, argv, sysattr)
// Make sure we don't run the finalizers of attr.Files.
runtime.KeepAlive(attr)
if e != nil {
return nil, &PathError{Op: "fork/exec", Path: name, Err: e}
}
return newProcess(pid, h), nil
}
func (p *Process) kill() error {
return p.Signal(Kill)
}
// ProcessState stores information about a process, as reported by Wait.
type ProcessState struct {
pid int // The process's id.
status syscall.WaitStatus // System-dependent status info.
rusage *syscall.Rusage
}
// Pid returns the process id of the exited process.
func (p *ProcessState) Pid() int {
return p.pid
}
func (p *ProcessState) exited() bool {
return p.status.Exited()
}
func (p *ProcessState) success() bool {
return p.status.ExitStatus() == 0
}
func (p *ProcessState) sys() any {
return p.status
}
func (p *ProcessState) sysUsage() any {
return p.rusage
}
func (p *ProcessState) String() string {
if p == nil {
return "<nil>"
}
status := p.Sys().(syscall.WaitStatus)
res := ""
switch {
case status.Exited():
code := status.ExitStatus()
if runtime.GOOS == "windows" && uint(code) >= 1<<16 { // windows uses large hex numbers
res = "exit status " + uitox(uint(code))
} else { // unix systems use small decimal integers
res = "exit status " + itoa.Itoa(code) // unix
}
case status.Signaled():
res = "signal: " + status.Signal().String()
case status.Stopped():
res = "stop signal: " + status.StopSignal().String()
if status.StopSignal() == syscall.SIGTRAP && status.TrapCause() != 0 {
res += " (trap " + itoa.Itoa(status.TrapCause()) + ")"
}
case status.Continued():
res = "continued"
}
if status.CoreDump() {
res += " (core dumped)"
}
return res
}
// ExitCode returns the exit code of the exited process, or -1
// if the process hasn't exited or was terminated by a signal.
func (p *ProcessState) ExitCode() int {
// return -1 if the process hasn't started.
if p == nil {
return -1
}
return p.status.ExitStatus()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package os
import (
"errors"
"runtime"
"syscall"
"time"
)
func (p *Process) wait() (ps *ProcessState, err error) {
if p.Pid == -1 {
return nil, syscall.EINVAL
}
// If we can block until Wait4 will succeed immediately, do so.
ready, err := p.blockUntilWaitable()
if err != nil {
return nil, err
}
if ready {
// Mark the process done now, before the call to Wait4,
// so that Process.signal will not send a signal.
p.setDone()
// Acquire a write lock on sigMu to wait for any
// active call to the signal method to complete.
p.sigMu.Lock()
p.sigMu.Unlock()
}
var (
status syscall.WaitStatus
rusage syscall.Rusage
pid1 int
e error
)
for {
pid1, e = syscall.Wait4(p.Pid, &status, 0, &rusage)
if e != syscall.EINTR {
break
}
}
if e != nil {
return nil, NewSyscallError("wait", e)
}
if pid1 != 0 {
p.setDone()
}
ps = &ProcessState{
pid: pid1,
status: status,
rusage: &rusage,
}
return ps, nil
}
func (p *Process) signal(sig Signal) error {
if p.Pid == -1 {
return errors.New("os: process already released")
}
if p.Pid == 0 {
return errors.New("os: process not initialized")
}
p.sigMu.RLock()
defer p.sigMu.RUnlock()
if p.done() {
return ErrProcessDone
}
s, ok := sig.(syscall.Signal)
if !ok {
return errors.New("os: unsupported signal type")
}
if e := syscall.Kill(p.Pid, s); e != nil {
if e == syscall.ESRCH {
return ErrProcessDone
}
return e
}
return nil
}
func (p *Process) release() error {
// NOOP for unix.
p.Pid = -1
// no need for a finalizer anymore
runtime.SetFinalizer(p, nil)
return nil
}
func findProcess(pid int) (p *Process, err error) {
// NOOP for unix.
return newProcess(pid, 0), nil
}
func (p *ProcessState) userTime() time.Duration {
return time.Duration(p.rusage.Utime.Nano()) * time.Nanosecond
}
func (p *ProcessState) systemTime() time.Duration {
return time.Duration(p.rusage.Stime.Nano()) * time.Nanosecond
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
// Executable returns the path name for the executable that started
// the current process. There is no guarantee that the path is still
// pointing to the correct executable. If a symlink was used to start
// the process, depending on the operating system, the result might
// be the symlink or the path it pointed to. If a stable result is
// needed, path/filepath.EvalSymlinks might help.
//
// Executable returns an absolute path unless an error occurred.
//
// The main use case is finding resources located relative to an
// executable.
func Executable() (string, error) {
return executable()
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux || netbsd || (js && wasm)
package os
import (
"errors"
"runtime"
)
func executable() (string, error) {
var procfn string
switch runtime.GOOS {
default:
return "", errors.New("Executable not implemented for " + runtime.GOOS)
case "linux", "android":
procfn = "/proc/self/exe"
case "netbsd":
procfn = "/proc/curproc/exe"
}
path, err := Readlink(procfn)
// When the executable has been deleted then Readlink returns a
// path appended with " (deleted)".
return stringsTrimSuffix(path, " (deleted)"), err
}
// stringsTrimSuffix is the same as strings.TrimSuffix.
func stringsTrimSuffix(s, suffix string) string {
if len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix {
return s[:len(s)-len(suffix)]
}
return s
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package os provides a platform-independent interface to operating system
// functionality. The design is Unix-like, although the error handling is
// Go-like; failing calls return values of type error rather than error numbers.
// Often, more information is available within the error. For example,
// if a call that takes a file name fails, such as Open or Stat, the error
// will include the failing file name when printed and will be of type
// *PathError, which may be unpacked for more information.
//
// The os interface is intended to be uniform across all operating systems.
// Features not generally available appear in the system-specific package syscall.
//
// Here is a simple example, opening a file and reading some of it.
//
// file, err := os.Open("file.go") // For read access.
// if err != nil {
// log.Fatal(err)
// }
//
// If the open fails, the error string will be self-explanatory, like
//
// open file.go: no such file or directory
//
// The file's data can then be read into a slice of bytes. Read and
// Write take their byte counts from the length of the argument slice.
//
// data := make([]byte, 100)
// count, err := file.Read(data)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Printf("read %d bytes: %q\n", count, data[:count])
//
// Note: The maximum number of concurrent operations on a File may be limited by
// the OS or the system. The number should be high, but exceeding it may degrade
// performance or cause other issues.
package os
import (
"errors"
"internal/poll"
"internal/safefilepath"
"internal/testlog"
"io"
"io/fs"
"runtime"
"syscall"
"time"
"unsafe"
)
// Name returns the name of the file as presented to Open.
func (f *File) Name() string { return f.name }
// Stdin, Stdout, and Stderr are open Files pointing to the standard input,
// standard output, and standard error file descriptors.
//
// Note that the Go runtime writes to standard error for panics and crashes;
// closing Stderr may cause those messages to go elsewhere, perhaps
// to a file opened later.
var (
Stdin = NewFile(uintptr(syscall.Stdin), "/dev/stdin")
Stdout = NewFile(uintptr(syscall.Stdout), "/dev/stdout")
Stderr = NewFile(uintptr(syscall.Stderr), "/dev/stderr")
)
// Flags to OpenFile wrapping those of the underlying system. Not all
// flags may be implemented on a given system.
const (
// Exactly one of O_RDONLY, O_WRONLY, or O_RDWR must be specified.
O_RDONLY int = syscall.O_RDONLY // open the file read-only.
O_WRONLY int = syscall.O_WRONLY // open the file write-only.
O_RDWR int = syscall.O_RDWR // open the file read-write.
// The remaining values may be or'ed in to control behavior.
O_APPEND int = syscall.O_APPEND // append data to the file when writing.
O_CREATE int = syscall.O_CREAT // create a new file if none exists.
O_EXCL int = syscall.O_EXCL // used with O_CREATE, file must not exist.
O_SYNC int = syscall.O_SYNC // open for synchronous I/O.
O_TRUNC int = syscall.O_TRUNC // truncate regular writable file when opened.
)
// Seek whence values.
//
// Deprecated: Use io.SeekStart, io.SeekCurrent, and io.SeekEnd.
const (
SEEK_SET int = 0 // seek relative to the origin of the file
SEEK_CUR int = 1 // seek relative to the current offset
SEEK_END int = 2 // seek relative to the end
)
// LinkError records an error during a link or symlink or rename
// system call and the paths that caused it.
type LinkError struct {
Op string
Old string
New string
Err error
}
func (e *LinkError) Error() string {
return e.Op + " " + e.Old + " " + e.New + ": " + e.Err.Error()
}
func (e *LinkError) Unwrap() error {
return e.Err
}
// Read reads up to len(b) bytes from the File and stores them in b.
// It returns the number of bytes read and any error encountered.
// At end of file, Read returns 0, io.EOF.
func (f *File) Read(b []byte) (n int, err error) {
if err := f.checkValid("read"); err != nil {
return 0, err
}
n, e := f.read(b)
return n, f.wrapErr("read", e)
}
// ReadAt reads len(b) bytes from the File starting at byte offset off.
// It returns the number of bytes read and the error, if any.
// ReadAt always returns a non-nil error when n < len(b).
// At end of file, that error is io.EOF.
func (f *File) ReadAt(b []byte, off int64) (n int, err error) {
if err := f.checkValid("read"); err != nil {
return 0, err
}
if off < 0 {
return 0, &PathError{Op: "readat", Path: f.name, Err: errors.New("negative offset")}
}
for len(b) > 0 {
m, e := f.pread(b, off)
if e != nil {
err = f.wrapErr("read", e)
break
}
n += m
b = b[m:]
off += int64(m)
}
return
}
// ReadFrom implements io.ReaderFrom.
func (f *File) ReadFrom(r io.Reader) (n int64, err error) {
if err := f.checkValid("write"); err != nil {
return 0, err
}
n, handled, e := f.readFrom(r)
if !handled {
return genericReadFrom(f, r) // without wrapping
}
return n, f.wrapErr("write", e)
}
func genericReadFrom(f *File, r io.Reader) (int64, error) {
return io.Copy(onlyWriter{f}, r)
}
type onlyWriter struct {
io.Writer
}
// Write writes len(b) bytes from b to the File.
// It returns the number of bytes written and an error, if any.
// Write returns a non-nil error when n != len(b).
func (f *File) Write(b []byte) (n int, err error) {
if err := f.checkValid("write"); err != nil {
return 0, err
}
n, e := f.write(b)
if n < 0 {
n = 0
}
if n != len(b) {
err = io.ErrShortWrite
}
epipecheck(f, e)
if e != nil {
err = f.wrapErr("write", e)
}
return n, err
}
var errWriteAtInAppendMode = errors.New("os: invalid use of WriteAt on file opened with O_APPEND")
// WriteAt writes len(b) bytes to the File starting at byte offset off.
// It returns the number of bytes written and an error, if any.
// WriteAt returns a non-nil error when n != len(b).
//
// If file was opened with the O_APPEND flag, WriteAt returns an error.
func (f *File) WriteAt(b []byte, off int64) (n int, err error) {
if err := f.checkValid("write"); err != nil {
return 0, err
}
if f.appendMode {
return 0, errWriteAtInAppendMode
}
if off < 0 {
return 0, &PathError{Op: "writeat", Path: f.name, Err: errors.New("negative offset")}
}
for len(b) > 0 {
m, e := f.pwrite(b, off)
if e != nil {
err = f.wrapErr("write", e)
break
}
n += m
b = b[m:]
off += int64(m)
}
return
}
// Seek sets the offset for the next Read or Write on file to offset, interpreted
// according to whence: 0 means relative to the origin of the file, 1 means
// relative to the current offset, and 2 means relative to the end.
// It returns the new offset and an error, if any.
// The behavior of Seek on a file opened with O_APPEND is not specified.
func (f *File) Seek(offset int64, whence int) (ret int64, err error) {
if err := f.checkValid("seek"); err != nil {
return 0, err
}
r, e := f.seek(offset, whence)
if e == nil && f.dirinfo != nil && r != 0 {
e = syscall.EISDIR
}
if e != nil {
return 0, f.wrapErr("seek", e)
}
return r, nil
}
// WriteString is like Write, but writes the contents of string s rather than
// a slice of bytes.
func (f *File) WriteString(s string) (n int, err error) {
b := unsafe.Slice(unsafe.StringData(s), len(s))
return f.Write(b)
}
// Mkdir creates a new directory with the specified name and permission
// bits (before umask).
// If there is an error, it will be of type *PathError.
func Mkdir(name string, perm FileMode) error {
longName := fixLongPath(name)
e := ignoringEINTR(func() error {
return syscall.Mkdir(longName, syscallMode(perm))
})
if e != nil {
return &PathError{Op: "mkdir", Path: name, Err: e}
}
// mkdir(2) itself won't handle the sticky bit on *BSD and Solaris
if !supportsCreateWithStickyBit && perm&ModeSticky != 0 {
e = setStickyBit(name)
if e != nil {
Remove(name)
return e
}
}
return nil
}
// setStickyBit adds ModeSticky to the permission bits of path, non atomic.
func setStickyBit(name string) error {
fi, err := Stat(name)
if err != nil {
return err
}
return Chmod(name, fi.Mode()|ModeSticky)
}
// Chdir changes the current working directory to the named directory.
// If there is an error, it will be of type *PathError.
func Chdir(dir string) error {
if e := syscall.Chdir(dir); e != nil {
testlog.Open(dir) // observe likely non-existent directory
return &PathError{Op: "chdir", Path: dir, Err: e}
}
if log := testlog.Logger(); log != nil {
wd, err := Getwd()
if err == nil {
log.Chdir(wd)
}
}
return nil
}
// Open opens the named file for reading. If successful, methods on
// the returned file can be used for reading; the associated file
// descriptor has mode O_RDONLY.
// If there is an error, it will be of type *PathError.
func Open(name string) (*File, error) {
return OpenFile(name, O_RDONLY, 0)
}
// Create creates or truncates the named file. If the file already exists,
// it is truncated. If the file does not exist, it is created with mode 0666
// (before umask). If successful, methods on the returned File can
// be used for I/O; the associated file descriptor has mode O_RDWR.
// If there is an error, it will be of type *PathError.
func Create(name string) (*File, error) {
return OpenFile(name, O_RDWR|O_CREATE|O_TRUNC, 0666)
}
// OpenFile is the generalized open call; most users will use Open
// or Create instead. It opens the named file with specified flag
// (O_RDONLY etc.). If the file does not exist, and the O_CREATE flag
// is passed, it is created with mode perm (before umask). If successful,
// methods on the returned File can be used for I/O.
// If there is an error, it will be of type *PathError.
func OpenFile(name string, flag int, perm FileMode) (*File, error) {
testlog.Open(name)
f, err := openFileNolog(name, flag, perm)
if err != nil {
return nil, err
}
f.appendMode = flag&O_APPEND != 0
return f, nil
}
// lstat is overridden in tests.
var lstat = Lstat
// Rename renames (moves) oldpath to newpath.
// If newpath already exists and is not a directory, Rename replaces it.
// OS-specific restrictions may apply when oldpath and newpath are in different directories.
// Even within the same directory, on non-Unix platforms Rename is not an atomic operation.
// If there is an error, it will be of type *LinkError.
func Rename(oldpath, newpath string) error {
return rename(oldpath, newpath)
}
// Many functions in package syscall return a count of -1 instead of 0.
// Using fixCount(call()) instead of call() corrects the count.
func fixCount(n int, err error) (int, error) {
if n < 0 {
n = 0
}
return n, err
}
// checkWrapErr is the test hook to enable checking unexpected wrapped errors of poll.ErrFileClosing.
// It is set to true in the export_test.go for tests (including fuzz tests).
var checkWrapErr = false
// wrapErr wraps an error that occurred during an operation on an open file.
// It passes io.EOF through unchanged, otherwise converts
// poll.ErrFileClosing to ErrClosed and wraps the error in a PathError.
func (f *File) wrapErr(op string, err error) error {
if err == nil || err == io.EOF {
return err
}
if err == poll.ErrFileClosing {
err = ErrClosed
} else if checkWrapErr && errors.Is(err, poll.ErrFileClosing) {
panic("unexpected error wrapping poll.ErrFileClosing: " + err.Error())
}
return &PathError{Op: op, Path: f.name, Err: err}
}
// TempDir returns the default directory to use for temporary files.
//
// On Unix systems, it returns $TMPDIR if non-empty, else /tmp.
// On Windows, it uses GetTempPath, returning the first non-empty
// value from %TMP%, %TEMP%, %USERPROFILE%, or the Windows directory.
// On Plan 9, it returns /tmp.
//
// The directory is neither guaranteed to exist nor have accessible
// permissions.
func TempDir() string {
return tempDir()
}
// UserCacheDir returns the default root directory to use for user-specific
// cached data. Users should create their own application-specific subdirectory
// within this one and use that.
//
// On Unix systems, it returns $XDG_CACHE_HOME as specified by
// https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html if
// non-empty, else $HOME/.cache.
// On Darwin, it returns $HOME/Library/Caches.
// On Windows, it returns %LocalAppData%.
// On Plan 9, it returns $home/lib/cache.
//
// If the location cannot be determined (for example, $HOME is not defined),
// then it will return an error.
func UserCacheDir() (string, error) {
var dir string
switch runtime.GOOS {
case "windows":
dir = Getenv("LocalAppData")
if dir == "" {
return "", errors.New("%LocalAppData% is not defined")
}
case "darwin", "ios":
dir = Getenv("HOME")
if dir == "" {
return "", errors.New("$HOME is not defined")
}
dir += "/Library/Caches"
case "plan9":
dir = Getenv("home")
if dir == "" {
return "", errors.New("$home is not defined")
}
dir += "/lib/cache"
default: // Unix
dir = Getenv("XDG_CACHE_HOME")
if dir == "" {
dir = Getenv("HOME")
if dir == "" {
return "", errors.New("neither $XDG_CACHE_HOME nor $HOME are defined")
}
dir += "/.cache"
}
}
return dir, nil
}
// UserConfigDir returns the default root directory to use for user-specific
// configuration data. Users should create their own application-specific
// subdirectory within this one and use that.
//
// On Unix systems, it returns $XDG_CONFIG_HOME as specified by
// https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html if
// non-empty, else $HOME/.config.
// On Darwin, it returns $HOME/Library/Application Support.
// On Windows, it returns %AppData%.
// On Plan 9, it returns $home/lib.
//
// If the location cannot be determined (for example, $HOME is not defined),
// then it will return an error.
func UserConfigDir() (string, error) {
var dir string
switch runtime.GOOS {
case "windows":
dir = Getenv("AppData")
if dir == "" {
return "", errors.New("%AppData% is not defined")
}
case "darwin", "ios":
dir = Getenv("HOME")
if dir == "" {
return "", errors.New("$HOME is not defined")
}
dir += "/Library/Application Support"
case "plan9":
dir = Getenv("home")
if dir == "" {
return "", errors.New("$home is not defined")
}
dir += "/lib"
default: // Unix
dir = Getenv("XDG_CONFIG_HOME")
if dir == "" {
dir = Getenv("HOME")
if dir == "" {
return "", errors.New("neither $XDG_CONFIG_HOME nor $HOME are defined")
}
dir += "/.config"
}
}
return dir, nil
}
// UserHomeDir returns the current user's home directory.
//
// On Unix, including macOS, it returns the $HOME environment variable.
// On Windows, it returns %USERPROFILE%.
// On Plan 9, it returns the $home environment variable.
//
// If the expected variable is not set in the environment, UserHomeDir
// returns either a platform-specific default value or a non-nil error.
func UserHomeDir() (string, error) {
env, enverr := "HOME", "$HOME"
switch runtime.GOOS {
case "windows":
env, enverr = "USERPROFILE", "%userprofile%"
case "plan9":
env, enverr = "home", "$home"
}
if v := Getenv(env); v != "" {
return v, nil
}
// On some geese the home directory is not always defined.
switch runtime.GOOS {
case "android":
return "/sdcard", nil
case "ios":
return "/", nil
}
return "", errors.New(enverr + " is not defined")
}
// Chmod changes the mode of the named file to mode.
// If the file is a symbolic link, it changes the mode of the link's target.
// If there is an error, it will be of type *PathError.
//
// A different subset of the mode bits are used, depending on the
// operating system.
//
// On Unix, the mode's permission bits, ModeSetuid, ModeSetgid, and
// ModeSticky are used.
//
// On Windows, only the 0200 bit (owner writable) of mode is used; it
// controls whether the file's read-only attribute is set or cleared.
// The other bits are currently unused. For compatibility with Go 1.12
// and earlier, use a non-zero mode. Use mode 0400 for a read-only
// file and 0600 for a readable+writable file.
//
// On Plan 9, the mode's permission bits, ModeAppend, ModeExclusive,
// and ModeTemporary are used.
func Chmod(name string, mode FileMode) error { return chmod(name, mode) }
// Chmod changes the mode of the file to mode.
// If there is an error, it will be of type *PathError.
func (f *File) Chmod(mode FileMode) error { return f.chmod(mode) }
// SetDeadline sets the read and write deadlines for a File.
// It is equivalent to calling both SetReadDeadline and SetWriteDeadline.
//
// Only some kinds of files support setting a deadline. Calls to SetDeadline
// for files that do not support deadlines will return ErrNoDeadline.
// On most systems ordinary files do not support deadlines, but pipes do.
//
// A deadline is an absolute time after which I/O operations fail with an
// error instead of blocking. The deadline applies to all future and pending
// I/O, not just the immediately following call to Read or Write.
// After a deadline has been exceeded, the connection can be refreshed
// by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other I/O
// methods will return an error that wraps ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// That error implements the Timeout method, and calling the Timeout
// method will return true, but there are other possible errors for which
// the Timeout will return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
func (f *File) SetDeadline(t time.Time) error {
return f.setDeadline(t)
}
// SetReadDeadline sets the deadline for future Read calls and any
// currently-blocked Read call.
// A zero value for t means Read will not time out.
// Not all files support setting deadlines; see SetDeadline.
func (f *File) SetReadDeadline(t time.Time) error {
return f.setReadDeadline(t)
}
// SetWriteDeadline sets the deadline for any future Write calls and any
// currently-blocked Write call.
// Even if Write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
// Not all files support setting deadlines; see SetDeadline.
func (f *File) SetWriteDeadline(t time.Time) error {
return f.setWriteDeadline(t)
}
// SyscallConn returns a raw file.
// This implements the syscall.Conn interface.
func (f *File) SyscallConn() (syscall.RawConn, error) {
if err := f.checkValid("SyscallConn"); err != nil {
return nil, err
}
return newRawConn(f)
}
// DirFS returns a file system (an fs.FS) for the tree of files rooted at the directory dir.
//
// Note that DirFS("/prefix") only guarantees that the Open calls it makes to the
// operating system will begin with "/prefix": DirFS("/prefix").Open("file") is the
// same as os.Open("/prefix/file"). So if /prefix/file is a symbolic link pointing outside
// the /prefix tree, then using DirFS does not stop the access any more than using
// os.Open does. Additionally, the root of the fs.FS returned for a relative path,
// DirFS("prefix"), will be affected by later calls to Chdir. DirFS is therefore not
// a general substitute for a chroot-style security mechanism when the directory tree
// contains arbitrary content.
//
// The directory dir must not be "".
//
// The result implements fs.StatFS.
func DirFS(dir string) fs.FS {
return dirFS(dir)
}
// containsAny reports whether any bytes in chars are within s.
func containsAny(s, chars string) bool {
for i := 0; i < len(s); i++ {
for j := 0; j < len(chars); j++ {
if s[i] == chars[j] {
return true
}
}
}
return false
}
type dirFS string
func (dir dirFS) Open(name string) (fs.File, error) {
fullname, err := dir.join(name)
if err != nil {
return nil, &PathError{Op: "stat", Path: name, Err: err}
}
f, err := Open(fullname)
if err != nil {
// DirFS takes a string appropriate for GOOS,
// while the name argument here is always slash separated.
// dir.join will have mixed the two; undo that for
// error reporting.
err.(*PathError).Path = name
return nil, err
}
return f, nil
}
func (dir dirFS) Stat(name string) (fs.FileInfo, error) {
fullname, err := dir.join(name)
if err != nil {
return nil, &PathError{Op: "stat", Path: name, Err: err}
}
f, err := Stat(fullname)
if err != nil {
// See comment in dirFS.Open.
err.(*PathError).Path = name
return nil, err
}
return f, nil
}
// join returns the path for name in dir.
func (dir dirFS) join(name string) (string, error) {
if dir == "" {
return "", errors.New("os: DirFS with empty root")
}
if !fs.ValidPath(name) {
return "", ErrInvalid
}
name, err := safefilepath.FromFS(name)
if err != nil {
return "", ErrInvalid
}
if IsPathSeparator(dir[len(dir)-1]) {
return string(dir) + name, nil
}
return string(dir) + string(PathSeparator) + name, nil
}
// ReadFile reads the named file and returns the contents.
// A successful call returns err == nil, not err == EOF.
// Because ReadFile reads the whole file, it does not treat an EOF from Read
// as an error to be reported.
func ReadFile(name string) ([]byte, error) {
f, err := Open(name)
if err != nil {
return nil, err
}
defer f.Close()
var size int
if info, err := f.Stat(); err == nil {
size64 := info.Size()
if int64(int(size64)) == size64 {
size = int(size64)
}
}
size++ // one byte for final read at EOF
// If a file claims a small size, read at least 512 bytes.
// In particular, files in Linux's /proc claim size 0 but
// then do not work right if read in small pieces,
// so an initial read of 1 byte would not work correctly.
if size < 512 {
size = 512
}
data := make([]byte, 0, size)
for {
if len(data) >= cap(data) {
d := append(data[:cap(data)], 0)
data = d[:len(data)]
}
n, err := f.Read(data[len(data):cap(data)])
data = data[:len(data)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return data, err
}
}
}
// WriteFile writes data to the named file, creating it if necessary.
// If the file does not exist, WriteFile creates it with permissions perm (before umask);
// otherwise WriteFile truncates it before writing, without changing permissions.
// Since Writefile requires multiple system calls to complete, a failure mid-operation
// can leave the file in a partially written state.
func WriteFile(name string, data []byte, perm FileMode) error {
f, err := OpenFile(name, O_WRONLY|O_CREATE|O_TRUNC, perm)
if err != nil {
return err
}
_, err = f.Write(data)
if err1 := f.Close(); err1 != nil && err == nil {
err = err1
}
return err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package os
import (
"runtime"
"syscall"
"time"
)
func sigpipe() // implemented in package runtime
// Close closes the File, rendering it unusable for I/O.
// On files that support SetDeadline, any pending I/O operations will
// be canceled and return immediately with an ErrClosed error.
// Close will return an error if it has already been called.
func (f *File) Close() error {
if f == nil {
return ErrInvalid
}
return f.file.close()
}
// read reads up to len(b) bytes from the File.
// It returns the number of bytes read and an error, if any.
func (f *File) read(b []byte) (n int, err error) {
n, err = f.pfd.Read(b)
runtime.KeepAlive(f)
return n, err
}
// pread reads len(b) bytes from the File starting at byte offset off.
// It returns the number of bytes read and the error, if any.
// EOF is signaled by a zero count with err set to nil.
func (f *File) pread(b []byte, off int64) (n int, err error) {
n, err = f.pfd.Pread(b, off)
runtime.KeepAlive(f)
return n, err
}
// write writes len(b) bytes to the File.
// It returns the number of bytes written and an error, if any.
func (f *File) write(b []byte) (n int, err error) {
n, err = f.pfd.Write(b)
runtime.KeepAlive(f)
return n, err
}
// pwrite writes len(b) bytes to the File starting at byte offset off.
// It returns the number of bytes written and an error, if any.
func (f *File) pwrite(b []byte, off int64) (n int, err error) {
n, err = f.pfd.Pwrite(b, off)
runtime.KeepAlive(f)
return n, err
}
// syscallMode returns the syscall-specific mode bits from Go's portable mode bits.
func syscallMode(i FileMode) (o uint32) {
o |= uint32(i.Perm())
if i&ModeSetuid != 0 {
o |= syscall.S_ISUID
}
if i&ModeSetgid != 0 {
o |= syscall.S_ISGID
}
if i&ModeSticky != 0 {
o |= syscall.S_ISVTX
}
// No mapping for Go's ModeTemporary (plan9 only).
return
}
// See docs in file.go:Chmod.
func chmod(name string, mode FileMode) error {
longName := fixLongPath(name)
e := ignoringEINTR(func() error {
return syscall.Chmod(longName, syscallMode(mode))
})
if e != nil {
return &PathError{Op: "chmod", Path: name, Err: e}
}
return nil
}
// See docs in file.go:(*File).Chmod.
func (f *File) chmod(mode FileMode) error {
if err := f.checkValid("chmod"); err != nil {
return err
}
if e := f.pfd.Fchmod(syscallMode(mode)); e != nil {
return f.wrapErr("chmod", e)
}
return nil
}
// Chown changes the numeric uid and gid of the named file.
// If the file is a symbolic link, it changes the uid and gid of the link's target.
// A uid or gid of -1 means to not change that value.
// If there is an error, it will be of type *PathError.
//
// On Windows or Plan 9, Chown always returns the syscall.EWINDOWS or
// EPLAN9 error, wrapped in *PathError.
func Chown(name string, uid, gid int) error {
e := ignoringEINTR(func() error {
return syscall.Chown(name, uid, gid)
})
if e != nil {
return &PathError{Op: "chown", Path: name, Err: e}
}
return nil
}
// Lchown changes the numeric uid and gid of the named file.
// If the file is a symbolic link, it changes the uid and gid of the link itself.
// If there is an error, it will be of type *PathError.
//
// On Windows, it always returns the syscall.EWINDOWS error, wrapped
// in *PathError.
func Lchown(name string, uid, gid int) error {
e := ignoringEINTR(func() error {
return syscall.Lchown(name, uid, gid)
})
if e != nil {
return &PathError{Op: "lchown", Path: name, Err: e}
}
return nil
}
// Chown changes the numeric uid and gid of the named file.
// If there is an error, it will be of type *PathError.
//
// On Windows, it always returns the syscall.EWINDOWS error, wrapped
// in *PathError.
func (f *File) Chown(uid, gid int) error {
if err := f.checkValid("chown"); err != nil {
return err
}
if e := f.pfd.Fchown(uid, gid); e != nil {
return f.wrapErr("chown", e)
}
return nil
}
// Truncate changes the size of the file.
// It does not change the I/O offset.
// If there is an error, it will be of type *PathError.
func (f *File) Truncate(size int64) error {
if err := f.checkValid("truncate"); err != nil {
return err
}
if e := f.pfd.Ftruncate(size); e != nil {
return f.wrapErr("truncate", e)
}
return nil
}
// Sync commits the current contents of the file to stable storage.
// Typically, this means flushing the file system's in-memory copy
// of recently written data to disk.
func (f *File) Sync() error {
if err := f.checkValid("sync"); err != nil {
return err
}
if e := f.pfd.Fsync(); e != nil {
return f.wrapErr("sync", e)
}
return nil
}
// Chtimes changes the access and modification times of the named
// file, similar to the Unix utime() or utimes() functions.
//
// The underlying filesystem may truncate or round the values to a
// less precise time unit.
// If there is an error, it will be of type *PathError.
func Chtimes(name string, atime time.Time, mtime time.Time) error {
var utimes [2]syscall.Timespec
utimes[0] = syscall.NsecToTimespec(atime.UnixNano())
utimes[1] = syscall.NsecToTimespec(mtime.UnixNano())
if e := syscall.UtimesNano(fixLongPath(name), utimes[0:]); e != nil {
return &PathError{Op: "chtimes", Path: name, Err: e}
}
return nil
}
// Chdir changes the current working directory to the file,
// which must be a directory.
// If there is an error, it will be of type *PathError.
func (f *File) Chdir() error {
if err := f.checkValid("chdir"); err != nil {
return err
}
if e := f.pfd.Fchdir(); e != nil {
return f.wrapErr("chdir", e)
}
return nil
}
// setDeadline sets the read and write deadline.
func (f *File) setDeadline(t time.Time) error {
if err := f.checkValid("SetDeadline"); err != nil {
return err
}
return f.pfd.SetDeadline(t)
}
// setReadDeadline sets the read deadline.
func (f *File) setReadDeadline(t time.Time) error {
if err := f.checkValid("SetReadDeadline"); err != nil {
return err
}
return f.pfd.SetReadDeadline(t)
}
// setWriteDeadline sets the write deadline.
func (f *File) setWriteDeadline(t time.Time) error {
if err := f.checkValid("SetWriteDeadline"); err != nil {
return err
}
return f.pfd.SetWriteDeadline(t)
}
// checkValid checks whether f is valid for use.
// If not, it returns an appropriate error, perhaps incorporating the operation name op.
func (f *File) checkValid(op string) error {
if f == nil {
return ErrInvalid
}
return nil
}
// ignoringEINTR makes a function call and repeats it if it returns an
// EINTR error. This appears to be required even though we install all
// signal handlers with SA_RESTART: see #22838, #38033, #38836, #40846.
// Also #20400 and #36644 are issues in which a signal handler is
// installed without setting SA_RESTART. None of these are the common case,
// but there are enough of them that it seems that we can't avoid
// an EINTR loop.
func ignoringEINTR(fn func() error) error {
for {
err := fn()
if err != syscall.EINTR {
return err
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package os
import (
"internal/poll"
"internal/syscall/unix"
"runtime"
"syscall"
)
// fixLongPath is a noop on non-Windows platforms.
func fixLongPath(path string) string {
return path
}
func rename(oldname, newname string) error {
fi, err := Lstat(newname)
if err == nil && fi.IsDir() {
// There are two independent errors this function can return:
// one for a bad oldname, and one for a bad newname.
// At this point we've determined the newname is bad.
// But just in case oldname is also bad, prioritize returning
// the oldname error because that's what we did historically.
// However, if the old name and new name are not the same, yet
// they refer to the same file, it implies a case-only
// rename on a case-insensitive filesystem, which is ok.
if ofi, err := Lstat(oldname); err != nil {
if pe, ok := err.(*PathError); ok {
err = pe.Err
}
return &LinkError{"rename", oldname, newname, err}
} else if newname == oldname || !SameFile(fi, ofi) {
return &LinkError{"rename", oldname, newname, syscall.EEXIST}
}
}
err = ignoringEINTR(func() error {
return syscall.Rename(oldname, newname)
})
if err != nil {
return &LinkError{"rename", oldname, newname, err}
}
return nil
}
// file is the real representation of *File.
// The extra level of indirection ensures that no clients of os
// can overwrite this data, which could cause the finalizer
// to close the wrong file descriptor.
type file struct {
pfd poll.FD
name string
dirinfo *dirInfo // nil unless directory being read
nonblock bool // whether we set nonblocking mode
stdoutOrErr bool // whether this is stdout or stderr
appendMode bool // whether file is opened for appending
}
// Fd returns the integer Unix file descriptor referencing the open file.
// If f is closed, the file descriptor becomes invalid.
// If f is garbage collected, a finalizer may close the file descriptor,
// making it invalid; see runtime.SetFinalizer for more information on when
// a finalizer might be run. On Unix systems this will cause the SetDeadline
// methods to stop working.
// Because file descriptors can be reused, the returned file descriptor may
// only be closed through the Close method of f, or by its finalizer during
// garbage collection. Otherwise, during garbage collection the finalizer
// may close an unrelated file descriptor with the same (reused) number.
//
// As an alternative, see the f.SyscallConn method.
func (f *File) Fd() uintptr {
if f == nil {
return ^(uintptr(0))
}
// If we put the file descriptor into nonblocking mode,
// then set it to blocking mode before we return it,
// because historically we have always returned a descriptor
// opened in blocking mode. The File will continue to work,
// but any blocking operation will tie up a thread.
if f.nonblock {
f.pfd.SetBlocking()
}
return uintptr(f.pfd.Sysfd)
}
// NewFile returns a new File with the given file descriptor and
// name. The returned value will be nil if fd is not a valid file
// descriptor. On Unix systems, if the file descriptor is in
// non-blocking mode, NewFile will attempt to return a pollable File
// (one for which the SetDeadline methods work).
//
// After passing it to NewFile, fd may become invalid under the same
// conditions described in the comments of the Fd method, and the same
// constraints apply.
func NewFile(fd uintptr, name string) *File {
kind := kindNewFile
if nb, err := unix.IsNonblock(int(fd)); err == nil && nb {
kind = kindNonBlock
}
return newFile(fd, name, kind)
}
// newFileKind describes the kind of file to newFile.
type newFileKind int
const (
// kindNewFile means that the descriptor was passed to us via NewFile.
kindNewFile newFileKind = iota
// kindOpenFile means that the descriptor was opened using
// Open, Create, or OpenFile.
kindOpenFile
// kindPipe means that the descriptor was opened using Pipe.
kindPipe
// kindNonBlock means that the descriptor was passed to us via NewFile,
// and the descriptor is already in non-blocking mode.
kindNonBlock
// kindNoPoll means that we should not put the descriptor into
// non-blocking mode, because we know it is not a pipe or FIFO.
// Used by openFdAt for directories.
kindNoPoll
)
// newFile is like NewFile, but if called from OpenFile or Pipe
// (as passed in the kind parameter) it tries to add the file to
// the runtime poller.
func newFile(fd uintptr, name string, kind newFileKind) *File {
fdi := int(fd)
if fdi < 0 {
return nil
}
f := &File{&file{
pfd: poll.FD{
Sysfd: fdi,
IsStream: true,
ZeroReadIsEOF: true,
},
name: name,
stdoutOrErr: fdi == 1 || fdi == 2,
}}
pollable := kind == kindOpenFile || kind == kindPipe || kind == kindNonBlock
// If the caller passed a non-blocking filedes (kindNonBlock),
// we assume they know what they are doing so we allow it to be
// used with kqueue.
if kind == kindOpenFile {
switch runtime.GOOS {
case "darwin", "ios", "dragonfly", "freebsd", "netbsd", "openbsd":
var st syscall.Stat_t
err := ignoringEINTR(func() error {
return syscall.Fstat(fdi, &st)
})
typ := st.Mode & syscall.S_IFMT
// Don't try to use kqueue with regular files on *BSDs.
// On FreeBSD a regular file is always
// reported as ready for writing.
// On Dragonfly, NetBSD and OpenBSD the fd is signaled
// only once as ready (both read and write).
// Issue 19093.
// Also don't add directories to the netpoller.
if err == nil && (typ == syscall.S_IFREG || typ == syscall.S_IFDIR) {
pollable = false
}
// In addition to the behavior described above for regular files,
// on Darwin, kqueue does not work properly with fifos:
// closing the last writer does not cause a kqueue event
// for any readers. See issue #24164.
if (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && typ == syscall.S_IFIFO {
pollable = false
}
}
}
clearNonBlock := false
if pollable {
if kind == kindNonBlock {
f.nonblock = true
} else if err := syscall.SetNonblock(fdi, true); err == nil {
f.nonblock = true
clearNonBlock = true
} else {
pollable = false
}
}
// An error here indicates a failure to register
// with the netpoll system. That can happen for
// a file descriptor that is not supported by
// epoll/kqueue; for example, disk files on
// Linux systems. We assume that any real error
// will show up in later I/O.
// We do restore the blocking behavior if it was set by us.
if pollErr := f.pfd.Init("file", pollable); pollErr != nil && clearNonBlock {
if err := syscall.SetNonblock(fdi, false); err == nil {
f.nonblock = false
}
}
runtime.SetFinalizer(f.file, (*file).close)
return f
}
// epipecheck raises SIGPIPE if we get an EPIPE error on standard
// output or standard error. See the SIGPIPE docs in os/signal, and
// issue 11845.
func epipecheck(file *File, e error) {
if e == syscall.EPIPE && file.stdoutOrErr {
sigpipe()
}
}
// DevNull is the name of the operating system's “null device.”
// On Unix-like systems, it is "/dev/null"; on Windows, "NUL".
const DevNull = "/dev/null"
// openFileNolog is the Unix implementation of OpenFile.
// Changes here should be reflected in openFdAt, if relevant.
func openFileNolog(name string, flag int, perm FileMode) (*File, error) {
setSticky := false
if !supportsCreateWithStickyBit && flag&O_CREATE != 0 && perm&ModeSticky != 0 {
if _, err := Stat(name); IsNotExist(err) {
setSticky = true
}
}
var r int
for {
var e error
r, e = syscall.Open(name, flag|syscall.O_CLOEXEC, syscallMode(perm))
if e == nil {
break
}
// We have to check EINTR here, per issues 11180 and 39237.
if e == syscall.EINTR {
continue
}
return nil, &PathError{Op: "open", Path: name, Err: e}
}
// open(2) itself won't handle the sticky bit on *BSD and Solaris
if setSticky {
setStickyBit(name)
}
// There's a race here with fork/exec, which we are
// content to live with. See ../syscall/exec_unix.go.
if !supportsCloseOnExec {
syscall.CloseOnExec(r)
}
return newFile(uintptr(r), name, kindOpenFile), nil
}
func (file *file) close() error {
if file == nil {
return syscall.EINVAL
}
if file.dirinfo != nil {
file.dirinfo.close()
file.dirinfo = nil
}
var err error
if e := file.pfd.Close(); e != nil {
if e == poll.ErrFileClosing {
e = ErrClosed
}
err = &PathError{Op: "close", Path: file.name, Err: e}
}
// no need for a finalizer anymore
runtime.SetFinalizer(file, nil)
return err
}
// seek sets the offset for the next Read or Write on file to offset, interpreted
// according to whence: 0 means relative to the origin of the file, 1 means
// relative to the current offset, and 2 means relative to the end.
// It returns the new offset and an error, if any.
func (f *File) seek(offset int64, whence int) (ret int64, err error) {
if f.dirinfo != nil {
// Free cached dirinfo, so we allocate a new one if we
// access this file as a directory again. See #35767 and #37161.
f.dirinfo.close()
f.dirinfo = nil
}
ret, err = f.pfd.Seek(offset, whence)
runtime.KeepAlive(f)
return ret, err
}
// Truncate changes the size of the named file.
// If the file is a symbolic link, it changes the size of the link's target.
// If there is an error, it will be of type *PathError.
func Truncate(name string, size int64) error {
e := ignoringEINTR(func() error {
return syscall.Truncate(name, size)
})
if e != nil {
return &PathError{Op: "truncate", Path: name, Err: e}
}
return nil
}
// Remove removes the named file or (empty) directory.
// If there is an error, it will be of type *PathError.
func Remove(name string) error {
// System call interface forces us to know
// whether name is a file or directory.
// Try both: it is cheaper on average than
// doing a Stat plus the right one.
e := ignoringEINTR(func() error {
return syscall.Unlink(name)
})
if e == nil {
return nil
}
e1 := ignoringEINTR(func() error {
return syscall.Rmdir(name)
})
if e1 == nil {
return nil
}
// Both failed: figure out which error to return.
// OS X and Linux differ on whether unlink(dir)
// returns EISDIR, so can't use that. However,
// both agree that rmdir(file) returns ENOTDIR,
// so we can use that to decide which error is real.
// Rmdir might also return ENOTDIR if given a bad
// file path, like /etc/passwd/foo, but in that case,
// both errors will be ENOTDIR, so it's okay to
// use the error from unlink.
if e1 != syscall.ENOTDIR {
e = e1
}
return &PathError{Op: "remove", Path: name, Err: e}
}
func tempDir() string {
dir := Getenv("TMPDIR")
if dir == "" {
if runtime.GOOS == "android" {
dir = "/data/local/tmp"
} else {
dir = "/tmp"
}
}
return dir
}
// Link creates newname as a hard link to the oldname file.
// If there is an error, it will be of type *LinkError.
func Link(oldname, newname string) error {
e := ignoringEINTR(func() error {
return syscall.Link(oldname, newname)
})
if e != nil {
return &LinkError{"link", oldname, newname, e}
}
return nil
}
// Symlink creates newname as a symbolic link to oldname.
// On Windows, a symlink to a non-existent oldname creates a file symlink;
// if oldname is later created as a directory the symlink will not work.
// If there is an error, it will be of type *LinkError.
func Symlink(oldname, newname string) error {
e := ignoringEINTR(func() error {
return syscall.Symlink(oldname, newname)
})
if e != nil {
return &LinkError{"symlink", oldname, newname, e}
}
return nil
}
// Readlink returns the destination of the named symbolic link.
// If there is an error, it will be of type *PathError.
func Readlink(name string) (string, error) {
for len := 128; ; len *= 2 {
b := make([]byte, len)
var (
n int
e error
)
for {
n, e = fixCount(syscall.Readlink(name, b))
if e != syscall.EINTR {
break
}
}
// buffer too small
if runtime.GOOS == "aix" && e == syscall.ERANGE {
continue
}
if e != nil {
return "", &PathError{Op: "readlink", Path: name, Err: e}
}
if n < len {
return string(b[0:n]), nil
}
}
}
type unixDirent struct {
parent string
name string
typ FileMode
info FileInfo
}
func (d *unixDirent) Name() string { return d.name }
func (d *unixDirent) IsDir() bool { return d.typ.IsDir() }
func (d *unixDirent) Type() FileMode { return d.typ }
func (d *unixDirent) Info() (FileInfo, error) {
if d.info != nil {
return d.info, nil
}
return lstat(d.parent + "/" + d.name)
}
func newUnixDirent(parent, name string, typ FileMode) (DirEntry, error) {
ude := &unixDirent{
parent: parent,
name: name,
typ: typ,
}
if typ != ^FileMode(0) && !testingForceReadDirLstat {
return ude, nil
}
info, err := lstat(parent + "/" + name)
if err != nil {
return nil, err
}
ude.typ = info.Mode().Type()
ude.info = info
return ude, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"runtime"
"sync"
"syscall"
)
var getwdCache struct {
sync.Mutex
dir string
}
// Getwd returns a rooted path name corresponding to the
// current directory. If the current directory can be
// reached via multiple paths (due to symbolic links),
// Getwd may return any one of them.
func Getwd() (dir string, err error) {
if runtime.GOOS == "windows" || runtime.GOOS == "plan9" {
return syscall.Getwd()
}
// Clumsy but widespread kludge:
// if $PWD is set and matches ".", use it.
dot, err := statNolog(".")
if err != nil {
return "", err
}
dir = Getenv("PWD")
if len(dir) > 0 && dir[0] == '/' {
d, err := statNolog(dir)
if err == nil && SameFile(dot, d) {
return dir, nil
}
}
// If the operating system provides a Getwd call, use it.
// Otherwise, we're trying to find our way back to ".".
if syscall.ImplementsGetwd {
var (
s string
e error
)
for {
s, e = syscall.Getwd()
if e != syscall.EINTR {
break
}
}
return s, NewSyscallError("getwd", e)
}
// Apply same kludge but to cached dir instead of $PWD.
getwdCache.Lock()
dir = getwdCache.dir
getwdCache.Unlock()
if len(dir) > 0 {
d, err := statNolog(dir)
if err == nil && SameFile(dot, d) {
return dir, nil
}
}
// Root is a special case because it has no parent
// and ends in a slash.
root, err := statNolog("/")
if err != nil {
// Can't stat root - no hope.
return "", err
}
if SameFile(root, dot) {
return "/", nil
}
// General algorithm: find name in parent
// and then find name of parent. Each iteration
// adds /name to the beginning of dir.
dir = ""
for parent := ".."; ; parent = "../" + parent {
if len(parent) >= 1024 { // Sanity check
return "", syscall.ENAMETOOLONG
}
fd, err := openFileNolog(parent, O_RDONLY, 0)
if err != nil {
return "", err
}
for {
names, err := fd.Readdirnames(100)
if err != nil {
fd.Close()
return "", err
}
for _, name := range names {
d, _ := lstatNolog(parent + "/" + name)
if SameFile(d, dot) {
dir = "/" + name + dir
goto Found
}
}
}
Found:
pd, err := fd.Stat()
fd.Close()
if err != nil {
return "", err
}
if SameFile(pd, root) {
break
}
// Set up for next round.
dot = pd
}
// Save answer as hint to avoid the expensive path next time.
getwdCache.Lock()
getwdCache.dir = dir
getwdCache.Unlock()
return dir, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"syscall"
)
// MkdirAll creates a directory named path,
// along with any necessary parents, and returns nil,
// or else returns an error.
// The permission bits perm (before umask) are used for all
// directories that MkdirAll creates.
// If path is already a directory, MkdirAll does nothing
// and returns nil.
func MkdirAll(path string, perm FileMode) error {
// Fast path: if we can tell whether path is a directory or file, stop with success or error.
dir, err := Stat(path)
if err == nil {
if dir.IsDir() {
return nil
}
return &PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR}
}
// Slow path: make sure parent exists and then call Mkdir for path.
i := len(path)
for i > 0 && IsPathSeparator(path[i-1]) { // Skip trailing path separator.
i--
}
j := i
for j > 0 && !IsPathSeparator(path[j-1]) { // Scan backward over element.
j--
}
if j > 1 {
// Create parent.
err = MkdirAll(fixRootDirectory(path[:j-1]), perm)
if err != nil {
return err
}
}
// Parent now exists; invoke Mkdir and use its result.
err = Mkdir(path, perm)
if err != nil {
// Handle arguments like "foo/." by
// double-checking that directory doesn't exist.
dir, err1 := Lstat(path)
if err1 == nil && dir.IsDir() {
return nil
}
return err
}
return nil
}
// RemoveAll removes path and any children it contains.
// It removes everything it can but returns the first error
// it encounters. If the path does not exist, RemoveAll
// returns nil (no error).
// If there is an error, it will be of type *PathError.
func RemoveAll(path string) error {
return removeAll(path)
}
// endsWithDot reports whether the final component of path is ".".
func endsWithDot(path string) bool {
if path == "." {
return true
}
if len(path) >= 2 && path[len(path)-1] == '.' && IsPathSeparator(path[len(path)-2]) {
return true
}
return false
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package os
const (
PathSeparator = '/' // OS-specific path separator
PathListSeparator = ':' // OS-specific path list separator
)
// IsPathSeparator reports whether c is a directory separator character.
func IsPathSeparator(c uint8) bool {
return PathSeparator == c
}
// basename removes trailing slashes and the leading directory name from path name.
func basename(name string) string {
i := len(name) - 1
// Remove trailing slashes
for ; i > 0 && name[i] == '/'; i-- {
name = name[:i]
}
// Remove leading directory name
for i--; i >= 0; i-- {
if name[i] == '/' {
name = name[i+1:]
break
}
}
return name
}
// splitPath returns the base name and parent directory.
func splitPath(path string) (string, string) {
// if no better parent is found, the path is relative from "here"
dirname := "."
// Remove all but one leading slash.
for len(path) > 1 && path[0] == '/' && path[1] == '/' {
path = path[1:]
}
i := len(path) - 1
// Remove trailing slashes.
for ; i > 0 && path[i] == '/'; i-- {
path = path[:i]
}
// if no slashes in path, base is path
basename := path
// Remove leading directory path
for i--; i >= 0; i-- {
if path[i] == '/' {
if i == 0 {
dirname = path[:1]
} else {
dirname = path[:i]
}
basename = path[i+1:]
break
}
}
return dirname, basename
}
func fixRootDirectory(p string) string {
return p
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
package os
import "syscall"
// Pipe returns a connected pair of Files; reads from r return bytes written to w.
// It returns the files and an error, if any.
func Pipe() (r *File, w *File, err error) {
var p [2]int
e := syscall.Pipe2(p[0:], syscall.O_CLOEXEC)
if e != nil {
return nil, nil, NewSyscallError("pipe2", e)
}
return newFile(uintptr(p[0]), "|0", kindPipe), newFile(uintptr(p[1]), "|1", kindPipe), nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Process etc.
package os
import (
"internal/testlog"
"runtime"
"syscall"
)
// Args hold the command-line arguments, starting with the program name.
var Args []string
func init() {
if runtime.GOOS == "windows" {
// Initialized in exec_windows.go.
return
}
Args = runtime_args()
}
func runtime_args() []string // in package runtime
// Getuid returns the numeric user id of the caller.
//
// On Windows, it returns -1.
func Getuid() int { return syscall.Getuid() }
// Geteuid returns the numeric effective user id of the caller.
//
// On Windows, it returns -1.
func Geteuid() int { return syscall.Geteuid() }
// Getgid returns the numeric group id of the caller.
//
// On Windows, it returns -1.
func Getgid() int { return syscall.Getgid() }
// Getegid returns the numeric effective group id of the caller.
//
// On Windows, it returns -1.
func Getegid() int { return syscall.Getegid() }
// Getgroups returns a list of the numeric ids of groups that the caller belongs to.
//
// On Windows, it returns syscall.EWINDOWS. See the os/user package
// for a possible alternative.
func Getgroups() ([]int, error) {
gids, e := syscall.Getgroups()
return gids, NewSyscallError("getgroups", e)
}
// Exit causes the current program to exit with the given status code.
// Conventionally, code zero indicates success, non-zero an error.
// The program terminates immediately; deferred functions are not run.
//
// For portability, the status code should be in the range [0, 125].
func Exit(code int) {
if code == 0 && testlog.PanicOnExit0() {
// We were told to panic on calls to os.Exit(0).
// This is used to fail tests that make an early
// unexpected call to os.Exit(0).
panic("unexpected call to os.Exit(0) during test")
}
// Inform the runtime that os.Exit is being called. If -race is
// enabled, this will give race detector a chance to fail the
// program (racy programs do not have the right to finish
// successfully). If coverage is enabled, then this call will
// enable us to write out a coverage data file.
runtime_beforeExit(code)
syscall.Exit(code)
}
func runtime_beforeExit(exitCode int) // implemented in runtime
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !plan9
package os
import (
"runtime"
)
// rawConn implements syscall.RawConn.
type rawConn struct {
file *File
}
func (c *rawConn) Control(f func(uintptr)) error {
if err := c.file.checkValid("SyscallConn.Control"); err != nil {
return err
}
err := c.file.pfd.RawControl(f)
runtime.KeepAlive(c.file)
return err
}
func (c *rawConn) Read(f func(uintptr) bool) error {
if err := c.file.checkValid("SyscallConn.Read"); err != nil {
return err
}
err := c.file.pfd.RawRead(f)
runtime.KeepAlive(c.file)
return err
}
func (c *rawConn) Write(f func(uintptr) bool) error {
if err := c.file.checkValid("SyscallConn.Write"); err != nil {
return err
}
err := c.file.pfd.RawWrite(f)
runtime.KeepAlive(c.file)
return err
}
func newRawConn(file *File) (*rawConn, error) {
return &rawConn{file: file}, nil
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"internal/poll"
"io"
"syscall"
)
var (
pollCopyFileRange = poll.CopyFileRange
pollSplice = poll.Splice
)
func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
// Neither copy_file_range(2) nor splice(2) supports destinations opened with
// O_APPEND, so don't bother to try zero-copy with these system calls.
//
// Visit https://man7.org/linux/man-pages/man2/copy_file_range.2.html#ERRORS and
// https://man7.org/linux/man-pages/man2/splice.2.html#ERRORS for details.
if f.appendMode {
return 0, false, nil
}
written, handled, err = f.copyFileRange(r)
if handled {
return
}
return f.spliceToFile(r)
}
func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error) {
var (
remain int64
lr *io.LimitedReader
)
if lr, r, remain = tryLimitedReader(r); remain <= 0 {
return 0, true, nil
}
pfd := getPollFD(r)
// TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
// Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
// where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
// doubt that splice(2) could help non-streams, cuz they usually send small frames respectively
// and one splice call would result in one frame.
// splice(2) is suitable for large data but the generation of fragments defeats its edge here.
// Therefore, don't bother to try splice if the r is not a streaming descriptor.
if pfd == nil || !pfd.IsStream {
return
}
var syscallName string
written, handled, syscallName, err = pollSplice(&f.pfd, pfd, remain)
if lr != nil {
lr.N = remain - written
}
return written, handled, wrapSyscallError(syscallName, err)
}
// getPollFD tries to get the poll.FD from the given io.Reader by expecting
// the underlying type of r to be the implementation of syscall.Conn that contains
// a *net.rawConn.
func getPollFD(r io.Reader) *poll.FD {
sc, ok := r.(syscall.Conn)
if !ok {
return nil
}
rc, err := sc.SyscallConn()
if err != nil {
return nil
}
ipfd, ok := rc.(interface{ PollFD() *poll.FD })
if !ok {
return nil
}
return ipfd.PollFD()
}
func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
var (
remain int64
lr *io.LimitedReader
)
if lr, r, remain = tryLimitedReader(r); remain <= 0 {
return 0, true, nil
}
src, ok := r.(*File)
if !ok {
return 0, false, nil
}
if src.checkValid("ReadFrom") != nil {
// Avoid returning the error as we report handled as false,
// leave further error handling as the responsibility of the caller.
return 0, false, nil
}
written, handled, err = pollCopyFileRange(&f.pfd, &src.pfd, remain)
if lr != nil {
lr.N -= written
}
return written, handled, wrapSyscallError("copy_file_range", err)
}
// tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
// the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
// otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
var remain int64 = 1<<63 - 1 // by default, copy until EOF
lr, ok := r.(*io.LimitedReader)
if !ok {
return nil, r, remain
}
remain = lr.N
return lr, lr.R, remain
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package os
import (
"internal/syscall/unix"
"io"
"syscall"
)
func removeAll(path string) error {
if path == "" {
// fail silently to retain compatibility with previous behavior
// of RemoveAll. See issue 28830.
return nil
}
// The rmdir system call does not permit removing ".",
// so we don't permit it either.
if endsWithDot(path) {
return &PathError{Op: "RemoveAll", Path: path, Err: syscall.EINVAL}
}
// Simple case: if Remove works, we're done.
err := Remove(path)
if err == nil || IsNotExist(err) {
return nil
}
// RemoveAll recurses by deleting the path base from
// its parent directory
parentDir, base := splitPath(path)
parent, err := Open(parentDir)
if IsNotExist(err) {
// If parent does not exist, base cannot exist. Fail silently
return nil
}
if err != nil {
return err
}
defer parent.Close()
if err := removeAllFrom(parent, base); err != nil {
if pathErr, ok := err.(*PathError); ok {
pathErr.Path = parentDir + string(PathSeparator) + pathErr.Path
err = pathErr
}
return err
}
return nil
}
func removeAllFrom(parent *File, base string) error {
parentFd := int(parent.Fd())
// Simple case: if Unlink (aka remove) works, we're done.
err := ignoringEINTR(func() error {
return unix.Unlinkat(parentFd, base, 0)
})
if err == nil || IsNotExist(err) {
return nil
}
// EISDIR means that we have a directory, and we need to
// remove its contents.
// EPERM or EACCES means that we don't have write permission on
// the parent directory, but this entry might still be a directory
// whose contents need to be removed.
// Otherwise just return the error.
if err != syscall.EISDIR && err != syscall.EPERM && err != syscall.EACCES {
return &PathError{Op: "unlinkat", Path: base, Err: err}
}
// Is this a directory we need to recurse into?
var statInfo syscall.Stat_t
statErr := ignoringEINTR(func() error {
return unix.Fstatat(parentFd, base, &statInfo, unix.AT_SYMLINK_NOFOLLOW)
})
if statErr != nil {
if IsNotExist(statErr) {
return nil
}
return &PathError{Op: "fstatat", Path: base, Err: statErr}
}
if statInfo.Mode&syscall.S_IFMT != syscall.S_IFDIR {
// Not a directory; return the error from the unix.Unlinkat.
return &PathError{Op: "unlinkat", Path: base, Err: err}
}
// Remove the directory's entries.
var recurseErr error
for {
const reqSize = 1024
var respSize int
// Open the directory to recurse into
file, err := openFdAt(parentFd, base)
if err != nil {
if IsNotExist(err) {
return nil
}
recurseErr = &PathError{Op: "openfdat", Path: base, Err: err}
break
}
for {
numErr := 0
names, readErr := file.Readdirnames(reqSize)
// Errors other than EOF should stop us from continuing.
if readErr != nil && readErr != io.EOF {
file.Close()
if IsNotExist(readErr) {
return nil
}
return &PathError{Op: "readdirnames", Path: base, Err: readErr}
}
respSize = len(names)
for _, name := range names {
err := removeAllFrom(file, name)
if err != nil {
if pathErr, ok := err.(*PathError); ok {
pathErr.Path = base + string(PathSeparator) + pathErr.Path
}
numErr++
if recurseErr == nil {
recurseErr = err
}
}
}
// If we can delete any entry, break to start new iteration.
// Otherwise, we discard current names, get next entries and try deleting them.
if numErr != reqSize {
break
}
}
// Removing files from the directory may have caused
// the OS to reshuffle it. Simply calling Readdirnames
// again may skip some entries. The only reliable way
// to avoid this is to close and re-open the
// directory. See issue 20841.
file.Close()
// Finish when the end of the directory is reached
if respSize < reqSize {
break
}
}
// Remove the directory itself.
unlinkError := ignoringEINTR(func() error {
return unix.Unlinkat(parentFd, base, unix.AT_REMOVEDIR)
})
if unlinkError == nil || IsNotExist(unlinkError) {
return nil
}
if recurseErr != nil {
return recurseErr
}
return &PathError{Op: "unlinkat", Path: base, Err: unlinkError}
}
// openFdAt opens path relative to the directory in fd.
// Other than that this should act like openFileNolog.
// This acts like openFileNolog rather than OpenFile because
// we are going to (try to) remove the file.
// The contents of this file are not relevant for test caching.
func openFdAt(dirfd int, name string) (*File, error) {
var r int
for {
var e error
r, e = unix.Openat(dirfd, name, O_RDONLY|syscall.O_CLOEXEC, 0)
if e == nil {
break
}
// See comment in openFileNolog.
if e == syscall.EINTR {
continue
}
return nil, e
}
if !supportsCloseOnExec {
syscall.CloseOnExec(r)
}
// We use kindNoPoll because we know that this is a directory.
return newFile(uintptr(r), name, kindNoPoll), nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package os
import "syscall"
// Some systems set an artificially low soft limit on open file count, for compatibility
// with code that uses select and its hard-coded maximum file descriptor
// (limited by the size of fd_set).
//
// Go does not use select, so it should not be subject to these limits.
// On some systems the limit is 256, which is very easy to run into,
// even in simple programs like gofmt when they parallelize walking
// a file tree.
//
// After a long discussion on go.dev/issue/46279, we decided the
// best approach was for Go to raise the limit unconditionally for itself,
// and then leave old software to set the limit back as needed.
// Code that really wants Go to leave the limit alone can set the hard limit,
// which Go of course has no choice but to respect.
func init() {
var lim syscall.Rlimit
if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &lim); err == nil && lim.Cur != lim.Max {
lim.Cur = lim.Max
adjustFileLimit(&lim)
syscall.Setrlimit(syscall.RLIMIT_NOFILE, &lim)
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || dragonfly || freebsd || linux || netbsd || openbsd || solaris
package os
import "syscall"
// adjustFileLimit adds per-OS limitations on the Rlimit used for RLIMIT_NOFILE. See rlimit.go.
func adjustFileLimit(lim *syscall.Rlimit) {}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package signal
import (
"context"
"os"
"sync"
)
var handlers struct {
sync.Mutex
// Map a channel to the signals that should be sent to it.
m map[chan<- os.Signal]*handler
// Map a signal to the number of channels receiving it.
ref [numSig]int64
// Map channels to signals while the channel is being stopped.
// Not a map because entries live here only very briefly.
// We need a separate container because we need m to correspond to ref
// at all times, and we also need to keep track of the *handler
// value for a channel being stopped. See the Stop function.
stopping []stopping
}
type stopping struct {
c chan<- os.Signal
h *handler
}
type handler struct {
mask [(numSig + 31) / 32]uint32
}
func (h *handler) want(sig int) bool {
return (h.mask[sig/32]>>uint(sig&31))&1 != 0
}
func (h *handler) set(sig int) {
h.mask[sig/32] |= 1 << uint(sig&31)
}
func (h *handler) clear(sig int) {
h.mask[sig/32] &^= 1 << uint(sig&31)
}
// Stop relaying the signals, sigs, to any channels previously registered to
// receive them and either reset the signal handlers to their original values
// (action=disableSignal) or ignore the signals (action=ignoreSignal).
func cancel(sigs []os.Signal, action func(int)) {
handlers.Lock()
defer handlers.Unlock()
remove := func(n int) {
var zerohandler handler
for c, h := range handlers.m {
if h.want(n) {
handlers.ref[n]--
h.clear(n)
if h.mask == zerohandler.mask {
delete(handlers.m, c)
}
}
}
action(n)
}
if len(sigs) == 0 {
for n := 0; n < numSig; n++ {
remove(n)
}
} else {
for _, s := range sigs {
remove(signum(s))
}
}
}
// Ignore causes the provided signals to be ignored. If they are received by
// the program, nothing will happen. Ignore undoes the effect of any prior
// calls to Notify for the provided signals.
// If no signals are provided, all incoming signals will be ignored.
func Ignore(sig ...os.Signal) {
cancel(sig, ignoreSignal)
}
// Ignored reports whether sig is currently ignored.
func Ignored(sig os.Signal) bool {
sn := signum(sig)
return sn >= 0 && signalIgnored(sn)
}
var (
// watchSignalLoopOnce guards calling the conditionally
// initialized watchSignalLoop. If watchSignalLoop is non-nil,
// it will be run in a goroutine lazily once Notify is invoked.
// See Issue 21576.
watchSignalLoopOnce sync.Once
watchSignalLoop func()
)
// Notify causes package signal to relay incoming signals to c.
// If no signals are provided, all incoming signals will be relayed to c.
// Otherwise, just the provided signals will.
//
// Package signal will not block sending to c: the caller must ensure
// that c has sufficient buffer space to keep up with the expected
// signal rate. For a channel used for notification of just one signal value,
// a buffer of size 1 is sufficient.
//
// It is allowed to call Notify multiple times with the same channel:
// each call expands the set of signals sent to that channel.
// The only way to remove signals from the set is to call Stop.
//
// It is allowed to call Notify multiple times with different channels
// and the same signals: each channel receives copies of incoming
// signals independently.
func Notify(c chan<- os.Signal, sig ...os.Signal) {
if c == nil {
panic("os/signal: Notify using nil channel")
}
handlers.Lock()
defer handlers.Unlock()
h := handlers.m[c]
if h == nil {
if handlers.m == nil {
handlers.m = make(map[chan<- os.Signal]*handler)
}
h = new(handler)
handlers.m[c] = h
}
add := func(n int) {
if n < 0 {
return
}
if !h.want(n) {
h.set(n)
if handlers.ref[n] == 0 {
enableSignal(n)
// The runtime requires that we enable a
// signal before starting the watcher.
watchSignalLoopOnce.Do(func() {
if watchSignalLoop != nil {
go watchSignalLoop()
}
})
}
handlers.ref[n]++
}
}
if len(sig) == 0 {
for n := 0; n < numSig; n++ {
add(n)
}
} else {
for _, s := range sig {
add(signum(s))
}
}
}
// Reset undoes the effect of any prior calls to Notify for the provided
// signals.
// If no signals are provided, all signal handlers will be reset.
func Reset(sig ...os.Signal) {
cancel(sig, disableSignal)
}
// Stop causes package signal to stop relaying incoming signals to c.
// It undoes the effect of all prior calls to Notify using c.
// When Stop returns, it is guaranteed that c will receive no more signals.
func Stop(c chan<- os.Signal) {
handlers.Lock()
h := handlers.m[c]
if h == nil {
handlers.Unlock()
return
}
delete(handlers.m, c)
for n := 0; n < numSig; n++ {
if h.want(n) {
handlers.ref[n]--
if handlers.ref[n] == 0 {
disableSignal(n)
}
}
}
// Signals will no longer be delivered to the channel.
// We want to avoid a race for a signal such as SIGINT:
// it should be either delivered to the channel,
// or the program should take the default action (that is, exit).
// To avoid the possibility that the signal is delivered,
// and the signal handler invoked, and then Stop deregisters
// the channel before the process function below has a chance
// to send it on the channel, put the channel on a list of
// channels being stopped and wait for signal delivery to
// quiesce before fully removing it.
handlers.stopping = append(handlers.stopping, stopping{c, h})
handlers.Unlock()
signalWaitUntilIdle()
handlers.Lock()
for i, s := range handlers.stopping {
if s.c == c {
handlers.stopping = append(handlers.stopping[:i], handlers.stopping[i+1:]...)
break
}
}
handlers.Unlock()
}
// Wait until there are no more signals waiting to be delivered.
// Defined by the runtime package.
func signalWaitUntilIdle()
func process(sig os.Signal) {
n := signum(sig)
if n < 0 {
return
}
handlers.Lock()
defer handlers.Unlock()
for c, h := range handlers.m {
if h.want(n) {
// send but do not block for it
select {
case c <- sig:
default:
}
}
}
// Avoid the race mentioned in Stop.
for _, d := range handlers.stopping {
if d.h.want(n) {
select {
case d.c <- sig:
default:
}
}
}
}
// NotifyContext returns a copy of the parent context that is marked done
// (its Done channel is closed) when one of the listed signals arrives,
// when the returned stop function is called, or when the parent context's
// Done channel is closed, whichever happens first.
//
// The stop function unregisters the signal behavior, which, like signal.Reset,
// may restore the default behavior for a given signal. For example, the default
// behavior of a Go program receiving os.Interrupt is to exit. Calling
// NotifyContext(parent, os.Interrupt) will change the behavior to cancel
// the returned context. Future interrupts received will not trigger the default
// (exit) behavior until the returned stop function is called.
//
// The stop function releases resources associated with it, so code should
// call stop as soon as the operations running in this Context complete and
// signals no longer need to be diverted to the context.
func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
ctx, cancel := context.WithCancel(parent)
c := &signalCtx{
Context: ctx,
cancel: cancel,
signals: signals,
}
c.ch = make(chan os.Signal, 1)
Notify(c.ch, c.signals...)
if ctx.Err() == nil {
go func() {
select {
case <-c.ch:
c.cancel()
case <-c.Done():
}
}()
}
return c, c.stop
}
type signalCtx struct {
context.Context
cancel context.CancelFunc
signals []os.Signal
ch chan os.Signal
}
func (c *signalCtx) stop() {
c.cancel()
Stop(c.ch)
}
type stringer interface {
String() string
}
func (c *signalCtx) String() string {
var buf []byte
// We know that the type of c.Context is context.cancelCtx, and we know that the
// String method of cancelCtx returns a string that ends with ".WithCancel".
name := c.Context.(stringer).String()
name = name[:len(name)-len(".WithCancel")]
buf = append(buf, "signal.NotifyContext("+name...)
if len(c.signals) != 0 {
buf = append(buf, ", ["...)
for i, s := range c.signals {
buf = append(buf, s.String()...)
if i != len(c.signals)-1 {
buf = append(buf, ' ')
}
}
buf = append(buf, ']')
}
buf = append(buf, ')')
return string(buf)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package signal
import (
"os"
"syscall"
)
// Defined by the runtime package.
func signal_disable(uint32)
func signal_enable(uint32)
func signal_ignore(uint32)
func signal_ignored(uint32) bool
func signal_recv() uint32
func loop() {
for {
process(syscall.Signal(signal_recv()))
}
}
func init() {
watchSignalLoop = loop
}
const (
numSig = 65 // max across all systems
)
func signum(sig os.Signal) int {
switch sig := sig.(type) {
case syscall.Signal:
i := int(sig)
if i < 0 || i >= numSig {
return -1
}
return i
default:
return -1
}
}
func enableSignal(sig int) {
signal_enable(uint32(sig))
}
func disableSignal(sig int) {
signal_disable(uint32(sig))
}
func ignoreSignal(sig int) {
signal_ignore(uint32(sig))
}
func signalIgnored(sig int) bool {
return signal_ignored(uint32(sig))
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import "internal/testlog"
// Stat returns a FileInfo describing the named file.
// If there is an error, it will be of type *PathError.
func Stat(name string) (FileInfo, error) {
testlog.Stat(name)
return statNolog(name)
}
// Lstat returns a FileInfo describing the named file.
// If the file is a symbolic link, the returned FileInfo
// describes the symbolic link. Lstat makes no attempt to follow the link.
// If there is an error, it will be of type *PathError.
func Lstat(name string) (FileInfo, error) {
testlog.Stat(name)
return lstatNolog(name)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"syscall"
"time"
)
func fillFileStatFromSys(fs *fileStat, name string) {
fs.name = basename(name)
fs.size = fs.sys.Size
fs.modTime = time.Unix(fs.sys.Mtim.Unix())
fs.mode = FileMode(fs.sys.Mode & 0777)
switch fs.sys.Mode & syscall.S_IFMT {
case syscall.S_IFBLK:
fs.mode |= ModeDevice
case syscall.S_IFCHR:
fs.mode |= ModeDevice | ModeCharDevice
case syscall.S_IFDIR:
fs.mode |= ModeDir
case syscall.S_IFIFO:
fs.mode |= ModeNamedPipe
case syscall.S_IFLNK:
fs.mode |= ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fs.mode |= ModeSocket
}
if fs.sys.Mode&syscall.S_ISGID != 0 {
fs.mode |= ModeSetgid
}
if fs.sys.Mode&syscall.S_ISUID != 0 {
fs.mode |= ModeSetuid
}
if fs.sys.Mode&syscall.S_ISVTX != 0 {
fs.mode |= ModeSticky
}
}
// For testing.
func atime(fi FileInfo) time.Time {
return time.Unix(fi.Sys().(*syscall.Stat_t).Atim.Unix())
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package os
import (
"syscall"
)
// Stat returns the FileInfo structure describing file.
// If there is an error, it will be of type *PathError.
func (f *File) Stat() (FileInfo, error) {
if f == nil {
return nil, ErrInvalid
}
var fs fileStat
err := f.pfd.Fstat(&fs.sys)
if err != nil {
return nil, &PathError{Op: "stat", Path: f.name, Err: err}
}
fillFileStatFromSys(&fs, f.name)
return &fs, nil
}
// statNolog stats a file with no test logging.
func statNolog(name string) (FileInfo, error) {
var fs fileStat
err := ignoringEINTR(func() error {
return syscall.Stat(name, &fs.sys)
})
if err != nil {
return nil, &PathError{Op: "stat", Path: name, Err: err}
}
fillFileStatFromSys(&fs, name)
return &fs, nil
}
// lstatNolog lstats a file with no test logging.
func lstatNolog(name string) (FileInfo, error) {
var fs fileStat
err := ignoringEINTR(func() error {
return syscall.Lstat(name, &fs.sys)
})
if err != nil {
return nil, &PathError{Op: "lstat", Path: name, Err: err}
}
fillFileStatFromSys(&fs, name)
return &fs, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Simple conversions to avoid depending on strconv.
package os
// itox converts val (an int) to a hexadecimal string.
func itox(val int) string {
if val < 0 {
return "-" + uitox(uint(-val))
}
return uitox(uint(val))
}
const hex = "0123456789abcdef"
// uitox converts val (a uint) to a hexadecimal string.
func uitox(val uint) string {
if val == 0 { // avoid string allocation
return "0x0"
}
var buf [20]byte // big enough for 64bit value base 16 + 0x
i := len(buf) - 1
for val >= 16 {
q := val / 16
buf[i] = hex[val%16]
i--
val = q
}
// val < 16
buf[i] = hex[val%16]
i--
buf[i] = 'x'
i--
buf[i] = '0'
return string(buf[i:])
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
// Hostname returns the host name reported by the kernel.
func Hostname() (name string, err error) {
return hostname()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"runtime"
"syscall"
)
func hostname() (name string, err error) {
// Try uname first, as it's only one system call and reading
// from /proc is not allowed on Android.
var un syscall.Utsname
err = syscall.Uname(&un)
var buf [512]byte // Enough for a DNS name.
for i, b := range un.Nodename[:] {
buf[i] = uint8(b)
if b == 0 {
name = string(buf[:i])
break
}
}
// If we got a name and it's not potentially truncated
// (Nodename is 65 bytes), return it.
if err == nil && len(name) > 0 && len(name) < 64 {
return name, nil
}
if runtime.GOOS == "android" {
if name != "" {
return name, nil
}
return "localhost", nil
}
f, err := Open("/proc/sys/kernel/hostname")
if err != nil {
return "", err
}
defer f.Close()
n, err := f.Read(buf[:])
if err != nil {
return "", err
}
if n > 0 && buf[n-1] == '\n' {
n--
}
return string(buf[:n]), nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"errors"
"internal/itoa"
)
// fastrand provided by runtime.
// We generate random temporary file names so that there's a good
// chance the file doesn't exist yet - keeps the number of tries in
// TempFile to a minimum.
func fastrand() uint32
func nextRandom() string {
return itoa.Uitoa(uint(fastrand()))
}
// CreateTemp creates a new temporary file in the directory dir,
// opens the file for reading and writing, and returns the resulting file.
// The filename is generated by taking pattern and adding a random string to the end.
// If pattern includes a "*", the random string replaces the last "*".
// If dir is the empty string, CreateTemp uses the default directory for temporary files, as returned by TempDir.
// Multiple programs or goroutines calling CreateTemp simultaneously will not choose the same file.
// The caller can use the file's Name method to find the pathname of the file.
// It is the caller's responsibility to remove the file when it is no longer needed.
func CreateTemp(dir, pattern string) (*File, error) {
if dir == "" {
dir = TempDir()
}
prefix, suffix, err := prefixAndSuffix(pattern)
if err != nil {
return nil, &PathError{Op: "createtemp", Path: pattern, Err: err}
}
prefix = joinPath(dir, prefix)
try := 0
for {
name := prefix + nextRandom() + suffix
f, err := OpenFile(name, O_RDWR|O_CREATE|O_EXCL, 0600)
if IsExist(err) {
if try++; try < 10000 {
continue
}
return nil, &PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: ErrExist}
}
return f, err
}
}
var errPatternHasSeparator = errors.New("pattern contains path separator")
// prefixAndSuffix splits pattern by the last wildcard "*", if applicable,
// returning prefix as the part before "*" and suffix as the part after "*".
func prefixAndSuffix(pattern string) (prefix, suffix string, err error) {
for i := 0; i < len(pattern); i++ {
if IsPathSeparator(pattern[i]) {
return "", "", errPatternHasSeparator
}
}
if pos := lastIndex(pattern, '*'); pos != -1 {
prefix, suffix = pattern[:pos], pattern[pos+1:]
} else {
prefix = pattern
}
return prefix, suffix, nil
}
// MkdirTemp creates a new temporary directory in the directory dir
// and returns the pathname of the new directory.
// The new directory's name is generated by adding a random string to the end of pattern.
// If pattern includes a "*", the random string replaces the last "*" instead.
// If dir is the empty string, MkdirTemp uses the default directory for temporary files, as returned by TempDir.
// Multiple programs or goroutines calling MkdirTemp simultaneously will not choose the same directory.
// It is the caller's responsibility to remove the directory when it is no longer needed.
func MkdirTemp(dir, pattern string) (string, error) {
if dir == "" {
dir = TempDir()
}
prefix, suffix, err := prefixAndSuffix(pattern)
if err != nil {
return "", &PathError{Op: "mkdirtemp", Path: pattern, Err: err}
}
prefix = joinPath(dir, prefix)
try := 0
for {
name := prefix + nextRandom() + suffix
err := Mkdir(name, 0700)
if err == nil {
return name, nil
}
if IsExist(err) {
if try++; try < 10000 {
continue
}
return "", &PathError{Op: "mkdirtemp", Path: dir + string(PathSeparator) + prefix + "*" + suffix, Err: ErrExist}
}
if IsNotExist(err) {
if _, err := Stat(dir); IsNotExist(err) {
return "", err
}
}
return "", err
}
}
func joinPath(dir, name string) string {
if len(dir) > 0 && IsPathSeparator(dir[len(dir)-1]) {
return dir + name
}
return dir + string(PathSeparator) + name
}
// lastIndex from the strings package.
func lastIndex(s string, sep byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == sep {
return i
}
}
return -1
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os
import (
"io/fs"
"syscall"
)
// Getpagesize returns the underlying system's memory page size.
func Getpagesize() int { return syscall.Getpagesize() }
// File represents an open file descriptor.
type File struct {
*file // os specific
}
// A FileInfo describes a file and is returned by Stat and Lstat.
type FileInfo = fs.FileInfo
// A FileMode represents a file's mode and permission bits.
// The bits have the same definition on all systems, so that
// information about files can be moved from one system
// to another portably. Not all bits apply to all systems.
// The only required bit is ModeDir for directories.
type FileMode = fs.FileMode
// The defined file mode bits are the most significant bits of the FileMode.
// The nine least-significant bits are the standard Unix rwxrwxrwx permissions.
// The values of these bits should be considered part of the public API and
// may be used in wire protocols or disk representations: they must not be
// changed, although new bits might be added.
const (
// The single letters are the abbreviations
// used by the String method's formatting.
ModeDir = fs.ModeDir // d: is a directory
ModeAppend = fs.ModeAppend // a: append-only
ModeExclusive = fs.ModeExclusive // l: exclusive use
ModeTemporary = fs.ModeTemporary // T: temporary file; Plan 9 only
ModeSymlink = fs.ModeSymlink // L: symbolic link
ModeDevice = fs.ModeDevice // D: device file
ModeNamedPipe = fs.ModeNamedPipe // p: named pipe (FIFO)
ModeSocket = fs.ModeSocket // S: Unix domain socket
ModeSetuid = fs.ModeSetuid // u: setuid
ModeSetgid = fs.ModeSetgid // g: setgid
ModeCharDevice = fs.ModeCharDevice // c: Unix character device, when ModeDevice is set
ModeSticky = fs.ModeSticky // t: sticky
ModeIrregular = fs.ModeIrregular // ?: non-regular file; nothing else is known about this file
// Mask for the type bits. For regular files, none will be set.
ModeType = fs.ModeType
ModePerm = fs.ModePerm // Unix permission bits, 0o777
)
func (fs *fileStat) Name() string { return fs.name }
func (fs *fileStat) IsDir() bool { return fs.Mode().IsDir() }
// SameFile reports whether fi1 and fi2 describe the same file.
// For example, on Unix this means that the device and inode fields
// of the two underlying structures are identical; on other systems
// the decision may be based on the path names.
// SameFile only applies to results returned by this package's Stat.
// It returns false in other cases.
func SameFile(fi1, fi2 FileInfo) bool {
fs1, ok1 := fi1.(*fileStat)
fs2, ok2 := fi2.(*fileStat)
if !ok1 || !ok2 {
return false
}
return sameFile(fs1, fs2)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows && !plan9
package os
import (
"syscall"
"time"
)
// A fileStat is the implementation of FileInfo returned by Stat and Lstat.
type fileStat struct {
name string
size int64
mode FileMode
modTime time.Time
sys syscall.Stat_t
}
func (fs *fileStat) Size() int64 { return fs.size }
func (fs *fileStat) Mode() FileMode { return fs.mode }
func (fs *fileStat) ModTime() time.Time { return fs.modTime }
func (fs *fileStat) Sys() any { return &fs.sys }
func sameFile(fs1, fs2 *fileStat) bool {
return fs1.sys.Dev == fs2.sys.Dev && fs1.sys.Ino == fs2.sys.Ino
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (cgo || darwin) && !osusergo && (darwin || dragonfly || freebsd || (linux && !android) || netbsd || openbsd || (solaris && !illumos))
package user
import (
"fmt"
"strconv"
"unsafe"
)
const maxGroups = 2048
func listGroups(u *User) ([]string, error) {
ug, err := strconv.Atoi(u.Gid)
if err != nil {
return nil, fmt.Errorf("user: list groups for %s: invalid gid %q", u.Username, u.Gid)
}
userGID := _C_gid_t(ug)
nameC := make([]byte, len(u.Username)+1)
copy(nameC, u.Username)
n := _C_int(256)
gidsC := make([]_C_gid_t, n)
rv := getGroupList((*_C_char)(unsafe.Pointer(&nameC[0])), userGID, &gidsC[0], &n)
if rv == -1 {
// Mac is the only Unix that does not set n properly when rv == -1, so
// we need to use different logic for Mac vs. the other OS's.
if err := groupRetry(u.Username, nameC, userGID, &gidsC, &n); err != nil {
return nil, err
}
}
gidsC = gidsC[:n]
gids := make([]string, 0, n)
for _, g := range gidsC[:n] {
gids = append(gids, strconv.Itoa(int(g)))
}
return gids, nil
}
// groupRetry retries getGroupList with much larger size for n. The result is
// stored in gids.
func groupRetry(username string, name []byte, userGID _C_gid_t, gids *[]_C_gid_t, n *_C_int) error {
// More than initial buffer, but now n contains the correct size.
if *n > maxGroups {
return fmt.Errorf("user: %q is a member of more than %d groups", username, maxGroups)
}
*gids = make([]_C_gid_t, *n)
rv := getGroupList((*_C_char)(unsafe.Pointer(&name[0])), userGID, &(*gids)[0], n)
if rv == -1 {
return fmt.Errorf("user: list groups for %s failed", username)
}
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build cgo && !osusergo && unix && !android && !darwin
package user
import (
"syscall"
)
/*
#cgo solaris CFLAGS: -D_POSIX_PTHREAD_SEMANTICS
#cgo CFLAGS: -fno-stack-protector
#include <unistd.h>
#include <sys/types.h>
#include <pwd.h>
#include <grp.h>
#include <stdlib.h>
#include <string.h>
static struct passwd mygetpwuid_r(int uid, char *buf, size_t buflen, int *found, int *perr) {
struct passwd pwd;
struct passwd *result;
memset (&pwd, 0, sizeof(pwd));
*perr = getpwuid_r(uid, &pwd, buf, buflen, &result);
*found = result != NULL;
return pwd;
}
static struct passwd mygetpwnam_r(const char *name, char *buf, size_t buflen, int *found, int *perr) {
struct passwd pwd;
struct passwd *result;
memset(&pwd, 0, sizeof(pwd));
*perr = getpwnam_r(name, &pwd, buf, buflen, &result);
*found = result != NULL;
return pwd;
}
static struct group mygetgrgid_r(int gid, char *buf, size_t buflen, int *found, int *perr) {
struct group grp;
struct group *result;
memset(&grp, 0, sizeof(grp));
*perr = getgrgid_r(gid, &grp, buf, buflen, &result);
*found = result != NULL;
return grp;
}
static struct group mygetgrnam_r(const char *name, char *buf, size_t buflen, int *found, int *perr) {
struct group grp;
struct group *result;
memset(&grp, 0, sizeof(grp));
*perr = getgrnam_r(name, &grp, buf, buflen, &result);
*found = result != NULL;
return grp;
}
*/
import "C"
type _C_char = C.char
type _C_int = C.int
type _C_gid_t = C.gid_t
type _C_uid_t = C.uid_t
type _C_size_t = C.size_t
type _C_struct_group = C.struct_group
type _C_struct_passwd = C.struct_passwd
type _C_long = C.long
func _C_pw_uid(p *_C_struct_passwd) _C_uid_t { return p.pw_uid }
func _C_pw_uidp(p *_C_struct_passwd) *_C_uid_t { return &p.pw_uid }
func _C_pw_gid(p *_C_struct_passwd) _C_gid_t { return p.pw_gid }
func _C_pw_gidp(p *_C_struct_passwd) *_C_gid_t { return &p.pw_gid }
func _C_pw_name(p *_C_struct_passwd) *_C_char { return p.pw_name }
func _C_pw_gecos(p *_C_struct_passwd) *_C_char { return p.pw_gecos }
func _C_pw_dir(p *_C_struct_passwd) *_C_char { return p.pw_dir }
func _C_gr_gid(g *_C_struct_group) _C_gid_t { return g.gr_gid }
func _C_gr_name(g *_C_struct_group) *_C_char { return g.gr_name }
func _C_GoString(p *_C_char) string { return C.GoString(p) }
func _C_getpwnam_r(name *_C_char, buf *_C_char, size _C_size_t) (pwd _C_struct_passwd, found bool, errno syscall.Errno) {
var f, e _C_int
pwd = C.mygetpwnam_r(name, buf, size, &f, &e)
return pwd, f != 0, syscall.Errno(e)
}
func _C_getpwuid_r(uid _C_uid_t, buf *_C_char, size _C_size_t) (pwd _C_struct_passwd, found bool, errno syscall.Errno) {
var f, e _C_int
pwd = C.mygetpwuid_r(_C_int(uid), buf, size, &f, &e)
return pwd, f != 0, syscall.Errno(e)
}
func _C_getgrnam_r(name *_C_char, buf *_C_char, size _C_size_t) (grp _C_struct_group, found bool, errno syscall.Errno) {
var f, e _C_int
grp = C.mygetgrnam_r(name, buf, size, &f, &e)
return grp, f != 0, syscall.Errno(e)
}
func _C_getgrgid_r(gid _C_gid_t, buf *_C_char, size _C_size_t) (grp _C_struct_group, found bool, errno syscall.Errno) {
var f, e _C_int
grp = C.mygetgrgid_r(_C_int(gid), buf, size, &f, &e)
return grp, f != 0, syscall.Errno(e)
}
const (
_C__SC_GETPW_R_SIZE_MAX = C._SC_GETPW_R_SIZE_MAX
_C__SC_GETGR_R_SIZE_MAX = C._SC_GETGR_R_SIZE_MAX
)
func _C_sysconf(key _C_int) _C_long { return C.sysconf(key) }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (cgo || darwin) && !osusergo && unix && !android
package user
import (
"fmt"
"runtime"
"strconv"
"strings"
"syscall"
"unsafe"
)
func current() (*User, error) {
return lookupUnixUid(syscall.Getuid())
}
func lookupUser(username string) (*User, error) {
var pwd _C_struct_passwd
var found bool
nameC := make([]byte, len(username)+1)
copy(nameC, username)
err := retryWithBuffer(userBuffer, func(buf []byte) syscall.Errno {
var errno syscall.Errno
pwd, found, errno = _C_getpwnam_r((*_C_char)(unsafe.Pointer(&nameC[0])),
(*_C_char)(unsafe.Pointer(&buf[0])), _C_size_t(len(buf)))
return errno
})
if err != nil {
return nil, fmt.Errorf("user: lookup username %s: %v", username, err)
}
if !found {
return nil, UnknownUserError(username)
}
return buildUser(&pwd), err
}
func lookupUserId(uid string) (*User, error) {
i, e := strconv.Atoi(uid)
if e != nil {
return nil, e
}
return lookupUnixUid(i)
}
func lookupUnixUid(uid int) (*User, error) {
var pwd _C_struct_passwd
var found bool
err := retryWithBuffer(userBuffer, func(buf []byte) syscall.Errno {
var errno syscall.Errno
pwd, found, errno = _C_getpwuid_r(_C_uid_t(uid),
(*_C_char)(unsafe.Pointer(&buf[0])), _C_size_t(len(buf)))
return errno
})
if err != nil {
return nil, fmt.Errorf("user: lookup userid %d: %v", uid, err)
}
if !found {
return nil, UnknownUserIdError(uid)
}
return buildUser(&pwd), nil
}
func buildUser(pwd *_C_struct_passwd) *User {
u := &User{
Uid: strconv.FormatUint(uint64(_C_pw_uid(pwd)), 10),
Gid: strconv.FormatUint(uint64(_C_pw_gid(pwd)), 10),
Username: _C_GoString(_C_pw_name(pwd)),
Name: _C_GoString(_C_pw_gecos(pwd)),
HomeDir: _C_GoString(_C_pw_dir(pwd)),
}
// The pw_gecos field isn't quite standardized. Some docs
// say: "It is expected to be a comma separated list of
// personal data where the first item is the full name of the
// user."
u.Name, _, _ = strings.Cut(u.Name, ",")
return u
}
func lookupGroup(groupname string) (*Group, error) {
var grp _C_struct_group
var found bool
cname := make([]byte, len(groupname)+1)
copy(cname, groupname)
err := retryWithBuffer(groupBuffer, func(buf []byte) syscall.Errno {
var errno syscall.Errno
grp, found, errno = _C_getgrnam_r((*_C_char)(unsafe.Pointer(&cname[0])),
(*_C_char)(unsafe.Pointer(&buf[0])), _C_size_t(len(buf)))
return errno
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupname %s: %v", groupname, err)
}
if !found {
return nil, UnknownGroupError(groupname)
}
return buildGroup(&grp), nil
}
func lookupGroupId(gid string) (*Group, error) {
i, e := strconv.Atoi(gid)
if e != nil {
return nil, e
}
return lookupUnixGid(i)
}
func lookupUnixGid(gid int) (*Group, error) {
var grp _C_struct_group
var found bool
err := retryWithBuffer(groupBuffer, func(buf []byte) syscall.Errno {
var errno syscall.Errno
grp, found, errno = _C_getgrgid_r(_C_gid_t(gid),
(*_C_char)(unsafe.Pointer(&buf[0])), _C_size_t(len(buf)))
return syscall.Errno(errno)
})
if err != nil {
return nil, fmt.Errorf("user: lookup groupid %d: %v", gid, err)
}
if !found {
return nil, UnknownGroupIdError(strconv.Itoa(gid))
}
return buildGroup(&grp), nil
}
func buildGroup(grp *_C_struct_group) *Group {
g := &Group{
Gid: strconv.Itoa(int(_C_gr_gid(grp))),
Name: _C_GoString(_C_gr_name(grp)),
}
return g
}
type bufferKind _C_int
var (
userBuffer = bufferKind(_C__SC_GETPW_R_SIZE_MAX)
groupBuffer = bufferKind(_C__SC_GETGR_R_SIZE_MAX)
)
func (k bufferKind) initialSize() _C_size_t {
sz := _C_sysconf(_C_int(k))
if sz == -1 {
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
// Additionally, not all Linux systems have it, either. For
// example, the musl libc returns -1.
return 1024
}
if !isSizeReasonable(int64(sz)) {
// Truncate. If this truly isn't enough, retryWithBuffer will error on the first run.
return maxBufferSize
}
return _C_size_t(sz)
}
// retryWithBuffer repeatedly calls f(), increasing the size of the
// buffer each time, until f succeeds, fails with a non-ERANGE error,
// or the buffer exceeds a reasonable limit.
func retryWithBuffer(startSize bufferKind, f func([]byte) syscall.Errno) error {
buf := make([]byte, startSize)
for {
errno := f(buf)
if errno == 0 {
return nil
} else if runtime.GOOS == "aix" && errno+1 == 0 {
// On AIX getpwuid_r appears to return -1,
// not ERANGE, on buffer overflow.
} else if errno != syscall.ERANGE {
return errno
}
newSize := len(buf) * 2
if !isSizeReasonable(int64(newSize)) {
return fmt.Errorf("internal buffer exceeds %d bytes", maxBufferSize)
}
buf = make([]byte, newSize)
}
}
const maxBufferSize = 1 << 20
func isSizeReasonable(sz int64) bool {
return sz > 0 && sz <= maxBufferSize
}
// Because we can't use cgo in tests:
func structPasswdForNegativeTest() _C_struct_passwd {
sp := _C_struct_passwd{}
*_C_pw_uidp(&sp) = 1<<32 - 2
*_C_pw_gidp(&sp) = 1<<32 - 3
return sp
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build cgo && !osusergo && (dragonfly || freebsd || (!android && linux) || netbsd || openbsd || (solaris && !illumos))
package user
/*
#include <unistd.h>
#include <sys/types.h>
#include <grp.h>
static int mygetgrouplist(const char* user, gid_t group, gid_t* groups, int* ngroups) {
return getgrouplist(user, group, groups, ngroups);
}
*/
import "C"
func getGroupList(name *_C_char, userGID _C_gid_t, gids *_C_gid_t, n *_C_int) _C_int {
return C.mygetgrouplist(name, userGID, gids, n)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package user
import "sync"
const (
userFile = "/etc/passwd"
groupFile = "/etc/group"
)
var colon = []byte{':'}
// Current returns the current user.
//
// The first call will cache the current user information.
// Subsequent calls will return the cached value and will not reflect
// changes to the current user.
func Current() (*User, error) {
cache.Do(func() { cache.u, cache.err = current() })
if cache.err != nil {
return nil, cache.err
}
u := *cache.u // copy
return &u, nil
}
// cache of the current user
var cache struct {
sync.Once
u *User
err error
}
// Lookup looks up a user by username. If the user cannot be found, the
// returned error is of type UnknownUserError.
func Lookup(username string) (*User, error) {
if u, err := Current(); err == nil && u.Username == username {
return u, err
}
return lookupUser(username)
}
// LookupId looks up a user by userid. If the user cannot be found, the
// returned error is of type UnknownUserIdError.
func LookupId(uid string) (*User, error) {
if u, err := Current(); err == nil && u.Uid == uid {
return u, err
}
return lookupUserId(uid)
}
// LookupGroup looks up a group by name. If the group cannot be found, the
// returned error is of type UnknownGroupError.
func LookupGroup(name string) (*Group, error) {
return lookupGroup(name)
}
// LookupGroupId looks up a group by groupid. If the group cannot be found, the
// returned error is of type UnknownGroupIdError.
func LookupGroupId(gid string) (*Group, error) {
return lookupGroupId(gid)
}
// GroupIds returns the list of group IDs that the user is a member of.
func (u *User) GroupIds() ([]string, error) {
return listGroups(u)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package user allows user account lookups by name or id.
For most Unix systems, this package has two internal implementations of
resolving user and group ids to names, and listing supplementary group IDs.
One is written in pure Go and parses /etc/passwd and /etc/group. The other
is cgo-based and relies on the standard C library (libc) routines such as
getpwuid_r, getgrnam_r, and getgrouplist.
When cgo is available, and the required routines are implemented in libc
for a particular platform, cgo-based (libc-backed) code is used.
This can be overridden by using osusergo build tag, which enforces
the pure Go implementation.
*/
package user
import (
"strconv"
)
// These may be set to false in init() for a particular platform and/or
// build flags to let the tests know to skip tests of some features.
var (
userImplemented = true
groupImplemented = true
groupListImplemented = true
)
// User represents a user account.
type User struct {
// Uid is the user ID.
// On POSIX systems, this is a decimal number representing the uid.
// On Windows, this is a security identifier (SID) in a string format.
// On Plan 9, this is the contents of /dev/user.
Uid string
// Gid is the primary group ID.
// On POSIX systems, this is a decimal number representing the gid.
// On Windows, this is a SID in a string format.
// On Plan 9, this is the contents of /dev/user.
Gid string
// Username is the login name.
Username string
// Name is the user's real or display name.
// It might be blank.
// On POSIX systems, this is the first (or only) entry in the GECOS field
// list.
// On Windows, this is the user's display name.
// On Plan 9, this is the contents of /dev/user.
Name string
// HomeDir is the path to the user's home directory (if they have one).
HomeDir string
}
// Group represents a grouping of users.
//
// On POSIX systems Gid contains a decimal number representing the group ID.
type Group struct {
Gid string // group ID
Name string // group name
}
// UnknownUserIdError is returned by LookupId when a user cannot be found.
type UnknownUserIdError int
func (e UnknownUserIdError) Error() string {
return "user: unknown userid " + strconv.Itoa(int(e))
}
// UnknownUserError is returned by Lookup when
// a user cannot be found.
type UnknownUserError string
func (e UnknownUserError) Error() string {
return "user: unknown user " + string(e)
}
// UnknownGroupIdError is returned by LookupGroupId when
// a group cannot be found.
type UnknownGroupIdError string
func (e UnknownGroupIdError) Error() string {
return "group: unknown groupid " + string(e)
}
// UnknownGroupError is returned by LookupGroup when
// a group cannot be found.
type UnknownGroupError string
func (e UnknownGroupError) Error() string {
return "group: unknown group " + string(e)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// We used to used this code for Darwin, but according to issue #19314
// waitid returns if the process is stopped, even when using WEXITED.
//go:build linux
package os
import (
"runtime"
"syscall"
"unsafe"
)
const _P_PID = 1
// blockUntilWaitable attempts to block until a call to p.Wait will
// succeed immediately, and reports whether it has done so.
// It does not actually call p.Wait.
func (p *Process) blockUntilWaitable() (bool, error) {
// The waitid system call expects a pointer to a siginfo_t,
// which is 128 bytes on all Linux systems.
// On darwin/amd64, it requires 104 bytes.
// We don't care about the values it returns.
var siginfo [16]uint64
psig := &siginfo[0]
var e syscall.Errno
for {
_, _, e = syscall.Syscall6(syscall.SYS_WAITID, _P_PID, uintptr(p.Pid), uintptr(unsafe.Pointer(psig)), syscall.WEXITED|syscall.WNOWAIT, 0, 0)
if e != syscall.EINTR {
break
}
}
runtime.KeepAlive(p)
if e != 0 {
// waitid has been available since Linux 2.6.9, but
// reportedly is not available in Ubuntu on Windows.
// See issue 16610.
if e == syscall.ENOSYS {
return false, nil
}
return false, NewSyscallError("waitid", e)
}
return true, nil
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filepath
import (
"errors"
"os"
"runtime"
"sort"
"strings"
"unicode/utf8"
)
// ErrBadPattern indicates a pattern was malformed.
var ErrBadPattern = errors.New("syntax error in pattern")
// Match reports whether name matches the shell file name pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-Separator characters
// '?' matches any single non-Separator character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
//
// On Windows, escaping is disabled. Instead, '\\' is treated as
// path separator.
func Match(pattern, name string) (matched bool, err error) {
Pattern:
for len(pattern) > 0 {
var star bool
var chunk string
star, chunk, pattern = scanChunk(pattern)
if star && chunk == "" {
// Trailing * matches rest of string unless it has a /.
return !strings.Contains(name, string(Separator)), nil
}
// Look for match at current position.
t, ok, err := matchChunk(chunk, name)
// if we're the last chunk, make sure we've exhausted the name
// otherwise we'll give a false result even if we could still match
// using the star
if ok && (len(t) == 0 || len(pattern) > 0) {
name = t
continue
}
if err != nil {
return false, err
}
if star {
// Look for match skipping i+1 bytes.
// Cannot skip /.
for i := 0; i < len(name) && name[i] != Separator; i++ {
t, ok, err := matchChunk(chunk, name[i+1:])
if ok {
// if we're the last chunk, make sure we exhausted the name
if len(pattern) == 0 && len(t) > 0 {
continue
}
name = t
continue Pattern
}
if err != nil {
return false, err
}
}
}
return false, nil
}
return len(name) == 0, nil
}
// scanChunk gets the next segment of pattern, which is a non-star string
// possibly preceded by a star.
func scanChunk(pattern string) (star bool, chunk, rest string) {
for len(pattern) > 0 && pattern[0] == '*' {
pattern = pattern[1:]
star = true
}
inrange := false
var i int
Scan:
for i = 0; i < len(pattern); i++ {
switch pattern[i] {
case '\\':
if runtime.GOOS != "windows" {
// error check handled in matchChunk: bad pattern.
if i+1 < len(pattern) {
i++
}
}
case '[':
inrange = true
case ']':
inrange = false
case '*':
if !inrange {
break Scan
}
}
}
return star, pattern[0:i], pattern[i:]
}
// matchChunk checks whether chunk matches the beginning of s.
// If so, it returns the remainder of s (after the match).
// Chunk is all single-character operators: literals, char classes, and ?.
func matchChunk(chunk, s string) (rest string, ok bool, err error) {
// failed records whether the match has failed.
// After the match fails, the loop continues on processing chunk,
// checking that the pattern is well-formed but no longer reading s.
failed := false
for len(chunk) > 0 {
if !failed && len(s) == 0 {
failed = true
}
switch chunk[0] {
case '[':
// character class
var r rune
if !failed {
var n int
r, n = utf8.DecodeRuneInString(s)
s = s[n:]
}
chunk = chunk[1:]
// possibly negated
negated := false
if len(chunk) > 0 && chunk[0] == '^' {
negated = true
chunk = chunk[1:]
}
// parse all ranges
match := false
nrange := 0
for {
if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
chunk = chunk[1:]
break
}
var lo, hi rune
if lo, chunk, err = getEsc(chunk); err != nil {
return "", false, err
}
hi = lo
if chunk[0] == '-' {
if hi, chunk, err = getEsc(chunk[1:]); err != nil {
return "", false, err
}
}
if lo <= r && r <= hi {
match = true
}
nrange++
}
if match == negated {
failed = true
}
case '?':
if !failed {
if s[0] == Separator {
failed = true
}
_, n := utf8.DecodeRuneInString(s)
s = s[n:]
}
chunk = chunk[1:]
case '\\':
if runtime.GOOS != "windows" {
chunk = chunk[1:]
if len(chunk) == 0 {
return "", false, ErrBadPattern
}
}
fallthrough
default:
if !failed {
if chunk[0] != s[0] {
failed = true
}
s = s[1:]
}
chunk = chunk[1:]
}
}
if failed {
return "", false, nil
}
return s, true, nil
}
// getEsc gets a possibly-escaped character from chunk, for a character class.
func getEsc(chunk string) (r rune, nchunk string, err error) {
if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
err = ErrBadPattern
return
}
if chunk[0] == '\\' && runtime.GOOS != "windows" {
chunk = chunk[1:]
if len(chunk) == 0 {
err = ErrBadPattern
return
}
}
r, n := utf8.DecodeRuneInString(chunk)
if r == utf8.RuneError && n == 1 {
err = ErrBadPattern
}
nchunk = chunk[n:]
if len(nchunk) == 0 {
err = ErrBadPattern
}
return
}
// Glob returns the names of all files matching pattern or nil
// if there is no matching file. The syntax of patterns is the same
// as in Match. The pattern may describe hierarchical names such as
// /usr/*/bin/ed (assuming the Separator is '/').
//
// Glob ignores file system errors such as I/O errors reading directories.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
func Glob(pattern string) (matches []string, err error) {
return globWithLimit(pattern, 0)
}
func globWithLimit(pattern string, depth int) (matches []string, err error) {
// This limit is used prevent stack exhaustion issues. See CVE-2022-30632.
const pathSeparatorsLimit = 10000
if depth == pathSeparatorsLimit {
return nil, ErrBadPattern
}
// Check pattern is well-formed.
if _, err := Match(pattern, ""); err != nil {
return nil, err
}
if !hasMeta(pattern) {
if _, err = os.Lstat(pattern); err != nil {
return nil, nil
}
return []string{pattern}, nil
}
dir, file := Split(pattern)
volumeLen := 0
if runtime.GOOS == "windows" {
volumeLen, dir = cleanGlobPathWindows(dir)
} else {
dir = cleanGlobPath(dir)
}
if !hasMeta(dir[volumeLen:]) {
return glob(dir, file, nil)
}
// Prevent infinite recursion. See issue 15879.
if dir == pattern {
return nil, ErrBadPattern
}
var m []string
m, err = globWithLimit(dir, depth+1)
if err != nil {
return
}
for _, d := range m {
matches, err = glob(d, file, matches)
if err != nil {
return
}
}
return
}
// cleanGlobPath prepares path for glob matching.
func cleanGlobPath(path string) string {
switch path {
case "":
return "."
case string(Separator):
// do nothing to the path
return path
default:
return path[0 : len(path)-1] // chop off trailing separator
}
}
// cleanGlobPathWindows is windows version of cleanGlobPath.
func cleanGlobPathWindows(path string) (prefixLen int, cleaned string) {
vollen := volumeNameLen(path)
switch {
case path == "":
return 0, "."
case vollen+1 == len(path) && os.IsPathSeparator(path[len(path)-1]): // /, \, C:\ and C:/
// do nothing to the path
return vollen + 1, path
case vollen == len(path) && len(path) == 2: // C:
return vollen, path + "." // convert C: into C:.
default:
if vollen >= len(path) {
vollen = len(path) - 1
}
return vollen, path[0 : len(path)-1] // chop off trailing separator
}
}
// glob searches for files matching pattern in the directory dir
// and appends them to matches. If the directory cannot be
// opened, it returns the existing matches. New matches are
// added in lexicographical order.
func glob(dir, pattern string, matches []string) (m []string, e error) {
m = matches
fi, err := os.Stat(dir)
if err != nil {
return // ignore I/O error
}
if !fi.IsDir() {
return // ignore I/O error
}
d, err := os.Open(dir)
if err != nil {
return // ignore I/O error
}
defer d.Close()
names, _ := d.Readdirnames(-1)
sort.Strings(names)
for _, n := range names {
matched, err := Match(pattern, n)
if err != nil {
return m, err
}
if matched {
m = append(m, Join(dir, n))
}
}
return
}
// hasMeta reports whether path contains any of the magic characters
// recognized by Match.
func hasMeta(path string) bool {
magicChars := `*?[`
if runtime.GOOS != "windows" {
magicChars = `*?[\`
}
return strings.ContainsAny(path, magicChars)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package filepath implements utility routines for manipulating filename paths
// in a way compatible with the target operating system-defined file paths.
//
// The filepath package uses either forward slashes or backslashes,
// depending on the operating system. To process paths such as URLs
// that always use forward slashes regardless of the operating
// system, see the path package.
package filepath
import (
"errors"
"io/fs"
"os"
"runtime"
"sort"
"strings"
)
// A lazybuf is a lazily constructed path buffer.
// It supports append, reading previously appended bytes,
// and retrieving the final string. It does not allocate a buffer
// to hold the output until that output diverges from s.
type lazybuf struct {
path string
buf []byte
w int
volAndPath string
volLen int
}
func (b *lazybuf) index(i int) byte {
if b.buf != nil {
return b.buf[i]
}
return b.path[i]
}
func (b *lazybuf) append(c byte) {
if b.buf == nil {
if b.w < len(b.path) && b.path[b.w] == c {
b.w++
return
}
b.buf = make([]byte, len(b.path))
copy(b.buf, b.path[:b.w])
}
b.buf[b.w] = c
b.w++
}
func (b *lazybuf) string() string {
if b.buf == nil {
return b.volAndPath[:b.volLen+b.w]
}
return b.volAndPath[:b.volLen] + string(b.buf[:b.w])
}
const (
Separator = os.PathSeparator
ListSeparator = os.PathListSeparator
)
// Clean returns the shortest path name equivalent to path
// by purely lexical processing. It applies the following rules
// iteratively until no further processing can be done:
//
// 1. Replace multiple Separator elements with a single one.
// 2. Eliminate each . path name element (the current directory).
// 3. Eliminate each inner .. path name element (the parent directory)
// along with the non-.. element that precedes it.
// 4. Eliminate .. elements that begin a rooted path:
// that is, replace "/.." by "/" at the beginning of a path,
// assuming Separator is '/'.
//
// The returned path ends in a slash only if it represents a root directory,
// such as "/" on Unix or `C:\` on Windows.
//
// Finally, any occurrences of slash are replaced by Separator.
//
// If the result of this process is an empty string, Clean
// returns the string ".".
//
// On Windows, Clean does not modify the volume name other than to replace
// occurrences of "/" with `\`.
// For example, Clean("//host/share/../x") returns `\\host\share\x`.
//
// See also Rob Pike, “Lexical File Names in Plan 9 or
// Getting Dot-Dot Right,”
// https://9p.io/sys/doc/lexnames.html
func Clean(path string) string {
originalPath := path
volLen := volumeNameLen(path)
path = path[volLen:]
if path == "" {
if volLen > 1 && os.IsPathSeparator(originalPath[0]) && os.IsPathSeparator(originalPath[1]) {
// should be UNC
return FromSlash(originalPath)
}
return originalPath + "."
}
rooted := os.IsPathSeparator(path[0])
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// dotdot is index in buf where .. must stop, either because
// it is the leading slash or it is a leading ../../.. prefix.
n := len(path)
out := lazybuf{path: path, volAndPath: originalPath, volLen: volLen}
r, dotdot := 0, 0
if rooted {
out.append(Separator)
r, dotdot = 1, 1
}
for r < n {
switch {
case os.IsPathSeparator(path[r]):
// empty path element
r++
case path[r] == '.' && (r+1 == n || os.IsPathSeparator(path[r+1])):
// . element
r++
case path[r] == '.' && path[r+1] == '.' && (r+2 == n || os.IsPathSeparator(path[r+2])):
// .. element: remove to last separator
r += 2
switch {
case out.w > dotdot:
// can backtrack
out.w--
for out.w > dotdot && !os.IsPathSeparator(out.index(out.w)) {
out.w--
}
case !rooted:
// cannot backtrack, but not rooted, so append .. element.
if out.w > 0 {
out.append(Separator)
}
out.append('.')
out.append('.')
dotdot = out.w
}
default:
// real path element.
// add slash if needed
if rooted && out.w != 1 || !rooted && out.w != 0 {
out.append(Separator)
}
// If a ':' appears in the path element at the start of a Windows path,
// insert a .\ at the beginning to avoid converting relative paths
// like a/../c: into c:.
if runtime.GOOS == "windows" && out.w == 0 && out.volLen == 0 && r != 0 {
for i := r; i < n && !os.IsPathSeparator(path[i]); i++ {
if path[i] == ':' {
out.append('.')
out.append(Separator)
break
}
}
}
// copy element
for ; r < n && !os.IsPathSeparator(path[r]); r++ {
out.append(path[r])
}
}
}
// Turn empty string into "."
if out.w == 0 {
out.append('.')
}
return FromSlash(out.string())
}
// IsLocal reports whether path, using lexical analysis only, has all of these properties:
//
// - is within the subtree rooted at the directory in which path is evaluated
// - is not an absolute path
// - is not empty
// - on Windows, is not a reserved name such as "NUL"
//
// If IsLocal(path) returns true, then
// Join(base, path) will always produce a path contained within base and
// Clean(path) will always produce an unrooted path with no ".." path elements.
//
// IsLocal is a purely lexical operation.
// In particular, it does not account for the effect of any symbolic links
// that may exist in the filesystem.
func IsLocal(path string) bool {
return isLocal(path)
}
func unixIsLocal(path string) bool {
if IsAbs(path) || path == "" {
return false
}
hasDots := false
for p := path; p != ""; {
var part string
part, p, _ = strings.Cut(p, "/")
if part == "." || part == ".." {
hasDots = true
break
}
}
if hasDots {
path = Clean(path)
}
if path == ".." || strings.HasPrefix(path, "../") {
return false
}
return true
}
// ToSlash returns the result of replacing each separator character
// in path with a slash ('/') character. Multiple separators are
// replaced by multiple slashes.
func ToSlash(path string) string {
if Separator == '/' {
return path
}
return strings.ReplaceAll(path, string(Separator), "/")
}
// FromSlash returns the result of replacing each slash ('/') character
// in path with a separator character. Multiple slashes are replaced
// by multiple separators.
func FromSlash(path string) string {
if Separator == '/' {
return path
}
return strings.ReplaceAll(path, "/", string(Separator))
}
// SplitList splits a list of paths joined by the OS-specific ListSeparator,
// usually found in PATH or GOPATH environment variables.
// Unlike strings.Split, SplitList returns an empty slice when passed an empty
// string.
func SplitList(path string) []string {
return splitList(path)
}
// Split splits path immediately following the final Separator,
// separating it into a directory and file name component.
// If there is no Separator in path, Split returns an empty dir
// and file set to path.
// The returned values have the property that path = dir+file.
func Split(path string) (dir, file string) {
vol := VolumeName(path)
i := len(path) - 1
for i >= len(vol) && !os.IsPathSeparator(path[i]) {
i--
}
return path[:i+1], path[i+1:]
}
// Join joins any number of path elements into a single path,
// separating them with an OS specific Separator. Empty elements
// are ignored. The result is Cleaned. However, if the argument
// list is empty or all its elements are empty, Join returns
// an empty string.
// On Windows, the result will only be a UNC path if the first
// non-empty element is a UNC path.
func Join(elem ...string) string {
return join(elem)
}
// Ext returns the file name extension used by path.
// The extension is the suffix beginning at the final dot
// in the final element of path; it is empty if there is
// no dot.
func Ext(path string) string {
for i := len(path) - 1; i >= 0 && !os.IsPathSeparator(path[i]); i-- {
if path[i] == '.' {
return path[i:]
}
}
return ""
}
// EvalSymlinks returns the path name after the evaluation of any symbolic
// links.
// If path is relative the result will be relative to the current directory,
// unless one of the components is an absolute symbolic link.
// EvalSymlinks calls Clean on the result.
func EvalSymlinks(path string) (string, error) {
return evalSymlinks(path)
}
// Abs returns an absolute representation of path.
// If the path is not absolute it will be joined with the current
// working directory to turn it into an absolute path. The absolute
// path name for a given file is not guaranteed to be unique.
// Abs calls Clean on the result.
func Abs(path string) (string, error) {
return abs(path)
}
func unixAbs(path string) (string, error) {
if IsAbs(path) {
return Clean(path), nil
}
wd, err := os.Getwd()
if err != nil {
return "", err
}
return Join(wd, path), nil
}
// Rel returns a relative path that is lexically equivalent to targpath when
// joined to basepath with an intervening separator. That is,
// Join(basepath, Rel(basepath, targpath)) is equivalent to targpath itself.
// On success, the returned path will always be relative to basepath,
// even if basepath and targpath share no elements.
// An error is returned if targpath can't be made relative to basepath or if
// knowing the current working directory would be necessary to compute it.
// Rel calls Clean on the result.
func Rel(basepath, targpath string) (string, error) {
baseVol := VolumeName(basepath)
targVol := VolumeName(targpath)
base := Clean(basepath)
targ := Clean(targpath)
if sameWord(targ, base) {
return ".", nil
}
base = base[len(baseVol):]
targ = targ[len(targVol):]
if base == "." {
base = ""
} else if base == "" && volumeNameLen(baseVol) > 2 /* isUNC */ {
// Treat any targetpath matching `\\host\share` basepath as absolute path.
base = string(Separator)
}
// Can't use IsAbs - `\a` and `a` are both relative in Windows.
baseSlashed := len(base) > 0 && base[0] == Separator
targSlashed := len(targ) > 0 && targ[0] == Separator
if baseSlashed != targSlashed || !sameWord(baseVol, targVol) {
return "", errors.New("Rel: can't make " + targpath + " relative to " + basepath)
}
// Position base[b0:bi] and targ[t0:ti] at the first differing elements.
bl := len(base)
tl := len(targ)
var b0, bi, t0, ti int
for {
for bi < bl && base[bi] != Separator {
bi++
}
for ti < tl && targ[ti] != Separator {
ti++
}
if !sameWord(targ[t0:ti], base[b0:bi]) {
break
}
if bi < bl {
bi++
}
if ti < tl {
ti++
}
b0 = bi
t0 = ti
}
if base[b0:bi] == ".." {
return "", errors.New("Rel: can't make " + targpath + " relative to " + basepath)
}
if b0 != bl {
// Base elements left. Must go up before going down.
seps := strings.Count(base[b0:bl], string(Separator))
size := 2 + seps*3
if tl != t0 {
size += 1 + tl - t0
}
buf := make([]byte, size)
n := copy(buf, "..")
for i := 0; i < seps; i++ {
buf[n] = Separator
copy(buf[n+1:], "..")
n += 3
}
if t0 != tl {
buf[n] = Separator
copy(buf[n+1:], targ[t0:])
}
return string(buf), nil
}
return targ[t0:], nil
}
// SkipDir is used as a return value from WalkFuncs to indicate that
// the directory named in the call is to be skipped. It is not returned
// as an error by any function.
var SkipDir error = fs.SkipDir
// SkipAll is used as a return value from WalkFuncs to indicate that
// all remaining files and directories are to be skipped. It is not returned
// as an error by any function.
var SkipAll error = fs.SkipAll
// WalkFunc is the type of the function called by Walk to visit each
// file or directory.
//
// The path argument contains the argument to Walk as a prefix.
// That is, if Walk is called with root argument "dir" and finds a file
// named "a" in that directory, the walk function will be called with
// argument "dir/a".
//
// The directory and file are joined with Join, which may clean the
// directory name: if Walk is called with the root argument "x/../dir"
// and finds a file named "a" in that directory, the walk function will
// be called with argument "dir/a", not "x/../dir/a".
//
// The info argument is the fs.FileInfo for the named path.
//
// The error result returned by the function controls how Walk continues.
// If the function returns the special value SkipDir, Walk skips the
// current directory (path if info.IsDir() is true, otherwise path's
// parent directory). If the function returns the special value SkipAll,
// Walk skips all remaining files and directories. Otherwise, if the function
// returns a non-nil error, Walk stops entirely and returns that error.
//
// The err argument reports an error related to path, signaling that Walk
// will not walk into that directory. The function can decide how to
// handle that error; as described earlier, returning the error will
// cause Walk to stop walking the entire tree.
//
// Walk calls the function with a non-nil err argument in two cases.
//
// First, if an os.Lstat on the root directory or any directory or file
// in the tree fails, Walk calls the function with path set to that
// directory or file's path, info set to nil, and err set to the error
// from os.Lstat.
//
// Second, if a directory's Readdirnames method fails, Walk calls the
// function with path set to the directory's path, info, set to an
// fs.FileInfo describing the directory, and err set to the error from
// Readdirnames.
type WalkFunc func(path string, info fs.FileInfo, err error) error
var lstat = os.Lstat // for testing
// walkDir recursively descends path, calling walkDirFn.
func walkDir(path string, d fs.DirEntry, walkDirFn fs.WalkDirFunc) error {
if err := walkDirFn(path, d, nil); err != nil || !d.IsDir() {
if err == SkipDir && d.IsDir() {
// Successfully skipped directory.
err = nil
}
return err
}
dirs, err := readDir(path)
if err != nil {
// Second call, to report ReadDir error.
err = walkDirFn(path, d, err)
if err != nil {
if err == SkipDir && d.IsDir() {
err = nil
}
return err
}
}
for _, d1 := range dirs {
path1 := Join(path, d1.Name())
if err := walkDir(path1, d1, walkDirFn); err != nil {
if err == SkipDir {
break
}
return err
}
}
return nil
}
// walk recursively descends path, calling walkFn.
func walk(path string, info fs.FileInfo, walkFn WalkFunc) error {
if !info.IsDir() {
return walkFn(path, info, nil)
}
names, err := readDirNames(path)
err1 := walkFn(path, info, err)
// If err != nil, walk can't walk into this directory.
// err1 != nil means walkFn want walk to skip this directory or stop walking.
// Therefore, if one of err and err1 isn't nil, walk will return.
if err != nil || err1 != nil {
// The caller's behavior is controlled by the return value, which is decided
// by walkFn. walkFn may ignore err and return nil.
// If walkFn returns SkipDir or SkipAll, it will be handled by the caller.
// So walk should return whatever walkFn returns.
return err1
}
for _, name := range names {
filename := Join(path, name)
fileInfo, err := lstat(filename)
if err != nil {
if err := walkFn(filename, fileInfo, err); err != nil && err != SkipDir {
return err
}
} else {
err = walk(filename, fileInfo, walkFn)
if err != nil {
if !fileInfo.IsDir() || err != SkipDir {
return err
}
}
}
}
return nil
}
// WalkDir walks the file tree rooted at root, calling fn for each file or
// directory in the tree, including root.
//
// All errors that arise visiting files and directories are filtered by fn:
// see the fs.WalkDirFunc documentation for details.
//
// The files are walked in lexical order, which makes the output deterministic
// but requires WalkDir to read an entire directory into memory before proceeding
// to walk that directory.
//
// WalkDir does not follow symbolic links.
//
// WalkDir calls fn with paths that use the separator character appropriate
// for the operating system. This is unlike [io/fs.WalkDir], which always
// uses slash separated paths.
func WalkDir(root string, fn fs.WalkDirFunc) error {
info, err := os.Lstat(root)
if err != nil {
err = fn(root, nil, err)
} else {
err = walkDir(root, &statDirEntry{info}, fn)
}
if err == SkipDir || err == SkipAll {
return nil
}
return err
}
type statDirEntry struct {
info fs.FileInfo
}
func (d *statDirEntry) Name() string { return d.info.Name() }
func (d *statDirEntry) IsDir() bool { return d.info.IsDir() }
func (d *statDirEntry) Type() fs.FileMode { return d.info.Mode().Type() }
func (d *statDirEntry) Info() (fs.FileInfo, error) { return d.info, nil }
// Walk walks the file tree rooted at root, calling fn for each file or
// directory in the tree, including root.
//
// All errors that arise visiting files and directories are filtered by fn:
// see the WalkFunc documentation for details.
//
// The files are walked in lexical order, which makes the output deterministic
// but requires Walk to read an entire directory into memory before proceeding
// to walk that directory.
//
// Walk does not follow symbolic links.
//
// Walk is less efficient than WalkDir, introduced in Go 1.16,
// which avoids calling os.Lstat on every visited file or directory.
func Walk(root string, fn WalkFunc) error {
info, err := os.Lstat(root)
if err != nil {
err = fn(root, nil, err)
} else {
err = walk(root, info, fn)
}
if err == SkipDir || err == SkipAll {
return nil
}
return err
}
// readDir reads the directory named by dirname and returns
// a sorted list of directory entries.
func readDir(dirname string) ([]fs.DirEntry, error) {
f, err := os.Open(dirname)
if err != nil {
return nil, err
}
dirs, err := f.ReadDir(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
return dirs, nil
}
// readDirNames reads the directory named by dirname and returns
// a sorted list of directory entry names.
func readDirNames(dirname string) ([]string, error) {
f, err := os.Open(dirname)
if err != nil {
return nil, err
}
names, err := f.Readdirnames(-1)
f.Close()
if err != nil {
return nil, err
}
sort.Strings(names)
return names, nil
}
// Base returns the last element of path.
// Trailing path separators are removed before extracting the last element.
// If the path is empty, Base returns ".".
// If the path consists entirely of separators, Base returns a single separator.
func Base(path string) string {
if path == "" {
return "."
}
// Strip trailing slashes.
for len(path) > 0 && os.IsPathSeparator(path[len(path)-1]) {
path = path[0 : len(path)-1]
}
// Throw away volume name
path = path[len(VolumeName(path)):]
// Find the last element
i := len(path) - 1
for i >= 0 && !os.IsPathSeparator(path[i]) {
i--
}
if i >= 0 {
path = path[i+1:]
}
// If empty now, it had only slashes.
if path == "" {
return string(Separator)
}
return path
}
// Dir returns all but the last element of path, typically the path's directory.
// After dropping the final element, Dir calls Clean on the path and trailing
// slashes are removed.
// If the path is empty, Dir returns ".".
// If the path consists entirely of separators, Dir returns a single separator.
// The returned path does not end in a separator unless it is the root directory.
func Dir(path string) string {
vol := VolumeName(path)
i := len(path) - 1
for i >= len(vol) && !os.IsPathSeparator(path[i]) {
i--
}
dir := Clean(path[len(vol) : i+1])
if dir == "." && len(vol) > 2 {
// must be UNC
return vol
}
return vol + dir
}
// VolumeName returns leading volume name.
// Given "C:\foo\bar" it returns "C:" on Windows.
// Given "\\host\share\foo" it returns "\\host\share".
// On other platforms it returns "".
func VolumeName(path string) string {
return FromSlash(path[:volumeNameLen(path)])
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package filepath
import "strings"
func isLocal(path string) bool {
return unixIsLocal(path)
}
// IsAbs reports whether the path is absolute.
func IsAbs(path string) bool {
return strings.HasPrefix(path, "/")
}
// volumeNameLen returns length of the leading volume name on Windows.
// It returns 0 elsewhere.
func volumeNameLen(path string) int {
return 0
}
// HasPrefix exists for historical compatibility and should not be used.
//
// Deprecated: HasPrefix does not respect path boundaries and
// does not ignore case when required.
func HasPrefix(p, prefix string) bool {
return strings.HasPrefix(p, prefix)
}
func splitList(path string) []string {
if path == "" {
return []string{}
}
return strings.Split(path, string(ListSeparator))
}
func abs(path string) (string, error) {
return unixAbs(path)
}
func join(elem []string) string {
// If there's a bug here, fix the logic in ./path_plan9.go too.
for i, e := range elem {
if e != "" {
return Clean(strings.Join(elem[i:], string(Separator)))
}
}
return ""
}
func sameWord(a, b string) bool {
return a == b
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filepath
import (
"errors"
"io/fs"
"os"
"runtime"
"syscall"
)
func walkSymlinks(path string) (string, error) {
volLen := volumeNameLen(path)
pathSeparator := string(os.PathSeparator)
if volLen < len(path) && os.IsPathSeparator(path[volLen]) {
volLen++
}
vol := path[:volLen]
dest := vol
linksWalked := 0
for start, end := volLen, volLen; start < len(path); start = end {
for start < len(path) && os.IsPathSeparator(path[start]) {
start++
}
end = start
for end < len(path) && !os.IsPathSeparator(path[end]) {
end++
}
// On Windows, "." can be a symlink.
// We look it up, and use the value if it is absolute.
// If not, we just return ".".
isWindowsDot := runtime.GOOS == "windows" && path[volumeNameLen(path):] == "."
// The next path component is in path[start:end].
if end == start {
// No more path components.
break
} else if path[start:end] == "." && !isWindowsDot {
// Ignore path component ".".
continue
} else if path[start:end] == ".." {
// Back up to previous component if possible.
// Note that volLen includes any leading slash.
// Set r to the index of the last slash in dest,
// after the volume.
var r int
for r = len(dest) - 1; r >= volLen; r-- {
if os.IsPathSeparator(dest[r]) {
break
}
}
if r < volLen || dest[r+1:] == ".." {
// Either path has no slashes
// (it's empty or just "C:")
// or it ends in a ".." we had to keep.
// Either way, keep this "..".
if len(dest) > volLen {
dest += pathSeparator
}
dest += ".."
} else {
// Discard everything since the last slash.
dest = dest[:r]
}
continue
}
// Ordinary path component. Add it to result.
if len(dest) > volumeNameLen(dest) && !os.IsPathSeparator(dest[len(dest)-1]) {
dest += pathSeparator
}
dest += path[start:end]
// Resolve symlink.
fi, err := os.Lstat(dest)
if err != nil {
return "", err
}
if fi.Mode()&fs.ModeSymlink == 0 {
if !fi.Mode().IsDir() && end < len(path) {
return "", syscall.ENOTDIR
}
continue
}
// Found symlink.
linksWalked++
if linksWalked > 255 {
return "", errors.New("EvalSymlinks: too many links")
}
link, err := os.Readlink(dest)
if err != nil {
return "", err
}
if isWindowsDot && !IsAbs(link) {
// On Windows, if "." is a relative symlink,
// just return ".".
break
}
path = link + path[end:]
v := volumeNameLen(link)
if v > 0 {
// Symlink to drive name is an absolute path.
if v < len(link) && os.IsPathSeparator(link[v]) {
v++
}
vol = link[:v]
dest = vol
end = len(vol)
} else if len(link) > 0 && os.IsPathSeparator(link[0]) {
// Symlink to absolute path.
dest = link[:1]
end = 1
vol = link[:1]
volLen = 1
} else {
// Symlink to relative path; replace last
// path component in dest.
var r int
for r = len(dest) - 1; r >= volLen; r-- {
if os.IsPathSeparator(dest[r]) {
break
}
}
if r < volLen {
dest = vol
} else {
dest = dest[:r]
}
end = 0
}
}
return Clean(dest), nil
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows && !plan9
package filepath
func evalSymlinks(path string) (string, error) {
return walkSymlinks(path)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package path
import (
"errors"
"internal/bytealg"
"unicode/utf8"
)
// ErrBadPattern indicates a pattern was malformed.
var ErrBadPattern = errors.New("syntax error in pattern")
// Match reports whether name matches the shell pattern.
// The pattern syntax is:
//
// pattern:
// { term }
// term:
// '*' matches any sequence of non-/ characters
// '?' matches any single non-/ character
// '[' [ '^' ] { character-range } ']'
// character class (must be non-empty)
// c matches character c (c != '*', '?', '\\', '[')
// '\\' c matches character c
//
// character-range:
// c matches character c (c != '\\', '-', ']')
// '\\' c matches character c
// lo '-' hi matches character c for lo <= c <= hi
//
// Match requires pattern to match all of name, not just a substring.
// The only possible returned error is ErrBadPattern, when pattern
// is malformed.
func Match(pattern, name string) (matched bool, err error) {
Pattern:
for len(pattern) > 0 {
var star bool
var chunk string
star, chunk, pattern = scanChunk(pattern)
if star && chunk == "" {
// Trailing * matches rest of string unless it has a /.
return bytealg.IndexByteString(name, '/') < 0, nil
}
// Look for match at current position.
t, ok, err := matchChunk(chunk, name)
// if we're the last chunk, make sure we've exhausted the name
// otherwise we'll give a false result even if we could still match
// using the star
if ok && (len(t) == 0 || len(pattern) > 0) {
name = t
continue
}
if err != nil {
return false, err
}
if star {
// Look for match skipping i+1 bytes.
// Cannot skip /.
for i := 0; i < len(name) && name[i] != '/'; i++ {
t, ok, err := matchChunk(chunk, name[i+1:])
if ok {
// if we're the last chunk, make sure we exhausted the name
if len(pattern) == 0 && len(t) > 0 {
continue
}
name = t
continue Pattern
}
if err != nil {
return false, err
}
}
}
// Before returning false with no error,
// check that the remainder of the pattern is syntactically valid.
for len(pattern) > 0 {
_, chunk, pattern = scanChunk(pattern)
if _, _, err := matchChunk(chunk, ""); err != nil {
return false, err
}
}
return false, nil
}
return len(name) == 0, nil
}
// scanChunk gets the next segment of pattern, which is a non-star string
// possibly preceded by a star.
func scanChunk(pattern string) (star bool, chunk, rest string) {
for len(pattern) > 0 && pattern[0] == '*' {
pattern = pattern[1:]
star = true
}
inrange := false
var i int
Scan:
for i = 0; i < len(pattern); i++ {
switch pattern[i] {
case '\\':
// error check handled in matchChunk: bad pattern.
if i+1 < len(pattern) {
i++
}
case '[':
inrange = true
case ']':
inrange = false
case '*':
if !inrange {
break Scan
}
}
}
return star, pattern[0:i], pattern[i:]
}
// matchChunk checks whether chunk matches the beginning of s.
// If so, it returns the remainder of s (after the match).
// Chunk is all single-character operators: literals, char classes, and ?.
func matchChunk(chunk, s string) (rest string, ok bool, err error) {
// failed records whether the match has failed.
// After the match fails, the loop continues on processing chunk,
// checking that the pattern is well-formed but no longer reading s.
failed := false
for len(chunk) > 0 {
if !failed && len(s) == 0 {
failed = true
}
switch chunk[0] {
case '[':
// character class
var r rune
if !failed {
var n int
r, n = utf8.DecodeRuneInString(s)
s = s[n:]
}
chunk = chunk[1:]
// possibly negated
negated := false
if len(chunk) > 0 && chunk[0] == '^' {
negated = true
chunk = chunk[1:]
}
// parse all ranges
match := false
nrange := 0
for {
if len(chunk) > 0 && chunk[0] == ']' && nrange > 0 {
chunk = chunk[1:]
break
}
var lo, hi rune
if lo, chunk, err = getEsc(chunk); err != nil {
return "", false, err
}
hi = lo
if chunk[0] == '-' {
if hi, chunk, err = getEsc(chunk[1:]); err != nil {
return "", false, err
}
}
if lo <= r && r <= hi {
match = true
}
nrange++
}
if match == negated {
failed = true
}
case '?':
if !failed {
if s[0] == '/' {
failed = true
}
_, n := utf8.DecodeRuneInString(s)
s = s[n:]
}
chunk = chunk[1:]
case '\\':
chunk = chunk[1:]
if len(chunk) == 0 {
return "", false, ErrBadPattern
}
fallthrough
default:
if !failed {
if chunk[0] != s[0] {
failed = true
}
s = s[1:]
}
chunk = chunk[1:]
}
}
if failed {
return "", false, nil
}
return s, true, nil
}
// getEsc gets a possibly-escaped character from chunk, for a character class.
func getEsc(chunk string) (r rune, nchunk string, err error) {
if len(chunk) == 0 || chunk[0] == '-' || chunk[0] == ']' {
err = ErrBadPattern
return
}
if chunk[0] == '\\' {
chunk = chunk[1:]
if len(chunk) == 0 {
err = ErrBadPattern
return
}
}
r, n := utf8.DecodeRuneInString(chunk)
if r == utf8.RuneError && n == 1 {
err = ErrBadPattern
}
nchunk = chunk[n:]
if len(nchunk) == 0 {
err = ErrBadPattern
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package path implements utility routines for manipulating slash-separated
// paths.
//
// The path package should only be used for paths separated by forward
// slashes, such as the paths in URLs. This package does not deal with
// Windows paths with drive letters or backslashes; to manipulate
// operating system paths, use the path/filepath package.
package path
// A lazybuf is a lazily constructed path buffer.
// It supports append, reading previously appended bytes,
// and retrieving the final string. It does not allocate a buffer
// to hold the output until that output diverges from s.
type lazybuf struct {
s string
buf []byte
w int
}
func (b *lazybuf) index(i int) byte {
if b.buf != nil {
return b.buf[i]
}
return b.s[i]
}
func (b *lazybuf) append(c byte) {
if b.buf == nil {
if b.w < len(b.s) && b.s[b.w] == c {
b.w++
return
}
b.buf = make([]byte, len(b.s))
copy(b.buf, b.s[:b.w])
}
b.buf[b.w] = c
b.w++
}
func (b *lazybuf) string() string {
if b.buf == nil {
return b.s[:b.w]
}
return string(b.buf[:b.w])
}
// Clean returns the shortest path name equivalent to path
// by purely lexical processing. It applies the following rules
// iteratively until no further processing can be done:
//
// 1. Replace multiple slashes with a single slash.
// 2. Eliminate each . path name element (the current directory).
// 3. Eliminate each inner .. path name element (the parent directory)
// along with the non-.. element that precedes it.
// 4. Eliminate .. elements that begin a rooted path:
// that is, replace "/.." by "/" at the beginning of a path.
//
// The returned path ends in a slash only if it is the root "/".
//
// If the result of this process is an empty string, Clean
// returns the string ".".
//
// See also Rob Pike, “Lexical File Names in Plan 9 or
// Getting Dot-Dot Right,”
// https://9p.io/sys/doc/lexnames.html
func Clean(path string) string {
if path == "" {
return "."
}
rooted := path[0] == '/'
n := len(path)
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// dotdot is index in buf where .. must stop, either because
// it is the leading slash or it is a leading ../../.. prefix.
out := lazybuf{s: path}
r, dotdot := 0, 0
if rooted {
out.append('/')
r, dotdot = 1, 1
}
for r < n {
switch {
case path[r] == '/':
// empty path element
r++
case path[r] == '.' && (r+1 == n || path[r+1] == '/'):
// . element
r++
case path[r] == '.' && path[r+1] == '.' && (r+2 == n || path[r+2] == '/'):
// .. element: remove to last /
r += 2
switch {
case out.w > dotdot:
// can backtrack
out.w--
for out.w > dotdot && out.index(out.w) != '/' {
out.w--
}
case !rooted:
// cannot backtrack, but not rooted, so append .. element.
if out.w > 0 {
out.append('/')
}
out.append('.')
out.append('.')
dotdot = out.w
}
default:
// real path element.
// add slash if needed
if rooted && out.w != 1 || !rooted && out.w != 0 {
out.append('/')
}
// copy element
for ; r < n && path[r] != '/'; r++ {
out.append(path[r])
}
}
}
// Turn empty string into "."
if out.w == 0 {
return "."
}
return out.string()
}
// lastSlash(s) is strings.LastIndex(s, "/") but we can't import strings.
func lastSlash(s string) int {
i := len(s) - 1
for i >= 0 && s[i] != '/' {
i--
}
return i
}
// Split splits path immediately following the final slash,
// separating it into a directory and file name component.
// If there is no slash in path, Split returns an empty dir and
// file set to path.
// The returned values have the property that path = dir+file.
func Split(path string) (dir, file string) {
i := lastSlash(path)
return path[:i+1], path[i+1:]
}
// Join joins any number of path elements into a single path,
// separating them with slashes. Empty elements are ignored.
// The result is Cleaned. However, if the argument list is
// empty or all its elements are empty, Join returns
// an empty string.
func Join(elem ...string) string {
size := 0
for _, e := range elem {
size += len(e)
}
if size == 0 {
return ""
}
buf := make([]byte, 0, size+len(elem)-1)
for _, e := range elem {
if len(buf) > 0 || e != "" {
if len(buf) > 0 {
buf = append(buf, '/')
}
buf = append(buf, e...)
}
}
return Clean(string(buf))
}
// Ext returns the file name extension used by path.
// The extension is the suffix beginning at the final dot
// in the final slash-separated element of path;
// it is empty if there is no dot.
func Ext(path string) string {
for i := len(path) - 1; i >= 0 && path[i] != '/'; i-- {
if path[i] == '.' {
return path[i:]
}
}
return ""
}
// Base returns the last element of path.
// Trailing slashes are removed before extracting the last element.
// If the path is empty, Base returns ".".
// If the path consists entirely of slashes, Base returns "/".
func Base(path string) string {
if path == "" {
return "."
}
// Strip trailing slashes.
for len(path) > 0 && path[len(path)-1] == '/' {
path = path[0 : len(path)-1]
}
// Find the last element
if i := lastSlash(path); i >= 0 {
path = path[i+1:]
}
// If empty now, it had only slashes.
if path == "" {
return "/"
}
return path
}
// IsAbs reports whether the path is absolute.
func IsAbs(path string) bool {
return len(path) > 0 && path[0] == '/'
}
// Dir returns all but the last element of path, typically the path's directory.
// After dropping the final element using Split, the path is Cleaned and trailing
// slashes are removed.
// If the path is empty, Dir returns ".".
// If the path consists entirely of slashes followed by non-slash bytes, Dir
// returns a single slash. In any other case, the returned path does not end in a
// slash.
func Dir(path string) string {
dir, _ := Split(path)
return Clean(dir)
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package plugin implements loading and symbol resolution of Go plugins.
//
// A plugin is a Go main package with exported functions and variables that
// has been built with:
//
// go build -buildmode=plugin
//
// When a plugin is first opened, the init functions of all packages not
// already part of the program are called. The main function is not run.
// A plugin is only initialized once, and cannot be closed.
//
// # Warnings
//
// The ability to dynamically load parts of an application during
// execution, perhaps based on user-defined configuration, may be a
// useful building block in some designs. In particular, because
// applications and dynamically loaded functions can share data
// structures directly, plugins may enable very high-performance
// integration of separate parts.
//
// However, the plugin mechanism has many significant drawbacks that
// should be considered carefully during the design. For example:
//
// - Plugins are currently supported only on Linux, FreeBSD, and
// macOS, making them unsuitable for applications intended to be
// portable.
//
// - Applications that use plugins may require careful configuration
// to ensure that the various parts of the program be made available
// in the correct location in the file system (or container image).
// By contrast, deploying an application consisting of a single static
// executable is straightforward.
//
// - Reasoning about program initialization is more difficult when
// some packages may not be initialized until long after the
// application has started running.
//
// - Bugs in applications that load plugins could be exploited by an
// an attacker to load dangerous or untrusted libraries.
//
// - Runtime crashes are likely to occur unless all parts of the
// program (the application and all its plugins) are compiled
// using exactly the same version of the toolchain, the same build
// tags, and the same values of certain flags and environment
// variables.
//
// - Similar crashing problems are likely to arise unless all common
// dependencies of the application and its plugins are built from
// exactly the same source code.
//
// - Together, these restrictions mean that, in practice, the
// application and its plugins must all be built together by a
// single person or component of a system. In that case, it may
// be simpler for that person or component to generate Go source
// files that blank-import the desired set of plugins and then
// compile a static executable in the usual way.
//
// For these reasons, many users decide that traditional interprocess
// communication (IPC) mechanisms such as sockets, pipes, remote
// procedure call (RPC), shared memory mappings, or file system
// operations may be more suitable despite the performance overheads.
package plugin
// Plugin is a loaded Go plugin.
type Plugin struct {
pluginpath string
err string // set if plugin failed to load
loaded chan struct{} // closed when loaded
syms map[string]any
}
// Open opens a Go plugin.
// If a path has already been opened, then the existing *Plugin is returned.
// It is safe for concurrent use by multiple goroutines.
func Open(path string) (*Plugin, error) {
return open(path)
}
// Lookup searches for a symbol named symName in plugin p.
// A symbol is any exported variable or function.
// It reports an error if the symbol is not found.
// It is safe for concurrent use by multiple goroutines.
func (p *Plugin) Lookup(symName string) (Symbol, error) {
return lookup(p, symName)
}
// A Symbol is a pointer to a variable or function.
//
// For example, a plugin defined as
//
// package main
//
// import "fmt"
//
// var V int
//
// func F() { fmt.Printf("Hello, number %d\n", V) }
//
// may be loaded with the Open function and then the exported package
// symbols V and F can be accessed
//
// p, err := plugin.Open("plugin_name.so")
// if err != nil {
// panic(err)
// }
// v, err := p.Lookup("V")
// if err != nil {
// panic(err)
// }
// f, err := p.Lookup("F")
// if err != nil {
// panic(err)
// }
// *v.(*int) = 7
// f.(func())() // prints "Hello, number 7"
type Symbol any
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (linux && cgo) || (darwin && cgo) || (freebsd && cgo)
package plugin
/*
#cgo linux LDFLAGS: -ldl
#include <dlfcn.h>
#include <limits.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdio.h>
static uintptr_t pluginOpen(const char* path, char** err) {
void* h = dlopen(path, RTLD_NOW|RTLD_GLOBAL);
if (h == NULL) {
*err = (char*)dlerror();
}
return (uintptr_t)h;
}
static void* pluginLookup(uintptr_t h, const char* name, char** err) {
void* r = dlsym((void*)h, name);
if (r == NULL) {
*err = (char*)dlerror();
}
return r;
}
*/
import "C"
import (
"errors"
"sync"
"unsafe"
)
func open(name string) (*Plugin, error) {
cPath := make([]byte, C.PATH_MAX+1)
cRelName := make([]byte, len(name)+1)
copy(cRelName, name)
if C.realpath(
(*C.char)(unsafe.Pointer(&cRelName[0])),
(*C.char)(unsafe.Pointer(&cPath[0]))) == nil {
return nil, errors.New(`plugin.Open("` + name + `"): realpath failed`)
}
filepath := C.GoString((*C.char)(unsafe.Pointer(&cPath[0])))
pluginsMu.Lock()
if p := plugins[filepath]; p != nil {
pluginsMu.Unlock()
if p.err != "" {
return nil, errors.New(`plugin.Open("` + name + `"): ` + p.err + ` (previous failure)`)
}
<-p.loaded
return p, nil
}
var cErr *C.char
h := C.pluginOpen((*C.char)(unsafe.Pointer(&cPath[0])), &cErr)
if h == 0 {
pluginsMu.Unlock()
return nil, errors.New(`plugin.Open("` + name + `"): ` + C.GoString(cErr))
}
// TODO(crawshaw): look for plugin note, confirm it is a Go plugin
// and it was built with the correct toolchain.
if len(name) > 3 && name[len(name)-3:] == ".so" {
name = name[:len(name)-3]
}
if plugins == nil {
plugins = make(map[string]*Plugin)
}
pluginpath, syms, initTasks, errstr := lastmoduleinit()
if errstr != "" {
plugins[filepath] = &Plugin{
pluginpath: pluginpath,
err: errstr,
}
pluginsMu.Unlock()
return nil, errors.New(`plugin.Open("` + name + `"): ` + errstr)
}
// This function can be called from the init function of a plugin.
// Drop a placeholder in the map so subsequent opens can wait on it.
p := &Plugin{
pluginpath: pluginpath,
loaded: make(chan struct{}),
}
plugins[filepath] = p
pluginsMu.Unlock()
doInit(initTasks)
// Fill out the value of each plugin symbol.
updatedSyms := map[string]any{}
for symName, sym := range syms {
isFunc := symName[0] == '.'
if isFunc {
delete(syms, symName)
symName = symName[1:]
}
fullName := pluginpath + "." + symName
cname := make([]byte, len(fullName)+1)
copy(cname, fullName)
p := C.pluginLookup(h, (*C.char)(unsafe.Pointer(&cname[0])), &cErr)
if p == nil {
return nil, errors.New(`plugin.Open("` + name + `"): could not find symbol ` + symName + `: ` + C.GoString(cErr))
}
valp := (*[2]unsafe.Pointer)(unsafe.Pointer(&sym))
if isFunc {
(*valp)[1] = unsafe.Pointer(&p)
} else {
(*valp)[1] = p
}
// we can't add to syms during iteration as we'll end up processing
// some symbols twice with the inability to tell if the symbol is a function
updatedSyms[symName] = sym
}
p.syms = updatedSyms
close(p.loaded)
return p, nil
}
func lookup(p *Plugin, symName string) (Symbol, error) {
if s := p.syms[symName]; s != nil {
return s, nil
}
return nil, errors.New("plugin: symbol " + symName + " not found in plugin " + p.pluginpath)
}
var (
pluginsMu sync.Mutex
plugins map[string]*Plugin
)
// lastmoduleinit is defined in package runtime.
func lastmoduleinit() (pluginpath string, syms map[string]any, inittasks []*initTask, errstr string)
// doInit is defined in package runtime.
//
//go:linkname doInit runtime.doInit
func doInit(t []*initTask)
type initTask struct {
// fields defined in runtime.initTask. We only handle pointers to an initTask
// in this package, so the contents are irrelevant.
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflect
import (
"internal/abi"
"internal/goarch"
"unsafe"
)
// These variables are used by the register assignment
// algorithm in this file.
//
// They should be modified with care (no other reflect code
// may be executing) and are generally only modified
// when testing this package.
//
// They should never be set higher than their internal/abi
// constant counterparts, because the system relies on a
// structure that is at least large enough to hold the
// registers the system supports.
//
// Currently they're set to zero because using the actual
// constants will break every part of the toolchain that
// uses reflect to call functions (e.g. go test, or anything
// that uses text/template). The values that are currently
// commented out there should be the actual values once
// we're ready to use the register ABI everywhere.
var (
intArgRegs = abi.IntArgRegs
floatArgRegs = abi.FloatArgRegs
floatRegSize = uintptr(abi.EffectiveFloatRegSize)
)
// abiStep represents an ABI "instruction." Each instruction
// describes one part of how to translate between a Go value
// in memory and a call frame.
type abiStep struct {
kind abiStepKind
// offset and size together describe a part of a Go value
// in memory.
offset uintptr
size uintptr // size in bytes of the part
// These fields describe the ABI side of the translation.
stkOff uintptr // stack offset, used if kind == abiStepStack
ireg int // integer register index, used if kind == abiStepIntReg or kind == abiStepPointer
freg int // FP register index, used if kind == abiStepFloatReg
}
// abiStepKind is the "op-code" for an abiStep instruction.
type abiStepKind int
const (
abiStepBad abiStepKind = iota
abiStepStack // copy to/from stack
abiStepIntReg // copy to/from integer register
abiStepPointer // copy pointer to/from integer register
abiStepFloatReg // copy to/from FP register
)
// abiSeq represents a sequence of ABI instructions for copying
// from a series of reflect.Values to a call frame (for call arguments)
// or vice-versa (for call results).
//
// An abiSeq should be populated by calling its addArg method.
type abiSeq struct {
// steps is the set of instructions.
//
// The instructions are grouped together by whole arguments,
// with the starting index for the instructions
// of the i'th Go value available in valueStart.
//
// For instance, if this abiSeq represents 3 arguments
// passed to a function, then the 2nd argument's steps
// begin at steps[valueStart[1]].
//
// Because reflect accepts Go arguments in distinct
// Values and each Value is stored separately, each abiStep
// that begins a new argument will have its offset
// field == 0.
steps []abiStep
valueStart []int
stackBytes uintptr // stack space used
iregs, fregs int // registers used
}
func (a *abiSeq) dump() {
for i, p := range a.steps {
println("part", i, p.kind, p.offset, p.size, p.stkOff, p.ireg, p.freg)
}
print("values ")
for _, i := range a.valueStart {
print(i, " ")
}
println()
println("stack", a.stackBytes)
println("iregs", a.iregs)
println("fregs", a.fregs)
}
// stepsForValue returns the ABI instructions for translating
// the i'th Go argument or return value represented by this
// abiSeq to the Go ABI.
func (a *abiSeq) stepsForValue(i int) []abiStep {
s := a.valueStart[i]
var e int
if i == len(a.valueStart)-1 {
e = len(a.steps)
} else {
e = a.valueStart[i+1]
}
return a.steps[s:e]
}
// addArg extends the abiSeq with a new Go value of type t.
//
// If the value was stack-assigned, returns the single
// abiStep describing that translation, and nil otherwise.
func (a *abiSeq) addArg(t *rtype) *abiStep {
// We'll always be adding a new value, so do that first.
pStart := len(a.steps)
a.valueStart = append(a.valueStart, pStart)
if t.size == 0 {
// If the size of the argument type is zero, then
// in order to degrade gracefully into ABI0, we need
// to stack-assign this type. The reason is that
// although zero-sized types take up no space on the
// stack, they do cause the next argument to be aligned.
// So just do that here, but don't bother actually
// generating a new ABI step for it (there's nothing to
// actually copy).
//
// We cannot handle this in the recursive case of
// regAssign because zero-sized *fields* of a
// non-zero-sized struct do not cause it to be
// stack-assigned. So we need a special case here
// at the top.
a.stackBytes = align(a.stackBytes, uintptr(t.align))
return nil
}
// Hold a copy of "a" so that we can roll back if
// register assignment fails.
aOld := *a
if !a.regAssign(t, 0) {
// Register assignment failed. Roll back any changes
// and stack-assign.
*a = aOld
a.stackAssign(t.size, uintptr(t.align))
return &a.steps[len(a.steps)-1]
}
return nil
}
// addRcvr extends the abiSeq with a new method call
// receiver according to the interface calling convention.
//
// If the receiver was stack-assigned, returns the single
// abiStep describing that translation, and nil otherwise.
// Returns true if the receiver is a pointer.
func (a *abiSeq) addRcvr(rcvr *rtype) (*abiStep, bool) {
// The receiver is always one word.
a.valueStart = append(a.valueStart, len(a.steps))
var ok, ptr bool
if ifaceIndir(rcvr) || rcvr.pointers() {
ok = a.assignIntN(0, goarch.PtrSize, 1, 0b1)
ptr = true
} else {
// TODO(mknyszek): Is this case even possible?
// The interface data work never contains a non-pointer
// value. This case was copied over from older code
// in the reflect package which only conditionally added
// a pointer bit to the reflect.(Value).Call stack frame's
// GC bitmap.
ok = a.assignIntN(0, goarch.PtrSize, 1, 0b0)
ptr = false
}
if !ok {
a.stackAssign(goarch.PtrSize, goarch.PtrSize)
return &a.steps[len(a.steps)-1], ptr
}
return nil, ptr
}
// regAssign attempts to reserve argument registers for a value of
// type t, stored at some offset.
//
// It returns whether or not the assignment succeeded, but
// leaves any changes it made to a.steps behind, so the caller
// must undo that work by adjusting a.steps if it fails.
//
// This method along with the assign* methods represent the
// complete register-assignment algorithm for the Go ABI.
func (a *abiSeq) regAssign(t *rtype, offset uintptr) bool {
switch t.Kind() {
case UnsafePointer, Pointer, Chan, Map, Func:
return a.assignIntN(offset, t.size, 1, 0b1)
case Bool, Int, Uint, Int8, Uint8, Int16, Uint16, Int32, Uint32, Uintptr:
return a.assignIntN(offset, t.size, 1, 0b0)
case Int64, Uint64:
switch goarch.PtrSize {
case 4:
return a.assignIntN(offset, 4, 2, 0b0)
case 8:
return a.assignIntN(offset, 8, 1, 0b0)
}
case Float32, Float64:
return a.assignFloatN(offset, t.size, 1)
case Complex64:
return a.assignFloatN(offset, 4, 2)
case Complex128:
return a.assignFloatN(offset, 8, 2)
case String:
return a.assignIntN(offset, goarch.PtrSize, 2, 0b01)
case Interface:
return a.assignIntN(offset, goarch.PtrSize, 2, 0b10)
case Slice:
return a.assignIntN(offset, goarch.PtrSize, 3, 0b001)
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
switch tt.len {
case 0:
// There's nothing to assign, so don't modify
// a.steps but succeed so the caller doesn't
// try to stack-assign this value.
return true
case 1:
return a.regAssign(tt.elem, offset)
default:
return false
}
case Struct:
st := (*structType)(unsafe.Pointer(t))
for i := range st.fields {
f := &st.fields[i]
if !a.regAssign(f.typ, offset+f.offset) {
return false
}
}
return true
default:
print("t.Kind == ", t.Kind(), "\n")
panic("unknown type kind")
}
panic("unhandled register assignment path")
}
// assignIntN assigns n values to registers, each "size" bytes large,
// from the data at [offset, offset+n*size) in memory. Each value at
// [offset+i*size, offset+(i+1)*size) for i < n is assigned to the
// next n integer registers.
//
// Bit i in ptrMap indicates whether the i'th value is a pointer.
// n must be <= 8.
//
// Returns whether assignment succeeded.
func (a *abiSeq) assignIntN(offset, size uintptr, n int, ptrMap uint8) bool {
if n > 8 || n < 0 {
panic("invalid n")
}
if ptrMap != 0 && size != goarch.PtrSize {
panic("non-empty pointer map passed for non-pointer-size values")
}
if a.iregs+n > intArgRegs {
return false
}
for i := 0; i < n; i++ {
kind := abiStepIntReg
if ptrMap&(uint8(1)<<i) != 0 {
kind = abiStepPointer
}
a.steps = append(a.steps, abiStep{
kind: kind,
offset: offset + uintptr(i)*size,
size: size,
ireg: a.iregs,
})
a.iregs++
}
return true
}
// assignFloatN assigns n values to registers, each "size" bytes large,
// from the data at [offset, offset+n*size) in memory. Each value at
// [offset+i*size, offset+(i+1)*size) for i < n is assigned to the
// next n floating-point registers.
//
// Returns whether assignment succeeded.
func (a *abiSeq) assignFloatN(offset, size uintptr, n int) bool {
if n < 0 {
panic("invalid n")
}
if a.fregs+n > floatArgRegs || floatRegSize < size {
return false
}
for i := 0; i < n; i++ {
a.steps = append(a.steps, abiStep{
kind: abiStepFloatReg,
offset: offset + uintptr(i)*size,
size: size,
freg: a.fregs,
})
a.fregs++
}
return true
}
// stackAssign reserves space for one value that is "size" bytes
// large with alignment "alignment" to the stack.
//
// Should not be called directly; use addArg instead.
func (a *abiSeq) stackAssign(size, alignment uintptr) {
a.stackBytes = align(a.stackBytes, alignment)
a.steps = append(a.steps, abiStep{
kind: abiStepStack,
offset: 0, // Only used for whole arguments, so the memory offset is 0.
size: size,
stkOff: a.stackBytes,
})
a.stackBytes += size
}
// abiDesc describes the ABI for a function or method.
type abiDesc struct {
// call and ret represent the translation steps for
// the call and return paths of a Go function.
call, ret abiSeq
// These fields describe the stack space allocated
// for the call. stackCallArgsSize is the amount of space
// reserved for arguments but not return values. retOffset
// is the offset at which return values begin, and
// spill is the size in bytes of additional space reserved
// to spill argument registers into in case of preemption in
// reflectcall's stack frame.
stackCallArgsSize, retOffset, spill uintptr
// stackPtrs is a bitmap that indicates whether
// each word in the ABI stack space (stack-assigned
// args + return values) is a pointer. Used
// as the heap pointer bitmap for stack space
// passed to reflectcall.
stackPtrs *bitVector
// inRegPtrs is a bitmap whose i'th bit indicates
// whether the i'th integer argument register contains
// a pointer. Used by makeFuncStub and methodValueCall
// to make result pointers visible to the GC.
//
// outRegPtrs is the same, but for result values.
// Used by reflectcall to make result pointers visible
// to the GC.
inRegPtrs, outRegPtrs abi.IntArgRegBitmap
}
func (a *abiDesc) dump() {
println("ABI")
println("call")
a.call.dump()
println("ret")
a.ret.dump()
println("stackCallArgsSize", a.stackCallArgsSize)
println("retOffset", a.retOffset)
println("spill", a.spill)
print("inRegPtrs:")
dumpPtrBitMap(a.inRegPtrs)
println()
print("outRegPtrs:")
dumpPtrBitMap(a.outRegPtrs)
println()
}
func dumpPtrBitMap(b abi.IntArgRegBitmap) {
for i := 0; i < intArgRegs; i++ {
x := 0
if b.Get(i) {
x = 1
}
print(" ", x)
}
}
func newAbiDesc(t *funcType, rcvr *rtype) abiDesc {
// We need to add space for this argument to
// the frame so that it can spill args into it.
//
// The size of this space is just the sum of the sizes
// of each register-allocated type.
//
// TODO(mknyszek): Remove this when we no longer have
// caller reserved spill space.
spill := uintptr(0)
// Compute gc program & stack bitmap for stack arguments
stackPtrs := new(bitVector)
// Compute the stack frame pointer bitmap and register
// pointer bitmap for arguments.
inRegPtrs := abi.IntArgRegBitmap{}
// Compute abiSeq for input parameters.
var in abiSeq
if rcvr != nil {
stkStep, isPtr := in.addRcvr(rcvr)
if stkStep != nil {
if isPtr {
stackPtrs.append(1)
} else {
stackPtrs.append(0)
}
} else {
spill += goarch.PtrSize
}
}
for i, arg := range t.in() {
stkStep := in.addArg(arg)
if stkStep != nil {
addTypeBits(stackPtrs, stkStep.stkOff, arg)
} else {
spill = align(spill, uintptr(arg.align))
spill += arg.size
for _, st := range in.stepsForValue(i) {
if st.kind == abiStepPointer {
inRegPtrs.Set(st.ireg)
}
}
}
}
spill = align(spill, goarch.PtrSize)
// From the input parameters alone, we now know
// the stackCallArgsSize and retOffset.
stackCallArgsSize := in.stackBytes
retOffset := align(in.stackBytes, goarch.PtrSize)
// Compute the stack frame pointer bitmap and register
// pointer bitmap for return values.
outRegPtrs := abi.IntArgRegBitmap{}
// Compute abiSeq for output parameters.
var out abiSeq
// Stack-assigned return values do not share
// space with arguments like they do with registers,
// so we need to inject a stack offset here.
// Fake it by artificially extending stackBytes by
// the return offset.
out.stackBytes = retOffset
for i, res := range t.out() {
stkStep := out.addArg(res)
if stkStep != nil {
addTypeBits(stackPtrs, stkStep.stkOff, res)
} else {
for _, st := range out.stepsForValue(i) {
if st.kind == abiStepPointer {
outRegPtrs.Set(st.ireg)
}
}
}
}
// Undo the faking from earlier so that stackBytes
// is accurate.
out.stackBytes -= retOffset
return abiDesc{in, out, stackCallArgsSize, retOffset, spill, stackPtrs, inRegPtrs, outRegPtrs}
}
// intFromReg loads an argSize sized integer from reg and places it at to.
//
// argSize must be non-zero, fit in a register, and a power-of-two.
func intFromReg(r *abi.RegArgs, reg int, argSize uintptr, to unsafe.Pointer) {
memmove(to, r.IntRegArgAddr(reg, argSize), argSize)
}
// intToReg loads an argSize sized integer and stores it into reg.
//
// argSize must be non-zero, fit in a register, and a power-of-two.
func intToReg(r *abi.RegArgs, reg int, argSize uintptr, from unsafe.Pointer) {
memmove(r.IntRegArgAddr(reg, argSize), from, argSize)
}
// floatFromReg loads a float value from its register representation in r.
//
// argSize must be 4 or 8.
func floatFromReg(r *abi.RegArgs, reg int, argSize uintptr, to unsafe.Pointer) {
switch argSize {
case 4:
*(*float32)(to) = archFloat32FromReg(r.Floats[reg])
case 8:
*(*float64)(to) = *(*float64)(unsafe.Pointer(&r.Floats[reg]))
default:
panic("bad argSize")
}
}
// floatToReg stores a float value in its register representation in r.
//
// argSize must be either 4 or 8.
func floatToReg(r *abi.RegArgs, reg int, argSize uintptr, from unsafe.Pointer) {
switch argSize {
case 4:
r.Floats[reg] = archFloat32ToReg(*(*float32)(from))
case 8:
r.Floats[reg] = *(*uint64)(from)
default:
panic("bad argSize")
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Deep equality test via reflection
package reflect
import (
"internal/bytealg"
"unsafe"
)
// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
type visit struct {
a1 unsafe.Pointer
a2 unsafe.Pointer
typ Type
}
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func deepValueEqual(v1, v2 Value, visited map[visit]bool) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
// We want to avoid putting more in the visited map than we need to.
// For any possible reference cycle that might be encountered,
// hard(v1, v2) needs to return true for at least one of the types in the cycle,
// and it's safe and valid to get Value's internal pointer.
hard := func(v1, v2 Value) bool {
switch v1.Kind() {
case Pointer:
if v1.typ.ptrdata == 0 {
// not-in-heap pointers can't be cyclic.
// At least, all of our current uses of runtime/internal/sys.NotInHeap
// have that property. The runtime ones aren't cyclic (and we don't use
// DeepEqual on them anyway), and the cgo-generated ones are
// all empty structs.
return false
}
fallthrough
case Map, Slice, Interface:
// Nil pointers cannot be cyclic. Avoid putting them in the visited map.
return !v1.IsNil() && !v2.IsNil()
}
return false
}
if hard(v1, v2) {
// For a Pointer or Map value, we need to check flagIndir,
// which we do by calling the pointer method.
// For Slice or Interface, flagIndir is always set,
// and using v.ptr suffices.
ptrval := func(v Value) unsafe.Pointer {
switch v.Kind() {
case Pointer, Map:
return v.pointer()
default:
return v.ptr
}
}
addr1 := ptrval(v1)
addr2 := ptrval(v2)
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
// Assumes non-moving garbage collector.
addr1, addr2 = addr2, addr1
}
// Short circuit if references are already seen.
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}
// Remember for later.
visited[v] = true
}
switch v1.Kind() {
case Array:
for i := 0; i < v1.Len(); i++ {
if !deepValueEqual(v1.Index(i), v2.Index(i), visited) {
return false
}
}
return true
case Slice:
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.UnsafePointer() == v2.UnsafePointer() {
return true
}
// Special case for []byte, which is common.
if v1.Type().Elem().Kind() == Uint8 {
return bytealg.Equal(v1.Bytes(), v2.Bytes())
}
for i := 0; i < v1.Len(); i++ {
if !deepValueEqual(v1.Index(i), v2.Index(i), visited) {
return false
}
}
return true
case Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return deepValueEqual(v1.Elem(), v2.Elem(), visited)
case Pointer:
if v1.UnsafePointer() == v2.UnsafePointer() {
return true
}
return deepValueEqual(v1.Elem(), v2.Elem(), visited)
case Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !deepValueEqual(v1.Field(i), v2.Field(i), visited) {
return false
}
}
return true
case Map:
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.UnsafePointer() == v2.UnsafePointer() {
return true
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !val1.IsValid() || !val2.IsValid() || !deepValueEqual(val1, val2, visited) {
return false
}
}
return true
case Func:
if v1.IsNil() && v2.IsNil() {
return true
}
// Can't do better than this:
return false
case Int, Int8, Int16, Int32, Int64:
return v1.Int() == v2.Int()
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return v1.Uint() == v2.Uint()
case String:
return v1.String() == v2.String()
case Bool:
return v1.Bool() == v2.Bool()
case Float32, Float64:
return v1.Float() == v2.Float()
case Complex64, Complex128:
return v1.Complex() == v2.Complex()
default:
// Normal equality suffices
return valueInterface(v1, false) == valueInterface(v2, false)
}
}
// DeepEqual reports whether x and y are “deeply equal,” defined as follows.
// Two values of identical type are deeply equal if one of the following cases applies.
// Values of distinct types are never deeply equal.
//
// Array values are deeply equal when their corresponding elements are deeply equal.
//
// Struct values are deeply equal if their corresponding fields,
// both exported and unexported, are deeply equal.
//
// Func values are deeply equal if both are nil; otherwise they are not deeply equal.
//
// Interface values are deeply equal if they hold deeply equal concrete values.
//
// Map values are deeply equal when all of the following are true:
// they are both nil or both non-nil, they have the same length,
// and either they are the same map object or their corresponding keys
// (matched using Go equality) map to deeply equal values.
//
// Pointer values are deeply equal if they are equal using Go's == operator
// or if they point to deeply equal values.
//
// Slice values are deeply equal when all of the following are true:
// they are both nil or both non-nil, they have the same length,
// and either they point to the same initial entry of the same underlying array
// (that is, &x[0] == &y[0]) or their corresponding elements (up to length) are deeply equal.
// Note that a non-nil empty slice and a nil slice (for example, []byte{} and []byte(nil))
// are not deeply equal.
//
// Other values - numbers, bools, strings, and channels - are deeply equal
// if they are equal using Go's == operator.
//
// In general DeepEqual is a recursive relaxation of Go's == operator.
// However, this idea is impossible to implement without some inconsistency.
// Specifically, it is possible for a value to be unequal to itself,
// either because it is of func type (uncomparable in general)
// or because it is a floating-point NaN value (not equal to itself in floating-point comparison),
// or because it is an array, struct, or interface containing
// such a value.
// On the other hand, pointer values are always equal to themselves,
// even if they point at or contain such problematic values,
// because they compare equal using Go's == operator, and that
// is a sufficient condition to be deeply equal, regardless of content.
// DeepEqual has been defined so that the same short-cut applies
// to slices and maps: if x and y are the same slice or the same map,
// they are deeply equal regardless of content.
//
// As DeepEqual traverses the data values it may find a cycle. The
// second and subsequent times that DeepEqual compares two pointer
// values that have been compared before, it treats the values as
// equal rather than examining the values to which they point.
// This ensures that DeepEqual terminates.
func DeepEqual(x, y any) bool {
if x == nil || y == nil {
return x == y
}
v1 := ValueOf(x)
v2 := ValueOf(y)
if v1.Type() != v2.Type() {
return false
}
return deepValueEqual(v1, v2, make(map[visit]bool))
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !ppc64 && !ppc64le && !riscv64
package reflect
import "unsafe"
// This file implements a straightforward conversion of a float32
// value into its representation in a register. This conversion
// applies for amd64 and arm64. It is also chosen for the case of
// zero argument registers, but is not used.
func archFloat32FromReg(reg uint64) float32 {
i := uint32(reg)
return *(*float32)(unsafe.Pointer(&i))
}
func archFloat32ToReg(val float32) uint64 {
return uint64(*(*uint32)(unsafe.Pointer(&val)))
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// MakeFunc implementation.
package reflect
import (
"internal/abi"
"unsafe"
)
// makeFuncImpl is the closure value implementing the function
// returned by MakeFunc.
// The first three words of this type must be kept in sync with
// methodValue and runtime.reflectMethodValue.
// Any changes should be reflected in all three.
type makeFuncImpl struct {
makeFuncCtxt
ftyp *funcType
fn func([]Value) []Value
}
// MakeFunc returns a new function of the given Type
// that wraps the function fn. When called, that new function
// does the following:
//
// - converts its arguments to a slice of Values.
// - runs results := fn(args).
// - returns the results as a slice of Values, one per formal result.
//
// The implementation fn can assume that the argument Value slice
// has the number and type of arguments given by typ.
// If typ describes a variadic function, the final Value is itself
// a slice representing the variadic arguments, as in the
// body of a variadic function. The result Value slice returned by fn
// must have the number and type of results given by typ.
//
// The Value.Call method allows the caller to invoke a typed function
// in terms of Values; in contrast, MakeFunc allows the caller to implement
// a typed function in terms of Values.
//
// The Examples section of the documentation includes an illustration
// of how to use MakeFunc to build a swap function for different types.
func MakeFunc(typ Type, fn func(args []Value) (results []Value)) Value {
if typ.Kind() != Func {
panic("reflect: call of MakeFunc with non-Func type")
}
t := typ.common()
ftyp := (*funcType)(unsafe.Pointer(t))
code := abi.FuncPCABI0(makeFuncStub)
// makeFuncImpl contains a stack map for use by the runtime
_, _, abid := funcLayout(ftyp, nil)
impl := &makeFuncImpl{
makeFuncCtxt: makeFuncCtxt{
fn: code,
stack: abid.stackPtrs,
argLen: abid.stackCallArgsSize,
regPtrs: abid.inRegPtrs,
},
ftyp: ftyp,
fn: fn,
}
return Value{t, unsafe.Pointer(impl), flag(Func)}
}
// makeFuncStub is an assembly function that is the code half of
// the function returned from MakeFunc. It expects a *callReflectFunc
// as its context register, and its job is to invoke callReflect(ctxt, frame)
// where ctxt is the context register and frame is a pointer to the first
// word in the passed-in argument frame.
func makeFuncStub()
// The first 3 words of this type must be kept in sync with
// makeFuncImpl and runtime.reflectMethodValue.
// Any changes should be reflected in all three.
type methodValue struct {
makeFuncCtxt
method int
rcvr Value
}
// makeMethodValue converts v from the rcvr+method index representation
// of a method value to an actual method func value, which is
// basically the receiver value with a special bit set, into a true
// func value - a value holding an actual func. The output is
// semantically equivalent to the input as far as the user of package
// reflect can tell, but the true func representation can be handled
// by code like Convert and Interface and Assign.
func makeMethodValue(op string, v Value) Value {
if v.flag&flagMethod == 0 {
panic("reflect: internal error: invalid use of makeMethodValue")
}
// Ignoring the flagMethod bit, v describes the receiver, not the method type.
fl := v.flag & (flagRO | flagAddr | flagIndir)
fl |= flag(v.typ.Kind())
rcvr := Value{v.typ, v.ptr, fl}
// v.Type returns the actual type of the method value.
ftyp := (*funcType)(unsafe.Pointer(v.Type().(*rtype)))
code := methodValueCallCodePtr()
// methodValue contains a stack map for use by the runtime
_, _, abid := funcLayout(ftyp, nil)
fv := &methodValue{
makeFuncCtxt: makeFuncCtxt{
fn: code,
stack: abid.stackPtrs,
argLen: abid.stackCallArgsSize,
regPtrs: abid.inRegPtrs,
},
method: int(v.flag) >> flagMethodShift,
rcvr: rcvr,
}
// Cause panic if method is not appropriate.
// The panic would still happen during the call if we omit this,
// but we want Interface() and other operations to fail early.
methodReceiver(op, fv.rcvr, fv.method)
return Value{&ftyp.rtype, unsafe.Pointer(fv), v.flag&flagRO | flag(Func)}
}
func methodValueCallCodePtr() uintptr {
return abi.FuncPCABI0(methodValueCall)
}
// methodValueCall is an assembly function that is the code half of
// the function returned from makeMethodValue. It expects a *methodValue
// as its context register, and its job is to invoke callMethod(ctxt, frame)
// where ctxt is the context register and frame is a pointer to the first
// word in the passed-in argument frame.
func methodValueCall()
// This structure must be kept in sync with runtime.reflectMethodValue.
// Any changes should be reflected in all both.
type makeFuncCtxt struct {
fn uintptr
stack *bitVector // ptrmap for both stack args and results
argLen uintptr // just args
regPtrs abi.IntArgRegBitmap
}
// moveMakeFuncArgPtrs uses ctxt.regPtrs to copy integer pointer arguments
// in args.Ints to args.Ptrs where the GC can see them.
//
// This is similar to what reflectcallmove does in the runtime, except
// that happens on the return path, whereas this happens on the call path.
//
// nosplit because pointers are being held in uintptr slots in args, so
// having our stack scanned now could lead to accidentally freeing
// memory.
//
//go:nosplit
func moveMakeFuncArgPtrs(ctxt *makeFuncCtxt, args *abi.RegArgs) {
for i, arg := range args.Ints {
// Avoid write barriers! Because our write barrier enqueues what
// was there before, we might enqueue garbage.
if ctxt.regPtrs.Get(i) {
*(*uintptr)(unsafe.Pointer(&args.Ptrs[i])) = arg
} else {
// We *must* zero this space ourselves because it's defined in
// assembly code and the GC will scan these pointers. Otherwise,
// there will be garbage here.
*(*uintptr)(unsafe.Pointer(&args.Ptrs[i])) = 0
}
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflect
import (
"internal/goarch"
"internal/unsafeheader"
"unsafe"
)
// Swapper returns a function that swaps the elements in the provided
// slice.
//
// Swapper panics if the provided interface is not a slice.
func Swapper(slice any) func(i, j int) {
v := ValueOf(slice)
if v.Kind() != Slice {
panic(&ValueError{Method: "Swapper", Kind: v.Kind()})
}
// Fast path for slices of size 0 and 1. Nothing to swap.
switch v.Len() {
case 0:
return func(i, j int) { panic("reflect: slice index out of range") }
case 1:
return func(i, j int) {
if i != 0 || j != 0 {
panic("reflect: slice index out of range")
}
}
}
typ := v.Type().Elem().(*rtype)
size := typ.Size()
hasPtr := typ.ptrdata != 0
// Some common & small cases, without using memmove:
if hasPtr {
if size == goarch.PtrSize {
ps := *(*[]unsafe.Pointer)(v.ptr)
return func(i, j int) { ps[i], ps[j] = ps[j], ps[i] }
}
if typ.Kind() == String {
ss := *(*[]string)(v.ptr)
return func(i, j int) { ss[i], ss[j] = ss[j], ss[i] }
}
} else {
switch size {
case 8:
is := *(*[]int64)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 4:
is := *(*[]int32)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 2:
is := *(*[]int16)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
case 1:
is := *(*[]int8)(v.ptr)
return func(i, j int) { is[i], is[j] = is[j], is[i] }
}
}
s := (*unsafeheader.Slice)(v.ptr)
tmp := unsafe_New(typ) // swap scratch space
return func(i, j int) {
if uint(i) >= uint(s.Len) || uint(j) >= uint(s.Len) {
panic("reflect: slice index out of range")
}
val1 := arrayAt(s.Data, i, size, "i < s.Len")
val2 := arrayAt(s.Data, j, size, "j < s.Len")
typedmemmove(typ, tmp, val1)
typedmemmove(typ, val1, val2)
typedmemmove(typ, val2, tmp)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package reflect implements run-time reflection, allowing a program to
// manipulate objects with arbitrary types. The typical use is to take a value
// with static type interface{} and extract its dynamic type information by
// calling TypeOf, which returns a Type.
//
// A call to ValueOf returns a Value representing the run-time data.
// Zero takes a Type and returns a Value representing a zero value
// for that type.
//
// See "The Laws of Reflection" for an introduction to reflection in Go:
// https://golang.org/doc/articles/laws_of_reflection.html
package reflect
import (
"internal/abi"
"internal/goarch"
"strconv"
"sync"
"unicode"
"unicode/utf8"
"unsafe"
)
// Type is the representation of a Go type.
//
// Not all methods apply to all kinds of types. Restrictions,
// if any, are noted in the documentation for each method.
// Use the Kind method to find out the kind of type before
// calling kind-specific methods. Calling a method
// inappropriate to the kind of type causes a run-time panic.
//
// Type values are comparable, such as with the == operator,
// so they can be used as map keys.
// Two Type values are equal if they represent identical types.
type Type interface {
// Methods applicable to all types.
// Align returns the alignment in bytes of a value of
// this type when allocated in memory.
Align() int
// FieldAlign returns the alignment in bytes of a value of
// this type when used as a field in a struct.
FieldAlign() int
// Method returns the i'th method in the type's method set.
// It panics if i is not in the range [0, NumMethod()).
//
// For a non-interface type T or *T, the returned Method's Type and Func
// fields describe a function whose first argument is the receiver,
// and only exported methods are accessible.
//
// For an interface type, the returned Method's Type field gives the
// method signature, without a receiver, and the Func field is nil.
//
// Methods are sorted in lexicographic order.
Method(int) Method
// MethodByName returns the method with that name in the type's
// method set and a boolean indicating if the method was found.
//
// For a non-interface type T or *T, the returned Method's Type and Func
// fields describe a function whose first argument is the receiver.
//
// For an interface type, the returned Method's Type field gives the
// method signature, without a receiver, and the Func field is nil.
MethodByName(string) (Method, bool)
// NumMethod returns the number of methods accessible using Method.
//
// For a non-interface type, it returns the number of exported methods.
//
// For an interface type, it returns the number of exported and unexported methods.
NumMethod() int
// Name returns the type's name within its package for a defined type.
// For other (non-defined) types it returns the empty string.
Name() string
// PkgPath returns a defined type's package path, that is, the import path
// that uniquely identifies the package, such as "encoding/base64".
// If the type was predeclared (string, error) or not defined (*T, struct{},
// []int, or A where A is an alias for a non-defined type), the package path
// will be the empty string.
PkgPath() string
// Size returns the number of bytes needed to store
// a value of the given type; it is analogous to unsafe.Sizeof.
Size() uintptr
// String returns a string representation of the type.
// The string representation may use shortened package names
// (e.g., base64 instead of "encoding/base64") and is not
// guaranteed to be unique among types. To test for type identity,
// compare the Types directly.
String() string
// Kind returns the specific kind of this type.
Kind() Kind
// Implements reports whether the type implements the interface type u.
Implements(u Type) bool
// AssignableTo reports whether a value of the type is assignable to type u.
AssignableTo(u Type) bool
// ConvertibleTo reports whether a value of the type is convertible to type u.
// Even if ConvertibleTo returns true, the conversion may still panic.
// For example, a slice of type []T is convertible to *[N]T,
// but the conversion will panic if its length is less than N.
ConvertibleTo(u Type) bool
// Comparable reports whether values of this type are comparable.
// Even if Comparable returns true, the comparison may still panic.
// For example, values of interface type are comparable,
// but the comparison will panic if their dynamic type is not comparable.
Comparable() bool
// Methods applicable only to some types, depending on Kind.
// The methods allowed for each kind are:
//
// Int*, Uint*, Float*, Complex*: Bits
// Array: Elem, Len
// Chan: ChanDir, Elem
// Func: In, NumIn, Out, NumOut, IsVariadic.
// Map: Key, Elem
// Pointer: Elem
// Slice: Elem
// Struct: Field, FieldByIndex, FieldByName, FieldByNameFunc, NumField
// Bits returns the size of the type in bits.
// It panics if the type's Kind is not one of the
// sized or unsized Int, Uint, Float, or Complex kinds.
Bits() int
// ChanDir returns a channel type's direction.
// It panics if the type's Kind is not Chan.
ChanDir() ChanDir
// IsVariadic reports whether a function type's final input parameter
// is a "..." parameter. If so, t.In(t.NumIn() - 1) returns the parameter's
// implicit actual type []T.
//
// For concreteness, if t represents func(x int, y ... float64), then
//
// t.NumIn() == 2
// t.In(0) is the reflect.Type for "int"
// t.In(1) is the reflect.Type for "[]float64"
// t.IsVariadic() == true
//
// IsVariadic panics if the type's Kind is not Func.
IsVariadic() bool
// Elem returns a type's element type.
// It panics if the type's Kind is not Array, Chan, Map, Pointer, or Slice.
Elem() Type
// Field returns a struct type's i'th field.
// It panics if the type's Kind is not Struct.
// It panics if i is not in the range [0, NumField()).
Field(i int) StructField
// FieldByIndex returns the nested field corresponding
// to the index sequence. It is equivalent to calling Field
// successively for each index i.
// It panics if the type's Kind is not Struct.
FieldByIndex(index []int) StructField
// FieldByName returns the struct field with the given name
// and a boolean indicating if the field was found.
FieldByName(name string) (StructField, bool)
// FieldByNameFunc returns the struct field with a name
// that satisfies the match function and a boolean indicating if
// the field was found.
//
// FieldByNameFunc considers the fields in the struct itself
// and then the fields in any embedded structs, in breadth first order,
// stopping at the shallowest nesting depth containing one or more
// fields satisfying the match function. If multiple fields at that depth
// satisfy the match function, they cancel each other
// and FieldByNameFunc returns no match.
// This behavior mirrors Go's handling of name lookup in
// structs containing embedded fields.
FieldByNameFunc(match func(string) bool) (StructField, bool)
// In returns the type of a function type's i'th input parameter.
// It panics if the type's Kind is not Func.
// It panics if i is not in the range [0, NumIn()).
In(i int) Type
// Key returns a map type's key type.
// It panics if the type's Kind is not Map.
Key() Type
// Len returns an array type's length.
// It panics if the type's Kind is not Array.
Len() int
// NumField returns a struct type's field count.
// It panics if the type's Kind is not Struct.
NumField() int
// NumIn returns a function type's input parameter count.
// It panics if the type's Kind is not Func.
NumIn() int
// NumOut returns a function type's output parameter count.
// It panics if the type's Kind is not Func.
NumOut() int
// Out returns the type of a function type's i'th output parameter.
// It panics if the type's Kind is not Func.
// It panics if i is not in the range [0, NumOut()).
Out(i int) Type
common() *rtype
uncommon() *uncommonType
}
// BUG(rsc): FieldByName and related functions consider struct field names to be equal
// if the names are equal, even if they are unexported names originating
// in different packages. The practical effect of this is that the result of
// t.FieldByName("x") is not well defined if the struct type t contains
// multiple fields named x (embedded from different packages).
// FieldByName may return one of the fields named x or may report that there are none.
// See https://golang.org/issue/4876 for more details.
/*
* These data structures are known to the compiler (../cmd/compile/internal/reflectdata/reflect.go).
* A few are known to ../runtime/type.go to convey to debuggers.
* They are also known to ../runtime/type.go.
*/
// A Kind represents the specific kind of type that a Type represents.
// The zero Kind is not a valid kind.
type Kind uint
const (
Invalid Kind = iota
Bool
Int
Int8
Int16
Int32
Int64
Uint
Uint8
Uint16
Uint32
Uint64
Uintptr
Float32
Float64
Complex64
Complex128
Array
Chan
Func
Interface
Map
Pointer
Slice
String
Struct
UnsafePointer
)
// Ptr is the old name for the Pointer kind.
const Ptr = Pointer
// tflag is used by an rtype to signal what extra type information is
// available in the memory directly following the rtype value.
//
// tflag values must be kept in sync with copies in:
//
// cmd/compile/internal/reflectdata/reflect.go
// cmd/link/internal/ld/decodesym.go
// runtime/type.go
type tflag uint8
const (
// tflagUncommon means that there is a pointer, *uncommonType,
// just beyond the outer type structure.
//
// For example, if t.Kind() == Struct and t.tflag&tflagUncommon != 0,
// then t has uncommonType data and it can be accessed as:
//
// type tUncommon struct {
// structType
// u uncommonType
// }
// u := &(*tUncommon)(unsafe.Pointer(t)).u
tflagUncommon tflag = 1 << 0
// tflagExtraStar means the name in the str field has an
// extraneous '*' prefix. This is because for most types T in
// a program, the type *T also exists and reusing the str data
// saves binary size.
tflagExtraStar tflag = 1 << 1
// tflagNamed means the type has a name.
tflagNamed tflag = 1 << 2
// tflagRegularMemory means that equal and hash functions can treat
// this type as a single region of t.size bytes.
tflagRegularMemory tflag = 1 << 3
)
// rtype is the common implementation of most values.
// It is embedded in other struct types.
//
// rtype must be kept in sync with ../runtime/type.go:/^type._type.
type rtype struct {
size uintptr
ptrdata uintptr // number of bytes in the type that can contain pointers
hash uint32 // hash of type; avoids computation in hash tables
tflag tflag // extra type information flags
align uint8 // alignment of variable with this type
fieldAlign uint8 // alignment of struct field with this type
kind uint8 // enumeration for C
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
equal func(unsafe.Pointer, unsafe.Pointer) bool
gcdata *byte // garbage collection data
str nameOff // string form
ptrToThis typeOff // type for pointer to this type, may be zero
}
// Method on non-interface type
type method struct {
name nameOff // name of method
mtyp typeOff // method type (without receiver)
ifn textOff // fn used in interface call (one-word receiver)
tfn textOff // fn used for normal method call
}
// uncommonType is present only for defined types or types with methods
// (if T is a defined type, the uncommonTypes for T and *T have methods).
// Using a pointer to this struct reduces the overall size required
// to describe a non-defined type with no methods.
type uncommonType struct {
pkgPath nameOff // import path; empty for built-in types like int, string
mcount uint16 // number of methods
xcount uint16 // number of exported methods
moff uint32 // offset from this uncommontype to [mcount]method
_ uint32 // unused
}
// ChanDir represents a channel type's direction.
type ChanDir int
const (
RecvDir ChanDir = 1 << iota // <-chan
SendDir // chan<-
BothDir = RecvDir | SendDir // chan
)
// arrayType represents a fixed array type.
type arrayType struct {
rtype
elem *rtype // array element type
slice *rtype // slice type
len uintptr
}
// chanType represents a channel type.
type chanType struct {
rtype
elem *rtype // channel element type
dir uintptr // channel direction (ChanDir)
}
// funcType represents a function type.
//
// A *rtype for each in and out parameter is stored in an array that
// directly follows the funcType (and possibly its uncommonType). So
// a function type with one method, one input, and one output is:
//
// struct {
// funcType
// uncommonType
// [2]*rtype // [0] is in, [1] is out
// }
type funcType struct {
rtype
inCount uint16
outCount uint16 // top bit is set if last input parameter is ...
}
// imethod represents a method on an interface type
type imethod struct {
name nameOff // name of method
typ typeOff // .(*FuncType) underneath
}
// interfaceType represents an interface type.
type interfaceType struct {
rtype
pkgPath name // import path
methods []imethod // sorted by hash
}
// mapType represents a map type.
type mapType struct {
rtype
key *rtype // map key type
elem *rtype // map element (value) type
bucket *rtype // internal bucket structure
// function for hashing keys (ptr to key, seed) -> hash
hasher func(unsafe.Pointer, uintptr) uintptr
keysize uint8 // size of key slot
valuesize uint8 // size of value slot
bucketsize uint16 // size of bucket
flags uint32
}
// ptrType represents a pointer type.
type ptrType struct {
rtype
elem *rtype // pointer element (pointed at) type
}
// sliceType represents a slice type.
type sliceType struct {
rtype
elem *rtype // slice element type
}
// Struct field
type structField struct {
name name // name is always non-empty
typ *rtype // type of field
offset uintptr // byte offset of field
}
func (f *structField) embedded() bool {
return f.name.embedded()
}
// structType represents a struct type.
type structType struct {
rtype
pkgPath name
fields []structField // sorted by offset
}
// name is an encoded type name with optional extra data.
//
// The first byte is a bit field containing:
//
// 1<<0 the name is exported
// 1<<1 tag data follows the name
// 1<<2 pkgPath nameOff follows the name and tag
// 1<<3 the name is of an embedded (a.k.a. anonymous) field
//
// Following that, there is a varint-encoded length of the name,
// followed by the name itself.
//
// If tag data is present, it also has a varint-encoded length
// followed by the tag itself.
//
// If the import path follows, then 4 bytes at the end of
// the data form a nameOff. The import path is only set for concrete
// methods that are defined in a different package than their type.
//
// If a name starts with "*", then the exported bit represents
// whether the pointed to type is exported.
//
// Note: this encoding must match here and in:
// cmd/compile/internal/reflectdata/reflect.go
// runtime/type.go
// internal/reflectlite/type.go
// cmd/link/internal/ld/decodesym.go
type name struct {
bytes *byte
}
func (n name) data(off int, whySafe string) *byte {
return (*byte)(add(unsafe.Pointer(n.bytes), uintptr(off), whySafe))
}
func (n name) isExported() bool {
return (*n.bytes)&(1<<0) != 0
}
func (n name) hasTag() bool {
return (*n.bytes)&(1<<1) != 0
}
func (n name) embedded() bool {
return (*n.bytes)&(1<<3) != 0
}
// readVarint parses a varint as encoded by encoding/binary.
// It returns the number of encoded bytes and the encoded value.
func (n name) readVarint(off int) (int, int) {
v := 0
for i := 0; ; i++ {
x := *n.data(off+i, "read varint")
v += int(x&0x7f) << (7 * i)
if x&0x80 == 0 {
return i + 1, v
}
}
}
// writeVarint writes n to buf in varint form. Returns the
// number of bytes written. n must be nonnegative.
// Writes at most 10 bytes.
func writeVarint(buf []byte, n int) int {
for i := 0; ; i++ {
b := byte(n & 0x7f)
n >>= 7
if n == 0 {
buf[i] = b
return i + 1
}
buf[i] = b | 0x80
}
}
func (n name) name() string {
if n.bytes == nil {
return ""
}
i, l := n.readVarint(1)
return unsafe.String(n.data(1+i, "non-empty string"), l)
}
func (n name) tag() string {
if !n.hasTag() {
return ""
}
i, l := n.readVarint(1)
i2, l2 := n.readVarint(1 + i + l)
return unsafe.String(n.data(1+i+l+i2, "non-empty string"), l2)
}
func (n name) pkgPath() string {
if n.bytes == nil || *n.data(0, "name flag field")&(1<<2) == 0 {
return ""
}
i, l := n.readVarint(1)
off := 1 + i + l
if n.hasTag() {
i2, l2 := n.readVarint(off)
off += i2 + l2
}
var nameOff int32
// Note that this field may not be aligned in memory,
// so we cannot use a direct int32 assignment here.
copy((*[4]byte)(unsafe.Pointer(&nameOff))[:], (*[4]byte)(unsafe.Pointer(n.data(off, "name offset field")))[:])
pkgPathName := name{(*byte)(resolveTypeOff(unsafe.Pointer(n.bytes), nameOff))}
return pkgPathName.name()
}
func newName(n, tag string, exported, embedded bool) name {
if len(n) >= 1<<29 {
panic("reflect.nameFrom: name too long: " + n[:1024] + "...")
}
if len(tag) >= 1<<29 {
panic("reflect.nameFrom: tag too long: " + tag[:1024] + "...")
}
var nameLen [10]byte
var tagLen [10]byte
nameLenLen := writeVarint(nameLen[:], len(n))
tagLenLen := writeVarint(tagLen[:], len(tag))
var bits byte
l := 1 + nameLenLen + len(n)
if exported {
bits |= 1 << 0
}
if len(tag) > 0 {
l += tagLenLen + len(tag)
bits |= 1 << 1
}
if embedded {
bits |= 1 << 3
}
b := make([]byte, l)
b[0] = bits
copy(b[1:], nameLen[:nameLenLen])
copy(b[1+nameLenLen:], n)
if len(tag) > 0 {
tb := b[1+nameLenLen+len(n):]
copy(tb, tagLen[:tagLenLen])
copy(tb[tagLenLen:], tag)
}
return name{bytes: &b[0]}
}
/*
* The compiler knows the exact layout of all the data structures above.
* The compiler does not know about the data structures and methods below.
*/
// Method represents a single method.
type Method struct {
// Name is the method name.
Name string
// PkgPath is the package path that qualifies a lower case (unexported)
// method name. It is empty for upper case (exported) method names.
// The combination of PkgPath and Name uniquely identifies a method
// in a method set.
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
PkgPath string
Type Type // method type
Func Value // func with receiver as first argument
Index int // index for Type.Method
}
// IsExported reports whether the method is exported.
func (m Method) IsExported() bool {
return m.PkgPath == ""
}
const (
kindDirectIface = 1 << 5
kindGCProg = 1 << 6 // Type.gc points to GC program
kindMask = (1 << 5) - 1
)
// String returns the name of k.
func (k Kind) String() string {
if uint(k) < uint(len(kindNames)) {
return kindNames[uint(k)]
}
return "kind" + strconv.Itoa(int(k))
}
var kindNames = []string{
Invalid: "invalid",
Bool: "bool",
Int: "int",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Uint: "uint",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Uintptr: "uintptr",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
Array: "array",
Chan: "chan",
Func: "func",
Interface: "interface",
Map: "map",
Pointer: "ptr",
Slice: "slice",
String: "string",
Struct: "struct",
UnsafePointer: "unsafe.Pointer",
}
func (t *uncommonType) methods() []method {
if t.mcount == 0 {
return nil
}
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.mcount > 0"))[:t.mcount:t.mcount]
}
func (t *uncommonType) exportedMethods() []method {
if t.xcount == 0 {
return nil
}
return (*[1 << 16]method)(add(unsafe.Pointer(t), uintptr(t.moff), "t.xcount > 0"))[:t.xcount:t.xcount]
}
// resolveNameOff resolves a name offset from a base pointer.
// The (*rtype).nameOff method is a convenience wrapper for this function.
// Implemented in the runtime package.
func resolveNameOff(ptrInModule unsafe.Pointer, off int32) unsafe.Pointer
// resolveTypeOff resolves an *rtype offset from a base type.
// The (*rtype).typeOff method is a convenience wrapper for this function.
// Implemented in the runtime package.
func resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer
// resolveTextOff resolves a function pointer offset from a base type.
// The (*rtype).textOff method is a convenience wrapper for this function.
// Implemented in the runtime package.
func resolveTextOff(rtype unsafe.Pointer, off int32) unsafe.Pointer
// addReflectOff adds a pointer to the reflection lookup map in the runtime.
// It returns a new ID that can be used as a typeOff or textOff, and will
// be resolved correctly. Implemented in the runtime package.
func addReflectOff(ptr unsafe.Pointer) int32
// resolveReflectName adds a name to the reflection lookup map in the runtime.
// It returns a new nameOff that can be used to refer to the pointer.
func resolveReflectName(n name) nameOff {
return nameOff(addReflectOff(unsafe.Pointer(n.bytes)))
}
// resolveReflectType adds a *rtype to the reflection lookup map in the runtime.
// It returns a new typeOff that can be used to refer to the pointer.
func resolveReflectType(t *rtype) typeOff {
return typeOff(addReflectOff(unsafe.Pointer(t)))
}
// resolveReflectText adds a function pointer to the reflection lookup map in
// the runtime. It returns a new textOff that can be used to refer to the
// pointer.
func resolveReflectText(ptr unsafe.Pointer) textOff {
return textOff(addReflectOff(ptr))
}
type nameOff int32 // offset to a name
type typeOff int32 // offset to an *rtype
type textOff int32 // offset from top of text section
func (t *rtype) nameOff(off nameOff) name {
return name{(*byte)(resolveNameOff(unsafe.Pointer(t), int32(off)))}
}
func (t *rtype) typeOff(off typeOff) *rtype {
return (*rtype)(resolveTypeOff(unsafe.Pointer(t), int32(off)))
}
func (t *rtype) textOff(off textOff) unsafe.Pointer {
return resolveTextOff(unsafe.Pointer(t), int32(off))
}
func (t *rtype) uncommon() *uncommonType {
if t.tflag&tflagUncommon == 0 {
return nil
}
switch t.Kind() {
case Struct:
return &(*structTypeUncommon)(unsafe.Pointer(t)).u
case Pointer:
type u struct {
ptrType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Func:
type u struct {
funcType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Slice:
type u struct {
sliceType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Array:
type u struct {
arrayType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Chan:
type u struct {
chanType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Map:
type u struct {
mapType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
case Interface:
type u struct {
interfaceType
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
default:
type u struct {
rtype
u uncommonType
}
return &(*u)(unsafe.Pointer(t)).u
}
}
func (t *rtype) String() string {
s := t.nameOff(t.str).name()
if t.tflag&tflagExtraStar != 0 {
return s[1:]
}
return s
}
func (t *rtype) Size() uintptr { return t.size }
func (t *rtype) Bits() int {
if t == nil {
panic("reflect: Bits of nil Type")
}
k := t.Kind()
if k < Int || k > Complex128 {
panic("reflect: Bits of non-arithmetic Type " + t.String())
}
return int(t.size) * 8
}
func (t *rtype) Align() int { return int(t.align) }
func (t *rtype) FieldAlign() int { return int(t.fieldAlign) }
func (t *rtype) Kind() Kind { return Kind(t.kind & kindMask) }
func (t *rtype) pointers() bool { return t.ptrdata != 0 }
func (t *rtype) common() *rtype { return t }
func (t *rtype) exportedMethods() []method {
ut := t.uncommon()
if ut == nil {
return nil
}
return ut.exportedMethods()
}
func (t *rtype) NumMethod() int {
if t.Kind() == Interface {
tt := (*interfaceType)(unsafe.Pointer(t))
return tt.NumMethod()
}
return len(t.exportedMethods())
}
func (t *rtype) Method(i int) (m Method) {
if t.Kind() == Interface {
tt := (*interfaceType)(unsafe.Pointer(t))
return tt.Method(i)
}
methods := t.exportedMethods()
if i < 0 || i >= len(methods) {
panic("reflect: Method index out of range")
}
p := methods[i]
pname := t.nameOff(p.name)
m.Name = pname.name()
fl := flag(Func)
mtyp := t.typeOff(p.mtyp)
ft := (*funcType)(unsafe.Pointer(mtyp))
in := make([]Type, 0, 1+len(ft.in()))
in = append(in, t)
for _, arg := range ft.in() {
in = append(in, arg)
}
out := make([]Type, 0, len(ft.out()))
for _, ret := range ft.out() {
out = append(out, ret)
}
mt := FuncOf(in, out, ft.IsVariadic())
m.Type = mt
tfn := t.textOff(p.tfn)
fn := unsafe.Pointer(&tfn)
m.Func = Value{mt.(*rtype), fn, fl}
m.Index = i
return m
}
func (t *rtype) MethodByName(name string) (m Method, ok bool) {
if t.Kind() == Interface {
tt := (*interfaceType)(unsafe.Pointer(t))
return tt.MethodByName(name)
}
ut := t.uncommon()
if ut == nil {
return Method{}, false
}
methods := ut.exportedMethods()
// We are looking for the first index i where the string becomes >= s.
// This is a copy of sort.Search, with f(h) replaced by (t.nameOff(methods[h].name).name() >= name).
i, j := 0, len(methods)
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !(t.nameOff(methods[h].name).name() >= name) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
if i < len(methods) && name == t.nameOff(methods[i].name).name() {
return t.Method(i), true
}
return Method{}, false
}
func (t *rtype) PkgPath() string {
if t.tflag&tflagNamed == 0 {
return ""
}
ut := t.uncommon()
if ut == nil {
return ""
}
return t.nameOff(ut.pkgPath).name()
}
func (t *rtype) hasName() bool {
return t.tflag&tflagNamed != 0
}
func (t *rtype) Name() string {
if !t.hasName() {
return ""
}
s := t.String()
i := len(s) - 1
sqBrackets := 0
for i >= 0 && (s[i] != '.' || sqBrackets != 0) {
switch s[i] {
case ']':
sqBrackets++
case '[':
sqBrackets--
}
i--
}
return s[i+1:]
}
func (t *rtype) ChanDir() ChanDir {
if t.Kind() != Chan {
panic("reflect: ChanDir of non-chan type " + t.String())
}
tt := (*chanType)(unsafe.Pointer(t))
return ChanDir(tt.dir)
}
func (t *rtype) IsVariadic() bool {
if t.Kind() != Func {
panic("reflect: IsVariadic of non-func type " + t.String())
}
tt := (*funcType)(unsafe.Pointer(t))
return tt.outCount&(1<<15) != 0
}
func (t *rtype) Elem() Type {
switch t.Kind() {
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
return toType(tt.elem)
case Chan:
tt := (*chanType)(unsafe.Pointer(t))
return toType(tt.elem)
case Map:
tt := (*mapType)(unsafe.Pointer(t))
return toType(tt.elem)
case Pointer:
tt := (*ptrType)(unsafe.Pointer(t))
return toType(tt.elem)
case Slice:
tt := (*sliceType)(unsafe.Pointer(t))
return toType(tt.elem)
}
panic("reflect: Elem of invalid type " + t.String())
}
func (t *rtype) Field(i int) StructField {
if t.Kind() != Struct {
panic("reflect: Field of non-struct type " + t.String())
}
tt := (*structType)(unsafe.Pointer(t))
return tt.Field(i)
}
func (t *rtype) FieldByIndex(index []int) StructField {
if t.Kind() != Struct {
panic("reflect: FieldByIndex of non-struct type " + t.String())
}
tt := (*structType)(unsafe.Pointer(t))
return tt.FieldByIndex(index)
}
func (t *rtype) FieldByName(name string) (StructField, bool) {
if t.Kind() != Struct {
panic("reflect: FieldByName of non-struct type " + t.String())
}
tt := (*structType)(unsafe.Pointer(t))
return tt.FieldByName(name)
}
func (t *rtype) FieldByNameFunc(match func(string) bool) (StructField, bool) {
if t.Kind() != Struct {
panic("reflect: FieldByNameFunc of non-struct type " + t.String())
}
tt := (*structType)(unsafe.Pointer(t))
return tt.FieldByNameFunc(match)
}
func (t *rtype) In(i int) Type {
if t.Kind() != Func {
panic("reflect: In of non-func type " + t.String())
}
tt := (*funcType)(unsafe.Pointer(t))
return toType(tt.in()[i])
}
func (t *rtype) Key() Type {
if t.Kind() != Map {
panic("reflect: Key of non-map type " + t.String())
}
tt := (*mapType)(unsafe.Pointer(t))
return toType(tt.key)
}
func (t *rtype) Len() int {
if t.Kind() != Array {
panic("reflect: Len of non-array type " + t.String())
}
tt := (*arrayType)(unsafe.Pointer(t))
return int(tt.len)
}
func (t *rtype) NumField() int {
if t.Kind() != Struct {
panic("reflect: NumField of non-struct type " + t.String())
}
tt := (*structType)(unsafe.Pointer(t))
return len(tt.fields)
}
func (t *rtype) NumIn() int {
if t.Kind() != Func {
panic("reflect: NumIn of non-func type " + t.String())
}
tt := (*funcType)(unsafe.Pointer(t))
return int(tt.inCount)
}
func (t *rtype) NumOut() int {
if t.Kind() != Func {
panic("reflect: NumOut of non-func type " + t.String())
}
tt := (*funcType)(unsafe.Pointer(t))
return len(tt.out())
}
func (t *rtype) Out(i int) Type {
if t.Kind() != Func {
panic("reflect: Out of non-func type " + t.String())
}
tt := (*funcType)(unsafe.Pointer(t))
return toType(tt.out()[i])
}
func (t *funcType) in() []*rtype {
uadd := unsafe.Sizeof(*t)
if t.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommonType{})
}
if t.inCount == 0 {
return nil
}
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "t.inCount > 0"))[:t.inCount:t.inCount]
}
func (t *funcType) out() []*rtype {
uadd := unsafe.Sizeof(*t)
if t.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommonType{})
}
outCount := t.outCount & (1<<15 - 1)
if outCount == 0 {
return nil
}
return (*[1 << 20]*rtype)(add(unsafe.Pointer(t), uadd, "outCount > 0"))[t.inCount : t.inCount+outCount : t.inCount+outCount]
}
// add returns p+x.
//
// The whySafe string is ignored, so that the function still inlines
// as efficiently as p+x, but all call sites should use the string to
// record why the addition is safe, which is to say why the addition
// does not cause x to advance to the very end of p's allocation
// and therefore point incorrectly at the next block in memory.
func add(p unsafe.Pointer, x uintptr, whySafe string) unsafe.Pointer {
return unsafe.Pointer(uintptr(p) + x)
}
func (d ChanDir) String() string {
switch d {
case SendDir:
return "chan<-"
case RecvDir:
return "<-chan"
case BothDir:
return "chan"
}
return "ChanDir" + strconv.Itoa(int(d))
}
// Method returns the i'th method in the type's method set.
func (t *interfaceType) Method(i int) (m Method) {
if i < 0 || i >= len(t.methods) {
return
}
p := &t.methods[i]
pname := t.nameOff(p.name)
m.Name = pname.name()
if !pname.isExported() {
m.PkgPath = pname.pkgPath()
if m.PkgPath == "" {
m.PkgPath = t.pkgPath.name()
}
}
m.Type = toType(t.typeOff(p.typ))
m.Index = i
return
}
// NumMethod returns the number of interface methods in the type's method set.
func (t *interfaceType) NumMethod() int { return len(t.methods) }
// MethodByName method with the given name in the type's method set.
func (t *interfaceType) MethodByName(name string) (m Method, ok bool) {
if t == nil {
return
}
var p *imethod
for i := range t.methods {
p = &t.methods[i]
if t.nameOff(p.name).name() == name {
return t.Method(i), true
}
}
return
}
// A StructField describes a single field in a struct.
type StructField struct {
// Name is the field name.
Name string
// PkgPath is the package path that qualifies a lower case (unexported)
// field name. It is empty for upper case (exported) field names.
// See https://golang.org/ref/spec#Uniqueness_of_identifiers
PkgPath string
Type Type // field type
Tag StructTag // field tag string
Offset uintptr // offset within struct, in bytes
Index []int // index sequence for Type.FieldByIndex
Anonymous bool // is an embedded field
}
// IsExported reports whether the field is exported.
func (f StructField) IsExported() bool {
return f.PkgPath == ""
}
// A StructTag is the tag string in a struct field.
//
// By convention, tag strings are a concatenation of
// optionally space-separated key:"value" pairs.
// Each key is a non-empty string consisting of non-control
// characters other than space (U+0020 ' '), quote (U+0022 '"'),
// and colon (U+003A ':'). Each value is quoted using U+0022 '"'
// characters and Go string literal syntax.
type StructTag string
// Get returns the value associated with key in the tag string.
// If there is no such key in the tag, Get returns the empty string.
// If the tag does not have the conventional format, the value
// returned by Get is unspecified. To determine whether a tag is
// explicitly set to the empty string, use Lookup.
func (tag StructTag) Get(key string) string {
v, _ := tag.Lookup(key)
return v
}
// Lookup returns the value associated with key in the tag string.
// If the key is present in the tag the value (which may be empty)
// is returned. Otherwise the returned value will be the empty string.
// The ok return value reports whether the value was explicitly set in
// the tag string. If the tag does not have the conventional format,
// the value returned by Lookup is unspecified.
func (tag StructTag) Lookup(key string) (value string, ok bool) {
// When modifying this code, also update the validateStructTag code
// in cmd/vet/structtag.go.
for tag != "" {
// Skip leading space.
i := 0
for i < len(tag) && tag[i] == ' ' {
i++
}
tag = tag[i:]
if tag == "" {
break
}
// Scan to colon. A space, a quote or a control character is a syntax error.
// Strictly speaking, control chars include the range [0x7f, 0x9f], not just
// [0x00, 0x1f], but in practice, we ignore the multi-byte control characters
// as it is simpler to inspect the tag's bytes than the tag's runes.
i = 0
for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f {
i++
}
if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' {
break
}
name := string(tag[:i])
tag = tag[i+1:]
// Scan quoted string to find value.
i = 1
for i < len(tag) && tag[i] != '"' {
if tag[i] == '\\' {
i++
}
i++
}
if i >= len(tag) {
break
}
qvalue := string(tag[:i+1])
tag = tag[i+1:]
if key == name {
value, err := strconv.Unquote(qvalue)
if err != nil {
break
}
return value, true
}
}
return "", false
}
// Field returns the i'th struct field.
func (t *structType) Field(i int) (f StructField) {
if i < 0 || i >= len(t.fields) {
panic("reflect: Field index out of bounds")
}
p := &t.fields[i]
f.Type = toType(p.typ)
f.Name = p.name.name()
f.Anonymous = p.embedded()
if !p.name.isExported() {
f.PkgPath = t.pkgPath.name()
}
if tag := p.name.tag(); tag != "" {
f.Tag = StructTag(tag)
}
f.Offset = p.offset
// NOTE(rsc): This is the only allocation in the interface
// presented by a reflect.Type. It would be nice to avoid,
// at least in the common cases, but we need to make sure
// that misbehaving clients of reflect cannot affect other
// uses of reflect. One possibility is CL 5371098, but we
// postponed that ugliness until there is a demonstrated
// need for the performance. This is issue 2320.
f.Index = []int{i}
return
}
// TODO(gri): Should there be an error/bool indicator if the index
// is wrong for FieldByIndex?
// FieldByIndex returns the nested field corresponding to index.
func (t *structType) FieldByIndex(index []int) (f StructField) {
f.Type = toType(&t.rtype)
for i, x := range index {
if i > 0 {
ft := f.Type
if ft.Kind() == Pointer && ft.Elem().Kind() == Struct {
ft = ft.Elem()
}
f.Type = ft
}
f = f.Type.Field(x)
}
return
}
// A fieldScan represents an item on the fieldByNameFunc scan work list.
type fieldScan struct {
typ *structType
index []int
}
// FieldByNameFunc returns the struct field with a name that satisfies the
// match function and a boolean to indicate if the field was found.
func (t *structType) FieldByNameFunc(match func(string) bool) (result StructField, ok bool) {
// This uses the same condition that the Go language does: there must be a unique instance
// of the match at a given depth level. If there are multiple instances of a match at the
// same depth, they annihilate each other and inhibit any possible match at a lower level.
// The algorithm is breadth first search, one depth level at a time.
// The current and next slices are work queues:
// current lists the fields to visit on this depth level,
// and next lists the fields on the next lower level.
current := []fieldScan{}
next := []fieldScan{{typ: t}}
// nextCount records the number of times an embedded type has been
// encountered and considered for queueing in the 'next' slice.
// We only queue the first one, but we increment the count on each.
// If a struct type T can be reached more than once at a given depth level,
// then it annihilates itself and need not be considered at all when we
// process that next depth level.
var nextCount map[*structType]int
// visited records the structs that have been considered already.
// Embedded pointer fields can create cycles in the graph of
// reachable embedded types; visited avoids following those cycles.
// It also avoids duplicated effort: if we didn't find the field in an
// embedded type T at level 2, we won't find it in one at level 4 either.
visited := map[*structType]bool{}
for len(next) > 0 {
current, next = next, current[:0]
count := nextCount
nextCount = nil
// Process all the fields at this depth, now listed in 'current'.
// The loop queues embedded fields found in 'next', for processing during the next
// iteration. The multiplicity of the 'current' field counts is recorded
// in 'count'; the multiplicity of the 'next' field counts is recorded in 'nextCount'.
for _, scan := range current {
t := scan.typ
if visited[t] {
// We've looked through this type before, at a higher level.
// That higher level would shadow the lower level we're now at,
// so this one can't be useful to us. Ignore it.
continue
}
visited[t] = true
for i := range t.fields {
f := &t.fields[i]
// Find name and (for embedded field) type for field f.
fname := f.name.name()
var ntyp *rtype
if f.embedded() {
// Embedded field of type T or *T.
ntyp = f.typ
if ntyp.Kind() == Pointer {
ntyp = ntyp.Elem().common()
}
}
// Does it match?
if match(fname) {
// Potential match
if count[t] > 1 || ok {
// Name appeared multiple times at this level: annihilate.
return StructField{}, false
}
result = t.Field(i)
result.Index = nil
result.Index = append(result.Index, scan.index...)
result.Index = append(result.Index, i)
ok = true
continue
}
// Queue embedded struct fields for processing with next level,
// but only if we haven't seen a match yet at this level and only
// if the embedded types haven't already been queued.
if ok || ntyp == nil || ntyp.Kind() != Struct {
continue
}
styp := (*structType)(unsafe.Pointer(ntyp))
if nextCount[styp] > 0 {
nextCount[styp] = 2 // exact multiple doesn't matter
continue
}
if nextCount == nil {
nextCount = map[*structType]int{}
}
nextCount[styp] = 1
if count[t] > 1 {
nextCount[styp] = 2 // exact multiple doesn't matter
}
var index []int
index = append(index, scan.index...)
index = append(index, i)
next = append(next, fieldScan{styp, index})
}
}
if ok {
break
}
}
return
}
// FieldByName returns the struct field with the given name
// and a boolean to indicate if the field was found.
func (t *structType) FieldByName(name string) (f StructField, present bool) {
// Quick check for top-level name, or struct without embedded fields.
hasEmbeds := false
if name != "" {
for i := range t.fields {
tf := &t.fields[i]
if tf.name.name() == name {
return t.Field(i), true
}
if tf.embedded() {
hasEmbeds = true
}
}
}
if !hasEmbeds {
return
}
return t.FieldByNameFunc(func(s string) bool { return s == name })
}
// TypeOf returns the reflection Type that represents the dynamic type of i.
// If i is a nil interface value, TypeOf returns nil.
func TypeOf(i any) Type {
eface := *(*emptyInterface)(unsafe.Pointer(&i))
return toType(eface.typ)
}
// rtypeOf directly extracts the *rtype of the provided value.
func rtypeOf(i any) *rtype {
eface := *(*emptyInterface)(unsafe.Pointer(&i))
return eface.typ
}
// ptrMap is the cache for PointerTo.
var ptrMap sync.Map // map[*rtype]*ptrType
// PtrTo returns the pointer type with element t.
// For example, if t represents type Foo, PtrTo(t) represents *Foo.
//
// PtrTo is the old spelling of PointerTo.
// The two functions behave identically.
func PtrTo(t Type) Type { return PointerTo(t) }
// PointerTo returns the pointer type with element t.
// For example, if t represents type Foo, PointerTo(t) represents *Foo.
func PointerTo(t Type) Type {
return t.(*rtype).ptrTo()
}
func (t *rtype) ptrTo() *rtype {
if t.ptrToThis != 0 {
return t.typeOff(t.ptrToThis)
}
// Check the cache.
if pi, ok := ptrMap.Load(t); ok {
return &pi.(*ptrType).rtype
}
// Look in known types.
s := "*" + t.String()
for _, tt := range typesByString(s) {
p := (*ptrType)(unsafe.Pointer(tt))
if p.elem != t {
continue
}
pi, _ := ptrMap.LoadOrStore(t, p)
return &pi.(*ptrType).rtype
}
// Create a new ptrType starting with the description
// of an *unsafe.Pointer.
var iptr any = (*unsafe.Pointer)(nil)
prototype := *(**ptrType)(unsafe.Pointer(&iptr))
pp := *prototype
pp.str = resolveReflectName(newName(s, "", false, false))
pp.ptrToThis = 0
// For the type structures linked into the binary, the
// compiler provides a good hash of the string.
// Create a good hash for the new string by using
// the FNV-1 hash's mixing function to combine the
// old hash and the new "*".
pp.hash = fnv1(t.hash, '*')
pp.elem = t
pi, _ := ptrMap.LoadOrStore(t, &pp)
return &pi.(*ptrType).rtype
}
// fnv1 incorporates the list of bytes into the hash x using the FNV-1 hash function.
func fnv1(x uint32, list ...byte) uint32 {
for _, b := range list {
x = x*16777619 ^ uint32(b)
}
return x
}
func (t *rtype) Implements(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.Implements")
}
if u.Kind() != Interface {
panic("reflect: non-interface type passed to Type.Implements")
}
return implements(u.(*rtype), t)
}
func (t *rtype) AssignableTo(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.AssignableTo")
}
uu := u.(*rtype)
return directlyAssignable(uu, t) || implements(uu, t)
}
func (t *rtype) ConvertibleTo(u Type) bool {
if u == nil {
panic("reflect: nil type passed to Type.ConvertibleTo")
}
uu := u.(*rtype)
return convertOp(uu, t) != nil
}
func (t *rtype) Comparable() bool {
return t.equal != nil
}
// implements reports whether the type V implements the interface type T.
func implements(T, V *rtype) bool {
if T.Kind() != Interface {
return false
}
t := (*interfaceType)(unsafe.Pointer(T))
if len(t.methods) == 0 {
return true
}
// The same algorithm applies in both cases, but the
// method tables for an interface type and a concrete type
// are different, so the code is duplicated.
// In both cases the algorithm is a linear scan over the two
// lists - T's methods and V's methods - simultaneously.
// Since method tables are stored in a unique sorted order
// (alphabetical, with no duplicate method names), the scan
// through V's methods must hit a match for each of T's
// methods along the way, or else V does not implement T.
// This lets us run the scan in overall linear time instead of
// the quadratic time a naive search would require.
// See also ../runtime/iface.go.
if V.Kind() == Interface {
v := (*interfaceType)(unsafe.Pointer(V))
i := 0
for j := 0; j < len(v.methods); j++ {
tm := &t.methods[i]
tmName := t.nameOff(tm.name)
vm := &v.methods[j]
vmName := V.nameOff(vm.name)
if vmName.name() == tmName.name() && V.typeOff(vm.typ) == t.typeOff(tm.typ) {
if !tmName.isExported() {
tmPkgPath := tmName.pkgPath()
if tmPkgPath == "" {
tmPkgPath = t.pkgPath.name()
}
vmPkgPath := vmName.pkgPath()
if vmPkgPath == "" {
vmPkgPath = v.pkgPath.name()
}
if tmPkgPath != vmPkgPath {
continue
}
}
if i++; i >= len(t.methods) {
return true
}
}
}
return false
}
v := V.uncommon()
if v == nil {
return false
}
i := 0
vmethods := v.methods()
for j := 0; j < int(v.mcount); j++ {
tm := &t.methods[i]
tmName := t.nameOff(tm.name)
vm := vmethods[j]
vmName := V.nameOff(vm.name)
if vmName.name() == tmName.name() && V.typeOff(vm.mtyp) == t.typeOff(tm.typ) {
if !tmName.isExported() {
tmPkgPath := tmName.pkgPath()
if tmPkgPath == "" {
tmPkgPath = t.pkgPath.name()
}
vmPkgPath := vmName.pkgPath()
if vmPkgPath == "" {
vmPkgPath = V.nameOff(v.pkgPath).name()
}
if tmPkgPath != vmPkgPath {
continue
}
}
if i++; i >= len(t.methods) {
return true
}
}
}
return false
}
// specialChannelAssignability reports whether a value x of channel type V
// can be directly assigned (using memmove) to another channel type T.
// https://golang.org/doc/go_spec.html#Assignability
// T and V must be both of Chan kind.
func specialChannelAssignability(T, V *rtype) bool {
// Special case:
// x is a bidirectional channel value, T is a channel type,
// x's type V and T have identical element types,
// and at least one of V or T is not a defined type.
return V.ChanDir() == BothDir && (T.Name() == "" || V.Name() == "") && haveIdenticalType(T.Elem(), V.Elem(), true)
}
// directlyAssignable reports whether a value x of type V can be directly
// assigned (using memmove) to a value of type T.
// https://golang.org/doc/go_spec.html#Assignability
// Ignoring the interface rules (implemented elsewhere)
// and the ideal constant rules (no ideal constants at run time).
func directlyAssignable(T, V *rtype) bool {
// x's type V is identical to T?
if T == V {
return true
}
// Otherwise at least one of T and V must not be defined
// and they must have the same kind.
if T.hasName() && V.hasName() || T.Kind() != V.Kind() {
return false
}
if T.Kind() == Chan && specialChannelAssignability(T, V) {
return true
}
// x's type T and V must have identical underlying types.
return haveIdenticalUnderlyingType(T, V, true)
}
func haveIdenticalType(T, V Type, cmpTags bool) bool {
if cmpTags {
return T == V
}
if T.Name() != V.Name() || T.Kind() != V.Kind() || T.PkgPath() != V.PkgPath() {
return false
}
return haveIdenticalUnderlyingType(T.common(), V.common(), false)
}
func haveIdenticalUnderlyingType(T, V *rtype, cmpTags bool) bool {
if T == V {
return true
}
kind := T.Kind()
if kind != V.Kind() {
return false
}
// Non-composite types of equal kind have same underlying type
// (the predefined instance of the type).
if Bool <= kind && kind <= Complex128 || kind == String || kind == UnsafePointer {
return true
}
// Composite types.
switch kind {
case Array:
return T.Len() == V.Len() && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Chan:
return V.ChanDir() == T.ChanDir() && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Func:
t := (*funcType)(unsafe.Pointer(T))
v := (*funcType)(unsafe.Pointer(V))
if t.outCount != v.outCount || t.inCount != v.inCount {
return false
}
for i := 0; i < t.NumIn(); i++ {
if !haveIdenticalType(t.In(i), v.In(i), cmpTags) {
return false
}
}
for i := 0; i < t.NumOut(); i++ {
if !haveIdenticalType(t.Out(i), v.Out(i), cmpTags) {
return false
}
}
return true
case Interface:
t := (*interfaceType)(unsafe.Pointer(T))
v := (*interfaceType)(unsafe.Pointer(V))
if len(t.methods) == 0 && len(v.methods) == 0 {
return true
}
// Might have the same methods but still
// need a run time conversion.
return false
case Map:
return haveIdenticalType(T.Key(), V.Key(), cmpTags) && haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Pointer, Slice:
return haveIdenticalType(T.Elem(), V.Elem(), cmpTags)
case Struct:
t := (*structType)(unsafe.Pointer(T))
v := (*structType)(unsafe.Pointer(V))
if len(t.fields) != len(v.fields) {
return false
}
if t.pkgPath.name() != v.pkgPath.name() {
return false
}
for i := range t.fields {
tf := &t.fields[i]
vf := &v.fields[i]
if tf.name.name() != vf.name.name() {
return false
}
if !haveIdenticalType(tf.typ, vf.typ, cmpTags) {
return false
}
if cmpTags && tf.name.tag() != vf.name.tag() {
return false
}
if tf.offset != vf.offset {
return false
}
if tf.embedded() != vf.embedded() {
return false
}
}
return true
}
return false
}
// typelinks is implemented in package runtime.
// It returns a slice of the sections in each module,
// and a slice of *rtype offsets in each module.
//
// The types in each module are sorted by string. That is, the first
// two linked types of the first module are:
//
// d0 := sections[0]
// t1 := (*rtype)(add(d0, offset[0][0]))
// t2 := (*rtype)(add(d0, offset[0][1]))
//
// and
//
// t1.String() < t2.String()
//
// Note that strings are not unique identifiers for types:
// there can be more than one with a given string.
// Only types we might want to look up are included:
// pointers, channels, maps, slices, and arrays.
func typelinks() (sections []unsafe.Pointer, offset [][]int32)
func rtypeOff(section unsafe.Pointer, off int32) *rtype {
return (*rtype)(add(section, uintptr(off), "sizeof(rtype) > 0"))
}
// typesByString returns the subslice of typelinks() whose elements have
// the given string representation.
// It may be empty (no known types with that string) or may have
// multiple elements (multiple types with that string).
func typesByString(s string) []*rtype {
sections, offset := typelinks()
var ret []*rtype
for offsI, offs := range offset {
section := sections[offsI]
// We are looking for the first index i where the string becomes >= s.
// This is a copy of sort.Search, with f(h) replaced by (*typ[h].String() >= s).
i, j := 0, len(offs)
for i < j {
h := i + (j-i)>>1 // avoid overflow when computing h
// i ≤ h < j
if !(rtypeOff(section, offs[h]).String() >= s) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
// Having found the first, linear scan forward to find the last.
// We could do a second binary search, but the caller is going
// to do a linear scan anyway.
for j := i; j < len(offs); j++ {
typ := rtypeOff(section, offs[j])
if typ.String() != s {
break
}
ret = append(ret, typ)
}
}
return ret
}
// The lookupCache caches ArrayOf, ChanOf, MapOf and SliceOf lookups.
var lookupCache sync.Map // map[cacheKey]*rtype
// A cacheKey is the key for use in the lookupCache.
// Four values describe any of the types we are looking for:
// type kind, one or two subtypes, and an extra integer.
type cacheKey struct {
kind Kind
t1 *rtype
t2 *rtype
extra uintptr
}
// The funcLookupCache caches FuncOf lookups.
// FuncOf does not share the common lookupCache since cacheKey is not
// sufficient to represent functions unambiguously.
var funcLookupCache struct {
sync.Mutex // Guards stores (but not loads) on m.
// m is a map[uint32][]*rtype keyed by the hash calculated in FuncOf.
// Elements of m are append-only and thus safe for concurrent reading.
m sync.Map
}
// ChanOf returns the channel type with the given direction and element type.
// For example, if t represents int, ChanOf(RecvDir, t) represents <-chan int.
//
// The gc runtime imposes a limit of 64 kB on channel element types.
// If t's size is equal to or exceeds this limit, ChanOf panics.
func ChanOf(dir ChanDir, t Type) Type {
typ := t.(*rtype)
// Look in cache.
ckey := cacheKey{Chan, typ, nil, uintptr(dir)}
if ch, ok := lookupCache.Load(ckey); ok {
return ch.(*rtype)
}
// This restriction is imposed by the gc compiler and the runtime.
if typ.size >= 1<<16 {
panic("reflect.ChanOf: element size too large")
}
// Look in known types.
var s string
switch dir {
default:
panic("reflect.ChanOf: invalid dir")
case SendDir:
s = "chan<- " + typ.String()
case RecvDir:
s = "<-chan " + typ.String()
case BothDir:
typeStr := typ.String()
if typeStr[0] == '<' {
// typ is recv chan, need parentheses as "<-" associates with leftmost
// chan possible, see:
// * https://golang.org/ref/spec#Channel_types
// * https://github.com/golang/go/issues/39897
s = "chan (" + typeStr + ")"
} else {
s = "chan " + typeStr
}
}
for _, tt := range typesByString(s) {
ch := (*chanType)(unsafe.Pointer(tt))
if ch.elem == typ && ch.dir == uintptr(dir) {
ti, _ := lookupCache.LoadOrStore(ckey, tt)
return ti.(Type)
}
}
// Make a channel type.
var ichan any = (chan unsafe.Pointer)(nil)
prototype := *(**chanType)(unsafe.Pointer(&ichan))
ch := *prototype
ch.tflag = tflagRegularMemory
ch.dir = uintptr(dir)
ch.str = resolveReflectName(newName(s, "", false, false))
ch.hash = fnv1(typ.hash, 'c', byte(dir))
ch.elem = typ
ti, _ := lookupCache.LoadOrStore(ckey, &ch.rtype)
return ti.(Type)
}
// MapOf returns the map type with the given key and element types.
// For example, if k represents int and e represents string,
// MapOf(k, e) represents map[int]string.
//
// If the key type is not a valid map key type (that is, if it does
// not implement Go's == operator), MapOf panics.
func MapOf(key, elem Type) Type {
ktyp := key.(*rtype)
etyp := elem.(*rtype)
if ktyp.equal == nil {
panic("reflect.MapOf: invalid key type " + ktyp.String())
}
// Look in cache.
ckey := cacheKey{Map, ktyp, etyp, 0}
if mt, ok := lookupCache.Load(ckey); ok {
return mt.(Type)
}
// Look in known types.
s := "map[" + ktyp.String() + "]" + etyp.String()
for _, tt := range typesByString(s) {
mt := (*mapType)(unsafe.Pointer(tt))
if mt.key == ktyp && mt.elem == etyp {
ti, _ := lookupCache.LoadOrStore(ckey, tt)
return ti.(Type)
}
}
// Make a map type.
// Note: flag values must match those used in the TMAP case
// in ../cmd/compile/internal/reflectdata/reflect.go:writeType.
var imap any = (map[unsafe.Pointer]unsafe.Pointer)(nil)
mt := **(**mapType)(unsafe.Pointer(&imap))
mt.str = resolveReflectName(newName(s, "", false, false))
mt.tflag = 0
mt.hash = fnv1(etyp.hash, 'm', byte(ktyp.hash>>24), byte(ktyp.hash>>16), byte(ktyp.hash>>8), byte(ktyp.hash))
mt.key = ktyp
mt.elem = etyp
mt.bucket = bucketOf(ktyp, etyp)
mt.hasher = func(p unsafe.Pointer, seed uintptr) uintptr {
return typehash(ktyp, p, seed)
}
mt.flags = 0
if ktyp.size > maxKeySize {
mt.keysize = uint8(goarch.PtrSize)
mt.flags |= 1 // indirect key
} else {
mt.keysize = uint8(ktyp.size)
}
if etyp.size > maxValSize {
mt.valuesize = uint8(goarch.PtrSize)
mt.flags |= 2 // indirect value
} else {
mt.valuesize = uint8(etyp.size)
}
mt.bucketsize = uint16(mt.bucket.size)
if isReflexive(ktyp) {
mt.flags |= 4
}
if needKeyUpdate(ktyp) {
mt.flags |= 8
}
if hashMightPanic(ktyp) {
mt.flags |= 16
}
mt.ptrToThis = 0
ti, _ := lookupCache.LoadOrStore(ckey, &mt.rtype)
return ti.(Type)
}
var funcTypes []Type
var funcTypesMutex sync.Mutex
func initFuncTypes(n int) Type {
funcTypesMutex.Lock()
defer funcTypesMutex.Unlock()
if n >= len(funcTypes) {
newFuncTypes := make([]Type, n+1)
copy(newFuncTypes, funcTypes)
funcTypes = newFuncTypes
}
if funcTypes[n] != nil {
return funcTypes[n]
}
funcTypes[n] = StructOf([]StructField{
{
Name: "FuncType",
Type: TypeOf(funcType{}),
},
{
Name: "Args",
Type: ArrayOf(n, TypeOf(&rtype{})),
},
})
return funcTypes[n]
}
// FuncOf returns the function type with the given argument and result types.
// For example if k represents int and e represents string,
// FuncOf([]Type{k}, []Type{e}, false) represents func(int) string.
//
// The variadic argument controls whether the function is variadic. FuncOf
// panics if the in[len(in)-1] does not represent a slice and variadic is
// true.
func FuncOf(in, out []Type, variadic bool) Type {
if variadic && (len(in) == 0 || in[len(in)-1].Kind() != Slice) {
panic("reflect.FuncOf: last arg of variadic func must be slice")
}
// Make a func type.
var ifunc any = (func())(nil)
prototype := *(**funcType)(unsafe.Pointer(&ifunc))
n := len(in) + len(out)
if n > 128 {
panic("reflect.FuncOf: too many arguments")
}
o := New(initFuncTypes(n)).Elem()
ft := (*funcType)(unsafe.Pointer(o.Field(0).Addr().Pointer()))
args := unsafe.Slice((**rtype)(unsafe.Pointer(o.Field(1).Addr().Pointer())), n)[0:0:n]
*ft = *prototype
// Build a hash and minimally populate ft.
var hash uint32
for _, in := range in {
t := in.(*rtype)
args = append(args, t)
hash = fnv1(hash, byte(t.hash>>24), byte(t.hash>>16), byte(t.hash>>8), byte(t.hash))
}
if variadic {
hash = fnv1(hash, 'v')
}
hash = fnv1(hash, '.')
for _, out := range out {
t := out.(*rtype)
args = append(args, t)
hash = fnv1(hash, byte(t.hash>>24), byte(t.hash>>16), byte(t.hash>>8), byte(t.hash))
}
ft.tflag = 0
ft.hash = hash
ft.inCount = uint16(len(in))
ft.outCount = uint16(len(out))
if variadic {
ft.outCount |= 1 << 15
}
// Look in cache.
if ts, ok := funcLookupCache.m.Load(hash); ok {
for _, t := range ts.([]*rtype) {
if haveIdenticalUnderlyingType(&ft.rtype, t, true) {
return t
}
}
}
// Not in cache, lock and retry.
funcLookupCache.Lock()
defer funcLookupCache.Unlock()
if ts, ok := funcLookupCache.m.Load(hash); ok {
for _, t := range ts.([]*rtype) {
if haveIdenticalUnderlyingType(&ft.rtype, t, true) {
return t
}
}
}
addToCache := func(tt *rtype) Type {
var rts []*rtype
if rti, ok := funcLookupCache.m.Load(hash); ok {
rts = rti.([]*rtype)
}
funcLookupCache.m.Store(hash, append(rts, tt))
return tt
}
// Look in known types for the same string representation.
str := funcStr(ft)
for _, tt := range typesByString(str) {
if haveIdenticalUnderlyingType(&ft.rtype, tt, true) {
return addToCache(tt)
}
}
// Populate the remaining fields of ft and store in cache.
ft.str = resolveReflectName(newName(str, "", false, false))
ft.ptrToThis = 0
return addToCache(&ft.rtype)
}
// funcStr builds a string representation of a funcType.
func funcStr(ft *funcType) string {
repr := make([]byte, 0, 64)
repr = append(repr, "func("...)
for i, t := range ft.in() {
if i > 0 {
repr = append(repr, ", "...)
}
if ft.IsVariadic() && i == int(ft.inCount)-1 {
repr = append(repr, "..."...)
repr = append(repr, (*sliceType)(unsafe.Pointer(t)).elem.String()...)
} else {
repr = append(repr, t.String()...)
}
}
repr = append(repr, ')')
out := ft.out()
if len(out) == 1 {
repr = append(repr, ' ')
} else if len(out) > 1 {
repr = append(repr, " ("...)
}
for i, t := range out {
if i > 0 {
repr = append(repr, ", "...)
}
repr = append(repr, t.String()...)
}
if len(out) > 1 {
repr = append(repr, ')')
}
return string(repr)
}
// isReflexive reports whether the == operation on the type is reflexive.
// That is, x == x for all values x of type t.
func isReflexive(t *rtype) bool {
switch t.Kind() {
case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr, Chan, Pointer, String, UnsafePointer:
return true
case Float32, Float64, Complex64, Complex128, Interface:
return false
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
return isReflexive(tt.elem)
case Struct:
tt := (*structType)(unsafe.Pointer(t))
for _, f := range tt.fields {
if !isReflexive(f.typ) {
return false
}
}
return true
default:
// Func, Map, Slice, Invalid
panic("isReflexive called on non-key type " + t.String())
}
}
// needKeyUpdate reports whether map overwrites require the key to be copied.
func needKeyUpdate(t *rtype) bool {
switch t.Kind() {
case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr, Chan, Pointer, UnsafePointer:
return false
case Float32, Float64, Complex64, Complex128, Interface, String:
// Float keys can be updated from +0 to -0.
// String keys can be updated to use a smaller backing store.
// Interfaces might have floats of strings in them.
return true
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
return needKeyUpdate(tt.elem)
case Struct:
tt := (*structType)(unsafe.Pointer(t))
for _, f := range tt.fields {
if needKeyUpdate(f.typ) {
return true
}
}
return false
default:
// Func, Map, Slice, Invalid
panic("needKeyUpdate called on non-key type " + t.String())
}
}
// hashMightPanic reports whether the hash of a map key of type t might panic.
func hashMightPanic(t *rtype) bool {
switch t.Kind() {
case Interface:
return true
case Array:
tt := (*arrayType)(unsafe.Pointer(t))
return hashMightPanic(tt.elem)
case Struct:
tt := (*structType)(unsafe.Pointer(t))
for _, f := range tt.fields {
if hashMightPanic(f.typ) {
return true
}
}
return false
default:
return false
}
}
// Make sure these routines stay in sync with ../runtime/map.go!
// These types exist only for GC, so we only fill out GC relevant info.
// Currently, that's just size and the GC program. We also fill in string
// for possible debugging use.
const (
bucketSize uintptr = abi.MapBucketCount
maxKeySize uintptr = abi.MapMaxKeyBytes
maxValSize uintptr = abi.MapMaxElemBytes
)
func bucketOf(ktyp, etyp *rtype) *rtype {
if ktyp.size > maxKeySize {
ktyp = PointerTo(ktyp).(*rtype)
}
if etyp.size > maxValSize {
etyp = PointerTo(etyp).(*rtype)
}
// Prepare GC data if any.
// A bucket is at most bucketSize*(1+maxKeySize+maxValSize)+ptrSize bytes,
// or 2064 bytes, or 258 pointer-size words, or 33 bytes of pointer bitmap.
// Note that since the key and value are known to be <= 128 bytes,
// they're guaranteed to have bitmaps instead of GC programs.
var gcdata *byte
var ptrdata uintptr
size := bucketSize*(1+ktyp.size+etyp.size) + goarch.PtrSize
if size&uintptr(ktyp.align-1) != 0 || size&uintptr(etyp.align-1) != 0 {
panic("reflect: bad size computation in MapOf")
}
if ktyp.ptrdata != 0 || etyp.ptrdata != 0 {
nptr := (bucketSize*(1+ktyp.size+etyp.size) + goarch.PtrSize) / goarch.PtrSize
n := (nptr + 7) / 8
// Runtime needs pointer masks to be a multiple of uintptr in size.
n = (n + goarch.PtrSize - 1) &^ (goarch.PtrSize - 1)
mask := make([]byte, n)
base := bucketSize / goarch.PtrSize
if ktyp.ptrdata != 0 {
emitGCMask(mask, base, ktyp, bucketSize)
}
base += bucketSize * ktyp.size / goarch.PtrSize
if etyp.ptrdata != 0 {
emitGCMask(mask, base, etyp, bucketSize)
}
base += bucketSize * etyp.size / goarch.PtrSize
word := base
mask[word/8] |= 1 << (word % 8)
gcdata = &mask[0]
ptrdata = (word + 1) * goarch.PtrSize
// overflow word must be last
if ptrdata != size {
panic("reflect: bad layout computation in MapOf")
}
}
b := &rtype{
align: goarch.PtrSize,
size: size,
kind: uint8(Struct),
ptrdata: ptrdata,
gcdata: gcdata,
}
s := "bucket(" + ktyp.String() + "," + etyp.String() + ")"
b.str = resolveReflectName(newName(s, "", false, false))
return b
}
func (t *rtype) gcSlice(begin, end uintptr) []byte {
return (*[1 << 30]byte)(unsafe.Pointer(t.gcdata))[begin:end:end]
}
// emitGCMask writes the GC mask for [n]typ into out, starting at bit
// offset base.
func emitGCMask(out []byte, base uintptr, typ *rtype, n uintptr) {
if typ.kind&kindGCProg != 0 {
panic("reflect: unexpected GC program")
}
ptrs := typ.ptrdata / goarch.PtrSize
words := typ.size / goarch.PtrSize
mask := typ.gcSlice(0, (ptrs+7)/8)
for j := uintptr(0); j < ptrs; j++ {
if (mask[j/8]>>(j%8))&1 != 0 {
for i := uintptr(0); i < n; i++ {
k := base + i*words + j
out[k/8] |= 1 << (k % 8)
}
}
}
}
// appendGCProg appends the GC program for the first ptrdata bytes of
// typ to dst and returns the extended slice.
func appendGCProg(dst []byte, typ *rtype) []byte {
if typ.kind&kindGCProg != 0 {
// Element has GC program; emit one element.
n := uintptr(*(*uint32)(unsafe.Pointer(typ.gcdata)))
prog := typ.gcSlice(4, 4+n-1)
return append(dst, prog...)
}
// Element is small with pointer mask; use as literal bits.
ptrs := typ.ptrdata / goarch.PtrSize
mask := typ.gcSlice(0, (ptrs+7)/8)
// Emit 120-bit chunks of full bytes (max is 127 but we avoid using partial bytes).
for ; ptrs > 120; ptrs -= 120 {
dst = append(dst, 120)
dst = append(dst, mask[:15]...)
mask = mask[15:]
}
dst = append(dst, byte(ptrs))
dst = append(dst, mask...)
return dst
}
// SliceOf returns the slice type with element type t.
// For example, if t represents int, SliceOf(t) represents []int.
func SliceOf(t Type) Type {
typ := t.(*rtype)
// Look in cache.
ckey := cacheKey{Slice, typ, nil, 0}
if slice, ok := lookupCache.Load(ckey); ok {
return slice.(Type)
}
// Look in known types.
s := "[]" + typ.String()
for _, tt := range typesByString(s) {
slice := (*sliceType)(unsafe.Pointer(tt))
if slice.elem == typ {
ti, _ := lookupCache.LoadOrStore(ckey, tt)
return ti.(Type)
}
}
// Make a slice type.
var islice any = ([]unsafe.Pointer)(nil)
prototype := *(**sliceType)(unsafe.Pointer(&islice))
slice := *prototype
slice.tflag = 0
slice.str = resolveReflectName(newName(s, "", false, false))
slice.hash = fnv1(typ.hash, '[')
slice.elem = typ
slice.ptrToThis = 0
ti, _ := lookupCache.LoadOrStore(ckey, &slice.rtype)
return ti.(Type)
}
// The structLookupCache caches StructOf lookups.
// StructOf does not share the common lookupCache since we need to pin
// the memory associated with *structTypeFixedN.
var structLookupCache struct {
sync.Mutex // Guards stores (but not loads) on m.
// m is a map[uint32][]Type keyed by the hash calculated in StructOf.
// Elements in m are append-only and thus safe for concurrent reading.
m sync.Map
}
type structTypeUncommon struct {
structType
u uncommonType
}
// isLetter reports whether a given 'rune' is classified as a Letter.
func isLetter(ch rune) bool {
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch >= utf8.RuneSelf && unicode.IsLetter(ch)
}
// isValidFieldName checks if a string is a valid (struct) field name or not.
//
// According to the language spec, a field name should be an identifier.
//
// identifier = letter { letter | unicode_digit } .
// letter = unicode_letter | "_" .
func isValidFieldName(fieldName string) bool {
for i, c := range fieldName {
if i == 0 && !isLetter(c) {
return false
}
if !(isLetter(c) || unicode.IsDigit(c)) {
return false
}
}
return len(fieldName) > 0
}
// StructOf returns the struct type containing fields.
// The Offset and Index fields are ignored and computed as they would be
// by the compiler.
//
// StructOf currently does not generate wrapper methods for embedded
// fields and panics if passed unexported StructFields.
// These limitations may be lifted in a future version.
func StructOf(fields []StructField) Type {
var (
hash = fnv1(0, []byte("struct {")...)
size uintptr
typalign uint8
comparable = true
methods []method
fs = make([]structField, len(fields))
repr = make([]byte, 0, 64)
fset = map[string]struct{}{} // fields' names
hasGCProg = false // records whether a struct-field type has a GCProg
)
lastzero := uintptr(0)
repr = append(repr, "struct {"...)
pkgpath := ""
for i, field := range fields {
if field.Name == "" {
panic("reflect.StructOf: field " + strconv.Itoa(i) + " has no name")
}
if !isValidFieldName(field.Name) {
panic("reflect.StructOf: field " + strconv.Itoa(i) + " has invalid name")
}
if field.Type == nil {
panic("reflect.StructOf: field " + strconv.Itoa(i) + " has no type")
}
f, fpkgpath := runtimeStructField(field)
ft := f.typ
if ft.kind&kindGCProg != 0 {
hasGCProg = true
}
if fpkgpath != "" {
if pkgpath == "" {
pkgpath = fpkgpath
} else if pkgpath != fpkgpath {
panic("reflect.Struct: fields with different PkgPath " + pkgpath + " and " + fpkgpath)
}
}
// Update string and hash
name := f.name.name()
hash = fnv1(hash, []byte(name)...)
repr = append(repr, (" " + name)...)
if f.embedded() {
// Embedded field
if f.typ.Kind() == Pointer {
// Embedded ** and *interface{} are illegal
elem := ft.Elem()
if k := elem.Kind(); k == Pointer || k == Interface {
panic("reflect.StructOf: illegal embedded field type " + ft.String())
}
}
switch f.typ.Kind() {
case Interface:
ift := (*interfaceType)(unsafe.Pointer(ft))
for im, m := range ift.methods {
if ift.nameOff(m.name).pkgPath() != "" {
// TODO(sbinet). Issue 15924.
panic("reflect: embedded interface with unexported method(s) not implemented")
}
var (
mtyp = ift.typeOff(m.typ)
ifield = i
imethod = im
ifn Value
tfn Value
)
if ft.kind&kindDirectIface != 0 {
tfn = MakeFunc(mtyp, func(in []Value) []Value {
var args []Value
var recv = in[0]
if len(in) > 1 {
args = in[1:]
}
return recv.Field(ifield).Method(imethod).Call(args)
})
ifn = MakeFunc(mtyp, func(in []Value) []Value {
var args []Value
var recv = in[0]
if len(in) > 1 {
args = in[1:]
}
return recv.Field(ifield).Method(imethod).Call(args)
})
} else {
tfn = MakeFunc(mtyp, func(in []Value) []Value {
var args []Value
var recv = in[0]
if len(in) > 1 {
args = in[1:]
}
return recv.Field(ifield).Method(imethod).Call(args)
})
ifn = MakeFunc(mtyp, func(in []Value) []Value {
var args []Value
var recv = Indirect(in[0])
if len(in) > 1 {
args = in[1:]
}
return recv.Field(ifield).Method(imethod).Call(args)
})
}
methods = append(methods, method{
name: resolveReflectName(ift.nameOff(m.name)),
mtyp: resolveReflectType(mtyp),
ifn: resolveReflectText(unsafe.Pointer(&ifn)),
tfn: resolveReflectText(unsafe.Pointer(&tfn)),
})
}
case Pointer:
ptr := (*ptrType)(unsafe.Pointer(ft))
if unt := ptr.uncommon(); unt != nil {
if i > 0 && unt.mcount > 0 {
// Issue 15924.
panic("reflect: embedded type with methods not implemented if type is not first field")
}
if len(fields) > 1 {
panic("reflect: embedded type with methods not implemented if there is more than one field")
}
for _, m := range unt.methods() {
mname := ptr.nameOff(m.name)
if mname.pkgPath() != "" {
// TODO(sbinet).
// Issue 15924.
panic("reflect: embedded interface with unexported method(s) not implemented")
}
methods = append(methods, method{
name: resolveReflectName(mname),
mtyp: resolveReflectType(ptr.typeOff(m.mtyp)),
ifn: resolveReflectText(ptr.textOff(m.ifn)),
tfn: resolveReflectText(ptr.textOff(m.tfn)),
})
}
}
if unt := ptr.elem.uncommon(); unt != nil {
for _, m := range unt.methods() {
mname := ptr.nameOff(m.name)
if mname.pkgPath() != "" {
// TODO(sbinet)
// Issue 15924.
panic("reflect: embedded interface with unexported method(s) not implemented")
}
methods = append(methods, method{
name: resolveReflectName(mname),
mtyp: resolveReflectType(ptr.elem.typeOff(m.mtyp)),
ifn: resolveReflectText(ptr.elem.textOff(m.ifn)),
tfn: resolveReflectText(ptr.elem.textOff(m.tfn)),
})
}
}
default:
if unt := ft.uncommon(); unt != nil {
if i > 0 && unt.mcount > 0 {
// Issue 15924.
panic("reflect: embedded type with methods not implemented if type is not first field")
}
if len(fields) > 1 && ft.kind&kindDirectIface != 0 {
panic("reflect: embedded type with methods not implemented for non-pointer type")
}
for _, m := range unt.methods() {
mname := ft.nameOff(m.name)
if mname.pkgPath() != "" {
// TODO(sbinet)
// Issue 15924.
panic("reflect: embedded interface with unexported method(s) not implemented")
}
methods = append(methods, method{
name: resolveReflectName(mname),
mtyp: resolveReflectType(ft.typeOff(m.mtyp)),
ifn: resolveReflectText(ft.textOff(m.ifn)),
tfn: resolveReflectText(ft.textOff(m.tfn)),
})
}
}
}
}
if _, dup := fset[name]; dup && name != "_" {
panic("reflect.StructOf: duplicate field " + name)
}
fset[name] = struct{}{}
hash = fnv1(hash, byte(ft.hash>>24), byte(ft.hash>>16), byte(ft.hash>>8), byte(ft.hash))
repr = append(repr, (" " + ft.String())...)
if f.name.hasTag() {
hash = fnv1(hash, []byte(f.name.tag())...)
repr = append(repr, (" " + strconv.Quote(f.name.tag()))...)
}
if i < len(fields)-1 {
repr = append(repr, ';')
}
comparable = comparable && (ft.equal != nil)
offset := align(size, uintptr(ft.align))
if offset < size {
panic("reflect.StructOf: struct size would exceed virtual address space")
}
if ft.align > typalign {
typalign = ft.align
}
size = offset + ft.size
if size < offset {
panic("reflect.StructOf: struct size would exceed virtual address space")
}
f.offset = offset
if ft.size == 0 {
lastzero = size
}
fs[i] = f
}
if size > 0 && lastzero == size {
// This is a non-zero sized struct that ends in a
// zero-sized field. We add an extra byte of padding,
// to ensure that taking the address of the final
// zero-sized field can't manufacture a pointer to the
// next object in the heap. See issue 9401.
size++
if size == 0 {
panic("reflect.StructOf: struct size would exceed virtual address space")
}
}
var typ *structType
var ut *uncommonType
if len(methods) == 0 {
t := new(structTypeUncommon)
typ = &t.structType
ut = &t.u
} else {
// A *rtype representing a struct is followed directly in memory by an
// array of method objects representing the methods attached to the
// struct. To get the same layout for a run time generated type, we
// need an array directly following the uncommonType memory.
// A similar strategy is used for funcTypeFixed4, ...funcTypeFixedN.
tt := New(StructOf([]StructField{
{Name: "S", Type: TypeOf(structType{})},
{Name: "U", Type: TypeOf(uncommonType{})},
{Name: "M", Type: ArrayOf(len(methods), TypeOf(methods[0]))},
}))
typ = (*structType)(tt.Elem().Field(0).Addr().UnsafePointer())
ut = (*uncommonType)(tt.Elem().Field(1).Addr().UnsafePointer())
copy(tt.Elem().Field(2).Slice(0, len(methods)).Interface().([]method), methods)
}
// TODO(sbinet): Once we allow embedding multiple types,
// methods will need to be sorted like the compiler does.
// TODO(sbinet): Once we allow non-exported methods, we will
// need to compute xcount as the number of exported methods.
ut.mcount = uint16(len(methods))
ut.xcount = ut.mcount
ut.moff = uint32(unsafe.Sizeof(uncommonType{}))
if len(fs) > 0 {
repr = append(repr, ' ')
}
repr = append(repr, '}')
hash = fnv1(hash, '}')
str := string(repr)
// Round the size up to be a multiple of the alignment.
s := align(size, uintptr(typalign))
if s < size {
panic("reflect.StructOf: struct size would exceed virtual address space")
}
size = s
// Make the struct type.
var istruct any = struct{}{}
prototype := *(**structType)(unsafe.Pointer(&istruct))
*typ = *prototype
typ.fields = fs
if pkgpath != "" {
typ.pkgPath = newName(pkgpath, "", false, false)
}
// Look in cache.
if ts, ok := structLookupCache.m.Load(hash); ok {
for _, st := range ts.([]Type) {
t := st.common()
if haveIdenticalUnderlyingType(&typ.rtype, t, true) {
return t
}
}
}
// Not in cache, lock and retry.
structLookupCache.Lock()
defer structLookupCache.Unlock()
if ts, ok := structLookupCache.m.Load(hash); ok {
for _, st := range ts.([]Type) {
t := st.common()
if haveIdenticalUnderlyingType(&typ.rtype, t, true) {
return t
}
}
}
addToCache := func(t Type) Type {
var ts []Type
if ti, ok := structLookupCache.m.Load(hash); ok {
ts = ti.([]Type)
}
structLookupCache.m.Store(hash, append(ts, t))
return t
}
// Look in known types.
for _, t := range typesByString(str) {
if haveIdenticalUnderlyingType(&typ.rtype, t, true) {
// even if 't' wasn't a structType with methods, we should be ok
// as the 'u uncommonType' field won't be accessed except when
// tflag&tflagUncommon is set.
return addToCache(t)
}
}
typ.str = resolveReflectName(newName(str, "", false, false))
typ.tflag = 0 // TODO: set tflagRegularMemory
typ.hash = hash
typ.size = size
typ.ptrdata = typeptrdata(typ.common())
typ.align = typalign
typ.fieldAlign = typalign
typ.ptrToThis = 0
if len(methods) > 0 {
typ.tflag |= tflagUncommon
}
if hasGCProg {
lastPtrField := 0
for i, ft := range fs {
if ft.typ.pointers() {
lastPtrField = i
}
}
prog := []byte{0, 0, 0, 0} // will be length of prog
var off uintptr
for i, ft := range fs {
if i > lastPtrField {
// gcprog should not include anything for any field after
// the last field that contains pointer data
break
}
if !ft.typ.pointers() {
// Ignore pointerless fields.
continue
}
// Pad to start of this field with zeros.
if ft.offset > off {
n := (ft.offset - off) / goarch.PtrSize
prog = append(prog, 0x01, 0x00) // emit a 0 bit
if n > 1 {
prog = append(prog, 0x81) // repeat previous bit
prog = appendVarint(prog, n-1) // n-1 times
}
off = ft.offset
}
prog = appendGCProg(prog, ft.typ)
off += ft.typ.ptrdata
}
prog = append(prog, 0)
*(*uint32)(unsafe.Pointer(&prog[0])) = uint32(len(prog) - 4)
typ.kind |= kindGCProg
typ.gcdata = &prog[0]
} else {
typ.kind &^= kindGCProg
bv := new(bitVector)
addTypeBits(bv, 0, typ.common())
if len(bv.data) > 0 {
typ.gcdata = &bv.data[0]
}
}
typ.equal = nil
if comparable {
typ.equal = func(p, q unsafe.Pointer) bool {
for _, ft := range typ.fields {
pi := add(p, ft.offset, "&x.field safe")
qi := add(q, ft.offset, "&x.field safe")
if !ft.typ.equal(pi, qi) {
return false
}
}
return true
}
}
switch {
case len(fs) == 1 && !ifaceIndir(fs[0].typ):
// structs of 1 direct iface type can be direct
typ.kind |= kindDirectIface
default:
typ.kind &^= kindDirectIface
}
return addToCache(&typ.rtype)
}
// runtimeStructField takes a StructField value passed to StructOf and
// returns both the corresponding internal representation, of type
// structField, and the pkgpath value to use for this field.
func runtimeStructField(field StructField) (structField, string) {
if field.Anonymous && field.PkgPath != "" {
panic("reflect.StructOf: field \"" + field.Name + "\" is anonymous but has PkgPath set")
}
if field.IsExported() {
// Best-effort check for misuse.
// Since this field will be treated as exported, not much harm done if Unicode lowercase slips through.
c := field.Name[0]
if 'a' <= c && c <= 'z' || c == '_' {
panic("reflect.StructOf: field \"" + field.Name + "\" is unexported but missing PkgPath")
}
}
resolveReflectType(field.Type.common()) // install in runtime
f := structField{
name: newName(field.Name, string(field.Tag), field.IsExported(), field.Anonymous),
typ: field.Type.common(),
offset: 0,
}
return f, field.PkgPath
}
// typeptrdata returns the length in bytes of the prefix of t
// containing pointer data. Anything after this offset is scalar data.
// keep in sync with ../cmd/compile/internal/reflectdata/reflect.go
func typeptrdata(t *rtype) uintptr {
switch t.Kind() {
case Struct:
st := (*structType)(unsafe.Pointer(t))
// find the last field that has pointers.
field := -1
for i := range st.fields {
ft := st.fields[i].typ
if ft.pointers() {
field = i
}
}
if field == -1 {
return 0
}
f := st.fields[field]
return f.offset + f.typ.ptrdata
default:
panic("reflect.typeptrdata: unexpected type, " + t.String())
}
}
// See cmd/compile/internal/reflectdata/reflect.go for derivation of constant.
const maxPtrmaskBytes = 2048
// ArrayOf returns the array type with the given length and element type.
// For example, if t represents int, ArrayOf(5, t) represents [5]int.
//
// If the resulting type would be larger than the available address space,
// ArrayOf panics.
func ArrayOf(length int, elem Type) Type {
if length < 0 {
panic("reflect: negative length passed to ArrayOf")
}
typ := elem.(*rtype)
// Look in cache.
ckey := cacheKey{Array, typ, nil, uintptr(length)}
if array, ok := lookupCache.Load(ckey); ok {
return array.(Type)
}
// Look in known types.
s := "[" + strconv.Itoa(length) + "]" + typ.String()
for _, tt := range typesByString(s) {
array := (*arrayType)(unsafe.Pointer(tt))
if array.elem == typ {
ti, _ := lookupCache.LoadOrStore(ckey, tt)
return ti.(Type)
}
}
// Make an array type.
var iarray any = [1]unsafe.Pointer{}
prototype := *(**arrayType)(unsafe.Pointer(&iarray))
array := *prototype
array.tflag = typ.tflag & tflagRegularMemory
array.str = resolveReflectName(newName(s, "", false, false))
array.hash = fnv1(typ.hash, '[')
for n := uint32(length); n > 0; n >>= 8 {
array.hash = fnv1(array.hash, byte(n))
}
array.hash = fnv1(array.hash, ']')
array.elem = typ
array.ptrToThis = 0
if typ.size > 0 {
max := ^uintptr(0) / typ.size
if uintptr(length) > max {
panic("reflect.ArrayOf: array size would exceed virtual address space")
}
}
array.size = typ.size * uintptr(length)
if length > 0 && typ.ptrdata != 0 {
array.ptrdata = typ.size*uintptr(length-1) + typ.ptrdata
}
array.align = typ.align
array.fieldAlign = typ.fieldAlign
array.len = uintptr(length)
array.slice = SliceOf(elem).(*rtype)
switch {
case typ.ptrdata == 0 || array.size == 0:
// No pointers.
array.gcdata = nil
array.ptrdata = 0
case length == 1:
// In memory, 1-element array looks just like the element.
array.kind |= typ.kind & kindGCProg
array.gcdata = typ.gcdata
array.ptrdata = typ.ptrdata
case typ.kind&kindGCProg == 0 && array.size <= maxPtrmaskBytes*8*goarch.PtrSize:
// Element is small with pointer mask; array is still small.
// Create direct pointer mask by turning each 1 bit in elem
// into length 1 bits in larger mask.
n := (array.ptrdata/goarch.PtrSize + 7) / 8
// Runtime needs pointer masks to be a multiple of uintptr in size.
n = (n + goarch.PtrSize - 1) &^ (goarch.PtrSize - 1)
mask := make([]byte, n)
emitGCMask(mask, 0, typ, array.len)
array.gcdata = &mask[0]
default:
// Create program that emits one element
// and then repeats to make the array.
prog := []byte{0, 0, 0, 0} // will be length of prog
prog = appendGCProg(prog, typ)
// Pad from ptrdata to size.
elemPtrs := typ.ptrdata / goarch.PtrSize
elemWords := typ.size / goarch.PtrSize
if elemPtrs < elemWords {
// Emit literal 0 bit, then repeat as needed.
prog = append(prog, 0x01, 0x00)
if elemPtrs+1 < elemWords {
prog = append(prog, 0x81)
prog = appendVarint(prog, elemWords-elemPtrs-1)
}
}
// Repeat length-1 times.
if elemWords < 0x80 {
prog = append(prog, byte(elemWords|0x80))
} else {
prog = append(prog, 0x80)
prog = appendVarint(prog, elemWords)
}
prog = appendVarint(prog, uintptr(length)-1)
prog = append(prog, 0)
*(*uint32)(unsafe.Pointer(&prog[0])) = uint32(len(prog) - 4)
array.kind |= kindGCProg
array.gcdata = &prog[0]
array.ptrdata = array.size // overestimate but ok; must match program
}
etyp := typ.common()
esize := etyp.Size()
array.equal = nil
if eequal := etyp.equal; eequal != nil {
array.equal = func(p, q unsafe.Pointer) bool {
for i := 0; i < length; i++ {
pi := arrayAt(p, i, esize, "i < length")
qi := arrayAt(q, i, esize, "i < length")
if !eequal(pi, qi) {
return false
}
}
return true
}
}
switch {
case length == 1 && !ifaceIndir(typ):
// array of 1 direct iface type can be direct
array.kind |= kindDirectIface
default:
array.kind &^= kindDirectIface
}
ti, _ := lookupCache.LoadOrStore(ckey, &array.rtype)
return ti.(Type)
}
func appendVarint(x []byte, v uintptr) []byte {
for ; v >= 0x80; v >>= 7 {
x = append(x, byte(v|0x80))
}
x = append(x, byte(v))
return x
}
// toType converts from a *rtype to a Type that can be returned
// to the client of package reflect. In gc, the only concern is that
// a nil *rtype must be replaced by a nil Type, but in gccgo this
// function takes care of ensuring that multiple *rtype for the same
// type are coalesced into a single Type.
func toType(t *rtype) Type {
if t == nil {
return nil
}
return t
}
type layoutKey struct {
ftyp *funcType // function signature
rcvr *rtype // receiver type, or nil if none
}
type layoutType struct {
t *rtype
framePool *sync.Pool
abid abiDesc
}
var layoutCache sync.Map // map[layoutKey]layoutType
// funcLayout computes a struct type representing the layout of the
// stack-assigned function arguments and return values for the function
// type t.
// If rcvr != nil, rcvr specifies the type of the receiver.
// The returned type exists only for GC, so we only fill out GC relevant info.
// Currently, that's just size and the GC program. We also fill in
// the name for possible debugging use.
func funcLayout(t *funcType, rcvr *rtype) (frametype *rtype, framePool *sync.Pool, abid abiDesc) {
if t.Kind() != Func {
panic("reflect: funcLayout of non-func type " + t.String())
}
if rcvr != nil && rcvr.Kind() == Interface {
panic("reflect: funcLayout with interface receiver " + rcvr.String())
}
k := layoutKey{t, rcvr}
if lti, ok := layoutCache.Load(k); ok {
lt := lti.(layoutType)
return lt.t, lt.framePool, lt.abid
}
// Compute the ABI layout.
abid = newAbiDesc(t, rcvr)
// build dummy rtype holding gc program
x := &rtype{
align: goarch.PtrSize,
// Don't add spill space here; it's only necessary in
// reflectcall's frame, not in the allocated frame.
// TODO(mknyszek): Remove this comment when register
// spill space in the frame is no longer required.
size: align(abid.retOffset+abid.ret.stackBytes, goarch.PtrSize),
ptrdata: uintptr(abid.stackPtrs.n) * goarch.PtrSize,
}
if abid.stackPtrs.n > 0 {
x.gcdata = &abid.stackPtrs.data[0]
}
var s string
if rcvr != nil {
s = "methodargs(" + rcvr.String() + ")(" + t.String() + ")"
} else {
s = "funcargs(" + t.String() + ")"
}
x.str = resolveReflectName(newName(s, "", false, false))
// cache result for future callers
framePool = &sync.Pool{New: func() any {
return unsafe_New(x)
}}
lti, _ := layoutCache.LoadOrStore(k, layoutType{
t: x,
framePool: framePool,
abid: abid,
})
lt := lti.(layoutType)
return lt.t, lt.framePool, lt.abid
}
// ifaceIndir reports whether t is stored indirectly in an interface value.
func ifaceIndir(t *rtype) bool {
return t.kind&kindDirectIface == 0
}
// Note: this type must agree with runtime.bitvector.
type bitVector struct {
n uint32 // number of bits
data []byte
}
// append a bit to the bitmap.
func (bv *bitVector) append(bit uint8) {
if bv.n%(8*goarch.PtrSize) == 0 {
// Runtime needs pointer masks to be a multiple of uintptr in size.
// Since reflect passes bv.data directly to the runtime as a pointer mask,
// we append a full uintptr of zeros at a time.
for i := 0; i < goarch.PtrSize; i++ {
bv.data = append(bv.data, 0)
}
}
bv.data[bv.n/8] |= bit << (bv.n % 8)
bv.n++
}
func addTypeBits(bv *bitVector, offset uintptr, t *rtype) {
if t.ptrdata == 0 {
return
}
switch Kind(t.kind & kindMask) {
case Chan, Func, Map, Pointer, Slice, String, UnsafePointer:
// 1 pointer at start of representation
for bv.n < uint32(offset/uintptr(goarch.PtrSize)) {
bv.append(0)
}
bv.append(1)
case Interface:
// 2 pointers
for bv.n < uint32(offset/uintptr(goarch.PtrSize)) {
bv.append(0)
}
bv.append(1)
bv.append(1)
case Array:
// repeat inner type
tt := (*arrayType)(unsafe.Pointer(t))
for i := 0; i < int(tt.len); i++ {
addTypeBits(bv, offset+uintptr(i)*tt.elem.size, tt.elem)
}
case Struct:
// apply fields
tt := (*structType)(unsafe.Pointer(t))
for i := range tt.fields {
f := &tt.fields[i]
addTypeBits(bv, offset+f.offset, f.typ)
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflect
import (
"errors"
"internal/abi"
"internal/goarch"
"internal/itoa"
"internal/unsafeheader"
"math"
"runtime"
"unsafe"
)
// Value is the reflection interface to a Go value.
//
// Not all methods apply to all kinds of values. Restrictions,
// if any, are noted in the documentation for each method.
// Use the Kind method to find out the kind of value before
// calling kind-specific methods. Calling a method
// inappropriate to the kind of type causes a run time panic.
//
// The zero Value represents no value.
// Its IsValid method returns false, its Kind method returns Invalid,
// its String method returns "<invalid Value>", and all other methods panic.
// Most functions and methods never return an invalid value.
// If one does, its documentation states the conditions explicitly.
//
// A Value can be used concurrently by multiple goroutines provided that
// the underlying Go value can be used concurrently for the equivalent
// direct operations.
//
// To compare two Values, compare the results of the Interface method.
// Using == on two Values does not compare the underlying values
// they represent.
type Value struct {
// typ holds the type of the value represented by a Value.
typ *rtype
// Pointer-valued data or, if flagIndir is set, pointer to data.
// Valid when either flagIndir is set or typ.pointers() is true.
ptr unsafe.Pointer
// flag holds metadata about the value.
//
// The lowest five bits give the Kind of the value, mirroring typ.Kind().
//
// The next set of bits are flag bits:
// - flagStickyRO: obtained via unexported not embedded field, so read-only
// - flagEmbedRO: obtained via unexported embedded field, so read-only
// - flagIndir: val holds a pointer to the data
// - flagAddr: v.CanAddr is true (implies flagIndir and ptr is non-nil)
// - flagMethod: v is a method value.
// If ifaceIndir(typ), code can assume that flagIndir is set.
//
// The remaining 22+ bits give a method number for method values.
// If flag.kind() != Func, code can assume that flagMethod is unset.
flag
// A method value represents a curried method invocation
// like r.Read for some receiver r. The typ+val+flag bits describe
// the receiver r, but the flag's Kind bits say Func (methods are
// functions), and the top bits of the flag give the method number
// in r's type's method table.
}
type flag uintptr
const (
flagKindWidth = 5 // there are 27 kinds
flagKindMask flag = 1<<flagKindWidth - 1
flagStickyRO flag = 1 << 5
flagEmbedRO flag = 1 << 6
flagIndir flag = 1 << 7
flagAddr flag = 1 << 8
flagMethod flag = 1 << 9
flagMethodShift = 10
flagRO flag = flagStickyRO | flagEmbedRO
)
func (f flag) kind() Kind {
return Kind(f & flagKindMask)
}
func (f flag) ro() flag {
if f&flagRO != 0 {
return flagStickyRO
}
return 0
}
// pointer returns the underlying pointer represented by v.
// v.Kind() must be Pointer, Map, Chan, Func, or UnsafePointer
// if v.Kind() == Pointer, the base type must not be not-in-heap.
func (v Value) pointer() unsafe.Pointer {
if v.typ.size != goarch.PtrSize || !v.typ.pointers() {
panic("can't call pointer on a non-pointer Value")
}
if v.flag&flagIndir != 0 {
return *(*unsafe.Pointer)(v.ptr)
}
return v.ptr
}
// packEface converts v to the empty interface.
func packEface(v Value) any {
t := v.typ
var i any
e := (*emptyInterface)(unsafe.Pointer(&i))
// First, fill in the data portion of the interface.
switch {
case ifaceIndir(t):
if v.flag&flagIndir == 0 {
panic("bad indir")
}
// Value is indirect, and so is the interface we're making.
ptr := v.ptr
if v.flag&flagAddr != 0 {
// TODO: pass safe boolean from valueInterface so
// we don't need to copy if safe==true?
c := unsafe_New(t)
typedmemmove(t, c, ptr)
ptr = c
}
e.word = ptr
case v.flag&flagIndir != 0:
// Value is indirect, but interface is direct. We need
// to load the data at v.ptr into the interface data word.
e.word = *(*unsafe.Pointer)(v.ptr)
default:
// Value is direct, and so is the interface.
e.word = v.ptr
}
// Now, fill in the type portion. We're very careful here not
// to have any operation between the e.word and e.typ assignments
// that would let the garbage collector observe the partially-built
// interface value.
e.typ = t
return i
}
// unpackEface converts the empty interface i to a Value.
func unpackEface(i any) Value {
e := (*emptyInterface)(unsafe.Pointer(&i))
// NOTE: don't read e.word until we know whether it is really a pointer or not.
t := e.typ
if t == nil {
return Value{}
}
f := flag(t.Kind())
if ifaceIndir(t) {
f |= flagIndir
}
return Value{t, e.word, f}
}
// A ValueError occurs when a Value method is invoked on
// a Value that does not support it. Such cases are documented
// in the description of each method.
type ValueError struct {
Method string
Kind Kind
}
func (e *ValueError) Error() string {
if e.Kind == 0 {
return "reflect: call of " + e.Method + " on zero Value"
}
return "reflect: call of " + e.Method + " on " + e.Kind.String() + " Value"
}
// valueMethodName returns the name of the exported calling method on Value.
func valueMethodName() string {
var pc [5]uintptr
n := runtime.Callers(1, pc[:])
frames := runtime.CallersFrames(pc[:n])
var frame runtime.Frame
for more := true; more; {
const prefix = "reflect.Value."
frame, more = frames.Next()
name := frame.Function
if len(name) > len(prefix) && name[:len(prefix)] == prefix {
methodName := name[len(prefix):]
if len(methodName) > 0 && 'A' <= methodName[0] && methodName[0] <= 'Z' {
return name
}
}
}
return "unknown method"
}
// emptyInterface is the header for an interface{} value.
type emptyInterface struct {
typ *rtype
word unsafe.Pointer
}
// nonEmptyInterface is the header for an interface value with methods.
type nonEmptyInterface struct {
// see ../runtime/iface.go:/Itab
itab *struct {
ityp *rtype // static interface type
typ *rtype // dynamic concrete type
hash uint32 // copy of typ.hash
_ [4]byte
fun [100000]unsafe.Pointer // method table
}
word unsafe.Pointer
}
// mustBe panics if f's kind is not expected.
// Making this a method on flag instead of on Value
// (and embedding flag in Value) means that we can write
// the very clear v.mustBe(Bool) and have it compile into
// v.flag.mustBe(Bool), which will only bother to copy the
// single important word for the receiver.
func (f flag) mustBe(expected Kind) {
// TODO(mvdan): use f.kind() again once mid-stack inlining gets better
if Kind(f&flagKindMask) != expected {
panic(&ValueError{valueMethodName(), f.kind()})
}
}
// mustBeExported panics if f records that the value was obtained using
// an unexported field.
func (f flag) mustBeExported() {
if f == 0 || f&flagRO != 0 {
f.mustBeExportedSlow()
}
}
func (f flag) mustBeExportedSlow() {
if f == 0 {
panic(&ValueError{valueMethodName(), Invalid})
}
if f&flagRO != 0 {
panic("reflect: " + valueMethodName() + " using value obtained using unexported field")
}
}
// mustBeAssignable panics if f records that the value is not assignable,
// which is to say that either it was obtained using an unexported field
// or it is not addressable.
func (f flag) mustBeAssignable() {
if f&flagRO != 0 || f&flagAddr == 0 {
f.mustBeAssignableSlow()
}
}
func (f flag) mustBeAssignableSlow() {
if f == 0 {
panic(&ValueError{valueMethodName(), Invalid})
}
// Assignable if addressable and not read-only.
if f&flagRO != 0 {
panic("reflect: " + valueMethodName() + " using value obtained using unexported field")
}
if f&flagAddr == 0 {
panic("reflect: " + valueMethodName() + " using unaddressable value")
}
}
// Addr returns a pointer value representing the address of v.
// It panics if CanAddr() returns false.
// Addr is typically used to obtain a pointer to a struct field
// or slice element in order to call a method that requires a
// pointer receiver.
func (v Value) Addr() Value {
if v.flag&flagAddr == 0 {
panic("reflect.Value.Addr of unaddressable value")
}
// Preserve flagRO instead of using v.flag.ro() so that
// v.Addr().Elem() is equivalent to v (#32772)
fl := v.flag & flagRO
return Value{v.typ.ptrTo(), v.ptr, fl | flag(Pointer)}
}
// Bool returns v's underlying value.
// It panics if v's kind is not Bool.
func (v Value) Bool() bool {
// panicNotBool is split out to keep Bool inlineable.
if v.kind() != Bool {
v.panicNotBool()
}
return *(*bool)(v.ptr)
}
func (v Value) panicNotBool() {
v.mustBe(Bool)
}
var bytesType = rtypeOf(([]byte)(nil))
// Bytes returns v's underlying value.
// It panics if v's underlying value is not a slice of bytes or
// an addressable array of bytes.
func (v Value) Bytes() []byte {
// bytesSlow is split out to keep Bytes inlineable for unnamed []byte.
if v.typ == bytesType {
return *(*[]byte)(v.ptr)
}
return v.bytesSlow()
}
func (v Value) bytesSlow() []byte {
switch v.kind() {
case Slice:
if v.typ.Elem().Kind() != Uint8 {
panic("reflect.Value.Bytes of non-byte slice")
}
// Slice is always bigger than a word; assume flagIndir.
return *(*[]byte)(v.ptr)
case Array:
if v.typ.Elem().Kind() != Uint8 {
panic("reflect.Value.Bytes of non-byte array")
}
if !v.CanAddr() {
panic("reflect.Value.Bytes of unaddressable byte array")
}
p := (*byte)(v.ptr)
n := int((*arrayType)(unsafe.Pointer(v.typ)).len)
return unsafe.Slice(p, n)
}
panic(&ValueError{"reflect.Value.Bytes", v.kind()})
}
// runes returns v's underlying value.
// It panics if v's underlying value is not a slice of runes (int32s).
func (v Value) runes() []rune {
v.mustBe(Slice)
if v.typ.Elem().Kind() != Int32 {
panic("reflect.Value.Bytes of non-rune slice")
}
// Slice is always bigger than a word; assume flagIndir.
return *(*[]rune)(v.ptr)
}
// CanAddr reports whether the value's address can be obtained with Addr.
// Such values are called addressable. A value is addressable if it is
// an element of a slice, an element of an addressable array,
// a field of an addressable struct, or the result of dereferencing a pointer.
// If CanAddr returns false, calling Addr will panic.
func (v Value) CanAddr() bool {
return v.flag&flagAddr != 0
}
// CanSet reports whether the value of v can be changed.
// A Value can be changed only if it is addressable and was not
// obtained by the use of unexported struct fields.
// If CanSet returns false, calling Set or any type-specific
// setter (e.g., SetBool, SetInt) will panic.
func (v Value) CanSet() bool {
return v.flag&(flagAddr|flagRO) == flagAddr
}
// Call calls the function v with the input arguments in.
// For example, if len(in) == 3, v.Call(in) represents the Go call v(in[0], in[1], in[2]).
// Call panics if v's Kind is not Func.
// It returns the output results as Values.
// As in Go, each input argument must be assignable to the
// type of the function's corresponding input parameter.
// If v is a variadic function, Call creates the variadic slice parameter
// itself, copying in the corresponding values.
func (v Value) Call(in []Value) []Value {
v.mustBe(Func)
v.mustBeExported()
return v.call("Call", in)
}
// CallSlice calls the variadic function v with the input arguments in,
// assigning the slice in[len(in)-1] to v's final variadic argument.
// For example, if len(in) == 3, v.CallSlice(in) represents the Go call v(in[0], in[1], in[2]...).
// CallSlice panics if v's Kind is not Func or if v is not variadic.
// It returns the output results as Values.
// As in Go, each input argument must be assignable to the
// type of the function's corresponding input parameter.
func (v Value) CallSlice(in []Value) []Value {
v.mustBe(Func)
v.mustBeExported()
return v.call("CallSlice", in)
}
var callGC bool // for testing; see TestCallMethodJump and TestCallArgLive
const debugReflectCall = false
func (v Value) call(op string, in []Value) []Value {
// Get function pointer, type.
t := (*funcType)(unsafe.Pointer(v.typ))
var (
fn unsafe.Pointer
rcvr Value
rcvrtype *rtype
)
if v.flag&flagMethod != 0 {
rcvr = v
rcvrtype, t, fn = methodReceiver(op, v, int(v.flag)>>flagMethodShift)
} else if v.flag&flagIndir != 0 {
fn = *(*unsafe.Pointer)(v.ptr)
} else {
fn = v.ptr
}
if fn == nil {
panic("reflect.Value.Call: call of nil function")
}
isSlice := op == "CallSlice"
n := t.NumIn()
isVariadic := t.IsVariadic()
if isSlice {
if !isVariadic {
panic("reflect: CallSlice of non-variadic function")
}
if len(in) < n {
panic("reflect: CallSlice with too few input arguments")
}
if len(in) > n {
panic("reflect: CallSlice with too many input arguments")
}
} else {
if isVariadic {
n--
}
if len(in) < n {
panic("reflect: Call with too few input arguments")
}
if !isVariadic && len(in) > n {
panic("reflect: Call with too many input arguments")
}
}
for _, x := range in {
if x.Kind() == Invalid {
panic("reflect: " + op + " using zero Value argument")
}
}
for i := 0; i < n; i++ {
if xt, targ := in[i].Type(), t.In(i); !xt.AssignableTo(targ) {
panic("reflect: " + op + " using " + xt.String() + " as type " + targ.String())
}
}
if !isSlice && isVariadic {
// prepare slice for remaining values
m := len(in) - n
slice := MakeSlice(t.In(n), m, m)
elem := t.In(n).Elem()
for i := 0; i < m; i++ {
x := in[n+i]
if xt := x.Type(); !xt.AssignableTo(elem) {
panic("reflect: cannot use " + xt.String() + " as type " + elem.String() + " in " + op)
}
slice.Index(i).Set(x)
}
origIn := in
in = make([]Value, n+1)
copy(in[:n], origIn)
in[n] = slice
}
nin := len(in)
if nin != t.NumIn() {
panic("reflect.Value.Call: wrong argument count")
}
nout := t.NumOut()
// Register argument space.
var regArgs abi.RegArgs
// Compute frame type.
frametype, framePool, abid := funcLayout(t, rcvrtype)
// Allocate a chunk of memory for frame if needed.
var stackArgs unsafe.Pointer
if frametype.size != 0 {
if nout == 0 {
stackArgs = framePool.Get().(unsafe.Pointer)
} else {
// Can't use pool if the function has return values.
// We will leak pointer to args in ret, so its lifetime is not scoped.
stackArgs = unsafe_New(frametype)
}
}
frameSize := frametype.size
if debugReflectCall {
println("reflect.call", t.String())
abid.dump()
}
// Copy inputs into args.
// Handle receiver.
inStart := 0
if rcvrtype != nil {
// Guaranteed to only be one word in size,
// so it will only take up exactly 1 abiStep (either
// in a register or on the stack).
switch st := abid.call.steps[0]; st.kind {
case abiStepStack:
storeRcvr(rcvr, stackArgs)
case abiStepPointer:
storeRcvr(rcvr, unsafe.Pointer(®Args.Ptrs[st.ireg]))
fallthrough
case abiStepIntReg:
storeRcvr(rcvr, unsafe.Pointer(®Args.Ints[st.ireg]))
case abiStepFloatReg:
storeRcvr(rcvr, unsafe.Pointer(®Args.Floats[st.freg]))
default:
panic("unknown ABI parameter kind")
}
inStart = 1
}
// Handle arguments.
for i, v := range in {
v.mustBeExported()
targ := t.In(i).(*rtype)
// TODO(mknyszek): Figure out if it's possible to get some
// scratch space for this assignment check. Previously, it
// was possible to use space in the argument frame.
v = v.assignTo("reflect.Value.Call", targ, nil)
stepsLoop:
for _, st := range abid.call.stepsForValue(i + inStart) {
switch st.kind {
case abiStepStack:
// Copy values to the "stack."
addr := add(stackArgs, st.stkOff, "precomputed stack arg offset")
if v.flag&flagIndir != 0 {
typedmemmove(targ, addr, v.ptr)
} else {
*(*unsafe.Pointer)(addr) = v.ptr
}
// There's only one step for a stack-allocated value.
break stepsLoop
case abiStepIntReg, abiStepPointer:
// Copy values to "integer registers."
if v.flag&flagIndir != 0 {
offset := add(v.ptr, st.offset, "precomputed value offset")
if st.kind == abiStepPointer {
// Duplicate this pointer in the pointer area of the
// register space. Otherwise, there's the potential for
// this to be the last reference to v.ptr.
regArgs.Ptrs[st.ireg] = *(*unsafe.Pointer)(offset)
}
intToReg(®Args, st.ireg, st.size, offset)
} else {
if st.kind == abiStepPointer {
// See the comment in abiStepPointer case above.
regArgs.Ptrs[st.ireg] = v.ptr
}
regArgs.Ints[st.ireg] = uintptr(v.ptr)
}
case abiStepFloatReg:
// Copy values to "float registers."
if v.flag&flagIndir == 0 {
panic("attempted to copy pointer to FP register")
}
offset := add(v.ptr, st.offset, "precomputed value offset")
floatToReg(®Args, st.freg, st.size, offset)
default:
panic("unknown ABI part kind")
}
}
}
// TODO(mknyszek): Remove this when we no longer have
// caller reserved spill space.
frameSize = align(frameSize, goarch.PtrSize)
frameSize += abid.spill
// Mark pointers in registers for the return path.
regArgs.ReturnIsPtr = abid.outRegPtrs
if debugReflectCall {
regArgs.Dump()
}
// For testing; see TestCallArgLive.
if callGC {
runtime.GC()
}
// Call.
call(frametype, fn, stackArgs, uint32(frametype.size), uint32(abid.retOffset), uint32(frameSize), ®Args)
// For testing; see TestCallMethodJump.
if callGC {
runtime.GC()
}
var ret []Value
if nout == 0 {
if stackArgs != nil {
typedmemclr(frametype, stackArgs)
framePool.Put(stackArgs)
}
} else {
if stackArgs != nil {
// Zero the now unused input area of args,
// because the Values returned by this function contain pointers to the args object,
// and will thus keep the args object alive indefinitely.
typedmemclrpartial(frametype, stackArgs, 0, abid.retOffset)
}
// Wrap Values around return values in args.
ret = make([]Value, nout)
for i := 0; i < nout; i++ {
tv := t.Out(i)
if tv.Size() == 0 {
// For zero-sized return value, args+off may point to the next object.
// In this case, return the zero value instead.
ret[i] = Zero(tv)
continue
}
steps := abid.ret.stepsForValue(i)
if st := steps[0]; st.kind == abiStepStack {
// This value is on the stack. If part of a value is stack
// allocated, the entire value is according to the ABI. So
// just make an indirection into the allocated frame.
fl := flagIndir | flag(tv.Kind())
ret[i] = Value{tv.common(), add(stackArgs, st.stkOff, "tv.Size() != 0"), fl}
// Note: this does introduce false sharing between results -
// if any result is live, they are all live.
// (And the space for the args is live as well, but as we've
// cleared that space it isn't as big a deal.)
continue
}
// Handle pointers passed in registers.
if !ifaceIndir(tv.common()) {
// Pointer-valued data gets put directly
// into v.ptr.
if steps[0].kind != abiStepPointer {
print("kind=", steps[0].kind, ", type=", tv.String(), "\n")
panic("mismatch between ABI description and types")
}
ret[i] = Value{tv.common(), regArgs.Ptrs[steps[0].ireg], flag(tv.Kind())}
continue
}
// All that's left is values passed in registers that we need to
// create space for and copy values back into.
//
// TODO(mknyszek): We make a new allocation for each register-allocated
// value, but previously we could always point into the heap-allocated
// stack frame. This is a regression that could be fixed by adding
// additional space to the allocated stack frame and storing the
// register-allocated return values into the allocated stack frame and
// referring there in the resulting Value.
s := unsafe_New(tv.common())
for _, st := range steps {
switch st.kind {
case abiStepIntReg:
offset := add(s, st.offset, "precomputed value offset")
intFromReg(®Args, st.ireg, st.size, offset)
case abiStepPointer:
s := add(s, st.offset, "precomputed value offset")
*((*unsafe.Pointer)(s)) = regArgs.Ptrs[st.ireg]
case abiStepFloatReg:
offset := add(s, st.offset, "precomputed value offset")
floatFromReg(®Args, st.freg, st.size, offset)
case abiStepStack:
panic("register-based return value has stack component")
default:
panic("unknown ABI part kind")
}
}
ret[i] = Value{tv.common(), s, flagIndir | flag(tv.Kind())}
}
}
return ret
}
// callReflect is the call implementation used by a function
// returned by MakeFunc. In many ways it is the opposite of the
// method Value.call above. The method above converts a call using Values
// into a call of a function with a concrete argument frame, while
// callReflect converts a call of a function with a concrete argument
// frame into a call using Values.
// It is in this file so that it can be next to the call method above.
// The remainder of the MakeFunc implementation is in makefunc.go.
//
// NOTE: This function must be marked as a "wrapper" in the generated code,
// so that the linker can make it work correctly for panic and recover.
// The gc compilers know to do that for the name "reflect.callReflect".
//
// ctxt is the "closure" generated by MakeFunc.
// frame is a pointer to the arguments to that closure on the stack.
// retValid points to a boolean which should be set when the results
// section of frame is set.
//
// regs contains the argument values passed in registers and will contain
// the values returned from ctxt.fn in registers.
func callReflect(ctxt *makeFuncImpl, frame unsafe.Pointer, retValid *bool, regs *abi.RegArgs) {
if callGC {
// Call GC upon entry during testing.
// Getting our stack scanned here is the biggest hazard, because
// our caller (makeFuncStub) could have failed to place the last
// pointer to a value in regs' pointer space, in which case it
// won't be visible to the GC.
runtime.GC()
}
ftyp := ctxt.ftyp
f := ctxt.fn
_, _, abid := funcLayout(ftyp, nil)
// Copy arguments into Values.
ptr := frame
in := make([]Value, 0, int(ftyp.inCount))
for i, typ := range ftyp.in() {
if typ.Size() == 0 {
in = append(in, Zero(typ))
continue
}
v := Value{typ, nil, flag(typ.Kind())}
steps := abid.call.stepsForValue(i)
if st := steps[0]; st.kind == abiStepStack {
if ifaceIndir(typ) {
// value cannot be inlined in interface data.
// Must make a copy, because f might keep a reference to it,
// and we cannot let f keep a reference to the stack frame
// after this function returns, not even a read-only reference.
v.ptr = unsafe_New(typ)
if typ.size > 0 {
typedmemmove(typ, v.ptr, add(ptr, st.stkOff, "typ.size > 0"))
}
v.flag |= flagIndir
} else {
v.ptr = *(*unsafe.Pointer)(add(ptr, st.stkOff, "1-ptr"))
}
} else {
if ifaceIndir(typ) {
// All that's left is values passed in registers that we need to
// create space for the values.
v.flag |= flagIndir
v.ptr = unsafe_New(typ)
for _, st := range steps {
switch st.kind {
case abiStepIntReg:
offset := add(v.ptr, st.offset, "precomputed value offset")
intFromReg(regs, st.ireg, st.size, offset)
case abiStepPointer:
s := add(v.ptr, st.offset, "precomputed value offset")
*((*unsafe.Pointer)(s)) = regs.Ptrs[st.ireg]
case abiStepFloatReg:
offset := add(v.ptr, st.offset, "precomputed value offset")
floatFromReg(regs, st.freg, st.size, offset)
case abiStepStack:
panic("register-based return value has stack component")
default:
panic("unknown ABI part kind")
}
}
} else {
// Pointer-valued data gets put directly
// into v.ptr.
if steps[0].kind != abiStepPointer {
print("kind=", steps[0].kind, ", type=", typ.String(), "\n")
panic("mismatch between ABI description and types")
}
v.ptr = regs.Ptrs[steps[0].ireg]
}
}
in = append(in, v)
}
// Call underlying function.
out := f(in)
numOut := ftyp.NumOut()
if len(out) != numOut {
panic("reflect: wrong return count from function created by MakeFunc")
}
// Copy results back into argument frame and register space.
if numOut > 0 {
for i, typ := range ftyp.out() {
v := out[i]
if v.typ == nil {
panic("reflect: function created by MakeFunc using " + funcName(f) +
" returned zero Value")
}
if v.flag&flagRO != 0 {
panic("reflect: function created by MakeFunc using " + funcName(f) +
" returned value obtained from unexported field")
}
if typ.size == 0 {
continue
}
// Convert v to type typ if v is assignable to a variable
// of type t in the language spec.
// See issue 28761.
//
//
// TODO(mknyszek): In the switch to the register ABI we lost
// the scratch space here for the register cases (and
// temporarily for all the cases).
//
// If/when this happens, take note of the following:
//
// We must clear the destination before calling assignTo,
// in case assignTo writes (with memory barriers) to the
// target location used as scratch space. See issue 39541.
v = v.assignTo("reflect.MakeFunc", typ, nil)
stepsLoop:
for _, st := range abid.ret.stepsForValue(i) {
switch st.kind {
case abiStepStack:
// Copy values to the "stack."
addr := add(ptr, st.stkOff, "precomputed stack arg offset")
// Do not use write barriers. The stack space used
// for this call is not adequately zeroed, and we
// are careful to keep the arguments alive until we
// return to makeFuncStub's caller.
if v.flag&flagIndir != 0 {
memmove(addr, v.ptr, st.size)
} else {
// This case must be a pointer type.
*(*uintptr)(addr) = uintptr(v.ptr)
}
// There's only one step for a stack-allocated value.
break stepsLoop
case abiStepIntReg, abiStepPointer:
// Copy values to "integer registers."
if v.flag&flagIndir != 0 {
offset := add(v.ptr, st.offset, "precomputed value offset")
intToReg(regs, st.ireg, st.size, offset)
} else {
// Only populate the Ints space on the return path.
// This is safe because out is kept alive until the
// end of this function, and the return path through
// makeFuncStub has no preemption, so these pointers
// are always visible to the GC.
regs.Ints[st.ireg] = uintptr(v.ptr)
}
case abiStepFloatReg:
// Copy values to "float registers."
if v.flag&flagIndir == 0 {
panic("attempted to copy pointer to FP register")
}
offset := add(v.ptr, st.offset, "precomputed value offset")
floatToReg(regs, st.freg, st.size, offset)
default:
panic("unknown ABI part kind")
}
}
}
}
// Announce that the return values are valid.
// After this point the runtime can depend on the return values being valid.
*retValid = true
// We have to make sure that the out slice lives at least until
// the runtime knows the return values are valid. Otherwise, the
// return values might not be scanned by anyone during a GC.
// (out would be dead, and the return slots not yet alive.)
runtime.KeepAlive(out)
// runtime.getArgInfo expects to be able to find ctxt on the
// stack when it finds our caller, makeFuncStub. Make sure it
// doesn't get garbage collected.
runtime.KeepAlive(ctxt)
}
// methodReceiver returns information about the receiver
// described by v. The Value v may or may not have the
// flagMethod bit set, so the kind cached in v.flag should
// not be used.
// The return value rcvrtype gives the method's actual receiver type.
// The return value t gives the method type signature (without the receiver).
// The return value fn is a pointer to the method code.
func methodReceiver(op string, v Value, methodIndex int) (rcvrtype *rtype, t *funcType, fn unsafe.Pointer) {
i := methodIndex
if v.typ.Kind() == Interface {
tt := (*interfaceType)(unsafe.Pointer(v.typ))
if uint(i) >= uint(len(tt.methods)) {
panic("reflect: internal error: invalid method index")
}
m := &tt.methods[i]
if !tt.nameOff(m.name).isExported() {
panic("reflect: " + op + " of unexported method")
}
iface := (*nonEmptyInterface)(v.ptr)
if iface.itab == nil {
panic("reflect: " + op + " of method on nil interface value")
}
rcvrtype = iface.itab.typ
fn = unsafe.Pointer(&iface.itab.fun[i])
t = (*funcType)(unsafe.Pointer(tt.typeOff(m.typ)))
} else {
rcvrtype = v.typ
ms := v.typ.exportedMethods()
if uint(i) >= uint(len(ms)) {
panic("reflect: internal error: invalid method index")
}
m := ms[i]
if !v.typ.nameOff(m.name).isExported() {
panic("reflect: " + op + " of unexported method")
}
ifn := v.typ.textOff(m.ifn)
fn = unsafe.Pointer(&ifn)
t = (*funcType)(unsafe.Pointer(v.typ.typeOff(m.mtyp)))
}
return
}
// v is a method receiver. Store at p the word which is used to
// encode that receiver at the start of the argument list.
// Reflect uses the "interface" calling convention for
// methods, which always uses one word to record the receiver.
func storeRcvr(v Value, p unsafe.Pointer) {
t := v.typ
if t.Kind() == Interface {
// the interface data word becomes the receiver word
iface := (*nonEmptyInterface)(v.ptr)
*(*unsafe.Pointer)(p) = iface.word
} else if v.flag&flagIndir != 0 && !ifaceIndir(t) {
*(*unsafe.Pointer)(p) = *(*unsafe.Pointer)(v.ptr)
} else {
*(*unsafe.Pointer)(p) = v.ptr
}
}
// align returns the result of rounding x up to a multiple of n.
// n must be a power of two.
func align(x, n uintptr) uintptr {
return (x + n - 1) &^ (n - 1)
}
// callMethod is the call implementation used by a function returned
// by makeMethodValue (used by v.Method(i).Interface()).
// It is a streamlined version of the usual reflect call: the caller has
// already laid out the argument frame for us, so we don't have
// to deal with individual Values for each argument.
// It is in this file so that it can be next to the two similar functions above.
// The remainder of the makeMethodValue implementation is in makefunc.go.
//
// NOTE: This function must be marked as a "wrapper" in the generated code,
// so that the linker can make it work correctly for panic and recover.
// The gc compilers know to do that for the name "reflect.callMethod".
//
// ctxt is the "closure" generated by makeVethodValue.
// frame is a pointer to the arguments to that closure on the stack.
// retValid points to a boolean which should be set when the results
// section of frame is set.
//
// regs contains the argument values passed in registers and will contain
// the values returned from ctxt.fn in registers.
func callMethod(ctxt *methodValue, frame unsafe.Pointer, retValid *bool, regs *abi.RegArgs) {
rcvr := ctxt.rcvr
rcvrType, valueFuncType, methodFn := methodReceiver("call", rcvr, ctxt.method)
// There are two ABIs at play here.
//
// methodValueCall was invoked with the ABI assuming there was no
// receiver ("value ABI") and that's what frame and regs are holding.
//
// Meanwhile, we need to actually call the method with a receiver, which
// has its own ABI ("method ABI"). Everything that follows is a translation
// between the two.
_, _, valueABI := funcLayout(valueFuncType, nil)
valueFrame, valueRegs := frame, regs
methodFrameType, methodFramePool, methodABI := funcLayout(valueFuncType, rcvrType)
// Make a new frame that is one word bigger so we can store the receiver.
// This space is used for both arguments and return values.
methodFrame := methodFramePool.Get().(unsafe.Pointer)
var methodRegs abi.RegArgs
// Deal with the receiver. It's guaranteed to only be one word in size.
switch st := methodABI.call.steps[0]; st.kind {
case abiStepStack:
// Only copy the receiver to the stack if the ABI says so.
// Otherwise, it'll be in a register already.
storeRcvr(rcvr, methodFrame)
case abiStepPointer:
// Put the receiver in a register.
storeRcvr(rcvr, unsafe.Pointer(&methodRegs.Ptrs[st.ireg]))
fallthrough
case abiStepIntReg:
storeRcvr(rcvr, unsafe.Pointer(&methodRegs.Ints[st.ireg]))
case abiStepFloatReg:
storeRcvr(rcvr, unsafe.Pointer(&methodRegs.Floats[st.freg]))
default:
panic("unknown ABI parameter kind")
}
// Translate the rest of the arguments.
for i, t := range valueFuncType.in() {
valueSteps := valueABI.call.stepsForValue(i)
methodSteps := methodABI.call.stepsForValue(i + 1)
// Zero-sized types are trivial: nothing to do.
if len(valueSteps) == 0 {
if len(methodSteps) != 0 {
panic("method ABI and value ABI do not align")
}
continue
}
// There are four cases to handle in translating each
// argument:
// 1. Stack -> stack translation.
// 2. Stack -> registers translation.
// 3. Registers -> stack translation.
// 4. Registers -> registers translation.
// If the value ABI passes the value on the stack,
// then the method ABI does too, because it has strictly
// fewer arguments. Simply copy between the two.
if vStep := valueSteps[0]; vStep.kind == abiStepStack {
mStep := methodSteps[0]
// Handle stack -> stack translation.
if mStep.kind == abiStepStack {
if vStep.size != mStep.size {
panic("method ABI and value ABI do not align")
}
typedmemmove(t,
add(methodFrame, mStep.stkOff, "precomputed stack offset"),
add(valueFrame, vStep.stkOff, "precomputed stack offset"))
continue
}
// Handle stack -> register translation.
for _, mStep := range methodSteps {
from := add(valueFrame, vStep.stkOff+mStep.offset, "precomputed stack offset")
switch mStep.kind {
case abiStepPointer:
// Do the pointer copy directly so we get a write barrier.
methodRegs.Ptrs[mStep.ireg] = *(*unsafe.Pointer)(from)
fallthrough // We need to make sure this ends up in Ints, too.
case abiStepIntReg:
intToReg(&methodRegs, mStep.ireg, mStep.size, from)
case abiStepFloatReg:
floatToReg(&methodRegs, mStep.freg, mStep.size, from)
default:
panic("unexpected method step")
}
}
continue
}
// Handle register -> stack translation.
if mStep := methodSteps[0]; mStep.kind == abiStepStack {
for _, vStep := range valueSteps {
to := add(methodFrame, mStep.stkOff+vStep.offset, "precomputed stack offset")
switch vStep.kind {
case abiStepPointer:
// Do the pointer copy directly so we get a write barrier.
*(*unsafe.Pointer)(to) = valueRegs.Ptrs[vStep.ireg]
case abiStepIntReg:
intFromReg(valueRegs, vStep.ireg, vStep.size, to)
case abiStepFloatReg:
floatFromReg(valueRegs, vStep.freg, vStep.size, to)
default:
panic("unexpected value step")
}
}
continue
}
// Handle register -> register translation.
if len(valueSteps) != len(methodSteps) {
// Because it's the same type for the value, and it's assigned
// to registers both times, it should always take up the same
// number of registers for each ABI.
panic("method ABI and value ABI don't align")
}
for i, vStep := range valueSteps {
mStep := methodSteps[i]
if mStep.kind != vStep.kind {
panic("method ABI and value ABI don't align")
}
switch vStep.kind {
case abiStepPointer:
// Copy this too, so we get a write barrier.
methodRegs.Ptrs[mStep.ireg] = valueRegs.Ptrs[vStep.ireg]
fallthrough
case abiStepIntReg:
methodRegs.Ints[mStep.ireg] = valueRegs.Ints[vStep.ireg]
case abiStepFloatReg:
methodRegs.Floats[mStep.freg] = valueRegs.Floats[vStep.freg]
default:
panic("unexpected value step")
}
}
}
methodFrameSize := methodFrameType.size
// TODO(mknyszek): Remove this when we no longer have
// caller reserved spill space.
methodFrameSize = align(methodFrameSize, goarch.PtrSize)
methodFrameSize += methodABI.spill
// Mark pointers in registers for the return path.
methodRegs.ReturnIsPtr = methodABI.outRegPtrs
// Call.
// Call copies the arguments from scratch to the stack, calls fn,
// and then copies the results back into scratch.
call(methodFrameType, methodFn, methodFrame, uint32(methodFrameType.size), uint32(methodABI.retOffset), uint32(methodFrameSize), &methodRegs)
// Copy return values.
//
// This is somewhat simpler because both ABIs have an identical
// return value ABI (the types are identical). As a result, register
// results can simply be copied over. Stack-allocated values are laid
// out the same, but are at different offsets from the start of the frame
// Ignore any changes to args.
// Avoid constructing out-of-bounds pointers if there are no return values.
// because the arguments may be laid out differently.
if valueRegs != nil {
*valueRegs = methodRegs
}
if retSize := methodFrameType.size - methodABI.retOffset; retSize > 0 {
valueRet := add(valueFrame, valueABI.retOffset, "valueFrame's size > retOffset")
methodRet := add(methodFrame, methodABI.retOffset, "methodFrame's size > retOffset")
// This copies to the stack. Write barriers are not needed.
memmove(valueRet, methodRet, retSize)
}
// Tell the runtime it can now depend on the return values
// being properly initialized.
*retValid = true
// Clear the scratch space and put it back in the pool.
// This must happen after the statement above, so that the return
// values will always be scanned by someone.
typedmemclr(methodFrameType, methodFrame)
methodFramePool.Put(methodFrame)
// See the comment in callReflect.
runtime.KeepAlive(ctxt)
// Keep valueRegs alive because it may hold live pointer results.
// The caller (methodValueCall) has it as a stack object, which is only
// scanned when there is a reference to it.
runtime.KeepAlive(valueRegs)
}
// funcName returns the name of f, for use in error messages.
func funcName(f func([]Value) []Value) string {
pc := *(*uintptr)(unsafe.Pointer(&f))
rf := runtime.FuncForPC(pc)
if rf != nil {
return rf.Name()
}
return "closure"
}
// Cap returns v's capacity.
// It panics if v's Kind is not Array, Chan, Slice or pointer to Array.
func (v Value) Cap() int {
// capNonSlice is split out to keep Cap inlineable for slice kinds.
if v.kind() == Slice {
return (*unsafeheader.Slice)(v.ptr).Cap
}
return v.capNonSlice()
}
func (v Value) capNonSlice() int {
k := v.kind()
switch k {
case Array:
return v.typ.Len()
case Chan:
return chancap(v.pointer())
case Ptr:
if v.typ.Elem().Kind() == Array {
return v.typ.Elem().Len()
}
panic("reflect: call of reflect.Value.Cap on ptr to non-array Value")
}
panic(&ValueError{"reflect.Value.Cap", v.kind()})
}
// Close closes the channel v.
// It panics if v's Kind is not Chan.
func (v Value) Close() {
v.mustBe(Chan)
v.mustBeExported()
chanclose(v.pointer())
}
// CanComplex reports whether Complex can be used without panicking.
func (v Value) CanComplex() bool {
switch v.kind() {
case Complex64, Complex128:
return true
default:
return false
}
}
// Complex returns v's underlying value, as a complex128.
// It panics if v's Kind is not Complex64 or Complex128
func (v Value) Complex() complex128 {
k := v.kind()
switch k {
case Complex64:
return complex128(*(*complex64)(v.ptr))
case Complex128:
return *(*complex128)(v.ptr)
}
panic(&ValueError{"reflect.Value.Complex", v.kind()})
}
// Elem returns the value that the interface v contains
// or that the pointer v points to.
// It panics if v's Kind is not Interface or Pointer.
// It returns the zero Value if v is nil.
func (v Value) Elem() Value {
k := v.kind()
switch k {
case Interface:
var eface any
if v.typ.NumMethod() == 0 {
eface = *(*any)(v.ptr)
} else {
eface = (any)(*(*interface {
M()
})(v.ptr))
}
x := unpackEface(eface)
if x.flag != 0 {
x.flag |= v.flag.ro()
}
return x
case Pointer:
ptr := v.ptr
if v.flag&flagIndir != 0 {
if ifaceIndir(v.typ) {
// This is a pointer to a not-in-heap object. ptr points to a uintptr
// in the heap. That uintptr is the address of a not-in-heap object.
// In general, pointers to not-in-heap objects can be total junk.
// But Elem() is asking to dereference it, so the user has asserted
// that at least it is a valid pointer (not just an integer stored in
// a pointer slot). So let's check, to make sure that it isn't a pointer
// that the runtime will crash on if it sees it during GC or write barriers.
// Since it is a not-in-heap pointer, all pointers to the heap are
// forbidden! That makes the test pretty easy.
// See issue 48399.
if !verifyNotInHeapPtr(*(*uintptr)(ptr)) {
panic("reflect: reflect.Value.Elem on an invalid notinheap pointer")
}
}
ptr = *(*unsafe.Pointer)(ptr)
}
// The returned value's address is v's value.
if ptr == nil {
return Value{}
}
tt := (*ptrType)(unsafe.Pointer(v.typ))
typ := tt.elem
fl := v.flag&flagRO | flagIndir | flagAddr
fl |= flag(typ.Kind())
return Value{typ, ptr, fl}
}
panic(&ValueError{"reflect.Value.Elem", v.kind()})
}
// Field returns the i'th field of the struct v.
// It panics if v's Kind is not Struct or i is out of range.
func (v Value) Field(i int) Value {
if v.kind() != Struct {
panic(&ValueError{"reflect.Value.Field", v.kind()})
}
tt := (*structType)(unsafe.Pointer(v.typ))
if uint(i) >= uint(len(tt.fields)) {
panic("reflect: Field index out of range")
}
field := &tt.fields[i]
typ := field.typ
// Inherit permission bits from v, but clear flagEmbedRO.
fl := v.flag&(flagStickyRO|flagIndir|flagAddr) | flag(typ.Kind())
// Using an unexported field forces flagRO.
if !field.name.isExported() {
if field.embedded() {
fl |= flagEmbedRO
} else {
fl |= flagStickyRO
}
}
// Either flagIndir is set and v.ptr points at struct,
// or flagIndir is not set and v.ptr is the actual struct data.
// In the former case, we want v.ptr + offset.
// In the latter case, we must have field.offset = 0,
// so v.ptr + field.offset is still the correct address.
ptr := add(v.ptr, field.offset, "same as non-reflect &v.field")
return Value{typ, ptr, fl}
}
// FieldByIndex returns the nested field corresponding to index.
// It panics if evaluation requires stepping through a nil
// pointer or a field that is not a struct.
func (v Value) FieldByIndex(index []int) Value {
if len(index) == 1 {
return v.Field(index[0])
}
v.mustBe(Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == Pointer && v.typ.Elem().Kind() == Struct {
if v.IsNil() {
panic("reflect: indirection through nil pointer to embedded struct")
}
v = v.Elem()
}
}
v = v.Field(x)
}
return v
}
// FieldByIndexErr returns the nested field corresponding to index.
// It returns an error if evaluation requires stepping through a nil
// pointer, but panics if it must step through a field that
// is not a struct.
func (v Value) FieldByIndexErr(index []int) (Value, error) {
if len(index) == 1 {
return v.Field(index[0]), nil
}
v.mustBe(Struct)
for i, x := range index {
if i > 0 {
if v.Kind() == Ptr && v.typ.Elem().Kind() == Struct {
if v.IsNil() {
return Value{}, errors.New("reflect: indirection through nil pointer to embedded struct field " + v.typ.Elem().Name())
}
v = v.Elem()
}
}
v = v.Field(x)
}
return v, nil
}
// FieldByName returns the struct field with the given name.
// It returns the zero Value if no field was found.
// It panics if v's Kind is not struct.
func (v Value) FieldByName(name string) Value {
v.mustBe(Struct)
if f, ok := v.typ.FieldByName(name); ok {
return v.FieldByIndex(f.Index)
}
return Value{}
}
// FieldByNameFunc returns the struct field with a name
// that satisfies the match function.
// It panics if v's Kind is not struct.
// It returns the zero Value if no field was found.
func (v Value) FieldByNameFunc(match func(string) bool) Value {
if f, ok := v.typ.FieldByNameFunc(match); ok {
return v.FieldByIndex(f.Index)
}
return Value{}
}
// CanFloat reports whether Float can be used without panicking.
func (v Value) CanFloat() bool {
switch v.kind() {
case Float32, Float64:
return true
default:
return false
}
}
// Float returns v's underlying value, as a float64.
// It panics if v's Kind is not Float32 or Float64
func (v Value) Float() float64 {
k := v.kind()
switch k {
case Float32:
return float64(*(*float32)(v.ptr))
case Float64:
return *(*float64)(v.ptr)
}
panic(&ValueError{"reflect.Value.Float", v.kind()})
}
var uint8Type = rtypeOf(uint8(0))
// Index returns v's i'th element.
// It panics if v's Kind is not Array, Slice, or String or i is out of range.
func (v Value) Index(i int) Value {
switch v.kind() {
case Array:
tt := (*arrayType)(unsafe.Pointer(v.typ))
if uint(i) >= uint(tt.len) {
panic("reflect: array index out of range")
}
typ := tt.elem
offset := uintptr(i) * typ.size
// Either flagIndir is set and v.ptr points at array,
// or flagIndir is not set and v.ptr is the actual array data.
// In the former case, we want v.ptr + offset.
// In the latter case, we must be doing Index(0), so offset = 0,
// so v.ptr + offset is still the correct address.
val := add(v.ptr, offset, "same as &v[i], i < tt.len")
fl := v.flag&(flagIndir|flagAddr) | v.flag.ro() | flag(typ.Kind()) // bits same as overall array
return Value{typ, val, fl}
case Slice:
// Element flag same as Elem of Pointer.
// Addressable, indirect, possibly read-only.
s := (*unsafeheader.Slice)(v.ptr)
if uint(i) >= uint(s.Len) {
panic("reflect: slice index out of range")
}
tt := (*sliceType)(unsafe.Pointer(v.typ))
typ := tt.elem
val := arrayAt(s.Data, i, typ.size, "i < s.Len")
fl := flagAddr | flagIndir | v.flag.ro() | flag(typ.Kind())
return Value{typ, val, fl}
case String:
s := (*unsafeheader.String)(v.ptr)
if uint(i) >= uint(s.Len) {
panic("reflect: string index out of range")
}
p := arrayAt(s.Data, i, 1, "i < s.Len")
fl := v.flag.ro() | flag(Uint8) | flagIndir
return Value{uint8Type, p, fl}
}
panic(&ValueError{"reflect.Value.Index", v.kind()})
}
// CanInt reports whether Int can be used without panicking.
func (v Value) CanInt() bool {
switch v.kind() {
case Int, Int8, Int16, Int32, Int64:
return true
default:
return false
}
}
// Int returns v's underlying value, as an int64.
// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64.
func (v Value) Int() int64 {
k := v.kind()
p := v.ptr
switch k {
case Int:
return int64(*(*int)(p))
case Int8:
return int64(*(*int8)(p))
case Int16:
return int64(*(*int16)(p))
case Int32:
return int64(*(*int32)(p))
case Int64:
return *(*int64)(p)
}
panic(&ValueError{"reflect.Value.Int", v.kind()})
}
// CanInterface reports whether Interface can be used without panicking.
func (v Value) CanInterface() bool {
if v.flag == 0 {
panic(&ValueError{"reflect.Value.CanInterface", Invalid})
}
return v.flag&flagRO == 0
}
// Interface returns v's current value as an interface{}.
// It is equivalent to:
//
// var i interface{} = (v's underlying value)
//
// It panics if the Value was obtained by accessing
// unexported struct fields.
func (v Value) Interface() (i any) {
return valueInterface(v, true)
}
func valueInterface(v Value, safe bool) any {
if v.flag == 0 {
panic(&ValueError{"reflect.Value.Interface", Invalid})
}
if safe && v.flag&flagRO != 0 {
// Do not allow access to unexported values via Interface,
// because they might be pointers that should not be
// writable or methods or function that should not be callable.
panic("reflect.Value.Interface: cannot return value obtained from unexported field or method")
}
if v.flag&flagMethod != 0 {
v = makeMethodValue("Interface", v)
}
if v.kind() == Interface {
// Special case: return the element inside the interface.
// Empty interface has one layout, all interfaces with
// methods have a second layout.
if v.NumMethod() == 0 {
return *(*any)(v.ptr)
}
return *(*interface {
M()
})(v.ptr)
}
// TODO: pass safe to packEface so we don't need to copy if safe==true?
return packEface(v)
}
// InterfaceData returns a pair of unspecified uintptr values.
// It panics if v's Kind is not Interface.
//
// In earlier versions of Go, this function returned the interface's
// value as a uintptr pair. As of Go 1.4, the implementation of
// interface values precludes any defined use of InterfaceData.
//
// Deprecated: The memory representation of interface values is not
// compatible with InterfaceData.
func (v Value) InterfaceData() [2]uintptr {
v.mustBe(Interface)
// We treat this as a read operation, so we allow
// it even for unexported data, because the caller
// has to import "unsafe" to turn it into something
// that can be abused.
// Interface value is always bigger than a word; assume flagIndir.
return *(*[2]uintptr)(v.ptr)
}
// IsNil reports whether its argument v is nil. The argument must be
// a chan, func, interface, map, pointer, or slice value; if it is
// not, IsNil panics. Note that IsNil is not always equivalent to a
// regular comparison with nil in Go. For example, if v was created
// by calling ValueOf with an uninitialized interface variable i,
// i==nil will be true but v.IsNil will panic as v will be the zero
// Value.
func (v Value) IsNil() bool {
k := v.kind()
switch k {
case Chan, Func, Map, Pointer, UnsafePointer:
if v.flag&flagMethod != 0 {
return false
}
ptr := v.ptr
if v.flag&flagIndir != 0 {
ptr = *(*unsafe.Pointer)(ptr)
}
return ptr == nil
case Interface, Slice:
// Both interface and slice are nil if first word is 0.
// Both are always bigger than a word; assume flagIndir.
return *(*unsafe.Pointer)(v.ptr) == nil
}
panic(&ValueError{"reflect.Value.IsNil", v.kind()})
}
// IsValid reports whether v represents a value.
// It returns false if v is the zero Value.
// If IsValid returns false, all other methods except String panic.
// Most functions and methods never return an invalid Value.
// If one does, its documentation states the conditions explicitly.
func (v Value) IsValid() bool {
return v.flag != 0
}
// IsZero reports whether v is the zero value for its type.
// It panics if the argument is invalid.
func (v Value) IsZero() bool {
switch v.kind() {
case Bool:
return !v.Bool()
case Int, Int8, Int16, Int32, Int64:
return v.Int() == 0
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return v.Uint() == 0
case Float32, Float64:
return math.Float64bits(v.Float()) == 0
case Complex64, Complex128:
c := v.Complex()
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
case Array:
// If the type is comparable, then compare directly with zero.
if v.typ.equal != nil && v.typ.size <= maxZero {
if v.flag&flagIndir == 0 {
return v.ptr == nil
}
return v.typ.equal(v.ptr, unsafe.Pointer(&zeroVal[0]))
}
n := v.Len()
for i := 0; i < n; i++ {
if !v.Index(i).IsZero() {
return false
}
}
return true
case Chan, Func, Interface, Map, Pointer, Slice, UnsafePointer:
return v.IsNil()
case String:
return v.Len() == 0
case Struct:
// If the type is comparable, then compare directly with zero.
if v.typ.equal != nil && v.typ.size <= maxZero {
if v.flag&flagIndir == 0 {
return v.ptr == nil
}
return v.typ.equal(v.ptr, unsafe.Pointer(&zeroVal[0]))
}
n := v.NumField()
for i := 0; i < n; i++ {
if !v.Field(i).IsZero() {
return false
}
}
return true
default:
// This should never happen, but will act as a safeguard for later,
// as a default value doesn't makes sense here.
panic(&ValueError{"reflect.Value.IsZero", v.Kind()})
}
}
// SetZero sets v to be the zero value of v's type.
// It panics if CanSet returns false.
func (v Value) SetZero() {
v.mustBeAssignable()
switch v.kind() {
case Bool:
*(*bool)(v.ptr) = false
case Int:
*(*int)(v.ptr) = 0
case Int8:
*(*int8)(v.ptr) = 0
case Int16:
*(*int16)(v.ptr) = 0
case Int32:
*(*int32)(v.ptr) = 0
case Int64:
*(*int64)(v.ptr) = 0
case Uint:
*(*uint)(v.ptr) = 0
case Uint8:
*(*uint8)(v.ptr) = 0
case Uint16:
*(*uint16)(v.ptr) = 0
case Uint32:
*(*uint32)(v.ptr) = 0
case Uint64:
*(*uint64)(v.ptr) = 0
case Uintptr:
*(*uintptr)(v.ptr) = 0
case Float32:
*(*float32)(v.ptr) = 0
case Float64:
*(*float64)(v.ptr) = 0
case Complex64:
*(*complex64)(v.ptr) = 0
case Complex128:
*(*complex128)(v.ptr) = 0
case String:
*(*string)(v.ptr) = ""
case Slice:
*(*unsafeheader.Slice)(v.ptr) = unsafeheader.Slice{}
case Interface:
*(*[2]unsafe.Pointer)(v.ptr) = [2]unsafe.Pointer{}
case Chan, Func, Map, Pointer, UnsafePointer:
*(*unsafe.Pointer)(v.ptr) = nil
case Array, Struct:
typedmemclr(v.typ, v.ptr)
default:
// This should never happen, but will act as a safeguard for later,
// as a default value doesn't makes sense here.
panic(&ValueError{"reflect.Value.SetZero", v.Kind()})
}
}
// Kind returns v's Kind.
// If v is the zero Value (IsValid returns false), Kind returns Invalid.
func (v Value) Kind() Kind {
return v.kind()
}
// Len returns v's length.
// It panics if v's Kind is not Array, Chan, Map, Slice, String, or pointer to Array.
func (v Value) Len() int {
// lenNonSlice is split out to keep Len inlineable for slice kinds.
if v.kind() == Slice {
return (*unsafeheader.Slice)(v.ptr).Len
}
return v.lenNonSlice()
}
func (v Value) lenNonSlice() int {
switch k := v.kind(); k {
case Array:
tt := (*arrayType)(unsafe.Pointer(v.typ))
return int(tt.len)
case Chan:
return chanlen(v.pointer())
case Map:
return maplen(v.pointer())
case String:
// String is bigger than a word; assume flagIndir.
return (*unsafeheader.String)(v.ptr).Len
case Ptr:
if v.typ.Elem().Kind() == Array {
return v.typ.Elem().Len()
}
panic("reflect: call of reflect.Value.Len on ptr to non-array Value")
}
panic(&ValueError{"reflect.Value.Len", v.kind()})
}
var stringType = rtypeOf("")
// MapIndex returns the value associated with key in the map v.
// It panics if v's Kind is not Map.
// It returns the zero Value if key is not found in the map or if v represents a nil map.
// As in Go, the key's value must be assignable to the map's key type.
func (v Value) MapIndex(key Value) Value {
v.mustBe(Map)
tt := (*mapType)(unsafe.Pointer(v.typ))
// Do not require key to be exported, so that DeepEqual
// and other programs can use all the keys returned by
// MapKeys as arguments to MapIndex. If either the map
// or the key is unexported, though, the result will be
// considered unexported. This is consistent with the
// behavior for structs, which allow read but not write
// of unexported fields.
var e unsafe.Pointer
if (tt.key == stringType || key.kind() == String) && tt.key == key.typ && tt.elem.size <= maxValSize {
k := *(*string)(key.ptr)
e = mapaccess_faststr(v.typ, v.pointer(), k)
} else {
key = key.assignTo("reflect.Value.MapIndex", tt.key, nil)
var k unsafe.Pointer
if key.flag&flagIndir != 0 {
k = key.ptr
} else {
k = unsafe.Pointer(&key.ptr)
}
e = mapaccess(v.typ, v.pointer(), k)
}
if e == nil {
return Value{}
}
typ := tt.elem
fl := (v.flag | key.flag).ro()
fl |= flag(typ.Kind())
return copyVal(typ, fl, e)
}
// MapKeys returns a slice containing all the keys present in the map,
// in unspecified order.
// It panics if v's Kind is not Map.
// It returns an empty slice if v represents a nil map.
func (v Value) MapKeys() []Value {
v.mustBe(Map)
tt := (*mapType)(unsafe.Pointer(v.typ))
keyType := tt.key
fl := v.flag.ro() | flag(keyType.Kind())
m := v.pointer()
mlen := int(0)
if m != nil {
mlen = maplen(m)
}
var it hiter
mapiterinit(v.typ, m, &it)
a := make([]Value, mlen)
var i int
for i = 0; i < len(a); i++ {
key := mapiterkey(&it)
if key == nil {
// Someone deleted an entry from the map since we
// called maplen above. It's a data race, but nothing
// we can do about it.
break
}
a[i] = copyVal(keyType, fl, key)
mapiternext(&it)
}
return a[:i]
}
// hiter's structure matches runtime.hiter's structure.
// Having a clone here allows us to embed a map iterator
// inside type MapIter so that MapIters can be re-used
// without doing any allocations.
type hiter struct {
key unsafe.Pointer
elem unsafe.Pointer
t unsafe.Pointer
h unsafe.Pointer
buckets unsafe.Pointer
bptr unsafe.Pointer
overflow *[]unsafe.Pointer
oldoverflow *[]unsafe.Pointer
startBucket uintptr
offset uint8
wrapped bool
B uint8
i uint8
bucket uintptr
checkBucket uintptr
}
func (h *hiter) initialized() bool {
return h.t != nil
}
// A MapIter is an iterator for ranging over a map.
// See Value.MapRange.
type MapIter struct {
m Value
hiter hiter
}
// Key returns the key of iter's current map entry.
func (iter *MapIter) Key() Value {
if !iter.hiter.initialized() {
panic("MapIter.Key called before Next")
}
iterkey := mapiterkey(&iter.hiter)
if iterkey == nil {
panic("MapIter.Key called on exhausted iterator")
}
t := (*mapType)(unsafe.Pointer(iter.m.typ))
ktype := t.key
return copyVal(ktype, iter.m.flag.ro()|flag(ktype.Kind()), iterkey)
}
// SetIterKey assigns to v the key of iter's current map entry.
// It is equivalent to v.Set(iter.Key()), but it avoids allocating a new Value.
// As in Go, the key must be assignable to v's type and
// must not be derived from an unexported field.
func (v Value) SetIterKey(iter *MapIter) {
if !iter.hiter.initialized() {
panic("reflect: Value.SetIterKey called before Next")
}
iterkey := mapiterkey(&iter.hiter)
if iterkey == nil {
panic("reflect: Value.SetIterKey called on exhausted iterator")
}
v.mustBeAssignable()
var target unsafe.Pointer
if v.kind() == Interface {
target = v.ptr
}
t := (*mapType)(unsafe.Pointer(iter.m.typ))
ktype := t.key
iter.m.mustBeExported() // do not let unexported m leak
key := Value{ktype, iterkey, iter.m.flag | flag(ktype.Kind()) | flagIndir}
key = key.assignTo("reflect.MapIter.SetKey", v.typ, target)
typedmemmove(v.typ, v.ptr, key.ptr)
}
// Value returns the value of iter's current map entry.
func (iter *MapIter) Value() Value {
if !iter.hiter.initialized() {
panic("MapIter.Value called before Next")
}
iterelem := mapiterelem(&iter.hiter)
if iterelem == nil {
panic("MapIter.Value called on exhausted iterator")
}
t := (*mapType)(unsafe.Pointer(iter.m.typ))
vtype := t.elem
return copyVal(vtype, iter.m.flag.ro()|flag(vtype.Kind()), iterelem)
}
// SetIterValue assigns to v the value of iter's current map entry.
// It is equivalent to v.Set(iter.Value()), but it avoids allocating a new Value.
// As in Go, the value must be assignable to v's type and
// must not be derived from an unexported field.
func (v Value) SetIterValue(iter *MapIter) {
if !iter.hiter.initialized() {
panic("reflect: Value.SetIterValue called before Next")
}
iterelem := mapiterelem(&iter.hiter)
if iterelem == nil {
panic("reflect: Value.SetIterValue called on exhausted iterator")
}
v.mustBeAssignable()
var target unsafe.Pointer
if v.kind() == Interface {
target = v.ptr
}
t := (*mapType)(unsafe.Pointer(iter.m.typ))
vtype := t.elem
iter.m.mustBeExported() // do not let unexported m leak
elem := Value{vtype, iterelem, iter.m.flag | flag(vtype.Kind()) | flagIndir}
elem = elem.assignTo("reflect.MapIter.SetValue", v.typ, target)
typedmemmove(v.typ, v.ptr, elem.ptr)
}
// Next advances the map iterator and reports whether there is another
// entry. It returns false when iter is exhausted; subsequent
// calls to Key, Value, or Next will panic.
func (iter *MapIter) Next() bool {
if !iter.m.IsValid() {
panic("MapIter.Next called on an iterator that does not have an associated map Value")
}
if !iter.hiter.initialized() {
mapiterinit(iter.m.typ, iter.m.pointer(), &iter.hiter)
} else {
if mapiterkey(&iter.hiter) == nil {
panic("MapIter.Next called on exhausted iterator")
}
mapiternext(&iter.hiter)
}
return mapiterkey(&iter.hiter) != nil
}
// Reset modifies iter to iterate over v.
// It panics if v's Kind is not Map and v is not the zero Value.
// Reset(Value{}) causes iter to not to refer to any map,
// which may allow the previously iterated-over map to be garbage collected.
func (iter *MapIter) Reset(v Value) {
if v.IsValid() {
v.mustBe(Map)
}
iter.m = v
iter.hiter = hiter{}
}
// MapRange returns a range iterator for a map.
// It panics if v's Kind is not Map.
//
// Call Next to advance the iterator, and Key/Value to access each entry.
// Next returns false when the iterator is exhausted.
// MapRange follows the same iteration semantics as a range statement.
//
// Example:
//
// iter := reflect.ValueOf(m).MapRange()
// for iter.Next() {
// k := iter.Key()
// v := iter.Value()
// ...
// }
func (v Value) MapRange() *MapIter {
// This is inlinable to take advantage of "function outlining".
// The allocation of MapIter can be stack allocated if the caller
// does not allow it to escape.
// See https://blog.filippo.io/efficient-go-apis-with-the-inliner/
if v.kind() != Map {
v.panicNotMap()
}
return &MapIter{m: v}
}
func (f flag) panicNotMap() {
f.mustBe(Map)
}
// copyVal returns a Value containing the map key or value at ptr,
// allocating a new variable as needed.
func copyVal(typ *rtype, fl flag, ptr unsafe.Pointer) Value {
if ifaceIndir(typ) {
// Copy result so future changes to the map
// won't change the underlying value.
c := unsafe_New(typ)
typedmemmove(typ, c, ptr)
return Value{typ, c, fl | flagIndir}
}
return Value{typ, *(*unsafe.Pointer)(ptr), fl}
}
// Method returns a function value corresponding to v's i'th method.
// The arguments to a Call on the returned function should not include
// a receiver; the returned function will always use v as the receiver.
// Method panics if i is out of range or if v is a nil interface value.
func (v Value) Method(i int) Value {
if v.typ == nil {
panic(&ValueError{"reflect.Value.Method", Invalid})
}
if v.flag&flagMethod != 0 || uint(i) >= uint(v.typ.NumMethod()) {
panic("reflect: Method index out of range")
}
if v.typ.Kind() == Interface && v.IsNil() {
panic("reflect: Method on nil interface value")
}
fl := v.flag.ro() | (v.flag & flagIndir)
fl |= flag(Func)
fl |= flag(i)<<flagMethodShift | flagMethod
return Value{v.typ, v.ptr, fl}
}
// NumMethod returns the number of methods in the value's method set.
//
// For a non-interface type, it returns the number of exported methods.
//
// For an interface type, it returns the number of exported and unexported methods.
func (v Value) NumMethod() int {
if v.typ == nil {
panic(&ValueError{"reflect.Value.NumMethod", Invalid})
}
if v.flag&flagMethod != 0 {
return 0
}
return v.typ.NumMethod()
}
// MethodByName returns a function value corresponding to the method
// of v with the given name.
// The arguments to a Call on the returned function should not include
// a receiver; the returned function will always use v as the receiver.
// It returns the zero Value if no method was found.
func (v Value) MethodByName(name string) Value {
if v.typ == nil {
panic(&ValueError{"reflect.Value.MethodByName", Invalid})
}
if v.flag&flagMethod != 0 {
return Value{}
}
m, ok := v.typ.MethodByName(name)
if !ok {
return Value{}
}
return v.Method(m.Index)
}
// NumField returns the number of fields in the struct v.
// It panics if v's Kind is not Struct.
func (v Value) NumField() int {
v.mustBe(Struct)
tt := (*structType)(unsafe.Pointer(v.typ))
return len(tt.fields)
}
// OverflowComplex reports whether the complex128 x cannot be represented by v's type.
// It panics if v's Kind is not Complex64 or Complex128.
func (v Value) OverflowComplex(x complex128) bool {
k := v.kind()
switch k {
case Complex64:
return overflowFloat32(real(x)) || overflowFloat32(imag(x))
case Complex128:
return false
}
panic(&ValueError{"reflect.Value.OverflowComplex", v.kind()})
}
// OverflowFloat reports whether the float64 x cannot be represented by v's type.
// It panics if v's Kind is not Float32 or Float64.
func (v Value) OverflowFloat(x float64) bool {
k := v.kind()
switch k {
case Float32:
return overflowFloat32(x)
case Float64:
return false
}
panic(&ValueError{"reflect.Value.OverflowFloat", v.kind()})
}
func overflowFloat32(x float64) bool {
if x < 0 {
x = -x
}
return math.MaxFloat32 < x && x <= math.MaxFloat64
}
// OverflowInt reports whether the int64 x cannot be represented by v's type.
// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64.
func (v Value) OverflowInt(x int64) bool {
k := v.kind()
switch k {
case Int, Int8, Int16, Int32, Int64:
bitSize := v.typ.size * 8
trunc := (x << (64 - bitSize)) >> (64 - bitSize)
return x != trunc
}
panic(&ValueError{"reflect.Value.OverflowInt", v.kind()})
}
// OverflowUint reports whether the uint64 x cannot be represented by v's type.
// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64.
func (v Value) OverflowUint(x uint64) bool {
k := v.kind()
switch k {
case Uint, Uintptr, Uint8, Uint16, Uint32, Uint64:
bitSize := v.typ.size * 8
trunc := (x << (64 - bitSize)) >> (64 - bitSize)
return x != trunc
}
panic(&ValueError{"reflect.Value.OverflowUint", v.kind()})
}
//go:nocheckptr
// This prevents inlining Value.Pointer when -d=checkptr is enabled,
// which ensures cmd/compile can recognize unsafe.Pointer(v.Pointer())
// and make an exception.
// Pointer returns v's value as a uintptr.
// It panics if v's Kind is not Chan, Func, Map, Pointer, Slice, or UnsafePointer.
//
// If v's Kind is Func, the returned pointer is an underlying
// code pointer, but not necessarily enough to identify a
// single function uniquely. The only guarantee is that the
// result is zero if and only if v is a nil func Value.
//
// If v's Kind is Slice, the returned pointer is to the first
// element of the slice. If the slice is nil the returned value
// is 0. If the slice is empty but non-nil the return value is non-zero.
//
// It's preferred to use uintptr(Value.UnsafePointer()) to get the equivalent result.
func (v Value) Pointer() uintptr {
k := v.kind()
switch k {
case Pointer:
if v.typ.ptrdata == 0 {
val := *(*uintptr)(v.ptr)
// Since it is a not-in-heap pointer, all pointers to the heap are
// forbidden! See comment in Value.Elem and issue #48399.
if !verifyNotInHeapPtr(val) {
panic("reflect: reflect.Value.Pointer on an invalid notinheap pointer")
}
return val
}
fallthrough
case Chan, Map, UnsafePointer:
return uintptr(v.pointer())
case Func:
if v.flag&flagMethod != 0 {
// As the doc comment says, the returned pointer is an
// underlying code pointer but not necessarily enough to
// identify a single function uniquely. All method expressions
// created via reflect have the same underlying code pointer,
// so their Pointers are equal. The function used here must
// match the one used in makeMethodValue.
return methodValueCallCodePtr()
}
p := v.pointer()
// Non-nil func value points at data block.
// First word of data block is actual code.
if p != nil {
p = *(*unsafe.Pointer)(p)
}
return uintptr(p)
case Slice:
return uintptr((*unsafeheader.Slice)(v.ptr).Data)
}
panic(&ValueError{"reflect.Value.Pointer", v.kind()})
}
// Recv receives and returns a value from the channel v.
// It panics if v's Kind is not Chan.
// The receive blocks until a value is ready.
// The boolean value ok is true if the value x corresponds to a send
// on the channel, false if it is a zero value received because the channel is closed.
func (v Value) Recv() (x Value, ok bool) {
v.mustBe(Chan)
v.mustBeExported()
return v.recv(false)
}
// internal recv, possibly non-blocking (nb).
// v is known to be a channel.
func (v Value) recv(nb bool) (val Value, ok bool) {
tt := (*chanType)(unsafe.Pointer(v.typ))
if ChanDir(tt.dir)&RecvDir == 0 {
panic("reflect: recv on send-only channel")
}
t := tt.elem
val = Value{t, nil, flag(t.Kind())}
var p unsafe.Pointer
if ifaceIndir(t) {
p = unsafe_New(t)
val.ptr = p
val.flag |= flagIndir
} else {
p = unsafe.Pointer(&val.ptr)
}
selected, ok := chanrecv(v.pointer(), nb, p)
if !selected {
val = Value{}
}
return
}
// Send sends x on the channel v.
// It panics if v's kind is not Chan or if x's type is not the same type as v's element type.
// As in Go, x's value must be assignable to the channel's element type.
func (v Value) Send(x Value) {
v.mustBe(Chan)
v.mustBeExported()
v.send(x, false)
}
// internal send, possibly non-blocking.
// v is known to be a channel.
func (v Value) send(x Value, nb bool) (selected bool) {
tt := (*chanType)(unsafe.Pointer(v.typ))
if ChanDir(tt.dir)&SendDir == 0 {
panic("reflect: send on recv-only channel")
}
x.mustBeExported()
x = x.assignTo("reflect.Value.Send", tt.elem, nil)
var p unsafe.Pointer
if x.flag&flagIndir != 0 {
p = x.ptr
} else {
p = unsafe.Pointer(&x.ptr)
}
return chansend(v.pointer(), p, nb)
}
// Set assigns x to the value v.
// It panics if CanSet returns false.
// As in Go, x's value must be assignable to v's type and
// must not be derived from an unexported field.
func (v Value) Set(x Value) {
v.mustBeAssignable()
x.mustBeExported() // do not let unexported x leak
var target unsafe.Pointer
if v.kind() == Interface {
target = v.ptr
}
x = x.assignTo("reflect.Set", v.typ, target)
if x.flag&flagIndir != 0 {
if x.ptr == unsafe.Pointer(&zeroVal[0]) {
typedmemclr(v.typ, v.ptr)
} else {
typedmemmove(v.typ, v.ptr, x.ptr)
}
} else {
*(*unsafe.Pointer)(v.ptr) = x.ptr
}
}
// SetBool sets v's underlying value.
// It panics if v's Kind is not Bool or if CanSet() is false.
func (v Value) SetBool(x bool) {
v.mustBeAssignable()
v.mustBe(Bool)
*(*bool)(v.ptr) = x
}
// SetBytes sets v's underlying value.
// It panics if v's underlying value is not a slice of bytes.
func (v Value) SetBytes(x []byte) {
v.mustBeAssignable()
v.mustBe(Slice)
if v.typ.Elem().Kind() != Uint8 {
panic("reflect.Value.SetBytes of non-byte slice")
}
*(*[]byte)(v.ptr) = x
}
// setRunes sets v's underlying value.
// It panics if v's underlying value is not a slice of runes (int32s).
func (v Value) setRunes(x []rune) {
v.mustBeAssignable()
v.mustBe(Slice)
if v.typ.Elem().Kind() != Int32 {
panic("reflect.Value.setRunes of non-rune slice")
}
*(*[]rune)(v.ptr) = x
}
// SetComplex sets v's underlying value to x.
// It panics if v's Kind is not Complex64 or Complex128, or if CanSet() is false.
func (v Value) SetComplex(x complex128) {
v.mustBeAssignable()
switch k := v.kind(); k {
default:
panic(&ValueError{"reflect.Value.SetComplex", v.kind()})
case Complex64:
*(*complex64)(v.ptr) = complex64(x)
case Complex128:
*(*complex128)(v.ptr) = x
}
}
// SetFloat sets v's underlying value to x.
// It panics if v's Kind is not Float32 or Float64, or if CanSet() is false.
func (v Value) SetFloat(x float64) {
v.mustBeAssignable()
switch k := v.kind(); k {
default:
panic(&ValueError{"reflect.Value.SetFloat", v.kind()})
case Float32:
*(*float32)(v.ptr) = float32(x)
case Float64:
*(*float64)(v.ptr) = x
}
}
// SetInt sets v's underlying value to x.
// It panics if v's Kind is not Int, Int8, Int16, Int32, or Int64, or if CanSet() is false.
func (v Value) SetInt(x int64) {
v.mustBeAssignable()
switch k := v.kind(); k {
default:
panic(&ValueError{"reflect.Value.SetInt", v.kind()})
case Int:
*(*int)(v.ptr) = int(x)
case Int8:
*(*int8)(v.ptr) = int8(x)
case Int16:
*(*int16)(v.ptr) = int16(x)
case Int32:
*(*int32)(v.ptr) = int32(x)
case Int64:
*(*int64)(v.ptr) = x
}
}
// SetLen sets v's length to n.
// It panics if v's Kind is not Slice or if n is negative or
// greater than the capacity of the slice.
func (v Value) SetLen(n int) {
v.mustBeAssignable()
v.mustBe(Slice)
s := (*unsafeheader.Slice)(v.ptr)
if uint(n) > uint(s.Cap) {
panic("reflect: slice length out of range in SetLen")
}
s.Len = n
}
// SetCap sets v's capacity to n.
// It panics if v's Kind is not Slice or if n is smaller than the length or
// greater than the capacity of the slice.
func (v Value) SetCap(n int) {
v.mustBeAssignable()
v.mustBe(Slice)
s := (*unsafeheader.Slice)(v.ptr)
if n < s.Len || n > s.Cap {
panic("reflect: slice capacity out of range in SetCap")
}
s.Cap = n
}
// SetMapIndex sets the element associated with key in the map v to elem.
// It panics if v's Kind is not Map.
// If elem is the zero Value, SetMapIndex deletes the key from the map.
// Otherwise if v holds a nil map, SetMapIndex will panic.
// As in Go, key's elem must be assignable to the map's key type,
// and elem's value must be assignable to the map's elem type.
func (v Value) SetMapIndex(key, elem Value) {
v.mustBe(Map)
v.mustBeExported()
key.mustBeExported()
tt := (*mapType)(unsafe.Pointer(v.typ))
if (tt.key == stringType || key.kind() == String) && tt.key == key.typ && tt.elem.size <= maxValSize {
k := *(*string)(key.ptr)
if elem.typ == nil {
mapdelete_faststr(v.typ, v.pointer(), k)
return
}
elem.mustBeExported()
elem = elem.assignTo("reflect.Value.SetMapIndex", tt.elem, nil)
var e unsafe.Pointer
if elem.flag&flagIndir != 0 {
e = elem.ptr
} else {
e = unsafe.Pointer(&elem.ptr)
}
mapassign_faststr(v.typ, v.pointer(), k, e)
return
}
key = key.assignTo("reflect.Value.SetMapIndex", tt.key, nil)
var k unsafe.Pointer
if key.flag&flagIndir != 0 {
k = key.ptr
} else {
k = unsafe.Pointer(&key.ptr)
}
if elem.typ == nil {
mapdelete(v.typ, v.pointer(), k)
return
}
elem.mustBeExported()
elem = elem.assignTo("reflect.Value.SetMapIndex", tt.elem, nil)
var e unsafe.Pointer
if elem.flag&flagIndir != 0 {
e = elem.ptr
} else {
e = unsafe.Pointer(&elem.ptr)
}
mapassign(v.typ, v.pointer(), k, e)
}
// SetUint sets v's underlying value to x.
// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64, or if CanSet() is false.
func (v Value) SetUint(x uint64) {
v.mustBeAssignable()
switch k := v.kind(); k {
default:
panic(&ValueError{"reflect.Value.SetUint", v.kind()})
case Uint:
*(*uint)(v.ptr) = uint(x)
case Uint8:
*(*uint8)(v.ptr) = uint8(x)
case Uint16:
*(*uint16)(v.ptr) = uint16(x)
case Uint32:
*(*uint32)(v.ptr) = uint32(x)
case Uint64:
*(*uint64)(v.ptr) = x
case Uintptr:
*(*uintptr)(v.ptr) = uintptr(x)
}
}
// SetPointer sets the [unsafe.Pointer] value v to x.
// It panics if v's Kind is not UnsafePointer.
func (v Value) SetPointer(x unsafe.Pointer) {
v.mustBeAssignable()
v.mustBe(UnsafePointer)
*(*unsafe.Pointer)(v.ptr) = x
}
// SetString sets v's underlying value to x.
// It panics if v's Kind is not String or if CanSet() is false.
func (v Value) SetString(x string) {
v.mustBeAssignable()
v.mustBe(String)
*(*string)(v.ptr) = x
}
// Slice returns v[i:j].
// It panics if v's Kind is not Array, Slice or String, or if v is an unaddressable array,
// or if the indexes are out of bounds.
func (v Value) Slice(i, j int) Value {
var (
cap int
typ *sliceType
base unsafe.Pointer
)
switch kind := v.kind(); kind {
default:
panic(&ValueError{"reflect.Value.Slice", v.kind()})
case Array:
if v.flag&flagAddr == 0 {
panic("reflect.Value.Slice: slice of unaddressable array")
}
tt := (*arrayType)(unsafe.Pointer(v.typ))
cap = int(tt.len)
typ = (*sliceType)(unsafe.Pointer(tt.slice))
base = v.ptr
case Slice:
typ = (*sliceType)(unsafe.Pointer(v.typ))
s := (*unsafeheader.Slice)(v.ptr)
base = s.Data
cap = s.Cap
case String:
s := (*unsafeheader.String)(v.ptr)
if i < 0 || j < i || j > s.Len {
panic("reflect.Value.Slice: string slice index out of bounds")
}
var t unsafeheader.String
if i < s.Len {
t = unsafeheader.String{Data: arrayAt(s.Data, i, 1, "i < s.Len"), Len: j - i}
}
return Value{v.typ, unsafe.Pointer(&t), v.flag}
}
if i < 0 || j < i || j > cap {
panic("reflect.Value.Slice: slice index out of bounds")
}
// Declare slice so that gc can see the base pointer in it.
var x []unsafe.Pointer
// Reinterpret as *unsafeheader.Slice to edit.
s := (*unsafeheader.Slice)(unsafe.Pointer(&x))
s.Len = j - i
s.Cap = cap - i
if cap-i > 0 {
s.Data = arrayAt(base, i, typ.elem.Size(), "i < cap")
} else {
// do not advance pointer, to avoid pointing beyond end of slice
s.Data = base
}
fl := v.flag.ro() | flagIndir | flag(Slice)
return Value{typ.common(), unsafe.Pointer(&x), fl}
}
// Slice3 is the 3-index form of the slice operation: it returns v[i:j:k].
// It panics if v's Kind is not Array or Slice, or if v is an unaddressable array,
// or if the indexes are out of bounds.
func (v Value) Slice3(i, j, k int) Value {
var (
cap int
typ *sliceType
base unsafe.Pointer
)
switch kind := v.kind(); kind {
default:
panic(&ValueError{"reflect.Value.Slice3", v.kind()})
case Array:
if v.flag&flagAddr == 0 {
panic("reflect.Value.Slice3: slice of unaddressable array")
}
tt := (*arrayType)(unsafe.Pointer(v.typ))
cap = int(tt.len)
typ = (*sliceType)(unsafe.Pointer(tt.slice))
base = v.ptr
case Slice:
typ = (*sliceType)(unsafe.Pointer(v.typ))
s := (*unsafeheader.Slice)(v.ptr)
base = s.Data
cap = s.Cap
}
if i < 0 || j < i || k < j || k > cap {
panic("reflect.Value.Slice3: slice index out of bounds")
}
// Declare slice so that the garbage collector
// can see the base pointer in it.
var x []unsafe.Pointer
// Reinterpret as *unsafeheader.Slice to edit.
s := (*unsafeheader.Slice)(unsafe.Pointer(&x))
s.Len = j - i
s.Cap = k - i
if k-i > 0 {
s.Data = arrayAt(base, i, typ.elem.Size(), "i < k <= cap")
} else {
// do not advance pointer, to avoid pointing beyond end of slice
s.Data = base
}
fl := v.flag.ro() | flagIndir | flag(Slice)
return Value{typ.common(), unsafe.Pointer(&x), fl}
}
// String returns the string v's underlying value, as a string.
// String is a special case because of Go's String method convention.
// Unlike the other getters, it does not panic if v's Kind is not String.
// Instead, it returns a string of the form "<T value>" where T is v's type.
// The fmt package treats Values specially. It does not call their String
// method implicitly but instead prints the concrete values they hold.
func (v Value) String() string {
// stringNonString is split out to keep String inlineable for string kinds.
if v.kind() == String {
return *(*string)(v.ptr)
}
return v.stringNonString()
}
func (v Value) stringNonString() string {
if v.kind() == Invalid {
return "<invalid Value>"
}
// If you call String on a reflect.Value of other type, it's better to
// print something than to panic. Useful in debugging.
return "<" + v.Type().String() + " Value>"
}
// TryRecv attempts to receive a value from the channel v but will not block.
// It panics if v's Kind is not Chan.
// If the receive delivers a value, x is the transferred value and ok is true.
// If the receive cannot finish without blocking, x is the zero Value and ok is false.
// If the channel is closed, x is the zero value for the channel's element type and ok is false.
func (v Value) TryRecv() (x Value, ok bool) {
v.mustBe(Chan)
v.mustBeExported()
return v.recv(true)
}
// TrySend attempts to send x on the channel v but will not block.
// It panics if v's Kind is not Chan.
// It reports whether the value was sent.
// As in Go, x's value must be assignable to the channel's element type.
func (v Value) TrySend(x Value) bool {
v.mustBe(Chan)
v.mustBeExported()
return v.send(x, true)
}
// Type returns v's type.
func (v Value) Type() Type {
if v.flag != 0 && v.flag&flagMethod == 0 {
return v.typ
}
return v.typeSlow()
}
func (v Value) typeSlow() Type {
if v.flag == 0 {
panic(&ValueError{"reflect.Value.Type", Invalid})
}
if v.flag&flagMethod == 0 {
return v.typ
}
// Method value.
// v.typ describes the receiver, not the method type.
i := int(v.flag) >> flagMethodShift
if v.typ.Kind() == Interface {
// Method on interface.
tt := (*interfaceType)(unsafe.Pointer(v.typ))
if uint(i) >= uint(len(tt.methods)) {
panic("reflect: internal error: invalid method index")
}
m := &tt.methods[i]
return v.typ.typeOff(m.typ)
}
// Method on concrete type.
ms := v.typ.exportedMethods()
if uint(i) >= uint(len(ms)) {
panic("reflect: internal error: invalid method index")
}
m := ms[i]
return v.typ.typeOff(m.mtyp)
}
// CanUint reports whether Uint can be used without panicking.
func (v Value) CanUint() bool {
switch v.kind() {
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return true
default:
return false
}
}
// Uint returns v's underlying value, as a uint64.
// It panics if v's Kind is not Uint, Uintptr, Uint8, Uint16, Uint32, or Uint64.
func (v Value) Uint() uint64 {
k := v.kind()
p := v.ptr
switch k {
case Uint:
return uint64(*(*uint)(p))
case Uint8:
return uint64(*(*uint8)(p))
case Uint16:
return uint64(*(*uint16)(p))
case Uint32:
return uint64(*(*uint32)(p))
case Uint64:
return *(*uint64)(p)
case Uintptr:
return uint64(*(*uintptr)(p))
}
panic(&ValueError{"reflect.Value.Uint", v.kind()})
}
//go:nocheckptr
// This prevents inlining Value.UnsafeAddr when -d=checkptr is enabled,
// which ensures cmd/compile can recognize unsafe.Pointer(v.UnsafeAddr())
// and make an exception.
// UnsafeAddr returns a pointer to v's data, as a uintptr.
// It panics if v is not addressable.
//
// It's preferred to use uintptr(Value.Addr().UnsafePointer()) to get the equivalent result.
func (v Value) UnsafeAddr() uintptr {
if v.typ == nil {
panic(&ValueError{"reflect.Value.UnsafeAddr", Invalid})
}
if v.flag&flagAddr == 0 {
panic("reflect.Value.UnsafeAddr of unaddressable value")
}
return uintptr(v.ptr)
}
// UnsafePointer returns v's value as a [unsafe.Pointer].
// It panics if v's Kind is not Chan, Func, Map, Pointer, Slice, or UnsafePointer.
//
// If v's Kind is Func, the returned pointer is an underlying
// code pointer, but not necessarily enough to identify a
// single function uniquely. The only guarantee is that the
// result is zero if and only if v is a nil func Value.
//
// If v's Kind is Slice, the returned pointer is to the first
// element of the slice. If the slice is nil the returned value
// is nil. If the slice is empty but non-nil the return value is non-nil.
func (v Value) UnsafePointer() unsafe.Pointer {
k := v.kind()
switch k {
case Pointer:
if v.typ.ptrdata == 0 {
// Since it is a not-in-heap pointer, all pointers to the heap are
// forbidden! See comment in Value.Elem and issue #48399.
if !verifyNotInHeapPtr(*(*uintptr)(v.ptr)) {
panic("reflect: reflect.Value.UnsafePointer on an invalid notinheap pointer")
}
return *(*unsafe.Pointer)(v.ptr)
}
fallthrough
case Chan, Map, UnsafePointer:
return v.pointer()
case Func:
if v.flag&flagMethod != 0 {
// As the doc comment says, the returned pointer is an
// underlying code pointer but not necessarily enough to
// identify a single function uniquely. All method expressions
// created via reflect have the same underlying code pointer,
// so their Pointers are equal. The function used here must
// match the one used in makeMethodValue.
code := methodValueCallCodePtr()
return *(*unsafe.Pointer)(unsafe.Pointer(&code))
}
p := v.pointer()
// Non-nil func value points at data block.
// First word of data block is actual code.
if p != nil {
p = *(*unsafe.Pointer)(p)
}
return p
case Slice:
return (*unsafeheader.Slice)(v.ptr).Data
}
panic(&ValueError{"reflect.Value.UnsafePointer", v.kind()})
}
// StringHeader is the runtime representation of a string.
// It cannot be used safely or portably and its representation may
// change in a later release.
// Moreover, the Data field is not sufficient to guarantee the data
// it references will not be garbage collected, so programs must keep
// a separate, correctly typed pointer to the underlying data.
//
// In new code, use unsafe.String or unsafe.StringData instead.
type StringHeader struct {
Data uintptr
Len int
}
// SliceHeader is the runtime representation of a slice.
// It cannot be used safely or portably and its representation may
// change in a later release.
// Moreover, the Data field is not sufficient to guarantee the data
// it references will not be garbage collected, so programs must keep
// a separate, correctly typed pointer to the underlying data.
//
// In new code, use unsafe.Slice or unsafe.SliceData instead.
type SliceHeader struct {
Data uintptr
Len int
Cap int
}
func typesMustMatch(what string, t1, t2 Type) {
if t1 != t2 {
panic(what + ": " + t1.String() + " != " + t2.String())
}
}
// arrayAt returns the i-th element of p,
// an array whose elements are eltSize bytes wide.
// The array pointed at by p must have at least i+1 elements:
// it is invalid (but impossible to check here) to pass i >= len,
// because then the result will point outside the array.
// whySafe must explain why i < len. (Passing "i < len" is fine;
// the benefit is to surface this assumption at the call site.)
func arrayAt(p unsafe.Pointer, i int, eltSize uintptr, whySafe string) unsafe.Pointer {
return add(p, uintptr(i)*eltSize, "i < len")
}
// Grow increases the slice's capacity, if necessary, to guarantee space for
// another n elements. After Grow(n), at least n elements can be appended
// to the slice without another allocation.
//
// It panics if v's Kind is not a Slice or if n is negative or too large to
// allocate the memory.
func (v Value) Grow(n int) {
v.mustBeAssignable()
v.mustBe(Slice)
v.grow(n)
}
// grow is identical to Grow but does not check for assignability.
func (v Value) grow(n int) {
p := (*unsafeheader.Slice)(v.ptr)
switch {
case n < 0:
panic("reflect.Value.Grow: negative len")
case p.Len+n < 0:
panic("reflect.Value.Grow: slice overflow")
case p.Len+n > p.Cap:
t := v.typ.Elem().(*rtype)
*p = growslice(t, *p, n)
}
}
// extendSlice extends a slice by n elements.
//
// Unlike Value.grow, which modifies the slice in place and
// does not change the length of the slice in place,
// extendSlice returns a new slice value with the length
// incremented by the number of specified elements.
func (v Value) extendSlice(n int) Value {
v.mustBeExported()
v.mustBe(Slice)
// Shallow copy the slice header to avoid mutating the source slice.
sh := *(*unsafeheader.Slice)(v.ptr)
s := &sh
v.ptr = unsafe.Pointer(s)
v.flag = flagIndir | flag(Slice) // equivalent flag to MakeSlice
v.grow(n) // fine to treat as assignable since we allocate a new slice header
s.Len += n
return v
}
// Clear clears the contents of a map or zeros the contents of a slice.
//
// It panics if v's Kind is not Map or Slice.
func (v Value) Clear() {
switch v.Kind() {
case Slice:
sh := *(*unsafeheader.Slice)(v.ptr)
st := (*sliceType)(unsafe.Pointer(v.typ))
typedarrayclear(st.elem, sh.Data, sh.Len)
case Map:
mapclear(v.typ, v.pointer())
default:
panic(&ValueError{"reflect.Value.Clear", v.Kind()})
}
}
// Append appends the values x to a slice s and returns the resulting slice.
// As in Go, each x's value must be assignable to the slice's element type.
func Append(s Value, x ...Value) Value {
s.mustBe(Slice)
n := s.Len()
s = s.extendSlice(len(x))
for i, v := range x {
s.Index(n + i).Set(v)
}
return s
}
// AppendSlice appends a slice t to a slice s and returns the resulting slice.
// The slices s and t must have the same element type.
func AppendSlice(s, t Value) Value {
s.mustBe(Slice)
t.mustBe(Slice)
typesMustMatch("reflect.AppendSlice", s.Type().Elem(), t.Type().Elem())
ns := s.Len()
nt := t.Len()
s = s.extendSlice(nt)
Copy(s.Slice(ns, ns+nt), t)
return s
}
// Copy copies the contents of src into dst until either
// dst has been filled or src has been exhausted.
// It returns the number of elements copied.
// Dst and src each must have kind Slice or Array, and
// dst and src must have the same element type.
//
// As a special case, src can have kind String if the element type of dst is kind Uint8.
func Copy(dst, src Value) int {
dk := dst.kind()
if dk != Array && dk != Slice {
panic(&ValueError{"reflect.Copy", dk})
}
if dk == Array {
dst.mustBeAssignable()
}
dst.mustBeExported()
sk := src.kind()
var stringCopy bool
if sk != Array && sk != Slice {
stringCopy = sk == String && dst.typ.Elem().Kind() == Uint8
if !stringCopy {
panic(&ValueError{"reflect.Copy", sk})
}
}
src.mustBeExported()
de := dst.typ.Elem()
if !stringCopy {
se := src.typ.Elem()
typesMustMatch("reflect.Copy", de, se)
}
var ds, ss unsafeheader.Slice
if dk == Array {
ds.Data = dst.ptr
ds.Len = dst.Len()
ds.Cap = ds.Len
} else {
ds = *(*unsafeheader.Slice)(dst.ptr)
}
if sk == Array {
ss.Data = src.ptr
ss.Len = src.Len()
ss.Cap = ss.Len
} else if sk == Slice {
ss = *(*unsafeheader.Slice)(src.ptr)
} else {
sh := *(*unsafeheader.String)(src.ptr)
ss.Data = sh.Data
ss.Len = sh.Len
ss.Cap = sh.Len
}
return typedslicecopy(de.common(), ds, ss)
}
// A runtimeSelect is a single case passed to rselect.
// This must match ../runtime/select.go:/runtimeSelect
type runtimeSelect struct {
dir SelectDir // SelectSend, SelectRecv or SelectDefault
typ *rtype // channel type
ch unsafe.Pointer // channel
val unsafe.Pointer // ptr to data (SendDir) or ptr to receive buffer (RecvDir)
}
// rselect runs a select. It returns the index of the chosen case.
// If the case was a receive, val is filled in with the received value.
// The conventional OK bool indicates whether the receive corresponds
// to a sent value.
//
//go:noescape
func rselect([]runtimeSelect) (chosen int, recvOK bool)
// A SelectDir describes the communication direction of a select case.
type SelectDir int
// NOTE: These values must match ../runtime/select.go:/selectDir.
const (
_ SelectDir = iota
SelectSend // case Chan <- Send
SelectRecv // case <-Chan:
SelectDefault // default
)
// A SelectCase describes a single case in a select operation.
// The kind of case depends on Dir, the communication direction.
//
// If Dir is SelectDefault, the case represents a default case.
// Chan and Send must be zero Values.
//
// If Dir is SelectSend, the case represents a send operation.
// Normally Chan's underlying value must be a channel, and Send's underlying value must be
// assignable to the channel's element type. As a special case, if Chan is a zero Value,
// then the case is ignored, and the field Send will also be ignored and may be either zero
// or non-zero.
//
// If Dir is SelectRecv, the case represents a receive operation.
// Normally Chan's underlying value must be a channel and Send must be a zero Value.
// If Chan is a zero Value, then the case is ignored, but Send must still be a zero Value.
// When a receive operation is selected, the received Value is returned by Select.
type SelectCase struct {
Dir SelectDir // direction of case
Chan Value // channel to use (for send or receive)
Send Value // value to send (for send)
}
// Select executes a select operation described by the list of cases.
// Like the Go select statement, it blocks until at least one of the cases
// can proceed, makes a uniform pseudo-random choice,
// and then executes that case. It returns the index of the chosen case
// and, if that case was a receive operation, the value received and a
// boolean indicating whether the value corresponds to a send on the channel
// (as opposed to a zero value received because the channel is closed).
// Select supports a maximum of 65536 cases.
func Select(cases []SelectCase) (chosen int, recv Value, recvOK bool) {
if len(cases) > 65536 {
panic("reflect.Select: too many cases (max 65536)")
}
// NOTE: Do not trust that caller is not modifying cases data underfoot.
// The range is safe because the caller cannot modify our copy of the len
// and each iteration makes its own copy of the value c.
var runcases []runtimeSelect
if len(cases) > 4 {
// Slice is heap allocated due to runtime dependent capacity.
runcases = make([]runtimeSelect, len(cases))
} else {
// Slice can be stack allocated due to constant capacity.
runcases = make([]runtimeSelect, len(cases), 4)
}
haveDefault := false
for i, c := range cases {
rc := &runcases[i]
rc.dir = c.Dir
switch c.Dir {
default:
panic("reflect.Select: invalid Dir")
case SelectDefault: // default
if haveDefault {
panic("reflect.Select: multiple default cases")
}
haveDefault = true
if c.Chan.IsValid() {
panic("reflect.Select: default case has Chan value")
}
if c.Send.IsValid() {
panic("reflect.Select: default case has Send value")
}
case SelectSend:
ch := c.Chan
if !ch.IsValid() {
break
}
ch.mustBe(Chan)
ch.mustBeExported()
tt := (*chanType)(unsafe.Pointer(ch.typ))
if ChanDir(tt.dir)&SendDir == 0 {
panic("reflect.Select: SendDir case using recv-only channel")
}
rc.ch = ch.pointer()
rc.typ = &tt.rtype
v := c.Send
if !v.IsValid() {
panic("reflect.Select: SendDir case missing Send value")
}
v.mustBeExported()
v = v.assignTo("reflect.Select", tt.elem, nil)
if v.flag&flagIndir != 0 {
rc.val = v.ptr
} else {
rc.val = unsafe.Pointer(&v.ptr)
}
case SelectRecv:
if c.Send.IsValid() {
panic("reflect.Select: RecvDir case has Send value")
}
ch := c.Chan
if !ch.IsValid() {
break
}
ch.mustBe(Chan)
ch.mustBeExported()
tt := (*chanType)(unsafe.Pointer(ch.typ))
if ChanDir(tt.dir)&RecvDir == 0 {
panic("reflect.Select: RecvDir case using send-only channel")
}
rc.ch = ch.pointer()
rc.typ = &tt.rtype
rc.val = unsafe_New(tt.elem)
}
}
chosen, recvOK = rselect(runcases)
if runcases[chosen].dir == SelectRecv {
tt := (*chanType)(unsafe.Pointer(runcases[chosen].typ))
t := tt.elem
p := runcases[chosen].val
fl := flag(t.Kind())
if ifaceIndir(t) {
recv = Value{t, p, fl | flagIndir}
} else {
recv = Value{t, *(*unsafe.Pointer)(p), fl}
}
}
return chosen, recv, recvOK
}
/*
* constructors
*/
// implemented in package runtime
func unsafe_New(*rtype) unsafe.Pointer
func unsafe_NewArray(*rtype, int) unsafe.Pointer
// MakeSlice creates a new zero-initialized slice value
// for the specified slice type, length, and capacity.
func MakeSlice(typ Type, len, cap int) Value {
if typ.Kind() != Slice {
panic("reflect.MakeSlice of non-slice type")
}
if len < 0 {
panic("reflect.MakeSlice: negative len")
}
if cap < 0 {
panic("reflect.MakeSlice: negative cap")
}
if len > cap {
panic("reflect.MakeSlice: len > cap")
}
s := unsafeheader.Slice{Data: unsafe_NewArray(typ.Elem().(*rtype), cap), Len: len, Cap: cap}
return Value{typ.(*rtype), unsafe.Pointer(&s), flagIndir | flag(Slice)}
}
// MakeChan creates a new channel with the specified type and buffer size.
func MakeChan(typ Type, buffer int) Value {
if typ.Kind() != Chan {
panic("reflect.MakeChan of non-chan type")
}
if buffer < 0 {
panic("reflect.MakeChan: negative buffer size")
}
if typ.ChanDir() != BothDir {
panic("reflect.MakeChan: unidirectional channel type")
}
t := typ.(*rtype)
ch := makechan(t, buffer)
return Value{t, ch, flag(Chan)}
}
// MakeMap creates a new map with the specified type.
func MakeMap(typ Type) Value {
return MakeMapWithSize(typ, 0)
}
// MakeMapWithSize creates a new map with the specified type
// and initial space for approximately n elements.
func MakeMapWithSize(typ Type, n int) Value {
if typ.Kind() != Map {
panic("reflect.MakeMapWithSize of non-map type")
}
t := typ.(*rtype)
m := makemap(t, n)
return Value{t, m, flag(Map)}
}
// Indirect returns the value that v points to.
// If v is a nil pointer, Indirect returns a zero Value.
// If v is not a pointer, Indirect returns v.
func Indirect(v Value) Value {
if v.Kind() != Pointer {
return v
}
return v.Elem()
}
// ValueOf returns a new Value initialized to the concrete value
// stored in the interface i. ValueOf(nil) returns the zero Value.
func ValueOf(i any) Value {
if i == nil {
return Value{}
}
// TODO: Maybe allow contents of a Value to live on the stack.
// For now we make the contents always escape to the heap. It
// makes life easier in a few places (see chanrecv/mapassign
// comment below).
escapes(i)
return unpackEface(i)
}
// Zero returns a Value representing the zero value for the specified type.
// The result is different from the zero value of the Value struct,
// which represents no value at all.
// For example, Zero(TypeOf(42)) returns a Value with Kind Int and value 0.
// The returned value is neither addressable nor settable.
func Zero(typ Type) Value {
if typ == nil {
panic("reflect: Zero(nil)")
}
t := typ.(*rtype)
fl := flag(t.Kind())
if ifaceIndir(t) {
var p unsafe.Pointer
if t.size <= maxZero {
p = unsafe.Pointer(&zeroVal[0])
} else {
p = unsafe_New(t)
}
return Value{t, p, fl | flagIndir}
}
return Value{t, nil, fl}
}
// must match declarations in runtime/map.go.
const maxZero = 1024
//go:linkname zeroVal runtime.zeroVal
var zeroVal [maxZero]byte
// New returns a Value representing a pointer to a new zero value
// for the specified type. That is, the returned Value's Type is PointerTo(typ).
func New(typ Type) Value {
if typ == nil {
panic("reflect: New(nil)")
}
t := typ.(*rtype)
pt := t.ptrTo()
if ifaceIndir(pt) {
// This is a pointer to a not-in-heap type.
panic("reflect: New of type that may not be allocated in heap (possibly undefined cgo C type)")
}
ptr := unsafe_New(t)
fl := flag(Pointer)
return Value{pt, ptr, fl}
}
// NewAt returns a Value representing a pointer to a value of the
// specified type, using p as that pointer.
func NewAt(typ Type, p unsafe.Pointer) Value {
fl := flag(Pointer)
t := typ.(*rtype)
return Value{t.ptrTo(), p, fl}
}
// assignTo returns a value v that can be assigned directly to dst.
// It panics if v is not assignable to dst.
// For a conversion to an interface type, target, if not nil,
// is a suggested scratch space to use.
// target must be initialized memory (or nil).
func (v Value) assignTo(context string, dst *rtype, target unsafe.Pointer) Value {
if v.flag&flagMethod != 0 {
v = makeMethodValue(context, v)
}
switch {
case directlyAssignable(dst, v.typ):
// Overwrite type so that they match.
// Same memory layout, so no harm done.
fl := v.flag&(flagAddr|flagIndir) | v.flag.ro()
fl |= flag(dst.Kind())
return Value{dst, v.ptr, fl}
case implements(dst, v.typ):
if v.Kind() == Interface && v.IsNil() {
// A nil ReadWriter passed to nil Reader is OK,
// but using ifaceE2I below will panic.
// Avoid the panic by returning a nil dst (e.g., Reader) explicitly.
return Value{dst, nil, flag(Interface)}
}
x := valueInterface(v, false)
if target == nil {
target = unsafe_New(dst)
}
if dst.NumMethod() == 0 {
*(*any)(target) = x
} else {
ifaceE2I(dst, x, target)
}
return Value{dst, target, flagIndir | flag(Interface)}
}
// Failed.
panic(context + ": value of type " + v.typ.String() + " is not assignable to type " + dst.String())
}
// Convert returns the value v converted to type t.
// If the usual Go conversion rules do not allow conversion
// of the value v to type t, or if converting v to type t panics, Convert panics.
func (v Value) Convert(t Type) Value {
if v.flag&flagMethod != 0 {
v = makeMethodValue("Convert", v)
}
op := convertOp(t.common(), v.typ)
if op == nil {
panic("reflect.Value.Convert: value of type " + v.typ.String() + " cannot be converted to type " + t.String())
}
return op(v, t)
}
// CanConvert reports whether the value v can be converted to type t.
// If v.CanConvert(t) returns true then v.Convert(t) will not panic.
func (v Value) CanConvert(t Type) bool {
vt := v.Type()
if !vt.ConvertibleTo(t) {
return false
}
// Converting from slice to array or to pointer-to-array can panic
// depending on the value.
switch {
case vt.Kind() == Slice && t.Kind() == Array:
if t.Len() > v.Len() {
return false
}
case vt.Kind() == Slice && t.Kind() == Pointer && t.Elem().Kind() == Array:
n := t.Elem().Len()
if n > v.Len() {
return false
}
}
return true
}
// Comparable reports whether the value v is comparable.
// If the type of v is an interface, this checks the dynamic type.
// If this reports true then v.Interface() == x will not panic for any x,
// nor will v.Equal(u) for any Value u.
func (v Value) Comparable() bool {
k := v.Kind()
switch k {
case Invalid:
return false
case Array:
switch v.Type().Elem().Kind() {
case Interface, Array, Struct:
for i := 0; i < v.Type().Len(); i++ {
if !v.Index(i).Comparable() {
return false
}
}
return true
}
return v.Type().Comparable()
case Interface:
return v.Elem().Comparable()
case Struct:
for i := 0; i < v.NumField(); i++ {
if !v.Field(i).Comparable() {
return false
}
}
return true
default:
return v.Type().Comparable()
}
}
// Equal reports true if v is equal to u.
// For two invalid values, Equal will report true.
// For an interface value, Equal will compare the value within the interface.
// Otherwise, If the values have different types, Equal will report false.
// Otherwise, for arrays and structs Equal will compare each element in order,
// and report false if it finds non-equal elements.
// During all comparisons, if values of the same type are compared,
// and the type is not comparable, Equal will panic.
func (v Value) Equal(u Value) bool {
if v.Kind() == Interface {
v = v.Elem()
}
if u.Kind() == Interface {
u = u.Elem()
}
if !v.IsValid() || !u.IsValid() {
return v.IsValid() == u.IsValid()
}
if v.Kind() != u.Kind() || v.Type() != u.Type() {
return false
}
// Handle each Kind directly rather than calling valueInterface
// to avoid allocating.
switch v.Kind() {
default:
panic("reflect.Value.Equal: invalid Kind")
case Bool:
return v.Bool() == u.Bool()
case Int, Int8, Int16, Int32, Int64:
return v.Int() == u.Int()
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return v.Uint() == u.Uint()
case Float32, Float64:
return v.Float() == u.Float()
case Complex64, Complex128:
return v.Complex() == u.Complex()
case String:
return v.String() == u.String()
case Chan, Pointer, UnsafePointer:
return v.Pointer() == u.Pointer()
case Array:
// u and v have the same type so they have the same length
vl := v.Len()
if vl == 0 {
// panic on [0]func()
if !v.Type().Elem().Comparable() {
break
}
return true
}
for i := 0; i < vl; i++ {
if !v.Index(i).Equal(u.Index(i)) {
return false
}
}
return true
case Struct:
// u and v have the same type so they have the same fields
nf := v.NumField()
for i := 0; i < nf; i++ {
if !v.Field(i).Equal(u.Field(i)) {
return false
}
}
return true
case Func, Map, Slice:
break
}
panic("reflect.Value.Equal: values of type " + v.Type().String() + " are not comparable")
}
// convertOp returns the function to convert a value of type src
// to a value of type dst. If the conversion is illegal, convertOp returns nil.
func convertOp(dst, src *rtype) func(Value, Type) Value {
switch src.Kind() {
case Int, Int8, Int16, Int32, Int64:
switch dst.Kind() {
case Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return cvtInt
case Float32, Float64:
return cvtIntFloat
case String:
return cvtIntString
}
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
switch dst.Kind() {
case Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return cvtUint
case Float32, Float64:
return cvtUintFloat
case String:
return cvtUintString
}
case Float32, Float64:
switch dst.Kind() {
case Int, Int8, Int16, Int32, Int64:
return cvtFloatInt
case Uint, Uint8, Uint16, Uint32, Uint64, Uintptr:
return cvtFloatUint
case Float32, Float64:
return cvtFloat
}
case Complex64, Complex128:
switch dst.Kind() {
case Complex64, Complex128:
return cvtComplex
}
case String:
if dst.Kind() == Slice && dst.Elem().PkgPath() == "" {
switch dst.Elem().Kind() {
case Uint8:
return cvtStringBytes
case Int32:
return cvtStringRunes
}
}
case Slice:
if dst.Kind() == String && src.Elem().PkgPath() == "" {
switch src.Elem().Kind() {
case Uint8:
return cvtBytesString
case Int32:
return cvtRunesString
}
}
// "x is a slice, T is a pointer-to-array type,
// and the slice and array types have identical element types."
if dst.Kind() == Pointer && dst.Elem().Kind() == Array && src.Elem() == dst.Elem().Elem() {
return cvtSliceArrayPtr
}
// "x is a slice, T is a array type,
// and the slice and array types have identical element types."
if dst.Kind() == Array && src.Elem() == dst.Elem() {
return cvtSliceArray
}
case Chan:
if dst.Kind() == Chan && specialChannelAssignability(dst, src) {
return cvtDirect
}
}
// dst and src have same underlying type.
if haveIdenticalUnderlyingType(dst, src, false) {
return cvtDirect
}
// dst and src are non-defined pointer types with same underlying base type.
if dst.Kind() == Pointer && dst.Name() == "" &&
src.Kind() == Pointer && src.Name() == "" &&
haveIdenticalUnderlyingType(dst.Elem().common(), src.Elem().common(), false) {
return cvtDirect
}
if implements(dst, src) {
if src.Kind() == Interface {
return cvtI2I
}
return cvtT2I
}
return nil
}
// makeInt returns a Value of type t equal to bits (possibly truncated),
// where t is a signed or unsigned int type.
func makeInt(f flag, bits uint64, t Type) Value {
typ := t.common()
ptr := unsafe_New(typ)
switch typ.size {
case 1:
*(*uint8)(ptr) = uint8(bits)
case 2:
*(*uint16)(ptr) = uint16(bits)
case 4:
*(*uint32)(ptr) = uint32(bits)
case 8:
*(*uint64)(ptr) = bits
}
return Value{typ, ptr, f | flagIndir | flag(typ.Kind())}
}
// makeFloat returns a Value of type t equal to v (possibly truncated to float32),
// where t is a float32 or float64 type.
func makeFloat(f flag, v float64, t Type) Value {
typ := t.common()
ptr := unsafe_New(typ)
switch typ.size {
case 4:
*(*float32)(ptr) = float32(v)
case 8:
*(*float64)(ptr) = v
}
return Value{typ, ptr, f | flagIndir | flag(typ.Kind())}
}
// makeFloat32 returns a Value of type t equal to v, where t is a float32 type.
func makeFloat32(f flag, v float32, t Type) Value {
typ := t.common()
ptr := unsafe_New(typ)
*(*float32)(ptr) = v
return Value{typ, ptr, f | flagIndir | flag(typ.Kind())}
}
// makeComplex returns a Value of type t equal to v (possibly truncated to complex64),
// where t is a complex64 or complex128 type.
func makeComplex(f flag, v complex128, t Type) Value {
typ := t.common()
ptr := unsafe_New(typ)
switch typ.size {
case 8:
*(*complex64)(ptr) = complex64(v)
case 16:
*(*complex128)(ptr) = v
}
return Value{typ, ptr, f | flagIndir | flag(typ.Kind())}
}
func makeString(f flag, v string, t Type) Value {
ret := New(t).Elem()
ret.SetString(v)
ret.flag = ret.flag&^flagAddr | f
return ret
}
func makeBytes(f flag, v []byte, t Type) Value {
ret := New(t).Elem()
ret.SetBytes(v)
ret.flag = ret.flag&^flagAddr | f
return ret
}
func makeRunes(f flag, v []rune, t Type) Value {
ret := New(t).Elem()
ret.setRunes(v)
ret.flag = ret.flag&^flagAddr | f
return ret
}
// These conversion functions are returned by convertOp
// for classes of conversions. For example, the first function, cvtInt,
// takes any value v of signed int type and returns the value converted
// to type t, where t is any signed or unsigned int type.
// convertOp: intXX -> [u]intXX
func cvtInt(v Value, t Type) Value {
return makeInt(v.flag.ro(), uint64(v.Int()), t)
}
// convertOp: uintXX -> [u]intXX
func cvtUint(v Value, t Type) Value {
return makeInt(v.flag.ro(), v.Uint(), t)
}
// convertOp: floatXX -> intXX
func cvtFloatInt(v Value, t Type) Value {
return makeInt(v.flag.ro(), uint64(int64(v.Float())), t)
}
// convertOp: floatXX -> uintXX
func cvtFloatUint(v Value, t Type) Value {
return makeInt(v.flag.ro(), uint64(v.Float()), t)
}
// convertOp: intXX -> floatXX
func cvtIntFloat(v Value, t Type) Value {
return makeFloat(v.flag.ro(), float64(v.Int()), t)
}
// convertOp: uintXX -> floatXX
func cvtUintFloat(v Value, t Type) Value {
return makeFloat(v.flag.ro(), float64(v.Uint()), t)
}
// convertOp: floatXX -> floatXX
func cvtFloat(v Value, t Type) Value {
if v.Type().Kind() == Float32 && t.Kind() == Float32 {
// Don't do any conversion if both types have underlying type float32.
// This avoids converting to float64 and back, which will
// convert a signaling NaN to a quiet NaN. See issue 36400.
return makeFloat32(v.flag.ro(), *(*float32)(v.ptr), t)
}
return makeFloat(v.flag.ro(), v.Float(), t)
}
// convertOp: complexXX -> complexXX
func cvtComplex(v Value, t Type) Value {
return makeComplex(v.flag.ro(), v.Complex(), t)
}
// convertOp: intXX -> string
func cvtIntString(v Value, t Type) Value {
s := "\uFFFD"
if x := v.Int(); int64(rune(x)) == x {
s = string(rune(x))
}
return makeString(v.flag.ro(), s, t)
}
// convertOp: uintXX -> string
func cvtUintString(v Value, t Type) Value {
s := "\uFFFD"
if x := v.Uint(); uint64(rune(x)) == x {
s = string(rune(x))
}
return makeString(v.flag.ro(), s, t)
}
// convertOp: []byte -> string
func cvtBytesString(v Value, t Type) Value {
return makeString(v.flag.ro(), string(v.Bytes()), t)
}
// convertOp: string -> []byte
func cvtStringBytes(v Value, t Type) Value {
return makeBytes(v.flag.ro(), []byte(v.String()), t)
}
// convertOp: []rune -> string
func cvtRunesString(v Value, t Type) Value {
return makeString(v.flag.ro(), string(v.runes()), t)
}
// convertOp: string -> []rune
func cvtStringRunes(v Value, t Type) Value {
return makeRunes(v.flag.ro(), []rune(v.String()), t)
}
// convertOp: []T -> *[N]T
func cvtSliceArrayPtr(v Value, t Type) Value {
n := t.Elem().Len()
if n > v.Len() {
panic("reflect: cannot convert slice with length " + itoa.Itoa(v.Len()) + " to pointer to array with length " + itoa.Itoa(n))
}
h := (*unsafeheader.Slice)(v.ptr)
return Value{t.common(), h.Data, v.flag&^(flagIndir|flagAddr|flagKindMask) | flag(Pointer)}
}
// convertOp: []T -> [N]T
func cvtSliceArray(v Value, t Type) Value {
n := t.Len()
if n > v.Len() {
panic("reflect: cannot convert slice with length " + itoa.Itoa(v.Len()) + " to array with length " + itoa.Itoa(n))
}
h := (*unsafeheader.Slice)(v.ptr)
typ := t.common()
ptr := h.Data
c := unsafe_New(typ)
typedmemmove(typ, c, ptr)
ptr = c
return Value{typ, ptr, v.flag&^(flagAddr|flagKindMask) | flag(Array)}
}
// convertOp: direct copy
func cvtDirect(v Value, typ Type) Value {
f := v.flag
t := typ.common()
ptr := v.ptr
if f&flagAddr != 0 {
// indirect, mutable word - make a copy
c := unsafe_New(t)
typedmemmove(t, c, ptr)
ptr = c
f &^= flagAddr
}
return Value{t, ptr, v.flag.ro() | f} // v.flag.ro()|f == f?
}
// convertOp: concrete -> interface
func cvtT2I(v Value, typ Type) Value {
target := unsafe_New(typ.common())
x := valueInterface(v, false)
if typ.NumMethod() == 0 {
*(*any)(target) = x
} else {
ifaceE2I(typ.(*rtype), x, target)
}
return Value{typ.common(), target, v.flag.ro() | flagIndir | flag(Interface)}
}
// convertOp: interface -> interface
func cvtI2I(v Value, typ Type) Value {
if v.IsNil() {
ret := Zero(typ)
ret.flag |= v.flag.ro()
return ret
}
return cvtT2I(v.Elem(), typ)
}
// implemented in ../runtime
func chancap(ch unsafe.Pointer) int
func chanclose(ch unsafe.Pointer)
func chanlen(ch unsafe.Pointer) int
// Note: some of the noescape annotations below are technically a lie,
// but safe in the context of this package. Functions like chansend
// and mapassign don't escape the referent, but may escape anything
// the referent points to (they do shallow copies of the referent).
// It is safe in this package because the referent may only point
// to something a Value may point to, and that is always in the heap
// (due to the escapes() call in ValueOf).
//go:noescape
func chanrecv(ch unsafe.Pointer, nb bool, val unsafe.Pointer) (selected, received bool)
//go:noescape
func chansend(ch unsafe.Pointer, val unsafe.Pointer, nb bool) bool
func makechan(typ *rtype, size int) (ch unsafe.Pointer)
func makemap(t *rtype, cap int) (m unsafe.Pointer)
//go:noescape
func mapaccess(t *rtype, m unsafe.Pointer, key unsafe.Pointer) (val unsafe.Pointer)
//go:noescape
func mapaccess_faststr(t *rtype, m unsafe.Pointer, key string) (val unsafe.Pointer)
//go:noescape
func mapassign(t *rtype, m unsafe.Pointer, key, val unsafe.Pointer)
//go:noescape
func mapassign_faststr(t *rtype, m unsafe.Pointer, key string, val unsafe.Pointer)
//go:noescape
func mapdelete(t *rtype, m unsafe.Pointer, key unsafe.Pointer)
//go:noescape
func mapdelete_faststr(t *rtype, m unsafe.Pointer, key string)
//go:noescape
func mapiterinit(t *rtype, m unsafe.Pointer, it *hiter)
//go:noescape
func mapiterkey(it *hiter) (key unsafe.Pointer)
//go:noescape
func mapiterelem(it *hiter) (elem unsafe.Pointer)
//go:noescape
func mapiternext(it *hiter)
//go:noescape
func maplen(m unsafe.Pointer) int
func mapclear(t *rtype, m unsafe.Pointer)
// call calls fn with "stackArgsSize" bytes of stack arguments laid out
// at stackArgs and register arguments laid out in regArgs. frameSize is
// the total amount of stack space that will be reserved by call, so this
// should include enough space to spill register arguments to the stack in
// case of preemption.
//
// After fn returns, call copies stackArgsSize-stackRetOffset result bytes
// back into stackArgs+stackRetOffset before returning, for any return
// values passed on the stack. Register-based return values will be found
// in the same regArgs structure.
//
// regArgs must also be prepared with an appropriate ReturnIsPtr bitmap
// indicating which registers will contain pointer-valued return values. The
// purpose of this bitmap is to keep pointers visible to the GC between
// returning from reflectcall and actually using them.
//
// If copying result bytes back from the stack, the caller must pass the
// argument frame type as stackArgsType, so that call can execute appropriate
// write barriers during the copy.
//
// Arguments passed through to call do not escape. The type is used only in a
// very limited callee of call, the stackArgs are copied, and regArgs is only
// used in the call frame.
//
//go:noescape
//go:linkname call runtime.reflectcall
func call(stackArgsType *rtype, f, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func ifaceE2I(t *rtype, src any, dst unsafe.Pointer)
// memmove copies size bytes to dst from src. No write barriers are used.
//
//go:noescape
func memmove(dst, src unsafe.Pointer, size uintptr)
// typedmemmove copies a value of type t to dst from src.
//
//go:noescape
func typedmemmove(t *rtype, dst, src unsafe.Pointer)
// typedmemmovepartial is like typedmemmove but assumes that
// dst and src point off bytes into the value and only copies size bytes.
//
//go:noescape
func typedmemmovepartial(t *rtype, dst, src unsafe.Pointer, off, size uintptr)
// typedmemclr zeros the value at ptr of type t.
//
//go:noescape
func typedmemclr(t *rtype, ptr unsafe.Pointer)
// typedmemclrpartial is like typedmemclr but assumes that
// dst points off bytes into the value and only clears size bytes.
//
//go:noescape
func typedmemclrpartial(t *rtype, ptr unsafe.Pointer, off, size uintptr)
// typedslicecopy copies a slice of elemType values from src to dst,
// returning the number of elements copied.
//
//go:noescape
func typedslicecopy(elemType *rtype, dst, src unsafeheader.Slice) int
// typedarrayclear zeroes the value at ptr of an array of elemType,
// only clears len elem.
//
//go:noescape
func typedarrayclear(elemType *rtype, ptr unsafe.Pointer, len int)
//go:noescape
func typehash(t *rtype, p unsafe.Pointer, h uintptr) uintptr
func verifyNotInHeapPtr(p uintptr) bool
//go:noescape
func growslice(t *rtype, old unsafeheader.Slice, num int) unsafeheader.Slice
// Dummy annotation marking that the value x escapes,
// for use in cases where the reflect code is so clever that
// the compiler cannot follow.
func escapes(x any) {
if dummy.b {
dummy.x = x
}
}
var dummy struct {
b bool
x any
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package reflect
// VisibleFields returns all the visible fields in t, which must be a
// struct type. A field is defined as visible if it's accessible
// directly with a FieldByName call. The returned fields include fields
// inside anonymous struct members and unexported fields. They follow
// the same order found in the struct, with anonymous fields followed
// immediately by their promoted fields.
//
// For each element e of the returned slice, the corresponding field
// can be retrieved from a value v of type t by calling v.FieldByIndex(e.Index).
func VisibleFields(t Type) []StructField {
if t == nil {
panic("reflect: VisibleFields(nil)")
}
if t.Kind() != Struct {
panic("reflect.VisibleFields of non-struct type")
}
w := &visibleFieldsWalker{
byName: make(map[string]int),
visiting: make(map[Type]bool),
fields: make([]StructField, 0, t.NumField()),
index: make([]int, 0, 2),
}
w.walk(t)
// Remove all the fields that have been hidden.
// Use an in-place removal that avoids copying in
// the common case that there are no hidden fields.
j := 0
for i := range w.fields {
f := &w.fields[i]
if f.Name == "" {
continue
}
if i != j {
// A field has been removed. We need to shuffle
// all the subsequent elements up.
w.fields[j] = *f
}
j++
}
return w.fields[:j]
}
type visibleFieldsWalker struct {
byName map[string]int
visiting map[Type]bool
fields []StructField
index []int
}
// walk walks all the fields in the struct type t, visiting
// fields in index preorder and appending them to w.fields
// (this maintains the required ordering).
// Fields that have been overridden have their
// Name field cleared.
func (w *visibleFieldsWalker) walk(t Type) {
if w.visiting[t] {
return
}
w.visiting[t] = true
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
w.index = append(w.index, i)
add := true
if oldIndex, ok := w.byName[f.Name]; ok {
old := &w.fields[oldIndex]
if len(w.index) == len(old.Index) {
// Fields with the same name at the same depth
// cancel one another out. Set the field name
// to empty to signify that has happened, and
// there's no need to add this field.
old.Name = ""
add = false
} else if len(w.index) < len(old.Index) {
// The old field loses because it's deeper than the new one.
old.Name = ""
} else {
// The old field wins because it's shallower than the new one.
add = false
}
}
if add {
// Copy the index so that it's not overwritten
// by the other appends.
f.Index = append([]int(nil), w.index...)
w.byName[f.Name] = len(w.fields)
w.fields = append(w.fields, f)
}
if f.Anonymous {
if f.Type.Kind() == Pointer {
f.Type = f.Type.Elem()
}
if f.Type.Kind() == Struct {
w.walk(f.Type)
}
}
w.index = w.index[:len(w.index)-1]
}
delete(w.visiting, t)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// backtrack is a regular expression search with submatch
// tracking for small regular expressions and texts. It allocates
// a bit vector with (length of input) * (length of prog) bits,
// to make sure it never explores the same (character position, instruction)
// state multiple times. This limits the search to run in time linear in
// the length of the test.
//
// backtrack is a fast replacement for the NFA code on small
// regexps when onepass cannot be used.
package regexp
import (
"regexp/syntax"
"sync"
)
// A job is an entry on the backtracker's job stack. It holds
// the instruction pc and the position in the input.
type job struct {
pc uint32
arg bool
pos int
}
const (
visitedBits = 32
maxBacktrackProg = 500 // len(prog.Inst) <= max
maxBacktrackVector = 256 * 1024 // bit vector size <= max (bits)
)
// bitState holds state for the backtracker.
type bitState struct {
end int
cap []int
matchcap []int
jobs []job
visited []uint32
inputs inputs
}
var bitStatePool sync.Pool
func newBitState() *bitState {
b, ok := bitStatePool.Get().(*bitState)
if !ok {
b = new(bitState)
}
return b
}
func freeBitState(b *bitState) {
b.inputs.clear()
bitStatePool.Put(b)
}
// maxBitStateLen returns the maximum length of a string to search with
// the backtracker using prog.
func maxBitStateLen(prog *syntax.Prog) int {
if !shouldBacktrack(prog) {
return 0
}
return maxBacktrackVector / len(prog.Inst)
}
// shouldBacktrack reports whether the program is too
// long for the backtracker to run.
func shouldBacktrack(prog *syntax.Prog) bool {
return len(prog.Inst) <= maxBacktrackProg
}
// reset resets the state of the backtracker.
// end is the end position in the input.
// ncap is the number of captures.
func (b *bitState) reset(prog *syntax.Prog, end int, ncap int) {
b.end = end
if cap(b.jobs) == 0 {
b.jobs = make([]job, 0, 256)
} else {
b.jobs = b.jobs[:0]
}
visitedSize := (len(prog.Inst)*(end+1) + visitedBits - 1) / visitedBits
if cap(b.visited) < visitedSize {
b.visited = make([]uint32, visitedSize, maxBacktrackVector/visitedBits)
} else {
b.visited = b.visited[:visitedSize]
for i := range b.visited {
b.visited[i] = 0
}
}
if cap(b.cap) < ncap {
b.cap = make([]int, ncap)
} else {
b.cap = b.cap[:ncap]
}
for i := range b.cap {
b.cap[i] = -1
}
if cap(b.matchcap) < ncap {
b.matchcap = make([]int, ncap)
} else {
b.matchcap = b.matchcap[:ncap]
}
for i := range b.matchcap {
b.matchcap[i] = -1
}
}
// shouldVisit reports whether the combination of (pc, pos) has not
// been visited yet.
func (b *bitState) shouldVisit(pc uint32, pos int) bool {
n := uint(int(pc)*(b.end+1) + pos)
if b.visited[n/visitedBits]&(1<<(n&(visitedBits-1))) != 0 {
return false
}
b.visited[n/visitedBits] |= 1 << (n & (visitedBits - 1))
return true
}
// push pushes (pc, pos, arg) onto the job stack if it should be
// visited.
func (b *bitState) push(re *Regexp, pc uint32, pos int, arg bool) {
// Only check shouldVisit when arg is false.
// When arg is true, we are continuing a previous visit.
if re.prog.Inst[pc].Op != syntax.InstFail && (arg || b.shouldVisit(pc, pos)) {
b.jobs = append(b.jobs, job{pc: pc, arg: arg, pos: pos})
}
}
// tryBacktrack runs a backtracking search starting at pos.
func (re *Regexp) tryBacktrack(b *bitState, i input, pc uint32, pos int) bool {
longest := re.longest
b.push(re, pc, pos, false)
for len(b.jobs) > 0 {
l := len(b.jobs) - 1
// Pop job off the stack.
pc := b.jobs[l].pc
pos := b.jobs[l].pos
arg := b.jobs[l].arg
b.jobs = b.jobs[:l]
// Optimization: rather than push and pop,
// code that is going to Push and continue
// the loop simply updates ip, p, and arg
// and jumps to CheckAndLoop. We have to
// do the ShouldVisit check that Push
// would have, but we avoid the stack
// manipulation.
goto Skip
CheckAndLoop:
if !b.shouldVisit(pc, pos) {
continue
}
Skip:
inst := &re.prog.Inst[pc]
switch inst.Op {
default:
panic("bad inst")
case syntax.InstFail:
panic("unexpected InstFail")
case syntax.InstAlt:
// Cannot just
// b.push(inst.Out, pos, false)
// b.push(inst.Arg, pos, false)
// If during the processing of inst.Out, we encounter
// inst.Arg via another path, we want to process it then.
// Pushing it here will inhibit that. Instead, re-push
// inst with arg==true as a reminder to push inst.Arg out
// later.
if arg {
// Finished inst.Out; try inst.Arg.
arg = false
pc = inst.Arg
goto CheckAndLoop
} else {
b.push(re, pc, pos, true)
pc = inst.Out
goto CheckAndLoop
}
case syntax.InstAltMatch:
// One opcode consumes runes; the other leads to match.
switch re.prog.Inst[inst.Out].Op {
case syntax.InstRune, syntax.InstRune1, syntax.InstRuneAny, syntax.InstRuneAnyNotNL:
// inst.Arg is the match.
b.push(re, inst.Arg, pos, false)
pc = inst.Arg
pos = b.end
goto CheckAndLoop
}
// inst.Out is the match - non-greedy
b.push(re, inst.Out, b.end, false)
pc = inst.Out
goto CheckAndLoop
case syntax.InstRune:
r, width := i.step(pos)
if !inst.MatchRune(r) {
continue
}
pos += width
pc = inst.Out
goto CheckAndLoop
case syntax.InstRune1:
r, width := i.step(pos)
if r != inst.Rune[0] {
continue
}
pos += width
pc = inst.Out
goto CheckAndLoop
case syntax.InstRuneAnyNotNL:
r, width := i.step(pos)
if r == '\n' || r == endOfText {
continue
}
pos += width
pc = inst.Out
goto CheckAndLoop
case syntax.InstRuneAny:
r, width := i.step(pos)
if r == endOfText {
continue
}
pos += width
pc = inst.Out
goto CheckAndLoop
case syntax.InstCapture:
if arg {
// Finished inst.Out; restore the old value.
b.cap[inst.Arg] = pos
continue
} else {
if inst.Arg < uint32(len(b.cap)) {
// Capture pos to register, but save old value.
b.push(re, pc, b.cap[inst.Arg], true) // come back when we're done.
b.cap[inst.Arg] = pos
}
pc = inst.Out
goto CheckAndLoop
}
case syntax.InstEmptyWidth:
flag := i.context(pos)
if !flag.match(syntax.EmptyOp(inst.Arg)) {
continue
}
pc = inst.Out
goto CheckAndLoop
case syntax.InstNop:
pc = inst.Out
goto CheckAndLoop
case syntax.InstMatch:
// We found a match. If the caller doesn't care
// where the match is, no point going further.
if len(b.cap) == 0 {
return true
}
// Record best match so far.
// Only need to check end point, because this entire
// call is only considering one start position.
if len(b.cap) > 1 {
b.cap[1] = pos
}
if old := b.matchcap[1]; old == -1 || (longest && pos > 0 && pos > old) {
copy(b.matchcap, b.cap)
}
// If going for first match, we're done.
if !longest {
return true
}
// If we used the entire text, no longer match is possible.
if pos == b.end {
return true
}
// Otherwise, continue on in hope of a longer match.
continue
}
}
return longest && len(b.matchcap) > 1 && b.matchcap[1] >= 0
}
// backtrack runs a backtracking search of prog on the input starting at pos.
func (re *Regexp) backtrack(ib []byte, is string, pos int, ncap int, dstCap []int) []int {
startCond := re.cond
if startCond == ^syntax.EmptyOp(0) { // impossible
return nil
}
if startCond&syntax.EmptyBeginText != 0 && pos != 0 {
// Anchored match, past beginning of text.
return nil
}
b := newBitState()
i, end := b.inputs.init(nil, ib, is)
b.reset(re.prog, end, ncap)
// Anchored search must start at the beginning of the input
if startCond&syntax.EmptyBeginText != 0 {
if len(b.cap) > 0 {
b.cap[0] = pos
}
if !re.tryBacktrack(b, i, uint32(re.prog.Start), pos) {
freeBitState(b)
return nil
}
} else {
// Unanchored search, starting from each possible text position.
// Notice that we have to try the empty string at the end of
// the text, so the loop condition is pos <= end, not pos < end.
// This looks like it's quadratic in the size of the text,
// but we are not clearing visited between calls to TrySearch,
// so no work is duplicated and it ends up still being linear.
width := -1
for ; pos <= end && width != 0; pos += width {
if len(re.prefix) > 0 {
// Match requires literal prefix; fast search for it.
advance := i.index(re, pos)
if advance < 0 {
freeBitState(b)
return nil
}
pos += advance
}
if len(b.cap) > 0 {
b.cap[0] = pos
}
if re.tryBacktrack(b, i, uint32(re.prog.Start), pos) {
// Match must be leftmost; done.
goto Match
}
_, width = i.step(pos)
}
freeBitState(b)
return nil
}
Match:
dstCap = append(dstCap, b.matchcap...)
freeBitState(b)
return dstCap
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package regexp
import (
"io"
"regexp/syntax"
"sync"
)
// A queue is a 'sparse array' holding pending threads of execution.
// See https://research.swtch.com/2008/03/using-uninitialized-memory-for-fun-and.html
type queue struct {
sparse []uint32
dense []entry
}
// An entry is an entry on a queue.
// It holds both the instruction pc and the actual thread.
// Some queue entries are just place holders so that the machine
// knows it has considered that pc. Such entries have t == nil.
type entry struct {
pc uint32
t *thread
}
// A thread is the state of a single path through the machine:
// an instruction and a corresponding capture array.
// See https://swtch.com/~rsc/regexp/regexp2.html
type thread struct {
inst *syntax.Inst
cap []int
}
// A machine holds all the state during an NFA simulation for p.
type machine struct {
re *Regexp // corresponding Regexp
p *syntax.Prog // compiled program
q0, q1 queue // two queues for runq, nextq
pool []*thread // pool of available threads
matched bool // whether a match was found
matchcap []int // capture information for the match
inputs inputs
}
type inputs struct {
// cached inputs, to avoid allocation
bytes inputBytes
string inputString
reader inputReader
}
func (i *inputs) newBytes(b []byte) input {
i.bytes.str = b
return &i.bytes
}
func (i *inputs) newString(s string) input {
i.string.str = s
return &i.string
}
func (i *inputs) newReader(r io.RuneReader) input {
i.reader.r = r
i.reader.atEOT = false
i.reader.pos = 0
return &i.reader
}
func (i *inputs) clear() {
// We need to clear 1 of these.
// Avoid the expense of clearing the others (pointer write barrier).
if i.bytes.str != nil {
i.bytes.str = nil
} else if i.reader.r != nil {
i.reader.r = nil
} else {
i.string.str = ""
}
}
func (i *inputs) init(r io.RuneReader, b []byte, s string) (input, int) {
if r != nil {
return i.newReader(r), 0
}
if b != nil {
return i.newBytes(b), len(b)
}
return i.newString(s), len(s)
}
func (m *machine) init(ncap int) {
for _, t := range m.pool {
t.cap = t.cap[:ncap]
}
m.matchcap = m.matchcap[:ncap]
}
// alloc allocates a new thread with the given instruction.
// It uses the free pool if possible.
func (m *machine) alloc(i *syntax.Inst) *thread {
var t *thread
if n := len(m.pool); n > 0 {
t = m.pool[n-1]
m.pool = m.pool[:n-1]
} else {
t = new(thread)
t.cap = make([]int, len(m.matchcap), cap(m.matchcap))
}
t.inst = i
return t
}
// A lazyFlag is a lazily-evaluated syntax.EmptyOp,
// for checking zero-width flags like ^ $ \A \z \B \b.
// It records the pair of relevant runes and does not
// determine the implied flags until absolutely necessary
// (most of the time, that means never).
type lazyFlag uint64
func newLazyFlag(r1, r2 rune) lazyFlag {
return lazyFlag(uint64(r1)<<32 | uint64(uint32(r2)))
}
func (f lazyFlag) match(op syntax.EmptyOp) bool {
if op == 0 {
return true
}
r1 := rune(f >> 32)
if op&syntax.EmptyBeginLine != 0 {
if r1 != '\n' && r1 >= 0 {
return false
}
op &^= syntax.EmptyBeginLine
}
if op&syntax.EmptyBeginText != 0 {
if r1 >= 0 {
return false
}
op &^= syntax.EmptyBeginText
}
if op == 0 {
return true
}
r2 := rune(f)
if op&syntax.EmptyEndLine != 0 {
if r2 != '\n' && r2 >= 0 {
return false
}
op &^= syntax.EmptyEndLine
}
if op&syntax.EmptyEndText != 0 {
if r2 >= 0 {
return false
}
op &^= syntax.EmptyEndText
}
if op == 0 {
return true
}
if syntax.IsWordChar(r1) != syntax.IsWordChar(r2) {
op &^= syntax.EmptyWordBoundary
} else {
op &^= syntax.EmptyNoWordBoundary
}
return op == 0
}
// match runs the machine over the input starting at pos.
// It reports whether a match was found.
// If so, m.matchcap holds the submatch information.
func (m *machine) match(i input, pos int) bool {
startCond := m.re.cond
if startCond == ^syntax.EmptyOp(0) { // impossible
return false
}
m.matched = false
for i := range m.matchcap {
m.matchcap[i] = -1
}
runq, nextq := &m.q0, &m.q1
r, r1 := endOfText, endOfText
width, width1 := 0, 0
r, width = i.step(pos)
if r != endOfText {
r1, width1 = i.step(pos + width)
}
var flag lazyFlag
if pos == 0 {
flag = newLazyFlag(-1, r)
} else {
flag = i.context(pos)
}
for {
if len(runq.dense) == 0 {
if startCond&syntax.EmptyBeginText != 0 && pos != 0 {
// Anchored match, past beginning of text.
break
}
if m.matched {
// Have match; finished exploring alternatives.
break
}
if len(m.re.prefix) > 0 && r1 != m.re.prefixRune && i.canCheckPrefix() {
// Match requires literal prefix; fast search for it.
advance := i.index(m.re, pos)
if advance < 0 {
break
}
pos += advance
r, width = i.step(pos)
r1, width1 = i.step(pos + width)
}
}
if !m.matched {
if len(m.matchcap) > 0 {
m.matchcap[0] = pos
}
m.add(runq, uint32(m.p.Start), pos, m.matchcap, &flag, nil)
}
flag = newLazyFlag(r, r1)
m.step(runq, nextq, pos, pos+width, r, &flag)
if width == 0 {
break
}
if len(m.matchcap) == 0 && m.matched {
// Found a match and not paying attention
// to where it is, so any match will do.
break
}
pos += width
r, width = r1, width1
if r != endOfText {
r1, width1 = i.step(pos + width)
}
runq, nextq = nextq, runq
}
m.clear(nextq)
return m.matched
}
// clear frees all threads on the thread queue.
func (m *machine) clear(q *queue) {
for _, d := range q.dense {
if d.t != nil {
m.pool = append(m.pool, d.t)
}
}
q.dense = q.dense[:0]
}
// step executes one step of the machine, running each of the threads
// on runq and appending new threads to nextq.
// The step processes the rune c (which may be endOfText),
// which starts at position pos and ends at nextPos.
// nextCond gives the setting for the empty-width flags after c.
func (m *machine) step(runq, nextq *queue, pos, nextPos int, c rune, nextCond *lazyFlag) {
longest := m.re.longest
for j := 0; j < len(runq.dense); j++ {
d := &runq.dense[j]
t := d.t
if t == nil {
continue
}
if longest && m.matched && len(t.cap) > 0 && m.matchcap[0] < t.cap[0] {
m.pool = append(m.pool, t)
continue
}
i := t.inst
add := false
switch i.Op {
default:
panic("bad inst")
case syntax.InstMatch:
if len(t.cap) > 0 && (!longest || !m.matched || m.matchcap[1] < pos) {
t.cap[1] = pos
copy(m.matchcap, t.cap)
}
if !longest {
// First-match mode: cut off all lower-priority threads.
for _, d := range runq.dense[j+1:] {
if d.t != nil {
m.pool = append(m.pool, d.t)
}
}
runq.dense = runq.dense[:0]
}
m.matched = true
case syntax.InstRune:
add = i.MatchRune(c)
case syntax.InstRune1:
add = c == i.Rune[0]
case syntax.InstRuneAny:
add = true
case syntax.InstRuneAnyNotNL:
add = c != '\n'
}
if add {
t = m.add(nextq, i.Out, nextPos, t.cap, nextCond, t)
}
if t != nil {
m.pool = append(m.pool, t)
}
}
runq.dense = runq.dense[:0]
}
// add adds an entry to q for pc, unless the q already has such an entry.
// It also recursively adds an entry for all instructions reachable from pc by following
// empty-width conditions satisfied by cond. pos gives the current position
// in the input.
func (m *machine) add(q *queue, pc uint32, pos int, cap []int, cond *lazyFlag, t *thread) *thread {
Again:
if pc == 0 {
return t
}
if j := q.sparse[pc]; j < uint32(len(q.dense)) && q.dense[j].pc == pc {
return t
}
j := len(q.dense)
q.dense = q.dense[:j+1]
d := &q.dense[j]
d.t = nil
d.pc = pc
q.sparse[pc] = uint32(j)
i := &m.p.Inst[pc]
switch i.Op {
default:
panic("unhandled")
case syntax.InstFail:
// nothing
case syntax.InstAlt, syntax.InstAltMatch:
t = m.add(q, i.Out, pos, cap, cond, t)
pc = i.Arg
goto Again
case syntax.InstEmptyWidth:
if cond.match(syntax.EmptyOp(i.Arg)) {
pc = i.Out
goto Again
}
case syntax.InstNop:
pc = i.Out
goto Again
case syntax.InstCapture:
if int(i.Arg) < len(cap) {
opos := cap[i.Arg]
cap[i.Arg] = pos
m.add(q, i.Out, pos, cap, cond, nil)
cap[i.Arg] = opos
} else {
pc = i.Out
goto Again
}
case syntax.InstMatch, syntax.InstRune, syntax.InstRune1, syntax.InstRuneAny, syntax.InstRuneAnyNotNL:
if t == nil {
t = m.alloc(i)
} else {
t.inst = i
}
if len(cap) > 0 && &t.cap[0] != &cap[0] {
copy(t.cap, cap)
}
d.t = t
t = nil
}
return t
}
type onePassMachine struct {
inputs inputs
matchcap []int
}
var onePassPool sync.Pool
func newOnePassMachine() *onePassMachine {
m, ok := onePassPool.Get().(*onePassMachine)
if !ok {
m = new(onePassMachine)
}
return m
}
func freeOnePassMachine(m *onePassMachine) {
m.inputs.clear()
onePassPool.Put(m)
}
// doOnePass implements r.doExecute using the one-pass execution engine.
func (re *Regexp) doOnePass(ir io.RuneReader, ib []byte, is string, pos, ncap int, dstCap []int) []int {
startCond := re.cond
if startCond == ^syntax.EmptyOp(0) { // impossible
return nil
}
m := newOnePassMachine()
if cap(m.matchcap) < ncap {
m.matchcap = make([]int, ncap)
} else {
m.matchcap = m.matchcap[:ncap]
}
matched := false
for i := range m.matchcap {
m.matchcap[i] = -1
}
i, _ := m.inputs.init(ir, ib, is)
r, r1 := endOfText, endOfText
width, width1 := 0, 0
r, width = i.step(pos)
if r != endOfText {
r1, width1 = i.step(pos + width)
}
var flag lazyFlag
if pos == 0 {
flag = newLazyFlag(-1, r)
} else {
flag = i.context(pos)
}
pc := re.onepass.Start
inst := &re.onepass.Inst[pc]
// If there is a simple literal prefix, skip over it.
if pos == 0 && flag.match(syntax.EmptyOp(inst.Arg)) &&
len(re.prefix) > 0 && i.canCheckPrefix() {
// Match requires literal prefix; fast search for it.
if !i.hasPrefix(re) {
goto Return
}
pos += len(re.prefix)
r, width = i.step(pos)
r1, width1 = i.step(pos + width)
flag = i.context(pos)
pc = int(re.prefixEnd)
}
for {
inst = &re.onepass.Inst[pc]
pc = int(inst.Out)
switch inst.Op {
default:
panic("bad inst")
case syntax.InstMatch:
matched = true
if len(m.matchcap) > 0 {
m.matchcap[0] = 0
m.matchcap[1] = pos
}
goto Return
case syntax.InstRune:
if !inst.MatchRune(r) {
goto Return
}
case syntax.InstRune1:
if r != inst.Rune[0] {
goto Return
}
case syntax.InstRuneAny:
// Nothing
case syntax.InstRuneAnyNotNL:
if r == '\n' {
goto Return
}
// peek at the input rune to see which branch of the Alt to take
case syntax.InstAlt, syntax.InstAltMatch:
pc = int(onePassNext(inst, r))
continue
case syntax.InstFail:
goto Return
case syntax.InstNop:
continue
case syntax.InstEmptyWidth:
if !flag.match(syntax.EmptyOp(inst.Arg)) {
goto Return
}
continue
case syntax.InstCapture:
if int(inst.Arg) < len(m.matchcap) {
m.matchcap[inst.Arg] = pos
}
continue
}
if width == 0 {
break
}
flag = newLazyFlag(r, r1)
pos += width
r, width = r1, width1
if r != endOfText {
r1, width1 = i.step(pos + width)
}
}
Return:
if !matched {
freeOnePassMachine(m)
return nil
}
dstCap = append(dstCap, m.matchcap...)
freeOnePassMachine(m)
return dstCap
}
// doMatch reports whether either r, b or s match the regexp.
func (re *Regexp) doMatch(r io.RuneReader, b []byte, s string) bool {
return re.doExecute(r, b, s, 0, 0, nil) != nil
}
// doExecute finds the leftmost match in the input, appends the position
// of its subexpressions to dstCap and returns dstCap.
//
// nil is returned if no matches are found and non-nil if matches are found.
func (re *Regexp) doExecute(r io.RuneReader, b []byte, s string, pos int, ncap int, dstCap []int) []int {
if dstCap == nil {
// Make sure 'return dstCap' is non-nil.
dstCap = arrayNoInts[:0:0]
}
if r == nil && len(b)+len(s) < re.minInputLen {
return nil
}
if re.onepass != nil {
return re.doOnePass(r, b, s, pos, ncap, dstCap)
}
if r == nil && len(b)+len(s) < re.maxBitStateLen {
return re.backtrack(b, s, pos, ncap, dstCap)
}
m := re.get()
i, _ := m.inputs.init(r, b, s)
m.init(ncap)
if !m.match(i, pos) {
re.put(m)
return nil
}
dstCap = append(dstCap, m.matchcap...)
re.put(m)
return dstCap
}
// arrayNoInts is returned by doExecute match if nil dstCap is passed
// to it with ncap=0.
var arrayNoInts [0]int
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package regexp
import (
"regexp/syntax"
"sort"
"strings"
"unicode"
"unicode/utf8"
)
// "One-pass" regexp execution.
// Some regexps can be analyzed to determine that they never need
// backtracking: they are guaranteed to run in one pass over the string
// without bothering to save all the usual NFA state.
// Detect those and execute them more quickly.
// A onePassProg is a compiled one-pass regular expression program.
// It is the same as syntax.Prog except for the use of onePassInst.
type onePassProg struct {
Inst []onePassInst
Start int // index of start instruction
NumCap int // number of InstCapture insts in re
}
// A onePassInst is a single instruction in a one-pass regular expression program.
// It is the same as syntax.Inst except for the new 'Next' field.
type onePassInst struct {
syntax.Inst
Next []uint32
}
// onePassPrefix returns a literal string that all matches for the
// regexp must start with. Complete is true if the prefix
// is the entire match. Pc is the index of the last rune instruction
// in the string. The onePassPrefix skips over the mandatory
// EmptyBeginText.
func onePassPrefix(p *syntax.Prog) (prefix string, complete bool, pc uint32) {
i := &p.Inst[p.Start]
if i.Op != syntax.InstEmptyWidth || (syntax.EmptyOp(i.Arg))&syntax.EmptyBeginText == 0 {
return "", i.Op == syntax.InstMatch, uint32(p.Start)
}
pc = i.Out
i = &p.Inst[pc]
for i.Op == syntax.InstNop {
pc = i.Out
i = &p.Inst[pc]
}
// Avoid allocation of buffer if prefix is empty.
if iop(i) != syntax.InstRune || len(i.Rune) != 1 {
return "", i.Op == syntax.InstMatch, uint32(p.Start)
}
// Have prefix; gather characters.
var buf strings.Builder
for iop(i) == syntax.InstRune && len(i.Rune) == 1 && syntax.Flags(i.Arg)&syntax.FoldCase == 0 && i.Rune[0] != utf8.RuneError {
buf.WriteRune(i.Rune[0])
pc, i = i.Out, &p.Inst[i.Out]
}
if i.Op == syntax.InstEmptyWidth &&
syntax.EmptyOp(i.Arg)&syntax.EmptyEndText != 0 &&
p.Inst[i.Out].Op == syntax.InstMatch {
complete = true
}
return buf.String(), complete, pc
}
// onePassNext selects the next actionable state of the prog, based on the input character.
// It should only be called when i.Op == InstAlt or InstAltMatch, and from the one-pass machine.
// One of the alternates may ultimately lead without input to end of line. If the instruction
// is InstAltMatch the path to the InstMatch is in i.Out, the normal node in i.Next.
func onePassNext(i *onePassInst, r rune) uint32 {
next := i.MatchRunePos(r)
if next >= 0 {
return i.Next[next]
}
if i.Op == syntax.InstAltMatch {
return i.Out
}
return 0
}
func iop(i *syntax.Inst) syntax.InstOp {
op := i.Op
switch op {
case syntax.InstRune1, syntax.InstRuneAny, syntax.InstRuneAnyNotNL:
op = syntax.InstRune
}
return op
}
// Sparse Array implementation is used as a queueOnePass.
type queueOnePass struct {
sparse []uint32
dense []uint32
size, nextIndex uint32
}
func (q *queueOnePass) empty() bool {
return q.nextIndex >= q.size
}
func (q *queueOnePass) next() (n uint32) {
n = q.dense[q.nextIndex]
q.nextIndex++
return
}
func (q *queueOnePass) clear() {
q.size = 0
q.nextIndex = 0
}
func (q *queueOnePass) contains(u uint32) bool {
if u >= uint32(len(q.sparse)) {
return false
}
return q.sparse[u] < q.size && q.dense[q.sparse[u]] == u
}
func (q *queueOnePass) insert(u uint32) {
if !q.contains(u) {
q.insertNew(u)
}
}
func (q *queueOnePass) insertNew(u uint32) {
if u >= uint32(len(q.sparse)) {
return
}
q.sparse[u] = q.size
q.dense[q.size] = u
q.size++
}
func newQueue(size int) (q *queueOnePass) {
return &queueOnePass{
sparse: make([]uint32, size),
dense: make([]uint32, size),
}
}
// mergeRuneSets merges two non-intersecting runesets, and returns the merged result,
// and a NextIp array. The idea is that if a rune matches the OnePassRunes at index
// i, NextIp[i/2] is the target. If the input sets intersect, an empty runeset and a
// NextIp array with the single element mergeFailed is returned.
// The code assumes that both inputs contain ordered and non-intersecting rune pairs.
const mergeFailed = uint32(0xffffffff)
var (
noRune = []rune{}
noNext = []uint32{mergeFailed}
)
func mergeRuneSets(leftRunes, rightRunes *[]rune, leftPC, rightPC uint32) ([]rune, []uint32) {
leftLen := len(*leftRunes)
rightLen := len(*rightRunes)
if leftLen&0x1 != 0 || rightLen&0x1 != 0 {
panic("mergeRuneSets odd length []rune")
}
var (
lx, rx int
)
merged := make([]rune, 0)
next := make([]uint32, 0)
ok := true
defer func() {
if !ok {
merged = nil
next = nil
}
}()
ix := -1
extend := func(newLow *int, newArray *[]rune, pc uint32) bool {
if ix > 0 && (*newArray)[*newLow] <= merged[ix] {
return false
}
merged = append(merged, (*newArray)[*newLow], (*newArray)[*newLow+1])
*newLow += 2
ix += 2
next = append(next, pc)
return true
}
for lx < leftLen || rx < rightLen {
switch {
case rx >= rightLen:
ok = extend(&lx, leftRunes, leftPC)
case lx >= leftLen:
ok = extend(&rx, rightRunes, rightPC)
case (*rightRunes)[rx] < (*leftRunes)[lx]:
ok = extend(&rx, rightRunes, rightPC)
default:
ok = extend(&lx, leftRunes, leftPC)
}
if !ok {
return noRune, noNext
}
}
return merged, next
}
// cleanupOnePass drops working memory, and restores certain shortcut instructions.
func cleanupOnePass(prog *onePassProg, original *syntax.Prog) {
for ix, instOriginal := range original.Inst {
switch instOriginal.Op {
case syntax.InstAlt, syntax.InstAltMatch, syntax.InstRune:
case syntax.InstCapture, syntax.InstEmptyWidth, syntax.InstNop, syntax.InstMatch, syntax.InstFail:
prog.Inst[ix].Next = nil
case syntax.InstRune1, syntax.InstRuneAny, syntax.InstRuneAnyNotNL:
prog.Inst[ix].Next = nil
prog.Inst[ix] = onePassInst{Inst: instOriginal}
}
}
}
// onePassCopy creates a copy of the original Prog, as we'll be modifying it.
func onePassCopy(prog *syntax.Prog) *onePassProg {
p := &onePassProg{
Start: prog.Start,
NumCap: prog.NumCap,
Inst: make([]onePassInst, len(prog.Inst)),
}
for i, inst := range prog.Inst {
p.Inst[i] = onePassInst{Inst: inst}
}
// rewrites one or more common Prog constructs that enable some otherwise
// non-onepass Progs to be onepass. A:BD (for example) means an InstAlt at
// ip A, that points to ips B & C.
// A:BC + B:DA => A:BC + B:CD
// A:BC + B:DC => A:DC + B:DC
for pc := range p.Inst {
switch p.Inst[pc].Op {
default:
continue
case syntax.InstAlt, syntax.InstAltMatch:
// A:Bx + B:Ay
p_A_Other := &p.Inst[pc].Out
p_A_Alt := &p.Inst[pc].Arg
// make sure a target is another Alt
instAlt := p.Inst[*p_A_Alt]
if !(instAlt.Op == syntax.InstAlt || instAlt.Op == syntax.InstAltMatch) {
p_A_Alt, p_A_Other = p_A_Other, p_A_Alt
instAlt = p.Inst[*p_A_Alt]
if !(instAlt.Op == syntax.InstAlt || instAlt.Op == syntax.InstAltMatch) {
continue
}
}
instOther := p.Inst[*p_A_Other]
// Analyzing both legs pointing to Alts is for another day
if instOther.Op == syntax.InstAlt || instOther.Op == syntax.InstAltMatch {
// too complicated
continue
}
// simple empty transition loop
// A:BC + B:DA => A:BC + B:DC
p_B_Alt := &p.Inst[*p_A_Alt].Out
p_B_Other := &p.Inst[*p_A_Alt].Arg
patch := false
if instAlt.Out == uint32(pc) {
patch = true
} else if instAlt.Arg == uint32(pc) {
patch = true
p_B_Alt, p_B_Other = p_B_Other, p_B_Alt
}
if patch {
*p_B_Alt = *p_A_Other
}
// empty transition to common target
// A:BC + B:DC => A:DC + B:DC
if *p_A_Other == *p_B_Alt {
*p_A_Alt = *p_B_Other
}
}
}
return p
}
// runeSlice exists to permit sorting the case-folded rune sets.
type runeSlice []rune
func (p runeSlice) Len() int { return len(p) }
func (p runeSlice) Less(i, j int) bool { return p[i] < p[j] }
func (p runeSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
var anyRuneNotNL = []rune{0, '\n' - 1, '\n' + 1, unicode.MaxRune}
var anyRune = []rune{0, unicode.MaxRune}
// makeOnePass creates a onepass Prog, if possible. It is possible if at any alt,
// the match engine can always tell which branch to take. The routine may modify
// p if it is turned into a onepass Prog. If it isn't possible for this to be a
// onepass Prog, the Prog nil is returned. makeOnePass is recursive
// to the size of the Prog.
func makeOnePass(p *onePassProg) *onePassProg {
// If the machine is very long, it's not worth the time to check if we can use one pass.
if len(p.Inst) >= 1000 {
return nil
}
var (
instQueue = newQueue(len(p.Inst))
visitQueue = newQueue(len(p.Inst))
check func(uint32, []bool) bool
onePassRunes = make([][]rune, len(p.Inst))
)
// check that paths from Alt instructions are unambiguous, and rebuild the new
// program as a onepass program
check = func(pc uint32, m []bool) (ok bool) {
ok = true
inst := &p.Inst[pc]
if visitQueue.contains(pc) {
return
}
visitQueue.insert(pc)
switch inst.Op {
case syntax.InstAlt, syntax.InstAltMatch:
ok = check(inst.Out, m) && check(inst.Arg, m)
// check no-input paths to InstMatch
matchOut := m[inst.Out]
matchArg := m[inst.Arg]
if matchOut && matchArg {
ok = false
break
}
// Match on empty goes in inst.Out
if matchArg {
inst.Out, inst.Arg = inst.Arg, inst.Out
matchOut, matchArg = matchArg, matchOut
}
if matchOut {
m[pc] = true
inst.Op = syntax.InstAltMatch
}
// build a dispatch operator from the two legs of the alt.
onePassRunes[pc], inst.Next = mergeRuneSets(
&onePassRunes[inst.Out], &onePassRunes[inst.Arg], inst.Out, inst.Arg)
if len(inst.Next) > 0 && inst.Next[0] == mergeFailed {
ok = false
break
}
case syntax.InstCapture, syntax.InstNop:
ok = check(inst.Out, m)
m[pc] = m[inst.Out]
// pass matching runes back through these no-ops.
onePassRunes[pc] = append([]rune{}, onePassRunes[inst.Out]...)
inst.Next = make([]uint32, len(onePassRunes[pc])/2+1)
for i := range inst.Next {
inst.Next[i] = inst.Out
}
case syntax.InstEmptyWidth:
ok = check(inst.Out, m)
m[pc] = m[inst.Out]
onePassRunes[pc] = append([]rune{}, onePassRunes[inst.Out]...)
inst.Next = make([]uint32, len(onePassRunes[pc])/2+1)
for i := range inst.Next {
inst.Next[i] = inst.Out
}
case syntax.InstMatch, syntax.InstFail:
m[pc] = inst.Op == syntax.InstMatch
case syntax.InstRune:
m[pc] = false
if len(inst.Next) > 0 {
break
}
instQueue.insert(inst.Out)
if len(inst.Rune) == 0 {
onePassRunes[pc] = []rune{}
inst.Next = []uint32{inst.Out}
break
}
runes := make([]rune, 0)
if len(inst.Rune) == 1 && syntax.Flags(inst.Arg)&syntax.FoldCase != 0 {
r0 := inst.Rune[0]
runes = append(runes, r0, r0)
for r1 := unicode.SimpleFold(r0); r1 != r0; r1 = unicode.SimpleFold(r1) {
runes = append(runes, r1, r1)
}
sort.Sort(runeSlice(runes))
} else {
runes = append(runes, inst.Rune...)
}
onePassRunes[pc] = runes
inst.Next = make([]uint32, len(onePassRunes[pc])/2+1)
for i := range inst.Next {
inst.Next[i] = inst.Out
}
inst.Op = syntax.InstRune
case syntax.InstRune1:
m[pc] = false
if len(inst.Next) > 0 {
break
}
instQueue.insert(inst.Out)
runes := []rune{}
// expand case-folded runes
if syntax.Flags(inst.Arg)&syntax.FoldCase != 0 {
r0 := inst.Rune[0]
runes = append(runes, r0, r0)
for r1 := unicode.SimpleFold(r0); r1 != r0; r1 = unicode.SimpleFold(r1) {
runes = append(runes, r1, r1)
}
sort.Sort(runeSlice(runes))
} else {
runes = append(runes, inst.Rune[0], inst.Rune[0])
}
onePassRunes[pc] = runes
inst.Next = make([]uint32, len(onePassRunes[pc])/2+1)
for i := range inst.Next {
inst.Next[i] = inst.Out
}
inst.Op = syntax.InstRune
case syntax.InstRuneAny:
m[pc] = false
if len(inst.Next) > 0 {
break
}
instQueue.insert(inst.Out)
onePassRunes[pc] = append([]rune{}, anyRune...)
inst.Next = []uint32{inst.Out}
case syntax.InstRuneAnyNotNL:
m[pc] = false
if len(inst.Next) > 0 {
break
}
instQueue.insert(inst.Out)
onePassRunes[pc] = append([]rune{}, anyRuneNotNL...)
inst.Next = make([]uint32, len(onePassRunes[pc])/2+1)
for i := range inst.Next {
inst.Next[i] = inst.Out
}
}
return
}
instQueue.clear()
instQueue.insert(uint32(p.Start))
m := make([]bool, len(p.Inst))
for !instQueue.empty() {
visitQueue.clear()
pc := instQueue.next()
if !check(pc, m) {
p = nil
break
}
}
if p != nil {
for i := range p.Inst {
p.Inst[i].Rune = onePassRunes[i]
}
}
return p
}
// compileOnePass returns a new *syntax.Prog suitable for onePass execution if the original Prog
// can be recharacterized as a one-pass regexp program, or syntax.nil if the
// Prog cannot be converted. For a one pass prog, the fundamental condition that must
// be true is: at any InstAlt, there must be no ambiguity about what branch to take.
func compileOnePass(prog *syntax.Prog) (p *onePassProg) {
if prog.Start == 0 {
return nil
}
// onepass regexp is anchored
if prog.Inst[prog.Start].Op != syntax.InstEmptyWidth ||
syntax.EmptyOp(prog.Inst[prog.Start].Arg)&syntax.EmptyBeginText != syntax.EmptyBeginText {
return nil
}
// every instruction leading to InstMatch must be EmptyEndText
for _, inst := range prog.Inst {
opOut := prog.Inst[inst.Out].Op
switch inst.Op {
default:
if opOut == syntax.InstMatch {
return nil
}
case syntax.InstAlt, syntax.InstAltMatch:
if opOut == syntax.InstMatch || prog.Inst[inst.Arg].Op == syntax.InstMatch {
return nil
}
case syntax.InstEmptyWidth:
if opOut == syntax.InstMatch {
if syntax.EmptyOp(inst.Arg)&syntax.EmptyEndText == syntax.EmptyEndText {
continue
}
return nil
}
}
}
// Creates a slightly optimized copy of the original Prog
// that cleans up some Prog idioms that block valid onepass programs
p = onePassCopy(prog)
// checkAmbiguity on InstAlts, build onepass Prog if possible
p = makeOnePass(p)
if p != nil {
cleanupOnePass(p, prog)
}
return p
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package regexp implements regular expression search.
//
// The syntax of the regular expressions accepted is the same
// general syntax used by Perl, Python, and other languages.
// More precisely, it is the syntax accepted by RE2 and described at
// https://golang.org/s/re2syntax, except for \C.
// For an overview of the syntax, run
//
// go doc regexp/syntax
//
// The regexp implementation provided by this package is
// guaranteed to run in time linear in the size of the input.
// (This is a property not guaranteed by most open source
// implementations of regular expressions.) For more information
// about this property, see
//
// https://swtch.com/~rsc/regexp/regexp1.html
//
// or any book about automata theory.
//
// All characters are UTF-8-encoded code points.
// Following utf8.DecodeRune, each byte of an invalid UTF-8 sequence
// is treated as if it encoded utf8.RuneError (U+FFFD).
//
// There are 16 methods of Regexp that match a regular expression and identify
// the matched text. Their names are matched by this regular expression:
//
// Find(All)?(String)?(Submatch)?(Index)?
//
// If 'All' is present, the routine matches successive non-overlapping
// matches of the entire expression. Empty matches abutting a preceding
// match are ignored. The return value is a slice containing the successive
// return values of the corresponding non-'All' routine. These routines take
// an extra integer argument, n. If n >= 0, the function returns at most n
// matches/submatches; otherwise, it returns all of them.
//
// If 'String' is present, the argument is a string; otherwise it is a slice
// of bytes; return values are adjusted as appropriate.
//
// If 'Submatch' is present, the return value is a slice identifying the
// successive submatches of the expression. Submatches are matches of
// parenthesized subexpressions (also known as capturing groups) within the
// regular expression, numbered from left to right in order of opening
// parenthesis. Submatch 0 is the match of the entire expression, submatch 1 is
// the match of the first parenthesized subexpression, and so on.
//
// If 'Index' is present, matches and submatches are identified by byte index
// pairs within the input string: result[2*n:2*n+2] identifies the indexes of
// the nth submatch. The pair for n==0 identifies the match of the entire
// expression. If 'Index' is not present, the match is identified by the text
// of the match/submatch. If an index is negative or text is nil, it means that
// subexpression did not match any string in the input. For 'String' versions
// an empty string means either no match or an empty match.
//
// There is also a subset of the methods that can be applied to text read
// from a RuneReader:
//
// MatchReader, FindReaderIndex, FindReaderSubmatchIndex
//
// This set may grow. Note that regular expression matches may need to
// examine text beyond the text returned by a match, so the methods that
// match text from a RuneReader may read arbitrarily far into the input
// before returning.
//
// (There are a few other methods that do not match this pattern.)
package regexp
import (
"bytes"
"io"
"regexp/syntax"
"strconv"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
// Regexp is the representation of a compiled regular expression.
// A Regexp is safe for concurrent use by multiple goroutines,
// except for configuration methods, such as Longest.
type Regexp struct {
expr string // as passed to Compile
prog *syntax.Prog // compiled program
onepass *onePassProg // onepass program or nil
numSubexp int
maxBitStateLen int
subexpNames []string
prefix string // required prefix in unanchored matches
prefixBytes []byte // prefix, as a []byte
prefixRune rune // first rune in prefix
prefixEnd uint32 // pc for last rune in prefix
mpool int // pool for machines
matchcap int // size of recorded match lengths
prefixComplete bool // prefix is the entire regexp
cond syntax.EmptyOp // empty-width conditions required at start of match
minInputLen int // minimum length of the input in bytes
// This field can be modified by the Longest method,
// but it is otherwise read-only.
longest bool // whether regexp prefers leftmost-longest match
}
// String returns the source text used to compile the regular expression.
func (re *Regexp) String() string {
return re.expr
}
// Copy returns a new Regexp object copied from re.
// Calling Longest on one copy does not affect another.
//
// Deprecated: In earlier releases, when using a Regexp in multiple goroutines,
// giving each goroutine its own copy helped to avoid lock contention.
// As of Go 1.12, using Copy is no longer necessary to avoid lock contention.
// Copy may still be appropriate if the reason for its use is to make
// two copies with different Longest settings.
func (re *Regexp) Copy() *Regexp {
re2 := *re
return &re2
}
// Compile parses a regular expression and returns, if successful,
// a Regexp object that can be used to match against text.
//
// When matching against text, the regexp returns a match that
// begins as early as possible in the input (leftmost), and among those
// it chooses the one that a backtracking search would have found first.
// This so-called leftmost-first matching is the same semantics
// that Perl, Python, and other implementations use, although this
// package implements it without the expense of backtracking.
// For POSIX leftmost-longest matching, see CompilePOSIX.
func Compile(expr string) (*Regexp, error) {
return compile(expr, syntax.Perl, false)
}
// CompilePOSIX is like Compile but restricts the regular expression
// to POSIX ERE (egrep) syntax and changes the match semantics to
// leftmost-longest.
//
// That is, when matching against text, the regexp returns a match that
// begins as early as possible in the input (leftmost), and among those
// it chooses a match that is as long as possible.
// This so-called leftmost-longest matching is the same semantics
// that early regular expression implementations used and that POSIX
// specifies.
//
// However, there can be multiple leftmost-longest matches, with different
// submatch choices, and here this package diverges from POSIX.
// Among the possible leftmost-longest matches, this package chooses
// the one that a backtracking search would have found first, while POSIX
// specifies that the match be chosen to maximize the length of the first
// subexpression, then the second, and so on from left to right.
// The POSIX rule is computationally prohibitive and not even well-defined.
// See https://swtch.com/~rsc/regexp/regexp2.html#posix for details.
func CompilePOSIX(expr string) (*Regexp, error) {
return compile(expr, syntax.POSIX, true)
}
// Longest makes future searches prefer the leftmost-longest match.
// That is, when matching against text, the regexp returns a match that
// begins as early as possible in the input (leftmost), and among those
// it chooses a match that is as long as possible.
// This method modifies the Regexp and may not be called concurrently
// with any other methods.
func (re *Regexp) Longest() {
re.longest = true
}
func compile(expr string, mode syntax.Flags, longest bool) (*Regexp, error) {
re, err := syntax.Parse(expr, mode)
if err != nil {
return nil, err
}
maxCap := re.MaxCap()
capNames := re.CapNames()
re = re.Simplify()
prog, err := syntax.Compile(re)
if err != nil {
return nil, err
}
matchcap := prog.NumCap
if matchcap < 2 {
matchcap = 2
}
regexp := &Regexp{
expr: expr,
prog: prog,
onepass: compileOnePass(prog),
numSubexp: maxCap,
subexpNames: capNames,
cond: prog.StartCond(),
longest: longest,
matchcap: matchcap,
minInputLen: minInputLen(re),
}
if regexp.onepass == nil {
regexp.prefix, regexp.prefixComplete = prog.Prefix()
regexp.maxBitStateLen = maxBitStateLen(prog)
} else {
regexp.prefix, regexp.prefixComplete, regexp.prefixEnd = onePassPrefix(prog)
}
if regexp.prefix != "" {
// TODO(rsc): Remove this allocation by adding
// IndexString to package bytes.
regexp.prefixBytes = []byte(regexp.prefix)
regexp.prefixRune, _ = utf8.DecodeRuneInString(regexp.prefix)
}
n := len(prog.Inst)
i := 0
for matchSize[i] != 0 && matchSize[i] < n {
i++
}
regexp.mpool = i
return regexp, nil
}
// Pools of *machine for use during (*Regexp).doExecute,
// split up by the size of the execution queues.
// matchPool[i] machines have queue size matchSize[i].
// On a 64-bit system each queue entry is 16 bytes,
// so matchPool[0] has 16*2*128 = 4kB queues, etc.
// The final matchPool is a catch-all for very large queues.
var (
matchSize = [...]int{128, 512, 2048, 16384, 0}
matchPool [len(matchSize)]sync.Pool
)
// get returns a machine to use for matching re.
// It uses the re's machine cache if possible, to avoid
// unnecessary allocation.
func (re *Regexp) get() *machine {
m, ok := matchPool[re.mpool].Get().(*machine)
if !ok {
m = new(machine)
}
m.re = re
m.p = re.prog
if cap(m.matchcap) < re.matchcap {
m.matchcap = make([]int, re.matchcap)
for _, t := range m.pool {
t.cap = make([]int, re.matchcap)
}
}
// Allocate queues if needed.
// Or reallocate, for "large" match pool.
n := matchSize[re.mpool]
if n == 0 { // large pool
n = len(re.prog.Inst)
}
if len(m.q0.sparse) < n {
m.q0 = queue{make([]uint32, n), make([]entry, 0, n)}
m.q1 = queue{make([]uint32, n), make([]entry, 0, n)}
}
return m
}
// put returns a machine to the correct machine pool.
func (re *Regexp) put(m *machine) {
m.re = nil
m.p = nil
m.inputs.clear()
matchPool[re.mpool].Put(m)
}
// minInputLen walks the regexp to find the minimum length of any matchable input.
func minInputLen(re *syntax.Regexp) int {
switch re.Op {
default:
return 0
case syntax.OpAnyChar, syntax.OpAnyCharNotNL, syntax.OpCharClass:
return 1
case syntax.OpLiteral:
l := 0
for _, r := range re.Rune {
if r == utf8.RuneError {
l++
} else {
l += utf8.RuneLen(r)
}
}
return l
case syntax.OpCapture, syntax.OpPlus:
return minInputLen(re.Sub[0])
case syntax.OpRepeat:
return re.Min * minInputLen(re.Sub[0])
case syntax.OpConcat:
l := 0
for _, sub := range re.Sub {
l += minInputLen(sub)
}
return l
case syntax.OpAlternate:
l := minInputLen(re.Sub[0])
var lnext int
for _, sub := range re.Sub[1:] {
lnext = minInputLen(sub)
if lnext < l {
l = lnext
}
}
return l
}
}
// MustCompile is like Compile but panics if the expression cannot be parsed.
// It simplifies safe initialization of global variables holding compiled regular
// expressions.
func MustCompile(str string) *Regexp {
regexp, err := Compile(str)
if err != nil {
panic(`regexp: Compile(` + quote(str) + `): ` + err.Error())
}
return regexp
}
// MustCompilePOSIX is like CompilePOSIX but panics if the expression cannot be parsed.
// It simplifies safe initialization of global variables holding compiled regular
// expressions.
func MustCompilePOSIX(str string) *Regexp {
regexp, err := CompilePOSIX(str)
if err != nil {
panic(`regexp: CompilePOSIX(` + quote(str) + `): ` + err.Error())
}
return regexp
}
func quote(s string) string {
if strconv.CanBackquote(s) {
return "`" + s + "`"
}
return strconv.Quote(s)
}
// NumSubexp returns the number of parenthesized subexpressions in this Regexp.
func (re *Regexp) NumSubexp() int {
return re.numSubexp
}
// SubexpNames returns the names of the parenthesized subexpressions
// in this Regexp. The name for the first sub-expression is names[1],
// so that if m is a match slice, the name for m[i] is SubexpNames()[i].
// Since the Regexp as a whole cannot be named, names[0] is always
// the empty string. The slice should not be modified.
func (re *Regexp) SubexpNames() []string {
return re.subexpNames
}
// SubexpIndex returns the index of the first subexpression with the given name,
// or -1 if there is no subexpression with that name.
//
// Note that multiple subexpressions can be written using the same name, as in
// (?P<bob>a+)(?P<bob>b+), which declares two subexpressions named "bob".
// In this case, SubexpIndex returns the index of the leftmost such subexpression
// in the regular expression.
func (re *Regexp) SubexpIndex(name string) int {
if name != "" {
for i, s := range re.subexpNames {
if name == s {
return i
}
}
}
return -1
}
const endOfText rune = -1
// input abstracts different representations of the input text. It provides
// one-character lookahead.
type input interface {
step(pos int) (r rune, width int) // advance one rune
canCheckPrefix() bool // can we look ahead without losing info?
hasPrefix(re *Regexp) bool
index(re *Regexp, pos int) int
context(pos int) lazyFlag
}
// inputString scans a string.
type inputString struct {
str string
}
func (i *inputString) step(pos int) (rune, int) {
if pos < len(i.str) {
c := i.str[pos]
if c < utf8.RuneSelf {
return rune(c), 1
}
return utf8.DecodeRuneInString(i.str[pos:])
}
return endOfText, 0
}
func (i *inputString) canCheckPrefix() bool {
return true
}
func (i *inputString) hasPrefix(re *Regexp) bool {
return strings.HasPrefix(i.str, re.prefix)
}
func (i *inputString) index(re *Regexp, pos int) int {
return strings.Index(i.str[pos:], re.prefix)
}
func (i *inputString) context(pos int) lazyFlag {
r1, r2 := endOfText, endOfText
// 0 < pos && pos <= len(i.str)
if uint(pos-1) < uint(len(i.str)) {
r1 = rune(i.str[pos-1])
if r1 >= utf8.RuneSelf {
r1, _ = utf8.DecodeLastRuneInString(i.str[:pos])
}
}
// 0 <= pos && pos < len(i.str)
if uint(pos) < uint(len(i.str)) {
r2 = rune(i.str[pos])
if r2 >= utf8.RuneSelf {
r2, _ = utf8.DecodeRuneInString(i.str[pos:])
}
}
return newLazyFlag(r1, r2)
}
// inputBytes scans a byte slice.
type inputBytes struct {
str []byte
}
func (i *inputBytes) step(pos int) (rune, int) {
if pos < len(i.str) {
c := i.str[pos]
if c < utf8.RuneSelf {
return rune(c), 1
}
return utf8.DecodeRune(i.str[pos:])
}
return endOfText, 0
}
func (i *inputBytes) canCheckPrefix() bool {
return true
}
func (i *inputBytes) hasPrefix(re *Regexp) bool {
return bytes.HasPrefix(i.str, re.prefixBytes)
}
func (i *inputBytes) index(re *Regexp, pos int) int {
return bytes.Index(i.str[pos:], re.prefixBytes)
}
func (i *inputBytes) context(pos int) lazyFlag {
r1, r2 := endOfText, endOfText
// 0 < pos && pos <= len(i.str)
if uint(pos-1) < uint(len(i.str)) {
r1 = rune(i.str[pos-1])
if r1 >= utf8.RuneSelf {
r1, _ = utf8.DecodeLastRune(i.str[:pos])
}
}
// 0 <= pos && pos < len(i.str)
if uint(pos) < uint(len(i.str)) {
r2 = rune(i.str[pos])
if r2 >= utf8.RuneSelf {
r2, _ = utf8.DecodeRune(i.str[pos:])
}
}
return newLazyFlag(r1, r2)
}
// inputReader scans a RuneReader.
type inputReader struct {
r io.RuneReader
atEOT bool
pos int
}
func (i *inputReader) step(pos int) (rune, int) {
if !i.atEOT && pos != i.pos {
return endOfText, 0
}
r, w, err := i.r.ReadRune()
if err != nil {
i.atEOT = true
return endOfText, 0
}
i.pos += w
return r, w
}
func (i *inputReader) canCheckPrefix() bool {
return false
}
func (i *inputReader) hasPrefix(re *Regexp) bool {
return false
}
func (i *inputReader) index(re *Regexp, pos int) int {
return -1
}
func (i *inputReader) context(pos int) lazyFlag {
return 0 // not used
}
// LiteralPrefix returns a literal string that must begin any match
// of the regular expression re. It returns the boolean true if the
// literal string comprises the entire regular expression.
func (re *Regexp) LiteralPrefix() (prefix string, complete bool) {
return re.prefix, re.prefixComplete
}
// MatchReader reports whether the text returned by the RuneReader
// contains any match of the regular expression re.
func (re *Regexp) MatchReader(r io.RuneReader) bool {
return re.doMatch(r, nil, "")
}
// MatchString reports whether the string s
// contains any match of the regular expression re.
func (re *Regexp) MatchString(s string) bool {
return re.doMatch(nil, nil, s)
}
// Match reports whether the byte slice b
// contains any match of the regular expression re.
func (re *Regexp) Match(b []byte) bool {
return re.doMatch(nil, b, "")
}
// MatchReader reports whether the text returned by the RuneReader
// contains any match of the regular expression pattern.
// More complicated queries need to use Compile and the full Regexp interface.
func MatchReader(pattern string, r io.RuneReader) (matched bool, err error) {
re, err := Compile(pattern)
if err != nil {
return false, err
}
return re.MatchReader(r), nil
}
// MatchString reports whether the string s
// contains any match of the regular expression pattern.
// More complicated queries need to use Compile and the full Regexp interface.
func MatchString(pattern string, s string) (matched bool, err error) {
re, err := Compile(pattern)
if err != nil {
return false, err
}
return re.MatchString(s), nil
}
// Match reports whether the byte slice b
// contains any match of the regular expression pattern.
// More complicated queries need to use Compile and the full Regexp interface.
func Match(pattern string, b []byte) (matched bool, err error) {
re, err := Compile(pattern)
if err != nil {
return false, err
}
return re.Match(b), nil
}
// ReplaceAllString returns a copy of src, replacing matches of the Regexp
// with the replacement string repl. Inside repl, $ signs are interpreted as
// in Expand, so for instance $1 represents the text of the first submatch.
func (re *Regexp) ReplaceAllString(src, repl string) string {
n := 2
if strings.Contains(repl, "$") {
n = 2 * (re.numSubexp + 1)
}
b := re.replaceAll(nil, src, n, func(dst []byte, match []int) []byte {
return re.expand(dst, repl, nil, src, match)
})
return string(b)
}
// ReplaceAllLiteralString returns a copy of src, replacing matches of the Regexp
// with the replacement string repl. The replacement repl is substituted directly,
// without using Expand.
func (re *Regexp) ReplaceAllLiteralString(src, repl string) string {
return string(re.replaceAll(nil, src, 2, func(dst []byte, match []int) []byte {
return append(dst, repl...)
}))
}
// ReplaceAllStringFunc returns a copy of src in which all matches of the
// Regexp have been replaced by the return value of function repl applied
// to the matched substring. The replacement returned by repl is substituted
// directly, without using Expand.
func (re *Regexp) ReplaceAllStringFunc(src string, repl func(string) string) string {
b := re.replaceAll(nil, src, 2, func(dst []byte, match []int) []byte {
return append(dst, repl(src[match[0]:match[1]])...)
})
return string(b)
}
func (re *Regexp) replaceAll(bsrc []byte, src string, nmatch int, repl func(dst []byte, m []int) []byte) []byte {
lastMatchEnd := 0 // end position of the most recent match
searchPos := 0 // position where we next look for a match
var buf []byte
var endPos int
if bsrc != nil {
endPos = len(bsrc)
} else {
endPos = len(src)
}
if nmatch > re.prog.NumCap {
nmatch = re.prog.NumCap
}
var dstCap [2]int
for searchPos <= endPos {
a := re.doExecute(nil, bsrc, src, searchPos, nmatch, dstCap[:0])
if len(a) == 0 {
break // no more matches
}
// Copy the unmatched characters before this match.
if bsrc != nil {
buf = append(buf, bsrc[lastMatchEnd:a[0]]...)
} else {
buf = append(buf, src[lastMatchEnd:a[0]]...)
}
// Now insert a copy of the replacement string, but not for a
// match of the empty string immediately after another match.
// (Otherwise, we get double replacement for patterns that
// match both empty and nonempty strings.)
if a[1] > lastMatchEnd || a[0] == 0 {
buf = repl(buf, a)
}
lastMatchEnd = a[1]
// Advance past this match; always advance at least one character.
var width int
if bsrc != nil {
_, width = utf8.DecodeRune(bsrc[searchPos:])
} else {
_, width = utf8.DecodeRuneInString(src[searchPos:])
}
if searchPos+width > a[1] {
searchPos += width
} else if searchPos+1 > a[1] {
// This clause is only needed at the end of the input
// string. In that case, DecodeRuneInString returns width=0.
searchPos++
} else {
searchPos = a[1]
}
}
// Copy the unmatched characters after the last match.
if bsrc != nil {
buf = append(buf, bsrc[lastMatchEnd:]...)
} else {
buf = append(buf, src[lastMatchEnd:]...)
}
return buf
}
// ReplaceAll returns a copy of src, replacing matches of the Regexp
// with the replacement text repl. Inside repl, $ signs are interpreted as
// in Expand, so for instance $1 represents the text of the first submatch.
func (re *Regexp) ReplaceAll(src, repl []byte) []byte {
n := 2
if bytes.IndexByte(repl, '$') >= 0 {
n = 2 * (re.numSubexp + 1)
}
srepl := ""
b := re.replaceAll(src, "", n, func(dst []byte, match []int) []byte {
if len(srepl) != len(repl) {
srepl = string(repl)
}
return re.expand(dst, srepl, src, "", match)
})
return b
}
// ReplaceAllLiteral returns a copy of src, replacing matches of the Regexp
// with the replacement bytes repl. The replacement repl is substituted directly,
// without using Expand.
func (re *Regexp) ReplaceAllLiteral(src, repl []byte) []byte {
return re.replaceAll(src, "", 2, func(dst []byte, match []int) []byte {
return append(dst, repl...)
})
}
// ReplaceAllFunc returns a copy of src in which all matches of the
// Regexp have been replaced by the return value of function repl applied
// to the matched byte slice. The replacement returned by repl is substituted
// directly, without using Expand.
func (re *Regexp) ReplaceAllFunc(src []byte, repl func([]byte) []byte) []byte {
return re.replaceAll(src, "", 2, func(dst []byte, match []int) []byte {
return append(dst, repl(src[match[0]:match[1]])...)
})
}
// Bitmap used by func special to check whether a character needs to be escaped.
var specialBytes [16]byte
// special reports whether byte b needs to be escaped by QuoteMeta.
func special(b byte) bool {
return b < utf8.RuneSelf && specialBytes[b%16]&(1<<(b/16)) != 0
}
func init() {
for _, b := range []byte(`\.+*?()|[]{}^$`) {
specialBytes[b%16] |= 1 << (b / 16)
}
}
// QuoteMeta returns a string that escapes all regular expression metacharacters
// inside the argument text; the returned string is a regular expression matching
// the literal text.
func QuoteMeta(s string) string {
// A byte loop is correct because all metacharacters are ASCII.
var i int
for i = 0; i < len(s); i++ {
if special(s[i]) {
break
}
}
// No meta characters found, so return original string.
if i >= len(s) {
return s
}
b := make([]byte, 2*len(s)-i)
copy(b, s[:i])
j := i
for ; i < len(s); i++ {
if special(s[i]) {
b[j] = '\\'
j++
}
b[j] = s[i]
j++
}
return string(b[:j])
}
// The number of capture values in the program may correspond
// to fewer capturing expressions than are in the regexp.
// For example, "(a){0}" turns into an empty program, so the
// maximum capture in the program is 0 but we need to return
// an expression for \1. Pad appends -1s to the slice a as needed.
func (re *Regexp) pad(a []int) []int {
if a == nil {
// No match.
return nil
}
n := (1 + re.numSubexp) * 2
for len(a) < n {
a = append(a, -1)
}
return a
}
// allMatches calls deliver at most n times
// with the location of successive matches in the input text.
// The input text is b if non-nil, otherwise s.
func (re *Regexp) allMatches(s string, b []byte, n int, deliver func([]int)) {
var end int
if b == nil {
end = len(s)
} else {
end = len(b)
}
for pos, i, prevMatchEnd := 0, 0, -1; i < n && pos <= end; {
matches := re.doExecute(nil, b, s, pos, re.prog.NumCap, nil)
if len(matches) == 0 {
break
}
accept := true
if matches[1] == pos {
// We've found an empty match.
if matches[0] == prevMatchEnd {
// We don't allow an empty match right
// after a previous match, so ignore it.
accept = false
}
var width int
if b == nil {
is := inputString{str: s}
_, width = is.step(pos)
} else {
ib := inputBytes{str: b}
_, width = ib.step(pos)
}
if width > 0 {
pos += width
} else {
pos = end + 1
}
} else {
pos = matches[1]
}
prevMatchEnd = matches[1]
if accept {
deliver(re.pad(matches))
i++
}
}
}
// Find returns a slice holding the text of the leftmost match in b of the regular expression.
// A return value of nil indicates no match.
func (re *Regexp) Find(b []byte) []byte {
var dstCap [2]int
a := re.doExecute(nil, b, "", 0, 2, dstCap[:0])
if a == nil {
return nil
}
return b[a[0]:a[1]:a[1]]
}
// FindIndex returns a two-element slice of integers defining the location of
// the leftmost match in b of the regular expression. The match itself is at
// b[loc[0]:loc[1]].
// A return value of nil indicates no match.
func (re *Regexp) FindIndex(b []byte) (loc []int) {
a := re.doExecute(nil, b, "", 0, 2, nil)
if a == nil {
return nil
}
return a[0:2]
}
// FindString returns a string holding the text of the leftmost match in s of the regular
// expression. If there is no match, the return value is an empty string,
// but it will also be empty if the regular expression successfully matches
// an empty string. Use FindStringIndex or FindStringSubmatch if it is
// necessary to distinguish these cases.
func (re *Regexp) FindString(s string) string {
var dstCap [2]int
a := re.doExecute(nil, nil, s, 0, 2, dstCap[:0])
if a == nil {
return ""
}
return s[a[0]:a[1]]
}
// FindStringIndex returns a two-element slice of integers defining the
// location of the leftmost match in s of the regular expression. The match
// itself is at s[loc[0]:loc[1]].
// A return value of nil indicates no match.
func (re *Regexp) FindStringIndex(s string) (loc []int) {
a := re.doExecute(nil, nil, s, 0, 2, nil)
if a == nil {
return nil
}
return a[0:2]
}
// FindReaderIndex returns a two-element slice of integers defining the
// location of the leftmost match of the regular expression in text read from
// the RuneReader. The match text was found in the input stream at
// byte offset loc[0] through loc[1]-1.
// A return value of nil indicates no match.
func (re *Regexp) FindReaderIndex(r io.RuneReader) (loc []int) {
a := re.doExecute(r, nil, "", 0, 2, nil)
if a == nil {
return nil
}
return a[0:2]
}
// FindSubmatch returns a slice of slices holding the text of the leftmost
// match of the regular expression in b and the matches, if any, of its
// subexpressions, as defined by the 'Submatch' descriptions in the package
// comment.
// A return value of nil indicates no match.
func (re *Regexp) FindSubmatch(b []byte) [][]byte {
var dstCap [4]int
a := re.doExecute(nil, b, "", 0, re.prog.NumCap, dstCap[:0])
if a == nil {
return nil
}
ret := make([][]byte, 1+re.numSubexp)
for i := range ret {
if 2*i < len(a) && a[2*i] >= 0 {
ret[i] = b[a[2*i]:a[2*i+1]:a[2*i+1]]
}
}
return ret
}
// Expand appends template to dst and returns the result; during the
// append, Expand replaces variables in the template with corresponding
// matches drawn from src. The match slice should have been returned by
// FindSubmatchIndex.
//
// In the template, a variable is denoted by a substring of the form
// $name or ${name}, where name is a non-empty sequence of letters,
// digits, and underscores. A purely numeric name like $1 refers to
// the submatch with the corresponding index; other names refer to
// capturing parentheses named with the (?P<name>...) syntax. A
// reference to an out of range or unmatched index or a name that is not
// present in the regular expression is replaced with an empty slice.
//
// In the $name form, name is taken to be as long as possible: $1x is
// equivalent to ${1x}, not ${1}x, and, $10 is equivalent to ${10}, not ${1}0.
//
// To insert a literal $ in the output, use $$ in the template.
func (re *Regexp) Expand(dst []byte, template []byte, src []byte, match []int) []byte {
return re.expand(dst, string(template), src, "", match)
}
// ExpandString is like Expand but the template and source are strings.
// It appends to and returns a byte slice in order to give the calling
// code control over allocation.
func (re *Regexp) ExpandString(dst []byte, template string, src string, match []int) []byte {
return re.expand(dst, template, nil, src, match)
}
func (re *Regexp) expand(dst []byte, template string, bsrc []byte, src string, match []int) []byte {
for len(template) > 0 {
before, after, ok := strings.Cut(template, "$")
if !ok {
break
}
dst = append(dst, before...)
template = after
if template != "" && template[0] == '$' {
// Treat $$ as $.
dst = append(dst, '$')
template = template[1:]
continue
}
name, num, rest, ok := extract(template)
if !ok {
// Malformed; treat $ as raw text.
dst = append(dst, '$')
continue
}
template = rest
if num >= 0 {
if 2*num+1 < len(match) && match[2*num] >= 0 {
if bsrc != nil {
dst = append(dst, bsrc[match[2*num]:match[2*num+1]]...)
} else {
dst = append(dst, src[match[2*num]:match[2*num+1]]...)
}
}
} else {
for i, namei := range re.subexpNames {
if name == namei && 2*i+1 < len(match) && match[2*i] >= 0 {
if bsrc != nil {
dst = append(dst, bsrc[match[2*i]:match[2*i+1]]...)
} else {
dst = append(dst, src[match[2*i]:match[2*i+1]]...)
}
break
}
}
}
}
dst = append(dst, template...)
return dst
}
// extract returns the name from a leading "name" or "{name}" in str.
// (The $ has already been removed by the caller.)
// If it is a number, extract returns num set to that number; otherwise num = -1.
func extract(str string) (name string, num int, rest string, ok bool) {
if str == "" {
return
}
brace := false
if str[0] == '{' {
brace = true
str = str[1:]
}
i := 0
for i < len(str) {
rune, size := utf8.DecodeRuneInString(str[i:])
if !unicode.IsLetter(rune) && !unicode.IsDigit(rune) && rune != '_' {
break
}
i += size
}
if i == 0 {
// empty name is not okay
return
}
name = str[:i]
if brace {
if i >= len(str) || str[i] != '}' {
// missing closing brace
return
}
i++
}
// Parse number.
num = 0
for i := 0; i < len(name); i++ {
if name[i] < '0' || '9' < name[i] || num >= 1e8 {
num = -1
break
}
num = num*10 + int(name[i]) - '0'
}
// Disallow leading zeros.
if name[0] == '0' && len(name) > 1 {
num = -1
}
rest = str[i:]
ok = true
return
}
// FindSubmatchIndex returns a slice holding the index pairs identifying the
// leftmost match of the regular expression in b and the matches, if any, of
// its subexpressions, as defined by the 'Submatch' and 'Index' descriptions
// in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindSubmatchIndex(b []byte) []int {
return re.pad(re.doExecute(nil, b, "", 0, re.prog.NumCap, nil))
}
// FindStringSubmatch returns a slice of strings holding the text of the
// leftmost match of the regular expression in s and the matches, if any, of
// its subexpressions, as defined by the 'Submatch' description in the
// package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindStringSubmatch(s string) []string {
var dstCap [4]int
a := re.doExecute(nil, nil, s, 0, re.prog.NumCap, dstCap[:0])
if a == nil {
return nil
}
ret := make([]string, 1+re.numSubexp)
for i := range ret {
if 2*i < len(a) && a[2*i] >= 0 {
ret[i] = s[a[2*i]:a[2*i+1]]
}
}
return ret
}
// FindStringSubmatchIndex returns a slice holding the index pairs
// identifying the leftmost match of the regular expression in s and the
// matches, if any, of its subexpressions, as defined by the 'Submatch' and
// 'Index' descriptions in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindStringSubmatchIndex(s string) []int {
return re.pad(re.doExecute(nil, nil, s, 0, re.prog.NumCap, nil))
}
// FindReaderSubmatchIndex returns a slice holding the index pairs
// identifying the leftmost match of the regular expression of text read by
// the RuneReader, and the matches, if any, of its subexpressions, as defined
// by the 'Submatch' and 'Index' descriptions in the package comment. A
// return value of nil indicates no match.
func (re *Regexp) FindReaderSubmatchIndex(r io.RuneReader) []int {
return re.pad(re.doExecute(r, nil, "", 0, re.prog.NumCap, nil))
}
const startSize = 10 // The size at which to start a slice in the 'All' routines.
// FindAll is the 'All' version of Find; it returns a slice of all successive
// matches of the expression, as defined by the 'All' description in the
// package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAll(b []byte, n int) [][]byte {
if n < 0 {
n = len(b) + 1
}
var result [][]byte
re.allMatches("", b, n, func(match []int) {
if result == nil {
result = make([][]byte, 0, startSize)
}
result = append(result, b[match[0]:match[1]:match[1]])
})
return result
}
// FindAllIndex is the 'All' version of FindIndex; it returns a slice of all
// successive matches of the expression, as defined by the 'All' description
// in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllIndex(b []byte, n int) [][]int {
if n < 0 {
n = len(b) + 1
}
var result [][]int
re.allMatches("", b, n, func(match []int) {
if result == nil {
result = make([][]int, 0, startSize)
}
result = append(result, match[0:2])
})
return result
}
// FindAllString is the 'All' version of FindString; it returns a slice of all
// successive matches of the expression, as defined by the 'All' description
// in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllString(s string, n int) []string {
if n < 0 {
n = len(s) + 1
}
var result []string
re.allMatches(s, nil, n, func(match []int) {
if result == nil {
result = make([]string, 0, startSize)
}
result = append(result, s[match[0]:match[1]])
})
return result
}
// FindAllStringIndex is the 'All' version of FindStringIndex; it returns a
// slice of all successive matches of the expression, as defined by the 'All'
// description in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllStringIndex(s string, n int) [][]int {
if n < 0 {
n = len(s) + 1
}
var result [][]int
re.allMatches(s, nil, n, func(match []int) {
if result == nil {
result = make([][]int, 0, startSize)
}
result = append(result, match[0:2])
})
return result
}
// FindAllSubmatch is the 'All' version of FindSubmatch; it returns a slice
// of all successive matches of the expression, as defined by the 'All'
// description in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllSubmatch(b []byte, n int) [][][]byte {
if n < 0 {
n = len(b) + 1
}
var result [][][]byte
re.allMatches("", b, n, func(match []int) {
if result == nil {
result = make([][][]byte, 0, startSize)
}
slice := make([][]byte, len(match)/2)
for j := range slice {
if match[2*j] >= 0 {
slice[j] = b[match[2*j]:match[2*j+1]:match[2*j+1]]
}
}
result = append(result, slice)
})
return result
}
// FindAllSubmatchIndex is the 'All' version of FindSubmatchIndex; it returns
// a slice of all successive matches of the expression, as defined by the
// 'All' description in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllSubmatchIndex(b []byte, n int) [][]int {
if n < 0 {
n = len(b) + 1
}
var result [][]int
re.allMatches("", b, n, func(match []int) {
if result == nil {
result = make([][]int, 0, startSize)
}
result = append(result, match)
})
return result
}
// FindAllStringSubmatch is the 'All' version of FindStringSubmatch; it
// returns a slice of all successive matches of the expression, as defined by
// the 'All' description in the package comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllStringSubmatch(s string, n int) [][]string {
if n < 0 {
n = len(s) + 1
}
var result [][]string
re.allMatches(s, nil, n, func(match []int) {
if result == nil {
result = make([][]string, 0, startSize)
}
slice := make([]string, len(match)/2)
for j := range slice {
if match[2*j] >= 0 {
slice[j] = s[match[2*j]:match[2*j+1]]
}
}
result = append(result, slice)
})
return result
}
// FindAllStringSubmatchIndex is the 'All' version of
// FindStringSubmatchIndex; it returns a slice of all successive matches of
// the expression, as defined by the 'All' description in the package
// comment.
// A return value of nil indicates no match.
func (re *Regexp) FindAllStringSubmatchIndex(s string, n int) [][]int {
if n < 0 {
n = len(s) + 1
}
var result [][]int
re.allMatches(s, nil, n, func(match []int) {
if result == nil {
result = make([][]int, 0, startSize)
}
result = append(result, match)
})
return result
}
// Split slices s into substrings separated by the expression and returns a slice of
// the substrings between those expression matches.
//
// The slice returned by this method consists of all the substrings of s
// not contained in the slice returned by FindAllString. When called on an expression
// that contains no metacharacters, it is equivalent to strings.SplitN.
//
// Example:
//
// s := regexp.MustCompile("a*").Split("abaabaccadaaae", 5)
// // s: ["", "b", "b", "c", "cadaaae"]
//
// The count determines the number of substrings to return:
//
// n > 0: at most n substrings; the last substring will be the unsplit remainder.
// n == 0: the result is nil (zero substrings)
// n < 0: all substrings
func (re *Regexp) Split(s string, n int) []string {
if n == 0 {
return nil
}
if len(re.expr) > 0 && len(s) == 0 {
return []string{""}
}
matches := re.FindAllStringIndex(s, n)
strings := make([]string, 0, len(matches))
beg := 0
end := 0
for _, match := range matches {
if n > 0 && len(strings) >= n-1 {
break
}
end = match[0]
if match[1] != 0 {
strings = append(strings, s[beg:end])
}
beg = match[1]
}
if end != len(s) {
strings = append(strings, s[beg:])
}
return strings
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syntax
import "unicode"
// A patchList is a list of instruction pointers that need to be filled in (patched).
// Because the pointers haven't been filled in yet, we can reuse their storage
// to hold the list. It's kind of sleazy, but works well in practice.
// See https://swtch.com/~rsc/regexp/regexp1.html for inspiration.
//
// These aren't really pointers: they're integers, so we can reinterpret them
// this way without using package unsafe. A value l.head denotes
// p.inst[l.head>>1].Out (l.head&1==0) or .Arg (l.head&1==1).
// head == 0 denotes the empty list, okay because we start every program
// with a fail instruction, so we'll never want to point at its output link.
type patchList struct {
head, tail uint32
}
func makePatchList(n uint32) patchList {
return patchList{n, n}
}
func (l patchList) patch(p *Prog, val uint32) {
head := l.head
for head != 0 {
i := &p.Inst[head>>1]
if head&1 == 0 {
head = i.Out
i.Out = val
} else {
head = i.Arg
i.Arg = val
}
}
}
func (l1 patchList) append(p *Prog, l2 patchList) patchList {
if l1.head == 0 {
return l2
}
if l2.head == 0 {
return l1
}
i := &p.Inst[l1.tail>>1]
if l1.tail&1 == 0 {
i.Out = l2.head
} else {
i.Arg = l2.head
}
return patchList{l1.head, l2.tail}
}
// A frag represents a compiled program fragment.
type frag struct {
i uint32 // index of first instruction
out patchList // where to record end instruction
nullable bool // whether fragment can match empty string
}
type compiler struct {
p *Prog
}
// Compile compiles the regexp into a program to be executed.
// The regexp should have been simplified already (returned from re.Simplify).
func Compile(re *Regexp) (*Prog, error) {
var c compiler
c.init()
f := c.compile(re)
f.out.patch(c.p, c.inst(InstMatch).i)
c.p.Start = int(f.i)
return c.p, nil
}
func (c *compiler) init() {
c.p = new(Prog)
c.p.NumCap = 2 // implicit ( and ) for whole match $0
c.inst(InstFail)
}
var anyRuneNotNL = []rune{0, '\n' - 1, '\n' + 1, unicode.MaxRune}
var anyRune = []rune{0, unicode.MaxRune}
func (c *compiler) compile(re *Regexp) frag {
switch re.Op {
case OpNoMatch:
return c.fail()
case OpEmptyMatch:
return c.nop()
case OpLiteral:
if len(re.Rune) == 0 {
return c.nop()
}
var f frag
for j := range re.Rune {
f1 := c.rune(re.Rune[j:j+1], re.Flags)
if j == 0 {
f = f1
} else {
f = c.cat(f, f1)
}
}
return f
case OpCharClass:
return c.rune(re.Rune, re.Flags)
case OpAnyCharNotNL:
return c.rune(anyRuneNotNL, 0)
case OpAnyChar:
return c.rune(anyRune, 0)
case OpBeginLine:
return c.empty(EmptyBeginLine)
case OpEndLine:
return c.empty(EmptyEndLine)
case OpBeginText:
return c.empty(EmptyBeginText)
case OpEndText:
return c.empty(EmptyEndText)
case OpWordBoundary:
return c.empty(EmptyWordBoundary)
case OpNoWordBoundary:
return c.empty(EmptyNoWordBoundary)
case OpCapture:
bra := c.cap(uint32(re.Cap << 1))
sub := c.compile(re.Sub[0])
ket := c.cap(uint32(re.Cap<<1 | 1))
return c.cat(c.cat(bra, sub), ket)
case OpStar:
return c.star(c.compile(re.Sub[0]), re.Flags&NonGreedy != 0)
case OpPlus:
return c.plus(c.compile(re.Sub[0]), re.Flags&NonGreedy != 0)
case OpQuest:
return c.quest(c.compile(re.Sub[0]), re.Flags&NonGreedy != 0)
case OpConcat:
if len(re.Sub) == 0 {
return c.nop()
}
var f frag
for i, sub := range re.Sub {
if i == 0 {
f = c.compile(sub)
} else {
f = c.cat(f, c.compile(sub))
}
}
return f
case OpAlternate:
var f frag
for _, sub := range re.Sub {
f = c.alt(f, c.compile(sub))
}
return f
}
panic("regexp: unhandled case in compile")
}
func (c *compiler) inst(op InstOp) frag {
// TODO: impose length limit
f := frag{i: uint32(len(c.p.Inst)), nullable: true}
c.p.Inst = append(c.p.Inst, Inst{Op: op})
return f
}
func (c *compiler) nop() frag {
f := c.inst(InstNop)
f.out = makePatchList(f.i << 1)
return f
}
func (c *compiler) fail() frag {
return frag{}
}
func (c *compiler) cap(arg uint32) frag {
f := c.inst(InstCapture)
f.out = makePatchList(f.i << 1)
c.p.Inst[f.i].Arg = arg
if c.p.NumCap < int(arg)+1 {
c.p.NumCap = int(arg) + 1
}
return f
}
func (c *compiler) cat(f1, f2 frag) frag {
// concat of failure is failure
if f1.i == 0 || f2.i == 0 {
return frag{}
}
// TODO: elide nop
f1.out.patch(c.p, f2.i)
return frag{f1.i, f2.out, f1.nullable && f2.nullable}
}
func (c *compiler) alt(f1, f2 frag) frag {
// alt of failure is other
if f1.i == 0 {
return f2
}
if f2.i == 0 {
return f1
}
f := c.inst(InstAlt)
i := &c.p.Inst[f.i]
i.Out = f1.i
i.Arg = f2.i
f.out = f1.out.append(c.p, f2.out)
f.nullable = f1.nullable || f2.nullable
return f
}
func (c *compiler) quest(f1 frag, nongreedy bool) frag {
f := c.inst(InstAlt)
i := &c.p.Inst[f.i]
if nongreedy {
i.Arg = f1.i
f.out = makePatchList(f.i << 1)
} else {
i.Out = f1.i
f.out = makePatchList(f.i<<1 | 1)
}
f.out = f.out.append(c.p, f1.out)
return f
}
// loop returns the fragment for the main loop of a plus or star.
// For plus, it can be used after changing the entry to f1.i.
// For star, it can be used directly when f1 can't match an empty string.
// (When f1 can match an empty string, f1* must be implemented as (f1+)?
// to get the priority match order correct.)
func (c *compiler) loop(f1 frag, nongreedy bool) frag {
f := c.inst(InstAlt)
i := &c.p.Inst[f.i]
if nongreedy {
i.Arg = f1.i
f.out = makePatchList(f.i << 1)
} else {
i.Out = f1.i
f.out = makePatchList(f.i<<1 | 1)
}
f1.out.patch(c.p, f.i)
return f
}
func (c *compiler) star(f1 frag, nongreedy bool) frag {
if f1.nullable {
// Use (f1+)? to get priority match order correct.
// See golang.org/issue/46123.
return c.quest(c.plus(f1, nongreedy), nongreedy)
}
return c.loop(f1, nongreedy)
}
func (c *compiler) plus(f1 frag, nongreedy bool) frag {
return frag{f1.i, c.loop(f1, nongreedy).out, f1.nullable}
}
func (c *compiler) empty(op EmptyOp) frag {
f := c.inst(InstEmptyWidth)
c.p.Inst[f.i].Arg = uint32(op)
f.out = makePatchList(f.i << 1)
return f
}
func (c *compiler) rune(r []rune, flags Flags) frag {
f := c.inst(InstRune)
f.nullable = false
i := &c.p.Inst[f.i]
i.Rune = r
flags &= FoldCase // only relevant flag is FoldCase
if len(r) != 1 || unicode.SimpleFold(r[0]) == r[0] {
// and sometimes not even that
flags &^= FoldCase
}
i.Arg = uint32(flags)
f.out = makePatchList(f.i << 1)
// Special cases for exec machine.
switch {
case flags&FoldCase == 0 && (len(r) == 1 || len(r) == 2 && r[0] == r[1]):
i.Op = InstRune1
case len(r) == 2 && r[0] == 0 && r[1] == unicode.MaxRune:
i.Op = InstRuneAny
case len(r) == 4 && r[0] == 0 && r[1] == '\n'-1 && r[2] == '\n'+1 && r[3] == unicode.MaxRune:
i.Op = InstRuneAnyNotNL
}
return f
}
// Code generated by "stringer -type Op -trimprefix Op"; DO NOT EDIT.
package syntax
import "strconv"
const (
_Op_name_0 = "NoMatchEmptyMatchLiteralCharClassAnyCharNotNLAnyCharBeginLineEndLineBeginTextEndTextWordBoundaryNoWordBoundaryCaptureStarPlusQuestRepeatConcatAlternate"
_Op_name_1 = "opPseudo"
)
var (
_Op_index_0 = [...]uint8{0, 7, 17, 24, 33, 45, 52, 61, 68, 77, 84, 96, 110, 117, 121, 125, 130, 136, 142, 151}
)
func (i Op) String() string {
switch {
case 1 <= i && i <= 19:
i -= 1
return _Op_name_0[_Op_index_0[i]:_Op_index_0[i+1]]
case i == 128:
return _Op_name_1
default:
return "Op(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syntax
import (
"sort"
"strings"
"unicode"
"unicode/utf8"
)
// An Error describes a failure to parse a regular expression
// and gives the offending expression.
type Error struct {
Code ErrorCode
Expr string
}
func (e *Error) Error() string {
return "error parsing regexp: " + e.Code.String() + ": `" + e.Expr + "`"
}
// An ErrorCode describes a failure to parse a regular expression.
type ErrorCode string
const (
// Unexpected error
ErrInternalError ErrorCode = "regexp/syntax: internal error"
// Parse errors
ErrInvalidCharClass ErrorCode = "invalid character class"
ErrInvalidCharRange ErrorCode = "invalid character class range"
ErrInvalidEscape ErrorCode = "invalid escape sequence"
ErrInvalidNamedCapture ErrorCode = "invalid named capture"
ErrInvalidPerlOp ErrorCode = "invalid or unsupported Perl syntax"
ErrInvalidRepeatOp ErrorCode = "invalid nested repetition operator"
ErrInvalidRepeatSize ErrorCode = "invalid repeat count"
ErrInvalidUTF8 ErrorCode = "invalid UTF-8"
ErrMissingBracket ErrorCode = "missing closing ]"
ErrMissingParen ErrorCode = "missing closing )"
ErrMissingRepeatArgument ErrorCode = "missing argument to repetition operator"
ErrTrailingBackslash ErrorCode = "trailing backslash at end of expression"
ErrUnexpectedParen ErrorCode = "unexpected )"
ErrNestingDepth ErrorCode = "expression nests too deeply"
ErrLarge ErrorCode = "expression too large"
)
func (e ErrorCode) String() string {
return string(e)
}
// Flags control the behavior of the parser and record information about regexp context.
type Flags uint16
const (
FoldCase Flags = 1 << iota // case-insensitive match
Literal // treat pattern as literal string
ClassNL // allow character classes like [^a-z] and [[:space:]] to match newline
DotNL // allow . to match newline
OneLine // treat ^ and $ as only matching at beginning and end of text
NonGreedy // make repetition operators default to non-greedy
PerlX // allow Perl extensions
UnicodeGroups // allow \p{Han}, \P{Han} for Unicode group and negation
WasDollar // regexp OpEndText was $, not \z
Simple // regexp contains no counted repetition
MatchNL = ClassNL | DotNL
Perl = ClassNL | OneLine | PerlX | UnicodeGroups // as close to Perl as possible
POSIX Flags = 0 // POSIX syntax
)
// Pseudo-ops for parsing stack.
const (
opLeftParen = opPseudo + iota
opVerticalBar
)
// maxHeight is the maximum height of a regexp parse tree.
// It is somewhat arbitrarily chosen, but the idea is to be large enough
// that no one will actually hit in real use but at the same time small enough
// that recursion on the Regexp tree will not hit the 1GB Go stack limit.
// The maximum amount of stack for a single recursive frame is probably
// closer to 1kB, so this could potentially be raised, but it seems unlikely
// that people have regexps nested even this deeply.
// We ran a test on Google's C++ code base and turned up only
// a single use case with depth > 100; it had depth 128.
// Using depth 1000 should be plenty of margin.
// As an optimization, we don't even bother calculating heights
// until we've allocated at least maxHeight Regexp structures.
const maxHeight = 1000
// maxSize is the maximum size of a compiled regexp in Insts.
// It too is somewhat arbitrarily chosen, but the idea is to be large enough
// to allow significant regexps while at the same time small enough that
// the compiled form will not take up too much memory.
// 128 MB is enough for a 3.3 million Inst structures, which roughly
// corresponds to a 3.3 MB regexp.
const (
maxSize = 128 << 20 / instSize
instSize = 5 * 8 // byte, 2 uint32, slice is 5 64-bit words
)
// maxRunes is the maximum number of runes allowed in a regexp tree
// counting the runes in all the nodes.
// Ignoring character classes p.numRunes is always less than the length of the regexp.
// Character classes can make it much larger: each \pL adds 1292 runes.
// 128 MB is enough for 32M runes, which is over 26k \pL instances.
// Note that repetitions do not make copies of the rune slices,
// so \pL{1000} is only one rune slice, not 1000.
// We could keep a cache of character classes we've seen,
// so that all the \pL we see use the same rune list,
// but that doesn't remove the problem entirely:
// consider something like [\pL01234][\pL01235][\pL01236]...[\pL^&*()].
// And because the Rune slice is exposed directly in the Regexp,
// there is not an opportunity to change the representation to allow
// partial sharing between different character classes.
// So the limit is the best we can do.
const (
maxRunes = 128 << 20 / runeSize
runeSize = 4 // rune is int32
)
type parser struct {
flags Flags // parse mode flags
stack []*Regexp // stack of parsed expressions
free *Regexp
numCap int // number of capturing groups seen
wholeRegexp string
tmpClass []rune // temporary char class work space
numRegexp int // number of regexps allocated
numRunes int // number of runes in char classes
repeats int64 // product of all repetitions seen
height map[*Regexp]int // regexp height, for height limit check
size map[*Regexp]int64 // regexp compiled size, for size limit check
}
func (p *parser) newRegexp(op Op) *Regexp {
re := p.free
if re != nil {
p.free = re.Sub0[0]
*re = Regexp{}
} else {
re = new(Regexp)
p.numRegexp++
}
re.Op = op
return re
}
func (p *parser) reuse(re *Regexp) {
if p.height != nil {
delete(p.height, re)
}
re.Sub0[0] = p.free
p.free = re
}
func (p *parser) checkLimits(re *Regexp) {
if p.numRunes > maxRunes {
panic(ErrLarge)
}
p.checkSize(re)
p.checkHeight(re)
}
func (p *parser) checkSize(re *Regexp) {
if p.size == nil {
// We haven't started tracking size yet.
// Do a relatively cheap check to see if we need to start.
// Maintain the product of all the repeats we've seen
// and don't track if the total number of regexp nodes
// we've seen times the repeat product is in budget.
if p.repeats == 0 {
p.repeats = 1
}
if re.Op == OpRepeat {
n := re.Max
if n == -1 {
n = re.Min
}
if n <= 0 {
n = 1
}
if int64(n) > maxSize/p.repeats {
p.repeats = maxSize
} else {
p.repeats *= int64(n)
}
}
if int64(p.numRegexp) < maxSize/p.repeats {
return
}
// We need to start tracking size.
// Make the map and belatedly populate it
// with info about everything we've constructed so far.
p.size = make(map[*Regexp]int64)
for _, re := range p.stack {
p.checkSize(re)
}
}
if p.calcSize(re, true) > maxSize {
panic(ErrLarge)
}
}
func (p *parser) calcSize(re *Regexp, force bool) int64 {
if !force {
if size, ok := p.size[re]; ok {
return size
}
}
var size int64
switch re.Op {
case OpLiteral:
size = int64(len(re.Rune))
case OpCapture, OpStar:
// star can be 1+ or 2+; assume 2 pessimistically
size = 2 + p.calcSize(re.Sub[0], false)
case OpPlus, OpQuest:
size = 1 + p.calcSize(re.Sub[0], false)
case OpConcat:
for _, sub := range re.Sub {
size += p.calcSize(sub, false)
}
case OpAlternate:
for _, sub := range re.Sub {
size += p.calcSize(sub, false)
}
if len(re.Sub) > 1 {
size += int64(len(re.Sub)) - 1
}
case OpRepeat:
sub := p.calcSize(re.Sub[0], false)
if re.Max == -1 {
if re.Min == 0 {
size = 2 + sub // x*
} else {
size = 1 + int64(re.Min)*sub // xxx+
}
break
}
// x{2,5} = xx(x(x(x)?)?)?
size = int64(re.Max)*sub + int64(re.Max-re.Min)
}
if size < 1 {
size = 1
}
p.size[re] = size
return size
}
func (p *parser) checkHeight(re *Regexp) {
if p.numRegexp < maxHeight {
return
}
if p.height == nil {
p.height = make(map[*Regexp]int)
for _, re := range p.stack {
p.checkHeight(re)
}
}
if p.calcHeight(re, true) > maxHeight {
panic(ErrNestingDepth)
}
}
func (p *parser) calcHeight(re *Regexp, force bool) int {
if !force {
if h, ok := p.height[re]; ok {
return h
}
}
h := 1
for _, sub := range re.Sub {
hsub := p.calcHeight(sub, false)
if h < 1+hsub {
h = 1 + hsub
}
}
p.height[re] = h
return h
}
// Parse stack manipulation.
// push pushes the regexp re onto the parse stack and returns the regexp.
func (p *parser) push(re *Regexp) *Regexp {
p.numRunes += len(re.Rune)
if re.Op == OpCharClass && len(re.Rune) == 2 && re.Rune[0] == re.Rune[1] {
// Single rune.
if p.maybeConcat(re.Rune[0], p.flags&^FoldCase) {
return nil
}
re.Op = OpLiteral
re.Rune = re.Rune[:1]
re.Flags = p.flags &^ FoldCase
} else if re.Op == OpCharClass && len(re.Rune) == 4 &&
re.Rune[0] == re.Rune[1] && re.Rune[2] == re.Rune[3] &&
unicode.SimpleFold(re.Rune[0]) == re.Rune[2] &&
unicode.SimpleFold(re.Rune[2]) == re.Rune[0] ||
re.Op == OpCharClass && len(re.Rune) == 2 &&
re.Rune[0]+1 == re.Rune[1] &&
unicode.SimpleFold(re.Rune[0]) == re.Rune[1] &&
unicode.SimpleFold(re.Rune[1]) == re.Rune[0] {
// Case-insensitive rune like [Aa] or [Δδ].
if p.maybeConcat(re.Rune[0], p.flags|FoldCase) {
return nil
}
// Rewrite as (case-insensitive) literal.
re.Op = OpLiteral
re.Rune = re.Rune[:1]
re.Flags = p.flags | FoldCase
} else {
// Incremental concatenation.
p.maybeConcat(-1, 0)
}
p.stack = append(p.stack, re)
p.checkLimits(re)
return re
}
// maybeConcat implements incremental concatenation
// of literal runes into string nodes. The parser calls this
// before each push, so only the top fragment of the stack
// might need processing. Since this is called before a push,
// the topmost literal is no longer subject to operators like *
// (Otherwise ab* would turn into (ab)*.)
// If r >= 0 and there's a node left over, maybeConcat uses it
// to push r with the given flags.
// maybeConcat reports whether r was pushed.
func (p *parser) maybeConcat(r rune, flags Flags) bool {
n := len(p.stack)
if n < 2 {
return false
}
re1 := p.stack[n-1]
re2 := p.stack[n-2]
if re1.Op != OpLiteral || re2.Op != OpLiteral || re1.Flags&FoldCase != re2.Flags&FoldCase {
return false
}
// Push re1 into re2.
re2.Rune = append(re2.Rune, re1.Rune...)
// Reuse re1 if possible.
if r >= 0 {
re1.Rune = re1.Rune0[:1]
re1.Rune[0] = r
re1.Flags = flags
return true
}
p.stack = p.stack[:n-1]
p.reuse(re1)
return false // did not push r
}
// literal pushes a literal regexp for the rune r on the stack.
func (p *parser) literal(r rune) {
re := p.newRegexp(OpLiteral)
re.Flags = p.flags
if p.flags&FoldCase != 0 {
r = minFoldRune(r)
}
re.Rune0[0] = r
re.Rune = re.Rune0[:1]
p.push(re)
}
// minFoldRune returns the minimum rune fold-equivalent to r.
func minFoldRune(r rune) rune {
if r < minFold || r > maxFold {
return r
}
min := r
r0 := r
for r = unicode.SimpleFold(r); r != r0; r = unicode.SimpleFold(r) {
if min > r {
min = r
}
}
return min
}
// op pushes a regexp with the given op onto the stack
// and returns that regexp.
func (p *parser) op(op Op) *Regexp {
re := p.newRegexp(op)
re.Flags = p.flags
return p.push(re)
}
// repeat replaces the top stack element with itself repeated according to op, min, max.
// before is the regexp suffix starting at the repetition operator.
// after is the regexp suffix following after the repetition operator.
// repeat returns an updated 'after' and an error, if any.
func (p *parser) repeat(op Op, min, max int, before, after, lastRepeat string) (string, error) {
flags := p.flags
if p.flags&PerlX != 0 {
if len(after) > 0 && after[0] == '?' {
after = after[1:]
flags ^= NonGreedy
}
if lastRepeat != "" {
// In Perl it is not allowed to stack repetition operators:
// a** is a syntax error, not a doubled star, and a++ means
// something else entirely, which we don't support!
return "", &Error{ErrInvalidRepeatOp, lastRepeat[:len(lastRepeat)-len(after)]}
}
}
n := len(p.stack)
if n == 0 {
return "", &Error{ErrMissingRepeatArgument, before[:len(before)-len(after)]}
}
sub := p.stack[n-1]
if sub.Op >= opPseudo {
return "", &Error{ErrMissingRepeatArgument, before[:len(before)-len(after)]}
}
re := p.newRegexp(op)
re.Min = min
re.Max = max
re.Flags = flags
re.Sub = re.Sub0[:1]
re.Sub[0] = sub
p.stack[n-1] = re
p.checkLimits(re)
if op == OpRepeat && (min >= 2 || max >= 2) && !repeatIsValid(re, 1000) {
return "", &Error{ErrInvalidRepeatSize, before[:len(before)-len(after)]}
}
return after, nil
}
// repeatIsValid reports whether the repetition re is valid.
// Valid means that the combination of the top-level repetition
// and any inner repetitions does not exceed n copies of the
// innermost thing.
// This function rewalks the regexp tree and is called for every repetition,
// so we have to worry about inducing quadratic behavior in the parser.
// We avoid this by only calling repeatIsValid when min or max >= 2.
// In that case the depth of any >= 2 nesting can only get to 9 without
// triggering a parse error, so each subtree can only be rewalked 9 times.
func repeatIsValid(re *Regexp, n int) bool {
if re.Op == OpRepeat {
m := re.Max
if m == 0 {
return true
}
if m < 0 {
m = re.Min
}
if m > n {
return false
}
if m > 0 {
n /= m
}
}
for _, sub := range re.Sub {
if !repeatIsValid(sub, n) {
return false
}
}
return true
}
// concat replaces the top of the stack (above the topmost '|' or '(') with its concatenation.
func (p *parser) concat() *Regexp {
p.maybeConcat(-1, 0)
// Scan down to find pseudo-operator | or (.
i := len(p.stack)
for i > 0 && p.stack[i-1].Op < opPseudo {
i--
}
subs := p.stack[i:]
p.stack = p.stack[:i]
// Empty concatenation is special case.
if len(subs) == 0 {
return p.push(p.newRegexp(OpEmptyMatch))
}
return p.push(p.collapse(subs, OpConcat))
}
// alternate replaces the top of the stack (above the topmost '(') with its alternation.
func (p *parser) alternate() *Regexp {
// Scan down to find pseudo-operator (.
// There are no | above (.
i := len(p.stack)
for i > 0 && p.stack[i-1].Op < opPseudo {
i--
}
subs := p.stack[i:]
p.stack = p.stack[:i]
// Make sure top class is clean.
// All the others already are (see swapVerticalBar).
if len(subs) > 0 {
cleanAlt(subs[len(subs)-1])
}
// Empty alternate is special case
// (shouldn't happen but easy to handle).
if len(subs) == 0 {
return p.push(p.newRegexp(OpNoMatch))
}
return p.push(p.collapse(subs, OpAlternate))
}
// cleanAlt cleans re for eventual inclusion in an alternation.
func cleanAlt(re *Regexp) {
switch re.Op {
case OpCharClass:
re.Rune = cleanClass(&re.Rune)
if len(re.Rune) == 2 && re.Rune[0] == 0 && re.Rune[1] == unicode.MaxRune {
re.Rune = nil
re.Op = OpAnyChar
return
}
if len(re.Rune) == 4 && re.Rune[0] == 0 && re.Rune[1] == '\n'-1 && re.Rune[2] == '\n'+1 && re.Rune[3] == unicode.MaxRune {
re.Rune = nil
re.Op = OpAnyCharNotNL
return
}
if cap(re.Rune)-len(re.Rune) > 100 {
// re.Rune will not grow any more.
// Make a copy or inline to reclaim storage.
re.Rune = append(re.Rune0[:0], re.Rune...)
}
}
}
// collapse returns the result of applying op to sub.
// If sub contains op nodes, they all get hoisted up
// so that there is never a concat of a concat or an
// alternate of an alternate.
func (p *parser) collapse(subs []*Regexp, op Op) *Regexp {
if len(subs) == 1 {
return subs[0]
}
re := p.newRegexp(op)
re.Sub = re.Sub0[:0]
for _, sub := range subs {
if sub.Op == op {
re.Sub = append(re.Sub, sub.Sub...)
p.reuse(sub)
} else {
re.Sub = append(re.Sub, sub)
}
}
if op == OpAlternate {
re.Sub = p.factor(re.Sub)
if len(re.Sub) == 1 {
old := re
re = re.Sub[0]
p.reuse(old)
}
}
return re
}
// factor factors common prefixes from the alternation list sub.
// It returns a replacement list that reuses the same storage and
// frees (passes to p.reuse) any removed *Regexps.
//
// For example,
//
// ABC|ABD|AEF|BCX|BCY
//
// simplifies by literal prefix extraction to
//
// A(B(C|D)|EF)|BC(X|Y)
//
// which simplifies by character class introduction to
//
// A(B[CD]|EF)|BC[XY]
func (p *parser) factor(sub []*Regexp) []*Regexp {
if len(sub) < 2 {
return sub
}
// Round 1: Factor out common literal prefixes.
var str []rune
var strflags Flags
start := 0
out := sub[:0]
for i := 0; i <= len(sub); i++ {
// Invariant: the Regexps that were in sub[0:start] have been
// used or marked for reuse, and the slice space has been reused
// for out (len(out) <= start).
//
// Invariant: sub[start:i] consists of regexps that all begin
// with str as modified by strflags.
var istr []rune
var iflags Flags
if i < len(sub) {
istr, iflags = p.leadingString(sub[i])
if iflags == strflags {
same := 0
for same < len(str) && same < len(istr) && str[same] == istr[same] {
same++
}
if same > 0 {
// Matches at least one rune in current range.
// Keep going around.
str = str[:same]
continue
}
}
}
// Found end of a run with common leading literal string:
// sub[start:i] all begin with str[0:len(str)], but sub[i]
// does not even begin with str[0].
//
// Factor out common string and append factored expression to out.
if i == start {
// Nothing to do - run of length 0.
} else if i == start+1 {
// Just one: don't bother factoring.
out = append(out, sub[start])
} else {
// Construct factored form: prefix(suffix1|suffix2|...)
prefix := p.newRegexp(OpLiteral)
prefix.Flags = strflags
prefix.Rune = append(prefix.Rune[:0], str...)
for j := start; j < i; j++ {
sub[j] = p.removeLeadingString(sub[j], len(str))
p.checkLimits(sub[j])
}
suffix := p.collapse(sub[start:i], OpAlternate) // recurse
re := p.newRegexp(OpConcat)
re.Sub = append(re.Sub[:0], prefix, suffix)
out = append(out, re)
}
// Prepare for next iteration.
start = i
str = istr
strflags = iflags
}
sub = out
// Round 2: Factor out common simple prefixes,
// just the first piece of each concatenation.
// This will be good enough a lot of the time.
//
// Complex subexpressions (e.g. involving quantifiers)
// are not safe to factor because that collapses their
// distinct paths through the automaton, which affects
// correctness in some cases.
start = 0
out = sub[:0]
var first *Regexp
for i := 0; i <= len(sub); i++ {
// Invariant: the Regexps that were in sub[0:start] have been
// used or marked for reuse, and the slice space has been reused
// for out (len(out) <= start).
//
// Invariant: sub[start:i] consists of regexps that all begin with ifirst.
var ifirst *Regexp
if i < len(sub) {
ifirst = p.leadingRegexp(sub[i])
if first != nil && first.Equal(ifirst) &&
// first must be a character class OR a fixed repeat of a character class.
(isCharClass(first) || (first.Op == OpRepeat && first.Min == first.Max && isCharClass(first.Sub[0]))) {
continue
}
}
// Found end of a run with common leading regexp:
// sub[start:i] all begin with first but sub[i] does not.
//
// Factor out common regexp and append factored expression to out.
if i == start {
// Nothing to do - run of length 0.
} else if i == start+1 {
// Just one: don't bother factoring.
out = append(out, sub[start])
} else {
// Construct factored form: prefix(suffix1|suffix2|...)
prefix := first
for j := start; j < i; j++ {
reuse := j != start // prefix came from sub[start]
sub[j] = p.removeLeadingRegexp(sub[j], reuse)
p.checkLimits(sub[j])
}
suffix := p.collapse(sub[start:i], OpAlternate) // recurse
re := p.newRegexp(OpConcat)
re.Sub = append(re.Sub[:0], prefix, suffix)
out = append(out, re)
}
// Prepare for next iteration.
start = i
first = ifirst
}
sub = out
// Round 3: Collapse runs of single literals into character classes.
start = 0
out = sub[:0]
for i := 0; i <= len(sub); i++ {
// Invariant: the Regexps that were in sub[0:start] have been
// used or marked for reuse, and the slice space has been reused
// for out (len(out) <= start).
//
// Invariant: sub[start:i] consists of regexps that are either
// literal runes or character classes.
if i < len(sub) && isCharClass(sub[i]) {
continue
}
// sub[i] is not a char or char class;
// emit char class for sub[start:i]...
if i == start {
// Nothing to do - run of length 0.
} else if i == start+1 {
out = append(out, sub[start])
} else {
// Make new char class.
// Start with most complex regexp in sub[start].
max := start
for j := start + 1; j < i; j++ {
if sub[max].Op < sub[j].Op || sub[max].Op == sub[j].Op && len(sub[max].Rune) < len(sub[j].Rune) {
max = j
}
}
sub[start], sub[max] = sub[max], sub[start]
for j := start + 1; j < i; j++ {
mergeCharClass(sub[start], sub[j])
p.reuse(sub[j])
}
cleanAlt(sub[start])
out = append(out, sub[start])
}
// ... and then emit sub[i].
if i < len(sub) {
out = append(out, sub[i])
}
start = i + 1
}
sub = out
// Round 4: Collapse runs of empty matches into a single empty match.
start = 0
out = sub[:0]
for i := range sub {
if i+1 < len(sub) && sub[i].Op == OpEmptyMatch && sub[i+1].Op == OpEmptyMatch {
continue
}
out = append(out, sub[i])
}
sub = out
return sub
}
// leadingString returns the leading literal string that re begins with.
// The string refers to storage in re or its children.
func (p *parser) leadingString(re *Regexp) ([]rune, Flags) {
if re.Op == OpConcat && len(re.Sub) > 0 {
re = re.Sub[0]
}
if re.Op != OpLiteral {
return nil, 0
}
return re.Rune, re.Flags & FoldCase
}
// removeLeadingString removes the first n leading runes
// from the beginning of re. It returns the replacement for re.
func (p *parser) removeLeadingString(re *Regexp, n int) *Regexp {
if re.Op == OpConcat && len(re.Sub) > 0 {
// Removing a leading string in a concatenation
// might simplify the concatenation.
sub := re.Sub[0]
sub = p.removeLeadingString(sub, n)
re.Sub[0] = sub
if sub.Op == OpEmptyMatch {
p.reuse(sub)
switch len(re.Sub) {
case 0, 1:
// Impossible but handle.
re.Op = OpEmptyMatch
re.Sub = nil
case 2:
old := re
re = re.Sub[1]
p.reuse(old)
default:
copy(re.Sub, re.Sub[1:])
re.Sub = re.Sub[:len(re.Sub)-1]
}
}
return re
}
if re.Op == OpLiteral {
re.Rune = re.Rune[:copy(re.Rune, re.Rune[n:])]
if len(re.Rune) == 0 {
re.Op = OpEmptyMatch
}
}
return re
}
// leadingRegexp returns the leading regexp that re begins with.
// The regexp refers to storage in re or its children.
func (p *parser) leadingRegexp(re *Regexp) *Regexp {
if re.Op == OpEmptyMatch {
return nil
}
if re.Op == OpConcat && len(re.Sub) > 0 {
sub := re.Sub[0]
if sub.Op == OpEmptyMatch {
return nil
}
return sub
}
return re
}
// removeLeadingRegexp removes the leading regexp in re.
// It returns the replacement for re.
// If reuse is true, it passes the removed regexp (if no longer needed) to p.reuse.
func (p *parser) removeLeadingRegexp(re *Regexp, reuse bool) *Regexp {
if re.Op == OpConcat && len(re.Sub) > 0 {
if reuse {
p.reuse(re.Sub[0])
}
re.Sub = re.Sub[:copy(re.Sub, re.Sub[1:])]
switch len(re.Sub) {
case 0:
re.Op = OpEmptyMatch
re.Sub = nil
case 1:
old := re
re = re.Sub[0]
p.reuse(old)
}
return re
}
if reuse {
p.reuse(re)
}
return p.newRegexp(OpEmptyMatch)
}
func literalRegexp(s string, flags Flags) *Regexp {
re := &Regexp{Op: OpLiteral}
re.Flags = flags
re.Rune = re.Rune0[:0] // use local storage for small strings
for _, c := range s {
if len(re.Rune) >= cap(re.Rune) {
// string is too long to fit in Rune0. let Go handle it
re.Rune = []rune(s)
break
}
re.Rune = append(re.Rune, c)
}
return re
}
// Parsing.
// Parse parses a regular expression string s, controlled by the specified
// Flags, and returns a regular expression parse tree. The syntax is
// described in the top-level comment.
func Parse(s string, flags Flags) (*Regexp, error) {
return parse(s, flags)
}
func parse(s string, flags Flags) (_ *Regexp, err error) {
defer func() {
switch r := recover(); r {
default:
panic(r)
case nil:
// ok
case ErrLarge: // too big
err = &Error{Code: ErrLarge, Expr: s}
case ErrNestingDepth:
err = &Error{Code: ErrNestingDepth, Expr: s}
}
}()
if flags&Literal != 0 {
// Trivial parser for literal string.
if err := checkUTF8(s); err != nil {
return nil, err
}
return literalRegexp(s, flags), nil
}
// Otherwise, must do real work.
var (
p parser
c rune
op Op
lastRepeat string
)
p.flags = flags
p.wholeRegexp = s
t := s
for t != "" {
repeat := ""
BigSwitch:
switch t[0] {
default:
if c, t, err = nextRune(t); err != nil {
return nil, err
}
p.literal(c)
case '(':
if p.flags&PerlX != 0 && len(t) >= 2 && t[1] == '?' {
// Flag changes and non-capturing groups.
if t, err = p.parsePerlFlags(t); err != nil {
return nil, err
}
break
}
p.numCap++
p.op(opLeftParen).Cap = p.numCap
t = t[1:]
case '|':
if err = p.parseVerticalBar(); err != nil {
return nil, err
}
t = t[1:]
case ')':
if err = p.parseRightParen(); err != nil {
return nil, err
}
t = t[1:]
case '^':
if p.flags&OneLine != 0 {
p.op(OpBeginText)
} else {
p.op(OpBeginLine)
}
t = t[1:]
case '$':
if p.flags&OneLine != 0 {
p.op(OpEndText).Flags |= WasDollar
} else {
p.op(OpEndLine)
}
t = t[1:]
case '.':
if p.flags&DotNL != 0 {
p.op(OpAnyChar)
} else {
p.op(OpAnyCharNotNL)
}
t = t[1:]
case '[':
if t, err = p.parseClass(t); err != nil {
return nil, err
}
case '*', '+', '?':
before := t
switch t[0] {
case '*':
op = OpStar
case '+':
op = OpPlus
case '?':
op = OpQuest
}
after := t[1:]
if after, err = p.repeat(op, 0, 0, before, after, lastRepeat); err != nil {
return nil, err
}
repeat = before
t = after
case '{':
op = OpRepeat
before := t
min, max, after, ok := p.parseRepeat(t)
if !ok {
// If the repeat cannot be parsed, { is a literal.
p.literal('{')
t = t[1:]
break
}
if min < 0 || min > 1000 || max > 1000 || max >= 0 && min > max {
// Numbers were too big, or max is present and min > max.
return nil, &Error{ErrInvalidRepeatSize, before[:len(before)-len(after)]}
}
if after, err = p.repeat(op, min, max, before, after, lastRepeat); err != nil {
return nil, err
}
repeat = before
t = after
case '\\':
if p.flags&PerlX != 0 && len(t) >= 2 {
switch t[1] {
case 'A':
p.op(OpBeginText)
t = t[2:]
break BigSwitch
case 'b':
p.op(OpWordBoundary)
t = t[2:]
break BigSwitch
case 'B':
p.op(OpNoWordBoundary)
t = t[2:]
break BigSwitch
case 'C':
// any byte; not supported
return nil, &Error{ErrInvalidEscape, t[:2]}
case 'Q':
// \Q ... \E: the ... is always literals
var lit string
lit, t, _ = strings.Cut(t[2:], `\E`)
for lit != "" {
c, rest, err := nextRune(lit)
if err != nil {
return nil, err
}
p.literal(c)
lit = rest
}
break BigSwitch
case 'z':
p.op(OpEndText)
t = t[2:]
break BigSwitch
}
}
re := p.newRegexp(OpCharClass)
re.Flags = p.flags
// Look for Unicode character group like \p{Han}
if len(t) >= 2 && (t[1] == 'p' || t[1] == 'P') {
r, rest, err := p.parseUnicodeClass(t, re.Rune0[:0])
if err != nil {
return nil, err
}
if r != nil {
re.Rune = r
t = rest
p.push(re)
break BigSwitch
}
}
// Perl character class escape.
if r, rest := p.parsePerlClassEscape(t, re.Rune0[:0]); r != nil {
re.Rune = r
t = rest
p.push(re)
break BigSwitch
}
p.reuse(re)
// Ordinary single-character escape.
if c, t, err = p.parseEscape(t); err != nil {
return nil, err
}
p.literal(c)
}
lastRepeat = repeat
}
p.concat()
if p.swapVerticalBar() {
// pop vertical bar
p.stack = p.stack[:len(p.stack)-1]
}
p.alternate()
n := len(p.stack)
if n != 1 {
return nil, &Error{ErrMissingParen, s}
}
return p.stack[0], nil
}
// parseRepeat parses {min} (max=min) or {min,} (max=-1) or {min,max}.
// If s is not of that form, it returns ok == false.
// If s has the right form but the values are too big, it returns min == -1, ok == true.
func (p *parser) parseRepeat(s string) (min, max int, rest string, ok bool) {
if s == "" || s[0] != '{' {
return
}
s = s[1:]
var ok1 bool
if min, s, ok1 = p.parseInt(s); !ok1 {
return
}
if s == "" {
return
}
if s[0] != ',' {
max = min
} else {
s = s[1:]
if s == "" {
return
}
if s[0] == '}' {
max = -1
} else if max, s, ok1 = p.parseInt(s); !ok1 {
return
} else if max < 0 {
// parseInt found too big a number
min = -1
}
}
if s == "" || s[0] != '}' {
return
}
rest = s[1:]
ok = true
return
}
// parsePerlFlags parses a Perl flag setting or non-capturing group or both,
// like (?i) or (?: or (?i:. It removes the prefix from s and updates the parse state.
// The caller must have ensured that s begins with "(?".
func (p *parser) parsePerlFlags(s string) (rest string, err error) {
t := s
// Check for named captures, first introduced in Python's regexp library.
// As usual, there are three slightly different syntaxes:
//
// (?P<name>expr) the original, introduced by Python
// (?<name>expr) the .NET alteration, adopted by Perl 5.10
// (?'name'expr) another .NET alteration, adopted by Perl 5.10
//
// Perl 5.10 gave in and implemented the Python version too,
// but they claim that the last two are the preferred forms.
// PCRE and languages based on it (specifically, PHP and Ruby)
// support all three as well. EcmaScript 4 uses only the Python form.
//
// In both the open source world (via Code Search) and the
// Google source tree, (?P<expr>name) is the dominant form,
// so that's the one we implement. One is enough.
if len(t) > 4 && t[2] == 'P' && t[3] == '<' {
// Pull out name.
end := strings.IndexRune(t, '>')
if end < 0 {
if err = checkUTF8(t); err != nil {
return "", err
}
return "", &Error{ErrInvalidNamedCapture, s}
}
capture := t[:end+1] // "(?P<name>"
name := t[4:end] // "name"
if err = checkUTF8(name); err != nil {
return "", err
}
if !isValidCaptureName(name) {
return "", &Error{ErrInvalidNamedCapture, capture}
}
// Like ordinary capture, but named.
p.numCap++
re := p.op(opLeftParen)
re.Cap = p.numCap
re.Name = name
return t[end+1:], nil
}
// Non-capturing group. Might also twiddle Perl flags.
var c rune
t = t[2:] // skip (?
flags := p.flags
sign := +1
sawFlag := false
Loop:
for t != "" {
if c, t, err = nextRune(t); err != nil {
return "", err
}
switch c {
default:
break Loop
// Flags.
case 'i':
flags |= FoldCase
sawFlag = true
case 'm':
flags &^= OneLine
sawFlag = true
case 's':
flags |= DotNL
sawFlag = true
case 'U':
flags |= NonGreedy
sawFlag = true
// Switch to negation.
case '-':
if sign < 0 {
break Loop
}
sign = -1
// Invert flags so that | above turn into &^ and vice versa.
// We'll invert flags again before using it below.
flags = ^flags
sawFlag = false
// End of flags, starting group or not.
case ':', ')':
if sign < 0 {
if !sawFlag {
break Loop
}
flags = ^flags
}
if c == ':' {
// Open new group
p.op(opLeftParen)
}
p.flags = flags
return t, nil
}
}
return "", &Error{ErrInvalidPerlOp, s[:len(s)-len(t)]}
}
// isValidCaptureName reports whether name
// is a valid capture name: [A-Za-z0-9_]+.
// PCRE limits names to 32 bytes.
// Python rejects names starting with digits.
// We don't enforce either of those.
func isValidCaptureName(name string) bool {
if name == "" {
return false
}
for _, c := range name {
if c != '_' && !isalnum(c) {
return false
}
}
return true
}
// parseInt parses a decimal integer.
func (p *parser) parseInt(s string) (n int, rest string, ok bool) {
if s == "" || s[0] < '0' || '9' < s[0] {
return
}
// Disallow leading zeros.
if len(s) >= 2 && s[0] == '0' && '0' <= s[1] && s[1] <= '9' {
return
}
t := s
for s != "" && '0' <= s[0] && s[0] <= '9' {
s = s[1:]
}
rest = s
ok = true
// Have digits, compute value.
t = t[:len(t)-len(s)]
for i := 0; i < len(t); i++ {
// Avoid overflow.
if n >= 1e8 {
n = -1
break
}
n = n*10 + int(t[i]) - '0'
}
return
}
// can this be represented as a character class?
// single-rune literal string, char class, ., and .|\n.
func isCharClass(re *Regexp) bool {
return re.Op == OpLiteral && len(re.Rune) == 1 ||
re.Op == OpCharClass ||
re.Op == OpAnyCharNotNL ||
re.Op == OpAnyChar
}
// does re match r?
func matchRune(re *Regexp, r rune) bool {
switch re.Op {
case OpLiteral:
return len(re.Rune) == 1 && re.Rune[0] == r
case OpCharClass:
for i := 0; i < len(re.Rune); i += 2 {
if re.Rune[i] <= r && r <= re.Rune[i+1] {
return true
}
}
return false
case OpAnyCharNotNL:
return r != '\n'
case OpAnyChar:
return true
}
return false
}
// parseVerticalBar handles a | in the input.
func (p *parser) parseVerticalBar() error {
p.concat()
// The concatenation we just parsed is on top of the stack.
// If it sits above an opVerticalBar, swap it below
// (things below an opVerticalBar become an alternation).
// Otherwise, push a new vertical bar.
if !p.swapVerticalBar() {
p.op(opVerticalBar)
}
return nil
}
// mergeCharClass makes dst = dst|src.
// The caller must ensure that dst.Op >= src.Op,
// to reduce the amount of copying.
func mergeCharClass(dst, src *Regexp) {
switch dst.Op {
case OpAnyChar:
// src doesn't add anything.
case OpAnyCharNotNL:
// src might add \n
if matchRune(src, '\n') {
dst.Op = OpAnyChar
}
case OpCharClass:
// src is simpler, so either literal or char class
if src.Op == OpLiteral {
dst.Rune = appendLiteral(dst.Rune, src.Rune[0], src.Flags)
} else {
dst.Rune = appendClass(dst.Rune, src.Rune)
}
case OpLiteral:
// both literal
if src.Rune[0] == dst.Rune[0] && src.Flags == dst.Flags {
break
}
dst.Op = OpCharClass
dst.Rune = appendLiteral(dst.Rune[:0], dst.Rune[0], dst.Flags)
dst.Rune = appendLiteral(dst.Rune, src.Rune[0], src.Flags)
}
}
// If the top of the stack is an element followed by an opVerticalBar
// swapVerticalBar swaps the two and returns true.
// Otherwise it returns false.
func (p *parser) swapVerticalBar() bool {
// If above and below vertical bar are literal or char class,
// can merge into a single char class.
n := len(p.stack)
if n >= 3 && p.stack[n-2].Op == opVerticalBar && isCharClass(p.stack[n-1]) && isCharClass(p.stack[n-3]) {
re1 := p.stack[n-1]
re3 := p.stack[n-3]
// Make re3 the more complex of the two.
if re1.Op > re3.Op {
re1, re3 = re3, re1
p.stack[n-3] = re3
}
mergeCharClass(re3, re1)
p.reuse(re1)
p.stack = p.stack[:n-1]
return true
}
if n >= 2 {
re1 := p.stack[n-1]
re2 := p.stack[n-2]
if re2.Op == opVerticalBar {
if n >= 3 {
// Now out of reach.
// Clean opportunistically.
cleanAlt(p.stack[n-3])
}
p.stack[n-2] = re1
p.stack[n-1] = re2
return true
}
}
return false
}
// parseRightParen handles a ) in the input.
func (p *parser) parseRightParen() error {
p.concat()
if p.swapVerticalBar() {
// pop vertical bar
p.stack = p.stack[:len(p.stack)-1]
}
p.alternate()
n := len(p.stack)
if n < 2 {
return &Error{ErrUnexpectedParen, p.wholeRegexp}
}
re1 := p.stack[n-1]
re2 := p.stack[n-2]
p.stack = p.stack[:n-2]
if re2.Op != opLeftParen {
return &Error{ErrUnexpectedParen, p.wholeRegexp}
}
// Restore flags at time of paren.
p.flags = re2.Flags
if re2.Cap == 0 {
// Just for grouping.
p.push(re1)
} else {
re2.Op = OpCapture
re2.Sub = re2.Sub0[:1]
re2.Sub[0] = re1
p.push(re2)
}
return nil
}
// parseEscape parses an escape sequence at the beginning of s
// and returns the rune.
func (p *parser) parseEscape(s string) (r rune, rest string, err error) {
t := s[1:]
if t == "" {
return 0, "", &Error{ErrTrailingBackslash, ""}
}
c, t, err := nextRune(t)
if err != nil {
return 0, "", err
}
Switch:
switch c {
default:
if c < utf8.RuneSelf && !isalnum(c) {
// Escaped non-word characters are always themselves.
// PCRE is not quite so rigorous: it accepts things like
// \q, but we don't. We once rejected \_, but too many
// programs and people insist on using it, so allow \_.
return c, t, nil
}
// Octal escapes.
case '1', '2', '3', '4', '5', '6', '7':
// Single non-zero digit is a backreference; not supported
if t == "" || t[0] < '0' || t[0] > '7' {
break
}
fallthrough
case '0':
// Consume up to three octal digits; already have one.
r = c - '0'
for i := 1; i < 3; i++ {
if t == "" || t[0] < '0' || t[0] > '7' {
break
}
r = r*8 + rune(t[0]) - '0'
t = t[1:]
}
return r, t, nil
// Hexadecimal escapes.
case 'x':
if t == "" {
break
}
if c, t, err = nextRune(t); err != nil {
return 0, "", err
}
if c == '{' {
// Any number of digits in braces.
// Perl accepts any text at all; it ignores all text
// after the first non-hex digit. We require only hex digits,
// and at least one.
nhex := 0
r = 0
for {
if t == "" {
break Switch
}
if c, t, err = nextRune(t); err != nil {
return 0, "", err
}
if c == '}' {
break
}
v := unhex(c)
if v < 0 {
break Switch
}
r = r*16 + v
if r > unicode.MaxRune {
break Switch
}
nhex++
}
if nhex == 0 {
break Switch
}
return r, t, nil
}
// Easy case: two hex digits.
x := unhex(c)
if c, t, err = nextRune(t); err != nil {
return 0, "", err
}
y := unhex(c)
if x < 0 || y < 0 {
break
}
return x*16 + y, t, nil
// C escapes. There is no case 'b', to avoid misparsing
// the Perl word-boundary \b as the C backspace \b
// when in POSIX mode. In Perl, /\b/ means word-boundary
// but /[\b]/ means backspace. We don't support that.
// If you want a backspace, embed a literal backspace
// character or use \x08.
case 'a':
return '\a', t, err
case 'f':
return '\f', t, err
case 'n':
return '\n', t, err
case 'r':
return '\r', t, err
case 't':
return '\t', t, err
case 'v':
return '\v', t, err
}
return 0, "", &Error{ErrInvalidEscape, s[:len(s)-len(t)]}
}
// parseClassChar parses a character class character at the beginning of s
// and returns it.
func (p *parser) parseClassChar(s, wholeClass string) (r rune, rest string, err error) {
if s == "" {
return 0, "", &Error{Code: ErrMissingBracket, Expr: wholeClass}
}
// Allow regular escape sequences even though
// many need not be escaped in this context.
if s[0] == '\\' {
return p.parseEscape(s)
}
return nextRune(s)
}
type charGroup struct {
sign int
class []rune
}
// parsePerlClassEscape parses a leading Perl character class escape like \d
// from the beginning of s. If one is present, it appends the characters to r
// and returns the new slice r and the remainder of the string.
func (p *parser) parsePerlClassEscape(s string, r []rune) (out []rune, rest string) {
if p.flags&PerlX == 0 || len(s) < 2 || s[0] != '\\' {
return
}
g := perlGroup[s[0:2]]
if g.sign == 0 {
return
}
return p.appendGroup(r, g), s[2:]
}
// parseNamedClass parses a leading POSIX named character class like [:alnum:]
// from the beginning of s. If one is present, it appends the characters to r
// and returns the new slice r and the remainder of the string.
func (p *parser) parseNamedClass(s string, r []rune) (out []rune, rest string, err error) {
if len(s) < 2 || s[0] != '[' || s[1] != ':' {
return
}
i := strings.Index(s[2:], ":]")
if i < 0 {
return
}
i += 2
name, s := s[0:i+2], s[i+2:]
g := posixGroup[name]
if g.sign == 0 {
return nil, "", &Error{ErrInvalidCharRange, name}
}
return p.appendGroup(r, g), s, nil
}
func (p *parser) appendGroup(r []rune, g charGroup) []rune {
if p.flags&FoldCase == 0 {
if g.sign < 0 {
r = appendNegatedClass(r, g.class)
} else {
r = appendClass(r, g.class)
}
} else {
tmp := p.tmpClass[:0]
tmp = appendFoldedClass(tmp, g.class)
p.tmpClass = tmp
tmp = cleanClass(&p.tmpClass)
if g.sign < 0 {
r = appendNegatedClass(r, tmp)
} else {
r = appendClass(r, tmp)
}
}
return r
}
var anyTable = &unicode.RangeTable{
R16: []unicode.Range16{{Lo: 0, Hi: 1<<16 - 1, Stride: 1}},
R32: []unicode.Range32{{Lo: 1 << 16, Hi: unicode.MaxRune, Stride: 1}},
}
// unicodeTable returns the unicode.RangeTable identified by name
// and the table of additional fold-equivalent code points.
func unicodeTable(name string) (*unicode.RangeTable, *unicode.RangeTable) {
// Special case: "Any" means any.
if name == "Any" {
return anyTable, anyTable
}
if t := unicode.Categories[name]; t != nil {
return t, unicode.FoldCategory[name]
}
if t := unicode.Scripts[name]; t != nil {
return t, unicode.FoldScript[name]
}
return nil, nil
}
// parseUnicodeClass parses a leading Unicode character class like \p{Han}
// from the beginning of s. If one is present, it appends the characters to r
// and returns the new slice r and the remainder of the string.
func (p *parser) parseUnicodeClass(s string, r []rune) (out []rune, rest string, err error) {
if p.flags&UnicodeGroups == 0 || len(s) < 2 || s[0] != '\\' || s[1] != 'p' && s[1] != 'P' {
return
}
// Committed to parse or return error.
sign := +1
if s[1] == 'P' {
sign = -1
}
t := s[2:]
c, t, err := nextRune(t)
if err != nil {
return
}
var seq, name string
if c != '{' {
// Single-letter name.
seq = s[:len(s)-len(t)]
name = seq[2:]
} else {
// Name is in braces.
end := strings.IndexRune(s, '}')
if end < 0 {
if err = checkUTF8(s); err != nil {
return
}
return nil, "", &Error{ErrInvalidCharRange, s}
}
seq, t = s[:end+1], s[end+1:]
name = s[3:end]
if err = checkUTF8(name); err != nil {
return
}
}
// Group can have leading negation too. \p{^Han} == \P{Han}, \P{^Han} == \p{Han}.
if name != "" && name[0] == '^' {
sign = -sign
name = name[1:]
}
tab, fold := unicodeTable(name)
if tab == nil {
return nil, "", &Error{ErrInvalidCharRange, seq}
}
if p.flags&FoldCase == 0 || fold == nil {
if sign > 0 {
r = appendTable(r, tab)
} else {
r = appendNegatedTable(r, tab)
}
} else {
// Merge and clean tab and fold in a temporary buffer.
// This is necessary for the negative case and just tidy
// for the positive case.
tmp := p.tmpClass[:0]
tmp = appendTable(tmp, tab)
tmp = appendTable(tmp, fold)
p.tmpClass = tmp
tmp = cleanClass(&p.tmpClass)
if sign > 0 {
r = appendClass(r, tmp)
} else {
r = appendNegatedClass(r, tmp)
}
}
return r, t, nil
}
// parseClass parses a character class at the beginning of s
// and pushes it onto the parse stack.
func (p *parser) parseClass(s string) (rest string, err error) {
t := s[1:] // chop [
re := p.newRegexp(OpCharClass)
re.Flags = p.flags
re.Rune = re.Rune0[:0]
sign := +1
if t != "" && t[0] == '^' {
sign = -1
t = t[1:]
// If character class does not match \n, add it here,
// so that negation later will do the right thing.
if p.flags&ClassNL == 0 {
re.Rune = append(re.Rune, '\n', '\n')
}
}
class := re.Rune
first := true // ] and - are okay as first char in class
for t == "" || t[0] != ']' || first {
// POSIX: - is only okay unescaped as first or last in class.
// Perl: - is okay anywhere.
if t != "" && t[0] == '-' && p.flags&PerlX == 0 && !first && (len(t) == 1 || t[1] != ']') {
_, size := utf8.DecodeRuneInString(t[1:])
return "", &Error{Code: ErrInvalidCharRange, Expr: t[:1+size]}
}
first = false
// Look for POSIX [:alnum:] etc.
if len(t) > 2 && t[0] == '[' && t[1] == ':' {
nclass, nt, err := p.parseNamedClass(t, class)
if err != nil {
return "", err
}
if nclass != nil {
class, t = nclass, nt
continue
}
}
// Look for Unicode character group like \p{Han}.
nclass, nt, err := p.parseUnicodeClass(t, class)
if err != nil {
return "", err
}
if nclass != nil {
class, t = nclass, nt
continue
}
// Look for Perl character class symbols (extension).
if nclass, nt := p.parsePerlClassEscape(t, class); nclass != nil {
class, t = nclass, nt
continue
}
// Single character or simple range.
rng := t
var lo, hi rune
if lo, t, err = p.parseClassChar(t, s); err != nil {
return "", err
}
hi = lo
// [a-] means (a|-) so check for final ].
if len(t) >= 2 && t[0] == '-' && t[1] != ']' {
t = t[1:]
if hi, t, err = p.parseClassChar(t, s); err != nil {
return "", err
}
if hi < lo {
rng = rng[:len(rng)-len(t)]
return "", &Error{Code: ErrInvalidCharRange, Expr: rng}
}
}
if p.flags&FoldCase == 0 {
class = appendRange(class, lo, hi)
} else {
class = appendFoldedRange(class, lo, hi)
}
}
t = t[1:] // chop ]
// Use &re.Rune instead of &class to avoid allocation.
re.Rune = class
class = cleanClass(&re.Rune)
if sign < 0 {
class = negateClass(class)
}
re.Rune = class
p.push(re)
return t, nil
}
// cleanClass sorts the ranges (pairs of elements of r),
// merges them, and eliminates duplicates.
func cleanClass(rp *[]rune) []rune {
// Sort by lo increasing, hi decreasing to break ties.
sort.Sort(ranges{rp})
r := *rp
if len(r) < 2 {
return r
}
// Merge abutting, overlapping.
w := 2 // write index
for i := 2; i < len(r); i += 2 {
lo, hi := r[i], r[i+1]
if lo <= r[w-1]+1 {
// merge with previous range
if hi > r[w-1] {
r[w-1] = hi
}
continue
}
// new disjoint range
r[w] = lo
r[w+1] = hi
w += 2
}
return r[:w]
}
// appendLiteral returns the result of appending the literal x to the class r.
func appendLiteral(r []rune, x rune, flags Flags) []rune {
if flags&FoldCase != 0 {
return appendFoldedRange(r, x, x)
}
return appendRange(r, x, x)
}
// appendRange returns the result of appending the range lo-hi to the class r.
func appendRange(r []rune, lo, hi rune) []rune {
// Expand last range or next to last range if it overlaps or abuts.
// Checking two ranges helps when appending case-folded
// alphabets, so that one range can be expanding A-Z and the
// other expanding a-z.
n := len(r)
for i := 2; i <= 4; i += 2 { // twice, using i=2, i=4
if n >= i {
rlo, rhi := r[n-i], r[n-i+1]
if lo <= rhi+1 && rlo <= hi+1 {
if lo < rlo {
r[n-i] = lo
}
if hi > rhi {
r[n-i+1] = hi
}
return r
}
}
}
return append(r, lo, hi)
}
const (
// minimum and maximum runes involved in folding.
// checked during test.
minFold = 0x0041
maxFold = 0x1e943
)
// appendFoldedRange returns the result of appending the range lo-hi
// and its case folding-equivalent runes to the class r.
func appendFoldedRange(r []rune, lo, hi rune) []rune {
// Optimizations.
if lo <= minFold && hi >= maxFold {
// Range is full: folding can't add more.
return appendRange(r, lo, hi)
}
if hi < minFold || lo > maxFold {
// Range is outside folding possibilities.
return appendRange(r, lo, hi)
}
if lo < minFold {
// [lo, minFold-1] needs no folding.
r = appendRange(r, lo, minFold-1)
lo = minFold
}
if hi > maxFold {
// [maxFold+1, hi] needs no folding.
r = appendRange(r, maxFold+1, hi)
hi = maxFold
}
// Brute force. Depend on appendRange to coalesce ranges on the fly.
for c := lo; c <= hi; c++ {
r = appendRange(r, c, c)
f := unicode.SimpleFold(c)
for f != c {
r = appendRange(r, f, f)
f = unicode.SimpleFold(f)
}
}
return r
}
// appendClass returns the result of appending the class x to the class r.
// It assume x is clean.
func appendClass(r []rune, x []rune) []rune {
for i := 0; i < len(x); i += 2 {
r = appendRange(r, x[i], x[i+1])
}
return r
}
// appendFoldedClass returns the result of appending the case folding of the class x to the class r.
func appendFoldedClass(r []rune, x []rune) []rune {
for i := 0; i < len(x); i += 2 {
r = appendFoldedRange(r, x[i], x[i+1])
}
return r
}
// appendNegatedClass returns the result of appending the negation of the class x to the class r.
// It assumes x is clean.
func appendNegatedClass(r []rune, x []rune) []rune {
nextLo := '\u0000'
for i := 0; i < len(x); i += 2 {
lo, hi := x[i], x[i+1]
if nextLo <= lo-1 {
r = appendRange(r, nextLo, lo-1)
}
nextLo = hi + 1
}
if nextLo <= unicode.MaxRune {
r = appendRange(r, nextLo, unicode.MaxRune)
}
return r
}
// appendTable returns the result of appending x to the class r.
func appendTable(r []rune, x *unicode.RangeTable) []rune {
for _, xr := range x.R16 {
lo, hi, stride := rune(xr.Lo), rune(xr.Hi), rune(xr.Stride)
if stride == 1 {
r = appendRange(r, lo, hi)
continue
}
for c := lo; c <= hi; c += stride {
r = appendRange(r, c, c)
}
}
for _, xr := range x.R32 {
lo, hi, stride := rune(xr.Lo), rune(xr.Hi), rune(xr.Stride)
if stride == 1 {
r = appendRange(r, lo, hi)
continue
}
for c := lo; c <= hi; c += stride {
r = appendRange(r, c, c)
}
}
return r
}
// appendNegatedTable returns the result of appending the negation of x to the class r.
func appendNegatedTable(r []rune, x *unicode.RangeTable) []rune {
nextLo := '\u0000' // lo end of next class to add
for _, xr := range x.R16 {
lo, hi, stride := rune(xr.Lo), rune(xr.Hi), rune(xr.Stride)
if stride == 1 {
if nextLo <= lo-1 {
r = appendRange(r, nextLo, lo-1)
}
nextLo = hi + 1
continue
}
for c := lo; c <= hi; c += stride {
if nextLo <= c-1 {
r = appendRange(r, nextLo, c-1)
}
nextLo = c + 1
}
}
for _, xr := range x.R32 {
lo, hi, stride := rune(xr.Lo), rune(xr.Hi), rune(xr.Stride)
if stride == 1 {
if nextLo <= lo-1 {
r = appendRange(r, nextLo, lo-1)
}
nextLo = hi + 1
continue
}
for c := lo; c <= hi; c += stride {
if nextLo <= c-1 {
r = appendRange(r, nextLo, c-1)
}
nextLo = c + 1
}
}
if nextLo <= unicode.MaxRune {
r = appendRange(r, nextLo, unicode.MaxRune)
}
return r
}
// negateClass overwrites r and returns r's negation.
// It assumes the class r is already clean.
func negateClass(r []rune) []rune {
nextLo := '\u0000' // lo end of next class to add
w := 0 // write index
for i := 0; i < len(r); i += 2 {
lo, hi := r[i], r[i+1]
if nextLo <= lo-1 {
r[w] = nextLo
r[w+1] = lo - 1
w += 2
}
nextLo = hi + 1
}
r = r[:w]
if nextLo <= unicode.MaxRune {
// It's possible for the negation to have one more
// range - this one - than the original class, so use append.
r = append(r, nextLo, unicode.MaxRune)
}
return r
}
// ranges implements sort.Interface on a []rune.
// The choice of receiver type definition is strange
// but avoids an allocation since we already have
// a *[]rune.
type ranges struct {
p *[]rune
}
func (ra ranges) Less(i, j int) bool {
p := *ra.p
i *= 2
j *= 2
return p[i] < p[j] || p[i] == p[j] && p[i+1] > p[j+1]
}
func (ra ranges) Len() int {
return len(*ra.p) / 2
}
func (ra ranges) Swap(i, j int) {
p := *ra.p
i *= 2
j *= 2
p[i], p[i+1], p[j], p[j+1] = p[j], p[j+1], p[i], p[i+1]
}
func checkUTF8(s string) error {
for s != "" {
rune, size := utf8.DecodeRuneInString(s)
if rune == utf8.RuneError && size == 1 {
return &Error{Code: ErrInvalidUTF8, Expr: s}
}
s = s[size:]
}
return nil
}
func nextRune(s string) (c rune, t string, err error) {
c, size := utf8.DecodeRuneInString(s)
if c == utf8.RuneError && size == 1 {
return 0, "", &Error{Code: ErrInvalidUTF8, Expr: s}
}
return c, s[size:], nil
}
func isalnum(c rune) bool {
return '0' <= c && c <= '9' || 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z'
}
func unhex(c rune) rune {
if '0' <= c && c <= '9' {
return c - '0'
}
if 'a' <= c && c <= 'f' {
return c - 'a' + 10
}
if 'A' <= c && c <= 'F' {
return c - 'A' + 10
}
return -1
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syntax
import (
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// Compiled program.
// May not belong in this package, but convenient for now.
// A Prog is a compiled regular expression program.
type Prog struct {
Inst []Inst
Start int // index of start instruction
NumCap int // number of InstCapture insts in re
}
// An InstOp is an instruction opcode.
type InstOp uint8
const (
InstAlt InstOp = iota
InstAltMatch
InstCapture
InstEmptyWidth
InstMatch
InstFail
InstNop
InstRune
InstRune1
InstRuneAny
InstRuneAnyNotNL
)
var instOpNames = []string{
"InstAlt",
"InstAltMatch",
"InstCapture",
"InstEmptyWidth",
"InstMatch",
"InstFail",
"InstNop",
"InstRune",
"InstRune1",
"InstRuneAny",
"InstRuneAnyNotNL",
}
func (i InstOp) String() string {
if uint(i) >= uint(len(instOpNames)) {
return ""
}
return instOpNames[i]
}
// An EmptyOp specifies a kind or mixture of zero-width assertions.
type EmptyOp uint8
const (
EmptyBeginLine EmptyOp = 1 << iota
EmptyEndLine
EmptyBeginText
EmptyEndText
EmptyWordBoundary
EmptyNoWordBoundary
)
// EmptyOpContext returns the zero-width assertions
// satisfied at the position between the runes r1 and r2.
// Passing r1 == -1 indicates that the position is
// at the beginning of the text.
// Passing r2 == -1 indicates that the position is
// at the end of the text.
func EmptyOpContext(r1, r2 rune) EmptyOp {
var op EmptyOp = EmptyNoWordBoundary
var boundary byte
switch {
case IsWordChar(r1):
boundary = 1
case r1 == '\n':
op |= EmptyBeginLine
case r1 < 0:
op |= EmptyBeginText | EmptyBeginLine
}
switch {
case IsWordChar(r2):
boundary ^= 1
case r2 == '\n':
op |= EmptyEndLine
case r2 < 0:
op |= EmptyEndText | EmptyEndLine
}
if boundary != 0 { // IsWordChar(r1) != IsWordChar(r2)
op ^= (EmptyWordBoundary | EmptyNoWordBoundary)
}
return op
}
// IsWordChar reports whether r is considered a “word character”
// during the evaluation of the \b and \B zero-width assertions.
// These assertions are ASCII-only: the word characters are [A-Za-z0-9_].
func IsWordChar(r rune) bool {
return 'A' <= r && r <= 'Z' || 'a' <= r && r <= 'z' || '0' <= r && r <= '9' || r == '_'
}
// An Inst is a single instruction in a regular expression program.
type Inst struct {
Op InstOp
Out uint32 // all but InstMatch, InstFail
Arg uint32 // InstAlt, InstAltMatch, InstCapture, InstEmptyWidth
Rune []rune
}
func (p *Prog) String() string {
var b strings.Builder
dumpProg(&b, p)
return b.String()
}
// skipNop follows any no-op or capturing instructions.
func (p *Prog) skipNop(pc uint32) *Inst {
i := &p.Inst[pc]
for i.Op == InstNop || i.Op == InstCapture {
i = &p.Inst[i.Out]
}
return i
}
// op returns i.Op but merges all the Rune special cases into InstRune
func (i *Inst) op() InstOp {
op := i.Op
switch op {
case InstRune1, InstRuneAny, InstRuneAnyNotNL:
op = InstRune
}
return op
}
// Prefix returns a literal string that all matches for the
// regexp must start with. Complete is true if the prefix
// is the entire match.
func (p *Prog) Prefix() (prefix string, complete bool) {
i := p.skipNop(uint32(p.Start))
// Avoid allocation of buffer if prefix is empty.
if i.op() != InstRune || len(i.Rune) != 1 {
return "", i.Op == InstMatch
}
// Have prefix; gather characters.
var buf strings.Builder
for i.op() == InstRune && len(i.Rune) == 1 && Flags(i.Arg)&FoldCase == 0 && i.Rune[0] != utf8.RuneError {
buf.WriteRune(i.Rune[0])
i = p.skipNop(i.Out)
}
return buf.String(), i.Op == InstMatch
}
// StartCond returns the leading empty-width conditions that must
// be true in any match. It returns ^EmptyOp(0) if no matches are possible.
func (p *Prog) StartCond() EmptyOp {
var flag EmptyOp
pc := uint32(p.Start)
i := &p.Inst[pc]
Loop:
for {
switch i.Op {
case InstEmptyWidth:
flag |= EmptyOp(i.Arg)
case InstFail:
return ^EmptyOp(0)
case InstCapture, InstNop:
// skip
default:
break Loop
}
pc = i.Out
i = &p.Inst[pc]
}
return flag
}
const noMatch = -1
// MatchRune reports whether the instruction matches (and consumes) r.
// It should only be called when i.Op == InstRune.
func (i *Inst) MatchRune(r rune) bool {
return i.MatchRunePos(r) != noMatch
}
// MatchRunePos checks whether the instruction matches (and consumes) r.
// If so, MatchRunePos returns the index of the matching rune pair
// (or, when len(i.Rune) == 1, rune singleton).
// If not, MatchRunePos returns -1.
// MatchRunePos should only be called when i.Op == InstRune.
func (i *Inst) MatchRunePos(r rune) int {
rune := i.Rune
switch len(rune) {
case 0:
return noMatch
case 1:
// Special case: single-rune slice is from literal string, not char class.
r0 := rune[0]
if r == r0 {
return 0
}
if Flags(i.Arg)&FoldCase != 0 {
for r1 := unicode.SimpleFold(r0); r1 != r0; r1 = unicode.SimpleFold(r1) {
if r == r1 {
return 0
}
}
}
return noMatch
case 2:
if r >= rune[0] && r <= rune[1] {
return 0
}
return noMatch
case 4, 6, 8:
// Linear search for a few pairs.
// Should handle ASCII well.
for j := 0; j < len(rune); j += 2 {
if r < rune[j] {
return noMatch
}
if r <= rune[j+1] {
return j / 2
}
}
return noMatch
}
// Otherwise binary search.
lo := 0
hi := len(rune) / 2
for lo < hi {
m := lo + (hi-lo)/2
if c := rune[2*m]; c <= r {
if r <= rune[2*m+1] {
return m
}
lo = m + 1
} else {
hi = m
}
}
return noMatch
}
// MatchEmptyWidth reports whether the instruction matches
// an empty string between the runes before and after.
// It should only be called when i.Op == InstEmptyWidth.
func (i *Inst) MatchEmptyWidth(before rune, after rune) bool {
switch EmptyOp(i.Arg) {
case EmptyBeginLine:
return before == '\n' || before == -1
case EmptyEndLine:
return after == '\n' || after == -1
case EmptyBeginText:
return before == -1
case EmptyEndText:
return after == -1
case EmptyWordBoundary:
return IsWordChar(before) != IsWordChar(after)
case EmptyNoWordBoundary:
return IsWordChar(before) == IsWordChar(after)
}
panic("unknown empty width arg")
}
func (i *Inst) String() string {
var b strings.Builder
dumpInst(&b, i)
return b.String()
}
func bw(b *strings.Builder, args ...string) {
for _, s := range args {
b.WriteString(s)
}
}
func dumpProg(b *strings.Builder, p *Prog) {
for j := range p.Inst {
i := &p.Inst[j]
pc := strconv.Itoa(j)
if len(pc) < 3 {
b.WriteString(" "[len(pc):])
}
if j == p.Start {
pc += "*"
}
bw(b, pc, "\t")
dumpInst(b, i)
bw(b, "\n")
}
}
func u32(i uint32) string {
return strconv.FormatUint(uint64(i), 10)
}
func dumpInst(b *strings.Builder, i *Inst) {
switch i.Op {
case InstAlt:
bw(b, "alt -> ", u32(i.Out), ", ", u32(i.Arg))
case InstAltMatch:
bw(b, "altmatch -> ", u32(i.Out), ", ", u32(i.Arg))
case InstCapture:
bw(b, "cap ", u32(i.Arg), " -> ", u32(i.Out))
case InstEmptyWidth:
bw(b, "empty ", u32(i.Arg), " -> ", u32(i.Out))
case InstMatch:
bw(b, "match")
case InstFail:
bw(b, "fail")
case InstNop:
bw(b, "nop -> ", u32(i.Out))
case InstRune:
if i.Rune == nil {
// shouldn't happen
bw(b, "rune <nil>")
}
bw(b, "rune ", strconv.QuoteToASCII(string(i.Rune)))
if Flags(i.Arg)&FoldCase != 0 {
bw(b, "/i")
}
bw(b, " -> ", u32(i.Out))
case InstRune1:
bw(b, "rune1 ", strconv.QuoteToASCII(string(i.Rune)), " -> ", u32(i.Out))
case InstRuneAny:
bw(b, "any -> ", u32(i.Out))
case InstRuneAnyNotNL:
bw(b, "anynotnl -> ", u32(i.Out))
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syntax
// Note to implementers:
// In this package, re is always a *Regexp and r is always a rune.
import (
"strconv"
"strings"
"unicode"
)
// A Regexp is a node in a regular expression syntax tree.
type Regexp struct {
Op Op // operator
Flags Flags
Sub []*Regexp // subexpressions, if any
Sub0 [1]*Regexp // storage for short Sub
Rune []rune // matched runes, for OpLiteral, OpCharClass
Rune0 [2]rune // storage for short Rune
Min, Max int // min, max for OpRepeat
Cap int // capturing index, for OpCapture
Name string // capturing name, for OpCapture
}
//go:generate stringer -type Op -trimprefix Op
// An Op is a single regular expression operator.
type Op uint8
// Operators are listed in precedence order, tightest binding to weakest.
// Character class operators are listed simplest to most complex
// (OpLiteral, OpCharClass, OpAnyCharNotNL, OpAnyChar).
const (
OpNoMatch Op = 1 + iota // matches no strings
OpEmptyMatch // matches empty string
OpLiteral // matches Runes sequence
OpCharClass // matches Runes interpreted as range pair list
OpAnyCharNotNL // matches any character except newline
OpAnyChar // matches any character
OpBeginLine // matches empty string at beginning of line
OpEndLine // matches empty string at end of line
OpBeginText // matches empty string at beginning of text
OpEndText // matches empty string at end of text
OpWordBoundary // matches word boundary `\b`
OpNoWordBoundary // matches word non-boundary `\B`
OpCapture // capturing subexpression with index Cap, optional name Name
OpStar // matches Sub[0] zero or more times
OpPlus // matches Sub[0] one or more times
OpQuest // matches Sub[0] zero or one times
OpRepeat // matches Sub[0] at least Min times, at most Max (Max == -1 is no limit)
OpConcat // matches concatenation of Subs
OpAlternate // matches alternation of Subs
)
const opPseudo Op = 128 // where pseudo-ops start
// Equal reports whether x and y have identical structure.
func (x *Regexp) Equal(y *Regexp) bool {
if x == nil || y == nil {
return x == y
}
if x.Op != y.Op {
return false
}
switch x.Op {
case OpEndText:
// The parse flags remember whether this is \z or \Z.
if x.Flags&WasDollar != y.Flags&WasDollar {
return false
}
case OpLiteral, OpCharClass:
if len(x.Rune) != len(y.Rune) {
return false
}
for i, r := range x.Rune {
if r != y.Rune[i] {
return false
}
}
case OpAlternate, OpConcat:
if len(x.Sub) != len(y.Sub) {
return false
}
for i, sub := range x.Sub {
if !sub.Equal(y.Sub[i]) {
return false
}
}
case OpStar, OpPlus, OpQuest:
if x.Flags&NonGreedy != y.Flags&NonGreedy || !x.Sub[0].Equal(y.Sub[0]) {
return false
}
case OpRepeat:
if x.Flags&NonGreedy != y.Flags&NonGreedy || x.Min != y.Min || x.Max != y.Max || !x.Sub[0].Equal(y.Sub[0]) {
return false
}
case OpCapture:
if x.Cap != y.Cap || x.Name != y.Name || !x.Sub[0].Equal(y.Sub[0]) {
return false
}
}
return true
}
// writeRegexp writes the Perl syntax for the regular expression re to b.
func writeRegexp(b *strings.Builder, re *Regexp) {
switch re.Op {
default:
b.WriteString("<invalid op" + strconv.Itoa(int(re.Op)) + ">")
case OpNoMatch:
b.WriteString(`[^\x00-\x{10FFFF}]`)
case OpEmptyMatch:
b.WriteString(`(?:)`)
case OpLiteral:
if re.Flags&FoldCase != 0 {
b.WriteString(`(?i:`)
}
for _, r := range re.Rune {
escape(b, r, false)
}
if re.Flags&FoldCase != 0 {
b.WriteString(`)`)
}
case OpCharClass:
if len(re.Rune)%2 != 0 {
b.WriteString(`[invalid char class]`)
break
}
b.WriteRune('[')
if len(re.Rune) == 0 {
b.WriteString(`^\x00-\x{10FFFF}`)
} else if re.Rune[0] == 0 && re.Rune[len(re.Rune)-1] == unicode.MaxRune && len(re.Rune) > 2 {
// Contains 0 and MaxRune. Probably a negated class.
// Print the gaps.
b.WriteRune('^')
for i := 1; i < len(re.Rune)-1; i += 2 {
lo, hi := re.Rune[i]+1, re.Rune[i+1]-1
escape(b, lo, lo == '-')
if lo != hi {
b.WriteRune('-')
escape(b, hi, hi == '-')
}
}
} else {
for i := 0; i < len(re.Rune); i += 2 {
lo, hi := re.Rune[i], re.Rune[i+1]
escape(b, lo, lo == '-')
if lo != hi {
b.WriteRune('-')
escape(b, hi, hi == '-')
}
}
}
b.WriteRune(']')
case OpAnyCharNotNL:
b.WriteString(`(?-s:.)`)
case OpAnyChar:
b.WriteString(`(?s:.)`)
case OpBeginLine:
b.WriteString(`(?m:^)`)
case OpEndLine:
b.WriteString(`(?m:$)`)
case OpBeginText:
b.WriteString(`\A`)
case OpEndText:
if re.Flags&WasDollar != 0 {
b.WriteString(`(?-m:$)`)
} else {
b.WriteString(`\z`)
}
case OpWordBoundary:
b.WriteString(`\b`)
case OpNoWordBoundary:
b.WriteString(`\B`)
case OpCapture:
if re.Name != "" {
b.WriteString(`(?P<`)
b.WriteString(re.Name)
b.WriteRune('>')
} else {
b.WriteRune('(')
}
if re.Sub[0].Op != OpEmptyMatch {
writeRegexp(b, re.Sub[0])
}
b.WriteRune(')')
case OpStar, OpPlus, OpQuest, OpRepeat:
if sub := re.Sub[0]; sub.Op > OpCapture || sub.Op == OpLiteral && len(sub.Rune) > 1 {
b.WriteString(`(?:`)
writeRegexp(b, sub)
b.WriteString(`)`)
} else {
writeRegexp(b, sub)
}
switch re.Op {
case OpStar:
b.WriteRune('*')
case OpPlus:
b.WriteRune('+')
case OpQuest:
b.WriteRune('?')
case OpRepeat:
b.WriteRune('{')
b.WriteString(strconv.Itoa(re.Min))
if re.Max != re.Min {
b.WriteRune(',')
if re.Max >= 0 {
b.WriteString(strconv.Itoa(re.Max))
}
}
b.WriteRune('}')
}
if re.Flags&NonGreedy != 0 {
b.WriteRune('?')
}
case OpConcat:
for _, sub := range re.Sub {
if sub.Op == OpAlternate {
b.WriteString(`(?:`)
writeRegexp(b, sub)
b.WriteString(`)`)
} else {
writeRegexp(b, sub)
}
}
case OpAlternate:
for i, sub := range re.Sub {
if i > 0 {
b.WriteRune('|')
}
writeRegexp(b, sub)
}
}
}
func (re *Regexp) String() string {
var b strings.Builder
writeRegexp(&b, re)
return b.String()
}
const meta = `\.+*?()|[]{}^$`
func escape(b *strings.Builder, r rune, force bool) {
if unicode.IsPrint(r) {
if strings.ContainsRune(meta, r) || force {
b.WriteRune('\\')
}
b.WriteRune(r)
return
}
switch r {
case '\a':
b.WriteString(`\a`)
case '\f':
b.WriteString(`\f`)
case '\n':
b.WriteString(`\n`)
case '\r':
b.WriteString(`\r`)
case '\t':
b.WriteString(`\t`)
case '\v':
b.WriteString(`\v`)
default:
if r < 0x100 {
b.WriteString(`\x`)
s := strconv.FormatInt(int64(r), 16)
if len(s) == 1 {
b.WriteRune('0')
}
b.WriteString(s)
break
}
b.WriteString(`\x{`)
b.WriteString(strconv.FormatInt(int64(r), 16))
b.WriteString(`}`)
}
}
// MaxCap walks the regexp to find the maximum capture index.
func (re *Regexp) MaxCap() int {
m := 0
if re.Op == OpCapture {
m = re.Cap
}
for _, sub := range re.Sub {
if n := sub.MaxCap(); m < n {
m = n
}
}
return m
}
// CapNames walks the regexp to find the names of capturing groups.
func (re *Regexp) CapNames() []string {
names := make([]string, re.MaxCap()+1)
re.capNames(names)
return names
}
func (re *Regexp) capNames(names []string) {
if re.Op == OpCapture {
names[re.Cap] = re.Name
}
for _, sub := range re.Sub {
sub.capNames(names)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syntax
// Simplify returns a regexp equivalent to re but without counted repetitions
// and with various other simplifications, such as rewriting /(?:a+)+/ to /a+/.
// The resulting regexp will execute correctly but its string representation
// will not produce the same parse tree, because capturing parentheses
// may have been duplicated or removed. For example, the simplified form
// for /(x){1,2}/ is /(x)(x)?/ but both parentheses capture as $1.
// The returned regexp may share structure with or be the original.
func (re *Regexp) Simplify() *Regexp {
if re == nil {
return nil
}
switch re.Op {
case OpCapture, OpConcat, OpAlternate:
// Simplify children, building new Regexp if children change.
nre := re
for i, sub := range re.Sub {
nsub := sub.Simplify()
if nre == re && nsub != sub {
// Start a copy.
nre = new(Regexp)
*nre = *re
nre.Rune = nil
nre.Sub = append(nre.Sub0[:0], re.Sub[:i]...)
}
if nre != re {
nre.Sub = append(nre.Sub, nsub)
}
}
return nre
case OpStar, OpPlus, OpQuest:
sub := re.Sub[0].Simplify()
return simplify1(re.Op, re.Flags, sub, re)
case OpRepeat:
// Special special case: x{0} matches the empty string
// and doesn't even need to consider x.
if re.Min == 0 && re.Max == 0 {
return &Regexp{Op: OpEmptyMatch}
}
// The fun begins.
sub := re.Sub[0].Simplify()
// x{n,} means at least n matches of x.
if re.Max == -1 {
// Special case: x{0,} is x*.
if re.Min == 0 {
return simplify1(OpStar, re.Flags, sub, nil)
}
// Special case: x{1,} is x+.
if re.Min == 1 {
return simplify1(OpPlus, re.Flags, sub, nil)
}
// General case: x{4,} is xxxx+.
nre := &Regexp{Op: OpConcat}
nre.Sub = nre.Sub0[:0]
for i := 0; i < re.Min-1; i++ {
nre.Sub = append(nre.Sub, sub)
}
nre.Sub = append(nre.Sub, simplify1(OpPlus, re.Flags, sub, nil))
return nre
}
// Special case x{0} handled above.
// Special case: x{1} is just x.
if re.Min == 1 && re.Max == 1 {
return sub
}
// General case: x{n,m} means n copies of x and m copies of x?
// The machine will do less work if we nest the final m copies,
// so that x{2,5} = xx(x(x(x)?)?)?
// Build leading prefix: xx.
var prefix *Regexp
if re.Min > 0 {
prefix = &Regexp{Op: OpConcat}
prefix.Sub = prefix.Sub0[:0]
for i := 0; i < re.Min; i++ {
prefix.Sub = append(prefix.Sub, sub)
}
}
// Build and attach suffix: (x(x(x)?)?)?
if re.Max > re.Min {
suffix := simplify1(OpQuest, re.Flags, sub, nil)
for i := re.Min + 1; i < re.Max; i++ {
nre2 := &Regexp{Op: OpConcat}
nre2.Sub = append(nre2.Sub0[:0], sub, suffix)
suffix = simplify1(OpQuest, re.Flags, nre2, nil)
}
if prefix == nil {
return suffix
}
prefix.Sub = append(prefix.Sub, suffix)
}
if prefix != nil {
return prefix
}
// Some degenerate case like min > max or min < max < 0.
// Handle as impossible match.
return &Regexp{Op: OpNoMatch}
}
return re
}
// simplify1 implements Simplify for the unary OpStar,
// OpPlus, and OpQuest operators. It returns the simple regexp
// equivalent to
//
// Regexp{Op: op, Flags: flags, Sub: {sub}}
//
// under the assumption that sub is already simple, and
// without first allocating that structure. If the regexp
// to be returned turns out to be equivalent to re, simplify1
// returns re instead.
//
// simplify1 is factored out of Simplify because the implementation
// for other operators generates these unary expressions.
// Letting them call simplify1 makes sure the expressions they
// generate are simple.
func simplify1(op Op, flags Flags, sub, re *Regexp) *Regexp {
// Special case: repeat the empty string as much as
// you want, but it's still the empty string.
if sub.Op == OpEmptyMatch {
return sub
}
// The operators are idempotent if the flags match.
if op == sub.Op && flags&NonGreedy == sub.Flags&NonGreedy {
return sub
}
if re != nil && re.Op == op && re.Flags&NonGreedy == flags&NonGreedy && sub == re.Sub[0] {
return re
}
re = &Regexp{Op: op, Flags: flags}
re.Sub = append(re.Sub0[:0], sub)
return re
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/cpu"
"internal/goarch"
"unsafe"
)
const (
c0 = uintptr((8-goarch.PtrSize)/4*2860486313 + (goarch.PtrSize-4)/4*33054211828000289)
c1 = uintptr((8-goarch.PtrSize)/4*3267000013 + (goarch.PtrSize-4)/4*23344194077549503)
)
func memhash0(p unsafe.Pointer, h uintptr) uintptr {
return h
}
func memhash8(p unsafe.Pointer, h uintptr) uintptr {
return memhash(p, h, 1)
}
func memhash16(p unsafe.Pointer, h uintptr) uintptr {
return memhash(p, h, 2)
}
func memhash128(p unsafe.Pointer, h uintptr) uintptr {
return memhash(p, h, 16)
}
//go:nosplit
func memhash_varlen(p unsafe.Pointer, h uintptr) uintptr {
ptr := getclosureptr()
size := *(*uintptr)(unsafe.Pointer(ptr + unsafe.Sizeof(h)))
return memhash(p, h, size)
}
// runtime variable to check if the processor we're running on
// actually supports the instructions used by the AES-based
// hash implementation.
var useAeshash bool
// in asm_*.s
func memhash(p unsafe.Pointer, h, s uintptr) uintptr
func memhash32(p unsafe.Pointer, h uintptr) uintptr
func memhash64(p unsafe.Pointer, h uintptr) uintptr
func strhash(p unsafe.Pointer, h uintptr) uintptr
func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
x := (*stringStruct)(a)
return memhashFallback(x.str, h, uintptr(x.len))
}
// NOTE: Because NaN != NaN, a map can contain any
// number of (mostly useless) entries keyed with NaNs.
// To avoid long hash chains, we assign a random number
// as the hash value for a NaN.
func f32hash(p unsafe.Pointer, h uintptr) uintptr {
f := *(*float32)(p)
switch {
case f == 0:
return c1 * (c0 ^ h) // +0, -0
case f != f:
return c1 * (c0 ^ h ^ uintptr(fastrand())) // any kind of NaN
default:
return memhash(p, h, 4)
}
}
func f64hash(p unsafe.Pointer, h uintptr) uintptr {
f := *(*float64)(p)
switch {
case f == 0:
return c1 * (c0 ^ h) // +0, -0
case f != f:
return c1 * (c0 ^ h ^ uintptr(fastrand())) // any kind of NaN
default:
return memhash(p, h, 8)
}
}
func c64hash(p unsafe.Pointer, h uintptr) uintptr {
x := (*[2]float32)(p)
return f32hash(unsafe.Pointer(&x[1]), f32hash(unsafe.Pointer(&x[0]), h))
}
func c128hash(p unsafe.Pointer, h uintptr) uintptr {
x := (*[2]float64)(p)
return f64hash(unsafe.Pointer(&x[1]), f64hash(unsafe.Pointer(&x[0]), h))
}
func interhash(p unsafe.Pointer, h uintptr) uintptr {
a := (*iface)(p)
tab := a.tab
if tab == nil {
return h
}
t := tab._type
if t.equal == nil {
// Check hashability here. We could do this check inside
// typehash, but we want to report the topmost type in
// the error text (e.g. in a struct with a field of slice type
// we want to report the struct, not the slice).
panic(errorString("hash of unhashable type " + t.string()))
}
if isDirectIface(t) {
return c1 * typehash(t, unsafe.Pointer(&a.data), h^c0)
} else {
return c1 * typehash(t, a.data, h^c0)
}
}
func nilinterhash(p unsafe.Pointer, h uintptr) uintptr {
a := (*eface)(p)
t := a._type
if t == nil {
return h
}
if t.equal == nil {
// See comment in interhash above.
panic(errorString("hash of unhashable type " + t.string()))
}
if isDirectIface(t) {
return c1 * typehash(t, unsafe.Pointer(&a.data), h^c0)
} else {
return c1 * typehash(t, a.data, h^c0)
}
}
// typehash computes the hash of the object of type t at address p.
// h is the seed.
// This function is seldom used. Most maps use for hashing either
// fixed functions (e.g. f32hash) or compiler-generated functions
// (e.g. for a type like struct { x, y string }). This implementation
// is slower but more general and is used for hashing interface types
// (called from interhash or nilinterhash, above) or for hashing in
// maps generated by reflect.MapOf (reflect_typehash, below).
// Note: this function must match the compiler generated
// functions exactly. See issue 37716.
func typehash(t *_type, p unsafe.Pointer, h uintptr) uintptr {
if t.tflag&tflagRegularMemory != 0 {
// Handle ptr sizes specially, see issue 37086.
switch t.size {
case 4:
return memhash32(p, h)
case 8:
return memhash64(p, h)
default:
return memhash(p, h, t.size)
}
}
switch t.kind & kindMask {
case kindFloat32:
return f32hash(p, h)
case kindFloat64:
return f64hash(p, h)
case kindComplex64:
return c64hash(p, h)
case kindComplex128:
return c128hash(p, h)
case kindString:
return strhash(p, h)
case kindInterface:
i := (*interfacetype)(unsafe.Pointer(t))
if len(i.mhdr) == 0 {
return nilinterhash(p, h)
}
return interhash(p, h)
case kindArray:
a := (*arraytype)(unsafe.Pointer(t))
for i := uintptr(0); i < a.len; i++ {
h = typehash(a.elem, add(p, i*a.elem.size), h)
}
return h
case kindStruct:
s := (*structtype)(unsafe.Pointer(t))
for _, f := range s.fields {
if f.name.isBlank() {
continue
}
h = typehash(f.typ, add(p, f.offset), h)
}
return h
default:
// Should never happen, as typehash should only be called
// with comparable types.
panic(errorString("hash of unhashable type " + t.string()))
}
}
//go:linkname reflect_typehash reflect.typehash
func reflect_typehash(t *_type, p unsafe.Pointer, h uintptr) uintptr {
return typehash(t, p, h)
}
func memequal0(p, q unsafe.Pointer) bool {
return true
}
func memequal8(p, q unsafe.Pointer) bool {
return *(*int8)(p) == *(*int8)(q)
}
func memequal16(p, q unsafe.Pointer) bool {
return *(*int16)(p) == *(*int16)(q)
}
func memequal32(p, q unsafe.Pointer) bool {
return *(*int32)(p) == *(*int32)(q)
}
func memequal64(p, q unsafe.Pointer) bool {
return *(*int64)(p) == *(*int64)(q)
}
func memequal128(p, q unsafe.Pointer) bool {
return *(*[2]int64)(p) == *(*[2]int64)(q)
}
func f32equal(p, q unsafe.Pointer) bool {
return *(*float32)(p) == *(*float32)(q)
}
func f64equal(p, q unsafe.Pointer) bool {
return *(*float64)(p) == *(*float64)(q)
}
func c64equal(p, q unsafe.Pointer) bool {
return *(*complex64)(p) == *(*complex64)(q)
}
func c128equal(p, q unsafe.Pointer) bool {
return *(*complex128)(p) == *(*complex128)(q)
}
func strequal(p, q unsafe.Pointer) bool {
return *(*string)(p) == *(*string)(q)
}
func interequal(p, q unsafe.Pointer) bool {
x := *(*iface)(p)
y := *(*iface)(q)
return x.tab == y.tab && ifaceeq(x.tab, x.data, y.data)
}
func nilinterequal(p, q unsafe.Pointer) bool {
x := *(*eface)(p)
y := *(*eface)(q)
return x._type == y._type && efaceeq(x._type, x.data, y.data)
}
func efaceeq(t *_type, x, y unsafe.Pointer) bool {
if t == nil {
return true
}
eq := t.equal
if eq == nil {
panic(errorString("comparing uncomparable type " + t.string()))
}
if isDirectIface(t) {
// Direct interface types are ptr, chan, map, func, and single-element structs/arrays thereof.
// Maps and funcs are not comparable, so they can't reach here.
// Ptrs, chans, and single-element items can be compared directly using ==.
return x == y
}
return eq(x, y)
}
func ifaceeq(tab *itab, x, y unsafe.Pointer) bool {
if tab == nil {
return true
}
t := tab._type
eq := t.equal
if eq == nil {
panic(errorString("comparing uncomparable type " + t.string()))
}
if isDirectIface(t) {
// See comment in efaceeq.
return x == y
}
return eq(x, y)
}
// Testing adapters for hash quality tests (see hash_test.go)
func stringHash(s string, seed uintptr) uintptr {
return strhash(noescape(unsafe.Pointer(&s)), seed)
}
func bytesHash(b []byte, seed uintptr) uintptr {
s := (*slice)(unsafe.Pointer(&b))
return memhash(s.array, seed, uintptr(s.len))
}
func int32Hash(i uint32, seed uintptr) uintptr {
return memhash32(noescape(unsafe.Pointer(&i)), seed)
}
func int64Hash(i uint64, seed uintptr) uintptr {
return memhash64(noescape(unsafe.Pointer(&i)), seed)
}
func efaceHash(i any, seed uintptr) uintptr {
return nilinterhash(noescape(unsafe.Pointer(&i)), seed)
}
func ifaceHash(i interface {
F()
}, seed uintptr) uintptr {
return interhash(noescape(unsafe.Pointer(&i)), seed)
}
const hashRandomBytes = goarch.PtrSize / 4 * 64
// used in asm_{386,amd64,arm64}.s to seed the hash function
var aeskeysched [hashRandomBytes]byte
// used in hash{32,64}.go to seed the hash function
var hashkey [4]uintptr
func alginit() {
// Install AES hash algorithms if the instructions needed are present.
if (GOARCH == "386" || GOARCH == "amd64") &&
cpu.X86.HasAES && // AESENC
cpu.X86.HasSSSE3 && // PSHUFB
cpu.X86.HasSSE41 { // PINSR{D,Q}
initAlgAES()
return
}
if GOARCH == "arm64" && cpu.ARM64.HasAES {
initAlgAES()
return
}
getRandomData((*[len(hashkey) * goarch.PtrSize]byte)(unsafe.Pointer(&hashkey))[:])
hashkey[0] |= 1 // make sure these numbers are odd
hashkey[1] |= 1
hashkey[2] |= 1
hashkey[3] |= 1
}
func initAlgAES() {
useAeshash = true
// Initialize with random data so hash collisions will be hard to engineer.
getRandomData(aeskeysched[:])
}
// Note: These routines perform the read with a native endianness.
func readUnaligned32(p unsafe.Pointer) uint32 {
q := (*[4]byte)(p)
if goarch.BigEndian {
return uint32(q[3]) | uint32(q[2])<<8 | uint32(q[1])<<16 | uint32(q[0])<<24
}
return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
}
func readUnaligned64(p unsafe.Pointer) uint64 {
q := (*[8]byte)(p)
if goarch.BigEndian {
return uint64(q[7]) | uint64(q[6])<<8 | uint64(q[5])<<16 | uint64(q[4])<<24 |
uint64(q[3])<<32 | uint64(q[2])<<40 | uint64(q[1])<<48 | uint64(q[0])<<56
}
return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implementation of (safe) user arenas.
//
// This file contains the implementation of user arenas wherein Go values can
// be manually allocated and freed in bulk. The act of manually freeing memory,
// potentially before a GC cycle, means that a garbage collection cycle can be
// delayed, improving efficiency by reducing GC cycle frequency. There are other
// potential efficiency benefits, such as improved locality and access to a more
// efficient allocation strategy.
//
// What makes the arenas here safe is that once they are freed, accessing the
// arena's memory will cause an explicit program fault, and the arena's address
// space will not be reused until no more pointers into it are found. There's one
// exception to this: if an arena allocated memory that isn't exhausted, it's placed
// back into a pool for reuse. This means that a crash is not always guaranteed.
//
// While this may seem unsafe, it still prevents memory corruption, and is in fact
// necessary in order to make new(T) a valid implementation of arenas. Such a property
// is desirable to allow for a trivial implementation. (It also avoids complexities
// that arise from synchronization with the GC when trying to set the arena chunks to
// fault while the GC is active.)
//
// The implementation works in layers. At the bottom, arenas are managed in chunks.
// Each chunk must be a multiple of the heap arena size, or the heap arena size must
// be divisible by the arena chunks. The address space for each chunk, and each
// corresponding heapArena for that address space, are eternally reserved for use as
// arena chunks. That is, they can never be used for the general heap. Each chunk
// is also represented by a single mspan, and is modeled as a single large heap
// allocation. It must be, because each chunk contains ordinary Go values that may
// point into the heap, so it must be scanned just like any other object. Any
// pointer into a chunk will therefore always cause the whole chunk to be scanned
// while its corresponding arena is still live.
//
// Chunks may be allocated either from new memory mapped by the OS on our behalf,
// or by reusing old freed chunks. When chunks are freed, their underlying memory
// is returned to the OS, set to fault on access, and may not be reused until the
// program doesn't point into the chunk anymore (the code refers to this state as
// "quarantined"), a property checked by the GC.
//
// The sweeper handles moving chunks out of this quarantine state to be ready for
// reuse. When the chunk is placed into the quarantine state, its corresponding
// span is marked as noscan so that the GC doesn't try to scan memory that would
// cause a fault.
//
// At the next layer are the user arenas themselves. They consist of a single
// active chunk which new Go values are bump-allocated into and a list of chunks
// that were exhausted when allocating into the arena. Once the arena is freed,
// it frees all full chunks it references, and places the active one onto a reuse
// list for a future arena to use. Each arena keeps its list of referenced chunks
// explicitly live until it is freed. Each user arena also maps to an object which
// has a finalizer attached that ensures the arena's chunks are all freed even if
// the arena itself is never explicitly freed.
//
// Pointer-ful memory is bump-allocated from low addresses to high addresses in each
// chunk, while pointer-free memory is bump-allocated from high address to low
// addresses. The reason for this is to take advantage of a GC optimization wherein
// the GC will stop scanning an object when there are no more pointers in it, which
// also allows us to elide clearing the heap bitmap for pointer-free Go values
// allocated into arenas.
//
// Note that arenas are not safe to use concurrently.
//
// In summary, there are 2 resources: arenas, and arena chunks. They exist in the
// following lifecycle:
//
// (1) A new arena is created via newArena.
// (2) Chunks are allocated to hold memory allocated into the arena with new or slice.
// (a) Chunks are first allocated from the reuse list of partially-used chunks.
// (b) If there are no such chunks, then chunks on the ready list are taken.
// (c) Failing all the above, memory for a new chunk is mapped.
// (3) The arena is freed, or all references to it are dropped, triggering its finalizer.
// (a) If the GC is not active, exhausted chunks are set to fault and placed on a
// quarantine list.
// (b) If the GC is active, exhausted chunks are placed on a fault list and will
// go through step (a) at a later point in time.
// (c) Any remaining partially-used chunk is placed on a reuse list.
// (4) Once no more pointers are found into quarantined arena chunks, the sweeper
// takes these chunks out of quarantine and places them on the ready list.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/math"
"unsafe"
)
// Functions starting with arena_ are meant to be exported to downstream users
// of arenas. They should wrap these functions in a higher-lever API.
//
// The underlying arena and its resources are managed through an opaque unsafe.Pointer.
// arena_newArena is a wrapper around newUserArena.
//
//go:linkname arena_newArena arena.runtime_arena_newArena
func arena_newArena() unsafe.Pointer {
return unsafe.Pointer(newUserArena())
}
// arena_arena_New is a wrapper around (*userArena).new, except that typ
// is an any (must be a *_type, still) and typ must be a type descriptor
// for a pointer to the type to actually be allocated, i.e. pass a *T
// to allocate a T. This is necessary because this function returns a *T.
//
//go:linkname arena_arena_New arena.runtime_arena_arena_New
func arena_arena_New(arena unsafe.Pointer, typ any) any {
t := (*_type)(efaceOf(&typ).data)
if t.kind&kindMask != kindPtr {
throw("arena_New: non-pointer type")
}
te := (*ptrtype)(unsafe.Pointer(t)).elem
x := ((*userArena)(arena)).new(te)
var result any
e := efaceOf(&result)
e._type = t
e.data = x
return result
}
// arena_arena_Slice is a wrapper around (*userArena).slice.
//
//go:linkname arena_arena_Slice arena.runtime_arena_arena_Slice
func arena_arena_Slice(arena unsafe.Pointer, slice any, cap int) {
((*userArena)(arena)).slice(slice, cap)
}
// arena_arena_Free is a wrapper around (*userArena).free.
//
//go:linkname arena_arena_Free arena.runtime_arena_arena_Free
func arena_arena_Free(arena unsafe.Pointer) {
((*userArena)(arena)).free()
}
// arena_heapify takes a value that lives in an arena and makes a copy
// of it on the heap. Values that don't live in an arena are returned unmodified.
//
//go:linkname arena_heapify arena.runtime_arena_heapify
func arena_heapify(s any) any {
var v unsafe.Pointer
e := efaceOf(&s)
t := e._type
switch t.kind & kindMask {
case kindString:
v = stringStructOf((*string)(e.data)).str
case kindSlice:
v = (*slice)(e.data).array
case kindPtr:
v = e.data
default:
panic("arena: Clone only supports pointers, slices, and strings")
}
span := spanOf(uintptr(v))
if span == nil || !span.isUserArenaChunk {
// Not stored in a user arena chunk.
return s
}
// Heap-allocate storage for a copy.
var x any
switch t.kind & kindMask {
case kindString:
s1 := s.(string)
s2, b := rawstring(len(s1))
copy(b, s1)
x = s2
case kindSlice:
len := (*slice)(e.data).len
et := (*slicetype)(unsafe.Pointer(t)).elem
sl := new(slice)
*sl = slice{makeslicecopy(et, len, len, (*slice)(e.data).array), len, len}
xe := efaceOf(&x)
xe._type = t
xe.data = unsafe.Pointer(sl)
case kindPtr:
et := (*ptrtype)(unsafe.Pointer(t)).elem
e2 := newobject(et)
typedmemmove(et, e2, e.data)
xe := efaceOf(&x)
xe._type = t
xe.data = e2
}
return x
}
const (
// userArenaChunkBytes is the size of a user arena chunk.
userArenaChunkBytesMax = 8 << 20
userArenaChunkBytes = uintptr(int64(userArenaChunkBytesMax-heapArenaBytes)&(int64(userArenaChunkBytesMax-heapArenaBytes)>>63) + heapArenaBytes) // min(userArenaChunkBytesMax, heapArenaBytes)
// userArenaChunkPages is the number of pages a user arena chunk uses.
userArenaChunkPages = userArenaChunkBytes / pageSize
// userArenaChunkMaxAllocBytes is the maximum size of an object that can
// be allocated from an arena. This number is chosen to cap worst-case
// fragmentation of user arenas to 25%. Larger allocations are redirected
// to the heap.
userArenaChunkMaxAllocBytes = userArenaChunkBytes / 4
)
func init() {
if userArenaChunkPages*pageSize != userArenaChunkBytes {
throw("user arena chunk size is not a multiple of the page size")
}
if userArenaChunkBytes%physPageSize != 0 {
throw("user arena chunk size is not a multiple of the physical page size")
}
if userArenaChunkBytes < heapArenaBytes {
if heapArenaBytes%userArenaChunkBytes != 0 {
throw("user arena chunk size is smaller than a heap arena, but doesn't divide it")
}
} else {
if userArenaChunkBytes%heapArenaBytes != 0 {
throw("user arena chunks size is larger than a heap arena, but not a multiple")
}
}
lockInit(&userArenaState.lock, lockRankUserArenaState)
}
type userArena struct {
// full is a list of full chunks that have not enough free memory left, and
// that we'll free once this user arena is freed.
//
// Can't use mSpanList here because it's not-in-heap.
fullList *mspan
// active is the user arena chunk we're currently allocating into.
active *mspan
// refs is a set of references to the arena chunks so that they're kept alive.
//
// The last reference in the list always refers to active, while the rest of
// them correspond to fullList. Specifically, the head of fullList is the
// second-to-last one, fullList.next is the third-to-last, and so on.
//
// In other words, every time a new chunk becomes active, its appended to this
// list.
refs []unsafe.Pointer
// defunct is true if free has been called on this arena.
//
// This is just a best-effort way to discover a concurrent allocation
// and free. Also used to detect a double-free.
defunct atomic.Bool
}
// newUserArena creates a new userArena ready to be used.
func newUserArena() *userArena {
a := new(userArena)
SetFinalizer(a, func(a *userArena) {
// If arena handle is dropped without being freed, then call
// free on the arena, so the arena chunks are never reclaimed
// by the garbage collector.
a.free()
})
a.refill()
return a
}
// new allocates a new object of the provided type into the arena, and returns
// its pointer.
//
// This operation is not safe to call concurrently with other operations on the
// same arena.
func (a *userArena) new(typ *_type) unsafe.Pointer {
return a.alloc(typ, -1)
}
// slice allocates a new slice backing store. slice must be a pointer to a slice
// (i.e. *[]T), because userArenaSlice will update the slice directly.
//
// cap determines the capacity of the slice backing store and must be non-negative.
//
// This operation is not safe to call concurrently with other operations on the
// same arena.
func (a *userArena) slice(sl any, cap int) {
if cap < 0 {
panic("userArena.slice: negative cap")
}
i := efaceOf(&sl)
typ := i._type
if typ.kind&kindMask != kindPtr {
panic("slice result of non-ptr type")
}
typ = (*ptrtype)(unsafe.Pointer(typ)).elem
if typ.kind&kindMask != kindSlice {
panic("slice of non-ptr-to-slice type")
}
typ = (*slicetype)(unsafe.Pointer(typ)).elem
// t is now the element type of the slice we want to allocate.
*((*slice)(i.data)) = slice{a.alloc(typ, cap), cap, cap}
}
// free returns the userArena's chunks back to mheap and marks it as defunct.
//
// Must be called at most once for any given arena.
//
// This operation is not safe to call concurrently with other operations on the
// same arena.
func (a *userArena) free() {
// Check for a double-free.
if a.defunct.Load() {
panic("arena double free")
}
// Mark ourselves as defunct.
a.defunct.Store(true)
SetFinalizer(a, nil)
// Free all the full arenas.
//
// The refs on this list are in reverse order from the second-to-last.
s := a.fullList
i := len(a.refs) - 2
for s != nil {
a.fullList = s.next
s.next = nil
freeUserArenaChunk(s, a.refs[i])
s = a.fullList
i--
}
if a.fullList != nil || i >= 0 {
// There's still something left on the full list, or we
// failed to actually iterate over the entire refs list.
throw("full list doesn't match refs list in length")
}
// Put the active chunk onto the reuse list.
//
// Note that active's reference is always the last reference in refs.
s = a.active
if s != nil {
if raceenabled || msanenabled || asanenabled {
// Don't reuse arenas with sanitizers enabled. We want to catch
// any use-after-free errors aggressively.
freeUserArenaChunk(s, a.refs[len(a.refs)-1])
} else {
lock(&userArenaState.lock)
userArenaState.reuse = append(userArenaState.reuse, liveUserArenaChunk{s, a.refs[len(a.refs)-1]})
unlock(&userArenaState.lock)
}
}
// nil out a.active so that a race with freeing will more likely cause a crash.
a.active = nil
a.refs = nil
}
// alloc reserves space in the current chunk or calls refill and reserves space
// in a new chunk. If cap is negative, the type will be taken literally, otherwise
// it will be considered as an element type for a slice backing store with capacity
// cap.
func (a *userArena) alloc(typ *_type, cap int) unsafe.Pointer {
s := a.active
var x unsafe.Pointer
for {
x = s.userArenaNextFree(typ, cap)
if x != nil {
break
}
s = a.refill()
}
return x
}
// refill inserts the current arena chunk onto the full list and obtains a new
// one, either from the partial list or allocating a new one, both from mheap.
func (a *userArena) refill() *mspan {
// If there's an active chunk, assume it's full.
s := a.active
if s != nil {
if s.userArenaChunkFree.size() > userArenaChunkMaxAllocBytes {
// It's difficult to tell when we're actually out of memory
// in a chunk because the allocation that failed may still leave
// some free space available. However, that amount of free space
// should never exceed the maximum allocation size.
throw("wasted too much memory in an arena chunk")
}
s.next = a.fullList
a.fullList = s
a.active = nil
s = nil
}
var x unsafe.Pointer
// Check the partially-used list.
lock(&userArenaState.lock)
if len(userArenaState.reuse) > 0 {
// Pick off the last arena chunk from the list.
n := len(userArenaState.reuse) - 1
x = userArenaState.reuse[n].x
s = userArenaState.reuse[n].mspan
userArenaState.reuse[n].x = nil
userArenaState.reuse[n].mspan = nil
userArenaState.reuse = userArenaState.reuse[:n]
}
unlock(&userArenaState.lock)
if s == nil {
// Allocate a new one.
x, s = newUserArenaChunk()
if s == nil {
throw("out of memory")
}
}
a.refs = append(a.refs, x)
a.active = s
return s
}
type liveUserArenaChunk struct {
*mspan // Must represent a user arena chunk.
// Reference to mspan.base() to keep the chunk alive.
x unsafe.Pointer
}
var userArenaState struct {
lock mutex
// reuse contains a list of partially-used and already-live
// user arena chunks that can be quickly reused for another
// arena.
//
// Protected by lock.
reuse []liveUserArenaChunk
// fault contains full user arena chunks that need to be faulted.
//
// Protected by lock.
fault []liveUserArenaChunk
}
// userArenaNextFree reserves space in the user arena for an item of the specified
// type. If cap is not -1, this is for an array of cap elements of type t.
func (s *mspan) userArenaNextFree(typ *_type, cap int) unsafe.Pointer {
size := typ.size
if cap > 0 {
if size > ^uintptr(0)/uintptr(cap) {
// Overflow.
throw("out of memory")
}
size *= uintptr(cap)
}
if size == 0 || cap == 0 {
return unsafe.Pointer(&zerobase)
}
if size > userArenaChunkMaxAllocBytes {
// Redirect allocations that don't fit into a chunk well directly
// from the heap.
if cap >= 0 {
return newarray(typ, cap)
}
return newobject(typ)
}
// Prevent preemption as we set up the space for a new object.
//
// Act like we're allocating.
mp := acquirem()
if mp.mallocing != 0 {
throw("malloc deadlock")
}
if mp.gsignal == getg() {
throw("malloc during signal")
}
mp.mallocing = 1
var ptr unsafe.Pointer
if typ.ptrdata == 0 {
// Allocate pointer-less objects from the tail end of the chunk.
v, ok := s.userArenaChunkFree.takeFromBack(size, typ.align)
if ok {
ptr = unsafe.Pointer(v)
}
} else {
v, ok := s.userArenaChunkFree.takeFromFront(size, typ.align)
if ok {
ptr = unsafe.Pointer(v)
}
}
if ptr == nil {
// Failed to allocate.
mp.mallocing = 0
releasem(mp)
return nil
}
if s.needzero != 0 {
throw("arena chunk needs zeroing, but should already be zeroed")
}
// Set up heap bitmap and do extra accounting.
if typ.ptrdata != 0 {
if cap >= 0 {
userArenaHeapBitsSetSliceType(typ, cap, ptr, s.base())
} else {
userArenaHeapBitsSetType(typ, ptr, s.base())
}
c := getMCache(mp)
if c == nil {
throw("mallocgc called without a P or outside bootstrapping")
}
if cap > 0 {
c.scanAlloc += size - (typ.size - typ.ptrdata)
} else {
c.scanAlloc += typ.ptrdata
}
}
// Ensure that the stores above that initialize x to
// type-safe memory and set the heap bits occur before
// the caller can make ptr observable to the garbage
// collector. Otherwise, on weakly ordered machines,
// the garbage collector could follow a pointer to x,
// but see uninitialized memory or stale heap bits.
publicationBarrier()
mp.mallocing = 0
releasem(mp)
return ptr
}
// userArenaHeapBitsSetType is the equivalent of heapBitsSetType but for
// non-slice-backing-store Go values allocated in a user arena chunk. It
// sets up the heap bitmap for the value with type typ allocated at address ptr.
// base is the base address of the arena chunk.
func userArenaHeapBitsSetType(typ *_type, ptr unsafe.Pointer, base uintptr) {
h := writeHeapBitsForAddr(uintptr(ptr))
// Our last allocation might have ended right at a noMorePtrs mark,
// which we would not have erased. We need to erase that mark here,
// because we're going to start adding new heap bitmap bits.
// We only need to clear one mark, because below we make sure to
// pad out the bits with zeroes and only write one noMorePtrs bit
// for each new object.
// (This is only necessary at noMorePtrs boundaries, as noMorePtrs
// marks within an object allocated with newAt will be erased by
// the normal writeHeapBitsForAddr mechanism.)
//
// Note that we skip this if this is the first allocation in the
// arena because there's definitely no previous noMorePtrs mark
// (in fact, we *must* do this, because we're going to try to back
// up a pointer to fix this up).
if uintptr(ptr)%(8*goarch.PtrSize*goarch.PtrSize) == 0 && uintptr(ptr) != base {
// Back up one pointer and rewrite that pointer. That will
// cause the writeHeapBits implementation to clear the
// noMorePtrs bit we need to clear.
r := heapBitsForAddr(uintptr(ptr)-goarch.PtrSize, goarch.PtrSize)
_, p := r.next()
b := uintptr(0)
if p == uintptr(ptr)-goarch.PtrSize {
b = 1
}
h = writeHeapBitsForAddr(uintptr(ptr) - goarch.PtrSize)
h = h.write(b, 1)
}
p := typ.gcdata // start of 1-bit pointer mask (or GC program)
var gcProgBits uintptr
if typ.kind&kindGCProg != 0 {
// Expand gc program, using the object itself for storage.
gcProgBits = runGCProg(addb(p, 4), (*byte)(ptr))
p = (*byte)(ptr)
}
nb := typ.ptrdata / goarch.PtrSize
for i := uintptr(0); i < nb; i += ptrBits {
k := nb - i
if k > ptrBits {
k = ptrBits
}
h = h.write(readUintptr(addb(p, i/8)), k)
}
// Note: we call pad here to ensure we emit explicit 0 bits
// for the pointerless tail of the object. This ensures that
// there's only a single noMorePtrs mark for the next object
// to clear. We don't need to do this to clear stale noMorePtrs
// markers from previous uses because arena chunk pointer bitmaps
// are always fully cleared when reused.
h = h.pad(typ.size - typ.ptrdata)
h.flush(uintptr(ptr), typ.size)
if typ.kind&kindGCProg != 0 {
// Zero out temporary ptrmask buffer inside object.
memclrNoHeapPointers(ptr, (gcProgBits+7)/8)
}
// Double-check that the bitmap was written out correctly.
//
// Derived from heapBitsSetType.
const doubleCheck = false
if doubleCheck {
size := typ.size
x := uintptr(ptr)
h := heapBitsForAddr(x, size)
for i := uintptr(0); i < size; i += goarch.PtrSize {
// Compute the pointer bit we want at offset i.
want := false
off := i % typ.size
if off < typ.ptrdata {
j := off / goarch.PtrSize
want = *addb(typ.gcdata, j/8)>>(j%8)&1 != 0
}
if want {
var addr uintptr
h, addr = h.next()
if addr != x+i {
throw("userArenaHeapBitsSetType: pointer entry not correct")
}
}
}
if _, addr := h.next(); addr != 0 {
throw("userArenaHeapBitsSetType: extra pointer")
}
}
}
// userArenaHeapBitsSetSliceType is the equivalent of heapBitsSetType but for
// Go slice backing store values allocated in a user arena chunk. It sets up the
// heap bitmap for n consecutive values with type typ allocated at address ptr.
func userArenaHeapBitsSetSliceType(typ *_type, n int, ptr unsafe.Pointer, base uintptr) {
mem, overflow := math.MulUintptr(typ.size, uintptr(n))
if overflow || n < 0 || mem > maxAlloc {
panic(plainError("runtime: allocation size out of range"))
}
for i := 0; i < n; i++ {
userArenaHeapBitsSetType(typ, add(ptr, uintptr(i)*typ.size), base)
}
}
// newUserArenaChunk allocates a user arena chunk, which maps to a single
// heap arena and single span. Returns a pointer to the base of the chunk
// (this is really important: we need to keep the chunk alive) and the span.
func newUserArenaChunk() (unsafe.Pointer, *mspan) {
if gcphase == _GCmarktermination {
throw("newUserArenaChunk called with gcphase == _GCmarktermination")
}
// Deduct assist credit. Because user arena chunks are modeled as one
// giant heap object which counts toward heapLive, we're obligated to
// assist the GC proportionally (and it's worth noting that the arena
// does represent additional work for the GC, but we also have no idea
// what that looks like until we actually allocate things into the
// arena).
deductAssistCredit(userArenaChunkBytes)
// Set mp.mallocing to keep from being preempted by GC.
mp := acquirem()
if mp.mallocing != 0 {
throw("malloc deadlock")
}
if mp.gsignal == getg() {
throw("malloc during signal")
}
mp.mallocing = 1
// Allocate a new user arena.
var span *mspan
systemstack(func() {
span = mheap_.allocUserArenaChunk()
})
if span == nil {
throw("out of memory")
}
x := unsafe.Pointer(span.base())
// Allocate black during GC.
// All slots hold nil so no scanning is needed.
// This may be racing with GC so do it atomically if there can be
// a race marking the bit.
if gcphase != _GCoff {
gcmarknewobject(span, span.base(), span.elemsize)
}
if raceenabled {
// TODO(mknyszek): Track individual objects.
racemalloc(unsafe.Pointer(span.base()), span.elemsize)
}
if msanenabled {
// TODO(mknyszek): Track individual objects.
msanmalloc(unsafe.Pointer(span.base()), span.elemsize)
}
if asanenabled {
// TODO(mknyszek): Track individual objects.
rzSize := computeRZlog(span.elemsize)
span.elemsize -= rzSize
span.limit -= rzSize
span.userArenaChunkFree = makeAddrRange(span.base(), span.limit)
asanpoison(unsafe.Pointer(span.limit), span.npages*pageSize-span.elemsize)
asanunpoison(unsafe.Pointer(span.base()), span.elemsize)
}
if rate := MemProfileRate; rate > 0 {
c := getMCache(mp)
if c == nil {
throw("newUserArenaChunk called without a P or outside bootstrapping")
}
// Note cache c only valid while m acquired; see #47302
if rate != 1 && userArenaChunkBytes < c.nextSample {
c.nextSample -= userArenaChunkBytes
} else {
profilealloc(mp, unsafe.Pointer(span.base()), userArenaChunkBytes)
}
}
mp.mallocing = 0
releasem(mp)
// Again, because this chunk counts toward heapLive, potentially trigger a GC.
if t := (gcTrigger{kind: gcTriggerHeap}); t.test() {
gcStart(t)
}
if debug.malloc {
if debug.allocfreetrace != 0 {
tracealloc(unsafe.Pointer(span.base()), userArenaChunkBytes, nil)
}
if inittrace.active && inittrace.id == getg().goid {
// Init functions are executed sequentially in a single goroutine.
inittrace.bytes += uint64(userArenaChunkBytes)
}
}
// Double-check it's aligned to the physical page size. Based on the current
// implementation this is trivially true, but it need not be in the future.
// However, if it's not aligned to the physical page size then we can't properly
// set it to fault later.
if uintptr(x)%physPageSize != 0 {
throw("user arena chunk is not aligned to the physical page size")
}
return x, span
}
// isUnusedUserArenaChunk indicates that the arena chunk has been set to fault
// and doesn't contain any scannable memory anymore. However, it might still be
// mSpanInUse as it sits on the quarantine list, since it needs to be swept.
//
// This is not safe to execute unless the caller has ownership of the mspan or
// the world is stopped (preemption is prevented while the relevant state changes).
//
// This is really only meant to be used by accounting tests in the runtime to
// distinguish when a span shouldn't be counted (since mSpanInUse might not be
// enough).
func (s *mspan) isUnusedUserArenaChunk() bool {
return s.isUserArenaChunk && s.spanclass == makeSpanClass(0, true)
}
// setUserArenaChunkToFault sets the address space for the user arena chunk to fault
// and releases any underlying memory resources.
//
// Must be in a non-preemptible state to ensure the consistency of statistics
// exported to MemStats.
func (s *mspan) setUserArenaChunkToFault() {
if !s.isUserArenaChunk {
throw("invalid span in heapArena for user arena")
}
if s.npages*pageSize != userArenaChunkBytes {
throw("span on userArena.faultList has invalid size")
}
// Update the span class to be noscan. What we want to happen is that
// any pointer into the span keeps it from getting recycled, so we want
// the mark bit to get set, but we're about to set the address space to fault,
// so we have to prevent the GC from scanning this memory.
//
// It's OK to set it here because (1) a GC isn't in progress, so the scanning code
// won't make a bad decision, (2) we're currently non-preemptible and in the runtime,
// so a GC is blocked from starting. We might race with sweeping, which could
// put it on the "wrong" sweep list, but really don't care because the chunk is
// treated as a large object span and there's no meaningful difference between scan
// and noscan large objects in the sweeper. The STW at the start of the GC acts as a
// barrier for this update.
s.spanclass = makeSpanClass(0, true)
// Actually set the arena chunk to fault, so we'll get dangling pointer errors.
// sysFault currently uses a method on each OS that forces it to evacuate all
// memory backing the chunk.
sysFault(unsafe.Pointer(s.base()), s.npages*pageSize)
// Everything on the list is counted as in-use, however sysFault transitions to
// Reserved, not Prepared, so we skip updating heapFree or heapReleased and just
// remove the memory from the total altogether; it's just address space now.
gcController.heapInUse.add(-int64(s.npages * pageSize))
// Count this as a free of an object right now as opposed to when
// the span gets off the quarantine list. The main reason is so that the
// amount of bytes allocated doesn't exceed how much is counted as
// "mapped ready," which could cause a deadlock in the pacer.
gcController.totalFree.Add(int64(s.npages * pageSize))
// Update consistent stats to match.
//
// We're non-preemptible, so it's safe to update consistent stats (our P
// won't change out from under us).
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.committed, -int64(s.npages*pageSize))
atomic.Xaddint64(&stats.inHeap, -int64(s.npages*pageSize))
atomic.Xadd64(&stats.largeFreeCount, 1)
atomic.Xadd64(&stats.largeFree, int64(s.npages*pageSize))
memstats.heapStats.release()
// This counts as a free, so update heapLive.
gcController.update(-int64(s.npages*pageSize), 0)
// Mark it as free for the race detector.
if raceenabled {
racefree(unsafe.Pointer(s.base()), s.elemsize)
}
systemstack(func() {
// Add the user arena to the quarantine list.
lock(&mheap_.lock)
mheap_.userArena.quarantineList.insert(s)
unlock(&mheap_.lock)
})
}
// inUserArenaChunk returns true if p points to a user arena chunk.
func inUserArenaChunk(p uintptr) bool {
s := spanOf(p)
if s == nil {
return false
}
return s.isUserArenaChunk
}
// freeUserArenaChunk releases the user arena represented by s back to the runtime.
//
// x must be a live pointer within s.
//
// The runtime will set the user arena to fault once it's safe (the GC is no longer running)
// and then once the user arena is no longer referenced by the application, will allow it to
// be reused.
func freeUserArenaChunk(s *mspan, x unsafe.Pointer) {
if !s.isUserArenaChunk {
throw("span is not for a user arena")
}
if s.npages*pageSize != userArenaChunkBytes {
throw("invalid user arena span size")
}
// Mark the region as free to various santizers immediately instead
// of handling them at sweep time.
if raceenabled {
racefree(unsafe.Pointer(s.base()), s.elemsize)
}
if msanenabled {
msanfree(unsafe.Pointer(s.base()), s.elemsize)
}
if asanenabled {
asanpoison(unsafe.Pointer(s.base()), s.elemsize)
}
// Make ourselves non-preemptible as we manipulate state and statistics.
//
// Also required by setUserArenaChunksToFault.
mp := acquirem()
// We can only set user arenas to fault if we're in the _GCoff phase.
if gcphase == _GCoff {
lock(&userArenaState.lock)
faultList := userArenaState.fault
userArenaState.fault = nil
unlock(&userArenaState.lock)
s.setUserArenaChunkToFault()
for _, lc := range faultList {
lc.mspan.setUserArenaChunkToFault()
}
// Until the chunks are set to fault, keep them alive via the fault list.
KeepAlive(x)
KeepAlive(faultList)
} else {
// Put the user arena on the fault list.
lock(&userArenaState.lock)
userArenaState.fault = append(userArenaState.fault, liveUserArenaChunk{s, x})
unlock(&userArenaState.lock)
}
releasem(mp)
}
// allocUserArenaChunk attempts to reuse a free user arena chunk represented
// as a span.
//
// Must be in a non-preemptible state to ensure the consistency of statistics
// exported to MemStats.
//
// Acquires the heap lock. Must run on the system stack for that reason.
//
//go:systemstack
func (h *mheap) allocUserArenaChunk() *mspan {
var s *mspan
var base uintptr
// First check the free list.
lock(&h.lock)
if !h.userArena.readyList.isEmpty() {
s = h.userArena.readyList.first
h.userArena.readyList.remove(s)
base = s.base()
} else {
// Free list was empty, so allocate a new arena.
hintList := &h.userArena.arenaHints
if raceenabled {
// In race mode just use the regular heap hints. We might fragment
// the address space, but the race detector requires that the heap
// is mapped contiguously.
hintList = &h.arenaHints
}
v, size := h.sysAlloc(userArenaChunkBytes, hintList, false)
if size%userArenaChunkBytes != 0 {
throw("sysAlloc size is not divisible by userArenaChunkBytes")
}
if size > userArenaChunkBytes {
// We got more than we asked for. This can happen if
// heapArenaSize > userArenaChunkSize, or if sysAlloc just returns
// some extra as a result of trying to find an aligned region.
//
// Divide it up and put it on the ready list.
for i := uintptr(userArenaChunkBytes); i < size; i += userArenaChunkBytes {
s := h.allocMSpanLocked()
s.init(uintptr(v)+i, userArenaChunkPages)
h.userArena.readyList.insertBack(s)
}
size = userArenaChunkBytes
}
base = uintptr(v)
if base == 0 {
// Out of memory.
unlock(&h.lock)
return nil
}
s = h.allocMSpanLocked()
}
unlock(&h.lock)
// sysAlloc returns Reserved address space, and any span we're
// reusing is set to fault (so, also Reserved), so transition
// it to Prepared and then Ready.
//
// Unlike (*mheap).grow, just map in everything that we
// asked for. We're likely going to use it all.
sysMap(unsafe.Pointer(base), userArenaChunkBytes, &gcController.heapReleased)
sysUsed(unsafe.Pointer(base), userArenaChunkBytes, userArenaChunkBytes)
// Model the user arena as a heap span for a large object.
spc := makeSpanClass(0, false)
h.initSpan(s, spanAllocHeap, spc, base, userArenaChunkPages)
s.isUserArenaChunk = true
// Account for this new arena chunk memory.
gcController.heapInUse.add(int64(userArenaChunkBytes))
gcController.heapReleased.add(-int64(userArenaChunkBytes))
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.inHeap, int64(userArenaChunkBytes))
atomic.Xaddint64(&stats.committed, int64(userArenaChunkBytes))
// Model the arena as a single large malloc.
atomic.Xadd64(&stats.largeAlloc, int64(userArenaChunkBytes))
atomic.Xadd64(&stats.largeAllocCount, 1)
memstats.heapStats.release()
// Count the alloc in inconsistent, internal stats.
gcController.totalAlloc.Add(int64(userArenaChunkBytes))
// Update heapLive.
gcController.update(int64(userArenaChunkBytes), 0)
// Put the large span in the mcentral swept list so that it's
// visible to the background sweeper.
h.central[spc].mcentral.fullSwept(h.sweepgen).push(s)
s.limit = s.base() + userArenaChunkBytes
s.freeindex = 1
s.allocCount = 1
// This must clear the entire heap bitmap so that it's safe
// to allocate noscan data without writing anything out.
s.initHeapBits(true)
// Clear the span preemptively. It's an arena chunk, so let's assume
// everything is going to be used.
//
// This also seems to make a massive difference as to whether or
// not Linux decides to back this memory with transparent huge
// pages. There's latency involved in this zeroing, but the hugepage
// gains are almost always worth it. Note: it's important that we
// clear even if it's freshly mapped and we know there's no point
// to zeroing as *that* is the critical signal to use huge pages.
memclrNoHeapPointers(unsafe.Pointer(s.base()), s.elemsize)
s.needzero = 0
s.freeIndexForScan = 1
// Set up the range for allocation.
s.userArenaChunkFree = makeAddrRange(base, s.limit)
return s
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !asan
// Dummy ASan support API, used when not built with -asan.
package runtime
import (
"unsafe"
)
const asanenabled = false
// Because asanenabled is false, none of these functions should be called.
func asanread(addr unsafe.Pointer, sz uintptr) { throw("asan") }
func asanwrite(addr unsafe.Pointer, sz uintptr) { throw("asan") }
func asanunpoison(addr unsafe.Pointer, sz uintptr) { throw("asan") }
func asanpoison(addr unsafe.Pointer, sz uintptr) { throw("asan") }
func asanregisterglobals(addr unsafe.Pointer, sz uintptr) { throw("asan") }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goexperiment"
"runtime/internal/atomic"
"unsafe"
)
// These functions cannot have go:noescape annotations,
// because while ptr does not escape, new does.
// If new is marked as not escaping, the compiler will make incorrect
// escape analysis decisions about the pointer value being stored.
// atomicwb performs a write barrier before an atomic pointer write.
// The caller should guard the call with "if writeBarrier.enabled".
//
//go:nosplit
func atomicwb(ptr *unsafe.Pointer, new unsafe.Pointer) {
slot := (*uintptr)(unsafe.Pointer(ptr))
buf := getg().m.p.ptr().wbBuf.get2()
buf[0] = *slot
buf[1] = uintptr(new)
}
// atomicstorep performs *ptr = new atomically and invokes a write barrier.
//
//go:nosplit
func atomicstorep(ptr unsafe.Pointer, new unsafe.Pointer) {
if writeBarrier.enabled {
atomicwb((*unsafe.Pointer)(ptr), new)
}
if goexperiment.CgoCheck2 {
cgoCheckPtrWrite((*unsafe.Pointer)(ptr), new)
}
atomic.StorepNoWB(noescape(ptr), new)
}
// atomic_storePointer is the implementation of runtime/internal/UnsafePointer.Store
// (like StoreNoWB but with the write barrier).
//
//go:nosplit
//go:linkname atomic_storePointer runtime/internal/atomic.storePointer
func atomic_storePointer(ptr *unsafe.Pointer, new unsafe.Pointer) {
atomicstorep(unsafe.Pointer(ptr), new)
}
// atomic_casPointer is the implementation of runtime/internal/UnsafePointer.CompareAndSwap
// (like CompareAndSwapNoWB but with the write barrier).
//
//go:nosplit
//go:linkname atomic_casPointer runtime/internal/atomic.casPointer
func atomic_casPointer(ptr *unsafe.Pointer, old, new unsafe.Pointer) bool {
if writeBarrier.enabled {
atomicwb(ptr, new)
}
if goexperiment.CgoCheck2 {
cgoCheckPtrWrite(ptr, new)
}
return atomic.Casp1(ptr, old, new)
}
// Like above, but implement in terms of sync/atomic's uintptr operations.
// We cannot just call the runtime routines, because the race detector expects
// to be able to intercept the sync/atomic forms but not the runtime forms.
//go:linkname sync_atomic_StoreUintptr sync/atomic.StoreUintptr
func sync_atomic_StoreUintptr(ptr *uintptr, new uintptr)
//go:linkname sync_atomic_StorePointer sync/atomic.StorePointer
//go:nosplit
func sync_atomic_StorePointer(ptr *unsafe.Pointer, new unsafe.Pointer) {
if writeBarrier.enabled {
atomicwb(ptr, new)
}
if goexperiment.CgoCheck2 {
cgoCheckPtrWrite(ptr, new)
}
sync_atomic_StoreUintptr((*uintptr)(unsafe.Pointer(ptr)), uintptr(new))
}
//go:linkname sync_atomic_SwapUintptr sync/atomic.SwapUintptr
func sync_atomic_SwapUintptr(ptr *uintptr, new uintptr) uintptr
//go:linkname sync_atomic_SwapPointer sync/atomic.SwapPointer
//go:nosplit
func sync_atomic_SwapPointer(ptr *unsafe.Pointer, new unsafe.Pointer) unsafe.Pointer {
if writeBarrier.enabled {
atomicwb(ptr, new)
}
if goexperiment.CgoCheck2 {
cgoCheckPtrWrite(ptr, new)
}
old := unsafe.Pointer(sync_atomic_SwapUintptr((*uintptr)(noescape(unsafe.Pointer(ptr))), uintptr(new)))
return old
}
//go:linkname sync_atomic_CompareAndSwapUintptr sync/atomic.CompareAndSwapUintptr
func sync_atomic_CompareAndSwapUintptr(ptr *uintptr, old, new uintptr) bool
//go:linkname sync_atomic_CompareAndSwapPointer sync/atomic.CompareAndSwapPointer
//go:nosplit
func sync_atomic_CompareAndSwapPointer(ptr *unsafe.Pointer, old, new unsafe.Pointer) bool {
if writeBarrier.enabled {
atomicwb(ptr, new)
}
if goexperiment.CgoCheck2 {
cgoCheckPtrWrite(ptr, new)
}
return sync_atomic_CompareAndSwapUintptr((*uintptr)(noescape(unsafe.Pointer(ptr))), uintptr(old), uintptr(new))
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
//go:cgo_export_static main
// Filled in by runtime/cgo when linked into binary.
//go:linkname _cgo_init _cgo_init
//go:linkname _cgo_thread_start _cgo_thread_start
//go:linkname _cgo_sys_thread_create _cgo_sys_thread_create
//go:linkname _cgo_notify_runtime_init_done _cgo_notify_runtime_init_done
//go:linkname _cgo_callers _cgo_callers
//go:linkname _cgo_set_context_function _cgo_set_context_function
//go:linkname _cgo_yield _cgo_yield
var (
_cgo_init unsafe.Pointer
_cgo_thread_start unsafe.Pointer
_cgo_sys_thread_create unsafe.Pointer
_cgo_notify_runtime_init_done unsafe.Pointer
_cgo_callers unsafe.Pointer
_cgo_set_context_function unsafe.Pointer
_cgo_yield unsafe.Pointer
)
// iscgo is set to true by the runtime/cgo package
var iscgo bool
// cgoHasExtraM is set on startup when an extra M is created for cgo.
// The extra M must be created before any C/C++ code calls cgocallback.
var cgoHasExtraM bool
// cgoUse is called by cgo-generated code (using go:linkname to get at
// an unexported name). The calls serve two purposes:
// 1) they are opaque to escape analysis, so the argument is considered to
// escape to the heap.
// 2) they keep the argument alive until the call site; the call is emitted after
// the end of the (presumed) use of the argument by C.
// cgoUse should not actually be called (see cgoAlwaysFalse).
func cgoUse(any) { throw("cgoUse should not be called") }
// cgoAlwaysFalse is a boolean value that is always false.
// The cgo-generated code says if cgoAlwaysFalse { cgoUse(p) }.
// The compiler cannot see that cgoAlwaysFalse is always false,
// so it emits the test and keeps the call, giving the desired
// escape analysis result. The test is cheaper than the call.
var cgoAlwaysFalse bool
var cgo_yield = &_cgo_yield
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cgo
import "unsafe"
// These utility functions are available to be called from code
// compiled with gcc via crosscall2.
// The declaration of crosscall2 is:
// void crosscall2(void (*fn)(void *), void *, int);
//
// We need to export the symbol crosscall2 in order to support
// callbacks from shared libraries. This applies regardless of
// linking mode.
//
// Compatibility note: SWIG uses crosscall2 in exactly one situation:
// to call _cgo_panic using the pattern shown below. We need to keep
// that pattern working. In particular, crosscall2 actually takes four
// arguments, but it works to call it with three arguments when
// calling _cgo_panic.
//
//go:cgo_export_static crosscall2
//go:cgo_export_dynamic crosscall2
// Panic. The argument is converted into a Go string.
// Call like this in code compiled with gcc:
// struct { const char *p; } a;
// a.p = /* string to pass to panic */;
// crosscall2(_cgo_panic, &a, sizeof a);
// /* The function call will not return. */
// TODO: We should export a regular C function to panic, change SWIG
// to use that instead of the above pattern, and then we can drop
// backwards-compatibility from crosscall2 and stop exporting it.
//go:linkname _runtime_cgo_panic_internal runtime._cgo_panic_internal
func _runtime_cgo_panic_internal(p *byte)
//go:linkname _cgo_panic _cgo_panic
//go:cgo_export_static _cgo_panic
//go:cgo_export_dynamic _cgo_panic
func _cgo_panic(a *struct{ cstr *byte }) {
_runtime_cgo_panic_internal(a.cstr)
}
//go:cgo_import_static x_cgo_init
//go:linkname x_cgo_init x_cgo_init
//go:linkname _cgo_init _cgo_init
var x_cgo_init byte
var _cgo_init = &x_cgo_init
//go:cgo_import_static x_cgo_thread_start
//go:linkname x_cgo_thread_start x_cgo_thread_start
//go:linkname _cgo_thread_start _cgo_thread_start
var x_cgo_thread_start byte
var _cgo_thread_start = &x_cgo_thread_start
// Creates a new system thread without updating any Go state.
//
// This method is invoked during shared library loading to create a new OS
// thread to perform the runtime initialization. This method is similar to
// _cgo_sys_thread_start except that it doesn't update any Go state.
//go:cgo_import_static x_cgo_sys_thread_create
//go:linkname x_cgo_sys_thread_create x_cgo_sys_thread_create
//go:linkname _cgo_sys_thread_create _cgo_sys_thread_create
var x_cgo_sys_thread_create byte
var _cgo_sys_thread_create = &x_cgo_sys_thread_create
// Notifies that the runtime has been initialized.
//
// We currently block at every CGO entry point (via _cgo_wait_runtime_init_done)
// to ensure that the runtime has been initialized before the CGO call is
// executed. This is necessary for shared libraries where we kickoff runtime
// initialization in a separate thread and return without waiting for this
// thread to complete the init.
//go:cgo_import_static x_cgo_notify_runtime_init_done
//go:linkname x_cgo_notify_runtime_init_done x_cgo_notify_runtime_init_done
//go:linkname _cgo_notify_runtime_init_done _cgo_notify_runtime_init_done
var x_cgo_notify_runtime_init_done byte
var _cgo_notify_runtime_init_done = &x_cgo_notify_runtime_init_done
// Sets the traceback context function. See runtime.SetCgoTraceback.
//go:cgo_import_static x_cgo_set_context_function
//go:linkname x_cgo_set_context_function x_cgo_set_context_function
//go:linkname _cgo_set_context_function _cgo_set_context_function
var x_cgo_set_context_function byte
var _cgo_set_context_function = &x_cgo_set_context_function
// Calls a libc function to execute background work injected via libc
// interceptors, such as processing pending signals under the thread
// sanitizer.
//
// Left as a nil pointer if no libc interceptors are expected.
//go:cgo_import_static _cgo_yield
//go:linkname _cgo_yield _cgo_yield
var _cgo_yield unsafe.Pointer
//go:cgo_export_static _cgo_topofstack
//go:cgo_export_dynamic _cgo_topofstack
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cgo
import (
"sync"
"sync/atomic"
)
// Handle provides a way to pass values that contain Go pointers
// (pointers to memory allocated by Go) between Go and C without
// breaking the cgo pointer passing rules. A Handle is an integer
// value that can represent any Go value. A Handle can be passed
// through C and back to Go, and Go code can use the Handle to
// retrieve the original Go value.
//
// The underlying type of Handle is guaranteed to fit in an integer type
// that is large enough to hold the bit pattern of any pointer. The zero
// value of a Handle is not valid, and thus is safe to use as a sentinel
// in C APIs.
//
// For instance, on the Go side:
//
// package main
//
// /*
// #include <stdint.h> // for uintptr_t
//
// extern void MyGoPrint(uintptr_t handle);
// void myprint(uintptr_t handle);
// */
// import "C"
// import "runtime/cgo"
//
// //export MyGoPrint
// func MyGoPrint(handle C.uintptr_t) {
// h := cgo.Handle(handle)
// val := h.Value().(string)
// println(val)
// h.Delete()
// }
//
// func main() {
// val := "hello Go"
// C.myprint(C.uintptr_t(cgo.NewHandle(val)))
// // Output: hello Go
// }
//
// and on the C side:
//
// #include <stdint.h> // for uintptr_t
//
// // A Go function
// extern void MyGoPrint(uintptr_t handle);
//
// // A C function
// void myprint(uintptr_t handle) {
// MyGoPrint(handle);
// }
//
// Some C functions accept a void* argument that points to an arbitrary
// data value supplied by the caller. It is not safe to coerce a cgo.Handle
// (an integer) to a Go unsafe.Pointer, but instead we can pass the address
// of the cgo.Handle to the void* parameter, as in this variant of the
// previous example:
//
// package main
//
// /*
// extern void MyGoPrint(void *context);
// static inline void myprint(void *context) {
// MyGoPrint(context);
// }
// */
// import "C"
// import (
// "runtime/cgo"
// "unsafe"
// )
//
// //export MyGoPrint
// func MyGoPrint(context unsafe.Pointer) {
// h := *(*cgo.Handle)(context)
// val := h.Value().(string)
// println(val)
// h.Delete()
// }
//
// func main() {
// val := "hello Go"
// h := cgo.NewHandle(val)
// C.myprint(unsafe.Pointer(&h))
// // Output: hello Go
// }
type Handle uintptr
// NewHandle returns a handle for a given value.
//
// The handle is valid until the program calls Delete on it. The handle
// uses resources, and this package assumes that C code may hold on to
// the handle, so a program must explicitly call Delete when the handle
// is no longer needed.
//
// The intended use is to pass the returned handle to C code, which
// passes it back to Go, which calls Value.
func NewHandle(v any) Handle {
h := atomic.AddUintptr(&handleIdx, 1)
if h == 0 {
panic("runtime/cgo: ran out of handle space")
}
handles.Store(h, v)
return Handle(h)
}
// Value returns the associated Go value for a valid handle.
//
// The method panics if the handle is invalid.
func (h Handle) Value() any {
v, ok := handles.Load(uintptr(h))
if !ok {
panic("runtime/cgo: misuse of an invalid Handle")
}
return v
}
// Delete invalidates a handle. This method should only be called once
// the program no longer needs to pass the handle to C and the C code
// no longer has a copy of the handle value.
//
// The method panics if the handle is invalid.
func (h Handle) Delete() {
_, ok := handles.LoadAndDelete(uintptr(h))
if !ok {
panic("runtime/cgo: misuse of an invalid Handle")
}
}
var (
handles = sync.Map{} // map[Handle]interface{}
handleIdx uintptr // atomic
)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Support for memory sanitizer. See runtime/cgo/mmap.go.
//go:build (linux && amd64) || (linux && arm64) || (freebsd && amd64)
package runtime
import "unsafe"
// _cgo_mmap is filled in by runtime/cgo when it is linked into the
// program, so it is only non-nil when using cgo.
//
//go:linkname _cgo_mmap _cgo_mmap
var _cgo_mmap unsafe.Pointer
// _cgo_munmap is filled in by runtime/cgo when it is linked into the
// program, so it is only non-nil when using cgo.
//
//go:linkname _cgo_munmap _cgo_munmap
var _cgo_munmap unsafe.Pointer
// mmap is used to route the mmap system call through C code when using cgo, to
// support sanitizer interceptors. Don't allow stack splits, since this function
// (used by sysAlloc) is called in a lot of low-level parts of the runtime and
// callers often assume it won't acquire any locks.
//
//go:nosplit
func mmap(addr unsafe.Pointer, n uintptr, prot, flags, fd int32, off uint32) (unsafe.Pointer, int) {
if _cgo_mmap != nil {
// Make ret a uintptr so that writing to it in the
// function literal does not trigger a write barrier.
// A write barrier here could break because of the way
// that mmap uses the same value both as a pointer and
// an errno value.
var ret uintptr
systemstack(func() {
ret = callCgoMmap(addr, n, prot, flags, fd, off)
})
if ret < 4096 {
return nil, int(ret)
}
return unsafe.Pointer(ret), 0
}
return sysMmap(addr, n, prot, flags, fd, off)
}
func munmap(addr unsafe.Pointer, n uintptr) {
if _cgo_munmap != nil {
systemstack(func() { callCgoMunmap(addr, n) })
return
}
sysMunmap(addr, n)
}
// sysMmap calls the mmap system call. It is implemented in assembly.
func sysMmap(addr unsafe.Pointer, n uintptr, prot, flags, fd int32, off uint32) (p unsafe.Pointer, err int)
// callCgoMmap calls the mmap function in the runtime/cgo package
// using the GCC calling convention. It is implemented in assembly.
func callCgoMmap(addr unsafe.Pointer, n uintptr, prot, flags, fd int32, off uint32) uintptr
// sysMunmap calls the munmap system call. It is implemented in assembly.
func sysMunmap(addr unsafe.Pointer, n uintptr)
// callCgoMunmap calls the munmap function in the runtime/cgo package
// using the GCC calling convention. It is implemented in assembly.
func callCgoMunmap(addr unsafe.Pointer, n uintptr)
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Support for sanitizers. See runtime/cgo/sigaction.go.
//go:build (linux && amd64) || (freebsd && amd64) || (linux && arm64) || (linux && ppc64le)
package runtime
import "unsafe"
// _cgo_sigaction is filled in by runtime/cgo when it is linked into the
// program, so it is only non-nil when using cgo.
//
//go:linkname _cgo_sigaction _cgo_sigaction
var _cgo_sigaction unsafe.Pointer
//go:nosplit
//go:nowritebarrierrec
func sigaction(sig uint32, new, old *sigactiont) {
// racewalk.go avoids adding sanitizing instrumentation to package runtime,
// but we might be calling into instrumented C functions here,
// so we need the pointer parameters to be properly marked.
//
// Mark the input as having been written before the call
// and the output as read after.
if msanenabled && new != nil {
msanwrite(unsafe.Pointer(new), unsafe.Sizeof(*new))
}
if asanenabled && new != nil {
asanwrite(unsafe.Pointer(new), unsafe.Sizeof(*new))
}
if _cgo_sigaction == nil || inForkedChild {
sysSigaction(sig, new, old)
} else {
// We need to call _cgo_sigaction, which means we need a big enough stack
// for C. To complicate matters, we may be in libpreinit (before the
// runtime has been initialized) or in an asynchronous signal handler (with
// the current thread in transition between goroutines, or with the g0
// system stack already in use).
var ret int32
var g *g
if mainStarted {
g = getg()
}
sp := uintptr(unsafe.Pointer(&sig))
switch {
case g == nil:
// No g: we're on a C stack or a signal stack.
ret = callCgoSigaction(uintptr(sig), new, old)
case sp < g.stack.lo || sp >= g.stack.hi:
// We're no longer on g's stack, so we must be handling a signal. It's
// possible that we interrupted the thread during a transition between g
// and g0, so we should stay on the current stack to avoid corrupting g0.
ret = callCgoSigaction(uintptr(sig), new, old)
default:
// We're running on g's stack, so either we're not in a signal handler or
// the signal handler has set the correct g. If we're on gsignal or g0,
// systemstack will make the call directly; otherwise, it will switch to
// g0 to ensure we have enough room to call a libc function.
//
// The function literal that we pass to systemstack is not nosplit, but
// that's ok: we'll be running on a fresh, clean system stack so the stack
// check will always succeed anyway.
systemstack(func() {
ret = callCgoSigaction(uintptr(sig), new, old)
})
}
const EINVAL = 22
if ret == EINVAL {
// libc reserves certain signals — normally 32-33 — for pthreads, and
// returns EINVAL for sigaction calls on those signals. If we get EINVAL,
// fall back to making the syscall directly.
sysSigaction(sig, new, old)
}
}
if msanenabled && old != nil {
msanread(unsafe.Pointer(old), unsafe.Sizeof(*old))
}
if asanenabled && old != nil {
asanread(unsafe.Pointer(old), unsafe.Sizeof(*old))
}
}
// callCgoSigaction calls the sigaction function in the runtime/cgo package
// using the GCC calling convention. It is implemented in assembly.
//
//go:noescape
func callCgoSigaction(sig uintptr, new, old *sigactiont) int32
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Cgo call and callback support.
//
// To call into the C function f from Go, the cgo-generated code calls
// runtime.cgocall(_cgo_Cfunc_f, frame), where _cgo_Cfunc_f is a
// gcc-compiled function written by cgo.
//
// runtime.cgocall (below) calls entersyscall so as not to block
// other goroutines or the garbage collector, and then calls
// runtime.asmcgocall(_cgo_Cfunc_f, frame).
//
// runtime.asmcgocall (in asm_$GOARCH.s) switches to the m->g0 stack
// (assumed to be an operating system-allocated stack, so safe to run
// gcc-compiled code on) and calls _cgo_Cfunc_f(frame).
//
// _cgo_Cfunc_f invokes the actual C function f with arguments
// taken from the frame structure, records the results in the frame,
// and returns to runtime.asmcgocall.
//
// After it regains control, runtime.asmcgocall switches back to the
// original g (m->curg)'s stack and returns to runtime.cgocall.
//
// After it regains control, runtime.cgocall calls exitsyscall, which blocks
// until this m can run Go code without violating the $GOMAXPROCS limit,
// and then unlocks g from m.
//
// The above description skipped over the possibility of the gcc-compiled
// function f calling back into Go. If that happens, we continue down
// the rabbit hole during the execution of f.
//
// To make it possible for gcc-compiled C code to call a Go function p.GoF,
// cgo writes a gcc-compiled function named GoF (not p.GoF, since gcc doesn't
// know about packages). The gcc-compiled C function f calls GoF.
//
// GoF initializes "frame", a structure containing all of its
// arguments and slots for p.GoF's results. It calls
// crosscall2(_cgoexp_GoF, frame, framesize, ctxt) using the gcc ABI.
//
// crosscall2 (in cgo/asm_$GOARCH.s) is a four-argument adapter from
// the gcc function call ABI to the gc function call ABI. At this
// point we're in the Go runtime, but we're still running on m.g0's
// stack and outside the $GOMAXPROCS limit. crosscall2 calls
// runtime.cgocallback(_cgoexp_GoF, frame, ctxt) using the gc ABI.
// (crosscall2's framesize argument is no longer used, but there's one
// case where SWIG calls crosscall2 directly and expects to pass this
// argument. See _cgo_panic.)
//
// runtime.cgocallback (in asm_$GOARCH.s) switches from m.g0's stack
// to the original g (m.curg)'s stack, on which it calls
// runtime.cgocallbackg(_cgoexp_GoF, frame, ctxt). As part of the
// stack switch, runtime.cgocallback saves the current SP as
// m.g0.sched.sp, so that any use of m.g0's stack during the execution
// of the callback will be done below the existing stack frames.
// Before overwriting m.g0.sched.sp, it pushes the old value on the
// m.g0 stack, so that it can be restored later.
//
// runtime.cgocallbackg (below) is now running on a real goroutine
// stack (not an m.g0 stack). First it calls runtime.exitsyscall, which will
// block until the $GOMAXPROCS limit allows running this goroutine.
// Once exitsyscall has returned, it is safe to do things like call the memory
// allocator or invoke the Go callback function. runtime.cgocallbackg
// first defers a function to unwind m.g0.sched.sp, so that if p.GoF
// panics, m.g0.sched.sp will be restored to its old value: the m.g0 stack
// and the m.curg stack will be unwound in lock step.
// Then it calls _cgoexp_GoF(frame).
//
// _cgoexp_GoF, which was generated by cmd/cgo, unpacks the arguments
// from frame, calls p.GoF, writes the results back to frame, and
// returns. Now we start unwinding this whole process.
//
// runtime.cgocallbackg pops but does not execute the deferred
// function to unwind m.g0.sched.sp, calls runtime.entersyscall, and
// returns to runtime.cgocallback.
//
// After it regains control, runtime.cgocallback switches back to
// m.g0's stack (the pointer is still in m.g0.sched.sp), restores the old
// m.g0.sched.sp value from the stack, and returns to crosscall2.
//
// crosscall2 restores the callee-save registers for gcc and returns
// to GoF, which unpacks any result values and returns to f.
package runtime
import (
"internal/goarch"
"internal/goexperiment"
"runtime/internal/sys"
"unsafe"
)
// Addresses collected in a cgo backtrace when crashing.
// Length must match arg.Max in x_cgo_callers in runtime/cgo/gcc_traceback.c.
type cgoCallers [32]uintptr
// argset matches runtime/cgo/linux_syscall.c:argset_t
type argset struct {
args unsafe.Pointer
retval uintptr
}
// wrapper for syscall package to call cgocall for libc (cgo) calls.
//
//go:linkname syscall_cgocaller syscall.cgocaller
//go:nosplit
//go:uintptrescapes
func syscall_cgocaller(fn unsafe.Pointer, args ...uintptr) uintptr {
as := argset{args: unsafe.Pointer(&args[0])}
cgocall(fn, unsafe.Pointer(&as))
return as.retval
}
var ncgocall uint64 // number of cgo calls in total for dead m
// Call from Go to C.
//
// This must be nosplit because it's used for syscalls on some
// platforms. Syscalls may have untyped arguments on the stack, so
// it's not safe to grow or scan the stack.
//
//go:nosplit
func cgocall(fn, arg unsafe.Pointer) int32 {
if !iscgo && GOOS != "solaris" && GOOS != "illumos" && GOOS != "windows" {
throw("cgocall unavailable")
}
if fn == nil {
throw("cgocall nil")
}
if raceenabled {
racereleasemerge(unsafe.Pointer(&racecgosync))
}
mp := getg().m
mp.ncgocall++
mp.ncgo++
// Reset traceback.
mp.cgoCallers[0] = 0
// Announce we are entering a system call
// so that the scheduler knows to create another
// M to run goroutines while we are in the
// foreign code.
//
// The call to asmcgocall is guaranteed not to
// grow the stack and does not allocate memory,
// so it is safe to call while "in a system call", outside
// the $GOMAXPROCS accounting.
//
// fn may call back into Go code, in which case we'll exit the
// "system call", run the Go code (which may grow the stack),
// and then re-enter the "system call" reusing the PC and SP
// saved by entersyscall here.
entersyscall()
// Tell asynchronous preemption that we're entering external
// code. We do this after entersyscall because this may block
// and cause an async preemption to fail, but at this point a
// sync preemption will succeed (though this is not a matter
// of correctness).
osPreemptExtEnter(mp)
mp.incgo = true
errno := asmcgocall(fn, arg)
// Update accounting before exitsyscall because exitsyscall may
// reschedule us on to a different M.
mp.incgo = false
mp.ncgo--
osPreemptExtExit(mp)
exitsyscall()
// Note that raceacquire must be called only after exitsyscall has
// wired this M to a P.
if raceenabled {
raceacquire(unsafe.Pointer(&racecgosync))
}
// From the garbage collector's perspective, time can move
// backwards in the sequence above. If there's a callback into
// Go code, GC will see this function at the call to
// asmcgocall. When the Go call later returns to C, the
// syscall PC/SP is rolled back and the GC sees this function
// back at the call to entersyscall. Normally, fn and arg
// would be live at entersyscall and dead at asmcgocall, so if
// time moved backwards, GC would see these arguments as dead
// and then live. Prevent these undead arguments from crashing
// GC by forcing them to stay live across this time warp.
KeepAlive(fn)
KeepAlive(arg)
KeepAlive(mp)
return errno
}
// Call from C back to Go. fn must point to an ABIInternal Go entry-point.
//
//go:nosplit
func cgocallbackg(fn, frame unsafe.Pointer, ctxt uintptr) {
gp := getg()
if gp != gp.m.curg {
println("runtime: bad g in cgocallback")
exit(2)
}
// The call from C is on gp.m's g0 stack, so we must ensure
// that we stay on that M. We have to do this before calling
// exitsyscall, since it would otherwise be free to move us to
// a different M. The call to unlockOSThread is in unwindm.
lockOSThread()
checkm := gp.m
// Save current syscall parameters, so m.syscall can be
// used again if callback decide to make syscall.
syscall := gp.m.syscall
// entersyscall saves the caller's SP to allow the GC to trace the Go
// stack. However, since we're returning to an earlier stack frame and
// need to pair with the entersyscall() call made by cgocall, we must
// save syscall* and let reentersyscall restore them.
savedsp := unsafe.Pointer(gp.syscallsp)
savedpc := gp.syscallpc
exitsyscall() // coming out of cgo call
gp.m.incgo = false
osPreemptExtExit(gp.m)
cgocallbackg1(fn, frame, ctxt) // will call unlockOSThread
// At this point unlockOSThread has been called.
// The following code must not change to a different m.
// This is enforced by checking incgo in the schedule function.
gp.m.incgo = true
if gp.m != checkm {
throw("m changed unexpectedly in cgocallbackg")
}
osPreemptExtEnter(gp.m)
// going back to cgo call
reentersyscall(savedpc, uintptr(savedsp))
gp.m.syscall = syscall
}
func cgocallbackg1(fn, frame unsafe.Pointer, ctxt uintptr) {
gp := getg()
// When we return, undo the call to lockOSThread in cgocallbackg.
// We must still stay on the same m.
defer unlockOSThread()
if gp.m.needextram || extraMWaiters.Load() > 0 {
gp.m.needextram = false
systemstack(newextram)
}
if ctxt != 0 {
s := append(gp.cgoCtxt, ctxt)
// Now we need to set gp.cgoCtxt = s, but we could get
// a SIGPROF signal while manipulating the slice, and
// the SIGPROF handler could pick up gp.cgoCtxt while
// tracing up the stack. We need to ensure that the
// handler always sees a valid slice, so set the
// values in an order such that it always does.
p := (*slice)(unsafe.Pointer(&gp.cgoCtxt))
atomicstorep(unsafe.Pointer(&p.array), unsafe.Pointer(&s[0]))
p.cap = cap(s)
p.len = len(s)
defer func(gp *g) {
// Decrease the length of the slice by one, safely.
p := (*slice)(unsafe.Pointer(&gp.cgoCtxt))
p.len--
}(gp)
}
if gp.m.ncgo == 0 {
// The C call to Go came from a thread not currently running
// any Go. In the case of -buildmode=c-archive or c-shared,
// this call may be coming in before package initialization
// is complete. Wait until it is.
<-main_init_done
}
// Check whether the profiler needs to be turned on or off; this route to
// run Go code does not use runtime.execute, so bypasses the check there.
hz := sched.profilehz
if gp.m.profilehz != hz {
setThreadCPUProfiler(hz)
}
// Add entry to defer stack in case of panic.
restore := true
defer unwindm(&restore)
if raceenabled {
raceacquire(unsafe.Pointer(&racecgosync))
}
// Invoke callback. This function is generated by cmd/cgo and
// will unpack the argument frame and call the Go function.
var cb func(frame unsafe.Pointer)
cbFV := funcval{uintptr(fn)}
*(*unsafe.Pointer)(unsafe.Pointer(&cb)) = noescape(unsafe.Pointer(&cbFV))
cb(frame)
if raceenabled {
racereleasemerge(unsafe.Pointer(&racecgosync))
}
// Do not unwind m->g0->sched.sp.
// Our caller, cgocallback, will do that.
restore = false
}
func unwindm(restore *bool) {
if *restore {
// Restore sp saved by cgocallback during
// unwind of g's stack (see comment at top of file).
mp := acquirem()
sched := &mp.g0.sched
sched.sp = *(*uintptr)(unsafe.Pointer(sched.sp + alignUp(sys.MinFrameSize, sys.StackAlign)))
// Do the accounting that cgocall will not have a chance to do
// during an unwind.
//
// In the case where a Go call originates from C, ncgo is 0
// and there is no matching cgocall to end.
if mp.ncgo > 0 {
mp.incgo = false
mp.ncgo--
osPreemptExtExit(mp)
}
releasem(mp)
}
}
// called from assembly.
func badcgocallback() {
throw("misaligned stack in cgocallback")
}
// called from (incomplete) assembly.
func cgounimpl() {
throw("cgo not implemented")
}
var racecgosync uint64 // represents possible synchronization in C code
// Pointer checking for cgo code.
// We want to detect all cases where a program that does not use
// unsafe makes a cgo call passing a Go pointer to memory that
// contains a Go pointer. Here a Go pointer is defined as a pointer
// to memory allocated by the Go runtime. Programs that use unsafe
// can evade this restriction easily, so we don't try to catch them.
// The cgo program will rewrite all possibly bad pointer arguments to
// call cgoCheckPointer, where we can catch cases of a Go pointer
// pointing to a Go pointer.
// Complicating matters, taking the address of a slice or array
// element permits the C program to access all elements of the slice
// or array. In that case we will see a pointer to a single element,
// but we need to check the entire data structure.
// The cgoCheckPointer call takes additional arguments indicating that
// it was called on an address expression. An additional argument of
// true means that it only needs to check a single element. An
// additional argument of a slice or array means that it needs to
// check the entire slice/array, but nothing else. Otherwise, the
// pointer could be anything, and we check the entire heap object,
// which is conservative but safe.
// When and if we implement a moving garbage collector,
// cgoCheckPointer will pin the pointer for the duration of the cgo
// call. (This is necessary but not sufficient; the cgo program will
// also have to change to pin Go pointers that cannot point to Go
// pointers.)
// cgoCheckPointer checks if the argument contains a Go pointer that
// points to a Go pointer, and panics if it does.
func cgoCheckPointer(ptr any, arg any) {
if !goexperiment.CgoCheck2 && debug.cgocheck == 0 {
return
}
ep := efaceOf(&ptr)
t := ep._type
top := true
if arg != nil && (t.kind&kindMask == kindPtr || t.kind&kindMask == kindUnsafePointer) {
p := ep.data
if t.kind&kindDirectIface == 0 {
p = *(*unsafe.Pointer)(p)
}
if p == nil || !cgoIsGoPointer(p) {
return
}
aep := efaceOf(&arg)
switch aep._type.kind & kindMask {
case kindBool:
if t.kind&kindMask == kindUnsafePointer {
// We don't know the type of the element.
break
}
pt := (*ptrtype)(unsafe.Pointer(t))
cgoCheckArg(pt.elem, p, true, false, cgoCheckPointerFail)
return
case kindSlice:
// Check the slice rather than the pointer.
ep = aep
t = ep._type
case kindArray:
// Check the array rather than the pointer.
// Pass top as false since we have a pointer
// to the array.
ep = aep
t = ep._type
top = false
default:
throw("can't happen")
}
}
cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, top, cgoCheckPointerFail)
}
const cgoCheckPointerFail = "cgo argument has Go pointer to Go pointer"
const cgoResultFail = "cgo result has Go pointer"
// cgoCheckArg is the real work of cgoCheckPointer. The argument p
// is either a pointer to the value (of type t), or the value itself,
// depending on indir. The top parameter is whether we are at the top
// level, where Go pointers are allowed.
func cgoCheckArg(t *_type, p unsafe.Pointer, indir, top bool, msg string) {
if t.ptrdata == 0 || p == nil {
// If the type has no pointers there is nothing to do.
return
}
switch t.kind & kindMask {
default:
throw("can't happen")
case kindArray:
at := (*arraytype)(unsafe.Pointer(t))
if !indir {
if at.len != 1 {
throw("can't happen")
}
cgoCheckArg(at.elem, p, at.elem.kind&kindDirectIface == 0, top, msg)
return
}
for i := uintptr(0); i < at.len; i++ {
cgoCheckArg(at.elem, p, true, top, msg)
p = add(p, at.elem.size)
}
case kindChan, kindMap:
// These types contain internal pointers that will
// always be allocated in the Go heap. It's never OK
// to pass them to C.
panic(errorString(msg))
case kindFunc:
if indir {
p = *(*unsafe.Pointer)(p)
}
if !cgoIsGoPointer(p) {
return
}
panic(errorString(msg))
case kindInterface:
it := *(**_type)(p)
if it == nil {
return
}
// A type known at compile time is OK since it's
// constant. A type not known at compile time will be
// in the heap and will not be OK.
if inheap(uintptr(unsafe.Pointer(it))) {
panic(errorString(msg))
}
p = *(*unsafe.Pointer)(add(p, goarch.PtrSize))
if !cgoIsGoPointer(p) {
return
}
if !top {
panic(errorString(msg))
}
cgoCheckArg(it, p, it.kind&kindDirectIface == 0, false, msg)
case kindSlice:
st := (*slicetype)(unsafe.Pointer(t))
s := (*slice)(p)
p = s.array
if p == nil || !cgoIsGoPointer(p) {
return
}
if !top {
panic(errorString(msg))
}
if st.elem.ptrdata == 0 {
return
}
for i := 0; i < s.cap; i++ {
cgoCheckArg(st.elem, p, true, false, msg)
p = add(p, st.elem.size)
}
case kindString:
ss := (*stringStruct)(p)
if !cgoIsGoPointer(ss.str) {
return
}
if !top {
panic(errorString(msg))
}
case kindStruct:
st := (*structtype)(unsafe.Pointer(t))
if !indir {
if len(st.fields) != 1 {
throw("can't happen")
}
cgoCheckArg(st.fields[0].typ, p, st.fields[0].typ.kind&kindDirectIface == 0, top, msg)
return
}
for _, f := range st.fields {
if f.typ.ptrdata == 0 {
continue
}
cgoCheckArg(f.typ, add(p, f.offset), true, top, msg)
}
case kindPtr, kindUnsafePointer:
if indir {
p = *(*unsafe.Pointer)(p)
if p == nil {
return
}
}
if !cgoIsGoPointer(p) {
return
}
if !top {
panic(errorString(msg))
}
cgoCheckUnknownPointer(p, msg)
}
}
// cgoCheckUnknownPointer is called for an arbitrary pointer into Go
// memory. It checks whether that Go memory contains any other
// pointer into Go memory. If it does, we panic.
// The return values are unused but useful to see in panic tracebacks.
func cgoCheckUnknownPointer(p unsafe.Pointer, msg string) (base, i uintptr) {
if inheap(uintptr(p)) {
b, span, _ := findObject(uintptr(p), 0, 0)
base = b
if base == 0 {
return
}
n := span.elemsize
hbits := heapBitsForAddr(base, n)
for {
var addr uintptr
if hbits, addr = hbits.next(); addr == 0 {
break
}
if cgoIsGoPointer(*(*unsafe.Pointer)(unsafe.Pointer(addr))) {
panic(errorString(msg))
}
}
return
}
for _, datap := range activeModules() {
if cgoInRange(p, datap.data, datap.edata) || cgoInRange(p, datap.bss, datap.ebss) {
// We have no way to know the size of the object.
// We have to assume that it might contain a pointer.
panic(errorString(msg))
}
// In the text or noptr sections, we know that the
// pointer does not point to a Go pointer.
}
return
}
// cgoIsGoPointer reports whether the pointer is a Go pointer--a
// pointer to Go memory. We only care about Go memory that might
// contain pointers.
//
//go:nosplit
//go:nowritebarrierrec
func cgoIsGoPointer(p unsafe.Pointer) bool {
if p == nil {
return false
}
if inHeapOrStack(uintptr(p)) {
return true
}
for _, datap := range activeModules() {
if cgoInRange(p, datap.data, datap.edata) || cgoInRange(p, datap.bss, datap.ebss) {
return true
}
}
return false
}
// cgoInRange reports whether p is between start and end.
//
//go:nosplit
//go:nowritebarrierrec
func cgoInRange(p unsafe.Pointer, start, end uintptr) bool {
return start <= uintptr(p) && uintptr(p) < end
}
// cgoCheckResult is called to check the result parameter of an
// exported Go function. It panics if the result is or contains a Go
// pointer.
func cgoCheckResult(val any) {
if !goexperiment.CgoCheck2 && debug.cgocheck == 0 {
return
}
ep := efaceOf(&val)
t := ep._type
cgoCheckArg(t, ep.data, t.kind&kindDirectIface == 0, false, cgoResultFail)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// These functions are called from C code via cgo/callbacks.go.
// Panic.
func _cgo_panic_internal(p *byte) {
panic(gostringnocopy(p))
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code to check that pointer writes follow the cgo rules.
// These functions are invoked when GOEXPERIMENT=cgocheck2 is enabled.
package runtime
import (
"internal/goarch"
"unsafe"
)
const cgoWriteBarrierFail = "Go pointer stored into non-Go memory"
// cgoCheckPtrWrite is called whenever a pointer is stored into memory.
// It throws if the program is storing a Go pointer into non-Go memory.
//
// This is called from generated code when GOEXPERIMENT=cgocheck2 is enabled.
//
//go:nosplit
//go:nowritebarrier
func cgoCheckPtrWrite(dst *unsafe.Pointer, src unsafe.Pointer) {
if !mainStarted {
// Something early in startup hates this function.
// Don't start doing any actual checking until the
// runtime has set itself up.
return
}
if !cgoIsGoPointer(src) {
return
}
if cgoIsGoPointer(unsafe.Pointer(dst)) {
return
}
// If we are running on the system stack then dst might be an
// address on the stack, which is OK.
gp := getg()
if gp == gp.m.g0 || gp == gp.m.gsignal {
return
}
// Allocating memory can write to various mfixalloc structs
// that look like they are non-Go memory.
if gp.m.mallocing != 0 {
return
}
// It's OK if writing to memory allocated by persistentalloc.
// Do this check last because it is more expensive and rarely true.
// If it is false the expense doesn't matter since we are crashing.
if inPersistentAlloc(uintptr(unsafe.Pointer(dst))) {
return
}
systemstack(func() {
println("write of Go pointer", hex(uintptr(src)), "to non-Go memory", hex(uintptr(unsafe.Pointer(dst))))
throw(cgoWriteBarrierFail)
})
}
// cgoCheckMemmove is called when moving a block of memory.
// It throws if the program is copying a block that contains a Go pointer
// into non-Go memory.
//
// This is called from generated code when GOEXPERIMENT=cgocheck2 is enabled.
//
//go:nosplit
//go:nowritebarrier
func cgoCheckMemmove(typ *_type, dst, src unsafe.Pointer) {
cgoCheckMemmove2(typ, dst, src, 0, typ.size)
}
// cgoCheckMemmove2 is called when moving a block of memory.
// dst and src point off bytes into the value to copy.
// size is the number of bytes to copy.
// It throws if the program is copying a block that contains a Go pointer
// into non-Go memory.
//go:nosplit
//go:nowritebarrier
func cgoCheckMemmove2(typ *_type, dst, src unsafe.Pointer, off, size uintptr) {
if typ.ptrdata == 0 {
return
}
if !cgoIsGoPointer(src) {
return
}
if cgoIsGoPointer(dst) {
return
}
cgoCheckTypedBlock(typ, src, off, size)
}
// cgoCheckSliceCopy is called when copying n elements of a slice.
// src and dst are pointers to the first element of the slice.
// typ is the element type of the slice.
// It throws if the program is copying slice elements that contain Go pointers
// into non-Go memory.
//
//go:nosplit
//go:nowritebarrier
func cgoCheckSliceCopy(typ *_type, dst, src unsafe.Pointer, n int) {
if typ.ptrdata == 0 {
return
}
if !cgoIsGoPointer(src) {
return
}
if cgoIsGoPointer(dst) {
return
}
p := src
for i := 0; i < n; i++ {
cgoCheckTypedBlock(typ, p, 0, typ.size)
p = add(p, typ.size)
}
}
// cgoCheckTypedBlock checks the block of memory at src, for up to size bytes,
// and throws if it finds a Go pointer. The type of the memory is typ,
// and src is off bytes into that type.
//
//go:nosplit
//go:nowritebarrier
func cgoCheckTypedBlock(typ *_type, src unsafe.Pointer, off, size uintptr) {
// Anything past typ.ptrdata is not a pointer.
if typ.ptrdata <= off {
return
}
if ptrdataSize := typ.ptrdata - off; size > ptrdataSize {
size = ptrdataSize
}
if typ.kind&kindGCProg == 0 {
cgoCheckBits(src, typ.gcdata, off, size)
return
}
// The type has a GC program. Try to find GC bits somewhere else.
for _, datap := range activeModules() {
if cgoInRange(src, datap.data, datap.edata) {
doff := uintptr(src) - datap.data
cgoCheckBits(add(src, -doff), datap.gcdatamask.bytedata, off+doff, size)
return
}
if cgoInRange(src, datap.bss, datap.ebss) {
boff := uintptr(src) - datap.bss
cgoCheckBits(add(src, -boff), datap.gcbssmask.bytedata, off+boff, size)
return
}
}
s := spanOfUnchecked(uintptr(src))
if s.state.get() == mSpanManual {
// There are no heap bits for value stored on the stack.
// For a channel receive src might be on the stack of some
// other goroutine, so we can't unwind the stack even if
// we wanted to.
// We can't expand the GC program without extra storage
// space we can't easily get.
// Fortunately we have the type information.
systemstack(func() {
cgoCheckUsingType(typ, src, off, size)
})
return
}
// src must be in the regular heap.
hbits := heapBitsForAddr(uintptr(src), size)
for {
var addr uintptr
if hbits, addr = hbits.next(); addr == 0 {
break
}
v := *(*unsafe.Pointer)(unsafe.Pointer(addr))
if cgoIsGoPointer(v) {
throw(cgoWriteBarrierFail)
}
}
}
// cgoCheckBits checks the block of memory at src, for up to size
// bytes, and throws if it finds a Go pointer. The gcbits mark each
// pointer value. The src pointer is off bytes into the gcbits.
//
//go:nosplit
//go:nowritebarrier
func cgoCheckBits(src unsafe.Pointer, gcbits *byte, off, size uintptr) {
skipMask := off / goarch.PtrSize / 8
skipBytes := skipMask * goarch.PtrSize * 8
ptrmask := addb(gcbits, skipMask)
src = add(src, skipBytes)
off -= skipBytes
size += off
var bits uint32
for i := uintptr(0); i < size; i += goarch.PtrSize {
if i&(goarch.PtrSize*8-1) == 0 {
bits = uint32(*ptrmask)
ptrmask = addb(ptrmask, 1)
} else {
bits >>= 1
}
if off > 0 {
off -= goarch.PtrSize
} else {
if bits&1 != 0 {
v := *(*unsafe.Pointer)(add(src, i))
if cgoIsGoPointer(v) {
throw(cgoWriteBarrierFail)
}
}
}
}
}
// cgoCheckUsingType is like cgoCheckTypedBlock, but is a last ditch
// fall back to look for pointers in src using the type information.
// We only use this when looking at a value on the stack when the type
// uses a GC program, because otherwise it's more efficient to use the
// GC bits. This is called on the system stack.
//
//go:nowritebarrier
//go:systemstack
func cgoCheckUsingType(typ *_type, src unsafe.Pointer, off, size uintptr) {
if typ.ptrdata == 0 {
return
}
// Anything past typ.ptrdata is not a pointer.
if typ.ptrdata <= off {
return
}
if ptrdataSize := typ.ptrdata - off; size > ptrdataSize {
size = ptrdataSize
}
if typ.kind&kindGCProg == 0 {
cgoCheckBits(src, typ.gcdata, off, size)
return
}
switch typ.kind & kindMask {
default:
throw("can't happen")
case kindArray:
at := (*arraytype)(unsafe.Pointer(typ))
for i := uintptr(0); i < at.len; i++ {
if off < at.elem.size {
cgoCheckUsingType(at.elem, src, off, size)
}
src = add(src, at.elem.size)
skipped := off
if skipped > at.elem.size {
skipped = at.elem.size
}
checked := at.elem.size - skipped
off -= skipped
if size <= checked {
return
}
size -= checked
}
case kindStruct:
st := (*structtype)(unsafe.Pointer(typ))
for _, f := range st.fields {
if off < f.typ.size {
cgoCheckUsingType(f.typ, src, off, size)
}
src = add(src, f.typ.size)
skipped := off
if skipped > f.typ.size {
skipped = f.typ.size
}
checked := f.typ.size - skipped
off -= skipped
if size <= checked {
return
}
size -= checked
}
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// This file contains the implementation of Go channels.
// Invariants:
// At least one of c.sendq and c.recvq is empty,
// except for the case of an unbuffered channel with a single goroutine
// blocked on it for both sending and receiving using a select statement,
// in which case the length of c.sendq and c.recvq is limited only by the
// size of the select statement.
//
// For buffered channels, also:
// c.qcount > 0 implies that c.recvq is empty.
// c.qcount < c.dataqsiz implies that c.sendq is empty.
import (
"internal/abi"
"runtime/internal/atomic"
"runtime/internal/math"
"unsafe"
)
const (
maxAlign = 8
hchanSize = unsafe.Sizeof(hchan{}) + uintptr(-int(unsafe.Sizeof(hchan{}))&(maxAlign-1))
debugChan = false
)
type hchan struct {
qcount uint // total data in the queue
dataqsiz uint // size of the circular queue
buf unsafe.Pointer // points to an array of dataqsiz elements
elemsize uint16
closed uint32
elemtype *_type // element type
sendx uint // send index
recvx uint // receive index
recvq waitq // list of recv waiters
sendq waitq // list of send waiters
// lock protects all fields in hchan, as well as several
// fields in sudogs blocked on this channel.
//
// Do not change another G's status while holding this lock
// (in particular, do not ready a G), as this can deadlock
// with stack shrinking.
lock mutex
}
type waitq struct {
first *sudog
last *sudog
}
//go:linkname reflect_makechan reflect.makechan
func reflect_makechan(t *chantype, size int) *hchan {
return makechan(t, size)
}
func makechan64(t *chantype, size int64) *hchan {
if int64(int(size)) != size {
panic(plainError("makechan: size out of range"))
}
return makechan(t, int(size))
}
func makechan(t *chantype, size int) *hchan {
elem := t.elem
// compiler checks this but be safe.
if elem.size >= 1<<16 {
throw("makechan: invalid channel element type")
}
if hchanSize%maxAlign != 0 || elem.align > maxAlign {
throw("makechan: bad alignment")
}
mem, overflow := math.MulUintptr(elem.size, uintptr(size))
if overflow || mem > maxAlloc-hchanSize || size < 0 {
panic(plainError("makechan: size out of range"))
}
// Hchan does not contain pointers interesting for GC when elements stored in buf do not contain pointers.
// buf points into the same allocation, elemtype is persistent.
// SudoG's are referenced from their owning thread so they can't be collected.
// TODO(dvyukov,rlh): Rethink when collector can move allocated objects.
var c *hchan
switch {
case mem == 0:
// Queue or element size is zero.
c = (*hchan)(mallocgc(hchanSize, nil, true))
// Race detector uses this location for synchronization.
c.buf = c.raceaddr()
case elem.ptrdata == 0:
// Elements do not contain pointers.
// Allocate hchan and buf in one call.
c = (*hchan)(mallocgc(hchanSize+mem, nil, true))
c.buf = add(unsafe.Pointer(c), hchanSize)
default:
// Elements contain pointers.
c = new(hchan)
c.buf = mallocgc(mem, elem, true)
}
c.elemsize = uint16(elem.size)
c.elemtype = elem
c.dataqsiz = uint(size)
lockInit(&c.lock, lockRankHchan)
if debugChan {
print("makechan: chan=", c, "; elemsize=", elem.size, "; dataqsiz=", size, "\n")
}
return c
}
// chanbuf(c, i) is pointer to the i'th slot in the buffer.
func chanbuf(c *hchan, i uint) unsafe.Pointer {
return add(c.buf, uintptr(i)*uintptr(c.elemsize))
}
// full reports whether a send on c would block (that is, the channel is full).
// It uses a single word-sized read of mutable state, so although
// the answer is instantaneously true, the correct answer may have changed
// by the time the calling function receives the return value.
func full(c *hchan) bool {
// c.dataqsiz is immutable (never written after the channel is created)
// so it is safe to read at any time during channel operation.
if c.dataqsiz == 0 {
// Assumes that a pointer read is relaxed-atomic.
return c.recvq.first == nil
}
// Assumes that a uint read is relaxed-atomic.
return c.qcount == c.dataqsiz
}
// entry point for c <- x from compiled code.
//
//go:nosplit
func chansend1(c *hchan, elem unsafe.Pointer) {
chansend(c, elem, true, getcallerpc())
}
/*
* generic single channel send/recv
* If block is not nil,
* then the protocol will not
* sleep but return if it could
* not complete.
*
* sleep can wake up with g.param == nil
* when a channel involved in the sleep has
* been closed. it is easiest to loop and re-run
* the operation; we'll see that it's now closed.
*/
func chansend(c *hchan, ep unsafe.Pointer, block bool, callerpc uintptr) bool {
if c == nil {
if !block {
return false
}
gopark(nil, nil, waitReasonChanSendNilChan, traceEvGoStop, 2)
throw("unreachable")
}
if debugChan {
print("chansend: chan=", c, "\n")
}
if raceenabled {
racereadpc(c.raceaddr(), callerpc, abi.FuncPCABIInternal(chansend))
}
// Fast path: check for failed non-blocking operation without acquiring the lock.
//
// After observing that the channel is not closed, we observe that the channel is
// not ready for sending. Each of these observations is a single word-sized read
// (first c.closed and second full()).
// Because a closed channel cannot transition from 'ready for sending' to
// 'not ready for sending', even if the channel is closed between the two observations,
// they imply a moment between the two when the channel was both not yet closed
// and not ready for sending. We behave as if we observed the channel at that moment,
// and report that the send cannot proceed.
//
// It is okay if the reads are reordered here: if we observe that the channel is not
// ready for sending and then observe that it is not closed, that implies that the
// channel wasn't closed during the first observation. However, nothing here
// guarantees forward progress. We rely on the side effects of lock release in
// chanrecv() and closechan() to update this thread's view of c.closed and full().
if !block && c.closed == 0 && full(c) {
return false
}
var t0 int64
if blockprofilerate > 0 {
t0 = cputicks()
}
lock(&c.lock)
if c.closed != 0 {
unlock(&c.lock)
panic(plainError("send on closed channel"))
}
if sg := c.recvq.dequeue(); sg != nil {
// Found a waiting receiver. We pass the value we want to send
// directly to the receiver, bypassing the channel buffer (if any).
send(c, sg, ep, func() { unlock(&c.lock) }, 3)
return true
}
if c.qcount < c.dataqsiz {
// Space is available in the channel buffer. Enqueue the element to send.
qp := chanbuf(c, c.sendx)
if raceenabled {
racenotify(c, c.sendx, nil)
}
typedmemmove(c.elemtype, qp, ep)
c.sendx++
if c.sendx == c.dataqsiz {
c.sendx = 0
}
c.qcount++
unlock(&c.lock)
return true
}
if !block {
unlock(&c.lock)
return false
}
// Block on the channel. Some receiver will complete our operation for us.
gp := getg()
mysg := acquireSudog()
mysg.releasetime = 0
if t0 != 0 {
mysg.releasetime = -1
}
// No stack splits between assigning elem and enqueuing mysg
// on gp.waiting where copystack can find it.
mysg.elem = ep
mysg.waitlink = nil
mysg.g = gp
mysg.isSelect = false
mysg.c = c
gp.waiting = mysg
gp.param = nil
c.sendq.enqueue(mysg)
// Signal to anyone trying to shrink our stack that we're about
// to park on a channel. The window between when this G's status
// changes and when we set gp.activeStackChans is not safe for
// stack shrinking.
gp.parkingOnChan.Store(true)
gopark(chanparkcommit, unsafe.Pointer(&c.lock), waitReasonChanSend, traceEvGoBlockSend, 2)
// Ensure the value being sent is kept alive until the
// receiver copies it out. The sudog has a pointer to the
// stack object, but sudogs aren't considered as roots of the
// stack tracer.
KeepAlive(ep)
// someone woke us up.
if mysg != gp.waiting {
throw("G waiting list is corrupted")
}
gp.waiting = nil
gp.activeStackChans = false
closed := !mysg.success
gp.param = nil
if mysg.releasetime > 0 {
blockevent(mysg.releasetime-t0, 2)
}
mysg.c = nil
releaseSudog(mysg)
if closed {
if c.closed == 0 {
throw("chansend: spurious wakeup")
}
panic(plainError("send on closed channel"))
}
return true
}
// send processes a send operation on an empty channel c.
// The value ep sent by the sender is copied to the receiver sg.
// The receiver is then woken up to go on its merry way.
// Channel c must be empty and locked. send unlocks c with unlockf.
// sg must already be dequeued from c.
// ep must be non-nil and point to the heap or the caller's stack.
func send(c *hchan, sg *sudog, ep unsafe.Pointer, unlockf func(), skip int) {
if raceenabled {
if c.dataqsiz == 0 {
racesync(c, sg)
} else {
// Pretend we go through the buffer, even though
// we copy directly. Note that we need to increment
// the head/tail locations only when raceenabled.
racenotify(c, c.recvx, nil)
racenotify(c, c.recvx, sg)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.sendx = c.recvx // c.sendx = (c.sendx+1) % c.dataqsiz
}
}
if sg.elem != nil {
sendDirect(c.elemtype, sg, ep)
sg.elem = nil
}
gp := sg.g
unlockf()
gp.param = unsafe.Pointer(sg)
sg.success = true
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
goready(gp, skip+1)
}
// Sends and receives on unbuffered or empty-buffered channels are the
// only operations where one running goroutine writes to the stack of
// another running goroutine. The GC assumes that stack writes only
// happen when the goroutine is running and are only done by that
// goroutine. Using a write barrier is sufficient to make up for
// violating that assumption, but the write barrier has to work.
// typedmemmove will call bulkBarrierPreWrite, but the target bytes
// are not in the heap, so that will not help. We arrange to call
// memmove and typeBitsBulkBarrier instead.
func sendDirect(t *_type, sg *sudog, src unsafe.Pointer) {
// src is on our stack, dst is a slot on another stack.
// Once we read sg.elem out of sg, it will no longer
// be updated if the destination's stack gets copied (shrunk).
// So make sure that no preemption points can happen between read & use.
dst := sg.elem
typeBitsBulkBarrier(t, uintptr(dst), uintptr(src), t.size)
// No need for cgo write barrier checks because dst is always
// Go memory.
memmove(dst, src, t.size)
}
func recvDirect(t *_type, sg *sudog, dst unsafe.Pointer) {
// dst is on our stack or the heap, src is on another stack.
// The channel is locked, so src will not move during this
// operation.
src := sg.elem
typeBitsBulkBarrier(t, uintptr(dst), uintptr(src), t.size)
memmove(dst, src, t.size)
}
func closechan(c *hchan) {
if c == nil {
panic(plainError("close of nil channel"))
}
lock(&c.lock)
if c.closed != 0 {
unlock(&c.lock)
panic(plainError("close of closed channel"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(c.raceaddr(), callerpc, abi.FuncPCABIInternal(closechan))
racerelease(c.raceaddr())
}
c.closed = 1
var glist gList
// release all readers
for {
sg := c.recvq.dequeue()
if sg == nil {
break
}
if sg.elem != nil {
typedmemclr(c.elemtype, sg.elem)
sg.elem = nil
}
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
gp := sg.g
gp.param = unsafe.Pointer(sg)
sg.success = false
if raceenabled {
raceacquireg(gp, c.raceaddr())
}
glist.push(gp)
}
// release all writers (they will panic)
for {
sg := c.sendq.dequeue()
if sg == nil {
break
}
sg.elem = nil
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
gp := sg.g
gp.param = unsafe.Pointer(sg)
sg.success = false
if raceenabled {
raceacquireg(gp, c.raceaddr())
}
glist.push(gp)
}
unlock(&c.lock)
// Ready all Gs now that we've dropped the channel lock.
for !glist.empty() {
gp := glist.pop()
gp.schedlink = 0
goready(gp, 3)
}
}
// empty reports whether a read from c would block (that is, the channel is
// empty). It uses a single atomic read of mutable state.
func empty(c *hchan) bool {
// c.dataqsiz is immutable.
if c.dataqsiz == 0 {
return atomic.Loadp(unsafe.Pointer(&c.sendq.first)) == nil
}
return atomic.Loaduint(&c.qcount) == 0
}
// entry points for <- c from compiled code.
//
//go:nosplit
func chanrecv1(c *hchan, elem unsafe.Pointer) {
chanrecv(c, elem, true)
}
//go:nosplit
func chanrecv2(c *hchan, elem unsafe.Pointer) (received bool) {
_, received = chanrecv(c, elem, true)
return
}
// chanrecv receives on channel c and writes the received data to ep.
// ep may be nil, in which case received data is ignored.
// If block == false and no elements are available, returns (false, false).
// Otherwise, if c is closed, zeros *ep and returns (true, false).
// Otherwise, fills in *ep with an element and returns (true, true).
// A non-nil ep must point to the heap or the caller's stack.
func chanrecv(c *hchan, ep unsafe.Pointer, block bool) (selected, received bool) {
// raceenabled: don't need to check ep, as it is always on the stack
// or is new memory allocated by reflect.
if debugChan {
print("chanrecv: chan=", c, "\n")
}
if c == nil {
if !block {
return
}
gopark(nil, nil, waitReasonChanReceiveNilChan, traceEvGoStop, 2)
throw("unreachable")
}
// Fast path: check for failed non-blocking operation without acquiring the lock.
if !block && empty(c) {
// After observing that the channel is not ready for receiving, we observe whether the
// channel is closed.
//
// Reordering of these checks could lead to incorrect behavior when racing with a close.
// For example, if the channel was open and not empty, was closed, and then drained,
// reordered reads could incorrectly indicate "open and empty". To prevent reordering,
// we use atomic loads for both checks, and rely on emptying and closing to happen in
// separate critical sections under the same lock. This assumption fails when closing
// an unbuffered channel with a blocked send, but that is an error condition anyway.
if atomic.Load(&c.closed) == 0 {
// Because a channel cannot be reopened, the later observation of the channel
// being not closed implies that it was also not closed at the moment of the
// first observation. We behave as if we observed the channel at that moment
// and report that the receive cannot proceed.
return
}
// The channel is irreversibly closed. Re-check whether the channel has any pending data
// to receive, which could have arrived between the empty and closed checks above.
// Sequential consistency is also required here, when racing with such a send.
if empty(c) {
// The channel is irreversibly closed and empty.
if raceenabled {
raceacquire(c.raceaddr())
}
if ep != nil {
typedmemclr(c.elemtype, ep)
}
return true, false
}
}
var t0 int64
if blockprofilerate > 0 {
t0 = cputicks()
}
lock(&c.lock)
if c.closed != 0 {
if c.qcount == 0 {
if raceenabled {
raceacquire(c.raceaddr())
}
unlock(&c.lock)
if ep != nil {
typedmemclr(c.elemtype, ep)
}
return true, false
}
// The channel has been closed, but the channel's buffer have data.
} else {
// Just found waiting sender with not closed.
if sg := c.sendq.dequeue(); sg != nil {
// Found a waiting sender. If buffer is size 0, receive value
// directly from sender. Otherwise, receive from head of queue
// and add sender's value to the tail of the queue (both map to
// the same buffer slot because the queue is full).
recv(c, sg, ep, func() { unlock(&c.lock) }, 3)
return true, true
}
}
if c.qcount > 0 {
// Receive directly from queue
qp := chanbuf(c, c.recvx)
if raceenabled {
racenotify(c, c.recvx, nil)
}
if ep != nil {
typedmemmove(c.elemtype, ep, qp)
}
typedmemclr(c.elemtype, qp)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.qcount--
unlock(&c.lock)
return true, true
}
if !block {
unlock(&c.lock)
return false, false
}
// no sender available: block on this channel.
gp := getg()
mysg := acquireSudog()
mysg.releasetime = 0
if t0 != 0 {
mysg.releasetime = -1
}
// No stack splits between assigning elem and enqueuing mysg
// on gp.waiting where copystack can find it.
mysg.elem = ep
mysg.waitlink = nil
gp.waiting = mysg
mysg.g = gp
mysg.isSelect = false
mysg.c = c
gp.param = nil
c.recvq.enqueue(mysg)
// Signal to anyone trying to shrink our stack that we're about
// to park on a channel. The window between when this G's status
// changes and when we set gp.activeStackChans is not safe for
// stack shrinking.
gp.parkingOnChan.Store(true)
gopark(chanparkcommit, unsafe.Pointer(&c.lock), waitReasonChanReceive, traceEvGoBlockRecv, 2)
// someone woke us up
if mysg != gp.waiting {
throw("G waiting list is corrupted")
}
gp.waiting = nil
gp.activeStackChans = false
if mysg.releasetime > 0 {
blockevent(mysg.releasetime-t0, 2)
}
success := mysg.success
gp.param = nil
mysg.c = nil
releaseSudog(mysg)
return true, success
}
// recv processes a receive operation on a full channel c.
// There are 2 parts:
// 1. The value sent by the sender sg is put into the channel
// and the sender is woken up to go on its merry way.
// 2. The value received by the receiver (the current G) is
// written to ep.
//
// For synchronous channels, both values are the same.
// For asynchronous channels, the receiver gets its data from
// the channel buffer and the sender's data is put in the
// channel buffer.
// Channel c must be full and locked. recv unlocks c with unlockf.
// sg must already be dequeued from c.
// A non-nil ep must point to the heap or the caller's stack.
func recv(c *hchan, sg *sudog, ep unsafe.Pointer, unlockf func(), skip int) {
if c.dataqsiz == 0 {
if raceenabled {
racesync(c, sg)
}
if ep != nil {
// copy data from sender
recvDirect(c.elemtype, sg, ep)
}
} else {
// Queue is full. Take the item at the
// head of the queue. Make the sender enqueue
// its item at the tail of the queue. Since the
// queue is full, those are both the same slot.
qp := chanbuf(c, c.recvx)
if raceenabled {
racenotify(c, c.recvx, nil)
racenotify(c, c.recvx, sg)
}
// copy data from queue to receiver
if ep != nil {
typedmemmove(c.elemtype, ep, qp)
}
// copy data from sender to queue
typedmemmove(c.elemtype, qp, sg.elem)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.sendx = c.recvx // c.sendx = (c.sendx+1) % c.dataqsiz
}
sg.elem = nil
gp := sg.g
unlockf()
gp.param = unsafe.Pointer(sg)
sg.success = true
if sg.releasetime != 0 {
sg.releasetime = cputicks()
}
goready(gp, skip+1)
}
func chanparkcommit(gp *g, chanLock unsafe.Pointer) bool {
// There are unlocked sudogs that point into gp's stack. Stack
// copying must lock the channels of those sudogs.
// Set activeStackChans here instead of before we try parking
// because we could self-deadlock in stack growth on the
// channel lock.
gp.activeStackChans = true
// Mark that it's safe for stack shrinking to occur now,
// because any thread acquiring this G's stack for shrinking
// is guaranteed to observe activeStackChans after this store.
gp.parkingOnChan.Store(false)
// Make sure we unlock after setting activeStackChans and
// unsetting parkingOnChan. The moment we unlock chanLock
// we risk gp getting readied by a channel operation and
// so gp could continue running before everything before
// the unlock is visible (even to gp itself).
unlock((*mutex)(chanLock))
return true
}
// compiler implements
//
// select {
// case c <- v:
// ... foo
// default:
// ... bar
// }
//
// as
//
// if selectnbsend(c, v) {
// ... foo
// } else {
// ... bar
// }
func selectnbsend(c *hchan, elem unsafe.Pointer) (selected bool) {
return chansend(c, elem, false, getcallerpc())
}
// compiler implements
//
// select {
// case v, ok = <-c:
// ... foo
// default:
// ... bar
// }
//
// as
//
// if selected, ok = selectnbrecv(&v, c); selected {
// ... foo
// } else {
// ... bar
// }
func selectnbrecv(elem unsafe.Pointer, c *hchan) (selected, received bool) {
return chanrecv(c, elem, false)
}
//go:linkname reflect_chansend reflect.chansend
func reflect_chansend(c *hchan, elem unsafe.Pointer, nb bool) (selected bool) {
return chansend(c, elem, !nb, getcallerpc())
}
//go:linkname reflect_chanrecv reflect.chanrecv
func reflect_chanrecv(c *hchan, nb bool, elem unsafe.Pointer) (selected bool, received bool) {
return chanrecv(c, elem, !nb)
}
//go:linkname reflect_chanlen reflect.chanlen
func reflect_chanlen(c *hchan) int {
if c == nil {
return 0
}
return int(c.qcount)
}
//go:linkname reflectlite_chanlen internal/reflectlite.chanlen
func reflectlite_chanlen(c *hchan) int {
if c == nil {
return 0
}
return int(c.qcount)
}
//go:linkname reflect_chancap reflect.chancap
func reflect_chancap(c *hchan) int {
if c == nil {
return 0
}
return int(c.dataqsiz)
}
//go:linkname reflect_chanclose reflect.chanclose
func reflect_chanclose(c *hchan) {
closechan(c)
}
func (q *waitq) enqueue(sgp *sudog) {
sgp.next = nil
x := q.last
if x == nil {
sgp.prev = nil
q.first = sgp
q.last = sgp
return
}
sgp.prev = x
x.next = sgp
q.last = sgp
}
func (q *waitq) dequeue() *sudog {
for {
sgp := q.first
if sgp == nil {
return nil
}
y := sgp.next
if y == nil {
q.first = nil
q.last = nil
} else {
y.prev = nil
q.first = y
sgp.next = nil // mark as removed (see dequeueSudoG)
}
// if a goroutine was put on this queue because of a
// select, there is a small window between the goroutine
// being woken up by a different case and it grabbing the
// channel locks. Once it has the lock
// it removes itself from the queue, so we won't see it after that.
// We use a flag in the G struct to tell us when someone
// else has won the race to signal this goroutine but the goroutine
// hasn't removed itself from the queue yet.
if sgp.isSelect && !sgp.g.selectDone.CompareAndSwap(0, 1) {
continue
}
return sgp
}
}
func (c *hchan) raceaddr() unsafe.Pointer {
// Treat read-like and write-like operations on the channel to
// happen at this address. Avoid using the address of qcount
// or dataqsiz, because the len() and cap() builtins read
// those addresses, and we don't want them racing with
// operations like close().
return unsafe.Pointer(&c.buf)
}
func racesync(c *hchan, sg *sudog) {
racerelease(chanbuf(c, 0))
raceacquireg(sg.g, chanbuf(c, 0))
racereleaseg(sg.g, chanbuf(c, 0))
raceacquire(chanbuf(c, 0))
}
// Notify the race detector of a send or receive involving buffer entry idx
// and a channel c or its communicating partner sg.
// This function handles the special case of c.elemsize==0.
func racenotify(c *hchan, idx uint, sg *sudog) {
// We could have passed the unsafe.Pointer corresponding to entry idx
// instead of idx itself. However, in a future version of this function,
// we can use idx to better handle the case of elemsize==0.
// A future improvement to the detector is to call TSan with c and idx:
// this way, Go will continue to not allocating buffer entries for channels
// of elemsize==0, yet the race detector can be made to handle multiple
// sync objects underneath the hood (one sync object per idx)
qp := chanbuf(c, idx)
// When elemsize==0, we don't allocate a full buffer for the channel.
// Instead of individual buffer entries, the race detector uses the
// c.buf as the only buffer entry. This simplification prevents us from
// following the memory model's happens-before rules (rules that are
// implemented in racereleaseacquire). Instead, we accumulate happens-before
// information in the synchronization object associated with c.buf.
if c.elemsize == 0 {
if sg == nil {
raceacquire(qp)
racerelease(qp)
} else {
raceacquireg(sg.g, qp)
racereleaseg(sg.g, qp)
}
} else {
if sg == nil {
racereleaseacquire(qp)
} else {
racereleaseacquireg(sg.g, qp)
}
}
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
func checkptrAlignment(p unsafe.Pointer, elem *_type, n uintptr) {
// nil pointer is always suitably aligned (#47430).
if p == nil {
return
}
// Check that (*[n]elem)(p) is appropriately aligned.
// Note that we allow unaligned pointers if the types they point to contain
// no pointers themselves. See issue 37298.
// TODO(mdempsky): What about fieldAlign?
if elem.ptrdata != 0 && uintptr(p)&(uintptr(elem.align)-1) != 0 {
throw("checkptr: misaligned pointer conversion")
}
// Check that (*[n]elem)(p) doesn't straddle multiple heap objects.
// TODO(mdempsky): Fix #46938 so we don't need to worry about overflow here.
if checkptrStraddles(p, n*elem.size) {
throw("checkptr: converted pointer straddles multiple allocations")
}
}
// checkptrStraddles reports whether the first size-bytes of memory
// addressed by ptr is known to straddle more than one Go allocation.
func checkptrStraddles(ptr unsafe.Pointer, size uintptr) bool {
if size <= 1 {
return false
}
// Check that add(ptr, size-1) won't overflow. This avoids the risk
// of producing an illegal pointer value (assuming ptr is legal).
if uintptr(ptr) >= -(size - 1) {
return true
}
end := add(ptr, size-1)
// TODO(mdempsky): Detect when [ptr, end] contains Go allocations,
// but neither ptr nor end point into one themselves.
return checkptrBase(ptr) != checkptrBase(end)
}
func checkptrArithmetic(p unsafe.Pointer, originals []unsafe.Pointer) {
if 0 < uintptr(p) && uintptr(p) < minLegalPointer {
throw("checkptr: pointer arithmetic computed bad pointer value")
}
// Check that if the computed pointer p points into a heap
// object, then one of the original pointers must have pointed
// into the same object.
base := checkptrBase(p)
if base == 0 {
return
}
for _, original := range originals {
if base == checkptrBase(original) {
return
}
}
throw("checkptr: pointer arithmetic result points to invalid allocation")
}
// checkptrBase returns the base address for the allocation containing
// the address p.
//
// Importantly, if p1 and p2 point into the same variable, then
// checkptrBase(p1) == checkptrBase(p2). However, the converse/inverse
// is not necessarily true as allocations can have trailing padding,
// and multiple variables may be packed into a single allocation.
func checkptrBase(p unsafe.Pointer) uintptr {
// stack
if gp := getg(); gp.stack.lo <= uintptr(p) && uintptr(p) < gp.stack.hi {
// TODO(mdempsky): Walk the stack to identify the
// specific stack frame or even stack object that p
// points into.
//
// In the mean time, use "1" as a pseudo-address to
// represent the stack. This is an invalid address on
// all platforms, so it's guaranteed to be distinct
// from any of the addresses we might return below.
return 1
}
// heap (must check after stack because of #35068)
if base, _, _ := findObject(uintptr(p), 0, 0); base != 0 {
return base
}
// data or bss
for _, datap := range activeModules() {
if datap.data <= uintptr(p) && uintptr(p) < datap.edata {
return datap.data
}
if datap.bss <= uintptr(p) && uintptr(p) < datap.ebss {
return datap.bss
}
}
return 0
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// inf2one returns a signed 1 if f is an infinity and a signed 0 otherwise.
// The sign of the result is the sign of f.
func inf2one(f float64) float64 {
g := 0.0
if isInf(f) {
g = 1.0
}
return copysign(g, f)
}
func complex128div(n complex128, m complex128) complex128 {
var e, f float64 // complex(e, f) = n/m
// Algorithm for robust complex division as described in
// Robert L. Smith: Algorithm 116: Complex division. Commun. ACM 5(8): 435 (1962).
if abs(real(m)) >= abs(imag(m)) {
ratio := imag(m) / real(m)
denom := real(m) + ratio*imag(m)
e = (real(n) + imag(n)*ratio) / denom
f = (imag(n) - real(n)*ratio) / denom
} else {
ratio := real(m) / imag(m)
denom := imag(m) + ratio*real(m)
e = (real(n)*ratio + imag(n)) / denom
f = (imag(n)*ratio - real(n)) / denom
}
if isNaN(e) && isNaN(f) {
// Correct final result to infinities and zeros if applicable.
// Matches C99: ISO/IEC 9899:1999 - G.5.1 Multiplicative operators.
a, b := real(n), imag(n)
c, d := real(m), imag(m)
switch {
case m == 0 && (!isNaN(a) || !isNaN(b)):
e = copysign(inf, c) * a
f = copysign(inf, c) * b
case (isInf(a) || isInf(b)) && isFinite(c) && isFinite(d):
a = inf2one(a)
b = inf2one(b)
e = inf * (a*c + b*d)
f = inf * (b*c - a*d)
case (isInf(c) || isInf(d)) && isFinite(a) && isFinite(b):
c = inf2one(c)
d = inf2one(d)
e = 0 * (a*c + b*d)
f = 0 * (b*c - a*d)
}
}
return complex(e, f)
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package coverage
import (
"fmt"
"internal/coverage"
"io"
"reflect"
"sync/atomic"
"unsafe"
)
// WriteMetaDir writes a coverage meta-data file for the currently
// running program to the directory specified in 'dir'. An error will
// be returned if the operation can't be completed successfully (for
// example, if the currently running program was not built with
// "-cover", or if the directory does not exist).
func WriteMetaDir(dir string) error {
if !finalHashComputed {
return fmt.Errorf("error: no meta-data available (binary not built with -cover?)")
}
return emitMetaDataToDirectory(dir, getCovMetaList())
}
// WriteMeta writes the meta-data content (the payload that would
// normally be emitted to a meta-data file) for the currently running
// program to the the writer 'w'. An error will be returned if the
// operation can't be completed successfully (for example, if the
// currently running program was not built with "-cover", or if a
// write fails).
func WriteMeta(w io.Writer) error {
if w == nil {
return fmt.Errorf("error: nil writer in WriteMeta")
}
if !finalHashComputed {
return fmt.Errorf("error: no meta-data available (binary not built with -cover?)")
}
ml := getCovMetaList()
return writeMetaData(w, ml, cmode, cgran, finalHash)
}
// WriteCountersDir writes a coverage counter-data file for the
// currently running program to the directory specified in 'dir'. An
// error will be returned if the operation can't be completed
// successfully (for example, if the currently running program was not
// built with "-cover", or if the directory does not exist). The
// counter data written will be a snapshot taken at the point of the
// call.
func WriteCountersDir(dir string) error {
if cmode != coverage.CtrModeAtomic {
return fmt.Errorf("WriteCountersDir invoked for program built with -covermode=%s (please use -covermode=atomic)", cmode.String())
}
return emitCounterDataToDirectory(dir)
}
// WriteCounters writes coverage counter-data content for the
// currently running program to the writer 'w'. An error will be
// returned if the operation can't be completed successfully (for
// example, if the currently running program was not built with
// "-cover", or if a write fails). The counter data written will be a
// snapshot taken at the point of the invocation.
func WriteCounters(w io.Writer) error {
if w == nil {
return fmt.Errorf("error: nil writer in WriteCounters")
}
if cmode != coverage.CtrModeAtomic {
return fmt.Errorf("WriteCounters invoked for program built with -covermode=%s (please use -covermode=atomic)", cmode.String())
}
// Ask the runtime for the list of coverage counter symbols.
cl := getCovCounterList()
if len(cl) == 0 {
return fmt.Errorf("program not built with -cover")
}
if !finalHashComputed {
return fmt.Errorf("meta-data not written yet, unable to write counter data")
}
pm := getCovPkgMap()
s := &emitState{
counterlist: cl,
pkgmap: pm,
}
return s.emitCounterDataToWriter(w)
}
// ClearCounters clears/resets all coverage counter variables in the
// currently running program. It returns an error if the program in
// question was not built with the "-cover" flag. Clearing of coverage
// counters is also not supported for programs not using atomic
// counter mode (see more detailed comments below for the rationale
// here).
func ClearCounters() error {
cl := getCovCounterList()
if len(cl) == 0 {
return fmt.Errorf("program not built with -cover")
}
if cmode != coverage.CtrModeAtomic {
return fmt.Errorf("ClearCounters invoked for program built with -covermode=%s (please use -covermode=atomic)", cmode.String())
}
// Implementation note: this function would be faster and simpler
// if we could just zero out the entire counter array, but for the
// moment we go through and zero out just the slots in the array
// corresponding to the counter values. We do this to avoid the
// following bad scenario: suppose that a user builds their Go
// program with "-cover", and that program has a function (call it
// main.XYZ) that invokes ClearCounters:
//
// func XYZ() {
// ... do some stuff ...
// coverage.ClearCounters()
// if someCondition { <<--- HERE
// ...
// }
// }
//
// At the point where ClearCounters executes, main.XYZ has not yet
// finished running, thus as soon as the call returns the line
// marked "HERE" above will trigger the writing of a non-zero
// value into main.XYZ's counter slab. However since we've just
// finished clearing the entire counter segment, we will have lost
// the values in the prolog portion of main.XYZ's counter slab
// (nctrs, pkgid, funcid). This means that later on at the end of
// program execution as we walk through the entire counter array
// for the program looking for executed functions, we'll zoom past
// main.XYZ's prolog (which was zero'd) and hit the non-zero
// counter value corresponding to the "HERE" block, which will
// then be interpreted as the start of another live function.
// Things will go downhill from there.
//
// This same scenario is also a potential risk if the program is
// running on an architecture that permits reordering of
// writes/stores, since the inconsistency described above could
// arise here. Example scenario:
//
// func ABC() {
// ... // prolog
// if alwaysTrue() {
// XYZ() // counter update here
// }
// }
//
// In the instrumented version of ABC, the prolog of the function
// will contain a series of stores to the initial portion of the
// counter array to write number-of-counters, pkgid, funcid. Later
// in the function there is also a store to increment a counter
// for the block containing the call to XYZ(). If the CPU is
// allowed to reorder stores and decides to issue the XYZ store
// before the prolog stores, this could be observable as an
// inconsistency similar to the one above. Hence the requirement
// for atomic counter mode: according to package atomic docs,
// "...operations that happen in a specific order on one thread,
// will always be observed to happen in exactly that order by
// another thread". Thus we can be sure that there will be no
// inconsistency when reading the counter array from the thread
// running ClearCounters.
var sd []atomic.Uint32
bufHdr := (*reflect.SliceHeader)(unsafe.Pointer(&sd))
for _, c := range cl {
bufHdr.Data = uintptr(unsafe.Pointer(c.Counters))
bufHdr.Len = int(c.Len)
bufHdr.Cap = int(c.Len)
for i := 0; i < len(sd); i++ {
// Skip ahead until the next non-zero value.
sdi := sd[i].Load()
if sdi == 0 {
continue
}
// We found a function that was executed; clear its counters.
nCtrs := sdi
for j := 0; j < int(nCtrs); j++ {
sd[i+coverage.FirstCtrOffset+j].Store(0)
}
// Move to next function.
i += coverage.FirstCtrOffset + int(nCtrs) - 1
}
}
return nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package coverage
import (
"crypto/md5"
"fmt"
"internal/coverage"
"internal/coverage/encodecounter"
"internal/coverage/encodemeta"
"internal/coverage/rtcov"
"io"
"os"
"path/filepath"
"reflect"
"runtime"
"sync/atomic"
"time"
"unsafe"
)
// This file contains functions that support the writing of data files
// emitted at the end of code coverage testing runs, from instrumented
// executables.
// getCovMetaList returns a list of meta-data blobs registered
// for the currently executing instrumented program. It is defined in the
// runtime.
func getCovMetaList() []rtcov.CovMetaBlob
// getCovCounterList returns a list of counter-data blobs registered
// for the currently executing instrumented program. It is defined in the
// runtime.
func getCovCounterList() []rtcov.CovCounterBlob
// getCovPkgMap returns a map storing the remapped package IDs for
// hard-coded runtime packages (see internal/coverage/pkgid.go for
// more on why hard-coded package IDs are needed). This function
// is defined in the runtime.
func getCovPkgMap() map[int]int
// emitState holds useful state information during the emit process.
//
// When an instrumented program finishes execution and starts the
// process of writing out coverage data, it's possible that an
// existing meta-data file already exists in the output directory. In
// this case openOutputFiles() below will leave the 'mf' field below
// as nil. If a new meta-data file is needed, field 'mfname' will be
// the final desired path of the meta file, 'mftmp' will be a
// temporary file, and 'mf' will be an open os.File pointer for
// 'mftmp'. The meta-data file payload will be written to 'mf', the
// temp file will be then closed and renamed (from 'mftmp' to
// 'mfname'), so as to insure that the meta-data file is created
// atomically; we want this so that things work smoothly in cases
// where there are several instances of a given instrumented program
// all terminating at the same time and trying to create meta-data
// files simultaneously.
//
// For counter data files there is less chance of a collision, hence
// the openOutputFiles() stores the counter data file in 'cfname' and
// then places the *io.File into 'cf'.
type emitState struct {
mfname string // path of final meta-data output file
mftmp string // path to meta-data temp file (if needed)
mf *os.File // open os.File for meta-data temp file
cfname string // path of final counter data file
cftmp string // path to counter data temp file
cf *os.File // open os.File for counter data file
outdir string // output directory
// List of meta-data symbols obtained from the runtime
metalist []rtcov.CovMetaBlob
// List of counter-data symbols obtained from the runtime
counterlist []rtcov.CovCounterBlob
// Table to use for remapping hard-coded pkg ids.
pkgmap map[int]int
// emit debug trace output
debug bool
}
var (
// finalHash is computed at init time from the list of meta-data
// symbols registered during init. It is used both for writing the
// meta-data file and counter-data files.
finalHash [16]byte
// Set to true when we've computed finalHash + finalMetaLen.
finalHashComputed bool
// Total meta-data length.
finalMetaLen uint64
// Records whether we've already attempted to write meta-data.
metaDataEmitAttempted bool
// Counter mode for this instrumented program run.
cmode coverage.CounterMode
// Counter granularity for this instrumented program run.
cgran coverage.CounterGranularity
// Cached value of GOCOVERDIR environment variable.
goCoverDir string
// Copy of os.Args made at init time, converted into map format.
capturedOsArgs map[string]string
// Flag used in tests to signal that coverage data already written.
covProfileAlreadyEmitted bool
)
// fileType is used to select between counter-data files and
// meta-data files.
type fileType int
const (
noFile = 1 << iota
metaDataFile
counterDataFile
)
// emitMetaData emits the meta-data output file for this coverage run.
// This entry point is intended to be invoked by the compiler from
// an instrumented program's main package init func.
func emitMetaData() {
if covProfileAlreadyEmitted {
return
}
ml, err := prepareForMetaEmit()
if err != nil {
fmt.Fprintf(os.Stderr, "error: coverage meta-data prep failed: %v\n", err)
if os.Getenv("GOCOVERDEBUG") != "" {
panic("meta-data write failure")
}
}
if len(ml) == 0 {
fmt.Fprintf(os.Stderr, "program not built with -cover\n")
return
}
goCoverDir = os.Getenv("GOCOVERDIR")
if goCoverDir == "" {
fmt.Fprintf(os.Stderr, "warning: GOCOVERDIR not set, no coverage data emitted\n")
return
}
if err := emitMetaDataToDirectory(goCoverDir, ml); err != nil {
fmt.Fprintf(os.Stderr, "error: coverage meta-data emit failed: %v\n", err)
if os.Getenv("GOCOVERDEBUG") != "" {
panic("meta-data write failure")
}
}
}
func modeClash(m coverage.CounterMode) bool {
if m == coverage.CtrModeRegOnly || m == coverage.CtrModeTestMain {
return false
}
if cmode == coverage.CtrModeInvalid {
cmode = m
return false
}
return cmode != m
}
func granClash(g coverage.CounterGranularity) bool {
if cgran == coverage.CtrGranularityInvalid {
cgran = g
return false
}
return cgran != g
}
// prepareForMetaEmit performs preparatory steps needed prior to
// emitting a meta-data file, notably computing a final hash of
// all meta-data blobs and capturing os args.
func prepareForMetaEmit() ([]rtcov.CovMetaBlob, error) {
// Ask the runtime for the list of coverage meta-data symbols.
ml := getCovMetaList()
// In the normal case (go build -o prog.exe ... ; ./prog.exe)
// len(ml) will always be non-zero, but we check here since at
// some point this function will be reachable via user-callable
// APIs (for example, to write out coverage data from a server
// program that doesn't ever call os.Exit).
if len(ml) == 0 {
return nil, nil
}
s := &emitState{
metalist: ml,
debug: os.Getenv("GOCOVERDEBUG") != "",
}
// Capture os.Args() now so as to avoid issues if args
// are rewritten during program execution.
capturedOsArgs = captureOsArgs()
if s.debug {
fmt.Fprintf(os.Stderr, "=+= GOCOVERDIR is %s\n", os.Getenv("GOCOVERDIR"))
fmt.Fprintf(os.Stderr, "=+= contents of covmetalist:\n")
for k, b := range ml {
fmt.Fprintf(os.Stderr, "=+= slot: %d path: %s ", k, b.PkgPath)
if b.PkgID != -1 {
fmt.Fprintf(os.Stderr, " hcid: %d", b.PkgID)
}
fmt.Fprintf(os.Stderr, "\n")
}
pm := getCovPkgMap()
fmt.Fprintf(os.Stderr, "=+= remap table:\n")
for from, to := range pm {
fmt.Fprintf(os.Stderr, "=+= from %d to %d\n",
uint32(from), uint32(to))
}
}
h := md5.New()
tlen := uint64(unsafe.Sizeof(coverage.MetaFileHeader{}))
for _, entry := range ml {
if _, err := h.Write(entry.Hash[:]); err != nil {
return nil, err
}
tlen += uint64(entry.Len)
ecm := coverage.CounterMode(entry.CounterMode)
if modeClash(ecm) {
return nil, fmt.Errorf("coverage counter mode clash: package %s uses mode=%d, but package %s uses mode=%s\n", ml[0].PkgPath, cmode, entry.PkgPath, ecm)
}
ecg := coverage.CounterGranularity(entry.CounterGranularity)
if granClash(ecg) {
return nil, fmt.Errorf("coverage counter granularity clash: package %s uses gran=%d, but package %s uses gran=%s\n", ml[0].PkgPath, cgran, entry.PkgPath, ecg)
}
}
// Hash mode and granularity as well.
h.Write([]byte(cmode.String()))
h.Write([]byte(cgran.String()))
// Compute final digest.
fh := h.Sum(nil)
copy(finalHash[:], fh)
finalHashComputed = true
finalMetaLen = tlen
return ml, nil
}
// emitMetaDataToDirectory emits the meta-data output file to the specified
// directory, returning an error if something went wrong.
func emitMetaDataToDirectory(outdir string, ml []rtcov.CovMetaBlob) error {
ml, err := prepareForMetaEmit()
if err != nil {
return err
}
if len(ml) == 0 {
return nil
}
metaDataEmitAttempted = true
s := &emitState{
metalist: ml,
debug: os.Getenv("GOCOVERDEBUG") != "",
outdir: outdir,
}
// Open output files.
if err := s.openOutputFiles(finalHash, finalMetaLen, metaDataFile); err != nil {
return err
}
// Emit meta-data file only if needed (may already be present).
if s.needMetaDataFile() {
if err := s.emitMetaDataFile(finalHash, finalMetaLen); err != nil {
return err
}
}
return nil
}
// emitCounterData emits the counter data output file for this coverage run.
// This entry point is intended to be invoked by the runtime when an
// instrumented program is terminating or calling os.Exit().
func emitCounterData() {
if goCoverDir == "" || !finalHashComputed || covProfileAlreadyEmitted {
return
}
if err := emitCounterDataToDirectory(goCoverDir); err != nil {
fmt.Fprintf(os.Stderr, "error: coverage counter data emit failed: %v\n", err)
if os.Getenv("GOCOVERDEBUG") != "" {
panic("counter-data write failure")
}
}
}
// emitCounterDataToDirectory emits the counter-data output file for this coverage run.
func emitCounterDataToDirectory(outdir string) error {
// Ask the runtime for the list of coverage counter symbols.
cl := getCovCounterList()
if len(cl) == 0 {
// no work to do here.
return nil
}
if !finalHashComputed {
return fmt.Errorf("error: meta-data not available (binary not built with -cover?)")
}
// Ask the runtime for the list of coverage counter symbols.
pm := getCovPkgMap()
s := &emitState{
counterlist: cl,
pkgmap: pm,
outdir: outdir,
debug: os.Getenv("GOCOVERDEBUG") != "",
}
// Open output file.
if err := s.openOutputFiles(finalHash, finalMetaLen, counterDataFile); err != nil {
return err
}
if s.cf == nil {
return fmt.Errorf("counter data output file open failed (no additional info")
}
// Emit counter data file.
if err := s.emitCounterDataFile(finalHash, s.cf); err != nil {
return err
}
if err := s.cf.Close(); err != nil {
return fmt.Errorf("closing counter data file: %v", err)
}
// Counter file has now been closed. Rename the temp to the
// final desired path.
if err := os.Rename(s.cftmp, s.cfname); err != nil {
return fmt.Errorf("writing %s: rename from %s failed: %v\n", s.cfname, s.cftmp, err)
}
return nil
}
// emitCounterDataToWriter emits counter data for this coverage run to an io.Writer.
func (s *emitState) emitCounterDataToWriter(w io.Writer) error {
if err := s.emitCounterDataFile(finalHash, w); err != nil {
return err
}
return nil
}
// openMetaFile determines whether we need to emit a meta-data output
// file, or whether we can reuse the existing file in the coverage out
// dir. It updates mfname/mftmp/mf fields in 's', returning an error
// if something went wrong. See the comment on the emitState type
// definition above for more on how file opening is managed.
func (s *emitState) openMetaFile(metaHash [16]byte, metaLen uint64) error {
// Open meta-outfile for reading to see if it exists.
fn := fmt.Sprintf("%s.%x", coverage.MetaFilePref, metaHash)
s.mfname = filepath.Join(s.outdir, fn)
fi, err := os.Stat(s.mfname)
if err != nil || fi.Size() != int64(metaLen) {
// We need a new meta-file.
tname := "tmp." + fn + fmt.Sprintf("%d", time.Now().UnixNano())
s.mftmp = filepath.Join(s.outdir, tname)
s.mf, err = os.Create(s.mftmp)
if err != nil {
return fmt.Errorf("creating meta-data file %s: %v", s.mftmp, err)
}
}
return nil
}
// openCounterFile opens an output file for the counter data portion
// of a test coverage run. If updates the 'cfname' and 'cf' fields in
// 's', returning an error if something went wrong.
func (s *emitState) openCounterFile(metaHash [16]byte) error {
processID := os.Getpid()
fn := fmt.Sprintf(coverage.CounterFileTempl, coverage.CounterFilePref, metaHash, processID, time.Now().UnixNano())
s.cfname = filepath.Join(s.outdir, fn)
s.cftmp = filepath.Join(s.outdir, "tmp."+fn)
var err error
s.cf, err = os.Create(s.cftmp)
if err != nil {
return fmt.Errorf("creating counter data file %s: %v", s.cftmp, err)
}
return nil
}
// openOutputFiles opens output files in preparation for emitting
// coverage data. In the case of the meta-data file, openOutputFiles
// may determine that we can reuse an existing meta-data file in the
// outdir, in which case it will leave the 'mf' field in the state
// struct as nil. If a new meta-file is needed, the field 'mfname'
// will be the final desired path of the meta file, 'mftmp' will be a
// temporary file, and 'mf' will be an open os.File pointer for
// 'mftmp'. The idea is that the client/caller will write content into
// 'mf', close it, and then rename 'mftmp' to 'mfname'. This function
// also opens the counter data output file, setting 'cf' and 'cfname'
// in the state struct.
func (s *emitState) openOutputFiles(metaHash [16]byte, metaLen uint64, which fileType) error {
fi, err := os.Stat(s.outdir)
if err != nil {
return fmt.Errorf("output directory %q inaccessible (err: %v); no coverage data written", s.outdir, err)
}
if !fi.IsDir() {
return fmt.Errorf("output directory %q not a directory; no coverage data written", s.outdir)
}
if (which & metaDataFile) != 0 {
if err := s.openMetaFile(metaHash, metaLen); err != nil {
return err
}
}
if (which & counterDataFile) != 0 {
if err := s.openCounterFile(metaHash); err != nil {
return err
}
}
return nil
}
// emitMetaDataFile emits coverage meta-data to a previously opened
// temporary file (s.mftmp), then renames the generated file to the
// final path (s.mfname).
func (s *emitState) emitMetaDataFile(finalHash [16]byte, tlen uint64) error {
if err := writeMetaData(s.mf, s.metalist, cmode, cgran, finalHash); err != nil {
return fmt.Errorf("writing %s: %v\n", s.mftmp, err)
}
if err := s.mf.Close(); err != nil {
return fmt.Errorf("closing meta data temp file: %v", err)
}
// Temp file has now been flushed and closed. Rename the temp to the
// final desired path.
if err := os.Rename(s.mftmp, s.mfname); err != nil {
return fmt.Errorf("writing %s: rename from %s failed: %v\n", s.mfname, s.mftmp, err)
}
return nil
}
// needMetaDataFile returns TRUE if we need to emit a meta-data file
// for this program run. It should be used only after
// openOutputFiles() has been invoked.
func (s *emitState) needMetaDataFile() bool {
return s.mf != nil
}
func writeMetaData(w io.Writer, metalist []rtcov.CovMetaBlob, cmode coverage.CounterMode, gran coverage.CounterGranularity, finalHash [16]byte) error {
mfw := encodemeta.NewCoverageMetaFileWriter("<io.Writer>", w)
// Note: "sd" is re-initialized on each iteration of the loop
// below, and would normally be declared inside the loop, but
// placed here escape analysis since we capture it in bufHdr.
var sd []byte
bufHdr := (*reflect.SliceHeader)(unsafe.Pointer(&sd))
var blobs [][]byte
for _, e := range metalist {
bufHdr.Data = uintptr(unsafe.Pointer(e.P))
bufHdr.Len = int(e.Len)
bufHdr.Cap = int(e.Len)
blobs = append(blobs, sd)
}
return mfw.Write(finalHash, blobs, cmode, gran)
}
func (s *emitState) NumFuncs() (int, error) {
var sd []atomic.Uint32
bufHdr := (*reflect.SliceHeader)(unsafe.Pointer(&sd))
totalFuncs := 0
for _, c := range s.counterlist {
bufHdr.Data = uintptr(unsafe.Pointer(c.Counters))
bufHdr.Len = int(c.Len)
bufHdr.Cap = int(c.Len)
for i := 0; i < len(sd); i++ {
// Skip ahead until the next non-zero value.
sdi := sd[i].Load()
if sdi == 0 {
continue
}
// We found a function that was executed.
nCtrs := sdi
// Check to make sure that we have at least one live
// counter. See the implementation note in ClearCoverageCounters
// for a description of why this is needed.
isLive := false
st := i + coverage.FirstCtrOffset
counters := sd[st : st+int(nCtrs)]
for i := 0; i < len(counters); i++ {
if counters[i].Load() != 0 {
isLive = true
break
}
}
if !isLive {
// Skip this function.
i += coverage.FirstCtrOffset + int(nCtrs) - 1
continue
}
totalFuncs++
// Move to the next function.
i += coverage.FirstCtrOffset + int(nCtrs) - 1
}
}
return totalFuncs, nil
}
func (s *emitState) VisitFuncs(f encodecounter.CounterVisitorFn) error {
var sd []atomic.Uint32
var tcounters []uint32
bufHdr := (*reflect.SliceHeader)(unsafe.Pointer(&sd))
rdCounters := func(actrs []atomic.Uint32, ctrs []uint32) []uint32 {
ctrs = ctrs[:0]
for i := range actrs {
ctrs = append(ctrs, actrs[i].Load())
}
return ctrs
}
dpkg := uint32(0)
for _, c := range s.counterlist {
bufHdr.Data = uintptr(unsafe.Pointer(c.Counters))
bufHdr.Len = int(c.Len)
bufHdr.Cap = int(c.Len)
for i := 0; i < len(sd); i++ {
// Skip ahead until the next non-zero value.
sdi := sd[i].Load()
if sdi == 0 {
continue
}
// We found a function that was executed.
nCtrs := sd[i+coverage.NumCtrsOffset].Load()
pkgId := sd[i+coverage.PkgIdOffset].Load()
funcId := sd[i+coverage.FuncIdOffset].Load()
cst := i + coverage.FirstCtrOffset
counters := sd[cst : cst+int(nCtrs)]
// Check to make sure that we have at least one live
// counter. See the implementation note in ClearCoverageCounters
// for a description of why this is needed.
isLive := false
for i := 0; i < len(counters); i++ {
if counters[i].Load() != 0 {
isLive = true
break
}
}
if !isLive {
// Skip this function.
i += coverage.FirstCtrOffset + int(nCtrs) - 1
continue
}
if s.debug {
if pkgId != dpkg {
dpkg = pkgId
fmt.Fprintf(os.Stderr, "\n=+= %d: pk=%d visit live fcn",
i, pkgId)
}
fmt.Fprintf(os.Stderr, " {i=%d F%d NC%d}", i, funcId, nCtrs)
}
// Vet and/or fix up package ID. A package ID of zero
// indicates that there is some new package X that is a
// runtime dependency, and this package has code that
// executes before its corresponding init package runs.
// This is a fatal error that we should only see during
// Go development (e.g. tip).
ipk := int32(pkgId)
if ipk == 0 {
fmt.Fprintf(os.Stderr, "\n")
reportErrorInHardcodedList(int32(i), ipk, funcId, nCtrs)
} else if ipk < 0 {
if newId, ok := s.pkgmap[int(ipk)]; ok {
pkgId = uint32(newId)
} else {
fmt.Fprintf(os.Stderr, "\n")
reportErrorInHardcodedList(int32(i), ipk, funcId, nCtrs)
}
} else {
// The package ID value stored in the counter array
// has 1 added to it (so as to preclude the
// possibility of a zero value ; see
// runtime.addCovMeta), so subtract off 1 here to form
// the real package ID.
pkgId--
}
tcounters = rdCounters(counters, tcounters)
if err := f(pkgId, funcId, tcounters); err != nil {
return err
}
// Skip over this function.
i += coverage.FirstCtrOffset + int(nCtrs) - 1
}
if s.debug {
fmt.Fprintf(os.Stderr, "\n")
}
}
return nil
}
// captureOsArgs converts os.Args() into the format we use to store
// this info in the counter data file (counter data file "args"
// section is a generic key-value collection). See the 'args' section
// in internal/coverage/defs.go for more info. The args map
// is also used to capture GOOS + GOARCH values as well.
func captureOsArgs() map[string]string {
m := make(map[string]string)
m["argc"] = fmt.Sprintf("%d", len(os.Args))
for k, a := range os.Args {
m[fmt.Sprintf("argv%d", k)] = a
}
m["GOOS"] = runtime.GOOS
m["GOARCH"] = runtime.GOARCH
return m
}
// emitCounterDataFile emits the counter data portion of a
// coverage output file (to the file 's.cf').
func (s *emitState) emitCounterDataFile(finalHash [16]byte, w io.Writer) error {
cfw := encodecounter.NewCoverageDataWriter(w, coverage.CtrULeb128)
if err := cfw.Write(finalHash, capturedOsArgs, s); err != nil {
return err
}
return nil
}
// markProfileEmitted signals the runtime/coverage machinery that
// coverate data output files have already been written out, and there
// is no need to take any additional action at exit time. This
// function is called (via linknamed reference) from the
// coverage-related boilerplate code in _testmain.go emitted for go
// unit tests.
func markProfileEmitted(val bool) {
covProfileAlreadyEmitted = val
}
func reportErrorInHardcodedList(slot, pkgID int32, fnID, nCtrs uint32) {
metaList := getCovMetaList()
pkgMap := getCovPkgMap()
println("internal error in coverage meta-data tracking:")
println("encountered bad pkgID:", pkgID, " at slot:", slot,
" fnID:", fnID, " numCtrs:", nCtrs)
println("list of hard-coded runtime package IDs needs revising.")
println("[see the comment on the 'rtPkgs' var in ")
println(" <goroot>/src/internal/coverage/pkid.go]")
println("registered list:")
for k, b := range metaList {
print("slot: ", k, " path='", b.PkgPath, "' ")
if b.PkgID != -1 {
print(" hard-coded id: ", b.PkgID)
}
println("")
}
println("remap table:")
for from, to := range pkgMap {
println("from ", from, " to ", to)
}
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package coverage
import _ "unsafe"
// initHook is invoked from the main package "init" routine in
// programs built with "-cover". This function is intended to be
// called only by the compiler.
//
// If 'istest' is false, it indicates we're building a regular program
// ("go build -cover ..."), in which case we immediately try to write
// out the meta-data file, and register emitCounterData as an exit
// hook.
//
// If 'istest' is true (indicating that the program in question is a
// Go test binary), then we tentatively queue up both emitMetaData and
// emitCounterData as exit hooks. In the normal case (e.g. regular "go
// test -cover" run) the testmain.go boilerplate will run at the end
// of the test, write out the coverage percentage, and then invoke
// markProfileEmitted() to indicate that no more work needs to be
// done. If however that call is never made, this is a sign that the
// test binary is being used as a replacement binary for the tool
// being tested, hence we do want to run exit hooks when the program
// terminates.
func initHook(istest bool) {
// Note: hooks are run in reverse registration order, so
// register the counter data hook before the meta-data hook
// (in the case where two hooks are needed).
runOnNonZeroExit := true
runtime_addExitHook(emitCounterData, runOnNonZeroExit)
if istest {
runtime_addExitHook(emitMetaData, runOnNonZeroExit)
} else {
emitMetaData()
}
}
//go:linkname runtime_addExitHook runtime.addExitHook
func runtime_addExitHook(f func(), runOnNonZeroExit bool)
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package coverage
import (
"fmt"
"internal/coverage"
"internal/coverage/calloc"
"internal/coverage/cformat"
"internal/coverage/cmerge"
"internal/coverage/decodecounter"
"internal/coverage/decodemeta"
"internal/coverage/pods"
"io"
"os"
"strings"
)
// processCoverTestDir is called (via a linknamed reference) from
// testmain code when "go test -cover" is in effect. It is not
// intended to be used other than internally by the Go command's
// generated code.
func processCoverTestDir(dir string, cfile string, cm string, cpkg string) error {
return processCoverTestDirInternal(dir, cfile, cm, cpkg, os.Stdout)
}
// processCoverTestDirInternal is an io.Writer version of processCoverTestDir,
// exposed for unit testing.
func processCoverTestDirInternal(dir string, cfile string, cm string, cpkg string, w io.Writer) error {
cmode := coverage.ParseCounterMode(cm)
if cmode == coverage.CtrModeInvalid {
return fmt.Errorf("invalid counter mode %q", cm)
}
// Emit meta-data and counter data.
ml := getCovMetaList()
if len(ml) == 0 {
// This corresponds to the case where we have a package that
// contains test code but no functions (which is fine). In this
// case there is no need to emit anything.
} else {
if err := emitMetaDataToDirectory(dir, ml); err != nil {
return err
}
if err := emitCounterDataToDirectory(dir); err != nil {
return err
}
}
// Collect pods from test run. For the majority of cases we would
// expect to see a single pod here, but allow for multiple pods in
// case the test harness is doing extra work to collect data files
// from builds that it kicks off as part of the testing.
podlist, err := pods.CollectPods([]string{dir}, false)
if err != nil {
return fmt.Errorf("reading from %s: %v", dir, err)
}
// Open text output file if appropriate.
var tf *os.File
var tfClosed bool
if cfile != "" {
var err error
tf, err = os.Create(cfile)
if err != nil {
return fmt.Errorf("internal error: opening coverage data output file %q: %v", cfile, err)
}
defer func() {
if !tfClosed {
tfClosed = true
tf.Close()
}
}()
}
// Read/process the pods.
ts := &tstate{
cm: &cmerge.Merger{},
cf: cformat.NewFormatter(cmode),
cmode: cmode,
}
// Generate the expected hash string based on the final meta-data
// hash for this test, then look only for pods that refer to that
// hash (just in case there are multiple instrumented executables
// in play). See issue #57924 for more on this.
hashstring := fmt.Sprintf("%x", finalHash)
for _, p := range podlist {
if !strings.Contains(p.MetaFile, hashstring) {
continue
}
if err := ts.processPod(p); err != nil {
return err
}
}
// Emit percent.
if err := ts.cf.EmitPercent(w, cpkg, true); err != nil {
return err
}
// Emit text output.
if tf != nil {
if err := ts.cf.EmitTextual(tf); err != nil {
return err
}
tfClosed = true
if err := tf.Close(); err != nil {
return fmt.Errorf("closing %s: %v", cfile, err)
}
}
return nil
}
type tstate struct {
calloc.BatchCounterAlloc
cm *cmerge.Merger
cf *cformat.Formatter
cmode coverage.CounterMode
}
// processPod reads coverage counter data for a specific pod.
func (ts *tstate) processPod(p pods.Pod) error {
// Open meta-data file
f, err := os.Open(p.MetaFile)
if err != nil {
return fmt.Errorf("unable to open meta-data file %s: %v", p.MetaFile, err)
}
defer func() {
f.Close()
}()
var mfr *decodemeta.CoverageMetaFileReader
mfr, err = decodemeta.NewCoverageMetaFileReader(f, nil)
if err != nil {
return fmt.Errorf("error reading meta-data file %s: %v", p.MetaFile, err)
}
newmode := mfr.CounterMode()
if newmode != ts.cmode {
return fmt.Errorf("internal error: counter mode clash: %q from test harness, %q from data file %s", ts.cmode.String(), newmode.String(), p.MetaFile)
}
newgran := mfr.CounterGranularity()
if err := ts.cm.SetModeAndGranularity(p.MetaFile, cmode, newgran); err != nil {
return err
}
// A map to store counter data, indexed by pkgid/fnid tuple.
pmm := make(map[pkfunc][]uint32)
// Helper to read a single counter data file.
readcdf := func(cdf string) error {
cf, err := os.Open(cdf)
if err != nil {
return fmt.Errorf("opening counter data file %s: %s", cdf, err)
}
defer cf.Close()
var cdr *decodecounter.CounterDataReader
cdr, err = decodecounter.NewCounterDataReader(cdf, cf)
if err != nil {
return fmt.Errorf("reading counter data file %s: %s", cdf, err)
}
var data decodecounter.FuncPayload
for {
ok, err := cdr.NextFunc(&data)
if err != nil {
return fmt.Errorf("reading counter data file %s: %v", cdf, err)
}
if !ok {
break
}
// NB: sanity check on pkg and func IDs?
key := pkfunc{pk: data.PkgIdx, fcn: data.FuncIdx}
if prev, found := pmm[key]; found {
// Note: no overflow reporting here.
if err, _ := ts.cm.MergeCounters(data.Counters, prev); err != nil {
return fmt.Errorf("processing counter data file %s: %v", cdf, err)
}
}
c := ts.AllocateCounters(len(data.Counters))
copy(c, data.Counters)
pmm[key] = c
}
return nil
}
// Read counter data files.
for _, cdf := range p.CounterDataFiles {
if err := readcdf(cdf); err != nil {
return err
}
}
// Visit meta-data file.
np := uint32(mfr.NumPackages())
payload := []byte{}
for pkIdx := uint32(0); pkIdx < np; pkIdx++ {
var pd *decodemeta.CoverageMetaDataDecoder
pd, payload, err = mfr.GetPackageDecoder(pkIdx, payload)
if err != nil {
return fmt.Errorf("reading pkg %d from meta-file %s: %s", pkIdx, p.MetaFile, err)
}
ts.cf.SetPackage(pd.PackagePath())
var fd coverage.FuncDesc
nf := pd.NumFuncs()
for fnIdx := uint32(0); fnIdx < nf; fnIdx++ {
if err := pd.ReadFunc(fnIdx, &fd); err != nil {
return fmt.Errorf("reading meta-data file %s: %v",
p.MetaFile, err)
}
key := pkfunc{pk: pkIdx, fcn: fnIdx}
counters, haveCounters := pmm[key]
for i := 0; i < len(fd.Units); i++ {
u := fd.Units[i]
// Skip units with non-zero parent (no way to represent
// these in the existing format).
if u.Parent != 0 {
continue
}
count := uint32(0)
if haveCounters {
count = counters[i]
}
ts.cf.AddUnit(fd.Srcfile, fd.Funcname, fd.Lit, u, count)
}
}
}
return nil
}
type pkfunc struct {
pk, fcn uint32
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/coverage/rtcov"
"unsafe"
)
//go:linkname runtime_coverage_getCovCounterList runtime/coverage.getCovCounterList
func runtime_coverage_getCovCounterList() []rtcov.CovCounterBlob {
res := []rtcov.CovCounterBlob{}
u32sz := unsafe.Sizeof(uint32(0))
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if datap.covctrs == datap.ecovctrs {
continue
}
res = append(res, rtcov.CovCounterBlob{
Counters: (*uint32)(unsafe.Pointer(datap.covctrs)),
Len: uint64((datap.ecovctrs - datap.covctrs) / u32sz),
})
}
return res
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/coverage/rtcov"
"unsafe"
)
// covMeta is the top-level container for bits of state related to
// code coverage meta-data in the runtime.
var covMeta struct {
// metaList contains the list of currently registered meta-data
// blobs for the running program.
metaList []rtcov.CovMetaBlob
// pkgMap records mappings from hard-coded package IDs to
// slots in the covMetaList above.
pkgMap map[int]int
// Set to true if we discover a package mapping glitch.
hardCodedListNeedsUpdating bool
}
// addCovMeta is invoked during package "init" functions by the
// compiler when compiling for coverage instrumentation; here 'p' is a
// meta-data blob of length 'dlen' for the package in question, 'hash'
// is a compiler-computed md5.sum for the blob, 'pkpath' is the
// package path, 'pkid' is the hard-coded ID that the compiler is
// using for the package (or -1 if the compiler doesn't think a
// hard-coded ID is needed), and 'cmode'/'cgran' are the coverage
// counter mode and granularity requested by the user. Return value is
// the ID for the package for use by the package code itself.
func addCovMeta(p unsafe.Pointer, dlen uint32, hash [16]byte, pkpath string, pkid int, cmode uint8, cgran uint8) uint32 {
slot := len(covMeta.metaList)
covMeta.metaList = append(covMeta.metaList,
rtcov.CovMetaBlob{
P: (*byte)(p),
Len: dlen,
Hash: hash,
PkgPath: pkpath,
PkgID: pkid,
CounterMode: cmode,
CounterGranularity: cgran,
})
if pkid != -1 {
if covMeta.pkgMap == nil {
covMeta.pkgMap = make(map[int]int)
}
if _, ok := covMeta.pkgMap[pkid]; ok {
throw("runtime.addCovMeta: coverage package map collision")
}
// Record the real slot (position on meta-list) for this
// package; we'll use the map to fix things up later on.
covMeta.pkgMap[pkid] = slot
}
// ID zero is reserved as invalid.
return uint32(slot + 1)
}
//go:linkname runtime_coverage_getCovMetaList runtime/coverage.getCovMetaList
func runtime_coverage_getCovMetaList() []rtcov.CovMetaBlob {
return covMeta.metaList
}
//go:linkname runtime_coverage_getCovPkgMap runtime/coverage.getCovPkgMap
func runtime_coverage_getCovPkgMap() map[int]int {
return covMeta.pkgMap
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/cpu"
)
var useAVXmemmove bool
func init() {
// Let's remove stepping and reserved fields
processor := processorVersionInfo & 0x0FFF3FF0
isIntelBridgeFamily := isIntel &&
processor == 0x206A0 ||
processor == 0x206D0 ||
processor == 0x306A0 ||
processor == 0x306E0
useAVXmemmove = cpu.X86.HasAVX && !isIntelBridgeFamily
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// CPU profiling.
//
// The signal handler for the profiling clock tick adds a new stack trace
// to a log of recent traces. The log is read by a user goroutine that
// turns it into formatted profile data. If the reader does not keep up
// with the log, those writes will be recorded as a count of lost records.
// The actual profile buffer is in profbuf.go.
package runtime
import (
"internal/abi"
"runtime/internal/sys"
"unsafe"
)
const (
maxCPUProfStack = 64
// profBufWordCount is the size of the CPU profile buffer's storage for the
// header and stack of each sample, measured in 64-bit words. Every sample
// has a required header of two words. With a small additional header (a
// word or two) and stacks at the profiler's maximum length of 64 frames,
// that capacity can support 1900 samples or 19 thread-seconds at a 100 Hz
// sample rate, at a cost of 1 MiB.
profBufWordCount = 1 << 17
// profBufTagCount is the size of the CPU profile buffer's storage for the
// goroutine tags associated with each sample. A capacity of 1<<14 means
// room for 16k samples, or 160 thread-seconds at a 100 Hz sample rate.
profBufTagCount = 1 << 14
)
type cpuProfile struct {
lock mutex
on bool // profiling is on
log *profBuf // profile events written here
// extra holds extra stacks accumulated in addNonGo
// corresponding to profiling signals arriving on
// non-Go-created threads. Those stacks are written
// to log the next time a normal Go thread gets the
// signal handler.
// Assuming the stacks are 2 words each (we don't get
// a full traceback from those threads), plus one word
// size for framing, 100 Hz profiling would generate
// 300 words per second.
// Hopefully a normal Go thread will get the profiling
// signal at least once every few seconds.
extra [1000]uintptr
numExtra int
lostExtra uint64 // count of frames lost because extra is full
lostAtomic uint64 // count of frames lost because of being in atomic64 on mips/arm; updated racily
}
var cpuprof cpuProfile
// SetCPUProfileRate sets the CPU profiling rate to hz samples per second.
// If hz <= 0, SetCPUProfileRate turns off profiling.
// If the profiler is on, the rate cannot be changed without first turning it off.
//
// Most clients should use the runtime/pprof package or
// the testing package's -test.cpuprofile flag instead of calling
// SetCPUProfileRate directly.
func SetCPUProfileRate(hz int) {
// Clamp hz to something reasonable.
if hz < 0 {
hz = 0
}
if hz > 1000000 {
hz = 1000000
}
lock(&cpuprof.lock)
if hz > 0 {
if cpuprof.on || cpuprof.log != nil {
print("runtime: cannot set cpu profile rate until previous profile has finished.\n")
unlock(&cpuprof.lock)
return
}
cpuprof.on = true
cpuprof.log = newProfBuf(1, profBufWordCount, profBufTagCount)
hdr := [1]uint64{uint64(hz)}
cpuprof.log.write(nil, nanotime(), hdr[:], nil)
setcpuprofilerate(int32(hz))
} else if cpuprof.on {
setcpuprofilerate(0)
cpuprof.on = false
cpuprof.addExtra()
cpuprof.log.close()
}
unlock(&cpuprof.lock)
}
// add adds the stack trace to the profile.
// It is called from signal handlers and other limited environments
// and cannot allocate memory or acquire locks that might be
// held at the time of the signal, nor can it use substantial amounts
// of stack.
//
//go:nowritebarrierrec
func (p *cpuProfile) add(tagPtr *unsafe.Pointer, stk []uintptr) {
// Simple cas-lock to coordinate with setcpuprofilerate.
for !prof.signalLock.CompareAndSwap(0, 1) {
// TODO: Is it safe to osyield here? https://go.dev/issue/52672
osyield()
}
if prof.hz.Load() != 0 { // implies cpuprof.log != nil
if p.numExtra > 0 || p.lostExtra > 0 || p.lostAtomic > 0 {
p.addExtra()
}
hdr := [1]uint64{1}
// Note: write "knows" that the argument is &gp.labels,
// because otherwise its write barrier behavior may not
// be correct. See the long comment there before
// changing the argument here.
cpuprof.log.write(tagPtr, nanotime(), hdr[:], stk)
}
prof.signalLock.Store(0)
}
// addNonGo adds the non-Go stack trace to the profile.
// It is called from a non-Go thread, so we cannot use much stack at all,
// nor do anything that needs a g or an m.
// In particular, we can't call cpuprof.log.write.
// Instead, we copy the stack into cpuprof.extra,
// which will be drained the next time a Go thread
// gets the signal handling event.
//
//go:nosplit
//go:nowritebarrierrec
func (p *cpuProfile) addNonGo(stk []uintptr) {
// Simple cas-lock to coordinate with SetCPUProfileRate.
// (Other calls to add or addNonGo should be blocked out
// by the fact that only one SIGPROF can be handled by the
// process at a time. If not, this lock will serialize those too.
// The use of timer_create(2) on Linux to request process-targeted
// signals may have changed this.)
for !prof.signalLock.CompareAndSwap(0, 1) {
// TODO: Is it safe to osyield here? https://go.dev/issue/52672
osyield()
}
if cpuprof.numExtra+1+len(stk) < len(cpuprof.extra) {
i := cpuprof.numExtra
cpuprof.extra[i] = uintptr(1 + len(stk))
copy(cpuprof.extra[i+1:], stk)
cpuprof.numExtra += 1 + len(stk)
} else {
cpuprof.lostExtra++
}
prof.signalLock.Store(0)
}
// addExtra adds the "extra" profiling events,
// queued by addNonGo, to the profile log.
// addExtra is called either from a signal handler on a Go thread
// or from an ordinary goroutine; either way it can use stack
// and has a g. The world may be stopped, though.
func (p *cpuProfile) addExtra() {
// Copy accumulated non-Go profile events.
hdr := [1]uint64{1}
for i := 0; i < p.numExtra; {
p.log.write(nil, 0, hdr[:], p.extra[i+1:i+int(p.extra[i])])
i += int(p.extra[i])
}
p.numExtra = 0
// Report any lost events.
if p.lostExtra > 0 {
hdr := [1]uint64{p.lostExtra}
lostStk := [2]uintptr{
abi.FuncPCABIInternal(_LostExternalCode) + sys.PCQuantum,
abi.FuncPCABIInternal(_ExternalCode) + sys.PCQuantum,
}
p.log.write(nil, 0, hdr[:], lostStk[:])
p.lostExtra = 0
}
if p.lostAtomic > 0 {
hdr := [1]uint64{p.lostAtomic}
lostStk := [2]uintptr{
abi.FuncPCABIInternal(_LostSIGPROFDuringAtomic64) + sys.PCQuantum,
abi.FuncPCABIInternal(_System) + sys.PCQuantum,
}
p.log.write(nil, 0, hdr[:], lostStk[:])
p.lostAtomic = 0
}
}
// CPUProfile panics.
// It formerly provided raw access to chunks of
// a pprof-format profile generated by the runtime.
// The details of generating that format have changed,
// so this functionality has been removed.
//
// Deprecated: Use the runtime/pprof package,
// or the handlers in the net/http/pprof package,
// or the testing package's -test.cpuprofile flag instead.
func CPUProfile() []byte {
panic("CPUProfile no longer available")
}
//go:linkname runtime_pprof_runtime_cyclesPerSecond runtime/pprof.runtime_cyclesPerSecond
func runtime_pprof_runtime_cyclesPerSecond() int64 {
return tickspersecond()
}
// readProfile, provided to runtime/pprof, returns the next chunk of
// binary CPU profiling stack trace data, blocking until data is available.
// If profiling is turned off and all the profile data accumulated while it was
// on has been returned, readProfile returns eof=true.
// The caller must save the returned data and tags before calling readProfile again.
// The returned data contains a whole number of records, and tags contains
// exactly one entry per record.
//
//go:linkname runtime_pprof_readProfile runtime/pprof.readProfile
func runtime_pprof_readProfile() ([]uint64, []unsafe.Pointer, bool) {
lock(&cpuprof.lock)
log := cpuprof.log
unlock(&cpuprof.lock)
data, tags, eof := log.read(profBufBlocking)
if len(data) == 0 && eof {
lock(&cpuprof.lock)
cpuprof.log = nil
unlock(&cpuprof.lock)
}
return data, tags, eof
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package runtime
const canCreateFile = true
// create returns an fd to a write-only file.
func create(name *byte, perm int32) int32 {
return open(name, _O_CREAT|_O_WRONLY|_O_TRUNC, perm)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
// GOMAXPROCS sets the maximum number of CPUs that can be executing
// simultaneously and returns the previous setting. It defaults to
// the value of runtime.NumCPU. If n < 1, it does not change the current setting.
// This call will go away when the scheduler improves.
func GOMAXPROCS(n int) int {
if GOARCH == "wasm" && n > 1 {
n = 1 // WebAssembly has no threads yet, so only one CPU is possible.
}
lock(&sched.lock)
ret := int(gomaxprocs)
unlock(&sched.lock)
if n <= 0 || n == ret {
return ret
}
stopTheWorldGC("GOMAXPROCS")
// newprocs will be processed by startTheWorld
newprocs = int32(n)
startTheWorldGC()
return ret
}
// NumCPU returns the number of logical CPUs usable by the current process.
//
// The set of available CPUs is checked by querying the operating system
// at process startup. Changes to operating system CPU allocation after
// process startup are not reflected.
func NumCPU() int {
return int(ncpu)
}
// NumCgoCall returns the number of cgo calls made by the current process.
func NumCgoCall() int64 {
var n = int64(atomic.Load64(&ncgocall))
for mp := (*m)(atomic.Loadp(unsafe.Pointer(&allm))); mp != nil; mp = mp.alllink {
n += int64(mp.ncgocall)
}
return n
}
// NumGoroutine returns the number of goroutines that currently exist.
func NumGoroutine() int {
return int(gcount())
}
//go:linkname debug_modinfo runtime/debug.modinfo
func debug_modinfo() string {
return modinfo
}
// mayMoreStackPreempt is a maymorestack hook that forces a preemption
// at every possible cooperative preemption point.
//
// This is valuable to apply to the runtime, which can be sensitive to
// preemption points. To apply this to all preemption points in the
// runtime and runtime-like code, use the following in bash or zsh:
//
// X=(-{gc,asm}flags={runtime/...,reflect,sync}=-d=maymorestack=runtime.mayMoreStackPreempt) GOFLAGS=${X[@]}
//
// This must be deeply nosplit because it is called from a function
// prologue before the stack is set up and because the compiler will
// call it from any splittable prologue (leading to infinite
// recursion).
//
// Ideally it should also use very little stack because the linker
// doesn't currently account for this in nosplit stack depth checking.
//
// Ensure mayMoreStackPreempt can be called for all ABIs.
//
//go:nosplit
//go:linkname mayMoreStackPreempt
func mayMoreStackPreempt() {
// Don't do anything on the g0 or gsignal stack.
gp := getg()
if gp == gp.m.g0 || gp == gp.m.gsignal {
return
}
// Force a preemption, unless the stack is already poisoned.
if gp.stackguard0 < stackPoisonMin {
gp.stackguard0 = stackPreempt
}
}
// mayMoreStackMove is a maymorestack hook that forces stack movement
// at every possible point.
//
// See mayMoreStackPreempt.
//
//go:nosplit
//go:linkname mayMoreStackMove
func mayMoreStackMove() {
// Don't do anything on the g0 or gsignal stack.
gp := getg()
if gp == gp.m.g0 || gp == gp.m.gsignal {
return
}
// Force stack movement, unless the stack is already poisoned.
if gp.stackguard0 < stackPoisonMin {
gp.stackguard0 = stackForceMove
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package debug
import (
"runtime"
"sort"
"time"
)
// GCStats collect information about recent garbage collections.
type GCStats struct {
LastGC time.Time // time of last collection
NumGC int64 // number of garbage collections
PauseTotal time.Duration // total pause for all collections
Pause []time.Duration // pause history, most recent first
PauseEnd []time.Time // pause end times history, most recent first
PauseQuantiles []time.Duration
}
// ReadGCStats reads statistics about garbage collection into stats.
// The number of entries in the pause history is system-dependent;
// stats.Pause slice will be reused if large enough, reallocated otherwise.
// ReadGCStats may use the full capacity of the stats.Pause slice.
// If stats.PauseQuantiles is non-empty, ReadGCStats fills it with quantiles
// summarizing the distribution of pause time. For example, if
// len(stats.PauseQuantiles) is 5, it will be filled with the minimum,
// 25%, 50%, 75%, and maximum pause times.
func ReadGCStats(stats *GCStats) {
// Create a buffer with space for at least two copies of the
// pause history tracked by the runtime. One will be returned
// to the caller and the other will be used as transfer buffer
// for end times history and as a temporary buffer for
// computing quantiles.
const maxPause = len(((*runtime.MemStats)(nil)).PauseNs)
if cap(stats.Pause) < 2*maxPause+3 {
stats.Pause = make([]time.Duration, 2*maxPause+3)
}
// readGCStats fills in the pause and end times histories (up to
// maxPause entries) and then three more: Unix ns time of last GC,
// number of GC, and total pause time in nanoseconds. Here we
// depend on the fact that time.Duration's native unit is
// nanoseconds, so the pauses and the total pause time do not need
// any conversion.
readGCStats(&stats.Pause)
n := len(stats.Pause) - 3
stats.LastGC = time.Unix(0, int64(stats.Pause[n]))
stats.NumGC = int64(stats.Pause[n+1])
stats.PauseTotal = stats.Pause[n+2]
n /= 2 // buffer holds pauses and end times
stats.Pause = stats.Pause[:n]
if cap(stats.PauseEnd) < maxPause {
stats.PauseEnd = make([]time.Time, 0, maxPause)
}
stats.PauseEnd = stats.PauseEnd[:0]
for _, ns := range stats.Pause[n : n+n] {
stats.PauseEnd = append(stats.PauseEnd, time.Unix(0, int64(ns)))
}
if len(stats.PauseQuantiles) > 0 {
if n == 0 {
for i := range stats.PauseQuantiles {
stats.PauseQuantiles[i] = 0
}
} else {
// There's room for a second copy of the data in stats.Pause.
// See the allocation at the top of the function.
sorted := stats.Pause[n : n+n]
copy(sorted, stats.Pause)
sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] })
nq := len(stats.PauseQuantiles) - 1
for i := 0; i < nq; i++ {
stats.PauseQuantiles[i] = sorted[len(sorted)*i/nq]
}
stats.PauseQuantiles[nq] = sorted[len(sorted)-1]
}
}
}
// SetGCPercent sets the garbage collection target percentage:
// a collection is triggered when the ratio of freshly allocated data
// to live data remaining after the previous collection reaches this percentage.
// SetGCPercent returns the previous setting.
// The initial setting is the value of the GOGC environment variable
// at startup, or 100 if the variable is not set.
// This setting may be effectively reduced in order to maintain a memory
// limit.
// A negative percentage effectively disables garbage collection, unless
// the memory limit is reached.
// See SetMemoryLimit for more details.
func SetGCPercent(percent int) int {
return int(setGCPercent(int32(percent)))
}
// FreeOSMemory forces a garbage collection followed by an
// attempt to return as much memory to the operating system
// as possible. (Even if this is not called, the runtime gradually
// returns memory to the operating system in a background task.)
func FreeOSMemory() {
freeOSMemory()
}
// SetMaxStack sets the maximum amount of memory that
// can be used by a single goroutine stack.
// If any goroutine exceeds this limit while growing its stack,
// the program crashes.
// SetMaxStack returns the previous setting.
// The initial setting is 1 GB on 64-bit systems, 250 MB on 32-bit systems.
// There may be a system-imposed maximum stack limit regardless
// of the value provided to SetMaxStack.
//
// SetMaxStack is useful mainly for limiting the damage done by
// goroutines that enter an infinite recursion. It only limits future
// stack growth.
func SetMaxStack(bytes int) int {
return setMaxStack(bytes)
}
// SetMaxThreads sets the maximum number of operating system
// threads that the Go program can use. If it attempts to use more than
// this many, the program crashes.
// SetMaxThreads returns the previous setting.
// The initial setting is 10,000 threads.
//
// The limit controls the number of operating system threads, not the number
// of goroutines. A Go program creates a new thread only when a goroutine
// is ready to run but all the existing threads are blocked in system calls, cgo calls,
// or are locked to other goroutines due to use of runtime.LockOSThread.
//
// SetMaxThreads is useful mainly for limiting the damage done by
// programs that create an unbounded number of threads. The idea is
// to take down the program before it takes down the operating system.
func SetMaxThreads(threads int) int {
return setMaxThreads(threads)
}
// SetPanicOnFault controls the runtime's behavior when a program faults
// at an unexpected (non-nil) address. Such faults are typically caused by
// bugs such as runtime memory corruption, so the default response is to crash
// the program. Programs working with memory-mapped files or unsafe
// manipulation of memory may cause faults at non-nil addresses in less
// dramatic situations; SetPanicOnFault allows such programs to request
// that the runtime trigger only a panic, not a crash.
// The runtime.Error that the runtime panics with may have an additional method:
//
// Addr() uintptr
//
// If that method exists, it returns the memory address which triggered the fault.
// The results of Addr are best-effort and the veracity of the result
// may depend on the platform.
// SetPanicOnFault applies only to the current goroutine.
// It returns the previous setting.
func SetPanicOnFault(enabled bool) bool {
return setPanicOnFault(enabled)
}
// WriteHeapDump writes a description of the heap and the objects in
// it to the given file descriptor.
//
// WriteHeapDump suspends the execution of all goroutines until the heap
// dump is completely written. Thus, the file descriptor must not be
// connected to a pipe or socket whose other end is in the same Go
// process; instead, use a temporary file or network socket.
//
// The heap dump format is defined at https://golang.org/s/go15heapdump.
func WriteHeapDump(fd uintptr)
// SetTraceback sets the amount of detail printed by the runtime in
// the traceback it prints before exiting due to an unrecovered panic
// or an internal runtime error.
// The level argument takes the same values as the GOTRACEBACK
// environment variable. For example, SetTraceback("all") ensure
// that the program prints all goroutines when it crashes.
// See the package runtime documentation for details.
// If SetTraceback is called with a level lower than that of the
// environment variable, the call is ignored.
func SetTraceback(level string)
// SetMemoryLimit provides the runtime with a soft memory limit.
//
// The runtime undertakes several processes to try to respect this
// memory limit, including adjustments to the frequency of garbage
// collections and returning memory to the underlying system more
// aggressively. This limit will be respected even if GOGC=off (or,
// if SetGCPercent(-1) is executed).
//
// The input limit is provided as bytes, and includes all memory
// mapped, managed, and not released by the Go runtime. Notably, it
// does not account for space used by the Go binary and memory
// external to Go, such as memory managed by the underlying system
// on behalf of the process, or memory managed by non-Go code inside
// the same process. Examples of excluded memory sources include: OS
// kernel memory held on behalf of the process, memory allocated by
// C code, and memory mapped by syscall.Mmap (because it is not
// managed by the Go runtime).
//
// More specifically, the following expression accurately reflects
// the value the runtime attempts to maintain as the limit:
//
// runtime.MemStats.Sys - runtime.MemStats.HeapReleased
//
// or in terms of the runtime/metrics package:
//
// /memory/classes/total:bytes - /memory/classes/heap/released:bytes
//
// A zero limit or a limit that's lower than the amount of memory
// used by the Go runtime may cause the garbage collector to run
// nearly continuously. However, the application may still make
// progress.
//
// The memory limit is always respected by the Go runtime, so to
// effectively disable this behavior, set the limit very high.
// math.MaxInt64 is the canonical value for disabling the limit,
// but values much greater than the available memory on the underlying
// system work just as well.
//
// See https://go.dev/doc/gc-guide for a detailed guide explaining
// the soft memory limit in more detail, as well as a variety of common
// use-cases and scenarios.
//
// The initial setting is math.MaxInt64 unless the GOMEMLIMIT
// environment variable is set, in which case it provides the initial
// setting. GOMEMLIMIT is a numeric value in bytes with an optional
// unit suffix. The supported suffixes include B, KiB, MiB, GiB, and
// TiB. These suffixes represent quantities of bytes as defined by
// the IEC 80000-13 standard. That is, they are based on powers of
// two: KiB means 2^10 bytes, MiB means 2^20 bytes, and so on.
//
// SetMemoryLimit returns the previously set memory limit.
// A negative input does not adjust the limit, and allows for
// retrieval of the currently set memory limit.
func SetMemoryLimit(limit int64) int64 {
return setMemoryLimit(limit)
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package debug
import (
"fmt"
"runtime"
"strconv"
"strings"
)
// exported from runtime.
func modinfo() string
// ReadBuildInfo returns the build information embedded
// in the running binary. The information is available only
// in binaries built with module support.
func ReadBuildInfo() (info *BuildInfo, ok bool) {
data := modinfo()
if len(data) < 32 {
return nil, false
}
data = data[16 : len(data)-16]
bi, err := ParseBuildInfo(data)
if err != nil {
return nil, false
}
// The go version is stored separately from other build info, mostly for
// historical reasons. It is not part of the modinfo() string, and
// ParseBuildInfo does not recognize it. We inject it here to hide this
// awkwardness from the user.
bi.GoVersion = runtime.Version()
return bi, true
}
// BuildInfo represents the build information read from a Go binary.
type BuildInfo struct {
// GoVersion is the version of the Go toolchain that built the binary
// (for example, "go1.19.2").
GoVersion string
// Path is the package path of the main package for the binary
// (for example, "golang.org/x/tools/cmd/stringer").
Path string
// Main describes the module that contains the main package for the binary.
Main Module
// Deps describes all the dependency modules, both direct and indirect,
// that contributed packages to the build of this binary.
Deps []*Module
// Settings describes the build settings used to build the binary.
Settings []BuildSetting
}
// A Module describes a single module included in a build.
type Module struct {
Path string // module path
Version string // module version
Sum string // checksum
Replace *Module // replaced by this module
}
// A BuildSetting is a key-value pair describing one setting that influenced a build.
//
// Defined keys include:
//
// - -buildmode: the buildmode flag used (typically "exe")
// - -compiler: the compiler toolchain flag used (typically "gc")
// - CGO_ENABLED: the effective CGO_ENABLED environment variable
// - CGO_CFLAGS: the effective CGO_CFLAGS environment variable
// - CGO_CPPFLAGS: the effective CGO_CPPFLAGS environment variable
// - CGO_CXXFLAGS: the effective CGO_CPPFLAGS environment variable
// - CGO_LDFLAGS: the effective CGO_CPPFLAGS environment variable
// - GOARCH: the architecture target
// - GOAMD64/GOARM64/GO386/etc: the architecture feature level for GOARCH
// - GOOS: the operating system target
// - vcs: the version control system for the source tree where the build ran
// - vcs.revision: the revision identifier for the current commit or checkout
// - vcs.time: the modification time associated with vcs.revision, in RFC3339 format
// - vcs.modified: true or false indicating whether the source tree had local modifications
type BuildSetting struct {
// Key and Value describe the build setting.
// Key must not contain an equals sign, space, tab, or newline.
// Value must not contain newlines ('\n').
Key, Value string
}
// quoteKey reports whether key is required to be quoted.
func quoteKey(key string) bool {
return len(key) == 0 || strings.ContainsAny(key, "= \t\r\n\"`")
}
// quoteValue reports whether value is required to be quoted.
func quoteValue(value string) bool {
return strings.ContainsAny(value, " \t\r\n\"`")
}
func (bi *BuildInfo) String() string {
buf := new(strings.Builder)
if bi.GoVersion != "" {
fmt.Fprintf(buf, "go\t%s\n", bi.GoVersion)
}
if bi.Path != "" {
fmt.Fprintf(buf, "path\t%s\n", bi.Path)
}
var formatMod func(string, Module)
formatMod = func(word string, m Module) {
buf.WriteString(word)
buf.WriteByte('\t')
buf.WriteString(m.Path)
buf.WriteByte('\t')
buf.WriteString(m.Version)
if m.Replace == nil {
buf.WriteByte('\t')
buf.WriteString(m.Sum)
} else {
buf.WriteByte('\n')
formatMod("=>", *m.Replace)
}
buf.WriteByte('\n')
}
if bi.Main != (Module{}) {
formatMod("mod", bi.Main)
}
for _, dep := range bi.Deps {
formatMod("dep", *dep)
}
for _, s := range bi.Settings {
key := s.Key
if quoteKey(key) {
key = strconv.Quote(key)
}
value := s.Value
if quoteValue(value) {
value = strconv.Quote(value)
}
fmt.Fprintf(buf, "build\t%s=%s\n", key, value)
}
return buf.String()
}
func ParseBuildInfo(data string) (bi *BuildInfo, err error) {
lineNum := 1
defer func() {
if err != nil {
err = fmt.Errorf("could not parse Go build info: line %d: %w", lineNum, err)
}
}()
var (
pathLine = "path\t"
modLine = "mod\t"
depLine = "dep\t"
repLine = "=>\t"
buildLine = "build\t"
newline = "\n"
tab = "\t"
)
readModuleLine := func(elem []string) (Module, error) {
if len(elem) != 2 && len(elem) != 3 {
return Module{}, fmt.Errorf("expected 2 or 3 columns; got %d", len(elem))
}
version := elem[1]
sum := ""
if len(elem) == 3 {
sum = elem[2]
}
return Module{
Path: elem[0],
Version: version,
Sum: sum,
}, nil
}
bi = new(BuildInfo)
var (
last *Module
line string
ok bool
)
// Reverse of BuildInfo.String(), except for go version.
for len(data) > 0 {
line, data, ok = strings.Cut(data, newline)
if !ok {
break
}
switch {
case strings.HasPrefix(line, pathLine):
elem := line[len(pathLine):]
bi.Path = string(elem)
case strings.HasPrefix(line, modLine):
elem := strings.Split(line[len(modLine):], tab)
last = &bi.Main
*last, err = readModuleLine(elem)
if err != nil {
return nil, err
}
case strings.HasPrefix(line, depLine):
elem := strings.Split(line[len(depLine):], tab)
last = new(Module)
bi.Deps = append(bi.Deps, last)
*last, err = readModuleLine(elem)
if err != nil {
return nil, err
}
case strings.HasPrefix(line, repLine):
elem := strings.Split(line[len(repLine):], tab)
if len(elem) != 3 {
return nil, fmt.Errorf("expected 3 columns for replacement; got %d", len(elem))
}
if last == nil {
return nil, fmt.Errorf("replacement with no module on previous line")
}
last.Replace = &Module{
Path: string(elem[0]),
Version: string(elem[1]),
Sum: string(elem[2]),
}
last = nil
case strings.HasPrefix(line, buildLine):
kv := line[len(buildLine):]
if len(kv) < 1 {
return nil, fmt.Errorf("build line missing '='")
}
var key, rawValue string
switch kv[0] {
case '=':
return nil, fmt.Errorf("build line with missing key")
case '`', '"':
rawKey, err := strconv.QuotedPrefix(kv)
if err != nil {
return nil, fmt.Errorf("invalid quoted key in build line")
}
if len(kv) == len(rawKey) {
return nil, fmt.Errorf("build line missing '=' after quoted key")
}
if c := kv[len(rawKey)]; c != '=' {
return nil, fmt.Errorf("unexpected character after quoted key: %q", c)
}
key, _ = strconv.Unquote(rawKey)
rawValue = kv[len(rawKey)+1:]
default:
var ok bool
key, rawValue, ok = strings.Cut(kv, "=")
if !ok {
return nil, fmt.Errorf("build line missing '=' after key")
}
if quoteKey(key) {
return nil, fmt.Errorf("unquoted key %q must be quoted", key)
}
}
var value string
if len(rawValue) > 0 {
switch rawValue[0] {
case '`', '"':
var err error
value, err = strconv.Unquote(rawValue)
if err != nil {
return nil, fmt.Errorf("invalid quoted value in build line")
}
default:
value = rawValue
if quoteValue(value) {
return nil, fmt.Errorf("unquoted value %q must be quoted", value)
}
}
}
bi.Settings = append(bi.Settings, BuildSetting{Key: key, Value: value})
}
lineNum++
}
return bi, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package debug contains facilities for programs to debug themselves while
// they are running.
package debug
import (
"os"
"runtime"
)
// PrintStack prints to standard error the stack trace returned by runtime.Stack.
func PrintStack() {
os.Stderr.Write(Stack())
}
// Stack returns a formatted stack trace of the goroutine that calls it.
// It calls runtime.Stack with a large enough buffer to capture the entire trace.
func Stack() []byte {
buf := make([]byte, 1024)
for {
n := runtime.Stack(buf, false)
if n < len(buf) {
return buf[:n]
}
buf = make([]byte, 2*len(buf))
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64
package runtime
import "unsafe"
const (
debugCallSystemStack = "executing on Go runtime stack"
debugCallUnknownFunc = "call from unknown function"
debugCallRuntime = "call from within the Go runtime"
debugCallUnsafePoint = "call not at safe point"
)
func debugCallV2()
func debugCallPanicked(val any)
// debugCallCheck checks whether it is safe to inject a debugger
// function call with return PC pc. If not, it returns a string
// explaining why.
//
//go:nosplit
func debugCallCheck(pc uintptr) string {
// No user calls from the system stack.
if getg() != getg().m.curg {
return debugCallSystemStack
}
if sp := getcallersp(); !(getg().stack.lo < sp && sp <= getg().stack.hi) {
// Fast syscalls (nanotime) and racecall switch to the
// g0 stack without switching g. We can't safely make
// a call in this state. (We can't even safely
// systemstack.)
return debugCallSystemStack
}
// Switch to the system stack to avoid overflowing the user
// stack.
var ret string
systemstack(func() {
f := findfunc(pc)
if !f.valid() {
ret = debugCallUnknownFunc
return
}
name := funcname(f)
switch name {
case "debugCall32",
"debugCall64",
"debugCall128",
"debugCall256",
"debugCall512",
"debugCall1024",
"debugCall2048",
"debugCall4096",
"debugCall8192",
"debugCall16384",
"debugCall32768",
"debugCall65536":
// These functions are allowed so that the debugger can initiate multiple function calls.
// See: https://golang.org/cl/161137/
return
}
// Disallow calls from the runtime. We could
// potentially make this condition tighter (e.g., not
// when locks are held), but there are enough tightly
// coded sequences (e.g., defer handling) that it's
// better to play it safe.
if pfx := "runtime."; len(name) > len(pfx) && name[:len(pfx)] == pfx {
ret = debugCallRuntime
return
}
// Check that this isn't an unsafe-point.
if pc != f.entry() {
pc--
}
up := pcdatavalue(f, _PCDATA_UnsafePoint, pc, nil)
if up != _PCDATA_UnsafePointSafe {
// Not at a safe point.
ret = debugCallUnsafePoint
}
})
return ret
}
// debugCallWrap starts a new goroutine to run a debug call and blocks
// the calling goroutine. On the goroutine, it prepares to recover
// panics from the debug call, and then calls the call dispatching
// function at PC dispatch.
//
// This must be deeply nosplit because there are untyped values on the
// stack from debugCallV2.
//
//go:nosplit
func debugCallWrap(dispatch uintptr) {
var lockedm bool
var lockedExt uint32
callerpc := getcallerpc()
gp := getg()
// Create a new goroutine to execute the call on. Run this on
// the system stack to avoid growing our stack.
systemstack(func() {
// TODO(mknyszek): It would be nice to wrap these arguments in an allocated
// closure and start the goroutine with that closure, but the compiler disallows
// implicit closure allocation in the runtime.
fn := debugCallWrap1
newg := newproc1(*(**funcval)(unsafe.Pointer(&fn)), gp, callerpc)
args := &debugCallWrapArgs{
dispatch: dispatch,
callingG: gp,
}
newg.param = unsafe.Pointer(args)
// If the current G is locked, then transfer that
// locked-ness to the new goroutine.
if gp.lockedm != 0 {
// Save lock state to restore later.
mp := gp.m
if mp != gp.lockedm.ptr() {
throw("inconsistent lockedm")
}
lockedm = true
lockedExt = mp.lockedExt
// Transfer external lock count to internal so
// it can't be unlocked from the debug call.
mp.lockedInt++
mp.lockedExt = 0
mp.lockedg.set(newg)
newg.lockedm.set(mp)
gp.lockedm = 0
}
// Mark the calling goroutine as being at an async
// safe-point, since it has a few conservative frames
// at the bottom of the stack. This also prevents
// stack shrinks.
gp.asyncSafePoint = true
// Stash newg away so we can execute it below (mcall's
// closure can't capture anything).
gp.schedlink.set(newg)
})
// Switch to the new goroutine.
mcall(func(gp *g) {
// Get newg.
newg := gp.schedlink.ptr()
gp.schedlink = 0
// Park the calling goroutine.
if trace.enabled {
traceGoPark(traceEvGoBlock, 1)
}
casGToWaiting(gp, _Grunning, waitReasonDebugCall)
dropg()
// Directly execute the new goroutine. The debug
// protocol will continue on the new goroutine, so
// it's important we not just let the scheduler do
// this or it may resume a different goroutine.
execute(newg, true)
})
// We'll resume here when the call returns.
// Restore locked state.
if lockedm {
mp := gp.m
mp.lockedExt = lockedExt
mp.lockedInt--
mp.lockedg.set(gp)
gp.lockedm.set(mp)
}
gp.asyncSafePoint = false
}
type debugCallWrapArgs struct {
dispatch uintptr
callingG *g
}
// debugCallWrap1 is the continuation of debugCallWrap on the callee
// goroutine.
func debugCallWrap1() {
gp := getg()
args := (*debugCallWrapArgs)(gp.param)
dispatch, callingG := args.dispatch, args.callingG
gp.param = nil
// Dispatch call and trap panics.
debugCallWrap2(dispatch)
// Resume the caller goroutine.
getg().schedlink.set(callingG)
mcall(func(gp *g) {
callingG := gp.schedlink.ptr()
gp.schedlink = 0
// Unlock this goroutine from the M if necessary. The
// calling G will relock.
if gp.lockedm != 0 {
gp.lockedm = 0
gp.m.lockedg = 0
}
// Switch back to the calling goroutine. At some point
// the scheduler will schedule us again and we'll
// finish exiting.
if trace.enabled {
traceGoSched()
}
casgstatus(gp, _Grunning, _Grunnable)
dropg()
lock(&sched.lock)
globrunqput(gp)
unlock(&sched.lock)
if trace.enabled {
traceGoUnpark(callingG, 0)
}
casgstatus(callingG, _Gwaiting, _Grunnable)
execute(callingG, true)
})
}
func debugCallWrap2(dispatch uintptr) {
// Call the dispatch function and trap panics.
var dispatchF func()
dispatchFV := funcval{dispatch}
*(*unsafe.Pointer)(unsafe.Pointer(&dispatchF)) = noescape(unsafe.Pointer(&dispatchFV))
var ok bool
defer func() {
if !ok {
err := recover()
debugCallPanicked(err)
}
}()
dispatchF()
ok = true
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file provides an internal debug logging facility. The debug
// log is a lightweight, in-memory, per-M ring buffer. By default, the
// runtime prints the debug log on panic.
//
// To print something to the debug log, call dlog to obtain a dlogger
// and use the methods on that to add values. The values will be
// space-separated in the output (much like println).
//
// This facility can be enabled by passing -tags debuglog when
// building. Without this tag, dlog calls compile to nothing.
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// debugLogBytes is the size of each per-M ring buffer. This is
// allocated off-heap to avoid blowing up the M and hence the GC'd
// heap size.
const debugLogBytes = 16 << 10
// debugLogStringLimit is the maximum number of bytes in a string.
// Above this, the string will be truncated with "..(n more bytes).."
const debugLogStringLimit = debugLogBytes / 8
// dlog returns a debug logger. The caller can use methods on the
// returned logger to add values, which will be space-separated in the
// final output, much like println. The caller must call end() to
// finish the message.
//
// dlog can be used from highly-constrained corners of the runtime: it
// is safe to use in the signal handler, from within the write
// barrier, from within the stack implementation, and in places that
// must be recursively nosplit.
//
// This will be compiled away if built without the debuglog build tag.
// However, argument construction may not be. If any of the arguments
// are not literals or trivial expressions, consider protecting the
// call with "if dlogEnabled".
//
//go:nosplit
//go:nowritebarrierrec
func dlog() *dlogger {
if !dlogEnabled {
return nil
}
// Get the time.
tick, nano := uint64(cputicks()), uint64(nanotime())
// Try to get a cached logger.
l := getCachedDlogger()
// If we couldn't get a cached logger, try to get one from the
// global pool.
if l == nil {
allp := (*uintptr)(unsafe.Pointer(&allDloggers))
all := (*dlogger)(unsafe.Pointer(atomic.Loaduintptr(allp)))
for l1 := all; l1 != nil; l1 = l1.allLink {
if l1.owned.Load() == 0 && l1.owned.CompareAndSwap(0, 1) {
l = l1
break
}
}
}
// If that failed, allocate a new logger.
if l == nil {
// Use sysAllocOS instead of sysAlloc because we want to interfere
// with the runtime as little as possible, and sysAlloc updates accounting.
l = (*dlogger)(sysAllocOS(unsafe.Sizeof(dlogger{})))
if l == nil {
throw("failed to allocate debug log")
}
l.w.r.data = &l.w.data
l.owned.Store(1)
// Prepend to allDloggers list.
headp := (*uintptr)(unsafe.Pointer(&allDloggers))
for {
head := atomic.Loaduintptr(headp)
l.allLink = (*dlogger)(unsafe.Pointer(head))
if atomic.Casuintptr(headp, head, uintptr(unsafe.Pointer(l))) {
break
}
}
}
// If the time delta is getting too high, write a new sync
// packet. We set the limit so we don't write more than 6
// bytes of delta in the record header.
const deltaLimit = 1<<(3*7) - 1 // ~2ms between sync packets
if tick-l.w.tick > deltaLimit || nano-l.w.nano > deltaLimit {
l.w.writeSync(tick, nano)
}
// Reserve space for framing header.
l.w.ensure(debugLogHeaderSize)
l.w.write += debugLogHeaderSize
// Write record header.
l.w.uvarint(tick - l.w.tick)
l.w.uvarint(nano - l.w.nano)
gp := getg()
if gp != nil && gp.m != nil && gp.m.p != 0 {
l.w.varint(int64(gp.m.p.ptr().id))
} else {
l.w.varint(-1)
}
return l
}
// A dlogger writes to the debug log.
//
// To obtain a dlogger, call dlog(). When done with the dlogger, call
// end().
type dlogger struct {
_ sys.NotInHeap
w debugLogWriter
// allLink is the next dlogger in the allDloggers list.
allLink *dlogger
// owned indicates that this dlogger is owned by an M. This is
// accessed atomically.
owned atomic.Uint32
}
// allDloggers is a list of all dloggers, linked through
// dlogger.allLink. This is accessed atomically. This is prepend only,
// so it doesn't need to protect against ABA races.
var allDloggers *dlogger
//go:nosplit
func (l *dlogger) end() {
if !dlogEnabled {
return
}
// Fill in framing header.
size := l.w.write - l.w.r.end
if !l.w.writeFrameAt(l.w.r.end, size) {
throw("record too large")
}
// Commit the record.
l.w.r.end = l.w.write
// Attempt to return this logger to the cache.
if putCachedDlogger(l) {
return
}
// Return the logger to the global pool.
l.owned.Store(0)
}
const (
debugLogUnknown = 1 + iota
debugLogBoolTrue
debugLogBoolFalse
debugLogInt
debugLogUint
debugLogHex
debugLogPtr
debugLogString
debugLogConstString
debugLogStringOverflow
debugLogPC
debugLogTraceback
)
//go:nosplit
func (l *dlogger) b(x bool) *dlogger {
if !dlogEnabled {
return l
}
if x {
l.w.byte(debugLogBoolTrue)
} else {
l.w.byte(debugLogBoolFalse)
}
return l
}
//go:nosplit
func (l *dlogger) i(x int) *dlogger {
return l.i64(int64(x))
}
//go:nosplit
func (l *dlogger) i8(x int8) *dlogger {
return l.i64(int64(x))
}
//go:nosplit
func (l *dlogger) i16(x int16) *dlogger {
return l.i64(int64(x))
}
//go:nosplit
func (l *dlogger) i32(x int32) *dlogger {
return l.i64(int64(x))
}
//go:nosplit
func (l *dlogger) i64(x int64) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogInt)
l.w.varint(x)
return l
}
//go:nosplit
func (l *dlogger) u(x uint) *dlogger {
return l.u64(uint64(x))
}
//go:nosplit
func (l *dlogger) uptr(x uintptr) *dlogger {
return l.u64(uint64(x))
}
//go:nosplit
func (l *dlogger) u8(x uint8) *dlogger {
return l.u64(uint64(x))
}
//go:nosplit
func (l *dlogger) u16(x uint16) *dlogger {
return l.u64(uint64(x))
}
//go:nosplit
func (l *dlogger) u32(x uint32) *dlogger {
return l.u64(uint64(x))
}
//go:nosplit
func (l *dlogger) u64(x uint64) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogUint)
l.w.uvarint(x)
return l
}
//go:nosplit
func (l *dlogger) hex(x uint64) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogHex)
l.w.uvarint(x)
return l
}
//go:nosplit
func (l *dlogger) p(x any) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogPtr)
if x == nil {
l.w.uvarint(0)
} else {
v := efaceOf(&x)
switch v._type.kind & kindMask {
case kindChan, kindFunc, kindMap, kindPtr, kindUnsafePointer:
l.w.uvarint(uint64(uintptr(v.data)))
default:
throw("not a pointer type")
}
}
return l
}
//go:nosplit
func (l *dlogger) s(x string) *dlogger {
if !dlogEnabled {
return l
}
strData := unsafe.StringData(x)
datap := &firstmoduledata
if len(x) > 4 && datap.etext <= uintptr(unsafe.Pointer(strData)) && uintptr(unsafe.Pointer(strData)) < datap.end {
// String constants are in the rodata section, which
// isn't recorded in moduledata. But it has to be
// somewhere between etext and end.
l.w.byte(debugLogConstString)
l.w.uvarint(uint64(len(x)))
l.w.uvarint(uint64(uintptr(unsafe.Pointer(strData)) - datap.etext))
} else {
l.w.byte(debugLogString)
// We can't use unsafe.Slice as it may panic, which isn't safe
// in this (potentially) nowritebarrier context.
var b []byte
bb := (*slice)(unsafe.Pointer(&b))
bb.array = unsafe.Pointer(strData)
bb.len, bb.cap = len(x), len(x)
if len(b) > debugLogStringLimit {
b = b[:debugLogStringLimit]
}
l.w.uvarint(uint64(len(b)))
l.w.bytes(b)
if len(b) != len(x) {
l.w.byte(debugLogStringOverflow)
l.w.uvarint(uint64(len(x) - len(b)))
}
}
return l
}
//go:nosplit
func (l *dlogger) pc(x uintptr) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogPC)
l.w.uvarint(uint64(x))
return l
}
//go:nosplit
func (l *dlogger) traceback(x []uintptr) *dlogger {
if !dlogEnabled {
return l
}
l.w.byte(debugLogTraceback)
l.w.uvarint(uint64(len(x)))
for _, pc := range x {
l.w.uvarint(uint64(pc))
}
return l
}
// A debugLogWriter is a ring buffer of binary debug log records.
//
// A log record consists of a 2-byte framing header and a sequence of
// fields. The framing header gives the size of the record as a little
// endian 16-bit value. Each field starts with a byte indicating its
// type, followed by type-specific data. If the size in the framing
// header is 0, it's a sync record consisting of two little endian
// 64-bit values giving a new time base.
//
// Because this is a ring buffer, new records will eventually
// overwrite old records. Hence, it maintains a reader that consumes
// the log as it gets overwritten. That reader state is where an
// actual log reader would start.
type debugLogWriter struct {
_ sys.NotInHeap
write uint64
data debugLogBuf
// tick and nano are the time bases from the most recently
// written sync record.
tick, nano uint64
// r is a reader that consumes records as they get overwritten
// by the writer. It also acts as the initial reader state
// when printing the log.
r debugLogReader
// buf is a scratch buffer for encoding. This is here to
// reduce stack usage.
buf [10]byte
}
type debugLogBuf struct {
_ sys.NotInHeap
b [debugLogBytes]byte
}
const (
// debugLogHeaderSize is the number of bytes in the framing
// header of every dlog record.
debugLogHeaderSize = 2
// debugLogSyncSize is the number of bytes in a sync record.
debugLogSyncSize = debugLogHeaderSize + 2*8
)
//go:nosplit
func (l *debugLogWriter) ensure(n uint64) {
for l.write+n >= l.r.begin+uint64(len(l.data.b)) {
// Consume record at begin.
if l.r.skip() == ^uint64(0) {
// Wrapped around within a record.
//
// TODO(austin): It would be better to just
// eat the whole buffer at this point, but we
// have to communicate that to the reader
// somehow.
throw("record wrapped around")
}
}
}
//go:nosplit
func (l *debugLogWriter) writeFrameAt(pos, size uint64) bool {
l.data.b[pos%uint64(len(l.data.b))] = uint8(size)
l.data.b[(pos+1)%uint64(len(l.data.b))] = uint8(size >> 8)
return size <= 0xFFFF
}
//go:nosplit
func (l *debugLogWriter) writeSync(tick, nano uint64) {
l.tick, l.nano = tick, nano
l.ensure(debugLogHeaderSize)
l.writeFrameAt(l.write, 0)
l.write += debugLogHeaderSize
l.writeUint64LE(tick)
l.writeUint64LE(nano)
l.r.end = l.write
}
//go:nosplit
func (l *debugLogWriter) writeUint64LE(x uint64) {
var b [8]byte
b[0] = byte(x)
b[1] = byte(x >> 8)
b[2] = byte(x >> 16)
b[3] = byte(x >> 24)
b[4] = byte(x >> 32)
b[5] = byte(x >> 40)
b[6] = byte(x >> 48)
b[7] = byte(x >> 56)
l.bytes(b[:])
}
//go:nosplit
func (l *debugLogWriter) byte(x byte) {
l.ensure(1)
pos := l.write
l.write++
l.data.b[pos%uint64(len(l.data.b))] = x
}
//go:nosplit
func (l *debugLogWriter) bytes(x []byte) {
l.ensure(uint64(len(x)))
pos := l.write
l.write += uint64(len(x))
for len(x) > 0 {
n := copy(l.data.b[pos%uint64(len(l.data.b)):], x)
pos += uint64(n)
x = x[n:]
}
}
//go:nosplit
func (l *debugLogWriter) varint(x int64) {
var u uint64
if x < 0 {
u = (^uint64(x) << 1) | 1 // complement i, bit 0 is 1
} else {
u = (uint64(x) << 1) // do not complement i, bit 0 is 0
}
l.uvarint(u)
}
//go:nosplit
func (l *debugLogWriter) uvarint(u uint64) {
i := 0
for u >= 0x80 {
l.buf[i] = byte(u) | 0x80
u >>= 7
i++
}
l.buf[i] = byte(u)
i++
l.bytes(l.buf[:i])
}
type debugLogReader struct {
data *debugLogBuf
// begin and end are the positions in the log of the beginning
// and end of the log data, modulo len(data).
begin, end uint64
// tick and nano are the current time base at begin.
tick, nano uint64
}
//go:nosplit
func (r *debugLogReader) skip() uint64 {
// Read size at pos.
if r.begin+debugLogHeaderSize > r.end {
return ^uint64(0)
}
size := uint64(r.readUint16LEAt(r.begin))
if size == 0 {
// Sync packet.
r.tick = r.readUint64LEAt(r.begin + debugLogHeaderSize)
r.nano = r.readUint64LEAt(r.begin + debugLogHeaderSize + 8)
size = debugLogSyncSize
}
if r.begin+size > r.end {
return ^uint64(0)
}
r.begin += size
return size
}
//go:nosplit
func (r *debugLogReader) readUint16LEAt(pos uint64) uint16 {
return uint16(r.data.b[pos%uint64(len(r.data.b))]) |
uint16(r.data.b[(pos+1)%uint64(len(r.data.b))])<<8
}
//go:nosplit
func (r *debugLogReader) readUint64LEAt(pos uint64) uint64 {
var b [8]byte
for i := range b {
b[i] = r.data.b[pos%uint64(len(r.data.b))]
pos++
}
return uint64(b[0]) | uint64(b[1])<<8 |
uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 |
uint64(b[6])<<48 | uint64(b[7])<<56
}
func (r *debugLogReader) peek() (tick uint64) {
// Consume any sync records.
size := uint64(0)
for size == 0 {
if r.begin+debugLogHeaderSize > r.end {
return ^uint64(0)
}
size = uint64(r.readUint16LEAt(r.begin))
if size != 0 {
break
}
if r.begin+debugLogSyncSize > r.end {
return ^uint64(0)
}
// Sync packet.
r.tick = r.readUint64LEAt(r.begin + debugLogHeaderSize)
r.nano = r.readUint64LEAt(r.begin + debugLogHeaderSize + 8)
r.begin += debugLogSyncSize
}
// Peek tick delta.
if r.begin+size > r.end {
return ^uint64(0)
}
pos := r.begin + debugLogHeaderSize
var u uint64
for i := uint(0); ; i += 7 {
b := r.data.b[pos%uint64(len(r.data.b))]
pos++
u |= uint64(b&^0x80) << i
if b&0x80 == 0 {
break
}
}
if pos > r.begin+size {
return ^uint64(0)
}
return r.tick + u
}
func (r *debugLogReader) header() (end, tick, nano uint64, p int) {
// Read size. We've already skipped sync packets and checked
// bounds in peek.
size := uint64(r.readUint16LEAt(r.begin))
end = r.begin + size
r.begin += debugLogHeaderSize
// Read tick, nano, and p.
tick = r.uvarint() + r.tick
nano = r.uvarint() + r.nano
p = int(r.varint())
return
}
func (r *debugLogReader) uvarint() uint64 {
var u uint64
for i := uint(0); ; i += 7 {
b := r.data.b[r.begin%uint64(len(r.data.b))]
r.begin++
u |= uint64(b&^0x80) << i
if b&0x80 == 0 {
break
}
}
return u
}
func (r *debugLogReader) varint() int64 {
u := r.uvarint()
var v int64
if u&1 == 0 {
v = int64(u >> 1)
} else {
v = ^int64(u >> 1)
}
return v
}
func (r *debugLogReader) printVal() bool {
typ := r.data.b[r.begin%uint64(len(r.data.b))]
r.begin++
switch typ {
default:
print("<unknown field type ", hex(typ), " pos ", r.begin-1, " end ", r.end, ">\n")
return false
case debugLogUnknown:
print("<unknown kind>")
case debugLogBoolTrue:
print(true)
case debugLogBoolFalse:
print(false)
case debugLogInt:
print(r.varint())
case debugLogUint:
print(r.uvarint())
case debugLogHex, debugLogPtr:
print(hex(r.uvarint()))
case debugLogString:
sl := r.uvarint()
if r.begin+sl > r.end {
r.begin = r.end
print("<string length corrupted>")
break
}
for sl > 0 {
b := r.data.b[r.begin%uint64(len(r.data.b)):]
if uint64(len(b)) > sl {
b = b[:sl]
}
r.begin += uint64(len(b))
sl -= uint64(len(b))
gwrite(b)
}
case debugLogConstString:
len, ptr := int(r.uvarint()), uintptr(r.uvarint())
ptr += firstmoduledata.etext
// We can't use unsafe.String as it may panic, which isn't safe
// in this (potentially) nowritebarrier context.
str := stringStruct{
str: unsafe.Pointer(ptr),
len: len,
}
s := *(*string)(unsafe.Pointer(&str))
print(s)
case debugLogStringOverflow:
print("..(", r.uvarint(), " more bytes)..")
case debugLogPC:
printDebugLogPC(uintptr(r.uvarint()), false)
case debugLogTraceback:
n := int(r.uvarint())
for i := 0; i < n; i++ {
print("\n\t")
// gentraceback PCs are always return PCs.
// Convert them to call PCs.
//
// TODO(austin): Expand inlined frames.
printDebugLogPC(uintptr(r.uvarint()), true)
}
}
return true
}
// printDebugLog prints the debug log.
func printDebugLog() {
if !dlogEnabled {
return
}
// This function should not panic or throw since it is used in
// the fatal panic path and this may deadlock.
printlock()
// Get the list of all debug logs.
allp := (*uintptr)(unsafe.Pointer(&allDloggers))
all := (*dlogger)(unsafe.Pointer(atomic.Loaduintptr(allp)))
// Count the logs.
n := 0
for l := all; l != nil; l = l.allLink {
n++
}
if n == 0 {
printunlock()
return
}
// Prepare read state for all logs.
type readState struct {
debugLogReader
first bool
lost uint64
nextTick uint64
}
// Use sysAllocOS instead of sysAlloc because we want to interfere
// with the runtime as little as possible, and sysAlloc updates accounting.
state1 := sysAllocOS(unsafe.Sizeof(readState{}) * uintptr(n))
if state1 == nil {
println("failed to allocate read state for", n, "logs")
printunlock()
return
}
state := (*[1 << 20]readState)(state1)[:n]
{
l := all
for i := range state {
s := &state[i]
s.debugLogReader = l.w.r
s.first = true
s.lost = l.w.r.begin
s.nextTick = s.peek()
l = l.allLink
}
}
// Print records.
for {
// Find the next record.
var best struct {
tick uint64
i int
}
best.tick = ^uint64(0)
for i := range state {
if state[i].nextTick < best.tick {
best.tick = state[i].nextTick
best.i = i
}
}
if best.tick == ^uint64(0) {
break
}
// Print record.
s := &state[best.i]
if s.first {
print(">> begin log ", best.i)
if s.lost != 0 {
print("; lost first ", s.lost>>10, "KB")
}
print(" <<\n")
s.first = false
}
end, _, nano, p := s.header()
oldEnd := s.end
s.end = end
print("[")
var tmpbuf [21]byte
pnano := int64(nano) - runtimeInitTime
if pnano < 0 {
// Logged before runtimeInitTime was set.
pnano = 0
}
pnanoBytes := itoaDiv(tmpbuf[:], uint64(pnano), 9)
print(slicebytetostringtmp((*byte)(noescape(unsafe.Pointer(&pnanoBytes[0]))), len(pnanoBytes)))
print(" P ", p, "] ")
for i := 0; s.begin < s.end; i++ {
if i > 0 {
print(" ")
}
if !s.printVal() {
// Abort this P log.
print("<aborting P log>")
end = oldEnd
break
}
}
println()
// Move on to the next record.
s.begin = end
s.end = oldEnd
s.nextTick = s.peek()
}
printunlock()
}
// printDebugLogPC prints a single symbolized PC. If returnPC is true,
// pc is a return PC that must first be converted to a call PC.
func printDebugLogPC(pc uintptr, returnPC bool) {
fn := findfunc(pc)
if returnPC && (!fn.valid() || pc > fn.entry()) {
// TODO(austin): Don't back up if the previous frame
// was a sigpanic.
pc--
}
print(hex(pc))
if !fn.valid() {
print(" [unknown PC]")
} else {
name := funcname(fn)
file, line := funcline(fn, pc)
print(" [", name, "+", hex(pc-fn.entry()),
" ", file, ":", line, "]")
}
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !debuglog
package runtime
const dlogEnabled = false
type dlogPerM struct{}
func getCachedDlogger() *dlogger {
return nil
}
func putCachedDlogger(l *dlogger) bool {
return false
}
// created by cgo -cdefs and then converted to Go
// cgo -cdefs defs_linux.go defs1_linux.go
package runtime
import "unsafe"
const (
_EINTR = 0x4
_EAGAIN = 0xb
_ENOMEM = 0xc
_PROT_NONE = 0x0
_PROT_READ = 0x1
_PROT_WRITE = 0x2
_PROT_EXEC = 0x4
_MAP_ANON = 0x20
_MAP_PRIVATE = 0x2
_MAP_FIXED = 0x10
_MADV_DONTNEED = 0x4
_MADV_FREE = 0x8
_MADV_HUGEPAGE = 0xe
_MADV_NOHUGEPAGE = 0xf
_SA_RESTART = 0x10000000
_SA_ONSTACK = 0x8000000
_SA_RESTORER = 0x4000000
_SA_SIGINFO = 0x4
_SI_KERNEL = 0x80
_SI_TIMER = -0x2
_SIGHUP = 0x1
_SIGINT = 0x2
_SIGQUIT = 0x3
_SIGILL = 0x4
_SIGTRAP = 0x5
_SIGABRT = 0x6
_SIGBUS = 0x7
_SIGFPE = 0x8
_SIGKILL = 0x9
_SIGUSR1 = 0xa
_SIGSEGV = 0xb
_SIGUSR2 = 0xc
_SIGPIPE = 0xd
_SIGALRM = 0xe
_SIGSTKFLT = 0x10
_SIGCHLD = 0x11
_SIGCONT = 0x12
_SIGSTOP = 0x13
_SIGTSTP = 0x14
_SIGTTIN = 0x15
_SIGTTOU = 0x16
_SIGURG = 0x17
_SIGXCPU = 0x18
_SIGXFSZ = 0x19
_SIGVTALRM = 0x1a
_SIGPROF = 0x1b
_SIGWINCH = 0x1c
_SIGIO = 0x1d
_SIGPWR = 0x1e
_SIGSYS = 0x1f
_SIGRTMIN = 0x20
_FPE_INTDIV = 0x1
_FPE_INTOVF = 0x2
_FPE_FLTDIV = 0x3
_FPE_FLTOVF = 0x4
_FPE_FLTUND = 0x5
_FPE_FLTRES = 0x6
_FPE_FLTINV = 0x7
_FPE_FLTSUB = 0x8
_BUS_ADRALN = 0x1
_BUS_ADRERR = 0x2
_BUS_OBJERR = 0x3
_SEGV_MAPERR = 0x1
_SEGV_ACCERR = 0x2
_ITIMER_REAL = 0x0
_ITIMER_VIRTUAL = 0x1
_ITIMER_PROF = 0x2
_CLOCK_THREAD_CPUTIME_ID = 0x3
_SIGEV_THREAD_ID = 0x4
_AF_UNIX = 0x1
_SOCK_DGRAM = 0x2
)
type timespec struct {
tv_sec int64
tv_nsec int64
}
//go:nosplit
func (ts *timespec) setNsec(ns int64) {
ts.tv_sec = ns / 1e9
ts.tv_nsec = ns % 1e9
}
type timeval struct {
tv_sec int64
tv_usec int64
}
func (tv *timeval) set_usec(x int32) {
tv.tv_usec = int64(x)
}
type sigactiont struct {
sa_handler uintptr
sa_flags uint64
sa_restorer uintptr
sa_mask uint64
}
type siginfoFields struct {
si_signo int32
si_errno int32
si_code int32
// below here is a union; si_addr is the only field we use
si_addr uint64
}
type siginfo struct {
siginfoFields
// Pad struct to the max size in the kernel.
_ [_si_max_size - unsafe.Sizeof(siginfoFields{})]byte
}
type itimerspec struct {
it_interval timespec
it_value timespec
}
type itimerval struct {
it_interval timeval
it_value timeval
}
type sigeventFields struct {
value uintptr
signo int32
notify int32
// below here is a union; sigev_notify_thread_id is the only field we use
sigev_notify_thread_id int32
}
type sigevent struct {
sigeventFields
// Pad struct to the max size in the kernel.
_ [_sigev_max_size - unsafe.Sizeof(sigeventFields{})]byte
}
// created by cgo -cdefs and then converted to Go
// cgo -cdefs defs_linux.go defs1_linux.go
const (
_O_RDONLY = 0x0
_O_WRONLY = 0x1
_O_CREAT = 0x40
_O_TRUNC = 0x200
_O_NONBLOCK = 0x800
_O_CLOEXEC = 0x80000
)
type usigset struct {
__val [16]uint64
}
type fpxreg struct {
significand [4]uint16
exponent uint16
padding [3]uint16
}
type xmmreg struct {
element [4]uint32
}
type fpstate struct {
cwd uint16
swd uint16
ftw uint16
fop uint16
rip uint64
rdp uint64
mxcsr uint32
mxcr_mask uint32
_st [8]fpxreg
_xmm [16]xmmreg
padding [24]uint32
}
type fpxreg1 struct {
significand [4]uint16
exponent uint16
padding [3]uint16
}
type xmmreg1 struct {
element [4]uint32
}
type fpstate1 struct {
cwd uint16
swd uint16
ftw uint16
fop uint16
rip uint64
rdp uint64
mxcsr uint32
mxcr_mask uint32
_st [8]fpxreg1
_xmm [16]xmmreg1
padding [24]uint32
}
type fpreg1 struct {
significand [4]uint16
exponent uint16
}
type stackt struct {
ss_sp *byte
ss_flags int32
pad_cgo_0 [4]byte
ss_size uintptr
}
type mcontext struct {
gregs [23]uint64
fpregs *fpstate
__reserved1 [8]uint64
}
type ucontext struct {
uc_flags uint64
uc_link *ucontext
uc_stack stackt
uc_mcontext mcontext
uc_sigmask usigset
__fpregs_mem fpstate
}
type sigcontext struct {
r8 uint64
r9 uint64
r10 uint64
r11 uint64
r12 uint64
r13 uint64
r14 uint64
r15 uint64
rdi uint64
rsi uint64
rbp uint64
rbx uint64
rdx uint64
rax uint64
rcx uint64
rsp uint64
rip uint64
eflags uint64
cs uint16
gs uint16
fs uint16
__pad0 uint16
err uint64
trapno uint64
oldmask uint64
cr2 uint64
fpstate *fpstate1
__reserved1 [8]uint64
}
type sockaddr_un struct {
family uint16
path [108]byte
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
func gogetenv(key string) string {
env := environ()
if env == nil {
throw("getenv before env init")
}
for _, s := range env {
if len(s) > len(key) && s[len(key)] == '=' && envKeyEqual(s[:len(key)], key) {
return s[len(key)+1:]
}
}
return ""
}
// envKeyEqual reports whether a == b, with ASCII-only case insensitivity
// on Windows. The two strings must have the same length.
func envKeyEqual(a, b string) bool {
if GOOS == "windows" { // case insensitive
for i := 0; i < len(a); i++ {
ca, cb := a[i], b[i]
if ca == cb || lowerASCII(ca) == lowerASCII(cb) {
continue
}
return false
}
return true
}
return a == b
}
func lowerASCII(c byte) byte {
if 'A' <= c && c <= 'Z' {
return c + ('a' - 'A')
}
return c
}
var _cgo_setenv unsafe.Pointer // pointer to C function
var _cgo_unsetenv unsafe.Pointer // pointer to C function
// Update the C environment if cgo is loaded.
func setenv_c(k string, v string) {
if _cgo_setenv == nil {
return
}
arg := [2]unsafe.Pointer{cstring(k), cstring(v)}
asmcgocall(_cgo_setenv, unsafe.Pointer(&arg))
}
// Update the C environment if cgo is loaded.
func unsetenv_c(k string) {
if _cgo_unsetenv == nil {
return
}
arg := [1]unsafe.Pointer{cstring(k)}
asmcgocall(_cgo_unsetenv, unsafe.Pointer(&arg))
}
func cstring(s string) unsafe.Pointer {
p := make([]byte, len(s)+1)
copy(p, s)
return unsafe.Pointer(&p[0])
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "internal/bytealg"
// The Error interface identifies a run time error.
type Error interface {
error
// RuntimeError is a no-op function but
// serves to distinguish types that are run time
// errors from ordinary errors: a type is a
// run time error if it has a RuntimeError method.
RuntimeError()
}
// A TypeAssertionError explains a failed type assertion.
type TypeAssertionError struct {
_interface *_type
concrete *_type
asserted *_type
missingMethod string // one method needed by Interface, missing from Concrete
}
func (*TypeAssertionError) RuntimeError() {}
func (e *TypeAssertionError) Error() string {
inter := "interface"
if e._interface != nil {
inter = e._interface.string()
}
as := e.asserted.string()
if e.concrete == nil {
return "interface conversion: " + inter + " is nil, not " + as
}
cs := e.concrete.string()
if e.missingMethod == "" {
msg := "interface conversion: " + inter + " is " + cs + ", not " + as
if cs == as {
// provide slightly clearer error message
if e.concrete.pkgpath() != e.asserted.pkgpath() {
msg += " (types from different packages)"
} else {
msg += " (types from different scopes)"
}
}
return msg
}
return "interface conversion: " + cs + " is not " + as +
": missing method " + e.missingMethod
}
// itoa converts val to a decimal representation. The result is
// written somewhere within buf and the location of the result is returned.
// buf must be at least 20 bytes.
//
//go:nosplit
func itoa(buf []byte, val uint64) []byte {
i := len(buf) - 1
for val >= 10 {
buf[i] = byte(val%10 + '0')
i--
val /= 10
}
buf[i] = byte(val + '0')
return buf[i:]
}
// An errorString represents a runtime error described by a single string.
type errorString string
func (e errorString) RuntimeError() {}
func (e errorString) Error() string {
return "runtime error: " + string(e)
}
type errorAddressString struct {
msg string // error message
addr uintptr // memory address where the error occurred
}
func (e errorAddressString) RuntimeError() {}
func (e errorAddressString) Error() string {
return "runtime error: " + e.msg
}
// Addr returns the memory address where a fault occurred.
// The address provided is best-effort.
// The veracity of the result may depend on the platform.
// Errors providing this method will only be returned as
// a result of using runtime/debug.SetPanicOnFault.
func (e errorAddressString) Addr() uintptr {
return e.addr
}
// plainError represents a runtime error described a string without
// the prefix "runtime error: " after invoking errorString.Error().
// See Issue #14965.
type plainError string
func (e plainError) RuntimeError() {}
func (e plainError) Error() string {
return string(e)
}
// A boundsError represents an indexing or slicing operation gone wrong.
type boundsError struct {
x int64
y int
// Values in an index or slice expression can be signed or unsigned.
// That means we'd need 65 bits to encode all possible indexes, from -2^63 to 2^64-1.
// Instead, we keep track of whether x should be interpreted as signed or unsigned.
// y is known to be nonnegative and to fit in an int.
signed bool
code boundsErrorCode
}
type boundsErrorCode uint8
const (
boundsIndex boundsErrorCode = iota // s[x], 0 <= x < len(s) failed
boundsSliceAlen // s[?:x], 0 <= x <= len(s) failed
boundsSliceAcap // s[?:x], 0 <= x <= cap(s) failed
boundsSliceB // s[x:y], 0 <= x <= y failed (but boundsSliceA didn't happen)
boundsSlice3Alen // s[?:?:x], 0 <= x <= len(s) failed
boundsSlice3Acap // s[?:?:x], 0 <= x <= cap(s) failed
boundsSlice3B // s[?:x:y], 0 <= x <= y failed (but boundsSlice3A didn't happen)
boundsSlice3C // s[x:y:?], 0 <= x <= y failed (but boundsSlice3A/B didn't happen)
boundsConvert // (*[x]T)(s), 0 <= x <= len(s) failed
// Note: in the above, len(s) and cap(s) are stored in y
)
// boundsErrorFmts provide error text for various out-of-bounds panics.
// Note: if you change these strings, you should adjust the size of the buffer
// in boundsError.Error below as well.
var boundsErrorFmts = [...]string{
boundsIndex: "index out of range [%x] with length %y",
boundsSliceAlen: "slice bounds out of range [:%x] with length %y",
boundsSliceAcap: "slice bounds out of range [:%x] with capacity %y",
boundsSliceB: "slice bounds out of range [%x:%y]",
boundsSlice3Alen: "slice bounds out of range [::%x] with length %y",
boundsSlice3Acap: "slice bounds out of range [::%x] with capacity %y",
boundsSlice3B: "slice bounds out of range [:%x:%y]",
boundsSlice3C: "slice bounds out of range [%x:%y:]",
boundsConvert: "cannot convert slice with length %y to array or pointer to array with length %x",
}
// boundsNegErrorFmts are overriding formats if x is negative. In this case there's no need to report y.
var boundsNegErrorFmts = [...]string{
boundsIndex: "index out of range [%x]",
boundsSliceAlen: "slice bounds out of range [:%x]",
boundsSliceAcap: "slice bounds out of range [:%x]",
boundsSliceB: "slice bounds out of range [%x:]",
boundsSlice3Alen: "slice bounds out of range [::%x]",
boundsSlice3Acap: "slice bounds out of range [::%x]",
boundsSlice3B: "slice bounds out of range [:%x:]",
boundsSlice3C: "slice bounds out of range [%x::]",
}
func (e boundsError) RuntimeError() {}
func appendIntStr(b []byte, v int64, signed bool) []byte {
if signed && v < 0 {
b = append(b, '-')
v = -v
}
var buf [20]byte
b = append(b, itoa(buf[:], uint64(v))...)
return b
}
func (e boundsError) Error() string {
fmt := boundsErrorFmts[e.code]
if e.signed && e.x < 0 {
fmt = boundsNegErrorFmts[e.code]
}
// max message length is 99: "runtime error: slice bounds out of range [::%x] with capacity %y"
// x can be at most 20 characters. y can be at most 19.
b := make([]byte, 0, 100)
b = append(b, "runtime error: "...)
for i := 0; i < len(fmt); i++ {
c := fmt[i]
if c != '%' {
b = append(b, c)
continue
}
i++
switch fmt[i] {
case 'x':
b = appendIntStr(b, e.x, e.signed)
case 'y':
b = appendIntStr(b, int64(e.y), true)
}
}
return string(b)
}
type stringer interface {
String() string
}
// printany prints an argument passed to panic.
// If panic is called with a value that has a String or Error method,
// it has already been converted into a string by preprintpanics.
func printany(i any) {
switch v := i.(type) {
case nil:
print("nil")
case bool:
print(v)
case int:
print(v)
case int8:
print(v)
case int16:
print(v)
case int32:
print(v)
case int64:
print(v)
case uint:
print(v)
case uint8:
print(v)
case uint16:
print(v)
case uint32:
print(v)
case uint64:
print(v)
case uintptr:
print(v)
case float32:
print(v)
case float64:
print(v)
case complex64:
print(v)
case complex128:
print(v)
case string:
print(v)
default:
printanycustomtype(i)
}
}
func printanycustomtype(i any) {
eface := efaceOf(&i)
typestring := eface._type.string()
switch eface._type.kind {
case kindString:
print(typestring, `("`, *(*string)(eface.data), `")`)
case kindBool:
print(typestring, "(", *(*bool)(eface.data), ")")
case kindInt:
print(typestring, "(", *(*int)(eface.data), ")")
case kindInt8:
print(typestring, "(", *(*int8)(eface.data), ")")
case kindInt16:
print(typestring, "(", *(*int16)(eface.data), ")")
case kindInt32:
print(typestring, "(", *(*int32)(eface.data), ")")
case kindInt64:
print(typestring, "(", *(*int64)(eface.data), ")")
case kindUint:
print(typestring, "(", *(*uint)(eface.data), ")")
case kindUint8:
print(typestring, "(", *(*uint8)(eface.data), ")")
case kindUint16:
print(typestring, "(", *(*uint16)(eface.data), ")")
case kindUint32:
print(typestring, "(", *(*uint32)(eface.data), ")")
case kindUint64:
print(typestring, "(", *(*uint64)(eface.data), ")")
case kindUintptr:
print(typestring, "(", *(*uintptr)(eface.data), ")")
case kindFloat32:
print(typestring, "(", *(*float32)(eface.data), ")")
case kindFloat64:
print(typestring, "(", *(*float64)(eface.data), ")")
case kindComplex64:
print(typestring, *(*complex64)(eface.data))
case kindComplex128:
print(typestring, *(*complex128)(eface.data))
default:
print("(", typestring, ") ", eface.data)
}
}
// panicwrap generates a panic for a call to a wrapped value method
// with a nil pointer receiver.
//
// It is called from the generated wrapper code.
func panicwrap() {
pc := getcallerpc()
name := funcname(findfunc(pc))
// name is something like "main.(*T).F".
// We want to extract pkg ("main"), typ ("T"), and meth ("F").
// Do it by finding the parens.
i := bytealg.IndexByteString(name, '(')
if i < 0 {
throw("panicwrap: no ( in " + name)
}
pkg := name[:i-1]
if i+2 >= len(name) || name[i-1:i+2] != ".(*" {
throw("panicwrap: unexpected string after package name: " + name)
}
name = name[i+2:]
i = bytealg.IndexByteString(name, ')')
if i < 0 {
throw("panicwrap: no ) in " + name)
}
if i+2 >= len(name) || name[i:i+2] != ")." {
throw("panicwrap: unexpected string after type name: " + name)
}
typ := name[:i]
meth := name[i+2:]
panic(plainError("value method " + pkg + "." + typ + "." + meth + " called using nil *" + typ + " pointer"))
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// addExitHook registers the specified function 'f' to be run at
// program termination (e.g. when someone invokes os.Exit(), or when
// main.main returns). Hooks are run in reverse order of registration:
// first hook added is the last one run.
//
// CAREFUL: the expectation is that addExitHook should only be called
// from a safe context (e.g. not an error/panic path or signal
// handler, preemption enabled, allocation allowed, write barriers
// allowed, etc), and that the exit function 'f' will be invoked under
// similar circumstances. That is the say, we are expecting that 'f'
// uses normal / high-level Go code as opposed to one of the more
// restricted dialects used for the trickier parts of the runtime.
func addExitHook(f func(), runOnNonZeroExit bool) {
exitHooks.hooks = append(exitHooks.hooks, exitHook{f: f, runOnNonZeroExit: runOnNonZeroExit})
}
// exitHook stores a function to be run on program exit, registered
// by the utility runtime.addExitHook.
type exitHook struct {
f func() // func to run
runOnNonZeroExit bool // whether to run on non-zero exit code
}
// exitHooks stores state related to hook functions registered to
// run when program execution terminates.
var exitHooks struct {
hooks []exitHook
runningExitHooks bool
}
// runExitHooks runs any registered exit hook functions (funcs
// previously registered using runtime.addExitHook). Here 'exitCode'
// is the status code being passed to os.Exit, or zero if the program
// is terminating normally without calling os.Exit).
func runExitHooks(exitCode int) {
if exitHooks.runningExitHooks {
throw("internal error: exit hook invoked exit")
}
exitHooks.runningExitHooks = true
runExitHook := func(f func()) (caughtPanic bool) {
defer func() {
if x := recover(); x != nil {
caughtPanic = true
}
}()
f()
return
}
finishPageTrace()
for i := range exitHooks.hooks {
h := exitHooks.hooks[len(exitHooks.hooks)-i-1]
if exitCode != 0 && !h.runOnNonZeroExit {
continue
}
if caughtPanic := runExitHook(h.f); caughtPanic {
throw("internal error: exit hook invoked panic")
}
}
exitHooks.hooks = nil
exitHooks.runningExitHooks = false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package runtime contains operations that interact with Go's runtime system,
such as functions to control goroutines. It also includes the low-level type information
used by the reflect package; see reflect's documentation for the programmable
interface to the run-time type system.
# Environment Variables
The following environment variables ($name or %name%, depending on the host
operating system) control the run-time behavior of Go programs. The meanings
and use may change from release to release.
The GOGC variable sets the initial garbage collection target percentage.
A collection is triggered when the ratio of freshly allocated data to live data
remaining after the previous collection reaches this percentage. The default
is GOGC=100. Setting GOGC=off disables the garbage collector entirely.
[runtime/debug.SetGCPercent] allows changing this percentage at run time.
The GOMEMLIMIT variable sets a soft memory limit for the runtime. This memory limit
includes the Go heap and all other memory managed by the runtime, and excludes
external memory sources such as mappings of the binary itself, memory managed in
other languages, and memory held by the operating system on behalf of the Go
program. GOMEMLIMIT is a numeric value in bytes with an optional unit suffix.
The supported suffixes include B, KiB, MiB, GiB, and TiB. These suffixes
represent quantities of bytes as defined by the IEC 80000-13 standard. That is,
they are based on powers of two: KiB means 2^10 bytes, MiB means 2^20 bytes,
and so on. The default setting is math.MaxInt64, which effectively disables the
memory limit. [runtime/debug.SetMemoryLimit] allows changing this limit at run
time.
The GODEBUG variable controls debugging variables within the runtime.
It is a comma-separated list of name=val pairs setting these named variables:
allocfreetrace: setting allocfreetrace=1 causes every allocation to be
profiled and a stack trace printed on each object's allocation and free.
clobberfree: setting clobberfree=1 causes the garbage collector to
clobber the memory content of an object with bad content when it frees
the object.
cpu.*: cpu.all=off disables the use of all optional instruction set extensions.
cpu.extension=off disables use of instructions from the specified instruction set extension.
extension is the lower case name for the instruction set extension such as sse41 or avx
as listed in internal/cpu package. As an example cpu.avx=off disables runtime detection
and thereby use of AVX instructions.
cgocheck: setting cgocheck=0 disables all checks for packages
using cgo to incorrectly pass Go pointers to non-Go code.
Setting cgocheck=1 (the default) enables relatively cheap
checks that may miss some errors. A more complete, but slow,
cgocheck mode can be enabled using GOEXPERIMENT (which
requires a rebuild), see https://pkg.go.dev/internal/goexperiment for details.
efence: setting efence=1 causes the allocator to run in a mode
where each object is allocated on a unique page and addresses are
never recycled.
gccheckmark: setting gccheckmark=1 enables verification of the
garbage collector's concurrent mark phase by performing a
second mark pass while the world is stopped. If the second
pass finds a reachable object that was not found by concurrent
mark, the garbage collector will panic.
gcpacertrace: setting gcpacertrace=1 causes the garbage collector to
print information about the internal state of the concurrent pacer.
gcshrinkstackoff: setting gcshrinkstackoff=1 disables moving goroutines
onto smaller stacks. In this mode, a goroutine's stack can only grow.
gcstoptheworld: setting gcstoptheworld=1 disables concurrent garbage collection,
making every garbage collection a stop-the-world event. Setting gcstoptheworld=2
also disables concurrent sweeping after the garbage collection finishes.
gctrace: setting gctrace=1 causes the garbage collector to emit a single line to standard
error at each collection, summarizing the amount of memory collected and the
length of the pause. The format of this line is subject to change.
Currently, it is:
gc # @#s #%: #+#+# ms clock, #+#/#/#+# ms cpu, #->#-># MB, # MB goal, # MB stacks, #MB globals, # P
where the fields are as follows:
gc # the GC number, incremented at each GC
@#s time in seconds since program start
#% percentage of time spent in GC since program start
#+...+# wall-clock/CPU times for the phases of the GC
#->#-># MB heap size at GC start, at GC end, and live heap
# MB goal goal heap size
# MB stacks estimated scannable stack size
# MB globals scannable global size
# P number of processors used
The phases are stop-the-world (STW) sweep termination, concurrent
mark and scan, and STW mark termination. The CPU times
for mark/scan are broken down in to assist time (GC performed in
line with allocation), background GC time, and idle GC time.
If the line ends with "(forced)", this GC was forced by a
runtime.GC() call.
harddecommit: setting harddecommit=1 causes memory that is returned to the OS to
also have protections removed on it. This is the only mode of operation on Windows,
but is helpful in debugging scavenger-related issues on other platforms. Currently,
only supported on Linux.
inittrace: setting inittrace=1 causes the runtime to emit a single line to standard
error for each package with init work, summarizing the execution time and memory
allocation. No information is printed for inits executed as part of plugin loading
and for packages without both user defined and compiler generated init work.
The format of this line is subject to change. Currently, it is:
init # @#ms, # ms clock, # bytes, # allocs
where the fields are as follows:
init # the package name
@# ms time in milliseconds when the init started since program start
# clock wall-clock time for package initialization work
# bytes memory allocated on the heap
# allocs number of heap allocations
madvdontneed: setting madvdontneed=0 will use MADV_FREE
instead of MADV_DONTNEED on Linux when returning memory to the
kernel. This is more efficient, but means RSS numbers will
drop only when the OS is under memory pressure. On the BSDs and
Illumos/Solaris, setting madvdontneed=1 will use MADV_DONTNEED instead
of MADV_FREE. This is less efficient, but causes RSS numbers to drop
more quickly.
memprofilerate: setting memprofilerate=X will update the value of runtime.MemProfileRate.
When set to 0 memory profiling is disabled. Refer to the description of
MemProfileRate for the default value.
pagetrace: setting pagetrace=/path/to/file will write out a trace of page events
that can be viewed, analyzed, and visualized using the x/debug/cmd/pagetrace tool.
Build your program with GOEXPERIMENT=pagetrace to enable this functionality. Do not
enable this functionality if your program is a setuid binary as it introduces a security
risk in that scenario. Currently not supported on Windows, plan9 or js/wasm. Setting this
option for some applications can produce large traces, so use with care.
invalidptr: invalidptr=1 (the default) causes the garbage collector and stack
copier to crash the program if an invalid pointer value (for example, 1)
is found in a pointer-typed location. Setting invalidptr=0 disables this check.
This should only be used as a temporary workaround to diagnose buggy code.
The real fix is to not store integers in pointer-typed locations.
sbrk: setting sbrk=1 replaces the memory allocator and garbage collector
with a trivial allocator that obtains memory from the operating system and
never reclaims any memory.
scavtrace: setting scavtrace=1 causes the runtime to emit a single line to standard
error, roughly once per GC cycle, summarizing the amount of work done by the
scavenger as well as the total amount of memory returned to the operating system
and an estimate of physical memory utilization. The format of this line is subject
to change, but currently it is:
scav # KiB work, # KiB total, #% util
where the fields are as follows:
# KiB work the amount of memory returned to the OS since the last line
# KiB total the total amount of memory returned to the OS
#% util the fraction of all unscavenged memory which is in-use
If the line ends with "(forced)", then scavenging was forced by a
debug.FreeOSMemory() call.
scheddetail: setting schedtrace=X and scheddetail=1 causes the scheduler to emit
detailed multiline info every X milliseconds, describing state of the scheduler,
processors, threads and goroutines.
schedtrace: setting schedtrace=X causes the scheduler to emit a single line to standard
error every X milliseconds, summarizing the scheduler state.
tracebackancestors: setting tracebackancestors=N extends tracebacks with the stacks at
which goroutines were created, where N limits the number of ancestor goroutines to
report. This also extends the information returned by runtime.Stack. Ancestor's goroutine
IDs will refer to the ID of the goroutine at the time of creation; it's possible for this
ID to be reused for another goroutine. Setting N to 0 will report no ancestry information.
asyncpreemptoff: asyncpreemptoff=1 disables signal-based
asynchronous goroutine preemption. This makes some loops
non-preemptible for long periods, which may delay GC and
goroutine scheduling. This is useful for debugging GC issues
because it also disables the conservative stack scanning used
for asynchronously preempted goroutines.
The net and net/http packages also refer to debugging variables in GODEBUG.
See the documentation for those packages for details.
The GOMAXPROCS variable limits the number of operating system threads that
can execute user-level Go code simultaneously. There is no limit to the number of threads
that can be blocked in system calls on behalf of Go code; those do not count against
the GOMAXPROCS limit. This package's GOMAXPROCS function queries and changes
the limit.
The GORACE variable configures the race detector, for programs built using -race.
See https://golang.org/doc/articles/race_detector.html for details.
The GOTRACEBACK variable controls the amount of output generated when a Go
program fails due to an unrecovered panic or an unexpected runtime condition.
By default, a failure prints a stack trace for the current goroutine,
eliding functions internal to the run-time system, and then exits with exit code 2.
The failure prints stack traces for all goroutines if there is no current goroutine
or the failure is internal to the run-time.
GOTRACEBACK=none omits the goroutine stack traces entirely.
GOTRACEBACK=single (the default) behaves as described above.
GOTRACEBACK=all adds stack traces for all user-created goroutines.
GOTRACEBACK=system is like “all” but adds stack frames for run-time functions
and shows goroutines created internally by the run-time.
GOTRACEBACK=crash is like “system” but crashes in an operating system-specific
manner instead of exiting. For example, on Unix systems, the crash raises
SIGABRT to trigger a core dump.
For historical reasons, the GOTRACEBACK settings 0, 1, and 2 are synonyms for
none, all, and system, respectively.
The runtime/debug package's SetTraceback function allows increasing the
amount of output at run time, but it cannot reduce the amount below that
specified by the environment variable.
See https://golang.org/pkg/runtime/debug/#SetTraceback.
The GOARCH, GOOS, GOPATH, and GOROOT environment variables complete
the set of Go environment variables. They influence the building of Go programs
(see https://golang.org/cmd/go and https://golang.org/pkg/go/build).
GOARCH, GOOS, and GOROOT are recorded at compile time and made available by
constants or functions in this package, but they do not influence the execution
of the run-time system.
*/
package runtime
import (
"internal/goarch"
"internal/goos"
)
// Caller reports file and line number information about function invocations on
// the calling goroutine's stack. The argument skip is the number of stack frames
// to ascend, with 0 identifying the caller of Caller. (For historical reasons the
// meaning of skip differs between Caller and Callers.) The return values report the
// program counter, file name, and line number within the file of the corresponding
// call. The boolean ok is false if it was not possible to recover the information.
func Caller(skip int) (pc uintptr, file string, line int, ok bool) {
rpc := make([]uintptr, 1)
n := callers(skip+1, rpc[:])
if n < 1 {
return
}
frame, _ := CallersFrames(rpc).Next()
return frame.PC, frame.File, frame.Line, frame.PC != 0
}
// Callers fills the slice pc with the return program counters of function invocations
// on the calling goroutine's stack. The argument skip is the number of stack frames
// to skip before recording in pc, with 0 identifying the frame for Callers itself and
// 1 identifying the caller of Callers.
// It returns the number of entries written to pc.
//
// To translate these PCs into symbolic information such as function
// names and line numbers, use CallersFrames. CallersFrames accounts
// for inlined functions and adjusts the return program counters into
// call program counters. Iterating over the returned slice of PCs
// directly is discouraged, as is using FuncForPC on any of the
// returned PCs, since these cannot account for inlining or return
// program counter adjustment.
func Callers(skip int, pc []uintptr) int {
// runtime.callers uses pc.array==nil as a signal
// to print a stack trace. Pick off 0-length pc here
// so that we don't let a nil pc slice get to it.
if len(pc) == 0 {
return 0
}
return callers(skip, pc)
}
var defaultGOROOT string // set by cmd/link
// GOROOT returns the root of the Go tree. It uses the
// GOROOT environment variable, if set at process start,
// or else the root used during the Go build.
func GOROOT() string {
s := gogetenv("GOROOT")
if s != "" {
return s
}
return defaultGOROOT
}
// buildVersion is the Go tree's version string at build time.
//
// If any GOEXPERIMENTs are set to non-default values, it will include
// "X:<GOEXPERIMENT>".
//
// This is set by the linker.
//
// This is accessed by "go version <binary>".
var buildVersion string
// Version returns the Go tree's version string.
// It is either the commit hash and date at the time of the build or,
// when possible, a release tag like "go1.3".
func Version() string {
return buildVersion
}
// GOOS is the running program's operating system target:
// one of darwin, freebsd, linux, and so on.
// To view possible combinations of GOOS and GOARCH, run "go tool dist list".
const GOOS string = goos.GOOS
// GOARCH is the running program's architecture target:
// one of 386, amd64, arm, s390x, and so on.
const GOARCH string = goarch.GOARCH
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// fastlog2 implements a fast approximation to the base 2 log of a
// float64. This is used to compute a geometric distribution for heap
// sampling, without introducing dependencies into package math. This
// uses a very rough approximation using the float64 exponent and the
// first 25 bits of the mantissa. The top 5 bits of the mantissa are
// used to load limits from a table of constants and the rest are used
// to scale linearly between them.
func fastlog2(x float64) float64 {
const fastlogScaleBits = 20
const fastlogScaleRatio = 1.0 / (1 << fastlogScaleBits)
xBits := float64bits(x)
// Extract the exponent from the IEEE float64, and index a constant
// table with the first 10 bits from the mantissa.
xExp := int64((xBits>>52)&0x7FF) - 1023
xManIndex := (xBits >> (52 - fastlogNumBits)) % (1 << fastlogNumBits)
xManScale := (xBits >> (52 - fastlogNumBits - fastlogScaleBits)) % (1 << fastlogScaleBits)
low, high := fastlog2Table[xManIndex], fastlog2Table[xManIndex+1]
return float64(xExp) + low + (high-low)*float64(xManScale)*fastlogScaleRatio
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
var inf = float64frombits(0x7FF0000000000000)
// isNaN reports whether f is an IEEE 754 “not-a-number” value.
func isNaN(f float64) (is bool) {
// IEEE 754 says that only NaNs satisfy f != f.
return f != f
}
// isFinite reports whether f is neither NaN nor an infinity.
func isFinite(f float64) bool {
return !isNaN(f - f)
}
// isInf reports whether f is an infinity.
func isInf(f float64) bool {
return !isNaN(f) && !isFinite(f)
}
// abs returns the absolute value of x.
//
// Special cases are:
//
// abs(±Inf) = +Inf
// abs(NaN) = NaN
func abs(x float64) float64 {
const sign = 1 << 63
return float64frombits(float64bits(x) &^ sign)
}
// copysign returns a value with the magnitude
// of x and the sign of y.
func copysign(x, y float64) float64 {
const sign = 1 << 63
return float64frombits(float64bits(x)&^sign | float64bits(y)&sign)
}
// float64bits returns the IEEE 754 binary representation of f.
func float64bits(f float64) uint64 {
return *(*uint64)(unsafe.Pointer(&f))
}
// float64frombits returns the floating point number corresponding
// the IEEE 754 binary representation b.
func float64frombits(b uint64) float64 {
return *(*float64)(unsafe.Pointer(&b))
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Hashing algorithm inspired by
// wyhash: https://github.com/wangyi-fudan/wyhash
//go:build amd64 || arm64 || loong64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x || wasm
package runtime
import (
"runtime/internal/math"
"unsafe"
)
const (
m1 = 0xa0761d6478bd642f
m2 = 0xe7037ed1a0b428db
m3 = 0x8ebc6af09c88c6e3
m4 = 0x589965cc75374cc3
m5 = 0x1d8e4e27c47d124f
)
func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
var a, b uintptr
seed ^= hashkey[0] ^ m1
switch {
case s == 0:
return seed
case s < 4:
a = uintptr(*(*byte)(p))
a |= uintptr(*(*byte)(add(p, s>>1))) << 8
a |= uintptr(*(*byte)(add(p, s-1))) << 16
case s == 4:
a = r4(p)
b = a
case s < 8:
a = r4(p)
b = r4(add(p, s-4))
case s == 8:
a = r8(p)
b = a
case s <= 16:
a = r8(p)
b = r8(add(p, s-8))
default:
l := s
if l > 48 {
seed1 := seed
seed2 := seed
for ; l > 48; l -= 48 {
seed = mix(r8(p)^m2, r8(add(p, 8))^seed)
seed1 = mix(r8(add(p, 16))^m3, r8(add(p, 24))^seed1)
seed2 = mix(r8(add(p, 32))^m4, r8(add(p, 40))^seed2)
p = add(p, 48)
}
seed ^= seed1 ^ seed2
}
for ; l > 16; l -= 16 {
seed = mix(r8(p)^m2, r8(add(p, 8))^seed)
p = add(p, 16)
}
a = r8(add(p, l-16))
b = r8(add(p, l-8))
}
return mix(m5^s, mix(a^m2, b^seed))
}
func memhash32Fallback(p unsafe.Pointer, seed uintptr) uintptr {
a := r4(p)
return mix(m5^4, mix(a^m2, a^seed^hashkey[0]^m1))
}
func memhash64Fallback(p unsafe.Pointer, seed uintptr) uintptr {
a := r8(p)
return mix(m5^8, mix(a^m2, a^seed^hashkey[0]^m1))
}
func mix(a, b uintptr) uintptr {
hi, lo := math.Mul64(uint64(a), uint64(b))
return uintptr(hi ^ lo)
}
func r4(p unsafe.Pointer) uintptr {
return uintptr(readUnaligned32(p))
}
func r8(p unsafe.Pointer) uintptr {
return uintptr(readUnaligned64(p))
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implementation of runtime/debug.WriteHeapDump. Writes all
// objects in the heap plus additional info (roots, threads,
// finalizers, etc.) to a file.
// The format of the dumped file is described at
// https://golang.org/s/go15heapdump.
package runtime
import (
"internal/goarch"
"unsafe"
)
//go:linkname runtime_debug_WriteHeapDump runtime/debug.WriteHeapDump
func runtime_debug_WriteHeapDump(fd uintptr) {
stopTheWorld("write heap dump")
// Keep m on this G's stack instead of the system stack.
// Both readmemstats_m and writeheapdump_m have pretty large
// peak stack depths and we risk blowing the system stack.
// This is safe because the world is stopped, so we don't
// need to worry about anyone shrinking and therefore moving
// our stack.
var m MemStats
systemstack(func() {
// Call readmemstats_m here instead of deeper in
// writeheapdump_m because we might blow the system stack
// otherwise.
readmemstats_m(&m)
writeheapdump_m(fd, &m)
})
startTheWorld()
}
const (
fieldKindEol = 0
fieldKindPtr = 1
fieldKindIface = 2
fieldKindEface = 3
tagEOF = 0
tagObject = 1
tagOtherRoot = 2
tagType = 3
tagGoroutine = 4
tagStackFrame = 5
tagParams = 6
tagFinalizer = 7
tagItab = 8
tagOSThread = 9
tagMemStats = 10
tagQueuedFinalizer = 11
tagData = 12
tagBSS = 13
tagDefer = 14
tagPanic = 15
tagMemProf = 16
tagAllocSample = 17
)
var dumpfd uintptr // fd to write the dump to.
var tmpbuf []byte
// buffer of pending write data
const (
bufSize = 4096
)
var buf [bufSize]byte
var nbuf uintptr
func dwrite(data unsafe.Pointer, len uintptr) {
if len == 0 {
return
}
if nbuf+len <= bufSize {
copy(buf[nbuf:], (*[bufSize]byte)(data)[:len])
nbuf += len
return
}
write(dumpfd, unsafe.Pointer(&buf), int32(nbuf))
if len >= bufSize {
write(dumpfd, data, int32(len))
nbuf = 0
} else {
copy(buf[:], (*[bufSize]byte)(data)[:len])
nbuf = len
}
}
func dwritebyte(b byte) {
dwrite(unsafe.Pointer(&b), 1)
}
func flush() {
write(dumpfd, unsafe.Pointer(&buf), int32(nbuf))
nbuf = 0
}
// Cache of types that have been serialized already.
// We use a type's hash field to pick a bucket.
// Inside a bucket, we keep a list of types that
// have been serialized so far, most recently used first.
// Note: when a bucket overflows we may end up
// serializing a type more than once. That's ok.
const (
typeCacheBuckets = 256
typeCacheAssoc = 4
)
type typeCacheBucket struct {
t [typeCacheAssoc]*_type
}
var typecache [typeCacheBuckets]typeCacheBucket
// dump a uint64 in a varint format parseable by encoding/binary.
func dumpint(v uint64) {
var buf [10]byte
var n int
for v >= 0x80 {
buf[n] = byte(v | 0x80)
n++
v >>= 7
}
buf[n] = byte(v)
n++
dwrite(unsafe.Pointer(&buf), uintptr(n))
}
func dumpbool(b bool) {
if b {
dumpint(1)
} else {
dumpint(0)
}
}
// dump varint uint64 length followed by memory contents.
func dumpmemrange(data unsafe.Pointer, len uintptr) {
dumpint(uint64(len))
dwrite(data, len)
}
func dumpslice(b []byte) {
dumpint(uint64(len(b)))
if len(b) > 0 {
dwrite(unsafe.Pointer(&b[0]), uintptr(len(b)))
}
}
func dumpstr(s string) {
dumpmemrange(unsafe.Pointer(unsafe.StringData(s)), uintptr(len(s)))
}
// dump information for a type.
func dumptype(t *_type) {
if t == nil {
return
}
// If we've definitely serialized the type before,
// no need to do it again.
b := &typecache[t.hash&(typeCacheBuckets-1)]
if t == b.t[0] {
return
}
for i := 1; i < typeCacheAssoc; i++ {
if t == b.t[i] {
// Move-to-front
for j := i; j > 0; j-- {
b.t[j] = b.t[j-1]
}
b.t[0] = t
return
}
}
// Might not have been dumped yet. Dump it and
// remember we did so.
for j := typeCacheAssoc - 1; j > 0; j-- {
b.t[j] = b.t[j-1]
}
b.t[0] = t
// dump the type
dumpint(tagType)
dumpint(uint64(uintptr(unsafe.Pointer(t))))
dumpint(uint64(t.size))
if x := t.uncommon(); x == nil || t.nameOff(x.pkgpath).name() == "" {
dumpstr(t.string())
} else {
pkgpath := t.nameOff(x.pkgpath).name()
name := t.name()
dumpint(uint64(uintptr(len(pkgpath)) + 1 + uintptr(len(name))))
dwrite(unsafe.Pointer(unsafe.StringData(pkgpath)), uintptr(len(pkgpath)))
dwritebyte('.')
dwrite(unsafe.Pointer(unsafe.StringData(name)), uintptr(len(name)))
}
dumpbool(t.kind&kindDirectIface == 0 || t.ptrdata != 0)
}
// dump an object.
func dumpobj(obj unsafe.Pointer, size uintptr, bv bitvector) {
dumpint(tagObject)
dumpint(uint64(uintptr(obj)))
dumpmemrange(obj, size)
dumpfields(bv)
}
func dumpotherroot(description string, to unsafe.Pointer) {
dumpint(tagOtherRoot)
dumpstr(description)
dumpint(uint64(uintptr(to)))
}
func dumpfinalizer(obj unsafe.Pointer, fn *funcval, fint *_type, ot *ptrtype) {
dumpint(tagFinalizer)
dumpint(uint64(uintptr(obj)))
dumpint(uint64(uintptr(unsafe.Pointer(fn))))
dumpint(uint64(uintptr(unsafe.Pointer(fn.fn))))
dumpint(uint64(uintptr(unsafe.Pointer(fint))))
dumpint(uint64(uintptr(unsafe.Pointer(ot))))
}
type childInfo struct {
// Information passed up from the callee frame about
// the layout of the outargs region.
argoff uintptr // where the arguments start in the frame
arglen uintptr // size of args region
args bitvector // if args.n >= 0, pointer map of args region
sp *uint8 // callee sp
depth uintptr // depth in call stack (0 == most recent)
}
// dump kinds & offsets of interesting fields in bv.
func dumpbv(cbv *bitvector, offset uintptr) {
for i := uintptr(0); i < uintptr(cbv.n); i++ {
if cbv.ptrbit(i) == 1 {
dumpint(fieldKindPtr)
dumpint(uint64(offset + i*goarch.PtrSize))
}
}
}
func dumpframe(s *stkframe, arg unsafe.Pointer) bool {
child := (*childInfo)(arg)
f := s.fn
// Figure out what we can about our stack map
pc := s.pc
pcdata := int32(-1) // Use the entry map at function entry
if pc != f.entry() {
pc--
pcdata = pcdatavalue(f, _PCDATA_StackMapIndex, pc, nil)
}
if pcdata == -1 {
// We do not have a valid pcdata value but there might be a
// stackmap for this function. It is likely that we are looking
// at the function prologue, assume so and hope for the best.
pcdata = 0
}
stkmap := (*stackmap)(funcdata(f, _FUNCDATA_LocalsPointerMaps))
var bv bitvector
if stkmap != nil && stkmap.n > 0 {
bv = stackmapdata(stkmap, pcdata)
} else {
bv.n = -1
}
// Dump main body of stack frame.
dumpint(tagStackFrame)
dumpint(uint64(s.sp)) // lowest address in frame
dumpint(uint64(child.depth)) // # of frames deep on the stack
dumpint(uint64(uintptr(unsafe.Pointer(child.sp)))) // sp of child, or 0 if bottom of stack
dumpmemrange(unsafe.Pointer(s.sp), s.fp-s.sp) // frame contents
dumpint(uint64(f.entry()))
dumpint(uint64(s.pc))
dumpint(uint64(s.continpc))
name := funcname(f)
if name == "" {
name = "unknown function"
}
dumpstr(name)
// Dump fields in the outargs section
if child.args.n >= 0 {
dumpbv(&child.args, child.argoff)
} else {
// conservative - everything might be a pointer
for off := child.argoff; off < child.argoff+child.arglen; off += goarch.PtrSize {
dumpint(fieldKindPtr)
dumpint(uint64(off))
}
}
// Dump fields in the local vars section
if stkmap == nil {
// No locals information, dump everything.
for off := child.arglen; off < s.varp-s.sp; off += goarch.PtrSize {
dumpint(fieldKindPtr)
dumpint(uint64(off))
}
} else if stkmap.n < 0 {
// Locals size information, dump just the locals.
size := uintptr(-stkmap.n)
for off := s.varp - size - s.sp; off < s.varp-s.sp; off += goarch.PtrSize {
dumpint(fieldKindPtr)
dumpint(uint64(off))
}
} else if stkmap.n > 0 {
// Locals bitmap information, scan just the pointers in
// locals.
dumpbv(&bv, s.varp-uintptr(bv.n)*goarch.PtrSize-s.sp)
}
dumpint(fieldKindEol)
// Record arg info for parent.
child.argoff = s.argp - s.fp
child.arglen = s.argBytes()
child.sp = (*uint8)(unsafe.Pointer(s.sp))
child.depth++
stkmap = (*stackmap)(funcdata(f, _FUNCDATA_ArgsPointerMaps))
if stkmap != nil {
child.args = stackmapdata(stkmap, pcdata)
} else {
child.args.n = -1
}
return true
}
func dumpgoroutine(gp *g) {
var sp, pc, lr uintptr
if gp.syscallsp != 0 {
sp = gp.syscallsp
pc = gp.syscallpc
lr = 0
} else {
sp = gp.sched.sp
pc = gp.sched.pc
lr = gp.sched.lr
}
dumpint(tagGoroutine)
dumpint(uint64(uintptr(unsafe.Pointer(gp))))
dumpint(uint64(sp))
dumpint(gp.goid)
dumpint(uint64(gp.gopc))
dumpint(uint64(readgstatus(gp)))
dumpbool(isSystemGoroutine(gp, false))
dumpbool(false) // isbackground
dumpint(uint64(gp.waitsince))
dumpstr(gp.waitreason.String())
dumpint(uint64(uintptr(gp.sched.ctxt)))
dumpint(uint64(uintptr(unsafe.Pointer(gp.m))))
dumpint(uint64(uintptr(unsafe.Pointer(gp._defer))))
dumpint(uint64(uintptr(unsafe.Pointer(gp._panic))))
// dump stack
var child childInfo
child.args.n = -1
child.arglen = 0
child.sp = nil
child.depth = 0
gentraceback(pc, sp, lr, gp, 0, nil, 0x7fffffff, dumpframe, noescape(unsafe.Pointer(&child)), 0)
// dump defer & panic records
for d := gp._defer; d != nil; d = d.link {
dumpint(tagDefer)
dumpint(uint64(uintptr(unsafe.Pointer(d))))
dumpint(uint64(uintptr(unsafe.Pointer(gp))))
dumpint(uint64(d.sp))
dumpint(uint64(d.pc))
fn := *(**funcval)(unsafe.Pointer(&d.fn))
dumpint(uint64(uintptr(unsafe.Pointer(fn))))
if d.fn == nil {
// d.fn can be nil for open-coded defers
dumpint(uint64(0))
} else {
dumpint(uint64(uintptr(unsafe.Pointer(fn.fn))))
}
dumpint(uint64(uintptr(unsafe.Pointer(d.link))))
}
for p := gp._panic; p != nil; p = p.link {
dumpint(tagPanic)
dumpint(uint64(uintptr(unsafe.Pointer(p))))
dumpint(uint64(uintptr(unsafe.Pointer(gp))))
eface := efaceOf(&p.arg)
dumpint(uint64(uintptr(unsafe.Pointer(eface._type))))
dumpint(uint64(uintptr(unsafe.Pointer(eface.data))))
dumpint(0) // was p->defer, no longer recorded
dumpint(uint64(uintptr(unsafe.Pointer(p.link))))
}
}
func dumpgs() {
assertWorldStopped()
// goroutines & stacks
forEachG(func(gp *g) {
status := readgstatus(gp) // The world is stopped so gp will not be in a scan state.
switch status {
default:
print("runtime: unexpected G.status ", hex(status), "\n")
throw("dumpgs in STW - bad status")
case _Gdead:
// ok
case _Grunnable,
_Gsyscall,
_Gwaiting:
dumpgoroutine(gp)
}
})
}
func finq_callback(fn *funcval, obj unsafe.Pointer, nret uintptr, fint *_type, ot *ptrtype) {
dumpint(tagQueuedFinalizer)
dumpint(uint64(uintptr(obj)))
dumpint(uint64(uintptr(unsafe.Pointer(fn))))
dumpint(uint64(uintptr(unsafe.Pointer(fn.fn))))
dumpint(uint64(uintptr(unsafe.Pointer(fint))))
dumpint(uint64(uintptr(unsafe.Pointer(ot))))
}
func dumproots() {
// To protect mheap_.allspans.
assertWorldStopped()
// TODO(mwhudson): dump datamask etc from all objects
// data segment
dumpint(tagData)
dumpint(uint64(firstmoduledata.data))
dumpmemrange(unsafe.Pointer(firstmoduledata.data), firstmoduledata.edata-firstmoduledata.data)
dumpfields(firstmoduledata.gcdatamask)
// bss segment
dumpint(tagBSS)
dumpint(uint64(firstmoduledata.bss))
dumpmemrange(unsafe.Pointer(firstmoduledata.bss), firstmoduledata.ebss-firstmoduledata.bss)
dumpfields(firstmoduledata.gcbssmask)
// mspan.types
for _, s := range mheap_.allspans {
if s.state.get() == mSpanInUse {
// Finalizers
for sp := s.specials; sp != nil; sp = sp.next {
if sp.kind != _KindSpecialFinalizer {
continue
}
spf := (*specialfinalizer)(unsafe.Pointer(sp))
p := unsafe.Pointer(s.base() + uintptr(spf.special.offset))
dumpfinalizer(p, spf.fn, spf.fint, spf.ot)
}
}
}
// Finalizer queue
iterate_finq(finq_callback)
}
// Bit vector of free marks.
// Needs to be as big as the largest number of objects per span.
var freemark [_PageSize / 8]bool
func dumpobjs() {
// To protect mheap_.allspans.
assertWorldStopped()
for _, s := range mheap_.allspans {
if s.state.get() != mSpanInUse {
continue
}
p := s.base()
size := s.elemsize
n := (s.npages << _PageShift) / size
if n > uintptr(len(freemark)) {
throw("freemark array doesn't have enough entries")
}
for freeIndex := uintptr(0); freeIndex < s.nelems; freeIndex++ {
if s.isFree(freeIndex) {
freemark[freeIndex] = true
}
}
for j := uintptr(0); j < n; j, p = j+1, p+size {
if freemark[j] {
freemark[j] = false
continue
}
dumpobj(unsafe.Pointer(p), size, makeheapobjbv(p, size))
}
}
}
func dumpparams() {
dumpint(tagParams)
x := uintptr(1)
if *(*byte)(unsafe.Pointer(&x)) == 1 {
dumpbool(false) // little-endian ptrs
} else {
dumpbool(true) // big-endian ptrs
}
dumpint(goarch.PtrSize)
var arenaStart, arenaEnd uintptr
for i1 := range mheap_.arenas {
if mheap_.arenas[i1] == nil {
continue
}
for i, ha := range mheap_.arenas[i1] {
if ha == nil {
continue
}
base := arenaBase(arenaIdx(i1)<<arenaL1Shift | arenaIdx(i))
if arenaStart == 0 || base < arenaStart {
arenaStart = base
}
if base+heapArenaBytes > arenaEnd {
arenaEnd = base + heapArenaBytes
}
}
}
dumpint(uint64(arenaStart))
dumpint(uint64(arenaEnd))
dumpstr(goarch.GOARCH)
dumpstr(buildVersion)
dumpint(uint64(ncpu))
}
func itab_callback(tab *itab) {
t := tab._type
dumptype(t)
dumpint(tagItab)
dumpint(uint64(uintptr(unsafe.Pointer(tab))))
dumpint(uint64(uintptr(unsafe.Pointer(t))))
}
func dumpitabs() {
iterate_itabs(itab_callback)
}
func dumpms() {
for mp := allm; mp != nil; mp = mp.alllink {
dumpint(tagOSThread)
dumpint(uint64(uintptr(unsafe.Pointer(mp))))
dumpint(uint64(mp.id))
dumpint(mp.procid)
}
}
//go:systemstack
func dumpmemstats(m *MemStats) {
assertWorldStopped()
// These ints should be identical to the exported
// MemStats structure and should be ordered the same
// way too.
dumpint(tagMemStats)
dumpint(m.Alloc)
dumpint(m.TotalAlloc)
dumpint(m.Sys)
dumpint(m.Lookups)
dumpint(m.Mallocs)
dumpint(m.Frees)
dumpint(m.HeapAlloc)
dumpint(m.HeapSys)
dumpint(m.HeapIdle)
dumpint(m.HeapInuse)
dumpint(m.HeapReleased)
dumpint(m.HeapObjects)
dumpint(m.StackInuse)
dumpint(m.StackSys)
dumpint(m.MSpanInuse)
dumpint(m.MSpanSys)
dumpint(m.MCacheInuse)
dumpint(m.MCacheSys)
dumpint(m.BuckHashSys)
dumpint(m.GCSys)
dumpint(m.OtherSys)
dumpint(m.NextGC)
dumpint(m.LastGC)
dumpint(m.PauseTotalNs)
for i := 0; i < 256; i++ {
dumpint(m.PauseNs[i])
}
dumpint(uint64(m.NumGC))
}
func dumpmemprof_callback(b *bucket, nstk uintptr, pstk *uintptr, size, allocs, frees uintptr) {
stk := (*[100000]uintptr)(unsafe.Pointer(pstk))
dumpint(tagMemProf)
dumpint(uint64(uintptr(unsafe.Pointer(b))))
dumpint(uint64(size))
dumpint(uint64(nstk))
for i := uintptr(0); i < nstk; i++ {
pc := stk[i]
f := findfunc(pc)
if !f.valid() {
var buf [64]byte
n := len(buf)
n--
buf[n] = ')'
if pc == 0 {
n--
buf[n] = '0'
} else {
for pc > 0 {
n--
buf[n] = "0123456789abcdef"[pc&15]
pc >>= 4
}
}
n--
buf[n] = 'x'
n--
buf[n] = '0'
n--
buf[n] = '('
dumpslice(buf[n:])
dumpstr("?")
dumpint(0)
} else {
dumpstr(funcname(f))
if i > 0 && pc > f.entry() {
pc--
}
file, line := funcline(f, pc)
dumpstr(file)
dumpint(uint64(line))
}
}
dumpint(uint64(allocs))
dumpint(uint64(frees))
}
func dumpmemprof() {
// To protect mheap_.allspans.
assertWorldStopped()
iterate_memprof(dumpmemprof_callback)
for _, s := range mheap_.allspans {
if s.state.get() != mSpanInUse {
continue
}
for sp := s.specials; sp != nil; sp = sp.next {
if sp.kind != _KindSpecialProfile {
continue
}
spp := (*specialprofile)(unsafe.Pointer(sp))
p := s.base() + uintptr(spp.special.offset)
dumpint(tagAllocSample)
dumpint(uint64(p))
dumpint(uint64(uintptr(unsafe.Pointer(spp.b))))
}
}
}
var dumphdr = []byte("go1.7 heap dump\n")
func mdump(m *MemStats) {
assertWorldStopped()
// make sure we're done sweeping
for _, s := range mheap_.allspans {
if s.state.get() == mSpanInUse {
s.ensureSwept()
}
}
memclrNoHeapPointers(unsafe.Pointer(&typecache), unsafe.Sizeof(typecache))
dwrite(unsafe.Pointer(&dumphdr[0]), uintptr(len(dumphdr)))
dumpparams()
dumpitabs()
dumpobjs()
dumpgs()
dumpms()
dumproots()
dumpmemstats(m)
dumpmemprof()
dumpint(tagEOF)
flush()
}
func writeheapdump_m(fd uintptr, m *MemStats) {
assertWorldStopped()
gp := getg()
casGToWaiting(gp.m.curg, _Grunning, waitReasonDumpingHeap)
// Set dump file.
dumpfd = fd
// Call dump routine.
mdump(m)
// Reset dump file.
dumpfd = 0
if tmpbuf != nil {
sysFree(unsafe.Pointer(&tmpbuf[0]), uintptr(len(tmpbuf)), &memstats.other_sys)
tmpbuf = nil
}
casgstatus(gp.m.curg, _Gwaiting, _Grunning)
}
// dumpint() the kind & offset of each field in an object.
func dumpfields(bv bitvector) {
dumpbv(&bv, 0)
dumpint(fieldKindEol)
}
func makeheapobjbv(p uintptr, size uintptr) bitvector {
// Extend the temp buffer if necessary.
nptr := size / goarch.PtrSize
if uintptr(len(tmpbuf)) < nptr/8+1 {
if tmpbuf != nil {
sysFree(unsafe.Pointer(&tmpbuf[0]), uintptr(len(tmpbuf)), &memstats.other_sys)
}
n := nptr/8 + 1
p := sysAlloc(n, &memstats.other_sys)
if p == nil {
throw("heapdump: out of memory")
}
tmpbuf = (*[1 << 30]byte)(p)[:n]
}
// Convert heap bitmap to pointer bitmap.
for i := uintptr(0); i < nptr/8+1; i++ {
tmpbuf[i] = 0
}
hbits := heapBitsForAddr(p, size)
for {
var addr uintptr
hbits, addr = hbits.next()
if addr == 0 {
break
}
i := (addr - p) / goarch.PtrSize
tmpbuf[i/8] |= 1 << (i % 8)
}
return bitvector{int32(nptr), &tmpbuf[0]}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
const (
// For the time histogram type, we use an HDR histogram.
// Values are placed in buckets based solely on the most
// significant set bit. Thus, buckets are power-of-2 sized.
// Values are then placed into sub-buckets based on the value of
// the next timeHistSubBucketBits most significant bits. Thus,
// sub-buckets are linear within a bucket.
//
// Therefore, the number of sub-buckets (timeHistNumSubBuckets)
// defines the error. This error may be computed as
// 1/timeHistNumSubBuckets*100%. For example, for 16 sub-buckets
// per bucket the error is approximately 6%.
//
// The number of buckets (timeHistNumBuckets), on the
// other hand, defines the range. To avoid producing a large number
// of buckets that are close together, especially for small numbers
// (e.g. 1, 2, 3, 4, 5 ns) that aren't very useful, timeHistNumBuckets
// is defined in terms of the least significant bit (timeHistMinBucketBits)
// that needs to be set before we start bucketing and the most
// significant bit (timeHistMaxBucketBits) that we bucket before we just
// dump it into a catch-all bucket.
//
// As an example, consider the configuration:
//
// timeHistMinBucketBits = 9
// timeHistMaxBucketBits = 48
// timeHistSubBucketBits = 2
//
// Then:
//
// 011000001
// ^--
// │ ^
// │ └---- Next 2 bits -> sub-bucket 3
// └------- Bit 9 unset -> bucket 0
//
// 110000001
// ^--
// │ ^
// │ └---- Next 2 bits -> sub-bucket 2
// └------- Bit 9 set -> bucket 1
//
// 1000000010
// ^-- ^
// │ ^ └-- Lower bits ignored
// │ └---- Next 2 bits -> sub-bucket 0
// └------- Bit 10 set -> bucket 2
//
// Following this pattern, bucket 38 will have the bit 46 set. We don't
// have any buckets for higher values, so we spill the rest into an overflow
// bucket containing values of 2^47-1 nanoseconds or approx. 1 day or more.
// This range is more than enough to handle durations produced by the runtime.
timeHistMinBucketBits = 9
timeHistMaxBucketBits = 48 // Note that this is exclusive; 1 higher than the actual range.
timeHistSubBucketBits = 2
timeHistNumSubBuckets = 1 << timeHistSubBucketBits
timeHistNumBuckets = timeHistMaxBucketBits - timeHistMinBucketBits + 1
// Two extra buckets, one for underflow, one for overflow.
timeHistTotalBuckets = timeHistNumBuckets*timeHistNumSubBuckets + 2
)
// timeHistogram represents a distribution of durations in
// nanoseconds.
//
// The accuracy and range of the histogram is defined by the
// timeHistSubBucketBits and timeHistNumBuckets constants.
//
// It is an HDR histogram with exponentially-distributed
// buckets and linearly distributed sub-buckets.
//
// The histogram is safe for concurrent reads and writes.
type timeHistogram struct {
counts [timeHistNumBuckets * timeHistNumSubBuckets]atomic.Uint64
// underflow counts all the times we got a negative duration
// sample. Because of how time works on some platforms, it's
// possible to measure negative durations. We could ignore them,
// but we record them anyway because it's better to have some
// signal that it's happening than just missing samples.
underflow atomic.Uint64
// overflow counts all the times we got a duration that exceeded
// the range counts represents.
overflow atomic.Uint64
}
// record adds the given duration to the distribution.
//
// Disallow preemptions and stack growths because this function
// may run in sensitive locations.
//
//go:nosplit
func (h *timeHistogram) record(duration int64) {
// If the duration is negative, capture that in underflow.
if duration < 0 {
h.underflow.Add(1)
return
}
// bucketBit is the target bit for the bucket which is usually the
// highest 1 bit, but if we're less than the minimum, is the highest
// 1 bit of the minimum (which will be zero in the duration).
//
// bucket is the bucket index, which is the bucketBit minus the
// highest bit of the minimum, plus one to leave room for the catch-all
// bucket for samples lower than the minimum.
var bucketBit, bucket uint
if l := sys.Len64(uint64(duration)); l < timeHistMinBucketBits {
bucketBit = timeHistMinBucketBits
bucket = 0 // bucketBit - timeHistMinBucketBits
} else {
bucketBit = uint(l)
bucket = bucketBit - timeHistMinBucketBits + 1
}
// If the bucket we computed is greater than the number of buckets,
// count that in overflow.
if bucket >= timeHistNumBuckets {
h.overflow.Add(1)
return
}
// The sub-bucket index is just next timeHistSubBucketBits after the bucketBit.
subBucket := uint(duration>>(bucketBit-1-timeHistSubBucketBits)) % timeHistNumSubBuckets
h.counts[bucket*timeHistNumSubBuckets+subBucket].Add(1)
}
const (
fInf = 0x7FF0000000000000
fNegInf = 0xFFF0000000000000
)
func float64Inf() float64 {
inf := uint64(fInf)
return *(*float64)(unsafe.Pointer(&inf))
}
func float64NegInf() float64 {
inf := uint64(fNegInf)
return *(*float64)(unsafe.Pointer(&inf))
}
// timeHistogramMetricsBuckets generates a slice of boundaries for
// the timeHistogram. These boundaries are represented in seconds,
// not nanoseconds like the timeHistogram represents durations.
func timeHistogramMetricsBuckets() []float64 {
b := make([]float64, timeHistTotalBuckets+1)
// Underflow bucket.
b[0] = float64NegInf()
for j := 0; j < timeHistNumSubBuckets; j++ {
// No bucket bit for the first few buckets. Just sub-bucket bits after the
// min bucket bit.
bucketNanos := uint64(j) << (timeHistMinBucketBits - 1 - timeHistSubBucketBits)
// Convert nanoseconds to seconds via a division.
// These values will all be exactly representable by a float64.
b[j+1] = float64(bucketNanos) / 1e9
}
// Generate the rest of the buckets. It's easier to reason
// about if we cut out the 0'th bucket.
for i := timeHistMinBucketBits; i < timeHistMaxBucketBits; i++ {
for j := 0; j < timeHistNumSubBuckets; j++ {
// Set the bucket bit.
bucketNanos := uint64(1) << (i - 1)
// Set the sub-bucket bits.
bucketNanos |= uint64(j) << (i - 1 - timeHistSubBucketBits)
// The index for this bucket is going to be the (i+1)'th bucket
// (note that we're starting from zero, but handled the first bucket
// earlier, so we need to compensate), and the j'th sub bucket.
// Add 1 because we left space for -Inf.
bucketIndex := (i-timeHistMinBucketBits+1)*timeHistNumSubBuckets + j + 1
// Convert nanoseconds to seconds via a division.
// These values will all be exactly representable by a float64.
b[bucketIndex] = float64(bucketNanos) / 1e9
}
}
// Overflow bucket.
b[len(b)-2] = float64(uint64(1)<<(timeHistMaxBucketBits-1)) / 1e9
b[len(b)-1] = float64Inf()
return b
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
const itabInitSize = 512
var (
itabLock mutex // lock for accessing itab table
itabTable = &itabTableInit // pointer to current table
itabTableInit = itabTableType{size: itabInitSize} // starter table
)
// Note: change the formula in the mallocgc call in itabAdd if you change these fields.
type itabTableType struct {
size uintptr // length of entries array. Always a power of 2.
count uintptr // current number of filled entries.
entries [itabInitSize]*itab // really [size] large
}
func itabHashFunc(inter *interfacetype, typ *_type) uintptr {
// compiler has provided some good hash codes for us.
return uintptr(inter.typ.hash ^ typ.hash)
}
func getitab(inter *interfacetype, typ *_type, canfail bool) *itab {
if len(inter.mhdr) == 0 {
throw("internal error - misuse of itab")
}
// easy case
if typ.tflag&tflagUncommon == 0 {
if canfail {
return nil
}
name := inter.typ.nameOff(inter.mhdr[0].name)
panic(&TypeAssertionError{nil, typ, &inter.typ, name.name()})
}
var m *itab
// First, look in the existing table to see if we can find the itab we need.
// This is by far the most common case, so do it without locks.
// Use atomic to ensure we see any previous writes done by the thread
// that updates the itabTable field (with atomic.Storep in itabAdd).
t := (*itabTableType)(atomic.Loadp(unsafe.Pointer(&itabTable)))
if m = t.find(inter, typ); m != nil {
goto finish
}
// Not found. Grab the lock and try again.
lock(&itabLock)
if m = itabTable.find(inter, typ); m != nil {
unlock(&itabLock)
goto finish
}
// Entry doesn't exist yet. Make a new entry & add it.
m = (*itab)(persistentalloc(unsafe.Sizeof(itab{})+uintptr(len(inter.mhdr)-1)*goarch.PtrSize, 0, &memstats.other_sys))
m.inter = inter
m._type = typ
// The hash is used in type switches. However, compiler statically generates itab's
// for all interface/type pairs used in switches (which are added to itabTable
// in itabsinit). The dynamically-generated itab's never participate in type switches,
// and thus the hash is irrelevant.
// Note: m.hash is _not_ the hash used for the runtime itabTable hash table.
m.hash = 0
m.init()
itabAdd(m)
unlock(&itabLock)
finish:
if m.fun[0] != 0 {
return m
}
if canfail {
return nil
}
// this can only happen if the conversion
// was already done once using the , ok form
// and we have a cached negative result.
// The cached result doesn't record which
// interface function was missing, so initialize
// the itab again to get the missing function name.
panic(&TypeAssertionError{concrete: typ, asserted: &inter.typ, missingMethod: m.init()})
}
// find finds the given interface/type pair in t.
// Returns nil if the given interface/type pair isn't present.
func (t *itabTableType) find(inter *interfacetype, typ *_type) *itab {
// Implemented using quadratic probing.
// Probe sequence is h(i) = h0 + i*(i+1)/2 mod 2^k.
// We're guaranteed to hit all table entries using this probe sequence.
mask := t.size - 1
h := itabHashFunc(inter, typ) & mask
for i := uintptr(1); ; i++ {
p := (**itab)(add(unsafe.Pointer(&t.entries), h*goarch.PtrSize))
// Use atomic read here so if we see m != nil, we also see
// the initializations of the fields of m.
// m := *p
m := (*itab)(atomic.Loadp(unsafe.Pointer(p)))
if m == nil {
return nil
}
if m.inter == inter && m._type == typ {
return m
}
h += i
h &= mask
}
}
// itabAdd adds the given itab to the itab hash table.
// itabLock must be held.
func itabAdd(m *itab) {
// Bugs can lead to calling this while mallocing is set,
// typically because this is called while panicing.
// Crash reliably, rather than only when we need to grow
// the hash table.
if getg().m.mallocing != 0 {
throw("malloc deadlock")
}
t := itabTable
if t.count >= 3*(t.size/4) { // 75% load factor
// Grow hash table.
// t2 = new(itabTableType) + some additional entries
// We lie and tell malloc we want pointer-free memory because
// all the pointed-to values are not in the heap.
t2 := (*itabTableType)(mallocgc((2+2*t.size)*goarch.PtrSize, nil, true))
t2.size = t.size * 2
// Copy over entries.
// Note: while copying, other threads may look for an itab and
// fail to find it. That's ok, they will then try to get the itab lock
// and as a consequence wait until this copying is complete.
iterate_itabs(t2.add)
if t2.count != t.count {
throw("mismatched count during itab table copy")
}
// Publish new hash table. Use an atomic write: see comment in getitab.
atomicstorep(unsafe.Pointer(&itabTable), unsafe.Pointer(t2))
// Adopt the new table as our own.
t = itabTable
// Note: the old table can be GC'ed here.
}
t.add(m)
}
// add adds the given itab to itab table t.
// itabLock must be held.
func (t *itabTableType) add(m *itab) {
// See comment in find about the probe sequence.
// Insert new itab in the first empty spot in the probe sequence.
mask := t.size - 1
h := itabHashFunc(m.inter, m._type) & mask
for i := uintptr(1); ; i++ {
p := (**itab)(add(unsafe.Pointer(&t.entries), h*goarch.PtrSize))
m2 := *p
if m2 == m {
// A given itab may be used in more than one module
// and thanks to the way global symbol resolution works, the
// pointed-to itab may already have been inserted into the
// global 'hash'.
return
}
if m2 == nil {
// Use atomic write here so if a reader sees m, it also
// sees the correctly initialized fields of m.
// NoWB is ok because m is not in heap memory.
// *p = m
atomic.StorepNoWB(unsafe.Pointer(p), unsafe.Pointer(m))
t.count++
return
}
h += i
h &= mask
}
}
// init fills in the m.fun array with all the code pointers for
// the m.inter/m._type pair. If the type does not implement the interface,
// it sets m.fun[0] to 0 and returns the name of an interface function that is missing.
// It is ok to call this multiple times on the same m, even concurrently.
func (m *itab) init() string {
inter := m.inter
typ := m._type
x := typ.uncommon()
// both inter and typ have method sorted by name,
// and interface names are unique,
// so can iterate over both in lock step;
// the loop is O(ni+nt) not O(ni*nt).
ni := len(inter.mhdr)
nt := int(x.mcount)
xmhdr := (*[1 << 16]method)(add(unsafe.Pointer(x), uintptr(x.moff)))[:nt:nt]
j := 0
methods := (*[1 << 16]unsafe.Pointer)(unsafe.Pointer(&m.fun[0]))[:ni:ni]
var fun0 unsafe.Pointer
imethods:
for k := 0; k < ni; k++ {
i := &inter.mhdr[k]
itype := inter.typ.typeOff(i.ityp)
name := inter.typ.nameOff(i.name)
iname := name.name()
ipkg := name.pkgPath()
if ipkg == "" {
ipkg = inter.pkgpath.name()
}
for ; j < nt; j++ {
t := &xmhdr[j]
tname := typ.nameOff(t.name)
if typ.typeOff(t.mtyp) == itype && tname.name() == iname {
pkgPath := tname.pkgPath()
if pkgPath == "" {
pkgPath = typ.nameOff(x.pkgpath).name()
}
if tname.isExported() || pkgPath == ipkg {
if m != nil {
ifn := typ.textOff(t.ifn)
if k == 0 {
fun0 = ifn // we'll set m.fun[0] at the end
} else {
methods[k] = ifn
}
}
continue imethods
}
}
}
// didn't find method
m.fun[0] = 0
return iname
}
m.fun[0] = uintptr(fun0)
return ""
}
func itabsinit() {
lockInit(&itabLock, lockRankItab)
lock(&itabLock)
for _, md := range activeModules() {
for _, i := range md.itablinks {
itabAdd(i)
}
}
unlock(&itabLock)
}
// panicdottypeE is called when doing an e.(T) conversion and the conversion fails.
// have = the dynamic type we have.
// want = the static type we're trying to convert to.
// iface = the static type we're converting from.
func panicdottypeE(have, want, iface *_type) {
panic(&TypeAssertionError{iface, have, want, ""})
}
// panicdottypeI is called when doing an i.(T) conversion and the conversion fails.
// Same args as panicdottypeE, but "have" is the dynamic itab we have.
func panicdottypeI(have *itab, want, iface *_type) {
var t *_type
if have != nil {
t = have._type
}
panicdottypeE(t, want, iface)
}
// panicnildottype is called when doing a i.(T) conversion and the interface i is nil.
// want = the static type we're trying to convert to.
func panicnildottype(want *_type) {
panic(&TypeAssertionError{nil, nil, want, ""})
// TODO: Add the static type we're converting from as well.
// It might generate a better error message.
// Just to match other nil conversion errors, we don't for now.
}
// The specialized convTx routines need a type descriptor to use when calling mallocgc.
// We don't need the type to be exact, just to have the correct size, alignment, and pointer-ness.
// However, when debugging, it'd be nice to have some indication in mallocgc where the types came from,
// so we use named types here.
// We then construct interface values of these types,
// and then extract the type word to use as needed.
type (
uint16InterfacePtr uint16
uint32InterfacePtr uint32
uint64InterfacePtr uint64
stringInterfacePtr string
sliceInterfacePtr []byte
)
var (
uint16Eface any = uint16InterfacePtr(0)
uint32Eface any = uint32InterfacePtr(0)
uint64Eface any = uint64InterfacePtr(0)
stringEface any = stringInterfacePtr("")
sliceEface any = sliceInterfacePtr(nil)
uint16Type *_type = efaceOf(&uint16Eface)._type
uint32Type *_type = efaceOf(&uint32Eface)._type
uint64Type *_type = efaceOf(&uint64Eface)._type
stringType *_type = efaceOf(&stringEface)._type
sliceType *_type = efaceOf(&sliceEface)._type
)
// The conv and assert functions below do very similar things.
// The convXXX functions are guaranteed by the compiler to succeed.
// The assertXXX functions may fail (either panicking or returning false,
// depending on whether they are 1-result or 2-result).
// The convXXX functions succeed on a nil input, whereas the assertXXX
// functions fail on a nil input.
// convT converts a value of type t, which is pointed to by v, to a pointer that can
// be used as the second word of an interface value.
func convT(t *_type, v unsafe.Pointer) unsafe.Pointer {
if raceenabled {
raceReadObjectPC(t, v, getcallerpc(), abi.FuncPCABIInternal(convT))
}
if msanenabled {
msanread(v, t.size)
}
if asanenabled {
asanread(v, t.size)
}
x := mallocgc(t.size, t, true)
typedmemmove(t, x, v)
return x
}
func convTnoptr(t *_type, v unsafe.Pointer) unsafe.Pointer {
// TODO: maybe take size instead of type?
if raceenabled {
raceReadObjectPC(t, v, getcallerpc(), abi.FuncPCABIInternal(convTnoptr))
}
if msanenabled {
msanread(v, t.size)
}
if asanenabled {
asanread(v, t.size)
}
x := mallocgc(t.size, t, false)
memmove(x, v, t.size)
return x
}
func convT16(val uint16) (x unsafe.Pointer) {
if val < uint16(len(staticuint64s)) {
x = unsafe.Pointer(&staticuint64s[val])
if goarch.BigEndian {
x = add(x, 6)
}
} else {
x = mallocgc(2, uint16Type, false)
*(*uint16)(x) = val
}
return
}
func convT32(val uint32) (x unsafe.Pointer) {
if val < uint32(len(staticuint64s)) {
x = unsafe.Pointer(&staticuint64s[val])
if goarch.BigEndian {
x = add(x, 4)
}
} else {
x = mallocgc(4, uint32Type, false)
*(*uint32)(x) = val
}
return
}
func convT64(val uint64) (x unsafe.Pointer) {
if val < uint64(len(staticuint64s)) {
x = unsafe.Pointer(&staticuint64s[val])
} else {
x = mallocgc(8, uint64Type, false)
*(*uint64)(x) = val
}
return
}
func convTstring(val string) (x unsafe.Pointer) {
if val == "" {
x = unsafe.Pointer(&zeroVal[0])
} else {
x = mallocgc(unsafe.Sizeof(val), stringType, true)
*(*string)(x) = val
}
return
}
func convTslice(val []byte) (x unsafe.Pointer) {
// Note: this must work for any element type, not just byte.
if (*slice)(unsafe.Pointer(&val)).array == nil {
x = unsafe.Pointer(&zeroVal[0])
} else {
x = mallocgc(unsafe.Sizeof(val), sliceType, true)
*(*[]byte)(x) = val
}
return
}
// convI2I returns the new itab to be used for the destination value
// when converting a value with itab src to the dst interface.
func convI2I(dst *interfacetype, src *itab) *itab {
if src == nil {
return nil
}
if src.inter == dst {
return src
}
return getitab(dst, src._type, false)
}
func assertI2I(inter *interfacetype, tab *itab) *itab {
if tab == nil {
// explicit conversions require non-nil interface value.
panic(&TypeAssertionError{nil, nil, &inter.typ, ""})
}
if tab.inter == inter {
return tab
}
return getitab(inter, tab._type, false)
}
func assertI2I2(inter *interfacetype, i iface) (r iface) {
tab := i.tab
if tab == nil {
return
}
if tab.inter != inter {
tab = getitab(inter, tab._type, true)
if tab == nil {
return
}
}
r.tab = tab
r.data = i.data
return
}
func assertE2I(inter *interfacetype, t *_type) *itab {
if t == nil {
// explicit conversions require non-nil interface value.
panic(&TypeAssertionError{nil, nil, &inter.typ, ""})
}
return getitab(inter, t, false)
}
func assertE2I2(inter *interfacetype, e eface) (r iface) {
t := e._type
if t == nil {
return
}
tab := getitab(inter, t, true)
if tab == nil {
return
}
r.tab = tab
r.data = e.data
return
}
//go:linkname reflect_ifaceE2I reflect.ifaceE2I
func reflect_ifaceE2I(inter *interfacetype, e eface, dst *iface) {
*dst = iface{assertE2I(inter, e._type), e.data}
}
//go:linkname reflectlite_ifaceE2I internal/reflectlite.ifaceE2I
func reflectlite_ifaceE2I(inter *interfacetype, e eface, dst *iface) {
*dst = iface{assertE2I(inter, e._type), e.data}
}
func iterate_itabs(fn func(*itab)) {
// Note: only runs during stop the world or with itabLock held,
// so no other locks/atomics needed.
t := itabTable
for i := uintptr(0); i < t.size; i++ {
m := *(**itab)(add(unsafe.Pointer(&t.entries), i*goarch.PtrSize))
if m != nil {
fn(m)
}
}
}
// staticuint64s is used to avoid allocating in convTx for small integer values.
var staticuint64s = [...]uint64{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47,
0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57,
0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67,
0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77,
0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
}
// The linker redirects a reference of a method that it determined
// unreachable to a reference to this function, so it will throw if
// ever called.
func unreachableMethod() {
throw("unreachable method called. linker bug?")
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomic
import "unsafe"
// Export some functions via linkname to assembly in sync/atomic.
//
//go:linkname Load
//go:linkname Loadp
//go:linkname Load64
//go:nosplit
//go:noinline
func Load(ptr *uint32) uint32 {
return *ptr
}
//go:nosplit
//go:noinline
func Loadp(ptr unsafe.Pointer) unsafe.Pointer {
return *(*unsafe.Pointer)(ptr)
}
//go:nosplit
//go:noinline
func Load64(ptr *uint64) uint64 {
return *ptr
}
//go:nosplit
//go:noinline
func LoadAcq(ptr *uint32) uint32 {
return *ptr
}
//go:nosplit
//go:noinline
func LoadAcq64(ptr *uint64) uint64 {
return *ptr
}
//go:nosplit
//go:noinline
func LoadAcquintptr(ptr *uintptr) uintptr {
return *ptr
}
//go:noescape
func Xadd(ptr *uint32, delta int32) uint32
//go:noescape
func Xadd64(ptr *uint64, delta int64) uint64
//go:noescape
func Xadduintptr(ptr *uintptr, delta uintptr) uintptr
//go:noescape
func Xchg(ptr *uint32, new uint32) uint32
//go:noescape
func Xchg64(ptr *uint64, new uint64) uint64
//go:noescape
func Xchguintptr(ptr *uintptr, new uintptr) uintptr
//go:nosplit
//go:noinline
func Load8(ptr *uint8) uint8 {
return *ptr
}
//go:noescape
func And8(ptr *uint8, val uint8)
//go:noescape
func Or8(ptr *uint8, val uint8)
//go:noescape
func And(ptr *uint32, val uint32)
//go:noescape
func Or(ptr *uint32, val uint32)
// NOTE: Do not add atomicxor8 (XOR is not idempotent).
//go:noescape
func Cas64(ptr *uint64, old, new uint64) bool
//go:noescape
func CasRel(ptr *uint32, old, new uint32) bool
//go:noescape
func Store(ptr *uint32, val uint32)
//go:noescape
func Store8(ptr *uint8, val uint8)
//go:noescape
func Store64(ptr *uint64, val uint64)
//go:noescape
func StoreRel(ptr *uint32, val uint32)
//go:noescape
func StoreRel64(ptr *uint64, val uint64)
//go:noescape
func StoreReluintptr(ptr *uintptr, val uintptr)
// StorepNoWB performs *ptr = val atomically and without a write
// barrier.
//
// NO go:noescape annotation; see atomic_pointer.go.
func StorepNoWB(ptr unsafe.Pointer, val unsafe.Pointer)
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomic
import "unsafe"
// Int32 is an atomically accessed int32 value.
//
// An Int32 must not be copied.
type Int32 struct {
noCopy noCopy
value int32
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (i *Int32) Load() int32 {
return Loadint32(&i.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (i *Int32) Store(value int32) {
Storeint32(&i.value, value)
}
// CompareAndSwap atomically compares i's value with old,
// and if they're equal, swaps i's value with new.
// It reports whether the swap ran.
//
//go:nosplit
func (i *Int32) CompareAndSwap(old, new int32) bool {
return Casint32(&i.value, old, new)
}
// Swap replaces i's value with new, returning
// i's value before the replacement.
//
//go:nosplit
func (i *Int32) Swap(new int32) int32 {
return Xchgint32(&i.value, new)
}
// Add adds delta to i atomically, returning
// the new updated value.
//
// This operation wraps around in the usual
// two's-complement way.
//
//go:nosplit
func (i *Int32) Add(delta int32) int32 {
return Xaddint32(&i.value, delta)
}
// Int64 is an atomically accessed int64 value.
//
// 8-byte aligned on all platforms, unlike a regular int64.
//
// An Int64 must not be copied.
type Int64 struct {
noCopy noCopy
_ align64
value int64
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (i *Int64) Load() int64 {
return Loadint64(&i.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (i *Int64) Store(value int64) {
Storeint64(&i.value, value)
}
// CompareAndSwap atomically compares i's value with old,
// and if they're equal, swaps i's value with new.
// It reports whether the swap ran.
//
//go:nosplit
func (i *Int64) CompareAndSwap(old, new int64) bool {
return Casint64(&i.value, old, new)
}
// Swap replaces i's value with new, returning
// i's value before the replacement.
//
//go:nosplit
func (i *Int64) Swap(new int64) int64 {
return Xchgint64(&i.value, new)
}
// Add adds delta to i atomically, returning
// the new updated value.
//
// This operation wraps around in the usual
// two's-complement way.
//
//go:nosplit
func (i *Int64) Add(delta int64) int64 {
return Xaddint64(&i.value, delta)
}
// Uint8 is an atomically accessed uint8 value.
//
// A Uint8 must not be copied.
type Uint8 struct {
noCopy noCopy
value uint8
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (u *Uint8) Load() uint8 {
return Load8(&u.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (u *Uint8) Store(value uint8) {
Store8(&u.value, value)
}
// And takes value and performs a bit-wise
// "and" operation with the value of u, storing
// the result into u.
//
// The full process is performed atomically.
//
//go:nosplit
func (u *Uint8) And(value uint8) {
And8(&u.value, value)
}
// Or takes value and performs a bit-wise
// "or" operation with the value of u, storing
// the result into u.
//
// The full process is performed atomically.
//
//go:nosplit
func (u *Uint8) Or(value uint8) {
Or8(&u.value, value)
}
// Bool is an atomically accessed bool value.
//
// A Bool must not be copied.
type Bool struct {
// Inherits noCopy from Uint8.
u Uint8
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (b *Bool) Load() bool {
return b.u.Load() != 0
}
// Store updates the value atomically.
//
//go:nosplit
func (b *Bool) Store(value bool) {
s := uint8(0)
if value {
s = 1
}
b.u.Store(s)
}
// Uint32 is an atomically accessed uint32 value.
//
// A Uint32 must not be copied.
type Uint32 struct {
noCopy noCopy
value uint32
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (u *Uint32) Load() uint32 {
return Load(&u.value)
}
// LoadAcquire is a partially unsynchronized version
// of Load that relaxes ordering constraints. Other threads
// may observe operations that precede this operation to
// occur after it, but no operation that occurs after it
// on this thread can be observed to occur before it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uint32) LoadAcquire() uint32 {
return LoadAcq(&u.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (u *Uint32) Store(value uint32) {
Store(&u.value, value)
}
// StoreRelease is a partially unsynchronized version
// of Store that relaxes ordering constraints. Other threads
// may observe operations that occur after this operation to
// precede it, but no operation that precedes it
// on this thread can be observed to occur after it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uint32) StoreRelease(value uint32) {
StoreRel(&u.value, value)
}
// CompareAndSwap atomically compares u's value with old,
// and if they're equal, swaps u's value with new.
// It reports whether the swap ran.
//
//go:nosplit
func (u *Uint32) CompareAndSwap(old, new uint32) bool {
return Cas(&u.value, old, new)
}
// CompareAndSwapRelease is a partially unsynchronized version
// of Cas that relaxes ordering constraints. Other threads
// may observe operations that occur after this operation to
// precede it, but no operation that precedes it
// on this thread can be observed to occur after it.
// It reports whether the swap ran.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uint32) CompareAndSwapRelease(old, new uint32) bool {
return CasRel(&u.value, old, new)
}
// Swap replaces u's value with new, returning
// u's value before the replacement.
//
//go:nosplit
func (u *Uint32) Swap(value uint32) uint32 {
return Xchg(&u.value, value)
}
// And takes value and performs a bit-wise
// "and" operation with the value of u, storing
// the result into u.
//
// The full process is performed atomically.
//
//go:nosplit
func (u *Uint32) And(value uint32) {
And(&u.value, value)
}
// Or takes value and performs a bit-wise
// "or" operation with the value of u, storing
// the result into u.
//
// The full process is performed atomically.
//
//go:nosplit
func (u *Uint32) Or(value uint32) {
Or(&u.value, value)
}
// Add adds delta to u atomically, returning
// the new updated value.
//
// This operation wraps around in the usual
// two's-complement way.
//
//go:nosplit
func (u *Uint32) Add(delta int32) uint32 {
return Xadd(&u.value, delta)
}
// Uint64 is an atomically accessed uint64 value.
//
// 8-byte aligned on all platforms, unlike a regular uint64.
//
// A Uint64 must not be copied.
type Uint64 struct {
noCopy noCopy
_ align64
value uint64
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (u *Uint64) Load() uint64 {
return Load64(&u.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (u *Uint64) Store(value uint64) {
Store64(&u.value, value)
}
// CompareAndSwap atomically compares u's value with old,
// and if they're equal, swaps u's value with new.
// It reports whether the swap ran.
//
//go:nosplit
func (u *Uint64) CompareAndSwap(old, new uint64) bool {
return Cas64(&u.value, old, new)
}
// Swap replaces u's value with new, returning
// u's value before the replacement.
//
//go:nosplit
func (u *Uint64) Swap(value uint64) uint64 {
return Xchg64(&u.value, value)
}
// Add adds delta to u atomically, returning
// the new updated value.
//
// This operation wraps around in the usual
// two's-complement way.
//
//go:nosplit
func (u *Uint64) Add(delta int64) uint64 {
return Xadd64(&u.value, delta)
}
// Uintptr is an atomically accessed uintptr value.
//
// A Uintptr must not be copied.
type Uintptr struct {
noCopy noCopy
value uintptr
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (u *Uintptr) Load() uintptr {
return Loaduintptr(&u.value)
}
// LoadAcquire is a partially unsynchronized version
// of Load that relaxes ordering constraints. Other threads
// may observe operations that precede this operation to
// occur after it, but no operation that occurs after it
// on this thread can be observed to occur before it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uintptr) LoadAcquire() uintptr {
return LoadAcquintptr(&u.value)
}
// Store updates the value atomically.
//
//go:nosplit
func (u *Uintptr) Store(value uintptr) {
Storeuintptr(&u.value, value)
}
// StoreRelease is a partially unsynchronized version
// of Store that relaxes ordering constraints. Other threads
// may observe operations that occur after this operation to
// precede it, but no operation that precedes it
// on this thread can be observed to occur after it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uintptr) StoreRelease(value uintptr) {
StoreReluintptr(&u.value, value)
}
// CompareAndSwap atomically compares u's value with old,
// and if they're equal, swaps u's value with new.
// It reports whether the swap ran.
//
//go:nosplit
func (u *Uintptr) CompareAndSwap(old, new uintptr) bool {
return Casuintptr(&u.value, old, new)
}
// Swap replaces u's value with new, returning
// u's value before the replacement.
//
//go:nosplit
func (u *Uintptr) Swap(value uintptr) uintptr {
return Xchguintptr(&u.value, value)
}
// Add adds delta to u atomically, returning
// the new updated value.
//
// This operation wraps around in the usual
// two's-complement way.
//
//go:nosplit
func (u *Uintptr) Add(delta uintptr) uintptr {
return Xadduintptr(&u.value, delta)
}
// Float64 is an atomically accessed float64 value.
//
// 8-byte aligned on all platforms, unlike a regular float64.
//
// A Float64 must not be copied.
type Float64 struct {
// Inherits noCopy and align64 from Uint64.
u Uint64
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (f *Float64) Load() float64 {
r := f.u.Load()
return *(*float64)(unsafe.Pointer(&r))
}
// Store updates the value atomically.
//
//go:nosplit
func (f *Float64) Store(value float64) {
f.u.Store(*(*uint64)(unsafe.Pointer(&value)))
}
// UnsafePointer is an atomically accessed unsafe.Pointer value.
//
// Note that because of the atomicity guarantees, stores to values
// of this type never trigger a write barrier, and the relevant
// methods are suffixed with "NoWB" to indicate that explicitly.
// As a result, this type should be used carefully, and sparingly,
// mostly with values that do not live in the Go heap anyway.
//
// An UnsafePointer must not be copied.
type UnsafePointer struct {
noCopy noCopy
value unsafe.Pointer
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (u *UnsafePointer) Load() unsafe.Pointer {
return Loadp(unsafe.Pointer(&u.value))
}
// StoreNoWB updates the value atomically.
//
// WARNING: As the name implies this operation does *not*
// perform a write barrier on value, and so this operation may
// hide pointers from the GC. Use with care and sparingly.
// It is safe to use with values not found in the Go heap.
// Prefer Store instead.
//
//go:nosplit
func (u *UnsafePointer) StoreNoWB(value unsafe.Pointer) {
StorepNoWB(unsafe.Pointer(&u.value), value)
}
// Store updates the value atomically.
func (u *UnsafePointer) Store(value unsafe.Pointer) {
storePointer(&u.value, value)
}
// provided by runtime
//go:linkname storePointer
func storePointer(ptr *unsafe.Pointer, new unsafe.Pointer)
// CompareAndSwapNoWB atomically (with respect to other methods)
// compares u's value with old, and if they're equal,
// swaps u's value with new.
// It reports whether the swap ran.
//
// WARNING: As the name implies this operation does *not*
// perform a write barrier on value, and so this operation may
// hide pointers from the GC. Use with care and sparingly.
// It is safe to use with values not found in the Go heap.
// Prefer CompareAndSwap instead.
//
//go:nosplit
func (u *UnsafePointer) CompareAndSwapNoWB(old, new unsafe.Pointer) bool {
return Casp1(&u.value, old, new)
}
// CompareAndSwap atomically compares u's value with old,
// and if they're equal, swaps u's value with new.
// It reports whether the swap ran.
func (u *UnsafePointer) CompareAndSwap(old, new unsafe.Pointer) bool {
return casPointer(&u.value, old, new)
}
func casPointer(ptr *unsafe.Pointer, old, new unsafe.Pointer) bool
// Pointer is an atomic pointer of type *T.
type Pointer[T any] struct {
u UnsafePointer
}
// Load accesses and returns the value atomically.
//
//go:nosplit
func (p *Pointer[T]) Load() *T {
return (*T)(p.u.Load())
}
// StoreNoWB updates the value atomically.
//
// WARNING: As the name implies this operation does *not*
// perform a write barrier on value, and so this operation may
// hide pointers from the GC. Use with care and sparingly.
// It is safe to use with values not found in the Go heap.
// Prefer Store instead.
//
//go:nosplit
func (p *Pointer[T]) StoreNoWB(value *T) {
p.u.StoreNoWB(unsafe.Pointer(value))
}
// Store updates the value atomically.
//go:nosplit
func (p *Pointer[T]) Store(value *T) {
p.u.Store(unsafe.Pointer(value))
}
// CompareAndSwapNoWB atomically (with respect to other methods)
// compares u's value with old, and if they're equal,
// swaps u's value with new.
// It reports whether the swap ran.
//
// WARNING: As the name implies this operation does *not*
// perform a write barrier on value, and so this operation may
// hide pointers from the GC. Use with care and sparingly.
// It is safe to use with values not found in the Go heap.
// Prefer CompareAndSwap instead.
//
//go:nosplit
func (p *Pointer[T]) CompareAndSwapNoWB(old, new *T) bool {
return p.u.CompareAndSwapNoWB(unsafe.Pointer(old), unsafe.Pointer(new))
}
// CompareAndSwap atomically (with respect to other methods)
// compares u's value with old, and if they're equal,
// swaps u's value with new.
// It reports whether the swap ran.
func (p *Pointer[T]) CompareAndSwap(old, new *T) bool {
return p.u.CompareAndSwap(unsafe.Pointer(old), unsafe.Pointer(new))
}
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// align64 may be added to structs that must be 64-bit aligned.
// This struct is recognized by a special case in the compiler
// and will not work if copied to any other package.
type align64 struct{}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64 || loong64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x || wasm
package atomic
// LoadAcquire is a partially unsynchronized version
// of Load that relaxes ordering constraints. Other threads
// may observe operations that precede this operation to
// occur after it, but no operation that occurs after it
// on this thread can be observed to occur before it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uint64) LoadAcquire() uint64 {
return LoadAcq64(&u.value)
}
// StoreRelease is a partially unsynchronized version
// of Store that relaxes ordering constraints. Other threads
// may observe operations that occur after this operation to
// precede it, but no operation that precedes it
// on this thread can be observed to occur after it.
//
// WARNING: Use sparingly and with great care.
//
//go:nosplit
func (u *Uint64) StoreRelease(value uint64) {
StoreRel64(&u.value, value)
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomic
func panicUnaligned() {
panic("unaligned 64-bit atomic operation")
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math
import "internal/goarch"
const MaxUintptr = ^uintptr(0)
// MulUintptr returns a * b and whether the multiplication overflowed.
// On supported platforms this is an intrinsic lowered by the compiler.
func MulUintptr(a, b uintptr) (uintptr, bool) {
if a|b < 1<<(4*goarch.PtrSize) || a == 0 {
return a * b, false
}
overflow := b > MaxUintptr/a
return a * b, overflow
}
// Mul64 returns the 128-bit product of x and y: (hi, lo) = x * y
// with the product bits' upper half returned in hi and the lower
// half returned in lo.
// This is a copy from math/bits.Mul64
// On supported platforms this is an intrinsic lowered by the compiler.
func Mul64(x, y uint64) (hi, lo uint64) {
const mask32 = 1<<32 - 1
x0 := x & mask32
x1 := x >> 32
y0 := y & mask32
y1 := y >> 32
w0 := x0 * y0
t := x1*y0 + w0>>32
w1 := t & mask32
w2 := t >> 32
w1 += x0 * y1
hi = x1*y1 + w2 + w1>>32
lo = x * y
return
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !386
// TODO finish intrinsifying 386, deadcode the assembly, remove build tags, merge w/ intrinsics_common
package sys
// Copied from math/bits to avoid dependence.
var deBruijn32tab = [32]byte{
0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
}
const deBruijn32 = 0x077CB531
var deBruijn64tab = [64]byte{
0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
}
const deBruijn64 = 0x03f79d71b4ca8b09
const ntz8tab = "" +
"\x08\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x06\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x07\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x06\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x05\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00" +
"\x04\x00\x01\x00\x02\x00\x01\x00\x03\x00\x01\x00\x02\x00\x01\x00"
// TrailingZeros32 returns the number of trailing zero bits in x; the result is 32 for x == 0.
func TrailingZeros32(x uint32) int {
if x == 0 {
return 32
}
// see comment in TrailingZeros64
return int(deBruijn32tab[(x&-x)*deBruijn32>>(32-5)])
}
// TrailingZeros64 returns the number of trailing zero bits in x; the result is 64 for x == 0.
func TrailingZeros64(x uint64) int {
if x == 0 {
return 64
}
// If popcount is fast, replace code below with return popcount(^x & (x - 1)).
//
// x & -x leaves only the right-most bit set in the word. Let k be the
// index of that bit. Since only a single bit is set, the value is two
// to the power of k. Multiplying by a power of two is equivalent to
// left shifting, in this case by k bits. The de Bruijn (64 bit) constant
// is such that all six bit, consecutive substrings are distinct.
// Therefore, if we have a left shifted version of this constant we can
// find by how many bits it was shifted by looking at which six bit
// substring ended up at the top of the word.
// (Knuth, volume 4, section 7.3.1)
return int(deBruijn64tab[(x&-x)*deBruijn64>>(64-6)])
}
// TrailingZeros8 returns the number of trailing zero bits in x; the result is 8 for x == 0.
func TrailingZeros8(x uint8) int {
return int(ntz8tab[x])
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sys
// Copied from math/bits to avoid dependence.
const len8tab = "" +
"\x00\x01\x02\x02\x03\x03\x03\x03\x04\x04\x04\x04\x04\x04\x04\x04" +
"\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05\x05" +
"\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" +
"\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06\x06" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07\x07" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08" +
"\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08\x08"
// Len64 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
//
// nosplit because this is used in src/runtime/histogram.go, which make run in sensitive contexts.
//
//go:nosplit
func Len64(x uint64) (n int) {
if x >= 1<<32 {
x >>= 32
n = 32
}
if x >= 1<<16 {
x >>= 16
n += 16
}
if x >= 1<<8 {
x >>= 8
n += 8
}
return n + int(len8tab[x])
}
// --- OnesCount ---
const m0 = 0x5555555555555555 // 01010101 ...
const m1 = 0x3333333333333333 // 00110011 ...
const m2 = 0x0f0f0f0f0f0f0f0f // 00001111 ...
// OnesCount64 returns the number of one bits ("population count") in x.
func OnesCount64(x uint64) int {
// Implementation: Parallel summing of adjacent bits.
// See "Hacker's Delight", Chap. 5: Counting Bits.
// The following pattern shows the general approach:
//
// x = x>>1&(m0&m) + x&(m0&m)
// x = x>>2&(m1&m) + x&(m1&m)
// x = x>>4&(m2&m) + x&(m2&m)
// x = x>>8&(m3&m) + x&(m3&m)
// x = x>>16&(m4&m) + x&(m4&m)
// x = x>>32&(m5&m) + x&(m5&m)
// return int(x)
//
// Masking (& operations) can be left away when there's no
// danger that a field's sum will carry over into the next
// field: Since the result cannot be > 64, 8 bits is enough
// and we can ignore the masks for the shifts by 8 and up.
// Per "Hacker's Delight", the first line can be simplified
// more, but it saves at best one instruction, so we leave
// it alone for clarity.
const m = 1<<64 - 1
x = x>>1&(m0&m) + x&(m0&m)
x = x>>2&(m1&m) + x&(m1&m)
x = (x>>4 + x) & (m2 & m)
x += x >> 8
x += x >> 16
x += x >> 32
return int(x) & (1<<7 - 1)
}
// LeadingZeros64 returns the number of leading zero bits in x; the result is 64 for x == 0.
func LeadingZeros64(x uint64) int { return 64 - Len64(x) }
// LeadingZeros8 returns the number of leading zero bits in x; the result is 8 for x == 0.
func LeadingZeros8(x uint8) int { return 8 - Len8(x) }
// Len8 returns the minimum number of bits required to represent x; the result is 0 for x == 0.
func Len8(x uint8) int {
return int(len8tab[x])
}
// Bswap64 returns its input with byte order reversed
// 0x0102030405060708 -> 0x0807060504030201
func Bswap64(x uint64) uint64 {
c8 := uint64(0x00ff00ff00ff00ff)
a := x >> 8 & c8
b := (x & c8) << 8
x = a | b
c16 := uint64(0x0000ffff0000ffff)
a = x >> 16 & c16
b = (x & c16) << 16
x = a | b
c32 := uint64(0x00000000ffffffff)
a = x >> 32 & c32
b = (x & c32) << 32
x = a | b
return x
}
// Bswap32 returns its input with byte order reversed
// 0x01020304 -> 0x04030201
func Bswap32(x uint32) uint32 {
c8 := uint32(0x00ff00ff)
a := x >> 8 & c8
b := (x & c8) << 8
x = a | b
c16 := uint32(0x0000ffff)
a = x >> 16 & c16
b = (x & c16) << 16
x = a | b
return x
}
// Prefetch prefetches data from memory addr to cache
//
// AMD64: Produce PREFETCHT0 instruction
//
// ARM64: Produce PRFM instruction with PLDL1KEEP option
func Prefetch(addr uintptr) {}
// PrefetchStreamed prefetches data from memory addr, with a hint that this data is being streamed.
// That is, it is likely to be accessed very soon, but only once. If possible, this will avoid polluting the cache.
//
// AMD64: Produce PREFETCHNTA instruction
//
// ARM64: Produce PRFM instruction with PLDL1STRM option
func PrefetchStreamed(addr uintptr) {}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package syscall provides the syscall primitives required for the runtime.
package syscall
import (
"unsafe"
)
// TODO(https://go.dev/issue/51087): This package is incomplete and currently
// only contains very minimal support for Linux.
// Syscall6 calls system call number 'num' with arguments a1-6.
func Syscall6(num, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, errno uintptr)
// syscall_RawSyscall6 is a push linkname to export Syscall6 as
// syscall.RawSyscall6.
//
// //go:uintptrkeepalive because the uintptr argument may be converted pointers
// that need to be kept alive in the caller (this is implied for Syscall6 since
// it has no body).
//
// //go:nosplit because stack copying does not account for uintptrkeepalive, so
// the stack must not grow. Stack copying cannot blindly assume that all
// uintptr arguments are pointers, because some values may look like pointers,
// but not really be pointers, and adjusting their value would break the call.
//
// This is a separate wrapper because we can't export one function as two
// names. The assembly implementations name themselves Syscall6 would not be
// affected by a linkname.
//
//go:uintptrkeepalive
//go:nosplit
//go:linkname syscall_RawSyscall6 syscall.RawSyscall6
func syscall_RawSyscall6(num, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, errno uintptr) {
return Syscall6(num, a1, a2, a3, a4, a5, a6)
}
func EpollCreate1(flags int32) (fd int32, errno uintptr) {
r1, _, e := Syscall6(SYS_EPOLL_CREATE1, uintptr(flags), 0, 0, 0, 0, 0)
return int32(r1), e
}
var _zero uintptr
func EpollWait(epfd int32, events []EpollEvent, maxev, waitms int32) (n int32, errno uintptr) {
var ev unsafe.Pointer
if len(events) > 0 {
ev = unsafe.Pointer(&events[0])
} else {
ev = unsafe.Pointer(&_zero)
}
r1, _, e := Syscall6(SYS_EPOLL_PWAIT, uintptr(epfd), uintptr(ev), uintptr(maxev), uintptr(waitms), 0, 0)
return int32(r1), e
}
func EpollCtl(epfd, op, fd int32, event *EpollEvent) (errno uintptr) {
_, _, e := Syscall6(SYS_EPOLL_CTL, uintptr(epfd), uintptr(op), uintptr(fd), uintptr(unsafe.Pointer(event)), 0, 0)
return e
}
func CloseOnExec(fd int32) {
Syscall6(SYS_FCNTL, uintptr(fd), F_SETFD, FD_CLOEXEC, 0, 0, 0)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Lock-free stack.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
// lfstack is the head of a lock-free stack.
//
// The zero value of lfstack is an empty list.
//
// This stack is intrusive. Nodes must embed lfnode as the first field.
//
// The stack does not keep GC-visible pointers to nodes, so the caller
// must ensure the nodes are allocated outside the Go heap.
type lfstack uint64
func (head *lfstack) push(node *lfnode) {
node.pushcnt++
new := lfstackPack(node, node.pushcnt)
if node1 := lfstackUnpack(new); node1 != node {
print("runtime: lfstack.push invalid packing: node=", node, " cnt=", hex(node.pushcnt), " packed=", hex(new), " -> node=", node1, "\n")
throw("lfstack.push")
}
for {
old := atomic.Load64((*uint64)(head))
node.next = old
if atomic.Cas64((*uint64)(head), old, new) {
break
}
}
}
func (head *lfstack) pop() unsafe.Pointer {
for {
old := atomic.Load64((*uint64)(head))
if old == 0 {
return nil
}
node := lfstackUnpack(old)
next := atomic.Load64(&node.next)
if atomic.Cas64((*uint64)(head), old, next) {
return unsafe.Pointer(node)
}
}
}
func (head *lfstack) empty() bool {
return atomic.Load64((*uint64)(head)) == 0
}
// lfnodeValidate panics if node is not a valid address for use with
// lfstack.push. This only needs to be called when node is allocated.
func lfnodeValidate(node *lfnode) {
if base, _, _ := findObject(uintptr(unsafe.Pointer(node)), 0, 0); base != 0 {
throw("lfstack node allocated from the heap")
}
if lfstackUnpack(lfstackPack(node, ^uintptr(0))) != node {
printlock()
println("runtime: bad lfnode address", hex(uintptr(unsafe.Pointer(node))))
throw("bad lfnode address")
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64 || loong64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x || wasm
package runtime
import "unsafe"
const (
// addrBits is the number of bits needed to represent a virtual address.
//
// See heapAddrBits for a table of address space sizes on
// various architectures. 48 bits is enough for all
// architectures except s390x.
//
// On AMD64, virtual addresses are 48-bit (or 57-bit) numbers sign extended to 64.
// We shift the address left 16 to eliminate the sign extended part and make
// room in the bottom for the count.
//
// On s390x, virtual addresses are 64-bit. There's not much we
// can do about this, so we just hope that the kernel doesn't
// get to really high addresses and panic if it does.
addrBits = 48
// In addition to the 16 bits taken from the top, we can take 3 from the
// bottom, because node must be pointer-aligned, giving a total of 19 bits
// of count.
cntBits = 64 - addrBits + 3
// On AIX, 64-bit addresses are split into 36-bit segment number and 28-bit
// offset in segment. Segment numbers in the range 0x0A0000000-0x0AFFFFFFF(LSA)
// are available for mmap.
// We assume all lfnode addresses are from memory allocated with mmap.
// We use one bit to distinguish between the two ranges.
aixAddrBits = 57
aixCntBits = 64 - aixAddrBits + 3
// riscv64 SV57 mode gives 56 bits of userspace VA.
// lfstack code supports it, but broader support for SV57 mode is incomplete,
// and there may be other issues (see #54104).
riscv64AddrBits = 56
riscv64CntBits = 64 - riscv64AddrBits + 3
)
func lfstackPack(node *lfnode, cnt uintptr) uint64 {
if GOARCH == "ppc64" && GOOS == "aix" {
return uint64(uintptr(unsafe.Pointer(node)))<<(64-aixAddrBits) | uint64(cnt&(1<<aixCntBits-1))
}
if GOARCH == "riscv64" {
return uint64(uintptr(unsafe.Pointer(node)))<<(64-riscv64AddrBits) | uint64(cnt&(1<<riscv64CntBits-1))
}
return uint64(uintptr(unsafe.Pointer(node)))<<(64-addrBits) | uint64(cnt&(1<<cntBits-1))
}
func lfstackUnpack(val uint64) *lfnode {
if GOARCH == "amd64" {
// amd64 systems can place the stack above the VA hole, so we need to sign extend
// val before unpacking.
return (*lfnode)(unsafe.Pointer(uintptr(int64(val) >> cntBits << 3)))
}
if GOARCH == "ppc64" && GOOS == "aix" {
return (*lfnode)(unsafe.Pointer(uintptr((val >> aixCntBits << 3) | 0xa<<56)))
}
if GOARCH == "riscv64" {
return (*lfnode)(unsafe.Pointer(uintptr(val >> riscv64CntBits << 3)))
}
return (*lfnode)(unsafe.Pointer(uintptr(val >> cntBits << 3)))
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
// This implementation depends on OS-specific implementations of
//
// futexsleep(addr *uint32, val uint32, ns int64)
// Atomically,
// if *addr == val { sleep }
// Might be woken up spuriously; that's allowed.
// Don't sleep longer than ns; ns < 0 means forever.
//
// futexwakeup(addr *uint32, cnt uint32)
// If any procs are sleeping on addr, wake up at most cnt.
const (
mutex_unlocked = 0
mutex_locked = 1
mutex_sleeping = 2
active_spin = 4
active_spin_cnt = 30
passive_spin = 1
)
// Possible lock states are mutex_unlocked, mutex_locked and mutex_sleeping.
// mutex_sleeping means that there is presumably at least one sleeping thread.
// Note that there can be spinning threads during all states - they do not
// affect mutex's state.
// We use the uintptr mutex.key and note.key as a uint32.
//
//go:nosplit
func key32(p *uintptr) *uint32 {
return (*uint32)(unsafe.Pointer(p))
}
func lock(l *mutex) {
lockWithRank(l, getLockRank(l))
}
func lock2(l *mutex) {
gp := getg()
if gp.m.locks < 0 {
throw("runtime·lock: lock count")
}
gp.m.locks++
// Speculative grab for lock.
v := atomic.Xchg(key32(&l.key), mutex_locked)
if v == mutex_unlocked {
return
}
// wait is either MUTEX_LOCKED or MUTEX_SLEEPING
// depending on whether there is a thread sleeping
// on this mutex. If we ever change l->key from
// MUTEX_SLEEPING to some other value, we must be
// careful to change it back to MUTEX_SLEEPING before
// returning, to ensure that the sleeping thread gets
// its wakeup call.
wait := v
// On uniprocessors, no point spinning.
// On multiprocessors, spin for ACTIVE_SPIN attempts.
spin := 0
if ncpu > 1 {
spin = active_spin
}
for {
// Try for lock, spinning.
for i := 0; i < spin; i++ {
for l.key == mutex_unlocked {
if atomic.Cas(key32(&l.key), mutex_unlocked, wait) {
return
}
}
procyield(active_spin_cnt)
}
// Try for lock, rescheduling.
for i := 0; i < passive_spin; i++ {
for l.key == mutex_unlocked {
if atomic.Cas(key32(&l.key), mutex_unlocked, wait) {
return
}
}
osyield()
}
// Sleep.
v = atomic.Xchg(key32(&l.key), mutex_sleeping)
if v == mutex_unlocked {
return
}
wait = mutex_sleeping
futexsleep(key32(&l.key), mutex_sleeping, -1)
}
}
func unlock(l *mutex) {
unlockWithRank(l)
}
func unlock2(l *mutex) {
v := atomic.Xchg(key32(&l.key), mutex_unlocked)
if v == mutex_unlocked {
throw("unlock of unlocked lock")
}
if v == mutex_sleeping {
futexwakeup(key32(&l.key), 1)
}
gp := getg()
gp.m.locks--
if gp.m.locks < 0 {
throw("runtime·unlock: lock count")
}
if gp.m.locks == 0 && gp.preempt { // restore the preemption request in case we've cleared it in newstack
gp.stackguard0 = stackPreempt
}
}
// One-time notifications.
func noteclear(n *note) {
n.key = 0
}
func notewakeup(n *note) {
old := atomic.Xchg(key32(&n.key), 1)
if old != 0 {
print("notewakeup - double wakeup (", old, ")\n")
throw("notewakeup - double wakeup")
}
futexwakeup(key32(&n.key), 1)
}
func notesleep(n *note) {
gp := getg()
if gp != gp.m.g0 {
throw("notesleep not on g0")
}
ns := int64(-1)
if *cgo_yield != nil {
// Sleep for an arbitrary-but-moderate interval to poll libc interceptors.
ns = 10e6
}
for atomic.Load(key32(&n.key)) == 0 {
gp.m.blocked = true
futexsleep(key32(&n.key), 0, ns)
if *cgo_yield != nil {
asmcgocall(*cgo_yield, nil)
}
gp.m.blocked = false
}
}
// May run with m.p==nil if called from notetsleep, so write barriers
// are not allowed.
//
//go:nosplit
//go:nowritebarrier
func notetsleep_internal(n *note, ns int64) bool {
gp := getg()
if ns < 0 {
if *cgo_yield != nil {
// Sleep for an arbitrary-but-moderate interval to poll libc interceptors.
ns = 10e6
}
for atomic.Load(key32(&n.key)) == 0 {
gp.m.blocked = true
futexsleep(key32(&n.key), 0, ns)
if *cgo_yield != nil {
asmcgocall(*cgo_yield, nil)
}
gp.m.blocked = false
}
return true
}
if atomic.Load(key32(&n.key)) != 0 {
return true
}
deadline := nanotime() + ns
for {
if *cgo_yield != nil && ns > 10e6 {
ns = 10e6
}
gp.m.blocked = true
futexsleep(key32(&n.key), 0, ns)
if *cgo_yield != nil {
asmcgocall(*cgo_yield, nil)
}
gp.m.blocked = false
if atomic.Load(key32(&n.key)) != 0 {
break
}
now := nanotime()
if now >= deadline {
break
}
ns = deadline - now
}
return atomic.Load(key32(&n.key)) != 0
}
func notetsleep(n *note, ns int64) bool {
gp := getg()
if gp != gp.m.g0 && gp.m.preemptoff != "" {
throw("notetsleep not on g0")
}
return notetsleep_internal(n, ns)
}
// same as runtime·notetsleep, but called on user g (not g0)
// calls only nosplit functions between entersyscallblock/exitsyscall.
func notetsleepg(n *note, ns int64) bool {
gp := getg()
if gp == gp.m.g0 {
throw("notetsleepg on g0")
}
entersyscallblock()
ok := notetsleep_internal(n, ns)
exitsyscall()
return ok
}
func beforeIdle(int64, int64) (*g, bool) {
return nil, false
}
func checkTimeouts() {}
// Code generated by mklockrank.go; DO NOT EDIT.
package runtime
type lockRank int
// Constants representing the ranks of all non-leaf runtime locks, in rank order.
// Locks with lower rank must be taken before locks with higher rank,
// in addition to satisfying the partial order in lockPartialOrder.
// A few ranks allow self-cycles, which are specified in lockPartialOrder.
const (
lockRankUnknown lockRank = iota
lockRankSysmon
lockRankScavenge
lockRankForcegc
lockRankDefer
lockRankSweepWaiters
lockRankAssistQueue
lockRankSweep
lockRankPollDesc
lockRankCpuprof
lockRankSched
lockRankAllg
lockRankAllp
lockRankTimers
lockRankNetpollInit
lockRankHchan
lockRankNotifyList
lockRankSudog
lockRankRwmutexW
lockRankRwmutexR
lockRankRoot
lockRankItab
lockRankReflectOffs
lockRankUserArenaState
// TRACEGLOBAL
lockRankTraceBuf
lockRankTraceStrings
// MALLOC
lockRankFin
lockRankGcBitsArenas
lockRankMheapSpecial
lockRankMspanSpecial
lockRankSpanSetSpine
// MPROF
lockRankProfInsert
lockRankProfBlock
lockRankProfMemActive
lockRankProfMemFuture
// STACKGROW
lockRankGscan
lockRankStackpool
lockRankStackLarge
lockRankHchanLeaf
// WB
lockRankWbufSpans
lockRankMheap
lockRankGlobalAlloc
// TRACE
lockRankTrace
lockRankTraceStackTab
lockRankPanic
lockRankDeadlock
)
// lockRankLeafRank is the rank of lock that does not have a declared rank,
// and hence is a leaf lock.
const lockRankLeafRank lockRank = 1000
// lockNames gives the names associated with each of the above ranks.
var lockNames = []string{
lockRankSysmon: "sysmon",
lockRankScavenge: "scavenge",
lockRankForcegc: "forcegc",
lockRankDefer: "defer",
lockRankSweepWaiters: "sweepWaiters",
lockRankAssistQueue: "assistQueue",
lockRankSweep: "sweep",
lockRankPollDesc: "pollDesc",
lockRankCpuprof: "cpuprof",
lockRankSched: "sched",
lockRankAllg: "allg",
lockRankAllp: "allp",
lockRankTimers: "timers",
lockRankNetpollInit: "netpollInit",
lockRankHchan: "hchan",
lockRankNotifyList: "notifyList",
lockRankSudog: "sudog",
lockRankRwmutexW: "rwmutexW",
lockRankRwmutexR: "rwmutexR",
lockRankRoot: "root",
lockRankItab: "itab",
lockRankReflectOffs: "reflectOffs",
lockRankUserArenaState: "userArenaState",
lockRankTraceBuf: "traceBuf",
lockRankTraceStrings: "traceStrings",
lockRankFin: "fin",
lockRankGcBitsArenas: "gcBitsArenas",
lockRankMheapSpecial: "mheapSpecial",
lockRankMspanSpecial: "mspanSpecial",
lockRankSpanSetSpine: "spanSetSpine",
lockRankProfInsert: "profInsert",
lockRankProfBlock: "profBlock",
lockRankProfMemActive: "profMemActive",
lockRankProfMemFuture: "profMemFuture",
lockRankGscan: "gscan",
lockRankStackpool: "stackpool",
lockRankStackLarge: "stackLarge",
lockRankHchanLeaf: "hchanLeaf",
lockRankWbufSpans: "wbufSpans",
lockRankMheap: "mheap",
lockRankGlobalAlloc: "globalAlloc",
lockRankTrace: "trace",
lockRankTraceStackTab: "traceStackTab",
lockRankPanic: "panic",
lockRankDeadlock: "deadlock",
}
func (rank lockRank) String() string {
if rank == 0 {
return "UNKNOWN"
}
if rank == lockRankLeafRank {
return "LEAF"
}
if rank < 0 || int(rank) >= len(lockNames) {
return "BAD RANK"
}
return lockNames[rank]
}
// lockPartialOrder is the transitive closure of the lock rank graph.
// An entry for rank X lists all of the ranks that can already be held
// when rank X is acquired.
//
// Lock ranks that allow self-cycles list themselves.
var lockPartialOrder [][]lockRank = [][]lockRank{
lockRankSysmon: {},
lockRankScavenge: {lockRankSysmon},
lockRankForcegc: {lockRankSysmon},
lockRankDefer: {},
lockRankSweepWaiters: {},
lockRankAssistQueue: {},
lockRankSweep: {},
lockRankPollDesc: {},
lockRankCpuprof: {},
lockRankSched: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof},
lockRankAllg: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched},
lockRankAllp: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched},
lockRankTimers: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllp, lockRankTimers},
lockRankNetpollInit: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllp, lockRankTimers},
lockRankHchan: {lockRankSysmon, lockRankScavenge, lockRankSweep, lockRankHchan},
lockRankNotifyList: {},
lockRankSudog: {lockRankSysmon, lockRankScavenge, lockRankSweep, lockRankHchan, lockRankNotifyList},
lockRankRwmutexW: {},
lockRankRwmutexR: {lockRankSysmon, lockRankRwmutexW},
lockRankRoot: {},
lockRankItab: {},
lockRankReflectOffs: {lockRankItab},
lockRankUserArenaState: {},
lockRankTraceBuf: {lockRankSysmon, lockRankScavenge},
lockRankTraceStrings: {lockRankSysmon, lockRankScavenge, lockRankTraceBuf},
lockRankFin: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankGcBitsArenas: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankMheapSpecial: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankMspanSpecial: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankSpanSetSpine: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankProfInsert: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankProfBlock: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankProfMemActive: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings},
lockRankProfMemFuture: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankHchan, lockRankNotifyList, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankProfMemActive},
lockRankGscan: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture},
lockRankStackpool: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankRwmutexW, lockRankRwmutexR, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan},
lockRankStackLarge: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan},
lockRankHchanLeaf: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan, lockRankHchanLeaf},
lockRankWbufSpans: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankDefer, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankSudog, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankMspanSpecial, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan},
lockRankMheap: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankDefer, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankSudog, lockRankRwmutexW, lockRankRwmutexR, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankMspanSpecial, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan, lockRankStackpool, lockRankStackLarge, lockRankWbufSpans},
lockRankGlobalAlloc: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankDefer, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankSudog, lockRankRwmutexW, lockRankRwmutexR, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankMheapSpecial, lockRankMspanSpecial, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan, lockRankStackpool, lockRankStackLarge, lockRankWbufSpans, lockRankMheap},
lockRankTrace: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankDefer, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankSudog, lockRankRwmutexW, lockRankRwmutexR, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankMspanSpecial, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan, lockRankStackpool, lockRankStackLarge, lockRankWbufSpans, lockRankMheap},
lockRankTraceStackTab: {lockRankSysmon, lockRankScavenge, lockRankForcegc, lockRankDefer, lockRankSweepWaiters, lockRankAssistQueue, lockRankSweep, lockRankPollDesc, lockRankCpuprof, lockRankSched, lockRankAllg, lockRankAllp, lockRankTimers, lockRankNetpollInit, lockRankHchan, lockRankNotifyList, lockRankSudog, lockRankRwmutexW, lockRankRwmutexR, lockRankRoot, lockRankItab, lockRankReflectOffs, lockRankUserArenaState, lockRankTraceBuf, lockRankTraceStrings, lockRankFin, lockRankGcBitsArenas, lockRankMspanSpecial, lockRankSpanSetSpine, lockRankProfInsert, lockRankProfBlock, lockRankProfMemActive, lockRankProfMemFuture, lockRankGscan, lockRankStackpool, lockRankStackLarge, lockRankWbufSpans, lockRankMheap, lockRankTrace},
lockRankPanic: {},
lockRankDeadlock: {lockRankPanic, lockRankDeadlock},
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.staticlockranking
package runtime
// // lockRankStruct is embedded in mutex, but is empty when staticklockranking is
// disabled (the default)
type lockRankStruct struct {
}
func lockInit(l *mutex, rank lockRank) {
}
func getLockRank(l *mutex) lockRank {
return 0
}
func lockWithRank(l *mutex, rank lockRank) {
lock2(l)
}
// This function may be called in nosplit context and thus must be nosplit.
//
//go:nosplit
func acquireLockRank(rank lockRank) {
}
func unlockWithRank(l *mutex) {
unlock2(l)
}
// This function may be called in nosplit context and thus must be nosplit.
//
//go:nosplit
func releaseLockRank(rank lockRank) {
}
func lockWithRankMayAcquire(l *mutex, rank lockRank) {
}
//go:nosplit
func assertLockHeld(l *mutex) {
}
//go:nosplit
func assertRankHeld(r lockRank) {
}
//go:nosplit
func worldStopped() {
}
//go:nosplit
func worldStarted() {
}
//go:nosplit
func assertWorldStopped() {
}
//go:nosplit
func assertWorldStoppedOrLockHeld(l *mutex) {
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Memory allocator.
//
// This was originally based on tcmalloc, but has diverged quite a bit.
// http://goog-perftools.sourceforge.net/doc/tcmalloc.html
// The main allocator works in runs of pages.
// Small allocation sizes (up to and including 32 kB) are
// rounded to one of about 70 size classes, each of which
// has its own free set of objects of exactly that size.
// Any free page of memory can be split into a set of objects
// of one size class, which are then managed using a free bitmap.
//
// The allocator's data structures are:
//
// fixalloc: a free-list allocator for fixed-size off-heap objects,
// used to manage storage used by the allocator.
// mheap: the malloc heap, managed at page (8192-byte) granularity.
// mspan: a run of in-use pages managed by the mheap.
// mcentral: collects all spans of a given size class.
// mcache: a per-P cache of mspans with free space.
// mstats: allocation statistics.
//
// Allocating a small object proceeds up a hierarchy of caches:
//
// 1. Round the size up to one of the small size classes
// and look in the corresponding mspan in this P's mcache.
// Scan the mspan's free bitmap to find a free slot.
// If there is a free slot, allocate it.
// This can all be done without acquiring a lock.
//
// 2. If the mspan has no free slots, obtain a new mspan
// from the mcentral's list of mspans of the required size
// class that have free space.
// Obtaining a whole span amortizes the cost of locking
// the mcentral.
//
// 3. If the mcentral's mspan list is empty, obtain a run
// of pages from the mheap to use for the mspan.
//
// 4. If the mheap is empty or has no page runs large enough,
// allocate a new group of pages (at least 1MB) from the
// operating system. Allocating a large run of pages
// amortizes the cost of talking to the operating system.
//
// Sweeping an mspan and freeing objects on it proceeds up a similar
// hierarchy:
//
// 1. If the mspan is being swept in response to allocation, it
// is returned to the mcache to satisfy the allocation.
//
// 2. Otherwise, if the mspan still has allocated objects in it,
// it is placed on the mcentral free list for the mspan's size
// class.
//
// 3. Otherwise, if all objects in the mspan are free, the mspan's
// pages are returned to the mheap and the mspan is now dead.
//
// Allocating and freeing a large object uses the mheap
// directly, bypassing the mcache and mcentral.
//
// If mspan.needzero is false, then free object slots in the mspan are
// already zeroed. Otherwise if needzero is true, objects are zeroed as
// they are allocated. There are various benefits to delaying zeroing
// this way:
//
// 1. Stack frame allocation can avoid zeroing altogether.
//
// 2. It exhibits better temporal locality, since the program is
// probably about to write to the memory.
//
// 3. We don't zero pages that never get reused.
// Virtual memory layout
//
// The heap consists of a set of arenas, which are 64MB on 64-bit and
// 4MB on 32-bit (heapArenaBytes). Each arena's start address is also
// aligned to the arena size.
//
// Each arena has an associated heapArena object that stores the
// metadata for that arena: the heap bitmap for all words in the arena
// and the span map for all pages in the arena. heapArena objects are
// themselves allocated off-heap.
//
// Since arenas are aligned, the address space can be viewed as a
// series of arena frames. The arena map (mheap_.arenas) maps from
// arena frame number to *heapArena, or nil for parts of the address
// space not backed by the Go heap. The arena map is structured as a
// two-level array consisting of a "L1" arena map and many "L2" arena
// maps; however, since arenas are large, on many architectures, the
// arena map consists of a single, large L2 map.
//
// The arena map covers the entire possible address space, allowing
// the Go heap to use any part of the address space. The allocator
// attempts to keep arenas contiguous so that large spans (and hence
// large objects) can cross arenas.
package runtime
import (
"internal/goarch"
"internal/goos"
"runtime/internal/atomic"
"runtime/internal/math"
"runtime/internal/sys"
"unsafe"
)
const (
maxTinySize = _TinySize
tinySizeClass = _TinySizeClass
maxSmallSize = _MaxSmallSize
pageShift = _PageShift
pageSize = _PageSize
concurrentSweep = _ConcurrentSweep
_PageSize = 1 << _PageShift
_PageMask = _PageSize - 1
// _64bit = 1 on 64-bit systems, 0 on 32-bit systems
_64bit = 1 << (^uintptr(0) >> 63) / 2
// Tiny allocator parameters, see "Tiny allocator" comment in malloc.go.
_TinySize = 16
_TinySizeClass = int8(2)
_FixAllocChunk = 16 << 10 // Chunk size for FixAlloc
// Per-P, per order stack segment cache size.
_StackCacheSize = 32 * 1024
// Number of orders that get caching. Order 0 is FixedStack
// and each successive order is twice as large.
// We want to cache 2KB, 4KB, 8KB, and 16KB stacks. Larger stacks
// will be allocated directly.
// Since FixedStack is different on different systems, we
// must vary NumStackOrders to keep the same maximum cached size.
// OS | FixedStack | NumStackOrders
// -----------------+------------+---------------
// linux/darwin/bsd | 2KB | 4
// windows/32 | 4KB | 3
// windows/64 | 8KB | 2
// plan9 | 4KB | 3
_NumStackOrders = 4 - goarch.PtrSize/4*goos.IsWindows - 1*goos.IsPlan9
// heapAddrBits is the number of bits in a heap address. On
// amd64, addresses are sign-extended beyond heapAddrBits. On
// other arches, they are zero-extended.
//
// On most 64-bit platforms, we limit this to 48 bits based on a
// combination of hardware and OS limitations.
//
// amd64 hardware limits addresses to 48 bits, sign-extended
// to 64 bits. Addresses where the top 16 bits are not either
// all 0 or all 1 are "non-canonical" and invalid. Because of
// these "negative" addresses, we offset addresses by 1<<47
// (arenaBaseOffset) on amd64 before computing indexes into
// the heap arenas index. In 2017, amd64 hardware added
// support for 57 bit addresses; however, currently only Linux
// supports this extension and the kernel will never choose an
// address above 1<<47 unless mmap is called with a hint
// address above 1<<47 (which we never do).
//
// arm64 hardware (as of ARMv8) limits user addresses to 48
// bits, in the range [0, 1<<48).
//
// ppc64, mips64, and s390x support arbitrary 64 bit addresses
// in hardware. On Linux, Go leans on stricter OS limits. Based
// on Linux's processor.h, the user address space is limited as
// follows on 64-bit architectures:
//
// Architecture Name Maximum Value (exclusive)
// ---------------------------------------------------------------------
// amd64 TASK_SIZE_MAX 0x007ffffffff000 (47 bit addresses)
// arm64 TASK_SIZE_64 0x01000000000000 (48 bit addresses)
// ppc64{,le} TASK_SIZE_USER64 0x00400000000000 (46 bit addresses)
// mips64{,le} TASK_SIZE64 0x00010000000000 (40 bit addresses)
// s390x TASK_SIZE 1<<64 (64 bit addresses)
//
// These limits may increase over time, but are currently at
// most 48 bits except on s390x. On all architectures, Linux
// starts placing mmap'd regions at addresses that are
// significantly below 48 bits, so even if it's possible to
// exceed Go's 48 bit limit, it's extremely unlikely in
// practice.
//
// On 32-bit platforms, we accept the full 32-bit address
// space because doing so is cheap.
// mips32 only has access to the low 2GB of virtual memory, so
// we further limit it to 31 bits.
//
// On ios/arm64, although 64-bit pointers are presumably
// available, pointers are truncated to 33 bits in iOS <14.
// Furthermore, only the top 4 GiB of the address space are
// actually available to the application. In iOS >=14, more
// of the address space is available, and the OS can now
// provide addresses outside of those 33 bits. Pick 40 bits
// as a reasonable balance between address space usage by the
// page allocator, and flexibility for what mmap'd regions
// we'll accept for the heap. We can't just move to the full
// 48 bits because this uses too much address space for older
// iOS versions.
// TODO(mknyszek): Once iOS <14 is deprecated, promote ios/arm64
// to a 48-bit address space like every other arm64 platform.
//
// WebAssembly currently has a limit of 4GB linear memory.
heapAddrBits = (_64bit*(1-goarch.IsWasm)*(1-goos.IsIos*goarch.IsArm64))*48 + (1-_64bit+goarch.IsWasm)*(32-(goarch.IsMips+goarch.IsMipsle)) + 40*goos.IsIos*goarch.IsArm64
// maxAlloc is the maximum size of an allocation. On 64-bit,
// it's theoretically possible to allocate 1<<heapAddrBits bytes. On
// 32-bit, however, this is one less than 1<<32 because the
// number of bytes in the address space doesn't actually fit
// in a uintptr.
maxAlloc = (1 << heapAddrBits) - (1-_64bit)*1
// The number of bits in a heap address, the size of heap
// arenas, and the L1 and L2 arena map sizes are related by
//
// (1 << addr bits) = arena size * L1 entries * L2 entries
//
// Currently, we balance these as follows:
//
// Platform Addr bits Arena size L1 entries L2 entries
// -------------- --------- ---------- ---------- -----------
// */64-bit 48 64MB 1 4M (32MB)
// windows/64-bit 48 4MB 64 1M (8MB)
// ios/arm64 33 4MB 1 2048 (8KB)
// */32-bit 32 4MB 1 1024 (4KB)
// */mips(le) 31 4MB 1 512 (2KB)
// heapArenaBytes is the size of a heap arena. The heap
// consists of mappings of size heapArenaBytes, aligned to
// heapArenaBytes. The initial heap mapping is one arena.
//
// This is currently 64MB on 64-bit non-Windows and 4MB on
// 32-bit and on Windows. We use smaller arenas on Windows
// because all committed memory is charged to the process,
// even if it's not touched. Hence, for processes with small
// heaps, the mapped arena space needs to be commensurate.
// This is particularly important with the race detector,
// since it significantly amplifies the cost of committed
// memory.
heapArenaBytes = 1 << logHeapArenaBytes
heapArenaWords = heapArenaBytes / goarch.PtrSize
// logHeapArenaBytes is log_2 of heapArenaBytes. For clarity,
// prefer using heapArenaBytes where possible (we need the
// constant to compute some other constants).
logHeapArenaBytes = (6+20)*(_64bit*(1-goos.IsWindows)*(1-goarch.IsWasm)*(1-goos.IsIos*goarch.IsArm64)) + (2+20)*(_64bit*goos.IsWindows) + (2+20)*(1-_64bit) + (2+20)*goarch.IsWasm + (2+20)*goos.IsIos*goarch.IsArm64
// heapArenaBitmapWords is the size of each heap arena's bitmap in uintptrs.
heapArenaBitmapWords = heapArenaWords / (8 * goarch.PtrSize)
pagesPerArena = heapArenaBytes / pageSize
// arenaL1Bits is the number of bits of the arena number
// covered by the first level arena map.
//
// This number should be small, since the first level arena
// map requires PtrSize*(1<<arenaL1Bits) of space in the
// binary's BSS. It can be zero, in which case the first level
// index is effectively unused. There is a performance benefit
// to this, since the generated code can be more efficient,
// but comes at the cost of having a large L2 mapping.
//
// We use the L1 map on 64-bit Windows because the arena size
// is small, but the address space is still 48 bits, and
// there's a high cost to having a large L2.
arenaL1Bits = 6 * (_64bit * goos.IsWindows)
// arenaL2Bits is the number of bits of the arena number
// covered by the second level arena index.
//
// The size of each arena map allocation is proportional to
// 1<<arenaL2Bits, so it's important that this not be too
// large. 48 bits leads to 32MB arena index allocations, which
// is about the practical threshold.
arenaL2Bits = heapAddrBits - logHeapArenaBytes - arenaL1Bits
// arenaL1Shift is the number of bits to shift an arena frame
// number by to compute an index into the first level arena map.
arenaL1Shift = arenaL2Bits
// arenaBits is the total bits in a combined arena map index.
// This is split between the index into the L1 arena map and
// the L2 arena map.
arenaBits = arenaL1Bits + arenaL2Bits
// arenaBaseOffset is the pointer value that corresponds to
// index 0 in the heap arena map.
//
// On amd64, the address space is 48 bits, sign extended to 64
// bits. This offset lets us handle "negative" addresses (or
// high addresses if viewed as unsigned).
//
// On aix/ppc64, this offset allows to keep the heapAddrBits to
// 48. Otherwise, it would be 60 in order to handle mmap addresses
// (in range 0x0a00000000000000 - 0x0afffffffffffff). But in this
// case, the memory reserved in (s *pageAlloc).init for chunks
// is causing important slowdowns.
//
// On other platforms, the user address space is contiguous
// and starts at 0, so no offset is necessary.
arenaBaseOffset = 0xffff800000000000*goarch.IsAmd64 + 0x0a00000000000000*goos.IsAix
// A typed version of this constant that will make it into DWARF (for viewcore).
arenaBaseOffsetUintptr = uintptr(arenaBaseOffset)
// Max number of threads to run garbage collection.
// 2, 3, and 4 are all plausible maximums depending
// on the hardware details of the machine. The garbage
// collector scales well to 32 cpus.
_MaxGcproc = 32
// minLegalPointer is the smallest possible legal pointer.
// This is the smallest possible architectural page size,
// since we assume that the first page is never mapped.
//
// This should agree with minZeroPage in the compiler.
minLegalPointer uintptr = 4096
)
// physPageSize is the size in bytes of the OS's physical pages.
// Mapping and unmapping operations must be done at multiples of
// physPageSize.
//
// This must be set by the OS init code (typically in osinit) before
// mallocinit.
var physPageSize uintptr
// physHugePageSize is the size in bytes of the OS's default physical huge
// page size whose allocation is opaque to the application. It is assumed
// and verified to be a power of two.
//
// If set, this must be set by the OS init code (typically in osinit) before
// mallocinit. However, setting it at all is optional, and leaving the default
// value is always safe (though potentially less efficient).
//
// Since physHugePageSize is always assumed to be a power of two,
// physHugePageShift is defined as physHugePageSize == 1 << physHugePageShift.
// The purpose of physHugePageShift is to avoid doing divisions in
// performance critical functions.
var (
physHugePageSize uintptr
physHugePageShift uint
)
func mallocinit() {
if class_to_size[_TinySizeClass] != _TinySize {
throw("bad TinySizeClass")
}
if heapArenaBitmapWords&(heapArenaBitmapWords-1) != 0 {
// heapBits expects modular arithmetic on bitmap
// addresses to work.
throw("heapArenaBitmapWords not a power of 2")
}
// Check physPageSize.
if physPageSize == 0 {
// The OS init code failed to fetch the physical page size.
throw("failed to get system page size")
}
if physPageSize > maxPhysPageSize {
print("system page size (", physPageSize, ") is larger than maximum page size (", maxPhysPageSize, ")\n")
throw("bad system page size")
}
if physPageSize < minPhysPageSize {
print("system page size (", physPageSize, ") is smaller than minimum page size (", minPhysPageSize, ")\n")
throw("bad system page size")
}
if physPageSize&(physPageSize-1) != 0 {
print("system page size (", physPageSize, ") must be a power of 2\n")
throw("bad system page size")
}
if physHugePageSize&(physHugePageSize-1) != 0 {
print("system huge page size (", physHugePageSize, ") must be a power of 2\n")
throw("bad system huge page size")
}
if physHugePageSize > maxPhysHugePageSize {
// physHugePageSize is greater than the maximum supported huge page size.
// Don't throw here, like in the other cases, since a system configured
// in this way isn't wrong, we just don't have the code to support them.
// Instead, silently set the huge page size to zero.
physHugePageSize = 0
}
if physHugePageSize != 0 {
// Since physHugePageSize is a power of 2, it suffices to increase
// physHugePageShift until 1<<physHugePageShift == physHugePageSize.
for 1<<physHugePageShift != physHugePageSize {
physHugePageShift++
}
}
if pagesPerArena%pagesPerSpanRoot != 0 {
print("pagesPerArena (", pagesPerArena, ") is not divisible by pagesPerSpanRoot (", pagesPerSpanRoot, ")\n")
throw("bad pagesPerSpanRoot")
}
if pagesPerArena%pagesPerReclaimerChunk != 0 {
print("pagesPerArena (", pagesPerArena, ") is not divisible by pagesPerReclaimerChunk (", pagesPerReclaimerChunk, ")\n")
throw("bad pagesPerReclaimerChunk")
}
// Initialize the heap.
mheap_.init()
mcache0 = allocmcache()
lockInit(&gcBitsArenas.lock, lockRankGcBitsArenas)
lockInit(&profInsertLock, lockRankProfInsert)
lockInit(&profBlockLock, lockRankProfBlock)
lockInit(&profMemActiveLock, lockRankProfMemActive)
for i := range profMemFutureLock {
lockInit(&profMemFutureLock[i], lockRankProfMemFuture)
}
lockInit(&globalAlloc.mutex, lockRankGlobalAlloc)
// Create initial arena growth hints.
if goarch.PtrSize == 8 {
// On a 64-bit machine, we pick the following hints
// because:
//
// 1. Starting from the middle of the address space
// makes it easier to grow out a contiguous range
// without running in to some other mapping.
//
// 2. This makes Go heap addresses more easily
// recognizable when debugging.
//
// 3. Stack scanning in gccgo is still conservative,
// so it's important that addresses be distinguishable
// from other data.
//
// Starting at 0x00c0 means that the valid memory addresses
// will begin 0x00c0, 0x00c1, ...
// In little-endian, that's c0 00, c1 00, ... None of those are valid
// UTF-8 sequences, and they are otherwise as far away from
// ff (likely a common byte) as possible. If that fails, we try other 0xXXc0
// addresses. An earlier attempt to use 0x11f8 caused out of memory errors
// on OS X during thread allocations. 0x00c0 causes conflicts with
// AddressSanitizer which reserves all memory up to 0x0100.
// These choices reduce the odds of a conservative garbage collector
// not collecting memory because some non-pointer block of memory
// had a bit pattern that matched a memory address.
//
// However, on arm64, we ignore all this advice above and slam the
// allocation at 0x40 << 32 because when using 4k pages with 3-level
// translation buffers, the user address space is limited to 39 bits
// On ios/arm64, the address space is even smaller.
//
// On AIX, mmaps starts at 0x0A00000000000000 for 64-bit.
// processes.
//
// Space mapped for user arenas comes immediately after the range
// originally reserved for the regular heap when race mode is not
// enabled because user arena chunks can never be used for regular heap
// allocations and we want to avoid fragmenting the address space.
//
// In race mode we have no choice but to just use the same hints because
// the race detector requires that the heap be mapped contiguously.
for i := 0x7f; i >= 0; i-- {
var p uintptr
switch {
case raceenabled:
// The TSAN runtime requires the heap
// to be in the range [0x00c000000000,
// 0x00e000000000).
p = uintptr(i)<<32 | uintptrMask&(0x00c0<<32)
if p >= uintptrMask&0x00e000000000 {
continue
}
case GOARCH == "arm64" && GOOS == "ios":
p = uintptr(i)<<40 | uintptrMask&(0x0013<<28)
case GOARCH == "arm64":
p = uintptr(i)<<40 | uintptrMask&(0x0040<<32)
case GOOS == "aix":
if i == 0 {
// We don't use addresses directly after 0x0A00000000000000
// to avoid collisions with others mmaps done by non-go programs.
continue
}
p = uintptr(i)<<40 | uintptrMask&(0xa0<<52)
default:
p = uintptr(i)<<40 | uintptrMask&(0x00c0<<32)
}
// Switch to generating hints for user arenas if we've gone
// through about half the hints. In race mode, take only about
// a quarter; we don't have very much space to work with.
hintList := &mheap_.arenaHints
if (!raceenabled && i > 0x3f) || (raceenabled && i > 0x5f) {
hintList = &mheap_.userArena.arenaHints
}
hint := (*arenaHint)(mheap_.arenaHintAlloc.alloc())
hint.addr = p
hint.next, *hintList = *hintList, hint
}
} else {
// On a 32-bit machine, we're much more concerned
// about keeping the usable heap contiguous.
// Hence:
//
// 1. We reserve space for all heapArenas up front so
// they don't get interleaved with the heap. They're
// ~258MB, so this isn't too bad. (We could reserve a
// smaller amount of space up front if this is a
// problem.)
//
// 2. We hint the heap to start right above the end of
// the binary so we have the best chance of keeping it
// contiguous.
//
// 3. We try to stake out a reasonably large initial
// heap reservation.
const arenaMetaSize = (1 << arenaBits) * unsafe.Sizeof(heapArena{})
meta := uintptr(sysReserve(nil, arenaMetaSize))
if meta != 0 {
mheap_.heapArenaAlloc.init(meta, arenaMetaSize, true)
}
// We want to start the arena low, but if we're linked
// against C code, it's possible global constructors
// have called malloc and adjusted the process' brk.
// Query the brk so we can avoid trying to map the
// region over it (which will cause the kernel to put
// the region somewhere else, likely at a high
// address).
procBrk := sbrk0()
// If we ask for the end of the data segment but the
// operating system requires a little more space
// before we can start allocating, it will give out a
// slightly higher pointer. Except QEMU, which is
// buggy, as usual: it won't adjust the pointer
// upward. So adjust it upward a little bit ourselves:
// 1/4 MB to get away from the running binary image.
p := firstmoduledata.end
if p < procBrk {
p = procBrk
}
if mheap_.heapArenaAlloc.next <= p && p < mheap_.heapArenaAlloc.end {
p = mheap_.heapArenaAlloc.end
}
p = alignUp(p+(256<<10), heapArenaBytes)
// Because we're worried about fragmentation on
// 32-bit, we try to make a large initial reservation.
arenaSizes := []uintptr{
512 << 20,
256 << 20,
128 << 20,
}
for _, arenaSize := range arenaSizes {
a, size := sysReserveAligned(unsafe.Pointer(p), arenaSize, heapArenaBytes)
if a != nil {
mheap_.arena.init(uintptr(a), size, false)
p = mheap_.arena.end // For hint below
break
}
}
hint := (*arenaHint)(mheap_.arenaHintAlloc.alloc())
hint.addr = p
hint.next, mheap_.arenaHints = mheap_.arenaHints, hint
// Place the hint for user arenas just after the large reservation.
//
// While this potentially competes with the hint above, in practice we probably
// aren't going to be getting this far anyway on 32-bit platforms.
userArenaHint := (*arenaHint)(mheap_.arenaHintAlloc.alloc())
userArenaHint.addr = p
userArenaHint.next, mheap_.userArena.arenaHints = mheap_.userArena.arenaHints, userArenaHint
}
}
// sysAlloc allocates heap arena space for at least n bytes. The
// returned pointer is always heapArenaBytes-aligned and backed by
// h.arenas metadata. The returned size is always a multiple of
// heapArenaBytes. sysAlloc returns nil on failure.
// There is no corresponding free function.
//
// hintList is a list of hint addresses for where to allocate new
// heap arenas. It must be non-nil.
//
// register indicates whether the heap arena should be registered
// in allArenas.
//
// sysAlloc returns a memory region in the Reserved state. This region must
// be transitioned to Prepared and then Ready before use.
//
// h must be locked.
func (h *mheap) sysAlloc(n uintptr, hintList **arenaHint, register bool) (v unsafe.Pointer, size uintptr) {
assertLockHeld(&h.lock)
n = alignUp(n, heapArenaBytes)
if hintList == &h.arenaHints {
// First, try the arena pre-reservation.
// Newly-used mappings are considered released.
//
// Only do this if we're using the regular heap arena hints.
// This behavior is only for the heap.
v = h.arena.alloc(n, heapArenaBytes, &gcController.heapReleased)
if v != nil {
size = n
goto mapped
}
}
// Try to grow the heap at a hint address.
for *hintList != nil {
hint := *hintList
p := hint.addr
if hint.down {
p -= n
}
if p+n < p {
// We can't use this, so don't ask.
v = nil
} else if arenaIndex(p+n-1) >= 1<<arenaBits {
// Outside addressable heap. Can't use.
v = nil
} else {
v = sysReserve(unsafe.Pointer(p), n)
}
if p == uintptr(v) {
// Success. Update the hint.
if !hint.down {
p += n
}
hint.addr = p
size = n
break
}
// Failed. Discard this hint and try the next.
//
// TODO: This would be cleaner if sysReserve could be
// told to only return the requested address. In
// particular, this is already how Windows behaves, so
// it would simplify things there.
if v != nil {
sysFreeOS(v, n)
}
*hintList = hint.next
h.arenaHintAlloc.free(unsafe.Pointer(hint))
}
if size == 0 {
if raceenabled {
// The race detector assumes the heap lives in
// [0x00c000000000, 0x00e000000000), but we
// just ran out of hints in this region. Give
// a nice failure.
throw("too many address space collisions for -race mode")
}
// All of the hints failed, so we'll take any
// (sufficiently aligned) address the kernel will give
// us.
v, size = sysReserveAligned(nil, n, heapArenaBytes)
if v == nil {
return nil, 0
}
// Create new hints for extending this region.
hint := (*arenaHint)(h.arenaHintAlloc.alloc())
hint.addr, hint.down = uintptr(v), true
hint.next, mheap_.arenaHints = mheap_.arenaHints, hint
hint = (*arenaHint)(h.arenaHintAlloc.alloc())
hint.addr = uintptr(v) + size
hint.next, mheap_.arenaHints = mheap_.arenaHints, hint
}
// Check for bad pointers or pointers we can't use.
{
var bad string
p := uintptr(v)
if p+size < p {
bad = "region exceeds uintptr range"
} else if arenaIndex(p) >= 1<<arenaBits {
bad = "base outside usable address space"
} else if arenaIndex(p+size-1) >= 1<<arenaBits {
bad = "end outside usable address space"
}
if bad != "" {
// This should be impossible on most architectures,
// but it would be really confusing to debug.
print("runtime: memory allocated by OS [", hex(p), ", ", hex(p+size), ") not in usable address space: ", bad, "\n")
throw("memory reservation exceeds address space limit")
}
}
if uintptr(v)&(heapArenaBytes-1) != 0 {
throw("misrounded allocation in sysAlloc")
}
mapped:
// Create arena metadata.
for ri := arenaIndex(uintptr(v)); ri <= arenaIndex(uintptr(v)+size-1); ri++ {
l2 := h.arenas[ri.l1()]
if l2 == nil {
// Allocate an L2 arena map.
//
// Use sysAllocOS instead of sysAlloc or persistentalloc because there's no
// statistic we can comfortably account for this space in. With this structure,
// we rely on demand paging to avoid large overheads, but tracking which memory
// is paged in is too expensive. Trying to account for the whole region means
// that it will appear like an enormous memory overhead in statistics, even though
// it is not.
l2 = (*[1 << arenaL2Bits]*heapArena)(sysAllocOS(unsafe.Sizeof(*l2)))
if l2 == nil {
throw("out of memory allocating heap arena map")
}
atomic.StorepNoWB(unsafe.Pointer(&h.arenas[ri.l1()]), unsafe.Pointer(l2))
}
if l2[ri.l2()] != nil {
throw("arena already initialized")
}
var r *heapArena
r = (*heapArena)(h.heapArenaAlloc.alloc(unsafe.Sizeof(*r), goarch.PtrSize, &memstats.gcMiscSys))
if r == nil {
r = (*heapArena)(persistentalloc(unsafe.Sizeof(*r), goarch.PtrSize, &memstats.gcMiscSys))
if r == nil {
throw("out of memory allocating heap arena metadata")
}
}
// Register the arena in allArenas if requested.
if register {
if len(h.allArenas) == cap(h.allArenas) {
size := 2 * uintptr(cap(h.allArenas)) * goarch.PtrSize
if size == 0 {
size = physPageSize
}
newArray := (*notInHeap)(persistentalloc(size, goarch.PtrSize, &memstats.gcMiscSys))
if newArray == nil {
throw("out of memory allocating allArenas")
}
oldSlice := h.allArenas
*(*notInHeapSlice)(unsafe.Pointer(&h.allArenas)) = notInHeapSlice{newArray, len(h.allArenas), int(size / goarch.PtrSize)}
copy(h.allArenas, oldSlice)
// Do not free the old backing array because
// there may be concurrent readers. Since we
// double the array each time, this can lead
// to at most 2x waste.
}
h.allArenas = h.allArenas[:len(h.allArenas)+1]
h.allArenas[len(h.allArenas)-1] = ri
}
// Store atomically just in case an object from the
// new heap arena becomes visible before the heap lock
// is released (which shouldn't happen, but there's
// little downside to this).
atomic.StorepNoWB(unsafe.Pointer(&l2[ri.l2()]), unsafe.Pointer(r))
}
// Tell the race detector about the new heap memory.
if raceenabled {
racemapshadow(v, size)
}
return
}
// sysReserveAligned is like sysReserve, but the returned pointer is
// aligned to align bytes. It may reserve either n or n+align bytes,
// so it returns the size that was reserved.
func sysReserveAligned(v unsafe.Pointer, size, align uintptr) (unsafe.Pointer, uintptr) {
// Since the alignment is rather large in uses of this
// function, we're not likely to get it by chance, so we ask
// for a larger region and remove the parts we don't need.
retries := 0
retry:
p := uintptr(sysReserve(v, size+align))
switch {
case p == 0:
return nil, 0
case p&(align-1) == 0:
return unsafe.Pointer(p), size + align
case GOOS == "windows":
// On Windows we can't release pieces of a
// reservation, so we release the whole thing and
// re-reserve the aligned sub-region. This may race,
// so we may have to try again.
sysFreeOS(unsafe.Pointer(p), size+align)
p = alignUp(p, align)
p2 := sysReserve(unsafe.Pointer(p), size)
if p != uintptr(p2) {
// Must have raced. Try again.
sysFreeOS(p2, size)
if retries++; retries == 100 {
throw("failed to allocate aligned heap memory; too many retries")
}
goto retry
}
// Success.
return p2, size
default:
// Trim off the unaligned parts.
pAligned := alignUp(p, align)
sysFreeOS(unsafe.Pointer(p), pAligned-p)
end := pAligned + size
endLen := (p + size + align) - end
if endLen > 0 {
sysFreeOS(unsafe.Pointer(end), endLen)
}
return unsafe.Pointer(pAligned), size
}
}
// base address for all 0-byte allocations
var zerobase uintptr
// nextFreeFast returns the next free object if one is quickly available.
// Otherwise it returns 0.
func nextFreeFast(s *mspan) gclinkptr {
theBit := sys.TrailingZeros64(s.allocCache) // Is there a free object in the allocCache?
if theBit < 64 {
result := s.freeindex + uintptr(theBit)
if result < s.nelems {
freeidx := result + 1
if freeidx%64 == 0 && freeidx != s.nelems {
return 0
}
s.allocCache >>= uint(theBit + 1)
s.freeindex = freeidx
s.allocCount++
return gclinkptr(result*s.elemsize + s.base())
}
}
return 0
}
// nextFree returns the next free object from the cached span if one is available.
// Otherwise it refills the cache with a span with an available object and
// returns that object along with a flag indicating that this was a heavy
// weight allocation. If it is a heavy weight allocation the caller must
// determine whether a new GC cycle needs to be started or if the GC is active
// whether this goroutine needs to assist the GC.
//
// Must run in a non-preemptible context since otherwise the owner of
// c could change.
func (c *mcache) nextFree(spc spanClass) (v gclinkptr, s *mspan, shouldhelpgc bool) {
s = c.alloc[spc]
shouldhelpgc = false
freeIndex := s.nextFreeIndex()
if freeIndex == s.nelems {
// The span is full.
if uintptr(s.allocCount) != s.nelems {
println("runtime: s.allocCount=", s.allocCount, "s.nelems=", s.nelems)
throw("s.allocCount != s.nelems && freeIndex == s.nelems")
}
c.refill(spc)
shouldhelpgc = true
s = c.alloc[spc]
freeIndex = s.nextFreeIndex()
}
if freeIndex >= s.nelems {
throw("freeIndex is not valid")
}
v = gclinkptr(freeIndex*s.elemsize + s.base())
s.allocCount++
if uintptr(s.allocCount) > s.nelems {
println("s.allocCount=", s.allocCount, "s.nelems=", s.nelems)
throw("s.allocCount > s.nelems")
}
return
}
// Allocate an object of size bytes.
// Small objects are allocated from the per-P cache's free lists.
// Large objects (> 32 kB) are allocated straight from the heap.
func mallocgc(size uintptr, typ *_type, needzero bool) unsafe.Pointer {
if gcphase == _GCmarktermination {
throw("mallocgc called with gcphase == _GCmarktermination")
}
if size == 0 {
return unsafe.Pointer(&zerobase)
}
// It's possible for any malloc to trigger sweeping, which may in
// turn queue finalizers. Record this dynamic lock edge.
lockRankMayQueueFinalizer()
userSize := size
if asanenabled {
// Refer to ASAN runtime library, the malloc() function allocates extra memory,
// the redzone, around the user requested memory region. And the redzones are marked
// as unaddressable. We perform the same operations in Go to detect the overflows or
// underflows.
size += computeRZlog(size)
}
if debug.malloc {
if debug.sbrk != 0 {
align := uintptr(16)
if typ != nil {
// TODO(austin): This should be just
// align = uintptr(typ.align)
// but that's only 4 on 32-bit platforms,
// even if there's a uint64 field in typ (see #599).
// This causes 64-bit atomic accesses to panic.
// Hence, we use stricter alignment that matches
// the normal allocator better.
if size&7 == 0 {
align = 8
} else if size&3 == 0 {
align = 4
} else if size&1 == 0 {
align = 2
} else {
align = 1
}
}
return persistentalloc(size, align, &memstats.other_sys)
}
if inittrace.active && inittrace.id == getg().goid {
// Init functions are executed sequentially in a single goroutine.
inittrace.allocs += 1
}
}
// assistG is the G to charge for this allocation, or nil if
// GC is not currently active.
assistG := deductAssistCredit(size)
// Set mp.mallocing to keep from being preempted by GC.
mp := acquirem()
if mp.mallocing != 0 {
throw("malloc deadlock")
}
if mp.gsignal == getg() {
throw("malloc during signal")
}
mp.mallocing = 1
shouldhelpgc := false
dataSize := userSize
c := getMCache(mp)
if c == nil {
throw("mallocgc called without a P or outside bootstrapping")
}
var span *mspan
var x unsafe.Pointer
noscan := typ == nil || typ.ptrdata == 0
// In some cases block zeroing can profitably (for latency reduction purposes)
// be delayed till preemption is possible; delayedZeroing tracks that state.
delayedZeroing := false
if size <= maxSmallSize {
if noscan && size < maxTinySize {
// Tiny allocator.
//
// Tiny allocator combines several tiny allocation requests
// into a single memory block. The resulting memory block
// is freed when all subobjects are unreachable. The subobjects
// must be noscan (don't have pointers), this ensures that
// the amount of potentially wasted memory is bounded.
//
// Size of the memory block used for combining (maxTinySize) is tunable.
// Current setting is 16 bytes, which relates to 2x worst case memory
// wastage (when all but one subobjects are unreachable).
// 8 bytes would result in no wastage at all, but provides less
// opportunities for combining.
// 32 bytes provides more opportunities for combining,
// but can lead to 4x worst case wastage.
// The best case winning is 8x regardless of block size.
//
// Objects obtained from tiny allocator must not be freed explicitly.
// So when an object will be freed explicitly, we ensure that
// its size >= maxTinySize.
//
// SetFinalizer has a special case for objects potentially coming
// from tiny allocator, it such case it allows to set finalizers
// for an inner byte of a memory block.
//
// The main targets of tiny allocator are small strings and
// standalone escaping variables. On a json benchmark
// the allocator reduces number of allocations by ~12% and
// reduces heap size by ~20%.
off := c.tinyoffset
// Align tiny pointer for required (conservative) alignment.
if size&7 == 0 {
off = alignUp(off, 8)
} else if goarch.PtrSize == 4 && size == 12 {
// Conservatively align 12-byte objects to 8 bytes on 32-bit
// systems so that objects whose first field is a 64-bit
// value is aligned to 8 bytes and does not cause a fault on
// atomic access. See issue 37262.
// TODO(mknyszek): Remove this workaround if/when issue 36606
// is resolved.
off = alignUp(off, 8)
} else if size&3 == 0 {
off = alignUp(off, 4)
} else if size&1 == 0 {
off = alignUp(off, 2)
}
if off+size <= maxTinySize && c.tiny != 0 {
// The object fits into existing tiny block.
x = unsafe.Pointer(c.tiny + off)
c.tinyoffset = off + size
c.tinyAllocs++
mp.mallocing = 0
releasem(mp)
return x
}
// Allocate a new maxTinySize block.
span = c.alloc[tinySpanClass]
v := nextFreeFast(span)
if v == 0 {
v, span, shouldhelpgc = c.nextFree(tinySpanClass)
}
x = unsafe.Pointer(v)
(*[2]uint64)(x)[0] = 0
(*[2]uint64)(x)[1] = 0
// See if we need to replace the existing tiny block with the new one
// based on amount of remaining free space.
if !raceenabled && (size < c.tinyoffset || c.tiny == 0) {
// Note: disabled when race detector is on, see comment near end of this function.
c.tiny = uintptr(x)
c.tinyoffset = size
}
size = maxTinySize
} else {
var sizeclass uint8
if size <= smallSizeMax-8 {
sizeclass = size_to_class8[divRoundUp(size, smallSizeDiv)]
} else {
sizeclass = size_to_class128[divRoundUp(size-smallSizeMax, largeSizeDiv)]
}
size = uintptr(class_to_size[sizeclass])
spc := makeSpanClass(sizeclass, noscan)
span = c.alloc[spc]
v := nextFreeFast(span)
if v == 0 {
v, span, shouldhelpgc = c.nextFree(spc)
}
x = unsafe.Pointer(v)
if needzero && span.needzero != 0 {
memclrNoHeapPointers(x, size)
}
}
} else {
shouldhelpgc = true
// For large allocations, keep track of zeroed state so that
// bulk zeroing can be happen later in a preemptible context.
span = c.allocLarge(size, noscan)
span.freeindex = 1
span.allocCount = 1
size = span.elemsize
x = unsafe.Pointer(span.base())
if needzero && span.needzero != 0 {
if noscan {
delayedZeroing = true
} else {
memclrNoHeapPointers(x, size)
// We've in theory cleared almost the whole span here,
// and could take the extra step of actually clearing
// the whole thing. However, don't. Any GC bits for the
// uncleared parts will be zero, and it's just going to
// be needzero = 1 once freed anyway.
}
}
}
if !noscan {
var scanSize uintptr
heapBitsSetType(uintptr(x), size, dataSize, typ)
if dataSize > typ.size {
// Array allocation. If there are any
// pointers, GC has to scan to the last
// element.
if typ.ptrdata != 0 {
scanSize = dataSize - typ.size + typ.ptrdata
}
} else {
scanSize = typ.ptrdata
}
c.scanAlloc += scanSize
}
// Ensure that the stores above that initialize x to
// type-safe memory and set the heap bits occur before
// the caller can make x observable to the garbage
// collector. Otherwise, on weakly ordered machines,
// the garbage collector could follow a pointer to x,
// but see uninitialized memory or stale heap bits.
publicationBarrier()
// As x and the heap bits are initialized, update
// freeIndexForScan now so x is seen by the GC
// (including convervative scan) as an allocated object.
// While this pointer can't escape into user code as a
// _live_ pointer until we return, conservative scanning
// may find a dead pointer that happens to point into this
// object. Delaying this update until now ensures that
// conservative scanning considers this pointer dead until
// this point.
span.freeIndexForScan = span.freeindex
// Allocate black during GC.
// All slots hold nil so no scanning is needed.
// This may be racing with GC so do it atomically if there can be
// a race marking the bit.
if gcphase != _GCoff {
gcmarknewobject(span, uintptr(x), size)
}
if raceenabled {
racemalloc(x, size)
}
if msanenabled {
msanmalloc(x, size)
}
if asanenabled {
// We should only read/write the memory with the size asked by the user.
// The rest of the allocated memory should be poisoned, so that we can report
// errors when accessing poisoned memory.
// The allocated memory is larger than required userSize, it will also include
// redzone and some other padding bytes.
rzBeg := unsafe.Add(x, userSize)
asanpoison(rzBeg, size-userSize)
asanunpoison(x, userSize)
}
if rate := MemProfileRate; rate > 0 {
// Note cache c only valid while m acquired; see #47302
if rate != 1 && size < c.nextSample {
c.nextSample -= size
} else {
profilealloc(mp, x, size)
}
}
mp.mallocing = 0
releasem(mp)
// Pointerfree data can be zeroed late in a context where preemption can occur.
// x will keep the memory alive.
if delayedZeroing {
if !noscan {
throw("delayed zeroing on data that may contain pointers")
}
memclrNoHeapPointersChunked(size, x) // This is a possible preemption point: see #47302
}
if debug.malloc {
if debug.allocfreetrace != 0 {
tracealloc(x, size, typ)
}
if inittrace.active && inittrace.id == getg().goid {
// Init functions are executed sequentially in a single goroutine.
inittrace.bytes += uint64(size)
}
}
if assistG != nil {
// Account for internal fragmentation in the assist
// debt now that we know it.
assistG.gcAssistBytes -= int64(size - dataSize)
}
if shouldhelpgc {
if t := (gcTrigger{kind: gcTriggerHeap}); t.test() {
gcStart(t)
}
}
if raceenabled && noscan && dataSize < maxTinySize {
// Pad tinysize allocations so they are aligned with the end
// of the tinyalloc region. This ensures that any arithmetic
// that goes off the top end of the object will be detectable
// by checkptr (issue 38872).
// Note that we disable tinyalloc when raceenabled for this to work.
// TODO: This padding is only performed when the race detector
// is enabled. It would be nice to enable it if any package
// was compiled with checkptr, but there's no easy way to
// detect that (especially at compile time).
// TODO: enable this padding for all allocations, not just
// tinyalloc ones. It's tricky because of pointer maps.
// Maybe just all noscan objects?
x = add(x, size-dataSize)
}
return x
}
// deductAssistCredit reduces the current G's assist credit
// by size bytes, and assists the GC if necessary.
//
// Caller must be preemptible.
//
// Returns the G for which the assist credit was accounted.
func deductAssistCredit(size uintptr) *g {
var assistG *g
if gcBlackenEnabled != 0 {
// Charge the current user G for this allocation.
assistG = getg()
if assistG.m.curg != nil {
assistG = assistG.m.curg
}
// Charge the allocation against the G. We'll account
// for internal fragmentation at the end of mallocgc.
assistG.gcAssistBytes -= int64(size)
if assistG.gcAssistBytes < 0 {
// This G is in debt. Assist the GC to correct
// this before allocating. This must happen
// before disabling preemption.
gcAssistAlloc(assistG)
}
}
return assistG
}
// memclrNoHeapPointersChunked repeatedly calls memclrNoHeapPointers
// on chunks of the buffer to be zeroed, with opportunities for preemption
// along the way. memclrNoHeapPointers contains no safepoints and also
// cannot be preemptively scheduled, so this provides a still-efficient
// block copy that can also be preempted on a reasonable granularity.
//
// Use this with care; if the data being cleared is tagged to contain
// pointers, this allows the GC to run before it is all cleared.
func memclrNoHeapPointersChunked(size uintptr, x unsafe.Pointer) {
v := uintptr(x)
// got this from benchmarking. 128k is too small, 512k is too large.
const chunkBytes = 256 * 1024
vsize := v + size
for voff := v; voff < vsize; voff = voff + chunkBytes {
if getg().preempt {
// may hold locks, e.g., profiling
goschedguarded()
}
// clear min(avail, lump) bytes
n := vsize - voff
if n > chunkBytes {
n = chunkBytes
}
memclrNoHeapPointers(unsafe.Pointer(voff), n)
}
}
// implementation of new builtin
// compiler (both frontend and SSA backend) knows the signature
// of this function.
func newobject(typ *_type) unsafe.Pointer {
return mallocgc(typ.size, typ, true)
}
//go:linkname reflect_unsafe_New reflect.unsafe_New
func reflect_unsafe_New(typ *_type) unsafe.Pointer {
return mallocgc(typ.size, typ, true)
}
//go:linkname reflectlite_unsafe_New internal/reflectlite.unsafe_New
func reflectlite_unsafe_New(typ *_type) unsafe.Pointer {
return mallocgc(typ.size, typ, true)
}
// newarray allocates an array of n elements of type typ.
func newarray(typ *_type, n int) unsafe.Pointer {
if n == 1 {
return mallocgc(typ.size, typ, true)
}
mem, overflow := math.MulUintptr(typ.size, uintptr(n))
if overflow || mem > maxAlloc || n < 0 {
panic(plainError("runtime: allocation size out of range"))
}
return mallocgc(mem, typ, true)
}
//go:linkname reflect_unsafe_NewArray reflect.unsafe_NewArray
func reflect_unsafe_NewArray(typ *_type, n int) unsafe.Pointer {
return newarray(typ, n)
}
func profilealloc(mp *m, x unsafe.Pointer, size uintptr) {
c := getMCache(mp)
if c == nil {
throw("profilealloc called without a P or outside bootstrapping")
}
c.nextSample = nextSample()
mProf_Malloc(x, size)
}
// nextSample returns the next sampling point for heap profiling. The goal is
// to sample allocations on average every MemProfileRate bytes, but with a
// completely random distribution over the allocation timeline; this
// corresponds to a Poisson process with parameter MemProfileRate. In Poisson
// processes, the distance between two samples follows the exponential
// distribution (exp(MemProfileRate)), so the best return value is a random
// number taken from an exponential distribution whose mean is MemProfileRate.
func nextSample() uintptr {
if MemProfileRate == 1 {
// Callers assign our return value to
// mcache.next_sample, but next_sample is not used
// when the rate is 1. So avoid the math below and
// just return something.
return 0
}
if GOOS == "plan9" {
// Plan 9 doesn't support floating point in note handler.
if gp := getg(); gp == gp.m.gsignal {
return nextSampleNoFP()
}
}
return uintptr(fastexprand(MemProfileRate))
}
// fastexprand returns a random number from an exponential distribution with
// the specified mean.
func fastexprand(mean int) int32 {
// Avoid overflow. Maximum possible step is
// -ln(1/(1<<randomBitCount)) * mean, approximately 20 * mean.
switch {
case mean > 0x7000000:
mean = 0x7000000
case mean == 0:
return 0
}
// Take a random sample of the exponential distribution exp(-mean*x).
// The probability distribution function is mean*exp(-mean*x), so the CDF is
// p = 1 - exp(-mean*x), so
// q = 1 - p == exp(-mean*x)
// log_e(q) = -mean*x
// -log_e(q)/mean = x
// x = -log_e(q) * mean
// x = log_2(q) * (-log_e(2)) * mean ; Using log_2 for efficiency
const randomBitCount = 26
q := fastrandn(1<<randomBitCount) + 1
qlog := fastlog2(float64(q)) - randomBitCount
if qlog > 0 {
qlog = 0
}
const minusLog2 = -0.6931471805599453 // -ln(2)
return int32(qlog*(minusLog2*float64(mean))) + 1
}
// nextSampleNoFP is similar to nextSample, but uses older,
// simpler code to avoid floating point.
func nextSampleNoFP() uintptr {
// Set first allocation sample size.
rate := MemProfileRate
if rate > 0x3fffffff { // make 2*rate not overflow
rate = 0x3fffffff
}
if rate != 0 {
return uintptr(fastrandn(uint32(2 * rate)))
}
return 0
}
type persistentAlloc struct {
base *notInHeap
off uintptr
}
var globalAlloc struct {
mutex
persistentAlloc
}
// persistentChunkSize is the number of bytes we allocate when we grow
// a persistentAlloc.
const persistentChunkSize = 256 << 10
// persistentChunks is a list of all the persistent chunks we have
// allocated. The list is maintained through the first word in the
// persistent chunk. This is updated atomically.
var persistentChunks *notInHeap
// Wrapper around sysAlloc that can allocate small chunks.
// There is no associated free operation.
// Intended for things like function/type/debug-related persistent data.
// If align is 0, uses default align (currently 8).
// The returned memory will be zeroed.
// sysStat must be non-nil.
//
// Consider marking persistentalloc'd types not in heap by embedding
// runtime/internal/sys.NotInHeap.
func persistentalloc(size, align uintptr, sysStat *sysMemStat) unsafe.Pointer {
var p *notInHeap
systemstack(func() {
p = persistentalloc1(size, align, sysStat)
})
return unsafe.Pointer(p)
}
// Must run on system stack because stack growth can (re)invoke it.
// See issue 9174.
//
//go:systemstack
func persistentalloc1(size, align uintptr, sysStat *sysMemStat) *notInHeap {
const (
maxBlock = 64 << 10 // VM reservation granularity is 64K on windows
)
if size == 0 {
throw("persistentalloc: size == 0")
}
if align != 0 {
if align&(align-1) != 0 {
throw("persistentalloc: align is not a power of 2")
}
if align > _PageSize {
throw("persistentalloc: align is too large")
}
} else {
align = 8
}
if size >= maxBlock {
return (*notInHeap)(sysAlloc(size, sysStat))
}
mp := acquirem()
var persistent *persistentAlloc
if mp != nil && mp.p != 0 {
persistent = &mp.p.ptr().palloc
} else {
lock(&globalAlloc.mutex)
persistent = &globalAlloc.persistentAlloc
}
persistent.off = alignUp(persistent.off, align)
if persistent.off+size > persistentChunkSize || persistent.base == nil {
persistent.base = (*notInHeap)(sysAlloc(persistentChunkSize, &memstats.other_sys))
if persistent.base == nil {
if persistent == &globalAlloc.persistentAlloc {
unlock(&globalAlloc.mutex)
}
throw("runtime: cannot allocate memory")
}
// Add the new chunk to the persistentChunks list.
for {
chunks := uintptr(unsafe.Pointer(persistentChunks))
*(*uintptr)(unsafe.Pointer(persistent.base)) = chunks
if atomic.Casuintptr((*uintptr)(unsafe.Pointer(&persistentChunks)), chunks, uintptr(unsafe.Pointer(persistent.base))) {
break
}
}
persistent.off = alignUp(goarch.PtrSize, align)
}
p := persistent.base.add(persistent.off)
persistent.off += size
releasem(mp)
if persistent == &globalAlloc.persistentAlloc {
unlock(&globalAlloc.mutex)
}
if sysStat != &memstats.other_sys {
sysStat.add(int64(size))
memstats.other_sys.add(-int64(size))
}
return p
}
// inPersistentAlloc reports whether p points to memory allocated by
// persistentalloc. This must be nosplit because it is called by the
// cgo checker code, which is called by the write barrier code.
//
//go:nosplit
func inPersistentAlloc(p uintptr) bool {
chunk := atomic.Loaduintptr((*uintptr)(unsafe.Pointer(&persistentChunks)))
for chunk != 0 {
if p >= chunk && p < chunk+persistentChunkSize {
return true
}
chunk = *(*uintptr)(unsafe.Pointer(chunk))
}
return false
}
// linearAlloc is a simple linear allocator that pre-reserves a region
// of memory and then optionally maps that region into the Ready state
// as needed.
//
// The caller is responsible for locking.
type linearAlloc struct {
next uintptr // next free byte
mapped uintptr // one byte past end of mapped space
end uintptr // end of reserved space
mapMemory bool // transition memory from Reserved to Ready if true
}
func (l *linearAlloc) init(base, size uintptr, mapMemory bool) {
if base+size < base {
// Chop off the last byte. The runtime isn't prepared
// to deal with situations where the bounds could overflow.
// Leave that memory reserved, though, so we don't map it
// later.
size -= 1
}
l.next, l.mapped = base, base
l.end = base + size
l.mapMemory = mapMemory
}
func (l *linearAlloc) alloc(size, align uintptr, sysStat *sysMemStat) unsafe.Pointer {
p := alignUp(l.next, align)
if p+size > l.end {
return nil
}
l.next = p + size
if pEnd := alignUp(l.next-1, physPageSize); pEnd > l.mapped {
if l.mapMemory {
// Transition from Reserved to Prepared to Ready.
n := pEnd - l.mapped
sysMap(unsafe.Pointer(l.mapped), n, sysStat)
sysUsed(unsafe.Pointer(l.mapped), n, n)
}
l.mapped = pEnd
}
return unsafe.Pointer(p)
}
// notInHeap is off-heap memory allocated by a lower-level allocator
// like sysAlloc or persistentAlloc.
//
// In general, it's better to use real types which embed
// runtime/internal/sys.NotInHeap, but this serves as a generic type
// for situations where that isn't possible (like in the allocators).
//
// TODO: Use this as the return type of sysAlloc, persistentAlloc, etc?
type notInHeap struct{ _ sys.NotInHeap }
func (p *notInHeap) add(bytes uintptr) *notInHeap {
return (*notInHeap)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + bytes))
}
// computeRZlog computes the size of the redzone.
// Refer to the implementation of the compiler-rt.
func computeRZlog(userSize uintptr) uintptr {
switch {
case userSize <= (64 - 16):
return 16 << 0
case userSize <= (128 - 32):
return 16 << 1
case userSize <= (512 - 64):
return 16 << 2
case userSize <= (4096 - 128):
return 16 << 3
case userSize <= (1<<14)-256:
return 16 << 4
case userSize <= (1<<15)-512:
return 16 << 5
case userSize <= (1<<16)-1024:
return 16 << 6
default:
return 16 << 7
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// This file contains the implementation of Go's map type.
//
// A map is just a hash table. The data is arranged
// into an array of buckets. Each bucket contains up to
// 8 key/elem pairs. The low-order bits of the hash are
// used to select a bucket. Each bucket contains a few
// high-order bits of each hash to distinguish the entries
// within a single bucket.
//
// If more than 8 keys hash to a bucket, we chain on
// extra buckets.
//
// When the hashtable grows, we allocate a new array
// of buckets twice as big. Buckets are incrementally
// copied from the old bucket array to the new bucket array.
//
// Map iterators walk through the array of buckets and
// return the keys in walk order (bucket #, then overflow
// chain order, then bucket index). To maintain iteration
// semantics, we never move keys within their bucket (if
// we did, keys might be returned 0 or 2 times). When
// growing the table, iterators remain iterating through the
// old table and must check the new table if the bucket
// they are iterating through has been moved ("evacuated")
// to the new table.
// Picking loadFactor: too large and we have lots of overflow
// buckets, too small and we waste a lot of space. I wrote
// a simple program to check some stats for different loads:
// (64-bit, 8 byte keys and elems)
// loadFactor %overflow bytes/entry hitprobe missprobe
// 4.00 2.13 20.77 3.00 4.00
// 4.50 4.05 17.30 3.25 4.50
// 5.00 6.85 14.77 3.50 5.00
// 5.50 10.55 12.94 3.75 5.50
// 6.00 15.27 11.67 4.00 6.00
// 6.50 20.90 10.79 4.25 6.50
// 7.00 27.14 10.15 4.50 7.00
// 7.50 34.03 9.73 4.75 7.50
// 8.00 41.10 9.40 5.00 8.00
//
// %overflow = percentage of buckets which have an overflow bucket
// bytes/entry = overhead bytes used per key/elem pair
// hitprobe = # of entries to check when looking up a present key
// missprobe = # of entries to check when looking up an absent key
//
// Keep in mind this data is for maximally loaded tables, i.e. just
// before the table grows. Typical tables will be somewhat less loaded.
import (
"internal/abi"
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/math"
"unsafe"
)
const (
// Maximum number of key/elem pairs a bucket can hold.
bucketCntBits = abi.MapBucketCountBits
bucketCnt = abi.MapBucketCount
// Maximum average load of a bucket that triggers growth is bucketCnt*13/16 (about 80% full)
// Because of minimum alignment rules, bucketCnt is known to be at least 8.
// Represent as loadFactorNum/loadFactorDen, to allow integer math.
loadFactorDen = 2
loadFactorNum = (bucketCnt * 13 / 16) * loadFactorDen
// Maximum key or elem size to keep inline (instead of mallocing per element).
// Must fit in a uint8.
// Fast versions cannot handle big elems - the cutoff size for
// fast versions in cmd/compile/internal/gc/walk.go must be at most this elem.
maxKeySize = abi.MapMaxKeyBytes
maxElemSize = abi.MapMaxElemBytes
// data offset should be the size of the bmap struct, but needs to be
// aligned correctly. For amd64p32 this means 64-bit alignment
// even though pointers are 32 bit.
dataOffset = unsafe.Offsetof(struct {
b bmap
v int64
}{}.v)
// Possible tophash values. We reserve a few possibilities for special marks.
// Each bucket (including its overflow buckets, if any) will have either all or none of its
// entries in the evacuated* states (except during the evacuate() method, which only happens
// during map writes and thus no one else can observe the map during that time).
emptyRest = 0 // this cell is empty, and there are no more non-empty cells at higher indexes or overflows.
emptyOne = 1 // this cell is empty
evacuatedX = 2 // key/elem is valid. Entry has been evacuated to first half of larger table.
evacuatedY = 3 // same as above, but evacuated to second half of larger table.
evacuatedEmpty = 4 // cell is empty, bucket is evacuated.
minTopHash = 5 // minimum tophash for a normal filled cell.
// flags
iterator = 1 // there may be an iterator using buckets
oldIterator = 2 // there may be an iterator using oldbuckets
hashWriting = 4 // a goroutine is writing to the map
sameSizeGrow = 8 // the current map growth is to a new map of the same size
// sentinel bucket ID for iterator checks
noCheck = 1<<(8*goarch.PtrSize) - 1
)
// isEmpty reports whether the given tophash array entry represents an empty bucket entry.
func isEmpty(x uint8) bool {
return x <= emptyOne
}
// A header for a Go map.
type hmap struct {
// Note: the format of the hmap is also encoded in cmd/compile/internal/reflectdata/reflect.go.
// Make sure this stays in sync with the compiler's definition.
count int // # live cells == size of map. Must be first (used by len() builtin)
flags uint8
B uint8 // log_2 of # of buckets (can hold up to loadFactor * 2^B items)
noverflow uint16 // approximate number of overflow buckets; see incrnoverflow for details
hash0 uint32 // hash seed
buckets unsafe.Pointer // array of 2^B Buckets. may be nil if count==0.
oldbuckets unsafe.Pointer // previous bucket array of half the size, non-nil only when growing
nevacuate uintptr // progress counter for evacuation (buckets less than this have been evacuated)
extra *mapextra // optional fields
}
// mapextra holds fields that are not present on all maps.
type mapextra struct {
// If both key and elem do not contain pointers and are inline, then we mark bucket
// type as containing no pointers. This avoids scanning such maps.
// However, bmap.overflow is a pointer. In order to keep overflow buckets
// alive, we store pointers to all overflow buckets in hmap.extra.overflow and hmap.extra.oldoverflow.
// overflow and oldoverflow are only used if key and elem do not contain pointers.
// overflow contains overflow buckets for hmap.buckets.
// oldoverflow contains overflow buckets for hmap.oldbuckets.
// The indirection allows to store a pointer to the slice in hiter.
overflow *[]*bmap
oldoverflow *[]*bmap
// nextOverflow holds a pointer to a free overflow bucket.
nextOverflow *bmap
}
// A bucket for a Go map.
type bmap struct {
// tophash generally contains the top byte of the hash value
// for each key in this bucket. If tophash[0] < minTopHash,
// tophash[0] is a bucket evacuation state instead.
tophash [bucketCnt]uint8
// Followed by bucketCnt keys and then bucketCnt elems.
// NOTE: packing all the keys together and then all the elems together makes the
// code a bit more complicated than alternating key/elem/key/elem/... but it allows
// us to eliminate padding which would be needed for, e.g., map[int64]int8.
// Followed by an overflow pointer.
}
// A hash iteration structure.
// If you modify hiter, also change cmd/compile/internal/reflectdata/reflect.go
// and reflect/value.go to match the layout of this structure.
type hiter struct {
key unsafe.Pointer // Must be in first position. Write nil to indicate iteration end (see cmd/compile/internal/walk/range.go).
elem unsafe.Pointer // Must be in second position (see cmd/compile/internal/walk/range.go).
t *maptype
h *hmap
buckets unsafe.Pointer // bucket ptr at hash_iter initialization time
bptr *bmap // current bucket
overflow *[]*bmap // keeps overflow buckets of hmap.buckets alive
oldoverflow *[]*bmap // keeps overflow buckets of hmap.oldbuckets alive
startBucket uintptr // bucket iteration started at
offset uint8 // intra-bucket offset to start from during iteration (should be big enough to hold bucketCnt-1)
wrapped bool // already wrapped around from end of bucket array to beginning
B uint8
i uint8
bucket uintptr
checkBucket uintptr
}
// bucketShift returns 1<<b, optimized for code generation.
func bucketShift(b uint8) uintptr {
// Masking the shift amount allows overflow checks to be elided.
return uintptr(1) << (b & (goarch.PtrSize*8 - 1))
}
// bucketMask returns 1<<b - 1, optimized for code generation.
func bucketMask(b uint8) uintptr {
return bucketShift(b) - 1
}
// tophash calculates the tophash value for hash.
func tophash(hash uintptr) uint8 {
top := uint8(hash >> (goarch.PtrSize*8 - 8))
if top < minTopHash {
top += minTopHash
}
return top
}
func evacuated(b *bmap) bool {
h := b.tophash[0]
return h > emptyOne && h < minTopHash
}
func (b *bmap) overflow(t *maptype) *bmap {
return *(**bmap)(add(unsafe.Pointer(b), uintptr(t.bucketsize)-goarch.PtrSize))
}
func (b *bmap) setoverflow(t *maptype, ovf *bmap) {
*(**bmap)(add(unsafe.Pointer(b), uintptr(t.bucketsize)-goarch.PtrSize)) = ovf
}
func (b *bmap) keys() unsafe.Pointer {
return add(unsafe.Pointer(b), dataOffset)
}
// incrnoverflow increments h.noverflow.
// noverflow counts the number of overflow buckets.
// This is used to trigger same-size map growth.
// See also tooManyOverflowBuckets.
// To keep hmap small, noverflow is a uint16.
// When there are few buckets, noverflow is an exact count.
// When there are many buckets, noverflow is an approximate count.
func (h *hmap) incrnoverflow() {
// We trigger same-size map growth if there are
// as many overflow buckets as buckets.
// We need to be able to count to 1<<h.B.
if h.B < 16 {
h.noverflow++
return
}
// Increment with probability 1/(1<<(h.B-15)).
// When we reach 1<<15 - 1, we will have approximately
// as many overflow buckets as buckets.
mask := uint32(1)<<(h.B-15) - 1
// Example: if h.B == 18, then mask == 7,
// and fastrand & 7 == 0 with probability 1/8.
if fastrand()&mask == 0 {
h.noverflow++
}
}
func (h *hmap) newoverflow(t *maptype, b *bmap) *bmap {
var ovf *bmap
if h.extra != nil && h.extra.nextOverflow != nil {
// We have preallocated overflow buckets available.
// See makeBucketArray for more details.
ovf = h.extra.nextOverflow
if ovf.overflow(t) == nil {
// We're not at the end of the preallocated overflow buckets. Bump the pointer.
h.extra.nextOverflow = (*bmap)(add(unsafe.Pointer(ovf), uintptr(t.bucketsize)))
} else {
// This is the last preallocated overflow bucket.
// Reset the overflow pointer on this bucket,
// which was set to a non-nil sentinel value.
ovf.setoverflow(t, nil)
h.extra.nextOverflow = nil
}
} else {
ovf = (*bmap)(newobject(t.bucket))
}
h.incrnoverflow()
if t.bucket.ptrdata == 0 {
h.createOverflow()
*h.extra.overflow = append(*h.extra.overflow, ovf)
}
b.setoverflow(t, ovf)
return ovf
}
func (h *hmap) createOverflow() {
if h.extra == nil {
h.extra = new(mapextra)
}
if h.extra.overflow == nil {
h.extra.overflow = new([]*bmap)
}
}
func makemap64(t *maptype, hint int64, h *hmap) *hmap {
if int64(int(hint)) != hint {
hint = 0
}
return makemap(t, int(hint), h)
}
// makemap_small implements Go map creation for make(map[k]v) and
// make(map[k]v, hint) when hint is known to be at most bucketCnt
// at compile time and the map needs to be allocated on the heap.
func makemap_small() *hmap {
h := new(hmap)
h.hash0 = fastrand()
return h
}
// makemap implements Go map creation for make(map[k]v, hint).
// If the compiler has determined that the map or the first bucket
// can be created on the stack, h and/or bucket may be non-nil.
// If h != nil, the map can be created directly in h.
// If h.buckets != nil, bucket pointed to can be used as the first bucket.
func makemap(t *maptype, hint int, h *hmap) *hmap {
mem, overflow := math.MulUintptr(uintptr(hint), t.bucket.size)
if overflow || mem > maxAlloc {
hint = 0
}
// initialize Hmap
if h == nil {
h = new(hmap)
}
h.hash0 = fastrand()
// Find the size parameter B which will hold the requested # of elements.
// For hint < 0 overLoadFactor returns false since hint < bucketCnt.
B := uint8(0)
for overLoadFactor(hint, B) {
B++
}
h.B = B
// allocate initial hash table
// if B == 0, the buckets field is allocated lazily later (in mapassign)
// If hint is large zeroing this memory could take a while.
if h.B != 0 {
var nextOverflow *bmap
h.buckets, nextOverflow = makeBucketArray(t, h.B, nil)
if nextOverflow != nil {
h.extra = new(mapextra)
h.extra.nextOverflow = nextOverflow
}
}
return h
}
// makeBucketArray initializes a backing array for map buckets.
// 1<<b is the minimum number of buckets to allocate.
// dirtyalloc should either be nil or a bucket array previously
// allocated by makeBucketArray with the same t and b parameters.
// If dirtyalloc is nil a new backing array will be alloced and
// otherwise dirtyalloc will be cleared and reused as backing array.
func makeBucketArray(t *maptype, b uint8, dirtyalloc unsafe.Pointer) (buckets unsafe.Pointer, nextOverflow *bmap) {
base := bucketShift(b)
nbuckets := base
// For small b, overflow buckets are unlikely.
// Avoid the overhead of the calculation.
if b >= 4 {
// Add on the estimated number of overflow buckets
// required to insert the median number of elements
// used with this value of b.
nbuckets += bucketShift(b - 4)
sz := t.bucket.size * nbuckets
up := roundupsize(sz)
if up != sz {
nbuckets = up / t.bucket.size
}
}
if dirtyalloc == nil {
buckets = newarray(t.bucket, int(nbuckets))
} else {
// dirtyalloc was previously generated by
// the above newarray(t.bucket, int(nbuckets))
// but may not be empty.
buckets = dirtyalloc
size := t.bucket.size * nbuckets
if t.bucket.ptrdata != 0 {
memclrHasPointers(buckets, size)
} else {
memclrNoHeapPointers(buckets, size)
}
}
if base != nbuckets {
// We preallocated some overflow buckets.
// To keep the overhead of tracking these overflow buckets to a minimum,
// we use the convention that if a preallocated overflow bucket's overflow
// pointer is nil, then there are more available by bumping the pointer.
// We need a safe non-nil pointer for the last overflow bucket; just use buckets.
nextOverflow = (*bmap)(add(buckets, base*uintptr(t.bucketsize)))
last := (*bmap)(add(buckets, (nbuckets-1)*uintptr(t.bucketsize)))
last.setoverflow(t, (*bmap)(buckets))
}
return buckets, nextOverflow
}
// mapaccess1 returns a pointer to h[key]. Never returns nil, instead
// it will return a reference to the zero object for the elem type if
// the key is not in the map.
// NOTE: The returned pointer may keep the whole map live, so don't
// hold onto it for very long.
func mapaccess1(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if raceenabled && h != nil {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(mapaccess1)
racereadpc(unsafe.Pointer(h), callerpc, pc)
raceReadObjectPC(t.key, key, callerpc, pc)
}
if msanenabled && h != nil {
msanread(key, t.key.size)
}
if asanenabled && h != nil {
asanread(key, t.key.size)
}
if h == nil || h.count == 0 {
if t.hashMightPanic() {
t.hasher(key, 0) // see issue 23734
}
return unsafe.Pointer(&zeroVal[0])
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
hash := t.hasher(key, uintptr(h.hash0))
m := bucketMask(h.B)
b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
top := tophash(hash)
bucketloop:
for ; b != nil; b = b.overflow(t) {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
if t.indirectkey() {
k = *((*unsafe.Pointer)(k))
}
if t.key.equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
if t.indirectelem() {
e = *((*unsafe.Pointer)(e))
}
return e
}
}
}
return unsafe.Pointer(&zeroVal[0])
}
func mapaccess2(t *maptype, h *hmap, key unsafe.Pointer) (unsafe.Pointer, bool) {
if raceenabled && h != nil {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(mapaccess2)
racereadpc(unsafe.Pointer(h), callerpc, pc)
raceReadObjectPC(t.key, key, callerpc, pc)
}
if msanenabled && h != nil {
msanread(key, t.key.size)
}
if asanenabled && h != nil {
asanread(key, t.key.size)
}
if h == nil || h.count == 0 {
if t.hashMightPanic() {
t.hasher(key, 0) // see issue 23734
}
return unsafe.Pointer(&zeroVal[0]), false
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
hash := t.hasher(key, uintptr(h.hash0))
m := bucketMask(h.B)
b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
top := tophash(hash)
bucketloop:
for ; b != nil; b = b.overflow(t) {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
if t.indirectkey() {
k = *((*unsafe.Pointer)(k))
}
if t.key.equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
if t.indirectelem() {
e = *((*unsafe.Pointer)(e))
}
return e, true
}
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
// returns both key and elem. Used by map iterator.
func mapaccessK(t *maptype, h *hmap, key unsafe.Pointer) (unsafe.Pointer, unsafe.Pointer) {
if h == nil || h.count == 0 {
return nil, nil
}
hash := t.hasher(key, uintptr(h.hash0))
m := bucketMask(h.B)
b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
top := tophash(hash)
bucketloop:
for ; b != nil; b = b.overflow(t) {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
if t.indirectkey() {
k = *((*unsafe.Pointer)(k))
}
if t.key.equal(key, k) {
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
if t.indirectelem() {
e = *((*unsafe.Pointer)(e))
}
return k, e
}
}
}
return nil, nil
}
func mapaccess1_fat(t *maptype, h *hmap, key, zero unsafe.Pointer) unsafe.Pointer {
e := mapaccess1(t, h, key)
if e == unsafe.Pointer(&zeroVal[0]) {
return zero
}
return e
}
func mapaccess2_fat(t *maptype, h *hmap, key, zero unsafe.Pointer) (unsafe.Pointer, bool) {
e := mapaccess1(t, h, key)
if e == unsafe.Pointer(&zeroVal[0]) {
return zero, false
}
return e, true
}
// Like mapaccess, but allocates a slot for the key if it is not present in the map.
func mapassign(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(mapassign)
racewritepc(unsafe.Pointer(h), callerpc, pc)
raceReadObjectPC(t.key, key, callerpc, pc)
}
if msanenabled {
msanread(key, t.key.size)
}
if asanenabled {
asanread(key, t.key.size)
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(key, uintptr(h.hash0))
// Set hashWriting after calling t.hasher, since t.hasher may panic,
// in which case we have not actually done a write.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
top := tophash(hash)
var inserti *uint8
var insertk unsafe.Pointer
var elem unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if isEmpty(b.tophash[i]) && inserti == nil {
inserti = &b.tophash[i]
insertk = add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
elem = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
if t.indirectkey() {
k = *((*unsafe.Pointer)(k))
}
if !t.key.equal(key, k) {
continue
}
// already have a mapping for key. Update it.
if t.needkeyupdate() {
typedmemmove(t.key, k, key)
}
elem = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if inserti == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
newb := h.newoverflow(t, b)
inserti = &newb.tophash[0]
insertk = add(unsafe.Pointer(newb), dataOffset)
elem = add(insertk, bucketCnt*uintptr(t.keysize))
}
// store new key/elem at insert position
if t.indirectkey() {
kmem := newobject(t.key)
*(*unsafe.Pointer)(insertk) = kmem
insertk = kmem
}
if t.indirectelem() {
vmem := newobject(t.elem)
*(*unsafe.Pointer)(elem) = vmem
}
typedmemmove(t.key, insertk, key)
*inserti = top
h.count++
done:
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
if t.indirectelem() {
elem = *((*unsafe.Pointer)(elem))
}
return elem
}
func mapdelete(t *maptype, h *hmap, key unsafe.Pointer) {
if raceenabled && h != nil {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(mapdelete)
racewritepc(unsafe.Pointer(h), callerpc, pc)
raceReadObjectPC(t.key, key, callerpc, pc)
}
if msanenabled && h != nil {
msanread(key, t.key.size)
}
if asanenabled && h != nil {
asanread(key, t.key.size)
}
if h == nil || h.count == 0 {
if t.hashMightPanic() {
t.hasher(key, 0) // see issue 23734
}
return
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(key, uintptr(h.hash0))
// Set hashWriting after calling t.hasher, since t.hasher may panic,
// in which case we have not actually done a write (delete).
h.flags ^= hashWriting
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
bOrig := b
top := tophash(hash)
search:
for ; b != nil; b = b.overflow(t) {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if b.tophash[i] == emptyRest {
break search
}
continue
}
k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
k2 := k
if t.indirectkey() {
k2 = *((*unsafe.Pointer)(k2))
}
if !t.key.equal(key, k2) {
continue
}
// Only clear key if there are pointers in it.
if t.indirectkey() {
*(*unsafe.Pointer)(k) = nil
} else if t.key.ptrdata != 0 {
memclrHasPointers(k, t.key.size)
}
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
if t.indirectelem() {
*(*unsafe.Pointer)(e) = nil
} else if t.elem.ptrdata != 0 {
memclrHasPointers(e, t.elem.size)
} else {
memclrNoHeapPointers(e, t.elem.size)
}
b.tophash[i] = emptyOne
// If the bucket now ends in a bunch of emptyOne states,
// change those to emptyRest states.
// It would be nice to make this a separate function, but
// for loops are not currently inlineable.
if i == bucketCnt-1 {
if b.overflow(t) != nil && b.overflow(t).tophash[0] != emptyRest {
goto notLast
}
} else {
if b.tophash[i+1] != emptyRest {
goto notLast
}
}
for {
b.tophash[i] = emptyRest
if i == 0 {
if b == bOrig {
break // beginning of initial bucket, we're done.
}
// Find previous bucket, continue at its last entry.
c := b
for b = bOrig; b.overflow(t) != c; b = b.overflow(t) {
}
i = bucketCnt - 1
} else {
i--
}
if b.tophash[i] != emptyOne {
break
}
}
notLast:
h.count--
// Reset the hash seed to make it more difficult for attackers to
// repeatedly trigger hash collisions. See issue 25237.
if h.count == 0 {
h.hash0 = fastrand()
}
break search
}
}
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
}
// mapiterinit initializes the hiter struct used for ranging over maps.
// The hiter struct pointed to by 'it' is allocated on the stack
// by the compilers order pass or on the heap by reflect_mapiterinit.
// Both need to have zeroed hiter since the struct contains pointers.
func mapiterinit(t *maptype, h *hmap, it *hiter) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapiterinit))
}
it.t = t
if h == nil || h.count == 0 {
return
}
if unsafe.Sizeof(hiter{})/goarch.PtrSize != 12 {
throw("hash_iter size incorrect") // see cmd/compile/internal/reflectdata/reflect.go
}
it.h = h
// grab snapshot of bucket state
it.B = h.B
it.buckets = h.buckets
if t.bucket.ptrdata == 0 {
// Allocate the current slice and remember pointers to both current and old.
// This preserves all relevant overflow buckets alive even if
// the table grows and/or overflow buckets are added to the table
// while we are iterating.
h.createOverflow()
it.overflow = h.extra.overflow
it.oldoverflow = h.extra.oldoverflow
}
// decide where to start
var r uintptr
if h.B > 31-bucketCntBits {
r = uintptr(fastrand64())
} else {
r = uintptr(fastrand())
}
it.startBucket = r & bucketMask(h.B)
it.offset = uint8(r >> h.B & (bucketCnt - 1))
// iterator state
it.bucket = it.startBucket
// Remember we have an iterator.
// Can run concurrently with another mapiterinit().
if old := h.flags; old&(iterator|oldIterator) != iterator|oldIterator {
atomic.Or8(&h.flags, iterator|oldIterator)
}
mapiternext(it)
}
func mapiternext(it *hiter) {
h := it.h
if raceenabled {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapiternext))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map iteration and map write")
}
t := it.t
bucket := it.bucket
b := it.bptr
i := it.i
checkBucket := it.checkBucket
next:
if b == nil {
if bucket == it.startBucket && it.wrapped {
// end of iteration
it.key = nil
it.elem = nil
return
}
if h.growing() && it.B == h.B {
// Iterator was started in the middle of a grow, and the grow isn't done yet.
// If the bucket we're looking at hasn't been filled in yet (i.e. the old
// bucket hasn't been evacuated) then we need to iterate through the old
// bucket and only return the ones that will be migrated to this bucket.
oldbucket := bucket & it.h.oldbucketmask()
b = (*bmap)(add(h.oldbuckets, oldbucket*uintptr(t.bucketsize)))
if !evacuated(b) {
checkBucket = bucket
} else {
b = (*bmap)(add(it.buckets, bucket*uintptr(t.bucketsize)))
checkBucket = noCheck
}
} else {
b = (*bmap)(add(it.buckets, bucket*uintptr(t.bucketsize)))
checkBucket = noCheck
}
bucket++
if bucket == bucketShift(it.B) {
bucket = 0
it.wrapped = true
}
i = 0
}
for ; i < bucketCnt; i++ {
offi := (i + it.offset) & (bucketCnt - 1)
if isEmpty(b.tophash[offi]) || b.tophash[offi] == evacuatedEmpty {
// TODO: emptyRest is hard to use here, as we start iterating
// in the middle of a bucket. It's feasible, just tricky.
continue
}
k := add(unsafe.Pointer(b), dataOffset+uintptr(offi)*uintptr(t.keysize))
if t.indirectkey() {
k = *((*unsafe.Pointer)(k))
}
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+uintptr(offi)*uintptr(t.elemsize))
if checkBucket != noCheck && !h.sameSizeGrow() {
// Special case: iterator was started during a grow to a larger size
// and the grow is not done yet. We're working on a bucket whose
// oldbucket has not been evacuated yet. Or at least, it wasn't
// evacuated when we started the bucket. So we're iterating
// through the oldbucket, skipping any keys that will go
// to the other new bucket (each oldbucket expands to two
// buckets during a grow).
if t.reflexivekey() || t.key.equal(k, k) {
// If the item in the oldbucket is not destined for
// the current new bucket in the iteration, skip it.
hash := t.hasher(k, uintptr(h.hash0))
if hash&bucketMask(it.B) != checkBucket {
continue
}
} else {
// Hash isn't repeatable if k != k (NaNs). We need a
// repeatable and randomish choice of which direction
// to send NaNs during evacuation. We'll use the low
// bit of tophash to decide which way NaNs go.
// NOTE: this case is why we need two evacuate tophash
// values, evacuatedX and evacuatedY, that differ in
// their low bit.
if checkBucket>>(it.B-1) != uintptr(b.tophash[offi]&1) {
continue
}
}
}
if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) ||
!(t.reflexivekey() || t.key.equal(k, k)) {
// This is the golden data, we can return it.
// OR
// key!=key, so the entry can't be deleted or updated, so we can just return it.
// That's lucky for us because when key!=key we can't look it up successfully.
it.key = k
if t.indirectelem() {
e = *((*unsafe.Pointer)(e))
}
it.elem = e
} else {
// The hash table has grown since the iterator was started.
// The golden data for this key is now somewhere else.
// Check the current hash table for the data.
// This code handles the case where the key
// has been deleted, updated, or deleted and reinserted.
// NOTE: we need to regrab the key as it has potentially been
// updated to an equal() but not identical key (e.g. +0.0 vs -0.0).
rk, re := mapaccessK(t, h, k)
if rk == nil {
continue // key has been deleted
}
it.key = rk
it.elem = re
}
it.bucket = bucket
if it.bptr != b { // avoid unnecessary write barrier; see issue 14921
it.bptr = b
}
it.i = i + 1
it.checkBucket = checkBucket
return
}
b = b.overflow(t)
i = 0
goto next
}
// mapclear deletes all keys from a map.
func mapclear(t *maptype, h *hmap) {
if raceenabled && h != nil {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(mapclear)
racewritepc(unsafe.Pointer(h), callerpc, pc)
}
if h == nil || h.count == 0 {
return
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
h.flags ^= hashWriting
h.flags &^= sameSizeGrow
h.oldbuckets = nil
h.nevacuate = 0
h.noverflow = 0
h.count = 0
// Reset the hash seed to make it more difficult for attackers to
// repeatedly trigger hash collisions. See issue 25237.
h.hash0 = fastrand()
// Keep the mapextra allocation but clear any extra information.
if h.extra != nil {
*h.extra = mapextra{}
}
// makeBucketArray clears the memory pointed to by h.buckets
// and recovers any overflow buckets by generating them
// as if h.buckets was newly alloced.
_, nextOverflow := makeBucketArray(t, h.B, h.buckets)
if nextOverflow != nil {
// If overflow buckets are created then h.extra
// will have been allocated during initial bucket creation.
h.extra.nextOverflow = nextOverflow
}
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
}
func hashGrow(t *maptype, h *hmap) {
// If we've hit the load factor, get bigger.
// Otherwise, there are too many overflow buckets,
// so keep the same number of buckets and "grow" laterally.
bigger := uint8(1)
if !overLoadFactor(h.count+1, h.B) {
bigger = 0
h.flags |= sameSizeGrow
}
oldbuckets := h.buckets
newbuckets, nextOverflow := makeBucketArray(t, h.B+bigger, nil)
flags := h.flags &^ (iterator | oldIterator)
if h.flags&iterator != 0 {
flags |= oldIterator
}
// commit the grow (atomic wrt gc)
h.B += bigger
h.flags = flags
h.oldbuckets = oldbuckets
h.buckets = newbuckets
h.nevacuate = 0
h.noverflow = 0
if h.extra != nil && h.extra.overflow != nil {
// Promote current overflow buckets to the old generation.
if h.extra.oldoverflow != nil {
throw("oldoverflow is not nil")
}
h.extra.oldoverflow = h.extra.overflow
h.extra.overflow = nil
}
if nextOverflow != nil {
if h.extra == nil {
h.extra = new(mapextra)
}
h.extra.nextOverflow = nextOverflow
}
// the actual copying of the hash table data is done incrementally
// by growWork() and evacuate().
}
// overLoadFactor reports whether count items placed in 1<<B buckets is over loadFactor.
func overLoadFactor(count int, B uint8) bool {
return count > bucketCnt && uintptr(count) > loadFactorNum*(bucketShift(B)/loadFactorDen)
}
// tooManyOverflowBuckets reports whether noverflow buckets is too many for a map with 1<<B buckets.
// Note that most of these overflow buckets must be in sparse use;
// if use was dense, then we'd have already triggered regular map growth.
func tooManyOverflowBuckets(noverflow uint16, B uint8) bool {
// If the threshold is too low, we do extraneous work.
// If the threshold is too high, maps that grow and shrink can hold on to lots of unused memory.
// "too many" means (approximately) as many overflow buckets as regular buckets.
// See incrnoverflow for more details.
if B > 15 {
B = 15
}
// The compiler doesn't see here that B < 16; mask B to generate shorter shift code.
return noverflow >= uint16(1)<<(B&15)
}
// growing reports whether h is growing. The growth may be to the same size or bigger.
func (h *hmap) growing() bool {
return h.oldbuckets != nil
}
// sameSizeGrow reports whether the current growth is to a map of the same size.
func (h *hmap) sameSizeGrow() bool {
return h.flags&sameSizeGrow != 0
}
// noldbuckets calculates the number of buckets prior to the current map growth.
func (h *hmap) noldbuckets() uintptr {
oldB := h.B
if !h.sameSizeGrow() {
oldB--
}
return bucketShift(oldB)
}
// oldbucketmask provides a mask that can be applied to calculate n % noldbuckets().
func (h *hmap) oldbucketmask() uintptr {
return h.noldbuckets() - 1
}
func growWork(t *maptype, h *hmap, bucket uintptr) {
// make sure we evacuate the oldbucket corresponding
// to the bucket we're about to use
evacuate(t, h, bucket&h.oldbucketmask())
// evacuate one more oldbucket to make progress on growing
if h.growing() {
evacuate(t, h, h.nevacuate)
}
}
func bucketEvacuated(t *maptype, h *hmap, bucket uintptr) bool {
b := (*bmap)(add(h.oldbuckets, bucket*uintptr(t.bucketsize)))
return evacuated(b)
}
// evacDst is an evacuation destination.
type evacDst struct {
b *bmap // current destination bucket
i int // key/elem index into b
k unsafe.Pointer // pointer to current key storage
e unsafe.Pointer // pointer to current elem storage
}
func evacuate(t *maptype, h *hmap, oldbucket uintptr) {
b := (*bmap)(add(h.oldbuckets, oldbucket*uintptr(t.bucketsize)))
newbit := h.noldbuckets()
if !evacuated(b) {
// TODO: reuse overflow buckets instead of using new ones, if there
// is no iterator using the old buckets. (If !oldIterator.)
// xy contains the x and y (low and high) evacuation destinations.
var xy [2]evacDst
x := &xy[0]
x.b = (*bmap)(add(h.buckets, oldbucket*uintptr(t.bucketsize)))
x.k = add(unsafe.Pointer(x.b), dataOffset)
x.e = add(x.k, bucketCnt*uintptr(t.keysize))
if !h.sameSizeGrow() {
// Only calculate y pointers if we're growing bigger.
// Otherwise GC can see bad pointers.
y := &xy[1]
y.b = (*bmap)(add(h.buckets, (oldbucket+newbit)*uintptr(t.bucketsize)))
y.k = add(unsafe.Pointer(y.b), dataOffset)
y.e = add(y.k, bucketCnt*uintptr(t.keysize))
}
for ; b != nil; b = b.overflow(t) {
k := add(unsafe.Pointer(b), dataOffset)
e := add(k, bucketCnt*uintptr(t.keysize))
for i := 0; i < bucketCnt; i, k, e = i+1, add(k, uintptr(t.keysize)), add(e, uintptr(t.elemsize)) {
top := b.tophash[i]
if isEmpty(top) {
b.tophash[i] = evacuatedEmpty
continue
}
if top < minTopHash {
throw("bad map state")
}
k2 := k
if t.indirectkey() {
k2 = *((*unsafe.Pointer)(k2))
}
var useY uint8
if !h.sameSizeGrow() {
// Compute hash to make our evacuation decision (whether we need
// to send this key/elem to bucket x or bucket y).
hash := t.hasher(k2, uintptr(h.hash0))
if h.flags&iterator != 0 && !t.reflexivekey() && !t.key.equal(k2, k2) {
// If key != key (NaNs), then the hash could be (and probably
// will be) entirely different from the old hash. Moreover,
// it isn't reproducible. Reproducibility is required in the
// presence of iterators, as our evacuation decision must
// match whatever decision the iterator made.
// Fortunately, we have the freedom to send these keys either
// way. Also, tophash is meaningless for these kinds of keys.
// We let the low bit of tophash drive the evacuation decision.
// We recompute a new random tophash for the next level so
// these keys will get evenly distributed across all buckets
// after multiple grows.
useY = top & 1
top = tophash(hash)
} else {
if hash&newbit != 0 {
useY = 1
}
}
}
if evacuatedX+1 != evacuatedY || evacuatedX^1 != evacuatedY {
throw("bad evacuatedN")
}
b.tophash[i] = evacuatedX + useY // evacuatedX + 1 == evacuatedY
dst := &xy[useY] // evacuation destination
if dst.i == bucketCnt {
dst.b = h.newoverflow(t, dst.b)
dst.i = 0
dst.k = add(unsafe.Pointer(dst.b), dataOffset)
dst.e = add(dst.k, bucketCnt*uintptr(t.keysize))
}
dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
if t.indirectkey() {
*(*unsafe.Pointer)(dst.k) = k2 // copy pointer
} else {
typedmemmove(t.key, dst.k, k) // copy elem
}
if t.indirectelem() {
*(*unsafe.Pointer)(dst.e) = *(*unsafe.Pointer)(e)
} else {
typedmemmove(t.elem, dst.e, e)
}
dst.i++
// These updates might push these pointers past the end of the
// key or elem arrays. That's ok, as we have the overflow pointer
// at the end of the bucket to protect against pointing past the
// end of the bucket.
dst.k = add(dst.k, uintptr(t.keysize))
dst.e = add(dst.e, uintptr(t.elemsize))
}
}
// Unlink the overflow buckets & clear key/elem to help GC.
if h.flags&oldIterator == 0 && t.bucket.ptrdata != 0 {
b := add(h.oldbuckets, oldbucket*uintptr(t.bucketsize))
// Preserve b.tophash because the evacuation
// state is maintained there.
ptr := add(b, dataOffset)
n := uintptr(t.bucketsize) - dataOffset
memclrHasPointers(ptr, n)
}
}
if oldbucket == h.nevacuate {
advanceEvacuationMark(h, t, newbit)
}
}
func advanceEvacuationMark(h *hmap, t *maptype, newbit uintptr) {
h.nevacuate++
// Experiments suggest that 1024 is overkill by at least an order of magnitude.
// Put it in there as a safeguard anyway, to ensure O(1) behavior.
stop := h.nevacuate + 1024
if stop > newbit {
stop = newbit
}
for h.nevacuate != stop && bucketEvacuated(t, h, h.nevacuate) {
h.nevacuate++
}
if h.nevacuate == newbit { // newbit == # of oldbuckets
// Growing is all done. Free old main bucket array.
h.oldbuckets = nil
// Can discard old overflow buckets as well.
// If they are still referenced by an iterator,
// then the iterator holds a pointers to the slice.
if h.extra != nil {
h.extra.oldoverflow = nil
}
h.flags &^= sameSizeGrow
}
}
// Reflect stubs. Called from ../reflect/asm_*.s
//go:linkname reflect_makemap reflect.makemap
func reflect_makemap(t *maptype, cap int) *hmap {
// Check invariants and reflects math.
if t.key.equal == nil {
throw("runtime.reflect_makemap: unsupported map key type")
}
if t.key.size > maxKeySize && (!t.indirectkey() || t.keysize != uint8(goarch.PtrSize)) ||
t.key.size <= maxKeySize && (t.indirectkey() || t.keysize != uint8(t.key.size)) {
throw("key size wrong")
}
if t.elem.size > maxElemSize && (!t.indirectelem() || t.elemsize != uint8(goarch.PtrSize)) ||
t.elem.size <= maxElemSize && (t.indirectelem() || t.elemsize != uint8(t.elem.size)) {
throw("elem size wrong")
}
if t.key.align > bucketCnt {
throw("key align too big")
}
if t.elem.align > bucketCnt {
throw("elem align too big")
}
if t.key.size%uintptr(t.key.align) != 0 {
throw("key size not a multiple of key align")
}
if t.elem.size%uintptr(t.elem.align) != 0 {
throw("elem size not a multiple of elem align")
}
if bucketCnt < 8 {
throw("bucketsize too small for proper alignment")
}
if dataOffset%uintptr(t.key.align) != 0 {
throw("need padding in bucket (key)")
}
if dataOffset%uintptr(t.elem.align) != 0 {
throw("need padding in bucket (elem)")
}
return makemap(t, cap, nil)
}
//go:linkname reflect_mapaccess reflect.mapaccess
func reflect_mapaccess(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
elem, ok := mapaccess2(t, h, key)
if !ok {
// reflect wants nil for a missing element
elem = nil
}
return elem
}
//go:linkname reflect_mapaccess_faststr reflect.mapaccess_faststr
func reflect_mapaccess_faststr(t *maptype, h *hmap, key string) unsafe.Pointer {
elem, ok := mapaccess2_faststr(t, h, key)
if !ok {
// reflect wants nil for a missing element
elem = nil
}
return elem
}
//go:linkname reflect_mapassign reflect.mapassign
func reflect_mapassign(t *maptype, h *hmap, key unsafe.Pointer, elem unsafe.Pointer) {
p := mapassign(t, h, key)
typedmemmove(t.elem, p, elem)
}
//go:linkname reflect_mapassign_faststr reflect.mapassign_faststr
func reflect_mapassign_faststr(t *maptype, h *hmap, key string, elem unsafe.Pointer) {
p := mapassign_faststr(t, h, key)
typedmemmove(t.elem, p, elem)
}
//go:linkname reflect_mapdelete reflect.mapdelete
func reflect_mapdelete(t *maptype, h *hmap, key unsafe.Pointer) {
mapdelete(t, h, key)
}
//go:linkname reflect_mapdelete_faststr reflect.mapdelete_faststr
func reflect_mapdelete_faststr(t *maptype, h *hmap, key string) {
mapdelete_faststr(t, h, key)
}
//go:linkname reflect_mapiterinit reflect.mapiterinit
func reflect_mapiterinit(t *maptype, h *hmap, it *hiter) {
mapiterinit(t, h, it)
}
//go:linkname reflect_mapiternext reflect.mapiternext
func reflect_mapiternext(it *hiter) {
mapiternext(it)
}
//go:linkname reflect_mapiterkey reflect.mapiterkey
func reflect_mapiterkey(it *hiter) unsafe.Pointer {
return it.key
}
//go:linkname reflect_mapiterelem reflect.mapiterelem
func reflect_mapiterelem(it *hiter) unsafe.Pointer {
return it.elem
}
//go:linkname reflect_maplen reflect.maplen
func reflect_maplen(h *hmap) int {
if h == nil {
return 0
}
if raceenabled {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(reflect_maplen))
}
return h.count
}
//go:linkname reflect_mapclear reflect.mapclear
func reflect_mapclear(t *maptype, h *hmap) {
mapclear(t, h)
}
//go:linkname reflectlite_maplen internal/reflectlite.maplen
func reflectlite_maplen(h *hmap) int {
if h == nil {
return 0
}
if raceenabled {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(reflect_maplen))
}
return h.count
}
const maxZero = 1024 // must match value in reflect/value.go:maxZero cmd/compile/internal/gc/walk.go:zeroValSize
var zeroVal [maxZero]byte
// mapinitnoop is a no-op function known the Go linker; if a given global
// map (of the right size) is determined to be dead, the linker will
// rewrite the relocation (from the package init func) from the outlined
// map init function to this symbol. Defined in assembly so as to avoid
// complications with instrumentation (coverage, etc).
func mapinitnoop()
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"unsafe"
)
func mapaccess1_fast32(t *maptype, h *hmap, key uint32) unsafe.Pointer {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess1_fast32))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0])
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
var b *bmap
if h.B == 0 {
// One-bucket table. No need to hash.
b = (*bmap)(h.buckets)
} else {
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
m := bucketMask(h.B)
b = (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
}
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 4) {
if *(*uint32)(k) == key && !isEmpty(b.tophash[i]) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*4+i*uintptr(t.elemsize))
}
}
}
return unsafe.Pointer(&zeroVal[0])
}
func mapaccess2_fast32(t *maptype, h *hmap, key uint32) (unsafe.Pointer, bool) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess2_fast32))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0]), false
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
var b *bmap
if h.B == 0 {
// One-bucket table. No need to hash.
b = (*bmap)(h.buckets)
} else {
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
m := bucketMask(h.B)
b = (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
}
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 4) {
if *(*uint32)(k) == key && !isEmpty(b.tophash[i]) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*4+i*uintptr(t.elemsize)), true
}
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
func mapassign_fast32(t *maptype, h *hmap, key uint32) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapassign_fast32))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapassign.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast32(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if isEmpty(b.tophash[i]) {
if insertb == nil {
inserti = i
insertb = b
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := *((*uint32)(add(unsafe.Pointer(b), dataOffset+i*4)))
if k != key {
continue
}
inserti = i
insertb = b
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*4)
// store new key at insert position
*(*uint32)(insertk) = key
h.count++
done:
elem := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*4+inserti*uintptr(t.elemsize))
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
return elem
}
func mapassign_fast32ptr(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapassign_fast32))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapassign.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast32(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if isEmpty(b.tophash[i]) {
if insertb == nil {
inserti = i
insertb = b
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := *((*unsafe.Pointer)(add(unsafe.Pointer(b), dataOffset+i*4)))
if k != key {
continue
}
inserti = i
insertb = b
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*4)
// store new key at insert position
*(*unsafe.Pointer)(insertk) = key
h.count++
done:
elem := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*4+inserti*uintptr(t.elemsize))
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
return elem
}
func mapdelete_fast32(t *maptype, h *hmap, key uint32) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapdelete_fast32))
}
if h == nil || h.count == 0 {
return
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapdelete
h.flags ^= hashWriting
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast32(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
bOrig := b
search:
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 4) {
if key != *(*uint32)(k) || isEmpty(b.tophash[i]) {
continue
}
// Only clear key if there are pointers in it.
// This can only happen if pointers are 32 bit
// wide as 64 bit pointers do not fit into a 32 bit key.
if goarch.PtrSize == 4 && t.key.ptrdata != 0 {
// The key must be a pointer as we checked pointers are
// 32 bits wide and the key is 32 bits wide also.
*(*unsafe.Pointer)(k) = nil
}
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*4+i*uintptr(t.elemsize))
if t.elem.ptrdata != 0 {
memclrHasPointers(e, t.elem.size)
} else {
memclrNoHeapPointers(e, t.elem.size)
}
b.tophash[i] = emptyOne
// If the bucket now ends in a bunch of emptyOne states,
// change those to emptyRest states.
if i == bucketCnt-1 {
if b.overflow(t) != nil && b.overflow(t).tophash[0] != emptyRest {
goto notLast
}
} else {
if b.tophash[i+1] != emptyRest {
goto notLast
}
}
for {
b.tophash[i] = emptyRest
if i == 0 {
if b == bOrig {
break // beginning of initial bucket, we're done.
}
// Find previous bucket, continue at its last entry.
c := b
for b = bOrig; b.overflow(t) != c; b = b.overflow(t) {
}
i = bucketCnt - 1
} else {
i--
}
if b.tophash[i] != emptyOne {
break
}
}
notLast:
h.count--
// Reset the hash seed to make it more difficult for attackers to
// repeatedly trigger hash collisions. See issue 25237.
if h.count == 0 {
h.hash0 = fastrand()
}
break search
}
}
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
}
func growWork_fast32(t *maptype, h *hmap, bucket uintptr) {
// make sure we evacuate the oldbucket corresponding
// to the bucket we're about to use
evacuate_fast32(t, h, bucket&h.oldbucketmask())
// evacuate one more oldbucket to make progress on growing
if h.growing() {
evacuate_fast32(t, h, h.nevacuate)
}
}
func evacuate_fast32(t *maptype, h *hmap, oldbucket uintptr) {
b := (*bmap)(add(h.oldbuckets, oldbucket*uintptr(t.bucketsize)))
newbit := h.noldbuckets()
if !evacuated(b) {
// TODO: reuse overflow buckets instead of using new ones, if there
// is no iterator using the old buckets. (If !oldIterator.)
// xy contains the x and y (low and high) evacuation destinations.
var xy [2]evacDst
x := &xy[0]
x.b = (*bmap)(add(h.buckets, oldbucket*uintptr(t.bucketsize)))
x.k = add(unsafe.Pointer(x.b), dataOffset)
x.e = add(x.k, bucketCnt*4)
if !h.sameSizeGrow() {
// Only calculate y pointers if we're growing bigger.
// Otherwise GC can see bad pointers.
y := &xy[1]
y.b = (*bmap)(add(h.buckets, (oldbucket+newbit)*uintptr(t.bucketsize)))
y.k = add(unsafe.Pointer(y.b), dataOffset)
y.e = add(y.k, bucketCnt*4)
}
for ; b != nil; b = b.overflow(t) {
k := add(unsafe.Pointer(b), dataOffset)
e := add(k, bucketCnt*4)
for i := 0; i < bucketCnt; i, k, e = i+1, add(k, 4), add(e, uintptr(t.elemsize)) {
top := b.tophash[i]
if isEmpty(top) {
b.tophash[i] = evacuatedEmpty
continue
}
if top < minTopHash {
throw("bad map state")
}
var useY uint8
if !h.sameSizeGrow() {
// Compute hash to make our evacuation decision (whether we need
// to send this key/elem to bucket x or bucket y).
hash := t.hasher(k, uintptr(h.hash0))
if hash&newbit != 0 {
useY = 1
}
}
b.tophash[i] = evacuatedX + useY // evacuatedX + 1 == evacuatedY, enforced in makemap
dst := &xy[useY] // evacuation destination
if dst.i == bucketCnt {
dst.b = h.newoverflow(t, dst.b)
dst.i = 0
dst.k = add(unsafe.Pointer(dst.b), dataOffset)
dst.e = add(dst.k, bucketCnt*4)
}
dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
// Copy key.
if goarch.PtrSize == 4 && t.key.ptrdata != 0 && writeBarrier.enabled {
// Write with a write barrier.
*(*unsafe.Pointer)(dst.k) = *(*unsafe.Pointer)(k)
} else {
*(*uint32)(dst.k) = *(*uint32)(k)
}
typedmemmove(t.elem, dst.e, e)
dst.i++
// These updates might push these pointers past the end of the
// key or elem arrays. That's ok, as we have the overflow pointer
// at the end of the bucket to protect against pointing past the
// end of the bucket.
dst.k = add(dst.k, 4)
dst.e = add(dst.e, uintptr(t.elemsize))
}
}
// Unlink the overflow buckets & clear key/elem to help GC.
if h.flags&oldIterator == 0 && t.bucket.ptrdata != 0 {
b := add(h.oldbuckets, oldbucket*uintptr(t.bucketsize))
// Preserve b.tophash because the evacuation
// state is maintained there.
ptr := add(b, dataOffset)
n := uintptr(t.bucketsize) - dataOffset
memclrHasPointers(ptr, n)
}
}
if oldbucket == h.nevacuate {
advanceEvacuationMark(h, t, newbit)
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"unsafe"
)
func mapaccess1_fast64(t *maptype, h *hmap, key uint64) unsafe.Pointer {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess1_fast64))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0])
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
var b *bmap
if h.B == 0 {
// One-bucket table. No need to hash.
b = (*bmap)(h.buckets)
} else {
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
m := bucketMask(h.B)
b = (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
}
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 8) {
if *(*uint64)(k) == key && !isEmpty(b.tophash[i]) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*8+i*uintptr(t.elemsize))
}
}
}
return unsafe.Pointer(&zeroVal[0])
}
func mapaccess2_fast64(t *maptype, h *hmap, key uint64) (unsafe.Pointer, bool) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess2_fast64))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0]), false
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
var b *bmap
if h.B == 0 {
// One-bucket table. No need to hash.
b = (*bmap)(h.buckets)
} else {
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
m := bucketMask(h.B)
b = (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
}
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 8) {
if *(*uint64)(k) == key && !isEmpty(b.tophash[i]) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*8+i*uintptr(t.elemsize)), true
}
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
func mapassign_fast64(t *maptype, h *hmap, key uint64) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapassign_fast64))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapassign.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast64(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if isEmpty(b.tophash[i]) {
if insertb == nil {
insertb = b
inserti = i
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := *((*uint64)(add(unsafe.Pointer(b), dataOffset+i*8)))
if k != key {
continue
}
insertb = b
inserti = i
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*8)
// store new key at insert position
*(*uint64)(insertk) = key
h.count++
done:
elem := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*8+inserti*uintptr(t.elemsize))
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
return elem
}
func mapassign_fast64ptr(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapassign_fast64))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapassign.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast64(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if isEmpty(b.tophash[i]) {
if insertb == nil {
insertb = b
inserti = i
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := *((*unsafe.Pointer)(add(unsafe.Pointer(b), dataOffset+i*8)))
if k != key {
continue
}
insertb = b
inserti = i
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = tophash(hash) // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*8)
// store new key at insert position
*(*unsafe.Pointer)(insertk) = key
h.count++
done:
elem := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*8+inserti*uintptr(t.elemsize))
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
return elem
}
func mapdelete_fast64(t *maptype, h *hmap, key uint64) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapdelete_fast64))
}
if h == nil || h.count == 0 {
return
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
hash := t.hasher(noescape(unsafe.Pointer(&key)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapdelete
h.flags ^= hashWriting
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_fast64(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
bOrig := b
search:
for ; b != nil; b = b.overflow(t) {
for i, k := uintptr(0), b.keys(); i < bucketCnt; i, k = i+1, add(k, 8) {
if key != *(*uint64)(k) || isEmpty(b.tophash[i]) {
continue
}
// Only clear key if there are pointers in it.
if t.key.ptrdata != 0 {
if goarch.PtrSize == 8 {
*(*unsafe.Pointer)(k) = nil
} else {
// There are three ways to squeeze at one ore more 32 bit pointers into 64 bits.
// Just call memclrHasPointers instead of trying to handle all cases here.
memclrHasPointers(k, 8)
}
}
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*8+i*uintptr(t.elemsize))
if t.elem.ptrdata != 0 {
memclrHasPointers(e, t.elem.size)
} else {
memclrNoHeapPointers(e, t.elem.size)
}
b.tophash[i] = emptyOne
// If the bucket now ends in a bunch of emptyOne states,
// change those to emptyRest states.
if i == bucketCnt-1 {
if b.overflow(t) != nil && b.overflow(t).tophash[0] != emptyRest {
goto notLast
}
} else {
if b.tophash[i+1] != emptyRest {
goto notLast
}
}
for {
b.tophash[i] = emptyRest
if i == 0 {
if b == bOrig {
break // beginning of initial bucket, we're done.
}
// Find previous bucket, continue at its last entry.
c := b
for b = bOrig; b.overflow(t) != c; b = b.overflow(t) {
}
i = bucketCnt - 1
} else {
i--
}
if b.tophash[i] != emptyOne {
break
}
}
notLast:
h.count--
// Reset the hash seed to make it more difficult for attackers to
// repeatedly trigger hash collisions. See issue 25237.
if h.count == 0 {
h.hash0 = fastrand()
}
break search
}
}
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
}
func growWork_fast64(t *maptype, h *hmap, bucket uintptr) {
// make sure we evacuate the oldbucket corresponding
// to the bucket we're about to use
evacuate_fast64(t, h, bucket&h.oldbucketmask())
// evacuate one more oldbucket to make progress on growing
if h.growing() {
evacuate_fast64(t, h, h.nevacuate)
}
}
func evacuate_fast64(t *maptype, h *hmap, oldbucket uintptr) {
b := (*bmap)(add(h.oldbuckets, oldbucket*uintptr(t.bucketsize)))
newbit := h.noldbuckets()
if !evacuated(b) {
// TODO: reuse overflow buckets instead of using new ones, if there
// is no iterator using the old buckets. (If !oldIterator.)
// xy contains the x and y (low and high) evacuation destinations.
var xy [2]evacDst
x := &xy[0]
x.b = (*bmap)(add(h.buckets, oldbucket*uintptr(t.bucketsize)))
x.k = add(unsafe.Pointer(x.b), dataOffset)
x.e = add(x.k, bucketCnt*8)
if !h.sameSizeGrow() {
// Only calculate y pointers if we're growing bigger.
// Otherwise GC can see bad pointers.
y := &xy[1]
y.b = (*bmap)(add(h.buckets, (oldbucket+newbit)*uintptr(t.bucketsize)))
y.k = add(unsafe.Pointer(y.b), dataOffset)
y.e = add(y.k, bucketCnt*8)
}
for ; b != nil; b = b.overflow(t) {
k := add(unsafe.Pointer(b), dataOffset)
e := add(k, bucketCnt*8)
for i := 0; i < bucketCnt; i, k, e = i+1, add(k, 8), add(e, uintptr(t.elemsize)) {
top := b.tophash[i]
if isEmpty(top) {
b.tophash[i] = evacuatedEmpty
continue
}
if top < minTopHash {
throw("bad map state")
}
var useY uint8
if !h.sameSizeGrow() {
// Compute hash to make our evacuation decision (whether we need
// to send this key/elem to bucket x or bucket y).
hash := t.hasher(k, uintptr(h.hash0))
if hash&newbit != 0 {
useY = 1
}
}
b.tophash[i] = evacuatedX + useY // evacuatedX + 1 == evacuatedY, enforced in makemap
dst := &xy[useY] // evacuation destination
if dst.i == bucketCnt {
dst.b = h.newoverflow(t, dst.b)
dst.i = 0
dst.k = add(unsafe.Pointer(dst.b), dataOffset)
dst.e = add(dst.k, bucketCnt*8)
}
dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
// Copy key.
if t.key.ptrdata != 0 && writeBarrier.enabled {
if goarch.PtrSize == 8 {
// Write with a write barrier.
*(*unsafe.Pointer)(dst.k) = *(*unsafe.Pointer)(k)
} else {
// There are three ways to squeeze at least one 32 bit pointer into 64 bits.
// Give up and call typedmemmove.
typedmemmove(t.key, dst.k, k)
}
} else {
*(*uint64)(dst.k) = *(*uint64)(k)
}
typedmemmove(t.elem, dst.e, e)
dst.i++
// These updates might push these pointers past the end of the
// key or elem arrays. That's ok, as we have the overflow pointer
// at the end of the bucket to protect against pointing past the
// end of the bucket.
dst.k = add(dst.k, 8)
dst.e = add(dst.e, uintptr(t.elemsize))
}
}
// Unlink the overflow buckets & clear key/elem to help GC.
if h.flags&oldIterator == 0 && t.bucket.ptrdata != 0 {
b := add(h.oldbuckets, oldbucket*uintptr(t.bucketsize))
// Preserve b.tophash because the evacuation
// state is maintained there.
ptr := add(b, dataOffset)
n := uintptr(t.bucketsize) - dataOffset
memclrHasPointers(ptr, n)
}
}
if oldbucket == h.nevacuate {
advanceEvacuationMark(h, t, newbit)
}
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"unsafe"
)
func mapaccess1_faststr(t *maptype, h *hmap, ky string) unsafe.Pointer {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess1_faststr))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0])
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
key := stringStructOf(&ky)
if h.B == 0 {
// One-bucket table.
b := (*bmap)(h.buckets)
if key.len < 32 {
// short key, doing lots of comparisons is ok
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || isEmpty(b.tophash[i]) {
if b.tophash[i] == emptyRest {
break
}
continue
}
if k.str == key.str || memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize))
}
}
return unsafe.Pointer(&zeroVal[0])
}
// long key, try not to do more comparisons than necessary
keymaybe := uintptr(bucketCnt)
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || isEmpty(b.tophash[i]) {
if b.tophash[i] == emptyRest {
break
}
continue
}
if k.str == key.str {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize))
}
// check first 4 bytes
if *((*[4]byte)(key.str)) != *((*[4]byte)(k.str)) {
continue
}
// check last 4 bytes
if *((*[4]byte)(add(key.str, uintptr(key.len)-4))) != *((*[4]byte)(add(k.str, uintptr(key.len)-4))) {
continue
}
if keymaybe != bucketCnt {
// Two keys are potential matches. Use hash to distinguish them.
goto dohash
}
keymaybe = i
}
if keymaybe != bucketCnt {
k := (*stringStruct)(add(unsafe.Pointer(b), dataOffset+keymaybe*2*goarch.PtrSize))
if memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+keymaybe*uintptr(t.elemsize))
}
}
return unsafe.Pointer(&zeroVal[0])
}
dohash:
hash := t.hasher(noescape(unsafe.Pointer(&ky)), uintptr(h.hash0))
m := bucketMask(h.B)
b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
top := tophash(hash)
for ; b != nil; b = b.overflow(t) {
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || b.tophash[i] != top {
continue
}
if k.str == key.str || memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize))
}
}
}
return unsafe.Pointer(&zeroVal[0])
}
func mapaccess2_faststr(t *maptype, h *hmap, ky string) (unsafe.Pointer, bool) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racereadpc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapaccess2_faststr))
}
if h == nil || h.count == 0 {
return unsafe.Pointer(&zeroVal[0]), false
}
if h.flags&hashWriting != 0 {
fatal("concurrent map read and map write")
}
key := stringStructOf(&ky)
if h.B == 0 {
// One-bucket table.
b := (*bmap)(h.buckets)
if key.len < 32 {
// short key, doing lots of comparisons is ok
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || isEmpty(b.tophash[i]) {
if b.tophash[i] == emptyRest {
break
}
continue
}
if k.str == key.str || memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize)), true
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
// long key, try not to do more comparisons than necessary
keymaybe := uintptr(bucketCnt)
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || isEmpty(b.tophash[i]) {
if b.tophash[i] == emptyRest {
break
}
continue
}
if k.str == key.str {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize)), true
}
// check first 4 bytes
if *((*[4]byte)(key.str)) != *((*[4]byte)(k.str)) {
continue
}
// check last 4 bytes
if *((*[4]byte)(add(key.str, uintptr(key.len)-4))) != *((*[4]byte)(add(k.str, uintptr(key.len)-4))) {
continue
}
if keymaybe != bucketCnt {
// Two keys are potential matches. Use hash to distinguish them.
goto dohash
}
keymaybe = i
}
if keymaybe != bucketCnt {
k := (*stringStruct)(add(unsafe.Pointer(b), dataOffset+keymaybe*2*goarch.PtrSize))
if memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+keymaybe*uintptr(t.elemsize)), true
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
dohash:
hash := t.hasher(noescape(unsafe.Pointer(&ky)), uintptr(h.hash0))
m := bucketMask(h.B)
b := (*bmap)(add(h.buckets, (hash&m)*uintptr(t.bucketsize)))
if c := h.oldbuckets; c != nil {
if !h.sameSizeGrow() {
// There used to be half as many buckets; mask down one more power of two.
m >>= 1
}
oldb := (*bmap)(add(c, (hash&m)*uintptr(t.bucketsize)))
if !evacuated(oldb) {
b = oldb
}
}
top := tophash(hash)
for ; b != nil; b = b.overflow(t) {
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || b.tophash[i] != top {
continue
}
if k.str == key.str || memequal(k.str, key.str, uintptr(key.len)) {
return add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize)), true
}
}
}
return unsafe.Pointer(&zeroVal[0]), false
}
func mapassign_faststr(t *maptype, h *hmap, s string) unsafe.Pointer {
if h == nil {
panic(plainError("assignment to entry in nil map"))
}
if raceenabled {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapassign_faststr))
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
key := stringStructOf(&s)
hash := t.hasher(noescape(unsafe.Pointer(&s)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapassign.
h.flags ^= hashWriting
if h.buckets == nil {
h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
}
again:
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_faststr(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
top := tophash(hash)
var insertb *bmap
var inserti uintptr
var insertk unsafe.Pointer
bucketloop:
for {
for i := uintptr(0); i < bucketCnt; i++ {
if b.tophash[i] != top {
if isEmpty(b.tophash[i]) && insertb == nil {
insertb = b
inserti = i
}
if b.tophash[i] == emptyRest {
break bucketloop
}
continue
}
k := (*stringStruct)(add(unsafe.Pointer(b), dataOffset+i*2*goarch.PtrSize))
if k.len != key.len {
continue
}
if k.str != key.str && !memequal(k.str, key.str, uintptr(key.len)) {
continue
}
// already have a mapping for key. Update it.
inserti = i
insertb = b
// Overwrite existing key, so it can be garbage collected.
// The size is already guaranteed to be set correctly.
k.str = key.str
goto done
}
ovf := b.overflow(t)
if ovf == nil {
break
}
b = ovf
}
// Did not find mapping for key. Allocate new cell & add entry.
// If we hit the max load factor or we have too many overflow buckets,
// and we're not already in the middle of growing, start growing.
if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
hashGrow(t, h)
goto again // Growing the table invalidates everything, so try again
}
if insertb == nil {
// The current bucket and all the overflow buckets connected to it are full, allocate a new one.
insertb = h.newoverflow(t, b)
inserti = 0 // not necessary, but avoids needlessly spilling inserti
}
insertb.tophash[inserti&(bucketCnt-1)] = top // mask inserti to avoid bounds checks
insertk = add(unsafe.Pointer(insertb), dataOffset+inserti*2*goarch.PtrSize)
// store new key at insert position
*((*stringStruct)(insertk)) = *key
h.count++
done:
elem := add(unsafe.Pointer(insertb), dataOffset+bucketCnt*2*goarch.PtrSize+inserti*uintptr(t.elemsize))
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
return elem
}
func mapdelete_faststr(t *maptype, h *hmap, ky string) {
if raceenabled && h != nil {
callerpc := getcallerpc()
racewritepc(unsafe.Pointer(h), callerpc, abi.FuncPCABIInternal(mapdelete_faststr))
}
if h == nil || h.count == 0 {
return
}
if h.flags&hashWriting != 0 {
fatal("concurrent map writes")
}
key := stringStructOf(&ky)
hash := t.hasher(noescape(unsafe.Pointer(&ky)), uintptr(h.hash0))
// Set hashWriting after calling t.hasher for consistency with mapdelete
h.flags ^= hashWriting
bucket := hash & bucketMask(h.B)
if h.growing() {
growWork_faststr(t, h, bucket)
}
b := (*bmap)(add(h.buckets, bucket*uintptr(t.bucketsize)))
bOrig := b
top := tophash(hash)
search:
for ; b != nil; b = b.overflow(t) {
for i, kptr := uintptr(0), b.keys(); i < bucketCnt; i, kptr = i+1, add(kptr, 2*goarch.PtrSize) {
k := (*stringStruct)(kptr)
if k.len != key.len || b.tophash[i] != top {
continue
}
if k.str != key.str && !memequal(k.str, key.str, uintptr(key.len)) {
continue
}
// Clear key's pointer.
k.str = nil
e := add(unsafe.Pointer(b), dataOffset+bucketCnt*2*goarch.PtrSize+i*uintptr(t.elemsize))
if t.elem.ptrdata != 0 {
memclrHasPointers(e, t.elem.size)
} else {
memclrNoHeapPointers(e, t.elem.size)
}
b.tophash[i] = emptyOne
// If the bucket now ends in a bunch of emptyOne states,
// change those to emptyRest states.
if i == bucketCnt-1 {
if b.overflow(t) != nil && b.overflow(t).tophash[0] != emptyRest {
goto notLast
}
} else {
if b.tophash[i+1] != emptyRest {
goto notLast
}
}
for {
b.tophash[i] = emptyRest
if i == 0 {
if b == bOrig {
break // beginning of initial bucket, we're done.
}
// Find previous bucket, continue at its last entry.
c := b
for b = bOrig; b.overflow(t) != c; b = b.overflow(t) {
}
i = bucketCnt - 1
} else {
i--
}
if b.tophash[i] != emptyOne {
break
}
}
notLast:
h.count--
// Reset the hash seed to make it more difficult for attackers to
// repeatedly trigger hash collisions. See issue 25237.
if h.count == 0 {
h.hash0 = fastrand()
}
break search
}
}
if h.flags&hashWriting == 0 {
fatal("concurrent map writes")
}
h.flags &^= hashWriting
}
func growWork_faststr(t *maptype, h *hmap, bucket uintptr) {
// make sure we evacuate the oldbucket corresponding
// to the bucket we're about to use
evacuate_faststr(t, h, bucket&h.oldbucketmask())
// evacuate one more oldbucket to make progress on growing
if h.growing() {
evacuate_faststr(t, h, h.nevacuate)
}
}
func evacuate_faststr(t *maptype, h *hmap, oldbucket uintptr) {
b := (*bmap)(add(h.oldbuckets, oldbucket*uintptr(t.bucketsize)))
newbit := h.noldbuckets()
if !evacuated(b) {
// TODO: reuse overflow buckets instead of using new ones, if there
// is no iterator using the old buckets. (If !oldIterator.)
// xy contains the x and y (low and high) evacuation destinations.
var xy [2]evacDst
x := &xy[0]
x.b = (*bmap)(add(h.buckets, oldbucket*uintptr(t.bucketsize)))
x.k = add(unsafe.Pointer(x.b), dataOffset)
x.e = add(x.k, bucketCnt*2*goarch.PtrSize)
if !h.sameSizeGrow() {
// Only calculate y pointers if we're growing bigger.
// Otherwise GC can see bad pointers.
y := &xy[1]
y.b = (*bmap)(add(h.buckets, (oldbucket+newbit)*uintptr(t.bucketsize)))
y.k = add(unsafe.Pointer(y.b), dataOffset)
y.e = add(y.k, bucketCnt*2*goarch.PtrSize)
}
for ; b != nil; b = b.overflow(t) {
k := add(unsafe.Pointer(b), dataOffset)
e := add(k, bucketCnt*2*goarch.PtrSize)
for i := 0; i < bucketCnt; i, k, e = i+1, add(k, 2*goarch.PtrSize), add(e, uintptr(t.elemsize)) {
top := b.tophash[i]
if isEmpty(top) {
b.tophash[i] = evacuatedEmpty
continue
}
if top < minTopHash {
throw("bad map state")
}
var useY uint8
if !h.sameSizeGrow() {
// Compute hash to make our evacuation decision (whether we need
// to send this key/elem to bucket x or bucket y).
hash := t.hasher(k, uintptr(h.hash0))
if hash&newbit != 0 {
useY = 1
}
}
b.tophash[i] = evacuatedX + useY // evacuatedX + 1 == evacuatedY, enforced in makemap
dst := &xy[useY] // evacuation destination
if dst.i == bucketCnt {
dst.b = h.newoverflow(t, dst.b)
dst.i = 0
dst.k = add(unsafe.Pointer(dst.b), dataOffset)
dst.e = add(dst.k, bucketCnt*2*goarch.PtrSize)
}
dst.b.tophash[dst.i&(bucketCnt-1)] = top // mask dst.i as an optimization, to avoid a bounds check
// Copy key.
*(*string)(dst.k) = *(*string)(k)
typedmemmove(t.elem, dst.e, e)
dst.i++
// These updates might push these pointers past the end of the
// key or elem arrays. That's ok, as we have the overflow pointer
// at the end of the bucket to protect against pointing past the
// end of the bucket.
dst.k = add(dst.k, 2*goarch.PtrSize)
dst.e = add(dst.e, uintptr(t.elemsize))
}
}
// Unlink the overflow buckets & clear key/elem to help GC.
if h.flags&oldIterator == 0 && t.bucket.ptrdata != 0 {
b := add(h.oldbuckets, oldbucket*uintptr(t.bucketsize))
// Preserve b.tophash because the evacuation
// state is maintained there.
ptr := add(b, dataOffset)
n := uintptr(t.bucketsize) - dataOffset
memclrHasPointers(ptr, n)
}
}
if oldbucket == h.nevacuate {
advanceEvacuationMark(h, t, newbit)
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: write barriers.
//
// For the concurrent garbage collector, the Go compiler implements
// updates to pointer-valued fields that may be in heap objects by
// emitting calls to write barriers. The main write barrier for
// individual pointer writes is gcWriteBarrier and is implemented in
// assembly. This file contains write barrier entry points for bulk
// operations. See also mwbbuf.go.
package runtime
import (
"internal/abi"
"internal/goarch"
"internal/goexperiment"
"unsafe"
)
// Go uses a hybrid barrier that combines a Yuasa-style deletion
// barrier—which shades the object whose reference is being
// overwritten—with Dijkstra insertion barrier—which shades the object
// whose reference is being written. The insertion part of the barrier
// is necessary while the calling goroutine's stack is grey. In
// pseudocode, the barrier is:
//
// writePointer(slot, ptr):
// shade(*slot)
// if current stack is grey:
// shade(ptr)
// *slot = ptr
//
// slot is the destination in Go code.
// ptr is the value that goes into the slot in Go code.
//
// Shade indicates that it has seen a white pointer by adding the referent
// to wbuf as well as marking it.
//
// The two shades and the condition work together to prevent a mutator
// from hiding an object from the garbage collector:
//
// 1. shade(*slot) prevents a mutator from hiding an object by moving
// the sole pointer to it from the heap to its stack. If it attempts
// to unlink an object from the heap, this will shade it.
//
// 2. shade(ptr) prevents a mutator from hiding an object by moving
// the sole pointer to it from its stack into a black object in the
// heap. If it attempts to install the pointer into a black object,
// this will shade it.
//
// 3. Once a goroutine's stack is black, the shade(ptr) becomes
// unnecessary. shade(ptr) prevents hiding an object by moving it from
// the stack to the heap, but this requires first having a pointer
// hidden on the stack. Immediately after a stack is scanned, it only
// points to shaded objects, so it's not hiding anything, and the
// shade(*slot) prevents it from hiding any other pointers on its
// stack.
//
// For a detailed description of this barrier and proof of
// correctness, see https://github.com/golang/proposal/blob/master/design/17503-eliminate-rescan.md
//
//
//
// Dealing with memory ordering:
//
// Both the Yuasa and Dijkstra barriers can be made conditional on the
// color of the object containing the slot. We chose not to make these
// conditional because the cost of ensuring that the object holding
// the slot doesn't concurrently change color without the mutator
// noticing seems prohibitive.
//
// Consider the following example where the mutator writes into
// a slot and then loads the slot's mark bit while the GC thread
// writes to the slot's mark bit and then as part of scanning reads
// the slot.
//
// Initially both [slot] and [slotmark] are 0 (nil)
// Mutator thread GC thread
// st [slot], ptr st [slotmark], 1
//
// ld r1, [slotmark] ld r2, [slot]
//
// Without an expensive memory barrier between the st and the ld, the final
// result on most HW (including 386/amd64) can be r1==r2==0. This is a classic
// example of what can happen when loads are allowed to be reordered with older
// stores (avoiding such reorderings lies at the heart of the classic
// Peterson/Dekker algorithms for mutual exclusion). Rather than require memory
// barriers, which will slow down both the mutator and the GC, we always grey
// the ptr object regardless of the slot's color.
//
// Another place where we intentionally omit memory barriers is when
// accessing mheap_.arena_used to check if a pointer points into the
// heap. On relaxed memory machines, it's possible for a mutator to
// extend the size of the heap by updating arena_used, allocate an
// object from this new region, and publish a pointer to that object,
// but for tracing running on another processor to observe the pointer
// but use the old value of arena_used. In this case, tracing will not
// mark the object, even though it's reachable. However, the mutator
// is guaranteed to execute a write barrier when it publishes the
// pointer, so it will take care of marking the object. A general
// consequence of this is that the garbage collector may cache the
// value of mheap_.arena_used. (See issue #9984.)
//
//
// Stack writes:
//
// The compiler omits write barriers for writes to the current frame,
// but if a stack pointer has been passed down the call stack, the
// compiler will generate a write barrier for writes through that
// pointer (because it doesn't know it's not a heap pointer).
//
// One might be tempted to ignore the write barrier if slot points
// into to the stack. Don't do it! Mark termination only re-scans
// frames that have potentially been active since the concurrent scan,
// so it depends on write barriers to track changes to pointers in
// stack frames that have not been active.
//
//
// Global writes:
//
// The Go garbage collector requires write barriers when heap pointers
// are stored in globals. Many garbage collectors ignore writes to
// globals and instead pick up global -> heap pointers during
// termination. This increases pause time, so we instead rely on write
// barriers for writes to globals so that we don't have to rescan
// global during mark termination.
//
//
// Publication ordering:
//
// The write barrier is *pre-publication*, meaning that the write
// barrier happens prior to the *slot = ptr write that may make ptr
// reachable by some goroutine that currently cannot reach it.
//
//
// Signal handler pointer writes:
//
// In general, the signal handler cannot safely invoke the write
// barrier because it may run without a P or even during the write
// barrier.
//
// There is exactly one exception: profbuf.go omits a barrier during
// signal handler profile logging. That's safe only because of the
// deletion barrier. See profbuf.go for a detailed argument. If we
// remove the deletion barrier, we'll have to work out a new way to
// handle the profile logging.
// typedmemmove copies a value of type typ to dst from src.
// Must be nosplit, see #16026.
//
// TODO: Perfect for go:nosplitrec since we can't have a safe point
// anywhere in the bulk barrier or memmove.
//
//go:nosplit
func typedmemmove(typ *_type, dst, src unsafe.Pointer) {
if dst == src {
return
}
if writeBarrier.needed && typ.ptrdata != 0 {
bulkBarrierPreWrite(uintptr(dst), uintptr(src), typ.ptrdata)
}
// There's a race here: if some other goroutine can write to
// src, it may change some pointer in src after we've
// performed the write barrier but before we perform the
// memory copy. This safe because the write performed by that
// other goroutine must also be accompanied by a write
// barrier, so at worst we've unnecessarily greyed the old
// pointer that was in src.
memmove(dst, src, typ.size)
if goexperiment.CgoCheck2 {
cgoCheckMemmove2(typ, dst, src, 0, typ.size)
}
}
// wbZero performs the write barrier operations necessary before
// zeroing a region of memory at address dst of type typ.
// Does not actually do the zeroing.
//go:nowritebarrierrec
//go:nosplit
func wbZero(typ *_type, dst unsafe.Pointer) {
bulkBarrierPreWrite(uintptr(dst), 0, typ.ptrdata)
}
// wbMove performs the write barrier operations necessary before
// copying a region of memory from src to dst of type typ.
// Does not actually do the copying.
//go:nowritebarrierrec
//go:nosplit
func wbMove(typ *_type, dst, src unsafe.Pointer) {
bulkBarrierPreWrite(uintptr(dst), uintptr(src), typ.ptrdata)
}
//go:linkname reflect_typedmemmove reflect.typedmemmove
func reflect_typedmemmove(typ *_type, dst, src unsafe.Pointer) {
if raceenabled {
raceWriteObjectPC(typ, dst, getcallerpc(), abi.FuncPCABIInternal(reflect_typedmemmove))
raceReadObjectPC(typ, src, getcallerpc(), abi.FuncPCABIInternal(reflect_typedmemmove))
}
if msanenabled {
msanwrite(dst, typ.size)
msanread(src, typ.size)
}
if asanenabled {
asanwrite(dst, typ.size)
asanread(src, typ.size)
}
typedmemmove(typ, dst, src)
}
//go:linkname reflectlite_typedmemmove internal/reflectlite.typedmemmove
func reflectlite_typedmemmove(typ *_type, dst, src unsafe.Pointer) {
reflect_typedmemmove(typ, dst, src)
}
// reflect_typedmemmovepartial is like typedmemmove but assumes that
// dst and src point off bytes into the value and only copies size bytes.
// off must be a multiple of goarch.PtrSize.
//
//go:linkname reflect_typedmemmovepartial reflect.typedmemmovepartial
func reflect_typedmemmovepartial(typ *_type, dst, src unsafe.Pointer, off, size uintptr) {
if writeBarrier.needed && typ.ptrdata > off && size >= goarch.PtrSize {
if off&(goarch.PtrSize-1) != 0 {
panic("reflect: internal error: misaligned offset")
}
pwsize := alignDown(size, goarch.PtrSize)
if poff := typ.ptrdata - off; pwsize > poff {
pwsize = poff
}
bulkBarrierPreWrite(uintptr(dst), uintptr(src), pwsize)
}
memmove(dst, src, size)
if goexperiment.CgoCheck2 {
cgoCheckMemmove2(typ, dst, src, off, size)
}
}
// reflectcallmove is invoked by reflectcall to copy the return values
// out of the stack and into the heap, invoking the necessary write
// barriers. dst, src, and size describe the return value area to
// copy. typ describes the entire frame (not just the return values).
// typ may be nil, which indicates write barriers are not needed.
//
// It must be nosplit and must only call nosplit functions because the
// stack map of reflectcall is wrong.
//
//go:nosplit
func reflectcallmove(typ *_type, dst, src unsafe.Pointer, size uintptr, regs *abi.RegArgs) {
if writeBarrier.needed && typ != nil && typ.ptrdata != 0 && size >= goarch.PtrSize {
bulkBarrierPreWrite(uintptr(dst), uintptr(src), size)
}
memmove(dst, src, size)
// Move pointers returned in registers to a place where the GC can see them.
for i := range regs.Ints {
if regs.ReturnIsPtr.Get(i) {
regs.Ptrs[i] = unsafe.Pointer(regs.Ints[i])
}
}
}
//go:nosplit
func typedslicecopy(typ *_type, dstPtr unsafe.Pointer, dstLen int, srcPtr unsafe.Pointer, srcLen int) int {
n := dstLen
if n > srcLen {
n = srcLen
}
if n == 0 {
return 0
}
// The compiler emits calls to typedslicecopy before
// instrumentation runs, so unlike the other copying and
// assignment operations, it's not instrumented in the calling
// code and needs its own instrumentation.
if raceenabled {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(slicecopy)
racewriterangepc(dstPtr, uintptr(n)*typ.size, callerpc, pc)
racereadrangepc(srcPtr, uintptr(n)*typ.size, callerpc, pc)
}
if msanenabled {
msanwrite(dstPtr, uintptr(n)*typ.size)
msanread(srcPtr, uintptr(n)*typ.size)
}
if asanenabled {
asanwrite(dstPtr, uintptr(n)*typ.size)
asanread(srcPtr, uintptr(n)*typ.size)
}
if goexperiment.CgoCheck2 {
cgoCheckSliceCopy(typ, dstPtr, srcPtr, n)
}
if dstPtr == srcPtr {
return n
}
// Note: No point in checking typ.ptrdata here:
// compiler only emits calls to typedslicecopy for types with pointers,
// and growslice and reflect_typedslicecopy check for pointers
// before calling typedslicecopy.
size := uintptr(n) * typ.size
if writeBarrier.needed {
pwsize := size - typ.size + typ.ptrdata
bulkBarrierPreWrite(uintptr(dstPtr), uintptr(srcPtr), pwsize)
}
// See typedmemmove for a discussion of the race between the
// barrier and memmove.
memmove(dstPtr, srcPtr, size)
return n
}
//go:linkname reflect_typedslicecopy reflect.typedslicecopy
func reflect_typedslicecopy(elemType *_type, dst, src slice) int {
if elemType.ptrdata == 0 {
return slicecopy(dst.array, dst.len, src.array, src.len, elemType.size)
}
return typedslicecopy(elemType, dst.array, dst.len, src.array, src.len)
}
// typedmemclr clears the typed memory at ptr with type typ. The
// memory at ptr must already be initialized (and hence in type-safe
// state). If the memory is being initialized for the first time, see
// memclrNoHeapPointers.
//
// If the caller knows that typ has pointers, it can alternatively
// call memclrHasPointers.
//
// TODO: A "go:nosplitrec" annotation would be perfect for this.
//
//go:nosplit
func typedmemclr(typ *_type, ptr unsafe.Pointer) {
if writeBarrier.needed && typ.ptrdata != 0 {
bulkBarrierPreWrite(uintptr(ptr), 0, typ.ptrdata)
}
memclrNoHeapPointers(ptr, typ.size)
}
//go:linkname reflect_typedmemclr reflect.typedmemclr
func reflect_typedmemclr(typ *_type, ptr unsafe.Pointer) {
typedmemclr(typ, ptr)
}
//go:linkname reflect_typedmemclrpartial reflect.typedmemclrpartial
func reflect_typedmemclrpartial(typ *_type, ptr unsafe.Pointer, off, size uintptr) {
if writeBarrier.needed && typ.ptrdata != 0 {
bulkBarrierPreWrite(uintptr(ptr), 0, size)
}
memclrNoHeapPointers(ptr, size)
}
//go:linkname reflect_typedarrayclear reflect.typedarrayclear
func reflect_typedarrayclear(typ *_type, ptr unsafe.Pointer, len int) {
size := typ.size * uintptr(len)
if writeBarrier.needed && typ.ptrdata != 0 {
bulkBarrierPreWrite(uintptr(ptr), 0, size)
}
memclrNoHeapPointers(ptr, size)
}
// memclrHasPointers clears n bytes of typed memory starting at ptr.
// The caller must ensure that the type of the object at ptr has
// pointers, usually by checking typ.ptrdata. However, ptr
// does not have to point to the start of the allocation.
//
//go:nosplit
func memclrHasPointers(ptr unsafe.Pointer, n uintptr) {
bulkBarrierPreWrite(uintptr(ptr), 0, n)
memclrNoHeapPointers(ptr, n)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: type and heap bitmaps.
//
// Stack, data, and bss bitmaps
//
// Stack frames and global variables in the data and bss sections are
// described by bitmaps with 1 bit per pointer-sized word. A "1" bit
// means the word is a live pointer to be visited by the GC (referred to
// as "pointer"). A "0" bit means the word should be ignored by GC
// (referred to as "scalar", though it could be a dead pointer value).
//
// Heap bitmap
//
// The heap bitmap comprises 1 bit for each pointer-sized word in the heap,
// recording whether a pointer is stored in that word or not. This bitmap
// is stored in the heapArena metadata backing each heap arena.
// That is, if ha is the heapArena for the arena starting at "start",
// then ha.bitmap[0] holds the 64 bits for the 64 words "start"
// through start+63*ptrSize, ha.bitmap[1] holds the entries for
// start+64*ptrSize through start+127*ptrSize, and so on.
// Bits correspond to words in little-endian order. ha.bitmap[0]&1 represents
// the word at "start", ha.bitmap[0]>>1&1 represents the word at start+8, etc.
// (For 32-bit platforms, s/64/32/.)
//
// We also keep a noMorePtrs bitmap which allows us to stop scanning
// the heap bitmap early in certain situations. If ha.noMorePtrs[i]>>j&1
// is 1, then the object containing the last word described by ha.bitmap[8*i+j]
// has no more pointers beyond those described by ha.bitmap[8*i+j].
// If ha.noMorePtrs[i]>>j&1 is set, the entries in ha.bitmap[8*i+j+1] and
// beyond must all be zero until the start of the next object.
//
// The bitmap for noscan spans is set to all zero at span allocation time.
//
// The bitmap for unallocated objects in scannable spans is not maintained
// (can be junk).
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// addb returns the byte pointer p+n.
//
//go:nowritebarrier
//go:nosplit
func addb(p *byte, n uintptr) *byte {
// Note: wrote out full expression instead of calling add(p, n)
// to reduce the number of temporaries generated by the
// compiler for this trivial expression during inlining.
return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + n))
}
// subtractb returns the byte pointer p-n.
//
//go:nowritebarrier
//go:nosplit
func subtractb(p *byte, n uintptr) *byte {
// Note: wrote out full expression instead of calling add(p, -n)
// to reduce the number of temporaries generated by the
// compiler for this trivial expression during inlining.
return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) - n))
}
// add1 returns the byte pointer p+1.
//
//go:nowritebarrier
//go:nosplit
func add1(p *byte) *byte {
// Note: wrote out full expression instead of calling addb(p, 1)
// to reduce the number of temporaries generated by the
// compiler for this trivial expression during inlining.
return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) + 1))
}
// subtract1 returns the byte pointer p-1.
//
// nosplit because it is used during write barriers and must not be preempted.
//
//go:nowritebarrier
//go:nosplit
func subtract1(p *byte) *byte {
// Note: wrote out full expression instead of calling subtractb(p, 1)
// to reduce the number of temporaries generated by the
// compiler for this trivial expression during inlining.
return (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(p)) - 1))
}
// markBits provides access to the mark bit for an object in the heap.
// bytep points to the byte holding the mark bit.
// mask is a byte with a single bit set that can be &ed with *bytep
// to see if the bit has been set.
// *m.byte&m.mask != 0 indicates the mark bit is set.
// index can be used along with span information to generate
// the address of the object in the heap.
// We maintain one set of mark bits for allocation and one for
// marking purposes.
type markBits struct {
bytep *uint8
mask uint8
index uintptr
}
//go:nosplit
func (s *mspan) allocBitsForIndex(allocBitIndex uintptr) markBits {
bytep, mask := s.allocBits.bitp(allocBitIndex)
return markBits{bytep, mask, allocBitIndex}
}
// refillAllocCache takes 8 bytes s.allocBits starting at whichByte
// and negates them so that ctz (count trailing zeros) instructions
// can be used. It then places these 8 bytes into the cached 64 bit
// s.allocCache.
func (s *mspan) refillAllocCache(whichByte uintptr) {
bytes := (*[8]uint8)(unsafe.Pointer(s.allocBits.bytep(whichByte)))
aCache := uint64(0)
aCache |= uint64(bytes[0])
aCache |= uint64(bytes[1]) << (1 * 8)
aCache |= uint64(bytes[2]) << (2 * 8)
aCache |= uint64(bytes[3]) << (3 * 8)
aCache |= uint64(bytes[4]) << (4 * 8)
aCache |= uint64(bytes[5]) << (5 * 8)
aCache |= uint64(bytes[6]) << (6 * 8)
aCache |= uint64(bytes[7]) << (7 * 8)
s.allocCache = ^aCache
}
// nextFreeIndex returns the index of the next free object in s at
// or after s.freeindex.
// There are hardware instructions that can be used to make this
// faster if profiling warrants it.
func (s *mspan) nextFreeIndex() uintptr {
sfreeindex := s.freeindex
snelems := s.nelems
if sfreeindex == snelems {
return sfreeindex
}
if sfreeindex > snelems {
throw("s.freeindex > s.nelems")
}
aCache := s.allocCache
bitIndex := sys.TrailingZeros64(aCache)
for bitIndex == 64 {
// Move index to start of next cached bits.
sfreeindex = (sfreeindex + 64) &^ (64 - 1)
if sfreeindex >= snelems {
s.freeindex = snelems
return snelems
}
whichByte := sfreeindex / 8
// Refill s.allocCache with the next 64 alloc bits.
s.refillAllocCache(whichByte)
aCache = s.allocCache
bitIndex = sys.TrailingZeros64(aCache)
// nothing available in cached bits
// grab the next 8 bytes and try again.
}
result := sfreeindex + uintptr(bitIndex)
if result >= snelems {
s.freeindex = snelems
return snelems
}
s.allocCache >>= uint(bitIndex + 1)
sfreeindex = result + 1
if sfreeindex%64 == 0 && sfreeindex != snelems {
// We just incremented s.freeindex so it isn't 0.
// As each 1 in s.allocCache was encountered and used for allocation
// it was shifted away. At this point s.allocCache contains all 0s.
// Refill s.allocCache so that it corresponds
// to the bits at s.allocBits starting at s.freeindex.
whichByte := sfreeindex / 8
s.refillAllocCache(whichByte)
}
s.freeindex = sfreeindex
return result
}
// isFree reports whether the index'th object in s is unallocated.
//
// The caller must ensure s.state is mSpanInUse, and there must have
// been no preemption points since ensuring this (which could allow a
// GC transition, which would allow the state to change).
func (s *mspan) isFree(index uintptr) bool {
if index < s.freeIndexForScan {
return false
}
bytep, mask := s.allocBits.bitp(index)
return *bytep&mask == 0
}
// divideByElemSize returns n/s.elemsize.
// n must be within [0, s.npages*_PageSize),
// or may be exactly s.npages*_PageSize
// if s.elemsize is from sizeclasses.go.
func (s *mspan) divideByElemSize(n uintptr) uintptr {
const doubleCheck = false
// See explanation in mksizeclasses.go's computeDivMagic.
q := uintptr((uint64(n) * uint64(s.divMul)) >> 32)
if doubleCheck && q != n/s.elemsize {
println(n, "/", s.elemsize, "should be", n/s.elemsize, "but got", q)
throw("bad magic division")
}
return q
}
func (s *mspan) objIndex(p uintptr) uintptr {
return s.divideByElemSize(p - s.base())
}
func markBitsForAddr(p uintptr) markBits {
s := spanOf(p)
objIndex := s.objIndex(p)
return s.markBitsForIndex(objIndex)
}
func (s *mspan) markBitsForIndex(objIndex uintptr) markBits {
bytep, mask := s.gcmarkBits.bitp(objIndex)
return markBits{bytep, mask, objIndex}
}
func (s *mspan) markBitsForBase() markBits {
return markBits{&s.gcmarkBits.x, uint8(1), 0}
}
// isMarked reports whether mark bit m is set.
func (m markBits) isMarked() bool {
return *m.bytep&m.mask != 0
}
// setMarked sets the marked bit in the markbits, atomically.
func (m markBits) setMarked() {
// Might be racing with other updates, so use atomic update always.
// We used to be clever here and use a non-atomic update in certain
// cases, but it's not worth the risk.
atomic.Or8(m.bytep, m.mask)
}
// setMarkedNonAtomic sets the marked bit in the markbits, non-atomically.
func (m markBits) setMarkedNonAtomic() {
*m.bytep |= m.mask
}
// clearMarked clears the marked bit in the markbits, atomically.
func (m markBits) clearMarked() {
// Might be racing with other updates, so use atomic update always.
// We used to be clever here and use a non-atomic update in certain
// cases, but it's not worth the risk.
atomic.And8(m.bytep, ^m.mask)
}
// markBitsForSpan returns the markBits for the span base address base.
func markBitsForSpan(base uintptr) (mbits markBits) {
mbits = markBitsForAddr(base)
if mbits.mask != 1 {
throw("markBitsForSpan: unaligned start")
}
return mbits
}
// advance advances the markBits to the next object in the span.
func (m *markBits) advance() {
if m.mask == 1<<7 {
m.bytep = (*uint8)(unsafe.Pointer(uintptr(unsafe.Pointer(m.bytep)) + 1))
m.mask = 1
} else {
m.mask = m.mask << 1
}
m.index++
}
// clobberdeadPtr is a special value that is used by the compiler to
// clobber dead stack slots, when -clobberdead flag is set.
const clobberdeadPtr = uintptr(0xdeaddead | 0xdeaddead<<((^uintptr(0)>>63)*32))
// badPointer throws bad pointer in heap panic.
func badPointer(s *mspan, p, refBase, refOff uintptr) {
// Typically this indicates an incorrect use
// of unsafe or cgo to store a bad pointer in
// the Go heap. It may also indicate a runtime
// bug.
//
// TODO(austin): We could be more aggressive
// and detect pointers to unallocated objects
// in allocated spans.
printlock()
print("runtime: pointer ", hex(p))
if s != nil {
state := s.state.get()
if state != mSpanInUse {
print(" to unallocated span")
} else {
print(" to unused region of span")
}
print(" span.base()=", hex(s.base()), " span.limit=", hex(s.limit), " span.state=", state)
}
print("\n")
if refBase != 0 {
print("runtime: found in object at *(", hex(refBase), "+", hex(refOff), ")\n")
gcDumpObject("object", refBase, refOff)
}
getg().m.traceback = 2
throw("found bad pointer in Go heap (incorrect use of unsafe or cgo?)")
}
// findObject returns the base address for the heap object containing
// the address p, the object's span, and the index of the object in s.
// If p does not point into a heap object, it returns base == 0.
//
// If p points is an invalid heap pointer and debug.invalidptr != 0,
// findObject panics.
//
// refBase and refOff optionally give the base address of the object
// in which the pointer p was found and the byte offset at which it
// was found. These are used for error reporting.
//
// It is nosplit so it is safe for p to be a pointer to the current goroutine's stack.
// Since p is a uintptr, it would not be adjusted if the stack were to move.
//
//go:nosplit
func findObject(p, refBase, refOff uintptr) (base uintptr, s *mspan, objIndex uintptr) {
s = spanOf(p)
// If s is nil, the virtual address has never been part of the heap.
// This pointer may be to some mmap'd region, so we allow it.
if s == nil {
if (GOARCH == "amd64" || GOARCH == "arm64") && p == clobberdeadPtr && debug.invalidptr != 0 {
// Crash if clobberdeadPtr is seen. Only on AMD64 and ARM64 for now,
// as they are the only platform where compiler's clobberdead mode is
// implemented. On these platforms clobberdeadPtr cannot be a valid address.
badPointer(s, p, refBase, refOff)
}
return
}
// If p is a bad pointer, it may not be in s's bounds.
//
// Check s.state to synchronize with span initialization
// before checking other fields. See also spanOfHeap.
if state := s.state.get(); state != mSpanInUse || p < s.base() || p >= s.limit {
// Pointers into stacks are also ok, the runtime manages these explicitly.
if state == mSpanManual {
return
}
// The following ensures that we are rigorous about what data
// structures hold valid pointers.
if debug.invalidptr != 0 {
badPointer(s, p, refBase, refOff)
}
return
}
objIndex = s.objIndex(p)
base = s.base() + objIndex*s.elemsize
return
}
// reflect_verifyNotInHeapPtr reports whether converting the not-in-heap pointer into a unsafe.Pointer is ok.
//
//go:linkname reflect_verifyNotInHeapPtr reflect.verifyNotInHeapPtr
func reflect_verifyNotInHeapPtr(p uintptr) bool {
// Conversion to a pointer is ok as long as findObject above does not call badPointer.
// Since we're already promised that p doesn't point into the heap, just disallow heap
// pointers and the special clobbered pointer.
return spanOf(p) == nil && p != clobberdeadPtr
}
const ptrBits = 8 * goarch.PtrSize
// heapBits provides access to the bitmap bits for a single heap word.
// The methods on heapBits take value receivers so that the compiler
// can more easily inline calls to those methods and registerize the
// struct fields independently.
type heapBits struct {
// heapBits will report on pointers in the range [addr,addr+size).
// The low bit of mask contains the pointerness of the word at addr
// (assuming valid>0).
addr, size uintptr
// The next few pointer bits representing words starting at addr.
// Those bits already returned by next() are zeroed.
mask uintptr
// Number of bits in mask that are valid. mask is always less than 1<<valid.
valid uintptr
}
// heapBitsForAddr returns the heapBits for the address addr.
// The caller must ensure [addr,addr+size) is in an allocated span.
// In particular, be careful not to point past the end of an object.
//
// nosplit because it is used during write barriers and must not be preempted.
//
//go:nosplit
func heapBitsForAddr(addr, size uintptr) heapBits {
// Find arena
ai := arenaIndex(addr)
ha := mheap_.arenas[ai.l1()][ai.l2()]
// Word index in arena.
word := addr / goarch.PtrSize % heapArenaWords
// Word index and bit offset in bitmap array.
idx := word / ptrBits
off := word % ptrBits
// Grab relevant bits of bitmap.
mask := ha.bitmap[idx] >> off
valid := ptrBits - off
// Process depending on where the object ends.
nptr := size / goarch.PtrSize
if nptr < valid {
// Bits for this object end before the end of this bitmap word.
// Squash bits for the following objects.
mask &= 1<<(nptr&(ptrBits-1)) - 1
valid = nptr
} else if nptr == valid {
// Bits for this object end at exactly the end of this bitmap word.
// All good.
} else {
// Bits for this object extend into the next bitmap word. See if there
// may be any pointers recorded there.
if uintptr(ha.noMorePtrs[idx/8])>>(idx%8)&1 != 0 {
// No more pointers in this object after this bitmap word.
// Update size so we know not to look there.
size = valid * goarch.PtrSize
}
}
return heapBits{addr: addr, size: size, mask: mask, valid: valid}
}
// Returns the (absolute) address of the next known pointer and
// a heapBits iterator representing any remaining pointers.
// If there are no more pointers, returns address 0.
// Note that next does not modify h. The caller must record the result.
//
// nosplit because it is used during write barriers and must not be preempted.
//
//go:nosplit
func (h heapBits) next() (heapBits, uintptr) {
for {
if h.mask != 0 {
var i int
if goarch.PtrSize == 8 {
i = sys.TrailingZeros64(uint64(h.mask))
} else {
i = sys.TrailingZeros32(uint32(h.mask))
}
h.mask ^= uintptr(1) << (i & (ptrBits - 1))
return h, h.addr + uintptr(i)*goarch.PtrSize
}
// Skip words that we've already processed.
h.addr += h.valid * goarch.PtrSize
h.size -= h.valid * goarch.PtrSize
if h.size == 0 {
return h, 0 // no more pointers
}
// Grab more bits and try again.
h = heapBitsForAddr(h.addr, h.size)
}
}
// nextFast is like next, but can return 0 even when there are more pointers
// to be found. Callers should call next if nextFast returns 0 as its second
// return value.
//
// if addr, h = h.nextFast(); addr == 0 {
// if addr, h = h.next(); addr == 0 {
// ... no more pointers ...
// }
// }
// ... process pointer at addr ...
//
// nextFast is designed to be inlineable.
//
//go:nosplit
func (h heapBits) nextFast() (heapBits, uintptr) {
// TESTQ/JEQ
if h.mask == 0 {
return h, 0
}
// BSFQ
var i int
if goarch.PtrSize == 8 {
i = sys.TrailingZeros64(uint64(h.mask))
} else {
i = sys.TrailingZeros32(uint32(h.mask))
}
// BTCQ
h.mask ^= uintptr(1) << (i & (ptrBits - 1))
// LEAQ (XX)(XX*8)
return h, h.addr + uintptr(i)*goarch.PtrSize
}
// bulkBarrierPreWrite executes a write barrier
// for every pointer slot in the memory range [src, src+size),
// using pointer/scalar information from [dst, dst+size).
// This executes the write barriers necessary before a memmove.
// src, dst, and size must be pointer-aligned.
// The range [dst, dst+size) must lie within a single object.
// It does not perform the actual writes.
//
// As a special case, src == 0 indicates that this is being used for a
// memclr. bulkBarrierPreWrite will pass 0 for the src of each write
// barrier.
//
// Callers should call bulkBarrierPreWrite immediately before
// calling memmove(dst, src, size). This function is marked nosplit
// to avoid being preempted; the GC must not stop the goroutine
// between the memmove and the execution of the barriers.
// The caller is also responsible for cgo pointer checks if this
// may be writing Go pointers into non-Go memory.
//
// The pointer bitmap is not maintained for allocations containing
// no pointers at all; any caller of bulkBarrierPreWrite must first
// make sure the underlying allocation contains pointers, usually
// by checking typ.ptrdata.
//
// Callers must perform cgo checks if goexperiment.CgoCheck2.
//
//go:nosplit
func bulkBarrierPreWrite(dst, src, size uintptr) {
if (dst|src|size)&(goarch.PtrSize-1) != 0 {
throw("bulkBarrierPreWrite: unaligned arguments")
}
if !writeBarrier.needed {
return
}
if s := spanOf(dst); s == nil {
// If dst is a global, use the data or BSS bitmaps to
// execute write barriers.
for _, datap := range activeModules() {
if datap.data <= dst && dst < datap.edata {
bulkBarrierBitmap(dst, src, size, dst-datap.data, datap.gcdatamask.bytedata)
return
}
}
for _, datap := range activeModules() {
if datap.bss <= dst && dst < datap.ebss {
bulkBarrierBitmap(dst, src, size, dst-datap.bss, datap.gcbssmask.bytedata)
return
}
}
return
} else if s.state.get() != mSpanInUse || dst < s.base() || s.limit <= dst {
// dst was heap memory at some point, but isn't now.
// It can't be a global. It must be either our stack,
// or in the case of direct channel sends, it could be
// another stack. Either way, no need for barriers.
// This will also catch if dst is in a freed span,
// though that should never have.
return
}
buf := &getg().m.p.ptr().wbBuf
h := heapBitsForAddr(dst, size)
if src == 0 {
for {
var addr uintptr
if h, addr = h.next(); addr == 0 {
break
}
dstx := (*uintptr)(unsafe.Pointer(addr))
p := buf.get1()
p[0] = *dstx
}
} else {
for {
var addr uintptr
if h, addr = h.next(); addr == 0 {
break
}
dstx := (*uintptr)(unsafe.Pointer(addr))
srcx := (*uintptr)(unsafe.Pointer(src + (addr - dst)))
p := buf.get2()
p[0] = *dstx
p[1] = *srcx
}
}
}
// bulkBarrierPreWriteSrcOnly is like bulkBarrierPreWrite but
// does not execute write barriers for [dst, dst+size).
//
// In addition to the requirements of bulkBarrierPreWrite
// callers need to ensure [dst, dst+size) is zeroed.
//
// This is used for special cases where e.g. dst was just
// created and zeroed with malloc.
//
//go:nosplit
func bulkBarrierPreWriteSrcOnly(dst, src, size uintptr) {
if (dst|src|size)&(goarch.PtrSize-1) != 0 {
throw("bulkBarrierPreWrite: unaligned arguments")
}
if !writeBarrier.needed {
return
}
buf := &getg().m.p.ptr().wbBuf
h := heapBitsForAddr(dst, size)
for {
var addr uintptr
if h, addr = h.next(); addr == 0 {
break
}
srcx := (*uintptr)(unsafe.Pointer(addr - dst + src))
p := buf.get1()
p[0] = *srcx
}
}
// bulkBarrierBitmap executes write barriers for copying from [src,
// src+size) to [dst, dst+size) using a 1-bit pointer bitmap. src is
// assumed to start maskOffset bytes into the data covered by the
// bitmap in bits (which may not be a multiple of 8).
//
// This is used by bulkBarrierPreWrite for writes to data and BSS.
//
//go:nosplit
func bulkBarrierBitmap(dst, src, size, maskOffset uintptr, bits *uint8) {
word := maskOffset / goarch.PtrSize
bits = addb(bits, word/8)
mask := uint8(1) << (word % 8)
buf := &getg().m.p.ptr().wbBuf
for i := uintptr(0); i < size; i += goarch.PtrSize {
if mask == 0 {
bits = addb(bits, 1)
if *bits == 0 {
// Skip 8 words.
i += 7 * goarch.PtrSize
continue
}
mask = 1
}
if *bits&mask != 0 {
dstx := (*uintptr)(unsafe.Pointer(dst + i))
if src == 0 {
p := buf.get1()
p[0] = *dstx
} else {
srcx := (*uintptr)(unsafe.Pointer(src + i))
p := buf.get2()
p[0] = *dstx
p[1] = *srcx
}
}
mask <<= 1
}
}
// typeBitsBulkBarrier executes a write barrier for every
// pointer that would be copied from [src, src+size) to [dst,
// dst+size) by a memmove using the type bitmap to locate those
// pointer slots.
//
// The type typ must correspond exactly to [src, src+size) and [dst, dst+size).
// dst, src, and size must be pointer-aligned.
// The type typ must have a plain bitmap, not a GC program.
// The only use of this function is in channel sends, and the
// 64 kB channel element limit takes care of this for us.
//
// Must not be preempted because it typically runs right before memmove,
// and the GC must observe them as an atomic action.
//
// Callers must perform cgo checks if goexperiment.CgoCheck2.
//
//go:nosplit
func typeBitsBulkBarrier(typ *_type, dst, src, size uintptr) {
if typ == nil {
throw("runtime: typeBitsBulkBarrier without type")
}
if typ.size != size {
println("runtime: typeBitsBulkBarrier with type ", typ.string(), " of size ", typ.size, " but memory size", size)
throw("runtime: invalid typeBitsBulkBarrier")
}
if typ.kind&kindGCProg != 0 {
println("runtime: typeBitsBulkBarrier with type ", typ.string(), " with GC prog")
throw("runtime: invalid typeBitsBulkBarrier")
}
if !writeBarrier.needed {
return
}
ptrmask := typ.gcdata
buf := &getg().m.p.ptr().wbBuf
var bits uint32
for i := uintptr(0); i < typ.ptrdata; i += goarch.PtrSize {
if i&(goarch.PtrSize*8-1) == 0 {
bits = uint32(*ptrmask)
ptrmask = addb(ptrmask, 1)
} else {
bits = bits >> 1
}
if bits&1 != 0 {
dstx := (*uintptr)(unsafe.Pointer(dst + i))
srcx := (*uintptr)(unsafe.Pointer(src + i))
p := buf.get2()
p[0] = *dstx
p[1] = *srcx
}
}
}
// initHeapBits initializes the heap bitmap for a span.
// If this is a span of single pointer allocations, it initializes all
// words to pointer. If force is true, clears all bits.
func (s *mspan) initHeapBits(forceClear bool) {
if forceClear || s.spanclass.noscan() {
// Set all the pointer bits to zero. We do this once
// when the span is allocated so we don't have to do it
// for each object allocation.
base := s.base()
size := s.npages * pageSize
h := writeHeapBitsForAddr(base)
h.flush(base, size)
return
}
isPtrs := goarch.PtrSize == 8 && s.elemsize == goarch.PtrSize
if !isPtrs {
return // nothing to do
}
h := writeHeapBitsForAddr(s.base())
size := s.npages * pageSize
nptrs := size / goarch.PtrSize
for i := uintptr(0); i < nptrs; i += ptrBits {
h = h.write(^uintptr(0), ptrBits)
}
h.flush(s.base(), size)
}
// countAlloc returns the number of objects allocated in span s by
// scanning the allocation bitmap.
func (s *mspan) countAlloc() int {
count := 0
bytes := divRoundUp(s.nelems, 8)
// Iterate over each 8-byte chunk and count allocations
// with an intrinsic. Note that newMarkBits guarantees that
// gcmarkBits will be 8-byte aligned, so we don't have to
// worry about edge cases, irrelevant bits will simply be zero.
for i := uintptr(0); i < bytes; i += 8 {
// Extract 64 bits from the byte pointer and get a OnesCount.
// Note that the unsafe cast here doesn't preserve endianness,
// but that's OK. We only care about how many bits are 1, not
// about the order we discover them in.
mrkBits := *(*uint64)(unsafe.Pointer(s.gcmarkBits.bytep(i)))
count += sys.OnesCount64(mrkBits)
}
return count
}
type writeHeapBits struct {
addr uintptr // address that the low bit of mask represents the pointer state of.
mask uintptr // some pointer bits starting at the address addr.
valid uintptr // number of bits in buf that are valid (including low)
low uintptr // number of low-order bits to not overwrite
}
func writeHeapBitsForAddr(addr uintptr) (h writeHeapBits) {
// We start writing bits maybe in the middle of a heap bitmap word.
// Remember how many bits into the word we started, so we can be sure
// not to overwrite the previous bits.
h.low = addr / goarch.PtrSize % ptrBits
// round down to heap word that starts the bitmap word.
h.addr = addr - h.low*goarch.PtrSize
// We don't have any bits yet.
h.mask = 0
h.valid = h.low
return
}
// write appends the pointerness of the next valid pointer slots
// using the low valid bits of bits. 1=pointer, 0=scalar.
func (h writeHeapBits) write(bits, valid uintptr) writeHeapBits {
if h.valid+valid <= ptrBits {
// Fast path - just accumulate the bits.
h.mask |= bits << h.valid
h.valid += valid
return h
}
// Too many bits to fit in this word. Write the current word
// out and move on to the next word.
data := h.mask | bits<<h.valid // mask for this word
h.mask = bits >> (ptrBits - h.valid) // leftover for next word
h.valid += valid - ptrBits // have h.valid+valid bits, writing ptrBits of them
// Flush mask to the memory bitmap.
// TODO: figure out how to cache arena lookup.
ai := arenaIndex(h.addr)
ha := mheap_.arenas[ai.l1()][ai.l2()]
idx := h.addr / (ptrBits * goarch.PtrSize) % heapArenaBitmapWords
m := uintptr(1)<<h.low - 1
ha.bitmap[idx] = ha.bitmap[idx]&m | data
// Note: no synchronization required for this write because
// the allocator has exclusive access to the page, and the bitmap
// entries are all for a single page. Also, visibility of these
// writes is guaranteed by the publication barrier in mallocgc.
// Clear noMorePtrs bit, since we're going to be writing bits
// into the following word.
ha.noMorePtrs[idx/8] &^= uint8(1) << (idx % 8)
// Note: same as above
// Move to next word of bitmap.
h.addr += ptrBits * goarch.PtrSize
h.low = 0
return h
}
// Add padding of size bytes.
func (h writeHeapBits) pad(size uintptr) writeHeapBits {
if size == 0 {
return h
}
words := size / goarch.PtrSize
for words > ptrBits {
h = h.write(0, ptrBits)
words -= ptrBits
}
return h.write(0, words)
}
// Flush the bits that have been written, and add zeros as needed
// to cover the full object [addr, addr+size).
func (h writeHeapBits) flush(addr, size uintptr) {
// zeros counts the number of bits needed to represent the object minus the
// number of bits we've already written. This is the number of 0 bits
// that need to be added.
zeros := (addr+size-h.addr)/goarch.PtrSize - h.valid
// Add zero bits up to the bitmap word boundary
if zeros > 0 {
z := ptrBits - h.valid
if z > zeros {
z = zeros
}
h.valid += z
zeros -= z
}
// Find word in bitmap that we're going to write.
ai := arenaIndex(h.addr)
ha := mheap_.arenas[ai.l1()][ai.l2()]
idx := h.addr / (ptrBits * goarch.PtrSize) % heapArenaBitmapWords
// Write remaining bits.
if h.valid != h.low {
m := uintptr(1)<<h.low - 1 // don't clear existing bits below "low"
m |= ^(uintptr(1)<<h.valid - 1) // don't clear existing bits above "valid"
ha.bitmap[idx] = ha.bitmap[idx]&m | h.mask
}
if zeros == 0 {
return
}
// Record in the noMorePtrs map that there won't be any more 1 bits,
// so readers can stop early.
ha.noMorePtrs[idx/8] |= uint8(1) << (idx % 8)
// Advance to next bitmap word.
h.addr += ptrBits * goarch.PtrSize
// Continue on writing zeros for the rest of the object.
// For standard use of the ptr bits this is not required, as
// the bits are read from the beginning of the object. Some uses,
// like noscan spans, oblets, bulk write barriers, and cgocheck, might
// start mid-object, so these writes are still required.
for {
// Write zero bits.
ai := arenaIndex(h.addr)
ha := mheap_.arenas[ai.l1()][ai.l2()]
idx := h.addr / (ptrBits * goarch.PtrSize) % heapArenaBitmapWords
if zeros < ptrBits {
ha.bitmap[idx] &^= uintptr(1)<<zeros - 1
break
} else if zeros == ptrBits {
ha.bitmap[idx] = 0
break
} else {
ha.bitmap[idx] = 0
zeros -= ptrBits
}
ha.noMorePtrs[idx/8] |= uint8(1) << (idx % 8)
h.addr += ptrBits * goarch.PtrSize
}
}
// Read the bytes starting at the aligned pointer p into a uintptr.
// Read is little-endian.
func readUintptr(p *byte) uintptr {
x := *(*uintptr)(unsafe.Pointer(p))
if goarch.BigEndian {
if goarch.PtrSize == 8 {
return uintptr(sys.Bswap64(uint64(x)))
}
return uintptr(sys.Bswap32(uint32(x)))
}
return x
}
// heapBitsSetType records that the new allocation [x, x+size)
// holds in [x, x+dataSize) one or more values of type typ.
// (The number of values is given by dataSize / typ.size.)
// If dataSize < size, the fragment [x+dataSize, x+size) is
// recorded as non-pointer data.
// It is known that the type has pointers somewhere;
// malloc does not call heapBitsSetType when there are no pointers,
// because all free objects are marked as noscan during
// heapBitsSweepSpan.
//
// There can only be one allocation from a given span active at a time,
// and the bitmap for a span always falls on word boundaries,
// so there are no write-write races for access to the heap bitmap.
// Hence, heapBitsSetType can access the bitmap without atomics.
//
// There can be read-write races between heapBitsSetType and things
// that read the heap bitmap like scanobject. However, since
// heapBitsSetType is only used for objects that have not yet been
// made reachable, readers will ignore bits being modified by this
// function. This does mean this function cannot transiently modify
// bits that belong to neighboring objects. Also, on weakly-ordered
// machines, callers must execute a store/store (publication) barrier
// between calling this function and making the object reachable.
func heapBitsSetType(x, size, dataSize uintptr, typ *_type) {
const doubleCheck = false // slow but helpful; enable to test modifications to this code
if doubleCheck && dataSize%typ.size != 0 {
throw("heapBitsSetType: dataSize not a multiple of typ.size")
}
if goarch.PtrSize == 8 && size == goarch.PtrSize {
// It's one word and it has pointers, it must be a pointer.
// Since all allocated one-word objects are pointers
// (non-pointers are aggregated into tinySize allocations),
// (*mspan).initHeapBits sets the pointer bits for us.
// Nothing to do here.
if doubleCheck {
h, addr := heapBitsForAddr(x, size).next()
if addr != x {
throw("heapBitsSetType: pointer bit missing")
}
_, addr = h.next()
if addr != 0 {
throw("heapBitsSetType: second pointer bit found")
}
}
return
}
h := writeHeapBitsForAddr(x)
// Handle GC program.
if typ.kind&kindGCProg != 0 {
// Expand the gc program into the storage we're going to use for the actual object.
obj := (*uint8)(unsafe.Pointer(x))
n := runGCProg(addb(typ.gcdata, 4), obj)
// Use the expanded program to set the heap bits.
for i := uintptr(0); true; i += typ.size {
// Copy expanded program to heap bitmap.
p := obj
j := n
for j > 8 {
h = h.write(uintptr(*p), 8)
p = add1(p)
j -= 8
}
h = h.write(uintptr(*p), j)
if i+typ.size == dataSize {
break // no padding after last element
}
// Pad with zeros to the start of the next element.
h = h.pad(typ.size - n*goarch.PtrSize)
}
h.flush(x, size)
// Erase the expanded GC program.
memclrNoHeapPointers(unsafe.Pointer(obj), (n+7)/8)
return
}
// Note about sizes:
//
// typ.size is the number of words in the object,
// and typ.ptrdata is the number of words in the prefix
// of the object that contains pointers. That is, the final
// typ.size - typ.ptrdata words contain no pointers.
// This allows optimization of a common pattern where
// an object has a small header followed by a large scalar
// buffer. If we know the pointers are over, we don't have
// to scan the buffer's heap bitmap at all.
// The 1-bit ptrmasks are sized to contain only bits for
// the typ.ptrdata prefix, zero padded out to a full byte
// of bitmap. If there is more room in the allocated object,
// that space is pointerless. The noMorePtrs bitmap will prevent
// scanning large pointerless tails of an object.
//
// Replicated copies are not as nice: if there is an array of
// objects with scalar tails, all but the last tail does have to
// be initialized, because there is no way to say "skip forward".
ptrs := typ.ptrdata / goarch.PtrSize
if typ.size == dataSize { // Single element
if ptrs <= ptrBits { // Single small element
m := readUintptr(typ.gcdata)
h = h.write(m, ptrs)
} else { // Single large element
p := typ.gcdata
for {
h = h.write(readUintptr(p), ptrBits)
p = addb(p, ptrBits/8)
ptrs -= ptrBits
if ptrs <= ptrBits {
break
}
}
m := readUintptr(p)
h = h.write(m, ptrs)
}
} else { // Repeated element
words := typ.size / goarch.PtrSize // total words, including scalar tail
if words <= ptrBits { // Repeated small element
n := dataSize / typ.size
m := readUintptr(typ.gcdata)
// Make larger unit to repeat
for words <= ptrBits/2 {
if n&1 != 0 {
h = h.write(m, words)
}
n /= 2
m |= m << words
ptrs += words
words *= 2
if n == 1 {
break
}
}
for n > 1 {
h = h.write(m, words)
n--
}
h = h.write(m, ptrs)
} else { // Repeated large element
for i := uintptr(0); true; i += typ.size {
p := typ.gcdata
j := ptrs
for j > ptrBits {
h = h.write(readUintptr(p), ptrBits)
p = addb(p, ptrBits/8)
j -= ptrBits
}
m := readUintptr(p)
h = h.write(m, j)
if i+typ.size == dataSize {
break // don't need the trailing nonptr bits on the last element.
}
// Pad with zeros to the start of the next element.
h = h.pad(typ.size - typ.ptrdata)
}
}
}
h.flush(x, size)
if doubleCheck {
h := heapBitsForAddr(x, size)
for i := uintptr(0); i < size; i += goarch.PtrSize {
// Compute the pointer bit we want at offset i.
want := false
if i < dataSize {
off := i % typ.size
if off < typ.ptrdata {
j := off / goarch.PtrSize
want = *addb(typ.gcdata, j/8)>>(j%8)&1 != 0
}
}
if want {
var addr uintptr
h, addr = h.next()
if addr != x+i {
throw("heapBitsSetType: pointer entry not correct")
}
}
}
if _, addr := h.next(); addr != 0 {
throw("heapBitsSetType: extra pointer")
}
}
}
var debugPtrmask struct {
lock mutex
data *byte
}
// progToPointerMask returns the 1-bit pointer mask output by the GC program prog.
// size the size of the region described by prog, in bytes.
// The resulting bitvector will have no more than size/goarch.PtrSize bits.
func progToPointerMask(prog *byte, size uintptr) bitvector {
n := (size/goarch.PtrSize + 7) / 8
x := (*[1 << 30]byte)(persistentalloc(n+1, 1, &memstats.buckhash_sys))[:n+1]
x[len(x)-1] = 0xa1 // overflow check sentinel
n = runGCProg(prog, &x[0])
if x[len(x)-1] != 0xa1 {
throw("progToPointerMask: overflow")
}
return bitvector{int32(n), &x[0]}
}
// Packed GC pointer bitmaps, aka GC programs.
//
// For large types containing arrays, the type information has a
// natural repetition that can be encoded to save space in the
// binary and in the memory representation of the type information.
//
// The encoding is a simple Lempel-Ziv style bytecode machine
// with the following instructions:
//
// 00000000: stop
// 0nnnnnnn: emit n bits copied from the next (n+7)/8 bytes
// 10000000 n c: repeat the previous n bits c times; n, c are varints
// 1nnnnnnn c: repeat the previous n bits c times; c is a varint
// runGCProg returns the number of 1-bit entries written to memory.
func runGCProg(prog, dst *byte) uintptr {
dstStart := dst
// Bits waiting to be written to memory.
var bits uintptr
var nbits uintptr
p := prog
Run:
for {
// Flush accumulated full bytes.
// The rest of the loop assumes that nbits <= 7.
for ; nbits >= 8; nbits -= 8 {
*dst = uint8(bits)
dst = add1(dst)
bits >>= 8
}
// Process one instruction.
inst := uintptr(*p)
p = add1(p)
n := inst & 0x7F
if inst&0x80 == 0 {
// Literal bits; n == 0 means end of program.
if n == 0 {
// Program is over.
break Run
}
nbyte := n / 8
for i := uintptr(0); i < nbyte; i++ {
bits |= uintptr(*p) << nbits
p = add1(p)
*dst = uint8(bits)
dst = add1(dst)
bits >>= 8
}
if n %= 8; n > 0 {
bits |= uintptr(*p) << nbits
p = add1(p)
nbits += n
}
continue Run
}
// Repeat. If n == 0, it is encoded in a varint in the next bytes.
if n == 0 {
for off := uint(0); ; off += 7 {
x := uintptr(*p)
p = add1(p)
n |= (x & 0x7F) << off
if x&0x80 == 0 {
break
}
}
}
// Count is encoded in a varint in the next bytes.
c := uintptr(0)
for off := uint(0); ; off += 7 {
x := uintptr(*p)
p = add1(p)
c |= (x & 0x7F) << off
if x&0x80 == 0 {
break
}
}
c *= n // now total number of bits to copy
// If the number of bits being repeated is small, load them
// into a register and use that register for the entire loop
// instead of repeatedly reading from memory.
// Handling fewer than 8 bits here makes the general loop simpler.
// The cutoff is goarch.PtrSize*8 - 7 to guarantee that when we add
// the pattern to a bit buffer holding at most 7 bits (a partial byte)
// it will not overflow.
src := dst
const maxBits = goarch.PtrSize*8 - 7
if n <= maxBits {
// Start with bits in output buffer.
pattern := bits
npattern := nbits
// If we need more bits, fetch them from memory.
src = subtract1(src)
for npattern < n {
pattern <<= 8
pattern |= uintptr(*src)
src = subtract1(src)
npattern += 8
}
// We started with the whole bit output buffer,
// and then we loaded bits from whole bytes.
// Either way, we might now have too many instead of too few.
// Discard the extra.
if npattern > n {
pattern >>= npattern - n
npattern = n
}
// Replicate pattern to at most maxBits.
if npattern == 1 {
// One bit being repeated.
// If the bit is 1, make the pattern all 1s.
// If the bit is 0, the pattern is already all 0s,
// but we can claim that the number of bits
// in the word is equal to the number we need (c),
// because right shift of bits will zero fill.
if pattern == 1 {
pattern = 1<<maxBits - 1
npattern = maxBits
} else {
npattern = c
}
} else {
b := pattern
nb := npattern
if nb+nb <= maxBits {
// Double pattern until the whole uintptr is filled.
for nb <= goarch.PtrSize*8 {
b |= b << nb
nb += nb
}
// Trim away incomplete copy of original pattern in high bits.
// TODO(rsc): Replace with table lookup or loop on systems without divide?
nb = maxBits / npattern * npattern
b &= 1<<nb - 1
pattern = b
npattern = nb
}
}
// Add pattern to bit buffer and flush bit buffer, c/npattern times.
// Since pattern contains >8 bits, there will be full bytes to flush
// on each iteration.
for ; c >= npattern; c -= npattern {
bits |= pattern << nbits
nbits += npattern
for nbits >= 8 {
*dst = uint8(bits)
dst = add1(dst)
bits >>= 8
nbits -= 8
}
}
// Add final fragment to bit buffer.
if c > 0 {
pattern &= 1<<c - 1
bits |= pattern << nbits
nbits += c
}
continue Run
}
// Repeat; n too large to fit in a register.
// Since nbits <= 7, we know the first few bytes of repeated data
// are already written to memory.
off := n - nbits // n > nbits because n > maxBits and nbits <= 7
// Leading src fragment.
src = subtractb(src, (off+7)/8)
if frag := off & 7; frag != 0 {
bits |= uintptr(*src) >> (8 - frag) << nbits
src = add1(src)
nbits += frag
c -= frag
}
// Main loop: load one byte, write another.
// The bits are rotating through the bit buffer.
for i := c / 8; i > 0; i-- {
bits |= uintptr(*src) << nbits
src = add1(src)
*dst = uint8(bits)
dst = add1(dst)
bits >>= 8
}
// Final src fragment.
if c %= 8; c > 0 {
bits |= (uintptr(*src) & (1<<c - 1)) << nbits
nbits += c
}
}
// Write any final bits out, using full-byte writes, even for the final byte.
totalBits := (uintptr(unsafe.Pointer(dst))-uintptr(unsafe.Pointer(dstStart)))*8 + nbits
nbits += -nbits & 7
for ; nbits > 0; nbits -= 8 {
*dst = uint8(bits)
dst = add1(dst)
bits >>= 8
}
return totalBits
}
// materializeGCProg allocates space for the (1-bit) pointer bitmask
// for an object of size ptrdata. Then it fills that space with the
// pointer bitmask specified by the program prog.
// The bitmask starts at s.startAddr.
// The result must be deallocated with dematerializeGCProg.
func materializeGCProg(ptrdata uintptr, prog *byte) *mspan {
// Each word of ptrdata needs one bit in the bitmap.
bitmapBytes := divRoundUp(ptrdata, 8*goarch.PtrSize)
// Compute the number of pages needed for bitmapBytes.
pages := divRoundUp(bitmapBytes, pageSize)
s := mheap_.allocManual(pages, spanAllocPtrScalarBits)
runGCProg(addb(prog, 4), (*byte)(unsafe.Pointer(s.startAddr)))
return s
}
func dematerializeGCProg(s *mspan) {
mheap_.freeManual(s, spanAllocPtrScalarBits)
}
func dumpGCProg(p *byte) {
nptr := 0
for {
x := *p
p = add1(p)
if x == 0 {
print("\t", nptr, " end\n")
break
}
if x&0x80 == 0 {
print("\t", nptr, " lit ", x, ":")
n := int(x+7) / 8
for i := 0; i < n; i++ {
print(" ", hex(*p))
p = add1(p)
}
print("\n")
nptr += int(x)
} else {
nbit := int(x &^ 0x80)
if nbit == 0 {
for nb := uint(0); ; nb += 7 {
x := *p
p = add1(p)
nbit |= int(x&0x7f) << nb
if x&0x80 == 0 {
break
}
}
}
count := 0
for nb := uint(0); ; nb += 7 {
x := *p
p = add1(p)
count |= int(x&0x7f) << nb
if x&0x80 == 0 {
break
}
}
print("\t", nptr, " repeat ", nbit, " × ", count, "\n")
nptr += nbit * count
}
}
}
// Testing.
func getgcmaskcb(frame *stkframe, ctxt unsafe.Pointer) bool {
target := (*stkframe)(ctxt)
if frame.sp <= target.sp && target.sp < frame.varp {
*target = *frame
return false
}
return true
}
// reflect_gcbits returns the GC type info for x, for testing.
// The result is the bitmap entries (0 or 1), one entry per byte.
//
//go:linkname reflect_gcbits reflect.gcbits
func reflect_gcbits(x any) []byte {
return getgcmask(x)
}
// Returns GC type info for the pointer stored in ep for testing.
// If ep points to the stack, only static live information will be returned
// (i.e. not for objects which are only dynamically live stack objects).
func getgcmask(ep any) (mask []byte) {
e := *efaceOf(&ep)
p := e.data
t := e._type
// data or bss
for _, datap := range activeModules() {
// data
if datap.data <= uintptr(p) && uintptr(p) < datap.edata {
bitmap := datap.gcdatamask.bytedata
n := (*ptrtype)(unsafe.Pointer(t)).elem.size
mask = make([]byte, n/goarch.PtrSize)
for i := uintptr(0); i < n; i += goarch.PtrSize {
off := (uintptr(p) + i - datap.data) / goarch.PtrSize
mask[i/goarch.PtrSize] = (*addb(bitmap, off/8) >> (off % 8)) & 1
}
return
}
// bss
if datap.bss <= uintptr(p) && uintptr(p) < datap.ebss {
bitmap := datap.gcbssmask.bytedata
n := (*ptrtype)(unsafe.Pointer(t)).elem.size
mask = make([]byte, n/goarch.PtrSize)
for i := uintptr(0); i < n; i += goarch.PtrSize {
off := (uintptr(p) + i - datap.bss) / goarch.PtrSize
mask[i/goarch.PtrSize] = (*addb(bitmap, off/8) >> (off % 8)) & 1
}
return
}
}
// heap
if base, s, _ := findObject(uintptr(p), 0, 0); base != 0 {
if s.spanclass.noscan() {
return nil
}
n := s.elemsize
hbits := heapBitsForAddr(base, n)
mask = make([]byte, n/goarch.PtrSize)
for {
var addr uintptr
if hbits, addr = hbits.next(); addr == 0 {
break
}
mask[(addr-base)/goarch.PtrSize] = 1
}
// Callers expect this mask to end at the last pointer.
for len(mask) > 0 && mask[len(mask)-1] == 0 {
mask = mask[:len(mask)-1]
}
return
}
// stack
if gp := getg(); gp.m.curg.stack.lo <= uintptr(p) && uintptr(p) < gp.m.curg.stack.hi {
var frame stkframe
frame.sp = uintptr(p)
gentraceback(gp.m.curg.sched.pc, gp.m.curg.sched.sp, 0, gp.m.curg, 0, nil, 1000, getgcmaskcb, noescape(unsafe.Pointer(&frame)), 0)
if frame.fn.valid() {
locals, _, _ := frame.getStackMap(nil, false)
if locals.n == 0 {
return
}
size := uintptr(locals.n) * goarch.PtrSize
n := (*ptrtype)(unsafe.Pointer(t)).elem.size
mask = make([]byte, n/goarch.PtrSize)
for i := uintptr(0); i < n; i += goarch.PtrSize {
off := (uintptr(p) + i - frame.varp + size) / goarch.PtrSize
mask[i/goarch.PtrSize] = locals.ptrbit(off)
}
}
return
}
// otherwise, not something the GC knows about.
// possibly read-only data, like malloc(0).
// must not have pointers
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// Per-thread (in Go, per-P) cache for small objects.
// This includes a small object cache and local allocation stats.
// No locking needed because it is per-thread (per-P).
//
// mcaches are allocated from non-GC'd memory, so any heap pointers
// must be specially handled.
type mcache struct {
_ sys.NotInHeap
// The following members are accessed on every malloc,
// so they are grouped here for better caching.
nextSample uintptr // trigger heap sample after allocating this many bytes
scanAlloc uintptr // bytes of scannable heap allocated
// Allocator cache for tiny objects w/o pointers.
// See "Tiny allocator" comment in malloc.go.
// tiny points to the beginning of the current tiny block, or
// nil if there is no current tiny block.
//
// tiny is a heap pointer. Since mcache is in non-GC'd memory,
// we handle it by clearing it in releaseAll during mark
// termination.
//
// tinyAllocs is the number of tiny allocations performed
// by the P that owns this mcache.
tiny uintptr
tinyoffset uintptr
tinyAllocs uintptr
// The rest is not accessed on every malloc.
alloc [numSpanClasses]*mspan // spans to allocate from, indexed by spanClass
stackcache [_NumStackOrders]stackfreelist
// flushGen indicates the sweepgen during which this mcache
// was last flushed. If flushGen != mheap_.sweepgen, the spans
// in this mcache are stale and need to the flushed so they
// can be swept. This is done in acquirep.
flushGen atomic.Uint32
}
// A gclink is a node in a linked list of blocks, like mlink,
// but it is opaque to the garbage collector.
// The GC does not trace the pointers during collection,
// and the compiler does not emit write barriers for assignments
// of gclinkptr values. Code should store references to gclinks
// as gclinkptr, not as *gclink.
type gclink struct {
next gclinkptr
}
// A gclinkptr is a pointer to a gclink, but it is opaque
// to the garbage collector.
type gclinkptr uintptr
// ptr returns the *gclink form of p.
// The result should be used for accessing fields, not stored
// in other data structures.
func (p gclinkptr) ptr() *gclink {
return (*gclink)(unsafe.Pointer(p))
}
type stackfreelist struct {
list gclinkptr // linked list of free stacks
size uintptr // total size of stacks in list
}
// dummy mspan that contains no free objects.
var emptymspan mspan
func allocmcache() *mcache {
var c *mcache
systemstack(func() {
lock(&mheap_.lock)
c = (*mcache)(mheap_.cachealloc.alloc())
c.flushGen.Store(mheap_.sweepgen)
unlock(&mheap_.lock)
})
for i := range c.alloc {
c.alloc[i] = &emptymspan
}
c.nextSample = nextSample()
return c
}
// freemcache releases resources associated with this
// mcache and puts the object onto a free list.
//
// In some cases there is no way to simply release
// resources, such as statistics, so donate them to
// a different mcache (the recipient).
func freemcache(c *mcache) {
systemstack(func() {
c.releaseAll()
stackcache_clear(c)
// NOTE(rsc,rlh): If gcworkbuffree comes back, we need to coordinate
// with the stealing of gcworkbufs during garbage collection to avoid
// a race where the workbuf is double-freed.
// gcworkbuffree(c.gcworkbuf)
lock(&mheap_.lock)
mheap_.cachealloc.free(unsafe.Pointer(c))
unlock(&mheap_.lock)
})
}
// getMCache is a convenience function which tries to obtain an mcache.
//
// Returns nil if we're not bootstrapping or we don't have a P. The caller's
// P must not change, so we must be in a non-preemptible state.
func getMCache(mp *m) *mcache {
// Grab the mcache, since that's where stats live.
pp := mp.p.ptr()
var c *mcache
if pp == nil {
// We will be called without a P while bootstrapping,
// in which case we use mcache0, which is set in mallocinit.
// mcache0 is cleared when bootstrapping is complete,
// by procresize.
c = mcache0
} else {
c = pp.mcache
}
return c
}
// refill acquires a new span of span class spc for c. This span will
// have at least one free object. The current span in c must be full.
//
// Must run in a non-preemptible context since otherwise the owner of
// c could change.
func (c *mcache) refill(spc spanClass) {
// Return the current cached span to the central lists.
s := c.alloc[spc]
if uintptr(s.allocCount) != s.nelems {
throw("refill of span with free space remaining")
}
if s != &emptymspan {
// Mark this span as no longer cached.
if s.sweepgen != mheap_.sweepgen+3 {
throw("bad sweepgen in refill")
}
mheap_.central[spc].mcentral.uncacheSpan(s)
// Count up how many slots were used and record it.
stats := memstats.heapStats.acquire()
slotsUsed := int64(s.allocCount) - int64(s.allocCountBeforeCache)
atomic.Xadd64(&stats.smallAllocCount[spc.sizeclass()], slotsUsed)
// Flush tinyAllocs.
if spc == tinySpanClass {
atomic.Xadd64(&stats.tinyAllocCount, int64(c.tinyAllocs))
c.tinyAllocs = 0
}
memstats.heapStats.release()
// Count the allocs in inconsistent, internal stats.
bytesAllocated := slotsUsed * int64(s.elemsize)
gcController.totalAlloc.Add(bytesAllocated)
// Clear the second allocCount just to be safe.
s.allocCountBeforeCache = 0
}
// Get a new cached span from the central lists.
s = mheap_.central[spc].mcentral.cacheSpan()
if s == nil {
throw("out of memory")
}
if uintptr(s.allocCount) == s.nelems {
throw("span has no free space")
}
// Indicate that this span is cached and prevent asynchronous
// sweeping in the next sweep phase.
s.sweepgen = mheap_.sweepgen + 3
// Store the current alloc count for accounting later.
s.allocCountBeforeCache = s.allocCount
// Update heapLive and flush scanAlloc.
//
// We have not yet allocated anything new into the span, but we
// assume that all of its slots will get used, so this makes
// heapLive an overestimate.
//
// When the span gets uncached, we'll fix up this overestimate
// if necessary (see releaseAll).
//
// We pick an overestimate here because an underestimate leads
// the pacer to believe that it's in better shape than it is,
// which appears to lead to more memory used. See #53738 for
// more details.
usedBytes := uintptr(s.allocCount) * s.elemsize
gcController.update(int64(s.npages*pageSize)-int64(usedBytes), int64(c.scanAlloc))
c.scanAlloc = 0
c.alloc[spc] = s
}
// allocLarge allocates a span for a large object.
func (c *mcache) allocLarge(size uintptr, noscan bool) *mspan {
if size+_PageSize < size {
throw("out of memory")
}
npages := size >> _PageShift
if size&_PageMask != 0 {
npages++
}
// Deduct credit for this span allocation and sweep if
// necessary. mHeap_Alloc will also sweep npages, so this only
// pays the debt down to npage pages.
deductSweepCredit(npages*_PageSize, npages)
spc := makeSpanClass(0, noscan)
s := mheap_.alloc(npages, spc)
if s == nil {
throw("out of memory")
}
// Count the alloc in consistent, external stats.
stats := memstats.heapStats.acquire()
atomic.Xadd64(&stats.largeAlloc, int64(npages*pageSize))
atomic.Xadd64(&stats.largeAllocCount, 1)
memstats.heapStats.release()
// Count the alloc in inconsistent, internal stats.
gcController.totalAlloc.Add(int64(npages * pageSize))
// Update heapLive.
gcController.update(int64(s.npages*pageSize), 0)
// Put the large span in the mcentral swept list so that it's
// visible to the background sweeper.
mheap_.central[spc].mcentral.fullSwept(mheap_.sweepgen).push(s)
s.limit = s.base() + size
s.initHeapBits(false)
return s
}
func (c *mcache) releaseAll() {
// Take this opportunity to flush scanAlloc.
scanAlloc := int64(c.scanAlloc)
c.scanAlloc = 0
sg := mheap_.sweepgen
dHeapLive := int64(0)
for i := range c.alloc {
s := c.alloc[i]
if s != &emptymspan {
slotsUsed := int64(s.allocCount) - int64(s.allocCountBeforeCache)
s.allocCountBeforeCache = 0
// Adjust smallAllocCount for whatever was allocated.
stats := memstats.heapStats.acquire()
atomic.Xadd64(&stats.smallAllocCount[spanClass(i).sizeclass()], slotsUsed)
memstats.heapStats.release()
// Adjust the actual allocs in inconsistent, internal stats.
// We assumed earlier that the full span gets allocated.
gcController.totalAlloc.Add(slotsUsed * int64(s.elemsize))
if s.sweepgen != sg+1 {
// refill conservatively counted unallocated slots in gcController.heapLive.
// Undo this.
//
// If this span was cached before sweep, then gcController.heapLive was totally
// recomputed since caching this span, so we don't do this for stale spans.
dHeapLive -= int64(uintptr(s.nelems)-uintptr(s.allocCount)) * int64(s.elemsize)
}
// Release the span to the mcentral.
mheap_.central[i].mcentral.uncacheSpan(s)
c.alloc[i] = &emptymspan
}
}
// Clear tinyalloc pool.
c.tiny = 0
c.tinyoffset = 0
// Flush tinyAllocs.
stats := memstats.heapStats.acquire()
atomic.Xadd64(&stats.tinyAllocCount, int64(c.tinyAllocs))
c.tinyAllocs = 0
memstats.heapStats.release()
// Update heapLive and heapScan.
gcController.update(dHeapLive, scanAlloc)
}
// prepareForSweep flushes c if the system has entered a new sweep phase
// since c was populated. This must happen between the sweep phase
// starting and the first allocation from c.
func (c *mcache) prepareForSweep() {
// Alternatively, instead of making sure we do this on every P
// between starting the world and allocating on that P, we
// could leave allocate-black on, allow allocation to continue
// as usual, use a ragged barrier at the beginning of sweep to
// ensure all cached spans are swept, and then disable
// allocate-black. However, with this approach it's difficult
// to avoid spilling mark bits into the *next* GC cycle.
sg := mheap_.sweepgen
flushGen := c.flushGen.Load()
if flushGen == sg {
return
} else if flushGen != sg-2 {
println("bad flushGen", flushGen, "in prepareForSweep; sweepgen", sg)
throw("bad flushGen")
}
c.releaseAll()
stackcache_clear(c)
c.flushGen.Store(mheap_.sweepgen) // Synchronizes with gcStart
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Central free lists.
//
// See malloc.go for an overview.
//
// The mcentral doesn't actually contain the list of free objects; the mspan does.
// Each mcentral is two lists of mspans: those with free objects (c->nonempty)
// and those that are completely allocated (c->empty).
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/sys"
)
// Central list of free objects of a given size.
type mcentral struct {
_ sys.NotInHeap
spanclass spanClass
// partial and full contain two mspan sets: one of swept in-use
// spans, and one of unswept in-use spans. These two trade
// roles on each GC cycle. The unswept set is drained either by
// allocation or by the background sweeper in every GC cycle,
// so only two roles are necessary.
//
// sweepgen is increased by 2 on each GC cycle, so the swept
// spans are in partial[sweepgen/2%2] and the unswept spans are in
// partial[1-sweepgen/2%2]. Sweeping pops spans from the
// unswept set and pushes spans that are still in-use on the
// swept set. Likewise, allocating an in-use span pushes it
// on the swept set.
//
// Some parts of the sweeper can sweep arbitrary spans, and hence
// can't remove them from the unswept set, but will add the span
// to the appropriate swept list. As a result, the parts of the
// sweeper and mcentral that do consume from the unswept list may
// encounter swept spans, and these should be ignored.
partial [2]spanSet // list of spans with a free object
full [2]spanSet // list of spans with no free objects
}
// Initialize a single central free list.
func (c *mcentral) init(spc spanClass) {
c.spanclass = spc
lockInit(&c.partial[0].spineLock, lockRankSpanSetSpine)
lockInit(&c.partial[1].spineLock, lockRankSpanSetSpine)
lockInit(&c.full[0].spineLock, lockRankSpanSetSpine)
lockInit(&c.full[1].spineLock, lockRankSpanSetSpine)
}
// partialUnswept returns the spanSet which holds partially-filled
// unswept spans for this sweepgen.
func (c *mcentral) partialUnswept(sweepgen uint32) *spanSet {
return &c.partial[1-sweepgen/2%2]
}
// partialSwept returns the spanSet which holds partially-filled
// swept spans for this sweepgen.
func (c *mcentral) partialSwept(sweepgen uint32) *spanSet {
return &c.partial[sweepgen/2%2]
}
// fullUnswept returns the spanSet which holds unswept spans without any
// free slots for this sweepgen.
func (c *mcentral) fullUnswept(sweepgen uint32) *spanSet {
return &c.full[1-sweepgen/2%2]
}
// fullSwept returns the spanSet which holds swept spans without any
// free slots for this sweepgen.
func (c *mcentral) fullSwept(sweepgen uint32) *spanSet {
return &c.full[sweepgen/2%2]
}
// Allocate a span to use in an mcache.
func (c *mcentral) cacheSpan() *mspan {
// Deduct credit for this span allocation and sweep if necessary.
spanBytes := uintptr(class_to_allocnpages[c.spanclass.sizeclass()]) * _PageSize
deductSweepCredit(spanBytes, 0)
traceDone := false
if trace.enabled {
traceGCSweepStart()
}
// If we sweep spanBudget spans without finding any free
// space, just allocate a fresh span. This limits the amount
// of time we can spend trying to find free space and
// amortizes the cost of small object sweeping over the
// benefit of having a full free span to allocate from. By
// setting this to 100, we limit the space overhead to 1%.
//
// TODO(austin,mknyszek): This still has bad worst-case
// throughput. For example, this could find just one free slot
// on the 100th swept span. That limits allocation latency, but
// still has very poor throughput. We could instead keep a
// running free-to-used budget and switch to fresh span
// allocation if the budget runs low.
spanBudget := 100
var s *mspan
var sl sweepLocker
// Try partial swept spans first.
sg := mheap_.sweepgen
if s = c.partialSwept(sg).pop(); s != nil {
goto havespan
}
sl = sweep.active.begin()
if sl.valid {
// Now try partial unswept spans.
for ; spanBudget >= 0; spanBudget-- {
s = c.partialUnswept(sg).pop()
if s == nil {
break
}
if s, ok := sl.tryAcquire(s); ok {
// We got ownership of the span, so let's sweep it and use it.
s.sweep(true)
sweep.active.end(sl)
goto havespan
}
// We failed to get ownership of the span, which means it's being or
// has been swept by an asynchronous sweeper that just couldn't remove it
// from the unswept list. That sweeper took ownership of the span and
// responsibility for either freeing it to the heap or putting it on the
// right swept list. Either way, we should just ignore it (and it's unsafe
// for us to do anything else).
}
// Now try full unswept spans, sweeping them and putting them into the
// right list if we fail to get a span.
for ; spanBudget >= 0; spanBudget-- {
s = c.fullUnswept(sg).pop()
if s == nil {
break
}
if s, ok := sl.tryAcquire(s); ok {
// We got ownership of the span, so let's sweep it.
s.sweep(true)
// Check if there's any free space.
freeIndex := s.nextFreeIndex()
if freeIndex != s.nelems {
s.freeindex = freeIndex
sweep.active.end(sl)
goto havespan
}
// Add it to the swept list, because sweeping didn't give us any free space.
c.fullSwept(sg).push(s.mspan)
}
// See comment for partial unswept spans.
}
sweep.active.end(sl)
}
if trace.enabled {
traceGCSweepDone()
traceDone = true
}
// We failed to get a span from the mcentral so get one from mheap.
s = c.grow()
if s == nil {
return nil
}
// At this point s is a span that should have free slots.
havespan:
if trace.enabled && !traceDone {
traceGCSweepDone()
}
n := int(s.nelems) - int(s.allocCount)
if n == 0 || s.freeindex == s.nelems || uintptr(s.allocCount) == s.nelems {
throw("span has no free objects")
}
freeByteBase := s.freeindex &^ (64 - 1)
whichByte := freeByteBase / 8
// Init alloc bits cache.
s.refillAllocCache(whichByte)
// Adjust the allocCache so that s.freeindex corresponds to the low bit in
// s.allocCache.
s.allocCache >>= s.freeindex % 64
return s
}
// Return span from an mcache.
//
// s must have a span class corresponding to this
// mcentral and it must not be empty.
func (c *mcentral) uncacheSpan(s *mspan) {
if s.allocCount == 0 {
throw("uncaching span but s.allocCount == 0")
}
sg := mheap_.sweepgen
stale := s.sweepgen == sg+1
// Fix up sweepgen.
if stale {
// Span was cached before sweep began. It's our
// responsibility to sweep it.
//
// Set sweepgen to indicate it's not cached but needs
// sweeping and can't be allocated from. sweep will
// set s.sweepgen to indicate s is swept.
atomic.Store(&s.sweepgen, sg-1)
} else {
// Indicate that s is no longer cached.
atomic.Store(&s.sweepgen, sg)
}
// Put the span in the appropriate place.
if stale {
// It's stale, so just sweep it. Sweeping will put it on
// the right list.
//
// We don't use a sweepLocker here. Stale cached spans
// aren't in the global sweep lists, so mark termination
// itself holds up sweep completion until all mcaches
// have been swept.
ss := sweepLocked{s}
ss.sweep(false)
} else {
if int(s.nelems)-int(s.allocCount) > 0 {
// Put it back on the partial swept list.
c.partialSwept(sg).push(s)
} else {
// There's no free space and it's not stale, so put it on the
// full swept list.
c.fullSwept(sg).push(s)
}
}
}
// grow allocates a new empty span from the heap and initializes it for c's size class.
func (c *mcentral) grow() *mspan {
npages := uintptr(class_to_allocnpages[c.spanclass.sizeclass()])
size := uintptr(class_to_size[c.spanclass.sizeclass()])
s := mheap_.alloc(npages, c.spanclass)
if s == nil {
return nil
}
// Use division by multiplication and shifts to quickly compute:
// n := (npages << _PageShift) / size
n := s.divideByElemSize(npages << _PageShift)
s.limit = s.base() + size*n
s.initHeapBits(false)
return s
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// GC checkmarks
//
// In a concurrent garbage collector, one worries about failing to mark
// a live object due to mutations without write barriers or bugs in the
// collector implementation. As a sanity check, the GC has a 'checkmark'
// mode that retraverses the object graph with the world stopped, to make
// sure that everything that should be marked is marked.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// A checkmarksMap stores the GC marks in "checkmarks" mode. It is a
// per-arena bitmap with a bit for every word in the arena. The mark
// is stored on the bit corresponding to the first word of the marked
// allocation.
type checkmarksMap struct {
_ sys.NotInHeap
b [heapArenaBytes / goarch.PtrSize / 8]uint8
}
// If useCheckmark is true, marking of an object uses the checkmark
// bits instead of the standard mark bits.
var useCheckmark = false
// startCheckmarks prepares for the checkmarks phase.
//
// The world must be stopped.
func startCheckmarks() {
assertWorldStopped()
// Clear all checkmarks.
for _, ai := range mheap_.allArenas {
arena := mheap_.arenas[ai.l1()][ai.l2()]
bitmap := arena.checkmarks
if bitmap == nil {
// Allocate bitmap on first use.
bitmap = (*checkmarksMap)(persistentalloc(unsafe.Sizeof(*bitmap), 0, &memstats.gcMiscSys))
if bitmap == nil {
throw("out of memory allocating checkmarks bitmap")
}
arena.checkmarks = bitmap
} else {
// Otherwise clear the existing bitmap.
for i := range bitmap.b {
bitmap.b[i] = 0
}
}
}
// Enable checkmarking.
useCheckmark = true
}
// endCheckmarks ends the checkmarks phase.
func endCheckmarks() {
if gcMarkWorkAvailable(nil) {
throw("GC work not flushed")
}
useCheckmark = false
}
// setCheckmark throws if marking object is a checkmarks violation,
// and otherwise sets obj's checkmark. It returns true if obj was
// already checkmarked.
func setCheckmark(obj, base, off uintptr, mbits markBits) bool {
if !mbits.isMarked() {
printlock()
print("runtime: checkmarks found unexpected unmarked object obj=", hex(obj), "\n")
print("runtime: found obj at *(", hex(base), "+", hex(off), ")\n")
// Dump the source (base) object
gcDumpObject("base", base, off)
// Dump the object
gcDumpObject("obj", obj, ^uintptr(0))
getg().m.traceback = 2
throw("checkmark found unmarked object")
}
ai := arenaIndex(obj)
arena := mheap_.arenas[ai.l1()][ai.l2()]
arenaWord := (obj / heapArenaBytes / 8) % uintptr(len(arena.checkmarks.b))
mask := byte(1 << ((obj / heapArenaBytes) % 8))
bytep := &arena.checkmarks.b[arenaWord]
if atomic.Load8(bytep)&mask != 0 {
// Already checkmarked.
return true
}
atomic.Or8(bytep, mask)
return false
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
// OS memory management abstraction layer
//
// Regions of the address space managed by the runtime may be in one of four
// states at any given time:
// 1) None - Unreserved and unmapped, the default state of any region.
// 2) Reserved - Owned by the runtime, but accessing it would cause a fault.
// Does not count against the process' memory footprint.
// 3) Prepared - Reserved, intended not to be backed by physical memory (though
// an OS may implement this lazily). Can transition efficiently to
// Ready. Accessing memory in such a region is undefined (may
// fault, may give back unexpected zeroes, etc.).
// 4) Ready - may be accessed safely.
//
// This set of states is more than is strictly necessary to support all the
// currently supported platforms. One could get by with just None, Reserved, and
// Ready. However, the Prepared state gives us flexibility for performance
// purposes. For example, on POSIX-y operating systems, Reserved is usually a
// private anonymous mmap'd region with PROT_NONE set, and to transition
// to Ready would require setting PROT_READ|PROT_WRITE. However the
// underspecification of Prepared lets us use just MADV_FREE to transition from
// Ready to Prepared. Thus with the Prepared state we can set the permission
// bits just once early on, we can efficiently tell the OS that it's free to
// take pages away from us when we don't strictly need them.
//
// This file defines a cross-OS interface for a common set of helpers
// that transition memory regions between these states. The helpers call into
// OS-specific implementations that handle errors, while the interface boundary
// implements cross-OS functionality, like updating runtime accounting.
// sysAlloc transitions an OS-chosen region of memory from None to Ready.
// More specifically, it obtains a large chunk of zeroed memory from the
// operating system, typically on the order of a hundred kilobytes
// or a megabyte. This memory is always immediately available for use.
//
// sysStat must be non-nil.
//
// Don't split the stack as this function may be invoked without a valid G,
// which prevents us from allocating more stack.
//
//go:nosplit
func sysAlloc(n uintptr, sysStat *sysMemStat) unsafe.Pointer {
sysStat.add(int64(n))
gcController.mappedReady.Add(int64(n))
return sysAllocOS(n)
}
// sysUnused transitions a memory region from Ready to Prepared. It notifies the
// operating system that the physical pages backing this memory region are no
// longer needed and can be reused for other purposes. The contents of a
// sysUnused memory region are considered forfeit and the region must not be
// accessed again until sysUsed is called.
func sysUnused(v unsafe.Pointer, n uintptr) {
gcController.mappedReady.Add(-int64(n))
sysUnusedOS(v, n)
}
// sysUsed transitions a memory region from Prepared to Ready. It notifies the
// operating system that the memory region is needed and ensures that the region
// may be safely accessed. This is typically a no-op on systems that don't have
// an explicit commit step and hard over-commit limits, but is critical on
// Windows, for example.
//
// This operation is idempotent for memory already in the Prepared state, so
// it is safe to refer, with v and n, to a range of memory that includes both
// Prepared and Ready memory. However, the caller must provide the exact amount
// of Prepared memory for accounting purposes.
func sysUsed(v unsafe.Pointer, n, prepared uintptr) {
gcController.mappedReady.Add(int64(prepared))
sysUsedOS(v, n)
}
// sysHugePage does not transition memory regions, but instead provides a
// hint to the OS that it would be more efficient to back this memory region
// with pages of a larger size transparently.
func sysHugePage(v unsafe.Pointer, n uintptr) {
sysHugePageOS(v, n)
}
// sysFree transitions a memory region from any state to None. Therefore, it
// returns memory unconditionally. It is used if an out-of-memory error has been
// detected midway through an allocation or to carve out an aligned section of
// the address space. It is okay if sysFree is a no-op only if sysReserve always
// returns a memory region aligned to the heap allocator's alignment
// restrictions.
//
// sysStat must be non-nil.
//
// Don't split the stack as this function may be invoked without a valid G,
// which prevents us from allocating more stack.
//
//go:nosplit
func sysFree(v unsafe.Pointer, n uintptr, sysStat *sysMemStat) {
sysStat.add(-int64(n))
gcController.mappedReady.Add(-int64(n))
sysFreeOS(v, n)
}
// sysFault transitions a memory region from Ready to Reserved. It
// marks a region such that it will always fault if accessed. Used only for
// debugging the runtime.
//
// TODO(mknyszek): Currently it's true that all uses of sysFault transition
// memory from Ready to Reserved, but this may not be true in the future
// since on every platform the operation is much more general than that.
// If a transition from Prepared is ever introduced, create a new function
// that elides the Ready state accounting.
func sysFault(v unsafe.Pointer, n uintptr) {
gcController.mappedReady.Add(-int64(n))
sysFaultOS(v, n)
}
// sysReserve transitions a memory region from None to Reserved. It reserves
// address space in such a way that it would cause a fatal fault upon access
// (either via permissions or not committing the memory). Such a reservation is
// thus never backed by physical memory.
//
// If the pointer passed to it is non-nil, the caller wants the
// reservation there, but sysReserve can still choose another
// location if that one is unavailable.
//
// NOTE: sysReserve returns OS-aligned memory, but the heap allocator
// may use larger alignment, so the caller must be careful to realign the
// memory obtained by sysReserve.
func sysReserve(v unsafe.Pointer, n uintptr) unsafe.Pointer {
return sysReserveOS(v, n)
}
// sysMap transitions a memory region from Reserved to Prepared. It ensures the
// memory region can be efficiently transitioned to Ready.
//
// sysStat must be non-nil.
func sysMap(v unsafe.Pointer, n uintptr, sysStat *sysMemStat) {
sysStat.add(int64(n))
sysMapOS(v, n)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
const (
_EACCES = 13
_EINVAL = 22
)
// Don't split the stack as this method may be invoked without a valid G, which
// prevents us from allocating more stack.
//
//go:nosplit
func sysAllocOS(n uintptr) unsafe.Pointer {
p, err := mmap(nil, n, _PROT_READ|_PROT_WRITE, _MAP_ANON|_MAP_PRIVATE, -1, 0)
if err != 0 {
if err == _EACCES {
print("runtime: mmap: access denied\n")
exit(2)
}
if err == _EAGAIN {
print("runtime: mmap: too much locked memory (check 'ulimit -l').\n")
exit(2)
}
return nil
}
return p
}
var adviseUnused = uint32(_MADV_FREE)
func sysUnusedOS(v unsafe.Pointer, n uintptr) {
// By default, Linux's "transparent huge page" support will
// merge pages into a huge page if there's even a single
// present regular page, undoing the effects of madvise(adviseUnused)
// below. On amd64, that means khugepaged can turn a single
// 4KB page to 2MB, bloating the process's RSS by as much as
// 512X. (See issue #8832 and Linux kernel bug
// https://bugzilla.kernel.org/show_bug.cgi?id=93111)
//
// To work around this, we explicitly disable transparent huge
// pages when we release pages of the heap. However, we have
// to do this carefully because changing this flag tends to
// split the VMA (memory mapping) containing v in to three
// VMAs in order to track the different values of the
// MADV_NOHUGEPAGE flag in the different regions. There's a
// default limit of 65530 VMAs per address space (sysctl
// vm.max_map_count), so we must be careful not to create too
// many VMAs (see issue #12233).
//
// Since huge pages are huge, there's little use in adjusting
// the MADV_NOHUGEPAGE flag on a fine granularity, so we avoid
// exploding the number of VMAs by only adjusting the
// MADV_NOHUGEPAGE flag on a large granularity. This still
// gets most of the benefit of huge pages while keeping the
// number of VMAs under control. With hugePageSize = 2MB, even
// a pessimal heap can reach 128GB before running out of VMAs.
if physHugePageSize != 0 {
// If it's a large allocation, we want to leave huge
// pages enabled. Hence, we only adjust the huge page
// flag on the huge pages containing v and v+n-1, and
// only if those aren't aligned.
var head, tail uintptr
if uintptr(v)&(physHugePageSize-1) != 0 {
// Compute huge page containing v.
head = alignDown(uintptr(v), physHugePageSize)
}
if (uintptr(v)+n)&(physHugePageSize-1) != 0 {
// Compute huge page containing v+n-1.
tail = alignDown(uintptr(v)+n-1, physHugePageSize)
}
// Note that madvise will return EINVAL if the flag is
// already set, which is quite likely. We ignore
// errors.
if head != 0 && head+physHugePageSize == tail {
// head and tail are different but adjacent,
// so do this in one call.
madvise(unsafe.Pointer(head), 2*physHugePageSize, _MADV_NOHUGEPAGE)
} else {
// Advise the huge pages containing v and v+n-1.
if head != 0 {
madvise(unsafe.Pointer(head), physHugePageSize, _MADV_NOHUGEPAGE)
}
if tail != 0 && tail != head {
madvise(unsafe.Pointer(tail), physHugePageSize, _MADV_NOHUGEPAGE)
}
}
}
if uintptr(v)&(physPageSize-1) != 0 || n&(physPageSize-1) != 0 {
// madvise will round this to any physical page
// *covered* by this range, so an unaligned madvise
// will release more memory than intended.
throw("unaligned sysUnused")
}
var advise uint32
if debug.madvdontneed != 0 {
advise = _MADV_DONTNEED
} else {
advise = atomic.Load(&adviseUnused)
}
if errno := madvise(v, n, int32(advise)); advise == _MADV_FREE && errno != 0 {
// MADV_FREE was added in Linux 4.5. Fall back to MADV_DONTNEED if it is
// not supported.
atomic.Store(&adviseUnused, _MADV_DONTNEED)
madvise(v, n, _MADV_DONTNEED)
}
if debug.harddecommit > 0 {
p, err := mmap(v, n, _PROT_NONE, _MAP_ANON|_MAP_FIXED|_MAP_PRIVATE, -1, 0)
if p != v || err != 0 {
throw("runtime: cannot disable permissions in address space")
}
}
}
func sysUsedOS(v unsafe.Pointer, n uintptr) {
if debug.harddecommit > 0 {
p, err := mmap(v, n, _PROT_READ|_PROT_WRITE, _MAP_ANON|_MAP_FIXED|_MAP_PRIVATE, -1, 0)
if err == _ENOMEM {
throw("runtime: out of memory")
}
if p != v || err != 0 {
throw("runtime: cannot remap pages in address space")
}
return
// Don't do the sysHugePage optimization in hard decommit mode.
// We're breaking up pages everywhere, there's no point.
}
// Partially undo the NOHUGEPAGE marks from sysUnused
// for whole huge pages between v and v+n. This may
// leave huge pages off at the end points v and v+n
// even though allocations may cover these entire huge
// pages. We could detect this and undo NOHUGEPAGE on
// the end points as well, but it's probably not worth
// the cost because when neighboring allocations are
// freed sysUnused will just set NOHUGEPAGE again.
sysHugePageOS(v, n)
}
func sysHugePageOS(v unsafe.Pointer, n uintptr) {
if physHugePageSize != 0 {
// Round v up to a huge page boundary.
beg := alignUp(uintptr(v), physHugePageSize)
// Round v+n down to a huge page boundary.
end := alignDown(uintptr(v)+n, physHugePageSize)
if beg < end {
madvise(unsafe.Pointer(beg), end-beg, _MADV_HUGEPAGE)
}
}
}
// Don't split the stack as this function may be invoked without a valid G,
// which prevents us from allocating more stack.
//
//go:nosplit
func sysFreeOS(v unsafe.Pointer, n uintptr) {
munmap(v, n)
}
func sysFaultOS(v unsafe.Pointer, n uintptr) {
mmap(v, n, _PROT_NONE, _MAP_ANON|_MAP_PRIVATE|_MAP_FIXED, -1, 0)
}
func sysReserveOS(v unsafe.Pointer, n uintptr) unsafe.Pointer {
p, err := mmap(v, n, _PROT_NONE, _MAP_ANON|_MAP_PRIVATE, -1, 0)
if err != 0 {
return nil
}
return p
}
func sysMapOS(v unsafe.Pointer, n uintptr) {
p, err := mmap(v, n, _PROT_READ|_PROT_WRITE, _MAP_ANON|_MAP_FIXED|_MAP_PRIVATE, -1, 0)
if err == _ENOMEM {
throw("runtime: out of memory")
}
if p != v || err != 0 {
print("runtime: mmap(", v, ", ", n, ") returned ", p, ", ", err, "\n")
throw("runtime: cannot map pages in arena address space")
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// Metrics implementation exported to runtime/metrics.
import (
"unsafe"
)
var (
// metrics is a map of runtime/metrics keys to data used by the runtime
// to sample each metric's value. metricsInit indicates it has been
// initialized.
//
// These fields are protected by metricsSema which should be
// locked/unlocked with metricsLock() / metricsUnlock().
metricsSema uint32 = 1
metricsInit bool
metrics map[string]metricData
sizeClassBuckets []float64
timeHistBuckets []float64
)
type metricData struct {
// deps is the set of runtime statistics that this metric
// depends on. Before compute is called, the statAggregate
// which will be passed must ensure() these dependencies.
deps statDepSet
// compute is a function that populates a metricValue
// given a populated statAggregate structure.
compute func(in *statAggregate, out *metricValue)
}
func metricsLock() {
// Acquire the metricsSema but with handoff. Operations are typically
// expensive enough that queueing up goroutines and handing off between
// them will be noticeably better-behaved.
semacquire1(&metricsSema, true, 0, 0, waitReasonSemacquire)
if raceenabled {
raceacquire(unsafe.Pointer(&metricsSema))
}
}
func metricsUnlock() {
if raceenabled {
racerelease(unsafe.Pointer(&metricsSema))
}
semrelease(&metricsSema)
}
// initMetrics initializes the metrics map if it hasn't been yet.
//
// metricsSema must be held.
func initMetrics() {
if metricsInit {
return
}
sizeClassBuckets = make([]float64, _NumSizeClasses, _NumSizeClasses+1)
// Skip size class 0 which is a stand-in for large objects, but large
// objects are tracked separately (and they actually get placed in
// the last bucket, not the first).
sizeClassBuckets[0] = 1 // The smallest allocation is 1 byte in size.
for i := 1; i < _NumSizeClasses; i++ {
// Size classes have an inclusive upper-bound
// and exclusive lower bound (e.g. 48-byte size class is
// (32, 48]) whereas we want and inclusive lower-bound
// and exclusive upper-bound (e.g. 48-byte size class is
// [33, 49). We can achieve this by shifting all bucket
// boundaries up by 1.
//
// Also, a float64 can precisely represent integers with
// value up to 2^53 and size classes are relatively small
// (nowhere near 2^48 even) so this will give us exact
// boundaries.
sizeClassBuckets[i] = float64(class_to_size[i] + 1)
}
sizeClassBuckets = append(sizeClassBuckets, float64Inf())
timeHistBuckets = timeHistogramMetricsBuckets()
metrics = map[string]metricData{
"/cgo/go-to-c-calls:calls": {
compute: func(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(NumCgoCall())
},
},
"/cpu/classes/gc/mark/assist:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.gcAssistTime))
},
},
"/cpu/classes/gc/mark/dedicated:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.gcDedicatedTime))
},
},
"/cpu/classes/gc/mark/idle:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.gcIdleTime))
},
},
"/cpu/classes/gc/pause:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.gcPauseTime))
},
},
"/cpu/classes/gc/total:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.gcTotalTime))
},
},
"/cpu/classes/idle:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.idleTime))
},
},
"/cpu/classes/scavenge/assist:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.scavengeAssistTime))
},
},
"/cpu/classes/scavenge/background:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.scavengeBgTime))
},
},
"/cpu/classes/scavenge/total:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.scavengeTotalTime))
},
},
"/cpu/classes/total:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.totalTime))
},
},
"/cpu/classes/user:cpu-seconds": {
deps: makeStatDepSet(cpuStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(in.cpuStats.userTime))
},
},
"/gc/cycles/automatic:gc-cycles": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.gcCyclesDone - in.sysStats.gcCyclesForced
},
},
"/gc/cycles/forced:gc-cycles": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.gcCyclesForced
},
},
"/gc/cycles/total:gc-cycles": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.gcCyclesDone
},
},
"/gc/heap/allocs-by-size:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
hist := out.float64HistOrInit(sizeClassBuckets)
hist.counts[len(hist.counts)-1] = uint64(in.heapStats.largeAllocCount)
// Cut off the first index which is ostensibly for size class 0,
// but large objects are tracked separately so it's actually unused.
for i, count := range in.heapStats.smallAllocCount[1:] {
hist.counts[i] = uint64(count)
}
},
},
"/gc/heap/allocs:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.totalAllocated
},
},
"/gc/heap/allocs:objects": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.totalAllocs
},
},
"/gc/heap/frees-by-size:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
hist := out.float64HistOrInit(sizeClassBuckets)
hist.counts[len(hist.counts)-1] = uint64(in.heapStats.largeFreeCount)
// Cut off the first index which is ostensibly for size class 0,
// but large objects are tracked separately so it's actually unused.
for i, count := range in.heapStats.smallFreeCount[1:] {
hist.counts[i] = uint64(count)
}
},
},
"/gc/heap/frees:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.totalFreed
},
},
"/gc/heap/frees:objects": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.totalFrees
},
},
"/gc/heap/goal:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.heapGoal
},
},
"/gc/heap/objects:objects": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.numObjects
},
},
"/gc/heap/tiny/allocs:objects": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.tinyAllocCount)
},
},
"/gc/limiter/last-enabled:gc-cycle": {
compute: func(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(gcCPULimiter.lastEnabledCycle.Load())
},
},
"/gc/pauses:seconds": {
compute: func(_ *statAggregate, out *metricValue) {
hist := out.float64HistOrInit(timeHistBuckets)
// The bottom-most bucket, containing negative values, is tracked
// as a separately as underflow, so fill that in manually and then
// iterate over the rest.
hist.counts[0] = memstats.gcPauseDist.underflow.Load()
for i := range memstats.gcPauseDist.counts {
hist.counts[i+1] = memstats.gcPauseDist.counts[i].Load()
}
hist.counts[len(hist.counts)-1] = memstats.gcPauseDist.overflow.Load()
},
},
"/gc/stack/starting-size:bytes": {
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(startingStackSize)
},
},
"/godebug/non-default-behavior/execerrdot:events": {compute: compute0},
"/godebug/non-default-behavior/http2client:events": {compute: compute0},
"/godebug/non-default-behavior/http2server:events": {compute: compute0},
"/godebug/non-default-behavior/installgoroot:events": {compute: compute0},
"/godebug/non-default-behavior/panicnil:events": {compute: compute0},
"/godebug/non-default-behavior/randautoseed:events": {compute: compute0},
"/godebug/non-default-behavior/tarinsecurepath:events": {compute: compute0},
"/godebug/non-default-behavior/x509sha1:events": {compute: compute0},
"/godebug/non-default-behavior/x509usefallbackroots:events": {compute: compute0},
"/godebug/non-default-behavior/zipinsecurepath:events": {compute: compute0},
"/memory/classes/heap/free:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.committed - in.heapStats.inHeap -
in.heapStats.inStacks - in.heapStats.inWorkBufs -
in.heapStats.inPtrScalarBits)
},
},
"/memory/classes/heap/objects:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.heapStats.inObjects
},
},
"/memory/classes/heap/released:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.released)
},
},
"/memory/classes/heap/stacks:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.inStacks)
},
},
"/memory/classes/heap/unused:bytes": {
deps: makeStatDepSet(heapStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.inHeap) - in.heapStats.inObjects
},
},
"/memory/classes/metadata/mcache/free:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.mCacheSys - in.sysStats.mCacheInUse
},
},
"/memory/classes/metadata/mcache/inuse:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.mCacheInUse
},
},
"/memory/classes/metadata/mspan/free:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.mSpanSys - in.sysStats.mSpanInUse
},
},
"/memory/classes/metadata/mspan/inuse:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.mSpanInUse
},
},
"/memory/classes/metadata/other:bytes": {
deps: makeStatDepSet(heapStatsDep, sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.inWorkBufs+in.heapStats.inPtrScalarBits) + in.sysStats.gcMiscSys
},
},
"/memory/classes/os-stacks:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.stacksSys
},
},
"/memory/classes/other:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.otherSys
},
},
"/memory/classes/profiling/buckets:bytes": {
deps: makeStatDepSet(sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = in.sysStats.buckHashSys
},
},
"/memory/classes/total:bytes": {
deps: makeStatDepSet(heapStatsDep, sysStatsDep),
compute: func(in *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(in.heapStats.committed+in.heapStats.released) +
in.sysStats.stacksSys + in.sysStats.mSpanSys +
in.sysStats.mCacheSys + in.sysStats.buckHashSys +
in.sysStats.gcMiscSys + in.sysStats.otherSys
},
},
"/sched/gomaxprocs:threads": {
compute: func(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(gomaxprocs)
},
},
"/sched/goroutines:goroutines": {
compute: func(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = uint64(gcount())
},
},
"/sched/latencies:seconds": {
compute: func(_ *statAggregate, out *metricValue) {
hist := out.float64HistOrInit(timeHistBuckets)
hist.counts[0] = sched.timeToRun.underflow.Load()
for i := range sched.timeToRun.counts {
hist.counts[i+1] = sched.timeToRun.counts[i].Load()
}
hist.counts[len(hist.counts)-1] = sched.timeToRun.overflow.Load()
},
},
"/sync/mutex/wait/total:seconds": {
compute: func(_ *statAggregate, out *metricValue) {
out.kind = metricKindFloat64
out.scalar = float64bits(nsToSec(sched.totalMutexWaitTime.Load()))
},
},
}
metricsInit = true
}
func compute0(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = 0
}
type metricReader func() uint64
func (f metricReader) compute(_ *statAggregate, out *metricValue) {
out.kind = metricKindUint64
out.scalar = f()
}
var godebugNonDefaults = []string{
"panicnil",
}
//go:linkname godebug_registerMetric internal/godebug.registerMetric
func godebug_registerMetric(name string, read func() uint64) {
metricsLock()
initMetrics()
d, ok := metrics[name]
if !ok {
throw("runtime: unexpected metric registration for " + name)
}
d.compute = metricReader(read).compute
metrics[name] = d
metricsUnlock()
}
// statDep is a dependency on a group of statistics
// that a metric might have.
type statDep uint
const (
heapStatsDep statDep = iota // corresponds to heapStatsAggregate
sysStatsDep // corresponds to sysStatsAggregate
cpuStatsDep // corresponds to cpuStatsAggregate
numStatsDeps
)
// statDepSet represents a set of statDeps.
//
// Under the hood, it's a bitmap.
type statDepSet [1]uint64
// makeStatDepSet creates a new statDepSet from a list of statDeps.
func makeStatDepSet(deps ...statDep) statDepSet {
var s statDepSet
for _, d := range deps {
s[d/64] |= 1 << (d % 64)
}
return s
}
// difference returns set difference of s from b as a new set.
func (s statDepSet) difference(b statDepSet) statDepSet {
var c statDepSet
for i := range s {
c[i] = s[i] &^ b[i]
}
return c
}
// union returns the union of the two sets as a new set.
func (s statDepSet) union(b statDepSet) statDepSet {
var c statDepSet
for i := range s {
c[i] = s[i] | b[i]
}
return c
}
// empty returns true if there are no dependencies in the set.
func (s *statDepSet) empty() bool {
for _, c := range s {
if c != 0 {
return false
}
}
return true
}
// has returns true if the set contains a given statDep.
func (s *statDepSet) has(d statDep) bool {
return s[d/64]&(1<<(d%64)) != 0
}
// heapStatsAggregate represents memory stats obtained from the
// runtime. This set of stats is grouped together because they
// depend on each other in some way to make sense of the runtime's
// current heap memory use. They're also sharded across Ps, so it
// makes sense to grab them all at once.
type heapStatsAggregate struct {
heapStatsDelta
// Derived from values in heapStatsDelta.
// inObjects is the bytes of memory occupied by objects,
inObjects uint64
// numObjects is the number of live objects in the heap.
numObjects uint64
// totalAllocated is the total bytes of heap objects allocated
// over the lifetime of the program.
totalAllocated uint64
// totalFreed is the total bytes of heap objects freed
// over the lifetime of the program.
totalFreed uint64
// totalAllocs is the number of heap objects allocated over
// the lifetime of the program.
totalAllocs uint64
// totalFrees is the number of heap objects freed over
// the lifetime of the program.
totalFrees uint64
}
// compute populates the heapStatsAggregate with values from the runtime.
func (a *heapStatsAggregate) compute() {
memstats.heapStats.read(&a.heapStatsDelta)
// Calculate derived stats.
a.totalAllocs = a.largeAllocCount
a.totalFrees = a.largeFreeCount
a.totalAllocated = a.largeAlloc
a.totalFreed = a.largeFree
for i := range a.smallAllocCount {
na := a.smallAllocCount[i]
nf := a.smallFreeCount[i]
a.totalAllocs += na
a.totalFrees += nf
a.totalAllocated += na * uint64(class_to_size[i])
a.totalFreed += nf * uint64(class_to_size[i])
}
a.inObjects = a.totalAllocated - a.totalFreed
a.numObjects = a.totalAllocs - a.totalFrees
}
// sysStatsAggregate represents system memory stats obtained
// from the runtime. This set of stats is grouped together because
// they're all relatively cheap to acquire and generally independent
// of one another and other runtime memory stats. The fact that they
// may be acquired at different times, especially with respect to
// heapStatsAggregate, means there could be some skew, but because of
// these stats are independent, there's no real consistency issue here.
type sysStatsAggregate struct {
stacksSys uint64
mSpanSys uint64
mSpanInUse uint64
mCacheSys uint64
mCacheInUse uint64
buckHashSys uint64
gcMiscSys uint64
otherSys uint64
heapGoal uint64
gcCyclesDone uint64
gcCyclesForced uint64
}
// compute populates the sysStatsAggregate with values from the runtime.
func (a *sysStatsAggregate) compute() {
a.stacksSys = memstats.stacks_sys.load()
a.buckHashSys = memstats.buckhash_sys.load()
a.gcMiscSys = memstats.gcMiscSys.load()
a.otherSys = memstats.other_sys.load()
a.heapGoal = gcController.heapGoal()
a.gcCyclesDone = uint64(memstats.numgc)
a.gcCyclesForced = uint64(memstats.numforcedgc)
systemstack(func() {
lock(&mheap_.lock)
a.mSpanSys = memstats.mspan_sys.load()
a.mSpanInUse = uint64(mheap_.spanalloc.inuse)
a.mCacheSys = memstats.mcache_sys.load()
a.mCacheInUse = uint64(mheap_.cachealloc.inuse)
unlock(&mheap_.lock)
})
}
// cpuStatsAggregate represents CPU stats obtained from the runtime
// acquired together to avoid skew and inconsistencies.
type cpuStatsAggregate struct {
cpuStats
}
// compute populates the cpuStatsAggregate with values from the runtime.
func (a *cpuStatsAggregate) compute() {
a.cpuStats = work.cpuStats
}
// nsToSec takes a duration in nanoseconds and converts it to seconds as
// a float64.
func nsToSec(ns int64) float64 {
return float64(ns) / 1e9
}
// statAggregate is the main driver of the metrics implementation.
//
// It contains multiple aggregates of runtime statistics, as well
// as a set of these aggregates that it has populated. The aggregates
// are populated lazily by its ensure method.
type statAggregate struct {
ensured statDepSet
heapStats heapStatsAggregate
sysStats sysStatsAggregate
cpuStats cpuStatsAggregate
}
// ensure populates statistics aggregates determined by deps if they
// haven't yet been populated.
func (a *statAggregate) ensure(deps *statDepSet) {
missing := deps.difference(a.ensured)
if missing.empty() {
return
}
for i := statDep(0); i < numStatsDeps; i++ {
if !missing.has(i) {
continue
}
switch i {
case heapStatsDep:
a.heapStats.compute()
case sysStatsDep:
a.sysStats.compute()
case cpuStatsDep:
a.cpuStats.compute()
}
}
a.ensured = a.ensured.union(missing)
}
// metricKind is a runtime copy of runtime/metrics.ValueKind and
// must be kept structurally identical to that type.
type metricKind int
const (
// These values must be kept identical to their corresponding Kind* values
// in the runtime/metrics package.
metricKindBad metricKind = iota
metricKindUint64
metricKindFloat64
metricKindFloat64Histogram
)
// metricSample is a runtime copy of runtime/metrics.Sample and
// must be kept structurally identical to that type.
type metricSample struct {
name string
value metricValue
}
// metricValue is a runtime copy of runtime/metrics.Sample and
// must be kept structurally identical to that type.
type metricValue struct {
kind metricKind
scalar uint64 // contains scalar values for scalar Kinds.
pointer unsafe.Pointer // contains non-scalar values.
}
// float64HistOrInit tries to pull out an existing float64Histogram
// from the value, but if none exists, then it allocates one with
// the given buckets.
func (v *metricValue) float64HistOrInit(buckets []float64) *metricFloat64Histogram {
var hist *metricFloat64Histogram
if v.kind == metricKindFloat64Histogram && v.pointer != nil {
hist = (*metricFloat64Histogram)(v.pointer)
} else {
v.kind = metricKindFloat64Histogram
hist = new(metricFloat64Histogram)
v.pointer = unsafe.Pointer(hist)
}
hist.buckets = buckets
if len(hist.counts) != len(hist.buckets)-1 {
hist.counts = make([]uint64, len(buckets)-1)
}
return hist
}
// metricFloat64Histogram is a runtime copy of runtime/metrics.Float64Histogram
// and must be kept structurally identical to that type.
type metricFloat64Histogram struct {
counts []uint64
buckets []float64
}
// agg is used by readMetrics, and is protected by metricsSema.
//
// Managed as a global variable because its pointer will be
// an argument to a dynamically-defined function, and we'd
// like to avoid it escaping to the heap.
var agg statAggregate
type metricName struct {
name string
kind metricKind
}
// readMetricNames is the implementation of runtime/metrics.readMetricNames,
// used by the runtime/metrics test and otherwise unreferenced.
//
//go:linkname readMetricNames runtime/metrics_test.runtime_readMetricNames
func readMetricNames() []string {
metricsLock()
initMetrics()
n := len(metrics)
metricsUnlock()
list := make([]string, 0, n)
metricsLock()
for name := range metrics {
list = append(list, name)
}
metricsUnlock()
return list
}
// readMetrics is the implementation of runtime/metrics.Read.
//
//go:linkname readMetrics runtime/metrics.runtime_readMetrics
func readMetrics(samplesp unsafe.Pointer, len int, cap int) {
// Construct a slice from the args.
sl := slice{samplesp, len, cap}
samples := *(*[]metricSample)(unsafe.Pointer(&sl))
metricsLock()
// Ensure the map is initialized.
initMetrics()
// Clear agg defensively.
agg = statAggregate{}
// Sample.
for i := range samples {
sample := &samples[i]
data, ok := metrics[sample.name]
if !ok {
sample.value.kind = metricKindBad
continue
}
// Ensure we have all the stats we need.
// agg is populated lazily.
agg.ensure(&data.deps)
// Compute the value based on the stats we have.
data.compute(&agg, &sample.value)
}
metricsUnlock()
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metrics
// Description describes a runtime metric.
type Description struct {
// Name is the full name of the metric which includes the unit.
//
// The format of the metric may be described by the following regular expression.
//
// ^(?P<name>/[^:]+):(?P<unit>[^:*/]+(?:[*/][^:*/]+)*)$
//
// The format splits the name into two components, separated by a colon: a path which always
// starts with a /, and a machine-parseable unit. The name may contain any valid Unicode
// codepoint in between / characters, but by convention will try to stick to lowercase
// characters and hyphens. An example of such a path might be "/memory/heap/free".
//
// The unit is by convention a series of lowercase English unit names (singular or plural)
// without prefixes delimited by '*' or '/'. The unit names may contain any valid Unicode
// codepoint that is not a delimiter.
// Examples of units might be "seconds", "bytes", "bytes/second", "cpu-seconds",
// "byte*cpu-seconds", and "bytes/second/second".
//
// For histograms, multiple units may apply. For instance, the units of the buckets and
// the count. By convention, for histograms, the units of the count are always "samples"
// with the type of sample evident by the metric's name, while the unit in the name
// specifies the buckets' unit.
//
// A complete name might look like "/memory/heap/free:bytes".
Name string
// Description is an English language sentence describing the metric.
Description string
// Kind is the kind of value for this metric.
//
// The purpose of this field is to allow users to filter out metrics whose values are
// types which their application may not understand.
Kind ValueKind
// Cumulative is whether or not the metric is cumulative. If a cumulative metric is just
// a single number, then it increases monotonically. If the metric is a distribution,
// then each bucket count increases monotonically.
//
// This flag thus indicates whether or not it's useful to compute a rate from this value.
Cumulative bool
}
// The English language descriptions below must be kept in sync with the
// descriptions of each metric in doc.go.
var allDesc = []Description{
{
Name: "/cgo/go-to-c-calls:calls",
Description: "Count of calls made from Go to C by the current process.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/cpu/classes/gc/mark/assist:cpu-seconds",
Description: "Estimated total CPU time goroutines spent performing GC tasks " +
"to assist the GC and prevent it from falling behind the application. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/gc/mark/dedicated:cpu-seconds",
Description: "Estimated total CPU time spent performing GC tasks on " +
"processors (as defined by GOMAXPROCS) dedicated to those tasks. " +
"This includes time spent with the world stopped due to the GC. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/gc/mark/idle:cpu-seconds",
Description: "Estimated total CPU time spent performing GC tasks on " +
"spare CPU resources that the Go scheduler could not otherwise find " +
"a use for. This should be subtracted from the total GC CPU time to " +
"obtain a measure of compulsory GC CPU time. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/gc/pause:cpu-seconds",
Description: "Estimated total CPU time spent with the application paused by " +
"the GC. Even if only one thread is running during the pause, this is " +
"computed as GOMAXPROCS times the pause latency because nothing else " +
"can be executing. This is the exact sum of samples in /gc/pause:seconds " +
"if each sample is multiplied by GOMAXPROCS at the time it is taken. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/gc/total:cpu-seconds",
Description: "Estimated total CPU time spent performing GC tasks. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics. Sum of all metrics in /cpu/classes/gc.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/idle:cpu-seconds",
Description: "Estimated total available CPU time not spent executing any Go or Go runtime code. " +
"In other words, the part of /cpu/classes/total:cpu-seconds that was unused. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/scavenge/assist:cpu-seconds",
Description: "Estimated total CPU time spent returning unused memory to the " +
"underlying platform in response eagerly in response to memory pressure. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/scavenge/background:cpu-seconds",
Description: "Estimated total CPU time spent performing background tasks " +
"to return unused memory to the underlying platform. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/scavenge/total:cpu-seconds",
Description: "Estimated total CPU time spent performing tasks that return " +
"unused memory to the underlying platform. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics. Sum of all metrics in /cpu/classes/scavenge.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/total:cpu-seconds",
Description: "Estimated total available CPU time for user Go code " +
"or the Go runtime, as defined by GOMAXPROCS. In other words, GOMAXPROCS " +
"integrated over the wall-clock duration this process has been executing for. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics. Sum of all metrics in /cpu/classes.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/cpu/classes/user:cpu-seconds",
Description: "Estimated total CPU time spent running user Go code. This may " +
"also include some small amount of time spent in the Go runtime. " +
"This metric is an overestimate, and not directly comparable to " +
"system CPU time measurements. Compare only with other /cpu/classes " +
"metrics.",
Kind: KindFloat64,
Cumulative: true,
},
{
Name: "/gc/cycles/automatic:gc-cycles",
Description: "Count of completed GC cycles generated by the Go runtime.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/cycles/forced:gc-cycles",
Description: "Count of completed GC cycles forced by the application.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/cycles/total:gc-cycles",
Description: "Count of all completed GC cycles.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/heap/allocs-by-size:bytes",
Description: "Distribution of heap allocations by approximate size. " +
"Note that this does not include tiny objects as defined by " +
"/gc/heap/tiny/allocs:objects, only tiny blocks.",
Kind: KindFloat64Histogram,
Cumulative: true,
},
{
Name: "/gc/heap/allocs:bytes",
Description: "Cumulative sum of memory allocated to the heap by the application.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/heap/allocs:objects",
Description: "Cumulative count of heap allocations triggered by the application. " +
"Note that this does not include tiny objects as defined by " +
"/gc/heap/tiny/allocs:objects, only tiny blocks.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/heap/frees-by-size:bytes",
Description: "Distribution of freed heap allocations by approximate size. " +
"Note that this does not include tiny objects as defined by " +
"/gc/heap/tiny/allocs:objects, only tiny blocks.",
Kind: KindFloat64Histogram,
Cumulative: true,
},
{
Name: "/gc/heap/frees:bytes",
Description: "Cumulative sum of heap memory freed by the garbage collector.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/heap/frees:objects",
Description: "Cumulative count of heap allocations whose storage was freed " +
"by the garbage collector. " +
"Note that this does not include tiny objects as defined by " +
"/gc/heap/tiny/allocs:objects, only tiny blocks.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/heap/goal:bytes",
Description: "Heap size target for the end of the GC cycle.",
Kind: KindUint64,
},
{
Name: "/gc/heap/objects:objects",
Description: "Number of objects, live or unswept, occupying heap memory.",
Kind: KindUint64,
},
{
Name: "/gc/heap/tiny/allocs:objects",
Description: "Count of small allocations that are packed together into blocks. " +
"These allocations are counted separately from other allocations " +
"because each individual allocation is not tracked by the runtime, " +
"only their block. Each block is already accounted for in " +
"allocs-by-size and frees-by-size.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/gc/limiter/last-enabled:gc-cycle",
Description: "GC cycle the last time the GC CPU limiter was enabled. " +
"This metric is useful for diagnosing the root cause of an out-of-memory " +
"error, because the limiter trades memory for CPU time when the GC's CPU " +
"time gets too high. This is most likely to occur with use of SetMemoryLimit. " +
"The first GC cycle is cycle 1, so a value of 0 indicates that it was never enabled.",
Kind: KindUint64,
},
{
Name: "/gc/pauses:seconds",
Description: "Distribution individual GC-related stop-the-world pause latencies.",
Kind: KindFloat64Histogram,
Cumulative: true,
},
{
Name: "/gc/stack/starting-size:bytes",
Description: "The stack size of new goroutines.",
Kind: KindUint64,
Cumulative: false,
},
{
Name: "/godebug/non-default-behavior/execerrdot:events",
Description: "The number of non-default behaviors executed by the os/exec package " +
"due to a non-default GODEBUG=execerrdot=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/http2client:events",
Description: "The number of non-default behaviors executed by the net/http package " +
"due to a non-default GODEBUG=http2client=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/http2server:events",
Description: "The number of non-default behaviors executed by the net/http package " +
"due to a non-default GODEBUG=http2server=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/installgoroot:events",
Description: "The number of non-default behaviors executed by the go/build package " +
"due to a non-default GODEBUG=installgoroot=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/panicnil:events",
Description: "The number of non-default behaviors executed by the runtime package " +
"due to a non-default GODEBUG=panicnil=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/randautoseed:events",
Description: "The number of non-default behaviors executed by the math/rand package " +
"due to a non-default GODEBUG=randautoseed=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/tarinsecurepath:events",
Description: "The number of non-default behaviors executed by the archive/tar package " +
"due to a non-default GODEBUG=tarinsecurepath=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/x509sha1:events",
Description: "The number of non-default behaviors executed by the crypto/x509 package " +
"due to a non-default GODEBUG=x509sha1=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/x509usefallbackroots:events",
Description: "The number of non-default behaviors executed by the crypto/x509 package " +
"due to a non-default GODEBUG=x509usefallbackroots=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/godebug/non-default-behavior/zipinsecurepath:events",
Description: "The number of non-default behaviors executed by the archive/zip package " +
"due to a non-default GODEBUG=zipinsecurepath=... setting.",
Kind: KindUint64,
Cumulative: true,
},
{
Name: "/memory/classes/heap/free:bytes",
Description: "Memory that is completely free and eligible to be returned to the underlying system, " +
"but has not been. This metric is the runtime's estimate of free address space that is backed by " +
"physical memory.",
Kind: KindUint64,
},
{
Name: "/memory/classes/heap/objects:bytes",
Description: "Memory occupied by live objects and dead objects that have not yet been marked free by the garbage collector.",
Kind: KindUint64,
},
{
Name: "/memory/classes/heap/released:bytes",
Description: "Memory that is completely free and has been returned to the underlying system. This " +
"metric is the runtime's estimate of free address space that is still mapped into the process, " +
"but is not backed by physical memory.",
Kind: KindUint64,
},
{
Name: "/memory/classes/heap/stacks:bytes",
Description: "Memory allocated from the heap that is reserved for stack space, whether or not it is currently in-use.",
Kind: KindUint64,
},
{
Name: "/memory/classes/heap/unused:bytes",
Description: "Memory that is reserved for heap objects but is not currently used to hold heap objects.",
Kind: KindUint64,
},
{
Name: "/memory/classes/metadata/mcache/free:bytes",
Description: "Memory that is reserved for runtime mcache structures, but not in-use.",
Kind: KindUint64,
},
{
Name: "/memory/classes/metadata/mcache/inuse:bytes",
Description: "Memory that is occupied by runtime mcache structures that are currently being used.",
Kind: KindUint64,
},
{
Name: "/memory/classes/metadata/mspan/free:bytes",
Description: "Memory that is reserved for runtime mspan structures, but not in-use.",
Kind: KindUint64,
},
{
Name: "/memory/classes/metadata/mspan/inuse:bytes",
Description: "Memory that is occupied by runtime mspan structures that are currently being used.",
Kind: KindUint64,
},
{
Name: "/memory/classes/metadata/other:bytes",
Description: "Memory that is reserved for or used to hold runtime metadata.",
Kind: KindUint64,
},
{
Name: "/memory/classes/os-stacks:bytes",
Description: "Stack memory allocated by the underlying operating system.",
Kind: KindUint64,
},
{
Name: "/memory/classes/other:bytes",
Description: "Memory used by execution trace buffers, structures for debugging the runtime, finalizer and profiler specials, and more.",
Kind: KindUint64,
},
{
Name: "/memory/classes/profiling/buckets:bytes",
Description: "Memory that is used by the stack trace hash map used for profiling.",
Kind: KindUint64,
},
{
Name: "/memory/classes/total:bytes",
Description: "All memory mapped by the Go runtime into the current process as read-write. Note that this does not include memory mapped by code called via cgo or via the syscall package. Sum of all metrics in /memory/classes.",
Kind: KindUint64,
},
{
Name: "/sched/gomaxprocs:threads",
Description: "The current runtime.GOMAXPROCS setting, or the number of operating system threads that can execute user-level Go code simultaneously.",
Kind: KindUint64,
},
{
Name: "/sched/goroutines:goroutines",
Description: "Count of live goroutines.",
Kind: KindUint64,
},
{
Name: "/sched/latencies:seconds",
Description: "Distribution of the time goroutines have spent in the scheduler in a runnable state before actually running.",
Kind: KindFloat64Histogram,
},
{
Name: "/sync/mutex/wait/total:seconds",
Description: "Approximate cumulative time goroutines have spent blocked on a sync.Mutex or sync.RWMutex. This metric is useful for identifying global changes in lock contention. Collect a mutex or block profile using the runtime/pprof package for more detailed contention data.",
Kind: KindFloat64,
Cumulative: true,
},
}
// All returns a slice of containing metric descriptions for all supported metrics.
func All() []Description {
return allDesc
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metrics
import (
_ "runtime" // depends on the runtime via a linkname'd function
"unsafe"
)
// Sample captures a single metric sample.
type Sample struct {
// Name is the name of the metric sampled.
//
// It must correspond to a name in one of the metric descriptions
// returned by All.
Name string
// Value is the value of the metric sample.
Value Value
}
// Implemented in the runtime.
func runtime_readMetrics(unsafe.Pointer, int, int)
// Read populates each Value field in the given slice of metric samples.
//
// Desired metrics should be present in the slice with the appropriate name.
// The user of this API is encouraged to re-use the same slice between calls for
// efficiency, but is not required to do so.
//
// Note that re-use has some caveats. Notably, Values should not be read or
// manipulated while a Read with that value is outstanding; that is a data race.
// This property includes pointer-typed Values (for example, Float64Histogram)
// whose underlying storage will be reused by Read when possible. To safely use
// such values in a concurrent setting, all data must be deep-copied.
//
// It is safe to execute multiple Read calls concurrently, but their arguments
// must share no underlying memory. When in doubt, create a new []Sample from
// scratch, which is always safe, though may be inefficient.
//
// Sample values with names not appearing in All will have their Value populated
// as KindBad to indicate that the name is unknown.
func Read(m []Sample) {
runtime_readMetrics(unsafe.Pointer(&m[0]), len(m), cap(m))
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metrics
import (
"math"
"unsafe"
)
// ValueKind is a tag for a metric Value which indicates its type.
type ValueKind int
const (
// KindBad indicates that the Value has no type and should not be used.
KindBad ValueKind = iota
// KindUint64 indicates that the type of the Value is a uint64.
KindUint64
// KindFloat64 indicates that the type of the Value is a float64.
KindFloat64
// KindFloat64Histogram indicates that the type of the Value is a *Float64Histogram.
KindFloat64Histogram
)
// Value represents a metric value returned by the runtime.
type Value struct {
kind ValueKind
scalar uint64 // contains scalar values for scalar Kinds.
pointer unsafe.Pointer // contains non-scalar values.
}
// Kind returns the tag representing the kind of value this is.
func (v Value) Kind() ValueKind {
return v.kind
}
// Uint64 returns the internal uint64 value for the metric.
//
// If v.Kind() != KindUint64, this method panics.
func (v Value) Uint64() uint64 {
if v.kind != KindUint64 {
panic("called Uint64 on non-uint64 metric value")
}
return v.scalar
}
// Float64 returns the internal float64 value for the metric.
//
// If v.Kind() != KindFloat64, this method panics.
func (v Value) Float64() float64 {
if v.kind != KindFloat64 {
panic("called Float64 on non-float64 metric value")
}
return math.Float64frombits(v.scalar)
}
// Float64Histogram returns the internal *Float64Histogram value for the metric.
//
// If v.Kind() != KindFloat64Histogram, this method panics.
func (v Value) Float64Histogram() *Float64Histogram {
if v.kind != KindFloat64Histogram {
panic("called Float64Histogram on non-Float64Histogram metric value")
}
return (*Float64Histogram)(v.pointer)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: finalizers and block profiling.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// finblock is an array of finalizers to be executed. finblocks are
// arranged in a linked list for the finalizer queue.
//
// finblock is allocated from non-GC'd memory, so any heap pointers
// must be specially handled. GC currently assumes that the finalizer
// queue does not grow during marking (but it can shrink).
type finblock struct {
_ sys.NotInHeap
alllink *finblock
next *finblock
cnt uint32
_ int32
fin [(_FinBlockSize - 2*goarch.PtrSize - 2*4) / unsafe.Sizeof(finalizer{})]finalizer
}
var fingStatus atomic.Uint32
// finalizer goroutine status.
const (
fingUninitialized uint32 = iota
fingCreated uint32 = 1 << (iota - 1)
fingRunningFinalizer
fingWait
fingWake
)
var finlock mutex // protects the following variables
var fing *g // goroutine that runs finalizers
var finq *finblock // list of finalizers that are to be executed
var finc *finblock // cache of free blocks
var finptrmask [_FinBlockSize / goarch.PtrSize / 8]byte
var allfin *finblock // list of all blocks
// NOTE: Layout known to queuefinalizer.
type finalizer struct {
fn *funcval // function to call (may be a heap pointer)
arg unsafe.Pointer // ptr to object (may be a heap pointer)
nret uintptr // bytes of return values from fn
fint *_type // type of first argument of fn
ot *ptrtype // type of ptr to object (may be a heap pointer)
}
var finalizer1 = [...]byte{
// Each Finalizer is 5 words, ptr ptr INT ptr ptr (INT = uintptr here)
// Each byte describes 8 words.
// Need 8 Finalizers described by 5 bytes before pattern repeats:
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// ptr ptr INT ptr ptr
// aka
//
// ptr ptr INT ptr ptr ptr ptr INT
// ptr ptr ptr ptr INT ptr ptr ptr
// ptr INT ptr ptr ptr ptr INT ptr
// ptr ptr ptr INT ptr ptr ptr ptr
// INT ptr ptr ptr ptr INT ptr ptr
//
// Assumptions about Finalizer layout checked below.
1<<0 | 1<<1 | 0<<2 | 1<<3 | 1<<4 | 1<<5 | 1<<6 | 0<<7,
1<<0 | 1<<1 | 1<<2 | 1<<3 | 0<<4 | 1<<5 | 1<<6 | 1<<7,
1<<0 | 0<<1 | 1<<2 | 1<<3 | 1<<4 | 1<<5 | 0<<6 | 1<<7,
1<<0 | 1<<1 | 1<<2 | 0<<3 | 1<<4 | 1<<5 | 1<<6 | 1<<7,
0<<0 | 1<<1 | 1<<2 | 1<<3 | 1<<4 | 0<<5 | 1<<6 | 1<<7,
}
// lockRankMayQueueFinalizer records the lock ranking effects of a
// function that may call queuefinalizer.
func lockRankMayQueueFinalizer() {
lockWithRankMayAcquire(&finlock, getLockRank(&finlock))
}
func queuefinalizer(p unsafe.Pointer, fn *funcval, nret uintptr, fint *_type, ot *ptrtype) {
if gcphase != _GCoff {
// Currently we assume that the finalizer queue won't
// grow during marking so we don't have to rescan it
// during mark termination. If we ever need to lift
// this assumption, we can do it by adding the
// necessary barriers to queuefinalizer (which it may
// have automatically).
throw("queuefinalizer during GC")
}
lock(&finlock)
if finq == nil || finq.cnt == uint32(len(finq.fin)) {
if finc == nil {
finc = (*finblock)(persistentalloc(_FinBlockSize, 0, &memstats.gcMiscSys))
finc.alllink = allfin
allfin = finc
if finptrmask[0] == 0 {
// Build pointer mask for Finalizer array in block.
// Check assumptions made in finalizer1 array above.
if (unsafe.Sizeof(finalizer{}) != 5*goarch.PtrSize ||
unsafe.Offsetof(finalizer{}.fn) != 0 ||
unsafe.Offsetof(finalizer{}.arg) != goarch.PtrSize ||
unsafe.Offsetof(finalizer{}.nret) != 2*goarch.PtrSize ||
unsafe.Offsetof(finalizer{}.fint) != 3*goarch.PtrSize ||
unsafe.Offsetof(finalizer{}.ot) != 4*goarch.PtrSize) {
throw("finalizer out of sync")
}
for i := range finptrmask {
finptrmask[i] = finalizer1[i%len(finalizer1)]
}
}
}
block := finc
finc = block.next
block.next = finq
finq = block
}
f := &finq.fin[finq.cnt]
atomic.Xadd(&finq.cnt, +1) // Sync with markroots
f.fn = fn
f.nret = nret
f.fint = fint
f.ot = ot
f.arg = p
unlock(&finlock)
fingStatus.Or(fingWake)
}
//go:nowritebarrier
func iterate_finq(callback func(*funcval, unsafe.Pointer, uintptr, *_type, *ptrtype)) {
for fb := allfin; fb != nil; fb = fb.alllink {
for i := uint32(0); i < fb.cnt; i++ {
f := &fb.fin[i]
callback(f.fn, f.arg, f.nret, f.fint, f.ot)
}
}
}
func wakefing() *g {
if ok := fingStatus.CompareAndSwap(fingCreated|fingWait|fingWake, fingCreated); ok {
return fing
}
return nil
}
func createfing() {
// start the finalizer goroutine exactly once
if fingStatus.Load() == fingUninitialized && fingStatus.CompareAndSwap(fingUninitialized, fingCreated) {
go runfinq()
}
}
func finalizercommit(gp *g, lock unsafe.Pointer) bool {
unlock((*mutex)(lock))
// fingStatus should be modified after fing is put into a waiting state
// to avoid waking fing in running state, even if it is about to be parked.
fingStatus.Or(fingWait)
return true
}
// This is the goroutine that runs all of the finalizers.
func runfinq() {
var (
frame unsafe.Pointer
framecap uintptr
argRegs int
)
gp := getg()
lock(&finlock)
fing = gp
unlock(&finlock)
for {
lock(&finlock)
fb := finq
finq = nil
if fb == nil {
gopark(finalizercommit, unsafe.Pointer(&finlock), waitReasonFinalizerWait, traceEvGoBlock, 1)
continue
}
argRegs = intArgRegs
unlock(&finlock)
if raceenabled {
racefingo()
}
for fb != nil {
for i := fb.cnt; i > 0; i-- {
f := &fb.fin[i-1]
var regs abi.RegArgs
// The args may be passed in registers or on stack. Even for
// the register case, we still need the spill slots.
// TODO: revisit if we remove spill slots.
//
// Unfortunately because we can have an arbitrary
// amount of returns and it would be complex to try and
// figure out how many of those can get passed in registers,
// just conservatively assume none of them do.
framesz := unsafe.Sizeof((any)(nil)) + f.nret
if framecap < framesz {
// The frame does not contain pointers interesting for GC,
// all not yet finalized objects are stored in finq.
// If we do not mark it as FlagNoScan,
// the last finalized object is not collected.
frame = mallocgc(framesz, nil, true)
framecap = framesz
}
if f.fint == nil {
throw("missing type in runfinq")
}
r := frame
if argRegs > 0 {
r = unsafe.Pointer(®s.Ints)
} else {
// frame is effectively uninitialized
// memory. That means we have to clear
// it before writing to it to avoid
// confusing the write barrier.
*(*[2]uintptr)(frame) = [2]uintptr{}
}
switch f.fint.kind & kindMask {
case kindPtr:
// direct use of pointer
*(*unsafe.Pointer)(r) = f.arg
case kindInterface:
ityp := (*interfacetype)(unsafe.Pointer(f.fint))
// set up with empty interface
(*eface)(r)._type = &f.ot.typ
(*eface)(r).data = f.arg
if len(ityp.mhdr) != 0 {
// convert to interface with methods
// this conversion is guaranteed to succeed - we checked in SetFinalizer
(*iface)(r).tab = assertE2I(ityp, (*eface)(r)._type)
}
default:
throw("bad kind in runfinq")
}
fingStatus.Or(fingRunningFinalizer)
reflectcall(nil, unsafe.Pointer(f.fn), frame, uint32(framesz), uint32(framesz), uint32(framesz), ®s)
fingStatus.And(^fingRunningFinalizer)
// Drop finalizer queue heap references
// before hiding them from markroot.
// This also ensures these will be
// clear if we reuse the finalizer.
f.fn = nil
f.arg = nil
f.ot = nil
atomic.Store(&fb.cnt, i-1)
}
next := fb.next
lock(&finlock)
fb.next = finc
finc = fb
unlock(&finlock)
fb = next
}
}
}
// SetFinalizer sets the finalizer associated with obj to the provided
// finalizer function. When the garbage collector finds an unreachable block
// with an associated finalizer, it clears the association and runs
// finalizer(obj) in a separate goroutine. This makes obj reachable again,
// but now without an associated finalizer. Assuming that SetFinalizer
// is not called again, the next time the garbage collector sees
// that obj is unreachable, it will free obj.
//
// SetFinalizer(obj, nil) clears any finalizer associated with obj.
//
// The argument obj must be a pointer to an object allocated by calling
// new, by taking the address of a composite literal, or by taking the
// address of a local variable.
// The argument finalizer must be a function that takes a single argument
// to which obj's type can be assigned, and can have arbitrary ignored return
// values. If either of these is not true, SetFinalizer may abort the
// program.
//
// Finalizers are run in dependency order: if A points at B, both have
// finalizers, and they are otherwise unreachable, only the finalizer
// for A runs; once A is freed, the finalizer for B can run.
// If a cyclic structure includes a block with a finalizer, that
// cycle is not guaranteed to be garbage collected and the finalizer
// is not guaranteed to run, because there is no ordering that
// respects the dependencies.
//
// The finalizer is scheduled to run at some arbitrary time after the
// program can no longer reach the object to which obj points.
// There is no guarantee that finalizers will run before a program exits,
// so typically they are useful only for releasing non-memory resources
// associated with an object during a long-running program.
// For example, an os.File object could use a finalizer to close the
// associated operating system file descriptor when a program discards
// an os.File without calling Close, but it would be a mistake
// to depend on a finalizer to flush an in-memory I/O buffer such as a
// bufio.Writer, because the buffer would not be flushed at program exit.
//
// It is not guaranteed that a finalizer will run if the size of *obj is
// zero bytes, because it may share same address with other zero-size
// objects in memory. See https://go.dev/ref/spec#Size_and_alignment_guarantees.
//
// It is not guaranteed that a finalizer will run for objects allocated
// in initializers for package-level variables. Such objects may be
// linker-allocated, not heap-allocated.
//
// Note that because finalizers may execute arbitrarily far into the future
// after an object is no longer referenced, the runtime is allowed to perform
// a space-saving optimization that batches objects together in a single
// allocation slot. The finalizer for an unreferenced object in such an
// allocation may never run if it always exists in the same batch as a
// referenced object. Typically, this batching only happens for tiny
// (on the order of 16 bytes or less) and pointer-free objects.
//
// A finalizer may run as soon as an object becomes unreachable.
// In order to use finalizers correctly, the program must ensure that
// the object is reachable until it is no longer required.
// Objects stored in global variables, or that can be found by tracing
// pointers from a global variable, are reachable. For other objects,
// pass the object to a call of the KeepAlive function to mark the
// last point in the function where the object must be reachable.
//
// For example, if p points to a struct, such as os.File, that contains
// a file descriptor d, and p has a finalizer that closes that file
// descriptor, and if the last use of p in a function is a call to
// syscall.Write(p.d, buf, size), then p may be unreachable as soon as
// the program enters syscall.Write. The finalizer may run at that moment,
// closing p.d, causing syscall.Write to fail because it is writing to
// a closed file descriptor (or, worse, to an entirely different
// file descriptor opened by a different goroutine). To avoid this problem,
// call KeepAlive(p) after the call to syscall.Write.
//
// A single goroutine runs all finalizers for a program, sequentially.
// If a finalizer must run for a long time, it should do so by starting
// a new goroutine.
//
// In the terminology of the Go memory model, a call
// SetFinalizer(x, f) “synchronizes before” the finalization call f(x).
// However, there is no guarantee that KeepAlive(x) or any other use of x
// “synchronizes before” f(x), so in general a finalizer should use a mutex
// or other synchronization mechanism if it needs to access mutable state in x.
// For example, consider a finalizer that inspects a mutable field in x
// that is modified from time to time in the main program before x
// becomes unreachable and the finalizer is invoked.
// The modifications in the main program and the inspection in the finalizer
// need to use appropriate synchronization, such as mutexes or atomic updates,
// to avoid read-write races.
func SetFinalizer(obj any, finalizer any) {
if debug.sbrk != 0 {
// debug.sbrk never frees memory, so no finalizers run
// (and we don't have the data structures to record them).
return
}
e := efaceOf(&obj)
etyp := e._type
if etyp == nil {
throw("runtime.SetFinalizer: first argument is nil")
}
if etyp.kind&kindMask != kindPtr {
throw("runtime.SetFinalizer: first argument is " + etyp.string() + ", not pointer")
}
ot := (*ptrtype)(unsafe.Pointer(etyp))
if ot.elem == nil {
throw("nil elem type!")
}
if inUserArenaChunk(uintptr(e.data)) {
// Arena-allocated objects are not eligible for finalizers.
throw("runtime.SetFinalizer: first argument was allocated into an arena")
}
// find the containing object
base, _, _ := findObject(uintptr(e.data), 0, 0)
if base == 0 {
// 0-length objects are okay.
if e.data == unsafe.Pointer(&zerobase) {
return
}
// Global initializers might be linker-allocated.
// var Foo = &Object{}
// func main() {
// runtime.SetFinalizer(Foo, nil)
// }
// The relevant segments are: noptrdata, data, bss, noptrbss.
// We cannot assume they are in any order or even contiguous,
// due to external linking.
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if datap.noptrdata <= uintptr(e.data) && uintptr(e.data) < datap.enoptrdata ||
datap.data <= uintptr(e.data) && uintptr(e.data) < datap.edata ||
datap.bss <= uintptr(e.data) && uintptr(e.data) < datap.ebss ||
datap.noptrbss <= uintptr(e.data) && uintptr(e.data) < datap.enoptrbss {
return
}
}
throw("runtime.SetFinalizer: pointer not in allocated block")
}
if uintptr(e.data) != base {
// As an implementation detail we allow to set finalizers for an inner byte
// of an object if it could come from tiny alloc (see mallocgc for details).
if ot.elem == nil || ot.elem.ptrdata != 0 || ot.elem.size >= maxTinySize {
throw("runtime.SetFinalizer: pointer not at beginning of allocated block")
}
}
f := efaceOf(&finalizer)
ftyp := f._type
if ftyp == nil {
// switch to system stack and remove finalizer
systemstack(func() {
removefinalizer(e.data)
})
return
}
if ftyp.kind&kindMask != kindFunc {
throw("runtime.SetFinalizer: second argument is " + ftyp.string() + ", not a function")
}
ft := (*functype)(unsafe.Pointer(ftyp))
if ft.dotdotdot() {
throw("runtime.SetFinalizer: cannot pass " + etyp.string() + " to finalizer " + ftyp.string() + " because dotdotdot")
}
if ft.inCount != 1 {
throw("runtime.SetFinalizer: cannot pass " + etyp.string() + " to finalizer " + ftyp.string())
}
fint := ft.in()[0]
switch {
case fint == etyp:
// ok - same type
goto okarg
case fint.kind&kindMask == kindPtr:
if (fint.uncommon() == nil || etyp.uncommon() == nil) && (*ptrtype)(unsafe.Pointer(fint)).elem == ot.elem {
// ok - not same type, but both pointers,
// one or the other is unnamed, and same element type, so assignable.
goto okarg
}
case fint.kind&kindMask == kindInterface:
ityp := (*interfacetype)(unsafe.Pointer(fint))
if len(ityp.mhdr) == 0 {
// ok - satisfies empty interface
goto okarg
}
if iface := assertE2I2(ityp, *efaceOf(&obj)); iface.tab != nil {
goto okarg
}
}
throw("runtime.SetFinalizer: cannot pass " + etyp.string() + " to finalizer " + ftyp.string())
okarg:
// compute size needed for return parameters
nret := uintptr(0)
for _, t := range ft.out() {
nret = alignUp(nret, uintptr(t.align)) + uintptr(t.size)
}
nret = alignUp(nret, goarch.PtrSize)
// make sure we have a finalizer goroutine
createfing()
systemstack(func() {
if !addfinalizer(e.data, (*funcval)(f.data), nret, fint, ot) {
throw("runtime.SetFinalizer: finalizer already set")
}
})
}
// Mark KeepAlive as noinline so that it is easily detectable as an intrinsic.
//
//go:noinline
// KeepAlive marks its argument as currently reachable.
// This ensures that the object is not freed, and its finalizer is not run,
// before the point in the program where KeepAlive is called.
//
// A very simplified example showing where KeepAlive is required:
//
// type File struct { d int }
// d, err := syscall.Open("/file/path", syscall.O_RDONLY, 0)
// // ... do something if err != nil ...
// p := &File{d}
// runtime.SetFinalizer(p, func(p *File) { syscall.Close(p.d) })
// var buf [10]byte
// n, err := syscall.Read(p.d, buf[:])
// // Ensure p is not finalized until Read returns.
// runtime.KeepAlive(p)
// // No more uses of p after this point.
//
// Without the KeepAlive call, the finalizer could run at the start of
// syscall.Read, closing the file descriptor before syscall.Read makes
// the actual system call.
//
// Note: KeepAlive should only be used to prevent finalizers from
// running prematurely. In particular, when used with unsafe.Pointer,
// the rules for valid uses of unsafe.Pointer still apply.
func KeepAlive(x any) {
// Introduce a use of x that the compiler can't eliminate.
// This makes sure x is alive on entry. We need x to be alive
// on entry for "defer runtime.KeepAlive(x)"; see issue 21402.
if cgoAlwaysFalse {
println(x)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Fixed-size object allocator. Returned memory is not zeroed.
//
// See malloc.go for overview.
package runtime
import (
"runtime/internal/sys"
"unsafe"
)
// FixAlloc is a simple free-list allocator for fixed size objects.
// Malloc uses a FixAlloc wrapped around sysAlloc to manage its
// mcache and mspan objects.
//
// Memory returned by fixalloc.alloc is zeroed by default, but the
// caller may take responsibility for zeroing allocations by setting
// the zero flag to false. This is only safe if the memory never
// contains heap pointers.
//
// The caller is responsible for locking around FixAlloc calls.
// Callers can keep state in the object but the first word is
// smashed by freeing and reallocating.
//
// Consider marking fixalloc'd types not in heap by embedding
// runtime/internal/sys.NotInHeap.
type fixalloc struct {
size uintptr
first func(arg, p unsafe.Pointer) // called first time p is returned
arg unsafe.Pointer
list *mlink
chunk uintptr // use uintptr instead of unsafe.Pointer to avoid write barriers
nchunk uint32 // bytes remaining in current chunk
nalloc uint32 // size of new chunks in bytes
inuse uintptr // in-use bytes now
stat *sysMemStat
zero bool // zero allocations
}
// A generic linked list of blocks. (Typically the block is bigger than sizeof(MLink).)
// Since assignments to mlink.next will result in a write barrier being performed
// this cannot be used by some of the internal GC structures. For example when
// the sweeper is placing an unmarked object on the free list it does not want the
// write barrier to be called since that could result in the object being reachable.
type mlink struct {
_ sys.NotInHeap
next *mlink
}
// Initialize f to allocate objects of the given size,
// using the allocator to obtain chunks of memory.
func (f *fixalloc) init(size uintptr, first func(arg, p unsafe.Pointer), arg unsafe.Pointer, stat *sysMemStat) {
if size > _FixAllocChunk {
throw("runtime: fixalloc size too large")
}
if min := unsafe.Sizeof(mlink{}); size < min {
size = min
}
f.size = size
f.first = first
f.arg = arg
f.list = nil
f.chunk = 0
f.nchunk = 0
f.nalloc = uint32(_FixAllocChunk / size * size) // Round _FixAllocChunk down to an exact multiple of size to eliminate tail waste
f.inuse = 0
f.stat = stat
f.zero = true
}
func (f *fixalloc) alloc() unsafe.Pointer {
if f.size == 0 {
print("runtime: use of FixAlloc_Alloc before FixAlloc_Init\n")
throw("runtime: internal error")
}
if f.list != nil {
v := unsafe.Pointer(f.list)
f.list = f.list.next
f.inuse += f.size
if f.zero {
memclrNoHeapPointers(v, f.size)
}
return v
}
if uintptr(f.nchunk) < f.size {
f.chunk = uintptr(persistentalloc(uintptr(f.nalloc), 0, f.stat))
f.nchunk = f.nalloc
}
v := unsafe.Pointer(f.chunk)
if f.first != nil {
f.first(f.arg, v)
}
f.chunk = f.chunk + f.size
f.nchunk -= uint32(f.size)
f.inuse += f.size
return v
}
func (f *fixalloc) free(p unsafe.Pointer) {
f.inuse -= f.size
v := (*mlink)(p)
v.next = f.list
f.list = v
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector (GC).
//
// The GC runs concurrently with mutator threads, is type accurate (aka precise), allows multiple
// GC thread to run in parallel. It is a concurrent mark and sweep that uses a write barrier. It is
// non-generational and non-compacting. Allocation is done using size segregated per P allocation
// areas to minimize fragmentation while eliminating locks in the common case.
//
// The algorithm decomposes into several steps.
// This is a high level description of the algorithm being used. For an overview of GC a good
// place to start is Richard Jones' gchandbook.org.
//
// The algorithm's intellectual heritage includes Dijkstra's on-the-fly algorithm, see
// Edsger W. Dijkstra, Leslie Lamport, A. J. Martin, C. S. Scholten, and E. F. M. Steffens. 1978.
// On-the-fly garbage collection: an exercise in cooperation. Commun. ACM 21, 11 (November 1978),
// 966-975.
// For journal quality proofs that these steps are complete, correct, and terminate see
// Hudson, R., and Moss, J.E.B. Copying Garbage Collection without stopping the world.
// Concurrency and Computation: Practice and Experience 15(3-5), 2003.
//
// 1. GC performs sweep termination.
//
// a. Stop the world. This causes all Ps to reach a GC safe-point.
//
// b. Sweep any unswept spans. There will only be unswept spans if
// this GC cycle was forced before the expected time.
//
// 2. GC performs the mark phase.
//
// a. Prepare for the mark phase by setting gcphase to _GCmark
// (from _GCoff), enabling the write barrier, enabling mutator
// assists, and enqueueing root mark jobs. No objects may be
// scanned until all Ps have enabled the write barrier, which is
// accomplished using STW.
//
// b. Start the world. From this point, GC work is done by mark
// workers started by the scheduler and by assists performed as
// part of allocation. The write barrier shades both the
// overwritten pointer and the new pointer value for any pointer
// writes (see mbarrier.go for details). Newly allocated objects
// are immediately marked black.
//
// c. GC performs root marking jobs. This includes scanning all
// stacks, shading all globals, and shading any heap pointers in
// off-heap runtime data structures. Scanning a stack stops a
// goroutine, shades any pointers found on its stack, and then
// resumes the goroutine.
//
// d. GC drains the work queue of grey objects, scanning each grey
// object to black and shading all pointers found in the object
// (which in turn may add those pointers to the work queue).
//
// e. Because GC work is spread across local caches, GC uses a
// distributed termination algorithm to detect when there are no
// more root marking jobs or grey objects (see gcMarkDone). At this
// point, GC transitions to mark termination.
//
// 3. GC performs mark termination.
//
// a. Stop the world.
//
// b. Set gcphase to _GCmarktermination, and disable workers and
// assists.
//
// c. Perform housekeeping like flushing mcaches.
//
// 4. GC performs the sweep phase.
//
// a. Prepare for the sweep phase by setting gcphase to _GCoff,
// setting up sweep state and disabling the write barrier.
//
// b. Start the world. From this point on, newly allocated objects
// are white, and allocating sweeps spans before use if necessary.
//
// c. GC does concurrent sweeping in the background and in response
// to allocation. See description below.
//
// 5. When sufficient allocation has taken place, replay the sequence
// starting with 1 above. See discussion of GC rate below.
// Concurrent sweep.
//
// The sweep phase proceeds concurrently with normal program execution.
// The heap is swept span-by-span both lazily (when a goroutine needs another span)
// and concurrently in a background goroutine (this helps programs that are not CPU bound).
// At the end of STW mark termination all spans are marked as "needs sweeping".
//
// The background sweeper goroutine simply sweeps spans one-by-one.
//
// To avoid requesting more OS memory while there are unswept spans, when a
// goroutine needs another span, it first attempts to reclaim that much memory
// by sweeping. When a goroutine needs to allocate a new small-object span, it
// sweeps small-object spans for the same object size until it frees at least
// one object. When a goroutine needs to allocate large-object span from heap,
// it sweeps spans until it frees at least that many pages into heap. There is
// one case where this may not suffice: if a goroutine sweeps and frees two
// nonadjacent one-page spans to the heap, it will allocate a new two-page
// span, but there can still be other one-page unswept spans which could be
// combined into a two-page span.
//
// It's critical to ensure that no operations proceed on unswept spans (that would corrupt
// mark bits in GC bitmap). During GC all mcaches are flushed into the central cache,
// so they are empty. When a goroutine grabs a new span into mcache, it sweeps it.
// When a goroutine explicitly frees an object or sets a finalizer, it ensures that
// the span is swept (either by sweeping it, or by waiting for the concurrent sweep to finish).
// The finalizer goroutine is kicked off only when all spans are swept.
// When the next GC starts, it sweeps all not-yet-swept spans (if any).
// GC rate.
// Next GC is after we've allocated an extra amount of memory proportional to
// the amount already in use. The proportion is controlled by GOGC environment variable
// (100 by default). If GOGC=100 and we're using 4M, we'll GC again when we get to 8M
// (this mark is computed by the gcController.heapGoal method). This keeps the GC cost in
// linear proportion to the allocation cost. Adjusting GOGC just changes the linear constant
// (and also the amount of extra memory used).
// Oblets
//
// In order to prevent long pauses while scanning large objects and to
// improve parallelism, the garbage collector breaks up scan jobs for
// objects larger than maxObletBytes into "oblets" of at most
// maxObletBytes. When scanning encounters the beginning of a large
// object, it scans only the first oblet and enqueues the remaining
// oblets as new scan jobs.
package runtime
import (
"internal/cpu"
"runtime/internal/atomic"
"unsafe"
)
const (
_DebugGC = 0
_ConcurrentSweep = true
_FinBlockSize = 4 * 1024
// debugScanConservative enables debug logging for stack
// frames that are scanned conservatively.
debugScanConservative = false
// sweepMinHeapDistance is a lower bound on the heap distance
// (in bytes) reserved for concurrent sweeping between GC
// cycles.
sweepMinHeapDistance = 1024 * 1024
)
func gcinit() {
if unsafe.Sizeof(workbuf{}) != _WorkbufSize {
throw("size of Workbuf is suboptimal")
}
// No sweep on the first cycle.
sweep.active.state.Store(sweepDrainedMask)
// Initialize GC pacer state.
// Use the environment variable GOGC for the initial gcPercent value.
// Use the environment variable GOMEMLIMIT for the initial memoryLimit value.
gcController.init(readGOGC(), readGOMEMLIMIT())
work.startSema = 1
work.markDoneSema = 1
lockInit(&work.sweepWaiters.lock, lockRankSweepWaiters)
lockInit(&work.assistQueue.lock, lockRankAssistQueue)
lockInit(&work.wbufSpans.lock, lockRankWbufSpans)
}
// gcenable is called after the bulk of the runtime initialization,
// just before we're about to start letting user code run.
// It kicks off the background sweeper goroutine, the background
// scavenger goroutine, and enables GC.
func gcenable() {
// Kick off sweeping and scavenging.
c := make(chan int, 2)
go bgsweep(c)
go bgscavenge(c)
<-c
<-c
memstats.enablegc = true // now that runtime is initialized, GC is okay
}
// Garbage collector phase.
// Indicates to write barrier and synchronization task to perform.
var gcphase uint32
// The compiler knows about this variable.
// If you change it, you must change builtin/runtime.go, too.
// If you change the first four bytes, you must also change the write
// barrier insertion code.
var writeBarrier struct {
enabled bool // compiler emits a check of this before calling write barrier
pad [3]byte // compiler uses 32-bit load for "enabled" field
needed bool // identical to enabled, for now (TODO: dedup)
alignme uint64 // guarantee alignment so that compiler can use a 32 or 64-bit load
}
// gcBlackenEnabled is 1 if mutator assists and background mark
// workers are allowed to blacken objects. This must only be set when
// gcphase == _GCmark.
var gcBlackenEnabled uint32
const (
_GCoff = iota // GC not running; sweeping in background, write barrier disabled
_GCmark // GC marking roots and workbufs: allocate black, write barrier ENABLED
_GCmarktermination // GC mark termination: allocate black, P's help GC, write barrier ENABLED
)
//go:nosplit
func setGCPhase(x uint32) {
atomic.Store(&gcphase, x)
writeBarrier.needed = gcphase == _GCmark || gcphase == _GCmarktermination
writeBarrier.enabled = writeBarrier.needed
}
// gcMarkWorkerMode represents the mode that a concurrent mark worker
// should operate in.
//
// Concurrent marking happens through four different mechanisms. One
// is mutator assists, which happen in response to allocations and are
// not scheduled. The other three are variations in the per-P mark
// workers and are distinguished by gcMarkWorkerMode.
type gcMarkWorkerMode int
const (
// gcMarkWorkerNotWorker indicates that the next scheduled G is not
// starting work and the mode should be ignored.
gcMarkWorkerNotWorker gcMarkWorkerMode = iota
// gcMarkWorkerDedicatedMode indicates that the P of a mark
// worker is dedicated to running that mark worker. The mark
// worker should run without preemption.
gcMarkWorkerDedicatedMode
// gcMarkWorkerFractionalMode indicates that a P is currently
// running the "fractional" mark worker. The fractional worker
// is necessary when GOMAXPROCS*gcBackgroundUtilization is not
// an integer and using only dedicated workers would result in
// utilization too far from the target of gcBackgroundUtilization.
// The fractional worker should run until it is preempted and
// will be scheduled to pick up the fractional part of
// GOMAXPROCS*gcBackgroundUtilization.
gcMarkWorkerFractionalMode
// gcMarkWorkerIdleMode indicates that a P is running the mark
// worker because it has nothing else to do. The idle worker
// should run until it is preempted and account its time
// against gcController.idleMarkTime.
gcMarkWorkerIdleMode
)
// gcMarkWorkerModeStrings are the strings labels of gcMarkWorkerModes
// to use in execution traces.
var gcMarkWorkerModeStrings = [...]string{
"Not worker",
"GC (dedicated)",
"GC (fractional)",
"GC (idle)",
}
// pollFractionalWorkerExit reports whether a fractional mark worker
// should self-preempt. It assumes it is called from the fractional
// worker.
func pollFractionalWorkerExit() bool {
// This should be kept in sync with the fractional worker
// scheduler logic in findRunnableGCWorker.
now := nanotime()
delta := now - gcController.markStartTime
if delta <= 0 {
return true
}
p := getg().m.p.ptr()
selfTime := p.gcFractionalMarkTime + (now - p.gcMarkWorkerStartTime)
// Add some slack to the utilization goal so that the
// fractional worker isn't behind again the instant it exits.
return float64(selfTime)/float64(delta) > 1.2*gcController.fractionalUtilizationGoal
}
var work workType
type workType struct {
full lfstack // lock-free list of full blocks workbuf
empty lfstack // lock-free list of empty blocks workbuf
pad0 cpu.CacheLinePad // prevents false-sharing between full/empty and nproc/nwait
wbufSpans struct {
lock mutex
// free is a list of spans dedicated to workbufs, but
// that don't currently contain any workbufs.
free mSpanList
// busy is a list of all spans containing workbufs on
// one of the workbuf lists.
busy mSpanList
}
// Restore 64-bit alignment on 32-bit.
_ uint32
// bytesMarked is the number of bytes marked this cycle. This
// includes bytes blackened in scanned objects, noscan objects
// that go straight to black, and permagrey objects scanned by
// markroot during the concurrent scan phase. This is updated
// atomically during the cycle. Updates may be batched
// arbitrarily, since the value is only read at the end of the
// cycle.
//
// Because of benign races during marking, this number may not
// be the exact number of marked bytes, but it should be very
// close.
//
// Put this field here because it needs 64-bit atomic access
// (and thus 8-byte alignment even on 32-bit architectures).
bytesMarked uint64
markrootNext uint32 // next markroot job
markrootJobs uint32 // number of markroot jobs
nproc uint32
tstart int64
nwait uint32
// Number of roots of various root types. Set by gcMarkRootPrepare.
//
// nStackRoots == len(stackRoots), but we have nStackRoots for
// consistency.
nDataRoots, nBSSRoots, nSpanRoots, nStackRoots int
// Base indexes of each root type. Set by gcMarkRootPrepare.
baseData, baseBSS, baseSpans, baseStacks, baseEnd uint32
// stackRoots is a snapshot of all of the Gs that existed
// before the beginning of concurrent marking. The backing
// store of this must not be modified because it might be
// shared with allgs.
stackRoots []*g
// Each type of GC state transition is protected by a lock.
// Since multiple threads can simultaneously detect the state
// transition condition, any thread that detects a transition
// condition must acquire the appropriate transition lock,
// re-check the transition condition and return if it no
// longer holds or perform the transition if it does.
// Likewise, any transition must invalidate the transition
// condition before releasing the lock. This ensures that each
// transition is performed by exactly one thread and threads
// that need the transition to happen block until it has
// happened.
//
// startSema protects the transition from "off" to mark or
// mark termination.
startSema uint32
// markDoneSema protects transitions from mark to mark termination.
markDoneSema uint32
bgMarkReady note // signal background mark worker has started
bgMarkDone uint32 // cas to 1 when at a background mark completion point
// Background mark completion signaling
// mode is the concurrency mode of the current GC cycle.
mode gcMode
// userForced indicates the current GC cycle was forced by an
// explicit user call.
userForced bool
// initialHeapLive is the value of gcController.heapLive at the
// beginning of this GC cycle.
initialHeapLive uint64
// assistQueue is a queue of assists that are blocked because
// there was neither enough credit to steal or enough work to
// do.
assistQueue struct {
lock mutex
q gQueue
}
// sweepWaiters is a list of blocked goroutines to wake when
// we transition from mark termination to sweep.
sweepWaiters struct {
lock mutex
list gList
}
// cycles is the number of completed GC cycles, where a GC
// cycle is sweep termination, mark, mark termination, and
// sweep. This differs from memstats.numgc, which is
// incremented at mark termination.
cycles atomic.Uint32
// Timing/utilization stats for this cycle.
stwprocs, maxprocs int32
tSweepTerm, tMark, tMarkTerm, tEnd int64 // nanotime() of phase start
pauseNS int64 // total STW time this cycle
pauseStart int64 // nanotime() of last STW
// debug.gctrace heap sizes for this cycle.
heap0, heap1, heap2 uint64
// Cumulative estimated CPU usage.
cpuStats
}
// GC runs a garbage collection and blocks the caller until the
// garbage collection is complete. It may also block the entire
// program.
func GC() {
// We consider a cycle to be: sweep termination, mark, mark
// termination, and sweep. This function shouldn't return
// until a full cycle has been completed, from beginning to
// end. Hence, we always want to finish up the current cycle
// and start a new one. That means:
//
// 1. In sweep termination, mark, or mark termination of cycle
// N, wait until mark termination N completes and transitions
// to sweep N.
//
// 2. In sweep N, help with sweep N.
//
// At this point we can begin a full cycle N+1.
//
// 3. Trigger cycle N+1 by starting sweep termination N+1.
//
// 4. Wait for mark termination N+1 to complete.
//
// 5. Help with sweep N+1 until it's done.
//
// This all has to be written to deal with the fact that the
// GC may move ahead on its own. For example, when we block
// until mark termination N, we may wake up in cycle N+2.
// Wait until the current sweep termination, mark, and mark
// termination complete.
n := work.cycles.Load()
gcWaitOnMark(n)
// We're now in sweep N or later. Trigger GC cycle N+1, which
// will first finish sweep N if necessary and then enter sweep
// termination N+1.
gcStart(gcTrigger{kind: gcTriggerCycle, n: n + 1})
// Wait for mark termination N+1 to complete.
gcWaitOnMark(n + 1)
// Finish sweep N+1 before returning. We do this both to
// complete the cycle and because runtime.GC() is often used
// as part of tests and benchmarks to get the system into a
// relatively stable and isolated state.
for work.cycles.Load() == n+1 && sweepone() != ^uintptr(0) {
sweep.nbgsweep++
Gosched()
}
// Callers may assume that the heap profile reflects the
// just-completed cycle when this returns (historically this
// happened because this was a STW GC), but right now the
// profile still reflects mark termination N, not N+1.
//
// As soon as all of the sweep frees from cycle N+1 are done,
// we can go ahead and publish the heap profile.
//
// First, wait for sweeping to finish. (We know there are no
// more spans on the sweep queue, but we may be concurrently
// sweeping spans, so we have to wait.)
for work.cycles.Load() == n+1 && !isSweepDone() {
Gosched()
}
// Now we're really done with sweeping, so we can publish the
// stable heap profile. Only do this if we haven't already hit
// another mark termination.
mp := acquirem()
cycle := work.cycles.Load()
if cycle == n+1 || (gcphase == _GCmark && cycle == n+2) {
mProf_PostSweep()
}
releasem(mp)
}
// gcWaitOnMark blocks until GC finishes the Nth mark phase. If GC has
// already completed this mark phase, it returns immediately.
func gcWaitOnMark(n uint32) {
for {
// Disable phase transitions.
lock(&work.sweepWaiters.lock)
nMarks := work.cycles.Load()
if gcphase != _GCmark {
// We've already completed this cycle's mark.
nMarks++
}
if nMarks > n {
// We're done.
unlock(&work.sweepWaiters.lock)
return
}
// Wait until sweep termination, mark, and mark
// termination of cycle N complete.
work.sweepWaiters.list.push(getg())
goparkunlock(&work.sweepWaiters.lock, waitReasonWaitForGCCycle, traceEvGoBlock, 1)
}
}
// gcMode indicates how concurrent a GC cycle should be.
type gcMode int
const (
gcBackgroundMode gcMode = iota // concurrent GC and sweep
gcForceMode // stop-the-world GC now, concurrent sweep
gcForceBlockMode // stop-the-world GC now and STW sweep (forced by user)
)
// A gcTrigger is a predicate for starting a GC cycle. Specifically,
// it is an exit condition for the _GCoff phase.
type gcTrigger struct {
kind gcTriggerKind
now int64 // gcTriggerTime: current time
n uint32 // gcTriggerCycle: cycle number to start
}
type gcTriggerKind int
const (
// gcTriggerHeap indicates that a cycle should be started when
// the heap size reaches the trigger heap size computed by the
// controller.
gcTriggerHeap gcTriggerKind = iota
// gcTriggerTime indicates that a cycle should be started when
// it's been more than forcegcperiod nanoseconds since the
// previous GC cycle.
gcTriggerTime
// gcTriggerCycle indicates that a cycle should be started if
// we have not yet started cycle number gcTrigger.n (relative
// to work.cycles).
gcTriggerCycle
)
// test reports whether the trigger condition is satisfied, meaning
// that the exit condition for the _GCoff phase has been met. The exit
// condition should be tested when allocating.
func (t gcTrigger) test() bool {
if !memstats.enablegc || panicking.Load() != 0 || gcphase != _GCoff {
return false
}
switch t.kind {
case gcTriggerHeap:
// Non-atomic access to gcController.heapLive for performance. If
// we are going to trigger on this, this thread just
// atomically wrote gcController.heapLive anyway and we'll see our
// own write.
trigger, _ := gcController.trigger()
return gcController.heapLive.Load() >= trigger
case gcTriggerTime:
if gcController.gcPercent.Load() < 0 {
return false
}
lastgc := int64(atomic.Load64(&memstats.last_gc_nanotime))
return lastgc != 0 && t.now-lastgc > forcegcperiod
case gcTriggerCycle:
// t.n > work.cycles, but accounting for wraparound.
return int32(t.n-work.cycles.Load()) > 0
}
return true
}
// gcStart starts the GC. It transitions from _GCoff to _GCmark (if
// debug.gcstoptheworld == 0) or performs all of GC (if
// debug.gcstoptheworld != 0).
//
// This may return without performing this transition in some cases,
// such as when called on a system stack or with locks held.
func gcStart(trigger gcTrigger) {
// Since this is called from malloc and malloc is called in
// the guts of a number of libraries that might be holding
// locks, don't attempt to start GC in non-preemptible or
// potentially unstable situations.
mp := acquirem()
if gp := getg(); gp == mp.g0 || mp.locks > 1 || mp.preemptoff != "" {
releasem(mp)
return
}
releasem(mp)
mp = nil
// Pick up the remaining unswept/not being swept spans concurrently
//
// This shouldn't happen if we're being invoked in background
// mode since proportional sweep should have just finished
// sweeping everything, but rounding errors, etc, may leave a
// few spans unswept. In forced mode, this is necessary since
// GC can be forced at any point in the sweeping cycle.
//
// We check the transition condition continuously here in case
// this G gets delayed in to the next GC cycle.
for trigger.test() && sweepone() != ^uintptr(0) {
sweep.nbgsweep++
}
// Perform GC initialization and the sweep termination
// transition.
semacquire(&work.startSema)
// Re-check transition condition under transition lock.
if !trigger.test() {
semrelease(&work.startSema)
return
}
// In gcstoptheworld debug mode, upgrade the mode accordingly.
// We do this after re-checking the transition condition so
// that multiple goroutines that detect the heap trigger don't
// start multiple STW GCs.
mode := gcBackgroundMode
if debug.gcstoptheworld == 1 {
mode = gcForceMode
} else if debug.gcstoptheworld == 2 {
mode = gcForceBlockMode
}
// Ok, we're doing it! Stop everybody else
semacquire(&gcsema)
semacquire(&worldsema)
// For stats, check if this GC was forced by the user.
// Update it under gcsema to avoid gctrace getting wrong values.
work.userForced = trigger.kind == gcTriggerCycle
if trace.enabled {
traceGCStart()
}
// Check that all Ps have finished deferred mcache flushes.
for _, p := range allp {
if fg := p.mcache.flushGen.Load(); fg != mheap_.sweepgen {
println("runtime: p", p.id, "flushGen", fg, "!= sweepgen", mheap_.sweepgen)
throw("p mcache not flushed")
}
}
gcBgMarkStartWorkers()
systemstack(gcResetMarkState)
work.stwprocs, work.maxprocs = gomaxprocs, gomaxprocs
if work.stwprocs > ncpu {
// This is used to compute CPU time of the STW phases,
// so it can't be more than ncpu, even if GOMAXPROCS is.
work.stwprocs = ncpu
}
work.heap0 = gcController.heapLive.Load()
work.pauseNS = 0
work.mode = mode
now := nanotime()
work.tSweepTerm = now
work.pauseStart = now
if trace.enabled {
traceGCSTWStart(1)
}
systemstack(stopTheWorldWithSema)
// Finish sweep before we start concurrent scan.
systemstack(func() {
finishsweep_m()
})
// clearpools before we start the GC. If we wait they memory will not be
// reclaimed until the next GC cycle.
clearpools()
work.cycles.Add(1)
// Assists and workers can start the moment we start
// the world.
gcController.startCycle(now, int(gomaxprocs), trigger)
// Notify the CPU limiter that assists may begin.
gcCPULimiter.startGCTransition(true, now)
// In STW mode, disable scheduling of user Gs. This may also
// disable scheduling of this goroutine, so it may block as
// soon as we start the world again.
if mode != gcBackgroundMode {
schedEnableUser(false)
}
// Enter concurrent mark phase and enable
// write barriers.
//
// Because the world is stopped, all Ps will
// observe that write barriers are enabled by
// the time we start the world and begin
// scanning.
//
// Write barriers must be enabled before assists are
// enabled because they must be enabled before
// any non-leaf heap objects are marked. Since
// allocations are blocked until assists can
// happen, we want enable assists as early as
// possible.
setGCPhase(_GCmark)
gcBgMarkPrepare() // Must happen before assist enable.
gcMarkRootPrepare()
// Mark all active tinyalloc blocks. Since we're
// allocating from these, they need to be black like
// other allocations. The alternative is to blacken
// the tiny block on every allocation from it, which
// would slow down the tiny allocator.
gcMarkTinyAllocs()
// At this point all Ps have enabled the write
// barrier, thus maintaining the no white to
// black invariant. Enable mutator assists to
// put back-pressure on fast allocating
// mutators.
atomic.Store(&gcBlackenEnabled, 1)
// In STW mode, we could block the instant systemstack
// returns, so make sure we're not preemptible.
mp = acquirem()
// Concurrent mark.
systemstack(func() {
now = startTheWorldWithSema(trace.enabled)
work.pauseNS += now - work.pauseStart
work.tMark = now
memstats.gcPauseDist.record(now - work.pauseStart)
// Release the CPU limiter.
gcCPULimiter.finishGCTransition(now)
})
// Release the world sema before Gosched() in STW mode
// because we will need to reacquire it later but before
// this goroutine becomes runnable again, and we could
// self-deadlock otherwise.
semrelease(&worldsema)
releasem(mp)
// Make sure we block instead of returning to user code
// in STW mode.
if mode != gcBackgroundMode {
Gosched()
}
semrelease(&work.startSema)
}
// gcMarkDoneFlushed counts the number of P's with flushed work.
//
// Ideally this would be a captured local in gcMarkDone, but forEachP
// escapes its callback closure, so it can't capture anything.
//
// This is protected by markDoneSema.
var gcMarkDoneFlushed uint32
// gcMarkDone transitions the GC from mark to mark termination if all
// reachable objects have been marked (that is, there are no grey
// objects and can be no more in the future). Otherwise, it flushes
// all local work to the global queues where it can be discovered by
// other workers.
//
// This should be called when all local mark work has been drained and
// there are no remaining workers. Specifically, when
//
// work.nwait == work.nproc && !gcMarkWorkAvailable(p)
//
// The calling context must be preemptible.
//
// Flushing local work is important because idle Ps may have local
// work queued. This is the only way to make that work visible and
// drive GC to completion.
//
// It is explicitly okay to have write barriers in this function. If
// it does transition to mark termination, then all reachable objects
// have been marked, so the write barrier cannot shade any more
// objects.
func gcMarkDone() {
// Ensure only one thread is running the ragged barrier at a
// time.
semacquire(&work.markDoneSema)
top:
// Re-check transition condition under transition lock.
//
// It's critical that this checks the global work queues are
// empty before performing the ragged barrier. Otherwise,
// there could be global work that a P could take after the P
// has passed the ragged barrier.
if !(gcphase == _GCmark && work.nwait == work.nproc && !gcMarkWorkAvailable(nil)) {
semrelease(&work.markDoneSema)
return
}
// forEachP needs worldsema to execute, and we'll need it to
// stop the world later, so acquire worldsema now.
semacquire(&worldsema)
// Flush all local buffers and collect flushedWork flags.
gcMarkDoneFlushed = 0
systemstack(func() {
gp := getg().m.curg
// Mark the user stack as preemptible so that it may be scanned.
// Otherwise, our attempt to force all P's to a safepoint could
// result in a deadlock as we attempt to preempt a worker that's
// trying to preempt us (e.g. for a stack scan).
casGToWaiting(gp, _Grunning, waitReasonGCMarkTermination)
forEachP(func(pp *p) {
// Flush the write barrier buffer, since this may add
// work to the gcWork.
wbBufFlush1(pp)
// Flush the gcWork, since this may create global work
// and set the flushedWork flag.
//
// TODO(austin): Break up these workbufs to
// better distribute work.
pp.gcw.dispose()
// Collect the flushedWork flag.
if pp.gcw.flushedWork {
atomic.Xadd(&gcMarkDoneFlushed, 1)
pp.gcw.flushedWork = false
}
})
casgstatus(gp, _Gwaiting, _Grunning)
})
if gcMarkDoneFlushed != 0 {
// More grey objects were discovered since the
// previous termination check, so there may be more
// work to do. Keep going. It's possible the
// transition condition became true again during the
// ragged barrier, so re-check it.
semrelease(&worldsema)
goto top
}
// There was no global work, no local work, and no Ps
// communicated work since we took markDoneSema. Therefore
// there are no grey objects and no more objects can be
// shaded. Transition to mark termination.
now := nanotime()
work.tMarkTerm = now
work.pauseStart = now
getg().m.preemptoff = "gcing"
if trace.enabled {
traceGCSTWStart(0)
}
systemstack(stopTheWorldWithSema)
// The gcphase is _GCmark, it will transition to _GCmarktermination
// below. The important thing is that the wb remains active until
// all marking is complete. This includes writes made by the GC.
// There is sometimes work left over when we enter mark termination due
// to write barriers performed after the completion barrier above.
// Detect this and resume concurrent mark. This is obviously
// unfortunate.
//
// See issue #27993 for details.
//
// Switch to the system stack to call wbBufFlush1, though in this case
// it doesn't matter because we're non-preemptible anyway.
restart := false
systemstack(func() {
for _, p := range allp {
wbBufFlush1(p)
if !p.gcw.empty() {
restart = true
break
}
}
})
if restart {
getg().m.preemptoff = ""
systemstack(func() {
now := startTheWorldWithSema(trace.enabled)
work.pauseNS += now - work.pauseStart
memstats.gcPauseDist.record(now - work.pauseStart)
})
semrelease(&worldsema)
goto top
}
gcComputeStartingStackSize()
// Disable assists and background workers. We must do
// this before waking blocked assists.
atomic.Store(&gcBlackenEnabled, 0)
// Notify the CPU limiter that GC assists will now cease.
gcCPULimiter.startGCTransition(false, now)
// Wake all blocked assists. These will run when we
// start the world again.
gcWakeAllAssists()
// Likewise, release the transition lock. Blocked
// workers and assists will run when we start the
// world again.
semrelease(&work.markDoneSema)
// In STW mode, re-enable user goroutines. These will be
// queued to run after we start the world.
schedEnableUser(true)
// endCycle depends on all gcWork cache stats being flushed.
// The termination algorithm above ensured that up to
// allocations since the ragged barrier.
gcController.endCycle(now, int(gomaxprocs), work.userForced)
// Perform mark termination. This will restart the world.
gcMarkTermination()
}
// World must be stopped and mark assists and background workers must be
// disabled.
func gcMarkTermination() {
// Start marktermination (write barrier remains enabled for now).
setGCPhase(_GCmarktermination)
work.heap1 = gcController.heapLive.Load()
startTime := nanotime()
mp := acquirem()
mp.preemptoff = "gcing"
mp.traceback = 2
curgp := mp.curg
casGToWaiting(curgp, _Grunning, waitReasonGarbageCollection)
// Run gc on the g0 stack. We do this so that the g stack
// we're currently running on will no longer change. Cuts
// the root set down a bit (g0 stacks are not scanned, and
// we don't need to scan gc's internal state). We also
// need to switch to g0 so we can shrink the stack.
systemstack(func() {
gcMark(startTime)
// Must return immediately.
// The outer function's stack may have moved
// during gcMark (it shrinks stacks, including the
// outer function's stack), so we must not refer
// to any of its variables. Return back to the
// non-system stack to pick up the new addresses
// before continuing.
})
systemstack(func() {
work.heap2 = work.bytesMarked
if debug.gccheckmark > 0 {
// Run a full non-parallel, stop-the-world
// mark using checkmark bits, to check that we
// didn't forget to mark anything during the
// concurrent mark process.
startCheckmarks()
gcResetMarkState()
gcw := &getg().m.p.ptr().gcw
gcDrain(gcw, 0)
wbBufFlush1(getg().m.p.ptr())
gcw.dispose()
endCheckmarks()
}
// marking is complete so we can turn the write barrier off
setGCPhase(_GCoff)
gcSweep(work.mode)
})
mp.traceback = 0
casgstatus(curgp, _Gwaiting, _Grunning)
if trace.enabled {
traceGCDone()
}
// all done
mp.preemptoff = ""
if gcphase != _GCoff {
throw("gc done but gcphase != _GCoff")
}
// Record heapInUse for scavenger.
memstats.lastHeapInUse = gcController.heapInUse.load()
// Update GC trigger and pacing, as well as downstream consumers
// of this pacing information, for the next cycle.
systemstack(gcControllerCommit)
// Update timing memstats
now := nanotime()
sec, nsec, _ := time_now()
unixNow := sec*1e9 + int64(nsec)
work.pauseNS += now - work.pauseStart
work.tEnd = now
memstats.gcPauseDist.record(now - work.pauseStart)
atomic.Store64(&memstats.last_gc_unix, uint64(unixNow)) // must be Unix time to make sense to user
atomic.Store64(&memstats.last_gc_nanotime, uint64(now)) // monotonic time for us
memstats.pause_ns[memstats.numgc%uint32(len(memstats.pause_ns))] = uint64(work.pauseNS)
memstats.pause_end[memstats.numgc%uint32(len(memstats.pause_end))] = uint64(unixNow)
memstats.pause_total_ns += uint64(work.pauseNS)
sweepTermCpu := int64(work.stwprocs) * (work.tMark - work.tSweepTerm)
// We report idle marking time below, but omit it from the
// overall utilization here since it's "free".
markAssistCpu := gcController.assistTime.Load()
markDedicatedCpu := gcController.dedicatedMarkTime.Load()
markFractionalCpu := gcController.fractionalMarkTime.Load()
markIdleCpu := gcController.idleMarkTime.Load()
markTermCpu := int64(work.stwprocs) * (work.tEnd - work.tMarkTerm)
scavAssistCpu := scavenge.assistTime.Load()
scavBgCpu := scavenge.backgroundTime.Load()
// Update cumulative GC CPU stats.
work.cpuStats.gcAssistTime += markAssistCpu
work.cpuStats.gcDedicatedTime += markDedicatedCpu + markFractionalCpu
work.cpuStats.gcIdleTime += markIdleCpu
work.cpuStats.gcPauseTime += sweepTermCpu + markTermCpu
work.cpuStats.gcTotalTime += sweepTermCpu + markAssistCpu + markDedicatedCpu + markFractionalCpu + markIdleCpu + markTermCpu
// Update cumulative scavenge CPU stats.
work.cpuStats.scavengeAssistTime += scavAssistCpu
work.cpuStats.scavengeBgTime += scavBgCpu
work.cpuStats.scavengeTotalTime += scavAssistCpu + scavBgCpu
// Update total CPU.
work.cpuStats.totalTime = sched.totaltime + (now-sched.procresizetime)*int64(gomaxprocs)
work.cpuStats.idleTime += sched.idleTime.Load()
// Compute userTime. We compute this indirectly as everything that's not the above.
//
// Since time spent in _Pgcstop is covered by gcPauseTime, and time spent in _Pidle
// is covered by idleTime, what we're left with is time spent in _Prunning and _Psyscall,
// the latter of which is fine because the P will either go idle or get used for something
// else via sysmon. Meanwhile if we subtract GC time from whatever's left, we get non-GC
// _Prunning time. Note that this still leaves time spent in sweeping and in the scheduler,
// but that's fine. The overwhelming majority of this time will be actual user time.
work.cpuStats.userTime = work.cpuStats.totalTime - (work.cpuStats.gcTotalTime +
work.cpuStats.scavengeTotalTime + work.cpuStats.idleTime)
// Compute overall GC CPU utilization.
// Omit idle marking time from the overall utilization here since it's "free".
memstats.gc_cpu_fraction = float64(work.cpuStats.gcTotalTime-work.cpuStats.gcIdleTime) / float64(work.cpuStats.totalTime)
// Reset assist time and background time stats.
//
// Do this now, instead of at the start of the next GC cycle, because
// these two may keep accumulating even if the GC is not active.
scavenge.assistTime.Store(0)
scavenge.backgroundTime.Store(0)
// Reset idle time stat.
sched.idleTime.Store(0)
// Reset sweep state.
sweep.nbgsweep = 0
sweep.npausesweep = 0
if work.userForced {
memstats.numforcedgc++
}
// Bump GC cycle count and wake goroutines waiting on sweep.
lock(&work.sweepWaiters.lock)
memstats.numgc++
injectglist(&work.sweepWaiters.list)
unlock(&work.sweepWaiters.lock)
// Release the CPU limiter.
gcCPULimiter.finishGCTransition(now)
// Finish the current heap profiling cycle and start a new
// heap profiling cycle. We do this before starting the world
// so events don't leak into the wrong cycle.
mProf_NextCycle()
// There may be stale spans in mcaches that need to be swept.
// Those aren't tracked in any sweep lists, so we need to
// count them against sweep completion until we ensure all
// those spans have been forced out.
sl := sweep.active.begin()
if !sl.valid {
throw("failed to set sweep barrier")
}
systemstack(func() { startTheWorldWithSema(trace.enabled) })
// Flush the heap profile so we can start a new cycle next GC.
// This is relatively expensive, so we don't do it with the
// world stopped.
mProf_Flush()
// Prepare workbufs for freeing by the sweeper. We do this
// asynchronously because it can take non-trivial time.
prepareFreeWorkbufs()
// Free stack spans. This must be done between GC cycles.
systemstack(freeStackSpans)
// Ensure all mcaches are flushed. Each P will flush its own
// mcache before allocating, but idle Ps may not. Since this
// is necessary to sweep all spans, we need to ensure all
// mcaches are flushed before we start the next GC cycle.
systemstack(func() {
forEachP(func(pp *p) {
pp.mcache.prepareForSweep()
})
})
// Now that we've swept stale spans in mcaches, they don't
// count against unswept spans.
sweep.active.end(sl)
// Print gctrace before dropping worldsema. As soon as we drop
// worldsema another cycle could start and smash the stats
// we're trying to print.
if debug.gctrace > 0 {
util := int(memstats.gc_cpu_fraction * 100)
var sbuf [24]byte
printlock()
print("gc ", memstats.numgc,
" @", string(itoaDiv(sbuf[:], uint64(work.tSweepTerm-runtimeInitTime)/1e6, 3)), "s ",
util, "%: ")
prev := work.tSweepTerm
for i, ns := range []int64{work.tMark, work.tMarkTerm, work.tEnd} {
if i != 0 {
print("+")
}
print(string(fmtNSAsMS(sbuf[:], uint64(ns-prev))))
prev = ns
}
print(" ms clock, ")
for i, ns := range []int64{
sweepTermCpu,
gcController.assistTime.Load(),
gcController.dedicatedMarkTime.Load() + gcController.fractionalMarkTime.Load(),
gcController.idleMarkTime.Load(),
markTermCpu,
} {
if i == 2 || i == 3 {
// Separate mark time components with /.
print("/")
} else if i != 0 {
print("+")
}
print(string(fmtNSAsMS(sbuf[:], uint64(ns))))
}
print(" ms cpu, ",
work.heap0>>20, "->", work.heap1>>20, "->", work.heap2>>20, " MB, ",
gcController.lastHeapGoal>>20, " MB goal, ",
gcController.lastStackScan.Load()>>20, " MB stacks, ",
gcController.globalsScan.Load()>>20, " MB globals, ",
work.maxprocs, " P")
if work.userForced {
print(" (forced)")
}
print("\n")
printunlock()
}
// Set any arena chunks that were deferred to fault.
lock(&userArenaState.lock)
faultList := userArenaState.fault
userArenaState.fault = nil
unlock(&userArenaState.lock)
for _, lc := range faultList {
lc.mspan.setUserArenaChunkToFault()
}
semrelease(&worldsema)
semrelease(&gcsema)
// Careful: another GC cycle may start now.
releasem(mp)
mp = nil
// now that gc is done, kick off finalizer thread if needed
if !concurrentSweep {
// give the queued finalizers, if any, a chance to run
Gosched()
}
}
// gcBgMarkStartWorkers prepares background mark worker goroutines. These
// goroutines will not run until the mark phase, but they must be started while
// the work is not stopped and from a regular G stack. The caller must hold
// worldsema.
func gcBgMarkStartWorkers() {
// Background marking is performed by per-P G's. Ensure that each P has
// a background GC G.
//
// Worker Gs don't exit if gomaxprocs is reduced. If it is raised
// again, we can reuse the old workers; no need to create new workers.
for gcBgMarkWorkerCount < gomaxprocs {
go gcBgMarkWorker()
notetsleepg(&work.bgMarkReady, -1)
noteclear(&work.bgMarkReady)
// The worker is now guaranteed to be added to the pool before
// its P's next findRunnableGCWorker.
gcBgMarkWorkerCount++
}
}
// gcBgMarkPrepare sets up state for background marking.
// Mutator assists must not yet be enabled.
func gcBgMarkPrepare() {
// Background marking will stop when the work queues are empty
// and there are no more workers (note that, since this is
// concurrent, this may be a transient state, but mark
// termination will clean it up). Between background workers
// and assists, we don't really know how many workers there
// will be, so we pretend to have an arbitrarily large number
// of workers, almost all of which are "waiting". While a
// worker is working it decrements nwait. If nproc == nwait,
// there are no workers.
work.nproc = ^uint32(0)
work.nwait = ^uint32(0)
}
// gcBgMarkWorkerNode is an entry in the gcBgMarkWorkerPool. It points to a single
// gcBgMarkWorker goroutine.
type gcBgMarkWorkerNode struct {
// Unused workers are managed in a lock-free stack. This field must be first.
node lfnode
// The g of this worker.
gp guintptr
// Release this m on park. This is used to communicate with the unlock
// function, which cannot access the G's stack. It is unused outside of
// gcBgMarkWorker().
m muintptr
}
func gcBgMarkWorker() {
gp := getg()
// We pass node to a gopark unlock function, so it can't be on
// the stack (see gopark). Prevent deadlock from recursively
// starting GC by disabling preemption.
gp.m.preemptoff = "GC worker init"
node := new(gcBgMarkWorkerNode)
gp.m.preemptoff = ""
node.gp.set(gp)
node.m.set(acquirem())
notewakeup(&work.bgMarkReady)
// After this point, the background mark worker is generally scheduled
// cooperatively by gcController.findRunnableGCWorker. While performing
// work on the P, preemption is disabled because we are working on
// P-local work buffers. When the preempt flag is set, this puts itself
// into _Gwaiting to be woken up by gcController.findRunnableGCWorker
// at the appropriate time.
//
// When preemption is enabled (e.g., while in gcMarkDone), this worker
// may be preempted and schedule as a _Grunnable G from a runq. That is
// fine; it will eventually gopark again for further scheduling via
// findRunnableGCWorker.
//
// Since we disable preemption before notifying bgMarkReady, we
// guarantee that this G will be in the worker pool for the next
// findRunnableGCWorker. This isn't strictly necessary, but it reduces
// latency between _GCmark starting and the workers starting.
for {
// Go to sleep until woken by
// gcController.findRunnableGCWorker.
gopark(func(g *g, nodep unsafe.Pointer) bool {
node := (*gcBgMarkWorkerNode)(nodep)
if mp := node.m.ptr(); mp != nil {
// The worker G is no longer running; release
// the M.
//
// N.B. it is _safe_ to release the M as soon
// as we are no longer performing P-local mark
// work.
//
// However, since we cooperatively stop work
// when gp.preempt is set, if we releasem in
// the loop then the following call to gopark
// would immediately preempt the G. This is
// also safe, but inefficient: the G must
// schedule again only to enter gopark and park
// again. Thus, we defer the release until
// after parking the G.
releasem(mp)
}
// Release this G to the pool.
gcBgMarkWorkerPool.push(&node.node)
// Note that at this point, the G may immediately be
// rescheduled and may be running.
return true
}, unsafe.Pointer(node), waitReasonGCWorkerIdle, traceEvGoBlock, 0)
// Preemption must not occur here, or another G might see
// p.gcMarkWorkerMode.
// Disable preemption so we can use the gcw. If the
// scheduler wants to preempt us, we'll stop draining,
// dispose the gcw, and then preempt.
node.m.set(acquirem())
pp := gp.m.p.ptr() // P can't change with preemption disabled.
if gcBlackenEnabled == 0 {
println("worker mode", pp.gcMarkWorkerMode)
throw("gcBgMarkWorker: blackening not enabled")
}
if pp.gcMarkWorkerMode == gcMarkWorkerNotWorker {
throw("gcBgMarkWorker: mode not set")
}
startTime := nanotime()
pp.gcMarkWorkerStartTime = startTime
var trackLimiterEvent bool
if pp.gcMarkWorkerMode == gcMarkWorkerIdleMode {
trackLimiterEvent = pp.limiterEvent.start(limiterEventIdleMarkWork, startTime)
}
decnwait := atomic.Xadd(&work.nwait, -1)
if decnwait == work.nproc {
println("runtime: work.nwait=", decnwait, "work.nproc=", work.nproc)
throw("work.nwait was > work.nproc")
}
systemstack(func() {
// Mark our goroutine preemptible so its stack
// can be scanned. This lets two mark workers
// scan each other (otherwise, they would
// deadlock). We must not modify anything on
// the G stack. However, stack shrinking is
// disabled for mark workers, so it is safe to
// read from the G stack.
casGToWaiting(gp, _Grunning, waitReasonGCWorkerActive)
switch pp.gcMarkWorkerMode {
default:
throw("gcBgMarkWorker: unexpected gcMarkWorkerMode")
case gcMarkWorkerDedicatedMode:
gcDrain(&pp.gcw, gcDrainUntilPreempt|gcDrainFlushBgCredit)
if gp.preempt {
// We were preempted. This is
// a useful signal to kick
// everything out of the run
// queue so it can run
// somewhere else.
if drainQ, n := runqdrain(pp); n > 0 {
lock(&sched.lock)
globrunqputbatch(&drainQ, int32(n))
unlock(&sched.lock)
}
}
// Go back to draining, this time
// without preemption.
gcDrain(&pp.gcw, gcDrainFlushBgCredit)
case gcMarkWorkerFractionalMode:
gcDrain(&pp.gcw, gcDrainFractional|gcDrainUntilPreempt|gcDrainFlushBgCredit)
case gcMarkWorkerIdleMode:
gcDrain(&pp.gcw, gcDrainIdle|gcDrainUntilPreempt|gcDrainFlushBgCredit)
}
casgstatus(gp, _Gwaiting, _Grunning)
})
// Account for time and mark us as stopped.
now := nanotime()
duration := now - startTime
gcController.markWorkerStop(pp.gcMarkWorkerMode, duration)
if trackLimiterEvent {
pp.limiterEvent.stop(limiterEventIdleMarkWork, now)
}
if pp.gcMarkWorkerMode == gcMarkWorkerFractionalMode {
atomic.Xaddint64(&pp.gcFractionalMarkTime, duration)
}
// Was this the last worker and did we run out
// of work?
incnwait := atomic.Xadd(&work.nwait, +1)
if incnwait > work.nproc {
println("runtime: p.gcMarkWorkerMode=", pp.gcMarkWorkerMode,
"work.nwait=", incnwait, "work.nproc=", work.nproc)
throw("work.nwait > work.nproc")
}
// We'll releasem after this point and thus this P may run
// something else. We must clear the worker mode to avoid
// attributing the mode to a different (non-worker) G in
// traceGoStart.
pp.gcMarkWorkerMode = gcMarkWorkerNotWorker
// If this worker reached a background mark completion
// point, signal the main GC goroutine.
if incnwait == work.nproc && !gcMarkWorkAvailable(nil) {
// We don't need the P-local buffers here, allow
// preemption because we may schedule like a regular
// goroutine in gcMarkDone (block on locks, etc).
releasem(node.m.ptr())
node.m.set(nil)
gcMarkDone()
}
}
}
// gcMarkWorkAvailable reports whether executing a mark worker
// on p is potentially useful. p may be nil, in which case it only
// checks the global sources of work.
func gcMarkWorkAvailable(p *p) bool {
if p != nil && !p.gcw.empty() {
return true
}
if !work.full.empty() {
return true // global work available
}
if work.markrootNext < work.markrootJobs {
return true // root scan work available
}
return false
}
// gcMark runs the mark (or, for concurrent GC, mark termination)
// All gcWork caches must be empty.
// STW is in effect at this point.
func gcMark(startTime int64) {
if debug.allocfreetrace > 0 {
tracegc()
}
if gcphase != _GCmarktermination {
throw("in gcMark expecting to see gcphase as _GCmarktermination")
}
work.tstart = startTime
// Check that there's no marking work remaining.
if work.full != 0 || work.markrootNext < work.markrootJobs {
print("runtime: full=", hex(work.full), " next=", work.markrootNext, " jobs=", work.markrootJobs, " nDataRoots=", work.nDataRoots, " nBSSRoots=", work.nBSSRoots, " nSpanRoots=", work.nSpanRoots, " nStackRoots=", work.nStackRoots, "\n")
panic("non-empty mark queue after concurrent mark")
}
if debug.gccheckmark > 0 {
// This is expensive when there's a large number of
// Gs, so only do it if checkmark is also enabled.
gcMarkRootCheck()
}
// Drop allg snapshot. allgs may have grown, in which case
// this is the only reference to the old backing store and
// there's no need to keep it around.
work.stackRoots = nil
// Clear out buffers and double-check that all gcWork caches
// are empty. This should be ensured by gcMarkDone before we
// enter mark termination.
//
// TODO: We could clear out buffers just before mark if this
// has a non-negligible impact on STW time.
for _, p := range allp {
// The write barrier may have buffered pointers since
// the gcMarkDone barrier. However, since the barrier
// ensured all reachable objects were marked, all of
// these must be pointers to black objects. Hence we
// can just discard the write barrier buffer.
if debug.gccheckmark > 0 {
// For debugging, flush the buffer and make
// sure it really was all marked.
wbBufFlush1(p)
} else {
p.wbBuf.reset()
}
gcw := &p.gcw
if !gcw.empty() {
printlock()
print("runtime: P ", p.id, " flushedWork ", gcw.flushedWork)
if gcw.wbuf1 == nil {
print(" wbuf1=<nil>")
} else {
print(" wbuf1.n=", gcw.wbuf1.nobj)
}
if gcw.wbuf2 == nil {
print(" wbuf2=<nil>")
} else {
print(" wbuf2.n=", gcw.wbuf2.nobj)
}
print("\n")
throw("P has cached GC work at end of mark termination")
}
// There may still be cached empty buffers, which we
// need to flush since we're going to free them. Also,
// there may be non-zero stats because we allocated
// black after the gcMarkDone barrier.
gcw.dispose()
}
// Flush scanAlloc from each mcache since we're about to modify
// heapScan directly. If we were to flush this later, then scanAlloc
// might have incorrect information.
//
// Note that it's not important to retain this information; we know
// exactly what heapScan is at this point via scanWork.
for _, p := range allp {
c := p.mcache
if c == nil {
continue
}
c.scanAlloc = 0
}
// Reset controller state.
gcController.resetLive(work.bytesMarked)
}
// gcSweep must be called on the system stack because it acquires the heap
// lock. See mheap for details.
//
// The world must be stopped.
//
//go:systemstack
func gcSweep(mode gcMode) {
assertWorldStopped()
if gcphase != _GCoff {
throw("gcSweep being done but phase is not GCoff")
}
lock(&mheap_.lock)
mheap_.sweepgen += 2
sweep.active.reset()
mheap_.pagesSwept.Store(0)
mheap_.sweepArenas = mheap_.allArenas
mheap_.reclaimIndex.Store(0)
mheap_.reclaimCredit.Store(0)
unlock(&mheap_.lock)
sweep.centralIndex.clear()
if !_ConcurrentSweep || mode == gcForceBlockMode {
// Special case synchronous sweep.
// Record that no proportional sweeping has to happen.
lock(&mheap_.lock)
mheap_.sweepPagesPerByte = 0
unlock(&mheap_.lock)
// Sweep all spans eagerly.
for sweepone() != ^uintptr(0) {
sweep.npausesweep++
}
// Free workbufs eagerly.
prepareFreeWorkbufs()
for freeSomeWbufs(false) {
}
// All "free" events for this mark/sweep cycle have
// now happened, so we can make this profile cycle
// available immediately.
mProf_NextCycle()
mProf_Flush()
return
}
// Background sweep.
lock(&sweep.lock)
if sweep.parked {
sweep.parked = false
ready(sweep.g, 0, true)
}
unlock(&sweep.lock)
}
// gcResetMarkState resets global state prior to marking (concurrent
// or STW) and resets the stack scan state of all Gs.
//
// This is safe to do without the world stopped because any Gs created
// during or after this will start out in the reset state.
//
// gcResetMarkState must be called on the system stack because it acquires
// the heap lock. See mheap for details.
//
//go:systemstack
func gcResetMarkState() {
// This may be called during a concurrent phase, so lock to make sure
// allgs doesn't change.
forEachG(func(gp *g) {
gp.gcscandone = false // set to true in gcphasework
gp.gcAssistBytes = 0
})
// Clear page marks. This is just 1MB per 64GB of heap, so the
// time here is pretty trivial.
lock(&mheap_.lock)
arenas := mheap_.allArenas
unlock(&mheap_.lock)
for _, ai := range arenas {
ha := mheap_.arenas[ai.l1()][ai.l2()]
for i := range ha.pageMarks {
ha.pageMarks[i] = 0
}
}
work.bytesMarked = 0
work.initialHeapLive = gcController.heapLive.Load()
}
// Hooks for other packages
var poolcleanup func()
var boringCaches []unsafe.Pointer // for crypto/internal/boring
//go:linkname sync_runtime_registerPoolCleanup sync.runtime_registerPoolCleanup
func sync_runtime_registerPoolCleanup(f func()) {
poolcleanup = f
}
//go:linkname boring_registerCache crypto/internal/boring/bcache.registerCache
func boring_registerCache(p unsafe.Pointer) {
boringCaches = append(boringCaches, p)
}
func clearpools() {
// clear sync.Pools
if poolcleanup != nil {
poolcleanup()
}
// clear boringcrypto caches
for _, p := range boringCaches {
atomicstorep(p, nil)
}
// Clear central sudog cache.
// Leave per-P caches alone, they have strictly bounded size.
// Disconnect cached list before dropping it on the floor,
// so that a dangling ref to one entry does not pin all of them.
lock(&sched.sudoglock)
var sg, sgnext *sudog
for sg = sched.sudogcache; sg != nil; sg = sgnext {
sgnext = sg.next
sg.next = nil
}
sched.sudogcache = nil
unlock(&sched.sudoglock)
// Clear central defer pool.
// Leave per-P pools alone, they have strictly bounded size.
lock(&sched.deferlock)
// disconnect cached list before dropping it on the floor,
// so that a dangling ref to one entry does not pin all of them.
var d, dlink *_defer
for d = sched.deferpool; d != nil; d = dlink {
dlink = d.link
d.link = nil
}
sched.deferpool = nil
unlock(&sched.deferlock)
}
// Timing
// itoaDiv formats val/(10**dec) into buf.
func itoaDiv(buf []byte, val uint64, dec int) []byte {
i := len(buf) - 1
idec := i - dec
for val >= 10 || i >= idec {
buf[i] = byte(val%10 + '0')
i--
if i == idec {
buf[i] = '.'
i--
}
val /= 10
}
buf[i] = byte(val + '0')
return buf[i:]
}
// fmtNSAsMS nicely formats ns nanoseconds as milliseconds.
func fmtNSAsMS(buf []byte, ns uint64) []byte {
if ns >= 10e6 {
// Format as whole milliseconds.
return itoaDiv(buf, ns/1e6, 0)
}
// Format two digits of precision, with at most three decimal places.
x := ns / 1e3
if x == 0 {
buf[0] = '0'
return buf[:1]
}
dec := 3
for x >= 100 {
x /= 10
dec--
}
return itoaDiv(buf, x, dec)
}
// Helpers for testing GC.
// gcTestMoveStackOnNextCall causes the stack to be moved on a call
// immediately following the call to this. It may not work correctly
// if any other work appears after this call (such as returning).
// Typically the following call should be marked go:noinline so it
// performs a stack check.
//
// In rare cases this may not cause the stack to move, specifically if
// there's a preemption between this call and the next.
func gcTestMoveStackOnNextCall() {
gp := getg()
gp.stackguard0 = stackForceMove
}
// gcTestIsReachable performs a GC and returns a bit set where bit i
// is set if ptrs[i] is reachable.
func gcTestIsReachable(ptrs ...unsafe.Pointer) (mask uint64) {
// This takes the pointers as unsafe.Pointers in order to keep
// them live long enough for us to attach specials. After
// that, we drop our references to them.
if len(ptrs) > 64 {
panic("too many pointers for uint64 mask")
}
// Block GC while we attach specials and drop our references
// to ptrs. Otherwise, if a GC is in progress, it could mark
// them reachable via this function before we have a chance to
// drop them.
semacquire(&gcsema)
// Create reachability specials for ptrs.
specials := make([]*specialReachable, len(ptrs))
for i, p := range ptrs {
lock(&mheap_.speciallock)
s := (*specialReachable)(mheap_.specialReachableAlloc.alloc())
unlock(&mheap_.speciallock)
s.special.kind = _KindSpecialReachable
if !addspecial(p, &s.special) {
throw("already have a reachable special (duplicate pointer?)")
}
specials[i] = s
// Make sure we don't retain ptrs.
ptrs[i] = nil
}
semrelease(&gcsema)
// Force a full GC and sweep.
GC()
// Process specials.
for i, s := range specials {
if !s.done {
printlock()
println("runtime: object", i, "was not swept")
throw("IsReachable failed")
}
if s.reachable {
mask |= 1 << i
}
lock(&mheap_.speciallock)
mheap_.specialReachableAlloc.free(unsafe.Pointer(s))
unlock(&mheap_.speciallock)
}
return mask
}
// gcTestPointerClass returns the category of what p points to, one of:
// "heap", "stack", "data", "bss", "other". This is useful for checking
// that a test is doing what it's intended to do.
//
// This is nosplit simply to avoid extra pointer shuffling that may
// complicate a test.
//
//go:nosplit
func gcTestPointerClass(p unsafe.Pointer) string {
p2 := uintptr(noescape(p))
gp := getg()
if gp.stack.lo <= p2 && p2 < gp.stack.hi {
return "stack"
}
if base, _, _ := findObject(p2, 0, 0); base != 0 {
return "heap"
}
for _, datap := range activeModules() {
if datap.data <= p2 && p2 < datap.edata || datap.noptrdata <= p2 && p2 < datap.enoptrdata {
return "data"
}
if datap.bss <= p2 && p2 < datap.ebss || datap.noptrbss <= p2 && p2 <= datap.enoptrbss {
return "bss"
}
}
KeepAlive(p)
return "other"
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "runtime/internal/atomic"
// gcCPULimiter is a mechanism to limit GC CPU utilization in situations
// where it might become excessive and inhibit application progress (e.g.
// a death spiral).
//
// The core of the limiter is a leaky bucket mechanism that fills with GC
// CPU time and drains with mutator time. Because the bucket fills and
// drains with time directly (i.e. without any weighting), this effectively
// sets a very conservative limit of 50%. This limit could be enforced directly,
// however, but the purpose of the bucket is to accommodate spikes in GC CPU
// utilization without hurting throughput.
//
// Note that the bucket in the leaky bucket mechanism can never go negative,
// so the GC never gets credit for a lot of CPU time spent without the GC
// running. This is intentional, as an application that stays idle for, say,
// an entire day, could build up enough credit to fail to prevent a death
// spiral the following day. The bucket's capacity is the GC's only leeway.
//
// The capacity thus also sets the window the limiter considers. For example,
// if the capacity of the bucket is 1 cpu-second, then the limiter will not
// kick in until at least 1 full cpu-second in the last 2 cpu-second window
// is spent on GC CPU time.
var gcCPULimiter gcCPULimiterState
type gcCPULimiterState struct {
lock atomic.Uint32
enabled atomic.Bool
bucket struct {
// Invariants:
// - fill >= 0
// - capacity >= 0
// - fill <= capacity
fill, capacity uint64
}
// overflow is the cumulative amount of GC CPU time that we tried to fill the
// bucket with but exceeded its capacity.
overflow uint64
// gcEnabled is an internal copy of gcBlackenEnabled that determines
// whether the limiter tracks total assist time.
//
// gcBlackenEnabled isn't used directly so as to keep this structure
// unit-testable.
gcEnabled bool
// transitioning is true when the GC is in a STW and transitioning between
// the mark and sweep phases.
transitioning bool
// assistTimePool is the accumulated assist time since the last update.
assistTimePool atomic.Int64
// idleMarkTimePool is the accumulated idle mark time since the last update.
idleMarkTimePool atomic.Int64
// idleTimePool is the accumulated time Ps spent on the idle list since the last update.
idleTimePool atomic.Int64
// lastUpdate is the nanotime timestamp of the last time update was called.
//
// Updated under lock, but may be read concurrently.
lastUpdate atomic.Int64
// lastEnabledCycle is the GC cycle that last had the limiter enabled.
lastEnabledCycle atomic.Uint32
// nprocs is an internal copy of gomaxprocs, used to determine total available
// CPU time.
//
// gomaxprocs isn't used directly so as to keep this structure unit-testable.
nprocs int32
// test indicates whether this instance of the struct was made for testing purposes.
test bool
}
// limiting returns true if the CPU limiter is currently enabled, meaning the Go GC
// should take action to limit CPU utilization.
//
// It is safe to call concurrently with other operations.
func (l *gcCPULimiterState) limiting() bool {
return l.enabled.Load()
}
// startGCTransition notifies the limiter of a GC transition.
//
// This call takes ownership of the limiter and disables all other means of
// updating the limiter. Release ownership by calling finishGCTransition.
//
// It is safe to call concurrently with other operations.
func (l *gcCPULimiterState) startGCTransition(enableGC bool, now int64) {
if !l.tryLock() {
// This must happen during a STW, so we can't fail to acquire the lock.
// If we did, something went wrong. Throw.
throw("failed to acquire lock to start a GC transition")
}
if l.gcEnabled == enableGC {
throw("transitioning GC to the same state as before?")
}
// Flush whatever was left between the last update and now.
l.updateLocked(now)
l.gcEnabled = enableGC
l.transitioning = true
// N.B. finishGCTransition releases the lock.
//
// We don't release here to increase the chance that if there's a failure
// to finish the transition, that we throw on failing to acquire the lock.
}
// finishGCTransition notifies the limiter that the GC transition is complete
// and releases ownership of it. It also accumulates STW time in the bucket.
// now must be the timestamp from the end of the STW pause.
func (l *gcCPULimiterState) finishGCTransition(now int64) {
if !l.transitioning {
throw("finishGCTransition called without starting one?")
}
// Count the full nprocs set of CPU time because the world is stopped
// between startGCTransition and finishGCTransition. Even though the GC
// isn't running on all CPUs, it is preventing user code from doing so,
// so it might as well be.
if lastUpdate := l.lastUpdate.Load(); now >= lastUpdate {
l.accumulate(0, (now-lastUpdate)*int64(l.nprocs))
}
l.lastUpdate.Store(now)
l.transitioning = false
l.unlock()
}
// gcCPULimiterUpdatePeriod dictates the maximum amount of wall-clock time
// we can go before updating the limiter.
const gcCPULimiterUpdatePeriod = 10e6 // 10ms
// needUpdate returns true if the limiter's maximum update period has been
// exceeded, and so would benefit from an update.
func (l *gcCPULimiterState) needUpdate(now int64) bool {
return now-l.lastUpdate.Load() > gcCPULimiterUpdatePeriod
}
// addAssistTime notifies the limiter of additional assist time. It will be
// included in the next update.
func (l *gcCPULimiterState) addAssistTime(t int64) {
l.assistTimePool.Add(t)
}
// addIdleTime notifies the limiter of additional time a P spent on the idle list. It will be
// subtracted from the total CPU time in the next update.
func (l *gcCPULimiterState) addIdleTime(t int64) {
l.idleTimePool.Add(t)
}
// update updates the bucket given runtime-specific information. now is the
// current monotonic time in nanoseconds.
//
// This is safe to call concurrently with other operations, except *GCTransition.
func (l *gcCPULimiterState) update(now int64) {
if !l.tryLock() {
// We failed to acquire the lock, which means something else is currently
// updating. Just drop our update, the next one to update will include
// our total assist time.
return
}
if l.transitioning {
throw("update during transition")
}
l.updateLocked(now)
l.unlock()
}
// updateLocked is the implementation of update. l.lock must be held.
func (l *gcCPULimiterState) updateLocked(now int64) {
lastUpdate := l.lastUpdate.Load()
if now < lastUpdate {
// Defensively avoid overflow. This isn't even the latest update anyway.
return
}
windowTotalTime := (now - lastUpdate) * int64(l.nprocs)
l.lastUpdate.Store(now)
// Drain the pool of assist time.
assistTime := l.assistTimePool.Load()
if assistTime != 0 {
l.assistTimePool.Add(-assistTime)
}
// Drain the pool of idle time.
idleTime := l.idleTimePool.Load()
if idleTime != 0 {
l.idleTimePool.Add(-idleTime)
}
if !l.test {
// Consume time from in-flight events. Make sure we're not preemptible so allp can't change.
//
// The reason we do this instead of just waiting for those events to finish and push updates
// is to ensure that all the time we're accounting for happened sometime between lastUpdate
// and now. This dramatically simplifies reasoning about the limiter because we're not at
// risk of extra time being accounted for in this window than actually happened in this window,
// leading to all sorts of weird transient behavior.
mp := acquirem()
for _, pp := range allp {
typ, duration := pp.limiterEvent.consume(now)
switch typ {
case limiterEventIdleMarkWork:
fallthrough
case limiterEventIdle:
idleTime += duration
case limiterEventMarkAssist:
fallthrough
case limiterEventScavengeAssist:
assistTime += duration
case limiterEventNone:
break
default:
throw("invalid limiter event type found")
}
}
releasem(mp)
}
// Compute total GC time.
windowGCTime := assistTime
if l.gcEnabled {
windowGCTime += int64(float64(windowTotalTime) * gcBackgroundUtilization)
}
// Subtract out all idle time from the total time. Do this after computing
// GC time, because the background utilization is dependent on the *real*
// total time, not the total time after idle time is subtracted.
//
// Idle time is counted as any time that a P is on the P idle list plus idle mark
// time. Idle mark workers soak up time that the application spends idle.
//
// On a heavily undersubscribed system, any additional idle time can skew GC CPU
// utilization, because the GC might be executing continuously and thrashing,
// yet the CPU utilization with respect to GOMAXPROCS will be quite low, so
// the limiter fails to turn on. By subtracting idle time, we're removing time that
// we know the application was idle giving a more accurate picture of whether
// the GC is thrashing.
//
// Note that this can cause the limiter to turn on even if it's not needed. For
// instance, on a system with 32 Ps but only 1 running goroutine, each GC will have
// 8 dedicated GC workers. Assuming the GC cycle is half mark phase and half sweep
// phase, then the GC CPU utilization over that cycle, with idle time removed, will
// be 8/(8+2) = 80%. Even though the limiter turns on, though, assist should be
// unnecessary, as the GC has way more CPU time to outpace the 1 goroutine that's
// running.
windowTotalTime -= idleTime
l.accumulate(windowTotalTime-windowGCTime, windowGCTime)
}
// accumulate adds time to the bucket and signals whether the limiter is enabled.
//
// This is an internal function that deals just with the bucket. Prefer update.
// l.lock must be held.
func (l *gcCPULimiterState) accumulate(mutatorTime, gcTime int64) {
headroom := l.bucket.capacity - l.bucket.fill
enabled := headroom == 0
// Let's be careful about three things here:
// 1. The addition and subtraction, for the invariants.
// 2. Overflow.
// 3. Excessive mutation of l.enabled, which is accessed
// by all assists, potentially more than once.
change := gcTime - mutatorTime
// Handle limiting case.
if change > 0 && headroom <= uint64(change) {
l.overflow += uint64(change) - headroom
l.bucket.fill = l.bucket.capacity
if !enabled {
l.enabled.Store(true)
l.lastEnabledCycle.Store(memstats.numgc + 1)
}
return
}
// Handle non-limiting cases.
if change < 0 && l.bucket.fill <= uint64(-change) {
// Bucket emptied.
l.bucket.fill = 0
} else {
// All other cases.
l.bucket.fill -= uint64(-change)
}
if change != 0 && enabled {
l.enabled.Store(false)
}
}
// tryLock attempts to lock l. Returns true on success.
func (l *gcCPULimiterState) tryLock() bool {
return l.lock.CompareAndSwap(0, 1)
}
// unlock releases the lock on l. Must be called if tryLock returns true.
func (l *gcCPULimiterState) unlock() {
old := l.lock.Swap(0)
if old != 1 {
throw("double unlock")
}
}
// capacityPerProc is the limiter's bucket capacity for each P in GOMAXPROCS.
const capacityPerProc = 1e9 // 1 second in nanoseconds
// resetCapacity updates the capacity based on GOMAXPROCS. Must not be called
// while the GC is enabled.
//
// It is safe to call concurrently with other operations.
func (l *gcCPULimiterState) resetCapacity(now int64, nprocs int32) {
if !l.tryLock() {
// This must happen during a STW, so we can't fail to acquire the lock.
// If we did, something went wrong. Throw.
throw("failed to acquire lock to reset capacity")
}
// Flush the rest of the time for this period.
l.updateLocked(now)
l.nprocs = nprocs
l.bucket.capacity = uint64(nprocs) * capacityPerProc
if l.bucket.fill > l.bucket.capacity {
l.bucket.fill = l.bucket.capacity
l.enabled.Store(true)
l.lastEnabledCycle.Store(memstats.numgc + 1)
} else if l.bucket.fill < l.bucket.capacity {
l.enabled.Store(false)
}
l.unlock()
}
// limiterEventType indicates the type of an event occurring on some P.
//
// These events represent the full set of events that the GC CPU limiter tracks
// to execute its function.
//
// This type may use no more than limiterEventBits bits of information.
type limiterEventType uint8
const (
limiterEventNone limiterEventType = iota // None of the following events.
limiterEventIdleMarkWork // Refers to an idle mark worker (see gcMarkWorkerMode).
limiterEventMarkAssist // Refers to mark assist (see gcAssistAlloc).
limiterEventScavengeAssist // Refers to a scavenge assist (see allocSpan).
limiterEventIdle // Refers to time a P spent on the idle list.
limiterEventBits = 3
)
// limiterEventTypeMask is a mask for the bits in p.limiterEventStart that represent
// the event type. The rest of the bits of that field represent a timestamp.
const (
limiterEventTypeMask = uint64((1<<limiterEventBits)-1) << (64 - limiterEventBits)
limiterEventStampNone = limiterEventStamp(0)
)
// limiterEventStamp is a nanotime timestamp packed with a limiterEventType.
type limiterEventStamp uint64
// makeLimiterEventStamp creates a new stamp from the event type and the current timestamp.
func makeLimiterEventStamp(typ limiterEventType, now int64) limiterEventStamp {
return limiterEventStamp(uint64(typ)<<(64-limiterEventBits) | (uint64(now) &^ limiterEventTypeMask))
}
// duration computes the difference between now and the start time stored in the stamp.
//
// Returns 0 if the difference is negative, which may happen if now is stale or if the
// before and after timestamps cross a 2^(64-limiterEventBits) boundary.
func (s limiterEventStamp) duration(now int64) int64 {
// The top limiterEventBits bits of the timestamp are derived from the current time
// when computing a duration.
start := int64((uint64(now) & limiterEventTypeMask) | (uint64(s) &^ limiterEventTypeMask))
if now < start {
return 0
}
return now - start
}
// type extracts the event type from the stamp.
func (s limiterEventStamp) typ() limiterEventType {
return limiterEventType(s >> (64 - limiterEventBits))
}
// limiterEvent represents tracking state for an event tracked by the GC CPU limiter.
type limiterEvent struct {
stamp atomic.Uint64 // Stores a limiterEventStamp.
}
// start begins tracking a new limiter event of the current type. If an event
// is already in flight, then a new event cannot begin because the current time is
// already being attributed to that event. In this case, this function returns false.
// Otherwise, it returns true.
//
// The caller must be non-preemptible until at least stop is called or this function
// returns false. Because this is trying to measure "on-CPU" time of some event, getting
// scheduled away during it can mean that whatever we're measuring isn't a reflection
// of "on-CPU" time. The OS could deschedule us at any time, but we want to maintain as
// close of an approximation as we can.
func (e *limiterEvent) start(typ limiterEventType, now int64) bool {
if limiterEventStamp(e.stamp.Load()).typ() != limiterEventNone {
return false
}
e.stamp.Store(uint64(makeLimiterEventStamp(typ, now)))
return true
}
// consume acquires the partial event CPU time from any in-flight event.
// It achieves this by storing the current time as the new event time.
//
// Returns the type of the in-flight event, as well as how long it's currently been
// executing for. Returns limiterEventNone if no event is active.
func (e *limiterEvent) consume(now int64) (typ limiterEventType, duration int64) {
// Read the limiter event timestamp and update it to now.
for {
old := limiterEventStamp(e.stamp.Load())
typ = old.typ()
if typ == limiterEventNone {
// There's no in-flight event, so just push that up.
return
}
duration = old.duration(now)
if duration == 0 {
// We might have a stale now value, or this crossed the
// 2^(64-limiterEventBits) boundary in the clock readings.
// Just ignore it.
return limiterEventNone, 0
}
new := makeLimiterEventStamp(typ, now)
if e.stamp.CompareAndSwap(uint64(old), uint64(new)) {
break
}
}
return
}
// stop stops the active limiter event. Throws if the
//
// The caller must be non-preemptible across the event. See start as to why.
func (e *limiterEvent) stop(typ limiterEventType, now int64) {
var stamp limiterEventStamp
for {
stamp = limiterEventStamp(e.stamp.Load())
if stamp.typ() != typ {
print("runtime: want=", typ, " got=", stamp.typ(), "\n")
throw("limiterEvent.stop: found wrong event in p's limiter event slot")
}
if e.stamp.CompareAndSwap(uint64(stamp), uint64(limiterEventStampNone)) {
break
}
}
duration := stamp.duration(now)
if duration == 0 {
// It's possible that we're missing time because we crossed a
// 2^(64-limiterEventBits) boundary between the start and end.
// In this case, we're dropping that information. This is OK because
// at worst it'll cause a transient hiccup that will quickly resolve
// itself as all new timestamps begin on the other side of the boundary.
// Such a hiccup should be incredibly rare.
return
}
// Account for the event.
switch typ {
case limiterEventIdleMarkWork:
gcCPULimiter.addIdleTime(duration)
case limiterEventIdle:
gcCPULimiter.addIdleTime(duration)
sched.idleTime.Add(duration)
case limiterEventMarkAssist:
fallthrough
case limiterEventScavengeAssist:
gcCPULimiter.addAssistTime(duration)
default:
throw("limiterEvent.stop: invalid limiter event type found")
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: marking and scanning
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
const (
fixedRootFinalizers = iota
fixedRootFreeGStacks
fixedRootCount
// rootBlockBytes is the number of bytes to scan per data or
// BSS root.
rootBlockBytes = 256 << 10
// maxObletBytes is the maximum bytes of an object to scan at
// once. Larger objects will be split up into "oblets" of at
// most this size. Since we can scan 1–2 MB/ms, 128 KB bounds
// scan preemption at ~100 µs.
//
// This must be > _MaxSmallSize so that the object base is the
// span base.
maxObletBytes = 128 << 10
// drainCheckThreshold specifies how many units of work to do
// between self-preemption checks in gcDrain. Assuming a scan
// rate of 1 MB/ms, this is ~100 µs. Lower values have higher
// overhead in the scan loop (the scheduler check may perform
// a syscall, so its overhead is nontrivial). Higher values
// make the system less responsive to incoming work.
drainCheckThreshold = 100000
// pagesPerSpanRoot indicates how many pages to scan from a span root
// at a time. Used by special root marking.
//
// Higher values improve throughput by increasing locality, but
// increase the minimum latency of a marking operation.
//
// Must be a multiple of the pageInUse bitmap element size and
// must also evenly divide pagesPerArena.
pagesPerSpanRoot = 512
)
// gcMarkRootPrepare queues root scanning jobs (stacks, globals, and
// some miscellany) and initializes scanning-related state.
//
// The world must be stopped.
func gcMarkRootPrepare() {
assertWorldStopped()
// Compute how many data and BSS root blocks there are.
nBlocks := func(bytes uintptr) int {
return int(divRoundUp(bytes, rootBlockBytes))
}
work.nDataRoots = 0
work.nBSSRoots = 0
// Scan globals.
for _, datap := range activeModules() {
nDataRoots := nBlocks(datap.edata - datap.data)
if nDataRoots > work.nDataRoots {
work.nDataRoots = nDataRoots
}
}
for _, datap := range activeModules() {
nBSSRoots := nBlocks(datap.ebss - datap.bss)
if nBSSRoots > work.nBSSRoots {
work.nBSSRoots = nBSSRoots
}
}
// Scan span roots for finalizer specials.
//
// We depend on addfinalizer to mark objects that get
// finalizers after root marking.
//
// We're going to scan the whole heap (that was available at the time the
// mark phase started, i.e. markArenas) for in-use spans which have specials.
//
// Break up the work into arenas, and further into chunks.
//
// Snapshot allArenas as markArenas. This snapshot is safe because allArenas
// is append-only.
mheap_.markArenas = mheap_.allArenas[:len(mheap_.allArenas):len(mheap_.allArenas)]
work.nSpanRoots = len(mheap_.markArenas) * (pagesPerArena / pagesPerSpanRoot)
// Scan stacks.
//
// Gs may be created after this point, but it's okay that we
// ignore them because they begin life without any roots, so
// there's nothing to scan, and any roots they create during
// the concurrent phase will be caught by the write barrier.
work.stackRoots = allGsSnapshot()
work.nStackRoots = len(work.stackRoots)
work.markrootNext = 0
work.markrootJobs = uint32(fixedRootCount + work.nDataRoots + work.nBSSRoots + work.nSpanRoots + work.nStackRoots)
// Calculate base indexes of each root type
work.baseData = uint32(fixedRootCount)
work.baseBSS = work.baseData + uint32(work.nDataRoots)
work.baseSpans = work.baseBSS + uint32(work.nBSSRoots)
work.baseStacks = work.baseSpans + uint32(work.nSpanRoots)
work.baseEnd = work.baseStacks + uint32(work.nStackRoots)
}
// gcMarkRootCheck checks that all roots have been scanned. It is
// purely for debugging.
func gcMarkRootCheck() {
if work.markrootNext < work.markrootJobs {
print(work.markrootNext, " of ", work.markrootJobs, " markroot jobs done\n")
throw("left over markroot jobs")
}
// Check that stacks have been scanned.
//
// We only check the first nStackRoots Gs that we should have scanned.
// Since we don't care about newer Gs (see comment in
// gcMarkRootPrepare), no locking is required.
i := 0
forEachGRace(func(gp *g) {
if i >= work.nStackRoots {
return
}
if !gp.gcscandone {
println("gp", gp, "goid", gp.goid,
"status", readgstatus(gp),
"gcscandone", gp.gcscandone)
throw("scan missed a g")
}
i++
})
}
// ptrmask for an allocation containing a single pointer.
var oneptrmask = [...]uint8{1}
// markroot scans the i'th root.
//
// Preemption must be disabled (because this uses a gcWork).
//
// Returns the amount of GC work credit produced by the operation.
// If flushBgCredit is true, then that credit is also flushed
// to the background credit pool.
//
// nowritebarrier is only advisory here.
//
//go:nowritebarrier
func markroot(gcw *gcWork, i uint32, flushBgCredit bool) int64 {
// Note: if you add a case here, please also update heapdump.go:dumproots.
var workDone int64
var workCounter *atomic.Int64
switch {
case work.baseData <= i && i < work.baseBSS:
workCounter = &gcController.globalsScanWork
for _, datap := range activeModules() {
workDone += markrootBlock(datap.data, datap.edata-datap.data, datap.gcdatamask.bytedata, gcw, int(i-work.baseData))
}
case work.baseBSS <= i && i < work.baseSpans:
workCounter = &gcController.globalsScanWork
for _, datap := range activeModules() {
workDone += markrootBlock(datap.bss, datap.ebss-datap.bss, datap.gcbssmask.bytedata, gcw, int(i-work.baseBSS))
}
case i == fixedRootFinalizers:
for fb := allfin; fb != nil; fb = fb.alllink {
cnt := uintptr(atomic.Load(&fb.cnt))
scanblock(uintptr(unsafe.Pointer(&fb.fin[0])), cnt*unsafe.Sizeof(fb.fin[0]), &finptrmask[0], gcw, nil)
}
case i == fixedRootFreeGStacks:
// Switch to the system stack so we can call
// stackfree.
systemstack(markrootFreeGStacks)
case work.baseSpans <= i && i < work.baseStacks:
// mark mspan.specials
markrootSpans(gcw, int(i-work.baseSpans))
default:
// the rest is scanning goroutine stacks
workCounter = &gcController.stackScanWork
if i < work.baseStacks || work.baseEnd <= i {
printlock()
print("runtime: markroot index ", i, " not in stack roots range [", work.baseStacks, ", ", work.baseEnd, ")\n")
throw("markroot: bad index")
}
gp := work.stackRoots[i-work.baseStacks]
// remember when we've first observed the G blocked
// needed only to output in traceback
status := readgstatus(gp) // We are not in a scan state
if (status == _Gwaiting || status == _Gsyscall) && gp.waitsince == 0 {
gp.waitsince = work.tstart
}
// scanstack must be done on the system stack in case
// we're trying to scan our own stack.
systemstack(func() {
// If this is a self-scan, put the user G in
// _Gwaiting to prevent self-deadlock. It may
// already be in _Gwaiting if this is a mark
// worker or we're in mark termination.
userG := getg().m.curg
selfScan := gp == userG && readgstatus(userG) == _Grunning
if selfScan {
casGToWaiting(userG, _Grunning, waitReasonGarbageCollectionScan)
}
// TODO: suspendG blocks (and spins) until gp
// stops, which may take a while for
// running goroutines. Consider doing this in
// two phases where the first is non-blocking:
// we scan the stacks we can and ask running
// goroutines to scan themselves; and the
// second blocks.
stopped := suspendG(gp)
if stopped.dead {
gp.gcscandone = true
return
}
if gp.gcscandone {
throw("g already scanned")
}
workDone += scanstack(gp, gcw)
gp.gcscandone = true
resumeG(stopped)
if selfScan {
casgstatus(userG, _Gwaiting, _Grunning)
}
})
}
if workCounter != nil && workDone != 0 {
workCounter.Add(workDone)
if flushBgCredit {
gcFlushBgCredit(workDone)
}
}
return workDone
}
// markrootBlock scans the shard'th shard of the block of memory [b0,
// b0+n0), with the given pointer mask.
//
// Returns the amount of work done.
//
//go:nowritebarrier
func markrootBlock(b0, n0 uintptr, ptrmask0 *uint8, gcw *gcWork, shard int) int64 {
if rootBlockBytes%(8*goarch.PtrSize) != 0 {
// This is necessary to pick byte offsets in ptrmask0.
throw("rootBlockBytes must be a multiple of 8*ptrSize")
}
// Note that if b0 is toward the end of the address space,
// then b0 + rootBlockBytes might wrap around.
// These tests are written to avoid any possible overflow.
off := uintptr(shard) * rootBlockBytes
if off >= n0 {
return 0
}
b := b0 + off
ptrmask := (*uint8)(add(unsafe.Pointer(ptrmask0), uintptr(shard)*(rootBlockBytes/(8*goarch.PtrSize))))
n := uintptr(rootBlockBytes)
if off+n > n0 {
n = n0 - off
}
// Scan this shard.
scanblock(b, n, ptrmask, gcw, nil)
return int64(n)
}
// markrootFreeGStacks frees stacks of dead Gs.
//
// This does not free stacks of dead Gs cached on Ps, but having a few
// cached stacks around isn't a problem.
func markrootFreeGStacks() {
// Take list of dead Gs with stacks.
lock(&sched.gFree.lock)
list := sched.gFree.stack
sched.gFree.stack = gList{}
unlock(&sched.gFree.lock)
if list.empty() {
return
}
// Free stacks.
q := gQueue{list.head, list.head}
for gp := list.head.ptr(); gp != nil; gp = gp.schedlink.ptr() {
stackfree(gp.stack)
gp.stack.lo = 0
gp.stack.hi = 0
// Manipulate the queue directly since the Gs are
// already all linked the right way.
q.tail.set(gp)
}
// Put Gs back on the free list.
lock(&sched.gFree.lock)
sched.gFree.noStack.pushAll(q)
unlock(&sched.gFree.lock)
}
// markrootSpans marks roots for one shard of markArenas.
//
//go:nowritebarrier
func markrootSpans(gcw *gcWork, shard int) {
// Objects with finalizers have two GC-related invariants:
//
// 1) Everything reachable from the object must be marked.
// This ensures that when we pass the object to its finalizer,
// everything the finalizer can reach will be retained.
//
// 2) Finalizer specials (which are not in the garbage
// collected heap) are roots. In practice, this means the fn
// field must be scanned.
sg := mheap_.sweepgen
// Find the arena and page index into that arena for this shard.
ai := mheap_.markArenas[shard/(pagesPerArena/pagesPerSpanRoot)]
ha := mheap_.arenas[ai.l1()][ai.l2()]
arenaPage := uint(uintptr(shard) * pagesPerSpanRoot % pagesPerArena)
// Construct slice of bitmap which we'll iterate over.
specialsbits := ha.pageSpecials[arenaPage/8:]
specialsbits = specialsbits[:pagesPerSpanRoot/8]
for i := range specialsbits {
// Find set bits, which correspond to spans with specials.
specials := atomic.Load8(&specialsbits[i])
if specials == 0 {
continue
}
for j := uint(0); j < 8; j++ {
if specials&(1<<j) == 0 {
continue
}
// Find the span for this bit.
//
// This value is guaranteed to be non-nil because having
// specials implies that the span is in-use, and since we're
// currently marking we can be sure that we don't have to worry
// about the span being freed and re-used.
s := ha.spans[arenaPage+uint(i)*8+j]
// The state must be mSpanInUse if the specials bit is set, so
// sanity check that.
if state := s.state.get(); state != mSpanInUse {
print("s.state = ", state, "\n")
throw("non in-use span found with specials bit set")
}
// Check that this span was swept (it may be cached or uncached).
if !useCheckmark && !(s.sweepgen == sg || s.sweepgen == sg+3) {
// sweepgen was updated (+2) during non-checkmark GC pass
print("sweep ", s.sweepgen, " ", sg, "\n")
throw("gc: unswept span")
}
// Lock the specials to prevent a special from being
// removed from the list while we're traversing it.
lock(&s.speciallock)
for sp := s.specials; sp != nil; sp = sp.next {
if sp.kind != _KindSpecialFinalizer {
continue
}
// don't mark finalized object, but scan it so we
// retain everything it points to.
spf := (*specialfinalizer)(unsafe.Pointer(sp))
// A finalizer can be set for an inner byte of an object, find object beginning.
p := s.base() + uintptr(spf.special.offset)/s.elemsize*s.elemsize
// Mark everything that can be reached from
// the object (but *not* the object itself or
// we'll never collect it).
if !s.spanclass.noscan() {
scanobject(p, gcw)
}
// The special itself is a root.
scanblock(uintptr(unsafe.Pointer(&spf.fn)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
}
unlock(&s.speciallock)
}
}
}
// gcAssistAlloc performs GC work to make gp's assist debt positive.
// gp must be the calling user goroutine.
//
// This must be called with preemption enabled.
func gcAssistAlloc(gp *g) {
// Don't assist in non-preemptible contexts. These are
// generally fragile and won't allow the assist to block.
if getg() == gp.m.g0 {
return
}
if mp := getg().m; mp.locks > 0 || mp.preemptoff != "" {
return
}
traced := false
retry:
if gcCPULimiter.limiting() {
// If the CPU limiter is enabled, intentionally don't
// assist to reduce the amount of CPU time spent in the GC.
if traced {
traceGCMarkAssistDone()
}
return
}
// Compute the amount of scan work we need to do to make the
// balance positive. When the required amount of work is low,
// we over-assist to build up credit for future allocations
// and amortize the cost of assisting.
assistWorkPerByte := gcController.assistWorkPerByte.Load()
assistBytesPerWork := gcController.assistBytesPerWork.Load()
debtBytes := -gp.gcAssistBytes
scanWork := int64(assistWorkPerByte * float64(debtBytes))
if scanWork < gcOverAssistWork {
scanWork = gcOverAssistWork
debtBytes = int64(assistBytesPerWork * float64(scanWork))
}
// Steal as much credit as we can from the background GC's
// scan credit. This is racy and may drop the background
// credit below 0 if two mutators steal at the same time. This
// will just cause steals to fail until credit is accumulated
// again, so in the long run it doesn't really matter, but we
// do have to handle the negative credit case.
bgScanCredit := gcController.bgScanCredit.Load()
stolen := int64(0)
if bgScanCredit > 0 {
if bgScanCredit < scanWork {
stolen = bgScanCredit
gp.gcAssistBytes += 1 + int64(assistBytesPerWork*float64(stolen))
} else {
stolen = scanWork
gp.gcAssistBytes += debtBytes
}
gcController.bgScanCredit.Add(-stolen)
scanWork -= stolen
if scanWork == 0 {
// We were able to steal all of the credit we
// needed.
if traced {
traceGCMarkAssistDone()
}
return
}
}
if trace.enabled && !traced {
traced = true
traceGCMarkAssistStart()
}
// Perform assist work
systemstack(func() {
gcAssistAlloc1(gp, scanWork)
// The user stack may have moved, so this can't touch
// anything on it until it returns from systemstack.
})
completed := gp.param != nil
gp.param = nil
if completed {
gcMarkDone()
}
if gp.gcAssistBytes < 0 {
// We were unable steal enough credit or perform
// enough work to pay off the assist debt. We need to
// do one of these before letting the mutator allocate
// more to prevent over-allocation.
//
// If this is because we were preempted, reschedule
// and try some more.
if gp.preempt {
Gosched()
goto retry
}
// Add this G to an assist queue and park. When the GC
// has more background credit, it will satisfy queued
// assists before flushing to the global credit pool.
//
// Note that this does *not* get woken up when more
// work is added to the work list. The theory is that
// there wasn't enough work to do anyway, so we might
// as well let background marking take care of the
// work that is available.
if !gcParkAssist() {
goto retry
}
// At this point either background GC has satisfied
// this G's assist debt, or the GC cycle is over.
}
if traced {
traceGCMarkAssistDone()
}
}
// gcAssistAlloc1 is the part of gcAssistAlloc that runs on the system
// stack. This is a separate function to make it easier to see that
// we're not capturing anything from the user stack, since the user
// stack may move while we're in this function.
//
// gcAssistAlloc1 indicates whether this assist completed the mark
// phase by setting gp.param to non-nil. This can't be communicated on
// the stack since it may move.
//
//go:systemstack
func gcAssistAlloc1(gp *g, scanWork int64) {
// Clear the flag indicating that this assist completed the
// mark phase.
gp.param = nil
if atomic.Load(&gcBlackenEnabled) == 0 {
// The gcBlackenEnabled check in malloc races with the
// store that clears it but an atomic check in every malloc
// would be a performance hit.
// Instead we recheck it here on the non-preemptable system
// stack to determine if we should perform an assist.
// GC is done, so ignore any remaining debt.
gp.gcAssistBytes = 0
return
}
// Track time spent in this assist. Since we're on the
// system stack, this is non-preemptible, so we can
// just measure start and end time.
//
// Limiter event tracking might be disabled if we end up here
// while on a mark worker.
startTime := nanotime()
trackLimiterEvent := gp.m.p.ptr().limiterEvent.start(limiterEventMarkAssist, startTime)
decnwait := atomic.Xadd(&work.nwait, -1)
if decnwait == work.nproc {
println("runtime: work.nwait =", decnwait, "work.nproc=", work.nproc)
throw("nwait > work.nprocs")
}
// gcDrainN requires the caller to be preemptible.
casGToWaiting(gp, _Grunning, waitReasonGCAssistMarking)
// drain own cached work first in the hopes that it
// will be more cache friendly.
gcw := &getg().m.p.ptr().gcw
workDone := gcDrainN(gcw, scanWork)
casgstatus(gp, _Gwaiting, _Grunning)
// Record that we did this much scan work.
//
// Back out the number of bytes of assist credit that
// this scan work counts for. The "1+" is a poor man's
// round-up, to ensure this adds credit even if
// assistBytesPerWork is very low.
assistBytesPerWork := gcController.assistBytesPerWork.Load()
gp.gcAssistBytes += 1 + int64(assistBytesPerWork*float64(workDone))
// If this is the last worker and we ran out of work,
// signal a completion point.
incnwait := atomic.Xadd(&work.nwait, +1)
if incnwait > work.nproc {
println("runtime: work.nwait=", incnwait,
"work.nproc=", work.nproc)
throw("work.nwait > work.nproc")
}
if incnwait == work.nproc && !gcMarkWorkAvailable(nil) {
// This has reached a background completion point. Set
// gp.param to a non-nil value to indicate this. It
// doesn't matter what we set it to (it just has to be
// a valid pointer).
gp.param = unsafe.Pointer(gp)
}
now := nanotime()
duration := now - startTime
pp := gp.m.p.ptr()
pp.gcAssistTime += duration
if trackLimiterEvent {
pp.limiterEvent.stop(limiterEventMarkAssist, now)
}
if pp.gcAssistTime > gcAssistTimeSlack {
gcController.assistTime.Add(pp.gcAssistTime)
gcCPULimiter.update(now)
pp.gcAssistTime = 0
}
}
// gcWakeAllAssists wakes all currently blocked assists. This is used
// at the end of a GC cycle. gcBlackenEnabled must be false to prevent
// new assists from going to sleep after this point.
func gcWakeAllAssists() {
lock(&work.assistQueue.lock)
list := work.assistQueue.q.popList()
injectglist(&list)
unlock(&work.assistQueue.lock)
}
// gcParkAssist puts the current goroutine on the assist queue and parks.
//
// gcParkAssist reports whether the assist is now satisfied. If it
// returns false, the caller must retry the assist.
func gcParkAssist() bool {
lock(&work.assistQueue.lock)
// If the GC cycle finished while we were getting the lock,
// exit the assist. The cycle can't finish while we hold the
// lock.
if atomic.Load(&gcBlackenEnabled) == 0 {
unlock(&work.assistQueue.lock)
return true
}
gp := getg()
oldList := work.assistQueue.q
work.assistQueue.q.pushBack(gp)
// Recheck for background credit now that this G is in
// the queue, but can still back out. This avoids a
// race in case background marking has flushed more
// credit since we checked above.
if gcController.bgScanCredit.Load() > 0 {
work.assistQueue.q = oldList
if oldList.tail != 0 {
oldList.tail.ptr().schedlink.set(nil)
}
unlock(&work.assistQueue.lock)
return false
}
// Park.
goparkunlock(&work.assistQueue.lock, waitReasonGCAssistWait, traceEvGoBlockGC, 2)
return true
}
// gcFlushBgCredit flushes scanWork units of background scan work
// credit. This first satisfies blocked assists on the
// work.assistQueue and then flushes any remaining credit to
// gcController.bgScanCredit.
//
// Write barriers are disallowed because this is used by gcDrain after
// it has ensured that all work is drained and this must preserve that
// condition.
//
//go:nowritebarrierrec
func gcFlushBgCredit(scanWork int64) {
if work.assistQueue.q.empty() {
// Fast path; there are no blocked assists. There's a
// small window here where an assist may add itself to
// the blocked queue and park. If that happens, we'll
// just get it on the next flush.
gcController.bgScanCredit.Add(scanWork)
return
}
assistBytesPerWork := gcController.assistBytesPerWork.Load()
scanBytes := int64(float64(scanWork) * assistBytesPerWork)
lock(&work.assistQueue.lock)
for !work.assistQueue.q.empty() && scanBytes > 0 {
gp := work.assistQueue.q.pop()
// Note that gp.gcAssistBytes is negative because gp
// is in debt. Think carefully about the signs below.
if scanBytes+gp.gcAssistBytes >= 0 {
// Satisfy this entire assist debt.
scanBytes += gp.gcAssistBytes
gp.gcAssistBytes = 0
// It's important that we *not* put gp in
// runnext. Otherwise, it's possible for user
// code to exploit the GC worker's high
// scheduler priority to get itself always run
// before other goroutines and always in the
// fresh quantum started by GC.
ready(gp, 0, false)
} else {
// Partially satisfy this assist.
gp.gcAssistBytes += scanBytes
scanBytes = 0
// As a heuristic, we move this assist to the
// back of the queue so that large assists
// can't clog up the assist queue and
// substantially delay small assists.
work.assistQueue.q.pushBack(gp)
break
}
}
if scanBytes > 0 {
// Convert from scan bytes back to work.
assistWorkPerByte := gcController.assistWorkPerByte.Load()
scanWork = int64(float64(scanBytes) * assistWorkPerByte)
gcController.bgScanCredit.Add(scanWork)
}
unlock(&work.assistQueue.lock)
}
// scanstack scans gp's stack, greying all pointers found on the stack.
//
// Returns the amount of scan work performed, but doesn't update
// gcController.stackScanWork or flush any credit. Any background credit produced
// by this function should be flushed by its caller. scanstack itself can't
// safely flush because it may result in trying to wake up a goroutine that
// was just scanned, resulting in a self-deadlock.
//
// scanstack will also shrink the stack if it is safe to do so. If it
// is not, it schedules a stack shrink for the next synchronous safe
// point.
//
// scanstack is marked go:systemstack because it must not be preempted
// while using a workbuf.
//
//go:nowritebarrier
//go:systemstack
func scanstack(gp *g, gcw *gcWork) int64 {
if readgstatus(gp)&_Gscan == 0 {
print("runtime:scanstack: gp=", gp, ", goid=", gp.goid, ", gp->atomicstatus=", hex(readgstatus(gp)), "\n")
throw("scanstack - bad status")
}
switch readgstatus(gp) &^ _Gscan {
default:
print("runtime: gp=", gp, ", goid=", gp.goid, ", gp->atomicstatus=", readgstatus(gp), "\n")
throw("mark - bad status")
case _Gdead:
return 0
case _Grunning:
print("runtime: gp=", gp, ", goid=", gp.goid, ", gp->atomicstatus=", readgstatus(gp), "\n")
throw("scanstack: goroutine not stopped")
case _Grunnable, _Gsyscall, _Gwaiting:
// ok
}
if gp == getg() {
throw("can't scan our own stack")
}
// scannedSize is the amount of work we'll be reporting.
//
// It is less than the allocated size (which is hi-lo).
var sp uintptr
if gp.syscallsp != 0 {
sp = gp.syscallsp // If in a system call this is the stack pointer (gp.sched.sp can be 0 in this case on Windows).
} else {
sp = gp.sched.sp
}
scannedSize := gp.stack.hi - sp
// Keep statistics for initial stack size calculation.
// Note that this accumulates the scanned size, not the allocated size.
p := getg().m.p.ptr()
p.scannedStackSize += uint64(scannedSize)
p.scannedStacks++
if isShrinkStackSafe(gp) {
// Shrink the stack if not much of it is being used.
shrinkstack(gp)
} else {
// Otherwise, shrink the stack at the next sync safe point.
gp.preemptShrink = true
}
var state stackScanState
state.stack = gp.stack
if stackTraceDebug {
println("stack trace goroutine", gp.goid)
}
if debugScanConservative && gp.asyncSafePoint {
print("scanning async preempted goroutine ", gp.goid, " stack [", hex(gp.stack.lo), ",", hex(gp.stack.hi), ")\n")
}
// Scan the saved context register. This is effectively a live
// register that gets moved back and forth between the
// register and sched.ctxt without a write barrier.
if gp.sched.ctxt != nil {
scanblock(uintptr(unsafe.Pointer(&gp.sched.ctxt)), goarch.PtrSize, &oneptrmask[0], gcw, &state)
}
// Scan the stack. Accumulate a list of stack objects.
scanframe := func(frame *stkframe, unused unsafe.Pointer) bool {
scanframeworker(frame, &state, gcw)
return true
}
gentraceback(^uintptr(0), ^uintptr(0), 0, gp, 0, nil, 0x7fffffff, scanframe, nil, 0)
// Find additional pointers that point into the stack from the heap.
// Currently this includes defers and panics. See also function copystack.
// Find and trace other pointers in defer records.
for d := gp._defer; d != nil; d = d.link {
if d.fn != nil {
// Scan the func value, which could be a stack allocated closure.
// See issue 30453.
scanblock(uintptr(unsafe.Pointer(&d.fn)), goarch.PtrSize, &oneptrmask[0], gcw, &state)
}
if d.link != nil {
// The link field of a stack-allocated defer record might point
// to a heap-allocated defer record. Keep that heap record live.
scanblock(uintptr(unsafe.Pointer(&d.link)), goarch.PtrSize, &oneptrmask[0], gcw, &state)
}
// Retain defers records themselves.
// Defer records might not be reachable from the G through regular heap
// tracing because the defer linked list might weave between the stack and the heap.
if d.heap {
scanblock(uintptr(unsafe.Pointer(&d)), goarch.PtrSize, &oneptrmask[0], gcw, &state)
}
}
if gp._panic != nil {
// Panics are always stack allocated.
state.putPtr(uintptr(unsafe.Pointer(gp._panic)), false)
}
// Find and scan all reachable stack objects.
//
// The state's pointer queue prioritizes precise pointers over
// conservative pointers so that we'll prefer scanning stack
// objects precisely.
state.buildIndex()
for {
p, conservative := state.getPtr()
if p == 0 {
break
}
obj := state.findObject(p)
if obj == nil {
continue
}
r := obj.r
if r == nil {
// We've already scanned this object.
continue
}
obj.setRecord(nil) // Don't scan it again.
if stackTraceDebug {
printlock()
print(" live stkobj at", hex(state.stack.lo+uintptr(obj.off)), "of size", obj.size)
if conservative {
print(" (conservative)")
}
println()
printunlock()
}
gcdata := r.gcdata()
var s *mspan
if r.useGCProg() {
// This path is pretty unlikely, an object large enough
// to have a GC program allocated on the stack.
// We need some space to unpack the program into a straight
// bitmask, which we allocate/free here.
// TODO: it would be nice if there were a way to run a GC
// program without having to store all its bits. We'd have
// to change from a Lempel-Ziv style program to something else.
// Or we can forbid putting objects on stacks if they require
// a gc program (see issue 27447).
s = materializeGCProg(r.ptrdata(), gcdata)
gcdata = (*byte)(unsafe.Pointer(s.startAddr))
}
b := state.stack.lo + uintptr(obj.off)
if conservative {
scanConservative(b, r.ptrdata(), gcdata, gcw, &state)
} else {
scanblock(b, r.ptrdata(), gcdata, gcw, &state)
}
if s != nil {
dematerializeGCProg(s)
}
}
// Deallocate object buffers.
// (Pointer buffers were all deallocated in the loop above.)
for state.head != nil {
x := state.head
state.head = x.next
if stackTraceDebug {
for i := 0; i < x.nobj; i++ {
obj := &x.obj[i]
if obj.r == nil { // reachable
continue
}
println(" dead stkobj at", hex(gp.stack.lo+uintptr(obj.off)), "of size", obj.r.size)
// Note: not necessarily really dead - only reachable-from-ptr dead.
}
}
x.nobj = 0
putempty((*workbuf)(unsafe.Pointer(x)))
}
if state.buf != nil || state.cbuf != nil || state.freeBuf != nil {
throw("remaining pointer buffers")
}
return int64(scannedSize)
}
// Scan a stack frame: local variables and function arguments/results.
//
//go:nowritebarrier
func scanframeworker(frame *stkframe, state *stackScanState, gcw *gcWork) {
if _DebugGC > 1 && frame.continpc != 0 {
print("scanframe ", funcname(frame.fn), "\n")
}
isAsyncPreempt := frame.fn.valid() && frame.fn.funcID == funcID_asyncPreempt
isDebugCall := frame.fn.valid() && frame.fn.funcID == funcID_debugCallV2
if state.conservative || isAsyncPreempt || isDebugCall {
if debugScanConservative {
println("conservatively scanning function", funcname(frame.fn), "at PC", hex(frame.continpc))
}
// Conservatively scan the frame. Unlike the precise
// case, this includes the outgoing argument space
// since we may have stopped while this function was
// setting up a call.
//
// TODO: We could narrow this down if the compiler
// produced a single map per function of stack slots
// and registers that ever contain a pointer.
if frame.varp != 0 {
size := frame.varp - frame.sp
if size > 0 {
scanConservative(frame.sp, size, nil, gcw, state)
}
}
// Scan arguments to this frame.
if n := frame.argBytes(); n != 0 {
// TODO: We could pass the entry argument map
// to narrow this down further.
scanConservative(frame.argp, n, nil, gcw, state)
}
if isAsyncPreempt || isDebugCall {
// This function's frame contained the
// registers for the asynchronously stopped
// parent frame. Scan the parent
// conservatively.
state.conservative = true
} else {
// We only wanted to scan those two frames
// conservatively. Clear the flag for future
// frames.
state.conservative = false
}
return
}
locals, args, objs := frame.getStackMap(&state.cache, false)
// Scan local variables if stack frame has been allocated.
if locals.n > 0 {
size := uintptr(locals.n) * goarch.PtrSize
scanblock(frame.varp-size, size, locals.bytedata, gcw, state)
}
// Scan arguments.
if args.n > 0 {
scanblock(frame.argp, uintptr(args.n)*goarch.PtrSize, args.bytedata, gcw, state)
}
// Add all stack objects to the stack object list.
if frame.varp != 0 {
// varp is 0 for defers, where there are no locals.
// In that case, there can't be a pointer to its args, either.
// (And all args would be scanned above anyway.)
for i := range objs {
obj := &objs[i]
off := obj.off
base := frame.varp // locals base pointer
if off >= 0 {
base = frame.argp // arguments and return values base pointer
}
ptr := base + uintptr(off)
if ptr < frame.sp {
// object hasn't been allocated in the frame yet.
continue
}
if stackTraceDebug {
println("stkobj at", hex(ptr), "of size", obj.size)
}
state.addObject(ptr, obj)
}
}
}
type gcDrainFlags int
const (
gcDrainUntilPreempt gcDrainFlags = 1 << iota
gcDrainFlushBgCredit
gcDrainIdle
gcDrainFractional
)
// gcDrain scans roots and objects in work buffers, blackening grey
// objects until it is unable to get more work. It may return before
// GC is done; it's the caller's responsibility to balance work from
// other Ps.
//
// If flags&gcDrainUntilPreempt != 0, gcDrain returns when g.preempt
// is set.
//
// If flags&gcDrainIdle != 0, gcDrain returns when there is other work
// to do.
//
// If flags&gcDrainFractional != 0, gcDrain self-preempts when
// pollFractionalWorkerExit() returns true. This implies
// gcDrainNoBlock.
//
// If flags&gcDrainFlushBgCredit != 0, gcDrain flushes scan work
// credit to gcController.bgScanCredit every gcCreditSlack units of
// scan work.
//
// gcDrain will always return if there is a pending STW.
//
//go:nowritebarrier
func gcDrain(gcw *gcWork, flags gcDrainFlags) {
if !writeBarrier.needed {
throw("gcDrain phase incorrect")
}
gp := getg().m.curg
preemptible := flags&gcDrainUntilPreempt != 0
flushBgCredit := flags&gcDrainFlushBgCredit != 0
idle := flags&gcDrainIdle != 0
initScanWork := gcw.heapScanWork
// checkWork is the scan work before performing the next
// self-preempt check.
checkWork := int64(1<<63 - 1)
var check func() bool
if flags&(gcDrainIdle|gcDrainFractional) != 0 {
checkWork = initScanWork + drainCheckThreshold
if idle {
check = pollWork
} else if flags&gcDrainFractional != 0 {
check = pollFractionalWorkerExit
}
}
// Drain root marking jobs.
if work.markrootNext < work.markrootJobs {
// Stop if we're preemptible or if someone wants to STW.
for !(gp.preempt && (preemptible || sched.gcwaiting.Load())) {
job := atomic.Xadd(&work.markrootNext, +1) - 1
if job >= work.markrootJobs {
break
}
markroot(gcw, job, flushBgCredit)
if check != nil && check() {
goto done
}
}
}
// Drain heap marking jobs.
// Stop if we're preemptible or if someone wants to STW.
for !(gp.preempt && (preemptible || sched.gcwaiting.Load())) {
// Try to keep work available on the global queue. We used to
// check if there were waiting workers, but it's better to
// just keep work available than to make workers wait. In the
// worst case, we'll do O(log(_WorkbufSize)) unnecessary
// balances.
if work.full == 0 {
gcw.balance()
}
b := gcw.tryGetFast()
if b == 0 {
b = gcw.tryGet()
if b == 0 {
// Flush the write barrier
// buffer; this may create
// more work.
wbBufFlush()
b = gcw.tryGet()
}
}
if b == 0 {
// Unable to get work.
break
}
scanobject(b, gcw)
// Flush background scan work credit to the global
// account if we've accumulated enough locally so
// mutator assists can draw on it.
if gcw.heapScanWork >= gcCreditSlack {
gcController.heapScanWork.Add(gcw.heapScanWork)
if flushBgCredit {
gcFlushBgCredit(gcw.heapScanWork - initScanWork)
initScanWork = 0
}
checkWork -= gcw.heapScanWork
gcw.heapScanWork = 0
if checkWork <= 0 {
checkWork += drainCheckThreshold
if check != nil && check() {
break
}
}
}
}
done:
// Flush remaining scan work credit.
if gcw.heapScanWork > 0 {
gcController.heapScanWork.Add(gcw.heapScanWork)
if flushBgCredit {
gcFlushBgCredit(gcw.heapScanWork - initScanWork)
}
gcw.heapScanWork = 0
}
}
// gcDrainN blackens grey objects until it has performed roughly
// scanWork units of scan work or the G is preempted. This is
// best-effort, so it may perform less work if it fails to get a work
// buffer. Otherwise, it will perform at least n units of work, but
// may perform more because scanning is always done in whole object
// increments. It returns the amount of scan work performed.
//
// The caller goroutine must be in a preemptible state (e.g.,
// _Gwaiting) to prevent deadlocks during stack scanning. As a
// consequence, this must be called on the system stack.
//
//go:nowritebarrier
//go:systemstack
func gcDrainN(gcw *gcWork, scanWork int64) int64 {
if !writeBarrier.needed {
throw("gcDrainN phase incorrect")
}
// There may already be scan work on the gcw, which we don't
// want to claim was done by this call.
workFlushed := -gcw.heapScanWork
// In addition to backing out because of a preemption, back out
// if the GC CPU limiter is enabled.
gp := getg().m.curg
for !gp.preempt && !gcCPULimiter.limiting() && workFlushed+gcw.heapScanWork < scanWork {
// See gcDrain comment.
if work.full == 0 {
gcw.balance()
}
b := gcw.tryGetFast()
if b == 0 {
b = gcw.tryGet()
if b == 0 {
// Flush the write barrier buffer;
// this may create more work.
wbBufFlush()
b = gcw.tryGet()
}
}
if b == 0 {
// Try to do a root job.
if work.markrootNext < work.markrootJobs {
job := atomic.Xadd(&work.markrootNext, +1) - 1
if job < work.markrootJobs {
workFlushed += markroot(gcw, job, false)
continue
}
}
// No heap or root jobs.
break
}
scanobject(b, gcw)
// Flush background scan work credit.
if gcw.heapScanWork >= gcCreditSlack {
gcController.heapScanWork.Add(gcw.heapScanWork)
workFlushed += gcw.heapScanWork
gcw.heapScanWork = 0
}
}
// Unlike gcDrain, there's no need to flush remaining work
// here because this never flushes to bgScanCredit and
// gcw.dispose will flush any remaining work to scanWork.
return workFlushed + gcw.heapScanWork
}
// scanblock scans b as scanobject would, but using an explicit
// pointer bitmap instead of the heap bitmap.
//
// This is used to scan non-heap roots, so it does not update
// gcw.bytesMarked or gcw.heapScanWork.
//
// If stk != nil, possible stack pointers are also reported to stk.putPtr.
//
//go:nowritebarrier
func scanblock(b0, n0 uintptr, ptrmask *uint8, gcw *gcWork, stk *stackScanState) {
// Use local copies of original parameters, so that a stack trace
// due to one of the throws below shows the original block
// base and extent.
b := b0
n := n0
for i := uintptr(0); i < n; {
// Find bits for the next word.
bits := uint32(*addb(ptrmask, i/(goarch.PtrSize*8)))
if bits == 0 {
i += goarch.PtrSize * 8
continue
}
for j := 0; j < 8 && i < n; j++ {
if bits&1 != 0 {
// Same work as in scanobject; see comments there.
p := *(*uintptr)(unsafe.Pointer(b + i))
if p != 0 {
if obj, span, objIndex := findObject(p, b, i); obj != 0 {
greyobject(obj, b, i, span, gcw, objIndex)
} else if stk != nil && p >= stk.stack.lo && p < stk.stack.hi {
stk.putPtr(p, false)
}
}
}
bits >>= 1
i += goarch.PtrSize
}
}
}
// scanobject scans the object starting at b, adding pointers to gcw.
// b must point to the beginning of a heap object or an oblet.
// scanobject consults the GC bitmap for the pointer mask and the
// spans for the size of the object.
//
//go:nowritebarrier
func scanobject(b uintptr, gcw *gcWork) {
// Prefetch object before we scan it.
//
// This will overlap fetching the beginning of the object with initial
// setup before we start scanning the object.
sys.Prefetch(b)
// Find the bits for b and the size of the object at b.
//
// b is either the beginning of an object, in which case this
// is the size of the object to scan, or it points to an
// oblet, in which case we compute the size to scan below.
s := spanOfUnchecked(b)
n := s.elemsize
if n == 0 {
throw("scanobject n == 0")
}
if s.spanclass.noscan() {
// Correctness-wise this is ok, but it's inefficient
// if noscan objects reach here.
throw("scanobject of a noscan object")
}
if n > maxObletBytes {
// Large object. Break into oblets for better
// parallelism and lower latency.
if b == s.base() {
// Enqueue the other oblets to scan later.
// Some oblets may be in b's scalar tail, but
// these will be marked as "no more pointers",
// so we'll drop out immediately when we go to
// scan those.
for oblet := b + maxObletBytes; oblet < s.base()+s.elemsize; oblet += maxObletBytes {
if !gcw.putFast(oblet) {
gcw.put(oblet)
}
}
}
// Compute the size of the oblet. Since this object
// must be a large object, s.base() is the beginning
// of the object.
n = s.base() + s.elemsize - b
if n > maxObletBytes {
n = maxObletBytes
}
}
hbits := heapBitsForAddr(b, n)
var scanSize uintptr
for {
var addr uintptr
if hbits, addr = hbits.nextFast(); addr == 0 {
if hbits, addr = hbits.next(); addr == 0 {
break
}
}
// Keep track of farthest pointer we found, so we can
// update heapScanWork. TODO: is there a better metric,
// now that we can skip scalar portions pretty efficiently?
scanSize = addr - b + goarch.PtrSize
// Work here is duplicated in scanblock and above.
// If you make changes here, make changes there too.
obj := *(*uintptr)(unsafe.Pointer(addr))
// At this point we have extracted the next potential pointer.
// Quickly filter out nil and pointers back to the current object.
if obj != 0 && obj-b >= n {
// Test if obj points into the Go heap and, if so,
// mark the object.
//
// Note that it's possible for findObject to
// fail if obj points to a just-allocated heap
// object because of a race with growing the
// heap. In this case, we know the object was
// just allocated and hence will be marked by
// allocation itself.
if obj, span, objIndex := findObject(obj, b, addr-b); obj != 0 {
greyobject(obj, b, addr-b, span, gcw, objIndex)
}
}
}
gcw.bytesMarked += uint64(n)
gcw.heapScanWork += int64(scanSize)
}
// scanConservative scans block [b, b+n) conservatively, treating any
// pointer-like value in the block as a pointer.
//
// If ptrmask != nil, only words that are marked in ptrmask are
// considered as potential pointers.
//
// If state != nil, it's assumed that [b, b+n) is a block in the stack
// and may contain pointers to stack objects.
func scanConservative(b, n uintptr, ptrmask *uint8, gcw *gcWork, state *stackScanState) {
if debugScanConservative {
printlock()
print("conservatively scanning [", hex(b), ",", hex(b+n), ")\n")
hexdumpWords(b, b+n, func(p uintptr) byte {
if ptrmask != nil {
word := (p - b) / goarch.PtrSize
bits := *addb(ptrmask, word/8)
if (bits>>(word%8))&1 == 0 {
return '$'
}
}
val := *(*uintptr)(unsafe.Pointer(p))
if state != nil && state.stack.lo <= val && val < state.stack.hi {
return '@'
}
span := spanOfHeap(val)
if span == nil {
return ' '
}
idx := span.objIndex(val)
if span.isFree(idx) {
return ' '
}
return '*'
})
printunlock()
}
for i := uintptr(0); i < n; i += goarch.PtrSize {
if ptrmask != nil {
word := i / goarch.PtrSize
bits := *addb(ptrmask, word/8)
if bits == 0 {
// Skip 8 words (the loop increment will do the 8th)
//
// This must be the first time we've
// seen this word of ptrmask, so i
// must be 8-word-aligned, but check
// our reasoning just in case.
if i%(goarch.PtrSize*8) != 0 {
throw("misaligned mask")
}
i += goarch.PtrSize*8 - goarch.PtrSize
continue
}
if (bits>>(word%8))&1 == 0 {
continue
}
}
val := *(*uintptr)(unsafe.Pointer(b + i))
// Check if val points into the stack.
if state != nil && state.stack.lo <= val && val < state.stack.hi {
// val may point to a stack object. This
// object may be dead from last cycle and
// hence may contain pointers to unallocated
// objects, but unlike heap objects we can't
// tell if it's already dead. Hence, if all
// pointers to this object are from
// conservative scanning, we have to scan it
// defensively, too.
state.putPtr(val, true)
continue
}
// Check if val points to a heap span.
span := spanOfHeap(val)
if span == nil {
continue
}
// Check if val points to an allocated object.
idx := span.objIndex(val)
if span.isFree(idx) {
continue
}
// val points to an allocated object. Mark it.
obj := span.base() + idx*span.elemsize
greyobject(obj, b, i, span, gcw, idx)
}
}
// Shade the object if it isn't already.
// The object is not nil and known to be in the heap.
// Preemption must be disabled.
//
//go:nowritebarrier
func shade(b uintptr) {
if obj, span, objIndex := findObject(b, 0, 0); obj != 0 {
gcw := &getg().m.p.ptr().gcw
greyobject(obj, 0, 0, span, gcw, objIndex)
}
}
// obj is the start of an object with mark mbits.
// If it isn't already marked, mark it and enqueue into gcw.
// base and off are for debugging only and could be removed.
//
// See also wbBufFlush1, which partially duplicates this logic.
//
//go:nowritebarrierrec
func greyobject(obj, base, off uintptr, span *mspan, gcw *gcWork, objIndex uintptr) {
// obj should be start of allocation, and so must be at least pointer-aligned.
if obj&(goarch.PtrSize-1) != 0 {
throw("greyobject: obj not pointer-aligned")
}
mbits := span.markBitsForIndex(objIndex)
if useCheckmark {
if setCheckmark(obj, base, off, mbits) {
// Already marked.
return
}
} else {
if debug.gccheckmark > 0 && span.isFree(objIndex) {
print("runtime: marking free object ", hex(obj), " found at *(", hex(base), "+", hex(off), ")\n")
gcDumpObject("base", base, off)
gcDumpObject("obj", obj, ^uintptr(0))
getg().m.traceback = 2
throw("marking free object")
}
// If marked we have nothing to do.
if mbits.isMarked() {
return
}
mbits.setMarked()
// Mark span.
arena, pageIdx, pageMask := pageIndexOf(span.base())
if arena.pageMarks[pageIdx]&pageMask == 0 {
atomic.Or8(&arena.pageMarks[pageIdx], pageMask)
}
// If this is a noscan object, fast-track it to black
// instead of greying it.
if span.spanclass.noscan() {
gcw.bytesMarked += uint64(span.elemsize)
return
}
}
// We're adding obj to P's local workbuf, so it's likely
// this object will be processed soon by the same P.
// Even if the workbuf gets flushed, there will likely still be
// some benefit on platforms with inclusive shared caches.
sys.Prefetch(obj)
// Queue the obj for scanning.
if !gcw.putFast(obj) {
gcw.put(obj)
}
}
// gcDumpObject dumps the contents of obj for debugging and marks the
// field at byte offset off in obj.
func gcDumpObject(label string, obj, off uintptr) {
s := spanOf(obj)
print(label, "=", hex(obj))
if s == nil {
print(" s=nil\n")
return
}
print(" s.base()=", hex(s.base()), " s.limit=", hex(s.limit), " s.spanclass=", s.spanclass, " s.elemsize=", s.elemsize, " s.state=")
if state := s.state.get(); 0 <= state && int(state) < len(mSpanStateNames) {
print(mSpanStateNames[state], "\n")
} else {
print("unknown(", state, ")\n")
}
skipped := false
size := s.elemsize
if s.state.get() == mSpanManual && size == 0 {
// We're printing something from a stack frame. We
// don't know how big it is, so just show up to an
// including off.
size = off + goarch.PtrSize
}
for i := uintptr(0); i < size; i += goarch.PtrSize {
// For big objects, just print the beginning (because
// that usually hints at the object's type) and the
// fields around off.
if !(i < 128*goarch.PtrSize || off-16*goarch.PtrSize < i && i < off+16*goarch.PtrSize) {
skipped = true
continue
}
if skipped {
print(" ...\n")
skipped = false
}
print(" *(", label, "+", i, ") = ", hex(*(*uintptr)(unsafe.Pointer(obj + i))))
if i == off {
print(" <==")
}
print("\n")
}
if skipped {
print(" ...\n")
}
}
// gcmarknewobject marks a newly allocated object black. obj must
// not contain any non-nil pointers.
//
// This is nosplit so it can manipulate a gcWork without preemption.
//
//go:nowritebarrier
//go:nosplit
func gcmarknewobject(span *mspan, obj, size uintptr) {
if useCheckmark { // The world should be stopped so this should not happen.
throw("gcmarknewobject called while doing checkmark")
}
// Mark object.
objIndex := span.objIndex(obj)
span.markBitsForIndex(objIndex).setMarked()
// Mark span.
arena, pageIdx, pageMask := pageIndexOf(span.base())
if arena.pageMarks[pageIdx]&pageMask == 0 {
atomic.Or8(&arena.pageMarks[pageIdx], pageMask)
}
gcw := &getg().m.p.ptr().gcw
gcw.bytesMarked += uint64(size)
}
// gcMarkTinyAllocs greys all active tiny alloc blocks.
//
// The world must be stopped.
func gcMarkTinyAllocs() {
assertWorldStopped()
for _, p := range allp {
c := p.mcache
if c == nil || c.tiny == 0 {
continue
}
_, span, objIndex := findObject(c.tiny, 0, 0)
gcw := &p.gcw
greyobject(c.tiny, 0, 0, span, gcw, objIndex)
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/cpu"
"internal/goexperiment"
"runtime/internal/atomic"
_ "unsafe" // for go:linkname
)
const (
// gcGoalUtilization is the goal CPU utilization for
// marking as a fraction of GOMAXPROCS.
//
// Increasing the goal utilization will shorten GC cycles as the GC
// has more resources behind it, lessening costs from the write barrier,
// but comes at the cost of increasing mutator latency.
gcGoalUtilization = gcBackgroundUtilization
// gcBackgroundUtilization is the fixed CPU utilization for background
// marking. It must be <= gcGoalUtilization. The difference between
// gcGoalUtilization and gcBackgroundUtilization will be made up by
// mark assists. The scheduler will aim to use within 50% of this
// goal.
//
// As a general rule, there's little reason to set gcBackgroundUtilization
// < gcGoalUtilization. One reason might be in mostly idle applications,
// where goroutines are unlikely to assist at all, so the actual
// utilization will be lower than the goal. But this is moot point
// because the idle mark workers already soak up idle CPU resources.
// These two values are still kept separate however because they are
// distinct conceptually, and in previous iterations of the pacer the
// distinction was more important.
gcBackgroundUtilization = 0.25
// gcCreditSlack is the amount of scan work credit that can
// accumulate locally before updating gcController.heapScanWork and,
// optionally, gcController.bgScanCredit. Lower values give a more
// accurate assist ratio and make it more likely that assists will
// successfully steal background credit. Higher values reduce memory
// contention.
gcCreditSlack = 2000
// gcAssistTimeSlack is the nanoseconds of mutator assist time that
// can accumulate on a P before updating gcController.assistTime.
gcAssistTimeSlack = 5000
// gcOverAssistWork determines how many extra units of scan work a GC
// assist does when an assist happens. This amortizes the cost of an
// assist by pre-paying for this many bytes of future allocations.
gcOverAssistWork = 64 << 10
// defaultHeapMinimum is the value of heapMinimum for GOGC==100.
defaultHeapMinimum = (goexperiment.HeapMinimum512KiBInt)*(512<<10) +
(1-goexperiment.HeapMinimum512KiBInt)*(4<<20)
// maxStackScanSlack is the bytes of stack space allocated or freed
// that can accumulate on a P before updating gcController.stackSize.
maxStackScanSlack = 8 << 10
// memoryLimitHeapGoalHeadroom is the amount of headroom the pacer gives to
// the heap goal when operating in the memory-limited regime. That is,
// it'll reduce the heap goal by this many extra bytes off of the base
// calculation.
memoryLimitHeapGoalHeadroom = 1 << 20
)
// gcController implements the GC pacing controller that determines
// when to trigger concurrent garbage collection and how much marking
// work to do in mutator assists and background marking.
//
// It calculates the ratio between the allocation rate (in terms of CPU
// time) and the GC scan throughput to determine the heap size at which to
// trigger a GC cycle such that no GC assists are required to finish on time.
// This algorithm thus optimizes GC CPU utilization to the dedicated background
// mark utilization of 25% of GOMAXPROCS by minimizing GC assists.
// GOMAXPROCS. The high-level design of this algorithm is documented
// at https://github.com/golang/proposal/blob/master/design/44167-gc-pacer-redesign.md.
// See https://golang.org/s/go15gcpacing for additional historical context.
var gcController gcControllerState
type gcControllerState struct {
// Initialized from GOGC. GOGC=off means no GC.
gcPercent atomic.Int32
// memoryLimit is the soft memory limit in bytes.
//
// Initialized from GOMEMLIMIT. GOMEMLIMIT=off is equivalent to MaxInt64
// which means no soft memory limit in practice.
//
// This is an int64 instead of a uint64 to more easily maintain parity with
// the SetMemoryLimit API, which sets a maximum at MaxInt64. This value
// should never be negative.
memoryLimit atomic.Int64
// heapMinimum is the minimum heap size at which to trigger GC.
// For small heaps, this overrides the usual GOGC*live set rule.
//
// When there is a very small live set but a lot of allocation, simply
// collecting when the heap reaches GOGC*live results in many GC
// cycles and high total per-GC overhead. This minimum amortizes this
// per-GC overhead while keeping the heap reasonably small.
//
// During initialization this is set to 4MB*GOGC/100. In the case of
// GOGC==0, this will set heapMinimum to 0, resulting in constant
// collection even when the heap size is small, which is useful for
// debugging.
heapMinimum uint64
// runway is the amount of runway in heap bytes allocated by the
// application that we want to give the GC once it starts.
//
// This is computed from consMark during mark termination.
runway atomic.Uint64
// consMark is the estimated per-CPU consMark ratio for the application.
//
// It represents the ratio between the application's allocation
// rate, as bytes allocated per CPU-time, and the GC's scan rate,
// as bytes scanned per CPU-time.
// The units of this ratio are (B / cpu-ns) / (B / cpu-ns).
//
// At a high level, this value is computed as the bytes of memory
// allocated (cons) per unit of scan work completed (mark) in a GC
// cycle, divided by the CPU time spent on each activity.
//
// Updated at the end of each GC cycle, in endCycle.
consMark float64
// lastConsMark is the computed cons/mark value for the previous GC
// cycle. Note that this is *not* the last value of cons/mark, but the
// actual computed value. See endCycle for details.
lastConsMark float64
// gcPercentHeapGoal is the goal heapLive for when next GC ends derived
// from gcPercent.
//
// Set to ^uint64(0) if gcPercent is disabled.
gcPercentHeapGoal atomic.Uint64
// sweepDistMinTrigger is the minimum trigger to ensure a minimum
// sweep distance.
//
// This bound is also special because it applies to both the trigger
// *and* the goal (all other trigger bounds must be based *on* the goal).
//
// It is computed ahead of time, at commit time. The theory is that,
// absent a sudden change to a parameter like gcPercent, the trigger
// will be chosen to always give the sweeper enough headroom. However,
// such a change might dramatically and suddenly move up the trigger,
// in which case we need to ensure the sweeper still has enough headroom.
sweepDistMinTrigger atomic.Uint64
// triggered is the point at which the current GC cycle actually triggered.
// Only valid during the mark phase of a GC cycle, otherwise set to ^uint64(0).
//
// Updated while the world is stopped.
triggered uint64
// lastHeapGoal is the value of heapGoal at the moment the last GC
// ended. Note that this is distinct from the last value heapGoal had,
// because it could change if e.g. gcPercent changes.
//
// Read and written with the world stopped or with mheap_.lock held.
lastHeapGoal uint64
// heapLive is the number of bytes considered live by the GC.
// That is: retained by the most recent GC plus allocated
// since then. heapLive ≤ memstats.totalAlloc-memstats.totalFree, since
// heapAlloc includes unmarked objects that have not yet been swept (and
// hence goes up as we allocate and down as we sweep) while heapLive
// excludes these objects (and hence only goes up between GCs).
//
// To reduce contention, this is updated only when obtaining a span
// from an mcentral and at this point it counts all of the unallocated
// slots in that span (which will be allocated before that mcache
// obtains another span from that mcentral). Hence, it slightly
// overestimates the "true" live heap size. It's better to overestimate
// than to underestimate because 1) this triggers the GC earlier than
// necessary rather than potentially too late and 2) this leads to a
// conservative GC rate rather than a GC rate that is potentially too
// low.
//
// Whenever this is updated, call traceHeapAlloc() and
// this gcControllerState's revise() method.
heapLive atomic.Uint64
// heapScan is the number of bytes of "scannable" heap. This is the
// live heap (as counted by heapLive), but omitting no-scan objects and
// no-scan tails of objects.
//
// This value is fixed at the start of a GC cycle. It represents the
// maximum scannable heap.
heapScan atomic.Uint64
// lastHeapScan is the number of bytes of heap that were scanned
// last GC cycle. It is the same as heapMarked, but only
// includes the "scannable" parts of objects.
//
// Updated when the world is stopped.
lastHeapScan uint64
// lastStackScan is the number of bytes of stack that were scanned
// last GC cycle.
lastStackScan atomic.Uint64
// maxStackScan is the amount of allocated goroutine stack space in
// use by goroutines.
//
// This number tracks allocated goroutine stack space rather than used
// goroutine stack space (i.e. what is actually scanned) because used
// goroutine stack space is much harder to measure cheaply. By using
// allocated space, we make an overestimate; this is OK, it's better
// to conservatively overcount than undercount.
maxStackScan atomic.Uint64
// globalsScan is the total amount of global variable space
// that is scannable.
globalsScan atomic.Uint64
// heapMarked is the number of bytes marked by the previous
// GC. After mark termination, heapLive == heapMarked, but
// unlike heapLive, heapMarked does not change until the
// next mark termination.
heapMarked uint64
// heapScanWork is the total heap scan work performed this cycle.
// stackScanWork is the total stack scan work performed this cycle.
// globalsScanWork is the total globals scan work performed this cycle.
//
// These are updated atomically during the cycle. Updates occur in
// bounded batches, since they are both written and read
// throughout the cycle. At the end of the cycle, heapScanWork is how
// much of the retained heap is scannable.
//
// Currently these are measured in bytes. For most uses, this is an
// opaque unit of work, but for estimation the definition is important.
//
// Note that stackScanWork includes only stack space scanned, not all
// of the allocated stack.
heapScanWork atomic.Int64
stackScanWork atomic.Int64
globalsScanWork atomic.Int64
// bgScanCredit is the scan work credit accumulated by the concurrent
// background scan. This credit is accumulated by the background scan
// and stolen by mutator assists. Updates occur in bounded batches,
// since it is both written and read throughout the cycle.
bgScanCredit atomic.Int64
// assistTime is the nanoseconds spent in mutator assists
// during this cycle. This is updated atomically, and must also
// be updated atomically even during a STW, because it is read
// by sysmon. Updates occur in bounded batches, since it is both
// written and read throughout the cycle.
assistTime atomic.Int64
// dedicatedMarkTime is the nanoseconds spent in dedicated mark workers
// during this cycle. This is updated at the end of the concurrent mark
// phase.
dedicatedMarkTime atomic.Int64
// fractionalMarkTime is the nanoseconds spent in the fractional mark
// worker during this cycle. This is updated throughout the cycle and
// will be up-to-date if the fractional mark worker is not currently
// running.
fractionalMarkTime atomic.Int64
// idleMarkTime is the nanoseconds spent in idle marking during this
// cycle. This is updated throughout the cycle.
idleMarkTime atomic.Int64
// markStartTime is the absolute start time in nanoseconds
// that assists and background mark workers started.
markStartTime int64
// dedicatedMarkWorkersNeeded is the number of dedicated mark workers
// that need to be started. This is computed at the beginning of each
// cycle and decremented as dedicated mark workers get started.
dedicatedMarkWorkersNeeded atomic.Int64
// idleMarkWorkers is two packed int32 values in a single uint64.
// These two values are always updated simultaneously.
//
// The bottom int32 is the current number of idle mark workers executing.
//
// The top int32 is the maximum number of idle mark workers allowed to
// execute concurrently. Normally, this number is just gomaxprocs. However,
// during periodic GC cycles it is set to 0 because the system is idle
// anyway; there's no need to go full blast on all of GOMAXPROCS.
//
// The maximum number of idle mark workers is used to prevent new workers
// from starting, but it is not a hard maximum. It is possible (but
// exceedingly rare) for the current number of idle mark workers to
// transiently exceed the maximum. This could happen if the maximum changes
// just after a GC ends, and an M with no P.
//
// Note that if we have no dedicated mark workers, we set this value to
// 1 in this case we only have fractional GC workers which aren't scheduled
// strictly enough to ensure GC progress. As a result, idle-priority mark
// workers are vital to GC progress in these situations.
//
// For example, consider a situation in which goroutines block on the GC
// (such as via runtime.GOMAXPROCS) and only fractional mark workers are
// scheduled (e.g. GOMAXPROCS=1). Without idle-priority mark workers, the
// last running M might skip scheduling a fractional mark worker if its
// utilization goal is met, such that once it goes to sleep (because there's
// nothing to do), there will be nothing else to spin up a new M for the
// fractional worker in the future, stalling GC progress and causing a
// deadlock. However, idle-priority workers will *always* run when there is
// nothing left to do, ensuring the GC makes progress.
//
// See github.com/golang/go/issues/44163 for more details.
idleMarkWorkers atomic.Uint64
// assistWorkPerByte is the ratio of scan work to allocated
// bytes that should be performed by mutator assists. This is
// computed at the beginning of each cycle and updated every
// time heapScan is updated.
assistWorkPerByte atomic.Float64
// assistBytesPerWork is 1/assistWorkPerByte.
//
// Note that because this is read and written independently
// from assistWorkPerByte users may notice a skew between
// the two values, and such a state should be safe.
assistBytesPerWork atomic.Float64
// fractionalUtilizationGoal is the fraction of wall clock
// time that should be spent in the fractional mark worker on
// each P that isn't running a dedicated worker.
//
// For example, if the utilization goal is 25% and there are
// no dedicated workers, this will be 0.25. If the goal is
// 25%, there is one dedicated worker, and GOMAXPROCS is 5,
// this will be 0.05 to make up the missing 5%.
//
// If this is zero, no fractional workers are needed.
fractionalUtilizationGoal float64
// These memory stats are effectively duplicates of fields from
// memstats.heapStats but are updated atomically or with the world
// stopped and don't provide the same consistency guarantees.
//
// Because the runtime is responsible for managing a memory limit, it's
// useful to couple these stats more tightly to the gcController, which
// is intimately connected to how that memory limit is maintained.
heapInUse sysMemStat // bytes in mSpanInUse spans
heapReleased sysMemStat // bytes released to the OS
heapFree sysMemStat // bytes not in any span, but not released to the OS
totalAlloc atomic.Uint64 // total bytes allocated
totalFree atomic.Uint64 // total bytes freed
mappedReady atomic.Uint64 // total virtual memory in the Ready state (see mem.go).
// test indicates that this is a test-only copy of gcControllerState.
test bool
_ cpu.CacheLinePad
}
func (c *gcControllerState) init(gcPercent int32, memoryLimit int64) {
c.heapMinimum = defaultHeapMinimum
c.triggered = ^uint64(0)
c.setGCPercent(gcPercent)
c.setMemoryLimit(memoryLimit)
c.commit(true) // No sweep phase in the first GC cycle.
// N.B. Don't bother calling traceHeapGoal. Tracing is never enabled at
// initialization time.
// N.B. No need to call revise; there's no GC enabled during
// initialization.
}
// startCycle resets the GC controller's state and computes estimates
// for a new GC cycle. The caller must hold worldsema and the world
// must be stopped.
func (c *gcControllerState) startCycle(markStartTime int64, procs int, trigger gcTrigger) {
c.heapScanWork.Store(0)
c.stackScanWork.Store(0)
c.globalsScanWork.Store(0)
c.bgScanCredit.Store(0)
c.assistTime.Store(0)
c.dedicatedMarkTime.Store(0)
c.fractionalMarkTime.Store(0)
c.idleMarkTime.Store(0)
c.markStartTime = markStartTime
c.triggered = c.heapLive.Load()
// Compute the background mark utilization goal. In general,
// this may not come out exactly. We round the number of
// dedicated workers so that the utilization is closest to
// 25%. For small GOMAXPROCS, this would introduce too much
// error, so we add fractional workers in that case.
totalUtilizationGoal := float64(procs) * gcBackgroundUtilization
dedicatedMarkWorkersNeeded := int64(totalUtilizationGoal + 0.5)
utilError := float64(dedicatedMarkWorkersNeeded)/totalUtilizationGoal - 1
const maxUtilError = 0.3
if utilError < -maxUtilError || utilError > maxUtilError {
// Rounding put us more than 30% off our goal. With
// gcBackgroundUtilization of 25%, this happens for
// GOMAXPROCS<=3 or GOMAXPROCS=6. Enable fractional
// workers to compensate.
if float64(dedicatedMarkWorkersNeeded) > totalUtilizationGoal {
// Too many dedicated workers.
dedicatedMarkWorkersNeeded--
}
c.fractionalUtilizationGoal = (totalUtilizationGoal - float64(dedicatedMarkWorkersNeeded)) / float64(procs)
} else {
c.fractionalUtilizationGoal = 0
}
// In STW mode, we just want dedicated workers.
if debug.gcstoptheworld > 0 {
dedicatedMarkWorkersNeeded = int64(procs)
c.fractionalUtilizationGoal = 0
}
// Clear per-P state
for _, p := range allp {
p.gcAssistTime = 0
p.gcFractionalMarkTime = 0
}
if trigger.kind == gcTriggerTime {
// During a periodic GC cycle, reduce the number of idle mark workers
// required. However, we need at least one dedicated mark worker or
// idle GC worker to ensure GC progress in some scenarios (see comment
// on maxIdleMarkWorkers).
if dedicatedMarkWorkersNeeded > 0 {
c.setMaxIdleMarkWorkers(0)
} else {
// TODO(mknyszek): The fundamental reason why we need this is because
// we can't count on the fractional mark worker to get scheduled.
// Fix that by ensuring it gets scheduled according to its quota even
// if the rest of the application is idle.
c.setMaxIdleMarkWorkers(1)
}
} else {
// N.B. gomaxprocs and dedicatedMarkWorkersNeeded are guaranteed not to
// change during a GC cycle.
c.setMaxIdleMarkWorkers(int32(procs) - int32(dedicatedMarkWorkersNeeded))
}
// Compute initial values for controls that are updated
// throughout the cycle.
c.dedicatedMarkWorkersNeeded.Store(dedicatedMarkWorkersNeeded)
c.revise()
if debug.gcpacertrace > 0 {
heapGoal := c.heapGoal()
assistRatio := c.assistWorkPerByte.Load()
print("pacer: assist ratio=", assistRatio,
" (scan ", gcController.heapScan.Load()>>20, " MB in ",
work.initialHeapLive>>20, "->",
heapGoal>>20, " MB)",
" workers=", dedicatedMarkWorkersNeeded,
"+", c.fractionalUtilizationGoal, "\n")
}
}
// revise updates the assist ratio during the GC cycle to account for
// improved estimates. This should be called whenever gcController.heapScan,
// gcController.heapLive, or if any inputs to gcController.heapGoal are
// updated. It is safe to call concurrently, but it may race with other
// calls to revise.
//
// The result of this race is that the two assist ratio values may not line
// up or may be stale. In practice this is OK because the assist ratio
// moves slowly throughout a GC cycle, and the assist ratio is a best-effort
// heuristic anyway. Furthermore, no part of the heuristic depends on
// the two assist ratio values being exact reciprocals of one another, since
// the two values are used to convert values from different sources.
//
// The worst case result of this raciness is that we may miss a larger shift
// in the ratio (say, if we decide to pace more aggressively against the
// hard heap goal) but even this "hard goal" is best-effort (see #40460).
// The dedicated GC should ensure we don't exceed the hard goal by too much
// in the rare case we do exceed it.
//
// It should only be called when gcBlackenEnabled != 0 (because this
// is when assists are enabled and the necessary statistics are
// available).
func (c *gcControllerState) revise() {
gcPercent := c.gcPercent.Load()
if gcPercent < 0 {
// If GC is disabled but we're running a forced GC,
// act like GOGC is huge for the below calculations.
gcPercent = 100000
}
live := c.heapLive.Load()
scan := c.heapScan.Load()
work := c.heapScanWork.Load() + c.stackScanWork.Load() + c.globalsScanWork.Load()
// Assume we're under the soft goal. Pace GC to complete at
// heapGoal assuming the heap is in steady-state.
heapGoal := int64(c.heapGoal())
// The expected scan work is computed as the amount of bytes scanned last
// GC cycle (both heap and stack), plus our estimate of globals work for this cycle.
scanWorkExpected := int64(c.lastHeapScan + c.lastStackScan.Load() + c.globalsScan.Load())
// maxScanWork is a worst-case estimate of the amount of scan work that
// needs to be performed in this GC cycle. Specifically, it represents
// the case where *all* scannable memory turns out to be live, and
// *all* allocated stack space is scannable.
maxStackScan := c.maxStackScan.Load()
maxScanWork := int64(scan + maxStackScan + c.globalsScan.Load())
if work > scanWorkExpected {
// We've already done more scan work than expected. Because our expectation
// is based on a steady-state scannable heap size, we assume this means our
// heap is growing. Compute a new heap goal that takes our existing runway
// computed for scanWorkExpected and extrapolates it to maxScanWork, the worst-case
// scan work. This keeps our assist ratio stable if the heap continues to grow.
//
// The effect of this mechanism is that assists stay flat in the face of heap
// growths. It's OK to use more memory this cycle to scan all the live heap,
// because the next GC cycle is inevitably going to use *at least* that much
// memory anyway.
extHeapGoal := int64(float64(heapGoal-int64(c.triggered))/float64(scanWorkExpected)*float64(maxScanWork)) + int64(c.triggered)
scanWorkExpected = maxScanWork
// hardGoal is a hard limit on the amount that we're willing to push back the
// heap goal, and that's twice the heap goal (i.e. if GOGC=100 and the heap and/or
// stacks and/or globals grow to twice their size, this limits the current GC cycle's
// growth to 4x the original live heap's size).
//
// This maintains the invariant that we use no more memory than the next GC cycle
// will anyway.
hardGoal := int64((1.0 + float64(gcPercent)/100.0) * float64(heapGoal))
if extHeapGoal > hardGoal {
extHeapGoal = hardGoal
}
heapGoal = extHeapGoal
}
if int64(live) > heapGoal {
// We're already past our heap goal, even the extrapolated one.
// Leave ourselves some extra runway, so in the worst case we
// finish by that point.
const maxOvershoot = 1.1
heapGoal = int64(float64(heapGoal) * maxOvershoot)
// Compute the upper bound on the scan work remaining.
scanWorkExpected = maxScanWork
}
// Compute the remaining scan work estimate.
//
// Note that we currently count allocations during GC as both
// scannable heap (heapScan) and scan work completed
// (scanWork), so allocation will change this difference
// slowly in the soft regime and not at all in the hard
// regime.
scanWorkRemaining := scanWorkExpected - work
if scanWorkRemaining < 1000 {
// We set a somewhat arbitrary lower bound on
// remaining scan work since if we aim a little high,
// we can miss by a little.
//
// We *do* need to enforce that this is at least 1,
// since marking is racy and double-scanning objects
// may legitimately make the remaining scan work
// negative, even in the hard goal regime.
scanWorkRemaining = 1000
}
// Compute the heap distance remaining.
heapRemaining := heapGoal - int64(live)
if heapRemaining <= 0 {
// This shouldn't happen, but if it does, avoid
// dividing by zero or setting the assist negative.
heapRemaining = 1
}
// Compute the mutator assist ratio so by the time the mutator
// allocates the remaining heap bytes up to heapGoal, it will
// have done (or stolen) the remaining amount of scan work.
// Note that the assist ratio values are updated atomically
// but not together. This means there may be some degree of
// skew between the two values. This is generally OK as the
// values shift relatively slowly over the course of a GC
// cycle.
assistWorkPerByte := float64(scanWorkRemaining) / float64(heapRemaining)
assistBytesPerWork := float64(heapRemaining) / float64(scanWorkRemaining)
c.assistWorkPerByte.Store(assistWorkPerByte)
c.assistBytesPerWork.Store(assistBytesPerWork)
}
// endCycle computes the consMark estimate for the next cycle.
// userForced indicates whether the current GC cycle was forced
// by the application.
func (c *gcControllerState) endCycle(now int64, procs int, userForced bool) {
// Record last heap goal for the scavenger.
// We'll be updating the heap goal soon.
gcController.lastHeapGoal = c.heapGoal()
// Compute the duration of time for which assists were turned on.
assistDuration := now - c.markStartTime
// Assume background mark hit its utilization goal.
utilization := gcBackgroundUtilization
// Add assist utilization; avoid divide by zero.
if assistDuration > 0 {
utilization += float64(c.assistTime.Load()) / float64(assistDuration*int64(procs))
}
if c.heapLive.Load() <= c.triggered {
// Shouldn't happen, but let's be very safe about this in case the
// GC is somehow extremely short.
//
// In this case though, the only reasonable value for c.heapLive-c.triggered
// would be 0, which isn't really all that useful, i.e. the GC was so short
// that it didn't matter.
//
// Ignore this case and don't update anything.
return
}
idleUtilization := 0.0
if assistDuration > 0 {
idleUtilization = float64(c.idleMarkTime.Load()) / float64(assistDuration*int64(procs))
}
// Determine the cons/mark ratio.
//
// The units we want for the numerator and denominator are both B / cpu-ns.
// We get this by taking the bytes allocated or scanned, and divide by the amount of
// CPU time it took for those operations. For allocations, that CPU time is
//
// assistDuration * procs * (1 - utilization)
//
// Where utilization includes just background GC workers and assists. It does *not*
// include idle GC work time, because in theory the mutator is free to take that at
// any point.
//
// For scanning, that CPU time is
//
// assistDuration * procs * (utilization + idleUtilization)
//
// In this case, we *include* idle utilization, because that is additional CPU time that
// the GC had available to it.
//
// In effect, idle GC time is sort of double-counted here, but it's very weird compared
// to other kinds of GC work, because of how fluid it is. Namely, because the mutator is
// *always* free to take it.
//
// So this calculation is really:
// (heapLive-trigger) / (assistDuration * procs * (1-utilization)) /
// (scanWork) / (assistDuration * procs * (utilization+idleUtilization)
//
// Note that because we only care about the ratio, assistDuration and procs cancel out.
scanWork := c.heapScanWork.Load() + c.stackScanWork.Load() + c.globalsScanWork.Load()
currentConsMark := (float64(c.heapLive.Load()-c.triggered) * (utilization + idleUtilization)) /
(float64(scanWork) * (1 - utilization))
// Update our cons/mark estimate. This is the raw value above, but averaged over 2 GC cycles
// because it tends to be jittery, even in the steady-state. The smoothing helps the GC to
// maintain much more stable cycle-by-cycle behavior.
oldConsMark := c.consMark
c.consMark = (currentConsMark + c.lastConsMark) / 2
c.lastConsMark = currentConsMark
if debug.gcpacertrace > 0 {
printlock()
goal := gcGoalUtilization * 100
print("pacer: ", int(utilization*100), "% CPU (", int(goal), " exp.) for ")
print(c.heapScanWork.Load(), "+", c.stackScanWork.Load(), "+", c.globalsScanWork.Load(), " B work (", c.lastHeapScan+c.lastStackScan.Load()+c.globalsScan.Load(), " B exp.) ")
live := c.heapLive.Load()
print("in ", c.triggered, " B -> ", live, " B (∆goal ", int64(live)-int64(c.lastHeapGoal), ", cons/mark ", oldConsMark, ")")
println()
printunlock()
}
}
// enlistWorker encourages another dedicated mark worker to start on
// another P if there are spare worker slots. It is used by putfull
// when more work is made available.
//
//go:nowritebarrier
func (c *gcControllerState) enlistWorker() {
// If there are idle Ps, wake one so it will run an idle worker.
// NOTE: This is suspected of causing deadlocks. See golang.org/issue/19112.
//
// if sched.npidle.Load() != 0 && sched.nmspinning.Load() == 0 {
// wakep()
// return
// }
// There are no idle Ps. If we need more dedicated workers,
// try to preempt a running P so it will switch to a worker.
if c.dedicatedMarkWorkersNeeded.Load() <= 0 {
return
}
// Pick a random other P to preempt.
if gomaxprocs <= 1 {
return
}
gp := getg()
if gp == nil || gp.m == nil || gp.m.p == 0 {
return
}
myID := gp.m.p.ptr().id
for tries := 0; tries < 5; tries++ {
id := int32(fastrandn(uint32(gomaxprocs - 1)))
if id >= myID {
id++
}
p := allp[id]
if p.status != _Prunning {
continue
}
if preemptone(p) {
return
}
}
}
// findRunnableGCWorker returns a background mark worker for pp if it
// should be run. This must only be called when gcBlackenEnabled != 0.
func (c *gcControllerState) findRunnableGCWorker(pp *p, now int64) (*g, int64) {
if gcBlackenEnabled == 0 {
throw("gcControllerState.findRunnable: blackening not enabled")
}
// Since we have the current time, check if the GC CPU limiter
// hasn't had an update in a while. This check is necessary in
// case the limiter is on but hasn't been checked in a while and
// so may have left sufficient headroom to turn off again.
if now == 0 {
now = nanotime()
}
if gcCPULimiter.needUpdate(now) {
gcCPULimiter.update(now)
}
if !gcMarkWorkAvailable(pp) {
// No work to be done right now. This can happen at
// the end of the mark phase when there are still
// assists tapering off. Don't bother running a worker
// now because it'll just return immediately.
return nil, now
}
// Grab a worker before we commit to running below.
node := (*gcBgMarkWorkerNode)(gcBgMarkWorkerPool.pop())
if node == nil {
// There is at least one worker per P, so normally there are
// enough workers to run on all Ps, if necessary. However, once
// a worker enters gcMarkDone it may park without rejoining the
// pool, thus freeing a P with no corresponding worker.
// gcMarkDone never depends on another worker doing work, so it
// is safe to simply do nothing here.
//
// If gcMarkDone bails out without completing the mark phase,
// it will always do so with queued global work. Thus, that P
// will be immediately eligible to re-run the worker G it was
// just using, ensuring work can complete.
return nil, now
}
decIfPositive := func(val *atomic.Int64) bool {
for {
v := val.Load()
if v <= 0 {
return false
}
if val.CompareAndSwap(v, v-1) {
return true
}
}
}
if decIfPositive(&c.dedicatedMarkWorkersNeeded) {
// This P is now dedicated to marking until the end of
// the concurrent mark phase.
pp.gcMarkWorkerMode = gcMarkWorkerDedicatedMode
} else if c.fractionalUtilizationGoal == 0 {
// No need for fractional workers.
gcBgMarkWorkerPool.push(&node.node)
return nil, now
} else {
// Is this P behind on the fractional utilization
// goal?
//
// This should be kept in sync with pollFractionalWorkerExit.
delta := now - c.markStartTime
if delta > 0 && float64(pp.gcFractionalMarkTime)/float64(delta) > c.fractionalUtilizationGoal {
// Nope. No need to run a fractional worker.
gcBgMarkWorkerPool.push(&node.node)
return nil, now
}
// Run a fractional worker.
pp.gcMarkWorkerMode = gcMarkWorkerFractionalMode
}
// Run the background mark worker.
gp := node.gp.ptr()
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, now
}
// resetLive sets up the controller state for the next mark phase after the end
// of the previous one. Must be called after endCycle and before commit, before
// the world is started.
//
// The world must be stopped.
func (c *gcControllerState) resetLive(bytesMarked uint64) {
c.heapMarked = bytesMarked
c.heapLive.Store(bytesMarked)
c.heapScan.Store(uint64(c.heapScanWork.Load()))
c.lastHeapScan = uint64(c.heapScanWork.Load())
c.lastStackScan.Store(uint64(c.stackScanWork.Load()))
c.triggered = ^uint64(0) // Reset triggered.
// heapLive was updated, so emit a trace event.
if trace.enabled {
traceHeapAlloc(bytesMarked)
}
}
// markWorkerStop must be called whenever a mark worker stops executing.
//
// It updates mark work accounting in the controller by a duration of
// work in nanoseconds and other bookkeeping.
//
// Safe to execute at any time.
func (c *gcControllerState) markWorkerStop(mode gcMarkWorkerMode, duration int64) {
switch mode {
case gcMarkWorkerDedicatedMode:
c.dedicatedMarkTime.Add(duration)
c.dedicatedMarkWorkersNeeded.Add(1)
case gcMarkWorkerFractionalMode:
c.fractionalMarkTime.Add(duration)
case gcMarkWorkerIdleMode:
c.idleMarkTime.Add(duration)
c.removeIdleMarkWorker()
default:
throw("markWorkerStop: unknown mark worker mode")
}
}
func (c *gcControllerState) update(dHeapLive, dHeapScan int64) {
if dHeapLive != 0 {
live := gcController.heapLive.Add(dHeapLive)
if trace.enabled {
// gcController.heapLive changed.
traceHeapAlloc(live)
}
}
if gcBlackenEnabled == 0 {
// Update heapScan when we're not in a current GC. It is fixed
// at the beginning of a cycle.
if dHeapScan != 0 {
gcController.heapScan.Add(dHeapScan)
}
} else {
// gcController.heapLive changed.
c.revise()
}
}
func (c *gcControllerState) addScannableStack(pp *p, amount int64) {
if pp == nil {
c.maxStackScan.Add(amount)
return
}
pp.maxStackScanDelta += amount
if pp.maxStackScanDelta >= maxStackScanSlack || pp.maxStackScanDelta <= -maxStackScanSlack {
c.maxStackScan.Add(pp.maxStackScanDelta)
pp.maxStackScanDelta = 0
}
}
func (c *gcControllerState) addGlobals(amount int64) {
c.globalsScan.Add(amount)
}
// heapGoal returns the current heap goal.
func (c *gcControllerState) heapGoal() uint64 {
goal, _ := c.heapGoalInternal()
return goal
}
// heapGoalInternal is the implementation of heapGoal which returns additional
// information that is necessary for computing the trigger.
//
// The returned minTrigger is always <= goal.
func (c *gcControllerState) heapGoalInternal() (goal, minTrigger uint64) {
// Start with the goal calculated for gcPercent.
goal = c.gcPercentHeapGoal.Load()
// Check if the memory-limit-based goal is smaller, and if so, pick that.
if newGoal := c.memoryLimitHeapGoal(); newGoal < goal {
goal = newGoal
} else {
// We're not limited by the memory limit goal, so perform a series of
// adjustments that might move the goal forward in a variety of circumstances.
sweepDistTrigger := c.sweepDistMinTrigger.Load()
if sweepDistTrigger > goal {
// Set the goal to maintain a minimum sweep distance since
// the last call to commit. Note that we never want to do this
// if we're in the memory limit regime, because it could push
// the goal up.
goal = sweepDistTrigger
}
// Since we ignore the sweep distance trigger in the memory
// limit regime, we need to ensure we don't propagate it to
// the trigger, because it could cause a violation of the
// invariant that the trigger < goal.
minTrigger = sweepDistTrigger
// Ensure that the heap goal is at least a little larger than
// the point at which we triggered. This may not be the case if GC
// start is delayed or if the allocation that pushed gcController.heapLive
// over trigger is large or if the trigger is really close to
// GOGC. Assist is proportional to this distance, so enforce a
// minimum distance, even if it means going over the GOGC goal
// by a tiny bit.
//
// Ignore this if we're in the memory limit regime: we'd prefer to
// have the GC respond hard about how close we are to the goal than to
// push the goal back in such a manner that it could cause us to exceed
// the memory limit.
const minRunway = 64 << 10
if c.triggered != ^uint64(0) && goal < c.triggered+minRunway {
goal = c.triggered + minRunway
}
}
return
}
// memoryLimitHeapGoal returns a heap goal derived from memoryLimit.
func (c *gcControllerState) memoryLimitHeapGoal() uint64 {
// Start by pulling out some values we'll need. Be careful about overflow.
var heapFree, heapAlloc, mappedReady uint64
for {
heapFree = c.heapFree.load() // Free and unscavenged memory.
heapAlloc = c.totalAlloc.Load() - c.totalFree.Load() // Heap object bytes in use.
mappedReady = c.mappedReady.Load() // Total unreleased mapped memory.
if heapFree+heapAlloc <= mappedReady {
break
}
// It is impossible for total unreleased mapped memory to exceed heap memory, but
// because these stats are updated independently, we may observe a partial update
// including only some values. Thus, we appear to break the invariant. However,
// this condition is necessarily transient, so just try again. In the case of a
// persistent accounting error, we'll deadlock here.
}
// Below we compute a goal from memoryLimit. There are a few things to be aware of.
// Firstly, the memoryLimit does not easily compare to the heap goal: the former
// is total mapped memory by the runtime that hasn't been released, while the latter is
// only heap object memory. Intuitively, the way we convert from one to the other is to
// subtract everything from memoryLimit that both contributes to the memory limit (so,
// ignore scavenged memory) and doesn't contain heap objects. This isn't quite what
// lines up with reality, but it's a good starting point.
//
// In practice this computation looks like the following:
//
// memoryLimit - ((mappedReady - heapFree - heapAlloc) + max(mappedReady - memoryLimit, 0)) - memoryLimitHeapGoalHeadroom
// ^1 ^2 ^3
//
// Let's break this down.
//
// The first term (marker 1) is everything that contributes to the memory limit and isn't
// or couldn't become heap objects. It represents, broadly speaking, non-heap overheads.
// One oddity you may have noticed is that we also subtract out heapFree, i.e. unscavenged
// memory that may contain heap objects in the future.
//
// Let's take a step back. In an ideal world, this term would look something like just
// the heap goal. That is, we "reserve" enough space for the heap to grow to the heap
// goal, and subtract out everything else. This is of course impossible; the definition
// is circular! However, this impossible definition contains a key insight: the amount
// we're *going* to use matters just as much as whatever we're currently using.
//
// Consider if the heap shrinks to 1/10th its size, leaving behind lots of free and
// unscavenged memory. mappedReady - heapAlloc will be quite large, because of that free
// and unscavenged memory, pushing the goal down significantly.
//
// heapFree is also safe to exclude from the memory limit because in the steady-state, it's
// just a pool of memory for future heap allocations, and making new allocations from heapFree
// memory doesn't increase overall memory use. In transient states, the scavenger and the
// allocator actively manage the pool of heapFree memory to maintain the memory limit.
//
// The second term (marker 2) is the amount of memory we've exceeded the limit by, and is
// intended to help recover from such a situation. By pushing the heap goal down, we also
// push the trigger down, triggering and finishing a GC sooner in order to make room for
// other memory sources. Note that since we're effectively reducing the heap goal by X bytes,
// we're actually giving more than X bytes of headroom back, because the heap goal is in
// terms of heap objects, but it takes more than X bytes (e.g. due to fragmentation) to store
// X bytes worth of objects.
//
// The third term (marker 3) subtracts an additional memoryLimitHeapGoalHeadroom bytes from the
// heap goal. As the name implies, this is to provide additional headroom in the face of pacing
// inaccuracies. This is a fixed number of bytes because these inaccuracies disproportionately
// affect small heaps: as heaps get smaller, the pacer's inputs get fuzzier. Shorter GC cycles
// and less GC work means noisy external factors like the OS scheduler have a greater impact.
memoryLimit := uint64(c.memoryLimit.Load())
// Compute term 1.
nonHeapMemory := mappedReady - heapFree - heapAlloc
// Compute term 2.
var overage uint64
if mappedReady > memoryLimit {
overage = mappedReady - memoryLimit
}
if nonHeapMemory+overage >= memoryLimit {
// We're at a point where non-heap memory exceeds the memory limit on its own.
// There's honestly not much we can do here but just trigger GCs continuously
// and let the CPU limiter reign that in. Something has to give at this point.
// Set it to heapMarked, the lowest possible goal.
return c.heapMarked
}
// Compute the goal.
goal := memoryLimit - (nonHeapMemory + overage)
// Apply some headroom to the goal to account for pacing inaccuracies.
// Be careful about small limits.
if goal < memoryLimitHeapGoalHeadroom || goal-memoryLimitHeapGoalHeadroom < memoryLimitHeapGoalHeadroom {
goal = memoryLimitHeapGoalHeadroom
} else {
goal = goal - memoryLimitHeapGoalHeadroom
}
// Don't let us go below the live heap. A heap goal below the live heap doesn't make sense.
if goal < c.heapMarked {
goal = c.heapMarked
}
return goal
}
const (
// These constants determine the bounds on the GC trigger as a fraction
// of heap bytes allocated between the start of a GC (heapLive == heapMarked)
// and the end of a GC (heapLive == heapGoal).
//
// The constants are obscured in this way for efficiency. The denominator
// of the fraction is always a power-of-two for a quick division, so that
// the numerator is a single constant integer multiplication.
triggerRatioDen = 64
// The minimum trigger constant was chosen empirically: given a sufficiently
// fast/scalable allocator with 48 Ps that could drive the trigger ratio
// to <0.05, this constant causes applications to retain the same peak
// RSS compared to not having this allocator.
minTriggerRatioNum = 45 // ~0.7
// The maximum trigger constant is chosen somewhat arbitrarily, but the
// current constant has served us well over the years.
maxTriggerRatioNum = 61 // ~0.95
)
// trigger returns the current point at which a GC should trigger along with
// the heap goal.
//
// The returned value may be compared against heapLive to determine whether
// the GC should trigger. Thus, the GC trigger condition should be (but may
// not be, in the case of small movements for efficiency) checked whenever
// the heap goal may change.
func (c *gcControllerState) trigger() (uint64, uint64) {
goal, minTrigger := c.heapGoalInternal()
// Invariant: the trigger must always be less than the heap goal.
//
// Note that the memory limit sets a hard maximum on our heap goal,
// but the live heap may grow beyond it.
if c.heapMarked >= goal {
// The goal should never be smaller than heapMarked, but let's be
// defensive about it. The only reasonable trigger here is one that
// causes a continuous GC cycle at heapMarked, but respect the goal
// if it came out as smaller than that.
return goal, goal
}
// Below this point, c.heapMarked < goal.
// heapMarked is our absolute minimum, and it's possible the trigger
// bound we get from heapGoalinternal is less than that.
if minTrigger < c.heapMarked {
minTrigger = c.heapMarked
}
// If we let the trigger go too low, then if the application
// is allocating very rapidly we might end up in a situation
// where we're allocating black during a nearly always-on GC.
// The result of this is a growing heap and ultimately an
// increase in RSS. By capping us at a point >0, we're essentially
// saying that we're OK using more CPU during the GC to prevent
// this growth in RSS.
triggerLowerBound := uint64(((goal-c.heapMarked)/triggerRatioDen)*minTriggerRatioNum) + c.heapMarked
if minTrigger < triggerLowerBound {
minTrigger = triggerLowerBound
}
// For small heaps, set the max trigger point at maxTriggerRatio of the way
// from the live heap to the heap goal. This ensures we always have *some*
// headroom when the GC actually starts. For larger heaps, set the max trigger
// point at the goal, minus the minimum heap size.
//
// This choice follows from the fact that the minimum heap size is chosen
// to reflect the costs of a GC with no work to do. With a large heap but
// very little scan work to perform, this gives us exactly as much runway
// as we would need, in the worst case.
maxTrigger := uint64(((goal-c.heapMarked)/triggerRatioDen)*maxTriggerRatioNum) + c.heapMarked
if goal > defaultHeapMinimum && goal-defaultHeapMinimum > maxTrigger {
maxTrigger = goal - defaultHeapMinimum
}
if maxTrigger < minTrigger {
maxTrigger = minTrigger
}
// Compute the trigger from our bounds and the runway stored by commit.
var trigger uint64
runway := c.runway.Load()
if runway > goal {
trigger = minTrigger
} else {
trigger = goal - runway
}
if trigger < minTrigger {
trigger = minTrigger
}
if trigger > maxTrigger {
trigger = maxTrigger
}
if trigger > goal {
print("trigger=", trigger, " heapGoal=", goal, "\n")
print("minTrigger=", minTrigger, " maxTrigger=", maxTrigger, "\n")
throw("produced a trigger greater than the heap goal")
}
return trigger, goal
}
// commit recomputes all pacing parameters needed to derive the
// trigger and the heap goal. Namely, the gcPercent-based heap goal,
// and the amount of runway we want to give the GC this cycle.
//
// This can be called any time. If GC is the in the middle of a
// concurrent phase, it will adjust the pacing of that phase.
//
// isSweepDone should be the result of calling isSweepDone(),
// unless we're testing or we know we're executing during a GC cycle.
//
// This depends on gcPercent, gcController.heapMarked, and
// gcController.heapLive. These must be up to date.
//
// Callers must call gcControllerState.revise after calling this
// function if the GC is enabled.
//
// mheap_.lock must be held or the world must be stopped.
func (c *gcControllerState) commit(isSweepDone bool) {
if !c.test {
assertWorldStoppedOrLockHeld(&mheap_.lock)
}
if isSweepDone {
// The sweep is done, so there aren't any restrictions on the trigger
// we need to think about.
c.sweepDistMinTrigger.Store(0)
} else {
// Concurrent sweep happens in the heap growth
// from gcController.heapLive to trigger. Make sure we
// give the sweeper some runway if it doesn't have enough.
c.sweepDistMinTrigger.Store(c.heapLive.Load() + sweepMinHeapDistance)
}
// Compute the next GC goal, which is when the allocated heap
// has grown by GOGC/100 over where it started the last cycle,
// plus additional runway for non-heap sources of GC work.
gcPercentHeapGoal := ^uint64(0)
if gcPercent := c.gcPercent.Load(); gcPercent >= 0 {
gcPercentHeapGoal = c.heapMarked + (c.heapMarked+c.lastStackScan.Load()+c.globalsScan.Load())*uint64(gcPercent)/100
}
// Apply the minimum heap size here. It's defined in terms of gcPercent
// and is only updated by functions that call commit.
if gcPercentHeapGoal < c.heapMinimum {
gcPercentHeapGoal = c.heapMinimum
}
c.gcPercentHeapGoal.Store(gcPercentHeapGoal)
// Compute the amount of runway we want the GC to have by using our
// estimate of the cons/mark ratio.
//
// The idea is to take our expected scan work, and multiply it by
// the cons/mark ratio to determine how long it'll take to complete
// that scan work in terms of bytes allocated. This gives us our GC's
// runway.
//
// However, the cons/mark ratio is a ratio of rates per CPU-second, but
// here we care about the relative rates for some division of CPU
// resources among the mutator and the GC.
//
// To summarize, we have B / cpu-ns, and we want B / ns. We get that
// by multiplying by our desired division of CPU resources. We choose
// to express CPU resources as GOMAPROCS*fraction. Note that because
// we're working with a ratio here, we can omit the number of CPU cores,
// because they'll appear in the numerator and denominator and cancel out.
// As a result, this is basically just "weighing" the cons/mark ratio by
// our desired division of resources.
//
// Furthermore, by setting the runway so that CPU resources are divided
// this way, assuming that the cons/mark ratio is correct, we make that
// division a reality.
c.runway.Store(uint64((c.consMark * (1 - gcGoalUtilization) / (gcGoalUtilization)) * float64(c.lastHeapScan+c.lastStackScan.Load()+c.globalsScan.Load())))
}
// setGCPercent updates gcPercent. commit must be called after.
// Returns the old value of gcPercent.
//
// The world must be stopped, or mheap_.lock must be held.
func (c *gcControllerState) setGCPercent(in int32) int32 {
if !c.test {
assertWorldStoppedOrLockHeld(&mheap_.lock)
}
out := c.gcPercent.Load()
if in < 0 {
in = -1
}
c.heapMinimum = defaultHeapMinimum * uint64(in) / 100
c.gcPercent.Store(in)
return out
}
//go:linkname setGCPercent runtime/debug.setGCPercent
func setGCPercent(in int32) (out int32) {
// Run on the system stack since we grab the heap lock.
systemstack(func() {
lock(&mheap_.lock)
out = gcController.setGCPercent(in)
gcControllerCommit()
unlock(&mheap_.lock)
})
// If we just disabled GC, wait for any concurrent GC mark to
// finish so we always return with no GC running.
if in < 0 {
gcWaitOnMark(work.cycles.Load())
}
return out
}
func readGOGC() int32 {
p := gogetenv("GOGC")
if p == "off" {
return -1
}
if n, ok := atoi32(p); ok {
return n
}
return 100
}
// setMemoryLimit updates memoryLimit. commit must be called after
// Returns the old value of memoryLimit.
//
// The world must be stopped, or mheap_.lock must be held.
func (c *gcControllerState) setMemoryLimit(in int64) int64 {
if !c.test {
assertWorldStoppedOrLockHeld(&mheap_.lock)
}
out := c.memoryLimit.Load()
if in >= 0 {
c.memoryLimit.Store(in)
}
return out
}
//go:linkname setMemoryLimit runtime/debug.setMemoryLimit
func setMemoryLimit(in int64) (out int64) {
// Run on the system stack since we grab the heap lock.
systemstack(func() {
lock(&mheap_.lock)
out = gcController.setMemoryLimit(in)
if in < 0 || out == in {
// If we're just checking the value or not changing
// it, there's no point in doing the rest.
unlock(&mheap_.lock)
return
}
gcControllerCommit()
unlock(&mheap_.lock)
})
return out
}
func readGOMEMLIMIT() int64 {
p := gogetenv("GOMEMLIMIT")
if p == "" || p == "off" {
return maxInt64
}
n, ok := parseByteCount(p)
if !ok {
print("GOMEMLIMIT=", p, "\n")
throw("malformed GOMEMLIMIT; see `go doc runtime/debug.SetMemoryLimit`")
}
return n
}
// addIdleMarkWorker attempts to add a new idle mark worker.
//
// If this returns true, the caller must become an idle mark worker unless
// there's no background mark worker goroutines in the pool. This case is
// harmless because there are already background mark workers running.
// If this returns false, the caller must NOT become an idle mark worker.
//
// nosplit because it may be called without a P.
//
//go:nosplit
func (c *gcControllerState) addIdleMarkWorker() bool {
for {
old := c.idleMarkWorkers.Load()
n, max := int32(old&uint64(^uint32(0))), int32(old>>32)
if n >= max {
// See the comment on idleMarkWorkers for why
// n > max is tolerated.
return false
}
if n < 0 {
print("n=", n, " max=", max, "\n")
throw("negative idle mark workers")
}
new := uint64(uint32(n+1)) | (uint64(max) << 32)
if c.idleMarkWorkers.CompareAndSwap(old, new) {
return true
}
}
}
// needIdleMarkWorker is a hint as to whether another idle mark worker is needed.
//
// The caller must still call addIdleMarkWorker to become one. This is mainly
// useful for a quick check before an expensive operation.
//
// nosplit because it may be called without a P.
//
//go:nosplit
func (c *gcControllerState) needIdleMarkWorker() bool {
p := c.idleMarkWorkers.Load()
n, max := int32(p&uint64(^uint32(0))), int32(p>>32)
return n < max
}
// removeIdleMarkWorker must be called when an new idle mark worker stops executing.
func (c *gcControllerState) removeIdleMarkWorker() {
for {
old := c.idleMarkWorkers.Load()
n, max := int32(old&uint64(^uint32(0))), int32(old>>32)
if n-1 < 0 {
print("n=", n, " max=", max, "\n")
throw("negative idle mark workers")
}
new := uint64(uint32(n-1)) | (uint64(max) << 32)
if c.idleMarkWorkers.CompareAndSwap(old, new) {
return
}
}
}
// setMaxIdleMarkWorkers sets the maximum number of idle mark workers allowed.
//
// This method is optimistic in that it does not wait for the number of
// idle mark workers to reduce to max before returning; it assumes the workers
// will deschedule themselves.
func (c *gcControllerState) setMaxIdleMarkWorkers(max int32) {
for {
old := c.idleMarkWorkers.Load()
n := int32(old & uint64(^uint32(0)))
if n < 0 {
print("n=", n, " max=", max, "\n")
throw("negative idle mark workers")
}
new := uint64(uint32(n)) | (uint64(max) << 32)
if c.idleMarkWorkers.CompareAndSwap(old, new) {
return
}
}
}
// gcControllerCommit is gcController.commit, but passes arguments from live
// (non-test) data. It also updates any consumers of the GC pacing, such as
// sweep pacing and the background scavenger.
//
// Calls gcController.commit.
//
// The heap lock must be held, so this must be executed on the system stack.
//
//go:systemstack
func gcControllerCommit() {
assertWorldStoppedOrLockHeld(&mheap_.lock)
gcController.commit(isSweepDone())
// Update mark pacing.
if gcphase != _GCoff {
gcController.revise()
}
// TODO(mknyszek): This isn't really accurate any longer because the heap
// goal is computed dynamically. Still useful to snapshot, but not as useful.
if trace.enabled {
traceHeapGoal()
}
trigger, heapGoal := gcController.trigger()
gcPaceSweeper(trigger)
gcPaceScavenger(gcController.memoryLimit.Load(), heapGoal, gcController.lastHeapGoal)
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Scavenging free pages.
//
// This file implements scavenging (the release of physical pages backing mapped
// memory) of free and unused pages in the heap as a way to deal with page-level
// fragmentation and reduce the RSS of Go applications.
//
// Scavenging in Go happens on two fronts: there's the background
// (asynchronous) scavenger and the heap-growth (synchronous) scavenger.
//
// The former happens on a goroutine much like the background sweeper which is
// soft-capped at using scavengePercent of the mutator's time, based on
// order-of-magnitude estimates of the costs of scavenging. The background
// scavenger's primary goal is to bring the estimated heap RSS of the
// application down to a goal.
//
// Before we consider what this looks like, we need to split the world into two
// halves. One in which a memory limit is not set, and one in which it is.
//
// For the former, the goal is defined as:
// (retainExtraPercent+100) / 100 * (heapGoal / lastHeapGoal) * lastHeapInUse
//
// Essentially, we wish to have the application's RSS track the heap goal, but
// the heap goal is defined in terms of bytes of objects, rather than pages like
// RSS. As a result, we need to take into account for fragmentation internal to
// spans. heapGoal / lastHeapGoal defines the ratio between the current heap goal
// and the last heap goal, which tells us by how much the heap is growing and
// shrinking. We estimate what the heap will grow to in terms of pages by taking
// this ratio and multiplying it by heapInUse at the end of the last GC, which
// allows us to account for this additional fragmentation. Note that this
// procedure makes the assumption that the degree of fragmentation won't change
// dramatically over the next GC cycle. Overestimating the amount of
// fragmentation simply results in higher memory use, which will be accounted
// for by the next pacing up date. Underestimating the fragmentation however
// could lead to performance degradation. Handling this case is not within the
// scope of the scavenger. Situations where the amount of fragmentation balloons
// over the course of a single GC cycle should be considered pathologies,
// flagged as bugs, and fixed appropriately.
//
// An additional factor of retainExtraPercent is added as a buffer to help ensure
// that there's more unscavenged memory to allocate out of, since each allocation
// out of scavenged memory incurs a potentially expensive page fault.
//
// If a memory limit is set, then we wish to pick a scavenge goal that maintains
// that memory limit. For that, we look at total memory that has been committed
// (memstats.mappedReady) and try to bring that down below the limit. In this case,
// we want to give buffer space in the *opposite* direction. When the application
// is close to the limit, we want to make sure we push harder to keep it under, so
// if we target below the memory limit, we ensure that the background scavenger is
// giving the situation the urgency it deserves.
//
// In this case, the goal is defined as:
// (100-reduceExtraPercent) / 100 * memoryLimit
//
// We compute both of these goals, and check whether either of them have been met.
// The background scavenger continues operating as long as either one of the goals
// has not been met.
//
// The goals are updated after each GC.
//
// The synchronous heap-growth scavenging happens whenever the heap grows in
// size, for some definition of heap-growth. The intuition behind this is that
// the application had to grow the heap because existing fragments were
// not sufficiently large to satisfy a page-level memory allocation, so we
// scavenge those fragments eagerly to offset the growth in RSS that results.
package runtime
import (
"internal/goos"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
const (
// The background scavenger is paced according to these parameters.
//
// scavengePercent represents the portion of mutator time we're willing
// to spend on scavenging in percent.
scavengePercent = 1 // 1%
// retainExtraPercent represents the amount of memory over the heap goal
// that the scavenger should keep as a buffer space for the allocator.
// This constant is used when we do not have a memory limit set.
//
// The purpose of maintaining this overhead is to have a greater pool of
// unscavenged memory available for allocation (since using scavenged memory
// incurs an additional cost), to account for heap fragmentation and
// the ever-changing layout of the heap.
retainExtraPercent = 10
// reduceExtraPercent represents the amount of memory under the limit
// that the scavenger should target. For example, 5 means we target 95%
// of the limit.
//
// The purpose of shooting lower than the limit is to ensure that, once
// close to the limit, the scavenger is working hard to maintain it. If
// we have a memory limit set but are far away from it, there's no harm
// in leaving up to 100-retainExtraPercent live, and it's more efficient
// anyway, for the same reasons that retainExtraPercent exists.
reduceExtraPercent = 5
// maxPagesPerPhysPage is the maximum number of supported runtime pages per
// physical page, based on maxPhysPageSize.
maxPagesPerPhysPage = maxPhysPageSize / pageSize
// scavengeCostRatio is the approximate ratio between the costs of using previously
// scavenged memory and scavenging memory.
//
// For most systems the cost of scavenging greatly outweighs the costs
// associated with using scavenged memory, making this constant 0. On other systems
// (especially ones where "sysUsed" is not just a no-op) this cost is non-trivial.
//
// This ratio is used as part of multiplicative factor to help the scavenger account
// for the additional costs of using scavenged memory in its pacing.
scavengeCostRatio = 0.7 * (goos.IsDarwin + goos.IsIos)
)
// heapRetained returns an estimate of the current heap RSS.
func heapRetained() uint64 {
return gcController.heapInUse.load() + gcController.heapFree.load()
}
// gcPaceScavenger updates the scavenger's pacing, particularly
// its rate and RSS goal. For this, it requires the current heapGoal,
// and the heapGoal for the previous GC cycle.
//
// The RSS goal is based on the current heap goal with a small overhead
// to accommodate non-determinism in the allocator.
//
// The pacing is based on scavengePageRate, which applies to both regular and
// huge pages. See that constant for more information.
//
// Must be called whenever GC pacing is updated.
//
// mheap_.lock must be held or the world must be stopped.
func gcPaceScavenger(memoryLimit int64, heapGoal, lastHeapGoal uint64) {
assertWorldStoppedOrLockHeld(&mheap_.lock)
// As described at the top of this file, there are two scavenge goals here: one
// for gcPercent and one for memoryLimit. Let's handle the latter first because
// it's simpler.
// We want to target retaining (100-reduceExtraPercent)% of the heap.
memoryLimitGoal := uint64(float64(memoryLimit) * (100.0 - reduceExtraPercent))
// mappedReady is comparable to memoryLimit, and represents how much total memory
// the Go runtime has committed now (estimated).
mappedReady := gcController.mappedReady.Load()
// If we're below the goal already indicate that we don't need the background
// scavenger for the memory limit. This may seems worrisome at first, but note
// that the allocator will assist the background scavenger in the face of a memory
// limit, so we'll be safe even if we stop the scavenger when we shouldn't have.
if mappedReady <= memoryLimitGoal {
scavenge.memoryLimitGoal.Store(^uint64(0))
} else {
scavenge.memoryLimitGoal.Store(memoryLimitGoal)
}
// Now handle the gcPercent goal.
// If we're called before the first GC completed, disable scavenging.
// We never scavenge before the 2nd GC cycle anyway (we don't have enough
// information about the heap yet) so this is fine, and avoids a fault
// or garbage data later.
if lastHeapGoal == 0 {
scavenge.gcPercentGoal.Store(^uint64(0))
return
}
// Compute our scavenging goal.
goalRatio := float64(heapGoal) / float64(lastHeapGoal)
gcPercentGoal := uint64(float64(memstats.lastHeapInUse) * goalRatio)
// Add retainExtraPercent overhead to retainedGoal. This calculation
// looks strange but the purpose is to arrive at an integer division
// (e.g. if retainExtraPercent = 12.5, then we get a divisor of 8)
// that also avoids the overflow from a multiplication.
gcPercentGoal += gcPercentGoal / (1.0 / (retainExtraPercent / 100.0))
// Align it to a physical page boundary to make the following calculations
// a bit more exact.
gcPercentGoal = (gcPercentGoal + uint64(physPageSize) - 1) &^ (uint64(physPageSize) - 1)
// Represents where we are now in the heap's contribution to RSS in bytes.
//
// Guaranteed to always be a multiple of physPageSize on systems where
// physPageSize <= pageSize since we map new heap memory at a size larger than
// any physPageSize and released memory in multiples of the physPageSize.
//
// However, certain functions recategorize heap memory as other stats (e.g.
// stacks) and this happens in multiples of pageSize, so on systems
// where physPageSize > pageSize the calculations below will not be exact.
// Generally this is OK since we'll be off by at most one regular
// physical page.
heapRetainedNow := heapRetained()
// If we're already below our goal, or within one page of our goal, then indicate
// that we don't need the background scavenger for maintaining a memory overhead
// proportional to the heap goal.
if heapRetainedNow <= gcPercentGoal || heapRetainedNow-gcPercentGoal < uint64(physPageSize) {
scavenge.gcPercentGoal.Store(^uint64(0))
} else {
scavenge.gcPercentGoal.Store(gcPercentGoal)
}
}
var scavenge struct {
// gcPercentGoal is the amount of retained heap memory (measured by
// heapRetained) that the runtime will try to maintain by returning
// memory to the OS. This goal is derived from gcController.gcPercent
// by choosing to retain enough memory to allocate heap memory up to
// the heap goal.
gcPercentGoal atomic.Uint64
// memoryLimitGoal is the amount of memory retained by the runtime (
// measured by gcController.mappedReady) that the runtime will try to
// maintain by returning memory to the OS. This goal is derived from
// gcController.memoryLimit by choosing to target the memory limit or
// some lower target to keep the scavenger working.
memoryLimitGoal atomic.Uint64
// assistTime is the time spent by the allocator scavenging in the last GC cycle.
//
// This is reset once a GC cycle ends.
assistTime atomic.Int64
// backgroundTime is the time spent by the background scavenger in the last GC cycle.
//
// This is reset once a GC cycle ends.
backgroundTime atomic.Int64
}
const (
// It doesn't really matter what value we start at, but we can't be zero, because
// that'll cause divide-by-zero issues. Pick something conservative which we'll
// also use as a fallback.
startingScavSleepRatio = 0.001
// Spend at least 1 ms scavenging, otherwise the corresponding
// sleep time to maintain our desired utilization is too low to
// be reliable.
minScavWorkTime = 1e6
)
// Sleep/wait state of the background scavenger.
var scavenger scavengerState
type scavengerState struct {
// lock protects all fields below.
lock mutex
// g is the goroutine the scavenger is bound to.
g *g
// parked is whether or not the scavenger is parked.
parked bool
// timer is the timer used for the scavenger to sleep.
timer *timer
// sysmonWake signals to sysmon that it should wake the scavenger.
sysmonWake atomic.Uint32
// targetCPUFraction is the target CPU overhead for the scavenger.
targetCPUFraction float64
// sleepRatio is the ratio of time spent doing scavenging work to
// time spent sleeping. This is used to decide how long the scavenger
// should sleep for in between batches of work. It is set by
// critSleepController in order to maintain a CPU overhead of
// targetCPUFraction.
//
// Lower means more sleep, higher means more aggressive scavenging.
sleepRatio float64
// sleepController controls sleepRatio.
//
// See sleepRatio for more details.
sleepController piController
// cooldown is the time left in nanoseconds during which we avoid
// using the controller and we hold sleepRatio at a conservative
// value. Used if the controller's assumptions fail to hold.
controllerCooldown int64
// printControllerReset instructs printScavTrace to signal that
// the controller was reset.
printControllerReset bool
// sleepStub is a stub used for testing to avoid actually having
// the scavenger sleep.
//
// Unlike the other stubs, this is not populated if left nil
// Instead, it is called when non-nil because any valid implementation
// of this function basically requires closing over this scavenger
// state, and allocating a closure is not allowed in the runtime as
// a matter of policy.
sleepStub func(n int64) int64
// scavenge is a function that scavenges n bytes of memory.
// Returns how many bytes of memory it actually scavenged, as
// well as the time it took in nanoseconds. Usually mheap.pages.scavenge
// with nanotime called around it, but stubbed out for testing.
// Like mheap.pages.scavenge, if it scavenges less than n bytes of
// memory, the caller may assume the heap is exhausted of scavengable
// memory for now.
//
// If this is nil, it is populated with the real thing in init.
scavenge func(n uintptr) (uintptr, int64)
// shouldStop is a callback called in the work loop and provides a
// point that can force the scavenger to stop early, for example because
// the scavenge policy dictates too much has been scavenged already.
//
// If this is nil, it is populated with the real thing in init.
shouldStop func() bool
// gomaxprocs returns the current value of gomaxprocs. Stub for testing.
//
// If this is nil, it is populated with the real thing in init.
gomaxprocs func() int32
}
// init initializes a scavenger state and wires to the current G.
//
// Must be called from a regular goroutine that can allocate.
func (s *scavengerState) init() {
if s.g != nil {
throw("scavenger state is already wired")
}
lockInit(&s.lock, lockRankScavenge)
s.g = getg()
s.timer = new(timer)
s.timer.arg = s
s.timer.f = func(s any, _ uintptr) {
s.(*scavengerState).wake()
}
// input: fraction of CPU time actually used.
// setpoint: ideal CPU fraction.
// output: ratio of time worked to time slept (determines sleep time).
//
// The output of this controller is somewhat indirect to what we actually
// want to achieve: how much time to sleep for. The reason for this definition
// is to ensure that the controller's outputs have a direct relationship with
// its inputs (as opposed to an inverse relationship), making it somewhat
// easier to reason about for tuning purposes.
s.sleepController = piController{
// Tuned loosely via Ziegler-Nichols process.
kp: 0.3375,
ti: 3.2e6,
tt: 1e9, // 1 second reset time.
// These ranges seem wide, but we want to give the controller plenty of
// room to hunt for the optimal value.
min: 0.001, // 1:1000
max: 1000.0, // 1000:1
}
s.sleepRatio = startingScavSleepRatio
// Install real functions if stubs aren't present.
if s.scavenge == nil {
s.scavenge = func(n uintptr) (uintptr, int64) {
start := nanotime()
r := mheap_.pages.scavenge(n, nil)
end := nanotime()
if start >= end {
return r, 0
}
scavenge.backgroundTime.Add(end - start)
return r, end - start
}
}
if s.shouldStop == nil {
s.shouldStop = func() bool {
// If background scavenging is disabled or if there's no work to do just stop.
return heapRetained() <= scavenge.gcPercentGoal.Load() &&
gcController.mappedReady.Load() <= scavenge.memoryLimitGoal.Load()
}
}
if s.gomaxprocs == nil {
s.gomaxprocs = func() int32 {
return gomaxprocs
}
}
}
// park parks the scavenger goroutine.
func (s *scavengerState) park() {
lock(&s.lock)
if getg() != s.g {
throw("tried to park scavenger from another goroutine")
}
s.parked = true
goparkunlock(&s.lock, waitReasonGCScavengeWait, traceEvGoBlock, 2)
}
// ready signals to sysmon that the scavenger should be awoken.
func (s *scavengerState) ready() {
s.sysmonWake.Store(1)
}
// wake immediately unparks the scavenger if necessary.
//
// Safe to run without a P.
func (s *scavengerState) wake() {
lock(&s.lock)
if s.parked {
// Unset sysmonWake, since the scavenger is now being awoken.
s.sysmonWake.Store(0)
// s.parked is unset to prevent a double wake-up.
s.parked = false
// Ready the goroutine by injecting it. We use injectglist instead
// of ready or goready in order to allow us to run this function
// without a P. injectglist also avoids placing the goroutine in
// the current P's runnext slot, which is desirable to prevent
// the scavenger from interfering with user goroutine scheduling
// too much.
var list gList
list.push(s.g)
injectglist(&list)
}
unlock(&s.lock)
}
// sleep puts the scavenger to sleep based on the amount of time that it worked
// in nanoseconds.
//
// Note that this function should only be called by the scavenger.
//
// The scavenger may be woken up earlier by a pacing change, and it may not go
// to sleep at all if there's a pending pacing change.
func (s *scavengerState) sleep(worked float64) {
lock(&s.lock)
if getg() != s.g {
throw("tried to sleep scavenger from another goroutine")
}
if worked < minScavWorkTime {
// This means there wasn't enough work to actually fill up minScavWorkTime.
// That's fine; we shouldn't try to do anything with this information
// because it's going result in a short enough sleep request that things
// will get messy. Just assume we did at least this much work.
// All this means is that we'll sleep longer than we otherwise would have.
worked = minScavWorkTime
}
// Multiply the critical time by 1 + the ratio of the costs of using
// scavenged memory vs. scavenging memory. This forces us to pay down
// the cost of reusing this memory eagerly by sleeping for a longer period
// of time and scavenging less frequently. More concretely, we avoid situations
// where we end up scavenging so often that we hurt allocation performance
// because of the additional overheads of using scavenged memory.
worked *= 1 + scavengeCostRatio
// sleepTime is the amount of time we're going to sleep, based on the amount
// of time we worked, and the sleepRatio.
sleepTime := int64(worked / s.sleepRatio)
var slept int64
if s.sleepStub == nil {
// Set the timer.
//
// This must happen here instead of inside gopark
// because we can't close over any variables without
// failing escape analysis.
start := nanotime()
resetTimer(s.timer, start+sleepTime)
// Mark ourselves as asleep and go to sleep.
s.parked = true
goparkunlock(&s.lock, waitReasonSleep, traceEvGoSleep, 2)
// How long we actually slept for.
slept = nanotime() - start
lock(&s.lock)
// Stop the timer here because s.wake is unable to do it for us.
// We don't really care if we succeed in stopping the timer. One
// reason we might fail is that we've already woken up, but the timer
// might be in the process of firing on some other P; essentially we're
// racing with it. That's totally OK. Double wake-ups are perfectly safe.
stopTimer(s.timer)
unlock(&s.lock)
} else {
unlock(&s.lock)
slept = s.sleepStub(sleepTime)
}
// Stop here if we're cooling down from the controller.
if s.controllerCooldown > 0 {
// worked and slept aren't exact measures of time, but it's OK to be a bit
// sloppy here. We're just hoping we're avoiding some transient bad behavior.
t := slept + int64(worked)
if t > s.controllerCooldown {
s.controllerCooldown = 0
} else {
s.controllerCooldown -= t
}
return
}
// idealFraction is the ideal % of overall application CPU time that we
// spend scavenging.
idealFraction := float64(scavengePercent) / 100.0
// Calculate the CPU time spent.
//
// This may be slightly inaccurate with respect to GOMAXPROCS, but we're
// recomputing this often enough relative to GOMAXPROCS changes in general
// (it only changes when the world is stopped, and not during a GC) that
// that small inaccuracy is in the noise.
cpuFraction := worked / ((float64(slept) + worked) * float64(s.gomaxprocs()))
// Update the critSleepRatio, adjusting until we reach our ideal fraction.
var ok bool
s.sleepRatio, ok = s.sleepController.next(cpuFraction, idealFraction, float64(slept)+worked)
if !ok {
// The core assumption of the controller, that we can get a proportional
// response, broke down. This may be transient, so temporarily switch to
// sleeping a fixed, conservative amount.
s.sleepRatio = startingScavSleepRatio
s.controllerCooldown = 5e9 // 5 seconds.
// Signal the scav trace printer to output this.
s.controllerFailed()
}
}
// controllerFailed indicates that the scavenger's scheduling
// controller failed.
func (s *scavengerState) controllerFailed() {
lock(&s.lock)
s.printControllerReset = true
unlock(&s.lock)
}
// run is the body of the main scavenging loop.
//
// Returns the number of bytes released and the estimated time spent
// releasing those bytes.
//
// Must be run on the scavenger goroutine.
func (s *scavengerState) run() (released uintptr, worked float64) {
lock(&s.lock)
if getg() != s.g {
throw("tried to run scavenger from another goroutine")
}
unlock(&s.lock)
for worked < minScavWorkTime {
// If something from outside tells us to stop early, stop.
if s.shouldStop() {
break
}
// scavengeQuantum is the amount of memory we try to scavenge
// in one go. A smaller value means the scavenger is more responsive
// to the scheduler in case of e.g. preemption. A larger value means
// that the overheads of scavenging are better amortized, so better
// scavenging throughput.
//
// The current value is chosen assuming a cost of ~10µs/physical page
// (this is somewhat pessimistic), which implies a worst-case latency of
// about 160µs for 4 KiB physical pages. The current value is biased
// toward latency over throughput.
const scavengeQuantum = 64 << 10
// Accumulate the amount of time spent scavenging.
r, duration := s.scavenge(scavengeQuantum)
// On some platforms we may see end >= start if the time it takes to scavenge
// memory is less than the minimum granularity of its clock (e.g. Windows) or
// due to clock bugs.
//
// In this case, just assume scavenging takes 10 µs per regular physical page
// (determined empirically), and conservatively ignore the impact of huge pages
// on timing.
const approxWorkedNSPerPhysicalPage = 10e3
if duration == 0 {
worked += approxWorkedNSPerPhysicalPage * float64(r/physPageSize)
} else {
// TODO(mknyszek): If duration is small compared to worked, it could be
// rounded down to zero. Probably not a problem in practice because the
// values are all within a few orders of magnitude of each other but maybe
// worth worrying about.
worked += float64(duration)
}
released += r
// scavenge does not return until it either finds the requisite amount of
// memory to scavenge, or exhausts the heap. If we haven't found enough
// to scavenge, then the heap must be exhausted.
if r < scavengeQuantum {
break
}
// When using fake time just do one loop.
if faketime != 0 {
break
}
}
if released > 0 && released < physPageSize {
// If this happens, it means that we may have attempted to release part
// of a physical page, but the likely effect of that is that it released
// the whole physical page, some of which may have still been in-use.
// This could lead to memory corruption. Throw.
throw("released less than one physical page of memory")
}
return
}
// Background scavenger.
//
// The background scavenger maintains the RSS of the application below
// the line described by the proportional scavenging statistics in
// the mheap struct.
func bgscavenge(c chan int) {
scavenger.init()
c <- 1
scavenger.park()
for {
released, workTime := scavenger.run()
if released == 0 {
scavenger.park()
continue
}
atomic.Xadduintptr(&mheap_.pages.scav.released, released)
scavenger.sleep(workTime)
}
}
// scavenge scavenges nbytes worth of free pages, starting with the
// highest address first. Successive calls continue from where it left
// off until the heap is exhausted. Call scavengeStartGen to bring it
// back to the top of the heap.
//
// Returns the amount of memory scavenged in bytes.
//
// scavenge always tries to scavenge nbytes worth of memory, and will
// only fail to do so if the heap is exhausted for now.
func (p *pageAlloc) scavenge(nbytes uintptr, shouldStop func() bool) uintptr {
released := uintptr(0)
for released < nbytes {
ci, pageIdx := p.scav.index.find()
if ci == 0 {
break
}
systemstack(func() {
released += p.scavengeOne(ci, pageIdx, nbytes-released)
})
if shouldStop != nil && shouldStop() {
break
}
}
return released
}
// printScavTrace prints a scavenge trace line to standard error.
//
// released should be the amount of memory released since the last time this
// was called, and forced indicates whether the scavenge was forced by the
// application.
//
// scavenger.lock must be held.
func printScavTrace(released uintptr, forced bool) {
assertLockHeld(&scavenger.lock)
printlock()
print("scav ",
released>>10, " KiB work, ",
gcController.heapReleased.load()>>10, " KiB total, ",
(gcController.heapInUse.load()*100)/heapRetained(), "% util",
)
if forced {
print(" (forced)")
} else if scavenger.printControllerReset {
print(" [controller reset]")
scavenger.printControllerReset = false
}
println()
printunlock()
}
// scavengeOne walks over the chunk at chunk index ci and searches for
// a contiguous run of pages to scavenge. It will try to scavenge
// at most max bytes at once, but may scavenge more to avoid
// breaking huge pages. Once it scavenges some memory it returns
// how much it scavenged in bytes.
//
// searchIdx is the page index to start searching from in ci.
//
// Returns the number of bytes scavenged.
//
// Must run on the systemstack because it acquires p.mheapLock.
//
//go:systemstack
func (p *pageAlloc) scavengeOne(ci chunkIdx, searchIdx uint, max uintptr) uintptr {
// Calculate the maximum number of pages to scavenge.
//
// This should be alignUp(max, pageSize) / pageSize but max can and will
// be ^uintptr(0), so we need to be very careful not to overflow here.
// Rather than use alignUp, calculate the number of pages rounded down
// first, then add back one if necessary.
maxPages := max / pageSize
if max%pageSize != 0 {
maxPages++
}
// Calculate the minimum number of pages we can scavenge.
//
// Because we can only scavenge whole physical pages, we must
// ensure that we scavenge at least minPages each time, aligned
// to minPages*pageSize.
minPages := physPageSize / pageSize
if minPages < 1 {
minPages = 1
}
lock(p.mheapLock)
if p.summary[len(p.summary)-1][ci].max() >= uint(minPages) {
// We only bother looking for a candidate if there at least
// minPages free pages at all.
base, npages := p.chunkOf(ci).findScavengeCandidate(searchIdx, minPages, maxPages)
// If we found something, scavenge it and return!
if npages != 0 {
// Compute the full address for the start of the range.
addr := chunkBase(ci) + uintptr(base)*pageSize
// Mark the range we're about to scavenge as allocated, because
// we don't want any allocating goroutines to grab it while
// the scavenging is in progress.
if scav := p.allocRange(addr, uintptr(npages)); scav != 0 {
throw("double scavenge")
}
// With that done, it's safe to unlock.
unlock(p.mheapLock)
if !p.test {
pageTraceScav(getg().m.p.ptr(), 0, addr, uintptr(npages))
// Only perform the actual scavenging if we're not in a test.
// It's dangerous to do so otherwise.
sysUnused(unsafe.Pointer(addr), uintptr(npages)*pageSize)
// Update global accounting only when not in test, otherwise
// the runtime's accounting will be wrong.
nbytes := int64(npages) * pageSize
gcController.heapReleased.add(nbytes)
gcController.heapFree.add(-nbytes)
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.committed, -nbytes)
atomic.Xaddint64(&stats.released, nbytes)
memstats.heapStats.release()
}
// Relock the heap, because now we need to make these pages
// available allocation. Free them back to the page allocator.
lock(p.mheapLock)
p.free(addr, uintptr(npages), true)
// Mark the range as scavenged.
p.chunkOf(ci).scavenged.setRange(base, npages)
unlock(p.mheapLock)
return uintptr(npages) * pageSize
}
}
// Mark this chunk as having no free pages.
p.scav.index.clear(ci)
unlock(p.mheapLock)
return 0
}
// fillAligned returns x but with all zeroes in m-aligned
// groups of m bits set to 1 if any bit in the group is non-zero.
//
// For example, fillAligned(0x0100a3, 8) == 0xff00ff.
//
// Note that if m == 1, this is a no-op.
//
// m must be a power of 2 <= maxPagesPerPhysPage.
func fillAligned(x uint64, m uint) uint64 {
apply := func(x uint64, c uint64) uint64 {
// The technique used it here is derived from
// https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord
// and extended for more than just bytes (like nibbles
// and uint16s) by using an appropriate constant.
//
// To summarize the technique, quoting from that page:
// "[It] works by first zeroing the high bits of the [8]
// bytes in the word. Subsequently, it adds a number that
// will result in an overflow to the high bit of a byte if
// any of the low bits were initially set. Next the high
// bits of the original word are ORed with these values;
// thus, the high bit of a byte is set iff any bit in the
// byte was set. Finally, we determine if any of these high
// bits are zero by ORing with ones everywhere except the
// high bits and inverting the result."
return ^((((x & c) + c) | x) | c)
}
// Transform x to contain a 1 bit at the top of each m-aligned
// group of m zero bits.
switch m {
case 1:
return x
case 2:
x = apply(x, 0x5555555555555555)
case 4:
x = apply(x, 0x7777777777777777)
case 8:
x = apply(x, 0x7f7f7f7f7f7f7f7f)
case 16:
x = apply(x, 0x7fff7fff7fff7fff)
case 32:
x = apply(x, 0x7fffffff7fffffff)
case 64: // == maxPagesPerPhysPage
x = apply(x, 0x7fffffffffffffff)
default:
throw("bad m value")
}
// Now, the top bit of each m-aligned group in x is set
// that group was all zero in the original x.
// From each group of m bits subtract 1.
// Because we know only the top bits of each
// m-aligned group are set, we know this will
// set each group to have all the bits set except
// the top bit, so just OR with the original
// result to set all the bits.
return ^((x - (x >> (m - 1))) | x)
}
// findScavengeCandidate returns a start index and a size for this pallocData
// segment which represents a contiguous region of free and unscavenged memory.
//
// searchIdx indicates the page index within this chunk to start the search, but
// note that findScavengeCandidate searches backwards through the pallocData. As a
// a result, it will return the highest scavenge candidate in address order.
//
// min indicates a hard minimum size and alignment for runs of pages. That is,
// findScavengeCandidate will not return a region smaller than min pages in size,
// or that is min pages or greater in size but not aligned to min. min must be
// a non-zero power of 2 <= maxPagesPerPhysPage.
//
// max is a hint for how big of a region is desired. If max >= pallocChunkPages, then
// findScavengeCandidate effectively returns entire free and unscavenged regions.
// If max < pallocChunkPages, it may truncate the returned region such that size is
// max. However, findScavengeCandidate may still return a larger region if, for
// example, it chooses to preserve huge pages, or if max is not aligned to min (it
// will round up). That is, even if max is small, the returned size is not guaranteed
// to be equal to max. max is allowed to be less than min, in which case it is as if
// max == min.
func (m *pallocData) findScavengeCandidate(searchIdx uint, min, max uintptr) (uint, uint) {
if min&(min-1) != 0 || min == 0 {
print("runtime: min = ", min, "\n")
throw("min must be a non-zero power of 2")
} else if min > maxPagesPerPhysPage {
print("runtime: min = ", min, "\n")
throw("min too large")
}
// max may not be min-aligned, so we might accidentally truncate to
// a max value which causes us to return a non-min-aligned value.
// To prevent this, align max up to a multiple of min (which is always
// a power of 2). This also prevents max from ever being less than
// min, unless it's zero, so handle that explicitly.
if max == 0 {
max = min
} else {
max = alignUp(max, min)
}
i := int(searchIdx / 64)
// Start by quickly skipping over blocks of non-free or scavenged pages.
for ; i >= 0; i-- {
// 1s are scavenged OR non-free => 0s are unscavenged AND free
x := fillAligned(m.scavenged[i]|m.pallocBits[i], uint(min))
if x != ^uint64(0) {
break
}
}
if i < 0 {
// Failed to find any free/unscavenged pages.
return 0, 0
}
// We have something in the 64-bit chunk at i, but it could
// extend further. Loop until we find the extent of it.
// 1s are scavenged OR non-free => 0s are unscavenged AND free
x := fillAligned(m.scavenged[i]|m.pallocBits[i], uint(min))
z1 := uint(sys.LeadingZeros64(^x))
run, end := uint(0), uint(i)*64+(64-z1)
if x<<z1 != 0 {
// After shifting out z1 bits, we still have 1s,
// so the run ends inside this word.
run = uint(sys.LeadingZeros64(x << z1))
} else {
// After shifting out z1 bits, we have no more 1s.
// This means the run extends to the bottom of the
// word so it may extend into further words.
run = 64 - z1
for j := i - 1; j >= 0; j-- {
x := fillAligned(m.scavenged[j]|m.pallocBits[j], uint(min))
run += uint(sys.LeadingZeros64(x))
if x != 0 {
// The run stopped in this word.
break
}
}
}
// Split the run we found if it's larger than max but hold on to
// our original length, since we may need it later.
size := run
if size > uint(max) {
size = uint(max)
}
start := end - size
// Each huge page is guaranteed to fit in a single palloc chunk.
//
// TODO(mknyszek): Support larger huge page sizes.
// TODO(mknyszek): Consider taking pages-per-huge-page as a parameter
// so we can write tests for this.
if physHugePageSize > pageSize && physHugePageSize > physPageSize {
// We have huge pages, so let's ensure we don't break one by scavenging
// over a huge page boundary. If the range [start, start+size) overlaps with
// a free-and-unscavenged huge page, we want to grow the region we scavenge
// to include that huge page.
// Compute the huge page boundary above our candidate.
pagesPerHugePage := uintptr(physHugePageSize / pageSize)
hugePageAbove := uint(alignUp(uintptr(start), pagesPerHugePage))
// If that boundary is within our current candidate, then we may be breaking
// a huge page.
if hugePageAbove <= end {
// Compute the huge page boundary below our candidate.
hugePageBelow := uint(alignDown(uintptr(start), pagesPerHugePage))
if hugePageBelow >= end-run {
// We're in danger of breaking apart a huge page since start+size crosses
// a huge page boundary and rounding down start to the nearest huge
// page boundary is included in the full run we found. Include the entire
// huge page in the bound by rounding down to the huge page size.
size = size + (start - hugePageBelow)
start = hugePageBelow
}
}
}
return start, size
}
// scavengeIndex is a structure for efficiently managing which pageAlloc chunks have
// memory available to scavenge.
type scavengeIndex struct {
// chunks is a bitmap representing the entire address space. Each bit represents
// a single chunk, and a 1 value indicates the presence of pages available for
// scavenging. Updates to the bitmap are serialized by the pageAlloc lock.
//
// The underlying storage of chunks is platform dependent and may not even be
// totally mapped read/write. min and max reflect the extent that is safe to access.
// min is inclusive, max is exclusive.
//
// searchAddr is the maximum address (in the offset address space, so we have a linear
// view of the address space; see mranges.go:offAddr) containing memory available to
// scavenge. It is a hint to the find operation to avoid O(n^2) behavior in repeated lookups.
//
// searchAddr is always inclusive and should be the base address of the highest runtime
// page available for scavenging.
//
// searchAddr is managed by both find and mark.
//
// Normally, find monotonically decreases searchAddr as it finds no more free pages to
// scavenge. However, mark, when marking a new chunk at an index greater than the current
// searchAddr, sets searchAddr to the *negative* index into chunks of that page. The trick here
// is that concurrent calls to find will fail to monotonically decrease searchAddr, and so they
// won't barge over new memory becoming available to scavenge. Furthermore, this ensures
// that some future caller of find *must* observe the new high index. That caller
// (or any other racing with it), then makes searchAddr positive before continuing, bringing
// us back to our monotonically decreasing steady-state.
//
// A pageAlloc lock serializes updates between min, max, and searchAddr, so abs(searchAddr)
// is always guaranteed to be >= min and < max (converted to heap addresses).
//
// TODO(mknyszek): Ideally we would use something bigger than a uint8 for faster
// iteration like uint32, but we lack the bit twiddling intrinsics. We'd need to either
// copy them from math/bits or fix the fact that we can't import math/bits' code from
// the runtime due to compiler instrumentation.
searchAddr atomicOffAddr
chunks []atomic.Uint8
minHeapIdx atomic.Int32
min, max atomic.Int32
}
// find returns the highest chunk index that may contain pages available to scavenge.
// It also returns an offset to start searching in the highest chunk.
func (s *scavengeIndex) find() (chunkIdx, uint) {
searchAddr, marked := s.searchAddr.Load()
if searchAddr == minOffAddr.addr() {
// We got a cleared search addr.
return 0, 0
}
// Starting from searchAddr's chunk, and moving down to minHeapIdx,
// iterate until we find a chunk with pages to scavenge.
min := s.minHeapIdx.Load()
searchChunk := chunkIndex(uintptr(searchAddr))
start := int32(searchChunk / 8)
for i := start; i >= min; i-- {
// Skip over irrelevant address space.
chunks := s.chunks[i].Load()
if chunks == 0 {
continue
}
// Note that we can't have 8 leading zeroes here because
// we necessarily skipped that case. So, what's left is
// an index. If there are no zeroes, we want the 7th
// index, if 1 zero, the 6th, and so on.
n := 7 - sys.LeadingZeros8(chunks)
ci := chunkIdx(uint(i)*8 + uint(n))
if searchChunk == ci {
return ci, chunkPageIndex(uintptr(searchAddr))
}
// Try to reduce searchAddr to newSearchAddr.
newSearchAddr := chunkBase(ci) + pallocChunkBytes - pageSize
if marked {
// Attempt to be the first one to decrease the searchAddr
// after an increase. If we fail, that means there was another
// increase, or somebody else got to it before us. Either way,
// it doesn't matter. We may lose some performance having an
// incorrect search address, but it's far more important that
// we don't miss updates.
s.searchAddr.StoreUnmark(searchAddr, newSearchAddr)
} else {
// Decrease searchAddr.
s.searchAddr.StoreMin(newSearchAddr)
}
return ci, pallocChunkPages - 1
}
// Clear searchAddr, because we've exhausted the heap.
s.searchAddr.Clear()
return 0, 0
}
// mark sets the inclusive range of chunks between indices start and end as
// containing pages available to scavenge.
//
// Must be serialized with other mark, markRange, and clear calls.
func (s *scavengeIndex) mark(base, limit uintptr) {
start, end := chunkIndex(base), chunkIndex(limit-pageSize)
if start == end {
// Within a chunk.
mask := uint8(1 << (start % 8))
s.chunks[start/8].Or(mask)
} else if start/8 == end/8 {
// Within the same byte in the index.
mask := uint8(uint16(1<<(end-start+1))-1) << (start % 8)
s.chunks[start/8].Or(mask)
} else {
// Crosses multiple bytes in the index.
startAligned := chunkIdx(alignUp(uintptr(start), 8))
endAligned := chunkIdx(alignDown(uintptr(end), 8))
// Do the end of the first byte first.
if width := startAligned - start; width > 0 {
mask := uint8(uint16(1<<width)-1) << (start % 8)
s.chunks[start/8].Or(mask)
}
// Do the middle aligned sections that take up a whole
// byte.
for ci := startAligned; ci < endAligned; ci += 8 {
s.chunks[ci/8].Store(^uint8(0))
}
// Do the end of the last byte.
//
// This width check doesn't match the one above
// for start because aligning down into the endAligned
// block means we always have at least one chunk in this
// block (note that end is *inclusive*). This also means
// that if end == endAligned+n, then what we really want
// is to fill n+1 chunks, i.e. width n+1. By induction,
// this is true for all n.
if width := end - endAligned + 1; width > 0 {
mask := uint8(uint16(1<<width) - 1)
s.chunks[end/8].Or(mask)
}
}
newSearchAddr := limit - pageSize
searchAddr, _ := s.searchAddr.Load()
// N.B. Because mark is serialized, it's not necessary to do a
// full CAS here. mark only ever increases searchAddr, while
// find only ever decreases it. Since we only ever race with
// decreases, even if the value we loaded is stale, the actual
// value will never be larger.
if (offAddr{searchAddr}).lessThan(offAddr{newSearchAddr}) {
s.searchAddr.StoreMarked(newSearchAddr)
}
}
// clear sets the chunk at index ci as not containing pages available to scavenge.
//
// Must be serialized with other mark, markRange, and clear calls.
func (s *scavengeIndex) clear(ci chunkIdx) {
s.chunks[ci/8].And(^uint8(1 << (ci % 8)))
}
type piController struct {
kp float64 // Proportional constant.
ti float64 // Integral time constant.
tt float64 // Reset time.
min, max float64 // Output boundaries.
// PI controller state.
errIntegral float64 // Integral of the error from t=0 to now.
// Error flags.
errOverflow bool // Set if errIntegral ever overflowed.
inputOverflow bool // Set if an operation with the input overflowed.
}
// next provides a new sample to the controller.
//
// input is the sample, setpoint is the desired point, and period is how much
// time (in whatever unit makes the most sense) has passed since the last sample.
//
// Returns a new value for the variable it's controlling, and whether the operation
// completed successfully. One reason this might fail is if error has been growing
// in an unbounded manner, to the point of overflow.
//
// In the specific case of an error overflow occurs, the errOverflow field will be
// set and the rest of the controller's internal state will be fully reset.
func (c *piController) next(input, setpoint, period float64) (float64, bool) {
// Compute the raw output value.
prop := c.kp * (setpoint - input)
rawOutput := prop + c.errIntegral
// Clamp rawOutput into output.
output := rawOutput
if isInf(output) || isNaN(output) {
// The input had a large enough magnitude that either it was already
// overflowed, or some operation with it overflowed.
// Set a flag and reset. That's the safest thing to do.
c.reset()
c.inputOverflow = true
return c.min, false
}
if output < c.min {
output = c.min
} else if output > c.max {
output = c.max
}
// Update the controller's state.
if c.ti != 0 && c.tt != 0 {
c.errIntegral += (c.kp*period/c.ti)*(setpoint-input) + (period/c.tt)*(output-rawOutput)
if isInf(c.errIntegral) || isNaN(c.errIntegral) {
// So much error has accumulated that we managed to overflow.
// The assumptions around the controller have likely broken down.
// Set a flag and reset. That's the safest thing to do.
c.reset()
c.errOverflow = true
return c.min, false
}
}
return output, true
}
// reset resets the controller state, except for controller error flags.
func (c *piController) reset() {
c.errIntegral = 0
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: stack objects and stack tracing
// See the design doc at https://docs.google.com/document/d/1un-Jn47yByHL7I0aVIP_uVCMxjdM5mpelJhiKlIqxkE/edit?usp=sharing
// Also see issue 22350.
// Stack tracing solves the problem of determining which parts of the
// stack are live and should be scanned. It runs as part of scanning
// a single goroutine stack.
//
// Normally determining which parts of the stack are live is easy to
// do statically, as user code has explicit references (reads and
// writes) to stack variables. The compiler can do a simple dataflow
// analysis to determine liveness of stack variables at every point in
// the code. See cmd/compile/internal/gc/plive.go for that analysis.
//
// However, when we take the address of a stack variable, determining
// whether that variable is still live is less clear. We can still
// look for static accesses, but accesses through a pointer to the
// variable are difficult in general to track statically. That pointer
// can be passed among functions on the stack, conditionally retained,
// etc.
//
// Instead, we will track pointers to stack variables dynamically.
// All pointers to stack-allocated variables will themselves be on the
// stack somewhere (or in associated locations, like defer records), so
// we can find them all efficiently.
//
// Stack tracing is organized as a mini garbage collection tracing
// pass. The objects in this garbage collection are all the variables
// on the stack whose address is taken, and which themselves contain a
// pointer. We call these variables "stack objects".
//
// We begin by determining all the stack objects on the stack and all
// the statically live pointers that may point into the stack. We then
// process each pointer to see if it points to a stack object. If it
// does, we scan that stack object. It may contain pointers into the
// heap, in which case those pointers are passed to the main garbage
// collection. It may also contain pointers into the stack, in which
// case we add them to our set of stack pointers.
//
// Once we're done processing all the pointers (including the ones we
// added during processing), we've found all the stack objects that
// are live. Any dead stack objects are not scanned and their contents
// will not keep heap objects live. Unlike the main garbage
// collection, we can't sweep the dead stack objects; they live on in
// a moribund state until the stack frame that contains them is
// popped.
//
// A stack can look like this:
//
// +----------+
// | foo() |
// | +------+ |
// | | A | | <---\
// | +------+ | |
// | | |
// | +------+ | |
// | | B | | |
// | +------+ | |
// | | |
// +----------+ |
// | bar() | |
// | +------+ | |
// | | C | | <-\ |
// | +----|-+ | | |
// | | | | |
// | +----v-+ | | |
// | | D ---------/
// | +------+ | |
// | | |
// +----------+ |
// | baz() | |
// | +------+ | |
// | | E -------/
// | +------+ |
// | ^ |
// | F: --/ |
// | |
// +----------+
//
// foo() calls bar() calls baz(). Each has a frame on the stack.
// foo() has stack objects A and B.
// bar() has stack objects C and D, with C pointing to D and D pointing to A.
// baz() has a stack object E pointing to C, and a local variable F pointing to E.
//
// Starting from the pointer in local variable F, we will eventually
// scan all of E, C, D, and A (in that order). B is never scanned
// because there is no live pointer to it. If B is also statically
// dead (meaning that foo() never accesses B again after it calls
// bar()), then B's pointers into the heap are not considered live.
package runtime
import (
"internal/goarch"
"runtime/internal/sys"
"unsafe"
)
const stackTraceDebug = false
// Buffer for pointers found during stack tracing.
// Must be smaller than or equal to workbuf.
type stackWorkBuf struct {
_ sys.NotInHeap
stackWorkBufHdr
obj [(_WorkbufSize - unsafe.Sizeof(stackWorkBufHdr{})) / goarch.PtrSize]uintptr
}
// Header declaration must come after the buf declaration above, because of issue #14620.
type stackWorkBufHdr struct {
_ sys.NotInHeap
workbufhdr
next *stackWorkBuf // linked list of workbufs
// Note: we could theoretically repurpose lfnode.next as this next pointer.
// It would save 1 word, but that probably isn't worth busting open
// the lfnode API.
}
// Buffer for stack objects found on a goroutine stack.
// Must be smaller than or equal to workbuf.
type stackObjectBuf struct {
_ sys.NotInHeap
stackObjectBufHdr
obj [(_WorkbufSize - unsafe.Sizeof(stackObjectBufHdr{})) / unsafe.Sizeof(stackObject{})]stackObject
}
type stackObjectBufHdr struct {
_ sys.NotInHeap
workbufhdr
next *stackObjectBuf
}
func init() {
if unsafe.Sizeof(stackWorkBuf{}) > unsafe.Sizeof(workbuf{}) {
panic("stackWorkBuf too big")
}
if unsafe.Sizeof(stackObjectBuf{}) > unsafe.Sizeof(workbuf{}) {
panic("stackObjectBuf too big")
}
}
// A stackObject represents a variable on the stack that has had
// its address taken.
type stackObject struct {
_ sys.NotInHeap
off uint32 // offset above stack.lo
size uint32 // size of object
r *stackObjectRecord // info of the object (for ptr/nonptr bits). nil if object has been scanned.
left *stackObject // objects with lower addresses
right *stackObject // objects with higher addresses
}
// obj.r = r, but with no write barrier.
//
//go:nowritebarrier
func (obj *stackObject) setRecord(r *stackObjectRecord) {
// Types of stack objects are always in read-only memory, not the heap.
// So not using a write barrier is ok.
*(*uintptr)(unsafe.Pointer(&obj.r)) = uintptr(unsafe.Pointer(r))
}
// A stackScanState keeps track of the state used during the GC walk
// of a goroutine.
type stackScanState struct {
cache pcvalueCache
// stack limits
stack stack
// conservative indicates that the next frame must be scanned conservatively.
// This applies only to the innermost frame at an async safe-point.
conservative bool
// buf contains the set of possible pointers to stack objects.
// Organized as a LIFO linked list of buffers.
// All buffers except possibly the head buffer are full.
buf *stackWorkBuf
freeBuf *stackWorkBuf // keep around one free buffer for allocation hysteresis
// cbuf contains conservative pointers to stack objects. If
// all pointers to a stack object are obtained via
// conservative scanning, then the stack object may be dead
// and may contain dead pointers, so it must be scanned
// defensively.
cbuf *stackWorkBuf
// list of stack objects
// Objects are in increasing address order.
head *stackObjectBuf
tail *stackObjectBuf
nobjs int
// root of binary tree for fast object lookup by address
// Initialized by buildIndex.
root *stackObject
}
// Add p as a potential pointer to a stack object.
// p must be a stack address.
func (s *stackScanState) putPtr(p uintptr, conservative bool) {
if p < s.stack.lo || p >= s.stack.hi {
throw("address not a stack address")
}
head := &s.buf
if conservative {
head = &s.cbuf
}
buf := *head
if buf == nil {
// Initial setup.
buf = (*stackWorkBuf)(unsafe.Pointer(getempty()))
buf.nobj = 0
buf.next = nil
*head = buf
} else if buf.nobj == len(buf.obj) {
if s.freeBuf != nil {
buf = s.freeBuf
s.freeBuf = nil
} else {
buf = (*stackWorkBuf)(unsafe.Pointer(getempty()))
}
buf.nobj = 0
buf.next = *head
*head = buf
}
buf.obj[buf.nobj] = p
buf.nobj++
}
// Remove and return a potential pointer to a stack object.
// Returns 0 if there are no more pointers available.
//
// This prefers non-conservative pointers so we scan stack objects
// precisely if there are any non-conservative pointers to them.
func (s *stackScanState) getPtr() (p uintptr, conservative bool) {
for _, head := range []**stackWorkBuf{&s.buf, &s.cbuf} {
buf := *head
if buf == nil {
// Never had any data.
continue
}
if buf.nobj == 0 {
if s.freeBuf != nil {
// Free old freeBuf.
putempty((*workbuf)(unsafe.Pointer(s.freeBuf)))
}
// Move buf to the freeBuf.
s.freeBuf = buf
buf = buf.next
*head = buf
if buf == nil {
// No more data in this list.
continue
}
}
buf.nobj--
return buf.obj[buf.nobj], head == &s.cbuf
}
// No more data in either list.
if s.freeBuf != nil {
putempty((*workbuf)(unsafe.Pointer(s.freeBuf)))
s.freeBuf = nil
}
return 0, false
}
// addObject adds a stack object at addr of type typ to the set of stack objects.
func (s *stackScanState) addObject(addr uintptr, r *stackObjectRecord) {
x := s.tail
if x == nil {
// initial setup
x = (*stackObjectBuf)(unsafe.Pointer(getempty()))
x.next = nil
s.head = x
s.tail = x
}
if x.nobj > 0 && uint32(addr-s.stack.lo) < x.obj[x.nobj-1].off+x.obj[x.nobj-1].size {
throw("objects added out of order or overlapping")
}
if x.nobj == len(x.obj) {
// full buffer - allocate a new buffer, add to end of linked list
y := (*stackObjectBuf)(unsafe.Pointer(getempty()))
y.next = nil
x.next = y
s.tail = y
x = y
}
obj := &x.obj[x.nobj]
x.nobj++
obj.off = uint32(addr - s.stack.lo)
obj.size = uint32(r.size)
obj.setRecord(r)
// obj.left and obj.right will be initialized by buildIndex before use.
s.nobjs++
}
// buildIndex initializes s.root to a binary search tree.
// It should be called after all addObject calls but before
// any call of findObject.
func (s *stackScanState) buildIndex() {
s.root, _, _ = binarySearchTree(s.head, 0, s.nobjs)
}
// Build a binary search tree with the n objects in the list
// x.obj[idx], x.obj[idx+1], ..., x.next.obj[0], ...
// Returns the root of that tree, and the buf+idx of the nth object after x.obj[idx].
// (The first object that was not included in the binary search tree.)
// If n == 0, returns nil, x.
func binarySearchTree(x *stackObjectBuf, idx int, n int) (root *stackObject, restBuf *stackObjectBuf, restIdx int) {
if n == 0 {
return nil, x, idx
}
var left, right *stackObject
left, x, idx = binarySearchTree(x, idx, n/2)
root = &x.obj[idx]
idx++
if idx == len(x.obj) {
x = x.next
idx = 0
}
right, x, idx = binarySearchTree(x, idx, n-n/2-1)
root.left = left
root.right = right
return root, x, idx
}
// findObject returns the stack object containing address a, if any.
// Must have called buildIndex previously.
func (s *stackScanState) findObject(a uintptr) *stackObject {
off := uint32(a - s.stack.lo)
obj := s.root
for {
if obj == nil {
return nil
}
if off < obj.off {
obj = obj.left
continue
}
if off >= obj.off+obj.size {
obj = obj.right
continue
}
return obj
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Garbage collector: sweeping
// The sweeper consists of two different algorithms:
//
// * The object reclaimer finds and frees unmarked slots in spans. It
// can free a whole span if none of the objects are marked, but that
// isn't its goal. This can be driven either synchronously by
// mcentral.cacheSpan for mcentral spans, or asynchronously by
// sweepone, which looks at all the mcentral lists.
//
// * The span reclaimer looks for spans that contain no marked objects
// and frees whole spans. This is a separate algorithm because
// freeing whole spans is the hardest task for the object reclaimer,
// but is critical when allocating new spans. The entry point for
// this is mheap_.reclaim and it's driven by a sequential scan of
// the page marks bitmap in the heap arenas.
//
// Both algorithms ultimately call mspan.sweep, which sweeps a single
// heap span.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
var sweep sweepdata
// State of background sweep.
type sweepdata struct {
lock mutex
g *g
parked bool
nbgsweep uint32
npausesweep uint32
// active tracks outstanding sweepers and the sweep
// termination condition.
active activeSweep
// centralIndex is the current unswept span class.
// It represents an index into the mcentral span
// sets. Accessed and updated via its load and
// update methods. Not protected by a lock.
//
// Reset at mark termination.
// Used by mheap.nextSpanForSweep.
centralIndex sweepClass
}
// sweepClass is a spanClass and one bit to represent whether we're currently
// sweeping partial or full spans.
type sweepClass uint32
const (
numSweepClasses = numSpanClasses * 2
sweepClassDone sweepClass = sweepClass(^uint32(0))
)
func (s *sweepClass) load() sweepClass {
return sweepClass(atomic.Load((*uint32)(s)))
}
func (s *sweepClass) update(sNew sweepClass) {
// Only update *s if its current value is less than sNew,
// since *s increases monotonically.
sOld := s.load()
for sOld < sNew && !atomic.Cas((*uint32)(s), uint32(sOld), uint32(sNew)) {
sOld = s.load()
}
// TODO(mknyszek): This isn't the only place we have
// an atomic monotonically increasing counter. It would
// be nice to have an "atomic max" which is just implemented
// as the above on most architectures. Some architectures
// like RISC-V however have native support for an atomic max.
}
func (s *sweepClass) clear() {
atomic.Store((*uint32)(s), 0)
}
// split returns the underlying span class as well as
// whether we're interested in the full or partial
// unswept lists for that class, indicated as a boolean
// (true means "full").
func (s sweepClass) split() (spc spanClass, full bool) {
return spanClass(s >> 1), s&1 == 0
}
// nextSpanForSweep finds and pops the next span for sweeping from the
// central sweep buffers. It returns ownership of the span to the caller.
// Returns nil if no such span exists.
func (h *mheap) nextSpanForSweep() *mspan {
sg := h.sweepgen
for sc := sweep.centralIndex.load(); sc < numSweepClasses; sc++ {
spc, full := sc.split()
c := &h.central[spc].mcentral
var s *mspan
if full {
s = c.fullUnswept(sg).pop()
} else {
s = c.partialUnswept(sg).pop()
}
if s != nil {
// Write down that we found something so future sweepers
// can start from here.
sweep.centralIndex.update(sc)
return s
}
}
// Write down that we found nothing.
sweep.centralIndex.update(sweepClassDone)
return nil
}
const sweepDrainedMask = 1 << 31
// activeSweep is a type that captures whether sweeping
// is done, and whether there are any outstanding sweepers.
//
// Every potential sweeper must call begin() before they look
// for work, and end() after they've finished sweeping.
type activeSweep struct {
// state is divided into two parts.
//
// The top bit (masked by sweepDrainedMask) is a boolean
// value indicating whether all the sweep work has been
// drained from the queue.
//
// The rest of the bits are a counter, indicating the
// number of outstanding concurrent sweepers.
state atomic.Uint32
}
// begin registers a new sweeper. Returns a sweepLocker
// for acquiring spans for sweeping. Any outstanding sweeper blocks
// sweep termination.
//
// If the sweepLocker is invalid, the caller can be sure that all
// outstanding sweep work has been drained, so there is nothing left
// to sweep. Note that there may be sweepers currently running, so
// this does not indicate that all sweeping has completed.
//
// Even if the sweepLocker is invalid, its sweepGen is always valid.
func (a *activeSweep) begin() sweepLocker {
for {
state := a.state.Load()
if state&sweepDrainedMask != 0 {
return sweepLocker{mheap_.sweepgen, false}
}
if a.state.CompareAndSwap(state, state+1) {
return sweepLocker{mheap_.sweepgen, true}
}
}
}
// end deregisters a sweeper. Must be called once for each time
// begin is called if the sweepLocker is valid.
func (a *activeSweep) end(sl sweepLocker) {
if sl.sweepGen != mheap_.sweepgen {
throw("sweeper left outstanding across sweep generations")
}
for {
state := a.state.Load()
if (state&^sweepDrainedMask)-1 >= sweepDrainedMask {
throw("mismatched begin/end of activeSweep")
}
if a.state.CompareAndSwap(state, state-1) {
if state != sweepDrainedMask {
return
}
if debug.gcpacertrace > 0 {
live := gcController.heapLive.Load()
print("pacer: sweep done at heap size ", live>>20, "MB; allocated ", (live-mheap_.sweepHeapLiveBasis)>>20, "MB during sweep; swept ", mheap_.pagesSwept.Load(), " pages at ", mheap_.sweepPagesPerByte, " pages/byte\n")
}
return
}
}
}
// markDrained marks the active sweep cycle as having drained
// all remaining work. This is safe to be called concurrently
// with all other methods of activeSweep, though may race.
//
// Returns true if this call was the one that actually performed
// the mark.
func (a *activeSweep) markDrained() bool {
for {
state := a.state.Load()
if state&sweepDrainedMask != 0 {
return false
}
if a.state.CompareAndSwap(state, state|sweepDrainedMask) {
return true
}
}
}
// sweepers returns the current number of active sweepers.
func (a *activeSweep) sweepers() uint32 {
return a.state.Load() &^ sweepDrainedMask
}
// isDone returns true if all sweep work has been drained and no more
// outstanding sweepers exist. That is, when the sweep phase is
// completely done.
func (a *activeSweep) isDone() bool {
return a.state.Load() == sweepDrainedMask
}
// reset sets up the activeSweep for the next sweep cycle.
//
// The world must be stopped.
func (a *activeSweep) reset() {
assertWorldStopped()
a.state.Store(0)
}
// finishsweep_m ensures that all spans are swept.
//
// The world must be stopped. This ensures there are no sweeps in
// progress.
//
//go:nowritebarrier
func finishsweep_m() {
assertWorldStopped()
// Sweeping must be complete before marking commences, so
// sweep any unswept spans. If this is a concurrent GC, there
// shouldn't be any spans left to sweep, so this should finish
// instantly. If GC was forced before the concurrent sweep
// finished, there may be spans to sweep.
for sweepone() != ^uintptr(0) {
sweep.npausesweep++
}
// Make sure there aren't any outstanding sweepers left.
// At this point, with the world stopped, it means one of two
// things. Either we were able to preempt a sweeper, or that
// a sweeper didn't call sweep.active.end when it should have.
// Both cases indicate a bug, so throw.
if sweep.active.sweepers() != 0 {
throw("active sweepers found at start of mark phase")
}
// Reset all the unswept buffers, which should be empty.
// Do this in sweep termination as opposed to mark termination
// so that we can catch unswept spans and reclaim blocks as
// soon as possible.
sg := mheap_.sweepgen
for i := range mheap_.central {
c := &mheap_.central[i].mcentral
c.partialUnswept(sg).reset()
c.fullUnswept(sg).reset()
}
// Sweeping is done, so if the scavenger isn't already awake,
// wake it up. There's definitely work for it to do at this
// point.
scavenger.wake()
nextMarkBitArenaEpoch()
}
func bgsweep(c chan int) {
sweep.g = getg()
lockInit(&sweep.lock, lockRankSweep)
lock(&sweep.lock)
sweep.parked = true
c <- 1
goparkunlock(&sweep.lock, waitReasonGCSweepWait, traceEvGoBlock, 1)
for {
// bgsweep attempts to be a "low priority" goroutine by intentionally
// yielding time. It's OK if it doesn't run, because goroutines allocating
// memory will sweep and ensure that all spans are swept before the next
// GC cycle. We really only want to run when we're idle.
//
// However, calling Gosched after each span swept produces a tremendous
// amount of tracing events, sometimes up to 50% of events in a trace. It's
// also inefficient to call into the scheduler so much because sweeping a
// single span is in general a very fast operation, taking as little as 30 ns
// on modern hardware. (See #54767.)
//
// As a result, bgsweep sweeps in batches, and only calls into the scheduler
// at the end of every batch. Furthermore, it only yields its time if there
// isn't spare idle time available on other cores. If there's available idle
// time, helping to sweep can reduce allocation latencies by getting ahead of
// the proportional sweeper and having spans ready to go for allocation.
const sweepBatchSize = 10
nSwept := 0
for sweepone() != ^uintptr(0) {
sweep.nbgsweep++
nSwept++
if nSwept%sweepBatchSize == 0 {
goschedIfBusy()
}
}
for freeSomeWbufs(true) {
// N.B. freeSomeWbufs is already batched internally.
goschedIfBusy()
}
lock(&sweep.lock)
if !isSweepDone() {
// This can happen if a GC runs between
// gosweepone returning ^0 above
// and the lock being acquired.
unlock(&sweep.lock)
continue
}
sweep.parked = true
goparkunlock(&sweep.lock, waitReasonGCSweepWait, traceEvGoBlock, 1)
}
}
// sweepLocker acquires sweep ownership of spans.
type sweepLocker struct {
// sweepGen is the sweep generation of the heap.
sweepGen uint32
valid bool
}
// sweepLocked represents sweep ownership of a span.
type sweepLocked struct {
*mspan
}
// tryAcquire attempts to acquire sweep ownership of span s. If it
// successfully acquires ownership, it blocks sweep completion.
func (l *sweepLocker) tryAcquire(s *mspan) (sweepLocked, bool) {
if !l.valid {
throw("use of invalid sweepLocker")
}
// Check before attempting to CAS.
if atomic.Load(&s.sweepgen) != l.sweepGen-2 {
return sweepLocked{}, false
}
// Attempt to acquire sweep ownership of s.
if !atomic.Cas(&s.sweepgen, l.sweepGen-2, l.sweepGen-1) {
return sweepLocked{}, false
}
return sweepLocked{s}, true
}
// sweepone sweeps some unswept heap span and returns the number of pages returned
// to the heap, or ^uintptr(0) if there was nothing to sweep.
func sweepone() uintptr {
gp := getg()
// Increment locks to ensure that the goroutine is not preempted
// in the middle of sweep thus leaving the span in an inconsistent state for next GC
gp.m.locks++
// TODO(austin): sweepone is almost always called in a loop;
// lift the sweepLocker into its callers.
sl := sweep.active.begin()
if !sl.valid {
gp.m.locks--
return ^uintptr(0)
}
// Find a span to sweep.
npages := ^uintptr(0)
var noMoreWork bool
for {
s := mheap_.nextSpanForSweep()
if s == nil {
noMoreWork = sweep.active.markDrained()
break
}
if state := s.state.get(); state != mSpanInUse {
// This can happen if direct sweeping already
// swept this span, but in that case the sweep
// generation should always be up-to-date.
if !(s.sweepgen == sl.sweepGen || s.sweepgen == sl.sweepGen+3) {
print("runtime: bad span s.state=", state, " s.sweepgen=", s.sweepgen, " sweepgen=", sl.sweepGen, "\n")
throw("non in-use span in unswept list")
}
continue
}
if s, ok := sl.tryAcquire(s); ok {
// Sweep the span we found.
npages = s.npages
if s.sweep(false) {
// Whole span was freed. Count it toward the
// page reclaimer credit since these pages can
// now be used for span allocation.
mheap_.reclaimCredit.Add(npages)
} else {
// Span is still in-use, so this returned no
// pages to the heap and the span needs to
// move to the swept in-use list.
npages = 0
}
break
}
}
sweep.active.end(sl)
if noMoreWork {
// The sweep list is empty. There may still be
// concurrent sweeps running, but we're at least very
// close to done sweeping.
// Move the scavenge gen forward (signaling
// that there's new work to do) and wake the scavenger.
//
// The scavenger is signaled by the last sweeper because once
// sweeping is done, we will definitely have useful work for
// the scavenger to do, since the scavenger only runs over the
// heap once per GC cycle. This update is not done during sweep
// termination because in some cases there may be a long delay
// between sweep done and sweep termination (e.g. not enough
// allocations to trigger a GC) which would be nice to fill in
// with scavenging work.
if debug.scavtrace > 0 {
systemstack(func() {
lock(&mheap_.lock)
released := atomic.Loaduintptr(&mheap_.pages.scav.released)
printScavTrace(released, false)
atomic.Storeuintptr(&mheap_.pages.scav.released, 0)
unlock(&mheap_.lock)
})
}
scavenger.ready()
}
gp.m.locks--
return npages
}
// isSweepDone reports whether all spans are swept.
//
// Note that this condition may transition from false to true at any
// time as the sweeper runs. It may transition from true to false if a
// GC runs; to prevent that the caller must be non-preemptible or must
// somehow block GC progress.
func isSweepDone() bool {
return sweep.active.isDone()
}
// Returns only when span s has been swept.
//
//go:nowritebarrier
func (s *mspan) ensureSwept() {
// Caller must disable preemption.
// Otherwise when this function returns the span can become unswept again
// (if GC is triggered on another goroutine).
gp := getg()
if gp.m.locks == 0 && gp.m.mallocing == 0 && gp != gp.m.g0 {
throw("mspan.ensureSwept: m is not locked")
}
// If this operation fails, then that means that there are
// no more spans to be swept. In this case, either s has already
// been swept, or is about to be acquired for sweeping and swept.
sl := sweep.active.begin()
if sl.valid {
// The caller must be sure that the span is a mSpanInUse span.
if s, ok := sl.tryAcquire(s); ok {
s.sweep(false)
sweep.active.end(sl)
return
}
sweep.active.end(sl)
}
// Unfortunately we can't sweep the span ourselves. Somebody else
// got to it first. We don't have efficient means to wait, but that's
// OK, it will be swept fairly soon.
for {
spangen := atomic.Load(&s.sweepgen)
if spangen == sl.sweepGen || spangen == sl.sweepGen+3 {
break
}
osyield()
}
}
// sweep frees or collects finalizers for blocks not marked in the mark phase.
// It clears the mark bits in preparation for the next GC round.
// Returns true if the span was returned to heap.
// If preserve=true, don't return it to heap nor relink in mcentral lists;
// caller takes care of it.
func (sl *sweepLocked) sweep(preserve bool) bool {
// It's critical that we enter this function with preemption disabled,
// GC must not start while we are in the middle of this function.
gp := getg()
if gp.m.locks == 0 && gp.m.mallocing == 0 && gp != gp.m.g0 {
throw("mspan.sweep: m is not locked")
}
s := sl.mspan
if !preserve {
// We'll release ownership of this span. Nil it out to
// prevent the caller from accidentally using it.
sl.mspan = nil
}
sweepgen := mheap_.sweepgen
if state := s.state.get(); state != mSpanInUse || s.sweepgen != sweepgen-1 {
print("mspan.sweep: state=", state, " sweepgen=", s.sweepgen, " mheap.sweepgen=", sweepgen, "\n")
throw("mspan.sweep: bad span state")
}
if trace.enabled {
traceGCSweepSpan(s.npages * _PageSize)
}
mheap_.pagesSwept.Add(int64(s.npages))
spc := s.spanclass
size := s.elemsize
// The allocBits indicate which unmarked objects don't need to be
// processed since they were free at the end of the last GC cycle
// and were not allocated since then.
// If the allocBits index is >= s.freeindex and the bit
// is not marked then the object remains unallocated
// since the last GC.
// This situation is analogous to being on a freelist.
// Unlink & free special records for any objects we're about to free.
// Two complications here:
// 1. An object can have both finalizer and profile special records.
// In such case we need to queue finalizer for execution,
// mark the object as live and preserve the profile special.
// 2. A tiny object can have several finalizers setup for different offsets.
// If such object is not marked, we need to queue all finalizers at once.
// Both 1 and 2 are possible at the same time.
hadSpecials := s.specials != nil
siter := newSpecialsIter(s)
for siter.valid() {
// A finalizer can be set for an inner byte of an object, find object beginning.
objIndex := uintptr(siter.s.offset) / size
p := s.base() + objIndex*size
mbits := s.markBitsForIndex(objIndex)
if !mbits.isMarked() {
// This object is not marked and has at least one special record.
// Pass 1: see if it has at least one finalizer.
hasFin := false
endOffset := p - s.base() + size
for tmp := siter.s; tmp != nil && uintptr(tmp.offset) < endOffset; tmp = tmp.next {
if tmp.kind == _KindSpecialFinalizer {
// Stop freeing of object if it has a finalizer.
mbits.setMarkedNonAtomic()
hasFin = true
break
}
}
// Pass 2: queue all finalizers _or_ handle profile record.
for siter.valid() && uintptr(siter.s.offset) < endOffset {
// Find the exact byte for which the special was setup
// (as opposed to object beginning).
special := siter.s
p := s.base() + uintptr(special.offset)
if special.kind == _KindSpecialFinalizer || !hasFin {
siter.unlinkAndNext()
freeSpecial(special, unsafe.Pointer(p), size)
} else {
// The object has finalizers, so we're keeping it alive.
// All other specials only apply when an object is freed,
// so just keep the special record.
siter.next()
}
}
} else {
// object is still live
if siter.s.kind == _KindSpecialReachable {
special := siter.unlinkAndNext()
(*specialReachable)(unsafe.Pointer(special)).reachable = true
freeSpecial(special, unsafe.Pointer(p), size)
} else {
// keep special record
siter.next()
}
}
}
if hadSpecials && s.specials == nil {
spanHasNoSpecials(s)
}
if debug.allocfreetrace != 0 || debug.clobberfree != 0 || raceenabled || msanenabled || asanenabled {
// Find all newly freed objects. This doesn't have to
// efficient; allocfreetrace has massive overhead.
mbits := s.markBitsForBase()
abits := s.allocBitsForIndex(0)
for i := uintptr(0); i < s.nelems; i++ {
if !mbits.isMarked() && (abits.index < s.freeindex || abits.isMarked()) {
x := s.base() + i*s.elemsize
if debug.allocfreetrace != 0 {
tracefree(unsafe.Pointer(x), size)
}
if debug.clobberfree != 0 {
clobberfree(unsafe.Pointer(x), size)
}
// User arenas are handled on explicit free.
if raceenabled && !s.isUserArenaChunk {
racefree(unsafe.Pointer(x), size)
}
if msanenabled && !s.isUserArenaChunk {
msanfree(unsafe.Pointer(x), size)
}
if asanenabled && !s.isUserArenaChunk {
asanpoison(unsafe.Pointer(x), size)
}
}
mbits.advance()
abits.advance()
}
}
// Check for zombie objects.
if s.freeindex < s.nelems {
// Everything < freeindex is allocated and hence
// cannot be zombies.
//
// Check the first bitmap byte, where we have to be
// careful with freeindex.
obj := s.freeindex
if (*s.gcmarkBits.bytep(obj / 8)&^*s.allocBits.bytep(obj / 8))>>(obj%8) != 0 {
s.reportZombies()
}
// Check remaining bytes.
for i := obj/8 + 1; i < divRoundUp(s.nelems, 8); i++ {
if *s.gcmarkBits.bytep(i)&^*s.allocBits.bytep(i) != 0 {
s.reportZombies()
}
}
}
// Count the number of free objects in this span.
nalloc := uint16(s.countAlloc())
nfreed := s.allocCount - nalloc
if nalloc > s.allocCount {
// The zombie check above should have caught this in
// more detail.
print("runtime: nelems=", s.nelems, " nalloc=", nalloc, " previous allocCount=", s.allocCount, " nfreed=", nfreed, "\n")
throw("sweep increased allocation count")
}
s.allocCount = nalloc
s.freeindex = 0 // reset allocation index to start of span.
s.freeIndexForScan = 0
if trace.enabled {
getg().m.p.ptr().traceReclaimed += uintptr(nfreed) * s.elemsize
}
// gcmarkBits becomes the allocBits.
// get a fresh cleared gcmarkBits in preparation for next GC
s.allocBits = s.gcmarkBits
s.gcmarkBits = newMarkBits(s.nelems)
// Initialize alloc bits cache.
s.refillAllocCache(0)
// The span must be in our exclusive ownership until we update sweepgen,
// check for potential races.
if state := s.state.get(); state != mSpanInUse || s.sweepgen != sweepgen-1 {
print("mspan.sweep: state=", state, " sweepgen=", s.sweepgen, " mheap.sweepgen=", sweepgen, "\n")
throw("mspan.sweep: bad span state after sweep")
}
if s.sweepgen == sweepgen+1 || s.sweepgen == sweepgen+3 {
throw("swept cached span")
}
// We need to set s.sweepgen = h.sweepgen only when all blocks are swept,
// because of the potential for a concurrent free/SetFinalizer.
//
// But we need to set it before we make the span available for allocation
// (return it to heap or mcentral), because allocation code assumes that a
// span is already swept if available for allocation.
//
// Serialization point.
// At this point the mark bits are cleared and allocation ready
// to go so release the span.
atomic.Store(&s.sweepgen, sweepgen)
if s.isUserArenaChunk {
if preserve {
// This is a case that should never be handled by a sweeper that
// preserves the span for reuse.
throw("sweep: tried to preserve a user arena span")
}
if nalloc > 0 {
// There still exist pointers into the span or the span hasn't been
// freed yet. It's not ready to be reused. Put it back on the
// full swept list for the next cycle.
mheap_.central[spc].mcentral.fullSwept(sweepgen).push(s)
return false
}
// It's only at this point that the sweeper doesn't actually need to look
// at this arena anymore, so subtract from pagesInUse now.
mheap_.pagesInUse.Add(-s.npages)
s.state.set(mSpanDead)
// The arena is ready to be recycled. Remove it from the quarantine list
// and place it on the ready list. Don't add it back to any sweep lists.
systemstack(func() {
// It's the arena code's responsibility to get the chunk on the quarantine
// list by the time all references to the chunk are gone.
if s.list != &mheap_.userArena.quarantineList {
throw("user arena span is on the wrong list")
}
lock(&mheap_.lock)
mheap_.userArena.quarantineList.remove(s)
mheap_.userArena.readyList.insert(s)
unlock(&mheap_.lock)
})
return false
}
if spc.sizeclass() != 0 {
// Handle spans for small objects.
if nfreed > 0 {
// Only mark the span as needing zeroing if we've freed any
// objects, because a fresh span that had been allocated into,
// wasn't totally filled, but then swept, still has all of its
// free slots zeroed.
s.needzero = 1
stats := memstats.heapStats.acquire()
atomic.Xadd64(&stats.smallFreeCount[spc.sizeclass()], int64(nfreed))
memstats.heapStats.release()
// Count the frees in the inconsistent, internal stats.
gcController.totalFree.Add(int64(nfreed) * int64(s.elemsize))
}
if !preserve {
// The caller may not have removed this span from whatever
// unswept set its on but taken ownership of the span for
// sweeping by updating sweepgen. If this span still is in
// an unswept set, then the mcentral will pop it off the
// set, check its sweepgen, and ignore it.
if nalloc == 0 {
// Free totally free span directly back to the heap.
mheap_.freeSpan(s)
return true
}
// Return span back to the right mcentral list.
if uintptr(nalloc) == s.nelems {
mheap_.central[spc].mcentral.fullSwept(sweepgen).push(s)
} else {
mheap_.central[spc].mcentral.partialSwept(sweepgen).push(s)
}
}
} else if !preserve {
// Handle spans for large objects.
if nfreed != 0 {
// Free large object span to heap.
// NOTE(rsc,dvyukov): The original implementation of efence
// in CL 22060046 used sysFree instead of sysFault, so that
// the operating system would eventually give the memory
// back to us again, so that an efence program could run
// longer without running out of memory. Unfortunately,
// calling sysFree here without any kind of adjustment of the
// heap data structures means that when the memory does
// come back to us, we have the wrong metadata for it, either in
// the mspan structures or in the garbage collection bitmap.
// Using sysFault here means that the program will run out of
// memory fairly quickly in efence mode, but at least it won't
// have mysterious crashes due to confused memory reuse.
// It should be possible to switch back to sysFree if we also
// implement and then call some kind of mheap.deleteSpan.
if debug.efence > 0 {
s.limit = 0 // prevent mlookup from finding this span
sysFault(unsafe.Pointer(s.base()), size)
} else {
mheap_.freeSpan(s)
}
// Count the free in the consistent, external stats.
stats := memstats.heapStats.acquire()
atomic.Xadd64(&stats.largeFreeCount, 1)
atomic.Xadd64(&stats.largeFree, int64(size))
memstats.heapStats.release()
// Count the free in the inconsistent, internal stats.
gcController.totalFree.Add(int64(size))
return true
}
// Add a large span directly onto the full+swept list.
mheap_.central[spc].mcentral.fullSwept(sweepgen).push(s)
}
return false
}
// reportZombies reports any marked but free objects in s and throws.
//
// This generally means one of the following:
//
// 1. User code converted a pointer to a uintptr and then back
// unsafely, and a GC ran while the uintptr was the only reference to
// an object.
//
// 2. User code (or a compiler bug) constructed a bad pointer that
// points to a free slot, often a past-the-end pointer.
//
// 3. The GC two cycles ago missed a pointer and freed a live object,
// but it was still live in the last cycle, so this GC cycle found a
// pointer to that object and marked it.
func (s *mspan) reportZombies() {
printlock()
print("runtime: marked free object in span ", s, ", elemsize=", s.elemsize, " freeindex=", s.freeindex, " (bad use of unsafe.Pointer? try -d=checkptr)\n")
mbits := s.markBitsForBase()
abits := s.allocBitsForIndex(0)
for i := uintptr(0); i < s.nelems; i++ {
addr := s.base() + i*s.elemsize
print(hex(addr))
alloc := i < s.freeindex || abits.isMarked()
if alloc {
print(" alloc")
} else {
print(" free ")
}
if mbits.isMarked() {
print(" marked ")
} else {
print(" unmarked")
}
zombie := mbits.isMarked() && !alloc
if zombie {
print(" zombie")
}
print("\n")
if zombie {
length := s.elemsize
if length > 1024 {
length = 1024
}
hexdumpWords(addr, addr+length, nil)
}
mbits.advance()
abits.advance()
}
throw("found pointer to free object")
}
// deductSweepCredit deducts sweep credit for allocating a span of
// size spanBytes. This must be performed *before* the span is
// allocated to ensure the system has enough credit. If necessary, it
// performs sweeping to prevent going in to debt. If the caller will
// also sweep pages (e.g., for a large allocation), it can pass a
// non-zero callerSweepPages to leave that many pages unswept.
//
// deductSweepCredit makes a worst-case assumption that all spanBytes
// bytes of the ultimately allocated span will be available for object
// allocation.
//
// deductSweepCredit is the core of the "proportional sweep" system.
// It uses statistics gathered by the garbage collector to perform
// enough sweeping so that all pages are swept during the concurrent
// sweep phase between GC cycles.
//
// mheap_ must NOT be locked.
func deductSweepCredit(spanBytes uintptr, callerSweepPages uintptr) {
if mheap_.sweepPagesPerByte == 0 {
// Proportional sweep is done or disabled.
return
}
if trace.enabled {
traceGCSweepStart()
}
// Fix debt if necessary.
retry:
sweptBasis := mheap_.pagesSweptBasis.Load()
live := gcController.heapLive.Load()
liveBasis := mheap_.sweepHeapLiveBasis
newHeapLive := spanBytes
if liveBasis < live {
// Only do this subtraction when we don't overflow. Otherwise, pagesTarget
// might be computed as something really huge, causing us to get stuck
// sweeping here until the next mark phase.
//
// Overflow can happen here if gcPaceSweeper is called concurrently with
// sweeping (i.e. not during a STW, like it usually is) because this code
// is intentionally racy. A concurrent call to gcPaceSweeper can happen
// if a GC tuning parameter is modified and we read an older value of
// heapLive than what was used to set the basis.
//
// This state should be transient, so it's fine to just let newHeapLive
// be a relatively small number. We'll probably just skip this attempt to
// sweep.
//
// See issue #57523.
newHeapLive += uintptr(live - liveBasis)
}
pagesTarget := int64(mheap_.sweepPagesPerByte*float64(newHeapLive)) - int64(callerSweepPages)
for pagesTarget > int64(mheap_.pagesSwept.Load()-sweptBasis) {
if sweepone() == ^uintptr(0) {
mheap_.sweepPagesPerByte = 0
break
}
if mheap_.pagesSweptBasis.Load() != sweptBasis {
// Sweep pacing changed. Recompute debt.
goto retry
}
}
if trace.enabled {
traceGCSweepDone()
}
}
// clobberfree sets the memory content at x to bad content, for debugging
// purposes.
func clobberfree(x unsafe.Pointer, size uintptr) {
// size (span.elemsize) is always a multiple of 4.
for i := uintptr(0); i < size; i += 4 {
*(*uint32)(add(x, i)) = 0xdeadbeef
}
}
// gcPaceSweeper updates the sweeper's pacing parameters.
//
// Must be called whenever the GC's pacing is updated.
//
// The world must be stopped, or mheap_.lock must be held.
func gcPaceSweeper(trigger uint64) {
assertWorldStoppedOrLockHeld(&mheap_.lock)
// Update sweep pacing.
if isSweepDone() {
mheap_.sweepPagesPerByte = 0
} else {
// Concurrent sweep needs to sweep all of the in-use
// pages by the time the allocated heap reaches the GC
// trigger. Compute the ratio of in-use pages to sweep
// per byte allocated, accounting for the fact that
// some might already be swept.
heapLiveBasis := gcController.heapLive.Load()
heapDistance := int64(trigger) - int64(heapLiveBasis)
// Add a little margin so rounding errors and
// concurrent sweep are less likely to leave pages
// unswept when GC starts.
heapDistance -= 1024 * 1024
if heapDistance < _PageSize {
// Avoid setting the sweep ratio extremely high
heapDistance = _PageSize
}
pagesSwept := mheap_.pagesSwept.Load()
pagesInUse := mheap_.pagesInUse.Load()
sweepDistancePages := int64(pagesInUse) - int64(pagesSwept)
if sweepDistancePages <= 0 {
mheap_.sweepPagesPerByte = 0
} else {
mheap_.sweepPagesPerByte = float64(sweepDistancePages) / float64(heapDistance)
mheap_.sweepHeapLiveBasis = heapLiveBasis
// Write pagesSweptBasis last, since this
// signals concurrent sweeps to recompute
// their debt.
mheap_.pagesSweptBasis.Store(pagesSwept)
}
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
const (
_WorkbufSize = 2048 // in bytes; larger values result in less contention
// workbufAlloc is the number of bytes to allocate at a time
// for new workbufs. This must be a multiple of pageSize and
// should be a multiple of _WorkbufSize.
//
// Larger values reduce workbuf allocation overhead. Smaller
// values reduce heap fragmentation.
workbufAlloc = 32 << 10
)
func init() {
if workbufAlloc%pageSize != 0 || workbufAlloc%_WorkbufSize != 0 {
throw("bad workbufAlloc")
}
}
// Garbage collector work pool abstraction.
//
// This implements a producer/consumer model for pointers to grey
// objects. A grey object is one that is marked and on a work
// queue. A black object is marked and not on a work queue.
//
// Write barriers, root discovery, stack scanning, and object scanning
// produce pointers to grey objects. Scanning consumes pointers to
// grey objects, thus blackening them, and then scans them,
// potentially producing new pointers to grey objects.
// A gcWork provides the interface to produce and consume work for the
// garbage collector.
//
// A gcWork can be used on the stack as follows:
//
// (preemption must be disabled)
// gcw := &getg().m.p.ptr().gcw
// .. call gcw.put() to produce and gcw.tryGet() to consume ..
//
// It's important that any use of gcWork during the mark phase prevent
// the garbage collector from transitioning to mark termination since
// gcWork may locally hold GC work buffers. This can be done by
// disabling preemption (systemstack or acquirem).
type gcWork struct {
// wbuf1 and wbuf2 are the primary and secondary work buffers.
//
// This can be thought of as a stack of both work buffers'
// pointers concatenated. When we pop the last pointer, we
// shift the stack up by one work buffer by bringing in a new
// full buffer and discarding an empty one. When we fill both
// buffers, we shift the stack down by one work buffer by
// bringing in a new empty buffer and discarding a full one.
// This way we have one buffer's worth of hysteresis, which
// amortizes the cost of getting or putting a work buffer over
// at least one buffer of work and reduces contention on the
// global work lists.
//
// wbuf1 is always the buffer we're currently pushing to and
// popping from and wbuf2 is the buffer that will be discarded
// next.
//
// Invariant: Both wbuf1 and wbuf2 are nil or neither are.
wbuf1, wbuf2 *workbuf
// Bytes marked (blackened) on this gcWork. This is aggregated
// into work.bytesMarked by dispose.
bytesMarked uint64
// Heap scan work performed on this gcWork. This is aggregated into
// gcController by dispose and may also be flushed by callers.
// Other types of scan work are flushed immediately.
heapScanWork int64
// flushedWork indicates that a non-empty work buffer was
// flushed to the global work list since the last gcMarkDone
// termination check. Specifically, this indicates that this
// gcWork may have communicated work to another gcWork.
flushedWork bool
}
// Most of the methods of gcWork are go:nowritebarrierrec because the
// write barrier itself can invoke gcWork methods but the methods are
// not generally re-entrant. Hence, if a gcWork method invoked the
// write barrier while the gcWork was in an inconsistent state, and
// the write barrier in turn invoked a gcWork method, it could
// permanently corrupt the gcWork.
func (w *gcWork) init() {
w.wbuf1 = getempty()
wbuf2 := trygetfull()
if wbuf2 == nil {
wbuf2 = getempty()
}
w.wbuf2 = wbuf2
}
// put enqueues a pointer for the garbage collector to trace.
// obj must point to the beginning of a heap object or an oblet.
//
//go:nowritebarrierrec
func (w *gcWork) put(obj uintptr) {
flushed := false
wbuf := w.wbuf1
// Record that this may acquire the wbufSpans or heap lock to
// allocate a workbuf.
lockWithRankMayAcquire(&work.wbufSpans.lock, lockRankWbufSpans)
lockWithRankMayAcquire(&mheap_.lock, lockRankMheap)
if wbuf == nil {
w.init()
wbuf = w.wbuf1
// wbuf is empty at this point.
} else if wbuf.nobj == len(wbuf.obj) {
w.wbuf1, w.wbuf2 = w.wbuf2, w.wbuf1
wbuf = w.wbuf1
if wbuf.nobj == len(wbuf.obj) {
putfull(wbuf)
w.flushedWork = true
wbuf = getempty()
w.wbuf1 = wbuf
flushed = true
}
}
wbuf.obj[wbuf.nobj] = obj
wbuf.nobj++
// If we put a buffer on full, let the GC controller know so
// it can encourage more workers to run. We delay this until
// the end of put so that w is in a consistent state, since
// enlistWorker may itself manipulate w.
if flushed && gcphase == _GCmark {
gcController.enlistWorker()
}
}
// putFast does a put and reports whether it can be done quickly
// otherwise it returns false and the caller needs to call put.
//
//go:nowritebarrierrec
func (w *gcWork) putFast(obj uintptr) bool {
wbuf := w.wbuf1
if wbuf == nil || wbuf.nobj == len(wbuf.obj) {
return false
}
wbuf.obj[wbuf.nobj] = obj
wbuf.nobj++
return true
}
// putBatch performs a put on every pointer in obj. See put for
// constraints on these pointers.
//
//go:nowritebarrierrec
func (w *gcWork) putBatch(obj []uintptr) {
if len(obj) == 0 {
return
}
flushed := false
wbuf := w.wbuf1
if wbuf == nil {
w.init()
wbuf = w.wbuf1
}
for len(obj) > 0 {
for wbuf.nobj == len(wbuf.obj) {
putfull(wbuf)
w.flushedWork = true
w.wbuf1, w.wbuf2 = w.wbuf2, getempty()
wbuf = w.wbuf1
flushed = true
}
n := copy(wbuf.obj[wbuf.nobj:], obj)
wbuf.nobj += n
obj = obj[n:]
}
if flushed && gcphase == _GCmark {
gcController.enlistWorker()
}
}
// tryGet dequeues a pointer for the garbage collector to trace.
//
// If there are no pointers remaining in this gcWork or in the global
// queue, tryGet returns 0. Note that there may still be pointers in
// other gcWork instances or other caches.
//
//go:nowritebarrierrec
func (w *gcWork) tryGet() uintptr {
wbuf := w.wbuf1
if wbuf == nil {
w.init()
wbuf = w.wbuf1
// wbuf is empty at this point.
}
if wbuf.nobj == 0 {
w.wbuf1, w.wbuf2 = w.wbuf2, w.wbuf1
wbuf = w.wbuf1
if wbuf.nobj == 0 {
owbuf := wbuf
wbuf = trygetfull()
if wbuf == nil {
return 0
}
putempty(owbuf)
w.wbuf1 = wbuf
}
}
wbuf.nobj--
return wbuf.obj[wbuf.nobj]
}
// tryGetFast dequeues a pointer for the garbage collector to trace
// if one is readily available. Otherwise it returns 0 and
// the caller is expected to call tryGet().
//
//go:nowritebarrierrec
func (w *gcWork) tryGetFast() uintptr {
wbuf := w.wbuf1
if wbuf == nil || wbuf.nobj == 0 {
return 0
}
wbuf.nobj--
return wbuf.obj[wbuf.nobj]
}
// dispose returns any cached pointers to the global queue.
// The buffers are being put on the full queue so that the
// write barriers will not simply reacquire them before the
// GC can inspect them. This helps reduce the mutator's
// ability to hide pointers during the concurrent mark phase.
//
//go:nowritebarrierrec
func (w *gcWork) dispose() {
if wbuf := w.wbuf1; wbuf != nil {
if wbuf.nobj == 0 {
putempty(wbuf)
} else {
putfull(wbuf)
w.flushedWork = true
}
w.wbuf1 = nil
wbuf = w.wbuf2
if wbuf.nobj == 0 {
putempty(wbuf)
} else {
putfull(wbuf)
w.flushedWork = true
}
w.wbuf2 = nil
}
if w.bytesMarked != 0 {
// dispose happens relatively infrequently. If this
// atomic becomes a problem, we should first try to
// dispose less and if necessary aggregate in a per-P
// counter.
atomic.Xadd64(&work.bytesMarked, int64(w.bytesMarked))
w.bytesMarked = 0
}
if w.heapScanWork != 0 {
gcController.heapScanWork.Add(w.heapScanWork)
w.heapScanWork = 0
}
}
// balance moves some work that's cached in this gcWork back on the
// global queue.
//
//go:nowritebarrierrec
func (w *gcWork) balance() {
if w.wbuf1 == nil {
return
}
if wbuf := w.wbuf2; wbuf.nobj != 0 {
putfull(wbuf)
w.flushedWork = true
w.wbuf2 = getempty()
} else if wbuf := w.wbuf1; wbuf.nobj > 4 {
w.wbuf1 = handoff(wbuf)
w.flushedWork = true // handoff did putfull
} else {
return
}
// We flushed a buffer to the full list, so wake a worker.
if gcphase == _GCmark {
gcController.enlistWorker()
}
}
// empty reports whether w has no mark work available.
//
//go:nowritebarrierrec
func (w *gcWork) empty() bool {
return w.wbuf1 == nil || (w.wbuf1.nobj == 0 && w.wbuf2.nobj == 0)
}
// Internally, the GC work pool is kept in arrays in work buffers.
// The gcWork interface caches a work buffer until full (or empty) to
// avoid contending on the global work buffer lists.
type workbufhdr struct {
node lfnode // must be first
nobj int
}
type workbuf struct {
_ sys.NotInHeap
workbufhdr
// account for the above fields
obj [(_WorkbufSize - unsafe.Sizeof(workbufhdr{})) / goarch.PtrSize]uintptr
}
// workbuf factory routines. These funcs are used to manage the
// workbufs.
// If the GC asks for some work these are the only routines that
// make wbufs available to the GC.
func (b *workbuf) checknonempty() {
if b.nobj == 0 {
throw("workbuf is empty")
}
}
func (b *workbuf) checkempty() {
if b.nobj != 0 {
throw("workbuf is not empty")
}
}
// getempty pops an empty work buffer off the work.empty list,
// allocating new buffers if none are available.
//
//go:nowritebarrier
func getempty() *workbuf {
var b *workbuf
if work.empty != 0 {
b = (*workbuf)(work.empty.pop())
if b != nil {
b.checkempty()
}
}
// Record that this may acquire the wbufSpans or heap lock to
// allocate a workbuf.
lockWithRankMayAcquire(&work.wbufSpans.lock, lockRankWbufSpans)
lockWithRankMayAcquire(&mheap_.lock, lockRankMheap)
if b == nil {
// Allocate more workbufs.
var s *mspan
if work.wbufSpans.free.first != nil {
lock(&work.wbufSpans.lock)
s = work.wbufSpans.free.first
if s != nil {
work.wbufSpans.free.remove(s)
work.wbufSpans.busy.insert(s)
}
unlock(&work.wbufSpans.lock)
}
if s == nil {
systemstack(func() {
s = mheap_.allocManual(workbufAlloc/pageSize, spanAllocWorkBuf)
})
if s == nil {
throw("out of memory")
}
// Record the new span in the busy list.
lock(&work.wbufSpans.lock)
work.wbufSpans.busy.insert(s)
unlock(&work.wbufSpans.lock)
}
// Slice up the span into new workbufs. Return one and
// put the rest on the empty list.
for i := uintptr(0); i+_WorkbufSize <= workbufAlloc; i += _WorkbufSize {
newb := (*workbuf)(unsafe.Pointer(s.base() + i))
newb.nobj = 0
lfnodeValidate(&newb.node)
if i == 0 {
b = newb
} else {
putempty(newb)
}
}
}
return b
}
// putempty puts a workbuf onto the work.empty list.
// Upon entry this goroutine owns b. The lfstack.push relinquishes ownership.
//
//go:nowritebarrier
func putempty(b *workbuf) {
b.checkempty()
work.empty.push(&b.node)
}
// putfull puts the workbuf on the work.full list for the GC.
// putfull accepts partially full buffers so the GC can avoid competing
// with the mutators for ownership of partially full buffers.
//
//go:nowritebarrier
func putfull(b *workbuf) {
b.checknonempty()
work.full.push(&b.node)
}
// trygetfull tries to get a full or partially empty workbuffer.
// If one is not immediately available return nil.
//
//go:nowritebarrier
func trygetfull() *workbuf {
b := (*workbuf)(work.full.pop())
if b != nil {
b.checknonempty()
return b
}
return b
}
//go:nowritebarrier
func handoff(b *workbuf) *workbuf {
// Make new buffer with half of b's pointers.
b1 := getempty()
n := b.nobj / 2
b.nobj -= n
b1.nobj = n
memmove(unsafe.Pointer(&b1.obj[0]), unsafe.Pointer(&b.obj[b.nobj]), uintptr(n)*unsafe.Sizeof(b1.obj[0]))
// Put b on full list - let first half of b get stolen.
putfull(b)
return b1
}
// prepareFreeWorkbufs moves busy workbuf spans to free list so they
// can be freed to the heap. This must only be called when all
// workbufs are on the empty list.
func prepareFreeWorkbufs() {
lock(&work.wbufSpans.lock)
if work.full != 0 {
throw("cannot free workbufs when work.full != 0")
}
// Since all workbufs are on the empty list, we don't care
// which ones are in which spans. We can wipe the entire empty
// list and move all workbuf spans to the free list.
work.empty = 0
work.wbufSpans.free.takeAll(&work.wbufSpans.busy)
unlock(&work.wbufSpans.lock)
}
// freeSomeWbufs frees some workbufs back to the heap and returns
// true if it should be called again to free more.
func freeSomeWbufs(preemptible bool) bool {
const batchSize = 64 // ~1–2 µs per span.
lock(&work.wbufSpans.lock)
if gcphase != _GCoff || work.wbufSpans.free.isEmpty() {
unlock(&work.wbufSpans.lock)
return false
}
systemstack(func() {
gp := getg().m.curg
for i := 0; i < batchSize && !(preemptible && gp.preempt); i++ {
span := work.wbufSpans.free.first
if span == nil {
break
}
work.wbufSpans.free.remove(span)
mheap_.freeManual(span, spanAllocWorkBuf)
}
})
more := !work.wbufSpans.free.isEmpty()
unlock(&work.wbufSpans.lock)
return more
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Page heap.
//
// See malloc.go for overview.
package runtime
import (
"internal/cpu"
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
const (
// minPhysPageSize is a lower-bound on the physical page size. The
// true physical page size may be larger than this. In contrast,
// sys.PhysPageSize is an upper-bound on the physical page size.
minPhysPageSize = 4096
// maxPhysPageSize is the maximum page size the runtime supports.
maxPhysPageSize = 512 << 10
// maxPhysHugePageSize sets an upper-bound on the maximum huge page size
// that the runtime supports.
maxPhysHugePageSize = pallocChunkBytes
// pagesPerReclaimerChunk indicates how many pages to scan from the
// pageInUse bitmap at a time. Used by the page reclaimer.
//
// Higher values reduce contention on scanning indexes (such as
// h.reclaimIndex), but increase the minimum latency of the
// operation.
//
// The time required to scan this many pages can vary a lot depending
// on how many spans are actually freed. Experimentally, it can
// scan for pages at ~300 GB/ms on a 2.6GHz Core i7, but can only
// free spans at ~32 MB/ms. Using 512 pages bounds this at
// roughly 100µs.
//
// Must be a multiple of the pageInUse bitmap element size and
// must also evenly divide pagesPerArena.
pagesPerReclaimerChunk = 512
// physPageAlignedStacks indicates whether stack allocations must be
// physical page aligned. This is a requirement for MAP_STACK on
// OpenBSD.
physPageAlignedStacks = GOOS == "openbsd"
)
// Main malloc heap.
// The heap itself is the "free" and "scav" treaps,
// but all the other global data is here too.
//
// mheap must not be heap-allocated because it contains mSpanLists,
// which must not be heap-allocated.
type mheap struct {
_ sys.NotInHeap
// lock must only be acquired on the system stack, otherwise a g
// could self-deadlock if its stack grows with the lock held.
lock mutex
pages pageAlloc // page allocation data structure
sweepgen uint32 // sweep generation, see comment in mspan; written during STW
// allspans is a slice of all mspans ever created. Each mspan
// appears exactly once.
//
// The memory for allspans is manually managed and can be
// reallocated and move as the heap grows.
//
// In general, allspans is protected by mheap_.lock, which
// prevents concurrent access as well as freeing the backing
// store. Accesses during STW might not hold the lock, but
// must ensure that allocation cannot happen around the
// access (since that may free the backing store).
allspans []*mspan // all spans out there
// Proportional sweep
//
// These parameters represent a linear function from gcController.heapLive
// to page sweep count. The proportional sweep system works to
// stay in the black by keeping the current page sweep count
// above this line at the current gcController.heapLive.
//
// The line has slope sweepPagesPerByte and passes through a
// basis point at (sweepHeapLiveBasis, pagesSweptBasis). At
// any given time, the system is at (gcController.heapLive,
// pagesSwept) in this space.
//
// It is important that the line pass through a point we
// control rather than simply starting at a 0,0 origin
// because that lets us adjust sweep pacing at any time while
// accounting for current progress. If we could only adjust
// the slope, it would create a discontinuity in debt if any
// progress has already been made.
pagesInUse atomic.Uintptr // pages of spans in stats mSpanInUse
pagesSwept atomic.Uint64 // pages swept this cycle
pagesSweptBasis atomic.Uint64 // pagesSwept to use as the origin of the sweep ratio
sweepHeapLiveBasis uint64 // value of gcController.heapLive to use as the origin of sweep ratio; written with lock, read without
sweepPagesPerByte float64 // proportional sweep ratio; written with lock, read without
// Page reclaimer state
// reclaimIndex is the page index in allArenas of next page to
// reclaim. Specifically, it refers to page (i %
// pagesPerArena) of arena allArenas[i / pagesPerArena].
//
// If this is >= 1<<63, the page reclaimer is done scanning
// the page marks.
reclaimIndex atomic.Uint64
// reclaimCredit is spare credit for extra pages swept. Since
// the page reclaimer works in large chunks, it may reclaim
// more than requested. Any spare pages released go to this
// credit pool.
reclaimCredit atomic.Uintptr
// arenas is the heap arena map. It points to the metadata for
// the heap for every arena frame of the entire usable virtual
// address space.
//
// Use arenaIndex to compute indexes into this array.
//
// For regions of the address space that are not backed by the
// Go heap, the arena map contains nil.
//
// Modifications are protected by mheap_.lock. Reads can be
// performed without locking; however, a given entry can
// transition from nil to non-nil at any time when the lock
// isn't held. (Entries never transitions back to nil.)
//
// In general, this is a two-level mapping consisting of an L1
// map and possibly many L2 maps. This saves space when there
// are a huge number of arena frames. However, on many
// platforms (even 64-bit), arenaL1Bits is 0, making this
// effectively a single-level map. In this case, arenas[0]
// will never be nil.
arenas [1 << arenaL1Bits]*[1 << arenaL2Bits]*heapArena
// heapArenaAlloc is pre-reserved space for allocating heapArena
// objects. This is only used on 32-bit, where we pre-reserve
// this space to avoid interleaving it with the heap itself.
heapArenaAlloc linearAlloc
// arenaHints is a list of addresses at which to attempt to
// add more heap arenas. This is initially populated with a
// set of general hint addresses, and grown with the bounds of
// actual heap arena ranges.
arenaHints *arenaHint
// arena is a pre-reserved space for allocating heap arenas
// (the actual arenas). This is only used on 32-bit.
arena linearAlloc
// allArenas is the arenaIndex of every mapped arena. This can
// be used to iterate through the address space.
//
// Access is protected by mheap_.lock. However, since this is
// append-only and old backing arrays are never freed, it is
// safe to acquire mheap_.lock, copy the slice header, and
// then release mheap_.lock.
allArenas []arenaIdx
// sweepArenas is a snapshot of allArenas taken at the
// beginning of the sweep cycle. This can be read safely by
// simply blocking GC (by disabling preemption).
sweepArenas []arenaIdx
// markArenas is a snapshot of allArenas taken at the beginning
// of the mark cycle. Because allArenas is append-only, neither
// this slice nor its contents will change during the mark, so
// it can be read safely.
markArenas []arenaIdx
// curArena is the arena that the heap is currently growing
// into. This should always be physPageSize-aligned.
curArena struct {
base, end uintptr
}
// central free lists for small size classes.
// the padding makes sure that the mcentrals are
// spaced CacheLinePadSize bytes apart, so that each mcentral.lock
// gets its own cache line.
// central is indexed by spanClass.
central [numSpanClasses]struct {
mcentral mcentral
pad [(cpu.CacheLinePadSize - unsafe.Sizeof(mcentral{})%cpu.CacheLinePadSize) % cpu.CacheLinePadSize]byte
}
spanalloc fixalloc // allocator for span*
cachealloc fixalloc // allocator for mcache*
specialfinalizeralloc fixalloc // allocator for specialfinalizer*
specialprofilealloc fixalloc // allocator for specialprofile*
specialReachableAlloc fixalloc // allocator for specialReachable
speciallock mutex // lock for special record allocators.
arenaHintAlloc fixalloc // allocator for arenaHints
// User arena state.
//
// Protected by mheap_.lock.
userArena struct {
// arenaHints is a list of addresses at which to attempt to
// add more heap arenas for user arena chunks. This is initially
// populated with a set of general hint addresses, and grown with
// the bounds of actual heap arena ranges.
arenaHints *arenaHint
// quarantineList is a list of user arena spans that have been set to fault, but
// are waiting for all pointers into them to go away. Sweeping handles
// identifying when this is true, and moves the span to the ready list.
quarantineList mSpanList
// readyList is a list of empty user arena spans that are ready for reuse.
readyList mSpanList
}
unused *specialfinalizer // never set, just here to force the specialfinalizer type into DWARF
}
var mheap_ mheap
// A heapArena stores metadata for a heap arena. heapArenas are stored
// outside of the Go heap and accessed via the mheap_.arenas index.
type heapArena struct {
_ sys.NotInHeap
// bitmap stores the pointer/scalar bitmap for the words in
// this arena. See mbitmap.go for a description.
// This array uses 1 bit per word of heap, or 1.6% of the heap size (for 64-bit).
bitmap [heapArenaBitmapWords]uintptr
// If the ith bit of noMorePtrs is true, then there are no more
// pointers for the object containing the word described by the
// high bit of bitmap[i].
// In that case, bitmap[i+1], ... must be zero until the start
// of the next object.
// We never operate on these entries using bit-parallel techniques,
// so it is ok if they are small. Also, they can't be bigger than
// uint16 because at that size a single noMorePtrs entry
// represents 8K of memory, the minimum size of a span. Any larger
// and we'd have to worry about concurrent updates.
// This array uses 1 bit per word of bitmap, or .024% of the heap size (for 64-bit).
noMorePtrs [heapArenaBitmapWords / 8]uint8
// spans maps from virtual address page ID within this arena to *mspan.
// For allocated spans, their pages map to the span itself.
// For free spans, only the lowest and highest pages map to the span itself.
// Internal pages map to an arbitrary span.
// For pages that have never been allocated, spans entries are nil.
//
// Modifications are protected by mheap.lock. Reads can be
// performed without locking, but ONLY from indexes that are
// known to contain in-use or stack spans. This means there
// must not be a safe-point between establishing that an
// address is live and looking it up in the spans array.
spans [pagesPerArena]*mspan
// pageInUse is a bitmap that indicates which spans are in
// state mSpanInUse. This bitmap is indexed by page number,
// but only the bit corresponding to the first page in each
// span is used.
//
// Reads and writes are atomic.
pageInUse [pagesPerArena / 8]uint8
// pageMarks is a bitmap that indicates which spans have any
// marked objects on them. Like pageInUse, only the bit
// corresponding to the first page in each span is used.
//
// Writes are done atomically during marking. Reads are
// non-atomic and lock-free since they only occur during
// sweeping (and hence never race with writes).
//
// This is used to quickly find whole spans that can be freed.
//
// TODO(austin): It would be nice if this was uint64 for
// faster scanning, but we don't have 64-bit atomic bit
// operations.
pageMarks [pagesPerArena / 8]uint8
// pageSpecials is a bitmap that indicates which spans have
// specials (finalizers or other). Like pageInUse, only the bit
// corresponding to the first page in each span is used.
//
// Writes are done atomically whenever a special is added to
// a span and whenever the last special is removed from a span.
// Reads are done atomically to find spans containing specials
// during marking.
pageSpecials [pagesPerArena / 8]uint8
// checkmarks stores the debug.gccheckmark state. It is only
// used if debug.gccheckmark > 0.
checkmarks *checkmarksMap
// zeroedBase marks the first byte of the first page in this
// arena which hasn't been used yet and is therefore already
// zero. zeroedBase is relative to the arena base.
// Increases monotonically until it hits heapArenaBytes.
//
// This field is sufficient to determine if an allocation
// needs to be zeroed because the page allocator follows an
// address-ordered first-fit policy.
//
// Read atomically and written with an atomic CAS.
zeroedBase uintptr
}
// arenaHint is a hint for where to grow the heap arenas. See
// mheap_.arenaHints.
type arenaHint struct {
_ sys.NotInHeap
addr uintptr
down bool
next *arenaHint
}
// An mspan is a run of pages.
//
// When a mspan is in the heap free treap, state == mSpanFree
// and heapmap(s->start) == span, heapmap(s->start+s->npages-1) == span.
// If the mspan is in the heap scav treap, then in addition to the
// above scavenged == true. scavenged == false in all other cases.
//
// When a mspan is allocated, state == mSpanInUse or mSpanManual
// and heapmap(i) == span for all s->start <= i < s->start+s->npages.
// Every mspan is in one doubly-linked list, either in the mheap's
// busy list or one of the mcentral's span lists.
// An mspan representing actual memory has state mSpanInUse,
// mSpanManual, or mSpanFree. Transitions between these states are
// constrained as follows:
//
// - A span may transition from free to in-use or manual during any GC
// phase.
//
// - During sweeping (gcphase == _GCoff), a span may transition from
// in-use to free (as a result of sweeping) or manual to free (as a
// result of stacks being freed).
//
// - During GC (gcphase != _GCoff), a span *must not* transition from
// manual or in-use to free. Because concurrent GC may read a pointer
// and then look up its span, the span state must be monotonic.
//
// Setting mspan.state to mSpanInUse or mSpanManual must be done
// atomically and only after all other span fields are valid.
// Likewise, if inspecting a span is contingent on it being
// mSpanInUse, the state should be loaded atomically and checked
// before depending on other fields. This allows the garbage collector
// to safely deal with potentially invalid pointers, since resolving
// such pointers may race with a span being allocated.
type mSpanState uint8
const (
mSpanDead mSpanState = iota
mSpanInUse // allocated for garbage collected heap
mSpanManual // allocated for manual management (e.g., stack allocator)
)
// mSpanStateNames are the names of the span states, indexed by
// mSpanState.
var mSpanStateNames = []string{
"mSpanDead",
"mSpanInUse",
"mSpanManual",
}
// mSpanStateBox holds an atomic.Uint8 to provide atomic operations on
// an mSpanState. This is a separate type to disallow accidental comparison
// or assignment with mSpanState.
type mSpanStateBox struct {
s atomic.Uint8
}
// It is nosplit to match get, below.
//go:nosplit
func (b *mSpanStateBox) set(s mSpanState) {
b.s.Store(uint8(s))
}
// It is nosplit because it's called indirectly by typedmemclr,
// which must not be preempted.
//go:nosplit
func (b *mSpanStateBox) get() mSpanState {
return mSpanState(b.s.Load())
}
// mSpanList heads a linked list of spans.
type mSpanList struct {
_ sys.NotInHeap
first *mspan // first span in list, or nil if none
last *mspan // last span in list, or nil if none
}
type mspan struct {
_ sys.NotInHeap
next *mspan // next span in list, or nil if none
prev *mspan // previous span in list, or nil if none
list *mSpanList // For debugging. TODO: Remove.
startAddr uintptr // address of first byte of span aka s.base()
npages uintptr // number of pages in span
manualFreeList gclinkptr // list of free objects in mSpanManual spans
// freeindex is the slot index between 0 and nelems at which to begin scanning
// for the next free object in this span.
// Each allocation scans allocBits starting at freeindex until it encounters a 0
// indicating a free object. freeindex is then adjusted so that subsequent scans begin
// just past the newly discovered free object.
//
// If freeindex == nelem, this span has no free objects.
//
// allocBits is a bitmap of objects in this span.
// If n >= freeindex and allocBits[n/8] & (1<<(n%8)) is 0
// then object n is free;
// otherwise, object n is allocated. Bits starting at nelem are
// undefined and should never be referenced.
//
// Object n starts at address n*elemsize + (start << pageShift).
freeindex uintptr
// TODO: Look up nelems from sizeclass and remove this field if it
// helps performance.
nelems uintptr // number of object in the span.
// Cache of the allocBits at freeindex. allocCache is shifted
// such that the lowest bit corresponds to the bit freeindex.
// allocCache holds the complement of allocBits, thus allowing
// ctz (count trailing zero) to use it directly.
// allocCache may contain bits beyond s.nelems; the caller must ignore
// these.
allocCache uint64
// allocBits and gcmarkBits hold pointers to a span's mark and
// allocation bits. The pointers are 8 byte aligned.
// There are three arenas where this data is held.
// free: Dirty arenas that are no longer accessed
// and can be reused.
// next: Holds information to be used in the next GC cycle.
// current: Information being used during this GC cycle.
// previous: Information being used during the last GC cycle.
// A new GC cycle starts with the call to finishsweep_m.
// finishsweep_m moves the previous arena to the free arena,
// the current arena to the previous arena, and
// the next arena to the current arena.
// The next arena is populated as the spans request
// memory to hold gcmarkBits for the next GC cycle as well
// as allocBits for newly allocated spans.
//
// The pointer arithmetic is done "by hand" instead of using
// arrays to avoid bounds checks along critical performance
// paths.
// The sweep will free the old allocBits and set allocBits to the
// gcmarkBits. The gcmarkBits are replaced with a fresh zeroed
// out memory.
allocBits *gcBits
gcmarkBits *gcBits
// sweep generation:
// if sweepgen == h->sweepgen - 2, the span needs sweeping
// if sweepgen == h->sweepgen - 1, the span is currently being swept
// if sweepgen == h->sweepgen, the span is swept and ready to use
// if sweepgen == h->sweepgen + 1, the span was cached before sweep began and is still cached, and needs sweeping
// if sweepgen == h->sweepgen + 3, the span was swept and then cached and is still cached
// h->sweepgen is incremented by 2 after every GC
sweepgen uint32
divMul uint32 // for divide by elemsize
allocCount uint16 // number of allocated objects
spanclass spanClass // size class and noscan (uint8)
state mSpanStateBox // mSpanInUse etc; accessed atomically (get/set methods)
needzero uint8 // needs to be zeroed before allocation
isUserArenaChunk bool // whether or not this span represents a user arena
allocCountBeforeCache uint16 // a copy of allocCount that is stored just before this span is cached
elemsize uintptr // computed from sizeclass or from npages
limit uintptr // end of data in span
speciallock mutex // guards specials list
specials *special // linked list of special records sorted by offset.
userArenaChunkFree addrRange // interval for managing chunk allocation
// freeIndexForScan is like freeindex, except that freeindex is
// used by the allocator whereas freeIndexForScan is used by the
// GC scanner. They are two fields so that the GC sees the object
// is allocated only when the object and the heap bits are
// initialized (see also the assignment of freeIndexForScan in
// mallocgc, and issue 54596).
freeIndexForScan uintptr
}
func (s *mspan) base() uintptr {
return s.startAddr
}
func (s *mspan) layout() (size, n, total uintptr) {
total = s.npages << _PageShift
size = s.elemsize
if size > 0 {
n = total / size
}
return
}
// recordspan adds a newly allocated span to h.allspans.
//
// This only happens the first time a span is allocated from
// mheap.spanalloc (it is not called when a span is reused).
//
// Write barriers are disallowed here because it can be called from
// gcWork when allocating new workbufs. However, because it's an
// indirect call from the fixalloc initializer, the compiler can't see
// this.
//
// The heap lock must be held.
//
//go:nowritebarrierrec
func recordspan(vh unsafe.Pointer, p unsafe.Pointer) {
h := (*mheap)(vh)
s := (*mspan)(p)
assertLockHeld(&h.lock)
if len(h.allspans) >= cap(h.allspans) {
n := 64 * 1024 / goarch.PtrSize
if n < cap(h.allspans)*3/2 {
n = cap(h.allspans) * 3 / 2
}
var new []*mspan
sp := (*slice)(unsafe.Pointer(&new))
sp.array = sysAlloc(uintptr(n)*goarch.PtrSize, &memstats.other_sys)
if sp.array == nil {
throw("runtime: cannot allocate memory")
}
sp.len = len(h.allspans)
sp.cap = n
if len(h.allspans) > 0 {
copy(new, h.allspans)
}
oldAllspans := h.allspans
*(*notInHeapSlice)(unsafe.Pointer(&h.allspans)) = *(*notInHeapSlice)(unsafe.Pointer(&new))
if len(oldAllspans) != 0 {
sysFree(unsafe.Pointer(&oldAllspans[0]), uintptr(cap(oldAllspans))*unsafe.Sizeof(oldAllspans[0]), &memstats.other_sys)
}
}
h.allspans = h.allspans[:len(h.allspans)+1]
h.allspans[len(h.allspans)-1] = s
}
// A spanClass represents the size class and noscan-ness of a span.
//
// Each size class has a noscan spanClass and a scan spanClass. The
// noscan spanClass contains only noscan objects, which do not contain
// pointers and thus do not need to be scanned by the garbage
// collector.
type spanClass uint8
const (
numSpanClasses = _NumSizeClasses << 1
tinySpanClass = spanClass(tinySizeClass<<1 | 1)
)
func makeSpanClass(sizeclass uint8, noscan bool) spanClass {
return spanClass(sizeclass<<1) | spanClass(bool2int(noscan))
}
func (sc spanClass) sizeclass() int8 {
return int8(sc >> 1)
}
func (sc spanClass) noscan() bool {
return sc&1 != 0
}
// arenaIndex returns the index into mheap_.arenas of the arena
// containing metadata for p. This index combines of an index into the
// L1 map and an index into the L2 map and should be used as
// mheap_.arenas[ai.l1()][ai.l2()].
//
// If p is outside the range of valid heap addresses, either l1() or
// l2() will be out of bounds.
//
// It is nosplit because it's called by spanOf and several other
// nosplit functions.
//
//go:nosplit
func arenaIndex(p uintptr) arenaIdx {
return arenaIdx((p - arenaBaseOffset) / heapArenaBytes)
}
// arenaBase returns the low address of the region covered by heap
// arena i.
func arenaBase(i arenaIdx) uintptr {
return uintptr(i)*heapArenaBytes + arenaBaseOffset
}
type arenaIdx uint
// l1 returns the "l1" portion of an arenaIdx.
//
// Marked nosplit because it's called by spanOf and other nosplit
// functions.
//
//go:nosplit
func (i arenaIdx) l1() uint {
if arenaL1Bits == 0 {
// Let the compiler optimize this away if there's no
// L1 map.
return 0
} else {
return uint(i) >> arenaL1Shift
}
}
// l2 returns the "l2" portion of an arenaIdx.
//
// Marked nosplit because it's called by spanOf and other nosplit funcs.
// functions.
//
//go:nosplit
func (i arenaIdx) l2() uint {
if arenaL1Bits == 0 {
return uint(i)
} else {
return uint(i) & (1<<arenaL2Bits - 1)
}
}
// inheap reports whether b is a pointer into a (potentially dead) heap object.
// It returns false for pointers into mSpanManual spans.
// Non-preemptible because it is used by write barriers.
//
//go:nowritebarrier
//go:nosplit
func inheap(b uintptr) bool {
return spanOfHeap(b) != nil
}
// inHeapOrStack is a variant of inheap that returns true for pointers
// into any allocated heap span.
//
//go:nowritebarrier
//go:nosplit
func inHeapOrStack(b uintptr) bool {
s := spanOf(b)
if s == nil || b < s.base() {
return false
}
switch s.state.get() {
case mSpanInUse, mSpanManual:
return b < s.limit
default:
return false
}
}
// spanOf returns the span of p. If p does not point into the heap
// arena or no span has ever contained p, spanOf returns nil.
//
// If p does not point to allocated memory, this may return a non-nil
// span that does *not* contain p. If this is a possibility, the
// caller should either call spanOfHeap or check the span bounds
// explicitly.
//
// Must be nosplit because it has callers that are nosplit.
//
//go:nosplit
func spanOf(p uintptr) *mspan {
// This function looks big, but we use a lot of constant
// folding around arenaL1Bits to get it under the inlining
// budget. Also, many of the checks here are safety checks
// that Go needs to do anyway, so the generated code is quite
// short.
ri := arenaIndex(p)
if arenaL1Bits == 0 {
// If there's no L1, then ri.l1() can't be out of bounds but ri.l2() can.
if ri.l2() >= uint(len(mheap_.arenas[0])) {
return nil
}
} else {
// If there's an L1, then ri.l1() can be out of bounds but ri.l2() can't.
if ri.l1() >= uint(len(mheap_.arenas)) {
return nil
}
}
l2 := mheap_.arenas[ri.l1()]
if arenaL1Bits != 0 && l2 == nil { // Should never happen if there's no L1.
return nil
}
ha := l2[ri.l2()]
if ha == nil {
return nil
}
return ha.spans[(p/pageSize)%pagesPerArena]
}
// spanOfUnchecked is equivalent to spanOf, but the caller must ensure
// that p points into an allocated heap arena.
//
// Must be nosplit because it has callers that are nosplit.
//
//go:nosplit
func spanOfUnchecked(p uintptr) *mspan {
ai := arenaIndex(p)
return mheap_.arenas[ai.l1()][ai.l2()].spans[(p/pageSize)%pagesPerArena]
}
// spanOfHeap is like spanOf, but returns nil if p does not point to a
// heap object.
//
// Must be nosplit because it has callers that are nosplit.
//
//go:nosplit
func spanOfHeap(p uintptr) *mspan {
s := spanOf(p)
// s is nil if it's never been allocated. Otherwise, we check
// its state first because we don't trust this pointer, so we
// have to synchronize with span initialization. Then, it's
// still possible we picked up a stale span pointer, so we
// have to check the span's bounds.
if s == nil || s.state.get() != mSpanInUse || p < s.base() || p >= s.limit {
return nil
}
return s
}
// pageIndexOf returns the arena, page index, and page mask for pointer p.
// The caller must ensure p is in the heap.
func pageIndexOf(p uintptr) (arena *heapArena, pageIdx uintptr, pageMask uint8) {
ai := arenaIndex(p)
arena = mheap_.arenas[ai.l1()][ai.l2()]
pageIdx = ((p / pageSize) / 8) % uintptr(len(arena.pageInUse))
pageMask = byte(1 << ((p / pageSize) % 8))
return
}
// Initialize the heap.
func (h *mheap) init() {
lockInit(&h.lock, lockRankMheap)
lockInit(&h.speciallock, lockRankMheapSpecial)
h.spanalloc.init(unsafe.Sizeof(mspan{}), recordspan, unsafe.Pointer(h), &memstats.mspan_sys)
h.cachealloc.init(unsafe.Sizeof(mcache{}), nil, nil, &memstats.mcache_sys)
h.specialfinalizeralloc.init(unsafe.Sizeof(specialfinalizer{}), nil, nil, &memstats.other_sys)
h.specialprofilealloc.init(unsafe.Sizeof(specialprofile{}), nil, nil, &memstats.other_sys)
h.specialReachableAlloc.init(unsafe.Sizeof(specialReachable{}), nil, nil, &memstats.other_sys)
h.arenaHintAlloc.init(unsafe.Sizeof(arenaHint{}), nil, nil, &memstats.other_sys)
// Don't zero mspan allocations. Background sweeping can
// inspect a span concurrently with allocating it, so it's
// important that the span's sweepgen survive across freeing
// and re-allocating a span to prevent background sweeping
// from improperly cas'ing it from 0.
//
// This is safe because mspan contains no heap pointers.
h.spanalloc.zero = false
// h->mapcache needs no init
for i := range h.central {
h.central[i].mcentral.init(spanClass(i))
}
h.pages.init(&h.lock, &memstats.gcMiscSys)
}
// reclaim sweeps and reclaims at least npage pages into the heap.
// It is called before allocating npage pages to keep growth in check.
//
// reclaim implements the page-reclaimer half of the sweeper.
//
// h.lock must NOT be held.
func (h *mheap) reclaim(npage uintptr) {
// TODO(austin): Half of the time spent freeing spans is in
// locking/unlocking the heap (even with low contention). We
// could make the slow path here several times faster by
// batching heap frees.
// Bail early if there's no more reclaim work.
if h.reclaimIndex.Load() >= 1<<63 {
return
}
// Disable preemption so the GC can't start while we're
// sweeping, so we can read h.sweepArenas, and so
// traceGCSweepStart/Done pair on the P.
mp := acquirem()
if trace.enabled {
traceGCSweepStart()
}
arenas := h.sweepArenas
locked := false
for npage > 0 {
// Pull from accumulated credit first.
if credit := h.reclaimCredit.Load(); credit > 0 {
take := credit
if take > npage {
// Take only what we need.
take = npage
}
if h.reclaimCredit.CompareAndSwap(credit, credit-take) {
npage -= take
}
continue
}
// Claim a chunk of work.
idx := uintptr(h.reclaimIndex.Add(pagesPerReclaimerChunk) - pagesPerReclaimerChunk)
if idx/pagesPerArena >= uintptr(len(arenas)) {
// Page reclaiming is done.
h.reclaimIndex.Store(1 << 63)
break
}
if !locked {
// Lock the heap for reclaimChunk.
lock(&h.lock)
locked = true
}
// Scan this chunk.
nfound := h.reclaimChunk(arenas, idx, pagesPerReclaimerChunk)
if nfound <= npage {
npage -= nfound
} else {
// Put spare pages toward global credit.
h.reclaimCredit.Add(nfound - npage)
npage = 0
}
}
if locked {
unlock(&h.lock)
}
if trace.enabled {
traceGCSweepDone()
}
releasem(mp)
}
// reclaimChunk sweeps unmarked spans that start at page indexes [pageIdx, pageIdx+n).
// It returns the number of pages returned to the heap.
//
// h.lock must be held and the caller must be non-preemptible. Note: h.lock may be
// temporarily unlocked and re-locked in order to do sweeping or if tracing is
// enabled.
func (h *mheap) reclaimChunk(arenas []arenaIdx, pageIdx, n uintptr) uintptr {
// The heap lock must be held because this accesses the
// heapArena.spans arrays using potentially non-live pointers.
// In particular, if a span were freed and merged concurrently
// with this probing heapArena.spans, it would be possible to
// observe arbitrary, stale span pointers.
assertLockHeld(&h.lock)
n0 := n
var nFreed uintptr
sl := sweep.active.begin()
if !sl.valid {
return 0
}
for n > 0 {
ai := arenas[pageIdx/pagesPerArena]
ha := h.arenas[ai.l1()][ai.l2()]
// Get a chunk of the bitmap to work on.
arenaPage := uint(pageIdx % pagesPerArena)
inUse := ha.pageInUse[arenaPage/8:]
marked := ha.pageMarks[arenaPage/8:]
if uintptr(len(inUse)) > n/8 {
inUse = inUse[:n/8]
marked = marked[:n/8]
}
// Scan this bitmap chunk for spans that are in-use
// but have no marked objects on them.
for i := range inUse {
inUseUnmarked := atomic.Load8(&inUse[i]) &^ marked[i]
if inUseUnmarked == 0 {
continue
}
for j := uint(0); j < 8; j++ {
if inUseUnmarked&(1<<j) != 0 {
s := ha.spans[arenaPage+uint(i)*8+j]
if s, ok := sl.tryAcquire(s); ok {
npages := s.npages
unlock(&h.lock)
if s.sweep(false) {
nFreed += npages
}
lock(&h.lock)
// Reload inUse. It's possible nearby
// spans were freed when we dropped the
// lock and we don't want to get stale
// pointers from the spans array.
inUseUnmarked = atomic.Load8(&inUse[i]) &^ marked[i]
}
}
}
}
// Advance.
pageIdx += uintptr(len(inUse) * 8)
n -= uintptr(len(inUse) * 8)
}
sweep.active.end(sl)
if trace.enabled {
unlock(&h.lock)
// Account for pages scanned but not reclaimed.
traceGCSweepSpan((n0 - nFreed) * pageSize)
lock(&h.lock)
}
assertLockHeld(&h.lock) // Must be locked on return.
return nFreed
}
// spanAllocType represents the type of allocation to make, or
// the type of allocation to be freed.
type spanAllocType uint8
const (
spanAllocHeap spanAllocType = iota // heap span
spanAllocStack // stack span
spanAllocPtrScalarBits // unrolled GC prog bitmap span
spanAllocWorkBuf // work buf span
)
// manual returns true if the span allocation is manually managed.
func (s spanAllocType) manual() bool {
return s != spanAllocHeap
}
// alloc allocates a new span of npage pages from the GC'd heap.
//
// spanclass indicates the span's size class and scannability.
//
// Returns a span that has been fully initialized. span.needzero indicates
// whether the span has been zeroed. Note that it may not be.
func (h *mheap) alloc(npages uintptr, spanclass spanClass) *mspan {
// Don't do any operations that lock the heap on the G stack.
// It might trigger stack growth, and the stack growth code needs
// to be able to allocate heap.
var s *mspan
systemstack(func() {
// To prevent excessive heap growth, before allocating n pages
// we need to sweep and reclaim at least n pages.
if !isSweepDone() {
h.reclaim(npages)
}
s = h.allocSpan(npages, spanAllocHeap, spanclass)
})
return s
}
// allocManual allocates a manually-managed span of npage pages.
// allocManual returns nil if allocation fails.
//
// allocManual adds the bytes used to *stat, which should be a
// memstats in-use field. Unlike allocations in the GC'd heap, the
// allocation does *not* count toward heapInUse.
//
// The memory backing the returned span may not be zeroed if
// span.needzero is set.
//
// allocManual must be called on the system stack because it may
// acquire the heap lock via allocSpan. See mheap for details.
//
// If new code is written to call allocManual, do NOT use an
// existing spanAllocType value and instead declare a new one.
//
//go:systemstack
func (h *mheap) allocManual(npages uintptr, typ spanAllocType) *mspan {
if !typ.manual() {
throw("manual span allocation called with non-manually-managed type")
}
return h.allocSpan(npages, typ, 0)
}
// setSpans modifies the span map so [spanOf(base), spanOf(base+npage*pageSize))
// is s.
func (h *mheap) setSpans(base, npage uintptr, s *mspan) {
p := base / pageSize
ai := arenaIndex(base)
ha := h.arenas[ai.l1()][ai.l2()]
for n := uintptr(0); n < npage; n++ {
i := (p + n) % pagesPerArena
if i == 0 {
ai = arenaIndex(base + n*pageSize)
ha = h.arenas[ai.l1()][ai.l2()]
}
ha.spans[i] = s
}
}
// allocNeedsZero checks if the region of address space [base, base+npage*pageSize),
// assumed to be allocated, needs to be zeroed, updating heap arena metadata for
// future allocations.
//
// This must be called each time pages are allocated from the heap, even if the page
// allocator can otherwise prove the memory it's allocating is already zero because
// they're fresh from the operating system. It updates heapArena metadata that is
// critical for future page allocations.
//
// There are no locking constraints on this method.
func (h *mheap) allocNeedsZero(base, npage uintptr) (needZero bool) {
for npage > 0 {
ai := arenaIndex(base)
ha := h.arenas[ai.l1()][ai.l2()]
zeroedBase := atomic.Loaduintptr(&ha.zeroedBase)
arenaBase := base % heapArenaBytes
if arenaBase < zeroedBase {
// We extended into the non-zeroed part of the
// arena, so this region needs to be zeroed before use.
//
// zeroedBase is monotonically increasing, so if we see this now then
// we can be sure we need to zero this memory region.
//
// We still need to update zeroedBase for this arena, and
// potentially more arenas.
needZero = true
}
// We may observe arenaBase > zeroedBase if we're racing with one or more
// allocations which are acquiring memory directly before us in the address
// space. But, because we know no one else is acquiring *this* memory, it's
// still safe to not zero.
// Compute how far into the arena we extend into, capped
// at heapArenaBytes.
arenaLimit := arenaBase + npage*pageSize
if arenaLimit > heapArenaBytes {
arenaLimit = heapArenaBytes
}
// Increase ha.zeroedBase so it's >= arenaLimit.
// We may be racing with other updates.
for arenaLimit > zeroedBase {
if atomic.Casuintptr(&ha.zeroedBase, zeroedBase, arenaLimit) {
break
}
zeroedBase = atomic.Loaduintptr(&ha.zeroedBase)
// Double check basic conditions of zeroedBase.
if zeroedBase <= arenaLimit && zeroedBase > arenaBase {
// The zeroedBase moved into the space we were trying to
// claim. That's very bad, and indicates someone allocated
// the same region we did.
throw("potentially overlapping in-use allocations detected")
}
}
// Move base forward and subtract from npage to move into
// the next arena, or finish.
base += arenaLimit - arenaBase
npage -= (arenaLimit - arenaBase) / pageSize
}
return
}
// tryAllocMSpan attempts to allocate an mspan object from
// the P-local cache, but may fail.
//
// h.lock need not be held.
//
// This caller must ensure that its P won't change underneath
// it during this function. Currently to ensure that we enforce
// that the function is run on the system stack, because that's
// the only place it is used now. In the future, this requirement
// may be relaxed if its use is necessary elsewhere.
//
//go:systemstack
func (h *mheap) tryAllocMSpan() *mspan {
pp := getg().m.p.ptr()
// If we don't have a p or the cache is empty, we can't do
// anything here.
if pp == nil || pp.mspancache.len == 0 {
return nil
}
// Pull off the last entry in the cache.
s := pp.mspancache.buf[pp.mspancache.len-1]
pp.mspancache.len--
return s
}
// allocMSpanLocked allocates an mspan object.
//
// h.lock must be held.
//
// allocMSpanLocked must be called on the system stack because
// its caller holds the heap lock. See mheap for details.
// Running on the system stack also ensures that we won't
// switch Ps during this function. See tryAllocMSpan for details.
//
//go:systemstack
func (h *mheap) allocMSpanLocked() *mspan {
assertLockHeld(&h.lock)
pp := getg().m.p.ptr()
if pp == nil {
// We don't have a p so just do the normal thing.
return (*mspan)(h.spanalloc.alloc())
}
// Refill the cache if necessary.
if pp.mspancache.len == 0 {
const refillCount = len(pp.mspancache.buf) / 2
for i := 0; i < refillCount; i++ {
pp.mspancache.buf[i] = (*mspan)(h.spanalloc.alloc())
}
pp.mspancache.len = refillCount
}
// Pull off the last entry in the cache.
s := pp.mspancache.buf[pp.mspancache.len-1]
pp.mspancache.len--
return s
}
// freeMSpanLocked free an mspan object.
//
// h.lock must be held.
//
// freeMSpanLocked must be called on the system stack because
// its caller holds the heap lock. See mheap for details.
// Running on the system stack also ensures that we won't
// switch Ps during this function. See tryAllocMSpan for details.
//
//go:systemstack
func (h *mheap) freeMSpanLocked(s *mspan) {
assertLockHeld(&h.lock)
pp := getg().m.p.ptr()
// First try to free the mspan directly to the cache.
if pp != nil && pp.mspancache.len < len(pp.mspancache.buf) {
pp.mspancache.buf[pp.mspancache.len] = s
pp.mspancache.len++
return
}
// Failing that (or if we don't have a p), just free it to
// the heap.
h.spanalloc.free(unsafe.Pointer(s))
}
// allocSpan allocates an mspan which owns npages worth of memory.
//
// If typ.manual() == false, allocSpan allocates a heap span of class spanclass
// and updates heap accounting. If manual == true, allocSpan allocates a
// manually-managed span (spanclass is ignored), and the caller is
// responsible for any accounting related to its use of the span. Either
// way, allocSpan will atomically add the bytes in the newly allocated
// span to *sysStat.
//
// The returned span is fully initialized.
//
// h.lock must not be held.
//
// allocSpan must be called on the system stack both because it acquires
// the heap lock and because it must block GC transitions.
//
//go:systemstack
func (h *mheap) allocSpan(npages uintptr, typ spanAllocType, spanclass spanClass) (s *mspan) {
// Function-global state.
gp := getg()
base, scav := uintptr(0), uintptr(0)
growth := uintptr(0)
// On some platforms we need to provide physical page aligned stack
// allocations. Where the page size is less than the physical page
// size, we already manage to do this by default.
needPhysPageAlign := physPageAlignedStacks && typ == spanAllocStack && pageSize < physPageSize
// If the allocation is small enough, try the page cache!
// The page cache does not support aligned allocations, so we cannot use
// it if we need to provide a physical page aligned stack allocation.
pp := gp.m.p.ptr()
if !needPhysPageAlign && pp != nil && npages < pageCachePages/4 {
c := &pp.pcache
// If the cache is empty, refill it.
if c.empty() {
lock(&h.lock)
*c = h.pages.allocToCache()
unlock(&h.lock)
}
// Try to allocate from the cache.
base, scav = c.alloc(npages)
if base != 0 {
s = h.tryAllocMSpan()
if s != nil {
goto HaveSpan
}
// We have a base but no mspan, so we need
// to lock the heap.
}
}
// For one reason or another, we couldn't get the
// whole job done without the heap lock.
lock(&h.lock)
if needPhysPageAlign {
// Overallocate by a physical page to allow for later alignment.
extraPages := physPageSize / pageSize
// Find a big enough region first, but then only allocate the
// aligned portion. We can't just allocate and then free the
// edges because we need to account for scavenged memory, and
// that's difficult with alloc.
//
// Note that we skip updates to searchAddr here. It's OK if
// it's stale and higher than normal; it'll operate correctly,
// just come with a performance cost.
base, _ = h.pages.find(npages + extraPages)
if base == 0 {
var ok bool
growth, ok = h.grow(npages + extraPages)
if !ok {
unlock(&h.lock)
return nil
}
base, _ = h.pages.find(npages + extraPages)
if base == 0 {
throw("grew heap, but no adequate free space found")
}
}
base = alignUp(base, physPageSize)
scav = h.pages.allocRange(base, npages)
}
if base == 0 {
// Try to acquire a base address.
base, scav = h.pages.alloc(npages)
if base == 0 {
var ok bool
growth, ok = h.grow(npages)
if !ok {
unlock(&h.lock)
return nil
}
base, scav = h.pages.alloc(npages)
if base == 0 {
throw("grew heap, but no adequate free space found")
}
}
}
if s == nil {
// We failed to get an mspan earlier, so grab
// one now that we have the heap lock.
s = h.allocMSpanLocked()
}
unlock(&h.lock)
HaveSpan:
// Decide if we need to scavenge in response to what we just allocated.
// Specifically, we track the maximum amount of memory to scavenge of all
// the alternatives below, assuming that the maximum satisfies *all*
// conditions we check (e.g. if we need to scavenge X to satisfy the
// memory limit and Y to satisfy heap-growth scavenging, and Y > X, then
// it's fine to pick Y, because the memory limit is still satisfied).
//
// It's fine to do this after allocating because we expect any scavenged
// pages not to get touched until we return. Simultaneously, it's important
// to do this before calling sysUsed because that may commit address space.
bytesToScavenge := uintptr(0)
if limit := gcController.memoryLimit.Load(); !gcCPULimiter.limiting() {
// Assist with scavenging to maintain the memory limit by the amount
// that we expect to page in.
inuse := gcController.mappedReady.Load()
// Be careful about overflow, especially with uintptrs. Even on 32-bit platforms
// someone can set a really big memory limit that isn't maxInt64.
if uint64(scav)+inuse > uint64(limit) {
bytesToScavenge = uintptr(uint64(scav) + inuse - uint64(limit))
}
}
if goal := scavenge.gcPercentGoal.Load(); goal != ^uint64(0) && growth > 0 {
// We just caused a heap growth, so scavenge down what will soon be used.
// By scavenging inline we deal with the failure to allocate out of
// memory fragments by scavenging the memory fragments that are least
// likely to be re-used.
//
// Only bother with this because we're not using a memory limit. We don't
// care about heap growths as long as we're under the memory limit, and the
// previous check for scaving already handles that.
if retained := heapRetained(); retained+uint64(growth) > goal {
// The scavenging algorithm requires the heap lock to be dropped so it
// can acquire it only sparingly. This is a potentially expensive operation
// so it frees up other goroutines to allocate in the meanwhile. In fact,
// they can make use of the growth we just created.
todo := growth
if overage := uintptr(retained + uint64(growth) - goal); todo > overage {
todo = overage
}
if todo > bytesToScavenge {
bytesToScavenge = todo
}
}
}
// There are a few very limited circumstances where we won't have a P here.
// It's OK to simply skip scavenging in these cases. Something else will notice
// and pick up the tab.
var now int64
if pp != nil && bytesToScavenge > 0 {
// Measure how long we spent scavenging and add that measurement to the assist
// time so we can track it for the GC CPU limiter.
//
// Limiter event tracking might be disabled if we end up here
// while on a mark worker.
start := nanotime()
track := pp.limiterEvent.start(limiterEventScavengeAssist, start)
// Scavenge, but back out if the limiter turns on.
h.pages.scavenge(bytesToScavenge, func() bool {
return gcCPULimiter.limiting()
})
// Finish up accounting.
now = nanotime()
if track {
pp.limiterEvent.stop(limiterEventScavengeAssist, now)
}
scavenge.assistTime.Add(now - start)
}
// Initialize the span.
h.initSpan(s, typ, spanclass, base, npages)
// Commit and account for any scavenged memory that the span now owns.
nbytes := npages * pageSize
if scav != 0 {
// sysUsed all the pages that are actually available
// in the span since some of them might be scavenged.
sysUsed(unsafe.Pointer(base), nbytes, scav)
gcController.heapReleased.add(-int64(scav))
}
// Update stats.
gcController.heapFree.add(-int64(nbytes - scav))
if typ == spanAllocHeap {
gcController.heapInUse.add(int64(nbytes))
}
// Update consistent stats.
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.committed, int64(scav))
atomic.Xaddint64(&stats.released, -int64(scav))
switch typ {
case spanAllocHeap:
atomic.Xaddint64(&stats.inHeap, int64(nbytes))
case spanAllocStack:
atomic.Xaddint64(&stats.inStacks, int64(nbytes))
case spanAllocPtrScalarBits:
atomic.Xaddint64(&stats.inPtrScalarBits, int64(nbytes))
case spanAllocWorkBuf:
atomic.Xaddint64(&stats.inWorkBufs, int64(nbytes))
}
memstats.heapStats.release()
pageTraceAlloc(pp, now, base, npages)
return s
}
// initSpan initializes a blank span s which will represent the range
// [base, base+npages*pageSize). typ is the type of span being allocated.
func (h *mheap) initSpan(s *mspan, typ spanAllocType, spanclass spanClass, base, npages uintptr) {
// At this point, both s != nil and base != 0, and the heap
// lock is no longer held. Initialize the span.
s.init(base, npages)
if h.allocNeedsZero(base, npages) {
s.needzero = 1
}
nbytes := npages * pageSize
if typ.manual() {
s.manualFreeList = 0
s.nelems = 0
s.limit = s.base() + s.npages*pageSize
s.state.set(mSpanManual)
} else {
// We must set span properties before the span is published anywhere
// since we're not holding the heap lock.
s.spanclass = spanclass
if sizeclass := spanclass.sizeclass(); sizeclass == 0 {
s.elemsize = nbytes
s.nelems = 1
s.divMul = 0
} else {
s.elemsize = uintptr(class_to_size[sizeclass])
s.nelems = nbytes / s.elemsize
s.divMul = class_to_divmagic[sizeclass]
}
// Initialize mark and allocation structures.
s.freeindex = 0
s.freeIndexForScan = 0
s.allocCache = ^uint64(0) // all 1s indicating all free.
s.gcmarkBits = newMarkBits(s.nelems)
s.allocBits = newAllocBits(s.nelems)
// It's safe to access h.sweepgen without the heap lock because it's
// only ever updated with the world stopped and we run on the
// systemstack which blocks a STW transition.
atomic.Store(&s.sweepgen, h.sweepgen)
// Now that the span is filled in, set its state. This
// is a publication barrier for the other fields in
// the span. While valid pointers into this span
// should never be visible until the span is returned,
// if the garbage collector finds an invalid pointer,
// access to the span may race with initialization of
// the span. We resolve this race by atomically
// setting the state after the span is fully
// initialized, and atomically checking the state in
// any situation where a pointer is suspect.
s.state.set(mSpanInUse)
}
// Publish the span in various locations.
// This is safe to call without the lock held because the slots
// related to this span will only ever be read or modified by
// this thread until pointers into the span are published (and
// we execute a publication barrier at the end of this function
// before that happens) or pageInUse is updated.
h.setSpans(s.base(), npages, s)
if !typ.manual() {
// Mark in-use span in arena page bitmap.
//
// This publishes the span to the page sweeper, so
// it's imperative that the span be completely initialized
// prior to this line.
arena, pageIdx, pageMask := pageIndexOf(s.base())
atomic.Or8(&arena.pageInUse[pageIdx], pageMask)
// Update related page sweeper stats.
h.pagesInUse.Add(npages)
}
// Make sure the newly allocated span will be observed
// by the GC before pointers into the span are published.
publicationBarrier()
}
// Try to add at least npage pages of memory to the heap,
// returning how much the heap grew by and whether it worked.
//
// h.lock must be held.
func (h *mheap) grow(npage uintptr) (uintptr, bool) {
assertLockHeld(&h.lock)
// We must grow the heap in whole palloc chunks.
// We call sysMap below but note that because we
// round up to pallocChunkPages which is on the order
// of MiB (generally >= to the huge page size) we
// won't be calling it too much.
ask := alignUp(npage, pallocChunkPages) * pageSize
totalGrowth := uintptr(0)
// This may overflow because ask could be very large
// and is otherwise unrelated to h.curArena.base.
end := h.curArena.base + ask
nBase := alignUp(end, physPageSize)
if nBase > h.curArena.end || /* overflow */ end < h.curArena.base {
// Not enough room in the current arena. Allocate more
// arena space. This may not be contiguous with the
// current arena, so we have to request the full ask.
av, asize := h.sysAlloc(ask, &h.arenaHints, true)
if av == nil {
inUse := gcController.heapFree.load() + gcController.heapReleased.load() + gcController.heapInUse.load()
print("runtime: out of memory: cannot allocate ", ask, "-byte block (", inUse, " in use)\n")
return 0, false
}
if uintptr(av) == h.curArena.end {
// The new space is contiguous with the old
// space, so just extend the current space.
h.curArena.end = uintptr(av) + asize
} else {
// The new space is discontiguous. Track what
// remains of the current space and switch to
// the new space. This should be rare.
if size := h.curArena.end - h.curArena.base; size != 0 {
// Transition this space from Reserved to Prepared and mark it
// as released since we'll be able to start using it after updating
// the page allocator and releasing the lock at any time.
sysMap(unsafe.Pointer(h.curArena.base), size, &gcController.heapReleased)
// Update stats.
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.released, int64(size))
memstats.heapStats.release()
// Update the page allocator's structures to make this
// space ready for allocation.
h.pages.grow(h.curArena.base, size)
totalGrowth += size
}
// Switch to the new space.
h.curArena.base = uintptr(av)
h.curArena.end = uintptr(av) + asize
}
// Recalculate nBase.
// We know this won't overflow, because sysAlloc returned
// a valid region starting at h.curArena.base which is at
// least ask bytes in size.
nBase = alignUp(h.curArena.base+ask, physPageSize)
}
// Grow into the current arena.
v := h.curArena.base
h.curArena.base = nBase
// Transition the space we're going to use from Reserved to Prepared.
//
// The allocation is always aligned to the heap arena
// size which is always > physPageSize, so its safe to
// just add directly to heapReleased.
sysMap(unsafe.Pointer(v), nBase-v, &gcController.heapReleased)
// The memory just allocated counts as both released
// and idle, even though it's not yet backed by spans.
stats := memstats.heapStats.acquire()
atomic.Xaddint64(&stats.released, int64(nBase-v))
memstats.heapStats.release()
// Update the page allocator's structures to make this
// space ready for allocation.
h.pages.grow(v, nBase-v)
totalGrowth += nBase - v
return totalGrowth, true
}
// Free the span back into the heap.
func (h *mheap) freeSpan(s *mspan) {
systemstack(func() {
pageTraceFree(getg().m.p.ptr(), 0, s.base(), s.npages)
lock(&h.lock)
if msanenabled {
// Tell msan that this entire span is no longer in use.
base := unsafe.Pointer(s.base())
bytes := s.npages << _PageShift
msanfree(base, bytes)
}
if asanenabled {
// Tell asan that this entire span is no longer in use.
base := unsafe.Pointer(s.base())
bytes := s.npages << _PageShift
asanpoison(base, bytes)
}
h.freeSpanLocked(s, spanAllocHeap)
unlock(&h.lock)
})
}
// freeManual frees a manually-managed span returned by allocManual.
// typ must be the same as the spanAllocType passed to the allocManual that
// allocated s.
//
// This must only be called when gcphase == _GCoff. See mSpanState for
// an explanation.
//
// freeManual must be called on the system stack because it acquires
// the heap lock. See mheap for details.
//
//go:systemstack
func (h *mheap) freeManual(s *mspan, typ spanAllocType) {
pageTraceFree(getg().m.p.ptr(), 0, s.base(), s.npages)
s.needzero = 1
lock(&h.lock)
h.freeSpanLocked(s, typ)
unlock(&h.lock)
}
func (h *mheap) freeSpanLocked(s *mspan, typ spanAllocType) {
assertLockHeld(&h.lock)
switch s.state.get() {
case mSpanManual:
if s.allocCount != 0 {
throw("mheap.freeSpanLocked - invalid stack free")
}
case mSpanInUse:
if s.isUserArenaChunk {
throw("mheap.freeSpanLocked - invalid free of user arena chunk")
}
if s.allocCount != 0 || s.sweepgen != h.sweepgen {
print("mheap.freeSpanLocked - span ", s, " ptr ", hex(s.base()), " allocCount ", s.allocCount, " sweepgen ", s.sweepgen, "/", h.sweepgen, "\n")
throw("mheap.freeSpanLocked - invalid free")
}
h.pagesInUse.Add(-s.npages)
// Clear in-use bit in arena page bitmap.
arena, pageIdx, pageMask := pageIndexOf(s.base())
atomic.And8(&arena.pageInUse[pageIdx], ^pageMask)
default:
throw("mheap.freeSpanLocked - invalid span state")
}
// Update stats.
//
// Mirrors the code in allocSpan.
nbytes := s.npages * pageSize
gcController.heapFree.add(int64(nbytes))
if typ == spanAllocHeap {
gcController.heapInUse.add(-int64(nbytes))
}
// Update consistent stats.
stats := memstats.heapStats.acquire()
switch typ {
case spanAllocHeap:
atomic.Xaddint64(&stats.inHeap, -int64(nbytes))
case spanAllocStack:
atomic.Xaddint64(&stats.inStacks, -int64(nbytes))
case spanAllocPtrScalarBits:
atomic.Xaddint64(&stats.inPtrScalarBits, -int64(nbytes))
case spanAllocWorkBuf:
atomic.Xaddint64(&stats.inWorkBufs, -int64(nbytes))
}
memstats.heapStats.release()
// Mark the space as free.
h.pages.free(s.base(), s.npages, false)
// Free the span structure. We no longer have a use for it.
s.state.set(mSpanDead)
h.freeMSpanLocked(s)
}
// scavengeAll acquires the heap lock (blocking any additional
// manipulation of the page allocator) and iterates over the whole
// heap, scavenging every free page available.
func (h *mheap) scavengeAll() {
// Disallow malloc or panic while holding the heap lock. We do
// this here because this is a non-mallocgc entry-point to
// the mheap API.
gp := getg()
gp.m.mallocing++
released := h.pages.scavenge(^uintptr(0), nil)
gp.m.mallocing--
if debug.scavtrace > 0 {
printScavTrace(released, true)
}
}
//go:linkname runtime_debug_freeOSMemory runtime/debug.freeOSMemory
func runtime_debug_freeOSMemory() {
GC()
systemstack(func() { mheap_.scavengeAll() })
}
// Initialize a new span with the given start and npages.
func (span *mspan) init(base uintptr, npages uintptr) {
// span is *not* zeroed.
span.next = nil
span.prev = nil
span.list = nil
span.startAddr = base
span.npages = npages
span.allocCount = 0
span.spanclass = 0
span.elemsize = 0
span.speciallock.key = 0
span.specials = nil
span.needzero = 0
span.freeindex = 0
span.freeIndexForScan = 0
span.allocBits = nil
span.gcmarkBits = nil
span.state.set(mSpanDead)
lockInit(&span.speciallock, lockRankMspanSpecial)
}
func (span *mspan) inList() bool {
return span.list != nil
}
// Initialize an empty doubly-linked list.
func (list *mSpanList) init() {
list.first = nil
list.last = nil
}
func (list *mSpanList) remove(span *mspan) {
if span.list != list {
print("runtime: failed mSpanList.remove span.npages=", span.npages,
" span=", span, " prev=", span.prev, " span.list=", span.list, " list=", list, "\n")
throw("mSpanList.remove")
}
if list.first == span {
list.first = span.next
} else {
span.prev.next = span.next
}
if list.last == span {
list.last = span.prev
} else {
span.next.prev = span.prev
}
span.next = nil
span.prev = nil
span.list = nil
}
func (list *mSpanList) isEmpty() bool {
return list.first == nil
}
func (list *mSpanList) insert(span *mspan) {
if span.next != nil || span.prev != nil || span.list != nil {
println("runtime: failed mSpanList.insert", span, span.next, span.prev, span.list)
throw("mSpanList.insert")
}
span.next = list.first
if list.first != nil {
// The list contains at least one span; link it in.
// The last span in the list doesn't change.
list.first.prev = span
} else {
// The list contains no spans, so this is also the last span.
list.last = span
}
list.first = span
span.list = list
}
func (list *mSpanList) insertBack(span *mspan) {
if span.next != nil || span.prev != nil || span.list != nil {
println("runtime: failed mSpanList.insertBack", span, span.next, span.prev, span.list)
throw("mSpanList.insertBack")
}
span.prev = list.last
if list.last != nil {
// The list contains at least one span.
list.last.next = span
} else {
// The list contains no spans, so this is also the first span.
list.first = span
}
list.last = span
span.list = list
}
// takeAll removes all spans from other and inserts them at the front
// of list.
func (list *mSpanList) takeAll(other *mSpanList) {
if other.isEmpty() {
return
}
// Reparent everything in other to list.
for s := other.first; s != nil; s = s.next {
s.list = list
}
// Concatenate the lists.
if list.isEmpty() {
*list = *other
} else {
// Neither list is empty. Put other before list.
other.last.next = list.first
list.first.prev = other.last
list.first = other.first
}
other.first, other.last = nil, nil
}
const (
_KindSpecialFinalizer = 1
_KindSpecialProfile = 2
// _KindSpecialReachable is a special used for tracking
// reachability during testing.
_KindSpecialReachable = 3
// Note: The finalizer special must be first because if we're freeing
// an object, a finalizer special will cause the freeing operation
// to abort, and we want to keep the other special records around
// if that happens.
)
type special struct {
_ sys.NotInHeap
next *special // linked list in span
offset uint16 // span offset of object
kind byte // kind of special
}
// spanHasSpecials marks a span as having specials in the arena bitmap.
func spanHasSpecials(s *mspan) {
arenaPage := (s.base() / pageSize) % pagesPerArena
ai := arenaIndex(s.base())
ha := mheap_.arenas[ai.l1()][ai.l2()]
atomic.Or8(&ha.pageSpecials[arenaPage/8], uint8(1)<<(arenaPage%8))
}
// spanHasNoSpecials marks a span as having no specials in the arena bitmap.
func spanHasNoSpecials(s *mspan) {
arenaPage := (s.base() / pageSize) % pagesPerArena
ai := arenaIndex(s.base())
ha := mheap_.arenas[ai.l1()][ai.l2()]
atomic.And8(&ha.pageSpecials[arenaPage/8], ^(uint8(1) << (arenaPage % 8)))
}
// Adds the special record s to the list of special records for
// the object p. All fields of s should be filled in except for
// offset & next, which this routine will fill in.
// Returns true if the special was successfully added, false otherwise.
// (The add will fail only if a record with the same p and s->kind
// already exists.)
func addspecial(p unsafe.Pointer, s *special) bool {
span := spanOfHeap(uintptr(p))
if span == nil {
throw("addspecial on invalid pointer")
}
// Ensure that the span is swept.
// Sweeping accesses the specials list w/o locks, so we have
// to synchronize with it. And it's just much safer.
mp := acquirem()
span.ensureSwept()
offset := uintptr(p) - span.base()
kind := s.kind
lock(&span.speciallock)
// Find splice point, check for existing record.
t := &span.specials
for {
x := *t
if x == nil {
break
}
if offset == uintptr(x.offset) && kind == x.kind {
unlock(&span.speciallock)
releasem(mp)
return false // already exists
}
if offset < uintptr(x.offset) || (offset == uintptr(x.offset) && kind < x.kind) {
break
}
t = &x.next
}
// Splice in record, fill in offset.
s.offset = uint16(offset)
s.next = *t
*t = s
spanHasSpecials(span)
unlock(&span.speciallock)
releasem(mp)
return true
}
// Removes the Special record of the given kind for the object p.
// Returns the record if the record existed, nil otherwise.
// The caller must FixAlloc_Free the result.
func removespecial(p unsafe.Pointer, kind uint8) *special {
span := spanOfHeap(uintptr(p))
if span == nil {
throw("removespecial on invalid pointer")
}
// Ensure that the span is swept.
// Sweeping accesses the specials list w/o locks, so we have
// to synchronize with it. And it's just much safer.
mp := acquirem()
span.ensureSwept()
offset := uintptr(p) - span.base()
var result *special
lock(&span.speciallock)
t := &span.specials
for {
s := *t
if s == nil {
break
}
// This function is used for finalizers only, so we don't check for
// "interior" specials (p must be exactly equal to s->offset).
if offset == uintptr(s.offset) && kind == s.kind {
*t = s.next
result = s
break
}
t = &s.next
}
if span.specials == nil {
spanHasNoSpecials(span)
}
unlock(&span.speciallock)
releasem(mp)
return result
}
// The described object has a finalizer set for it.
//
// specialfinalizer is allocated from non-GC'd memory, so any heap
// pointers must be specially handled.
type specialfinalizer struct {
_ sys.NotInHeap
special special
fn *funcval // May be a heap pointer.
nret uintptr
fint *_type // May be a heap pointer, but always live.
ot *ptrtype // May be a heap pointer, but always live.
}
// Adds a finalizer to the object p. Returns true if it succeeded.
func addfinalizer(p unsafe.Pointer, f *funcval, nret uintptr, fint *_type, ot *ptrtype) bool {
lock(&mheap_.speciallock)
s := (*specialfinalizer)(mheap_.specialfinalizeralloc.alloc())
unlock(&mheap_.speciallock)
s.special.kind = _KindSpecialFinalizer
s.fn = f
s.nret = nret
s.fint = fint
s.ot = ot
if addspecial(p, &s.special) {
// This is responsible for maintaining the same
// GC-related invariants as markrootSpans in any
// situation where it's possible that markrootSpans
// has already run but mark termination hasn't yet.
if gcphase != _GCoff {
base, span, _ := findObject(uintptr(p), 0, 0)
mp := acquirem()
gcw := &mp.p.ptr().gcw
// Mark everything reachable from the object
// so it's retained for the finalizer.
if !span.spanclass.noscan() {
scanobject(base, gcw)
}
// Mark the finalizer itself, since the
// special isn't part of the GC'd heap.
scanblock(uintptr(unsafe.Pointer(&s.fn)), goarch.PtrSize, &oneptrmask[0], gcw, nil)
releasem(mp)
}
return true
}
// There was an old finalizer
lock(&mheap_.speciallock)
mheap_.specialfinalizeralloc.free(unsafe.Pointer(s))
unlock(&mheap_.speciallock)
return false
}
// Removes the finalizer (if any) from the object p.
func removefinalizer(p unsafe.Pointer) {
s := (*specialfinalizer)(unsafe.Pointer(removespecial(p, _KindSpecialFinalizer)))
if s == nil {
return // there wasn't a finalizer to remove
}
lock(&mheap_.speciallock)
mheap_.specialfinalizeralloc.free(unsafe.Pointer(s))
unlock(&mheap_.speciallock)
}
// The described object is being heap profiled.
type specialprofile struct {
_ sys.NotInHeap
special special
b *bucket
}
// Set the heap profile bucket associated with addr to b.
func setprofilebucket(p unsafe.Pointer, b *bucket) {
lock(&mheap_.speciallock)
s := (*specialprofile)(mheap_.specialprofilealloc.alloc())
unlock(&mheap_.speciallock)
s.special.kind = _KindSpecialProfile
s.b = b
if !addspecial(p, &s.special) {
throw("setprofilebucket: profile already set")
}
}
// specialReachable tracks whether an object is reachable on the next
// GC cycle. This is used by testing.
type specialReachable struct {
special special
done bool
reachable bool
}
// specialsIter helps iterate over specials lists.
type specialsIter struct {
pprev **special
s *special
}
func newSpecialsIter(span *mspan) specialsIter {
return specialsIter{&span.specials, span.specials}
}
func (i *specialsIter) valid() bool {
return i.s != nil
}
func (i *specialsIter) next() {
i.pprev = &i.s.next
i.s = *i.pprev
}
// unlinkAndNext removes the current special from the list and moves
// the iterator to the next special. It returns the unlinked special.
func (i *specialsIter) unlinkAndNext() *special {
cur := i.s
i.s = cur.next
*i.pprev = i.s
return cur
}
// freeSpecial performs any cleanup on special s and deallocates it.
// s must already be unlinked from the specials list.
func freeSpecial(s *special, p unsafe.Pointer, size uintptr) {
switch s.kind {
case _KindSpecialFinalizer:
sf := (*specialfinalizer)(unsafe.Pointer(s))
queuefinalizer(p, sf.fn, sf.nret, sf.fint, sf.ot)
lock(&mheap_.speciallock)
mheap_.specialfinalizeralloc.free(unsafe.Pointer(sf))
unlock(&mheap_.speciallock)
case _KindSpecialProfile:
sp := (*specialprofile)(unsafe.Pointer(s))
mProf_Free(sp.b, size)
lock(&mheap_.speciallock)
mheap_.specialprofilealloc.free(unsafe.Pointer(sp))
unlock(&mheap_.speciallock)
case _KindSpecialReachable:
sp := (*specialReachable)(unsafe.Pointer(s))
sp.done = true
// The creator frees these.
default:
throw("bad special kind")
panic("not reached")
}
}
// gcBits is an alloc/mark bitmap. This is always used as gcBits.x.
type gcBits struct {
_ sys.NotInHeap
x uint8
}
// bytep returns a pointer to the n'th byte of b.
func (b *gcBits) bytep(n uintptr) *uint8 {
return addb(&b.x, n)
}
// bitp returns a pointer to the byte containing bit n and a mask for
// selecting that bit from *bytep.
func (b *gcBits) bitp(n uintptr) (bytep *uint8, mask uint8) {
return b.bytep(n / 8), 1 << (n % 8)
}
const gcBitsChunkBytes = uintptr(64 << 10)
const gcBitsHeaderBytes = unsafe.Sizeof(gcBitsHeader{})
type gcBitsHeader struct {
free uintptr // free is the index into bits of the next free byte.
next uintptr // *gcBits triggers recursive type bug. (issue 14620)
}
type gcBitsArena struct {
_ sys.NotInHeap
// gcBitsHeader // side step recursive type bug (issue 14620) by including fields by hand.
free uintptr // free is the index into bits of the next free byte; read/write atomically
next *gcBitsArena
bits [gcBitsChunkBytes - gcBitsHeaderBytes]gcBits
}
var gcBitsArenas struct {
lock mutex
free *gcBitsArena
next *gcBitsArena // Read atomically. Write atomically under lock.
current *gcBitsArena
previous *gcBitsArena
}
// tryAlloc allocates from b or returns nil if b does not have enough room.
// This is safe to call concurrently.
func (b *gcBitsArena) tryAlloc(bytes uintptr) *gcBits {
if b == nil || atomic.Loaduintptr(&b.free)+bytes > uintptr(len(b.bits)) {
return nil
}
// Try to allocate from this block.
end := atomic.Xadduintptr(&b.free, bytes)
if end > uintptr(len(b.bits)) {
return nil
}
// There was enough room.
start := end - bytes
return &b.bits[start]
}
// newMarkBits returns a pointer to 8 byte aligned bytes
// to be used for a span's mark bits.
func newMarkBits(nelems uintptr) *gcBits {
blocksNeeded := uintptr((nelems + 63) / 64)
bytesNeeded := blocksNeeded * 8
// Try directly allocating from the current head arena.
head := (*gcBitsArena)(atomic.Loadp(unsafe.Pointer(&gcBitsArenas.next)))
if p := head.tryAlloc(bytesNeeded); p != nil {
return p
}
// There's not enough room in the head arena. We may need to
// allocate a new arena.
lock(&gcBitsArenas.lock)
// Try the head arena again, since it may have changed. Now
// that we hold the lock, the list head can't change, but its
// free position still can.
if p := gcBitsArenas.next.tryAlloc(bytesNeeded); p != nil {
unlock(&gcBitsArenas.lock)
return p
}
// Allocate a new arena. This may temporarily drop the lock.
fresh := newArenaMayUnlock()
// If newArenaMayUnlock dropped the lock, another thread may
// have put a fresh arena on the "next" list. Try allocating
// from next again.
if p := gcBitsArenas.next.tryAlloc(bytesNeeded); p != nil {
// Put fresh back on the free list.
// TODO: Mark it "already zeroed"
fresh.next = gcBitsArenas.free
gcBitsArenas.free = fresh
unlock(&gcBitsArenas.lock)
return p
}
// Allocate from the fresh arena. We haven't linked it in yet, so
// this cannot race and is guaranteed to succeed.
p := fresh.tryAlloc(bytesNeeded)
if p == nil {
throw("markBits overflow")
}
// Add the fresh arena to the "next" list.
fresh.next = gcBitsArenas.next
atomic.StorepNoWB(unsafe.Pointer(&gcBitsArenas.next), unsafe.Pointer(fresh))
unlock(&gcBitsArenas.lock)
return p
}
// newAllocBits returns a pointer to 8 byte aligned bytes
// to be used for this span's alloc bits.
// newAllocBits is used to provide newly initialized spans
// allocation bits. For spans not being initialized the
// mark bits are repurposed as allocation bits when
// the span is swept.
func newAllocBits(nelems uintptr) *gcBits {
return newMarkBits(nelems)
}
// nextMarkBitArenaEpoch establishes a new epoch for the arenas
// holding the mark bits. The arenas are named relative to the
// current GC cycle which is demarcated by the call to finishweep_m.
//
// All current spans have been swept.
// During that sweep each span allocated room for its gcmarkBits in
// gcBitsArenas.next block. gcBitsArenas.next becomes the gcBitsArenas.current
// where the GC will mark objects and after each span is swept these bits
// will be used to allocate objects.
// gcBitsArenas.current becomes gcBitsArenas.previous where the span's
// gcAllocBits live until all the spans have been swept during this GC cycle.
// The span's sweep extinguishes all the references to gcBitsArenas.previous
// by pointing gcAllocBits into the gcBitsArenas.current.
// The gcBitsArenas.previous is released to the gcBitsArenas.free list.
func nextMarkBitArenaEpoch() {
lock(&gcBitsArenas.lock)
if gcBitsArenas.previous != nil {
if gcBitsArenas.free == nil {
gcBitsArenas.free = gcBitsArenas.previous
} else {
// Find end of previous arenas.
last := gcBitsArenas.previous
for last = gcBitsArenas.previous; last.next != nil; last = last.next {
}
last.next = gcBitsArenas.free
gcBitsArenas.free = gcBitsArenas.previous
}
}
gcBitsArenas.previous = gcBitsArenas.current
gcBitsArenas.current = gcBitsArenas.next
atomic.StorepNoWB(unsafe.Pointer(&gcBitsArenas.next), nil) // newMarkBits calls newArena when needed
unlock(&gcBitsArenas.lock)
}
// newArenaMayUnlock allocates and zeroes a gcBits arena.
// The caller must hold gcBitsArena.lock. This may temporarily release it.
func newArenaMayUnlock() *gcBitsArena {
var result *gcBitsArena
if gcBitsArenas.free == nil {
unlock(&gcBitsArenas.lock)
result = (*gcBitsArena)(sysAlloc(gcBitsChunkBytes, &memstats.gcMiscSys))
if result == nil {
throw("runtime: cannot allocate memory")
}
lock(&gcBitsArenas.lock)
} else {
result = gcBitsArenas.free
gcBitsArenas.free = gcBitsArenas.free.next
memclrNoHeapPointers(unsafe.Pointer(result), gcBitsChunkBytes)
}
result.next = nil
// If result.bits is not 8 byte aligned adjust index so
// that &result.bits[result.free] is 8 byte aligned.
if uintptr(unsafe.Offsetof(gcBitsArena{}.bits))&7 == 0 {
result.free = 0
} else {
result.free = 8 - (uintptr(unsafe.Pointer(&result.bits[0])) & 7)
}
return result
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Page allocator.
//
// The page allocator manages mapped pages (defined by pageSize, NOT
// physPageSize) for allocation and re-use. It is embedded into mheap.
//
// Pages are managed using a bitmap that is sharded into chunks.
// In the bitmap, 1 means in-use, and 0 means free. The bitmap spans the
// process's address space. Chunks are managed in a sparse-array-style structure
// similar to mheap.arenas, since the bitmap may be large on some systems.
//
// The bitmap is efficiently searched by using a radix tree in combination
// with fast bit-wise intrinsics. Allocation is performed using an address-ordered
// first-fit approach.
//
// Each entry in the radix tree is a summary that describes three properties of
// a particular region of the address space: the number of contiguous free pages
// at the start and end of the region it represents, and the maximum number of
// contiguous free pages found anywhere in that region.
//
// Each level of the radix tree is stored as one contiguous array, which represents
// a different granularity of subdivision of the processes' address space. Thus, this
// radix tree is actually implicit in these large arrays, as opposed to having explicit
// dynamically-allocated pointer-based node structures. Naturally, these arrays may be
// quite large for system with large address spaces, so in these cases they are mapped
// into memory as needed. The leaf summaries of the tree correspond to a bitmap chunk.
//
// The root level (referred to as L0 and index 0 in pageAlloc.summary) has each
// summary represent the largest section of address space (16 GiB on 64-bit systems),
// with each subsequent level representing successively smaller subsections until we
// reach the finest granularity at the leaves, a chunk.
//
// More specifically, each summary in each level (except for leaf summaries)
// represents some number of entries in the following level. For example, each
// summary in the root level may represent a 16 GiB region of address space,
// and in the next level there could be 8 corresponding entries which represent 2
// GiB subsections of that 16 GiB region, each of which could correspond to 8
// entries in the next level which each represent 256 MiB regions, and so on.
//
// Thus, this design only scales to heaps so large, but can always be extended to
// larger heaps by simply adding levels to the radix tree, which mostly costs
// additional virtual address space. The choice of managing large arrays also means
// that a large amount of virtual address space may be reserved by the runtime.
package runtime
import (
"unsafe"
)
const (
// The size of a bitmap chunk, i.e. the amount of bits (that is, pages) to consider
// in the bitmap at once.
pallocChunkPages = 1 << logPallocChunkPages
pallocChunkBytes = pallocChunkPages * pageSize
logPallocChunkPages = 9
logPallocChunkBytes = logPallocChunkPages + pageShift
// The number of radix bits for each level.
//
// The value of 3 is chosen such that the block of summaries we need to scan at
// each level fits in 64 bytes (2^3 summaries * 8 bytes per summary), which is
// close to the L1 cache line width on many systems. Also, a value of 3 fits 4 tree
// levels perfectly into the 21-bit pallocBits summary field at the root level.
//
// The following equation explains how each of the constants relate:
// summaryL0Bits + (summaryLevels-1)*summaryLevelBits + logPallocChunkBytes = heapAddrBits
//
// summaryLevels is an architecture-dependent value defined in mpagealloc_*.go.
summaryLevelBits = 3
summaryL0Bits = heapAddrBits - logPallocChunkBytes - (summaryLevels-1)*summaryLevelBits
// pallocChunksL2Bits is the number of bits of the chunk index number
// covered by the second level of the chunks map.
//
// See (*pageAlloc).chunks for more details. Update the documentation
// there should this change.
pallocChunksL2Bits = heapAddrBits - logPallocChunkBytes - pallocChunksL1Bits
pallocChunksL1Shift = pallocChunksL2Bits
)
// maxSearchAddr returns the maximum searchAddr value, which indicates
// that the heap has no free space.
//
// This function exists just to make it clear that this is the maximum address
// for the page allocator's search space. See maxOffAddr for details.
//
// It's a function (rather than a variable) because it needs to be
// usable before package runtime's dynamic initialization is complete.
// See #51913 for details.
func maxSearchAddr() offAddr { return maxOffAddr }
// Global chunk index.
//
// Represents an index into the leaf level of the radix tree.
// Similar to arenaIndex, except instead of arenas, it divides the address
// space into chunks.
type chunkIdx uint
// chunkIndex returns the global index of the palloc chunk containing the
// pointer p.
func chunkIndex(p uintptr) chunkIdx {
return chunkIdx((p - arenaBaseOffset) / pallocChunkBytes)
}
// chunkBase returns the base address of the palloc chunk at index ci.
func chunkBase(ci chunkIdx) uintptr {
return uintptr(ci)*pallocChunkBytes + arenaBaseOffset
}
// chunkPageIndex computes the index of the page that contains p,
// relative to the chunk which contains p.
func chunkPageIndex(p uintptr) uint {
return uint(p % pallocChunkBytes / pageSize)
}
// l1 returns the index into the first level of (*pageAlloc).chunks.
func (i chunkIdx) l1() uint {
if pallocChunksL1Bits == 0 {
// Let the compiler optimize this away if there's no
// L1 map.
return 0
} else {
return uint(i) >> pallocChunksL1Shift
}
}
// l2 returns the index into the second level of (*pageAlloc).chunks.
func (i chunkIdx) l2() uint {
if pallocChunksL1Bits == 0 {
return uint(i)
} else {
return uint(i) & (1<<pallocChunksL2Bits - 1)
}
}
// offAddrToLevelIndex converts an address in the offset address space
// to the index into summary[level] containing addr.
func offAddrToLevelIndex(level int, addr offAddr) int {
return int((addr.a - arenaBaseOffset) >> levelShift[level])
}
// levelIndexToOffAddr converts an index into summary[level] into
// the corresponding address in the offset address space.
func levelIndexToOffAddr(level, idx int) offAddr {
return offAddr{(uintptr(idx) << levelShift[level]) + arenaBaseOffset}
}
// addrsToSummaryRange converts base and limit pointers into a range
// of entries for the given summary level.
//
// The returned range is inclusive on the lower bound and exclusive on
// the upper bound.
func addrsToSummaryRange(level int, base, limit uintptr) (lo int, hi int) {
// This is slightly more nuanced than just a shift for the exclusive
// upper-bound. Note that the exclusive upper bound may be within a
// summary at this level, meaning if we just do the obvious computation
// hi will end up being an inclusive upper bound. Unfortunately, just
// adding 1 to that is too broad since we might be on the very edge
// of a summary's max page count boundary for this level
// (1 << levelLogPages[level]). So, make limit an inclusive upper bound
// then shift, then add 1, so we get an exclusive upper bound at the end.
lo = int((base - arenaBaseOffset) >> levelShift[level])
hi = int(((limit-1)-arenaBaseOffset)>>levelShift[level]) + 1
return
}
// blockAlignSummaryRange aligns indices into the given level to that
// level's block width (1 << levelBits[level]). It assumes lo is inclusive
// and hi is exclusive, and so aligns them down and up respectively.
func blockAlignSummaryRange(level int, lo, hi int) (int, int) {
e := uintptr(1) << levelBits[level]
return int(alignDown(uintptr(lo), e)), int(alignUp(uintptr(hi), e))
}
type pageAlloc struct {
// Radix tree of summaries.
//
// Each slice's cap represents the whole memory reservation.
// Each slice's len reflects the allocator's maximum known
// mapped heap address for that level.
//
// The backing store of each summary level is reserved in init
// and may or may not be committed in grow (small address spaces
// may commit all the memory in init).
//
// The purpose of keeping len <= cap is to enforce bounds checks
// on the top end of the slice so that instead of an unknown
// runtime segmentation fault, we get a much friendlier out-of-bounds
// error.
//
// To iterate over a summary level, use inUse to determine which ranges
// are currently available. Otherwise one might try to access
// memory which is only Reserved which may result in a hard fault.
//
// We may still get segmentation faults < len since some of that
// memory may not be committed yet.
summary [summaryLevels][]pallocSum
// chunks is a slice of bitmap chunks.
//
// The total size of chunks is quite large on most 64-bit platforms
// (O(GiB) or more) if flattened, so rather than making one large mapping
// (which has problems on some platforms, even when PROT_NONE) we use a
// two-level sparse array approach similar to the arena index in mheap.
//
// To find the chunk containing a memory address `a`, do:
// chunkOf(chunkIndex(a))
//
// Below is a table describing the configuration for chunks for various
// heapAddrBits supported by the runtime.
//
// heapAddrBits | L1 Bits | L2 Bits | L2 Entry Size
// ------------------------------------------------
// 32 | 0 | 10 | 128 KiB
// 33 (iOS) | 0 | 11 | 256 KiB
// 48 | 13 | 13 | 1 MiB
//
// There's no reason to use the L1 part of chunks on 32-bit, the
// address space is small so the L2 is small. For platforms with a
// 48-bit address space, we pick the L1 such that the L2 is 1 MiB
// in size, which is a good balance between low granularity without
// making the impact on BSS too high (note the L1 is stored directly
// in pageAlloc).
//
// To iterate over the bitmap, use inUse to determine which ranges
// are currently available. Otherwise one might iterate over unused
// ranges.
//
// Protected by mheapLock.
//
// TODO(mknyszek): Consider changing the definition of the bitmap
// such that 1 means free and 0 means in-use so that summaries and
// the bitmaps align better on zero-values.
chunks [1 << pallocChunksL1Bits]*[1 << pallocChunksL2Bits]pallocData
// The address to start an allocation search with. It must never
// point to any memory that is not contained in inUse, i.e.
// inUse.contains(searchAddr.addr()) must always be true. The one
// exception to this rule is that it may take on the value of
// maxOffAddr to indicate that the heap is exhausted.
//
// We guarantee that all valid heap addresses below this value
// are allocated and not worth searching.
searchAddr offAddr
// start and end represent the chunk indices
// which pageAlloc knows about. It assumes
// chunks in the range [start, end) are
// currently ready to use.
start, end chunkIdx
// inUse is a slice of ranges of address space which are
// known by the page allocator to be currently in-use (passed
// to grow).
//
// This field is currently unused on 32-bit architectures but
// is harmless to track. We care much more about having a
// contiguous heap in these cases and take additional measures
// to ensure that, so in nearly all cases this should have just
// 1 element.
//
// All access is protected by the mheapLock.
inUse addrRanges
// scav stores the scavenger state.
scav struct {
// index is an efficient index of chunks that have pages available to
// scavenge.
index scavengeIndex
// released is the amount of memory released this scavenge cycle.
//
// Updated atomically.
released uintptr
}
// mheap_.lock. This level of indirection makes it possible
// to test pageAlloc independently of the runtime allocator.
mheapLock *mutex
// sysStat is the runtime memstat to update when new system
// memory is committed by the pageAlloc for allocation metadata.
sysStat *sysMemStat
// summaryMappedReady is the number of bytes mapped in the Ready state
// in the summary structure. Used only for testing currently.
//
// Protected by mheapLock.
summaryMappedReady uintptr
// Whether or not this struct is being used in tests.
test bool
}
func (p *pageAlloc) init(mheapLock *mutex, sysStat *sysMemStat) {
if levelLogPages[0] > logMaxPackedValue {
// We can't represent 1<<levelLogPages[0] pages, the maximum number
// of pages we need to represent at the root level, in a summary, which
// is a big problem. Throw.
print("runtime: root level max pages = ", 1<<levelLogPages[0], "\n")
print("runtime: summary max pages = ", maxPackedValue, "\n")
throw("root level max pages doesn't fit in summary")
}
p.sysStat = sysStat
// Initialize p.inUse.
p.inUse.init(sysStat)
// System-dependent initialization.
p.sysInit()
// Start with the searchAddr in a state indicating there's no free memory.
p.searchAddr = maxSearchAddr()
// Set the mheapLock.
p.mheapLock = mheapLock
}
// tryChunkOf returns the bitmap data for the given chunk.
//
// Returns nil if the chunk data has not been mapped.
func (p *pageAlloc) tryChunkOf(ci chunkIdx) *pallocData {
l2 := p.chunks[ci.l1()]
if l2 == nil {
return nil
}
return &l2[ci.l2()]
}
// chunkOf returns the chunk at the given chunk index.
//
// The chunk index must be valid or this method may throw.
func (p *pageAlloc) chunkOf(ci chunkIdx) *pallocData {
return &p.chunks[ci.l1()][ci.l2()]
}
// grow sets up the metadata for the address range [base, base+size).
// It may allocate metadata, in which case *p.sysStat will be updated.
//
// p.mheapLock must be held.
func (p *pageAlloc) grow(base, size uintptr) {
assertLockHeld(p.mheapLock)
// Round up to chunks, since we can't deal with increments smaller
// than chunks. Also, sysGrow expects aligned values.
limit := alignUp(base+size, pallocChunkBytes)
base = alignDown(base, pallocChunkBytes)
// Grow the summary levels in a system-dependent manner.
// We just update a bunch of additional metadata here.
p.sysGrow(base, limit)
// Update p.start and p.end.
// If no growth happened yet, start == 0. This is generally
// safe since the zero page is unmapped.
firstGrowth := p.start == 0
start, end := chunkIndex(base), chunkIndex(limit)
if firstGrowth || start < p.start {
p.start = start
}
if end > p.end {
p.end = end
}
// Note that [base, limit) will never overlap with any existing
// range inUse because grow only ever adds never-used memory
// regions to the page allocator.
p.inUse.add(makeAddrRange(base, limit))
// A grow operation is a lot like a free operation, so if our
// chunk ends up below p.searchAddr, update p.searchAddr to the
// new address, just like in free.
if b := (offAddr{base}); b.lessThan(p.searchAddr) {
p.searchAddr = b
}
// Add entries into chunks, which is sparse, if needed. Then,
// initialize the bitmap.
//
// Newly-grown memory is always considered scavenged.
// Set all the bits in the scavenged bitmaps high.
for c := chunkIndex(base); c < chunkIndex(limit); c++ {
if p.chunks[c.l1()] == nil {
// Create the necessary l2 entry.
r := sysAlloc(unsafe.Sizeof(*p.chunks[0]), p.sysStat)
if r == nil {
throw("pageAlloc: out of memory")
}
// Store the new chunk block but avoid a write barrier.
// grow is used in call chains that disallow write barriers.
*(*uintptr)(unsafe.Pointer(&p.chunks[c.l1()])) = uintptr(r)
}
p.chunkOf(c).scavenged.setRange(0, pallocChunkPages)
}
// Update summaries accordingly. The grow acts like a free, so
// we need to ensure this newly-free memory is visible in the
// summaries.
p.update(base, size/pageSize, true, false)
}
// update updates heap metadata. It must be called each time the bitmap
// is updated.
//
// If contig is true, update does some optimizations assuming that there was
// a contiguous allocation or free between addr and addr+npages. alloc indicates
// whether the operation performed was an allocation or a free.
//
// p.mheapLock must be held.
func (p *pageAlloc) update(base, npages uintptr, contig, alloc bool) {
assertLockHeld(p.mheapLock)
// base, limit, start, and end are inclusive.
limit := base + npages*pageSize - 1
sc, ec := chunkIndex(base), chunkIndex(limit)
// Handle updating the lowest level first.
if sc == ec {
// Fast path: the allocation doesn't span more than one chunk,
// so update this one and if the summary didn't change, return.
x := p.summary[len(p.summary)-1][sc]
y := p.chunkOf(sc).summarize()
if x == y {
return
}
p.summary[len(p.summary)-1][sc] = y
} else if contig {
// Slow contiguous path: the allocation spans more than one chunk
// and at least one summary is guaranteed to change.
summary := p.summary[len(p.summary)-1]
// Update the summary for chunk sc.
summary[sc] = p.chunkOf(sc).summarize()
// Update the summaries for chunks in between, which are
// either totally allocated or freed.
whole := p.summary[len(p.summary)-1][sc+1 : ec]
if alloc {
// Should optimize into a memclr.
for i := range whole {
whole[i] = 0
}
} else {
for i := range whole {
whole[i] = freeChunkSum
}
}
// Update the summary for chunk ec.
summary[ec] = p.chunkOf(ec).summarize()
} else {
// Slow general path: the allocation spans more than one chunk
// and at least one summary is guaranteed to change.
//
// We can't assume a contiguous allocation happened, so walk over
// every chunk in the range and manually recompute the summary.
summary := p.summary[len(p.summary)-1]
for c := sc; c <= ec; c++ {
summary[c] = p.chunkOf(c).summarize()
}
}
// Walk up the radix tree and update the summaries appropriately.
changed := true
for l := len(p.summary) - 2; l >= 0 && changed; l-- {
// Update summaries at level l from summaries at level l+1.
changed = false
// "Constants" for the previous level which we
// need to compute the summary from that level.
logEntriesPerBlock := levelBits[l+1]
logMaxPages := levelLogPages[l+1]
// lo and hi describe all the parts of the level we need to look at.
lo, hi := addrsToSummaryRange(l, base, limit+1)
// Iterate over each block, updating the corresponding summary in the less-granular level.
for i := lo; i < hi; i++ {
children := p.summary[l+1][i<<logEntriesPerBlock : (i+1)<<logEntriesPerBlock]
sum := mergeSummaries(children, logMaxPages)
old := p.summary[l][i]
if old != sum {
changed = true
p.summary[l][i] = sum
}
}
}
}
// allocRange marks the range of memory [base, base+npages*pageSize) as
// allocated. It also updates the summaries to reflect the newly-updated
// bitmap.
//
// Returns the amount of scavenged memory in bytes present in the
// allocated range.
//
// p.mheapLock must be held.
func (p *pageAlloc) allocRange(base, npages uintptr) uintptr {
assertLockHeld(p.mheapLock)
limit := base + npages*pageSize - 1
sc, ec := chunkIndex(base), chunkIndex(limit)
si, ei := chunkPageIndex(base), chunkPageIndex(limit)
scav := uint(0)
if sc == ec {
// The range doesn't cross any chunk boundaries.
chunk := p.chunkOf(sc)
scav += chunk.scavenged.popcntRange(si, ei+1-si)
chunk.allocRange(si, ei+1-si)
} else {
// The range crosses at least one chunk boundary.
chunk := p.chunkOf(sc)
scav += chunk.scavenged.popcntRange(si, pallocChunkPages-si)
chunk.allocRange(si, pallocChunkPages-si)
for c := sc + 1; c < ec; c++ {
chunk := p.chunkOf(c)
scav += chunk.scavenged.popcntRange(0, pallocChunkPages)
chunk.allocAll()
}
chunk = p.chunkOf(ec)
scav += chunk.scavenged.popcntRange(0, ei+1)
chunk.allocRange(0, ei+1)
}
p.update(base, npages, true, true)
return uintptr(scav) * pageSize
}
// findMappedAddr returns the smallest mapped offAddr that is
// >= addr. That is, if addr refers to mapped memory, then it is
// returned. If addr is higher than any mapped region, then
// it returns maxOffAddr.
//
// p.mheapLock must be held.
func (p *pageAlloc) findMappedAddr(addr offAddr) offAddr {
assertLockHeld(p.mheapLock)
// If we're not in a test, validate first by checking mheap_.arenas.
// This is a fast path which is only safe to use outside of testing.
ai := arenaIndex(addr.addr())
if p.test || mheap_.arenas[ai.l1()] == nil || mheap_.arenas[ai.l1()][ai.l2()] == nil {
vAddr, ok := p.inUse.findAddrGreaterEqual(addr.addr())
if ok {
return offAddr{vAddr}
} else {
// The candidate search address is greater than any
// known address, which means we definitely have no
// free memory left.
return maxOffAddr
}
}
return addr
}
// find searches for the first (address-ordered) contiguous free region of
// npages in size and returns a base address for that region.
//
// It uses p.searchAddr to prune its search and assumes that no palloc chunks
// below chunkIndex(p.searchAddr) contain any free memory at all.
//
// find also computes and returns a candidate p.searchAddr, which may or
// may not prune more of the address space than p.searchAddr already does.
// This candidate is always a valid p.searchAddr.
//
// find represents the slow path and the full radix tree search.
//
// Returns a base address of 0 on failure, in which case the candidate
// searchAddr returned is invalid and must be ignored.
//
// p.mheapLock must be held.
func (p *pageAlloc) find(npages uintptr) (uintptr, offAddr) {
assertLockHeld(p.mheapLock)
// Search algorithm.
//
// This algorithm walks each level l of the radix tree from the root level
// to the leaf level. It iterates over at most 1 << levelBits[l] of entries
// in a given level in the radix tree, and uses the summary information to
// find either:
// 1) That a given subtree contains a large enough contiguous region, at
// which point it continues iterating on the next level, or
// 2) That there are enough contiguous boundary-crossing bits to satisfy
// the allocation, at which point it knows exactly where to start
// allocating from.
//
// i tracks the index into the current level l's structure for the
// contiguous 1 << levelBits[l] entries we're actually interested in.
//
// NOTE: Technically this search could allocate a region which crosses
// the arenaBaseOffset boundary, which when arenaBaseOffset != 0, is
// a discontinuity. However, the only way this could happen is if the
// page at the zero address is mapped, and this is impossible on
// every system we support where arenaBaseOffset != 0. So, the
// discontinuity is already encoded in the fact that the OS will never
// map the zero page for us, and this function doesn't try to handle
// this case in any way.
// i is the beginning of the block of entries we're searching at the
// current level.
i := 0
// firstFree is the region of address space that we are certain to
// find the first free page in the heap. base and bound are the inclusive
// bounds of this window, and both are addresses in the linearized, contiguous
// view of the address space (with arenaBaseOffset pre-added). At each level,
// this window is narrowed as we find the memory region containing the
// first free page of memory. To begin with, the range reflects the
// full process address space.
//
// firstFree is updated by calling foundFree each time free space in the
// heap is discovered.
//
// At the end of the search, base.addr() is the best new
// searchAddr we could deduce in this search.
firstFree := struct {
base, bound offAddr
}{
base: minOffAddr,
bound: maxOffAddr,
}
// foundFree takes the given address range [addr, addr+size) and
// updates firstFree if it is a narrower range. The input range must
// either be fully contained within firstFree or not overlap with it
// at all.
//
// This way, we'll record the first summary we find with any free
// pages on the root level and narrow that down if we descend into
// that summary. But as soon as we need to iterate beyond that summary
// in a level to find a large enough range, we'll stop narrowing.
foundFree := func(addr offAddr, size uintptr) {
if firstFree.base.lessEqual(addr) && addr.add(size-1).lessEqual(firstFree.bound) {
// This range fits within the current firstFree window, so narrow
// down the firstFree window to the base and bound of this range.
firstFree.base = addr
firstFree.bound = addr.add(size - 1)
} else if !(addr.add(size-1).lessThan(firstFree.base) || firstFree.bound.lessThan(addr)) {
// This range only partially overlaps with the firstFree range,
// so throw.
print("runtime: addr = ", hex(addr.addr()), ", size = ", size, "\n")
print("runtime: base = ", hex(firstFree.base.addr()), ", bound = ", hex(firstFree.bound.addr()), "\n")
throw("range partially overlaps")
}
}
// lastSum is the summary which we saw on the previous level that made us
// move on to the next level. Used to print additional information in the
// case of a catastrophic failure.
// lastSumIdx is that summary's index in the previous level.
lastSum := packPallocSum(0, 0, 0)
lastSumIdx := -1
nextLevel:
for l := 0; l < len(p.summary); l++ {
// For the root level, entriesPerBlock is the whole level.
entriesPerBlock := 1 << levelBits[l]
logMaxPages := levelLogPages[l]
// We've moved into a new level, so let's update i to our new
// starting index. This is a no-op for level 0.
i <<= levelBits[l]
// Slice out the block of entries we care about.
entries := p.summary[l][i : i+entriesPerBlock]
// Determine j0, the first index we should start iterating from.
// The searchAddr may help us eliminate iterations if we followed the
// searchAddr on the previous level or we're on the root level, in which
// case the searchAddr should be the same as i after levelShift.
j0 := 0
if searchIdx := offAddrToLevelIndex(l, p.searchAddr); searchIdx&^(entriesPerBlock-1) == i {
j0 = searchIdx & (entriesPerBlock - 1)
}
// Run over the level entries looking for
// a contiguous run of at least npages either
// within an entry or across entries.
//
// base contains the page index (relative to
// the first entry's first page) of the currently
// considered run of consecutive pages.
//
// size contains the size of the currently considered
// run of consecutive pages.
var base, size uint
for j := j0; j < len(entries); j++ {
sum := entries[j]
if sum == 0 {
// A full entry means we broke any streak and
// that we should skip it altogether.
size = 0
continue
}
// We've encountered a non-zero summary which means
// free memory, so update firstFree.
foundFree(levelIndexToOffAddr(l, i+j), (uintptr(1)<<logMaxPages)*pageSize)
s := sum.start()
if size+s >= uint(npages) {
// If size == 0 we don't have a run yet,
// which means base isn't valid. So, set
// base to the first page in this block.
if size == 0 {
base = uint(j) << logMaxPages
}
// We hit npages; we're done!
size += s
break
}
if sum.max() >= uint(npages) {
// The entry itself contains npages contiguous
// free pages, so continue on the next level
// to find that run.
i += j
lastSumIdx = i
lastSum = sum
continue nextLevel
}
if size == 0 || s < 1<<logMaxPages {
// We either don't have a current run started, or this entry
// isn't totally free (meaning we can't continue the current
// one), so try to begin a new run by setting size and base
// based on sum.end.
size = sum.end()
base = uint(j+1)<<logMaxPages - size
continue
}
// The entry is completely free, so continue the run.
size += 1 << logMaxPages
}
if size >= uint(npages) {
// We found a sufficiently large run of free pages straddling
// some boundary, so compute the address and return it.
addr := levelIndexToOffAddr(l, i).add(uintptr(base) * pageSize).addr()
return addr, p.findMappedAddr(firstFree.base)
}
if l == 0 {
// We're at level zero, so that means we've exhausted our search.
return 0, maxSearchAddr()
}
// We're not at level zero, and we exhausted the level we were looking in.
// This means that either our calculations were wrong or the level above
// lied to us. In either case, dump some useful state and throw.
print("runtime: summary[", l-1, "][", lastSumIdx, "] = ", lastSum.start(), ", ", lastSum.max(), ", ", lastSum.end(), "\n")
print("runtime: level = ", l, ", npages = ", npages, ", j0 = ", j0, "\n")
print("runtime: p.searchAddr = ", hex(p.searchAddr.addr()), ", i = ", i, "\n")
print("runtime: levelShift[level] = ", levelShift[l], ", levelBits[level] = ", levelBits[l], "\n")
for j := 0; j < len(entries); j++ {
sum := entries[j]
print("runtime: summary[", l, "][", i+j, "] = (", sum.start(), ", ", sum.max(), ", ", sum.end(), ")\n")
}
throw("bad summary data")
}
// Since we've gotten to this point, that means we haven't found a
// sufficiently-sized free region straddling some boundary (chunk or larger).
// This means the last summary we inspected must have had a large enough "max"
// value, so look inside the chunk to find a suitable run.
//
// After iterating over all levels, i must contain a chunk index which
// is what the final level represents.
ci := chunkIdx(i)
j, searchIdx := p.chunkOf(ci).find(npages, 0)
if j == ^uint(0) {
// We couldn't find any space in this chunk despite the summaries telling
// us it should be there. There's likely a bug, so dump some state and throw.
sum := p.summary[len(p.summary)-1][i]
print("runtime: summary[", len(p.summary)-1, "][", i, "] = (", sum.start(), ", ", sum.max(), ", ", sum.end(), ")\n")
print("runtime: npages = ", npages, "\n")
throw("bad summary data")
}
// Compute the address at which the free space starts.
addr := chunkBase(ci) + uintptr(j)*pageSize
// Since we actually searched the chunk, we may have
// found an even narrower free window.
searchAddr := chunkBase(ci) + uintptr(searchIdx)*pageSize
foundFree(offAddr{searchAddr}, chunkBase(ci+1)-searchAddr)
return addr, p.findMappedAddr(firstFree.base)
}
// alloc allocates npages worth of memory from the page heap, returning the base
// address for the allocation and the amount of scavenged memory in bytes
// contained in the region [base address, base address + npages*pageSize).
//
// Returns a 0 base address on failure, in which case other returned values
// should be ignored.
//
// p.mheapLock must be held.
//
// Must run on the system stack because p.mheapLock must be held.
//
//go:systemstack
func (p *pageAlloc) alloc(npages uintptr) (addr uintptr, scav uintptr) {
assertLockHeld(p.mheapLock)
// If the searchAddr refers to a region which has a higher address than
// any known chunk, then we know we're out of memory.
if chunkIndex(p.searchAddr.addr()) >= p.end {
return 0, 0
}
// If npages has a chance of fitting in the chunk where the searchAddr is,
// search it directly.
searchAddr := minOffAddr
if pallocChunkPages-chunkPageIndex(p.searchAddr.addr()) >= uint(npages) {
// npages is guaranteed to be no greater than pallocChunkPages here.
i := chunkIndex(p.searchAddr.addr())
if max := p.summary[len(p.summary)-1][i].max(); max >= uint(npages) {
j, searchIdx := p.chunkOf(i).find(npages, chunkPageIndex(p.searchAddr.addr()))
if j == ^uint(0) {
print("runtime: max = ", max, ", npages = ", npages, "\n")
print("runtime: searchIdx = ", chunkPageIndex(p.searchAddr.addr()), ", p.searchAddr = ", hex(p.searchAddr.addr()), "\n")
throw("bad summary data")
}
addr = chunkBase(i) + uintptr(j)*pageSize
searchAddr = offAddr{chunkBase(i) + uintptr(searchIdx)*pageSize}
goto Found
}
}
// We failed to use a searchAddr for one reason or another, so try
// the slow path.
addr, searchAddr = p.find(npages)
if addr == 0 {
if npages == 1 {
// We failed to find a single free page, the smallest unit
// of allocation. This means we know the heap is completely
// exhausted. Otherwise, the heap still might have free
// space in it, just not enough contiguous space to
// accommodate npages.
p.searchAddr = maxSearchAddr()
}
return 0, 0
}
Found:
// Go ahead and actually mark the bits now that we have an address.
scav = p.allocRange(addr, npages)
// If we found a higher searchAddr, we know that all the
// heap memory before that searchAddr in an offset address space is
// allocated, so bump p.searchAddr up to the new one.
if p.searchAddr.lessThan(searchAddr) {
p.searchAddr = searchAddr
}
return addr, scav
}
// free returns npages worth of memory starting at base back to the page heap.
//
// p.mheapLock must be held.
//
// Must run on the system stack because p.mheapLock must be held.
//
//go:systemstack
func (p *pageAlloc) free(base, npages uintptr, scavenged bool) {
assertLockHeld(p.mheapLock)
// If we're freeing pages below the p.searchAddr, update searchAddr.
if b := (offAddr{base}); b.lessThan(p.searchAddr) {
p.searchAddr = b
}
limit := base + npages*pageSize - 1
if !scavenged {
p.scav.index.mark(base, limit+1)
}
if npages == 1 {
// Fast path: we're clearing a single bit, and we know exactly
// where it is, so mark it directly.
i := chunkIndex(base)
p.chunkOf(i).free1(chunkPageIndex(base))
} else {
// Slow path: we're clearing more bits so we may need to iterate.
sc, ec := chunkIndex(base), chunkIndex(limit)
si, ei := chunkPageIndex(base), chunkPageIndex(limit)
if sc == ec {
// The range doesn't cross any chunk boundaries.
p.chunkOf(sc).free(si, ei+1-si)
} else {
// The range crosses at least one chunk boundary.
p.chunkOf(sc).free(si, pallocChunkPages-si)
for c := sc + 1; c < ec; c++ {
p.chunkOf(c).freeAll()
}
p.chunkOf(ec).free(0, ei+1)
}
}
p.update(base, npages, true, false)
}
const (
pallocSumBytes = unsafe.Sizeof(pallocSum(0))
// maxPackedValue is the maximum value that any of the three fields in
// the pallocSum may take on.
maxPackedValue = 1 << logMaxPackedValue
logMaxPackedValue = logPallocChunkPages + (summaryLevels-1)*summaryLevelBits
freeChunkSum = pallocSum(uint64(pallocChunkPages) |
uint64(pallocChunkPages<<logMaxPackedValue) |
uint64(pallocChunkPages<<(2*logMaxPackedValue)))
)
// pallocSum is a packed summary type which packs three numbers: start, max,
// and end into a single 8-byte value. Each of these values are a summary of
// a bitmap and are thus counts, each of which may have a maximum value of
// 2^21 - 1, or all three may be equal to 2^21. The latter case is represented
// by just setting the 64th bit.
type pallocSum uint64
// packPallocSum takes a start, max, and end value and produces a pallocSum.
func packPallocSum(start, max, end uint) pallocSum {
if max == maxPackedValue {
return pallocSum(uint64(1 << 63))
}
return pallocSum((uint64(start) & (maxPackedValue - 1)) |
((uint64(max) & (maxPackedValue - 1)) << logMaxPackedValue) |
((uint64(end) & (maxPackedValue - 1)) << (2 * logMaxPackedValue)))
}
// start extracts the start value from a packed sum.
func (p pallocSum) start() uint {
if uint64(p)&uint64(1<<63) != 0 {
return maxPackedValue
}
return uint(uint64(p) & (maxPackedValue - 1))
}
// max extracts the max value from a packed sum.
func (p pallocSum) max() uint {
if uint64(p)&uint64(1<<63) != 0 {
return maxPackedValue
}
return uint((uint64(p) >> logMaxPackedValue) & (maxPackedValue - 1))
}
// end extracts the end value from a packed sum.
func (p pallocSum) end() uint {
if uint64(p)&uint64(1<<63) != 0 {
return maxPackedValue
}
return uint((uint64(p) >> (2 * logMaxPackedValue)) & (maxPackedValue - 1))
}
// unpack unpacks all three values from the summary.
func (p pallocSum) unpack() (uint, uint, uint) {
if uint64(p)&uint64(1<<63) != 0 {
return maxPackedValue, maxPackedValue, maxPackedValue
}
return uint(uint64(p) & (maxPackedValue - 1)),
uint((uint64(p) >> logMaxPackedValue) & (maxPackedValue - 1)),
uint((uint64(p) >> (2 * logMaxPackedValue)) & (maxPackedValue - 1))
}
// mergeSummaries merges consecutive summaries which may each represent at
// most 1 << logMaxPagesPerSum pages each together into one.
func mergeSummaries(sums []pallocSum, logMaxPagesPerSum uint) pallocSum {
// Merge the summaries in sums into one.
//
// We do this by keeping a running summary representing the merged
// summaries of sums[:i] in start, max, and end.
start, max, end := sums[0].unpack()
for i := 1; i < len(sums); i++ {
// Merge in sums[i].
si, mi, ei := sums[i].unpack()
// Merge in sums[i].start only if the running summary is
// completely free, otherwise this summary's start
// plays no role in the combined sum.
if start == uint(i)<<logMaxPagesPerSum {
start += si
}
// Recompute the max value of the running sum by looking
// across the boundary between the running sum and sums[i]
// and at the max sums[i], taking the greatest of those two
// and the max of the running sum.
if end+si > max {
max = end + si
}
if mi > max {
max = mi
}
// Merge in end by checking if this new summary is totally
// free. If it is, then we want to extend the running sum's
// end by the new summary. If not, then we have some alloc'd
// pages in there and we just want to take the end value in
// sums[i].
if ei == 1<<logMaxPagesPerSum {
end += 1 << logMaxPagesPerSum
} else {
end = ei
}
}
return packPallocSum(start, max, end)
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || arm64 || loong64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
const (
// The number of levels in the radix tree.
summaryLevels = 5
// Constants for testing.
pageAlloc32Bit = 0
pageAlloc64Bit = 1
// Number of bits needed to represent all indices into the L1 of the
// chunks map.
//
// See (*pageAlloc).chunks for more details. Update the documentation
// there should this number change.
pallocChunksL1Bits = 13
)
// levelBits is the number of bits in the radix for a given level in the super summary
// structure.
//
// The sum of all the entries of levelBits should equal heapAddrBits.
var levelBits = [summaryLevels]uint{
summaryL0Bits,
summaryLevelBits,
summaryLevelBits,
summaryLevelBits,
summaryLevelBits,
}
// levelShift is the number of bits to shift to acquire the radix for a given level
// in the super summary structure.
//
// With levelShift, one can compute the index of the summary at level l related to a
// pointer p by doing:
//
// p >> levelShift[l]
var levelShift = [summaryLevels]uint{
heapAddrBits - summaryL0Bits,
heapAddrBits - summaryL0Bits - 1*summaryLevelBits,
heapAddrBits - summaryL0Bits - 2*summaryLevelBits,
heapAddrBits - summaryL0Bits - 3*summaryLevelBits,
heapAddrBits - summaryL0Bits - 4*summaryLevelBits,
}
// levelLogPages is log2 the maximum number of runtime pages in the address space
// a summary in the given level represents.
//
// The leaf level always represents exactly log2 of 1 chunk's worth of pages.
var levelLogPages = [summaryLevels]uint{
logPallocChunkPages + 4*summaryLevelBits,
logPallocChunkPages + 3*summaryLevelBits,
logPallocChunkPages + 2*summaryLevelBits,
logPallocChunkPages + 1*summaryLevelBits,
logPallocChunkPages,
}
// sysInit performs architecture-dependent initialization of fields
// in pageAlloc. pageAlloc should be uninitialized except for sysStat
// if any runtime statistic should be updated.
func (p *pageAlloc) sysInit() {
// Reserve memory for each level. This will get mapped in
// as R/W by setArenas.
for l, shift := range levelShift {
entries := 1 << (heapAddrBits - shift)
// Reserve b bytes of memory anywhere in the address space.
b := alignUp(uintptr(entries)*pallocSumBytes, physPageSize)
r := sysReserve(nil, b)
if r == nil {
throw("failed to reserve page summary memory")
}
// Put this reservation into a slice.
sl := notInHeapSlice{(*notInHeap)(r), 0, entries}
p.summary[l] = *(*[]pallocSum)(unsafe.Pointer(&sl))
}
// Set up the scavenge index.
nbytes := uintptr(1<<heapAddrBits) / pallocChunkBytes / 8
r := sysReserve(nil, nbytes)
sl := notInHeapSlice{(*notInHeap)(r), int(nbytes), int(nbytes)}
p.scav.index.chunks = *(*[]atomic.Uint8)(unsafe.Pointer(&sl))
}
// sysGrow performs architecture-dependent operations on heap
// growth for the page allocator, such as mapping in new memory
// for summaries. It also updates the length of the slices in
// [.summary.
//
// base is the base of the newly-added heap memory and limit is
// the first address past the end of the newly-added heap memory.
// Both must be aligned to pallocChunkBytes.
//
// The caller must update p.start and p.end after calling sysGrow.
func (p *pageAlloc) sysGrow(base, limit uintptr) {
if base%pallocChunkBytes != 0 || limit%pallocChunkBytes != 0 {
print("runtime: base = ", hex(base), ", limit = ", hex(limit), "\n")
throw("sysGrow bounds not aligned to pallocChunkBytes")
}
// addrRangeToSummaryRange converts a range of addresses into a range
// of summary indices which must be mapped to support those addresses
// in the summary range.
addrRangeToSummaryRange := func(level int, r addrRange) (int, int) {
sumIdxBase, sumIdxLimit := addrsToSummaryRange(level, r.base.addr(), r.limit.addr())
return blockAlignSummaryRange(level, sumIdxBase, sumIdxLimit)
}
// summaryRangeToSumAddrRange converts a range of indices in any
// level of p.summary into page-aligned addresses which cover that
// range of indices.
summaryRangeToSumAddrRange := func(level, sumIdxBase, sumIdxLimit int) addrRange {
baseOffset := alignDown(uintptr(sumIdxBase)*pallocSumBytes, physPageSize)
limitOffset := alignUp(uintptr(sumIdxLimit)*pallocSumBytes, physPageSize)
base := unsafe.Pointer(&p.summary[level][0])
return addrRange{
offAddr{uintptr(add(base, baseOffset))},
offAddr{uintptr(add(base, limitOffset))},
}
}
// addrRangeToSumAddrRange is a convenience function that converts
// an address range r to the address range of the given summary level
// that stores the summaries for r.
addrRangeToSumAddrRange := func(level int, r addrRange) addrRange {
sumIdxBase, sumIdxLimit := addrRangeToSummaryRange(level, r)
return summaryRangeToSumAddrRange(level, sumIdxBase, sumIdxLimit)
}
// Find the first inUse index which is strictly greater than base.
//
// Because this function will never be asked remap the same memory
// twice, this index is effectively the index at which we would insert
// this new growth, and base will never overlap/be contained within
// any existing range.
//
// This will be used to look at what memory in the summary array is already
// mapped before and after this new range.
inUseIndex := p.inUse.findSucc(base)
// Walk up the radix tree and map summaries in as needed.
for l := range p.summary {
// Figure out what part of the summary array this new address space needs.
needIdxBase, needIdxLimit := addrRangeToSummaryRange(l, makeAddrRange(base, limit))
// Update the summary slices with a new upper-bound. This ensures
// we get tight bounds checks on at least the top bound.
//
// We must do this regardless of whether we map new memory.
if needIdxLimit > len(p.summary[l]) {
p.summary[l] = p.summary[l][:needIdxLimit]
}
// Compute the needed address range in the summary array for level l.
need := summaryRangeToSumAddrRange(l, needIdxBase, needIdxLimit)
// Prune need down to what needs to be newly mapped. Some parts of it may
// already be mapped by what inUse describes due to page alignment requirements
// for mapping. prune's invariants are guaranteed by the fact that this
// function will never be asked to remap the same memory twice.
if inUseIndex > 0 {
need = need.subtract(addrRangeToSumAddrRange(l, p.inUse.ranges[inUseIndex-1]))
}
if inUseIndex < len(p.inUse.ranges) {
need = need.subtract(addrRangeToSumAddrRange(l, p.inUse.ranges[inUseIndex]))
}
// It's possible that after our pruning above, there's nothing new to map.
if need.size() == 0 {
continue
}
// Map and commit need.
sysMap(unsafe.Pointer(need.base.addr()), need.size(), p.sysStat)
sysUsed(unsafe.Pointer(need.base.addr()), need.size(), need.size())
p.summaryMappedReady += need.size()
}
// Update the scavenge index.
p.summaryMappedReady += p.scav.index.grow(base, limit, p.sysStat)
}
// grow increases the index's backing store in response to a heap growth.
//
// Returns the amount of memory added to sysStat.
func (s *scavengeIndex) grow(base, limit uintptr, sysStat *sysMemStat) uintptr {
if base%pallocChunkBytes != 0 || limit%pallocChunkBytes != 0 {
print("runtime: base = ", hex(base), ", limit = ", hex(limit), "\n")
throw("sysGrow bounds not aligned to pallocChunkBytes")
}
// Map and commit the pieces of chunks that we need.
//
// We always map the full range of the minimum heap address to the
// maximum heap address. We don't do this for the summary structure
// because it's quite large and a discontiguous heap could cause a
// lot of memory to be used. In this situation, the worst case overhead
// is in the single-digit MiB if we map the whole thing.
//
// The base address of the backing store is always page-aligned,
// because it comes from the OS, so it's sufficient to align the
// index.
haveMin := s.min.Load()
haveMax := s.max.Load()
needMin := int32(alignDown(uintptr(chunkIndex(base)/8), physPageSize))
needMax := int32(alignUp(uintptr((chunkIndex(limit)+7)/8), physPageSize))
// Extend the range down to what we have, if there's no overlap.
if needMax < haveMin {
needMax = haveMin
}
if needMin > haveMax {
needMin = haveMax
}
have := makeAddrRange(
// Avoid a panic from indexing one past the last element.
uintptr(unsafe.Pointer(&s.chunks[0]))+uintptr(haveMin),
uintptr(unsafe.Pointer(&s.chunks[0]))+uintptr(haveMax),
)
need := makeAddrRange(
// Avoid a panic from indexing one past the last element.
uintptr(unsafe.Pointer(&s.chunks[0]))+uintptr(needMin),
uintptr(unsafe.Pointer(&s.chunks[0]))+uintptr(needMax),
)
// Subtract any overlap from rounding. We can't re-map memory because
// it'll be zeroed.
need = need.subtract(have)
// If we've got something to map, map it, and update the slice bounds.
if need.size() != 0 {
sysMap(unsafe.Pointer(need.base.addr()), need.size(), sysStat)
sysUsed(unsafe.Pointer(need.base.addr()), need.size(), need.size())
// Update the indices only after the new memory is valid.
if haveMin == 0 || needMin < haveMin {
s.min.Store(needMin)
}
if haveMax == 0 || needMax > haveMax {
s.max.Store(needMax)
}
}
// Update minHeapIdx. Note that even if there's no mapping work to do,
// we may still have a new, lower minimum heap address.
minHeapIdx := s.minHeapIdx.Load()
if baseIdx := int32(chunkIndex(base) / 8); minHeapIdx == 0 || baseIdx < minHeapIdx {
s.minHeapIdx.Store(baseIdx)
}
return need.size()
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/sys"
"unsafe"
)
const pageCachePages = 8 * unsafe.Sizeof(pageCache{}.cache)
// pageCache represents a per-p cache of pages the allocator can
// allocate from without a lock. More specifically, it represents
// a pageCachePages*pageSize chunk of memory with 0 or more free
// pages in it.
type pageCache struct {
base uintptr // base address of the chunk
cache uint64 // 64-bit bitmap representing free pages (1 means free)
scav uint64 // 64-bit bitmap representing scavenged pages (1 means scavenged)
}
// empty reports whether the page cache has no free pages.
func (c *pageCache) empty() bool {
return c.cache == 0
}
// alloc allocates npages from the page cache and is the main entry
// point for allocation.
//
// Returns a base address and the amount of scavenged memory in the
// allocated region in bytes.
//
// Returns a base address of zero on failure, in which case the
// amount of scavenged memory should be ignored.
func (c *pageCache) alloc(npages uintptr) (uintptr, uintptr) {
if c.cache == 0 {
return 0, 0
}
if npages == 1 {
i := uintptr(sys.TrailingZeros64(c.cache))
scav := (c.scav >> i) & 1
c.cache &^= 1 << i // set bit to mark in-use
c.scav &^= 1 << i // clear bit to mark unscavenged
return c.base + i*pageSize, uintptr(scav) * pageSize
}
return c.allocN(npages)
}
// allocN is a helper which attempts to allocate npages worth of pages
// from the cache. It represents the general case for allocating from
// the page cache.
//
// Returns a base address and the amount of scavenged memory in the
// allocated region in bytes.
func (c *pageCache) allocN(npages uintptr) (uintptr, uintptr) {
i := findBitRange64(c.cache, uint(npages))
if i >= 64 {
return 0, 0
}
mask := ((uint64(1) << npages) - 1) << i
scav := sys.OnesCount64(c.scav & mask)
c.cache &^= mask // mark in-use bits
c.scav &^= mask // clear scavenged bits
return c.base + uintptr(i*pageSize), uintptr(scav) * pageSize
}
// flush empties out unallocated free pages in the given cache
// into s. Then, it clears the cache, such that empty returns
// true.
//
// p.mheapLock must be held.
//
// Must run on the system stack because p.mheapLock must be held.
//
//go:systemstack
func (c *pageCache) flush(p *pageAlloc) {
assertLockHeld(p.mheapLock)
if c.empty() {
return
}
ci := chunkIndex(c.base)
pi := chunkPageIndex(c.base)
// This method is called very infrequently, so just do the
// slower, safer thing by iterating over each bit individually.
for i := uint(0); i < 64; i++ {
if c.cache&(1<<i) != 0 {
p.chunkOf(ci).free1(pi + i)
}
if c.scav&(1<<i) != 0 {
p.chunkOf(ci).scavenged.setRange(pi+i, 1)
}
}
// Since this is a lot like a free, we need to make sure
// we update the searchAddr just like free does.
if b := (offAddr{c.base}); b.lessThan(p.searchAddr) {
p.searchAddr = b
}
p.update(c.base, pageCachePages, false, false)
*c = pageCache{}
}
// allocToCache acquires a pageCachePages-aligned chunk of free pages which
// may not be contiguous, and returns a pageCache structure which owns the
// chunk.
//
// p.mheapLock must be held.
//
// Must run on the system stack because p.mheapLock must be held.
//
//go:systemstack
func (p *pageAlloc) allocToCache() pageCache {
assertLockHeld(p.mheapLock)
// If the searchAddr refers to a region which has a higher address than
// any known chunk, then we know we're out of memory.
if chunkIndex(p.searchAddr.addr()) >= p.end {
return pageCache{}
}
c := pageCache{}
ci := chunkIndex(p.searchAddr.addr()) // chunk index
var chunk *pallocData
if p.summary[len(p.summary)-1][ci] != 0 {
// Fast path: there's free pages at or near the searchAddr address.
chunk = p.chunkOf(ci)
j, _ := chunk.find(1, chunkPageIndex(p.searchAddr.addr()))
if j == ^uint(0) {
throw("bad summary data")
}
c = pageCache{
base: chunkBase(ci) + alignDown(uintptr(j), 64)*pageSize,
cache: ^chunk.pages64(j),
scav: chunk.scavenged.block64(j),
}
} else {
// Slow path: the searchAddr address had nothing there, so go find
// the first free page the slow way.
addr, _ := p.find(1)
if addr == 0 {
// We failed to find adequate free space, so mark the searchAddr as OoM
// and return an empty pageCache.
p.searchAddr = maxSearchAddr()
return pageCache{}
}
ci := chunkIndex(addr)
chunk = p.chunkOf(ci)
c = pageCache{
base: alignDown(addr, 64*pageSize),
cache: ^chunk.pages64(chunkPageIndex(addr)),
scav: chunk.scavenged.block64(chunkPageIndex(addr)),
}
}
// Set the page bits as allocated and clear the scavenged bits, but
// be careful to only set and clear the relevant bits.
cpi := chunkPageIndex(c.base)
chunk.allocPages64(cpi, c.cache)
chunk.scavenged.clearBlock64(cpi, c.cache&c.scav /* free and scavenged */)
// Update as an allocation, but note that it's not contiguous.
p.update(c.base, pageCachePages, false, true)
// Set the search address to the last page represented by the cache.
// Since all of the pages in this block are going to the cache, and we
// searched for the first free page, we can confidently start at the
// next page.
//
// However, p.searchAddr is not allowed to point into unmapped heap memory
// unless it is maxSearchAddr, so make it the last page as opposed to
// the page after.
p.searchAddr = offAddr{c.base + pageSize*(pageCachePages-1)}
return c
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/sys"
)
// pageBits is a bitmap representing one bit per page in a palloc chunk.
type pageBits [pallocChunkPages / 64]uint64
// get returns the value of the i'th bit in the bitmap.
func (b *pageBits) get(i uint) uint {
return uint((b[i/64] >> (i % 64)) & 1)
}
// block64 returns the 64-bit aligned block of bits containing the i'th bit.
func (b *pageBits) block64(i uint) uint64 {
return b[i/64]
}
// set sets bit i of pageBits.
func (b *pageBits) set(i uint) {
b[i/64] |= 1 << (i % 64)
}
// setRange sets bits in the range [i, i+n).
func (b *pageBits) setRange(i, n uint) {
_ = b[i/64]
if n == 1 {
// Fast path for the n == 1 case.
b.set(i)
return
}
// Set bits [i, j].
j := i + n - 1
if i/64 == j/64 {
b[i/64] |= ((uint64(1) << n) - 1) << (i % 64)
return
}
_ = b[j/64]
// Set leading bits.
b[i/64] |= ^uint64(0) << (i % 64)
for k := i/64 + 1; k < j/64; k++ {
b[k] = ^uint64(0)
}
// Set trailing bits.
b[j/64] |= (uint64(1) << (j%64 + 1)) - 1
}
// setAll sets all the bits of b.
func (b *pageBits) setAll() {
for i := range b {
b[i] = ^uint64(0)
}
}
// setBlock64 sets the 64-bit aligned block of bits containing the i'th bit that
// are set in v.
func (b *pageBits) setBlock64(i uint, v uint64) {
b[i/64] |= v
}
// clear clears bit i of pageBits.
func (b *pageBits) clear(i uint) {
b[i/64] &^= 1 << (i % 64)
}
// clearRange clears bits in the range [i, i+n).
func (b *pageBits) clearRange(i, n uint) {
_ = b[i/64]
if n == 1 {
// Fast path for the n == 1 case.
b.clear(i)
return
}
// Clear bits [i, j].
j := i + n - 1
if i/64 == j/64 {
b[i/64] &^= ((uint64(1) << n) - 1) << (i % 64)
return
}
_ = b[j/64]
// Clear leading bits.
b[i/64] &^= ^uint64(0) << (i % 64)
for k := i/64 + 1; k < j/64; k++ {
b[k] = 0
}
// Clear trailing bits.
b[j/64] &^= (uint64(1) << (j%64 + 1)) - 1
}
// clearAll frees all the bits of b.
func (b *pageBits) clearAll() {
for i := range b {
b[i] = 0
}
}
// clearBlock64 clears the 64-bit aligned block of bits containing the i'th bit that
// are set in v.
func (b *pageBits) clearBlock64(i uint, v uint64) {
b[i/64] &^= v
}
// popcntRange counts the number of set bits in the
// range [i, i+n).
func (b *pageBits) popcntRange(i, n uint) (s uint) {
if n == 1 {
return uint((b[i/64] >> (i % 64)) & 1)
}
_ = b[i/64]
j := i + n - 1
if i/64 == j/64 {
return uint(sys.OnesCount64((b[i/64] >> (i % 64)) & ((1 << n) - 1)))
}
_ = b[j/64]
s += uint(sys.OnesCount64(b[i/64] >> (i % 64)))
for k := i/64 + 1; k < j/64; k++ {
s += uint(sys.OnesCount64(b[k]))
}
s += uint(sys.OnesCount64(b[j/64] & ((1 << (j%64 + 1)) - 1)))
return
}
// pallocBits is a bitmap that tracks page allocations for at most one
// palloc chunk.
//
// The precise representation is an implementation detail, but for the
// sake of documentation, 0s are free pages and 1s are allocated pages.
type pallocBits pageBits
// summarize returns a packed summary of the bitmap in pallocBits.
func (b *pallocBits) summarize() pallocSum {
var start, max, cur uint
const notSetYet = ^uint(0) // sentinel for start value
start = notSetYet
for i := 0; i < len(b); i++ {
x := b[i]
if x == 0 {
cur += 64
continue
}
t := uint(sys.TrailingZeros64(x))
l := uint(sys.LeadingZeros64(x))
// Finish any region spanning the uint64s
cur += t
if start == notSetYet {
start = cur
}
if cur > max {
max = cur
}
// Final region that might span to next uint64
cur = l
}
if start == notSetYet {
// Made it all the way through without finding a single 1 bit.
const n = uint(64 * len(b))
return packPallocSum(n, n, n)
}
if cur > max {
max = cur
}
if max >= 64-2 {
// There is no way an internal run of zeros could beat max.
return packPallocSum(start, max, cur)
}
// Now look inside each uint64 for runs of zeros.
// All uint64s must be nonzero, or we would have aborted above.
outer:
for i := 0; i < len(b); i++ {
x := b[i]
// Look inside this uint64. We have a pattern like
// 000000 1xxxxx1 000000
// We need to look inside the 1xxxxx1 for any contiguous
// region of zeros.
// We already know the trailing zeros are no larger than max. Remove them.
x >>= sys.TrailingZeros64(x) & 63
if x&(x+1) == 0 { // no more zeros (except at the top).
continue
}
// Strategy: shrink all runs of zeros by max. If any runs of zero
// remain, then we've identified a larger maximum zero run.
p := max // number of zeros we still need to shrink by.
k := uint(1) // current minimum length of runs of ones in x.
for {
// Shrink all runs of zeros by p places (except the top zeros).
for p > 0 {
if p <= k {
// Shift p ones down into the top of each run of zeros.
x |= x >> (p & 63)
if x&(x+1) == 0 { // no more zeros (except at the top).
continue outer
}
break
}
// Shift k ones down into the top of each run of zeros.
x |= x >> (k & 63)
if x&(x+1) == 0 { // no more zeros (except at the top).
continue outer
}
p -= k
// We've just doubled the minimum length of 1-runs.
// This allows us to shift farther in the next iteration.
k *= 2
}
// The length of the lowest-order zero run is an increment to our maximum.
j := uint(sys.TrailingZeros64(^x)) // count contiguous trailing ones
x >>= j & 63 // remove trailing ones
j = uint(sys.TrailingZeros64(x)) // count contiguous trailing zeros
x >>= j & 63 // remove zeros
max += j // we have a new maximum!
if x&(x+1) == 0 { // no more zeros (except at the top).
continue outer
}
p = j // remove j more zeros from each zero run.
}
}
return packPallocSum(start, max, cur)
}
// find searches for npages contiguous free pages in pallocBits and returns
// the index where that run starts, as well as the index of the first free page
// it found in the search. searchIdx represents the first known free page and
// where to begin the next search from.
//
// If find fails to find any free space, it returns an index of ^uint(0) and
// the new searchIdx should be ignored.
//
// Note that if npages == 1, the two returned values will always be identical.
func (b *pallocBits) find(npages uintptr, searchIdx uint) (uint, uint) {
if npages == 1 {
addr := b.find1(searchIdx)
return addr, addr
} else if npages <= 64 {
return b.findSmallN(npages, searchIdx)
}
return b.findLargeN(npages, searchIdx)
}
// find1 is a helper for find which searches for a single free page
// in the pallocBits and returns the index.
//
// See find for an explanation of the searchIdx parameter.
func (b *pallocBits) find1(searchIdx uint) uint {
_ = b[0] // lift nil check out of loop
for i := searchIdx / 64; i < uint(len(b)); i++ {
x := b[i]
if ^x == 0 {
continue
}
return i*64 + uint(sys.TrailingZeros64(^x))
}
return ^uint(0)
}
// findSmallN is a helper for find which searches for npages contiguous free pages
// in this pallocBits and returns the index where that run of contiguous pages
// starts as well as the index of the first free page it finds in its search.
//
// See find for an explanation of the searchIdx parameter.
//
// Returns a ^uint(0) index on failure and the new searchIdx should be ignored.
//
// findSmallN assumes npages <= 64, where any such contiguous run of pages
// crosses at most one aligned 64-bit boundary in the bits.
func (b *pallocBits) findSmallN(npages uintptr, searchIdx uint) (uint, uint) {
end, newSearchIdx := uint(0), ^uint(0)
for i := searchIdx / 64; i < uint(len(b)); i++ {
bi := b[i]
if ^bi == 0 {
end = 0
continue
}
// First see if we can pack our allocation in the trailing
// zeros plus the end of the last 64 bits.
if newSearchIdx == ^uint(0) {
// The new searchIdx is going to be at these 64 bits after any
// 1s we file, so count trailing 1s.
newSearchIdx = i*64 + uint(sys.TrailingZeros64(^bi))
}
start := uint(sys.TrailingZeros64(bi))
if end+start >= uint(npages) {
return i*64 - end, newSearchIdx
}
// Next, check the interior of the 64-bit chunk.
j := findBitRange64(^bi, uint(npages))
if j < 64 {
return i*64 + j, newSearchIdx
}
end = uint(sys.LeadingZeros64(bi))
}
return ^uint(0), newSearchIdx
}
// findLargeN is a helper for find which searches for npages contiguous free pages
// in this pallocBits and returns the index where that run starts, as well as the
// index of the first free page it found it its search.
//
// See alloc for an explanation of the searchIdx parameter.
//
// Returns a ^uint(0) index on failure and the new searchIdx should be ignored.
//
// findLargeN assumes npages > 64, where any such run of free pages
// crosses at least one aligned 64-bit boundary in the bits.
func (b *pallocBits) findLargeN(npages uintptr, searchIdx uint) (uint, uint) {
start, size, newSearchIdx := ^uint(0), uint(0), ^uint(0)
for i := searchIdx / 64; i < uint(len(b)); i++ {
x := b[i]
if x == ^uint64(0) {
size = 0
continue
}
if newSearchIdx == ^uint(0) {
// The new searchIdx is going to be at these 64 bits after any
// 1s we file, so count trailing 1s.
newSearchIdx = i*64 + uint(sys.TrailingZeros64(^x))
}
if size == 0 {
size = uint(sys.LeadingZeros64(x))
start = i*64 + 64 - size
continue
}
s := uint(sys.TrailingZeros64(x))
if s+size >= uint(npages) {
size += s
return start, newSearchIdx
}
if s < 64 {
size = uint(sys.LeadingZeros64(x))
start = i*64 + 64 - size
continue
}
size += 64
}
if size < uint(npages) {
return ^uint(0), newSearchIdx
}
return start, newSearchIdx
}
// allocRange allocates the range [i, i+n).
func (b *pallocBits) allocRange(i, n uint) {
(*pageBits)(b).setRange(i, n)
}
// allocAll allocates all the bits of b.
func (b *pallocBits) allocAll() {
(*pageBits)(b).setAll()
}
// free1 frees a single page in the pallocBits at i.
func (b *pallocBits) free1(i uint) {
(*pageBits)(b).clear(i)
}
// free frees the range [i, i+n) of pages in the pallocBits.
func (b *pallocBits) free(i, n uint) {
(*pageBits)(b).clearRange(i, n)
}
// freeAll frees all the bits of b.
func (b *pallocBits) freeAll() {
(*pageBits)(b).clearAll()
}
// pages64 returns a 64-bit bitmap representing a block of 64 pages aligned
// to 64 pages. The returned block of pages is the one containing the i'th
// page in this pallocBits. Each bit represents whether the page is in-use.
func (b *pallocBits) pages64(i uint) uint64 {
return (*pageBits)(b).block64(i)
}
// allocPages64 allocates a 64-bit block of 64 pages aligned to 64 pages according
// to the bits set in alloc. The block set is the one containing the i'th page.
func (b *pallocBits) allocPages64(i uint, alloc uint64) {
(*pageBits)(b).setBlock64(i, alloc)
}
// findBitRange64 returns the bit index of the first set of
// n consecutive 1 bits. If no consecutive set of 1 bits of
// size n may be found in c, then it returns an integer >= 64.
// n must be > 0.
func findBitRange64(c uint64, n uint) uint {
// This implementation is based on shrinking the length of
// runs of contiguous 1 bits. We remove the top n-1 1 bits
// from each run of 1s, then look for the first remaining 1 bit.
p := n - 1 // number of 1s we want to remove.
k := uint(1) // current minimum width of runs of 0 in c.
for p > 0 {
if p <= k {
// Shift p 0s down into the top of each run of 1s.
c &= c >> (p & 63)
break
}
// Shift k 0s down into the top of each run of 1s.
c &= c >> (k & 63)
if c == 0 {
return 64
}
p -= k
// We've just doubled the minimum length of 0-runs.
// This allows us to shift farther in the next iteration.
k *= 2
}
// Find first remaining 1.
// Since we shrunk from the top down, the first 1 is in
// its correct original position.
return uint(sys.TrailingZeros64(c))
}
// pallocData encapsulates pallocBits and a bitmap for
// whether or not a given page is scavenged in a single
// structure. It's effectively a pallocBits with
// additional functionality.
//
// Update the comment on (*pageAlloc).chunks should this
// structure change.
type pallocData struct {
pallocBits
scavenged pageBits
}
// allocRange sets bits [i, i+n) in the bitmap to 1 and
// updates the scavenged bits appropriately.
func (m *pallocData) allocRange(i, n uint) {
// Clear the scavenged bits when we alloc the range.
m.pallocBits.allocRange(i, n)
m.scavenged.clearRange(i, n)
}
// allocAll sets every bit in the bitmap to 1 and updates
// the scavenged bits appropriately.
func (m *pallocData) allocAll() {
// Clear the scavenged bits when we alloc the range.
m.pallocBits.allocAll()
m.scavenged.clearAll()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Malloc profiling.
// Patterned after tcmalloc's algorithms; shorter code.
package runtime
import (
"internal/abi"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// NOTE(rsc): Everything here could use cas if contention became an issue.
var (
// profInsertLock protects changes to the start of all *bucket linked lists
profInsertLock mutex
// profBlockLock protects the contents of every blockRecord struct
profBlockLock mutex
// profMemActiveLock protects the active field of every memRecord struct
profMemActiveLock mutex
// profMemFutureLock is a set of locks that protect the respective elements
// of the future array of every memRecord struct
profMemFutureLock [len(memRecord{}.future)]mutex
)
// All memory allocations are local and do not escape outside of the profiler.
// The profiler is forbidden from referring to garbage-collected memory.
const (
// profile types
memProfile bucketType = 1 + iota
blockProfile
mutexProfile
// size of bucket hash table
buckHashSize = 179999
// max depth of stack to record in bucket
maxStack = 32
)
type bucketType int
// A bucket holds per-call-stack profiling information.
// The representation is a bit sleazy, inherited from C.
// This struct defines the bucket header. It is followed in
// memory by the stack words and then the actual record
// data, either a memRecord or a blockRecord.
//
// Per-call-stack profiling information.
// Lookup by hashing call stack into a linked-list hash table.
//
// None of the fields in this bucket header are modified after
// creation, including its next and allnext links.
//
// No heap pointers.
type bucket struct {
_ sys.NotInHeap
next *bucket
allnext *bucket
typ bucketType // memBucket or blockBucket (includes mutexProfile)
hash uintptr
size uintptr
nstk uintptr
}
// A memRecord is the bucket data for a bucket of type memProfile,
// part of the memory profile.
type memRecord struct {
// The following complex 3-stage scheme of stats accumulation
// is required to obtain a consistent picture of mallocs and frees
// for some point in time.
// The problem is that mallocs come in real time, while frees
// come only after a GC during concurrent sweeping. So if we would
// naively count them, we would get a skew toward mallocs.
//
// Hence, we delay information to get consistent snapshots as
// of mark termination. Allocations count toward the next mark
// termination's snapshot, while sweep frees count toward the
// previous mark termination's snapshot:
//
// MT MT MT MT
// .·| .·| .·| .·|
// .·˙ | .·˙ | .·˙ | .·˙ |
// .·˙ | .·˙ | .·˙ | .·˙ |
// .·˙ |.·˙ |.·˙ |.·˙ |
//
// alloc → ▲ ← free
// ┠┅┅┅┅┅┅┅┅┅┅┅P
// C+2 → C+1 → C
//
// alloc → ▲ ← free
// ┠┅┅┅┅┅┅┅┅┅┅┅P
// C+2 → C+1 → C
//
// Since we can't publish a consistent snapshot until all of
// the sweep frees are accounted for, we wait until the next
// mark termination ("MT" above) to publish the previous mark
// termination's snapshot ("P" above). To do this, allocation
// and free events are accounted to *future* heap profile
// cycles ("C+n" above) and we only publish a cycle once all
// of the events from that cycle must be done. Specifically:
//
// Mallocs are accounted to cycle C+2.
// Explicit frees are accounted to cycle C+2.
// GC frees (done during sweeping) are accounted to cycle C+1.
//
// After mark termination, we increment the global heap
// profile cycle counter and accumulate the stats from cycle C
// into the active profile.
// active is the currently published profile. A profiling
// cycle can be accumulated into active once its complete.
active memRecordCycle
// future records the profile events we're counting for cycles
// that have not yet been published. This is ring buffer
// indexed by the global heap profile cycle C and stores
// cycles C, C+1, and C+2. Unlike active, these counts are
// only for a single cycle; they are not cumulative across
// cycles.
//
// We store cycle C here because there's a window between when
// C becomes the active cycle and when we've flushed it to
// active.
future [3]memRecordCycle
}
// memRecordCycle
type memRecordCycle struct {
allocs, frees uintptr
alloc_bytes, free_bytes uintptr
}
// add accumulates b into a. It does not zero b.
func (a *memRecordCycle) add(b *memRecordCycle) {
a.allocs += b.allocs
a.frees += b.frees
a.alloc_bytes += b.alloc_bytes
a.free_bytes += b.free_bytes
}
// A blockRecord is the bucket data for a bucket of type blockProfile,
// which is used in blocking and mutex profiles.
type blockRecord struct {
count float64
cycles int64
}
var (
mbuckets atomic.UnsafePointer // *bucket, memory profile buckets
bbuckets atomic.UnsafePointer // *bucket, blocking profile buckets
xbuckets atomic.UnsafePointer // *bucket, mutex profile buckets
buckhash atomic.UnsafePointer // *buckhashArray
mProfCycle mProfCycleHolder
)
type buckhashArray [buckHashSize]atomic.UnsafePointer // *bucket
const mProfCycleWrap = uint32(len(memRecord{}.future)) * (2 << 24)
// mProfCycleHolder holds the global heap profile cycle number (wrapped at
// mProfCycleWrap, stored starting at bit 1), and a flag (stored at bit 0) to
// indicate whether future[cycle] in all buckets has been queued to flush into
// the active profile.
type mProfCycleHolder struct {
value atomic.Uint32
}
// read returns the current cycle count.
func (c *mProfCycleHolder) read() (cycle uint32) {
v := c.value.Load()
cycle = v >> 1
return cycle
}
// setFlushed sets the flushed flag. It returns the current cycle count and the
// previous value of the flushed flag.
func (c *mProfCycleHolder) setFlushed() (cycle uint32, alreadyFlushed bool) {
for {
prev := c.value.Load()
cycle = prev >> 1
alreadyFlushed = (prev & 0x1) != 0
next := prev | 0x1
if c.value.CompareAndSwap(prev, next) {
return cycle, alreadyFlushed
}
}
}
// increment increases the cycle count by one, wrapping the value at
// mProfCycleWrap. It clears the flushed flag.
func (c *mProfCycleHolder) increment() {
// We explicitly wrap mProfCycle rather than depending on
// uint wraparound because the memRecord.future ring does not
// itself wrap at a power of two.
for {
prev := c.value.Load()
cycle := prev >> 1
cycle = (cycle + 1) % mProfCycleWrap
next := cycle << 1
if c.value.CompareAndSwap(prev, next) {
break
}
}
}
// newBucket allocates a bucket with the given type and number of stack entries.
func newBucket(typ bucketType, nstk int) *bucket {
size := unsafe.Sizeof(bucket{}) + uintptr(nstk)*unsafe.Sizeof(uintptr(0))
switch typ {
default:
throw("invalid profile bucket type")
case memProfile:
size += unsafe.Sizeof(memRecord{})
case blockProfile, mutexProfile:
size += unsafe.Sizeof(blockRecord{})
}
b := (*bucket)(persistentalloc(size, 0, &memstats.buckhash_sys))
b.typ = typ
b.nstk = uintptr(nstk)
return b
}
// stk returns the slice in b holding the stack.
func (b *bucket) stk() []uintptr {
stk := (*[maxStack]uintptr)(add(unsafe.Pointer(b), unsafe.Sizeof(*b)))
return stk[:b.nstk:b.nstk]
}
// mp returns the memRecord associated with the memProfile bucket b.
func (b *bucket) mp() *memRecord {
if b.typ != memProfile {
throw("bad use of bucket.mp")
}
data := add(unsafe.Pointer(b), unsafe.Sizeof(*b)+b.nstk*unsafe.Sizeof(uintptr(0)))
return (*memRecord)(data)
}
// bp returns the blockRecord associated with the blockProfile bucket b.
func (b *bucket) bp() *blockRecord {
if b.typ != blockProfile && b.typ != mutexProfile {
throw("bad use of bucket.bp")
}
data := add(unsafe.Pointer(b), unsafe.Sizeof(*b)+b.nstk*unsafe.Sizeof(uintptr(0)))
return (*blockRecord)(data)
}
// Return the bucket for stk[0:nstk], allocating new bucket if needed.
func stkbucket(typ bucketType, size uintptr, stk []uintptr, alloc bool) *bucket {
bh := (*buckhashArray)(buckhash.Load())
if bh == nil {
lock(&profInsertLock)
// check again under the lock
bh = (*buckhashArray)(buckhash.Load())
if bh == nil {
bh = (*buckhashArray)(sysAlloc(unsafe.Sizeof(buckhashArray{}), &memstats.buckhash_sys))
if bh == nil {
throw("runtime: cannot allocate memory")
}
buckhash.StoreNoWB(unsafe.Pointer(bh))
}
unlock(&profInsertLock)
}
// Hash stack.
var h uintptr
for _, pc := range stk {
h += pc
h += h << 10
h ^= h >> 6
}
// hash in size
h += size
h += h << 10
h ^= h >> 6
// finalize
h += h << 3
h ^= h >> 11
i := int(h % buckHashSize)
// first check optimistically, without the lock
for b := (*bucket)(bh[i].Load()); b != nil; b = b.next {
if b.typ == typ && b.hash == h && b.size == size && eqslice(b.stk(), stk) {
return b
}
}
if !alloc {
return nil
}
lock(&profInsertLock)
// check again under the insertion lock
for b := (*bucket)(bh[i].Load()); b != nil; b = b.next {
if b.typ == typ && b.hash == h && b.size == size && eqslice(b.stk(), stk) {
unlock(&profInsertLock)
return b
}
}
// Create new bucket.
b := newBucket(typ, len(stk))
copy(b.stk(), stk)
b.hash = h
b.size = size
var allnext *atomic.UnsafePointer
if typ == memProfile {
allnext = &mbuckets
} else if typ == mutexProfile {
allnext = &xbuckets
} else {
allnext = &bbuckets
}
b.next = (*bucket)(bh[i].Load())
b.allnext = (*bucket)(allnext.Load())
bh[i].StoreNoWB(unsafe.Pointer(b))
allnext.StoreNoWB(unsafe.Pointer(b))
unlock(&profInsertLock)
return b
}
func eqslice(x, y []uintptr) bool {
if len(x) != len(y) {
return false
}
for i, xi := range x {
if xi != y[i] {
return false
}
}
return true
}
// mProf_NextCycle publishes the next heap profile cycle and creates a
// fresh heap profile cycle. This operation is fast and can be done
// during STW. The caller must call mProf_Flush before calling
// mProf_NextCycle again.
//
// This is called by mark termination during STW so allocations and
// frees after the world is started again count towards a new heap
// profiling cycle.
func mProf_NextCycle() {
mProfCycle.increment()
}
// mProf_Flush flushes the events from the current heap profiling
// cycle into the active profile. After this it is safe to start a new
// heap profiling cycle with mProf_NextCycle.
//
// This is called by GC after mark termination starts the world. In
// contrast with mProf_NextCycle, this is somewhat expensive, but safe
// to do concurrently.
func mProf_Flush() {
cycle, alreadyFlushed := mProfCycle.setFlushed()
if alreadyFlushed {
return
}
index := cycle % uint32(len(memRecord{}.future))
lock(&profMemActiveLock)
lock(&profMemFutureLock[index])
mProf_FlushLocked(index)
unlock(&profMemFutureLock[index])
unlock(&profMemActiveLock)
}
// mProf_FlushLocked flushes the events from the heap profiling cycle at index
// into the active profile. The caller must hold the lock for the active profile
// (profMemActiveLock) and for the profiling cycle at index
// (profMemFutureLock[index]).
func mProf_FlushLocked(index uint32) {
assertLockHeld(&profMemActiveLock)
assertLockHeld(&profMemFutureLock[index])
head := (*bucket)(mbuckets.Load())
for b := head; b != nil; b = b.allnext {
mp := b.mp()
// Flush cycle C into the published profile and clear
// it for reuse.
mpc := &mp.future[index]
mp.active.add(mpc)
*mpc = memRecordCycle{}
}
}
// mProf_PostSweep records that all sweep frees for this GC cycle have
// completed. This has the effect of publishing the heap profile
// snapshot as of the last mark termination without advancing the heap
// profile cycle.
func mProf_PostSweep() {
// Flush cycle C+1 to the active profile so everything as of
// the last mark termination becomes visible. *Don't* advance
// the cycle, since we're still accumulating allocs in cycle
// C+2, which have to become C+1 in the next mark termination
// and so on.
cycle := mProfCycle.read() + 1
index := cycle % uint32(len(memRecord{}.future))
lock(&profMemActiveLock)
lock(&profMemFutureLock[index])
mProf_FlushLocked(index)
unlock(&profMemFutureLock[index])
unlock(&profMemActiveLock)
}
// Called by malloc to record a profiled block.
func mProf_Malloc(p unsafe.Pointer, size uintptr) {
var stk [maxStack]uintptr
nstk := callers(4, stk[:])
index := (mProfCycle.read() + 2) % uint32(len(memRecord{}.future))
b := stkbucket(memProfile, size, stk[:nstk], true)
mp := b.mp()
mpc := &mp.future[index]
lock(&profMemFutureLock[index])
mpc.allocs++
mpc.alloc_bytes += size
unlock(&profMemFutureLock[index])
// Setprofilebucket locks a bunch of other mutexes, so we call it outside of
// the profiler locks. This reduces potential contention and chances of
// deadlocks. Since the object must be alive during the call to
// mProf_Malloc, it's fine to do this non-atomically.
systemstack(func() {
setprofilebucket(p, b)
})
}
// Called when freeing a profiled block.
func mProf_Free(b *bucket, size uintptr) {
index := (mProfCycle.read() + 1) % uint32(len(memRecord{}.future))
mp := b.mp()
mpc := &mp.future[index]
lock(&profMemFutureLock[index])
mpc.frees++
mpc.free_bytes += size
unlock(&profMemFutureLock[index])
}
var blockprofilerate uint64 // in CPU ticks
// SetBlockProfileRate controls the fraction of goroutine blocking events
// that are reported in the blocking profile. The profiler aims to sample
// an average of one blocking event per rate nanoseconds spent blocked.
//
// To include every blocking event in the profile, pass rate = 1.
// To turn off profiling entirely, pass rate <= 0.
func SetBlockProfileRate(rate int) {
var r int64
if rate <= 0 {
r = 0 // disable profiling
} else if rate == 1 {
r = 1 // profile everything
} else {
// convert ns to cycles, use float64 to prevent overflow during multiplication
r = int64(float64(rate) * float64(tickspersecond()) / (1000 * 1000 * 1000))
if r == 0 {
r = 1
}
}
atomic.Store64(&blockprofilerate, uint64(r))
}
func blockevent(cycles int64, skip int) {
if cycles <= 0 {
cycles = 1
}
rate := int64(atomic.Load64(&blockprofilerate))
if blocksampled(cycles, rate) {
saveblockevent(cycles, rate, skip+1, blockProfile)
}
}
// blocksampled returns true for all events where cycles >= rate. Shorter
// events have a cycles/rate random chance of returning true.
func blocksampled(cycles, rate int64) bool {
if rate <= 0 || (rate > cycles && int64(fastrand())%rate > cycles) {
return false
}
return true
}
func saveblockevent(cycles, rate int64, skip int, which bucketType) {
gp := getg()
var nstk int
var stk [maxStack]uintptr
if gp.m.curg == nil || gp.m.curg == gp {
nstk = callers(skip, stk[:])
} else {
nstk = gcallers(gp.m.curg, skip, stk[:])
}
b := stkbucket(which, 0, stk[:nstk], true)
bp := b.bp()
lock(&profBlockLock)
// We want to up-scale the count and cycles according to the
// probability that the event was sampled. For block profile events,
// the sample probability is 1 if cycles >= rate, and cycles / rate
// otherwise. For mutex profile events, the sample probability is 1 / rate.
// We scale the events by 1 / (probability the event was sampled).
if which == blockProfile && cycles < rate {
// Remove sampling bias, see discussion on http://golang.org/cl/299991.
bp.count += float64(rate) / float64(cycles)
bp.cycles += rate
} else if which == mutexProfile {
bp.count += float64(rate)
bp.cycles += rate * cycles
} else {
bp.count++
bp.cycles += cycles
}
unlock(&profBlockLock)
}
var mutexprofilerate uint64 // fraction sampled
// SetMutexProfileFraction controls the fraction of mutex contention events
// that are reported in the mutex profile. On average 1/rate events are
// reported. The previous rate is returned.
//
// To turn off profiling entirely, pass rate 0.
// To just read the current rate, pass rate < 0.
// (For n>1 the details of sampling may change.)
func SetMutexProfileFraction(rate int) int {
if rate < 0 {
return int(mutexprofilerate)
}
old := mutexprofilerate
atomic.Store64(&mutexprofilerate, uint64(rate))
return int(old)
}
//go:linkname mutexevent sync.event
func mutexevent(cycles int64, skip int) {
if cycles < 0 {
cycles = 0
}
rate := int64(atomic.Load64(&mutexprofilerate))
// TODO(pjw): measure impact of always calling fastrand vs using something
// like malloc.go:nextSample()
if rate > 0 && int64(fastrand())%rate == 0 {
saveblockevent(cycles, rate, skip+1, mutexProfile)
}
}
// Go interface to profile data.
// A StackRecord describes a single execution stack.
type StackRecord struct {
Stack0 [32]uintptr // stack trace for this record; ends at first 0 entry
}
// Stack returns the stack trace associated with the record,
// a prefix of r.Stack0.
func (r *StackRecord) Stack() []uintptr {
for i, v := range r.Stack0 {
if v == 0 {
return r.Stack0[0:i]
}
}
return r.Stack0[0:]
}
// MemProfileRate controls the fraction of memory allocations
// that are recorded and reported in the memory profile.
// The profiler aims to sample an average of
// one allocation per MemProfileRate bytes allocated.
//
// To include every allocated block in the profile, set MemProfileRate to 1.
// To turn off profiling entirely, set MemProfileRate to 0.
//
// The tools that process the memory profiles assume that the
// profile rate is constant across the lifetime of the program
// and equal to the current value. Programs that change the
// memory profiling rate should do so just once, as early as
// possible in the execution of the program (for example,
// at the beginning of main).
var MemProfileRate int = 512 * 1024
// disableMemoryProfiling is set by the linker if runtime.MemProfile
// is not used and the link type guarantees nobody else could use it
// elsewhere.
var disableMemoryProfiling bool
// A MemProfileRecord describes the live objects allocated
// by a particular call sequence (stack trace).
type MemProfileRecord struct {
AllocBytes, FreeBytes int64 // number of bytes allocated, freed
AllocObjects, FreeObjects int64 // number of objects allocated, freed
Stack0 [32]uintptr // stack trace for this record; ends at first 0 entry
}
// InUseBytes returns the number of bytes in use (AllocBytes - FreeBytes).
func (r *MemProfileRecord) InUseBytes() int64 { return r.AllocBytes - r.FreeBytes }
// InUseObjects returns the number of objects in use (AllocObjects - FreeObjects).
func (r *MemProfileRecord) InUseObjects() int64 {
return r.AllocObjects - r.FreeObjects
}
// Stack returns the stack trace associated with the record,
// a prefix of r.Stack0.
func (r *MemProfileRecord) Stack() []uintptr {
for i, v := range r.Stack0 {
if v == 0 {
return r.Stack0[0:i]
}
}
return r.Stack0[0:]
}
// MemProfile returns a profile of memory allocated and freed per allocation
// site.
//
// MemProfile returns n, the number of records in the current memory profile.
// If len(p) >= n, MemProfile copies the profile into p and returns n, true.
// If len(p) < n, MemProfile does not change p and returns n, false.
//
// If inuseZero is true, the profile includes allocation records
// where r.AllocBytes > 0 but r.AllocBytes == r.FreeBytes.
// These are sites where memory was allocated, but it has all
// been released back to the runtime.
//
// The returned profile may be up to two garbage collection cycles old.
// This is to avoid skewing the profile toward allocations; because
// allocations happen in real time but frees are delayed until the garbage
// collector performs sweeping, the profile only accounts for allocations
// that have had a chance to be freed by the garbage collector.
//
// Most clients should use the runtime/pprof package or
// the testing package's -test.memprofile flag instead
// of calling MemProfile directly.
func MemProfile(p []MemProfileRecord, inuseZero bool) (n int, ok bool) {
cycle := mProfCycle.read()
// If we're between mProf_NextCycle and mProf_Flush, take care
// of flushing to the active profile so we only have to look
// at the active profile below.
index := cycle % uint32(len(memRecord{}.future))
lock(&profMemActiveLock)
lock(&profMemFutureLock[index])
mProf_FlushLocked(index)
unlock(&profMemFutureLock[index])
clear := true
head := (*bucket)(mbuckets.Load())
for b := head; b != nil; b = b.allnext {
mp := b.mp()
if inuseZero || mp.active.alloc_bytes != mp.active.free_bytes {
n++
}
if mp.active.allocs != 0 || mp.active.frees != 0 {
clear = false
}
}
if clear {
// Absolutely no data, suggesting that a garbage collection
// has not yet happened. In order to allow profiling when
// garbage collection is disabled from the beginning of execution,
// accumulate all of the cycles, and recount buckets.
n = 0
for b := head; b != nil; b = b.allnext {
mp := b.mp()
for c := range mp.future {
lock(&profMemFutureLock[c])
mp.active.add(&mp.future[c])
mp.future[c] = memRecordCycle{}
unlock(&profMemFutureLock[c])
}
if inuseZero || mp.active.alloc_bytes != mp.active.free_bytes {
n++
}
}
}
if n <= len(p) {
ok = true
idx := 0
for b := head; b != nil; b = b.allnext {
mp := b.mp()
if inuseZero || mp.active.alloc_bytes != mp.active.free_bytes {
record(&p[idx], b)
idx++
}
}
}
unlock(&profMemActiveLock)
return
}
// Write b's data to r.
func record(r *MemProfileRecord, b *bucket) {
mp := b.mp()
r.AllocBytes = int64(mp.active.alloc_bytes)
r.FreeBytes = int64(mp.active.free_bytes)
r.AllocObjects = int64(mp.active.allocs)
r.FreeObjects = int64(mp.active.frees)
if raceenabled {
racewriterangepc(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0), getcallerpc(), abi.FuncPCABIInternal(MemProfile))
}
if msanenabled {
msanwrite(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0))
}
if asanenabled {
asanwrite(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0))
}
copy(r.Stack0[:], b.stk())
for i := int(b.nstk); i < len(r.Stack0); i++ {
r.Stack0[i] = 0
}
}
func iterate_memprof(fn func(*bucket, uintptr, *uintptr, uintptr, uintptr, uintptr)) {
lock(&profMemActiveLock)
head := (*bucket)(mbuckets.Load())
for b := head; b != nil; b = b.allnext {
mp := b.mp()
fn(b, b.nstk, &b.stk()[0], b.size, mp.active.allocs, mp.active.frees)
}
unlock(&profMemActiveLock)
}
// BlockProfileRecord describes blocking events originated
// at a particular call sequence (stack trace).
type BlockProfileRecord struct {
Count int64
Cycles int64
StackRecord
}
// BlockProfile returns n, the number of records in the current blocking profile.
// If len(p) >= n, BlockProfile copies the profile into p and returns n, true.
// If len(p) < n, BlockProfile does not change p and returns n, false.
//
// Most clients should use the runtime/pprof package or
// the testing package's -test.blockprofile flag instead
// of calling BlockProfile directly.
func BlockProfile(p []BlockProfileRecord) (n int, ok bool) {
lock(&profBlockLock)
head := (*bucket)(bbuckets.Load())
for b := head; b != nil; b = b.allnext {
n++
}
if n <= len(p) {
ok = true
for b := head; b != nil; b = b.allnext {
bp := b.bp()
r := &p[0]
r.Count = int64(bp.count)
// Prevent callers from having to worry about division by zero errors.
// See discussion on http://golang.org/cl/299991.
if r.Count == 0 {
r.Count = 1
}
r.Cycles = bp.cycles
if raceenabled {
racewriterangepc(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0), getcallerpc(), abi.FuncPCABIInternal(BlockProfile))
}
if msanenabled {
msanwrite(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0))
}
if asanenabled {
asanwrite(unsafe.Pointer(&r.Stack0[0]), unsafe.Sizeof(r.Stack0))
}
i := copy(r.Stack0[:], b.stk())
for ; i < len(r.Stack0); i++ {
r.Stack0[i] = 0
}
p = p[1:]
}
}
unlock(&profBlockLock)
return
}
// MutexProfile returns n, the number of records in the current mutex profile.
// If len(p) >= n, MutexProfile copies the profile into p and returns n, true.
// Otherwise, MutexProfile does not change p, and returns n, false.
//
// Most clients should use the runtime/pprof package
// instead of calling MutexProfile directly.
func MutexProfile(p []BlockProfileRecord) (n int, ok bool) {
lock(&profBlockLock)
head := (*bucket)(xbuckets.Load())
for b := head; b != nil; b = b.allnext {
n++
}
if n <= len(p) {
ok = true
for b := head; b != nil; b = b.allnext {
bp := b.bp()
r := &p[0]
r.Count = int64(bp.count)
r.Cycles = bp.cycles
i := copy(r.Stack0[:], b.stk())
for ; i < len(r.Stack0); i++ {
r.Stack0[i] = 0
}
p = p[1:]
}
}
unlock(&profBlockLock)
return
}
// ThreadCreateProfile returns n, the number of records in the thread creation profile.
// If len(p) >= n, ThreadCreateProfile copies the profile into p and returns n, true.
// If len(p) < n, ThreadCreateProfile does not change p and returns n, false.
//
// Most clients should use the runtime/pprof package instead
// of calling ThreadCreateProfile directly.
func ThreadCreateProfile(p []StackRecord) (n int, ok bool) {
first := (*m)(atomic.Loadp(unsafe.Pointer(&allm)))
for mp := first; mp != nil; mp = mp.alllink {
n++
}
if n <= len(p) {
ok = true
i := 0
for mp := first; mp != nil; mp = mp.alllink {
p[i].Stack0 = mp.createstack
i++
}
}
return
}
//go:linkname runtime_goroutineProfileWithLabels runtime/pprof.runtime_goroutineProfileWithLabels
func runtime_goroutineProfileWithLabels(p []StackRecord, labels []unsafe.Pointer) (n int, ok bool) {
return goroutineProfileWithLabels(p, labels)
}
// labels may be nil. If labels is non-nil, it must have the same length as p.
func goroutineProfileWithLabels(p []StackRecord, labels []unsafe.Pointer) (n int, ok bool) {
if labels != nil && len(labels) != len(p) {
labels = nil
}
return goroutineProfileWithLabelsConcurrent(p, labels)
}
var goroutineProfile = struct {
sema uint32
active bool
offset atomic.Int64
records []StackRecord
labels []unsafe.Pointer
}{
sema: 1,
}
// goroutineProfileState indicates the status of a goroutine's stack for the
// current in-progress goroutine profile. Goroutines' stacks are initially
// "Absent" from the profile, and end up "Satisfied" by the time the profile is
// complete. While a goroutine's stack is being captured, its
// goroutineProfileState will be "InProgress" and it will not be able to run
// until the capture completes and the state moves to "Satisfied".
//
// Some goroutines (the finalizer goroutine, which at various times can be
// either a "system" or a "user" goroutine, and the goroutine that is
// coordinating the profile, any goroutines created during the profile) move
// directly to the "Satisfied" state.
type goroutineProfileState uint32
const (
goroutineProfileAbsent goroutineProfileState = iota
goroutineProfileInProgress
goroutineProfileSatisfied
)
type goroutineProfileStateHolder atomic.Uint32
func (p *goroutineProfileStateHolder) Load() goroutineProfileState {
return goroutineProfileState((*atomic.Uint32)(p).Load())
}
func (p *goroutineProfileStateHolder) Store(value goroutineProfileState) {
(*atomic.Uint32)(p).Store(uint32(value))
}
func (p *goroutineProfileStateHolder) CompareAndSwap(old, new goroutineProfileState) bool {
return (*atomic.Uint32)(p).CompareAndSwap(uint32(old), uint32(new))
}
func goroutineProfileWithLabelsConcurrent(p []StackRecord, labels []unsafe.Pointer) (n int, ok bool) {
semacquire(&goroutineProfile.sema)
ourg := getg()
stopTheWorld("profile")
// Using gcount while the world is stopped should give us a consistent view
// of the number of live goroutines, minus the number of goroutines that are
// alive and permanently marked as "system". But to make this count agree
// with what we'd get from isSystemGoroutine, we need special handling for
// goroutines that can vary between user and system to ensure that the count
// doesn't change during the collection. So, check the finalizer goroutine
// in particular.
n = int(gcount())
if fingStatus.Load()&fingRunningFinalizer != 0 {
n++
}
if n > len(p) {
// There's not enough space in p to store the whole profile, so (per the
// contract of runtime.GoroutineProfile) we're not allowed to write to p
// at all and must return n, false.
startTheWorld()
semrelease(&goroutineProfile.sema)
return n, false
}
// Save current goroutine.
sp := getcallersp()
pc := getcallerpc()
systemstack(func() {
saveg(pc, sp, ourg, &p[0])
})
ourg.goroutineProfiled.Store(goroutineProfileSatisfied)
goroutineProfile.offset.Store(1)
// Prepare for all other goroutines to enter the profile. Aside from ourg,
// every goroutine struct in the allgs list has its goroutineProfiled field
// cleared. Any goroutine created from this point on (while
// goroutineProfile.active is set) will start with its goroutineProfiled
// field set to goroutineProfileSatisfied.
goroutineProfile.active = true
goroutineProfile.records = p
goroutineProfile.labels = labels
// The finalizer goroutine needs special handling because it can vary over
// time between being a user goroutine (eligible for this profile) and a
// system goroutine (to be excluded). Pick one before restarting the world.
if fing != nil {
fing.goroutineProfiled.Store(goroutineProfileSatisfied)
if readgstatus(fing) != _Gdead && !isSystemGoroutine(fing, false) {
doRecordGoroutineProfile(fing)
}
}
startTheWorld()
// Visit each goroutine that existed as of the startTheWorld call above.
//
// New goroutines may not be in this list, but we didn't want to know about
// them anyway. If they do appear in this list (via reusing a dead goroutine
// struct, or racing to launch between the world restarting and us getting
// the list), they will already have their goroutineProfiled field set to
// goroutineProfileSatisfied before their state transitions out of _Gdead.
//
// Any goroutine that the scheduler tries to execute concurrently with this
// call will start by adding itself to the profile (before the act of
// executing can cause any changes in its stack).
forEachGRace(func(gp1 *g) {
tryRecordGoroutineProfile(gp1, Gosched)
})
stopTheWorld("profile cleanup")
endOffset := goroutineProfile.offset.Swap(0)
goroutineProfile.active = false
goroutineProfile.records = nil
goroutineProfile.labels = nil
startTheWorld()
// Restore the invariant that every goroutine struct in allgs has its
// goroutineProfiled field cleared.
forEachGRace(func(gp1 *g) {
gp1.goroutineProfiled.Store(goroutineProfileAbsent)
})
if raceenabled {
raceacquire(unsafe.Pointer(&labelSync))
}
if n != int(endOffset) {
// It's a big surprise that the number of goroutines changed while we
// were collecting the profile. But probably better to return a
// truncated profile than to crash the whole process.
//
// For instance, needm moves a goroutine out of the _Gdead state and so
// might be able to change the goroutine count without interacting with
// the scheduler. For code like that, the race windows are small and the
// combination of features is uncommon, so it's hard to be (and remain)
// sure we've caught them all.
}
semrelease(&goroutineProfile.sema)
return n, true
}
// tryRecordGoroutineProfileWB asserts that write barriers are allowed and calls
// tryRecordGoroutineProfile.
//
//go:yeswritebarrierrec
func tryRecordGoroutineProfileWB(gp1 *g) {
if getg().m.p.ptr() == nil {
throw("no P available, write barriers are forbidden")
}
tryRecordGoroutineProfile(gp1, osyield)
}
// tryRecordGoroutineProfile ensures that gp1 has the appropriate representation
// in the current goroutine profile: either that it should not be profiled, or
// that a snapshot of its call stack and labels are now in the profile.
func tryRecordGoroutineProfile(gp1 *g, yield func()) {
if readgstatus(gp1) == _Gdead {
// Dead goroutines should not appear in the profile. Goroutines that
// start while profile collection is active will get goroutineProfiled
// set to goroutineProfileSatisfied before transitioning out of _Gdead,
// so here we check _Gdead first.
return
}
if isSystemGoroutine(gp1, true) {
// System goroutines should not appear in the profile. (The finalizer
// goroutine is marked as "already profiled".)
return
}
for {
prev := gp1.goroutineProfiled.Load()
if prev == goroutineProfileSatisfied {
// This goroutine is already in the profile (or is new since the
// start of collection, so shouldn't appear in the profile).
break
}
if prev == goroutineProfileInProgress {
// Something else is adding gp1 to the goroutine profile right now.
// Give that a moment to finish.
yield()
continue
}
// While we have gp1.goroutineProfiled set to
// goroutineProfileInProgress, gp1 may appear _Grunnable but will not
// actually be able to run. Disable preemption for ourselves, to make
// sure we finish profiling gp1 right away instead of leaving it stuck
// in this limbo.
mp := acquirem()
if gp1.goroutineProfiled.CompareAndSwap(goroutineProfileAbsent, goroutineProfileInProgress) {
doRecordGoroutineProfile(gp1)
gp1.goroutineProfiled.Store(goroutineProfileSatisfied)
}
releasem(mp)
}
}
// doRecordGoroutineProfile writes gp1's call stack and labels to an in-progress
// goroutine profile. Preemption is disabled.
//
// This may be called via tryRecordGoroutineProfile in two ways: by the
// goroutine that is coordinating the goroutine profile (running on its own
// stack), or from the scheduler in preparation to execute gp1 (running on the
// system stack).
func doRecordGoroutineProfile(gp1 *g) {
if readgstatus(gp1) == _Grunning {
print("doRecordGoroutineProfile gp1=", gp1.goid, "\n")
throw("cannot read stack of running goroutine")
}
offset := int(goroutineProfile.offset.Add(1)) - 1
if offset >= len(goroutineProfile.records) {
// Should be impossible, but better to return a truncated profile than
// to crash the entire process at this point. Instead, deal with it in
// goroutineProfileWithLabelsConcurrent where we have more context.
return
}
// saveg calls gentraceback, which may call cgo traceback functions. When
// called from the scheduler, this is on the system stack already so
// traceback.go:cgoContextPCs will avoid calling back into the scheduler.
//
// When called from the goroutine coordinating the profile, we still have
// set gp1.goroutineProfiled to goroutineProfileInProgress and so are still
// preventing it from being truly _Grunnable. So we'll use the system stack
// to avoid schedule delays.
systemstack(func() { saveg(^uintptr(0), ^uintptr(0), gp1, &goroutineProfile.records[offset]) })
if goroutineProfile.labels != nil {
goroutineProfile.labels[offset] = gp1.labels
}
}
func goroutineProfileWithLabelsSync(p []StackRecord, labels []unsafe.Pointer) (n int, ok bool) {
gp := getg()
isOK := func(gp1 *g) bool {
// Checking isSystemGoroutine here makes GoroutineProfile
// consistent with both NumGoroutine and Stack.
return gp1 != gp && readgstatus(gp1) != _Gdead && !isSystemGoroutine(gp1, false)
}
stopTheWorld("profile")
// World is stopped, no locking required.
n = 1
forEachGRace(func(gp1 *g) {
if isOK(gp1) {
n++
}
})
if n <= len(p) {
ok = true
r, lbl := p, labels
// Save current goroutine.
sp := getcallersp()
pc := getcallerpc()
systemstack(func() {
saveg(pc, sp, gp, &r[0])
})
r = r[1:]
// If we have a place to put our goroutine labelmap, insert it there.
if labels != nil {
lbl[0] = gp.labels
lbl = lbl[1:]
}
// Save other goroutines.
forEachGRace(func(gp1 *g) {
if !isOK(gp1) {
return
}
if len(r) == 0 {
// Should be impossible, but better to return a
// truncated profile than to crash the entire process.
return
}
// saveg calls gentraceback, which may call cgo traceback functions.
// The world is stopped, so it cannot use cgocall (which will be
// blocked at exitsyscall). Do it on the system stack so it won't
// call into the schedular (see traceback.go:cgoContextPCs).
systemstack(func() { saveg(^uintptr(0), ^uintptr(0), gp1, &r[0]) })
if labels != nil {
lbl[0] = gp1.labels
lbl = lbl[1:]
}
r = r[1:]
})
}
if raceenabled {
raceacquire(unsafe.Pointer(&labelSync))
}
startTheWorld()
return n, ok
}
// GoroutineProfile returns n, the number of records in the active goroutine stack profile.
// If len(p) >= n, GoroutineProfile copies the profile into p and returns n, true.
// If len(p) < n, GoroutineProfile does not change p and returns n, false.
//
// Most clients should use the runtime/pprof package instead
// of calling GoroutineProfile directly.
func GoroutineProfile(p []StackRecord) (n int, ok bool) {
return goroutineProfileWithLabels(p, nil)
}
func saveg(pc, sp uintptr, gp *g, r *StackRecord) {
n := gentraceback(pc, sp, 0, gp, 0, &r.Stack0[0], len(r.Stack0), nil, nil, 0)
if n < len(r.Stack0) {
r.Stack0[n] = 0
}
}
// Stack formats a stack trace of the calling goroutine into buf
// and returns the number of bytes written to buf.
// If all is true, Stack formats stack traces of all other goroutines
// into buf after the trace for the current goroutine.
func Stack(buf []byte, all bool) int {
if all {
stopTheWorld("stack trace")
}
n := 0
if len(buf) > 0 {
gp := getg()
sp := getcallersp()
pc := getcallerpc()
systemstack(func() {
g0 := getg()
// Force traceback=1 to override GOTRACEBACK setting,
// so that Stack's results are consistent.
// GOTRACEBACK is only about crash dumps.
g0.m.traceback = 1
g0.writebuf = buf[0:0:len(buf)]
goroutineheader(gp)
traceback(pc, sp, 0, gp)
if all {
tracebackothers(gp)
}
g0.m.traceback = 0
n = len(g0.writebuf)
g0.writebuf = nil
})
}
if all {
startTheWorld()
}
return n
}
// Tracing of alloc/free/gc.
var tracelock mutex
func tracealloc(p unsafe.Pointer, size uintptr, typ *_type) {
lock(&tracelock)
gp := getg()
gp.m.traceback = 2
if typ == nil {
print("tracealloc(", p, ", ", hex(size), ")\n")
} else {
print("tracealloc(", p, ", ", hex(size), ", ", typ.string(), ")\n")
}
if gp.m.curg == nil || gp == gp.m.curg {
goroutineheader(gp)
pc := getcallerpc()
sp := getcallersp()
systemstack(func() {
traceback(pc, sp, 0, gp)
})
} else {
goroutineheader(gp.m.curg)
traceback(^uintptr(0), ^uintptr(0), 0, gp.m.curg)
}
print("\n")
gp.m.traceback = 0
unlock(&tracelock)
}
func tracefree(p unsafe.Pointer, size uintptr) {
lock(&tracelock)
gp := getg()
gp.m.traceback = 2
print("tracefree(", p, ", ", hex(size), ")\n")
goroutineheader(gp)
pc := getcallerpc()
sp := getcallersp()
systemstack(func() {
traceback(pc, sp, 0, gp)
})
print("\n")
gp.m.traceback = 0
unlock(&tracelock)
}
func tracegc() {
lock(&tracelock)
gp := getg()
gp.m.traceback = 2
print("tracegc()\n")
// running on m->g0 stack; show all non-g0 goroutines
tracebackothers(gp)
print("end tracegc\n")
print("\n")
gp.m.traceback = 0
unlock(&tracelock)
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Address range data structure.
//
// This file contains an implementation of a data structure which
// manages ordered address ranges.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
// addrRange represents a region of address space.
//
// An addrRange must never span a gap in the address space.
type addrRange struct {
// base and limit together represent the region of address space
// [base, limit). That is, base is inclusive, limit is exclusive.
// These are address over an offset view of the address space on
// platforms with a segmented address space, that is, on platforms
// where arenaBaseOffset != 0.
base, limit offAddr
}
// makeAddrRange creates a new address range from two virtual addresses.
//
// Throws if the base and limit are not in the same memory segment.
func makeAddrRange(base, limit uintptr) addrRange {
r := addrRange{offAddr{base}, offAddr{limit}}
if (base-arenaBaseOffset >= base) != (limit-arenaBaseOffset >= limit) {
throw("addr range base and limit are not in the same memory segment")
}
return r
}
// size returns the size of the range represented in bytes.
func (a addrRange) size() uintptr {
if !a.base.lessThan(a.limit) {
return 0
}
// Subtraction is safe because limit and base must be in the same
// segment of the address space.
return a.limit.diff(a.base)
}
// contains returns whether or not the range contains a given address.
func (a addrRange) contains(addr uintptr) bool {
return a.base.lessEqual(offAddr{addr}) && (offAddr{addr}).lessThan(a.limit)
}
// subtract takes the addrRange toPrune and cuts out any overlap with
// from, then returns the new range. subtract assumes that a and b
// either don't overlap at all, only overlap on one side, or are equal.
// If b is strictly contained in a, thus forcing a split, it will throw.
func (a addrRange) subtract(b addrRange) addrRange {
if b.base.lessEqual(a.base) && a.limit.lessEqual(b.limit) {
return addrRange{}
} else if a.base.lessThan(b.base) && b.limit.lessThan(a.limit) {
throw("bad prune")
} else if b.limit.lessThan(a.limit) && a.base.lessThan(b.limit) {
a.base = b.limit
} else if a.base.lessThan(b.base) && b.base.lessThan(a.limit) {
a.limit = b.base
}
return a
}
// takeFromFront takes len bytes from the front of the address range, aligning
// the base to align first. On success, returns the aligned start of the region
// taken and true.
func (a *addrRange) takeFromFront(len uintptr, align uint8) (uintptr, bool) {
base := alignUp(a.base.addr(), uintptr(align)) + len
if base > a.limit.addr() {
return 0, false
}
a.base = offAddr{base}
return base - len, true
}
// takeFromBack takes len bytes from the end of the address range, aligning
// the limit to align after subtracting len. On success, returns the aligned
// start of the region taken and true.
func (a *addrRange) takeFromBack(len uintptr, align uint8) (uintptr, bool) {
limit := alignDown(a.limit.addr()-len, uintptr(align))
if a.base.addr() > limit {
return 0, false
}
a.limit = offAddr{limit}
return limit, true
}
// removeGreaterEqual removes all addresses in a greater than or equal
// to addr and returns the new range.
func (a addrRange) removeGreaterEqual(addr uintptr) addrRange {
if (offAddr{addr}).lessEqual(a.base) {
return addrRange{}
}
if a.limit.lessEqual(offAddr{addr}) {
return a
}
return makeAddrRange(a.base.addr(), addr)
}
var (
// minOffAddr is the minimum address in the offset space, and
// it corresponds to the virtual address arenaBaseOffset.
minOffAddr = offAddr{arenaBaseOffset}
// maxOffAddr is the maximum address in the offset address
// space. It corresponds to the highest virtual address representable
// by the page alloc chunk and heap arena maps.
maxOffAddr = offAddr{(((1 << heapAddrBits) - 1) + arenaBaseOffset) & uintptrMask}
)
// offAddr represents an address in a contiguous view
// of the address space on systems where the address space is
// segmented. On other systems, it's just a normal address.
type offAddr struct {
// a is just the virtual address, but should never be used
// directly. Call addr() to get this value instead.
a uintptr
}
// add adds a uintptr offset to the offAddr.
func (l offAddr) add(bytes uintptr) offAddr {
return offAddr{a: l.a + bytes}
}
// sub subtracts a uintptr offset from the offAddr.
func (l offAddr) sub(bytes uintptr) offAddr {
return offAddr{a: l.a - bytes}
}
// diff returns the amount of bytes in between the
// two offAddrs.
func (l1 offAddr) diff(l2 offAddr) uintptr {
return l1.a - l2.a
}
// lessThan returns true if l1 is less than l2 in the offset
// address space.
func (l1 offAddr) lessThan(l2 offAddr) bool {
return (l1.a - arenaBaseOffset) < (l2.a - arenaBaseOffset)
}
// lessEqual returns true if l1 is less than or equal to l2 in
// the offset address space.
func (l1 offAddr) lessEqual(l2 offAddr) bool {
return (l1.a - arenaBaseOffset) <= (l2.a - arenaBaseOffset)
}
// equal returns true if the two offAddr values are equal.
func (l1 offAddr) equal(l2 offAddr) bool {
// No need to compare in the offset space, it
// means the same thing.
return l1 == l2
}
// addr returns the virtual address for this offset address.
func (l offAddr) addr() uintptr {
return l.a
}
// atomicOffAddr is like offAddr, but operations on it are atomic.
// It also contains operations to be able to store marked addresses
// to ensure that they're not overridden until they've been seen.
type atomicOffAddr struct {
// a contains the offset address, unlike offAddr.
a atomic.Int64
}
// Clear attempts to store minOffAddr in atomicOffAddr. It may fail
// if a marked value is placed in the box in the meanwhile.
func (b *atomicOffAddr) Clear() {
for {
old := b.a.Load()
if old < 0 {
return
}
if b.a.CompareAndSwap(old, int64(minOffAddr.addr()-arenaBaseOffset)) {
return
}
}
}
// StoreMin stores addr if it's less than the current value in the
// offset address space if the current value is not marked.
func (b *atomicOffAddr) StoreMin(addr uintptr) {
new := int64(addr - arenaBaseOffset)
for {
old := b.a.Load()
if old < new {
return
}
if b.a.CompareAndSwap(old, new) {
return
}
}
}
// StoreUnmark attempts to unmark the value in atomicOffAddr and
// replace it with newAddr. markedAddr must be a marked address
// returned by Load. This function will not store newAddr if the
// box no longer contains markedAddr.
func (b *atomicOffAddr) StoreUnmark(markedAddr, newAddr uintptr) {
b.a.CompareAndSwap(-int64(markedAddr-arenaBaseOffset), int64(newAddr-arenaBaseOffset))
}
// StoreMarked stores addr but first converted to the offset address
// space and then negated.
func (b *atomicOffAddr) StoreMarked(addr uintptr) {
b.a.Store(-int64(addr - arenaBaseOffset))
}
// Load returns the address in the box as a virtual address. It also
// returns if the value was marked or not.
func (b *atomicOffAddr) Load() (uintptr, bool) {
v := b.a.Load()
wasMarked := false
if v < 0 {
wasMarked = true
v = -v
}
return uintptr(v) + arenaBaseOffset, wasMarked
}
// addrRanges is a data structure holding a collection of ranges of
// address space.
//
// The ranges are coalesced eagerly to reduce the
// number ranges it holds.
//
// The slice backing store for this field is persistentalloc'd
// and thus there is no way to free it.
//
// addrRanges is not thread-safe.
type addrRanges struct {
// ranges is a slice of ranges sorted by base.
ranges []addrRange
// totalBytes is the total amount of address space in bytes counted by
// this addrRanges.
totalBytes uintptr
// sysStat is the stat to track allocations by this type
sysStat *sysMemStat
}
func (a *addrRanges) init(sysStat *sysMemStat) {
ranges := (*notInHeapSlice)(unsafe.Pointer(&a.ranges))
ranges.len = 0
ranges.cap = 16
ranges.array = (*notInHeap)(persistentalloc(unsafe.Sizeof(addrRange{})*uintptr(ranges.cap), goarch.PtrSize, sysStat))
a.sysStat = sysStat
a.totalBytes = 0
}
// findSucc returns the first index in a such that addr is
// less than the base of the addrRange at that index.
func (a *addrRanges) findSucc(addr uintptr) int {
base := offAddr{addr}
// Narrow down the search space via a binary search
// for large addrRanges until we have at most iterMax
// candidates left.
const iterMax = 8
bot, top := 0, len(a.ranges)
for top-bot > iterMax {
i := ((top - bot) / 2) + bot
if a.ranges[i].contains(base.addr()) {
// a.ranges[i] contains base, so
// its successor is the next index.
return i + 1
}
if base.lessThan(a.ranges[i].base) {
// In this case i might actually be
// the successor, but we can't be sure
// until we check the ones before it.
top = i
} else {
// In this case we know base is
// greater than or equal to a.ranges[i].limit-1,
// so i is definitely not the successor.
// We already checked i, so pick the next
// one.
bot = i + 1
}
}
// There are top-bot candidates left, so
// iterate over them and find the first that
// base is strictly less than.
for i := bot; i < top; i++ {
if base.lessThan(a.ranges[i].base) {
return i
}
}
return top
}
// findAddrGreaterEqual returns the smallest address represented by a
// that is >= addr. Thus, if the address is represented by a,
// then it returns addr. The second return value indicates whether
// such an address exists for addr in a. That is, if addr is larger than
// any address known to a, the second return value will be false.
func (a *addrRanges) findAddrGreaterEqual(addr uintptr) (uintptr, bool) {
i := a.findSucc(addr)
if i == 0 {
return a.ranges[0].base.addr(), true
}
if a.ranges[i-1].contains(addr) {
return addr, true
}
if i < len(a.ranges) {
return a.ranges[i].base.addr(), true
}
return 0, false
}
// contains returns true if a covers the address addr.
func (a *addrRanges) contains(addr uintptr) bool {
i := a.findSucc(addr)
if i == 0 {
return false
}
return a.ranges[i-1].contains(addr)
}
// add inserts a new address range to a.
//
// r must not overlap with any address range in a and r.size() must be > 0.
func (a *addrRanges) add(r addrRange) {
// The copies in this function are potentially expensive, but this data
// structure is meant to represent the Go heap. At worst, copying this
// would take ~160µs assuming a conservative copying rate of 25 GiB/s (the
// copy will almost never trigger a page fault) for a 1 TiB heap with 4 MiB
// arenas which is completely discontiguous. ~160µs is still a lot, but in
// practice most platforms have 64 MiB arenas (which cuts this by a factor
// of 16) and Go heaps are usually mostly contiguous, so the chance that
// an addrRanges even grows to that size is extremely low.
// An empty range has no effect on the set of addresses represented
// by a, but passing a zero-sized range is almost always a bug.
if r.size() == 0 {
print("runtime: range = {", hex(r.base.addr()), ", ", hex(r.limit.addr()), "}\n")
throw("attempted to add zero-sized address range")
}
// Because we assume r is not currently represented in a,
// findSucc gives us our insertion index.
i := a.findSucc(r.base.addr())
coalescesDown := i > 0 && a.ranges[i-1].limit.equal(r.base)
coalescesUp := i < len(a.ranges) && r.limit.equal(a.ranges[i].base)
if coalescesUp && coalescesDown {
// We have neighbors and they both border us.
// Merge a.ranges[i-1], r, and a.ranges[i] together into a.ranges[i-1].
a.ranges[i-1].limit = a.ranges[i].limit
// Delete a.ranges[i].
copy(a.ranges[i:], a.ranges[i+1:])
a.ranges = a.ranges[:len(a.ranges)-1]
} else if coalescesDown {
// We have a neighbor at a lower address only and it borders us.
// Merge the new space into a.ranges[i-1].
a.ranges[i-1].limit = r.limit
} else if coalescesUp {
// We have a neighbor at a higher address only and it borders us.
// Merge the new space into a.ranges[i].
a.ranges[i].base = r.base
} else {
// We may or may not have neighbors which don't border us.
// Add the new range.
if len(a.ranges)+1 > cap(a.ranges) {
// Grow the array. Note that this leaks the old array, but since
// we're doubling we have at most 2x waste. For a 1 TiB heap and
// 4 MiB arenas which are all discontiguous (both very conservative
// assumptions), this would waste at most 4 MiB of memory.
oldRanges := a.ranges
ranges := (*notInHeapSlice)(unsafe.Pointer(&a.ranges))
ranges.len = len(oldRanges) + 1
ranges.cap = cap(oldRanges) * 2
ranges.array = (*notInHeap)(persistentalloc(unsafe.Sizeof(addrRange{})*uintptr(ranges.cap), goarch.PtrSize, a.sysStat))
// Copy in the old array, but make space for the new range.
copy(a.ranges[:i], oldRanges[:i])
copy(a.ranges[i+1:], oldRanges[i:])
} else {
a.ranges = a.ranges[:len(a.ranges)+1]
copy(a.ranges[i+1:], a.ranges[i:])
}
a.ranges[i] = r
}
a.totalBytes += r.size()
}
// removeLast removes and returns the highest-addressed contiguous range
// of a, or the last nBytes of that range, whichever is smaller. If a is
// empty, it returns an empty range.
func (a *addrRanges) removeLast(nBytes uintptr) addrRange {
if len(a.ranges) == 0 {
return addrRange{}
}
r := a.ranges[len(a.ranges)-1]
size := r.size()
if size > nBytes {
newEnd := r.limit.sub(nBytes)
a.ranges[len(a.ranges)-1].limit = newEnd
a.totalBytes -= nBytes
return addrRange{newEnd, r.limit}
}
a.ranges = a.ranges[:len(a.ranges)-1]
a.totalBytes -= size
return r
}
// removeGreaterEqual removes the ranges of a which are above addr, and additionally
// splits any range containing addr.
func (a *addrRanges) removeGreaterEqual(addr uintptr) {
pivot := a.findSucc(addr)
if pivot == 0 {
// addr is before all ranges in a.
a.totalBytes = 0
a.ranges = a.ranges[:0]
return
}
removed := uintptr(0)
for _, r := range a.ranges[pivot:] {
removed += r.size()
}
if r := a.ranges[pivot-1]; r.contains(addr) {
removed += r.size()
r = r.removeGreaterEqual(addr)
if r.size() == 0 {
pivot--
} else {
removed -= r.size()
a.ranges[pivot-1] = r
}
}
a.ranges = a.ranges[:pivot]
a.totalBytes -= removed
}
// cloneInto makes a deep clone of a's state into b, re-using
// b's ranges if able.
func (a *addrRanges) cloneInto(b *addrRanges) {
if len(a.ranges) > cap(b.ranges) {
// Grow the array.
ranges := (*notInHeapSlice)(unsafe.Pointer(&b.ranges))
ranges.len = 0
ranges.cap = cap(a.ranges)
ranges.array = (*notInHeap)(persistentalloc(unsafe.Sizeof(addrRange{})*uintptr(ranges.cap), goarch.PtrSize, b.sysStat))
}
b.ranges = b.ranges[:len(a.ranges)]
b.totalBytes = a.totalBytes
copy(b.ranges, a.ranges)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !msan
// Dummy MSan support API, used when not built with -msan.
package runtime
import (
"unsafe"
)
const msanenabled = false
// Because msanenabled is false, none of these functions should be called.
func msanread(addr unsafe.Pointer, sz uintptr) { throw("msan") }
func msanwrite(addr unsafe.Pointer, sz uintptr) { throw("msan") }
func msanmalloc(addr unsafe.Pointer, sz uintptr) { throw("msan") }
func msanfree(addr unsafe.Pointer, sz uintptr) { throw("msan") }
func msanmove(dst, src unsafe.Pointer, sz uintptr) { throw("msan") }
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Malloc small size classes.
//
// See malloc.go for overview.
// See also mksizeclasses.go for how we decide what size classes to use.
package runtime
// Returns size of the memory block that mallocgc will allocate if you ask for the size.
func roundupsize(size uintptr) uintptr {
if size < _MaxSmallSize {
if size <= smallSizeMax-8 {
return uintptr(class_to_size[size_to_class8[divRoundUp(size, smallSizeDiv)]])
} else {
return uintptr(class_to_size[size_to_class128[divRoundUp(size-smallSizeMax, largeSizeDiv)]])
}
}
if size+_PageSize < size {
return size
}
return alignUp(size, _PageSize)
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/cpu"
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
// A spanSet is a set of *mspans.
//
// spanSet is safe for concurrent push and pop operations.
type spanSet struct {
// A spanSet is a two-level data structure consisting of a
// growable spine that points to fixed-sized blocks. The spine
// can be accessed without locks, but adding a block or
// growing it requires taking the spine lock.
//
// Because each mspan covers at least 8K of heap and takes at
// most 8 bytes in the spanSet, the growth of the spine is
// quite limited.
//
// The spine and all blocks are allocated off-heap, which
// allows this to be used in the memory manager and avoids the
// need for write barriers on all of these. spanSetBlocks are
// managed in a pool, though never freed back to the operating
// system. We never release spine memory because there could be
// concurrent lock-free access and we're likely to reuse it
// anyway. (In principle, we could do this during STW.)
spineLock mutex
spine atomicSpanSetSpinePointer // *[N]atomic.Pointer[spanSetBlock]
spineLen atomic.Uintptr // Spine array length
spineCap uintptr // Spine array cap, accessed under spineLock
// index is the head and tail of the spanSet in a single field.
// The head and the tail both represent an index into the logical
// concatenation of all blocks, with the head always behind or
// equal to the tail (indicating an empty set). This field is
// always accessed atomically.
//
// The head and the tail are only 32 bits wide, which means we
// can only support up to 2^32 pushes before a reset. If every
// span in the heap were stored in this set, and each span were
// the minimum size (1 runtime page, 8 KiB), then roughly the
// smallest heap which would be unrepresentable is 32 TiB in size.
index atomicHeadTailIndex
}
const (
spanSetBlockEntries = 512 // 4KB on 64-bit
spanSetInitSpineCap = 256 // Enough for 1GB heap on 64-bit
)
type spanSetBlock struct {
// Free spanSetBlocks are managed via a lock-free stack.
lfnode
// popped is the number of pop operations that have occurred on
// this block. This number is used to help determine when a block
// may be safely recycled.
popped atomic.Uint32
// spans is the set of spans in this block.
spans [spanSetBlockEntries]atomicMSpanPointer
}
// push adds span s to buffer b. push is safe to call concurrently
// with other push and pop operations.
func (b *spanSet) push(s *mspan) {
// Obtain our slot.
cursor := uintptr(b.index.incTail().tail() - 1)
top, bottom := cursor/spanSetBlockEntries, cursor%spanSetBlockEntries
// Do we need to add a block?
spineLen := b.spineLen.Load()
var block *spanSetBlock
retry:
if top < spineLen {
block = b.spine.Load().lookup(top).Load()
} else {
// Add a new block to the spine, potentially growing
// the spine.
lock(&b.spineLock)
// spineLen cannot change until we release the lock,
// but may have changed while we were waiting.
spineLen = b.spineLen.Load()
if top < spineLen {
unlock(&b.spineLock)
goto retry
}
spine := b.spine.Load()
if spineLen == b.spineCap {
// Grow the spine.
newCap := b.spineCap * 2
if newCap == 0 {
newCap = spanSetInitSpineCap
}
newSpine := persistentalloc(newCap*goarch.PtrSize, cpu.CacheLineSize, &memstats.gcMiscSys)
if b.spineCap != 0 {
// Blocks are allocated off-heap, so
// no write barriers.
memmove(newSpine, spine.p, b.spineCap*goarch.PtrSize)
}
spine = spanSetSpinePointer{newSpine}
// Spine is allocated off-heap, so no write barrier.
b.spine.StoreNoWB(spine)
b.spineCap = newCap
// We can't immediately free the old spine
// since a concurrent push with a lower index
// could still be reading from it. We let it
// leak because even a 1TB heap would waste
// less than 2MB of memory on old spines. If
// this is a problem, we could free old spines
// during STW.
}
// Allocate a new block from the pool.
block = spanSetBlockPool.alloc()
// Add it to the spine.
// Blocks are allocated off-heap, so no write barrier.
spine.lookup(top).StoreNoWB(block)
b.spineLen.Store(spineLen + 1)
unlock(&b.spineLock)
}
// We have a block. Insert the span atomically, since there may be
// concurrent readers via the block API.
block.spans[bottom].StoreNoWB(s)
}
// pop removes and returns a span from buffer b, or nil if b is empty.
// pop is safe to call concurrently with other pop and push operations.
func (b *spanSet) pop() *mspan {
var head, tail uint32
claimLoop:
for {
headtail := b.index.load()
head, tail = headtail.split()
if head >= tail {
// The buf is empty, as far as we can tell.
return nil
}
// Check if the head position we want to claim is actually
// backed by a block.
spineLen := b.spineLen.Load()
if spineLen <= uintptr(head)/spanSetBlockEntries {
// We're racing with a spine growth and the allocation of
// a new block (and maybe a new spine!), and trying to grab
// the span at the index which is currently being pushed.
// Instead of spinning, let's just notify the caller that
// there's nothing currently here. Spinning on this is
// almost definitely not worth it.
return nil
}
// Try to claim the current head by CASing in an updated head.
// This may fail transiently due to a push which modifies the
// tail, so keep trying while the head isn't changing.
want := head
for want == head {
if b.index.cas(headtail, makeHeadTailIndex(want+1, tail)) {
break claimLoop
}
headtail = b.index.load()
head, tail = headtail.split()
}
// We failed to claim the spot we were after and the head changed,
// meaning a popper got ahead of us. Try again from the top because
// the buf may not be empty.
}
top, bottom := head/spanSetBlockEntries, head%spanSetBlockEntries
// We may be reading a stale spine pointer, but because the length
// grows monotonically and we've already verified it, we'll definitely
// be reading from a valid block.
blockp := b.spine.Load().lookup(uintptr(top))
// Given that the spine length is correct, we know we will never
// see a nil block here, since the length is always updated after
// the block is set.
block := blockp.Load()
s := block.spans[bottom].Load()
for s == nil {
// We raced with the span actually being set, but given that we
// know a block for this span exists, the race window here is
// extremely small. Try again.
s = block.spans[bottom].Load()
}
// Clear the pointer. This isn't strictly necessary, but defensively
// avoids accidentally re-using blocks which could lead to memory
// corruption. This way, we'll get a nil pointer access instead.
block.spans[bottom].StoreNoWB(nil)
// Increase the popped count. If we are the last possible popper
// in the block (note that bottom need not equal spanSetBlockEntries-1
// due to races) then it's our responsibility to free the block.
//
// If we increment popped to spanSetBlockEntries, we can be sure that
// we're the last popper for this block, and it's thus safe to free it.
// Every other popper must have crossed this barrier (and thus finished
// popping its corresponding mspan) by the time we get here. Because
// we're the last popper, we also don't have to worry about concurrent
// pushers (there can't be any). Note that we may not be the popper
// which claimed the last slot in the block, we're just the last one
// to finish popping.
if block.popped.Add(1) == spanSetBlockEntries {
// Clear the block's pointer.
blockp.StoreNoWB(nil)
// Return the block to the block pool.
spanSetBlockPool.free(block)
}
return s
}
// reset resets a spanSet which is empty. It will also clean up
// any left over blocks.
//
// Throws if the buf is not empty.
//
// reset may not be called concurrently with any other operations
// on the span set.
func (b *spanSet) reset() {
head, tail := b.index.load().split()
if head < tail {
print("head = ", head, ", tail = ", tail, "\n")
throw("attempt to clear non-empty span set")
}
top := head / spanSetBlockEntries
if uintptr(top) < b.spineLen.Load() {
// If the head catches up to the tail and the set is empty,
// we may not clean up the block containing the head and tail
// since it may be pushed into again. In order to avoid leaking
// memory since we're going to reset the head and tail, clean
// up such a block now, if it exists.
blockp := b.spine.Load().lookup(uintptr(top))
block := blockp.Load()
if block != nil {
// Check the popped value.
if block.popped.Load() == 0 {
// popped should never be zero because that means we have
// pushed at least one value but not yet popped if this
// block pointer is not nil.
throw("span set block with unpopped elements found in reset")
}
if block.popped.Load() == spanSetBlockEntries {
// popped should also never be equal to spanSetBlockEntries
// because the last popper should have made the block pointer
// in this slot nil.
throw("fully empty unfreed span set block found in reset")
}
// Clear the pointer to the block.
blockp.StoreNoWB(nil)
// Return the block to the block pool.
spanSetBlockPool.free(block)
}
}
b.index.reset()
b.spineLen.Store(0)
}
// atomicSpanSetSpinePointer is an atomically-accessed spanSetSpinePointer.
//
// It has the same semantics as atomic.UnsafePointer.
type atomicSpanSetSpinePointer struct {
a atomic.UnsafePointer
}
// Loads the spanSetSpinePointer and returns it.
//
// It has the same semantics as atomic.UnsafePointer.
func (s *atomicSpanSetSpinePointer) Load() spanSetSpinePointer {
return spanSetSpinePointer{s.a.Load()}
}
// Stores the spanSetSpinePointer.
//
// It has the same semantics as atomic.UnsafePointer.
func (s *atomicSpanSetSpinePointer) StoreNoWB(p spanSetSpinePointer) {
s.a.StoreNoWB(p.p)
}
// spanSetSpinePointer represents a pointer to a contiguous block of atomic.Pointer[spanSetBlock].
type spanSetSpinePointer struct {
p unsafe.Pointer
}
// lookup returns &s[idx].
func (s spanSetSpinePointer) lookup(idx uintptr) *atomic.Pointer[spanSetBlock] {
return (*atomic.Pointer[spanSetBlock])(add(unsafe.Pointer(s.p), goarch.PtrSize*idx))
}
// spanSetBlockPool is a global pool of spanSetBlocks.
var spanSetBlockPool spanSetBlockAlloc
// spanSetBlockAlloc represents a concurrent pool of spanSetBlocks.
type spanSetBlockAlloc struct {
stack lfstack
}
// alloc tries to grab a spanSetBlock out of the pool, and if it fails
// persistentallocs a new one and returns it.
func (p *spanSetBlockAlloc) alloc() *spanSetBlock {
if s := (*spanSetBlock)(p.stack.pop()); s != nil {
return s
}
return (*spanSetBlock)(persistentalloc(unsafe.Sizeof(spanSetBlock{}), cpu.CacheLineSize, &memstats.gcMiscSys))
}
// free returns a spanSetBlock back to the pool.
func (p *spanSetBlockAlloc) free(block *spanSetBlock) {
block.popped.Store(0)
p.stack.push(&block.lfnode)
}
// haidTailIndex represents a combined 32-bit head and 32-bit tail
// of a queue into a single 64-bit value.
type headTailIndex uint64
// makeHeadTailIndex creates a headTailIndex value from a separate
// head and tail.
func makeHeadTailIndex(head, tail uint32) headTailIndex {
return headTailIndex(uint64(head)<<32 | uint64(tail))
}
// head returns the head of a headTailIndex value.
func (h headTailIndex) head() uint32 {
return uint32(h >> 32)
}
// tail returns the tail of a headTailIndex value.
func (h headTailIndex) tail() uint32 {
return uint32(h)
}
// split splits the headTailIndex value into its parts.
func (h headTailIndex) split() (head uint32, tail uint32) {
return h.head(), h.tail()
}
// atomicHeadTailIndex is an atomically-accessed headTailIndex.
type atomicHeadTailIndex struct {
u atomic.Uint64
}
// load atomically reads a headTailIndex value.
func (h *atomicHeadTailIndex) load() headTailIndex {
return headTailIndex(h.u.Load())
}
// cas atomically compares-and-swaps a headTailIndex value.
func (h *atomicHeadTailIndex) cas(old, new headTailIndex) bool {
return h.u.CompareAndSwap(uint64(old), uint64(new))
}
// incHead atomically increments the head of a headTailIndex.
func (h *atomicHeadTailIndex) incHead() headTailIndex {
return headTailIndex(h.u.Add(1 << 32))
}
// decHead atomically decrements the head of a headTailIndex.
func (h *atomicHeadTailIndex) decHead() headTailIndex {
return headTailIndex(h.u.Add(-(1 << 32)))
}
// incTail atomically increments the tail of a headTailIndex.
func (h *atomicHeadTailIndex) incTail() headTailIndex {
ht := headTailIndex(h.u.Add(1))
// Check for overflow.
if ht.tail() == 0 {
print("runtime: head = ", ht.head(), ", tail = ", ht.tail(), "\n")
throw("headTailIndex overflow")
}
return ht
}
// reset clears the headTailIndex to (0, 0).
func (h *atomicHeadTailIndex) reset() {
h.u.Store(0)
}
// atomicMSpanPointer is an atomic.Pointer[mspan]. Can't use generics because it's NotInHeap.
type atomicMSpanPointer struct {
p atomic.UnsafePointer
}
// Load returns the *mspan.
func (p *atomicMSpanPointer) Load() *mspan {
return (*mspan)(p.p.Load())
}
// Store stores an *mspan.
func (p *atomicMSpanPointer) StoreNoWB(s *mspan) {
p.p.StoreNoWB(unsafe.Pointer(s))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Memory statistics
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
type mstats struct {
// Statistics about malloc heap.
heapStats consistentHeapStats
// Statistics about stacks.
stacks_sys sysMemStat // only counts newosproc0 stack in mstats; differs from MemStats.StackSys
// Statistics about allocation of low-level fixed-size structures.
mspan_sys sysMemStat
mcache_sys sysMemStat
buckhash_sys sysMemStat // profiling bucket hash table
// Statistics about GC overhead.
gcMiscSys sysMemStat // updated atomically or during STW
// Miscellaneous statistics.
other_sys sysMemStat // updated atomically or during STW
// Statistics about the garbage collector.
// Protected by mheap or stopping the world during GC.
last_gc_unix uint64 // last gc (in unix time)
pause_total_ns uint64
pause_ns [256]uint64 // circular buffer of recent gc pause lengths
pause_end [256]uint64 // circular buffer of recent gc end times (nanoseconds since 1970)
numgc uint32
numforcedgc uint32 // number of user-forced GCs
gc_cpu_fraction float64 // fraction of CPU time used by GC
last_gc_nanotime uint64 // last gc (monotonic time)
lastHeapInUse uint64 // heapInUse at mark termination of the previous GC
enablegc bool
// gcPauseDist represents the distribution of all GC-related
// application pauses in the runtime.
//
// Each individual pause is counted separately, unlike pause_ns.
gcPauseDist timeHistogram
}
var memstats mstats
// A MemStats records statistics about the memory allocator.
type MemStats struct {
// General statistics.
// Alloc is bytes of allocated heap objects.
//
// This is the same as HeapAlloc (see below).
Alloc uint64
// TotalAlloc is cumulative bytes allocated for heap objects.
//
// TotalAlloc increases as heap objects are allocated, but
// unlike Alloc and HeapAlloc, it does not decrease when
// objects are freed.
TotalAlloc uint64
// Sys is the total bytes of memory obtained from the OS.
//
// Sys is the sum of the XSys fields below. Sys measures the
// virtual address space reserved by the Go runtime for the
// heap, stacks, and other internal data structures. It's
// likely that not all of the virtual address space is backed
// by physical memory at any given moment, though in general
// it all was at some point.
Sys uint64
// Lookups is the number of pointer lookups performed by the
// runtime.
//
// This is primarily useful for debugging runtime internals.
Lookups uint64
// Mallocs is the cumulative count of heap objects allocated.
// The number of live objects is Mallocs - Frees.
Mallocs uint64
// Frees is the cumulative count of heap objects freed.
Frees uint64
// Heap memory statistics.
//
// Interpreting the heap statistics requires some knowledge of
// how Go organizes memory. Go divides the virtual address
// space of the heap into "spans", which are contiguous
// regions of memory 8K or larger. A span may be in one of
// three states:
//
// An "idle" span contains no objects or other data. The
// physical memory backing an idle span can be released back
// to the OS (but the virtual address space never is), or it
// can be converted into an "in use" or "stack" span.
//
// An "in use" span contains at least one heap object and may
// have free space available to allocate more heap objects.
//
// A "stack" span is used for goroutine stacks. Stack spans
// are not considered part of the heap. A span can change
// between heap and stack memory; it is never used for both
// simultaneously.
// HeapAlloc is bytes of allocated heap objects.
//
// "Allocated" heap objects include all reachable objects, as
// well as unreachable objects that the garbage collector has
// not yet freed. Specifically, HeapAlloc increases as heap
// objects are allocated and decreases as the heap is swept
// and unreachable objects are freed. Sweeping occurs
// incrementally between GC cycles, so these two processes
// occur simultaneously, and as a result HeapAlloc tends to
// change smoothly (in contrast with the sawtooth that is
// typical of stop-the-world garbage collectors).
HeapAlloc uint64
// HeapSys is bytes of heap memory obtained from the OS.
//
// HeapSys measures the amount of virtual address space
// reserved for the heap. This includes virtual address space
// that has been reserved but not yet used, which consumes no
// physical memory, but tends to be small, as well as virtual
// address space for which the physical memory has been
// returned to the OS after it became unused (see HeapReleased
// for a measure of the latter).
//
// HeapSys estimates the largest size the heap has had.
HeapSys uint64
// HeapIdle is bytes in idle (unused) spans.
//
// Idle spans have no objects in them. These spans could be
// (and may already have been) returned to the OS, or they can
// be reused for heap allocations, or they can be reused as
// stack memory.
//
// HeapIdle minus HeapReleased estimates the amount of memory
// that could be returned to the OS, but is being retained by
// the runtime so it can grow the heap without requesting more
// memory from the OS. If this difference is significantly
// larger than the heap size, it indicates there was a recent
// transient spike in live heap size.
HeapIdle uint64
// HeapInuse is bytes in in-use spans.
//
// In-use spans have at least one object in them. These spans
// can only be used for other objects of roughly the same
// size.
//
// HeapInuse minus HeapAlloc estimates the amount of memory
// that has been dedicated to particular size classes, but is
// not currently being used. This is an upper bound on
// fragmentation, but in general this memory can be reused
// efficiently.
HeapInuse uint64
// HeapReleased is bytes of physical memory returned to the OS.
//
// This counts heap memory from idle spans that was returned
// to the OS and has not yet been reacquired for the heap.
HeapReleased uint64
// HeapObjects is the number of allocated heap objects.
//
// Like HeapAlloc, this increases as objects are allocated and
// decreases as the heap is swept and unreachable objects are
// freed.
HeapObjects uint64
// Stack memory statistics.
//
// Stacks are not considered part of the heap, but the runtime
// can reuse a span of heap memory for stack memory, and
// vice-versa.
// StackInuse is bytes in stack spans.
//
// In-use stack spans have at least one stack in them. These
// spans can only be used for other stacks of the same size.
//
// There is no StackIdle because unused stack spans are
// returned to the heap (and hence counted toward HeapIdle).
StackInuse uint64
// StackSys is bytes of stack memory obtained from the OS.
//
// StackSys is StackInuse, plus any memory obtained directly
// from the OS for OS thread stacks (which should be minimal).
StackSys uint64
// Off-heap memory statistics.
//
// The following statistics measure runtime-internal
// structures that are not allocated from heap memory (usually
// because they are part of implementing the heap). Unlike
// heap or stack memory, any memory allocated to these
// structures is dedicated to these structures.
//
// These are primarily useful for debugging runtime memory
// overheads.
// MSpanInuse is bytes of allocated mspan structures.
MSpanInuse uint64
// MSpanSys is bytes of memory obtained from the OS for mspan
// structures.
MSpanSys uint64
// MCacheInuse is bytes of allocated mcache structures.
MCacheInuse uint64
// MCacheSys is bytes of memory obtained from the OS for
// mcache structures.
MCacheSys uint64
// BuckHashSys is bytes of memory in profiling bucket hash tables.
BuckHashSys uint64
// GCSys is bytes of memory in garbage collection metadata.
GCSys uint64
// OtherSys is bytes of memory in miscellaneous off-heap
// runtime allocations.
OtherSys uint64
// Garbage collector statistics.
// NextGC is the target heap size of the next GC cycle.
//
// The garbage collector's goal is to keep HeapAlloc ≤ NextGC.
// At the end of each GC cycle, the target for the next cycle
// is computed based on the amount of reachable data and the
// value of GOGC.
NextGC uint64
// LastGC is the time the last garbage collection finished, as
// nanoseconds since 1970 (the UNIX epoch).
LastGC uint64
// PauseTotalNs is the cumulative nanoseconds in GC
// stop-the-world pauses since the program started.
//
// During a stop-the-world pause, all goroutines are paused
// and only the garbage collector can run.
PauseTotalNs uint64
// PauseNs is a circular buffer of recent GC stop-the-world
// pause times in nanoseconds.
//
// The most recent pause is at PauseNs[(NumGC+255)%256]. In
// general, PauseNs[N%256] records the time paused in the most
// recent N%256th GC cycle. There may be multiple pauses per
// GC cycle; this is the sum of all pauses during a cycle.
PauseNs [256]uint64
// PauseEnd is a circular buffer of recent GC pause end times,
// as nanoseconds since 1970 (the UNIX epoch).
//
// This buffer is filled the same way as PauseNs. There may be
// multiple pauses per GC cycle; this records the end of the
// last pause in a cycle.
PauseEnd [256]uint64
// NumGC is the number of completed GC cycles.
NumGC uint32
// NumForcedGC is the number of GC cycles that were forced by
// the application calling the GC function.
NumForcedGC uint32
// GCCPUFraction is the fraction of this program's available
// CPU time used by the GC since the program started.
//
// GCCPUFraction is expressed as a number between 0 and 1,
// where 0 means GC has consumed none of this program's CPU. A
// program's available CPU time is defined as the integral of
// GOMAXPROCS since the program started. That is, if
// GOMAXPROCS is 2 and a program has been running for 10
// seconds, its "available CPU" is 20 seconds. GCCPUFraction
// does not include CPU time used for write barrier activity.
//
// This is the same as the fraction of CPU reported by
// GODEBUG=gctrace=1.
GCCPUFraction float64
// EnableGC indicates that GC is enabled. It is always true,
// even if GOGC=off.
EnableGC bool
// DebugGC is currently unused.
DebugGC bool
// BySize reports per-size class allocation statistics.
//
// BySize[N] gives statistics for allocations of size S where
// BySize[N-1].Size < S ≤ BySize[N].Size.
//
// This does not report allocations larger than BySize[60].Size.
BySize [61]struct {
// Size is the maximum byte size of an object in this
// size class.
Size uint32
// Mallocs is the cumulative count of heap objects
// allocated in this size class. The cumulative bytes
// of allocation is Size*Mallocs. The number of live
// objects in this size class is Mallocs - Frees.
Mallocs uint64
// Frees is the cumulative count of heap objects freed
// in this size class.
Frees uint64
}
}
func init() {
if offset := unsafe.Offsetof(memstats.heapStats); offset%8 != 0 {
println(offset)
throw("memstats.heapStats not aligned to 8 bytes")
}
// Ensure the size of heapStatsDelta causes adjacent fields/slots (e.g.
// [3]heapStatsDelta) to be 8-byte aligned.
if size := unsafe.Sizeof(heapStatsDelta{}); size%8 != 0 {
println(size)
throw("heapStatsDelta not a multiple of 8 bytes in size")
}
}
// ReadMemStats populates m with memory allocator statistics.
//
// The returned memory allocator statistics are up to date as of the
// call to ReadMemStats. This is in contrast with a heap profile,
// which is a snapshot as of the most recently completed garbage
// collection cycle.
func ReadMemStats(m *MemStats) {
stopTheWorld("read mem stats")
systemstack(func() {
readmemstats_m(m)
})
startTheWorld()
}
// readmemstats_m populates stats for internal runtime values.
//
// The world must be stopped.
func readmemstats_m(stats *MemStats) {
assertWorldStopped()
// Flush mcaches to mcentral before doing anything else.
//
// Flushing to the mcentral may in general cause stats to
// change as mcentral data structures are manipulated.
systemstack(flushallmcaches)
// Calculate memory allocator stats.
// During program execution we only count number of frees and amount of freed memory.
// Current number of alive objects in the heap and amount of alive heap memory
// are calculated by scanning all spans.
// Total number of mallocs is calculated as number of frees plus number of alive objects.
// Similarly, total amount of allocated memory is calculated as amount of freed memory
// plus amount of alive heap memory.
// Collect consistent stats, which are the source-of-truth in some cases.
var consStats heapStatsDelta
memstats.heapStats.unsafeRead(&consStats)
// Collect large allocation stats.
totalAlloc := consStats.largeAlloc
nMalloc := consStats.largeAllocCount
totalFree := consStats.largeFree
nFree := consStats.largeFreeCount
// Collect per-sizeclass stats.
var bySize [_NumSizeClasses]struct {
Size uint32
Mallocs uint64
Frees uint64
}
for i := range bySize {
bySize[i].Size = uint32(class_to_size[i])
// Malloc stats.
a := consStats.smallAllocCount[i]
totalAlloc += a * uint64(class_to_size[i])
nMalloc += a
bySize[i].Mallocs = a
// Free stats.
f := consStats.smallFreeCount[i]
totalFree += f * uint64(class_to_size[i])
nFree += f
bySize[i].Frees = f
}
// Account for tiny allocations.
// For historical reasons, MemStats includes tiny allocations
// in both the total free and total alloc count. This double-counts
// memory in some sense because their tiny allocation block is also
// counted. Tracking the lifetime of individual tiny allocations is
// currently not done because it would be too expensive.
nFree += consStats.tinyAllocCount
nMalloc += consStats.tinyAllocCount
// Calculate derived stats.
stackInUse := uint64(consStats.inStacks)
gcWorkBufInUse := uint64(consStats.inWorkBufs)
gcProgPtrScalarBitsInUse := uint64(consStats.inPtrScalarBits)
totalMapped := gcController.heapInUse.load() + gcController.heapFree.load() + gcController.heapReleased.load() +
memstats.stacks_sys.load() + memstats.mspan_sys.load() + memstats.mcache_sys.load() +
memstats.buckhash_sys.load() + memstats.gcMiscSys.load() + memstats.other_sys.load() +
stackInUse + gcWorkBufInUse + gcProgPtrScalarBitsInUse
heapGoal := gcController.heapGoal()
// The world is stopped, so the consistent stats (after aggregation)
// should be identical to some combination of memstats. In particular:
//
// * memstats.heapInUse == inHeap
// * memstats.heapReleased == released
// * memstats.heapInUse + memstats.heapFree == committed - inStacks - inWorkBufs - inPtrScalarBits
// * memstats.totalAlloc == totalAlloc
// * memstats.totalFree == totalFree
//
// Check if that's actually true.
//
// TODO(mknyszek): Maybe don't throw here. It would be bad if a
// bug in otherwise benign accounting caused the whole application
// to crash.
if gcController.heapInUse.load() != uint64(consStats.inHeap) {
print("runtime: heapInUse=", gcController.heapInUse.load(), "\n")
print("runtime: consistent value=", consStats.inHeap, "\n")
throw("heapInUse and consistent stats are not equal")
}
if gcController.heapReleased.load() != uint64(consStats.released) {
print("runtime: heapReleased=", gcController.heapReleased.load(), "\n")
print("runtime: consistent value=", consStats.released, "\n")
throw("heapReleased and consistent stats are not equal")
}
heapRetained := gcController.heapInUse.load() + gcController.heapFree.load()
consRetained := uint64(consStats.committed - consStats.inStacks - consStats.inWorkBufs - consStats.inPtrScalarBits)
if heapRetained != consRetained {
print("runtime: global value=", heapRetained, "\n")
print("runtime: consistent value=", consRetained, "\n")
throw("measures of the retained heap are not equal")
}
if gcController.totalAlloc.Load() != totalAlloc {
print("runtime: totalAlloc=", gcController.totalAlloc.Load(), "\n")
print("runtime: consistent value=", totalAlloc, "\n")
throw("totalAlloc and consistent stats are not equal")
}
if gcController.totalFree.Load() != totalFree {
print("runtime: totalFree=", gcController.totalFree.Load(), "\n")
print("runtime: consistent value=", totalFree, "\n")
throw("totalFree and consistent stats are not equal")
}
// Also check that mappedReady lines up with totalMapped - released.
// This isn't really the same type of "make sure consistent stats line up" situation,
// but this is an opportune time to check.
if gcController.mappedReady.Load() != totalMapped-uint64(consStats.released) {
print("runtime: mappedReady=", gcController.mappedReady.Load(), "\n")
print("runtime: totalMapped=", totalMapped, "\n")
print("runtime: released=", uint64(consStats.released), "\n")
print("runtime: totalMapped-released=", totalMapped-uint64(consStats.released), "\n")
throw("mappedReady and other memstats are not equal")
}
// We've calculated all the values we need. Now, populate stats.
stats.Alloc = totalAlloc - totalFree
stats.TotalAlloc = totalAlloc
stats.Sys = totalMapped
stats.Mallocs = nMalloc
stats.Frees = nFree
stats.HeapAlloc = totalAlloc - totalFree
stats.HeapSys = gcController.heapInUse.load() + gcController.heapFree.load() + gcController.heapReleased.load()
// By definition, HeapIdle is memory that was mapped
// for the heap but is not currently used to hold heap
// objects. It also specifically is memory that can be
// used for other purposes, like stacks, but this memory
// is subtracted out of HeapSys before it makes that
// transition. Put another way:
//
// HeapSys = bytes allocated from the OS for the heap - bytes ultimately used for non-heap purposes
// HeapIdle = bytes allocated from the OS for the heap - bytes ultimately used for any purpose
//
// or
//
// HeapSys = sys - stacks_inuse - gcWorkBufInUse - gcProgPtrScalarBitsInUse
// HeapIdle = sys - stacks_inuse - gcWorkBufInUse - gcProgPtrScalarBitsInUse - heapInUse
//
// => HeapIdle = HeapSys - heapInUse = heapFree + heapReleased
stats.HeapIdle = gcController.heapFree.load() + gcController.heapReleased.load()
stats.HeapInuse = gcController.heapInUse.load()
stats.HeapReleased = gcController.heapReleased.load()
stats.HeapObjects = nMalloc - nFree
stats.StackInuse = stackInUse
// memstats.stacks_sys is only memory mapped directly for OS stacks.
// Add in heap-allocated stack memory for user consumption.
stats.StackSys = stackInUse + memstats.stacks_sys.load()
stats.MSpanInuse = uint64(mheap_.spanalloc.inuse)
stats.MSpanSys = memstats.mspan_sys.load()
stats.MCacheInuse = uint64(mheap_.cachealloc.inuse)
stats.MCacheSys = memstats.mcache_sys.load()
stats.BuckHashSys = memstats.buckhash_sys.load()
// MemStats defines GCSys as an aggregate of all memory related
// to the memory management system, but we track this memory
// at a more granular level in the runtime.
stats.GCSys = memstats.gcMiscSys.load() + gcWorkBufInUse + gcProgPtrScalarBitsInUse
stats.OtherSys = memstats.other_sys.load()
stats.NextGC = heapGoal
stats.LastGC = memstats.last_gc_unix
stats.PauseTotalNs = memstats.pause_total_ns
stats.PauseNs = memstats.pause_ns
stats.PauseEnd = memstats.pause_end
stats.NumGC = memstats.numgc
stats.NumForcedGC = memstats.numforcedgc
stats.GCCPUFraction = memstats.gc_cpu_fraction
stats.EnableGC = true
// stats.BySize and bySize might not match in length.
// That's OK, stats.BySize cannot change due to backwards
// compatibility issues. copy will copy the minimum amount
// of values between the two of them.
copy(stats.BySize[:], bySize[:])
}
//go:linkname readGCStats runtime/debug.readGCStats
func readGCStats(pauses *[]uint64) {
systemstack(func() {
readGCStats_m(pauses)
})
}
// readGCStats_m must be called on the system stack because it acquires the heap
// lock. See mheap for details.
//
//go:systemstack
func readGCStats_m(pauses *[]uint64) {
p := *pauses
// Calling code in runtime/debug should make the slice large enough.
if cap(p) < len(memstats.pause_ns)+3 {
throw("short slice passed to readGCStats")
}
// Pass back: pauses, pause ends, last gc (absolute time), number of gc, total pause ns.
lock(&mheap_.lock)
n := memstats.numgc
if n > uint32(len(memstats.pause_ns)) {
n = uint32(len(memstats.pause_ns))
}
// The pause buffer is circular. The most recent pause is at
// pause_ns[(numgc-1)%len(pause_ns)], and then backward
// from there to go back farther in time. We deliver the times
// most recent first (in p[0]).
p = p[:cap(p)]
for i := uint32(0); i < n; i++ {
j := (memstats.numgc - 1 - i) % uint32(len(memstats.pause_ns))
p[i] = memstats.pause_ns[j]
p[n+i] = memstats.pause_end[j]
}
p[n+n] = memstats.last_gc_unix
p[n+n+1] = uint64(memstats.numgc)
p[n+n+2] = memstats.pause_total_ns
unlock(&mheap_.lock)
*pauses = p[:n+n+3]
}
// flushmcache flushes the mcache of allp[i].
//
// The world must be stopped.
//
//go:nowritebarrier
func flushmcache(i int) {
assertWorldStopped()
p := allp[i]
c := p.mcache
if c == nil {
return
}
c.releaseAll()
stackcache_clear(c)
}
// flushallmcaches flushes the mcaches of all Ps.
//
// The world must be stopped.
//
//go:nowritebarrier
func flushallmcaches() {
assertWorldStopped()
for i := 0; i < int(gomaxprocs); i++ {
flushmcache(i)
}
}
// sysMemStat represents a global system statistic that is managed atomically.
//
// This type must structurally be a uint64 so that mstats aligns with MemStats.
type sysMemStat uint64
// load atomically reads the value of the stat.
//
// Must be nosplit as it is called in runtime initialization, e.g. newosproc0.
//
//go:nosplit
func (s *sysMemStat) load() uint64 {
return atomic.Load64((*uint64)(s))
}
// add atomically adds the sysMemStat by n.
//
// Must be nosplit as it is called in runtime initialization, e.g. newosproc0.
//
//go:nosplit
func (s *sysMemStat) add(n int64) {
val := atomic.Xadd64((*uint64)(s), n)
if (n > 0 && int64(val) < n) || (n < 0 && int64(val)+n < n) {
print("runtime: val=", val, " n=", n, "\n")
throw("sysMemStat overflow")
}
}
// heapStatsDelta contains deltas of various runtime memory statistics
// that need to be updated together in order for them to be kept
// consistent with one another.
type heapStatsDelta struct {
// Memory stats.
committed int64 // byte delta of memory committed
released int64 // byte delta of released memory generated
inHeap int64 // byte delta of memory placed in the heap
inStacks int64 // byte delta of memory reserved for stacks
inWorkBufs int64 // byte delta of memory reserved for work bufs
inPtrScalarBits int64 // byte delta of memory reserved for unrolled GC prog bits
// Allocator stats.
//
// These are all uint64 because they're cumulative, and could quickly wrap
// around otherwise.
tinyAllocCount uint64 // number of tiny allocations
largeAlloc uint64 // bytes allocated for large objects
largeAllocCount uint64 // number of large object allocations
smallAllocCount [_NumSizeClasses]uint64 // number of allocs for small objects
largeFree uint64 // bytes freed for large objects (>maxSmallSize)
largeFreeCount uint64 // number of frees for large objects (>maxSmallSize)
smallFreeCount [_NumSizeClasses]uint64 // number of frees for small objects (<=maxSmallSize)
// NOTE: This struct must be a multiple of 8 bytes in size because it
// is stored in an array. If it's not, atomic accesses to the above
// fields may be unaligned and fail on 32-bit platforms.
}
// merge adds in the deltas from b into a.
func (a *heapStatsDelta) merge(b *heapStatsDelta) {
a.committed += b.committed
a.released += b.released
a.inHeap += b.inHeap
a.inStacks += b.inStacks
a.inWorkBufs += b.inWorkBufs
a.inPtrScalarBits += b.inPtrScalarBits
a.tinyAllocCount += b.tinyAllocCount
a.largeAlloc += b.largeAlloc
a.largeAllocCount += b.largeAllocCount
for i := range b.smallAllocCount {
a.smallAllocCount[i] += b.smallAllocCount[i]
}
a.largeFree += b.largeFree
a.largeFreeCount += b.largeFreeCount
for i := range b.smallFreeCount {
a.smallFreeCount[i] += b.smallFreeCount[i]
}
}
// consistentHeapStats represents a set of various memory statistics
// whose updates must be viewed completely to get a consistent
// state of the world.
//
// To write updates to memory stats use the acquire and release
// methods. To obtain a consistent global snapshot of these statistics,
// use read.
type consistentHeapStats struct {
// stats is a ring buffer of heapStatsDelta values.
// Writers always atomically update the delta at index gen.
//
// Readers operate by rotating gen (0 -> 1 -> 2 -> 0 -> ...)
// and synchronizing with writers by observing each P's
// statsSeq field. If the reader observes a P not writing,
// it can be sure that it will pick up the new gen value the
// next time it writes.
//
// The reader then takes responsibility by clearing space
// in the ring buffer for the next reader to rotate gen to
// that space (i.e. it merges in values from index (gen-2) mod 3
// to index (gen-1) mod 3, then clears the former).
//
// Note that this means only one reader can be reading at a time.
// There is no way for readers to synchronize.
//
// This process is why we need a ring buffer of size 3 instead
// of 2: one is for the writers, one contains the most recent
// data, and the last one is clear so writers can begin writing
// to it the moment gen is updated.
stats [3]heapStatsDelta
// gen represents the current index into which writers
// are writing, and can take on the value of 0, 1, or 2.
gen atomic.Uint32
// noPLock is intended to provide mutual exclusion for updating
// stats when no P is available. It does not block other writers
// with a P, only other writers without a P and the reader. Because
// stats are usually updated when a P is available, contention on
// this lock should be minimal.
noPLock mutex
}
// acquire returns a heapStatsDelta to be updated. In effect,
// it acquires the shard for writing. release must be called
// as soon as the relevant deltas are updated.
//
// The returned heapStatsDelta must be updated atomically.
//
// The caller's P must not change between acquire and
// release. This also means that the caller should not
// acquire a P or release its P in between. A P also must
// not acquire a given consistentHeapStats if it hasn't
// yet released it.
//
// nosplit because a stack growth in this function could
// lead to a stack allocation that could reenter the
// function.
//
//go:nosplit
func (m *consistentHeapStats) acquire() *heapStatsDelta {
if pp := getg().m.p.ptr(); pp != nil {
seq := pp.statsSeq.Add(1)
if seq%2 == 0 {
// Should have been incremented to odd.
print("runtime: seq=", seq, "\n")
throw("bad sequence number")
}
} else {
lock(&m.noPLock)
}
gen := m.gen.Load() % 3
return &m.stats[gen]
}
// release indicates that the writer is done modifying
// the delta. The value returned by the corresponding
// acquire must no longer be accessed or modified after
// release is called.
//
// The caller's P must not change between acquire and
// release. This also means that the caller should not
// acquire a P or release its P in between.
//
// nosplit because a stack growth in this function could
// lead to a stack allocation that causes another acquire
// before this operation has completed.
//
//go:nosplit
func (m *consistentHeapStats) release() {
if pp := getg().m.p.ptr(); pp != nil {
seq := pp.statsSeq.Add(1)
if seq%2 != 0 {
// Should have been incremented to even.
print("runtime: seq=", seq, "\n")
throw("bad sequence number")
}
} else {
unlock(&m.noPLock)
}
}
// unsafeRead aggregates the delta for this shard into out.
//
// Unsafe because it does so without any synchronization. The
// world must be stopped.
func (m *consistentHeapStats) unsafeRead(out *heapStatsDelta) {
assertWorldStopped()
for i := range m.stats {
out.merge(&m.stats[i])
}
}
// unsafeClear clears the shard.
//
// Unsafe because the world must be stopped and values should
// be donated elsewhere before clearing.
func (m *consistentHeapStats) unsafeClear() {
assertWorldStopped()
for i := range m.stats {
m.stats[i] = heapStatsDelta{}
}
}
// read takes a globally consistent snapshot of m
// and puts the aggregated value in out. Even though out is a
// heapStatsDelta, the resulting values should be complete and
// valid statistic values.
//
// Not safe to call concurrently. The world must be stopped
// or metricsSema must be held.
func (m *consistentHeapStats) read(out *heapStatsDelta) {
// Getting preempted after this point is not safe because
// we read allp. We need to make sure a STW can't happen
// so it doesn't change out from under us.
mp := acquirem()
// Get the current generation. We can be confident that this
// will not change since read is serialized and is the only
// one that modifies currGen.
currGen := m.gen.Load()
prevGen := currGen - 1
if currGen == 0 {
prevGen = 2
}
// Prevent writers without a P from writing while we update gen.
lock(&m.noPLock)
// Rotate gen, effectively taking a snapshot of the state of
// these statistics at the point of the exchange by moving
// writers to the next set of deltas.
//
// This exchange is safe to do because we won't race
// with anyone else trying to update this value.
m.gen.Swap((currGen + 1) % 3)
// Allow P-less writers to continue. They'll be writing to the
// next generation now.
unlock(&m.noPLock)
for _, p := range allp {
// Spin until there are no more writers.
for p.statsSeq.Load()%2 != 0 {
}
}
// At this point we've observed that each sequence
// number is even, so any future writers will observe
// the new gen value. That means it's safe to read from
// the other deltas in the stats buffer.
// Perform our responsibilities and free up
// stats[prevGen] for the next time we want to take
// a snapshot.
m.stats[currGen].merge(&m.stats[prevGen])
m.stats[prevGen] = heapStatsDelta{}
// Finally, copy out the complete delta.
*out = m.stats[currGen]
releasem(mp)
}
type cpuStats struct {
// All fields are CPU time in nanoseconds computed by comparing
// calls of nanotime. This means they're all overestimates, because
// they don't accurately compute on-CPU time (so some of the time
// could be spent scheduled away by the OS).
gcAssistTime int64 // GC assists
gcDedicatedTime int64 // GC dedicated mark workers + pauses
gcIdleTime int64 // GC idle mark workers
gcPauseTime int64 // GC pauses (all GOMAXPROCS, even if just 1 is running)
gcTotalTime int64
scavengeAssistTime int64 // background scavenger
scavengeBgTime int64 // scavenge assists
scavengeTotalTime int64
idleTime int64 // Time Ps spent in _Pidle.
userTime int64 // Time Ps spent in _Prunning or _Psyscall that's not any of the above.
totalTime int64 // GOMAXPROCS * (monotonic wall clock time elapsed)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This implements the write barrier buffer. The write barrier itself
// is gcWriteBarrier and is implemented in assembly.
//
// See mbarrier.go for algorithmic details on the write barrier. This
// file deals only with the buffer.
//
// The write barrier has a fast path and a slow path. The fast path
// simply enqueues to a per-P write barrier buffer. It's written in
// assembly and doesn't clobber any general purpose registers, so it
// doesn't have the usual overheads of a Go call.
//
// When the buffer fills up, the write barrier invokes the slow path
// (wbBufFlush) to flush the buffer to the GC work queues. In this
// path, since the compiler didn't spill registers, we spill *all*
// registers and disallow any GC safe points that could observe the
// stack frame (since we don't know the types of the spilled
// registers).
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
// testSmallBuf forces a small write barrier buffer to stress write
// barrier flushing.
const testSmallBuf = false
// wbBuf is a per-P buffer of pointers queued by the write barrier.
// This buffer is flushed to the GC workbufs when it fills up and on
// various GC transitions.
//
// This is closely related to a "sequential store buffer" (SSB),
// except that SSBs are usually used for maintaining remembered sets,
// while this is used for marking.
type wbBuf struct {
// next points to the next slot in buf. It must not be a
// pointer type because it can point past the end of buf and
// must be updated without write barriers.
//
// This is a pointer rather than an index to optimize the
// write barrier assembly.
next uintptr
// end points to just past the end of buf. It must not be a
// pointer type because it points past the end of buf and must
// be updated without write barriers.
end uintptr
// buf stores a series of pointers to execute write barriers on.
buf [wbBufEntries]uintptr
}
const (
// wbBufEntries is the maximum number of pointers that can be
// stored in the write barrier buffer.
//
// This trades latency for throughput amortization. Higher
// values amortize flushing overhead more, but increase the
// latency of flushing. Higher values also increase the cache
// footprint of the buffer.
//
// TODO: What is the latency cost of this? Tune this value.
wbBufEntries = 512
// Maximum number of entries that we need to ask from the
// buffer in a single call.
wbMaxEntriesPerCall = 8
)
// reset empties b by resetting its next and end pointers.
func (b *wbBuf) reset() {
start := uintptr(unsafe.Pointer(&b.buf[0]))
b.next = start
if testSmallBuf {
// For testing, make the buffer smaller but more than
// 1 write barrier's worth, so it tests both the
// immediate flush and delayed flush cases.
b.end = uintptr(unsafe.Pointer(&b.buf[wbMaxEntriesPerCall+1]))
} else {
b.end = start + uintptr(len(b.buf))*unsafe.Sizeof(b.buf[0])
}
if (b.end-b.next)%unsafe.Sizeof(b.buf[0]) != 0 {
throw("bad write barrier buffer bounds")
}
}
// discard resets b's next pointer, but not its end pointer.
//
// This must be nosplit because it's called by wbBufFlush.
//
//go:nosplit
func (b *wbBuf) discard() {
b.next = uintptr(unsafe.Pointer(&b.buf[0]))
}
// empty reports whether b contains no pointers.
func (b *wbBuf) empty() bool {
return b.next == uintptr(unsafe.Pointer(&b.buf[0]))
}
// getX returns space in the write barrier buffer to store X pointers.
// getX will flush the buffer if necessary. Callers should use this as:
//
// buf := &getg().m.p.ptr().wbBuf
// p := buf.get2()
// p[0], p[1] = old, new
// ... actual memory write ...
//
// The caller must ensure there are no preemption points during the
// above sequence. There must be no preemption points while buf is in
// use because it is a per-P resource. There must be no preemption
// points between the buffer put and the write to memory because this
// could allow a GC phase change, which could result in missed write
// barriers.
//
// getX must be nowritebarrierrec to because write barriers here would
// corrupt the write barrier buffer. It (and everything it calls, if
// it called anything) has to be nosplit to avoid scheduling on to a
// different P and a different buffer.
//
//go:nowritebarrierrec
//go:nosplit
func (b *wbBuf) get1() *[1]uintptr {
if b.next+goarch.PtrSize > b.end {
wbBufFlush()
}
p := (*[1]uintptr)(unsafe.Pointer(b.next))
b.next += goarch.PtrSize
return p
}
//go:nowritebarrierrec
//go:nosplit
func (b *wbBuf) get2() *[2]uintptr {
if b.next+2*goarch.PtrSize > b.end {
wbBufFlush()
}
p := (*[2]uintptr)(unsafe.Pointer(b.next))
b.next += 2 * goarch.PtrSize
return p
}
// wbBufFlush flushes the current P's write barrier buffer to the GC
// workbufs.
//
// This must not have write barriers because it is part of the write
// barrier implementation.
//
// This and everything it calls must be nosplit because 1) the stack
// contains untyped slots from gcWriteBarrier and 2) there must not be
// a GC safe point between the write barrier test in the caller and
// flushing the buffer.
//
// TODO: A "go:nosplitrec" annotation would be perfect for this.
//
//go:nowritebarrierrec
//go:nosplit
func wbBufFlush() {
// Note: Every possible return from this function must reset
// the buffer's next pointer to prevent buffer overflow.
if getg().m.dying > 0 {
// We're going down. Not much point in write barriers
// and this way we can allow write barriers in the
// panic path.
getg().m.p.ptr().wbBuf.discard()
return
}
// Switch to the system stack so we don't have to worry about
// safe points.
systemstack(func() {
wbBufFlush1(getg().m.p.ptr())
})
}
// wbBufFlush1 flushes p's write barrier buffer to the GC work queue.
//
// This must not have write barriers because it is part of the write
// barrier implementation, so this may lead to infinite loops or
// buffer corruption.
//
// This must be non-preemptible because it uses the P's workbuf.
//
//go:nowritebarrierrec
//go:systemstack
func wbBufFlush1(pp *p) {
// Get the buffered pointers.
start := uintptr(unsafe.Pointer(&pp.wbBuf.buf[0]))
n := (pp.wbBuf.next - start) / unsafe.Sizeof(pp.wbBuf.buf[0])
ptrs := pp.wbBuf.buf[:n]
// Poison the buffer to make extra sure nothing is enqueued
// while we're processing the buffer.
pp.wbBuf.next = 0
if useCheckmark {
// Slow path for checkmark mode.
for _, ptr := range ptrs {
shade(ptr)
}
pp.wbBuf.reset()
return
}
// Mark all of the pointers in the buffer and record only the
// pointers we greyed. We use the buffer itself to temporarily
// record greyed pointers.
//
// TODO: Should scanobject/scanblock just stuff pointers into
// the wbBuf? Then this would become the sole greying path.
//
// TODO: We could avoid shading any of the "new" pointers in
// the buffer if the stack has been shaded, or even avoid
// putting them in the buffer at all (which would double its
// capacity). This is slightly complicated with the buffer; we
// could track whether any un-shaded goroutine has used the
// buffer, or just track globally whether there are any
// un-shaded stacks and flush after each stack scan.
gcw := &pp.gcw
pos := 0
for _, ptr := range ptrs {
if ptr < minLegalPointer {
// nil pointers are very common, especially
// for the "old" values. Filter out these and
// other "obvious" non-heap pointers ASAP.
//
// TODO: Should we filter out nils in the fast
// path to reduce the rate of flushes?
continue
}
obj, span, objIndex := findObject(ptr, 0, 0)
if obj == 0 {
continue
}
// TODO: Consider making two passes where the first
// just prefetches the mark bits.
mbits := span.markBitsForIndex(objIndex)
if mbits.isMarked() {
continue
}
mbits.setMarked()
// Mark span.
arena, pageIdx, pageMask := pageIndexOf(span.base())
if arena.pageMarks[pageIdx]&pageMask == 0 {
atomic.Or8(&arena.pageMarks[pageIdx], pageMask)
}
if span.spanclass.noscan() {
gcw.bytesMarked += uint64(span.elemsize)
continue
}
ptrs[pos] = obj
pos++
}
// Enqueue the greyed objects.
gcw.putBatch(ptrs[:pos])
pp.wbBuf.reset()
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
package runtime
func nonblockingPipe() (r, w int32, errno int32) {
return pipe2(_O_NONBLOCK | _O_CLOEXEC)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || windows
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// Integrated network poller (platform-independent part).
// A particular implementation (epoll/kqueue/port/AIX/Windows)
// must define the following functions:
//
// func netpollinit()
// Initialize the poller. Only called once.
//
// func netpollopen(fd uintptr, pd *pollDesc) int32
// Arm edge-triggered notifications for fd. The pd argument is to pass
// back to netpollready when fd is ready. Return an errno value.
//
// func netpollclose(fd uintptr) int32
// Disable notifications for fd. Return an errno value.
//
// func netpoll(delta int64) gList
// Poll the network. If delta < 0, block indefinitely. If delta == 0,
// poll without blocking. If delta > 0, block for up to delta nanoseconds.
// Return a list of goroutines built by calling netpollready.
//
// func netpollBreak()
// Wake up the network poller, assumed to be blocked in netpoll.
//
// func netpollIsPollDescriptor(fd uintptr) bool
// Reports whether fd is a file descriptor used by the poller.
// Error codes returned by runtime_pollReset and runtime_pollWait.
// These must match the values in internal/poll/fd_poll_runtime.go.
const (
pollNoError = 0 // no error
pollErrClosing = 1 // descriptor is closed
pollErrTimeout = 2 // I/O timeout
pollErrNotPollable = 3 // general error polling descriptor
)
// pollDesc contains 2 binary semaphores, rg and wg, to park reader and writer
// goroutines respectively. The semaphore can be in the following states:
//
// pdReady - io readiness notification is pending;
// a goroutine consumes the notification by changing the state to pdNil.
// pdWait - a goroutine prepares to park on the semaphore, but not yet parked;
// the goroutine commits to park by changing the state to G pointer,
// or, alternatively, concurrent io notification changes the state to pdReady,
// or, alternatively, concurrent timeout/close changes the state to pdNil.
// G pointer - the goroutine is blocked on the semaphore;
// io notification or timeout/close changes the state to pdReady or pdNil respectively
// and unparks the goroutine.
// pdNil - none of the above.
const (
pdNil uintptr = 0
pdReady uintptr = 1
pdWait uintptr = 2
)
const pollBlockSize = 4 * 1024
// Network poller descriptor.
//
// No heap pointers.
type pollDesc struct {
_ sys.NotInHeap
link *pollDesc // in pollcache, protected by pollcache.lock
fd uintptr // constant for pollDesc usage lifetime
// atomicInfo holds bits from closing, rd, and wd,
// which are only ever written while holding the lock,
// summarized for use by netpollcheckerr,
// which cannot acquire the lock.
// After writing these fields under lock in a way that
// might change the summary, code must call publishInfo
// before releasing the lock.
// Code that changes fields and then calls netpollunblock
// (while still holding the lock) must call publishInfo
// before calling netpollunblock, because publishInfo is what
// stops netpollblock from blocking anew
// (by changing the result of netpollcheckerr).
// atomicInfo also holds the eventErr bit,
// recording whether a poll event on the fd got an error;
// atomicInfo is the only source of truth for that bit.
atomicInfo atomic.Uint32 // atomic pollInfo
// rg, wg are accessed atomically and hold g pointers.
// (Using atomic.Uintptr here is similar to using guintptr elsewhere.)
rg atomic.Uintptr // pdReady, pdWait, G waiting for read or pdNil
wg atomic.Uintptr // pdReady, pdWait, G waiting for write or pdNil
lock mutex // protects the following fields
closing bool
user uint32 // user settable cookie
rseq uintptr // protects from stale read timers
rt timer // read deadline timer (set if rt.f != nil)
rd int64 // read deadline (a nanotime in the future, -1 when expired)
wseq uintptr // protects from stale write timers
wt timer // write deadline timer
wd int64 // write deadline (a nanotime in the future, -1 when expired)
self *pollDesc // storage for indirect interface. See (*pollDesc).makeArg.
}
// pollInfo is the bits needed by netpollcheckerr, stored atomically,
// mostly duplicating state that is manipulated under lock in pollDesc.
// The one exception is the pollEventErr bit, which is maintained only
// in the pollInfo.
type pollInfo uint32
const (
pollClosing = 1 << iota
pollEventErr
pollExpiredReadDeadline
pollExpiredWriteDeadline
)
func (i pollInfo) closing() bool { return i&pollClosing != 0 }
func (i pollInfo) eventErr() bool { return i&pollEventErr != 0 }
func (i pollInfo) expiredReadDeadline() bool { return i&pollExpiredReadDeadline != 0 }
func (i pollInfo) expiredWriteDeadline() bool { return i&pollExpiredWriteDeadline != 0 }
// info returns the pollInfo corresponding to pd.
func (pd *pollDesc) info() pollInfo {
return pollInfo(pd.atomicInfo.Load())
}
// publishInfo updates pd.atomicInfo (returned by pd.info)
// using the other values in pd.
// It must be called while holding pd.lock,
// and it must be called after changing anything
// that might affect the info bits.
// In practice this means after changing closing
// or changing rd or wd from < 0 to >= 0.
func (pd *pollDesc) publishInfo() {
var info uint32
if pd.closing {
info |= pollClosing
}
if pd.rd < 0 {
info |= pollExpiredReadDeadline
}
if pd.wd < 0 {
info |= pollExpiredWriteDeadline
}
// Set all of x except the pollEventErr bit.
x := pd.atomicInfo.Load()
for !pd.atomicInfo.CompareAndSwap(x, (x&pollEventErr)|info) {
x = pd.atomicInfo.Load()
}
}
// setEventErr sets the result of pd.info().eventErr() to b.
func (pd *pollDesc) setEventErr(b bool) {
x := pd.atomicInfo.Load()
for (x&pollEventErr != 0) != b && !pd.atomicInfo.CompareAndSwap(x, x^pollEventErr) {
x = pd.atomicInfo.Load()
}
}
type pollCache struct {
lock mutex
first *pollDesc
// PollDesc objects must be type-stable,
// because we can get ready notification from epoll/kqueue
// after the descriptor is closed/reused.
// Stale notifications are detected using seq variable,
// seq is incremented when deadlines are changed or descriptor is reused.
}
var (
netpollInitLock mutex
netpollInited atomic.Uint32
pollcache pollCache
netpollWaiters atomic.Uint32
)
//go:linkname poll_runtime_pollServerInit internal/poll.runtime_pollServerInit
func poll_runtime_pollServerInit() {
netpollGenericInit()
}
func netpollGenericInit() {
if netpollInited.Load() == 0 {
lockInit(&netpollInitLock, lockRankNetpollInit)
lock(&netpollInitLock)
if netpollInited.Load() == 0 {
netpollinit()
netpollInited.Store(1)
}
unlock(&netpollInitLock)
}
}
func netpollinited() bool {
return netpollInited.Load() != 0
}
//go:linkname poll_runtime_isPollServerDescriptor internal/poll.runtime_isPollServerDescriptor
// poll_runtime_isPollServerDescriptor reports whether fd is a
// descriptor being used by netpoll.
func poll_runtime_isPollServerDescriptor(fd uintptr) bool {
return netpollIsPollDescriptor(fd)
}
//go:linkname poll_runtime_pollOpen internal/poll.runtime_pollOpen
func poll_runtime_pollOpen(fd uintptr) (*pollDesc, int) {
pd := pollcache.alloc()
lock(&pd.lock)
wg := pd.wg.Load()
if wg != pdNil && wg != pdReady {
throw("runtime: blocked write on free polldesc")
}
rg := pd.rg.Load()
if rg != pdNil && rg != pdReady {
throw("runtime: blocked read on free polldesc")
}
pd.fd = fd
pd.closing = false
pd.setEventErr(false)
pd.rseq++
pd.rg.Store(pdNil)
pd.rd = 0
pd.wseq++
pd.wg.Store(pdNil)
pd.wd = 0
pd.self = pd
pd.publishInfo()
unlock(&pd.lock)
errno := netpollopen(fd, pd)
if errno != 0 {
pollcache.free(pd)
return nil, int(errno)
}
return pd, 0
}
//go:linkname poll_runtime_pollClose internal/poll.runtime_pollClose
func poll_runtime_pollClose(pd *pollDesc) {
if !pd.closing {
throw("runtime: close polldesc w/o unblock")
}
wg := pd.wg.Load()
if wg != pdNil && wg != pdReady {
throw("runtime: blocked write on closing polldesc")
}
rg := pd.rg.Load()
if rg != pdNil && rg != pdReady {
throw("runtime: blocked read on closing polldesc")
}
netpollclose(pd.fd)
pollcache.free(pd)
}
func (c *pollCache) free(pd *pollDesc) {
lock(&c.lock)
pd.link = c.first
c.first = pd
unlock(&c.lock)
}
// poll_runtime_pollReset, which is internal/poll.runtime_pollReset,
// prepares a descriptor for polling in mode, which is 'r' or 'w'.
// This returns an error code; the codes are defined above.
//
//go:linkname poll_runtime_pollReset internal/poll.runtime_pollReset
func poll_runtime_pollReset(pd *pollDesc, mode int) int {
errcode := netpollcheckerr(pd, int32(mode))
if errcode != pollNoError {
return errcode
}
if mode == 'r' {
pd.rg.Store(pdNil)
} else if mode == 'w' {
pd.wg.Store(pdNil)
}
return pollNoError
}
// poll_runtime_pollWait, which is internal/poll.runtime_pollWait,
// waits for a descriptor to be ready for reading or writing,
// according to mode, which is 'r' or 'w'.
// This returns an error code; the codes are defined above.
//
//go:linkname poll_runtime_pollWait internal/poll.runtime_pollWait
func poll_runtime_pollWait(pd *pollDesc, mode int) int {
errcode := netpollcheckerr(pd, int32(mode))
if errcode != pollNoError {
return errcode
}
// As for now only Solaris, illumos, and AIX use level-triggered IO.
if GOOS == "solaris" || GOOS == "illumos" || GOOS == "aix" {
netpollarm(pd, mode)
}
for !netpollblock(pd, int32(mode), false) {
errcode = netpollcheckerr(pd, int32(mode))
if errcode != pollNoError {
return errcode
}
// Can happen if timeout has fired and unblocked us,
// but before we had a chance to run, timeout has been reset.
// Pretend it has not happened and retry.
}
return pollNoError
}
//go:linkname poll_runtime_pollWaitCanceled internal/poll.runtime_pollWaitCanceled
func poll_runtime_pollWaitCanceled(pd *pollDesc, mode int) {
// This function is used only on windows after a failed attempt to cancel
// a pending async IO operation. Wait for ioready, ignore closing or timeouts.
for !netpollblock(pd, int32(mode), true) {
}
}
//go:linkname poll_runtime_pollSetDeadline internal/poll.runtime_pollSetDeadline
func poll_runtime_pollSetDeadline(pd *pollDesc, d int64, mode int) {
lock(&pd.lock)
if pd.closing {
unlock(&pd.lock)
return
}
rd0, wd0 := pd.rd, pd.wd
combo0 := rd0 > 0 && rd0 == wd0
if d > 0 {
d += nanotime()
if d <= 0 {
// If the user has a deadline in the future, but the delay calculation
// overflows, then set the deadline to the maximum possible value.
d = 1<<63 - 1
}
}
if mode == 'r' || mode == 'r'+'w' {
pd.rd = d
}
if mode == 'w' || mode == 'r'+'w' {
pd.wd = d
}
pd.publishInfo()
combo := pd.rd > 0 && pd.rd == pd.wd
rtf := netpollReadDeadline
if combo {
rtf = netpollDeadline
}
if pd.rt.f == nil {
if pd.rd > 0 {
pd.rt.f = rtf
// Copy current seq into the timer arg.
// Timer func will check the seq against current descriptor seq,
// if they differ the descriptor was reused or timers were reset.
pd.rt.arg = pd.makeArg()
pd.rt.seq = pd.rseq
resettimer(&pd.rt, pd.rd)
}
} else if pd.rd != rd0 || combo != combo0 {
pd.rseq++ // invalidate current timers
if pd.rd > 0 {
modtimer(&pd.rt, pd.rd, 0, rtf, pd.makeArg(), pd.rseq)
} else {
deltimer(&pd.rt)
pd.rt.f = nil
}
}
if pd.wt.f == nil {
if pd.wd > 0 && !combo {
pd.wt.f = netpollWriteDeadline
pd.wt.arg = pd.makeArg()
pd.wt.seq = pd.wseq
resettimer(&pd.wt, pd.wd)
}
} else if pd.wd != wd0 || combo != combo0 {
pd.wseq++ // invalidate current timers
if pd.wd > 0 && !combo {
modtimer(&pd.wt, pd.wd, 0, netpollWriteDeadline, pd.makeArg(), pd.wseq)
} else {
deltimer(&pd.wt)
pd.wt.f = nil
}
}
// If we set the new deadline in the past, unblock currently pending IO if any.
// Note that pd.publishInfo has already been called, above, immediately after modifying rd and wd.
var rg, wg *g
if pd.rd < 0 {
rg = netpollunblock(pd, 'r', false)
}
if pd.wd < 0 {
wg = netpollunblock(pd, 'w', false)
}
unlock(&pd.lock)
if rg != nil {
netpollgoready(rg, 3)
}
if wg != nil {
netpollgoready(wg, 3)
}
}
//go:linkname poll_runtime_pollUnblock internal/poll.runtime_pollUnblock
func poll_runtime_pollUnblock(pd *pollDesc) {
lock(&pd.lock)
if pd.closing {
throw("runtime: unblock on closing polldesc")
}
pd.closing = true
pd.rseq++
pd.wseq++
var rg, wg *g
pd.publishInfo()
rg = netpollunblock(pd, 'r', false)
wg = netpollunblock(pd, 'w', false)
if pd.rt.f != nil {
deltimer(&pd.rt)
pd.rt.f = nil
}
if pd.wt.f != nil {
deltimer(&pd.wt)
pd.wt.f = nil
}
unlock(&pd.lock)
if rg != nil {
netpollgoready(rg, 3)
}
if wg != nil {
netpollgoready(wg, 3)
}
}
// netpollready is called by the platform-specific netpoll function.
// It declares that the fd associated with pd is ready for I/O.
// The toRun argument is used to build a list of goroutines to return
// from netpoll. The mode argument is 'r', 'w', or 'r'+'w' to indicate
// whether the fd is ready for reading or writing or both.
//
// This may run while the world is stopped, so write barriers are not allowed.
//
//go:nowritebarrier
func netpollready(toRun *gList, pd *pollDesc, mode int32) {
var rg, wg *g
if mode == 'r' || mode == 'r'+'w' {
rg = netpollunblock(pd, 'r', true)
}
if mode == 'w' || mode == 'r'+'w' {
wg = netpollunblock(pd, 'w', true)
}
if rg != nil {
toRun.push(rg)
}
if wg != nil {
toRun.push(wg)
}
}
func netpollcheckerr(pd *pollDesc, mode int32) int {
info := pd.info()
if info.closing() {
return pollErrClosing
}
if (mode == 'r' && info.expiredReadDeadline()) || (mode == 'w' && info.expiredWriteDeadline()) {
return pollErrTimeout
}
// Report an event scanning error only on a read event.
// An error on a write event will be captured in a subsequent
// write call that is able to report a more specific error.
if mode == 'r' && info.eventErr() {
return pollErrNotPollable
}
return pollNoError
}
func netpollblockcommit(gp *g, gpp unsafe.Pointer) bool {
r := atomic.Casuintptr((*uintptr)(gpp), pdWait, uintptr(unsafe.Pointer(gp)))
if r {
// Bump the count of goroutines waiting for the poller.
// The scheduler uses this to decide whether to block
// waiting for the poller if there is nothing else to do.
netpollWaiters.Add(1)
}
return r
}
func netpollgoready(gp *g, traceskip int) {
netpollWaiters.Add(-1)
goready(gp, traceskip+1)
}
// returns true if IO is ready, or false if timed out or closed
// waitio - wait only for completed IO, ignore errors
// Concurrent calls to netpollblock in the same mode are forbidden, as pollDesc
// can hold only a single waiting goroutine for each mode.
func netpollblock(pd *pollDesc, mode int32, waitio bool) bool {
gpp := &pd.rg
if mode == 'w' {
gpp = &pd.wg
}
// set the gpp semaphore to pdWait
for {
// Consume notification if already ready.
if gpp.CompareAndSwap(pdReady, pdNil) {
return true
}
if gpp.CompareAndSwap(pdNil, pdWait) {
break
}
// Double check that this isn't corrupt; otherwise we'd loop
// forever.
if v := gpp.Load(); v != pdReady && v != pdNil {
throw("runtime: double wait")
}
}
// need to recheck error states after setting gpp to pdWait
// this is necessary because runtime_pollUnblock/runtime_pollSetDeadline/deadlineimpl
// do the opposite: store to closing/rd/wd, publishInfo, load of rg/wg
if waitio || netpollcheckerr(pd, mode) == pollNoError {
gopark(netpollblockcommit, unsafe.Pointer(gpp), waitReasonIOWait, traceEvGoBlockNet, 5)
}
// be careful to not lose concurrent pdReady notification
old := gpp.Swap(pdNil)
if old > pdWait {
throw("runtime: corrupted polldesc")
}
return old == pdReady
}
func netpollunblock(pd *pollDesc, mode int32, ioready bool) *g {
gpp := &pd.rg
if mode == 'w' {
gpp = &pd.wg
}
for {
old := gpp.Load()
if old == pdReady {
return nil
}
if old == pdNil && !ioready {
// Only set pdReady for ioready. runtime_pollWait
// will check for timeout/cancel before waiting.
return nil
}
var new uintptr
if ioready {
new = pdReady
}
if gpp.CompareAndSwap(old, new) {
if old == pdWait {
old = pdNil
}
return (*g)(unsafe.Pointer(old))
}
}
}
func netpolldeadlineimpl(pd *pollDesc, seq uintptr, read, write bool) {
lock(&pd.lock)
// Seq arg is seq when the timer was set.
// If it's stale, ignore the timer event.
currentSeq := pd.rseq
if !read {
currentSeq = pd.wseq
}
if seq != currentSeq {
// The descriptor was reused or timers were reset.
unlock(&pd.lock)
return
}
var rg *g
if read {
if pd.rd <= 0 || pd.rt.f == nil {
throw("runtime: inconsistent read deadline")
}
pd.rd = -1
pd.publishInfo()
rg = netpollunblock(pd, 'r', false)
}
var wg *g
if write {
if pd.wd <= 0 || pd.wt.f == nil && !read {
throw("runtime: inconsistent write deadline")
}
pd.wd = -1
pd.publishInfo()
wg = netpollunblock(pd, 'w', false)
}
unlock(&pd.lock)
if rg != nil {
netpollgoready(rg, 0)
}
if wg != nil {
netpollgoready(wg, 0)
}
}
func netpollDeadline(arg any, seq uintptr) {
netpolldeadlineimpl(arg.(*pollDesc), seq, true, true)
}
func netpollReadDeadline(arg any, seq uintptr) {
netpolldeadlineimpl(arg.(*pollDesc), seq, true, false)
}
func netpollWriteDeadline(arg any, seq uintptr) {
netpolldeadlineimpl(arg.(*pollDesc), seq, false, true)
}
func (c *pollCache) alloc() *pollDesc {
lock(&c.lock)
if c.first == nil {
const pdSize = unsafe.Sizeof(pollDesc{})
n := pollBlockSize / pdSize
if n == 0 {
n = 1
}
// Must be in non-GC memory because can be referenced
// only from epoll/kqueue internals.
mem := persistentalloc(n*pdSize, 0, &memstats.other_sys)
for i := uintptr(0); i < n; i++ {
pd := (*pollDesc)(add(mem, i*pdSize))
pd.link = c.first
c.first = pd
}
}
pd := c.first
c.first = pd.link
lockInit(&pd.lock, lockRankPollDesc)
unlock(&c.lock)
return pd
}
// makeArg converts pd to an interface{}.
// makeArg does not do any allocation. Normally, such
// a conversion requires an allocation because pointers to
// types which embed runtime/internal/sys.NotInHeap (which pollDesc is)
// must be stored in interfaces indirectly. See issue 42076.
func (pd *pollDesc) makeArg() (i any) {
x := (*eface)(unsafe.Pointer(&i))
x._type = pdType
x.data = unsafe.Pointer(&pd.self)
return
}
var (
pdEface any = (*pollDesc)(nil)
pdType *_type = efaceOf(&pdEface)._type
)
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux
package runtime
import (
"runtime/internal/atomic"
"runtime/internal/syscall"
"unsafe"
)
var (
epfd int32 = -1 // epoll descriptor
netpollBreakRd, netpollBreakWr uintptr // for netpollBreak
netpollWakeSig atomic.Uint32 // used to avoid duplicate calls of netpollBreak
)
func netpollinit() {
var errno uintptr
epfd, errno = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
if errno != 0 {
println("runtime: epollcreate failed with", errno)
throw("runtime: netpollinit failed")
}
r, w, errpipe := nonblockingPipe()
if errpipe != 0 {
println("runtime: pipe failed with", -errpipe)
throw("runtime: pipe failed")
}
ev := syscall.EpollEvent{
Events: syscall.EPOLLIN,
}
*(**uintptr)(unsafe.Pointer(&ev.Data)) = &netpollBreakRd
errno = syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, r, &ev)
if errno != 0 {
println("runtime: epollctl failed with", errno)
throw("runtime: epollctl failed")
}
netpollBreakRd = uintptr(r)
netpollBreakWr = uintptr(w)
}
func netpollIsPollDescriptor(fd uintptr) bool {
return fd == uintptr(epfd) || fd == netpollBreakRd || fd == netpollBreakWr
}
func netpollopen(fd uintptr, pd *pollDesc) uintptr {
var ev syscall.EpollEvent
ev.Events = syscall.EPOLLIN | syscall.EPOLLOUT | syscall.EPOLLRDHUP | syscall.EPOLLET
*(**pollDesc)(unsafe.Pointer(&ev.Data)) = pd
return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, int32(fd), &ev)
}
func netpollclose(fd uintptr) uintptr {
var ev syscall.EpollEvent
return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, int32(fd), &ev)
}
func netpollarm(pd *pollDesc, mode int) {
throw("runtime: unused")
}
// netpollBreak interrupts an epollwait.
func netpollBreak() {
// Failing to cas indicates there is an in-flight wakeup, so we're done here.
if !netpollWakeSig.CompareAndSwap(0, 1) {
return
}
for {
var b byte
n := write(netpollBreakWr, unsafe.Pointer(&b), 1)
if n == 1 {
break
}
if n == -_EINTR {
continue
}
if n == -_EAGAIN {
return
}
println("runtime: netpollBreak write failed with", -n)
throw("runtime: netpollBreak write failed")
}
}
// netpoll checks for ready network connections.
// Returns list of goroutines that become runnable.
// delay < 0: blocks indefinitely
// delay == 0: does not block, just polls
// delay > 0: block for up to that many nanoseconds
func netpoll(delay int64) gList {
if epfd == -1 {
return gList{}
}
var waitms int32
if delay < 0 {
waitms = -1
} else if delay == 0 {
waitms = 0
} else if delay < 1e6 {
waitms = 1
} else if delay < 1e15 {
waitms = int32(delay / 1e6)
} else {
// An arbitrary cap on how long to wait for a timer.
// 1e9 ms == ~11.5 days.
waitms = 1e9
}
var events [128]syscall.EpollEvent
retry:
n, errno := syscall.EpollWait(epfd, events[:], int32(len(events)), waitms)
if errno != 0 {
if errno != _EINTR {
println("runtime: epollwait on fd", epfd, "failed with", errno)
throw("runtime: netpoll failed")
}
// If a timed sleep was interrupted, just return to
// recalculate how long we should sleep now.
if waitms > 0 {
return gList{}
}
goto retry
}
var toRun gList
for i := int32(0); i < n; i++ {
ev := events[i]
if ev.Events == 0 {
continue
}
if *(**uintptr)(unsafe.Pointer(&ev.Data)) == &netpollBreakRd {
if ev.Events != syscall.EPOLLIN {
println("runtime: netpoll: break fd ready for", ev.Events)
throw("runtime: netpoll: break fd ready for something unexpected")
}
if delay != 0 {
// netpollBreak could be picked up by a
// nonblocking poll. Only read the byte
// if blocking.
var tmp [16]byte
read(int32(netpollBreakRd), noescape(unsafe.Pointer(&tmp[0])), int32(len(tmp)))
netpollWakeSig.Store(0)
}
continue
}
var mode int32
if ev.Events&(syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLHUP|syscall.EPOLLERR) != 0 {
mode += 'r'
}
if ev.Events&(syscall.EPOLLOUT|syscall.EPOLLHUP|syscall.EPOLLERR) != 0 {
mode += 'w'
}
if mode != 0 {
pd := *(**pollDesc)(unsafe.Pointer(&ev.Data))
pd.setEventErr(ev.Events == syscall.EPOLLERR)
netpollready(&toRun, pd, mode)
}
}
return toRun
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/syscall"
"unsafe"
)
// sigPerThreadSyscall is the same signal (SIGSETXID) used by glibc for
// per-thread syscalls on Linux. We use it for the same purpose in non-cgo
// binaries.
const sigPerThreadSyscall = _SIGRTMIN + 1
type mOS struct {
// profileTimer holds the ID of the POSIX interval timer for profiling CPU
// usage on this thread.
//
// It is valid when the profileTimerValid field is true. A thread
// creates and manages its own timer, and these fields are read and written
// only by this thread. But because some of the reads on profileTimerValid
// are in signal handling code, this field should be atomic type.
profileTimer int32
profileTimerValid atomic.Bool
// needPerThreadSyscall indicates that a per-thread syscall is required
// for doAllThreadsSyscall.
needPerThreadSyscall atomic.Uint8
}
//go:noescape
func futex(addr unsafe.Pointer, op int32, val uint32, ts, addr2 unsafe.Pointer, val3 uint32) int32
// Linux futex.
//
// futexsleep(uint32 *addr, uint32 val)
// futexwakeup(uint32 *addr)
//
// Futexsleep atomically checks if *addr == val and if so, sleeps on addr.
// Futexwakeup wakes up threads sleeping on addr.
// Futexsleep is allowed to wake up spuriously.
const (
_FUTEX_PRIVATE_FLAG = 128
_FUTEX_WAIT_PRIVATE = 0 | _FUTEX_PRIVATE_FLAG
_FUTEX_WAKE_PRIVATE = 1 | _FUTEX_PRIVATE_FLAG
)
// Atomically,
//
// if(*addr == val) sleep
//
// Might be woken up spuriously; that's allowed.
// Don't sleep longer than ns; ns < 0 means forever.
//
//go:nosplit
func futexsleep(addr *uint32, val uint32, ns int64) {
// Some Linux kernels have a bug where futex of
// FUTEX_WAIT returns an internal error code
// as an errno. Libpthread ignores the return value
// here, and so can we: as it says a few lines up,
// spurious wakeups are allowed.
if ns < 0 {
futex(unsafe.Pointer(addr), _FUTEX_WAIT_PRIVATE, val, nil, nil, 0)
return
}
var ts timespec
ts.setNsec(ns)
futex(unsafe.Pointer(addr), _FUTEX_WAIT_PRIVATE, val, unsafe.Pointer(&ts), nil, 0)
}
// If any procs are sleeping on addr, wake up at most cnt.
//
//go:nosplit
func futexwakeup(addr *uint32, cnt uint32) {
ret := futex(unsafe.Pointer(addr), _FUTEX_WAKE_PRIVATE, cnt, nil, nil, 0)
if ret >= 0 {
return
}
// I don't know that futex wakeup can return
// EAGAIN or EINTR, but if it does, it would be
// safe to loop and call futex again.
systemstack(func() {
print("futexwakeup addr=", addr, " returned ", ret, "\n")
})
*(*int32)(unsafe.Pointer(uintptr(0x1006))) = 0x1006
}
func getproccount() int32 {
// This buffer is huge (8 kB) but we are on the system stack
// and there should be plenty of space (64 kB).
// Also this is a leaf, so we're not holding up the memory for long.
// See golang.org/issue/11823.
// The suggested behavior here is to keep trying with ever-larger
// buffers, but we don't have a dynamic memory allocator at the
// moment, so that's a bit tricky and seems like overkill.
const maxCPUs = 64 * 1024
var buf [maxCPUs / 8]byte
r := sched_getaffinity(0, unsafe.Sizeof(buf), &buf[0])
if r < 0 {
return 1
}
n := int32(0)
for _, v := range buf[:r] {
for v != 0 {
n += int32(v & 1)
v >>= 1
}
}
if n == 0 {
n = 1
}
return n
}
// Clone, the Linux rfork.
const (
_CLONE_VM = 0x100
_CLONE_FS = 0x200
_CLONE_FILES = 0x400
_CLONE_SIGHAND = 0x800
_CLONE_PTRACE = 0x2000
_CLONE_VFORK = 0x4000
_CLONE_PARENT = 0x8000
_CLONE_THREAD = 0x10000
_CLONE_NEWNS = 0x20000
_CLONE_SYSVSEM = 0x40000
_CLONE_SETTLS = 0x80000
_CLONE_PARENT_SETTID = 0x100000
_CLONE_CHILD_CLEARTID = 0x200000
_CLONE_UNTRACED = 0x800000
_CLONE_CHILD_SETTID = 0x1000000
_CLONE_STOPPED = 0x2000000
_CLONE_NEWUTS = 0x4000000
_CLONE_NEWIPC = 0x8000000
// As of QEMU 2.8.0 (5ea2fc84d), user emulation requires all six of these
// flags to be set when creating a thread; attempts to share the other
// five but leave SYSVSEM unshared will fail with -EINVAL.
//
// In non-QEMU environments CLONE_SYSVSEM is inconsequential as we do not
// use System V semaphores.
cloneFlags = _CLONE_VM | /* share memory */
_CLONE_FS | /* share cwd, etc */
_CLONE_FILES | /* share fd table */
_CLONE_SIGHAND | /* share sig handler table */
_CLONE_SYSVSEM | /* share SysV semaphore undo lists (see issue #20763) */
_CLONE_THREAD /* revisit - okay for now */
)
//go:noescape
func clone(flags int32, stk, mp, gp, fn unsafe.Pointer) int32
// May run with m.p==nil, so write barriers are not allowed.
//
//go:nowritebarrier
func newosproc(mp *m) {
stk := unsafe.Pointer(mp.g0.stack.hi)
/*
* note: strace gets confused if we use CLONE_PTRACE here.
*/
if false {
print("newosproc stk=", stk, " m=", mp, " g=", mp.g0, " clone=", abi.FuncPCABI0(clone), " id=", mp.id, " ostk=", &mp, "\n")
}
// Disable signals during clone, so that the new thread starts
// with signals disabled. It will enable them in minit.
var oset sigset
sigprocmask(_SIG_SETMASK, &sigset_all, &oset)
ret := retryOnEAGAIN(func() int32 {
r := clone(cloneFlags, stk, unsafe.Pointer(mp), unsafe.Pointer(mp.g0), unsafe.Pointer(abi.FuncPCABI0(mstart)))
// clone returns positive TID, negative errno.
// We don't care about the TID.
if r >= 0 {
return 0
}
return -r
})
sigprocmask(_SIG_SETMASK, &oset, nil)
if ret != 0 {
print("runtime: failed to create new OS thread (have ", mcount(), " already; errno=", ret, ")\n")
if ret == _EAGAIN {
println("runtime: may need to increase max user processes (ulimit -u)")
}
throw("newosproc")
}
}
// Version of newosproc that doesn't require a valid G.
//
//go:nosplit
func newosproc0(stacksize uintptr, fn unsafe.Pointer) {
stack := sysAlloc(stacksize, &memstats.stacks_sys)
if stack == nil {
writeErrStr(failallocatestack)
exit(1)
}
ret := clone(cloneFlags, unsafe.Pointer(uintptr(stack)+stacksize), nil, nil, fn)
if ret < 0 {
writeErrStr(failthreadcreate)
exit(1)
}
}
const (
_AT_NULL = 0 // End of vector
_AT_PAGESZ = 6 // System physical page size
_AT_HWCAP = 16 // hardware capability bit vector
_AT_RANDOM = 25 // introduced in 2.6.29
_AT_HWCAP2 = 26 // hardware capability bit vector 2
)
var procAuxv = []byte("/proc/self/auxv\x00")
var addrspace_vec [1]byte
func mincore(addr unsafe.Pointer, n uintptr, dst *byte) int32
var auxvreadbuf [128]uintptr
func sysargs(argc int32, argv **byte) {
n := argc + 1
// skip over argv, envp to get to auxv
for argv_index(argv, n) != nil {
n++
}
// skip NULL separator
n++
// now argv+n is auxv
auxvp := (*[1 << 28]uintptr)(add(unsafe.Pointer(argv), uintptr(n)*goarch.PtrSize))
if pairs := sysauxv(auxvp[:]); pairs != 0 {
auxv = auxvp[: pairs*2 : pairs*2]
return
}
// In some situations we don't get a loader-provided
// auxv, such as when loaded as a library on Android.
// Fall back to /proc/self/auxv.
fd := open(&procAuxv[0], 0 /* O_RDONLY */, 0)
if fd < 0 {
// On Android, /proc/self/auxv might be unreadable (issue 9229), so we fallback to
// try using mincore to detect the physical page size.
// mincore should return EINVAL when address is not a multiple of system page size.
const size = 256 << 10 // size of memory region to allocate
p, err := mmap(nil, size, _PROT_READ|_PROT_WRITE, _MAP_ANON|_MAP_PRIVATE, -1, 0)
if err != 0 {
return
}
var n uintptr
for n = 4 << 10; n < size; n <<= 1 {
err := mincore(unsafe.Pointer(uintptr(p)+n), 1, &addrspace_vec[0])
if err == 0 {
physPageSize = n
break
}
}
if physPageSize == 0 {
physPageSize = size
}
munmap(p, size)
return
}
n = read(fd, noescape(unsafe.Pointer(&auxvreadbuf[0])), int32(unsafe.Sizeof(auxvreadbuf)))
closefd(fd)
if n < 0 {
return
}
// Make sure buf is terminated, even if we didn't read
// the whole file.
auxvreadbuf[len(auxvreadbuf)-2] = _AT_NULL
pairs := sysauxv(auxvreadbuf[:])
auxv = auxvreadbuf[: pairs*2 : pairs*2]
}
// startupRandomData holds random bytes initialized at startup. These come from
// the ELF AT_RANDOM auxiliary vector.
var startupRandomData []byte
func sysauxv(auxv []uintptr) (pairs int) {
var i int
for ; auxv[i] != _AT_NULL; i += 2 {
tag, val := auxv[i], auxv[i+1]
switch tag {
case _AT_RANDOM:
// The kernel provides a pointer to 16-bytes
// worth of random data.
startupRandomData = (*[16]byte)(unsafe.Pointer(val))[:]
case _AT_PAGESZ:
physPageSize = val
}
archauxv(tag, val)
vdsoauxv(tag, val)
}
return i / 2
}
var sysTHPSizePath = []byte("/sys/kernel/mm/transparent_hugepage/hpage_pmd_size\x00")
func getHugePageSize() uintptr {
var numbuf [20]byte
fd := open(&sysTHPSizePath[0], 0 /* O_RDONLY */, 0)
if fd < 0 {
return 0
}
ptr := noescape(unsafe.Pointer(&numbuf[0]))
n := read(fd, ptr, int32(len(numbuf)))
closefd(fd)
if n <= 0 {
return 0
}
n-- // remove trailing newline
v, ok := atoi(slicebytetostringtmp((*byte)(ptr), int(n)))
if !ok || v < 0 {
v = 0
}
if v&(v-1) != 0 {
// v is not a power of 2
return 0
}
return uintptr(v)
}
func osinit() {
ncpu = getproccount()
physHugePageSize = getHugePageSize()
if iscgo {
// #42494 glibc and musl reserve some signals for
// internal use and require they not be blocked by
// the rest of a normal C runtime. When the go runtime
// blocks...unblocks signals, temporarily, the blocked
// interval of time is generally very short. As such,
// these expectations of *libc code are mostly met by
// the combined go+cgo system of threads. However,
// when go causes a thread to exit, via a return from
// mstart(), the combined runtime can deadlock if
// these signals are blocked. Thus, don't block these
// signals when exiting threads.
// - glibc: SIGCANCEL (32), SIGSETXID (33)
// - musl: SIGTIMER (32), SIGCANCEL (33), SIGSYNCCALL (34)
sigdelset(&sigsetAllExiting, 32)
sigdelset(&sigsetAllExiting, 33)
sigdelset(&sigsetAllExiting, 34)
}
osArchInit()
}
var urandom_dev = []byte("/dev/urandom\x00")
func getRandomData(r []byte) {
if startupRandomData != nil {
n := copy(r, startupRandomData)
extendRandom(r, n)
return
}
fd := open(&urandom_dev[0], 0 /* O_RDONLY */, 0)
n := read(fd, unsafe.Pointer(&r[0]), int32(len(r)))
closefd(fd)
extendRandom(r, int(n))
}
func goenvs() {
goenvs_unix()
}
// Called to do synchronous initialization of Go code built with
// -buildmode=c-archive or -buildmode=c-shared.
// None of the Go runtime is initialized.
//
//go:nosplit
//go:nowritebarrierrec
func libpreinit() {
initsig(true)
}
// Called to initialize a new m (including the bootstrap m).
// Called on the parent thread (main thread in case of bootstrap), can allocate memory.
func mpreinit(mp *m) {
mp.gsignal = malg(32 * 1024) // Linux wants >= 2K
mp.gsignal.m = mp
}
func gettid() uint32
// Called to initialize a new m (including the bootstrap m).
// Called on the new thread, cannot allocate memory.
func minit() {
minitSignals()
// Cgo-created threads and the bootstrap m are missing a
// procid. We need this for asynchronous preemption and it's
// useful in debuggers.
getg().m.procid = uint64(gettid())
}
// Called from dropm to undo the effect of an minit.
//
//go:nosplit
func unminit() {
unminitSignals()
}
// Called from exitm, but not from drop, to undo the effect of thread-owned
// resources in minit, semacreate, or elsewhere. Do not take locks after calling this.
func mdestroy(mp *m) {
}
//#ifdef GOARCH_386
//#define sa_handler k_sa_handler
//#endif
func sigreturn()
func sigtramp() // Called via C ABI
func cgoSigtramp()
//go:noescape
func sigaltstack(new, old *stackt)
//go:noescape
func setitimer(mode int32, new, old *itimerval)
//go:noescape
func timer_create(clockid int32, sevp *sigevent, timerid *int32) int32
//go:noescape
func timer_settime(timerid int32, flags int32, new, old *itimerspec) int32
//go:noescape
func timer_delete(timerid int32) int32
//go:noescape
func rtsigprocmask(how int32, new, old *sigset, size int32)
//go:nosplit
//go:nowritebarrierrec
func sigprocmask(how int32, new, old *sigset) {
rtsigprocmask(how, new, old, int32(unsafe.Sizeof(*new)))
}
func raise(sig uint32)
func raiseproc(sig uint32)
//go:noescape
func sched_getaffinity(pid, len uintptr, buf *byte) int32
func osyield()
//go:nosplit
func osyield_no_g() {
osyield()
}
func pipe2(flags int32) (r, w int32, errno int32)
const (
_si_max_size = 128
_sigev_max_size = 64
)
//go:nosplit
//go:nowritebarrierrec
func setsig(i uint32, fn uintptr) {
var sa sigactiont
sa.sa_flags = _SA_SIGINFO | _SA_ONSTACK | _SA_RESTORER | _SA_RESTART
sigfillset(&sa.sa_mask)
// Although Linux manpage says "sa_restorer element is obsolete and
// should not be used". x86_64 kernel requires it. Only use it on
// x86.
if GOARCH == "386" || GOARCH == "amd64" {
sa.sa_restorer = abi.FuncPCABI0(sigreturn)
}
if fn == abi.FuncPCABIInternal(sighandler) { // abi.FuncPCABIInternal(sighandler) matches the callers in signal_unix.go
if iscgo {
fn = abi.FuncPCABI0(cgoSigtramp)
} else {
fn = abi.FuncPCABI0(sigtramp)
}
}
sa.sa_handler = fn
sigaction(i, &sa, nil)
}
//go:nosplit
//go:nowritebarrierrec
func setsigstack(i uint32) {
var sa sigactiont
sigaction(i, nil, &sa)
if sa.sa_flags&_SA_ONSTACK != 0 {
return
}
sa.sa_flags |= _SA_ONSTACK
sigaction(i, &sa, nil)
}
//go:nosplit
//go:nowritebarrierrec
func getsig(i uint32) uintptr {
var sa sigactiont
sigaction(i, nil, &sa)
return sa.sa_handler
}
// setSignalstackSP sets the ss_sp field of a stackt.
//
//go:nosplit
func setSignalstackSP(s *stackt, sp uintptr) {
*(*uintptr)(unsafe.Pointer(&s.ss_sp)) = sp
}
//go:nosplit
func (c *sigctxt) fixsigcode(sig uint32) {
}
// sysSigaction calls the rt_sigaction system call.
//
//go:nosplit
func sysSigaction(sig uint32, new, old *sigactiont) {
if rt_sigaction(uintptr(sig), new, old, unsafe.Sizeof(sigactiont{}.sa_mask)) != 0 {
// Workaround for bugs in QEMU user mode emulation.
//
// QEMU turns calls to the sigaction system call into
// calls to the C library sigaction call; the C
// library call rejects attempts to call sigaction for
// SIGCANCEL (32) or SIGSETXID (33).
//
// QEMU rejects calling sigaction on SIGRTMAX (64).
//
// Just ignore the error in these case. There isn't
// anything we can do about it anyhow.
if sig != 32 && sig != 33 && sig != 64 {
// Use system stack to avoid split stack overflow on ppc64/ppc64le.
systemstack(func() {
throw("sigaction failed")
})
}
}
}
// rt_sigaction is implemented in assembly.
//
//go:noescape
func rt_sigaction(sig uintptr, new, old *sigactiont, size uintptr) int32
func getpid() int
func tgkill(tgid, tid, sig int)
// signalM sends a signal to mp.
func signalM(mp *m, sig int) {
tgkill(getpid(), int(mp.procid), sig)
}
// validSIGPROF compares this signal delivery's code against the signal sources
// that the profiler uses, returning whether the delivery should be processed.
// To be processed, a signal delivery from a known profiling mechanism should
// correspond to the best profiling mechanism available to this thread. Signals
// from other sources are always considered valid.
//
//go:nosplit
func validSIGPROF(mp *m, c *sigctxt) bool {
code := int32(c.sigcode())
setitimer := code == _SI_KERNEL
timer_create := code == _SI_TIMER
if !(setitimer || timer_create) {
// The signal doesn't correspond to a profiling mechanism that the
// runtime enables itself. There's no reason to process it, but there's
// no reason to ignore it either.
return true
}
if mp == nil {
// Since we don't have an M, we can't check if there's an active
// per-thread timer for this thread. We don't know how long this thread
// has been around, and if it happened to interact with the Go scheduler
// at a time when profiling was active (causing it to have a per-thread
// timer). But it may have never interacted with the Go scheduler, or
// never while profiling was active. To avoid double-counting, process
// only signals from setitimer.
//
// When a custom cgo traceback function has been registered (on
// platforms that support runtime.SetCgoTraceback), SIGPROF signals
// delivered to a thread that cannot find a matching M do this check in
// the assembly implementations of runtime.cgoSigtramp.
return setitimer
}
// Having an M means the thread interacts with the Go scheduler, and we can
// check whether there's an active per-thread timer for this thread.
if mp.profileTimerValid.Load() {
// If this M has its own per-thread CPU profiling interval timer, we
// should track the SIGPROF signals that come from that timer (for
// accurate reporting of its CPU usage; see issue 35057) and ignore any
// that it gets from the process-wide setitimer (to not over-count its
// CPU consumption).
return timer_create
}
// No active per-thread timer means the only valid profiler is setitimer.
return setitimer
}
func setProcessCPUProfiler(hz int32) {
setProcessCPUProfilerTimer(hz)
}
func setThreadCPUProfiler(hz int32) {
mp := getg().m
mp.profilehz = hz
// destroy any active timer
if mp.profileTimerValid.Load() {
timerid := mp.profileTimer
mp.profileTimerValid.Store(false)
mp.profileTimer = 0
ret := timer_delete(timerid)
if ret != 0 {
print("runtime: failed to disable profiling timer; timer_delete(", timerid, ") errno=", -ret, "\n")
throw("timer_delete")
}
}
if hz == 0 {
// If the goal was to disable profiling for this thread, then the job's done.
return
}
// The period of the timer should be 1/Hz. For every "1/Hz" of additional
// work, the user should expect one additional sample in the profile.
//
// But to scale down to very small amounts of application work, to observe
// even CPU usage of "one tenth" of the requested period, set the initial
// timing delay in a different way: So that "one tenth" of a period of CPU
// spend shows up as a 10% chance of one sample (for an expected value of
// 0.1 samples), and so that "two and six tenths" periods of CPU spend show
// up as a 60% chance of 3 samples and a 40% chance of 2 samples (for an
// expected value of 2.6). Set the initial delay to a value in the unifom
// random distribution between 0 and the desired period. And because "0"
// means "disable timer", add 1 so the half-open interval [0,period) turns
// into (0,period].
//
// Otherwise, this would show up as a bias away from short-lived threads and
// from threads that are only occasionally active: for example, when the
// garbage collector runs on a mostly-idle system, the additional threads it
// activates may do a couple milliseconds of GC-related work and nothing
// else in the few seconds that the profiler observes.
spec := new(itimerspec)
spec.it_value.setNsec(1 + int64(fastrandn(uint32(1e9/hz))))
spec.it_interval.setNsec(1e9 / int64(hz))
var timerid int32
var sevp sigevent
sevp.notify = _SIGEV_THREAD_ID
sevp.signo = _SIGPROF
sevp.sigev_notify_thread_id = int32(mp.procid)
ret := timer_create(_CLOCK_THREAD_CPUTIME_ID, &sevp, &timerid)
if ret != 0 {
// If we cannot create a timer for this M, leave profileTimerValid false
// to fall back to the process-wide setitimer profiler.
return
}
ret = timer_settime(timerid, 0, spec, nil)
if ret != 0 {
print("runtime: failed to configure profiling timer; timer_settime(", timerid,
", 0, {interval: {",
spec.it_interval.tv_sec, "s + ", spec.it_interval.tv_nsec, "ns} value: {",
spec.it_value.tv_sec, "s + ", spec.it_value.tv_nsec, "ns}}, nil) errno=", -ret, "\n")
throw("timer_settime")
}
mp.profileTimer = timerid
mp.profileTimerValid.Store(true)
}
// perThreadSyscallArgs contains the system call number, arguments, and
// expected return values for a system call to be executed on all threads.
type perThreadSyscallArgs struct {
trap uintptr
a1 uintptr
a2 uintptr
a3 uintptr
a4 uintptr
a5 uintptr
a6 uintptr
r1 uintptr
r2 uintptr
}
// perThreadSyscall is the system call to execute for the ongoing
// doAllThreadsSyscall.
//
// perThreadSyscall may only be written while mp.needPerThreadSyscall == 0 on
// all Ms.
var perThreadSyscall perThreadSyscallArgs
// syscall_runtime_doAllThreadsSyscall and executes a specified system call on
// all Ms.
//
// The system call is expected to succeed and return the same value on every
// thread. If any threads do not match, the runtime throws.
//
//go:linkname syscall_runtime_doAllThreadsSyscall syscall.runtime_doAllThreadsSyscall
//go:uintptrescapes
func syscall_runtime_doAllThreadsSyscall(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr) {
if iscgo {
// In cgo, we are not aware of threads created in C, so this approach will not work.
panic("doAllThreadsSyscall not supported with cgo enabled")
}
// STW to guarantee that user goroutines see an atomic change to thread
// state. Without STW, goroutines could migrate Ms while change is in
// progress and e.g., see state old -> new -> old -> new.
//
// N.B. Internally, this function does not depend on STW to
// successfully change every thread. It is only needed for user
// expectations, per above.
stopTheWorld("doAllThreadsSyscall")
// This function depends on several properties:
//
// 1. All OS threads that already exist are associated with an M in
// allm. i.e., we won't miss any pre-existing threads.
// 2. All Ms listed in allm will eventually have an OS thread exist.
// i.e., they will set procid and be able to receive signals.
// 3. OS threads created after we read allm will clone from a thread
// that has executed the system call. i.e., they inherit the
// modified state.
//
// We achieve these through different mechanisms:
//
// 1. Addition of new Ms to allm in allocm happens before clone of its
// OS thread later in newm.
// 2. newm does acquirem to avoid being preempted, ensuring that new Ms
// created in allocm will eventually reach OS thread clone later in
// newm.
// 3. We take allocmLock for write here to prevent allocation of new Ms
// while this function runs. Per (1), this prevents clone of OS
// threads that are not yet in allm.
allocmLock.lock()
// Disable preemption, preventing us from changing Ms, as we handle
// this M specially.
//
// N.B. STW and lock() above do this as well, this is added for extra
// clarity.
acquirem()
// N.B. allocmLock also prevents concurrent execution of this function,
// serializing use of perThreadSyscall, mp.needPerThreadSyscall, and
// ensuring all threads execute system calls from multiple calls in the
// same order.
r1, r2, errno := syscall.Syscall6(trap, a1, a2, a3, a4, a5, a6)
if GOARCH == "ppc64" || GOARCH == "ppc64le" {
// TODO(https://go.dev/issue/51192 ): ppc64 doesn't use r2.
r2 = 0
}
if errno != 0 {
releasem(getg().m)
allocmLock.unlock()
startTheWorld()
return r1, r2, errno
}
perThreadSyscall = perThreadSyscallArgs{
trap: trap,
a1: a1,
a2: a2,
a3: a3,
a4: a4,
a5: a5,
a6: a6,
r1: r1,
r2: r2,
}
// Wait for all threads to start.
//
// As described above, some Ms have been added to allm prior to
// allocmLock, but not yet completed OS clone and set procid.
//
// At minimum we must wait for a thread to set procid before we can
// send it a signal.
//
// We take this one step further and wait for all threads to start
// before sending any signals. This prevents system calls from getting
// applied twice: once in the parent and once in the child, like so:
//
// A B C
// add C to allm
// doAllThreadsSyscall
// allocmLock.lock()
// signal B
// <receive signal>
// execute syscall
// <signal return>
// clone C
// <thread start>
// set procid
// signal C
// <receive signal>
// execute syscall
// <signal return>
//
// In this case, thread C inherited the syscall-modified state from
// thread B and did not need to execute the syscall, but did anyway
// because doAllThreadsSyscall could not be sure whether it was
// required.
//
// Some system calls may not be idempotent, so we ensure each thread
// executes the system call exactly once.
for mp := allm; mp != nil; mp = mp.alllink {
for atomic.Load64(&mp.procid) == 0 {
// Thread is starting.
osyield()
}
}
// Signal every other thread, where they will execute perThreadSyscall
// from the signal handler.
gp := getg()
tid := gp.m.procid
for mp := allm; mp != nil; mp = mp.alllink {
if atomic.Load64(&mp.procid) == tid {
// Our thread already performed the syscall.
continue
}
mp.needPerThreadSyscall.Store(1)
signalM(mp, sigPerThreadSyscall)
}
// Wait for all threads to complete.
for mp := allm; mp != nil; mp = mp.alllink {
if mp.procid == tid {
continue
}
for mp.needPerThreadSyscall.Load() != 0 {
osyield()
}
}
perThreadSyscall = perThreadSyscallArgs{}
releasem(getg().m)
allocmLock.unlock()
startTheWorld()
return r1, r2, errno
}
// runPerThreadSyscall runs perThreadSyscall for this M if required.
//
// This function throws if the system call returns with anything other than the
// expected values.
//
//go:nosplit
func runPerThreadSyscall() {
gp := getg()
if gp.m.needPerThreadSyscall.Load() == 0 {
return
}
args := perThreadSyscall
r1, r2, errno := syscall.Syscall6(args.trap, args.a1, args.a2, args.a3, args.a4, args.a5, args.a6)
if GOARCH == "ppc64" || GOARCH == "ppc64le" {
// TODO(https://go.dev/issue/51192 ): ppc64 doesn't use r2.
r2 = 0
}
if errno != 0 || r1 != args.r1 || r2 != args.r2 {
print("trap:", args.trap, ", a123456=[", args.a1, ",", args.a2, ",", args.a3, ",", args.a4, ",", args.a5, ",", args.a6, "]\n")
print("results: got {r1=", r1, ",r2=", r2, ",errno=", errno, "}, want {r1=", args.r1, ",r2=", args.r2, ",errno=0}\n")
fatal("AllThreadsSyscall6 results differ between threads; runtime corrupted")
}
gp.m.needPerThreadSyscall.Store(0)
}
const (
_SI_USER = 0
_SI_TKILL = -6
)
// sigFromUser reports whether the signal was sent because of a call
// to kill or tgkill.
//
//go:nosplit
func (c *sigctxt) sigFromUser() bool {
code := int32(c.sigcode())
return code == _SI_USER || code == _SI_TKILL
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !mips && !mipsle && !mips64 && !mips64le && !s390x && !ppc64 && linux
package runtime
const (
_SS_DISABLE = 2
_NSIG = 65
_SIG_BLOCK = 0
_SIG_UNBLOCK = 1
_SIG_SETMASK = 2
)
// It's hard to tease out exactly how big a Sigset is, but
// rt_sigprocmask crashes if we get it wrong, so if binaries
// are running, this is right.
type sigset [2]uint32
var sigset_all = sigset{^uint32(0), ^uint32(0)}
//go:nosplit
//go:nowritebarrierrec
func sigaddset(mask *sigset, i int) {
(*mask)[(i-1)/32] |= 1 << ((uint32(i) - 1) & 31)
}
func sigdelset(mask *sigset, i int) {
(*mask)[(i-1)/32] &^= 1 << ((uint32(i) - 1) & 31)
}
//go:nosplit
func sigfillset(mask *uint64) {
*mask = ^uint64(0)
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && !arm && !arm64 && !loong64 && !mips && !mipsle && !mips64 && !mips64le && !s390x && !ppc64 && !ppc64le
package runtime
func archauxv(tag, val uintptr) {
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (386 || amd64)
package runtime
func osArchInit() {}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !openbsd
package runtime
// osStackAlloc performs OS-specific initialization before s is used
// as stack memory.
func osStackAlloc(s *mspan) {
}
// osStackFree undoes the effect of osStackAlloc before s is returned
// to the heap.
func osStackFree(s *mspan) {
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !goexperiment.pagetrace
package runtime
//go:systemstack
func pageTraceAlloc(pp *p, now int64, base, npages uintptr) {
}
//go:systemstack
func pageTraceFree(pp *p, now int64, base, npages uintptr) {
}
//go:systemstack
func pageTraceScav(pp *p, now int64, base, npages uintptr) {
}
type pageTraceBuf struct {
}
func initPageTrace(env string) {
}
func finishPageTrace() {
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// throwType indicates the current type of ongoing throw, which affects the
// amount of detail printed to stderr. Higher values include more detail.
type throwType uint32
const (
// throwTypeNone means that we are not throwing.
throwTypeNone throwType = iota
// throwTypeUser is a throw due to a problem with the application.
//
// These throws do not include runtime frames, system goroutines, or
// frame metadata.
throwTypeUser
// throwTypeRuntime is a throw due to a problem with Go itself.
//
// These throws include as much information as possible to aid in
// debugging the runtime, including runtime frames, system goroutines,
// and frame metadata.
throwTypeRuntime
)
// We have two different ways of doing defers. The older way involves creating a
// defer record at the time that a defer statement is executing and adding it to a
// defer chain. This chain is inspected by the deferreturn call at all function
// exits in order to run the appropriate defer calls. A cheaper way (which we call
// open-coded defers) is used for functions in which no defer statements occur in
// loops. In that case, we simply store the defer function/arg information into
// specific stack slots at the point of each defer statement, as well as setting a
// bit in a bitmask. At each function exit, we add inline code to directly make
// the appropriate defer calls based on the bitmask and fn/arg information stored
// on the stack. During panic/Goexit processing, the appropriate defer calls are
// made using extra funcdata info that indicates the exact stack slots that
// contain the bitmask and defer fn/args.
// Check to make sure we can really generate a panic. If the panic
// was generated from the runtime, or from inside malloc, then convert
// to a throw of msg.
// pc should be the program counter of the compiler-generated code that
// triggered this panic.
func panicCheck1(pc uintptr, msg string) {
if goarch.IsWasm == 0 && hasPrefix(funcname(findfunc(pc)), "runtime.") {
// Note: wasm can't tail call, so we can't get the original caller's pc.
throw(msg)
}
// TODO: is this redundant? How could we be in malloc
// but not in the runtime? runtime/internal/*, maybe?
gp := getg()
if gp != nil && gp.m != nil && gp.m.mallocing != 0 {
throw(msg)
}
}
// Same as above, but calling from the runtime is allowed.
//
// Using this function is necessary for any panic that may be
// generated by runtime.sigpanic, since those are always called by the
// runtime.
func panicCheck2(err string) {
// panic allocates, so to avoid recursive malloc, turn panics
// during malloc into throws.
gp := getg()
if gp != nil && gp.m != nil && gp.m.mallocing != 0 {
throw(err)
}
}
// Many of the following panic entry-points turn into throws when they
// happen in various runtime contexts. These should never happen in
// the runtime, and if they do, they indicate a serious issue and
// should not be caught by user code.
//
// The panic{Index,Slice,divide,shift} functions are called by
// code generated by the compiler for out of bounds index expressions,
// out of bounds slice expressions, division by zero, and shift by negative.
// The panicdivide (again), panicoverflow, panicfloat, and panicmem
// functions are called by the signal handler when a signal occurs
// indicating the respective problem.
//
// Since panic{Index,Slice,shift} are never called directly, and
// since the runtime package should never have an out of bounds slice
// or array reference or negative shift, if we see those functions called from the
// runtime package we turn the panic into a throw. That will dump the
// entire runtime stack for easier debugging.
//
// The entry points called by the signal handler will be called from
// runtime.sigpanic, so we can't disallow calls from the runtime to
// these (they always look like they're called from the runtime).
// Hence, for these, we just check for clearly bad runtime conditions.
//
// The panic{Index,Slice} functions are implemented in assembly and tail call
// to the goPanic{Index,Slice} functions below. This is done so we can use
// a space-minimal register calling convention.
// failures in the comparisons for s[x], 0 <= x < y (y == len(s))
//
//go:yeswritebarrierrec
func goPanicIndex(x int, y int) {
panicCheck1(getcallerpc(), "index out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsIndex})
}
//go:yeswritebarrierrec
func goPanicIndexU(x uint, y int) {
panicCheck1(getcallerpc(), "index out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsIndex})
}
// failures in the comparisons for s[:x], 0 <= x <= y (y == len(s) or cap(s))
//
//go:yeswritebarrierrec
func goPanicSliceAlen(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSliceAlen})
}
//go:yeswritebarrierrec
func goPanicSliceAlenU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSliceAlen})
}
//go:yeswritebarrierrec
func goPanicSliceAcap(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSliceAcap})
}
//go:yeswritebarrierrec
func goPanicSliceAcapU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSliceAcap})
}
// failures in the comparisons for s[x:y], 0 <= x <= y
//
//go:yeswritebarrierrec
func goPanicSliceB(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSliceB})
}
//go:yeswritebarrierrec
func goPanicSliceBU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSliceB})
}
// failures in the comparisons for s[::x], 0 <= x <= y (y == len(s) or cap(s))
func goPanicSlice3Alen(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSlice3Alen})
}
func goPanicSlice3AlenU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSlice3Alen})
}
func goPanicSlice3Acap(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSlice3Acap})
}
func goPanicSlice3AcapU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSlice3Acap})
}
// failures in the comparisons for s[:x:y], 0 <= x <= y
func goPanicSlice3B(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSlice3B})
}
func goPanicSlice3BU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSlice3B})
}
// failures in the comparisons for s[x:y:], 0 <= x <= y
func goPanicSlice3C(x int, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsSlice3C})
}
func goPanicSlice3CU(x uint, y int) {
panicCheck1(getcallerpc(), "slice bounds out of range")
panic(boundsError{x: int64(x), signed: false, y: y, code: boundsSlice3C})
}
// failures in the conversion ([x]T)(s) or (*[x]T)(s), 0 <= x <= y, y == len(s)
func goPanicSliceConvert(x int, y int) {
panicCheck1(getcallerpc(), "slice length too short to convert to array or pointer to array")
panic(boundsError{x: int64(x), signed: true, y: y, code: boundsConvert})
}
// Implemented in assembly, as they take arguments in registers.
// Declared here to mark them as ABIInternal.
func panicIndex(x int, y int)
func panicIndexU(x uint, y int)
func panicSliceAlen(x int, y int)
func panicSliceAlenU(x uint, y int)
func panicSliceAcap(x int, y int)
func panicSliceAcapU(x uint, y int)
func panicSliceB(x int, y int)
func panicSliceBU(x uint, y int)
func panicSlice3Alen(x int, y int)
func panicSlice3AlenU(x uint, y int)
func panicSlice3Acap(x int, y int)
func panicSlice3AcapU(x uint, y int)
func panicSlice3B(x int, y int)
func panicSlice3BU(x uint, y int)
func panicSlice3C(x int, y int)
func panicSlice3CU(x uint, y int)
func panicSliceConvert(x int, y int)
var shiftError = error(errorString("negative shift amount"))
//go:yeswritebarrierrec
func panicshift() {
panicCheck1(getcallerpc(), "negative shift amount")
panic(shiftError)
}
var divideError = error(errorString("integer divide by zero"))
//go:yeswritebarrierrec
func panicdivide() {
panicCheck2("integer divide by zero")
panic(divideError)
}
var overflowError = error(errorString("integer overflow"))
func panicoverflow() {
panicCheck2("integer overflow")
panic(overflowError)
}
var floatError = error(errorString("floating point error"))
func panicfloat() {
panicCheck2("floating point error")
panic(floatError)
}
var memoryError = error(errorString("invalid memory address or nil pointer dereference"))
func panicmem() {
panicCheck2("invalid memory address or nil pointer dereference")
panic(memoryError)
}
func panicmemAddr(addr uintptr) {
panicCheck2("invalid memory address or nil pointer dereference")
panic(errorAddressString{msg: "invalid memory address or nil pointer dereference", addr: addr})
}
// Create a new deferred function fn, which has no arguments and results.
// The compiler turns a defer statement into a call to this.
func deferproc(fn func()) {
gp := getg()
if gp.m.curg != gp {
// go code on the system stack can't defer
throw("defer on system stack")
}
d := newdefer()
if d._panic != nil {
throw("deferproc: d.panic != nil after newdefer")
}
d.link = gp._defer
gp._defer = d
d.fn = fn
d.pc = getcallerpc()
// We must not be preempted between calling getcallersp and
// storing it to d.sp because getcallersp's result is a
// uintptr stack pointer.
d.sp = getcallersp()
// deferproc returns 0 normally.
// a deferred func that stops a panic
// makes the deferproc return 1.
// the code the compiler generates always
// checks the return value and jumps to the
// end of the function if deferproc returns != 0.
return0()
// No code can go here - the C return register has
// been set and must not be clobbered.
}
// deferprocStack queues a new deferred function with a defer record on the stack.
// The defer record must have its fn field initialized.
// All other fields can contain junk.
// Nosplit because of the uninitialized pointer fields on the stack.
//
//go:nosplit
func deferprocStack(d *_defer) {
gp := getg()
if gp.m.curg != gp {
// go code on the system stack can't defer
throw("defer on system stack")
}
// fn is already set.
// The other fields are junk on entry to deferprocStack and
// are initialized here.
d.started = false
d.heap = false
d.openDefer = false
d.sp = getcallersp()
d.pc = getcallerpc()
d.framepc = 0
d.varp = 0
// The lines below implement:
// d.panic = nil
// d.fd = nil
// d.link = gp._defer
// gp._defer = d
// But without write barriers. The first three are writes to
// the stack so they don't need a write barrier, and furthermore
// are to uninitialized memory, so they must not use a write barrier.
// The fourth write does not require a write barrier because we
// explicitly mark all the defer structures, so we don't need to
// keep track of pointers to them with a write barrier.
*(*uintptr)(unsafe.Pointer(&d._panic)) = 0
*(*uintptr)(unsafe.Pointer(&d.fd)) = 0
*(*uintptr)(unsafe.Pointer(&d.link)) = uintptr(unsafe.Pointer(gp._defer))
*(*uintptr)(unsafe.Pointer(&gp._defer)) = uintptr(unsafe.Pointer(d))
return0()
// No code can go here - the C return register has
// been set and must not be clobbered.
}
// Each P holds a pool for defers.
// Allocate a Defer, usually using per-P pool.
// Each defer must be released with freedefer. The defer is not
// added to any defer chain yet.
func newdefer() *_defer {
var d *_defer
mp := acquirem()
pp := mp.p.ptr()
if len(pp.deferpool) == 0 && sched.deferpool != nil {
lock(&sched.deferlock)
for len(pp.deferpool) < cap(pp.deferpool)/2 && sched.deferpool != nil {
d := sched.deferpool
sched.deferpool = d.link
d.link = nil
pp.deferpool = append(pp.deferpool, d)
}
unlock(&sched.deferlock)
}
if n := len(pp.deferpool); n > 0 {
d = pp.deferpool[n-1]
pp.deferpool[n-1] = nil
pp.deferpool = pp.deferpool[:n-1]
}
releasem(mp)
mp, pp = nil, nil
if d == nil {
// Allocate new defer.
d = new(_defer)
}
d.heap = true
return d
}
// Free the given defer.
// The defer cannot be used after this call.
//
// This is nosplit because the incoming defer is in a perilous state.
// It's not on any defer list, so stack copying won't adjust stack
// pointers in it (namely, d.link). Hence, if we were to copy the
// stack, d could then contain a stale pointer.
//
//go:nosplit
func freedefer(d *_defer) {
d.link = nil
// After this point we can copy the stack.
if d._panic != nil {
freedeferpanic()
}
if d.fn != nil {
freedeferfn()
}
if !d.heap {
return
}
mp := acquirem()
pp := mp.p.ptr()
if len(pp.deferpool) == cap(pp.deferpool) {
// Transfer half of local cache to the central cache.
var first, last *_defer
for len(pp.deferpool) > cap(pp.deferpool)/2 {
n := len(pp.deferpool)
d := pp.deferpool[n-1]
pp.deferpool[n-1] = nil
pp.deferpool = pp.deferpool[:n-1]
if first == nil {
first = d
} else {
last.link = d
}
last = d
}
lock(&sched.deferlock)
last.link = sched.deferpool
sched.deferpool = first
unlock(&sched.deferlock)
}
*d = _defer{}
pp.deferpool = append(pp.deferpool, d)
releasem(mp)
mp, pp = nil, nil
}
// Separate function so that it can split stack.
// Windows otherwise runs out of stack space.
func freedeferpanic() {
// _panic must be cleared before d is unlinked from gp.
throw("freedefer with d._panic != nil")
}
func freedeferfn() {
// fn must be cleared before d is unlinked from gp.
throw("freedefer with d.fn != nil")
}
// deferreturn runs deferred functions for the caller's frame.
// The compiler inserts a call to this at the end of any
// function which calls defer.
func deferreturn() {
gp := getg()
for {
d := gp._defer
if d == nil {
return
}
sp := getcallersp()
if d.sp != sp {
return
}
if d.openDefer {
done := runOpenDeferFrame(d)
if !done {
throw("unfinished open-coded defers in deferreturn")
}
gp._defer = d.link
freedefer(d)
// If this frame uses open defers, then this
// must be the only defer record for the
// frame, so we can just return.
return
}
fn := d.fn
d.fn = nil
gp._defer = d.link
freedefer(d)
fn()
}
}
// Goexit terminates the goroutine that calls it. No other goroutine is affected.
// Goexit runs all deferred calls before terminating the goroutine. Because Goexit
// is not a panic, any recover calls in those deferred functions will return nil.
//
// Calling Goexit from the main goroutine terminates that goroutine
// without func main returning. Since func main has not returned,
// the program continues execution of other goroutines.
// If all other goroutines exit, the program crashes.
func Goexit() {
// Run all deferred functions for the current goroutine.
// This code is similar to gopanic, see that implementation
// for detailed comments.
gp := getg()
// Create a panic object for Goexit, so we can recognize when it might be
// bypassed by a recover().
var p _panic
p.goexit = true
p.link = gp._panic
gp._panic = (*_panic)(noescape(unsafe.Pointer(&p)))
addOneOpenDeferFrame(gp, getcallerpc(), unsafe.Pointer(getcallersp()))
for {
d := gp._defer
if d == nil {
break
}
if d.started {
if d._panic != nil {
d._panic.aborted = true
d._panic = nil
}
if !d.openDefer {
d.fn = nil
gp._defer = d.link
freedefer(d)
continue
}
}
d.started = true
d._panic = (*_panic)(noescape(unsafe.Pointer(&p)))
if d.openDefer {
done := runOpenDeferFrame(d)
if !done {
// We should always run all defers in the frame,
// since there is no panic associated with this
// defer that can be recovered.
throw("unfinished open-coded defers in Goexit")
}
if p.aborted {
// Since our current defer caused a panic and may
// have been already freed, just restart scanning
// for open-coded defers from this frame again.
addOneOpenDeferFrame(gp, getcallerpc(), unsafe.Pointer(getcallersp()))
} else {
addOneOpenDeferFrame(gp, 0, nil)
}
} else {
// Save the pc/sp in deferCallSave(), so we can "recover" back to this
// loop if necessary.
deferCallSave(&p, d.fn)
}
if p.aborted {
// We had a recursive panic in the defer d we started, and
// then did a recover in a defer that was further down the
// defer chain than d. In the case of an outstanding Goexit,
// we force the recover to return back to this loop. d will
// have already been freed if completed, so just continue
// immediately to the next defer on the chain.
p.aborted = false
continue
}
if gp._defer != d {
throw("bad defer entry in Goexit")
}
d._panic = nil
d.fn = nil
gp._defer = d.link
freedefer(d)
// Note: we ignore recovers here because Goexit isn't a panic
}
goexit1()
}
// Call all Error and String methods before freezing the world.
// Used when crashing with panicking.
func preprintpanics(p *_panic) {
defer func() {
text := "panic while printing panic value"
switch r := recover().(type) {
case nil:
// nothing to do
case string:
throw(text + ": " + r)
default:
throw(text + ": type " + efaceOf(&r)._type.string())
}
}()
for p != nil {
switch v := p.arg.(type) {
case error:
p.arg = v.Error()
case stringer:
p.arg = v.String()
}
p = p.link
}
}
// Print all currently active panics. Used when crashing.
// Should only be called after preprintpanics.
func printpanics(p *_panic) {
if p.link != nil {
printpanics(p.link)
if !p.link.goexit {
print("\t")
}
}
if p.goexit {
return
}
print("panic: ")
printany(p.arg)
if p.recovered {
print(" [recovered]")
}
print("\n")
}
// addOneOpenDeferFrame scans the stack (in gentraceback order, from inner frames to
// outer frames) for the first frame (if any) with open-coded defers. If it finds
// one, it adds a single entry to the defer chain for that frame. The entry added
// represents all the defers in the associated open defer frame, and is sorted in
// order with respect to any non-open-coded defers.
//
// addOneOpenDeferFrame stops (possibly without adding a new entry) if it encounters
// an in-progress open defer entry. An in-progress open defer entry means there has
// been a new panic because of a defer in the associated frame. addOneOpenDeferFrame
// does not add an open defer entry past a started entry, because that started entry
// still needs to finished, and addOneOpenDeferFrame will be called when that started
// entry is completed. The defer removal loop in gopanic() similarly stops at an
// in-progress defer entry. Together, addOneOpenDeferFrame and the defer removal loop
// ensure the invariant that there is no open defer entry further up the stack than
// an in-progress defer, and also that the defer removal loop is guaranteed to remove
// all not-in-progress open defer entries from the defer chain.
//
// If sp is non-nil, addOneOpenDeferFrame starts the stack scan from the frame
// specified by sp. If sp is nil, it uses the sp from the current defer record (which
// has just been finished). Hence, it continues the stack scan from the frame of the
// defer that just finished. It skips any frame that already has a (not-in-progress)
// open-coded _defer record in the defer chain.
//
// Note: All entries of the defer chain (including this new open-coded entry) have
// their pointers (including sp) adjusted properly if the stack moves while
// running deferred functions. Also, it is safe to pass in the sp arg (which is
// the direct result of calling getcallersp()), because all pointer variables
// (including arguments) are adjusted as needed during stack copies.
func addOneOpenDeferFrame(gp *g, pc uintptr, sp unsafe.Pointer) {
var prevDefer *_defer
if sp == nil {
prevDefer = gp._defer
pc = prevDefer.framepc
sp = unsafe.Pointer(prevDefer.sp)
}
systemstack(func() {
gentraceback(pc, uintptr(sp), 0, gp, 0, nil, 0x7fffffff,
func(frame *stkframe, unused unsafe.Pointer) bool {
if prevDefer != nil && prevDefer.sp == frame.sp {
// Skip the frame for the previous defer that
// we just finished (and was used to set
// where we restarted the stack scan)
return true
}
f := frame.fn
fd := funcdata(f, _FUNCDATA_OpenCodedDeferInfo)
if fd == nil {
return true
}
// Insert the open defer record in the
// chain, in order sorted by sp.
d := gp._defer
var prev *_defer
for d != nil {
dsp := d.sp
if frame.sp < dsp {
break
}
if frame.sp == dsp {
if !d.openDefer {
throw("duplicated defer entry")
}
// Don't add any record past an
// in-progress defer entry. We don't
// need it, and more importantly, we
// want to keep the invariant that
// there is no open defer entry
// passed an in-progress entry (see
// header comment).
if d.started {
return false
}
return true
}
prev = d
d = d.link
}
if frame.fn.deferreturn == 0 {
throw("missing deferreturn")
}
d1 := newdefer()
d1.openDefer = true
d1._panic = nil
// These are the pc/sp to set after we've
// run a defer in this frame that did a
// recover. We return to a special
// deferreturn that runs any remaining
// defers and then returns from the
// function.
d1.pc = frame.fn.entry() + uintptr(frame.fn.deferreturn)
d1.varp = frame.varp
d1.fd = fd
// Save the SP/PC associated with current frame,
// so we can continue stack trace later if needed.
d1.framepc = frame.pc
d1.sp = frame.sp
d1.link = d
if prev == nil {
gp._defer = d1
} else {
prev.link = d1
}
// Stop stack scanning after adding one open defer record
return false
},
nil, 0)
})
}
// readvarintUnsafe reads the uint32 in varint format starting at fd, and returns the
// uint32 and a pointer to the byte following the varint.
//
// There is a similar function runtime.readvarint, which takes a slice of bytes,
// rather than an unsafe pointer. These functions are duplicated, because one of
// the two use cases for the functions would get slower if the functions were
// combined.
func readvarintUnsafe(fd unsafe.Pointer) (uint32, unsafe.Pointer) {
var r uint32
var shift int
for {
b := *(*uint8)((unsafe.Pointer(fd)))
fd = add(fd, unsafe.Sizeof(b))
if b < 128 {
return r + uint32(b)<<shift, fd
}
r += ((uint32(b) &^ 128) << shift)
shift += 7
if shift > 28 {
panic("Bad varint")
}
}
}
// runOpenDeferFrame runs the active open-coded defers in the frame specified by
// d. It normally processes all active defers in the frame, but stops immediately
// if a defer does a successful recover. It returns true if there are no
// remaining defers to run in the frame.
func runOpenDeferFrame(d *_defer) bool {
done := true
fd := d.fd
deferBitsOffset, fd := readvarintUnsafe(fd)
nDefers, fd := readvarintUnsafe(fd)
deferBits := *(*uint8)(unsafe.Pointer(d.varp - uintptr(deferBitsOffset)))
for i := int(nDefers) - 1; i >= 0; i-- {
// read the funcdata info for this defer
var closureOffset uint32
closureOffset, fd = readvarintUnsafe(fd)
if deferBits&(1<<i) == 0 {
continue
}
closure := *(*func())(unsafe.Pointer(d.varp - uintptr(closureOffset)))
d.fn = closure
deferBits = deferBits &^ (1 << i)
*(*uint8)(unsafe.Pointer(d.varp - uintptr(deferBitsOffset))) = deferBits
p := d._panic
// Call the defer. Note that this can change d.varp if
// the stack moves.
deferCallSave(p, d.fn)
if p != nil && p.aborted {
break
}
d.fn = nil
if d._panic != nil && d._panic.recovered {
done = deferBits == 0
break
}
}
return done
}
// deferCallSave calls fn() after saving the caller's pc and sp in the
// panic record. This allows the runtime to return to the Goexit defer
// processing loop, in the unusual case where the Goexit may be
// bypassed by a successful recover.
//
// This is marked as a wrapper by the compiler so it doesn't appear in
// tracebacks.
func deferCallSave(p *_panic, fn func()) {
if p != nil {
p.argp = unsafe.Pointer(getargp())
p.pc = getcallerpc()
p.sp = unsafe.Pointer(getcallersp())
}
fn()
if p != nil {
p.pc = 0
p.sp = unsafe.Pointer(nil)
}
}
// A PanicNilError happens when code calls panic(nil).
//
// Before Go 1.21, programs that called panic(nil) observed recover returning nil.
// Starting in Go 1.21, programs that call panic(nil) observe recover returning a *PanicNilError.
// Programs can change back to the old behavior by setting GODEBUG=panicnil=1.
type PanicNilError struct {
// This field makes PanicNilError structurally different from
// any other struct in this package, and the _ makes it different
// from any struct in other packages too.
// This avoids any accidental conversions being possible
// between this struct and some other struct sharing the same fields,
// like happened in go.dev/issue/56603.
_ [0]*PanicNilError
}
func (*PanicNilError) Error() string { return "panic called with nil argument" }
func (*PanicNilError) RuntimeError() {}
var panicnil = &godebugInc{name: "panicnil"}
// The implementation of the predeclared function panic.
func gopanic(e any) {
if e == nil {
if debug.panicnil.Load() != 1 {
e = new(PanicNilError)
} else {
panicnil.IncNonDefault()
}
}
gp := getg()
if gp.m.curg != gp {
print("panic: ")
printany(e)
print("\n")
throw("panic on system stack")
}
if gp.m.mallocing != 0 {
print("panic: ")
printany(e)
print("\n")
throw("panic during malloc")
}
if gp.m.preemptoff != "" {
print("panic: ")
printany(e)
print("\n")
print("preempt off reason: ")
print(gp.m.preemptoff)
print("\n")
throw("panic during preemptoff")
}
if gp.m.locks != 0 {
print("panic: ")
printany(e)
print("\n")
throw("panic holding locks")
}
var p _panic
p.arg = e
p.link = gp._panic
gp._panic = (*_panic)(noescape(unsafe.Pointer(&p)))
runningPanicDefers.Add(1)
// By calculating getcallerpc/getcallersp here, we avoid scanning the
// gopanic frame (stack scanning is slow...)
addOneOpenDeferFrame(gp, getcallerpc(), unsafe.Pointer(getcallersp()))
for {
d := gp._defer
if d == nil {
break
}
// If defer was started by earlier panic or Goexit (and, since we're back here, that triggered a new panic),
// take defer off list. An earlier panic will not continue running, but we will make sure below that an
// earlier Goexit does continue running.
if d.started {
if d._panic != nil {
d._panic.aborted = true
}
d._panic = nil
if !d.openDefer {
// For open-coded defers, we need to process the
// defer again, in case there are any other defers
// to call in the frame (not including the defer
// call that caused the panic).
d.fn = nil
gp._defer = d.link
freedefer(d)
continue
}
}
// Mark defer as started, but keep on list, so that traceback
// can find and update the defer's argument frame if stack growth
// or a garbage collection happens before executing d.fn.
d.started = true
// Record the panic that is running the defer.
// If there is a new panic during the deferred call, that panic
// will find d in the list and will mark d._panic (this panic) aborted.
d._panic = (*_panic)(noescape(unsafe.Pointer(&p)))
done := true
if d.openDefer {
done = runOpenDeferFrame(d)
if done && !d._panic.recovered {
addOneOpenDeferFrame(gp, 0, nil)
}
} else {
p.argp = unsafe.Pointer(getargp())
d.fn()
}
p.argp = nil
// Deferred function did not panic. Remove d.
if gp._defer != d {
throw("bad defer entry in panic")
}
d._panic = nil
// trigger shrinkage to test stack copy. See stack_test.go:TestStackPanic
//GC()
pc := d.pc
sp := unsafe.Pointer(d.sp) // must be pointer so it gets adjusted during stack copy
if done {
d.fn = nil
gp._defer = d.link
freedefer(d)
}
if p.recovered {
gp._panic = p.link
if gp._panic != nil && gp._panic.goexit && gp._panic.aborted {
// A normal recover would bypass/abort the Goexit. Instead,
// we return to the processing loop of the Goexit.
gp.sigcode0 = uintptr(gp._panic.sp)
gp.sigcode1 = uintptr(gp._panic.pc)
mcall(recovery)
throw("bypassed recovery failed") // mcall should not return
}
runningPanicDefers.Add(-1)
// After a recover, remove any remaining non-started,
// open-coded defer entries, since the corresponding defers
// will be executed normally (inline). Any such entry will
// become stale once we run the corresponding defers inline
// and exit the associated stack frame. We only remove up to
// the first started (in-progress) open defer entry, not
// including the current frame, since any higher entries will
// be from a higher panic in progress, and will still be
// needed.
d := gp._defer
var prev *_defer
if !done {
// Skip our current frame, if not done. It is
// needed to complete any remaining defers in
// deferreturn()
prev = d
d = d.link
}
for d != nil {
if d.started {
// This defer is started but we
// are in the middle of a
// defer-panic-recover inside of
// it, so don't remove it or any
// further defer entries
break
}
if d.openDefer {
if prev == nil {
gp._defer = d.link
} else {
prev.link = d.link
}
newd := d.link
freedefer(d)
d = newd
} else {
prev = d
d = d.link
}
}
gp._panic = p.link
// Aborted panics are marked but remain on the g.panic list.
// Remove them from the list.
for gp._panic != nil && gp._panic.aborted {
gp._panic = gp._panic.link
}
if gp._panic == nil { // must be done with signal
gp.sig = 0
}
// Pass information about recovering frame to recovery.
gp.sigcode0 = uintptr(sp)
gp.sigcode1 = pc
mcall(recovery)
throw("recovery failed") // mcall should not return
}
}
// ran out of deferred calls - old-school panic now
// Because it is unsafe to call arbitrary user code after freezing
// the world, we call preprintpanics to invoke all necessary Error
// and String methods to prepare the panic strings before startpanic.
preprintpanics(gp._panic)
fatalpanic(gp._panic) // should not return
*(*int)(nil) = 0 // not reached
}
// getargp returns the location where the caller
// writes outgoing function call arguments.
//
//go:nosplit
//go:noinline
func getargp() uintptr {
return getcallersp() + sys.MinFrameSize
}
// The implementation of the predeclared function recover.
// Cannot split the stack because it needs to reliably
// find the stack segment of its caller.
//
// TODO(rsc): Once we commit to CopyStackAlways,
// this doesn't need to be nosplit.
//
//go:nosplit
func gorecover(argp uintptr) any {
// Must be in a function running as part of a deferred call during the panic.
// Must be called from the topmost function of the call
// (the function used in the defer statement).
// p.argp is the argument pointer of that topmost deferred function call.
// Compare against argp reported by caller.
// If they match, the caller is the one who can recover.
gp := getg()
p := gp._panic
if p != nil && !p.goexit && !p.recovered && argp == uintptr(p.argp) {
p.recovered = true
return p.arg
}
return nil
}
//go:linkname sync_throw sync.throw
func sync_throw(s string) {
throw(s)
}
//go:linkname sync_fatal sync.fatal
func sync_fatal(s string) {
fatal(s)
}
// throw triggers a fatal error that dumps a stack trace and exits.
//
// throw should be used for runtime-internal fatal errors where Go itself,
// rather than user code, may be at fault for the failure.
//
//go:nosplit
func throw(s string) {
// Everything throw does should be recursively nosplit so it
// can be called even when it's unsafe to grow the stack.
systemstack(func() {
print("fatal error: ", s, "\n")
})
fatalthrow(throwTypeRuntime)
}
// fatal triggers a fatal error that dumps a stack trace and exits.
//
// fatal is equivalent to throw, but is used when user code is expected to be
// at fault for the failure, such as racing map writes.
//
// fatal does not include runtime frames, system goroutines, or frame metadata
// (fp, sp, pc) in the stack trace unless GOTRACEBACK=system or higher.
//
//go:nosplit
func fatal(s string) {
// Everything fatal does should be recursively nosplit so it
// can be called even when it's unsafe to grow the stack.
systemstack(func() {
print("fatal error: ", s, "\n")
})
fatalthrow(throwTypeUser)
}
// runningPanicDefers is non-zero while running deferred functions for panic.
// This is used to try hard to get a panic stack trace out when exiting.
var runningPanicDefers atomic.Uint32
// panicking is non-zero when crashing the program for an unrecovered panic.
var panicking atomic.Uint32
// paniclk is held while printing the panic information and stack trace,
// so that two concurrent panics don't overlap their output.
var paniclk mutex
// Unwind the stack after a deferred function calls recover
// after a panic. Then arrange to continue running as though
// the caller of the deferred function returned normally.
func recovery(gp *g) {
// Info about defer passed in G struct.
sp := gp.sigcode0
pc := gp.sigcode1
// d's arguments need to be in the stack.
if sp != 0 && (sp < gp.stack.lo || gp.stack.hi < sp) {
print("recover: ", hex(sp), " not in [", hex(gp.stack.lo), ", ", hex(gp.stack.hi), "]\n")
throw("bad recovery")
}
// Make the deferproc for this d return again,
// this time returning 1. The calling function will
// jump to the standard return epilogue.
gp.sched.sp = sp
gp.sched.pc = pc
gp.sched.lr = 0
gp.sched.ret = 1
gogo(&gp.sched)
}
// fatalthrow implements an unrecoverable runtime throw. It freezes the
// system, prints stack traces starting from its caller, and terminates the
// process.
//
//go:nosplit
func fatalthrow(t throwType) {
pc := getcallerpc()
sp := getcallersp()
gp := getg()
if gp.m.throwing == throwTypeNone {
gp.m.throwing = t
}
// Switch to the system stack to avoid any stack growth, which may make
// things worse if the runtime is in a bad state.
systemstack(func() {
startpanic_m()
if dopanic_m(gp, pc, sp) {
// crash uses a decent amount of nosplit stack and we're already
// low on stack in throw, so crash on the system stack (unlike
// fatalpanic).
crash()
}
exit(2)
})
*(*int)(nil) = 0 // not reached
}
// fatalpanic implements an unrecoverable panic. It is like fatalthrow, except
// that if msgs != nil, fatalpanic also prints panic messages and decrements
// runningPanicDefers once main is blocked from exiting.
//
//go:nosplit
func fatalpanic(msgs *_panic) {
pc := getcallerpc()
sp := getcallersp()
gp := getg()
var docrash bool
// Switch to the system stack to avoid any stack growth, which
// may make things worse if the runtime is in a bad state.
systemstack(func() {
if startpanic_m() && msgs != nil {
// There were panic messages and startpanic_m
// says it's okay to try to print them.
// startpanic_m set panicking, which will
// block main from exiting, so now OK to
// decrement runningPanicDefers.
runningPanicDefers.Add(-1)
printpanics(msgs)
}
docrash = dopanic_m(gp, pc, sp)
})
if docrash {
// By crashing outside the above systemstack call, debuggers
// will not be confused when generating a backtrace.
// Function crash is marked nosplit to avoid stack growth.
crash()
}
systemstack(func() {
exit(2)
})
*(*int)(nil) = 0 // not reached
}
// startpanic_m prepares for an unrecoverable panic.
//
// It returns true if panic messages should be printed, or false if
// the runtime is in bad shape and should just print stacks.
//
// It must not have write barriers even though the write barrier
// explicitly ignores writes once dying > 0. Write barriers still
// assume that g.m.p != nil, and this function may not have P
// in some contexts (e.g. a panic in a signal handler for a signal
// sent to an M with no P).
//
//go:nowritebarrierrec
func startpanic_m() bool {
gp := getg()
if mheap_.cachealloc.size == 0 { // very early
print("runtime: panic before malloc heap initialized\n")
}
// Disallow malloc during an unrecoverable panic. A panic
// could happen in a signal handler, or in a throw, or inside
// malloc itself. We want to catch if an allocation ever does
// happen (even if we're not in one of these situations).
gp.m.mallocing++
// If we're dying because of a bad lock count, set it to a
// good lock count so we don't recursively panic below.
if gp.m.locks < 0 {
gp.m.locks = 1
}
switch gp.m.dying {
case 0:
// Setting dying >0 has the side-effect of disabling this G's writebuf.
gp.m.dying = 1
panicking.Add(1)
lock(&paniclk)
if debug.schedtrace > 0 || debug.scheddetail > 0 {
schedtrace(true)
}
freezetheworld()
return true
case 1:
// Something failed while panicking.
// Just print a stack trace and exit.
gp.m.dying = 2
print("panic during panic\n")
return false
case 2:
// This is a genuine bug in the runtime, we couldn't even
// print the stack trace successfully.
gp.m.dying = 3
print("stack trace unavailable\n")
exit(4)
fallthrough
default:
// Can't even print! Just exit.
exit(5)
return false // Need to return something.
}
}
var didothers bool
var deadlock mutex
// gp is the crashing g running on this M, but may be a user G, while getg() is
// always g0.
func dopanic_m(gp *g, pc, sp uintptr) bool {
if gp.sig != 0 {
signame := signame(gp.sig)
if signame != "" {
print("[signal ", signame)
} else {
print("[signal ", hex(gp.sig))
}
print(" code=", hex(gp.sigcode0), " addr=", hex(gp.sigcode1), " pc=", hex(gp.sigpc), "]\n")
}
level, all, docrash := gotraceback()
if level > 0 {
if gp != gp.m.curg {
all = true
}
if gp != gp.m.g0 {
print("\n")
goroutineheader(gp)
traceback(pc, sp, 0, gp)
} else if level >= 2 || gp.m.throwing >= throwTypeRuntime {
print("\nruntime stack:\n")
traceback(pc, sp, 0, gp)
}
if !didothers && all {
didothers = true
tracebackothers(gp)
}
}
unlock(&paniclk)
if panicking.Add(-1) != 0 {
// Some other m is panicking too.
// Let it print what it needs to print.
// Wait forever without chewing up cpu.
// It will exit when it's done.
lock(&deadlock)
lock(&deadlock)
}
printDebugLog()
return docrash
}
// canpanic returns false if a signal should throw instead of
// panicking.
//
//go:nosplit
func canpanic() bool {
gp := getg()
mp := acquirem()
// Is it okay for gp to panic instead of crashing the program?
// Yes, as long as it is running Go code, not runtime code,
// and not stuck in a system call.
if gp != mp.curg {
releasem(mp)
return false
}
// N.B. mp.locks != 1 instead of 0 to account for acquirem.
if mp.locks != 1 || mp.mallocing != 0 || mp.throwing != throwTypeNone || mp.preemptoff != "" || mp.dying != 0 {
releasem(mp)
return false
}
status := readgstatus(gp)
if status&^_Gscan != _Grunning || gp.syscallsp != 0 {
releasem(mp)
return false
}
if GOOS == "windows" && mp.libcallsp != 0 {
releasem(mp)
return false
}
releasem(mp)
return true
}
// shouldPushSigpanic reports whether pc should be used as sigpanic's
// return PC (pushing a frame for the call). Otherwise, it should be
// left alone so that LR is used as sigpanic's return PC, effectively
// replacing the top-most frame with sigpanic. This is used by
// preparePanic.
func shouldPushSigpanic(gp *g, pc, lr uintptr) bool {
if pc == 0 {
// Probably a call to a nil func. The old LR is more
// useful in the stack trace. Not pushing the frame
// will make the trace look like a call to sigpanic
// instead. (Otherwise the trace will end at sigpanic
// and we won't get to see who faulted.)
return false
}
// If we don't recognize the PC as code, but we do recognize
// the link register as code, then this assumes the panic was
// caused by a call to non-code. In this case, we want to
// ignore this call to make unwinding show the context.
//
// If we running C code, we're not going to recognize pc as a
// Go function, so just assume it's good. Otherwise, traceback
// may try to read a stale LR that looks like a Go code
// pointer and wander into the woods.
if gp.m.incgo || findfunc(pc).valid() {
// This wasn't a bad call, so use PC as sigpanic's
// return PC.
return true
}
if findfunc(lr).valid() {
// This was a bad call, but the LR is good, so use the
// LR as sigpanic's return PC.
return false
}
// Neither the PC or LR is good. Hopefully pushing a frame
// will work.
return true
}
// isAbortPC reports whether pc is the program counter at which
// runtime.abort raises a signal.
//
// It is nosplit because it's part of the isgoexception
// implementation.
//
//go:nosplit
func isAbortPC(pc uintptr) bool {
f := findfunc(pc)
if !f.valid() {
return false
}
return f.funcID == funcID_abort
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
//go:linkname plugin_lastmoduleinit plugin.lastmoduleinit
func plugin_lastmoduleinit() (path string, syms map[string]any, initTasks []*initTask, errstr string) {
var md *moduledata
for pmd := firstmoduledata.next; pmd != nil; pmd = pmd.next {
if pmd.bad {
md = nil // we only want the last module
continue
}
md = pmd
}
if md == nil {
throw("runtime: no plugin module data")
}
if md.pluginpath == "" {
throw("runtime: plugin has empty pluginpath")
}
if md.typemap != nil {
return "", nil, nil, "plugin already loaded"
}
for _, pmd := range activeModules() {
if pmd.pluginpath == md.pluginpath {
md.bad = true
return "", nil, nil, "plugin already loaded"
}
if inRange(pmd.text, pmd.etext, md.text, md.etext) ||
inRange(pmd.bss, pmd.ebss, md.bss, md.ebss) ||
inRange(pmd.data, pmd.edata, md.data, md.edata) ||
inRange(pmd.types, pmd.etypes, md.types, md.etypes) {
println("plugin: new module data overlaps with previous moduledata")
println("\tpmd.text-etext=", hex(pmd.text), "-", hex(pmd.etext))
println("\tpmd.bss-ebss=", hex(pmd.bss), "-", hex(pmd.ebss))
println("\tpmd.data-edata=", hex(pmd.data), "-", hex(pmd.edata))
println("\tpmd.types-etypes=", hex(pmd.types), "-", hex(pmd.etypes))
println("\tmd.text-etext=", hex(md.text), "-", hex(md.etext))
println("\tmd.bss-ebss=", hex(md.bss), "-", hex(md.ebss))
println("\tmd.data-edata=", hex(md.data), "-", hex(md.edata))
println("\tmd.types-etypes=", hex(md.types), "-", hex(md.etypes))
throw("plugin: new module data overlaps with previous moduledata")
}
}
for _, pkghash := range md.pkghashes {
if pkghash.linktimehash != *pkghash.runtimehash {
md.bad = true
return "", nil, nil, "plugin was built with a different version of package " + pkghash.modulename
}
}
// Initialize the freshly loaded module.
modulesinit()
typelinksinit()
pluginftabverify(md)
moduledataverify1(md)
lock(&itabLock)
for _, i := range md.itablinks {
itabAdd(i)
}
unlock(&itabLock)
// Build a map of symbol names to symbols. Here in the runtime
// we fill out the first word of the interface, the type. We
// pass these zero value interfaces to the plugin package,
// where the symbol value is filled in (usually via cgo).
//
// Because functions are handled specially in the plugin package,
// function symbol names are prefixed here with '.' to avoid
// a dependency on the reflect package.
syms = make(map[string]any, len(md.ptab))
for _, ptab := range md.ptab {
symName := resolveNameOff(unsafe.Pointer(md.types), ptab.name)
t := (*_type)(unsafe.Pointer(md.types)).typeOff(ptab.typ)
var val any
valp := (*[2]unsafe.Pointer)(unsafe.Pointer(&val))
(*valp)[0] = unsafe.Pointer(t)
name := symName.name()
if t.kind&kindMask == kindFunc {
name = "." + name
}
syms[name] = val
}
return md.pluginpath, syms, md.inittasks, ""
}
func pluginftabverify(md *moduledata) {
badtable := false
for i := 0; i < len(md.ftab); i++ {
entry := md.textAddr(md.ftab[i].entryoff)
if md.minpc <= entry && entry <= md.maxpc {
continue
}
f := funcInfo{(*_func)(unsafe.Pointer(&md.pclntable[md.ftab[i].funcoff])), md}
name := funcname(f)
// A common bug is f.entry has a relocation to a duplicate
// function symbol, meaning if we search for its PC we get
// a valid entry with a name that is useful for debugging.
name2 := "none"
entry2 := uintptr(0)
f2 := findfunc(entry)
if f2.valid() {
name2 = funcname(f2)
entry2 = f2.entry()
}
badtable = true
println("ftab entry", hex(entry), "/", hex(entry2), ": ",
name, "/", name2, "outside pc range:[", hex(md.minpc), ",", hex(md.maxpc), "], modulename=", md.modulename, ", pluginpath=", md.pluginpath)
}
if badtable {
throw("runtime: plugin has bad symbol table")
}
}
// inRange reports whether v0 or v1 are in the range [r0, r1].
func inRange(r0, r1, v0, v1 uintptr) bool {
return (v0 >= r0 && v0 <= r1) || (v1 >= r0 && v1 <= r1)
}
// A ptabEntry is generated by the compiler for each exported function
// and global variable in the main package of a plugin. It is used to
// initialize the plugin module's symbol map.
type ptabEntry struct {
name nameOff
typ typeOff
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import (
"encoding/binary"
"errors"
"fmt"
"os"
)
var (
errBadELF = errors.New("malformed ELF binary")
errNoBuildID = errors.New("no NT_GNU_BUILD_ID found in ELF binary")
)
// elfBuildID returns the GNU build ID of the named ELF binary,
// without introducing a dependency on debug/elf and its dependencies.
func elfBuildID(file string) (string, error) {
buf := make([]byte, 256)
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
if _, err := f.ReadAt(buf[:64], 0); err != nil {
return "", err
}
// ELF file begins with \x7F E L F.
if buf[0] != 0x7F || buf[1] != 'E' || buf[2] != 'L' || buf[3] != 'F' {
return "", errBadELF
}
var byteOrder binary.ByteOrder
switch buf[5] {
default:
return "", errBadELF
case 1: // little-endian
byteOrder = binary.LittleEndian
case 2: // big-endian
byteOrder = binary.BigEndian
}
var shnum int
var shoff, shentsize int64
switch buf[4] {
default:
return "", errBadELF
case 1: // 32-bit file header
shoff = int64(byteOrder.Uint32(buf[32:]))
shentsize = int64(byteOrder.Uint16(buf[46:]))
if shentsize != 40 {
return "", errBadELF
}
shnum = int(byteOrder.Uint16(buf[48:]))
case 2: // 64-bit file header
shoff = int64(byteOrder.Uint64(buf[40:]))
shentsize = int64(byteOrder.Uint16(buf[58:]))
if shentsize != 64 {
return "", errBadELF
}
shnum = int(byteOrder.Uint16(buf[60:]))
}
for i := 0; i < shnum; i++ {
if _, err := f.ReadAt(buf[:shentsize], shoff+int64(i)*shentsize); err != nil {
return "", err
}
if typ := byteOrder.Uint32(buf[4:]); typ != 7 { // SHT_NOTE
continue
}
var off, size int64
if shentsize == 40 {
// 32-bit section header
off = int64(byteOrder.Uint32(buf[16:]))
size = int64(byteOrder.Uint32(buf[20:]))
} else {
// 64-bit section header
off = int64(byteOrder.Uint64(buf[24:]))
size = int64(byteOrder.Uint64(buf[32:]))
}
size += off
for off < size {
if _, err := f.ReadAt(buf[:16], off); err != nil { // room for header + name GNU\x00
return "", err
}
nameSize := int(byteOrder.Uint32(buf[0:]))
descSize := int(byteOrder.Uint32(buf[4:]))
noteType := int(byteOrder.Uint32(buf[8:]))
descOff := off + int64(12+(nameSize+3)&^3)
off = descOff + int64((descSize+3)&^3)
if nameSize != 4 || noteType != 3 || buf[12] != 'G' || buf[13] != 'N' || buf[14] != 'U' || buf[15] != '\x00' { // want name GNU\x00 type 3 (NT_GNU_BUILD_ID)
continue
}
if descSize > len(buf) {
return "", errBadELF
}
if _, err := f.ReadAt(buf[:descSize], descOff); err != nil {
return "", err
}
return fmt.Sprintf("%x", buf[:descSize]), nil
}
}
return "", errNoBuildID
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import (
"context"
"fmt"
"sort"
"strings"
)
type label struct {
key string
value string
}
// LabelSet is a set of labels.
type LabelSet struct {
list []label
}
// labelContextKey is the type of contextKeys used for profiler labels.
type labelContextKey struct{}
func labelValue(ctx context.Context) labelMap {
labels, _ := ctx.Value(labelContextKey{}).(*labelMap)
if labels == nil {
return labelMap(nil)
}
return *labels
}
// labelMap is the representation of the label set held in the context type.
// This is an initial implementation, but it will be replaced with something
// that admits incremental immutable modification more efficiently.
type labelMap map[string]string
// String satisfies Stringer and returns key, value pairs in a consistent
// order.
func (l *labelMap) String() string {
if l == nil {
return ""
}
keyVals := make([]string, 0, len(*l))
for k, v := range *l {
keyVals = append(keyVals, fmt.Sprintf("%q:%q", k, v))
}
sort.Strings(keyVals)
return "{" + strings.Join(keyVals, ", ") + "}"
}
// WithLabels returns a new context.Context with the given labels added.
// A label overwrites a prior label with the same key.
func WithLabels(ctx context.Context, labels LabelSet) context.Context {
parentLabels := labelValue(ctx)
childLabels := make(labelMap, len(parentLabels))
// TODO(matloob): replace the map implementation with something
// more efficient so creating a child context WithLabels doesn't need
// to clone the map.
for k, v := range parentLabels {
childLabels[k] = v
}
for _, label := range labels.list {
childLabels[label.key] = label.value
}
return context.WithValue(ctx, labelContextKey{}, &childLabels)
}
// Labels takes an even number of strings representing key-value pairs
// and makes a LabelSet containing them.
// A label overwrites a prior label with the same key.
// Currently only the CPU and goroutine profiles utilize any labels
// information.
// See https://golang.org/issue/23458 for details.
func Labels(args ...string) LabelSet {
if len(args)%2 != 0 {
panic("uneven number of arguments to pprof.Labels")
}
list := make([]label, 0, len(args)/2)
for i := 0; i+1 < len(args); i += 2 {
list = append(list, label{key: args[i], value: args[i+1]})
}
return LabelSet{list: list}
}
// Label returns the value of the label with the given key on ctx, and a boolean indicating
// whether that label exists.
func Label(ctx context.Context, key string) (string, bool) {
ctxLabels := labelValue(ctx)
v, ok := ctxLabels[key]
return v, ok
}
// ForLabels invokes f with each label set on the context.
// The function f should return true to continue iteration or false to stop iteration early.
func ForLabels(ctx context.Context, f func(key, value string) bool) {
ctxLabels := labelValue(ctx)
for k, v := range ctxLabels {
if !f(k, v) {
break
}
}
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import "unsafe"
// A profMap is a map from (stack, tag) to mapEntry.
// It grows without bound, but that's assumed to be OK.
type profMap struct {
hash map[uintptr]*profMapEntry
all *profMapEntry
last *profMapEntry
free []profMapEntry
freeStk []uintptr
}
// A profMapEntry is a single entry in the profMap.
type profMapEntry struct {
nextHash *profMapEntry // next in hash list
nextAll *profMapEntry // next in list of all entries
stk []uintptr
tag unsafe.Pointer
count int64
}
func (m *profMap) lookup(stk []uint64, tag unsafe.Pointer) *profMapEntry {
// Compute hash of (stk, tag).
h := uintptr(0)
for _, x := range stk {
h = h<<8 | (h >> (8 * (unsafe.Sizeof(h) - 1)))
h += uintptr(x) * 41
}
h = h<<8 | (h >> (8 * (unsafe.Sizeof(h) - 1)))
h += uintptr(tag) * 41
// Find entry if present.
var last *profMapEntry
Search:
for e := m.hash[h]; e != nil; last, e = e, e.nextHash {
if len(e.stk) != len(stk) || e.tag != tag {
continue
}
for j := range stk {
if e.stk[j] != uintptr(stk[j]) {
continue Search
}
}
// Move to front.
if last != nil {
last.nextHash = e.nextHash
e.nextHash = m.hash[h]
m.hash[h] = e
}
return e
}
// Add new entry.
if len(m.free) < 1 {
m.free = make([]profMapEntry, 128)
}
e := &m.free[0]
m.free = m.free[1:]
e.nextHash = m.hash[h]
e.tag = tag
if len(m.freeStk) < len(stk) {
m.freeStk = make([]uintptr, 1024)
}
// Limit cap to prevent append from clobbering freeStk.
e.stk = m.freeStk[:len(stk):len(stk)]
m.freeStk = m.freeStk[len(stk):]
for j := range stk {
e.stk[j] = uintptr(stk[j])
}
if m.hash == nil {
m.hash = make(map[uintptr]*profMapEntry)
}
m.hash[h] = e
if m.all == nil {
m.all = e
m.last = e
} else {
m.last.nextAll = e
m.last = e
}
return e
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import "os"
// peBuildID returns a best effort unique ID for the named executable.
//
// It would be wasteful to calculate the hash of the whole file,
// instead use the binary name and the last modified time for the buildid.
func peBuildID(file string) string {
s, err := os.Stat(file)
if err != nil {
return file
}
return file + s.ModTime().String()
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package pprof writes runtime profiling data in the format expected
// by the pprof visualization tool.
//
// # Profiling a Go program
//
// The first step to profiling a Go program is to enable profiling.
// Support for profiling benchmarks built with the standard testing
// package is built into go test. For example, the following command
// runs benchmarks in the current directory and writes the CPU and
// memory profiles to cpu.prof and mem.prof:
//
// go test -cpuprofile cpu.prof -memprofile mem.prof -bench .
//
// To add equivalent profiling support to a standalone program, add
// code like the following to your main function:
//
// var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to `file`")
// var memprofile = flag.String("memprofile", "", "write memory profile to `file`")
//
// func main() {
// flag.Parse()
// if *cpuprofile != "" {
// f, err := os.Create(*cpuprofile)
// if err != nil {
// log.Fatal("could not create CPU profile: ", err)
// }
// defer f.Close() // error handling omitted for example
// if err := pprof.StartCPUProfile(f); err != nil {
// log.Fatal("could not start CPU profile: ", err)
// }
// defer pprof.StopCPUProfile()
// }
//
// // ... rest of the program ...
//
// if *memprofile != "" {
// f, err := os.Create(*memprofile)
// if err != nil {
// log.Fatal("could not create memory profile: ", err)
// }
// defer f.Close() // error handling omitted for example
// runtime.GC() // get up-to-date statistics
// if err := pprof.WriteHeapProfile(f); err != nil {
// log.Fatal("could not write memory profile: ", err)
// }
// }
// }
//
// There is also a standard HTTP interface to profiling data. Adding
// the following line will install handlers under the /debug/pprof/
// URL to download live profiles:
//
// import _ "net/http/pprof"
//
// See the net/http/pprof package for more details.
//
// Profiles can then be visualized with the pprof tool:
//
// go tool pprof cpu.prof
//
// There are many commands available from the pprof command line.
// Commonly used commands include "top", which prints a summary of the
// top program hot-spots, and "web", which opens an interactive graph
// of hot-spots and their call graphs. Use "help" for information on
// all pprof commands.
//
// For more information about pprof, see
// https://github.com/google/pprof/blob/master/doc/README.md.
package pprof
import (
"bufio"
"fmt"
"internal/abi"
"io"
"runtime"
"sort"
"strings"
"sync"
"text/tabwriter"
"time"
"unsafe"
)
// BUG(rsc): Profiles are only as good as the kernel support used to generate them.
// See https://golang.org/issue/13841 for details about known problems.
// A Profile is a collection of stack traces showing the call sequences
// that led to instances of a particular event, such as allocation.
// Packages can create and maintain their own profiles; the most common
// use is for tracking resources that must be explicitly closed, such as files
// or network connections.
//
// A Profile's methods can be called from multiple goroutines simultaneously.
//
// Each Profile has a unique name. A few profiles are predefined:
//
// goroutine - stack traces of all current goroutines
// heap - a sampling of memory allocations of live objects
// allocs - a sampling of all past memory allocations
// threadcreate - stack traces that led to the creation of new OS threads
// block - stack traces that led to blocking on synchronization primitives
// mutex - stack traces of holders of contended mutexes
//
// These predefined profiles maintain themselves and panic on an explicit
// Add or Remove method call.
//
// The heap profile reports statistics as of the most recently completed
// garbage collection; it elides more recent allocation to avoid skewing
// the profile away from live data and toward garbage.
// If there has been no garbage collection at all, the heap profile reports
// all known allocations. This exception helps mainly in programs running
// without garbage collection enabled, usually for debugging purposes.
//
// The heap profile tracks both the allocation sites for all live objects in
// the application memory and for all objects allocated since the program start.
// Pprof's -inuse_space, -inuse_objects, -alloc_space, and -alloc_objects
// flags select which to display, defaulting to -inuse_space (live objects,
// scaled by size).
//
// The allocs profile is the same as the heap profile but changes the default
// pprof display to -alloc_space, the total number of bytes allocated since
// the program began (including garbage-collected bytes).
//
// The CPU profile is not available as a Profile. It has a special API,
// the StartCPUProfile and StopCPUProfile functions, because it streams
// output to a writer during profiling.
type Profile struct {
name string
mu sync.Mutex
m map[any][]uintptr
count func() int
write func(io.Writer, int) error
}
// profiles records all registered profiles.
var profiles struct {
mu sync.Mutex
m map[string]*Profile
}
var goroutineProfile = &Profile{
name: "goroutine",
count: countGoroutine,
write: writeGoroutine,
}
var threadcreateProfile = &Profile{
name: "threadcreate",
count: countThreadCreate,
write: writeThreadCreate,
}
var heapProfile = &Profile{
name: "heap",
count: countHeap,
write: writeHeap,
}
var allocsProfile = &Profile{
name: "allocs",
count: countHeap, // identical to heap profile
write: writeAlloc,
}
var blockProfile = &Profile{
name: "block",
count: countBlock,
write: writeBlock,
}
var mutexProfile = &Profile{
name: "mutex",
count: countMutex,
write: writeMutex,
}
func lockProfiles() {
profiles.mu.Lock()
if profiles.m == nil {
// Initial built-in profiles.
profiles.m = map[string]*Profile{
"goroutine": goroutineProfile,
"threadcreate": threadcreateProfile,
"heap": heapProfile,
"allocs": allocsProfile,
"block": blockProfile,
"mutex": mutexProfile,
}
}
}
func unlockProfiles() {
profiles.mu.Unlock()
}
// NewProfile creates a new profile with the given name.
// If a profile with that name already exists, NewProfile panics.
// The convention is to use a 'import/path.' prefix to create
// separate name spaces for each package.
// For compatibility with various tools that read pprof data,
// profile names should not contain spaces.
func NewProfile(name string) *Profile {
lockProfiles()
defer unlockProfiles()
if name == "" {
panic("pprof: NewProfile with empty name")
}
if profiles.m[name] != nil {
panic("pprof: NewProfile name already in use: " + name)
}
p := &Profile{
name: name,
m: map[any][]uintptr{},
}
profiles.m[name] = p
return p
}
// Lookup returns the profile with the given name, or nil if no such profile exists.
func Lookup(name string) *Profile {
lockProfiles()
defer unlockProfiles()
return profiles.m[name]
}
// Profiles returns a slice of all the known profiles, sorted by name.
func Profiles() []*Profile {
lockProfiles()
defer unlockProfiles()
all := make([]*Profile, 0, len(profiles.m))
for _, p := range profiles.m {
all = append(all, p)
}
sort.Slice(all, func(i, j int) bool { return all[i].name < all[j].name })
return all
}
// Name returns this profile's name, which can be passed to Lookup to reobtain the profile.
func (p *Profile) Name() string {
return p.name
}
// Count returns the number of execution stacks currently in the profile.
func (p *Profile) Count() int {
p.mu.Lock()
defer p.mu.Unlock()
if p.count != nil {
return p.count()
}
return len(p.m)
}
// Add adds the current execution stack to the profile, associated with value.
// Add stores value in an internal map, so value must be suitable for use as
// a map key and will not be garbage collected until the corresponding
// call to Remove. Add panics if the profile already contains a stack for value.
//
// The skip parameter has the same meaning as runtime.Caller's skip
// and controls where the stack trace begins. Passing skip=0 begins the
// trace in the function calling Add. For example, given this
// execution stack:
//
// Add
// called from rpc.NewClient
// called from mypkg.Run
// called from main.main
//
// Passing skip=0 begins the stack trace at the call to Add inside rpc.NewClient.
// Passing skip=1 begins the stack trace at the call to NewClient inside mypkg.Run.
func (p *Profile) Add(value any, skip int) {
if p.name == "" {
panic("pprof: use of uninitialized Profile")
}
if p.write != nil {
panic("pprof: Add called on built-in Profile " + p.name)
}
stk := make([]uintptr, 32)
n := runtime.Callers(skip+1, stk[:])
stk = stk[:n]
if len(stk) == 0 {
// The value for skip is too large, and there's no stack trace to record.
stk = []uintptr{abi.FuncPCABIInternal(lostProfileEvent)}
}
p.mu.Lock()
defer p.mu.Unlock()
if p.m[value] != nil {
panic("pprof: Profile.Add of duplicate value")
}
p.m[value] = stk
}
// Remove removes the execution stack associated with value from the profile.
// It is a no-op if the value is not in the profile.
func (p *Profile) Remove(value any) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.m, value)
}
// WriteTo writes a pprof-formatted snapshot of the profile to w.
// If a write to w returns an error, WriteTo returns that error.
// Otherwise, WriteTo returns nil.
//
// The debug parameter enables additional output.
// Passing debug=0 writes the gzip-compressed protocol buffer described
// in https://github.com/google/pprof/tree/master/proto#overview.
// Passing debug=1 writes the legacy text format with comments
// translating addresses to function names and line numbers, so that a
// programmer can read the profile without tools.
//
// The predefined profiles may assign meaning to other debug values;
// for example, when printing the "goroutine" profile, debug=2 means to
// print the goroutine stacks in the same form that a Go program uses
// when dying due to an unrecovered panic.
func (p *Profile) WriteTo(w io.Writer, debug int) error {
if p.name == "" {
panic("pprof: use of zero Profile")
}
if p.write != nil {
return p.write(w, debug)
}
// Obtain consistent snapshot under lock; then process without lock.
p.mu.Lock()
all := make([][]uintptr, 0, len(p.m))
for _, stk := range p.m {
all = append(all, stk)
}
p.mu.Unlock()
// Map order is non-deterministic; make output deterministic.
sort.Slice(all, func(i, j int) bool {
t, u := all[i], all[j]
for k := 0; k < len(t) && k < len(u); k++ {
if t[k] != u[k] {
return t[k] < u[k]
}
}
return len(t) < len(u)
})
return printCountProfile(w, debug, p.name, stackProfile(all))
}
type stackProfile [][]uintptr
func (x stackProfile) Len() int { return len(x) }
func (x stackProfile) Stack(i int) []uintptr { return x[i] }
func (x stackProfile) Label(i int) *labelMap { return nil }
// A countProfile is a set of stack traces to be printed as counts
// grouped by stack trace. There are multiple implementations:
// all that matters is that we can find out how many traces there are
// and obtain each trace in turn.
type countProfile interface {
Len() int
Stack(i int) []uintptr
Label(i int) *labelMap
}
// printCountCycleProfile outputs block profile records (for block or mutex profiles)
// as the pprof-proto format output. Translations from cycle count to time duration
// are done because The proto expects count and time (nanoseconds) instead of count
// and the number of cycles for block, contention profiles.
func printCountCycleProfile(w io.Writer, countName, cycleName string, records []runtime.BlockProfileRecord) error {
// Output profile in protobuf form.
b := newProfileBuilder(w)
b.pbValueType(tagProfile_PeriodType, countName, "count")
b.pb.int64Opt(tagProfile_Period, 1)
b.pbValueType(tagProfile_SampleType, countName, "count")
b.pbValueType(tagProfile_SampleType, cycleName, "nanoseconds")
cpuGHz := float64(runtime_cyclesPerSecond()) / 1e9
values := []int64{0, 0}
var locs []uint64
for _, r := range records {
values[0] = r.Count
values[1] = int64(float64(r.Cycles) / cpuGHz)
// For count profiles, all stack addresses are
// return PCs, which is what appendLocsForStack expects.
locs = b.appendLocsForStack(locs[:0], r.Stack())
b.pbSample(values, locs, nil)
}
b.build()
return nil
}
// printCountProfile prints a countProfile at the specified debug level.
// The profile will be in compressed proto format unless debug is nonzero.
func printCountProfile(w io.Writer, debug int, name string, p countProfile) error {
// Build count of each stack.
var buf strings.Builder
key := func(stk []uintptr, lbls *labelMap) string {
buf.Reset()
fmt.Fprintf(&buf, "@")
for _, pc := range stk {
fmt.Fprintf(&buf, " %#x", pc)
}
if lbls != nil {
buf.WriteString("\n# labels: ")
buf.WriteString(lbls.String())
}
return buf.String()
}
count := map[string]int{}
index := map[string]int{}
var keys []string
n := p.Len()
for i := 0; i < n; i++ {
k := key(p.Stack(i), p.Label(i))
if count[k] == 0 {
index[k] = i
keys = append(keys, k)
}
count[k]++
}
sort.Sort(&keysByCount{keys, count})
if debug > 0 {
// Print debug profile in legacy format
tw := tabwriter.NewWriter(w, 1, 8, 1, '\t', 0)
fmt.Fprintf(tw, "%s profile: total %d\n", name, p.Len())
for _, k := range keys {
fmt.Fprintf(tw, "%d %s\n", count[k], k)
printStackRecord(tw, p.Stack(index[k]), false)
}
return tw.Flush()
}
// Output profile in protobuf form.
b := newProfileBuilder(w)
b.pbValueType(tagProfile_PeriodType, name, "count")
b.pb.int64Opt(tagProfile_Period, 1)
b.pbValueType(tagProfile_SampleType, name, "count")
values := []int64{0}
var locs []uint64
for _, k := range keys {
values[0] = int64(count[k])
// For count profiles, all stack addresses are
// return PCs, which is what appendLocsForStack expects.
locs = b.appendLocsForStack(locs[:0], p.Stack(index[k]))
idx := index[k]
var labels func()
if p.Label(idx) != nil {
labels = func() {
for k, v := range *p.Label(idx) {
b.pbLabel(tagSample_Label, k, v, 0)
}
}
}
b.pbSample(values, locs, labels)
}
b.build()
return nil
}
// keysByCount sorts keys with higher counts first, breaking ties by key string order.
type keysByCount struct {
keys []string
count map[string]int
}
func (x *keysByCount) Len() int { return len(x.keys) }
func (x *keysByCount) Swap(i, j int) { x.keys[i], x.keys[j] = x.keys[j], x.keys[i] }
func (x *keysByCount) Less(i, j int) bool {
ki, kj := x.keys[i], x.keys[j]
ci, cj := x.count[ki], x.count[kj]
if ci != cj {
return ci > cj
}
return ki < kj
}
// printStackRecord prints the function + source line information
// for a single stack trace.
func printStackRecord(w io.Writer, stk []uintptr, allFrames bool) {
show := allFrames
frames := runtime.CallersFrames(stk)
for {
frame, more := frames.Next()
name := frame.Function
if name == "" {
show = true
fmt.Fprintf(w, "#\t%#x\n", frame.PC)
} else if name != "runtime.goexit" && (show || !strings.HasPrefix(name, "runtime.")) {
// Hide runtime.goexit and any runtime functions at the beginning.
// This is useful mainly for allocation traces.
show = true
fmt.Fprintf(w, "#\t%#x\t%s+%#x\t%s:%d\n", frame.PC, name, frame.PC-frame.Entry, frame.File, frame.Line)
}
if !more {
break
}
}
if !show {
// We didn't print anything; do it again,
// and this time include runtime functions.
printStackRecord(w, stk, true)
return
}
fmt.Fprintf(w, "\n")
}
// Interface to system profiles.
// WriteHeapProfile is shorthand for Lookup("heap").WriteTo(w, 0).
// It is preserved for backwards compatibility.
func WriteHeapProfile(w io.Writer) error {
return writeHeap(w, 0)
}
// countHeap returns the number of records in the heap profile.
func countHeap() int {
n, _ := runtime.MemProfile(nil, true)
return n
}
// writeHeap writes the current runtime heap profile to w.
func writeHeap(w io.Writer, debug int) error {
return writeHeapInternal(w, debug, "")
}
// writeAlloc writes the current runtime heap profile to w
// with the total allocation space as the default sample type.
func writeAlloc(w io.Writer, debug int) error {
return writeHeapInternal(w, debug, "alloc_space")
}
func writeHeapInternal(w io.Writer, debug int, defaultSampleType string) error {
var memStats *runtime.MemStats
if debug != 0 {
// Read mem stats first, so that our other allocations
// do not appear in the statistics.
memStats = new(runtime.MemStats)
runtime.ReadMemStats(memStats)
}
// Find out how many records there are (MemProfile(nil, true)),
// allocate that many records, and get the data.
// There's a race—more records might be added between
// the two calls—so allocate a few extra records for safety
// and also try again if we're very unlucky.
// The loop should only execute one iteration in the common case.
var p []runtime.MemProfileRecord
n, ok := runtime.MemProfile(nil, true)
for {
// Allocate room for a slightly bigger profile,
// in case a few more entries have been added
// since the call to MemProfile.
p = make([]runtime.MemProfileRecord, n+50)
n, ok = runtime.MemProfile(p, true)
if ok {
p = p[0:n]
break
}
// Profile grew; try again.
}
if debug == 0 {
return writeHeapProto(w, p, int64(runtime.MemProfileRate), defaultSampleType)
}
sort.Slice(p, func(i, j int) bool { return p[i].InUseBytes() > p[j].InUseBytes() })
b := bufio.NewWriter(w)
tw := tabwriter.NewWriter(b, 1, 8, 1, '\t', 0)
w = tw
var total runtime.MemProfileRecord
for i := range p {
r := &p[i]
total.AllocBytes += r.AllocBytes
total.AllocObjects += r.AllocObjects
total.FreeBytes += r.FreeBytes
total.FreeObjects += r.FreeObjects
}
// Technically the rate is MemProfileRate not 2*MemProfileRate,
// but early versions of the C++ heap profiler reported 2*MemProfileRate,
// so that's what pprof has come to expect.
rate := 2 * runtime.MemProfileRate
// pprof reads a profile with alloc == inuse as being a "2-column" profile
// (objects and bytes, not distinguishing alloc from inuse),
// but then such a profile can't be merged using pprof *.prof with
// other 4-column profiles where alloc != inuse.
// The easiest way to avoid this bug is to adjust allocBytes so it's never == inuseBytes.
// pprof doesn't use these header values anymore except for checking equality.
inUseBytes := total.InUseBytes()
allocBytes := total.AllocBytes
if inUseBytes == allocBytes {
allocBytes++
}
fmt.Fprintf(w, "heap profile: %d: %d [%d: %d] @ heap/%d\n",
total.InUseObjects(), inUseBytes,
total.AllocObjects, allocBytes,
rate)
for i := range p {
r := &p[i]
fmt.Fprintf(w, "%d: %d [%d: %d] @",
r.InUseObjects(), r.InUseBytes(),
r.AllocObjects, r.AllocBytes)
for _, pc := range r.Stack() {
fmt.Fprintf(w, " %#x", pc)
}
fmt.Fprintf(w, "\n")
printStackRecord(w, r.Stack(), false)
}
// Print memstats information too.
// Pprof will ignore, but useful for people
s := memStats
fmt.Fprintf(w, "\n# runtime.MemStats\n")
fmt.Fprintf(w, "# Alloc = %d\n", s.Alloc)
fmt.Fprintf(w, "# TotalAlloc = %d\n", s.TotalAlloc)
fmt.Fprintf(w, "# Sys = %d\n", s.Sys)
fmt.Fprintf(w, "# Lookups = %d\n", s.Lookups)
fmt.Fprintf(w, "# Mallocs = %d\n", s.Mallocs)
fmt.Fprintf(w, "# Frees = %d\n", s.Frees)
fmt.Fprintf(w, "# HeapAlloc = %d\n", s.HeapAlloc)
fmt.Fprintf(w, "# HeapSys = %d\n", s.HeapSys)
fmt.Fprintf(w, "# HeapIdle = %d\n", s.HeapIdle)
fmt.Fprintf(w, "# HeapInuse = %d\n", s.HeapInuse)
fmt.Fprintf(w, "# HeapReleased = %d\n", s.HeapReleased)
fmt.Fprintf(w, "# HeapObjects = %d\n", s.HeapObjects)
fmt.Fprintf(w, "# Stack = %d / %d\n", s.StackInuse, s.StackSys)
fmt.Fprintf(w, "# MSpan = %d / %d\n", s.MSpanInuse, s.MSpanSys)
fmt.Fprintf(w, "# MCache = %d / %d\n", s.MCacheInuse, s.MCacheSys)
fmt.Fprintf(w, "# BuckHashSys = %d\n", s.BuckHashSys)
fmt.Fprintf(w, "# GCSys = %d\n", s.GCSys)
fmt.Fprintf(w, "# OtherSys = %d\n", s.OtherSys)
fmt.Fprintf(w, "# NextGC = %d\n", s.NextGC)
fmt.Fprintf(w, "# LastGC = %d\n", s.LastGC)
fmt.Fprintf(w, "# PauseNs = %d\n", s.PauseNs)
fmt.Fprintf(w, "# PauseEnd = %d\n", s.PauseEnd)
fmt.Fprintf(w, "# NumGC = %d\n", s.NumGC)
fmt.Fprintf(w, "# NumForcedGC = %d\n", s.NumForcedGC)
fmt.Fprintf(w, "# GCCPUFraction = %v\n", s.GCCPUFraction)
fmt.Fprintf(w, "# DebugGC = %v\n", s.DebugGC)
// Also flush out MaxRSS on supported platforms.
addMaxRSS(w)
tw.Flush()
return b.Flush()
}
// countThreadCreate returns the size of the current ThreadCreateProfile.
func countThreadCreate() int {
n, _ := runtime.ThreadCreateProfile(nil)
return n
}
// writeThreadCreate writes the current runtime ThreadCreateProfile to w.
func writeThreadCreate(w io.Writer, debug int) error {
// Until https://golang.org/issues/6104 is addressed, wrap
// ThreadCreateProfile because there's no point in tracking labels when we
// don't get any stack-traces.
return writeRuntimeProfile(w, debug, "threadcreate", func(p []runtime.StackRecord, _ []unsafe.Pointer) (n int, ok bool) {
return runtime.ThreadCreateProfile(p)
})
}
// countGoroutine returns the number of goroutines.
func countGoroutine() int {
return runtime.NumGoroutine()
}
// runtime_goroutineProfileWithLabels is defined in runtime/mprof.go
func runtime_goroutineProfileWithLabels(p []runtime.StackRecord, labels []unsafe.Pointer) (n int, ok bool)
// writeGoroutine writes the current runtime GoroutineProfile to w.
func writeGoroutine(w io.Writer, debug int) error {
if debug >= 2 {
return writeGoroutineStacks(w)
}
return writeRuntimeProfile(w, debug, "goroutine", runtime_goroutineProfileWithLabels)
}
func writeGoroutineStacks(w io.Writer) error {
// We don't know how big the buffer needs to be to collect
// all the goroutines. Start with 1 MB and try a few times, doubling each time.
// Give up and use a truncated trace if 64 MB is not enough.
buf := make([]byte, 1<<20)
for i := 0; ; i++ {
n := runtime.Stack(buf, true)
if n < len(buf) {
buf = buf[:n]
break
}
if len(buf) >= 64<<20 {
// Filled 64 MB - stop there.
break
}
buf = make([]byte, 2*len(buf))
}
_, err := w.Write(buf)
return err
}
func writeRuntimeProfile(w io.Writer, debug int, name string, fetch func([]runtime.StackRecord, []unsafe.Pointer) (int, bool)) error {
// Find out how many records there are (fetch(nil)),
// allocate that many records, and get the data.
// There's a race—more records might be added between
// the two calls—so allocate a few extra records for safety
// and also try again if we're very unlucky.
// The loop should only execute one iteration in the common case.
var p []runtime.StackRecord
var labels []unsafe.Pointer
n, ok := fetch(nil, nil)
for {
// Allocate room for a slightly bigger profile,
// in case a few more entries have been added
// since the call to ThreadProfile.
p = make([]runtime.StackRecord, n+10)
labels = make([]unsafe.Pointer, n+10)
n, ok = fetch(p, labels)
if ok {
p = p[0:n]
break
}
// Profile grew; try again.
}
return printCountProfile(w, debug, name, &runtimeProfile{p, labels})
}
type runtimeProfile struct {
stk []runtime.StackRecord
labels []unsafe.Pointer
}
func (p *runtimeProfile) Len() int { return len(p.stk) }
func (p *runtimeProfile) Stack(i int) []uintptr { return p.stk[i].Stack() }
func (p *runtimeProfile) Label(i int) *labelMap { return (*labelMap)(p.labels[i]) }
var cpu struct {
sync.Mutex
profiling bool
done chan bool
}
// StartCPUProfile enables CPU profiling for the current process.
// While profiling, the profile will be buffered and written to w.
// StartCPUProfile returns an error if profiling is already enabled.
//
// On Unix-like systems, StartCPUProfile does not work by default for
// Go code built with -buildmode=c-archive or -buildmode=c-shared.
// StartCPUProfile relies on the SIGPROF signal, but that signal will
// be delivered to the main program's SIGPROF signal handler (if any)
// not to the one used by Go. To make it work, call os/signal.Notify
// for syscall.SIGPROF, but note that doing so may break any profiling
// being done by the main program.
func StartCPUProfile(w io.Writer) error {
// The runtime routines allow a variable profiling rate,
// but in practice operating systems cannot trigger signals
// at more than about 500 Hz, and our processing of the
// signal is not cheap (mostly getting the stack trace).
// 100 Hz is a reasonable choice: it is frequent enough to
// produce useful data, rare enough not to bog down the
// system, and a nice round number to make it easy to
// convert sample counts to seconds. Instead of requiring
// each client to specify the frequency, we hard code it.
const hz = 100
cpu.Lock()
defer cpu.Unlock()
if cpu.done == nil {
cpu.done = make(chan bool)
}
// Double-check.
if cpu.profiling {
return fmt.Errorf("cpu profiling already in use")
}
cpu.profiling = true
runtime.SetCPUProfileRate(hz)
go profileWriter(w)
return nil
}
// readProfile, provided by the runtime, returns the next chunk of
// binary CPU profiling stack trace data, blocking until data is available.
// If profiling is turned off and all the profile data accumulated while it was
// on has been returned, readProfile returns eof=true.
// The caller must save the returned data and tags before calling readProfile again.
func readProfile() (data []uint64, tags []unsafe.Pointer, eof bool)
func profileWriter(w io.Writer) {
b := newProfileBuilder(w)
var err error
for {
time.Sleep(100 * time.Millisecond)
data, tags, eof := readProfile()
if e := b.addCPUData(data, tags); e != nil && err == nil {
err = e
}
if eof {
break
}
}
if err != nil {
// The runtime should never produce an invalid or truncated profile.
// It drops records that can't fit into its log buffers.
panic("runtime/pprof: converting profile: " + err.Error())
}
b.build()
cpu.done <- true
}
// StopCPUProfile stops the current CPU profile, if any.
// StopCPUProfile only returns after all the writes for the
// profile have completed.
func StopCPUProfile() {
cpu.Lock()
defer cpu.Unlock()
if !cpu.profiling {
return
}
cpu.profiling = false
runtime.SetCPUProfileRate(0)
<-cpu.done
}
// countBlock returns the number of records in the blocking profile.
func countBlock() int {
n, _ := runtime.BlockProfile(nil)
return n
}
// countMutex returns the number of records in the mutex profile.
func countMutex() int {
n, _ := runtime.MutexProfile(nil)
return n
}
// writeBlock writes the current blocking profile to w.
func writeBlock(w io.Writer, debug int) error {
return writeProfileInternal(w, debug, "contention", runtime.BlockProfile)
}
// writeMutex writes the current mutex profile to w.
func writeMutex(w io.Writer, debug int) error {
return writeProfileInternal(w, debug, "mutex", runtime.MutexProfile)
}
// writeProfileInternal writes the current blocking or mutex profile depending on the passed parameters.
func writeProfileInternal(w io.Writer, debug int, name string, runtimeProfile func([]runtime.BlockProfileRecord) (int, bool)) error {
var p []runtime.BlockProfileRecord
n, ok := runtimeProfile(nil)
for {
p = make([]runtime.BlockProfileRecord, n+50)
n, ok = runtimeProfile(p)
if ok {
p = p[:n]
break
}
}
sort.Slice(p, func(i, j int) bool { return p[i].Cycles > p[j].Cycles })
if debug <= 0 {
return printCountCycleProfile(w, "contentions", "delay", p)
}
b := bufio.NewWriter(w)
tw := tabwriter.NewWriter(w, 1, 8, 1, '\t', 0)
w = tw
fmt.Fprintf(w, "--- %v:\n", name)
fmt.Fprintf(w, "cycles/second=%v\n", runtime_cyclesPerSecond())
if name == "mutex" {
fmt.Fprintf(w, "sampling period=%d\n", runtime.SetMutexProfileFraction(-1))
}
for i := range p {
r := &p[i]
fmt.Fprintf(w, "%v %v @", r.Cycles, r.Count)
for _, pc := range r.Stack() {
fmt.Fprintf(w, " %#x", pc)
}
fmt.Fprint(w, "\n")
if debug > 0 {
printStackRecord(w, r.Stack(), true)
}
}
if tw != nil {
tw.Flush()
}
return b.Flush()
}
func runtime_cyclesPerSecond() int64
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package pprof
import (
"fmt"
"io"
"runtime"
"syscall"
)
// Adds MaxRSS to platforms that are supported.
func addMaxRSS(w io.Writer) {
var rssToBytes uintptr
switch runtime.GOOS {
case "aix", "android", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
rssToBytes = 1024
case "darwin", "ios":
rssToBytes = 1
case "illumos", "solaris":
rssToBytes = uintptr(syscall.Getpagesize())
default:
panic("unsupported OS")
}
var rusage syscall.Rusage
err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage)
if err == nil {
fmt.Fprintf(w, "# MaxRSS = %d\n", uintptr(rusage.Maxrss)*rssToBytes)
}
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import (
"bytes"
"compress/gzip"
"fmt"
"internal/abi"
"io"
"runtime"
"strconv"
"strings"
"time"
"unsafe"
)
// lostProfileEvent is the function to which lost profiling
// events are attributed.
// (The name shows up in the pprof graphs.)
func lostProfileEvent() { lostProfileEvent() }
// A profileBuilder writes a profile incrementally from a
// stream of profile samples delivered by the runtime.
type profileBuilder struct {
start time.Time
end time.Time
havePeriod bool
period int64
m profMap
// encoding state
w io.Writer
zw *gzip.Writer
pb protobuf
strings []string
stringMap map[string]int
locs map[uintptr]locInfo // list of locInfo starting with the given PC.
funcs map[string]int // Package path-qualified function name to Function.ID
mem []memMap
deck pcDeck
}
type memMap struct {
// initialized as reading mapping
start uintptr // Address at which the binary (or DLL) is loaded into memory.
end uintptr // The limit of the address range occupied by this mapping.
offset uint64 // Offset in the binary that corresponds to the first mapped address.
file string // The object this entry is loaded from.
buildID string // A string that uniquely identifies a particular program version with high probability.
funcs symbolizeFlag
fake bool // map entry was faked; /proc/self/maps wasn't available
}
// symbolizeFlag keeps track of symbolization result.
//
// 0 : no symbol lookup was performed
// 1<<0 (lookupTried) : symbol lookup was performed
// 1<<1 (lookupFailed): symbol lookup was performed but failed
type symbolizeFlag uint8
const (
lookupTried symbolizeFlag = 1 << iota
lookupFailed symbolizeFlag = 1 << iota
)
const (
// message Profile
tagProfile_SampleType = 1 // repeated ValueType
tagProfile_Sample = 2 // repeated Sample
tagProfile_Mapping = 3 // repeated Mapping
tagProfile_Location = 4 // repeated Location
tagProfile_Function = 5 // repeated Function
tagProfile_StringTable = 6 // repeated string
tagProfile_DropFrames = 7 // int64 (string table index)
tagProfile_KeepFrames = 8 // int64 (string table index)
tagProfile_TimeNanos = 9 // int64
tagProfile_DurationNanos = 10 // int64
tagProfile_PeriodType = 11 // ValueType (really optional string???)
tagProfile_Period = 12 // int64
tagProfile_Comment = 13 // repeated int64
tagProfile_DefaultSampleType = 14 // int64
// message ValueType
tagValueType_Type = 1 // int64 (string table index)
tagValueType_Unit = 2 // int64 (string table index)
// message Sample
tagSample_Location = 1 // repeated uint64
tagSample_Value = 2 // repeated int64
tagSample_Label = 3 // repeated Label
// message Label
tagLabel_Key = 1 // int64 (string table index)
tagLabel_Str = 2 // int64 (string table index)
tagLabel_Num = 3 // int64
// message Mapping
tagMapping_ID = 1 // uint64
tagMapping_Start = 2 // uint64
tagMapping_Limit = 3 // uint64
tagMapping_Offset = 4 // uint64
tagMapping_Filename = 5 // int64 (string table index)
tagMapping_BuildID = 6 // int64 (string table index)
tagMapping_HasFunctions = 7 // bool
tagMapping_HasFilenames = 8 // bool
tagMapping_HasLineNumbers = 9 // bool
tagMapping_HasInlineFrames = 10 // bool
// message Location
tagLocation_ID = 1 // uint64
tagLocation_MappingID = 2 // uint64
tagLocation_Address = 3 // uint64
tagLocation_Line = 4 // repeated Line
// message Line
tagLine_FunctionID = 1 // uint64
tagLine_Line = 2 // int64
// message Function
tagFunction_ID = 1 // uint64
tagFunction_Name = 2 // int64 (string table index)
tagFunction_SystemName = 3 // int64 (string table index)
tagFunction_Filename = 4 // int64 (string table index)
tagFunction_StartLine = 5 // int64
)
// stringIndex adds s to the string table if not already present
// and returns the index of s in the string table.
func (b *profileBuilder) stringIndex(s string) int64 {
id, ok := b.stringMap[s]
if !ok {
id = len(b.strings)
b.strings = append(b.strings, s)
b.stringMap[s] = id
}
return int64(id)
}
func (b *profileBuilder) flush() {
const dataFlush = 4096
if b.pb.nest == 0 && len(b.pb.data) > dataFlush {
b.zw.Write(b.pb.data)
b.pb.data = b.pb.data[:0]
}
}
// pbValueType encodes a ValueType message to b.pb.
func (b *profileBuilder) pbValueType(tag int, typ, unit string) {
start := b.pb.startMessage()
b.pb.int64(tagValueType_Type, b.stringIndex(typ))
b.pb.int64(tagValueType_Unit, b.stringIndex(unit))
b.pb.endMessage(tag, start)
}
// pbSample encodes a Sample message to b.pb.
func (b *profileBuilder) pbSample(values []int64, locs []uint64, labels func()) {
start := b.pb.startMessage()
b.pb.int64s(tagSample_Value, values)
b.pb.uint64s(tagSample_Location, locs)
if labels != nil {
labels()
}
b.pb.endMessage(tagProfile_Sample, start)
b.flush()
}
// pbLabel encodes a Label message to b.pb.
func (b *profileBuilder) pbLabel(tag int, key, str string, num int64) {
start := b.pb.startMessage()
b.pb.int64Opt(tagLabel_Key, b.stringIndex(key))
b.pb.int64Opt(tagLabel_Str, b.stringIndex(str))
b.pb.int64Opt(tagLabel_Num, num)
b.pb.endMessage(tag, start)
}
// pbLine encodes a Line message to b.pb.
func (b *profileBuilder) pbLine(tag int, funcID uint64, line int64) {
start := b.pb.startMessage()
b.pb.uint64Opt(tagLine_FunctionID, funcID)
b.pb.int64Opt(tagLine_Line, line)
b.pb.endMessage(tag, start)
}
// pbMapping encodes a Mapping message to b.pb.
func (b *profileBuilder) pbMapping(tag int, id, base, limit, offset uint64, file, buildID string, hasFuncs bool) {
start := b.pb.startMessage()
b.pb.uint64Opt(tagMapping_ID, id)
b.pb.uint64Opt(tagMapping_Start, base)
b.pb.uint64Opt(tagMapping_Limit, limit)
b.pb.uint64Opt(tagMapping_Offset, offset)
b.pb.int64Opt(tagMapping_Filename, b.stringIndex(file))
b.pb.int64Opt(tagMapping_BuildID, b.stringIndex(buildID))
// TODO: we set HasFunctions if all symbols from samples were symbolized (hasFuncs).
// Decide what to do about HasInlineFrames and HasLineNumbers.
// Also, another approach to handle the mapping entry with
// incomplete symbolization results is to duplicate the mapping
// entry (but with different Has* fields values) and use
// different entries for symbolized locations and unsymbolized locations.
if hasFuncs {
b.pb.bool(tagMapping_HasFunctions, true)
}
b.pb.endMessage(tag, start)
}
func allFrames(addr uintptr) ([]runtime.Frame, symbolizeFlag) {
// Expand this one address using CallersFrames so we can cache
// each expansion. In general, CallersFrames takes a whole
// stack, but in this case we know there will be no skips in
// the stack and we have return PCs anyway.
frames := runtime.CallersFrames([]uintptr{addr})
frame, more := frames.Next()
if frame.Function == "runtime.goexit" {
// Short-circuit if we see runtime.goexit so the loop
// below doesn't allocate a useless empty location.
return nil, 0
}
symbolizeResult := lookupTried
if frame.PC == 0 || frame.Function == "" || frame.File == "" || frame.Line == 0 {
symbolizeResult |= lookupFailed
}
if frame.PC == 0 {
// If we failed to resolve the frame, at least make up
// a reasonable call PC. This mostly happens in tests.
frame.PC = addr - 1
}
ret := []runtime.Frame{frame}
for frame.Function != "runtime.goexit" && more {
frame, more = frames.Next()
ret = append(ret, frame)
}
return ret, symbolizeResult
}
type locInfo struct {
// location id assigned by the profileBuilder
id uint64
// sequence of PCs, including the fake PCs returned by the traceback
// to represent inlined functions
// https://github.com/golang/go/blob/d6f2f833c93a41ec1c68e49804b8387a06b131c5/src/runtime/traceback.go#L347-L368
pcs []uintptr
// firstPCFrames and firstPCSymbolizeResult hold the results of the
// allFrames call for the first (leaf-most) PC this locInfo represents
firstPCFrames []runtime.Frame
firstPCSymbolizeResult symbolizeFlag
}
// newProfileBuilder returns a new profileBuilder.
// CPU profiling data obtained from the runtime can be added
// by calling b.addCPUData, and then the eventual profile
// can be obtained by calling b.finish.
func newProfileBuilder(w io.Writer) *profileBuilder {
zw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed)
b := &profileBuilder{
w: w,
zw: zw,
start: time.Now(),
strings: []string{""},
stringMap: map[string]int{"": 0},
locs: map[uintptr]locInfo{},
funcs: map[string]int{},
}
b.readMapping()
return b
}
// addCPUData adds the CPU profiling data to the profile.
//
// The data must be a whole number of records, as delivered by the runtime.
// len(tags) must be equal to the number of records in data.
func (b *profileBuilder) addCPUData(data []uint64, tags []unsafe.Pointer) error {
if !b.havePeriod {
// first record is period
if len(data) < 3 {
return fmt.Errorf("truncated profile")
}
if data[0] != 3 || data[2] == 0 {
return fmt.Errorf("malformed profile")
}
// data[2] is sampling rate in Hz. Convert to sampling
// period in nanoseconds.
b.period = 1e9 / int64(data[2])
b.havePeriod = true
data = data[3:]
// Consume tag slot. Note that there isn't a meaningful tag
// value for this record.
tags = tags[1:]
}
// Parse CPU samples from the profile.
// Each sample is 3+n uint64s:
// data[0] = 3+n
// data[1] = time stamp (ignored)
// data[2] = count
// data[3:3+n] = stack
// If the count is 0 and the stack has length 1,
// that's an overflow record inserted by the runtime
// to indicate that stack[0] samples were lost.
// Otherwise the count is usually 1,
// but in a few special cases like lost non-Go samples
// there can be larger counts.
// Because many samples with the same stack arrive,
// we want to deduplicate immediately, which we do
// using the b.m profMap.
for len(data) > 0 {
if len(data) < 3 || data[0] > uint64(len(data)) {
return fmt.Errorf("truncated profile")
}
if data[0] < 3 || tags != nil && len(tags) < 1 {
return fmt.Errorf("malformed profile")
}
if len(tags) < 1 {
return fmt.Errorf("mismatched profile records and tags")
}
count := data[2]
stk := data[3:data[0]]
data = data[data[0]:]
tag := tags[0]
tags = tags[1:]
if count == 0 && len(stk) == 1 {
// overflow record
count = uint64(stk[0])
stk = []uint64{
// gentraceback guarantees that PCs in the
// stack can be unconditionally decremented and
// still be valid, so we must do the same.
uint64(abi.FuncPCABIInternal(lostProfileEvent) + 1),
}
}
b.m.lookup(stk, tag).count += int64(count)
}
if len(tags) != 0 {
return fmt.Errorf("mismatched profile records and tags")
}
return nil
}
// build completes and returns the constructed profile.
func (b *profileBuilder) build() {
b.end = time.Now()
b.pb.int64Opt(tagProfile_TimeNanos, b.start.UnixNano())
if b.havePeriod { // must be CPU profile
b.pbValueType(tagProfile_SampleType, "samples", "count")
b.pbValueType(tagProfile_SampleType, "cpu", "nanoseconds")
b.pb.int64Opt(tagProfile_DurationNanos, b.end.Sub(b.start).Nanoseconds())
b.pbValueType(tagProfile_PeriodType, "cpu", "nanoseconds")
b.pb.int64Opt(tagProfile_Period, b.period)
}
values := []int64{0, 0}
var locs []uint64
for e := b.m.all; e != nil; e = e.nextAll {
values[0] = e.count
values[1] = e.count * b.period
var labels func()
if e.tag != nil {
labels = func() {
for k, v := range *(*labelMap)(e.tag) {
b.pbLabel(tagSample_Label, k, v, 0)
}
}
}
locs = b.appendLocsForStack(locs[:0], e.stk)
b.pbSample(values, locs, labels)
}
for i, m := range b.mem {
hasFunctions := m.funcs == lookupTried // lookupTried but not lookupFailed
b.pbMapping(tagProfile_Mapping, uint64(i+1), uint64(m.start), uint64(m.end), m.offset, m.file, m.buildID, hasFunctions)
}
// TODO: Anything for tagProfile_DropFrames?
// TODO: Anything for tagProfile_KeepFrames?
b.pb.strings(tagProfile_StringTable, b.strings)
b.zw.Write(b.pb.data)
b.zw.Close()
}
// appendLocsForStack appends the location IDs for the given stack trace to the given
// location ID slice, locs. The addresses in the stack are return PCs or 1 + the PC of
// an inline marker as the runtime traceback function returns.
//
// It may return an empty slice even if locs is non-empty, for example if locs consists
// solely of runtime.goexit. We still count these empty stacks in profiles in order to
// get the right cumulative sample count.
//
// It may emit to b.pb, so there must be no message encoding in progress.
func (b *profileBuilder) appendLocsForStack(locs []uint64, stk []uintptr) (newLocs []uint64) {
b.deck.reset()
// The last frame might be truncated. Recover lost inline frames.
stk = runtime_expandFinalInlineFrame(stk)
for len(stk) > 0 {
addr := stk[0]
if l, ok := b.locs[addr]; ok {
// When generating code for an inlined function, the compiler adds
// NOP instructions to the outermost function as a placeholder for
// each layer of inlining. When the runtime generates tracebacks for
// stacks that include inlined functions, it uses the addresses of
// those NOPs as "fake" PCs on the stack as if they were regular
// function call sites. But if a profiling signal arrives while the
// CPU is executing one of those NOPs, its PC will show up as a leaf
// in the profile with its own Location entry. So, always check
// whether addr is a "fake" PC in the context of the current call
// stack by trying to add it to the inlining deck before assuming
// that the deck is complete.
if len(b.deck.pcs) > 0 {
if added := b.deck.tryAdd(addr, l.firstPCFrames, l.firstPCSymbolizeResult); added {
stk = stk[1:]
continue
}
}
// first record the location if there is any pending accumulated info.
if id := b.emitLocation(); id > 0 {
locs = append(locs, id)
}
// then, record the cached location.
locs = append(locs, l.id)
// Skip the matching pcs.
//
// Even if stk was truncated due to the stack depth
// limit, expandFinalInlineFrame above has already
// fixed the truncation, ensuring it is long enough.
stk = stk[len(l.pcs):]
continue
}
frames, symbolizeResult := allFrames(addr)
if len(frames) == 0 { // runtime.goexit.
if id := b.emitLocation(); id > 0 {
locs = append(locs, id)
}
stk = stk[1:]
continue
}
if added := b.deck.tryAdd(addr, frames, symbolizeResult); added {
stk = stk[1:]
continue
}
// add failed because this addr is not inlined with the
// existing PCs in the deck. Flush the deck and retry handling
// this pc.
if id := b.emitLocation(); id > 0 {
locs = append(locs, id)
}
// check cache again - previous emitLocation added a new entry
if l, ok := b.locs[addr]; ok {
locs = append(locs, l.id)
stk = stk[len(l.pcs):] // skip the matching pcs.
} else {
b.deck.tryAdd(addr, frames, symbolizeResult) // must succeed.
stk = stk[1:]
}
}
if id := b.emitLocation(); id > 0 { // emit remaining location.
locs = append(locs, id)
}
return locs
}
// Here's an example of how Go 1.17 writes out inlined functions, compiled for
// linux/amd64. The disassembly of main.main shows two levels of inlining: main
// calls b, b calls a, a does some work.
//
// inline.go:9 0x4553ec 90 NOPL // func main() { b(v) }
// inline.go:6 0x4553ed 90 NOPL // func b(v *int) { a(v) }
// inline.go:5 0x4553ee 48c7002a000000 MOVQ $0x2a, 0(AX) // func a(v *int) { *v = 42 }
//
// If a profiling signal arrives while executing the MOVQ at 0x4553ee (for line
// 5), the runtime will report the stack as the MOVQ frame being called by the
// NOPL at 0x4553ed (for line 6) being called by the NOPL at 0x4553ec (for line
// 9).
//
// The role of pcDeck is to collapse those three frames back into a single
// location at 0x4553ee, with file/line/function symbolization info representing
// the three layers of calls. It does that via sequential calls to pcDeck.tryAdd
// starting with the leaf-most address. The fourth call to pcDeck.tryAdd will be
// for the caller of main.main. Because main.main was not inlined in its caller,
// the deck will reject the addition, and the fourth PC on the stack will get
// its own location.
// pcDeck is a helper to detect a sequence of inlined functions from
// a stack trace returned by the runtime.
//
// The stack traces returned by runtime's trackback functions are fully
// expanded (at least for Go functions) and include the fake pcs representing
// inlined functions. The profile proto expects the inlined functions to be
// encoded in one Location message.
// https://github.com/google/pprof/blob/5e965273ee43930341d897407202dd5e10e952cb/proto/profile.proto#L177-L184
//
// Runtime does not directly expose whether a frame is for an inlined function
// and looking up debug info is not ideal, so we use a heuristic to filter
// the fake pcs and restore the inlined and entry functions. Inlined functions
// have the following properties:
//
// Frame's Func is nil (note: also true for non-Go functions), and
// Frame's Entry matches its entry function frame's Entry (note: could also be true for recursive calls and non-Go functions), and
// Frame's Name does not match its entry function frame's name (note: inlined functions cannot be directly recursive).
//
// As reading and processing the pcs in a stack trace one by one (from leaf to the root),
// we use pcDeck to temporarily hold the observed pcs and their expanded frames
// until we observe the entry function frame.
type pcDeck struct {
pcs []uintptr
frames []runtime.Frame
symbolizeResult symbolizeFlag
// firstPCFrames indicates the number of frames associated with the first
// (leaf-most) PC in the deck
firstPCFrames int
// firstPCSymbolizeResult holds the results of the allFrames call for the
// first (leaf-most) PC in the deck
firstPCSymbolizeResult symbolizeFlag
}
func (d *pcDeck) reset() {
d.pcs = d.pcs[:0]
d.frames = d.frames[:0]
d.symbolizeResult = 0
d.firstPCFrames = 0
d.firstPCSymbolizeResult = 0
}
// tryAdd tries to add the pc and Frames expanded from it (most likely one,
// since the stack trace is already fully expanded) and the symbolizeResult
// to the deck. If it fails the caller needs to flush the deck and retry.
func (d *pcDeck) tryAdd(pc uintptr, frames []runtime.Frame, symbolizeResult symbolizeFlag) (success bool) {
if existing := len(d.frames); existing > 0 {
// 'd.frames' are all expanded from one 'pc' and represent all
// inlined functions so we check only the last one.
newFrame := frames[0]
last := d.frames[existing-1]
if last.Func != nil { // the last frame can't be inlined. Flush.
return false
}
if last.Entry == 0 || newFrame.Entry == 0 { // Possibly not a Go function. Don't try to merge.
return false
}
if last.Entry != newFrame.Entry { // newFrame is for a different function.
return false
}
if last.Function == newFrame.Function { // maybe recursion.
return false
}
}
d.pcs = append(d.pcs, pc)
d.frames = append(d.frames, frames...)
d.symbolizeResult |= symbolizeResult
if len(d.pcs) == 1 {
d.firstPCFrames = len(d.frames)
d.firstPCSymbolizeResult = symbolizeResult
}
return true
}
// emitLocation emits the new location and function information recorded in the deck
// and returns the location ID encoded in the profile protobuf.
// It emits to b.pb, so there must be no message encoding in progress.
// It resets the deck.
func (b *profileBuilder) emitLocation() uint64 {
if len(b.deck.pcs) == 0 {
return 0
}
defer b.deck.reset()
addr := b.deck.pcs[0]
firstFrame := b.deck.frames[0]
// We can't write out functions while in the middle of the
// Location message, so record new functions we encounter and
// write them out after the Location.
type newFunc struct {
id uint64
name, file string
startLine int64
}
newFuncs := make([]newFunc, 0, 8)
id := uint64(len(b.locs)) + 1
b.locs[addr] = locInfo{
id: id,
pcs: append([]uintptr{}, b.deck.pcs...),
firstPCSymbolizeResult: b.deck.firstPCSymbolizeResult,
firstPCFrames: append([]runtime.Frame{}, b.deck.frames[:b.deck.firstPCFrames]...),
}
start := b.pb.startMessage()
b.pb.uint64Opt(tagLocation_ID, id)
b.pb.uint64Opt(tagLocation_Address, uint64(firstFrame.PC))
for _, frame := range b.deck.frames {
// Write out each line in frame expansion.
funcID := uint64(b.funcs[frame.Function])
if funcID == 0 {
funcID = uint64(len(b.funcs)) + 1
b.funcs[frame.Function] = int(funcID)
newFuncs = append(newFuncs, newFunc{
id: funcID,
name: frame.Function,
file: frame.File,
startLine: int64(runtime_FrameStartLine(&frame)),
})
}
b.pbLine(tagLocation_Line, funcID, int64(frame.Line))
}
for i := range b.mem {
if b.mem[i].start <= addr && addr < b.mem[i].end || b.mem[i].fake {
b.pb.uint64Opt(tagLocation_MappingID, uint64(i+1))
m := b.mem[i]
m.funcs |= b.deck.symbolizeResult
b.mem[i] = m
break
}
}
b.pb.endMessage(tagProfile_Location, start)
// Write out functions we found during frame expansion.
for _, fn := range newFuncs {
start := b.pb.startMessage()
b.pb.uint64Opt(tagFunction_ID, fn.id)
b.pb.int64Opt(tagFunction_Name, b.stringIndex(fn.name))
b.pb.int64Opt(tagFunction_SystemName, b.stringIndex(fn.name))
b.pb.int64Opt(tagFunction_Filename, b.stringIndex(fn.file))
b.pb.int64Opt(tagFunction_StartLine, fn.startLine)
b.pb.endMessage(tagProfile_Function, start)
}
b.flush()
return id
}
var space = []byte(" ")
var newline = []byte("\n")
func parseProcSelfMaps(data []byte, addMapping func(lo, hi, offset uint64, file, buildID string)) {
// $ cat /proc/self/maps
// 00400000-0040b000 r-xp 00000000 fc:01 787766 /bin/cat
// 0060a000-0060b000 r--p 0000a000 fc:01 787766 /bin/cat
// 0060b000-0060c000 rw-p 0000b000 fc:01 787766 /bin/cat
// 014ab000-014cc000 rw-p 00000000 00:00 0 [heap]
// 7f7d76af8000-7f7d7797c000 r--p 00000000 fc:01 1318064 /usr/lib/locale/locale-archive
// 7f7d7797c000-7f7d77b36000 r-xp 00000000 fc:01 1180226 /lib/x86_64-linux-gnu/libc-2.19.so
// 7f7d77b36000-7f7d77d36000 ---p 001ba000 fc:01 1180226 /lib/x86_64-linux-gnu/libc-2.19.so
// 7f7d77d36000-7f7d77d3a000 r--p 001ba000 fc:01 1180226 /lib/x86_64-linux-gnu/libc-2.19.so
// 7f7d77d3a000-7f7d77d3c000 rw-p 001be000 fc:01 1180226 /lib/x86_64-linux-gnu/libc-2.19.so
// 7f7d77d3c000-7f7d77d41000 rw-p 00000000 00:00 0
// 7f7d77d41000-7f7d77d64000 r-xp 00000000 fc:01 1180217 /lib/x86_64-linux-gnu/ld-2.19.so
// 7f7d77f3f000-7f7d77f42000 rw-p 00000000 00:00 0
// 7f7d77f61000-7f7d77f63000 rw-p 00000000 00:00 0
// 7f7d77f63000-7f7d77f64000 r--p 00022000 fc:01 1180217 /lib/x86_64-linux-gnu/ld-2.19.so
// 7f7d77f64000-7f7d77f65000 rw-p 00023000 fc:01 1180217 /lib/x86_64-linux-gnu/ld-2.19.so
// 7f7d77f65000-7f7d77f66000 rw-p 00000000 00:00 0
// 7ffc342a2000-7ffc342c3000 rw-p 00000000 00:00 0 [stack]
// 7ffc34343000-7ffc34345000 r-xp 00000000 00:00 0 [vdso]
// ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 [vsyscall]
var line []byte
// next removes and returns the next field in the line.
// It also removes from line any spaces following the field.
next := func() []byte {
var f []byte
f, line, _ = bytes.Cut(line, space)
line = bytes.TrimLeft(line, " ")
return f
}
for len(data) > 0 {
line, data, _ = bytes.Cut(data, newline)
addr := next()
loStr, hiStr, ok := strings.Cut(string(addr), "-")
if !ok {
continue
}
lo, err := strconv.ParseUint(loStr, 16, 64)
if err != nil {
continue
}
hi, err := strconv.ParseUint(hiStr, 16, 64)
if err != nil {
continue
}
perm := next()
if len(perm) < 4 || perm[2] != 'x' {
// Only interested in executable mappings.
continue
}
offset, err := strconv.ParseUint(string(next()), 16, 64)
if err != nil {
continue
}
next() // dev
inode := next() // inode
if line == nil {
continue
}
file := string(line)
// Trim deleted file marker.
deletedStr := " (deleted)"
deletedLen := len(deletedStr)
if len(file) >= deletedLen && file[len(file)-deletedLen:] == deletedStr {
file = file[:len(file)-deletedLen]
}
if len(inode) == 1 && inode[0] == '0' && file == "" {
// Huge-page text mappings list the initial fragment of
// mapped but unpopulated memory as being inode 0.
// Don't report that part.
// But [vdso] and [vsyscall] are inode 0, so let non-empty file names through.
continue
}
// TODO: pprof's remapMappingIDs makes one adjustment:
// 1. If there is an /anon_hugepage mapping first and it is
// consecutive to a next mapping, drop the /anon_hugepage.
// There's no indication why this is needed.
// Let's try not doing this and see what breaks.
// If we do need it, it would go here, before we
// enter the mappings into b.mem in the first place.
buildID, _ := elfBuildID(file)
addMapping(lo, hi, offset, file, buildID)
}
}
func (b *profileBuilder) addMapping(lo, hi, offset uint64, file, buildID string) {
b.addMappingEntry(lo, hi, offset, file, buildID, false)
}
func (b *profileBuilder) addMappingEntry(lo, hi, offset uint64, file, buildID string, fake bool) {
b.mem = append(b.mem, memMap{
start: uintptr(lo),
end: uintptr(hi),
offset: offset,
file: file,
buildID: buildID,
fake: fake,
})
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package pprof
import (
"errors"
"os"
)
// readMapping reads /proc/self/maps and writes mappings to b.pb.
// It saves the address ranges of the mappings in b.mem for use
// when emitting locations.
func (b *profileBuilder) readMapping() {
data, _ := os.ReadFile("/proc/self/maps")
parseProcSelfMaps(data, b.addMapping)
if len(b.mem) == 0 { // pprof expects a map entry, so fake one.
b.addMappingEntry(0, 0, 0, "", "", true)
// TODO(hyangah): make addMapping return *memMap or
// take a memMap struct, and get rid of addMappingEntry
// that takes a bunch of positional arguments.
}
}
func readMainModuleMapping() (start, end uint64, err error) {
return 0, 0, errors.New("not implemented")
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
// A protobuf is a simple protocol buffer encoder.
type protobuf struct {
data []byte
tmp [16]byte
nest int
}
func (b *protobuf) varint(x uint64) {
for x >= 128 {
b.data = append(b.data, byte(x)|0x80)
x >>= 7
}
b.data = append(b.data, byte(x))
}
func (b *protobuf) length(tag int, len int) {
b.varint(uint64(tag)<<3 | 2)
b.varint(uint64(len))
}
func (b *protobuf) uint64(tag int, x uint64) {
// append varint to b.data
b.varint(uint64(tag)<<3 | 0)
b.varint(x)
}
func (b *protobuf) uint64s(tag int, x []uint64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
b.varint(u)
}
n2 := len(b.data)
b.length(tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
b.uint64(tag, u)
}
}
func (b *protobuf) uint64Opt(tag int, x uint64) {
if x == 0 {
return
}
b.uint64(tag, x)
}
func (b *protobuf) int64(tag int, x int64) {
u := uint64(x)
b.uint64(tag, u)
}
func (b *protobuf) int64Opt(tag int, x int64) {
if x == 0 {
return
}
b.int64(tag, x)
}
func (b *protobuf) int64s(tag int, x []int64) {
if len(x) > 2 {
// Use packed encoding
n1 := len(b.data)
for _, u := range x {
b.varint(uint64(u))
}
n2 := len(b.data)
b.length(tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
return
}
for _, u := range x {
b.int64(tag, u)
}
}
func (b *protobuf) string(tag int, x string) {
b.length(tag, len(x))
b.data = append(b.data, x...)
}
func (b *protobuf) strings(tag int, x []string) {
for _, s := range x {
b.string(tag, s)
}
}
func (b *protobuf) stringOpt(tag int, x string) {
if x == "" {
return
}
b.string(tag, x)
}
func (b *protobuf) bool(tag int, x bool) {
if x {
b.uint64(tag, 1)
} else {
b.uint64(tag, 0)
}
}
func (b *protobuf) boolOpt(tag int, x bool) {
if !x {
return
}
b.bool(tag, x)
}
type msgOffset int
func (b *protobuf) startMessage() msgOffset {
b.nest++
return msgOffset(len(b.data))
}
func (b *protobuf) endMessage(tag int, start msgOffset) {
n1 := int(start)
n2 := len(b.data)
b.length(tag, n2-n1)
n3 := len(b.data)
copy(b.tmp[:], b.data[n2:n3])
copy(b.data[n1+(n3-n2):], b.data[n1:n2])
copy(b.data[n1:], b.tmp[:n3-n2])
b.nest--
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import (
"io"
"math"
"runtime"
"strings"
)
// writeHeapProto writes the current heap profile in protobuf format to w.
func writeHeapProto(w io.Writer, p []runtime.MemProfileRecord, rate int64, defaultSampleType string) error {
b := newProfileBuilder(w)
b.pbValueType(tagProfile_PeriodType, "space", "bytes")
b.pb.int64Opt(tagProfile_Period, rate)
b.pbValueType(tagProfile_SampleType, "alloc_objects", "count")
b.pbValueType(tagProfile_SampleType, "alloc_space", "bytes")
b.pbValueType(tagProfile_SampleType, "inuse_objects", "count")
b.pbValueType(tagProfile_SampleType, "inuse_space", "bytes")
if defaultSampleType != "" {
b.pb.int64Opt(tagProfile_DefaultSampleType, b.stringIndex(defaultSampleType))
}
values := []int64{0, 0, 0, 0}
var locs []uint64
for _, r := range p {
hideRuntime := true
for tries := 0; tries < 2; tries++ {
stk := r.Stack()
// For heap profiles, all stack
// addresses are return PCs, which is
// what appendLocsForStack expects.
if hideRuntime {
for i, addr := range stk {
if f := runtime.FuncForPC(addr); f != nil && strings.HasPrefix(f.Name(), "runtime.") {
continue
}
// Found non-runtime. Show any runtime uses above it.
stk = stk[i:]
break
}
}
locs = b.appendLocsForStack(locs[:0], stk)
if len(locs) > 0 {
break
}
hideRuntime = false // try again, and show all frames next time.
}
values[0], values[1] = scaleHeapSample(r.AllocObjects, r.AllocBytes, rate)
values[2], values[3] = scaleHeapSample(r.InUseObjects(), r.InUseBytes(), rate)
var blockSize int64
if r.AllocObjects > 0 {
blockSize = r.AllocBytes / r.AllocObjects
}
b.pbSample(values, locs, func() {
if blockSize != 0 {
b.pbLabel(tagSample_Label, "bytes", "", blockSize)
}
})
}
b.build()
return nil
}
// scaleHeapSample adjusts the data from a heap Sample to
// account for its probability of appearing in the collected
// data. heap profiles are a sampling of the memory allocations
// requests in a program. We estimate the unsampled value by dividing
// each collected sample by its probability of appearing in the
// profile. heap profiles rely on a poisson process to determine
// which samples to collect, based on the desired average collection
// rate R. The probability of a sample of size S to appear in that
// profile is 1-exp(-S/R).
func scaleHeapSample(count, size, rate int64) (int64, int64) {
if count == 0 || size == 0 {
return 0, 0
}
if rate <= 1 {
// if rate==1 all samples were collected so no adjustment is needed.
// if rate<1 treat as unknown and skip scaling.
return count, size
}
avgSize := float64(size) / float64(count)
scale := 1 / (1 - math.Exp(-avgSize/float64(rate)))
return int64(float64(count) * scale), int64(float64(size) * scale)
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pprof
import (
"context"
"runtime"
"unsafe"
)
// runtime_FrameStartLine is defined in runtime/symtab.go.
func runtime_FrameStartLine(f *runtime.Frame) int
// runtime_expandFinalInlineFrame is defined in runtime/symtab.go.
func runtime_expandFinalInlineFrame(stk []uintptr) []uintptr
// runtime_setProfLabel is defined in runtime/proflabel.go.
func runtime_setProfLabel(labels unsafe.Pointer)
// runtime_getProfLabel is defined in runtime/proflabel.go.
func runtime_getProfLabel() unsafe.Pointer
// SetGoroutineLabels sets the current goroutine's labels to match ctx.
// A new goroutine inherits the labels of the goroutine that created it.
// This is a lower-level API than Do, which should be used instead when possible.
func SetGoroutineLabels(ctx context.Context) {
ctxLabels, _ := ctx.Value(labelContextKey{}).(*labelMap)
runtime_setProfLabel(unsafe.Pointer(ctxLabels))
}
// Do calls f with a copy of the parent context with the
// given labels added to the parent's label map.
// Goroutines spawned while executing f will inherit the augmented label-set.
// Each key/value pair in labels is inserted into the label map in the
// order provided, overriding any previous value for the same key.
// The augmented label map will be set for the duration of the call to f
// and restored once f returns.
func Do(ctx context.Context, labels LabelSet, f func(context.Context)) {
defer SetGoroutineLabels(ctx)
ctx = WithLabels(ctx, labels)
SetGoroutineLabels(ctx)
f(ctx)
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Goroutine preemption
//
// A goroutine can be preempted at any safe-point. Currently, there
// are a few categories of safe-points:
//
// 1. A blocked safe-point occurs for the duration that a goroutine is
// descheduled, blocked on synchronization, or in a system call.
//
// 2. Synchronous safe-points occur when a running goroutine checks
// for a preemption request.
//
// 3. Asynchronous safe-points occur at any instruction in user code
// where the goroutine can be safely paused and a conservative
// stack and register scan can find stack roots. The runtime can
// stop a goroutine at an async safe-point using a signal.
//
// At both blocked and synchronous safe-points, a goroutine's CPU
// state is minimal and the garbage collector has complete information
// about its entire stack. This makes it possible to deschedule a
// goroutine with minimal space, and to precisely scan a goroutine's
// stack.
//
// Synchronous safe-points are implemented by overloading the stack
// bound check in function prologues. To preempt a goroutine at the
// next synchronous safe-point, the runtime poisons the goroutine's
// stack bound to a value that will cause the next stack bound check
// to fail and enter the stack growth implementation, which will
// detect that it was actually a preemption and redirect to preemption
// handling.
//
// Preemption at asynchronous safe-points is implemented by suspending
// the thread using an OS mechanism (e.g., signals) and inspecting its
// state to determine if the goroutine was at an asynchronous
// safe-point. Since the thread suspension itself is generally
// asynchronous, it also checks if the running goroutine wants to be
// preempted, since this could have changed. If all conditions are
// satisfied, it adjusts the signal context to make it look like the
// signaled thread just called asyncPreempt and resumes the thread.
// asyncPreempt spills all registers and enters the scheduler.
//
// (An alternative would be to preempt in the signal handler itself.
// This would let the OS save and restore the register state and the
// runtime would only need to know how to extract potentially
// pointer-containing registers from the signal context. However, this
// would consume an M for every preempted G, and the scheduler itself
// is not designed to run from a signal handler, as it tends to
// allocate memory and start threads in the preemption path.)
package runtime
import (
"internal/abi"
"internal/goarch"
)
type suspendGState struct {
g *g
// dead indicates the goroutine was not suspended because it
// is dead. This goroutine could be reused after the dead
// state was observed, so the caller must not assume that it
// remains dead.
dead bool
// stopped indicates that this suspendG transitioned the G to
// _Gwaiting via g.preemptStop and thus is responsible for
// readying it when done.
stopped bool
}
// suspendG suspends goroutine gp at a safe-point and returns the
// state of the suspended goroutine. The caller gets read access to
// the goroutine until it calls resumeG.
//
// It is safe for multiple callers to attempt to suspend the same
// goroutine at the same time. The goroutine may execute between
// subsequent successful suspend operations. The current
// implementation grants exclusive access to the goroutine, and hence
// multiple callers will serialize. However, the intent is to grant
// shared read access, so please don't depend on exclusive access.
//
// This must be called from the system stack and the user goroutine on
// the current M (if any) must be in a preemptible state. This
// prevents deadlocks where two goroutines attempt to suspend each
// other and both are in non-preemptible states. There are other ways
// to resolve this deadlock, but this seems simplest.
//
// TODO(austin): What if we instead required this to be called from a
// user goroutine? Then we could deschedule the goroutine while
// waiting instead of blocking the thread. If two goroutines tried to
// suspend each other, one of them would win and the other wouldn't
// complete the suspend until it was resumed. We would have to be
// careful that they couldn't actually queue up suspend for each other
// and then both be suspended. This would also avoid the need for a
// kernel context switch in the synchronous case because we could just
// directly schedule the waiter. The context switch is unavoidable in
// the signal case.
//
//go:systemstack
func suspendG(gp *g) suspendGState {
if mp := getg().m; mp.curg != nil && readgstatus(mp.curg) == _Grunning {
// Since we're on the system stack of this M, the user
// G is stuck at an unsafe point. If another goroutine
// were to try to preempt m.curg, it could deadlock.
throw("suspendG from non-preemptible goroutine")
}
// See https://golang.org/cl/21503 for justification of the yield delay.
const yieldDelay = 10 * 1000
var nextYield int64
// Drive the goroutine to a preemption point.
stopped := false
var asyncM *m
var asyncGen uint32
var nextPreemptM int64
for i := 0; ; i++ {
switch s := readgstatus(gp); s {
default:
if s&_Gscan != 0 {
// Someone else is suspending it. Wait
// for them to finish.
//
// TODO: It would be nicer if we could
// coalesce suspends.
break
}
dumpgstatus(gp)
throw("invalid g status")
case _Gdead:
// Nothing to suspend.
//
// preemptStop may need to be cleared, but
// doing that here could race with goroutine
// reuse. Instead, goexit0 clears it.
return suspendGState{dead: true}
case _Gcopystack:
// The stack is being copied. We need to wait
// until this is done.
case _Gpreempted:
// We (or someone else) suspended the G. Claim
// ownership of it by transitioning it to
// _Gwaiting.
if !casGFromPreempted(gp, _Gpreempted, _Gwaiting) {
break
}
// We stopped the G, so we have to ready it later.
stopped = true
s = _Gwaiting
fallthrough
case _Grunnable, _Gsyscall, _Gwaiting:
// Claim goroutine by setting scan bit.
// This may race with execution or readying of gp.
// The scan bit keeps it from transition state.
if !castogscanstatus(gp, s, s|_Gscan) {
break
}
// Clear the preemption request. It's safe to
// reset the stack guard because we hold the
// _Gscan bit and thus own the stack.
gp.preemptStop = false
gp.preempt = false
gp.stackguard0 = gp.stack.lo + _StackGuard
// The goroutine was already at a safe-point
// and we've now locked that in.
//
// TODO: It would be much better if we didn't
// leave it in _Gscan, but instead gently
// prevented its scheduling until resumption.
// Maybe we only use this to bump a suspended
// count and the scheduler skips suspended
// goroutines? That wouldn't be enough for
// {_Gsyscall,_Gwaiting} -> _Grunning. Maybe
// for all those transitions we need to check
// suspended and deschedule?
return suspendGState{g: gp, stopped: stopped}
case _Grunning:
// Optimization: if there is already a pending preemption request
// (from the previous loop iteration), don't bother with the atomics.
if gp.preemptStop && gp.preempt && gp.stackguard0 == stackPreempt && asyncM == gp.m && asyncM.preemptGen.Load() == asyncGen {
break
}
// Temporarily block state transitions.
if !castogscanstatus(gp, _Grunning, _Gscanrunning) {
break
}
// Request synchronous preemption.
gp.preemptStop = true
gp.preempt = true
gp.stackguard0 = stackPreempt
// Prepare for asynchronous preemption.
asyncM2 := gp.m
asyncGen2 := asyncM2.preemptGen.Load()
needAsync := asyncM != asyncM2 || asyncGen != asyncGen2
asyncM = asyncM2
asyncGen = asyncGen2
casfrom_Gscanstatus(gp, _Gscanrunning, _Grunning)
// Send asynchronous preemption. We do this
// after CASing the G back to _Grunning
// because preemptM may be synchronous and we
// don't want to catch the G just spinning on
// its status.
if preemptMSupported && debug.asyncpreemptoff == 0 && needAsync {
// Rate limit preemptM calls. This is
// particularly important on Windows
// where preemptM is actually
// synchronous and the spin loop here
// can lead to live-lock.
now := nanotime()
if now >= nextPreemptM {
nextPreemptM = now + yieldDelay/2
preemptM(asyncM)
}
}
}
// TODO: Don't busy wait. This loop should really only
// be a simple read/decide/CAS loop that only fails if
// there's an active race. Once the CAS succeeds, we
// should queue up the preemption (which will require
// it to be reliable in the _Grunning case, not
// best-effort) and then sleep until we're notified
// that the goroutine is suspended.
if i == 0 {
nextYield = nanotime() + yieldDelay
}
if nanotime() < nextYield {
procyield(10)
} else {
osyield()
nextYield = nanotime() + yieldDelay/2
}
}
}
// resumeG undoes the effects of suspendG, allowing the suspended
// goroutine to continue from its current safe-point.
func resumeG(state suspendGState) {
if state.dead {
// We didn't actually stop anything.
return
}
gp := state.g
switch s := readgstatus(gp); s {
default:
dumpgstatus(gp)
throw("unexpected g status")
case _Grunnable | _Gscan,
_Gwaiting | _Gscan,
_Gsyscall | _Gscan:
casfrom_Gscanstatus(gp, s, s&^_Gscan)
}
if state.stopped {
// We stopped it, so we need to re-schedule it.
ready(gp, 0, true)
}
}
// canPreemptM reports whether mp is in a state that is safe to preempt.
//
// It is nosplit because it has nosplit callers.
//
//go:nosplit
func canPreemptM(mp *m) bool {
return mp.locks == 0 && mp.mallocing == 0 && mp.preemptoff == "" && mp.p.ptr().status == _Prunning
}
//go:generate go run mkpreempt.go
// asyncPreempt saves all user registers and calls asyncPreempt2.
//
// When stack scanning encounters an asyncPreempt frame, it scans that
// frame and its parent frame conservatively.
//
// asyncPreempt is implemented in assembly.
func asyncPreempt()
//go:nosplit
func asyncPreempt2() {
gp := getg()
gp.asyncSafePoint = true
if gp.preemptStop {
mcall(preemptPark)
} else {
mcall(gopreempt_m)
}
gp.asyncSafePoint = false
}
// asyncPreemptStack is the bytes of stack space required to inject an
// asyncPreempt call.
var asyncPreemptStack = ^uintptr(0)
func init() {
f := findfunc(abi.FuncPCABI0(asyncPreempt))
total := funcMaxSPDelta(f)
f = findfunc(abi.FuncPCABIInternal(asyncPreempt2))
total += funcMaxSPDelta(f)
// Add some overhead for return PCs, etc.
asyncPreemptStack = uintptr(total) + 8*goarch.PtrSize
if asyncPreemptStack > _StackLimit {
// We need more than the nosplit limit. This isn't
// unsafe, but it may limit asynchronous preemption.
//
// This may be a problem if we start using more
// registers. In that case, we should store registers
// in a context object. If we pre-allocate one per P,
// asyncPreempt can spill just a few registers to the
// stack, then grab its context object and spill into
// it. When it enters the runtime, it would allocate a
// new context for the P.
print("runtime: asyncPreemptStack=", asyncPreemptStack, "\n")
throw("async stack too large")
}
}
// wantAsyncPreempt returns whether an asynchronous preemption is
// queued for gp.
func wantAsyncPreempt(gp *g) bool {
// Check both the G and the P.
return (gp.preempt || gp.m.p != 0 && gp.m.p.ptr().preempt) && readgstatus(gp)&^_Gscan == _Grunning
}
// isAsyncSafePoint reports whether gp at instruction PC is an
// asynchronous safe point. This indicates that:
//
// 1. It's safe to suspend gp and conservatively scan its stack and
// registers. There are no potentially hidden pointer values and it's
// not in the middle of an atomic sequence like a write barrier.
//
// 2. gp has enough stack space to inject the asyncPreempt call.
//
// 3. It's generally safe to interact with the runtime, even if we're
// in a signal handler stopped here. For example, there are no runtime
// locks held, so acquiring a runtime lock won't self-deadlock.
//
// In some cases the PC is safe for asynchronous preemption but it
// also needs to adjust the resumption PC. The new PC is returned in
// the second result.
func isAsyncSafePoint(gp *g, pc, sp, lr uintptr) (bool, uintptr) {
mp := gp.m
// Only user Gs can have safe-points. We check this first
// because it's extremely common that we'll catch mp in the
// scheduler processing this G preemption.
if mp.curg != gp {
return false, 0
}
// Check M state.
if mp.p == 0 || !canPreemptM(mp) {
return false, 0
}
// Check stack space.
if sp < gp.stack.lo || sp-gp.stack.lo < asyncPreemptStack {
return false, 0
}
// Check if PC is an unsafe-point.
f := findfunc(pc)
if !f.valid() {
// Not Go code.
return false, 0
}
if (GOARCH == "mips" || GOARCH == "mipsle" || GOARCH == "mips64" || GOARCH == "mips64le") && lr == pc+8 && funcspdelta(f, pc, nil) == 0 {
// We probably stopped at a half-executed CALL instruction,
// where the LR is updated but the PC has not. If we preempt
// here we'll see a seemingly self-recursive call, which is in
// fact not.
// This is normally ok, as we use the return address saved on
// stack for unwinding, not the LR value. But if this is a
// call to morestack, we haven't created the frame, and we'll
// use the LR for unwinding, which will be bad.
return false, 0
}
up, startpc := pcdatavalue2(f, _PCDATA_UnsafePoint, pc)
if up == _PCDATA_UnsafePointUnsafe {
// Unsafe-point marked by compiler. This includes
// atomic sequences (e.g., write barrier) and nosplit
// functions (except at calls).
return false, 0
}
if fd := funcdata(f, _FUNCDATA_LocalsPointerMaps); fd == nil || f.flag&funcFlag_ASM != 0 {
// This is assembly code. Don't assume it's well-formed.
// TODO: Empirically we still need the fd == nil check. Why?
//
// TODO: Are there cases that are safe but don't have a
// locals pointer map, like empty frame functions?
// It might be possible to preempt any assembly functions
// except the ones that have funcFlag_SPWRITE set in f.flag.
return false, 0
}
name := funcname(f)
if inldata := funcdata(f, _FUNCDATA_InlTree); inldata != nil {
inltree := (*[1 << 20]inlinedCall)(inldata)
ix := pcdatavalue(f, _PCDATA_InlTreeIndex, pc, nil)
if ix >= 0 {
name = funcnameFromNameOff(f, inltree[ix].nameOff)
}
}
if hasPrefix(name, "runtime.") ||
hasPrefix(name, "runtime/internal/") ||
hasPrefix(name, "reflect.") {
// For now we never async preempt the runtime or
// anything closely tied to the runtime. Known issues
// include: various points in the scheduler ("don't
// preempt between here and here"), much of the defer
// implementation (untyped info on stack), bulk write
// barriers (write barrier check),
// reflect.{makeFuncStub,methodValueCall}.
//
// TODO(austin): We should improve this, or opt things
// in incrementally.
return false, 0
}
switch up {
case _PCDATA_Restart1, _PCDATA_Restart2:
// Restartable instruction sequence. Back off PC to
// the start PC.
if startpc == 0 || startpc > pc || pc-startpc > 20 {
throw("bad restart PC")
}
return true, startpc
case _PCDATA_RestartAtEntry:
// Restart from the function entry at resumption.
return true, f.entry()
}
return true, pc
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package runtime
//go:nosplit
func osPreemptExtEnter(mp *m) {}
//go:nosplit
func osPreemptExtExit(mp *m) {}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"unsafe"
)
// The compiler knows that a print of a value of this type
// should use printhex instead of printuint (decimal).
type hex uint64
func bytes(s string) (ret []byte) {
rp := (*slice)(unsafe.Pointer(&ret))
sp := stringStructOf(&s)
rp.array = sp.str
rp.len = sp.len
rp.cap = sp.len
return
}
var (
// printBacklog is a circular buffer of messages written with the builtin
// print* functions, for use in postmortem analysis of core dumps.
printBacklog [512]byte
printBacklogIndex int
)
// recordForPanic maintains a circular buffer of messages written by the
// runtime leading up to a process crash, allowing the messages to be
// extracted from a core dump.
//
// The text written during a process crash (following "panic" or "fatal
// error") is not saved, since the goroutine stacks will generally be readable
// from the runtime data structures in the core file.
func recordForPanic(b []byte) {
printlock()
if panicking.Load() == 0 {
// Not actively crashing: maintain circular buffer of print output.
for i := 0; i < len(b); {
n := copy(printBacklog[printBacklogIndex:], b[i:])
i += n
printBacklogIndex += n
printBacklogIndex %= len(printBacklog)
}
}
printunlock()
}
var debuglock mutex
// The compiler emits calls to printlock and printunlock around
// the multiple calls that implement a single Go print or println
// statement. Some of the print helpers (printslice, for example)
// call print recursively. There is also the problem of a crash
// happening during the print routines and needing to acquire
// the print lock to print information about the crash.
// For both these reasons, let a thread acquire the printlock 'recursively'.
func printlock() {
mp := getg().m
mp.locks++ // do not reschedule between printlock++ and lock(&debuglock).
mp.printlock++
if mp.printlock == 1 {
lock(&debuglock)
}
mp.locks-- // now we know debuglock is held and holding up mp.locks for us.
}
func printunlock() {
mp := getg().m
mp.printlock--
if mp.printlock == 0 {
unlock(&debuglock)
}
}
// write to goroutine-local buffer if diverting output,
// or else standard error.
func gwrite(b []byte) {
if len(b) == 0 {
return
}
recordForPanic(b)
gp := getg()
// Don't use the writebuf if gp.m is dying. We want anything
// written through gwrite to appear in the terminal rather
// than be written to in some buffer, if we're in a panicking state.
// Note that we can't just clear writebuf in the gp.m.dying case
// because a panic isn't allowed to have any write barriers.
if gp == nil || gp.writebuf == nil || gp.m.dying > 0 {
writeErr(b)
return
}
n := copy(gp.writebuf[len(gp.writebuf):cap(gp.writebuf)], b)
gp.writebuf = gp.writebuf[:len(gp.writebuf)+n]
}
func printsp() {
printstring(" ")
}
func printnl() {
printstring("\n")
}
func printbool(v bool) {
if v {
printstring("true")
} else {
printstring("false")
}
}
func printfloat(v float64) {
switch {
case v != v:
printstring("NaN")
return
case v+v == v && v > 0:
printstring("+Inf")
return
case v+v == v && v < 0:
printstring("-Inf")
return
}
const n = 7 // digits printed
var buf [n + 7]byte
buf[0] = '+'
e := 0 // exp
if v == 0 {
if 1/v < 0 {
buf[0] = '-'
}
} else {
if v < 0 {
v = -v
buf[0] = '-'
}
// normalize
for v >= 10 {
e++
v /= 10
}
for v < 1 {
e--
v *= 10
}
// round
h := 5.0
for i := 0; i < n; i++ {
h /= 10
}
v += h
if v >= 10 {
e++
v /= 10
}
}
// format +d.dddd+edd
for i := 0; i < n; i++ {
s := int(v)
buf[i+2] = byte(s + '0')
v -= float64(s)
v *= 10
}
buf[1] = buf[2]
buf[2] = '.'
buf[n+2] = 'e'
buf[n+3] = '+'
if e < 0 {
e = -e
buf[n+3] = '-'
}
buf[n+4] = byte(e/100) + '0'
buf[n+5] = byte(e/10)%10 + '0'
buf[n+6] = byte(e%10) + '0'
gwrite(buf[:])
}
func printcomplex(c complex128) {
print("(", real(c), imag(c), "i)")
}
func printuint(v uint64) {
var buf [100]byte
i := len(buf)
for i--; i > 0; i-- {
buf[i] = byte(v%10 + '0')
if v < 10 {
break
}
v /= 10
}
gwrite(buf[i:])
}
func printint(v int64) {
if v < 0 {
printstring("-")
v = -v
}
printuint(uint64(v))
}
var minhexdigits = 0 // protected by printlock
func printhex(v uint64) {
const dig = "0123456789abcdef"
var buf [100]byte
i := len(buf)
for i--; i > 0; i-- {
buf[i] = dig[v%16]
if v < 16 && len(buf)-i >= minhexdigits {
break
}
v /= 16
}
i--
buf[i] = 'x'
i--
buf[i] = '0'
gwrite(buf[i:])
}
func printpointer(p unsafe.Pointer) {
printhex(uint64(uintptr(p)))
}
func printuintptr(p uintptr) {
printhex(uint64(p))
}
func printstring(s string) {
gwrite(bytes(s))
}
func printslice(s []byte) {
sp := (*slice)(unsafe.Pointer(&s))
print("[", len(s), "/", cap(s), "]")
printpointer(sp.array)
}
func printeface(e eface) {
print("(", e._type, ",", e.data, ")")
}
func printiface(i iface) {
print("(", i.tab, ",", i.data, ")")
}
// hexdumpWords prints a word-oriented hex dump of [p, end).
//
// If mark != nil, it will be called with each printed word's address
// and should return a character mark to appear just before that
// word's value. It can return 0 to indicate no mark.
func hexdumpWords(p, end uintptr, mark func(uintptr) byte) {
printlock()
var markbuf [1]byte
markbuf[0] = ' '
minhexdigits = int(unsafe.Sizeof(uintptr(0)) * 2)
for i := uintptr(0); p+i < end; i += goarch.PtrSize {
if i%16 == 0 {
if i != 0 {
println()
}
print(hex(p+i), ": ")
}
if mark != nil {
markbuf[0] = mark(p + i)
if markbuf[0] == 0 {
markbuf[0] = ' '
}
}
gwrite(markbuf[:])
val := *(*uintptr)(unsafe.Pointer(p + i))
print(hex(val))
print(" ")
// Can we symbolize val?
fn := findfunc(val)
if fn.valid() {
print("<", funcname(fn), "+", hex(val-fn.entry()), "> ")
}
}
minhexdigits = 0
println()
printunlock()
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/cpu"
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// set using cmd/go/internal/modload.ModInfoProg
var modinfo string
// Goroutine scheduler
// The scheduler's job is to distribute ready-to-run goroutines over worker threads.
//
// The main concepts are:
// G - goroutine.
// M - worker thread, or machine.
// P - processor, a resource that is required to execute Go code.
// M must have an associated P to execute Go code, however it can be
// blocked or in a syscall w/o an associated P.
//
// Design doc at https://golang.org/s/go11sched.
// Worker thread parking/unparking.
// We need to balance between keeping enough running worker threads to utilize
// available hardware parallelism and parking excessive running worker threads
// to conserve CPU resources and power. This is not simple for two reasons:
// (1) scheduler state is intentionally distributed (in particular, per-P work
// queues), so it is not possible to compute global predicates on fast paths;
// (2) for optimal thread management we would need to know the future (don't park
// a worker thread when a new goroutine will be readied in near future).
//
// Three rejected approaches that would work badly:
// 1. Centralize all scheduler state (would inhibit scalability).
// 2. Direct goroutine handoff. That is, when we ready a new goroutine and there
// is a spare P, unpark a thread and handoff it the thread and the goroutine.
// This would lead to thread state thrashing, as the thread that readied the
// goroutine can be out of work the very next moment, we will need to park it.
// Also, it would destroy locality of computation as we want to preserve
// dependent goroutines on the same thread; and introduce additional latency.
// 3. Unpark an additional thread whenever we ready a goroutine and there is an
// idle P, but don't do handoff. This would lead to excessive thread parking/
// unparking as the additional threads will instantly park without discovering
// any work to do.
//
// The current approach:
//
// This approach applies to three primary sources of potential work: readying a
// goroutine, new/modified-earlier timers, and idle-priority GC. See below for
// additional details.
//
// We unpark an additional thread when we submit work if (this is wakep()):
// 1. There is an idle P, and
// 2. There are no "spinning" worker threads.
//
// A worker thread is considered spinning if it is out of local work and did
// not find work in the global run queue or netpoller; the spinning state is
// denoted in m.spinning and in sched.nmspinning. Threads unparked this way are
// also considered spinning; we don't do goroutine handoff so such threads are
// out of work initially. Spinning threads spin on looking for work in per-P
// run queues and timer heaps or from the GC before parking. If a spinning
// thread finds work it takes itself out of the spinning state and proceeds to
// execution. If it does not find work it takes itself out of the spinning
// state and then parks.
//
// If there is at least one spinning thread (sched.nmspinning>1), we don't
// unpark new threads when submitting work. To compensate for that, if the last
// spinning thread finds work and stops spinning, it must unpark a new spinning
// thread. This approach smooths out unjustified spikes of thread unparking,
// but at the same time guarantees eventual maximal CPU parallelism
// utilization.
//
// The main implementation complication is that we need to be very careful
// during spinning->non-spinning thread transition. This transition can race
// with submission of new work, and either one part or another needs to unpark
// another worker thread. If they both fail to do that, we can end up with
// semi-persistent CPU underutilization.
//
// The general pattern for submission is:
// 1. Submit work to the local run queue, timer heap, or GC state.
// 2. #StoreLoad-style memory barrier.
// 3. Check sched.nmspinning.
//
// The general pattern for spinning->non-spinning transition is:
// 1. Decrement nmspinning.
// 2. #StoreLoad-style memory barrier.
// 3. Check all per-P work queues and GC for new work.
//
// Note that all this complexity does not apply to global run queue as we are
// not sloppy about thread unparking when submitting to global queue. Also see
// comments for nmspinning manipulation.
//
// How these different sources of work behave varies, though it doesn't affect
// the synchronization approach:
// * Ready goroutine: this is an obvious source of work; the goroutine is
// immediately ready and must run on some thread eventually.
// * New/modified-earlier timer: The current timer implementation (see time.go)
// uses netpoll in a thread with no work available to wait for the soonest
// timer. If there is no thread waiting, we want a new spinning thread to go
// wait.
// * Idle-priority GC: The GC wakes a stopped idle thread to contribute to
// background GC work (note: currently disabled per golang.org/issue/19112).
// Also see golang.org/issue/44313, as this should be extended to all GC
// workers.
var (
m0 m
g0 g
mcache0 *mcache
raceprocctx0 uintptr
)
// This slice records the initializing tasks that need to be
// done to start up the runtime. It is built by the linker.
var runtime_inittasks []*initTask
// main_init_done is a signal used by cgocallbackg that initialization
// has been completed. It is made before _cgo_notify_runtime_init_done,
// so all cgo calls can rely on it existing. When main_init is complete,
// it is closed, meaning cgocallbackg can reliably receive from it.
var main_init_done chan bool
//go:linkname main_main main.main
func main_main()
// mainStarted indicates that the main M has started.
var mainStarted bool
// runtimeInitTime is the nanotime() at which the runtime started.
var runtimeInitTime int64
// Value to use for signal mask for newly created M's.
var initSigmask sigset
// The main goroutine.
func main() {
mp := getg().m
// Racectx of m0->g0 is used only as the parent of the main goroutine.
// It must not be used for anything else.
mp.g0.racectx = 0
// Max stack size is 1 GB on 64-bit, 250 MB on 32-bit.
// Using decimal instead of binary GB and MB because
// they look nicer in the stack overflow failure message.
if goarch.PtrSize == 8 {
maxstacksize = 1000000000
} else {
maxstacksize = 250000000
}
// An upper limit for max stack size. Used to avoid random crashes
// after calling SetMaxStack and trying to allocate a stack that is too big,
// since stackalloc works with 32-bit sizes.
maxstackceiling = 2 * maxstacksize
// Allow newproc to start new Ms.
mainStarted = true
if GOARCH != "wasm" { // no threads on wasm yet, so no sysmon
systemstack(func() {
newm(sysmon, nil, -1)
})
}
// Lock the main goroutine onto this, the main OS thread,
// during initialization. Most programs won't care, but a few
// do require certain calls to be made by the main thread.
// Those can arrange for main.main to run in the main thread
// by calling runtime.LockOSThread during initialization
// to preserve the lock.
lockOSThread()
if mp != &m0 {
throw("runtime.main not on m0")
}
// Record when the world started.
// Must be before doInit for tracing init.
runtimeInitTime = nanotime()
if runtimeInitTime == 0 {
throw("nanotime returning zero")
}
if debug.inittrace != 0 {
inittrace.id = getg().goid
inittrace.active = true
}
doInit(runtime_inittasks) // Must be before defer.
// Defer unlock so that runtime.Goexit during init does the unlock too.
needUnlock := true
defer func() {
if needUnlock {
unlockOSThread()
}
}()
gcenable()
main_init_done = make(chan bool)
if iscgo {
if _cgo_thread_start == nil {
throw("_cgo_thread_start missing")
}
if GOOS != "windows" {
if _cgo_setenv == nil {
throw("_cgo_setenv missing")
}
if _cgo_unsetenv == nil {
throw("_cgo_unsetenv missing")
}
}
if _cgo_notify_runtime_init_done == nil {
throw("_cgo_notify_runtime_init_done missing")
}
// Start the template thread in case we enter Go from
// a C-created thread and need to create a new thread.
startTemplateThread()
cgocall(_cgo_notify_runtime_init_done, nil)
}
// Run the initializing tasks. Depending on build mode this
// list can arrive a few different ways, but it will always
// contain the init tasks computed by the linker for all the
// packages in the program (excluding those added at runtime
// by package plugin).
for _, m := range activeModules() {
doInit(m.inittasks)
}
// Disable init tracing after main init done to avoid overhead
// of collecting statistics in malloc and newproc
inittrace.active = false
close(main_init_done)
needUnlock = false
unlockOSThread()
if isarchive || islibrary {
// A program compiled with -buildmode=c-archive or c-shared
// has a main, but it is not executed.
return
}
fn := main_main // make an indirect call, as the linker doesn't know the address of the main package when laying down the runtime
fn()
if raceenabled {
runExitHooks(0) // run hooks now, since racefini does not return
racefini()
}
// Make racy client program work: if panicking on
// another goroutine at the same time as main returns,
// let the other goroutine finish printing the panic trace.
// Once it does, it will exit. See issues 3934 and 20018.
if runningPanicDefers.Load() != 0 {
// Running deferred functions should not take long.
for c := 0; c < 1000; c++ {
if runningPanicDefers.Load() == 0 {
break
}
Gosched()
}
}
if panicking.Load() != 0 {
gopark(nil, nil, waitReasonPanicWait, traceEvGoStop, 1)
}
runExitHooks(0)
exit(0)
for {
var x *int32
*x = 0
}
}
// os_beforeExit is called from os.Exit(0).
//
//go:linkname os_beforeExit os.runtime_beforeExit
func os_beforeExit(exitCode int) {
runExitHooks(exitCode)
if exitCode == 0 && raceenabled {
racefini()
}
}
// start forcegc helper goroutine
func init() {
go forcegchelper()
}
func forcegchelper() {
forcegc.g = getg()
lockInit(&forcegc.lock, lockRankForcegc)
for {
lock(&forcegc.lock)
if forcegc.idle.Load() {
throw("forcegc: phase error")
}
forcegc.idle.Store(true)
goparkunlock(&forcegc.lock, waitReasonForceGCIdle, traceEvGoBlock, 1)
// this goroutine is explicitly resumed by sysmon
if debug.gctrace > 0 {
println("GC forced")
}
// Time-triggered, fully concurrent.
gcStart(gcTrigger{kind: gcTriggerTime, now: nanotime()})
}
}
// Gosched yields the processor, allowing other goroutines to run. It does not
// suspend the current goroutine, so execution resumes automatically.
//
//go:nosplit
func Gosched() {
checkTimeouts()
mcall(gosched_m)
}
// goschedguarded yields the processor like gosched, but also checks
// for forbidden states and opts out of the yield in those cases.
//
//go:nosplit
func goschedguarded() {
mcall(goschedguarded_m)
}
// goschedIfBusy yields the processor like gosched, but only does so if
// there are no idle Ps or if we're on the only P and there's nothing in
// the run queue. In both cases, there is freely available idle time.
//
//go:nosplit
func goschedIfBusy() {
gp := getg()
// Call gosched if gp.preempt is set; we may be in a tight loop that
// doesn't otherwise yield.
if !gp.preempt && sched.npidle.Load() > 0 {
return
}
mcall(gosched_m)
}
// Puts the current goroutine into a waiting state and calls unlockf on the
// system stack.
//
// If unlockf returns false, the goroutine is resumed.
//
// unlockf must not access this G's stack, as it may be moved between
// the call to gopark and the call to unlockf.
//
// Note that because unlockf is called after putting the G into a waiting
// state, the G may have already been readied by the time unlockf is called
// unless there is external synchronization preventing the G from being
// readied. If unlockf returns false, it must guarantee that the G cannot be
// externally readied.
//
// Reason explains why the goroutine has been parked. It is displayed in stack
// traces and heap dumps. Reasons should be unique and descriptive. Do not
// re-use reasons, add new ones.
func gopark(unlockf func(*g, unsafe.Pointer) bool, lock unsafe.Pointer, reason waitReason, traceEv byte, traceskip int) {
if reason != waitReasonSleep {
checkTimeouts() // timeouts may expire while two goroutines keep the scheduler busy
}
mp := acquirem()
gp := mp.curg
status := readgstatus(gp)
if status != _Grunning && status != _Gscanrunning {
throw("gopark: bad g status")
}
mp.waitlock = lock
mp.waitunlockf = unlockf
gp.waitreason = reason
mp.waittraceev = traceEv
mp.waittraceskip = traceskip
releasem(mp)
// can't do anything that might move the G between Ms here.
mcall(park_m)
}
// Puts the current goroutine into a waiting state and unlocks the lock.
// The goroutine can be made runnable again by calling goready(gp).
func goparkunlock(lock *mutex, reason waitReason, traceEv byte, traceskip int) {
gopark(parkunlock_c, unsafe.Pointer(lock), reason, traceEv, traceskip)
}
func goready(gp *g, traceskip int) {
systemstack(func() {
ready(gp, traceskip, true)
})
}
//go:nosplit
func acquireSudog() *sudog {
// Delicate dance: the semaphore implementation calls
// acquireSudog, acquireSudog calls new(sudog),
// new calls malloc, malloc can call the garbage collector,
// and the garbage collector calls the semaphore implementation
// in stopTheWorld.
// Break the cycle by doing acquirem/releasem around new(sudog).
// The acquirem/releasem increments m.locks during new(sudog),
// which keeps the garbage collector from being invoked.
mp := acquirem()
pp := mp.p.ptr()
if len(pp.sudogcache) == 0 {
lock(&sched.sudoglock)
// First, try to grab a batch from central cache.
for len(pp.sudogcache) < cap(pp.sudogcache)/2 && sched.sudogcache != nil {
s := sched.sudogcache
sched.sudogcache = s.next
s.next = nil
pp.sudogcache = append(pp.sudogcache, s)
}
unlock(&sched.sudoglock)
// If the central cache is empty, allocate a new one.
if len(pp.sudogcache) == 0 {
pp.sudogcache = append(pp.sudogcache, new(sudog))
}
}
n := len(pp.sudogcache)
s := pp.sudogcache[n-1]
pp.sudogcache[n-1] = nil
pp.sudogcache = pp.sudogcache[:n-1]
if s.elem != nil {
throw("acquireSudog: found s.elem != nil in cache")
}
releasem(mp)
return s
}
//go:nosplit
func releaseSudog(s *sudog) {
if s.elem != nil {
throw("runtime: sudog with non-nil elem")
}
if s.isSelect {
throw("runtime: sudog with non-false isSelect")
}
if s.next != nil {
throw("runtime: sudog with non-nil next")
}
if s.prev != nil {
throw("runtime: sudog with non-nil prev")
}
if s.waitlink != nil {
throw("runtime: sudog with non-nil waitlink")
}
if s.c != nil {
throw("runtime: sudog with non-nil c")
}
gp := getg()
if gp.param != nil {
throw("runtime: releaseSudog with non-nil gp.param")
}
mp := acquirem() // avoid rescheduling to another P
pp := mp.p.ptr()
if len(pp.sudogcache) == cap(pp.sudogcache) {
// Transfer half of local cache to the central cache.
var first, last *sudog
for len(pp.sudogcache) > cap(pp.sudogcache)/2 {
n := len(pp.sudogcache)
p := pp.sudogcache[n-1]
pp.sudogcache[n-1] = nil
pp.sudogcache = pp.sudogcache[:n-1]
if first == nil {
first = p
} else {
last.next = p
}
last = p
}
lock(&sched.sudoglock)
last.next = sched.sudogcache
sched.sudogcache = first
unlock(&sched.sudoglock)
}
pp.sudogcache = append(pp.sudogcache, s)
releasem(mp)
}
// called from assembly.
func badmcall(fn func(*g)) {
throw("runtime: mcall called on m->g0 stack")
}
func badmcall2(fn func(*g)) {
throw("runtime: mcall function returned")
}
func badreflectcall() {
panic(plainError("arg size to reflect.call more than 1GB"))
}
//go:nosplit
//go:nowritebarrierrec
func badmorestackg0() {
writeErrStr("fatal: morestack on g0\n")
}
//go:nosplit
//go:nowritebarrierrec
func badmorestackgsignal() {
writeErrStr("fatal: morestack on gsignal\n")
}
//go:nosplit
func badctxt() {
throw("ctxt != 0")
}
func lockedOSThread() bool {
gp := getg()
return gp.lockedm != 0 && gp.m.lockedg != 0
}
var (
// allgs contains all Gs ever created (including dead Gs), and thus
// never shrinks.
//
// Access via the slice is protected by allglock or stop-the-world.
// Readers that cannot take the lock may (carefully!) use the atomic
// variables below.
allglock mutex
allgs []*g
// allglen and allgptr are atomic variables that contain len(allgs) and
// &allgs[0] respectively. Proper ordering depends on totally-ordered
// loads and stores. Writes are protected by allglock.
//
// allgptr is updated before allglen. Readers should read allglen
// before allgptr to ensure that allglen is always <= len(allgptr). New
// Gs appended during the race can be missed. For a consistent view of
// all Gs, allglock must be held.
//
// allgptr copies should always be stored as a concrete type or
// unsafe.Pointer, not uintptr, to ensure that GC can still reach it
// even if it points to a stale array.
allglen uintptr
allgptr **g
)
func allgadd(gp *g) {
if readgstatus(gp) == _Gidle {
throw("allgadd: bad status Gidle")
}
lock(&allglock)
allgs = append(allgs, gp)
if &allgs[0] != allgptr {
atomicstorep(unsafe.Pointer(&allgptr), unsafe.Pointer(&allgs[0]))
}
atomic.Storeuintptr(&allglen, uintptr(len(allgs)))
unlock(&allglock)
}
// allGsSnapshot returns a snapshot of the slice of all Gs.
//
// The world must be stopped or allglock must be held.
func allGsSnapshot() []*g {
assertWorldStoppedOrLockHeld(&allglock)
// Because the world is stopped or allglock is held, allgadd
// cannot happen concurrently with this. allgs grows
// monotonically and existing entries never change, so we can
// simply return a copy of the slice header. For added safety,
// we trim everything past len because that can still change.
return allgs[:len(allgs):len(allgs)]
}
// atomicAllG returns &allgs[0] and len(allgs) for use with atomicAllGIndex.
func atomicAllG() (**g, uintptr) {
length := atomic.Loaduintptr(&allglen)
ptr := (**g)(atomic.Loadp(unsafe.Pointer(&allgptr)))
return ptr, length
}
// atomicAllGIndex returns ptr[i] with the allgptr returned from atomicAllG.
func atomicAllGIndex(ptr **g, i uintptr) *g {
return *(**g)(add(unsafe.Pointer(ptr), i*goarch.PtrSize))
}
// forEachG calls fn on every G from allgs.
//
// forEachG takes a lock to exclude concurrent addition of new Gs.
func forEachG(fn func(gp *g)) {
lock(&allglock)
for _, gp := range allgs {
fn(gp)
}
unlock(&allglock)
}
// forEachGRace calls fn on every G from allgs.
//
// forEachGRace avoids locking, but does not exclude addition of new Gs during
// execution, which may be missed.
func forEachGRace(fn func(gp *g)) {
ptr, length := atomicAllG()
for i := uintptr(0); i < length; i++ {
gp := atomicAllGIndex(ptr, i)
fn(gp)
}
return
}
const (
// Number of goroutine ids to grab from sched.goidgen to local per-P cache at once.
// 16 seems to provide enough amortization, but other than that it's mostly arbitrary number.
_GoidCacheBatch = 16
)
// cpuinit sets up CPU feature flags and calls internal/cpu.Initialize. env should be the complete
// value of the GODEBUG environment variable.
func cpuinit(env string) {
switch GOOS {
case "aix", "darwin", "ios", "dragonfly", "freebsd", "netbsd", "openbsd", "illumos", "solaris", "linux":
cpu.DebugOptions = true
}
cpu.Initialize(env)
// Support cpu feature variables are used in code generated by the compiler
// to guard execution of instructions that can not be assumed to be always supported.
switch GOARCH {
case "386", "amd64":
x86HasPOPCNT = cpu.X86.HasPOPCNT
x86HasSSE41 = cpu.X86.HasSSE41
x86HasFMA = cpu.X86.HasFMA
case "arm":
armHasVFPv4 = cpu.ARM.HasVFPv4
case "arm64":
arm64HasATOMICS = cpu.ARM64.HasATOMICS
}
}
// getGodebugEarly extracts the environment variable GODEBUG from the environment on
// Unix-like operating systems and returns it. This function exists to extract GODEBUG
// early before much of the runtime is initialized.
func getGodebugEarly() string {
const prefix = "GODEBUG="
var env string
switch GOOS {
case "aix", "darwin", "ios", "dragonfly", "freebsd", "netbsd", "openbsd", "illumos", "solaris", "linux":
// Similar to goenv_unix but extracts the environment value for
// GODEBUG directly.
// TODO(moehrmann): remove when general goenvs() can be called before cpuinit()
n := int32(0)
for argv_index(argv, argc+1+n) != nil {
n++
}
for i := int32(0); i < n; i++ {
p := argv_index(argv, argc+1+i)
s := unsafe.String(p, findnull(p))
if hasPrefix(s, prefix) {
env = gostring(p)[len(prefix):]
break
}
}
}
return env
}
// The bootstrap sequence is:
//
// call osinit
// call schedinit
// make & queue new G
// call runtime·mstart
//
// The new G calls runtime·main.
func schedinit() {
lockInit(&sched.lock, lockRankSched)
lockInit(&sched.sysmonlock, lockRankSysmon)
lockInit(&sched.deferlock, lockRankDefer)
lockInit(&sched.sudoglock, lockRankSudog)
lockInit(&deadlock, lockRankDeadlock)
lockInit(&paniclk, lockRankPanic)
lockInit(&allglock, lockRankAllg)
lockInit(&allpLock, lockRankAllp)
lockInit(&reflectOffs.lock, lockRankReflectOffs)
lockInit(&finlock, lockRankFin)
lockInit(&trace.bufLock, lockRankTraceBuf)
lockInit(&trace.stringsLock, lockRankTraceStrings)
lockInit(&trace.lock, lockRankTrace)
lockInit(&cpuprof.lock, lockRankCpuprof)
lockInit(&trace.stackTab.lock, lockRankTraceStackTab)
// Enforce that this lock is always a leaf lock.
// All of this lock's critical sections should be
// extremely short.
lockInit(&memstats.heapStats.noPLock, lockRankLeafRank)
// raceinit must be the first call to race detector.
// In particular, it must be done before mallocinit below calls racemapshadow.
gp := getg()
if raceenabled {
gp.racectx, raceprocctx0 = raceinit()
}
sched.maxmcount = 10000
// The world starts stopped.
worldStopped()
moduledataverify()
stackinit()
mallocinit()
godebug := getGodebugEarly()
initPageTrace(godebug) // must run after mallocinit but before anything allocates
cpuinit(godebug) // must run before alginit
alginit() // maps, hash, fastrand must not be used before this call
fastrandinit() // must run before mcommoninit
mcommoninit(gp.m, -1)
modulesinit() // provides activeModules
typelinksinit() // uses maps, activeModules
itabsinit() // uses activeModules
stkobjinit() // must run before GC starts
sigsave(&gp.m.sigmask)
initSigmask = gp.m.sigmask
goargs()
goenvs()
parsedebugvars()
gcinit()
// if disableMemoryProfiling is set, update MemProfileRate to 0 to turn off memprofile.
// Note: parsedebugvars may update MemProfileRate, but when disableMemoryProfiling is
// set to true by the linker, it means that nothing is consuming the profile, it is
// safe to set MemProfileRate to 0.
if disableMemoryProfiling {
MemProfileRate = 0
}
lock(&sched.lock)
sched.lastpoll.Store(nanotime())
procs := ncpu
if n, ok := atoi32(gogetenv("GOMAXPROCS")); ok && n > 0 {
procs = n
}
if procresize(procs) != nil {
throw("unknown runnable goroutine during bootstrap")
}
unlock(&sched.lock)
// World is effectively started now, as P's can run.
worldStarted()
if buildVersion == "" {
// Condition should never trigger. This code just serves
// to ensure runtime·buildVersion is kept in the resulting binary.
buildVersion = "unknown"
}
if len(modinfo) == 1 {
// Condition should never trigger. This code just serves
// to ensure runtime·modinfo is kept in the resulting binary.
modinfo = ""
}
}
func dumpgstatus(gp *g) {
thisg := getg()
print("runtime: gp: gp=", gp, ", goid=", gp.goid, ", gp->atomicstatus=", readgstatus(gp), "\n")
print("runtime: getg: g=", thisg, ", goid=", thisg.goid, ", g->atomicstatus=", readgstatus(thisg), "\n")
}
// sched.lock must be held.
func checkmcount() {
assertLockHeld(&sched.lock)
if mcount() > sched.maxmcount {
print("runtime: program exceeds ", sched.maxmcount, "-thread limit\n")
throw("thread exhaustion")
}
}
// mReserveID returns the next ID to use for a new m. This new m is immediately
// considered 'running' by checkdead.
//
// sched.lock must be held.
func mReserveID() int64 {
assertLockHeld(&sched.lock)
if sched.mnext+1 < sched.mnext {
throw("runtime: thread ID overflow")
}
id := sched.mnext
sched.mnext++
checkmcount()
return id
}
// Pre-allocated ID may be passed as 'id', or omitted by passing -1.
func mcommoninit(mp *m, id int64) {
gp := getg()
// g0 stack won't make sense for user (and is not necessary unwindable).
if gp != gp.m.g0 {
callers(1, mp.createstack[:])
}
lock(&sched.lock)
if id >= 0 {
mp.id = id
} else {
mp.id = mReserveID()
}
lo := uint32(int64Hash(uint64(mp.id), fastrandseed))
hi := uint32(int64Hash(uint64(cputicks()), ^fastrandseed))
if lo|hi == 0 {
hi = 1
}
// Same behavior as for 1.17.
// TODO: Simplify this.
if goarch.BigEndian {
mp.fastrand = uint64(lo)<<32 | uint64(hi)
} else {
mp.fastrand = uint64(hi)<<32 | uint64(lo)
}
mpreinit(mp)
if mp.gsignal != nil {
mp.gsignal.stackguard1 = mp.gsignal.stack.lo + _StackGuard
}
// Add to allm so garbage collector doesn't free g->m
// when it is just in a register or thread-local storage.
mp.alllink = allm
// NumCgoCall() iterates over allm w/o schedlock,
// so we need to publish it safely.
atomicstorep(unsafe.Pointer(&allm), unsafe.Pointer(mp))
unlock(&sched.lock)
// Allocate memory to hold a cgo traceback if the cgo call crashes.
if iscgo || GOOS == "solaris" || GOOS == "illumos" || GOOS == "windows" {
mp.cgoCallers = new(cgoCallers)
}
}
func (mp *m) becomeSpinning() {
mp.spinning = true
sched.nmspinning.Add(1)
sched.needspinning.Store(0)
}
var fastrandseed uintptr
func fastrandinit() {
s := (*[unsafe.Sizeof(fastrandseed)]byte)(unsafe.Pointer(&fastrandseed))[:]
getRandomData(s)
}
// Mark gp ready to run.
func ready(gp *g, traceskip int, next bool) {
if trace.enabled {
traceGoUnpark(gp, traceskip)
}
status := readgstatus(gp)
// Mark runnable.
mp := acquirem() // disable preemption because it can be holding p in a local var
if status&^_Gscan != _Gwaiting {
dumpgstatus(gp)
throw("bad g->status in ready")
}
// status is Gwaiting or Gscanwaiting, make Grunnable and put on runq
casgstatus(gp, _Gwaiting, _Grunnable)
runqput(mp.p.ptr(), gp, next)
wakep()
releasem(mp)
}
// freezeStopWait is a large value that freezetheworld sets
// sched.stopwait to in order to request that all Gs permanently stop.
const freezeStopWait = 0x7fffffff
// freezing is set to non-zero if the runtime is trying to freeze the
// world.
var freezing atomic.Bool
// Similar to stopTheWorld but best-effort and can be called several times.
// There is no reverse operation, used during crashing.
// This function must not lock any mutexes.
func freezetheworld() {
freezing.Store(true)
// stopwait and preemption requests can be lost
// due to races with concurrently executing threads,
// so try several times
for i := 0; i < 5; i++ {
// this should tell the scheduler to not start any new goroutines
sched.stopwait = freezeStopWait
sched.gcwaiting.Store(true)
// this should stop running goroutines
if !preemptall() {
break // no running goroutines
}
usleep(1000)
}
// to be sure
usleep(1000)
preemptall()
usleep(1000)
}
// All reads and writes of g's status go through readgstatus, casgstatus
// castogscanstatus, casfrom_Gscanstatus.
//
//go:nosplit
func readgstatus(gp *g) uint32 {
return gp.atomicstatus.Load()
}
// The Gscanstatuses are acting like locks and this releases them.
// If it proves to be a performance hit we should be able to make these
// simple atomic stores but for now we are going to throw if
// we see an inconsistent state.
func casfrom_Gscanstatus(gp *g, oldval, newval uint32) {
success := false
// Check that transition is valid.
switch oldval {
default:
print("runtime: casfrom_Gscanstatus bad oldval gp=", gp, ", oldval=", hex(oldval), ", newval=", hex(newval), "\n")
dumpgstatus(gp)
throw("casfrom_Gscanstatus:top gp->status is not in scan state")
case _Gscanrunnable,
_Gscanwaiting,
_Gscanrunning,
_Gscansyscall,
_Gscanpreempted:
if newval == oldval&^_Gscan {
success = gp.atomicstatus.CompareAndSwap(oldval, newval)
}
}
if !success {
print("runtime: casfrom_Gscanstatus failed gp=", gp, ", oldval=", hex(oldval), ", newval=", hex(newval), "\n")
dumpgstatus(gp)
throw("casfrom_Gscanstatus: gp->status is not in scan state")
}
releaseLockRank(lockRankGscan)
}
// This will return false if the gp is not in the expected status and the cas fails.
// This acts like a lock acquire while the casfromgstatus acts like a lock release.
func castogscanstatus(gp *g, oldval, newval uint32) bool {
switch oldval {
case _Grunnable,
_Grunning,
_Gwaiting,
_Gsyscall:
if newval == oldval|_Gscan {
r := gp.atomicstatus.CompareAndSwap(oldval, newval)
if r {
acquireLockRank(lockRankGscan)
}
return r
}
}
print("runtime: castogscanstatus oldval=", hex(oldval), " newval=", hex(newval), "\n")
throw("castogscanstatus")
panic("not reached")
}
// casgstatusAlwaysTrack is a debug flag that causes casgstatus to always track
// various latencies on every transition instead of sampling them.
var casgstatusAlwaysTrack = false
// If asked to move to or from a Gscanstatus this will throw. Use the castogscanstatus
// and casfrom_Gscanstatus instead.
// casgstatus will loop if the g->atomicstatus is in a Gscan status until the routine that
// put it in the Gscan state is finished.
//
//go:nosplit
func casgstatus(gp *g, oldval, newval uint32) {
if (oldval&_Gscan != 0) || (newval&_Gscan != 0) || oldval == newval {
systemstack(func() {
print("runtime: casgstatus: oldval=", hex(oldval), " newval=", hex(newval), "\n")
throw("casgstatus: bad incoming values")
})
}
acquireLockRank(lockRankGscan)
releaseLockRank(lockRankGscan)
// See https://golang.org/cl/21503 for justification of the yield delay.
const yieldDelay = 5 * 1000
var nextYield int64
// loop if gp->atomicstatus is in a scan state giving
// GC time to finish and change the state to oldval.
for i := 0; !gp.atomicstatus.CompareAndSwap(oldval, newval); i++ {
if oldval == _Gwaiting && gp.atomicstatus.Load() == _Grunnable {
throw("casgstatus: waiting for Gwaiting but is Grunnable")
}
if i == 0 {
nextYield = nanotime() + yieldDelay
}
if nanotime() < nextYield {
for x := 0; x < 10 && gp.atomicstatus.Load() != oldval; x++ {
procyield(1)
}
} else {
osyield()
nextYield = nanotime() + yieldDelay/2
}
}
if oldval == _Grunning {
// Track every gTrackingPeriod time a goroutine transitions out of running.
if casgstatusAlwaysTrack || gp.trackingSeq%gTrackingPeriod == 0 {
gp.tracking = true
}
gp.trackingSeq++
}
if !gp.tracking {
return
}
// Handle various kinds of tracking.
//
// Currently:
// - Time spent in runnable.
// - Time spent blocked on a sync.Mutex or sync.RWMutex.
switch oldval {
case _Grunnable:
// We transitioned out of runnable, so measure how much
// time we spent in this state and add it to
// runnableTime.
now := nanotime()
gp.runnableTime += now - gp.trackingStamp
gp.trackingStamp = 0
case _Gwaiting:
if !gp.waitreason.isMutexWait() {
// Not blocking on a lock.
break
}
// Blocking on a lock, measure it. Note that because we're
// sampling, we have to multiply by our sampling period to get
// a more representative estimate of the absolute value.
// gTrackingPeriod also represents an accurate sampling period
// because we can only enter this state from _Grunning.
now := nanotime()
sched.totalMutexWaitTime.Add((now - gp.trackingStamp) * gTrackingPeriod)
gp.trackingStamp = 0
}
switch newval {
case _Gwaiting:
if !gp.waitreason.isMutexWait() {
// Not blocking on a lock.
break
}
// Blocking on a lock. Write down the timestamp.
now := nanotime()
gp.trackingStamp = now
case _Grunnable:
// We just transitioned into runnable, so record what
// time that happened.
now := nanotime()
gp.trackingStamp = now
case _Grunning:
// We're transitioning into running, so turn off
// tracking and record how much time we spent in
// runnable.
gp.tracking = false
sched.timeToRun.record(gp.runnableTime)
gp.runnableTime = 0
}
}
// casGToWaiting transitions gp from old to _Gwaiting, and sets the wait reason.
//
// Use this over casgstatus when possible to ensure that a waitreason is set.
func casGToWaiting(gp *g, old uint32, reason waitReason) {
// Set the wait reason before calling casgstatus, because casgstatus will use it.
gp.waitreason = reason
casgstatus(gp, old, _Gwaiting)
}
// casgstatus(gp, oldstatus, Gcopystack), assuming oldstatus is Gwaiting or Grunnable.
// Returns old status. Cannot call casgstatus directly, because we are racing with an
// async wakeup that might come in from netpoll. If we see Gwaiting from the readgstatus,
// it might have become Grunnable by the time we get to the cas. If we called casgstatus,
// it would loop waiting for the status to go back to Gwaiting, which it never will.
//
//go:nosplit
func casgcopystack(gp *g) uint32 {
for {
oldstatus := readgstatus(gp) &^ _Gscan
if oldstatus != _Gwaiting && oldstatus != _Grunnable {
throw("copystack: bad status, not Gwaiting or Grunnable")
}
if gp.atomicstatus.CompareAndSwap(oldstatus, _Gcopystack) {
return oldstatus
}
}
}
// casGToPreemptScan transitions gp from _Grunning to _Gscan|_Gpreempted.
//
// TODO(austin): This is the only status operation that both changes
// the status and locks the _Gscan bit. Rethink this.
func casGToPreemptScan(gp *g, old, new uint32) {
if old != _Grunning || new != _Gscan|_Gpreempted {
throw("bad g transition")
}
acquireLockRank(lockRankGscan)
for !gp.atomicstatus.CompareAndSwap(_Grunning, _Gscan|_Gpreempted) {
}
}
// casGFromPreempted attempts to transition gp from _Gpreempted to
// _Gwaiting. If successful, the caller is responsible for
// re-scheduling gp.
func casGFromPreempted(gp *g, old, new uint32) bool {
if old != _Gpreempted || new != _Gwaiting {
throw("bad g transition")
}
gp.waitreason = waitReasonPreempted
return gp.atomicstatus.CompareAndSwap(_Gpreempted, _Gwaiting)
}
// stopTheWorld stops all P's from executing goroutines, interrupting
// all goroutines at GC safe points and records reason as the reason
// for the stop. On return, only the current goroutine's P is running.
// stopTheWorld must not be called from a system stack and the caller
// must not hold worldsema. The caller must call startTheWorld when
// other P's should resume execution.
//
// stopTheWorld is safe for multiple goroutines to call at the
// same time. Each will execute its own stop, and the stops will
// be serialized.
//
// This is also used by routines that do stack dumps. If the system is
// in panic or being exited, this may not reliably stop all
// goroutines.
func stopTheWorld(reason string) {
semacquire(&worldsema)
gp := getg()
gp.m.preemptoff = reason
systemstack(func() {
// Mark the goroutine which called stopTheWorld preemptible so its
// stack may be scanned.
// This lets a mark worker scan us while we try to stop the world
// since otherwise we could get in a mutual preemption deadlock.
// We must not modify anything on the G stack because a stack shrink
// may occur. A stack shrink is otherwise OK though because in order
// to return from this function (and to leave the system stack) we
// must have preempted all goroutines, including any attempting
// to scan our stack, in which case, any stack shrinking will
// have already completed by the time we exit.
// Don't provide a wait reason because we're still executing.
casGToWaiting(gp, _Grunning, waitReasonStoppingTheWorld)
stopTheWorldWithSema()
casgstatus(gp, _Gwaiting, _Grunning)
})
}
// startTheWorld undoes the effects of stopTheWorld.
func startTheWorld() {
systemstack(func() { startTheWorldWithSema(false) })
// worldsema must be held over startTheWorldWithSema to ensure
// gomaxprocs cannot change while worldsema is held.
//
// Release worldsema with direct handoff to the next waiter, but
// acquirem so that semrelease1 doesn't try to yield our time.
//
// Otherwise if e.g. ReadMemStats is being called in a loop,
// it might stomp on other attempts to stop the world, such as
// for starting or ending GC. The operation this blocks is
// so heavy-weight that we should just try to be as fair as
// possible here.
//
// We don't want to just allow us to get preempted between now
// and releasing the semaphore because then we keep everyone
// (including, for example, GCs) waiting longer.
mp := acquirem()
mp.preemptoff = ""
semrelease1(&worldsema, true, 0)
releasem(mp)
}
// stopTheWorldGC has the same effect as stopTheWorld, but blocks
// until the GC is not running. It also blocks a GC from starting
// until startTheWorldGC is called.
func stopTheWorldGC(reason string) {
semacquire(&gcsema)
stopTheWorld(reason)
}
// startTheWorldGC undoes the effects of stopTheWorldGC.
func startTheWorldGC() {
startTheWorld()
semrelease(&gcsema)
}
// Holding worldsema grants an M the right to try to stop the world.
var worldsema uint32 = 1
// Holding gcsema grants the M the right to block a GC, and blocks
// until the current GC is done. In particular, it prevents gomaxprocs
// from changing concurrently.
//
// TODO(mknyszek): Once gomaxprocs and the execution tracer can handle
// being changed/enabled during a GC, remove this.
var gcsema uint32 = 1
// stopTheWorldWithSema is the core implementation of stopTheWorld.
// The caller is responsible for acquiring worldsema and disabling
// preemption first and then should stopTheWorldWithSema on the system
// stack:
//
// semacquire(&worldsema, 0)
// m.preemptoff = "reason"
// systemstack(stopTheWorldWithSema)
//
// When finished, the caller must either call startTheWorld or undo
// these three operations separately:
//
// m.preemptoff = ""
// systemstack(startTheWorldWithSema)
// semrelease(&worldsema)
//
// It is allowed to acquire worldsema once and then execute multiple
// startTheWorldWithSema/stopTheWorldWithSema pairs.
// Other P's are able to execute between successive calls to
// startTheWorldWithSema and stopTheWorldWithSema.
// Holding worldsema causes any other goroutines invoking
// stopTheWorld to block.
func stopTheWorldWithSema() {
gp := getg()
// If we hold a lock, then we won't be able to stop another M
// that is blocked trying to acquire the lock.
if gp.m.locks > 0 {
throw("stopTheWorld: holding locks")
}
lock(&sched.lock)
sched.stopwait = gomaxprocs
sched.gcwaiting.Store(true)
preemptall()
// stop current P
gp.m.p.ptr().status = _Pgcstop // Pgcstop is only diagnostic.
sched.stopwait--
// try to retake all P's in Psyscall status
for _, pp := range allp {
s := pp.status
if s == _Psyscall && atomic.Cas(&pp.status, s, _Pgcstop) {
if trace.enabled {
traceGoSysBlock(pp)
traceProcStop(pp)
}
pp.syscalltick++
sched.stopwait--
}
}
// stop idle P's
now := nanotime()
for {
pp, _ := pidleget(now)
if pp == nil {
break
}
pp.status = _Pgcstop
sched.stopwait--
}
wait := sched.stopwait > 0
unlock(&sched.lock)
// wait for remaining P's to stop voluntarily
if wait {
for {
// wait for 100us, then try to re-preempt in case of any races
if notetsleep(&sched.stopnote, 100*1000) {
noteclear(&sched.stopnote)
break
}
preemptall()
}
}
// sanity checks
bad := ""
if sched.stopwait != 0 {
bad = "stopTheWorld: not stopped (stopwait != 0)"
} else {
for _, pp := range allp {
if pp.status != _Pgcstop {
bad = "stopTheWorld: not stopped (status != _Pgcstop)"
}
}
}
if freezing.Load() {
// Some other thread is panicking. This can cause the
// sanity checks above to fail if the panic happens in
// the signal handler on a stopped thread. Either way,
// we should halt this thread.
lock(&deadlock)
lock(&deadlock)
}
if bad != "" {
throw(bad)
}
worldStopped()
}
func startTheWorldWithSema(emitTraceEvent bool) int64 {
assertWorldStopped()
mp := acquirem() // disable preemption because it can be holding p in a local var
if netpollinited() {
list := netpoll(0) // non-blocking
injectglist(&list)
}
lock(&sched.lock)
procs := gomaxprocs
if newprocs != 0 {
procs = newprocs
newprocs = 0
}
p1 := procresize(procs)
sched.gcwaiting.Store(false)
if sched.sysmonwait.Load() {
sched.sysmonwait.Store(false)
notewakeup(&sched.sysmonnote)
}
unlock(&sched.lock)
worldStarted()
for p1 != nil {
p := p1
p1 = p1.link.ptr()
if p.m != 0 {
mp := p.m.ptr()
p.m = 0
if mp.nextp != 0 {
throw("startTheWorld: inconsistent mp->nextp")
}
mp.nextp.set(p)
notewakeup(&mp.park)
} else {
// Start M to run P. Do not start another M below.
newm(nil, p, -1)
}
}
// Capture start-the-world time before doing clean-up tasks.
startTime := nanotime()
if emitTraceEvent {
traceGCSTWDone()
}
// Wakeup an additional proc in case we have excessive runnable goroutines
// in local queues or in the global queue. If we don't, the proc will park itself.
// If we have lots of excessive work, resetspinning will unpark additional procs as necessary.
wakep()
releasem(mp)
return startTime
}
// usesLibcall indicates whether this runtime performs system calls
// via libcall.
func usesLibcall() bool {
switch GOOS {
case "aix", "darwin", "illumos", "ios", "solaris", "windows":
return true
case "openbsd":
return GOARCH == "386" || GOARCH == "amd64" || GOARCH == "arm" || GOARCH == "arm64"
}
return false
}
// mStackIsSystemAllocated indicates whether this runtime starts on a
// system-allocated stack.
func mStackIsSystemAllocated() bool {
switch GOOS {
case "aix", "darwin", "plan9", "illumos", "ios", "solaris", "windows":
return true
case "openbsd":
switch GOARCH {
case "386", "amd64", "arm", "arm64":
return true
}
}
return false
}
// mstart is the entry-point for new Ms.
// It is written in assembly, uses ABI0, is marked TOPFRAME, and calls mstart0.
func mstart()
// mstart0 is the Go entry-point for new Ms.
// This must not split the stack because we may not even have stack
// bounds set up yet.
//
// May run during STW (because it doesn't have a P yet), so write
// barriers are not allowed.
//
//go:nosplit
//go:nowritebarrierrec
func mstart0() {
gp := getg()
osStack := gp.stack.lo == 0
if osStack {
// Initialize stack bounds from system stack.
// Cgo may have left stack size in stack.hi.
// minit may update the stack bounds.
//
// Note: these bounds may not be very accurate.
// We set hi to &size, but there are things above
// it. The 1024 is supposed to compensate this,
// but is somewhat arbitrary.
size := gp.stack.hi
if size == 0 {
size = 8192 * sys.StackGuardMultiplier
}
gp.stack.hi = uintptr(noescape(unsafe.Pointer(&size)))
gp.stack.lo = gp.stack.hi - size + 1024
}
// Initialize stack guard so that we can start calling regular
// Go code.
gp.stackguard0 = gp.stack.lo + _StackGuard
// This is the g0, so we can also call go:systemstack
// functions, which check stackguard1.
gp.stackguard1 = gp.stackguard0
mstart1()
// Exit this thread.
if mStackIsSystemAllocated() {
// Windows, Solaris, illumos, Darwin, AIX and Plan 9 always system-allocate
// the stack, but put it in gp.stack before mstart,
// so the logic above hasn't set osStack yet.
osStack = true
}
mexit(osStack)
}
// The go:noinline is to guarantee the getcallerpc/getcallersp below are safe,
// so that we can set up g0.sched to return to the call of mstart1 above.
//
//go:noinline
func mstart1() {
gp := getg()
if gp != gp.m.g0 {
throw("bad runtime·mstart")
}
// Set up m.g0.sched as a label returning to just
// after the mstart1 call in mstart0 above, for use by goexit0 and mcall.
// We're never coming back to mstart1 after we call schedule,
// so other calls can reuse the current frame.
// And goexit0 does a gogo that needs to return from mstart1
// and let mstart0 exit the thread.
gp.sched.g = guintptr(unsafe.Pointer(gp))
gp.sched.pc = getcallerpc()
gp.sched.sp = getcallersp()
asminit()
minit()
// Install signal handlers; after minit so that minit can
// prepare the thread to be able to handle the signals.
if gp.m == &m0 {
mstartm0()
}
if fn := gp.m.mstartfn; fn != nil {
fn()
}
if gp.m != &m0 {
acquirep(gp.m.nextp.ptr())
gp.m.nextp = 0
}
schedule()
}
// mstartm0 implements part of mstart1 that only runs on the m0.
//
// Write barriers are allowed here because we know the GC can't be
// running yet, so they'll be no-ops.
//
//go:yeswritebarrierrec
func mstartm0() {
// Create an extra M for callbacks on threads not created by Go.
// An extra M is also needed on Windows for callbacks created by
// syscall.NewCallback. See issue #6751 for details.
if (iscgo || GOOS == "windows") && !cgoHasExtraM {
cgoHasExtraM = true
newextram()
}
initsig(false)
}
// mPark causes a thread to park itself, returning once woken.
//
//go:nosplit
func mPark() {
gp := getg()
notesleep(&gp.m.park)
noteclear(&gp.m.park)
}
// mexit tears down and exits the current thread.
//
// Don't call this directly to exit the thread, since it must run at
// the top of the thread stack. Instead, use gogo(&gp.m.g0.sched) to
// unwind the stack to the point that exits the thread.
//
// It is entered with m.p != nil, so write barriers are allowed. It
// will release the P before exiting.
//
//go:yeswritebarrierrec
func mexit(osStack bool) {
mp := getg().m
if mp == &m0 {
// This is the main thread. Just wedge it.
//
// On Linux, exiting the main thread puts the process
// into a non-waitable zombie state. On Plan 9,
// exiting the main thread unblocks wait even though
// other threads are still running. On Solaris we can
// neither exitThread nor return from mstart. Other
// bad things probably happen on other platforms.
//
// We could try to clean up this M more before wedging
// it, but that complicates signal handling.
handoffp(releasep())
lock(&sched.lock)
sched.nmfreed++
checkdead()
unlock(&sched.lock)
mPark()
throw("locked m0 woke up")
}
sigblock(true)
unminit()
// Free the gsignal stack.
if mp.gsignal != nil {
stackfree(mp.gsignal.stack)
// On some platforms, when calling into VDSO (e.g. nanotime)
// we store our g on the gsignal stack, if there is one.
// Now the stack is freed, unlink it from the m, so we
// won't write to it when calling VDSO code.
mp.gsignal = nil
}
// Remove m from allm.
lock(&sched.lock)
for pprev := &allm; *pprev != nil; pprev = &(*pprev).alllink {
if *pprev == mp {
*pprev = mp.alllink
goto found
}
}
throw("m not found in allm")
found:
// Delay reaping m until it's done with the stack.
//
// Put mp on the free list, though it will not be reaped while freeWait
// is freeMWait. mp is no longer reachable via allm, so even if it is
// on an OS stack, we must keep a reference to mp alive so that the GC
// doesn't free mp while we are still using it.
//
// Note that the free list must not be linked through alllink because
// some functions walk allm without locking, so may be using alllink.
mp.freeWait.Store(freeMWait)
mp.freelink = sched.freem
sched.freem = mp
unlock(&sched.lock)
atomic.Xadd64(&ncgocall, int64(mp.ncgocall))
// Release the P.
handoffp(releasep())
// After this point we must not have write barriers.
// Invoke the deadlock detector. This must happen after
// handoffp because it may have started a new M to take our
// P's work.
lock(&sched.lock)
sched.nmfreed++
checkdead()
unlock(&sched.lock)
if GOOS == "darwin" || GOOS == "ios" {
// Make sure pendingPreemptSignals is correct when an M exits.
// For #41702.
if mp.signalPending.Load() != 0 {
pendingPreemptSignals.Add(-1)
}
}
// Destroy all allocated resources. After this is called, we may no
// longer take any locks.
mdestroy(mp)
if osStack {
// No more uses of mp, so it is safe to drop the reference.
mp.freeWait.Store(freeMRef)
// Return from mstart and let the system thread
// library free the g0 stack and terminate the thread.
return
}
// mstart is the thread's entry point, so there's nothing to
// return to. Exit the thread directly. exitThread will clear
// m.freeWait when it's done with the stack and the m can be
// reaped.
exitThread(&mp.freeWait)
}
// forEachP calls fn(p) for every P p when p reaches a GC safe point.
// If a P is currently executing code, this will bring the P to a GC
// safe point and execute fn on that P. If the P is not executing code
// (it is idle or in a syscall), this will call fn(p) directly while
// preventing the P from exiting its state. This does not ensure that
// fn will run on every CPU executing Go code, but it acts as a global
// memory barrier. GC uses this as a "ragged barrier."
//
// The caller must hold worldsema.
//
//go:systemstack
func forEachP(fn func(*p)) {
mp := acquirem()
pp := getg().m.p.ptr()
lock(&sched.lock)
if sched.safePointWait != 0 {
throw("forEachP: sched.safePointWait != 0")
}
sched.safePointWait = gomaxprocs - 1
sched.safePointFn = fn
// Ask all Ps to run the safe point function.
for _, p2 := range allp {
if p2 != pp {
atomic.Store(&p2.runSafePointFn, 1)
}
}
preemptall()
// Any P entering _Pidle or _Psyscall from now on will observe
// p.runSafePointFn == 1 and will call runSafePointFn when
// changing its status to _Pidle/_Psyscall.
// Run safe point function for all idle Ps. sched.pidle will
// not change because we hold sched.lock.
for p := sched.pidle.ptr(); p != nil; p = p.link.ptr() {
if atomic.Cas(&p.runSafePointFn, 1, 0) {
fn(p)
sched.safePointWait--
}
}
wait := sched.safePointWait > 0
unlock(&sched.lock)
// Run fn for the current P.
fn(pp)
// Force Ps currently in _Psyscall into _Pidle and hand them
// off to induce safe point function execution.
for _, p2 := range allp {
s := p2.status
if s == _Psyscall && p2.runSafePointFn == 1 && atomic.Cas(&p2.status, s, _Pidle) {
if trace.enabled {
traceGoSysBlock(p2)
traceProcStop(p2)
}
p2.syscalltick++
handoffp(p2)
}
}
// Wait for remaining Ps to run fn.
if wait {
for {
// Wait for 100us, then try to re-preempt in
// case of any races.
//
// Requires system stack.
if notetsleep(&sched.safePointNote, 100*1000) {
noteclear(&sched.safePointNote)
break
}
preemptall()
}
}
if sched.safePointWait != 0 {
throw("forEachP: not done")
}
for _, p2 := range allp {
if p2.runSafePointFn != 0 {
throw("forEachP: P did not run fn")
}
}
lock(&sched.lock)
sched.safePointFn = nil
unlock(&sched.lock)
releasem(mp)
}
// runSafePointFn runs the safe point function, if any, for this P.
// This should be called like
//
// if getg().m.p.runSafePointFn != 0 {
// runSafePointFn()
// }
//
// runSafePointFn must be checked on any transition in to _Pidle or
// _Psyscall to avoid a race where forEachP sees that the P is running
// just before the P goes into _Pidle/_Psyscall and neither forEachP
// nor the P run the safe-point function.
func runSafePointFn() {
p := getg().m.p.ptr()
// Resolve the race between forEachP running the safe-point
// function on this P's behalf and this P running the
// safe-point function directly.
if !atomic.Cas(&p.runSafePointFn, 1, 0) {
return
}
sched.safePointFn(p)
lock(&sched.lock)
sched.safePointWait--
if sched.safePointWait == 0 {
notewakeup(&sched.safePointNote)
}
unlock(&sched.lock)
}
// When running with cgo, we call _cgo_thread_start
// to start threads for us so that we can play nicely with
// foreign code.
var cgoThreadStart unsafe.Pointer
type cgothreadstart struct {
g guintptr
tls *uint64
fn unsafe.Pointer
}
// Allocate a new m unassociated with any thread.
// Can use p for allocation context if needed.
// fn is recorded as the new m's m.mstartfn.
// id is optional pre-allocated m ID. Omit by passing -1.
//
// This function is allowed to have write barriers even if the caller
// isn't because it borrows pp.
//
//go:yeswritebarrierrec
func allocm(pp *p, fn func(), id int64) *m {
allocmLock.rlock()
// The caller owns pp, but we may borrow (i.e., acquirep) it. We must
// disable preemption to ensure it is not stolen, which would make the
// caller lose ownership.
acquirem()
gp := getg()
if gp.m.p == 0 {
acquirep(pp) // temporarily borrow p for mallocs in this function
}
// Release the free M list. We need to do this somewhere and
// this may free up a stack we can use.
if sched.freem != nil {
lock(&sched.lock)
var newList *m
for freem := sched.freem; freem != nil; {
wait := freem.freeWait.Load()
if wait == freeMWait {
next := freem.freelink
freem.freelink = newList
newList = freem
freem = next
continue
}
// Free the stack if needed. For freeMRef, there is
// nothing to do except drop freem from the sched.freem
// list.
if wait == freeMStack {
// stackfree must be on the system stack, but allocm is
// reachable off the system stack transitively from
// startm.
systemstack(func() {
stackfree(freem.g0.stack)
})
}
freem = freem.freelink
}
sched.freem = newList
unlock(&sched.lock)
}
mp := new(m)
mp.mstartfn = fn
mcommoninit(mp, id)
// In case of cgo or Solaris or illumos or Darwin, pthread_create will make us a stack.
// Windows and Plan 9 will layout sched stack on OS stack.
if iscgo || mStackIsSystemAllocated() {
mp.g0 = malg(-1)
} else {
mp.g0 = malg(8192 * sys.StackGuardMultiplier)
}
mp.g0.m = mp
if pp == gp.m.p.ptr() {
releasep()
}
releasem(gp.m)
allocmLock.runlock()
return mp
}
// needm is called when a cgo callback happens on a
// thread without an m (a thread not created by Go).
// In this case, needm is expected to find an m to use
// and return with m, g initialized correctly.
// Since m and g are not set now (likely nil, but see below)
// needm is limited in what routines it can call. In particular
// it can only call nosplit functions (textflag 7) and cannot
// do any scheduling that requires an m.
//
// In order to avoid needing heavy lifting here, we adopt
// the following strategy: there is a stack of available m's
// that can be stolen. Using compare-and-swap
// to pop from the stack has ABA races, so we simulate
// a lock by doing an exchange (via Casuintptr) to steal the stack
// head and replace the top pointer with MLOCKED (1).
// This serves as a simple spin lock that we can use even
// without an m. The thread that locks the stack in this way
// unlocks the stack by storing a valid stack head pointer.
//
// In order to make sure that there is always an m structure
// available to be stolen, we maintain the invariant that there
// is always one more than needed. At the beginning of the
// program (if cgo is in use) the list is seeded with a single m.
// If needm finds that it has taken the last m off the list, its job
// is - once it has installed its own m so that it can do things like
// allocate memory - to create a spare m and put it on the list.
//
// Each of these extra m's also has a g0 and a curg that are
// pressed into service as the scheduling stack and current
// goroutine for the duration of the cgo callback.
//
// When the callback is done with the m, it calls dropm to
// put the m back on the list.
//
//go:nosplit
func needm() {
if (iscgo || GOOS == "windows") && !cgoHasExtraM {
// Can happen if C/C++ code calls Go from a global ctor.
// Can also happen on Windows if a global ctor uses a
// callback created by syscall.NewCallback. See issue #6751
// for details.
//
// Can not throw, because scheduler is not initialized yet.
writeErrStr("fatal error: cgo callback before cgo call\n")
exit(1)
}
// Save and block signals before getting an M.
// The signal handler may call needm itself,
// and we must avoid a deadlock. Also, once g is installed,
// any incoming signals will try to execute,
// but we won't have the sigaltstack settings and other data
// set up appropriately until the end of minit, which will
// unblock the signals. This is the same dance as when
// starting a new m to run Go code via newosproc.
var sigmask sigset
sigsave(&sigmask)
sigblock(false)
// Lock extra list, take head, unlock popped list.
// nilokay=false is safe here because of the invariant above,
// that the extra list always contains or will soon contain
// at least one m.
mp := lockextra(false)
// Set needextram when we've just emptied the list,
// so that the eventual call into cgocallbackg will
// allocate a new m for the extra list. We delay the
// allocation until then so that it can be done
// after exitsyscall makes sure it is okay to be
// running at all (that is, there's no garbage collection
// running right now).
mp.needextram = mp.schedlink == 0
extraMCount--
unlockextra(mp.schedlink.ptr())
// Store the original signal mask for use by minit.
mp.sigmask = sigmask
// Install TLS on some platforms (previously setg
// would do this if necessary).
osSetupTLS(mp)
// Install g (= m->g0) and set the stack bounds
// to match the current stack. We don't actually know
// how big the stack is, like we don't know how big any
// scheduling stack is, but we assume there's at least 32 kB,
// which is more than enough for us.
setg(mp.g0)
gp := getg()
gp.stack.hi = getcallersp() + 1024
gp.stack.lo = getcallersp() - 32*1024
gp.stackguard0 = gp.stack.lo + _StackGuard
// Initialize this thread to use the m.
asminit()
minit()
// mp.curg is now a real goroutine.
casgstatus(mp.curg, _Gdead, _Gsyscall)
sched.ngsys.Add(-1)
}
// newextram allocates m's and puts them on the extra list.
// It is called with a working local m, so that it can do things
// like call schedlock and allocate.
func newextram() {
c := extraMWaiters.Swap(0)
if c > 0 {
for i := uint32(0); i < c; i++ {
oneNewExtraM()
}
} else {
// Make sure there is at least one extra M.
mp := lockextra(true)
unlockextra(mp)
if mp == nil {
oneNewExtraM()
}
}
}
// oneNewExtraM allocates an m and puts it on the extra list.
func oneNewExtraM() {
// Create extra goroutine locked to extra m.
// The goroutine is the context in which the cgo callback will run.
// The sched.pc will never be returned to, but setting it to
// goexit makes clear to the traceback routines where
// the goroutine stack ends.
mp := allocm(nil, nil, -1)
gp := malg(4096)
gp.sched.pc = abi.FuncPCABI0(goexit) + sys.PCQuantum
gp.sched.sp = gp.stack.hi
gp.sched.sp -= 4 * goarch.PtrSize // extra space in case of reads slightly beyond frame
gp.sched.lr = 0
gp.sched.g = guintptr(unsafe.Pointer(gp))
gp.syscallpc = gp.sched.pc
gp.syscallsp = gp.sched.sp
gp.stktopsp = gp.sched.sp
// malg returns status as _Gidle. Change to _Gdead before
// adding to allg where GC can see it. We use _Gdead to hide
// this from tracebacks and stack scans since it isn't a
// "real" goroutine until needm grabs it.
casgstatus(gp, _Gidle, _Gdead)
gp.m = mp
mp.curg = gp
mp.isextra = true
mp.lockedInt++
mp.lockedg.set(gp)
gp.lockedm.set(mp)
gp.goid = sched.goidgen.Add(1)
gp.sysblocktraced = true
if raceenabled {
gp.racectx = racegostart(abi.FuncPCABIInternal(newextram) + sys.PCQuantum)
}
if trace.enabled {
// Trigger two trace events for the locked g in the extra m,
// since the next event of the g will be traceEvGoSysExit in exitsyscall,
// while calling from C thread to Go.
traceGoCreate(gp, 0) // no start pc
gp.traceseq++
traceEvent(traceEvGoInSyscall, -1, gp.goid)
}
// put on allg for garbage collector
allgadd(gp)
// gp is now on the allg list, but we don't want it to be
// counted by gcount. It would be more "proper" to increment
// sched.ngfree, but that requires locking. Incrementing ngsys
// has the same effect.
sched.ngsys.Add(1)
// Add m to the extra list.
mnext := lockextra(true)
mp.schedlink.set(mnext)
extraMCount++
unlockextra(mp)
}
// dropm is called when a cgo callback has called needm but is now
// done with the callback and returning back into the non-Go thread.
// It puts the current m back onto the extra list.
//
// The main expense here is the call to signalstack to release the
// m's signal stack, and then the call to needm on the next callback
// from this thread. It is tempting to try to save the m for next time,
// which would eliminate both these costs, but there might not be
// a next time: the current thread (which Go does not control) might exit.
// If we saved the m for that thread, there would be an m leak each time
// such a thread exited. Instead, we acquire and release an m on each
// call. These should typically not be scheduling operations, just a few
// atomics, so the cost should be small.
//
// TODO(rsc): An alternative would be to allocate a dummy pthread per-thread
// variable using pthread_key_create. Unlike the pthread keys we already use
// on OS X, this dummy key would never be read by Go code. It would exist
// only so that we could register at thread-exit-time destructor.
// That destructor would put the m back onto the extra list.
// This is purely a performance optimization. The current version,
// in which dropm happens on each cgo call, is still correct too.
// We may have to keep the current version on systems with cgo
// but without pthreads, like Windows.
func dropm() {
// Clear m and g, and return m to the extra list.
// After the call to setg we can only call nosplit functions
// with no pointer manipulation.
mp := getg().m
// Return mp.curg to dead state.
casgstatus(mp.curg, _Gsyscall, _Gdead)
mp.curg.preemptStop = false
sched.ngsys.Add(1)
// Block signals before unminit.
// Unminit unregisters the signal handling stack (but needs g on some systems).
// Setg(nil) clears g, which is the signal handler's cue not to run Go handlers.
// It's important not to try to handle a signal between those two steps.
sigmask := mp.sigmask
sigblock(false)
unminit()
mnext := lockextra(true)
extraMCount++
mp.schedlink.set(mnext)
setg(nil)
// Commit the release of mp.
unlockextra(mp)
msigrestore(sigmask)
}
// A helper function for EnsureDropM.
func getm() uintptr {
return uintptr(unsafe.Pointer(getg().m))
}
var extram atomic.Uintptr
var extraMCount uint32 // Protected by lockextra
var extraMWaiters atomic.Uint32
// lockextra locks the extra list and returns the list head.
// The caller must unlock the list by storing a new list head
// to extram. If nilokay is true, then lockextra will
// return a nil list head if that's what it finds. If nilokay is false,
// lockextra will keep waiting until the list head is no longer nil.
//
//go:nosplit
func lockextra(nilokay bool) *m {
const locked = 1
incr := false
for {
old := extram.Load()
if old == locked {
osyield_no_g()
continue
}
if old == 0 && !nilokay {
if !incr {
// Add 1 to the number of threads
// waiting for an M.
// This is cleared by newextram.
extraMWaiters.Add(1)
incr = true
}
usleep_no_g(1)
continue
}
if extram.CompareAndSwap(old, locked) {
return (*m)(unsafe.Pointer(old))
}
osyield_no_g()
continue
}
}
//go:nosplit
func unlockextra(mp *m) {
extram.Store(uintptr(unsafe.Pointer(mp)))
}
var (
// allocmLock is locked for read when creating new Ms in allocm and their
// addition to allm. Thus acquiring this lock for write blocks the
// creation of new Ms.
allocmLock rwmutex
// execLock serializes exec and clone to avoid bugs or unspecified
// behaviour around exec'ing while creating/destroying threads. See
// issue #19546.
execLock rwmutex
)
// These errors are reported (via writeErrStr) by some OS-specific
// versions of newosproc and newosproc0.
const (
failthreadcreate = "runtime: failed to create new OS thread\n"
failallocatestack = "runtime: failed to allocate stack for the new OS thread\n"
)
// newmHandoff contains a list of m structures that need new OS threads.
// This is used by newm in situations where newm itself can't safely
// start an OS thread.
var newmHandoff struct {
lock mutex
// newm points to a list of M structures that need new OS
// threads. The list is linked through m.schedlink.
newm muintptr
// waiting indicates that wake needs to be notified when an m
// is put on the list.
waiting bool
wake note
// haveTemplateThread indicates that the templateThread has
// been started. This is not protected by lock. Use cas to set
// to 1.
haveTemplateThread uint32
}
// Create a new m. It will start off with a call to fn, or else the scheduler.
// fn needs to be static and not a heap allocated closure.
// May run with m.p==nil, so write barriers are not allowed.
//
// id is optional pre-allocated m ID. Omit by passing -1.
//
//go:nowritebarrierrec
func newm(fn func(), pp *p, id int64) {
// allocm adds a new M to allm, but they do not start until created by
// the OS in newm1 or the template thread.
//
// doAllThreadsSyscall requires that every M in allm will eventually
// start and be signal-able, even with a STW.
//
// Disable preemption here until we start the thread to ensure that
// newm is not preempted between allocm and starting the new thread,
// ensuring that anything added to allm is guaranteed to eventually
// start.
acquirem()
mp := allocm(pp, fn, id)
mp.nextp.set(pp)
mp.sigmask = initSigmask
if gp := getg(); gp != nil && gp.m != nil && (gp.m.lockedExt != 0 || gp.m.incgo) && GOOS != "plan9" {
// We're on a locked M or a thread that may have been
// started by C. The kernel state of this thread may
// be strange (the user may have locked it for that
// purpose). We don't want to clone that into another
// thread. Instead, ask a known-good thread to create
// the thread for us.
//
// This is disabled on Plan 9. See golang.org/issue/22227.
//
// TODO: This may be unnecessary on Windows, which
// doesn't model thread creation off fork.
lock(&newmHandoff.lock)
if newmHandoff.haveTemplateThread == 0 {
throw("on a locked thread with no template thread")
}
mp.schedlink = newmHandoff.newm
newmHandoff.newm.set(mp)
if newmHandoff.waiting {
newmHandoff.waiting = false
notewakeup(&newmHandoff.wake)
}
unlock(&newmHandoff.lock)
// The M has not started yet, but the template thread does not
// participate in STW, so it will always process queued Ms and
// it is safe to releasem.
releasem(getg().m)
return
}
newm1(mp)
releasem(getg().m)
}
func newm1(mp *m) {
if iscgo {
var ts cgothreadstart
if _cgo_thread_start == nil {
throw("_cgo_thread_start missing")
}
ts.g.set(mp.g0)
ts.tls = (*uint64)(unsafe.Pointer(&mp.tls[0]))
ts.fn = unsafe.Pointer(abi.FuncPCABI0(mstart))
if msanenabled {
msanwrite(unsafe.Pointer(&ts), unsafe.Sizeof(ts))
}
if asanenabled {
asanwrite(unsafe.Pointer(&ts), unsafe.Sizeof(ts))
}
execLock.rlock() // Prevent process clone.
asmcgocall(_cgo_thread_start, unsafe.Pointer(&ts))
execLock.runlock()
return
}
execLock.rlock() // Prevent process clone.
newosproc(mp)
execLock.runlock()
}
// startTemplateThread starts the template thread if it is not already
// running.
//
// The calling thread must itself be in a known-good state.
func startTemplateThread() {
if GOARCH == "wasm" { // no threads on wasm yet
return
}
// Disable preemption to guarantee that the template thread will be
// created before a park once haveTemplateThread is set.
mp := acquirem()
if !atomic.Cas(&newmHandoff.haveTemplateThread, 0, 1) {
releasem(mp)
return
}
newm(templateThread, nil, -1)
releasem(mp)
}
// templateThread is a thread in a known-good state that exists solely
// to start new threads in known-good states when the calling thread
// may not be in a good state.
//
// Many programs never need this, so templateThread is started lazily
// when we first enter a state that might lead to running on a thread
// in an unknown state.
//
// templateThread runs on an M without a P, so it must not have write
// barriers.
//
//go:nowritebarrierrec
func templateThread() {
lock(&sched.lock)
sched.nmsys++
checkdead()
unlock(&sched.lock)
for {
lock(&newmHandoff.lock)
for newmHandoff.newm != 0 {
newm := newmHandoff.newm.ptr()
newmHandoff.newm = 0
unlock(&newmHandoff.lock)
for newm != nil {
next := newm.schedlink.ptr()
newm.schedlink = 0
newm1(newm)
newm = next
}
lock(&newmHandoff.lock)
}
newmHandoff.waiting = true
noteclear(&newmHandoff.wake)
unlock(&newmHandoff.lock)
notesleep(&newmHandoff.wake)
}
}
// Stops execution of the current m until new work is available.
// Returns with acquired P.
func stopm() {
gp := getg()
if gp.m.locks != 0 {
throw("stopm holding locks")
}
if gp.m.p != 0 {
throw("stopm holding p")
}
if gp.m.spinning {
throw("stopm spinning")
}
lock(&sched.lock)
mput(gp.m)
unlock(&sched.lock)
mPark()
acquirep(gp.m.nextp.ptr())
gp.m.nextp = 0
}
func mspinning() {
// startm's caller incremented nmspinning. Set the new M's spinning.
getg().m.spinning = true
}
// Schedules some M to run the p (creates an M if necessary).
// If p==nil, tries to get an idle P, if no idle P's does nothing.
// May run with m.p==nil, so write barriers are not allowed.
// If spinning is set, the caller has incremented nmspinning and must provide a
// P. startm will set m.spinning in the newly started M.
//
// Callers passing a non-nil P must call from a non-preemptible context. See
// comment on acquirem below.
//
// Must not have write barriers because this may be called without a P.
//
//go:nowritebarrierrec
func startm(pp *p, spinning bool) {
// Disable preemption.
//
// Every owned P must have an owner that will eventually stop it in the
// event of a GC stop request. startm takes transient ownership of a P
// (either from argument or pidleget below) and transfers ownership to
// a started M, which will be responsible for performing the stop.
//
// Preemption must be disabled during this transient ownership,
// otherwise the P this is running on may enter GC stop while still
// holding the transient P, leaving that P in limbo and deadlocking the
// STW.
//
// Callers passing a non-nil P must already be in non-preemptible
// context, otherwise such preemption could occur on function entry to
// startm. Callers passing a nil P may be preemptible, so we must
// disable preemption before acquiring a P from pidleget below.
mp := acquirem()
lock(&sched.lock)
if pp == nil {
if spinning {
// TODO(prattmic): All remaining calls to this function
// with _p_ == nil could be cleaned up to find a P
// before calling startm.
throw("startm: P required for spinning=true")
}
pp, _ = pidleget(0)
if pp == nil {
unlock(&sched.lock)
releasem(mp)
return
}
}
nmp := mget()
if nmp == nil {
// No M is available, we must drop sched.lock and call newm.
// However, we already own a P to assign to the M.
//
// Once sched.lock is released, another G (e.g., in a syscall),
// could find no idle P while checkdead finds a runnable G but
// no running M's because this new M hasn't started yet, thus
// throwing in an apparent deadlock.
//
// Avoid this situation by pre-allocating the ID for the new M,
// thus marking it as 'running' before we drop sched.lock. This
// new M will eventually run the scheduler to execute any
// queued G's.
id := mReserveID()
unlock(&sched.lock)
var fn func()
if spinning {
// The caller incremented nmspinning, so set m.spinning in the new M.
fn = mspinning
}
newm(fn, pp, id)
// Ownership transfer of pp committed by start in newm.
// Preemption is now safe.
releasem(mp)
return
}
unlock(&sched.lock)
if nmp.spinning {
throw("startm: m is spinning")
}
if nmp.nextp != 0 {
throw("startm: m has p")
}
if spinning && !runqempty(pp) {
throw("startm: p has runnable gs")
}
// The caller incremented nmspinning, so set m.spinning in the new M.
nmp.spinning = spinning
nmp.nextp.set(pp)
notewakeup(&nmp.park)
// Ownership transfer of pp committed by wakeup. Preemption is now
// safe.
releasem(mp)
}
// Hands off P from syscall or locked M.
// Always runs without a P, so write barriers are not allowed.
//
//go:nowritebarrierrec
func handoffp(pp *p) {
// handoffp must start an M in any situation where
// findrunnable would return a G to run on pp.
// if it has local work, start it straight away
if !runqempty(pp) || sched.runqsize != 0 {
startm(pp, false)
return
}
// if there's trace work to do, start it straight away
if (trace.enabled || trace.shutdown) && traceReaderAvailable() != nil {
startm(pp, false)
return
}
// if it has GC work, start it straight away
if gcBlackenEnabled != 0 && gcMarkWorkAvailable(pp) {
startm(pp, false)
return
}
// no local work, check that there are no spinning/idle M's,
// otherwise our help is not required
if sched.nmspinning.Load()+sched.npidle.Load() == 0 && sched.nmspinning.CompareAndSwap(0, 1) { // TODO: fast atomic
sched.needspinning.Store(0)
startm(pp, true)
return
}
lock(&sched.lock)
if sched.gcwaiting.Load() {
pp.status = _Pgcstop
sched.stopwait--
if sched.stopwait == 0 {
notewakeup(&sched.stopnote)
}
unlock(&sched.lock)
return
}
if pp.runSafePointFn != 0 && atomic.Cas(&pp.runSafePointFn, 1, 0) {
sched.safePointFn(pp)
sched.safePointWait--
if sched.safePointWait == 0 {
notewakeup(&sched.safePointNote)
}
}
if sched.runqsize != 0 {
unlock(&sched.lock)
startm(pp, false)
return
}
// If this is the last running P and nobody is polling network,
// need to wakeup another M to poll network.
if sched.npidle.Load() == gomaxprocs-1 && sched.lastpoll.Load() != 0 {
unlock(&sched.lock)
startm(pp, false)
return
}
// The scheduler lock cannot be held when calling wakeNetPoller below
// because wakeNetPoller may call wakep which may call startm.
when := nobarrierWakeTime(pp)
pidleput(pp, 0)
unlock(&sched.lock)
if when != 0 {
wakeNetPoller(when)
}
}
// Tries to add one more P to execute G's.
// Called when a G is made runnable (newproc, ready).
// Must be called with a P.
func wakep() {
// Be conservative about spinning threads, only start one if none exist
// already.
if sched.nmspinning.Load() != 0 || !sched.nmspinning.CompareAndSwap(0, 1) {
return
}
// Disable preemption until ownership of pp transfers to the next M in
// startm. Otherwise preemption here would leave pp stuck waiting to
// enter _Pgcstop.
//
// See preemption comment on acquirem in startm for more details.
mp := acquirem()
var pp *p
lock(&sched.lock)
pp, _ = pidlegetSpinning(0)
if pp == nil {
if sched.nmspinning.Add(-1) < 0 {
throw("wakep: negative nmspinning")
}
unlock(&sched.lock)
releasem(mp)
return
}
// Since we always have a P, the race in the "No M is available"
// comment in startm doesn't apply during the small window between the
// unlock here and lock in startm. A checkdead in between will always
// see at least one running M (ours).
unlock(&sched.lock)
startm(pp, true)
releasem(mp)
}
// Stops execution of the current m that is locked to a g until the g is runnable again.
// Returns with acquired P.
func stoplockedm() {
gp := getg()
if gp.m.lockedg == 0 || gp.m.lockedg.ptr().lockedm.ptr() != gp.m {
throw("stoplockedm: inconsistent locking")
}
if gp.m.p != 0 {
// Schedule another M to run this p.
pp := releasep()
handoffp(pp)
}
incidlelocked(1)
// Wait until another thread schedules lockedg again.
mPark()
status := readgstatus(gp.m.lockedg.ptr())
if status&^_Gscan != _Grunnable {
print("runtime:stoplockedm: lockedg (atomicstatus=", status, ") is not Grunnable or Gscanrunnable\n")
dumpgstatus(gp.m.lockedg.ptr())
throw("stoplockedm: not runnable")
}
acquirep(gp.m.nextp.ptr())
gp.m.nextp = 0
}
// Schedules the locked m to run the locked gp.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func startlockedm(gp *g) {
mp := gp.lockedm.ptr()
if mp == getg().m {
throw("startlockedm: locked to me")
}
if mp.nextp != 0 {
throw("startlockedm: m has p")
}
// directly handoff current P to the locked m
incidlelocked(-1)
pp := releasep()
mp.nextp.set(pp)
notewakeup(&mp.park)
stopm()
}
// Stops the current m for stopTheWorld.
// Returns when the world is restarted.
func gcstopm() {
gp := getg()
if !sched.gcwaiting.Load() {
throw("gcstopm: not waiting for gc")
}
if gp.m.spinning {
gp.m.spinning = false
// OK to just drop nmspinning here,
// startTheWorld will unpark threads as necessary.
if sched.nmspinning.Add(-1) < 0 {
throw("gcstopm: negative nmspinning")
}
}
pp := releasep()
lock(&sched.lock)
pp.status = _Pgcstop
sched.stopwait--
if sched.stopwait == 0 {
notewakeup(&sched.stopnote)
}
unlock(&sched.lock)
stopm()
}
// Schedules gp to run on the current M.
// If inheritTime is true, gp inherits the remaining time in the
// current time slice. Otherwise, it starts a new time slice.
// Never returns.
//
// Write barriers are allowed because this is called immediately after
// acquiring a P in several places.
//
//go:yeswritebarrierrec
func execute(gp *g, inheritTime bool) {
mp := getg().m
if goroutineProfile.active {
// Make sure that gp has had its stack written out to the goroutine
// profile, exactly as it was when the goroutine profiler first stopped
// the world.
tryRecordGoroutineProfile(gp, osyield)
}
// Assign gp.m before entering _Grunning so running Gs have an
// M.
mp.curg = gp
gp.m = mp
casgstatus(gp, _Grunnable, _Grunning)
gp.waitsince = 0
gp.preempt = false
gp.stackguard0 = gp.stack.lo + _StackGuard
if !inheritTime {
mp.p.ptr().schedtick++
}
// Check whether the profiler needs to be turned on or off.
hz := sched.profilehz
if mp.profilehz != hz {
setThreadCPUProfiler(hz)
}
if trace.enabled {
// GoSysExit has to happen when we have a P, but before GoStart.
// So we emit it here.
if gp.syscallsp != 0 && gp.sysblocktraced {
traceGoSysExit(gp.sysexitticks)
}
traceGoStart()
}
gogo(&gp.sched)
}
// Finds a runnable goroutine to execute.
// Tries to steal from other P's, get g from local or global queue, poll network.
// tryWakeP indicates that the returned goroutine is not normal (GC worker, trace
// reader) so the caller should try to wake a P.
func findRunnable() (gp *g, inheritTime, tryWakeP bool) {
mp := getg().m
// The conditions here and in handoffp must agree: if
// findrunnable would return a G to run, handoffp must start
// an M.
top:
pp := mp.p.ptr()
if sched.gcwaiting.Load() {
gcstopm()
goto top
}
if pp.runSafePointFn != 0 {
runSafePointFn()
}
// now and pollUntil are saved for work stealing later,
// which may steal timers. It's important that between now
// and then, nothing blocks, so these numbers remain mostly
// relevant.
now, pollUntil, _ := checkTimers(pp, 0)
// Try to schedule the trace reader.
if trace.enabled || trace.shutdown {
gp := traceReader()
if gp != nil {
casgstatus(gp, _Gwaiting, _Grunnable)
traceGoUnpark(gp, 0)
return gp, false, true
}
}
// Try to schedule a GC worker.
if gcBlackenEnabled != 0 {
gp, tnow := gcController.findRunnableGCWorker(pp, now)
if gp != nil {
return gp, false, true
}
now = tnow
}
// Check the global runnable queue once in a while to ensure fairness.
// Otherwise two goroutines can completely occupy the local runqueue
// by constantly respawning each other.
if pp.schedtick%61 == 0 && sched.runqsize > 0 {
lock(&sched.lock)
gp := globrunqget(pp, 1)
unlock(&sched.lock)
if gp != nil {
return gp, false, false
}
}
// Wake up the finalizer G.
if fingStatus.Load()&(fingWait|fingWake) == fingWait|fingWake {
if gp := wakefing(); gp != nil {
ready(gp, 0, true)
}
}
if *cgo_yield != nil {
asmcgocall(*cgo_yield, nil)
}
// local runq
if gp, inheritTime := runqget(pp); gp != nil {
return gp, inheritTime, false
}
// global runq
if sched.runqsize != 0 {
lock(&sched.lock)
gp := globrunqget(pp, 0)
unlock(&sched.lock)
if gp != nil {
return gp, false, false
}
}
// Poll network.
// This netpoll is only an optimization before we resort to stealing.
// We can safely skip it if there are no waiters or a thread is blocked
// in netpoll already. If there is any kind of logical race with that
// blocked thread (e.g. it has already returned from netpoll, but does
// not set lastpoll yet), this thread will do blocking netpoll below
// anyway.
if netpollinited() && netpollWaiters.Load() > 0 && sched.lastpoll.Load() != 0 {
if list := netpoll(0); !list.empty() { // non-blocking
gp := list.pop()
injectglist(&list)
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, false, false
}
}
// Spinning Ms: steal work from other Ps.
//
// Limit the number of spinning Ms to half the number of busy Ps.
// This is necessary to prevent excessive CPU consumption when
// GOMAXPROCS>>1 but the program parallelism is low.
if mp.spinning || 2*sched.nmspinning.Load() < gomaxprocs-sched.npidle.Load() {
if !mp.spinning {
mp.becomeSpinning()
}
gp, inheritTime, tnow, w, newWork := stealWork(now)
if gp != nil {
// Successfully stole.
return gp, inheritTime, false
}
if newWork {
// There may be new timer or GC work; restart to
// discover.
goto top
}
now = tnow
if w != 0 && (pollUntil == 0 || w < pollUntil) {
// Earlier timer to wait for.
pollUntil = w
}
}
// We have nothing to do.
//
// If we're in the GC mark phase, can safely scan and blacken objects,
// and have work to do, run idle-time marking rather than give up the P.
if gcBlackenEnabled != 0 && gcMarkWorkAvailable(pp) && gcController.addIdleMarkWorker() {
node := (*gcBgMarkWorkerNode)(gcBgMarkWorkerPool.pop())
if node != nil {
pp.gcMarkWorkerMode = gcMarkWorkerIdleMode
gp := node.gp.ptr()
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, false, false
}
gcController.removeIdleMarkWorker()
}
// wasm only:
// If a callback returned and no other goroutine is awake,
// then wake event handler goroutine which pauses execution
// until a callback was triggered.
gp, otherReady := beforeIdle(now, pollUntil)
if gp != nil {
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, false, false
}
if otherReady {
goto top
}
// Before we drop our P, make a snapshot of the allp slice,
// which can change underfoot once we no longer block
// safe-points. We don't need to snapshot the contents because
// everything up to cap(allp) is immutable.
allpSnapshot := allp
// Also snapshot masks. Value changes are OK, but we can't allow
// len to change out from under us.
idlepMaskSnapshot := idlepMask
timerpMaskSnapshot := timerpMask
// return P and block
lock(&sched.lock)
if sched.gcwaiting.Load() || pp.runSafePointFn != 0 {
unlock(&sched.lock)
goto top
}
if sched.runqsize != 0 {
gp := globrunqget(pp, 0)
unlock(&sched.lock)
return gp, false, false
}
if !mp.spinning && sched.needspinning.Load() == 1 {
// See "Delicate dance" comment below.
mp.becomeSpinning()
unlock(&sched.lock)
goto top
}
if releasep() != pp {
throw("findrunnable: wrong p")
}
now = pidleput(pp, now)
unlock(&sched.lock)
// Delicate dance: thread transitions from spinning to non-spinning
// state, potentially concurrently with submission of new work. We must
// drop nmspinning first and then check all sources again (with
// #StoreLoad memory barrier in between). If we do it the other way
// around, another thread can submit work after we've checked all
// sources but before we drop nmspinning; as a result nobody will
// unpark a thread to run the work.
//
// This applies to the following sources of work:
//
// * Goroutines added to a per-P run queue.
// * New/modified-earlier timers on a per-P timer heap.
// * Idle-priority GC work (barring golang.org/issue/19112).
//
// If we discover new work below, we need to restore m.spinning as a
// signal for resetspinning to unpark a new worker thread (because
// there can be more than one starving goroutine).
//
// However, if after discovering new work we also observe no idle Ps
// (either here or in resetspinning), we have a problem. We may be
// racing with a non-spinning M in the block above, having found no
// work and preparing to release its P and park. Allowing that P to go
// idle will result in loss of work conservation (idle P while there is
// runnable work). This could result in complete deadlock in the
// unlikely event that we discover new work (from netpoll) right as we
// are racing with _all_ other Ps going idle.
//
// We use sched.needspinning to synchronize with non-spinning Ms going
// idle. If needspinning is set when they are about to drop their P,
// they abort the drop and instead become a new spinning M on our
// behalf. If we are not racing and the system is truly fully loaded
// then no spinning threads are required, and the next thread to
// naturally become spinning will clear the flag.
//
// Also see "Worker thread parking/unparking" comment at the top of the
// file.
wasSpinning := mp.spinning
if mp.spinning {
mp.spinning = false
if sched.nmspinning.Add(-1) < 0 {
throw("findrunnable: negative nmspinning")
}
// Note the for correctness, only the last M transitioning from
// spinning to non-spinning must perform these rechecks to
// ensure no missed work. However, the runtime has some cases
// of transient increments of nmspinning that are decremented
// without going through this path, so we must be conservative
// and perform the check on all spinning Ms.
//
// See https://go.dev/issue/43997.
// Check all runqueues once again.
pp := checkRunqsNoP(allpSnapshot, idlepMaskSnapshot)
if pp != nil {
acquirep(pp)
mp.becomeSpinning()
goto top
}
// Check for idle-priority GC work again.
pp, gp := checkIdleGCNoP()
if pp != nil {
acquirep(pp)
mp.becomeSpinning()
// Run the idle worker.
pp.gcMarkWorkerMode = gcMarkWorkerIdleMode
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, false, false
}
// Finally, check for timer creation or expiry concurrently with
// transitioning from spinning to non-spinning.
//
// Note that we cannot use checkTimers here because it calls
// adjusttimers which may need to allocate memory, and that isn't
// allowed when we don't have an active P.
pollUntil = checkTimersNoP(allpSnapshot, timerpMaskSnapshot, pollUntil)
}
// Poll network until next timer.
if netpollinited() && (netpollWaiters.Load() > 0 || pollUntil != 0) && sched.lastpoll.Swap(0) != 0 {
sched.pollUntil.Store(pollUntil)
if mp.p != 0 {
throw("findrunnable: netpoll with p")
}
if mp.spinning {
throw("findrunnable: netpoll with spinning")
}
// Refresh now.
now = nanotime()
delay := int64(-1)
if pollUntil != 0 {
delay = pollUntil - now
if delay < 0 {
delay = 0
}
}
if faketime != 0 {
// When using fake time, just poll.
delay = 0
}
list := netpoll(delay) // block until new work is available
sched.pollUntil.Store(0)
sched.lastpoll.Store(now)
if faketime != 0 && list.empty() {
// Using fake time and nothing is ready; stop M.
// When all M's stop, checkdead will call timejump.
stopm()
goto top
}
lock(&sched.lock)
pp, _ := pidleget(now)
unlock(&sched.lock)
if pp == nil {
injectglist(&list)
} else {
acquirep(pp)
if !list.empty() {
gp := list.pop()
injectglist(&list)
casgstatus(gp, _Gwaiting, _Grunnable)
if trace.enabled {
traceGoUnpark(gp, 0)
}
return gp, false, false
}
if wasSpinning {
mp.becomeSpinning()
}
goto top
}
} else if pollUntil != 0 && netpollinited() {
pollerPollUntil := sched.pollUntil.Load()
if pollerPollUntil == 0 || pollerPollUntil > pollUntil {
netpollBreak()
}
}
stopm()
goto top
}
// pollWork reports whether there is non-background work this P could
// be doing. This is a fairly lightweight check to be used for
// background work loops, like idle GC. It checks a subset of the
// conditions checked by the actual scheduler.
func pollWork() bool {
if sched.runqsize != 0 {
return true
}
p := getg().m.p.ptr()
if !runqempty(p) {
return true
}
if netpollinited() && netpollWaiters.Load() > 0 && sched.lastpoll.Load() != 0 {
if list := netpoll(0); !list.empty() {
injectglist(&list)
return true
}
}
return false
}
// stealWork attempts to steal a runnable goroutine or timer from any P.
//
// If newWork is true, new work may have been readied.
//
// If now is not 0 it is the current time. stealWork returns the passed time or
// the current time if now was passed as 0.
func stealWork(now int64) (gp *g, inheritTime bool, rnow, pollUntil int64, newWork bool) {
pp := getg().m.p.ptr()
ranTimer := false
const stealTries = 4
for i := 0; i < stealTries; i++ {
stealTimersOrRunNextG := i == stealTries-1
for enum := stealOrder.start(fastrand()); !enum.done(); enum.next() {
if sched.gcwaiting.Load() {
// GC work may be available.
return nil, false, now, pollUntil, true
}
p2 := allp[enum.position()]
if pp == p2 {
continue
}
// Steal timers from p2. This call to checkTimers is the only place
// where we might hold a lock on a different P's timers. We do this
// once on the last pass before checking runnext because stealing
// from the other P's runnext should be the last resort, so if there
// are timers to steal do that first.
//
// We only check timers on one of the stealing iterations because
// the time stored in now doesn't change in this loop and checking
// the timers for each P more than once with the same value of now
// is probably a waste of time.
//
// timerpMask tells us whether the P may have timers at all. If it
// can't, no need to check at all.
if stealTimersOrRunNextG && timerpMask.read(enum.position()) {
tnow, w, ran := checkTimers(p2, now)
now = tnow
if w != 0 && (pollUntil == 0 || w < pollUntil) {
pollUntil = w
}
if ran {
// Running the timers may have
// made an arbitrary number of G's
// ready and added them to this P's
// local run queue. That invalidates
// the assumption of runqsteal
// that it always has room to add
// stolen G's. So check now if there
// is a local G to run.
if gp, inheritTime := runqget(pp); gp != nil {
return gp, inheritTime, now, pollUntil, ranTimer
}
ranTimer = true
}
}
// Don't bother to attempt to steal if p2 is idle.
if !idlepMask.read(enum.position()) {
if gp := runqsteal(pp, p2, stealTimersOrRunNextG); gp != nil {
return gp, false, now, pollUntil, ranTimer
}
}
}
}
// No goroutines found to steal. Regardless, running a timer may have
// made some goroutine ready that we missed. Indicate the next timer to
// wait for.
return nil, false, now, pollUntil, ranTimer
}
// Check all Ps for a runnable G to steal.
//
// On entry we have no P. If a G is available to steal and a P is available,
// the P is returned which the caller should acquire and attempt to steal the
// work to.
func checkRunqsNoP(allpSnapshot []*p, idlepMaskSnapshot pMask) *p {
for id, p2 := range allpSnapshot {
if !idlepMaskSnapshot.read(uint32(id)) && !runqempty(p2) {
lock(&sched.lock)
pp, _ := pidlegetSpinning(0)
if pp == nil {
// Can't get a P, don't bother checking remaining Ps.
unlock(&sched.lock)
return nil
}
unlock(&sched.lock)
return pp
}
}
// No work available.
return nil
}
// Check all Ps for a timer expiring sooner than pollUntil.
//
// Returns updated pollUntil value.
func checkTimersNoP(allpSnapshot []*p, timerpMaskSnapshot pMask, pollUntil int64) int64 {
for id, p2 := range allpSnapshot {
if timerpMaskSnapshot.read(uint32(id)) {
w := nobarrierWakeTime(p2)
if w != 0 && (pollUntil == 0 || w < pollUntil) {
pollUntil = w
}
}
}
return pollUntil
}
// Check for idle-priority GC, without a P on entry.
//
// If some GC work, a P, and a worker G are all available, the P and G will be
// returned. The returned P has not been wired yet.
func checkIdleGCNoP() (*p, *g) {
// N.B. Since we have no P, gcBlackenEnabled may change at any time; we
// must check again after acquiring a P. As an optimization, we also check
// if an idle mark worker is needed at all. This is OK here, because if we
// observe that one isn't needed, at least one is currently running. Even if
// it stops running, its own journey into the scheduler should schedule it
// again, if need be (at which point, this check will pass, if relevant).
if atomic.Load(&gcBlackenEnabled) == 0 || !gcController.needIdleMarkWorker() {
return nil, nil
}
if !gcMarkWorkAvailable(nil) {
return nil, nil
}
// Work is available; we can start an idle GC worker only if there is
// an available P and available worker G.
//
// We can attempt to acquire these in either order, though both have
// synchronization concerns (see below). Workers are almost always
// available (see comment in findRunnableGCWorker for the one case
// there may be none). Since we're slightly less likely to find a P,
// check for that first.
//
// Synchronization: note that we must hold sched.lock until we are
// committed to keeping it. Otherwise we cannot put the unnecessary P
// back in sched.pidle without performing the full set of idle
// transition checks.
//
// If we were to check gcBgMarkWorkerPool first, we must somehow handle
// the assumption in gcControllerState.findRunnableGCWorker that an
// empty gcBgMarkWorkerPool is only possible if gcMarkDone is running.
lock(&sched.lock)
pp, now := pidlegetSpinning(0)
if pp == nil {
unlock(&sched.lock)
return nil, nil
}
// Now that we own a P, gcBlackenEnabled can't change (as it requires STW).
if gcBlackenEnabled == 0 || !gcController.addIdleMarkWorker() {
pidleput(pp, now)
unlock(&sched.lock)
return nil, nil
}
node := (*gcBgMarkWorkerNode)(gcBgMarkWorkerPool.pop())
if node == nil {
pidleput(pp, now)
unlock(&sched.lock)
gcController.removeIdleMarkWorker()
return nil, nil
}
unlock(&sched.lock)
return pp, node.gp.ptr()
}
// wakeNetPoller wakes up the thread sleeping in the network poller if it isn't
// going to wake up before the when argument; or it wakes an idle P to service
// timers and the network poller if there isn't one already.
func wakeNetPoller(when int64) {
if sched.lastpoll.Load() == 0 {
// In findrunnable we ensure that when polling the pollUntil
// field is either zero or the time to which the current
// poll is expected to run. This can have a spurious wakeup
// but should never miss a wakeup.
pollerPollUntil := sched.pollUntil.Load()
if pollerPollUntil == 0 || pollerPollUntil > when {
netpollBreak()
}
} else {
// There are no threads in the network poller, try to get
// one there so it can handle new timers.
if GOOS != "plan9" { // Temporary workaround - see issue #42303.
wakep()
}
}
}
func resetspinning() {
gp := getg()
if !gp.m.spinning {
throw("resetspinning: not a spinning m")
}
gp.m.spinning = false
nmspinning := sched.nmspinning.Add(-1)
if nmspinning < 0 {
throw("findrunnable: negative nmspinning")
}
// M wakeup policy is deliberately somewhat conservative, so check if we
// need to wakeup another P here. See "Worker thread parking/unparking"
// comment at the top of the file for details.
wakep()
}
// injectglist adds each runnable G on the list to some run queue,
// and clears glist. If there is no current P, they are added to the
// global queue, and up to npidle M's are started to run them.
// Otherwise, for each idle P, this adds a G to the global queue
// and starts an M. Any remaining G's are added to the current P's
// local run queue.
// This may temporarily acquire sched.lock.
// Can run concurrently with GC.
func injectglist(glist *gList) {
if glist.empty() {
return
}
if trace.enabled {
for gp := glist.head.ptr(); gp != nil; gp = gp.schedlink.ptr() {
traceGoUnpark(gp, 0)
}
}
// Mark all the goroutines as runnable before we put them
// on the run queues.
head := glist.head.ptr()
var tail *g
qsize := 0
for gp := head; gp != nil; gp = gp.schedlink.ptr() {
tail = gp
qsize++
casgstatus(gp, _Gwaiting, _Grunnable)
}
// Turn the gList into a gQueue.
var q gQueue
q.head.set(head)
q.tail.set(tail)
*glist = gList{}
startIdle := func(n int) {
for i := 0; i < n; i++ {
mp := acquirem() // See comment in startm.
lock(&sched.lock)
pp, _ := pidlegetSpinning(0)
if pp == nil {
unlock(&sched.lock)
releasem(mp)
break
}
unlock(&sched.lock)
startm(pp, false)
releasem(mp)
}
}
pp := getg().m.p.ptr()
if pp == nil {
lock(&sched.lock)
globrunqputbatch(&q, int32(qsize))
unlock(&sched.lock)
startIdle(qsize)
return
}
npidle := int(sched.npidle.Load())
var globq gQueue
var n int
for n = 0; n < npidle && !q.empty(); n++ {
g := q.pop()
globq.pushBack(g)
}
if n > 0 {
lock(&sched.lock)
globrunqputbatch(&globq, int32(n))
unlock(&sched.lock)
startIdle(n)
qsize -= n
}
if !q.empty() {
runqputbatch(pp, &q, qsize)
}
}
// One round of scheduler: find a runnable goroutine and execute it.
// Never returns.
func schedule() {
mp := getg().m
if mp.locks != 0 {
throw("schedule: holding locks")
}
if mp.lockedg != 0 {
stoplockedm()
execute(mp.lockedg.ptr(), false) // Never returns.
}
// We should not schedule away from a g that is executing a cgo call,
// since the cgo call is using the m's g0 stack.
if mp.incgo {
throw("schedule: in cgo")
}
top:
pp := mp.p.ptr()
pp.preempt = false
// Safety check: if we are spinning, the run queue should be empty.
// Check this before calling checkTimers, as that might call
// goready to put a ready goroutine on the local run queue.
if mp.spinning && (pp.runnext != 0 || pp.runqhead != pp.runqtail) {
throw("schedule: spinning with local work")
}
gp, inheritTime, tryWakeP := findRunnable() // blocks until work is available
// This thread is going to run a goroutine and is not spinning anymore,
// so if it was marked as spinning we need to reset it now and potentially
// start a new spinning M.
if mp.spinning {
resetspinning()
}
if sched.disable.user && !schedEnabled(gp) {
// Scheduling of this goroutine is disabled. Put it on
// the list of pending runnable goroutines for when we
// re-enable user scheduling and look again.
lock(&sched.lock)
if schedEnabled(gp) {
// Something re-enabled scheduling while we
// were acquiring the lock.
unlock(&sched.lock)
} else {
sched.disable.runnable.pushBack(gp)
sched.disable.n++
unlock(&sched.lock)
goto top
}
}
// If about to schedule a not-normal goroutine (a GCworker or tracereader),
// wake a P if there is one.
if tryWakeP {
wakep()
}
if gp.lockedm != 0 {
// Hands off own p to the locked m,
// then blocks waiting for a new p.
startlockedm(gp)
goto top
}
execute(gp, inheritTime)
}
// dropg removes the association between m and the current goroutine m->curg (gp for short).
// Typically a caller sets gp's status away from Grunning and then
// immediately calls dropg to finish the job. The caller is also responsible
// for arranging that gp will be restarted using ready at an
// appropriate time. After calling dropg and arranging for gp to be
// readied later, the caller can do other work but eventually should
// call schedule to restart the scheduling of goroutines on this m.
func dropg() {
gp := getg()
setMNoWB(&gp.m.curg.m, nil)
setGNoWB(&gp.m.curg, nil)
}
// checkTimers runs any timers for the P that are ready.
// If now is not 0 it is the current time.
// It returns the passed time or the current time if now was passed as 0.
// and the time when the next timer should run or 0 if there is no next timer,
// and reports whether it ran any timers.
// If the time when the next timer should run is not 0,
// it is always larger than the returned time.
// We pass now in and out to avoid extra calls of nanotime.
//
//go:yeswritebarrierrec
func checkTimers(pp *p, now int64) (rnow, pollUntil int64, ran bool) {
// If it's not yet time for the first timer, or the first adjusted
// timer, then there is nothing to do.
next := pp.timer0When.Load()
nextAdj := pp.timerModifiedEarliest.Load()
if next == 0 || (nextAdj != 0 && nextAdj < next) {
next = nextAdj
}
if next == 0 {
// No timers to run or adjust.
return now, 0, false
}
if now == 0 {
now = nanotime()
}
if now < next {
// Next timer is not ready to run, but keep going
// if we would clear deleted timers.
// This corresponds to the condition below where
// we decide whether to call clearDeletedTimers.
if pp != getg().m.p.ptr() || int(pp.deletedTimers.Load()) <= int(pp.numTimers.Load()/4) {
return now, next, false
}
}
lock(&pp.timersLock)
if len(pp.timers) > 0 {
adjusttimers(pp, now)
for len(pp.timers) > 0 {
// Note that runtimer may temporarily unlock
// pp.timersLock.
if tw := runtimer(pp, now); tw != 0 {
if tw > 0 {
pollUntil = tw
}
break
}
ran = true
}
}
// If this is the local P, and there are a lot of deleted timers,
// clear them out. We only do this for the local P to reduce
// lock contention on timersLock.
if pp == getg().m.p.ptr() && int(pp.deletedTimers.Load()) > len(pp.timers)/4 {
clearDeletedTimers(pp)
}
unlock(&pp.timersLock)
return now, pollUntil, ran
}
func parkunlock_c(gp *g, lock unsafe.Pointer) bool {
unlock((*mutex)(lock))
return true
}
// park continuation on g0.
func park_m(gp *g) {
mp := getg().m
if trace.enabled {
traceGoPark(mp.waittraceev, mp.waittraceskip)
}
// N.B. Not using casGToWaiting here because the waitreason is
// set by park_m's caller.
casgstatus(gp, _Grunning, _Gwaiting)
dropg()
if fn := mp.waitunlockf; fn != nil {
ok := fn(gp, mp.waitlock)
mp.waitunlockf = nil
mp.waitlock = nil
if !ok {
if trace.enabled {
traceGoUnpark(gp, 2)
}
casgstatus(gp, _Gwaiting, _Grunnable)
execute(gp, true) // Schedule it back, never returns.
}
}
schedule()
}
func goschedImpl(gp *g) {
status := readgstatus(gp)
if status&^_Gscan != _Grunning {
dumpgstatus(gp)
throw("bad g status")
}
casgstatus(gp, _Grunning, _Grunnable)
dropg()
lock(&sched.lock)
globrunqput(gp)
unlock(&sched.lock)
schedule()
}
// Gosched continuation on g0.
func gosched_m(gp *g) {
if trace.enabled {
traceGoSched()
}
goschedImpl(gp)
}
// goschedguarded is a forbidden-states-avoided version of gosched_m.
func goschedguarded_m(gp *g) {
if !canPreemptM(gp.m) {
gogo(&gp.sched) // never return
}
if trace.enabled {
traceGoSched()
}
goschedImpl(gp)
}
func gopreempt_m(gp *g) {
if trace.enabled {
traceGoPreempt()
}
goschedImpl(gp)
}
// preemptPark parks gp and puts it in _Gpreempted.
//
//go:systemstack
func preemptPark(gp *g) {
if trace.enabled {
traceGoPark(traceEvGoBlock, 0)
}
status := readgstatus(gp)
if status&^_Gscan != _Grunning {
dumpgstatus(gp)
throw("bad g status")
}
if gp.asyncSafePoint {
// Double-check that async preemption does not
// happen in SPWRITE assembly functions.
// isAsyncSafePoint must exclude this case.
f := findfunc(gp.sched.pc)
if !f.valid() {
throw("preempt at unknown pc")
}
if f.flag&funcFlag_SPWRITE != 0 {
println("runtime: unexpected SPWRITE function", funcname(f), "in async preempt")
throw("preempt SPWRITE")
}
}
// Transition from _Grunning to _Gscan|_Gpreempted. We can't
// be in _Grunning when we dropg because then we'd be running
// without an M, but the moment we're in _Gpreempted,
// something could claim this G before we've fully cleaned it
// up. Hence, we set the scan bit to lock down further
// transitions until we can dropg.
casGToPreemptScan(gp, _Grunning, _Gscan|_Gpreempted)
dropg()
casfrom_Gscanstatus(gp, _Gscan|_Gpreempted, _Gpreempted)
schedule()
}
// goyield is like Gosched, but it:
// - emits a GoPreempt trace event instead of a GoSched trace event
// - puts the current G on the runq of the current P instead of the globrunq
func goyield() {
checkTimeouts()
mcall(goyield_m)
}
func goyield_m(gp *g) {
if trace.enabled {
traceGoPreempt()
}
pp := gp.m.p.ptr()
casgstatus(gp, _Grunning, _Grunnable)
dropg()
runqput(pp, gp, false)
schedule()
}
// Finishes execution of the current goroutine.
func goexit1() {
if raceenabled {
racegoend()
}
if trace.enabled {
traceGoEnd()
}
mcall(goexit0)
}
// goexit continuation on g0.
func goexit0(gp *g) {
mp := getg().m
pp := mp.p.ptr()
casgstatus(gp, _Grunning, _Gdead)
gcController.addScannableStack(pp, -int64(gp.stack.hi-gp.stack.lo))
if isSystemGoroutine(gp, false) {
sched.ngsys.Add(-1)
}
gp.m = nil
locked := gp.lockedm != 0
gp.lockedm = 0
mp.lockedg = 0
gp.preemptStop = false
gp.paniconfault = false
gp._defer = nil // should be true already but just in case.
gp._panic = nil // non-nil for Goexit during panic. points at stack-allocated data.
gp.writebuf = nil
gp.waitreason = waitReasonZero
gp.param = nil
gp.labels = nil
gp.timer = nil
if gcBlackenEnabled != 0 && gp.gcAssistBytes > 0 {
// Flush assist credit to the global pool. This gives
// better information to pacing if the application is
// rapidly creating an exiting goroutines.
assistWorkPerByte := gcController.assistWorkPerByte.Load()
scanCredit := int64(assistWorkPerByte * float64(gp.gcAssistBytes))
gcController.bgScanCredit.Add(scanCredit)
gp.gcAssistBytes = 0
}
dropg()
if GOARCH == "wasm" { // no threads yet on wasm
gfput(pp, gp)
schedule() // never returns
}
if mp.lockedInt != 0 {
print("invalid m->lockedInt = ", mp.lockedInt, "\n")
throw("internal lockOSThread error")
}
gfput(pp, gp)
if locked {
// The goroutine may have locked this thread because
// it put it in an unusual kernel state. Kill it
// rather than returning it to the thread pool.
// Return to mstart, which will release the P and exit
// the thread.
if GOOS != "plan9" { // See golang.org/issue/22227.
gogo(&mp.g0.sched)
} else {
// Clear lockedExt on plan9 since we may end up re-using
// this thread.
mp.lockedExt = 0
}
}
schedule()
}
// save updates getg().sched to refer to pc and sp so that a following
// gogo will restore pc and sp.
//
// save must not have write barriers because invoking a write barrier
// can clobber getg().sched.
//
//go:nosplit
//go:nowritebarrierrec
func save(pc, sp uintptr) {
gp := getg()
if gp == gp.m.g0 || gp == gp.m.gsignal {
// m.g0.sched is special and must describe the context
// for exiting the thread. mstart1 writes to it directly.
// m.gsignal.sched should not be used at all.
// This check makes sure save calls do not accidentally
// run in contexts where they'd write to system g's.
throw("save on system g not allowed")
}
gp.sched.pc = pc
gp.sched.sp = sp
gp.sched.lr = 0
gp.sched.ret = 0
// We need to ensure ctxt is zero, but can't have a write
// barrier here. However, it should always already be zero.
// Assert that.
if gp.sched.ctxt != nil {
badctxt()
}
}
// The goroutine g is about to enter a system call.
// Record that it's not using the cpu anymore.
// This is called only from the go syscall library and cgocall,
// not from the low-level system calls used by the runtime.
//
// Entersyscall cannot split the stack: the save must
// make g->sched refer to the caller's stack segment, because
// entersyscall is going to return immediately after.
//
// Nothing entersyscall calls can split the stack either.
// We cannot safely move the stack during an active call to syscall,
// because we do not know which of the uintptr arguments are
// really pointers (back into the stack).
// In practice, this means that we make the fast path run through
// entersyscall doing no-split things, and the slow path has to use systemstack
// to run bigger things on the system stack.
//
// reentersyscall is the entry point used by cgo callbacks, where explicitly
// saved SP and PC are restored. This is needed when exitsyscall will be called
// from a function further up in the call stack than the parent, as g->syscallsp
// must always point to a valid stack frame. entersyscall below is the normal
// entry point for syscalls, which obtains the SP and PC from the caller.
//
// Syscall tracing:
// At the start of a syscall we emit traceGoSysCall to capture the stack trace.
// If the syscall does not block, that is it, we do not emit any other events.
// If the syscall blocks (that is, P is retaken), retaker emits traceGoSysBlock;
// when syscall returns we emit traceGoSysExit and when the goroutine starts running
// (potentially instantly, if exitsyscallfast returns true) we emit traceGoStart.
// To ensure that traceGoSysExit is emitted strictly after traceGoSysBlock,
// we remember current value of syscalltick in m (gp.m.syscalltick = gp.m.p.ptr().syscalltick),
// whoever emits traceGoSysBlock increments p.syscalltick afterwards;
// and we wait for the increment before emitting traceGoSysExit.
// Note that the increment is done even if tracing is not enabled,
// because tracing can be enabled in the middle of syscall. We don't want the wait to hang.
//
//go:nosplit
func reentersyscall(pc, sp uintptr) {
gp := getg()
// Disable preemption because during this function g is in Gsyscall status,
// but can have inconsistent g->sched, do not let GC observe it.
gp.m.locks++
// Entersyscall must not call any function that might split/grow the stack.
// (See details in comment above.)
// Catch calls that might, by replacing the stack guard with something that
// will trip any stack check and leaving a flag to tell newstack to die.
gp.stackguard0 = stackPreempt
gp.throwsplit = true
// Leave SP around for GC and traceback.
save(pc, sp)
gp.syscallsp = sp
gp.syscallpc = pc
casgstatus(gp, _Grunning, _Gsyscall)
if gp.syscallsp < gp.stack.lo || gp.stack.hi < gp.syscallsp {
systemstack(func() {
print("entersyscall inconsistent ", hex(gp.syscallsp), " [", hex(gp.stack.lo), ",", hex(gp.stack.hi), "]\n")
throw("entersyscall")
})
}
if trace.enabled {
systemstack(traceGoSysCall)
// systemstack itself clobbers g.sched.{pc,sp} and we might
// need them later when the G is genuinely blocked in a
// syscall
save(pc, sp)
}
if sched.sysmonwait.Load() {
systemstack(entersyscall_sysmon)
save(pc, sp)
}
if gp.m.p.ptr().runSafePointFn != 0 {
// runSafePointFn may stack split if run on this stack
systemstack(runSafePointFn)
save(pc, sp)
}
gp.m.syscalltick = gp.m.p.ptr().syscalltick
gp.sysblocktraced = true
pp := gp.m.p.ptr()
pp.m = 0
gp.m.oldp.set(pp)
gp.m.p = 0
atomic.Store(&pp.status, _Psyscall)
if sched.gcwaiting.Load() {
systemstack(entersyscall_gcwait)
save(pc, sp)
}
gp.m.locks--
}
// Standard syscall entry used by the go syscall library and normal cgo calls.
//
// This is exported via linkname to assembly in the syscall package and x/sys.
//
//go:nosplit
//go:linkname entersyscall
func entersyscall() {
reentersyscall(getcallerpc(), getcallersp())
}
func entersyscall_sysmon() {
lock(&sched.lock)
if sched.sysmonwait.Load() {
sched.sysmonwait.Store(false)
notewakeup(&sched.sysmonnote)
}
unlock(&sched.lock)
}
func entersyscall_gcwait() {
gp := getg()
pp := gp.m.oldp.ptr()
lock(&sched.lock)
if sched.stopwait > 0 && atomic.Cas(&pp.status, _Psyscall, _Pgcstop) {
if trace.enabled {
traceGoSysBlock(pp)
traceProcStop(pp)
}
pp.syscalltick++
if sched.stopwait--; sched.stopwait == 0 {
notewakeup(&sched.stopnote)
}
}
unlock(&sched.lock)
}
// The same as entersyscall(), but with a hint that the syscall is blocking.
//
//go:nosplit
func entersyscallblock() {
gp := getg()
gp.m.locks++ // see comment in entersyscall
gp.throwsplit = true
gp.stackguard0 = stackPreempt // see comment in entersyscall
gp.m.syscalltick = gp.m.p.ptr().syscalltick
gp.sysblocktraced = true
gp.m.p.ptr().syscalltick++
// Leave SP around for GC and traceback.
pc := getcallerpc()
sp := getcallersp()
save(pc, sp)
gp.syscallsp = gp.sched.sp
gp.syscallpc = gp.sched.pc
if gp.syscallsp < gp.stack.lo || gp.stack.hi < gp.syscallsp {
sp1 := sp
sp2 := gp.sched.sp
sp3 := gp.syscallsp
systemstack(func() {
print("entersyscallblock inconsistent ", hex(sp1), " ", hex(sp2), " ", hex(sp3), " [", hex(gp.stack.lo), ",", hex(gp.stack.hi), "]\n")
throw("entersyscallblock")
})
}
casgstatus(gp, _Grunning, _Gsyscall)
if gp.syscallsp < gp.stack.lo || gp.stack.hi < gp.syscallsp {
systemstack(func() {
print("entersyscallblock inconsistent ", hex(sp), " ", hex(gp.sched.sp), " ", hex(gp.syscallsp), " [", hex(gp.stack.lo), ",", hex(gp.stack.hi), "]\n")
throw("entersyscallblock")
})
}
systemstack(entersyscallblock_handoff)
// Resave for traceback during blocked call.
save(getcallerpc(), getcallersp())
gp.m.locks--
}
func entersyscallblock_handoff() {
if trace.enabled {
traceGoSysCall()
traceGoSysBlock(getg().m.p.ptr())
}
handoffp(releasep())
}
// The goroutine g exited its system call.
// Arrange for it to run on a cpu again.
// This is called only from the go syscall library, not
// from the low-level system calls used by the runtime.
//
// Write barriers are not allowed because our P may have been stolen.
//
// This is exported via linkname to assembly in the syscall package.
//
//go:nosplit
//go:nowritebarrierrec
//go:linkname exitsyscall
func exitsyscall() {
gp := getg()
gp.m.locks++ // see comment in entersyscall
if getcallersp() > gp.syscallsp {
throw("exitsyscall: syscall frame is no longer valid")
}
gp.waitsince = 0
oldp := gp.m.oldp.ptr()
gp.m.oldp = 0
if exitsyscallfast(oldp) {
// When exitsyscallfast returns success, we have a P so can now use
// write barriers
if goroutineProfile.active {
// Make sure that gp has had its stack written out to the goroutine
// profile, exactly as it was when the goroutine profiler first
// stopped the world.
systemstack(func() {
tryRecordGoroutineProfileWB(gp)
})
}
if trace.enabled {
if oldp != gp.m.p.ptr() || gp.m.syscalltick != gp.m.p.ptr().syscalltick {
systemstack(traceGoStart)
}
}
// There's a cpu for us, so we can run.
gp.m.p.ptr().syscalltick++
// We need to cas the status and scan before resuming...
casgstatus(gp, _Gsyscall, _Grunning)
// Garbage collector isn't running (since we are),
// so okay to clear syscallsp.
gp.syscallsp = 0
gp.m.locks--
if gp.preempt {
// restore the preemption request in case we've cleared it in newstack
gp.stackguard0 = stackPreempt
} else {
// otherwise restore the real _StackGuard, we've spoiled it in entersyscall/entersyscallblock
gp.stackguard0 = gp.stack.lo + _StackGuard
}
gp.throwsplit = false
if sched.disable.user && !schedEnabled(gp) {
// Scheduling of this goroutine is disabled.
Gosched()
}
return
}
gp.sysexitticks = 0
if trace.enabled {
// Wait till traceGoSysBlock event is emitted.
// This ensures consistency of the trace (the goroutine is started after it is blocked).
for oldp != nil && oldp.syscalltick == gp.m.syscalltick {
osyield()
}
// We can't trace syscall exit right now because we don't have a P.
// Tracing code can invoke write barriers that cannot run without a P.
// So instead we remember the syscall exit time and emit the event
// in execute when we have a P.
gp.sysexitticks = cputicks()
}
gp.m.locks--
// Call the scheduler.
mcall(exitsyscall0)
// Scheduler returned, so we're allowed to run now.
// Delete the syscallsp information that we left for
// the garbage collector during the system call.
// Must wait until now because until gosched returns
// we don't know for sure that the garbage collector
// is not running.
gp.syscallsp = 0
gp.m.p.ptr().syscalltick++
gp.throwsplit = false
}
//go:nosplit
func exitsyscallfast(oldp *p) bool {
gp := getg()
// Freezetheworld sets stopwait but does not retake P's.
if sched.stopwait == freezeStopWait {
return false
}
// Try to re-acquire the last P.
if oldp != nil && oldp.status == _Psyscall && atomic.Cas(&oldp.status, _Psyscall, _Pidle) {
// There's a cpu for us, so we can run.
wirep(oldp)
exitsyscallfast_reacquired()
return true
}
// Try to get any other idle P.
if sched.pidle != 0 {
var ok bool
systemstack(func() {
ok = exitsyscallfast_pidle()
if ok && trace.enabled {
if oldp != nil {
// Wait till traceGoSysBlock event is emitted.
// This ensures consistency of the trace (the goroutine is started after it is blocked).
for oldp.syscalltick == gp.m.syscalltick {
osyield()
}
}
traceGoSysExit(0)
}
})
if ok {
return true
}
}
return false
}
// exitsyscallfast_reacquired is the exitsyscall path on which this G
// has successfully reacquired the P it was running on before the
// syscall.
//
//go:nosplit
func exitsyscallfast_reacquired() {
gp := getg()
if gp.m.syscalltick != gp.m.p.ptr().syscalltick {
if trace.enabled {
// The p was retaken and then enter into syscall again (since gp.m.syscalltick has changed).
// traceGoSysBlock for this syscall was already emitted,
// but here we effectively retake the p from the new syscall running on the same p.
systemstack(func() {
// Denote blocking of the new syscall.
traceGoSysBlock(gp.m.p.ptr())
// Denote completion of the current syscall.
traceGoSysExit(0)
})
}
gp.m.p.ptr().syscalltick++
}
}
func exitsyscallfast_pidle() bool {
lock(&sched.lock)
pp, _ := pidleget(0)
if pp != nil && sched.sysmonwait.Load() {
sched.sysmonwait.Store(false)
notewakeup(&sched.sysmonnote)
}
unlock(&sched.lock)
if pp != nil {
acquirep(pp)
return true
}
return false
}
// exitsyscall slow path on g0.
// Failed to acquire P, enqueue gp as runnable.
//
// Called via mcall, so gp is the calling g from this M.
//
//go:nowritebarrierrec
func exitsyscall0(gp *g) {
casgstatus(gp, _Gsyscall, _Grunnable)
dropg()
lock(&sched.lock)
var pp *p
if schedEnabled(gp) {
pp, _ = pidleget(0)
}
var locked bool
if pp == nil {
globrunqput(gp)
// Below, we stoplockedm if gp is locked. globrunqput releases
// ownership of gp, so we must check if gp is locked prior to
// committing the release by unlocking sched.lock, otherwise we
// could race with another M transitioning gp from unlocked to
// locked.
locked = gp.lockedm != 0
} else if sched.sysmonwait.Load() {
sched.sysmonwait.Store(false)
notewakeup(&sched.sysmonnote)
}
unlock(&sched.lock)
if pp != nil {
acquirep(pp)
execute(gp, false) // Never returns.
}
if locked {
// Wait until another thread schedules gp and so m again.
//
// N.B. lockedm must be this M, as this g was running on this M
// before entersyscall.
stoplockedm()
execute(gp, false) // Never returns.
}
stopm()
schedule() // Never returns.
}
// Called from syscall package before fork.
//
//go:linkname syscall_runtime_BeforeFork syscall.runtime_BeforeFork
//go:nosplit
func syscall_runtime_BeforeFork() {
gp := getg().m.curg
// Block signals during a fork, so that the child does not run
// a signal handler before exec if a signal is sent to the process
// group. See issue #18600.
gp.m.locks++
sigsave(&gp.m.sigmask)
sigblock(false)
// This function is called before fork in syscall package.
// Code between fork and exec must not allocate memory nor even try to grow stack.
// Here we spoil g->_StackGuard to reliably detect any attempts to grow stack.
// runtime_AfterFork will undo this in parent process, but not in child.
gp.stackguard0 = stackFork
}
// Called from syscall package after fork in parent.
//
//go:linkname syscall_runtime_AfterFork syscall.runtime_AfterFork
//go:nosplit
func syscall_runtime_AfterFork() {
gp := getg().m.curg
// See the comments in beforefork.
gp.stackguard0 = gp.stack.lo + _StackGuard
msigrestore(gp.m.sigmask)
gp.m.locks--
}
// inForkedChild is true while manipulating signals in the child process.
// This is used to avoid calling libc functions in case we are using vfork.
var inForkedChild bool
// Called from syscall package after fork in child.
// It resets non-sigignored signals to the default handler, and
// restores the signal mask in preparation for the exec.
//
// Because this might be called during a vfork, and therefore may be
// temporarily sharing address space with the parent process, this must
// not change any global variables or calling into C code that may do so.
//
//go:linkname syscall_runtime_AfterForkInChild syscall.runtime_AfterForkInChild
//go:nosplit
//go:nowritebarrierrec
func syscall_runtime_AfterForkInChild() {
// It's OK to change the global variable inForkedChild here
// because we are going to change it back. There is no race here,
// because if we are sharing address space with the parent process,
// then the parent process can not be running concurrently.
inForkedChild = true
clearSignalHandlers()
// When we are the child we are the only thread running,
// so we know that nothing else has changed gp.m.sigmask.
msigrestore(getg().m.sigmask)
inForkedChild = false
}
// pendingPreemptSignals is the number of preemption signals
// that have been sent but not received. This is only used on Darwin.
// For #41702.
var pendingPreemptSignals atomic.Int32
// Called from syscall package before Exec.
//
//go:linkname syscall_runtime_BeforeExec syscall.runtime_BeforeExec
func syscall_runtime_BeforeExec() {
// Prevent thread creation during exec.
execLock.lock()
// On Darwin, wait for all pending preemption signals to
// be received. See issue #41702.
if GOOS == "darwin" || GOOS == "ios" {
for pendingPreemptSignals.Load() > 0 {
osyield()
}
}
}
// Called from syscall package after Exec.
//
//go:linkname syscall_runtime_AfterExec syscall.runtime_AfterExec
func syscall_runtime_AfterExec() {
execLock.unlock()
}
// Allocate a new g, with a stack big enough for stacksize bytes.
func malg(stacksize int32) *g {
newg := new(g)
if stacksize >= 0 {
stacksize = round2(_StackSystem + stacksize)
systemstack(func() {
newg.stack = stackalloc(uint32(stacksize))
})
newg.stackguard0 = newg.stack.lo + _StackGuard
newg.stackguard1 = ^uintptr(0)
// Clear the bottom word of the stack. We record g
// there on gsignal stack during VDSO on ARM and ARM64.
*(*uintptr)(unsafe.Pointer(newg.stack.lo)) = 0
}
return newg
}
// Create a new g running fn.
// Put it on the queue of g's waiting to run.
// The compiler turns a go statement into a call to this.
func newproc(fn *funcval) {
gp := getg()
pc := getcallerpc()
systemstack(func() {
newg := newproc1(fn, gp, pc)
pp := getg().m.p.ptr()
runqput(pp, newg, true)
if mainStarted {
wakep()
}
})
}
// Create a new g in state _Grunnable, starting at fn. callerpc is the
// address of the go statement that created this. The caller is responsible
// for adding the new g to the scheduler.
func newproc1(fn *funcval, callergp *g, callerpc uintptr) *g {
if fn == nil {
fatal("go of nil func value")
}
mp := acquirem() // disable preemption because we hold M and P in local vars.
pp := mp.p.ptr()
newg := gfget(pp)
if newg == nil {
newg = malg(_StackMin)
casgstatus(newg, _Gidle, _Gdead)
allgadd(newg) // publishes with a g->status of Gdead so GC scanner doesn't look at uninitialized stack.
}
if newg.stack.hi == 0 {
throw("newproc1: newg missing stack")
}
if readgstatus(newg) != _Gdead {
throw("newproc1: new g is not Gdead")
}
totalSize := uintptr(4*goarch.PtrSize + sys.MinFrameSize) // extra space in case of reads slightly beyond frame
totalSize = alignUp(totalSize, sys.StackAlign)
sp := newg.stack.hi - totalSize
spArg := sp
if usesLR {
// caller's LR
*(*uintptr)(unsafe.Pointer(sp)) = 0
prepGoExitFrame(sp)
spArg += sys.MinFrameSize
}
memclrNoHeapPointers(unsafe.Pointer(&newg.sched), unsafe.Sizeof(newg.sched))
newg.sched.sp = sp
newg.stktopsp = sp
newg.sched.pc = abi.FuncPCABI0(goexit) + sys.PCQuantum // +PCQuantum so that previous instruction is in same function
newg.sched.g = guintptr(unsafe.Pointer(newg))
gostartcallfn(&newg.sched, fn)
newg.parentGoid = callergp.goid
newg.gopc = callerpc
newg.ancestors = saveAncestors(callergp)
newg.startpc = fn.fn
if isSystemGoroutine(newg, false) {
sched.ngsys.Add(1)
} else {
// Only user goroutines inherit pprof labels.
if mp.curg != nil {
newg.labels = mp.curg.labels
}
if goroutineProfile.active {
// A concurrent goroutine profile is running. It should include
// exactly the set of goroutines that were alive when the goroutine
// profiler first stopped the world. That does not include newg, so
// mark it as not needing a profile before transitioning it from
// _Gdead.
newg.goroutineProfiled.Store(goroutineProfileSatisfied)
}
}
// Track initial transition?
newg.trackingSeq = uint8(fastrand())
if newg.trackingSeq%gTrackingPeriod == 0 {
newg.tracking = true
}
casgstatus(newg, _Gdead, _Grunnable)
gcController.addScannableStack(pp, int64(newg.stack.hi-newg.stack.lo))
if pp.goidcache == pp.goidcacheend {
// Sched.goidgen is the last allocated id,
// this batch must be [sched.goidgen+1, sched.goidgen+GoidCacheBatch].
// At startup sched.goidgen=0, so main goroutine receives goid=1.
pp.goidcache = sched.goidgen.Add(_GoidCacheBatch)
pp.goidcache -= _GoidCacheBatch - 1
pp.goidcacheend = pp.goidcache + _GoidCacheBatch
}
newg.goid = pp.goidcache
pp.goidcache++
if raceenabled {
newg.racectx = racegostart(callerpc)
if newg.labels != nil {
// See note in proflabel.go on labelSync's role in synchronizing
// with the reads in the signal handler.
racereleasemergeg(newg, unsafe.Pointer(&labelSync))
}
}
if trace.enabled {
traceGoCreate(newg, newg.startpc)
}
releasem(mp)
return newg
}
// saveAncestors copies previous ancestors of the given caller g and
// includes info for the current caller into a new set of tracebacks for
// a g being created.
func saveAncestors(callergp *g) *[]ancestorInfo {
// Copy all prior info, except for the root goroutine (goid 0).
if debug.tracebackancestors <= 0 || callergp.goid == 0 {
return nil
}
var callerAncestors []ancestorInfo
if callergp.ancestors != nil {
callerAncestors = *callergp.ancestors
}
n := int32(len(callerAncestors)) + 1
if n > debug.tracebackancestors {
n = debug.tracebackancestors
}
ancestors := make([]ancestorInfo, n)
copy(ancestors[1:], callerAncestors)
var pcs [_TracebackMaxFrames]uintptr
npcs := gcallers(callergp, 0, pcs[:])
ipcs := make([]uintptr, npcs)
copy(ipcs, pcs[:])
ancestors[0] = ancestorInfo{
pcs: ipcs,
goid: callergp.goid,
gopc: callergp.gopc,
}
ancestorsp := new([]ancestorInfo)
*ancestorsp = ancestors
return ancestorsp
}
// Put on gfree list.
// If local list is too long, transfer a batch to the global list.
func gfput(pp *p, gp *g) {
if readgstatus(gp) != _Gdead {
throw("gfput: bad status (not Gdead)")
}
stksize := gp.stack.hi - gp.stack.lo
if stksize != uintptr(startingStackSize) {
// non-standard stack size - free it.
stackfree(gp.stack)
gp.stack.lo = 0
gp.stack.hi = 0
gp.stackguard0 = 0
}
pp.gFree.push(gp)
pp.gFree.n++
if pp.gFree.n >= 64 {
var (
inc int32
stackQ gQueue
noStackQ gQueue
)
for pp.gFree.n >= 32 {
gp := pp.gFree.pop()
pp.gFree.n--
if gp.stack.lo == 0 {
noStackQ.push(gp)
} else {
stackQ.push(gp)
}
inc++
}
lock(&sched.gFree.lock)
sched.gFree.noStack.pushAll(noStackQ)
sched.gFree.stack.pushAll(stackQ)
sched.gFree.n += inc
unlock(&sched.gFree.lock)
}
}
// Get from gfree list.
// If local list is empty, grab a batch from global list.
func gfget(pp *p) *g {
retry:
if pp.gFree.empty() && (!sched.gFree.stack.empty() || !sched.gFree.noStack.empty()) {
lock(&sched.gFree.lock)
// Move a batch of free Gs to the P.
for pp.gFree.n < 32 {
// Prefer Gs with stacks.
gp := sched.gFree.stack.pop()
if gp == nil {
gp = sched.gFree.noStack.pop()
if gp == nil {
break
}
}
sched.gFree.n--
pp.gFree.push(gp)
pp.gFree.n++
}
unlock(&sched.gFree.lock)
goto retry
}
gp := pp.gFree.pop()
if gp == nil {
return nil
}
pp.gFree.n--
if gp.stack.lo != 0 && gp.stack.hi-gp.stack.lo != uintptr(startingStackSize) {
// Deallocate old stack. We kept it in gfput because it was the
// right size when the goroutine was put on the free list, but
// the right size has changed since then.
systemstack(func() {
stackfree(gp.stack)
gp.stack.lo = 0
gp.stack.hi = 0
gp.stackguard0 = 0
})
}
if gp.stack.lo == 0 {
// Stack was deallocated in gfput or just above. Allocate a new one.
systemstack(func() {
gp.stack = stackalloc(startingStackSize)
})
gp.stackguard0 = gp.stack.lo + _StackGuard
} else {
if raceenabled {
racemalloc(unsafe.Pointer(gp.stack.lo), gp.stack.hi-gp.stack.lo)
}
if msanenabled {
msanmalloc(unsafe.Pointer(gp.stack.lo), gp.stack.hi-gp.stack.lo)
}
if asanenabled {
asanunpoison(unsafe.Pointer(gp.stack.lo), gp.stack.hi-gp.stack.lo)
}
}
return gp
}
// Purge all cached G's from gfree list to the global list.
func gfpurge(pp *p) {
var (
inc int32
stackQ gQueue
noStackQ gQueue
)
for !pp.gFree.empty() {
gp := pp.gFree.pop()
pp.gFree.n--
if gp.stack.lo == 0 {
noStackQ.push(gp)
} else {
stackQ.push(gp)
}
inc++
}
lock(&sched.gFree.lock)
sched.gFree.noStack.pushAll(noStackQ)
sched.gFree.stack.pushAll(stackQ)
sched.gFree.n += inc
unlock(&sched.gFree.lock)
}
// Breakpoint executes a breakpoint trap.
func Breakpoint() {
breakpoint()
}
// dolockOSThread is called by LockOSThread and lockOSThread below
// after they modify m.locked. Do not allow preemption during this call,
// or else the m might be different in this function than in the caller.
//
//go:nosplit
func dolockOSThread() {
if GOARCH == "wasm" {
return // no threads on wasm yet
}
gp := getg()
gp.m.lockedg.set(gp)
gp.lockedm.set(gp.m)
}
// LockOSThread wires the calling goroutine to its current operating system thread.
// The calling goroutine will always execute in that thread,
// and no other goroutine will execute in it,
// until the calling goroutine has made as many calls to
// UnlockOSThread as to LockOSThread.
// If the calling goroutine exits without unlocking the thread,
// the thread will be terminated.
//
// All init functions are run on the startup thread. Calling LockOSThread
// from an init function will cause the main function to be invoked on
// that thread.
//
// A goroutine should call LockOSThread before calling OS services or
// non-Go library functions that depend on per-thread state.
//
//go:nosplit
func LockOSThread() {
if atomic.Load(&newmHandoff.haveTemplateThread) == 0 && GOOS != "plan9" {
// If we need to start a new thread from the locked
// thread, we need the template thread. Start it now
// while we're in a known-good state.
startTemplateThread()
}
gp := getg()
gp.m.lockedExt++
if gp.m.lockedExt == 0 {
gp.m.lockedExt--
panic("LockOSThread nesting overflow")
}
dolockOSThread()
}
//go:nosplit
func lockOSThread() {
getg().m.lockedInt++
dolockOSThread()
}
// dounlockOSThread is called by UnlockOSThread and unlockOSThread below
// after they update m->locked. Do not allow preemption during this call,
// or else the m might be in different in this function than in the caller.
//
//go:nosplit
func dounlockOSThread() {
if GOARCH == "wasm" {
return // no threads on wasm yet
}
gp := getg()
if gp.m.lockedInt != 0 || gp.m.lockedExt != 0 {
return
}
gp.m.lockedg = 0
gp.lockedm = 0
}
// UnlockOSThread undoes an earlier call to LockOSThread.
// If this drops the number of active LockOSThread calls on the
// calling goroutine to zero, it unwires the calling goroutine from
// its fixed operating system thread.
// If there are no active LockOSThread calls, this is a no-op.
//
// Before calling UnlockOSThread, the caller must ensure that the OS
// thread is suitable for running other goroutines. If the caller made
// any permanent changes to the state of the thread that would affect
// other goroutines, it should not call this function and thus leave
// the goroutine locked to the OS thread until the goroutine (and
// hence the thread) exits.
//
//go:nosplit
func UnlockOSThread() {
gp := getg()
if gp.m.lockedExt == 0 {
return
}
gp.m.lockedExt--
dounlockOSThread()
}
//go:nosplit
func unlockOSThread() {
gp := getg()
if gp.m.lockedInt == 0 {
systemstack(badunlockosthread)
}
gp.m.lockedInt--
dounlockOSThread()
}
func badunlockosthread() {
throw("runtime: internal error: misuse of lockOSThread/unlockOSThread")
}
func gcount() int32 {
n := int32(atomic.Loaduintptr(&allglen)) - sched.gFree.n - sched.ngsys.Load()
for _, pp := range allp {
n -= pp.gFree.n
}
// All these variables can be changed concurrently, so the result can be inconsistent.
// But at least the current goroutine is running.
if n < 1 {
n = 1
}
return n
}
func mcount() int32 {
return int32(sched.mnext - sched.nmfreed)
}
var prof struct {
signalLock atomic.Uint32
// Must hold signalLock to write. Reads may be lock-free, but
// signalLock should be taken to synchronize with changes.
hz atomic.Int32
}
func _System() { _System() }
func _ExternalCode() { _ExternalCode() }
func _LostExternalCode() { _LostExternalCode() }
func _GC() { _GC() }
func _LostSIGPROFDuringAtomic64() { _LostSIGPROFDuringAtomic64() }
func _VDSO() { _VDSO() }
// Called if we receive a SIGPROF signal.
// Called by the signal handler, may run during STW.
//
//go:nowritebarrierrec
func sigprof(pc, sp, lr uintptr, gp *g, mp *m) {
if prof.hz.Load() == 0 {
return
}
// If mp.profilehz is 0, then profiling is not enabled for this thread.
// We must check this to avoid a deadlock between setcpuprofilerate
// and the call to cpuprof.add, below.
if mp != nil && mp.profilehz == 0 {
return
}
// On mips{,le}/arm, 64bit atomics are emulated with spinlocks, in
// runtime/internal/atomic. If SIGPROF arrives while the program is inside
// the critical section, it creates a deadlock (when writing the sample).
// As a workaround, create a counter of SIGPROFs while in critical section
// to store the count, and pass it to sigprof.add() later when SIGPROF is
// received from somewhere else (with _LostSIGPROFDuringAtomic64 as pc).
if GOARCH == "mips" || GOARCH == "mipsle" || GOARCH == "arm" {
if f := findfunc(pc); f.valid() {
if hasPrefix(funcname(f), "runtime/internal/atomic") {
cpuprof.lostAtomic++
return
}
}
if GOARCH == "arm" && goarm < 7 && GOOS == "linux" && pc&0xffff0000 == 0xffff0000 {
// runtime/internal/atomic functions call into kernel
// helpers on arm < 7. See
// runtime/internal/atomic/sys_linux_arm.s.
cpuprof.lostAtomic++
return
}
}
// Profiling runs concurrently with GC, so it must not allocate.
// Set a trap in case the code does allocate.
// Note that on windows, one thread takes profiles of all the
// other threads, so mp is usually not getg().m.
// In fact mp may not even be stopped.
// See golang.org/issue/17165.
getg().m.mallocing++
var stk [maxCPUProfStack]uintptr
n := 0
if mp.ncgo > 0 && mp.curg != nil && mp.curg.syscallpc != 0 && mp.curg.syscallsp != 0 {
cgoOff := 0
// Check cgoCallersUse to make sure that we are not
// interrupting other code that is fiddling with
// cgoCallers. We are running in a signal handler
// with all signals blocked, so we don't have to worry
// about any other code interrupting us.
if mp.cgoCallersUse.Load() == 0 && mp.cgoCallers != nil && mp.cgoCallers[0] != 0 {
for cgoOff < len(mp.cgoCallers) && mp.cgoCallers[cgoOff] != 0 {
cgoOff++
}
copy(stk[:], mp.cgoCallers[:cgoOff])
mp.cgoCallers[0] = 0
}
// Collect Go stack that leads to the cgo call.
n = gentraceback(mp.curg.syscallpc, mp.curg.syscallsp, 0, mp.curg, 0, &stk[cgoOff], len(stk)-cgoOff, nil, nil, 0)
if n > 0 {
n += cgoOff
}
} else if usesLibcall() && mp.libcallg != 0 && mp.libcallpc != 0 && mp.libcallsp != 0 {
// Libcall, i.e. runtime syscall on windows.
// Collect Go stack that leads to the call.
n = gentraceback(mp.libcallpc, mp.libcallsp, 0, mp.libcallg.ptr(), 0, &stk[n], len(stk[n:]), nil, nil, 0)
} else if mp != nil && mp.vdsoSP != 0 {
// VDSO call, e.g. nanotime1 on Linux.
// Collect Go stack that leads to the call.
n = gentraceback(mp.vdsoPC, mp.vdsoSP, 0, gp, 0, &stk[n], len(stk[n:]), nil, nil, _TraceJumpStack)
} else {
n = gentraceback(pc, sp, lr, gp, 0, &stk[0], len(stk), nil, nil, _TraceTrap|_TraceJumpStack)
}
if n <= 0 {
// Normal traceback is impossible or has failed.
// Account it against abstract "System" or "GC".
n = 2
if inVDSOPage(pc) {
pc = abi.FuncPCABIInternal(_VDSO) + sys.PCQuantum
} else if pc > firstmoduledata.etext {
// "ExternalCode" is better than "etext".
pc = abi.FuncPCABIInternal(_ExternalCode) + sys.PCQuantum
}
stk[0] = pc
if mp.preemptoff != "" {
stk[1] = abi.FuncPCABIInternal(_GC) + sys.PCQuantum
} else {
stk[1] = abi.FuncPCABIInternal(_System) + sys.PCQuantum
}
}
if prof.hz.Load() != 0 {
// Note: it can happen on Windows that we interrupted a system thread
// with no g, so gp could nil. The other nil checks are done out of
// caution, but not expected to be nil in practice.
var tagPtr *unsafe.Pointer
if gp != nil && gp.m != nil && gp.m.curg != nil {
tagPtr = &gp.m.curg.labels
}
cpuprof.add(tagPtr, stk[:n])
gprof := gp
var pp *p
if gp != nil && gp.m != nil {
if gp.m.curg != nil {
gprof = gp.m.curg
}
pp = gp.m.p.ptr()
}
traceCPUSample(gprof, pp, stk[:n])
}
getg().m.mallocing--
}
// setcpuprofilerate sets the CPU profiling rate to hz times per second.
// If hz <= 0, setcpuprofilerate turns off CPU profiling.
func setcpuprofilerate(hz int32) {
// Force sane arguments.
if hz < 0 {
hz = 0
}
// Disable preemption, otherwise we can be rescheduled to another thread
// that has profiling enabled.
gp := getg()
gp.m.locks++
// Stop profiler on this thread so that it is safe to lock prof.
// if a profiling signal came in while we had prof locked,
// it would deadlock.
setThreadCPUProfiler(0)
for !prof.signalLock.CompareAndSwap(0, 1) {
osyield()
}
if prof.hz.Load() != hz {
setProcessCPUProfiler(hz)
prof.hz.Store(hz)
}
prof.signalLock.Store(0)
lock(&sched.lock)
sched.profilehz = hz
unlock(&sched.lock)
if hz != 0 {
setThreadCPUProfiler(hz)
}
gp.m.locks--
}
// init initializes pp, which may be a freshly allocated p or a
// previously destroyed p, and transitions it to status _Pgcstop.
func (pp *p) init(id int32) {
pp.id = id
pp.status = _Pgcstop
pp.sudogcache = pp.sudogbuf[:0]
pp.deferpool = pp.deferpoolbuf[:0]
pp.wbBuf.reset()
if pp.mcache == nil {
if id == 0 {
if mcache0 == nil {
throw("missing mcache?")
}
// Use the bootstrap mcache0. Only one P will get
// mcache0: the one with ID 0.
pp.mcache = mcache0
} else {
pp.mcache = allocmcache()
}
}
if raceenabled && pp.raceprocctx == 0 {
if id == 0 {
pp.raceprocctx = raceprocctx0
raceprocctx0 = 0 // bootstrap
} else {
pp.raceprocctx = raceproccreate()
}
}
lockInit(&pp.timersLock, lockRankTimers)
// This P may get timers when it starts running. Set the mask here
// since the P may not go through pidleget (notably P 0 on startup).
timerpMask.set(id)
// Similarly, we may not go through pidleget before this P starts
// running if it is P 0 on startup.
idlepMask.clear(id)
}
// destroy releases all of the resources associated with pp and
// transitions it to status _Pdead.
//
// sched.lock must be held and the world must be stopped.
func (pp *p) destroy() {
assertLockHeld(&sched.lock)
assertWorldStopped()
// Move all runnable goroutines to the global queue
for pp.runqhead != pp.runqtail {
// Pop from tail of local queue
pp.runqtail--
gp := pp.runq[pp.runqtail%uint32(len(pp.runq))].ptr()
// Push onto head of global queue
globrunqputhead(gp)
}
if pp.runnext != 0 {
globrunqputhead(pp.runnext.ptr())
pp.runnext = 0
}
if len(pp.timers) > 0 {
plocal := getg().m.p.ptr()
// The world is stopped, but we acquire timersLock to
// protect against sysmon calling timeSleepUntil.
// This is the only case where we hold the timersLock of
// more than one P, so there are no deadlock concerns.
lock(&plocal.timersLock)
lock(&pp.timersLock)
moveTimers(plocal, pp.timers)
pp.timers = nil
pp.numTimers.Store(0)
pp.deletedTimers.Store(0)
pp.timer0When.Store(0)
unlock(&pp.timersLock)
unlock(&plocal.timersLock)
}
// Flush p's write barrier buffer.
if gcphase != _GCoff {
wbBufFlush1(pp)
pp.gcw.dispose()
}
for i := range pp.sudogbuf {
pp.sudogbuf[i] = nil
}
pp.sudogcache = pp.sudogbuf[:0]
for j := range pp.deferpoolbuf {
pp.deferpoolbuf[j] = nil
}
pp.deferpool = pp.deferpoolbuf[:0]
systemstack(func() {
for i := 0; i < pp.mspancache.len; i++ {
// Safe to call since the world is stopped.
mheap_.spanalloc.free(unsafe.Pointer(pp.mspancache.buf[i]))
}
pp.mspancache.len = 0
lock(&mheap_.lock)
pp.pcache.flush(&mheap_.pages)
unlock(&mheap_.lock)
})
freemcache(pp.mcache)
pp.mcache = nil
gfpurge(pp)
traceProcFree(pp)
if raceenabled {
if pp.timerRaceCtx != 0 {
// The race detector code uses a callback to fetch
// the proc context, so arrange for that callback
// to see the right thing.
// This hack only works because we are the only
// thread running.
mp := getg().m
phold := mp.p.ptr()
mp.p.set(pp)
racectxend(pp.timerRaceCtx)
pp.timerRaceCtx = 0
mp.p.set(phold)
}
raceprocdestroy(pp.raceprocctx)
pp.raceprocctx = 0
}
pp.gcAssistTime = 0
pp.status = _Pdead
}
// Change number of processors.
//
// sched.lock must be held, and the world must be stopped.
//
// gcworkbufs must not be being modified by either the GC or the write barrier
// code, so the GC must not be running if the number of Ps actually changes.
//
// Returns list of Ps with local work, they need to be scheduled by the caller.
func procresize(nprocs int32) *p {
assertLockHeld(&sched.lock)
assertWorldStopped()
old := gomaxprocs
if old < 0 || nprocs <= 0 {
throw("procresize: invalid arg")
}
if trace.enabled {
traceGomaxprocs(nprocs)
}
// update statistics
now := nanotime()
if sched.procresizetime != 0 {
sched.totaltime += int64(old) * (now - sched.procresizetime)
}
sched.procresizetime = now
maskWords := (nprocs + 31) / 32
// Grow allp if necessary.
if nprocs > int32(len(allp)) {
// Synchronize with retake, which could be running
// concurrently since it doesn't run on a P.
lock(&allpLock)
if nprocs <= int32(cap(allp)) {
allp = allp[:nprocs]
} else {
nallp := make([]*p, nprocs)
// Copy everything up to allp's cap so we
// never lose old allocated Ps.
copy(nallp, allp[:cap(allp)])
allp = nallp
}
if maskWords <= int32(cap(idlepMask)) {
idlepMask = idlepMask[:maskWords]
timerpMask = timerpMask[:maskWords]
} else {
nidlepMask := make([]uint32, maskWords)
// No need to copy beyond len, old Ps are irrelevant.
copy(nidlepMask, idlepMask)
idlepMask = nidlepMask
ntimerpMask := make([]uint32, maskWords)
copy(ntimerpMask, timerpMask)
timerpMask = ntimerpMask
}
unlock(&allpLock)
}
// initialize new P's
for i := old; i < nprocs; i++ {
pp := allp[i]
if pp == nil {
pp = new(p)
}
pp.init(i)
atomicstorep(unsafe.Pointer(&allp[i]), unsafe.Pointer(pp))
}
gp := getg()
if gp.m.p != 0 && gp.m.p.ptr().id < nprocs {
// continue to use the current P
gp.m.p.ptr().status = _Prunning
gp.m.p.ptr().mcache.prepareForSweep()
} else {
// release the current P and acquire allp[0].
//
// We must do this before destroying our current P
// because p.destroy itself has write barriers, so we
// need to do that from a valid P.
if gp.m.p != 0 {
if trace.enabled {
// Pretend that we were descheduled
// and then scheduled again to keep
// the trace sane.
traceGoSched()
traceProcStop(gp.m.p.ptr())
}
gp.m.p.ptr().m = 0
}
gp.m.p = 0
pp := allp[0]
pp.m = 0
pp.status = _Pidle
acquirep(pp)
if trace.enabled {
traceGoStart()
}
}
// g.m.p is now set, so we no longer need mcache0 for bootstrapping.
mcache0 = nil
// release resources from unused P's
for i := nprocs; i < old; i++ {
pp := allp[i]
pp.destroy()
// can't free P itself because it can be referenced by an M in syscall
}
// Trim allp.
if int32(len(allp)) != nprocs {
lock(&allpLock)
allp = allp[:nprocs]
idlepMask = idlepMask[:maskWords]
timerpMask = timerpMask[:maskWords]
unlock(&allpLock)
}
var runnablePs *p
for i := nprocs - 1; i >= 0; i-- {
pp := allp[i]
if gp.m.p.ptr() == pp {
continue
}
pp.status = _Pidle
if runqempty(pp) {
pidleput(pp, now)
} else {
pp.m.set(mget())
pp.link.set(runnablePs)
runnablePs = pp
}
}
stealOrder.reset(uint32(nprocs))
var int32p *int32 = &gomaxprocs // make compiler check that gomaxprocs is an int32
atomic.Store((*uint32)(unsafe.Pointer(int32p)), uint32(nprocs))
if old != nprocs {
// Notify the limiter that the amount of procs has changed.
gcCPULimiter.resetCapacity(now, nprocs)
}
return runnablePs
}
// Associate p and the current m.
//
// This function is allowed to have write barriers even if the caller
// isn't because it immediately acquires pp.
//
//go:yeswritebarrierrec
func acquirep(pp *p) {
// Do the part that isn't allowed to have write barriers.
wirep(pp)
// Have p; write barriers now allowed.
// Perform deferred mcache flush before this P can allocate
// from a potentially stale mcache.
pp.mcache.prepareForSweep()
if trace.enabled {
traceProcStart()
}
}
// wirep is the first step of acquirep, which actually associates the
// current M to pp. This is broken out so we can disallow write
// barriers for this part, since we don't yet have a P.
//
//go:nowritebarrierrec
//go:nosplit
func wirep(pp *p) {
gp := getg()
if gp.m.p != 0 {
throw("wirep: already in go")
}
if pp.m != 0 || pp.status != _Pidle {
id := int64(0)
if pp.m != 0 {
id = pp.m.ptr().id
}
print("wirep: p->m=", pp.m, "(", id, ") p->status=", pp.status, "\n")
throw("wirep: invalid p state")
}
gp.m.p.set(pp)
pp.m.set(gp.m)
pp.status = _Prunning
}
// Disassociate p and the current m.
func releasep() *p {
gp := getg()
if gp.m.p == 0 {
throw("releasep: invalid arg")
}
pp := gp.m.p.ptr()
if pp.m.ptr() != gp.m || pp.status != _Prunning {
print("releasep: m=", gp.m, " m->p=", gp.m.p.ptr(), " p->m=", hex(pp.m), " p->status=", pp.status, "\n")
throw("releasep: invalid p state")
}
if trace.enabled {
traceProcStop(gp.m.p.ptr())
}
gp.m.p = 0
pp.m = 0
pp.status = _Pidle
return pp
}
func incidlelocked(v int32) {
lock(&sched.lock)
sched.nmidlelocked += v
if v > 0 {
checkdead()
}
unlock(&sched.lock)
}
// Check for deadlock situation.
// The check is based on number of running M's, if 0 -> deadlock.
// sched.lock must be held.
func checkdead() {
assertLockHeld(&sched.lock)
// For -buildmode=c-shared or -buildmode=c-archive it's OK if
// there are no running goroutines. The calling program is
// assumed to be running.
if islibrary || isarchive {
return
}
// If we are dying because of a signal caught on an already idle thread,
// freezetheworld will cause all running threads to block.
// And runtime will essentially enter into deadlock state,
// except that there is a thread that will call exit soon.
if panicking.Load() > 0 {
return
}
// If we are not running under cgo, but we have an extra M then account
// for it. (It is possible to have an extra M on Windows without cgo to
// accommodate callbacks created by syscall.NewCallback. See issue #6751
// for details.)
var run0 int32
if !iscgo && cgoHasExtraM {
mp := lockextra(true)
haveExtraM := extraMCount > 0
unlockextra(mp)
if haveExtraM {
run0 = 1
}
}
run := mcount() - sched.nmidle - sched.nmidlelocked - sched.nmsys
if run > run0 {
return
}
if run < 0 {
print("runtime: checkdead: nmidle=", sched.nmidle, " nmidlelocked=", sched.nmidlelocked, " mcount=", mcount(), " nmsys=", sched.nmsys, "\n")
throw("checkdead: inconsistent counts")
}
grunning := 0
forEachG(func(gp *g) {
if isSystemGoroutine(gp, false) {
return
}
s := readgstatus(gp)
switch s &^ _Gscan {
case _Gwaiting,
_Gpreempted:
grunning++
case _Grunnable,
_Grunning,
_Gsyscall:
print("runtime: checkdead: find g ", gp.goid, " in status ", s, "\n")
throw("checkdead: runnable g")
}
})
if grunning == 0 { // possible if main goroutine calls runtime·Goexit()
unlock(&sched.lock) // unlock so that GODEBUG=scheddetail=1 doesn't hang
fatal("no goroutines (main called runtime.Goexit) - deadlock!")
}
// Maybe jump time forward for playground.
if faketime != 0 {
if when := timeSleepUntil(); when < maxWhen {
faketime = when
// Start an M to steal the timer.
pp, _ := pidleget(faketime)
if pp == nil {
// There should always be a free P since
// nothing is running.
throw("checkdead: no p for timer")
}
mp := mget()
if mp == nil {
// There should always be a free M since
// nothing is running.
throw("checkdead: no m for timer")
}
// M must be spinning to steal. We set this to be
// explicit, but since this is the only M it would
// become spinning on its own anyways.
sched.nmspinning.Add(1)
mp.spinning = true
mp.nextp.set(pp)
notewakeup(&mp.park)
return
}
}
// There are no goroutines running, so we can look at the P's.
for _, pp := range allp {
if len(pp.timers) > 0 {
return
}
}
unlock(&sched.lock) // unlock so that GODEBUG=scheddetail=1 doesn't hang
fatal("all goroutines are asleep - deadlock!")
}
// forcegcperiod is the maximum time in nanoseconds between garbage
// collections. If we go this long without a garbage collection, one
// is forced to run.
//
// This is a variable for testing purposes. It normally doesn't change.
var forcegcperiod int64 = 2 * 60 * 1e9
// needSysmonWorkaround is true if the workaround for
// golang.org/issue/42515 is needed on NetBSD.
var needSysmonWorkaround bool = false
// Always runs without a P, so write barriers are not allowed.
//
//go:nowritebarrierrec
func sysmon() {
lock(&sched.lock)
sched.nmsys++
checkdead()
unlock(&sched.lock)
lasttrace := int64(0)
idle := 0 // how many cycles in succession we had not wokeup somebody
delay := uint32(0)
for {
if idle == 0 { // start with 20us sleep...
delay = 20
} else if idle > 50 { // start doubling the sleep after 1ms...
delay *= 2
}
if delay > 10*1000 { // up to 10ms
delay = 10 * 1000
}
usleep(delay)
// sysmon should not enter deep sleep if schedtrace is enabled so that
// it can print that information at the right time.
//
// It should also not enter deep sleep if there are any active P's so
// that it can retake P's from syscalls, preempt long running G's, and
// poll the network if all P's are busy for long stretches.
//
// It should wakeup from deep sleep if any P's become active either due
// to exiting a syscall or waking up due to a timer expiring so that it
// can resume performing those duties. If it wakes from a syscall it
// resets idle and delay as a bet that since it had retaken a P from a
// syscall before, it may need to do it again shortly after the
// application starts work again. It does not reset idle when waking
// from a timer to avoid adding system load to applications that spend
// most of their time sleeping.
now := nanotime()
if debug.schedtrace <= 0 && (sched.gcwaiting.Load() || sched.npidle.Load() == gomaxprocs) {
lock(&sched.lock)
if sched.gcwaiting.Load() || sched.npidle.Load() == gomaxprocs {
syscallWake := false
next := timeSleepUntil()
if next > now {
sched.sysmonwait.Store(true)
unlock(&sched.lock)
// Make wake-up period small enough
// for the sampling to be correct.
sleep := forcegcperiod / 2
if next-now < sleep {
sleep = next - now
}
shouldRelax := sleep >= osRelaxMinNS
if shouldRelax {
osRelax(true)
}
syscallWake = notetsleep(&sched.sysmonnote, sleep)
if shouldRelax {
osRelax(false)
}
lock(&sched.lock)
sched.sysmonwait.Store(false)
noteclear(&sched.sysmonnote)
}
if syscallWake {
idle = 0
delay = 20
}
}
unlock(&sched.lock)
}
lock(&sched.sysmonlock)
// Update now in case we blocked on sysmonnote or spent a long time
// blocked on schedlock or sysmonlock above.
now = nanotime()
// trigger libc interceptors if needed
if *cgo_yield != nil {
asmcgocall(*cgo_yield, nil)
}
// poll network if not polled for more than 10ms
lastpoll := sched.lastpoll.Load()
if netpollinited() && lastpoll != 0 && lastpoll+10*1000*1000 < now {
sched.lastpoll.CompareAndSwap(lastpoll, now)
list := netpoll(0) // non-blocking - returns list of goroutines
if !list.empty() {
// Need to decrement number of idle locked M's
// (pretending that one more is running) before injectglist.
// Otherwise it can lead to the following situation:
// injectglist grabs all P's but before it starts M's to run the P's,
// another M returns from syscall, finishes running its G,
// observes that there is no work to do and no other running M's
// and reports deadlock.
incidlelocked(-1)
injectglist(&list)
incidlelocked(1)
}
}
if GOOS == "netbsd" && needSysmonWorkaround {
// netpoll is responsible for waiting for timer
// expiration, so we typically don't have to worry
// about starting an M to service timers. (Note that
// sleep for timeSleepUntil above simply ensures sysmon
// starts running again when that timer expiration may
// cause Go code to run again).
//
// However, netbsd has a kernel bug that sometimes
// misses netpollBreak wake-ups, which can lead to
// unbounded delays servicing timers. If we detect this
// overrun, then startm to get something to handle the
// timer.
//
// See issue 42515 and
// https://gnats.netbsd.org/cgi-bin/query-pr-single.pl?number=50094.
if next := timeSleepUntil(); next < now {
startm(nil, false)
}
}
if scavenger.sysmonWake.Load() != 0 {
// Kick the scavenger awake if someone requested it.
scavenger.wake()
}
// retake P's blocked in syscalls
// and preempt long running G's
if retake(now) != 0 {
idle = 0
} else {
idle++
}
// check if we need to force a GC
if t := (gcTrigger{kind: gcTriggerTime, now: now}); t.test() && forcegc.idle.Load() {
lock(&forcegc.lock)
forcegc.idle.Store(false)
var list gList
list.push(forcegc.g)
injectglist(&list)
unlock(&forcegc.lock)
}
if debug.schedtrace > 0 && lasttrace+int64(debug.schedtrace)*1000000 <= now {
lasttrace = now
schedtrace(debug.scheddetail > 0)
}
unlock(&sched.sysmonlock)
}
}
type sysmontick struct {
schedtick uint32
schedwhen int64
syscalltick uint32
syscallwhen int64
}
// forcePreemptNS is the time slice given to a G before it is
// preempted.
const forcePreemptNS = 10 * 1000 * 1000 // 10ms
func retake(now int64) uint32 {
n := 0
// Prevent allp slice changes. This lock will be completely
// uncontended unless we're already stopping the world.
lock(&allpLock)
// We can't use a range loop over allp because we may
// temporarily drop the allpLock. Hence, we need to re-fetch
// allp each time around the loop.
for i := 0; i < len(allp); i++ {
pp := allp[i]
if pp == nil {
// This can happen if procresize has grown
// allp but not yet created new Ps.
continue
}
pd := &pp.sysmontick
s := pp.status
sysretake := false
if s == _Prunning || s == _Psyscall {
// Preempt G if it's running for too long.
t := int64(pp.schedtick)
if int64(pd.schedtick) != t {
pd.schedtick = uint32(t)
pd.schedwhen = now
} else if pd.schedwhen+forcePreemptNS <= now {
preemptone(pp)
// In case of syscall, preemptone() doesn't
// work, because there is no M wired to P.
sysretake = true
}
}
if s == _Psyscall {
// Retake P from syscall if it's there for more than 1 sysmon tick (at least 20us).
t := int64(pp.syscalltick)
if !sysretake && int64(pd.syscalltick) != t {
pd.syscalltick = uint32(t)
pd.syscallwhen = now
continue
}
// On the one hand we don't want to retake Ps if there is no other work to do,
// but on the other hand we want to retake them eventually
// because they can prevent the sysmon thread from deep sleep.
if runqempty(pp) && sched.nmspinning.Load()+sched.npidle.Load() > 0 && pd.syscallwhen+10*1000*1000 > now {
continue
}
// Drop allpLock so we can take sched.lock.
unlock(&allpLock)
// Need to decrement number of idle locked M's
// (pretending that one more is running) before the CAS.
// Otherwise the M from which we retake can exit the syscall,
// increment nmidle and report deadlock.
incidlelocked(-1)
if atomic.Cas(&pp.status, s, _Pidle) {
if trace.enabled {
traceGoSysBlock(pp)
traceProcStop(pp)
}
n++
pp.syscalltick++
handoffp(pp)
}
incidlelocked(1)
lock(&allpLock)
}
}
unlock(&allpLock)
return uint32(n)
}
// Tell all goroutines that they have been preempted and they should stop.
// This function is purely best-effort. It can fail to inform a goroutine if a
// processor just started running it.
// No locks need to be held.
// Returns true if preemption request was issued to at least one goroutine.
func preemptall() bool {
res := false
for _, pp := range allp {
if pp.status != _Prunning {
continue
}
if preemptone(pp) {
res = true
}
}
return res
}
// Tell the goroutine running on processor P to stop.
// This function is purely best-effort. It can incorrectly fail to inform the
// goroutine. It can inform the wrong goroutine. Even if it informs the
// correct goroutine, that goroutine might ignore the request if it is
// simultaneously executing newstack.
// No lock needs to be held.
// Returns true if preemption request was issued.
// The actual preemption will happen at some point in the future
// and will be indicated by the gp->status no longer being
// Grunning
func preemptone(pp *p) bool {
mp := pp.m.ptr()
if mp == nil || mp == getg().m {
return false
}
gp := mp.curg
if gp == nil || gp == mp.g0 {
return false
}
gp.preempt = true
// Every call in a goroutine checks for stack overflow by
// comparing the current stack pointer to gp->stackguard0.
// Setting gp->stackguard0 to StackPreempt folds
// preemption into the normal stack overflow check.
gp.stackguard0 = stackPreempt
// Request an async preemption of this P.
if preemptMSupported && debug.asyncpreemptoff == 0 {
pp.preempt = true
preemptM(mp)
}
return true
}
var starttime int64
func schedtrace(detailed bool) {
now := nanotime()
if starttime == 0 {
starttime = now
}
lock(&sched.lock)
print("SCHED ", (now-starttime)/1e6, "ms: gomaxprocs=", gomaxprocs, " idleprocs=", sched.npidle.Load(), " threads=", mcount(), " spinningthreads=", sched.nmspinning.Load(), " needspinning=", sched.needspinning.Load(), " idlethreads=", sched.nmidle, " runqueue=", sched.runqsize)
if detailed {
print(" gcwaiting=", sched.gcwaiting.Load(), " nmidlelocked=", sched.nmidlelocked, " stopwait=", sched.stopwait, " sysmonwait=", sched.sysmonwait.Load(), "\n")
}
// We must be careful while reading data from P's, M's and G's.
// Even if we hold schedlock, most data can be changed concurrently.
// E.g. (p->m ? p->m->id : -1) can crash if p->m changes from non-nil to nil.
for i, pp := range allp {
mp := pp.m.ptr()
h := atomic.Load(&pp.runqhead)
t := atomic.Load(&pp.runqtail)
if detailed {
print(" P", i, ": status=", pp.status, " schedtick=", pp.schedtick, " syscalltick=", pp.syscalltick, " m=")
if mp != nil {
print(mp.id)
} else {
print("nil")
}
print(" runqsize=", t-h, " gfreecnt=", pp.gFree.n, " timerslen=", len(pp.timers), "\n")
} else {
// In non-detailed mode format lengths of per-P run queues as:
// [len1 len2 len3 len4]
print(" ")
if i == 0 {
print("[")
}
print(t - h)
if i == len(allp)-1 {
print("]\n")
}
}
}
if !detailed {
unlock(&sched.lock)
return
}
for mp := allm; mp != nil; mp = mp.alllink {
pp := mp.p.ptr()
print(" M", mp.id, ": p=")
if pp != nil {
print(pp.id)
} else {
print("nil")
}
print(" curg=")
if mp.curg != nil {
print(mp.curg.goid)
} else {
print("nil")
}
print(" mallocing=", mp.mallocing, " throwing=", mp.throwing, " preemptoff=", mp.preemptoff, " locks=", mp.locks, " dying=", mp.dying, " spinning=", mp.spinning, " blocked=", mp.blocked, " lockedg=")
if lockedg := mp.lockedg.ptr(); lockedg != nil {
print(lockedg.goid)
} else {
print("nil")
}
print("\n")
}
forEachG(func(gp *g) {
print(" G", gp.goid, ": status=", readgstatus(gp), "(", gp.waitreason.String(), ") m=")
if gp.m != nil {
print(gp.m.id)
} else {
print("nil")
}
print(" lockedm=")
if lockedm := gp.lockedm.ptr(); lockedm != nil {
print(lockedm.id)
} else {
print("nil")
}
print("\n")
})
unlock(&sched.lock)
}
// schedEnableUser enables or disables the scheduling of user
// goroutines.
//
// This does not stop already running user goroutines, so the caller
// should first stop the world when disabling user goroutines.
func schedEnableUser(enable bool) {
lock(&sched.lock)
if sched.disable.user == !enable {
unlock(&sched.lock)
return
}
sched.disable.user = !enable
if enable {
n := sched.disable.n
sched.disable.n = 0
globrunqputbatch(&sched.disable.runnable, n)
unlock(&sched.lock)
for ; n != 0 && sched.npidle.Load() != 0; n-- {
startm(nil, false)
}
} else {
unlock(&sched.lock)
}
}
// schedEnabled reports whether gp should be scheduled. It returns
// false is scheduling of gp is disabled.
//
// sched.lock must be held.
func schedEnabled(gp *g) bool {
assertLockHeld(&sched.lock)
if sched.disable.user {
return isSystemGoroutine(gp, true)
}
return true
}
// Put mp on midle list.
// sched.lock must be held.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func mput(mp *m) {
assertLockHeld(&sched.lock)
mp.schedlink = sched.midle
sched.midle.set(mp)
sched.nmidle++
checkdead()
}
// Try to get an m from midle list.
// sched.lock must be held.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func mget() *m {
assertLockHeld(&sched.lock)
mp := sched.midle.ptr()
if mp != nil {
sched.midle = mp.schedlink
sched.nmidle--
}
return mp
}
// Put gp on the global runnable queue.
// sched.lock must be held.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func globrunqput(gp *g) {
assertLockHeld(&sched.lock)
sched.runq.pushBack(gp)
sched.runqsize++
}
// Put gp at the head of the global runnable queue.
// sched.lock must be held.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func globrunqputhead(gp *g) {
assertLockHeld(&sched.lock)
sched.runq.push(gp)
sched.runqsize++
}
// Put a batch of runnable goroutines on the global runnable queue.
// This clears *batch.
// sched.lock must be held.
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func globrunqputbatch(batch *gQueue, n int32) {
assertLockHeld(&sched.lock)
sched.runq.pushBackAll(*batch)
sched.runqsize += n
*batch = gQueue{}
}
// Try get a batch of G's from the global runnable queue.
// sched.lock must be held.
func globrunqget(pp *p, max int32) *g {
assertLockHeld(&sched.lock)
if sched.runqsize == 0 {
return nil
}
n := sched.runqsize/gomaxprocs + 1
if n > sched.runqsize {
n = sched.runqsize
}
if max > 0 && n > max {
n = max
}
if n > int32(len(pp.runq))/2 {
n = int32(len(pp.runq)) / 2
}
sched.runqsize -= n
gp := sched.runq.pop()
n--
for ; n > 0; n-- {
gp1 := sched.runq.pop()
runqput(pp, gp1, false)
}
return gp
}
// pMask is an atomic bitstring with one bit per P.
type pMask []uint32
// read returns true if P id's bit is set.
func (p pMask) read(id uint32) bool {
word := id / 32
mask := uint32(1) << (id % 32)
return (atomic.Load(&p[word]) & mask) != 0
}
// set sets P id's bit.
func (p pMask) set(id int32) {
word := id / 32
mask := uint32(1) << (id % 32)
atomic.Or(&p[word], mask)
}
// clear clears P id's bit.
func (p pMask) clear(id int32) {
word := id / 32
mask := uint32(1) << (id % 32)
atomic.And(&p[word], ^mask)
}
// updateTimerPMask clears pp's timer mask if it has no timers on its heap.
//
// Ideally, the timer mask would be kept immediately consistent on any timer
// operations. Unfortunately, updating a shared global data structure in the
// timer hot path adds too much overhead in applications frequently switching
// between no timers and some timers.
//
// As a compromise, the timer mask is updated only on pidleget / pidleput. A
// running P (returned by pidleget) may add a timer at any time, so its mask
// must be set. An idle P (passed to pidleput) cannot add new timers while
// idle, so if it has no timers at that time, its mask may be cleared.
//
// Thus, we get the following effects on timer-stealing in findrunnable:
//
// - Idle Ps with no timers when they go idle are never checked in findrunnable
// (for work- or timer-stealing; this is the ideal case).
// - Running Ps must always be checked.
// - Idle Ps whose timers are stolen must continue to be checked until they run
// again, even after timer expiration.
//
// When the P starts running again, the mask should be set, as a timer may be
// added at any time.
//
// TODO(prattmic): Additional targeted updates may improve the above cases.
// e.g., updating the mask when stealing a timer.
func updateTimerPMask(pp *p) {
if pp.numTimers.Load() > 0 {
return
}
// Looks like there are no timers, however another P may transiently
// decrement numTimers when handling a timerModified timer in
// checkTimers. We must take timersLock to serialize with these changes.
lock(&pp.timersLock)
if pp.numTimers.Load() == 0 {
timerpMask.clear(pp.id)
}
unlock(&pp.timersLock)
}
// pidleput puts p on the _Pidle list. now must be a relatively recent call
// to nanotime or zero. Returns now or the current time if now was zero.
//
// This releases ownership of p. Once sched.lock is released it is no longer
// safe to use p.
//
// sched.lock must be held.
//
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func pidleput(pp *p, now int64) int64 {
assertLockHeld(&sched.lock)
if !runqempty(pp) {
throw("pidleput: P has non-empty run queue")
}
if now == 0 {
now = nanotime()
}
updateTimerPMask(pp) // clear if there are no timers.
idlepMask.set(pp.id)
pp.link = sched.pidle
sched.pidle.set(pp)
sched.npidle.Add(1)
if !pp.limiterEvent.start(limiterEventIdle, now) {
throw("must be able to track idle limiter event")
}
return now
}
// pidleget tries to get a p from the _Pidle list, acquiring ownership.
//
// sched.lock must be held.
//
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func pidleget(now int64) (*p, int64) {
assertLockHeld(&sched.lock)
pp := sched.pidle.ptr()
if pp != nil {
// Timer may get added at any time now.
if now == 0 {
now = nanotime()
}
timerpMask.set(pp.id)
idlepMask.clear(pp.id)
sched.pidle = pp.link
sched.npidle.Add(-1)
pp.limiterEvent.stop(limiterEventIdle, now)
}
return pp, now
}
// pidlegetSpinning tries to get a p from the _Pidle list, acquiring ownership.
// This is called by spinning Ms (or callers than need a spinning M) that have
// found work. If no P is available, this must synchronized with non-spinning
// Ms that may be preparing to drop their P without discovering this work.
//
// sched.lock must be held.
//
// May run during STW, so write barriers are not allowed.
//
//go:nowritebarrierrec
func pidlegetSpinning(now int64) (*p, int64) {
assertLockHeld(&sched.lock)
pp, now := pidleget(now)
if pp == nil {
// See "Delicate dance" comment in findrunnable. We found work
// that we cannot take, we must synchronize with non-spinning
// Ms that may be preparing to drop their P.
sched.needspinning.Store(1)
return nil, now
}
return pp, now
}
// runqempty reports whether pp has no Gs on its local run queue.
// It never returns true spuriously.
func runqempty(pp *p) bool {
// Defend against a race where 1) pp has G1 in runqnext but runqhead == runqtail,
// 2) runqput on pp kicks G1 to the runq, 3) runqget on pp empties runqnext.
// Simply observing that runqhead == runqtail and then observing that runqnext == nil
// does not mean the queue is empty.
for {
head := atomic.Load(&pp.runqhead)
tail := atomic.Load(&pp.runqtail)
runnext := atomic.Loaduintptr((*uintptr)(unsafe.Pointer(&pp.runnext)))
if tail == atomic.Load(&pp.runqtail) {
return head == tail && runnext == 0
}
}
}
// To shake out latent assumptions about scheduling order,
// we introduce some randomness into scheduling decisions
// when running with the race detector.
// The need for this was made obvious by changing the
// (deterministic) scheduling order in Go 1.5 and breaking
// many poorly-written tests.
// With the randomness here, as long as the tests pass
// consistently with -race, they shouldn't have latent scheduling
// assumptions.
const randomizeScheduler = raceenabled
// runqput tries to put g on the local runnable queue.
// If next is false, runqput adds g to the tail of the runnable queue.
// If next is true, runqput puts g in the pp.runnext slot.
// If the run queue is full, runnext puts g on the global queue.
// Executed only by the owner P.
func runqput(pp *p, gp *g, next bool) {
if randomizeScheduler && next && fastrandn(2) == 0 {
next = false
}
if next {
retryNext:
oldnext := pp.runnext
if !pp.runnext.cas(oldnext, guintptr(unsafe.Pointer(gp))) {
goto retryNext
}
if oldnext == 0 {
return
}
// Kick the old runnext out to the regular run queue.
gp = oldnext.ptr()
}
retry:
h := atomic.LoadAcq(&pp.runqhead) // load-acquire, synchronize with consumers
t := pp.runqtail
if t-h < uint32(len(pp.runq)) {
pp.runq[t%uint32(len(pp.runq))].set(gp)
atomic.StoreRel(&pp.runqtail, t+1) // store-release, makes the item available for consumption
return
}
if runqputslow(pp, gp, h, t) {
return
}
// the queue is not full, now the put above must succeed
goto retry
}
// Put g and a batch of work from local runnable queue on global queue.
// Executed only by the owner P.
func runqputslow(pp *p, gp *g, h, t uint32) bool {
var batch [len(pp.runq)/2 + 1]*g
// First, grab a batch from local queue.
n := t - h
n = n / 2
if n != uint32(len(pp.runq)/2) {
throw("runqputslow: queue is not full")
}
for i := uint32(0); i < n; i++ {
batch[i] = pp.runq[(h+i)%uint32(len(pp.runq))].ptr()
}
if !atomic.CasRel(&pp.runqhead, h, h+n) { // cas-release, commits consume
return false
}
batch[n] = gp
if randomizeScheduler {
for i := uint32(1); i <= n; i++ {
j := fastrandn(i + 1)
batch[i], batch[j] = batch[j], batch[i]
}
}
// Link the goroutines.
for i := uint32(0); i < n; i++ {
batch[i].schedlink.set(batch[i+1])
}
var q gQueue
q.head.set(batch[0])
q.tail.set(batch[n])
// Now put the batch on global queue.
lock(&sched.lock)
globrunqputbatch(&q, int32(n+1))
unlock(&sched.lock)
return true
}
// runqputbatch tries to put all the G's on q on the local runnable queue.
// If the queue is full, they are put on the global queue; in that case
// this will temporarily acquire the scheduler lock.
// Executed only by the owner P.
func runqputbatch(pp *p, q *gQueue, qsize int) {
h := atomic.LoadAcq(&pp.runqhead)
t := pp.runqtail
n := uint32(0)
for !q.empty() && t-h < uint32(len(pp.runq)) {
gp := q.pop()
pp.runq[t%uint32(len(pp.runq))].set(gp)
t++
n++
}
qsize -= int(n)
if randomizeScheduler {
off := func(o uint32) uint32 {
return (pp.runqtail + o) % uint32(len(pp.runq))
}
for i := uint32(1); i < n; i++ {
j := fastrandn(i + 1)
pp.runq[off(i)], pp.runq[off(j)] = pp.runq[off(j)], pp.runq[off(i)]
}
}
atomic.StoreRel(&pp.runqtail, t)
if !q.empty() {
lock(&sched.lock)
globrunqputbatch(q, int32(qsize))
unlock(&sched.lock)
}
}
// Get g from local runnable queue.
// If inheritTime is true, gp should inherit the remaining time in the
// current time slice. Otherwise, it should start a new time slice.
// Executed only by the owner P.
func runqget(pp *p) (gp *g, inheritTime bool) {
// If there's a runnext, it's the next G to run.
next := pp.runnext
// If the runnext is non-0 and the CAS fails, it could only have been stolen by another P,
// because other Ps can race to set runnext to 0, but only the current P can set it to non-0.
// Hence, there's no need to retry this CAS if it fails.
if next != 0 && pp.runnext.cas(next, 0) {
return next.ptr(), true
}
for {
h := atomic.LoadAcq(&pp.runqhead) // load-acquire, synchronize with other consumers
t := pp.runqtail
if t == h {
return nil, false
}
gp := pp.runq[h%uint32(len(pp.runq))].ptr()
if atomic.CasRel(&pp.runqhead, h, h+1) { // cas-release, commits consume
return gp, false
}
}
}
// runqdrain drains the local runnable queue of pp and returns all goroutines in it.
// Executed only by the owner P.
func runqdrain(pp *p) (drainQ gQueue, n uint32) {
oldNext := pp.runnext
if oldNext != 0 && pp.runnext.cas(oldNext, 0) {
drainQ.pushBack(oldNext.ptr())
n++
}
retry:
h := atomic.LoadAcq(&pp.runqhead) // load-acquire, synchronize with other consumers
t := pp.runqtail
qn := t - h
if qn == 0 {
return
}
if qn > uint32(len(pp.runq)) { // read inconsistent h and t
goto retry
}
if !atomic.CasRel(&pp.runqhead, h, h+qn) { // cas-release, commits consume
goto retry
}
// We've inverted the order in which it gets G's from the local P's runnable queue
// and then advances the head pointer because we don't want to mess up the statuses of G's
// while runqdrain() and runqsteal() are running in parallel.
// Thus we should advance the head pointer before draining the local P into a gQueue,
// so that we can update any gp.schedlink only after we take the full ownership of G,
// meanwhile, other P's can't access to all G's in local P's runnable queue and steal them.
// See https://groups.google.com/g/golang-dev/c/0pTKxEKhHSc/m/6Q85QjdVBQAJ for more details.
for i := uint32(0); i < qn; i++ {
gp := pp.runq[(h+i)%uint32(len(pp.runq))].ptr()
drainQ.pushBack(gp)
n++
}
return
}
// Grabs a batch of goroutines from pp's runnable queue into batch.
// Batch is a ring buffer starting at batchHead.
// Returns number of grabbed goroutines.
// Can be executed by any P.
func runqgrab(pp *p, batch *[256]guintptr, batchHead uint32, stealRunNextG bool) uint32 {
for {
h := atomic.LoadAcq(&pp.runqhead) // load-acquire, synchronize with other consumers
t := atomic.LoadAcq(&pp.runqtail) // load-acquire, synchronize with the producer
n := t - h
n = n - n/2
if n == 0 {
if stealRunNextG {
// Try to steal from pp.runnext.
if next := pp.runnext; next != 0 {
if pp.status == _Prunning {
// Sleep to ensure that pp isn't about to run the g
// we are about to steal.
// The important use case here is when the g running
// on pp ready()s another g and then almost
// immediately blocks. Instead of stealing runnext
// in this window, back off to give pp a chance to
// schedule runnext. This will avoid thrashing gs
// between different Ps.
// A sync chan send/recv takes ~50ns as of time of
// writing, so 3us gives ~50x overshoot.
if GOOS != "windows" && GOOS != "openbsd" && GOOS != "netbsd" {
usleep(3)
} else {
// On some platforms system timer granularity is
// 1-15ms, which is way too much for this
// optimization. So just yield.
osyield()
}
}
if !pp.runnext.cas(next, 0) {
continue
}
batch[batchHead%uint32(len(batch))] = next
return 1
}
}
return 0
}
if n > uint32(len(pp.runq)/2) { // read inconsistent h and t
continue
}
for i := uint32(0); i < n; i++ {
g := pp.runq[(h+i)%uint32(len(pp.runq))]
batch[(batchHead+i)%uint32(len(batch))] = g
}
if atomic.CasRel(&pp.runqhead, h, h+n) { // cas-release, commits consume
return n
}
}
}
// Steal half of elements from local runnable queue of p2
// and put onto local runnable queue of p.
// Returns one of the stolen elements (or nil if failed).
func runqsteal(pp, p2 *p, stealRunNextG bool) *g {
t := pp.runqtail
n := runqgrab(p2, &pp.runq, t, stealRunNextG)
if n == 0 {
return nil
}
n--
gp := pp.runq[(t+n)%uint32(len(pp.runq))].ptr()
if n == 0 {
return gp
}
h := atomic.LoadAcq(&pp.runqhead) // load-acquire, synchronize with consumers
if t-h+n >= uint32(len(pp.runq)) {
throw("runqsteal: runq overflow")
}
atomic.StoreRel(&pp.runqtail, t+n) // store-release, makes the item available for consumption
return gp
}
// A gQueue is a dequeue of Gs linked through g.schedlink. A G can only
// be on one gQueue or gList at a time.
type gQueue struct {
head guintptr
tail guintptr
}
// empty reports whether q is empty.
func (q *gQueue) empty() bool {
return q.head == 0
}
// push adds gp to the head of q.
func (q *gQueue) push(gp *g) {
gp.schedlink = q.head
q.head.set(gp)
if q.tail == 0 {
q.tail.set(gp)
}
}
// pushBack adds gp to the tail of q.
func (q *gQueue) pushBack(gp *g) {
gp.schedlink = 0
if q.tail != 0 {
q.tail.ptr().schedlink.set(gp)
} else {
q.head.set(gp)
}
q.tail.set(gp)
}
// pushBackAll adds all Gs in q2 to the tail of q. After this q2 must
// not be used.
func (q *gQueue) pushBackAll(q2 gQueue) {
if q2.tail == 0 {
return
}
q2.tail.ptr().schedlink = 0
if q.tail != 0 {
q.tail.ptr().schedlink = q2.head
} else {
q.head = q2.head
}
q.tail = q2.tail
}
// pop removes and returns the head of queue q. It returns nil if
// q is empty.
func (q *gQueue) pop() *g {
gp := q.head.ptr()
if gp != nil {
q.head = gp.schedlink
if q.head == 0 {
q.tail = 0
}
}
return gp
}
// popList takes all Gs in q and returns them as a gList.
func (q *gQueue) popList() gList {
stack := gList{q.head}
*q = gQueue{}
return stack
}
// A gList is a list of Gs linked through g.schedlink. A G can only be
// on one gQueue or gList at a time.
type gList struct {
head guintptr
}
// empty reports whether l is empty.
func (l *gList) empty() bool {
return l.head == 0
}
// push adds gp to the head of l.
func (l *gList) push(gp *g) {
gp.schedlink = l.head
l.head.set(gp)
}
// pushAll prepends all Gs in q to l.
func (l *gList) pushAll(q gQueue) {
if !q.empty() {
q.tail.ptr().schedlink = l.head
l.head = q.head
}
}
// pop removes and returns the head of l. If l is empty, it returns nil.
func (l *gList) pop() *g {
gp := l.head.ptr()
if gp != nil {
l.head = gp.schedlink
}
return gp
}
//go:linkname setMaxThreads runtime/debug.setMaxThreads
func setMaxThreads(in int) (out int) {
lock(&sched.lock)
out = int(sched.maxmcount)
if in > 0x7fffffff { // MaxInt32
sched.maxmcount = 0x7fffffff
} else {
sched.maxmcount = int32(in)
}
checkmcount()
unlock(&sched.lock)
return
}
//go:nosplit
func procPin() int {
gp := getg()
mp := gp.m
mp.locks++
return int(mp.p.ptr().id)
}
//go:nosplit
func procUnpin() {
gp := getg()
gp.m.locks--
}
//go:linkname sync_runtime_procPin sync.runtime_procPin
//go:nosplit
func sync_runtime_procPin() int {
return procPin()
}
//go:linkname sync_runtime_procUnpin sync.runtime_procUnpin
//go:nosplit
func sync_runtime_procUnpin() {
procUnpin()
}
//go:linkname sync_atomic_runtime_procPin sync/atomic.runtime_procPin
//go:nosplit
func sync_atomic_runtime_procPin() int {
return procPin()
}
//go:linkname sync_atomic_runtime_procUnpin sync/atomic.runtime_procUnpin
//go:nosplit
func sync_atomic_runtime_procUnpin() {
procUnpin()
}
// Active spinning for sync.Mutex.
//
//go:linkname sync_runtime_canSpin sync.runtime_canSpin
//go:nosplit
func sync_runtime_canSpin(i int) bool {
// sync.Mutex is cooperative, so we are conservative with spinning.
// Spin only few times and only if running on a multicore machine and
// GOMAXPROCS>1 and there is at least one other running P and local runq is empty.
// As opposed to runtime mutex we don't do passive spinning here,
// because there can be work on global runq or on other Ps.
if i >= active_spin || ncpu <= 1 || gomaxprocs <= sched.npidle.Load()+sched.nmspinning.Load()+1 {
return false
}
if p := getg().m.p.ptr(); !runqempty(p) {
return false
}
return true
}
//go:linkname sync_runtime_doSpin sync.runtime_doSpin
//go:nosplit
func sync_runtime_doSpin() {
procyield(active_spin_cnt)
}
var stealOrder randomOrder
// randomOrder/randomEnum are helper types for randomized work stealing.
// They allow to enumerate all Ps in different pseudo-random orders without repetitions.
// The algorithm is based on the fact that if we have X such that X and GOMAXPROCS
// are coprime, then a sequences of (i + X) % GOMAXPROCS gives the required enumeration.
type randomOrder struct {
count uint32
coprimes []uint32
}
type randomEnum struct {
i uint32
count uint32
pos uint32
inc uint32
}
func (ord *randomOrder) reset(count uint32) {
ord.count = count
ord.coprimes = ord.coprimes[:0]
for i := uint32(1); i <= count; i++ {
if gcd(i, count) == 1 {
ord.coprimes = append(ord.coprimes, i)
}
}
}
func (ord *randomOrder) start(i uint32) randomEnum {
return randomEnum{
count: ord.count,
pos: i % ord.count,
inc: ord.coprimes[i/ord.count%uint32(len(ord.coprimes))],
}
}
func (enum *randomEnum) done() bool {
return enum.i == enum.count
}
func (enum *randomEnum) next() {
enum.i++
enum.pos = (enum.pos + enum.inc) % enum.count
}
func (enum *randomEnum) position() uint32 {
return enum.pos
}
func gcd(a, b uint32) uint32 {
for b != 0 {
a, b = b, a%b
}
return a
}
// An initTask represents the set of initializations that need to be done for a package.
// Keep in sync with ../../test/noinit.go:initTask
type initTask struct {
state uint32 // 0 = uninitialized, 1 = in progress, 2 = done
nfns uint32
// followed by nfns pcs, uintptr sized, one per init function to run
}
// inittrace stores statistics for init functions which are
// updated by malloc and newproc when active is true.
var inittrace tracestat
type tracestat struct {
active bool // init tracing activation status
id uint64 // init goroutine id
allocs uint64 // heap allocations
bytes uint64 // heap allocated bytes
}
func doInit(ts []*initTask) {
for _, t := range ts {
doInit1(t)
}
}
func doInit1(t *initTask) {
switch t.state {
case 2: // fully initialized
return
case 1: // initialization in progress
throw("recursive call during initialization - linker skew")
default: // not initialized yet
t.state = 1 // initialization in progress
var (
start int64
before tracestat
)
if inittrace.active {
start = nanotime()
// Load stats non-atomically since tracinit is updated only by this init goroutine.
before = inittrace
}
if t.nfns == 0 {
// We should have pruned all of these in the linker.
throw("inittask with no functions")
}
firstFunc := add(unsafe.Pointer(t), 8)
for i := uint32(0); i < t.nfns; i++ {
p := add(firstFunc, uintptr(i)*goarch.PtrSize)
f := *(*func())(unsafe.Pointer(&p))
f()
}
if inittrace.active {
end := nanotime()
// Load stats non-atomically since tracinit is updated only by this init goroutine.
after := inittrace
f := *(*func())(unsafe.Pointer(&firstFunc))
pkg := funcpkgpath(findfunc(abi.FuncPCABIInternal(f)))
var sbuf [24]byte
print("init ", pkg, " @")
print(string(fmtNSAsMS(sbuf[:], uint64(start-runtimeInitTime))), " ms, ")
print(string(fmtNSAsMS(sbuf[:], uint64(end-start))), " ms clock, ")
print(string(itoa(sbuf[:], after.bytes-before.bytes)), " bytes, ")
print(string(itoa(sbuf[:], after.allocs-before.allocs)), " allocs")
print("\n")
}
t.state = 2 // initialization done
}
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
// A profBuf is a lock-free buffer for profiling events,
// safe for concurrent use by one reader and one writer.
// The writer may be a signal handler running without a user g.
// The reader is assumed to be a user g.
//
// Each logged event corresponds to a fixed size header, a list of
// uintptrs (typically a stack), and exactly one unsafe.Pointer tag.
// The header and uintptrs are stored in the circular buffer data and the
// tag is stored in a circular buffer tags, running in parallel.
// In the circular buffer data, each event takes 2+hdrsize+len(stk)
// words: the value 2+hdrsize+len(stk), then the time of the event, then
// hdrsize words giving the fixed-size header, and then len(stk) words
// for the stack.
//
// The current effective offsets into the tags and data circular buffers
// for reading and writing are stored in the high 30 and low 32 bits of r and w.
// The bottom bits of the high 32 are additional flag bits in w, unused in r.
// "Effective" offsets means the total number of reads or writes, mod 2^length.
// The offset in the buffer is the effective offset mod the length of the buffer.
// To make wraparound mod 2^length match wraparound mod length of the buffer,
// the length of the buffer must be a power of two.
//
// If the reader catches up to the writer, a flag passed to read controls
// whether the read blocks until more data is available. A read returns a
// pointer to the buffer data itself; the caller is assumed to be done with
// that data at the next read. The read offset rNext tracks the next offset to
// be returned by read. By definition, r ≤ rNext ≤ w (before wraparound),
// and rNext is only used by the reader, so it can be accessed without atomics.
//
// If the writer gets ahead of the reader, so that the buffer fills,
// future writes are discarded and replaced in the output stream by an
// overflow entry, which has size 2+hdrsize+1, time set to the time of
// the first discarded write, a header of all zeroed words, and a "stack"
// containing one word, the number of discarded writes.
//
// Between the time the buffer fills and the buffer becomes empty enough
// to hold more data, the overflow entry is stored as a pending overflow
// entry in the fields overflow and overflowTime. The pending overflow
// entry can be turned into a real record by either the writer or the
// reader. If the writer is called to write a new record and finds that
// the output buffer has room for both the pending overflow entry and the
// new record, the writer emits the pending overflow entry and the new
// record into the buffer. If the reader is called to read data and finds
// that the output buffer is empty but that there is a pending overflow
// entry, the reader will return a synthesized record for the pending
// overflow entry.
//
// Only the writer can create or add to a pending overflow entry, but
// either the reader or the writer can clear the pending overflow entry.
// A pending overflow entry is indicated by the low 32 bits of 'overflow'
// holding the number of discarded writes, and overflowTime holding the
// time of the first discarded write. The high 32 bits of 'overflow'
// increment each time the low 32 bits transition from zero to non-zero
// or vice versa. This sequence number avoids ABA problems in the use of
// compare-and-swap to coordinate between reader and writer.
// The overflowTime is only written when the low 32 bits of overflow are
// zero, that is, only when there is no pending overflow entry, in
// preparation for creating a new one. The reader can therefore fetch and
// clear the entry atomically using
//
// for {
// overflow = load(&b.overflow)
// if uint32(overflow) == 0 {
// // no pending entry
// break
// }
// time = load(&b.overflowTime)
// if cas(&b.overflow, overflow, ((overflow>>32)+1)<<32) {
// // pending entry cleared
// break
// }
// }
// if uint32(overflow) > 0 {
// emit entry for uint32(overflow), time
// }
type profBuf struct {
// accessed atomically
r, w profAtomic
overflow atomic.Uint64
overflowTime atomic.Uint64
eof atomic.Uint32
// immutable (excluding slice content)
hdrsize uintptr
data []uint64
tags []unsafe.Pointer
// owned by reader
rNext profIndex
overflowBuf []uint64 // for use by reader to return overflow record
wait note
}
// A profAtomic is the atomically-accessed word holding a profIndex.
type profAtomic uint64
// A profIndex is the packet tag and data counts and flags bits, described above.
type profIndex uint64
const (
profReaderSleeping profIndex = 1 << 32 // reader is sleeping and must be woken up
profWriteExtra profIndex = 1 << 33 // overflow or eof waiting
)
func (x *profAtomic) load() profIndex {
return profIndex(atomic.Load64((*uint64)(x)))
}
func (x *profAtomic) store(new profIndex) {
atomic.Store64((*uint64)(x), uint64(new))
}
func (x *profAtomic) cas(old, new profIndex) bool {
return atomic.Cas64((*uint64)(x), uint64(old), uint64(new))
}
func (x profIndex) dataCount() uint32 {
return uint32(x)
}
func (x profIndex) tagCount() uint32 {
return uint32(x >> 34)
}
// countSub subtracts two counts obtained from profIndex.dataCount or profIndex.tagCount,
// assuming that they are no more than 2^29 apart (guaranteed since they are never more than
// len(data) or len(tags) apart, respectively).
// tagCount wraps at 2^30, while dataCount wraps at 2^32.
// This function works for both.
func countSub(x, y uint32) int {
// x-y is 32-bit signed or 30-bit signed; sign-extend to 32 bits and convert to int.
return int(int32(x-y) << 2 >> 2)
}
// addCountsAndClearFlags returns the packed form of "x + (data, tag) - all flags".
func (x profIndex) addCountsAndClearFlags(data, tag int) profIndex {
return profIndex((uint64(x)>>34+uint64(uint32(tag)<<2>>2))<<34 | uint64(uint32(x)+uint32(data)))
}
// hasOverflow reports whether b has any overflow records pending.
func (b *profBuf) hasOverflow() bool {
return uint32(b.overflow.Load()) > 0
}
// takeOverflow consumes the pending overflow records, returning the overflow count
// and the time of the first overflow.
// When called by the reader, it is racing against incrementOverflow.
func (b *profBuf) takeOverflow() (count uint32, time uint64) {
overflow := b.overflow.Load()
time = b.overflowTime.Load()
for {
count = uint32(overflow)
if count == 0 {
time = 0
break
}
// Increment generation, clear overflow count in low bits.
if b.overflow.CompareAndSwap(overflow, ((overflow>>32)+1)<<32) {
break
}
overflow = b.overflow.Load()
time = b.overflowTime.Load()
}
return uint32(overflow), time
}
// incrementOverflow records a single overflow at time now.
// It is racing against a possible takeOverflow in the reader.
func (b *profBuf) incrementOverflow(now int64) {
for {
overflow := b.overflow.Load()
// Once we see b.overflow reach 0, it's stable: no one else is changing it underfoot.
// We need to set overflowTime if we're incrementing b.overflow from 0.
if uint32(overflow) == 0 {
// Store overflowTime first so it's always available when overflow != 0.
b.overflowTime.Store(uint64(now))
b.overflow.Store((((overflow >> 32) + 1) << 32) + 1)
break
}
// Otherwise we're racing to increment against reader
// who wants to set b.overflow to 0.
// Out of paranoia, leave 2³²-1 a sticky overflow value,
// to avoid wrapping around. Extremely unlikely.
if int32(overflow) == -1 {
break
}
if b.overflow.CompareAndSwap(overflow, overflow+1) {
break
}
}
}
// newProfBuf returns a new profiling buffer with room for
// a header of hdrsize words and a buffer of at least bufwords words.
func newProfBuf(hdrsize, bufwords, tags int) *profBuf {
if min := 2 + hdrsize + 1; bufwords < min {
bufwords = min
}
// Buffer sizes must be power of two, so that we don't have to
// worry about uint32 wraparound changing the effective position
// within the buffers. We store 30 bits of count; limiting to 28
// gives us some room for intermediate calculations.
if bufwords >= 1<<28 || tags >= 1<<28 {
throw("newProfBuf: buffer too large")
}
var i int
for i = 1; i < bufwords; i <<= 1 {
}
bufwords = i
for i = 1; i < tags; i <<= 1 {
}
tags = i
b := new(profBuf)
b.hdrsize = uintptr(hdrsize)
b.data = make([]uint64, bufwords)
b.tags = make([]unsafe.Pointer, tags)
b.overflowBuf = make([]uint64, 2+b.hdrsize+1)
return b
}
// canWriteRecord reports whether the buffer has room
// for a single contiguous record with a stack of length nstk.
func (b *profBuf) canWriteRecord(nstk int) bool {
br := b.r.load()
bw := b.w.load()
// room for tag?
if countSub(br.tagCount(), bw.tagCount())+len(b.tags) < 1 {
return false
}
// room for data?
nd := countSub(br.dataCount(), bw.dataCount()) + len(b.data)
want := 2 + int(b.hdrsize) + nstk
i := int(bw.dataCount() % uint32(len(b.data)))
if i+want > len(b.data) {
// Can't fit in trailing fragment of slice.
// Skip over that and start over at beginning of slice.
nd -= len(b.data) - i
}
return nd >= want
}
// canWriteTwoRecords reports whether the buffer has room
// for two records with stack lengths nstk1, nstk2, in that order.
// Each record must be contiguous on its own, but the two
// records need not be contiguous (one can be at the end of the buffer
// and the other can wrap around and start at the beginning of the buffer).
func (b *profBuf) canWriteTwoRecords(nstk1, nstk2 int) bool {
br := b.r.load()
bw := b.w.load()
// room for tag?
if countSub(br.tagCount(), bw.tagCount())+len(b.tags) < 2 {
return false
}
// room for data?
nd := countSub(br.dataCount(), bw.dataCount()) + len(b.data)
// first record
want := 2 + int(b.hdrsize) + nstk1
i := int(bw.dataCount() % uint32(len(b.data)))
if i+want > len(b.data) {
// Can't fit in trailing fragment of slice.
// Skip over that and start over at beginning of slice.
nd -= len(b.data) - i
i = 0
}
i += want
nd -= want
// second record
want = 2 + int(b.hdrsize) + nstk2
if i+want > len(b.data) {
// Can't fit in trailing fragment of slice.
// Skip over that and start over at beginning of slice.
nd -= len(b.data) - i
i = 0
}
return nd >= want
}
// write writes an entry to the profiling buffer b.
// The entry begins with a fixed hdr, which must have
// length b.hdrsize, followed by a variable-sized stack
// and a single tag pointer *tagPtr (or nil if tagPtr is nil).
// No write barriers allowed because this might be called from a signal handler.
func (b *profBuf) write(tagPtr *unsafe.Pointer, now int64, hdr []uint64, stk []uintptr) {
if b == nil {
return
}
if len(hdr) > int(b.hdrsize) {
throw("misuse of profBuf.write")
}
if hasOverflow := b.hasOverflow(); hasOverflow && b.canWriteTwoRecords(1, len(stk)) {
// Room for both an overflow record and the one being written.
// Write the overflow record if the reader hasn't gotten to it yet.
// Only racing against reader, not other writers.
count, time := b.takeOverflow()
if count > 0 {
var stk [1]uintptr
stk[0] = uintptr(count)
b.write(nil, int64(time), nil, stk[:])
}
} else if hasOverflow || !b.canWriteRecord(len(stk)) {
// Pending overflow without room to write overflow and new records
// or no overflow but also no room for new record.
b.incrementOverflow(now)
b.wakeupExtra()
return
}
// There's room: write the record.
br := b.r.load()
bw := b.w.load()
// Profiling tag
//
// The tag is a pointer, but we can't run a write barrier here.
// We have interrupted the OS-level execution of gp, but the
// runtime still sees gp as executing. In effect, we are running
// in place of the real gp. Since gp is the only goroutine that
// can overwrite gp.labels, the value of gp.labels is stable during
// this signal handler: it will still be reachable from gp when
// we finish executing. If a GC is in progress right now, it must
// keep gp.labels alive, because gp.labels is reachable from gp.
// If gp were to overwrite gp.labels, the deletion barrier would
// still shade that pointer, which would preserve it for the
// in-progress GC, so all is well. Any future GC will see the
// value we copied when scanning b.tags (heap-allocated).
// We arrange that the store here is always overwriting a nil,
// so there is no need for a deletion barrier on b.tags[wt].
wt := int(bw.tagCount() % uint32(len(b.tags)))
if tagPtr != nil {
*(*uintptr)(unsafe.Pointer(&b.tags[wt])) = uintptr(unsafe.Pointer(*tagPtr))
}
// Main record.
// It has to fit in a contiguous section of the slice, so if it doesn't fit at the end,
// leave a rewind marker (0) and start over at the beginning of the slice.
wd := int(bw.dataCount() % uint32(len(b.data)))
nd := countSub(br.dataCount(), bw.dataCount()) + len(b.data)
skip := 0
if wd+2+int(b.hdrsize)+len(stk) > len(b.data) {
b.data[wd] = 0
skip = len(b.data) - wd
nd -= skip
wd = 0
}
data := b.data[wd:]
data[0] = uint64(2 + b.hdrsize + uintptr(len(stk))) // length
data[1] = uint64(now) // time stamp
// header, zero-padded
i := uintptr(copy(data[2:2+b.hdrsize], hdr))
for ; i < b.hdrsize; i++ {
data[2+i] = 0
}
for i, pc := range stk {
data[2+b.hdrsize+uintptr(i)] = uint64(pc)
}
for {
// Commit write.
// Racing with reader setting flag bits in b.w, to avoid lost wakeups.
old := b.w.load()
new := old.addCountsAndClearFlags(skip+2+len(stk)+int(b.hdrsize), 1)
if !b.w.cas(old, new) {
continue
}
// If there was a reader, wake it up.
if old&profReaderSleeping != 0 {
notewakeup(&b.wait)
}
break
}
}
// close signals that there will be no more writes on the buffer.
// Once all the data has been read from the buffer, reads will return eof=true.
func (b *profBuf) close() {
if b.eof.Load() > 0 {
throw("runtime: profBuf already closed")
}
b.eof.Store(1)
b.wakeupExtra()
}
// wakeupExtra must be called after setting one of the "extra"
// atomic fields b.overflow or b.eof.
// It records the change in b.w and wakes up the reader if needed.
func (b *profBuf) wakeupExtra() {
for {
old := b.w.load()
new := old | profWriteExtra
if !b.w.cas(old, new) {
continue
}
if old&profReaderSleeping != 0 {
notewakeup(&b.wait)
}
break
}
}
// profBufReadMode specifies whether to block when no data is available to read.
type profBufReadMode int
const (
profBufBlocking profBufReadMode = iota
profBufNonBlocking
)
var overflowTag [1]unsafe.Pointer // always nil
func (b *profBuf) read(mode profBufReadMode) (data []uint64, tags []unsafe.Pointer, eof bool) {
if b == nil {
return nil, nil, true
}
br := b.rNext
// Commit previous read, returning that part of the ring to the writer.
// First clear tags that have now been read, both to avoid holding
// up the memory they point at for longer than necessary
// and so that b.write can assume it is always overwriting
// nil tag entries (see comment in b.write).
rPrev := b.r.load()
if rPrev != br {
ntag := countSub(br.tagCount(), rPrev.tagCount())
ti := int(rPrev.tagCount() % uint32(len(b.tags)))
for i := 0; i < ntag; i++ {
b.tags[ti] = nil
if ti++; ti == len(b.tags) {
ti = 0
}
}
b.r.store(br)
}
Read:
bw := b.w.load()
numData := countSub(bw.dataCount(), br.dataCount())
if numData == 0 {
if b.hasOverflow() {
// No data to read, but there is overflow to report.
// Racing with writer flushing b.overflow into a real record.
count, time := b.takeOverflow()
if count == 0 {
// Lost the race, go around again.
goto Read
}
// Won the race, report overflow.
dst := b.overflowBuf
dst[0] = uint64(2 + b.hdrsize + 1)
dst[1] = uint64(time)
for i := uintptr(0); i < b.hdrsize; i++ {
dst[2+i] = 0
}
dst[2+b.hdrsize] = uint64(count)
return dst[:2+b.hdrsize+1], overflowTag[:1], false
}
if b.eof.Load() > 0 {
// No data, no overflow, EOF set: done.
return nil, nil, true
}
if bw&profWriteExtra != 0 {
// Writer claims to have published extra information (overflow or eof).
// Attempt to clear notification and then check again.
// If we fail to clear the notification it means b.w changed,
// so we still need to check again.
b.w.cas(bw, bw&^profWriteExtra)
goto Read
}
// Nothing to read right now.
// Return or sleep according to mode.
if mode == profBufNonBlocking {
return nil, nil, false
}
if !b.w.cas(bw, bw|profReaderSleeping) {
goto Read
}
// Committed to sleeping.
notetsleepg(&b.wait, -1)
noteclear(&b.wait)
goto Read
}
data = b.data[br.dataCount()%uint32(len(b.data)):]
if len(data) > numData {
data = data[:numData]
} else {
numData -= len(data) // available in case of wraparound
}
skip := 0
if data[0] == 0 {
// Wraparound record. Go back to the beginning of the ring.
skip = len(data)
data = b.data
if len(data) > numData {
data = data[:numData]
}
}
ntag := countSub(bw.tagCount(), br.tagCount())
if ntag == 0 {
throw("runtime: malformed profBuf buffer - tag and data out of sync")
}
tags = b.tags[br.tagCount()%uint32(len(b.tags)):]
if len(tags) > ntag {
tags = tags[:ntag]
}
// Count out whole data records until either data or tags is done.
// They are always in sync in the buffer, but due to an end-of-slice
// wraparound we might need to stop early and return the rest
// in the next call.
di := 0
ti := 0
for di < len(data) && data[di] != 0 && ti < len(tags) {
if uintptr(di)+uintptr(data[di]) > uintptr(len(data)) {
throw("runtime: malformed profBuf buffer - invalid size")
}
di += int(data[di])
ti++
}
// Remember how much we returned, to commit read on next call.
b.rNext = br.addCountsAndClearFlags(skip+di, ti)
if raceenabled {
// Match racereleasemerge in runtime_setProfLabel,
// so that the setting of the labels in runtime_setProfLabel
// is treated as happening before any use of the labels
// by our caller. The synchronization on labelSync itself is a fiction
// for the race detector. The actual synchronization is handled
// by the fact that the signal handler only reads from the current
// goroutine and uses atomics to write the updated queue indices,
// and then the read-out from the signal handler buffer uses
// atomics to read those queue indices.
raceacquire(unsafe.Pointer(&labelSync))
}
return data[:di], tags[:ti], false
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import "unsafe"
var labelSync uintptr
//go:linkname runtime_setProfLabel runtime/pprof.runtime_setProfLabel
func runtime_setProfLabel(labels unsafe.Pointer) {
// Introduce race edge for read-back via profile.
// This would more properly use &getg().labels as the sync address,
// but we do the read in a signal handler and can't call the race runtime then.
//
// This uses racereleasemerge rather than just racerelease so
// the acquire in profBuf.read synchronizes with *all* prior
// setProfLabel operations, not just the most recent one. This
// is important because profBuf.read will observe different
// labels set by different setProfLabel operations on
// different goroutines, so it needs to synchronize with all
// of them (this wouldn't be an issue if we could synchronize
// on &getg().labels since we would synchronize with each
// most-recent labels write separately.)
//
// racereleasemerge is like a full read-modify-write on
// labelSync, rather than just a store-release, so it carries
// a dependency on the previous racereleasemerge, which
// ultimately carries forward to the acquire in profBuf.read.
if raceenabled {
racereleasemerge(unsafe.Pointer(&labelSync))
}
getg().labels = labels
}
//go:linkname runtime_getProfLabel runtime/pprof.runtime_getProfLabel
func runtime_getProfLabel() unsafe.Pointer {
return getg().labels
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !race
// Dummy race detection API, used when not built with -race.
package runtime
import (
"unsafe"
)
const raceenabled = false
// Because raceenabled is false, none of these functions should be called.
func raceReadObjectPC(t *_type, addr unsafe.Pointer, callerpc, pc uintptr) { throw("race") }
func raceWriteObjectPC(t *_type, addr unsafe.Pointer, callerpc, pc uintptr) { throw("race") }
func raceinit() (uintptr, uintptr) { throw("race"); return 0, 0 }
func racefini() { throw("race") }
func raceproccreate() uintptr { throw("race"); return 0 }
func raceprocdestroy(ctx uintptr) { throw("race") }
func racemapshadow(addr unsafe.Pointer, size uintptr) { throw("race") }
func racewritepc(addr unsafe.Pointer, callerpc, pc uintptr) { throw("race") }
func racereadpc(addr unsafe.Pointer, callerpc, pc uintptr) { throw("race") }
func racereadrangepc(addr unsafe.Pointer, sz, callerpc, pc uintptr) { throw("race") }
func racewriterangepc(addr unsafe.Pointer, sz, callerpc, pc uintptr) { throw("race") }
func raceacquire(addr unsafe.Pointer) { throw("race") }
func raceacquireg(gp *g, addr unsafe.Pointer) { throw("race") }
func raceacquirectx(racectx uintptr, addr unsafe.Pointer) { throw("race") }
func racerelease(addr unsafe.Pointer) { throw("race") }
func racereleaseg(gp *g, addr unsafe.Pointer) { throw("race") }
func racereleaseacquire(addr unsafe.Pointer) { throw("race") }
func racereleaseacquireg(gp *g, addr unsafe.Pointer) { throw("race") }
func racereleasemerge(addr unsafe.Pointer) { throw("race") }
func racereleasemergeg(gp *g, addr unsafe.Pointer) { throw("race") }
func racefingo() { throw("race") }
func racemalloc(p unsafe.Pointer, sz uintptr) { throw("race") }
func racefree(p unsafe.Pointer, sz uintptr) { throw("race") }
func racegostart(pc uintptr) uintptr { throw("race"); return 0 }
func racegoend() { throw("race") }
func racectxend(racectx uintptr) { throw("race") }
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import _ "unsafe" // for go:linkname
//go:linkname setMaxStack runtime/debug.setMaxStack
func setMaxStack(in int) (out int) {
out = int(maxstacksize)
maxstacksize = uintptr(in)
return out
}
//go:linkname setPanicOnFault runtime/debug.setPanicOnFault
func setPanicOnFault(new bool) (old bool) {
gp := getg()
old = gp.paniconfault
gp.paniconfault = new
return old
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package runtime
// osRelaxMinNS is the number of nanoseconds of idleness to tolerate
// without performing an osRelax. Since osRelax may reduce the
// precision of timers, this should be enough larger than the relaxed
// timer precision to keep the timer error acceptable.
const osRelaxMinNS = 0
// osRelax is called by the scheduler when transitioning to and from
// all Ps being idle.
func osRelax(relax bool) {}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package runtime
// retryOnEAGAIN retries a function until it does not return EAGAIN.
// It will use an increasing delay between calls, and retry up to 20 times.
// The function argument is expected to return an errno value,
// and retryOnEAGAIN will return any errno value other than EAGAIN.
// If all retries return EAGAIN, then retryOnEAGAIN will return EAGAIN.
func retryOnEAGAIN(fn func() int32) int32 {
for tries := 0; tries < 20; tries++ {
errno := fn()
if errno != _EAGAIN {
return errno
}
usleep_no_g(uint32(tries+1) * 1000) // milliseconds
}
return _EAGAIN
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
//go:generate go run wincallback.go
//go:generate go run mkduff.go
//go:generate go run mkfastlog2table.go
//go:generate go run mklockrank.go -o lockrank.go
var ticks ticksType
type ticksType struct {
lock mutex
val atomic.Int64
}
// Note: Called by runtime/pprof in addition to runtime code.
func tickspersecond() int64 {
r := ticks.val.Load()
if r != 0 {
return r
}
lock(&ticks.lock)
r = ticks.val.Load()
if r == 0 {
t0 := nanotime()
c0 := cputicks()
usleep(100 * 1000)
t1 := nanotime()
c1 := cputicks()
if t1 == t0 {
t1++
}
r = (c1 - c0) * 1000 * 1000 * 1000 / (t1 - t0)
if r == 0 {
r++
}
ticks.val.Store(r)
}
unlock(&ticks.lock)
return r
}
var envs []string
var argslice []string
//go:linkname syscall_runtime_envs syscall.runtime_envs
func syscall_runtime_envs() []string { return append([]string{}, envs...) }
//go:linkname syscall_Getpagesize syscall.Getpagesize
func syscall_Getpagesize() int { return int(physPageSize) }
//go:linkname os_runtime_args os.runtime_args
func os_runtime_args() []string { return append([]string{}, argslice...) }
//go:linkname syscall_Exit syscall.Exit
//go:nosplit
func syscall_Exit(code int) {
exit(int32(code))
}
var godebugDefault string
var godebugUpdate atomic.Pointer[func(string, string)]
var godebugEnv atomic.Pointer[string] // set by parsedebugvars
var godebugNewIncNonDefault atomic.Pointer[func(string) func()]
//go:linkname godebug_setUpdate internal/godebug.setUpdate
func godebug_setUpdate(update func(string, string)) {
p := new(func(string, string))
*p = update
godebugUpdate.Store(p)
godebugNotify(false)
}
//go:linkname godebug_setNewIncNonDefault internal/godebug.setNewIncNonDefault
func godebug_setNewIncNonDefault(newIncNonDefault func(string) func()) {
p := new(func(string) func())
*p = newIncNonDefault
godebugNewIncNonDefault.Store(p)
}
// A godebugInc provides access to internal/godebug's IncNonDefault function
// for a given GODEBUG setting.
// Calls before internal/godebug registers itself are dropped on the floor.
type godebugInc struct {
name string
inc atomic.Pointer[func()]
}
func (g *godebugInc) IncNonDefault() {
inc := g.inc.Load()
if inc == nil {
newInc := godebugNewIncNonDefault.Load()
if newInc == nil {
return
}
// If other goroutines are racing here, no big deal. One will win,
// and all the inc functions will be using the same underlying
// *godebug.Setting.
inc = new(func())
*inc = (*newInc)(g.name)
g.inc.Store(inc)
}
(*inc)()
}
func godebugNotify(envChanged bool) {
update := godebugUpdate.Load()
var env string
if p := godebugEnv.Load(); p != nil {
env = *p
}
if envChanged {
reparsedebugvars(env)
}
if update != nil {
(*update)(godebugDefault, env)
}
}
//go:linkname syscall_runtimeSetenv syscall.runtimeSetenv
func syscall_runtimeSetenv(key, value string) {
setenv_c(key, value)
if key == "GODEBUG" {
p := new(string)
*p = value
godebugEnv.Store(p)
godebugNotify(true)
}
}
//go:linkname syscall_runtimeUnsetenv syscall.runtimeUnsetenv
func syscall_runtimeUnsetenv(key string) {
unsetenv_c(key)
if key == "GODEBUG" {
godebugEnv.Store(nil)
godebugNotify(true)
}
}
// writeErrStr writes a string to descriptor 2.
//
//go:nosplit
func writeErrStr(s string) {
write(2, unsafe.Pointer(unsafe.StringData(s)), int32(len(s)))
}
// auxv is populated on relevant platforms but defined here for all platforms
// so x/sys/cpu can assume the getAuxv symbol exists without keeping its list
// of auxv-using GOOS build tags in sync.
//
// It contains an even number of elements, (tag, value) pairs.
var auxv []uintptr
func getAuxv() []uintptr { return auxv } // accessed from x/sys/cpu; see issue 57336
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/bytealg"
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
// Keep a cached value to make gotraceback fast,
// since we call it on every call to gentraceback.
// The cached value is a uint32 in which the low bits
// are the "crash" and "all" settings and the remaining
// bits are the traceback value (0 off, 1 on, 2 include system).
const (
tracebackCrash = 1 << iota
tracebackAll
tracebackShift = iota
)
var traceback_cache uint32 = 2 << tracebackShift
var traceback_env uint32
// gotraceback returns the current traceback settings.
//
// If level is 0, suppress all tracebacks.
// If level is 1, show tracebacks, but exclude runtime frames.
// If level is 2, show tracebacks including runtime frames.
// If all is set, print all goroutine stacks. Otherwise, print just the current goroutine.
// If crash is set, crash (core dump, etc) after tracebacking.
//
//go:nosplit
func gotraceback() (level int32, all, crash bool) {
gp := getg()
t := atomic.Load(&traceback_cache)
crash = t&tracebackCrash != 0
all = gp.m.throwing >= throwTypeUser || t&tracebackAll != 0
if gp.m.traceback != 0 {
level = int32(gp.m.traceback)
} else if gp.m.throwing >= throwTypeRuntime {
// Always include runtime frames in runtime throws unless
// otherwise overridden by m.traceback.
level = 2
} else {
level = int32(t >> tracebackShift)
}
return
}
var (
argc int32
argv **byte
)
// nosplit for use in linux startup sysargs.
//
//go:nosplit
func argv_index(argv **byte, i int32) *byte {
return *(**byte)(add(unsafe.Pointer(argv), uintptr(i)*goarch.PtrSize))
}
func args(c int32, v **byte) {
argc = c
argv = v
sysargs(c, v)
}
func goargs() {
if GOOS == "windows" {
return
}
argslice = make([]string, argc)
for i := int32(0); i < argc; i++ {
argslice[i] = gostringnocopy(argv_index(argv, i))
}
}
func goenvs_unix() {
// TODO(austin): ppc64 in dynamic linking mode doesn't
// guarantee env[] will immediately follow argv. Might cause
// problems.
n := int32(0)
for argv_index(argv, argc+1+n) != nil {
n++
}
envs = make([]string, n)
for i := int32(0); i < n; i++ {
envs[i] = gostring(argv_index(argv, argc+1+i))
}
}
func environ() []string {
return envs
}
// TODO: These should be locals in testAtomic64, but we don't 8-byte
// align stack variables on 386.
var test_z64, test_x64 uint64
func testAtomic64() {
test_z64 = 42
test_x64 = 0
if atomic.Cas64(&test_z64, test_x64, 1) {
throw("cas64 failed")
}
if test_x64 != 0 {
throw("cas64 failed")
}
test_x64 = 42
if !atomic.Cas64(&test_z64, test_x64, 1) {
throw("cas64 failed")
}
if test_x64 != 42 || test_z64 != 1 {
throw("cas64 failed")
}
if atomic.Load64(&test_z64) != 1 {
throw("load64 failed")
}
atomic.Store64(&test_z64, (1<<40)+1)
if atomic.Load64(&test_z64) != (1<<40)+1 {
throw("store64 failed")
}
if atomic.Xadd64(&test_z64, (1<<40)+1) != (2<<40)+2 {
throw("xadd64 failed")
}
if atomic.Load64(&test_z64) != (2<<40)+2 {
throw("xadd64 failed")
}
if atomic.Xchg64(&test_z64, (3<<40)+3) != (2<<40)+2 {
throw("xchg64 failed")
}
if atomic.Load64(&test_z64) != (3<<40)+3 {
throw("xchg64 failed")
}
}
func check() {
var (
a int8
b uint8
c int16
d uint16
e int32
f uint32
g int64
h uint64
i, i1 float32
j, j1 float64
k unsafe.Pointer
l *uint16
m [4]byte
)
type x1t struct {
x uint8
}
type y1t struct {
x1 x1t
y uint8
}
var x1 x1t
var y1 y1t
if unsafe.Sizeof(a) != 1 {
throw("bad a")
}
if unsafe.Sizeof(b) != 1 {
throw("bad b")
}
if unsafe.Sizeof(c) != 2 {
throw("bad c")
}
if unsafe.Sizeof(d) != 2 {
throw("bad d")
}
if unsafe.Sizeof(e) != 4 {
throw("bad e")
}
if unsafe.Sizeof(f) != 4 {
throw("bad f")
}
if unsafe.Sizeof(g) != 8 {
throw("bad g")
}
if unsafe.Sizeof(h) != 8 {
throw("bad h")
}
if unsafe.Sizeof(i) != 4 {
throw("bad i")
}
if unsafe.Sizeof(j) != 8 {
throw("bad j")
}
if unsafe.Sizeof(k) != goarch.PtrSize {
throw("bad k")
}
if unsafe.Sizeof(l) != goarch.PtrSize {
throw("bad l")
}
if unsafe.Sizeof(x1) != 1 {
throw("bad unsafe.Sizeof x1")
}
if unsafe.Offsetof(y1.y) != 1 {
throw("bad offsetof y1.y")
}
if unsafe.Sizeof(y1) != 2 {
throw("bad unsafe.Sizeof y1")
}
if timediv(12345*1000000000+54321, 1000000000, &e) != 12345 || e != 54321 {
throw("bad timediv")
}
var z uint32
z = 1
if !atomic.Cas(&z, 1, 2) {
throw("cas1")
}
if z != 2 {
throw("cas2")
}
z = 4
if atomic.Cas(&z, 5, 6) {
throw("cas3")
}
if z != 4 {
throw("cas4")
}
z = 0xffffffff
if !atomic.Cas(&z, 0xffffffff, 0xfffffffe) {
throw("cas5")
}
if z != 0xfffffffe {
throw("cas6")
}
m = [4]byte{1, 1, 1, 1}
atomic.Or8(&m[1], 0xf0)
if m[0] != 1 || m[1] != 0xf1 || m[2] != 1 || m[3] != 1 {
throw("atomicor8")
}
m = [4]byte{0xff, 0xff, 0xff, 0xff}
atomic.And8(&m[1], 0x1)
if m[0] != 0xff || m[1] != 0x1 || m[2] != 0xff || m[3] != 0xff {
throw("atomicand8")
}
*(*uint64)(unsafe.Pointer(&j)) = ^uint64(0)
if j == j {
throw("float64nan")
}
if !(j != j) {
throw("float64nan1")
}
*(*uint64)(unsafe.Pointer(&j1)) = ^uint64(1)
if j == j1 {
throw("float64nan2")
}
if !(j != j1) {
throw("float64nan3")
}
*(*uint32)(unsafe.Pointer(&i)) = ^uint32(0)
if i == i {
throw("float32nan")
}
if i == i {
throw("float32nan1")
}
*(*uint32)(unsafe.Pointer(&i1)) = ^uint32(1)
if i == i1 {
throw("float32nan2")
}
if i == i1 {
throw("float32nan3")
}
testAtomic64()
if _FixedStack != round2(_FixedStack) {
throw("FixedStack is not power-of-2")
}
if !checkASM() {
throw("assembly checks failed")
}
}
type dbgVar struct {
name string
value *int32 // for variables that can only be set at startup
atomic *atomic.Int32 // for variables that can be changed during execution
def int32 // default value (ideally zero)
}
// Holds variables parsed from GODEBUG env var,
// except for "memprofilerate" since there is an
// existing int var for that value, which may
// already have an initial value.
var debug struct {
cgocheck int32
clobberfree int32
efence int32
gccheckmark int32
gcpacertrace int32
gcshrinkstackoff int32
gcstoptheworld int32
gctrace int32
invalidptr int32
madvdontneed int32 // for Linux; issue 28466
scavtrace int32
scheddetail int32
schedtrace int32
tracebackancestors int32
asyncpreemptoff int32
harddecommit int32
adaptivestackstart int32
// debug.malloc is used as a combined debug check
// in the malloc function and should be set
// if any of the below debug options is != 0.
malloc bool
allocfreetrace int32
inittrace int32
sbrk int32
panicnil atomic.Int32
}
var dbgvars = []*dbgVar{
{name: "allocfreetrace", value: &debug.allocfreetrace},
{name: "clobberfree", value: &debug.clobberfree},
{name: "cgocheck", value: &debug.cgocheck},
{name: "efence", value: &debug.efence},
{name: "gccheckmark", value: &debug.gccheckmark},
{name: "gcpacertrace", value: &debug.gcpacertrace},
{name: "gcshrinkstackoff", value: &debug.gcshrinkstackoff},
{name: "gcstoptheworld", value: &debug.gcstoptheworld},
{name: "gctrace", value: &debug.gctrace},
{name: "invalidptr", value: &debug.invalidptr},
{name: "madvdontneed", value: &debug.madvdontneed},
{name: "sbrk", value: &debug.sbrk},
{name: "scavtrace", value: &debug.scavtrace},
{name: "scheddetail", value: &debug.scheddetail},
{name: "schedtrace", value: &debug.schedtrace},
{name: "tracebackancestors", value: &debug.tracebackancestors},
{name: "asyncpreemptoff", value: &debug.asyncpreemptoff},
{name: "inittrace", value: &debug.inittrace},
{name: "harddecommit", value: &debug.harddecommit},
{name: "adaptivestackstart", value: &debug.adaptivestackstart},
{name: "panicnil", atomic: &debug.panicnil},
}
func parsedebugvars() {
// defaults
debug.cgocheck = 1
debug.invalidptr = 1
debug.adaptivestackstart = 1 // set this to 0 to turn larger initial goroutine stacks off
if GOOS == "linux" {
// On Linux, MADV_FREE is faster than MADV_DONTNEED,
// but doesn't affect many of the statistics that
// MADV_DONTNEED does until the memory is actually
// reclaimed. This generally leads to poor user
// experience, like confusing stats in top and other
// monitoring tools; and bad integration with
// management systems that respond to memory usage.
// Hence, default to MADV_DONTNEED.
debug.madvdontneed = 1
}
godebug := gogetenv("GODEBUG")
p := new(string)
*p = godebug
godebugEnv.Store(p)
// apply runtime defaults, if any
for _, v := range dbgvars {
if v.def != 0 {
// Every var should have either v.value or v.atomic set.
if v.value != nil {
*v.value = v.def
} else if v.atomic != nil {
v.atomic.Store(v.def)
}
}
}
// apply compile-time GODEBUG settings
parsegodebug(godebugDefault, nil)
// apply environment settings
parsegodebug(godebug, nil)
debug.malloc = (debug.allocfreetrace | debug.inittrace | debug.sbrk) != 0
setTraceback(gogetenv("GOTRACEBACK"))
traceback_env = traceback_cache
}
// reparsedebugvars reparses the runtime's debug variables
// because the environment variable has been changed to env.
func reparsedebugvars(env string) {
seen := make(map[string]bool)
// apply environment settings
parsegodebug(env, seen)
// apply compile-time GODEBUG settings for as-yet-unseen variables
parsegodebug(godebugDefault, seen)
// apply defaults for as-yet-unseen variables
for _, v := range dbgvars {
if v.atomic != nil && !seen[v.name] {
v.atomic.Store(0)
}
}
}
// parsegodebug parses the godebug string, updating variables listed in dbgvars.
// If seen == nil, this is startup time and we process the string left to right
// overwriting older settings with newer ones.
// If seen != nil, $GODEBUG has changed and we are doing an
// incremental update. To avoid flapping in the case where a value is
// set multiple times (perhaps in the default and the environment,
// or perhaps twice in the environment), we process the string right-to-left
// and only change values not already seen. After doing this for both
// the environment and the default settings, the caller must also call
// cleargodebug(seen) to reset any now-unset values back to their defaults.
func parsegodebug(godebug string, seen map[string]bool) {
for p := godebug; p != ""; {
var field string
if seen == nil {
// startup: process left to right, overwriting older settings with newer
i := bytealg.IndexByteString(p, ',')
if i < 0 {
field, p = p, ""
} else {
field, p = p[:i], p[i+1:]
}
} else {
// incremental update: process right to left, updating and skipping seen
i := len(p) - 1
for i >= 0 && p[i] != ',' {
i--
}
if i < 0 {
p, field = "", p
} else {
p, field = p[:i], p[i+1:]
}
}
i := bytealg.IndexByteString(field, '=')
if i < 0 {
continue
}
key, value := field[:i], field[i+1:]
if seen[key] {
continue
}
if seen != nil {
seen[key] = true
}
// Update MemProfileRate directly here since it
// is int, not int32, and should only be updated
// if specified in GODEBUG.
if seen == nil && key == "memprofilerate" {
if n, ok := atoi(value); ok {
MemProfileRate = n
}
} else {
for _, v := range dbgvars {
if v.name == key {
if n, ok := atoi32(value); ok {
if seen == nil && v.value != nil {
*v.value = n
} else if v.atomic != nil {
v.atomic.Store(n)
}
}
}
}
}
}
if debug.cgocheck > 1 {
throw("cgocheck > 1 mode is no longer supported at runtime. Use GOEXPERIMENT=cgocheck2 at build time instead.")
}
}
//go:linkname setTraceback runtime/debug.SetTraceback
func setTraceback(level string) {
var t uint32
switch level {
case "none":
t = 0
case "single", "":
t = 1 << tracebackShift
case "all":
t = 1<<tracebackShift | tracebackAll
case "system":
t = 2<<tracebackShift | tracebackAll
case "crash":
t = 2<<tracebackShift | tracebackAll | tracebackCrash
default:
t = tracebackAll
if n, ok := atoi(level); ok && n == int(uint32(n)) {
t |= uint32(n) << tracebackShift
}
}
// when C owns the process, simply exit'ing the process on fatal errors
// and panics is surprising. Be louder and abort instead.
if islibrary || isarchive {
t |= tracebackCrash
}
t |= traceback_env
atomic.Store(&traceback_cache, t)
}
// Poor mans 64-bit division.
// This is a very special function, do not use it if you are not sure what you are doing.
// int64 division is lowered into _divv() call on 386, which does not fit into nosplit functions.
// Handles overflow in a time-specific manner.
// This keeps us within no-split stack limits on 32-bit processors.
//
//go:nosplit
func timediv(v int64, div int32, rem *int32) int32 {
res := int32(0)
for bit := 30; bit >= 0; bit-- {
if v >= int64(div)<<uint(bit) {
v = v - (int64(div) << uint(bit))
// Before this for loop, res was 0, thus all these
// power of 2 increments are now just bitsets.
res |= 1 << uint(bit)
}
}
if v >= int64(div) {
if rem != nil {
*rem = 0
}
return 0x7fffffff
}
if rem != nil {
*rem = int32(v)
}
return res
}
// Helpers for Go. Must be NOSPLIT, must only call NOSPLIT functions, and must not block.
//go:nosplit
func acquirem() *m {
gp := getg()
gp.m.locks++
return gp.m
}
//go:nosplit
func releasem(mp *m) {
gp := getg()
mp.locks--
if mp.locks == 0 && gp.preempt {
// restore the preemption request in case we've cleared it in newstack
gp.stackguard0 = stackPreempt
}
}
//go:linkname reflect_typelinks reflect.typelinks
func reflect_typelinks() ([]unsafe.Pointer, [][]int32) {
modules := activeModules()
sections := []unsafe.Pointer{unsafe.Pointer(modules[0].types)}
ret := [][]int32{modules[0].typelinks}
for _, md := range modules[1:] {
sections = append(sections, unsafe.Pointer(md.types))
ret = append(ret, md.typelinks)
}
return sections, ret
}
// reflect_resolveNameOff resolves a name offset from a base pointer.
//
//go:linkname reflect_resolveNameOff reflect.resolveNameOff
func reflect_resolveNameOff(ptrInModule unsafe.Pointer, off int32) unsafe.Pointer {
return unsafe.Pointer(resolveNameOff(ptrInModule, nameOff(off)).bytes)
}
// reflect_resolveTypeOff resolves an *rtype offset from a base type.
//
//go:linkname reflect_resolveTypeOff reflect.resolveTypeOff
func reflect_resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer {
return unsafe.Pointer((*_type)(rtype).typeOff(typeOff(off)))
}
// reflect_resolveTextOff resolves a function pointer offset from a base type.
//
//go:linkname reflect_resolveTextOff reflect.resolveTextOff
func reflect_resolveTextOff(rtype unsafe.Pointer, off int32) unsafe.Pointer {
return (*_type)(rtype).textOff(textOff(off))
}
// reflectlite_resolveNameOff resolves a name offset from a base pointer.
//
//go:linkname reflectlite_resolveNameOff internal/reflectlite.resolveNameOff
func reflectlite_resolveNameOff(ptrInModule unsafe.Pointer, off int32) unsafe.Pointer {
return unsafe.Pointer(resolveNameOff(ptrInModule, nameOff(off)).bytes)
}
// reflectlite_resolveTypeOff resolves an *rtype offset from a base type.
//
//go:linkname reflectlite_resolveTypeOff internal/reflectlite.resolveTypeOff
func reflectlite_resolveTypeOff(rtype unsafe.Pointer, off int32) unsafe.Pointer {
return unsafe.Pointer((*_type)(rtype).typeOff(typeOff(off)))
}
// reflect_addReflectOff adds a pointer to the reflection offset lookup map.
//
//go:linkname reflect_addReflectOff reflect.addReflectOff
func reflect_addReflectOff(ptr unsafe.Pointer) int32 {
reflectOffsLock()
if reflectOffs.m == nil {
reflectOffs.m = make(map[int32]unsafe.Pointer)
reflectOffs.minv = make(map[unsafe.Pointer]int32)
reflectOffs.next = -1
}
id, found := reflectOffs.minv[ptr]
if !found {
id = reflectOffs.next
reflectOffs.next-- // use negative offsets as IDs to aid debugging
reflectOffs.m[id] = ptr
reflectOffs.minv[ptr] = id
}
reflectOffsUnlock()
return id
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"unsafe"
)
// defined constants
const (
// G status
//
// Beyond indicating the general state of a G, the G status
// acts like a lock on the goroutine's stack (and hence its
// ability to execute user code).
//
// If you add to this list, add to the list
// of "okay during garbage collection" status
// in mgcmark.go too.
//
// TODO(austin): The _Gscan bit could be much lighter-weight.
// For example, we could choose not to run _Gscanrunnable
// goroutines found in the run queue, rather than CAS-looping
// until they become _Grunnable. And transitions like
// _Gscanwaiting -> _Gscanrunnable are actually okay because
// they don't affect stack ownership.
// _Gidle means this goroutine was just allocated and has not
// yet been initialized.
_Gidle = iota // 0
// _Grunnable means this goroutine is on a run queue. It is
// not currently executing user code. The stack is not owned.
_Grunnable // 1
// _Grunning means this goroutine may execute user code. The
// stack is owned by this goroutine. It is not on a run queue.
// It is assigned an M and a P (g.m and g.m.p are valid).
_Grunning // 2
// _Gsyscall means this goroutine is executing a system call.
// It is not executing user code. The stack is owned by this
// goroutine. It is not on a run queue. It is assigned an M.
_Gsyscall // 3
// _Gwaiting means this goroutine is blocked in the runtime.
// It is not executing user code. It is not on a run queue,
// but should be recorded somewhere (e.g., a channel wait
// queue) so it can be ready()d when necessary. The stack is
// not owned *except* that a channel operation may read or
// write parts of the stack under the appropriate channel
// lock. Otherwise, it is not safe to access the stack after a
// goroutine enters _Gwaiting (e.g., it may get moved).
_Gwaiting // 4
// _Gmoribund_unused is currently unused, but hardcoded in gdb
// scripts.
_Gmoribund_unused // 5
// _Gdead means this goroutine is currently unused. It may be
// just exited, on a free list, or just being initialized. It
// is not executing user code. It may or may not have a stack
// allocated. The G and its stack (if any) are owned by the M
// that is exiting the G or that obtained the G from the free
// list.
_Gdead // 6
// _Genqueue_unused is currently unused.
_Genqueue_unused // 7
// _Gcopystack means this goroutine's stack is being moved. It
// is not executing user code and is not on a run queue. The
// stack is owned by the goroutine that put it in _Gcopystack.
_Gcopystack // 8
// _Gpreempted means this goroutine stopped itself for a
// suspendG preemption. It is like _Gwaiting, but nothing is
// yet responsible for ready()ing it. Some suspendG must CAS
// the status to _Gwaiting to take responsibility for
// ready()ing this G.
_Gpreempted // 9
// _Gscan combined with one of the above states other than
// _Grunning indicates that GC is scanning the stack. The
// goroutine is not executing user code and the stack is owned
// by the goroutine that set the _Gscan bit.
//
// _Gscanrunning is different: it is used to briefly block
// state transitions while GC signals the G to scan its own
// stack. This is otherwise like _Grunning.
//
// atomicstatus&~Gscan gives the state the goroutine will
// return to when the scan completes.
_Gscan = 0x1000
_Gscanrunnable = _Gscan + _Grunnable // 0x1001
_Gscanrunning = _Gscan + _Grunning // 0x1002
_Gscansyscall = _Gscan + _Gsyscall // 0x1003
_Gscanwaiting = _Gscan + _Gwaiting // 0x1004
_Gscanpreempted = _Gscan + _Gpreempted // 0x1009
)
const (
// P status
// _Pidle means a P is not being used to run user code or the
// scheduler. Typically, it's on the idle P list and available
// to the scheduler, but it may just be transitioning between
// other states.
//
// The P is owned by the idle list or by whatever is
// transitioning its state. Its run queue is empty.
_Pidle = iota
// _Prunning means a P is owned by an M and is being used to
// run user code or the scheduler. Only the M that owns this P
// is allowed to change the P's status from _Prunning. The M
// may transition the P to _Pidle (if it has no more work to
// do), _Psyscall (when entering a syscall), or _Pgcstop (to
// halt for the GC). The M may also hand ownership of the P
// off directly to another M (e.g., to schedule a locked G).
_Prunning
// _Psyscall means a P is not running user code. It has
// affinity to an M in a syscall but is not owned by it and
// may be stolen by another M. This is similar to _Pidle but
// uses lightweight transitions and maintains M affinity.
//
// Leaving _Psyscall must be done with a CAS, either to steal
// or retake the P. Note that there's an ABA hazard: even if
// an M successfully CASes its original P back to _Prunning
// after a syscall, it must understand the P may have been
// used by another M in the interim.
_Psyscall
// _Pgcstop means a P is halted for STW and owned by the M
// that stopped the world. The M that stopped the world
// continues to use its P, even in _Pgcstop. Transitioning
// from _Prunning to _Pgcstop causes an M to release its P and
// park.
//
// The P retains its run queue and startTheWorld will restart
// the scheduler on Ps with non-empty run queues.
_Pgcstop
// _Pdead means a P is no longer used (GOMAXPROCS shrank). We
// reuse Ps if GOMAXPROCS increases. A dead P is mostly
// stripped of its resources, though a few things remain
// (e.g., trace buffers).
_Pdead
)
// Mutual exclusion locks. In the uncontended case,
// as fast as spin locks (just a few user-level instructions),
// but on the contention path they sleep in the kernel.
// A zeroed Mutex is unlocked (no need to initialize each lock).
// Initialization is helpful for static lock ranking, but not required.
type mutex struct {
// Empty struct if lock ranking is disabled, otherwise includes the lock rank
lockRankStruct
// Futex-based impl treats it as uint32 key,
// while sema-based impl as M* waitm.
// Used to be a union, but unions break precise GC.
key uintptr
}
// sleep and wakeup on one-time events.
// before any calls to notesleep or notewakeup,
// must call noteclear to initialize the Note.
// then, exactly one thread can call notesleep
// and exactly one thread can call notewakeup (once).
// once notewakeup has been called, the notesleep
// will return. future notesleep will return immediately.
// subsequent noteclear must be called only after
// previous notesleep has returned, e.g. it's disallowed
// to call noteclear straight after notewakeup.
//
// notetsleep is like notesleep but wakes up after
// a given number of nanoseconds even if the event
// has not yet happened. if a goroutine uses notetsleep to
// wake up early, it must wait to call noteclear until it
// can be sure that no other goroutine is calling
// notewakeup.
//
// notesleep/notetsleep are generally called on g0,
// notetsleepg is similar to notetsleep but is called on user g.
type note struct {
// Futex-based impl treats it as uint32 key,
// while sema-based impl as M* waitm.
// Used to be a union, but unions break precise GC.
key uintptr
}
type funcval struct {
fn uintptr
// variable-size, fn-specific data here
}
type iface struct {
tab *itab
data unsafe.Pointer
}
type eface struct {
_type *_type
data unsafe.Pointer
}
func efaceOf(ep *any) *eface {
return (*eface)(unsafe.Pointer(ep))
}
// The guintptr, muintptr, and puintptr are all used to bypass write barriers.
// It is particularly important to avoid write barriers when the current P has
// been released, because the GC thinks the world is stopped, and an
// unexpected write barrier would not be synchronized with the GC,
// which can lead to a half-executed write barrier that has marked the object
// but not queued it. If the GC skips the object and completes before the
// queuing can occur, it will incorrectly free the object.
//
// We tried using special assignment functions invoked only when not
// holding a running P, but then some updates to a particular memory
// word went through write barriers and some did not. This breaks the
// write barrier shadow checking mode, and it is also scary: better to have
// a word that is completely ignored by the GC than to have one for which
// only a few updates are ignored.
//
// Gs and Ps are always reachable via true pointers in the
// allgs and allp lists or (during allocation before they reach those lists)
// from stack variables.
//
// Ms are always reachable via true pointers either from allm or
// freem. Unlike Gs and Ps we do free Ms, so it's important that
// nothing ever hold an muintptr across a safe point.
// A guintptr holds a goroutine pointer, but typed as a uintptr
// to bypass write barriers. It is used in the Gobuf goroutine state
// and in scheduling lists that are manipulated without a P.
//
// The Gobuf.g goroutine pointer is almost always updated by assembly code.
// In one of the few places it is updated by Go code - func save - it must be
// treated as a uintptr to avoid a write barrier being emitted at a bad time.
// Instead of figuring out how to emit the write barriers missing in the
// assembly manipulation, we change the type of the field to uintptr,
// so that it does not require write barriers at all.
//
// Goroutine structs are published in the allg list and never freed.
// That will keep the goroutine structs from being collected.
// There is never a time that Gobuf.g's contain the only references
// to a goroutine: the publishing of the goroutine in allg comes first.
// Goroutine pointers are also kept in non-GC-visible places like TLS,
// so I can't see them ever moving. If we did want to start moving data
// in the GC, we'd need to allocate the goroutine structs from an
// alternate arena. Using guintptr doesn't make that problem any worse.
// Note that pollDesc.rg, pollDesc.wg also store g in uintptr form,
// so they would need to be updated too if g's start moving.
type guintptr uintptr
//go:nosplit
func (gp guintptr) ptr() *g { return (*g)(unsafe.Pointer(gp)) }
//go:nosplit
func (gp *guintptr) set(g *g) { *gp = guintptr(unsafe.Pointer(g)) }
//go:nosplit
func (gp *guintptr) cas(old, new guintptr) bool {
return atomic.Casuintptr((*uintptr)(unsafe.Pointer(gp)), uintptr(old), uintptr(new))
}
// setGNoWB performs *gp = new without a write barrier.
// For times when it's impractical to use a guintptr.
//
//go:nosplit
//go:nowritebarrier
func setGNoWB(gp **g, new *g) {
(*guintptr)(unsafe.Pointer(gp)).set(new)
}
type puintptr uintptr
//go:nosplit
func (pp puintptr) ptr() *p { return (*p)(unsafe.Pointer(pp)) }
//go:nosplit
func (pp *puintptr) set(p *p) { *pp = puintptr(unsafe.Pointer(p)) }
// muintptr is a *m that is not tracked by the garbage collector.
//
// Because we do free Ms, there are some additional constrains on
// muintptrs:
//
// 1. Never hold an muintptr locally across a safe point.
//
// 2. Any muintptr in the heap must be owned by the M itself so it can
// ensure it is not in use when the last true *m is released.
type muintptr uintptr
//go:nosplit
func (mp muintptr) ptr() *m { return (*m)(unsafe.Pointer(mp)) }
//go:nosplit
func (mp *muintptr) set(m *m) { *mp = muintptr(unsafe.Pointer(m)) }
// setMNoWB performs *mp = new without a write barrier.
// For times when it's impractical to use an muintptr.
//
//go:nosplit
//go:nowritebarrier
func setMNoWB(mp **m, new *m) {
(*muintptr)(unsafe.Pointer(mp)).set(new)
}
type gobuf struct {
// The offsets of sp, pc, and g are known to (hard-coded in) libmach.
//
// ctxt is unusual with respect to GC: it may be a
// heap-allocated funcval, so GC needs to track it, but it
// needs to be set and cleared from assembly, where it's
// difficult to have write barriers. However, ctxt is really a
// saved, live register, and we only ever exchange it between
// the real register and the gobuf. Hence, we treat it as a
// root during stack scanning, which means assembly that saves
// and restores it doesn't need write barriers. It's still
// typed as a pointer so that any other writes from Go get
// write barriers.
sp uintptr
pc uintptr
g guintptr
ctxt unsafe.Pointer
ret uintptr
lr uintptr
bp uintptr // for framepointer-enabled architectures
}
// sudog represents a g in a wait list, such as for sending/receiving
// on a channel.
//
// sudog is necessary because the g ↔ synchronization object relation
// is many-to-many. A g can be on many wait lists, so there may be
// many sudogs for one g; and many gs may be waiting on the same
// synchronization object, so there may be many sudogs for one object.
//
// sudogs are allocated from a special pool. Use acquireSudog and
// releaseSudog to allocate and free them.
type sudog struct {
// The following fields are protected by the hchan.lock of the
// channel this sudog is blocking on. shrinkstack depends on
// this for sudogs involved in channel ops.
g *g
next *sudog
prev *sudog
elem unsafe.Pointer // data element (may point to stack)
// The following fields are never accessed concurrently.
// For channels, waitlink is only accessed by g.
// For semaphores, all fields (including the ones above)
// are only accessed when holding a semaRoot lock.
acquiretime int64
releasetime int64
ticket uint32
// isSelect indicates g is participating in a select, so
// g.selectDone must be CAS'd to win the wake-up race.
isSelect bool
// success indicates whether communication over channel c
// succeeded. It is true if the goroutine was awoken because a
// value was delivered over channel c, and false if awoken
// because c was closed.
success bool
parent *sudog // semaRoot binary tree
waitlink *sudog // g.waiting list or semaRoot
waittail *sudog // semaRoot
c *hchan // channel
}
type libcall struct {
fn uintptr
n uintptr // number of parameters
args uintptr // parameters
r1 uintptr // return values
r2 uintptr
err uintptr // error number
}
// Stack describes a Go execution stack.
// The bounds of the stack are exactly [lo, hi),
// with no implicit data structures on either side.
type stack struct {
lo uintptr
hi uintptr
}
// heldLockInfo gives info on a held lock and the rank of that lock
type heldLockInfo struct {
lockAddr uintptr
rank lockRank
}
type g struct {
// Stack parameters.
// stack describes the actual stack memory: [stack.lo, stack.hi).
// stackguard0 is the stack pointer compared in the Go stack growth prologue.
// It is stack.lo+StackGuard normally, but can be StackPreempt to trigger a preemption.
// stackguard1 is the stack pointer compared in the C stack growth prologue.
// It is stack.lo+StackGuard on g0 and gsignal stacks.
// It is ~0 on other goroutine stacks, to trigger a call to morestackc (and crash).
stack stack // offset known to runtime/cgo
stackguard0 uintptr // offset known to liblink
stackguard1 uintptr // offset known to liblink
_panic *_panic // innermost panic - offset known to liblink
_defer *_defer // innermost defer
m *m // current m; offset known to arm liblink
sched gobuf
syscallsp uintptr // if status==Gsyscall, syscallsp = sched.sp to use during gc
syscallpc uintptr // if status==Gsyscall, syscallpc = sched.pc to use during gc
stktopsp uintptr // expected sp at top of stack, to check in traceback
// param is a generic pointer parameter field used to pass
// values in particular contexts where other storage for the
// parameter would be difficult to find. It is currently used
// in three ways:
// 1. When a channel operation wakes up a blocked goroutine, it sets param to
// point to the sudog of the completed blocking operation.
// 2. By gcAssistAlloc1 to signal back to its caller that the goroutine completed
// the GC cycle. It is unsafe to do so in any other way, because the goroutine's
// stack may have moved in the meantime.
// 3. By debugCallWrap to pass parameters to a new goroutine because allocating a
// closure in the runtime is forbidden.
param unsafe.Pointer
atomicstatus atomic.Uint32
stackLock uint32 // sigprof/scang lock; TODO: fold in to atomicstatus
goid uint64
schedlink guintptr
waitsince int64 // approx time when the g become blocked
waitreason waitReason // if status==Gwaiting
preempt bool // preemption signal, duplicates stackguard0 = stackpreempt
preemptStop bool // transition to _Gpreempted on preemption; otherwise, just deschedule
preemptShrink bool // shrink stack at synchronous safe point
// asyncSafePoint is set if g is stopped at an asynchronous
// safe point. This means there are frames on the stack
// without precise pointer information.
asyncSafePoint bool
paniconfault bool // panic (instead of crash) on unexpected fault address
gcscandone bool // g has scanned stack; protected by _Gscan bit in status
throwsplit bool // must not split stack
// activeStackChans indicates that there are unlocked channels
// pointing into this goroutine's stack. If true, stack
// copying needs to acquire channel locks to protect these
// areas of the stack.
activeStackChans bool
// parkingOnChan indicates that the goroutine is about to
// park on a chansend or chanrecv. Used to signal an unsafe point
// for stack shrinking.
parkingOnChan atomic.Bool
raceignore int8 // ignore race detection events
sysblocktraced bool // StartTrace has emitted EvGoInSyscall about this goroutine
tracking bool // whether we're tracking this G for sched latency statistics
trackingSeq uint8 // used to decide whether to track this G
trackingStamp int64 // timestamp of when the G last started being tracked
runnableTime int64 // the amount of time spent runnable, cleared when running, only used when tracking
sysexitticks int64 // cputicks when syscall has returned (for tracing)
traceseq uint64 // trace event sequencer
tracelastp puintptr // last P emitted an event for this goroutine
lockedm muintptr
sig uint32
writebuf []byte
sigcode0 uintptr
sigcode1 uintptr
sigpc uintptr
parentGoid uint64 // goid of goroutine that created this goroutine
gopc uintptr // pc of go statement that created this goroutine
ancestors *[]ancestorInfo // ancestor information goroutine(s) that created this goroutine (only used if debug.tracebackancestors)
startpc uintptr // pc of goroutine function
racectx uintptr
waiting *sudog // sudog structures this g is waiting on (that have a valid elem ptr); in lock order
cgoCtxt []uintptr // cgo traceback context
labels unsafe.Pointer // profiler labels
timer *timer // cached timer for time.Sleep
selectDone atomic.Uint32 // are we participating in a select and did someone win the race?
// goroutineProfiled indicates the status of this goroutine's stack for the
// current in-progress goroutine profile
goroutineProfiled goroutineProfileStateHolder
// Per-G GC state
// gcAssistBytes is this G's GC assist credit in terms of
// bytes allocated. If this is positive, then the G has credit
// to allocate gcAssistBytes bytes without assisting. If this
// is negative, then the G must correct this by performing
// scan work. We track this in bytes to make it fast to update
// and check for debt in the malloc hot path. The assist ratio
// determines how this corresponds to scan work debt.
gcAssistBytes int64
}
// gTrackingPeriod is the number of transitions out of _Grunning between
// latency tracking runs.
const gTrackingPeriod = 8
const (
// tlsSlots is the number of pointer-sized slots reserved for TLS on some platforms,
// like Windows.
tlsSlots = 6
tlsSize = tlsSlots * goarch.PtrSize
)
// Values for m.freeWait.
const (
freeMStack = 0 // M done, free stack and reference.
freeMRef = 1 // M done, free reference.
freeMWait = 2 // M still in use.
)
type m struct {
g0 *g // goroutine with scheduling stack
morebuf gobuf // gobuf arg to morestack
divmod uint32 // div/mod denominator for arm - known to liblink
_ uint32 // align next field to 8 bytes
// Fields not known to debuggers.
procid uint64 // for debuggers, but offset not hard-coded
gsignal *g // signal-handling g
goSigStack gsignalStack // Go-allocated signal handling stack
sigmask sigset // storage for saved signal mask
tls [tlsSlots]uintptr // thread-local storage (for x86 extern register)
mstartfn func()
curg *g // current running goroutine
caughtsig guintptr // goroutine running during fatal signal
p puintptr // attached p for executing go code (nil if not executing go code)
nextp puintptr
oldp puintptr // the p that was attached before executing a syscall
id int64
mallocing int32
throwing throwType
preemptoff string // if != "", keep curg running on this m
locks int32
dying int32
profilehz int32
spinning bool // m is out of work and is actively looking for work
blocked bool // m is blocked on a note
newSigstack bool // minit on C thread called sigaltstack
printlock int8
incgo bool // m is executing a cgo call
isextra bool // m is an extra m
freeWait atomic.Uint32 // Whether it is safe to free g0 and delete m (one of freeMRef, freeMStack, freeMWait)
fastrand uint64
needextram bool
traceback uint8
ncgocall uint64 // number of cgo calls in total
ncgo int32 // number of cgo calls currently in progress
cgoCallersUse atomic.Uint32 // if non-zero, cgoCallers in use temporarily
cgoCallers *cgoCallers // cgo traceback if crashing in cgo call
park note
alllink *m // on allm
schedlink muintptr
lockedg guintptr
createstack [32]uintptr // stack that created this thread.
lockedExt uint32 // tracking for external LockOSThread
lockedInt uint32 // tracking for internal lockOSThread
nextwaitm muintptr // next m waiting for lock
waitunlockf func(*g, unsafe.Pointer) bool
waitlock unsafe.Pointer
waittraceev byte
waittraceskip int
startingtrace bool
syscalltick uint32
freelink *m // on sched.freem
// these are here because they are too large to be on the stack
// of low-level NOSPLIT functions.
libcall libcall
libcallpc uintptr // for cpu profiler
libcallsp uintptr
libcallg guintptr
syscall libcall // stores syscall parameters on windows
vdsoSP uintptr // SP for traceback while in VDSO call (0 if not in call)
vdsoPC uintptr // PC for traceback while in VDSO call
// preemptGen counts the number of completed preemption
// signals. This is used to detect when a preemption is
// requested, but fails.
preemptGen atomic.Uint32
// Whether this is a pending preemption signal on this M.
signalPending atomic.Uint32
dlogPerM
mOS
// Up to 10 locks held by this m, maintained by the lock ranking code.
locksHeldLen int
locksHeld [10]heldLockInfo
}
type p struct {
id int32
status uint32 // one of pidle/prunning/...
link puintptr
schedtick uint32 // incremented on every scheduler call
syscalltick uint32 // incremented on every system call
sysmontick sysmontick // last tick observed by sysmon
m muintptr // back-link to associated m (nil if idle)
mcache *mcache
pcache pageCache
raceprocctx uintptr
deferpool []*_defer // pool of available defer structs (see panic.go)
deferpoolbuf [32]*_defer
// Cache of goroutine ids, amortizes accesses to runtime·sched.goidgen.
goidcache uint64
goidcacheend uint64
// Queue of runnable goroutines. Accessed without lock.
runqhead uint32
runqtail uint32
runq [256]guintptr
// runnext, if non-nil, is a runnable G that was ready'd by
// the current G and should be run next instead of what's in
// runq if there's time remaining in the running G's time
// slice. It will inherit the time left in the current time
// slice. If a set of goroutines is locked in a
// communicate-and-wait pattern, this schedules that set as a
// unit and eliminates the (potentially large) scheduling
// latency that otherwise arises from adding the ready'd
// goroutines to the end of the run queue.
//
// Note that while other P's may atomically CAS this to zero,
// only the owner P can CAS it to a valid G.
runnext guintptr
// Available G's (status == Gdead)
gFree struct {
gList
n int32
}
sudogcache []*sudog
sudogbuf [128]*sudog
// Cache of mspan objects from the heap.
mspancache struct {
// We need an explicit length here because this field is used
// in allocation codepaths where write barriers are not allowed,
// and eliminating the write barrier/keeping it eliminated from
// slice updates is tricky, moreso than just managing the length
// ourselves.
len int
buf [128]*mspan
}
tracebuf traceBufPtr
// traceSweep indicates the sweep events should be traced.
// This is used to defer the sweep start event until a span
// has actually been swept.
traceSweep bool
// traceSwept and traceReclaimed track the number of bytes
// swept and reclaimed by sweeping in the current sweep loop.
traceSwept, traceReclaimed uintptr
palloc persistentAlloc // per-P to avoid mutex
// The when field of the first entry on the timer heap.
// This is 0 if the timer heap is empty.
timer0When atomic.Int64
// The earliest known nextwhen field of a timer with
// timerModifiedEarlier status. Because the timer may have been
// modified again, there need not be any timer with this value.
// This is 0 if there are no timerModifiedEarlier timers.
timerModifiedEarliest atomic.Int64
// Per-P GC state
gcAssistTime int64 // Nanoseconds in assistAlloc
gcFractionalMarkTime int64 // Nanoseconds in fractional mark worker (atomic)
// limiterEvent tracks events for the GC CPU limiter.
limiterEvent limiterEvent
// gcMarkWorkerMode is the mode for the next mark worker to run in.
// That is, this is used to communicate with the worker goroutine
// selected for immediate execution by
// gcController.findRunnableGCWorker. When scheduling other goroutines,
// this field must be set to gcMarkWorkerNotWorker.
gcMarkWorkerMode gcMarkWorkerMode
// gcMarkWorkerStartTime is the nanotime() at which the most recent
// mark worker started.
gcMarkWorkerStartTime int64
// gcw is this P's GC work buffer cache. The work buffer is
// filled by write barriers, drained by mutator assists, and
// disposed on certain GC state transitions.
gcw gcWork
// wbBuf is this P's GC write barrier buffer.
//
// TODO: Consider caching this in the running G.
wbBuf wbBuf
runSafePointFn uint32 // if 1, run sched.safePointFn at next safe point
// statsSeq is a counter indicating whether this P is currently
// writing any stats. Its value is even when not, odd when it is.
statsSeq atomic.Uint32
// Lock for timers. We normally access the timers while running
// on this P, but the scheduler can also do it from a different P.
timersLock mutex
// Actions to take at some time. This is used to implement the
// standard library's time package.
// Must hold timersLock to access.
timers []*timer
// Number of timers in P's heap.
numTimers atomic.Uint32
// Number of timerDeleted timers in P's heap.
deletedTimers atomic.Uint32
// Race context used while executing timer functions.
timerRaceCtx uintptr
// maxStackScanDelta accumulates the amount of stack space held by
// live goroutines (i.e. those eligible for stack scanning).
// Flushed to gcController.maxStackScan once maxStackScanSlack
// or -maxStackScanSlack is reached.
maxStackScanDelta int64
// gc-time statistics about current goroutines
// Note that this differs from maxStackScan in that this
// accumulates the actual stack observed to be used at GC time (hi - sp),
// not an instantaneous measure of the total stack size that might need
// to be scanned (hi - lo).
scannedStackSize uint64 // stack size of goroutines scanned by this P
scannedStacks uint64 // number of goroutines scanned by this P
// preempt is set to indicate that this P should be enter the
// scheduler ASAP (regardless of what G is running on it).
preempt bool
// pageTraceBuf is a buffer for writing out page allocation/free/scavenge traces.
//
// Used only if GOEXPERIMENT=pagetrace.
pageTraceBuf pageTraceBuf
// Padding is no longer needed. False sharing is now not a worry because p is large enough
// that its size class is an integer multiple of the cache line size (for any of our architectures).
}
type schedt struct {
goidgen atomic.Uint64
lastpoll atomic.Int64 // time of last network poll, 0 if currently polling
pollUntil atomic.Int64 // time to which current poll is sleeping
lock mutex
// When increasing nmidle, nmidlelocked, nmsys, or nmfreed, be
// sure to call checkdead().
midle muintptr // idle m's waiting for work
nmidle int32 // number of idle m's waiting for work
nmidlelocked int32 // number of locked m's waiting for work
mnext int64 // number of m's that have been created and next M ID
maxmcount int32 // maximum number of m's allowed (or die)
nmsys int32 // number of system m's not counted for deadlock
nmfreed int64 // cumulative number of freed m's
ngsys atomic.Int32 // number of system goroutines
pidle puintptr // idle p's
npidle atomic.Int32
nmspinning atomic.Int32 // See "Worker thread parking/unparking" comment in proc.go.
needspinning atomic.Uint32 // See "Delicate dance" comment in proc.go. Boolean. Must hold sched.lock to set to 1.
// Global runnable queue.
runq gQueue
runqsize int32
// disable controls selective disabling of the scheduler.
//
// Use schedEnableUser to control this.
//
// disable is protected by sched.lock.
disable struct {
// user disables scheduling of user goroutines.
user bool
runnable gQueue // pending runnable Gs
n int32 // length of runnable
}
// Global cache of dead G's.
gFree struct {
lock mutex
stack gList // Gs with stacks
noStack gList // Gs without stacks
n int32
}
// Central cache of sudog structs.
sudoglock mutex
sudogcache *sudog
// Central pool of available defer structs.
deferlock mutex
deferpool *_defer
// freem is the list of m's waiting to be freed when their
// m.exited is set. Linked through m.freelink.
freem *m
gcwaiting atomic.Bool // gc is waiting to run
stopwait int32
stopnote note
sysmonwait atomic.Bool
sysmonnote note
// safepointFn should be called on each P at the next GC
// safepoint if p.runSafePointFn is set.
safePointFn func(*p)
safePointWait int32
safePointNote note
profilehz int32 // cpu profiling rate
procresizetime int64 // nanotime() of last change to gomaxprocs
totaltime int64 // ∫gomaxprocs dt up to procresizetime
// sysmonlock protects sysmon's actions on the runtime.
//
// Acquire and hold this mutex to block sysmon from interacting
// with the rest of the runtime.
sysmonlock mutex
// timeToRun is a distribution of scheduling latencies, defined
// as the sum of time a G spends in the _Grunnable state before
// it transitions to _Grunning.
timeToRun timeHistogram
// idleTime is the total CPU time Ps have "spent" idle.
//
// Reset on each GC cycle.
idleTime atomic.Int64
// totalMutexWaitTime is the sum of time goroutines have spent in _Gwaiting
// with a waitreason of the form waitReasonSync{RW,}Mutex{R,}Lock.
totalMutexWaitTime atomic.Int64
}
// Values for the flags field of a sigTabT.
const (
_SigNotify = 1 << iota // let signal.Notify have signal, even if from kernel
_SigKill // if signal.Notify doesn't take it, exit quietly
_SigThrow // if signal.Notify doesn't take it, exit loudly
_SigPanic // if the signal is from the kernel, panic
_SigDefault // if the signal isn't explicitly requested, don't monitor it
_SigGoExit // cause all runtime procs to exit (only used on Plan 9).
_SigSetStack // Don't explicitly install handler, but add SA_ONSTACK to existing libc handler
_SigUnblock // always unblock; see blockableSig
_SigIgn // _SIG_DFL action is to ignore the signal
)
// Layout of in-memory per-function information prepared by linker
// See https://golang.org/s/go12symtab.
// Keep in sync with linker (../cmd/link/internal/ld/pcln.go:/pclntab)
// and with package debug/gosym and with symtab.go in package runtime.
type _func struct {
entryOff uint32 // start pc, as offset from moduledata.text/pcHeader.textStart
nameOff int32 // function name, as index into moduledata.funcnametab.
args int32 // in/out args size
deferreturn uint32 // offset of start of a deferreturn call instruction from entry, if any.
pcsp uint32
pcfile uint32
pcln uint32
npcdata uint32
cuOffset uint32 // runtime.cutab offset of this function's CU
startLine int32 // line number of start of function (func keyword/TEXT directive)
funcID funcID // set for certain special runtime functions
flag funcFlag
_ [1]byte // pad
nfuncdata uint8 // must be last, must end on a uint32-aligned boundary
// The end of the struct is followed immediately by two variable-length
// arrays that reference the pcdata and funcdata locations for this
// function.
// pcdata contains the offset into moduledata.pctab for the start of
// that index's table. e.g.,
// &moduledata.pctab[_func.pcdata[_PCDATA_UnsafePoint]] is the start of
// the unsafe point table.
//
// An offset of 0 indicates that there is no table.
//
// pcdata [npcdata]uint32
// funcdata contains the offset past moduledata.gofunc which contains a
// pointer to that index's funcdata. e.g.,
// *(moduledata.gofunc + _func.funcdata[_FUNCDATA_ArgsPointerMaps]) is
// the argument pointer map.
//
// An offset of ^uint32(0) indicates that there is no entry.
//
// funcdata [nfuncdata]uint32
}
// Pseudo-Func that is returned for PCs that occur in inlined code.
// A *Func can be either a *_func or a *funcinl, and they are distinguished
// by the first uintptr.
type funcinl struct {
ones uint32 // set to ^0 to distinguish from _func
entry uintptr // entry of the real (the "outermost") frame
name string
file string
line int32
startLine int32
}
// layout of Itab known to compilers
// allocated in non-garbage-collected memory
// Needs to be in sync with
// ../cmd/compile/internal/reflectdata/reflect.go:/^func.WriteTabs.
type itab struct {
inter *interfacetype
_type *_type
hash uint32 // copy of _type.hash. Used for type switches.
_ [4]byte
fun [1]uintptr // variable sized. fun[0]==0 means _type does not implement inter.
}
// Lock-free stack node.
// Also known to export_test.go.
type lfnode struct {
next uint64
pushcnt uintptr
}
type forcegcstate struct {
lock mutex
g *g
idle atomic.Bool
}
// extendRandom extends the random numbers in r[:n] to the whole slice r.
// Treats n<0 as n==0.
func extendRandom(r []byte, n int) {
if n < 0 {
n = 0
}
for n < len(r) {
// Extend random bits using hash function & time seed
w := n
if w > 16 {
w = 16
}
h := memhash(unsafe.Pointer(&r[n-w]), uintptr(nanotime()), uintptr(w))
for i := 0; i < goarch.PtrSize && n < len(r); i++ {
r[n] = byte(h)
n++
h >>= 8
}
}
}
// A _defer holds an entry on the list of deferred calls.
// If you add a field here, add code to clear it in deferProcStack.
// This struct must match the code in cmd/compile/internal/ssagen/ssa.go:deferstruct
// and cmd/compile/internal/ssagen/ssa.go:(*state).call.
// Some defers will be allocated on the stack and some on the heap.
// All defers are logically part of the stack, so write barriers to
// initialize them are not required. All defers must be manually scanned,
// and for heap defers, marked.
type _defer struct {
started bool
heap bool
// openDefer indicates that this _defer is for a frame with open-coded
// defers. We have only one defer record for the entire frame (which may
// currently have 0, 1, or more defers active).
openDefer bool
sp uintptr // sp at time of defer
pc uintptr // pc at time of defer
fn func() // can be nil for open-coded defers
_panic *_panic // panic that is running defer
link *_defer // next defer on G; can point to either heap or stack!
// If openDefer is true, the fields below record values about the stack
// frame and associated function that has the open-coded defer(s). sp
// above will be the sp for the frame, and pc will be address of the
// deferreturn call in the function.
fd unsafe.Pointer // funcdata for the function associated with the frame
varp uintptr // value of varp for the stack frame
// framepc is the current pc associated with the stack frame. Together,
// with sp above (which is the sp associated with the stack frame),
// framepc/sp can be used as pc/sp pair to continue a stack trace via
// gentraceback().
framepc uintptr
}
// A _panic holds information about an active panic.
//
// A _panic value must only ever live on the stack.
//
// The argp and link fields are stack pointers, but don't need special
// handling during stack growth: because they are pointer-typed and
// _panic values only live on the stack, regular stack pointer
// adjustment takes care of them.
type _panic struct {
argp unsafe.Pointer // pointer to arguments of deferred call run during panic; cannot move - known to liblink
arg any // argument to panic
link *_panic // link to earlier panic
pc uintptr // where to return to in runtime if this panic is bypassed
sp unsafe.Pointer // where to return to in runtime if this panic is bypassed
recovered bool // whether this panic is over
aborted bool // the panic was aborted
goexit bool
}
// ancestorInfo records details of where a goroutine was started.
type ancestorInfo struct {
pcs []uintptr // pcs from the stack of this goroutine
goid uint64 // goroutine id of this goroutine; original goroutine possibly dead
gopc uintptr // pc of go statement that created this goroutine
}
const (
_TraceRuntimeFrames = 1 << iota // include frames for internal runtime functions.
_TraceTrap // the initial PC, SP are from a trap, not a return PC from a call
_TraceJumpStack // if traceback is on a systemstack, resume trace at g that called into it
)
// The maximum number of frames we print for a traceback
const _TracebackMaxFrames = 100
// A waitReason explains why a goroutine has been stopped.
// See gopark. Do not re-use waitReasons, add new ones.
type waitReason uint8
const (
waitReasonZero waitReason = iota // ""
waitReasonGCAssistMarking // "GC assist marking"
waitReasonIOWait // "IO wait"
waitReasonChanReceiveNilChan // "chan receive (nil chan)"
waitReasonChanSendNilChan // "chan send (nil chan)"
waitReasonDumpingHeap // "dumping heap"
waitReasonGarbageCollection // "garbage collection"
waitReasonGarbageCollectionScan // "garbage collection scan"
waitReasonPanicWait // "panicwait"
waitReasonSelect // "select"
waitReasonSelectNoCases // "select (no cases)"
waitReasonGCAssistWait // "GC assist wait"
waitReasonGCSweepWait // "GC sweep wait"
waitReasonGCScavengeWait // "GC scavenge wait"
waitReasonChanReceive // "chan receive"
waitReasonChanSend // "chan send"
waitReasonFinalizerWait // "finalizer wait"
waitReasonForceGCIdle // "force gc (idle)"
waitReasonSemacquire // "semacquire"
waitReasonSleep // "sleep"
waitReasonSyncCondWait // "sync.Cond.Wait"
waitReasonSyncMutexLock // "sync.Mutex.Lock"
waitReasonSyncRWMutexRLock // "sync.RWMutex.RLock"
waitReasonSyncRWMutexLock // "sync.RWMutex.Lock"
waitReasonTraceReaderBlocked // "trace reader (blocked)"
waitReasonWaitForGCCycle // "wait for GC cycle"
waitReasonGCWorkerIdle // "GC worker (idle)"
waitReasonGCWorkerActive // "GC worker (active)"
waitReasonPreempted // "preempted"
waitReasonDebugCall // "debug call"
waitReasonGCMarkTermination // "GC mark termination"
waitReasonStoppingTheWorld // "stopping the world"
)
var waitReasonStrings = [...]string{
waitReasonZero: "",
waitReasonGCAssistMarking: "GC assist marking",
waitReasonIOWait: "IO wait",
waitReasonChanReceiveNilChan: "chan receive (nil chan)",
waitReasonChanSendNilChan: "chan send (nil chan)",
waitReasonDumpingHeap: "dumping heap",
waitReasonGarbageCollection: "garbage collection",
waitReasonGarbageCollectionScan: "garbage collection scan",
waitReasonPanicWait: "panicwait",
waitReasonSelect: "select",
waitReasonSelectNoCases: "select (no cases)",
waitReasonGCAssistWait: "GC assist wait",
waitReasonGCSweepWait: "GC sweep wait",
waitReasonGCScavengeWait: "GC scavenge wait",
waitReasonChanReceive: "chan receive",
waitReasonChanSend: "chan send",
waitReasonFinalizerWait: "finalizer wait",
waitReasonForceGCIdle: "force gc (idle)",
waitReasonSemacquire: "semacquire",
waitReasonSleep: "sleep",
waitReasonSyncCondWait: "sync.Cond.Wait",
waitReasonSyncMutexLock: "sync.Mutex.Lock",
waitReasonSyncRWMutexRLock: "sync.RWMutex.RLock",
waitReasonSyncRWMutexLock: "sync.RWMutex.Lock",
waitReasonTraceReaderBlocked: "trace reader (blocked)",
waitReasonWaitForGCCycle: "wait for GC cycle",
waitReasonGCWorkerIdle: "GC worker (idle)",
waitReasonGCWorkerActive: "GC worker (active)",
waitReasonPreempted: "preempted",
waitReasonDebugCall: "debug call",
waitReasonGCMarkTermination: "GC mark termination",
waitReasonStoppingTheWorld: "stopping the world",
}
func (w waitReason) String() string {
if w < 0 || w >= waitReason(len(waitReasonStrings)) {
return "unknown wait reason"
}
return waitReasonStrings[w]
}
func (w waitReason) isMutexWait() bool {
return w == waitReasonSyncMutexLock ||
w == waitReasonSyncRWMutexRLock ||
w == waitReasonSyncRWMutexLock
}
var (
allm *m
gomaxprocs int32
ncpu int32
forcegc forcegcstate
sched schedt
newprocs int32
// allpLock protects P-less reads and size changes of allp, idlepMask,
// and timerpMask, and all writes to allp.
allpLock mutex
// len(allp) == gomaxprocs; may change at safe points, otherwise
// immutable.
allp []*p
// Bitmask of Ps in _Pidle list, one bit per P. Reads and writes must
// be atomic. Length may change at safe points.
//
// Each P must update only its own bit. In order to maintain
// consistency, a P going idle must the idle mask simultaneously with
// updates to the idle P list under the sched.lock, otherwise a racing
// pidleget may clear the mask before pidleput sets the mask,
// corrupting the bitmap.
//
// N.B., procresize takes ownership of all Ps in stopTheWorldWithSema.
idlepMask pMask
// Bitmask of Ps that may have a timer, one bit per P. Reads and writes
// must be atomic. Length may change at safe points.
timerpMask pMask
// Pool of GC parked background workers. Entries are type
// *gcBgMarkWorkerNode.
gcBgMarkWorkerPool lfstack
// Total number of gcBgMarkWorker goroutines. Protected by worldsema.
gcBgMarkWorkerCount int32
// Information about what cpu features are available.
// Packages outside the runtime should not use these
// as they are not an external api.
// Set on startup in asm_{386,amd64}.s
processorVersionInfo uint32
isIntel bool
goarm uint8 // set by cmd/link on arm systems
)
// Set by the linker so the runtime can determine the buildmode.
var (
islibrary bool // -buildmode=c-shared
isarchive bool // -buildmode=c-archive
)
// Must agree with internal/buildcfg.FramePointerEnabled.
const framepointer_enabled = GOARCH == "amd64" || GOARCH == "arm64"
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import _ "unsafe" // for go:linkname
//go:linkname boring_runtime_arg0 crypto/internal/boring.runtime_arg0
func boring_runtime_arg0() string {
// On Windows, argslice is not set, and it's too much work to find argv0.
if len(argslice) == 0 {
return ""
}
return argslice[0]
}
//go:linkname fipstls_runtime_arg0 crypto/internal/boring/fipstls.runtime_arg0
func fipstls_runtime_arg0() string { return boring_runtime_arg0() }
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/atomic"
)
// This is a copy of sync/rwmutex.go rewritten to work in the runtime.
// A rwmutex is a reader/writer mutual exclusion lock.
// The lock can be held by an arbitrary number of readers or a single writer.
// This is a variant of sync.RWMutex, for the runtime package.
// Like mutex, rwmutex blocks the calling M.
// It does not interact with the goroutine scheduler.
type rwmutex struct {
rLock mutex // protects readers, readerPass, writer
readers muintptr // list of pending readers
readerPass uint32 // number of pending readers to skip readers list
wLock mutex // serializes writers
writer muintptr // pending writer waiting for completing readers
readerCount atomic.Int32 // number of pending readers
readerWait atomic.Int32 // number of departing readers
}
const rwmutexMaxReaders = 1 << 30
// rlock locks rw for reading.
func (rw *rwmutex) rlock() {
// The reader must not be allowed to lose its P or else other
// things blocking on the lock may consume all of the Ps and
// deadlock (issue #20903). Alternatively, we could drop the P
// while sleeping.
acquirem()
if rw.readerCount.Add(1) < 0 {
// A writer is pending. Park on the reader queue.
systemstack(func() {
lockWithRank(&rw.rLock, lockRankRwmutexR)
if rw.readerPass > 0 {
// Writer finished.
rw.readerPass -= 1
unlock(&rw.rLock)
} else {
// Queue this reader to be woken by
// the writer.
m := getg().m
m.schedlink = rw.readers
rw.readers.set(m)
unlock(&rw.rLock)
notesleep(&m.park)
noteclear(&m.park)
}
})
}
}
// runlock undoes a single rlock call on rw.
func (rw *rwmutex) runlock() {
if r := rw.readerCount.Add(-1); r < 0 {
if r+1 == 0 || r+1 == -rwmutexMaxReaders {
throw("runlock of unlocked rwmutex")
}
// A writer is pending.
if rw.readerWait.Add(-1) == 0 {
// The last reader unblocks the writer.
lockWithRank(&rw.rLock, lockRankRwmutexR)
w := rw.writer.ptr()
if w != nil {
notewakeup(&w.park)
}
unlock(&rw.rLock)
}
}
releasem(getg().m)
}
// lock locks rw for writing.
func (rw *rwmutex) lock() {
// Resolve competition with other writers and stick to our P.
lockWithRank(&rw.wLock, lockRankRwmutexW)
m := getg().m
// Announce that there is a pending writer.
r := rw.readerCount.Add(-rwmutexMaxReaders) + rwmutexMaxReaders
// Wait for any active readers to complete.
lockWithRank(&rw.rLock, lockRankRwmutexR)
if r != 0 && rw.readerWait.Add(r) != 0 {
// Wait for reader to wake us up.
systemstack(func() {
rw.writer.set(m)
unlock(&rw.rLock)
notesleep(&m.park)
noteclear(&m.park)
})
} else {
unlock(&rw.rLock)
}
}
// unlock unlocks rw for writing.
func (rw *rwmutex) unlock() {
// Announce to readers that there is no active writer.
r := rw.readerCount.Add(rwmutexMaxReaders)
if r >= rwmutexMaxReaders {
throw("unlock of unlocked rwmutex")
}
// Unblock blocked readers.
lockWithRank(&rw.rLock, lockRankRwmutexR)
for rw.readers.ptr() != nil {
reader := rw.readers.ptr()
rw.readers = reader.schedlink
reader.schedlink.set(nil)
notewakeup(&reader.park)
r -= 1
}
// If r > 0, there are pending readers that aren't on the
// queue. Tell them to skip waiting.
rw.readerPass += uint32(r)
unlock(&rw.rLock)
// Allow other writers to proceed.
unlock(&rw.wLock)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// This file contains the implementation of Go select statements.
import (
"internal/abi"
"unsafe"
)
const debugSelect = false
// Select case descriptor.
// Known to compiler.
// Changes here must also be made in src/cmd/compile/internal/walk/select.go's scasetype.
type scase struct {
c *hchan // chan
elem unsafe.Pointer // data element
}
var (
chansendpc = abi.FuncPCABIInternal(chansend)
chanrecvpc = abi.FuncPCABIInternal(chanrecv)
)
func selectsetpc(pc *uintptr) {
*pc = getcallerpc()
}
func sellock(scases []scase, lockorder []uint16) {
var c *hchan
for _, o := range lockorder {
c0 := scases[o].c
if c0 != c {
c = c0
lock(&c.lock)
}
}
}
func selunlock(scases []scase, lockorder []uint16) {
// We must be very careful here to not touch sel after we have unlocked
// the last lock, because sel can be freed right after the last unlock.
// Consider the following situation.
// First M calls runtime·park() in runtime·selectgo() passing the sel.
// Once runtime·park() has unlocked the last lock, another M makes
// the G that calls select runnable again and schedules it for execution.
// When the G runs on another M, it locks all the locks and frees sel.
// Now if the first M touches sel, it will access freed memory.
for i := len(lockorder) - 1; i >= 0; i-- {
c := scases[lockorder[i]].c
if i > 0 && c == scases[lockorder[i-1]].c {
continue // will unlock it on the next iteration
}
unlock(&c.lock)
}
}
func selparkcommit(gp *g, _ unsafe.Pointer) bool {
// There are unlocked sudogs that point into gp's stack. Stack
// copying must lock the channels of those sudogs.
// Set activeStackChans here instead of before we try parking
// because we could self-deadlock in stack growth on a
// channel lock.
gp.activeStackChans = true
// Mark that it's safe for stack shrinking to occur now,
// because any thread acquiring this G's stack for shrinking
// is guaranteed to observe activeStackChans after this store.
gp.parkingOnChan.Store(false)
// Make sure we unlock after setting activeStackChans and
// unsetting parkingOnChan. The moment we unlock any of the
// channel locks we risk gp getting readied by a channel operation
// and so gp could continue running before everything before the
// unlock is visible (even to gp itself).
// This must not access gp's stack (see gopark). In
// particular, it must not access the *hselect. That's okay,
// because by the time this is called, gp.waiting has all
// channels in lock order.
var lastc *hchan
for sg := gp.waiting; sg != nil; sg = sg.waitlink {
if sg.c != lastc && lastc != nil {
// As soon as we unlock the channel, fields in
// any sudog with that channel may change,
// including c and waitlink. Since multiple
// sudogs may have the same channel, we unlock
// only after we've passed the last instance
// of a channel.
unlock(&lastc.lock)
}
lastc = sg.c
}
if lastc != nil {
unlock(&lastc.lock)
}
return true
}
func block() {
gopark(nil, nil, waitReasonSelectNoCases, traceEvGoStop, 1) // forever
}
// selectgo implements the select statement.
//
// cas0 points to an array of type [ncases]scase, and order0 points to
// an array of type [2*ncases]uint16 where ncases must be <= 65536.
// Both reside on the goroutine's stack (regardless of any escaping in
// selectgo).
//
// For race detector builds, pc0 points to an array of type
// [ncases]uintptr (also on the stack); for other builds, it's set to
// nil.
//
// selectgo returns the index of the chosen scase, which matches the
// ordinal position of its respective select{recv,send,default} call.
// Also, if the chosen scase was a receive operation, it reports whether
// a value was received.
func selectgo(cas0 *scase, order0 *uint16, pc0 *uintptr, nsends, nrecvs int, block bool) (int, bool) {
if debugSelect {
print("select: cas0=", cas0, "\n")
}
// NOTE: In order to maintain a lean stack size, the number of scases
// is capped at 65536.
cas1 := (*[1 << 16]scase)(unsafe.Pointer(cas0))
order1 := (*[1 << 17]uint16)(unsafe.Pointer(order0))
ncases := nsends + nrecvs
scases := cas1[:ncases:ncases]
pollorder := order1[:ncases:ncases]
lockorder := order1[ncases:][:ncases:ncases]
// NOTE: pollorder/lockorder's underlying array was not zero-initialized by compiler.
// Even when raceenabled is true, there might be select
// statements in packages compiled without -race (e.g.,
// ensureSigM in runtime/signal_unix.go).
var pcs []uintptr
if raceenabled && pc0 != nil {
pc1 := (*[1 << 16]uintptr)(unsafe.Pointer(pc0))
pcs = pc1[:ncases:ncases]
}
casePC := func(casi int) uintptr {
if pcs == nil {
return 0
}
return pcs[casi]
}
var t0 int64
if blockprofilerate > 0 {
t0 = cputicks()
}
// The compiler rewrites selects that statically have
// only 0 or 1 cases plus default into simpler constructs.
// The only way we can end up with such small sel.ncase
// values here is for a larger select in which most channels
// have been nilled out. The general code handles those
// cases correctly, and they are rare enough not to bother
// optimizing (and needing to test).
// generate permuted order
norder := 0
for i := range scases {
cas := &scases[i]
// Omit cases without channels from the poll and lock orders.
if cas.c == nil {
cas.elem = nil // allow GC
continue
}
j := fastrandn(uint32(norder + 1))
pollorder[norder] = pollorder[j]
pollorder[j] = uint16(i)
norder++
}
pollorder = pollorder[:norder]
lockorder = lockorder[:norder]
// sort the cases by Hchan address to get the locking order.
// simple heap sort, to guarantee n log n time and constant stack footprint.
for i := range lockorder {
j := i
// Start with the pollorder to permute cases on the same channel.
c := scases[pollorder[i]].c
for j > 0 && scases[lockorder[(j-1)/2]].c.sortkey() < c.sortkey() {
k := (j - 1) / 2
lockorder[j] = lockorder[k]
j = k
}
lockorder[j] = pollorder[i]
}
for i := len(lockorder) - 1; i >= 0; i-- {
o := lockorder[i]
c := scases[o].c
lockorder[i] = lockorder[0]
j := 0
for {
k := j*2 + 1
if k >= i {
break
}
if k+1 < i && scases[lockorder[k]].c.sortkey() < scases[lockorder[k+1]].c.sortkey() {
k++
}
if c.sortkey() < scases[lockorder[k]].c.sortkey() {
lockorder[j] = lockorder[k]
j = k
continue
}
break
}
lockorder[j] = o
}
if debugSelect {
for i := 0; i+1 < len(lockorder); i++ {
if scases[lockorder[i]].c.sortkey() > scases[lockorder[i+1]].c.sortkey() {
print("i=", i, " x=", lockorder[i], " y=", lockorder[i+1], "\n")
throw("select: broken sort")
}
}
}
// lock all the channels involved in the select
sellock(scases, lockorder)
var (
gp *g
sg *sudog
c *hchan
k *scase
sglist *sudog
sgnext *sudog
qp unsafe.Pointer
nextp **sudog
)
// pass 1 - look for something already waiting
var casi int
var cas *scase
var caseSuccess bool
var caseReleaseTime int64 = -1
var recvOK bool
for _, casei := range pollorder {
casi = int(casei)
cas = &scases[casi]
c = cas.c
if casi >= nsends {
sg = c.sendq.dequeue()
if sg != nil {
goto recv
}
if c.qcount > 0 {
goto bufrecv
}
if c.closed != 0 {
goto rclose
}
} else {
if raceenabled {
racereadpc(c.raceaddr(), casePC(casi), chansendpc)
}
if c.closed != 0 {
goto sclose
}
sg = c.recvq.dequeue()
if sg != nil {
goto send
}
if c.qcount < c.dataqsiz {
goto bufsend
}
}
}
if !block {
selunlock(scases, lockorder)
casi = -1
goto retc
}
// pass 2 - enqueue on all chans
gp = getg()
if gp.waiting != nil {
throw("gp.waiting != nil")
}
nextp = &gp.waiting
for _, casei := range lockorder {
casi = int(casei)
cas = &scases[casi]
c = cas.c
sg := acquireSudog()
sg.g = gp
sg.isSelect = true
// No stack splits between assigning elem and enqueuing
// sg on gp.waiting where copystack can find it.
sg.elem = cas.elem
sg.releasetime = 0
if t0 != 0 {
sg.releasetime = -1
}
sg.c = c
// Construct waiting list in lock order.
*nextp = sg
nextp = &sg.waitlink
if casi < nsends {
c.sendq.enqueue(sg)
} else {
c.recvq.enqueue(sg)
}
}
// wait for someone to wake us up
gp.param = nil
// Signal to anyone trying to shrink our stack that we're about
// to park on a channel. The window between when this G's status
// changes and when we set gp.activeStackChans is not safe for
// stack shrinking.
gp.parkingOnChan.Store(true)
gopark(selparkcommit, nil, waitReasonSelect, traceEvGoBlockSelect, 1)
gp.activeStackChans = false
sellock(scases, lockorder)
gp.selectDone.Store(0)
sg = (*sudog)(gp.param)
gp.param = nil
// pass 3 - dequeue from unsuccessful chans
// otherwise they stack up on quiet channels
// record the successful case, if any.
// We singly-linked up the SudoGs in lock order.
casi = -1
cas = nil
caseSuccess = false
sglist = gp.waiting
// Clear all elem before unlinking from gp.waiting.
for sg1 := gp.waiting; sg1 != nil; sg1 = sg1.waitlink {
sg1.isSelect = false
sg1.elem = nil
sg1.c = nil
}
gp.waiting = nil
for _, casei := range lockorder {
k = &scases[casei]
if sg == sglist {
// sg has already been dequeued by the G that woke us up.
casi = int(casei)
cas = k
caseSuccess = sglist.success
if sglist.releasetime > 0 {
caseReleaseTime = sglist.releasetime
}
} else {
c = k.c
if int(casei) < nsends {
c.sendq.dequeueSudoG(sglist)
} else {
c.recvq.dequeueSudoG(sglist)
}
}
sgnext = sglist.waitlink
sglist.waitlink = nil
releaseSudog(sglist)
sglist = sgnext
}
if cas == nil {
throw("selectgo: bad wakeup")
}
c = cas.c
if debugSelect {
print("wait-return: cas0=", cas0, " c=", c, " cas=", cas, " send=", casi < nsends, "\n")
}
if casi < nsends {
if !caseSuccess {
goto sclose
}
} else {
recvOK = caseSuccess
}
if raceenabled {
if casi < nsends {
raceReadObjectPC(c.elemtype, cas.elem, casePC(casi), chansendpc)
} else if cas.elem != nil {
raceWriteObjectPC(c.elemtype, cas.elem, casePC(casi), chanrecvpc)
}
}
if msanenabled {
if casi < nsends {
msanread(cas.elem, c.elemtype.size)
} else if cas.elem != nil {
msanwrite(cas.elem, c.elemtype.size)
}
}
if asanenabled {
if casi < nsends {
asanread(cas.elem, c.elemtype.size)
} else if cas.elem != nil {
asanwrite(cas.elem, c.elemtype.size)
}
}
selunlock(scases, lockorder)
goto retc
bufrecv:
// can receive from buffer
if raceenabled {
if cas.elem != nil {
raceWriteObjectPC(c.elemtype, cas.elem, casePC(casi), chanrecvpc)
}
racenotify(c, c.recvx, nil)
}
if msanenabled && cas.elem != nil {
msanwrite(cas.elem, c.elemtype.size)
}
if asanenabled && cas.elem != nil {
asanwrite(cas.elem, c.elemtype.size)
}
recvOK = true
qp = chanbuf(c, c.recvx)
if cas.elem != nil {
typedmemmove(c.elemtype, cas.elem, qp)
}
typedmemclr(c.elemtype, qp)
c.recvx++
if c.recvx == c.dataqsiz {
c.recvx = 0
}
c.qcount--
selunlock(scases, lockorder)
goto retc
bufsend:
// can send to buffer
if raceenabled {
racenotify(c, c.sendx, nil)
raceReadObjectPC(c.elemtype, cas.elem, casePC(casi), chansendpc)
}
if msanenabled {
msanread(cas.elem, c.elemtype.size)
}
if asanenabled {
asanread(cas.elem, c.elemtype.size)
}
typedmemmove(c.elemtype, chanbuf(c, c.sendx), cas.elem)
c.sendx++
if c.sendx == c.dataqsiz {
c.sendx = 0
}
c.qcount++
selunlock(scases, lockorder)
goto retc
recv:
// can receive from sleeping sender (sg)
recv(c, sg, cas.elem, func() { selunlock(scases, lockorder) }, 2)
if debugSelect {
print("syncrecv: cas0=", cas0, " c=", c, "\n")
}
recvOK = true
goto retc
rclose:
// read at end of closed channel
selunlock(scases, lockorder)
recvOK = false
if cas.elem != nil {
typedmemclr(c.elemtype, cas.elem)
}
if raceenabled {
raceacquire(c.raceaddr())
}
goto retc
send:
// can send to a sleeping receiver (sg)
if raceenabled {
raceReadObjectPC(c.elemtype, cas.elem, casePC(casi), chansendpc)
}
if msanenabled {
msanread(cas.elem, c.elemtype.size)
}
if asanenabled {
asanread(cas.elem, c.elemtype.size)
}
send(c, sg, cas.elem, func() { selunlock(scases, lockorder) }, 2)
if debugSelect {
print("syncsend: cas0=", cas0, " c=", c, "\n")
}
goto retc
retc:
if caseReleaseTime > 0 {
blockevent(caseReleaseTime-t0, 1)
}
return casi, recvOK
sclose:
// send on closed channel
selunlock(scases, lockorder)
panic(plainError("send on closed channel"))
}
func (c *hchan) sortkey() uintptr {
return uintptr(unsafe.Pointer(c))
}
// A runtimeSelect is a single case passed to rselect.
// This must match ../reflect/value.go:/runtimeSelect
type runtimeSelect struct {
dir selectDir
typ unsafe.Pointer // channel type (not used here)
ch *hchan // channel
val unsafe.Pointer // ptr to data (SendDir) or ptr to receive buffer (RecvDir)
}
// These values must match ../reflect/value.go:/SelectDir.
type selectDir int
const (
_ selectDir = iota
selectSend // case Chan <- Send
selectRecv // case <-Chan:
selectDefault // default
)
//go:linkname reflect_rselect reflect.rselect
func reflect_rselect(cases []runtimeSelect) (int, bool) {
if len(cases) == 0 {
block()
}
sel := make([]scase, len(cases))
orig := make([]int, len(cases))
nsends, nrecvs := 0, 0
dflt := -1
for i, rc := range cases {
var j int
switch rc.dir {
case selectDefault:
dflt = i
continue
case selectSend:
j = nsends
nsends++
case selectRecv:
nrecvs++
j = len(cases) - nrecvs
}
sel[j] = scase{c: rc.ch, elem: rc.val}
orig[j] = i
}
// Only a default case.
if nsends+nrecvs == 0 {
return dflt, false
}
// Compact sel and orig if necessary.
if nsends+nrecvs < len(cases) {
copy(sel[nsends:], sel[len(cases)-nrecvs:])
copy(orig[nsends:], orig[len(cases)-nrecvs:])
}
order := make([]uint16, 2*(nsends+nrecvs))
var pc0 *uintptr
if raceenabled {
pcs := make([]uintptr, nsends+nrecvs)
for i := range pcs {
selectsetpc(&pcs[i])
}
pc0 = &pcs[0]
}
chosen, recvOK := selectgo(&sel[0], &order[0], pc0, nsends, nrecvs, dflt == -1)
// Translate chosen back to caller's ordering.
if chosen < 0 {
chosen = dflt
} else {
chosen = orig[chosen]
}
return chosen, recvOK
}
func (q *waitq) dequeueSudoG(sgp *sudog) {
x := sgp.prev
y := sgp.next
if x != nil {
if y != nil {
// middle of queue
x.next = y
y.prev = x
sgp.next = nil
sgp.prev = nil
return
}
// end of queue
x.next = nil
q.last = x
sgp.prev = nil
return
}
if y != nil {
// start of queue
y.prev = nil
q.first = y
sgp.next = nil
return
}
// x==y==nil. Either sgp is the only element in the queue,
// or it has already been removed. Use q.first to disambiguate.
if q.first == sgp {
q.first = nil
q.last = nil
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Semaphore implementation exposed to Go.
// Intended use is provide a sleep and wakeup
// primitive that can be used in the contended case
// of other synchronization primitives.
// Thus it targets the same goal as Linux's futex,
// but it has much simpler semantics.
//
// That is, don't think of these as semaphores.
// Think of them as a way to implement sleep and wakeup
// such that every sleep is paired with a single wakeup,
// even if, due to races, the wakeup happens before the sleep.
//
// See Mullender and Cox, ``Semaphores in Plan 9,''
// https://swtch.com/semaphore.pdf
package runtime
import (
"internal/cpu"
"runtime/internal/atomic"
"unsafe"
)
// Asynchronous semaphore for sync.Mutex.
// A semaRoot holds a balanced tree of sudog with distinct addresses (s.elem).
// Each of those sudog may in turn point (through s.waitlink) to a list
// of other sudogs waiting on the same address.
// The operations on the inner lists of sudogs with the same address
// are all O(1). The scanning of the top-level semaRoot list is O(log n),
// where n is the number of distinct addresses with goroutines blocked
// on them that hash to the given semaRoot.
// See golang.org/issue/17953 for a program that worked badly
// before we introduced the second level of list, and
// BenchmarkSemTable/OneAddrCollision/* for a benchmark that exercises this.
type semaRoot struct {
lock mutex
treap *sudog // root of balanced tree of unique waiters.
nwait atomic.Uint32 // Number of waiters. Read w/o the lock.
}
var semtable semTable
// Prime to not correlate with any user patterns.
const semTabSize = 251
type semTable [semTabSize]struct {
root semaRoot
pad [cpu.CacheLinePadSize - unsafe.Sizeof(semaRoot{})]byte
}
func (t *semTable) rootFor(addr *uint32) *semaRoot {
return &t[(uintptr(unsafe.Pointer(addr))>>3)%semTabSize].root
}
//go:linkname sync_runtime_Semacquire sync.runtime_Semacquire
func sync_runtime_Semacquire(addr *uint32) {
semacquire1(addr, false, semaBlockProfile, 0, waitReasonSemacquire)
}
//go:linkname poll_runtime_Semacquire internal/poll.runtime_Semacquire
func poll_runtime_Semacquire(addr *uint32) {
semacquire1(addr, false, semaBlockProfile, 0, waitReasonSemacquire)
}
//go:linkname sync_runtime_Semrelease sync.runtime_Semrelease
func sync_runtime_Semrelease(addr *uint32, handoff bool, skipframes int) {
semrelease1(addr, handoff, skipframes)
}
//go:linkname sync_runtime_SemacquireMutex sync.runtime_SemacquireMutex
func sync_runtime_SemacquireMutex(addr *uint32, lifo bool, skipframes int) {
semacquire1(addr, lifo, semaBlockProfile|semaMutexProfile, skipframes, waitReasonSyncMutexLock)
}
//go:linkname sync_runtime_SemacquireRWMutexR sync.runtime_SemacquireRWMutexR
func sync_runtime_SemacquireRWMutexR(addr *uint32, lifo bool, skipframes int) {
semacquire1(addr, lifo, semaBlockProfile|semaMutexProfile, skipframes, waitReasonSyncRWMutexRLock)
}
//go:linkname sync_runtime_SemacquireRWMutex sync.runtime_SemacquireRWMutex
func sync_runtime_SemacquireRWMutex(addr *uint32, lifo bool, skipframes int) {
semacquire1(addr, lifo, semaBlockProfile|semaMutexProfile, skipframes, waitReasonSyncRWMutexLock)
}
//go:linkname poll_runtime_Semrelease internal/poll.runtime_Semrelease
func poll_runtime_Semrelease(addr *uint32) {
semrelease(addr)
}
func readyWithTime(s *sudog, traceskip int) {
if s.releasetime != 0 {
s.releasetime = cputicks()
}
goready(s.g, traceskip)
}
type semaProfileFlags int
const (
semaBlockProfile semaProfileFlags = 1 << iota
semaMutexProfile
)
// Called from runtime.
func semacquire(addr *uint32) {
semacquire1(addr, false, 0, 0, waitReasonSemacquire)
}
func semacquire1(addr *uint32, lifo bool, profile semaProfileFlags, skipframes int, reason waitReason) {
gp := getg()
if gp != gp.m.curg {
throw("semacquire not on the G stack")
}
// Easy case.
if cansemacquire(addr) {
return
}
// Harder case:
// increment waiter count
// try cansemacquire one more time, return if succeeded
// enqueue itself as a waiter
// sleep
// (waiter descriptor is dequeued by signaler)
s := acquireSudog()
root := semtable.rootFor(addr)
t0 := int64(0)
s.releasetime = 0
s.acquiretime = 0
s.ticket = 0
if profile&semaBlockProfile != 0 && blockprofilerate > 0 {
t0 = cputicks()
s.releasetime = -1
}
if profile&semaMutexProfile != 0 && mutexprofilerate > 0 {
if t0 == 0 {
t0 = cputicks()
}
s.acquiretime = t0
}
for {
lockWithRank(&root.lock, lockRankRoot)
// Add ourselves to nwait to disable "easy case" in semrelease.
root.nwait.Add(1)
// Check cansemacquire to avoid missed wakeup.
if cansemacquire(addr) {
root.nwait.Add(-1)
unlock(&root.lock)
break
}
// Any semrelease after the cansemacquire knows we're waiting
// (we set nwait above), so go to sleep.
root.queue(addr, s, lifo)
goparkunlock(&root.lock, reason, traceEvGoBlockSync, 4+skipframes)
if s.ticket != 0 || cansemacquire(addr) {
break
}
}
if s.releasetime > 0 {
blockevent(s.releasetime-t0, 3+skipframes)
}
releaseSudog(s)
}
func semrelease(addr *uint32) {
semrelease1(addr, false, 0)
}
func semrelease1(addr *uint32, handoff bool, skipframes int) {
root := semtable.rootFor(addr)
atomic.Xadd(addr, 1)
// Easy case: no waiters?
// This check must happen after the xadd, to avoid a missed wakeup
// (see loop in semacquire).
if root.nwait.Load() == 0 {
return
}
// Harder case: search for a waiter and wake it.
lockWithRank(&root.lock, lockRankRoot)
if root.nwait.Load() == 0 {
// The count is already consumed by another goroutine,
// so no need to wake up another goroutine.
unlock(&root.lock)
return
}
s, t0 := root.dequeue(addr)
if s != nil {
root.nwait.Add(-1)
}
unlock(&root.lock)
if s != nil { // May be slow or even yield, so unlock first
acquiretime := s.acquiretime
if acquiretime != 0 {
mutexevent(t0-acquiretime, 3+skipframes)
}
if s.ticket != 0 {
throw("corrupted semaphore ticket")
}
if handoff && cansemacquire(addr) {
s.ticket = 1
}
readyWithTime(s, 5+skipframes)
if s.ticket == 1 && getg().m.locks == 0 {
// Direct G handoff
// readyWithTime has added the waiter G as runnext in the
// current P; we now call the scheduler so that we start running
// the waiter G immediately.
// Note that waiter inherits our time slice: this is desirable
// to avoid having a highly contended semaphore hog the P
// indefinitely. goyield is like Gosched, but it emits a
// "preempted" trace event instead and, more importantly, puts
// the current G on the local runq instead of the global one.
// We only do this in the starving regime (handoff=true), as in
// the non-starving case it is possible for a different waiter
// to acquire the semaphore while we are yielding/scheduling,
// and this would be wasteful. We wait instead to enter starving
// regime, and then we start to do direct handoffs of ticket and
// P.
// See issue 33747 for discussion.
goyield()
}
}
}
func cansemacquire(addr *uint32) bool {
for {
v := atomic.Load(addr)
if v == 0 {
return false
}
if atomic.Cas(addr, v, v-1) {
return true
}
}
}
// queue adds s to the blocked goroutines in semaRoot.
func (root *semaRoot) queue(addr *uint32, s *sudog, lifo bool) {
s.g = getg()
s.elem = unsafe.Pointer(addr)
s.next = nil
s.prev = nil
var last *sudog
pt := &root.treap
for t := *pt; t != nil; t = *pt {
if t.elem == unsafe.Pointer(addr) {
// Already have addr in list.
if lifo {
// Substitute s in t's place in treap.
*pt = s
s.ticket = t.ticket
s.acquiretime = t.acquiretime
s.parent = t.parent
s.prev = t.prev
s.next = t.next
if s.prev != nil {
s.prev.parent = s
}
if s.next != nil {
s.next.parent = s
}
// Add t first in s's wait list.
s.waitlink = t
s.waittail = t.waittail
if s.waittail == nil {
s.waittail = t
}
t.parent = nil
t.prev = nil
t.next = nil
t.waittail = nil
} else {
// Add s to end of t's wait list.
if t.waittail == nil {
t.waitlink = s
} else {
t.waittail.waitlink = s
}
t.waittail = s
s.waitlink = nil
}
return
}
last = t
if uintptr(unsafe.Pointer(addr)) < uintptr(t.elem) {
pt = &t.prev
} else {
pt = &t.next
}
}
// Add s as new leaf in tree of unique addrs.
// The balanced tree is a treap using ticket as the random heap priority.
// That is, it is a binary tree ordered according to the elem addresses,
// but then among the space of possible binary trees respecting those
// addresses, it is kept balanced on average by maintaining a heap ordering
// on the ticket: s.ticket <= both s.prev.ticket and s.next.ticket.
// https://en.wikipedia.org/wiki/Treap
// https://faculty.washington.edu/aragon/pubs/rst89.pdf
//
// s.ticket compared with zero in couple of places, therefore set lowest bit.
// It will not affect treap's quality noticeably.
s.ticket = fastrand() | 1
s.parent = last
*pt = s
// Rotate up into tree according to ticket (priority).
for s.parent != nil && s.parent.ticket > s.ticket {
if s.parent.prev == s {
root.rotateRight(s.parent)
} else {
if s.parent.next != s {
panic("semaRoot queue")
}
root.rotateLeft(s.parent)
}
}
}
// dequeue searches for and finds the first goroutine
// in semaRoot blocked on addr.
// If the sudog was being profiled, dequeue returns the time
// at which it was woken up as now. Otherwise now is 0.
func (root *semaRoot) dequeue(addr *uint32) (found *sudog, now int64) {
ps := &root.treap
s := *ps
for ; s != nil; s = *ps {
if s.elem == unsafe.Pointer(addr) {
goto Found
}
if uintptr(unsafe.Pointer(addr)) < uintptr(s.elem) {
ps = &s.prev
} else {
ps = &s.next
}
}
return nil, 0
Found:
now = int64(0)
if s.acquiretime != 0 {
now = cputicks()
}
if t := s.waitlink; t != nil {
// Substitute t, also waiting on addr, for s in root tree of unique addrs.
*ps = t
t.ticket = s.ticket
t.parent = s.parent
t.prev = s.prev
if t.prev != nil {
t.prev.parent = t
}
t.next = s.next
if t.next != nil {
t.next.parent = t
}
if t.waitlink != nil {
t.waittail = s.waittail
} else {
t.waittail = nil
}
t.acquiretime = now
s.waitlink = nil
s.waittail = nil
} else {
// Rotate s down to be leaf of tree for removal, respecting priorities.
for s.next != nil || s.prev != nil {
if s.next == nil || s.prev != nil && s.prev.ticket < s.next.ticket {
root.rotateRight(s)
} else {
root.rotateLeft(s)
}
}
// Remove s, now a leaf.
if s.parent != nil {
if s.parent.prev == s {
s.parent.prev = nil
} else {
s.parent.next = nil
}
} else {
root.treap = nil
}
}
s.parent = nil
s.elem = nil
s.next = nil
s.prev = nil
s.ticket = 0
return s, now
}
// rotateLeft rotates the tree rooted at node x.
// turning (x a (y b c)) into (y (x a b) c).
func (root *semaRoot) rotateLeft(x *sudog) {
// p -> (x a (y b c))
p := x.parent
y := x.next
b := y.prev
y.prev = x
x.parent = y
x.next = b
if b != nil {
b.parent = x
}
y.parent = p
if p == nil {
root.treap = y
} else if p.prev == x {
p.prev = y
} else {
if p.next != x {
throw("semaRoot rotateLeft")
}
p.next = y
}
}
// rotateRight rotates the tree rooted at node y.
// turning (y (x a b) c) into (x a (y b c)).
func (root *semaRoot) rotateRight(y *sudog) {
// p -> (y (x a b) c)
p := y.parent
x := y.prev
b := x.next
x.next = y
y.parent = x
y.prev = b
if b != nil {
b.parent = y
}
x.parent = p
if p == nil {
root.treap = x
} else if p.prev == y {
p.prev = x
} else {
if p.next != y {
throw("semaRoot rotateRight")
}
p.next = x
}
}
// notifyList is a ticket-based notification list used to implement sync.Cond.
//
// It must be kept in sync with the sync package.
type notifyList struct {
// wait is the ticket number of the next waiter. It is atomically
// incremented outside the lock.
wait atomic.Uint32
// notify is the ticket number of the next waiter to be notified. It can
// be read outside the lock, but is only written to with lock held.
//
// Both wait & notify can wrap around, and such cases will be correctly
// handled as long as their "unwrapped" difference is bounded by 2^31.
// For this not to be the case, we'd need to have 2^31+ goroutines
// blocked on the same condvar, which is currently not possible.
notify uint32
// List of parked waiters.
lock mutex
head *sudog
tail *sudog
}
// less checks if a < b, considering a & b running counts that may overflow the
// 32-bit range, and that their "unwrapped" difference is always less than 2^31.
func less(a, b uint32) bool {
return int32(a-b) < 0
}
// notifyListAdd adds the caller to a notify list such that it can receive
// notifications. The caller must eventually call notifyListWait to wait for
// such a notification, passing the returned ticket number.
//
//go:linkname notifyListAdd sync.runtime_notifyListAdd
func notifyListAdd(l *notifyList) uint32 {
// This may be called concurrently, for example, when called from
// sync.Cond.Wait while holding a RWMutex in read mode.
return l.wait.Add(1) - 1
}
// notifyListWait waits for a notification. If one has been sent since
// notifyListAdd was called, it returns immediately. Otherwise, it blocks.
//
//go:linkname notifyListWait sync.runtime_notifyListWait
func notifyListWait(l *notifyList, t uint32) {
lockWithRank(&l.lock, lockRankNotifyList)
// Return right away if this ticket has already been notified.
if less(t, l.notify) {
unlock(&l.lock)
return
}
// Enqueue itself.
s := acquireSudog()
s.g = getg()
s.ticket = t
s.releasetime = 0
t0 := int64(0)
if blockprofilerate > 0 {
t0 = cputicks()
s.releasetime = -1
}
if l.tail == nil {
l.head = s
} else {
l.tail.next = s
}
l.tail = s
goparkunlock(&l.lock, waitReasonSyncCondWait, traceEvGoBlockCond, 3)
if t0 != 0 {
blockevent(s.releasetime-t0, 2)
}
releaseSudog(s)
}
// notifyListNotifyAll notifies all entries in the list.
//
//go:linkname notifyListNotifyAll sync.runtime_notifyListNotifyAll
func notifyListNotifyAll(l *notifyList) {
// Fast-path: if there are no new waiters since the last notification
// we don't need to acquire the lock.
if l.wait.Load() == atomic.Load(&l.notify) {
return
}
// Pull the list out into a local variable, waiters will be readied
// outside the lock.
lockWithRank(&l.lock, lockRankNotifyList)
s := l.head
l.head = nil
l.tail = nil
// Update the next ticket to be notified. We can set it to the current
// value of wait because any previous waiters are already in the list
// or will notice that they have already been notified when trying to
// add themselves to the list.
atomic.Store(&l.notify, l.wait.Load())
unlock(&l.lock)
// Go through the local list and ready all waiters.
for s != nil {
next := s.next
s.next = nil
readyWithTime(s, 4)
s = next
}
}
// notifyListNotifyOne notifies one entry in the list.
//
//go:linkname notifyListNotifyOne sync.runtime_notifyListNotifyOne
func notifyListNotifyOne(l *notifyList) {
// Fast-path: if there are no new waiters since the last notification
// we don't need to acquire the lock at all.
if l.wait.Load() == atomic.Load(&l.notify) {
return
}
lockWithRank(&l.lock, lockRankNotifyList)
// Re-check under the lock if we need to do anything.
t := l.notify
if t == l.wait.Load() {
unlock(&l.lock)
return
}
// Update the next notify ticket number.
atomic.Store(&l.notify, t+1)
// Try to find the g that needs to be notified.
// If it hasn't made it to the list yet we won't find it,
// but it won't park itself once it sees the new notify number.
//
// This scan looks linear but essentially always stops quickly.
// Because g's queue separately from taking numbers,
// there may be minor reorderings in the list, but we
// expect the g we're looking for to be near the front.
// The g has others in front of it on the list only to the
// extent that it lost the race, so the iteration will not
// be too long. This applies even when the g is missing:
// it hasn't yet gotten to sleep and has lost the race to
// the (few) other g's that we find on the list.
for p, s := (*sudog)(nil), l.head; s != nil; p, s = s, s.next {
if s.ticket == t {
n := s.next
if p != nil {
p.next = n
} else {
l.head = n
}
if n == nil {
l.tail = p
}
unlock(&l.lock)
s.next = nil
readyWithTime(s, 4)
return
}
}
unlock(&l.lock)
}
//go:linkname notifyListCheck sync.runtime_notifyListCheck
func notifyListCheck(sz uintptr) {
if sz != unsafe.Sizeof(notifyList{}) {
print("runtime: bad notifyList size - sync=", sz, " runtime=", unsafe.Sizeof(notifyList{}), "\n")
throw("bad notifyList size")
}
}
//go:linkname sync_nanotime sync.runtime_nanotime
func sync_nanotime() int64 {
return nanotime()
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 && (darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris)
package runtime
import (
"internal/abi"
"internal/goarch"
"unsafe"
)
func dumpregs(c *sigctxt) {
print("rax ", hex(c.rax()), "\n")
print("rbx ", hex(c.rbx()), "\n")
print("rcx ", hex(c.rcx()), "\n")
print("rdx ", hex(c.rdx()), "\n")
print("rdi ", hex(c.rdi()), "\n")
print("rsi ", hex(c.rsi()), "\n")
print("rbp ", hex(c.rbp()), "\n")
print("rsp ", hex(c.rsp()), "\n")
print("r8 ", hex(c.r8()), "\n")
print("r9 ", hex(c.r9()), "\n")
print("r10 ", hex(c.r10()), "\n")
print("r11 ", hex(c.r11()), "\n")
print("r12 ", hex(c.r12()), "\n")
print("r13 ", hex(c.r13()), "\n")
print("r14 ", hex(c.r14()), "\n")
print("r15 ", hex(c.r15()), "\n")
print("rip ", hex(c.rip()), "\n")
print("rflags ", hex(c.rflags()), "\n")
print("cs ", hex(c.cs()), "\n")
print("fs ", hex(c.fs()), "\n")
print("gs ", hex(c.gs()), "\n")
}
//go:nosplit
//go:nowritebarrierrec
func (c *sigctxt) sigpc() uintptr { return uintptr(c.rip()) }
func (c *sigctxt) setsigpc(x uint64) { c.set_rip(x) }
func (c *sigctxt) sigsp() uintptr { return uintptr(c.rsp()) }
func (c *sigctxt) siglr() uintptr { return 0 }
func (c *sigctxt) fault() uintptr { return uintptr(c.sigaddr()) }
// preparePanic sets up the stack to look like a call to sigpanic.
func (c *sigctxt) preparePanic(sig uint32, gp *g) {
// Work around Leopard bug that doesn't set FPE_INTDIV.
// Look at instruction to see if it is a divide.
// Not necessary in Snow Leopard (si_code will be != 0).
if GOOS == "darwin" && sig == _SIGFPE && gp.sigcode0 == 0 {
pc := (*[4]byte)(unsafe.Pointer(gp.sigpc))
i := 0
if pc[i]&0xF0 == 0x40 { // 64-bit REX prefix
i++
} else if pc[i] == 0x66 { // 16-bit instruction prefix
i++
}
if pc[i] == 0xF6 || pc[i] == 0xF7 {
gp.sigcode0 = _FPE_INTDIV
}
}
pc := uintptr(c.rip())
sp := uintptr(c.rsp())
// In case we are panicking from external code, we need to initialize
// Go special registers. We inject sigpanic0 (instead of sigpanic),
// which takes care of that.
if shouldPushSigpanic(gp, pc, *(*uintptr)(unsafe.Pointer(sp))) {
c.pushCall(abi.FuncPCABI0(sigpanic0), pc)
} else {
// Not safe to push the call. Just clobber the frame.
c.set_rip(uint64(abi.FuncPCABI0(sigpanic0)))
}
}
func (c *sigctxt) pushCall(targetPC, resumePC uintptr) {
// Make it look like we called target at resumePC.
sp := uintptr(c.rsp())
sp -= goarch.PtrSize
*(*uintptr)(unsafe.Pointer(sp)) = resumePC
c.set_rsp(uint64(sp))
c.set_rip(uint64(targetPC))
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"unsafe"
)
type sigctxt struct {
info *siginfo
ctxt unsafe.Pointer
}
//go:nosplit
//go:nowritebarrierrec
func (c *sigctxt) regs() *sigcontext {
return (*sigcontext)(unsafe.Pointer(&(*ucontext)(c.ctxt).uc_mcontext))
}
func (c *sigctxt) rax() uint64 { return c.regs().rax }
func (c *sigctxt) rbx() uint64 { return c.regs().rbx }
func (c *sigctxt) rcx() uint64 { return c.regs().rcx }
func (c *sigctxt) rdx() uint64 { return c.regs().rdx }
func (c *sigctxt) rdi() uint64 { return c.regs().rdi }
func (c *sigctxt) rsi() uint64 { return c.regs().rsi }
func (c *sigctxt) rbp() uint64 { return c.regs().rbp }
func (c *sigctxt) rsp() uint64 { return c.regs().rsp }
func (c *sigctxt) r8() uint64 { return c.regs().r8 }
func (c *sigctxt) r9() uint64 { return c.regs().r9 }
func (c *sigctxt) r10() uint64 { return c.regs().r10 }
func (c *sigctxt) r11() uint64 { return c.regs().r11 }
func (c *sigctxt) r12() uint64 { return c.regs().r12 }
func (c *sigctxt) r13() uint64 { return c.regs().r13 }
func (c *sigctxt) r14() uint64 { return c.regs().r14 }
func (c *sigctxt) r15() uint64 { return c.regs().r15 }
//go:nosplit
//go:nowritebarrierrec
func (c *sigctxt) rip() uint64 { return c.regs().rip }
func (c *sigctxt) rflags() uint64 { return c.regs().eflags }
func (c *sigctxt) cs() uint64 { return uint64(c.regs().cs) }
func (c *sigctxt) fs() uint64 { return uint64(c.regs().fs) }
func (c *sigctxt) gs() uint64 { return uint64(c.regs().gs) }
func (c *sigctxt) sigcode() uint64 { return uint64(c.info.si_code) }
func (c *sigctxt) sigaddr() uint64 { return c.info.si_addr }
func (c *sigctxt) set_rip(x uint64) { c.regs().rip = x }
func (c *sigctxt) set_rsp(x uint64) { c.regs().rsp = x }
func (c *sigctxt) set_sigcode(x uint64) { c.info.si_code = int32(x) }
func (c *sigctxt) set_sigaddr(x uint64) {
*(*uintptr)(add(unsafe.Pointer(c.info), 2*goarch.PtrSize)) = uintptr(x)
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package runtime
import (
"internal/abi"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// sigTabT is the type of an entry in the global sigtable array.
// sigtable is inherently system dependent, and appears in OS-specific files,
// but sigTabT is the same for all Unixy systems.
// The sigtable array is indexed by a system signal number to get the flags
// and printable name of each signal.
type sigTabT struct {
flags int32
name string
}
//go:linkname os_sigpipe os.sigpipe
func os_sigpipe() {
systemstack(sigpipe)
}
func signame(sig uint32) string {
if sig >= uint32(len(sigtable)) {
return ""
}
return sigtable[sig].name
}
const (
_SIG_DFL uintptr = 0
_SIG_IGN uintptr = 1
)
// sigPreempt is the signal used for non-cooperative preemption.
//
// There's no good way to choose this signal, but there are some
// heuristics:
//
// 1. It should be a signal that's passed-through by debuggers by
// default. On Linux, this is SIGALRM, SIGURG, SIGCHLD, SIGIO,
// SIGVTALRM, SIGPROF, and SIGWINCH, plus some glibc-internal signals.
//
// 2. It shouldn't be used internally by libc in mixed Go/C binaries
// because libc may assume it's the only thing that can handle these
// signals. For example SIGCANCEL or SIGSETXID.
//
// 3. It should be a signal that can happen spuriously without
// consequences. For example, SIGALRM is a bad choice because the
// signal handler can't tell if it was caused by the real process
// alarm or not (arguably this means the signal is broken, but I
// digress). SIGUSR1 and SIGUSR2 are also bad because those are often
// used in meaningful ways by applications.
//
// 4. We need to deal with platforms without real-time signals (like
// macOS), so those are out.
//
// We use SIGURG because it meets all of these criteria, is extremely
// unlikely to be used by an application for its "real" meaning (both
// because out-of-band data is basically unused and because SIGURG
// doesn't report which socket has the condition, making it pretty
// useless), and even if it is, the application has to be ready for
// spurious SIGURG. SIGIO wouldn't be a bad choice either, but is more
// likely to be used for real.
const sigPreempt = _SIGURG
// Stores the signal handlers registered before Go installed its own.
// These signal handlers will be invoked in cases where Go doesn't want to
// handle a particular signal (e.g., signal occurred on a non-Go thread).
// See sigfwdgo for more information on when the signals are forwarded.
//
// This is read by the signal handler; accesses should use
// atomic.Loaduintptr and atomic.Storeuintptr.
var fwdSig [_NSIG]uintptr
// handlingSig is indexed by signal number and is non-zero if we are
// currently handling the signal. Or, to put it another way, whether
// the signal handler is currently set to the Go signal handler or not.
// This is uint32 rather than bool so that we can use atomic instructions.
var handlingSig [_NSIG]uint32
// channels for synchronizing signal mask updates with the signal mask
// thread
var (
disableSigChan chan uint32
enableSigChan chan uint32
maskUpdatedChan chan struct{}
)
func init() {
// _NSIG is the number of signals on this operating system.
// sigtable should describe what to do for all the possible signals.
if len(sigtable) != _NSIG {
print("runtime: len(sigtable)=", len(sigtable), " _NSIG=", _NSIG, "\n")
throw("bad sigtable len")
}
}
var signalsOK bool
// Initialize signals.
// Called by libpreinit so runtime may not be initialized.
//
//go:nosplit
//go:nowritebarrierrec
func initsig(preinit bool) {
if !preinit {
// It's now OK for signal handlers to run.
signalsOK = true
}
// For c-archive/c-shared this is called by libpreinit with
// preinit == true.
if (isarchive || islibrary) && !preinit {
return
}
for i := uint32(0); i < _NSIG; i++ {
t := &sigtable[i]
if t.flags == 0 || t.flags&_SigDefault != 0 {
continue
}
// We don't need to use atomic operations here because
// there shouldn't be any other goroutines running yet.
fwdSig[i] = getsig(i)
if !sigInstallGoHandler(i) {
// Even if we are not installing a signal handler,
// set SA_ONSTACK if necessary.
if fwdSig[i] != _SIG_DFL && fwdSig[i] != _SIG_IGN {
setsigstack(i)
} else if fwdSig[i] == _SIG_IGN {
sigInitIgnored(i)
}
continue
}
handlingSig[i] = 1
setsig(i, abi.FuncPCABIInternal(sighandler))
}
}
//go:nosplit
//go:nowritebarrierrec
func sigInstallGoHandler(sig uint32) bool {
// For some signals, we respect an inherited SIG_IGN handler
// rather than insist on installing our own default handler.
// Even these signals can be fetched using the os/signal package.
switch sig {
case _SIGHUP, _SIGINT:
if atomic.Loaduintptr(&fwdSig[sig]) == _SIG_IGN {
return false
}
}
if (GOOS == "linux" || GOOS == "android") && !iscgo && sig == sigPerThreadSyscall {
// sigPerThreadSyscall is the same signal used by glibc for
// per-thread syscalls on Linux. We use it for the same purpose
// in non-cgo binaries.
return true
}
t := &sigtable[sig]
if t.flags&_SigSetStack != 0 {
return false
}
// When built using c-archive or c-shared, only install signal
// handlers for synchronous signals and SIGPIPE and sigPreempt.
if (isarchive || islibrary) && t.flags&_SigPanic == 0 && sig != _SIGPIPE && sig != sigPreempt {
return false
}
return true
}
// sigenable enables the Go signal handler to catch the signal sig.
// It is only called while holding the os/signal.handlers lock,
// via os/signal.enableSignal and signal_enable.
func sigenable(sig uint32) {
if sig >= uint32(len(sigtable)) {
return
}
// SIGPROF is handled specially for profiling.
if sig == _SIGPROF {
return
}
t := &sigtable[sig]
if t.flags&_SigNotify != 0 {
ensureSigM()
enableSigChan <- sig
<-maskUpdatedChan
if atomic.Cas(&handlingSig[sig], 0, 1) {
atomic.Storeuintptr(&fwdSig[sig], getsig(sig))
setsig(sig, abi.FuncPCABIInternal(sighandler))
}
}
}
// sigdisable disables the Go signal handler for the signal sig.
// It is only called while holding the os/signal.handlers lock,
// via os/signal.disableSignal and signal_disable.
func sigdisable(sig uint32) {
if sig >= uint32(len(sigtable)) {
return
}
// SIGPROF is handled specially for profiling.
if sig == _SIGPROF {
return
}
t := &sigtable[sig]
if t.flags&_SigNotify != 0 {
ensureSigM()
disableSigChan <- sig
<-maskUpdatedChan
// If initsig does not install a signal handler for a
// signal, then to go back to the state before Notify
// we should remove the one we installed.
if !sigInstallGoHandler(sig) {
atomic.Store(&handlingSig[sig], 0)
setsig(sig, atomic.Loaduintptr(&fwdSig[sig]))
}
}
}
// sigignore ignores the signal sig.
// It is only called while holding the os/signal.handlers lock,
// via os/signal.ignoreSignal and signal_ignore.
func sigignore(sig uint32) {
if sig >= uint32(len(sigtable)) {
return
}
// SIGPROF is handled specially for profiling.
if sig == _SIGPROF {
return
}
t := &sigtable[sig]
if t.flags&_SigNotify != 0 {
atomic.Store(&handlingSig[sig], 0)
setsig(sig, _SIG_IGN)
}
}
// clearSignalHandlers clears all signal handlers that are not ignored
// back to the default. This is called by the child after a fork, so that
// we can enable the signal mask for the exec without worrying about
// running a signal handler in the child.
//
//go:nosplit
//go:nowritebarrierrec
func clearSignalHandlers() {
for i := uint32(0); i < _NSIG; i++ {
if atomic.Load(&handlingSig[i]) != 0 {
setsig(i, _SIG_DFL)
}
}
}
// setProcessCPUProfilerTimer is called when the profiling timer changes.
// It is called with prof.signalLock held. hz is the new timer, and is 0 if
// profiling is being disabled. Enable or disable the signal as
// required for -buildmode=c-archive.
func setProcessCPUProfilerTimer(hz int32) {
if hz != 0 {
// Enable the Go signal handler if not enabled.
if atomic.Cas(&handlingSig[_SIGPROF], 0, 1) {
h := getsig(_SIGPROF)
// If no signal handler was installed before, then we record
// _SIG_IGN here. When we turn off profiling (below) we'll start
// ignoring SIGPROF signals. We do this, rather than change
// to SIG_DFL, because there may be a pending SIGPROF
// signal that has not yet been delivered to some other thread.
// If we change to SIG_DFL when turning off profiling, the
// program will crash when that SIGPROF is delivered. We assume
// that programs that use profiling don't want to crash on a
// stray SIGPROF. See issue 19320.
// We do the change here instead of when turning off profiling,
// because there we may race with a signal handler running
// concurrently, in particular, sigfwdgo may observe _SIG_DFL and
// die. See issue 43828.
if h == _SIG_DFL {
h = _SIG_IGN
}
atomic.Storeuintptr(&fwdSig[_SIGPROF], h)
setsig(_SIGPROF, abi.FuncPCABIInternal(sighandler))
}
var it itimerval
it.it_interval.tv_sec = 0
it.it_interval.set_usec(1000000 / hz)
it.it_value = it.it_interval
setitimer(_ITIMER_PROF, &it, nil)
} else {
setitimer(_ITIMER_PROF, &itimerval{}, nil)
// If the Go signal handler should be disabled by default,
// switch back to the signal handler that was installed
// when we enabled profiling. We don't try to handle the case
// of a program that changes the SIGPROF handler while Go
// profiling is enabled.
if !sigInstallGoHandler(_SIGPROF) {
if atomic.Cas(&handlingSig[_SIGPROF], 1, 0) {
h := atomic.Loaduintptr(&fwdSig[_SIGPROF])
setsig(_SIGPROF, h)
}
}
}
}
// setThreadCPUProfilerHz makes any thread-specific changes required to
// implement profiling at a rate of hz.
// No changes required on Unix systems when using setitimer.
func setThreadCPUProfilerHz(hz int32) {
getg().m.profilehz = hz
}
func sigpipe() {
if signal_ignored(_SIGPIPE) || sigsend(_SIGPIPE) {
return
}
dieFromSignal(_SIGPIPE)
}
// doSigPreempt handles a preemption signal on gp.
func doSigPreempt(gp *g, ctxt *sigctxt) {
// Check if this G wants to be preempted and is safe to
// preempt.
if wantAsyncPreempt(gp) {
if ok, newpc := isAsyncSafePoint(gp, ctxt.sigpc(), ctxt.sigsp(), ctxt.siglr()); ok {
// Adjust the PC and inject a call to asyncPreempt.
ctxt.pushCall(abi.FuncPCABI0(asyncPreempt), newpc)
}
}
// Acknowledge the preemption.
gp.m.preemptGen.Add(1)
gp.m.signalPending.Store(0)
if GOOS == "darwin" || GOOS == "ios" {
pendingPreemptSignals.Add(-1)
}
}
const preemptMSupported = true
// preemptM sends a preemption request to mp. This request may be
// handled asynchronously and may be coalesced with other requests to
// the M. When the request is received, if the running G or P are
// marked for preemption and the goroutine is at an asynchronous
// safe-point, it will preempt the goroutine. It always atomically
// increments mp.preemptGen after handling a preemption request.
func preemptM(mp *m) {
// On Darwin, don't try to preempt threads during exec.
// Issue #41702.
if GOOS == "darwin" || GOOS == "ios" {
execLock.rlock()
}
if mp.signalPending.CompareAndSwap(0, 1) {
if GOOS == "darwin" || GOOS == "ios" {
pendingPreemptSignals.Add(1)
}
// If multiple threads are preempting the same M, it may send many
// signals to the same M such that it hardly make progress, causing
// live-lock problem. Apparently this could happen on darwin. See
// issue #37741.
// Only send a signal if there isn't already one pending.
signalM(mp, sigPreempt)
}
if GOOS == "darwin" || GOOS == "ios" {
execLock.runlock()
}
}
// sigFetchG fetches the value of G safely when running in a signal handler.
// On some architectures, the g value may be clobbered when running in a VDSO.
// See issue #32912.
//
//go:nosplit
func sigFetchG(c *sigctxt) *g {
switch GOARCH {
case "arm", "arm64", "ppc64", "ppc64le", "riscv64", "s390x":
if !iscgo && inVDSOPage(c.sigpc()) {
// When using cgo, we save the g on TLS and load it from there
// in sigtramp. Just use that.
// Otherwise, before making a VDSO call we save the g to the
// bottom of the signal stack. Fetch from there.
// TODO: in efence mode, stack is sysAlloc'd, so this wouldn't
// work.
sp := getcallersp()
s := spanOf(sp)
if s != nil && s.state.get() == mSpanManual && s.base() < sp && sp < s.limit {
gp := *(**g)(unsafe.Pointer(s.base()))
return gp
}
return nil
}
}
return getg()
}
// sigtrampgo is called from the signal handler function, sigtramp,
// written in assembly code.
// This is called by the signal handler, and the world may be stopped.
//
// It must be nosplit because getg() is still the G that was running
// (if any) when the signal was delivered, but it's (usually) called
// on the gsignal stack. Until this switches the G to gsignal, the
// stack bounds check won't work.
//
//go:nosplit
//go:nowritebarrierrec
func sigtrampgo(sig uint32, info *siginfo, ctx unsafe.Pointer) {
if sigfwdgo(sig, info, ctx) {
return
}
c := &sigctxt{info, ctx}
gp := sigFetchG(c)
setg(gp)
if gp == nil {
if sig == _SIGPROF {
// Some platforms (Linux) have per-thread timers, which we use in
// combination with the process-wide timer. Avoid double-counting.
if validSIGPROF(nil, c) {
sigprofNonGoPC(c.sigpc())
}
return
}
if sig == sigPreempt && preemptMSupported && debug.asyncpreemptoff == 0 {
// This is probably a signal from preemptM sent
// while executing Go code but received while
// executing non-Go code.
// We got past sigfwdgo, so we know that there is
// no non-Go signal handler for sigPreempt.
// The default behavior for sigPreempt is to ignore
// the signal, so badsignal will be a no-op anyway.
if GOOS == "darwin" || GOOS == "ios" {
pendingPreemptSignals.Add(-1)
}
return
}
c.fixsigcode(sig)
badsignal(uintptr(sig), c)
return
}
setg(gp.m.gsignal)
// If some non-Go code called sigaltstack, adjust.
var gsignalStack gsignalStack
setStack := adjustSignalStack(sig, gp.m, &gsignalStack)
if setStack {
gp.m.gsignal.stktopsp = getcallersp()
}
if gp.stackguard0 == stackFork {
signalDuringFork(sig)
}
c.fixsigcode(sig)
sighandler(sig, info, ctx, gp)
setg(gp)
if setStack {
restoreGsignalStack(&gsignalStack)
}
}
// If the signal handler receives a SIGPROF signal on a non-Go thread,
// it tries to collect a traceback into sigprofCallers.
// sigprofCallersUse is set to non-zero while sigprofCallers holds a traceback.
var sigprofCallers cgoCallers
var sigprofCallersUse uint32
// sigprofNonGo is called if we receive a SIGPROF signal on a non-Go thread,
// and the signal handler collected a stack trace in sigprofCallers.
// When this is called, sigprofCallersUse will be non-zero.
// g is nil, and what we can do is very limited.
//
// It is called from the signal handling functions written in assembly code that
// are active for cgo programs, cgoSigtramp and sigprofNonGoWrapper, which have
// not verified that the SIGPROF delivery corresponds to the best available
// profiling source for this thread.
//
//go:nosplit
//go:nowritebarrierrec
func sigprofNonGo(sig uint32, info *siginfo, ctx unsafe.Pointer) {
if prof.hz.Load() != 0 {
c := &sigctxt{info, ctx}
// Some platforms (Linux) have per-thread timers, which we use in
// combination with the process-wide timer. Avoid double-counting.
if validSIGPROF(nil, c) {
n := 0
for n < len(sigprofCallers) && sigprofCallers[n] != 0 {
n++
}
cpuprof.addNonGo(sigprofCallers[:n])
}
}
atomic.Store(&sigprofCallersUse, 0)
}
// sigprofNonGoPC is called when a profiling signal arrived on a
// non-Go thread and we have a single PC value, not a stack trace.
// g is nil, and what we can do is very limited.
//
//go:nosplit
//go:nowritebarrierrec
func sigprofNonGoPC(pc uintptr) {
if prof.hz.Load() != 0 {
stk := []uintptr{
pc,
abi.FuncPCABIInternal(_ExternalCode) + sys.PCQuantum,
}
cpuprof.addNonGo(stk)
}
}
// adjustSignalStack adjusts the current stack guard based on the
// stack pointer that is actually in use while handling a signal.
// We do this in case some non-Go code called sigaltstack.
// This reports whether the stack was adjusted, and if so stores the old
// signal stack in *gsigstack.
//
//go:nosplit
func adjustSignalStack(sig uint32, mp *m, gsigStack *gsignalStack) bool {
sp := uintptr(unsafe.Pointer(&sig))
if sp >= mp.gsignal.stack.lo && sp < mp.gsignal.stack.hi {
return false
}
var st stackt
sigaltstack(nil, &st)
stsp := uintptr(unsafe.Pointer(st.ss_sp))
if st.ss_flags&_SS_DISABLE == 0 && sp >= stsp && sp < stsp+st.ss_size {
setGsignalStack(&st, gsigStack)
return true
}
if sp >= mp.g0.stack.lo && sp < mp.g0.stack.hi {
// The signal was delivered on the g0 stack.
// This can happen when linked with C code
// using the thread sanitizer, which collects
// signals then delivers them itself by calling
// the signal handler directly when C code,
// including C code called via cgo, calls a
// TSAN-intercepted function such as malloc.
//
// We check this condition last as g0.stack.lo
// may be not very accurate (see mstart).
st := stackt{ss_size: mp.g0.stack.hi - mp.g0.stack.lo}
setSignalstackSP(&st, mp.g0.stack.lo)
setGsignalStack(&st, gsigStack)
return true
}
// sp is not within gsignal stack, g0 stack, or sigaltstack. Bad.
setg(nil)
needm()
if st.ss_flags&_SS_DISABLE != 0 {
noSignalStack(sig)
} else {
sigNotOnStack(sig)
}
dropm()
return false
}
// crashing is the number of m's we have waited for when implementing
// GOTRACEBACK=crash when a signal is received.
var crashing int32
// testSigtrap and testSigusr1 are used by the runtime tests. If
// non-nil, it is called on SIGTRAP/SIGUSR1. If it returns true, the
// normal behavior on this signal is suppressed.
var testSigtrap func(info *siginfo, ctxt *sigctxt, gp *g) bool
var testSigusr1 func(gp *g) bool
// sighandler is invoked when a signal occurs. The global g will be
// set to a gsignal goroutine and we will be running on the alternate
// signal stack. The parameter gp will be the value of the global g
// when the signal occurred. The sig, info, and ctxt parameters are
// from the system signal handler: they are the parameters passed when
// the SA is passed to the sigaction system call.
//
// The garbage collector may have stopped the world, so write barriers
// are not allowed.
//
//go:nowritebarrierrec
func sighandler(sig uint32, info *siginfo, ctxt unsafe.Pointer, gp *g) {
// The g executing the signal handler. This is almost always
// mp.gsignal. See delayedSignal for an exception.
gsignal := getg()
mp := gsignal.m
c := &sigctxt{info, ctxt}
// Cgo TSAN (not the Go race detector) intercepts signals and calls the
// signal handler at a later time. When the signal handler is called, the
// memory may have changed, but the signal context remains old. The
// unmatched signal context and memory makes it unsafe to unwind or inspect
// the stack. So we ignore delayed non-fatal signals that will cause a stack
// inspection (profiling signal and preemption signal).
// cgo_yield is only non-nil for TSAN, and is specifically used to trigger
// signal delivery. We use that as an indicator of delayed signals.
// For delayed signals, the handler is called on the g0 stack (see
// adjustSignalStack).
delayedSignal := *cgo_yield != nil && mp != nil && gsignal.stack == mp.g0.stack
if sig == _SIGPROF {
// Some platforms (Linux) have per-thread timers, which we use in
// combination with the process-wide timer. Avoid double-counting.
if !delayedSignal && validSIGPROF(mp, c) {
sigprof(c.sigpc(), c.sigsp(), c.siglr(), gp, mp)
}
return
}
if sig == _SIGTRAP && testSigtrap != nil && testSigtrap(info, (*sigctxt)(noescape(unsafe.Pointer(c))), gp) {
return
}
if sig == _SIGUSR1 && testSigusr1 != nil && testSigusr1(gp) {
return
}
if (GOOS == "linux" || GOOS == "android") && sig == sigPerThreadSyscall {
// sigPerThreadSyscall is the same signal used by glibc for
// per-thread syscalls on Linux. We use it for the same purpose
// in non-cgo binaries. Since this signal is not _SigNotify,
// there is nothing more to do once we run the syscall.
runPerThreadSyscall()
return
}
if sig == sigPreempt && debug.asyncpreemptoff == 0 && !delayedSignal {
// Might be a preemption signal.
doSigPreempt(gp, c)
// Even if this was definitely a preemption signal, it
// may have been coalesced with another signal, so we
// still let it through to the application.
}
flags := int32(_SigThrow)
if sig < uint32(len(sigtable)) {
flags = sigtable[sig].flags
}
if !c.sigFromUser() && flags&_SigPanic != 0 && gp.throwsplit {
// We can't safely sigpanic because it may grow the
// stack. Abort in the signal handler instead.
flags = _SigThrow
}
if isAbortPC(c.sigpc()) {
// On many architectures, the abort function just
// causes a memory fault. Don't turn that into a panic.
flags = _SigThrow
}
if !c.sigFromUser() && flags&_SigPanic != 0 {
// The signal is going to cause a panic.
// Arrange the stack so that it looks like the point
// where the signal occurred made a call to the
// function sigpanic. Then set the PC to sigpanic.
// Have to pass arguments out of band since
// augmenting the stack frame would break
// the unwinding code.
gp.sig = sig
gp.sigcode0 = uintptr(c.sigcode())
gp.sigcode1 = uintptr(c.fault())
gp.sigpc = c.sigpc()
c.preparePanic(sig, gp)
return
}
if c.sigFromUser() || flags&_SigNotify != 0 {
if sigsend(sig) {
return
}
}
if c.sigFromUser() && signal_ignored(sig) {
return
}
if flags&_SigKill != 0 {
dieFromSignal(sig)
}
// _SigThrow means that we should exit now.
// If we get here with _SigPanic, it means that the signal
// was sent to us by a program (c.sigFromUser() is true);
// in that case, if we didn't handle it in sigsend, we exit now.
if flags&(_SigThrow|_SigPanic) == 0 {
return
}
mp.throwing = throwTypeRuntime
mp.caughtsig.set(gp)
if crashing == 0 {
startpanic_m()
}
if sig < uint32(len(sigtable)) {
print(sigtable[sig].name, "\n")
} else {
print("Signal ", sig, "\n")
}
print("PC=", hex(c.sigpc()), " m=", mp.id, " sigcode=", c.sigcode(), "\n")
if mp.incgo && gp == mp.g0 && mp.curg != nil {
print("signal arrived during cgo execution\n")
// Switch to curg so that we get a traceback of the Go code
// leading up to the cgocall, which switched from curg to g0.
gp = mp.curg
}
if sig == _SIGILL || sig == _SIGFPE {
// It would be nice to know how long the instruction is.
// Unfortunately, that's complicated to do in general (mostly for x86
// and s930x, but other archs have non-standard instruction lengths also).
// Opt to print 16 bytes, which covers most instructions.
const maxN = 16
n := uintptr(maxN)
// We have to be careful, though. If we're near the end of
// a page and the following page isn't mapped, we could
// segfault. So make sure we don't straddle a page (even though
// that could lead to printing an incomplete instruction).
// We're assuming here we can read at least the page containing the PC.
// I suppose it is possible that the page is mapped executable but not readable?
pc := c.sigpc()
if n > physPageSize-pc%physPageSize {
n = physPageSize - pc%physPageSize
}
print("instruction bytes:")
b := (*[maxN]byte)(unsafe.Pointer(pc))
for i := uintptr(0); i < n; i++ {
print(" ", hex(b[i]))
}
println()
}
print("\n")
level, _, docrash := gotraceback()
if level > 0 {
goroutineheader(gp)
tracebacktrap(c.sigpc(), c.sigsp(), c.siglr(), gp)
if crashing > 0 && gp != mp.curg && mp.curg != nil && readgstatus(mp.curg)&^_Gscan == _Grunning {
// tracebackothers on original m skipped this one; trace it now.
goroutineheader(mp.curg)
traceback(^uintptr(0), ^uintptr(0), 0, mp.curg)
} else if crashing == 0 {
tracebackothers(gp)
print("\n")
}
dumpregs(c)
}
if docrash {
crashing++
if crashing < mcount()-int32(extraMCount) {
// There are other m's that need to dump their stacks.
// Relay SIGQUIT to the next m by sending it to the current process.
// All m's that have already received SIGQUIT have signal masks blocking
// receipt of any signals, so the SIGQUIT will go to an m that hasn't seen it yet.
// When the last m receives the SIGQUIT, it will fall through to the call to
// crash below. Just in case the relaying gets botched, each m involved in
// the relay sleeps for 5 seconds and then does the crash/exit itself.
// In expected operation, the last m has received the SIGQUIT and run
// crash/exit and the process is gone, all long before any of the
// 5-second sleeps have finished.
print("\n-----\n\n")
raiseproc(_SIGQUIT)
usleep(5 * 1000 * 1000)
}
crash()
}
printDebugLog()
exit(2)
}
// sigpanic turns a synchronous signal into a run-time panic.
// If the signal handler sees a synchronous panic, it arranges the
// stack to look like the function where the signal occurred called
// sigpanic, sets the signal's PC value to sigpanic, and returns from
// the signal handler. The effect is that the program will act as
// though the function that got the signal simply called sigpanic
// instead.
//
// This must NOT be nosplit because the linker doesn't know where
// sigpanic calls can be injected.
//
// The signal handler must not inject a call to sigpanic if
// getg().throwsplit, since sigpanic may need to grow the stack.
//
// This is exported via linkname to assembly in runtime/cgo.
//
//go:linkname sigpanic
func sigpanic() {
gp := getg()
if !canpanic() {
throw("unexpected signal during runtime execution")
}
switch gp.sig {
case _SIGBUS:
if gp.sigcode0 == _BUS_ADRERR && gp.sigcode1 < 0x1000 {
panicmem()
}
// Support runtime/debug.SetPanicOnFault.
if gp.paniconfault {
panicmemAddr(gp.sigcode1)
}
print("unexpected fault address ", hex(gp.sigcode1), "\n")
throw("fault")
case _SIGSEGV:
if (gp.sigcode0 == 0 || gp.sigcode0 == _SEGV_MAPERR || gp.sigcode0 == _SEGV_ACCERR) && gp.sigcode1 < 0x1000 {
panicmem()
}
// Support runtime/debug.SetPanicOnFault.
if gp.paniconfault {
panicmemAddr(gp.sigcode1)
}
if inUserArenaChunk(gp.sigcode1) {
// We could check that the arena chunk is explicitly set to fault,
// but the fact that we faulted on accessing it is enough to prove
// that it is.
print("accessed data from freed user arena ", hex(gp.sigcode1), "\n")
} else {
print("unexpected fault address ", hex(gp.sigcode1), "\n")
}
throw("fault")
case _SIGFPE:
switch gp.sigcode0 {
case _FPE_INTDIV:
panicdivide()
case _FPE_INTOVF:
panicoverflow()
}
panicfloat()
}
if gp.sig >= uint32(len(sigtable)) {
// can't happen: we looked up gp.sig in sigtable to decide to call sigpanic
throw("unexpected signal value")
}
panic(errorString(sigtable[gp.sig].name))
}
// dieFromSignal kills the program with a signal.
// This provides the expected exit status for the shell.
// This is only called with fatal signals expected to kill the process.
//
//go:nosplit
//go:nowritebarrierrec
func dieFromSignal(sig uint32) {
unblocksig(sig)
// Mark the signal as unhandled to ensure it is forwarded.
atomic.Store(&handlingSig[sig], 0)
raise(sig)
// That should have killed us. On some systems, though, raise
// sends the signal to the whole process rather than to just
// the current thread, which means that the signal may not yet
// have been delivered. Give other threads a chance to run and
// pick up the signal.
osyield()
osyield()
osyield()
// If that didn't work, try _SIG_DFL.
setsig(sig, _SIG_DFL)
raise(sig)
osyield()
osyield()
osyield()
// If we are still somehow running, just exit with the wrong status.
exit(2)
}
// raisebadsignal is called when a signal is received on a non-Go
// thread, and the Go program does not want to handle it (that is, the
// program has not called os/signal.Notify for the signal).
func raisebadsignal(sig uint32, c *sigctxt) {
if sig == _SIGPROF {
// Ignore profiling signals that arrive on non-Go threads.
return
}
var handler uintptr
if sig >= _NSIG {
handler = _SIG_DFL
} else {
handler = atomic.Loaduintptr(&fwdSig[sig])
}
// Reset the signal handler and raise the signal.
// We are currently running inside a signal handler, so the
// signal is blocked. We need to unblock it before raising the
// signal, or the signal we raise will be ignored until we return
// from the signal handler. We know that the signal was unblocked
// before entering the handler, or else we would not have received
// it. That means that we don't have to worry about blocking it
// again.
unblocksig(sig)
setsig(sig, handler)
// If we're linked into a non-Go program we want to try to
// avoid modifying the original context in which the signal
// was raised. If the handler is the default, we know it
// is non-recoverable, so we don't have to worry about
// re-installing sighandler. At this point we can just
// return and the signal will be re-raised and caught by
// the default handler with the correct context.
//
// On FreeBSD, the libthr sigaction code prevents
// this from working so we fall through to raise.
if GOOS != "freebsd" && (isarchive || islibrary) && handler == _SIG_DFL && !c.sigFromUser() {
return
}
raise(sig)
// Give the signal a chance to be delivered.
// In almost all real cases the program is about to crash,
// so sleeping here is not a waste of time.
usleep(1000)
// If the signal didn't cause the program to exit, restore the
// Go signal handler and carry on.
//
// We may receive another instance of the signal before we
// restore the Go handler, but that is not so bad: we know
// that the Go program has been ignoring the signal.
setsig(sig, abi.FuncPCABIInternal(sighandler))
}
//go:nosplit
func crash() {
// OS X core dumps are linear dumps of the mapped memory,
// from the first virtual byte to the last, with zeros in the gaps.
// Because of the way we arrange the address space on 64-bit systems,
// this means the OS X core file will be >128 GB and even on a zippy
// workstation can take OS X well over an hour to write (uninterruptible).
// Save users from making that mistake.
if GOOS == "darwin" && GOARCH == "amd64" {
return
}
dieFromSignal(_SIGABRT)
}
// ensureSigM starts one global, sleeping thread to make sure at least one thread
// is available to catch signals enabled for os/signal.
func ensureSigM() {
if maskUpdatedChan != nil {
return
}
maskUpdatedChan = make(chan struct{})
disableSigChan = make(chan uint32)
enableSigChan = make(chan uint32)
go func() {
// Signal masks are per-thread, so make sure this goroutine stays on one
// thread.
LockOSThread()
defer UnlockOSThread()
// The sigBlocked mask contains the signals not active for os/signal,
// initially all signals except the essential. When signal.Notify()/Stop is called,
// sigenable/sigdisable in turn notify this thread to update its signal
// mask accordingly.
sigBlocked := sigset_all
for i := range sigtable {
if !blockableSig(uint32(i)) {
sigdelset(&sigBlocked, i)
}
}
sigprocmask(_SIG_SETMASK, &sigBlocked, nil)
for {
select {
case sig := <-enableSigChan:
if sig > 0 {
sigdelset(&sigBlocked, int(sig))
}
case sig := <-disableSigChan:
if sig > 0 && blockableSig(sig) {
sigaddset(&sigBlocked, int(sig))
}
}
sigprocmask(_SIG_SETMASK, &sigBlocked, nil)
maskUpdatedChan <- struct{}{}
}
}()
}
// This is called when we receive a signal when there is no signal stack.
// This can only happen if non-Go code calls sigaltstack to disable the
// signal stack.
func noSignalStack(sig uint32) {
println("signal", sig, "received on thread with no signal stack")
throw("non-Go code disabled sigaltstack")
}
// This is called if we receive a signal when there is a signal stack
// but we are not on it. This can only happen if non-Go code called
// sigaction without setting the SS_ONSTACK flag.
func sigNotOnStack(sig uint32) {
println("signal", sig, "received but handler not on signal stack")
throw("non-Go code set up signal handler without SA_ONSTACK flag")
}
// signalDuringFork is called if we receive a signal while doing a fork.
// We do not want signals at that time, as a signal sent to the process
// group may be delivered to the child process, causing confusion.
// This should never be called, because we block signals across the fork;
// this function is just a safety check. See issue 18600 for background.
func signalDuringFork(sig uint32) {
println("signal", sig, "received during fork")
throw("signal received during fork")
}
// This runs on a foreign stack, without an m or a g. No stack split.
//
//go:nosplit
//go:norace
//go:nowritebarrierrec
func badsignal(sig uintptr, c *sigctxt) {
if !iscgo && !cgoHasExtraM {
// There is no extra M. needm will not be able to grab
// an M. Instead of hanging, just crash.
// Cannot call split-stack function as there is no G.
writeErrStr("fatal: bad g in signal handler\n")
exit(2)
*(*uintptr)(unsafe.Pointer(uintptr(123))) = 2
}
needm()
if !sigsend(uint32(sig)) {
// A foreign thread received the signal sig, and the
// Go code does not want to handle it.
raisebadsignal(uint32(sig), c)
}
dropm()
}
//go:noescape
func sigfwd(fn uintptr, sig uint32, info *siginfo, ctx unsafe.Pointer)
// Determines if the signal should be handled by Go and if not, forwards the
// signal to the handler that was installed before Go's. Returns whether the
// signal was forwarded.
// This is called by the signal handler, and the world may be stopped.
//
//go:nosplit
//go:nowritebarrierrec
func sigfwdgo(sig uint32, info *siginfo, ctx unsafe.Pointer) bool {
if sig >= uint32(len(sigtable)) {
return false
}
fwdFn := atomic.Loaduintptr(&fwdSig[sig])
flags := sigtable[sig].flags
// If we aren't handling the signal, forward it.
if atomic.Load(&handlingSig[sig]) == 0 || !signalsOK {
// If the signal is ignored, doing nothing is the same as forwarding.
if fwdFn == _SIG_IGN || (fwdFn == _SIG_DFL && flags&_SigIgn != 0) {
return true
}
// We are not handling the signal and there is no other handler to forward to.
// Crash with the default behavior.
if fwdFn == _SIG_DFL {
setsig(sig, _SIG_DFL)
dieFromSignal(sig)
return false
}
sigfwd(fwdFn, sig, info, ctx)
return true
}
// This function and its caller sigtrampgo assumes SIGPIPE is delivered on the
// originating thread. This property does not hold on macOS (golang.org/issue/33384),
// so we have no choice but to ignore SIGPIPE.
if (GOOS == "darwin" || GOOS == "ios") && sig == _SIGPIPE {
return true
}
// If there is no handler to forward to, no need to forward.
if fwdFn == _SIG_DFL {
return false
}
c := &sigctxt{info, ctx}
// Only forward synchronous signals and SIGPIPE.
// Unfortunately, user generated SIGPIPEs will also be forwarded, because si_code
// is set to _SI_USER even for a SIGPIPE raised from a write to a closed socket
// or pipe.
if (c.sigFromUser() || flags&_SigPanic == 0) && sig != _SIGPIPE {
return false
}
// Determine if the signal occurred inside Go code. We test that:
// (1) we weren't in VDSO page,
// (2) we were in a goroutine (i.e., m.curg != nil), and
// (3) we weren't in CGO.
gp := sigFetchG(c)
if gp != nil && gp.m != nil && gp.m.curg != nil && !gp.m.incgo {
return false
}
// Signal not handled by Go, forward it.
if fwdFn != _SIG_IGN {
sigfwd(fwdFn, sig, info, ctx)
}
return true
}
// sigsave saves the current thread's signal mask into *p.
// This is used to preserve the non-Go signal mask when a non-Go
// thread calls a Go function.
// This is nosplit and nowritebarrierrec because it is called by needm
// which may be called on a non-Go thread with no g available.
//
//go:nosplit
//go:nowritebarrierrec
func sigsave(p *sigset) {
sigprocmask(_SIG_SETMASK, nil, p)
}
// msigrestore sets the current thread's signal mask to sigmask.
// This is used to restore the non-Go signal mask when a non-Go thread
// calls a Go function.
// This is nosplit and nowritebarrierrec because it is called by dropm
// after g has been cleared.
//
//go:nosplit
//go:nowritebarrierrec
func msigrestore(sigmask sigset) {
sigprocmask(_SIG_SETMASK, &sigmask, nil)
}
// sigsetAllExiting is used by sigblock(true) when a thread is
// exiting. sigset_all is defined in OS specific code, and per GOOS
// behavior may override this default for sigsetAllExiting: see
// osinit().
var sigsetAllExiting = sigset_all
// sigblock blocks signals in the current thread's signal mask.
// This is used to block signals while setting up and tearing down g
// when a non-Go thread calls a Go function. When a thread is exiting
// we use the sigsetAllExiting value, otherwise the OS specific
// definition of sigset_all is used.
// This is nosplit and nowritebarrierrec because it is called by needm
// which may be called on a non-Go thread with no g available.
//
//go:nosplit
//go:nowritebarrierrec
func sigblock(exiting bool) {
if exiting {
sigprocmask(_SIG_SETMASK, &sigsetAllExiting, nil)
return
}
sigprocmask(_SIG_SETMASK, &sigset_all, nil)
}
// unblocksig removes sig from the current thread's signal mask.
// This is nosplit and nowritebarrierrec because it is called from
// dieFromSignal, which can be called by sigfwdgo while running in the
// signal handler, on the signal stack, with no g available.
//
//go:nosplit
//go:nowritebarrierrec
func unblocksig(sig uint32) {
var set sigset
sigaddset(&set, int(sig))
sigprocmask(_SIG_UNBLOCK, &set, nil)
}
// minitSignals is called when initializing a new m to set the
// thread's alternate signal stack and signal mask.
func minitSignals() {
minitSignalStack()
minitSignalMask()
}
// minitSignalStack is called when initializing a new m to set the
// alternate signal stack. If the alternate signal stack is not set
// for the thread (the normal case) then set the alternate signal
// stack to the gsignal stack. If the alternate signal stack is set
// for the thread (the case when a non-Go thread sets the alternate
// signal stack and then calls a Go function) then set the gsignal
// stack to the alternate signal stack. We also set the alternate
// signal stack to the gsignal stack if cgo is not used (regardless
// of whether it is already set). Record which choice was made in
// newSigstack, so that it can be undone in unminit.
func minitSignalStack() {
mp := getg().m
var st stackt
sigaltstack(nil, &st)
if st.ss_flags&_SS_DISABLE != 0 || !iscgo {
signalstack(&mp.gsignal.stack)
mp.newSigstack = true
} else {
setGsignalStack(&st, &mp.goSigStack)
mp.newSigstack = false
}
}
// minitSignalMask is called when initializing a new m to set the
// thread's signal mask. When this is called all signals have been
// blocked for the thread. This starts with m.sigmask, which was set
// either from initSigmask for a newly created thread or by calling
// sigsave if this is a non-Go thread calling a Go function. It
// removes all essential signals from the mask, thus causing those
// signals to not be blocked. Then it sets the thread's signal mask.
// After this is called the thread can receive signals.
func minitSignalMask() {
nmask := getg().m.sigmask
for i := range sigtable {
if !blockableSig(uint32(i)) {
sigdelset(&nmask, i)
}
}
sigprocmask(_SIG_SETMASK, &nmask, nil)
}
// unminitSignals is called from dropm, via unminit, to undo the
// effect of calling minit on a non-Go thread.
//
//go:nosplit
func unminitSignals() {
if getg().m.newSigstack {
st := stackt{ss_flags: _SS_DISABLE}
sigaltstack(&st, nil)
} else {
// We got the signal stack from someone else. Restore
// the Go-allocated stack in case this M gets reused
// for another thread (e.g., it's an extram). Also, on
// Android, libc allocates a signal stack for all
// threads, so it's important to restore the Go stack
// even on Go-created threads so we can free it.
restoreGsignalStack(&getg().m.goSigStack)
}
}
// blockableSig reports whether sig may be blocked by the signal mask.
// We never want to block the signals marked _SigUnblock;
// these are the synchronous signals that turn into a Go panic.
// We never want to block the preemption signal if it is being used.
// In a Go program--not a c-archive/c-shared--we never want to block
// the signals marked _SigKill or _SigThrow, as otherwise it's possible
// for all running threads to block them and delay their delivery until
// we start a new thread. When linked into a C program we let the C code
// decide on the disposition of those signals.
func blockableSig(sig uint32) bool {
flags := sigtable[sig].flags
if flags&_SigUnblock != 0 {
return false
}
if sig == sigPreempt && preemptMSupported && debug.asyncpreemptoff == 0 {
return false
}
if isarchive || islibrary {
return true
}
return flags&(_SigKill|_SigThrow) == 0
}
// gsignalStack saves the fields of the gsignal stack changed by
// setGsignalStack.
type gsignalStack struct {
stack stack
stackguard0 uintptr
stackguard1 uintptr
stktopsp uintptr
}
// setGsignalStack sets the gsignal stack of the current m to an
// alternate signal stack returned from the sigaltstack system call.
// It saves the old values in *old for use by restoreGsignalStack.
// This is used when handling a signal if non-Go code has set the
// alternate signal stack.
//
//go:nosplit
//go:nowritebarrierrec
func setGsignalStack(st *stackt, old *gsignalStack) {
gp := getg()
if old != nil {
old.stack = gp.m.gsignal.stack
old.stackguard0 = gp.m.gsignal.stackguard0
old.stackguard1 = gp.m.gsignal.stackguard1
old.stktopsp = gp.m.gsignal.stktopsp
}
stsp := uintptr(unsafe.Pointer(st.ss_sp))
gp.m.gsignal.stack.lo = stsp
gp.m.gsignal.stack.hi = stsp + st.ss_size
gp.m.gsignal.stackguard0 = stsp + _StackGuard
gp.m.gsignal.stackguard1 = stsp + _StackGuard
}
// restoreGsignalStack restores the gsignal stack to the value it had
// before entering the signal handler.
//
//go:nosplit
//go:nowritebarrierrec
func restoreGsignalStack(st *gsignalStack) {
gp := getg().m.gsignal
gp.stack = st.stack
gp.stackguard0 = st.stackguard0
gp.stackguard1 = st.stackguard1
gp.stktopsp = st.stktopsp
}
// signalstack sets the current thread's alternate signal stack to s.
//
//go:nosplit
func signalstack(s *stack) {
st := stackt{ss_size: s.hi - s.lo}
setSignalstackSP(&st, s.lo)
sigaltstack(&st, nil)
}
// setsigsegv is used on darwin/arm64 to fake a segmentation fault.
//
// This is exported via linkname to assembly in runtime/cgo.
//
//go:nosplit
//go:linkname setsigsegv
func setsigsegv(pc uintptr) {
gp := getg()
gp.sig = _SIGSEGV
gp.sigpc = pc
gp.sigcode0 = _SEGV_MAPERR
gp.sigcode1 = 0 // TODO: emulate si_addr
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements runtime support for signal handling.
//
// Most synchronization primitives are not available from
// the signal handler (it cannot block, allocate memory, or use locks)
// so the handler communicates with a processing goroutine
// via struct sig, below.
//
// sigsend is called by the signal handler to queue a new signal.
// signal_recv is called by the Go program to receive a newly queued signal.
//
// Synchronization between sigsend and signal_recv is based on the sig.state
// variable. It can be in three states:
// * sigReceiving means that signal_recv is blocked on sig.Note and there are
// no new pending signals.
// * sigSending means that sig.mask *may* contain new pending signals,
// signal_recv can't be blocked in this state.
// * sigIdle means that there are no new pending signals and signal_recv is not
// blocked.
//
// Transitions between states are done atomically with CAS.
//
// When signal_recv is unblocked, it resets sig.Note and rechecks sig.mask.
// If several sigsends and signal_recv execute concurrently, it can lead to
// unnecessary rechecks of sig.mask, but it cannot lead to missed signals
// nor deadlocks.
//go:build !plan9
package runtime
import (
"runtime/internal/atomic"
_ "unsafe" // for go:linkname
)
// sig handles communication between the signal handler and os/signal.
// Other than the inuse and recv fields, the fields are accessed atomically.
//
// The wanted and ignored fields are only written by one goroutine at
// a time; access is controlled by the handlers Mutex in os/signal.
// The fields are only read by that one goroutine and by the signal handler.
// We access them atomically to minimize the race between setting them
// in the goroutine calling os/signal and the signal handler,
// which may be running in a different thread. That race is unavoidable,
// as there is no connection between handling a signal and receiving one,
// but atomic instructions should minimize it.
var sig struct {
note note
mask [(_NSIG + 31) / 32]uint32
wanted [(_NSIG + 31) / 32]uint32
ignored [(_NSIG + 31) / 32]uint32
recv [(_NSIG + 31) / 32]uint32
state atomic.Uint32
delivering atomic.Uint32
inuse bool
}
const (
sigIdle = iota
sigReceiving
sigSending
)
// sigsend delivers a signal from sighandler to the internal signal delivery queue.
// It reports whether the signal was sent. If not, the caller typically crashes the program.
// It runs from the signal handler, so it's limited in what it can do.
func sigsend(s uint32) bool {
bit := uint32(1) << uint(s&31)
if s >= uint32(32*len(sig.wanted)) {
return false
}
sig.delivering.Add(1)
// We are running in the signal handler; defer is not available.
if w := atomic.Load(&sig.wanted[s/32]); w&bit == 0 {
sig.delivering.Add(-1)
return false
}
// Add signal to outgoing queue.
for {
mask := sig.mask[s/32]
if mask&bit != 0 {
sig.delivering.Add(-1)
return true // signal already in queue
}
if atomic.Cas(&sig.mask[s/32], mask, mask|bit) {
break
}
}
// Notify receiver that queue has new bit.
Send:
for {
switch sig.state.Load() {
default:
throw("sigsend: inconsistent state")
case sigIdle:
if sig.state.CompareAndSwap(sigIdle, sigSending) {
break Send
}
case sigSending:
// notification already pending
break Send
case sigReceiving:
if sig.state.CompareAndSwap(sigReceiving, sigIdle) {
if GOOS == "darwin" || GOOS == "ios" {
sigNoteWakeup(&sig.note)
break Send
}
notewakeup(&sig.note)
break Send
}
}
}
sig.delivering.Add(-1)
return true
}
// Called to receive the next queued signal.
// Must only be called from a single goroutine at a time.
//
//go:linkname signal_recv os/signal.signal_recv
func signal_recv() uint32 {
for {
// Serve any signals from local copy.
for i := uint32(0); i < _NSIG; i++ {
if sig.recv[i/32]&(1<<(i&31)) != 0 {
sig.recv[i/32] &^= 1 << (i & 31)
return i
}
}
// Wait for updates to be available from signal sender.
Receive:
for {
switch sig.state.Load() {
default:
throw("signal_recv: inconsistent state")
case sigIdle:
if sig.state.CompareAndSwap(sigIdle, sigReceiving) {
if GOOS == "darwin" || GOOS == "ios" {
sigNoteSleep(&sig.note)
break Receive
}
notetsleepg(&sig.note, -1)
noteclear(&sig.note)
break Receive
}
case sigSending:
if sig.state.CompareAndSwap(sigSending, sigIdle) {
break Receive
}
}
}
// Incorporate updates from sender into local copy.
for i := range sig.mask {
sig.recv[i] = atomic.Xchg(&sig.mask[i], 0)
}
}
}
// signalWaitUntilIdle waits until the signal delivery mechanism is idle.
// This is used to ensure that we do not drop a signal notification due
// to a race between disabling a signal and receiving a signal.
// This assumes that signal delivery has already been disabled for
// the signal(s) in question, and here we are just waiting to make sure
// that all the signals have been delivered to the user channels
// by the os/signal package.
//
//go:linkname signalWaitUntilIdle os/signal.signalWaitUntilIdle
func signalWaitUntilIdle() {
// Although the signals we care about have been removed from
// sig.wanted, it is possible that another thread has received
// a signal, has read from sig.wanted, is now updating sig.mask,
// and has not yet woken up the processor thread. We need to wait
// until all current signal deliveries have completed.
for sig.delivering.Load() != 0 {
Gosched()
}
// Although WaitUntilIdle seems like the right name for this
// function, the state we are looking for is sigReceiving, not
// sigIdle. The sigIdle state is really more like sigProcessing.
for sig.state.Load() != sigReceiving {
Gosched()
}
}
// Must only be called from a single goroutine at a time.
//
//go:linkname signal_enable os/signal.signal_enable
func signal_enable(s uint32) {
if !sig.inuse {
// This is the first call to signal_enable. Initialize.
sig.inuse = true // enable reception of signals; cannot disable
if GOOS == "darwin" || GOOS == "ios" {
sigNoteSetup(&sig.note)
} else {
noteclear(&sig.note)
}
}
if s >= uint32(len(sig.wanted)*32) {
return
}
w := sig.wanted[s/32]
w |= 1 << (s & 31)
atomic.Store(&sig.wanted[s/32], w)
i := sig.ignored[s/32]
i &^= 1 << (s & 31)
atomic.Store(&sig.ignored[s/32], i)
sigenable(s)
}
// Must only be called from a single goroutine at a time.
//
//go:linkname signal_disable os/signal.signal_disable
func signal_disable(s uint32) {
if s >= uint32(len(sig.wanted)*32) {
return
}
sigdisable(s)
w := sig.wanted[s/32]
w &^= 1 << (s & 31)
atomic.Store(&sig.wanted[s/32], w)
}
// Must only be called from a single goroutine at a time.
//
//go:linkname signal_ignore os/signal.signal_ignore
func signal_ignore(s uint32) {
if s >= uint32(len(sig.wanted)*32) {
return
}
sigignore(s)
w := sig.wanted[s/32]
w &^= 1 << (s & 31)
atomic.Store(&sig.wanted[s/32], w)
i := sig.ignored[s/32]
i |= 1 << (s & 31)
atomic.Store(&sig.ignored[s/32], i)
}
// sigInitIgnored marks the signal as already ignored. This is called at
// program start by initsig. In a shared library initsig is called by
// libpreinit, so the runtime may not be initialized yet.
//
//go:nosplit
func sigInitIgnored(s uint32) {
i := sig.ignored[s/32]
i |= 1 << (s & 31)
atomic.Store(&sig.ignored[s/32], i)
}
// Checked by signal handlers.
//
//go:linkname signal_ignored os/signal.signal_ignored
func signal_ignored(s uint32) bool {
i := atomic.Load(&sig.ignored[s/32])
return i&(1<<(s&31)) != 0
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The current implementation of notes on Darwin is not async-signal-safe,
// so on Darwin the sigqueue code uses different functions to wake up the
// signal_recv thread. This file holds the non-Darwin implementations of
// those functions. These functions will never be called.
//go:build !darwin && !plan9
package runtime
func sigNoteSetup(*note) {
throw("sigNoteSetup")
}
func sigNoteSleep(*note) {
throw("sigNoteSleep")
}
func sigNoteWakeup(*note) {
throw("sigNoteWakeup")
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/math"
"runtime/internal/sys"
"unsafe"
)
type slice struct {
array unsafe.Pointer
len int
cap int
}
// A notInHeapSlice is a slice backed by runtime/internal/sys.NotInHeap memory.
type notInHeapSlice struct {
array *notInHeap
len int
cap int
}
func panicmakeslicelen() {
panic(errorString("makeslice: len out of range"))
}
func panicmakeslicecap() {
panic(errorString("makeslice: cap out of range"))
}
// makeslicecopy allocates a slice of "tolen" elements of type "et",
// then copies "fromlen" elements of type "et" into that new allocation from "from".
func makeslicecopy(et *_type, tolen int, fromlen int, from unsafe.Pointer) unsafe.Pointer {
var tomem, copymem uintptr
if uintptr(tolen) > uintptr(fromlen) {
var overflow bool
tomem, overflow = math.MulUintptr(et.size, uintptr(tolen))
if overflow || tomem > maxAlloc || tolen < 0 {
panicmakeslicelen()
}
copymem = et.size * uintptr(fromlen)
} else {
// fromlen is a known good length providing and equal or greater than tolen,
// thereby making tolen a good slice length too as from and to slices have the
// same element width.
tomem = et.size * uintptr(tolen)
copymem = tomem
}
var to unsafe.Pointer
if et.ptrdata == 0 {
to = mallocgc(tomem, nil, false)
if copymem < tomem {
memclrNoHeapPointers(add(to, copymem), tomem-copymem)
}
} else {
// Note: can't use rawmem (which avoids zeroing of memory), because then GC can scan uninitialized memory.
to = mallocgc(tomem, et, true)
if copymem > 0 && writeBarrier.enabled {
// Only shade the pointers in old.array since we know the destination slice to
// only contains nil pointers because it has been cleared during alloc.
bulkBarrierPreWriteSrcOnly(uintptr(to), uintptr(from), copymem)
}
}
if raceenabled {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(makeslicecopy)
racereadrangepc(from, copymem, callerpc, pc)
}
if msanenabled {
msanread(from, copymem)
}
if asanenabled {
asanread(from, copymem)
}
memmove(to, from, copymem)
return to
}
func makeslice(et *_type, len, cap int) unsafe.Pointer {
mem, overflow := math.MulUintptr(et.size, uintptr(cap))
if overflow || mem > maxAlloc || len < 0 || len > cap {
// NOTE: Produce a 'len out of range' error instead of a
// 'cap out of range' error when someone does make([]T, bignumber).
// 'cap out of range' is true too, but since the cap is only being
// supplied implicitly, saying len is clearer.
// See golang.org/issue/4085.
mem, overflow := math.MulUintptr(et.size, uintptr(len))
if overflow || mem > maxAlloc || len < 0 {
panicmakeslicelen()
}
panicmakeslicecap()
}
return mallocgc(mem, et, true)
}
func makeslice64(et *_type, len64, cap64 int64) unsafe.Pointer {
len := int(len64)
if int64(len) != len64 {
panicmakeslicelen()
}
cap := int(cap64)
if int64(cap) != cap64 {
panicmakeslicecap()
}
return makeslice(et, len, cap)
}
// This is a wrapper over runtime/internal/math.MulUintptr,
// so the compiler can recognize and treat it as an intrinsic.
func mulUintptr(a, b uintptr) (uintptr, bool) {
return math.MulUintptr(a, b)
}
// growslice allocates new backing store for a slice.
//
// arguments:
//
// oldPtr = pointer to the slice's backing array
// newLen = new length (= oldLen + num)
// oldCap = original slice's capacity.
// num = number of elements being added
// et = element type
//
// return values:
//
// newPtr = pointer to the new backing store
// newLen = same value as the argument
// newCap = capacity of the new backing store
//
// Requires that uint(newLen) > uint(oldCap).
// Assumes the original slice length is newLen - num
//
// A new backing store is allocated with space for at least newLen elements.
// Existing entries [0, oldLen) are copied over to the new backing store.
// Added entries [oldLen, newLen) are not initialized by growslice
// (although for pointer-containing element types, they are zeroed). They
// must be initialized by the caller.
// Trailing entries [newLen, newCap) are zeroed.
//
// growslice's odd calling convention makes the generated code that calls
// this function simpler. In particular, it accepts and returns the
// new length so that the old length is not live (does not need to be
// spilled/restored) and the new length is returned (also does not need
// to be spilled/restored).
func growslice(oldPtr unsafe.Pointer, newLen, oldCap, num int, et *_type) slice {
oldLen := newLen - num
if raceenabled {
callerpc := getcallerpc()
racereadrangepc(oldPtr, uintptr(oldLen*int(et.size)), callerpc, abi.FuncPCABIInternal(growslice))
}
if msanenabled {
msanread(oldPtr, uintptr(oldLen*int(et.size)))
}
if asanenabled {
asanread(oldPtr, uintptr(oldLen*int(et.size)))
}
if newLen < 0 {
panic(errorString("growslice: len out of range"))
}
if et.size == 0 {
// append should not create a slice with nil pointer but non-zero len.
// We assume that append doesn't need to preserve oldPtr in this case.
return slice{unsafe.Pointer(&zerobase), newLen, newLen}
}
newcap := oldCap
doublecap := newcap + newcap
if newLen > doublecap {
newcap = newLen
} else {
const threshold = 256
if oldCap < threshold {
newcap = doublecap
} else {
// Check 0 < newcap to detect overflow
// and prevent an infinite loop.
for 0 < newcap && newcap < newLen {
// Transition from growing 2x for small slices
// to growing 1.25x for large slices. This formula
// gives a smooth-ish transition between the two.
newcap += (newcap + 3*threshold) / 4
}
// Set newcap to the requested cap when
// the newcap calculation overflowed.
if newcap <= 0 {
newcap = newLen
}
}
}
var overflow bool
var lenmem, newlenmem, capmem uintptr
// Specialize for common values of et.size.
// For 1 we don't need any division/multiplication.
// For goarch.PtrSize, compiler will optimize division/multiplication into a shift by a constant.
// For powers of 2, use a variable shift.
switch {
case et.size == 1:
lenmem = uintptr(oldLen)
newlenmem = uintptr(newLen)
capmem = roundupsize(uintptr(newcap))
overflow = uintptr(newcap) > maxAlloc
newcap = int(capmem)
case et.size == goarch.PtrSize:
lenmem = uintptr(oldLen) * goarch.PtrSize
newlenmem = uintptr(newLen) * goarch.PtrSize
capmem = roundupsize(uintptr(newcap) * goarch.PtrSize)
overflow = uintptr(newcap) > maxAlloc/goarch.PtrSize
newcap = int(capmem / goarch.PtrSize)
case isPowerOfTwo(et.size):
var shift uintptr
if goarch.PtrSize == 8 {
// Mask shift for better code generation.
shift = uintptr(sys.TrailingZeros64(uint64(et.size))) & 63
} else {
shift = uintptr(sys.TrailingZeros32(uint32(et.size))) & 31
}
lenmem = uintptr(oldLen) << shift
newlenmem = uintptr(newLen) << shift
capmem = roundupsize(uintptr(newcap) << shift)
overflow = uintptr(newcap) > (maxAlloc >> shift)
newcap = int(capmem >> shift)
capmem = uintptr(newcap) << shift
default:
lenmem = uintptr(oldLen) * et.size
newlenmem = uintptr(newLen) * et.size
capmem, overflow = math.MulUintptr(et.size, uintptr(newcap))
capmem = roundupsize(capmem)
newcap = int(capmem / et.size)
capmem = uintptr(newcap) * et.size
}
// The check of overflow in addition to capmem > maxAlloc is needed
// to prevent an overflow which can be used to trigger a segfault
// on 32bit architectures with this example program:
//
// type T [1<<27 + 1]int64
//
// var d T
// var s []T
//
// func main() {
// s = append(s, d, d, d, d)
// print(len(s), "\n")
// }
if overflow || capmem > maxAlloc {
panic(errorString("growslice: len out of range"))
}
var p unsafe.Pointer
if et.ptrdata == 0 {
p = mallocgc(capmem, nil, false)
// The append() that calls growslice is going to overwrite from oldLen to newLen.
// Only clear the part that will not be overwritten.
// The reflect_growslice() that calls growslice will manually clear
// the region not cleared here.
memclrNoHeapPointers(add(p, newlenmem), capmem-newlenmem)
} else {
// Note: can't use rawmem (which avoids zeroing of memory), because then GC can scan uninitialized memory.
p = mallocgc(capmem, et, true)
if lenmem > 0 && writeBarrier.enabled {
// Only shade the pointers in oldPtr since we know the destination slice p
// only contains nil pointers because it has been cleared during alloc.
bulkBarrierPreWriteSrcOnly(uintptr(p), uintptr(oldPtr), lenmem-et.size+et.ptrdata)
}
}
memmove(p, oldPtr, lenmem)
return slice{p, newLen, newcap}
}
//go:linkname reflect_growslice reflect.growslice
func reflect_growslice(et *_type, old slice, num int) slice {
// Semantically equivalent to slices.Grow, except that the caller
// is responsible for ensuring that old.len+num > old.cap.
num -= old.cap - old.len // preserve memory of old[old.len:old.cap]
new := growslice(old.array, old.cap+num, old.cap, num, et)
// growslice does not zero out new[old.cap:new.len] since it assumes that
// the memory will be overwritten by an append() that called growslice.
// Since the caller of reflect_growslice is not append(),
// zero out this region before returning the slice to the reflect package.
if et.ptrdata == 0 {
oldcapmem := uintptr(old.cap) * et.size
newlenmem := uintptr(new.len) * et.size
memclrNoHeapPointers(add(new.array, oldcapmem), newlenmem-oldcapmem)
}
new.len = old.len // preserve the old length
return new
}
func isPowerOfTwo(x uintptr) bool {
return x&(x-1) == 0
}
// slicecopy is used to copy from a string or slice of pointerless elements into a slice.
func slicecopy(toPtr unsafe.Pointer, toLen int, fromPtr unsafe.Pointer, fromLen int, width uintptr) int {
if fromLen == 0 || toLen == 0 {
return 0
}
n := fromLen
if toLen < n {
n = toLen
}
if width == 0 {
return n
}
size := uintptr(n) * width
if raceenabled {
callerpc := getcallerpc()
pc := abi.FuncPCABIInternal(slicecopy)
racereadrangepc(fromPtr, size, callerpc, pc)
racewriterangepc(toPtr, size, callerpc, pc)
}
if msanenabled {
msanread(fromPtr, size)
msanwrite(toPtr, size)
}
if asanenabled {
asanread(fromPtr, size)
asanwrite(toPtr, size)
}
if size == 1 { // common case worth about 2x to do here
// TODO: is this still worth it with new memmove impl?
*(*byte)(toPtr) = *(*byte)(fromPtr) // known to be a byte pointer
} else {
memmove(toPtr, fromPtr, size)
}
return n
}
//go:linkname bytealg_MakeNoZero internal/bytealg.MakeNoZero
func bytealg_MakeNoZero(len int) []byte {
if uintptr(len) > maxAlloc {
panicmakeslicelen()
}
return unsafe.Slice((*byte)(mallocgc(uintptr(len), nil, false)), len)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Software IEEE754 64-bit floating point.
// Only referred to (and thus linked in) by softfloat targets
// and by tests in this directory.
package runtime
const (
mantbits64 uint = 52
expbits64 uint = 11
bias64 = -1<<(expbits64-1) + 1
nan64 uint64 = (1<<expbits64-1)<<mantbits64 + 1<<(mantbits64-1) // quiet NaN, 0 payload
inf64 uint64 = (1<<expbits64 - 1) << mantbits64
neg64 uint64 = 1 << (expbits64 + mantbits64)
mantbits32 uint = 23
expbits32 uint = 8
bias32 = -1<<(expbits32-1) + 1
nan32 uint32 = (1<<expbits32-1)<<mantbits32 + 1<<(mantbits32-1) // quiet NaN, 0 payload
inf32 uint32 = (1<<expbits32 - 1) << mantbits32
neg32 uint32 = 1 << (expbits32 + mantbits32)
)
func funpack64(f uint64) (sign, mant uint64, exp int, inf, nan bool) {
sign = f & (1 << (mantbits64 + expbits64))
mant = f & (1<<mantbits64 - 1)
exp = int(f>>mantbits64) & (1<<expbits64 - 1)
switch exp {
case 1<<expbits64 - 1:
if mant != 0 {
nan = true
return
}
inf = true
return
case 0:
// denormalized
if mant != 0 {
exp += bias64 + 1
for mant < 1<<mantbits64 {
mant <<= 1
exp--
}
}
default:
// add implicit top bit
mant |= 1 << mantbits64
exp += bias64
}
return
}
func funpack32(f uint32) (sign, mant uint32, exp int, inf, nan bool) {
sign = f & (1 << (mantbits32 + expbits32))
mant = f & (1<<mantbits32 - 1)
exp = int(f>>mantbits32) & (1<<expbits32 - 1)
switch exp {
case 1<<expbits32 - 1:
if mant != 0 {
nan = true
return
}
inf = true
return
case 0:
// denormalized
if mant != 0 {
exp += bias32 + 1
for mant < 1<<mantbits32 {
mant <<= 1
exp--
}
}
default:
// add implicit top bit
mant |= 1 << mantbits32
exp += bias32
}
return
}
func fpack64(sign, mant uint64, exp int, trunc uint64) uint64 {
mant0, exp0, trunc0 := mant, exp, trunc
if mant == 0 {
return sign
}
for mant < 1<<mantbits64 {
mant <<= 1
exp--
}
for mant >= 4<<mantbits64 {
trunc |= mant & 1
mant >>= 1
exp++
}
if mant >= 2<<mantbits64 {
if mant&1 != 0 && (trunc != 0 || mant&2 != 0) {
mant++
if mant >= 4<<mantbits64 {
mant >>= 1
exp++
}
}
mant >>= 1
exp++
}
if exp >= 1<<expbits64-1+bias64 {
return sign ^ inf64
}
if exp < bias64+1 {
if exp < bias64-int(mantbits64) {
return sign | 0
}
// repeat expecting denormal
mant, exp, trunc = mant0, exp0, trunc0
for exp < bias64 {
trunc |= mant & 1
mant >>= 1
exp++
}
if mant&1 != 0 && (trunc != 0 || mant&2 != 0) {
mant++
}
mant >>= 1
exp++
if mant < 1<<mantbits64 {
return sign | mant
}
}
return sign | uint64(exp-bias64)<<mantbits64 | mant&(1<<mantbits64-1)
}
func fpack32(sign, mant uint32, exp int, trunc uint32) uint32 {
mant0, exp0, trunc0 := mant, exp, trunc
if mant == 0 {
return sign
}
for mant < 1<<mantbits32 {
mant <<= 1
exp--
}
for mant >= 4<<mantbits32 {
trunc |= mant & 1
mant >>= 1
exp++
}
if mant >= 2<<mantbits32 {
if mant&1 != 0 && (trunc != 0 || mant&2 != 0) {
mant++
if mant >= 4<<mantbits32 {
mant >>= 1
exp++
}
}
mant >>= 1
exp++
}
if exp >= 1<<expbits32-1+bias32 {
return sign ^ inf32
}
if exp < bias32+1 {
if exp < bias32-int(mantbits32) {
return sign | 0
}
// repeat expecting denormal
mant, exp, trunc = mant0, exp0, trunc0
for exp < bias32 {
trunc |= mant & 1
mant >>= 1
exp++
}
if mant&1 != 0 && (trunc != 0 || mant&2 != 0) {
mant++
}
mant >>= 1
exp++
if mant < 1<<mantbits32 {
return sign | mant
}
}
return sign | uint32(exp-bias32)<<mantbits32 | mant&(1<<mantbits32-1)
}
func fadd64(f, g uint64) uint64 {
fs, fm, fe, fi, fn := funpack64(f)
gs, gm, ge, gi, gn := funpack64(g)
// Special cases.
switch {
case fn || gn: // NaN + x or x + NaN = NaN
return nan64
case fi && gi && fs != gs: // +Inf + -Inf or -Inf + +Inf = NaN
return nan64
case fi: // ±Inf + g = ±Inf
return f
case gi: // f + ±Inf = ±Inf
return g
case fm == 0 && gm == 0 && fs != 0 && gs != 0: // -0 + -0 = -0
return f
case fm == 0: // 0 + g = g but 0 + -0 = +0
if gm == 0 {
g ^= gs
}
return g
case gm == 0: // f + 0 = f
return f
}
if fe < ge || fe == ge && fm < gm {
f, g, fs, fm, fe, gs, gm, ge = g, f, gs, gm, ge, fs, fm, fe
}
shift := uint(fe - ge)
fm <<= 2
gm <<= 2
trunc := gm & (1<<shift - 1)
gm >>= shift
if fs == gs {
fm += gm
} else {
fm -= gm
if trunc != 0 {
fm--
}
}
if fm == 0 {
fs = 0
}
return fpack64(fs, fm, fe-2, trunc)
}
func fsub64(f, g uint64) uint64 {
return fadd64(f, fneg64(g))
}
func fneg64(f uint64) uint64 {
return f ^ (1 << (mantbits64 + expbits64))
}
func fmul64(f, g uint64) uint64 {
fs, fm, fe, fi, fn := funpack64(f)
gs, gm, ge, gi, gn := funpack64(g)
// Special cases.
switch {
case fn || gn: // NaN * g or f * NaN = NaN
return nan64
case fi && gi: // Inf * Inf = Inf (with sign adjusted)
return f ^ gs
case fi && gm == 0, fm == 0 && gi: // 0 * Inf = Inf * 0 = NaN
return nan64
case fm == 0: // 0 * x = 0 (with sign adjusted)
return f ^ gs
case gm == 0: // x * 0 = 0 (with sign adjusted)
return g ^ fs
}
// 53-bit * 53-bit = 107- or 108-bit
lo, hi := mullu(fm, gm)
shift := mantbits64 - 1
trunc := lo & (1<<shift - 1)
mant := hi<<(64-shift) | lo>>shift
return fpack64(fs^gs, mant, fe+ge-1, trunc)
}
func fdiv64(f, g uint64) uint64 {
fs, fm, fe, fi, fn := funpack64(f)
gs, gm, ge, gi, gn := funpack64(g)
// Special cases.
switch {
case fn || gn: // NaN / g = f / NaN = NaN
return nan64
case fi && gi: // ±Inf / ±Inf = NaN
return nan64
case !fi && !gi && fm == 0 && gm == 0: // 0 / 0 = NaN
return nan64
case fi, !gi && gm == 0: // Inf / g = f / 0 = Inf
return fs ^ gs ^ inf64
case gi, fm == 0: // f / Inf = 0 / g = Inf
return fs ^ gs ^ 0
}
_, _, _, _ = fi, fn, gi, gn
// 53-bit<<54 / 53-bit = 53- or 54-bit.
shift := mantbits64 + 2
q, r := divlu(fm>>(64-shift), fm<<shift, gm)
return fpack64(fs^gs, q, fe-ge-2, r)
}
func f64to32(f uint64) uint32 {
fs, fm, fe, fi, fn := funpack64(f)
if fn {
return nan32
}
fs32 := uint32(fs >> 32)
if fi {
return fs32 ^ inf32
}
const d = mantbits64 - mantbits32 - 1
return fpack32(fs32, uint32(fm>>d), fe-1, uint32(fm&(1<<d-1)))
}
func f32to64(f uint32) uint64 {
const d = mantbits64 - mantbits32
fs, fm, fe, fi, fn := funpack32(f)
if fn {
return nan64
}
fs64 := uint64(fs) << 32
if fi {
return fs64 ^ inf64
}
return fpack64(fs64, uint64(fm)<<d, fe, 0)
}
func fcmp64(f, g uint64) (cmp int32, isnan bool) {
fs, fm, _, fi, fn := funpack64(f)
gs, gm, _, gi, gn := funpack64(g)
switch {
case fn, gn: // flag NaN
return 0, true
case !fi && !gi && fm == 0 && gm == 0: // ±0 == ±0
return 0, false
case fs > gs: // f < 0, g > 0
return -1, false
case fs < gs: // f > 0, g < 0
return +1, false
// Same sign, not NaN.
// Can compare encodings directly now.
// Reverse for sign.
case fs == 0 && f < g, fs != 0 && f > g:
return -1, false
case fs == 0 && f > g, fs != 0 && f < g:
return +1, false
}
// f == g
return 0, false
}
func f64toint(f uint64) (val int64, ok bool) {
fs, fm, fe, fi, fn := funpack64(f)
switch {
case fi, fn: // NaN
return 0, false
case fe < -1: // f < 0.5
return 0, false
case fe > 63: // f >= 2^63
if fs != 0 && fm == 0 { // f == -2^63
return -1 << 63, true
}
if fs != 0 {
return 0, false
}
return 0, false
}
for fe > int(mantbits64) {
fe--
fm <<= 1
}
for fe < int(mantbits64) {
fe++
fm >>= 1
}
val = int64(fm)
if fs != 0 {
val = -val
}
return val, true
}
func fintto64(val int64) (f uint64) {
fs := uint64(val) & (1 << 63)
mant := uint64(val)
if fs != 0 {
mant = -mant
}
return fpack64(fs, mant, int(mantbits64), 0)
}
func fintto32(val int64) (f uint32) {
fs := uint64(val) & (1 << 63)
mant := uint64(val)
if fs != 0 {
mant = -mant
}
// Reduce mantissa size until it fits into a uint32.
// Keep track of the bits we throw away, and if any are
// nonzero or them into the lowest bit.
exp := int(mantbits32)
var trunc uint32
for mant >= 1<<32 {
trunc |= uint32(mant) & 1
mant >>= 1
exp++
}
return fpack32(uint32(fs>>32), uint32(mant), exp, trunc)
}
// 64x64 -> 128 multiply.
// adapted from hacker's delight.
func mullu(u, v uint64) (lo, hi uint64) {
const (
s = 32
mask = 1<<s - 1
)
u0 := u & mask
u1 := u >> s
v0 := v & mask
v1 := v >> s
w0 := u0 * v0
t := u1*v0 + w0>>s
w1 := t & mask
w2 := t >> s
w1 += u0 * v1
return u * v, u1*v1 + w2 + w1>>s
}
// 128/64 -> 64 quotient, 64 remainder.
// adapted from hacker's delight
func divlu(u1, u0, v uint64) (q, r uint64) {
const b = 1 << 32
if u1 >= v {
return 1<<64 - 1, 1<<64 - 1
}
// s = nlz(v); v <<= s
s := uint(0)
for v&(1<<63) == 0 {
s++
v <<= 1
}
vn1 := v >> 32
vn0 := v & (1<<32 - 1)
un32 := u1<<s | u0>>(64-s)
un10 := u0 << s
un1 := un10 >> 32
un0 := un10 & (1<<32 - 1)
q1 := un32 / vn1
rhat := un32 - q1*vn1
again1:
if q1 >= b || q1*vn0 > b*rhat+un1 {
q1--
rhat += vn1
if rhat < b {
goto again1
}
}
un21 := un32*b + un1 - q1*v
q0 := un21 / vn1
rhat = un21 - q0*vn1
again2:
if q0 >= b || q0*vn0 > b*rhat+un0 {
q0--
rhat += vn1
if rhat < b {
goto again2
}
}
return q1*b + q0, (un21*b + un0 - q0*v) >> s
}
func fadd32(x, y uint32) uint32 {
return f64to32(fadd64(f32to64(x), f32to64(y)))
}
func fmul32(x, y uint32) uint32 {
return f64to32(fmul64(f32to64(x), f32to64(y)))
}
func fdiv32(x, y uint32) uint32 {
// TODO: are there double-rounding problems here? See issue 48807.
return f64to32(fdiv64(f32to64(x), f32to64(y)))
}
func feq32(x, y uint32) bool {
cmp, nan := fcmp64(f32to64(x), f32to64(y))
return cmp == 0 && !nan
}
func fgt32(x, y uint32) bool {
cmp, nan := fcmp64(f32to64(x), f32to64(y))
return cmp >= 1 && !nan
}
func fge32(x, y uint32) bool {
cmp, nan := fcmp64(f32to64(x), f32to64(y))
return cmp >= 0 && !nan
}
func feq64(x, y uint64) bool {
cmp, nan := fcmp64(x, y)
return cmp == 0 && !nan
}
func fgt64(x, y uint64) bool {
cmp, nan := fcmp64(x, y)
return cmp >= 1 && !nan
}
func fge64(x, y uint64) bool {
cmp, nan := fcmp64(x, y)
return cmp >= 0 && !nan
}
func fint32to32(x int32) uint32 {
return fintto32(int64(x))
}
func fint32to64(x int32) uint64 {
return fintto64(int64(x))
}
func fint64to32(x int64) uint32 {
return fintto32(x)
}
func fint64to64(x int64) uint64 {
return fintto64(x)
}
func f32toint32(x uint32) int32 {
val, _ := f64toint(f32to64(x))
return int32(val)
}
func f32toint64(x uint32) int64 {
val, _ := f64toint(f32to64(x))
return val
}
func f64toint32(x uint64) int32 {
val, _ := f64toint(x)
return int32(val)
}
func f64toint64(x uint64) int64 {
val, _ := f64toint(x)
return val
}
func f64touint64(x uint64) uint64 {
var m uint64 = 0x43e0000000000000 // float64 1<<63
if fgt64(m, x) {
return uint64(f64toint64(x))
}
y := fadd64(x, -m)
z := uint64(f64toint64(y))
return z | (1 << 63)
}
func f32touint64(x uint32) uint64 {
var m uint32 = 0x5f000000 // float32 1<<63
if fgt32(m, x) {
return uint64(f32toint64(x))
}
y := fadd32(x, -m)
z := uint64(f32toint64(y))
return z | (1 << 63)
}
func fuint64to64(x uint64) uint64 {
if int64(x) >= 0 {
return fint64to64(int64(x))
}
// See ../cmd/compile/internal/ssagen/ssa.go:uint64Tofloat
y := x & 1
z := x >> 1
z = z | y
r := fint64to64(int64(z))
return fadd64(r, r)
}
func fuint64to32(x uint64) uint32 {
if int64(x) >= 0 {
return fint64to32(int64(x))
}
// See ../cmd/compile/internal/ssagen/ssa.go:uint64Tofloat
y := x & 1
z := x >> 1
z = z | y
r := fint64to32(int64(z))
return fadd32(r, r)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/cpu"
"internal/goarch"
"internal/goos"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
/*
Stack layout parameters.
Included both by runtime (compiled via 6c) and linkers (compiled via gcc).
The per-goroutine g->stackguard is set to point StackGuard bytes
above the bottom of the stack. Each function compares its stack
pointer against g->stackguard to check for overflow. To cut one
instruction from the check sequence for functions with tiny frames,
the stack is allowed to protrude StackSmall bytes below the stack
guard. Functions with large frames don't bother with the check and
always call morestack. The sequences are (for amd64, others are
similar):
guard = g->stackguard
frame = function's stack frame size
argsize = size of function arguments (call + return)
stack frame size <= StackSmall:
CMPQ guard, SP
JHI 3(PC)
MOVQ m->morearg, $(argsize << 32)
CALL morestack(SB)
stack frame size > StackSmall but < StackBig
LEAQ (frame-StackSmall)(SP), R0
CMPQ guard, R0
JHI 3(PC)
MOVQ m->morearg, $(argsize << 32)
CALL morestack(SB)
stack frame size >= StackBig:
MOVQ m->morearg, $((argsize << 32) | frame)
CALL morestack(SB)
The bottom StackGuard - StackSmall bytes are important: there has
to be enough room to execute functions that refuse to check for
stack overflow, either because they need to be adjacent to the
actual caller's frame (deferproc) or because they handle the imminent
stack overflow (morestack).
For example, deferproc might call malloc, which does one of the
above checks (without allocating a full frame), which might trigger
a call to morestack. This sequence needs to fit in the bottom
section of the stack. On amd64, morestack's frame is 40 bytes, and
deferproc's frame is 56 bytes. That fits well within the
StackGuard - StackSmall bytes at the bottom.
The linkers explore all possible call traces involving non-splitting
functions to make sure that this limit cannot be violated.
*/
const (
// StackSystem is a number of additional bytes to add
// to each stack below the usual guard area for OS-specific
// purposes like signal handling. Used on Windows, Plan 9,
// and iOS because they do not use a separate stack.
_StackSystem = goos.IsWindows*512*goarch.PtrSize + goos.IsPlan9*512 + goos.IsIos*goarch.IsArm64*1024
// The minimum size of stack used by Go code
_StackMin = 2048
// The minimum stack size to allocate.
// The hackery here rounds FixedStack0 up to a power of 2.
_FixedStack0 = _StackMin + _StackSystem
_FixedStack1 = _FixedStack0 - 1
_FixedStack2 = _FixedStack1 | (_FixedStack1 >> 1)
_FixedStack3 = _FixedStack2 | (_FixedStack2 >> 2)
_FixedStack4 = _FixedStack3 | (_FixedStack3 >> 4)
_FixedStack5 = _FixedStack4 | (_FixedStack4 >> 8)
_FixedStack6 = _FixedStack5 | (_FixedStack5 >> 16)
_FixedStack = _FixedStack6 + 1
// Functions that need frames bigger than this use an extra
// instruction to do the stack split check, to avoid overflow
// in case SP - framesize wraps below zero.
// This value can be no bigger than the size of the unmapped
// space at zero.
_StackBig = 4096
// The stack guard is a pointer this many bytes above the
// bottom of the stack.
//
// The guard leaves enough room for one _StackSmall frame plus
// a _StackLimit chain of NOSPLIT calls plus _StackSystem
// bytes for the OS.
// This arithmetic must match that in cmd/internal/objabi/stack.go:StackLimit.
_StackGuard = 928*sys.StackGuardMultiplier + _StackSystem
// After a stack split check the SP is allowed to be this
// many bytes below the stack guard. This saves an instruction
// in the checking sequence for tiny frames.
_StackSmall = 128
// The maximum number of bytes that a chain of NOSPLIT
// functions can use.
// This arithmetic must match that in cmd/internal/objabi/stack.go:StackLimit.
_StackLimit = _StackGuard - _StackSystem - _StackSmall
)
const (
// stackDebug == 0: no logging
// == 1: logging of per-stack operations
// == 2: logging of per-frame operations
// == 3: logging of per-word updates
// == 4: logging of per-word reads
stackDebug = 0
stackFromSystem = 0 // allocate stacks from system memory instead of the heap
stackFaultOnFree = 0 // old stacks are mapped noaccess to detect use after free
stackPoisonCopy = 0 // fill stack that should not be accessed with garbage, to detect bad dereferences during copy
stackNoCache = 0 // disable per-P small stack caches
// check the BP links during traceback.
debugCheckBP = false
)
const (
uintptrMask = 1<<(8*goarch.PtrSize) - 1
// The values below can be stored to g.stackguard0 to force
// the next stack check to fail.
// These are all larger than any real SP.
// Goroutine preemption request.
// 0xfffffade in hex.
stackPreempt = uintptrMask & -1314
// Thread is forking. Causes a split stack check failure.
// 0xfffffb2e in hex.
stackFork = uintptrMask & -1234
// Force a stack movement. Used for debugging.
// 0xfffffeed in hex.
stackForceMove = uintptrMask & -275
// stackPoisonMin is the lowest allowed stack poison value.
stackPoisonMin = uintptrMask & -4096
)
// Global pool of spans that have free stacks.
// Stacks are assigned an order according to size.
//
// order = log_2(size/FixedStack)
//
// There is a free list for each order.
var stackpool [_NumStackOrders]struct {
item stackpoolItem
_ [(cpu.CacheLinePadSize - unsafe.Sizeof(stackpoolItem{})%cpu.CacheLinePadSize) % cpu.CacheLinePadSize]byte
}
type stackpoolItem struct {
_ sys.NotInHeap
mu mutex
span mSpanList
}
// Global pool of large stack spans.
var stackLarge struct {
lock mutex
free [heapAddrBits - pageShift]mSpanList // free lists by log_2(s.npages)
}
func stackinit() {
if _StackCacheSize&_PageMask != 0 {
throw("cache size must be a multiple of page size")
}
for i := range stackpool {
stackpool[i].item.span.init()
lockInit(&stackpool[i].item.mu, lockRankStackpool)
}
for i := range stackLarge.free {
stackLarge.free[i].init()
lockInit(&stackLarge.lock, lockRankStackLarge)
}
}
// stacklog2 returns ⌊log_2(n)⌋.
func stacklog2(n uintptr) int {
log2 := 0
for n > 1 {
n >>= 1
log2++
}
return log2
}
// Allocates a stack from the free pool. Must be called with
// stackpool[order].item.mu held.
func stackpoolalloc(order uint8) gclinkptr {
list := &stackpool[order].item.span
s := list.first
lockWithRankMayAcquire(&mheap_.lock, lockRankMheap)
if s == nil {
// no free stacks. Allocate another span worth.
s = mheap_.allocManual(_StackCacheSize>>_PageShift, spanAllocStack)
if s == nil {
throw("out of memory")
}
if s.allocCount != 0 {
throw("bad allocCount")
}
if s.manualFreeList.ptr() != nil {
throw("bad manualFreeList")
}
osStackAlloc(s)
s.elemsize = _FixedStack << order
for i := uintptr(0); i < _StackCacheSize; i += s.elemsize {
x := gclinkptr(s.base() + i)
x.ptr().next = s.manualFreeList
s.manualFreeList = x
}
list.insert(s)
}
x := s.manualFreeList
if x.ptr() == nil {
throw("span has no free stacks")
}
s.manualFreeList = x.ptr().next
s.allocCount++
if s.manualFreeList.ptr() == nil {
// all stacks in s are allocated.
list.remove(s)
}
return x
}
// Adds stack x to the free pool. Must be called with stackpool[order].item.mu held.
func stackpoolfree(x gclinkptr, order uint8) {
s := spanOfUnchecked(uintptr(x))
if s.state.get() != mSpanManual {
throw("freeing stack not in a stack span")
}
if s.manualFreeList.ptr() == nil {
// s will now have a free stack
stackpool[order].item.span.insert(s)
}
x.ptr().next = s.manualFreeList
s.manualFreeList = x
s.allocCount--
if gcphase == _GCoff && s.allocCount == 0 {
// Span is completely free. Return it to the heap
// immediately if we're sweeping.
//
// If GC is active, we delay the free until the end of
// GC to avoid the following type of situation:
//
// 1) GC starts, scans a SudoG but does not yet mark the SudoG.elem pointer
// 2) The stack that pointer points to is copied
// 3) The old stack is freed
// 4) The containing span is marked free
// 5) GC attempts to mark the SudoG.elem pointer. The
// marking fails because the pointer looks like a
// pointer into a free span.
//
// By not freeing, we prevent step #4 until GC is done.
stackpool[order].item.span.remove(s)
s.manualFreeList = 0
osStackFree(s)
mheap_.freeManual(s, spanAllocStack)
}
}
// stackcacherefill/stackcacherelease implement a global pool of stack segments.
// The pool is required to prevent unlimited growth of per-thread caches.
//
//go:systemstack
func stackcacherefill(c *mcache, order uint8) {
if stackDebug >= 1 {
print("stackcacherefill order=", order, "\n")
}
// Grab some stacks from the global cache.
// Grab half of the allowed capacity (to prevent thrashing).
var list gclinkptr
var size uintptr
lock(&stackpool[order].item.mu)
for size < _StackCacheSize/2 {
x := stackpoolalloc(order)
x.ptr().next = list
list = x
size += _FixedStack << order
}
unlock(&stackpool[order].item.mu)
c.stackcache[order].list = list
c.stackcache[order].size = size
}
//go:systemstack
func stackcacherelease(c *mcache, order uint8) {
if stackDebug >= 1 {
print("stackcacherelease order=", order, "\n")
}
x := c.stackcache[order].list
size := c.stackcache[order].size
lock(&stackpool[order].item.mu)
for size > _StackCacheSize/2 {
y := x.ptr().next
stackpoolfree(x, order)
x = y
size -= _FixedStack << order
}
unlock(&stackpool[order].item.mu)
c.stackcache[order].list = x
c.stackcache[order].size = size
}
//go:systemstack
func stackcache_clear(c *mcache) {
if stackDebug >= 1 {
print("stackcache clear\n")
}
for order := uint8(0); order < _NumStackOrders; order++ {
lock(&stackpool[order].item.mu)
x := c.stackcache[order].list
for x.ptr() != nil {
y := x.ptr().next
stackpoolfree(x, order)
x = y
}
c.stackcache[order].list = 0
c.stackcache[order].size = 0
unlock(&stackpool[order].item.mu)
}
}
// stackalloc allocates an n byte stack.
//
// stackalloc must run on the system stack because it uses per-P
// resources and must not split the stack.
//
//go:systemstack
func stackalloc(n uint32) stack {
// Stackalloc must be called on scheduler stack, so that we
// never try to grow the stack during the code that stackalloc runs.
// Doing so would cause a deadlock (issue 1547).
thisg := getg()
if thisg != thisg.m.g0 {
throw("stackalloc not on scheduler stack")
}
if n&(n-1) != 0 {
throw("stack size not a power of 2")
}
if stackDebug >= 1 {
print("stackalloc ", n, "\n")
}
if debug.efence != 0 || stackFromSystem != 0 {
n = uint32(alignUp(uintptr(n), physPageSize))
v := sysAlloc(uintptr(n), &memstats.stacks_sys)
if v == nil {
throw("out of memory (stackalloc)")
}
return stack{uintptr(v), uintptr(v) + uintptr(n)}
}
// Small stacks are allocated with a fixed-size free-list allocator.
// If we need a stack of a bigger size, we fall back on allocating
// a dedicated span.
var v unsafe.Pointer
if n < _FixedStack<<_NumStackOrders && n < _StackCacheSize {
order := uint8(0)
n2 := n
for n2 > _FixedStack {
order++
n2 >>= 1
}
var x gclinkptr
if stackNoCache != 0 || thisg.m.p == 0 || thisg.m.preemptoff != "" {
// thisg.m.p == 0 can happen in the guts of exitsyscall
// or procresize. Just get a stack from the global pool.
// Also don't touch stackcache during gc
// as it's flushed concurrently.
lock(&stackpool[order].item.mu)
x = stackpoolalloc(order)
unlock(&stackpool[order].item.mu)
} else {
c := thisg.m.p.ptr().mcache
x = c.stackcache[order].list
if x.ptr() == nil {
stackcacherefill(c, order)
x = c.stackcache[order].list
}
c.stackcache[order].list = x.ptr().next
c.stackcache[order].size -= uintptr(n)
}
v = unsafe.Pointer(x)
} else {
var s *mspan
npage := uintptr(n) >> _PageShift
log2npage := stacklog2(npage)
// Try to get a stack from the large stack cache.
lock(&stackLarge.lock)
if !stackLarge.free[log2npage].isEmpty() {
s = stackLarge.free[log2npage].first
stackLarge.free[log2npage].remove(s)
}
unlock(&stackLarge.lock)
lockWithRankMayAcquire(&mheap_.lock, lockRankMheap)
if s == nil {
// Allocate a new stack from the heap.
s = mheap_.allocManual(npage, spanAllocStack)
if s == nil {
throw("out of memory")
}
osStackAlloc(s)
s.elemsize = uintptr(n)
}
v = unsafe.Pointer(s.base())
}
if raceenabled {
racemalloc(v, uintptr(n))
}
if msanenabled {
msanmalloc(v, uintptr(n))
}
if asanenabled {
asanunpoison(v, uintptr(n))
}
if stackDebug >= 1 {
print(" allocated ", v, "\n")
}
return stack{uintptr(v), uintptr(v) + uintptr(n)}
}
// stackfree frees an n byte stack allocation at stk.
//
// stackfree must run on the system stack because it uses per-P
// resources and must not split the stack.
//
//go:systemstack
func stackfree(stk stack) {
gp := getg()
v := unsafe.Pointer(stk.lo)
n := stk.hi - stk.lo
if n&(n-1) != 0 {
throw("stack not a power of 2")
}
if stk.lo+n < stk.hi {
throw("bad stack size")
}
if stackDebug >= 1 {
println("stackfree", v, n)
memclrNoHeapPointers(v, n) // for testing, clobber stack data
}
if debug.efence != 0 || stackFromSystem != 0 {
if debug.efence != 0 || stackFaultOnFree != 0 {
sysFault(v, n)
} else {
sysFree(v, n, &memstats.stacks_sys)
}
return
}
if msanenabled {
msanfree(v, n)
}
if asanenabled {
asanpoison(v, n)
}
if n < _FixedStack<<_NumStackOrders && n < _StackCacheSize {
order := uint8(0)
n2 := n
for n2 > _FixedStack {
order++
n2 >>= 1
}
x := gclinkptr(v)
if stackNoCache != 0 || gp.m.p == 0 || gp.m.preemptoff != "" {
lock(&stackpool[order].item.mu)
stackpoolfree(x, order)
unlock(&stackpool[order].item.mu)
} else {
c := gp.m.p.ptr().mcache
if c.stackcache[order].size >= _StackCacheSize {
stackcacherelease(c, order)
}
x.ptr().next = c.stackcache[order].list
c.stackcache[order].list = x
c.stackcache[order].size += n
}
} else {
s := spanOfUnchecked(uintptr(v))
if s.state.get() != mSpanManual {
println(hex(s.base()), v)
throw("bad span state")
}
if gcphase == _GCoff {
// Free the stack immediately if we're
// sweeping.
osStackFree(s)
mheap_.freeManual(s, spanAllocStack)
} else {
// If the GC is running, we can't return a
// stack span to the heap because it could be
// reused as a heap span, and this state
// change would race with GC. Add it to the
// large stack cache instead.
log2npage := stacklog2(s.npages)
lock(&stackLarge.lock)
stackLarge.free[log2npage].insert(s)
unlock(&stackLarge.lock)
}
}
}
var maxstacksize uintptr = 1 << 20 // enough until runtime.main sets it for real
var maxstackceiling = maxstacksize
var ptrnames = []string{
0: "scalar",
1: "ptr",
}
// Stack frame layout
//
// (x86)
// +------------------+
// | args from caller |
// +------------------+ <- frame->argp
// | return address |
// +------------------+
// | caller's BP (*) | (*) if framepointer_enabled && varp < sp
// +------------------+ <- frame->varp
// | locals |
// +------------------+
// | args to callee |
// +------------------+ <- frame->sp
//
// (arm)
// +------------------+
// | args from caller |
// +------------------+ <- frame->argp
// | caller's retaddr |
// +------------------+ <- frame->varp
// | locals |
// +------------------+
// | args to callee |
// +------------------+
// | return address |
// +------------------+ <- frame->sp
type adjustinfo struct {
old stack
delta uintptr // ptr distance from old to new stack (newbase - oldbase)
cache pcvalueCache
// sghi is the highest sudog.elem on the stack.
sghi uintptr
}
// adjustpointer checks whether *vpp is in the old stack described by adjinfo.
// If so, it rewrites *vpp to point into the new stack.
func adjustpointer(adjinfo *adjustinfo, vpp unsafe.Pointer) {
pp := (*uintptr)(vpp)
p := *pp
if stackDebug >= 4 {
print(" ", pp, ":", hex(p), "\n")
}
if adjinfo.old.lo <= p && p < adjinfo.old.hi {
*pp = p + adjinfo.delta
if stackDebug >= 3 {
print(" adjust ptr ", pp, ":", hex(p), " -> ", hex(*pp), "\n")
}
}
}
// Information from the compiler about the layout of stack frames.
// Note: this type must agree with reflect.bitVector.
type bitvector struct {
n int32 // # of bits
bytedata *uint8
}
// ptrbit returns the i'th bit in bv.
// ptrbit is less efficient than iterating directly over bitvector bits,
// and should only be used in non-performance-critical code.
// See adjustpointers for an example of a high-efficiency walk of a bitvector.
func (bv *bitvector) ptrbit(i uintptr) uint8 {
b := *(addb(bv.bytedata, i/8))
return (b >> (i % 8)) & 1
}
// bv describes the memory starting at address scanp.
// Adjust any pointers contained therein.
func adjustpointers(scanp unsafe.Pointer, bv *bitvector, adjinfo *adjustinfo, f funcInfo) {
minp := adjinfo.old.lo
maxp := adjinfo.old.hi
delta := adjinfo.delta
num := uintptr(bv.n)
// If this frame might contain channel receive slots, use CAS
// to adjust pointers. If the slot hasn't been received into
// yet, it may contain stack pointers and a concurrent send
// could race with adjusting those pointers. (The sent value
// itself can never contain stack pointers.)
useCAS := uintptr(scanp) < adjinfo.sghi
for i := uintptr(0); i < num; i += 8 {
if stackDebug >= 4 {
for j := uintptr(0); j < 8; j++ {
print(" ", add(scanp, (i+j)*goarch.PtrSize), ":", ptrnames[bv.ptrbit(i+j)], ":", hex(*(*uintptr)(add(scanp, (i+j)*goarch.PtrSize))), " # ", i, " ", *addb(bv.bytedata, i/8), "\n")
}
}
b := *(addb(bv.bytedata, i/8))
for b != 0 {
j := uintptr(sys.TrailingZeros8(b))
b &= b - 1
pp := (*uintptr)(add(scanp, (i+j)*goarch.PtrSize))
retry:
p := *pp
if f.valid() && 0 < p && p < minLegalPointer && debug.invalidptr != 0 {
// Looks like a junk value in a pointer slot.
// Live analysis wrong?
getg().m.traceback = 2
print("runtime: bad pointer in frame ", funcname(f), " at ", pp, ": ", hex(p), "\n")
throw("invalid pointer found on stack")
}
if minp <= p && p < maxp {
if stackDebug >= 3 {
print("adjust ptr ", hex(p), " ", funcname(f), "\n")
}
if useCAS {
ppu := (*unsafe.Pointer)(unsafe.Pointer(pp))
if !atomic.Casp1(ppu, unsafe.Pointer(p), unsafe.Pointer(p+delta)) {
goto retry
}
} else {
*pp = p + delta
}
}
}
}
}
// Note: the argument/return area is adjusted by the callee.
func adjustframe(frame *stkframe, arg unsafe.Pointer) bool {
adjinfo := (*adjustinfo)(arg)
if frame.continpc == 0 {
// Frame is dead.
return true
}
f := frame.fn
if stackDebug >= 2 {
print(" adjusting ", funcname(f), " frame=[", hex(frame.sp), ",", hex(frame.fp), "] pc=", hex(frame.pc), " continpc=", hex(frame.continpc), "\n")
}
if f.funcID == funcID_systemstack_switch {
// A special routine at the bottom of stack of a goroutine that does a systemstack call.
// We will allow it to be copied even though we don't
// have full GC info for it (because it is written in asm).
return true
}
locals, args, objs := frame.getStackMap(&adjinfo.cache, true)
// Adjust local variables if stack frame has been allocated.
if locals.n > 0 {
size := uintptr(locals.n) * goarch.PtrSize
adjustpointers(unsafe.Pointer(frame.varp-size), &locals, adjinfo, f)
}
// Adjust saved base pointer if there is one.
// TODO what about arm64 frame pointer adjustment?
if goarch.ArchFamily == goarch.AMD64 && frame.argp-frame.varp == 2*goarch.PtrSize {
if stackDebug >= 3 {
print(" saved bp\n")
}
if debugCheckBP {
// Frame pointers should always point to the next higher frame on
// the Go stack (or be nil, for the top frame on the stack).
bp := *(*uintptr)(unsafe.Pointer(frame.varp))
if bp != 0 && (bp < adjinfo.old.lo || bp >= adjinfo.old.hi) {
println("runtime: found invalid frame pointer")
print("bp=", hex(bp), " min=", hex(adjinfo.old.lo), " max=", hex(adjinfo.old.hi), "\n")
throw("bad frame pointer")
}
}
adjustpointer(adjinfo, unsafe.Pointer(frame.varp))
}
// Adjust arguments.
if args.n > 0 {
if stackDebug >= 3 {
print(" args\n")
}
adjustpointers(unsafe.Pointer(frame.argp), &args, adjinfo, funcInfo{})
}
// Adjust pointers in all stack objects (whether they are live or not).
// See comments in mgcmark.go:scanframeworker.
if frame.varp != 0 {
for i := range objs {
obj := &objs[i]
off := obj.off
base := frame.varp // locals base pointer
if off >= 0 {
base = frame.argp // arguments and return values base pointer
}
p := base + uintptr(off)
if p < frame.sp {
// Object hasn't been allocated in the frame yet.
// (Happens when the stack bounds check fails and
// we call into morestack.)
continue
}
ptrdata := obj.ptrdata()
gcdata := obj.gcdata()
var s *mspan
if obj.useGCProg() {
// See comments in mgcmark.go:scanstack
s = materializeGCProg(ptrdata, gcdata)
gcdata = (*byte)(unsafe.Pointer(s.startAddr))
}
for i := uintptr(0); i < ptrdata; i += goarch.PtrSize {
if *addb(gcdata, i/(8*goarch.PtrSize))>>(i/goarch.PtrSize&7)&1 != 0 {
adjustpointer(adjinfo, unsafe.Pointer(p+i))
}
}
if s != nil {
dematerializeGCProg(s)
}
}
}
return true
}
func adjustctxt(gp *g, adjinfo *adjustinfo) {
adjustpointer(adjinfo, unsafe.Pointer(&gp.sched.ctxt))
if !framepointer_enabled {
return
}
if debugCheckBP {
bp := gp.sched.bp
if bp != 0 && (bp < adjinfo.old.lo || bp >= adjinfo.old.hi) {
println("runtime: found invalid top frame pointer")
print("bp=", hex(bp), " min=", hex(adjinfo.old.lo), " max=", hex(adjinfo.old.hi), "\n")
throw("bad top frame pointer")
}
}
adjustpointer(adjinfo, unsafe.Pointer(&gp.sched.bp))
}
func adjustdefers(gp *g, adjinfo *adjustinfo) {
// Adjust pointers in the Defer structs.
// We need to do this first because we need to adjust the
// defer.link fields so we always work on the new stack.
adjustpointer(adjinfo, unsafe.Pointer(&gp._defer))
for d := gp._defer; d != nil; d = d.link {
adjustpointer(adjinfo, unsafe.Pointer(&d.fn))
adjustpointer(adjinfo, unsafe.Pointer(&d.sp))
adjustpointer(adjinfo, unsafe.Pointer(&d._panic))
adjustpointer(adjinfo, unsafe.Pointer(&d.link))
adjustpointer(adjinfo, unsafe.Pointer(&d.varp))
adjustpointer(adjinfo, unsafe.Pointer(&d.fd))
}
}
func adjustpanics(gp *g, adjinfo *adjustinfo) {
// Panics are on stack and already adjusted.
// Update pointer to head of list in G.
adjustpointer(adjinfo, unsafe.Pointer(&gp._panic))
}
func adjustsudogs(gp *g, adjinfo *adjustinfo) {
// the data elements pointed to by a SudoG structure
// might be in the stack.
for s := gp.waiting; s != nil; s = s.waitlink {
adjustpointer(adjinfo, unsafe.Pointer(&s.elem))
}
}
func fillstack(stk stack, b byte) {
for p := stk.lo; p < stk.hi; p++ {
*(*byte)(unsafe.Pointer(p)) = b
}
}
func findsghi(gp *g, stk stack) uintptr {
var sghi uintptr
for sg := gp.waiting; sg != nil; sg = sg.waitlink {
p := uintptr(sg.elem) + uintptr(sg.c.elemsize)
if stk.lo <= p && p < stk.hi && p > sghi {
sghi = p
}
}
return sghi
}
// syncadjustsudogs adjusts gp's sudogs and copies the part of gp's
// stack they refer to while synchronizing with concurrent channel
// operations. It returns the number of bytes of stack copied.
func syncadjustsudogs(gp *g, used uintptr, adjinfo *adjustinfo) uintptr {
if gp.waiting == nil {
return 0
}
// Lock channels to prevent concurrent send/receive.
var lastc *hchan
for sg := gp.waiting; sg != nil; sg = sg.waitlink {
if sg.c != lastc {
// There is a ranking cycle here between gscan bit and
// hchan locks. Normally, we only allow acquiring hchan
// locks and then getting a gscan bit. In this case, we
// already have the gscan bit. We allow acquiring hchan
// locks here as a special case, since a deadlock can't
// happen because the G involved must already be
// suspended. So, we get a special hchan lock rank here
// that is lower than gscan, but doesn't allow acquiring
// any other locks other than hchan.
lockWithRank(&sg.c.lock, lockRankHchanLeaf)
}
lastc = sg.c
}
// Adjust sudogs.
adjustsudogs(gp, adjinfo)
// Copy the part of the stack the sudogs point in to
// while holding the lock to prevent races on
// send/receive slots.
var sgsize uintptr
if adjinfo.sghi != 0 {
oldBot := adjinfo.old.hi - used
newBot := oldBot + adjinfo.delta
sgsize = adjinfo.sghi - oldBot
memmove(unsafe.Pointer(newBot), unsafe.Pointer(oldBot), sgsize)
}
// Unlock channels.
lastc = nil
for sg := gp.waiting; sg != nil; sg = sg.waitlink {
if sg.c != lastc {
unlock(&sg.c.lock)
}
lastc = sg.c
}
return sgsize
}
// Copies gp's stack to a new stack of a different size.
// Caller must have changed gp status to Gcopystack.
func copystack(gp *g, newsize uintptr) {
if gp.syscallsp != 0 {
throw("stack growth not allowed in system call")
}
old := gp.stack
if old.lo == 0 {
throw("nil stackbase")
}
used := old.hi - gp.sched.sp
// Add just the difference to gcController.addScannableStack.
// g0 stacks never move, so this will never account for them.
// It's also fine if we have no P, addScannableStack can deal with
// that case.
gcController.addScannableStack(getg().m.p.ptr(), int64(newsize)-int64(old.hi-old.lo))
// allocate new stack
new := stackalloc(uint32(newsize))
if stackPoisonCopy != 0 {
fillstack(new, 0xfd)
}
if stackDebug >= 1 {
print("copystack gp=", gp, " [", hex(old.lo), " ", hex(old.hi-used), " ", hex(old.hi), "]", " -> [", hex(new.lo), " ", hex(new.hi-used), " ", hex(new.hi), "]/", newsize, "\n")
}
// Compute adjustment.
var adjinfo adjustinfo
adjinfo.old = old
adjinfo.delta = new.hi - old.hi
// Adjust sudogs, synchronizing with channel ops if necessary.
ncopy := used
if !gp.activeStackChans {
if newsize < old.hi-old.lo && gp.parkingOnChan.Load() {
// It's not safe for someone to shrink this stack while we're actively
// parking on a channel, but it is safe to grow since we do that
// ourselves and explicitly don't want to synchronize with channels
// since we could self-deadlock.
throw("racy sudog adjustment due to parking on channel")
}
adjustsudogs(gp, &adjinfo)
} else {
// sudogs may be pointing in to the stack and gp has
// released channel locks, so other goroutines could
// be writing to gp's stack. Find the highest such
// pointer so we can handle everything there and below
// carefully. (This shouldn't be far from the bottom
// of the stack, so there's little cost in handling
// everything below it carefully.)
adjinfo.sghi = findsghi(gp, old)
// Synchronize with channel ops and copy the part of
// the stack they may interact with.
ncopy -= syncadjustsudogs(gp, used, &adjinfo)
}
// Copy the stack (or the rest of it) to the new location
memmove(unsafe.Pointer(new.hi-ncopy), unsafe.Pointer(old.hi-ncopy), ncopy)
// Adjust remaining structures that have pointers into stacks.
// We have to do most of these before we traceback the new
// stack because gentraceback uses them.
adjustctxt(gp, &adjinfo)
adjustdefers(gp, &adjinfo)
adjustpanics(gp, &adjinfo)
if adjinfo.sghi != 0 {
adjinfo.sghi += adjinfo.delta
}
// Swap out old stack for new one
gp.stack = new
gp.stackguard0 = new.lo + _StackGuard // NOTE: might clobber a preempt request
gp.sched.sp = new.hi - used
gp.stktopsp += adjinfo.delta
// Adjust pointers in the new stack.
gentraceback(^uintptr(0), ^uintptr(0), 0, gp, 0, nil, 0x7fffffff, adjustframe, noescape(unsafe.Pointer(&adjinfo)), 0)
// free old stack
if stackPoisonCopy != 0 {
fillstack(old, 0xfc)
}
stackfree(old)
}
// round x up to a power of 2.
func round2(x int32) int32 {
s := uint(0)
for 1<<s < x {
s++
}
return 1 << s
}
// Called from runtime·morestack when more stack is needed.
// Allocate larger stack and relocate to new stack.
// Stack growth is multiplicative, for constant amortized cost.
//
// g->atomicstatus will be Grunning or Gscanrunning upon entry.
// If the scheduler is trying to stop this g, then it will set preemptStop.
//
// This must be nowritebarrierrec because it can be called as part of
// stack growth from other nowritebarrierrec functions, but the
// compiler doesn't check this.
//
//go:nowritebarrierrec
func newstack() {
thisg := getg()
// TODO: double check all gp. shouldn't be getg().
if thisg.m.morebuf.g.ptr().stackguard0 == stackFork {
throw("stack growth after fork")
}
if thisg.m.morebuf.g.ptr() != thisg.m.curg {
print("runtime: newstack called from g=", hex(thisg.m.morebuf.g), "\n"+"\tm=", thisg.m, " m->curg=", thisg.m.curg, " m->g0=", thisg.m.g0, " m->gsignal=", thisg.m.gsignal, "\n")
morebuf := thisg.m.morebuf
traceback(morebuf.pc, morebuf.sp, morebuf.lr, morebuf.g.ptr())
throw("runtime: wrong goroutine in newstack")
}
gp := thisg.m.curg
if thisg.m.curg.throwsplit {
// Update syscallsp, syscallpc in case traceback uses them.
morebuf := thisg.m.morebuf
gp.syscallsp = morebuf.sp
gp.syscallpc = morebuf.pc
pcname, pcoff := "(unknown)", uintptr(0)
f := findfunc(gp.sched.pc)
if f.valid() {
pcname = funcname(f)
pcoff = gp.sched.pc - f.entry()
}
print("runtime: newstack at ", pcname, "+", hex(pcoff),
" sp=", hex(gp.sched.sp), " stack=[", hex(gp.stack.lo), ", ", hex(gp.stack.hi), "]\n",
"\tmorebuf={pc:", hex(morebuf.pc), " sp:", hex(morebuf.sp), " lr:", hex(morebuf.lr), "}\n",
"\tsched={pc:", hex(gp.sched.pc), " sp:", hex(gp.sched.sp), " lr:", hex(gp.sched.lr), " ctxt:", gp.sched.ctxt, "}\n")
thisg.m.traceback = 2 // Include runtime frames
traceback(morebuf.pc, morebuf.sp, morebuf.lr, gp)
throw("runtime: stack split at bad time")
}
morebuf := thisg.m.morebuf
thisg.m.morebuf.pc = 0
thisg.m.morebuf.lr = 0
thisg.m.morebuf.sp = 0
thisg.m.morebuf.g = 0
// NOTE: stackguard0 may change underfoot, if another thread
// is about to try to preempt gp. Read it just once and use that same
// value now and below.
stackguard0 := atomic.Loaduintptr(&gp.stackguard0)
// Be conservative about where we preempt.
// We are interested in preempting user Go code, not runtime code.
// If we're holding locks, mallocing, or preemption is disabled, don't
// preempt.
// This check is very early in newstack so that even the status change
// from Grunning to Gwaiting and back doesn't happen in this case.
// That status change by itself can be viewed as a small preemption,
// because the GC might change Gwaiting to Gscanwaiting, and then
// this goroutine has to wait for the GC to finish before continuing.
// If the GC is in some way dependent on this goroutine (for example,
// it needs a lock held by the goroutine), that small preemption turns
// into a real deadlock.
preempt := stackguard0 == stackPreempt
if preempt {
if !canPreemptM(thisg.m) {
// Let the goroutine keep running for now.
// gp->preempt is set, so it will be preempted next time.
gp.stackguard0 = gp.stack.lo + _StackGuard
gogo(&gp.sched) // never return
}
}
if gp.stack.lo == 0 {
throw("missing stack in newstack")
}
sp := gp.sched.sp
if goarch.ArchFamily == goarch.AMD64 || goarch.ArchFamily == goarch.I386 || goarch.ArchFamily == goarch.WASM {
// The call to morestack cost a word.
sp -= goarch.PtrSize
}
if stackDebug >= 1 || sp < gp.stack.lo {
print("runtime: newstack sp=", hex(sp), " stack=[", hex(gp.stack.lo), ", ", hex(gp.stack.hi), "]\n",
"\tmorebuf={pc:", hex(morebuf.pc), " sp:", hex(morebuf.sp), " lr:", hex(morebuf.lr), "}\n",
"\tsched={pc:", hex(gp.sched.pc), " sp:", hex(gp.sched.sp), " lr:", hex(gp.sched.lr), " ctxt:", gp.sched.ctxt, "}\n")
}
if sp < gp.stack.lo {
print("runtime: gp=", gp, ", goid=", gp.goid, ", gp->status=", hex(readgstatus(gp)), "\n ")
print("runtime: split stack overflow: ", hex(sp), " < ", hex(gp.stack.lo), "\n")
throw("runtime: split stack overflow")
}
if preempt {
if gp == thisg.m.g0 {
throw("runtime: preempt g0")
}
if thisg.m.p == 0 && thisg.m.locks == 0 {
throw("runtime: g is running but p is not")
}
if gp.preemptShrink {
// We're at a synchronous safe point now, so
// do the pending stack shrink.
gp.preemptShrink = false
shrinkstack(gp)
}
if gp.preemptStop {
preemptPark(gp) // never returns
}
// Act like goroutine called runtime.Gosched.
gopreempt_m(gp) // never return
}
// Allocate a bigger segment and move the stack.
oldsize := gp.stack.hi - gp.stack.lo
newsize := oldsize * 2
// Make sure we grow at least as much as needed to fit the new frame.
// (This is just an optimization - the caller of morestack will
// recheck the bounds on return.)
if f := findfunc(gp.sched.pc); f.valid() {
max := uintptr(funcMaxSPDelta(f))
needed := max + _StackGuard
used := gp.stack.hi - gp.sched.sp
for newsize-used < needed {
newsize *= 2
}
}
if stackguard0 == stackForceMove {
// Forced stack movement used for debugging.
// Don't double the stack (or we may quickly run out
// if this is done repeatedly).
newsize = oldsize
}
if newsize > maxstacksize || newsize > maxstackceiling {
if maxstacksize < maxstackceiling {
print("runtime: goroutine stack exceeds ", maxstacksize, "-byte limit\n")
} else {
print("runtime: goroutine stack exceeds ", maxstackceiling, "-byte limit\n")
}
print("runtime: sp=", hex(sp), " stack=[", hex(gp.stack.lo), ", ", hex(gp.stack.hi), "]\n")
throw("stack overflow")
}
// The goroutine must be executing in order to call newstack,
// so it must be Grunning (or Gscanrunning).
casgstatus(gp, _Grunning, _Gcopystack)
// The concurrent GC will not scan the stack while we are doing the copy since
// the gp is in a Gcopystack status.
copystack(gp, newsize)
if stackDebug >= 1 {
print("stack grow done\n")
}
casgstatus(gp, _Gcopystack, _Grunning)
gogo(&gp.sched)
}
//go:nosplit
func nilfunc() {
*(*uint8)(nil) = 0
}
// adjust Gobuf as if it executed a call to fn
// and then stopped before the first instruction in fn.
func gostartcallfn(gobuf *gobuf, fv *funcval) {
var fn unsafe.Pointer
if fv != nil {
fn = unsafe.Pointer(fv.fn)
} else {
fn = unsafe.Pointer(abi.FuncPCABIInternal(nilfunc))
}
gostartcall(gobuf, fn, unsafe.Pointer(fv))
}
// isShrinkStackSafe returns whether it's safe to attempt to shrink
// gp's stack. Shrinking the stack is only safe when we have precise
// pointer maps for all frames on the stack.
func isShrinkStackSafe(gp *g) bool {
// We can't copy the stack if we're in a syscall.
// The syscall might have pointers into the stack and
// often we don't have precise pointer maps for the innermost
// frames.
//
// We also can't copy the stack if we're at an asynchronous
// safe-point because we don't have precise pointer maps for
// all frames.
//
// We also can't *shrink* the stack in the window between the
// goroutine calling gopark to park on a channel and
// gp.activeStackChans being set.
return gp.syscallsp == 0 && !gp.asyncSafePoint && !gp.parkingOnChan.Load()
}
// Maybe shrink the stack being used by gp.
//
// gp must be stopped and we must own its stack. It may be in
// _Grunning, but only if this is our own user G.
func shrinkstack(gp *g) {
if gp.stack.lo == 0 {
throw("missing stack in shrinkstack")
}
if s := readgstatus(gp); s&_Gscan == 0 {
// We don't own the stack via _Gscan. We could still
// own it if this is our own user G and we're on the
// system stack.
if !(gp == getg().m.curg && getg() != getg().m.curg && s == _Grunning) {
// We don't own the stack.
throw("bad status in shrinkstack")
}
}
if !isShrinkStackSafe(gp) {
throw("shrinkstack at bad time")
}
// Check for self-shrinks while in a libcall. These may have
// pointers into the stack disguised as uintptrs, but these
// code paths should all be nosplit.
if gp == getg().m.curg && gp.m.libcallsp != 0 {
throw("shrinking stack in libcall")
}
if debug.gcshrinkstackoff > 0 {
return
}
f := findfunc(gp.startpc)
if f.valid() && f.funcID == funcID_gcBgMarkWorker {
// We're not allowed to shrink the gcBgMarkWorker
// stack (see gcBgMarkWorker for explanation).
return
}
oldsize := gp.stack.hi - gp.stack.lo
newsize := oldsize / 2
// Don't shrink the allocation below the minimum-sized stack
// allocation.
if newsize < _FixedStack {
return
}
// Compute how much of the stack is currently in use and only
// shrink the stack if gp is using less than a quarter of its
// current stack. The currently used stack includes everything
// down to the SP plus the stack guard space that ensures
// there's room for nosplit functions.
avail := gp.stack.hi - gp.stack.lo
if used := gp.stack.hi - gp.sched.sp + _StackLimit; used >= avail/4 {
return
}
if stackDebug > 0 {
print("shrinking stack ", oldsize, "->", newsize, "\n")
}
copystack(gp, newsize)
}
// freeStackSpans frees unused stack spans at the end of GC.
func freeStackSpans() {
// Scan stack pools for empty stack spans.
for order := range stackpool {
lock(&stackpool[order].item.mu)
list := &stackpool[order].item.span
for s := list.first; s != nil; {
next := s.next
if s.allocCount == 0 {
list.remove(s)
s.manualFreeList = 0
osStackFree(s)
mheap_.freeManual(s, spanAllocStack)
}
s = next
}
unlock(&stackpool[order].item.mu)
}
// Free large stack spans.
lock(&stackLarge.lock)
for i := range stackLarge.free {
for s := stackLarge.free[i].first; s != nil; {
next := s.next
stackLarge.free[i].remove(s)
osStackFree(s)
mheap_.freeManual(s, spanAllocStack)
s = next
}
}
unlock(&stackLarge.lock)
}
// A stackObjectRecord is generated by the compiler for each stack object in a stack frame.
// This record must match the generator code in cmd/compile/internal/liveness/plive.go:emitStackObjects.
type stackObjectRecord struct {
// offset in frame
// if negative, offset from varp
// if non-negative, offset from argp
off int32
size int32
_ptrdata int32 // ptrdata, or -ptrdata is GC prog is used
gcdataoff uint32 // offset to gcdata from moduledata.rodata
}
func (r *stackObjectRecord) useGCProg() bool {
return r._ptrdata < 0
}
func (r *stackObjectRecord) ptrdata() uintptr {
x := r._ptrdata
if x < 0 {
return uintptr(-x)
}
return uintptr(x)
}
// gcdata returns pointer map or GC prog of the type.
func (r *stackObjectRecord) gcdata() *byte {
ptr := uintptr(unsafe.Pointer(r))
var mod *moduledata
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if datap.gofunc <= ptr && ptr < datap.end {
mod = datap
break
}
}
// If you get a panic here due to a nil mod,
// you may have made a copy of a stackObjectRecord.
// You must use the original pointer.
res := mod.rodata + uintptr(r.gcdataoff)
return (*byte)(unsafe.Pointer(res))
}
// This is exported as ABI0 via linkname so obj can call it.
//
//go:nosplit
//go:linkname morestackc
func morestackc() {
throw("attempt to execute system stack code on user stack")
}
// startingStackSize is the amount of stack that new goroutines start with.
// It is a power of 2, and between _FixedStack and maxstacksize, inclusive.
// startingStackSize is updated every GC by tracking the average size of
// stacks scanned during the GC.
var startingStackSize uint32 = _FixedStack
func gcComputeStartingStackSize() {
if debug.adaptivestackstart == 0 {
return
}
// For details, see the design doc at
// https://docs.google.com/document/d/1YDlGIdVTPnmUiTAavlZxBI1d9pwGQgZT7IKFKlIXohQ/edit?usp=sharing
// The basic algorithm is to track the average size of stacks
// and start goroutines with stack equal to that average size.
// Starting at the average size uses at most 2x the space that
// an ideal algorithm would have used.
// This is just a heuristic to avoid excessive stack growth work
// early in a goroutine's lifetime. See issue 18138. Stacks that
// are allocated too small can still grow, and stacks allocated
// too large can still shrink.
var scannedStackSize uint64
var scannedStacks uint64
for _, p := range allp {
scannedStackSize += p.scannedStackSize
scannedStacks += p.scannedStacks
// Reset for next time
p.scannedStackSize = 0
p.scannedStacks = 0
}
if scannedStacks == 0 {
startingStackSize = _FixedStack
return
}
avg := scannedStackSize/scannedStacks + _StackGuard
// Note: we add _StackGuard to ensure that a goroutine that
// uses the average space will not trigger a growth.
if avg > uint64(maxstacksize) {
avg = uint64(maxstacksize)
}
if avg < _FixedStack {
avg = _FixedStack
}
// Note: maxstacksize fits in 30 bits, so avg also does.
startingStackSize = uint32(round2(int32(avg)))
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/sys"
"unsafe"
)
// A stkframe holds information about a single physical stack frame.
type stkframe struct {
// fn is the function being run in this frame. If there is
// inlining, this is the outermost function.
fn funcInfo
// pc is the program counter within fn.
//
// The meaning of this is subtle:
//
// - Typically, this frame performed a regular function call
// and this is the return PC (just after the CALL
// instruction). In this case, pc-1 reflects the CALL
// instruction itself and is the correct source of symbolic
// information.
//
// - If this frame "called" sigpanic, then pc is the
// instruction that panicked, and pc is the correct address
// to use for symbolic information.
//
// - If this is the innermost frame, then PC is where
// execution will continue, but it may not be the
// instruction following a CALL. This may be from
// cooperative preemption, in which case this is the
// instruction after the call to morestack. Or this may be
// from a signal or an un-started goroutine, in which case
// PC could be any instruction, including the first
// instruction in a function. Conventionally, we use pc-1
// for symbolic information, unless pc == fn.entry(), in
// which case we use pc.
pc uintptr
// continpc is the PC where execution will continue in fn, or
// 0 if execution will not continue in this frame.
//
// This is usually the same as pc, unless this frame "called"
// sigpanic, in which case it's either the address of
// deferreturn or 0 if this frame will never execute again.
//
// This is the PC to use to look up GC liveness for this frame.
continpc uintptr
lr uintptr // program counter at caller aka link register
sp uintptr // stack pointer at pc
fp uintptr // stack pointer at caller aka frame pointer
varp uintptr // top of local variables
argp uintptr // pointer to function arguments
}
// reflectMethodValue is a partial duplicate of reflect.makeFuncImpl
// and reflect.methodValue.
type reflectMethodValue struct {
fn uintptr
stack *bitvector // ptrmap for both args and results
argLen uintptr // just args
}
// argBytes returns the argument frame size for a call to frame.fn.
func (frame *stkframe) argBytes() uintptr {
if frame.fn.args != _ArgsSizeUnknown {
return uintptr(frame.fn.args)
}
// This is an uncommon and complicated case. Fall back to fully
// fetching the argument map to compute its size.
argMap, _ := frame.argMapInternal()
return uintptr(argMap.n) * goarch.PtrSize
}
// argMapInternal is used internally by stkframe to fetch special
// argument maps.
//
// argMap.n is always populated with the size of the argument map.
//
// argMap.bytedata is only populated for dynamic argument maps (used
// by reflect). If the caller requires the argument map, it should use
// this if non-nil, and otherwise fetch the argument map using the
// current PC.
//
// hasReflectStackObj indicates that this frame also has a reflect
// function stack object, which the caller must synthesize.
func (frame *stkframe) argMapInternal() (argMap bitvector, hasReflectStackObj bool) {
f := frame.fn
if f.args != _ArgsSizeUnknown {
argMap.n = f.args / goarch.PtrSize
return
}
// Extract argument bitmaps for reflect stubs from the calls they made to reflect.
switch funcname(f) {
case "reflect.makeFuncStub", "reflect.methodValueCall":
// These take a *reflect.methodValue as their
// context register and immediately save it to 0(SP).
// Get the methodValue from 0(SP).
arg0 := frame.sp + sys.MinFrameSize
minSP := frame.fp
if !usesLR {
// The CALL itself pushes a word.
// Undo that adjustment.
minSP -= goarch.PtrSize
}
if arg0 >= minSP {
// The function hasn't started yet.
// This only happens if f was the
// start function of a new goroutine
// that hasn't run yet *and* f takes
// no arguments and has no results
// (otherwise it will get wrapped in a
// closure). In this case, we can't
// reach into its locals because it
// doesn't have locals yet, but we
// also know its argument map is
// empty.
if frame.pc != f.entry() {
print("runtime: confused by ", funcname(f), ": no frame (sp=", hex(frame.sp), " fp=", hex(frame.fp), ") at entry+", hex(frame.pc-f.entry()), "\n")
throw("reflect mismatch")
}
return bitvector{}, false // No locals, so also no stack objects
}
hasReflectStackObj = true
mv := *(**reflectMethodValue)(unsafe.Pointer(arg0))
// Figure out whether the return values are valid.
// Reflect will update this value after it copies
// in the return values.
retValid := *(*bool)(unsafe.Pointer(arg0 + 4*goarch.PtrSize))
if mv.fn != f.entry() {
print("runtime: confused by ", funcname(f), "\n")
throw("reflect mismatch")
}
argMap = *mv.stack
if !retValid {
// argMap.n includes the results, but
// those aren't valid, so drop them.
n := int32((uintptr(mv.argLen) &^ (goarch.PtrSize - 1)) / goarch.PtrSize)
if n < argMap.n {
argMap.n = n
}
}
}
return
}
// getStackMap returns the locals and arguments live pointer maps, and
// stack object list for frame.
func (frame *stkframe) getStackMap(cache *pcvalueCache, debug bool) (locals, args bitvector, objs []stackObjectRecord) {
targetpc := frame.continpc
if targetpc == 0 {
// Frame is dead. Return empty bitvectors.
return
}
f := frame.fn
pcdata := int32(-1)
if targetpc != f.entry() {
// Back up to the CALL. If we're at the function entry
// point, we want to use the entry map (-1), even if
// the first instruction of the function changes the
// stack map.
targetpc--
pcdata = pcdatavalue(f, _PCDATA_StackMapIndex, targetpc, cache)
}
if pcdata == -1 {
// We do not have a valid pcdata value but there might be a
// stackmap for this function. It is likely that we are looking
// at the function prologue, assume so and hope for the best.
pcdata = 0
}
// Local variables.
size := frame.varp - frame.sp
var minsize uintptr
switch goarch.ArchFamily {
case goarch.ARM64:
minsize = sys.StackAlign
default:
minsize = sys.MinFrameSize
}
if size > minsize {
stackid := pcdata
stkmap := (*stackmap)(funcdata(f, _FUNCDATA_LocalsPointerMaps))
if stkmap == nil || stkmap.n <= 0 {
print("runtime: frame ", funcname(f), " untyped locals ", hex(frame.varp-size), "+", hex(size), "\n")
throw("missing stackmap")
}
// If nbit == 0, there's no work to do.
if stkmap.nbit > 0 {
if stackid < 0 || stackid >= stkmap.n {
// don't know where we are
print("runtime: pcdata is ", stackid, " and ", stkmap.n, " locals stack map entries for ", funcname(f), " (targetpc=", hex(targetpc), ")\n")
throw("bad symbol table")
}
locals = stackmapdata(stkmap, stackid)
if stackDebug >= 3 && debug {
print(" locals ", stackid, "/", stkmap.n, " ", locals.n, " words ", locals.bytedata, "\n")
}
} else if stackDebug >= 3 && debug {
print(" no locals to adjust\n")
}
}
// Arguments. First fetch frame size and special-case argument maps.
var isReflect bool
args, isReflect = frame.argMapInternal()
if args.n > 0 && args.bytedata == nil {
// Non-empty argument frame, but not a special map.
// Fetch the argument map at pcdata.
stackmap := (*stackmap)(funcdata(f, _FUNCDATA_ArgsPointerMaps))
if stackmap == nil || stackmap.n <= 0 {
print("runtime: frame ", funcname(f), " untyped args ", hex(frame.argp), "+", hex(args.n*goarch.PtrSize), "\n")
throw("missing stackmap")
}
if pcdata < 0 || pcdata >= stackmap.n {
// don't know where we are
print("runtime: pcdata is ", pcdata, " and ", stackmap.n, " args stack map entries for ", funcname(f), " (targetpc=", hex(targetpc), ")\n")
throw("bad symbol table")
}
if stackmap.nbit == 0 {
args.n = 0
} else {
args = stackmapdata(stackmap, pcdata)
}
}
// stack objects.
if (GOARCH == "amd64" || GOARCH == "arm64" || GOARCH == "ppc64" || GOARCH == "ppc64le" || GOARCH == "riscv64") &&
unsafe.Sizeof(abi.RegArgs{}) > 0 && isReflect {
// For reflect.makeFuncStub and reflect.methodValueCall,
// we need to fake the stack object record.
// These frames contain an internal/abi.RegArgs at a hard-coded offset.
// This offset matches the assembly code on amd64 and arm64.
objs = methodValueCallFrameObjs[:]
} else {
p := funcdata(f, _FUNCDATA_StackObjects)
if p != nil {
n := *(*uintptr)(p)
p = add(p, goarch.PtrSize)
r0 := (*stackObjectRecord)(noescape(p))
objs = unsafe.Slice(r0, int(n))
// Note: the noescape above is needed to keep
// getStackMap from "leaking param content:
// frame". That leak propagates up to getgcmask, then
// GCMask, then verifyGCInfo, which converts the stack
// gcinfo tests into heap gcinfo tests :(
}
}
return
}
var methodValueCallFrameObjs [1]stackObjectRecord // initialized in stackobjectinit
func stkobjinit() {
var abiRegArgsEface any = abi.RegArgs{}
abiRegArgsType := efaceOf(&abiRegArgsEface)._type
if abiRegArgsType.kind&kindGCProg != 0 {
throw("abiRegArgsType needs GC Prog, update methodValueCallFrameObjs")
}
// Set methodValueCallFrameObjs[0].gcdataoff so that
// stackObjectRecord.gcdata() will work correctly with it.
ptr := uintptr(unsafe.Pointer(&methodValueCallFrameObjs[0]))
var mod *moduledata
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if datap.gofunc <= ptr && ptr < datap.end {
mod = datap
break
}
}
if mod == nil {
throw("methodValueCallFrameObjs is not in a module")
}
methodValueCallFrameObjs[0] = stackObjectRecord{
off: -int32(alignUp(abiRegArgsType.size, 8)), // It's always the highest address local.
size: int32(abiRegArgsType.size),
_ptrdata: int32(abiRegArgsType.ptrdata),
gcdataoff: uint32(uintptr(unsafe.Pointer(abiRegArgsType.gcdata)) - mod.rodata),
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/bytealg"
"internal/goarch"
"unsafe"
)
// The constant is known to the compiler.
// There is no fundamental theory behind this number.
const tmpStringBufSize = 32
type tmpBuf [tmpStringBufSize]byte
// concatstrings implements a Go string concatenation x+y+z+...
// The operands are passed in the slice a.
// If buf != nil, the compiler has determined that the result does not
// escape the calling function, so the string data can be stored in buf
// if small enough.
func concatstrings(buf *tmpBuf, a []string) string {
idx := 0
l := 0
count := 0
for i, x := range a {
n := len(x)
if n == 0 {
continue
}
if l+n < l {
throw("string concatenation too long")
}
l += n
count++
idx = i
}
if count == 0 {
return ""
}
// If there is just one string and either it is not on the stack
// or our result does not escape the calling frame (buf != nil),
// then we can return that string directly.
if count == 1 && (buf != nil || !stringDataOnStack(a[idx])) {
return a[idx]
}
s, b := rawstringtmp(buf, l)
for _, x := range a {
copy(b, x)
b = b[len(x):]
}
return s
}
func concatstring2(buf *tmpBuf, a0, a1 string) string {
return concatstrings(buf, []string{a0, a1})
}
func concatstring3(buf *tmpBuf, a0, a1, a2 string) string {
return concatstrings(buf, []string{a0, a1, a2})
}
func concatstring4(buf *tmpBuf, a0, a1, a2, a3 string) string {
return concatstrings(buf, []string{a0, a1, a2, a3})
}
func concatstring5(buf *tmpBuf, a0, a1, a2, a3, a4 string) string {
return concatstrings(buf, []string{a0, a1, a2, a3, a4})
}
// slicebytetostring converts a byte slice to a string.
// It is inserted by the compiler into generated code.
// ptr is a pointer to the first element of the slice;
// n is the length of the slice.
// Buf is a fixed-size buffer for the result,
// it is not nil if the result does not escape.
func slicebytetostring(buf *tmpBuf, ptr *byte, n int) string {
if n == 0 {
// Turns out to be a relatively common case.
// Consider that you want to parse out data between parens in "foo()bar",
// you find the indices and convert the subslice to string.
return ""
}
if raceenabled {
racereadrangepc(unsafe.Pointer(ptr),
uintptr(n),
getcallerpc(),
abi.FuncPCABIInternal(slicebytetostring))
}
if msanenabled {
msanread(unsafe.Pointer(ptr), uintptr(n))
}
if asanenabled {
asanread(unsafe.Pointer(ptr), uintptr(n))
}
if n == 1 {
p := unsafe.Pointer(&staticuint64s[*ptr])
if goarch.BigEndian {
p = add(p, 7)
}
return unsafe.String((*byte)(p), 1)
}
var p unsafe.Pointer
if buf != nil && n <= len(buf) {
p = unsafe.Pointer(buf)
} else {
p = mallocgc(uintptr(n), nil, false)
}
memmove(p, unsafe.Pointer(ptr), uintptr(n))
return unsafe.String((*byte)(p), n)
}
// stringDataOnStack reports whether the string's data is
// stored on the current goroutine's stack.
func stringDataOnStack(s string) bool {
ptr := uintptr(unsafe.Pointer(unsafe.StringData(s)))
stk := getg().stack
return stk.lo <= ptr && ptr < stk.hi
}
func rawstringtmp(buf *tmpBuf, l int) (s string, b []byte) {
if buf != nil && l <= len(buf) {
b = buf[:l]
s = slicebytetostringtmp(&b[0], len(b))
} else {
s, b = rawstring(l)
}
return
}
// slicebytetostringtmp returns a "string" referring to the actual []byte bytes.
//
// Callers need to ensure that the returned string will not be used after
// the calling goroutine modifies the original slice or synchronizes with
// another goroutine.
//
// The function is only called when instrumenting
// and otherwise intrinsified by the compiler.
//
// Some internal compiler optimizations use this function.
// - Used for m[T1{... Tn{..., string(k), ...} ...}] and m[string(k)]
// where k is []byte, T1 to Tn is a nesting of struct and array literals.
// - Used for "<"+string(b)+">" concatenation where b is []byte.
// - Used for string(b)=="foo" comparison where b is []byte.
func slicebytetostringtmp(ptr *byte, n int) string {
if raceenabled && n > 0 {
racereadrangepc(unsafe.Pointer(ptr),
uintptr(n),
getcallerpc(),
abi.FuncPCABIInternal(slicebytetostringtmp))
}
if msanenabled && n > 0 {
msanread(unsafe.Pointer(ptr), uintptr(n))
}
if asanenabled && n > 0 {
asanread(unsafe.Pointer(ptr), uintptr(n))
}
return unsafe.String(ptr, n)
}
func stringtoslicebyte(buf *tmpBuf, s string) []byte {
var b []byte
if buf != nil && len(s) <= len(buf) {
*buf = tmpBuf{}
b = buf[:len(s)]
} else {
b = rawbyteslice(len(s))
}
copy(b, s)
return b
}
func stringtoslicerune(buf *[tmpStringBufSize]rune, s string) []rune {
// two passes.
// unlike slicerunetostring, no race because strings are immutable.
n := 0
for range s {
n++
}
var a []rune
if buf != nil && n <= len(buf) {
*buf = [tmpStringBufSize]rune{}
a = buf[:n]
} else {
a = rawruneslice(n)
}
n = 0
for _, r := range s {
a[n] = r
n++
}
return a
}
func slicerunetostring(buf *tmpBuf, a []rune) string {
if raceenabled && len(a) > 0 {
racereadrangepc(unsafe.Pointer(&a[0]),
uintptr(len(a))*unsafe.Sizeof(a[0]),
getcallerpc(),
abi.FuncPCABIInternal(slicerunetostring))
}
if msanenabled && len(a) > 0 {
msanread(unsafe.Pointer(&a[0]), uintptr(len(a))*unsafe.Sizeof(a[0]))
}
if asanenabled && len(a) > 0 {
asanread(unsafe.Pointer(&a[0]), uintptr(len(a))*unsafe.Sizeof(a[0]))
}
var dum [4]byte
size1 := 0
for _, r := range a {
size1 += encoderune(dum[:], r)
}
s, b := rawstringtmp(buf, size1+3)
size2 := 0
for _, r := range a {
// check for race
if size2 >= size1 {
break
}
size2 += encoderune(b[size2:], r)
}
return s[:size2]
}
type stringStruct struct {
str unsafe.Pointer
len int
}
// Variant with *byte pointer type for DWARF debugging.
type stringStructDWARF struct {
str *byte
len int
}
func stringStructOf(sp *string) *stringStruct {
return (*stringStruct)(unsafe.Pointer(sp))
}
func intstring(buf *[4]byte, v int64) (s string) {
var b []byte
if buf != nil {
b = buf[:]
s = slicebytetostringtmp(&b[0], len(b))
} else {
s, b = rawstring(4)
}
if int64(rune(v)) != v {
v = runeError
}
n := encoderune(b, rune(v))
return s[:n]
}
// rawstring allocates storage for a new string. The returned
// string and byte slice both refer to the same storage.
// The storage is not zeroed. Callers should use
// b to set the string contents and then drop b.
func rawstring(size int) (s string, b []byte) {
p := mallocgc(uintptr(size), nil, false)
return unsafe.String((*byte)(p), size), unsafe.Slice((*byte)(p), size)
}
// rawbyteslice allocates a new byte slice. The byte slice is not zeroed.
func rawbyteslice(size int) (b []byte) {
cap := roundupsize(uintptr(size))
p := mallocgc(cap, nil, false)
if cap != uintptr(size) {
memclrNoHeapPointers(add(p, uintptr(size)), cap-uintptr(size))
}
*(*slice)(unsafe.Pointer(&b)) = slice{p, size, int(cap)}
return
}
// rawruneslice allocates a new rune slice. The rune slice is not zeroed.
func rawruneslice(size int) (b []rune) {
if uintptr(size) > maxAlloc/4 {
throw("out of memory")
}
mem := roundupsize(uintptr(size) * 4)
p := mallocgc(mem, nil, false)
if mem != uintptr(size)*4 {
memclrNoHeapPointers(add(p, uintptr(size)*4), mem-uintptr(size)*4)
}
*(*slice)(unsafe.Pointer(&b)) = slice{p, size, int(mem / 4)}
return
}
// used by cmd/cgo
func gobytes(p *byte, n int) (b []byte) {
if n == 0 {
return make([]byte, 0)
}
if n < 0 || uintptr(n) > maxAlloc {
panic(errorString("gobytes: length out of range"))
}
bp := mallocgc(uintptr(n), nil, false)
memmove(bp, unsafe.Pointer(p), uintptr(n))
*(*slice)(unsafe.Pointer(&b)) = slice{bp, n, n}
return
}
// This is exported via linkname to assembly in syscall (for Plan9).
//
//go:linkname gostring
func gostring(p *byte) string {
l := findnull(p)
if l == 0 {
return ""
}
s, b := rawstring(l)
memmove(unsafe.Pointer(&b[0]), unsafe.Pointer(p), uintptr(l))
return s
}
// internal_syscall_gostring is a version of gostring for internal/syscall/unix.
//
//go:linkname internal_syscall_gostring internal/syscall/unix.gostring
func internal_syscall_gostring(p *byte) string {
return gostring(p)
}
func gostringn(p *byte, l int) string {
if l == 0 {
return ""
}
s, b := rawstring(l)
memmove(unsafe.Pointer(&b[0]), unsafe.Pointer(p), uintptr(l))
return s
}
func hasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
const (
maxUint64 = ^uint64(0)
maxInt64 = int64(maxUint64 >> 1)
)
// atoi64 parses an int64 from a string s.
// The bool result reports whether s is a number
// representable by a value of type int64.
func atoi64(s string) (int64, bool) {
if s == "" {
return 0, false
}
neg := false
if s[0] == '-' {
neg = true
s = s[1:]
}
un := uint64(0)
for i := 0; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
return 0, false
}
if un > maxUint64/10 {
// overflow
return 0, false
}
un *= 10
un1 := un + uint64(c) - '0'
if un1 < un {
// overflow
return 0, false
}
un = un1
}
if !neg && un > uint64(maxInt64) {
return 0, false
}
if neg && un > uint64(maxInt64)+1 {
return 0, false
}
n := int64(un)
if neg {
n = -n
}
return n, true
}
// atoi is like atoi64 but for integers
// that fit into an int.
func atoi(s string) (int, bool) {
if n, ok := atoi64(s); n == int64(int(n)) {
return int(n), ok
}
return 0, false
}
// atoi32 is like atoi but for integers
// that fit into an int32.
func atoi32(s string) (int32, bool) {
if n, ok := atoi64(s); n == int64(int32(n)) {
return int32(n), ok
}
return 0, false
}
// parseByteCount parses a string that represents a count of bytes.
//
// s must match the following regular expression:
//
// ^[0-9]+(([KMGT]i)?B)?$
//
// In other words, an integer byte count with an optional unit
// suffix. Acceptable suffixes include one of
// - KiB, MiB, GiB, TiB which represent binary IEC/ISO 80000 units, or
// - B, which just represents bytes.
//
// Returns an int64 because that's what its callers want and receive,
// but the result is always non-negative.
func parseByteCount(s string) (int64, bool) {
// The empty string is not valid.
if s == "" {
return 0, false
}
// Handle the easy non-suffix case.
last := s[len(s)-1]
if last >= '0' && last <= '9' {
n, ok := atoi64(s)
if !ok || n < 0 {
return 0, false
}
return n, ok
}
// Failing a trailing digit, this must always end in 'B'.
// Also at this point there must be at least one digit before
// that B.
if last != 'B' || len(s) < 2 {
return 0, false
}
// The one before that must always be a digit or 'i'.
if c := s[len(s)-2]; c >= '0' && c <= '9' {
// Trivial 'B' suffix.
n, ok := atoi64(s[:len(s)-1])
if !ok || n < 0 {
return 0, false
}
return n, ok
} else if c != 'i' {
return 0, false
}
// Finally, we need at least 4 characters now, for the unit
// prefix and at least one digit.
if len(s) < 4 {
return 0, false
}
power := 0
switch s[len(s)-3] {
case 'K':
power = 1
case 'M':
power = 2
case 'G':
power = 3
case 'T':
power = 4
default:
// Invalid suffix.
return 0, false
}
m := uint64(1)
for i := 0; i < power; i++ {
m *= 1024
}
n, ok := atoi64(s[:len(s)-3])
if !ok || n < 0 {
return 0, false
}
un := uint64(n)
if un > maxUint64/m {
// Overflow.
return 0, false
}
un *= m
if un > uint64(maxInt64) {
// Overflow.
return 0, false
}
return int64(un), true
}
//go:nosplit
func findnull(s *byte) int {
if s == nil {
return 0
}
// Avoid IndexByteString on Plan 9 because it uses SSE instructions
// on x86 machines, and those are classified as floating point instructions,
// which are illegal in a note handler.
if GOOS == "plan9" {
p := (*[maxAlloc/2 - 1]byte)(unsafe.Pointer(s))
l := 0
for p[l] != 0 {
l++
}
return l
}
// pageSize is the unit we scan at a time looking for NULL.
// It must be the minimum page size for any architecture Go
// runs on. It's okay (just a minor performance loss) if the
// actual system page size is larger than this value.
const pageSize = 4096
offset := 0
ptr := unsafe.Pointer(s)
// IndexByteString uses wide reads, so we need to be careful
// with page boundaries. Call IndexByteString on
// [ptr, endOfPage) interval.
safeLen := int(pageSize - uintptr(ptr)%pageSize)
for {
t := *(*string)(unsafe.Pointer(&stringStruct{ptr, safeLen}))
// Check one page at a time.
if i := bytealg.IndexByteString(t, 0); i != -1 {
return offset + i
}
// Move to next page
ptr = unsafe.Pointer(uintptr(ptr) + uintptr(safeLen))
offset += safeLen
safeLen = pageSize
}
}
func findnullw(s *uint16) int {
if s == nil {
return 0
}
p := (*[maxAlloc/2/2 - 1]uint16)(unsafe.Pointer(s))
l := 0
for p[l] != 0 {
l++
}
return l
}
//go:nosplit
func gostringnocopy(str *byte) string {
ss := stringStruct{str: unsafe.Pointer(str), len: findnull(str)}
s := *(*string)(unsafe.Pointer(&ss))
return s
}
func gostringw(strw *uint16) string {
var buf [8]byte
str := (*[maxAlloc/2/2 - 1]uint16)(unsafe.Pointer(strw))
n1 := 0
for i := 0; str[i] != 0; i++ {
n1 += encoderune(buf[:], rune(str[i]))
}
s, b := rawstring(n1 + 4)
n2 := 0
for i := 0; str[i] != 0; i++ {
// check for race
if n2 >= n1 {
break
}
n2 += encoderune(b[n2:], rune(str[i]))
}
b[n2] = 0 // for luck
return s[:n2]
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/abi"
"internal/goarch"
"runtime/internal/math"
"unsafe"
)
// Should be a built-in for unsafe.Pointer?
//
//go:nosplit
func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
return unsafe.Pointer(uintptr(p) + x)
}
// getg returns the pointer to the current g.
// The compiler rewrites calls to this function into instructions
// that fetch the g directly (from TLS or from the dedicated register).
func getg() *g
// mcall switches from the g to the g0 stack and invokes fn(g),
// where g is the goroutine that made the call.
// mcall saves g's current PC/SP in g->sched so that it can be restored later.
// It is up to fn to arrange for that later execution, typically by recording
// g in a data structure, causing something to call ready(g) later.
// mcall returns to the original goroutine g later, when g has been rescheduled.
// fn must not return at all; typically it ends by calling schedule, to let the m
// run other goroutines.
//
// mcall can only be called from g stacks (not g0, not gsignal).
//
// This must NOT be go:noescape: if fn is a stack-allocated closure,
// fn puts g on a run queue, and g executes before fn returns, the
// closure will be invalidated while it is still executing.
func mcall(fn func(*g))
// systemstack runs fn on a system stack.
// If systemstack is called from the per-OS-thread (g0) stack, or
// if systemstack is called from the signal handling (gsignal) stack,
// systemstack calls fn directly and returns.
// Otherwise, systemstack is being called from the limited stack
// of an ordinary goroutine. In this case, systemstack switches
// to the per-OS-thread stack, calls fn, and switches back.
// It is common to use a func literal as the argument, in order
// to share inputs and outputs with the code around the call
// to system stack:
//
// ... set up y ...
// systemstack(func() {
// x = bigcall(y)
// })
// ... use x ...
//
//go:noescape
func systemstack(fn func())
//go:nosplit
//go:nowritebarrierrec
func badsystemstack() {
writeErrStr("fatal: systemstack called from unexpected goroutine")
}
// memclrNoHeapPointers clears n bytes starting at ptr.
//
// Usually you should use typedmemclr. memclrNoHeapPointers should be
// used only when the caller knows that *ptr contains no heap pointers
// because either:
//
// *ptr is initialized memory and its type is pointer-free, or
//
// *ptr is uninitialized memory (e.g., memory that's being reused
// for a new allocation) and hence contains only "junk".
//
// memclrNoHeapPointers ensures that if ptr is pointer-aligned, and n
// is a multiple of the pointer size, then any pointer-aligned,
// pointer-sized portion is cleared atomically. Despite the function
// name, this is necessary because this function is the underlying
// implementation of typedmemclr and memclrHasPointers. See the doc of
// memmove for more details.
//
// The (CPU-specific) implementations of this function are in memclr_*.s.
//
//go:noescape
func memclrNoHeapPointers(ptr unsafe.Pointer, n uintptr)
//go:linkname reflect_memclrNoHeapPointers reflect.memclrNoHeapPointers
func reflect_memclrNoHeapPointers(ptr unsafe.Pointer, n uintptr) {
memclrNoHeapPointers(ptr, n)
}
// memmove copies n bytes from "from" to "to".
//
// memmove ensures that any pointer in "from" is written to "to" with
// an indivisible write, so that racy reads cannot observe a
// half-written pointer. This is necessary to prevent the garbage
// collector from observing invalid pointers, and differs from memmove
// in unmanaged languages. However, memmove is only required to do
// this if "from" and "to" may contain pointers, which can only be the
// case if "from", "to", and "n" are all be word-aligned.
//
// Implementations are in memmove_*.s.
//
//go:noescape
func memmove(to, from unsafe.Pointer, n uintptr)
// Outside assembly calls memmove. Make sure it has ABI wrappers.
//
//go:linkname memmove
//go:linkname reflect_memmove reflect.memmove
func reflect_memmove(to, from unsafe.Pointer, n uintptr) {
memmove(to, from, n)
}
// exported value for testing
const hashLoad = float32(loadFactorNum) / float32(loadFactorDen)
//go:nosplit
func fastrand() uint32 {
mp := getg().m
// Implement wyrand: https://github.com/wangyi-fudan/wyhash
// Only the platform that math.Mul64 can be lowered
// by the compiler should be in this list.
if goarch.IsAmd64|goarch.IsArm64|goarch.IsPpc64|
goarch.IsPpc64le|goarch.IsMips64|goarch.IsMips64le|
goarch.IsS390x|goarch.IsRiscv64|goarch.IsLoong64 == 1 {
mp.fastrand += 0xa0761d6478bd642f
hi, lo := math.Mul64(mp.fastrand, mp.fastrand^0xe7037ed1a0b428db)
return uint32(hi ^ lo)
}
// Implement xorshift64+: 2 32-bit xorshift sequences added together.
// Shift triplet [17,7,16] was calculated as indicated in Marsaglia's
// Xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf
// This generator passes the SmallCrush suite, part of TestU01 framework:
// http://simul.iro.umontreal.ca/testu01/tu01.html
t := (*[2]uint32)(unsafe.Pointer(&mp.fastrand))
s1, s0 := t[0], t[1]
s1 ^= s1 << 17
s1 = s1 ^ s0 ^ s1>>7 ^ s0>>16
t[0], t[1] = s0, s1
return s0 + s1
}
//go:nosplit
func fastrandn(n uint32) uint32 {
// This is similar to fastrand() % n, but faster.
// See https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
return uint32(uint64(fastrand()) * uint64(n) >> 32)
}
func fastrand64() uint64 {
mp := getg().m
// Implement wyrand: https://github.com/wangyi-fudan/wyhash
// Only the platform that math.Mul64 can be lowered
// by the compiler should be in this list.
if goarch.IsAmd64|goarch.IsArm64|goarch.IsPpc64|
goarch.IsPpc64le|goarch.IsMips64|goarch.IsMips64le|
goarch.IsS390x|goarch.IsRiscv64 == 1 {
mp.fastrand += 0xa0761d6478bd642f
hi, lo := math.Mul64(mp.fastrand, mp.fastrand^0xe7037ed1a0b428db)
return hi ^ lo
}
// Implement xorshift64+: 2 32-bit xorshift sequences added together.
// Xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf
// This generator passes the SmallCrush suite, part of TestU01 framework:
// http://simul.iro.umontreal.ca/testu01/tu01.html
t := (*[2]uint32)(unsafe.Pointer(&mp.fastrand))
s1, s0 := t[0], t[1]
s1 ^= s1 << 17
s1 = s1 ^ s0 ^ s1>>7 ^ s0>>16
r := uint64(s0 + s1)
s0, s1 = s1, s0
s1 ^= s1 << 17
s1 = s1 ^ s0 ^ s1>>7 ^ s0>>16
r += uint64(s0+s1) << 32
t[0], t[1] = s0, s1
return r
}
func fastrandu() uint {
if goarch.PtrSize == 4 {
return uint(fastrand())
}
return uint(fastrand64())
}
//go:linkname rand_fastrand64 math/rand.fastrand64
func rand_fastrand64() uint64 { return fastrand64() }
//go:linkname sync_fastrandn sync.fastrandn
func sync_fastrandn(n uint32) uint32 { return fastrandn(n) }
//go:linkname net_fastrandu net.fastrandu
func net_fastrandu() uint { return fastrandu() }
//go:linkname os_fastrand os.fastrand
func os_fastrand() uint32 { return fastrand() }
// in internal/bytealg/equal_*.s
//
//go:noescape
func memequal(a, b unsafe.Pointer, size uintptr) bool
// noescape hides a pointer from escape analysis. noescape is
// the identity function but escape analysis doesn't think the
// output depends on the input. noescape is inlined and currently
// compiles down to zero instructions.
// USE CAREFULLY!
//
//go:nosplit
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
// Not all cgocallback frames are actually cgocallback,
// so not all have these arguments. Mark them uintptr so that the GC
// does not misinterpret memory when the arguments are not present.
// cgocallback is not called from Go, only from crosscall2.
// This in turn calls cgocallbackg, which is where we'll find
// pointer-declared arguments.
func cgocallback(fn, frame, ctxt uintptr)
func gogo(buf *gobuf)
func asminit()
func setg(gg *g)
func breakpoint()
// reflectcall calls fn with arguments described by stackArgs, stackArgsSize,
// frameSize, and regArgs.
//
// Arguments passed on the stack and space for return values passed on the stack
// must be laid out at the space pointed to by stackArgs (with total length
// stackArgsSize) according to the ABI.
//
// stackRetOffset must be some value <= stackArgsSize that indicates the
// offset within stackArgs where the return value space begins.
//
// frameSize is the total size of the argument frame at stackArgs and must
// therefore be >= stackArgsSize. It must include additional space for spilling
// register arguments for stack growth and preemption.
//
// TODO(mknyszek): Once we don't need the additional spill space, remove frameSize,
// since frameSize will be redundant with stackArgsSize.
//
// Arguments passed in registers must be laid out in regArgs according to the ABI.
// regArgs will hold any return values passed in registers after the call.
//
// reflectcall copies stack arguments from stackArgs to the goroutine stack, and
// then copies back stackArgsSize-stackRetOffset bytes back to the return space
// in stackArgs once fn has completed. It also "unspills" argument registers from
// regArgs before calling fn, and spills them back into regArgs immediately
// following the call to fn. If there are results being returned on the stack,
// the caller should pass the argument frame type as stackArgsType so that
// reflectcall can execute appropriate write barriers during the copy.
//
// reflectcall expects regArgs.ReturnIsPtr to be populated indicating which
// registers on the return path will contain Go pointers. It will then store
// these pointers in regArgs.Ptrs such that they are visible to the GC.
//
// Package reflect passes a frame type. In package runtime, there is only
// one call that copies results back, in callbackWrap in syscall_windows.go, and it
// does NOT pass a frame type, meaning there are no write barriers invoked. See that
// call site for justification.
//
// Package reflect accesses this symbol through a linkname.
//
// Arguments passed through to reflectcall do not escape. The type is used
// only in a very limited callee of reflectcall, the stackArgs are copied, and
// regArgs is only used in the reflectcall frame.
//
//go:noescape
func reflectcall(stackArgsType *_type, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func procyield(cycles uint32)
type neverCallThisFunction struct{}
// goexit is the return stub at the top of every goroutine call stack.
// Each goroutine stack is constructed as if goexit called the
// goroutine's entry point function, so that when the entry point
// function returns, it will return to goexit, which will call goexit1
// to perform the actual exit.
//
// This function must never be called directly. Call goexit1 instead.
// gentraceback assumes that goexit terminates the stack. A direct
// call on the stack will cause gentraceback to stop walking the stack
// prematurely and if there is leftover state it may panic.
func goexit(neverCallThisFunction)
// publicationBarrier performs a store/store barrier (a "publication"
// or "export" barrier). Some form of synchronization is required
// between initializing an object and making that object accessible to
// another processor. Without synchronization, the initialization
// writes and the "publication" write may be reordered, allowing the
// other processor to follow the pointer and observe an uninitialized
// object. In general, higher-level synchronization should be used,
// such as locking or an atomic pointer write. publicationBarrier is
// for when those aren't an option, such as in the implementation of
// the memory manager.
//
// There's no corresponding barrier for the read side because the read
// side naturally has a data dependency order. All architectures that
// Go supports or seems likely to ever support automatically enforce
// data dependency ordering.
func publicationBarrier()
// getcallerpc returns the program counter (PC) of its caller's caller.
// getcallersp returns the stack pointer (SP) of its caller's caller.
// The implementation may be a compiler intrinsic; there is not
// necessarily code implementing this on every platform.
//
// For example:
//
// func f(arg1, arg2, arg3 int) {
// pc := getcallerpc()
// sp := getcallersp()
// }
//
// These two lines find the PC and SP immediately following
// the call to f (where f will return).
//
// The call to getcallerpc and getcallersp must be done in the
// frame being asked about.
//
// The result of getcallersp is correct at the time of the return,
// but it may be invalidated by any subsequent call to a function
// that might relocate the stack in order to grow or shrink it.
// A general rule is that the result of getcallersp should be used
// immediately and can only be passed to nosplit functions.
//go:noescape
func getcallerpc() uintptr
//go:noescape
func getcallersp() uintptr // implemented as an intrinsic on all platforms
// getclosureptr returns the pointer to the current closure.
// getclosureptr can only be used in an assignment statement
// at the entry of a function. Moreover, go:nosplit directive
// must be specified at the declaration of caller function,
// so that the function prolog does not clobber the closure register.
// for example:
//
// //go:nosplit
// func f(arg1, arg2, arg3 int) {
// dx := getclosureptr()
// }
//
// The compiler rewrites calls to this function into instructions that fetch the
// pointer from a well-known register (DX on x86 architecture, etc.) directly.
func getclosureptr() uintptr
//go:noescape
func asmcgocall(fn, arg unsafe.Pointer) int32
func morestack()
func morestack_noctxt()
func rt0_go()
// return0 is a stub used to return 0 from deferproc.
// It is called at the very end of deferproc to signal
// the calling Go function that it should not jump
// to deferreturn.
// in asm_*.s
func return0()
// in asm_*.s
// not called directly; definitions here supply type information for traceback.
// These must have the same signature (arg pointer map) as reflectcall.
func call16(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call32(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call64(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call128(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call256(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call512(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call1024(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call2048(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call4096(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call8192(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call16384(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call32768(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call65536(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call131072(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call262144(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call524288(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call1048576(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call2097152(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call4194304(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call8388608(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call16777216(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call33554432(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call67108864(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call134217728(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call268435456(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call536870912(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func call1073741824(typ, fn, stackArgs unsafe.Pointer, stackArgsSize, stackRetOffset, frameSize uint32, regArgs *abi.RegArgs)
func systemstack_switch()
// alignUp rounds n up to a multiple of a. a must be a power of 2.
func alignUp(n, a uintptr) uintptr {
return (n + a - 1) &^ (a - 1)
}
// alignDown rounds n down to a multiple of a. a must be a power of 2.
func alignDown(n, a uintptr) uintptr {
return n &^ (a - 1)
}
// divRoundUp returns ceil(n / a).
func divRoundUp(n, a uintptr) uintptr {
// a is generally a power of two. This will get inlined and
// the compiler will optimize the division.
return (n + a - 1) / a
}
// checkASM reports whether assembly runtime checks have passed.
func checkASM() bool
func memequal_varlen(a, b unsafe.Pointer) bool
// bool2int returns 0 if x is false or 1 if x is true.
func bool2int(x bool) int {
// Avoid branches. In the SSA compiler, this compiles to
// exactly what you would want it to.
return int(uint8(*(*uint8)(unsafe.Pointer(&x))))
}
// abort crashes the runtime in situations where even throw might not
// work. In general it should do something a debugger will recognize
// (e.g., an INT3 on x86). A crash in abort is recognized by the
// signal handler, which will attempt to tear down the runtime
// immediately.
func abort()
// Called from compiled code; declared for vet; do NOT call from Go.
func gcWriteBarrier1()
func gcWriteBarrier2()
func gcWriteBarrier3()
func gcWriteBarrier4()
func gcWriteBarrier5()
func gcWriteBarrier6()
func gcWriteBarrier7()
func gcWriteBarrier8()
func duffzero()
func duffcopy()
// Called from linker-generated .initarray; declared for go vet; do NOT call from Go.
func addmoduledata()
// Injected by the signal handler for panicking signals.
// Initializes any registers that have fixed meaning at calls but
// are scratch in bodies and calls sigpanic.
// On many platforms it just jumps to sigpanic.
func sigpanic0()
// intArgRegs is used by the various register assignment
// algorithm implementations in the runtime. These include:.
// - Finalizers (mfinal.go)
// - Windows callbacks (syscall_windows.go)
//
// Both are stripped-down versions of the algorithm since they
// only have to deal with a subset of cases (finalizers only
// take a pointer or interface argument, Go Windows callbacks
// don't support floating point).
//
// It should be modified with care and are generally only
// modified when testing this package.
//
// It should never be set higher than its internal/abi
// constant counterparts, because the system relies on a
// structure that is at least large enough to hold the
// registers the system supports.
//
// Protected by finlock.
var intArgRegs = abi.IntArgRegs
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !aix && !darwin && !js && !openbsd && !plan9 && !solaris && !windows
package runtime
import (
"runtime/internal/atomic"
"unsafe"
)
// read calls the read system call.
// It returns a non-negative number of bytes written or a negative errno value.
func read(fd int32, p unsafe.Pointer, n int32) int32
func closefd(fd int32) int32
func exit(code int32)
func usleep(usec uint32)
//go:nosplit
func usleep_no_g(usec uint32) {
usleep(usec)
}
// write1 calls the write system call.
// It returns a non-negative number of bytes written or a negative errno value.
//
//go:noescape
func write1(fd uintptr, p unsafe.Pointer, n int32) int32
//go:noescape
func open(name *byte, mode, perm int32) int32
// return value is only set on linux to be used in osinit().
func madvise(addr unsafe.Pointer, n uintptr, flags int32) int32
// exitThread terminates the current thread, writing *wait = freeMStack when
// the stack is safe to reclaim.
//
//go:noescape
func exitThread(wait *atomic.Uint32)
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// Frames may be used to get function/file/line information for a
// slice of PC values returned by Callers.
type Frames struct {
// callers is a slice of PCs that have not yet been expanded to frames.
callers []uintptr
// frames is a slice of Frames that have yet to be returned.
frames []Frame
frameStore [2]Frame
}
// Frame is the information returned by Frames for each call frame.
type Frame struct {
// PC is the program counter for the location in this frame.
// For a frame that calls another frame, this will be the
// program counter of a call instruction. Because of inlining,
// multiple frames may have the same PC value, but different
// symbolic information.
PC uintptr
// Func is the Func value of this call frame. This may be nil
// for non-Go code or fully inlined functions.
Func *Func
// Function is the package path-qualified function name of
// this call frame. If non-empty, this string uniquely
// identifies a single function in the program.
// This may be the empty string if not known.
// If Func is not nil then Function == Func.Name().
Function string
// File and Line are the file name and line number of the
// location in this frame. For non-leaf frames, this will be
// the location of a call. These may be the empty string and
// zero, respectively, if not known.
File string
Line int
// startLine is the line number of the beginning of the function in
// this frame. Specifically, it is the line number of the func keyword
// for Go functions. Note that //line directives can change the
// filename and/or line number arbitrarily within a function, meaning
// that the Line - startLine offset is not always meaningful.
//
// This may be zero if not known.
startLine int
// Entry point program counter for the function; may be zero
// if not known. If Func is not nil then Entry ==
// Func.Entry().
Entry uintptr
// The runtime's internal view of the function. This field
// is set (funcInfo.valid() returns true) only for Go functions,
// not for C functions.
funcInfo funcInfo
}
// CallersFrames takes a slice of PC values returned by Callers and
// prepares to return function/file/line information.
// Do not change the slice until you are done with the Frames.
func CallersFrames(callers []uintptr) *Frames {
f := &Frames{callers: callers}
f.frames = f.frameStore[:0]
return f
}
// Next returns a Frame representing the next call frame in the slice
// of PC values. If it has already returned all call frames, Next
// returns a zero Frame.
//
// The more result indicates whether the next call to Next will return
// a valid Frame. It does not necessarily indicate whether this call
// returned one.
//
// See the Frames example for idiomatic usage.
func (ci *Frames) Next() (frame Frame, more bool) {
for len(ci.frames) < 2 {
// Find the next frame.
// We need to look for 2 frames so we know what
// to return for the "more" result.
if len(ci.callers) == 0 {
break
}
pc := ci.callers[0]
ci.callers = ci.callers[1:]
funcInfo := findfunc(pc)
if !funcInfo.valid() {
if cgoSymbolizer != nil {
// Pre-expand cgo frames. We could do this
// incrementally, too, but there's no way to
// avoid allocation in this case anyway.
ci.frames = append(ci.frames, expandCgoFrames(pc)...)
}
continue
}
f := funcInfo._Func()
entry := f.Entry()
if pc > entry {
// We store the pc of the start of the instruction following
// the instruction in question (the call or the inline mark).
// This is done for historical reasons, and to make FuncForPC
// work correctly for entries in the result of runtime.Callers.
pc--
}
name := funcname(funcInfo)
startLine := f.startLine()
if inldata := funcdata(funcInfo, _FUNCDATA_InlTree); inldata != nil {
inltree := (*[1 << 20]inlinedCall)(inldata)
// Non-strict as cgoTraceback may have added bogus PCs
// with a valid funcInfo but invalid PCDATA.
ix := pcdatavalue1(funcInfo, _PCDATA_InlTreeIndex, pc, nil, false)
if ix >= 0 {
// Note: entry is not modified. It always refers to a real frame, not an inlined one.
f = nil
ic := inltree[ix]
name = funcnameFromNameOff(funcInfo, ic.nameOff)
startLine = ic.startLine
// File/line from funcline1 below are already correct.
}
}
ci.frames = append(ci.frames, Frame{
PC: pc,
Func: f,
Function: name,
Entry: entry,
startLine: int(startLine),
funcInfo: funcInfo,
// Note: File,Line set below
})
}
// Pop one frame from the frame list. Keep the rest.
// Avoid allocation in the common case, which is 1 or 2 frames.
switch len(ci.frames) {
case 0: // In the rare case when there are no frames at all, we return Frame{}.
return
case 1:
frame = ci.frames[0]
ci.frames = ci.frameStore[:0]
case 2:
frame = ci.frames[0]
ci.frameStore[0] = ci.frames[1]
ci.frames = ci.frameStore[:1]
default:
frame = ci.frames[0]
ci.frames = ci.frames[1:]
}
more = len(ci.frames) > 0
if frame.funcInfo.valid() {
// Compute file/line just before we need to return it,
// as it can be expensive. This avoids computing file/line
// for the Frame we find but don't return. See issue 32093.
file, line := funcline1(frame.funcInfo, frame.PC, false)
frame.File, frame.Line = file, int(line)
}
return
}
// runtime_FrameStartLine returns the start line of the function in a Frame.
//
//go:linkname runtime_FrameStartLine runtime/pprof.runtime_FrameStartLine
func runtime_FrameStartLine(f *Frame) int {
return f.startLine
}
// runtime_expandFinalInlineFrame expands the final pc in stk to include all
// "callers" if pc is inline.
//
//go:linkname runtime_expandFinalInlineFrame runtime/pprof.runtime_expandFinalInlineFrame
func runtime_expandFinalInlineFrame(stk []uintptr) []uintptr {
if len(stk) == 0 {
return stk
}
pc := stk[len(stk)-1]
tracepc := pc - 1
f := findfunc(tracepc)
if !f.valid() {
// Not a Go function.
return stk
}
inldata := funcdata(f, _FUNCDATA_InlTree)
if inldata == nil {
// Nothing inline in f.
return stk
}
// Treat the previous func as normal. We haven't actually checked, but
// since this pc was included in the stack, we know it shouldn't be
// elided.
lastFuncID := funcID_normal
// Remove pc from stk; we'll re-add it below.
stk = stk[:len(stk)-1]
// See inline expansion in gentraceback.
var cache pcvalueCache
inltree := (*[1 << 20]inlinedCall)(inldata)
for {
// Non-strict as cgoTraceback may have added bogus PCs
// with a valid funcInfo but invalid PCDATA.
ix := pcdatavalue1(f, _PCDATA_InlTreeIndex, tracepc, &cache, false)
if ix < 0 {
break
}
if inltree[ix].funcID == funcID_wrapper && elideWrapperCalling(lastFuncID) {
// ignore wrappers
} else {
stk = append(stk, pc)
}
lastFuncID = inltree[ix].funcID
// Back up to an instruction in the "caller".
tracepc = f.entry() + uintptr(inltree[ix].parentPc)
pc = tracepc + 1
}
// N.B. we want to keep the last parentPC which is not inline.
if f.funcID == funcID_wrapper && elideWrapperCalling(lastFuncID) {
// Ignore wrapper functions (except when they trigger panics).
} else {
stk = append(stk, pc)
}
return stk
}
// expandCgoFrames expands frame information for pc, known to be
// a non-Go function, using the cgoSymbolizer hook. expandCgoFrames
// returns nil if pc could not be expanded.
func expandCgoFrames(pc uintptr) []Frame {
arg := cgoSymbolizerArg{pc: pc}
callCgoSymbolizer(&arg)
if arg.file == nil && arg.funcName == nil {
// No useful information from symbolizer.
return nil
}
var frames []Frame
for {
frames = append(frames, Frame{
PC: pc,
Func: nil,
Function: gostring(arg.funcName),
File: gostring(arg.file),
Line: int(arg.lineno),
Entry: arg.entry,
// funcInfo is zero, which implies !funcInfo.valid().
// That ensures that we use the File/Line info given here.
})
if arg.more == 0 {
break
}
callCgoSymbolizer(&arg)
}
// No more frames for this PC. Tell the symbolizer we are done.
// We don't try to maintain a single cgoSymbolizerArg for the
// whole use of Frames, because there would be no good way to tell
// the symbolizer when we are done.
arg.pc = 0
callCgoSymbolizer(&arg)
return frames
}
// NOTE: Func does not expose the actual unexported fields, because we return *Func
// values to users, and we want to keep them from being able to overwrite the data
// with (say) *f = Func{}.
// All code operating on a *Func must call raw() to get the *_func
// or funcInfo() to get the funcInfo instead.
// A Func represents a Go function in the running binary.
type Func struct {
opaque struct{} // unexported field to disallow conversions
}
func (f *Func) raw() *_func {
return (*_func)(unsafe.Pointer(f))
}
func (f *Func) funcInfo() funcInfo {
return f.raw().funcInfo()
}
func (f *_func) funcInfo() funcInfo {
// Find the module containing fn. fn is located in the pclntable.
// The unsafe.Pointer to uintptr conversions and arithmetic
// are safe because we are working with module addresses.
ptr := uintptr(unsafe.Pointer(f))
var mod *moduledata
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if len(datap.pclntable) == 0 {
continue
}
base := uintptr(unsafe.Pointer(&datap.pclntable[0]))
if base <= ptr && ptr < base+uintptr(len(datap.pclntable)) {
mod = datap
break
}
}
return funcInfo{f, mod}
}
// PCDATA and FUNCDATA table indexes.
//
// See funcdata.h and ../cmd/internal/objabi/funcdata.go.
const (
_PCDATA_UnsafePoint = 0
_PCDATA_StackMapIndex = 1
_PCDATA_InlTreeIndex = 2
_PCDATA_ArgLiveIndex = 3
_FUNCDATA_ArgsPointerMaps = 0
_FUNCDATA_LocalsPointerMaps = 1
_FUNCDATA_StackObjects = 2
_FUNCDATA_InlTree = 3
_FUNCDATA_OpenCodedDeferInfo = 4
_FUNCDATA_ArgInfo = 5
_FUNCDATA_ArgLiveInfo = 6
_FUNCDATA_WrapInfo = 7
_ArgsSizeUnknown = -0x80000000
)
const (
// PCDATA_UnsafePoint values.
_PCDATA_UnsafePointSafe = -1 // Safe for async preemption
_PCDATA_UnsafePointUnsafe = -2 // Unsafe for async preemption
// _PCDATA_Restart1(2) apply on a sequence of instructions, within
// which if an async preemption happens, we should back off the PC
// to the start of the sequence when resume.
// We need two so we can distinguish the start/end of the sequence
// in case that two sequences are next to each other.
_PCDATA_Restart1 = -3
_PCDATA_Restart2 = -4
// Like _PCDATA_RestartAtEntry, but back to function entry if async
// preempted.
_PCDATA_RestartAtEntry = -5
)
// A FuncID identifies particular functions that need to be treated
// specially by the runtime.
// Note that in some situations involving plugins, there may be multiple
// copies of a particular special runtime function.
// Note: this list must match the list in cmd/internal/objabi/funcid.go.
type funcID uint8
const (
funcID_normal funcID = iota // not a special function
funcID_abort
funcID_asmcgocall
funcID_asyncPreempt
funcID_cgocallback
funcID_debugCallV2
funcID_gcBgMarkWorker
funcID_goexit
funcID_gogo
funcID_gopanic
funcID_handleAsyncEvent
funcID_mcall
funcID_morestack
funcID_mstart
funcID_panicwrap
funcID_rt0_go
funcID_runfinq
funcID_runtime_main
funcID_sigpanic
funcID_systemstack
funcID_systemstack_switch
funcID_wrapper // any autogenerated code (hash/eq algorithms, method wrappers, etc.)
)
// A FuncFlag holds bits about a function.
// This list must match the list in cmd/internal/objabi/funcid.go.
type funcFlag uint8
const (
// TOPFRAME indicates a function that appears at the top of its stack.
// The traceback routine stop at such a function and consider that a
// successful, complete traversal of the stack.
// Examples of TOPFRAME functions include goexit, which appears
// at the top of a user goroutine stack, and mstart, which appears
// at the top of a system goroutine stack.
funcFlag_TOPFRAME funcFlag = 1 << iota
// SPWRITE indicates a function that writes an arbitrary value to SP
// (any write other than adding or subtracting a constant amount).
// The traceback routines cannot encode such changes into the
// pcsp tables, so the function traceback cannot safely unwind past
// SPWRITE functions. Stopping at an SPWRITE function is considered
// to be an incomplete unwinding of the stack. In certain contexts
// (in particular garbage collector stack scans) that is a fatal error.
funcFlag_SPWRITE
// ASM indicates that a function was implemented in assembly.
funcFlag_ASM
)
// pcHeader holds data used by the pclntab lookups.
type pcHeader struct {
magic uint32 // 0xFFFFFFF1
pad1, pad2 uint8 // 0,0
minLC uint8 // min instruction size
ptrSize uint8 // size of a ptr in bytes
nfunc int // number of functions in the module
nfiles uint // number of entries in the file tab
textStart uintptr // base for function entry PC offsets in this module, equal to moduledata.text
funcnameOffset uintptr // offset to the funcnametab variable from pcHeader
cuOffset uintptr // offset to the cutab variable from pcHeader
filetabOffset uintptr // offset to the filetab variable from pcHeader
pctabOffset uintptr // offset to the pctab variable from pcHeader
pclnOffset uintptr // offset to the pclntab variable from pcHeader
}
// moduledata records information about the layout of the executable
// image. It is written by the linker. Any changes here must be
// matched changes to the code in cmd/link/internal/ld/symtab.go:symtab.
// moduledata is stored in statically allocated non-pointer memory;
// none of the pointers here are visible to the garbage collector.
type moduledata struct {
pcHeader *pcHeader
funcnametab []byte
cutab []uint32
filetab []byte
pctab []byte
pclntable []byte
ftab []functab
findfunctab uintptr
minpc, maxpc uintptr
text, etext uintptr
noptrdata, enoptrdata uintptr
data, edata uintptr
bss, ebss uintptr
noptrbss, enoptrbss uintptr
covctrs, ecovctrs uintptr
end, gcdata, gcbss uintptr
types, etypes uintptr
rodata uintptr
gofunc uintptr // go.func.*
textsectmap []textsect
typelinks []int32 // offsets from types
itablinks []*itab
ptab []ptabEntry
pluginpath string
pkghashes []modulehash
// This slice records the initializing tasks that need to be
// done to start up the program. It is built by the linker.
inittasks []*initTask
modulename string
modulehashes []modulehash
hasmain uint8 // 1 if module contains the main function, 0 otherwise
gcdatamask, gcbssmask bitvector
typemap map[typeOff]*_type // offset to *_rtype in previous module
bad bool // module failed to load and should be ignored
next *moduledata
}
// A modulehash is used to compare the ABI of a new module or a
// package in a new module with the loaded program.
//
// For each shared library a module links against, the linker creates an entry in the
// moduledata.modulehashes slice containing the name of the module, the abi hash seen
// at link time and a pointer to the runtime abi hash. These are checked in
// moduledataverify1 below.
//
// For each loaded plugin, the pkghashes slice has a modulehash of the
// newly loaded package that can be used to check the plugin's version of
// a package against any previously loaded version of the package.
// This is done in plugin.lastmoduleinit.
type modulehash struct {
modulename string
linktimehash string
runtimehash *string
}
// pinnedTypemaps are the map[typeOff]*_type from the moduledata objects.
//
// These typemap objects are allocated at run time on the heap, but the
// only direct reference to them is in the moduledata, created by the
// linker and marked SNOPTRDATA so it is ignored by the GC.
//
// To make sure the map isn't collected, we keep a second reference here.
var pinnedTypemaps []map[typeOff]*_type
var firstmoduledata moduledata // linker symbol
var lastmoduledatap *moduledata // linker symbol
var modulesSlice *[]*moduledata // see activeModules
// activeModules returns a slice of active modules.
//
// A module is active once its gcdatamask and gcbssmask have been
// assembled and it is usable by the GC.
//
// This is nosplit/nowritebarrier because it is called by the
// cgo pointer checking code.
//
//go:nosplit
//go:nowritebarrier
func activeModules() []*moduledata {
p := (*[]*moduledata)(atomic.Loadp(unsafe.Pointer(&modulesSlice)))
if p == nil {
return nil
}
return *p
}
// modulesinit creates the active modules slice out of all loaded modules.
//
// When a module is first loaded by the dynamic linker, an .init_array
// function (written by cmd/link) is invoked to call addmoduledata,
// appending to the module to the linked list that starts with
// firstmoduledata.
//
// There are two times this can happen in the lifecycle of a Go
// program. First, if compiled with -linkshared, a number of modules
// built with -buildmode=shared can be loaded at program initialization.
// Second, a Go program can load a module while running that was built
// with -buildmode=plugin.
//
// After loading, this function is called which initializes the
// moduledata so it is usable by the GC and creates a new activeModules
// list.
//
// Only one goroutine may call modulesinit at a time.
func modulesinit() {
modules := new([]*moduledata)
for md := &firstmoduledata; md != nil; md = md.next {
if md.bad {
continue
}
*modules = append(*modules, md)
if md.gcdatamask == (bitvector{}) {
scanDataSize := md.edata - md.data
md.gcdatamask = progToPointerMask((*byte)(unsafe.Pointer(md.gcdata)), scanDataSize)
scanBSSSize := md.ebss - md.bss
md.gcbssmask = progToPointerMask((*byte)(unsafe.Pointer(md.gcbss)), scanBSSSize)
gcController.addGlobals(int64(scanDataSize + scanBSSSize))
}
}
// Modules appear in the moduledata linked list in the order they are
// loaded by the dynamic loader, with one exception: the
// firstmoduledata itself the module that contains the runtime. This
// is not always the first module (when using -buildmode=shared, it
// is typically libstd.so, the second module). The order matters for
// typelinksinit, so we swap the first module with whatever module
// contains the main function.
//
// See Issue #18729.
for i, md := range *modules {
if md.hasmain != 0 {
(*modules)[0] = md
(*modules)[i] = &firstmoduledata
break
}
}
atomicstorep(unsafe.Pointer(&modulesSlice), unsafe.Pointer(modules))
}
type functab struct {
entryoff uint32 // relative to runtime.text
funcoff uint32
}
// Mapping information for secondary text sections
type textsect struct {
vaddr uintptr // prelinked section vaddr
end uintptr // vaddr + section length
baseaddr uintptr // relocated section address
}
const minfunc = 16 // minimum function size
const pcbucketsize = 256 * minfunc // size of bucket in the pc->func lookup table
// findfuncbucket is an array of these structures.
// Each bucket represents 4096 bytes of the text segment.
// Each subbucket represents 256 bytes of the text segment.
// To find a function given a pc, locate the bucket and subbucket for
// that pc. Add together the idx and subbucket value to obtain a
// function index. Then scan the functab array starting at that
// index to find the target function.
// This table uses 20 bytes for every 4096 bytes of code, or ~0.5% overhead.
type findfuncbucket struct {
idx uint32
subbuckets [16]byte
}
func moduledataverify() {
for datap := &firstmoduledata; datap != nil; datap = datap.next {
moduledataverify1(datap)
}
}
const debugPcln = false
func moduledataverify1(datap *moduledata) {
// Check that the pclntab's format is valid.
hdr := datap.pcHeader
if hdr.magic != 0xfffffff1 || hdr.pad1 != 0 || hdr.pad2 != 0 ||
hdr.minLC != sys.PCQuantum || hdr.ptrSize != goarch.PtrSize || hdr.textStart != datap.text {
println("runtime: pcHeader: magic=", hex(hdr.magic), "pad1=", hdr.pad1, "pad2=", hdr.pad2,
"minLC=", hdr.minLC, "ptrSize=", hdr.ptrSize, "pcHeader.textStart=", hex(hdr.textStart),
"text=", hex(datap.text), "pluginpath=", datap.pluginpath)
throw("invalid function symbol table")
}
// ftab is lookup table for function by program counter.
nftab := len(datap.ftab) - 1
for i := 0; i < nftab; i++ {
// NOTE: ftab[nftab].entry is legal; it is the address beyond the final function.
if datap.ftab[i].entryoff > datap.ftab[i+1].entryoff {
f1 := funcInfo{(*_func)(unsafe.Pointer(&datap.pclntable[datap.ftab[i].funcoff])), datap}
f2 := funcInfo{(*_func)(unsafe.Pointer(&datap.pclntable[datap.ftab[i+1].funcoff])), datap}
f2name := "end"
if i+1 < nftab {
f2name = funcname(f2)
}
println("function symbol table not sorted by PC offset:", hex(datap.ftab[i].entryoff), funcname(f1), ">", hex(datap.ftab[i+1].entryoff), f2name, ", plugin:", datap.pluginpath)
for j := 0; j <= i; j++ {
println("\t", hex(datap.ftab[j].entryoff), funcname(funcInfo{(*_func)(unsafe.Pointer(&datap.pclntable[datap.ftab[j].funcoff])), datap}))
}
if GOOS == "aix" && isarchive {
println("-Wl,-bnoobjreorder is mandatory on aix/ppc64 with c-archive")
}
throw("invalid runtime symbol table")
}
}
min := datap.textAddr(datap.ftab[0].entryoff)
max := datap.textAddr(datap.ftab[nftab].entryoff)
if datap.minpc != min || datap.maxpc != max {
println("minpc=", hex(datap.minpc), "min=", hex(min), "maxpc=", hex(datap.maxpc), "max=", hex(max))
throw("minpc or maxpc invalid")
}
for _, modulehash := range datap.modulehashes {
if modulehash.linktimehash != *modulehash.runtimehash {
println("abi mismatch detected between", datap.modulename, "and", modulehash.modulename)
throw("abi mismatch")
}
}
}
// textAddr returns md.text + off, with special handling for multiple text sections.
// off is a (virtual) offset computed at internal linking time,
// before the external linker adjusts the sections' base addresses.
//
// The text, or instruction stream is generated as one large buffer.
// The off (offset) for a function is its offset within this buffer.
// If the total text size gets too large, there can be issues on platforms like ppc64
// if the target of calls are too far for the call instruction.
// To resolve the large text issue, the text is split into multiple text sections
// to allow the linker to generate long calls when necessary.
// When this happens, the vaddr for each text section is set to its offset within the text.
// Each function's offset is compared against the section vaddrs and ends to determine the containing section.
// Then the section relative offset is added to the section's
// relocated baseaddr to compute the function address.
//
// It is nosplit because it is part of the findfunc implementation.
//
//go:nosplit
func (md *moduledata) textAddr(off32 uint32) uintptr {
off := uintptr(off32)
res := md.text + off
if len(md.textsectmap) > 1 {
for i, sect := range md.textsectmap {
// For the last section, include the end address (etext), as it is included in the functab.
if off >= sect.vaddr && off < sect.end || (i == len(md.textsectmap)-1 && off == sect.end) {
res = sect.baseaddr + off - sect.vaddr
break
}
}
if res > md.etext && GOARCH != "wasm" { // on wasm, functions do not live in the same address space as the linear memory
println("runtime: textAddr", hex(res), "out of range", hex(md.text), "-", hex(md.etext))
throw("runtime: text offset out of range")
}
}
return res
}
// textOff is the opposite of textAddr. It converts a PC to a (virtual) offset
// to md.text, and returns if the PC is in any Go text section.
//
// It is nosplit because it is part of the findfunc implementation.
//
//go:nosplit
func (md *moduledata) textOff(pc uintptr) (uint32, bool) {
res := uint32(pc - md.text)
if len(md.textsectmap) > 1 {
for i, sect := range md.textsectmap {
if sect.baseaddr > pc {
// pc is not in any section.
return 0, false
}
end := sect.baseaddr + (sect.end - sect.vaddr)
// For the last section, include the end address (etext), as it is included in the functab.
if i == len(md.textsectmap) {
end++
}
if pc < end {
res = uint32(pc - sect.baseaddr + sect.vaddr)
break
}
}
}
return res, true
}
// FuncForPC returns a *Func describing the function that contains the
// given program counter address, or else nil.
//
// If pc represents multiple functions because of inlining, it returns
// the *Func describing the innermost function, but with an entry of
// the outermost function.
func FuncForPC(pc uintptr) *Func {
f := findfunc(pc)
if !f.valid() {
return nil
}
if inldata := funcdata(f, _FUNCDATA_InlTree); inldata != nil {
// Note: strict=false so bad PCs (those between functions) don't crash the runtime.
// We just report the preceding function in that situation. See issue 29735.
// TODO: Perhaps we should report no function at all in that case.
// The runtime currently doesn't have function end info, alas.
if ix := pcdatavalue1(f, _PCDATA_InlTreeIndex, pc, nil, false); ix >= 0 {
inltree := (*[1 << 20]inlinedCall)(inldata)
ic := inltree[ix]
name := funcnameFromNameOff(f, ic.nameOff)
file, line := funcline(f, pc)
fi := &funcinl{
ones: ^uint32(0),
entry: f.entry(), // entry of the real (the outermost) function.
name: name,
file: file,
line: line,
startLine: ic.startLine,
}
return (*Func)(unsafe.Pointer(fi))
}
}
return f._Func()
}
// Name returns the name of the function.
func (f *Func) Name() string {
if f == nil {
return ""
}
fn := f.raw()
if fn.isInlined() { // inlined version
fi := (*funcinl)(unsafe.Pointer(fn))
return fi.name
}
return funcname(f.funcInfo())
}
// Entry returns the entry address of the function.
func (f *Func) Entry() uintptr {
fn := f.raw()
if fn.isInlined() { // inlined version
fi := (*funcinl)(unsafe.Pointer(fn))
return fi.entry
}
return fn.funcInfo().entry()
}
// FileLine returns the file name and line number of the
// source code corresponding to the program counter pc.
// The result will not be accurate if pc is not a program
// counter within f.
func (f *Func) FileLine(pc uintptr) (file string, line int) {
fn := f.raw()
if fn.isInlined() { // inlined version
fi := (*funcinl)(unsafe.Pointer(fn))
return fi.file, int(fi.line)
}
// Pass strict=false here, because anyone can call this function,
// and they might just be wrong about targetpc belonging to f.
file, line32 := funcline1(f.funcInfo(), pc, false)
return file, int(line32)
}
// startLine returns the starting line number of the function. i.e., the line
// number of the func keyword.
func (f *Func) startLine() int32 {
fn := f.raw()
if fn.isInlined() { // inlined version
fi := (*funcinl)(unsafe.Pointer(fn))
return fi.startLine
}
return fn.funcInfo().startLine
}
// findmoduledatap looks up the moduledata for a PC.
//
// It is nosplit because it's part of the isgoexception
// implementation.
//
//go:nosplit
func findmoduledatap(pc uintptr) *moduledata {
for datap := &firstmoduledata; datap != nil; datap = datap.next {
if datap.minpc <= pc && pc < datap.maxpc {
return datap
}
}
return nil
}
type funcInfo struct {
*_func
datap *moduledata
}
func (f funcInfo) valid() bool {
return f._func != nil
}
func (f funcInfo) _Func() *Func {
return (*Func)(unsafe.Pointer(f._func))
}
// isInlined reports whether f should be re-interpreted as a *funcinl.
func (f *_func) isInlined() bool {
return f.entryOff == ^uint32(0) // see comment for funcinl.ones
}
// entry returns the entry PC for f.
func (f funcInfo) entry() uintptr {
return f.datap.textAddr(f.entryOff)
}
// findfunc looks up function metadata for a PC.
//
// It is nosplit because it's part of the isgoexception
// implementation.
//
//go:nosplit
func findfunc(pc uintptr) funcInfo {
datap := findmoduledatap(pc)
if datap == nil {
return funcInfo{}
}
const nsub = uintptr(len(findfuncbucket{}.subbuckets))
pcOff, ok := datap.textOff(pc)
if !ok {
return funcInfo{}
}
x := uintptr(pcOff) + datap.text - datap.minpc // TODO: are datap.text and datap.minpc always equal?
b := x / pcbucketsize
i := x % pcbucketsize / (pcbucketsize / nsub)
ffb := (*findfuncbucket)(add(unsafe.Pointer(datap.findfunctab), b*unsafe.Sizeof(findfuncbucket{})))
idx := ffb.idx + uint32(ffb.subbuckets[i])
// Find the ftab entry.
for datap.ftab[idx+1].entryoff <= pcOff {
idx++
}
funcoff := datap.ftab[idx].funcoff
return funcInfo{(*_func)(unsafe.Pointer(&datap.pclntable[funcoff])), datap}
}
type pcvalueCache struct {
entries [2][8]pcvalueCacheEnt
}
type pcvalueCacheEnt struct {
// targetpc and off together are the key of this cache entry.
targetpc uintptr
off uint32
// val is the value of this cached pcvalue entry.
val int32
}
// pcvalueCacheKey returns the outermost index in a pcvalueCache to use for targetpc.
// It must be very cheap to calculate.
// For now, align to goarch.PtrSize and reduce mod the number of entries.
// In practice, this appears to be fairly randomly and evenly distributed.
func pcvalueCacheKey(targetpc uintptr) uintptr {
return (targetpc / goarch.PtrSize) % uintptr(len(pcvalueCache{}.entries))
}
// Returns the PCData value, and the PC where this value starts.
// TODO: the start PC is returned only when cache is nil.
func pcvalue(f funcInfo, off uint32, targetpc uintptr, cache *pcvalueCache, strict bool) (int32, uintptr) {
if off == 0 {
return -1, 0
}
// Check the cache. This speeds up walks of deep stacks, which
// tend to have the same recursive functions over and over.
//
// This cache is small enough that full associativity is
// cheaper than doing the hashing for a less associative
// cache.
if cache != nil {
x := pcvalueCacheKey(targetpc)
for i := range cache.entries[x] {
// We check off first because we're more
// likely to have multiple entries with
// different offsets for the same targetpc
// than the other way around, so we'll usually
// fail in the first clause.
ent := &cache.entries[x][i]
if ent.off == off && ent.targetpc == targetpc {
return ent.val, 0
}
}
}
if !f.valid() {
if strict && panicking.Load() == 0 {
println("runtime: no module data for", hex(f.entry()))
throw("no module data")
}
return -1, 0
}
datap := f.datap
p := datap.pctab[off:]
pc := f.entry()
prevpc := pc
val := int32(-1)
for {
var ok bool
p, ok = step(p, &pc, &val, pc == f.entry())
if !ok {
break
}
if targetpc < pc {
// Replace a random entry in the cache. Random
// replacement prevents a performance cliff if
// a recursive stack's cycle is slightly
// larger than the cache.
// Put the new element at the beginning,
// since it is the most likely to be newly used.
if cache != nil {
x := pcvalueCacheKey(targetpc)
e := &cache.entries[x]
ci := fastrandn(uint32(len(cache.entries[x])))
e[ci] = e[0]
e[0] = pcvalueCacheEnt{
targetpc: targetpc,
off: off,
val: val,
}
}
return val, prevpc
}
prevpc = pc
}
// If there was a table, it should have covered all program counters.
// If not, something is wrong.
if panicking.Load() != 0 || !strict {
return -1, 0
}
print("runtime: invalid pc-encoded table f=", funcname(f), " pc=", hex(pc), " targetpc=", hex(targetpc), " tab=", p, "\n")
p = datap.pctab[off:]
pc = f.entry()
val = -1
for {
var ok bool
p, ok = step(p, &pc, &val, pc == f.entry())
if !ok {
break
}
print("\tvalue=", val, " until pc=", hex(pc), "\n")
}
throw("invalid runtime symbol table")
return -1, 0
}
func cfuncname(f funcInfo) *byte {
if !f.valid() || f.nameOff == 0 {
return nil
}
return &f.datap.funcnametab[f.nameOff]
}
func funcname(f funcInfo) string {
return gostringnocopy(cfuncname(f))
}
func funcpkgpath(f funcInfo) string {
name := funcname(f)
i := len(name) - 1
for ; i > 0; i-- {
if name[i] == '/' {
break
}
}
for ; i < len(name); i++ {
if name[i] == '.' {
break
}
}
return name[:i]
}
func cfuncnameFromNameOff(f funcInfo, nameOff int32) *byte {
if !f.valid() {
return nil
}
return &f.datap.funcnametab[nameOff]
}
func funcnameFromNameOff(f funcInfo, nameOff int32) string {
return gostringnocopy(cfuncnameFromNameOff(f, nameOff))
}
func funcfile(f funcInfo, fileno int32) string {
datap := f.datap
if !f.valid() {
return "?"
}
// Make sure the cu index and file offset are valid
if fileoff := datap.cutab[f.cuOffset+uint32(fileno)]; fileoff != ^uint32(0) {
return gostringnocopy(&datap.filetab[fileoff])
}
// pcln section is corrupt.
return "?"
}
func funcline1(f funcInfo, targetpc uintptr, strict bool) (file string, line int32) {
datap := f.datap
if !f.valid() {
return "?", 0
}
fileno, _ := pcvalue(f, f.pcfile, targetpc, nil, strict)
line, _ = pcvalue(f, f.pcln, targetpc, nil, strict)
if fileno == -1 || line == -1 || int(fileno) >= len(datap.filetab) {
// print("looking for ", hex(targetpc), " in ", funcname(f), " got file=", fileno, " line=", lineno, "\n")
return "?", 0
}
file = funcfile(f, fileno)
return
}
func funcline(f funcInfo, targetpc uintptr) (file string, line int32) {
return funcline1(f, targetpc, true)
}
func funcspdelta(f funcInfo, targetpc uintptr, cache *pcvalueCache) int32 {
x, _ := pcvalue(f, f.pcsp, targetpc, cache, true)
if debugPcln && x&(goarch.PtrSize-1) != 0 {
print("invalid spdelta ", funcname(f), " ", hex(f.entry()), " ", hex(targetpc), " ", hex(f.pcsp), " ", x, "\n")
throw("bad spdelta")
}
return x
}
// funcMaxSPDelta returns the maximum spdelta at any point in f.
func funcMaxSPDelta(f funcInfo) int32 {
datap := f.datap
p := datap.pctab[f.pcsp:]
pc := f.entry()
val := int32(-1)
max := int32(0)
for {
var ok bool
p, ok = step(p, &pc, &val, pc == f.entry())
if !ok {
return max
}
if val > max {
max = val
}
}
}
func pcdatastart(f funcInfo, table uint32) uint32 {
return *(*uint32)(add(unsafe.Pointer(&f.nfuncdata), unsafe.Sizeof(f.nfuncdata)+uintptr(table)*4))
}
func pcdatavalue(f funcInfo, table uint32, targetpc uintptr, cache *pcvalueCache) int32 {
if table >= f.npcdata {
return -1
}
r, _ := pcvalue(f, pcdatastart(f, table), targetpc, cache, true)
return r
}
func pcdatavalue1(f funcInfo, table uint32, targetpc uintptr, cache *pcvalueCache, strict bool) int32 {
if table >= f.npcdata {
return -1
}
r, _ := pcvalue(f, pcdatastart(f, table), targetpc, cache, strict)
return r
}
// Like pcdatavalue, but also return the start PC of this PCData value.
// It doesn't take a cache.
func pcdatavalue2(f funcInfo, table uint32, targetpc uintptr) (int32, uintptr) {
if table >= f.npcdata {
return -1, 0
}
return pcvalue(f, pcdatastart(f, table), targetpc, nil, true)
}
// funcdata returns a pointer to the ith funcdata for f.
// funcdata should be kept in sync with cmd/link:writeFuncs.
func funcdata(f funcInfo, i uint8) unsafe.Pointer {
if i < 0 || i >= f.nfuncdata {
return nil
}
base := f.datap.gofunc // load gofunc address early so that we calculate during cache misses
p := uintptr(unsafe.Pointer(&f.nfuncdata)) + unsafe.Sizeof(f.nfuncdata) + uintptr(f.npcdata)*4 + uintptr(i)*4
off := *(*uint32)(unsafe.Pointer(p))
// Return off == ^uint32(0) ? 0 : f.datap.gofunc + uintptr(off), but without branches.
// The compiler calculates mask on most architectures using conditional assignment.
var mask uintptr
if off == ^uint32(0) {
mask = 1
}
mask--
raw := base + uintptr(off)
return unsafe.Pointer(raw & mask)
}
// step advances to the next pc, value pair in the encoded table.
func step(p []byte, pc *uintptr, val *int32, first bool) (newp []byte, ok bool) {
// For both uvdelta and pcdelta, the common case (~70%)
// is that they are a single byte. If so, avoid calling readvarint.
uvdelta := uint32(p[0])
if uvdelta == 0 && !first {
return nil, false
}
n := uint32(1)
if uvdelta&0x80 != 0 {
n, uvdelta = readvarint(p)
}
*val += int32(-(uvdelta & 1) ^ (uvdelta >> 1))
p = p[n:]
pcdelta := uint32(p[0])
n = 1
if pcdelta&0x80 != 0 {
n, pcdelta = readvarint(p)
}
p = p[n:]
*pc += uintptr(pcdelta * sys.PCQuantum)
return p, true
}
// readvarint reads a varint from p.
func readvarint(p []byte) (read uint32, val uint32) {
var v, shift, n uint32
for {
b := p[n]
n++
v |= uint32(b&0x7F) << (shift & 31)
if b&0x80 == 0 {
break
}
shift += 7
}
return n, v
}
type stackmap struct {
n int32 // number of bitmaps
nbit int32 // number of bits in each bitmap
bytedata [1]byte // bitmaps, each starting on a byte boundary
}
//go:nowritebarrier
func stackmapdata(stkmap *stackmap, n int32) bitvector {
// Check this invariant only when stackDebug is on at all.
// The invariant is already checked by many of stackmapdata's callers,
// and disabling it by default allows stackmapdata to be inlined.
if stackDebug > 0 && (n < 0 || n >= stkmap.n) {
throw("stackmapdata: index out of range")
}
return bitvector{stkmap.nbit, addb(&stkmap.bytedata[0], uintptr(n*((stkmap.nbit+7)>>3)))}
}
// inlinedCall is the encoding of entries in the FUNCDATA_InlTree table.
type inlinedCall struct {
funcID funcID // type of the called function
_ [3]byte
nameOff int32 // offset into pclntab for name of called function
parentPc int32 // position of an instruction whose source position is the call site (offset from entry)
startLine int32 // line number of start of function (func keyword/TEXT directive)
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !ppc64 && !ppc64le
package runtime
func prepGoExitFrame(sp uintptr) {
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build amd64 || 386
package runtime
import (
"internal/goarch"
"unsafe"
)
// adjust Gobuf as if it executed a call to fn with context ctxt
// and then stopped before the first instruction in fn.
func gostartcall(buf *gobuf, fn, ctxt unsafe.Pointer) {
sp := buf.sp
sp -= goarch.PtrSize
*(*uintptr)(unsafe.Pointer(sp)) = buf.pc
buf.sp = sp
buf.pc = uintptr(fn)
buf.ctxt = ctxt
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Time-related runtime and pieces of package time.
package runtime
import (
"internal/abi"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// Package time knows the layout of this structure.
// If this struct changes, adjust ../time/sleep.go:/runtimeTimer.
type timer struct {
// If this timer is on a heap, which P's heap it is on.
// puintptr rather than *p to match uintptr in the versions
// of this struct defined in other packages.
pp puintptr
// Timer wakes up at when, and then at when+period, ... (period > 0 only)
// each time calling f(arg, now) in the timer goroutine, so f must be
// a well-behaved function and not block.
//
// when must be positive on an active timer.
when int64
period int64
f func(any, uintptr)
arg any
seq uintptr
// What to set the when field to in timerModifiedXX status.
nextwhen int64
// The status field holds one of the values below.
status atomic.Uint32
}
// Code outside this file has to be careful in using a timer value.
//
// The pp, status, and nextwhen fields may only be used by code in this file.
//
// Code that creates a new timer value can set the when, period, f,
// arg, and seq fields.
// A new timer value may be passed to addtimer (called by time.startTimer).
// After doing that no fields may be touched.
//
// An active timer (one that has been passed to addtimer) may be
// passed to deltimer (time.stopTimer), after which it is no longer an
// active timer. It is an inactive timer.
// In an inactive timer the period, f, arg, and seq fields may be modified,
// but not the when field.
// It's OK to just drop an inactive timer and let the GC collect it.
// It's not OK to pass an inactive timer to addtimer.
// Only newly allocated timer values may be passed to addtimer.
//
// An active timer may be passed to modtimer. No fields may be touched.
// It remains an active timer.
//
// An inactive timer may be passed to resettimer to turn into an
// active timer with an updated when field.
// It's OK to pass a newly allocated timer value to resettimer.
//
// Timer operations are addtimer, deltimer, modtimer, resettimer,
// cleantimers, adjusttimers, and runtimer.
//
// We don't permit calling addtimer/deltimer/modtimer/resettimer simultaneously,
// but adjusttimers and runtimer can be called at the same time as any of those.
//
// Active timers live in heaps attached to P, in the timers field.
// Inactive timers live there too temporarily, until they are removed.
//
// addtimer:
// timerNoStatus -> timerWaiting
// anything else -> panic: invalid value
// deltimer:
// timerWaiting -> timerModifying -> timerDeleted
// timerModifiedEarlier -> timerModifying -> timerDeleted
// timerModifiedLater -> timerModifying -> timerDeleted
// timerNoStatus -> do nothing
// timerDeleted -> do nothing
// timerRemoving -> do nothing
// timerRemoved -> do nothing
// timerRunning -> wait until status changes
// timerMoving -> wait until status changes
// timerModifying -> wait until status changes
// modtimer:
// timerWaiting -> timerModifying -> timerModifiedXX
// timerModifiedXX -> timerModifying -> timerModifiedYY
// timerNoStatus -> timerModifying -> timerWaiting
// timerRemoved -> timerModifying -> timerWaiting
// timerDeleted -> timerModifying -> timerModifiedXX
// timerRunning -> wait until status changes
// timerMoving -> wait until status changes
// timerRemoving -> wait until status changes
// timerModifying -> wait until status changes
// cleantimers (looks in P's timer heap):
// timerDeleted -> timerRemoving -> timerRemoved
// timerModifiedXX -> timerMoving -> timerWaiting
// adjusttimers (looks in P's timer heap):
// timerDeleted -> timerRemoving -> timerRemoved
// timerModifiedXX -> timerMoving -> timerWaiting
// runtimer (looks in P's timer heap):
// timerNoStatus -> panic: uninitialized timer
// timerWaiting -> timerWaiting or
// timerWaiting -> timerRunning -> timerNoStatus or
// timerWaiting -> timerRunning -> timerWaiting
// timerModifying -> wait until status changes
// timerModifiedXX -> timerMoving -> timerWaiting
// timerDeleted -> timerRemoving -> timerRemoved
// timerRunning -> panic: concurrent runtimer calls
// timerRemoved -> panic: inconsistent timer heap
// timerRemoving -> panic: inconsistent timer heap
// timerMoving -> panic: inconsistent timer heap
// Values for the timer status field.
const (
// Timer has no status set yet.
timerNoStatus = iota
// Waiting for timer to fire.
// The timer is in some P's heap.
timerWaiting
// Running the timer function.
// A timer will only have this status briefly.
timerRunning
// The timer is deleted and should be removed.
// It should not be run, but it is still in some P's heap.
timerDeleted
// The timer is being removed.
// The timer will only have this status briefly.
timerRemoving
// The timer has been stopped.
// It is not in any P's heap.
timerRemoved
// The timer is being modified.
// The timer will only have this status briefly.
timerModifying
// The timer has been modified to an earlier time.
// The new when value is in the nextwhen field.
// The timer is in some P's heap, possibly in the wrong place.
timerModifiedEarlier
// The timer has been modified to the same or a later time.
// The new when value is in the nextwhen field.
// The timer is in some P's heap, possibly in the wrong place.
timerModifiedLater
// The timer has been modified and is being moved.
// The timer will only have this status briefly.
timerMoving
)
// maxWhen is the maximum value for timer's when field.
const maxWhen = 1<<63 - 1
// verifyTimers can be set to true to add debugging checks that the
// timer heaps are valid.
const verifyTimers = false
// Package time APIs.
// Godoc uses the comments in package time, not these.
// time.now is implemented in assembly.
// timeSleep puts the current goroutine to sleep for at least ns nanoseconds.
//
//go:linkname timeSleep time.Sleep
func timeSleep(ns int64) {
if ns <= 0 {
return
}
gp := getg()
t := gp.timer
if t == nil {
t = new(timer)
gp.timer = t
}
t.f = goroutineReady
t.arg = gp
t.nextwhen = nanotime() + ns
if t.nextwhen < 0 { // check for overflow.
t.nextwhen = maxWhen
}
gopark(resetForSleep, unsafe.Pointer(t), waitReasonSleep, traceEvGoSleep, 1)
}
// resetForSleep is called after the goroutine is parked for timeSleep.
// We can't call resettimer in timeSleep itself because if this is a short
// sleep and there are many goroutines then the P can wind up running the
// timer function, goroutineReady, before the goroutine has been parked.
func resetForSleep(gp *g, ut unsafe.Pointer) bool {
t := (*timer)(ut)
resettimer(t, t.nextwhen)
return true
}
// startTimer adds t to the timer heap.
//
//go:linkname startTimer time.startTimer
func startTimer(t *timer) {
if raceenabled {
racerelease(unsafe.Pointer(t))
}
addtimer(t)
}
// stopTimer stops a timer.
// It reports whether t was stopped before being run.
//
//go:linkname stopTimer time.stopTimer
func stopTimer(t *timer) bool {
return deltimer(t)
}
// resetTimer resets an inactive timer, adding it to the heap.
//
// Reports whether the timer was modified before it was run.
//
//go:linkname resetTimer time.resetTimer
func resetTimer(t *timer, when int64) bool {
if raceenabled {
racerelease(unsafe.Pointer(t))
}
return resettimer(t, when)
}
// modTimer modifies an existing timer.
//
//go:linkname modTimer time.modTimer
func modTimer(t *timer, when, period int64, f func(any, uintptr), arg any, seq uintptr) {
modtimer(t, when, period, f, arg, seq)
}
// Go runtime.
// Ready the goroutine arg.
func goroutineReady(arg any, seq uintptr) {
goready(arg.(*g), 0)
}
// Note: this changes some unsynchronized operations to synchronized operations
// addtimer adds a timer to the current P.
// This should only be called with a newly created timer.
// That avoids the risk of changing the when field of a timer in some P's heap,
// which could cause the heap to become unsorted.
func addtimer(t *timer) {
// when must be positive. A negative value will cause runtimer to
// overflow during its delta calculation and never expire other runtime
// timers. Zero will cause checkTimers to fail to notice the timer.
if t.when <= 0 {
throw("timer when must be positive")
}
if t.period < 0 {
throw("timer period must be non-negative")
}
if t.status.Load() != timerNoStatus {
throw("addtimer called with initialized timer")
}
t.status.Store(timerWaiting)
when := t.when
// Disable preemption while using pp to avoid changing another P's heap.
mp := acquirem()
pp := getg().m.p.ptr()
lock(&pp.timersLock)
cleantimers(pp)
doaddtimer(pp, t)
unlock(&pp.timersLock)
wakeNetPoller(when)
releasem(mp)
}
// doaddtimer adds t to the current P's heap.
// The caller must have locked the timers for pp.
func doaddtimer(pp *p, t *timer) {
// Timers rely on the network poller, so make sure the poller
// has started.
if netpollInited.Load() == 0 {
netpollGenericInit()
}
if t.pp != 0 {
throw("doaddtimer: P already set in timer")
}
t.pp.set(pp)
i := len(pp.timers)
pp.timers = append(pp.timers, t)
siftupTimer(pp.timers, i)
if t == pp.timers[0] {
pp.timer0When.Store(t.when)
}
pp.numTimers.Add(1)
}
// deltimer deletes the timer t. It may be on some other P, so we can't
// actually remove it from the timers heap. We can only mark it as deleted.
// It will be removed in due course by the P whose heap it is on.
// Reports whether the timer was removed before it was run.
func deltimer(t *timer) bool {
for {
switch s := t.status.Load(); s {
case timerWaiting, timerModifiedLater:
// Prevent preemption while the timer is in timerModifying.
// This could lead to a self-deadlock. See #38070.
mp := acquirem()
if t.status.CompareAndSwap(s, timerModifying) {
// Must fetch t.pp before changing status,
// as cleantimers in another goroutine
// can clear t.pp of a timerDeleted timer.
tpp := t.pp.ptr()
if !t.status.CompareAndSwap(timerModifying, timerDeleted) {
badTimer()
}
releasem(mp)
tpp.deletedTimers.Add(1)
// Timer was not yet run.
return true
} else {
releasem(mp)
}
case timerModifiedEarlier:
// Prevent preemption while the timer is in timerModifying.
// This could lead to a self-deadlock. See #38070.
mp := acquirem()
if t.status.CompareAndSwap(s, timerModifying) {
// Must fetch t.pp before setting status
// to timerDeleted.
tpp := t.pp.ptr()
if !t.status.CompareAndSwap(timerModifying, timerDeleted) {
badTimer()
}
releasem(mp)
tpp.deletedTimers.Add(1)
// Timer was not yet run.
return true
} else {
releasem(mp)
}
case timerDeleted, timerRemoving, timerRemoved:
// Timer was already run.
return false
case timerRunning, timerMoving:
// The timer is being run or moved, by a different P.
// Wait for it to complete.
osyield()
case timerNoStatus:
// Removing timer that was never added or
// has already been run. Also see issue 21874.
return false
case timerModifying:
// Simultaneous calls to deltimer and modtimer.
// Wait for the other call to complete.
osyield()
default:
badTimer()
}
}
}
// dodeltimer removes timer i from the current P's heap.
// We are locked on the P when this is called.
// It returns the smallest changed index in pp.timers.
// The caller must have locked the timers for pp.
func dodeltimer(pp *p, i int) int {
if t := pp.timers[i]; t.pp.ptr() != pp {
throw("dodeltimer: wrong P")
} else {
t.pp = 0
}
last := len(pp.timers) - 1
if i != last {
pp.timers[i] = pp.timers[last]
}
pp.timers[last] = nil
pp.timers = pp.timers[:last]
smallestChanged := i
if i != last {
// Moving to i may have moved the last timer to a new parent,
// so sift up to preserve the heap guarantee.
smallestChanged = siftupTimer(pp.timers, i)
siftdownTimer(pp.timers, i)
}
if i == 0 {
updateTimer0When(pp)
}
n := pp.numTimers.Add(-1)
if n == 0 {
// If there are no timers, then clearly none are modified.
pp.timerModifiedEarliest.Store(0)
}
return smallestChanged
}
// dodeltimer0 removes timer 0 from the current P's heap.
// We are locked on the P when this is called.
// It reports whether it saw no problems due to races.
// The caller must have locked the timers for pp.
func dodeltimer0(pp *p) {
if t := pp.timers[0]; t.pp.ptr() != pp {
throw("dodeltimer0: wrong P")
} else {
t.pp = 0
}
last := len(pp.timers) - 1
if last > 0 {
pp.timers[0] = pp.timers[last]
}
pp.timers[last] = nil
pp.timers = pp.timers[:last]
if last > 0 {
siftdownTimer(pp.timers, 0)
}
updateTimer0When(pp)
n := pp.numTimers.Add(-1)
if n == 0 {
// If there are no timers, then clearly none are modified.
pp.timerModifiedEarliest.Store(0)
}
}
// modtimer modifies an existing timer.
// This is called by the netpoll code or time.Ticker.Reset or time.Timer.Reset.
// Reports whether the timer was modified before it was run.
func modtimer(t *timer, when, period int64, f func(any, uintptr), arg any, seq uintptr) bool {
if when <= 0 {
throw("timer when must be positive")
}
if period < 0 {
throw("timer period must be non-negative")
}
status := uint32(timerNoStatus)
wasRemoved := false
var pending bool
var mp *m
loop:
for {
switch status = t.status.Load(); status {
case timerWaiting, timerModifiedEarlier, timerModifiedLater:
// Prevent preemption while the timer is in timerModifying.
// This could lead to a self-deadlock. See #38070.
mp = acquirem()
if t.status.CompareAndSwap(status, timerModifying) {
pending = true // timer not yet run
break loop
}
releasem(mp)
case timerNoStatus, timerRemoved:
// Prevent preemption while the timer is in timerModifying.
// This could lead to a self-deadlock. See #38070.
mp = acquirem()
// Timer was already run and t is no longer in a heap.
// Act like addtimer.
if t.status.CompareAndSwap(status, timerModifying) {
wasRemoved = true
pending = false // timer already run or stopped
break loop
}
releasem(mp)
case timerDeleted:
// Prevent preemption while the timer is in timerModifying.
// This could lead to a self-deadlock. See #38070.
mp = acquirem()
if t.status.CompareAndSwap(status, timerModifying) {
t.pp.ptr().deletedTimers.Add(-1)
pending = false // timer already stopped
break loop
}
releasem(mp)
case timerRunning, timerRemoving, timerMoving:
// The timer is being run or moved, by a different P.
// Wait for it to complete.
osyield()
case timerModifying:
// Multiple simultaneous calls to modtimer.
// Wait for the other call to complete.
osyield()
default:
badTimer()
}
}
t.period = period
t.f = f
t.arg = arg
t.seq = seq
if wasRemoved {
t.when = when
pp := getg().m.p.ptr()
lock(&pp.timersLock)
doaddtimer(pp, t)
unlock(&pp.timersLock)
if !t.status.CompareAndSwap(timerModifying, timerWaiting) {
badTimer()
}
releasem(mp)
wakeNetPoller(when)
} else {
// The timer is in some other P's heap, so we can't change
// the when field. If we did, the other P's heap would
// be out of order. So we put the new when value in the
// nextwhen field, and let the other P set the when field
// when it is prepared to resort the heap.
t.nextwhen = when
newStatus := uint32(timerModifiedLater)
if when < t.when {
newStatus = timerModifiedEarlier
}
tpp := t.pp.ptr()
if newStatus == timerModifiedEarlier {
updateTimerModifiedEarliest(tpp, when)
}
// Set the new status of the timer.
if !t.status.CompareAndSwap(timerModifying, newStatus) {
badTimer()
}
releasem(mp)
// If the new status is earlier, wake up the poller.
if newStatus == timerModifiedEarlier {
wakeNetPoller(when)
}
}
return pending
}
// resettimer resets the time when a timer should fire.
// If used for an inactive timer, the timer will become active.
// This should be called instead of addtimer if the timer value has been,
// or may have been, used previously.
// Reports whether the timer was modified before it was run.
func resettimer(t *timer, when int64) bool {
return modtimer(t, when, t.period, t.f, t.arg, t.seq)
}
// cleantimers cleans up the head of the timer queue. This speeds up
// programs that create and delete timers; leaving them in the heap
// slows down addtimer. Reports whether no timer problems were found.
// The caller must have locked the timers for pp.
func cleantimers(pp *p) {
gp := getg()
for {
if len(pp.timers) == 0 {
return
}
// This loop can theoretically run for a while, and because
// it is holding timersLock it cannot be preempted.
// If someone is trying to preempt us, just return.
// We can clean the timers later.
if gp.preemptStop {
return
}
t := pp.timers[0]
if t.pp.ptr() != pp {
throw("cleantimers: bad p")
}
switch s := t.status.Load(); s {
case timerDeleted:
if !t.status.CompareAndSwap(s, timerRemoving) {
continue
}
dodeltimer0(pp)
if !t.status.CompareAndSwap(timerRemoving, timerRemoved) {
badTimer()
}
pp.deletedTimers.Add(-1)
case timerModifiedEarlier, timerModifiedLater:
if !t.status.CompareAndSwap(s, timerMoving) {
continue
}
// Now we can change the when field.
t.when = t.nextwhen
// Move t to the right position.
dodeltimer0(pp)
doaddtimer(pp, t)
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
default:
// Head of timers does not need adjustment.
return
}
}
}
// moveTimers moves a slice of timers to pp. The slice has been taken
// from a different P.
// This is currently called when the world is stopped, but the caller
// is expected to have locked the timers for pp.
func moveTimers(pp *p, timers []*timer) {
for _, t := range timers {
loop:
for {
switch s := t.status.Load(); s {
case timerWaiting:
if !t.status.CompareAndSwap(s, timerMoving) {
continue
}
t.pp = 0
doaddtimer(pp, t)
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
break loop
case timerModifiedEarlier, timerModifiedLater:
if !t.status.CompareAndSwap(s, timerMoving) {
continue
}
t.when = t.nextwhen
t.pp = 0
doaddtimer(pp, t)
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
break loop
case timerDeleted:
if !t.status.CompareAndSwap(s, timerRemoved) {
continue
}
t.pp = 0
// We no longer need this timer in the heap.
break loop
case timerModifying:
// Loop until the modification is complete.
osyield()
case timerNoStatus, timerRemoved:
// We should not see these status values in a timers heap.
badTimer()
case timerRunning, timerRemoving, timerMoving:
// Some other P thinks it owns this timer,
// which should not happen.
badTimer()
default:
badTimer()
}
}
}
}
// adjusttimers looks through the timers in the current P's heap for
// any timers that have been modified to run earlier, and puts them in
// the correct place in the heap. While looking for those timers,
// it also moves timers that have been modified to run later,
// and removes deleted timers. The caller must have locked the timers for pp.
func adjusttimers(pp *p, now int64) {
// If we haven't yet reached the time of the first timerModifiedEarlier
// timer, don't do anything. This speeds up programs that adjust
// a lot of timers back and forth if the timers rarely expire.
// We'll postpone looking through all the adjusted timers until
// one would actually expire.
first := pp.timerModifiedEarliest.Load()
if first == 0 || first > now {
if verifyTimers {
verifyTimerHeap(pp)
}
return
}
// We are going to clear all timerModifiedEarlier timers.
pp.timerModifiedEarliest.Store(0)
var moved []*timer
for i := 0; i < len(pp.timers); i++ {
t := pp.timers[i]
if t.pp.ptr() != pp {
throw("adjusttimers: bad p")
}
switch s := t.status.Load(); s {
case timerDeleted:
if t.status.CompareAndSwap(s, timerRemoving) {
changed := dodeltimer(pp, i)
if !t.status.CompareAndSwap(timerRemoving, timerRemoved) {
badTimer()
}
pp.deletedTimers.Add(-1)
// Go back to the earliest changed heap entry.
// "- 1" because the loop will add 1.
i = changed - 1
}
case timerModifiedEarlier, timerModifiedLater:
if t.status.CompareAndSwap(s, timerMoving) {
// Now we can change the when field.
t.when = t.nextwhen
// Take t off the heap, and hold onto it.
// We don't add it back yet because the
// heap manipulation could cause our
// loop to skip some other timer.
changed := dodeltimer(pp, i)
moved = append(moved, t)
// Go back to the earliest changed heap entry.
// "- 1" because the loop will add 1.
i = changed - 1
}
case timerNoStatus, timerRunning, timerRemoving, timerRemoved, timerMoving:
badTimer()
case timerWaiting:
// OK, nothing to do.
case timerModifying:
// Check again after modification is complete.
osyield()
i--
default:
badTimer()
}
}
if len(moved) > 0 {
addAdjustedTimers(pp, moved)
}
if verifyTimers {
verifyTimerHeap(pp)
}
}
// addAdjustedTimers adds any timers we adjusted in adjusttimers
// back to the timer heap.
func addAdjustedTimers(pp *p, moved []*timer) {
for _, t := range moved {
doaddtimer(pp, t)
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
}
}
// nobarrierWakeTime looks at P's timers and returns the time when we
// should wake up the netpoller. It returns 0 if there are no timers.
// This function is invoked when dropping a P, and must run without
// any write barriers.
//
//go:nowritebarrierrec
func nobarrierWakeTime(pp *p) int64 {
next := pp.timer0When.Load()
nextAdj := pp.timerModifiedEarliest.Load()
if next == 0 || (nextAdj != 0 && nextAdj < next) {
next = nextAdj
}
return next
}
// runtimer examines the first timer in timers. If it is ready based on now,
// it runs the timer and removes or updates it.
// Returns 0 if it ran a timer, -1 if there are no more timers, or the time
// when the first timer should run.
// The caller must have locked the timers for pp.
// If a timer is run, this will temporarily unlock the timers.
//
//go:systemstack
func runtimer(pp *p, now int64) int64 {
for {
t := pp.timers[0]
if t.pp.ptr() != pp {
throw("runtimer: bad p")
}
switch s := t.status.Load(); s {
case timerWaiting:
if t.when > now {
// Not ready to run.
return t.when
}
if !t.status.CompareAndSwap(s, timerRunning) {
continue
}
// Note that runOneTimer may temporarily unlock
// pp.timersLock.
runOneTimer(pp, t, now)
return 0
case timerDeleted:
if !t.status.CompareAndSwap(s, timerRemoving) {
continue
}
dodeltimer0(pp)
if !t.status.CompareAndSwap(timerRemoving, timerRemoved) {
badTimer()
}
pp.deletedTimers.Add(-1)
if len(pp.timers) == 0 {
return -1
}
case timerModifiedEarlier, timerModifiedLater:
if !t.status.CompareAndSwap(s, timerMoving) {
continue
}
t.when = t.nextwhen
dodeltimer0(pp)
doaddtimer(pp, t)
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
case timerModifying:
// Wait for modification to complete.
osyield()
case timerNoStatus, timerRemoved:
// Should not see a new or inactive timer on the heap.
badTimer()
case timerRunning, timerRemoving, timerMoving:
// These should only be set when timers are locked,
// and we didn't do it.
badTimer()
default:
badTimer()
}
}
}
// runOneTimer runs a single timer.
// The caller must have locked the timers for pp.
// This will temporarily unlock the timers while running the timer function.
//
//go:systemstack
func runOneTimer(pp *p, t *timer, now int64) {
if raceenabled {
ppcur := getg().m.p.ptr()
if ppcur.timerRaceCtx == 0 {
ppcur.timerRaceCtx = racegostart(abi.FuncPCABIInternal(runtimer) + sys.PCQuantum)
}
raceacquirectx(ppcur.timerRaceCtx, unsafe.Pointer(t))
}
f := t.f
arg := t.arg
seq := t.seq
if t.period > 0 {
// Leave in heap but adjust next time to fire.
delta := t.when - now
t.when += t.period * (1 + -delta/t.period)
if t.when < 0 { // check for overflow.
t.when = maxWhen
}
siftdownTimer(pp.timers, 0)
if !t.status.CompareAndSwap(timerRunning, timerWaiting) {
badTimer()
}
updateTimer0When(pp)
} else {
// Remove from heap.
dodeltimer0(pp)
if !t.status.CompareAndSwap(timerRunning, timerNoStatus) {
badTimer()
}
}
if raceenabled {
// Temporarily use the current P's racectx for g0.
gp := getg()
if gp.racectx != 0 {
throw("runOneTimer: unexpected racectx")
}
gp.racectx = gp.m.p.ptr().timerRaceCtx
}
unlock(&pp.timersLock)
f(arg, seq)
lock(&pp.timersLock)
if raceenabled {
gp := getg()
gp.racectx = 0
}
}
// clearDeletedTimers removes all deleted timers from the P's timer heap.
// This is used to avoid clogging up the heap if the program
// starts a lot of long-running timers and then stops them.
// For example, this can happen via context.WithTimeout.
//
// This is the only function that walks through the entire timer heap,
// other than moveTimers which only runs when the world is stopped.
//
// The caller must have locked the timers for pp.
func clearDeletedTimers(pp *p) {
// We are going to clear all timerModifiedEarlier timers.
// Do this now in case new ones show up while we are looping.
pp.timerModifiedEarliest.Store(0)
cdel := int32(0)
to := 0
changedHeap := false
timers := pp.timers
nextTimer:
for _, t := range timers {
for {
switch s := t.status.Load(); s {
case timerWaiting:
if changedHeap {
timers[to] = t
siftupTimer(timers, to)
}
to++
continue nextTimer
case timerModifiedEarlier, timerModifiedLater:
if t.status.CompareAndSwap(s, timerMoving) {
t.when = t.nextwhen
timers[to] = t
siftupTimer(timers, to)
to++
changedHeap = true
if !t.status.CompareAndSwap(timerMoving, timerWaiting) {
badTimer()
}
continue nextTimer
}
case timerDeleted:
if t.status.CompareAndSwap(s, timerRemoving) {
t.pp = 0
cdel++
if !t.status.CompareAndSwap(timerRemoving, timerRemoved) {
badTimer()
}
changedHeap = true
continue nextTimer
}
case timerModifying:
// Loop until modification complete.
osyield()
case timerNoStatus, timerRemoved:
// We should not see these status values in a timer heap.
badTimer()
case timerRunning, timerRemoving, timerMoving:
// Some other P thinks it owns this timer,
// which should not happen.
badTimer()
default:
badTimer()
}
}
}
// Set remaining slots in timers slice to nil,
// so that the timer values can be garbage collected.
for i := to; i < len(timers); i++ {
timers[i] = nil
}
pp.deletedTimers.Add(-cdel)
pp.numTimers.Add(-cdel)
timers = timers[:to]
pp.timers = timers
updateTimer0When(pp)
if verifyTimers {
verifyTimerHeap(pp)
}
}
// verifyTimerHeap verifies that the timer heap is in a valid state.
// This is only for debugging, and is only called if verifyTimers is true.
// The caller must have locked the timers.
func verifyTimerHeap(pp *p) {
for i, t := range pp.timers {
if i == 0 {
// First timer has no parent.
continue
}
// The heap is 4-ary. See siftupTimer and siftdownTimer.
p := (i - 1) / 4
if t.when < pp.timers[p].when {
print("bad timer heap at ", i, ": ", p, ": ", pp.timers[p].when, ", ", i, ": ", t.when, "\n")
throw("bad timer heap")
}
}
if numTimers := int(pp.numTimers.Load()); len(pp.timers) != numTimers {
println("timer heap len", len(pp.timers), "!= numTimers", numTimers)
throw("bad timer heap len")
}
}
// updateTimer0When sets the P's timer0When field.
// The caller must have locked the timers for pp.
func updateTimer0When(pp *p) {
if len(pp.timers) == 0 {
pp.timer0When.Store(0)
} else {
pp.timer0When.Store(pp.timers[0].when)
}
}
// updateTimerModifiedEarliest updates the recorded nextwhen field of the
// earlier timerModifiedEarier value.
// The timers for pp will not be locked.
func updateTimerModifiedEarliest(pp *p, nextwhen int64) {
for {
old := pp.timerModifiedEarliest.Load()
if old != 0 && int64(old) < nextwhen {
return
}
if pp.timerModifiedEarliest.CompareAndSwap(old, nextwhen) {
return
}
}
}
// timeSleepUntil returns the time when the next timer should fire. Returns
// maxWhen if there are no timers.
// This is only called by sysmon and checkdead.
func timeSleepUntil() int64 {
next := int64(maxWhen)
// Prevent allp slice changes. This is like retake.
lock(&allpLock)
for _, pp := range allp {
if pp == nil {
// This can happen if procresize has grown
// allp but not yet created new Ps.
continue
}
w := pp.timer0When.Load()
if w != 0 && w < next {
next = w
}
w = pp.timerModifiedEarliest.Load()
if w != 0 && w < next {
next = w
}
}
unlock(&allpLock)
return next
}
// Heap maintenance algorithms.
// These algorithms check for slice index errors manually.
// Slice index error can happen if the program is using racy
// access to timers. We don't want to panic here, because
// it will cause the program to crash with a mysterious
// "panic holding locks" message. Instead, we panic while not
// holding a lock.
// siftupTimer puts the timer at position i in the right place
// in the heap by moving it up toward the top of the heap.
// It returns the smallest changed index.
func siftupTimer(t []*timer, i int) int {
if i >= len(t) {
badTimer()
}
when := t[i].when
if when <= 0 {
badTimer()
}
tmp := t[i]
for i > 0 {
p := (i - 1) / 4 // parent
if when >= t[p].when {
break
}
t[i] = t[p]
i = p
}
if tmp != t[i] {
t[i] = tmp
}
return i
}
// siftdownTimer puts the timer at position i in the right place
// in the heap by moving it down toward the bottom of the heap.
func siftdownTimer(t []*timer, i int) {
n := len(t)
if i >= n {
badTimer()
}
when := t[i].when
if when <= 0 {
badTimer()
}
tmp := t[i]
for {
c := i*4 + 1 // left child
c3 := c + 2 // mid child
if c >= n {
break
}
w := t[c].when
if c+1 < n && t[c+1].when < w {
w = t[c+1].when
c++
}
if c3 < n {
w3 := t[c3].when
if c3+1 < n && t[c3+1].when < w3 {
w3 = t[c3+1].when
c3++
}
if w3 < w {
w = w3
c = c3
}
}
if w >= when {
break
}
t[i] = t[c]
i = c
}
if tmp != t[i] {
t[i] = tmp
}
}
// badTimer is called if the timer data structures have been corrupted,
// presumably due to racy use by the program. We panic here rather than
// panicing due to invalid slice access while holding locks.
// See issue #25686.
func badTimer() {
throw("timer data corruption")
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !faketime
package runtime
import "unsafe"
// faketime is the simulated time in nanoseconds since 1970 for the
// playground.
//
// Zero means not to use faketime.
var faketime int64
//go:nosplit
func nanotime() int64 {
return nanotime1()
}
var overrideWrite func(fd uintptr, p unsafe.Pointer, n int32) int32
// write must be nosplit on Windows (see write1)
//
//go:nosplit
func write(fd uintptr, p unsafe.Pointer, n int32) int32 {
if overrideWrite != nil {
return overrideWrite(fd, noescape(p), n)
}
return write1(fd, p, n)
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (windows && !amd64) || !windows
package runtime
//go:nosplit
func osSetupTLS(mp *m) {}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Go execution tracer.
// The tracer captures a wide range of execution events like goroutine
// creation/blocking/unblocking, syscall enter/exit/block, GC-related events,
// changes of heap size, processor start/stop, etc and writes them to a buffer
// in a compact form. A precise nanosecond-precision timestamp and a stack
// trace is captured for most events.
// See https://golang.org/s/go15trace for more info.
package runtime
import (
"internal/goarch"
"runtime/internal/atomic"
"runtime/internal/sys"
"unsafe"
)
// Event types in the trace, args are given in square brackets.
const (
traceEvNone = 0 // unused
traceEvBatch = 1 // start of per-P batch of events [pid, timestamp]
traceEvFrequency = 2 // contains tracer timer frequency [frequency (ticks per second)]
traceEvStack = 3 // stack [stack id, number of PCs, array of {PC, func string ID, file string ID, line}]
traceEvGomaxprocs = 4 // current value of GOMAXPROCS [timestamp, GOMAXPROCS, stack id]
traceEvProcStart = 5 // start of P [timestamp, thread id]
traceEvProcStop = 6 // stop of P [timestamp]
traceEvGCStart = 7 // GC start [timestamp, seq, stack id]
traceEvGCDone = 8 // GC done [timestamp]
traceEvGCSTWStart = 9 // GC STW start [timestamp, kind]
traceEvGCSTWDone = 10 // GC STW done [timestamp]
traceEvGCSweepStart = 11 // GC sweep start [timestamp, stack id]
traceEvGCSweepDone = 12 // GC sweep done [timestamp, swept, reclaimed]
traceEvGoCreate = 13 // goroutine creation [timestamp, new goroutine id, new stack id, stack id]
traceEvGoStart = 14 // goroutine starts running [timestamp, goroutine id, seq]
traceEvGoEnd = 15 // goroutine ends [timestamp]
traceEvGoStop = 16 // goroutine stops (like in select{}) [timestamp, stack]
traceEvGoSched = 17 // goroutine calls Gosched [timestamp, stack]
traceEvGoPreempt = 18 // goroutine is preempted [timestamp, stack]
traceEvGoSleep = 19 // goroutine calls Sleep [timestamp, stack]
traceEvGoBlock = 20 // goroutine blocks [timestamp, stack]
traceEvGoUnblock = 21 // goroutine is unblocked [timestamp, goroutine id, seq, stack]
traceEvGoBlockSend = 22 // goroutine blocks on chan send [timestamp, stack]
traceEvGoBlockRecv = 23 // goroutine blocks on chan recv [timestamp, stack]
traceEvGoBlockSelect = 24 // goroutine blocks on select [timestamp, stack]
traceEvGoBlockSync = 25 // goroutine blocks on Mutex/RWMutex [timestamp, stack]
traceEvGoBlockCond = 26 // goroutine blocks on Cond [timestamp, stack]
traceEvGoBlockNet = 27 // goroutine blocks on network [timestamp, stack]
traceEvGoSysCall = 28 // syscall enter [timestamp, stack]
traceEvGoSysExit = 29 // syscall exit [timestamp, goroutine id, seq, real timestamp]
traceEvGoSysBlock = 30 // syscall blocks [timestamp]
traceEvGoWaiting = 31 // denotes that goroutine is blocked when tracing starts [timestamp, goroutine id]
traceEvGoInSyscall = 32 // denotes that goroutine is in syscall when tracing starts [timestamp, goroutine id]
traceEvHeapAlloc = 33 // gcController.heapLive change [timestamp, heap_alloc]
traceEvHeapGoal = 34 // gcController.heapGoal() (formerly next_gc) change [timestamp, heap goal in bytes]
traceEvTimerGoroutine = 35 // not currently used; previously denoted timer goroutine [timer goroutine id]
traceEvFutileWakeup = 36 // denotes that the previous wakeup of this goroutine was futile [timestamp]
traceEvString = 37 // string dictionary entry [ID, length, string]
traceEvGoStartLocal = 38 // goroutine starts running on the same P as the last event [timestamp, goroutine id]
traceEvGoUnblockLocal = 39 // goroutine is unblocked on the same P as the last event [timestamp, goroutine id, stack]
traceEvGoSysExitLocal = 40 // syscall exit on the same P as the last event [timestamp, goroutine id, real timestamp]
traceEvGoStartLabel = 41 // goroutine starts running with label [timestamp, goroutine id, seq, label string id]
traceEvGoBlockGC = 42 // goroutine blocks on GC assist [timestamp, stack]
traceEvGCMarkAssistStart = 43 // GC mark assist start [timestamp, stack]
traceEvGCMarkAssistDone = 44 // GC mark assist done [timestamp]
traceEvUserTaskCreate = 45 // trace.NewContext [timestamp, internal task id, internal parent task id, stack, name string]
traceEvUserTaskEnd = 46 // end of a task [timestamp, internal task id, stack]
traceEvUserRegion = 47 // trace.WithRegion [timestamp, internal task id, mode(0:start, 1:end), stack, name string]
traceEvUserLog = 48 // trace.Log [timestamp, internal task id, key string id, stack, value string]
traceEvCPUSample = 49 // CPU profiling sample [timestamp, real timestamp, real P id (-1 when absent), goroutine id, stack]
traceEvCount = 50
// Byte is used but only 6 bits are available for event type.
// The remaining 2 bits are used to specify the number of arguments.
// That means, the max event type value is 63.
)
const (
// Timestamps in trace are cputicks/traceTickDiv.
// This makes absolute values of timestamp diffs smaller,
// and so they are encoded in less number of bytes.
// 64 on x86 is somewhat arbitrary (one tick is ~20ns on a 3GHz machine).
// The suggested increment frequency for PowerPC's time base register is
// 512 MHz according to Power ISA v2.07 section 6.2, so we use 16 on ppc64
// and ppc64le.
// Tracing won't work reliably for architectures where cputicks is emulated
// by nanotime, so the value doesn't matter for those architectures.
traceTickDiv = 16 + 48*(goarch.Is386|goarch.IsAmd64)
// Maximum number of PCs in a single stack trace.
// Since events contain only stack id rather than whole stack trace,
// we can allow quite large values here.
traceStackSize = 128
// Identifier of a fake P that is used when we trace without a real P.
traceGlobProc = -1
// Maximum number of bytes to encode uint64 in base-128.
traceBytesPerNumber = 10
// Shift of the number of arguments in the first event byte.
traceArgCountShift = 6
// Flag passed to traceGoPark to denote that the previous wakeup of this
// goroutine was futile. For example, a goroutine was unblocked on a mutex,
// but another goroutine got ahead and acquired the mutex before the first
// goroutine is scheduled, so the first goroutine has to block again.
// Such wakeups happen on buffered channels and sync.Mutex,
// but are generally not interesting for end user.
traceFutileWakeup byte = 128
)
// trace is global tracing context.
var trace struct {
// trace.lock must only be acquired on the system stack where
// stack splits cannot happen while it is held.
lock mutex // protects the following members
lockOwner *g // to avoid deadlocks during recursive lock locks
enabled bool // when set runtime traces events
shutdown bool // set when we are waiting for trace reader to finish after setting enabled to false
headerWritten bool // whether ReadTrace has emitted trace header
footerWritten bool // whether ReadTrace has emitted trace footer
shutdownSema uint32 // used to wait for ReadTrace completion
seqStart uint64 // sequence number when tracing was started
ticksStart int64 // cputicks when tracing was started
ticksEnd int64 // cputicks when tracing was stopped
timeStart int64 // nanotime when tracing was started
timeEnd int64 // nanotime when tracing was stopped
seqGC uint64 // GC start/done sequencer
reading traceBufPtr // buffer currently handed off to user
empty traceBufPtr // stack of empty buffers
fullHead traceBufPtr // queue of full buffers
fullTail traceBufPtr
stackTab traceStackTable // maps stack traces to unique ids
// cpuLogRead accepts CPU profile samples from the signal handler where
// they're generated. It uses a two-word header to hold the IDs of the P and
// G (respectively) that were active at the time of the sample. Because
// profBuf uses a record with all zeros in its header to indicate overflow,
// we make sure to make the P field always non-zero: The ID of a real P will
// start at bit 1, and bit 0 will be set. Samples that arrive while no P is
// running (such as near syscalls) will set the first header field to 0b10.
// This careful handling of the first header field allows us to store ID of
// the active G directly in the second field, even though that will be 0
// when sampling g0.
cpuLogRead *profBuf
// cpuLogBuf is a trace buffer to hold events corresponding to CPU profile
// samples, which arrive out of band and not directly connected to a
// specific P.
cpuLogBuf traceBufPtr
reader atomic.Pointer[g] // goroutine that called ReadTrace, or nil
signalLock atomic.Uint32 // protects use of the following member, only usable in signal handlers
cpuLogWrite *profBuf // copy of cpuLogRead for use in signal handlers, set without signalLock
// Dictionary for traceEvString.
//
// TODO: central lock to access the map is not ideal.
// option: pre-assign ids to all user annotation region names and tags
// option: per-P cache
// option: sync.Map like data structure
stringsLock mutex
strings map[string]uint64
stringSeq uint64
// markWorkerLabels maps gcMarkWorkerMode to string ID.
markWorkerLabels [len(gcMarkWorkerModeStrings)]uint64
bufLock mutex // protects buf
buf traceBufPtr // global trace buffer, used when running without a p
}
// traceBufHeader is per-P tracing buffer.
type traceBufHeader struct {
link traceBufPtr // in trace.empty/full
lastTicks uint64 // when we wrote the last event
pos int // next write offset in arr
stk [traceStackSize]uintptr // scratch buffer for traceback
}
// traceBuf is per-P tracing buffer.
type traceBuf struct {
_ sys.NotInHeap
traceBufHeader
arr [64<<10 - unsafe.Sizeof(traceBufHeader{})]byte // underlying buffer for traceBufHeader.buf
}
// traceBufPtr is a *traceBuf that is not traced by the garbage
// collector and doesn't have write barriers. traceBufs are not
// allocated from the GC'd heap, so this is safe, and are often
// manipulated in contexts where write barriers are not allowed, so
// this is necessary.
//
// TODO: Since traceBuf is now embedded runtime/internal/sys.NotInHeap, this isn't necessary.
type traceBufPtr uintptr
func (tp traceBufPtr) ptr() *traceBuf { return (*traceBuf)(unsafe.Pointer(tp)) }
func (tp *traceBufPtr) set(b *traceBuf) { *tp = traceBufPtr(unsafe.Pointer(b)) }
func traceBufPtrOf(b *traceBuf) traceBufPtr {
return traceBufPtr(unsafe.Pointer(b))
}
// StartTrace enables tracing for the current process.
// While tracing, the data will be buffered and available via ReadTrace.
// StartTrace returns an error if tracing is already enabled.
// Most clients should use the runtime/trace package or the testing package's
// -test.trace flag instead of calling StartTrace directly.
func StartTrace() error {
// Stop the world so that we can take a consistent snapshot
// of all goroutines at the beginning of the trace.
// Do not stop the world during GC so we ensure we always see
// a consistent view of GC-related events (e.g. a start is always
// paired with an end).
stopTheWorldGC("start tracing")
// Prevent sysmon from running any code that could generate events.
lock(&sched.sysmonlock)
// We are in stop-the-world, but syscalls can finish and write to trace concurrently.
// Exitsyscall could check trace.enabled long before and then suddenly wake up
// and decide to write to trace at a random point in time.
// However, such syscall will use the global trace.buf buffer, because we've
// acquired all p's by doing stop-the-world. So this protects us from such races.
lock(&trace.bufLock)
if trace.enabled || trace.shutdown {
unlock(&trace.bufLock)
unlock(&sched.sysmonlock)
startTheWorldGC()
return errorString("tracing is already enabled")
}
// Can't set trace.enabled yet. While the world is stopped, exitsyscall could
// already emit a delayed event (see exitTicks in exitsyscall) if we set trace.enabled here.
// That would lead to an inconsistent trace:
// - either GoSysExit appears before EvGoInSyscall,
// - or GoSysExit appears for a goroutine for which we don't emit EvGoInSyscall below.
// To instruct traceEvent that it must not ignore events below, we set startingtrace.
// trace.enabled is set afterwards once we have emitted all preliminary events.
mp := getg().m
mp.startingtrace = true
// Obtain current stack ID to use in all traceEvGoCreate events below.
stkBuf := make([]uintptr, traceStackSize)
stackID := traceStackID(mp, stkBuf, 2)
profBuf := newProfBuf(2, profBufWordCount, profBufTagCount) // after the timestamp, header is [pp.id, gp.goid]
trace.cpuLogRead = profBuf
// We must not acquire trace.signalLock outside of a signal handler: a
// profiling signal may arrive at any time and try to acquire it, leading to
// deadlock. Because we can't use that lock to protect updates to
// trace.cpuLogWrite (only use of the structure it references), reads and
// writes of the pointer must be atomic. (And although this field is never
// the sole pointer to the profBuf value, it's best to allow a write barrier
// here.)
atomicstorep(unsafe.Pointer(&trace.cpuLogWrite), unsafe.Pointer(profBuf))
// World is stopped, no need to lock.
forEachGRace(func(gp *g) {
status := readgstatus(gp)
if status != _Gdead {
gp.traceseq = 0
gp.tracelastp = getg().m.p
// +PCQuantum because traceFrameForPC expects return PCs and subtracts PCQuantum.
id := trace.stackTab.put([]uintptr{startPCforTrace(gp.startpc) + sys.PCQuantum})
traceEvent(traceEvGoCreate, -1, gp.goid, uint64(id), stackID)
}
if status == _Gwaiting {
// traceEvGoWaiting is implied to have seq=1.
gp.traceseq++
traceEvent(traceEvGoWaiting, -1, gp.goid)
}
if status == _Gsyscall {
gp.traceseq++
traceEvent(traceEvGoInSyscall, -1, gp.goid)
} else if status == _Gdead && gp.m != nil && gp.m.isextra {
// Trigger two trace events for the dead g in the extra m,
// since the next event of the g will be traceEvGoSysExit in exitsyscall,
// while calling from C thread to Go.
gp.traceseq = 0
gp.tracelastp = getg().m.p
// +PCQuantum because traceFrameForPC expects return PCs and subtracts PCQuantum.
id := trace.stackTab.put([]uintptr{startPCforTrace(0) + sys.PCQuantum}) // no start pc
traceEvent(traceEvGoCreate, -1, gp.goid, uint64(id), stackID)
gp.traceseq++
traceEvent(traceEvGoInSyscall, -1, gp.goid)
} else {
gp.sysblocktraced = false
}
})
traceProcStart()
traceGoStart()
// Note: ticksStart needs to be set after we emit traceEvGoInSyscall events.
// If we do it the other way around, it is possible that exitsyscall will
// query sysexitticks after ticksStart but before traceEvGoInSyscall timestamp.
// It will lead to a false conclusion that cputicks is broken.
trace.ticksStart = cputicks()
trace.timeStart = nanotime()
trace.headerWritten = false
trace.footerWritten = false
// string to id mapping
// 0 : reserved for an empty string
// remaining: other strings registered by traceString
trace.stringSeq = 0
trace.strings = make(map[string]uint64)
trace.seqGC = 0
mp.startingtrace = false
trace.enabled = true
// Register runtime goroutine labels.
_, pid, bufp := traceAcquireBuffer()
for i, label := range gcMarkWorkerModeStrings[:] {
trace.markWorkerLabels[i], bufp = traceString(bufp, pid, label)
}
traceReleaseBuffer(pid)
unlock(&trace.bufLock)
unlock(&sched.sysmonlock)
startTheWorldGC()
return nil
}
// StopTrace stops tracing, if it was previously enabled.
// StopTrace only returns after all the reads for the trace have completed.
func StopTrace() {
// Stop the world so that we can collect the trace buffers from all p's below,
// and also to avoid races with traceEvent.
stopTheWorldGC("stop tracing")
// See the comment in StartTrace.
lock(&sched.sysmonlock)
// See the comment in StartTrace.
lock(&trace.bufLock)
if !trace.enabled {
unlock(&trace.bufLock)
unlock(&sched.sysmonlock)
startTheWorldGC()
return
}
traceGoSched()
atomicstorep(unsafe.Pointer(&trace.cpuLogWrite), nil)
trace.cpuLogRead.close()
traceReadCPU()
// Loop over all allocated Ps because dead Ps may still have
// trace buffers.
for _, p := range allp[:cap(allp)] {
buf := p.tracebuf
if buf != 0 {
traceFullQueue(buf)
p.tracebuf = 0
}
}
if trace.buf != 0 {
buf := trace.buf
trace.buf = 0
if buf.ptr().pos != 0 {
traceFullQueue(buf)
}
}
if trace.cpuLogBuf != 0 {
buf := trace.cpuLogBuf
trace.cpuLogBuf = 0
if buf.ptr().pos != 0 {
traceFullQueue(buf)
}
}
for {
trace.ticksEnd = cputicks()
trace.timeEnd = nanotime()
// Windows time can tick only every 15ms, wait for at least one tick.
if trace.timeEnd != trace.timeStart {
break
}
osyield()
}
trace.enabled = false
trace.shutdown = true
unlock(&trace.bufLock)
unlock(&sched.sysmonlock)
startTheWorldGC()
// The world is started but we've set trace.shutdown, so new tracing can't start.
// Wait for the trace reader to flush pending buffers and stop.
semacquire(&trace.shutdownSema)
if raceenabled {
raceacquire(unsafe.Pointer(&trace.shutdownSema))
}
systemstack(func() {
// The lock protects us from races with StartTrace/StopTrace because they do stop-the-world.
lock(&trace.lock)
for _, p := range allp[:cap(allp)] {
if p.tracebuf != 0 {
throw("trace: non-empty trace buffer in proc")
}
}
if trace.buf != 0 {
throw("trace: non-empty global trace buffer")
}
if trace.fullHead != 0 || trace.fullTail != 0 {
throw("trace: non-empty full trace buffer")
}
if trace.reading != 0 || trace.reader.Load() != nil {
throw("trace: reading after shutdown")
}
for trace.empty != 0 {
buf := trace.empty
trace.empty = buf.ptr().link
sysFree(unsafe.Pointer(buf), unsafe.Sizeof(*buf.ptr()), &memstats.other_sys)
}
trace.strings = nil
trace.shutdown = false
trace.cpuLogRead = nil
unlock(&trace.lock)
})
}
// ReadTrace returns the next chunk of binary tracing data, blocking until data
// is available. If tracing is turned off and all the data accumulated while it
// was on has been returned, ReadTrace returns nil. The caller must copy the
// returned data before calling ReadTrace again.
// ReadTrace must be called from one goroutine at a time.
func ReadTrace() []byte {
top:
var buf []byte
var park bool
systemstack(func() {
buf, park = readTrace0()
})
if park {
gopark(func(gp *g, _ unsafe.Pointer) bool {
if !trace.reader.CompareAndSwapNoWB(nil, gp) {
// We're racing with another reader.
// Wake up and handle this case.
return false
}
if g2 := traceReader(); gp == g2 {
// New data arrived between unlocking
// and the CAS and we won the wake-up
// race, so wake up directly.
return false
} else if g2 != nil {
printlock()
println("runtime: got trace reader", g2, g2.goid)
throw("unexpected trace reader")
}
return true
}, nil, waitReasonTraceReaderBlocked, traceEvGoBlock, 2)
goto top
}
return buf
}
// readTrace0 is ReadTrace's continuation on g0. This must run on the
// system stack because it acquires trace.lock.
//
//go:systemstack
func readTrace0() (buf []byte, park bool) {
if raceenabled {
// g0 doesn't have a race context. Borrow the user G's.
if getg().racectx != 0 {
throw("expected racectx == 0")
}
getg().racectx = getg().m.curg.racectx
// (This defer should get open-coded, which is safe on
// the system stack.)
defer func() { getg().racectx = 0 }()
}
// This function may need to lock trace.lock recursively
// (goparkunlock -> traceGoPark -> traceEvent -> traceFlush).
// To allow this we use trace.lockOwner.
// Also this function must not allocate while holding trace.lock:
// allocation can call heap allocate, which will try to emit a trace
// event while holding heap lock.
lock(&trace.lock)
trace.lockOwner = getg().m.curg
if trace.reader.Load() != nil {
// More than one goroutine reads trace. This is bad.
// But we rather do not crash the program because of tracing,
// because tracing can be enabled at runtime on prod servers.
trace.lockOwner = nil
unlock(&trace.lock)
println("runtime: ReadTrace called from multiple goroutines simultaneously")
return nil, false
}
// Recycle the old buffer.
if buf := trace.reading; buf != 0 {
buf.ptr().link = trace.empty
trace.empty = buf
trace.reading = 0
}
// Write trace header.
if !trace.headerWritten {
trace.headerWritten = true
trace.lockOwner = nil
unlock(&trace.lock)
return []byte("go 1.19 trace\x00\x00\x00"), false
}
// Optimistically look for CPU profile samples. This may write new stack
// records, and may write new tracing buffers.
if !trace.footerWritten && !trace.shutdown {
traceReadCPU()
}
// Wait for new data.
if trace.fullHead == 0 && !trace.shutdown {
// We don't simply use a note because the scheduler
// executes this goroutine directly when it wakes up
// (also a note would consume an M).
trace.lockOwner = nil
unlock(&trace.lock)
return nil, true
}
newFull:
assertLockHeld(&trace.lock)
// Write a buffer.
if trace.fullHead != 0 {
buf := traceFullDequeue()
trace.reading = buf
trace.lockOwner = nil
unlock(&trace.lock)
return buf.ptr().arr[:buf.ptr().pos], false
}
// Write footer with timer frequency.
if !trace.footerWritten {
trace.footerWritten = true
// Use float64 because (trace.ticksEnd - trace.ticksStart) * 1e9 can overflow int64.
freq := float64(trace.ticksEnd-trace.ticksStart) * 1e9 / float64(trace.timeEnd-trace.timeStart) / traceTickDiv
if freq <= 0 {
throw("trace: ReadTrace got invalid frequency")
}
trace.lockOwner = nil
unlock(&trace.lock)
// Write frequency event.
bufp := traceFlush(0, 0)
buf := bufp.ptr()
buf.byte(traceEvFrequency | 0<<traceArgCountShift)
buf.varint(uint64(freq))
// Dump stack table.
// This will emit a bunch of full buffers, we will pick them up
// on the next iteration.
bufp = trace.stackTab.dump(bufp)
// Flush final buffer.
lock(&trace.lock)
traceFullQueue(bufp)
goto newFull // trace.lock should be held at newFull
}
// Done.
if trace.shutdown {
trace.lockOwner = nil
unlock(&trace.lock)
if raceenabled {
// Model synchronization on trace.shutdownSema, which race
// detector does not see. This is required to avoid false
// race reports on writer passed to trace.Start.
racerelease(unsafe.Pointer(&trace.shutdownSema))
}
// trace.enabled is already reset, so can call traceable functions.
semrelease(&trace.shutdownSema)
return nil, false
}
// Also bad, but see the comment above.
trace.lockOwner = nil
unlock(&trace.lock)
println("runtime: spurious wakeup of trace reader")
return nil, false
}
// traceReader returns the trace reader that should be woken up, if any.
// Callers should first check that trace.enabled or trace.shutdown is set.
//
// This must run on the system stack because it acquires trace.lock.
//
//go:systemstack
func traceReader() *g {
// Optimistic check first
if traceReaderAvailable() == nil {
return nil
}
lock(&trace.lock)
gp := traceReaderAvailable()
if gp == nil || !trace.reader.CompareAndSwapNoWB(gp, nil) {
unlock(&trace.lock)
return nil
}
unlock(&trace.lock)
return gp
}
// traceReaderAvailable returns the trace reader if it is not currently
// scheduled and should be. Callers should first check that trace.enabled
// or trace.shutdown is set.
func traceReaderAvailable() *g {
if trace.fullHead != 0 || trace.shutdown {
return trace.reader.Load()
}
return nil
}
// traceProcFree frees trace buffer associated with pp.
//
// This must run on the system stack because it acquires trace.lock.
//
//go:systemstack
func traceProcFree(pp *p) {
buf := pp.tracebuf
pp.tracebuf = 0
if buf == 0 {
return
}
lock(&trace.lock)
traceFullQueue(buf)
unlock(&trace.lock)
}
// traceFullQueue queues buf into queue of full buffers.
func traceFullQueue(buf traceBufPtr) {
buf.ptr().link = 0
if trace.fullHead == 0 {
trace.fullHead = buf
} else {
trace.fullTail.ptr().link = buf
}
trace.fullTail = buf
}
// traceFullDequeue dequeues from queue of full buffers.
func traceFullDequeue() traceBufPtr {
buf := trace.fullHead
if buf == 0 {
return 0
}
trace.fullHead = buf.ptr().link
if trace.fullHead == 0 {
trace.fullTail = 0
}
buf.ptr().link = 0
return buf
}
// traceEvent writes a single event to trace buffer, flushing the buffer if necessary.
// ev is event type.
// If skip > 0, write current stack id as the last argument (skipping skip top frames).
// If skip = 0, this event type should contain a stack, but we don't want
// to collect and remember it for this particular call.
func traceEvent(ev byte, skip int, args ...uint64) {
mp, pid, bufp := traceAcquireBuffer()
// Double-check trace.enabled now that we've done m.locks++ and acquired bufLock.
// This protects from races between traceEvent and StartTrace/StopTrace.
// The caller checked that trace.enabled == true, but trace.enabled might have been
// turned off between the check and now. Check again. traceLockBuffer did mp.locks++,
// StopTrace does stopTheWorld, and stopTheWorld waits for mp.locks to go back to zero,
// so if we see trace.enabled == true now, we know it's true for the rest of the function.
// Exitsyscall can run even during stopTheWorld. The race with StartTrace/StopTrace
// during tracing in exitsyscall is resolved by locking trace.bufLock in traceLockBuffer.
//
// Note trace_userTaskCreate runs the same check.
if !trace.enabled && !mp.startingtrace {
traceReleaseBuffer(pid)
return
}
if skip > 0 {
if getg() == mp.curg {
skip++ // +1 because stack is captured in traceEventLocked.
}
}
traceEventLocked(0, mp, pid, bufp, ev, 0, skip, args...)
traceReleaseBuffer(pid)
}
// traceEventLocked writes a single event of type ev to the trace buffer bufp,
// flushing the buffer if necessary. pid is the id of the current P, or
// traceGlobProc if we're tracing without a real P.
//
// Preemption is disabled, and if running without a real P the global tracing
// buffer is locked.
//
// Events types that do not include a stack set skip to -1. Event types that
// include a stack may explicitly reference a stackID from the trace.stackTab
// (obtained by an earlier call to traceStackID). Without an explicit stackID,
// this function will automatically capture the stack of the goroutine currently
// running on mp, skipping skip top frames or, if skip is 0, writing out an
// empty stack record.
//
// It records the event's args to the traceBuf, and also makes an effort to
// reserve extraBytes bytes of additional space immediately following the event,
// in the same traceBuf.
func traceEventLocked(extraBytes int, mp *m, pid int32, bufp *traceBufPtr, ev byte, stackID uint32, skip int, args ...uint64) {
buf := bufp.ptr()
// TODO: test on non-zero extraBytes param.
maxSize := 2 + 5*traceBytesPerNumber + extraBytes // event type, length, sequence, timestamp, stack id and two add params
if buf == nil || len(buf.arr)-buf.pos < maxSize {
systemstack(func() {
buf = traceFlush(traceBufPtrOf(buf), pid).ptr()
})
bufp.set(buf)
}
// NOTE: ticks might be same after tick division, although the real cputicks is
// linear growth.
ticks := uint64(cputicks()) / traceTickDiv
tickDiff := ticks - buf.lastTicks
if tickDiff == 0 {
ticks = buf.lastTicks + 1
tickDiff = 1
}
buf.lastTicks = ticks
narg := byte(len(args))
if stackID != 0 || skip >= 0 {
narg++
}
// We have only 2 bits for number of arguments.
// If number is >= 3, then the event type is followed by event length in bytes.
if narg > 3 {
narg = 3
}
startPos := buf.pos
buf.byte(ev | narg<<traceArgCountShift)
var lenp *byte
if narg == 3 {
// Reserve the byte for length assuming that length < 128.
buf.varint(0)
lenp = &buf.arr[buf.pos-1]
}
buf.varint(tickDiff)
for _, a := range args {
buf.varint(a)
}
if stackID != 0 {
buf.varint(uint64(stackID))
} else if skip == 0 {
buf.varint(0)
} else if skip > 0 {
buf.varint(traceStackID(mp, buf.stk[:], skip))
}
evSize := buf.pos - startPos
if evSize > maxSize {
throw("invalid length of trace event")
}
if lenp != nil {
// Fill in actual length.
*lenp = byte(evSize - 2)
}
}
// traceCPUSample writes a CPU profile sample stack to the execution tracer's
// profiling buffer. It is called from a signal handler, so is limited in what
// it can do.
func traceCPUSample(gp *g, pp *p, stk []uintptr) {
if !trace.enabled {
// Tracing is usually turned off; don't spend time acquiring the signal
// lock unless it's active.
return
}
// Match the clock used in traceEventLocked
now := cputicks()
// The "header" here is the ID of the P that was running the profiled code,
// followed by the ID of the goroutine. (For normal CPU profiling, it's
// usually the number of samples with the given stack.) Near syscalls, pp
// may be nil. Reporting goid of 0 is fine for either g0 or a nil gp.
var hdr [2]uint64
if pp != nil {
// Overflow records in profBuf have all header values set to zero. Make
// sure that real headers have at least one bit set.
hdr[0] = uint64(pp.id)<<1 | 0b1
} else {
hdr[0] = 0b10
}
if gp != nil {
hdr[1] = gp.goid
}
// Allow only one writer at a time
for !trace.signalLock.CompareAndSwap(0, 1) {
// TODO: Is it safe to osyield here? https://go.dev/issue/52672
osyield()
}
if log := (*profBuf)(atomic.Loadp(unsafe.Pointer(&trace.cpuLogWrite))); log != nil {
// Note: we don't pass a tag pointer here (how should profiling tags
// interact with the execution tracer?), but if we did we'd need to be
// careful about write barriers. See the long comment in profBuf.write.
log.write(nil, now, hdr[:], stk)
}
trace.signalLock.Store(0)
}
func traceReadCPU() {
bufp := &trace.cpuLogBuf
for {
data, tags, _ := trace.cpuLogRead.read(profBufNonBlocking)
if len(data) == 0 {
break
}
for len(data) > 0 {
if len(data) < 4 || data[0] > uint64(len(data)) {
break // truncated profile
}
if data[0] < 4 || tags != nil && len(tags) < 1 {
break // malformed profile
}
if len(tags) < 1 {
break // mismatched profile records and tags
}
timestamp := data[1]
ppid := data[2] >> 1
if hasP := (data[2] & 0b1) != 0; !hasP {
ppid = ^uint64(0)
}
goid := data[3]
stk := data[4:data[0]]
empty := len(stk) == 1 && data[2] == 0 && data[3] == 0
data = data[data[0]:]
// No support here for reporting goroutine tags at the moment; if
// that information is to be part of the execution trace, we'd
// probably want to see when the tags are applied and when they
// change, instead of only seeing them when we get a CPU sample.
tags = tags[1:]
if empty {
// Looks like an overflow record from the profBuf. Not much to
// do here, we only want to report full records.
//
// TODO: should we start a goroutine to drain the profBuf,
// rather than relying on a high-enough volume of tracing events
// to keep ReadTrace busy? https://go.dev/issue/52674
continue
}
buf := bufp.ptr()
if buf == nil {
systemstack(func() {
*bufp = traceFlush(*bufp, 0)
})
buf = bufp.ptr()
}
for i := range stk {
if i >= len(buf.stk) {
break
}
buf.stk[i] = uintptr(stk[i])
}
stackID := trace.stackTab.put(buf.stk[:len(stk)])
traceEventLocked(0, nil, 0, bufp, traceEvCPUSample, stackID, 1, timestamp/traceTickDiv, ppid, goid)
}
}
}
func traceStackID(mp *m, buf []uintptr, skip int) uint64 {
gp := getg()
curgp := mp.curg
var nstk int
if curgp == gp {
nstk = callers(skip+1, buf)
} else if curgp != nil {
nstk = gcallers(curgp, skip, buf)
}
if nstk > 0 {
nstk-- // skip runtime.goexit
}
if nstk > 0 && curgp.goid == 1 {
nstk-- // skip runtime.main
}
id := trace.stackTab.put(buf[:nstk])
return uint64(id)
}
// traceAcquireBuffer returns trace buffer to use and, if necessary, locks it.
func traceAcquireBuffer() (mp *m, pid int32, bufp *traceBufPtr) {
// Any time we acquire a buffer, we may end up flushing it,
// but flushes are rare. Record the lock edge even if it
// doesn't happen this time.
lockRankMayTraceFlush()
mp = acquirem()
if p := mp.p.ptr(); p != nil {
return mp, p.id, &p.tracebuf
}
lock(&trace.bufLock)
return mp, traceGlobProc, &trace.buf
}
// traceReleaseBuffer releases a buffer previously acquired with traceAcquireBuffer.
func traceReleaseBuffer(pid int32) {
if pid == traceGlobProc {
unlock(&trace.bufLock)
}
releasem(getg().m)
}
// lockRankMayTraceFlush records the lock ranking effects of a
// potential call to traceFlush.
func lockRankMayTraceFlush() {
owner := trace.lockOwner
dolock := owner == nil || owner != getg().m.curg
if dolock {
lockWithRankMayAcquire(&trace.lock, getLockRank(&trace.lock))
}
}
// traceFlush puts buf onto stack of full buffers and returns an empty buffer.
//
// This must run on the system stack because it acquires trace.lock.
//
//go:systemstack
func traceFlush(buf traceBufPtr, pid int32) traceBufPtr {
owner := trace.lockOwner
dolock := owner == nil || owner != getg().m.curg
if dolock {
lock(&trace.lock)
}
if buf != 0 {
traceFullQueue(buf)
}
if trace.empty != 0 {
buf = trace.empty
trace.empty = buf.ptr().link
} else {
buf = traceBufPtr(sysAlloc(unsafe.Sizeof(traceBuf{}), &memstats.other_sys))
if buf == 0 {
throw("trace: out of memory")
}
}
bufp := buf.ptr()
bufp.link.set(nil)
bufp.pos = 0
// initialize the buffer for a new batch
ticks := uint64(cputicks()) / traceTickDiv
if ticks == bufp.lastTicks {
ticks = bufp.lastTicks + 1
}
bufp.lastTicks = ticks
bufp.byte(traceEvBatch | 1<<traceArgCountShift)
bufp.varint(uint64(pid))
bufp.varint(ticks)
if dolock {
unlock(&trace.lock)
}
return buf
}
// traceString adds a string to the trace.strings and returns the id.
func traceString(bufp *traceBufPtr, pid int32, s string) (uint64, *traceBufPtr) {
if s == "" {
return 0, bufp
}
lock(&trace.stringsLock)
if raceenabled {
// raceacquire is necessary because the map access
// below is race annotated.
raceacquire(unsafe.Pointer(&trace.stringsLock))
}
if id, ok := trace.strings[s]; ok {
if raceenabled {
racerelease(unsafe.Pointer(&trace.stringsLock))
}
unlock(&trace.stringsLock)
return id, bufp
}
trace.stringSeq++
id := trace.stringSeq
trace.strings[s] = id
if raceenabled {
racerelease(unsafe.Pointer(&trace.stringsLock))
}
unlock(&trace.stringsLock)
// memory allocation in above may trigger tracing and
// cause *bufp changes. Following code now works with *bufp,
// so there must be no memory allocation or any activities
// that causes tracing after this point.
buf := bufp.ptr()
size := 1 + 2*traceBytesPerNumber + len(s)
if buf == nil || len(buf.arr)-buf.pos < size {
systemstack(func() {
buf = traceFlush(traceBufPtrOf(buf), pid).ptr()
bufp.set(buf)
})
}
buf.byte(traceEvString)
buf.varint(id)
// double-check the string and the length can fit.
// Otherwise, truncate the string.
slen := len(s)
if room := len(buf.arr) - buf.pos; room < slen+traceBytesPerNumber {
slen = room
}
buf.varint(uint64(slen))
buf.pos += copy(buf.arr[buf.pos:], s[:slen])
bufp.set(buf)
return id, bufp
}
// varint appends v to buf in little-endian-base-128 encoding.
func (buf *traceBuf) varint(v uint64) {
pos := buf.pos
for ; v >= 0x80; v >>= 7 {
buf.arr[pos] = 0x80 | byte(v)
pos++
}
buf.arr[pos] = byte(v)
pos++
buf.pos = pos
}
// varintAt writes varint v at byte position pos in buf. This always
// consumes traceBytesPerNumber bytes. This is intended for when the
// caller needs to reserve space for a varint but can't populate it
// until later.
func (buf *traceBuf) varintAt(pos int, v uint64) {
for i := 0; i < traceBytesPerNumber; i++ {
if i < traceBytesPerNumber-1 {
buf.arr[pos] = 0x80 | byte(v)
} else {
buf.arr[pos] = byte(v)
}
v >>= 7
pos++
}
}
// byte appends v to buf.
func (buf *traceBuf) byte(v byte) {
buf.arr[buf.pos] = v
buf.pos++
}
// traceStackTable maps stack traces (arrays of PC's) to unique uint32 ids.
// It is lock-free for reading.
type traceStackTable struct {
lock mutex // Must be acquired on the system stack
seq uint32
mem traceAlloc
tab [1 << 13]traceStackPtr
}
// traceStack is a single stack in traceStackTable.
type traceStack struct {
link traceStackPtr
hash uintptr
id uint32
n int
stk [0]uintptr // real type [n]uintptr
}
type traceStackPtr uintptr
func (tp traceStackPtr) ptr() *traceStack { return (*traceStack)(unsafe.Pointer(tp)) }
// stack returns slice of PCs.
func (ts *traceStack) stack() []uintptr {
return (*[traceStackSize]uintptr)(unsafe.Pointer(&ts.stk))[:ts.n]
}
// put returns a unique id for the stack trace pcs and caches it in the table,
// if it sees the trace for the first time.
func (tab *traceStackTable) put(pcs []uintptr) uint32 {
if len(pcs) == 0 {
return 0
}
hash := memhash(unsafe.Pointer(&pcs[0]), 0, uintptr(len(pcs))*unsafe.Sizeof(pcs[0]))
// First, search the hashtable w/o the mutex.
if id := tab.find(pcs, hash); id != 0 {
return id
}
// Now, double check under the mutex.
// Switch to the system stack so we can acquire tab.lock
var id uint32
systemstack(func() {
lock(&tab.lock)
if id = tab.find(pcs, hash); id != 0 {
unlock(&tab.lock)
return
}
// Create new record.
tab.seq++
stk := tab.newStack(len(pcs))
stk.hash = hash
stk.id = tab.seq
id = stk.id
stk.n = len(pcs)
stkpc := stk.stack()
copy(stkpc, pcs)
part := int(hash % uintptr(len(tab.tab)))
stk.link = tab.tab[part]
atomicstorep(unsafe.Pointer(&tab.tab[part]), unsafe.Pointer(stk))
unlock(&tab.lock)
})
return id
}
// find checks if the stack trace pcs is already present in the table.
func (tab *traceStackTable) find(pcs []uintptr, hash uintptr) uint32 {
part := int(hash % uintptr(len(tab.tab)))
Search:
for stk := tab.tab[part].ptr(); stk != nil; stk = stk.link.ptr() {
if stk.hash == hash && stk.n == len(pcs) {
for i, stkpc := range stk.stack() {
if stkpc != pcs[i] {
continue Search
}
}
return stk.id
}
}
return 0
}
// newStack allocates a new stack of size n.
func (tab *traceStackTable) newStack(n int) *traceStack {
return (*traceStack)(tab.mem.alloc(unsafe.Sizeof(traceStack{}) + uintptr(n)*goarch.PtrSize))
}
// traceFrames returns the frames corresponding to pcs. It may
// allocate and may emit trace events.
func traceFrames(bufp traceBufPtr, pcs []uintptr) ([]traceFrame, traceBufPtr) {
frames := make([]traceFrame, 0, len(pcs))
ci := CallersFrames(pcs)
for {
var frame traceFrame
f, more := ci.Next()
frame, bufp = traceFrameForPC(bufp, 0, f)
frames = append(frames, frame)
if !more {
return frames, bufp
}
}
}
// dump writes all previously cached stacks to trace buffers,
// releases all memory and resets state.
//
// This must run on the system stack because it calls traceFlush.
//
//go:systemstack
func (tab *traceStackTable) dump(bufp traceBufPtr) traceBufPtr {
for i := range tab.tab {
stk := tab.tab[i].ptr()
for ; stk != nil; stk = stk.link.ptr() {
var frames []traceFrame
frames, bufp = traceFrames(bufp, stk.stack())
// Estimate the size of this record. This
// bound is pretty loose, but avoids counting
// lots of varint sizes.
maxSize := 1 + traceBytesPerNumber + (2+4*len(frames))*traceBytesPerNumber
// Make sure we have enough buffer space.
if buf := bufp.ptr(); len(buf.arr)-buf.pos < maxSize {
bufp = traceFlush(bufp, 0)
}
// Emit header, with space reserved for length.
buf := bufp.ptr()
buf.byte(traceEvStack | 3<<traceArgCountShift)
lenPos := buf.pos
buf.pos += traceBytesPerNumber
// Emit body.
recPos := buf.pos
buf.varint(uint64(stk.id))
buf.varint(uint64(len(frames)))
for _, frame := range frames {
buf.varint(uint64(frame.PC))
buf.varint(frame.funcID)
buf.varint(frame.fileID)
buf.varint(frame.line)
}
// Fill in size header.
buf.varintAt(lenPos, uint64(buf.pos-recPos))
}
}
tab.mem.drop()
*tab = traceStackTable{}
lockInit(&((*tab).lock), lockRankTraceStackTab)
return bufp
}
type traceFrame struct {
PC uintptr
funcID uint64
fileID uint64
line uint64
}
// traceFrameForPC records the frame information.
// It may allocate memory.
func traceFrameForPC(buf traceBufPtr, pid int32, f Frame) (traceFrame, traceBufPtr) {
bufp := &buf
var frame traceFrame
frame.PC = f.PC
fn := f.Function
const maxLen = 1 << 10
if len(fn) > maxLen {
fn = fn[len(fn)-maxLen:]
}
frame.funcID, bufp = traceString(bufp, pid, fn)
frame.line = uint64(f.Line)
file := f.File
if len(file) > maxLen {
file = file[len(file)-maxLen:]
}
frame.fileID, bufp = traceString(bufp, pid, file)
return frame, (*bufp)
}
// traceAlloc is a non-thread-safe region allocator.
// It holds a linked list of traceAllocBlock.
type traceAlloc struct {
head traceAllocBlockPtr
off uintptr
}
// traceAllocBlock is a block in traceAlloc.
//
// traceAllocBlock is allocated from non-GC'd memory, so it must not
// contain heap pointers. Writes to pointers to traceAllocBlocks do
// not need write barriers.
type traceAllocBlock struct {
_ sys.NotInHeap
next traceAllocBlockPtr
data [64<<10 - goarch.PtrSize]byte
}
// TODO: Since traceAllocBlock is now embedded runtime/internal/sys.NotInHeap, this isn't necessary.
type traceAllocBlockPtr uintptr
func (p traceAllocBlockPtr) ptr() *traceAllocBlock { return (*traceAllocBlock)(unsafe.Pointer(p)) }
func (p *traceAllocBlockPtr) set(x *traceAllocBlock) { *p = traceAllocBlockPtr(unsafe.Pointer(x)) }
// alloc allocates n-byte block.
func (a *traceAlloc) alloc(n uintptr) unsafe.Pointer {
n = alignUp(n, goarch.PtrSize)
if a.head == 0 || a.off+n > uintptr(len(a.head.ptr().data)) {
if n > uintptr(len(a.head.ptr().data)) {
throw("trace: alloc too large")
}
block := (*traceAllocBlock)(sysAlloc(unsafe.Sizeof(traceAllocBlock{}), &memstats.other_sys))
if block == nil {
throw("trace: out of memory")
}
block.next.set(a.head.ptr())
a.head.set(block)
a.off = 0
}
p := &a.head.ptr().data[a.off]
a.off += n
return unsafe.Pointer(p)
}
// drop frees all previously allocated memory and resets the allocator.
func (a *traceAlloc) drop() {
for a.head != 0 {
block := a.head.ptr()
a.head.set(block.next.ptr())
sysFree(unsafe.Pointer(block), unsafe.Sizeof(traceAllocBlock{}), &memstats.other_sys)
}
}
// The following functions write specific events to trace.
func traceGomaxprocs(procs int32) {
traceEvent(traceEvGomaxprocs, 1, uint64(procs))
}
func traceProcStart() {
traceEvent(traceEvProcStart, -1, uint64(getg().m.id))
}
func traceProcStop(pp *p) {
// Sysmon and stopTheWorld can stop Ps blocked in syscalls,
// to handle this we temporary employ the P.
mp := acquirem()
oldp := mp.p
mp.p.set(pp)
traceEvent(traceEvProcStop, -1)
mp.p = oldp
releasem(mp)
}
func traceGCStart() {
traceEvent(traceEvGCStart, 3, trace.seqGC)
trace.seqGC++
}
func traceGCDone() {
traceEvent(traceEvGCDone, -1)
}
func traceGCSTWStart(kind int) {
traceEvent(traceEvGCSTWStart, -1, uint64(kind))
}
func traceGCSTWDone() {
traceEvent(traceEvGCSTWDone, -1)
}
// traceGCSweepStart prepares to trace a sweep loop. This does not
// emit any events until traceGCSweepSpan is called.
//
// traceGCSweepStart must be paired with traceGCSweepDone and there
// must be no preemption points between these two calls.
func traceGCSweepStart() {
// Delay the actual GCSweepStart event until the first span
// sweep. If we don't sweep anything, don't emit any events.
pp := getg().m.p.ptr()
if pp.traceSweep {
throw("double traceGCSweepStart")
}
pp.traceSweep, pp.traceSwept, pp.traceReclaimed = true, 0, 0
}
// traceGCSweepSpan traces the sweep of a single page.
//
// This may be called outside a traceGCSweepStart/traceGCSweepDone
// pair; however, it will not emit any trace events in this case.
func traceGCSweepSpan(bytesSwept uintptr) {
pp := getg().m.p.ptr()
if pp.traceSweep {
if pp.traceSwept == 0 {
traceEvent(traceEvGCSweepStart, 1)
}
pp.traceSwept += bytesSwept
}
}
func traceGCSweepDone() {
pp := getg().m.p.ptr()
if !pp.traceSweep {
throw("missing traceGCSweepStart")
}
if pp.traceSwept != 0 {
traceEvent(traceEvGCSweepDone, -1, uint64(pp.traceSwept), uint64(pp.traceReclaimed))
}
pp.traceSweep = false
}
func traceGCMarkAssistStart() {
traceEvent(traceEvGCMarkAssistStart, 1)
}
func traceGCMarkAssistDone() {
traceEvent(traceEvGCMarkAssistDone, -1)
}
func traceGoCreate(newg *g, pc uintptr) {
newg.traceseq = 0
newg.tracelastp = getg().m.p
// +PCQuantum because traceFrameForPC expects return PCs and subtracts PCQuantum.
id := trace.stackTab.put([]uintptr{startPCforTrace(pc) + sys.PCQuantum})
traceEvent(traceEvGoCreate, 2, newg.goid, uint64(id))
}
func traceGoStart() {
gp := getg().m.curg
pp := gp.m.p
gp.traceseq++
if pp.ptr().gcMarkWorkerMode != gcMarkWorkerNotWorker {
traceEvent(traceEvGoStartLabel, -1, gp.goid, gp.traceseq, trace.markWorkerLabels[pp.ptr().gcMarkWorkerMode])
} else if gp.tracelastp == pp {
traceEvent(traceEvGoStartLocal, -1, gp.goid)
} else {
gp.tracelastp = pp
traceEvent(traceEvGoStart, -1, gp.goid, gp.traceseq)
}
}
func traceGoEnd() {
traceEvent(traceEvGoEnd, -1)
}
func traceGoSched() {
gp := getg()
gp.tracelastp = gp.m.p
traceEvent(traceEvGoSched, 1)
}
func traceGoPreempt() {
gp := getg()
gp.tracelastp = gp.m.p
traceEvent(traceEvGoPreempt, 1)
}
func traceGoPark(traceEv byte, skip int) {
if traceEv&traceFutileWakeup != 0 {
traceEvent(traceEvFutileWakeup, -1)
}
traceEvent(traceEv & ^traceFutileWakeup, skip)
}
func traceGoUnpark(gp *g, skip int) {
pp := getg().m.p
gp.traceseq++
if gp.tracelastp == pp {
traceEvent(traceEvGoUnblockLocal, skip, gp.goid)
} else {
gp.tracelastp = pp
traceEvent(traceEvGoUnblock, skip, gp.goid, gp.traceseq)
}
}
func traceGoSysCall() {
traceEvent(traceEvGoSysCall, 1)
}
func traceGoSysExit(ts int64) {
if ts != 0 && ts < trace.ticksStart {
// There is a race between the code that initializes sysexitticks
// (in exitsyscall, which runs without a P, and therefore is not
// stopped with the rest of the world) and the code that initializes
// a new trace. The recorded sysexitticks must therefore be treated
// as "best effort". If they are valid for this trace, then great,
// use them for greater accuracy. But if they're not valid for this
// trace, assume that the trace was started after the actual syscall
// exit (but before we actually managed to start the goroutine,
// aka right now), and assign a fresh time stamp to keep the log consistent.
ts = 0
}
gp := getg().m.curg
gp.traceseq++
gp.tracelastp = gp.m.p
traceEvent(traceEvGoSysExit, -1, gp.goid, gp.traceseq, uint64(ts)/traceTickDiv)
}
func traceGoSysBlock(pp *p) {
// Sysmon and stopTheWorld can declare syscalls running on remote Ps as blocked,
// to handle this we temporary employ the P.
mp := acquirem()
oldp := mp.p
mp.p.set(pp)
traceEvent(traceEvGoSysBlock, -1)
mp.p = oldp
releasem(mp)
}
func traceHeapAlloc(live uint64) {
traceEvent(traceEvHeapAlloc, -1, live)
}
func traceHeapGoal() {
heapGoal := gcController.heapGoal()
if heapGoal == ^uint64(0) {
// Heap-based triggering is disabled.
traceEvent(traceEvHeapGoal, -1, 0)
} else {
traceEvent(traceEvHeapGoal, -1, heapGoal)
}
}
// To access runtime functions from runtime/trace.
// See runtime/trace/annotation.go
//go:linkname trace_userTaskCreate runtime/trace.userTaskCreate
func trace_userTaskCreate(id, parentID uint64, taskType string) {
if !trace.enabled {
return
}
// Same as in traceEvent.
mp, pid, bufp := traceAcquireBuffer()
if !trace.enabled && !mp.startingtrace {
traceReleaseBuffer(pid)
return
}
typeStringID, bufp := traceString(bufp, pid, taskType)
traceEventLocked(0, mp, pid, bufp, traceEvUserTaskCreate, 0, 3, id, parentID, typeStringID)
traceReleaseBuffer(pid)
}
//go:linkname trace_userTaskEnd runtime/trace.userTaskEnd
func trace_userTaskEnd(id uint64) {
traceEvent(traceEvUserTaskEnd, 2, id)
}
//go:linkname trace_userRegion runtime/trace.userRegion
func trace_userRegion(id, mode uint64, name string) {
if !trace.enabled {
return
}
mp, pid, bufp := traceAcquireBuffer()
if !trace.enabled && !mp.startingtrace {
traceReleaseBuffer(pid)
return
}
nameStringID, bufp := traceString(bufp, pid, name)
traceEventLocked(0, mp, pid, bufp, traceEvUserRegion, 0, 3, id, mode, nameStringID)
traceReleaseBuffer(pid)
}
//go:linkname trace_userLog runtime/trace.userLog
func trace_userLog(id uint64, category, message string) {
if !trace.enabled {
return
}
mp, pid, bufp := traceAcquireBuffer()
if !trace.enabled && !mp.startingtrace {
traceReleaseBuffer(pid)
return
}
categoryID, bufp := traceString(bufp, pid, category)
extraSpace := traceBytesPerNumber + len(message) // extraSpace for the value string
traceEventLocked(extraSpace, mp, pid, bufp, traceEvUserLog, 0, 3, id, categoryID)
// traceEventLocked reserved extra space for val and len(val)
// in buf, so buf now has room for the following.
buf := bufp.ptr()
// double-check the message and its length can fit.
// Otherwise, truncate the message.
slen := len(message)
if room := len(buf.arr) - buf.pos; room < slen+traceBytesPerNumber {
slen = room
}
buf.varint(uint64(slen))
buf.pos += copy(buf.arr[buf.pos:], message[:slen])
traceReleaseBuffer(pid)
}
// the start PC of a goroutine for tracing purposes. If pc is a wrapper,
// it returns the PC of the wrapped function. Otherwise it returns pc.
func startPCforTrace(pc uintptr) uintptr {
f := findfunc(pc)
if !f.valid() {
return pc // may happen for locked g in extra M since its pc is 0.
}
w := funcdata(f, _FUNCDATA_WrapInfo)
if w == nil {
return pc // not a wrapper
}
return f.datap.textAddr(*(*uint32)(w))
}
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trace
import (
"context"
"fmt"
"sync/atomic"
_ "unsafe"
)
type traceContextKey struct{}
// NewTask creates a task instance with the type taskType and returns
// it along with a Context that carries the task.
// If the input context contains a task, the new task is its subtask.
//
// The taskType is used to classify task instances. Analysis tools
// like the Go execution tracer may assume there are only a bounded
// number of unique task types in the system.
//
// The returned end function is used to mark the task's end.
// The trace tool measures task latency as the time between task creation
// and when the end function is called, and provides the latency
// distribution per task type.
// If the end function is called multiple times, only the first
// call is used in the latency measurement.
//
// ctx, task := trace.NewTask(ctx, "awesomeTask")
// trace.WithRegion(ctx, "preparation", prepWork)
// // preparation of the task
// go func() { // continue processing the task in a separate goroutine.
// defer task.End()
// trace.WithRegion(ctx, "remainingWork", remainingWork)
// }()
func NewTask(pctx context.Context, taskType string) (ctx context.Context, task *Task) {
pid := fromContext(pctx).id
id := newID()
userTaskCreate(id, pid, taskType)
s := &Task{id: id}
return context.WithValue(pctx, traceContextKey{}, s), s
// We allocate a new task and the end function even when
// the tracing is disabled because the context and the detach
// function can be used across trace enable/disable boundaries,
// which complicates the problem.
//
// For example, consider the following scenario:
// - trace is enabled.
// - trace.WithRegion is called, so a new context ctx
// with a new region is created.
// - trace is disabled.
// - trace is enabled again.
// - trace APIs with the ctx is called. Is the ID in the task
// a valid one to use?
//
// TODO(hyangah): reduce the overhead at least when
// tracing is disabled. Maybe the id can embed a tracing
// round number and ignore ids generated from previous
// tracing round.
}
func fromContext(ctx context.Context) *Task {
if s, ok := ctx.Value(traceContextKey{}).(*Task); ok {
return s
}
return &bgTask
}
// Task is a data type for tracing a user-defined, logical operation.
type Task struct {
id uint64
// TODO(hyangah): record parent id?
}
// End marks the end of the operation represented by the Task.
func (t *Task) End() {
userTaskEnd(t.id)
}
var lastTaskID uint64 = 0 // task id issued last time
func newID() uint64 {
// TODO(hyangah): use per-P cache
return atomic.AddUint64(&lastTaskID, 1)
}
var bgTask = Task{id: uint64(0)}
// Log emits a one-off event with the given category and message.
// Category can be empty and the API assumes there are only a handful of
// unique categories in the system.
func Log(ctx context.Context, category, message string) {
id := fromContext(ctx).id
userLog(id, category, message)
}
// Logf is like Log, but the value is formatted using the specified format spec.
func Logf(ctx context.Context, category, format string, args ...any) {
if IsEnabled() {
// Ideally this should be just Log, but that will
// add one more frame in the stack trace.
id := fromContext(ctx).id
userLog(id, category, fmt.Sprintf(format, args...))
}
}
const (
regionStartCode = uint64(0)
regionEndCode = uint64(1)
)
// WithRegion starts a region associated with its calling goroutine, runs fn,
// and then ends the region. If the context carries a task, the region is
// associated with the task. Otherwise, the region is attached to the background
// task.
//
// The regionType is used to classify regions, so there should be only a
// handful of unique region types.
func WithRegion(ctx context.Context, regionType string, fn func()) {
// NOTE:
// WithRegion helps avoiding misuse of the API but in practice,
// this is very restrictive:
// - Use of WithRegion makes the stack traces captured from
// region start and end are identical.
// - Refactoring the existing code to use WithRegion is sometimes
// hard and makes the code less readable.
// e.g. code block nested deep in the loop with various
// exit point with return values
// - Refactoring the code to use this API with closure can
// cause different GC behavior such as retaining some parameters
// longer.
// This causes more churns in code than I hoped, and sometimes
// makes the code less readable.
id := fromContext(ctx).id
userRegion(id, regionStartCode, regionType)
defer userRegion(id, regionEndCode, regionType)
fn()
}
// StartRegion starts a region and returns a function for marking the
// end of the region. The returned Region's End function must be called
// from the same goroutine where the region was started.
// Within each goroutine, regions must nest. That is, regions started
// after this region must be ended before this region can be ended.
// Recommended usage is
//
// defer trace.StartRegion(ctx, "myTracedRegion").End()
func StartRegion(ctx context.Context, regionType string) *Region {
if !IsEnabled() {
return noopRegion
}
id := fromContext(ctx).id
userRegion(id, regionStartCode, regionType)
return &Region{id, regionType}
}
// Region is a region of code whose execution time interval is traced.
type Region struct {
id uint64
regionType string
}
var noopRegion = &Region{}
// End marks the end of the traced code region.
func (r *Region) End() {
if r == noopRegion {
return
}
userRegion(r.id, regionEndCode, r.regionType)
}
// IsEnabled reports whether tracing is enabled.
// The information is advisory only. The tracing status
// may have changed by the time this function returns.
func IsEnabled() bool {
return tracing.enabled.Load()
}
//
// Function bodies are defined in runtime/trace.go
//
// emits UserTaskCreate event.
func userTaskCreate(id, parentID uint64, taskType string)
// emits UserTaskEnd event.
func userTaskEnd(id uint64)
// emits UserRegion event.
func userRegion(id, mode uint64, regionType string)
// emits UserLog event.
func userLog(id uint64, category, message string)
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package trace contains facilities for programs to generate traces
// for the Go execution tracer.
//
// # Tracing runtime activities
//
// The execution trace captures a wide range of execution events such as
// goroutine creation/blocking/unblocking, syscall enter/exit/block,
// GC-related events, changes of heap size, processor start/stop, etc.
// When CPU profiling is active, the execution tracer makes an effort to
// include those samples as well.
// A precise nanosecond-precision timestamp and a stack trace is
// captured for most events. The generated trace can be interpreted
// using `go tool trace`.
//
// Support for tracing tests and benchmarks built with the standard
// testing package is built into `go test`. For example, the following
// command runs the test in the current directory and writes the trace
// file (trace.out).
//
// go test -trace=trace.out
//
// This runtime/trace package provides APIs to add equivalent tracing
// support to a standalone program. See the Example that demonstrates
// how to use this API to enable tracing.
//
// There is also a standard HTTP interface to trace data. Adding the
// following line will install a handler under the /debug/pprof/trace URL
// to download a live trace:
//
// import _ "net/http/pprof"
//
// See the net/http/pprof package for more details about all of the
// debug endpoints installed by this import.
//
// # User annotation
//
// Package trace provides user annotation APIs that can be used to
// log interesting events during execution.
//
// There are three types of user annotations: log messages, regions,
// and tasks.
//
// Log emits a timestamped message to the execution trace along with
// additional information such as the category of the message and
// which goroutine called Log. The execution tracer provides UIs to filter
// and group goroutines using the log category and the message supplied
// in Log.
//
// A region is for logging a time interval during a goroutine's execution.
// By definition, a region starts and ends in the same goroutine.
// Regions can be nested to represent subintervals.
// For example, the following code records four regions in the execution
// trace to trace the durations of sequential steps in a cappuccino making
// operation.
//
// trace.WithRegion(ctx, "makeCappuccino", func() {
//
// // orderID allows to identify a specific order
// // among many cappuccino order region records.
// trace.Log(ctx, "orderID", orderID)
//
// trace.WithRegion(ctx, "steamMilk", steamMilk)
// trace.WithRegion(ctx, "extractCoffee", extractCoffee)
// trace.WithRegion(ctx, "mixMilkCoffee", mixMilkCoffee)
// })
//
// A task is a higher-level component that aids tracing of logical
// operations such as an RPC request, an HTTP request, or an
// interesting local operation which may require multiple goroutines
// working together. Since tasks can involve multiple goroutines,
// they are tracked via a context.Context object. NewTask creates
// a new task and embeds it in the returned context.Context object.
// Log messages and regions are attached to the task, if any, in the
// Context passed to Log and WithRegion.
//
// For example, assume that we decided to froth milk, extract coffee,
// and mix milk and coffee in separate goroutines. With a task,
// the trace tool can identify the goroutines involved in a specific
// cappuccino order.
//
// ctx, task := trace.NewTask(ctx, "makeCappuccino")
// trace.Log(ctx, "orderID", orderID)
//
// milk := make(chan bool)
// espresso := make(chan bool)
//
// go func() {
// trace.WithRegion(ctx, "steamMilk", steamMilk)
// milk <- true
// }()
// go func() {
// trace.WithRegion(ctx, "extractCoffee", extractCoffee)
// espresso <- true
// }()
// go func() {
// defer task.End() // When assemble is done, the order is complete.
// <-espresso
// <-milk
// trace.WithRegion(ctx, "mixMilkCoffee", mixMilkCoffee)
// }()
//
// The trace tool computes the latency of a task by measuring the
// time between the task creation and the task end and provides
// latency distributions for each task type found in the trace.
package trace
import (
"io"
"runtime"
"sync"
"sync/atomic"
)
// Start enables tracing for the current program.
// While tracing, the trace will be buffered and written to w.
// Start returns an error if tracing is already enabled.
func Start(w io.Writer) error {
tracing.Lock()
defer tracing.Unlock()
if err := runtime.StartTrace(); err != nil {
return err
}
go func() {
for {
data := runtime.ReadTrace()
if data == nil {
break
}
w.Write(data)
}
}()
tracing.enabled.Store(true)
return nil
}
// Stop stops the current tracing, if any.
// Stop only returns after all the writes for the trace have completed.
func Stop() {
tracing.Lock()
defer tracing.Unlock()
tracing.enabled.Store(false)
runtime.StopTrace()
}
var tracing struct {
sync.Mutex // gate mutators (Start, Stop)
enabled atomic.Bool
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"internal/bytealg"
"internal/goarch"
"runtime/internal/sys"
"unsafe"
)
// The code in this file implements stack trace walking for all architectures.
// The most important fact about a given architecture is whether it uses a link register.
// On systems with link registers, the prologue for a non-leaf function stores the
// incoming value of LR at the bottom of the newly allocated stack frame.
// On systems without link registers (x86), the architecture pushes a return PC during
// the call instruction, so the return PC ends up above the stack frame.
// In this file, the return PC is always called LR, no matter how it was found.
const usesLR = sys.MinFrameSize > 0
// Generic traceback. Handles runtime stack prints (pcbuf == nil),
// the runtime.Callers function (pcbuf != nil), as well as the garbage
// collector (callback != nil). A little clunky to merge these, but avoids
// duplicating the code and all its subtlety.
//
// The skip argument is only valid with pcbuf != nil and counts the number
// of logical frames to skip rather than physical frames (with inlining, a
// PC in pcbuf can represent multiple calls).
func gentraceback(pc0, sp0, lr0 uintptr, gp *g, skip int, pcbuf *uintptr, max int, callback func(*stkframe, unsafe.Pointer) bool, v unsafe.Pointer, flags uint) int {
if skip > 0 && callback != nil {
throw("gentraceback callback cannot be used with non-zero skip")
}
// Don't call this "g"; it's too easy get "g" and "gp" confused.
if ourg := getg(); ourg == gp && ourg == ourg.m.curg {
// The starting sp has been passed in as a uintptr, and the caller may
// have other uintptr-typed stack references as well.
// If during one of the calls that got us here or during one of the
// callbacks below the stack must be grown, all these uintptr references
// to the stack will not be updated, and gentraceback will continue
// to inspect the old stack memory, which may no longer be valid.
// Even if all the variables were updated correctly, it is not clear that
// we want to expose a traceback that begins on one stack and ends
// on another stack. That could confuse callers quite a bit.
// Instead, we require that gentraceback and any other function that
// accepts an sp for the current goroutine (typically obtained by
// calling getcallersp) must not run on that goroutine's stack but
// instead on the g0 stack.
throw("gentraceback cannot trace user goroutine on its own stack")
}
level, _, _ := gotraceback()
if pc0 == ^uintptr(0) && sp0 == ^uintptr(0) { // Signal to fetch saved values from gp.
if gp.syscallsp != 0 {
pc0 = gp.syscallpc
sp0 = gp.syscallsp
if usesLR {
lr0 = 0
}
} else {
pc0 = gp.sched.pc
sp0 = gp.sched.sp
if usesLR {
lr0 = gp.sched.lr
}
}
}
nprint := 0
var frame stkframe
frame.pc = pc0
frame.sp = sp0
if usesLR {
frame.lr = lr0
}
waspanic := false
cgoCtxt := gp.cgoCtxt
stack := gp.stack
printing := pcbuf == nil && callback == nil
// If the PC is zero, it's likely a nil function call.
// Start in the caller's frame.
if frame.pc == 0 {
if usesLR {
frame.pc = *(*uintptr)(unsafe.Pointer(frame.sp))
frame.lr = 0
} else {
frame.pc = uintptr(*(*uintptr)(unsafe.Pointer(frame.sp)))
frame.sp += goarch.PtrSize
}
}
// runtime/internal/atomic functions call into kernel helpers on
// arm < 7. See runtime/internal/atomic/sys_linux_arm.s.
//
// Start in the caller's frame.
if GOARCH == "arm" && goarm < 7 && GOOS == "linux" && frame.pc&0xffff0000 == 0xffff0000 {
// Note that the calls are simple BL without pushing the return
// address, so we use LR directly.
//
// The kernel helpers are frameless leaf functions, so SP and
// LR are not touched.
frame.pc = frame.lr
frame.lr = 0
}
f := findfunc(frame.pc)
if !f.valid() {
if callback != nil || printing {
print("runtime: g ", gp.goid, ": unknown pc ", hex(frame.pc), "\n")
tracebackHexdump(stack, &frame, 0)
}
if callback != nil {
throw("unknown pc")
}
return 0
}
frame.fn = f
var cache pcvalueCache
lastFuncID := funcID_normal
n := 0
for n < max {
// Typically:
// pc is the PC of the running function.
// sp is the stack pointer at that program counter.
// fp is the frame pointer (caller's stack pointer) at that program counter, or nil if unknown.
// stk is the stack containing sp.
// The caller's program counter is lr, unless lr is zero, in which case it is *(uintptr*)sp.
f = frame.fn
if f.pcsp == 0 {
// No frame information, must be external function, like race support.
// See golang.org/issue/13568.
break
}
// Compute function info flags.
flag := f.flag
if f.funcID == funcID_cgocallback {
// cgocallback does write SP to switch from the g0 to the curg stack,
// but it carefully arranges that during the transition BOTH stacks
// have cgocallback frame valid for unwinding through.
// So we don't need to exclude it with the other SP-writing functions.
flag &^= funcFlag_SPWRITE
}
if frame.pc == pc0 && frame.sp == sp0 && pc0 == gp.syscallpc && sp0 == gp.syscallsp {
// Some Syscall functions write to SP, but they do so only after
// saving the entry PC/SP using entersyscall.
// Since we are using the entry PC/SP, the later SP write doesn't matter.
flag &^= funcFlag_SPWRITE
}
// Found an actual function.
// Derive frame pointer and link register.
if frame.fp == 0 {
// Jump over system stack transitions. If we're on g0 and there's a user
// goroutine, try to jump. Otherwise this is a regular call.
// We also defensively check that this won't switch M's on us,
// which could happen at critical points in the scheduler.
// This ensures gp.m doesn't change from a stack jump.
if flags&_TraceJumpStack != 0 && gp == gp.m.g0 && gp.m.curg != nil && gp.m.curg.m == gp.m {
switch f.funcID {
case funcID_morestack:
// morestack does not return normally -- newstack()
// gogo's to curg.sched. Match that.
// This keeps morestack() from showing up in the backtrace,
// but that makes some sense since it'll never be returned
// to.
gp = gp.m.curg
frame.pc = gp.sched.pc
frame.fn = findfunc(frame.pc)
f = frame.fn
flag = f.flag
frame.lr = gp.sched.lr
frame.sp = gp.sched.sp
stack = gp.stack
cgoCtxt = gp.cgoCtxt
case funcID_systemstack:
// systemstack returns normally, so just follow the
// stack transition.
if usesLR && funcspdelta(f, frame.pc, &cache) == 0 {
// We're at the function prologue and the stack
// switch hasn't happened, or epilogue where we're
// about to return. Just unwind normally.
// Do this only on LR machines because on x86
// systemstack doesn't have an SP delta (the CALL
// instruction opens the frame), therefore no way
// to check.
flag &^= funcFlag_SPWRITE
break
}
gp = gp.m.curg
frame.sp = gp.sched.sp
stack = gp.stack
cgoCtxt = gp.cgoCtxt
flag &^= funcFlag_SPWRITE
}
}
frame.fp = frame.sp + uintptr(funcspdelta(f, frame.pc, &cache))
if !usesLR {
// On x86, call instruction pushes return PC before entering new function.
frame.fp += goarch.PtrSize
}
}
var flr funcInfo
if flag&funcFlag_TOPFRAME != 0 {
// This function marks the top of the stack. Stop the traceback.
frame.lr = 0
flr = funcInfo{}
} else if flag&funcFlag_SPWRITE != 0 && (callback == nil || n > 0) {
// The function we are in does a write to SP that we don't know
// how to encode in the spdelta table. Examples include context
// switch routines like runtime.gogo but also any code that switches
// to the g0 stack to run host C code. Since we can't reliably unwind
// the SP (we might not even be on the stack we think we are),
// we stop the traceback here.
// This only applies for profiling signals (callback == nil).
//
// For a GC stack traversal (callback != nil), we should only see
// a function when it has voluntarily preempted itself on entry
// during the stack growth check. In that case, the function has
// not yet had a chance to do any writes to SP and is safe to unwind.
// isAsyncSafePoint does not allow assembly functions to be async preempted,
// and preemptPark double-checks that SPWRITE functions are not async preempted.
// So for GC stack traversal we leave things alone (this if body does not execute for n == 0)
// at the bottom frame of the stack. But farther up the stack we'd better not
// find any.
if callback != nil {
println("traceback: unexpected SPWRITE function", funcname(f))
throw("traceback")
}
frame.lr = 0
flr = funcInfo{}
} else {
var lrPtr uintptr
if usesLR {
if n == 0 && frame.sp < frame.fp || frame.lr == 0 {
lrPtr = frame.sp
frame.lr = *(*uintptr)(unsafe.Pointer(lrPtr))
}
} else {
if frame.lr == 0 {
lrPtr = frame.fp - goarch.PtrSize
frame.lr = uintptr(*(*uintptr)(unsafe.Pointer(lrPtr)))
}
}
flr = findfunc(frame.lr)
if !flr.valid() {
// This happens if you get a profiling interrupt at just the wrong time.
// In that context it is okay to stop early.
// But if callback is set, we're doing a garbage collection and must
// get everything, so crash loudly.
doPrint := printing
if doPrint && gp.m.incgo && f.funcID == funcID_sigpanic {
// We can inject sigpanic
// calls directly into C code,
// in which case we'll see a C
// return PC. Don't complain.
doPrint = false
}
if callback != nil || doPrint {
print("runtime: g ", gp.goid, ": unexpected return pc for ", funcname(f), " called from ", hex(frame.lr), "\n")
tracebackHexdump(stack, &frame, lrPtr)
}
if callback != nil {
throw("unknown caller pc")
}
}
}
frame.varp = frame.fp
if !usesLR {
// On x86, call instruction pushes return PC before entering new function.
frame.varp -= goarch.PtrSize
}
// For architectures with frame pointers, if there's
// a frame, then there's a saved frame pointer here.
//
// NOTE: This code is not as general as it looks.
// On x86, the ABI is to save the frame pointer word at the
// top of the stack frame, so we have to back down over it.
// On arm64, the frame pointer should be at the bottom of
// the stack (with R29 (aka FP) = RSP), in which case we would
// not want to do the subtraction here. But we started out without
// any frame pointer, and when we wanted to add it, we didn't
// want to break all the assembly doing direct writes to 8(RSP)
// to set the first parameter to a called function.
// So we decided to write the FP link *below* the stack pointer
// (with R29 = RSP - 8 in Go functions).
// This is technically ABI-compatible but not standard.
// And it happens to end up mimicking the x86 layout.
// Other architectures may make different decisions.
if frame.varp > frame.sp && framepointer_enabled {
frame.varp -= goarch.PtrSize
}
frame.argp = frame.fp + sys.MinFrameSize
// Determine frame's 'continuation PC', where it can continue.
// Normally this is the return address on the stack, but if sigpanic
// is immediately below this function on the stack, then the frame
// stopped executing due to a trap, and frame.pc is probably not
// a safe point for looking up liveness information. In this panicking case,
// the function either doesn't return at all (if it has no defers or if the
// defers do not recover) or it returns from one of the calls to
// deferproc a second time (if the corresponding deferred func recovers).
// In the latter case, use a deferreturn call site as the continuation pc.
frame.continpc = frame.pc
if waspanic {
if frame.fn.deferreturn != 0 {
frame.continpc = frame.fn.entry() + uintptr(frame.fn.deferreturn) + 1
// Note: this may perhaps keep return variables alive longer than
// strictly necessary, as we are using "function has a defer statement"
// as a proxy for "function actually deferred something". It seems
// to be a minor drawback. (We used to actually look through the
// gp._defer for a defer corresponding to this function, but that
// is hard to do with defer records on the stack during a stack copy.)
// Note: the +1 is to offset the -1 that
// stack.go:getStackMap does to back up a return
// address make sure the pc is in the CALL instruction.
} else {
frame.continpc = 0
}
}
if callback != nil {
if !callback((*stkframe)(noescape(unsafe.Pointer(&frame))), v) {
return n
}
}
if pcbuf != nil {
pc := frame.pc
// backup to CALL instruction to read inlining info (same logic as below)
tracepc := pc
// Normally, pc is a return address. In that case, we want to look up
// file/line information using pc-1, because that is the pc of the
// call instruction (more precisely, the last byte of the call instruction).
// Callers expect the pc buffer to contain return addresses and do the
// same -1 themselves, so we keep pc unchanged.
// When the pc is from a signal (e.g. profiler or segv) then we want
// to look up file/line information using pc, and we store pc+1 in the
// pc buffer so callers can unconditionally subtract 1 before looking up.
// See issue 34123.
// The pc can be at function entry when the frame is initialized without
// actually running code, like runtime.mstart.
if (n == 0 && flags&_TraceTrap != 0) || waspanic || pc == f.entry() {
pc++
} else {
tracepc--
}
// If there is inlining info, record the inner frames.
if inldata := funcdata(f, _FUNCDATA_InlTree); inldata != nil {
inltree := (*[1 << 20]inlinedCall)(inldata)
for {
ix := pcdatavalue(f, _PCDATA_InlTreeIndex, tracepc, &cache)
if ix < 0 {
break
}
if inltree[ix].funcID == funcID_wrapper && elideWrapperCalling(lastFuncID) {
// ignore wrappers
} else if skip > 0 {
skip--
} else if n < max {
(*[1 << 20]uintptr)(unsafe.Pointer(pcbuf))[n] = pc
n++
}
lastFuncID = inltree[ix].funcID
// Back up to an instruction in the "caller".
tracepc = frame.fn.entry() + uintptr(inltree[ix].parentPc)
pc = tracepc + 1
}
}
// Record the main frame.
if f.funcID == funcID_wrapper && elideWrapperCalling(lastFuncID) {
// Ignore wrapper functions (except when they trigger panics).
} else if skip > 0 {
skip--
} else if n < max {
(*[1 << 20]uintptr)(unsafe.Pointer(pcbuf))[n] = pc
n++
}
lastFuncID = f.funcID
n-- // offset n++ below
}
if printing {
// assume skip=0 for printing.
//
// Never elide wrappers if we haven't printed
// any frames. And don't elide wrappers that
// called panic rather than the wrapped
// function. Otherwise, leave them out.
// backup to CALL instruction to read inlining info (same logic as below)
tracepc := frame.pc
if (n > 0 || flags&_TraceTrap == 0) && frame.pc > f.entry() && !waspanic {
tracepc--
}
// If there is inlining info, print the inner frames.
if inldata := funcdata(f, _FUNCDATA_InlTree); inldata != nil {
inltree := (*[1 << 20]inlinedCall)(inldata)
var inlFunc _func
inlFuncInfo := funcInfo{&inlFunc, f.datap}
for {
ix := pcdatavalue(f, _PCDATA_InlTreeIndex, tracepc, nil)
if ix < 0 {
break
}
// Create a fake _func for the
// inlined function.
inlFunc.nameOff = inltree[ix].nameOff
inlFunc.funcID = inltree[ix].funcID
inlFunc.startLine = inltree[ix].startLine
if (flags&_TraceRuntimeFrames) != 0 || showframe(inlFuncInfo, gp, nprint == 0, inlFuncInfo.funcID, lastFuncID) {
name := funcname(inlFuncInfo)
file, line := funcline(f, tracepc)
print(name, "(...)\n")
print("\t", file, ":", line, "\n")
nprint++
}
lastFuncID = inltree[ix].funcID
// Back up to an instruction in the "caller".
tracepc = frame.fn.entry() + uintptr(inltree[ix].parentPc)
}
}
if (flags&_TraceRuntimeFrames) != 0 || showframe(f, gp, nprint == 0, f.funcID, lastFuncID) {
// Print during crash.
// main(0x1, 0x2, 0x3)
// /home/rsc/go/src/runtime/x.go:23 +0xf
//
name := funcname(f)
file, line := funcline(f, tracepc)
if name == "runtime.gopanic" {
name = "panic"
}
print(name, "(")
argp := unsafe.Pointer(frame.argp)
printArgs(f, argp, tracepc)
print(")\n")
print("\t", file, ":", line)
if frame.pc > f.entry() {
print(" +", hex(frame.pc-f.entry()))
}
if gp.m != nil && gp.m.throwing >= throwTypeRuntime && gp == gp.m.curg || level >= 2 {
print(" fp=", hex(frame.fp), " sp=", hex(frame.sp), " pc=", hex(frame.pc))
}
print("\n")
nprint++
}
lastFuncID = f.funcID
}
n++
if f.funcID == funcID_cgocallback && len(cgoCtxt) > 0 {
ctxt := cgoCtxt[len(cgoCtxt)-1]
cgoCtxt = cgoCtxt[:len(cgoCtxt)-1]
// skip only applies to Go frames.
// callback != nil only used when we only care
// about Go frames.
if skip == 0 && callback == nil {
n = tracebackCgoContext(pcbuf, printing, ctxt, n, max)
}
}
waspanic = f.funcID == funcID_sigpanic
injectedCall := waspanic || f.funcID == funcID_asyncPreempt || f.funcID == funcID_debugCallV2
// Do not unwind past the bottom of the stack.
if !flr.valid() {
break
}
if frame.pc == frame.lr && frame.sp == frame.fp {
// If the next frame is identical to the current frame, we cannot make progress.
print("runtime: traceback stuck. pc=", hex(frame.pc), " sp=", hex(frame.sp), "\n")
tracebackHexdump(stack, &frame, frame.sp)
throw("traceback stuck")
}
// Unwind to next frame.
frame.fn = flr
frame.pc = frame.lr
frame.lr = 0
frame.sp = frame.fp
frame.fp = 0
// On link register architectures, sighandler saves the LR on stack
// before faking a call.
if usesLR && injectedCall {
x := *(*uintptr)(unsafe.Pointer(frame.sp))
frame.sp += alignUp(sys.MinFrameSize, sys.StackAlign)
f = findfunc(frame.pc)
frame.fn = f
if !f.valid() {
frame.pc = x
} else if funcspdelta(f, frame.pc, &cache) == 0 {
frame.lr = x
}
}
}
if printing {
n = nprint
}
// Note that panic != nil is okay here: there can be leftover panics,
// because the defers on the panic stack do not nest in frame order as
// they do on the defer stack. If you have:
//
// frame 1 defers d1
// frame 2 defers d2
// frame 3 defers d3
// frame 4 panics
// frame 4's panic starts running defers
// frame 5, running d3, defers d4
// frame 5 panics
// frame 5's panic starts running defers
// frame 6, running d4, garbage collects
// frame 6, running d2, garbage collects
//
// During the execution of d4, the panic stack is d4 -> d3, which
// is nested properly, and we'll treat frame 3 as resumable, because we
// can find d3. (And in fact frame 3 is resumable. If d4 recovers
// and frame 5 continues running, d3, d3 can recover and we'll
// resume execution in (returning from) frame 3.)
//
// During the execution of d2, however, the panic stack is d2 -> d3,
// which is inverted. The scan will match d2 to frame 2 but having
// d2 on the stack until then means it will not match d3 to frame 3.
// This is okay: if we're running d2, then all the defers after d2 have
// completed and their corresponding frames are dead. Not finding d3
// for frame 3 means we'll set frame 3's continpc == 0, which is correct
// (frame 3 is dead). At the end of the walk the panic stack can thus
// contain defers (d3 in this case) for dead frames. The inversion here
// always indicates a dead frame, and the effect of the inversion on the
// scan is to hide those dead frames, so the scan is still okay:
// what's left on the panic stack are exactly (and only) the dead frames.
//
// We require callback != nil here because only when callback != nil
// do we know that gentraceback is being called in a "must be correct"
// context as opposed to a "best effort" context. The tracebacks with
// callbacks only happen when everything is stopped nicely.
// At other times, such as when gathering a stack for a profiling signal
// or when printing a traceback during a crash, everything may not be
// stopped nicely, and the stack walk may not be able to complete.
if callback != nil && n < max && frame.sp != gp.stktopsp {
print("runtime: g", gp.goid, ": frame.sp=", hex(frame.sp), " top=", hex(gp.stktopsp), "\n")
print("\tstack=[", hex(gp.stack.lo), "-", hex(gp.stack.hi), "] n=", n, " max=", max, "\n")
throw("traceback did not unwind completely")
}
return n
}
// printArgs prints function arguments in traceback.
func printArgs(f funcInfo, argp unsafe.Pointer, pc uintptr) {
// The "instruction" of argument printing is encoded in _FUNCDATA_ArgInfo.
// See cmd/compile/internal/ssagen.emitArgInfo for the description of the
// encoding.
// These constants need to be in sync with the compiler.
const (
_endSeq = 0xff
_startAgg = 0xfe
_endAgg = 0xfd
_dotdotdot = 0xfc
_offsetTooLarge = 0xfb
)
const (
limit = 10 // print no more than 10 args/components
maxDepth = 5 // no more than 5 layers of nesting
maxLen = (maxDepth*3+2)*limit + 1 // max length of _FUNCDATA_ArgInfo (see the compiler side for reasoning)
)
p := (*[maxLen]uint8)(funcdata(f, _FUNCDATA_ArgInfo))
if p == nil {
return
}
liveInfo := funcdata(f, _FUNCDATA_ArgLiveInfo)
liveIdx := pcdatavalue(f, _PCDATA_ArgLiveIndex, pc, nil)
startOffset := uint8(0xff) // smallest offset that needs liveness info (slots with a lower offset is always live)
if liveInfo != nil {
startOffset = *(*uint8)(liveInfo)
}
isLive := func(off, slotIdx uint8) bool {
if liveInfo == nil || liveIdx <= 0 {
return true // no liveness info, always live
}
if off < startOffset {
return true
}
bits := *(*uint8)(add(liveInfo, uintptr(liveIdx)+uintptr(slotIdx/8)))
return bits&(1<<(slotIdx%8)) != 0
}
print1 := func(off, sz, slotIdx uint8) {
x := readUnaligned64(add(argp, uintptr(off)))
// mask out irrelevant bits
if sz < 8 {
shift := 64 - sz*8
if goarch.BigEndian {
x = x >> shift
} else {
x = x << shift >> shift
}
}
print(hex(x))
if !isLive(off, slotIdx) {
print("?")
}
}
start := true
printcomma := func() {
if !start {
print(", ")
}
}
pi := 0
slotIdx := uint8(0) // register arg spill slot index
printloop:
for {
o := p[pi]
pi++
switch o {
case _endSeq:
break printloop
case _startAgg:
printcomma()
print("{")
start = true
continue
case _endAgg:
print("}")
case _dotdotdot:
printcomma()
print("...")
case _offsetTooLarge:
printcomma()
print("_")
default:
printcomma()
sz := p[pi]
pi++
print1(o, sz, slotIdx)
if o >= startOffset {
slotIdx++
}
}
start = false
}
}
// tracebackCgoContext handles tracing back a cgo context value, from
// the context argument to setCgoTraceback, for the gentraceback
// function. It returns the new value of n.
func tracebackCgoContext(pcbuf *uintptr, printing bool, ctxt uintptr, n, max int) int {
var cgoPCs [32]uintptr
cgoContextPCs(ctxt, cgoPCs[:])
var arg cgoSymbolizerArg
anySymbolized := false
for _, pc := range cgoPCs {
if pc == 0 || n >= max {
break
}
if pcbuf != nil {
(*[1 << 20]uintptr)(unsafe.Pointer(pcbuf))[n] = pc
}
if printing {
if cgoSymbolizer == nil {
print("non-Go function at pc=", hex(pc), "\n")
} else {
c := printOneCgoTraceback(pc, max-n, &arg)
n += c - 1 // +1 a few lines down
anySymbolized = true
}
}
n++
}
if anySymbolized {
arg.pc = 0
callCgoSymbolizer(&arg)
}
return n
}
func printcreatedby(gp *g) {
// Show what created goroutine, except main goroutine (goid 1).
pc := gp.gopc
f := findfunc(pc)
if f.valid() && showframe(f, gp, false, funcID_normal, funcID_normal) && gp.goid != 1 {
printcreatedby1(f, pc, gp.parentGoid)
}
}
func printcreatedby1(f funcInfo, pc uintptr, goid uint64) {
print("created by ", funcname(f))
if goid != 0 {
print(" in goroutine ", goid)
}
print("\n")
tracepc := pc // back up to CALL instruction for funcline.
if pc > f.entry() {
tracepc -= sys.PCQuantum
}
file, line := funcline(f, tracepc)
print("\t", file, ":", line)
if pc > f.entry() {
print(" +", hex(pc-f.entry()))
}
print("\n")
}
func traceback(pc, sp, lr uintptr, gp *g) {
traceback1(pc, sp, lr, gp, 0)
}
// tracebacktrap is like traceback but expects that the PC and SP were obtained
// from a trap, not from gp->sched or gp->syscallpc/gp->syscallsp or getcallerpc/getcallersp.
// Because they are from a trap instead of from a saved pair,
// the initial PC must not be rewound to the previous instruction.
// (All the saved pairs record a PC that is a return address, so we
// rewind it into the CALL instruction.)
// If gp.m.libcall{g,pc,sp} information is available, it uses that information in preference to
// the pc/sp/lr passed in.
func tracebacktrap(pc, sp, lr uintptr, gp *g) {
if gp.m.libcallsp != 0 {
// We're in C code somewhere, traceback from the saved position.
traceback1(gp.m.libcallpc, gp.m.libcallsp, 0, gp.m.libcallg.ptr(), 0)
return
}
traceback1(pc, sp, lr, gp, _TraceTrap)
}
func traceback1(pc, sp, lr uintptr, gp *g, flags uint) {
// If the goroutine is in cgo, and we have a cgo traceback, print that.
if iscgo && gp.m != nil && gp.m.ncgo > 0 && gp.syscallsp != 0 && gp.m.cgoCallers != nil && gp.m.cgoCallers[0] != 0 {
// Lock cgoCallers so that a signal handler won't
// change it, copy the array, reset it, unlock it.
// We are locked to the thread and are not running
// concurrently with a signal handler.
// We just have to stop a signal handler from interrupting
// in the middle of our copy.
gp.m.cgoCallersUse.Store(1)
cgoCallers := *gp.m.cgoCallers
gp.m.cgoCallers[0] = 0
gp.m.cgoCallersUse.Store(0)
printCgoTraceback(&cgoCallers)
}
if readgstatus(gp)&^_Gscan == _Gsyscall {
// Override registers if blocked in system call.
pc = gp.syscallpc
sp = gp.syscallsp
flags &^= _TraceTrap
}
if gp.m != nil && gp.m.vdsoSP != 0 {
// Override registers if running in VDSO. This comes after the
// _Gsyscall check to cover VDSO calls after entersyscall.
pc = gp.m.vdsoPC
sp = gp.m.vdsoSP
flags &^= _TraceTrap
}
// Print traceback. By default, omits runtime frames.
// If that means we print nothing at all, repeat forcing all frames printed.
n := gentraceback(pc, sp, lr, gp, 0, nil, _TracebackMaxFrames, nil, nil, flags)
if n == 0 && (flags&_TraceRuntimeFrames) == 0 {
n = gentraceback(pc, sp, lr, gp, 0, nil, _TracebackMaxFrames, nil, nil, flags|_TraceRuntimeFrames)
}
if n == _TracebackMaxFrames {
print("...additional frames elided...\n")
}
printcreatedby(gp)
if gp.ancestors == nil {
return
}
for _, ancestor := range *gp.ancestors {
printAncestorTraceback(ancestor)
}
}
// printAncestorTraceback prints the traceback of the given ancestor.
// TODO: Unify this with gentraceback and CallersFrames.
func printAncestorTraceback(ancestor ancestorInfo) {
print("[originating from goroutine ", ancestor.goid, "]:\n")
for fidx, pc := range ancestor.pcs {
f := findfunc(pc) // f previously validated
if showfuncinfo(f, fidx == 0, funcID_normal, funcID_normal) {
printAncestorTracebackFuncInfo(f, pc)
}
}
if len(ancestor.pcs) == _TracebackMaxFrames {
print("...additional frames elided...\n")
}
// Show what created goroutine, except main goroutine (goid 1).
f := findfunc(ancestor.gopc)
if f.valid() && showfuncinfo(f, false, funcID_normal, funcID_normal) && ancestor.goid != 1 {
// In ancestor mode, we'll already print the goroutine ancestor.
// Pass 0 for the goid parameter so we don't print it again.
printcreatedby1(f, ancestor.gopc, 0)
}
}
// printAncestorTracebackFuncInfo prints the given function info at a given pc
// within an ancestor traceback. The precision of this info is reduced
// due to only have access to the pcs at the time of the caller
// goroutine being created.
func printAncestorTracebackFuncInfo(f funcInfo, pc uintptr) {
name := funcname(f)
if inldata := funcdata(f, _FUNCDATA_InlTree); inldata != nil {
inltree := (*[1 << 20]inlinedCall)(inldata)
ix := pcdatavalue(f, _PCDATA_InlTreeIndex, pc, nil)
if ix >= 0 {
name = funcnameFromNameOff(f, inltree[ix].nameOff)
}
}
file, line := funcline(f, pc)
if name == "runtime.gopanic" {
name = "panic"
}
print(name, "(...)\n")
print("\t", file, ":", line)
if pc > f.entry() {
print(" +", hex(pc-f.entry()))
}
print("\n")
}
func callers(skip int, pcbuf []uintptr) int {
sp := getcallersp()
pc := getcallerpc()
gp := getg()
var n int
systemstack(func() {
n = gentraceback(pc, sp, 0, gp, skip, &pcbuf[0], len(pcbuf), nil, nil, 0)
})
return n
}
func gcallers(gp *g, skip int, pcbuf []uintptr) int {
return gentraceback(^uintptr(0), ^uintptr(0), 0, gp, skip, &pcbuf[0], len(pcbuf), nil, nil, 0)
}
// showframe reports whether the frame with the given characteristics should
// be printed during a traceback.
func showframe(f funcInfo, gp *g, firstFrame bool, funcID, childID funcID) bool {
mp := getg().m
if mp.throwing >= throwTypeRuntime && gp != nil && (gp == mp.curg || gp == mp.caughtsig.ptr()) {
return true
}
return showfuncinfo(f, firstFrame, funcID, childID)
}
// showfuncinfo reports whether a function with the given characteristics should
// be printed during a traceback.
func showfuncinfo(f funcInfo, firstFrame bool, funcID, childID funcID) bool {
// Note that f may be a synthesized funcInfo for an inlined
// function, in which case only nameOff and funcID are set.
level, _, _ := gotraceback()
if level > 1 {
// Show all frames.
return true
}
if !f.valid() {
return false
}
if funcID == funcID_wrapper && elideWrapperCalling(childID) {
return false
}
name := funcname(f)
// Special case: always show runtime.gopanic frame
// in the middle of a stack trace, so that we can
// see the boundary between ordinary code and
// panic-induced deferred code.
// See golang.org/issue/5832.
if name == "runtime.gopanic" && !firstFrame {
return true
}
return bytealg.IndexByteString(name, '.') >= 0 && (!hasPrefix(name, "runtime.") || isExportedRuntime(name))
}
// isExportedRuntime reports whether name is an exported runtime function.
// It is only for runtime functions, so ASCII A-Z is fine.
func isExportedRuntime(name string) bool {
const n = len("runtime.")
return len(name) > n && name[:n] == "runtime." && 'A' <= name[n] && name[n] <= 'Z'
}
// elideWrapperCalling reports whether a wrapper function that called
// function id should be elided from stack traces.
func elideWrapperCalling(id funcID) bool {
// If the wrapper called a panic function instead of the
// wrapped function, we want to include it in stacks.
return !(id == funcID_gopanic || id == funcID_sigpanic || id == funcID_panicwrap)
}
var gStatusStrings = [...]string{
_Gidle: "idle",
_Grunnable: "runnable",
_Grunning: "running",
_Gsyscall: "syscall",
_Gwaiting: "waiting",
_Gdead: "dead",
_Gcopystack: "copystack",
_Gpreempted: "preempted",
}
func goroutineheader(gp *g) {
gpstatus := readgstatus(gp)
isScan := gpstatus&_Gscan != 0
gpstatus &^= _Gscan // drop the scan bit
// Basic string status
var status string
if 0 <= gpstatus && gpstatus < uint32(len(gStatusStrings)) {
status = gStatusStrings[gpstatus]
} else {
status = "???"
}
// Override.
if gpstatus == _Gwaiting && gp.waitreason != waitReasonZero {
status = gp.waitreason.String()
}
// approx time the G is blocked, in minutes
var waitfor int64
if (gpstatus == _Gwaiting || gpstatus == _Gsyscall) && gp.waitsince != 0 {
waitfor = (nanotime() - gp.waitsince) / 60e9
}
print("goroutine ", gp.goid, " [", status)
if isScan {
print(" (scan)")
}
if waitfor >= 1 {
print(", ", waitfor, " minutes")
}
if gp.lockedm != 0 {
print(", locked to thread")
}
print("]:\n")
}
func tracebackothers(me *g) {
level, _, _ := gotraceback()
// Show the current goroutine first, if we haven't already.
curgp := getg().m.curg
if curgp != nil && curgp != me {
print("\n")
goroutineheader(curgp)
traceback(^uintptr(0), ^uintptr(0), 0, curgp)
}
// We can't call locking forEachG here because this may be during fatal
// throw/panic, where locking could be out-of-order or a direct
// deadlock.
//
// Instead, use forEachGRace, which requires no locking. We don't lock
// against concurrent creation of new Gs, but even with allglock we may
// miss Gs created after this loop.
forEachGRace(func(gp *g) {
if gp == me || gp == curgp || readgstatus(gp) == _Gdead || isSystemGoroutine(gp, false) && level < 2 {
return
}
print("\n")
goroutineheader(gp)
// Note: gp.m == getg().m occurs when tracebackothers is called
// from a signal handler initiated during a systemstack call.
// The original G is still in the running state, and we want to
// print its stack.
if gp.m != getg().m && readgstatus(gp)&^_Gscan == _Grunning {
print("\tgoroutine running on other thread; stack unavailable\n")
printcreatedby(gp)
} else {
traceback(^uintptr(0), ^uintptr(0), 0, gp)
}
})
}
// tracebackHexdump hexdumps part of stk around frame.sp and frame.fp
// for debugging purposes. If the address bad is included in the
// hexdumped range, it will mark it as well.
func tracebackHexdump(stk stack, frame *stkframe, bad uintptr) {
const expand = 32 * goarch.PtrSize
const maxExpand = 256 * goarch.PtrSize
// Start around frame.sp.
lo, hi := frame.sp, frame.sp
// Expand to include frame.fp.
if frame.fp != 0 && frame.fp < lo {
lo = frame.fp
}
if frame.fp != 0 && frame.fp > hi {
hi = frame.fp
}
// Expand a bit more.
lo, hi = lo-expand, hi+expand
// But don't go too far from frame.sp.
if lo < frame.sp-maxExpand {
lo = frame.sp - maxExpand
}
if hi > frame.sp+maxExpand {
hi = frame.sp + maxExpand
}
// And don't go outside the stack bounds.
if lo < stk.lo {
lo = stk.lo
}
if hi > stk.hi {
hi = stk.hi
}
// Print the hex dump.
print("stack: frame={sp:", hex(frame.sp), ", fp:", hex(frame.fp), "} stack=[", hex(stk.lo), ",", hex(stk.hi), ")\n")
hexdumpWords(lo, hi, func(p uintptr) byte {
switch p {
case frame.fp:
return '>'
case frame.sp:
return '<'
case bad:
return '!'
}
return 0
})
}
// isSystemGoroutine reports whether the goroutine g must be omitted
// in stack dumps and deadlock detector. This is any goroutine that
// starts at a runtime.* entry point, except for runtime.main,
// runtime.handleAsyncEvent (wasm only) and sometimes runtime.runfinq.
//
// If fixed is true, any goroutine that can vary between user and
// system (that is, the finalizer goroutine) is considered a user
// goroutine.
func isSystemGoroutine(gp *g, fixed bool) bool {
// Keep this in sync with internal/trace.IsSystemGoroutine.
f := findfunc(gp.startpc)
if !f.valid() {
return false
}
if f.funcID == funcID_runtime_main || f.funcID == funcID_handleAsyncEvent {
return false
}
if f.funcID == funcID_runfinq {
// We include the finalizer goroutine if it's calling
// back into user code.
if fixed {
// This goroutine can vary. In fixed mode,
// always consider it a user goroutine.
return false
}
return fingStatus.Load()&fingRunningFinalizer == 0
}
return hasPrefix(funcname(f), "runtime.")
}
// SetCgoTraceback records three C functions to use to gather
// traceback information from C code and to convert that traceback
// information into symbolic information. These are used when printing
// stack traces for a program that uses cgo.
//
// The traceback and context functions may be called from a signal
// handler, and must therefore use only async-signal safe functions.
// The symbolizer function may be called while the program is
// crashing, and so must be cautious about using memory. None of the
// functions may call back into Go.
//
// The context function will be called with a single argument, a
// pointer to a struct:
//
// struct {
// Context uintptr
// }
//
// In C syntax, this struct will be
//
// struct {
// uintptr_t Context;
// };
//
// If the Context field is 0, the context function is being called to
// record the current traceback context. It should record in the
// Context field whatever information is needed about the current
// point of execution to later produce a stack trace, probably the
// stack pointer and PC. In this case the context function will be
// called from C code.
//
// If the Context field is not 0, then it is a value returned by a
// previous call to the context function. This case is called when the
// context is no longer needed; that is, when the Go code is returning
// to its C code caller. This permits the context function to release
// any associated resources.
//
// While it would be correct for the context function to record a
// complete a stack trace whenever it is called, and simply copy that
// out in the traceback function, in a typical program the context
// function will be called many times without ever recording a
// traceback for that context. Recording a complete stack trace in a
// call to the context function is likely to be inefficient.
//
// The traceback function will be called with a single argument, a
// pointer to a struct:
//
// struct {
// Context uintptr
// SigContext uintptr
// Buf *uintptr
// Max uintptr
// }
//
// In C syntax, this struct will be
//
// struct {
// uintptr_t Context;
// uintptr_t SigContext;
// uintptr_t* Buf;
// uintptr_t Max;
// };
//
// The Context field will be zero to gather a traceback from the
// current program execution point. In this case, the traceback
// function will be called from C code.
//
// Otherwise Context will be a value previously returned by a call to
// the context function. The traceback function should gather a stack
// trace from that saved point in the program execution. The traceback
// function may be called from an execution thread other than the one
// that recorded the context, but only when the context is known to be
// valid and unchanging. The traceback function may also be called
// deeper in the call stack on the same thread that recorded the
// context. The traceback function may be called multiple times with
// the same Context value; it will usually be appropriate to cache the
// result, if possible, the first time this is called for a specific
// context value.
//
// If the traceback function is called from a signal handler on a Unix
// system, SigContext will be the signal context argument passed to
// the signal handler (a C ucontext_t* cast to uintptr_t). This may be
// used to start tracing at the point where the signal occurred. If
// the traceback function is not called from a signal handler,
// SigContext will be zero.
//
// Buf is where the traceback information should be stored. It should
// be PC values, such that Buf[0] is the PC of the caller, Buf[1] is
// the PC of that function's caller, and so on. Max is the maximum
// number of entries to store. The function should store a zero to
// indicate the top of the stack, or that the caller is on a different
// stack, presumably a Go stack.
//
// Unlike runtime.Callers, the PC values returned should, when passed
// to the symbolizer function, return the file/line of the call
// instruction. No additional subtraction is required or appropriate.
//
// On all platforms, the traceback function is invoked when a call from
// Go to C to Go requests a stack trace. On linux/amd64, linux/ppc64le,
// linux/arm64, and freebsd/amd64, the traceback function is also invoked
// when a signal is received by a thread that is executing a cgo call.
// The traceback function should not make assumptions about when it is
// called, as future versions of Go may make additional calls.
//
// The symbolizer function will be called with a single argument, a
// pointer to a struct:
//
// struct {
// PC uintptr // program counter to fetch information for
// File *byte // file name (NUL terminated)
// Lineno uintptr // line number
// Func *byte // function name (NUL terminated)
// Entry uintptr // function entry point
// More uintptr // set non-zero if more info for this PC
// Data uintptr // unused by runtime, available for function
// }
//
// In C syntax, this struct will be
//
// struct {
// uintptr_t PC;
// char* File;
// uintptr_t Lineno;
// char* Func;
// uintptr_t Entry;
// uintptr_t More;
// uintptr_t Data;
// };
//
// The PC field will be a value returned by a call to the traceback
// function.
//
// The first time the function is called for a particular traceback,
// all the fields except PC will be 0. The function should fill in the
// other fields if possible, setting them to 0/nil if the information
// is not available. The Data field may be used to store any useful
// information across calls. The More field should be set to non-zero
// if there is more information for this PC, zero otherwise. If More
// is set non-zero, the function will be called again with the same
// PC, and may return different information (this is intended for use
// with inlined functions). If More is zero, the function will be
// called with the next PC value in the traceback. When the traceback
// is complete, the function will be called once more with PC set to
// zero; this may be used to free any information. Each call will
// leave the fields of the struct set to the same values they had upon
// return, except for the PC field when the More field is zero. The
// function must not keep a copy of the struct pointer between calls.
//
// When calling SetCgoTraceback, the version argument is the version
// number of the structs that the functions expect to receive.
// Currently this must be zero.
//
// The symbolizer function may be nil, in which case the results of
// the traceback function will be displayed as numbers. If the
// traceback function is nil, the symbolizer function will never be
// called. The context function may be nil, in which case the
// traceback function will only be called with the context field set
// to zero. If the context function is nil, then calls from Go to C
// to Go will not show a traceback for the C portion of the call stack.
//
// SetCgoTraceback should be called only once, ideally from an init function.
func SetCgoTraceback(version int, traceback, context, symbolizer unsafe.Pointer) {
if version != 0 {
panic("unsupported version")
}
if cgoTraceback != nil && cgoTraceback != traceback ||
cgoContext != nil && cgoContext != context ||
cgoSymbolizer != nil && cgoSymbolizer != symbolizer {
panic("call SetCgoTraceback only once")
}
cgoTraceback = traceback
cgoContext = context
cgoSymbolizer = symbolizer
// The context function is called when a C function calls a Go
// function. As such it is only called by C code in runtime/cgo.
if _cgo_set_context_function != nil {
cgocall(_cgo_set_context_function, context)
}
}
var cgoTraceback unsafe.Pointer
var cgoContext unsafe.Pointer
var cgoSymbolizer unsafe.Pointer
// cgoTracebackArg is the type passed to cgoTraceback.
type cgoTracebackArg struct {
context uintptr
sigContext uintptr
buf *uintptr
max uintptr
}
// cgoContextArg is the type passed to the context function.
type cgoContextArg struct {
context uintptr
}
// cgoSymbolizerArg is the type passed to cgoSymbolizer.
type cgoSymbolizerArg struct {
pc uintptr
file *byte
lineno uintptr
funcName *byte
entry uintptr
more uintptr
data uintptr
}
// printCgoTraceback prints a traceback of callers.
func printCgoTraceback(callers *cgoCallers) {
if cgoSymbolizer == nil {
for _, c := range callers {
if c == 0 {
break
}
print("non-Go function at pc=", hex(c), "\n")
}
return
}
var arg cgoSymbolizerArg
for _, c := range callers {
if c == 0 {
break
}
printOneCgoTraceback(c, 0x7fffffff, &arg)
}
arg.pc = 0
callCgoSymbolizer(&arg)
}
// printOneCgoTraceback prints the traceback of a single cgo caller.
// This can print more than one line because of inlining.
// Returns the number of frames printed.
func printOneCgoTraceback(pc uintptr, max int, arg *cgoSymbolizerArg) int {
c := 0
arg.pc = pc
for c <= max {
callCgoSymbolizer(arg)
if arg.funcName != nil {
// Note that we don't print any argument
// information here, not even parentheses.
// The symbolizer must add that if appropriate.
println(gostringnocopy(arg.funcName))
} else {
println("non-Go function")
}
print("\t")
if arg.file != nil {
print(gostringnocopy(arg.file), ":", arg.lineno, " ")
}
print("pc=", hex(pc), "\n")
c++
if arg.more == 0 {
break
}
}
return c
}
// callCgoSymbolizer calls the cgoSymbolizer function.
func callCgoSymbolizer(arg *cgoSymbolizerArg) {
call := cgocall
if panicking.Load() > 0 || getg().m.curg != getg() {
// We do not want to call into the scheduler when panicking
// or when on the system stack.
call = asmcgocall
}
if msanenabled {
msanwrite(unsafe.Pointer(arg), unsafe.Sizeof(cgoSymbolizerArg{}))
}
if asanenabled {
asanwrite(unsafe.Pointer(arg), unsafe.Sizeof(cgoSymbolizerArg{}))
}
call(cgoSymbolizer, noescape(unsafe.Pointer(arg)))
}
// cgoContextPCs gets the PC values from a cgo traceback.
func cgoContextPCs(ctxt uintptr, buf []uintptr) {
if cgoTraceback == nil {
return
}
call := cgocall
if panicking.Load() > 0 || getg().m.curg != getg() {
// We do not want to call into the scheduler when panicking
// or when on the system stack.
call = asmcgocall
}
arg := cgoTracebackArg{
context: ctxt,
buf: (*uintptr)(noescape(unsafe.Pointer(&buf[0]))),
max: uintptr(len(buf)),
}
if msanenabled {
msanwrite(unsafe.Pointer(&arg), unsafe.Sizeof(arg))
}
if asanenabled {
asanwrite(unsafe.Pointer(&arg), unsafe.Sizeof(arg))
}
call(cgoTraceback, noescape(unsafe.Pointer(&arg)))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Runtime type representation.
package runtime
import (
"internal/abi"
"unsafe"
)
// tflag is documented in reflect/type.go.
//
// tflag values must be kept in sync with copies in:
//
// cmd/compile/internal/reflectdata/reflect.go
// cmd/link/internal/ld/decodesym.go
// reflect/type.go
// internal/reflectlite/type.go
type tflag uint8
const (
tflagUncommon tflag = 1 << 0
tflagExtraStar tflag = 1 << 1
tflagNamed tflag = 1 << 2
tflagRegularMemory tflag = 1 << 3 // equal and hash can treat values of this type as a single region of t.size bytes
)
// Needs to be in sync with ../cmd/link/internal/ld/decodesym.go:/^func.commonsize,
// ../cmd/compile/internal/reflectdata/reflect.go:/^func.dcommontype and
// ../reflect/type.go:/^type.rtype.
// ../internal/reflectlite/type.go:/^type.rtype.
type _type struct {
size uintptr
ptrdata uintptr // size of memory prefix holding all pointers
hash uint32
tflag tflag
align uint8
fieldAlign uint8
kind uint8
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
equal func(unsafe.Pointer, unsafe.Pointer) bool
// gcdata stores the GC type data for the garbage collector.
// If the KindGCProg bit is set in kind, gcdata is a GC program.
// Otherwise it is a ptrmask bitmap. See mbitmap.go for details.
gcdata *byte
str nameOff
ptrToThis typeOff
}
func (t *_type) string() string {
s := t.nameOff(t.str).name()
if t.tflag&tflagExtraStar != 0 {
return s[1:]
}
return s
}
func (t *_type) uncommon() *uncommontype {
if t.tflag&tflagUncommon == 0 {
return nil
}
switch t.kind & kindMask {
case kindStruct:
type u struct {
structtype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindPtr:
type u struct {
ptrtype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindFunc:
type u struct {
functype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindSlice:
type u struct {
slicetype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindArray:
type u struct {
arraytype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindChan:
type u struct {
chantype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindMap:
type u struct {
maptype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
case kindInterface:
type u struct {
interfacetype
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
default:
type u struct {
_type
u uncommontype
}
return &(*u)(unsafe.Pointer(t)).u
}
}
func (t *_type) name() string {
if t.tflag&tflagNamed == 0 {
return ""
}
s := t.string()
i := len(s) - 1
sqBrackets := 0
for i >= 0 && (s[i] != '.' || sqBrackets != 0) {
switch s[i] {
case ']':
sqBrackets++
case '[':
sqBrackets--
}
i--
}
return s[i+1:]
}
// pkgpath returns the path of the package where t was defined, if
// available. This is not the same as the reflect package's PkgPath
// method, in that it returns the package path for struct and interface
// types, not just named types.
func (t *_type) pkgpath() string {
if u := t.uncommon(); u != nil {
return t.nameOff(u.pkgpath).name()
}
switch t.kind & kindMask {
case kindStruct:
st := (*structtype)(unsafe.Pointer(t))
return st.pkgPath.name()
case kindInterface:
it := (*interfacetype)(unsafe.Pointer(t))
return it.pkgpath.name()
}
return ""
}
// reflectOffs holds type offsets defined at run time by the reflect package.
//
// When a type is defined at run time, its *rtype data lives on the heap.
// There are a wide range of possible addresses the heap may use, that
// may not be representable as a 32-bit offset. Moreover the GC may
// one day start moving heap memory, in which case there is no stable
// offset that can be defined.
//
// To provide stable offsets, we add pin *rtype objects in a global map
// and treat the offset as an identifier. We use negative offsets that
// do not overlap with any compile-time module offsets.
//
// Entries are created by reflect.addReflectOff.
var reflectOffs struct {
lock mutex
next int32
m map[int32]unsafe.Pointer
minv map[unsafe.Pointer]int32
}
func reflectOffsLock() {
lock(&reflectOffs.lock)
if raceenabled {
raceacquire(unsafe.Pointer(&reflectOffs.lock))
}
}
func reflectOffsUnlock() {
if raceenabled {
racerelease(unsafe.Pointer(&reflectOffs.lock))
}
unlock(&reflectOffs.lock)
}
func resolveNameOff(ptrInModule unsafe.Pointer, off nameOff) name {
if off == 0 {
return name{}
}
base := uintptr(ptrInModule)
for md := &firstmoduledata; md != nil; md = md.next {
if base >= md.types && base < md.etypes {
res := md.types + uintptr(off)
if res > md.etypes {
println("runtime: nameOff", hex(off), "out of range", hex(md.types), "-", hex(md.etypes))
throw("runtime: name offset out of range")
}
return name{(*byte)(unsafe.Pointer(res))}
}
}
// No module found. see if it is a run time name.
reflectOffsLock()
res, found := reflectOffs.m[int32(off)]
reflectOffsUnlock()
if !found {
println("runtime: nameOff", hex(off), "base", hex(base), "not in ranges:")
for next := &firstmoduledata; next != nil; next = next.next {
println("\ttypes", hex(next.types), "etypes", hex(next.etypes))
}
throw("runtime: name offset base pointer out of range")
}
return name{(*byte)(res)}
}
func (t *_type) nameOff(off nameOff) name {
return resolveNameOff(unsafe.Pointer(t), off)
}
func resolveTypeOff(ptrInModule unsafe.Pointer, off typeOff) *_type {
if off == 0 || off == -1 {
// -1 is the sentinel value for unreachable code.
// See cmd/link/internal/ld/data.go:relocsym.
return nil
}
base := uintptr(ptrInModule)
var md *moduledata
for next := &firstmoduledata; next != nil; next = next.next {
if base >= next.types && base < next.etypes {
md = next
break
}
}
if md == nil {
reflectOffsLock()
res := reflectOffs.m[int32(off)]
reflectOffsUnlock()
if res == nil {
println("runtime: typeOff", hex(off), "base", hex(base), "not in ranges:")
for next := &firstmoduledata; next != nil; next = next.next {
println("\ttypes", hex(next.types), "etypes", hex(next.etypes))
}
throw("runtime: type offset base pointer out of range")
}
return (*_type)(res)
}
if t := md.typemap[off]; t != nil {
return t
}
res := md.types + uintptr(off)
if res > md.etypes {
println("runtime: typeOff", hex(off), "out of range", hex(md.types), "-", hex(md.etypes))
throw("runtime: type offset out of range")
}
return (*_type)(unsafe.Pointer(res))
}
func (t *_type) typeOff(off typeOff) *_type {
return resolveTypeOff(unsafe.Pointer(t), off)
}
func (t *_type) textOff(off textOff) unsafe.Pointer {
if off == -1 {
// -1 is the sentinel value for unreachable code.
// See cmd/link/internal/ld/data.go:relocsym.
return unsafe.Pointer(abi.FuncPCABIInternal(unreachableMethod))
}
base := uintptr(unsafe.Pointer(t))
var md *moduledata
for next := &firstmoduledata; next != nil; next = next.next {
if base >= next.types && base < next.etypes {
md = next
break
}
}
if md == nil {
reflectOffsLock()
res := reflectOffs.m[int32(off)]
reflectOffsUnlock()
if res == nil {
println("runtime: textOff", hex(off), "base", hex(base), "not in ranges:")
for next := &firstmoduledata; next != nil; next = next.next {
println("\ttypes", hex(next.types), "etypes", hex(next.etypes))
}
throw("runtime: text offset base pointer out of range")
}
return res
}
res := md.textAddr(uint32(off))
return unsafe.Pointer(res)
}
func (t *functype) in() []*_type {
// See funcType in reflect/type.go for details on data layout.
uadd := uintptr(unsafe.Sizeof(functype{}))
if t.typ.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommontype{})
}
return (*[1 << 20]*_type)(add(unsafe.Pointer(t), uadd))[:t.inCount]
}
func (t *functype) out() []*_type {
// See funcType in reflect/type.go for details on data layout.
uadd := uintptr(unsafe.Sizeof(functype{}))
if t.typ.tflag&tflagUncommon != 0 {
uadd += unsafe.Sizeof(uncommontype{})
}
outCount := t.outCount & (1<<15 - 1)
return (*[1 << 20]*_type)(add(unsafe.Pointer(t), uadd))[t.inCount : t.inCount+outCount]
}
func (t *functype) dotdotdot() bool {
return t.outCount&(1<<15) != 0
}
type nameOff int32
type typeOff int32
type textOff int32
type method struct {
name nameOff
mtyp typeOff
ifn textOff
tfn textOff
}
type uncommontype struct {
pkgpath nameOff
mcount uint16 // number of methods
xcount uint16 // number of exported methods
moff uint32 // offset from this uncommontype to [mcount]method
_ uint32 // unused
}
type imethod struct {
name nameOff
ityp typeOff
}
type interfacetype struct {
typ _type
pkgpath name
mhdr []imethod
}
type maptype struct {
typ _type
key *_type
elem *_type
bucket *_type // internal type representing a hash bucket
// function for hashing keys (ptr to key, seed) -> hash
hasher func(unsafe.Pointer, uintptr) uintptr
keysize uint8 // size of key slot
elemsize uint8 // size of elem slot
bucketsize uint16 // size of bucket
flags uint32
}
// Note: flag values must match those used in the TMAP case
// in ../cmd/compile/internal/reflectdata/reflect.go:writeType.
func (mt *maptype) indirectkey() bool { // store ptr to key instead of key itself
return mt.flags&1 != 0
}
func (mt *maptype) indirectelem() bool { // store ptr to elem instead of elem itself
return mt.flags&2 != 0
}
func (mt *maptype) reflexivekey() bool { // true if k==k for all keys
return mt.flags&4 != 0
}
func (mt *maptype) needkeyupdate() bool { // true if we need to update key on an overwrite
return mt.flags&8 != 0
}
func (mt *maptype) hashMightPanic() bool { // true if hash function might panic
return mt.flags&16 != 0
}
type arraytype struct {
typ _type
elem *_type
slice *_type
len uintptr
}
type chantype struct {
typ _type
elem *_type
dir uintptr
}
type slicetype struct {
typ _type
elem *_type
}
type functype struct {
typ _type
inCount uint16
outCount uint16
}
type ptrtype struct {
typ _type
elem *_type
}
type structfield struct {
name name
typ *_type
offset uintptr
}
type structtype struct {
typ _type
pkgPath name
fields []structfield
}
// name is an encoded type name with optional extra data.
// See reflect/type.go for details.
type name struct {
bytes *byte
}
func (n name) data(off int) *byte {
return (*byte)(add(unsafe.Pointer(n.bytes), uintptr(off)))
}
func (n name) isExported() bool {
return (*n.bytes)&(1<<0) != 0
}
func (n name) isEmbedded() bool {
return (*n.bytes)&(1<<3) != 0
}
func (n name) readvarint(off int) (int, int) {
v := 0
for i := 0; ; i++ {
x := *n.data(off + i)
v += int(x&0x7f) << (7 * i)
if x&0x80 == 0 {
return i + 1, v
}
}
}
func (n name) name() string {
if n.bytes == nil {
return ""
}
i, l := n.readvarint(1)
if l == 0 {
return ""
}
return unsafe.String(n.data(1+i), l)
}
func (n name) tag() string {
if *n.data(0)&(1<<1) == 0 {
return ""
}
i, l := n.readvarint(1)
i2, l2 := n.readvarint(1 + i + l)
return unsafe.String(n.data(1+i+l+i2), l2)
}
func (n name) pkgPath() string {
if n.bytes == nil || *n.data(0)&(1<<2) == 0 {
return ""
}
i, l := n.readvarint(1)
off := 1 + i + l
if *n.data(0)&(1<<1) != 0 {
i2, l2 := n.readvarint(off)
off += i2 + l2
}
var nameOff nameOff
copy((*[4]byte)(unsafe.Pointer(&nameOff))[:], (*[4]byte)(unsafe.Pointer(n.data(off)))[:])
pkgPathName := resolveNameOff(unsafe.Pointer(n.bytes), nameOff)
return pkgPathName.name()
}
func (n name) isBlank() bool {
if n.bytes == nil {
return false
}
_, l := n.readvarint(1)
return l == 1 && *n.data(2) == '_'
}
// typelinksinit scans the types from extra modules and builds the
// moduledata typemap used to de-duplicate type pointers.
func typelinksinit() {
if firstmoduledata.next == nil {
return
}
typehash := make(map[uint32][]*_type, len(firstmoduledata.typelinks))
modules := activeModules()
prev := modules[0]
for _, md := range modules[1:] {
// Collect types from the previous module into typehash.
collect:
for _, tl := range prev.typelinks {
var t *_type
if prev.typemap == nil {
t = (*_type)(unsafe.Pointer(prev.types + uintptr(tl)))
} else {
t = prev.typemap[typeOff(tl)]
}
// Add to typehash if not seen before.
tlist := typehash[t.hash]
for _, tcur := range tlist {
if tcur == t {
continue collect
}
}
typehash[t.hash] = append(tlist, t)
}
if md.typemap == nil {
// If any of this module's typelinks match a type from a
// prior module, prefer that prior type by adding the offset
// to this module's typemap.
tm := make(map[typeOff]*_type, len(md.typelinks))
pinnedTypemaps = append(pinnedTypemaps, tm)
md.typemap = tm
for _, tl := range md.typelinks {
t := (*_type)(unsafe.Pointer(md.types + uintptr(tl)))
for _, candidate := range typehash[t.hash] {
seen := map[_typePair]struct{}{}
if typesEqual(t, candidate, seen) {
t = candidate
break
}
}
md.typemap[typeOff(tl)] = t
}
}
prev = md
}
}
type _typePair struct {
t1 *_type
t2 *_type
}
// typesEqual reports whether two types are equal.
//
// Everywhere in the runtime and reflect packages, it is assumed that
// there is exactly one *_type per Go type, so that pointer equality
// can be used to test if types are equal. There is one place that
// breaks this assumption: buildmode=shared. In this case a type can
// appear as two different pieces of memory. This is hidden from the
// runtime and reflect package by the per-module typemap built in
// typelinksinit. It uses typesEqual to map types from later modules
// back into earlier ones.
//
// Only typelinksinit needs this function.
func typesEqual(t, v *_type, seen map[_typePair]struct{}) bool {
tp := _typePair{t, v}
if _, ok := seen[tp]; ok {
return true
}
// mark these types as seen, and thus equivalent which prevents an infinite loop if
// the two types are identical, but recursively defined and loaded from
// different modules
seen[tp] = struct{}{}
if t == v {
return true
}
kind := t.kind & kindMask
if kind != v.kind&kindMask {
return false
}
if t.string() != v.string() {
return false
}
ut := t.uncommon()
uv := v.uncommon()
if ut != nil || uv != nil {
if ut == nil || uv == nil {
return false
}
pkgpatht := t.nameOff(ut.pkgpath).name()
pkgpathv := v.nameOff(uv.pkgpath).name()
if pkgpatht != pkgpathv {
return false
}
}
if kindBool <= kind && kind <= kindComplex128 {
return true
}
switch kind {
case kindString, kindUnsafePointer:
return true
case kindArray:
at := (*arraytype)(unsafe.Pointer(t))
av := (*arraytype)(unsafe.Pointer(v))
return typesEqual(at.elem, av.elem, seen) && at.len == av.len
case kindChan:
ct := (*chantype)(unsafe.Pointer(t))
cv := (*chantype)(unsafe.Pointer(v))
return ct.dir == cv.dir && typesEqual(ct.elem, cv.elem, seen)
case kindFunc:
ft := (*functype)(unsafe.Pointer(t))
fv := (*functype)(unsafe.Pointer(v))
if ft.outCount != fv.outCount || ft.inCount != fv.inCount {
return false
}
tin, vin := ft.in(), fv.in()
for i := 0; i < len(tin); i++ {
if !typesEqual(tin[i], vin[i], seen) {
return false
}
}
tout, vout := ft.out(), fv.out()
for i := 0; i < len(tout); i++ {
if !typesEqual(tout[i], vout[i], seen) {
return false
}
}
return true
case kindInterface:
it := (*interfacetype)(unsafe.Pointer(t))
iv := (*interfacetype)(unsafe.Pointer(v))
if it.pkgpath.name() != iv.pkgpath.name() {
return false
}
if len(it.mhdr) != len(iv.mhdr) {
return false
}
for i := range it.mhdr {
tm := &it.mhdr[i]
vm := &iv.mhdr[i]
// Note the mhdr array can be relocated from
// another module. See #17724.
tname := resolveNameOff(unsafe.Pointer(tm), tm.name)
vname := resolveNameOff(unsafe.Pointer(vm), vm.name)
if tname.name() != vname.name() {
return false
}
if tname.pkgPath() != vname.pkgPath() {
return false
}
tityp := resolveTypeOff(unsafe.Pointer(tm), tm.ityp)
vityp := resolveTypeOff(unsafe.Pointer(vm), vm.ityp)
if !typesEqual(tityp, vityp, seen) {
return false
}
}
return true
case kindMap:
mt := (*maptype)(unsafe.Pointer(t))
mv := (*maptype)(unsafe.Pointer(v))
return typesEqual(mt.key, mv.key, seen) && typesEqual(mt.elem, mv.elem, seen)
case kindPtr:
pt := (*ptrtype)(unsafe.Pointer(t))
pv := (*ptrtype)(unsafe.Pointer(v))
return typesEqual(pt.elem, pv.elem, seen)
case kindSlice:
st := (*slicetype)(unsafe.Pointer(t))
sv := (*slicetype)(unsafe.Pointer(v))
return typesEqual(st.elem, sv.elem, seen)
case kindStruct:
st := (*structtype)(unsafe.Pointer(t))
sv := (*structtype)(unsafe.Pointer(v))
if len(st.fields) != len(sv.fields) {
return false
}
if st.pkgPath.name() != sv.pkgPath.name() {
return false
}
for i := range st.fields {
tf := &st.fields[i]
vf := &sv.fields[i]
if tf.name.name() != vf.name.name() {
return false
}
if !typesEqual(tf.typ, vf.typ, seen) {
return false
}
if tf.name.tag() != vf.name.tag() {
return false
}
if tf.offset != vf.offset {
return false
}
if tf.name.isEmbedded() != vf.name.isEmbedded() {
return false
}
}
return true
default:
println("runtime: impossible type kind", kind)
throw("runtime: impossible type kind")
return false
}
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
const (
kindBool = 1 + iota
kindInt
kindInt8
kindInt16
kindInt32
kindInt64
kindUint
kindUint8
kindUint16
kindUint32
kindUint64
kindUintptr
kindFloat32
kindFloat64
kindComplex64
kindComplex128
kindArray
kindChan
kindFunc
kindInterface
kindMap
kindPtr
kindSlice
kindString
kindStruct
kindUnsafePointer
kindDirectIface = 1 << 5
kindGCProg = 1 << 6
kindMask = (1 << 5) - 1
)
// isDirectIface reports whether t is stored directly in an interface value.
func isDirectIface(t *_type) bool {
return t.kind&kindDirectIface != 0
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
import (
"runtime/internal/math"
"unsafe"
)
func unsafestring(ptr unsafe.Pointer, len int) {
if len < 0 {
panicunsafestringlen()
}
if uintptr(len) > -uintptr(ptr) {
if ptr == nil {
panicunsafestringnilptr()
}
panicunsafestringlen()
}
}
// Keep this code in sync with cmd/compile/internal/walk/builtin.go:walkUnsafeString
func unsafestring64(ptr unsafe.Pointer, len64 int64) {
len := int(len64)
if int64(len) != len64 {
panicunsafestringlen()
}
unsafestring(ptr, len)
}
func unsafestringcheckptr(ptr unsafe.Pointer, len64 int64) {
unsafestring64(ptr, len64)
// Check that underlying array doesn't straddle multiple heap objects.
// unsafestring64 has already checked for overflow.
if checkptrStraddles(ptr, uintptr(len64)) {
throw("checkptr: unsafe.String result straddles multiple allocations")
}
}
func panicunsafestringlen() {
panic(errorString("unsafe.String: len out of range"))
}
func panicunsafestringnilptr() {
panic(errorString("unsafe.String: ptr is nil and len is not zero"))
}
// Keep this code in sync with cmd/compile/internal/walk/builtin.go:walkUnsafeSlice
func unsafeslice(et *_type, ptr unsafe.Pointer, len int) {
if len < 0 {
panicunsafeslicelen()
}
if et.size == 0 {
if ptr == nil && len > 0 {
panicunsafeslicenilptr()
}
}
mem, overflow := math.MulUintptr(et.size, uintptr(len))
if overflow || mem > -uintptr(ptr) {
if ptr == nil {
panicunsafeslicenilptr()
}
panicunsafeslicelen()
}
}
// Keep this code in sync with cmd/compile/internal/walk/builtin.go:walkUnsafeSlice
func unsafeslice64(et *_type, ptr unsafe.Pointer, len64 int64) {
len := int(len64)
if int64(len) != len64 {
panicunsafeslicelen()
}
unsafeslice(et, ptr, len)
}
func unsafeslicecheckptr(et *_type, ptr unsafe.Pointer, len64 int64) {
unsafeslice64(et, ptr, len64)
// Check that underlying array doesn't straddle multiple heap objects.
// unsafeslice64 has already checked for overflow.
if checkptrStraddles(ptr, uintptr(len64)*et.size) {
throw("checkptr: unsafe.Slice result straddles multiple allocations")
}
}
func panicunsafeslicelen() {
panic(errorString("unsafe.Slice: len out of range"))
}
func panicunsafeslicenilptr() {
panic(errorString("unsafe.Slice: ptr is nil and len is not zero"))
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package runtime
// Numbers fundamental to the encoding.
const (
runeError = '\uFFFD' // the "error" Rune or "Unicode replacement character"
runeSelf = 0x80 // characters below runeSelf are represented as themselves in a single byte.
maxRune = '\U0010FFFF' // Maximum valid Unicode code point.
)
// Code points in the surrogate range are not valid for UTF-8.
const (
surrogateMin = 0xD800
surrogateMax = 0xDFFF
)
const (
t1 = 0x00 // 0000 0000
tx = 0x80 // 1000 0000
t2 = 0xC0 // 1100 0000
t3 = 0xE0 // 1110 0000
t4 = 0xF0 // 1111 0000
t5 = 0xF8 // 1111 1000
maskx = 0x3F // 0011 1111
mask2 = 0x1F // 0001 1111
mask3 = 0x0F // 0000 1111
mask4 = 0x07 // 0000 0111
rune1Max = 1<<7 - 1
rune2Max = 1<<11 - 1
rune3Max = 1<<16 - 1
// The default lowest and highest continuation byte.
locb = 0x80 // 1000 0000
hicb = 0xBF // 1011 1111
)
// countrunes returns the number of runes in s.
func countrunes(s string) int {
n := 0
for range s {
n++
}
return n
}
// decoderune returns the non-ASCII rune at the start of
// s[k:] and the index after the rune in s.
//
// decoderune assumes that caller has checked that
// the to be decoded rune is a non-ASCII rune.
//
// If the string appears to be incomplete or decoding problems
// are encountered (runeerror, k + 1) is returned to ensure
// progress when decoderune is used to iterate over a string.
func decoderune(s string, k int) (r rune, pos int) {
pos = k
if k >= len(s) {
return runeError, k + 1
}
s = s[k:]
switch {
case t2 <= s[0] && s[0] < t3:
// 0080-07FF two byte sequence
if len(s) > 1 && (locb <= s[1] && s[1] <= hicb) {
r = rune(s[0]&mask2)<<6 | rune(s[1]&maskx)
pos += 2
if rune1Max < r {
return
}
}
case t3 <= s[0] && s[0] < t4:
// 0800-FFFF three byte sequence
if len(s) > 2 && (locb <= s[1] && s[1] <= hicb) && (locb <= s[2] && s[2] <= hicb) {
r = rune(s[0]&mask3)<<12 | rune(s[1]&maskx)<<6 | rune(s[2]&maskx)
pos += 3
if rune2Max < r && !(surrogateMin <= r && r <= surrogateMax) {
return
}
}
case t4 <= s[0] && s[0] < t5:
// 10000-1FFFFF four byte sequence
if len(s) > 3 && (locb <= s[1] && s[1] <= hicb) && (locb <= s[2] && s[2] <= hicb) && (locb <= s[3] && s[3] <= hicb) {
r = rune(s[0]&mask4)<<18 | rune(s[1]&maskx)<<12 | rune(s[2]&maskx)<<6 | rune(s[3]&maskx)
pos += 4
if rune3Max < r && r <= maxRune {
return
}
}
}
return runeError, k + 1
}
// encoderune writes into p (which must be large enough) the UTF-8 encoding of the rune.
// It returns the number of bytes written.
func encoderune(p []byte, r rune) int {
// Negative values are erroneous. Making it unsigned addresses the problem.
switch i := uint32(r); {
case i <= rune1Max:
p[0] = byte(r)
return 1
case i <= rune2Max:
_ = p[1] // eliminate bounds checks
p[0] = t2 | byte(r>>6)
p[1] = tx | byte(r)&maskx
return 2
case i > maxRune, surrogateMin <= i && i <= surrogateMax:
r = runeError
fallthrough
case i <= rune3Max:
_ = p[2] // eliminate bounds checks
p[0] = t3 | byte(r>>12)
p[1] = tx | byte(r>>6)&maskx
p[2] = tx | byte(r)&maskx
return 3
default:
_ = p[3] // eliminate bounds checks
p[0] = t4 | byte(r>>18)
p[1] = tx | byte(r>>12)&maskx
p[2] = tx | byte(r>>6)&maskx
p[3] = tx | byte(r)&maskx
return 4
}
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux && (386 || amd64 || arm || arm64 || loong64 || mips64 || mips64le || ppc64 || ppc64le || riscv64 || s390x)
package runtime
import "unsafe"
// Look up symbols in the Linux vDSO.
// This code was originally based on the sample Linux vDSO parser at
// https://git.kernel.org/cgit/linux/kernel/git/torvalds/linux.git/tree/tools/testing/selftests/vDSO/parse_vdso.c
// This implements the ELF dynamic linking spec at
// http://sco.com/developers/gabi/latest/ch5.dynamic.html
// The version section is documented at
// https://refspecs.linuxfoundation.org/LSB_3.2.0/LSB-Core-generic/LSB-Core-generic/symversion.html
const (
_AT_SYSINFO_EHDR = 33
_PT_LOAD = 1 /* Loadable program segment */
_PT_DYNAMIC = 2 /* Dynamic linking information */
_DT_NULL = 0 /* Marks end of dynamic section */
_DT_HASH = 4 /* Dynamic symbol hash table */
_DT_STRTAB = 5 /* Address of string table */
_DT_SYMTAB = 6 /* Address of symbol table */
_DT_GNU_HASH = 0x6ffffef5 /* GNU-style dynamic symbol hash table */
_DT_VERSYM = 0x6ffffff0
_DT_VERDEF = 0x6ffffffc
_VER_FLG_BASE = 0x1 /* Version definition of file itself */
_SHN_UNDEF = 0 /* Undefined section */
_SHT_DYNSYM = 11 /* Dynamic linker symbol table */
_STT_FUNC = 2 /* Symbol is a code object */
_STT_NOTYPE = 0 /* Symbol type is not specified */
_STB_GLOBAL = 1 /* Global symbol */
_STB_WEAK = 2 /* Weak symbol */
_EI_NIDENT = 16
// Maximum indices for the array types used when traversing the vDSO ELF structures.
// Computed from architecture-specific max provided by vdso_linux_*.go
vdsoSymTabSize = vdsoArrayMax / unsafe.Sizeof(elfSym{})
vdsoDynSize = vdsoArrayMax / unsafe.Sizeof(elfDyn{})
vdsoSymStringsSize = vdsoArrayMax // byte
vdsoVerSymSize = vdsoArrayMax / 2 // uint16
vdsoHashSize = vdsoArrayMax / 4 // uint32
// vdsoBloomSizeScale is a scaling factor for gnuhash tables which are uint32 indexed,
// but contain uintptrs
vdsoBloomSizeScale = unsafe.Sizeof(uintptr(0)) / 4 // uint32
)
/* How to extract and insert information held in the st_info field. */
func _ELF_ST_BIND(val byte) byte { return val >> 4 }
func _ELF_ST_TYPE(val byte) byte { return val & 0xf }
type vdsoSymbolKey struct {
name string
symHash uint32
gnuHash uint32
ptr *uintptr
}
type vdsoVersionKey struct {
version string
verHash uint32
}
type vdsoInfo struct {
valid bool
/* Load information */
loadAddr uintptr
loadOffset uintptr /* loadAddr - recorded vaddr */
/* Symbol table */
symtab *[vdsoSymTabSize]elfSym
symstrings *[vdsoSymStringsSize]byte
chain []uint32
bucket []uint32
symOff uint32
isGNUHash bool
/* Version table */
versym *[vdsoVerSymSize]uint16
verdef *elfVerdef
}
// see vdso_linux_*.go for vdsoSymbolKeys[] and vdso*Sym vars
func vdsoInitFromSysinfoEhdr(info *vdsoInfo, hdr *elfEhdr) {
info.valid = false
info.loadAddr = uintptr(unsafe.Pointer(hdr))
pt := unsafe.Pointer(info.loadAddr + uintptr(hdr.e_phoff))
// We need two things from the segment table: the load offset
// and the dynamic table.
var foundVaddr bool
var dyn *[vdsoDynSize]elfDyn
for i := uint16(0); i < hdr.e_phnum; i++ {
pt := (*elfPhdr)(add(pt, uintptr(i)*unsafe.Sizeof(elfPhdr{})))
switch pt.p_type {
case _PT_LOAD:
if !foundVaddr {
foundVaddr = true
info.loadOffset = info.loadAddr + uintptr(pt.p_offset-pt.p_vaddr)
}
case _PT_DYNAMIC:
dyn = (*[vdsoDynSize]elfDyn)(unsafe.Pointer(info.loadAddr + uintptr(pt.p_offset)))
}
}
if !foundVaddr || dyn == nil {
return // Failed
}
// Fish out the useful bits of the dynamic table.
var hash, gnuhash *[vdsoHashSize]uint32
info.symstrings = nil
info.symtab = nil
info.versym = nil
info.verdef = nil
for i := 0; dyn[i].d_tag != _DT_NULL; i++ {
dt := &dyn[i]
p := info.loadOffset + uintptr(dt.d_val)
switch dt.d_tag {
case _DT_STRTAB:
info.symstrings = (*[vdsoSymStringsSize]byte)(unsafe.Pointer(p))
case _DT_SYMTAB:
info.symtab = (*[vdsoSymTabSize]elfSym)(unsafe.Pointer(p))
case _DT_HASH:
hash = (*[vdsoHashSize]uint32)(unsafe.Pointer(p))
case _DT_GNU_HASH:
gnuhash = (*[vdsoHashSize]uint32)(unsafe.Pointer(p))
case _DT_VERSYM:
info.versym = (*[vdsoVerSymSize]uint16)(unsafe.Pointer(p))
case _DT_VERDEF:
info.verdef = (*elfVerdef)(unsafe.Pointer(p))
}
}
if info.symstrings == nil || info.symtab == nil || (hash == nil && gnuhash == nil) {
return // Failed
}
if info.verdef == nil {
info.versym = nil
}
if gnuhash != nil {
// Parse the GNU hash table header.
nbucket := gnuhash[0]
info.symOff = gnuhash[1]
bloomSize := gnuhash[2]
info.bucket = gnuhash[4+bloomSize*uint32(vdsoBloomSizeScale):][:nbucket]
info.chain = gnuhash[4+bloomSize*uint32(vdsoBloomSizeScale)+nbucket:]
info.isGNUHash = true
} else {
// Parse the hash table header.
nbucket := hash[0]
nchain := hash[1]
info.bucket = hash[2 : 2+nbucket]
info.chain = hash[2+nbucket : 2+nbucket+nchain]
}
// That's all we need.
info.valid = true
}
func vdsoFindVersion(info *vdsoInfo, ver *vdsoVersionKey) int32 {
if !info.valid {
return 0
}
def := info.verdef
for {
if def.vd_flags&_VER_FLG_BASE == 0 {
aux := (*elfVerdaux)(add(unsafe.Pointer(def), uintptr(def.vd_aux)))
if def.vd_hash == ver.verHash && ver.version == gostringnocopy(&info.symstrings[aux.vda_name]) {
return int32(def.vd_ndx & 0x7fff)
}
}
if def.vd_next == 0 {
break
}
def = (*elfVerdef)(add(unsafe.Pointer(def), uintptr(def.vd_next)))
}
return -1 // cannot match any version
}
func vdsoParseSymbols(info *vdsoInfo, version int32) {
if !info.valid {
return
}
apply := func(symIndex uint32, k vdsoSymbolKey) bool {
sym := &info.symtab[symIndex]
typ := _ELF_ST_TYPE(sym.st_info)
bind := _ELF_ST_BIND(sym.st_info)
// On ppc64x, VDSO functions are of type _STT_NOTYPE.
if typ != _STT_FUNC && typ != _STT_NOTYPE || bind != _STB_GLOBAL && bind != _STB_WEAK || sym.st_shndx == _SHN_UNDEF {
return false
}
if k.name != gostringnocopy(&info.symstrings[sym.st_name]) {
return false
}
// Check symbol version.
if info.versym != nil && version != 0 && int32(info.versym[symIndex]&0x7fff) != version {
return false
}
*k.ptr = info.loadOffset + uintptr(sym.st_value)
return true
}
if !info.isGNUHash {
// Old-style DT_HASH table.
for _, k := range vdsoSymbolKeys {
if len(info.bucket) > 0 {
for chain := info.bucket[k.symHash%uint32(len(info.bucket))]; chain != 0; chain = info.chain[chain] {
if apply(chain, k) {
break
}
}
}
}
return
}
// New-style DT_GNU_HASH table.
for _, k := range vdsoSymbolKeys {
symIndex := info.bucket[k.gnuHash%uint32(len(info.bucket))]
if symIndex < info.symOff {
continue
}
for ; ; symIndex++ {
hash := info.chain[symIndex-info.symOff]
if hash|1 == k.gnuHash|1 {
// Found a hash match.
if apply(symIndex, k) {
break
}
}
if hash&1 != 0 {
// End of chain.
break
}
}
}
}
func vdsoauxv(tag, val uintptr) {
switch tag {
case _AT_SYSINFO_EHDR:
if val == 0 {
// Something went wrong
return
}
var info vdsoInfo
// TODO(rsc): I don't understand why the compiler thinks info escapes
// when passed to the three functions below.
info1 := (*vdsoInfo)(noescape(unsafe.Pointer(&info)))
vdsoInitFromSysinfoEhdr(info1, (*elfEhdr)(unsafe.Pointer(val)))
vdsoParseSymbols(info1, vdsoFindVersion(info1, &vdsoLinuxVersion))
}
}
// vdsoMarker reports whether PC is on the VDSO page.
//
//go:nosplit
func inVDSOPage(pc uintptr) bool {
for _, k := range vdsoSymbolKeys {
if *k.ptr != 0 {
page := *k.ptr &^ (physPageSize - 1)
return pc >= page && pc < page+physPageSize
}
}
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !android
package runtime
import "unsafe"
func writeErr(b []byte) {
write(2, unsafe.Pointer(&b[0]), int32(len(b)))
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package slices defines various functions useful with slices of any type.
package slices
// Equal reports whether two slices are equal: the same length and all
// elements equal. If the lengths are different, Equal returns false.
// Otherwise, the elements are compared in increasing index order, and the
// comparison stops at the first unequal pair.
// Floating point NaNs are not considered equal.
func Equal[E comparable](s1, s2 []E) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}
// EqualFunc reports whether two slices are equal using a comparison
// function on each pair of elements. If the lengths are different,
// EqualFunc returns false. Otherwise, the elements are compared in
// increasing index order, and the comparison stops at the first index
// for which eq returns false.
func EqualFunc[E1, E2 any](s1 []E1, s2 []E2, eq func(E1, E2) bool) bool {
if len(s1) != len(s2) {
return false
}
for i, v1 := range s1 {
v2 := s2[i]
if !eq(v1, v2) {
return false
}
}
return true
}
// Index returns the index of the first occurrence of v in s,
// or -1 if not present.
func Index[E comparable](s []E, v E) int {
for i, vs := range s {
if v == vs {
return i
}
}
return -1
}
// IndexFunc returns the first index i satisfying f(s[i]),
// or -1 if none do.
func IndexFunc[E any](s []E, f func(E) bool) int {
for i, v := range s {
if f(v) {
return i
}
}
return -1
}
// Contains reports whether v is present in s.
func Contains[E comparable](s []E, v E) bool {
return Index(s, v) >= 0
}
// ContainsFunc reports whether at least one
// element e of s satisfies f(e).
func ContainsFunc[E any](s []E, f func(E) bool) bool {
return IndexFunc(s, f) >= 0
}
// Insert inserts the values v... into s at index i,
// returning the modified slice.
// The elements at s[i:] are shifted up to make room.
// In the returned slice r, r[i] == v[0],
// and r[i+len(v)] == value originally at r[i].
// Insert panics if i is out of range.
// This function is O(len(s) + len(v)).
func Insert[S ~[]E, E any](s S, i int, v ...E) S {
tot := len(s) + len(v)
if tot <= cap(s) {
s2 := s[:tot]
copy(s2[i+len(v):], s[i:])
copy(s2[i:], v)
return s2
}
s2 := make(S, tot)
copy(s2, s[:i])
copy(s2[i:], v)
copy(s2[i+len(v):], s[i:])
return s2
}
// Delete removes the elements s[i:j] from s, returning the modified slice.
// Delete panics if s[i:j] is not a valid slice of s.
// Delete modifies the contents of the slice s; it does not create a new slice.
// Delete is O(len(s)-j), so if many items must be deleted, it is better to
// make a single call deleting them all together than to delete one at a time.
// Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those
// elements contain pointers you might consider zeroing those elements so that
// objects they reference can be garbage collected.
func Delete[S ~[]E, E any](s S, i, j int) S {
_ = s[i:j] // bounds check
return append(s[:i], s[j:]...)
}
// Replace replaces the elements s[i:j] by the given v, and returns the
// modified slice. Replace panics if s[i:j] is not a valid slice of s.
func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
_ = s[i:j] // verify that i:j is a valid subslice
tot := len(s[:i]) + len(v) + len(s[j:])
if tot <= cap(s) {
s2 := s[:tot]
copy(s2[i+len(v):], s[j:])
copy(s2[i:], v)
return s2
}
s2 := make(S, tot)
copy(s2, s[:i])
copy(s2[i:], v)
copy(s2[i+len(v):], s[j:])
return s2
}
// Clone returns a copy of the slice.
// The elements are copied using assignment, so this is a shallow clone.
func Clone[S ~[]E, E any](s S) S {
// Preserve nil in case it matters.
if s == nil {
return nil
}
return append(S([]E{}), s...)
}
// Compact replaces consecutive runs of equal elements with a single copy.
// This is like the uniq command found on Unix.
// Compact modifies the contents of the slice s; it does not create a new slice.
// When Compact discards m elements in total, it might not modify the elements
// s[len(s)-m:len(s)]. If those elements contain pointers you might consider
// zeroing those elements so that objects they reference can be garbage collected.
func Compact[S ~[]E, E comparable](s S) S {
if len(s) < 2 {
return s
}
i := 1
last := s[0]
for _, v := range s[1:] {
if v != last {
s[i] = v
i++
last = v
}
}
return s[:i]
}
// CompactFunc is like Compact but uses a comparison function.
func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S {
if len(s) < 2 {
return s
}
i := 1
last := s[0]
for _, v := range s[1:] {
if !eq(v, last) {
s[i] = v
i++
last = v
}
}
return s[:i]
}
// Grow increases the slice's capacity, if necessary, to guarantee space for
// another n elements. After Grow(n), at least n elements can be appended
// to the slice without another allocation. If n is negative or too large to
// allocate the memory, Grow panics.
func Grow[S ~[]E, E any](s S, n int) S {
if n < 0 {
panic("cannot be negative")
}
if n -= cap(s) - len(s); n > 0 {
s = append(s[:cap(s)], make([]E, n)...)[:len(s)]
}
return s
}
// Clip removes unused capacity from the slice, returning s[:len(s):len(s)].
func Clip[S ~[]E, E any](s S) S {
return s[:len(s):len(s)]
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements binary search.
package sort
// Search uses binary search to find and return the smallest index i
// in [0, n) at which f(i) is true, assuming that on the range [0, n),
// f(i) == true implies f(i+1) == true. That is, Search requires that
// f is false for some (possibly empty) prefix of the input range [0, n)
// and then true for the (possibly empty) remainder; Search returns
// the first true index. If there is no such index, Search returns n.
// (Note that the "not found" return value is not -1 as in, for instance,
// strings.Index.)
// Search calls f(i) only for i in the range [0, n).
//
// A common use of Search is to find the index i for a value x in
// a sorted, indexable data structure such as an array or slice.
// In this case, the argument f, typically a closure, captures the value
// to be searched for, and how the data structure is indexed and
// ordered.
//
// For instance, given a slice data sorted in ascending order,
// the call Search(len(data), func(i int) bool { return data[i] >= 23 })
// returns the smallest index i such that data[i] >= 23. If the caller
// wants to find whether 23 is in the slice, it must test data[i] == 23
// separately.
//
// Searching data sorted in descending order would use the <=
// operator instead of the >= operator.
//
// To complete the example above, the following code tries to find the value
// x in an integer slice data sorted in ascending order:
//
// x := 23
// i := sort.Search(len(data), func(i int) bool { return data[i] >= x })
// if i < len(data) && data[i] == x {
// // x is present at data[i]
// } else {
// // x is not present in data,
// // but i is the index where it would be inserted.
// }
//
// As a more whimsical example, this program guesses your number:
//
// func GuessingGame() {
// var s string
// fmt.Printf("Pick an integer from 0 to 100.\n")
// answer := sort.Search(100, func(i int) bool {
// fmt.Printf("Is your number <= %d? ", i)
// fmt.Scanf("%s", &s)
// return s != "" && s[0] == 'y'
// })
// fmt.Printf("Your number is %d.\n", answer)
// }
func Search(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}
// Find uses binary search to find and return the smallest index i in [0, n)
// at which cmp(i) <= 0. If there is no such index i, Find returns i = n.
// The found result is true if i < n and cmp(i) == 0.
// Find calls cmp(i) only for i in the range [0, n).
//
// To permit binary search, Find requires that cmp(i) > 0 for a leading
// prefix of the range, cmp(i) == 0 in the middle, and cmp(i) < 0 for
// the final suffix of the range. (Each subrange could be empty.)
// The usual way to establish this condition is to interpret cmp(i)
// as a comparison of a desired target value t against entry i in an
// underlying indexed data structure x, returning <0, 0, and >0
// when t < x[i], t == x[i], and t > x[i], respectively.
//
// For example, to look for a particular string in a sorted, random-access
// list of strings:
//
// i, found := sort.Find(x.Len(), func(i int) int {
// return strings.Compare(target, x.At(i))
// })
// if found {
// fmt.Printf("found %s at entry %d\n", target, i)
// } else {
// fmt.Printf("%s not found, would insert at %d", target, i)
// }
func Find(n int, cmp func(int) int) (i int, found bool) {
// The invariants here are similar to the ones in Search.
// Define cmp(-1) > 0 and cmp(n) <= 0
// Invariant: cmp(i-1) > 0, cmp(j) <= 0
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if cmp(h) > 0 {
i = h + 1 // preserves cmp(i-1) > 0
} else {
j = h // preserves cmp(j) <= 0
}
}
// i == j, cmp(i-1) > 0 and cmp(j) <= 0
return i, i < n && cmp(i) == 0
}
// Convenience wrappers for common cases.
// SearchInts searches for x in a sorted slice of ints and returns the index
// as specified by Search. The return value is the index to insert x if x is
// not present (it could be len(a)).
// The slice must be sorted in ascending order.
func SearchInts(a []int, x int) int {
return Search(len(a), func(i int) bool { return a[i] >= x })
}
// SearchFloat64s searches for x in a sorted slice of float64s and returns the index
// as specified by Search. The return value is the index to insert x if x is not
// present (it could be len(a)).
// The slice must be sorted in ascending order.
func SearchFloat64s(a []float64, x float64) int {
return Search(len(a), func(i int) bool { return a[i] >= x })
}
// SearchStrings searches for x in a sorted slice of strings and returns the index
// as specified by Search. The return value is the index to insert x if x is not
// present (it could be len(a)).
// The slice must be sorted in ascending order.
func SearchStrings(a []string, x string) int {
return Search(len(a), func(i int) bool { return a[i] >= x })
}
// Search returns the result of applying SearchInts to the receiver and x.
func (p IntSlice) Search(x int) int { return SearchInts(p, x) }
// Search returns the result of applying SearchFloat64s to the receiver and x.
func (p Float64Slice) Search(x float64) int { return SearchFloat64s(p, x) }
// Search returns the result of applying SearchStrings to the receiver and x.
func (p StringSlice) Search(x string) int { return SearchStrings(p, x) }
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sort
import (
"internal/reflectlite"
"math/bits"
)
// Slice sorts the slice x given the provided less function.
// It panics if x is not a slice.
//
// The sort is not guaranteed to be stable: equal elements
// may be reversed from their original order.
// For a stable sort, use SliceStable.
//
// The less function must satisfy the same requirements as
// the Interface type's Less method.
func Slice(x any, less func(i, j int) bool) {
rv := reflectlite.ValueOf(x)
swap := reflectlite.Swapper(x)
length := rv.Len()
limit := bits.Len(uint(length))
pdqsort_func(lessSwap{less, swap}, 0, length, limit)
}
// SliceStable sorts the slice x using the provided less
// function, keeping equal elements in their original order.
// It panics if x is not a slice.
//
// The less function must satisfy the same requirements as
// the Interface type's Less method.
func SliceStable(x any, less func(i, j int) bool) {
rv := reflectlite.ValueOf(x)
swap := reflectlite.Swapper(x)
stable_func(lessSwap{less, swap}, rv.Len())
}
// SliceIsSorted reports whether the slice x is sorted according to the provided less function.
// It panics if x is not a slice.
func SliceIsSorted(x any, less func(i, j int) bool) bool {
rv := reflectlite.ValueOf(x)
n := rv.Len()
for i := n - 1; i > 0; i-- {
if less(i, i-1) {
return false
}
}
return true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run gen_sort_variants.go
// Package sort provides primitives for sorting slices and user-defined collections.
package sort
import "math/bits"
// An implementation of Interface can be sorted by the routines in this package.
// The methods refer to elements of the underlying collection by integer index.
type Interface interface {
// Len is the number of elements in the collection.
Len() int
// Less reports whether the element with index i
// must sort before the element with index j.
//
// If both Less(i, j) and Less(j, i) are false,
// then the elements at index i and j are considered equal.
// Sort may place equal elements in any order in the final result,
// while Stable preserves the original input order of equal elements.
//
// Less must describe a transitive ordering:
// - if both Less(i, j) and Less(j, k) are true, then Less(i, k) must be true as well.
// - if both Less(i, j) and Less(j, k) are false, then Less(i, k) must be false as well.
//
// Note that floating-point comparison (the < operator on float32 or float64 values)
// is not a transitive ordering when not-a-number (NaN) values are involved.
// See Float64Slice.Less for a correct implementation for floating-point values.
Less(i, j int) bool
// Swap swaps the elements with indexes i and j.
Swap(i, j int)
}
// Sort sorts data in ascending order as determined by the Less method.
// It makes one call to data.Len to determine n and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
n := data.Len()
if n <= 1 {
return
}
limit := bits.Len(uint(n))
pdqsort(data, 0, n, limit)
}
type sortedHint int // hint for pdqsort when choosing the pivot
const (
unknownHint sortedHint = iota
increasingHint
decreasingHint
)
// xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf
type xorshift uint64
func (r *xorshift) Next() uint64 {
*r ^= *r << 13
*r ^= *r >> 17
*r ^= *r << 5
return uint64(*r)
}
func nextPowerOfTwo(length int) uint {
shift := uint(bits.Len(uint(length)))
return uint(1 << shift)
}
// lessSwap is a pair of Less and Swap function for use with the
// auto-generated func-optimized variant of sort.go in
// zfuncversion.go.
type lessSwap struct {
Less func(i, j int) bool
Swap func(i, j int)
}
type reverse struct {
// This embedded Interface permits Reverse to use the methods of
// another Interface implementation.
Interface
}
// Less returns the opposite of the embedded implementation's Less method.
func (r reverse) Less(i, j int) bool {
return r.Interface.Less(j, i)
}
// Reverse returns the reverse order for data.
func Reverse(data Interface) Interface {
return &reverse{data}
}
// IsSorted reports whether data is sorted.
func IsSorted(data Interface) bool {
n := data.Len()
for i := n - 1; i > 0; i-- {
if data.Less(i, i-1) {
return false
}
}
return true
}
// Convenience types for common cases
// IntSlice attaches the methods of Interface to []int, sorting in increasing order.
type IntSlice []int
func (x IntSlice) Len() int { return len(x) }
func (x IntSlice) Less(i, j int) bool { return x[i] < x[j] }
func (x IntSlice) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
// Sort is a convenience method: x.Sort() calls Sort(x).
func (x IntSlice) Sort() { Sort(x) }
// Float64Slice implements Interface for a []float64, sorting in increasing order,
// with not-a-number (NaN) values ordered before other values.
type Float64Slice []float64
func (x Float64Slice) Len() int { return len(x) }
// Less reports whether x[i] should be ordered before x[j], as required by the sort Interface.
// Note that floating-point comparison by itself is not a transitive relation: it does not
// report a consistent ordering for not-a-number (NaN) values.
// This implementation of Less places NaN values before any others, by using:
//
// x[i] < x[j] || (math.IsNaN(x[i]) && !math.IsNaN(x[j]))
func (x Float64Slice) Less(i, j int) bool { return x[i] < x[j] || (isNaN(x[i]) && !isNaN(x[j])) }
func (x Float64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
// isNaN is a copy of math.IsNaN to avoid a dependency on the math package.
func isNaN(f float64) bool {
return f != f
}
// Sort is a convenience method: x.Sort() calls Sort(x).
func (x Float64Slice) Sort() { Sort(x) }
// StringSlice attaches the methods of Interface to []string, sorting in increasing order.
type StringSlice []string
func (x StringSlice) Len() int { return len(x) }
func (x StringSlice) Less(i, j int) bool { return x[i] < x[j] }
func (x StringSlice) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
// Sort is a convenience method: x.Sort() calls Sort(x).
func (x StringSlice) Sort() { Sort(x) }
// Convenience wrappers for common cases
// Ints sorts a slice of ints in increasing order.
func Ints(x []int) { Sort(IntSlice(x)) }
// Float64s sorts a slice of float64s in increasing order.
// Not-a-number (NaN) values are ordered before other values.
func Float64s(x []float64) { Sort(Float64Slice(x)) }
// Strings sorts a slice of strings in increasing order.
func Strings(x []string) { Sort(StringSlice(x)) }
// IntsAreSorted reports whether the slice x is sorted in increasing order.
func IntsAreSorted(x []int) bool { return IsSorted(IntSlice(x)) }
// Float64sAreSorted reports whether the slice x is sorted in increasing order,
// with not-a-number (NaN) values before any other values.
func Float64sAreSorted(x []float64) bool { return IsSorted(Float64Slice(x)) }
// StringsAreSorted reports whether the slice x is sorted in increasing order.
func StringsAreSorted(x []string) bool { return IsSorted(StringSlice(x)) }
// Notes on stable sorting:
// The used algorithms are simple and provable correct on all input and use
// only logarithmic additional stack space. They perform well if compared
// experimentally to other stable in-place sorting algorithms.
//
// Remarks on other algorithms evaluated:
// - GCC's 4.6.3 stable_sort with merge_without_buffer from libstdc++:
// Not faster.
// - GCC's __rotate for block rotations: Not faster.
// - "Practical in-place mergesort" from Jyrki Katajainen, Tomi A. Pasanen
// and Jukka Teuhola; Nordic Journal of Computing 3,1 (1996), 27-40:
// The given algorithms are in-place, number of Swap and Assignments
// grow as n log n but the algorithm is not stable.
// - "Fast Stable In-Place Sorting with O(n) Data Moves" J.I. Munro and
// V. Raman in Algorithmica (1996) 16, 115-160:
// This algorithm either needs additional 2n bits or works only if there
// are enough different elements available to encode some permutations
// which have to be undone later (so not stable on any input).
// - All the optimal in-place sorting/merging algorithms I found are either
// unstable or rely on enough different elements in each step to encode the
// performed block rearrangements. See also "In-Place Merging Algorithms",
// Denham Coates-Evely, Department of Computer Science, Kings College,
// January 2004 and the references in there.
// - Often "optimal" algorithms are optimal in the number of assignments
// but Interface has only Swap as operation.
// Stable sorts data in ascending order as determined by the Less method,
// while keeping the original order of equal elements.
//
// It makes one call to data.Len to determine n, O(n*log(n)) calls to
// data.Less and O(n*log(n)*log(n)) calls to data.Swap.
func Stable(data Interface) {
stable(data, data.Len())
}
/*
Complexity of Stable Sorting
Complexity of block swapping rotation
Each Swap puts one new element into its correct, final position.
Elements which reach their final position are no longer moved.
Thus block swapping rotation needs |u|+|v| calls to Swaps.
This is best possible as each element might need a move.
Pay attention when comparing to other optimal algorithms which
typically count the number of assignments instead of swaps:
E.g. the optimal algorithm of Dudzinski and Dydek for in-place
rotations uses O(u + v + gcd(u,v)) assignments which is
better than our O(3 * (u+v)) as gcd(u,v) <= u.
Stable sorting by SymMerge and BlockSwap rotations
SymMerg complexity for same size input M = N:
Calls to Less: O(M*log(N/M+1)) = O(N*log(2)) = O(N)
Calls to Swap: O((M+N)*log(M)) = O(2*N*log(N)) = O(N*log(N))
(The following argument does not fuzz over a missing -1 or
other stuff which does not impact the final result).
Let n = data.Len(). Assume n = 2^k.
Plain merge sort performs log(n) = k iterations.
On iteration i the algorithm merges 2^(k-i) blocks, each of size 2^i.
Thus iteration i of merge sort performs:
Calls to Less O(2^(k-i) * 2^i) = O(2^k) = O(2^log(n)) = O(n)
Calls to Swap O(2^(k-i) * 2^i * log(2^i)) = O(2^k * i) = O(n*i)
In total k = log(n) iterations are performed; so in total:
Calls to Less O(log(n) * n)
Calls to Swap O(n + 2*n + 3*n + ... + (k-1)*n + k*n)
= O((k/2) * k * n) = O(n * k^2) = O(n * log^2(n))
Above results should generalize to arbitrary n = 2^k + p
and should not be influenced by the initial insertion sort phase:
Insertion sort is O(n^2) on Swap and Less, thus O(bs^2) per block of
size bs at n/bs blocks: O(bs*n) Swaps and Less during insertion sort.
Merge sort iterations start at i = log(bs). With t = log(bs) constant:
Calls to Less O((log(n)-t) * n + bs*n) = O(log(n)*n + (bs-t)*n)
= O(n * log(n))
Calls to Swap O(n * log^2(n) - (t^2+t)/2*n) = O(n * log^2(n))
*/
// Code generated by gen_sort_variants.go; DO NOT EDIT.
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sort
// insertionSort_func sorts data[a:b] using insertion sort.
func insertionSort_func(data lessSwap, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}
// siftDown_func implements the heap property on data[lo:hi].
// first is an offset into the array where the root of the heap lies.
func siftDown_func(data lessSwap, lo, hi, first int) {
root := lo
for {
child := 2*root + 1
if child >= hi {
break
}
if child+1 < hi && data.Less(first+child, first+child+1) {
child++
}
if !data.Less(first+root, first+child) {
return
}
data.Swap(first+root, first+child)
root = child
}
}
func heapSort_func(data lessSwap, a, b int) {
first := a
lo := 0
hi := b - a
// Build heap with greatest element at top.
for i := (hi - 1) / 2; i >= 0; i-- {
siftDown_func(data, i, hi, first)
}
// Pop elements, largest first, into end of data.
for i := hi - 1; i >= 0; i-- {
data.Swap(first, first+i)
siftDown_func(data, lo, i, first)
}
}
// pdqsort_func sorts data[a:b].
// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort.
// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf
// C++ implementation: https://github.com/orlp/pdqsort
// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/
// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort.
func pdqsort_func(data lessSwap, a, b, limit int) {
const maxInsertion = 12
var (
wasBalanced = true // whether the last partitioning was reasonably balanced
wasPartitioned = true // whether the slice was already partitioned
)
for {
length := b - a
if length <= maxInsertion {
insertionSort_func(data, a, b)
return
}
// Fall back to heapsort if too many bad choices were made.
if limit == 0 {
heapSort_func(data, a, b)
return
}
// If the last partitioning was imbalanced, we need to breaking patterns.
if !wasBalanced {
breakPatterns_func(data, a, b)
limit--
}
pivot, hint := choosePivot_func(data, a, b)
if hint == decreasingHint {
reverseRange_func(data, a, b)
// The chosen pivot was pivot-a elements after the start of the array.
// After reversing it is pivot-a elements before the end of the array.
// The idea came from Rust's implementation.
pivot = (b - 1) - (pivot - a)
hint = increasingHint
}
// The slice is likely already sorted.
if wasBalanced && wasPartitioned && hint == increasingHint {
if partialInsertionSort_func(data, a, b) {
return
}
}
// Probably the slice contains many duplicate elements, partition the slice into
// elements equal to and elements greater than the pivot.
if a > 0 && !data.Less(a-1, pivot) {
mid := partitionEqual_func(data, a, b, pivot)
a = mid
continue
}
mid, alreadyPartitioned := partition_func(data, a, b, pivot)
wasPartitioned = alreadyPartitioned
leftLen, rightLen := mid-a, b-mid
balanceThreshold := length / 8
if leftLen < rightLen {
wasBalanced = leftLen >= balanceThreshold
pdqsort_func(data, a, mid, limit)
a = mid + 1
} else {
wasBalanced = rightLen >= balanceThreshold
pdqsort_func(data, mid+1, b, limit)
b = mid
}
}
}
// partition_func does one quicksort partition.
// Let p = data[pivot]
// Moves elements in data[a:b] around, so that data[i]<p and data[j]>=p for i<newpivot and j>newpivot.
// On return, data[newpivot] = p
func partition_func(data lessSwap, a, b, pivot int) (newpivot int, alreadyPartitioned bool) {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
data.Swap(j, a)
return j, true
}
data.Swap(i, j)
i++
j--
for {
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
data.Swap(j, a)
return j, false
}
// partitionEqual_func partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot].
// It assumed that data[a:b] does not contain elements smaller than the data[pivot].
func partitionEqual_func(data lessSwap, a, b, pivot int) (newpivot int) {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned
for {
for i <= j && !data.Less(a, i) {
i++
}
for i <= j && data.Less(a, j) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
return i
}
// partialInsertionSort_func partially sorts a slice, returns true if the slice is sorted at the end.
func partialInsertionSort_func(data lessSwap, a, b int) bool {
const (
maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted
shortestShifting = 50 // don't shift any elements on short arrays
)
i := a + 1
for j := 0; j < maxSteps; j++ {
for i < b && !data.Less(i, i-1) {
i++
}
if i == b {
return true
}
if b-a < shortestShifting {
return false
}
data.Swap(i, i-1)
// Shift the smaller one to the left.
if i-a >= 2 {
for j := i - 1; j >= 1; j-- {
if !data.Less(j, j-1) {
break
}
data.Swap(j, j-1)
}
}
// Shift the greater one to the right.
if b-i >= 2 {
for j := i + 1; j < b; j++ {
if !data.Less(j, j-1) {
break
}
data.Swap(j, j-1)
}
}
}
return false
}
// breakPatterns_func scatters some elements around in an attempt to break some patterns
// that might cause imbalanced partitions in quicksort.
func breakPatterns_func(data lessSwap, a, b int) {
length := b - a
if length >= 8 {
random := xorshift(length)
modulus := nextPowerOfTwo(length)
for idx := a + (length/4)*2 - 1; idx <= a+(length/4)*2+1; idx++ {
other := int(uint(random.Next()) & (modulus - 1))
if other >= length {
other -= length
}
data.Swap(idx, a+other)
}
}
}
// choosePivot_func chooses a pivot in data[a:b].
//
// [0,8): chooses a static pivot.
// [8,shortestNinther): uses the simple median-of-three method.
// [shortestNinther,∞): uses the Tukey ninther method.
func choosePivot_func(data lessSwap, a, b int) (pivot int, hint sortedHint) {
const (
shortestNinther = 50
maxSwaps = 4 * 3
)
l := b - a
var (
swaps int
i = a + l/4*1
j = a + l/4*2
k = a + l/4*3
)
if l >= 8 {
if l >= shortestNinther {
// Tukey ninther method, the idea came from Rust's implementation.
i = medianAdjacent_func(data, i, &swaps)
j = medianAdjacent_func(data, j, &swaps)
k = medianAdjacent_func(data, k, &swaps)
}
// Find the median among i, j, k and stores it into j.
j = median_func(data, i, j, k, &swaps)
}
switch swaps {
case 0:
return j, increasingHint
case maxSwaps:
return j, decreasingHint
default:
return j, unknownHint
}
}
// order2_func returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a.
func order2_func(data lessSwap, a, b int, swaps *int) (int, int) {
if data.Less(b, a) {
*swaps++
return b, a
}
return a, b
}
// median_func returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c.
func median_func(data lessSwap, a, b, c int, swaps *int) int {
a, b = order2_func(data, a, b, swaps)
b, c = order2_func(data, b, c, swaps)
a, b = order2_func(data, a, b, swaps)
return b
}
// medianAdjacent_func finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a.
func medianAdjacent_func(data lessSwap, a int, swaps *int) int {
return median_func(data, a-1, a, a+1, swaps)
}
func reverseRange_func(data lessSwap, a, b int) {
i := a
j := b - 1
for i < j {
data.Swap(i, j)
i++
j--
}
}
func swapRange_func(data lessSwap, a, b, n int) {
for i := 0; i < n; i++ {
data.Swap(a+i, b+i)
}
}
func stable_func(data lessSwap, n int) {
blockSize := 20 // must be > 0
a, b := 0, blockSize
for b <= n {
insertionSort_func(data, a, b)
a = b
b += blockSize
}
insertionSort_func(data, a, n)
for blockSize < n {
a, b = 0, 2*blockSize
for b <= n {
symMerge_func(data, a, a+blockSize, b)
a = b
b += 2 * blockSize
}
if m := a + blockSize; m < n {
symMerge_func(data, a, m, n)
}
blockSize *= 2
}
}
// symMerge_func merges the two sorted subsequences data[a:m] and data[m:b] using
// the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum
// Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz
// Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in
// Computer Science, pages 714-723. Springer, 2004.
//
// Let M = m-a and N = b-n. Wolog M < N.
// The recursion depth is bound by ceil(log(N+M)).
// The algorithm needs O(M*log(N/M + 1)) calls to data.Less.
// The algorithm needs O((M+N)*log(M)) calls to data.Swap.
//
// The paper gives O((M+N)*log(M)) as the number of assignments assuming a
// rotation algorithm which uses O(M+N+gcd(M+N)) assignments. The argumentation
// in the paper carries through for Swap operations, especially as the block
// swapping rotate uses only O(M+N) Swaps.
//
// symMerge assumes non-degenerate arguments: a < m && m < b.
// Having the caller check this condition eliminates many leaf recursion calls,
// which improves performance.
func symMerge_func(data lessSwap, a, m, b int) {
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[a] into data[m:b]
// if data[a:m] only contains one element.
if m-a == 1 {
// Use binary search to find the lowest index i
// such that data[i] >= data[a] for m <= i < b.
// Exit the search loop with i == b in case no such index exists.
i := m
j := b
for i < j {
h := int(uint(i+j) >> 1)
if data.Less(h, a) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[a] reaches the position before i.
for k := a; k < i-1; k++ {
data.Swap(k, k+1)
}
return
}
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[m] into data[a:m]
// if data[m:b] only contains one element.
if b-m == 1 {
// Use binary search to find the lowest index i
// such that data[i] > data[m] for a <= i < m.
// Exit the search loop with i == m in case no such index exists.
i := a
j := m
for i < j {
h := int(uint(i+j) >> 1)
if !data.Less(m, h) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[m] reaches the position i.
for k := m; k > i; k-- {
data.Swap(k, k-1)
}
return
}
mid := int(uint(a+b) >> 1)
n := mid + m
var start, r int
if m > mid {
start = n - b
r = mid
} else {
start = a
r = m
}
p := n - 1
for start < r {
c := int(uint(start+r) >> 1)
if !data.Less(p-c, c) {
start = c + 1
} else {
r = c
}
}
end := n - start
if start < m && m < end {
rotate_func(data, start, m, end)
}
if a < start && start < mid {
symMerge_func(data, a, start, mid)
}
if mid < end && end < b {
symMerge_func(data, mid, end, b)
}
}
// rotate_func rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data:
// Data of the form 'x u v y' is changed to 'x v u y'.
// rotate performs at most b-a many calls to data.Swap,
// and it assumes non-degenerate arguments: a < m && m < b.
func rotate_func(data lessSwap, a, m, b int) {
i := m - a
j := b - m
for i != j {
if i > j {
swapRange_func(data, m-i, m, j)
i -= j
} else {
swapRange_func(data, m-i, m+j-i, i)
j -= i
}
}
// i == j
swapRange_func(data, m-i, m, i)
}
// Code generated by gen_sort_variants.go; DO NOT EDIT.
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sort
// insertionSort sorts data[a:b] using insertion sort.
func insertionSort(data Interface, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}
// siftDown implements the heap property on data[lo:hi].
// first is an offset into the array where the root of the heap lies.
func siftDown(data Interface, lo, hi, first int) {
root := lo
for {
child := 2*root + 1
if child >= hi {
break
}
if child+1 < hi && data.Less(first+child, first+child+1) {
child++
}
if !data.Less(first+root, first+child) {
return
}
data.Swap(first+root, first+child)
root = child
}
}
func heapSort(data Interface, a, b int) {
first := a
lo := 0
hi := b - a
// Build heap with greatest element at top.
for i := (hi - 1) / 2; i >= 0; i-- {
siftDown(data, i, hi, first)
}
// Pop elements, largest first, into end of data.
for i := hi - 1; i >= 0; i-- {
data.Swap(first, first+i)
siftDown(data, lo, i, first)
}
}
// pdqsort sorts data[a:b].
// The algorithm based on pattern-defeating quicksort(pdqsort), but without the optimizations from BlockQuicksort.
// pdqsort paper: https://arxiv.org/pdf/2106.05123.pdf
// C++ implementation: https://github.com/orlp/pdqsort
// Rust implementation: https://docs.rs/pdqsort/latest/pdqsort/
// limit is the number of allowed bad (very unbalanced) pivots before falling back to heapsort.
func pdqsort(data Interface, a, b, limit int) {
const maxInsertion = 12
var (
wasBalanced = true // whether the last partitioning was reasonably balanced
wasPartitioned = true // whether the slice was already partitioned
)
for {
length := b - a
if length <= maxInsertion {
insertionSort(data, a, b)
return
}
// Fall back to heapsort if too many bad choices were made.
if limit == 0 {
heapSort(data, a, b)
return
}
// If the last partitioning was imbalanced, we need to breaking patterns.
if !wasBalanced {
breakPatterns(data, a, b)
limit--
}
pivot, hint := choosePivot(data, a, b)
if hint == decreasingHint {
reverseRange(data, a, b)
// The chosen pivot was pivot-a elements after the start of the array.
// After reversing it is pivot-a elements before the end of the array.
// The idea came from Rust's implementation.
pivot = (b - 1) - (pivot - a)
hint = increasingHint
}
// The slice is likely already sorted.
if wasBalanced && wasPartitioned && hint == increasingHint {
if partialInsertionSort(data, a, b) {
return
}
}
// Probably the slice contains many duplicate elements, partition the slice into
// elements equal to and elements greater than the pivot.
if a > 0 && !data.Less(a-1, pivot) {
mid := partitionEqual(data, a, b, pivot)
a = mid
continue
}
mid, alreadyPartitioned := partition(data, a, b, pivot)
wasPartitioned = alreadyPartitioned
leftLen, rightLen := mid-a, b-mid
balanceThreshold := length / 8
if leftLen < rightLen {
wasBalanced = leftLen >= balanceThreshold
pdqsort(data, a, mid, limit)
a = mid + 1
} else {
wasBalanced = rightLen >= balanceThreshold
pdqsort(data, mid+1, b, limit)
b = mid
}
}
}
// partition does one quicksort partition.
// Let p = data[pivot]
// Moves elements in data[a:b] around, so that data[i]<p and data[j]>=p for i<newpivot and j>newpivot.
// On return, data[newpivot] = p
func partition(data Interface, a, b, pivot int) (newpivot int, alreadyPartitioned bool) {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
data.Swap(j, a)
return j, true
}
data.Swap(i, j)
i++
j--
for {
for i <= j && data.Less(i, a) {
i++
}
for i <= j && !data.Less(j, a) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
data.Swap(j, a)
return j, false
}
// partitionEqual partitions data[a:b] into elements equal to data[pivot] followed by elements greater than data[pivot].
// It assumed that data[a:b] does not contain elements smaller than the data[pivot].
func partitionEqual(data Interface, a, b, pivot int) (newpivot int) {
data.Swap(a, pivot)
i, j := a+1, b-1 // i and j are inclusive of the elements remaining to be partitioned
for {
for i <= j && !data.Less(a, i) {
i++
}
for i <= j && data.Less(a, j) {
j--
}
if i > j {
break
}
data.Swap(i, j)
i++
j--
}
return i
}
// partialInsertionSort partially sorts a slice, returns true if the slice is sorted at the end.
func partialInsertionSort(data Interface, a, b int) bool {
const (
maxSteps = 5 // maximum number of adjacent out-of-order pairs that will get shifted
shortestShifting = 50 // don't shift any elements on short arrays
)
i := a + 1
for j := 0; j < maxSteps; j++ {
for i < b && !data.Less(i, i-1) {
i++
}
if i == b {
return true
}
if b-a < shortestShifting {
return false
}
data.Swap(i, i-1)
// Shift the smaller one to the left.
if i-a >= 2 {
for j := i - 1; j >= 1; j-- {
if !data.Less(j, j-1) {
break
}
data.Swap(j, j-1)
}
}
// Shift the greater one to the right.
if b-i >= 2 {
for j := i + 1; j < b; j++ {
if !data.Less(j, j-1) {
break
}
data.Swap(j, j-1)
}
}
}
return false
}
// breakPatterns scatters some elements around in an attempt to break some patterns
// that might cause imbalanced partitions in quicksort.
func breakPatterns(data Interface, a, b int) {
length := b - a
if length >= 8 {
random := xorshift(length)
modulus := nextPowerOfTwo(length)
for idx := a + (length/4)*2 - 1; idx <= a+(length/4)*2+1; idx++ {
other := int(uint(random.Next()) & (modulus - 1))
if other >= length {
other -= length
}
data.Swap(idx, a+other)
}
}
}
// choosePivot chooses a pivot in data[a:b].
//
// [0,8): chooses a static pivot.
// [8,shortestNinther): uses the simple median-of-three method.
// [shortestNinther,∞): uses the Tukey ninther method.
func choosePivot(data Interface, a, b int) (pivot int, hint sortedHint) {
const (
shortestNinther = 50
maxSwaps = 4 * 3
)
l := b - a
var (
swaps int
i = a + l/4*1
j = a + l/4*2
k = a + l/4*3
)
if l >= 8 {
if l >= shortestNinther {
// Tukey ninther method, the idea came from Rust's implementation.
i = medianAdjacent(data, i, &swaps)
j = medianAdjacent(data, j, &swaps)
k = medianAdjacent(data, k, &swaps)
}
// Find the median among i, j, k and stores it into j.
j = median(data, i, j, k, &swaps)
}
switch swaps {
case 0:
return j, increasingHint
case maxSwaps:
return j, decreasingHint
default:
return j, unknownHint
}
}
// order2 returns x,y where data[x] <= data[y], where x,y=a,b or x,y=b,a.
func order2(data Interface, a, b int, swaps *int) (int, int) {
if data.Less(b, a) {
*swaps++
return b, a
}
return a, b
}
// median returns x where data[x] is the median of data[a],data[b],data[c], where x is a, b, or c.
func median(data Interface, a, b, c int, swaps *int) int {
a, b = order2(data, a, b, swaps)
b, c = order2(data, b, c, swaps)
a, b = order2(data, a, b, swaps)
return b
}
// medianAdjacent finds the median of data[a - 1], data[a], data[a + 1] and stores the index into a.
func medianAdjacent(data Interface, a int, swaps *int) int {
return median(data, a-1, a, a+1, swaps)
}
func reverseRange(data Interface, a, b int) {
i := a
j := b - 1
for i < j {
data.Swap(i, j)
i++
j--
}
}
func swapRange(data Interface, a, b, n int) {
for i := 0; i < n; i++ {
data.Swap(a+i, b+i)
}
}
func stable(data Interface, n int) {
blockSize := 20 // must be > 0
a, b := 0, blockSize
for b <= n {
insertionSort(data, a, b)
a = b
b += blockSize
}
insertionSort(data, a, n)
for blockSize < n {
a, b = 0, 2*blockSize
for b <= n {
symMerge(data, a, a+blockSize, b)
a = b
b += 2 * blockSize
}
if m := a + blockSize; m < n {
symMerge(data, a, m, n)
}
blockSize *= 2
}
}
// symMerge merges the two sorted subsequences data[a:m] and data[m:b] using
// the SymMerge algorithm from Pok-Son Kim and Arne Kutzner, "Stable Minimum
// Storage Merging by Symmetric Comparisons", in Susanne Albers and Tomasz
// Radzik, editors, Algorithms - ESA 2004, volume 3221 of Lecture Notes in
// Computer Science, pages 714-723. Springer, 2004.
//
// Let M = m-a and N = b-n. Wolog M < N.
// The recursion depth is bound by ceil(log(N+M)).
// The algorithm needs O(M*log(N/M + 1)) calls to data.Less.
// The algorithm needs O((M+N)*log(M)) calls to data.Swap.
//
// The paper gives O((M+N)*log(M)) as the number of assignments assuming a
// rotation algorithm which uses O(M+N+gcd(M+N)) assignments. The argumentation
// in the paper carries through for Swap operations, especially as the block
// swapping rotate uses only O(M+N) Swaps.
//
// symMerge assumes non-degenerate arguments: a < m && m < b.
// Having the caller check this condition eliminates many leaf recursion calls,
// which improves performance.
func symMerge(data Interface, a, m, b int) {
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[a] into data[m:b]
// if data[a:m] only contains one element.
if m-a == 1 {
// Use binary search to find the lowest index i
// such that data[i] >= data[a] for m <= i < b.
// Exit the search loop with i == b in case no such index exists.
i := m
j := b
for i < j {
h := int(uint(i+j) >> 1)
if data.Less(h, a) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[a] reaches the position before i.
for k := a; k < i-1; k++ {
data.Swap(k, k+1)
}
return
}
// Avoid unnecessary recursions of symMerge
// by direct insertion of data[m] into data[a:m]
// if data[m:b] only contains one element.
if b-m == 1 {
// Use binary search to find the lowest index i
// such that data[i] > data[m] for a <= i < m.
// Exit the search loop with i == m in case no such index exists.
i := a
j := m
for i < j {
h := int(uint(i+j) >> 1)
if !data.Less(m, h) {
i = h + 1
} else {
j = h
}
}
// Swap values until data[m] reaches the position i.
for k := m; k > i; k-- {
data.Swap(k, k-1)
}
return
}
mid := int(uint(a+b) >> 1)
n := mid + m
var start, r int
if m > mid {
start = n - b
r = mid
} else {
start = a
r = m
}
p := n - 1
for start < r {
c := int(uint(start+r) >> 1)
if !data.Less(p-c, c) {
start = c + 1
} else {
r = c
}
}
end := n - start
if start < m && m < end {
rotate(data, start, m, end)
}
if a < start && start < mid {
symMerge(data, a, start, mid)
}
if mid < end && end < b {
symMerge(data, mid, end, b)
}
}
// rotate rotates two consecutive blocks u = data[a:m] and v = data[m:b] in data:
// Data of the form 'x u v y' is changed to 'x v u y'.
// rotate performs at most b-a many calls to data.Swap,
// and it assumes non-degenerate arguments: a < m && m < b.
func rotate(data Interface, a, m, b int) {
i := m - a
j := b - m
for i != j {
if i > j {
swapRange(data, m-i, m, j)
i -= j
} else {
swapRange(data, m-i, m+j-i, i)
j -= i
}
}
// i == j
swapRange(data, m-i, m, i)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
// ParseBool returns the boolean value represented by the string.
// It accepts 1, t, T, TRUE, true, True, 0, f, F, FALSE, false, False.
// Any other value returns an error.
func ParseBool(str string) (bool, error) {
switch str {
case "1", "t", "T", "true", "TRUE", "True":
return true, nil
case "0", "f", "F", "false", "FALSE", "False":
return false, nil
}
return false, syntaxError("ParseBool", str)
}
// FormatBool returns "true" or "false" according to the value of b.
func FormatBool(b bool) string {
if b {
return "true"
}
return "false"
}
// AppendBool appends "true" or "false", according to the value of b,
// to dst and returns the extended buffer.
func AppendBool(dst []byte, b bool) []byte {
if b {
return append(dst, "true"...)
}
return append(dst, "false"...)
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
const fnParseComplex = "ParseComplex"
// convErr splits an error returned by parseFloatPrefix
// into a syntax or range error for ParseComplex.
func convErr(err error, s string) (syntax, range_ error) {
if x, ok := err.(*NumError); ok {
x.Func = fnParseComplex
x.Num = cloneString(s)
if x.Err == ErrRange {
return nil, x
}
}
return err, nil
}
// ParseComplex converts the string s to a complex number
// with the precision specified by bitSize: 64 for complex64, or 128 for complex128.
// When bitSize=64, the result still has type complex128, but it will be
// convertible to complex64 without changing its value.
//
// The number represented by s must be of the form N, Ni, or N±Ni, where N stands
// for a floating-point number as recognized by ParseFloat, and i is the imaginary
// component. If the second N is unsigned, a + sign is required between the two components
// as indicated by the ±. If the second N is NaN, only a + sign is accepted.
// The form may be parenthesized and cannot contain any spaces.
// The resulting complex number consists of the two components converted by ParseFloat.
//
// The errors that ParseComplex returns have concrete type *NumError
// and include err.Num = s.
//
// If s is not syntactically well-formed, ParseComplex returns err.Err = ErrSyntax.
//
// If s is syntactically well-formed but either component is more than 1/2 ULP
// away from the largest floating point number of the given component's size,
// ParseComplex returns err.Err = ErrRange and c = ±Inf for the respective component.
func ParseComplex(s string, bitSize int) (complex128, error) {
size := 64
if bitSize == 64 {
size = 32 // complex64 uses float32 parts
}
orig := s
// Remove parentheses, if any.
if len(s) >= 2 && s[0] == '(' && s[len(s)-1] == ')' {
s = s[1 : len(s)-1]
}
var pending error // pending range error, or nil
// Read real part (possibly imaginary part if followed by 'i').
re, n, err := parseFloatPrefix(s, size)
if err != nil {
err, pending = convErr(err, orig)
if err != nil {
return 0, err
}
}
s = s[n:]
// If we have nothing left, we're done.
if len(s) == 0 {
return complex(re, 0), pending
}
// Otherwise, look at the next character.
switch s[0] {
case '+':
// Consume the '+' to avoid an error if we have "+NaNi", but
// do this only if we don't have a "++" (don't hide that error).
if len(s) > 1 && s[1] != '+' {
s = s[1:]
}
case '-':
// ok
case 'i':
// If 'i' is the last character, we only have an imaginary part.
if len(s) == 1 {
return complex(0, re), pending
}
fallthrough
default:
return 0, syntaxError(fnParseComplex, orig)
}
// Read imaginary part.
im, n, err := parseFloatPrefix(s, size)
if err != nil {
err, pending = convErr(err, orig)
if err != nil {
return 0, err
}
}
s = s[n:]
if s != "i" {
return 0, syntaxError(fnParseComplex, orig)
}
return complex(re, im), pending
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
// decimal to binary floating point conversion.
// Algorithm:
// 1) Store input in multiprecision decimal.
// 2) Multiply/divide decimal by powers of two until in range [0.5, 1)
// 3) Multiply by 2^precision and round to get mantissa.
import "math"
var optimize = true // set to false to force slow-path conversions for testing
// commonPrefixLenIgnoreCase returns the length of the common
// prefix of s and prefix, with the character case of s ignored.
// The prefix argument must be all lower-case.
func commonPrefixLenIgnoreCase(s, prefix string) int {
n := len(prefix)
if n > len(s) {
n = len(s)
}
for i := 0; i < n; i++ {
c := s[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
if c != prefix[i] {
return i
}
}
return n
}
// special returns the floating-point value for the special,
// possibly signed floating-point representations inf, infinity,
// and NaN. The result is ok if a prefix of s contains one
// of these representations and n is the length of that prefix.
// The character case is ignored.
func special(s string) (f float64, n int, ok bool) {
if len(s) == 0 {
return 0, 0, false
}
sign := 1
nsign := 0
switch s[0] {
case '+', '-':
if s[0] == '-' {
sign = -1
}
nsign = 1
s = s[1:]
fallthrough
case 'i', 'I':
n := commonPrefixLenIgnoreCase(s, "infinity")
// Anything longer than "inf" is ok, but if we
// don't have "infinity", only consume "inf".
if 3 < n && n < 8 {
n = 3
}
if n == 3 || n == 8 {
return math.Inf(sign), nsign + n, true
}
case 'n', 'N':
if commonPrefixLenIgnoreCase(s, "nan") == 3 {
return math.NaN(), 3, true
}
}
return 0, 0, false
}
func (b *decimal) set(s string) (ok bool) {
i := 0
b.neg = false
b.trunc = false
// optional sign
if i >= len(s) {
return
}
switch {
case s[i] == '+':
i++
case s[i] == '-':
b.neg = true
i++
}
// digits
sawdot := false
sawdigits := false
for ; i < len(s); i++ {
switch {
case s[i] == '_':
// readFloat already checked underscores
continue
case s[i] == '.':
if sawdot {
return
}
sawdot = true
b.dp = b.nd
continue
case '0' <= s[i] && s[i] <= '9':
sawdigits = true
if s[i] == '0' && b.nd == 0 { // ignore leading zeros
b.dp--
continue
}
if b.nd < len(b.d) {
b.d[b.nd] = s[i]
b.nd++
} else if s[i] != '0' {
b.trunc = true
}
continue
}
break
}
if !sawdigits {
return
}
if !sawdot {
b.dp = b.nd
}
// optional exponent moves decimal point.
// if we read a very large, very long number,
// just be sure to move the decimal point by
// a lot (say, 100000). it doesn't matter if it's
// not the exact number.
if i < len(s) && lower(s[i]) == 'e' {
i++
if i >= len(s) {
return
}
esign := 1
if s[i] == '+' {
i++
} else if s[i] == '-' {
i++
esign = -1
}
if i >= len(s) || s[i] < '0' || s[i] > '9' {
return
}
e := 0
for ; i < len(s) && ('0' <= s[i] && s[i] <= '9' || s[i] == '_'); i++ {
if s[i] == '_' {
// readFloat already checked underscores
continue
}
if e < 10000 {
e = e*10 + int(s[i]) - '0'
}
}
b.dp += e * esign
}
if i != len(s) {
return
}
ok = true
return
}
// readFloat reads a decimal or hexadecimal mantissa and exponent from a float
// string representation in s; the number may be followed by other characters.
// readFloat reports the number of bytes consumed (i), and whether the number
// is valid (ok).
func readFloat(s string) (mantissa uint64, exp int, neg, trunc, hex bool, i int, ok bool) {
underscores := false
// optional sign
if i >= len(s) {
return
}
switch {
case s[i] == '+':
i++
case s[i] == '-':
neg = true
i++
}
// digits
base := uint64(10)
maxMantDigits := 19 // 10^19 fits in uint64
expChar := byte('e')
if i+2 < len(s) && s[i] == '0' && lower(s[i+1]) == 'x' {
base = 16
maxMantDigits = 16 // 16^16 fits in uint64
i += 2
expChar = 'p'
hex = true
}
sawdot := false
sawdigits := false
nd := 0
ndMant := 0
dp := 0
loop:
for ; i < len(s); i++ {
switch c := s[i]; true {
case c == '_':
underscores = true
continue
case c == '.':
if sawdot {
break loop
}
sawdot = true
dp = nd
continue
case '0' <= c && c <= '9':
sawdigits = true
if c == '0' && nd == 0 { // ignore leading zeros
dp--
continue
}
nd++
if ndMant < maxMantDigits {
mantissa *= base
mantissa += uint64(c - '0')
ndMant++
} else if c != '0' {
trunc = true
}
continue
case base == 16 && 'a' <= lower(c) && lower(c) <= 'f':
sawdigits = true
nd++
if ndMant < maxMantDigits {
mantissa *= 16
mantissa += uint64(lower(c) - 'a' + 10)
ndMant++
} else {
trunc = true
}
continue
}
break
}
if !sawdigits {
return
}
if !sawdot {
dp = nd
}
if base == 16 {
dp *= 4
ndMant *= 4
}
// optional exponent moves decimal point.
// if we read a very large, very long number,
// just be sure to move the decimal point by
// a lot (say, 100000). it doesn't matter if it's
// not the exact number.
if i < len(s) && lower(s[i]) == expChar {
i++
if i >= len(s) {
return
}
esign := 1
if s[i] == '+' {
i++
} else if s[i] == '-' {
i++
esign = -1
}
if i >= len(s) || s[i] < '0' || s[i] > '9' {
return
}
e := 0
for ; i < len(s) && ('0' <= s[i] && s[i] <= '9' || s[i] == '_'); i++ {
if s[i] == '_' {
underscores = true
continue
}
if e < 10000 {
e = e*10 + int(s[i]) - '0'
}
}
dp += e * esign
} else if base == 16 {
// Must have exponent.
return
}
if mantissa != 0 {
exp = dp - ndMant
}
if underscores && !underscoreOK(s[:i]) {
return
}
ok = true
return
}
// decimal power of ten to binary power of two.
var powtab = []int{1, 3, 6, 9, 13, 16, 19, 23, 26}
func (d *decimal) floatBits(flt *floatInfo) (b uint64, overflow bool) {
var exp int
var mant uint64
// Zero is always a special case.
if d.nd == 0 {
mant = 0
exp = flt.bias
goto out
}
// Obvious overflow/underflow.
// These bounds are for 64-bit floats.
// Will have to change if we want to support 80-bit floats in the future.
if d.dp > 310 {
goto overflow
}
if d.dp < -330 {
// zero
mant = 0
exp = flt.bias
goto out
}
// Scale by powers of two until in range [0.5, 1.0)
exp = 0
for d.dp > 0 {
var n int
if d.dp >= len(powtab) {
n = 27
} else {
n = powtab[d.dp]
}
d.Shift(-n)
exp += n
}
for d.dp < 0 || d.dp == 0 && d.d[0] < '5' {
var n int
if -d.dp >= len(powtab) {
n = 27
} else {
n = powtab[-d.dp]
}
d.Shift(n)
exp -= n
}
// Our range is [0.5,1) but floating point range is [1,2).
exp--
// Minimum representable exponent is flt.bias+1.
// If the exponent is smaller, move it up and
// adjust d accordingly.
if exp < flt.bias+1 {
n := flt.bias + 1 - exp
d.Shift(-n)
exp += n
}
if exp-flt.bias >= 1<<flt.expbits-1 {
goto overflow
}
// Extract 1+flt.mantbits bits.
d.Shift(int(1 + flt.mantbits))
mant = d.RoundedInteger()
// Rounding might have added a bit; shift down.
if mant == 2<<flt.mantbits {
mant >>= 1
exp++
if exp-flt.bias >= 1<<flt.expbits-1 {
goto overflow
}
}
// Denormalized?
if mant&(1<<flt.mantbits) == 0 {
exp = flt.bias
}
goto out
overflow:
// ±Inf
mant = 0
exp = 1<<flt.expbits - 1 + flt.bias
overflow = true
out:
// Assemble bits.
bits := mant & (uint64(1)<<flt.mantbits - 1)
bits |= uint64((exp-flt.bias)&(1<<flt.expbits-1)) << flt.mantbits
if d.neg {
bits |= 1 << flt.mantbits << flt.expbits
}
return bits, overflow
}
// Exact powers of 10.
var float64pow10 = []float64{
1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9,
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
1e20, 1e21, 1e22,
}
var float32pow10 = []float32{1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10}
// If possible to convert decimal representation to 64-bit float f exactly,
// entirely in floating-point math, do so, avoiding the expense of decimalToFloatBits.
// Three common cases:
//
// value is exact integer
// value is exact integer * exact power of ten
// value is exact integer / exact power of ten
//
// These all produce potentially inexact but correctly rounded answers.
func atof64exact(mantissa uint64, exp int, neg bool) (f float64, ok bool) {
if mantissa>>float64info.mantbits != 0 {
return
}
f = float64(mantissa)
if neg {
f = -f
}
switch {
case exp == 0:
// an integer.
return f, true
// Exact integers are <= 10^15.
// Exact powers of ten are <= 10^22.
case exp > 0 && exp <= 15+22: // int * 10^k
// If exponent is big but number of digits is not,
// can move a few zeros into the integer part.
if exp > 22 {
f *= float64pow10[exp-22]
exp = 22
}
if f > 1e15 || f < -1e15 {
// the exponent was really too large.
return
}
return f * float64pow10[exp], true
case exp < 0 && exp >= -22: // int / 10^k
return f / float64pow10[-exp], true
}
return
}
// If possible to compute mantissa*10^exp to 32-bit float f exactly,
// entirely in floating-point math, do so, avoiding the machinery above.
func atof32exact(mantissa uint64, exp int, neg bool) (f float32, ok bool) {
if mantissa>>float32info.mantbits != 0 {
return
}
f = float32(mantissa)
if neg {
f = -f
}
switch {
case exp == 0:
return f, true
// Exact integers are <= 10^7.
// Exact powers of ten are <= 10^10.
case exp > 0 && exp <= 7+10: // int * 10^k
// If exponent is big but number of digits is not,
// can move a few zeros into the integer part.
if exp > 10 {
f *= float32pow10[exp-10]
exp = 10
}
if f > 1e7 || f < -1e7 {
// the exponent was really too large.
return
}
return f * float32pow10[exp], true
case exp < 0 && exp >= -10: // int / 10^k
return f / float32pow10[-exp], true
}
return
}
// atofHex converts the hex floating-point string s
// to a rounded float32 or float64 value (depending on flt==&float32info or flt==&float64info)
// and returns it as a float64.
// The string s has already been parsed into a mantissa, exponent, and sign (neg==true for negative).
// If trunc is true, trailing non-zero bits have been omitted from the mantissa.
func atofHex(s string, flt *floatInfo, mantissa uint64, exp int, neg, trunc bool) (float64, error) {
maxExp := 1<<flt.expbits + flt.bias - 2
minExp := flt.bias + 1
exp += int(flt.mantbits) // mantissa now implicitly divided by 2^mantbits.
// Shift mantissa and exponent to bring representation into float range.
// Eventually we want a mantissa with a leading 1-bit followed by mantbits other bits.
// For rounding, we need two more, where the bottom bit represents
// whether that bit or any later bit was non-zero.
// (If the mantissa has already lost non-zero bits, trunc is true,
// and we OR in a 1 below after shifting left appropriately.)
for mantissa != 0 && mantissa>>(flt.mantbits+2) == 0 {
mantissa <<= 1
exp--
}
if trunc {
mantissa |= 1
}
for mantissa>>(1+flt.mantbits+2) != 0 {
mantissa = mantissa>>1 | mantissa&1
exp++
}
// If exponent is too negative,
// denormalize in hopes of making it representable.
// (The -2 is for the rounding bits.)
for mantissa > 1 && exp < minExp-2 {
mantissa = mantissa>>1 | mantissa&1
exp++
}
// Round using two bottom bits.
round := mantissa & 3
mantissa >>= 2
round |= mantissa & 1 // round to even (round up if mantissa is odd)
exp += 2
if round == 3 {
mantissa++
if mantissa == 1<<(1+flt.mantbits) {
mantissa >>= 1
exp++
}
}
if mantissa>>flt.mantbits == 0 { // Denormal or zero.
exp = flt.bias
}
var err error
if exp > maxExp { // infinity and range error
mantissa = 1 << flt.mantbits
exp = maxExp + 1
err = rangeError(fnParseFloat, s)
}
bits := mantissa & (1<<flt.mantbits - 1)
bits |= uint64((exp-flt.bias)&(1<<flt.expbits-1)) << flt.mantbits
if neg {
bits |= 1 << flt.mantbits << flt.expbits
}
if flt == &float32info {
return float64(math.Float32frombits(uint32(bits))), err
}
return math.Float64frombits(bits), err
}
const fnParseFloat = "ParseFloat"
func atof32(s string) (f float32, n int, err error) {
if val, n, ok := special(s); ok {
return float32(val), n, nil
}
mantissa, exp, neg, trunc, hex, n, ok := readFloat(s)
if !ok {
return 0, n, syntaxError(fnParseFloat, s)
}
if hex {
f, err := atofHex(s[:n], &float32info, mantissa, exp, neg, trunc)
return float32(f), n, err
}
if optimize {
// Try pure floating-point arithmetic conversion, and if that fails,
// the Eisel-Lemire algorithm.
if !trunc {
if f, ok := atof32exact(mantissa, exp, neg); ok {
return f, n, nil
}
}
f, ok := eiselLemire32(mantissa, exp, neg)
if ok {
if !trunc {
return f, n, nil
}
// Even if the mantissa was truncated, we may
// have found the correct result. Confirm by
// converting the upper mantissa bound.
fUp, ok := eiselLemire32(mantissa+1, exp, neg)
if ok && f == fUp {
return f, n, nil
}
}
}
// Slow fallback.
var d decimal
if !d.set(s[:n]) {
return 0, n, syntaxError(fnParseFloat, s)
}
b, ovf := d.floatBits(&float32info)
f = math.Float32frombits(uint32(b))
if ovf {
err = rangeError(fnParseFloat, s)
}
return f, n, err
}
func atof64(s string) (f float64, n int, err error) {
if val, n, ok := special(s); ok {
return val, n, nil
}
mantissa, exp, neg, trunc, hex, n, ok := readFloat(s)
if !ok {
return 0, n, syntaxError(fnParseFloat, s)
}
if hex {
f, err := atofHex(s[:n], &float64info, mantissa, exp, neg, trunc)
return f, n, err
}
if optimize {
// Try pure floating-point arithmetic conversion, and if that fails,
// the Eisel-Lemire algorithm.
if !trunc {
if f, ok := atof64exact(mantissa, exp, neg); ok {
return f, n, nil
}
}
f, ok := eiselLemire64(mantissa, exp, neg)
if ok {
if !trunc {
return f, n, nil
}
// Even if the mantissa was truncated, we may
// have found the correct result. Confirm by
// converting the upper mantissa bound.
fUp, ok := eiselLemire64(mantissa+1, exp, neg)
if ok && f == fUp {
return f, n, nil
}
}
}
// Slow fallback.
var d decimal
if !d.set(s[:n]) {
return 0, n, syntaxError(fnParseFloat, s)
}
b, ovf := d.floatBits(&float64info)
f = math.Float64frombits(b)
if ovf {
err = rangeError(fnParseFloat, s)
}
return f, n, err
}
// ParseFloat converts the string s to a floating-point number
// with the precision specified by bitSize: 32 for float32, or 64 for float64.
// When bitSize=32, the result still has type float64, but it will be
// convertible to float32 without changing its value.
//
// ParseFloat accepts decimal and hexadecimal floating-point numbers
// as defined by the Go syntax for [floating-point literals].
// If s is well-formed and near a valid floating-point number,
// ParseFloat returns the nearest floating-point number rounded
// using IEEE754 unbiased rounding.
// (Parsing a hexadecimal floating-point value only rounds when
// there are more bits in the hexadecimal representation than
// will fit in the mantissa.)
//
// The errors that ParseFloat returns have concrete type *NumError
// and include err.Num = s.
//
// If s is not syntactically well-formed, ParseFloat returns err.Err = ErrSyntax.
//
// If s is syntactically well-formed but is more than 1/2 ULP
// away from the largest floating point number of the given size,
// ParseFloat returns f = ±Inf, err.Err = ErrRange.
//
// ParseFloat recognizes the string "NaN", and the (possibly signed) strings "Inf" and "Infinity"
// as their respective special floating point values. It ignores case when matching.
//
// [floating-point literals]: https://go.dev/ref/spec#Floating-point_literals
func ParseFloat(s string, bitSize int) (float64, error) {
f, n, err := parseFloatPrefix(s, bitSize)
if n != len(s) && (err == nil || err.(*NumError).Err != ErrSyntax) {
return 0, syntaxError(fnParseFloat, s)
}
return f, err
}
func parseFloatPrefix(s string, bitSize int) (float64, int, error) {
if bitSize == 32 {
f, n, err := atof32(s)
return float64(f), n, err
}
return atof64(s)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
import "errors"
// lower(c) is a lower-case letter if and only if
// c is either that lower-case letter or the equivalent upper-case letter.
// Instead of writing c == 'x' || c == 'X' one can write lower(c) == 'x'.
// Note that lower of non-letters can produce other non-letters.
func lower(c byte) byte {
return c | ('x' - 'X')
}
// ErrRange indicates that a value is out of range for the target type.
var ErrRange = errors.New("value out of range")
// ErrSyntax indicates that a value does not have the right syntax for the target type.
var ErrSyntax = errors.New("invalid syntax")
// A NumError records a failed conversion.
type NumError struct {
Func string // the failing function (ParseBool, ParseInt, ParseUint, ParseFloat, ParseComplex)
Num string // the input
Err error // the reason the conversion failed (e.g. ErrRange, ErrSyntax, etc.)
}
func (e *NumError) Error() string {
return "strconv." + e.Func + ": " + "parsing " + Quote(e.Num) + ": " + e.Err.Error()
}
func (e *NumError) Unwrap() error { return e.Err }
// cloneString returns a string copy of x.
//
// All ParseXXX functions allow the input string to escape to the error value.
// This hurts strconv.ParseXXX(string(b)) calls where b is []byte since
// the conversion from []byte must allocate a string on the heap.
// If we assume errors are infrequent, then we can avoid escaping the input
// back to the output by copying it first. This allows the compiler to call
// strconv.ParseXXX without a heap allocation for most []byte to string
// conversions, since it can now prove that the string cannot escape Parse.
//
// TODO: Use strings.Clone instead? However, we cannot depend on "strings"
// since it incurs a transitive dependency on "unicode".
// Either move strings.Clone to an internal/bytealg or make the
// "strings" to "unicode" dependency lighter (see https://go.dev/issue/54098).
func cloneString(x string) string { return string([]byte(x)) }
func syntaxError(fn, str string) *NumError {
return &NumError{fn, cloneString(str), ErrSyntax}
}
func rangeError(fn, str string) *NumError {
return &NumError{fn, cloneString(str), ErrRange}
}
func baseError(fn, str string, base int) *NumError {
return &NumError{fn, cloneString(str), errors.New("invalid base " + Itoa(base))}
}
func bitSizeError(fn, str string, bitSize int) *NumError {
return &NumError{fn, cloneString(str), errors.New("invalid bit size " + Itoa(bitSize))}
}
const intSize = 32 << (^uint(0) >> 63)
// IntSize is the size in bits of an int or uint value.
const IntSize = intSize
const maxUint64 = 1<<64 - 1
// ParseUint is like ParseInt but for unsigned numbers.
//
// A sign prefix is not permitted.
func ParseUint(s string, base int, bitSize int) (uint64, error) {
const fnParseUint = "ParseUint"
if s == "" {
return 0, syntaxError(fnParseUint, s)
}
base0 := base == 0
s0 := s
switch {
case 2 <= base && base <= 36:
// valid base; nothing to do
case base == 0:
// Look for octal, hex prefix.
base = 10
if s[0] == '0' {
switch {
case len(s) >= 3 && lower(s[1]) == 'b':
base = 2
s = s[2:]
case len(s) >= 3 && lower(s[1]) == 'o':
base = 8
s = s[2:]
case len(s) >= 3 && lower(s[1]) == 'x':
base = 16
s = s[2:]
default:
base = 8
s = s[1:]
}
}
default:
return 0, baseError(fnParseUint, s0, base)
}
if bitSize == 0 {
bitSize = IntSize
} else if bitSize < 0 || bitSize > 64 {
return 0, bitSizeError(fnParseUint, s0, bitSize)
}
// Cutoff is the smallest number such that cutoff*base > maxUint64.
// Use compile-time constants for common cases.
var cutoff uint64
switch base {
case 10:
cutoff = maxUint64/10 + 1
case 16:
cutoff = maxUint64/16 + 1
default:
cutoff = maxUint64/uint64(base) + 1
}
maxVal := uint64(1)<<uint(bitSize) - 1
underscores := false
var n uint64
for _, c := range []byte(s) {
var d byte
switch {
case c == '_' && base0:
underscores = true
continue
case '0' <= c && c <= '9':
d = c - '0'
case 'a' <= lower(c) && lower(c) <= 'z':
d = lower(c) - 'a' + 10
default:
return 0, syntaxError(fnParseUint, s0)
}
if d >= byte(base) {
return 0, syntaxError(fnParseUint, s0)
}
if n >= cutoff {
// n*base overflows
return maxVal, rangeError(fnParseUint, s0)
}
n *= uint64(base)
n1 := n + uint64(d)
if n1 < n || n1 > maxVal {
// n+d overflows
return maxVal, rangeError(fnParseUint, s0)
}
n = n1
}
if underscores && !underscoreOK(s0) {
return 0, syntaxError(fnParseUint, s0)
}
return n, nil
}
// ParseInt interprets a string s in the given base (0, 2 to 36) and
// bit size (0 to 64) and returns the corresponding value i.
//
// The string may begin with a leading sign: "+" or "-".
//
// If the base argument is 0, the true base is implied by the string's
// prefix following the sign (if present): 2 for "0b", 8 for "0" or "0o",
// 16 for "0x", and 10 otherwise. Also, for argument base 0 only,
// underscore characters are permitted as defined by the Go syntax for
// [integer literals].
//
// The bitSize argument specifies the integer type
// that the result must fit into. Bit sizes 0, 8, 16, 32, and 64
// correspond to int, int8, int16, int32, and int64.
// If bitSize is below 0 or above 64, an error is returned.
//
// The errors that ParseInt returns have concrete type *NumError
// and include err.Num = s. If s is empty or contains invalid
// digits, err.Err = ErrSyntax and the returned value is 0;
// if the value corresponding to s cannot be represented by a
// signed integer of the given size, err.Err = ErrRange and the
// returned value is the maximum magnitude integer of the
// appropriate bitSize and sign.
//
// [integer literals]: https://go.dev/ref/spec#Integer_literals
func ParseInt(s string, base int, bitSize int) (i int64, err error) {
const fnParseInt = "ParseInt"
if s == "" {
return 0, syntaxError(fnParseInt, s)
}
// Pick off leading sign.
s0 := s
neg := false
if s[0] == '+' {
s = s[1:]
} else if s[0] == '-' {
neg = true
s = s[1:]
}
// Convert unsigned and check range.
var un uint64
un, err = ParseUint(s, base, bitSize)
if err != nil && err.(*NumError).Err != ErrRange {
err.(*NumError).Func = fnParseInt
err.(*NumError).Num = cloneString(s0)
return 0, err
}
if bitSize == 0 {
bitSize = IntSize
}
cutoff := uint64(1 << uint(bitSize-1))
if !neg && un >= cutoff {
return int64(cutoff - 1), rangeError(fnParseInt, s0)
}
if neg && un > cutoff {
return -int64(cutoff), rangeError(fnParseInt, s0)
}
n := int64(un)
if neg {
n = -n
}
return n, nil
}
// Atoi is equivalent to ParseInt(s, 10, 0), converted to type int.
func Atoi(s string) (int, error) {
const fnAtoi = "Atoi"
sLen := len(s)
if intSize == 32 && (0 < sLen && sLen < 10) ||
intSize == 64 && (0 < sLen && sLen < 19) {
// Fast path for small integers that fit int type.
s0 := s
if s[0] == '-' || s[0] == '+' {
s = s[1:]
if len(s) < 1 {
return 0, syntaxError(fnAtoi, s0)
}
}
n := 0
for _, ch := range []byte(s) {
ch -= '0'
if ch > 9 {
return 0, syntaxError(fnAtoi, s0)
}
n = n*10 + int(ch)
}
if s0[0] == '-' {
n = -n
}
return n, nil
}
// Slow path for invalid, big, or underscored integers.
i64, err := ParseInt(s, 10, 0)
if nerr, ok := err.(*NumError); ok {
nerr.Func = fnAtoi
}
return int(i64), err
}
// underscoreOK reports whether the underscores in s are allowed.
// Checking them in this one function lets all the parsers skip over them simply.
// Underscore must appear only between digits or between a base prefix and a digit.
func underscoreOK(s string) bool {
// saw tracks the last character (class) we saw:
// ^ for beginning of number,
// 0 for a digit or base prefix,
// _ for an underscore,
// ! for none of the above.
saw := '^'
i := 0
// Optional sign.
if len(s) >= 1 && (s[0] == '-' || s[0] == '+') {
s = s[1:]
}
// Optional base prefix.
hex := false
if len(s) >= 2 && s[0] == '0' && (lower(s[1]) == 'b' || lower(s[1]) == 'o' || lower(s[1]) == 'x') {
i = 2
saw = '0' // base prefix counts as a digit for "underscore as digit separator"
hex = lower(s[1]) == 'x'
}
// Number proper.
for ; i < len(s); i++ {
// Digits are always okay.
if '0' <= s[i] && s[i] <= '9' || hex && 'a' <= lower(s[i]) && lower(s[i]) <= 'f' {
saw = '0'
continue
}
// Underscore must follow digit.
if s[i] == '_' {
if saw != '0' {
return false
}
saw = '_'
continue
}
// Underscore must also be followed by digit.
if saw == '_' {
return false
}
// Saw non-digit, non-underscore.
saw = '!'
}
return saw != '_'
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !compiler_bootstrap
// +build !compiler_bootstrap
package strconv
import "internal/bytealg"
// index returns the index of the first instance of c in s, or -1 if missing.
func index(s string, c byte) int {
return bytealg.IndexByteString(s, c)
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
// FormatComplex converts the complex number c to a string of the
// form (a+bi) where a and b are the real and imaginary parts,
// formatted according to the format fmt and precision prec.
//
// The format fmt and precision prec have the same meaning as in FormatFloat.
// It rounds the result assuming that the original was obtained from a complex
// value of bitSize bits, which must be 64 for complex64 and 128 for complex128.
func FormatComplex(c complex128, fmt byte, prec, bitSize int) string {
if bitSize != 64 && bitSize != 128 {
panic("invalid bitSize")
}
bitSize >>= 1 // complex64 uses float32 internally
// Check if imaginary part has a sign. If not, add one.
im := FormatFloat(imag(c), fmt, prec, bitSize)
if im[0] != '+' && im[0] != '-' {
im = "+" + im
}
return "(" + FormatFloat(real(c), fmt, prec, bitSize) + im + "i)"
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Multiprecision decimal numbers.
// For floating-point formatting only; not general purpose.
// Only operations are assign and (binary) left/right shift.
// Can do binary floating point in multiprecision decimal precisely
// because 2 divides 10; cannot do decimal floating point
// in multiprecision binary precisely.
package strconv
type decimal struct {
d [800]byte // digits, big-endian representation
nd int // number of digits used
dp int // decimal point
neg bool // negative flag
trunc bool // discarded nonzero digits beyond d[:nd]
}
func (a *decimal) String() string {
n := 10 + a.nd
if a.dp > 0 {
n += a.dp
}
if a.dp < 0 {
n += -a.dp
}
buf := make([]byte, n)
w := 0
switch {
case a.nd == 0:
return "0"
case a.dp <= 0:
// zeros fill space between decimal point and digits
buf[w] = '0'
w++
buf[w] = '.'
w++
w += digitZero(buf[w : w+-a.dp])
w += copy(buf[w:], a.d[0:a.nd])
case a.dp < a.nd:
// decimal point in middle of digits
w += copy(buf[w:], a.d[0:a.dp])
buf[w] = '.'
w++
w += copy(buf[w:], a.d[a.dp:a.nd])
default:
// zeros fill space between digits and decimal point
w += copy(buf[w:], a.d[0:a.nd])
w += digitZero(buf[w : w+a.dp-a.nd])
}
return string(buf[0:w])
}
func digitZero(dst []byte) int {
for i := range dst {
dst[i] = '0'
}
return len(dst)
}
// trim trailing zeros from number.
// (They are meaningless; the decimal point is tracked
// independent of the number of digits.)
func trim(a *decimal) {
for a.nd > 0 && a.d[a.nd-1] == '0' {
a.nd--
}
if a.nd == 0 {
a.dp = 0
}
}
// Assign v to a.
func (a *decimal) Assign(v uint64) {
var buf [24]byte
// Write reversed decimal in buf.
n := 0
for v > 0 {
v1 := v / 10
v -= 10 * v1
buf[n] = byte(v + '0')
n++
v = v1
}
// Reverse again to produce forward decimal in a.d.
a.nd = 0
for n--; n >= 0; n-- {
a.d[a.nd] = buf[n]
a.nd++
}
a.dp = a.nd
trim(a)
}
// Maximum shift that we can do in one pass without overflow.
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
const uintSize = 32 << (^uint(0) >> 63)
const maxShift = uintSize - 4
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
func rightShift(a *decimal, k uint) {
r := 0 // read pointer
w := 0 // write pointer
// Pick up enough leading digits to cover first shift.
var n uint
for ; n>>k == 0; r++ {
if r >= a.nd {
if n == 0 {
// a == 0; shouldn't get here, but handle anyway.
a.nd = 0
return
}
for n>>k == 0 {
n = n * 10
r++
}
break
}
c := uint(a.d[r])
n = n*10 + c - '0'
}
a.dp -= r - 1
var mask uint = (1 << k) - 1
// Pick up a digit, put down a digit.
for ; r < a.nd; r++ {
c := uint(a.d[r])
dig := n >> k
n &= mask
a.d[w] = byte(dig + '0')
w++
n = n*10 + c - '0'
}
// Put down extra digits.
for n > 0 {
dig := n >> k
n &= mask
if w < len(a.d) {
a.d[w] = byte(dig + '0')
w++
} else if dig > 0 {
a.trunc = true
}
n = n * 10
}
a.nd = w
trim(a)
}
// Cheat sheet for left shift: table indexed by shift count giving
// number of new digits that will be introduced by that shift.
//
// For example, leftcheats[4] = {2, "625"}. That means that
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
// when the string prefix is "625" through "999", and one fewer digit
// if the string prefix is "000" through "624".
//
// Credit for this trick goes to Ken.
type leftCheat struct {
delta int // number of new digits
cutoff string // minus one digit if original < a.
}
var leftcheats = []leftCheat{
// Leading digits of 1/2^i = 5^i.
// 5^23 is not an exact 64-bit floating point number,
// so have to use bc for the math.
// Go up to 60 to be large enough for 32bit and 64bit platforms.
/*
seq 60 | sed 's/^/5^/' | bc |
awk 'BEGIN{ print "\t{ 0, \"\" }," }
{
log2 = log(2)/log(10)
printf("\t{ %d, \"%s\" },\t// * %d\n",
int(log2*NR+1), $0, 2**NR)
}'
*/
{0, ""},
{1, "5"}, // * 2
{1, "25"}, // * 4
{1, "125"}, // * 8
{2, "625"}, // * 16
{2, "3125"}, // * 32
{2, "15625"}, // * 64
{3, "78125"}, // * 128
{3, "390625"}, // * 256
{3, "1953125"}, // * 512
{4, "9765625"}, // * 1024
{4, "48828125"}, // * 2048
{4, "244140625"}, // * 4096
{4, "1220703125"}, // * 8192
{5, "6103515625"}, // * 16384
{5, "30517578125"}, // * 32768
{5, "152587890625"}, // * 65536
{6, "762939453125"}, // * 131072
{6, "3814697265625"}, // * 262144
{6, "19073486328125"}, // * 524288
{7, "95367431640625"}, // * 1048576
{7, "476837158203125"}, // * 2097152
{7, "2384185791015625"}, // * 4194304
{7, "11920928955078125"}, // * 8388608
{8, "59604644775390625"}, // * 16777216
{8, "298023223876953125"}, // * 33554432
{8, "1490116119384765625"}, // * 67108864
{9, "7450580596923828125"}, // * 134217728
{9, "37252902984619140625"}, // * 268435456
{9, "186264514923095703125"}, // * 536870912
{10, "931322574615478515625"}, // * 1073741824
{10, "4656612873077392578125"}, // * 2147483648
{10, "23283064365386962890625"}, // * 4294967296
{10, "116415321826934814453125"}, // * 8589934592
{11, "582076609134674072265625"}, // * 17179869184
{11, "2910383045673370361328125"}, // * 34359738368
{11, "14551915228366851806640625"}, // * 68719476736
{12, "72759576141834259033203125"}, // * 137438953472
{12, "363797880709171295166015625"}, // * 274877906944
{12, "1818989403545856475830078125"}, // * 549755813888
{13, "9094947017729282379150390625"}, // * 1099511627776
{13, "45474735088646411895751953125"}, // * 2199023255552
{13, "227373675443232059478759765625"}, // * 4398046511104
{13, "1136868377216160297393798828125"}, // * 8796093022208
{14, "5684341886080801486968994140625"}, // * 17592186044416
{14, "28421709430404007434844970703125"}, // * 35184372088832
{14, "142108547152020037174224853515625"}, // * 70368744177664
{15, "710542735760100185871124267578125"}, // * 140737488355328
{15, "3552713678800500929355621337890625"}, // * 281474976710656
{15, "17763568394002504646778106689453125"}, // * 562949953421312
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
}
// Is the leading prefix of b lexicographically less than s?
func prefixIsLessThan(b []byte, s string) bool {
for i := 0; i < len(s); i++ {
if i >= len(b) {
return true
}
if b[i] != s[i] {
return b[i] < s[i]
}
}
return false
}
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
func leftShift(a *decimal, k uint) {
delta := leftcheats[k].delta
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
delta--
}
r := a.nd // read index
w := a.nd + delta // write index
// Pick up a digit, put down a digit.
var n uint
for r--; r >= 0; r-- {
n += (uint(a.d[r]) - '0') << k
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
// Put down extra digits.
for n > 0 {
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
a.nd += delta
if a.nd >= len(a.d) {
a.nd = len(a.d)
}
a.dp += delta
trim(a)
}
// Binary shift left (k > 0) or right (k < 0).
func (a *decimal) Shift(k int) {
switch {
case a.nd == 0:
// nothing to do: a == 0
case k > 0:
for k > maxShift {
leftShift(a, maxShift)
k -= maxShift
}
leftShift(a, uint(k))
case k < 0:
for k < -maxShift {
rightShift(a, maxShift)
k += maxShift
}
rightShift(a, uint(-k))
}
}
// If we chop a at nd digits, should we round up?
func shouldRoundUp(a *decimal, nd int) bool {
if nd < 0 || nd >= a.nd {
return false
}
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
// if we truncated, a little higher than what's recorded - always round up
if a.trunc {
return true
}
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
}
// not halfway - digit tells all
return a.d[nd] >= '5'
}
// Round a to nd digits (or fewer).
// If nd is zero, it means we're rounding
// just to the left of the digits, as in
// 0.09 -> 0.1.
func (a *decimal) Round(nd int) {
if nd < 0 || nd >= a.nd {
return
}
if shouldRoundUp(a, nd) {
a.RoundUp(nd)
} else {
a.RoundDown(nd)
}
}
// Round a down to nd digits (or fewer).
func (a *decimal) RoundDown(nd int) {
if nd < 0 || nd >= a.nd {
return
}
a.nd = nd
trim(a)
}
// Round a up to nd digits (or fewer).
func (a *decimal) RoundUp(nd int) {
if nd < 0 || nd >= a.nd {
return
}
// round up
for i := nd - 1; i >= 0; i-- {
c := a.d[i]
if c < '9' { // can stop after this digit
a.d[i]++
a.nd = i + 1
return
}
}
// Number is all 9s.
// Change to single 1 with adjusted decimal point.
a.d[0] = '1'
a.nd = 1
a.dp++
}
// Extract integer part, rounded appropriately.
// No guarantees about overflow.
func (a *decimal) RoundedInteger() uint64 {
if a.dp > 20 {
return 0xFFFFFFFFFFFFFFFF
}
var i int
n := uint64(0)
for i = 0; i < a.dp && i < a.nd; i++ {
n = n*10 + uint64(a.d[i]-'0')
}
for ; i < a.dp; i++ {
n *= 10
}
if shouldRoundUp(a, a.dp) {
n++
}
return n
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
// This file implements the Eisel-Lemire ParseFloat algorithm, published in
// 2020 and discussed extensively at
// https://nigeltao.github.io/blog/2020/eisel-lemire.html
//
// The original C++ implementation is at
// https://github.com/lemire/fast_double_parser/blob/644bef4306059d3be01a04e77d3cc84b379c596f/include/fast_double_parser.h#L840
//
// This Go re-implementation closely follows the C re-implementation at
// https://github.com/google/wuffs/blob/ba3818cb6b473a2ed0b38ecfc07dbbd3a97e8ae7/internal/cgen/base/floatconv-submodule-code.c#L990
//
// Additional testing (on over several million test strings) is done by
// https://github.com/nigeltao/parse-number-fxx-test-data/blob/5280dcfccf6d0b02a65ae282dad0b6d9de50e039/script/test-go-strconv.go
import (
"math"
"math/bits"
)
func eiselLemire64(man uint64, exp10 int, neg bool) (f float64, ok bool) {
// The terse comments in this function body refer to sections of the
// https://nigeltao.github.io/blog/2020/eisel-lemire.html blog post.
// Exp10 Range.
if man == 0 {
if neg {
f = math.Float64frombits(0x8000000000000000) // Negative zero.
}
return f, true
}
if exp10 < detailedPowersOfTenMinExp10 || detailedPowersOfTenMaxExp10 < exp10 {
return 0, false
}
// Normalization.
clz := bits.LeadingZeros64(man)
man <<= uint(clz)
const float64ExponentBias = 1023
retExp2 := uint64(217706*exp10>>16+64+float64ExponentBias) - uint64(clz)
// Multiplication.
xHi, xLo := bits.Mul64(man, detailedPowersOfTen[exp10-detailedPowersOfTenMinExp10][1])
// Wider Approximation.
if xHi&0x1FF == 0x1FF && xLo+man < man {
yHi, yLo := bits.Mul64(man, detailedPowersOfTen[exp10-detailedPowersOfTenMinExp10][0])
mergedHi, mergedLo := xHi, xLo+yHi
if mergedLo < xLo {
mergedHi++
}
if mergedHi&0x1FF == 0x1FF && mergedLo+1 == 0 && yLo+man < man {
return 0, false
}
xHi, xLo = mergedHi, mergedLo
}
// Shifting to 54 Bits.
msb := xHi >> 63
retMantissa := xHi >> (msb + 9)
retExp2 -= 1 ^ msb
// Half-way Ambiguity.
if xLo == 0 && xHi&0x1FF == 0 && retMantissa&3 == 1 {
return 0, false
}
// From 54 to 53 Bits.
retMantissa += retMantissa & 1
retMantissa >>= 1
if retMantissa>>53 > 0 {
retMantissa >>= 1
retExp2 += 1
}
// retExp2 is a uint64. Zero or underflow means that we're in subnormal
// float64 space. 0x7FF or above means that we're in Inf/NaN float64 space.
//
// The if block is equivalent to (but has fewer branches than):
// if retExp2 <= 0 || retExp2 >= 0x7FF { etc }
if retExp2-1 >= 0x7FF-1 {
return 0, false
}
retBits := retExp2<<52 | retMantissa&0x000FFFFFFFFFFFFF
if neg {
retBits |= 0x8000000000000000
}
return math.Float64frombits(retBits), true
}
func eiselLemire32(man uint64, exp10 int, neg bool) (f float32, ok bool) {
// The terse comments in this function body refer to sections of the
// https://nigeltao.github.io/blog/2020/eisel-lemire.html blog post.
//
// That blog post discusses the float64 flavor (11 exponent bits with a
// -1023 bias, 52 mantissa bits) of the algorithm, but the same approach
// applies to the float32 flavor (8 exponent bits with a -127 bias, 23
// mantissa bits). The computation here happens with 64-bit values (e.g.
// man, xHi, retMantissa) before finally converting to a 32-bit float.
// Exp10 Range.
if man == 0 {
if neg {
f = math.Float32frombits(0x80000000) // Negative zero.
}
return f, true
}
if exp10 < detailedPowersOfTenMinExp10 || detailedPowersOfTenMaxExp10 < exp10 {
return 0, false
}
// Normalization.
clz := bits.LeadingZeros64(man)
man <<= uint(clz)
const float32ExponentBias = 127
retExp2 := uint64(217706*exp10>>16+64+float32ExponentBias) - uint64(clz)
// Multiplication.
xHi, xLo := bits.Mul64(man, detailedPowersOfTen[exp10-detailedPowersOfTenMinExp10][1])
// Wider Approximation.
if xHi&0x3FFFFFFFFF == 0x3FFFFFFFFF && xLo+man < man {
yHi, yLo := bits.Mul64(man, detailedPowersOfTen[exp10-detailedPowersOfTenMinExp10][0])
mergedHi, mergedLo := xHi, xLo+yHi
if mergedLo < xLo {
mergedHi++
}
if mergedHi&0x3FFFFFFFFF == 0x3FFFFFFFFF && mergedLo+1 == 0 && yLo+man < man {
return 0, false
}
xHi, xLo = mergedHi, mergedLo
}
// Shifting to 54 Bits (and for float32, it's shifting to 25 bits).
msb := xHi >> 63
retMantissa := xHi >> (msb + 38)
retExp2 -= 1 ^ msb
// Half-way Ambiguity.
if xLo == 0 && xHi&0x3FFFFFFFFF == 0 && retMantissa&3 == 1 {
return 0, false
}
// From 54 to 53 Bits (and for float32, it's from 25 to 24 bits).
retMantissa += retMantissa & 1
retMantissa >>= 1
if retMantissa>>24 > 0 {
retMantissa >>= 1
retExp2 += 1
}
// retExp2 is a uint64. Zero or underflow means that we're in subnormal
// float32 space. 0xFF or above means that we're in Inf/NaN float32 space.
//
// The if block is equivalent to (but has fewer branches than):
// if retExp2 <= 0 || retExp2 >= 0xFF { etc }
if retExp2-1 >= 0xFF-1 {
return 0, false
}
retBits := retExp2<<23 | retMantissa&0x007FFFFF
if neg {
retBits |= 0x80000000
}
return math.Float32frombits(uint32(retBits)), true
}
// detailedPowersOfTen{Min,Max}Exp10 is the power of 10 represented by the
// first and last rows of detailedPowersOfTen. Both bounds are inclusive.
const (
detailedPowersOfTenMinExp10 = -348
detailedPowersOfTenMaxExp10 = +347
)
// detailedPowersOfTen contains 128-bit mantissa approximations (rounded down)
// to the powers of 10. For example:
//
// - 1e43 ≈ (0xE596B7B0_C643C719 * (2 ** 79))
// - 1e43 = (0xE596B7B0_C643C719_6D9CCD05_D0000000 * (2 ** 15))
//
// The mantissas are explicitly listed. The exponents are implied by a linear
// expression with slope 217706.0/65536.0 ≈ log(10)/log(2).
//
// The table was generated by
// https://github.com/google/wuffs/blob/ba3818cb6b473a2ed0b38ecfc07dbbd3a97e8ae7/script/print-mpb-powers-of-10.go
var detailedPowersOfTen = [...][2]uint64{
{0x1732C869CD60E453, 0xFA8FD5A0081C0288}, // 1e-348
{0x0E7FBD42205C8EB4, 0x9C99E58405118195}, // 1e-347
{0x521FAC92A873B261, 0xC3C05EE50655E1FA}, // 1e-346
{0xE6A797B752909EF9, 0xF4B0769E47EB5A78}, // 1e-345
{0x9028BED2939A635C, 0x98EE4A22ECF3188B}, // 1e-344
{0x7432EE873880FC33, 0xBF29DCABA82FDEAE}, // 1e-343
{0x113FAA2906A13B3F, 0xEEF453D6923BD65A}, // 1e-342
{0x4AC7CA59A424C507, 0x9558B4661B6565F8}, // 1e-341
{0x5D79BCF00D2DF649, 0xBAAEE17FA23EBF76}, // 1e-340
{0xF4D82C2C107973DC, 0xE95A99DF8ACE6F53}, // 1e-339
{0x79071B9B8A4BE869, 0x91D8A02BB6C10594}, // 1e-338
{0x9748E2826CDEE284, 0xB64EC836A47146F9}, // 1e-337
{0xFD1B1B2308169B25, 0xE3E27A444D8D98B7}, // 1e-336
{0xFE30F0F5E50E20F7, 0x8E6D8C6AB0787F72}, // 1e-335
{0xBDBD2D335E51A935, 0xB208EF855C969F4F}, // 1e-334
{0xAD2C788035E61382, 0xDE8B2B66B3BC4723}, // 1e-333
{0x4C3BCB5021AFCC31, 0x8B16FB203055AC76}, // 1e-332
{0xDF4ABE242A1BBF3D, 0xADDCB9E83C6B1793}, // 1e-331
{0xD71D6DAD34A2AF0D, 0xD953E8624B85DD78}, // 1e-330
{0x8672648C40E5AD68, 0x87D4713D6F33AA6B}, // 1e-329
{0x680EFDAF511F18C2, 0xA9C98D8CCB009506}, // 1e-328
{0x0212BD1B2566DEF2, 0xD43BF0EFFDC0BA48}, // 1e-327
{0x014BB630F7604B57, 0x84A57695FE98746D}, // 1e-326
{0x419EA3BD35385E2D, 0xA5CED43B7E3E9188}, // 1e-325
{0x52064CAC828675B9, 0xCF42894A5DCE35EA}, // 1e-324
{0x7343EFEBD1940993, 0x818995CE7AA0E1B2}, // 1e-323
{0x1014EBE6C5F90BF8, 0xA1EBFB4219491A1F}, // 1e-322
{0xD41A26E077774EF6, 0xCA66FA129F9B60A6}, // 1e-321
{0x8920B098955522B4, 0xFD00B897478238D0}, // 1e-320
{0x55B46E5F5D5535B0, 0x9E20735E8CB16382}, // 1e-319
{0xEB2189F734AA831D, 0xC5A890362FDDBC62}, // 1e-318
{0xA5E9EC7501D523E4, 0xF712B443BBD52B7B}, // 1e-317
{0x47B233C92125366E, 0x9A6BB0AA55653B2D}, // 1e-316
{0x999EC0BB696E840A, 0xC1069CD4EABE89F8}, // 1e-315
{0xC00670EA43CA250D, 0xF148440A256E2C76}, // 1e-314
{0x380406926A5E5728, 0x96CD2A865764DBCA}, // 1e-313
{0xC605083704F5ECF2, 0xBC807527ED3E12BC}, // 1e-312
{0xF7864A44C633682E, 0xEBA09271E88D976B}, // 1e-311
{0x7AB3EE6AFBE0211D, 0x93445B8731587EA3}, // 1e-310
{0x5960EA05BAD82964, 0xB8157268FDAE9E4C}, // 1e-309
{0x6FB92487298E33BD, 0xE61ACF033D1A45DF}, // 1e-308
{0xA5D3B6D479F8E056, 0x8FD0C16206306BAB}, // 1e-307
{0x8F48A4899877186C, 0xB3C4F1BA87BC8696}, // 1e-306
{0x331ACDABFE94DE87, 0xE0B62E2929ABA83C}, // 1e-305
{0x9FF0C08B7F1D0B14, 0x8C71DCD9BA0B4925}, // 1e-304
{0x07ECF0AE5EE44DD9, 0xAF8E5410288E1B6F}, // 1e-303
{0xC9E82CD9F69D6150, 0xDB71E91432B1A24A}, // 1e-302
{0xBE311C083A225CD2, 0x892731AC9FAF056E}, // 1e-301
{0x6DBD630A48AAF406, 0xAB70FE17C79AC6CA}, // 1e-300
{0x092CBBCCDAD5B108, 0xD64D3D9DB981787D}, // 1e-299
{0x25BBF56008C58EA5, 0x85F0468293F0EB4E}, // 1e-298
{0xAF2AF2B80AF6F24E, 0xA76C582338ED2621}, // 1e-297
{0x1AF5AF660DB4AEE1, 0xD1476E2C07286FAA}, // 1e-296
{0x50D98D9FC890ED4D, 0x82CCA4DB847945CA}, // 1e-295
{0xE50FF107BAB528A0, 0xA37FCE126597973C}, // 1e-294
{0x1E53ED49A96272C8, 0xCC5FC196FEFD7D0C}, // 1e-293
{0x25E8E89C13BB0F7A, 0xFF77B1FCBEBCDC4F}, // 1e-292
{0x77B191618C54E9AC, 0x9FAACF3DF73609B1}, // 1e-291
{0xD59DF5B9EF6A2417, 0xC795830D75038C1D}, // 1e-290
{0x4B0573286B44AD1D, 0xF97AE3D0D2446F25}, // 1e-289
{0x4EE367F9430AEC32, 0x9BECCE62836AC577}, // 1e-288
{0x229C41F793CDA73F, 0xC2E801FB244576D5}, // 1e-287
{0x6B43527578C1110F, 0xF3A20279ED56D48A}, // 1e-286
{0x830A13896B78AAA9, 0x9845418C345644D6}, // 1e-285
{0x23CC986BC656D553, 0xBE5691EF416BD60C}, // 1e-284
{0x2CBFBE86B7EC8AA8, 0xEDEC366B11C6CB8F}, // 1e-283
{0x7BF7D71432F3D6A9, 0x94B3A202EB1C3F39}, // 1e-282
{0xDAF5CCD93FB0CC53, 0xB9E08A83A5E34F07}, // 1e-281
{0xD1B3400F8F9CFF68, 0xE858AD248F5C22C9}, // 1e-280
{0x23100809B9C21FA1, 0x91376C36D99995BE}, // 1e-279
{0xABD40A0C2832A78A, 0xB58547448FFFFB2D}, // 1e-278
{0x16C90C8F323F516C, 0xE2E69915B3FFF9F9}, // 1e-277
{0xAE3DA7D97F6792E3, 0x8DD01FAD907FFC3B}, // 1e-276
{0x99CD11CFDF41779C, 0xB1442798F49FFB4A}, // 1e-275
{0x40405643D711D583, 0xDD95317F31C7FA1D}, // 1e-274
{0x482835EA666B2572, 0x8A7D3EEF7F1CFC52}, // 1e-273
{0xDA3243650005EECF, 0xAD1C8EAB5EE43B66}, // 1e-272
{0x90BED43E40076A82, 0xD863B256369D4A40}, // 1e-271
{0x5A7744A6E804A291, 0x873E4F75E2224E68}, // 1e-270
{0x711515D0A205CB36, 0xA90DE3535AAAE202}, // 1e-269
{0x0D5A5B44CA873E03, 0xD3515C2831559A83}, // 1e-268
{0xE858790AFE9486C2, 0x8412D9991ED58091}, // 1e-267
{0x626E974DBE39A872, 0xA5178FFF668AE0B6}, // 1e-266
{0xFB0A3D212DC8128F, 0xCE5D73FF402D98E3}, // 1e-265
{0x7CE66634BC9D0B99, 0x80FA687F881C7F8E}, // 1e-264
{0x1C1FFFC1EBC44E80, 0xA139029F6A239F72}, // 1e-263
{0xA327FFB266B56220, 0xC987434744AC874E}, // 1e-262
{0x4BF1FF9F0062BAA8, 0xFBE9141915D7A922}, // 1e-261
{0x6F773FC3603DB4A9, 0x9D71AC8FADA6C9B5}, // 1e-260
{0xCB550FB4384D21D3, 0xC4CE17B399107C22}, // 1e-259
{0x7E2A53A146606A48, 0xF6019DA07F549B2B}, // 1e-258
{0x2EDA7444CBFC426D, 0x99C102844F94E0FB}, // 1e-257
{0xFA911155FEFB5308, 0xC0314325637A1939}, // 1e-256
{0x793555AB7EBA27CA, 0xF03D93EEBC589F88}, // 1e-255
{0x4BC1558B2F3458DE, 0x96267C7535B763B5}, // 1e-254
{0x9EB1AAEDFB016F16, 0xBBB01B9283253CA2}, // 1e-253
{0x465E15A979C1CADC, 0xEA9C227723EE8BCB}, // 1e-252
{0x0BFACD89EC191EC9, 0x92A1958A7675175F}, // 1e-251
{0xCEF980EC671F667B, 0xB749FAED14125D36}, // 1e-250
{0x82B7E12780E7401A, 0xE51C79A85916F484}, // 1e-249
{0xD1B2ECB8B0908810, 0x8F31CC0937AE58D2}, // 1e-248
{0x861FA7E6DCB4AA15, 0xB2FE3F0B8599EF07}, // 1e-247
{0x67A791E093E1D49A, 0xDFBDCECE67006AC9}, // 1e-246
{0xE0C8BB2C5C6D24E0, 0x8BD6A141006042BD}, // 1e-245
{0x58FAE9F773886E18, 0xAECC49914078536D}, // 1e-244
{0xAF39A475506A899E, 0xDA7F5BF590966848}, // 1e-243
{0x6D8406C952429603, 0x888F99797A5E012D}, // 1e-242
{0xC8E5087BA6D33B83, 0xAAB37FD7D8F58178}, // 1e-241
{0xFB1E4A9A90880A64, 0xD5605FCDCF32E1D6}, // 1e-240
{0x5CF2EEA09A55067F, 0x855C3BE0A17FCD26}, // 1e-239
{0xF42FAA48C0EA481E, 0xA6B34AD8C9DFC06F}, // 1e-238
{0xF13B94DAF124DA26, 0xD0601D8EFC57B08B}, // 1e-237
{0x76C53D08D6B70858, 0x823C12795DB6CE57}, // 1e-236
{0x54768C4B0C64CA6E, 0xA2CB1717B52481ED}, // 1e-235
{0xA9942F5DCF7DFD09, 0xCB7DDCDDA26DA268}, // 1e-234
{0xD3F93B35435D7C4C, 0xFE5D54150B090B02}, // 1e-233
{0xC47BC5014A1A6DAF, 0x9EFA548D26E5A6E1}, // 1e-232
{0x359AB6419CA1091B, 0xC6B8E9B0709F109A}, // 1e-231
{0xC30163D203C94B62, 0xF867241C8CC6D4C0}, // 1e-230
{0x79E0DE63425DCF1D, 0x9B407691D7FC44F8}, // 1e-229
{0x985915FC12F542E4, 0xC21094364DFB5636}, // 1e-228
{0x3E6F5B7B17B2939D, 0xF294B943E17A2BC4}, // 1e-227
{0xA705992CEECF9C42, 0x979CF3CA6CEC5B5A}, // 1e-226
{0x50C6FF782A838353, 0xBD8430BD08277231}, // 1e-225
{0xA4F8BF5635246428, 0xECE53CEC4A314EBD}, // 1e-224
{0x871B7795E136BE99, 0x940F4613AE5ED136}, // 1e-223
{0x28E2557B59846E3F, 0xB913179899F68584}, // 1e-222
{0x331AEADA2FE589CF, 0xE757DD7EC07426E5}, // 1e-221
{0x3FF0D2C85DEF7621, 0x9096EA6F3848984F}, // 1e-220
{0x0FED077A756B53A9, 0xB4BCA50B065ABE63}, // 1e-219
{0xD3E8495912C62894, 0xE1EBCE4DC7F16DFB}, // 1e-218
{0x64712DD7ABBBD95C, 0x8D3360F09CF6E4BD}, // 1e-217
{0xBD8D794D96AACFB3, 0xB080392CC4349DEC}, // 1e-216
{0xECF0D7A0FC5583A0, 0xDCA04777F541C567}, // 1e-215
{0xF41686C49DB57244, 0x89E42CAAF9491B60}, // 1e-214
{0x311C2875C522CED5, 0xAC5D37D5B79B6239}, // 1e-213
{0x7D633293366B828B, 0xD77485CB25823AC7}, // 1e-212
{0xAE5DFF9C02033197, 0x86A8D39EF77164BC}, // 1e-211
{0xD9F57F830283FDFC, 0xA8530886B54DBDEB}, // 1e-210
{0xD072DF63C324FD7B, 0xD267CAA862A12D66}, // 1e-209
{0x4247CB9E59F71E6D, 0x8380DEA93DA4BC60}, // 1e-208
{0x52D9BE85F074E608, 0xA46116538D0DEB78}, // 1e-207
{0x67902E276C921F8B, 0xCD795BE870516656}, // 1e-206
{0x00BA1CD8A3DB53B6, 0x806BD9714632DFF6}, // 1e-205
{0x80E8A40ECCD228A4, 0xA086CFCD97BF97F3}, // 1e-204
{0x6122CD128006B2CD, 0xC8A883C0FDAF7DF0}, // 1e-203
{0x796B805720085F81, 0xFAD2A4B13D1B5D6C}, // 1e-202
{0xCBE3303674053BB0, 0x9CC3A6EEC6311A63}, // 1e-201
{0xBEDBFC4411068A9C, 0xC3F490AA77BD60FC}, // 1e-200
{0xEE92FB5515482D44, 0xF4F1B4D515ACB93B}, // 1e-199
{0x751BDD152D4D1C4A, 0x991711052D8BF3C5}, // 1e-198
{0xD262D45A78A0635D, 0xBF5CD54678EEF0B6}, // 1e-197
{0x86FB897116C87C34, 0xEF340A98172AACE4}, // 1e-196
{0xD45D35E6AE3D4DA0, 0x9580869F0E7AAC0E}, // 1e-195
{0x8974836059CCA109, 0xBAE0A846D2195712}, // 1e-194
{0x2BD1A438703FC94B, 0xE998D258869FACD7}, // 1e-193
{0x7B6306A34627DDCF, 0x91FF83775423CC06}, // 1e-192
{0x1A3BC84C17B1D542, 0xB67F6455292CBF08}, // 1e-191
{0x20CABA5F1D9E4A93, 0xE41F3D6A7377EECA}, // 1e-190
{0x547EB47B7282EE9C, 0x8E938662882AF53E}, // 1e-189
{0xE99E619A4F23AA43, 0xB23867FB2A35B28D}, // 1e-188
{0x6405FA00E2EC94D4, 0xDEC681F9F4C31F31}, // 1e-187
{0xDE83BC408DD3DD04, 0x8B3C113C38F9F37E}, // 1e-186
{0x9624AB50B148D445, 0xAE0B158B4738705E}, // 1e-185
{0x3BADD624DD9B0957, 0xD98DDAEE19068C76}, // 1e-184
{0xE54CA5D70A80E5D6, 0x87F8A8D4CFA417C9}, // 1e-183
{0x5E9FCF4CCD211F4C, 0xA9F6D30A038D1DBC}, // 1e-182
{0x7647C3200069671F, 0xD47487CC8470652B}, // 1e-181
{0x29ECD9F40041E073, 0x84C8D4DFD2C63F3B}, // 1e-180
{0xF468107100525890, 0xA5FB0A17C777CF09}, // 1e-179
{0x7182148D4066EEB4, 0xCF79CC9DB955C2CC}, // 1e-178
{0xC6F14CD848405530, 0x81AC1FE293D599BF}, // 1e-177
{0xB8ADA00E5A506A7C, 0xA21727DB38CB002F}, // 1e-176
{0xA6D90811F0E4851C, 0xCA9CF1D206FDC03B}, // 1e-175
{0x908F4A166D1DA663, 0xFD442E4688BD304A}, // 1e-174
{0x9A598E4E043287FE, 0x9E4A9CEC15763E2E}, // 1e-173
{0x40EFF1E1853F29FD, 0xC5DD44271AD3CDBA}, // 1e-172
{0xD12BEE59E68EF47C, 0xF7549530E188C128}, // 1e-171
{0x82BB74F8301958CE, 0x9A94DD3E8CF578B9}, // 1e-170
{0xE36A52363C1FAF01, 0xC13A148E3032D6E7}, // 1e-169
{0xDC44E6C3CB279AC1, 0xF18899B1BC3F8CA1}, // 1e-168
{0x29AB103A5EF8C0B9, 0x96F5600F15A7B7E5}, // 1e-167
{0x7415D448F6B6F0E7, 0xBCB2B812DB11A5DE}, // 1e-166
{0x111B495B3464AD21, 0xEBDF661791D60F56}, // 1e-165
{0xCAB10DD900BEEC34, 0x936B9FCEBB25C995}, // 1e-164
{0x3D5D514F40EEA742, 0xB84687C269EF3BFB}, // 1e-163
{0x0CB4A5A3112A5112, 0xE65829B3046B0AFA}, // 1e-162
{0x47F0E785EABA72AB, 0x8FF71A0FE2C2E6DC}, // 1e-161
{0x59ED216765690F56, 0xB3F4E093DB73A093}, // 1e-160
{0x306869C13EC3532C, 0xE0F218B8D25088B8}, // 1e-159
{0x1E414218C73A13FB, 0x8C974F7383725573}, // 1e-158
{0xE5D1929EF90898FA, 0xAFBD2350644EEACF}, // 1e-157
{0xDF45F746B74ABF39, 0xDBAC6C247D62A583}, // 1e-156
{0x6B8BBA8C328EB783, 0x894BC396CE5DA772}, // 1e-155
{0x066EA92F3F326564, 0xAB9EB47C81F5114F}, // 1e-154
{0xC80A537B0EFEFEBD, 0xD686619BA27255A2}, // 1e-153
{0xBD06742CE95F5F36, 0x8613FD0145877585}, // 1e-152
{0x2C48113823B73704, 0xA798FC4196E952E7}, // 1e-151
{0xF75A15862CA504C5, 0xD17F3B51FCA3A7A0}, // 1e-150
{0x9A984D73DBE722FB, 0x82EF85133DE648C4}, // 1e-149
{0xC13E60D0D2E0EBBA, 0xA3AB66580D5FDAF5}, // 1e-148
{0x318DF905079926A8, 0xCC963FEE10B7D1B3}, // 1e-147
{0xFDF17746497F7052, 0xFFBBCFE994E5C61F}, // 1e-146
{0xFEB6EA8BEDEFA633, 0x9FD561F1FD0F9BD3}, // 1e-145
{0xFE64A52EE96B8FC0, 0xC7CABA6E7C5382C8}, // 1e-144
{0x3DFDCE7AA3C673B0, 0xF9BD690A1B68637B}, // 1e-143
{0x06BEA10CA65C084E, 0x9C1661A651213E2D}, // 1e-142
{0x486E494FCFF30A62, 0xC31BFA0FE5698DB8}, // 1e-141
{0x5A89DBA3C3EFCCFA, 0xF3E2F893DEC3F126}, // 1e-140
{0xF89629465A75E01C, 0x986DDB5C6B3A76B7}, // 1e-139
{0xF6BBB397F1135823, 0xBE89523386091465}, // 1e-138
{0x746AA07DED582E2C, 0xEE2BA6C0678B597F}, // 1e-137
{0xA8C2A44EB4571CDC, 0x94DB483840B717EF}, // 1e-136
{0x92F34D62616CE413, 0xBA121A4650E4DDEB}, // 1e-135
{0x77B020BAF9C81D17, 0xE896A0D7E51E1566}, // 1e-134
{0x0ACE1474DC1D122E, 0x915E2486EF32CD60}, // 1e-133
{0x0D819992132456BA, 0xB5B5ADA8AAFF80B8}, // 1e-132
{0x10E1FFF697ED6C69, 0xE3231912D5BF60E6}, // 1e-131
{0xCA8D3FFA1EF463C1, 0x8DF5EFABC5979C8F}, // 1e-130
{0xBD308FF8A6B17CB2, 0xB1736B96B6FD83B3}, // 1e-129
{0xAC7CB3F6D05DDBDE, 0xDDD0467C64BCE4A0}, // 1e-128
{0x6BCDF07A423AA96B, 0x8AA22C0DBEF60EE4}, // 1e-127
{0x86C16C98D2C953C6, 0xAD4AB7112EB3929D}, // 1e-126
{0xE871C7BF077BA8B7, 0xD89D64D57A607744}, // 1e-125
{0x11471CD764AD4972, 0x87625F056C7C4A8B}, // 1e-124
{0xD598E40D3DD89BCF, 0xA93AF6C6C79B5D2D}, // 1e-123
{0x4AFF1D108D4EC2C3, 0xD389B47879823479}, // 1e-122
{0xCEDF722A585139BA, 0x843610CB4BF160CB}, // 1e-121
{0xC2974EB4EE658828, 0xA54394FE1EEDB8FE}, // 1e-120
{0x733D226229FEEA32, 0xCE947A3DA6A9273E}, // 1e-119
{0x0806357D5A3F525F, 0x811CCC668829B887}, // 1e-118
{0xCA07C2DCB0CF26F7, 0xA163FF802A3426A8}, // 1e-117
{0xFC89B393DD02F0B5, 0xC9BCFF6034C13052}, // 1e-116
{0xBBAC2078D443ACE2, 0xFC2C3F3841F17C67}, // 1e-115
{0xD54B944B84AA4C0D, 0x9D9BA7832936EDC0}, // 1e-114
{0x0A9E795E65D4DF11, 0xC5029163F384A931}, // 1e-113
{0x4D4617B5FF4A16D5, 0xF64335BCF065D37D}, // 1e-112
{0x504BCED1BF8E4E45, 0x99EA0196163FA42E}, // 1e-111
{0xE45EC2862F71E1D6, 0xC06481FB9BCF8D39}, // 1e-110
{0x5D767327BB4E5A4C, 0xF07DA27A82C37088}, // 1e-109
{0x3A6A07F8D510F86F, 0x964E858C91BA2655}, // 1e-108
{0x890489F70A55368B, 0xBBE226EFB628AFEA}, // 1e-107
{0x2B45AC74CCEA842E, 0xEADAB0ABA3B2DBE5}, // 1e-106
{0x3B0B8BC90012929D, 0x92C8AE6B464FC96F}, // 1e-105
{0x09CE6EBB40173744, 0xB77ADA0617E3BBCB}, // 1e-104
{0xCC420A6A101D0515, 0xE55990879DDCAABD}, // 1e-103
{0x9FA946824A12232D, 0x8F57FA54C2A9EAB6}, // 1e-102
{0x47939822DC96ABF9, 0xB32DF8E9F3546564}, // 1e-101
{0x59787E2B93BC56F7, 0xDFF9772470297EBD}, // 1e-100
{0x57EB4EDB3C55B65A, 0x8BFBEA76C619EF36}, // 1e-99
{0xEDE622920B6B23F1, 0xAEFAE51477A06B03}, // 1e-98
{0xE95FAB368E45ECED, 0xDAB99E59958885C4}, // 1e-97
{0x11DBCB0218EBB414, 0x88B402F7FD75539B}, // 1e-96
{0xD652BDC29F26A119, 0xAAE103B5FCD2A881}, // 1e-95
{0x4BE76D3346F0495F, 0xD59944A37C0752A2}, // 1e-94
{0x6F70A4400C562DDB, 0x857FCAE62D8493A5}, // 1e-93
{0xCB4CCD500F6BB952, 0xA6DFBD9FB8E5B88E}, // 1e-92
{0x7E2000A41346A7A7, 0xD097AD07A71F26B2}, // 1e-91
{0x8ED400668C0C28C8, 0x825ECC24C873782F}, // 1e-90
{0x728900802F0F32FA, 0xA2F67F2DFA90563B}, // 1e-89
{0x4F2B40A03AD2FFB9, 0xCBB41EF979346BCA}, // 1e-88
{0xE2F610C84987BFA8, 0xFEA126B7D78186BC}, // 1e-87
{0x0DD9CA7D2DF4D7C9, 0x9F24B832E6B0F436}, // 1e-86
{0x91503D1C79720DBB, 0xC6EDE63FA05D3143}, // 1e-85
{0x75A44C6397CE912A, 0xF8A95FCF88747D94}, // 1e-84
{0xC986AFBE3EE11ABA, 0x9B69DBE1B548CE7C}, // 1e-83
{0xFBE85BADCE996168, 0xC24452DA229B021B}, // 1e-82
{0xFAE27299423FB9C3, 0xF2D56790AB41C2A2}, // 1e-81
{0xDCCD879FC967D41A, 0x97C560BA6B0919A5}, // 1e-80
{0x5400E987BBC1C920, 0xBDB6B8E905CB600F}, // 1e-79
{0x290123E9AAB23B68, 0xED246723473E3813}, // 1e-78
{0xF9A0B6720AAF6521, 0x9436C0760C86E30B}, // 1e-77
{0xF808E40E8D5B3E69, 0xB94470938FA89BCE}, // 1e-76
{0xB60B1D1230B20E04, 0xE7958CB87392C2C2}, // 1e-75
{0xB1C6F22B5E6F48C2, 0x90BD77F3483BB9B9}, // 1e-74
{0x1E38AEB6360B1AF3, 0xB4ECD5F01A4AA828}, // 1e-73
{0x25C6DA63C38DE1B0, 0xE2280B6C20DD5232}, // 1e-72
{0x579C487E5A38AD0E, 0x8D590723948A535F}, // 1e-71
{0x2D835A9DF0C6D851, 0xB0AF48EC79ACE837}, // 1e-70
{0xF8E431456CF88E65, 0xDCDB1B2798182244}, // 1e-69
{0x1B8E9ECB641B58FF, 0x8A08F0F8BF0F156B}, // 1e-68
{0xE272467E3D222F3F, 0xAC8B2D36EED2DAC5}, // 1e-67
{0x5B0ED81DCC6ABB0F, 0xD7ADF884AA879177}, // 1e-66
{0x98E947129FC2B4E9, 0x86CCBB52EA94BAEA}, // 1e-65
{0x3F2398D747B36224, 0xA87FEA27A539E9A5}, // 1e-64
{0x8EEC7F0D19A03AAD, 0xD29FE4B18E88640E}, // 1e-63
{0x1953CF68300424AC, 0x83A3EEEEF9153E89}, // 1e-62
{0x5FA8C3423C052DD7, 0xA48CEAAAB75A8E2B}, // 1e-61
{0x3792F412CB06794D, 0xCDB02555653131B6}, // 1e-60
{0xE2BBD88BBEE40BD0, 0x808E17555F3EBF11}, // 1e-59
{0x5B6ACEAEAE9D0EC4, 0xA0B19D2AB70E6ED6}, // 1e-58
{0xF245825A5A445275, 0xC8DE047564D20A8B}, // 1e-57
{0xEED6E2F0F0D56712, 0xFB158592BE068D2E}, // 1e-56
{0x55464DD69685606B, 0x9CED737BB6C4183D}, // 1e-55
{0xAA97E14C3C26B886, 0xC428D05AA4751E4C}, // 1e-54
{0xD53DD99F4B3066A8, 0xF53304714D9265DF}, // 1e-53
{0xE546A8038EFE4029, 0x993FE2C6D07B7FAB}, // 1e-52
{0xDE98520472BDD033, 0xBF8FDB78849A5F96}, // 1e-51
{0x963E66858F6D4440, 0xEF73D256A5C0F77C}, // 1e-50
{0xDDE7001379A44AA8, 0x95A8637627989AAD}, // 1e-49
{0x5560C018580D5D52, 0xBB127C53B17EC159}, // 1e-48
{0xAAB8F01E6E10B4A6, 0xE9D71B689DDE71AF}, // 1e-47
{0xCAB3961304CA70E8, 0x9226712162AB070D}, // 1e-46
{0x3D607B97C5FD0D22, 0xB6B00D69BB55C8D1}, // 1e-45
{0x8CB89A7DB77C506A, 0xE45C10C42A2B3B05}, // 1e-44
{0x77F3608E92ADB242, 0x8EB98A7A9A5B04E3}, // 1e-43
{0x55F038B237591ED3, 0xB267ED1940F1C61C}, // 1e-42
{0x6B6C46DEC52F6688, 0xDF01E85F912E37A3}, // 1e-41
{0x2323AC4B3B3DA015, 0x8B61313BBABCE2C6}, // 1e-40
{0xABEC975E0A0D081A, 0xAE397D8AA96C1B77}, // 1e-39
{0x96E7BD358C904A21, 0xD9C7DCED53C72255}, // 1e-38
{0x7E50D64177DA2E54, 0x881CEA14545C7575}, // 1e-37
{0xDDE50BD1D5D0B9E9, 0xAA242499697392D2}, // 1e-36
{0x955E4EC64B44E864, 0xD4AD2DBFC3D07787}, // 1e-35
{0xBD5AF13BEF0B113E, 0x84EC3C97DA624AB4}, // 1e-34
{0xECB1AD8AEACDD58E, 0xA6274BBDD0FADD61}, // 1e-33
{0x67DE18EDA5814AF2, 0xCFB11EAD453994BA}, // 1e-32
{0x80EACF948770CED7, 0x81CEB32C4B43FCF4}, // 1e-31
{0xA1258379A94D028D, 0xA2425FF75E14FC31}, // 1e-30
{0x096EE45813A04330, 0xCAD2F7F5359A3B3E}, // 1e-29
{0x8BCA9D6E188853FC, 0xFD87B5F28300CA0D}, // 1e-28
{0x775EA264CF55347D, 0x9E74D1B791E07E48}, // 1e-27
{0x95364AFE032A819D, 0xC612062576589DDA}, // 1e-26
{0x3A83DDBD83F52204, 0xF79687AED3EEC551}, // 1e-25
{0xC4926A9672793542, 0x9ABE14CD44753B52}, // 1e-24
{0x75B7053C0F178293, 0xC16D9A0095928A27}, // 1e-23
{0x5324C68B12DD6338, 0xF1C90080BAF72CB1}, // 1e-22
{0xD3F6FC16EBCA5E03, 0x971DA05074DA7BEE}, // 1e-21
{0x88F4BB1CA6BCF584, 0xBCE5086492111AEA}, // 1e-20
{0x2B31E9E3D06C32E5, 0xEC1E4A7DB69561A5}, // 1e-19
{0x3AFF322E62439FCF, 0x9392EE8E921D5D07}, // 1e-18
{0x09BEFEB9FAD487C2, 0xB877AA3236A4B449}, // 1e-17
{0x4C2EBE687989A9B3, 0xE69594BEC44DE15B}, // 1e-16
{0x0F9D37014BF60A10, 0x901D7CF73AB0ACD9}, // 1e-15
{0x538484C19EF38C94, 0xB424DC35095CD80F}, // 1e-14
{0x2865A5F206B06FB9, 0xE12E13424BB40E13}, // 1e-13
{0xF93F87B7442E45D3, 0x8CBCCC096F5088CB}, // 1e-12
{0xF78F69A51539D748, 0xAFEBFF0BCB24AAFE}, // 1e-11
{0xB573440E5A884D1B, 0xDBE6FECEBDEDD5BE}, // 1e-10
{0x31680A88F8953030, 0x89705F4136B4A597}, // 1e-9
{0xFDC20D2B36BA7C3D, 0xABCC77118461CEFC}, // 1e-8
{0x3D32907604691B4C, 0xD6BF94D5E57A42BC}, // 1e-7
{0xA63F9A49C2C1B10F, 0x8637BD05AF6C69B5}, // 1e-6
{0x0FCF80DC33721D53, 0xA7C5AC471B478423}, // 1e-5
{0xD3C36113404EA4A8, 0xD1B71758E219652B}, // 1e-4
{0x645A1CAC083126E9, 0x83126E978D4FDF3B}, // 1e-3
{0x3D70A3D70A3D70A3, 0xA3D70A3D70A3D70A}, // 1e-2
{0xCCCCCCCCCCCCCCCC, 0xCCCCCCCCCCCCCCCC}, // 1e-1
{0x0000000000000000, 0x8000000000000000}, // 1e0
{0x0000000000000000, 0xA000000000000000}, // 1e1
{0x0000000000000000, 0xC800000000000000}, // 1e2
{0x0000000000000000, 0xFA00000000000000}, // 1e3
{0x0000000000000000, 0x9C40000000000000}, // 1e4
{0x0000000000000000, 0xC350000000000000}, // 1e5
{0x0000000000000000, 0xF424000000000000}, // 1e6
{0x0000000000000000, 0x9896800000000000}, // 1e7
{0x0000000000000000, 0xBEBC200000000000}, // 1e8
{0x0000000000000000, 0xEE6B280000000000}, // 1e9
{0x0000000000000000, 0x9502F90000000000}, // 1e10
{0x0000000000000000, 0xBA43B74000000000}, // 1e11
{0x0000000000000000, 0xE8D4A51000000000}, // 1e12
{0x0000000000000000, 0x9184E72A00000000}, // 1e13
{0x0000000000000000, 0xB5E620F480000000}, // 1e14
{0x0000000000000000, 0xE35FA931A0000000}, // 1e15
{0x0000000000000000, 0x8E1BC9BF04000000}, // 1e16
{0x0000000000000000, 0xB1A2BC2EC5000000}, // 1e17
{0x0000000000000000, 0xDE0B6B3A76400000}, // 1e18
{0x0000000000000000, 0x8AC7230489E80000}, // 1e19
{0x0000000000000000, 0xAD78EBC5AC620000}, // 1e20
{0x0000000000000000, 0xD8D726B7177A8000}, // 1e21
{0x0000000000000000, 0x878678326EAC9000}, // 1e22
{0x0000000000000000, 0xA968163F0A57B400}, // 1e23
{0x0000000000000000, 0xD3C21BCECCEDA100}, // 1e24
{0x0000000000000000, 0x84595161401484A0}, // 1e25
{0x0000000000000000, 0xA56FA5B99019A5C8}, // 1e26
{0x0000000000000000, 0xCECB8F27F4200F3A}, // 1e27
{0x4000000000000000, 0x813F3978F8940984}, // 1e28
{0x5000000000000000, 0xA18F07D736B90BE5}, // 1e29
{0xA400000000000000, 0xC9F2C9CD04674EDE}, // 1e30
{0x4D00000000000000, 0xFC6F7C4045812296}, // 1e31
{0xF020000000000000, 0x9DC5ADA82B70B59D}, // 1e32
{0x6C28000000000000, 0xC5371912364CE305}, // 1e33
{0xC732000000000000, 0xF684DF56C3E01BC6}, // 1e34
{0x3C7F400000000000, 0x9A130B963A6C115C}, // 1e35
{0x4B9F100000000000, 0xC097CE7BC90715B3}, // 1e36
{0x1E86D40000000000, 0xF0BDC21ABB48DB20}, // 1e37
{0x1314448000000000, 0x96769950B50D88F4}, // 1e38
{0x17D955A000000000, 0xBC143FA4E250EB31}, // 1e39
{0x5DCFAB0800000000, 0xEB194F8E1AE525FD}, // 1e40
{0x5AA1CAE500000000, 0x92EFD1B8D0CF37BE}, // 1e41
{0xF14A3D9E40000000, 0xB7ABC627050305AD}, // 1e42
{0x6D9CCD05D0000000, 0xE596B7B0C643C719}, // 1e43
{0xE4820023A2000000, 0x8F7E32CE7BEA5C6F}, // 1e44
{0xDDA2802C8A800000, 0xB35DBF821AE4F38B}, // 1e45
{0xD50B2037AD200000, 0xE0352F62A19E306E}, // 1e46
{0x4526F422CC340000, 0x8C213D9DA502DE45}, // 1e47
{0x9670B12B7F410000, 0xAF298D050E4395D6}, // 1e48
{0x3C0CDD765F114000, 0xDAF3F04651D47B4C}, // 1e49
{0xA5880A69FB6AC800, 0x88D8762BF324CD0F}, // 1e50
{0x8EEA0D047A457A00, 0xAB0E93B6EFEE0053}, // 1e51
{0x72A4904598D6D880, 0xD5D238A4ABE98068}, // 1e52
{0x47A6DA2B7F864750, 0x85A36366EB71F041}, // 1e53
{0x999090B65F67D924, 0xA70C3C40A64E6C51}, // 1e54
{0xFFF4B4E3F741CF6D, 0xD0CF4B50CFE20765}, // 1e55
{0xBFF8F10E7A8921A4, 0x82818F1281ED449F}, // 1e56
{0xAFF72D52192B6A0D, 0xA321F2D7226895C7}, // 1e57
{0x9BF4F8A69F764490, 0xCBEA6F8CEB02BB39}, // 1e58
{0x02F236D04753D5B4, 0xFEE50B7025C36A08}, // 1e59
{0x01D762422C946590, 0x9F4F2726179A2245}, // 1e60
{0x424D3AD2B7B97EF5, 0xC722F0EF9D80AAD6}, // 1e61
{0xD2E0898765A7DEB2, 0xF8EBAD2B84E0D58B}, // 1e62
{0x63CC55F49F88EB2F, 0x9B934C3B330C8577}, // 1e63
{0x3CBF6B71C76B25FB, 0xC2781F49FFCFA6D5}, // 1e64
{0x8BEF464E3945EF7A, 0xF316271C7FC3908A}, // 1e65
{0x97758BF0E3CBB5AC, 0x97EDD871CFDA3A56}, // 1e66
{0x3D52EEED1CBEA317, 0xBDE94E8E43D0C8EC}, // 1e67
{0x4CA7AAA863EE4BDD, 0xED63A231D4C4FB27}, // 1e68
{0x8FE8CAA93E74EF6A, 0x945E455F24FB1CF8}, // 1e69
{0xB3E2FD538E122B44, 0xB975D6B6EE39E436}, // 1e70
{0x60DBBCA87196B616, 0xE7D34C64A9C85D44}, // 1e71
{0xBC8955E946FE31CD, 0x90E40FBEEA1D3A4A}, // 1e72
{0x6BABAB6398BDBE41, 0xB51D13AEA4A488DD}, // 1e73
{0xC696963C7EED2DD1, 0xE264589A4DCDAB14}, // 1e74
{0xFC1E1DE5CF543CA2, 0x8D7EB76070A08AEC}, // 1e75
{0x3B25A55F43294BCB, 0xB0DE65388CC8ADA8}, // 1e76
{0x49EF0EB713F39EBE, 0xDD15FE86AFFAD912}, // 1e77
{0x6E3569326C784337, 0x8A2DBF142DFCC7AB}, // 1e78
{0x49C2C37F07965404, 0xACB92ED9397BF996}, // 1e79
{0xDC33745EC97BE906, 0xD7E77A8F87DAF7FB}, // 1e80
{0x69A028BB3DED71A3, 0x86F0AC99B4E8DAFD}, // 1e81
{0xC40832EA0D68CE0C, 0xA8ACD7C0222311BC}, // 1e82
{0xF50A3FA490C30190, 0xD2D80DB02AABD62B}, // 1e83
{0x792667C6DA79E0FA, 0x83C7088E1AAB65DB}, // 1e84
{0x577001B891185938, 0xA4B8CAB1A1563F52}, // 1e85
{0xED4C0226B55E6F86, 0xCDE6FD5E09ABCF26}, // 1e86
{0x544F8158315B05B4, 0x80B05E5AC60B6178}, // 1e87
{0x696361AE3DB1C721, 0xA0DC75F1778E39D6}, // 1e88
{0x03BC3A19CD1E38E9, 0xC913936DD571C84C}, // 1e89
{0x04AB48A04065C723, 0xFB5878494ACE3A5F}, // 1e90
{0x62EB0D64283F9C76, 0x9D174B2DCEC0E47B}, // 1e91
{0x3BA5D0BD324F8394, 0xC45D1DF942711D9A}, // 1e92
{0xCA8F44EC7EE36479, 0xF5746577930D6500}, // 1e93
{0x7E998B13CF4E1ECB, 0x9968BF6ABBE85F20}, // 1e94
{0x9E3FEDD8C321A67E, 0xBFC2EF456AE276E8}, // 1e95
{0xC5CFE94EF3EA101E, 0xEFB3AB16C59B14A2}, // 1e96
{0xBBA1F1D158724A12, 0x95D04AEE3B80ECE5}, // 1e97
{0x2A8A6E45AE8EDC97, 0xBB445DA9CA61281F}, // 1e98
{0xF52D09D71A3293BD, 0xEA1575143CF97226}, // 1e99
{0x593C2626705F9C56, 0x924D692CA61BE758}, // 1e100
{0x6F8B2FB00C77836C, 0xB6E0C377CFA2E12E}, // 1e101
{0x0B6DFB9C0F956447, 0xE498F455C38B997A}, // 1e102
{0x4724BD4189BD5EAC, 0x8EDF98B59A373FEC}, // 1e103
{0x58EDEC91EC2CB657, 0xB2977EE300C50FE7}, // 1e104
{0x2F2967B66737E3ED, 0xDF3D5E9BC0F653E1}, // 1e105
{0xBD79E0D20082EE74, 0x8B865B215899F46C}, // 1e106
{0xECD8590680A3AA11, 0xAE67F1E9AEC07187}, // 1e107
{0xE80E6F4820CC9495, 0xDA01EE641A708DE9}, // 1e108
{0x3109058D147FDCDD, 0x884134FE908658B2}, // 1e109
{0xBD4B46F0599FD415, 0xAA51823E34A7EEDE}, // 1e110
{0x6C9E18AC7007C91A, 0xD4E5E2CDC1D1EA96}, // 1e111
{0x03E2CF6BC604DDB0, 0x850FADC09923329E}, // 1e112
{0x84DB8346B786151C, 0xA6539930BF6BFF45}, // 1e113
{0xE612641865679A63, 0xCFE87F7CEF46FF16}, // 1e114
{0x4FCB7E8F3F60C07E, 0x81F14FAE158C5F6E}, // 1e115
{0xE3BE5E330F38F09D, 0xA26DA3999AEF7749}, // 1e116
{0x5CADF5BFD3072CC5, 0xCB090C8001AB551C}, // 1e117
{0x73D9732FC7C8F7F6, 0xFDCB4FA002162A63}, // 1e118
{0x2867E7FDDCDD9AFA, 0x9E9F11C4014DDA7E}, // 1e119
{0xB281E1FD541501B8, 0xC646D63501A1511D}, // 1e120
{0x1F225A7CA91A4226, 0xF7D88BC24209A565}, // 1e121
{0x3375788DE9B06958, 0x9AE757596946075F}, // 1e122
{0x0052D6B1641C83AE, 0xC1A12D2FC3978937}, // 1e123
{0xC0678C5DBD23A49A, 0xF209787BB47D6B84}, // 1e124
{0xF840B7BA963646E0, 0x9745EB4D50CE6332}, // 1e125
{0xB650E5A93BC3D898, 0xBD176620A501FBFF}, // 1e126
{0xA3E51F138AB4CEBE, 0xEC5D3FA8CE427AFF}, // 1e127
{0xC66F336C36B10137, 0x93BA47C980E98CDF}, // 1e128
{0xB80B0047445D4184, 0xB8A8D9BBE123F017}, // 1e129
{0xA60DC059157491E5, 0xE6D3102AD96CEC1D}, // 1e130
{0x87C89837AD68DB2F, 0x9043EA1AC7E41392}, // 1e131
{0x29BABE4598C311FB, 0xB454E4A179DD1877}, // 1e132
{0xF4296DD6FEF3D67A, 0xE16A1DC9D8545E94}, // 1e133
{0x1899E4A65F58660C, 0x8CE2529E2734BB1D}, // 1e134
{0x5EC05DCFF72E7F8F, 0xB01AE745B101E9E4}, // 1e135
{0x76707543F4FA1F73, 0xDC21A1171D42645D}, // 1e136
{0x6A06494A791C53A8, 0x899504AE72497EBA}, // 1e137
{0x0487DB9D17636892, 0xABFA45DA0EDBDE69}, // 1e138
{0x45A9D2845D3C42B6, 0xD6F8D7509292D603}, // 1e139
{0x0B8A2392BA45A9B2, 0x865B86925B9BC5C2}, // 1e140
{0x8E6CAC7768D7141E, 0xA7F26836F282B732}, // 1e141
{0x3207D795430CD926, 0xD1EF0244AF2364FF}, // 1e142
{0x7F44E6BD49E807B8, 0x8335616AED761F1F}, // 1e143
{0x5F16206C9C6209A6, 0xA402B9C5A8D3A6E7}, // 1e144
{0x36DBA887C37A8C0F, 0xCD036837130890A1}, // 1e145
{0xC2494954DA2C9789, 0x802221226BE55A64}, // 1e146
{0xF2DB9BAA10B7BD6C, 0xA02AA96B06DEB0FD}, // 1e147
{0x6F92829494E5ACC7, 0xC83553C5C8965D3D}, // 1e148
{0xCB772339BA1F17F9, 0xFA42A8B73ABBF48C}, // 1e149
{0xFF2A760414536EFB, 0x9C69A97284B578D7}, // 1e150
{0xFEF5138519684ABA, 0xC38413CF25E2D70D}, // 1e151
{0x7EB258665FC25D69, 0xF46518C2EF5B8CD1}, // 1e152
{0xEF2F773FFBD97A61, 0x98BF2F79D5993802}, // 1e153
{0xAAFB550FFACFD8FA, 0xBEEEFB584AFF8603}, // 1e154
{0x95BA2A53F983CF38, 0xEEAABA2E5DBF6784}, // 1e155
{0xDD945A747BF26183, 0x952AB45CFA97A0B2}, // 1e156
{0x94F971119AEEF9E4, 0xBA756174393D88DF}, // 1e157
{0x7A37CD5601AAB85D, 0xE912B9D1478CEB17}, // 1e158
{0xAC62E055C10AB33A, 0x91ABB422CCB812EE}, // 1e159
{0x577B986B314D6009, 0xB616A12B7FE617AA}, // 1e160
{0xED5A7E85FDA0B80B, 0xE39C49765FDF9D94}, // 1e161
{0x14588F13BE847307, 0x8E41ADE9FBEBC27D}, // 1e162
{0x596EB2D8AE258FC8, 0xB1D219647AE6B31C}, // 1e163
{0x6FCA5F8ED9AEF3BB, 0xDE469FBD99A05FE3}, // 1e164
{0x25DE7BB9480D5854, 0x8AEC23D680043BEE}, // 1e165
{0xAF561AA79A10AE6A, 0xADA72CCC20054AE9}, // 1e166
{0x1B2BA1518094DA04, 0xD910F7FF28069DA4}, // 1e167
{0x90FB44D2F05D0842, 0x87AA9AFF79042286}, // 1e168
{0x353A1607AC744A53, 0xA99541BF57452B28}, // 1e169
{0x42889B8997915CE8, 0xD3FA922F2D1675F2}, // 1e170
{0x69956135FEBADA11, 0x847C9B5D7C2E09B7}, // 1e171
{0x43FAB9837E699095, 0xA59BC234DB398C25}, // 1e172
{0x94F967E45E03F4BB, 0xCF02B2C21207EF2E}, // 1e173
{0x1D1BE0EEBAC278F5, 0x8161AFB94B44F57D}, // 1e174
{0x6462D92A69731732, 0xA1BA1BA79E1632DC}, // 1e175
{0x7D7B8F7503CFDCFE, 0xCA28A291859BBF93}, // 1e176
{0x5CDA735244C3D43E, 0xFCB2CB35E702AF78}, // 1e177
{0x3A0888136AFA64A7, 0x9DEFBF01B061ADAB}, // 1e178
{0x088AAA1845B8FDD0, 0xC56BAEC21C7A1916}, // 1e179
{0x8AAD549E57273D45, 0xF6C69A72A3989F5B}, // 1e180
{0x36AC54E2F678864B, 0x9A3C2087A63F6399}, // 1e181
{0x84576A1BB416A7DD, 0xC0CB28A98FCF3C7F}, // 1e182
{0x656D44A2A11C51D5, 0xF0FDF2D3F3C30B9F}, // 1e183
{0x9F644AE5A4B1B325, 0x969EB7C47859E743}, // 1e184
{0x873D5D9F0DDE1FEE, 0xBC4665B596706114}, // 1e185
{0xA90CB506D155A7EA, 0xEB57FF22FC0C7959}, // 1e186
{0x09A7F12442D588F2, 0x9316FF75DD87CBD8}, // 1e187
{0x0C11ED6D538AEB2F, 0xB7DCBF5354E9BECE}, // 1e188
{0x8F1668C8A86DA5FA, 0xE5D3EF282A242E81}, // 1e189
{0xF96E017D694487BC, 0x8FA475791A569D10}, // 1e190
{0x37C981DCC395A9AC, 0xB38D92D760EC4455}, // 1e191
{0x85BBE253F47B1417, 0xE070F78D3927556A}, // 1e192
{0x93956D7478CCEC8E, 0x8C469AB843B89562}, // 1e193
{0x387AC8D1970027B2, 0xAF58416654A6BABB}, // 1e194
{0x06997B05FCC0319E, 0xDB2E51BFE9D0696A}, // 1e195
{0x441FECE3BDF81F03, 0x88FCF317F22241E2}, // 1e196
{0xD527E81CAD7626C3, 0xAB3C2FDDEEAAD25A}, // 1e197
{0x8A71E223D8D3B074, 0xD60B3BD56A5586F1}, // 1e198
{0xF6872D5667844E49, 0x85C7056562757456}, // 1e199
{0xB428F8AC016561DB, 0xA738C6BEBB12D16C}, // 1e200
{0xE13336D701BEBA52, 0xD106F86E69D785C7}, // 1e201
{0xECC0024661173473, 0x82A45B450226B39C}, // 1e202
{0x27F002D7F95D0190, 0xA34D721642B06084}, // 1e203
{0x31EC038DF7B441F4, 0xCC20CE9BD35C78A5}, // 1e204
{0x7E67047175A15271, 0xFF290242C83396CE}, // 1e205
{0x0F0062C6E984D386, 0x9F79A169BD203E41}, // 1e206
{0x52C07B78A3E60868, 0xC75809C42C684DD1}, // 1e207
{0xA7709A56CCDF8A82, 0xF92E0C3537826145}, // 1e208
{0x88A66076400BB691, 0x9BBCC7A142B17CCB}, // 1e209
{0x6ACFF893D00EA435, 0xC2ABF989935DDBFE}, // 1e210
{0x0583F6B8C4124D43, 0xF356F7EBF83552FE}, // 1e211
{0xC3727A337A8B704A, 0x98165AF37B2153DE}, // 1e212
{0x744F18C0592E4C5C, 0xBE1BF1B059E9A8D6}, // 1e213
{0x1162DEF06F79DF73, 0xEDA2EE1C7064130C}, // 1e214
{0x8ADDCB5645AC2BA8, 0x9485D4D1C63E8BE7}, // 1e215
{0x6D953E2BD7173692, 0xB9A74A0637CE2EE1}, // 1e216
{0xC8FA8DB6CCDD0437, 0xE8111C87C5C1BA99}, // 1e217
{0x1D9C9892400A22A2, 0x910AB1D4DB9914A0}, // 1e218
{0x2503BEB6D00CAB4B, 0xB54D5E4A127F59C8}, // 1e219
{0x2E44AE64840FD61D, 0xE2A0B5DC971F303A}, // 1e220
{0x5CEAECFED289E5D2, 0x8DA471A9DE737E24}, // 1e221
{0x7425A83E872C5F47, 0xB10D8E1456105DAD}, // 1e222
{0xD12F124E28F77719, 0xDD50F1996B947518}, // 1e223
{0x82BD6B70D99AAA6F, 0x8A5296FFE33CC92F}, // 1e224
{0x636CC64D1001550B, 0xACE73CBFDC0BFB7B}, // 1e225
{0x3C47F7E05401AA4E, 0xD8210BEFD30EFA5A}, // 1e226
{0x65ACFAEC34810A71, 0x8714A775E3E95C78}, // 1e227
{0x7F1839A741A14D0D, 0xA8D9D1535CE3B396}, // 1e228
{0x1EDE48111209A050, 0xD31045A8341CA07C}, // 1e229
{0x934AED0AAB460432, 0x83EA2B892091E44D}, // 1e230
{0xF81DA84D5617853F, 0xA4E4B66B68B65D60}, // 1e231
{0x36251260AB9D668E, 0xCE1DE40642E3F4B9}, // 1e232
{0xC1D72B7C6B426019, 0x80D2AE83E9CE78F3}, // 1e233
{0xB24CF65B8612F81F, 0xA1075A24E4421730}, // 1e234
{0xDEE033F26797B627, 0xC94930AE1D529CFC}, // 1e235
{0x169840EF017DA3B1, 0xFB9B7CD9A4A7443C}, // 1e236
{0x8E1F289560EE864E, 0x9D412E0806E88AA5}, // 1e237
{0xF1A6F2BAB92A27E2, 0xC491798A08A2AD4E}, // 1e238
{0xAE10AF696774B1DB, 0xF5B5D7EC8ACB58A2}, // 1e239
{0xACCA6DA1E0A8EF29, 0x9991A6F3D6BF1765}, // 1e240
{0x17FD090A58D32AF3, 0xBFF610B0CC6EDD3F}, // 1e241
{0xDDFC4B4CEF07F5B0, 0xEFF394DCFF8A948E}, // 1e242
{0x4ABDAF101564F98E, 0x95F83D0A1FB69CD9}, // 1e243
{0x9D6D1AD41ABE37F1, 0xBB764C4CA7A4440F}, // 1e244
{0x84C86189216DC5ED, 0xEA53DF5FD18D5513}, // 1e245
{0x32FD3CF5B4E49BB4, 0x92746B9BE2F8552C}, // 1e246
{0x3FBC8C33221DC2A1, 0xB7118682DBB66A77}, // 1e247
{0x0FABAF3FEAA5334A, 0xE4D5E82392A40515}, // 1e248
{0x29CB4D87F2A7400E, 0x8F05B1163BA6832D}, // 1e249
{0x743E20E9EF511012, 0xB2C71D5BCA9023F8}, // 1e250
{0x914DA9246B255416, 0xDF78E4B2BD342CF6}, // 1e251
{0x1AD089B6C2F7548E, 0x8BAB8EEFB6409C1A}, // 1e252
{0xA184AC2473B529B1, 0xAE9672ABA3D0C320}, // 1e253
{0xC9E5D72D90A2741E, 0xDA3C0F568CC4F3E8}, // 1e254
{0x7E2FA67C7A658892, 0x8865899617FB1871}, // 1e255
{0xDDBB901B98FEEAB7, 0xAA7EEBFB9DF9DE8D}, // 1e256
{0x552A74227F3EA565, 0xD51EA6FA85785631}, // 1e257
{0xD53A88958F87275F, 0x8533285C936B35DE}, // 1e258
{0x8A892ABAF368F137, 0xA67FF273B8460356}, // 1e259
{0x2D2B7569B0432D85, 0xD01FEF10A657842C}, // 1e260
{0x9C3B29620E29FC73, 0x8213F56A67F6B29B}, // 1e261
{0x8349F3BA91B47B8F, 0xA298F2C501F45F42}, // 1e262
{0x241C70A936219A73, 0xCB3F2F7642717713}, // 1e263
{0xED238CD383AA0110, 0xFE0EFB53D30DD4D7}, // 1e264
{0xF4363804324A40AA, 0x9EC95D1463E8A506}, // 1e265
{0xB143C6053EDCD0D5, 0xC67BB4597CE2CE48}, // 1e266
{0xDD94B7868E94050A, 0xF81AA16FDC1B81DA}, // 1e267
{0xCA7CF2B4191C8326, 0x9B10A4E5E9913128}, // 1e268
{0xFD1C2F611F63A3F0, 0xC1D4CE1F63F57D72}, // 1e269
{0xBC633B39673C8CEC, 0xF24A01A73CF2DCCF}, // 1e270
{0xD5BE0503E085D813, 0x976E41088617CA01}, // 1e271
{0x4B2D8644D8A74E18, 0xBD49D14AA79DBC82}, // 1e272
{0xDDF8E7D60ED1219E, 0xEC9C459D51852BA2}, // 1e273
{0xCABB90E5C942B503, 0x93E1AB8252F33B45}, // 1e274
{0x3D6A751F3B936243, 0xB8DA1662E7B00A17}, // 1e275
{0x0CC512670A783AD4, 0xE7109BFBA19C0C9D}, // 1e276
{0x27FB2B80668B24C5, 0x906A617D450187E2}, // 1e277
{0xB1F9F660802DEDF6, 0xB484F9DC9641E9DA}, // 1e278
{0x5E7873F8A0396973, 0xE1A63853BBD26451}, // 1e279
{0xDB0B487B6423E1E8, 0x8D07E33455637EB2}, // 1e280
{0x91CE1A9A3D2CDA62, 0xB049DC016ABC5E5F}, // 1e281
{0x7641A140CC7810FB, 0xDC5C5301C56B75F7}, // 1e282
{0xA9E904C87FCB0A9D, 0x89B9B3E11B6329BA}, // 1e283
{0x546345FA9FBDCD44, 0xAC2820D9623BF429}, // 1e284
{0xA97C177947AD4095, 0xD732290FBACAF133}, // 1e285
{0x49ED8EABCCCC485D, 0x867F59A9D4BED6C0}, // 1e286
{0x5C68F256BFFF5A74, 0xA81F301449EE8C70}, // 1e287
{0x73832EEC6FFF3111, 0xD226FC195C6A2F8C}, // 1e288
{0xC831FD53C5FF7EAB, 0x83585D8FD9C25DB7}, // 1e289
{0xBA3E7CA8B77F5E55, 0xA42E74F3D032F525}, // 1e290
{0x28CE1BD2E55F35EB, 0xCD3A1230C43FB26F}, // 1e291
{0x7980D163CF5B81B3, 0x80444B5E7AA7CF85}, // 1e292
{0xD7E105BCC332621F, 0xA0555E361951C366}, // 1e293
{0x8DD9472BF3FEFAA7, 0xC86AB5C39FA63440}, // 1e294
{0xB14F98F6F0FEB951, 0xFA856334878FC150}, // 1e295
{0x6ED1BF9A569F33D3, 0x9C935E00D4B9D8D2}, // 1e296
{0x0A862F80EC4700C8, 0xC3B8358109E84F07}, // 1e297
{0xCD27BB612758C0FA, 0xF4A642E14C6262C8}, // 1e298
{0x8038D51CB897789C, 0x98E7E9CCCFBD7DBD}, // 1e299
{0xE0470A63E6BD56C3, 0xBF21E44003ACDD2C}, // 1e300
{0x1858CCFCE06CAC74, 0xEEEA5D5004981478}, // 1e301
{0x0F37801E0C43EBC8, 0x95527A5202DF0CCB}, // 1e302
{0xD30560258F54E6BA, 0xBAA718E68396CFFD}, // 1e303
{0x47C6B82EF32A2069, 0xE950DF20247C83FD}, // 1e304
{0x4CDC331D57FA5441, 0x91D28B7416CDD27E}, // 1e305
{0xE0133FE4ADF8E952, 0xB6472E511C81471D}, // 1e306
{0x58180FDDD97723A6, 0xE3D8F9E563A198E5}, // 1e307
{0x570F09EAA7EA7648, 0x8E679C2F5E44FF8F}, // 1e308
{0x2CD2CC6551E513DA, 0xB201833B35D63F73}, // 1e309
{0xF8077F7EA65E58D1, 0xDE81E40A034BCF4F}, // 1e310
{0xFB04AFAF27FAF782, 0x8B112E86420F6191}, // 1e311
{0x79C5DB9AF1F9B563, 0xADD57A27D29339F6}, // 1e312
{0x18375281AE7822BC, 0xD94AD8B1C7380874}, // 1e313
{0x8F2293910D0B15B5, 0x87CEC76F1C830548}, // 1e314
{0xB2EB3875504DDB22, 0xA9C2794AE3A3C69A}, // 1e315
{0x5FA60692A46151EB, 0xD433179D9C8CB841}, // 1e316
{0xDBC7C41BA6BCD333, 0x849FEEC281D7F328}, // 1e317
{0x12B9B522906C0800, 0xA5C7EA73224DEFF3}, // 1e318
{0xD768226B34870A00, 0xCF39E50FEAE16BEF}, // 1e319
{0xE6A1158300D46640, 0x81842F29F2CCE375}, // 1e320
{0x60495AE3C1097FD0, 0xA1E53AF46F801C53}, // 1e321
{0x385BB19CB14BDFC4, 0xCA5E89B18B602368}, // 1e322
{0x46729E03DD9ED7B5, 0xFCF62C1DEE382C42}, // 1e323
{0x6C07A2C26A8346D1, 0x9E19DB92B4E31BA9}, // 1e324
{0xC7098B7305241885, 0xC5A05277621BE293}, // 1e325
{0xB8CBEE4FC66D1EA7, 0xF70867153AA2DB38}, // 1e326
{0x737F74F1DC043328, 0x9A65406D44A5C903}, // 1e327
{0x505F522E53053FF2, 0xC0FE908895CF3B44}, // 1e328
{0x647726B9E7C68FEF, 0xF13E34AABB430A15}, // 1e329
{0x5ECA783430DC19F5, 0x96C6E0EAB509E64D}, // 1e330
{0xB67D16413D132072, 0xBC789925624C5FE0}, // 1e331
{0xE41C5BD18C57E88F, 0xEB96BF6EBADF77D8}, // 1e332
{0x8E91B962F7B6F159, 0x933E37A534CBAAE7}, // 1e333
{0x723627BBB5A4ADB0, 0xB80DC58E81FE95A1}, // 1e334
{0xCEC3B1AAA30DD91C, 0xE61136F2227E3B09}, // 1e335
{0x213A4F0AA5E8A7B1, 0x8FCAC257558EE4E6}, // 1e336
{0xA988E2CD4F62D19D, 0xB3BD72ED2AF29E1F}, // 1e337
{0x93EB1B80A33B8605, 0xE0ACCFA875AF45A7}, // 1e338
{0xBC72F130660533C3, 0x8C6C01C9498D8B88}, // 1e339
{0xEB8FAD7C7F8680B4, 0xAF87023B9BF0EE6A}, // 1e340
{0xA67398DB9F6820E1, 0xDB68C2CA82ED2A05}, // 1e341
{0x88083F8943A1148C, 0x892179BE91D43A43}, // 1e342
{0x6A0A4F6B948959B0, 0xAB69D82E364948D4}, // 1e343
{0x848CE34679ABB01C, 0xD6444E39C3DB9B09}, // 1e344
{0xF2D80E0C0C0B4E11, 0x85EAB0E41A6940E5}, // 1e345
{0x6F8E118F0F0E2195, 0xA7655D1D2103911F}, // 1e346
{0x4B7195F2D2D1A9FB, 0xD13EB46469447567}, // 1e347
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Binary to decimal floating point conversion.
// Algorithm:
// 1) store mantissa in multiprecision decimal
// 2) shift decimal by exponent
// 3) read digits out & format
package strconv
import "math"
// TODO: move elsewhere?
type floatInfo struct {
mantbits uint
expbits uint
bias int
}
var float32info = floatInfo{23, 8, -127}
var float64info = floatInfo{52, 11, -1023}
// FormatFloat converts the floating-point number f to a string,
// according to the format fmt and precision prec. It rounds the
// result assuming that the original was obtained from a floating-point
// value of bitSize bits (32 for float32, 64 for float64).
//
// The format fmt is one of
// 'b' (-ddddp±ddd, a binary exponent),
// 'e' (-d.dddde±dd, a decimal exponent),
// 'E' (-d.ddddE±dd, a decimal exponent),
// 'f' (-ddd.dddd, no exponent),
// 'g' ('e' for large exponents, 'f' otherwise),
// 'G' ('E' for large exponents, 'f' otherwise),
// 'x' (-0xd.ddddp±ddd, a hexadecimal fraction and binary exponent), or
// 'X' (-0Xd.ddddP±ddd, a hexadecimal fraction and binary exponent).
//
// The precision prec controls the number of digits (excluding the exponent)
// printed by the 'e', 'E', 'f', 'g', 'G', 'x', and 'X' formats.
// For 'e', 'E', 'f', 'x', and 'X', it is the number of digits after the decimal point.
// For 'g' and 'G' it is the maximum number of significant digits (trailing
// zeros are removed).
// The special precision -1 uses the smallest number of digits
// necessary such that ParseFloat will return f exactly.
func FormatFloat(f float64, fmt byte, prec, bitSize int) string {
return string(genericFtoa(make([]byte, 0, max(prec+4, 24)), f, fmt, prec, bitSize))
}
// AppendFloat appends the string form of the floating-point number f,
// as generated by FormatFloat, to dst and returns the extended buffer.
func AppendFloat(dst []byte, f float64, fmt byte, prec, bitSize int) []byte {
return genericFtoa(dst, f, fmt, prec, bitSize)
}
func genericFtoa(dst []byte, val float64, fmt byte, prec, bitSize int) []byte {
var bits uint64
var flt *floatInfo
switch bitSize {
case 32:
bits = uint64(math.Float32bits(float32(val)))
flt = &float32info
case 64:
bits = math.Float64bits(val)
flt = &float64info
default:
panic("strconv: illegal AppendFloat/FormatFloat bitSize")
}
neg := bits>>(flt.expbits+flt.mantbits) != 0
exp := int(bits>>flt.mantbits) & (1<<flt.expbits - 1)
mant := bits & (uint64(1)<<flt.mantbits - 1)
switch exp {
case 1<<flt.expbits - 1:
// Inf, NaN
var s string
switch {
case mant != 0:
s = "NaN"
case neg:
s = "-Inf"
default:
s = "+Inf"
}
return append(dst, s...)
case 0:
// denormalized
exp++
default:
// add implicit top bit
mant |= uint64(1) << flt.mantbits
}
exp += flt.bias
// Pick off easy binary, hex formats.
if fmt == 'b' {
return fmtB(dst, neg, mant, exp, flt)
}
if fmt == 'x' || fmt == 'X' {
return fmtX(dst, prec, fmt, neg, mant, exp, flt)
}
if !optimize {
return bigFtoa(dst, prec, fmt, neg, mant, exp, flt)
}
var digs decimalSlice
ok := false
// Negative precision means "only as much as needed to be exact."
shortest := prec < 0
if shortest {
// Use Ryu algorithm.
var buf [32]byte
digs.d = buf[:]
ryuFtoaShortest(&digs, mant, exp-int(flt.mantbits), flt)
ok = true
// Precision for shortest representation mode.
switch fmt {
case 'e', 'E':
prec = max(digs.nd-1, 0)
case 'f':
prec = max(digs.nd-digs.dp, 0)
case 'g', 'G':
prec = digs.nd
}
} else if fmt != 'f' {
// Fixed number of digits.
digits := prec
switch fmt {
case 'e', 'E':
digits++
case 'g', 'G':
if prec == 0 {
prec = 1
}
digits = prec
default:
// Invalid mode.
digits = 1
}
var buf [24]byte
if bitSize == 32 && digits <= 9 {
digs.d = buf[:]
ryuFtoaFixed32(&digs, uint32(mant), exp-int(flt.mantbits), digits)
ok = true
} else if digits <= 18 {
digs.d = buf[:]
ryuFtoaFixed64(&digs, mant, exp-int(flt.mantbits), digits)
ok = true
}
}
if !ok {
return bigFtoa(dst, prec, fmt, neg, mant, exp, flt)
}
return formatDigits(dst, shortest, neg, digs, prec, fmt)
}
// bigFtoa uses multiprecision computations to format a float.
func bigFtoa(dst []byte, prec int, fmt byte, neg bool, mant uint64, exp int, flt *floatInfo) []byte {
d := new(decimal)
d.Assign(mant)
d.Shift(exp - int(flt.mantbits))
var digs decimalSlice
shortest := prec < 0
if shortest {
roundShortest(d, mant, exp, flt)
digs = decimalSlice{d: d.d[:], nd: d.nd, dp: d.dp}
// Precision for shortest representation mode.
switch fmt {
case 'e', 'E':
prec = digs.nd - 1
case 'f':
prec = max(digs.nd-digs.dp, 0)
case 'g', 'G':
prec = digs.nd
}
} else {
// Round appropriately.
switch fmt {
case 'e', 'E':
d.Round(prec + 1)
case 'f':
d.Round(d.dp + prec)
case 'g', 'G':
if prec == 0 {
prec = 1
}
d.Round(prec)
}
digs = decimalSlice{d: d.d[:], nd: d.nd, dp: d.dp}
}
return formatDigits(dst, shortest, neg, digs, prec, fmt)
}
func formatDigits(dst []byte, shortest bool, neg bool, digs decimalSlice, prec int, fmt byte) []byte {
switch fmt {
case 'e', 'E':
return fmtE(dst, neg, digs, prec, fmt)
case 'f':
return fmtF(dst, neg, digs, prec)
case 'g', 'G':
// trailing fractional zeros in 'e' form will be trimmed.
eprec := prec
if eprec > digs.nd && digs.nd >= digs.dp {
eprec = digs.nd
}
// %e is used if the exponent from the conversion
// is less than -4 or greater than or equal to the precision.
// if precision was the shortest possible, use precision 6 for this decision.
if shortest {
eprec = 6
}
exp := digs.dp - 1
if exp < -4 || exp >= eprec {
if prec > digs.nd {
prec = digs.nd
}
return fmtE(dst, neg, digs, prec-1, fmt+'e'-'g')
}
if prec > digs.dp {
prec = digs.nd
}
return fmtF(dst, neg, digs, max(prec-digs.dp, 0))
}
// unknown format
return append(dst, '%', fmt)
}
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
// that will let the original floating point value be precisely reconstructed.
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
// If mantissa is zero, the number is zero; stop now.
if mant == 0 {
d.nd = 0
return
}
// Compute upper and lower such that any decimal number
// between upper and lower (possibly inclusive)
// will round to the original floating point number.
// We may see at once that the number is already shortest.
//
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
// The closest shorter number is at least 10^(dp-nd) away.
// The lower/upper bounds computed below are at distance
// at most 2^(exp-mantbits).
//
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
minexp := flt.bias + 1 // minimum possible exponent
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
// The number is already shortest.
return
}
// d = mant << (exp - mantbits)
// Next highest floating point number is mant+1 << exp-mantbits.
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
upper := new(decimal)
upper.Assign(mant*2 + 1)
upper.Shift(exp - int(flt.mantbits) - 1)
// d = mant << (exp - mantbits)
// Next lowest floating point number is mant-1 << exp-mantbits,
// unless mant-1 drops the significant bit and exp is not the minimum exp,
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
// Either way, call it mantlo << explo-mantbits.
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
var mantlo uint64
var explo int
if mant > 1<<flt.mantbits || exp == minexp {
mantlo = mant - 1
explo = exp
} else {
mantlo = mant*2 - 1
explo = exp - 1
}
lower := new(decimal)
lower.Assign(mantlo*2 + 1)
lower.Shift(explo - int(flt.mantbits) - 1)
// The upper and lower bounds are possible outputs only if
// the original mantissa is even, so that IEEE round-to-even
// would round to the original mantissa and not the neighbors.
inclusive := mant%2 == 0
// As we walk the digits we want to know whether rounding up would fall
// within the upper bound. This is tracked by upperdelta:
//
// If upperdelta == 0, the digits of d and upper are the same so far.
//
// If upperdelta == 1, we saw a difference of 1 between d and upper on a
// previous digit and subsequently only 9s for d and 0s for upper.
// (Thus rounding up may fall outside the bound, if it is exclusive.)
//
// If upperdelta == 2, then the difference is greater than 1
// and we know that rounding up falls within the bound.
var upperdelta uint8
// Now we can figure out the minimum number of digits required.
// Walk along until d has distinguished itself from upper and lower.
for ui := 0; ; ui++ {
// lower, d, and upper may have the decimal points at different
// places. In this case upper is the longest, so we iterate from
// ui==0 and start li and mi at (possibly) -1.
mi := ui - upper.dp + d.dp
if mi >= d.nd {
break
}
li := ui - upper.dp + lower.dp
l := byte('0') // lower digit
if li >= 0 && li < lower.nd {
l = lower.d[li]
}
m := byte('0') // middle digit
if mi >= 0 {
m = d.d[mi]
}
u := byte('0') // upper digit
if ui < upper.nd {
u = upper.d[ui]
}
// Okay to round down (truncate) if lower has a different digit
// or if lower is inclusive and is exactly the result of rounding
// down (i.e., and we have reached the final digit of lower).
okdown := l != m || inclusive && li+1 == lower.nd
switch {
case upperdelta == 0 && m+1 < u:
// Example:
// m = 12345xxx
// u = 12347xxx
upperdelta = 2
case upperdelta == 0 && m != u:
// Example:
// m = 12345xxx
// u = 12346xxx
upperdelta = 1
case upperdelta == 1 && (m != '9' || u != '0'):
// Example:
// m = 1234598x
// u = 1234600x
upperdelta = 2
}
// Okay to round up if upper has a different digit and either upper
// is inclusive or upper is bigger than the result of rounding up.
okup := upperdelta > 0 && (inclusive || upperdelta > 1 || ui+1 < upper.nd)
// If it's okay to do either, then round to the nearest one.
// If it's okay to do only one, do it.
switch {
case okdown && okup:
d.Round(mi + 1)
return
case okdown:
d.RoundDown(mi + 1)
return
case okup:
d.RoundUp(mi + 1)
return
}
}
}
type decimalSlice struct {
d []byte
nd, dp int
}
// %e: -d.ddddde±dd
func fmtE(dst []byte, neg bool, d decimalSlice, prec int, fmt byte) []byte {
// sign
if neg {
dst = append(dst, '-')
}
// first digit
ch := byte('0')
if d.nd != 0 {
ch = d.d[0]
}
dst = append(dst, ch)
// .moredigits
if prec > 0 {
dst = append(dst, '.')
i := 1
m := min(d.nd, prec+1)
if i < m {
dst = append(dst, d.d[i:m]...)
i = m
}
for ; i <= prec; i++ {
dst = append(dst, '0')
}
}
// e±
dst = append(dst, fmt)
exp := d.dp - 1
if d.nd == 0 { // special case: 0 has exponent 0
exp = 0
}
if exp < 0 {
ch = '-'
exp = -exp
} else {
ch = '+'
}
dst = append(dst, ch)
// dd or ddd
switch {
case exp < 10:
dst = append(dst, '0', byte(exp)+'0')
case exp < 100:
dst = append(dst, byte(exp/10)+'0', byte(exp%10)+'0')
default:
dst = append(dst, byte(exp/100)+'0', byte(exp/10)%10+'0', byte(exp%10)+'0')
}
return dst
}
// %f: -ddddddd.ddddd
func fmtF(dst []byte, neg bool, d decimalSlice, prec int) []byte {
// sign
if neg {
dst = append(dst, '-')
}
// integer, padded with zeros as needed.
if d.dp > 0 {
m := min(d.nd, d.dp)
dst = append(dst, d.d[:m]...)
for ; m < d.dp; m++ {
dst = append(dst, '0')
}
} else {
dst = append(dst, '0')
}
// fraction
if prec > 0 {
dst = append(dst, '.')
for i := 0; i < prec; i++ {
ch := byte('0')
if j := d.dp + i; 0 <= j && j < d.nd {
ch = d.d[j]
}
dst = append(dst, ch)
}
}
return dst
}
// %b: -ddddddddp±ddd
func fmtB(dst []byte, neg bool, mant uint64, exp int, flt *floatInfo) []byte {
// sign
if neg {
dst = append(dst, '-')
}
// mantissa
dst, _ = formatBits(dst, mant, 10, false, true)
// p
dst = append(dst, 'p')
// ±exponent
exp -= int(flt.mantbits)
if exp >= 0 {
dst = append(dst, '+')
}
dst, _ = formatBits(dst, uint64(exp), 10, exp < 0, true)
return dst
}
// %x: -0x1.yyyyyyyyp±ddd or -0x0p+0. (y is hex digit, d is decimal digit)
func fmtX(dst []byte, prec int, fmt byte, neg bool, mant uint64, exp int, flt *floatInfo) []byte {
if mant == 0 {
exp = 0
}
// Shift digits so leading 1 (if any) is at bit 1<<60.
mant <<= 60 - flt.mantbits
for mant != 0 && mant&(1<<60) == 0 {
mant <<= 1
exp--
}
// Round if requested.
if prec >= 0 && prec < 15 {
shift := uint(prec * 4)
extra := (mant << shift) & (1<<60 - 1)
mant >>= 60 - shift
if extra|(mant&1) > 1<<59 {
mant++
}
mant <<= 60 - shift
if mant&(1<<61) != 0 {
// Wrapped around.
mant >>= 1
exp++
}
}
hex := lowerhex
if fmt == 'X' {
hex = upperhex
}
// sign, 0x, leading digit
if neg {
dst = append(dst, '-')
}
dst = append(dst, '0', fmt, '0'+byte((mant>>60)&1))
// .fraction
mant <<= 4 // remove leading 0 or 1
if prec < 0 && mant != 0 {
dst = append(dst, '.')
for mant != 0 {
dst = append(dst, hex[(mant>>60)&15])
mant <<= 4
}
} else if prec > 0 {
dst = append(dst, '.')
for i := 0; i < prec; i++ {
dst = append(dst, hex[(mant>>60)&15])
mant <<= 4
}
}
// p±
ch := byte('P')
if fmt == lower(fmt) {
ch = 'p'
}
dst = append(dst, ch)
if exp < 0 {
ch = '-'
exp = -exp
} else {
ch = '+'
}
dst = append(dst, ch)
// dd or ddd or dddd
switch {
case exp < 100:
dst = append(dst, byte(exp/10)+'0', byte(exp%10)+'0')
case exp < 1000:
dst = append(dst, byte(exp/100)+'0', byte((exp/10)%10)+'0', byte(exp%10)+'0')
default:
dst = append(dst, byte(exp/1000)+'0', byte(exp/100)%10+'0', byte((exp/10)%10)+'0', byte(exp%10)+'0')
}
return dst
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
import (
"math/bits"
)
// binary to decimal conversion using the Ryū algorithm.
//
// See Ulf Adams, "Ryū: Fast Float-to-String Conversion" (doi:10.1145/3192366.3192369)
//
// Fixed precision formatting is a variant of the original paper's
// algorithm, where a single multiplication by 10^k is required,
// sharing the same rounding guarantees.
// ryuFtoaFixed32 formats mant*(2^exp) with prec decimal digits.
func ryuFtoaFixed32(d *decimalSlice, mant uint32, exp int, prec int) {
if prec < 0 {
panic("ryuFtoaFixed32 called with negative prec")
}
if prec > 9 {
panic("ryuFtoaFixed32 called with prec > 9")
}
// Zero input.
if mant == 0 {
d.nd, d.dp = 0, 0
return
}
// Renormalize to a 25-bit mantissa.
e2 := exp
if b := bits.Len32(mant); b < 25 {
mant <<= uint(25 - b)
e2 += b - 25
}
// Choose an exponent such that rounded mant*(2^e2)*(10^q) has
// at least prec decimal digits, i.e
// mant*(2^e2)*(10^q) >= 10^(prec-1)
// Because mant >= 2^24, it is enough to choose:
// 2^(e2+24) >= 10^(-q+prec-1)
// or q = -mulByLog2Log10(e2+24) + prec - 1
q := -mulByLog2Log10(e2+24) + prec - 1
// Now compute mant*(2^e2)*(10^q).
// Is it an exact computation?
// Only small positive powers of 10 are exact (5^28 has 66 bits).
exact := q <= 27 && q >= 0
di, dexp2, d0 := mult64bitPow10(mant, e2, q)
if dexp2 >= 0 {
panic("not enough significant bits after mult64bitPow10")
}
// As a special case, computation might still be exact, if exponent
// was negative and if it amounts to computing an exact division.
// In that case, we ignore all lower bits.
// Note that division by 10^11 cannot be exact as 5^11 has 26 bits.
if q < 0 && q >= -10 && divisibleByPower5(uint64(mant), -q) {
exact = true
d0 = true
}
// Remove extra lower bits and keep rounding info.
extra := uint(-dexp2)
extraMask := uint32(1<<extra - 1)
di, dfrac := di>>extra, di&extraMask
roundUp := false
if exact {
// If we computed an exact product, d + 1/2
// should round to d+1 if 'd' is odd.
roundUp = dfrac > 1<<(extra-1) ||
(dfrac == 1<<(extra-1) && !d0) ||
(dfrac == 1<<(extra-1) && d0 && di&1 == 1)
} else {
// otherwise, d+1/2 always rounds up because
// we truncated below.
roundUp = dfrac>>(extra-1) == 1
}
if dfrac != 0 {
d0 = false
}
// Proceed to the requested number of digits
formatDecimal(d, uint64(di), !d0, roundUp, prec)
// Adjust exponent
d.dp -= q
}
// ryuFtoaFixed64 formats mant*(2^exp) with prec decimal digits.
func ryuFtoaFixed64(d *decimalSlice, mant uint64, exp int, prec int) {
if prec > 18 {
panic("ryuFtoaFixed64 called with prec > 18")
}
// Zero input.
if mant == 0 {
d.nd, d.dp = 0, 0
return
}
// Renormalize to a 55-bit mantissa.
e2 := exp
if b := bits.Len64(mant); b < 55 {
mant = mant << uint(55-b)
e2 += b - 55
}
// Choose an exponent such that rounded mant*(2^e2)*(10^q) has
// at least prec decimal digits, i.e
// mant*(2^e2)*(10^q) >= 10^(prec-1)
// Because mant >= 2^54, it is enough to choose:
// 2^(e2+54) >= 10^(-q+prec-1)
// or q = -mulByLog2Log10(e2+54) + prec - 1
//
// The minimal required exponent is -mulByLog2Log10(1025)+18 = -291
// The maximal required exponent is mulByLog2Log10(1074)+18 = 342
q := -mulByLog2Log10(e2+54) + prec - 1
// Now compute mant*(2^e2)*(10^q).
// Is it an exact computation?
// Only small positive powers of 10 are exact (5^55 has 128 bits).
exact := q <= 55 && q >= 0
di, dexp2, d0 := mult128bitPow10(mant, e2, q)
if dexp2 >= 0 {
panic("not enough significant bits after mult128bitPow10")
}
// As a special case, computation might still be exact, if exponent
// was negative and if it amounts to computing an exact division.
// In that case, we ignore all lower bits.
// Note that division by 10^23 cannot be exact as 5^23 has 54 bits.
if q < 0 && q >= -22 && divisibleByPower5(mant, -q) {
exact = true
d0 = true
}
// Remove extra lower bits and keep rounding info.
extra := uint(-dexp2)
extraMask := uint64(1<<extra - 1)
di, dfrac := di>>extra, di&extraMask
roundUp := false
if exact {
// If we computed an exact product, d + 1/2
// should round to d+1 if 'd' is odd.
roundUp = dfrac > 1<<(extra-1) ||
(dfrac == 1<<(extra-1) && !d0) ||
(dfrac == 1<<(extra-1) && d0 && di&1 == 1)
} else {
// otherwise, d+1/2 always rounds up because
// we truncated below.
roundUp = dfrac>>(extra-1) == 1
}
if dfrac != 0 {
d0 = false
}
// Proceed to the requested number of digits
formatDecimal(d, di, !d0, roundUp, prec)
// Adjust exponent
d.dp -= q
}
var uint64pow10 = [...]uint64{
1, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9,
1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19,
}
// formatDecimal fills d with at most prec decimal digits
// of mantissa m. The boolean trunc indicates whether m
// is truncated compared to the original number being formatted.
func formatDecimal(d *decimalSlice, m uint64, trunc bool, roundUp bool, prec int) {
max := uint64pow10[prec]
trimmed := 0
for m >= max {
a, b := m/10, m%10
m = a
trimmed++
if b > 5 {
roundUp = true
} else if b < 5 {
roundUp = false
} else { // b == 5
// round up if there are trailing digits,
// or if the new value of m is odd (round-to-even convention)
roundUp = trunc || m&1 == 1
}
if b != 0 {
trunc = true
}
}
if roundUp {
m++
}
if m >= max {
// Happens if di was originally 99999....xx
m /= 10
trimmed++
}
// render digits (similar to formatBits)
n := uint(prec)
d.nd = prec
v := m
for v >= 100 {
var v1, v2 uint64
if v>>32 == 0 {
v1, v2 = uint64(uint32(v)/100), uint64(uint32(v)%100)
} else {
v1, v2 = v/100, v%100
}
n -= 2
d.d[n+1] = smallsString[2*v2+1]
d.d[n+0] = smallsString[2*v2+0]
v = v1
}
if v > 0 {
n--
d.d[n] = smallsString[2*v+1]
}
if v >= 10 {
n--
d.d[n] = smallsString[2*v]
}
for d.d[d.nd-1] == '0' {
d.nd--
trimmed++
}
d.dp = d.nd + trimmed
}
// ryuFtoaShortest formats mant*2^exp with prec decimal digits.
func ryuFtoaShortest(d *decimalSlice, mant uint64, exp int, flt *floatInfo) {
if mant == 0 {
d.nd, d.dp = 0, 0
return
}
// If input is an exact integer with fewer bits than the mantissa,
// the previous and next integer are not admissible representations.
if exp <= 0 && bits.TrailingZeros64(mant) >= -exp {
mant >>= uint(-exp)
ryuDigits(d, mant, mant, mant, true, false)
return
}
ml, mc, mu, e2 := computeBounds(mant, exp, flt)
if e2 == 0 {
ryuDigits(d, ml, mc, mu, true, false)
return
}
// Find 10^q *larger* than 2^-e2
q := mulByLog2Log10(-e2) + 1
// We are going to multiply by 10^q using 128-bit arithmetic.
// The exponent is the same for all 3 numbers.
var dl, dc, du uint64
var dl0, dc0, du0 bool
if flt == &float32info {
var dl32, dc32, du32 uint32
dl32, _, dl0 = mult64bitPow10(uint32(ml), e2, q)
dc32, _, dc0 = mult64bitPow10(uint32(mc), e2, q)
du32, e2, du0 = mult64bitPow10(uint32(mu), e2, q)
dl, dc, du = uint64(dl32), uint64(dc32), uint64(du32)
} else {
dl, _, dl0 = mult128bitPow10(ml, e2, q)
dc, _, dc0 = mult128bitPow10(mc, e2, q)
du, e2, du0 = mult128bitPow10(mu, e2, q)
}
if e2 >= 0 {
panic("not enough significant bits after mult128bitPow10")
}
// Is it an exact computation?
if q > 55 {
// Large positive powers of ten are not exact
dl0, dc0, du0 = false, false, false
}
if q < 0 && q >= -24 {
// Division by a power of ten may be exact.
// (note that 5^25 is a 59-bit number so division by 5^25 is never exact).
if divisibleByPower5(ml, -q) {
dl0 = true
}
if divisibleByPower5(mc, -q) {
dc0 = true
}
if divisibleByPower5(mu, -q) {
du0 = true
}
}
// Express the results (dl, dc, du)*2^e2 as integers.
// Extra bits must be removed and rounding hints computed.
extra := uint(-e2)
extraMask := uint64(1<<extra - 1)
// Now compute the floored, integral base 10 mantissas.
dl, fracl := dl>>extra, dl&extraMask
dc, fracc := dc>>extra, dc&extraMask
du, fracu := du>>extra, du&extraMask
// Is it allowed to use 'du' as a result?
// It is always allowed when it is truncated, but also
// if it is exact and the original binary mantissa is even
// When disallowed, we can subtract 1.
uok := !du0 || fracu > 0
if du0 && fracu == 0 {
uok = mant&1 == 0
}
if !uok {
du--
}
// Is 'dc' the correctly rounded base 10 mantissa?
// The correct rounding might be dc+1
cup := false // don't round up.
if dc0 {
// If we computed an exact product, the half integer
// should round to next (even) integer if 'dc' is odd.
cup = fracc > 1<<(extra-1) ||
(fracc == 1<<(extra-1) && dc&1 == 1)
} else {
// otherwise, the result is a lower truncation of the ideal
// result.
cup = fracc>>(extra-1) == 1
}
// Is 'dl' an allowed representation?
// Only if it is an exact value, and if the original binary mantissa
// was even.
lok := dl0 && fracl == 0 && (mant&1 == 0)
if !lok {
dl++
}
// We need to remember whether the trimmed digits of 'dc' are zero.
c0 := dc0 && fracc == 0
// render digits
ryuDigits(d, dl, dc, du, c0, cup)
d.dp -= q
}
// mulByLog2Log10 returns math.Floor(x * log(2)/log(10)) for an integer x in
// the range -1600 <= x && x <= +1600.
//
// The range restriction lets us work in faster integer arithmetic instead of
// slower floating point arithmetic. Correctness is verified by unit tests.
func mulByLog2Log10(x int) int {
// log(2)/log(10) ≈ 0.30102999566 ≈ 78913 / 2^18
return (x * 78913) >> 18
}
// mulByLog10Log2 returns math.Floor(x * log(10)/log(2)) for an integer x in
// the range -500 <= x && x <= +500.
//
// The range restriction lets us work in faster integer arithmetic instead of
// slower floating point arithmetic. Correctness is verified by unit tests.
func mulByLog10Log2(x int) int {
// log(10)/log(2) ≈ 3.32192809489 ≈ 108853 / 2^15
return (x * 108853) >> 15
}
// computeBounds returns a floating-point vector (l, c, u)×2^e2
// where the mantissas are 55-bit (or 26-bit) integers, describing the interval
// represented by the input float64 or float32.
func computeBounds(mant uint64, exp int, flt *floatInfo) (lower, central, upper uint64, e2 int) {
if mant != 1<<flt.mantbits || exp == flt.bias+1-int(flt.mantbits) {
// regular case (or denormals)
lower, central, upper = 2*mant-1, 2*mant, 2*mant+1
e2 = exp - 1
return
} else {
// border of an exponent
lower, central, upper = 4*mant-1, 4*mant, 4*mant+2
e2 = exp - 2
return
}
}
func ryuDigits(d *decimalSlice, lower, central, upper uint64,
c0, cup bool) {
lhi, llo := divmod1e9(lower)
chi, clo := divmod1e9(central)
uhi, ulo := divmod1e9(upper)
if uhi == 0 {
// only low digits (for denormals)
ryuDigits32(d, llo, clo, ulo, c0, cup, 8)
} else if lhi < uhi {
// truncate 9 digits at once.
if llo != 0 {
lhi++
}
c0 = c0 && clo == 0
cup = (clo > 5e8) || (clo == 5e8 && cup)
ryuDigits32(d, lhi, chi, uhi, c0, cup, 8)
d.dp += 9
} else {
d.nd = 0
// emit high part
n := uint(9)
for v := chi; v > 0; {
v1, v2 := v/10, v%10
v = v1
n--
d.d[n] = byte(v2 + '0')
}
d.d = d.d[n:]
d.nd = int(9 - n)
// emit low part
ryuDigits32(d, llo, clo, ulo,
c0, cup, d.nd+8)
}
// trim trailing zeros
for d.nd > 0 && d.d[d.nd-1] == '0' {
d.nd--
}
// trim initial zeros
for d.nd > 0 && d.d[0] == '0' {
d.nd--
d.dp--
d.d = d.d[1:]
}
}
// ryuDigits32 emits decimal digits for a number less than 1e9.
func ryuDigits32(d *decimalSlice, lower, central, upper uint32,
c0, cup bool, endindex int) {
if upper == 0 {
d.dp = endindex + 1
return
}
trimmed := 0
// Remember last trimmed digit to check for round-up.
// c0 will be used to remember zeroness of following digits.
cNextDigit := 0
for upper > 0 {
// Repeatedly compute:
// l = Ceil(lower / 10^k)
// c = Round(central / 10^k)
// u = Floor(upper / 10^k)
// and stop when c goes out of the (l, u) interval.
l := (lower + 9) / 10
c, cdigit := central/10, central%10
u := upper / 10
if l > u {
// don't trim the last digit as it is forbidden to go below l
// other, trim and exit now.
break
}
// Check that we didn't cross the lower boundary.
// The case where l < u but c == l-1 is essentially impossible,
// but may happen if:
// lower = ..11
// central = ..19
// upper = ..31
// and means that 'central' is very close but less than
// an integer ending with many zeros, and usually
// the "round-up" logic hides the problem.
if l == c+1 && c < u {
c++
cdigit = 0
cup = false
}
trimmed++
// Remember trimmed digits of c
c0 = c0 && cNextDigit == 0
cNextDigit = int(cdigit)
lower, central, upper = l, c, u
}
// should we round up?
if trimmed > 0 {
cup = cNextDigit > 5 ||
(cNextDigit == 5 && !c0) ||
(cNextDigit == 5 && c0 && central&1 == 1)
}
if central < upper && cup {
central++
}
// We know where the number ends, fill directly
endindex -= trimmed
v := central
n := endindex
for n > d.nd {
v1, v2 := v/100, v%100
d.d[n] = smallsString[2*v2+1]
d.d[n-1] = smallsString[2*v2+0]
n -= 2
v = v1
}
if n == d.nd {
d.d[n] = byte(v + '0')
}
d.nd = endindex + 1
d.dp = d.nd + trimmed
}
// mult64bitPow10 takes a floating-point input with a 25-bit
// mantissa and multiplies it with 10^q. The resulting mantissa
// is m*P >> 57 where P is a 64-bit element of the detailedPowersOfTen tables.
// It is typically 31 or 32-bit wide.
// The returned boolean is true if all trimmed bits were zero.
//
// That is:
//
// m*2^e2 * round(10^q) = resM * 2^resE + ε
// exact = ε == 0
func mult64bitPow10(m uint32, e2, q int) (resM uint32, resE int, exact bool) {
if q == 0 {
// P == 1<<63
return m << 6, e2 - 6, true
}
if q < detailedPowersOfTenMinExp10 || detailedPowersOfTenMaxExp10 < q {
// This never happens due to the range of float32/float64 exponent
panic("mult64bitPow10: power of 10 is out of range")
}
pow := detailedPowersOfTen[q-detailedPowersOfTenMinExp10][1]
if q < 0 {
// Inverse powers of ten must be rounded up.
pow += 1
}
hi, lo := bits.Mul64(uint64(m), pow)
e2 += mulByLog10Log2(q) - 63 + 57
return uint32(hi<<7 | lo>>57), e2, lo<<7 == 0
}
// mult128bitPow10 takes a floating-point input with a 55-bit
// mantissa and multiplies it with 10^q. The resulting mantissa
// is m*P >> 119 where P is a 128-bit element of the detailedPowersOfTen tables.
// It is typically 63 or 64-bit wide.
// The returned boolean is true is all trimmed bits were zero.
//
// That is:
//
// m*2^e2 * round(10^q) = resM * 2^resE + ε
// exact = ε == 0
func mult128bitPow10(m uint64, e2, q int) (resM uint64, resE int, exact bool) {
if q == 0 {
// P == 1<<127
return m << 8, e2 - 8, true
}
if q < detailedPowersOfTenMinExp10 || detailedPowersOfTenMaxExp10 < q {
// This never happens due to the range of float32/float64 exponent
panic("mult128bitPow10: power of 10 is out of range")
}
pow := detailedPowersOfTen[q-detailedPowersOfTenMinExp10]
if q < 0 {
// Inverse powers of ten must be rounded up.
pow[0] += 1
}
e2 += mulByLog10Log2(q) - 127 + 119
// long multiplication
l1, l0 := bits.Mul64(m, pow[0])
h1, h0 := bits.Mul64(m, pow[1])
mid, carry := bits.Add64(l1, h0, 0)
h1 += carry
return h1<<9 | mid>>55, e2, mid<<9 == 0 && l0 == 0
}
func divisibleByPower5(m uint64, k int) bool {
if m == 0 {
return true
}
for i := 0; i < k; i++ {
if m%5 != 0 {
return false
}
m /= 5
}
return true
}
// divmod1e9 computes quotient and remainder of division by 1e9,
// avoiding runtime uint64 division on 32-bit platforms.
func divmod1e9(x uint64) (uint32, uint32) {
if !host32bit {
return uint32(x / 1e9), uint32(x % 1e9)
}
// Use the same sequence of operations as the amd64 compiler.
hi, _ := bits.Mul64(x>>1, 0x89705f4136b4a598) // binary digits of 1e-9
q := hi >> 28
return uint32(q), uint32(x - q*1e9)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strconv
import "math/bits"
const fastSmalls = true // enable fast path for small integers
// FormatUint returns the string representation of i in the given base,
// for 2 <= base <= 36. The result uses the lower-case letters 'a' to 'z'
// for digit values >= 10.
func FormatUint(i uint64, base int) string {
if fastSmalls && i < nSmalls && base == 10 {
return small(int(i))
}
_, s := formatBits(nil, i, base, false, false)
return s
}
// FormatInt returns the string representation of i in the given base,
// for 2 <= base <= 36. The result uses the lower-case letters 'a' to 'z'
// for digit values >= 10.
func FormatInt(i int64, base int) string {
if fastSmalls && 0 <= i && i < nSmalls && base == 10 {
return small(int(i))
}
_, s := formatBits(nil, uint64(i), base, i < 0, false)
return s
}
// Itoa is equivalent to FormatInt(int64(i), 10).
func Itoa(i int) string {
return FormatInt(int64(i), 10)
}
// AppendInt appends the string form of the integer i,
// as generated by FormatInt, to dst and returns the extended buffer.
func AppendInt(dst []byte, i int64, base int) []byte {
if fastSmalls && 0 <= i && i < nSmalls && base == 10 {
return append(dst, small(int(i))...)
}
dst, _ = formatBits(dst, uint64(i), base, i < 0, true)
return dst
}
// AppendUint appends the string form of the unsigned integer i,
// as generated by FormatUint, to dst and returns the extended buffer.
func AppendUint(dst []byte, i uint64, base int) []byte {
if fastSmalls && i < nSmalls && base == 10 {
return append(dst, small(int(i))...)
}
dst, _ = formatBits(dst, i, base, false, true)
return dst
}
// small returns the string for an i with 0 <= i < nSmalls.
func small(i int) string {
if i < 10 {
return digits[i : i+1]
}
return smallsString[i*2 : i*2+2]
}
const nSmalls = 100
const smallsString = "00010203040506070809" +
"10111213141516171819" +
"20212223242526272829" +
"30313233343536373839" +
"40414243444546474849" +
"50515253545556575859" +
"60616263646566676869" +
"70717273747576777879" +
"80818283848586878889" +
"90919293949596979899"
const host32bit = ^uint(0)>>32 == 0
const digits = "0123456789abcdefghijklmnopqrstuvwxyz"
// formatBits computes the string representation of u in the given base.
// If neg is set, u is treated as negative int64 value. If append_ is
// set, the string is appended to dst and the resulting byte slice is
// returned as the first result value; otherwise the string is returned
// as the second result value.
func formatBits(dst []byte, u uint64, base int, neg, append_ bool) (d []byte, s string) {
if base < 2 || base > len(digits) {
panic("strconv: illegal AppendInt/FormatInt base")
}
// 2 <= base && base <= len(digits)
var a [64 + 1]byte // +1 for sign of 64bit value in base 2
i := len(a)
if neg {
u = -u
}
// convert bits
// We use uint values where we can because those will
// fit into a single register even on a 32bit machine.
if base == 10 {
// common case: use constants for / because
// the compiler can optimize it into a multiply+shift
if host32bit {
// convert the lower digits using 32bit operations
for u >= 1e9 {
// Avoid using r = a%b in addition to q = a/b
// since 64bit division and modulo operations
// are calculated by runtime functions on 32bit machines.
q := u / 1e9
us := uint(u - q*1e9) // u % 1e9 fits into a uint
for j := 4; j > 0; j-- {
is := us % 100 * 2
us /= 100
i -= 2
a[i+1] = smallsString[is+1]
a[i+0] = smallsString[is+0]
}
// us < 10, since it contains the last digit
// from the initial 9-digit us.
i--
a[i] = smallsString[us*2+1]
u = q
}
// u < 1e9
}
// u guaranteed to fit into a uint
us := uint(u)
for us >= 100 {
is := us % 100 * 2
us /= 100
i -= 2
a[i+1] = smallsString[is+1]
a[i+0] = smallsString[is+0]
}
// us < 100
is := us * 2
i--
a[i] = smallsString[is+1]
if us >= 10 {
i--
a[i] = smallsString[is]
}
} else if isPowerOfTwo(base) {
// Use shifts and masks instead of / and %.
// Base is a power of 2 and 2 <= base <= len(digits) where len(digits) is 36.
// The largest power of 2 below or equal to 36 is 32, which is 1 << 5;
// i.e., the largest possible shift count is 5. By &-ind that value with
// the constant 7 we tell the compiler that the shift count is always
// less than 8 which is smaller than any register width. This allows
// the compiler to generate better code for the shift operation.
shift := uint(bits.TrailingZeros(uint(base))) & 7
b := uint64(base)
m := uint(base) - 1 // == 1<<shift - 1
for u >= b {
i--
a[i] = digits[uint(u)&m]
u >>= shift
}
// u < base
i--
a[i] = digits[uint(u)]
} else {
// general case
b := uint64(base)
for u >= b {
i--
// Avoid using r = a%b in addition to q = a/b
// since 64bit division and modulo operations
// are calculated by runtime functions on 32bit machines.
q := u / b
a[i] = digits[uint(u-q*b)]
u = q
}
// u < base
i--
a[i] = digits[uint(u)]
}
// add sign, if any
if neg {
i--
a[i] = '-'
}
if append_ {
d = append(dst, a[i:]...)
return
}
s = string(a[i:])
return
}
func isPowerOfTwo(x int) bool {
return x&(x-1) == 0
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run makeisprint.go -output isprint.go
package strconv
import (
"unicode/utf8"
)
const (
lowerhex = "0123456789abcdef"
upperhex = "0123456789ABCDEF"
)
// contains reports whether the string contains the byte c.
func contains(s string, c byte) bool {
return index(s, c) != -1
}
func quoteWith(s string, quote byte, ASCIIonly, graphicOnly bool) string {
return string(appendQuotedWith(make([]byte, 0, 3*len(s)/2), s, quote, ASCIIonly, graphicOnly))
}
func quoteRuneWith(r rune, quote byte, ASCIIonly, graphicOnly bool) string {
return string(appendQuotedRuneWith(nil, r, quote, ASCIIonly, graphicOnly))
}
func appendQuotedWith(buf []byte, s string, quote byte, ASCIIonly, graphicOnly bool) []byte {
// Often called with big strings, so preallocate. If there's quoting,
// this is conservative but still helps a lot.
if cap(buf)-len(buf) < len(s) {
nBuf := make([]byte, len(buf), len(buf)+1+len(s)+1)
copy(nBuf, buf)
buf = nBuf
}
buf = append(buf, quote)
for width := 0; len(s) > 0; s = s[width:] {
r := rune(s[0])
width = 1
if r >= utf8.RuneSelf {
r, width = utf8.DecodeRuneInString(s)
}
if width == 1 && r == utf8.RuneError {
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[s[0]>>4])
buf = append(buf, lowerhex[s[0]&0xF])
continue
}
buf = appendEscapedRune(buf, r, quote, ASCIIonly, graphicOnly)
}
buf = append(buf, quote)
return buf
}
func appendQuotedRuneWith(buf []byte, r rune, quote byte, ASCIIonly, graphicOnly bool) []byte {
buf = append(buf, quote)
if !utf8.ValidRune(r) {
r = utf8.RuneError
}
buf = appendEscapedRune(buf, r, quote, ASCIIonly, graphicOnly)
buf = append(buf, quote)
return buf
}
func appendEscapedRune(buf []byte, r rune, quote byte, ASCIIonly, graphicOnly bool) []byte {
var runeTmp [utf8.UTFMax]byte
if r == rune(quote) || r == '\\' { // always backslashed
buf = append(buf, '\\')
buf = append(buf, byte(r))
return buf
}
if ASCIIonly {
if r < utf8.RuneSelf && IsPrint(r) {
buf = append(buf, byte(r))
return buf
}
} else if IsPrint(r) || graphicOnly && isInGraphicList(r) {
n := utf8.EncodeRune(runeTmp[:], r)
buf = append(buf, runeTmp[:n]...)
return buf
}
switch r {
case '\a':
buf = append(buf, `\a`...)
case '\b':
buf = append(buf, `\b`...)
case '\f':
buf = append(buf, `\f`...)
case '\n':
buf = append(buf, `\n`...)
case '\r':
buf = append(buf, `\r`...)
case '\t':
buf = append(buf, `\t`...)
case '\v':
buf = append(buf, `\v`...)
default:
switch {
case r < ' ' || r == 0x7f:
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[byte(r)>>4])
buf = append(buf, lowerhex[byte(r)&0xF])
case !utf8.ValidRune(r):
r = 0xFFFD
fallthrough
case r < 0x10000:
buf = append(buf, `\u`...)
for s := 12; s >= 0; s -= 4 {
buf = append(buf, lowerhex[r>>uint(s)&0xF])
}
default:
buf = append(buf, `\U`...)
for s := 28; s >= 0; s -= 4 {
buf = append(buf, lowerhex[r>>uint(s)&0xF])
}
}
}
return buf
}
// Quote returns a double-quoted Go string literal representing s. The
// returned string uses Go escape sequences (\t, \n, \xFF, \u0100) for
// control characters and non-printable characters as defined by
// IsPrint.
func Quote(s string) string {
return quoteWith(s, '"', false, false)
}
// AppendQuote appends a double-quoted Go string literal representing s,
// as generated by Quote, to dst and returns the extended buffer.
func AppendQuote(dst []byte, s string) []byte {
return appendQuotedWith(dst, s, '"', false, false)
}
// QuoteToASCII returns a double-quoted Go string literal representing s.
// The returned string uses Go escape sequences (\t, \n, \xFF, \u0100) for
// non-ASCII characters and non-printable characters as defined by IsPrint.
func QuoteToASCII(s string) string {
return quoteWith(s, '"', true, false)
}
// AppendQuoteToASCII appends a double-quoted Go string literal representing s,
// as generated by QuoteToASCII, to dst and returns the extended buffer.
func AppendQuoteToASCII(dst []byte, s string) []byte {
return appendQuotedWith(dst, s, '"', true, false)
}
// QuoteToGraphic returns a double-quoted Go string literal representing s.
// The returned string leaves Unicode graphic characters, as defined by
// IsGraphic, unchanged and uses Go escape sequences (\t, \n, \xFF, \u0100)
// for non-graphic characters.
func QuoteToGraphic(s string) string {
return quoteWith(s, '"', false, true)
}
// AppendQuoteToGraphic appends a double-quoted Go string literal representing s,
// as generated by QuoteToGraphic, to dst and returns the extended buffer.
func AppendQuoteToGraphic(dst []byte, s string) []byte {
return appendQuotedWith(dst, s, '"', false, true)
}
// QuoteRune returns a single-quoted Go character literal representing the
// rune. The returned string uses Go escape sequences (\t, \n, \xFF, \u0100)
// for control characters and non-printable characters as defined by IsPrint.
// If r is not a valid Unicode code point, it is interpreted as the Unicode
// replacement character U+FFFD.
func QuoteRune(r rune) string {
return quoteRuneWith(r, '\'', false, false)
}
// AppendQuoteRune appends a single-quoted Go character literal representing the rune,
// as generated by QuoteRune, to dst and returns the extended buffer.
func AppendQuoteRune(dst []byte, r rune) []byte {
return appendQuotedRuneWith(dst, r, '\'', false, false)
}
// QuoteRuneToASCII returns a single-quoted Go character literal representing
// the rune. The returned string uses Go escape sequences (\t, \n, \xFF,
// \u0100) for non-ASCII characters and non-printable characters as defined
// by IsPrint.
// If r is not a valid Unicode code point, it is interpreted as the Unicode
// replacement character U+FFFD.
func QuoteRuneToASCII(r rune) string {
return quoteRuneWith(r, '\'', true, false)
}
// AppendQuoteRuneToASCII appends a single-quoted Go character literal representing the rune,
// as generated by QuoteRuneToASCII, to dst and returns the extended buffer.
func AppendQuoteRuneToASCII(dst []byte, r rune) []byte {
return appendQuotedRuneWith(dst, r, '\'', true, false)
}
// QuoteRuneToGraphic returns a single-quoted Go character literal representing
// the rune. If the rune is not a Unicode graphic character,
// as defined by IsGraphic, the returned string will use a Go escape sequence
// (\t, \n, \xFF, \u0100).
// If r is not a valid Unicode code point, it is interpreted as the Unicode
// replacement character U+FFFD.
func QuoteRuneToGraphic(r rune) string {
return quoteRuneWith(r, '\'', false, true)
}
// AppendQuoteRuneToGraphic appends a single-quoted Go character literal representing the rune,
// as generated by QuoteRuneToGraphic, to dst and returns the extended buffer.
func AppendQuoteRuneToGraphic(dst []byte, r rune) []byte {
return appendQuotedRuneWith(dst, r, '\'', false, true)
}
// CanBackquote reports whether the string s can be represented
// unchanged as a single-line backquoted string without control
// characters other than tab.
func CanBackquote(s string) bool {
for len(s) > 0 {
r, wid := utf8.DecodeRuneInString(s)
s = s[wid:]
if wid > 1 {
if r == '\ufeff' {
return false // BOMs are invisible and should not be quoted.
}
continue // All other multibyte runes are correctly encoded and assumed printable.
}
if r == utf8.RuneError {
return false
}
if (r < ' ' && r != '\t') || r == '`' || r == '\u007F' {
return false
}
}
return true
}
func unhex(b byte) (v rune, ok bool) {
c := rune(b)
switch {
case '0' <= c && c <= '9':
return c - '0', true
case 'a' <= c && c <= 'f':
return c - 'a' + 10, true
case 'A' <= c && c <= 'F':
return c - 'A' + 10, true
}
return
}
// UnquoteChar decodes the first character or byte in the escaped string
// or character literal represented by the string s.
// It returns four values:
//
// 1. value, the decoded Unicode code point or byte value;
// 2. multibyte, a boolean indicating whether the decoded character requires a multibyte UTF-8 representation;
// 3. tail, the remainder of the string after the character; and
// 4. an error that will be nil if the character is syntactically valid.
//
// The second argument, quote, specifies the type of literal being parsed
// and therefore which escaped quote character is permitted.
// If set to a single quote, it permits the sequence \' and disallows unescaped '.
// If set to a double quote, it permits \" and disallows unescaped ".
// If set to zero, it does not permit either escape and allows both quote characters to appear unescaped.
func UnquoteChar(s string, quote byte) (value rune, multibyte bool, tail string, err error) {
// easy cases
if len(s) == 0 {
err = ErrSyntax
return
}
switch c := s[0]; {
case c == quote && (quote == '\'' || quote == '"'):
err = ErrSyntax
return
case c >= utf8.RuneSelf:
r, size := utf8.DecodeRuneInString(s)
return r, true, s[size:], nil
case c != '\\':
return rune(s[0]), false, s[1:], nil
}
// hard case: c is backslash
if len(s) <= 1 {
err = ErrSyntax
return
}
c := s[1]
s = s[2:]
switch c {
case 'a':
value = '\a'
case 'b':
value = '\b'
case 'f':
value = '\f'
case 'n':
value = '\n'
case 'r':
value = '\r'
case 't':
value = '\t'
case 'v':
value = '\v'
case 'x', 'u', 'U':
n := 0
switch c {
case 'x':
n = 2
case 'u':
n = 4
case 'U':
n = 8
}
var v rune
if len(s) < n {
err = ErrSyntax
return
}
for j := 0; j < n; j++ {
x, ok := unhex(s[j])
if !ok {
err = ErrSyntax
return
}
v = v<<4 | x
}
s = s[n:]
if c == 'x' {
// single-byte string, possibly not UTF-8
value = v
break
}
if !utf8.ValidRune(v) {
err = ErrSyntax
return
}
value = v
multibyte = true
case '0', '1', '2', '3', '4', '5', '6', '7':
v := rune(c) - '0'
if len(s) < 2 {
err = ErrSyntax
return
}
for j := 0; j < 2; j++ { // one digit already; two more
x := rune(s[j]) - '0'
if x < 0 || x > 7 {
err = ErrSyntax
return
}
v = (v << 3) | x
}
s = s[2:]
if v > 255 {
err = ErrSyntax
return
}
value = v
case '\\':
value = '\\'
case '\'', '"':
if c != quote {
err = ErrSyntax
return
}
value = rune(c)
default:
err = ErrSyntax
return
}
tail = s
return
}
// QuotedPrefix returns the quoted string (as understood by Unquote) at the prefix of s.
// If s does not start with a valid quoted string, QuotedPrefix returns an error.
func QuotedPrefix(s string) (string, error) {
out, _, err := unquote(s, false)
return out, err
}
// Unquote interprets s as a single-quoted, double-quoted,
// or backquoted Go string literal, returning the string value
// that s quotes. (If s is single-quoted, it would be a Go
// character literal; Unquote returns the corresponding
// one-character string.)
func Unquote(s string) (string, error) {
out, rem, err := unquote(s, true)
if len(rem) > 0 {
return "", ErrSyntax
}
return out, err
}
// unquote parses a quoted string at the start of the input,
// returning the parsed prefix, the remaining suffix, and any parse errors.
// If unescape is true, the parsed prefix is unescaped,
// otherwise the input prefix is provided verbatim.
func unquote(in string, unescape bool) (out, rem string, err error) {
// Determine the quote form and optimistically find the terminating quote.
if len(in) < 2 {
return "", in, ErrSyntax
}
quote := in[0]
end := index(in[1:], quote)
if end < 0 {
return "", in, ErrSyntax
}
end += 2 // position after terminating quote; may be wrong if escape sequences are present
switch quote {
case '`':
switch {
case !unescape:
out = in[:end] // include quotes
case !contains(in[:end], '\r'):
out = in[len("`") : end-len("`")] // exclude quotes
default:
// Carriage return characters ('\r') inside raw string literals
// are discarded from the raw string value.
buf := make([]byte, 0, end-len("`")-len("\r")-len("`"))
for i := len("`"); i < end-len("`"); i++ {
if in[i] != '\r' {
buf = append(buf, in[i])
}
}
out = string(buf)
}
// NOTE: Prior implementations did not verify that raw strings consist
// of valid UTF-8 characters and we continue to not verify it as such.
// The Go specification does not explicitly require valid UTF-8,
// but only mention that it is implicitly valid for Go source code
// (which must be valid UTF-8).
return out, in[end:], nil
case '"', '\'':
// Handle quoted strings without any escape sequences.
if !contains(in[:end], '\\') && !contains(in[:end], '\n') {
var valid bool
switch quote {
case '"':
valid = utf8.ValidString(in[len(`"`) : end-len(`"`)])
case '\'':
r, n := utf8.DecodeRuneInString(in[len("'") : end-len("'")])
valid = len("'")+n+len("'") == end && (r != utf8.RuneError || n != 1)
}
if valid {
out = in[:end]
if unescape {
out = out[1 : end-1] // exclude quotes
}
return out, in[end:], nil
}
}
// Handle quoted strings with escape sequences.
var buf []byte
in0 := in
in = in[1:] // skip starting quote
if unescape {
buf = make([]byte, 0, 3*end/2) // try to avoid more allocations
}
for len(in) > 0 && in[0] != quote {
// Process the next character,
// rejecting any unescaped newline characters which are invalid.
r, multibyte, rem, err := UnquoteChar(in, quote)
if in[0] == '\n' || err != nil {
return "", in0, ErrSyntax
}
in = rem
// Append the character if unescaping the input.
if unescape {
if r < utf8.RuneSelf || !multibyte {
buf = append(buf, byte(r))
} else {
var arr [utf8.UTFMax]byte
n := utf8.EncodeRune(arr[:], r)
buf = append(buf, arr[:n]...)
}
}
// Single quoted strings must be a single character.
if quote == '\'' {
break
}
}
// Verify that the string ends with a terminating quote.
if !(len(in) > 0 && in[0] == quote) {
return "", in0, ErrSyntax
}
in = in[1:] // skip terminating quote
if unescape {
return string(buf), in, nil
}
return in0[:len(in0)-len(in)], in, nil
default:
return "", in, ErrSyntax
}
}
// bsearch16 returns the smallest i such that a[i] >= x.
// If there is no such i, bsearch16 returns len(a).
func bsearch16(a []uint16, x uint16) int {
i, j := 0, len(a)
for i < j {
h := i + (j-i)>>1
if a[h] < x {
i = h + 1
} else {
j = h
}
}
return i
}
// bsearch32 returns the smallest i such that a[i] >= x.
// If there is no such i, bsearch32 returns len(a).
func bsearch32(a []uint32, x uint32) int {
i, j := 0, len(a)
for i < j {
h := i + (j-i)>>1
if a[h] < x {
i = h + 1
} else {
j = h
}
}
return i
}
// TODO: IsPrint is a local implementation of unicode.IsPrint, verified by the tests
// to give the same answer. It allows this package not to depend on unicode,
// and therefore not pull in all the Unicode tables. If the linker were better
// at tossing unused tables, we could get rid of this implementation.
// That would be nice.
// IsPrint reports whether the rune is defined as printable by Go, with
// the same definition as unicode.IsPrint: letters, numbers, punctuation,
// symbols and ASCII space.
func IsPrint(r rune) bool {
// Fast check for Latin-1
if r <= 0xFF {
if 0x20 <= r && r <= 0x7E {
// All the ASCII is printable from space through DEL-1.
return true
}
if 0xA1 <= r && r <= 0xFF {
// Similarly for ¡ through ÿ...
return r != 0xAD // ...except for the bizarre soft hyphen.
}
return false
}
// Same algorithm, either on uint16 or uint32 value.
// First, find first i such that isPrint[i] >= x.
// This is the index of either the start or end of a pair that might span x.
// The start is even (isPrint[i&^1]) and the end is odd (isPrint[i|1]).
// If we find x in a range, make sure x is not in isNotPrint list.
if 0 <= r && r < 1<<16 {
rr, isPrint, isNotPrint := uint16(r), isPrint16, isNotPrint16
i := bsearch16(isPrint, rr)
if i >= len(isPrint) || rr < isPrint[i&^1] || isPrint[i|1] < rr {
return false
}
j := bsearch16(isNotPrint, rr)
return j >= len(isNotPrint) || isNotPrint[j] != rr
}
rr, isPrint, isNotPrint := uint32(r), isPrint32, isNotPrint32
i := bsearch32(isPrint, rr)
if i >= len(isPrint) || rr < isPrint[i&^1] || isPrint[i|1] < rr {
return false
}
if r >= 0x20000 {
return true
}
r -= 0x10000
j := bsearch16(isNotPrint, uint16(r))
return j >= len(isNotPrint) || isNotPrint[j] != uint16(r)
}
// IsGraphic reports whether the rune is defined as a Graphic by Unicode. Such
// characters include letters, marks, numbers, punctuation, symbols, and
// spaces, from categories L, M, N, P, S, and Zs.
func IsGraphic(r rune) bool {
if IsPrint(r) {
return true
}
return isInGraphicList(r)
}
// isInGraphicList reports whether the rune is in the isGraphic list. This separation
// from IsGraphic allows quoteWith to avoid two calls to IsPrint.
// Should be called only if IsPrint fails.
func isInGraphicList(r rune) bool {
// We know r must fit in 16 bits - see makeisprint.go.
if r > 0xFFFF {
return false
}
rr := uint16(r)
i := bsearch16(isGraphic, rr)
return i < len(isGraphic) && rr == isGraphic[i]
}
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
import (
"internal/bytealg"
"unicode/utf8"
"unsafe"
)
// A Builder is used to efficiently build a string using Write methods.
// It minimizes memory copying. The zero value is ready to use.
// Do not copy a non-zero Builder.
type Builder struct {
addr *Builder // of receiver, to detect copies by value
buf []byte
}
// noescape hides a pointer from escape analysis. It is the identity function
// but escape analysis doesn't think the output depends on the input.
// noescape is inlined and currently compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime; see issues 23382 and 7921.
//
//go:nosplit
//go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
func (b *Builder) copyCheck() {
if b.addr == nil {
// This hack works around a failing of Go's escape analysis
// that was causing b to escape and be heap allocated.
// See issue 23382.
// TODO: once issue 7921 is fixed, this should be reverted to
// just "b.addr = b".
b.addr = (*Builder)(noescape(unsafe.Pointer(b)))
} else if b.addr != b {
panic("strings: illegal use of non-zero Builder copied by value")
}
}
// String returns the accumulated string.
func (b *Builder) String() string {
return unsafe.String(unsafe.SliceData(b.buf), len(b.buf))
}
// Len returns the number of accumulated bytes; b.Len() == len(b.String()).
func (b *Builder) Len() int { return len(b.buf) }
// Cap returns the capacity of the builder's underlying byte slice. It is the
// total space allocated for the string being built and includes any bytes
// already written.
func (b *Builder) Cap() int { return cap(b.buf) }
// Reset resets the Builder to be empty.
func (b *Builder) Reset() {
b.addr = nil
b.buf = nil
}
// grow copies the buffer to a new, larger buffer so that there are at least n
// bytes of capacity beyond len(b.buf).
func (b *Builder) grow(n int) {
buf := bytealg.MakeNoZero(2*cap(b.buf) + n)[:len(b.buf)]
copy(buf, b.buf)
b.buf = buf
}
// Grow grows b's capacity, if necessary, to guarantee space for
// another n bytes. After Grow(n), at least n bytes can be written to b
// without another allocation. If n is negative, Grow panics.
func (b *Builder) Grow(n int) {
b.copyCheck()
if n < 0 {
panic("strings.Builder.Grow: negative count")
}
if cap(b.buf)-len(b.buf) < n {
b.grow(n)
}
}
// Write appends the contents of p to b's buffer.
// Write always returns len(p), nil.
func (b *Builder) Write(p []byte) (int, error) {
b.copyCheck()
b.buf = append(b.buf, p...)
return len(p), nil
}
// WriteByte appends the byte c to b's buffer.
// The returned error is always nil.
func (b *Builder) WriteByte(c byte) error {
b.copyCheck()
b.buf = append(b.buf, c)
return nil
}
// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer.
// It returns the length of r and a nil error.
func (b *Builder) WriteRune(r rune) (int, error) {
b.copyCheck()
n := len(b.buf)
b.buf = utf8.AppendRune(b.buf, r)
return len(b.buf) - n, nil
}
// WriteString appends the contents of s to b's buffer.
// It returns the length of s and a nil error.
func (b *Builder) WriteString(s string) (int, error) {
b.copyCheck()
b.buf = append(b.buf, s...)
return len(s), nil
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
import (
"unsafe"
)
// Clone returns a fresh copy of s.
// It guarantees to make a copy of s into a new allocation,
// which can be important when retaining only a small substring
// of a much larger string. Using Clone can help such programs
// use less memory. Of course, since using Clone makes a copy,
// overuse of Clone can make programs use more memory.
// Clone should typically be used only rarely, and only when
// profiling indicates that it is needed.
// For strings of length zero the string "" will be returned
// and no allocation is made.
func Clone(s string) string {
if len(s) == 0 {
return ""
}
b := make([]byte, len(s))
copy(b, s)
return unsafe.String(&b[0], len(b))
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
// Compare returns an integer comparing two strings lexicographically.
// The result will be 0 if a == b, -1 if a < b, and +1 if a > b.
//
// Compare is included only for symmetry with package bytes.
// It is usually clearer and always faster to use the built-in
// string comparison operators ==, <, >, and so on.
func Compare(a, b string) int {
// NOTE(rsc): This function does NOT call the runtime cmpstring function,
// because we do not want to provide any performance justification for
// using strings.Compare. Basically no one should use strings.Compare.
// As the comment above says, it is here only for symmetry with package bytes.
// If performance is important, the compiler should be changed to recognize
// the pattern so that all code doing three-way comparisons, not just code
// using strings.Compare, can benefit.
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
import (
"errors"
"io"
"unicode/utf8"
)
// A Reader implements the io.Reader, io.ReaderAt, io.ByteReader, io.ByteScanner,
// io.RuneReader, io.RuneScanner, io.Seeker, and io.WriterTo interfaces by reading
// from a string.
// The zero value for Reader operates like a Reader of an empty string.
type Reader struct {
s string
i int64 // current reading index
prevRune int // index of previous rune; or < 0
}
// Len returns the number of bytes of the unread portion of the
// string.
func (r *Reader) Len() int {
if r.i >= int64(len(r.s)) {
return 0
}
return int(int64(len(r.s)) - r.i)
}
// Size returns the original length of the underlying string.
// Size is the number of bytes available for reading via ReadAt.
// The returned value is always the same and is not affected by calls
// to any other method.
func (r *Reader) Size() int64 { return int64(len(r.s)) }
// Read implements the io.Reader interface.
func (r *Reader) Read(b []byte) (n int, err error) {
if r.i >= int64(len(r.s)) {
return 0, io.EOF
}
r.prevRune = -1
n = copy(b, r.s[r.i:])
r.i += int64(n)
return
}
// ReadAt implements the io.ReaderAt interface.
func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
// cannot modify state - see io.ReaderAt
if off < 0 {
return 0, errors.New("strings.Reader.ReadAt: negative offset")
}
if off >= int64(len(r.s)) {
return 0, io.EOF
}
n = copy(b, r.s[off:])
if n < len(b) {
err = io.EOF
}
return
}
// ReadByte implements the io.ByteReader interface.
func (r *Reader) ReadByte() (byte, error) {
r.prevRune = -1
if r.i >= int64(len(r.s)) {
return 0, io.EOF
}
b := r.s[r.i]
r.i++
return b, nil
}
// UnreadByte implements the io.ByteScanner interface.
func (r *Reader) UnreadByte() error {
if r.i <= 0 {
return errors.New("strings.Reader.UnreadByte: at beginning of string")
}
r.prevRune = -1
r.i--
return nil
}
// ReadRune implements the io.RuneReader interface.
func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= int64(len(r.s)) {
r.prevRune = -1
return 0, 0, io.EOF
}
r.prevRune = int(r.i)
if c := r.s[r.i]; c < utf8.RuneSelf {
r.i++
return rune(c), 1, nil
}
ch, size = utf8.DecodeRuneInString(r.s[r.i:])
r.i += int64(size)
return
}
// UnreadRune implements the io.RuneScanner interface.
func (r *Reader) UnreadRune() error {
if r.i <= 0 {
return errors.New("strings.Reader.UnreadRune: at beginning of string")
}
if r.prevRune < 0 {
return errors.New("strings.Reader.UnreadRune: previous operation was not ReadRune")
}
r.i = int64(r.prevRune)
r.prevRune = -1
return nil
}
// Seek implements the io.Seeker interface.
func (r *Reader) Seek(offset int64, whence int) (int64, error) {
r.prevRune = -1
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = r.i + offset
case io.SeekEnd:
abs = int64(len(r.s)) + offset
default:
return 0, errors.New("strings.Reader.Seek: invalid whence")
}
if abs < 0 {
return 0, errors.New("strings.Reader.Seek: negative position")
}
r.i = abs
return abs, nil
}
// WriteTo implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
r.prevRune = -1
if r.i >= int64(len(r.s)) {
return 0, nil
}
s := r.s[r.i:]
m, err := io.WriteString(w, s)
if m > len(s) {
panic("strings.Reader.WriteTo: invalid WriteString count")
}
r.i += int64(m)
n = int64(m)
if m != len(s) && err == nil {
err = io.ErrShortWrite
}
return
}
// Reset resets the Reader to be reading from s.
func (r *Reader) Reset(s string) { *r = Reader{s, 0, -1} }
// NewReader returns a new Reader reading from s.
// It is similar to bytes.NewBufferString but more efficient and read-only.
func NewReader(s string) *Reader { return &Reader{s, 0, -1} }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
import (
"io"
"sync"
)
// Replacer replaces a list of strings with replacements.
// It is safe for concurrent use by multiple goroutines.
type Replacer struct {
once sync.Once // guards buildOnce method
r replacer
oldnew []string
}
// replacer is the interface that a replacement algorithm needs to implement.
type replacer interface {
Replace(s string) string
WriteString(w io.Writer, s string) (n int, err error)
}
// NewReplacer returns a new Replacer from a list of old, new string
// pairs. Replacements are performed in the order they appear in the
// target string, without overlapping matches. The old string
// comparisons are done in argument order.
//
// NewReplacer panics if given an odd number of arguments.
func NewReplacer(oldnew ...string) *Replacer {
if len(oldnew)%2 == 1 {
panic("strings.NewReplacer: odd argument count")
}
return &Replacer{oldnew: append([]string(nil), oldnew...)}
}
func (r *Replacer) buildOnce() {
r.r = r.build()
r.oldnew = nil
}
func (b *Replacer) build() replacer {
oldnew := b.oldnew
if len(oldnew) == 2 && len(oldnew[0]) > 1 {
return makeSingleStringReplacer(oldnew[0], oldnew[1])
}
allNewBytes := true
for i := 0; i < len(oldnew); i += 2 {
if len(oldnew[i]) != 1 {
return makeGenericReplacer(oldnew)
}
if len(oldnew[i+1]) != 1 {
allNewBytes = false
}
}
if allNewBytes {
r := byteReplacer{}
for i := range r {
r[i] = byte(i)
}
// The first occurrence of old->new map takes precedence
// over the others with the same old string.
for i := len(oldnew) - 2; i >= 0; i -= 2 {
o := oldnew[i][0]
n := oldnew[i+1][0]
r[o] = n
}
return &r
}
r := byteStringReplacer{toReplace: make([]string, 0, len(oldnew)/2)}
// The first occurrence of old->new map takes precedence
// over the others with the same old string.
for i := len(oldnew) - 2; i >= 0; i -= 2 {
o := oldnew[i][0]
n := oldnew[i+1]
// To avoid counting repetitions multiple times.
if r.replacements[o] == nil {
// We need to use string([]byte{o}) instead of string(o),
// to avoid utf8 encoding of o.
// E. g. byte(150) produces string of length 2.
r.toReplace = append(r.toReplace, string([]byte{o}))
}
r.replacements[o] = []byte(n)
}
return &r
}
// Replace returns a copy of s with all replacements performed.
func (r *Replacer) Replace(s string) string {
r.once.Do(r.buildOnce)
return r.r.Replace(s)
}
// WriteString writes s to w with all replacements performed.
func (r *Replacer) WriteString(w io.Writer, s string) (n int, err error) {
r.once.Do(r.buildOnce)
return r.r.WriteString(w, s)
}
// trieNode is a node in a lookup trie for prioritized key/value pairs. Keys
// and values may be empty. For example, the trie containing keys "ax", "ay",
// "bcbc", "x" and "xy" could have eight nodes:
//
// n0 -
// n1 a-
// n2 .x+
// n3 .y+
// n4 b-
// n5 .cbc+
// n6 x+
// n7 .y+
//
// n0 is the root node, and its children are n1, n4 and n6; n1's children are
// n2 and n3; n4's child is n5; n6's child is n7. Nodes n0, n1 and n4 (marked
// with a trailing "-") are partial keys, and nodes n2, n3, n5, n6 and n7
// (marked with a trailing "+") are complete keys.
type trieNode struct {
// value is the value of the trie node's key/value pair. It is empty if
// this node is not a complete key.
value string
// priority is the priority (higher is more important) of the trie node's
// key/value pair; keys are not necessarily matched shortest- or longest-
// first. Priority is positive if this node is a complete key, and zero
// otherwise. In the example above, positive/zero priorities are marked
// with a trailing "+" or "-".
priority int
// A trie node may have zero, one or more child nodes:
// * if the remaining fields are zero, there are no children.
// * if prefix and next are non-zero, there is one child in next.
// * if table is non-zero, it defines all the children.
//
// Prefixes are preferred over tables when there is one child, but the
// root node always uses a table for lookup efficiency.
// prefix is the difference in keys between this trie node and the next.
// In the example above, node n4 has prefix "cbc" and n4's next node is n5.
// Node n5 has no children and so has zero prefix, next and table fields.
prefix string
next *trieNode
// table is a lookup table indexed by the next byte in the key, after
// remapping that byte through genericReplacer.mapping to create a dense
// index. In the example above, the keys only use 'a', 'b', 'c', 'x' and
// 'y', which remap to 0, 1, 2, 3 and 4. All other bytes remap to 5, and
// genericReplacer.tableSize will be 5. Node n0's table will be
// []*trieNode{ 0:n1, 1:n4, 3:n6 }, where the 0, 1 and 3 are the remapped
// 'a', 'b' and 'x'.
table []*trieNode
}
func (t *trieNode) add(key, val string, priority int, r *genericReplacer) {
if key == "" {
if t.priority == 0 {
t.value = val
t.priority = priority
}
return
}
if t.prefix != "" {
// Need to split the prefix among multiple nodes.
var n int // length of the longest common prefix
for ; n < len(t.prefix) && n < len(key); n++ {
if t.prefix[n] != key[n] {
break
}
}
if n == len(t.prefix) {
t.next.add(key[n:], val, priority, r)
} else if n == 0 {
// First byte differs, start a new lookup table here. Looking up
// what is currently t.prefix[0] will lead to prefixNode, and
// looking up key[0] will lead to keyNode.
var prefixNode *trieNode
if len(t.prefix) == 1 {
prefixNode = t.next
} else {
prefixNode = &trieNode{
prefix: t.prefix[1:],
next: t.next,
}
}
keyNode := new(trieNode)
t.table = make([]*trieNode, r.tableSize)
t.table[r.mapping[t.prefix[0]]] = prefixNode
t.table[r.mapping[key[0]]] = keyNode
t.prefix = ""
t.next = nil
keyNode.add(key[1:], val, priority, r)
} else {
// Insert new node after the common section of the prefix.
next := &trieNode{
prefix: t.prefix[n:],
next: t.next,
}
t.prefix = t.prefix[:n]
t.next = next
next.add(key[n:], val, priority, r)
}
} else if t.table != nil {
// Insert into existing table.
m := r.mapping[key[0]]
if t.table[m] == nil {
t.table[m] = new(trieNode)
}
t.table[m].add(key[1:], val, priority, r)
} else {
t.prefix = key
t.next = new(trieNode)
t.next.add("", val, priority, r)
}
}
func (r *genericReplacer) lookup(s string, ignoreRoot bool) (val string, keylen int, found bool) {
// Iterate down the trie to the end, and grab the value and keylen with
// the highest priority.
bestPriority := 0
node := &r.root
n := 0
for node != nil {
if node.priority > bestPriority && !(ignoreRoot && node == &r.root) {
bestPriority = node.priority
val = node.value
keylen = n
found = true
}
if s == "" {
break
}
if node.table != nil {
index := r.mapping[s[0]]
if int(index) == r.tableSize {
break
}
node = node.table[index]
s = s[1:]
n++
} else if node.prefix != "" && HasPrefix(s, node.prefix) {
n += len(node.prefix)
s = s[len(node.prefix):]
node = node.next
} else {
break
}
}
return
}
// genericReplacer is the fully generic algorithm.
// It's used as a fallback when nothing faster can be used.
type genericReplacer struct {
root trieNode
// tableSize is the size of a trie node's lookup table. It is the number
// of unique key bytes.
tableSize int
// mapping maps from key bytes to a dense index for trieNode.table.
mapping [256]byte
}
func makeGenericReplacer(oldnew []string) *genericReplacer {
r := new(genericReplacer)
// Find each byte used, then assign them each an index.
for i := 0; i < len(oldnew); i += 2 {
key := oldnew[i]
for j := 0; j < len(key); j++ {
r.mapping[key[j]] = 1
}
}
for _, b := range r.mapping {
r.tableSize += int(b)
}
var index byte
for i, b := range r.mapping {
if b == 0 {
r.mapping[i] = byte(r.tableSize)
} else {
r.mapping[i] = index
index++
}
}
// Ensure root node uses a lookup table (for performance).
r.root.table = make([]*trieNode, r.tableSize)
for i := 0; i < len(oldnew); i += 2 {
r.root.add(oldnew[i], oldnew[i+1], len(oldnew)-i, r)
}
return r
}
type appendSliceWriter []byte
// Write writes to the buffer to satisfy io.Writer.
func (w *appendSliceWriter) Write(p []byte) (int, error) {
*w = append(*w, p...)
return len(p), nil
}
// WriteString writes to the buffer without string->[]byte->string allocations.
func (w *appendSliceWriter) WriteString(s string) (int, error) {
*w = append(*w, s...)
return len(s), nil
}
type stringWriter struct {
w io.Writer
}
func (w stringWriter) WriteString(s string) (int, error) {
return w.w.Write([]byte(s))
}
func getStringWriter(w io.Writer) io.StringWriter {
sw, ok := w.(io.StringWriter)
if !ok {
sw = stringWriter{w}
}
return sw
}
func (r *genericReplacer) Replace(s string) string {
buf := make(appendSliceWriter, 0, len(s))
r.WriteString(&buf, s)
return string(buf)
}
func (r *genericReplacer) WriteString(w io.Writer, s string) (n int, err error) {
sw := getStringWriter(w)
var last, wn int
var prevMatchEmpty bool
for i := 0; i <= len(s); {
// Fast path: s[i] is not a prefix of any pattern.
if i != len(s) && r.root.priority == 0 {
index := int(r.mapping[s[i]])
if index == r.tableSize || r.root.table[index] == nil {
i++
continue
}
}
// Ignore the empty match iff the previous loop found the empty match.
val, keylen, match := r.lookup(s[i:], prevMatchEmpty)
prevMatchEmpty = match && keylen == 0
if match {
wn, err = sw.WriteString(s[last:i])
n += wn
if err != nil {
return
}
wn, err = sw.WriteString(val)
n += wn
if err != nil {
return
}
i += keylen
last = i
continue
}
i++
}
if last != len(s) {
wn, err = sw.WriteString(s[last:])
n += wn
}
return
}
// singleStringReplacer is the implementation that's used when there is only
// one string to replace (and that string has more than one byte).
type singleStringReplacer struct {
finder *stringFinder
// value is the new string that replaces that pattern when it's found.
value string
}
func makeSingleStringReplacer(pattern string, value string) *singleStringReplacer {
return &singleStringReplacer{finder: makeStringFinder(pattern), value: value}
}
func (r *singleStringReplacer) Replace(s string) string {
var buf Builder
i, matched := 0, false
for {
match := r.finder.next(s[i:])
if match == -1 {
break
}
matched = true
buf.Grow(match + len(r.value))
buf.WriteString(s[i : i+match])
buf.WriteString(r.value)
i += match + len(r.finder.pattern)
}
if !matched {
return s
}
buf.WriteString(s[i:])
return buf.String()
}
func (r *singleStringReplacer) WriteString(w io.Writer, s string) (n int, err error) {
sw := getStringWriter(w)
var i, wn int
for {
match := r.finder.next(s[i:])
if match == -1 {
break
}
wn, err = sw.WriteString(s[i : i+match])
n += wn
if err != nil {
return
}
wn, err = sw.WriteString(r.value)
n += wn
if err != nil {
return
}
i += match + len(r.finder.pattern)
}
wn, err = sw.WriteString(s[i:])
n += wn
return
}
// byteReplacer is the implementation that's used when all the "old"
// and "new" values are single ASCII bytes.
// The array contains replacement bytes indexed by old byte.
type byteReplacer [256]byte
func (r *byteReplacer) Replace(s string) string {
var buf []byte // lazily allocated
for i := 0; i < len(s); i++ {
b := s[i]
if r[b] != b {
if buf == nil {
buf = []byte(s)
}
buf[i] = r[b]
}
}
if buf == nil {
return s
}
return string(buf)
}
func (r *byteReplacer) WriteString(w io.Writer, s string) (n int, err error) {
sw := getStringWriter(w)
last := 0
for i := 0; i < len(s); i++ {
b := s[i]
if r[b] == b {
continue
}
if last != i {
wn, err := sw.WriteString(s[last:i])
n += wn
if err != nil {
return n, err
}
}
last = i + 1
nw, err := w.Write(r[b : int(b)+1])
n += nw
if err != nil {
return n, err
}
}
if last != len(s) {
nw, err := sw.WriteString(s[last:])
n += nw
if err != nil {
return n, err
}
}
return n, nil
}
// byteStringReplacer is the implementation that's used when all the
// "old" values are single ASCII bytes but the "new" values vary in size.
type byteStringReplacer struct {
// replacements contains replacement byte slices indexed by old byte.
// A nil []byte means that the old byte should not be replaced.
replacements [256][]byte
// toReplace keeps a list of bytes to replace. Depending on length of toReplace
// and length of target string it may be faster to use Count, or a plain loop.
// We store single byte as a string, because Count takes a string.
toReplace []string
}
// countCutOff controls the ratio of a string length to a number of replacements
// at which (*byteStringReplacer).Replace switches algorithms.
// For strings with higher ration of length to replacements than that value,
// we call Count, for each replacement from toReplace.
// For strings, with a lower ratio we use simple loop, because of Count overhead.
// countCutOff is an empirically determined overhead multiplier.
// TODO(tocarip) revisit once we have register-based abi/mid-stack inlining.
const countCutOff = 8
func (r *byteStringReplacer) Replace(s string) string {
newSize := len(s)
anyChanges := false
// Is it faster to use Count?
if len(r.toReplace)*countCutOff <= len(s) {
for _, x := range r.toReplace {
if c := Count(s, x); c != 0 {
// The -1 is because we are replacing 1 byte with len(replacements[b]) bytes.
newSize += c * (len(r.replacements[x[0]]) - 1)
anyChanges = true
}
}
} else {
for i := 0; i < len(s); i++ {
b := s[i]
if r.replacements[b] != nil {
// See above for explanation of -1
newSize += len(r.replacements[b]) - 1
anyChanges = true
}
}
}
if !anyChanges {
return s
}
buf := make([]byte, newSize)
j := 0
for i := 0; i < len(s); i++ {
b := s[i]
if r.replacements[b] != nil {
j += copy(buf[j:], r.replacements[b])
} else {
buf[j] = b
j++
}
}
return string(buf)
}
func (r *byteStringReplacer) WriteString(w io.Writer, s string) (n int, err error) {
sw := getStringWriter(w)
last := 0
for i := 0; i < len(s); i++ {
b := s[i]
if r.replacements[b] == nil {
continue
}
if last != i {
nw, err := sw.WriteString(s[last:i])
n += nw
if err != nil {
return n, err
}
}
last = i + 1
nw, err := w.Write(r.replacements[b])
n += nw
if err != nil {
return n, err
}
}
if last != len(s) {
var nw int
nw, err = sw.WriteString(s[last:])
n += nw
}
return
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package strings
// stringFinder efficiently finds strings in a source text. It's implemented
// using the Boyer-Moore string search algorithm:
// https://en.wikipedia.org/wiki/Boyer-Moore_string_search_algorithm
// https://www.cs.utexas.edu/~moore/publications/fstrpos.pdf (note: this aged
// document uses 1-based indexing)
type stringFinder struct {
// pattern is the string that we are searching for in the text.
pattern string
// badCharSkip[b] contains the distance between the last byte of pattern
// and the rightmost occurrence of b in pattern. If b is not in pattern,
// badCharSkip[b] is len(pattern).
//
// Whenever a mismatch is found with byte b in the text, we can safely
// shift the matching frame at least badCharSkip[b] until the next time
// the matching char could be in alignment.
badCharSkip [256]int
// goodSuffixSkip[i] defines how far we can shift the matching frame given
// that the suffix pattern[i+1:] matches, but the byte pattern[i] does
// not. There are two cases to consider:
//
// 1. The matched suffix occurs elsewhere in pattern (with a different
// byte preceding it that we might possibly match). In this case, we can
// shift the matching frame to align with the next suffix chunk. For
// example, the pattern "mississi" has the suffix "issi" next occurring
// (in right-to-left order) at index 1, so goodSuffixSkip[3] ==
// shift+len(suffix) == 3+4 == 7.
//
// 2. If the matched suffix does not occur elsewhere in pattern, then the
// matching frame may share part of its prefix with the end of the
// matching suffix. In this case, goodSuffixSkip[i] will contain how far
// to shift the frame to align this portion of the prefix to the
// suffix. For example, in the pattern "abcxxxabc", when the first
// mismatch from the back is found to be in position 3, the matching
// suffix "xxabc" is not found elsewhere in the pattern. However, its
// rightmost "abc" (at position 6) is a prefix of the whole pattern, so
// goodSuffixSkip[3] == shift+len(suffix) == 6+5 == 11.
goodSuffixSkip []int
}
func makeStringFinder(pattern string) *stringFinder {
f := &stringFinder{
pattern: pattern,
goodSuffixSkip: make([]int, len(pattern)),
}
// last is the index of the last character in the pattern.
last := len(pattern) - 1
// Build bad character table.
// Bytes not in the pattern can skip one pattern's length.
for i := range f.badCharSkip {
f.badCharSkip[i] = len(pattern)
}
// The loop condition is < instead of <= so that the last byte does not
// have a zero distance to itself. Finding this byte out of place implies
// that it is not in the last position.
for i := 0; i < last; i++ {
f.badCharSkip[pattern[i]] = last - i
}
// Build good suffix table.
// First pass: set each value to the next index which starts a prefix of
// pattern.
lastPrefix := last
for i := last; i >= 0; i-- {
if HasPrefix(pattern, pattern[i+1:]) {
lastPrefix = i + 1
}
// lastPrefix is the shift, and (last-i) is len(suffix).
f.goodSuffixSkip[i] = lastPrefix + last - i
}
// Second pass: find repeats of pattern's suffix starting from the front.
for i := 0; i < last; i++ {
lenSuffix := longestCommonSuffix(pattern, pattern[1:i+1])
if pattern[i-lenSuffix] != pattern[last-lenSuffix] {
// (last-i) is the shift, and lenSuffix is len(suffix).
f.goodSuffixSkip[last-lenSuffix] = lenSuffix + last - i
}
}
return f
}
func longestCommonSuffix(a, b string) (i int) {
for ; i < len(a) && i < len(b); i++ {
if a[len(a)-1-i] != b[len(b)-1-i] {
break
}
}
return
}
// next returns the index in text of the first occurrence of the pattern. If
// the pattern is not found, it returns -1.
func (f *stringFinder) next(text string) int {
i := len(f.pattern) - 1
for i < len(text) {
// Compare backwards from the end until the first unmatching character.
j := len(f.pattern) - 1
for j >= 0 && text[i] == f.pattern[j] {
i--
j--
}
if j < 0 {
return i + 1 // match
}
i += max(f.badCharSkip[text[i]], f.goodSuffixSkip[j])
}
return -1
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package strings implements simple functions to manipulate UTF-8 encoded strings.
//
// For information about UTF-8 strings in Go, see https://blog.golang.org/strings.
package strings
import (
"internal/bytealg"
"unicode"
"unicode/utf8"
)
const maxInt = int(^uint(0) >> 1)
// explode splits s into a slice of UTF-8 strings,
// one string per Unicode character up to a maximum of n (n < 0 means no limit).
// Invalid UTF-8 bytes are sliced individually.
func explode(s string, n int) []string {
l := utf8.RuneCountInString(s)
if n < 0 || n > l {
n = l
}
a := make([]string, n)
for i := 0; i < n-1; i++ {
_, size := utf8.DecodeRuneInString(s)
a[i] = s[:size]
s = s[size:]
}
if n > 0 {
a[n-1] = s
}
return a
}
// Count counts the number of non-overlapping instances of substr in s.
// If substr is an empty string, Count returns 1 + the number of Unicode code points in s.
func Count(s, substr string) int {
// special case
if len(substr) == 0 {
return utf8.RuneCountInString(s) + 1
}
if len(substr) == 1 {
return bytealg.CountString(s, substr[0])
}
n := 0
for {
i := Index(s, substr)
if i == -1 {
return n
}
n++
s = s[i+len(substr):]
}
}
// Contains reports whether substr is within s.
func Contains(s, substr string) bool {
return Index(s, substr) >= 0
}
// ContainsAny reports whether any Unicode code points in chars are within s.
func ContainsAny(s, chars string) bool {
return IndexAny(s, chars) >= 0
}
// ContainsRune reports whether the Unicode code point r is within s.
func ContainsRune(s string, r rune) bool {
return IndexRune(s, r) >= 0
}
// ContainsFunc reports whether any Unicode code points r within s satisfy f(r).
func ContainsFunc(s string, f func(rune) bool) bool {
return IndexFunc(s, f) >= 0
}
// LastIndex returns the index of the last instance of substr in s, or -1 if substr is not present in s.
func LastIndex(s, substr string) int {
n := len(substr)
switch {
case n == 0:
return len(s)
case n == 1:
return LastIndexByte(s, substr[0])
case n == len(s):
if substr == s {
return 0
}
return -1
case n > len(s):
return -1
}
// Rabin-Karp search from the end of the string
hashss, pow := bytealg.HashStrRev(substr)
last := len(s) - n
var h uint32
for i := len(s) - 1; i >= last; i-- {
h = h*bytealg.PrimeRK + uint32(s[i])
}
if h == hashss && s[last:] == substr {
return last
}
for i := last - 1; i >= 0; i-- {
h *= bytealg.PrimeRK
h += uint32(s[i])
h -= pow * uint32(s[i+n])
if h == hashss && s[i:i+n] == substr {
return i
}
}
return -1
}
// IndexByte returns the index of the first instance of c in s, or -1 if c is not present in s.
func IndexByte(s string, c byte) int {
return bytealg.IndexByteString(s, c)
}
// IndexRune returns the index of the first instance of the Unicode code point
// r, or -1 if rune is not present in s.
// If r is utf8.RuneError, it returns the first instance of any
// invalid UTF-8 byte sequence.
func IndexRune(s string, r rune) int {
switch {
case 0 <= r && r < utf8.RuneSelf:
return IndexByte(s, byte(r))
case r == utf8.RuneError:
for i, r := range s {
if r == utf8.RuneError {
return i
}
}
return -1
case !utf8.ValidRune(r):
return -1
default:
return Index(s, string(r))
}
}
// IndexAny returns the index of the first instance of any Unicode code point
// from chars in s, or -1 if no Unicode code point from chars is present in s.
func IndexAny(s, chars string) int {
if chars == "" {
// Avoid scanning all of s.
return -1
}
if len(chars) == 1 {
// Avoid scanning all of s.
r := rune(chars[0])
if r >= utf8.RuneSelf {
r = utf8.RuneError
}
return IndexRune(s, r)
}
if len(s) > 8 {
if as, isASCII := makeASCIISet(chars); isASCII {
for i := 0; i < len(s); i++ {
if as.contains(s[i]) {
return i
}
}
return -1
}
}
for i, c := range s {
if IndexRune(chars, c) >= 0 {
return i
}
}
return -1
}
// LastIndexAny returns the index of the last instance of any Unicode code
// point from chars in s, or -1 if no Unicode code point from chars is
// present in s.
func LastIndexAny(s, chars string) int {
if chars == "" {
// Avoid scanning all of s.
return -1
}
if len(s) == 1 {
rc := rune(s[0])
if rc >= utf8.RuneSelf {
rc = utf8.RuneError
}
if IndexRune(chars, rc) >= 0 {
return 0
}
return -1
}
if len(s) > 8 {
if as, isASCII := makeASCIISet(chars); isASCII {
for i := len(s) - 1; i >= 0; i-- {
if as.contains(s[i]) {
return i
}
}
return -1
}
}
if len(chars) == 1 {
rc := rune(chars[0])
if rc >= utf8.RuneSelf {
rc = utf8.RuneError
}
for i := len(s); i > 0; {
r, size := utf8.DecodeLastRuneInString(s[:i])
i -= size
if rc == r {
return i
}
}
return -1
}
for i := len(s); i > 0; {
r, size := utf8.DecodeLastRuneInString(s[:i])
i -= size
if IndexRune(chars, r) >= 0 {
return i
}
}
return -1
}
// LastIndexByte returns the index of the last instance of c in s, or -1 if c is not present in s.
func LastIndexByte(s string, c byte) int {
for i := len(s) - 1; i >= 0; i-- {
if s[i] == c {
return i
}
}
return -1
}
// Generic split: splits after each instance of sep,
// including sepSave bytes of sep in the subarrays.
func genSplit(s, sep string, sepSave, n int) []string {
if n == 0 {
return nil
}
if sep == "" {
return explode(s, n)
}
if n < 0 {
n = Count(s, sep) + 1
}
if n > len(s)+1 {
n = len(s) + 1
}
a := make([]string, n)
n--
i := 0
for i < n {
m := Index(s, sep)
if m < 0 {
break
}
a[i] = s[:m+sepSave]
s = s[m+len(sep):]
i++
}
a[i] = s
return a[:i+1]
}
// SplitN slices s into substrings separated by sep and returns a slice of
// the substrings between those separators.
//
// The count determines the number of substrings to return:
//
// n > 0: at most n substrings; the last substring will be the unsplit remainder.
// n == 0: the result is nil (zero substrings)
// n < 0: all substrings
//
// Edge cases for s and sep (for example, empty strings) are handled
// as described in the documentation for Split.
//
// To split around the first instance of a separator, see Cut.
func SplitN(s, sep string, n int) []string { return genSplit(s, sep, 0, n) }
// SplitAfterN slices s into substrings after each instance of sep and
// returns a slice of those substrings.
//
// The count determines the number of substrings to return:
//
// n > 0: at most n substrings; the last substring will be the unsplit remainder.
// n == 0: the result is nil (zero substrings)
// n < 0: all substrings
//
// Edge cases for s and sep (for example, empty strings) are handled
// as described in the documentation for SplitAfter.
func SplitAfterN(s, sep string, n int) []string {
return genSplit(s, sep, len(sep), n)
}
// Split slices s into all substrings separated by sep and returns a slice of
// the substrings between those separators.
//
// If s does not contain sep and sep is not empty, Split returns a
// slice of length 1 whose only element is s.
//
// If sep is empty, Split splits after each UTF-8 sequence. If both s
// and sep are empty, Split returns an empty slice.
//
// It is equivalent to SplitN with a count of -1.
//
// To split around the first instance of a separator, see Cut.
func Split(s, sep string) []string { return genSplit(s, sep, 0, -1) }
// SplitAfter slices s into all substrings after each instance of sep and
// returns a slice of those substrings.
//
// If s does not contain sep and sep is not empty, SplitAfter returns
// a slice of length 1 whose only element is s.
//
// If sep is empty, SplitAfter splits after each UTF-8 sequence. If
// both s and sep are empty, SplitAfter returns an empty slice.
//
// It is equivalent to SplitAfterN with a count of -1.
func SplitAfter(s, sep string) []string {
return genSplit(s, sep, len(sep), -1)
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
// Fields splits the string s around each instance of one or more consecutive white space
// characters, as defined by unicode.IsSpace, returning a slice of substrings of s or an
// empty slice if s contains only white space.
func Fields(s string) []string {
// First count the fields.
// This is an exact count if s is ASCII, otherwise it is an approximation.
n := 0
wasSpace := 1
// setBits is used to track which bits are set in the bytes of s.
setBits := uint8(0)
for i := 0; i < len(s); i++ {
r := s[i]
setBits |= r
isSpace := int(asciiSpace[r])
n += wasSpace & ^isSpace
wasSpace = isSpace
}
if setBits >= utf8.RuneSelf {
// Some runes in the input string are not ASCII.
return FieldsFunc(s, unicode.IsSpace)
}
// ASCII fast path
a := make([]string, n)
na := 0
fieldStart := 0
i := 0
// Skip spaces in the front of the input.
for i < len(s) && asciiSpace[s[i]] != 0 {
i++
}
fieldStart = i
for i < len(s) {
if asciiSpace[s[i]] == 0 {
i++
continue
}
a[na] = s[fieldStart:i]
na++
i++
// Skip spaces in between fields.
for i < len(s) && asciiSpace[s[i]] != 0 {
i++
}
fieldStart = i
}
if fieldStart < len(s) { // Last field might end at EOF.
a[na] = s[fieldStart:]
}
return a
}
// FieldsFunc splits the string s at each run of Unicode code points c satisfying f(c)
// and returns an array of slices of s. If all code points in s satisfy f(c) or the
// string is empty, an empty slice is returned.
//
// FieldsFunc makes no guarantees about the order in which it calls f(c)
// and assumes that f always returns the same value for a given c.
func FieldsFunc(s string, f func(rune) bool) []string {
// A span is used to record a slice of s of the form s[start:end].
// The start index is inclusive and the end index is exclusive.
type span struct {
start int
end int
}
spans := make([]span, 0, 32)
// Find the field start and end indices.
// Doing this in a separate pass (rather than slicing the string s
// and collecting the result substrings right away) is significantly
// more efficient, possibly due to cache effects.
start := -1 // valid span start if >= 0
for end, rune := range s {
if f(rune) {
if start >= 0 {
spans = append(spans, span{start, end})
// Set start to a negative value.
// Note: using -1 here consistently and reproducibly
// slows down this code by a several percent on amd64.
start = ^start
}
} else {
if start < 0 {
start = end
}
}
}
// Last field might end at EOF.
if start >= 0 {
spans = append(spans, span{start, len(s)})
}
// Create strings from recorded field indices.
a := make([]string, len(spans))
for i, span := range spans {
a[i] = s[span.start:span.end]
}
return a
}
// Join concatenates the elements of its first argument to create a single string. The separator
// string sep is placed between elements in the resulting string.
func Join(elems []string, sep string) string {
switch len(elems) {
case 0:
return ""
case 1:
return elems[0]
}
var n int
if len(sep) > 0 {
if len(sep) >= maxInt/(len(elems)-1) {
panic("strings: Join output length overflow")
}
n += len(sep) * (len(elems) - 1)
}
for _, elem := range elems {
if len(elem) > maxInt-n {
panic("strings: Join output length overflow")
}
n += len(elem)
}
var b Builder
b.Grow(n)
b.WriteString(elems[0])
for _, s := range elems[1:] {
b.WriteString(sep)
b.WriteString(s)
}
return b.String()
}
// HasPrefix tests whether the string s begins with prefix.
func HasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[0:len(prefix)] == prefix
}
// HasSuffix tests whether the string s ends with suffix.
func HasSuffix(s, suffix string) bool {
return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
}
// Map returns a copy of the string s with all its characters modified
// according to the mapping function. If mapping returns a negative value, the character is
// dropped from the string with no replacement.
func Map(mapping func(rune) rune, s string) string {
// In the worst case, the string can grow when mapped, making
// things unpleasant. But it's so rare we barge in assuming it's
// fine. It could also shrink but that falls out naturally.
// The output buffer b is initialized on demand, the first
// time a character differs.
var b Builder
for i, c := range s {
r := mapping(c)
if r == c && c != utf8.RuneError {
continue
}
var width int
if c == utf8.RuneError {
c, width = utf8.DecodeRuneInString(s[i:])
if width != 1 && r == c {
continue
}
} else {
width = utf8.RuneLen(c)
}
b.Grow(len(s) + utf8.UTFMax)
b.WriteString(s[:i])
if r >= 0 {
b.WriteRune(r)
}
s = s[i+width:]
break
}
// Fast path for unchanged input
if b.Cap() == 0 { // didn't call b.Grow above
return s
}
for _, c := range s {
r := mapping(c)
if r >= 0 {
// common case
// Due to inlining, it is more performant to determine if WriteByte should be
// invoked rather than always call WriteRune
if r < utf8.RuneSelf {
b.WriteByte(byte(r))
} else {
// r is not a ASCII rune.
b.WriteRune(r)
}
}
}
return b.String()
}
// Repeat returns a new string consisting of count copies of the string s.
//
// It panics if count is negative or if the result of (len(s) * count)
// overflows.
func Repeat(s string, count int) string {
switch count {
case 0:
return ""
case 1:
return s
}
// Since we cannot return an error on overflow,
// we should panic if the repeat will generate an overflow.
// See golang.org/issue/16237.
if count < 0 {
panic("strings: negative Repeat count")
}
if len(s) >= maxInt/count {
panic("strings: Repeat output length overflow")
}
n := len(s) * count
if len(s) == 0 {
return ""
}
// Past a certain chunk size it is counterproductive to use
// larger chunks as the source of the write, as when the source
// is too large we are basically just thrashing the CPU D-cache.
// So if the result length is larger than an empirically-found
// limit (8KB), we stop growing the source string once the limit
// is reached and keep reusing the same source string - that
// should therefore be always resident in the L1 cache - until we
// have completed the construction of the result.
// This yields significant speedups (up to +100%) in cases where
// the result length is large (roughly, over L2 cache size).
const chunkLimit = 8 * 1024
chunkMax := n
if n > chunkLimit {
chunkMax = chunkLimit / len(s) * len(s)
if chunkMax == 0 {
chunkMax = len(s)
}
}
var b Builder
b.Grow(n)
b.WriteString(s)
for b.Len() < n {
chunk := n - b.Len()
if chunk > b.Len() {
chunk = b.Len()
}
if chunk > chunkMax {
chunk = chunkMax
}
b.WriteString(b.String()[:chunk])
}
return b.String()
}
// ToUpper returns s with all Unicode letters mapped to their upper case.
func ToUpper(s string) string {
isASCII, hasLower := true, false
for i := 0; i < len(s); i++ {
c := s[i]
if c >= utf8.RuneSelf {
isASCII = false
break
}
hasLower = hasLower || ('a' <= c && c <= 'z')
}
if isASCII { // optimize for ASCII-only strings.
if !hasLower {
return s
}
var (
b Builder
pos int
)
b.Grow(len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'a' <= c && c <= 'z' {
c -= 'a' - 'A'
if pos < i {
b.WriteString(s[pos:i])
}
b.WriteByte(c)
pos = i + 1
}
}
if pos < len(s) {
b.WriteString(s[pos:])
}
return b.String()
}
return Map(unicode.ToUpper, s)
}
// ToLower returns s with all Unicode letters mapped to their lower case.
func ToLower(s string) string {
isASCII, hasUpper := true, false
for i := 0; i < len(s); i++ {
c := s[i]
if c >= utf8.RuneSelf {
isASCII = false
break
}
hasUpper = hasUpper || ('A' <= c && c <= 'Z')
}
if isASCII { // optimize for ASCII-only strings.
if !hasUpper {
return s
}
var (
b Builder
pos int
)
b.Grow(len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
if pos < i {
b.WriteString(s[pos:i])
}
b.WriteByte(c)
pos = i + 1
}
}
if pos < len(s) {
b.WriteString(s[pos:])
}
return b.String()
}
return Map(unicode.ToLower, s)
}
// ToTitle returns a copy of the string s with all Unicode letters mapped to
// their Unicode title case.
func ToTitle(s string) string { return Map(unicode.ToTitle, s) }
// ToUpperSpecial returns a copy of the string s with all Unicode letters mapped to their
// upper case using the case mapping specified by c.
func ToUpperSpecial(c unicode.SpecialCase, s string) string {
return Map(c.ToUpper, s)
}
// ToLowerSpecial returns a copy of the string s with all Unicode letters mapped to their
// lower case using the case mapping specified by c.
func ToLowerSpecial(c unicode.SpecialCase, s string) string {
return Map(c.ToLower, s)
}
// ToTitleSpecial returns a copy of the string s with all Unicode letters mapped to their
// Unicode title case, giving priority to the special casing rules.
func ToTitleSpecial(c unicode.SpecialCase, s string) string {
return Map(c.ToTitle, s)
}
// ToValidUTF8 returns a copy of the string s with each run of invalid UTF-8 byte sequences
// replaced by the replacement string, which may be empty.
func ToValidUTF8(s, replacement string) string {
var b Builder
for i, c := range s {
if c != utf8.RuneError {
continue
}
_, wid := utf8.DecodeRuneInString(s[i:])
if wid == 1 {
b.Grow(len(s) + len(replacement))
b.WriteString(s[:i])
s = s[i:]
break
}
}
// Fast path for unchanged input
if b.Cap() == 0 { // didn't call b.Grow above
return s
}
invalid := false // previous byte was from an invalid UTF-8 sequence
for i := 0; i < len(s); {
c := s[i]
if c < utf8.RuneSelf {
i++
invalid = false
b.WriteByte(c)
continue
}
_, wid := utf8.DecodeRuneInString(s[i:])
if wid == 1 {
i++
if !invalid {
invalid = true
b.WriteString(replacement)
}
continue
}
invalid = false
b.WriteString(s[i : i+wid])
i += wid
}
return b.String()
}
// isSeparator reports whether the rune could mark a word boundary.
// TODO: update when package unicode captures more of the properties.
func isSeparator(r rune) bool {
// ASCII alphanumerics and underscore are not separators
if r <= 0x7F {
switch {
case '0' <= r && r <= '9':
return false
case 'a' <= r && r <= 'z':
return false
case 'A' <= r && r <= 'Z':
return false
case r == '_':
return false
}
return true
}
// Letters and digits are not separators
if unicode.IsLetter(r) || unicode.IsDigit(r) {
return false
}
// Otherwise, all we can do for now is treat spaces as separators.
return unicode.IsSpace(r)
}
// Title returns a copy of the string s with all Unicode letters that begin words
// mapped to their Unicode title case.
//
// Deprecated: The rule Title uses for word boundaries does not handle Unicode
// punctuation properly. Use golang.org/x/text/cases instead.
func Title(s string) string {
// Use a closure here to remember state.
// Hackish but effective. Depends on Map scanning in order and calling
// the closure once per rune.
prev := ' '
return Map(
func(r rune) rune {
if isSeparator(prev) {
prev = r
return unicode.ToTitle(r)
}
prev = r
return r
},
s)
}
// TrimLeftFunc returns a slice of the string s with all leading
// Unicode code points c satisfying f(c) removed.
func TrimLeftFunc(s string, f func(rune) bool) string {
i := indexFunc(s, f, false)
if i == -1 {
return ""
}
return s[i:]
}
// TrimRightFunc returns a slice of the string s with all trailing
// Unicode code points c satisfying f(c) removed.
func TrimRightFunc(s string, f func(rune) bool) string {
i := lastIndexFunc(s, f, false)
if i >= 0 && s[i] >= utf8.RuneSelf {
_, wid := utf8.DecodeRuneInString(s[i:])
i += wid
} else {
i++
}
return s[0:i]
}
// TrimFunc returns a slice of the string s with all leading
// and trailing Unicode code points c satisfying f(c) removed.
func TrimFunc(s string, f func(rune) bool) string {
return TrimRightFunc(TrimLeftFunc(s, f), f)
}
// IndexFunc returns the index into s of the first Unicode
// code point satisfying f(c), or -1 if none do.
func IndexFunc(s string, f func(rune) bool) int {
return indexFunc(s, f, true)
}
// LastIndexFunc returns the index into s of the last
// Unicode code point satisfying f(c), or -1 if none do.
func LastIndexFunc(s string, f func(rune) bool) int {
return lastIndexFunc(s, f, true)
}
// indexFunc is the same as IndexFunc except that if
// truth==false, the sense of the predicate function is
// inverted.
func indexFunc(s string, f func(rune) bool, truth bool) int {
for i, r := range s {
if f(r) == truth {
return i
}
}
return -1
}
// lastIndexFunc is the same as LastIndexFunc except that if
// truth==false, the sense of the predicate function is
// inverted.
func lastIndexFunc(s string, f func(rune) bool, truth bool) int {
for i := len(s); i > 0; {
r, size := utf8.DecodeLastRuneInString(s[0:i])
i -= size
if f(r) == truth {
return i
}
}
return -1
}
// asciiSet is a 32-byte value, where each bit represents the presence of a
// given ASCII character in the set. The 128-bits of the lower 16 bytes,
// starting with the least-significant bit of the lowest word to the
// most-significant bit of the highest word, map to the full range of all
// 128 ASCII characters. The 128-bits of the upper 16 bytes will be zeroed,
// ensuring that any non-ASCII character will be reported as not in the set.
// This allocates a total of 32 bytes even though the upper half
// is unused to avoid bounds checks in asciiSet.contains.
type asciiSet [8]uint32
// makeASCIISet creates a set of ASCII characters and reports whether all
// characters in chars are ASCII.
func makeASCIISet(chars string) (as asciiSet, ok bool) {
for i := 0; i < len(chars); i++ {
c := chars[i]
if c >= utf8.RuneSelf {
return as, false
}
as[c/32] |= 1 << (c % 32)
}
return as, true
}
// contains reports whether c is inside the set.
func (as *asciiSet) contains(c byte) bool {
return (as[c/32] & (1 << (c % 32))) != 0
}
// Trim returns a slice of the string s with all leading and
// trailing Unicode code points contained in cutset removed.
func Trim(s, cutset string) string {
if s == "" || cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimLeftByte(trimRightByte(s, cutset[0]), cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimLeftASCII(trimRightASCII(s, &as), &as)
}
return trimLeftUnicode(trimRightUnicode(s, cutset), cutset)
}
// TrimLeft returns a slice of the string s with all leading
// Unicode code points contained in cutset removed.
//
// To remove a prefix, use TrimPrefix instead.
func TrimLeft(s, cutset string) string {
if s == "" || cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimLeftByte(s, cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimLeftASCII(s, &as)
}
return trimLeftUnicode(s, cutset)
}
func trimLeftByte(s string, c byte) string {
for len(s) > 0 && s[0] == c {
s = s[1:]
}
return s
}
func trimLeftASCII(s string, as *asciiSet) string {
for len(s) > 0 {
if !as.contains(s[0]) {
break
}
s = s[1:]
}
return s
}
func trimLeftUnicode(s, cutset string) string {
for len(s) > 0 {
r, n := rune(s[0]), 1
if r >= utf8.RuneSelf {
r, n = utf8.DecodeRuneInString(s)
}
if !ContainsRune(cutset, r) {
break
}
s = s[n:]
}
return s
}
// TrimRight returns a slice of the string s, with all trailing
// Unicode code points contained in cutset removed.
//
// To remove a suffix, use TrimSuffix instead.
func TrimRight(s, cutset string) string {
if s == "" || cutset == "" {
return s
}
if len(cutset) == 1 && cutset[0] < utf8.RuneSelf {
return trimRightByte(s, cutset[0])
}
if as, ok := makeASCIISet(cutset); ok {
return trimRightASCII(s, &as)
}
return trimRightUnicode(s, cutset)
}
func trimRightByte(s string, c byte) string {
for len(s) > 0 && s[len(s)-1] == c {
s = s[:len(s)-1]
}
return s
}
func trimRightASCII(s string, as *asciiSet) string {
for len(s) > 0 {
if !as.contains(s[len(s)-1]) {
break
}
s = s[:len(s)-1]
}
return s
}
func trimRightUnicode(s, cutset string) string {
for len(s) > 0 {
r, n := rune(s[len(s)-1]), 1
if r >= utf8.RuneSelf {
r, n = utf8.DecodeLastRuneInString(s)
}
if !ContainsRune(cutset, r) {
break
}
s = s[:len(s)-n]
}
return s
}
// TrimSpace returns a slice of the string s, with all leading
// and trailing white space removed, as defined by Unicode.
func TrimSpace(s string) string {
// Fast path for ASCII: look for the first ASCII non-space byte
start := 0
for ; start < len(s); start++ {
c := s[start]
if c >= utf8.RuneSelf {
// If we run into a non-ASCII byte, fall back to the
// slower unicode-aware method on the remaining bytes
return TrimFunc(s[start:], unicode.IsSpace)
}
if asciiSpace[c] == 0 {
break
}
}
// Now look for the first ASCII non-space byte from the end
stop := len(s)
for ; stop > start; stop-- {
c := s[stop-1]
if c >= utf8.RuneSelf {
// start has been already trimmed above, should trim end only
return TrimRightFunc(s[start:stop], unicode.IsSpace)
}
if asciiSpace[c] == 0 {
break
}
}
// At this point s[start:stop] starts and ends with an ASCII
// non-space bytes, so we're done. Non-ASCII cases have already
// been handled above.
return s[start:stop]
}
// TrimPrefix returns s without the provided leading prefix string.
// If s doesn't start with prefix, s is returned unchanged.
func TrimPrefix(s, prefix string) string {
if HasPrefix(s, prefix) {
return s[len(prefix):]
}
return s
}
// TrimSuffix returns s without the provided trailing suffix string.
// If s doesn't end with suffix, s is returned unchanged.
func TrimSuffix(s, suffix string) string {
if HasSuffix(s, suffix) {
return s[:len(s)-len(suffix)]
}
return s
}
// Replace returns a copy of the string s with the first n
// non-overlapping instances of old replaced by new.
// If old is empty, it matches at the beginning of the string
// and after each UTF-8 sequence, yielding up to k+1 replacements
// for a k-rune string.
// If n < 0, there is no limit on the number of replacements.
func Replace(s, old, new string, n int) string {
if old == new || n == 0 {
return s // avoid allocation
}
// Compute number of replacements.
if m := Count(s, old); m == 0 {
return s // avoid allocation
} else if n < 0 || m < n {
n = m
}
// Apply replacements to buffer.
var b Builder
b.Grow(len(s) + n*(len(new)-len(old)))
start := 0
for i := 0; i < n; i++ {
j := start
if len(old) == 0 {
if i > 0 {
_, wid := utf8.DecodeRuneInString(s[start:])
j += wid
}
} else {
j += Index(s[start:], old)
}
b.WriteString(s[start:j])
b.WriteString(new)
start = j + len(old)
}
b.WriteString(s[start:])
return b.String()
}
// ReplaceAll returns a copy of the string s with all
// non-overlapping instances of old replaced by new.
// If old is empty, it matches at the beginning of the string
// and after each UTF-8 sequence, yielding up to k+1 replacements
// for a k-rune string.
func ReplaceAll(s, old, new string) string {
return Replace(s, old, new, -1)
}
// EqualFold reports whether s and t, interpreted as UTF-8 strings,
// are equal under simple Unicode case-folding, which is a more general
// form of case-insensitivity.
func EqualFold(s, t string) bool {
// ASCII fast path
i := 0
for ; i < len(s) && i < len(t); i++ {
sr := s[i]
tr := t[i]
if sr|tr >= utf8.RuneSelf {
goto hasUnicode
}
// Easy case.
if tr == sr {
continue
}
// Make sr < tr to simplify what follows.
if tr < sr {
tr, sr = sr, tr
}
// ASCII only, sr/tr must be upper/lower case
if 'A' <= sr && sr <= 'Z' && tr == sr+'a'-'A' {
continue
}
return false
}
// Check if we've exhausted both strings.
return len(s) == len(t)
hasUnicode:
s = s[i:]
t = t[i:]
for _, sr := range s {
// If t is exhausted the strings are not equal.
if len(t) == 0 {
return false
}
// Extract first rune from second string.
var tr rune
if t[0] < utf8.RuneSelf {
tr, t = rune(t[0]), t[1:]
} else {
r, size := utf8.DecodeRuneInString(t)
tr, t = r, t[size:]
}
// If they match, keep going; if not, return false.
// Easy case.
if tr == sr {
continue
}
// Make sr < tr to simplify what follows.
if tr < sr {
tr, sr = sr, tr
}
// Fast check for ASCII.
if tr < utf8.RuneSelf {
// ASCII only, sr/tr must be upper/lower case
if 'A' <= sr && sr <= 'Z' && tr == sr+'a'-'A' {
continue
}
return false
}
// General case. SimpleFold(x) returns the next equivalent rune > x
// or wraps around to smaller values.
r := unicode.SimpleFold(sr)
for r != sr && r < tr {
r = unicode.SimpleFold(r)
}
if r == tr {
continue
}
return false
}
// First string is empty, so check if the second one is also empty.
return len(t) == 0
}
// Index returns the index of the first instance of substr in s, or -1 if substr is not present in s.
func Index(s, substr string) int {
n := len(substr)
switch {
case n == 0:
return 0
case n == 1:
return IndexByte(s, substr[0])
case n == len(s):
if substr == s {
return 0
}
return -1
case n > len(s):
return -1
case n <= bytealg.MaxLen:
// Use brute force when s and substr both are small
if len(s) <= bytealg.MaxBruteForce {
return bytealg.IndexString(s, substr)
}
c0 := substr[0]
c1 := substr[1]
i := 0
t := len(s) - n + 1
fails := 0
for i < t {
if s[i] != c0 {
// IndexByte is faster than bytealg.IndexString, so use it as long as
// we're not getting lots of false positives.
o := IndexByte(s[i+1:t], c0)
if o < 0 {
return -1
}
i += o + 1
}
if s[i+1] == c1 && s[i:i+n] == substr {
return i
}
fails++
i++
// Switch to bytealg.IndexString when IndexByte produces too many false positives.
if fails > bytealg.Cutover(i) {
r := bytealg.IndexString(s[i:], substr)
if r >= 0 {
return r + i
}
return -1
}
}
return -1
}
c0 := substr[0]
c1 := substr[1]
i := 0
t := len(s) - n + 1
fails := 0
for i < t {
if s[i] != c0 {
o := IndexByte(s[i+1:t], c0)
if o < 0 {
return -1
}
i += o + 1
}
if s[i+1] == c1 && s[i:i+n] == substr {
return i
}
i++
fails++
if fails >= 4+i>>4 && i < t {
// See comment in ../bytes/bytes.go.
j := bytealg.IndexRabinKarp(s[i:], substr)
if j < 0 {
return -1
}
return i + j
}
}
return -1
}
// Cut slices s around the first instance of sep,
// returning the text before and after sep.
// The found result reports whether sep appears in s.
// If sep does not appear in s, cut returns s, "", false.
func Cut(s, sep string) (before, after string, found bool) {
if i := Index(s, sep); i >= 0 {
return s[:i], s[i+len(sep):], true
}
return s, "", false
}
// CutPrefix returns s without the provided leading prefix string
// and reports whether it found the prefix.
// If s doesn't start with prefix, CutPrefix returns s, false.
// If prefix is the empty string, CutPrefix returns s, true.
func CutPrefix(s, prefix string) (after string, found bool) {
if !HasPrefix(s, prefix) {
return s, false
}
return s[len(prefix):], true
}
// CutSuffix returns s without the provided ending suffix string
// and reports whether it found the suffix.
// If s doesn't end with suffix, CutSuffix returns s, false.
// If suffix is the empty string, CutSuffix returns s, true.
func CutSuffix(s, suffix string) (before string, found bool) {
if !HasSuffix(s, suffix) {
return s, false
}
return s[:len(s)-len(suffix)], true
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomic
import "unsafe"
// A Bool is an atomic boolean value.
// The zero value is false.
type Bool struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Bool) Load() bool { return LoadUint32(&x.v) != 0 }
// Store atomically stores val into x.
func (x *Bool) Store(val bool) { StoreUint32(&x.v, b32(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Bool) Swap(new bool) (old bool) { return SwapUint32(&x.v, b32(new)) != 0 }
// CompareAndSwap executes the compare-and-swap operation for the boolean value x.
func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) {
return CompareAndSwapUint32(&x.v, b32(old), b32(new))
}
// b32 returns a uint32 0 or 1 representing b.
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}
// For testing *Pointer[T]'s methods can be inlined.
// Keep in sync with cmd/compile/internal/test/inl_test.go:TestIntendedInlining.
var _ = &Pointer[int]{}
// A Pointer is an atomic pointer of type *T. The zero value is a nil *T.
type Pointer[T any] struct {
// Mention *T in a field to disallow conversion between Pointer types.
// See go.dev/issue/56603 for more details.
// Use *T, not T, to avoid spurious recursive type definition errors.
_ [0]*T
_ noCopy
v unsafe.Pointer
}
// Load atomically loads and returns the value stored in x.
func (x *Pointer[T]) Load() *T { return (*T)(LoadPointer(&x.v)) }
// Store atomically stores val into x.
func (x *Pointer[T]) Store(val *T) { StorePointer(&x.v, unsafe.Pointer(val)) }
// Swap atomically stores new into x and returns the previous value.
func (x *Pointer[T]) Swap(new *T) (old *T) { return (*T)(SwapPointer(&x.v, unsafe.Pointer(new))) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) {
return CompareAndSwapPointer(&x.v, unsafe.Pointer(old), unsafe.Pointer(new))
}
// An Int32 is an atomic int32. The zero value is zero.
type Int32 struct {
_ noCopy
v int32
}
// Load atomically loads and returns the value stored in x.
func (x *Int32) Load() int32 { return LoadInt32(&x.v) }
// Store atomically stores val into x.
func (x *Int32) Store(val int32) { StoreInt32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int32) Swap(new int32) (old int32) { return SwapInt32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int32) CompareAndSwap(old, new int32) (swapped bool) {
return CompareAndSwapInt32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int32) Add(delta int32) (new int32) { return AddInt32(&x.v, delta) }
// An Int64 is an atomic int64. The zero value is zero.
type Int64 struct {
_ noCopy
_ align64
v int64
}
// Load atomically loads and returns the value stored in x.
func (x *Int64) Load() int64 { return LoadInt64(&x.v) }
// Store atomically stores val into x.
func (x *Int64) Store(val int64) { StoreInt64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Int64) Swap(new int64) (old int64) { return SwapInt64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Int64) CompareAndSwap(old, new int64) (swapped bool) {
return CompareAndSwapInt64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Int64) Add(delta int64) (new int64) { return AddInt64(&x.v, delta) }
// An Uint32 is an atomic uint32. The zero value is zero.
type Uint32 struct {
_ noCopy
v uint32
}
// Load atomically loads and returns the value stored in x.
func (x *Uint32) Load() uint32 { return LoadUint32(&x.v) }
// Store atomically stores val into x.
func (x *Uint32) Store(val uint32) { StoreUint32(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint32) Swap(new uint32) (old uint32) { return SwapUint32(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint32) CompareAndSwap(old, new uint32) (swapped bool) {
return CompareAndSwapUint32(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint32) Add(delta uint32) (new uint32) { return AddUint32(&x.v, delta) }
// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
_ noCopy
_ align64
v uint64
}
// Load atomically loads and returns the value stored in x.
func (x *Uint64) Load() uint64 { return LoadUint64(&x.v) }
// Store atomically stores val into x.
func (x *Uint64) Store(val uint64) { StoreUint64(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uint64) Swap(new uint64) (old uint64) { return SwapUint64(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uint64) CompareAndSwap(old, new uint64) (swapped bool) {
return CompareAndSwapUint64(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uint64) Add(delta uint64) (new uint64) { return AddUint64(&x.v, delta) }
// An Uintptr is an atomic uintptr. The zero value is zero.
type Uintptr struct {
_ noCopy
v uintptr
}
// Load atomically loads and returns the value stored in x.
func (x *Uintptr) Load() uintptr { return LoadUintptr(&x.v) }
// Store atomically stores val into x.
func (x *Uintptr) Store(val uintptr) { StoreUintptr(&x.v, val) }
// Swap atomically stores new into x and returns the previous value.
func (x *Uintptr) Swap(new uintptr) (old uintptr) { return SwapUintptr(&x.v, new) }
// CompareAndSwap executes the compare-and-swap operation for x.
func (x *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) {
return CompareAndSwapUintptr(&x.v, old, new)
}
// Add atomically adds delta to x and returns the new value.
func (x *Uintptr) Add(delta uintptr) (new uintptr) { return AddUintptr(&x.v, delta) }
// noCopy may be added to structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
//
// Note that it must not be embedded, due to the Lock and Unlock methods.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// align64 may be added to structs that must be 64-bit aligned.
// This struct is recognized by a special case in the compiler
// and will not work if copied to any other package.
type align64 struct{}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomic
import (
"unsafe"
)
// A Value provides an atomic load and store of a consistently typed value.
// The zero value for a Value returns nil from Load.
// Once Store has been called, a Value must not be copied.
//
// A Value must not be copied after first use.
type Value struct {
v any
}
// efaceWords is interface{} internal representation.
type efaceWords struct {
typ unsafe.Pointer
data unsafe.Pointer
}
// Load returns the value set by the most recent Store.
// It returns nil if there has been no call to Store for this Value.
func (v *Value) Load() (val any) {
vp := (*efaceWords)(unsafe.Pointer(v))
typ := LoadPointer(&vp.typ)
if typ == nil || typ == unsafe.Pointer(&firstStoreInProgress) {
// First store not yet completed.
return nil
}
data := LoadPointer(&vp.data)
vlp := (*efaceWords)(unsafe.Pointer(&val))
vlp.typ = typ
vlp.data = data
return
}
var firstStoreInProgress byte
// Store sets the value of the Value v to val.
// All calls to Store for a given Value must use values of the same concrete type.
// Store of an inconsistent type panics, as does Store(nil).
func (v *Value) Store(val any) {
if val == nil {
panic("sync/atomic: store of nil value into Value")
}
vp := (*efaceWords)(unsafe.Pointer(v))
vlp := (*efaceWords)(unsafe.Pointer(&val))
for {
typ := LoadPointer(&vp.typ)
if typ == nil {
// Attempt to start first store.
// Disable preemption so that other goroutines can use
// active spin wait to wait for completion.
runtime_procPin()
if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(&firstStoreInProgress)) {
runtime_procUnpin()
continue
}
// Complete first store.
StorePointer(&vp.data, vlp.data)
StorePointer(&vp.typ, vlp.typ)
runtime_procUnpin()
return
}
if typ == unsafe.Pointer(&firstStoreInProgress) {
// First store in progress. Wait.
// Since we disable preemption around the first store,
// we can wait with active spinning.
continue
}
// First store completed. Check type and overwrite data.
if typ != vlp.typ {
panic("sync/atomic: store of inconsistently typed value into Value")
}
StorePointer(&vp.data, vlp.data)
return
}
}
// Swap stores new into Value and returns the previous value. It returns nil if
// the Value is empty.
//
// All calls to Swap for a given Value must use values of the same concrete
// type. Swap of an inconsistent type panics, as does Swap(nil).
func (v *Value) Swap(new any) (old any) {
if new == nil {
panic("sync/atomic: swap of nil value into Value")
}
vp := (*efaceWords)(unsafe.Pointer(v))
np := (*efaceWords)(unsafe.Pointer(&new))
for {
typ := LoadPointer(&vp.typ)
if typ == nil {
// Attempt to start first store.
// Disable preemption so that other goroutines can use
// active spin wait to wait for completion; and so that
// GC does not see the fake type accidentally.
runtime_procPin()
if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(&firstStoreInProgress)) {
runtime_procUnpin()
continue
}
// Complete first store.
StorePointer(&vp.data, np.data)
StorePointer(&vp.typ, np.typ)
runtime_procUnpin()
return nil
}
if typ == unsafe.Pointer(&firstStoreInProgress) {
// First store in progress. Wait.
// Since we disable preemption around the first store,
// we can wait with active spinning.
continue
}
// First store completed. Check type and overwrite data.
if typ != np.typ {
panic("sync/atomic: swap of inconsistently typed value into Value")
}
op := (*efaceWords)(unsafe.Pointer(&old))
op.typ, op.data = np.typ, SwapPointer(&vp.data, np.data)
return old
}
}
// CompareAndSwap executes the compare-and-swap operation for the Value.
//
// All calls to CompareAndSwap for a given Value must use values of the same
// concrete type. CompareAndSwap of an inconsistent type panics, as does
// CompareAndSwap(old, nil).
func (v *Value) CompareAndSwap(old, new any) (swapped bool) {
if new == nil {
panic("sync/atomic: compare and swap of nil value into Value")
}
vp := (*efaceWords)(unsafe.Pointer(v))
np := (*efaceWords)(unsafe.Pointer(&new))
op := (*efaceWords)(unsafe.Pointer(&old))
if op.typ != nil && np.typ != op.typ {
panic("sync/atomic: compare and swap of inconsistently typed values")
}
for {
typ := LoadPointer(&vp.typ)
if typ == nil {
if old != nil {
return false
}
// Attempt to start first store.
// Disable preemption so that other goroutines can use
// active spin wait to wait for completion; and so that
// GC does not see the fake type accidentally.
runtime_procPin()
if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(&firstStoreInProgress)) {
runtime_procUnpin()
continue
}
// Complete first store.
StorePointer(&vp.data, np.data)
StorePointer(&vp.typ, np.typ)
runtime_procUnpin()
return true
}
if typ == unsafe.Pointer(&firstStoreInProgress) {
// First store in progress. Wait.
// Since we disable preemption around the first store,
// we can wait with active spinning.
continue
}
// First store completed. Check type and overwrite data.
if typ != np.typ {
panic("sync/atomic: compare and swap of inconsistently typed value into Value")
}
// Compare old and current via runtime equality check.
// This allows value types to be compared, something
// not offered by the package functions.
// CompareAndSwapPointer below only ensures vp.data
// has not changed since LoadPointer.
data := LoadPointer(&vp.data)
var i any
(*efaceWords)(unsafe.Pointer(&i)).typ = typ
(*efaceWords)(unsafe.Pointer(&i)).data = data
if i != old {
return false
}
return CompareAndSwapPointer(&vp.data, data, np.data)
}
}
// Disable/enable preemption, implemented in runtime.
func runtime_procPin() int
func runtime_procUnpin()
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"sync/atomic"
"unsafe"
)
// Cond implements a condition variable, a rendezvous point
// for goroutines waiting for or announcing the occurrence
// of an event.
//
// Each Cond has an associated Locker L (often a *Mutex or *RWMutex),
// which must be held when changing the condition and
// when calling the Wait method.
//
// A Cond must not be copied after first use.
//
// In the terminology of the Go memory model, Cond arranges that
// a call to Broadcast or Signal “synchronizes before” any Wait call
// that it unblocks.
//
// For many simple use cases, users will be better off using channels than a
// Cond (Broadcast corresponds to closing a channel, and Signal corresponds to
// sending on a channel).
//
// For more on replacements for sync.Cond, see [Roberto Clapis's series on
// advanced concurrency patterns], as well as [Bryan Mills's talk on concurrency
// patterns].
//
// [Roberto Clapis's series on advanced concurrency patterns]: https://blogtitle.github.io/categories/concurrency/
// [Bryan Mills's talk on concurrency patterns]: https://drive.google.com/file/d/1nPdvhB0PutEJzdCq5ms6UI58dp50fcAN/view
type Cond struct {
noCopy noCopy
// L is held while observing or changing the condition
L Locker
notify notifyList
checker copyChecker
}
// NewCond returns a new Cond with Locker l.
func NewCond(l Locker) *Cond {
return &Cond{L: l}
}
// Wait atomically unlocks c.L and suspends execution
// of the calling goroutine. After later resuming execution,
// Wait locks c.L before returning. Unlike in other systems,
// Wait cannot return unless awoken by Broadcast or Signal.
//
// Because c.L is not locked while Wait is waiting, the caller
// typically cannot assume that the condition is true when
// Wait returns. Instead, the caller should Wait in a loop:
//
// c.L.Lock()
// for !condition() {
// c.Wait()
// }
// ... make use of condition ...
// c.L.Unlock()
func (c *Cond) Wait() {
c.checker.check()
t := runtime_notifyListAdd(&c.notify)
c.L.Unlock()
runtime_notifyListWait(&c.notify, t)
c.L.Lock()
}
// Signal wakes one goroutine waiting on c, if there is any.
//
// It is allowed but not required for the caller to hold c.L
// during the call.
//
// Signal() does not affect goroutine scheduling priority; if other goroutines
// are attempting to lock c.L, they may be awoken before a "waiting" goroutine.
func (c *Cond) Signal() {
c.checker.check()
runtime_notifyListNotifyOne(&c.notify)
}
// Broadcast wakes all goroutines waiting on c.
//
// It is allowed but not required for the caller to hold c.L
// during the call.
func (c *Cond) Broadcast() {
c.checker.check()
runtime_notifyListNotifyAll(&c.notify)
}
// copyChecker holds back pointer to itself to detect object copying.
type copyChecker uintptr
func (c *copyChecker) check() {
if uintptr(*c) != uintptr(unsafe.Pointer(c)) &&
!atomic.CompareAndSwapUintptr((*uintptr)(c), 0, uintptr(unsafe.Pointer(c))) &&
uintptr(*c) != uintptr(unsafe.Pointer(c)) {
panic("sync.Cond is copied")
}
}
// noCopy may be added to structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
//
// Note that it must not be embedded, due to the Lock and Unlock methods.
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"sync/atomic"
)
// Map is like a Go map[interface{}]interface{} but is safe for concurrent use
// by multiple goroutines without additional locking or coordination.
// Loads, stores, and deletes run in amortized constant time.
//
// The Map type is specialized. Most code should use a plain Go map instead,
// with separate locking or coordination, for better type safety and to make it
// easier to maintain other invariants along with the map content.
//
// The Map type is optimized for two common use cases: (1) when the entry for a given
// key is only ever written once but read many times, as in caches that only grow,
// or (2) when multiple goroutines read, write, and overwrite entries for disjoint
// sets of keys. In these two cases, use of a Map may significantly reduce lock
// contention compared to a Go map paired with a separate Mutex or RWMutex.
//
// The zero Map is empty and ready for use. A Map must not be copied after first use.
//
// In the terminology of the Go memory model, Map arranges that a write operation
// “synchronizes before” any read operation that observes the effect of the write, where
// read and write operations are defined as follows.
// Load, LoadAndDelete, LoadOrStore, Swap, CompareAndSwap, and CompareAndDelete
// are read operations; Delete, LoadAndDelete, Store, and Swap are write operations;
// LoadOrStore is a write operation when it returns loaded set to false;
// CompareAndSwap is a write operation when it returns swapped set to true;
// and CompareAndDelete is a write operation when it returns deleted set to true.
type Map struct {
mu Mutex
// read contains the portion of the map's contents that are safe for
// concurrent access (with or without mu held).
//
// The read field itself is always safe to load, but must only be stored with
// mu held.
//
// Entries stored in read may be updated concurrently without mu, but updating
// a previously-expunged entry requires that the entry be copied to the dirty
// map and unexpunged with mu held.
read atomic.Pointer[readOnly]
// dirty contains the portion of the map's contents that require mu to be
// held. To ensure that the dirty map can be promoted to the read map quickly,
// it also includes all of the non-expunged entries in the read map.
//
// Expunged entries are not stored in the dirty map. An expunged entry in the
// clean map must be unexpunged and added to the dirty map before a new value
// can be stored to it.
//
// If the dirty map is nil, the next write to the map will initialize it by
// making a shallow copy of the clean map, omitting stale entries.
dirty map[any]*entry
// misses counts the number of loads since the read map was last updated that
// needed to lock mu to determine whether the key was present.
//
// Once enough misses have occurred to cover the cost of copying the dirty
// map, the dirty map will be promoted to the read map (in the unamended
// state) and the next store to the map will make a new dirty copy.
misses int
}
// readOnly is an immutable struct stored atomically in the Map.read field.
type readOnly struct {
m map[any]*entry
amended bool // true if the dirty map contains some key not in m.
}
// expunged is an arbitrary pointer that marks entries which have been deleted
// from the dirty map.
var expunged = new(any)
// An entry is a slot in the map corresponding to a particular key.
type entry struct {
// p points to the interface{} value stored for the entry.
//
// If p == nil, the entry has been deleted, and either m.dirty == nil or
// m.dirty[key] is e.
//
// If p == expunged, the entry has been deleted, m.dirty != nil, and the entry
// is missing from m.dirty.
//
// Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty
// != nil, in m.dirty[key].
//
// An entry can be deleted by atomic replacement with nil: when m.dirty is
// next created, it will atomically replace nil with expunged and leave
// m.dirty[key] unset.
//
// An entry's associated value can be updated by atomic replacement, provided
// p != expunged. If p == expunged, an entry's associated value can be updated
// only after first setting m.dirty[key] = e so that lookups using the dirty
// map find the entry.
p atomic.Pointer[any]
}
func newEntry(i any) *entry {
e := &entry{}
e.p.Store(&i)
return e
}
func (m *Map) loadReadOnly() readOnly {
if p := m.read.Load(); p != nil {
return *p
}
return readOnly{}
}
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map) Load(key any) (value any, ok bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
// Avoid reporting a spurious miss if m.dirty got promoted while we were
// blocked on m.mu. (If further loads of the same key will not miss, it's
// not worth copying the dirty map for this key.)
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if !ok {
return nil, false
}
return e.load()
}
func (e *entry) load() (value any, ok bool) {
p := e.p.Load()
if p == nil || p == expunged {
return nil, false
}
return *p, true
}
// Store sets the value for a key.
func (m *Map) Store(key, value any) {
_, _ = m.Swap(key, value)
}
// tryCompareAndSwap compare the entry with the given old value and swaps
// it with a new value if the entry is equal to the old value, and the entry
// has not been expunged.
//
// If the entry is expunged, tryCompareAndSwap returns false and leaves
// the entry unchanged.
func (e *entry) tryCompareAndSwap(old, new any) bool {
p := e.p.Load()
if p == nil || p == expunged || *p != old {
return false
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if the comparison fails from the start, we shouldn't
// bother heap-allocating an interface value to store.
nc := new
for {
if e.p.CompareAndSwap(p, &nc) {
return true
}
p = e.p.Load()
if p == nil || p == expunged || *p != old {
return false
}
}
}
// unexpungeLocked ensures that the entry is not marked as expunged.
//
// If the entry was previously expunged, it must be added to the dirty map
// before m.mu is unlocked.
func (e *entry) unexpungeLocked() (wasExpunged bool) {
return e.p.CompareAndSwap(expunged, nil)
}
// swapLocked unconditionally swaps a value into the entry.
//
// The entry must be known not to be expunged.
func (e *entry) swapLocked(i *any) *any {
return e.p.Swap(i)
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key, value any) (actual any, loaded bool) {
// Avoid locking if it's a clean hit.
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
actual, loaded, ok := e.tryLoadOrStore(value)
if ok {
return actual, loaded
}
}
m.mu.Lock()
read = m.loadReadOnly()
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
m.dirty[key] = e
}
actual, loaded, _ = e.tryLoadOrStore(value)
} else if e, ok := m.dirty[key]; ok {
actual, loaded, _ = e.tryLoadOrStore(value)
m.missLocked()
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(&readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
actual, loaded = value, false
}
m.mu.Unlock()
return actual, loaded
}
// tryLoadOrStore atomically loads or stores a value if the entry is not
// expunged.
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry) tryLoadOrStore(i any) (actual any, loaded, ok bool) {
p := e.p.Load()
if p == expunged {
return nil, false, false
}
if p != nil {
return *p, true, true
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if we hit the "load" path or the entry is expunged, we
// shouldn't bother heap-allocating.
ic := i
for {
if e.p.CompareAndSwap(nil, &ic) {
return i, false, true
}
p = e.p.Load()
if p == expunged {
return nil, false, false
}
if p != nil {
return *p, true, true
}
}
}
// LoadAndDelete deletes the value for a key, returning the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map) LoadAndDelete(key any) (value any, loaded bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
delete(m.dirty, key)
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if ok {
return e.delete()
}
return nil, false
}
// Delete deletes the value for a key.
func (m *Map) Delete(key any) {
m.LoadAndDelete(key)
}
func (e *entry) delete() (value any, ok bool) {
for {
p := e.p.Load()
if p == nil || p == expunged {
return nil, false
}
if e.p.CompareAndSwap(p, nil) {
return *p, true
}
}
}
// trySwap swaps a value if the entry has not been expunged.
//
// If the entry is expunged, trySwap returns false and leaves the entry
// unchanged.
func (e *entry) trySwap(i *any) (*any, bool) {
for {
p := e.p.Load()
if p == expunged {
return nil, false
}
if e.p.CompareAndSwap(p, i) {
return p, true
}
}
}
// Swap swaps the value for a key and returns the previous value if any.
// The loaded result reports whether the key was present.
func (m *Map) Swap(key, value any) (previous any, loaded bool) {
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
if v, ok := e.trySwap(&value); ok {
if v == nil {
return nil, false
}
return *v, true
}
}
m.mu.Lock()
read = m.loadReadOnly()
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
// The entry was previously expunged, which implies that there is a
// non-nil dirty map and this entry is not in it.
m.dirty[key] = e
}
if v := e.swapLocked(&value); v != nil {
loaded = true
previous = *v
}
} else if e, ok := m.dirty[key]; ok {
if v := e.swapLocked(&value); v != nil {
loaded = true
previous = *v
}
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(&readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
}
m.mu.Unlock()
return previous, loaded
}
// CompareAndSwap swaps the old and new values for key
// if the value stored in the map is equal to old.
// The old value must be of a comparable type.
func (m *Map) CompareAndSwap(key, old, new any) bool {
read := m.loadReadOnly()
if e, ok := read.m[key]; ok {
return e.tryCompareAndSwap(old, new)
} else if !read.amended {
return false // No existing value for key.
}
m.mu.Lock()
defer m.mu.Unlock()
read = m.loadReadOnly()
swapped := false
if e, ok := read.m[key]; ok {
swapped = e.tryCompareAndSwap(old, new)
} else if e, ok := m.dirty[key]; ok {
swapped = e.tryCompareAndSwap(old, new)
// We needed to lock mu in order to load the entry for key,
// and the operation didn't change the set of keys in the map
// (so it would be made more efficient by promoting the dirty
// map to read-only).
// Count it as a miss so that we will eventually switch to the
// more efficient steady state.
m.missLocked()
}
return swapped
}
// CompareAndDelete deletes the entry for key if its value is equal to old.
// The old value must be of a comparable type.
//
// If there is no current value for key in the map, CompareAndDelete
// returns false (even if the old value is the nil interface value).
func (m *Map) CompareAndDelete(key, old any) (deleted bool) {
read := m.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read = m.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Don't delete key from m.dirty: we still need to do the “compare” part
// of the operation. The entry will eventually be expunged when the
// dirty map is promoted to the read map.
//
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
for ok {
p := e.p.Load()
if p == nil || p == expunged || *p != old {
return false
}
if e.p.CompareAndSwap(p, nil) {
return true
}
}
return false
}
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently (including by f), Range may reflect any
// mapping for that key from any point during the Range call. Range does not
// block other methods on the receiver; even f itself may call any method on m.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map) Range(f func(key, value any) bool) {
// We need to be able to iterate over all of the keys that were already
// present at the start of the call to Range.
// If read.amended is false, then read.m satisfies that property without
// requiring us to hold m.mu for a long time.
read := m.loadReadOnly()
if read.amended {
// m.dirty contains keys not in read.m. Fortunately, Range is already O(N)
// (assuming the caller does not break out early), so a call to Range
// amortizes an entire copy of the map: we can promote the dirty copy
// immediately!
m.mu.Lock()
read = m.loadReadOnly()
if read.amended {
read = readOnly{m: m.dirty}
m.read.Store(&read)
m.dirty = nil
m.misses = 0
}
m.mu.Unlock()
}
for k, e := range read.m {
v, ok := e.load()
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
func (m *Map) missLocked() {
m.misses++
if m.misses < len(m.dirty) {
return
}
m.read.Store(&readOnly{m: m.dirty})
m.dirty = nil
m.misses = 0
}
func (m *Map) dirtyLocked() {
if m.dirty != nil {
return
}
read := m.loadReadOnly()
m.dirty = make(map[any]*entry, len(read.m))
for k, e := range read.m {
if !e.tryExpungeLocked() {
m.dirty[k] = e
}
}
}
func (e *entry) tryExpungeLocked() (isExpunged bool) {
p := e.p.Load()
for p == nil {
if e.p.CompareAndSwap(nil, expunged) {
return true
}
p = e.p.Load()
}
return p == expunged
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sync provides basic synchronization primitives such as mutual
// exclusion locks. Other than the Once and WaitGroup types, most are intended
// for use by low-level library routines. Higher-level synchronization is
// better done via channels and communication.
//
// Values containing the types defined in this package should not be copied.
package sync
import (
"internal/race"
"sync/atomic"
"unsafe"
)
// Provided by runtime via linkname.
func throw(string)
func fatal(string)
// A Mutex is a mutual exclusion lock.
// The zero value for a Mutex is an unlocked mutex.
//
// A Mutex must not be copied after first use.
//
// In the terminology of the Go memory model,
// the n'th call to Unlock “synchronizes before” the m'th call to Lock
// for any n < m.
// A successful call to TryLock is equivalent to a call to Lock.
// A failed call to TryLock does not establish any “synchronizes before”
// relation at all.
type Mutex struct {
state int32
sema uint32
}
// A Locker represents an object that can be locked and unlocked.
type Locker interface {
Lock()
Unlock()
}
const (
mutexLocked = 1 << iota // mutex is locked
mutexWoken
mutexStarving
mutexWaiterShift = iota
// Mutex fairness.
//
// Mutex can be in 2 modes of operations: normal and starvation.
// In normal mode waiters are queued in FIFO order, but a woken up waiter
// does not own the mutex and competes with new arriving goroutines over
// the ownership. New arriving goroutines have an advantage -- they are
// already running on CPU and there can be lots of them, so a woken up
// waiter has good chances of losing. In such case it is queued at front
// of the wait queue. If a waiter fails to acquire the mutex for more than 1ms,
// it switches mutex to the starvation mode.
//
// In starvation mode ownership of the mutex is directly handed off from
// the unlocking goroutine to the waiter at the front of the queue.
// New arriving goroutines don't try to acquire the mutex even if it appears
// to be unlocked, and don't try to spin. Instead they queue themselves at
// the tail of the wait queue.
//
// If a waiter receives ownership of the mutex and sees that either
// (1) it is the last waiter in the queue, or (2) it waited for less than 1 ms,
// it switches mutex back to normal operation mode.
//
// Normal mode has considerably better performance as a goroutine can acquire
// a mutex several times in a row even if there are blocked waiters.
// Starvation mode is important to prevent pathological cases of tail latency.
starvationThresholdNs = 1e6
)
// Lock locks m.
// If the lock is already in use, the calling goroutine
// blocks until the mutex is available.
func (m *Mutex) Lock() {
// Fast path: grab unlocked mutex.
if atomic.CompareAndSwapInt32(&m.state, 0, mutexLocked) {
if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
return
}
// Slow path (outlined so that the fast path can be inlined)
m.lockSlow()
}
// TryLock tries to lock m and reports whether it succeeded.
//
// Note that while correct uses of TryLock do exist, they are rare,
// and use of TryLock is often a sign of a deeper problem
// in a particular use of mutexes.
func (m *Mutex) TryLock() bool {
old := m.state
if old&(mutexLocked|mutexStarving) != 0 {
return false
}
// There may be a goroutine waiting for the mutex, but we are
// running now and can try to grab the mutex before that
// goroutine wakes up.
if !atomic.CompareAndSwapInt32(&m.state, old, old|mutexLocked) {
return false
}
if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
return true
}
func (m *Mutex) lockSlow() {
var waitStartTime int64
starving := false
awoke := false
iter := 0
old := m.state
for {
// Don't spin in starvation mode, ownership is handed off to waiters
// so we won't be able to acquire the mutex anyway.
if old&(mutexLocked|mutexStarving) == mutexLocked && runtime_canSpin(iter) {
// Active spinning makes sense.
// Try to set mutexWoken flag to inform Unlock
// to not wake other blocked goroutines.
if !awoke && old&mutexWoken == 0 && old>>mutexWaiterShift != 0 &&
atomic.CompareAndSwapInt32(&m.state, old, old|mutexWoken) {
awoke = true
}
runtime_doSpin()
iter++
old = m.state
continue
}
new := old
// Don't try to acquire starving mutex, new arriving goroutines must queue.
if old&mutexStarving == 0 {
new |= mutexLocked
}
if old&(mutexLocked|mutexStarving) != 0 {
new += 1 << mutexWaiterShift
}
// The current goroutine switches mutex to starvation mode.
// But if the mutex is currently unlocked, don't do the switch.
// Unlock expects that starving mutex has waiters, which will not
// be true in this case.
if starving && old&mutexLocked != 0 {
new |= mutexStarving
}
if awoke {
// The goroutine has been woken from sleep,
// so we need to reset the flag in either case.
if new&mutexWoken == 0 {
throw("sync: inconsistent mutex state")
}
new &^= mutexWoken
}
if atomic.CompareAndSwapInt32(&m.state, old, new) {
if old&(mutexLocked|mutexStarving) == 0 {
break // locked the mutex with CAS
}
// If we were already waiting before, queue at the front of the queue.
queueLifo := waitStartTime != 0
if waitStartTime == 0 {
waitStartTime = runtime_nanotime()
}
runtime_SemacquireMutex(&m.sema, queueLifo, 1)
starving = starving || runtime_nanotime()-waitStartTime > starvationThresholdNs
old = m.state
if old&mutexStarving != 0 {
// If this goroutine was woken and mutex is in starvation mode,
// ownership was handed off to us but mutex is in somewhat
// inconsistent state: mutexLocked is not set and we are still
// accounted as waiter. Fix that.
if old&(mutexLocked|mutexWoken) != 0 || old>>mutexWaiterShift == 0 {
throw("sync: inconsistent mutex state")
}
delta := int32(mutexLocked - 1<<mutexWaiterShift)
if !starving || old>>mutexWaiterShift == 1 {
// Exit starvation mode.
// Critical to do it here and consider wait time.
// Starvation mode is so inefficient, that two goroutines
// can go lock-step infinitely once they switch mutex
// to starvation mode.
delta -= mutexStarving
}
atomic.AddInt32(&m.state, delta)
break
}
awoke = true
iter = 0
} else {
old = m.state
}
}
if race.Enabled {
race.Acquire(unsafe.Pointer(m))
}
}
// Unlock unlocks m.
// It is a run-time error if m is not locked on entry to Unlock.
//
// A locked Mutex is not associated with a particular goroutine.
// It is allowed for one goroutine to lock a Mutex and then
// arrange for another goroutine to unlock it.
func (m *Mutex) Unlock() {
if race.Enabled {
_ = m.state
race.Release(unsafe.Pointer(m))
}
// Fast path: drop lock bit.
new := atomic.AddInt32(&m.state, -mutexLocked)
if new != 0 {
// Outlined slow path to allow inlining the fast path.
// To hide unlockSlow during tracing we skip one extra frame when tracing GoUnblock.
m.unlockSlow(new)
}
}
func (m *Mutex) unlockSlow(new int32) {
if (new+mutexLocked)&mutexLocked == 0 {
fatal("sync: unlock of unlocked mutex")
}
if new&mutexStarving == 0 {
old := new
for {
// If there are no waiters or a goroutine has already
// been woken or grabbed the lock, no need to wake anyone.
// In starvation mode ownership is directly handed off from unlocking
// goroutine to the next waiter. We are not part of this chain,
// since we did not observe mutexStarving when we unlocked the mutex above.
// So get off the way.
if old>>mutexWaiterShift == 0 || old&(mutexLocked|mutexWoken|mutexStarving) != 0 {
return
}
// Grab the right to wake someone.
new = (old - 1<<mutexWaiterShift) | mutexWoken
if atomic.CompareAndSwapInt32(&m.state, old, new) {
runtime_Semrelease(&m.sema, false, 1)
return
}
old = m.state
}
} else {
// Starving mode: handoff mutex ownership to the next waiter, and yield
// our time slice so that the next waiter can start to run immediately.
// Note: mutexLocked is not set, the waiter will set it after wakeup.
// But mutex is still considered locked if mutexStarving is set,
// so new coming goroutines won't acquire it.
runtime_Semrelease(&m.sema, true, 1)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"sync/atomic"
)
// Once is an object that will perform exactly one action.
//
// A Once must not be copied after first use.
//
// In the terminology of the Go memory model,
// the return from f “synchronizes before”
// the return from any call of once.Do(f).
type Once struct {
// done indicates whether the action has been performed.
// It is first in the struct because it is used in the hot path.
// The hot path is inlined at every call site.
// Placing done first allows more compact instructions on some architectures (amd64/386),
// and fewer instructions (to calculate offset) on other architectures.
done uint32
m Mutex
}
// Do calls the function f if and only if Do is being called for the
// first time for this instance of Once. In other words, given
//
// var once Once
//
// if once.Do(f) is called multiple times, only the first call will invoke f,
// even if f has a different value in each invocation. A new instance of
// Once is required for each function to execute.
//
// Do is intended for initialization that must be run exactly once. Since f
// is niladic, it may be necessary to use a function literal to capture the
// arguments to a function to be invoked by Do:
//
// config.once.Do(func() { config.init(filename) })
//
// Because no call to Do returns until the one call to f returns, if f causes
// Do to be called, it will deadlock.
//
// If f panics, Do considers it to have returned; future calls of Do return
// without calling f.
func (o *Once) Do(f func()) {
// Note: Here is an incorrect implementation of Do:
//
// if atomic.CompareAndSwapUint32(&o.done, 0, 1) {
// f()
// }
//
// Do guarantees that when it returns, f has finished.
// This implementation would not implement that guarantee:
// given two simultaneous calls, the winner of the cas would
// call f, and the second would return immediately, without
// waiting for the first's call to f to complete.
// This is why the slow path falls back to a mutex, and why
// the atomic.StoreUint32 must be delayed until after f returns.
if atomic.LoadUint32(&o.done) == 0 {
// Outlined slow-path to allow inlining of the fast-path.
o.doSlow(f)
}
}
func (o *Once) doSlow(f func()) {
o.m.Lock()
defer o.m.Unlock()
if o.done == 0 {
defer atomic.StoreUint32(&o.done, 1)
f()
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"internal/race"
"runtime"
"sync/atomic"
"unsafe"
)
// A Pool is a set of temporary objects that may be individually saved and
// retrieved.
//
// Any item stored in the Pool may be removed automatically at any time without
// notification. If the Pool holds the only reference when this happens, the
// item might be deallocated.
//
// A Pool is safe for use by multiple goroutines simultaneously.
//
// Pool's purpose is to cache allocated but unused items for later reuse,
// relieving pressure on the garbage collector. That is, it makes it easy to
// build efficient, thread-safe free lists. However, it is not suitable for all
// free lists.
//
// An appropriate use of a Pool is to manage a group of temporary items
// silently shared among and potentially reused by concurrent independent
// clients of a package. Pool provides a way to amortize allocation overhead
// across many clients.
//
// An example of good use of a Pool is in the fmt package, which maintains a
// dynamically-sized store of temporary output buffers. The store scales under
// load (when many goroutines are actively printing) and shrinks when
// quiescent.
//
// On the other hand, a free list maintained as part of a short-lived object is
// not a suitable use for a Pool, since the overhead does not amortize well in
// that scenario. It is more efficient to have such objects implement their own
// free list.
//
// A Pool must not be copied after first use.
//
// In the terminology of the Go memory model, a call to Put(x) “synchronizes before”
// a call to Get returning that same value x.
// Similarly, a call to New returning x “synchronizes before”
// a call to Get returning that same value x.
type Pool struct {
noCopy noCopy
local unsafe.Pointer // local fixed-size per-P pool, actual type is [P]poolLocal
localSize uintptr // size of the local array
victim unsafe.Pointer // local from previous cycle
victimSize uintptr // size of victims array
// New optionally specifies a function to generate
// a value when Get would otherwise return nil.
// It may not be changed concurrently with calls to Get.
New func() any
}
// Local per-P Pool appendix.
type poolLocalInternal struct {
private any // Can be used only by the respective P.
shared poolChain // Local P can pushHead/popHead; any P can popTail.
}
type poolLocal struct {
poolLocalInternal
// Prevents false sharing on widespread platforms with
// 128 mod (cache line size) = 0 .
pad [128 - unsafe.Sizeof(poolLocalInternal{})%128]byte
}
// from runtime
func fastrandn(n uint32) uint32
var poolRaceHash [128]uint64
// poolRaceAddr returns an address to use as the synchronization point
// for race detector logic. We don't use the actual pointer stored in x
// directly, for fear of conflicting with other synchronization on that address.
// Instead, we hash the pointer to get an index into poolRaceHash.
// See discussion on golang.org/cl/31589.
func poolRaceAddr(x any) unsafe.Pointer {
ptr := uintptr((*[2]unsafe.Pointer)(unsafe.Pointer(&x))[1])
h := uint32((uint64(uint32(ptr)) * 0x85ebca6b) >> 16)
return unsafe.Pointer(&poolRaceHash[h%uint32(len(poolRaceHash))])
}
// Put adds x to the pool.
func (p *Pool) Put(x any) {
if x == nil {
return
}
if race.Enabled {
if fastrandn(4) == 0 {
// Randomly drop x on floor.
return
}
race.ReleaseMerge(poolRaceAddr(x))
race.Disable()
}
l, _ := p.pin()
if l.private == nil {
l.private = x
} else {
l.shared.pushHead(x)
}
runtime_procUnpin()
if race.Enabled {
race.Enable()
}
}
// Get selects an arbitrary item from the Pool, removes it from the
// Pool, and returns it to the caller.
// Get may choose to ignore the pool and treat it as empty.
// Callers should not assume any relation between values passed to Put and
// the values returned by Get.
//
// If Get would otherwise return nil and p.New is non-nil, Get returns
// the result of calling p.New.
func (p *Pool) Get() any {
if race.Enabled {
race.Disable()
}
l, pid := p.pin()
x := l.private
l.private = nil
if x == nil {
// Try to pop the head of the local shard. We prefer
// the head over the tail for temporal locality of
// reuse.
x, _ = l.shared.popHead()
if x == nil {
x = p.getSlow(pid)
}
}
runtime_procUnpin()
if race.Enabled {
race.Enable()
if x != nil {
race.Acquire(poolRaceAddr(x))
}
}
if x == nil && p.New != nil {
x = p.New()
}
return x
}
func (p *Pool) getSlow(pid int) any {
// See the comment in pin regarding ordering of the loads.
size := runtime_LoadAcquintptr(&p.localSize) // load-acquire
locals := p.local // load-consume
// Try to steal one element from other procs.
for i := 0; i < int(size); i++ {
l := indexLocal(locals, (pid+i+1)%int(size))
if x, _ := l.shared.popTail(); x != nil {
return x
}
}
// Try the victim cache. We do this after attempting to steal
// from all primary caches because we want objects in the
// victim cache to age out if at all possible.
size = atomic.LoadUintptr(&p.victimSize)
if uintptr(pid) >= size {
return nil
}
locals = p.victim
l := indexLocal(locals, pid)
if x := l.private; x != nil {
l.private = nil
return x
}
for i := 0; i < int(size); i++ {
l := indexLocal(locals, (pid+i)%int(size))
if x, _ := l.shared.popTail(); x != nil {
return x
}
}
// Mark the victim cache as empty for future gets don't bother
// with it.
atomic.StoreUintptr(&p.victimSize, 0)
return nil
}
// pin pins the current goroutine to P, disables preemption and
// returns poolLocal pool for the P and the P's id.
// Caller must call runtime_procUnpin() when done with the pool.
func (p *Pool) pin() (*poolLocal, int) {
pid := runtime_procPin()
// In pinSlow we store to local and then to localSize, here we load in opposite order.
// Since we've disabled preemption, GC cannot happen in between.
// Thus here we must observe local at least as large localSize.
// We can observe a newer/larger local, it is fine (we must observe its zero-initialized-ness).
s := runtime_LoadAcquintptr(&p.localSize) // load-acquire
l := p.local // load-consume
if uintptr(pid) < s {
return indexLocal(l, pid), pid
}
return p.pinSlow()
}
func (p *Pool) pinSlow() (*poolLocal, int) {
// Retry under the mutex.
// Can not lock the mutex while pinned.
runtime_procUnpin()
allPoolsMu.Lock()
defer allPoolsMu.Unlock()
pid := runtime_procPin()
// poolCleanup won't be called while we are pinned.
s := p.localSize
l := p.local
if uintptr(pid) < s {
return indexLocal(l, pid), pid
}
if p.local == nil {
allPools = append(allPools, p)
}
// If GOMAXPROCS changes between GCs, we re-allocate the array and lose the old one.
size := runtime.GOMAXPROCS(0)
local := make([]poolLocal, size)
atomic.StorePointer(&p.local, unsafe.Pointer(&local[0])) // store-release
runtime_StoreReluintptr(&p.localSize, uintptr(size)) // store-release
return &local[pid], pid
}
func poolCleanup() {
// This function is called with the world stopped, at the beginning of a garbage collection.
// It must not allocate and probably should not call any runtime functions.
// Because the world is stopped, no pool user can be in a
// pinned section (in effect, this has all Ps pinned).
// Drop victim caches from all pools.
for _, p := range oldPools {
p.victim = nil
p.victimSize = 0
}
// Move primary cache to victim cache.
for _, p := range allPools {
p.victim = p.local
p.victimSize = p.localSize
p.local = nil
p.localSize = 0
}
// The pools with non-empty primary caches now have non-empty
// victim caches and no pools have primary caches.
oldPools, allPools = allPools, nil
}
var (
allPoolsMu Mutex
// allPools is the set of pools that have non-empty primary
// caches. Protected by either 1) allPoolsMu and pinning or 2)
// STW.
allPools []*Pool
// oldPools is the set of pools that may have non-empty victim
// caches. Protected by STW.
oldPools []*Pool
)
func init() {
runtime_registerPoolCleanup(poolCleanup)
}
func indexLocal(l unsafe.Pointer, i int) *poolLocal {
lp := unsafe.Pointer(uintptr(l) + uintptr(i)*unsafe.Sizeof(poolLocal{}))
return (*poolLocal)(lp)
}
// Implemented in runtime.
func runtime_registerPoolCleanup(cleanup func())
func runtime_procPin() int
func runtime_procUnpin()
// The below are implemented in runtime/internal/atomic and the
// compiler also knows to intrinsify the symbol we linkname into this
// package.
//go:linkname runtime_LoadAcquintptr runtime/internal/atomic.LoadAcquintptr
func runtime_LoadAcquintptr(ptr *uintptr) uintptr
//go:linkname runtime_StoreReluintptr runtime/internal/atomic.StoreReluintptr
func runtime_StoreReluintptr(ptr *uintptr, val uintptr) uintptr
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"sync/atomic"
"unsafe"
)
// poolDequeue is a lock-free fixed-size single-producer,
// multi-consumer queue. The single producer can both push and pop
// from the head, and consumers can pop from the tail.
//
// It has the added feature that it nils out unused slots to avoid
// unnecessary retention of objects. This is important for sync.Pool,
// but not typically a property considered in the literature.
type poolDequeue struct {
// headTail packs together a 32-bit head index and a 32-bit
// tail index. Both are indexes into vals modulo len(vals)-1.
//
// tail = index of oldest data in queue
// head = index of next slot to fill
//
// Slots in the range [tail, head) are owned by consumers.
// A consumer continues to own a slot outside this range until
// it nils the slot, at which point ownership passes to the
// producer.
//
// The head index is stored in the most-significant bits so
// that we can atomically add to it and the overflow is
// harmless.
headTail uint64
// vals is a ring buffer of interface{} values stored in this
// dequeue. The size of this must be a power of 2.
//
// vals[i].typ is nil if the slot is empty and non-nil
// otherwise. A slot is still in use until *both* the tail
// index has moved beyond it and typ has been set to nil. This
// is set to nil atomically by the consumer and read
// atomically by the producer.
vals []eface
}
type eface struct {
typ, val unsafe.Pointer
}
const dequeueBits = 32
// dequeueLimit is the maximum size of a poolDequeue.
//
// This must be at most (1<<dequeueBits)/2 because detecting fullness
// depends on wrapping around the ring buffer without wrapping around
// the index. We divide by 4 so this fits in an int on 32-bit.
const dequeueLimit = (1 << dequeueBits) / 4
// dequeueNil is used in poolDequeue to represent interface{}(nil).
// Since we use nil to represent empty slots, we need a sentinel value
// to represent nil.
type dequeueNil *struct{}
func (d *poolDequeue) unpack(ptrs uint64) (head, tail uint32) {
const mask = 1<<dequeueBits - 1
head = uint32((ptrs >> dequeueBits) & mask)
tail = uint32(ptrs & mask)
return
}
func (d *poolDequeue) pack(head, tail uint32) uint64 {
const mask = 1<<dequeueBits - 1
return (uint64(head) << dequeueBits) |
uint64(tail&mask)
}
// pushHead adds val at the head of the queue. It returns false if the
// queue is full. It must only be called by a single producer.
func (d *poolDequeue) pushHead(val any) bool {
ptrs := atomic.LoadUint64(&d.headTail)
head, tail := d.unpack(ptrs)
if (tail+uint32(len(d.vals)))&(1<<dequeueBits-1) == head {
// Queue is full.
return false
}
slot := &d.vals[head&uint32(len(d.vals)-1)]
// Check if the head slot has been released by popTail.
typ := atomic.LoadPointer(&slot.typ)
if typ != nil {
// Another goroutine is still cleaning up the tail, so
// the queue is actually still full.
return false
}
// The head slot is free, so we own it.
if val == nil {
val = dequeueNil(nil)
}
*(*any)(unsafe.Pointer(slot)) = val
// Increment head. This passes ownership of slot to popTail
// and acts as a store barrier for writing the slot.
atomic.AddUint64(&d.headTail, 1<<dequeueBits)
return true
}
// popHead removes and returns the element at the head of the queue.
// It returns false if the queue is empty. It must only be called by a
// single producer.
func (d *poolDequeue) popHead() (any, bool) {
var slot *eface
for {
ptrs := atomic.LoadUint64(&d.headTail)
head, tail := d.unpack(ptrs)
if tail == head {
// Queue is empty.
return nil, false
}
// Confirm tail and decrement head. We do this before
// reading the value to take back ownership of this
// slot.
head--
ptrs2 := d.pack(head, tail)
if atomic.CompareAndSwapUint64(&d.headTail, ptrs, ptrs2) {
// We successfully took back slot.
slot = &d.vals[head&uint32(len(d.vals)-1)]
break
}
}
val := *(*any)(unsafe.Pointer(slot))
if val == dequeueNil(nil) {
val = nil
}
// Zero the slot. Unlike popTail, this isn't racing with
// pushHead, so we don't need to be careful here.
*slot = eface{}
return val, true
}
// popTail removes and returns the element at the tail of the queue.
// It returns false if the queue is empty. It may be called by any
// number of consumers.
func (d *poolDequeue) popTail() (any, bool) {
var slot *eface
for {
ptrs := atomic.LoadUint64(&d.headTail)
head, tail := d.unpack(ptrs)
if tail == head {
// Queue is empty.
return nil, false
}
// Confirm head and tail (for our speculative check
// above) and increment tail. If this succeeds, then
// we own the slot at tail.
ptrs2 := d.pack(head, tail+1)
if atomic.CompareAndSwapUint64(&d.headTail, ptrs, ptrs2) {
// Success.
slot = &d.vals[tail&uint32(len(d.vals)-1)]
break
}
}
// We now own slot.
val := *(*any)(unsafe.Pointer(slot))
if val == dequeueNil(nil) {
val = nil
}
// Tell pushHead that we're done with this slot. Zeroing the
// slot is also important so we don't leave behind references
// that could keep this object live longer than necessary.
//
// We write to val first and then publish that we're done with
// this slot by atomically writing to typ.
slot.val = nil
atomic.StorePointer(&slot.typ, nil)
// At this point pushHead owns the slot.
return val, true
}
// poolChain is a dynamically-sized version of poolDequeue.
//
// This is implemented as a doubly-linked list queue of poolDequeues
// where each dequeue is double the size of the previous one. Once a
// dequeue fills up, this allocates a new one and only ever pushes to
// the latest dequeue. Pops happen from the other end of the list and
// once a dequeue is exhausted, it gets removed from the list.
type poolChain struct {
// head is the poolDequeue to push to. This is only accessed
// by the producer, so doesn't need to be synchronized.
head *poolChainElt
// tail is the poolDequeue to popTail from. This is accessed
// by consumers, so reads and writes must be atomic.
tail *poolChainElt
}
type poolChainElt struct {
poolDequeue
// next and prev link to the adjacent poolChainElts in this
// poolChain.
//
// next is written atomically by the producer and read
// atomically by the consumer. It only transitions from nil to
// non-nil.
//
// prev is written atomically by the consumer and read
// atomically by the producer. It only transitions from
// non-nil to nil.
next, prev *poolChainElt
}
func storePoolChainElt(pp **poolChainElt, v *poolChainElt) {
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(pp)), unsafe.Pointer(v))
}
func loadPoolChainElt(pp **poolChainElt) *poolChainElt {
return (*poolChainElt)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(pp))))
}
func (c *poolChain) pushHead(val any) {
d := c.head
if d == nil {
// Initialize the chain.
const initSize = 8 // Must be a power of 2
d = new(poolChainElt)
d.vals = make([]eface, initSize)
c.head = d
storePoolChainElt(&c.tail, d)
}
if d.pushHead(val) {
return
}
// The current dequeue is full. Allocate a new one of twice
// the size.
newSize := len(d.vals) * 2
if newSize >= dequeueLimit {
// Can't make it any bigger.
newSize = dequeueLimit
}
d2 := &poolChainElt{prev: d}
d2.vals = make([]eface, newSize)
c.head = d2
storePoolChainElt(&d.next, d2)
d2.pushHead(val)
}
func (c *poolChain) popHead() (any, bool) {
d := c.head
for d != nil {
if val, ok := d.popHead(); ok {
return val, ok
}
// There may still be unconsumed elements in the
// previous dequeue, so try backing up.
d = loadPoolChainElt(&d.prev)
}
return nil, false
}
func (c *poolChain) popTail() (any, bool) {
d := loadPoolChainElt(&c.tail)
if d == nil {
return nil, false
}
for {
// It's important that we load the next pointer
// *before* popping the tail. In general, d may be
// transiently empty, but if next is non-nil before
// the pop and the pop fails, then d is permanently
// empty, which is the only condition under which it's
// safe to drop d from the chain.
d2 := loadPoolChainElt(&d.next)
if val, ok := d.popTail(); ok {
return val, ok
}
if d2 == nil {
// This is the only dequeue. It's empty right
// now, but could be pushed to in the future.
return nil, false
}
// The tail of the chain has been drained, so move on
// to the next dequeue. Try to drop it from the chain
// so the next pop doesn't have to look at the empty
// dequeue again.
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&c.tail)), unsafe.Pointer(d), unsafe.Pointer(d2)) {
// We won the race. Clear the prev pointer so
// the garbage collector can collect the empty
// dequeue and so popHead doesn't back up
// further than necessary.
storePoolChainElt(&d2.prev, nil)
}
d = d2
}
}
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import "unsafe"
// defined in package runtime
// Semacquire waits until *s > 0 and then atomically decrements it.
// It is intended as a simple sleep primitive for use by the synchronization
// library and should not be used directly.
func runtime_Semacquire(s *uint32)
// Semacquire(RW)Mutex(R) is like Semacquire, but for profiling contended
// Mutexes and RWMutexes.
// If lifo is true, queue waiter at the head of wait queue.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_SemacquireMutex's caller.
// The different forms of this function just tell the runtime how to present
// the reason for waiting in a backtrace, and is used to compute some metrics.
// Otherwise they're functionally identical.
func runtime_SemacquireMutex(s *uint32, lifo bool, skipframes int)
func runtime_SemacquireRWMutexR(s *uint32, lifo bool, skipframes int)
func runtime_SemacquireRWMutex(s *uint32, lifo bool, skipframes int)
// Semrelease atomically increments *s and notifies a waiting goroutine
// if one is blocked in Semacquire.
// It is intended as a simple wakeup primitive for use by the synchronization
// library and should not be used directly.
// If handoff is true, pass count directly to the first waiter.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_Semrelease's caller.
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
// See runtime/sema.go for documentation.
func runtime_notifyListAdd(l *notifyList) uint32
// See runtime/sema.go for documentation.
func runtime_notifyListWait(l *notifyList, t uint32)
// See runtime/sema.go for documentation.
func runtime_notifyListNotifyAll(l *notifyList)
// See runtime/sema.go for documentation.
func runtime_notifyListNotifyOne(l *notifyList)
// Ensure that sync and runtime agree on size of notifyList.
func runtime_notifyListCheck(size uintptr)
func init() {
var n notifyList
runtime_notifyListCheck(unsafe.Sizeof(n))
}
// Active spinning runtime support.
// runtime_canSpin reports whether spinning makes sense at the moment.
func runtime_canSpin(i int) bool
// runtime_doSpin does active spinning.
func runtime_doSpin()
func runtime_nanotime() int64
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"internal/race"
"sync/atomic"
"unsafe"
)
// There is a modified copy of this file in runtime/rwmutex.go.
// If you make any changes here, see if you should make them there.
// A RWMutex is a reader/writer mutual exclusion lock.
// The lock can be held by an arbitrary number of readers or a single writer.
// The zero value for a RWMutex is an unlocked mutex.
//
// A RWMutex must not be copied after first use.
//
// If a goroutine holds a RWMutex for reading and another goroutine might
// call Lock, no goroutine should expect to be able to acquire a read lock
// until the initial read lock is released. In particular, this prohibits
// recursive read locking. This is to ensure that the lock eventually becomes
// available; a blocked Lock call excludes new readers from acquiring the
// lock.
//
// In the terminology of the Go memory model,
// the n'th call to Unlock “synchronizes before” the m'th call to Lock
// for any n < m, just as for Mutex.
// For any call to RLock, there exists an n such that
// the n'th call to Unlock “synchronizes before” that call to RLock,
// and the corresponding call to RUnlock “synchronizes before”
// the n+1'th call to Lock.
type RWMutex struct {
w Mutex // held if there are pending writers
writerSem uint32 // semaphore for writers to wait for completing readers
readerSem uint32 // semaphore for readers to wait for completing writers
readerCount atomic.Int32 // number of pending readers
readerWait atomic.Int32 // number of departing readers
}
const rwmutexMaxReaders = 1 << 30
// Happens-before relationships are indicated to the race detector via:
// - Unlock -> Lock: readerSem
// - Unlock -> RLock: readerSem
// - RUnlock -> Lock: writerSem
//
// The methods below temporarily disable handling of race synchronization
// events in order to provide the more precise model above to the race
// detector.
//
// For example, atomic.AddInt32 in RLock should not appear to provide
// acquire-release semantics, which would incorrectly synchronize racing
// readers, thus potentially missing races.
// RLock locks rw for reading.
//
// It should not be used for recursive read locking; a blocked Lock
// call excludes new readers from acquiring the lock. See the
// documentation on the RWMutex type.
func (rw *RWMutex) RLock() {
if race.Enabled {
_ = rw.w.state
race.Disable()
}
if rw.readerCount.Add(1) < 0 {
// A writer is pending, wait for it.
runtime_SemacquireRWMutexR(&rw.readerSem, false, 0)
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(&rw.readerSem))
}
}
// TryRLock tries to lock rw for reading and reports whether it succeeded.
//
// Note that while correct uses of TryRLock do exist, they are rare,
// and use of TryRLock is often a sign of a deeper problem
// in a particular use of mutexes.
func (rw *RWMutex) TryRLock() bool {
if race.Enabled {
_ = rw.w.state
race.Disable()
}
for {
c := rw.readerCount.Load()
if c < 0 {
if race.Enabled {
race.Enable()
}
return false
}
if rw.readerCount.CompareAndSwap(c, c+1) {
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(&rw.readerSem))
}
return true
}
}
}
// RUnlock undoes a single RLock call;
// it does not affect other simultaneous readers.
// It is a run-time error if rw is not locked for reading
// on entry to RUnlock.
func (rw *RWMutex) RUnlock() {
if race.Enabled {
_ = rw.w.state
race.ReleaseMerge(unsafe.Pointer(&rw.writerSem))
race.Disable()
}
if r := rw.readerCount.Add(-1); r < 0 {
// Outlined slow-path to allow the fast-path to be inlined
rw.rUnlockSlow(r)
}
if race.Enabled {
race.Enable()
}
}
func (rw *RWMutex) rUnlockSlow(r int32) {
if r+1 == 0 || r+1 == -rwmutexMaxReaders {
race.Enable()
fatal("sync: RUnlock of unlocked RWMutex")
}
// A writer is pending.
if rw.readerWait.Add(-1) == 0 {
// The last reader unblocks the writer.
runtime_Semrelease(&rw.writerSem, false, 1)
}
}
// Lock locks rw for writing.
// If the lock is already locked for reading or writing,
// Lock blocks until the lock is available.
func (rw *RWMutex) Lock() {
if race.Enabled {
_ = rw.w.state
race.Disable()
}
// First, resolve competition with other writers.
rw.w.Lock()
// Announce to readers there is a pending writer.
r := rw.readerCount.Add(-rwmutexMaxReaders) + rwmutexMaxReaders
// Wait for active readers.
if r != 0 && rw.readerWait.Add(r) != 0 {
runtime_SemacquireRWMutex(&rw.writerSem, false, 0)
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(&rw.readerSem))
race.Acquire(unsafe.Pointer(&rw.writerSem))
}
}
// TryLock tries to lock rw for writing and reports whether it succeeded.
//
// Note that while correct uses of TryLock do exist, they are rare,
// and use of TryLock is often a sign of a deeper problem
// in a particular use of mutexes.
func (rw *RWMutex) TryLock() bool {
if race.Enabled {
_ = rw.w.state
race.Disable()
}
if !rw.w.TryLock() {
if race.Enabled {
race.Enable()
}
return false
}
if !rw.readerCount.CompareAndSwap(0, -rwmutexMaxReaders) {
rw.w.Unlock()
if race.Enabled {
race.Enable()
}
return false
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(&rw.readerSem))
race.Acquire(unsafe.Pointer(&rw.writerSem))
}
return true
}
// Unlock unlocks rw for writing. It is a run-time error if rw is
// not locked for writing on entry to Unlock.
//
// As with Mutexes, a locked RWMutex is not associated with a particular
// goroutine. One goroutine may RLock (Lock) a RWMutex and then
// arrange for another goroutine to RUnlock (Unlock) it.
func (rw *RWMutex) Unlock() {
if race.Enabled {
_ = rw.w.state
race.Release(unsafe.Pointer(&rw.readerSem))
race.Disable()
}
// Announce to readers there is no active writer.
r := rw.readerCount.Add(rwmutexMaxReaders)
if r >= rwmutexMaxReaders {
race.Enable()
fatal("sync: Unlock of unlocked RWMutex")
}
// Unblock blocked readers, if any.
for i := 0; i < int(r); i++ {
runtime_Semrelease(&rw.readerSem, false, 0)
}
// Allow other writers to proceed.
rw.w.Unlock()
if race.Enabled {
race.Enable()
}
}
// RLocker returns a Locker interface that implements
// the Lock and Unlock methods by calling rw.RLock and rw.RUnlock.
func (rw *RWMutex) RLocker() Locker {
return (*rlocker)(rw)
}
type rlocker RWMutex
func (r *rlocker) Lock() { (*RWMutex)(r).RLock() }
func (r *rlocker) Unlock() { (*RWMutex)(r).RUnlock() }
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sync
import (
"internal/race"
"sync/atomic"
"unsafe"
)
// A WaitGroup waits for a collection of goroutines to finish.
// The main goroutine calls Add to set the number of
// goroutines to wait for. Then each of the goroutines
// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
//
// A WaitGroup must not be copied after first use.
//
// In the terminology of the Go memory model, a call to Done
// “synchronizes before” the return of any Wait call that it unblocks.
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}
// Add adds delta, which may be negative, to the WaitGroup counter.
// If the counter becomes zero, all goroutines blocked on Wait are released.
// If the counter goes negative, Add panics.
//
// Note that calls with a positive delta that occur when the counter is zero
// must happen before a Wait. Calls with a negative delta, or calls with a
// positive delta that start when the counter is greater than zero, may happen
// at any time.
// Typically this means the calls to Add should execute before the statement
// creating the goroutine or other event to be waited for.
// If a WaitGroup is reused to wait for several independent sets of events,
// new Add calls must happen after all previous Wait calls have returned.
// See the WaitGroup example.
func (wg *WaitGroup) Add(delta int) {
if race.Enabled {
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
if race.Enabled {
race.Disable()
}
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !asan
package syscall
import (
"unsafe"
)
const asanenabled = false
func asanRead(addr unsafe.Pointer, len int) {
}
func asanWrite(addr unsafe.Pointer, len int) {
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package syscall
import "unsafe"
// readInt returns the size-bytes unsigned integer in native byte order at offset off.
func readInt(b []byte, off, size uintptr) (u uint64, ok bool) {
if len(b) < int(off+size) {
return 0, false
}
if isBigEndian {
return readIntBE(b[off:], size), true
}
return readIntLE(b[off:], size), true
}
func readIntBE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[1]) | uint64(b[0])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[3]) | uint64(b[2])<<8 | uint64(b[1])<<16 | uint64(b[0])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
func readIntLE(b []byte, size uintptr) uint64 {
switch size {
case 1:
return uint64(b[0])
case 2:
_ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8
case 4:
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24
case 8:
_ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
default:
panic("syscall: readInt with unsupported size")
}
}
// ParseDirent parses up to max directory entries in buf,
// appending the names to names. It returns the number of
// bytes consumed from buf, the number of entries added
// to names, and the new names slice.
func ParseDirent(buf []byte, max int, names []string) (consumed int, count int, newnames []string) {
origlen := len(buf)
count = 0
for max != 0 && len(buf) > 0 {
reclen, ok := direntReclen(buf)
if !ok || reclen > uint64(len(buf)) {
return origlen, count, names
}
rec := buf[:reclen]
buf = buf[reclen:]
ino, ok := direntIno(rec)
if !ok {
break
}
if ino == 0 { // File absent in directory.
continue
}
const namoff = uint64(unsafe.Offsetof(Dirent{}.Name))
namlen, ok := direntNamlen(rec)
if !ok || namoff+namlen > uint64(len(rec)) {
break
}
name := rec[namoff : namoff+namlen]
for i, c := range name {
if c == 0 {
name = name[:i]
break
}
}
// Check for useless names before allocating a string.
if string(name) == "." || string(name) == ".." {
continue
}
max--
count++
names = append(names, string(name))
}
return origlen - len(buf), count, names
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm) || plan9
// Unix environment variables.
package syscall
import (
"runtime"
"sync"
)
var (
// envOnce guards initialization by copyenv, which populates env.
envOnce sync.Once
// envLock guards env and envs.
envLock sync.RWMutex
// env maps from an environment variable to its first occurrence in envs.
env map[string]int
// envs is provided by the runtime. elements are expected to
// be of the form "key=value". An empty string means deleted
// (or a duplicate to be ignored).
envs []string = runtime_envs()
)
func runtime_envs() []string // in package runtime
func copyenv() {
env = make(map[string]int)
for i, s := range envs {
for j := 0; j < len(s); j++ {
if s[j] == '=' {
key := s[:j]
if _, ok := env[key]; !ok {
env[key] = i // first mention of key
} else {
// Clear duplicate keys. This permits Unsetenv to
// safely delete only the first item without
// worrying about unshadowing a later one,
// which might be a security problem.
envs[i] = ""
}
break
}
}
}
}
func Unsetenv(key string) error {
envOnce.Do(copyenv)
envLock.Lock()
defer envLock.Unlock()
if i, ok := env[key]; ok {
envs[i] = ""
delete(env, key)
}
runtimeUnsetenv(key)
return nil
}
func Getenv(key string) (value string, found bool) {
envOnce.Do(copyenv)
if len(key) == 0 {
return "", false
}
envLock.RLock()
defer envLock.RUnlock()
i, ok := env[key]
if !ok {
return "", false
}
s := envs[i]
for i := 0; i < len(s); i++ {
if s[i] == '=' {
return s[i+1:], true
}
}
return "", false
}
func Setenv(key, value string) error {
envOnce.Do(copyenv)
if len(key) == 0 {
return EINVAL
}
for i := 0; i < len(key); i++ {
if key[i] == '=' || key[i] == 0 {
return EINVAL
}
}
// On Plan 9, null is used as a separator, eg in $path.
if runtime.GOOS != "plan9" {
for i := 0; i < len(value); i++ {
if value[i] == 0 {
return EINVAL
}
}
}
envLock.Lock()
defer envLock.Unlock()
i, ok := env[key]
kv := key + "=" + value
if ok {
envs[i] = kv
} else {
i = len(envs)
envs = append(envs, kv)
}
env[key] = i
runtimeSetenv(key, value)
return nil
}
func Clearenv() {
envOnce.Do(copyenv) // prevent copyenv in Getenv/Setenv
envLock.Lock()
defer envLock.Unlock()
for k := range env {
runtimeUnsetenv(k)
}
env = make(map[string]int)
envs = []string{}
}
func Environ() []string {
envOnce.Do(copyenv)
envLock.RLock()
defer envLock.RUnlock()
a := make([]string, 0, len(envs))
for _, env := range envs {
if env != "" {
a = append(a, env)
}
}
return a
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux
package syscall
import (
"internal/itoa"
"runtime"
"unsafe"
)
// Linux unshare/clone/clone2/clone3 flags, architecture-independent,
// copied from linux/sched.h.
const (
CLONE_VM = 0x00000100 // set if VM shared between processes
CLONE_FS = 0x00000200 // set if fs info shared between processes
CLONE_FILES = 0x00000400 // set if open files shared between processes
CLONE_SIGHAND = 0x00000800 // set if signal handlers and blocked signals shared
CLONE_PIDFD = 0x00001000 // set if a pidfd should be placed in parent
CLONE_PTRACE = 0x00002000 // set if we want to let tracing continue on the child too
CLONE_VFORK = 0x00004000 // set if the parent wants the child to wake it up on mm_release
CLONE_PARENT = 0x00008000 // set if we want to have the same parent as the cloner
CLONE_THREAD = 0x00010000 // Same thread group?
CLONE_NEWNS = 0x00020000 // New mount namespace group
CLONE_SYSVSEM = 0x00040000 // share system V SEM_UNDO semantics
CLONE_SETTLS = 0x00080000 // create a new TLS for the child
CLONE_PARENT_SETTID = 0x00100000 // set the TID in the parent
CLONE_CHILD_CLEARTID = 0x00200000 // clear the TID in the child
CLONE_DETACHED = 0x00400000 // Unused, ignored
CLONE_UNTRACED = 0x00800000 // set if the tracing process can't force CLONE_PTRACE on this clone
CLONE_CHILD_SETTID = 0x01000000 // set the TID in the child
CLONE_NEWCGROUP = 0x02000000 // New cgroup namespace
CLONE_NEWUTS = 0x04000000 // New utsname namespace
CLONE_NEWIPC = 0x08000000 // New ipc namespace
CLONE_NEWUSER = 0x10000000 // New user namespace
CLONE_NEWPID = 0x20000000 // New pid namespace
CLONE_NEWNET = 0x40000000 // New network namespace
CLONE_IO = 0x80000000 // Clone io context
// Flags for the clone3() syscall.
CLONE_CLEAR_SIGHAND = 0x100000000 // Clear any signal handler and reset to SIG_DFL.
CLONE_INTO_CGROUP = 0x200000000 // Clone into a specific cgroup given the right permissions.
// Cloning flags intersect with CSIGNAL so can be used with unshare and clone3
// syscalls only:
CLONE_NEWTIME = 0x00000080 // New time namespace
)
// SysProcIDMap holds Container ID to Host ID mappings used for User Namespaces in Linux.
// See user_namespaces(7).
type SysProcIDMap struct {
ContainerID int // Container ID.
HostID int // Host ID.
Size int // Size.
}
type SysProcAttr struct {
Chroot string // Chroot.
Credential *Credential // Credential.
// Ptrace tells the child to call ptrace(PTRACE_TRACEME).
// Call runtime.LockOSThread before starting a process with this set,
// and don't call UnlockOSThread until done with PtraceSyscall calls.
Ptrace bool
Setsid bool // Create session.
// Setpgid sets the process group ID of the child to Pgid,
// or, if Pgid == 0, to the new child's process ID.
Setpgid bool
// Setctty sets the controlling terminal of the child to
// file descriptor Ctty. Ctty must be a descriptor number
// in the child process: an index into ProcAttr.Files.
// This is only meaningful if Setsid is true.
Setctty bool
Noctty bool // Detach fd 0 from controlling terminal
Ctty int // Controlling TTY fd
// Foreground places the child process group in the foreground.
// This implies Setpgid. The Ctty field must be set to
// the descriptor of the controlling TTY.
// Unlike Setctty, in this case Ctty must be a descriptor
// number in the parent process.
Foreground bool
Pgid int // Child's process group ID if Setpgid.
// Pdeathsig, if non-zero, is a signal that the kernel will send to
// the child process when the creating thread dies. Note that the signal
// is sent on thread termination, which may happen before process termination.
// There are more details at https://go.dev/issue/27505.
Pdeathsig Signal
Cloneflags uintptr // Flags for clone calls (Linux only)
Unshareflags uintptr // Flags for unshare calls (Linux only)
UidMappings []SysProcIDMap // User ID mappings for user namespaces.
GidMappings []SysProcIDMap // Group ID mappings for user namespaces.
// GidMappingsEnableSetgroups enabling setgroups syscall.
// If false, then setgroups syscall will be disabled for the child process.
// This parameter is no-op if GidMappings == nil. Otherwise for unprivileged
// users this should be set to false for mappings work.
GidMappingsEnableSetgroups bool
AmbientCaps []uintptr // Ambient capabilities (Linux only)
UseCgroupFD bool // Whether to make use of the CgroupFD field.
CgroupFD int // File descriptor of a cgroup to put the new process into.
}
var (
none = [...]byte{'n', 'o', 'n', 'e', 0}
slash = [...]byte{'/', 0}
)
// Implemented in runtime package.
func runtime_BeforeFork()
func runtime_AfterFork()
func runtime_AfterForkInChild()
// Fork, dup fd onto 0..len(fd), and exec(argv0, argvv, envv) in child.
// If a dup or exec fails, write the errno error to pipe.
// (Pipe is close-on-exec so if exec succeeds, it will be closed.)
// In the child, this function must not acquire any locks, because
// they might have been locked at the time of the fork. This means
// no rescheduling, no malloc calls, and no new stack segments.
// For the same reason compiler does not race instrument it.
// The calls to RawSyscall are okay because they are assembly
// functions that do not grow the stack.
//
//go:norace
func forkAndExecInChild(argv0 *byte, argv, envv []*byte, chroot, dir *byte, attr *ProcAttr, sys *SysProcAttr, pipe int) (pid int, err Errno) {
// Set up and fork. This returns immediately in the parent or
// if there's an error.
upid, err, mapPipe, locked := forkAndExecInChild1(argv0, argv, envv, chroot, dir, attr, sys, pipe)
if locked {
runtime_AfterFork()
}
if err != 0 {
return 0, err
}
// parent; return PID
pid = int(upid)
if sys.UidMappings != nil || sys.GidMappings != nil {
Close(mapPipe[0])
var err2 Errno
// uid/gid mappings will be written after fork and unshare(2) for user
// namespaces.
if sys.Unshareflags&CLONE_NEWUSER == 0 {
if err := writeUidGidMappings(pid, sys); err != nil {
err2 = err.(Errno)
}
}
RawSyscall(SYS_WRITE, uintptr(mapPipe[1]), uintptr(unsafe.Pointer(&err2)), unsafe.Sizeof(err2))
Close(mapPipe[1])
}
return pid, 0
}
const _LINUX_CAPABILITY_VERSION_3 = 0x20080522
type capHeader struct {
version uint32
pid int32
}
type capData struct {
effective uint32
permitted uint32
inheritable uint32
}
type caps struct {
hdr capHeader
data [2]capData
}
// See CAP_TO_INDEX in linux/capability.h:
func capToIndex(cap uintptr) uintptr { return cap >> 5 }
// See CAP_TO_MASK in linux/capability.h:
func capToMask(cap uintptr) uint32 { return 1 << uint(cap&31) }
// cloneArgs holds arguments for clone3 Linux syscall.
type cloneArgs struct {
flags uint64 // Flags bit mask
pidFD uint64 // Where to store PID file descriptor (int *)
childTID uint64 // Where to store child TID, in child's memory (pid_t *)
parentTID uint64 // Where to store child TID, in parent's memory (pid_t *)
exitSignal uint64 // Signal to deliver to parent on child termination
stack uint64 // Pointer to lowest byte of stack
stackSize uint64 // Size of stack
tls uint64 // Location of new TLS
setTID uint64 // Pointer to a pid_t array (since Linux 5.5)
setTIDSize uint64 // Number of elements in set_tid (since Linux 5.5)
cgroup uint64 // File descriptor for target cgroup of child (since Linux 5.7)
}
// forkAndExecInChild1 implements the body of forkAndExecInChild up to
// the parent's post-fork path. This is a separate function so we can
// separate the child's and parent's stack frames if we're using
// vfork.
//
// This is go:noinline because the point is to keep the stack frames
// of this and forkAndExecInChild separate.
//
//go:noinline
//go:norace
func forkAndExecInChild1(argv0 *byte, argv, envv []*byte, chroot, dir *byte, attr *ProcAttr, sys *SysProcAttr, pipe int) (pid uintptr, err1 Errno, mapPipe [2]int, locked bool) {
// Defined in linux/prctl.h starting with Linux 4.3.
const (
PR_CAP_AMBIENT = 0x2f
PR_CAP_AMBIENT_RAISE = 0x2
)
// vfork requires that the child not touch any of the parent's
// active stack frames. Hence, the child does all post-fork
// processing in this stack frame and never returns, while the
// parent returns immediately from this frame and does all
// post-fork processing in the outer frame.
//
// Declare all variables at top in case any
// declarations require heap allocation (e.g., err2).
// ":=" should not be used to declare any variable after
// the call to runtime_BeforeFork.
//
// NOTE(bcmills): The allocation behavior described in the above comment
// seems to lack a corresponding test, and it may be rendered invalid
// by an otherwise-correct change in the compiler.
var (
err2 Errno
nextfd int
i int
caps caps
fd1, flags uintptr
puid, psetgroups, pgid []byte
uidmap, setgroups, gidmap []byte
clone3 *cloneArgs
pgrp int32
dirfd int
cred *Credential
ngroups, groups uintptr
c uintptr
)
if sys.UidMappings != nil {
puid = []byte("/proc/self/uid_map\000")
uidmap = formatIDMappings(sys.UidMappings)
}
if sys.GidMappings != nil {
psetgroups = []byte("/proc/self/setgroups\000")
pgid = []byte("/proc/self/gid_map\000")
if sys.GidMappingsEnableSetgroups {
setgroups = []byte("allow\000")
} else {
setgroups = []byte("deny\000")
}
gidmap = formatIDMappings(sys.GidMappings)
}
// Record parent PID so child can test if it has died.
ppid, _ := rawSyscallNoError(SYS_GETPID, 0, 0, 0)
// Guard against side effects of shuffling fds below.
// Make sure that nextfd is beyond any currently open files so
// that we can't run the risk of overwriting any of them.
fd := make([]int, len(attr.Files))
nextfd = len(attr.Files)
for i, ufd := range attr.Files {
if nextfd < int(ufd) {
nextfd = int(ufd)
}
fd[i] = int(ufd)
}
nextfd++
// Allocate another pipe for parent to child communication for
// synchronizing writing of User ID/Group ID mappings.
if sys.UidMappings != nil || sys.GidMappings != nil {
if err := forkExecPipe(mapPipe[:]); err != nil {
err1 = err.(Errno)
return
}
}
flags = sys.Cloneflags
if sys.Cloneflags&CLONE_NEWUSER == 0 && sys.Unshareflags&CLONE_NEWUSER == 0 {
flags |= CLONE_VFORK | CLONE_VM
}
// Whether to use clone3.
if sys.UseCgroupFD {
clone3 = &cloneArgs{
flags: uint64(flags) | CLONE_INTO_CGROUP,
exitSignal: uint64(SIGCHLD),
cgroup: uint64(sys.CgroupFD),
}
}
// About to call fork.
// No more allocation or calls of non-assembly functions.
runtime_BeforeFork()
locked = true
if clone3 != nil {
pid, err1 = rawVforkSyscall(_SYS_clone3, uintptr(unsafe.Pointer(clone3)), unsafe.Sizeof(*clone3))
} else {
flags |= uintptr(SIGCHLD)
if runtime.GOARCH == "s390x" {
// On Linux/s390, the first two arguments of clone(2) are swapped.
pid, err1 = rawVforkSyscall(SYS_CLONE, 0, flags)
} else {
pid, err1 = rawVforkSyscall(SYS_CLONE, flags, 0)
}
}
if err1 != 0 || pid != 0 {
// If we're in the parent, we must return immediately
// so we're not in the same stack frame as the child.
// This can at most use the return PC, which the child
// will not modify, and the results of
// rawVforkSyscall, which must have been written after
// the child was replaced.
return
}
// Fork succeeded, now in child.
// Enable the "keep capabilities" flag to set ambient capabilities later.
if len(sys.AmbientCaps) > 0 {
_, _, err1 = RawSyscall6(SYS_PRCTL, PR_SET_KEEPCAPS, 1, 0, 0, 0, 0)
if err1 != 0 {
goto childerror
}
}
// Wait for User ID/Group ID mappings to be written.
if sys.UidMappings != nil || sys.GidMappings != nil {
if _, _, err1 = RawSyscall(SYS_CLOSE, uintptr(mapPipe[1]), 0, 0); err1 != 0 {
goto childerror
}
pid, _, err1 = RawSyscall(SYS_READ, uintptr(mapPipe[0]), uintptr(unsafe.Pointer(&err2)), unsafe.Sizeof(err2))
if err1 != 0 {
goto childerror
}
if pid != unsafe.Sizeof(err2) {
err1 = EINVAL
goto childerror
}
if err2 != 0 {
err1 = err2
goto childerror
}
}
// Session ID
if sys.Setsid {
_, _, err1 = RawSyscall(SYS_SETSID, 0, 0, 0)
if err1 != 0 {
goto childerror
}
}
// Set process group
if sys.Setpgid || sys.Foreground {
// Place child in process group.
_, _, err1 = RawSyscall(SYS_SETPGID, 0, uintptr(sys.Pgid), 0)
if err1 != 0 {
goto childerror
}
}
if sys.Foreground {
pgrp = int32(sys.Pgid)
if pgrp == 0 {
pid, _ = rawSyscallNoError(SYS_GETPID, 0, 0, 0)
pgrp = int32(pid)
}
// Place process group in foreground.
_, _, err1 = RawSyscall(SYS_IOCTL, uintptr(sys.Ctty), uintptr(TIOCSPGRP), uintptr(unsafe.Pointer(&pgrp)))
if err1 != 0 {
goto childerror
}
}
// Restore the signal mask. We do this after TIOCSPGRP to avoid
// having the kernel send a SIGTTOU signal to the process group.
runtime_AfterForkInChild()
// Unshare
if sys.Unshareflags != 0 {
_, _, err1 = RawSyscall(SYS_UNSHARE, sys.Unshareflags, 0, 0)
if err1 != 0 {
goto childerror
}
if sys.Unshareflags&CLONE_NEWUSER != 0 && sys.GidMappings != nil {
dirfd = int(_AT_FDCWD)
if fd1, _, err1 = RawSyscall6(SYS_OPENAT, uintptr(dirfd), uintptr(unsafe.Pointer(&psetgroups[0])), uintptr(O_WRONLY), 0, 0, 0); err1 != 0 {
goto childerror
}
pid, _, err1 = RawSyscall(SYS_WRITE, uintptr(fd1), uintptr(unsafe.Pointer(&setgroups[0])), uintptr(len(setgroups)))
if err1 != 0 {
goto childerror
}
if _, _, err1 = RawSyscall(SYS_CLOSE, uintptr(fd1), 0, 0); err1 != 0 {
goto childerror
}
if fd1, _, err1 = RawSyscall6(SYS_OPENAT, uintptr(dirfd), uintptr(unsafe.Pointer(&pgid[0])), uintptr(O_WRONLY), 0, 0, 0); err1 != 0 {
goto childerror
}
pid, _, err1 = RawSyscall(SYS_WRITE, uintptr(fd1), uintptr(unsafe.Pointer(&gidmap[0])), uintptr(len(gidmap)))
if err1 != 0 {
goto childerror
}
if _, _, err1 = RawSyscall(SYS_CLOSE, uintptr(fd1), 0, 0); err1 != 0 {
goto childerror
}
}
if sys.Unshareflags&CLONE_NEWUSER != 0 && sys.UidMappings != nil {
dirfd = int(_AT_FDCWD)
if fd1, _, err1 = RawSyscall6(SYS_OPENAT, uintptr(dirfd), uintptr(unsafe.Pointer(&puid[0])), uintptr(O_WRONLY), 0, 0, 0); err1 != 0 {
goto childerror
}
pid, _, err1 = RawSyscall(SYS_WRITE, uintptr(fd1), uintptr(unsafe.Pointer(&uidmap[0])), uintptr(len(uidmap)))
if err1 != 0 {
goto childerror
}
if _, _, err1 = RawSyscall(SYS_CLOSE, uintptr(fd1), 0, 0); err1 != 0 {
goto childerror
}
}
// The unshare system call in Linux doesn't unshare mount points
// mounted with --shared. Systemd mounts / with --shared. For a
// long discussion of the pros and cons of this see debian bug 739593.
// The Go model of unsharing is more like Plan 9, where you ask
// to unshare and the namespaces are unconditionally unshared.
// To make this model work we must further mark / as MS_PRIVATE.
// This is what the standard unshare command does.
if sys.Unshareflags&CLONE_NEWNS == CLONE_NEWNS {
_, _, err1 = RawSyscall6(SYS_MOUNT, uintptr(unsafe.Pointer(&none[0])), uintptr(unsafe.Pointer(&slash[0])), 0, MS_REC|MS_PRIVATE, 0, 0)
if err1 != 0 {
goto childerror
}
}
}
// Chroot
if chroot != nil {
_, _, err1 = RawSyscall(SYS_CHROOT, uintptr(unsafe.Pointer(chroot)), 0, 0)
if err1 != 0 {
goto childerror
}
}
// User and groups
if cred = sys.Credential; cred != nil {
ngroups = uintptr(len(cred.Groups))
groups = uintptr(0)
if ngroups > 0 {
groups = uintptr(unsafe.Pointer(&cred.Groups[0]))
}
if !(sys.GidMappings != nil && !sys.GidMappingsEnableSetgroups && ngroups == 0) && !cred.NoSetGroups {
_, _, err1 = RawSyscall(_SYS_setgroups, ngroups, groups, 0)
if err1 != 0 {
goto childerror
}
}
_, _, err1 = RawSyscall(sys_SETGID, uintptr(cred.Gid), 0, 0)
if err1 != 0 {
goto childerror
}
_, _, err1 = RawSyscall(sys_SETUID, uintptr(cred.Uid), 0, 0)
if err1 != 0 {
goto childerror
}
}
if len(sys.AmbientCaps) != 0 {
// Ambient capabilities were added in the 4.3 kernel,
// so it is safe to always use _LINUX_CAPABILITY_VERSION_3.
caps.hdr.version = _LINUX_CAPABILITY_VERSION_3
if _, _, err1 = RawSyscall(SYS_CAPGET, uintptr(unsafe.Pointer(&caps.hdr)), uintptr(unsafe.Pointer(&caps.data[0])), 0); err1 != 0 {
goto childerror
}
for _, c = range sys.AmbientCaps {
// Add the c capability to the permitted and inheritable capability mask,
// otherwise we will not be able to add it to the ambient capability mask.
caps.data[capToIndex(c)].permitted |= capToMask(c)
caps.data[capToIndex(c)].inheritable |= capToMask(c)
}
if _, _, err1 = RawSyscall(SYS_CAPSET, uintptr(unsafe.Pointer(&caps.hdr)), uintptr(unsafe.Pointer(&caps.data[0])), 0); err1 != 0 {
goto childerror
}
for _, c = range sys.AmbientCaps {
_, _, err1 = RawSyscall6(SYS_PRCTL, PR_CAP_AMBIENT, uintptr(PR_CAP_AMBIENT_RAISE), c, 0, 0, 0)
if err1 != 0 {
goto childerror
}
}
}
// Chdir
if dir != nil {
_, _, err1 = RawSyscall(SYS_CHDIR, uintptr(unsafe.Pointer(dir)), 0, 0)
if err1 != 0 {
goto childerror
}
}
// Parent death signal
if sys.Pdeathsig != 0 {
_, _, err1 = RawSyscall6(SYS_PRCTL, PR_SET_PDEATHSIG, uintptr(sys.Pdeathsig), 0, 0, 0, 0)
if err1 != 0 {
goto childerror
}
// Signal self if parent is already dead. This might cause a
// duplicate signal in rare cases, but it won't matter when
// using SIGKILL.
pid, _ = rawSyscallNoError(SYS_GETPPID, 0, 0, 0)
if pid != ppid {
pid, _ = rawSyscallNoError(SYS_GETPID, 0, 0, 0)
_, _, err1 = RawSyscall(SYS_KILL, pid, uintptr(sys.Pdeathsig), 0)
if err1 != 0 {
goto childerror
}
}
}
// Pass 1: look for fd[i] < i and move those up above len(fd)
// so that pass 2 won't stomp on an fd it needs later.
if pipe < nextfd {
_, _, err1 = RawSyscall(SYS_DUP3, uintptr(pipe), uintptr(nextfd), O_CLOEXEC)
if err1 != 0 {
goto childerror
}
pipe = nextfd
nextfd++
}
for i = 0; i < len(fd); i++ {
if fd[i] >= 0 && fd[i] < i {
if nextfd == pipe { // don't stomp on pipe
nextfd++
}
_, _, err1 = RawSyscall(SYS_DUP3, uintptr(fd[i]), uintptr(nextfd), O_CLOEXEC)
if err1 != 0 {
goto childerror
}
fd[i] = nextfd
nextfd++
}
}
// Pass 2: dup fd[i] down onto i.
for i = 0; i < len(fd); i++ {
if fd[i] == -1 {
RawSyscall(SYS_CLOSE, uintptr(i), 0, 0)
continue
}
if fd[i] == i {
// dup2(i, i) won't clear close-on-exec flag on Linux,
// probably not elsewhere either.
_, _, err1 = RawSyscall(fcntl64Syscall, uintptr(fd[i]), F_SETFD, 0)
if err1 != 0 {
goto childerror
}
continue
}
// The new fd is created NOT close-on-exec,
// which is exactly what we want.
_, _, err1 = RawSyscall(SYS_DUP3, uintptr(fd[i]), uintptr(i), 0)
if err1 != 0 {
goto childerror
}
}
// By convention, we don't close-on-exec the fds we are
// started with, so if len(fd) < 3, close 0, 1, 2 as needed.
// Programs that know they inherit fds >= 3 will need
// to set them close-on-exec.
for i = len(fd); i < 3; i++ {
RawSyscall(SYS_CLOSE, uintptr(i), 0, 0)
}
// Detach fd 0 from tty
if sys.Noctty {
_, _, err1 = RawSyscall(SYS_IOCTL, 0, uintptr(TIOCNOTTY), 0)
if err1 != 0 {
goto childerror
}
}
// Set the controlling TTY to Ctty
if sys.Setctty {
_, _, err1 = RawSyscall(SYS_IOCTL, uintptr(sys.Ctty), uintptr(TIOCSCTTY), 1)
if err1 != 0 {
goto childerror
}
}
// Enable tracing if requested.
// Do this right before exec so that we don't unnecessarily trace the runtime
// setting up after the fork. See issue #21428.
if sys.Ptrace {
_, _, err1 = RawSyscall(SYS_PTRACE, uintptr(PTRACE_TRACEME), 0, 0)
if err1 != 0 {
goto childerror
}
}
// Time to exec.
_, _, err1 = RawSyscall(SYS_EXECVE,
uintptr(unsafe.Pointer(argv0)),
uintptr(unsafe.Pointer(&argv[0])),
uintptr(unsafe.Pointer(&envv[0])))
childerror:
// send error code on pipe
RawSyscall(SYS_WRITE, uintptr(pipe), uintptr(unsafe.Pointer(&err1)), unsafe.Sizeof(err1))
for {
RawSyscall(SYS_EXIT, 253, 0, 0)
}
}
// Try to open a pipe with O_CLOEXEC set on both file descriptors.
func forkExecPipe(p []int) (err error) {
return Pipe2(p, O_CLOEXEC)
}
func formatIDMappings(idMap []SysProcIDMap) []byte {
var data []byte
for _, im := range idMap {
data = append(data, itoa.Itoa(im.ContainerID)+" "+itoa.Itoa(im.HostID)+" "+itoa.Itoa(im.Size)+"\n"...)
}
return data
}
// writeIDMappings writes the user namespace User ID or Group ID mappings to the specified path.
func writeIDMappings(path string, idMap []SysProcIDMap) error {
fd, err := Open(path, O_RDWR, 0)
if err != nil {
return err
}
if _, err := Write(fd, formatIDMappings(idMap)); err != nil {
Close(fd)
return err
}
if err := Close(fd); err != nil {
return err
}
return nil
}
// writeSetgroups writes to /proc/PID/setgroups "deny" if enable is false
// and "allow" if enable is true.
// This is needed since kernel 3.19, because you can't write gid_map without
// disabling setgroups() system call.
func writeSetgroups(pid int, enable bool) error {
sgf := "/proc/" + itoa.Itoa(pid) + "/setgroups"
fd, err := Open(sgf, O_RDWR, 0)
if err != nil {
return err
}
var data []byte
if enable {
data = []byte("allow")
} else {
data = []byte("deny")
}
if _, err := Write(fd, data); err != nil {
Close(fd)
return err
}
return Close(fd)
}
// writeUidGidMappings writes User ID and Group ID mappings for user namespaces
// for a process and it is called from the parent process.
func writeUidGidMappings(pid int, sys *SysProcAttr) error {
if sys.UidMappings != nil {
uidf := "/proc/" + itoa.Itoa(pid) + "/uid_map"
if err := writeIDMappings(uidf, sys.UidMappings); err != nil {
return err
}
}
if sys.GidMappings != nil {
// If the kernel is too old to support /proc/PID/setgroups, writeSetGroups will return ENOENT; this is OK.
if err := writeSetgroups(pid, sys.GidMappingsEnableSetgroups); err != nil && err != ENOENT {
return err
}
gidf := "/proc/" + itoa.Itoa(pid) + "/gid_map"
if err := writeIDMappings(gidf, sys.GidMappings); err != nil {
return err
}
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
// Fork, exec, wait, etc.
package syscall
import (
errorspkg "errors"
"internal/bytealg"
"runtime"
"sync"
"unsafe"
)
// Lock synchronizing creation of new file descriptors with fork.
//
// We want the child in a fork/exec sequence to inherit only the
// file descriptors we intend. To do that, we mark all file
// descriptors close-on-exec and then, in the child, explicitly
// unmark the ones we want the exec'ed program to keep.
// Unix doesn't make this easy: there is, in general, no way to
// allocate a new file descriptor close-on-exec. Instead you
// have to allocate the descriptor and then mark it close-on-exec.
// If a fork happens between those two events, the child's exec
// will inherit an unwanted file descriptor.
//
// This lock solves that race: the create new fd/mark close-on-exec
// operation is done holding ForkLock for reading, and the fork itself
// is done holding ForkLock for writing. At least, that's the idea.
// There are some complications.
//
// Some system calls that create new file descriptors can block
// for arbitrarily long times: open on a hung NFS server or named
// pipe, accept on a socket, and so on. We can't reasonably grab
// the lock across those operations.
//
// It is worse to inherit some file descriptors than others.
// If a non-malicious child accidentally inherits an open ordinary file,
// that's not a big deal. On the other hand, if a long-lived child
// accidentally inherits the write end of a pipe, then the reader
// of that pipe will not see EOF until that child exits, potentially
// causing the parent program to hang. This is a common problem
// in threaded C programs that use popen.
//
// Luckily, the file descriptors that are most important not to
// inherit are not the ones that can take an arbitrarily long time
// to create: pipe returns instantly, and the net package uses
// non-blocking I/O to accept on a listening socket.
// The rules for which file descriptor-creating operations use the
// ForkLock are as follows:
//
// 1) Pipe. Does not block. Use the ForkLock.
// 2) Socket. Does not block. Use the ForkLock.
// 3) Accept. If using non-blocking mode, use the ForkLock.
// Otherwise, live with the race.
// 4) Open. Can block. Use O_CLOEXEC if available (Linux).
// Otherwise, live with the race.
// 5) Dup. Does not block. Use the ForkLock.
// On Linux, could use fcntl F_DUPFD_CLOEXEC
// instead of the ForkLock, but only for dup(fd, -1).
var ForkLock sync.RWMutex
// StringSlicePtr converts a slice of strings to a slice of pointers
// to NUL-terminated byte arrays. If any string contains a NUL byte
// this function panics instead of returning an error.
//
// Deprecated: Use SlicePtrFromStrings instead.
func StringSlicePtr(ss []string) []*byte {
bb := make([]*byte, len(ss)+1)
for i := 0; i < len(ss); i++ {
bb[i] = StringBytePtr(ss[i])
}
bb[len(ss)] = nil
return bb
}
// SlicePtrFromStrings converts a slice of strings to a slice of
// pointers to NUL-terminated byte arrays. If any string contains
// a NUL byte, it returns (nil, EINVAL).
func SlicePtrFromStrings(ss []string) ([]*byte, error) {
n := 0
for _, s := range ss {
if bytealg.IndexByteString(s, 0) != -1 {
return nil, EINVAL
}
n += len(s) + 1 // +1 for NUL
}
bb := make([]*byte, len(ss)+1)
b := make([]byte, n)
n = 0
for i, s := range ss {
bb[i] = &b[n]
copy(b[n:], s)
n += len(s) + 1
}
return bb, nil
}
func CloseOnExec(fd int) { fcntl(fd, F_SETFD, FD_CLOEXEC) }
func SetNonblock(fd int, nonblocking bool) (err error) {
flag, err := fcntl(fd, F_GETFL, 0)
if err != nil {
return err
}
if nonblocking {
flag |= O_NONBLOCK
} else {
flag &^= O_NONBLOCK
}
_, err = fcntl(fd, F_SETFL, flag)
return err
}
// Credential holds user and group identities to be assumed
// by a child process started by StartProcess.
type Credential struct {
Uid uint32 // User ID.
Gid uint32 // Group ID.
Groups []uint32 // Supplementary group IDs.
NoSetGroups bool // If true, don't set supplementary groups
}
// ProcAttr holds attributes that will be applied to a new process started
// by StartProcess.
type ProcAttr struct {
Dir string // Current working directory.
Env []string // Environment.
Files []uintptr // File descriptors.
Sys *SysProcAttr
}
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
func forkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err error) {
var p [2]int
var n int
var err1 Errno
var wstatus WaitStatus
if attr == nil {
attr = &zeroProcAttr
}
sys := attr.Sys
if sys == nil {
sys = &zeroSysProcAttr
}
// Convert args to C form.
argv0p, err := BytePtrFromString(argv0)
if err != nil {
return 0, err
}
argvp, err := SlicePtrFromStrings(argv)
if err != nil {
return 0, err
}
envvp, err := SlicePtrFromStrings(attr.Env)
if err != nil {
return 0, err
}
if (runtime.GOOS == "freebsd" || runtime.GOOS == "dragonfly") && len(argv[0]) > len(argv0) {
argvp[0] = argv0p
}
var chroot *byte
if sys.Chroot != "" {
chroot, err = BytePtrFromString(sys.Chroot)
if err != nil {
return 0, err
}
}
var dir *byte
if attr.Dir != "" {
dir, err = BytePtrFromString(attr.Dir)
if err != nil {
return 0, err
}
}
// Both Setctty and Foreground use the Ctty field,
// but they give it slightly different meanings.
if sys.Setctty && sys.Foreground {
return 0, errorspkg.New("both Setctty and Foreground set in SysProcAttr")
}
if sys.Setctty && sys.Ctty >= len(attr.Files) {
return 0, errorspkg.New("Setctty set but Ctty not valid in child")
}
// Acquire the fork lock so that no other threads
// create new fds that are not yet close-on-exec
// before we fork.
ForkLock.Lock()
// Allocate child status pipe close on exec.
if err = forkExecPipe(p[:]); err != nil {
ForkLock.Unlock()
return 0, err
}
// Kick off child.
pid, err1 = forkAndExecInChild(argv0p, argvp, envvp, chroot, dir, attr, sys, p[1])
if err1 != 0 {
Close(p[0])
Close(p[1])
ForkLock.Unlock()
return 0, Errno(err1)
}
ForkLock.Unlock()
// Read child error status from pipe.
Close(p[1])
for {
n, err = readlen(p[0], (*byte)(unsafe.Pointer(&err1)), int(unsafe.Sizeof(err1)))
if err != EINTR {
break
}
}
Close(p[0])
if err != nil || n != 0 {
if n == int(unsafe.Sizeof(err1)) {
err = Errno(err1)
}
if err == nil {
err = EPIPE
}
// Child failed; wait for it to exit, to make sure
// the zombies don't accumulate.
_, err1 := Wait4(pid, &wstatus, 0, nil)
for err1 == EINTR {
_, err1 = Wait4(pid, &wstatus, 0, nil)
}
return 0, err
}
// Read got EOF, so pipe closed on exec, so exec succeeded.
return pid, nil
}
// Combination of fork and exec, careful to be thread safe.
func ForkExec(argv0 string, argv []string, attr *ProcAttr) (pid int, err error) {
return forkExec(argv0, argv, attr)
}
// StartProcess wraps ForkExec for package os.
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
pid, err = forkExec(argv0, argv, attr)
return pid, 0, err
}
// Implemented in runtime package.
func runtime_BeforeExec()
func runtime_AfterExec()
// execveLibc is non-nil on OS using libc syscall, set to execve in exec_libc.go; this
// avoids a build dependency for other platforms.
var execveLibc func(path uintptr, argv uintptr, envp uintptr) Errno
var execveDarwin func(path *byte, argv **byte, envp **byte) error
var execveOpenBSD func(path *byte, argv **byte, envp **byte) error
// Exec invokes the execve(2) system call.
func Exec(argv0 string, argv []string, envv []string) (err error) {
argv0p, err := BytePtrFromString(argv0)
if err != nil {
return err
}
argvp, err := SlicePtrFromStrings(argv)
if err != nil {
return err
}
envvp, err := SlicePtrFromStrings(envv)
if err != nil {
return err
}
runtime_BeforeExec()
var err1 error
if runtime.GOOS == "solaris" || runtime.GOOS == "illumos" || runtime.GOOS == "aix" {
// RawSyscall should never be used on Solaris, illumos, or AIX.
err1 = execveLibc(
uintptr(unsafe.Pointer(argv0p)),
uintptr(unsafe.Pointer(&argvp[0])),
uintptr(unsafe.Pointer(&envvp[0])))
} else if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
// Similarly on Darwin.
err1 = execveDarwin(argv0p, &argvp[0], &envvp[0])
} else if runtime.GOOS == "openbsd" && (runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
// Similarly on OpenBSD.
err1 = execveOpenBSD(argv0p, &argvp[0], &envvp[0])
} else {
_, _, err1 = RawSyscall(SYS_EXECVE,
uintptr(unsafe.Pointer(argv0p)),
uintptr(unsafe.Pointer(&argvp[0])),
uintptr(unsafe.Pointer(&envvp[0])))
}
runtime_AfterExec()
return err1
}
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build linux || freebsd || openbsd || netbsd || dragonfly
package syscall
import "unsafe"
// fcntl64Syscall is usually SYS_FCNTL, but is overridden on 32-bit Linux
// systems by flock_linux_32bit.go to be SYS_FCNTL64.
var fcntl64Syscall uintptr = SYS_FCNTL
// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command.
func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error {
_, _, errno := Syscall(fcntl64Syscall, fd, uintptr(cmd), uintptr(unsafe.Pointer(lk)))
if errno == 0 {
return nil
}
return errno
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Linux socket filter
package syscall
import (
"unsafe"
)
// Deprecated: Use golang.org/x/net/bpf instead.
func LsfStmt(code, k int) *SockFilter {
return &SockFilter{Code: uint16(code), K: uint32(k)}
}
// Deprecated: Use golang.org/x/net/bpf instead.
func LsfJump(code, k, jt, jf int) *SockFilter {
return &SockFilter{Code: uint16(code), Jt: uint8(jt), Jf: uint8(jf), K: uint32(k)}
}
// Deprecated: Use golang.org/x/net/bpf instead.
func LsfSocket(ifindex, proto int) (int, error) {
var lsall SockaddrLinklayer
// This is missing SOCK_CLOEXEC, but adding the flag
// could break callers.
s, e := Socket(AF_PACKET, SOCK_RAW, proto)
if e != nil {
return 0, e
}
p := (*[2]byte)(unsafe.Pointer(&lsall.Protocol))
p[0] = byte(proto >> 8)
p[1] = byte(proto)
lsall.Ifindex = ifindex
e = Bind(s, &lsall)
if e != nil {
Close(s)
return 0, e
}
return s, nil
}
type iflags struct {
name [IFNAMSIZ]byte
flags uint16
}
// Deprecated: Use golang.org/x/net/bpf instead.
func SetLsfPromisc(name string, m bool) error {
s, e := Socket(AF_INET, SOCK_DGRAM|SOCK_CLOEXEC, 0)
if e != nil {
return e
}
defer Close(s)
var ifl iflags
copy(ifl.name[:], []byte(name))
_, _, ep := Syscall(SYS_IOCTL, uintptr(s), SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifl)))
if ep != 0 {
return Errno(ep)
}
if m {
ifl.flags |= uint16(IFF_PROMISC)
} else {
ifl.flags &^= uint16(IFF_PROMISC)
}
_, _, ep = Syscall(SYS_IOCTL, uintptr(s), SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifl)))
if ep != 0 {
return Errno(ep)
}
return nil
}
// Deprecated: Use golang.org/x/net/bpf instead.
func AttachLsf(fd int, i []SockFilter) error {
var p SockFprog
p.Len = uint16(len(i))
p.Filter = (*SockFilter)(unsafe.Pointer(&i[0]))
return setsockopt(fd, SOL_SOCKET, SO_ATTACH_FILTER, unsafe.Pointer(&p), unsafe.Sizeof(p))
}
// Deprecated: Use golang.org/x/net/bpf instead.
func DetachLsf(fd int) error {
var dummy int
return setsockopt(fd, SOL_SOCKET, SO_DETACH_FILTER, unsafe.Pointer(&dummy), unsafe.Sizeof(dummy))
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !msan
package syscall
import (
"unsafe"
)
const msanenabled = false
func msanRead(addr unsafe.Pointer, len int) {
}
func msanWrite(addr unsafe.Pointer, len int) {
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Netlink sockets and messages
package syscall
import "unsafe"
// Round the length of a netlink message up to align it properly.
func nlmAlignOf(msglen int) int {
return (msglen + NLMSG_ALIGNTO - 1) & ^(NLMSG_ALIGNTO - 1)
}
// Round the length of a netlink route attribute up to align it
// properly.
func rtaAlignOf(attrlen int) int {
return (attrlen + RTA_ALIGNTO - 1) & ^(RTA_ALIGNTO - 1)
}
// NetlinkRouteRequest represents a request message to receive routing
// and link states from the kernel.
type NetlinkRouteRequest struct {
Header NlMsghdr
Data RtGenmsg
}
func (rr *NetlinkRouteRequest) toWireFormat() []byte {
b := make([]byte, rr.Header.Len)
*(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
*(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
*(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
*(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
*(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
b[16] = byte(rr.Data.Family)
return b
}
func newNetlinkRouteRequest(proto, seq, family int) []byte {
rr := &NetlinkRouteRequest{}
rr.Header.Len = uint32(NLMSG_HDRLEN + SizeofRtGenmsg)
rr.Header.Type = uint16(proto)
rr.Header.Flags = NLM_F_DUMP | NLM_F_REQUEST
rr.Header.Seq = uint32(seq)
rr.Data.Family = uint8(family)
return rr.toWireFormat()
}
// NetlinkRIB returns routing information base, as known as RIB, which
// consists of network facility information, states and parameters.
func NetlinkRIB(proto, family int) ([]byte, error) {
s, err := Socket(AF_NETLINK, SOCK_RAW|SOCK_CLOEXEC, NETLINK_ROUTE)
if err != nil {
return nil, err
}
defer Close(s)
sa := &SockaddrNetlink{Family: AF_NETLINK}
if err := Bind(s, sa); err != nil {
return nil, err
}
wb := newNetlinkRouteRequest(proto, 1, family)
if err := Sendto(s, wb, 0, sa); err != nil {
return nil, err
}
lsa, err := Getsockname(s)
if err != nil {
return nil, err
}
lsanl, ok := lsa.(*SockaddrNetlink)
if !ok {
return nil, EINVAL
}
var tab []byte
rbNew := make([]byte, Getpagesize())
done:
for {
rb := rbNew
nr, _, err := Recvfrom(s, rb, 0)
if err != nil {
return nil, err
}
if nr < NLMSG_HDRLEN {
return nil, EINVAL
}
rb = rb[:nr]
tab = append(tab, rb...)
msgs, err := ParseNetlinkMessage(rb)
if err != nil {
return nil, err
}
for _, m := range msgs {
if m.Header.Seq != 1 || m.Header.Pid != lsanl.Pid {
return nil, EINVAL
}
if m.Header.Type == NLMSG_DONE {
break done
}
if m.Header.Type == NLMSG_ERROR {
return nil, EINVAL
}
}
}
return tab, nil
}
// NetlinkMessage represents a netlink message.
type NetlinkMessage struct {
Header NlMsghdr
Data []byte
}
// ParseNetlinkMessage parses b as an array of netlink messages and
// returns the slice containing the NetlinkMessage structures.
func ParseNetlinkMessage(b []byte) ([]NetlinkMessage, error) {
var msgs []NetlinkMessage
for len(b) >= NLMSG_HDRLEN {
h, dbuf, dlen, err := netlinkMessageHeaderAndData(b)
if err != nil {
return nil, err
}
m := NetlinkMessage{Header: *h, Data: dbuf[:int(h.Len)-NLMSG_HDRLEN]}
msgs = append(msgs, m)
b = b[dlen:]
}
return msgs, nil
}
func netlinkMessageHeaderAndData(b []byte) (*NlMsghdr, []byte, int, error) {
h := (*NlMsghdr)(unsafe.Pointer(&b[0]))
l := nlmAlignOf(int(h.Len))
if int(h.Len) < NLMSG_HDRLEN || l > len(b) {
return nil, nil, 0, EINVAL
}
return h, b[NLMSG_HDRLEN:], l, nil
}
// NetlinkRouteAttr represents a netlink route attribute.
type NetlinkRouteAttr struct {
Attr RtAttr
Value []byte
}
// ParseNetlinkRouteAttr parses m's payload as an array of netlink
// route attributes and returns the slice containing the
// NetlinkRouteAttr structures.
func ParseNetlinkRouteAttr(m *NetlinkMessage) ([]NetlinkRouteAttr, error) {
var b []byte
switch m.Header.Type {
case RTM_NEWLINK, RTM_DELLINK:
b = m.Data[SizeofIfInfomsg:]
case RTM_NEWADDR, RTM_DELADDR:
b = m.Data[SizeofIfAddrmsg:]
case RTM_NEWROUTE, RTM_DELROUTE:
b = m.Data[SizeofRtMsg:]
default:
return nil, EINVAL
}
var attrs []NetlinkRouteAttr
for len(b) >= SizeofRtAttr {
a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
if err != nil {
return nil, err
}
ra := NetlinkRouteAttr{Attr: *a, Value: vbuf[:int(a.Len)-SizeofRtAttr]}
attrs = append(attrs, ra)
b = b[alen:]
}
return attrs, nil
}
func netlinkRouteAttrAndValue(b []byte) (*RtAttr, []byte, int, error) {
a := (*RtAttr)(unsafe.Pointer(&b[0]))
if int(a.Len) < SizeofRtAttr || int(a.Len) > len(b) {
return nil, nil, 0, EINVAL
}
return a, b[SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Socket control messages
package syscall
import "unsafe"
// UnixCredentials encodes credentials into a socket control message
// for sending to another process. This can be used for
// authentication.
func UnixCredentials(ucred *Ucred) []byte {
b := make([]byte, CmsgSpace(SizeofUcred))
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
h.Level = SOL_SOCKET
h.Type = SCM_CREDENTIALS
h.SetLen(CmsgLen(SizeofUcred))
*(*Ucred)(h.data(0)) = *ucred
return b
}
// ParseUnixCredentials decodes a socket control message that contains
// credentials in a Ucred structure. To receive such a message, the
// SO_PASSCRED option must be enabled on the socket.
func ParseUnixCredentials(m *SocketControlMessage) (*Ucred, error) {
if m.Header.Level != SOL_SOCKET {
return nil, EINVAL
}
if m.Header.Type != SCM_CREDENTIALS {
return nil, EINVAL
}
if uintptr(len(m.Data)) < unsafe.Sizeof(Ucred{}) {
return nil, EINVAL
}
ucred := *(*Ucred)(unsafe.Pointer(&m.Data[0]))
return &ucred, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
// Socket control messages
package syscall
import (
"unsafe"
)
// CmsgLen returns the value to store in the Len field of the Cmsghdr
// structure, taking into account any necessary alignment.
func CmsgLen(datalen int) int {
return cmsgAlignOf(SizeofCmsghdr) + datalen
}
// CmsgSpace returns the number of bytes an ancillary element with
// payload of the passed data length occupies.
func CmsgSpace(datalen int) int {
return cmsgAlignOf(SizeofCmsghdr) + cmsgAlignOf(datalen)
}
func (h *Cmsghdr) data(offset uintptr) unsafe.Pointer {
return unsafe.Pointer(uintptr(unsafe.Pointer(h)) + uintptr(cmsgAlignOf(SizeofCmsghdr)) + offset)
}
// SocketControlMessage represents a socket control message.
type SocketControlMessage struct {
Header Cmsghdr
Data []byte
}
// ParseSocketControlMessage parses b as an array of socket control
// messages.
func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
var msgs []SocketControlMessage
i := 0
for i+CmsgLen(0) <= len(b) {
h, dbuf, err := socketControlMessageHeaderAndData(b[i:])
if err != nil {
return nil, err
}
m := SocketControlMessage{Header: *h, Data: dbuf}
msgs = append(msgs, m)
i += cmsgAlignOf(int(h.Len))
}
return msgs, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {
return nil, nil, EINVAL
}
return h, b[cmsgAlignOf(SizeofCmsghdr):h.Len], nil
}
// UnixRights encodes a set of open file descriptors into a socket
// control message for sending to another process.
func UnixRights(fds ...int) []byte {
datalen := len(fds) * 4
b := make([]byte, CmsgSpace(datalen))
h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
h.Level = SOL_SOCKET
h.Type = SCM_RIGHTS
h.SetLen(CmsgLen(datalen))
for i, fd := range fds {
*(*int32)(h.data(4 * uintptr(i))) = int32(fd)
}
return b
}
// ParseUnixRights decodes a socket control message that contains an
// integer array of open file descriptors from another process.
func ParseUnixRights(m *SocketControlMessage) ([]int, error) {
if m.Header.Level != SOL_SOCKET {
return nil, EINVAL
}
if m.Header.Type != SCM_RIGHTS {
return nil, EINVAL
}
fds := make([]int, len(m.Data)>>2)
for i, j := 0, 0; i < len(m.Data); i += 4 {
fds[j] = int(*(*int32)(unsafe.Pointer(&m.Data[i])))
j++
}
return fds, nil
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || freebsd || linux || netbsd || openbsd || solaris
package syscall
import (
"runtime"
)
// Round the length of a raw sockaddr up to align it properly.
func cmsgAlignOf(salen int) int {
salign := sizeofPtr
// dragonfly needs to check ABI version at runtime, see cmsgAlignOf in
// sockcmsg_dragonfly.go
switch runtime.GOOS {
case "aix":
// There is no alignment on AIX.
salign = 1
case "darwin", "ios", "illumos", "solaris":
// NOTE: It seems like 64-bit Darwin, Illumos and Solaris
// kernels still require 32-bit aligned access to network
// subsystem.
if sizeofPtr == 8 {
salign = 4
}
case "netbsd", "openbsd":
// NetBSD and OpenBSD armv7 require 64-bit alignment.
if runtime.GOARCH == "arm" {
salign = 8
}
// NetBSD aarch64 requires 128-bit alignment.
if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm64" {
salign = 16
}
}
return (salen + salign - 1) & ^(salign - 1)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package syscall contains an interface to the low-level operating system
// primitives. The details vary depending on the underlying system, and
// by default, godoc will display the syscall documentation for the current
// system. If you want godoc to display syscall documentation for another
// system, set $GOOS and $GOARCH to the desired system. For example, if
// you want to view documentation for freebsd/arm on linux/amd64, set $GOOS
// to freebsd and $GOARCH to arm.
// The primary use of syscall is inside other packages that provide a more
// portable interface to the system, such as "os", "time" and "net". Use
// those packages rather than this one if you can.
// For details of the functions and data types in this package consult
// the manuals for the appropriate operating system.
// These calls return err == nil to indicate success; otherwise
// err is an operating system error describing the failure.
// On most systems, that error has type syscall.Errno.
//
// Deprecated: this package is locked down. Callers should use the
// corresponding package in the golang.org/x/sys repository instead.
// That is also where updates required by new systems or versions
// should be applied. See https://golang.org/s/go1.4-syscall for more
// information.
package syscall
import "internal/bytealg"
//go:generate go run ./mksyscall_windows.go -systemdll -output zsyscall_windows.go syscall_windows.go security_windows.go
// StringByteSlice converts a string to a NUL-terminated []byte,
// If s contains a NUL byte this function panics instead of
// returning an error.
//
// Deprecated: Use ByteSliceFromString instead.
func StringByteSlice(s string) []byte {
a, err := ByteSliceFromString(s)
if err != nil {
panic("syscall: string with NUL passed to StringByteSlice")
}
return a
}
// ByteSliceFromString returns a NUL-terminated slice of bytes
// containing the text of s. If s contains a NUL byte at any
// location, it returns (nil, EINVAL).
func ByteSliceFromString(s string) ([]byte, error) {
if bytealg.IndexByteString(s, 0) != -1 {
return nil, EINVAL
}
a := make([]byte, len(s)+1)
copy(a, s)
return a, nil
}
// StringBytePtr returns a pointer to a NUL-terminated array of bytes.
// If s contains a NUL byte this function panics instead of returning
// an error.
//
// Deprecated: Use BytePtrFromString instead.
func StringBytePtr(s string) *byte { return &StringByteSlice(s)[0] }
// BytePtrFromString returns a pointer to a NUL-terminated array of
// bytes containing the text of s. If s contains a NUL byte at any
// location, it returns (nil, EINVAL).
func BytePtrFromString(s string) (*byte, error) {
a, err := ByteSliceFromString(s)
if err != nil {
return nil, err
}
return &a[0], nil
}
// Single-word zero for use when we need a valid pointer to 0 bytes.
// See mksyscall.pl.
var _zero uintptr
// Unix returns the time stored in ts as seconds plus nanoseconds.
func (ts *Timespec) Unix() (sec int64, nsec int64) {
return int64(ts.Sec), int64(ts.Nsec)
}
// Unix returns the time stored in tv as seconds plus nanoseconds.
func (tv *Timeval) Unix() (sec int64, nsec int64) {
return int64(tv.Sec), int64(tv.Usec) * 1000
}
// Nano returns the time stored in ts as nanoseconds.
func (ts *Timespec) Nano() int64 {
return int64(ts.Sec)*1e9 + int64(ts.Nsec)
}
// Nano returns the time stored in tv as nanoseconds.
func (tv *Timeval) Nano() int64 {
return int64(tv.Sec)*1e9 + int64(tv.Usec)*1000
}
// Getpagesize and Exit are provided by the runtime.
func Getpagesize() int
func Exit(code int)
// runtimeSetenv and runtimeUnsetenv are provided by the runtime.
func runtimeSetenv(k, v string)
func runtimeUnsetenv(k string)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Linux system calls.
// This file is compiled as ordinary Go code,
// but it is also input to mksyscall,
// which parses the //sys lines and generates system call stubs.
// Note that sometimes we use a lowercase //sys name and
// wrap it in our own nicer implementation.
package syscall
import (
"internal/itoa"
"runtime"
"unsafe"
)
// N.B. RawSyscall6 is provided via linkname by runtime/internal/syscall.
//
// Errno is uintptr and thus compatible with the runtime/internal/syscall
// definition.
func RawSyscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno)
// Pull in entersyscall/exitsyscall for Syscall/Syscall6.
//
// Note that this can't be a push linkname because the runtime already has a
// nameless linkname to export to assembly here and in x/sys. Additionally,
// entersyscall fetches the caller PC and SP and thus can't have a wrapper
// inbetween.
//go:linkname runtime_entersyscall runtime.entersyscall
func runtime_entersyscall()
//go:linkname runtime_exitsyscall runtime.exitsyscall
func runtime_exitsyscall()
// N.B. For the Syscall functions below:
//
// //go:uintptrkeepalive because the uintptr argument may be converted pointers
// that need to be kept alive in the caller (this is implied for RawSyscall6
// since it has no body).
//
// //go:nosplit because stack copying does not account for uintptrkeepalive, so
// the stack must not grow. Stack copying cannot blindly assume that all
// uintptr arguments are pointers, because some values may look like pointers,
// but not really be pointers, and adjusting their value would break the call.
//
// //go:norace, on RawSyscall, to avoid race instrumentation if RawSyscall is
// called after fork, or from a signal handler.
//
// //go:linkname to ensure ABI wrappers are generated for external callers
// (notably x/sys/unix assembly).
//go:uintptrkeepalive
//go:nosplit
//go:norace
//go:linkname RawSyscall
func RawSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) {
return RawSyscall6(trap, a1, a2, a3, 0, 0, 0)
}
//go:uintptrkeepalive
//go:nosplit
//go:linkname Syscall
func Syscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) {
runtime_entersyscall()
// N.B. Calling RawSyscall here is unsafe with atomic coverage
// instrumentation and race mode.
//
// Coverage instrumentation will add a sync/atomic call to RawSyscall.
// Race mode will add race instrumentation to sync/atomic. Race
// instrumentation requires a P, which we no longer have.
//
// RawSyscall6 is fine because it is implemented in assembly and thus
// has no coverage instrumentation.
//
// This is typically not a problem in the runtime because cmd/go avoids
// adding coverage instrumentation to the runtime in race mode.
r1, r2, err = RawSyscall6(trap, a1, a2, a3, 0, 0, 0)
runtime_exitsyscall()
return
}
//go:uintptrkeepalive
//go:nosplit
//go:linkname Syscall6
func Syscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno) {
runtime_entersyscall()
r1, r2, err = RawSyscall6(trap, a1, a2, a3, a4, a5, a6)
runtime_exitsyscall()
return
}
func rawSyscallNoError(trap, a1, a2, a3 uintptr) (r1, r2 uintptr)
func rawVforkSyscall(trap, a1, a2 uintptr) (r1 uintptr, err Errno)
/*
* Wrapped
*/
func Access(path string, mode uint32) (err error) {
return Faccessat(_AT_FDCWD, path, mode, 0)
}
func Chmod(path string, mode uint32) (err error) {
return Fchmodat(_AT_FDCWD, path, mode, 0)
}
func Chown(path string, uid int, gid int) (err error) {
return Fchownat(_AT_FDCWD, path, uid, gid, 0)
}
func Creat(path string, mode uint32) (fd int, err error) {
return Open(path, O_CREAT|O_WRONLY|O_TRUNC, mode)
}
func EpollCreate(size int) (fd int, err error) {
if size <= 0 {
return -1, EINVAL
}
return EpollCreate1(0)
}
func isGroupMember(gid int) bool {
groups, err := Getgroups()
if err != nil {
return false
}
for _, g := range groups {
if g == gid {
return true
}
}
return false
}
func isCapDacOverrideSet() bool {
const _CAP_DAC_OVERRIDE = 1
var c caps
c.hdr.version = _LINUX_CAPABILITY_VERSION_3
_, _, err := RawSyscall(SYS_CAPGET, uintptr(unsafe.Pointer(&c.hdr)), uintptr(unsafe.Pointer(&c.data[0])), 0)
return err == 0 && c.data[0].effective&capToMask(_CAP_DAC_OVERRIDE) != 0
}
//sys faccessat(dirfd int, path string, mode uint32) (err error)
//sys faccessat2(dirfd int, path string, mode uint32, flags int) (err error) = _SYS_faccessat2
func Faccessat(dirfd int, path string, mode uint32, flags int) (err error) {
if flags == 0 {
return faccessat(dirfd, path, mode)
}
// Attempt to use the newer faccessat2, which supports flags directly,
// falling back if it doesn't exist.
//
// Don't attempt on Android, which does not allow faccessat2 through
// its seccomp policy [1] on any version of Android as of 2022-12-20.
//
// [1] https://cs.android.com/android/platform/superproject/+/master:bionic/libc/SECCOMP_BLOCKLIST_APP.TXT;l=4;drc=dbb8670dfdcc677f7e3b9262e93800fa14c4e417
if runtime.GOOS != "android" {
if err := faccessat2(dirfd, path, mode, flags); err != ENOSYS && err != EPERM {
return err
}
}
// The Linux kernel faccessat system call does not take any flags.
// The glibc faccessat implements the flags itself; see
// https://sourceware.org/git/?p=glibc.git;a=blob;f=sysdeps/unix/sysv/linux/faccessat.c;hb=HEAD
// Because people naturally expect syscall.Faccessat to act
// like C faccessat, we do the same.
if flags & ^(_AT_SYMLINK_NOFOLLOW|_AT_EACCESS) != 0 {
return EINVAL
}
var st Stat_t
if err := fstatat(dirfd, path, &st, flags&_AT_SYMLINK_NOFOLLOW); err != nil {
return err
}
mode &= 7
if mode == 0 {
return nil
}
// Fallback to checking permission bits.
var uid int
if flags&_AT_EACCESS != 0 {
uid = Geteuid()
if uid != 0 && isCapDacOverrideSet() {
// If CAP_DAC_OVERRIDE is set, file access check is
// done by the kernel in the same way as for root
// (see generic_permission() in the Linux sources).
uid = 0
}
} else {
uid = Getuid()
}
if uid == 0 {
if mode&1 == 0 {
// Root can read and write any file.
return nil
}
if st.Mode&0111 != 0 {
// Root can execute any file that anybody can execute.
return nil
}
return EACCES
}
var fmode uint32
if uint32(uid) == st.Uid {
fmode = (st.Mode >> 6) & 7
} else {
var gid int
if flags&_AT_EACCESS != 0 {
gid = Getegid()
} else {
gid = Getgid()
}
if uint32(gid) == st.Gid || isGroupMember(int(st.Gid)) {
fmode = (st.Mode >> 3) & 7
} else {
fmode = st.Mode & 7
}
}
if fmode&mode == mode {
return nil
}
return EACCES
}
//sys fchmodat(dirfd int, path string, mode uint32) (err error)
func Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) {
// Linux fchmodat doesn't support the flags parameter. Mimick glibc's behavior
// and check the flags. Otherwise the mode would be applied to the symlink
// destination which is not what the user expects.
if flags&^_AT_SYMLINK_NOFOLLOW != 0 {
return EINVAL
} else if flags&_AT_SYMLINK_NOFOLLOW != 0 {
return EOPNOTSUPP
}
return fchmodat(dirfd, path, mode)
}
//sys linkat(olddirfd int, oldpath string, newdirfd int, newpath string, flags int) (err error)
func Link(oldpath string, newpath string) (err error) {
return linkat(_AT_FDCWD, oldpath, _AT_FDCWD, newpath, 0)
}
func Mkdir(path string, mode uint32) (err error) {
return Mkdirat(_AT_FDCWD, path, mode)
}
func Mknod(path string, mode uint32, dev int) (err error) {
return Mknodat(_AT_FDCWD, path, mode, dev)
}
func Open(path string, mode int, perm uint32) (fd int, err error) {
return openat(_AT_FDCWD, path, mode|O_LARGEFILE, perm)
}
//sys openat(dirfd int, path string, flags int, mode uint32) (fd int, err error)
func Openat(dirfd int, path string, flags int, mode uint32) (fd int, err error) {
return openat(dirfd, path, flags|O_LARGEFILE, mode)
}
func Pipe(p []int) error {
return Pipe2(p, 0)
}
//sysnb pipe2(p *[2]_C_int, flags int) (err error)
func Pipe2(p []int, flags int) error {
if len(p) != 2 {
return EINVAL
}
var pp [2]_C_int
err := pipe2(&pp, flags)
if err == nil {
p[0] = int(pp[0])
p[1] = int(pp[1])
}
return err
}
//sys readlinkat(dirfd int, path string, buf []byte) (n int, err error)
func Readlink(path string, buf []byte) (n int, err error) {
return readlinkat(_AT_FDCWD, path, buf)
}
func Rename(oldpath string, newpath string) (err error) {
return Renameat(_AT_FDCWD, oldpath, _AT_FDCWD, newpath)
}
func Rmdir(path string) error {
return unlinkat(_AT_FDCWD, path, _AT_REMOVEDIR)
}
//sys symlinkat(oldpath string, newdirfd int, newpath string) (err error)
func Symlink(oldpath string, newpath string) (err error) {
return symlinkat(oldpath, _AT_FDCWD, newpath)
}
func Unlink(path string) error {
return unlinkat(_AT_FDCWD, path, 0)
}
//sys unlinkat(dirfd int, path string, flags int) (err error)
func Unlinkat(dirfd int, path string) error {
return unlinkat(dirfd, path, 0)
}
func Utimes(path string, tv []Timeval) (err error) {
if len(tv) != 2 {
return EINVAL
}
return utimes(path, (*[2]Timeval)(unsafe.Pointer(&tv[0])))
}
//sys utimensat(dirfd int, path string, times *[2]Timespec, flag int) (err error)
func UtimesNano(path string, ts []Timespec) (err error) {
if len(ts) != 2 {
return EINVAL
}
return utimensat(_AT_FDCWD, path, (*[2]Timespec)(unsafe.Pointer(&ts[0])), 0)
}
func Futimesat(dirfd int, path string, tv []Timeval) (err error) {
if len(tv) != 2 {
return EINVAL
}
return futimesat(dirfd, path, (*[2]Timeval)(unsafe.Pointer(&tv[0])))
}
func Futimes(fd int, tv []Timeval) (err error) {
// Believe it or not, this is the best we can do on Linux
// (and is what glibc does).
return Utimes("/proc/self/fd/"+itoa.Itoa(fd), tv)
}
const ImplementsGetwd = true
//sys Getcwd(buf []byte) (n int, err error)
func Getwd() (wd string, err error) {
var buf [PathMax]byte
n, err := Getcwd(buf[0:])
if err != nil {
return "", err
}
// Getcwd returns the number of bytes written to buf, including the NUL.
if n < 1 || n > len(buf) || buf[n-1] != 0 {
return "", EINVAL
}
// In some cases, Linux can return a path that starts with the
// "(unreachable)" prefix, which can potentially be a valid relative
// path. To work around that, return ENOENT if path is not absolute.
if buf[0] != '/' {
return "", ENOENT
}
return string(buf[0 : n-1]), nil
}
func Getgroups() (gids []int, err error) {
n, err := getgroups(0, nil)
if err != nil {
return nil, err
}
if n == 0 {
return nil, nil
}
// Sanity check group count. Max is 1<<16 on Linux.
if n < 0 || n > 1<<20 {
return nil, EINVAL
}
a := make([]_Gid_t, n)
n, err = getgroups(n, &a[0])
if err != nil {
return nil, err
}
gids = make([]int, n)
for i, v := range a[0:n] {
gids[i] = int(v)
}
return
}
var cgo_libc_setgroups unsafe.Pointer // non-nil if cgo linked.
func Setgroups(gids []int) (err error) {
n := uintptr(len(gids))
if n == 0 {
if cgo_libc_setgroups == nil {
if _, _, e1 := AllThreadsSyscall(_SYS_setgroups, 0, 0, 0); e1 != 0 {
err = errnoErr(e1)
}
return
}
if ret := cgocaller(cgo_libc_setgroups, 0, 0); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
a := make([]_Gid_t, len(gids))
for i, v := range gids {
a[i] = _Gid_t(v)
}
if cgo_libc_setgroups == nil {
if _, _, e1 := AllThreadsSyscall(_SYS_setgroups, n, uintptr(unsafe.Pointer(&a[0])), 0); e1 != 0 {
err = errnoErr(e1)
}
return
}
if ret := cgocaller(cgo_libc_setgroups, n, uintptr(unsafe.Pointer(&a[0]))); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
type WaitStatus uint32
// Wait status is 7 bits at bottom, either 0 (exited),
// 0x7F (stopped), or a signal number that caused an exit.
// The 0x80 bit is whether there was a core dump.
// An extra number (exit code, signal causing a stop)
// is in the high bits. At least that's the idea.
// There are various irregularities. For example, the
// "continued" status is 0xFFFF, distinguishing itself
// from stopped via the core dump bit.
const (
mask = 0x7F
core = 0x80
exited = 0x00
stopped = 0x7F
shift = 8
)
func (w WaitStatus) Exited() bool { return w&mask == exited }
func (w WaitStatus) Signaled() bool { return w&mask != stopped && w&mask != exited }
func (w WaitStatus) Stopped() bool { return w&0xFF == stopped }
func (w WaitStatus) Continued() bool { return w == 0xFFFF }
func (w WaitStatus) CoreDump() bool { return w.Signaled() && w&core != 0 }
func (w WaitStatus) ExitStatus() int {
if !w.Exited() {
return -1
}
return int(w>>shift) & 0xFF
}
func (w WaitStatus) Signal() Signal {
if !w.Signaled() {
return -1
}
return Signal(w & mask)
}
func (w WaitStatus) StopSignal() Signal {
if !w.Stopped() {
return -1
}
return Signal(w>>shift) & 0xFF
}
func (w WaitStatus) TrapCause() int {
if w.StopSignal() != SIGTRAP {
return -1
}
return int(w>>shift) >> 8
}
//sys wait4(pid int, wstatus *_C_int, options int, rusage *Rusage) (wpid int, err error)
func Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (wpid int, err error) {
var status _C_int
wpid, err = wait4(pid, &status, options, rusage)
if wstatus != nil {
*wstatus = WaitStatus(status)
}
return
}
func Mkfifo(path string, mode uint32) (err error) {
return Mknod(path, mode|S_IFIFO, 0)
}
func (sa *SockaddrInet4) sockaddr() (unsafe.Pointer, _Socklen, error) {
if sa.Port < 0 || sa.Port > 0xFFFF {
return nil, 0, EINVAL
}
sa.raw.Family = AF_INET
p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
sa.raw.Addr = sa.Addr
return unsafe.Pointer(&sa.raw), SizeofSockaddrInet4, nil
}
func (sa *SockaddrInet6) sockaddr() (unsafe.Pointer, _Socklen, error) {
if sa.Port < 0 || sa.Port > 0xFFFF {
return nil, 0, EINVAL
}
sa.raw.Family = AF_INET6
p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
sa.raw.Scope_id = sa.ZoneId
sa.raw.Addr = sa.Addr
return unsafe.Pointer(&sa.raw), SizeofSockaddrInet6, nil
}
func (sa *SockaddrUnix) sockaddr() (unsafe.Pointer, _Socklen, error) {
name := sa.Name
n := len(name)
if n > len(sa.raw.Path) {
return nil, 0, EINVAL
}
if n == len(sa.raw.Path) && name[0] != '@' {
return nil, 0, EINVAL
}
sa.raw.Family = AF_UNIX
for i := 0; i < n; i++ {
sa.raw.Path[i] = int8(name[i])
}
// length is family (uint16), name, NUL.
sl := _Socklen(2)
if n > 0 {
sl += _Socklen(n) + 1
}
if sa.raw.Path[0] == '@' {
sa.raw.Path[0] = 0
// Don't count trailing NUL for abstract address.
sl--
}
return unsafe.Pointer(&sa.raw), sl, nil
}
type SockaddrLinklayer struct {
Protocol uint16
Ifindex int
Hatype uint16
Pkttype uint8
Halen uint8
Addr [8]byte
raw RawSockaddrLinklayer
}
func (sa *SockaddrLinklayer) sockaddr() (unsafe.Pointer, _Socklen, error) {
if sa.Ifindex < 0 || sa.Ifindex > 0x7fffffff {
return nil, 0, EINVAL
}
sa.raw.Family = AF_PACKET
sa.raw.Protocol = sa.Protocol
sa.raw.Ifindex = int32(sa.Ifindex)
sa.raw.Hatype = sa.Hatype
sa.raw.Pkttype = sa.Pkttype
sa.raw.Halen = sa.Halen
sa.raw.Addr = sa.Addr
return unsafe.Pointer(&sa.raw), SizeofSockaddrLinklayer, nil
}
type SockaddrNetlink struct {
Family uint16
Pad uint16
Pid uint32
Groups uint32
raw RawSockaddrNetlink
}
func (sa *SockaddrNetlink) sockaddr() (unsafe.Pointer, _Socklen, error) {
sa.raw.Family = AF_NETLINK
sa.raw.Pad = sa.Pad
sa.raw.Pid = sa.Pid
sa.raw.Groups = sa.Groups
return unsafe.Pointer(&sa.raw), SizeofSockaddrNetlink, nil
}
func anyToSockaddr(rsa *RawSockaddrAny) (Sockaddr, error) {
switch rsa.Addr.Family {
case AF_NETLINK:
pp := (*RawSockaddrNetlink)(unsafe.Pointer(rsa))
sa := new(SockaddrNetlink)
sa.Family = pp.Family
sa.Pad = pp.Pad
sa.Pid = pp.Pid
sa.Groups = pp.Groups
return sa, nil
case AF_PACKET:
pp := (*RawSockaddrLinklayer)(unsafe.Pointer(rsa))
sa := new(SockaddrLinklayer)
sa.Protocol = pp.Protocol
sa.Ifindex = int(pp.Ifindex)
sa.Hatype = pp.Hatype
sa.Pkttype = pp.Pkttype
sa.Halen = pp.Halen
sa.Addr = pp.Addr
return sa, nil
case AF_UNIX:
pp := (*RawSockaddrUnix)(unsafe.Pointer(rsa))
sa := new(SockaddrUnix)
if pp.Path[0] == 0 {
// "Abstract" Unix domain socket.
// Rewrite leading NUL as @ for textual display.
// (This is the standard convention.)
// Not friendly to overwrite in place,
// but the callers below don't care.
pp.Path[0] = '@'
}
// Assume path ends at NUL.
// This is not technically the Linux semantics for
// abstract Unix domain sockets--they are supposed
// to be uninterpreted fixed-size binary blobs--but
// everyone uses this convention.
n := 0
for n < len(pp.Path) && pp.Path[n] != 0 {
n++
}
sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n))
return sa, nil
case AF_INET:
pp := (*RawSockaddrInet4)(unsafe.Pointer(rsa))
sa := new(SockaddrInet4)
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
sa.Port = int(p[0])<<8 + int(p[1])
sa.Addr = pp.Addr
return sa, nil
case AF_INET6:
pp := (*RawSockaddrInet6)(unsafe.Pointer(rsa))
sa := new(SockaddrInet6)
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
sa.Port = int(p[0])<<8 + int(p[1])
sa.ZoneId = pp.Scope_id
sa.Addr = pp.Addr
return sa, nil
}
return nil, EAFNOSUPPORT
}
func Accept4(fd int, flags int) (nfd int, sa Sockaddr, err error) {
var rsa RawSockaddrAny
var len _Socklen = SizeofSockaddrAny
nfd, err = accept4(fd, &rsa, &len, flags)
if err != nil {
return
}
if len > SizeofSockaddrAny {
panic("RawSockaddrAny too small")
}
sa, err = anyToSockaddr(&rsa)
if err != nil {
Close(nfd)
nfd = 0
}
return
}
func Getsockname(fd int) (sa Sockaddr, err error) {
var rsa RawSockaddrAny
var len _Socklen = SizeofSockaddrAny
if err = getsockname(fd, &rsa, &len); err != nil {
return
}
return anyToSockaddr(&rsa)
}
func GetsockoptInet4Addr(fd, level, opt int) (value [4]byte, err error) {
vallen := _Socklen(4)
err = getsockopt(fd, level, opt, unsafe.Pointer(&value[0]), &vallen)
return value, err
}
func GetsockoptIPMreq(fd, level, opt int) (*IPMreq, error) {
var value IPMreq
vallen := _Socklen(SizeofIPMreq)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func GetsockoptIPMreqn(fd, level, opt int) (*IPMreqn, error) {
var value IPMreqn
vallen := _Socklen(SizeofIPMreqn)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func GetsockoptIPv6Mreq(fd, level, opt int) (*IPv6Mreq, error) {
var value IPv6Mreq
vallen := _Socklen(SizeofIPv6Mreq)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func GetsockoptIPv6MTUInfo(fd, level, opt int) (*IPv6MTUInfo, error) {
var value IPv6MTUInfo
vallen := _Socklen(SizeofIPv6MTUInfo)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func GetsockoptICMPv6Filter(fd, level, opt int) (*ICMPv6Filter, error) {
var value ICMPv6Filter
vallen := _Socklen(SizeofICMPv6Filter)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func GetsockoptUcred(fd, level, opt int) (*Ucred, error) {
var value Ucred
vallen := _Socklen(SizeofUcred)
err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
return &value, err
}
func SetsockoptIPMreqn(fd, level, opt int, mreq *IPMreqn) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(mreq), unsafe.Sizeof(*mreq))
}
func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
var msg Msghdr
msg.Name = (*byte)(unsafe.Pointer(rsa))
msg.Namelen = uint32(SizeofSockaddrAny)
var iov Iovec
if len(p) > 0 {
iov.Base = &p[0]
iov.SetLen(len(p))
}
var dummy byte
if len(oob) > 0 {
if len(p) == 0 {
var sockType int
sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
if err != nil {
return
}
// receive at least one normal byte
if sockType != SOCK_DGRAM {
iov.Base = &dummy
iov.SetLen(1)
}
}
msg.Control = &oob[0]
msg.SetControllen(len(oob))
}
msg.Iov = &iov
msg.Iovlen = 1
if n, err = recvmsg(fd, &msg, flags); err != nil {
return
}
oobn = int(msg.Controllen)
recvflags = int(msg.Flags)
return
}
func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
var msg Msghdr
msg.Name = (*byte)(ptr)
msg.Namelen = uint32(salen)
var iov Iovec
if len(p) > 0 {
iov.Base = &p[0]
iov.SetLen(len(p))
}
var dummy byte
if len(oob) > 0 {
if len(p) == 0 {
var sockType int
sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE)
if err != nil {
return 0, err
}
// send at least one normal byte
if sockType != SOCK_DGRAM {
iov.Base = &dummy
iov.SetLen(1)
}
}
msg.Control = &oob[0]
msg.SetControllen(len(oob))
}
msg.Iov = &iov
msg.Iovlen = 1
if n, err = sendmsg(fd, &msg, flags); err != nil {
return 0, err
}
if len(oob) > 0 && len(p) == 0 {
n = 0
}
return n, nil
}
// BindToDevice binds the socket associated with fd to device.
func BindToDevice(fd int, device string) (err error) {
return SetsockoptString(fd, SOL_SOCKET, SO_BINDTODEVICE, device)
}
//sys ptrace(request int, pid int, addr uintptr, data uintptr) (err error)
//sys ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) = SYS_PTRACE
func ptracePeek(req int, pid int, addr uintptr, out []byte) (count int, err error) {
// The peek requests are machine-size oriented, so we wrap it
// to retrieve arbitrary-length data.
// The ptrace syscall differs from glibc's ptrace.
// Peeks returns the word in *data, not as the return value.
var buf [sizeofPtr]byte
// Leading edge. PEEKTEXT/PEEKDATA don't require aligned
// access (PEEKUSER warns that it might), but if we don't
// align our reads, we might straddle an unmapped page
// boundary and not get the bytes leading up to the page
// boundary.
n := 0
if addr%sizeofPtr != 0 {
err = ptracePtr(req, pid, addr-addr%sizeofPtr, unsafe.Pointer(&buf[0]))
if err != nil {
return 0, err
}
n += copy(out, buf[addr%sizeofPtr:])
out = out[n:]
}
// Remainder.
for len(out) > 0 {
// We use an internal buffer to guarantee alignment.
// It's not documented if this is necessary, but we're paranoid.
err = ptracePtr(req, pid, addr+uintptr(n), unsafe.Pointer(&buf[0]))
if err != nil {
return n, err
}
copied := copy(out, buf[0:])
n += copied
out = out[copied:]
}
return n, nil
}
func PtracePeekText(pid int, addr uintptr, out []byte) (count int, err error) {
return ptracePeek(PTRACE_PEEKTEXT, pid, addr, out)
}
func PtracePeekData(pid int, addr uintptr, out []byte) (count int, err error) {
return ptracePeek(PTRACE_PEEKDATA, pid, addr, out)
}
func ptracePoke(pokeReq int, peekReq int, pid int, addr uintptr, data []byte) (count int, err error) {
// As for ptracePeek, we need to align our accesses to deal
// with the possibility of straddling an invalid page.
// Leading edge.
n := 0
if addr%sizeofPtr != 0 {
var buf [sizeofPtr]byte
err = ptracePtr(peekReq, pid, addr-addr%sizeofPtr, unsafe.Pointer(&buf[0]))
if err != nil {
return 0, err
}
n += copy(buf[addr%sizeofPtr:], data)
word := *((*uintptr)(unsafe.Pointer(&buf[0])))
err = ptrace(pokeReq, pid, addr-addr%sizeofPtr, word)
if err != nil {
return 0, err
}
data = data[n:]
}
// Interior.
for len(data) > sizeofPtr {
word := *((*uintptr)(unsafe.Pointer(&data[0])))
err = ptrace(pokeReq, pid, addr+uintptr(n), word)
if err != nil {
return n, err
}
n += sizeofPtr
data = data[sizeofPtr:]
}
// Trailing edge.
if len(data) > 0 {
var buf [sizeofPtr]byte
err = ptracePtr(peekReq, pid, addr+uintptr(n), unsafe.Pointer(&buf[0]))
if err != nil {
return n, err
}
copy(buf[0:], data)
word := *((*uintptr)(unsafe.Pointer(&buf[0])))
err = ptrace(pokeReq, pid, addr+uintptr(n), word)
if err != nil {
return n, err
}
n += len(data)
}
return n, nil
}
func PtracePokeText(pid int, addr uintptr, data []byte) (count int, err error) {
return ptracePoke(PTRACE_POKETEXT, PTRACE_PEEKTEXT, pid, addr, data)
}
func PtracePokeData(pid int, addr uintptr, data []byte) (count int, err error) {
return ptracePoke(PTRACE_POKEDATA, PTRACE_PEEKDATA, pid, addr, data)
}
func PtraceGetRegs(pid int, regsout *PtraceRegs) (err error) {
return ptracePtr(PTRACE_GETREGS, pid, 0, unsafe.Pointer(regsout))
}
func PtraceSetRegs(pid int, regs *PtraceRegs) (err error) {
return ptracePtr(PTRACE_SETREGS, pid, 0, unsafe.Pointer(regs))
}
func PtraceSetOptions(pid int, options int) (err error) {
return ptrace(PTRACE_SETOPTIONS, pid, 0, uintptr(options))
}
func PtraceGetEventMsg(pid int) (msg uint, err error) {
var data _C_long
err = ptracePtr(PTRACE_GETEVENTMSG, pid, 0, unsafe.Pointer(&data))
msg = uint(data)
return
}
func PtraceCont(pid int, signal int) (err error) {
return ptrace(PTRACE_CONT, pid, 0, uintptr(signal))
}
func PtraceSyscall(pid int, signal int) (err error) {
return ptrace(PTRACE_SYSCALL, pid, 0, uintptr(signal))
}
func PtraceSingleStep(pid int) (err error) { return ptrace(PTRACE_SINGLESTEP, pid, 0, 0) }
func PtraceAttach(pid int) (err error) { return ptrace(PTRACE_ATTACH, pid, 0, 0) }
func PtraceDetach(pid int) (err error) { return ptrace(PTRACE_DETACH, pid, 0, 0) }
//sys reboot(magic1 uint, magic2 uint, cmd int, arg string) (err error)
func Reboot(cmd int) (err error) {
return reboot(LINUX_REBOOT_MAGIC1, LINUX_REBOOT_MAGIC2, cmd, "")
}
func ReadDirent(fd int, buf []byte) (n int, err error) {
return Getdents(fd, buf)
}
func direntIno(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino))
}
func direntReclen(buf []byte) (uint64, bool) {
return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen))
}
func direntNamlen(buf []byte) (uint64, bool) {
reclen, ok := direntReclen(buf)
if !ok {
return 0, false
}
return reclen - uint64(unsafe.Offsetof(Dirent{}.Name)), true
}
//sys mount(source string, target string, fstype string, flags uintptr, data *byte) (err error)
func Mount(source string, target string, fstype string, flags uintptr, data string) (err error) {
// Certain file systems get rather angry and EINVAL if you give
// them an empty string of data, rather than NULL.
if data == "" {
return mount(source, target, fstype, flags, nil)
}
datap, err := BytePtrFromString(data)
if err != nil {
return err
}
return mount(source, target, fstype, flags, datap)
}
// Sendto
// Recvfrom
// Socketpair
/*
* Direct access
*/
//sys Acct(path string) (err error)
//sys Adjtimex(buf *Timex) (state int, err error)
//sys Chdir(path string) (err error)
//sys Chroot(path string) (err error)
//sys Close(fd int) (err error)
//sys Dup(oldfd int) (fd int, err error)
//sys Dup3(oldfd int, newfd int, flags int) (err error)
//sysnb EpollCreate1(flag int) (fd int, err error)
//sysnb EpollCtl(epfd int, op int, fd int, event *EpollEvent) (err error)
//sys Fallocate(fd int, mode uint32, off int64, len int64) (err error)
//sys Fchdir(fd int) (err error)
//sys Fchmod(fd int, mode uint32) (err error)
//sys Fchownat(dirfd int, path string, uid int, gid int, flags int) (err error)
//sys fcntl(fd int, cmd int, arg int) (val int, err error)
//sys Fdatasync(fd int) (err error)
//sys Flock(fd int, how int) (err error)
//sys Fsync(fd int) (err error)
//sys Getdents(fd int, buf []byte) (n int, err error) = SYS_GETDENTS64
//sysnb Getpgid(pid int) (pgid int, err error)
func Getpgrp() (pid int) {
pid, _ = Getpgid(0)
return
}
//sysnb Getpid() (pid int)
//sysnb Getppid() (ppid int)
//sys Getpriority(which int, who int) (prio int, err error)
//sysnb Getrusage(who int, rusage *Rusage) (err error)
//sysnb Gettid() (tid int)
//sys Getxattr(path string, attr string, dest []byte) (sz int, err error)
//sys InotifyAddWatch(fd int, pathname string, mask uint32) (watchdesc int, err error)
//sysnb InotifyInit1(flags int) (fd int, err error)
//sysnb InotifyRmWatch(fd int, watchdesc uint32) (success int, err error)
//sysnb Kill(pid int, sig Signal) (err error)
//sys Klogctl(typ int, buf []byte) (n int, err error) = SYS_SYSLOG
//sys Listxattr(path string, dest []byte) (sz int, err error)
//sys Mkdirat(dirfd int, path string, mode uint32) (err error)
//sys Mknodat(dirfd int, path string, mode uint32, dev int) (err error)
//sys Nanosleep(time *Timespec, leftover *Timespec) (err error)
//sys PivotRoot(newroot string, putold string) (err error) = SYS_PIVOT_ROOT
//sysnb prlimit(pid int, resource int, newlimit *Rlimit, old *Rlimit) (err error) = SYS_PRLIMIT64
//sys read(fd int, p []byte) (n int, err error)
//sys Removexattr(path string, attr string) (err error)
//sys Setdomainname(p []byte) (err error)
//sys Sethostname(p []byte) (err error)
//sysnb Setpgid(pid int, pgid int) (err error)
//sysnb Setsid() (pid int, err error)
//sysnb Settimeofday(tv *Timeval) (err error)
// Provided by runtime.syscall_runtime_doAllThreadsSyscall which stops the
// world and invokes the syscall on each OS thread. Once this function returns,
// all threads are in sync.
//
//go:uintptrescapes
func runtime_doAllThreadsSyscall(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr)
// AllThreadsSyscall performs a syscall on each OS thread of the Go
// runtime. It first invokes the syscall on one thread. Should that
// invocation fail, it returns immediately with the error status.
// Otherwise, it invokes the syscall on all of the remaining threads
// in parallel. It will terminate the program if it observes any
// invoked syscall's return value differs from that of the first
// invocation.
//
// AllThreadsSyscall is intended for emulating simultaneous
// process-wide state changes that require consistently modifying
// per-thread state of the Go runtime.
//
// AllThreadsSyscall is unaware of any threads that are launched
// explicitly by cgo linked code, so the function always returns
// ENOTSUP in binaries that use cgo.
//
//go:uintptrescapes
func AllThreadsSyscall(trap, a1, a2, a3 uintptr) (r1, r2 uintptr, err Errno) {
if cgo_libc_setegid != nil {
return minus1, minus1, ENOTSUP
}
r1, r2, errno := runtime_doAllThreadsSyscall(trap, a1, a2, a3, 0, 0, 0)
return r1, r2, Errno(errno)
}
// AllThreadsSyscall6 is like AllThreadsSyscall, but extended to six
// arguments.
//
//go:uintptrescapes
func AllThreadsSyscall6(trap, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err Errno) {
if cgo_libc_setegid != nil {
return minus1, minus1, ENOTSUP
}
r1, r2, errno := runtime_doAllThreadsSyscall(trap, a1, a2, a3, a4, a5, a6)
return r1, r2, Errno(errno)
}
// linked by runtime.cgocall.go
//
//go:uintptrescapes
func cgocaller(unsafe.Pointer, ...uintptr) uintptr
var cgo_libc_setegid unsafe.Pointer // non-nil if cgo linked.
const minus1 = ^uintptr(0)
func Setegid(egid int) (err error) {
if cgo_libc_setegid == nil {
if _, _, e1 := AllThreadsSyscall(SYS_SETRESGID, minus1, uintptr(egid), minus1); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setegid, uintptr(egid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_seteuid unsafe.Pointer // non-nil if cgo linked.
func Seteuid(euid int) (err error) {
if cgo_libc_seteuid == nil {
if _, _, e1 := AllThreadsSyscall(SYS_SETRESUID, minus1, uintptr(euid), minus1); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_seteuid, uintptr(euid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setgid unsafe.Pointer // non-nil if cgo linked.
func Setgid(gid int) (err error) {
if cgo_libc_setgid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETGID, uintptr(gid), 0, 0); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setgid, uintptr(gid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setregid unsafe.Pointer // non-nil if cgo linked.
func Setregid(rgid, egid int) (err error) {
if cgo_libc_setregid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETREGID, uintptr(rgid), uintptr(egid), 0); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setregid, uintptr(rgid), uintptr(egid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setresgid unsafe.Pointer // non-nil if cgo linked.
func Setresgid(rgid, egid, sgid int) (err error) {
if cgo_libc_setresgid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETRESGID, uintptr(rgid), uintptr(egid), uintptr(sgid)); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setresgid, uintptr(rgid), uintptr(egid), uintptr(sgid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setresuid unsafe.Pointer // non-nil if cgo linked.
func Setresuid(ruid, euid, suid int) (err error) {
if cgo_libc_setresuid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETRESUID, uintptr(ruid), uintptr(euid), uintptr(suid)); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setresuid, uintptr(ruid), uintptr(euid), uintptr(suid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setreuid unsafe.Pointer // non-nil if cgo linked.
func Setreuid(ruid, euid int) (err error) {
if cgo_libc_setreuid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETREUID, uintptr(ruid), uintptr(euid), 0); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setreuid, uintptr(ruid), uintptr(euid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
var cgo_libc_setuid unsafe.Pointer // non-nil if cgo linked.
func Setuid(uid int) (err error) {
if cgo_libc_setuid == nil {
if _, _, e1 := AllThreadsSyscall(sys_SETUID, uintptr(uid), 0, 0); e1 != 0 {
err = errnoErr(e1)
}
} else if ret := cgocaller(cgo_libc_setuid, uintptr(uid)); ret != 0 {
err = errnoErr(Errno(ret))
}
return
}
//sys Setpriority(which int, who int, prio int) (err error)
//sys Setxattr(path string, attr string, data []byte, flags int) (err error)
//sys Sync()
//sysnb Sysinfo(info *Sysinfo_t) (err error)
//sys Tee(rfd int, wfd int, len int, flags int) (n int64, err error)
//sysnb Tgkill(tgid int, tid int, sig Signal) (err error)
//sysnb Times(tms *Tms) (ticks uintptr, err error)
//sysnb Umask(mask int) (oldmask int)
//sysnb Uname(buf *Utsname) (err error)
//sys Unmount(target string, flags int) (err error) = SYS_UMOUNT2
//sys Unshare(flags int) (err error)
//sys write(fd int, p []byte) (n int, err error)
//sys exitThread(code int) (err error) = SYS_EXIT
//sys readlen(fd int, p *byte, np int) (n int, err error) = SYS_READ
//sys writelen(fd int, p *byte, np int) (n int, err error) = SYS_WRITE
// mmap varies by architecture; see syscall_linux_*.go.
//sys munmap(addr uintptr, length uintptr) (err error)
var mapper = &mmapper{
active: make(map[*byte][]byte),
mmap: mmap,
munmap: munmap,
}
func Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) {
return mapper.Mmap(fd, offset, length, prot, flags)
}
func Munmap(b []byte) (err error) {
return mapper.Munmap(b)
}
//sys Madvise(b []byte, advice int) (err error)
//sys Mprotect(b []byte, prot int) (err error)
//sys Mlock(b []byte) (err error)
//sys Munlock(b []byte) (err error)
//sys Mlockall(flags int) (err error)
//sys Munlockall() (err error)
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file provides the Accept function used on all systems
// other than arm. See syscall_linux_accept.go for why.
//go:build linux && !arm
package syscall
func Accept(fd int) (nfd int, sa Sockaddr, err error) {
var rsa RawSockaddrAny
var len _Socklen = SizeofSockaddrAny
nfd, err = accept4(fd, &rsa, &len, 0)
if err != nil {
return
}
sa, err = anyToSockaddr(&rsa)
if err != nil {
Close(nfd)
nfd = 0
}
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syscall
const (
_SYS_setgroups = SYS_SETGROUPS
_SYS_clone3 = 435
_SYS_faccessat2 = 439
)
//sys Dup2(oldfd int, newfd int) (err error)
//sys Fchown(fd int, uid int, gid int) (err error)
//sys Fstat(fd int, stat *Stat_t) (err error)
//sys Fstatfs(fd int, buf *Statfs_t) (err error)
//sys Ftruncate(fd int, length int64) (err error)
//sysnb Getegid() (egid int)
//sysnb Geteuid() (euid int)
//sysnb Getgid() (gid int)
//sysnb Getrlimit(resource int, rlim *Rlimit) (err error)
//sysnb Getuid() (uid int)
//sysnb InotifyInit() (fd int, err error)
//sys Ioperm(from int, num int, on int) (err error)
//sys Iopl(level int) (err error)
//sys Listen(s int, n int) (err error)
//sys Pause() (err error)
//sys pread(fd int, p []byte, offset int64) (n int, err error) = SYS_PREAD64
//sys pwrite(fd int, p []byte, offset int64) (n int, err error) = SYS_PWRITE64
//sys Renameat(olddirfd int, oldpath string, newdirfd int, newpath string) (err error)
//sys Seek(fd int, offset int64, whence int) (off int64, err error) = SYS_LSEEK
//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error)
//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error)
//sys Setfsgid(gid int) (err error)
//sys Setfsuid(uid int) (err error)
//sysnb Setrlimit(resource int, rlim *Rlimit) (err error)
//sys Shutdown(fd int, how int) (err error)
//sys Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error)
//sys Statfs(path string, buf *Statfs_t) (err error)
//sys SyncFileRange(fd int, off int64, n int64, flags int) (err error)
//sys Truncate(path string, length int64) (err error)
//sys Ustat(dev int, ubuf *Ustat_t) (err error)
//sys accept4(s int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (fd int, err error)
//sys bind(s int, addr unsafe.Pointer, addrlen _Socklen) (err error)
//sys connect(s int, addr unsafe.Pointer, addrlen _Socklen) (err error)
//sys fstatat(fd int, path string, stat *Stat_t, flags int) (err error) = SYS_NEWFSTATAT
//sysnb getgroups(n int, list *_Gid_t) (nn int, err error)
//sys getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *_Socklen) (err error)
//sys setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error)
//sysnb socket(domain int, typ int, proto int) (fd int, err error)
//sysnb socketpair(domain int, typ int, proto int, fd *[2]int32) (err error)
//sysnb getpeername(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error)
//sysnb getsockname(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error)
//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error)
//sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error)
//sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error)
//sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error)
//sys EpollWait(epfd int, events []EpollEvent, msec int) (n int, err error)
func Stat(path string, stat *Stat_t) (err error) {
return fstatat(_AT_FDCWD, path, stat, 0)
}
func Lchown(path string, uid int, gid int) (err error) {
return Fchownat(_AT_FDCWD, path, uid, gid, _AT_SYMLINK_NOFOLLOW)
}
func Lstat(path string, stat *Stat_t) (err error) {
return fstatat(_AT_FDCWD, path, stat, _AT_SYMLINK_NOFOLLOW)
}
//sys futimesat(dirfd int, path string, times *[2]Timeval) (err error)
//go:noescape
func gettimeofday(tv *Timeval) (err Errno)
func Gettimeofday(tv *Timeval) (err error) {
errno := gettimeofday(tv)
if errno != 0 {
return errno
}
return nil
}
func Time(t *Time_t) (tt Time_t, err error) {
var tv Timeval
errno := gettimeofday(&tv)
if errno != 0 {
return 0, errno
}
if t != nil {
*t = Time_t(tv.Sec)
}
return Time_t(tv.Sec), nil
}
//sys Utime(path string, buf *Utimbuf) (err error)
//sys utimes(path string, times *[2]Timeval) (err error)
func setTimespec(sec, nsec int64) Timespec {
return Timespec{Sec: sec, Nsec: nsec}
}
func setTimeval(sec, usec int64) Timeval {
return Timeval{Sec: sec, Usec: usec}
}
func (r *PtraceRegs) PC() uint64 { return r.Rip }
func (r *PtraceRegs) SetPC(pc uint64) { r.Rip = pc }
func (iov *Iovec) SetLen(length int) {
iov.Len = uint64(length)
}
func (msghdr *Msghdr) SetControllen(length int) {
msghdr.Controllen = uint64(length)
}
func (cmsg *Cmsghdr) SetLen(length int) {
cmsg.Len = uint64(length)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package syscall
import (
"internal/bytealg"
"internal/itoa"
"internal/oserror"
"internal/race"
"runtime"
"sync"
"unsafe"
)
var (
Stdin = 0
Stdout = 1
Stderr = 2
)
const (
darwin64Bit = (runtime.GOOS == "darwin" || runtime.GOOS == "ios") && sizeofPtr == 8
netbsd32Bit = runtime.GOOS == "netbsd" && sizeofPtr == 4
)
// clen returns the index of the first NULL byte in n or len(n) if n contains no NULL byte.
func clen(n []byte) int {
if i := bytealg.IndexByte(n, 0); i != -1 {
return i
}
return len(n)
}
// Mmap manager, for use by operating system-specific implementations.
type mmapper struct {
sync.Mutex
active map[*byte][]byte // active mappings; key is last byte in mapping
mmap func(addr, length uintptr, prot, flags, fd int, offset int64) (uintptr, error)
munmap func(addr uintptr, length uintptr) error
}
func (m *mmapper) Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) {
if length <= 0 {
return nil, EINVAL
}
// Map the requested memory.
addr, errno := m.mmap(0, uintptr(length), prot, flags, fd, offset)
if errno != nil {
return nil, errno
}
// Use unsafe to turn addr into a []byte.
b := unsafe.Slice((*byte)(unsafe.Pointer(addr)), length)
// Register mapping in m and return it.
p := &b[cap(b)-1]
m.Lock()
defer m.Unlock()
m.active[p] = b
return b, nil
}
func (m *mmapper) Munmap(data []byte) (err error) {
if len(data) == 0 || len(data) != cap(data) {
return EINVAL
}
// Find the base of the mapping.
p := &data[cap(data)-1]
m.Lock()
defer m.Unlock()
b := m.active[p]
if b == nil || &b[0] != &data[0] {
return EINVAL
}
// Unmap the memory and update m.
if errno := m.munmap(uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))); errno != nil {
return errno
}
delete(m.active, p)
return nil
}
// An Errno is an unsigned number describing an error condition.
// It implements the error interface. The zero Errno is by convention
// a non-error, so code to convert from Errno to error should use:
//
// err = nil
// if errno != 0 {
// err = errno
// }
//
// Errno values can be tested against error values from the os package
// using errors.Is. For example:
//
// _, _, err := syscall.Syscall(...)
// if errors.Is(err, fs.ErrNotExist) ...
type Errno uintptr
func (e Errno) Error() string {
if 0 <= int(e) && int(e) < len(errors) {
s := errors[e]
if s != "" {
return s
}
}
return "errno " + itoa.Itoa(int(e))
}
func (e Errno) Is(target error) bool {
switch target {
case oserror.ErrPermission:
return e == EACCES || e == EPERM
case oserror.ErrExist:
return e == EEXIST || e == ENOTEMPTY
case oserror.ErrNotExist:
return e == ENOENT
}
return false
}
func (e Errno) Temporary() bool {
return e == EINTR || e == EMFILE || e == ENFILE || e.Timeout()
}
func (e Errno) Timeout() bool {
return e == EAGAIN || e == EWOULDBLOCK || e == ETIMEDOUT
}
// Do the interface allocations only once for common
// Errno values.
var (
errEAGAIN error = EAGAIN
errEINVAL error = EINVAL
errENOENT error = ENOENT
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e Errno) error {
switch e {
case 0:
return nil
case EAGAIN:
return errEAGAIN
case EINVAL:
return errEINVAL
case ENOENT:
return errENOENT
}
return e
}
// A Signal is a number describing a process signal.
// It implements the os.Signal interface.
type Signal int
func (s Signal) Signal() {}
func (s Signal) String() string {
if 0 <= s && int(s) < len(signals) {
str := signals[s]
if str != "" {
return str
}
}
return "signal " + itoa.Itoa(int(s))
}
func Read(fd int, p []byte) (n int, err error) {
n, err = read(fd, p)
if race.Enabled {
if n > 0 {
race.WriteRange(unsafe.Pointer(&p[0]), n)
}
if err == nil {
race.Acquire(unsafe.Pointer(&ioSync))
}
}
if msanenabled && n > 0 {
msanWrite(unsafe.Pointer(&p[0]), n)
}
if asanenabled && n > 0 {
asanWrite(unsafe.Pointer(&p[0]), n)
}
return
}
func Write(fd int, p []byte) (n int, err error) {
if race.Enabled {
race.ReleaseMerge(unsafe.Pointer(&ioSync))
}
if faketime && (fd == 1 || fd == 2) {
n = faketimeWrite(fd, p)
if n < 0 {
n, err = 0, errnoErr(Errno(-n))
}
} else {
n, err = write(fd, p)
}
if race.Enabled && n > 0 {
race.ReadRange(unsafe.Pointer(&p[0]), n)
}
if msanenabled && n > 0 {
msanRead(unsafe.Pointer(&p[0]), n)
}
if asanenabled && n > 0 {
asanRead(unsafe.Pointer(&p[0]), n)
}
return
}
func Pread(fd int, p []byte, offset int64) (n int, err error) {
n, err = pread(fd, p, offset)
if race.Enabled {
if n > 0 {
race.WriteRange(unsafe.Pointer(&p[0]), n)
}
if err == nil {
race.Acquire(unsafe.Pointer(&ioSync))
}
}
if msanenabled && n > 0 {
msanWrite(unsafe.Pointer(&p[0]), n)
}
if asanenabled && n > 0 {
asanWrite(unsafe.Pointer(&p[0]), n)
}
return
}
func Pwrite(fd int, p []byte, offset int64) (n int, err error) {
if race.Enabled {
race.ReleaseMerge(unsafe.Pointer(&ioSync))
}
n, err = pwrite(fd, p, offset)
if race.Enabled && n > 0 {
race.ReadRange(unsafe.Pointer(&p[0]), n)
}
if msanenabled && n > 0 {
msanRead(unsafe.Pointer(&p[0]), n)
}
if asanenabled && n > 0 {
asanRead(unsafe.Pointer(&p[0]), n)
}
return
}
// For testing: clients can set this flag to force
// creation of IPv6 sockets to return EAFNOSUPPORT.
var SocketDisableIPv6 bool
type Sockaddr interface {
sockaddr() (ptr unsafe.Pointer, len _Socklen, err error) // lowercase; only we can define Sockaddrs
}
type SockaddrInet4 struct {
Port int
Addr [4]byte
raw RawSockaddrInet4
}
type SockaddrInet6 struct {
Port int
ZoneId uint32
Addr [16]byte
raw RawSockaddrInet6
}
type SockaddrUnix struct {
Name string
raw RawSockaddrUnix
}
func Bind(fd int, sa Sockaddr) (err error) {
ptr, n, err := sa.sockaddr()
if err != nil {
return err
}
return bind(fd, ptr, n)
}
func Connect(fd int, sa Sockaddr) (err error) {
ptr, n, err := sa.sockaddr()
if err != nil {
return err
}
return connect(fd, ptr, n)
}
func Getpeername(fd int) (sa Sockaddr, err error) {
var rsa RawSockaddrAny
var len _Socklen = SizeofSockaddrAny
if err = getpeername(fd, &rsa, &len); err != nil {
return
}
return anyToSockaddr(&rsa)
}
func GetsockoptInt(fd, level, opt int) (value int, err error) {
var n int32
vallen := _Socklen(4)
err = getsockopt(fd, level, opt, unsafe.Pointer(&n), &vallen)
return int(n), err
}
func Recvfrom(fd int, p []byte, flags int) (n int, from Sockaddr, err error) {
var rsa RawSockaddrAny
var len _Socklen = SizeofSockaddrAny
if n, err = recvfrom(fd, p, flags, &rsa, &len); err != nil {
return
}
if rsa.Addr.Family != AF_UNSPEC {
from, err = anyToSockaddr(&rsa)
}
return
}
func recvfromInet4(fd int, p []byte, flags int, from *SockaddrInet4) (n int, err error) {
var rsa RawSockaddrAny
var socklen _Socklen = SizeofSockaddrAny
if n, err = recvfrom(fd, p, flags, &rsa, &socklen); err != nil {
return
}
pp := (*RawSockaddrInet4)(unsafe.Pointer(&rsa))
port := (*[2]byte)(unsafe.Pointer(&pp.Port))
from.Port = int(port[0])<<8 + int(port[1])
from.Addr = pp.Addr
return
}
func recvfromInet6(fd int, p []byte, flags int, from *SockaddrInet6) (n int, err error) {
var rsa RawSockaddrAny
var socklen _Socklen = SizeofSockaddrAny
if n, err = recvfrom(fd, p, flags, &rsa, &socklen); err != nil {
return
}
pp := (*RawSockaddrInet6)(unsafe.Pointer(&rsa))
port := (*[2]byte)(unsafe.Pointer(&pp.Port))
from.Port = int(port[0])<<8 + int(port[1])
from.ZoneId = pp.Scope_id
from.Addr = pp.Addr
return
}
func recvmsgInet4(fd int, p, oob []byte, flags int, from *SockaddrInet4) (n, oobn int, recvflags int, err error) {
var rsa RawSockaddrAny
n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa)
if err != nil {
return
}
pp := (*RawSockaddrInet4)(unsafe.Pointer(&rsa))
port := (*[2]byte)(unsafe.Pointer(&pp.Port))
from.Port = int(port[0])<<8 + int(port[1])
from.Addr = pp.Addr
return
}
func recvmsgInet6(fd int, p, oob []byte, flags int, from *SockaddrInet6) (n, oobn int, recvflags int, err error) {
var rsa RawSockaddrAny
n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa)
if err != nil {
return
}
pp := (*RawSockaddrInet6)(unsafe.Pointer(&rsa))
port := (*[2]byte)(unsafe.Pointer(&pp.Port))
from.Port = int(port[0])<<8 + int(port[1])
from.ZoneId = pp.Scope_id
from.Addr = pp.Addr
return
}
func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {
var rsa RawSockaddrAny
n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa)
// source address is only specified if the socket is unconnected
if rsa.Addr.Family != AF_UNSPEC {
from, err = anyToSockaddr(&rsa)
}
return
}
func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) {
_, err = SendmsgN(fd, p, oob, to, flags)
return
}
func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) {
var ptr unsafe.Pointer
var salen _Socklen
if to != nil {
ptr, salen, err = to.sockaddr()
if err != nil {
return 0, err
}
}
return sendmsgN(fd, p, oob, ptr, salen, flags)
}
func sendmsgNInet4(fd int, p, oob []byte, to *SockaddrInet4, flags int) (n int, err error) {
ptr, salen, err := to.sockaddr()
if err != nil {
return 0, err
}
return sendmsgN(fd, p, oob, ptr, salen, flags)
}
func sendmsgNInet6(fd int, p, oob []byte, to *SockaddrInet6, flags int) (n int, err error) {
ptr, salen, err := to.sockaddr()
if err != nil {
return 0, err
}
return sendmsgN(fd, p, oob, ptr, salen, flags)
}
func sendtoInet4(fd int, p []byte, flags int, to *SockaddrInet4) (err error) {
ptr, n, err := to.sockaddr()
if err != nil {
return err
}
return sendto(fd, p, flags, ptr, n)
}
func sendtoInet6(fd int, p []byte, flags int, to *SockaddrInet6) (err error) {
ptr, n, err := to.sockaddr()
if err != nil {
return err
}
return sendto(fd, p, flags, ptr, n)
}
func Sendto(fd int, p []byte, flags int, to Sockaddr) (err error) {
var (
ptr unsafe.Pointer
salen _Socklen
)
if to != nil {
ptr, salen, err = to.sockaddr()
if err != nil {
return err
}
}
return sendto(fd, p, flags, ptr, salen)
}
func SetsockoptByte(fd, level, opt int, value byte) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(&value), 1)
}
func SetsockoptInt(fd, level, opt int, value int) (err error) {
var n = int32(value)
return setsockopt(fd, level, opt, unsafe.Pointer(&n), 4)
}
func SetsockoptInet4Addr(fd, level, opt int, value [4]byte) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(&value[0]), 4)
}
func SetsockoptIPMreq(fd, level, opt int, mreq *IPMreq) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(mreq), SizeofIPMreq)
}
func SetsockoptIPv6Mreq(fd, level, opt int, mreq *IPv6Mreq) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(mreq), SizeofIPv6Mreq)
}
func SetsockoptICMPv6Filter(fd, level, opt int, filter *ICMPv6Filter) error {
return setsockopt(fd, level, opt, unsafe.Pointer(filter), SizeofICMPv6Filter)
}
func SetsockoptLinger(fd, level, opt int, l *Linger) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(l), SizeofLinger)
}
func SetsockoptString(fd, level, opt int, s string) (err error) {
var p unsafe.Pointer
if len(s) > 0 {
p = unsafe.Pointer(&[]byte(s)[0])
}
return setsockopt(fd, level, opt, p, uintptr(len(s)))
}
func SetsockoptTimeval(fd, level, opt int, tv *Timeval) (err error) {
return setsockopt(fd, level, opt, unsafe.Pointer(tv), unsafe.Sizeof(*tv))
}
func Socket(domain, typ, proto int) (fd int, err error) {
if domain == AF_INET6 && SocketDisableIPv6 {
return -1, EAFNOSUPPORT
}
fd, err = socket(domain, typ, proto)
return
}
func Socketpair(domain, typ, proto int) (fd [2]int, err error) {
var fdx [2]int32
err = socketpair(domain, typ, proto, &fdx)
if err == nil {
fd[0] = int(fdx[0])
fd[1] = int(fdx[1])
}
return
}
func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) {
if race.Enabled {
race.ReleaseMerge(unsafe.Pointer(&ioSync))
}
return sendfile(outfd, infd, offset, count)
}
var ioSync int64
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !faketime
package syscall
const faketime = false
func faketimeWrite(fd int, p []byte) int {
// This should never be called since faketime is false.
panic("not implemented")
}
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package syscall
// TimespecToNsec returns the time stored in ts as nanoseconds.
func TimespecToNsec(ts Timespec) int64 { return ts.Nano() }
// NsecToTimespec converts a number of nanoseconds into a Timespec.
func NsecToTimespec(nsec int64) Timespec {
sec := nsec / 1e9
nsec = nsec % 1e9
if nsec < 0 {
nsec += 1e9
sec--
}
return setTimespec(sec, nsec)
}
// TimevalToNsec returns the time stored in tv as nanoseconds.
func TimevalToNsec(tv Timeval) int64 { return tv.Nano() }
// NsecToTimeval converts a number of nanoseconds into a Timeval.
func NsecToTimeval(nsec int64) Timeval {
nsec += 999 // round up to microsecond
usec := nsec % 1e9 / 1e3
sec := nsec / 1e9
if usec < 0 {
usec += 1e6
sec--
}
return setTimeval(sec, usec)
}
// mksyscall.pl -tags linux,amd64 syscall_linux.go syscall_linux_amd64.go
// Code generated by the command above; DO NOT EDIT.
//go:build linux && amd64
package syscall
import "unsafe"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func faccessat(dirfd int, path string, mode uint32) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_FACCESSAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(mode))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func faccessat2(dirfd int, path string, mode uint32, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall6(_SYS_faccessat2, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(mode), uintptr(flags), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fchmodat(dirfd int, path string, mode uint32) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_FCHMODAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(mode))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func linkat(olddirfd int, oldpath string, newdirfd int, newpath string, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(oldpath)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(newpath)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_LINKAT, uintptr(olddirfd), uintptr(unsafe.Pointer(_p0)), uintptr(newdirfd), uintptr(unsafe.Pointer(_p1)), uintptr(flags), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func openat(dirfd int, path string, flags int, mode uint32) (fd int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
r0, _, e1 := Syscall6(SYS_OPENAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(flags), uintptr(mode), 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe2(p *[2]_C_int, flags int) (err error) {
_, _, e1 := RawSyscall(SYS_PIPE2, uintptr(unsafe.Pointer(p)), uintptr(flags), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readlinkat(dirfd int, path string, buf []byte) (n int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(buf) > 0 {
_p1 = unsafe.Pointer(&buf[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_READLINKAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(buf)), 0, 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func symlinkat(oldpath string, newdirfd int, newpath string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(oldpath)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(newpath)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_SYMLINKAT, uintptr(unsafe.Pointer(_p0)), uintptr(newdirfd), uintptr(unsafe.Pointer(_p1)))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func unlinkat(dirfd int, path string, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_UNLINKAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(flags))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func utimensat(dirfd int, path string, times *[2]Timespec, flag int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_UTIMENSAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(times)), uintptr(flag), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getcwd(buf []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_GETCWD, uintptr(_p0), uintptr(len(buf)), 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func wait4(pid int, wstatus *_C_int, options int, rusage *Rusage) (wpid int, err error) {
r0, _, e1 := Syscall6(SYS_WAIT4, uintptr(pid), uintptr(unsafe.Pointer(wstatus)), uintptr(options), uintptr(unsafe.Pointer(rusage)), 0, 0)
wpid = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) {
_, _, e1 := Syscall6(SYS_PTRACE, uintptr(request), uintptr(pid), uintptr(addr), uintptr(data), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) {
_, _, e1 := Syscall6(SYS_PTRACE, uintptr(request), uintptr(pid), uintptr(addr), uintptr(data), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func reboot(magic1 uint, magic2 uint, cmd int, arg string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(arg)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_REBOOT, uintptr(magic1), uintptr(magic2), uintptr(cmd), uintptr(unsafe.Pointer(_p0)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func mount(source string, target string, fstype string, flags uintptr, data *byte) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(source)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(target)
if err != nil {
return
}
var _p2 *byte
_p2, err = BytePtrFromString(fstype)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_MOUNT, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(unsafe.Pointer(_p2)), uintptr(flags), uintptr(unsafe.Pointer(data)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Acct(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_ACCT, uintptr(unsafe.Pointer(_p0)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Adjtimex(buf *Timex) (state int, err error) {
r0, _, e1 := Syscall(SYS_ADJTIMEX, uintptr(unsafe.Pointer(buf)), 0, 0)
state = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Chdir(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_CHDIR, uintptr(unsafe.Pointer(_p0)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Chroot(path string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_CHROOT, uintptr(unsafe.Pointer(_p0)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Close(fd int) (err error) {
_, _, e1 := Syscall(SYS_CLOSE, uintptr(fd), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup(oldfd int) (fd int, err error) {
r0, _, e1 := Syscall(SYS_DUP, uintptr(oldfd), 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup3(oldfd int, newfd int, flags int) (err error) {
_, _, e1 := Syscall(SYS_DUP3, uintptr(oldfd), uintptr(newfd), uintptr(flags))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func EpollCreate1(flag int) (fd int, err error) {
r0, _, e1 := RawSyscall(SYS_EPOLL_CREATE1, uintptr(flag), 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func EpollCtl(epfd int, op int, fd int, event *EpollEvent) (err error) {
_, _, e1 := RawSyscall6(SYS_EPOLL_CTL, uintptr(epfd), uintptr(op), uintptr(fd), uintptr(unsafe.Pointer(event)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fallocate(fd int, mode uint32, off int64, len int64) (err error) {
_, _, e1 := Syscall6(SYS_FALLOCATE, uintptr(fd), uintptr(mode), uintptr(off), uintptr(len), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fchdir(fd int) (err error) {
_, _, e1 := Syscall(SYS_FCHDIR, uintptr(fd), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fchmod(fd int, mode uint32) (err error) {
_, _, e1 := Syscall(SYS_FCHMOD, uintptr(fd), uintptr(mode), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fchownat(dirfd int, path string, uid int, gid int, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_FCHOWNAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(uid), uintptr(gid), uintptr(flags), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fcntl(fd int, cmd int, arg int) (val int, err error) {
r0, _, e1 := Syscall(SYS_FCNTL, uintptr(fd), uintptr(cmd), uintptr(arg))
val = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fdatasync(fd int) (err error) {
_, _, e1 := Syscall(SYS_FDATASYNC, uintptr(fd), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Flock(fd int, how int) (err error) {
_, _, e1 := Syscall(SYS_FLOCK, uintptr(fd), uintptr(how), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fsync(fd int) (err error) {
_, _, e1 := Syscall(SYS_FSYNC, uintptr(fd), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getdents(fd int, buf []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_GETDENTS64, uintptr(fd), uintptr(_p0), uintptr(len(buf)))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getpgid(pid int) (pgid int, err error) {
r0, _, e1 := RawSyscall(SYS_GETPGID, uintptr(pid), 0, 0)
pgid = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getpid() (pid int) {
r0, _ := rawSyscallNoError(SYS_GETPID, 0, 0, 0)
pid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getppid() (ppid int) {
r0, _ := rawSyscallNoError(SYS_GETPPID, 0, 0, 0)
ppid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getpriority(which int, who int) (prio int, err error) {
r0, _, e1 := Syscall(SYS_GETPRIORITY, uintptr(which), uintptr(who), 0)
prio = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getrusage(who int, rusage *Rusage) (err error) {
_, _, e1 := RawSyscall(SYS_GETRUSAGE, uintptr(who), uintptr(unsafe.Pointer(rusage)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Gettid() (tid int) {
r0, _ := rawSyscallNoError(SYS_GETTID, 0, 0, 0)
tid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getxattr(path string, attr string, dest []byte) (sz int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(attr)
if err != nil {
return
}
var _p2 unsafe.Pointer
if len(dest) > 0 {
_p2 = unsafe.Pointer(&dest[0])
} else {
_p2 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_GETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(_p2), uintptr(len(dest)), 0, 0)
sz = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func InotifyAddWatch(fd int, pathname string, mask uint32) (watchdesc int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(pathname)
if err != nil {
return
}
r0, _, e1 := Syscall(SYS_INOTIFY_ADD_WATCH, uintptr(fd), uintptr(unsafe.Pointer(_p0)), uintptr(mask))
watchdesc = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func InotifyInit1(flags int) (fd int, err error) {
r0, _, e1 := RawSyscall(SYS_INOTIFY_INIT1, uintptr(flags), 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func InotifyRmWatch(fd int, watchdesc uint32) (success int, err error) {
r0, _, e1 := RawSyscall(SYS_INOTIFY_RM_WATCH, uintptr(fd), uintptr(watchdesc), 0)
success = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Kill(pid int, sig Signal) (err error) {
_, _, e1 := RawSyscall(SYS_KILL, uintptr(pid), uintptr(sig), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Klogctl(typ int, buf []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_SYSLOG, uintptr(typ), uintptr(_p0), uintptr(len(buf)))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Listxattr(path string, dest []byte) (sz int, err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 unsafe.Pointer
if len(dest) > 0 {
_p1 = unsafe.Pointer(&dest[0])
} else {
_p1 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_LISTXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(_p1), uintptr(len(dest)))
sz = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Mkdirat(dirfd int, path string, mode uint32) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_MKDIRAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(mode))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Mknodat(dirfd int, path string, mode uint32, dev int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_MKNODAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(mode), uintptr(dev), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Nanosleep(time *Timespec, leftover *Timespec) (err error) {
_, _, e1 := Syscall(SYS_NANOSLEEP, uintptr(unsafe.Pointer(time)), uintptr(unsafe.Pointer(leftover)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func PivotRoot(newroot string, putold string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(newroot)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(putold)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_PIVOT_ROOT, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func prlimit(pid int, resource int, newlimit *Rlimit, old *Rlimit) (err error) {
_, _, e1 := RawSyscall6(SYS_PRLIMIT64, uintptr(pid), uintptr(resource), uintptr(unsafe.Pointer(newlimit)), uintptr(unsafe.Pointer(old)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func read(fd int, p []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_READ, uintptr(fd), uintptr(_p0), uintptr(len(p)))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Removexattr(path string, attr string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(attr)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_REMOVEXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setdomainname(p []byte) (err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_SETDOMAINNAME, uintptr(_p0), uintptr(len(p)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Sethostname(p []byte) (err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_SETHOSTNAME, uintptr(_p0), uintptr(len(p)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setpgid(pid int, pgid int) (err error) {
_, _, e1 := RawSyscall(SYS_SETPGID, uintptr(pid), uintptr(pgid), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setsid() (pid int, err error) {
r0, _, e1 := RawSyscall(SYS_SETSID, 0, 0, 0)
pid = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Settimeofday(tv *Timeval) (err error) {
_, _, e1 := RawSyscall(SYS_SETTIMEOFDAY, uintptr(unsafe.Pointer(tv)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setpriority(which int, who int, prio int) (err error) {
_, _, e1 := Syscall(SYS_SETPRIORITY, uintptr(which), uintptr(who), uintptr(prio))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setxattr(path string, attr string, data []byte, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(attr)
if err != nil {
return
}
var _p2 unsafe.Pointer
if len(data) > 0 {
_p2 = unsafe.Pointer(&data[0])
} else {
_p2 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall6(SYS_SETXATTR, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), uintptr(_p2), uintptr(len(data)), uintptr(flags), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Sync() {
Syscall(SYS_SYNC, 0, 0, 0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Sysinfo(info *Sysinfo_t) (err error) {
_, _, e1 := RawSyscall(SYS_SYSINFO, uintptr(unsafe.Pointer(info)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Tee(rfd int, wfd int, len int, flags int) (n int64, err error) {
r0, _, e1 := Syscall6(SYS_TEE, uintptr(rfd), uintptr(wfd), uintptr(len), uintptr(flags), 0, 0)
n = int64(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Tgkill(tgid int, tid int, sig Signal) (err error) {
_, _, e1 := RawSyscall(SYS_TGKILL, uintptr(tgid), uintptr(tid), uintptr(sig))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Times(tms *Tms) (ticks uintptr, err error) {
r0, _, e1 := RawSyscall(SYS_TIMES, uintptr(unsafe.Pointer(tms)), 0, 0)
ticks = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Umask(mask int) (oldmask int) {
r0, _ := rawSyscallNoError(SYS_UMASK, uintptr(mask), 0, 0)
oldmask = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Uname(buf *Utsname) (err error) {
_, _, e1 := RawSyscall(SYS_UNAME, uintptr(unsafe.Pointer(buf)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Unmount(target string, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(target)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_UMOUNT2, uintptr(unsafe.Pointer(_p0)), uintptr(flags), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Unshare(flags int) (err error) {
_, _, e1 := Syscall(SYS_UNSHARE, uintptr(flags), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func write(fd int, p []byte) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall(SYS_WRITE, uintptr(fd), uintptr(_p0), uintptr(len(p)))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func exitThread(code int) (err error) {
_, _, e1 := Syscall(SYS_EXIT, uintptr(code), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readlen(fd int, p *byte, np int) (n int, err error) {
r0, _, e1 := Syscall(SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(p)), uintptr(np))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func writelen(fd int, p *byte, np int) (n int, err error) {
r0, _, e1 := Syscall(SYS_WRITE, uintptr(fd), uintptr(unsafe.Pointer(p)), uintptr(np))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func munmap(addr uintptr, length uintptr) (err error) {
_, _, e1 := Syscall(SYS_MUNMAP, uintptr(addr), uintptr(length), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Madvise(b []byte, advice int) (err error) {
var _p0 unsafe.Pointer
if len(b) > 0 {
_p0 = unsafe.Pointer(&b[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_MADVISE, uintptr(_p0), uintptr(len(b)), uintptr(advice))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Mprotect(b []byte, prot int) (err error) {
var _p0 unsafe.Pointer
if len(b) > 0 {
_p0 = unsafe.Pointer(&b[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_MPROTECT, uintptr(_p0), uintptr(len(b)), uintptr(prot))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Mlock(b []byte) (err error) {
var _p0 unsafe.Pointer
if len(b) > 0 {
_p0 = unsafe.Pointer(&b[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_MLOCK, uintptr(_p0), uintptr(len(b)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Munlock(b []byte) (err error) {
var _p0 unsafe.Pointer
if len(b) > 0 {
_p0 = unsafe.Pointer(&b[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall(SYS_MUNLOCK, uintptr(_p0), uintptr(len(b)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Mlockall(flags int) (err error) {
_, _, e1 := Syscall(SYS_MLOCKALL, uintptr(flags), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Munlockall() (err error) {
_, _, e1 := Syscall(SYS_MUNLOCKALL, 0, 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Dup2(oldfd int, newfd int) (err error) {
_, _, e1 := Syscall(SYS_DUP2, uintptr(oldfd), uintptr(newfd), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fchown(fd int, uid int, gid int) (err error) {
_, _, e1 := Syscall(SYS_FCHOWN, uintptr(fd), uintptr(uid), uintptr(gid))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fstat(fd int, stat *Stat_t) (err error) {
_, _, e1 := Syscall(SYS_FSTAT, uintptr(fd), uintptr(unsafe.Pointer(stat)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Fstatfs(fd int, buf *Statfs_t) (err error) {
_, _, e1 := Syscall(SYS_FSTATFS, uintptr(fd), uintptr(unsafe.Pointer(buf)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Ftruncate(fd int, length int64) (err error) {
_, _, e1 := Syscall(SYS_FTRUNCATE, uintptr(fd), uintptr(length), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getegid() (egid int) {
r0, _ := rawSyscallNoError(SYS_GETEGID, 0, 0, 0)
egid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Geteuid() (euid int) {
r0, _ := rawSyscallNoError(SYS_GETEUID, 0, 0, 0)
euid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getgid() (gid int) {
r0, _ := rawSyscallNoError(SYS_GETGID, 0, 0, 0)
gid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getrlimit(resource int, rlim *Rlimit) (err error) {
_, _, e1 := RawSyscall(SYS_GETRLIMIT, uintptr(resource), uintptr(unsafe.Pointer(rlim)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getuid() (uid int) {
r0, _ := rawSyscallNoError(SYS_GETUID, 0, 0, 0)
uid = int(r0)
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func InotifyInit() (fd int, err error) {
r0, _, e1 := RawSyscall(SYS_INOTIFY_INIT, 0, 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Ioperm(from int, num int, on int) (err error) {
_, _, e1 := Syscall(SYS_IOPERM, uintptr(from), uintptr(num), uintptr(on))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Iopl(level int) (err error) {
_, _, e1 := Syscall(SYS_IOPL, uintptr(level), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Listen(s int, n int) (err error) {
_, _, e1 := Syscall(SYS_LISTEN, uintptr(s), uintptr(n), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Pause() (err error) {
_, _, e1 := Syscall(SYS_PAUSE, 0, 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pread(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PREAD64, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), 0, 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pwrite(fd int, p []byte, offset int64) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_PWRITE64, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(offset), 0, 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Renameat(olddirfd int, oldpath string, newdirfd int, newpath string) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(oldpath)
if err != nil {
return
}
var _p1 *byte
_p1, err = BytePtrFromString(newpath)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_RENAMEAT, uintptr(olddirfd), uintptr(unsafe.Pointer(_p0)), uintptr(newdirfd), uintptr(unsafe.Pointer(_p1)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Seek(fd int, offset int64, whence int) (off int64, err error) {
r0, _, e1 := Syscall(SYS_LSEEK, uintptr(fd), uintptr(offset), uintptr(whence))
off = int64(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) {
r0, _, e1 := Syscall6(SYS_SELECT, uintptr(nfd), uintptr(unsafe.Pointer(r)), uintptr(unsafe.Pointer(w)), uintptr(unsafe.Pointer(e)), uintptr(unsafe.Pointer(timeout)), 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) {
r0, _, e1 := Syscall6(SYS_SENDFILE, uintptr(outfd), uintptr(infd), uintptr(unsafe.Pointer(offset)), uintptr(count), 0, 0)
written = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setfsgid(gid int) (err error) {
_, _, e1 := Syscall(SYS_SETFSGID, uintptr(gid), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setfsuid(uid int) (err error) {
_, _, e1 := Syscall(SYS_SETFSUID, uintptr(uid), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Setrlimit(resource int, rlim *Rlimit) (err error) {
_, _, e1 := RawSyscall(SYS_SETRLIMIT, uintptr(resource), uintptr(unsafe.Pointer(rlim)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Shutdown(fd int, how int) (err error) {
_, _, e1 := Syscall(SYS_SHUTDOWN, uintptr(fd), uintptr(how), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Splice(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int64, err error) {
r0, _, e1 := Syscall6(SYS_SPLICE, uintptr(rfd), uintptr(unsafe.Pointer(roff)), uintptr(wfd), uintptr(unsafe.Pointer(woff)), uintptr(len), uintptr(flags))
n = int64(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Statfs(path string, buf *Statfs_t) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_STATFS, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(buf)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func SyncFileRange(fd int, off int64, n int64, flags int) (err error) {
_, _, e1 := Syscall6(SYS_SYNC_FILE_RANGE, uintptr(fd), uintptr(off), uintptr(n), uintptr(flags), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Truncate(path string, length int64) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_TRUNCATE, uintptr(unsafe.Pointer(_p0)), uintptr(length), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Ustat(dev int, ubuf *Ustat_t) (err error) {
_, _, e1 := Syscall(SYS_USTAT, uintptr(dev), uintptr(unsafe.Pointer(ubuf)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func accept4(s int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (fd int, err error) {
r0, _, e1 := Syscall6(SYS_ACCEPT4, uintptr(s), uintptr(unsafe.Pointer(rsa)), uintptr(unsafe.Pointer(addrlen)), uintptr(flags), 0, 0)
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func bind(s int, addr unsafe.Pointer, addrlen _Socklen) (err error) {
_, _, e1 := Syscall(SYS_BIND, uintptr(s), uintptr(addr), uintptr(addrlen))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func connect(s int, addr unsafe.Pointer, addrlen _Socklen) (err error) {
_, _, e1 := Syscall(SYS_CONNECT, uintptr(s), uintptr(addr), uintptr(addrlen))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fstatat(fd int, path string, stat *Stat_t, flags int) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall6(SYS_NEWFSTATAT, uintptr(fd), uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(stat)), uintptr(flags), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func getgroups(n int, list *_Gid_t) (nn int, err error) {
r0, _, e1 := RawSyscall(SYS_GETGROUPS, uintptr(n), uintptr(unsafe.Pointer(list)), 0)
nn = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *_Socklen) (err error) {
_, _, e1 := Syscall6(SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) {
_, _, e1 := Syscall6(SYS_SETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(vallen), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func socket(domain int, typ int, proto int) (fd int, err error) {
r0, _, e1 := RawSyscall(SYS_SOCKET, uintptr(domain), uintptr(typ), uintptr(proto))
fd = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func socketpair(domain int, typ int, proto int, fd *[2]int32) (err error) {
_, _, e1 := RawSyscall6(SYS_SOCKETPAIR, uintptr(domain), uintptr(typ), uintptr(proto), uintptr(unsafe.Pointer(fd)), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func getpeername(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error) {
_, _, e1 := RawSyscall(SYS_GETPEERNAME, uintptr(fd), uintptr(unsafe.Pointer(rsa)), uintptr(unsafe.Pointer(addrlen)))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func getsockname(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error) {
_, _, e1 := RawSyscall(SYS_GETSOCKNAME, uintptr(fd), uintptr(unsafe.Pointer(rsa)), uintptr(unsafe.Pointer(addrlen)))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) {
var _p0 unsafe.Pointer
if len(p) > 0 {
_p0 = unsafe.Pointer(&p[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_RECVFROM, uintptr(fd), uintptr(_p0), uintptr(len(p)), uintptr(flags), uintptr(unsafe.Pointer(from)), uintptr(unsafe.Pointer(fromlen)))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error) {
var _p0 unsafe.Pointer
if len(buf) > 0 {
_p0 = unsafe.Pointer(&buf[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
_, _, e1 := Syscall6(SYS_SENDTO, uintptr(s), uintptr(_p0), uintptr(len(buf)), uintptr(flags), uintptr(to), uintptr(addrlen))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func recvmsg(s int, msg *Msghdr, flags int) (n int, err error) {
r0, _, e1 := Syscall(SYS_RECVMSG, uintptr(s), uintptr(unsafe.Pointer(msg)), uintptr(flags))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func sendmsg(s int, msg *Msghdr, flags int) (n int, err error) {
r0, _, e1 := Syscall(SYS_SENDMSG, uintptr(s), uintptr(unsafe.Pointer(msg)), uintptr(flags))
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func mmap(addr uintptr, length uintptr, prot int, flags int, fd int, offset int64) (xaddr uintptr, err error) {
r0, _, e1 := Syscall6(SYS_MMAP, uintptr(addr), uintptr(length), uintptr(prot), uintptr(flags), uintptr(fd), uintptr(offset))
xaddr = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func EpollWait(epfd int, events []EpollEvent, msec int) (n int, err error) {
var _p0 unsafe.Pointer
if len(events) > 0 {
_p0 = unsafe.Pointer(&events[0])
} else {
_p0 = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(SYS_EPOLL_WAIT, uintptr(epfd), uintptr(_p0), uintptr(len(events)), uintptr(msec), 0, 0)
n = int(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func futimesat(dirfd int, path string, times *[2]Timeval) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_FUTIMESAT, uintptr(dirfd), uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(times)))
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Utime(path string, buf *Utimbuf) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_UTIME, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(buf)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func utimes(path string, times *[2]Timeval) (err error) {
var _p0 *byte
_p0, err = BytePtrFromString(path)
if err != nil {
return
}
_, _, e1 := Syscall(SYS_UTIMES, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(times)), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testing
import (
"runtime"
)
// AllocsPerRun returns the average number of allocations during calls to f.
// Although the return value has type float64, it will always be an integral value.
//
// To compute the number of allocations, the function will first be run once as
// a warm-up. The average number of allocations over the specified number of
// runs will then be measured and returned.
//
// AllocsPerRun sets GOMAXPROCS to 1 during its measurement and will restore
// it before returning.
func AllocsPerRun(runs int, f func()) (avg float64) {
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
// Warm up the function
f()
// Measure the starting statistics
var memstats runtime.MemStats
runtime.ReadMemStats(&memstats)
mallocs := 0 - memstats.Mallocs
// Run the function the specified number of times
for i := 0; i < runs; i++ {
f()
}
// Read the final statistics
runtime.ReadMemStats(&memstats)
mallocs += memstats.Mallocs
// Average the mallocs over the runs (not counting the warm-up).
// We are forced to return a float64 because the API is silly, but do
// the division as integers so we can ask if AllocsPerRun()==1
// instead of AllocsPerRun()<2.
return float64(mallocs / uint64(runs))
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testing
import (
"flag"
"fmt"
"internal/race"
"internal/sysinfo"
"io"
"math"
"os"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
)
func initBenchmarkFlags() {
matchBenchmarks = flag.String("test.bench", "", "run only benchmarks matching `regexp`")
benchmarkMemory = flag.Bool("test.benchmem", false, "print memory allocations for benchmarks")
flag.Var(&benchTime, "test.benchtime", "run each benchmark for duration `d`")
}
var (
matchBenchmarks *string
benchmarkMemory *bool
benchTime = durationOrCountFlag{d: 1 * time.Second} // changed during test of testing package
)
type durationOrCountFlag struct {
d time.Duration
n int
allowZero bool
}
func (f *durationOrCountFlag) String() string {
if f.n > 0 {
return fmt.Sprintf("%dx", f.n)
}
return f.d.String()
}
func (f *durationOrCountFlag) Set(s string) error {
if strings.HasSuffix(s, "x") {
n, err := strconv.ParseInt(s[:len(s)-1], 10, 0)
if err != nil || n < 0 || (!f.allowZero && n == 0) {
return fmt.Errorf("invalid count")
}
*f = durationOrCountFlag{n: int(n)}
return nil
}
d, err := time.ParseDuration(s)
if err != nil || d < 0 || (!f.allowZero && d == 0) {
return fmt.Errorf("invalid duration")
}
*f = durationOrCountFlag{d: d}
return nil
}
// Global lock to ensure only one benchmark runs at a time.
var benchmarkLock sync.Mutex
// Used for every benchmark for measuring memory.
var memStats runtime.MemStats
// InternalBenchmark is an internal type but exported because it is cross-package;
// it is part of the implementation of the "go test" command.
type InternalBenchmark struct {
Name string
F func(b *B)
}
// B is a type passed to Benchmark functions to manage benchmark
// timing and to specify the number of iterations to run.
//
// A benchmark ends when its Benchmark function returns or calls any of the methods
// FailNow, Fatal, Fatalf, SkipNow, Skip, or Skipf. Those methods must be called
// only from the goroutine running the Benchmark function.
// The other reporting methods, such as the variations of Log and Error,
// may be called simultaneously from multiple goroutines.
//
// Like in tests, benchmark logs are accumulated during execution
// and dumped to standard output when done. Unlike in tests, benchmark logs
// are always printed, so as not to hide output whose existence may be
// affecting benchmark results.
type B struct {
common
importPath string // import path of the package containing the benchmark
context *benchContext
N int
previousN int // number of iterations in the previous run
previousDuration time.Duration // total duration of the previous run
benchFunc func(b *B)
benchTime durationOrCountFlag
bytes int64
missingBytes bool // one of the subbenchmarks does not have bytes set.
timerOn bool
showAllocResult bool
result BenchmarkResult
parallelism int // RunParallel creates parallelism*GOMAXPROCS goroutines
// The initial states of memStats.Mallocs and memStats.TotalAlloc.
startAllocs uint64
startBytes uint64
// The net total of this test after being run.
netAllocs uint64
netBytes uint64
// Extra metrics collected by ReportMetric.
extra map[string]float64
}
// StartTimer starts timing a test. This function is called automatically
// before a benchmark starts, but it can also be used to resume timing after
// a call to StopTimer.
func (b *B) StartTimer() {
if !b.timerOn {
runtime.ReadMemStats(&memStats)
b.startAllocs = memStats.Mallocs
b.startBytes = memStats.TotalAlloc
b.start = time.Now()
b.timerOn = true
}
}
// StopTimer stops timing a test. This can be used to pause the timer
// while performing complex initialization that you don't
// want to measure.
func (b *B) StopTimer() {
if b.timerOn {
b.duration += time.Since(b.start)
runtime.ReadMemStats(&memStats)
b.netAllocs += memStats.Mallocs - b.startAllocs
b.netBytes += memStats.TotalAlloc - b.startBytes
b.timerOn = false
}
}
// ResetTimer zeroes the elapsed benchmark time and memory allocation counters
// and deletes user-reported metrics.
// It does not affect whether the timer is running.
func (b *B) ResetTimer() {
if b.extra == nil {
// Allocate the extra map before reading memory stats.
// Pre-size it to make more allocation unlikely.
b.extra = make(map[string]float64, 16)
} else {
for k := range b.extra {
delete(b.extra, k)
}
}
if b.timerOn {
runtime.ReadMemStats(&memStats)
b.startAllocs = memStats.Mallocs
b.startBytes = memStats.TotalAlloc
b.start = time.Now()
}
b.duration = 0
b.netAllocs = 0
b.netBytes = 0
}
// SetBytes records the number of bytes processed in a single operation.
// If this is called, the benchmark will report ns/op and MB/s.
func (b *B) SetBytes(n int64) { b.bytes = n }
// ReportAllocs enables malloc statistics for this benchmark.
// It is equivalent to setting -test.benchmem, but it only affects the
// benchmark function that calls ReportAllocs.
func (b *B) ReportAllocs() {
b.showAllocResult = true
}
// runN runs a single benchmark for the specified number of iterations.
func (b *B) runN(n int) {
benchmarkLock.Lock()
defer benchmarkLock.Unlock()
defer b.runCleanup(normalPanic)
// Try to get a comparable environment for each run
// by clearing garbage from previous runs.
runtime.GC()
b.raceErrors = -race.Errors()
b.N = n
b.parallelism = 1
b.ResetTimer()
b.StartTimer()
b.benchFunc(b)
b.StopTimer()
b.previousN = n
b.previousDuration = b.duration
b.raceErrors += race.Errors()
if b.raceErrors > 0 {
b.Errorf("race detected during execution of benchmark")
}
}
func min(x, y int64) int64 {
if x > y {
return y
}
return x
}
func max(x, y int64) int64 {
if x < y {
return y
}
return x
}
// run1 runs the first iteration of benchFunc. It reports whether more
// iterations of this benchmarks should be run.
func (b *B) run1() bool {
if ctx := b.context; ctx != nil {
// Extend maxLen, if needed.
if n := len(b.name) + ctx.extLen + 1; n > ctx.maxLen {
ctx.maxLen = n + 8 // Add additional slack to avoid too many jumps in size.
}
}
go func() {
// Signal that we're done whether we return normally
// or by FailNow's runtime.Goexit.
defer func() {
b.signal <- true
}()
b.runN(1)
}()
<-b.signal
if b.failed {
fmt.Fprintf(b.w, "%s--- FAIL: %s\n%s", b.chatty.prefix(), b.name, b.output)
return false
}
// Only print the output if we know we are not going to proceed.
// Otherwise it is printed in processBench.
b.mu.RLock()
finished := b.finished
b.mu.RUnlock()
if b.hasSub.Load() || finished {
tag := "BENCH"
if b.skipped {
tag = "SKIP"
}
if b.chatty != nil && (len(b.output) > 0 || finished) {
b.trimOutput()
fmt.Fprintf(b.w, "%s--- %s: %s\n%s", b.chatty.prefix(), tag, b.name, b.output)
}
return false
}
return true
}
var labelsOnce sync.Once
// run executes the benchmark in a separate goroutine, including all of its
// subbenchmarks. b must not have subbenchmarks.
func (b *B) run() {
labelsOnce.Do(func() {
fmt.Fprintf(b.w, "goos: %s\n", runtime.GOOS)
fmt.Fprintf(b.w, "goarch: %s\n", runtime.GOARCH)
if b.importPath != "" {
fmt.Fprintf(b.w, "pkg: %s\n", b.importPath)
}
if cpu := sysinfo.CPU.Name(); cpu != "" {
fmt.Fprintf(b.w, "cpu: %s\n", cpu)
}
})
if b.context != nil {
// Running go test --test.bench
b.context.processBench(b) // Must call doBench.
} else {
// Running func Benchmark.
b.doBench()
}
}
func (b *B) doBench() BenchmarkResult {
go b.launch()
<-b.signal
return b.result
}
// launch launches the benchmark function. It gradually increases the number
// of benchmark iterations until the benchmark runs for the requested benchtime.
// launch is run by the doBench function as a separate goroutine.
// run1 must have been called on b.
func (b *B) launch() {
// Signal that we're done whether we return normally
// or by FailNow's runtime.Goexit.
defer func() {
b.signal <- true
}()
// Run the benchmark for at least the specified amount of time.
if b.benchTime.n > 0 {
// We already ran a single iteration in run1.
// If -benchtime=1x was requested, use that result.
// See https://golang.org/issue/32051.
if b.benchTime.n > 1 {
b.runN(b.benchTime.n)
}
} else {
d := b.benchTime.d
for n := int64(1); !b.failed && b.duration < d && n < 1e9; {
last := n
// Predict required iterations.
goalns := d.Nanoseconds()
prevIters := int64(b.N)
prevns := b.duration.Nanoseconds()
if prevns <= 0 {
// Round up, to avoid div by zero.
prevns = 1
}
// Order of operations matters.
// For very fast benchmarks, prevIters ~= prevns.
// If you divide first, you get 0 or 1,
// which can hide an order of magnitude in execution time.
// So multiply first, then divide.
n = goalns * prevIters / prevns
// Run more iterations than we think we'll need (1.2x).
n += n / 5
// Don't grow too fast in case we had timing errors previously.
n = min(n, 100*last)
// Be sure to run at least one more than last time.
n = max(n, last+1)
// Don't run more than 1e9 times. (This also keeps n in int range on 32 bit platforms.)
n = min(n, 1e9)
b.runN(int(n))
}
}
b.result = BenchmarkResult{b.N, b.duration, b.bytes, b.netAllocs, b.netBytes, b.extra}
}
// Elapsed returns the measured elapsed time of the benchmark.
// The duration reported by Elapsed matches the one measured by
// StartTimer, StopTimer, and ResetTimer.
func (b *B) Elapsed() time.Duration {
d := b.duration
if b.timerOn {
d += time.Since(b.start)
}
return d
}
// ReportMetric adds "n unit" to the reported benchmark results.
// If the metric is per-iteration, the caller should divide by b.N,
// and by convention units should end in "/op".
// ReportMetric overrides any previously reported value for the same unit.
// ReportMetric panics if unit is the empty string or if unit contains
// any whitespace.
// If unit is a unit normally reported by the benchmark framework itself
// (such as "allocs/op"), ReportMetric will override that metric.
// Setting "ns/op" to 0 will suppress that built-in metric.
func (b *B) ReportMetric(n float64, unit string) {
if unit == "" {
panic("metric unit must not be empty")
}
if strings.IndexFunc(unit, unicode.IsSpace) >= 0 {
panic("metric unit must not contain whitespace")
}
b.extra[unit] = n
}
// BenchmarkResult contains the results of a benchmark run.
type BenchmarkResult struct {
N int // The number of iterations.
T time.Duration // The total time taken.
Bytes int64 // Bytes processed in one iteration.
MemAllocs uint64 // The total number of memory allocations.
MemBytes uint64 // The total number of bytes allocated.
// Extra records additional metrics reported by ReportMetric.
Extra map[string]float64
}
// NsPerOp returns the "ns/op" metric.
func (r BenchmarkResult) NsPerOp() int64 {
if v, ok := r.Extra["ns/op"]; ok {
return int64(v)
}
if r.N <= 0 {
return 0
}
return r.T.Nanoseconds() / int64(r.N)
}
// mbPerSec returns the "MB/s" metric.
func (r BenchmarkResult) mbPerSec() float64 {
if v, ok := r.Extra["MB/s"]; ok {
return v
}
if r.Bytes <= 0 || r.T <= 0 || r.N <= 0 {
return 0
}
return (float64(r.Bytes) * float64(r.N) / 1e6) / r.T.Seconds()
}
// AllocsPerOp returns the "allocs/op" metric,
// which is calculated as r.MemAllocs / r.N.
func (r BenchmarkResult) AllocsPerOp() int64 {
if v, ok := r.Extra["allocs/op"]; ok {
return int64(v)
}
if r.N <= 0 {
return 0
}
return int64(r.MemAllocs) / int64(r.N)
}
// AllocedBytesPerOp returns the "B/op" metric,
// which is calculated as r.MemBytes / r.N.
func (r BenchmarkResult) AllocedBytesPerOp() int64 {
if v, ok := r.Extra["B/op"]; ok {
return int64(v)
}
if r.N <= 0 {
return 0
}
return int64(r.MemBytes) / int64(r.N)
}
// String returns a summary of the benchmark results.
// It follows the benchmark result line format from
// https://golang.org/design/14313-benchmark-format, not including the
// benchmark name.
// Extra metrics override built-in metrics of the same name.
// String does not include allocs/op or B/op, since those are reported
// by MemString.
func (r BenchmarkResult) String() string {
buf := new(strings.Builder)
fmt.Fprintf(buf, "%8d", r.N)
// Get ns/op as a float.
ns, ok := r.Extra["ns/op"]
if !ok {
ns = float64(r.T.Nanoseconds()) / float64(r.N)
}
if ns != 0 {
buf.WriteByte('\t')
prettyPrint(buf, ns, "ns/op")
}
if mbs := r.mbPerSec(); mbs != 0 {
fmt.Fprintf(buf, "\t%7.2f MB/s", mbs)
}
// Print extra metrics that aren't represented in the standard
// metrics.
var extraKeys []string
for k := range r.Extra {
switch k {
case "ns/op", "MB/s", "B/op", "allocs/op":
// Built-in metrics reported elsewhere.
continue
}
extraKeys = append(extraKeys, k)
}
sort.Strings(extraKeys)
for _, k := range extraKeys {
buf.WriteByte('\t')
prettyPrint(buf, r.Extra[k], k)
}
return buf.String()
}
func prettyPrint(w io.Writer, x float64, unit string) {
// Print all numbers with 10 places before the decimal point
// and small numbers with four sig figs. Field widths are
// chosen to fit the whole part in 10 places while aligning
// the decimal point of all fractional formats.
var format string
switch y := math.Abs(x); {
case y == 0 || y >= 999.95:
format = "%10.0f %s"
case y >= 99.995:
format = "%12.1f %s"
case y >= 9.9995:
format = "%13.2f %s"
case y >= 0.99995:
format = "%14.3f %s"
case y >= 0.099995:
format = "%15.4f %s"
case y >= 0.0099995:
format = "%16.5f %s"
case y >= 0.00099995:
format = "%17.6f %s"
default:
format = "%18.7f %s"
}
fmt.Fprintf(w, format, x, unit)
}
// MemString returns r.AllocedBytesPerOp and r.AllocsPerOp in the same format as 'go test'.
func (r BenchmarkResult) MemString() string {
return fmt.Sprintf("%8d B/op\t%8d allocs/op",
r.AllocedBytesPerOp(), r.AllocsPerOp())
}
// benchmarkName returns full name of benchmark including procs suffix.
func benchmarkName(name string, n int) string {
if n != 1 {
return fmt.Sprintf("%s-%d", name, n)
}
return name
}
type benchContext struct {
match *matcher
maxLen int // The largest recorded benchmark name.
extLen int // Maximum extension length.
}
// RunBenchmarks is an internal function but exported because it is cross-package;
// it is part of the implementation of the "go test" command.
func RunBenchmarks(matchString func(pat, str string) (bool, error), benchmarks []InternalBenchmark) {
runBenchmarks("", matchString, benchmarks)
}
func runBenchmarks(importPath string, matchString func(pat, str string) (bool, error), benchmarks []InternalBenchmark) bool {
// If no flag was specified, don't run benchmarks.
if len(*matchBenchmarks) == 0 {
return true
}
// Collect matching benchmarks and determine longest name.
maxprocs := 1
for _, procs := range cpuList {
if procs > maxprocs {
maxprocs = procs
}
}
ctx := &benchContext{
match: newMatcher(matchString, *matchBenchmarks, "-test.bench", *skip),
extLen: len(benchmarkName("", maxprocs)),
}
var bs []InternalBenchmark
for _, Benchmark := range benchmarks {
if _, matched, _ := ctx.match.fullName(nil, Benchmark.Name); matched {
bs = append(bs, Benchmark)
benchName := benchmarkName(Benchmark.Name, maxprocs)
if l := len(benchName) + ctx.extLen + 1; l > ctx.maxLen {
ctx.maxLen = l
}
}
}
main := &B{
common: common{
name: "Main",
w: os.Stdout,
bench: true,
},
importPath: importPath,
benchFunc: func(b *B) {
for _, Benchmark := range bs {
b.Run(Benchmark.Name, Benchmark.F)
}
},
benchTime: benchTime,
context: ctx,
}
if Verbose() {
main.chatty = newChattyPrinter(main.w)
}
main.runN(1)
return !main.failed
}
// processBench runs bench b for the configured CPU counts and prints the results.
func (ctx *benchContext) processBench(b *B) {
for i, procs := range cpuList {
for j := uint(0); j < *count; j++ {
runtime.GOMAXPROCS(procs)
benchName := benchmarkName(b.name, procs)
// If it's chatty, we've already printed this information.
if b.chatty == nil {
fmt.Fprintf(b.w, "%-*s\t", ctx.maxLen, benchName)
}
// Recompute the running time for all but the first iteration.
if i > 0 || j > 0 {
b = &B{
common: common{
signal: make(chan bool),
name: b.name,
w: b.w,
chatty: b.chatty,
bench: true,
},
benchFunc: b.benchFunc,
benchTime: b.benchTime,
}
b.run1()
}
r := b.doBench()
if b.failed {
// The output could be very long here, but probably isn't.
// We print it all, regardless, because we don't want to trim the reason
// the benchmark failed.
fmt.Fprintf(b.w, "%s--- FAIL: %s\n%s", b.chatty.prefix(), benchName, b.output)
continue
}
results := r.String()
if b.chatty != nil {
fmt.Fprintf(b.w, "%-*s\t", ctx.maxLen, benchName)
}
if *benchmarkMemory || b.showAllocResult {
results += "\t" + r.MemString()
}
fmt.Fprintln(b.w, results)
// Unlike with tests, we ignore the -chatty flag and always print output for
// benchmarks since the output generation time will skew the results.
if len(b.output) > 0 {
b.trimOutput()
fmt.Fprintf(b.w, "%s--- BENCH: %s\n%s", b.chatty.prefix(), benchName, b.output)
}
if p := runtime.GOMAXPROCS(-1); p != procs {
fmt.Fprintf(os.Stderr, "testing: %s left GOMAXPROCS set to %d\n", benchName, p)
}
if b.chatty != nil && b.chatty.json {
b.chatty.Updatef("", "=== NAME %s\n", "")
}
}
}
}
// If hideStdoutForTesting is true, Run does not print the benchName.
// This avoids a spurious print during 'go test' on package testing itself,
// which invokes b.Run in its own tests (see sub_test.go).
var hideStdoutForTesting = false
// Run benchmarks f as a subbenchmark with the given name. It reports
// whether there were any failures.
//
// A subbenchmark is like any other benchmark. A benchmark that calls Run at
// least once will not be measured itself and will be called once with N=1.
func (b *B) Run(name string, f func(b *B)) bool {
// Since b has subbenchmarks, we will no longer run it as a benchmark itself.
// Release the lock and acquire it on exit to ensure locks stay paired.
b.hasSub.Store(true)
benchmarkLock.Unlock()
defer benchmarkLock.Lock()
benchName, ok, partial := b.name, true, false
if b.context != nil {
benchName, ok, partial = b.context.match.fullName(&b.common, name)
}
if !ok {
return true
}
var pc [maxStackLen]uintptr
n := runtime.Callers(2, pc[:])
sub := &B{
common: common{
signal: make(chan bool),
name: benchName,
parent: &b.common,
level: b.level + 1,
creator: pc[:n],
w: b.w,
chatty: b.chatty,
bench: true,
},
importPath: b.importPath,
benchFunc: f,
benchTime: b.benchTime,
context: b.context,
}
if partial {
// Partial name match, like -bench=X/Y matching BenchmarkX.
// Only process sub-benchmarks, if any.
sub.hasSub.Store(true)
}
if b.chatty != nil {
labelsOnce.Do(func() {
fmt.Printf("goos: %s\n", runtime.GOOS)
fmt.Printf("goarch: %s\n", runtime.GOARCH)
if b.importPath != "" {
fmt.Printf("pkg: %s\n", b.importPath)
}
if cpu := sysinfo.CPU.Name(); cpu != "" {
fmt.Printf("cpu: %s\n", cpu)
}
})
if !hideStdoutForTesting {
if b.chatty.json {
b.chatty.Updatef(benchName, "=== RUN %s\n", benchName)
}
fmt.Println(benchName)
}
}
if sub.run1() {
sub.run()
}
b.add(sub.result)
return !sub.failed
}
// add simulates running benchmarks in sequence in a single iteration. It is
// used to give some meaningful results in case func Benchmark is used in
// combination with Run.
func (b *B) add(other BenchmarkResult) {
r := &b.result
// The aggregated BenchmarkResults resemble running all subbenchmarks as
// in sequence in a single benchmark.
r.N = 1
r.T += time.Duration(other.NsPerOp())
if other.Bytes == 0 {
// Summing Bytes is meaningless in aggregate if not all subbenchmarks
// set it.
b.missingBytes = true
r.Bytes = 0
}
if !b.missingBytes {
r.Bytes += other.Bytes
}
r.MemAllocs += uint64(other.AllocsPerOp())
r.MemBytes += uint64(other.AllocedBytesPerOp())
}
// trimOutput shortens the output from a benchmark, which can be very long.
func (b *B) trimOutput() {
// The output is likely to appear multiple times because the benchmark
// is run multiple times, but at least it will be seen. This is not a big deal
// because benchmarks rarely print, but just in case, we trim it if it's too long.
const maxNewlines = 10
for nlCount, j := 0, 0; j < len(b.output); j++ {
if b.output[j] == '\n' {
nlCount++
if nlCount >= maxNewlines {
b.output = append(b.output[:j], "\n\t... [output truncated]\n"...)
break
}
}
}
}
// A PB is used by RunParallel for running parallel benchmarks.
type PB struct {
globalN *uint64 // shared between all worker goroutines iteration counter
grain uint64 // acquire that many iterations from globalN at once
cache uint64 // local cache of acquired iterations
bN uint64 // total number of iterations to execute (b.N)
}
// Next reports whether there are more iterations to execute.
func (pb *PB) Next() bool {
if pb.cache == 0 {
n := atomic.AddUint64(pb.globalN, pb.grain)
if n <= pb.bN {
pb.cache = pb.grain
} else if n < pb.bN+pb.grain {
pb.cache = pb.bN + pb.grain - n
} else {
return false
}
}
pb.cache--
return true
}
// RunParallel runs a benchmark in parallel.
// It creates multiple goroutines and distributes b.N iterations among them.
// The number of goroutines defaults to GOMAXPROCS. To increase parallelism for
// non-CPU-bound benchmarks, call SetParallelism before RunParallel.
// RunParallel is usually used with the go test -cpu flag.
//
// The body function will be run in each goroutine. It should set up any
// goroutine-local state and then iterate until pb.Next returns false.
// It should not use the StartTimer, StopTimer, or ResetTimer functions,
// because they have global effect. It should also not call Run.
//
// RunParallel reports ns/op values as wall time for the benchmark as a whole,
// not the sum of wall time or CPU time over each parallel goroutine.
func (b *B) RunParallel(body func(*PB)) {
if b.N == 0 {
return // Nothing to do when probing.
}
// Calculate grain size as number of iterations that take ~100µs.
// 100µs is enough to amortize the overhead and provide sufficient
// dynamic load balancing.
grain := uint64(0)
if b.previousN > 0 && b.previousDuration > 0 {
grain = 1e5 * uint64(b.previousN) / uint64(b.previousDuration)
}
if grain < 1 {
grain = 1
}
// We expect the inner loop and function call to take at least 10ns,
// so do not do more than 100µs/10ns=1e4 iterations.
if grain > 1e4 {
grain = 1e4
}
n := uint64(0)
numProcs := b.parallelism * runtime.GOMAXPROCS(0)
var wg sync.WaitGroup
wg.Add(numProcs)
for p := 0; p < numProcs; p++ {
go func() {
defer wg.Done()
pb := &PB{
globalN: &n,
grain: grain,
bN: uint64(b.N),
}
body(pb)
}()
}
wg.Wait()
if n <= uint64(b.N) && !b.Failed() {
b.Fatal("RunParallel: body exited without pb.Next() == false")
}
}
// SetParallelism sets the number of goroutines used by RunParallel to p*GOMAXPROCS.
// There is usually no need to call SetParallelism for CPU-bound benchmarks.
// If p is less than 1, this call will have no effect.
func (b *B) SetParallelism(p int) {
if p >= 1 {
b.parallelism = p
}
}
// Benchmark benchmarks a single function. It is useful for creating
// custom benchmarks that do not use the "go test" command.
//
// If f depends on testing flags, then Init must be used to register
// those flags before calling Benchmark and before calling flag.Parse.
//
// If f calls Run, the result will be an estimate of running all its
// subbenchmarks that don't call Run in sequence in a single benchmark.
func Benchmark(f func(b *B)) BenchmarkResult {
b := &B{
common: common{
signal: make(chan bool),
w: discard{},
},
benchFunc: f,
benchTime: benchTime,
}
if b.run1() {
b.run()
}
return b.result
}
type discard struct{}
func (discard) Write(b []byte) (n int, err error) { return len(b), nil }
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Support for test coverage.
package testing
import (
"fmt"
"internal/goexperiment"
"os"
"sync/atomic"
)
// CoverBlock records the coverage data for a single basic block.
// The fields are 1-indexed, as in an editor: The opening line of
// the file is number 1, for example. Columns are measured
// in bytes.
// NOTE: This struct is internal to the testing infrastructure and may change.
// It is not covered (yet) by the Go 1 compatibility guidelines.
type CoverBlock struct {
Line0 uint32 // Line number for block start.
Col0 uint16 // Column number for block start.
Line1 uint32 // Line number for block end.
Col1 uint16 // Column number for block end.
Stmts uint16 // Number of statements included in this block.
}
var cover Cover
// Cover records information about test coverage checking.
// NOTE: This struct is internal to the testing infrastructure and may change.
// It is not covered (yet) by the Go 1 compatibility guidelines.
type Cover struct {
Mode string
Counters map[string][]uint32
Blocks map[string][]CoverBlock
CoveredPackages string
}
// Coverage reports the current code coverage as a fraction in the range [0, 1].
// If coverage is not enabled, Coverage returns 0.
//
// When running a large set of sequential test cases, checking Coverage after each one
// can be useful for identifying which test cases exercise new code paths.
// It is not a replacement for the reports generated by 'go test -cover' and
// 'go tool cover'.
func Coverage() float64 {
var n, d int64
for _, counters := range cover.Counters {
for i := range counters {
if atomic.LoadUint32(&counters[i]) > 0 {
n++
}
d++
}
}
if d == 0 {
return 0
}
return float64(n) / float64(d)
}
// RegisterCover records the coverage data accumulators for the tests.
// NOTE: This function is internal to the testing infrastructure and may change.
// It is not covered (yet) by the Go 1 compatibility guidelines.
func RegisterCover(c Cover) {
cover = c
}
// mustBeNil checks the error and, if present, reports it and exits.
func mustBeNil(err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
}
// coverReport reports the coverage percentage and writes a coverage profile if requested.
func coverReport() {
if goexperiment.CoverageRedesign {
coverReport2()
return
}
var f *os.File
var err error
if *coverProfile != "" {
f, err = os.Create(toOutputDir(*coverProfile))
mustBeNil(err)
fmt.Fprintf(f, "mode: %s\n", cover.Mode)
defer func() { mustBeNil(f.Close()) }()
}
var active, total int64
var count uint32
for name, counts := range cover.Counters {
blocks := cover.Blocks[name]
for i := range counts {
stmts := int64(blocks[i].Stmts)
total += stmts
count = atomic.LoadUint32(&counts[i]) // For -mode=atomic.
if count > 0 {
active += stmts
}
if f != nil {
_, err := fmt.Fprintf(f, "%s:%d.%d,%d.%d %d %d\n", name,
blocks[i].Line0, blocks[i].Col0,
blocks[i].Line1, blocks[i].Col1,
stmts,
count)
mustBeNil(err)
}
}
}
if total == 0 {
fmt.Println("coverage: [no statements]")
return
}
fmt.Printf("coverage: %.1f%% of statements%s\n", 100*float64(active)/float64(total), cover.CoveredPackages)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testing
import (
"fmt"
"os"
"sort"
"strings"
"time"
)
type InternalExample struct {
Name string
F func()
Output string
Unordered bool
}
// RunExamples is an internal function but exported because it is cross-package;
// it is part of the implementation of the "go test" command.
func RunExamples(matchString func(pat, str string) (bool, error), examples []InternalExample) (ok bool) {
_, ok = runExamples(matchString, examples)
return ok
}
func runExamples(matchString func(pat, str string) (bool, error), examples []InternalExample) (ran, ok bool) {
ok = true
var eg InternalExample
for _, eg = range examples {
matched, err := matchString(*match, eg.Name)
if err != nil {
fmt.Fprintf(os.Stderr, "testing: invalid regexp for -test.run: %s\n", err)
os.Exit(1)
}
if !matched {
continue
}
ran = true
if !runExample(eg) {
ok = false
}
}
return ran, ok
}
func sortLines(output string) string {
lines := strings.Split(output, "\n")
sort.Strings(lines)
return strings.Join(lines, "\n")
}
// processRunResult computes a summary and status of the result of running an example test.
// stdout is the captured output from stdout of the test.
// recovered is the result of invoking recover after running the test, in case it panicked.
//
// If stdout doesn't match the expected output or if recovered is non-nil, it'll print the cause of failure to stdout.
// If the test is chatty/verbose, it'll print a success message to stdout.
// If recovered is non-nil, it'll panic with that value.
// If the test panicked with nil, or invoked runtime.Goexit, it'll be
// made to fail and panic with errNilPanicOrGoexit
func (eg *InternalExample) processRunResult(stdout string, timeSpent time.Duration, finished bool, recovered any) (passed bool) {
passed = true
dstr := fmtDuration(timeSpent)
var fail string
got := strings.TrimSpace(stdout)
want := strings.TrimSpace(eg.Output)
if eg.Unordered {
if sortLines(got) != sortLines(want) && recovered == nil {
fail = fmt.Sprintf("got:\n%s\nwant (unordered):\n%s\n", stdout, eg.Output)
}
} else {
if got != want && recovered == nil {
fail = fmt.Sprintf("got:\n%s\nwant:\n%s\n", got, want)
}
}
if fail != "" || !finished || recovered != nil {
fmt.Printf("%s--- FAIL: %s (%s)\n%s", chatty.prefix(), eg.Name, dstr, fail)
passed = false
} else if chatty.on {
fmt.Printf("%s--- PASS: %s (%s)\n", chatty.prefix(), eg.Name, dstr)
}
if chatty.on && chatty.json {
fmt.Printf("%s=== NAME %s\n", chatty.prefix(), "")
}
if recovered != nil {
// Propagate the previously recovered result, by panicking.
panic(recovered)
}
if !finished && recovered == nil {
panic(errNilPanicOrGoexit)
}
return
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fstest
import (
"io"
"io/fs"
"path"
"sort"
"strings"
"time"
)
// A MapFS is a simple in-memory file system for use in tests,
// represented as a map from path names (arguments to Open)
// to information about the files or directories they represent.
//
// The map need not include parent directories for files contained
// in the map; those will be synthesized if needed.
// But a directory can still be included by setting the MapFile.Mode's ModeDir bit;
// this may be necessary for detailed control over the directory's FileInfo
// or to create an empty directory.
//
// File system operations read directly from the map,
// so that the file system can be changed by editing the map as needed.
// An implication is that file system operations must not run concurrently
// with changes to the map, which would be a race.
// Another implication is that opening or reading a directory requires
// iterating over the entire map, so a MapFS should typically be used with not more
// than a few hundred entries or directory reads.
type MapFS map[string]*MapFile
// A MapFile describes a single file in a MapFS.
type MapFile struct {
Data []byte // file content
Mode fs.FileMode // FileInfo.Mode
ModTime time.Time // FileInfo.ModTime
Sys any // FileInfo.Sys
}
var _ fs.FS = MapFS(nil)
var _ fs.File = (*openMapFile)(nil)
// Open opens the named file.
func (fsys MapFS) Open(name string) (fs.File, error) {
if !fs.ValidPath(name) {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
file := fsys[name]
if file != nil && file.Mode&fs.ModeDir == 0 {
// Ordinary file
return &openMapFile{name, mapFileInfo{path.Base(name), file}, 0}, nil
}
// Directory, possibly synthesized.
// Note that file can be nil here: the map need not contain explicit parent directories for all its files.
// But file can also be non-nil, in case the user wants to set metadata for the directory explicitly.
// Either way, we need to construct the list of children of this directory.
var list []mapFileInfo
var elem string
var need = make(map[string]bool)
if name == "." {
elem = "."
for fname, f := range fsys {
i := strings.Index(fname, "/")
if i < 0 {
if fname != "." {
list = append(list, mapFileInfo{fname, f})
}
} else {
need[fname[:i]] = true
}
}
} else {
elem = name[strings.LastIndex(name, "/")+1:]
prefix := name + "/"
for fname, f := range fsys {
if strings.HasPrefix(fname, prefix) {
felem := fname[len(prefix):]
i := strings.Index(felem, "/")
if i < 0 {
list = append(list, mapFileInfo{felem, f})
} else {
need[fname[len(prefix):len(prefix)+i]] = true
}
}
}
// If the directory name is not in the map,
// and there are no children of the name in the map,
// then the directory is treated as not existing.
if file == nil && list == nil && len(need) == 0 {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
}
for _, fi := range list {
delete(need, fi.name)
}
for name := range need {
list = append(list, mapFileInfo{name, &MapFile{Mode: fs.ModeDir}})
}
sort.Slice(list, func(i, j int) bool {
return list[i].name < list[j].name
})
if file == nil {
file = &MapFile{Mode: fs.ModeDir}
}
return &mapDir{name, mapFileInfo{elem, file}, list, 0}, nil
}
// fsOnly is a wrapper that hides all but the fs.FS methods,
// to avoid an infinite recursion when implementing special
// methods in terms of helpers that would use them.
// (In general, implementing these methods using the package fs helpers
// is redundant and unnecessary, but having the methods may make
// MapFS exercise more code paths when used in tests.)
type fsOnly struct{ fs.FS }
func (fsys MapFS) ReadFile(name string) ([]byte, error) {
return fs.ReadFile(fsOnly{fsys}, name)
}
func (fsys MapFS) Stat(name string) (fs.FileInfo, error) {
return fs.Stat(fsOnly{fsys}, name)
}
func (fsys MapFS) ReadDir(name string) ([]fs.DirEntry, error) {
return fs.ReadDir(fsOnly{fsys}, name)
}
func (fsys MapFS) Glob(pattern string) ([]string, error) {
return fs.Glob(fsOnly{fsys}, pattern)
}
type noSub struct {
MapFS
}
func (noSub) Sub() {} // not the fs.SubFS signature
func (fsys MapFS) Sub(dir string) (fs.FS, error) {
return fs.Sub(noSub{fsys}, dir)
}
// A mapFileInfo implements fs.FileInfo and fs.DirEntry for a given map file.
type mapFileInfo struct {
name string
f *MapFile
}
func (i *mapFileInfo) Name() string { return i.name }
func (i *mapFileInfo) Size() int64 { return int64(len(i.f.Data)) }
func (i *mapFileInfo) Mode() fs.FileMode { return i.f.Mode }
func (i *mapFileInfo) Type() fs.FileMode { return i.f.Mode.Type() }
func (i *mapFileInfo) ModTime() time.Time { return i.f.ModTime }
func (i *mapFileInfo) IsDir() bool { return i.f.Mode&fs.ModeDir != 0 }
func (i *mapFileInfo) Sys() any { return i.f.Sys }
func (i *mapFileInfo) Info() (fs.FileInfo, error) { return i, nil }
// An openMapFile is a regular (non-directory) fs.File open for reading.
type openMapFile struct {
path string
mapFileInfo
offset int64
}
func (f *openMapFile) Stat() (fs.FileInfo, error) { return &f.mapFileInfo, nil }
func (f *openMapFile) Close() error { return nil }
func (f *openMapFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.Data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid}
}
n := copy(b, f.f.Data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openMapFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.Data))
}
if offset < 0 || offset > int64(len(f.f.Data)) {
return 0, &fs.PathError{Op: "seek", Path: f.path, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
func (f *openMapFile) ReadAt(b []byte, offset int64) (int, error) {
if offset < 0 || offset > int64(len(f.f.Data)) {
return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid}
}
n := copy(b, f.f.Data[offset:])
if n < len(b) {
return n, io.EOF
}
return n, nil
}
// A mapDir is a directory fs.File (so also an fs.ReadDirFile) open for reading.
type mapDir struct {
path string
mapFileInfo
entry []mapFileInfo
offset int
}
func (d *mapDir) Stat() (fs.FileInfo, error) { return &d.mapFileInfo, nil }
func (d *mapDir) Close() error { return nil }
func (d *mapDir) Read(b []byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.path, Err: fs.ErrInvalid}
}
func (d *mapDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.entry) - d.offset
if n == 0 && count > 0 {
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.entry[d.offset+i]
}
d.offset += n
return list, nil
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fstest implements support for testing implementations and users of file systems.
package fstest
import (
"errors"
"fmt"
"io"
"io/fs"
"path"
"reflect"
"sort"
"strings"
"testing/iotest"
)
// TestFS tests a file system implementation.
// It walks the entire tree of files in fsys,
// opening and checking that each file behaves correctly.
// It also checks that the file system contains at least the expected files.
// As a special case, if no expected files are listed, fsys must be empty.
// Otherwise, fsys must contain at least the listed files; it can also contain others.
// The contents of fsys must not change concurrently with TestFS.
//
// If TestFS finds any misbehaviors, it returns an error reporting all of them.
// The error text spans multiple lines, one per detected misbehavior.
//
// Typical usage inside a test is:
//
// if err := fstest.TestFS(myFS, "file/that/should/be/present"); err != nil {
// t.Fatal(err)
// }
func TestFS(fsys fs.FS, expected ...string) error {
if err := testFS(fsys, expected...); err != nil {
return err
}
for _, name := range expected {
if i := strings.Index(name, "/"); i >= 0 {
dir, dirSlash := name[:i], name[:i+1]
var subExpected []string
for _, name := range expected {
if strings.HasPrefix(name, dirSlash) {
subExpected = append(subExpected, name[len(dirSlash):])
}
}
sub, err := fs.Sub(fsys, dir)
if err != nil {
return err
}
if err := testFS(sub, subExpected...); err != nil {
return fmt.Errorf("testing fs.Sub(fsys, %s): %v", dir, err)
}
break // one sub-test is enough
}
}
return nil
}
func testFS(fsys fs.FS, expected ...string) error {
t := fsTester{fsys: fsys}
t.checkDir(".")
t.checkOpen(".")
found := make(map[string]bool)
for _, dir := range t.dirs {
found[dir] = true
}
for _, file := range t.files {
found[file] = true
}
delete(found, ".")
if len(expected) == 0 && len(found) > 0 {
var list []string
for k := range found {
if k != "." {
list = append(list, k)
}
}
sort.Strings(list)
if len(list) > 15 {
list = append(list[:10], "...")
}
t.errorf("expected empty file system but found files:\n%s", strings.Join(list, "\n"))
}
for _, name := range expected {
if !found[name] {
t.errorf("expected but not found: %s", name)
}
}
if len(t.errText) == 0 {
return nil
}
return errors.New("TestFS found errors:\n" + string(t.errText))
}
// An fsTester holds state for running the test.
type fsTester struct {
fsys fs.FS
errText []byte
dirs []string
files []string
}
// errorf adds an error line to errText.
func (t *fsTester) errorf(format string, args ...any) {
if len(t.errText) > 0 {
t.errText = append(t.errText, '\n')
}
t.errText = append(t.errText, fmt.Sprintf(format, args...)...)
}
func (t *fsTester) openDir(dir string) fs.ReadDirFile {
f, err := t.fsys.Open(dir)
if err != nil {
t.errorf("%s: Open: %v", dir, err)
return nil
}
d, ok := f.(fs.ReadDirFile)
if !ok {
f.Close()
t.errorf("%s: Open returned File type %T, not a fs.ReadDirFile", dir, f)
return nil
}
return d
}
// checkDir checks the directory dir, which is expected to exist
// (it is either the root or was found in a directory listing with IsDir true).
func (t *fsTester) checkDir(dir string) {
// Read entire directory.
t.dirs = append(t.dirs, dir)
d := t.openDir(dir)
if d == nil {
return
}
list, err := d.ReadDir(-1)
if err != nil {
d.Close()
t.errorf("%s: ReadDir(-1): %v", dir, err)
return
}
// Check all children.
var prefix string
if dir == "." {
prefix = ""
} else {
prefix = dir + "/"
}
for _, info := range list {
name := info.Name()
switch {
case name == ".", name == "..", name == "":
t.errorf("%s: ReadDir: child has invalid name: %#q", dir, name)
continue
case strings.Contains(name, "/"):
t.errorf("%s: ReadDir: child name contains slash: %#q", dir, name)
continue
case strings.Contains(name, `\`):
t.errorf("%s: ReadDir: child name contains backslash: %#q", dir, name)
continue
}
path := prefix + name
t.checkStat(path, info)
t.checkOpen(path)
if info.IsDir() {
t.checkDir(path)
} else {
t.checkFile(path)
}
}
// Check ReadDir(-1) at EOF.
list2, err := d.ReadDir(-1)
if len(list2) > 0 || err != nil {
d.Close()
t.errorf("%s: ReadDir(-1) at EOF = %d entries, %v, wanted 0 entries, nil", dir, len(list2), err)
return
}
// Check ReadDir(1) at EOF (different results).
list2, err = d.ReadDir(1)
if len(list2) > 0 || err != io.EOF {
d.Close()
t.errorf("%s: ReadDir(1) at EOF = %d entries, %v, wanted 0 entries, EOF", dir, len(list2), err)
return
}
// Check that close does not report an error.
if err := d.Close(); err != nil {
t.errorf("%s: Close: %v", dir, err)
}
// Check that closing twice doesn't crash.
// The return value doesn't matter.
d.Close()
// Reopen directory, read a second time, make sure contents match.
if d = t.openDir(dir); d == nil {
return
}
defer d.Close()
list2, err = d.ReadDir(-1)
if err != nil {
t.errorf("%s: second Open+ReadDir(-1): %v", dir, err)
return
}
t.checkDirList(dir, "first Open+ReadDir(-1) vs second Open+ReadDir(-1)", list, list2)
// Reopen directory, read a third time in pieces, make sure contents match.
if d = t.openDir(dir); d == nil {
return
}
defer d.Close()
list2 = nil
for {
n := 1
if len(list2) > 0 {
n = 2
}
frag, err := d.ReadDir(n)
if len(frag) > n {
t.errorf("%s: third Open: ReadDir(%d) after %d: %d entries (too many)", dir, n, len(list2), len(frag))
return
}
list2 = append(list2, frag...)
if err == io.EOF {
break
}
if err != nil {
t.errorf("%s: third Open: ReadDir(%d) after %d: %v", dir, n, len(list2), err)
return
}
if n == 0 {
t.errorf("%s: third Open: ReadDir(%d) after %d: 0 entries but nil error", dir, n, len(list2))
return
}
}
t.checkDirList(dir, "first Open+ReadDir(-1) vs third Open+ReadDir(1,2) loop", list, list2)
// If fsys has ReadDir, check that it matches and is sorted.
if fsys, ok := t.fsys.(fs.ReadDirFS); ok {
list2, err := fsys.ReadDir(dir)
if err != nil {
t.errorf("%s: fsys.ReadDir: %v", dir, err)
return
}
t.checkDirList(dir, "first Open+ReadDir(-1) vs fsys.ReadDir", list, list2)
for i := 0; i+1 < len(list2); i++ {
if list2[i].Name() >= list2[i+1].Name() {
t.errorf("%s: fsys.ReadDir: list not sorted: %s before %s", dir, list2[i].Name(), list2[i+1].Name())
}
}
}
// Check fs.ReadDir as well.
list2, err = fs.ReadDir(t.fsys, dir)
if err != nil {
t.errorf("%s: fs.ReadDir: %v", dir, err)
return
}
t.checkDirList(dir, "first Open+ReadDir(-1) vs fs.ReadDir", list, list2)
for i := 0; i+1 < len(list2); i++ {
if list2[i].Name() >= list2[i+1].Name() {
t.errorf("%s: fs.ReadDir: list not sorted: %s before %s", dir, list2[i].Name(), list2[i+1].Name())
}
}
t.checkGlob(dir, list)
}
// formatEntry formats an fs.DirEntry into a string for error messages and comparison.
func formatEntry(entry fs.DirEntry) string {
return fmt.Sprintf("%s IsDir=%v Type=%v", entry.Name(), entry.IsDir(), entry.Type())
}
// formatInfoEntry formats an fs.FileInfo into a string like the result of formatEntry, for error messages and comparison.
func formatInfoEntry(info fs.FileInfo) string {
return fmt.Sprintf("%s IsDir=%v Type=%v", info.Name(), info.IsDir(), info.Mode().Type())
}
// formatInfo formats an fs.FileInfo into a string for error messages and comparison.
func formatInfo(info fs.FileInfo) string {
return fmt.Sprintf("%s IsDir=%v Mode=%v Size=%d ModTime=%v", info.Name(), info.IsDir(), info.Mode(), info.Size(), info.ModTime())
}
// checkGlob checks that various glob patterns work if the file system implements GlobFS.
func (t *fsTester) checkGlob(dir string, list []fs.DirEntry) {
if _, ok := t.fsys.(fs.GlobFS); !ok {
return
}
// Make a complex glob pattern prefix that only matches dir.
var glob string
if dir != "." {
elem := strings.Split(dir, "/")
for i, e := range elem {
var pattern []rune
for j, r := range e {
if r == '*' || r == '?' || r == '\\' || r == '[' || r == '-' {
pattern = append(pattern, '\\', r)
continue
}
switch (i + j) % 5 {
case 0:
pattern = append(pattern, r)
case 1:
pattern = append(pattern, '[', r, ']')
case 2:
pattern = append(pattern, '[', r, '-', r, ']')
case 3:
pattern = append(pattern, '[', '\\', r, ']')
case 4:
pattern = append(pattern, '[', '\\', r, '-', '\\', r, ']')
}
}
elem[i] = string(pattern)
}
glob = strings.Join(elem, "/") + "/"
}
// Test that malformed patterns are detected.
// The error is likely path.ErrBadPattern but need not be.
if _, err := t.fsys.(fs.GlobFS).Glob(glob + "nonexist/[]"); err == nil {
t.errorf("%s: Glob(%#q): bad pattern not detected", dir, glob+"nonexist/[]")
}
// Try to find a letter that appears in only some of the final names.
c := rune('a')
for ; c <= 'z'; c++ {
have, haveNot := false, false
for _, d := range list {
if strings.ContainsRune(d.Name(), c) {
have = true
} else {
haveNot = true
}
}
if have && haveNot {
break
}
}
if c > 'z' {
c = 'a'
}
glob += "*" + string(c) + "*"
var want []string
for _, d := range list {
if strings.ContainsRune(d.Name(), c) {
want = append(want, path.Join(dir, d.Name()))
}
}
names, err := t.fsys.(fs.GlobFS).Glob(glob)
if err != nil {
t.errorf("%s: Glob(%#q): %v", dir, glob, err)
return
}
if reflect.DeepEqual(want, names) {
return
}
if !sort.StringsAreSorted(names) {
t.errorf("%s: Glob(%#q): unsorted output:\n%s", dir, glob, strings.Join(names, "\n"))
sort.Strings(names)
}
var problems []string
for len(want) > 0 || len(names) > 0 {
switch {
case len(want) > 0 && len(names) > 0 && want[0] == names[0]:
want, names = want[1:], names[1:]
case len(want) > 0 && (len(names) == 0 || want[0] < names[0]):
problems = append(problems, "missing: "+want[0])
want = want[1:]
default:
problems = append(problems, "extra: "+names[0])
names = names[1:]
}
}
t.errorf("%s: Glob(%#q): wrong output:\n%s", dir, glob, strings.Join(problems, "\n"))
}
// checkStat checks that a direct stat of path matches entry,
// which was found in the parent's directory listing.
func (t *fsTester) checkStat(path string, entry fs.DirEntry) {
file, err := t.fsys.Open(path)
if err != nil {
t.errorf("%s: Open: %v", path, err)
return
}
info, err := file.Stat()
file.Close()
if err != nil {
t.errorf("%s: Stat: %v", path, err)
return
}
fentry := formatEntry(entry)
fientry := formatInfoEntry(info)
// Note: mismatch here is OK for symlink, because Open dereferences symlink.
if fentry != fientry && entry.Type()&fs.ModeSymlink == 0 {
t.errorf("%s: mismatch:\n\tentry = %s\n\tfile.Stat() = %s", path, fentry, fientry)
}
einfo, err := entry.Info()
if err != nil {
t.errorf("%s: entry.Info: %v", path, err)
return
}
finfo := formatInfo(info)
if entry.Type()&fs.ModeSymlink != 0 {
// For symlink, just check that entry.Info matches entry on common fields.
// Open deferences symlink, so info itself may differ.
feentry := formatInfoEntry(einfo)
if fentry != feentry {
t.errorf("%s: mismatch\n\tentry = %s\n\tentry.Info() = %s\n", path, fentry, feentry)
}
} else {
feinfo := formatInfo(einfo)
if feinfo != finfo {
t.errorf("%s: mismatch:\n\tentry.Info() = %s\n\tfile.Stat() = %s\n", path, feinfo, finfo)
}
}
// Stat should be the same as Open+Stat, even for symlinks.
info2, err := fs.Stat(t.fsys, path)
if err != nil {
t.errorf("%s: fs.Stat: %v", path, err)
return
}
finfo2 := formatInfo(info2)
if finfo2 != finfo {
t.errorf("%s: fs.Stat(...) = %s\n\twant %s", path, finfo2, finfo)
}
if fsys, ok := t.fsys.(fs.StatFS); ok {
info2, err := fsys.Stat(path)
if err != nil {
t.errorf("%s: fsys.Stat: %v", path, err)
return
}
finfo2 := formatInfo(info2)
if finfo2 != finfo {
t.errorf("%s: fsys.Stat(...) = %s\n\twant %s", path, finfo2, finfo)
}
}
}
// checkDirList checks that two directory lists contain the same files and file info.
// The order of the lists need not match.
func (t *fsTester) checkDirList(dir, desc string, list1, list2 []fs.DirEntry) {
old := make(map[string]fs.DirEntry)
checkMode := func(entry fs.DirEntry) {
if entry.IsDir() != (entry.Type()&fs.ModeDir != 0) {
if entry.IsDir() {
t.errorf("%s: ReadDir returned %s with IsDir() = true, Type() & ModeDir = 0", dir, entry.Name())
} else {
t.errorf("%s: ReadDir returned %s with IsDir() = false, Type() & ModeDir = ModeDir", dir, entry.Name())
}
}
}
for _, entry1 := range list1 {
old[entry1.Name()] = entry1
checkMode(entry1)
}
var diffs []string
for _, entry2 := range list2 {
entry1 := old[entry2.Name()]
if entry1 == nil {
checkMode(entry2)
diffs = append(diffs, "+ "+formatEntry(entry2))
continue
}
if formatEntry(entry1) != formatEntry(entry2) {
diffs = append(diffs, "- "+formatEntry(entry1), "+ "+formatEntry(entry2))
}
delete(old, entry2.Name())
}
for _, entry1 := range old {
diffs = append(diffs, "- "+formatEntry(entry1))
}
if len(diffs) == 0 {
return
}
sort.Slice(diffs, func(i, j int) bool {
fi := strings.Fields(diffs[i])
fj := strings.Fields(diffs[j])
// sort by name (i < j) and then +/- (j < i, because + < -)
return fi[1]+" "+fj[0] < fj[1]+" "+fi[0]
})
t.errorf("%s: diff %s:\n\t%s", dir, desc, strings.Join(diffs, "\n\t"))
}
// checkFile checks that basic file reading works correctly.
func (t *fsTester) checkFile(file string) {
t.files = append(t.files, file)
// Read entire file.
f, err := t.fsys.Open(file)
if err != nil {
t.errorf("%s: Open: %v", file, err)
return
}
data, err := io.ReadAll(f)
if err != nil {
f.Close()
t.errorf("%s: Open+ReadAll: %v", file, err)
return
}
if err := f.Close(); err != nil {
t.errorf("%s: Close: %v", file, err)
}
// Check that closing twice doesn't crash.
// The return value doesn't matter.
f.Close()
// Check that ReadFile works if present.
if fsys, ok := t.fsys.(fs.ReadFileFS); ok {
data2, err := fsys.ReadFile(file)
if err != nil {
t.errorf("%s: fsys.ReadFile: %v", file, err)
return
}
t.checkFileRead(file, "ReadAll vs fsys.ReadFile", data, data2)
// Modify the data and check it again. Modifying the
// returned byte slice should not affect the next call.
for i := range data2 {
data2[i]++
}
data2, err = fsys.ReadFile(file)
if err != nil {
t.errorf("%s: second call to fsys.ReadFile: %v", file, err)
return
}
t.checkFileRead(file, "Readall vs second fsys.ReadFile", data, data2)
t.checkBadPath(file, "ReadFile",
func(name string) error { _, err := fsys.ReadFile(name); return err })
}
// Check that fs.ReadFile works with t.fsys.
data2, err := fs.ReadFile(t.fsys, file)
if err != nil {
t.errorf("%s: fs.ReadFile: %v", file, err)
return
}
t.checkFileRead(file, "ReadAll vs fs.ReadFile", data, data2)
// Use iotest.TestReader to check small reads, Seek, ReadAt.
f, err = t.fsys.Open(file)
if err != nil {
t.errorf("%s: second Open: %v", file, err)
return
}
defer f.Close()
if err := iotest.TestReader(f, data); err != nil {
t.errorf("%s: failed TestReader:\n\t%s", file, strings.ReplaceAll(err.Error(), "\n", "\n\t"))
}
}
func (t *fsTester) checkFileRead(file, desc string, data1, data2 []byte) {
if string(data1) != string(data2) {
t.errorf("%s: %s: different data returned\n\t%q\n\t%q", file, desc, data1, data2)
return
}
}
// checkBadPath checks that various invalid forms of file's name cannot be opened using t.fsys.Open.
func (t *fsTester) checkOpen(file string) {
t.checkBadPath(file, "Open", func(file string) error {
f, err := t.fsys.Open(file)
if err == nil {
f.Close()
}
return err
})
}
// checkBadPath checks that various invalid forms of file's name cannot be opened using open.
func (t *fsTester) checkBadPath(file string, desc string, open func(string) error) {
bad := []string{
"/" + file,
file + "/.",
}
if file == "." {
bad = append(bad, "/")
}
if i := strings.Index(file, "/"); i >= 0 {
bad = append(bad,
file[:i]+"//"+file[i+1:],
file[:i]+"/./"+file[i+1:],
file[:i]+`\`+file[i+1:],
file[:i]+"/../"+file,
)
}
if i := strings.LastIndex(file, "/"); i >= 0 {
bad = append(bad,
file[:i]+"//"+file[i+1:],
file[:i]+"/./"+file[i+1:],
file[:i]+`\`+file[i+1:],
file+"/../"+file[i+1:],
)
}
for _, b := range bad {
if err := open(b); err == nil {
t.errorf("%s: %s(%s) succeeded, want error", file, desc, b)
}
}
}
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testing
import (
"errors"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"time"
)
func initFuzzFlags() {
matchFuzz = flag.String("test.fuzz", "", "run the fuzz test matching `regexp`")
flag.Var(&fuzzDuration, "test.fuzztime", "time to spend fuzzing; default is to run indefinitely")
flag.Var(&minimizeDuration, "test.fuzzminimizetime", "time to spend minimizing a value after finding a failing input")
fuzzCacheDir = flag.String("test.fuzzcachedir", "", "directory where interesting fuzzing inputs are stored (for use only by cmd/go)")
isFuzzWorker = flag.Bool("test.fuzzworker", false, "coordinate with the parent process to fuzz random values (for use only by cmd/go)")
}
var (
matchFuzz *string
fuzzDuration durationOrCountFlag
minimizeDuration = durationOrCountFlag{d: 60 * time.Second, allowZero: true}
fuzzCacheDir *string
isFuzzWorker *bool
// corpusDir is the parent directory of the fuzz test's seed corpus within
// the package.
corpusDir = "testdata/fuzz"
)
// fuzzWorkerExitCode is used as an exit code by fuzz worker processes after an
// internal error. This distinguishes internal errors from uncontrolled panics
// and other failures. Keep in sync with internal/fuzz.workerExitCode.
const fuzzWorkerExitCode = 70
// InternalFuzzTarget is an internal type but exported because it is
// cross-package; it is part of the implementation of the "go test" command.
type InternalFuzzTarget struct {
Name string
Fn func(f *F)
}
// F is a type passed to fuzz tests.
//
// Fuzz tests run generated inputs against a provided fuzz target, which can
// find and report potential bugs in the code being tested.
//
// A fuzz test runs the seed corpus by default, which includes entries provided
// by (*F).Add and entries in the testdata/fuzz/<FuzzTestName> directory. After
// any necessary setup and calls to (*F).Add, the fuzz test must then call
// (*F).Fuzz to provide the fuzz target. See the testing package documentation
// for an example, and see the F.Fuzz and F.Add method documentation for
// details.
//
// *F methods can only be called before (*F).Fuzz. Once the test is
// executing the fuzz target, only (*T) methods can be used. The only *F methods
// that are allowed in the (*F).Fuzz function are (*F).Failed and (*F).Name.
type F struct {
common
fuzzContext *fuzzContext
testContext *testContext
// inFuzzFn is true when the fuzz function is running. Most F methods cannot
// be called when inFuzzFn is true.
inFuzzFn bool
// corpus is a set of seed corpus entries, added with F.Add and loaded
// from testdata.
corpus []corpusEntry
result fuzzResult
fuzzCalled bool
}
var _ TB = (*F)(nil)
// corpusEntry is an alias to the same type as internal/fuzz.CorpusEntry.
// We use a type alias because we don't want to export this type, and we can't
// import internal/fuzz from testing.
type corpusEntry = struct {
Parent string
Path string
Data []byte
Values []any
Generation int
IsSeed bool
}
// Helper marks the calling function as a test helper function.
// When printing file and line information, that function will be skipped.
// Helper may be called simultaneously from multiple goroutines.
func (f *F) Helper() {
if f.inFuzzFn {
panic("testing: f.Helper was called inside the fuzz target, use t.Helper instead")
}
// common.Helper is inlined here.
// If we called it, it would mark F.Helper as the helper
// instead of the caller.
f.mu.Lock()
defer f.mu.Unlock()
if f.helperPCs == nil {
f.helperPCs = make(map[uintptr]struct{})
}
// repeating code from callerName here to save walking a stack frame
var pc [1]uintptr
n := runtime.Callers(2, pc[:]) // skip runtime.Callers + Helper
if n == 0 {
panic("testing: zero callers found")
}
if _, found := f.helperPCs[pc[0]]; !found {
f.helperPCs[pc[0]] = struct{}{}
f.helperNames = nil // map will be recreated next time it is needed
}
}
// Fail marks the function as having failed but continues execution.
func (f *F) Fail() {
// (*F).Fail may be called by (*T).Fail, which we should allow. However, we
// shouldn't allow direct (*F).Fail calls from inside the (*F).Fuzz function.
if f.inFuzzFn {
panic("testing: f.Fail was called inside the fuzz target, use t.Fail instead")
}
f.common.Helper()
f.common.Fail()
}
// Skipped reports whether the test was skipped.
func (f *F) Skipped() bool {
// (*F).Skipped may be called by tRunner, which we should allow. However, we
// shouldn't allow direct (*F).Skipped calls from inside the (*F).Fuzz function.
if f.inFuzzFn {
panic("testing: f.Skipped was called inside the fuzz target, use t.Skipped instead")
}
f.common.Helper()
return f.common.Skipped()
}
// Add will add the arguments to the seed corpus for the fuzz test. This will be
// a no-op if called after or within the fuzz target, and args must match the
// arguments for the fuzz target.
func (f *F) Add(args ...any) {
var values []any
for i := range args {
if t := reflect.TypeOf(args[i]); !supportedTypes[t] {
panic(fmt.Sprintf("testing: unsupported type to Add %v", t))
}
values = append(values, args[i])
}
f.corpus = append(f.corpus, corpusEntry{Values: values, IsSeed: true, Path: fmt.Sprintf("seed#%d", len(f.corpus))})
}
// supportedTypes represents all of the supported types which can be fuzzed.
var supportedTypes = map[reflect.Type]bool{
reflect.TypeOf(([]byte)("")): true,
reflect.TypeOf((string)("")): true,
reflect.TypeOf((bool)(false)): true,
reflect.TypeOf((byte)(0)): true,
reflect.TypeOf((rune)(0)): true,
reflect.TypeOf((float32)(0)): true,
reflect.TypeOf((float64)(0)): true,
reflect.TypeOf((int)(0)): true,
reflect.TypeOf((int8)(0)): true,
reflect.TypeOf((int16)(0)): true,
reflect.TypeOf((int32)(0)): true,
reflect.TypeOf((int64)(0)): true,
reflect.TypeOf((uint)(0)): true,
reflect.TypeOf((uint8)(0)): true,
reflect.TypeOf((uint16)(0)): true,
reflect.TypeOf((uint32)(0)): true,
reflect.TypeOf((uint64)(0)): true,
}
// Fuzz runs the fuzz function, ff, for fuzz testing. If ff fails for a set of
// arguments, those arguments will be added to the seed corpus.
//
// ff must be a function with no return value whose first argument is *T and
// whose remaining arguments are the types to be fuzzed.
// For example:
//
// f.Fuzz(func(t *testing.T, b []byte, i int) { ... })
//
// The following types are allowed: []byte, string, bool, byte, rune, float32,
// float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64.
// More types may be supported in the future.
//
// ff must not call any *F methods, e.g. (*F).Log, (*F).Error, (*F).Skip. Use
// the corresponding *T method instead. The only *F methods that are allowed in
// the (*F).Fuzz function are (*F).Failed and (*F).Name.
//
// This function should be fast and deterministic, and its behavior should not
// depend on shared state. No mutatable input arguments, or pointers to them,
// should be retained between executions of the fuzz function, as the memory
// backing them may be mutated during a subsequent invocation. ff must not
// modify the underlying data of the arguments provided by the fuzzing engine.
//
// When fuzzing, F.Fuzz does not return until a problem is found, time runs out
// (set with -fuzztime), or the test process is interrupted by a signal. F.Fuzz
// should be called exactly once, unless F.Skip or F.Fail is called beforehand.
func (f *F) Fuzz(ff any) {
if f.fuzzCalled {
panic("testing: F.Fuzz called more than once")
}
f.fuzzCalled = true
if f.failed {
return
}
f.Helper()
// ff should be in the form func(*testing.T, ...interface{})
fn := reflect.ValueOf(ff)
fnType := fn.Type()
if fnType.Kind() != reflect.Func {
panic("testing: F.Fuzz must receive a function")
}
if fnType.NumIn() < 2 || fnType.In(0) != reflect.TypeOf((*T)(nil)) {
panic("testing: fuzz target must receive at least two arguments, where the first argument is a *T")
}
if fnType.NumOut() != 0 {
panic("testing: fuzz target must not return a value")
}
// Save the types of the function to compare against the corpus.
var types []reflect.Type
for i := 1; i < fnType.NumIn(); i++ {
t := fnType.In(i)
if !supportedTypes[t] {
panic(fmt.Sprintf("testing: unsupported type for fuzzing %v", t))
}
types = append(types, t)
}
// Load the testdata seed corpus. Check types of entries in the testdata
// corpus and entries declared with F.Add.
//
// Don't load the seed corpus if this is a worker process; we won't use it.
if f.fuzzContext.mode != fuzzWorker {
for _, c := range f.corpus {
if err := f.fuzzContext.deps.CheckCorpus(c.Values, types); err != nil {
// TODO(#48302): Report the source location of the F.Add call.
f.Fatal(err)
}
}
// Load seed corpus
c, err := f.fuzzContext.deps.ReadCorpus(filepath.Join(corpusDir, f.name), types)
if err != nil {
f.Fatal(err)
}
for i := range c {
c[i].IsSeed = true // these are all seed corpus values
if f.fuzzContext.mode == fuzzCoordinator {
// If this is the coordinator process, zero the values, since we don't need
// to hold onto them.
c[i].Values = nil
}
}
f.corpus = append(f.corpus, c...)
}
// run calls fn on a given input, as a subtest with its own T.
// run is analogous to T.Run. The test filtering and cleanup works similarly.
// fn is called in its own goroutine.
run := func(captureOut io.Writer, e corpusEntry) (ok bool) {
if e.Values == nil {
// The corpusEntry must have non-nil Values in order to run the
// test. If Values is nil, it is a bug in our code.
panic(fmt.Sprintf("corpus file %q was not unmarshaled", e.Path))
}
if shouldFailFast() {
return true
}
testName := f.name
if e.Path != "" {
testName = fmt.Sprintf("%s/%s", testName, filepath.Base(e.Path))
}
if f.testContext.isFuzzing {
// Don't preserve subtest names while fuzzing. If fn calls T.Run,
// there will be a very large number of subtests with duplicate names,
// which will use a large amount of memory. The subtest names aren't
// useful since there's no way to re-run them deterministically.
f.testContext.match.clearSubNames()
}
// Record the stack trace at the point of this call so that if the subtest
// function - which runs in a separate stack - is marked as a helper, we can
// continue walking the stack into the parent test.
var pc [maxStackLen]uintptr
n := runtime.Callers(2, pc[:])
t := &T{
common: common{
barrier: make(chan bool),
signal: make(chan bool),
name: testName,
parent: &f.common,
level: f.level + 1,
creator: pc[:n],
chatty: f.chatty,
},
context: f.testContext,
}
if captureOut != nil {
// t.parent aliases f.common.
t.parent.w = captureOut
}
t.w = indenter{&t.common}
if t.chatty != nil {
t.chatty.Updatef(t.name, "=== RUN %s\n", t.name)
}
f.common.inFuzzFn, f.inFuzzFn = true, true
go tRunner(t, func(t *T) {
args := []reflect.Value{reflect.ValueOf(t)}
for _, v := range e.Values {
args = append(args, reflect.ValueOf(v))
}
// Before resetting the current coverage, defer the snapshot so that
// we make sure it is called right before the tRunner function
// exits, regardless of whether it was executed cleanly, panicked,
// or if the fuzzFn called t.Fatal.
if f.testContext.isFuzzing {
defer f.fuzzContext.deps.SnapshotCoverage()
f.fuzzContext.deps.ResetCoverage()
}
fn.Call(args)
})
<-t.signal
if t.chatty != nil && t.chatty.json {
t.chatty.Updatef(t.parent.name, "=== NAME %s\n", t.parent.name)
}
f.common.inFuzzFn, f.inFuzzFn = false, false
return !t.Failed()
}
switch f.fuzzContext.mode {
case fuzzCoordinator:
// Fuzzing is enabled, and this is the test process started by 'go test'.
// Act as the coordinator process, and coordinate workers to perform the
// actual fuzzing.
corpusTargetDir := filepath.Join(corpusDir, f.name)
cacheTargetDir := filepath.Join(*fuzzCacheDir, f.name)
err := f.fuzzContext.deps.CoordinateFuzzing(
fuzzDuration.d,
int64(fuzzDuration.n),
minimizeDuration.d,
int64(minimizeDuration.n),
*parallel,
f.corpus,
types,
corpusTargetDir,
cacheTargetDir)
if err != nil {
f.result = fuzzResult{Error: err}
f.Fail()
fmt.Fprintf(f.w, "%v\n", err)
if crashErr, ok := err.(fuzzCrashError); ok {
crashPath := crashErr.CrashPath()
fmt.Fprintf(f.w, "Failing input written to %s\n", crashPath)
testName := filepath.Base(crashPath)
fmt.Fprintf(f.w, "To re-run:\ngo test -run=%s/%s\n", f.name, testName)
}
}
// TODO(jayconrod,katiehockman): Aggregate statistics across workers
// and add to FuzzResult (ie. time taken, num iterations)
case fuzzWorker:
// Fuzzing is enabled, and this is a worker process. Follow instructions
// from the coordinator.
if err := f.fuzzContext.deps.RunFuzzWorker(func(e corpusEntry) error {
// Don't write to f.w (which points to Stdout) if running from a
// fuzz worker. This would become very verbose, particularly during
// minimization. Return the error instead, and let the caller deal
// with the output.
var buf strings.Builder
if ok := run(&buf, e); !ok {
return errors.New(buf.String())
}
return nil
}); err != nil {
// Internal errors are marked with f.Fail; user code may call this too, before F.Fuzz.
// The worker will exit with fuzzWorkerExitCode, indicating this is a failure
// (and 'go test' should exit non-zero) but a failing input should not be recorded.
f.Errorf("communicating with fuzzing coordinator: %v", err)
}
default:
// Fuzzing is not enabled, or will be done later. Only run the seed
// corpus now.
for _, e := range f.corpus {
name := fmt.Sprintf("%s/%s", f.name, filepath.Base(e.Path))
if _, ok, _ := f.testContext.match.fullName(nil, name); ok {
run(f.w, e)
}
}
}
}
func (f *F) report() {
if *isFuzzWorker || f.parent == nil {
return
}
dstr := fmtDuration(f.duration)
format := "--- %s: %s (%s)\n"
if f.Failed() {
f.flushToParent(f.name, format, "FAIL", f.name, dstr)
} else if f.chatty != nil {
if f.Skipped() {
f.flushToParent(f.name, format, "SKIP", f.name, dstr)
} else {
f.flushToParent(f.name, format, "PASS", f.name, dstr)
}
}
}
// fuzzResult contains the results of a fuzz run.
type fuzzResult struct {
N int // The number of iterations.
T time.Duration // The total time taken.
Error error // Error is the error from the failing input
}
func (r fuzzResult) String() string {
if r.Error == nil {
return ""
}
return r.Error.Error()
}
// fuzzCrashError is satisfied by a failing input detected while fuzzing.
// These errors are written to the seed corpus and can be re-run with 'go test'.
// Errors within the fuzzing framework (like I/O errors between coordinator
// and worker processes) don't satisfy this interface.
type fuzzCrashError interface {
error
Unwrap() error
// CrashPath returns the path of the subtest that corresponds to the saved
// crash input file in the seed corpus. The test can be re-run with go test
// -run=$test/$name $test is the fuzz test name, and $name is the
// filepath.Base of the string returned here.
CrashPath() string
}
// fuzzContext holds fields common to all fuzz tests.
type fuzzContext struct {
deps testDeps
mode fuzzMode
}
type fuzzMode uint8
const (
seedCorpusOnly fuzzMode = iota
fuzzCoordinator
fuzzWorker
)
// runFuzzTests runs the fuzz tests matching the pattern for -run. This will
// only run the (*F).Fuzz function for each seed corpus without using the
// fuzzing engine to generate or mutate inputs.
func runFuzzTests(deps testDeps, fuzzTests []InternalFuzzTarget, deadline time.Time) (ran, ok bool) {
ok = true
if len(fuzzTests) == 0 || *isFuzzWorker {
return ran, ok
}
m := newMatcher(deps.MatchString, *match, "-test.run", *skip)
var mFuzz *matcher
if *matchFuzz != "" {
mFuzz = newMatcher(deps.MatchString, *matchFuzz, "-test.fuzz", *skip)
}
for _, procs := range cpuList {
runtime.GOMAXPROCS(procs)
for i := uint(0); i < *count; i++ {
if shouldFailFast() {
break
}
tctx := newTestContext(*parallel, m)
tctx.deadline = deadline
fctx := &fuzzContext{deps: deps, mode: seedCorpusOnly}
root := common{w: os.Stdout} // gather output in one place
if Verbose() {
root.chatty = newChattyPrinter(root.w)
}
for _, ft := range fuzzTests {
if shouldFailFast() {
break
}
testName, matched, _ := tctx.match.fullName(nil, ft.Name)
if !matched {
continue
}
if mFuzz != nil {
if _, fuzzMatched, _ := mFuzz.fullName(nil, ft.Name); fuzzMatched {
// If this will be fuzzed, then don't run the seed corpus
// right now. That will happen later.
continue
}
}
f := &F{
common: common{
signal: make(chan bool),
barrier: make(chan bool),
name: testName,
parent: &root,
level: root.level + 1,
chatty: root.chatty,
},
testContext: tctx,
fuzzContext: fctx,
}
f.w = indenter{&f.common}
if f.chatty != nil {
f.chatty.Updatef(f.name, "=== RUN %s\n", f.name)
}
go fRunner(f, ft.Fn)
<-f.signal
if f.chatty != nil && f.chatty.json {
f.chatty.Updatef(f.parent.name, "=== NAME %s\n", f.parent.name)
}
ok = ok && !f.Failed()
ran = ran || f.ran
}
if !ran {
// There were no tests to run on this iteration.
// This won't change, so no reason to keep trying.
break
}
}
}
return ran, ok
}
// runFuzzing runs the fuzz test matching the pattern for -fuzz. Only one such
// fuzz test must match. This will run the fuzzing engine to generate and
// mutate new inputs against the fuzz target.
//
// If fuzzing is disabled (-test.fuzz is not set), runFuzzing
// returns immediately.
func runFuzzing(deps testDeps, fuzzTests []InternalFuzzTarget) (ok bool) {
if len(fuzzTests) == 0 || *matchFuzz == "" {
return true
}
m := newMatcher(deps.MatchString, *matchFuzz, "-test.fuzz", *skip)
tctx := newTestContext(1, m)
tctx.isFuzzing = true
fctx := &fuzzContext{
deps: deps,
}
root := common{w: os.Stdout}
if *isFuzzWorker {
root.w = io.Discard
fctx.mode = fuzzWorker
} else {
fctx.mode = fuzzCoordinator
}
if Verbose() && !*isFuzzWorker {
root.chatty = newChattyPrinter(root.w)
}
var fuzzTest *InternalFuzzTarget
var testName string
var matched []string
for i := range fuzzTests {
name, ok, _ := tctx.match.fullName(nil, fuzzTests[i].Name)
if !ok {
continue
}
matched = append(matched, name)
fuzzTest = &fuzzTests[i]
testName = name
}
if len(matched) == 0 {
fmt.Fprintln(os.Stderr, "testing: warning: no fuzz tests to fuzz")
return true
}
if len(matched) > 1 {
fmt.Fprintf(os.Stderr, "testing: will not fuzz, -fuzz matches more than one fuzz test: %v\n", matched)
return false
}
f := &F{
common: common{
signal: make(chan bool),
barrier: nil, // T.Parallel has no effect when fuzzing.
name: testName,
parent: &root,
level: root.level + 1,
chatty: root.chatty,
},
fuzzContext: fctx,
testContext: tctx,
}
f.w = indenter{&f.common}
if f.chatty != nil {
f.chatty.Updatef(f.name, "=== RUN %s\n", f.name)
}
go fRunner(f, fuzzTest.Fn)
<-f.signal
if f.chatty != nil {
f.chatty.Updatef(f.parent.name, "=== NAME %s\n", f.parent.name)
}
return !f.failed
}
// fRunner wraps a call to a fuzz test and ensures that cleanup functions are
// called and status flags are set. fRunner should be called in its own
// goroutine. To wait for its completion, receive from f.signal.
//
// fRunner is analogous to tRunner, which wraps subtests started with T.Run.
// Unit tests and fuzz tests work a little differently, so for now, these
// functions aren't consolidated. In particular, because there are no F.Run and
// F.Parallel methods, i.e., no fuzz sub-tests or parallel fuzz tests, a few
// simplifications are made. We also require that F.Fuzz, F.Skip, or F.Fail is
// called.
func fRunner(f *F, fn func(*F)) {
// When this goroutine is done, either because runtime.Goexit was called, a
// panic started, or fn returned normally, record the duration and send
// t.signal, indicating the fuzz test is done.
defer func() {
// Detect whether the fuzz test panicked or called runtime.Goexit
// without calling F.Fuzz, F.Fail, or F.Skip. If it did, panic (possibly
// replacing a nil panic value). Nothing should recover after fRunner
// unwinds, so this should crash the process and print stack.
// Unfortunately, recovering here adds stack frames, but the location of
// the original panic should still be
// clear.
if f.Failed() {
numFailed.Add(1)
}
err := recover()
if err == nil {
f.mu.RLock()
fuzzNotCalled := !f.fuzzCalled && !f.skipped && !f.failed
if !f.finished && !f.skipped && !f.failed {
err = errNilPanicOrGoexit
}
f.mu.RUnlock()
if fuzzNotCalled && err == nil {
f.Error("returned without calling F.Fuzz, F.Fail, or F.Skip")
}
}
// Use a deferred call to ensure that we report that the test is
// complete even if a cleanup function calls F.FailNow. See issue 41355.
didPanic := false
defer func() {
if !didPanic {
// Only report that the test is complete if it doesn't panic,
// as otherwise the test binary can exit before the panic is
// reported to the user. See issue 41479.
f.signal <- true
}
}()
// If we recovered a panic or inappropriate runtime.Goexit, fail the test,
// flush the output log up to the root, then panic.
doPanic := func(err any) {
f.Fail()
if r := f.runCleanup(recoverAndReturnPanic); r != nil {
f.Logf("cleanup panicked with %v", r)
}
for root := &f.common; root.parent != nil; root = root.parent {
root.mu.Lock()
root.duration += time.Since(root.start)
d := root.duration
root.mu.Unlock()
root.flushToParent(root.name, "--- FAIL: %s (%s)\n", root.name, fmtDuration(d))
}
didPanic = true
panic(err)
}
if err != nil {
doPanic(err)
}
// No panic or inappropriate Goexit.
f.duration += time.Since(f.start)
if len(f.sub) > 0 {
// Unblock inputs that called T.Parallel while running the seed corpus.
// This only affects fuzz tests run as normal tests.
// While fuzzing, T.Parallel has no effect, so f.sub is empty, and this
// branch is not taken. f.barrier is nil in that case.
f.testContext.release()
close(f.barrier)
// Wait for the subtests to complete.
for _, sub := range f.sub {
<-sub.signal
}
cleanupStart := time.Now()
err := f.runCleanup(recoverAndReturnPanic)
f.duration += time.Since(cleanupStart)
if err != nil {
doPanic(err)
}
}
// Report after all subtests have finished.
f.report()
f.done = true
f.setRan()
}()
defer func() {
if len(f.sub) == 0 {
f.runCleanup(normalPanic)
}
}()
f.start = time.Now()
fn(f)
// Code beyond this point will not be executed when FailNow or SkipNow
// is invoked.
f.mu.Lock()
f.finished = true
f.mu.Unlock()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package iotest
import (
"io"
"log"
)
type writeLogger struct {
prefix string
w io.Writer
}
func (l *writeLogger) Write(p []byte) (n int, err error) {
n, err = l.w.Write(p)
if err != nil {
log.Printf("%s %x: %v", l.prefix, p[0:n], err)
} else {
log.Printf("%s %x", l.prefix, p[0:n])
}
return
}
// NewWriteLogger returns a writer that behaves like w except
// that it logs (using log.Printf) each write to standard error,
// printing the prefix and the hexadecimal data written.
func NewWriteLogger(prefix string, w io.Writer) io.Writer {
return &writeLogger{prefix, w}
}
type readLogger struct {
prefix string
r io.Reader
}
func (l *readLogger) Read(p []byte) (n int, err error) {
n, err = l.r.Read(p)
if err != nil {
log.Printf("%s %x: %v", l.prefix, p[0:n], err)
} else {
log.Printf("%s %x", l.prefix, p[0:n])
}
return
}
// NewReadLogger returns a reader that behaves like r except
// that it logs (using log.Printf) each read to standard error,
// printing the prefix and the hexadecimal data read.
func NewReadLogger(prefix string, r io.Reader) io.Reader {
return &readLogger{prefix, r}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package iotest implements Readers and Writers useful mainly for testing.
package iotest
import (
"bytes"
"errors"
"fmt"
"io"
)
// OneByteReader returns a Reader that implements
// each non-empty Read by reading one byte from r.
func OneByteReader(r io.Reader) io.Reader { return &oneByteReader{r} }
type oneByteReader struct {
r io.Reader
}
func (r *oneByteReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
return r.r.Read(p[0:1])
}
// HalfReader returns a Reader that implements Read
// by reading half as many requested bytes from r.
func HalfReader(r io.Reader) io.Reader { return &halfReader{r} }
type halfReader struct {
r io.Reader
}
func (r *halfReader) Read(p []byte) (int, error) {
return r.r.Read(p[0 : (len(p)+1)/2])
}
// DataErrReader changes the way errors are handled by a Reader. Normally, a
// Reader returns an error (typically EOF) from the first Read call after the
// last piece of data is read. DataErrReader wraps a Reader and changes its
// behavior so the final error is returned along with the final data, instead
// of in the first call after the final data.
func DataErrReader(r io.Reader) io.Reader { return &dataErrReader{r, nil, make([]byte, 1024)} }
type dataErrReader struct {
r io.Reader
unread []byte
data []byte
}
func (r *dataErrReader) Read(p []byte) (n int, err error) {
// loop because first call needs two reads:
// one to get data and a second to look for an error.
for {
if len(r.unread) == 0 {
n1, err1 := r.r.Read(r.data)
r.unread = r.data[0:n1]
err = err1
}
if n > 0 || err != nil {
break
}
n = copy(p, r.unread)
r.unread = r.unread[n:]
}
return
}
// ErrTimeout is a fake timeout error.
var ErrTimeout = errors.New("timeout")
// TimeoutReader returns ErrTimeout on the second read
// with no data. Subsequent calls to read succeed.
func TimeoutReader(r io.Reader) io.Reader { return &timeoutReader{r, 0} }
type timeoutReader struct {
r io.Reader
count int
}
func (r *timeoutReader) Read(p []byte) (int, error) {
r.count++
if r.count == 2 {
return 0, ErrTimeout
}
return r.r.Read(p)
}
// ErrReader returns an io.Reader that returns 0, err from all Read calls.
func ErrReader(err error) io.Reader {
return &errReader{err: err}
}
type errReader struct {
err error
}
func (r *errReader) Read(p []byte) (int, error) {
return 0, r.err
}
type smallByteReader struct {
r io.Reader
off int
n int
}
func (r *smallByteReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
r.n = r.n%3 + 1
n := r.n
if n > len(p) {
n = len(p)
}
n, err := r.r.Read(p[0:n])
if err != nil && err != io.EOF {
err = fmt.Errorf("Read(%d bytes at offset %d): %v", n, r.off, err)
}
r.off += n
return n, err
}
// TestReader tests that reading from r returns the expected file content.
// It does reads of different sizes, until EOF.
// If r implements io.ReaderAt or io.Seeker, TestReader also checks
// that those operations behave as they should.
//
// If TestReader finds any misbehaviors, it returns an error reporting them.
// The error text may span multiple lines.
func TestReader(r io.Reader, content []byte) error {
if len(content) > 0 {
n, err := r.Read(nil)
if n != 0 || err != nil {
return fmt.Errorf("Read(0) = %d, %v, want 0, nil", n, err)
}
}
data, err := io.ReadAll(&smallByteReader{r: r})
if err != nil {
return err
}
if !bytes.Equal(data, content) {
return fmt.Errorf("ReadAll(small amounts) = %q\n\twant %q", data, content)
}
n, err := r.Read(make([]byte, 10))
if n != 0 || err != io.EOF {
return fmt.Errorf("Read(10) at EOF = %v, %v, want 0, EOF", n, err)
}
if r, ok := r.(io.ReadSeeker); ok {
// Seek(0, 1) should report the current file position (EOF).
if off, err := r.Seek(0, 1); off != int64(len(content)) || err != nil {
return fmt.Errorf("Seek(0, 1) from EOF = %d, %v, want %d, nil", off, err, len(content))
}
// Seek backward partway through file, in two steps.
// If middle == 0, len(content) == 0, can't use the -1 and +1 seeks.
middle := len(content) - len(content)/3
if middle > 0 {
if off, err := r.Seek(-1, 1); off != int64(len(content)-1) || err != nil {
return fmt.Errorf("Seek(-1, 1) from EOF = %d, %v, want %d, nil", -off, err, len(content)-1)
}
if off, err := r.Seek(int64(-len(content)/3), 1); off != int64(middle-1) || err != nil {
return fmt.Errorf("Seek(%d, 1) from %d = %d, %v, want %d, nil", -len(content)/3, len(content)-1, off, err, middle-1)
}
if off, err := r.Seek(+1, 1); off != int64(middle) || err != nil {
return fmt.Errorf("Seek(+1, 1) from %d = %d, %v, want %d, nil", middle-1, off, err, middle)
}
}
// Seek(0, 1) should report the current file position (middle).
if off, err := r.Seek(0, 1); off != int64(middle) || err != nil {
return fmt.Errorf("Seek(0, 1) from %d = %d, %v, want %d, nil", middle, off, err, middle)
}
// Reading forward should return the last part of the file.
data, err := io.ReadAll(&smallByteReader{r: r})
if err != nil {
return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
}
if !bytes.Equal(data, content[middle:]) {
return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
}
// Seek relative to end of file, but start elsewhere.
if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
}
if off, err := r.Seek(int64(-len(content)/3), 2); off != int64(middle) || err != nil {
return fmt.Errorf("Seek(%d, 2) from %d = %d, %v, want %d, nil", -len(content)/3, middle/2, off, err, middle)
}
// Reading forward should return the last part of the file (again).
data, err = io.ReadAll(&smallByteReader{r: r})
if err != nil {
return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
}
if !bytes.Equal(data, content[middle:]) {
return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
}
// Absolute seek & read forward.
if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
}
data, err = io.ReadAll(r)
if err != nil {
return fmt.Errorf("ReadAll from offset %d: %v", middle/2, err)
}
if !bytes.Equal(data, content[middle/2:]) {
return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle/2, data, content[middle/2:])
}
}
if r, ok := r.(io.ReaderAt); ok {
data := make([]byte, len(content), len(content)+1)
for i := range data {
data[i] = 0xfe
}
n, err := r.ReadAt(data, 0)
if n != len(data) || err != nil && err != io.EOF {
return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, nil or EOF", len(data), n, err, len(data))
}
if !bytes.Equal(data, content) {
return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
}
n, err = r.ReadAt(data[:1], int64(len(data)))
if n != 0 || err != io.EOF {
return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 0, EOF", len(data), n, err)
}
for i := range data {
data[i] = 0xfe
}
n, err = r.ReadAt(data[:cap(data)], 0)
if n != len(data) || err != io.EOF {
return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, EOF", cap(data), n, err, len(data))
}
if !bytes.Equal(data, content) {
return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
}
for i := range data {
data[i] = 0xfe
}
for i := range data {
n, err = r.ReadAt(data[i:i+1], int64(i))
if n != 1 || err != nil && (i != len(data)-1 || err != io.EOF) {
want := "nil"
if i == len(data)-1 {
want = "nil or EOF"
}
return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 1, %s", i, n, err, want)
}
if data[i] != content[i] {
return fmt.Errorf("ReadAt(1, %d) = %q want %q", i, data[i:i+1], content[i:i+1])
}
}
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package iotest
import "io"
// TruncateWriter returns a Writer that writes to w
// but stops silently after n bytes.
func TruncateWriter(w io.Writer, n int64) io.Writer {
return &truncateWriter{w, n}
}
type truncateWriter struct {
w io.Writer
n int64
}
func (t *truncateWriter) Write(p []byte) (n int, err error) {
if t.n <= 0 {
return len(p), nil
}
// real write
n = len(p)
if int64(n) > t.n {
n = int(t.n)
}
n, err = t.w.Write(p[0:n])
t.n -= int64(n)
if err == nil {
n = len(p)
}
return
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testing
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
)
// matcher sanitizes, uniques, and filters names of subtests and subbenchmarks.
type matcher struct {
filter filterMatch
skip filterMatch
matchFunc func(pat, str string) (bool, error)
mu sync.Mutex
// subNames is used to deduplicate subtest names.
// Each key is the subtest name joined to the deduplicated name of the parent test.
// Each value is the count of the number of occurrences of the given subtest name
// already seen.
subNames map[string]int32
}
type filterMatch interface {
// matches checks the name against the receiver's pattern strings using the
// given match function.
matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool)
// verify checks that the receiver's pattern strings are valid filters by
// calling the given match function.
verify(name string, matchString func(pat, str string) (bool, error)) error
}
// simpleMatch matches a test name if all of the pattern strings match in
// sequence.
type simpleMatch []string
// alternationMatch matches a test name if one of the alternations match.
type alternationMatch []filterMatch
// TODO: fix test_main to avoid race and improve caching, also allowing to
// eliminate this Mutex.
var matchMutex sync.Mutex
func allMatcher() *matcher {
return newMatcher(nil, "", "", "")
}
func newMatcher(matchString func(pat, str string) (bool, error), patterns, name, skips string) *matcher {
var filter, skip filterMatch
if patterns == "" {
filter = simpleMatch{} // always partial true
} else {
filter = splitRegexp(patterns)
if err := filter.verify(name, matchString); err != nil {
fmt.Fprintf(os.Stderr, "testing: invalid regexp for %s\n", err)
os.Exit(1)
}
}
if skips == "" {
skip = alternationMatch{} // always false
} else {
skip = splitRegexp(skips)
if err := skip.verify("-test.skip", matchString); err != nil {
fmt.Fprintf(os.Stderr, "testing: invalid regexp for %v\n", err)
os.Exit(1)
}
}
return &matcher{
filter: filter,
skip: skip,
matchFunc: matchString,
subNames: map[string]int32{},
}
}
func (m *matcher) fullName(c *common, subname string) (name string, ok, partial bool) {
name = subname
m.mu.Lock()
defer m.mu.Unlock()
if c != nil && c.level > 0 {
name = m.unique(c.name, rewrite(subname))
}
matchMutex.Lock()
defer matchMutex.Unlock()
// We check the full array of paths each time to allow for the case that a pattern contains a '/'.
elem := strings.Split(name, "/")
// filter must match.
// accept partial match that may produce full match later.
ok, partial = m.filter.matches(elem, m.matchFunc)
if !ok {
return name, false, false
}
// skip must not match.
// ignore partial match so we can get to more precise match later.
skip, partialSkip := m.skip.matches(elem, m.matchFunc)
if skip && !partialSkip {
return name, false, false
}
return name, ok, partial
}
// clearSubNames clears the matcher's internal state, potentially freeing
// memory. After this is called, T.Name may return the same strings as it did
// for earlier subtests.
func (m *matcher) clearSubNames() {
m.mu.Lock()
defer m.mu.Unlock()
for key := range m.subNames {
delete(m.subNames, key)
}
}
func (m simpleMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) {
for i, s := range name {
if i >= len(m) {
break
}
if ok, _ := matchString(m[i], s); !ok {
return false, false
}
}
return true, len(name) < len(m)
}
func (m simpleMatch) verify(name string, matchString func(pat, str string) (bool, error)) error {
for i, s := range m {
m[i] = rewrite(s)
}
// Verify filters before doing any processing.
for i, s := range m {
if _, err := matchString(s, "non-empty"); err != nil {
return fmt.Errorf("element %d of %s (%q): %s", i, name, s, err)
}
}
return nil
}
func (m alternationMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) {
for _, m := range m {
if ok, partial = m.matches(name, matchString); ok {
return ok, partial
}
}
return false, false
}
func (m alternationMatch) verify(name string, matchString func(pat, str string) (bool, error)) error {
for i, m := range m {
if err := m.verify(name, matchString); err != nil {
return fmt.Errorf("alternation %d of %s", i, err)
}
}
return nil
}
func splitRegexp(s string) filterMatch {
a := make(simpleMatch, 0, strings.Count(s, "/"))
b := make(alternationMatch, 0, strings.Count(s, "|"))
cs := 0
cp := 0
for i := 0; i < len(s); {
switch s[i] {
case '[':
cs++
case ']':
if cs--; cs < 0 { // An unmatched ']' is legal.
cs = 0
}
case '(':
if cs == 0 {
cp++
}
case ')':
if cs == 0 {
cp--
}
case '\\':
i++
case '/':
if cs == 0 && cp == 0 {
a = append(a, s[:i])
s = s[i+1:]
i = 0
continue
}
case '|':
if cs == 0 && cp == 0 {
a = append(a, s[:i])
s = s[i+1:]
i = 0
b = append(b, a)
a = make(simpleMatch, 0, len(a))
continue
}
}
i++
}
a = append(a, s)
if len(b) == 0 {
return a
}
return append(b, a)
}
// unique creates a unique name for the given parent and subname by affixing it
// with one or more counts, if necessary.
func (m *matcher) unique(parent, subname string) string {
base := parent + "/" + subname
for {
n := m.subNames[base]
if n < 0 {
panic("subtest count overflow")
}
m.subNames[base] = n + 1
if n == 0 && subname != "" {
prefix, nn := parseSubtestNumber(base)
if len(prefix) < len(base) && nn < m.subNames[prefix] {
// This test is explicitly named like "parent/subname#NN",
// and #NN was already used for the NNth occurrence of "parent/subname".
// Loop to add a disambiguating suffix.
continue
}
return base
}
name := fmt.Sprintf("%s#%02d", base, n)
if m.subNames[name] != 0 {
// This is the nth occurrence of base, but the name "parent/subname#NN"
// collides with the first occurrence of a subtest *explicitly* named
// "parent/subname#NN". Try the next number.
continue
}
return name
}
}
// parseSubtestNumber splits a subtest name into a "#%02d"-formatted int32
// suffix (if present), and a prefix preceding that suffix (always).
func parseSubtestNumber(s string) (prefix string, nn int32) {
i := strings.LastIndex(s, "#")
if i < 0 {
return s, 0
}
prefix, suffix := s[:i], s[i+1:]
if len(suffix) < 2 || (len(suffix) > 2 && suffix[0] == '0') {
// Even if suffix is numeric, it is not a possible output of a "%02" format
// string: it has either too few digits or too many leading zeroes.
return s, 0
}
if suffix == "00" {
if !strings.HasSuffix(prefix, "/") {
// We only use "#00" as a suffix for subtests named with the empty
// string — it isn't a valid suffix if the subtest name is non-empty.
return s, 0
}
}
n, err := strconv.ParseInt(suffix, 10, 32)
if err != nil || n < 0 {
return s, 0
}
return prefix, int32(n)
}
// rewrite rewrites a subname to having only printable characters and no white
// space.
func rewrite(s string) string {
b := []byte{}
for _, r := range s {
switch {
case isSpace(r):
b = append(b, '_')
case !strconv.IsPrint(r):
s := strconv.QuoteRune(r)
b = append(b, s[1:len(s)-1]...)
default:
b = append(b, string(r)...)
}
}
return string(b)
}
func isSpace(r rune) bool {
if r < 0x2000 {
switch r {
// Note: not the same as Unicode Z class.
case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0, 0x1680:
return true
}
} else {
if r <= 0x200a {
return true
}
switch r {
case 0x2028, 0x2029, 0x202f, 0x205f, 0x3000:
return true
}
}
return false
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Support for test coverage with redesigned coverage implementation.
package testing
import (
"fmt"
"internal/goexperiment"
"os"
)
// cover2 variable stores the current coverage mode and a
// tear-down function to be called at the end of the testing run.
var cover2 struct {
mode string
tearDown func(coverprofile string, gocoverdir string) (string, error)
}
// registerCover2 is invoked during "go test -cover" runs by the test harness
// code in _testmain.go; it is used to record a 'tear down' function
// (to be called when the test is complete) and the coverage mode.
func registerCover2(mode string, tearDown func(coverprofile string, gocoverdir string) (string, error)) {
cover2.mode = mode
cover2.tearDown = tearDown
}
// coverReport2 invokes a callback in _testmain.go that will
// emit coverage data at the point where test execution is complete,
// for "go test -cover" runs.
func coverReport2() {
if !goexperiment.CoverageRedesign {
panic("unexpected")
}
if errmsg, err := cover2.tearDown(*coverProfile, *gocoverdir); err != nil {
fmt.Fprintf(os.Stderr, "%s: %v\n", errmsg, err)
os.Exit(2)
}
}
// testGoCoverDir returns the value passed to the -test.gocoverdir
// flag by the Go command, if goexperiment.CoverageRedesign is
// in effect.
func testGoCoverDir() string {
return *gocoverdir
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package quick implements utility functions to help with black box testing.
//
// The testing/quick package is frozen and is not accepting new features.
package quick
import (
"flag"
"fmt"
"math"
"math/rand"
"reflect"
"strings"
"time"
)
var defaultMaxCount *int = flag.Int("quickchecks", 100, "The default number of iterations for each check")
// A Generator can generate random values of its own type.
type Generator interface {
// Generate returns a random instance of the type on which it is a
// method using the size as a size hint.
Generate(rand *rand.Rand, size int) reflect.Value
}
// randFloat32 generates a random float taking the full range of a float32.
func randFloat32(rand *rand.Rand) float32 {
f := rand.Float64() * math.MaxFloat32
if rand.Int()&1 == 1 {
f = -f
}
return float32(f)
}
// randFloat64 generates a random float taking the full range of a float64.
func randFloat64(rand *rand.Rand) float64 {
f := rand.Float64() * math.MaxFloat64
if rand.Int()&1 == 1 {
f = -f
}
return f
}
// randInt64 returns a random int64.
func randInt64(rand *rand.Rand) int64 {
return int64(rand.Uint64())
}
// complexSize is the maximum length of arbitrary values that contain other
// values.
const complexSize = 50
// Value returns an arbitrary value of the given type.
// If the type implements the Generator interface, that will be used.
// Note: To create arbitrary values for structs, all the fields must be exported.
func Value(t reflect.Type, rand *rand.Rand) (value reflect.Value, ok bool) {
return sizedValue(t, rand, complexSize)
}
// sizedValue returns an arbitrary value of the given type. The size
// hint is used for shrinking as a function of indirection level so
// that recursive data structures will terminate.
func sizedValue(t reflect.Type, rand *rand.Rand, size int) (value reflect.Value, ok bool) {
if m, ok := reflect.Zero(t).Interface().(Generator); ok {
return m.Generate(rand, size), true
}
v := reflect.New(t).Elem()
switch concrete := t; concrete.Kind() {
case reflect.Bool:
v.SetBool(rand.Int()&1 == 0)
case reflect.Float32:
v.SetFloat(float64(randFloat32(rand)))
case reflect.Float64:
v.SetFloat(randFloat64(rand))
case reflect.Complex64:
v.SetComplex(complex(float64(randFloat32(rand)), float64(randFloat32(rand))))
case reflect.Complex128:
v.SetComplex(complex(randFloat64(rand), randFloat64(rand)))
case reflect.Int16:
v.SetInt(randInt64(rand))
case reflect.Int32:
v.SetInt(randInt64(rand))
case reflect.Int64:
v.SetInt(randInt64(rand))
case reflect.Int8:
v.SetInt(randInt64(rand))
case reflect.Int:
v.SetInt(randInt64(rand))
case reflect.Uint16:
v.SetUint(uint64(randInt64(rand)))
case reflect.Uint32:
v.SetUint(uint64(randInt64(rand)))
case reflect.Uint64:
v.SetUint(uint64(randInt64(rand)))
case reflect.Uint8:
v.SetUint(uint64(randInt64(rand)))
case reflect.Uint:
v.SetUint(uint64(randInt64(rand)))
case reflect.Uintptr:
v.SetUint(uint64(randInt64(rand)))
case reflect.Map:
numElems := rand.Intn(size)
v.Set(reflect.MakeMap(concrete))
for i := 0; i < numElems; i++ {
key, ok1 := sizedValue(concrete.Key(), rand, size)
value, ok2 := sizedValue(concrete.Elem(), rand, size)
if !ok1 || !ok2 {
return reflect.Value{}, false
}
v.SetMapIndex(key, value)
}
case reflect.Pointer:
if rand.Intn(size) == 0 {
v.Set(reflect.Zero(concrete)) // Generate nil pointer.
} else {
elem, ok := sizedValue(concrete.Elem(), rand, size)
if !ok {
return reflect.Value{}, false
}
v.Set(reflect.New(concrete.Elem()))
v.Elem().Set(elem)
}
case reflect.Slice:
numElems := rand.Intn(size)
sizeLeft := size - numElems
v.Set(reflect.MakeSlice(concrete, numElems, numElems))
for i := 0; i < numElems; i++ {
elem, ok := sizedValue(concrete.Elem(), rand, sizeLeft)
if !ok {
return reflect.Value{}, false
}
v.Index(i).Set(elem)
}
case reflect.Array:
for i := 0; i < v.Len(); i++ {
elem, ok := sizedValue(concrete.Elem(), rand, size)
if !ok {
return reflect.Value{}, false
}
v.Index(i).Set(elem)
}
case reflect.String:
numChars := rand.Intn(complexSize)
codePoints := make([]rune, numChars)
for i := 0; i < numChars; i++ {
codePoints[i] = rune(rand.Intn(0x10ffff))
}
v.SetString(string(codePoints))
case reflect.Struct:
n := v.NumField()
// Divide sizeLeft evenly among the struct fields.
sizeLeft := size
if n > sizeLeft {
sizeLeft = 1
} else if n > 0 {
sizeLeft /= n
}
for i := 0; i < n; i++ {
elem, ok := sizedValue(concrete.Field(i).Type, rand, sizeLeft)
if !ok {
return reflect.Value{}, false
}
v.Field(i).Set(elem)
}
default:
return reflect.Value{}, false
}
return v, true
}
// A Config structure contains options for running a test.
type Config struct {
// MaxCount sets the maximum number of iterations.
// If zero, MaxCountScale is used.
MaxCount int
// MaxCountScale is a non-negative scale factor applied to the
// default maximum.
// A count of zero implies the default, which is usually 100
// but can be set by the -quickchecks flag.
MaxCountScale float64
// Rand specifies a source of random numbers.
// If nil, a default pseudo-random source will be used.
Rand *rand.Rand
// Values specifies a function to generate a slice of
// arbitrary reflect.Values that are congruent with the
// arguments to the function being tested.
// If nil, the top-level Value function is used to generate them.
Values func([]reflect.Value, *rand.Rand)
}
var defaultConfig Config
// getRand returns the *rand.Rand to use for a given Config.
func (c *Config) getRand() *rand.Rand {
if c.Rand == nil {
return rand.New(rand.NewSource(time.Now().UnixNano()))
}
return c.Rand
}
// getMaxCount returns the maximum number of iterations to run for a given
// Config.
func (c *Config) getMaxCount() (maxCount int) {
maxCount = c.MaxCount
if maxCount == 0 {
if c.MaxCountScale != 0 {
maxCount = int(c.MaxCountScale * float64(*defaultMaxCount))
} else {
maxCount = *defaultMaxCount
}
}
return
}
// A SetupError is the result of an error in the way that check is being
// used, independent of the functions being tested.
type SetupError string
func (s SetupError) Error() string { return string(s) }
// A CheckError is the result of Check finding an error.
type CheckError struct {
Count int
In []any
}
func (s *CheckError) Error() string {
return fmt.Sprintf("#%d: failed on input %s", s.Count, toString(s.In))
}
// A CheckEqualError is the result CheckEqual finding an error.
type CheckEqualError struct {
CheckError
Out1 []any
Out2 []any
}
func (s *CheckEqualError) Error() string {
return fmt.Sprintf("#%d: failed on input %s. Output 1: %s. Output 2: %s", s.Count, toString(s.In), toString(s.Out1), toString(s.Out2))
}
// Check looks for an input to f, any function that returns bool,
// such that f returns false. It calls f repeatedly, with arbitrary
// values for each argument. If f returns false on a given input,
// Check returns that input as a *CheckError.
// For example:
//
// func TestOddMultipleOfThree(t *testing.T) {
// f := func(x int) bool {
// y := OddMultipleOfThree(x)
// return y%2 == 1 && y%3 == 0
// }
// if err := quick.Check(f, nil); err != nil {
// t.Error(err)
// }
// }
func Check(f any, config *Config) error {
if config == nil {
config = &defaultConfig
}
fVal, fType, ok := functionAndType(f)
if !ok {
return SetupError("argument is not a function")
}
if fType.NumOut() != 1 {
return SetupError("function does not return one value")
}
if fType.Out(0).Kind() != reflect.Bool {
return SetupError("function does not return a bool")
}
arguments := make([]reflect.Value, fType.NumIn())
rand := config.getRand()
maxCount := config.getMaxCount()
for i := 0; i < maxCount; i++ {
err := arbitraryValues(arguments, fType, config, rand)
if err != nil {
return err
}
if !fVal.Call(arguments)[0].Bool() {
return &CheckError{i + 1, toInterfaces(arguments)}
}
}
return nil
}
// CheckEqual looks for an input on which f and g return different results.
// It calls f and g repeatedly with arbitrary values for each argument.
// If f and g return different answers, CheckEqual returns a *CheckEqualError
// describing the input and the outputs.
func CheckEqual(f, g any, config *Config) error {
if config == nil {
config = &defaultConfig
}
x, xType, ok := functionAndType(f)
if !ok {
return SetupError("f is not a function")
}
y, yType, ok := functionAndType(g)
if !ok {
return SetupError("g is not a function")
}
if xType != yType {
return SetupError("functions have different types")
}
arguments := make([]reflect.Value, xType.NumIn())
rand := config.getRand()
maxCount := config.getMaxCount()
for i := 0; i < maxCount; i++ {
err := arbitraryValues(arguments, xType, config, rand)
if err != nil {
return err
}
xOut := toInterfaces(x.Call(arguments))
yOut := toInterfaces(y.Call(arguments))
if !reflect.DeepEqual(xOut, yOut) {
return &CheckEqualError{CheckError{i + 1, toInterfaces(arguments)}, xOut, yOut}
}
}
return nil
}
// arbitraryValues writes Values to args such that args contains Values
// suitable for calling f.
func arbitraryValues(args []reflect.Value, f reflect.Type, config *Config, rand *rand.Rand) (err error) {
if config.Values != nil {
config.Values(args, rand)
return
}
for j := 0; j < len(args); j++ {
var ok bool
args[j], ok = Value(f.In(j), rand)
if !ok {
err = SetupError(fmt.Sprintf("cannot create arbitrary value of type %s for argument %d", f.In(j), j))
return
}
}
return
}
func functionAndType(f any) (v reflect.Value, t reflect.Type, ok bool) {
v = reflect.ValueOf(f)
ok = v.Kind() == reflect.Func
if !ok {
return
}
t = v.Type()
return
}
func toInterfaces(values []reflect.Value) []any {
ret := make([]any, len(values))
for i, v := range values {
ret[i] = v.Interface()
}
return ret
}
func toString(interfaces []any) string {
s := make([]string, len(interfaces))
for i, v := range interfaces {
s[i] = fmt.Sprintf("%#v", v)
}
return strings.Join(s, ", ")
}
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !js
// TODO(@musiol, @odeke-em): re-unify this entire file back into
// example.go when js/wasm gets an os.Pipe implementation
// and no longer needs this separation.
package testing
import (
"fmt"
"io"
"os"
"strings"
"time"
)
func runExample(eg InternalExample) (ok bool) {
if chatty.on {
fmt.Printf("%s=== RUN %s\n", chatty.prefix(), eg.Name)
}
// Capture stdout.
stdout := os.Stdout
r, w, err := os.Pipe()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
os.Stdout = w
outC := make(chan string)
go func() {
var buf strings.Builder
_, err := io.Copy(&buf, r)
r.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "testing: copying pipe: %v\n", err)
os.Exit(1)
}
outC <- buf.String()
}()
finished := false
start := time.Now()
// Clean up in a deferred call so we can recover if the example panics.
defer func() {
timeSpent := time.Since(start)
// Close pipe, restore stdout, get output.
w.Close()
os.Stdout = stdout
out := <-outC
err := recover()
ok = eg.processRunResult(out, timeSpent, finished, err)
}()
// Run example.
eg.F()
finished = true
return
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package testing provides support for automated testing of Go packages.
// It is intended to be used in concert with the "go test" command, which automates
// execution of any function of the form
//
// func TestXxx(*testing.T)
//
// where Xxx does not start with a lowercase letter. The function name
// serves to identify the test routine.
//
// Within these functions, use the Error, Fail or related methods to signal failure.
//
// To write a new test suite, create a file that
// contains the TestXxx functions as described here,
// and give that file a name ending in "_test.go".
// The file will be excluded from regular
// package builds but will be included when the "go test" command is run.
//
// The test file can be in the same package as the one being tested,
// or in a corresponding package with the suffix "_test".
//
// If the test file is in the same package, it may refer to unexported
// identifiers within the package, as in this example:
//
// package abs
//
// import "testing"
//
// func TestAbs(t *testing.T) {
// got := Abs(-1)
// if got != 1 {
// t.Errorf("Abs(-1) = %d; want 1", got)
// }
// }
//
// If the file is in a separate "_test" package, the package being tested
// must be imported explicitly and only its exported identifiers may be used.
// This is known as "black box" testing.
//
// package abs_test
//
// import (
// "testing"
//
// "path_to_pkg/abs"
// )
//
// func TestAbs(t *testing.T) {
// got := abs.Abs(-1)
// if got != 1 {
// t.Errorf("Abs(-1) = %d; want 1", got)
// }
// }
//
// For more detail, run "go help test" and "go help testflag".
//
// # Benchmarks
//
// Functions of the form
//
// func BenchmarkXxx(*testing.B)
//
// are considered benchmarks, and are executed by the "go test" command when
// its -bench flag is provided. Benchmarks are run sequentially.
//
// For a description of the testing flags, see
// https://golang.org/cmd/go/#hdr-Testing_flags.
//
// A sample benchmark function looks like this:
//
// func BenchmarkRandInt(b *testing.B) {
// for i := 0; i < b.N; i++ {
// rand.Int()
// }
// }
//
// The benchmark function must run the target code b.N times.
// During benchmark execution, b.N is adjusted until the benchmark function lasts
// long enough to be timed reliably. The output
//
// BenchmarkRandInt-8 68453040 17.8 ns/op
//
// means that the loop ran 68453040 times at a speed of 17.8 ns per loop.
//
// If a benchmark needs some expensive setup before running, the timer
// may be reset:
//
// func BenchmarkBigLen(b *testing.B) {
// big := NewBig()
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// big.Len()
// }
// }
//
// If a benchmark needs to test performance in a parallel setting, it may use
// the RunParallel helper function; such benchmarks are intended to be used with
// the go test -cpu flag:
//
// func BenchmarkTemplateParallel(b *testing.B) {
// templ := template.Must(template.New("test").Parse("Hello, {{.}}!"))
// b.RunParallel(func(pb *testing.PB) {
// var buf bytes.Buffer
// for pb.Next() {
// buf.Reset()
// templ.Execute(&buf, "World")
// }
// })
// }
//
// A detailed specification of the benchmark results format is given
// in https://golang.org/design/14313-benchmark-format.
//
// There are standard tools for working with benchmark results at
// https://golang.org/x/perf/cmd.
// In particular, https://golang.org/x/perf/cmd/benchstat performs
// statistically robust A/B comparisons.
//
// # Examples
//
// The package also runs and verifies example code. Example functions may
// include a concluding line comment that begins with "Output:" and is compared with
// the standard output of the function when the tests are run. (The comparison
// ignores leading and trailing space.) These are examples of an example:
//
// func ExampleHello() {
// fmt.Println("hello")
// // Output: hello
// }
//
// func ExampleSalutations() {
// fmt.Println("hello, and")
// fmt.Println("goodbye")
// // Output:
// // hello, and
// // goodbye
// }
//
// The comment prefix "Unordered output:" is like "Output:", but matches any
// line order:
//
// func ExamplePerm() {
// for _, value := range Perm(5) {
// fmt.Println(value)
// }
// // Unordered output: 4
// // 2
// // 1
// // 3
// // 0
// }
//
// Example functions without output comments are compiled but not executed.
//
// The naming convention to declare examples for the package, a function F, a type T and
// method M on type T are:
//
// func Example() { ... }
// func ExampleF() { ... }
// func ExampleT() { ... }
// func ExampleT_M() { ... }
//
// Multiple example functions for a package/type/function/method may be provided by
// appending a distinct suffix to the name. The suffix must start with a
// lower-case letter.
//
// func Example_suffix() { ... }
// func ExampleF_suffix() { ... }
// func ExampleT_suffix() { ... }
// func ExampleT_M_suffix() { ... }
//
// The entire test file is presented as the example when it contains a single
// example function, at least one other function, type, variable, or constant
// declaration, and no test or benchmark functions.
//
// # Fuzzing
//
// 'go test' and the testing package support fuzzing, a testing technique where
// a function is called with randomly generated inputs to find bugs not
// anticipated by unit tests.
//
// Functions of the form
//
// func FuzzXxx(*testing.F)
//
// are considered fuzz tests.
//
// For example:
//
// func FuzzHex(f *testing.F) {
// for _, seed := range [][]byte{{}, {0}, {9}, {0xa}, {0xf}, {1, 2, 3, 4}} {
// f.Add(seed)
// }
// f.Fuzz(func(t *testing.T, in []byte) {
// enc := hex.EncodeToString(in)
// out, err := hex.DecodeString(enc)
// if err != nil {
// t.Fatalf("%v: decode: %v", in, err)
// }
// if !bytes.Equal(in, out) {
// t.Fatalf("%v: not equal after round trip: %v", in, out)
// }
// })
// }
//
// A fuzz test maintains a seed corpus, or a set of inputs which are run by
// default, and can seed input generation. Seed inputs may be registered by
// calling (*F).Add or by storing files in the directory testdata/fuzz/<Name>
// (where <Name> is the name of the fuzz test) within the package containing
// the fuzz test. Seed inputs are optional, but the fuzzing engine may find
// bugs more efficiently when provided with a set of small seed inputs with good
// code coverage. These seed inputs can also serve as regression tests for bugs
// identified through fuzzing.
//
// The function passed to (*F).Fuzz within the fuzz test is considered the fuzz
// target. A fuzz target must accept a *T parameter, followed by one or more
// parameters for random inputs. The types of arguments passed to (*F).Add must
// be identical to the types of these parameters. The fuzz target may signal
// that it's found a problem the same way tests do: by calling T.Fail (or any
// method that calls it like T.Error or T.Fatal) or by panicking.
//
// When fuzzing is enabled (by setting the -fuzz flag to a regular expression
// that matches a specific fuzz test), the fuzz target is called with arguments
// generated by repeatedly making random changes to the seed inputs. On
// supported platforms, 'go test' compiles the test executable with fuzzing
// coverage instrumentation. The fuzzing engine uses that instrumentation to
// find and cache inputs that expand coverage, increasing the likelihood of
// finding bugs. If the fuzz target fails for a given input, the fuzzing engine
// writes the inputs that caused the failure to a file in the directory
// testdata/fuzz/<Name> within the package directory. This file later serves as
// a seed input. If the file can't be written at that location (for example,
// because the directory is read-only), the fuzzing engine writes the file to
// the fuzz cache directory within the build cache instead.
//
// When fuzzing is disabled, the fuzz target is called with the seed inputs
// registered with F.Add and seed inputs from testdata/fuzz/<Name>. In this
// mode, the fuzz test acts much like a regular test, with subtests started
// with F.Fuzz instead of T.Run.
//
// See https://go.dev/doc/fuzz for documentation about fuzzing.
//
// # Skipping
//
// Tests or benchmarks may be skipped at run time with a call to
// the Skip method of *T or *B:
//
// func TestTimeConsuming(t *testing.T) {
// if testing.Short() {
// t.Skip("skipping test in short mode.")
// }
// ...
// }
//
// The Skip method of *T can be used in a fuzz target if the input is invalid,
// but should not be considered a failing input. For example:
//
// func FuzzJSONMarshaling(f *testing.F) {
// f.Fuzz(func(t *testing.T, b []byte) {
// var v interface{}
// if err := json.Unmarshal(b, &v); err != nil {
// t.Skip()
// }
// if _, err := json.Marshal(v); err != nil {
// t.Errorf("Marshal: %v", err)
// }
// })
// }
//
// # Subtests and Sub-benchmarks
//
// The Run methods of T and B allow defining subtests and sub-benchmarks,
// without having to define separate functions for each. This enables uses
// like table-driven benchmarks and creating hierarchical tests.
// It also provides a way to share common setup and tear-down code:
//
// func TestFoo(t *testing.T) {
// // <setup code>
// t.Run("A=1", func(t *testing.T) { ... })
// t.Run("A=2", func(t *testing.T) { ... })
// t.Run("B=1", func(t *testing.T) { ... })
// // <tear-down code>
// }
//
// Each subtest and sub-benchmark has a unique name: the combination of the name
// of the top-level test and the sequence of names passed to Run, separated by
// slashes, with an optional trailing sequence number for disambiguation.
//
// The argument to the -run, -bench, and -fuzz command-line flags is an unanchored regular
// expression that matches the test's name. For tests with multiple slash-separated
// elements, such as subtests, the argument is itself slash-separated, with
// expressions matching each name element in turn. Because it is unanchored, an
// empty expression matches any string.
// For example, using "matching" to mean "whose name contains":
//
// go test -run '' # Run all tests.
// go test -run Foo # Run top-level tests matching "Foo", such as "TestFooBar".
// go test -run Foo/A= # For top-level tests matching "Foo", run subtests matching "A=".
// go test -run /A=1 # For all top-level tests, run subtests matching "A=1".
// go test -fuzz FuzzFoo # Fuzz the target matching "FuzzFoo"
//
// The -run argument can also be used to run a specific value in the seed
// corpus, for debugging. For example:
//
// go test -run=FuzzFoo/9ddb952d9814
//
// The -fuzz and -run flags can both be set, in order to fuzz a target but
// skip the execution of all other tests.
//
// Subtests can also be used to control parallelism. A parent test will only
// complete once all of its subtests complete. In this example, all tests are
// run in parallel with each other, and only with each other, regardless of
// other top-level tests that may be defined:
//
// func TestGroupedParallel(t *testing.T) {
// for _, tc := range tests {
// tc := tc // capture range variable
// t.Run(tc.Name, func(t *testing.T) {
// t.Parallel()
// ...
// })
// }
// }
//
// Run does not return until parallel subtests have completed, providing a way
// to clean up after a group of parallel tests:
//
// func TestTeardownParallel(t *testing.T) {
// // This Run will not return until the parallel tests finish.
// t.Run("group", func(t *testing.T) {
// t.Run("Test1", parallelTest1)
// t.Run("Test2", parallelTest2)
// t.Run("Test3", parallelTest3)
// })
// // <tear-down code>
// }
//
// # Main
//
// It is sometimes necessary for a test or benchmark program to do extra setup or teardown
// before or after it executes. It is also sometimes necessary to control
// which code runs on the main thread. To support these and other cases,
// if a test file contains a function:
//
// func TestMain(m *testing.M)
//
// then the generated test will call TestMain(m) instead of running the tests or benchmarks
// directly. TestMain runs in the main goroutine and can do whatever setup
// and teardown is necessary around a call to m.Run. m.Run will return an exit
// code that may be passed to os.Exit. If TestMain returns, the test wrapper
// will pass the result of m.Run to os.Exit itself.
//
// When TestMain is called, flag.Parse has not been run. If TestMain depends on
// command-line flags, including those of the testing package, it should call
// flag.Parse explicitly. Command line flags are always parsed by the time test
// or benchmark functions run.
//
// A simple implementation of TestMain is:
//
// func TestMain(m *testing.M) {
// // call flag.Parse() here if TestMain uses flags
// os.Exit(m.Run())
// }
//
// TestMain is a low-level primitive and should not be necessary for casual
// testing needs, where ordinary test functions suffice.
package testing
import (
"bytes"
"errors"
"flag"
"fmt"
"internal/goexperiment"
"internal/race"
"io"
"math/rand"
"os"
"reflect"
"runtime"
"runtime/debug"
"runtime/trace"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
"unicode/utf8"
)
var initRan bool
// Init registers testing flags. These flags are automatically registered by
// the "go test" command before running test functions, so Init is only needed
// when calling functions such as Benchmark without using "go test".
//
// Init has no effect if it was already called.
func Init() {
if initRan {
return
}
initRan = true
// The short flag requests that tests run more quickly, but its functionality
// is provided by test writers themselves. The testing package is just its
// home. The all.bash installation script sets it to make installation more
// efficient, but by default the flag is off so a plain "go test" will do a
// full test of the package.
short = flag.Bool("test.short", false, "run smaller test suite to save time")
// The failfast flag requests that test execution stop after the first test failure.
failFast = flag.Bool("test.failfast", false, "do not start new tests after the first test failure")
// The directory in which to create profile files and the like. When run from
// "go test", the binary always runs in the source directory for the package;
// this flag lets "go test" tell the binary to write the files in the directory where
// the "go test" command is run.
outputDir = flag.String("test.outputdir", "", "write profiles to `dir`")
// Report as tests are run; default is silent for success.
flag.Var(&chatty, "test.v", "verbose: print additional output")
count = flag.Uint("test.count", 1, "run tests and benchmarks `n` times")
coverProfile = flag.String("test.coverprofile", "", "write a coverage profile to `file`")
gocoverdir = flag.String("test.gocoverdir", "", "write coverage intermediate files to this directory")
matchList = flag.String("test.list", "", "list tests, examples, and benchmarks matching `regexp` then exit")
match = flag.String("test.run", "", "run only tests and examples matching `regexp`")
skip = flag.String("test.skip", "", "do not list or run tests matching `regexp`")
memProfile = flag.String("test.memprofile", "", "write an allocation profile to `file`")
memProfileRate = flag.Int("test.memprofilerate", 0, "set memory allocation profiling `rate` (see runtime.MemProfileRate)")
cpuProfile = flag.String("test.cpuprofile", "", "write a cpu profile to `file`")
blockProfile = flag.String("test.blockprofile", "", "write a goroutine blocking profile to `file`")
blockProfileRate = flag.Int("test.blockprofilerate", 1, "set blocking profile `rate` (see runtime.SetBlockProfileRate)")
mutexProfile = flag.String("test.mutexprofile", "", "write a mutex contention profile to the named file after execution")
mutexProfileFraction = flag.Int("test.mutexprofilefraction", 1, "if >= 0, calls runtime.SetMutexProfileFraction()")
panicOnExit0 = flag.Bool("test.paniconexit0", false, "panic on call to os.Exit(0)")
traceFile = flag.String("test.trace", "", "write an execution trace to `file`")
timeout = flag.Duration("test.timeout", 0, "panic test binary after duration `d` (default 0, timeout disabled)")
cpuListStr = flag.String("test.cpu", "", "comma-separated `list` of cpu counts to run each test with")
parallel = flag.Int("test.parallel", runtime.GOMAXPROCS(0), "run at most `n` tests in parallel")
testlog = flag.String("test.testlogfile", "", "write test action log to `file` (for use only by cmd/go)")
shuffle = flag.String("test.shuffle", "off", "randomize the execution order of tests and benchmarks")
fullPath = flag.Bool("test.fullpath", false, "show full file names in error messages")
initBenchmarkFlags()
initFuzzFlags()
}
var (
// Flags, registered during Init.
short *bool
failFast *bool
outputDir *string
chatty chattyFlag
count *uint
coverProfile *string
gocoverdir *string
matchList *string
match *string
skip *string
memProfile *string
memProfileRate *int
cpuProfile *string
blockProfile *string
blockProfileRate *int
mutexProfile *string
mutexProfileFraction *int
panicOnExit0 *bool
traceFile *string
timeout *time.Duration
cpuListStr *string
parallel *int
shuffle *string
testlog *string
fullPath *bool
haveExamples bool // are there examples?
cpuList []int
testlogFile *os.File
numFailed atomic.Uint32 // number of test failures
running sync.Map // map[string]time.Time of running, unpaused tests
)
type chattyFlag struct {
on bool // -v is set in some form
json bool // -v=test2json is set, to make output better for test2json
}
func (*chattyFlag) IsBoolFlag() bool { return true }
func (f *chattyFlag) Set(arg string) error {
switch arg {
default:
return fmt.Errorf("invalid flag -test.v=%s", arg)
case "true", "test2json":
f.on = true
f.json = arg == "test2json"
case "false":
f.on = false
f.json = false
}
return nil
}
func (f *chattyFlag) String() string {
if f.json {
return "test2json"
}
if f.on {
return "true"
}
return "false"
}
func (f *chattyFlag) Get() any {
if f.json {
return "test2json"
}
return f.on
}
const marker = byte(0x16) // ^V for framing
func (f *chattyFlag) prefix() string {
if f.json {
return string(marker)
}
return ""
}
type chattyPrinter struct {
w io.Writer
lastNameMu sync.Mutex // guards lastName
lastName string // last printed test name in chatty mode
json bool // -v=json output mode
}
func newChattyPrinter(w io.Writer) *chattyPrinter {
return &chattyPrinter{w: w, json: chatty.json}
}
// prefix is like chatty.prefix but using p.json instead of chatty.json.
// Using p.json allows tests to check the json behavior without modifying
// the global variable. For convenience, we allow p == nil and treat
// that as not in json mode (because it's not chatty at all).
func (p *chattyPrinter) prefix() string {
if p != nil && p.json {
return string(marker)
}
return ""
}
// Updatef prints a message about the status of the named test to w.
//
// The formatted message must include the test name itself.
func (p *chattyPrinter) Updatef(testName, format string, args ...any) {
p.lastNameMu.Lock()
defer p.lastNameMu.Unlock()
// Since the message already implies an association with a specific new test,
// we don't need to check what the old test name was or log an extra NAME line
// for it. (We're updating it anyway, and the current message already includes
// the test name.)
p.lastName = testName
fmt.Fprintf(p.w, p.prefix()+format, args...)
}
// Printf prints a message, generated by the named test, that does not
// necessarily mention that tests's name itself.
func (p *chattyPrinter) Printf(testName, format string, args ...any) {
p.lastNameMu.Lock()
defer p.lastNameMu.Unlock()
if p.lastName == "" {
p.lastName = testName
} else if p.lastName != testName {
fmt.Fprintf(p.w, "%s=== NAME %s\n", p.prefix(), testName)
p.lastName = testName
}
fmt.Fprintf(p.w, format, args...)
}
// The maximum number of stack frames to go through when skipping helper functions for
// the purpose of decorating log messages.
const maxStackLen = 50
// common holds the elements common between T and B and
// captures common methods such as Errorf.
type common struct {
mu sync.RWMutex // guards this group of fields
output []byte // Output generated by test or benchmark.
w io.Writer // For flushToParent.
ran bool // Test or benchmark (or one of its subtests) was executed.
failed bool // Test or benchmark has failed.
skipped bool // Test or benchmark has been skipped.
done bool // Test is finished and all subtests have completed.
helperPCs map[uintptr]struct{} // functions to be skipped when writing file/line info
helperNames map[string]struct{} // helperPCs converted to function names
cleanups []func() // optional functions to be called at the end of the test
cleanupName string // Name of the cleanup function.
cleanupPc []uintptr // The stack trace at the point where Cleanup was called.
finished bool // Test function has completed.
inFuzzFn bool // Whether the fuzz target, if this is one, is running.
chatty *chattyPrinter // A copy of chattyPrinter, if the chatty flag is set.
bench bool // Whether the current test is a benchmark.
hasSub atomic.Bool // whether there are sub-benchmarks.
cleanupStarted atomic.Bool // Registered cleanup callbacks have started to execute
raceErrors int // Number of races detected during test.
runner string // Function name of tRunner running the test.
isParallel bool // Whether the test is parallel.
parent *common
level int // Nesting depth of test or benchmark.
creator []uintptr // If level > 0, the stack trace at the point where the parent called t.Run.
name string // Name of test or benchmark.
start time.Time // Time test or benchmark started
duration time.Duration
barrier chan bool // To signal parallel subtests they may start. Nil when T.Parallel is not present (B) or not usable (when fuzzing).
signal chan bool // To signal a test is done.
sub []*T // Queue of subtests to be run in parallel.
tempDirMu sync.Mutex
tempDir string
tempDirErr error
tempDirSeq int32
}
// Short reports whether the -test.short flag is set.
func Short() bool {
if short == nil {
panic("testing: Short called before Init")
}
// Catch code that calls this from TestMain without first calling flag.Parse.
if !flag.Parsed() {
panic("testing: Short called before Parse")
}
return *short
}
// CoverMode reports what the test coverage mode is set to. The
// values are "set", "count", or "atomic". The return value will be
// empty if test coverage is not enabled.
func CoverMode() string {
if goexperiment.CoverageRedesign {
return cover2.mode
}
return cover.Mode
}
// Verbose reports whether the -test.v flag is set.
func Verbose() bool {
// Same as in Short.
if !flag.Parsed() {
panic("testing: Verbose called before Parse")
}
return chatty.on
}
func (c *common) checkFuzzFn(name string) {
if c.inFuzzFn {
panic(fmt.Sprintf("testing: f.%s was called inside the fuzz target, use t.%s instead", name, name))
}
}
// frameSkip searches, starting after skip frames, for the first caller frame
// in a function not marked as a helper and returns that frame.
// The search stops if it finds a tRunner function that
// was the entry point into the test and the test is not a subtest.
// This function must be called with c.mu held.
func (c *common) frameSkip(skip int) runtime.Frame {
// If the search continues into the parent test, we'll have to hold
// its mu temporarily. If we then return, we need to unlock it.
shouldUnlock := false
defer func() {
if shouldUnlock {
c.mu.Unlock()
}
}()
var pc [maxStackLen]uintptr
// Skip two extra frames to account for this function
// and runtime.Callers itself.
n := runtime.Callers(skip+2, pc[:])
if n == 0 {
panic("testing: zero callers found")
}
frames := runtime.CallersFrames(pc[:n])
var firstFrame, prevFrame, frame runtime.Frame
for more := true; more; prevFrame = frame {
frame, more = frames.Next()
if frame.Function == "runtime.gopanic" {
continue
}
if frame.Function == c.cleanupName {
frames = runtime.CallersFrames(c.cleanupPc)
continue
}
if firstFrame.PC == 0 {
firstFrame = frame
}
if frame.Function == c.runner {
// We've gone up all the way to the tRunner calling
// the test function (so the user must have
// called tb.Helper from inside that test function).
// If this is a top-level test, only skip up to the test function itself.
// If we're in a subtest, continue searching in the parent test,
// starting from the point of the call to Run which created this subtest.
if c.level > 1 {
frames = runtime.CallersFrames(c.creator)
parent := c.parent
// We're no longer looking at the current c after this point,
// so we should unlock its mu, unless it's the original receiver,
// in which case our caller doesn't expect us to do that.
if shouldUnlock {
c.mu.Unlock()
}
c = parent
// Remember to unlock c.mu when we no longer need it, either
// because we went up another nesting level, or because we
// returned.
shouldUnlock = true
c.mu.Lock()
continue
}
return prevFrame
}
// If more helper PCs have been added since we last did the conversion
if c.helperNames == nil {
c.helperNames = make(map[string]struct{})
for pc := range c.helperPCs {
c.helperNames[pcToName(pc)] = struct{}{}
}
}
if _, ok := c.helperNames[frame.Function]; !ok {
// Found a frame that wasn't inside a helper function.
return frame
}
}
return firstFrame
}
// decorate prefixes the string with the file and line of the call site
// and inserts the final newline if needed and indentation spaces for formatting.
// This function must be called with c.mu held.
func (c *common) decorate(s string, skip int) string {
frame := c.frameSkip(skip)
file := frame.File
line := frame.Line
if file != "" {
if *fullPath {
// If relative path, truncate file name at last file name separator.
} else if index := strings.LastIndex(file, "/"); index >= 0 {
file = file[index+1:]
} else if index = strings.LastIndex(file, "\\"); index >= 0 {
file = file[index+1:]
}
} else {
file = "???"
}
if line == 0 {
line = 1
}
buf := new(strings.Builder)
// Every line is indented at least 4 spaces.
buf.WriteString(" ")
fmt.Fprintf(buf, "%s:%d: ", file, line)
lines := strings.Split(s, "\n")
if l := len(lines); l > 1 && lines[l-1] == "" {
lines = lines[:l-1]
}
for i, line := range lines {
if i > 0 {
// Second and subsequent lines are indented an additional 4 spaces.
buf.WriteString("\n ")
}
buf.WriteString(line)
}
buf.WriteByte('\n')
return buf.String()
}
// flushToParent writes c.output to the parent after first writing the header
// with the given format and arguments.
func (c *common) flushToParent(testName, format string, args ...any) {
p := c.parent
p.mu.Lock()
defer p.mu.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if len(c.output) > 0 {
// Add the current c.output to the print,
// and then arrange for the print to replace c.output.
// (This displays the logged output after the --- FAIL line.)
format += "%s"
args = append(args[:len(args):len(args)], c.output)
c.output = c.output[:0]
}
if c.chatty != nil && (p.w == c.chatty.w || c.chatty.json) {
// We're flushing to the actual output, so track that this output is
// associated with a specific test (and, specifically, that the next output
// is *not* associated with that test).
//
// Moreover, if c.output is non-empty it is important that this write be
// atomic with respect to the output of other tests, so that we don't end up
// with confusing '=== NAME' lines in the middle of our '--- PASS' block.
// Neither humans nor cmd/test2json can parse those easily.
// (See https://go.dev/issue/40771.)
//
// If test2json is used, we never flush to parent tests,
// so that the json stream shows subtests as they finish.
// (See https://go.dev/issue/29811.)
c.chatty.Updatef(testName, format, args...)
} else {
// We're flushing to the output buffer of the parent test, which will
// itself follow a test-name header when it is finally flushed to stdout.
fmt.Fprintf(p.w, c.chatty.prefix()+format, args...)
}
}
type indenter struct {
c *common
}
func (w indenter) Write(b []byte) (n int, err error) {
n = len(b)
for len(b) > 0 {
end := bytes.IndexByte(b, '\n')
if end == -1 {
end = len(b)
} else {
end++
}
// An indent of 4 spaces will neatly align the dashes with the status
// indicator of the parent.
line := b[:end]
if line[0] == marker {
w.c.output = append(w.c.output, marker)
line = line[1:]
}
const indent = " "
w.c.output = append(w.c.output, indent...)
w.c.output = append(w.c.output, line...)
b = b[end:]
}
return
}
// fmtDuration returns a string representing d in the form "87.00s".
func fmtDuration(d time.Duration) string {
return fmt.Sprintf("%.2fs", d.Seconds())
}
// TB is the interface common to T, B, and F.
type TB interface {
Cleanup(func())
Error(args ...any)
Errorf(format string, args ...any)
Fail()
FailNow()
Failed() bool
Fatal(args ...any)
Fatalf(format string, args ...any)
Helper()
Log(args ...any)
Logf(format string, args ...any)
Name() string
Setenv(key, value string)
Skip(args ...any)
SkipNow()
Skipf(format string, args ...any)
Skipped() bool
TempDir() string
// A private method to prevent users implementing the
// interface and so future additions to it will not
// violate Go 1 compatibility.
private()
}
var _ TB = (*T)(nil)
var _ TB = (*B)(nil)
// T is a type passed to Test functions to manage test state and support formatted test logs.
//
// A test ends when its Test function returns or calls any of the methods
// FailNow, Fatal, Fatalf, SkipNow, Skip, or Skipf. Those methods, as well as
// the Parallel method, must be called only from the goroutine running the
// Test function.
//
// The other reporting methods, such as the variations of Log and Error,
// may be called simultaneously from multiple goroutines.
type T struct {
common
isEnvSet bool
context *testContext // For running tests and subtests.
}
func (c *common) private() {}
// Name returns the name of the running (sub-) test or benchmark.
//
// The name will include the name of the test along with the names of
// any nested sub-tests. If two sibling sub-tests have the same name,
// Name will append a suffix to guarantee the returned name is unique.
func (c *common) Name() string {
return c.name
}
func (c *common) setRan() {
if c.parent != nil {
c.parent.setRan()
}
c.mu.Lock()
defer c.mu.Unlock()
c.ran = true
}
// Fail marks the function as having failed but continues execution.
func (c *common) Fail() {
if c.parent != nil {
c.parent.Fail()
}
c.mu.Lock()
defer c.mu.Unlock()
// c.done needs to be locked to synchronize checks to c.done in parent tests.
if c.done {
panic("Fail in goroutine after " + c.name + " has completed")
}
c.failed = true
}
// Failed reports whether the function has failed.
func (c *common) Failed() bool {
c.mu.RLock()
failed := c.failed
c.mu.RUnlock()
return failed || c.raceErrors+race.Errors() > 0
}
// FailNow marks the function as having failed and stops its execution
// by calling runtime.Goexit (which then runs all deferred calls in the
// current goroutine).
// Execution will continue at the next test or benchmark.
// FailNow must be called from the goroutine running the
// test or benchmark function, not from other goroutines
// created during the test. Calling FailNow does not stop
// those other goroutines.
func (c *common) FailNow() {
c.checkFuzzFn("FailNow")
c.Fail()
// Calling runtime.Goexit will exit the goroutine, which
// will run the deferred functions in this goroutine,
// which will eventually run the deferred lines in tRunner,
// which will signal to the test loop that this test is done.
//
// A previous version of this code said:
//
// c.duration = ...
// c.signal <- c.self
// runtime.Goexit()
//
// This previous version duplicated code (those lines are in
// tRunner no matter what), but worse the goroutine teardown
// implicit in runtime.Goexit was not guaranteed to complete
// before the test exited. If a test deferred an important cleanup
// function (like removing temporary files), there was no guarantee
// it would run on a test failure. Because we send on c.signal during
// a top-of-stack deferred function now, we know that the send
// only happens after any other stacked defers have completed.
c.mu.Lock()
c.finished = true
c.mu.Unlock()
runtime.Goexit()
}
// log generates the output. It's always at the same stack depth.
func (c *common) log(s string) {
c.logDepth(s, 3) // logDepth + log + public function
}
// logDepth generates the output at an arbitrary stack depth.
func (c *common) logDepth(s string, depth int) {
c.mu.Lock()
defer c.mu.Unlock()
if c.done {
// This test has already finished. Try and log this message
// with our parent. If we don't have a parent, panic.
for parent := c.parent; parent != nil; parent = parent.parent {
parent.mu.Lock()
defer parent.mu.Unlock()
if !parent.done {
parent.output = append(parent.output, parent.decorate(s, depth+1)...)
return
}
}
panic("Log in goroutine after " + c.name + " has completed: " + s)
} else {
if c.chatty != nil {
if c.bench {
// Benchmarks don't print === CONT, so we should skip the test
// printer and just print straight to stdout.
fmt.Print(c.decorate(s, depth+1))
} else {
c.chatty.Printf(c.name, "%s", c.decorate(s, depth+1))
}
return
}
c.output = append(c.output, c.decorate(s, depth+1)...)
}
}
// Log formats its arguments using default formatting, analogous to Println,
// and records the text in the error log. For tests, the text will be printed only if
// the test fails or the -test.v flag is set. For benchmarks, the text is always
// printed to avoid having performance depend on the value of the -test.v flag.
func (c *common) Log(args ...any) {
c.checkFuzzFn("Log")
c.log(fmt.Sprintln(args...))
}
// Logf formats its arguments according to the format, analogous to Printf, and
// records the text in the error log. A final newline is added if not provided. For
// tests, the text will be printed only if the test fails or the -test.v flag is
// set. For benchmarks, the text is always printed to avoid having performance
// depend on the value of the -test.v flag.
func (c *common) Logf(format string, args ...any) {
c.checkFuzzFn("Logf")
c.log(fmt.Sprintf(format, args...))
}
// Error is equivalent to Log followed by Fail.
func (c *common) Error(args ...any) {
c.checkFuzzFn("Error")
c.log(fmt.Sprintln(args...))
c.Fail()
}
// Errorf is equivalent to Logf followed by Fail.
func (c *common) Errorf(format string, args ...any) {
c.checkFuzzFn("Errorf")
c.log(fmt.Sprintf(format, args...))
c.Fail()
}
// Fatal is equivalent to Log followed by FailNow.
func (c *common) Fatal(args ...any) {
c.checkFuzzFn("Fatal")
c.log(fmt.Sprintln(args...))
c.FailNow()
}
// Fatalf is equivalent to Logf followed by FailNow.
func (c *common) Fatalf(format string, args ...any) {
c.checkFuzzFn("Fatalf")
c.log(fmt.Sprintf(format, args...))
c.FailNow()
}
// Skip is equivalent to Log followed by SkipNow.
func (c *common) Skip(args ...any) {
c.checkFuzzFn("Skip")
c.log(fmt.Sprintln(args...))
c.SkipNow()
}
// Skipf is equivalent to Logf followed by SkipNow.
func (c *common) Skipf(format string, args ...any) {
c.checkFuzzFn("Skipf")
c.log(fmt.Sprintf(format, args...))
c.SkipNow()
}
// SkipNow marks the test as having been skipped and stops its execution
// by calling runtime.Goexit.
// If a test fails (see Error, Errorf, Fail) and is then skipped,
// it is still considered to have failed.
// Execution will continue at the next test or benchmark. See also FailNow.
// SkipNow must be called from the goroutine running the test, not from
// other goroutines created during the test. Calling SkipNow does not stop
// those other goroutines.
func (c *common) SkipNow() {
c.checkFuzzFn("SkipNow")
c.mu.Lock()
c.skipped = true
c.finished = true
c.mu.Unlock()
runtime.Goexit()
}
// Skipped reports whether the test was skipped.
func (c *common) Skipped() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.skipped
}
// Helper marks the calling function as a test helper function.
// When printing file and line information, that function will be skipped.
// Helper may be called simultaneously from multiple goroutines.
func (c *common) Helper() {
c.mu.Lock()
defer c.mu.Unlock()
if c.helperPCs == nil {
c.helperPCs = make(map[uintptr]struct{})
}
// repeating code from callerName here to save walking a stack frame
var pc [1]uintptr
n := runtime.Callers(2, pc[:]) // skip runtime.Callers + Helper
if n == 0 {
panic("testing: zero callers found")
}
if _, found := c.helperPCs[pc[0]]; !found {
c.helperPCs[pc[0]] = struct{}{}
c.helperNames = nil // map will be recreated next time it is needed
}
}
// Cleanup registers a function to be called when the test (or subtest) and all its
// subtests complete. Cleanup functions will be called in last added,
// first called order.
func (c *common) Cleanup(f func()) {
c.checkFuzzFn("Cleanup")
var pc [maxStackLen]uintptr
// Skip two extra frames to account for this function and runtime.Callers itself.
n := runtime.Callers(2, pc[:])
cleanupPc := pc[:n]
fn := func() {
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupName = ""
c.cleanupPc = nil
}()
name := callerName(0)
c.mu.Lock()
c.cleanupName = name
c.cleanupPc = cleanupPc
c.mu.Unlock()
f()
}
c.mu.Lock()
defer c.mu.Unlock()
c.cleanups = append(c.cleanups, fn)
}
// TempDir returns a temporary directory for the test to use.
// The directory is automatically removed by Cleanup when the test and
// all its subtests complete.
// Each subsequent call to t.TempDir returns a unique directory;
// if the directory creation fails, TempDir terminates the test by calling Fatal.
func (c *common) TempDir() string {
c.checkFuzzFn("TempDir")
// Use a single parent directory for all the temporary directories
// created by a test, each numbered sequentially.
c.tempDirMu.Lock()
var nonExistent bool
if c.tempDir == "" { // Usually the case with js/wasm
nonExistent = true
} else {
_, err := os.Stat(c.tempDir)
nonExistent = os.IsNotExist(err)
if err != nil && !nonExistent {
c.Fatalf("TempDir: %v", err)
}
}
if nonExistent {
c.Helper()
// Drop unusual characters (such as path separators or
// characters interacting with globs) from the directory name to
// avoid surprising os.MkdirTemp behavior.
mapper := func(r rune) rune {
if r < utf8.RuneSelf {
const allowed = "!#$%&()+,-.=@^_{}~ "
if '0' <= r && r <= '9' ||
'a' <= r && r <= 'z' ||
'A' <= r && r <= 'Z' {
return r
}
if strings.ContainsRune(allowed, r) {
return r
}
} else if unicode.IsLetter(r) || unicode.IsNumber(r) {
return r
}
return -1
}
pattern := strings.Map(mapper, c.Name())
c.tempDir, c.tempDirErr = os.MkdirTemp("", pattern)
if c.tempDirErr == nil {
c.Cleanup(func() {
if err := removeAll(c.tempDir); err != nil {
c.Errorf("TempDir RemoveAll cleanup: %v", err)
}
})
}
}
if c.tempDirErr == nil {
c.tempDirSeq++
}
seq := c.tempDirSeq
c.tempDirMu.Unlock()
if c.tempDirErr != nil {
c.Fatalf("TempDir: %v", c.tempDirErr)
}
dir := fmt.Sprintf("%s%c%03d", c.tempDir, os.PathSeparator, seq)
if err := os.Mkdir(dir, 0777); err != nil {
c.Fatalf("TempDir: %v", err)
}
return dir
}
// removeAll is like os.RemoveAll, but retries Windows "Access is denied."
// errors up to an arbitrary timeout.
//
// Those errors have been known to occur spuriously on at least the
// windows-amd64-2012 builder (https://go.dev/issue/50051), and can only occur
// legitimately if the test leaves behind a temp file that either is still open
// or the test otherwise lacks permission to delete. In the case of legitimate
// failures, a failing test may take a bit longer to fail, but once the test is
// fixed the extra latency will go away.
func removeAll(path string) error {
const arbitraryTimeout = 2 * time.Second
var (
start time.Time
nextSleep = 1 * time.Millisecond
)
for {
err := os.RemoveAll(path)
if !isWindowsRetryable(err) {
return err
}
if start.IsZero() {
start = time.Now()
} else if d := time.Since(start) + nextSleep; d >= arbitraryTimeout {
return err
}
time.Sleep(nextSleep)
nextSleep += time.Duration(rand.Int63n(int64(nextSleep)))
}
}
// Setenv calls os.Setenv(key, value) and uses Cleanup to
// restore the environment variable to its original value
// after the test.
//
// Because Setenv affects the whole process, it cannot be used
// in parallel tests or tests with parallel ancestors.
func (c *common) Setenv(key, value string) {
c.checkFuzzFn("Setenv")
prevValue, ok := os.LookupEnv(key)
if err := os.Setenv(key, value); err != nil {
c.Fatalf("cannot set environment variable: %v", err)
}
if ok {
c.Cleanup(func() {
os.Setenv(key, prevValue)
})
} else {
c.Cleanup(func() {
os.Unsetenv(key)
})
}
}
// panicHanding is an argument to runCleanup.
type panicHandling int
const (
normalPanic panicHandling = iota
recoverAndReturnPanic
)
// runCleanup is called at the end of the test.
// If catchPanic is true, this will catch panics, and return the recovered
// value if any.
func (c *common) runCleanup(ph panicHandling) (panicVal any) {
c.cleanupStarted.Store(true)
defer c.cleanupStarted.Store(false)
if ph == recoverAndReturnPanic {
defer func() {
panicVal = recover()
}()
}
// Make sure that if a cleanup function panics,
// we still run the remaining cleanup functions.
defer func() {
c.mu.Lock()
recur := len(c.cleanups) > 0
c.mu.Unlock()
if recur {
c.runCleanup(normalPanic)
}
}()
for {
var cleanup func()
c.mu.Lock()
if len(c.cleanups) > 0 {
last := len(c.cleanups) - 1
cleanup = c.cleanups[last]
c.cleanups = c.cleanups[:last]
}
c.mu.Unlock()
if cleanup == nil {
return nil
}
cleanup()
}
}
// callerName gives the function name (qualified with a package path)
// for the caller after skip frames (where 0 means the current function).
func callerName(skip int) string {
var pc [1]uintptr
n := runtime.Callers(skip+2, pc[:]) // skip + runtime.Callers + callerName
if n == 0 {
panic("testing: zero callers found")
}
return pcToName(pc[0])
}
func pcToName(pc uintptr) string {
pcs := []uintptr{pc}
frames := runtime.CallersFrames(pcs)
frame, _ := frames.Next()
return frame.Function
}
// Parallel signals that this test is to be run in parallel with (and only with)
// other parallel tests. When a test is run multiple times due to use of
// -test.count or -test.cpu, multiple instances of a single test never run in
// parallel with each other.
func (t *T) Parallel() {
if t.isParallel {
panic("testing: t.Parallel called multiple times")
}
if t.isEnvSet {
panic("testing: t.Parallel called after t.Setenv; cannot set environment variables in parallel tests")
}
t.isParallel = true
if t.parent.barrier == nil {
// T.Parallel has no effect when fuzzing.
// Multiple processes may run in parallel, but only one input can run at a
// time per process so we can attribute crashes to specific inputs.
return
}
// We don't want to include the time we spend waiting for serial tests
// in the test duration. Record the elapsed time thus far and reset the
// timer afterwards.
t.duration += time.Since(t.start)
// Add to the list of tests to be released by the parent.
t.parent.sub = append(t.parent.sub, t)
t.raceErrors += race.Errors()
if t.chatty != nil {
t.chatty.Updatef(t.name, "=== PAUSE %s\n", t.name)
}
running.Delete(t.name)
t.signal <- true // Release calling test.
<-t.parent.barrier // Wait for the parent test to complete.
t.context.waitParallel()
if t.chatty != nil {
t.chatty.Updatef(t.name, "=== CONT %s\n", t.name)
}
running.Store(t.name, time.Now())
t.start = time.Now()
t.raceErrors += -race.Errors()
}
// Setenv calls os.Setenv(key, value) and uses Cleanup to
// restore the environment variable to its original value
// after the test.
//
// Because Setenv affects the whole process, it cannot be used
// in parallel tests or tests with parallel ancestors.
func (t *T) Setenv(key, value string) {
// Non-parallel subtests that have parallel ancestors may still
// run in parallel with other tests: they are only non-parallel
// with respect to the other subtests of the same parent.
// Since SetEnv affects the whole process, we need to disallow it
// if the current test or any parent is parallel.
isParallel := false
for c := &t.common; c != nil; c = c.parent {
if c.isParallel {
isParallel = true
break
}
}
if isParallel {
panic("testing: t.Setenv called after t.Parallel; cannot set environment variables in parallel tests")
}
t.isEnvSet = true
t.common.Setenv(key, value)
}
// InternalTest is an internal type but exported because it is cross-package;
// it is part of the implementation of the "go test" command.
type InternalTest struct {
Name string
F func(*T)
}
var errNilPanicOrGoexit = errors.New("test executed panic(nil) or runtime.Goexit")
func tRunner(t *T, fn func(t *T)) {
t.runner = callerName(0)
// When this goroutine is done, either because fn(t)
// returned normally or because a test failure triggered
// a call to runtime.Goexit, record the duration and send
// a signal saying that the test is done.
defer func() {
if t.Failed() {
numFailed.Add(1)
}
if t.raceErrors+race.Errors() > 0 {
t.Errorf("race detected during execution of test")
}
// Check if the test panicked or Goexited inappropriately.
//
// If this happens in a normal test, print output but continue panicking.
// tRunner is called in its own goroutine, so this terminates the process.
//
// If this happens while fuzzing, recover from the panic and treat it like a
// normal failure. It's important that the process keeps running in order to
// find short inputs that cause panics.
err := recover()
signal := true
t.mu.RLock()
finished := t.finished
t.mu.RUnlock()
if !finished && err == nil {
err = errNilPanicOrGoexit
for p := t.parent; p != nil; p = p.parent {
p.mu.RLock()
finished = p.finished
p.mu.RUnlock()
if finished {
if !t.isParallel {
t.Errorf("%v: subtest may have called FailNow on a parent test", err)
err = nil
}
signal = false
break
}
}
}
if err != nil && t.context.isFuzzing {
prefix := "panic: "
if err == errNilPanicOrGoexit {
prefix = ""
}
t.Errorf("%s%s\n%s\n", prefix, err, string(debug.Stack()))
t.mu.Lock()
t.finished = true
t.mu.Unlock()
err = nil
}
// Use a deferred call to ensure that we report that the test is
// complete even if a cleanup function calls t.FailNow. See issue 41355.
didPanic := false
defer func() {
// Only report that the test is complete if it doesn't panic,
// as otherwise the test binary can exit before the panic is
// reported to the user. See issue 41479.
if didPanic {
return
}
if err != nil {
panic(err)
}
running.Delete(t.name)
t.signal <- signal
}()
doPanic := func(err any) {
t.Fail()
if r := t.runCleanup(recoverAndReturnPanic); r != nil {
t.Logf("cleanup panicked with %v", r)
}
// Flush the output log up to the root before dying.
for root := &t.common; root.parent != nil; root = root.parent {
root.mu.Lock()
root.duration += time.Since(root.start)
d := root.duration
root.mu.Unlock()
root.flushToParent(root.name, "--- FAIL: %s (%s)\n", root.name, fmtDuration(d))
if r := root.parent.runCleanup(recoverAndReturnPanic); r != nil {
fmt.Fprintf(root.parent.w, "cleanup panicked with %v", r)
}
}
didPanic = true
panic(err)
}
if err != nil {
doPanic(err)
}
t.duration += time.Since(t.start)
if len(t.sub) > 0 {
// Run parallel subtests.
// Decrease the running count for this test.
t.context.release()
// Release the parallel subtests.
close(t.barrier)
// Wait for subtests to complete.
for _, sub := range t.sub {
<-sub.signal
}
cleanupStart := time.Now()
err := t.runCleanup(recoverAndReturnPanic)
t.duration += time.Since(cleanupStart)
if err != nil {
doPanic(err)
}
if !t.isParallel {
// Reacquire the count for sequential tests. See comment in Run.
t.context.waitParallel()
}
} else if t.isParallel {
// Only release the count for this test if it was run as a parallel
// test. See comment in Run method.
t.context.release()
}
t.report() // Report after all subtests have finished.
// Do not lock t.done to allow race detector to detect race in case
// the user does not appropriately synchronize a goroutine.
t.done = true
if t.parent != nil && !t.hasSub.Load() {
t.setRan()
}
}()
defer func() {
if len(t.sub) == 0 {
t.runCleanup(normalPanic)
}
}()
t.start = time.Now()
t.raceErrors = -race.Errors()
fn(t)
// code beyond here will not be executed when FailNow is invoked
t.mu.Lock()
t.finished = true
t.mu.Unlock()
}
// Run runs f as a subtest of t called name. It runs f in a separate goroutine
// and blocks until f returns or calls t.Parallel to become a parallel test.
// Run reports whether f succeeded (or at least did not fail before calling t.Parallel).
//
// Run may be called simultaneously from multiple goroutines, but all such calls
// must return before the outer test function for t returns.
func (t *T) Run(name string, f func(t *T)) bool {
if t.cleanupStarted.Load() {
panic("testing: t.Run called during t.Cleanup")
}
t.hasSub.Store(true)
testName, ok, _ := t.context.match.fullName(&t.common, name)
if !ok || shouldFailFast() {
return true
}
// Record the stack trace at the point of this call so that if the subtest
// function - which runs in a separate stack - is marked as a helper, we can
// continue walking the stack into the parent test.
var pc [maxStackLen]uintptr
n := runtime.Callers(2, pc[:])
t = &T{
common: common{
barrier: make(chan bool),
signal: make(chan bool, 1),
name: testName,
parent: &t.common,
level: t.level + 1,
creator: pc[:n],
chatty: t.chatty,
},
context: t.context,
}
t.w = indenter{&t.common}
if t.chatty != nil {
t.chatty.Updatef(t.name, "=== RUN %s\n", t.name)
}
running.Store(t.name, time.Now())
// Instead of reducing the running count of this test before calling the
// tRunner and increasing it afterwards, we rely on tRunner keeping the
// count correct. This ensures that a sequence of sequential tests runs
// without being preempted, even when their parent is a parallel test. This
// may especially reduce surprises if *parallel == 1.
go tRunner(t, f)
if !<-t.signal {
// At this point, it is likely that FailNow was called on one of the
// parent tests by one of the subtests. Continue aborting up the chain.
runtime.Goexit()
}
if t.chatty != nil && t.chatty.json {
t.chatty.Updatef(t.parent.name, "=== NAME %s\n", t.parent.name)
}
return !t.failed
}
// Deadline reports the time at which the test binary will have
// exceeded the timeout specified by the -timeout flag.
//
// The ok result is false if the -timeout flag indicates “no timeout” (0).
func (t *T) Deadline() (deadline time.Time, ok bool) {
deadline = t.context.deadline
return deadline, !deadline.IsZero()
}
// testContext holds all fields that are common to all tests. This includes
// synchronization primitives to run at most *parallel tests.
type testContext struct {
match *matcher
deadline time.Time
// isFuzzing is true in the context used when generating random inputs
// for fuzz targets. isFuzzing is false when running normal tests and
// when running fuzz tests as unit tests (without -fuzz or when -fuzz
// does not match).
isFuzzing bool
mu sync.Mutex
// Channel used to signal tests that are ready to be run in parallel.
startParallel chan bool
// running is the number of tests currently running in parallel.
// This does not include tests that are waiting for subtests to complete.
running int
// numWaiting is the number tests waiting to be run in parallel.
numWaiting int
// maxParallel is a copy of the parallel flag.
maxParallel int
}
func newTestContext(maxParallel int, m *matcher) *testContext {
return &testContext{
match: m,
startParallel: make(chan bool),
maxParallel: maxParallel,
running: 1, // Set the count to 1 for the main (sequential) test.
}
}
func (c *testContext) waitParallel() {
c.mu.Lock()
if c.running < c.maxParallel {
c.running++
c.mu.Unlock()
return
}
c.numWaiting++
c.mu.Unlock()
<-c.startParallel
}
func (c *testContext) release() {
c.mu.Lock()
if c.numWaiting == 0 {
c.running--
c.mu.Unlock()
return
}
c.numWaiting--
c.mu.Unlock()
c.startParallel <- true // Pick a waiting test to be run.
}
// No one should be using func Main anymore.
// See the doc comment on func Main and use MainStart instead.
var errMain = errors.New("testing: unexpected use of func Main")
type matchStringOnly func(pat, str string) (bool, error)
func (f matchStringOnly) MatchString(pat, str string) (bool, error) { return f(pat, str) }
func (f matchStringOnly) StartCPUProfile(w io.Writer) error { return errMain }
func (f matchStringOnly) StopCPUProfile() {}
func (f matchStringOnly) WriteProfileTo(string, io.Writer, int) error { return errMain }
func (f matchStringOnly) ImportPath() string { return "" }
func (f matchStringOnly) StartTestLog(io.Writer) {}
func (f matchStringOnly) StopTestLog() error { return errMain }
func (f matchStringOnly) SetPanicOnExit0(bool) {}
func (f matchStringOnly) CoordinateFuzzing(time.Duration, int64, time.Duration, int64, int, []corpusEntry, []reflect.Type, string, string) error {
return errMain
}
func (f matchStringOnly) RunFuzzWorker(func(corpusEntry) error) error { return errMain }
func (f matchStringOnly) ReadCorpus(string, []reflect.Type) ([]corpusEntry, error) {
return nil, errMain
}
func (f matchStringOnly) CheckCorpus([]any, []reflect.Type) error { return nil }
func (f matchStringOnly) ResetCoverage() {}
func (f matchStringOnly) SnapshotCoverage() {}
// Main is an internal function, part of the implementation of the "go test" command.
// It was exported because it is cross-package and predates "internal" packages.
// It is no longer used by "go test" but preserved, as much as possible, for other
// systems that simulate "go test" using Main, but Main sometimes cannot be updated as
// new functionality is added to the testing package.
// Systems simulating "go test" should be updated to use MainStart.
func Main(matchString func(pat, str string) (bool, error), tests []InternalTest, benchmarks []InternalBenchmark, examples []InternalExample) {
os.Exit(MainStart(matchStringOnly(matchString), tests, benchmarks, nil, examples).Run())
}
// M is a type passed to a TestMain function to run the actual tests.
type M struct {
deps testDeps
tests []InternalTest
benchmarks []InternalBenchmark
fuzzTargets []InternalFuzzTarget
examples []InternalExample
timer *time.Timer
afterOnce sync.Once
numRun int
// value to pass to os.Exit, the outer test func main
// harness calls os.Exit with this code. See #34129.
exitCode int
}
// testDeps is an internal interface of functionality that is
// passed into this package by a test's generated main package.
// The canonical implementation of this interface is
// testing/internal/testdeps's TestDeps.
type testDeps interface {
ImportPath() string
MatchString(pat, str string) (bool, error)
SetPanicOnExit0(bool)
StartCPUProfile(io.Writer) error
StopCPUProfile()
StartTestLog(io.Writer)
StopTestLog() error
WriteProfileTo(string, io.Writer, int) error
CoordinateFuzzing(time.Duration, int64, time.Duration, int64, int, []corpusEntry, []reflect.Type, string, string) error
RunFuzzWorker(func(corpusEntry) error) error
ReadCorpus(string, []reflect.Type) ([]corpusEntry, error)
CheckCorpus([]any, []reflect.Type) error
ResetCoverage()
SnapshotCoverage()
}
// MainStart is meant for use by tests generated by 'go test'.
// It is not meant to be called directly and is not subject to the Go 1 compatibility document.
// It may change signature from release to release.
func MainStart(deps testDeps, tests []InternalTest, benchmarks []InternalBenchmark, fuzzTargets []InternalFuzzTarget, examples []InternalExample) *M {
Init()
return &M{
deps: deps,
tests: tests,
benchmarks: benchmarks,
fuzzTargets: fuzzTargets,
examples: examples,
}
}
var testingTesting bool
var realStderr *os.File
// Run runs the tests. It returns an exit code to pass to os.Exit.
func (m *M) Run() (code int) {
defer func() {
code = m.exitCode
}()
// Count the number of calls to m.Run.
// We only ever expected 1, but we didn't enforce that,
// and now there are tests in the wild that call m.Run multiple times.
// Sigh. go.dev/issue/23129.
m.numRun++
// TestMain may have already called flag.Parse.
if !flag.Parsed() {
flag.Parse()
}
if chatty.json {
// With -v=json, stdout and stderr are pointing to the same pipe,
// which is leading into test2json. In general, operating systems
// do a good job of ensuring that writes to the same pipe through
// different file descriptors are delivered whole, so that writing
// AAA to stdout and BBB to stderr simultaneously produces
// AAABBB or BBBAAA on the pipe, not something like AABBBA.
// However, the exception to this is when the pipe fills: in that
// case, Go's use of non-blocking I/O means that writing AAA
// or BBB might be split across multiple system calls, making it
// entirely possible to get output like AABBBA. The same problem
// happens inside the operating system kernel if we switch to
// blocking I/O on the pipe. This interleaved output can do things
// like print unrelated messages in the middle of a TestFoo line,
// which confuses test2json. Setting os.Stderr = os.Stdout will make
// them share a single pfd, which will hold a lock for each program
// write, preventing any interleaving.
//
// It might be nice to set Stderr = Stdout always, or perhaps if
// we can tell they are the same file, but for now -v=json is
// a very clear signal. Making the two files the same may cause
// surprises if programs close os.Stdout but expect to be able
// to continue to write to os.Stderr, but it's hard to see why a
// test would think it could take over global state that way.
//
// This fix only helps programs where the output is coming directly
// from Go code. It does not help programs in which a subprocess is
// writing to stderr or stdout at the same time that a Go test is writing output.
// It also does not help when the output is coming from the runtime,
// such as when using the print/println functions, since that code writes
// directly to fd 2 without any locking.
// We keep realStderr around to prevent fd 2 from being closed.
//
// See go.dev/issue/33419.
realStderr = os.Stderr
os.Stderr = os.Stdout
}
if *parallel < 1 {
fmt.Fprintln(os.Stderr, "testing: -parallel can only be given a positive integer")
flag.Usage()
m.exitCode = 2
return
}
if *matchFuzz != "" && *fuzzCacheDir == "" {
fmt.Fprintln(os.Stderr, "testing: -test.fuzzcachedir must be set if -test.fuzz is set")
flag.Usage()
m.exitCode = 2
return
}
if *matchList != "" {
listTests(m.deps.MatchString, m.tests, m.benchmarks, m.fuzzTargets, m.examples)
m.exitCode = 0
return
}
if *shuffle != "off" {
var n int64
var err error
if *shuffle == "on" {
n = time.Now().UnixNano()
} else {
n, err = strconv.ParseInt(*shuffle, 10, 64)
if err != nil {
fmt.Fprintln(os.Stderr, `testing: -shuffle should be "off", "on", or a valid integer:`, err)
m.exitCode = 2
return
}
}
fmt.Println("-test.shuffle", n)
rng := rand.New(rand.NewSource(n))
rng.Shuffle(len(m.tests), func(i, j int) { m.tests[i], m.tests[j] = m.tests[j], m.tests[i] })
rng.Shuffle(len(m.benchmarks), func(i, j int) { m.benchmarks[i], m.benchmarks[j] = m.benchmarks[j], m.benchmarks[i] })
}
parseCpuList()
m.before()
defer m.after()
// Run tests, examples, and benchmarks unless this is a fuzz worker process.
// Workers start after this is done by their parent process, and they should
// not repeat this work.
if !*isFuzzWorker {
deadline := m.startAlarm()
haveExamples = len(m.examples) > 0
testRan, testOk := runTests(m.deps.MatchString, m.tests, deadline)
fuzzTargetsRan, fuzzTargetsOk := runFuzzTests(m.deps, m.fuzzTargets, deadline)
exampleRan, exampleOk := runExamples(m.deps.MatchString, m.examples)
m.stopAlarm()
if !testRan && !exampleRan && !fuzzTargetsRan && *matchBenchmarks == "" && *matchFuzz == "" {
fmt.Fprintln(os.Stderr, "testing: warning: no tests to run")
if testingTesting && *match != "^$" {
// If this happens during testing of package testing it could be that
// package testing's own logic for when to run a test is broken,
// in which case every test will run nothing and succeed,
// with no obvious way to detect this problem (since no tests are running).
// So make 'no tests to run' a hard failure when testing package testing itself.
// The compile-only builders use -run=^$ to run no tests, so allow that.
fmt.Print(chatty.prefix(), "FAIL: package testing must run tests\n")
testOk = false
}
}
if !testOk || !exampleOk || !fuzzTargetsOk || !runBenchmarks(m.deps.ImportPath(), m.deps.MatchString, m.benchmarks) || race.Errors() > 0 {
fmt.Print(chatty.prefix(), "FAIL\n")
m.exitCode = 1
return
}
}
fuzzingOk := runFuzzing(m.deps, m.fuzzTargets)
if !fuzzingOk {
fmt.Print(chatty.prefix(), "FAIL\n")
if *isFuzzWorker {
m.exitCode = fuzzWorkerExitCode
} else {
m.exitCode = 1
}
return
}
m.exitCode = 0
if !*isFuzzWorker {
fmt.Print(chatty.prefix(), "PASS\n")
}
return
}
func (t *T) report() {
if t.parent == nil {
return
}
dstr := fmtDuration(t.duration)
format := "--- %s: %s (%s)\n"
if t.Failed() {
t.flushToParent(t.name, format, "FAIL", t.name, dstr)
} else if t.chatty != nil {
if t.Skipped() {
t.flushToParent(t.name, format, "SKIP", t.name, dstr)
} else {
t.flushToParent(t.name, format, "PASS", t.name, dstr)
}
}
}
func listTests(matchString func(pat, str string) (bool, error), tests []InternalTest, benchmarks []InternalBenchmark, fuzzTargets []InternalFuzzTarget, examples []InternalExample) {
if _, err := matchString(*matchList, "non-empty"); err != nil {
fmt.Fprintf(os.Stderr, "testing: invalid regexp in -test.list (%q): %s\n", *matchList, err)
os.Exit(1)
}
for _, test := range tests {
if ok, _ := matchString(*matchList, test.Name); ok {
fmt.Println(test.Name)
}
}
for _, bench := range benchmarks {
if ok, _ := matchString(*matchList, bench.Name); ok {
fmt.Println(bench.Name)
}
}
for _, fuzzTarget := range fuzzTargets {
if ok, _ := matchString(*matchList, fuzzTarget.Name); ok {
fmt.Println(fuzzTarget.Name)
}
}
for _, example := range examples {
if ok, _ := matchString(*matchList, example.Name); ok {
fmt.Println(example.Name)
}
}
}
// RunTests is an internal function but exported because it is cross-package;
// it is part of the implementation of the "go test" command.
func RunTests(matchString func(pat, str string) (bool, error), tests []InternalTest) (ok bool) {
var deadline time.Time
if *timeout > 0 {
deadline = time.Now().Add(*timeout)
}
ran, ok := runTests(matchString, tests, deadline)
if !ran && !haveExamples {
fmt.Fprintln(os.Stderr, "testing: warning: no tests to run")
}
return ok
}
func runTests(matchString func(pat, str string) (bool, error), tests []InternalTest, deadline time.Time) (ran, ok bool) {
ok = true
for _, procs := range cpuList {
runtime.GOMAXPROCS(procs)
for i := uint(0); i < *count; i++ {
if shouldFailFast() {
break
}
if i > 0 && !ran {
// There were no tests to run on the first
// iteration. This won't change, so no reason
// to keep trying.
break
}
ctx := newTestContext(*parallel, newMatcher(matchString, *match, "-test.run", *skip))
ctx.deadline = deadline
t := &T{
common: common{
signal: make(chan bool, 1),
barrier: make(chan bool),
w: os.Stdout,
},
context: ctx,
}
if Verbose() {
t.chatty = newChattyPrinter(t.w)
}
tRunner(t, func(t *T) {
for _, test := range tests {
t.Run(test.Name, test.F)
}
})
select {
case <-t.signal:
default:
panic("internal error: tRunner exited without sending on t.signal")
}
ok = ok && !t.Failed()
ran = ran || t.ran
}
}
return ran, ok
}
// before runs before all testing.
func (m *M) before() {
if *memProfileRate > 0 {
runtime.MemProfileRate = *memProfileRate
}
if *cpuProfile != "" {
f, err := os.Create(toOutputDir(*cpuProfile))
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
return
}
if err := m.deps.StartCPUProfile(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't start cpu profile: %s\n", err)
f.Close()
return
}
// Could save f so after can call f.Close; not worth the effort.
}
if *traceFile != "" {
f, err := os.Create(toOutputDir(*traceFile))
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
return
}
if err := trace.Start(f); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't start tracing: %s\n", err)
f.Close()
return
}
// Could save f so after can call f.Close; not worth the effort.
}
if *blockProfile != "" && *blockProfileRate >= 0 {
runtime.SetBlockProfileRate(*blockProfileRate)
}
if *mutexProfile != "" && *mutexProfileFraction >= 0 {
runtime.SetMutexProfileFraction(*mutexProfileFraction)
}
if *coverProfile != "" && CoverMode() == "" {
fmt.Fprintf(os.Stderr, "testing: cannot use -test.coverprofile because test binary was not built with coverage enabled\n")
os.Exit(2)
}
if *gocoverdir != "" && CoverMode() == "" {
fmt.Fprintf(os.Stderr, "testing: cannot use -test.gocoverdir because test binary was not built with coverage enabled\n")
os.Exit(2)
}
if *testlog != "" {
// Note: Not using toOutputDir.
// This file is for use by cmd/go, not users.
var f *os.File
var err error
if m.numRun == 1 {
f, err = os.Create(*testlog)
} else {
f, err = os.OpenFile(*testlog, os.O_WRONLY, 0)
if err == nil {
f.Seek(0, io.SeekEnd)
}
}
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
m.deps.StartTestLog(f)
testlogFile = f
}
if *panicOnExit0 {
m.deps.SetPanicOnExit0(true)
}
}
// after runs after all testing.
func (m *M) after() {
m.afterOnce.Do(func() {
m.writeProfiles()
})
// Restore PanicOnExit0 after every run, because we set it to true before
// every run. Otherwise, if m.Run is called multiple times the behavior of
// os.Exit(0) will not be restored after the second run.
if *panicOnExit0 {
m.deps.SetPanicOnExit0(false)
}
}
func (m *M) writeProfiles() {
if *testlog != "" {
if err := m.deps.StopTestLog(); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write %s: %s\n", *testlog, err)
os.Exit(2)
}
if err := testlogFile.Close(); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write %s: %s\n", *testlog, err)
os.Exit(2)
}
}
if *cpuProfile != "" {
m.deps.StopCPUProfile() // flushes profile to disk
}
if *traceFile != "" {
trace.Stop() // flushes trace to disk
}
if *memProfile != "" {
f, err := os.Create(toOutputDir(*memProfile))
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
runtime.GC() // materialize all statistics
if err = m.deps.WriteProfileTo("allocs", f, 0); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write %s: %s\n", *memProfile, err)
os.Exit(2)
}
f.Close()
}
if *blockProfile != "" && *blockProfileRate >= 0 {
f, err := os.Create(toOutputDir(*blockProfile))
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
if err = m.deps.WriteProfileTo("block", f, 0); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write %s: %s\n", *blockProfile, err)
os.Exit(2)
}
f.Close()
}
if *mutexProfile != "" && *mutexProfileFraction >= 0 {
f, err := os.Create(toOutputDir(*mutexProfile))
if err != nil {
fmt.Fprintf(os.Stderr, "testing: %s\n", err)
os.Exit(2)
}
if err = m.deps.WriteProfileTo("mutex", f, 0); err != nil {
fmt.Fprintf(os.Stderr, "testing: can't write %s: %s\n", *mutexProfile, err)
os.Exit(2)
}
f.Close()
}
if CoverMode() != "" {
coverReport()
}
}
// toOutputDir returns the file name relocated, if required, to outputDir.
// Simple implementation to avoid pulling in path/filepath.
func toOutputDir(path string) string {
if *outputDir == "" || path == "" {
return path
}
// On Windows, it's clumsy, but we can be almost always correct
// by just looking for a drive letter and a colon.
// Absolute paths always have a drive letter (ignoring UNC).
// Problem: if path == "C:A" and outputdir == "C:\Go" it's unclear
// what to do, but even then path/filepath doesn't help.
// TODO: Worth doing better? Probably not, because we're here only
// under the management of go test.
if runtime.GOOS == "windows" && len(path) >= 2 {
letter, colon := path[0], path[1]
if ('a' <= letter && letter <= 'z' || 'A' <= letter && letter <= 'Z') && colon == ':' {
// If path starts with a drive letter we're stuck with it regardless.
return path
}
}
if os.IsPathSeparator(path[0]) {
return path
}
return fmt.Sprintf("%s%c%s", *outputDir, os.PathSeparator, path)
}
// startAlarm starts an alarm if requested.
func (m *M) startAlarm() time.Time {
if *timeout <= 0 {
return time.Time{}
}
deadline := time.Now().Add(*timeout)
m.timer = time.AfterFunc(*timeout, func() {
m.after()
debug.SetTraceback("all")
extra := ""
if list := runningList(); len(list) > 0 {
var b strings.Builder
b.WriteString("\nrunning tests:")
for _, name := range list {
b.WriteString("\n\t")
b.WriteString(name)
}
extra = b.String()
}
panic(fmt.Sprintf("test timed out after %v%s", *timeout, extra))
})
return deadline
}
// runningList returns the list of running tests.
func runningList() []string {
var list []string
running.Range(func(k, v any) bool {
list = append(list, fmt.Sprintf("%s (%v)", k.(string), time.Since(v.(time.Time)).Round(time.Second)))
return true
})
sort.Strings(list)
return list
}
// stopAlarm turns off the alarm.
func (m *M) stopAlarm() {
if *timeout > 0 {
m.timer.Stop()
}
}
func parseCpuList() {
for _, val := range strings.Split(*cpuListStr, ",") {
val = strings.TrimSpace(val)
if val == "" {
continue
}
cpu, err := strconv.Atoi(val)
if err != nil || cpu <= 0 {
fmt.Fprintf(os.Stderr, "testing: invalid value %q for -test.cpu\n", val)
os.Exit(1)
}
cpuList = append(cpuList, cpu)
}
if cpuList == nil {
cpuList = append(cpuList, runtime.GOMAXPROCS(-1))
}
}
func shouldFailFast() bool {
return *failFast && numFailed.Load() > 0
}
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !windows
package testing
// isWindowsRetryable reports whether err is a Windows error code
// that may be fixed by retrying a failed filesystem operation.
func isWindowsRetryable(err error) bool {
return false
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package scanner provides a scanner and tokenizer for UTF-8-encoded text.
// It takes an io.Reader providing the source, which then can be tokenized
// through repeated calls to the Scan function. For compatibility with
// existing tools, the NUL character is not allowed. If the first character
// in the source is a UTF-8 encoded byte order mark (BOM), it is discarded.
//
// By default, a Scanner skips white space and Go comments and recognizes all
// literals as defined by the Go language specification. It may be
// customized to recognize only a subset of those literals and to recognize
// different identifier and white space characters.
package scanner
import (
"bytes"
"fmt"
"io"
"os"
"unicode"
"unicode/utf8"
)
// Position is a value that represents a source position.
// A position is valid if Line > 0.
type Position struct {
Filename string // filename, if any
Offset int // byte offset, starting at 0
Line int // line number, starting at 1
Column int // column number, starting at 1 (character count per line)
}
// IsValid reports whether the position is valid.
func (pos *Position) IsValid() bool { return pos.Line > 0 }
func (pos Position) String() string {
s := pos.Filename
if s == "" {
s = "<input>"
}
if pos.IsValid() {
s += fmt.Sprintf(":%d:%d", pos.Line, pos.Column)
}
return s
}
// Predefined mode bits to control recognition of tokens. For instance,
// to configure a Scanner such that it only recognizes (Go) identifiers,
// integers, and skips comments, set the Scanner's Mode field to:
//
// ScanIdents | ScanInts | SkipComments
//
// With the exceptions of comments, which are skipped if SkipComments is
// set, unrecognized tokens are not ignored. Instead, the scanner simply
// returns the respective individual characters (or possibly sub-tokens).
// For instance, if the mode is ScanIdents (not ScanStrings), the string
// "foo" is scanned as the token sequence '"' Ident '"'.
//
// Use GoTokens to configure the Scanner such that it accepts all Go
// literal tokens including Go identifiers. Comments will be skipped.
const (
ScanIdents = 1 << -Ident
ScanInts = 1 << -Int
ScanFloats = 1 << -Float // includes Ints and hexadecimal floats
ScanChars = 1 << -Char
ScanStrings = 1 << -String
ScanRawStrings = 1 << -RawString
ScanComments = 1 << -Comment
SkipComments = 1 << -skipComment // if set with ScanComments, comments become white space
GoTokens = ScanIdents | ScanFloats | ScanChars | ScanStrings | ScanRawStrings | ScanComments | SkipComments
)
// The result of Scan is one of these tokens or a Unicode character.
const (
EOF = -(iota + 1)
Ident
Int
Float
Char
String
RawString
Comment
// internal use only
skipComment
)
var tokenString = map[rune]string{
EOF: "EOF",
Ident: "Ident",
Int: "Int",
Float: "Float",
Char: "Char",
String: "String",
RawString: "RawString",
Comment: "Comment",
}
// TokenString returns a printable string for a token or Unicode character.
func TokenString(tok rune) string {
if s, found := tokenString[tok]; found {
return s
}
return fmt.Sprintf("%q", string(tok))
}
// GoWhitespace is the default value for the Scanner's Whitespace field.
// Its value selects Go's white space characters.
const GoWhitespace = 1<<'\t' | 1<<'\n' | 1<<'\r' | 1<<' '
const bufLen = 1024 // at least utf8.UTFMax
// A Scanner implements reading of Unicode characters and tokens from an io.Reader.
type Scanner struct {
// Input
src io.Reader
// Source buffer
srcBuf [bufLen + 1]byte // +1 for sentinel for common case of s.next()
srcPos int // reading position (srcBuf index)
srcEnd int // source end (srcBuf index)
// Source position
srcBufOffset int // byte offset of srcBuf[0] in source
line int // line count
column int // character count
lastLineLen int // length of last line in characters (for correct column reporting)
lastCharLen int // length of last character in bytes
// Token text buffer
// Typically, token text is stored completely in srcBuf, but in general
// the token text's head may be buffered in tokBuf while the token text's
// tail is stored in srcBuf.
tokBuf bytes.Buffer // token text head that is not in srcBuf anymore
tokPos int // token text tail position (srcBuf index); valid if >= 0
tokEnd int // token text tail end (srcBuf index)
// One character look-ahead
ch rune // character before current srcPos
// Error is called for each error encountered. If no Error
// function is set, the error is reported to os.Stderr.
Error func(s *Scanner, msg string)
// ErrorCount is incremented by one for each error encountered.
ErrorCount int
// The Mode field controls which tokens are recognized. For instance,
// to recognize Ints, set the ScanInts bit in Mode. The field may be
// changed at any time.
Mode uint
// The Whitespace field controls which characters are recognized
// as white space. To recognize a character ch <= ' ' as white space,
// set the ch'th bit in Whitespace (the Scanner's behavior is undefined
// for values ch > ' '). The field may be changed at any time.
Whitespace uint64
// IsIdentRune is a predicate controlling the characters accepted
// as the ith rune in an identifier. The set of valid characters
// must not intersect with the set of white space characters.
// If no IsIdentRune function is set, regular Go identifiers are
// accepted instead. The field may be changed at any time.
IsIdentRune func(ch rune, i int) bool
// Start position of most recently scanned token; set by Scan.
// Calling Init or Next invalidates the position (Line == 0).
// The Filename field is always left untouched by the Scanner.
// If an error is reported (via Error) and Position is invalid,
// the scanner is not inside a token. Call Pos to obtain an error
// position in that case, or to obtain the position immediately
// after the most recently scanned token.
Position
}
// Init initializes a Scanner with a new source and returns s.
// Error is set to nil, ErrorCount is set to 0, Mode is set to GoTokens,
// and Whitespace is set to GoWhitespace.
func (s *Scanner) Init(src io.Reader) *Scanner {
s.src = src
// initialize source buffer
// (the first call to next() will fill it by calling src.Read)
s.srcBuf[0] = utf8.RuneSelf // sentinel
s.srcPos = 0
s.srcEnd = 0
// initialize source position
s.srcBufOffset = 0
s.line = 1
s.column = 0
s.lastLineLen = 0
s.lastCharLen = 0
// initialize token text buffer
// (required for first call to next()).
s.tokPos = -1
// initialize one character look-ahead
s.ch = -2 // no char read yet, not EOF
// initialize public fields
s.Error = nil
s.ErrorCount = 0
s.Mode = GoTokens
s.Whitespace = GoWhitespace
s.Line = 0 // invalidate token position
return s
}
// next reads and returns the next Unicode character. It is designed such
// that only a minimal amount of work needs to be done in the common ASCII
// case (one test to check for both ASCII and end-of-buffer, and one test
// to check for newlines).
func (s *Scanner) next() rune {
ch, width := rune(s.srcBuf[s.srcPos]), 1
if ch >= utf8.RuneSelf {
// uncommon case: not ASCII or not enough bytes
for s.srcPos+utf8.UTFMax > s.srcEnd && !utf8.FullRune(s.srcBuf[s.srcPos:s.srcEnd]) {
// not enough bytes: read some more, but first
// save away token text if any
if s.tokPos >= 0 {
s.tokBuf.Write(s.srcBuf[s.tokPos:s.srcPos])
s.tokPos = 0
// s.tokEnd is set by Scan()
}
// move unread bytes to beginning of buffer
copy(s.srcBuf[0:], s.srcBuf[s.srcPos:s.srcEnd])
s.srcBufOffset += s.srcPos
// read more bytes
// (an io.Reader must return io.EOF when it reaches
// the end of what it is reading - simply returning
// n == 0 will make this loop retry forever; but the
// error is in the reader implementation in that case)
i := s.srcEnd - s.srcPos
n, err := s.src.Read(s.srcBuf[i:bufLen])
s.srcPos = 0
s.srcEnd = i + n
s.srcBuf[s.srcEnd] = utf8.RuneSelf // sentinel
if err != nil {
if err != io.EOF {
s.error(err.Error())
}
if s.srcEnd == 0 {
if s.lastCharLen > 0 {
// previous character was not EOF
s.column++
}
s.lastCharLen = 0
return EOF
}
// If err == EOF, we won't be getting more
// bytes; break to avoid infinite loop. If
// err is something else, we don't know if
// we can get more bytes; thus also break.
break
}
}
// at least one byte
ch = rune(s.srcBuf[s.srcPos])
if ch >= utf8.RuneSelf {
// uncommon case: not ASCII
ch, width = utf8.DecodeRune(s.srcBuf[s.srcPos:s.srcEnd])
if ch == utf8.RuneError && width == 1 {
// advance for correct error position
s.srcPos += width
s.lastCharLen = width
s.column++
s.error("invalid UTF-8 encoding")
return ch
}
}
}
// advance
s.srcPos += width
s.lastCharLen = width
s.column++
// special situations
switch ch {
case 0:
// for compatibility with other tools
s.error("invalid character NUL")
case '\n':
s.line++
s.lastLineLen = s.column
s.column = 0
}
return ch
}
// Next reads and returns the next Unicode character.
// It returns EOF at the end of the source. It reports
// a read error by calling s.Error, if not nil; otherwise
// it prints an error message to os.Stderr. Next does not
// update the Scanner's Position field; use Pos() to
// get the current position.
func (s *Scanner) Next() rune {
s.tokPos = -1 // don't collect token text
s.Line = 0 // invalidate token position
ch := s.Peek()
if ch != EOF {
s.ch = s.next()
}
return ch
}
// Peek returns the next Unicode character in the source without advancing
// the scanner. It returns EOF if the scanner's position is at the last
// character of the source.
func (s *Scanner) Peek() rune {
if s.ch == -2 {
// this code is only run for the very first character
s.ch = s.next()
if s.ch == '\uFEFF' {
s.ch = s.next() // ignore BOM
}
}
return s.ch
}
func (s *Scanner) error(msg string) {
s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated
s.ErrorCount++
if s.Error != nil {
s.Error(s, msg)
return
}
pos := s.Position
if !pos.IsValid() {
pos = s.Pos()
}
fmt.Fprintf(os.Stderr, "%s: %s\n", pos, msg)
}
func (s *Scanner) errorf(format string, args ...any) {
s.error(fmt.Sprintf(format, args...))
}
func (s *Scanner) isIdentRune(ch rune, i int) bool {
if s.IsIdentRune != nil {
return ch != EOF && s.IsIdentRune(ch, i)
}
return ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && i > 0
}
func (s *Scanner) scanIdentifier() rune {
// we know the zero'th rune is OK; start scanning at the next one
ch := s.next()
for i := 1; s.isIdentRune(ch, i); i++ {
ch = s.next()
}
return ch
}
func lower(ch rune) rune { return ('a' - 'A') | ch } // returns lower-case ch iff ch is ASCII letter
func isDecimal(ch rune) bool { return '0' <= ch && ch <= '9' }
func isHex(ch rune) bool { return '0' <= ch && ch <= '9' || 'a' <= lower(ch) && lower(ch) <= 'f' }
// digits accepts the sequence { digit | '_' } starting with ch0.
// If base <= 10, digits accepts any decimal digit but records
// the first invalid digit >= base in *invalid if *invalid == 0.
// digits returns the first rune that is not part of the sequence
// anymore, and a bitset describing whether the sequence contained
// digits (bit 0 is set), or separators '_' (bit 1 is set).
func (s *Scanner) digits(ch0 rune, base int, invalid *rune) (ch rune, digsep int) {
ch = ch0
if base <= 10 {
max := rune('0' + base)
for isDecimal(ch) || ch == '_' {
ds := 1
if ch == '_' {
ds = 2
} else if ch >= max && *invalid == 0 {
*invalid = ch
}
digsep |= ds
ch = s.next()
}
} else {
for isHex(ch) || ch == '_' {
ds := 1
if ch == '_' {
ds = 2
}
digsep |= ds
ch = s.next()
}
}
return
}
func (s *Scanner) scanNumber(ch rune, seenDot bool) (rune, rune) {
base := 10 // number base
prefix := rune(0) // one of 0 (decimal), '0' (0-octal), 'x', 'o', or 'b'
digsep := 0 // bit 0: digit present, bit 1: '_' present
invalid := rune(0) // invalid digit in literal, or 0
// integer part
var tok rune
var ds int
if !seenDot {
tok = Int
if ch == '0' {
ch = s.next()
switch lower(ch) {
case 'x':
ch = s.next()
base, prefix = 16, 'x'
case 'o':
ch = s.next()
base, prefix = 8, 'o'
case 'b':
ch = s.next()
base, prefix = 2, 'b'
default:
base, prefix = 8, '0'
digsep = 1 // leading 0
}
}
ch, ds = s.digits(ch, base, &invalid)
digsep |= ds
if ch == '.' && s.Mode&ScanFloats != 0 {
ch = s.next()
seenDot = true
}
}
// fractional part
if seenDot {
tok = Float
if prefix == 'o' || prefix == 'b' {
s.error("invalid radix point in " + litname(prefix))
}
ch, ds = s.digits(ch, base, &invalid)
digsep |= ds
}
if digsep&1 == 0 {
s.error(litname(prefix) + " has no digits")
}
// exponent
if e := lower(ch); (e == 'e' || e == 'p') && s.Mode&ScanFloats != 0 {
switch {
case e == 'e' && prefix != 0 && prefix != '0':
s.errorf("%q exponent requires decimal mantissa", ch)
case e == 'p' && prefix != 'x':
s.errorf("%q exponent requires hexadecimal mantissa", ch)
}
ch = s.next()
tok = Float
if ch == '+' || ch == '-' {
ch = s.next()
}
ch, ds = s.digits(ch, 10, nil)
digsep |= ds
if ds&1 == 0 {
s.error("exponent has no digits")
}
} else if prefix == 'x' && tok == Float {
s.error("hexadecimal mantissa requires a 'p' exponent")
}
if tok == Int && invalid != 0 {
s.errorf("invalid digit %q in %s", invalid, litname(prefix))
}
if digsep&2 != 0 {
s.tokEnd = s.srcPos - s.lastCharLen // make sure token text is terminated
if i := invalidSep(s.TokenText()); i >= 0 {
s.error("'_' must separate successive digits")
}
}
return tok, ch
}
func litname(prefix rune) string {
switch prefix {
default:
return "decimal literal"
case 'x':
return "hexadecimal literal"
case 'o', '0':
return "octal literal"
case 'b':
return "binary literal"
}
}
// invalidSep returns the index of the first invalid separator in x, or -1.
func invalidSep(x string) int {
x1 := ' ' // prefix char, we only care if it's 'x'
d := '.' // digit, one of '_', '0' (a digit), or '.' (anything else)
i := 0
// a prefix counts as a digit
if len(x) >= 2 && x[0] == '0' {
x1 = lower(rune(x[1]))
if x1 == 'x' || x1 == 'o' || x1 == 'b' {
d = '0'
i = 2
}
}
// mantissa and exponent
for ; i < len(x); i++ {
p := d // previous digit
d = rune(x[i])
switch {
case d == '_':
if p != '0' {
return i
}
case isDecimal(d) || x1 == 'x' && isHex(d):
d = '0'
default:
if p == '_' {
return i - 1
}
d = '.'
}
}
if d == '_' {
return len(x) - 1
}
return -1
}
func digitVal(ch rune) int {
switch {
case '0' <= ch && ch <= '9':
return int(ch - '0')
case 'a' <= lower(ch) && lower(ch) <= 'f':
return int(lower(ch) - 'a' + 10)
}
return 16 // larger than any legal digit val
}
func (s *Scanner) scanDigits(ch rune, base, n int) rune {
for n > 0 && digitVal(ch) < base {
ch = s.next()
n--
}
if n > 0 {
s.error("invalid char escape")
}
return ch
}
func (s *Scanner) scanEscape(quote rune) rune {
ch := s.next() // read character after '/'
switch ch {
case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', quote:
// nothing to do
ch = s.next()
case '0', '1', '2', '3', '4', '5', '6', '7':
ch = s.scanDigits(ch, 8, 3)
case 'x':
ch = s.scanDigits(s.next(), 16, 2)
case 'u':
ch = s.scanDigits(s.next(), 16, 4)
case 'U':
ch = s.scanDigits(s.next(), 16, 8)
default:
s.error("invalid char escape")
}
return ch
}
func (s *Scanner) scanString(quote rune) (n int) {
ch := s.next() // read character after quote
for ch != quote {
if ch == '\n' || ch < 0 {
s.error("literal not terminated")
return
}
if ch == '\\' {
ch = s.scanEscape(quote)
} else {
ch = s.next()
}
n++
}
return
}
func (s *Scanner) scanRawString() {
ch := s.next() // read character after '`'
for ch != '`' {
if ch < 0 {
s.error("literal not terminated")
return
}
ch = s.next()
}
}
func (s *Scanner) scanChar() {
if s.scanString('\'') != 1 {
s.error("invalid char literal")
}
}
func (s *Scanner) scanComment(ch rune) rune {
// ch == '/' || ch == '*'
if ch == '/' {
// line comment
ch = s.next() // read character after "//"
for ch != '\n' && ch >= 0 {
ch = s.next()
}
return ch
}
// general comment
ch = s.next() // read character after "/*"
for {
if ch < 0 {
s.error("comment not terminated")
break
}
ch0 := ch
ch = s.next()
if ch0 == '*' && ch == '/' {
ch = s.next()
break
}
}
return ch
}
// Scan reads the next token or Unicode character from source and returns it.
// It only recognizes tokens t for which the respective Mode bit (1<<-t) is set.
// It returns EOF at the end of the source. It reports scanner errors (read and
// token errors) by calling s.Error, if not nil; otherwise it prints an error
// message to os.Stderr.
func (s *Scanner) Scan() rune {
ch := s.Peek()
// reset token text position
s.tokPos = -1
s.Line = 0
redo:
// skip white space
for s.Whitespace&(1<<uint(ch)) != 0 {
ch = s.next()
}
// start collecting token text
s.tokBuf.Reset()
s.tokPos = s.srcPos - s.lastCharLen
// set token position
// (this is a slightly optimized version of the code in Pos())
s.Offset = s.srcBufOffset + s.tokPos
if s.column > 0 {
// common case: last character was not a '\n'
s.Line = s.line
s.Column = s.column
} else {
// last character was a '\n'
// (we cannot be at the beginning of the source
// since we have called next() at least once)
s.Line = s.line - 1
s.Column = s.lastLineLen
}
// determine token value
tok := ch
switch {
case s.isIdentRune(ch, 0):
if s.Mode&ScanIdents != 0 {
tok = Ident
ch = s.scanIdentifier()
} else {
ch = s.next()
}
case isDecimal(ch):
if s.Mode&(ScanInts|ScanFloats) != 0 {
tok, ch = s.scanNumber(ch, false)
} else {
ch = s.next()
}
default:
switch ch {
case EOF:
break
case '"':
if s.Mode&ScanStrings != 0 {
s.scanString('"')
tok = String
}
ch = s.next()
case '\'':
if s.Mode&ScanChars != 0 {
s.scanChar()
tok = Char
}
ch = s.next()
case '.':
ch = s.next()
if isDecimal(ch) && s.Mode&ScanFloats != 0 {
tok, ch = s.scanNumber(ch, true)
}
case '/':
ch = s.next()
if (ch == '/' || ch == '*') && s.Mode&ScanComments != 0 {
if s.Mode&SkipComments != 0 {
s.tokPos = -1 // don't collect token text
ch = s.scanComment(ch)
goto redo
}
ch = s.scanComment(ch)
tok = Comment
}
case '`':
if s.Mode&ScanRawStrings != 0 {
s.scanRawString()
tok = RawString
}
ch = s.next()
default:
ch = s.next()
}
}
// end of token text
s.tokEnd = s.srcPos - s.lastCharLen
s.ch = ch
return tok
}
// Pos returns the position of the character immediately after
// the character or token returned by the last call to Next or Scan.
// Use the Scanner's Position field for the start position of the most
// recently scanned token.
func (s *Scanner) Pos() (pos Position) {
pos.Filename = s.Filename
pos.Offset = s.srcBufOffset + s.srcPos - s.lastCharLen
switch {
case s.column > 0:
// common case: last character was not a '\n'
pos.Line = s.line
pos.Column = s.column
case s.lastLineLen > 0:
// last character was a '\n'
pos.Line = s.line - 1
pos.Column = s.lastLineLen
default:
// at the beginning of the source
pos.Line = 1
pos.Column = 1
}
return
}
// TokenText returns the string corresponding to the most recently scanned token.
// Valid after calling Scan and in calls of Scanner.Error.
func (s *Scanner) TokenText() string {
if s.tokPos < 0 {
// no token text
return ""
}
if s.tokEnd < s.tokPos {
// if EOF was reached, s.tokEnd is set to -1 (s.srcPos == 0)
s.tokEnd = s.tokPos
}
// s.tokEnd >= s.tokPos
if s.tokBuf.Len() == 0 {
// common case: the entire token text is still in srcBuf
return string(s.srcBuf[s.tokPos:s.tokEnd])
}
// part of the token text was saved in tokBuf: save the rest in
// tokBuf as well and return its content
s.tokBuf.Write(s.srcBuf[s.tokPos:s.tokEnd])
s.tokPos = s.tokEnd // ensure idempotency of TokenText() call
return s.tokBuf.String()
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tabwriter implements a write filter (tabwriter.Writer) that
// translates tabbed columns in input into properly aligned text.
//
// The package is using the Elastic Tabstops algorithm described at
// http://nickgravgaard.com/elastictabstops/index.html.
//
// The text/tabwriter package is frozen and is not accepting new features.
package tabwriter
import (
"io"
"unicode/utf8"
)
// ----------------------------------------------------------------------------
// Filter implementation
// A cell represents a segment of text terminated by tabs or line breaks.
// The text itself is stored in a separate buffer; cell only describes the
// segment's size in bytes, its width in runes, and whether it's an htab
// ('\t') terminated cell.
type cell struct {
size int // cell size in bytes
width int // cell width in runes
htab bool // true if the cell is terminated by an htab ('\t')
}
// A Writer is a filter that inserts padding around tab-delimited
// columns in its input to align them in the output.
//
// The Writer treats incoming bytes as UTF-8-encoded text consisting
// of cells terminated by horizontal ('\t') or vertical ('\v') tabs,
// and newline ('\n') or formfeed ('\f') characters; both newline and
// formfeed act as line breaks.
//
// Tab-terminated cells in contiguous lines constitute a column. The
// Writer inserts padding as needed to make all cells in a column have
// the same width, effectively aligning the columns. It assumes that
// all characters have the same width, except for tabs for which a
// tabwidth must be specified. Column cells must be tab-terminated, not
// tab-separated: non-tab terminated trailing text at the end of a line
// forms a cell but that cell is not part of an aligned column.
// For instance, in this example (where | stands for a horizontal tab):
//
// aaaa|bbb|d
// aa |b |dd
// a |
// aa |cccc|eee
//
// the b and c are in distinct columns (the b column is not contiguous
// all the way). The d and e are not in a column at all (there's no
// terminating tab, nor would the column be contiguous).
//
// The Writer assumes that all Unicode code points have the same width;
// this may not be true in some fonts or if the string contains combining
// characters.
//
// If DiscardEmptyColumns is set, empty columns that are terminated
// entirely by vertical (or "soft") tabs are discarded. Columns
// terminated by horizontal (or "hard") tabs are not affected by
// this flag.
//
// If a Writer is configured to filter HTML, HTML tags and entities
// are passed through. The widths of tags and entities are
// assumed to be zero (tags) and one (entities) for formatting purposes.
//
// A segment of text may be escaped by bracketing it with Escape
// characters. The tabwriter passes escaped text segments through
// unchanged. In particular, it does not interpret any tabs or line
// breaks within the segment. If the StripEscape flag is set, the
// Escape characters are stripped from the output; otherwise they
// are passed through as well. For the purpose of formatting, the
// width of the escaped text is always computed excluding the Escape
// characters.
//
// The formfeed character acts like a newline but it also terminates
// all columns in the current line (effectively calling Flush). Tab-
// terminated cells in the next line start new columns. Unless found
// inside an HTML tag or inside an escaped text segment, formfeed
// characters appear as newlines in the output.
//
// The Writer must buffer input internally, because proper spacing
// of one line may depend on the cells in future lines. Clients must
// call Flush when done calling Write.
type Writer struct {
// configuration
output io.Writer
minwidth int
tabwidth int
padding int
padbytes [8]byte
flags uint
// current state
buf []byte // collected text excluding tabs or line breaks
pos int // buffer position up to which cell.width of incomplete cell has been computed
cell cell // current incomplete cell; cell.width is up to buf[pos] excluding ignored sections
endChar byte // terminating char of escaped sequence (Escape for escapes, '>', ';' for HTML tags/entities, or 0)
lines [][]cell // list of lines; each line is a list of cells
widths []int // list of column widths in runes - re-used during formatting
}
// addLine adds a new line.
// flushed is a hint indicating whether the underlying writer was just flushed.
// If so, the previous line is not likely to be a good indicator of the new line's cells.
func (b *Writer) addLine(flushed bool) {
// Grow slice instead of appending,
// as that gives us an opportunity
// to re-use an existing []cell.
if n := len(b.lines) + 1; n <= cap(b.lines) {
b.lines = b.lines[:n]
b.lines[n-1] = b.lines[n-1][:0]
} else {
b.lines = append(b.lines, nil)
}
if !flushed {
// The previous line is probably a good indicator
// of how many cells the current line will have.
// If the current line's capacity is smaller than that,
// abandon it and make a new one.
if n := len(b.lines); n >= 2 {
if prev := len(b.lines[n-2]); prev > cap(b.lines[n-1]) {
b.lines[n-1] = make([]cell, 0, prev)
}
}
}
}
// Reset the current state.
func (b *Writer) reset() {
b.buf = b.buf[:0]
b.pos = 0
b.cell = cell{}
b.endChar = 0
b.lines = b.lines[0:0]
b.widths = b.widths[0:0]
b.addLine(true)
}
// Internal representation (current state):
//
// - all text written is appended to buf; tabs and line breaks are stripped away
// - at any given time there is a (possibly empty) incomplete cell at the end
// (the cell starts after a tab or line break)
// - cell.size is the number of bytes belonging to the cell so far
// - cell.width is text width in runes of that cell from the start of the cell to
// position pos; html tags and entities are excluded from this width if html
// filtering is enabled
// - the sizes and widths of processed text are kept in the lines list
// which contains a list of cells for each line
// - the widths list is a temporary list with current widths used during
// formatting; it is kept in Writer because it's re-used
//
// |<---------- size ---------->|
// | |
// |<- width ->|<- ignored ->| |
// | | | |
// [---processed---tab------------<tag>...</tag>...]
// ^ ^ ^
// | | |
// buf start of incomplete cell pos
// Formatting can be controlled with these flags.
const (
// Ignore html tags and treat entities (starting with '&'
// and ending in ';') as single characters (width = 1).
FilterHTML uint = 1 << iota
// Strip Escape characters bracketing escaped text segments
// instead of passing them through unchanged with the text.
StripEscape
// Force right-alignment of cell content.
// Default is left-alignment.
AlignRight
// Handle empty columns as if they were not present in
// the input in the first place.
DiscardEmptyColumns
// Always use tabs for indentation columns (i.e., padding of
// leading empty cells on the left) independent of padchar.
TabIndent
// Print a vertical bar ('|') between columns (after formatting).
// Discarded columns appear as zero-width columns ("||").
Debug
)
// A Writer must be initialized with a call to Init. The first parameter (output)
// specifies the filter output. The remaining parameters control the formatting:
//
// minwidth minimal cell width including any padding
// tabwidth width of tab characters (equivalent number of spaces)
// padding padding added to a cell before computing its width
// padchar ASCII char used for padding
// if padchar == '\t', the Writer will assume that the
// width of a '\t' in the formatted output is tabwidth,
// and cells are left-aligned independent of align_left
// (for correct-looking results, tabwidth must correspond
// to the tab width in the viewer displaying the result)
// flags formatting control
func (b *Writer) Init(output io.Writer, minwidth, tabwidth, padding int, padchar byte, flags uint) *Writer {
if minwidth < 0 || tabwidth < 0 || padding < 0 {
panic("negative minwidth, tabwidth, or padding")
}
b.output = output
b.minwidth = minwidth
b.tabwidth = tabwidth
b.padding = padding
for i := range b.padbytes {
b.padbytes[i] = padchar
}
if padchar == '\t' {
// tab padding enforces left-alignment
flags &^= AlignRight
}
b.flags = flags
b.reset()
return b
}
// debugging support (keep code around)
func (b *Writer) dump() {
pos := 0
for i, line := range b.lines {
print("(", i, ") ")
for _, c := range line {
print("[", string(b.buf[pos:pos+c.size]), "]")
pos += c.size
}
print("\n")
}
print("\n")
}
// local error wrapper so we can distinguish errors we want to return
// as errors from genuine panics (which we don't want to return as errors)
type osError struct {
err error
}
func (b *Writer) write0(buf []byte) {
n, err := b.output.Write(buf)
if n != len(buf) && err == nil {
err = io.ErrShortWrite
}
if err != nil {
panic(osError{err})
}
}
func (b *Writer) writeN(src []byte, n int) {
for n > len(src) {
b.write0(src)
n -= len(src)
}
b.write0(src[0:n])
}
var (
newline = []byte{'\n'}
tabs = []byte("\t\t\t\t\t\t\t\t")
)
func (b *Writer) writePadding(textw, cellw int, useTabs bool) {
if b.padbytes[0] == '\t' || useTabs {
// padding is done with tabs
if b.tabwidth == 0 {
return // tabs have no width - can't do any padding
}
// make cellw the smallest multiple of b.tabwidth
cellw = (cellw + b.tabwidth - 1) / b.tabwidth * b.tabwidth
n := cellw - textw // amount of padding
if n < 0 {
panic("internal error")
}
b.writeN(tabs, (n+b.tabwidth-1)/b.tabwidth)
return
}
// padding is done with non-tab characters
b.writeN(b.padbytes[0:], cellw-textw)
}
var vbar = []byte{'|'}
func (b *Writer) writeLines(pos0 int, line0, line1 int) (pos int) {
pos = pos0
for i := line0; i < line1; i++ {
line := b.lines[i]
// if TabIndent is set, use tabs to pad leading empty cells
useTabs := b.flags&TabIndent != 0
for j, c := range line {
if j > 0 && b.flags&Debug != 0 {
// indicate column break
b.write0(vbar)
}
if c.size == 0 {
// empty cell
if j < len(b.widths) {
b.writePadding(c.width, b.widths[j], useTabs)
}
} else {
// non-empty cell
useTabs = false
if b.flags&AlignRight == 0 { // align left
b.write0(b.buf[pos : pos+c.size])
pos += c.size
if j < len(b.widths) {
b.writePadding(c.width, b.widths[j], false)
}
} else { // align right
if j < len(b.widths) {
b.writePadding(c.width, b.widths[j], false)
}
b.write0(b.buf[pos : pos+c.size])
pos += c.size
}
}
}
if i+1 == len(b.lines) {
// last buffered line - we don't have a newline, so just write
// any outstanding buffered data
b.write0(b.buf[pos : pos+b.cell.size])
pos += b.cell.size
} else {
// not the last line - write newline
b.write0(newline)
}
}
return
}
// Format the text between line0 and line1 (excluding line1); pos
// is the buffer position corresponding to the beginning of line0.
// Returns the buffer position corresponding to the beginning of
// line1 and an error, if any.
func (b *Writer) format(pos0 int, line0, line1 int) (pos int) {
pos = pos0
column := len(b.widths)
for this := line0; this < line1; this++ {
line := b.lines[this]
if column >= len(line)-1 {
continue
}
// cell exists in this column => this line
// has more cells than the previous line
// (the last cell per line is ignored because cells are
// tab-terminated; the last cell per line describes the
// text before the newline/formfeed and does not belong
// to a column)
// print unprinted lines until beginning of block
pos = b.writeLines(pos, line0, this)
line0 = this
// column block begin
width := b.minwidth // minimal column width
discardable := true // true if all cells in this column are empty and "soft"
for ; this < line1; this++ {
line = b.lines[this]
if column >= len(line)-1 {
break
}
// cell exists in this column
c := line[column]
// update width
if w := c.width + b.padding; w > width {
width = w
}
// update discardable
if c.width > 0 || c.htab {
discardable = false
}
}
// column block end
// discard empty columns if necessary
if discardable && b.flags&DiscardEmptyColumns != 0 {
width = 0
}
// format and print all columns to the right of this column
// (we know the widths of this column and all columns to the left)
b.widths = append(b.widths, width) // push width
pos = b.format(pos, line0, this)
b.widths = b.widths[0 : len(b.widths)-1] // pop width
line0 = this
}
// print unprinted lines until end
return b.writeLines(pos, line0, line1)
}
// Append text to current cell.
func (b *Writer) append(text []byte) {
b.buf = append(b.buf, text...)
b.cell.size += len(text)
}
// Update the cell width.
func (b *Writer) updateWidth() {
b.cell.width += utf8.RuneCount(b.buf[b.pos:])
b.pos = len(b.buf)
}
// To escape a text segment, bracket it with Escape characters.
// For instance, the tab in this string "Ignore this tab: \xff\t\xff"
// does not terminate a cell and constitutes a single character of
// width one for formatting purposes.
//
// The value 0xff was chosen because it cannot appear in a valid UTF-8 sequence.
const Escape = '\xff'
// Start escaped mode.
func (b *Writer) startEscape(ch byte) {
switch ch {
case Escape:
b.endChar = Escape
case '<':
b.endChar = '>'
case '&':
b.endChar = ';'
}
}
// Terminate escaped mode. If the escaped text was an HTML tag, its width
// is assumed to be zero for formatting purposes; if it was an HTML entity,
// its width is assumed to be one. In all other cases, the width is the
// unicode width of the text.
func (b *Writer) endEscape() {
switch b.endChar {
case Escape:
b.updateWidth()
if b.flags&StripEscape == 0 {
b.cell.width -= 2 // don't count the Escape chars
}
case '>': // tag of zero width
case ';':
b.cell.width++ // entity, count as one rune
}
b.pos = len(b.buf)
b.endChar = 0
}
// Terminate the current cell by adding it to the list of cells of the
// current line. Returns the number of cells in that line.
func (b *Writer) terminateCell(htab bool) int {
b.cell.htab = htab
line := &b.lines[len(b.lines)-1]
*line = append(*line, b.cell)
b.cell = cell{}
return len(*line)
}
func (b *Writer) handlePanic(err *error, op string) {
if e := recover(); e != nil {
if op == "Flush" {
// If Flush ran into a panic, we still need to reset.
b.reset()
}
if nerr, ok := e.(osError); ok {
*err = nerr.err
return
}
panic("tabwriter: panic during " + op)
}
}
// Flush should be called after the last call to Write to ensure
// that any data buffered in the Writer is written to output. Any
// incomplete escape sequence at the end is considered
// complete for formatting purposes.
func (b *Writer) Flush() error {
return b.flush()
}
// flush is the internal version of Flush, with a named return value which we
// don't want to expose.
func (b *Writer) flush() (err error) {
defer b.handlePanic(&err, "Flush")
b.flushNoDefers()
return nil
}
// flushNoDefers is like flush, but without a deferred handlePanic call. This
// can be called from other methods which already have their own deferred
// handlePanic calls, such as Write, and avoid the extra defer work.
func (b *Writer) flushNoDefers() {
// add current cell if not empty
if b.cell.size > 0 {
if b.endChar != 0 {
// inside escape - terminate it even if incomplete
b.endEscape()
}
b.terminateCell(false)
}
// format contents of buffer
b.format(0, 0, len(b.lines))
b.reset()
}
var hbar = []byte("---\n")
// Write writes buf to the writer b.
// The only errors returned are ones encountered
// while writing to the underlying output stream.
func (b *Writer) Write(buf []byte) (n int, err error) {
defer b.handlePanic(&err, "Write")
// split text into cells
n = 0
for i, ch := range buf {
if b.endChar == 0 {
// outside escape
switch ch {
case '\t', '\v', '\n', '\f':
// end of cell
b.append(buf[n:i])
b.updateWidth()
n = i + 1 // ch consumed
ncells := b.terminateCell(ch == '\t')
if ch == '\n' || ch == '\f' {
// terminate line
b.addLine(ch == '\f')
if ch == '\f' || ncells == 1 {
// A '\f' always forces a flush. Otherwise, if the previous
// line has only one cell which does not have an impact on
// the formatting of the following lines (the last cell per
// line is ignored by format()), thus we can flush the
// Writer contents.
b.flushNoDefers()
if ch == '\f' && b.flags&Debug != 0 {
// indicate section break
b.write0(hbar)
}
}
}
case Escape:
// start of escaped sequence
b.append(buf[n:i])
b.updateWidth()
n = i
if b.flags&StripEscape != 0 {
n++ // strip Escape
}
b.startEscape(Escape)
case '<', '&':
// possibly an html tag/entity
if b.flags&FilterHTML != 0 {
// begin of tag/entity
b.append(buf[n:i])
b.updateWidth()
n = i
b.startEscape(ch)
}
}
} else {
// inside escape
if ch == b.endChar {
// end of tag/entity
j := i + 1
if ch == Escape && b.flags&StripEscape != 0 {
j = i // strip Escape
}
b.append(buf[n:j])
n = i + 1 // ch consumed
b.endEscape()
}
}
}
// append leftover text
b.append(buf[n:])
n = len(buf)
return
}
// NewWriter allocates and initializes a new tabwriter.Writer.
// The parameters are the same as for the Init function.
func NewWriter(output io.Writer, minwidth, tabwidth, padding int, padchar byte, flags uint) *Writer {
return new(Writer).Init(output, minwidth, tabwidth, padding, padchar, flags)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"errors"
"fmt"
"internal/fmtsort"
"io"
"reflect"
"runtime"
"strings"
"text/template/parse"
)
// maxExecDepth specifies the maximum stack depth of templates within
// templates. This limit is only practically reached by accidentally
// recursive template invocations. This limit allows us to return
// an error instead of triggering a stack overflow.
var maxExecDepth = initMaxExecDepth()
func initMaxExecDepth() int {
if runtime.GOARCH == "wasm" {
return 1000
}
return 100000
}
// state represents the state of an execution. It's not part of the
// template so that multiple executions of the same template
// can execute in parallel.
type state struct {
tmpl *Template
wr io.Writer
node parse.Node // current node, for errors
vars []variable // push-down stack of variable values.
depth int // the height of the stack of executing templates.
}
// variable holds the dynamic value of a variable such as $, $x etc.
type variable struct {
name string
value reflect.Value
}
// push pushes a new variable on the stack.
func (s *state) push(name string, value reflect.Value) {
s.vars = append(s.vars, variable{name, value})
}
// mark returns the length of the variable stack.
func (s *state) mark() int {
return len(s.vars)
}
// pop pops the variable stack up to the mark.
func (s *state) pop(mark int) {
s.vars = s.vars[0:mark]
}
// setVar overwrites the last declared variable with the given name.
// Used by variable assignments.
func (s *state) setVar(name string, value reflect.Value) {
for i := s.mark() - 1; i >= 0; i-- {
if s.vars[i].name == name {
s.vars[i].value = value
return
}
}
s.errorf("undefined variable: %s", name)
}
// setTopVar overwrites the top-nth variable on the stack. Used by range iterations.
func (s *state) setTopVar(n int, value reflect.Value) {
s.vars[len(s.vars)-n].value = value
}
// varValue returns the value of the named variable.
func (s *state) varValue(name string) reflect.Value {
for i := s.mark() - 1; i >= 0; i-- {
if s.vars[i].name == name {
return s.vars[i].value
}
}
s.errorf("undefined variable: %s", name)
return zero
}
var zero reflect.Value
type missingValType struct{}
var missingVal = reflect.ValueOf(missingValType{})
var missingValReflectType = reflect.TypeOf(missingValType{})
func isMissing(v reflect.Value) bool {
return v.IsValid() && v.Type() == missingValReflectType
}
// at marks the state to be on node n, for error reporting.
func (s *state) at(node parse.Node) {
s.node = node
}
// doublePercent returns the string with %'s replaced by %%, if necessary,
// so it can be used safely inside a Printf format string.
func doublePercent(str string) string {
return strings.ReplaceAll(str, "%", "%%")
}
// TODO: It would be nice if ExecError was more broken down, but
// the way ErrorContext embeds the template name makes the
// processing too clumsy.
// ExecError is the custom error type returned when Execute has an
// error evaluating its template. (If a write error occurs, the actual
// error is returned; it will not be of type ExecError.)
type ExecError struct {
Name string // Name of template.
Err error // Pre-formatted error.
}
func (e ExecError) Error() string {
return e.Err.Error()
}
func (e ExecError) Unwrap() error {
return e.Err
}
// errorf records an ExecError and terminates processing.
func (s *state) errorf(format string, args ...any) {
name := doublePercent(s.tmpl.Name())
if s.node == nil {
format = fmt.Sprintf("template: %s: %s", name, format)
} else {
location, context := s.tmpl.ErrorContext(s.node)
format = fmt.Sprintf("template: %s: executing %q at <%s>: %s", location, name, doublePercent(context), format)
}
panic(ExecError{
Name: s.tmpl.Name(),
Err: fmt.Errorf(format, args...),
})
}
// writeError is the wrapper type used internally when Execute has an
// error writing to its output. We strip the wrapper in errRecover.
// Note that this is not an implementation of error, so it cannot escape
// from the package as an error value.
type writeError struct {
Err error // Original error.
}
func (s *state) writeError(err error) {
panic(writeError{
Err: err,
})
}
// errRecover is the handler that turns panics into returns from the top
// level of Parse.
func errRecover(errp *error) {
e := recover()
if e != nil {
switch err := e.(type) {
case runtime.Error:
panic(e)
case writeError:
*errp = err.Err // Strip the wrapper.
case ExecError:
*errp = err // Keep the wrapper.
default:
panic(e)
}
}
}
// ExecuteTemplate applies the template associated with t that has the given name
// to the specified data object and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
func (t *Template) ExecuteTemplate(wr io.Writer, name string, data any) error {
tmpl := t.Lookup(name)
if tmpl == nil {
return fmt.Errorf("template: no template %q associated with template %q", name, t.name)
}
return tmpl.Execute(wr, data)
}
// Execute applies a parsed template to the specified data object,
// and writes the output to wr.
// If an error occurs executing the template or writing its output,
// execution stops, but partial results may already have been written to
// the output writer.
// A template may be executed safely in parallel, although if parallel
// executions share a Writer the output may be interleaved.
//
// If data is a reflect.Value, the template applies to the concrete
// value that the reflect.Value holds, as in fmt.Print.
func (t *Template) Execute(wr io.Writer, data any) error {
return t.execute(wr, data)
}
func (t *Template) execute(wr io.Writer, data any) (err error) {
defer errRecover(&err)
value, ok := data.(reflect.Value)
if !ok {
value = reflect.ValueOf(data)
}
state := &state{
tmpl: t,
wr: wr,
vars: []variable{{"$", value}},
}
if t.Tree == nil || t.Root == nil {
state.errorf("%q is an incomplete or empty template", t.Name())
}
state.walk(value, t.Root)
return
}
// DefinedTemplates returns a string listing the defined templates,
// prefixed by the string "; defined templates are: ". If there are none,
// it returns the empty string. For generating an error message here
// and in html/template.
func (t *Template) DefinedTemplates() string {
if t.common == nil {
return ""
}
var b strings.Builder
t.muTmpl.RLock()
defer t.muTmpl.RUnlock()
for name, tmpl := range t.tmpl {
if tmpl.Tree == nil || tmpl.Root == nil {
continue
}
if b.Len() == 0 {
b.WriteString("; defined templates are: ")
} else {
b.WriteString(", ")
}
fmt.Fprintf(&b, "%q", name)
}
return b.String()
}
// Sentinel errors for use with panic to signal early exits from range loops.
var (
walkBreak = errors.New("break")
walkContinue = errors.New("continue")
)
// Walk functions step through the major pieces of the template structure,
// generating output as they go.
func (s *state) walk(dot reflect.Value, node parse.Node) {
s.at(node)
switch node := node.(type) {
case *parse.ActionNode:
// Do not pop variables so they persist until next end.
// Also, if the action declares variables, don't print the result.
val := s.evalPipeline(dot, node.Pipe)
if len(node.Pipe.Decl) == 0 {
s.printValue(node, val)
}
case *parse.BreakNode:
panic(walkBreak)
case *parse.CommentNode:
case *parse.ContinueNode:
panic(walkContinue)
case *parse.IfNode:
s.walkIfOrWith(parse.NodeIf, dot, node.Pipe, node.List, node.ElseList)
case *parse.ListNode:
for _, node := range node.Nodes {
s.walk(dot, node)
}
case *parse.RangeNode:
s.walkRange(dot, node)
case *parse.TemplateNode:
s.walkTemplate(dot, node)
case *parse.TextNode:
if _, err := s.wr.Write(node.Text); err != nil {
s.writeError(err)
}
case *parse.WithNode:
s.walkIfOrWith(parse.NodeWith, dot, node.Pipe, node.List, node.ElseList)
default:
s.errorf("unknown node: %s", node)
}
}
// walkIfOrWith walks an 'if' or 'with' node. The two control structures
// are identical in behavior except that 'with' sets dot.
func (s *state) walkIfOrWith(typ parse.NodeType, dot reflect.Value, pipe *parse.PipeNode, list, elseList *parse.ListNode) {
defer s.pop(s.mark())
val := s.evalPipeline(dot, pipe)
truth, ok := isTrue(indirectInterface(val))
if !ok {
s.errorf("if/with can't use %v", val)
}
if truth {
if typ == parse.NodeWith {
s.walk(val, list)
} else {
s.walk(dot, list)
}
} else if elseList != nil {
s.walk(dot, elseList)
}
}
// IsTrue reports whether the value is 'true', in the sense of not the zero of its type,
// and whether the value has a meaningful truth value. This is the definition of
// truth used by if and other such actions.
func IsTrue(val any) (truth, ok bool) {
return isTrue(reflect.ValueOf(val))
}
func isTrue(val reflect.Value) (truth, ok bool) {
if !val.IsValid() {
// Something like var x interface{}, never set. It's a form of nil.
return false, true
}
switch val.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
truth = val.Len() > 0
case reflect.Bool:
truth = val.Bool()
case reflect.Complex64, reflect.Complex128:
truth = val.Complex() != 0
case reflect.Chan, reflect.Func, reflect.Pointer, reflect.Interface:
truth = !val.IsNil()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
truth = val.Int() != 0
case reflect.Float32, reflect.Float64:
truth = val.Float() != 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
truth = val.Uint() != 0
case reflect.Struct:
truth = true // Struct values are always true.
default:
return
}
return truth, true
}
func (s *state) walkRange(dot reflect.Value, r *parse.RangeNode) {
s.at(r)
defer func() {
if r := recover(); r != nil && r != walkBreak {
panic(r)
}
}()
defer s.pop(s.mark())
val, _ := indirect(s.evalPipeline(dot, r.Pipe))
// mark top of stack before any variables in the body are pushed.
mark := s.mark()
oneIteration := func(index, elem reflect.Value) {
// Set top var (lexically the second if there are two) to the element.
if len(r.Pipe.Decl) > 0 {
if r.Pipe.IsAssign {
s.setVar(r.Pipe.Decl[0].Ident[0], elem)
} else {
s.setTopVar(1, elem)
}
}
// Set next var (lexically the first if there are two) to the index.
if len(r.Pipe.Decl) > 1 {
if r.Pipe.IsAssign {
s.setVar(r.Pipe.Decl[1].Ident[0], index)
} else {
s.setTopVar(2, index)
}
}
defer s.pop(mark)
defer func() {
// Consume panic(walkContinue)
if r := recover(); r != nil && r != walkContinue {
panic(r)
}
}()
s.walk(elem, r.List)
}
switch val.Kind() {
case reflect.Array, reflect.Slice:
if val.Len() == 0 {
break
}
for i := 0; i < val.Len(); i++ {
oneIteration(reflect.ValueOf(i), val.Index(i))
}
return
case reflect.Map:
if val.Len() == 0 {
break
}
om := fmtsort.Sort(val)
for i, key := range om.Key {
oneIteration(key, om.Value[i])
}
return
case reflect.Chan:
if val.IsNil() {
break
}
if val.Type().ChanDir() == reflect.SendDir {
s.errorf("range over send-only channel %v", val)
break
}
i := 0
for ; ; i++ {
elem, ok := val.Recv()
if !ok {
break
}
oneIteration(reflect.ValueOf(i), elem)
}
if i == 0 {
break
}
return
case reflect.Invalid:
break // An invalid value is likely a nil map, etc. and acts like an empty map.
default:
s.errorf("range can't iterate over %v", val)
}
if r.ElseList != nil {
s.walk(dot, r.ElseList)
}
}
func (s *state) walkTemplate(dot reflect.Value, t *parse.TemplateNode) {
s.at(t)
tmpl := s.tmpl.Lookup(t.Name)
if tmpl == nil {
s.errorf("template %q not defined", t.Name)
}
if s.depth == maxExecDepth {
s.errorf("exceeded maximum template depth (%v)", maxExecDepth)
}
// Variables declared by the pipeline persist.
dot = s.evalPipeline(dot, t.Pipe)
newState := *s
newState.depth++
newState.tmpl = tmpl
// No dynamic scoping: template invocations inherit no variables.
newState.vars = []variable{{"$", dot}}
newState.walk(dot, tmpl.Root)
}
// Eval functions evaluate pipelines, commands, and their elements and extract
// values from the data structure by examining fields, calling methods, and so on.
// The printing of those values happens only through walk functions.
// evalPipeline returns the value acquired by evaluating a pipeline. If the
// pipeline has a variable declaration, the variable will be pushed on the
// stack. Callers should therefore pop the stack after they are finished
// executing commands depending on the pipeline value.
func (s *state) evalPipeline(dot reflect.Value, pipe *parse.PipeNode) (value reflect.Value) {
if pipe == nil {
return
}
s.at(pipe)
value = missingVal
for _, cmd := range pipe.Cmds {
value = s.evalCommand(dot, cmd, value) // previous value is this one's final arg.
// If the object has type interface{}, dig down one level to the thing inside.
if value.Kind() == reflect.Interface && value.Type().NumMethod() == 0 {
value = reflect.ValueOf(value.Interface()) // lovely!
}
}
for _, variable := range pipe.Decl {
if pipe.IsAssign {
s.setVar(variable.Ident[0], value)
} else {
s.push(variable.Ident[0], value)
}
}
return value
}
func (s *state) notAFunction(args []parse.Node, final reflect.Value) {
if len(args) > 1 || !isMissing(final) {
s.errorf("can't give argument to non-function %s", args[0])
}
}
func (s *state) evalCommand(dot reflect.Value, cmd *parse.CommandNode, final reflect.Value) reflect.Value {
firstWord := cmd.Args[0]
switch n := firstWord.(type) {
case *parse.FieldNode:
return s.evalFieldNode(dot, n, cmd.Args, final)
case *parse.ChainNode:
return s.evalChainNode(dot, n, cmd.Args, final)
case *parse.IdentifierNode:
// Must be a function.
return s.evalFunction(dot, n, cmd, cmd.Args, final)
case *parse.PipeNode:
// Parenthesized pipeline. The arguments are all inside the pipeline; final must be absent.
s.notAFunction(cmd.Args, final)
return s.evalPipeline(dot, n)
case *parse.VariableNode:
return s.evalVariableNode(dot, n, cmd.Args, final)
}
s.at(firstWord)
s.notAFunction(cmd.Args, final)
switch word := firstWord.(type) {
case *parse.BoolNode:
return reflect.ValueOf(word.True)
case *parse.DotNode:
return dot
case *parse.NilNode:
s.errorf("nil is not a command")
case *parse.NumberNode:
return s.idealConstant(word)
case *parse.StringNode:
return reflect.ValueOf(word.Text)
}
s.errorf("can't evaluate command %q", firstWord)
panic("not reached")
}
// idealConstant is called to return the value of a number in a context where
// we don't know the type. In that case, the syntax of the number tells us
// its type, and we use Go rules to resolve. Note there is no such thing as
// a uint ideal constant in this situation - the value must be of int type.
func (s *state) idealConstant(constant *parse.NumberNode) reflect.Value {
// These are ideal constants but we don't know the type
// and we have no context. (If it was a method argument,
// we'd know what we need.) The syntax guides us to some extent.
s.at(constant)
switch {
case constant.IsComplex:
return reflect.ValueOf(constant.Complex128) // incontrovertible.
case constant.IsFloat &&
!isHexInt(constant.Text) && !isRuneInt(constant.Text) &&
strings.ContainsAny(constant.Text, ".eEpP"):
return reflect.ValueOf(constant.Float64)
case constant.IsInt:
n := int(constant.Int64)
if int64(n) != constant.Int64 {
s.errorf("%s overflows int", constant.Text)
}
return reflect.ValueOf(n)
case constant.IsUint:
s.errorf("%s overflows int", constant.Text)
}
return zero
}
func isRuneInt(s string) bool {
return len(s) > 0 && s[0] == '\''
}
func isHexInt(s string) bool {
return len(s) > 2 && s[0] == '0' && (s[1] == 'x' || s[1] == 'X') && !strings.ContainsAny(s, "pP")
}
func (s *state) evalFieldNode(dot reflect.Value, field *parse.FieldNode, args []parse.Node, final reflect.Value) reflect.Value {
s.at(field)
return s.evalFieldChain(dot, dot, field, field.Ident, args, final)
}
func (s *state) evalChainNode(dot reflect.Value, chain *parse.ChainNode, args []parse.Node, final reflect.Value) reflect.Value {
s.at(chain)
if len(chain.Field) == 0 {
s.errorf("internal error: no fields in evalChainNode")
}
if chain.Node.Type() == parse.NodeNil {
s.errorf("indirection through explicit nil in %s", chain)
}
// (pipe).Field1.Field2 has pipe as .Node, fields as .Field. Eval the pipeline, then the fields.
pipe := s.evalArg(dot, nil, chain.Node)
return s.evalFieldChain(dot, pipe, chain, chain.Field, args, final)
}
func (s *state) evalVariableNode(dot reflect.Value, variable *parse.VariableNode, args []parse.Node, final reflect.Value) reflect.Value {
// $x.Field has $x as the first ident, Field as the second. Eval the var, then the fields.
s.at(variable)
value := s.varValue(variable.Ident[0])
if len(variable.Ident) == 1 {
s.notAFunction(args, final)
return value
}
return s.evalFieldChain(dot, value, variable, variable.Ident[1:], args, final)
}
// evalFieldChain evaluates .X.Y.Z possibly followed by arguments.
// dot is the environment in which to evaluate arguments, while
// receiver is the value being walked along the chain.
func (s *state) evalFieldChain(dot, receiver reflect.Value, node parse.Node, ident []string, args []parse.Node, final reflect.Value) reflect.Value {
n := len(ident)
for i := 0; i < n-1; i++ {
receiver = s.evalField(dot, ident[i], node, nil, missingVal, receiver)
}
// Now if it's a method, it gets the arguments.
return s.evalField(dot, ident[n-1], node, args, final, receiver)
}
func (s *state) evalFunction(dot reflect.Value, node *parse.IdentifierNode, cmd parse.Node, args []parse.Node, final reflect.Value) reflect.Value {
s.at(node)
name := node.Ident
function, isBuiltin, ok := findFunction(name, s.tmpl)
if !ok {
s.errorf("%q is not a defined function", name)
}
return s.evalCall(dot, function, isBuiltin, cmd, name, args, final)
}
// evalField evaluates an expression like (.Field) or (.Field arg1 arg2).
// The 'final' argument represents the return value from the preceding
// value of the pipeline, if any.
func (s *state) evalField(dot reflect.Value, fieldName string, node parse.Node, args []parse.Node, final, receiver reflect.Value) reflect.Value {
if !receiver.IsValid() {
if s.tmpl.option.missingKey == mapError { // Treat invalid value as missing map key.
s.errorf("nil data; no entry for key %q", fieldName)
}
return zero
}
typ := receiver.Type()
receiver, isNil := indirect(receiver)
if receiver.Kind() == reflect.Interface && isNil {
// Calling a method on a nil interface can't work. The
// MethodByName method call below would panic.
s.errorf("nil pointer evaluating %s.%s", typ, fieldName)
return zero
}
// Unless it's an interface, need to get to a value of type *T to guarantee
// we see all methods of T and *T.
ptr := receiver
if ptr.Kind() != reflect.Interface && ptr.Kind() != reflect.Pointer && ptr.CanAddr() {
ptr = ptr.Addr()
}
if method := ptr.MethodByName(fieldName); method.IsValid() {
return s.evalCall(dot, method, false, node, fieldName, args, final)
}
hasArgs := len(args) > 1 || !isMissing(final)
// It's not a method; must be a field of a struct or an element of a map.
switch receiver.Kind() {
case reflect.Struct:
tField, ok := receiver.Type().FieldByName(fieldName)
if ok {
field, err := receiver.FieldByIndexErr(tField.Index)
if !tField.IsExported() {
s.errorf("%s is an unexported field of struct type %s", fieldName, typ)
}
if err != nil {
s.errorf("%v", err)
}
// If it's a function, we must call it.
if hasArgs {
s.errorf("%s has arguments but cannot be invoked as function", fieldName)
}
return field
}
case reflect.Map:
// If it's a map, attempt to use the field name as a key.
nameVal := reflect.ValueOf(fieldName)
if nameVal.Type().AssignableTo(receiver.Type().Key()) {
if hasArgs {
s.errorf("%s is not a method but has arguments", fieldName)
}
result := receiver.MapIndex(nameVal)
if !result.IsValid() {
switch s.tmpl.option.missingKey {
case mapInvalid:
// Just use the invalid value.
case mapZeroValue:
result = reflect.Zero(receiver.Type().Elem())
case mapError:
s.errorf("map has no entry for key %q", fieldName)
}
}
return result
}
case reflect.Pointer:
etyp := receiver.Type().Elem()
if etyp.Kind() == reflect.Struct {
if _, ok := etyp.FieldByName(fieldName); !ok {
// If there's no such field, say "can't evaluate"
// instead of "nil pointer evaluating".
break
}
}
if isNil {
s.errorf("nil pointer evaluating %s.%s", typ, fieldName)
}
}
s.errorf("can't evaluate field %s in type %s", fieldName, typ)
panic("not reached")
}
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
fmtStringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
reflectValueType = reflect.TypeOf((*reflect.Value)(nil)).Elem()
)
// evalCall executes a function or method call. If it's a method, fun already has the receiver bound, so
// it looks just like a function call. The arg list, if non-nil, includes (in the manner of the shell), arg[0]
// as the function itself.
func (s *state) evalCall(dot, fun reflect.Value, isBuiltin bool, node parse.Node, name string, args []parse.Node, final reflect.Value) reflect.Value {
if args != nil {
args = args[1:] // Zeroth arg is function name/node; not passed to function.
}
typ := fun.Type()
numIn := len(args)
if !isMissing(final) {
numIn++
}
numFixed := len(args)
if typ.IsVariadic() {
numFixed = typ.NumIn() - 1 // last arg is the variadic one.
if numIn < numFixed {
s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args))
}
} else if numIn != typ.NumIn() {
s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), numIn)
}
if !goodFunc(typ) {
// TODO: This could still be a confusing error; maybe goodFunc should provide info.
s.errorf("can't call method/function %q with %d results", name, typ.NumOut())
}
unwrap := func(v reflect.Value) reflect.Value {
if v.Type() == reflectValueType {
v = v.Interface().(reflect.Value)
}
return v
}
// Special case for builtin and/or, which short-circuit.
if isBuiltin && (name == "and" || name == "or") {
argType := typ.In(0)
var v reflect.Value
for _, arg := range args {
v = s.evalArg(dot, argType, arg).Interface().(reflect.Value)
if truth(v) == (name == "or") {
// This value was already unwrapped
// by the .Interface().(reflect.Value).
return v
}
}
if final != missingVal {
// The last argument to and/or is coming from
// the pipeline. We didn't short circuit on an earlier
// argument, so we are going to return this one.
// We don't have to evaluate final, but we do
// have to check its type. Then, since we are
// going to return it, we have to unwrap it.
v = unwrap(s.validateType(final, argType))
}
return v
}
// Build the arg list.
argv := make([]reflect.Value, numIn)
// Args must be evaluated. Fixed args first.
i := 0
for ; i < numFixed && i < len(args); i++ {
argv[i] = s.evalArg(dot, typ.In(i), args[i])
}
// Now the ... args.
if typ.IsVariadic() {
argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice.
for ; i < len(args); i++ {
argv[i] = s.evalArg(dot, argType, args[i])
}
}
// Add final value if necessary.
if !isMissing(final) {
t := typ.In(typ.NumIn() - 1)
if typ.IsVariadic() {
if numIn-1 < numFixed {
// The added final argument corresponds to a fixed parameter of the function.
// Validate against the type of the actual parameter.
t = typ.In(numIn - 1)
} else {
// The added final argument corresponds to the variadic part.
// Validate against the type of the elements of the variadic slice.
t = t.Elem()
}
}
argv[i] = s.validateType(final, t)
}
v, err := safeCall(fun, argv)
// If we have an error that is not nil, stop execution and return that
// error to the caller.
if err != nil {
s.at(node)
s.errorf("error calling %s: %w", name, err)
}
return unwrap(v)
}
// canBeNil reports whether an untyped nil can be assigned to the type. See reflect.Zero.
func canBeNil(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return true
case reflect.Struct:
return typ == reflectValueType
}
return false
}
// validateType guarantees that the value is valid and assignable to the type.
func (s *state) validateType(value reflect.Value, typ reflect.Type) reflect.Value {
if !value.IsValid() {
if typ == nil {
// An untyped nil interface{}. Accept as a proper nil value.
return reflect.ValueOf(nil)
}
if canBeNil(typ) {
// Like above, but use the zero value of the non-nil type.
return reflect.Zero(typ)
}
s.errorf("invalid value; expected %s", typ)
}
if typ == reflectValueType && value.Type() != typ {
return reflect.ValueOf(value)
}
if typ != nil && !value.Type().AssignableTo(typ) {
if value.Kind() == reflect.Interface && !value.IsNil() {
value = value.Elem()
if value.Type().AssignableTo(typ) {
return value
}
// fallthrough
}
// Does one dereference or indirection work? We could do more, as we
// do with method receivers, but that gets messy and method receivers
// are much more constrained, so it makes more sense there than here.
// Besides, one is almost always all you need.
switch {
case value.Kind() == reflect.Pointer && value.Type().Elem().AssignableTo(typ):
value = value.Elem()
if !value.IsValid() {
s.errorf("dereference of nil pointer of type %s", typ)
}
case reflect.PointerTo(value.Type()).AssignableTo(typ) && value.CanAddr():
value = value.Addr()
default:
s.errorf("wrong type for value; expected %s; got %s", typ, value.Type())
}
}
return value
}
func (s *state) evalArg(dot reflect.Value, typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
switch arg := n.(type) {
case *parse.DotNode:
return s.validateType(dot, typ)
case *parse.NilNode:
if canBeNil(typ) {
return reflect.Zero(typ)
}
s.errorf("cannot assign nil to %s", typ)
case *parse.FieldNode:
return s.validateType(s.evalFieldNode(dot, arg, []parse.Node{n}, missingVal), typ)
case *parse.VariableNode:
return s.validateType(s.evalVariableNode(dot, arg, nil, missingVal), typ)
case *parse.PipeNode:
return s.validateType(s.evalPipeline(dot, arg), typ)
case *parse.IdentifierNode:
return s.validateType(s.evalFunction(dot, arg, arg, nil, missingVal), typ)
case *parse.ChainNode:
return s.validateType(s.evalChainNode(dot, arg, nil, missingVal), typ)
}
switch typ.Kind() {
case reflect.Bool:
return s.evalBool(typ, n)
case reflect.Complex64, reflect.Complex128:
return s.evalComplex(typ, n)
case reflect.Float32, reflect.Float64:
return s.evalFloat(typ, n)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return s.evalInteger(typ, n)
case reflect.Interface:
if typ.NumMethod() == 0 {
return s.evalEmptyInterface(dot, n)
}
case reflect.Struct:
if typ == reflectValueType {
return reflect.ValueOf(s.evalEmptyInterface(dot, n))
}
case reflect.String:
return s.evalString(typ, n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return s.evalUnsignedInteger(typ, n)
}
s.errorf("can't handle %s for arg of type %s", n, typ)
panic("not reached")
}
func (s *state) evalBool(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.BoolNode); ok {
value := reflect.New(typ).Elem()
value.SetBool(n.True)
return value
}
s.errorf("expected bool; found %s", n)
panic("not reached")
}
func (s *state) evalString(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.StringNode); ok {
value := reflect.New(typ).Elem()
value.SetString(n.Text)
return value
}
s.errorf("expected string; found %s", n)
panic("not reached")
}
func (s *state) evalInteger(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsInt {
value := reflect.New(typ).Elem()
value.SetInt(n.Int64)
return value
}
s.errorf("expected integer; found %s", n)
panic("not reached")
}
func (s *state) evalUnsignedInteger(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsUint {
value := reflect.New(typ).Elem()
value.SetUint(n.Uint64)
return value
}
s.errorf("expected unsigned integer; found %s", n)
panic("not reached")
}
func (s *state) evalFloat(typ reflect.Type, n parse.Node) reflect.Value {
s.at(n)
if n, ok := n.(*parse.NumberNode); ok && n.IsFloat {
value := reflect.New(typ).Elem()
value.SetFloat(n.Float64)
return value
}
s.errorf("expected float; found %s", n)
panic("not reached")
}
func (s *state) evalComplex(typ reflect.Type, n parse.Node) reflect.Value {
if n, ok := n.(*parse.NumberNode); ok && n.IsComplex {
value := reflect.New(typ).Elem()
value.SetComplex(n.Complex128)
return value
}
s.errorf("expected complex; found %s", n)
panic("not reached")
}
func (s *state) evalEmptyInterface(dot reflect.Value, n parse.Node) reflect.Value {
s.at(n)
switch n := n.(type) {
case *parse.BoolNode:
return reflect.ValueOf(n.True)
case *parse.DotNode:
return dot
case *parse.FieldNode:
return s.evalFieldNode(dot, n, nil, missingVal)
case *parse.IdentifierNode:
return s.evalFunction(dot, n, n, nil, missingVal)
case *parse.NilNode:
// NilNode is handled in evalArg, the only place that calls here.
s.errorf("evalEmptyInterface: nil (can't happen)")
case *parse.NumberNode:
return s.idealConstant(n)
case *parse.StringNode:
return reflect.ValueOf(n.Text)
case *parse.VariableNode:
return s.evalVariableNode(dot, n, nil, missingVal)
case *parse.PipeNode:
return s.evalPipeline(dot, n)
}
s.errorf("can't handle assignment of %s to empty interface argument", n)
panic("not reached")
}
// indirect returns the item at the end of indirection, and a bool to indicate
// if it's nil. If the returned bool is true, the returned value's kind will be
// either a pointer or interface.
func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
for ; v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface; v = v.Elem() {
if v.IsNil() {
return v, true
}
}
return v, false
}
// indirectInterface returns the concrete value in an interface value,
// or else the zero reflect.Value.
// That is, if v represents the interface value x, the result is the same as reflect.ValueOf(x):
// the fact that x was an interface value is forgotten.
func indirectInterface(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Interface {
return v
}
if v.IsNil() {
return reflect.Value{}
}
return v.Elem()
}
// printValue writes the textual representation of the value to the output of
// the template.
func (s *state) printValue(n parse.Node, v reflect.Value) {
s.at(n)
iface, ok := printableValue(v)
if !ok {
s.errorf("can't print %s of type %s", n, v.Type())
}
_, err := fmt.Fprint(s.wr, iface)
if err != nil {
s.writeError(err)
}
}
// printableValue returns the, possibly indirected, interface value inside v that
// is best for a call to formatted printer.
func printableValue(v reflect.Value) (any, bool) {
if v.Kind() == reflect.Pointer {
v, _ = indirect(v) // fmt.Fprint handles nil.
}
if !v.IsValid() {
return "<no value>", true
}
if !v.Type().Implements(errorType) && !v.Type().Implements(fmtStringerType) {
if v.CanAddr() && (reflect.PointerTo(v.Type()).Implements(errorType) || reflect.PointerTo(v.Type()).Implements(fmtStringerType)) {
v = v.Addr()
} else {
switch v.Kind() {
case reflect.Chan, reflect.Func:
return nil, false
}
}
}
return v.Interface(), true
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"errors"
"fmt"
"io"
"net/url"
"reflect"
"strings"
"sync"
"unicode"
"unicode/utf8"
)
// FuncMap is the type of the map defining the mapping from names to functions.
// Each function must have either a single return value, or two return values of
// which the second has type error. In that case, if the second (error)
// return value evaluates to non-nil during execution, execution terminates and
// Execute returns that error.
//
// Errors returned by Execute wrap the underlying error; call errors.As to
// uncover them.
//
// When template execution invokes a function with an argument list, that list
// must be assignable to the function's parameter types. Functions meant to
// apply to arguments of arbitrary type can use parameters of type interface{} or
// of type reflect.Value. Similarly, functions meant to return a result of arbitrary
// type can return interface{} or reflect.Value.
type FuncMap map[string]any
// builtins returns the FuncMap.
// It is not a global variable so the linker can dead code eliminate
// more when this isn't called. See golang.org/issue/36021.
// TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
func builtins() FuncMap {
return FuncMap{
"and": and,
"call": call,
"html": HTMLEscaper,
"index": index,
"slice": slice,
"js": JSEscaper,
"len": length,
"not": not,
"or": or,
"print": fmt.Sprint,
"printf": fmt.Sprintf,
"println": fmt.Sprintln,
"urlquery": URLQueryEscaper,
// Comparisons
"eq": eq, // ==
"ge": ge, // >=
"gt": gt, // >
"le": le, // <=
"lt": lt, // <
"ne": ne, // !=
}
}
var builtinFuncsOnce struct {
sync.Once
v map[string]reflect.Value
}
// builtinFuncsOnce lazily computes & caches the builtinFuncs map.
// TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
func builtinFuncs() map[string]reflect.Value {
builtinFuncsOnce.Do(func() {
builtinFuncsOnce.v = createValueFuncs(builtins())
})
return builtinFuncsOnce.v
}
// createValueFuncs turns a FuncMap into a map[string]reflect.Value
func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
m := make(map[string]reflect.Value)
addValueFuncs(m, funcMap)
return m
}
// addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
for name, fn := range in {
if !goodName(name) {
panic(fmt.Errorf("function name %q is not a valid identifier", name))
}
v := reflect.ValueOf(fn)
if v.Kind() != reflect.Func {
panic("value for " + name + " not a function")
}
if !goodFunc(v.Type()) {
panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
}
out[name] = v
}
}
// addFuncs adds to values the functions in funcs. It does no checking of the input -
// call addValueFuncs first.
func addFuncs(out, in FuncMap) {
for name, fn := range in {
out[name] = fn
}
}
// goodFunc reports whether the function or method has the right result signature.
func goodFunc(typ reflect.Type) bool {
// We allow functions with 1 result or 2 results where the second is an error.
switch {
case typ.NumOut() == 1:
return true
case typ.NumOut() == 2 && typ.Out(1) == errorType:
return true
}
return false
}
// goodName reports whether the function name is a valid identifier.
func goodName(name string) bool {
if name == "" {
return false
}
for i, r := range name {
switch {
case r == '_':
case i == 0 && !unicode.IsLetter(r):
return false
case !unicode.IsLetter(r) && !unicode.IsDigit(r):
return false
}
}
return true
}
// findFunction looks for a function in the template, and global map.
func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
if tmpl != nil && tmpl.common != nil {
tmpl.muFuncs.RLock()
defer tmpl.muFuncs.RUnlock()
if fn := tmpl.execFuncs[name]; fn.IsValid() {
return fn, false, true
}
}
if fn := builtinFuncs()[name]; fn.IsValid() {
return fn, true, true
}
return reflect.Value{}, false, false
}
// prepareArg checks if value can be used as an argument of type argType, and
// converts an invalid value to appropriate zero if possible.
func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
if !value.IsValid() {
if !canBeNil(argType) {
return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
}
value = reflect.Zero(argType)
}
if value.Type().AssignableTo(argType) {
return value, nil
}
if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
value = value.Convert(argType)
return value, nil
}
return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
}
func intLike(typ reflect.Kind) bool {
switch typ {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return true
}
return false
}
// indexArg checks if a reflect.Value can be used as an index, and converts it to int if possible.
func indexArg(index reflect.Value, cap int) (int, error) {
var x int64
switch index.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
x = index.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
x = int64(index.Uint())
case reflect.Invalid:
return 0, fmt.Errorf("cannot index slice/array with nil")
default:
return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
}
if x < 0 || int(x) < 0 || int(x) > cap {
return 0, fmt.Errorf("index out of range: %d", x)
}
return int(x), nil
}
// Indexing.
// index returns the result of indexing its first argument by the following
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
// indexed item must be a map, slice, or array.
func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
item = indirectInterface(item)
if !item.IsValid() {
return reflect.Value{}, fmt.Errorf("index of untyped nil")
}
for _, index := range indexes {
index = indirectInterface(index)
var isNil bool
if item, isNil = indirect(item); isNil {
return reflect.Value{}, fmt.Errorf("index of nil pointer")
}
switch item.Kind() {
case reflect.Array, reflect.Slice, reflect.String:
x, err := indexArg(index, item.Len())
if err != nil {
return reflect.Value{}, err
}
item = item.Index(x)
case reflect.Map:
index, err := prepareArg(index, item.Type().Key())
if err != nil {
return reflect.Value{}, err
}
if x := item.MapIndex(index); x.IsValid() {
item = x
} else {
item = reflect.Zero(item.Type().Elem())
}
case reflect.Invalid:
// the loop holds invariant: item.IsValid()
panic("unreachable")
default:
return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
}
}
return item, nil
}
// Slicing.
// slice returns the result of slicing its first argument by the remaining
// arguments. Thus "slice x 1 2" is, in Go syntax, x[1:2], while "slice x"
// is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
// argument must be a string, slice, or array.
func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
item = indirectInterface(item)
if !item.IsValid() {
return reflect.Value{}, fmt.Errorf("slice of untyped nil")
}
if len(indexes) > 3 {
return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
}
var cap int
switch item.Kind() {
case reflect.String:
if len(indexes) == 3 {
return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
}
cap = item.Len()
case reflect.Array, reflect.Slice:
cap = item.Cap()
default:
return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
}
// set default values for cases item[:], item[i:].
idx := [3]int{0, item.Len()}
for i, index := range indexes {
x, err := indexArg(index, cap)
if err != nil {
return reflect.Value{}, err
}
idx[i] = x
}
// given item[i:j], make sure i <= j.
if idx[0] > idx[1] {
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
}
if len(indexes) < 3 {
return item.Slice(idx[0], idx[1]), nil
}
// given item[i:j:k], make sure i <= j <= k.
if idx[1] > idx[2] {
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
}
return item.Slice3(idx[0], idx[1], idx[2]), nil
}
// Length
// length returns the length of the item, with an error if it has no defined length.
func length(item reflect.Value) (int, error) {
item, isNil := indirect(item)
if isNil {
return 0, fmt.Errorf("len of nil pointer")
}
switch item.Kind() {
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
return item.Len(), nil
}
return 0, fmt.Errorf("len of type %s", item.Type())
}
// Function invocation
// call returns the result of evaluating the first argument as a function.
// The function must return 1 result, or 2 results, the second of which is an error.
func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
fn = indirectInterface(fn)
if !fn.IsValid() {
return reflect.Value{}, fmt.Errorf("call of nil")
}
typ := fn.Type()
if typ.Kind() != reflect.Func {
return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
}
if !goodFunc(typ) {
return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
}
numIn := typ.NumIn()
var dddType reflect.Type
if typ.IsVariadic() {
if len(args) < numIn-1 {
return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
}
dddType = typ.In(numIn - 1).Elem()
} else {
if len(args) != numIn {
return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
}
}
argv := make([]reflect.Value, len(args))
for i, arg := range args {
arg = indirectInterface(arg)
// Compute the expected type. Clumsy because of variadics.
argType := dddType
if !typ.IsVariadic() || i < numIn-1 {
argType = typ.In(i)
}
var err error
if argv[i], err = prepareArg(arg, argType); err != nil {
return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
}
}
return safeCall(fn, argv)
}
// safeCall runs fun.Call(args), and returns the resulting value and error, if
// any. If the call panics, the panic value is returned as an error.
func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
if e, ok := r.(error); ok {
err = e
} else {
err = fmt.Errorf("%v", r)
}
}
}()
ret := fun.Call(args)
if len(ret) == 2 && !ret[1].IsNil() {
return ret[0], ret[1].Interface().(error)
}
return ret[0], nil
}
// Boolean logic.
func truth(arg reflect.Value) bool {
t, _ := isTrue(indirectInterface(arg))
return t
}
// and computes the Boolean AND of its arguments, returning
// the first false argument it encounters, or the last argument.
func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
panic("unreachable") // implemented as a special case in evalCall
}
// or computes the Boolean OR of its arguments, returning
// the first true argument it encounters, or the last argument.
func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
panic("unreachable") // implemented as a special case in evalCall
}
// not returns the Boolean negation of its argument.
func not(arg reflect.Value) bool {
return !truth(arg)
}
// Comparison.
// TODO: Perhaps allow comparison between signed and unsigned integers.
var (
errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
stringKind
uintKind
)
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
case reflect.String:
return stringKind, nil
}
return invalidKind, errBadComparisonType
}
// isNil returns true if v is the zero reflect.Value, or nil of its type.
func isNil(v reflect.Value) bool {
if !v.IsValid() {
return true
}
switch v.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
return v.IsNil()
}
return false
}
// canCompare reports whether v1 and v2 are both the same kind, or one is nil.
// Called only when dealing with nillable types, or there's about to be an error.
func canCompare(v1, v2 reflect.Value) bool {
k1 := v1.Kind()
k2 := v2.Kind()
if k1 == k2 {
return true
}
// We know the type can be compared to nil.
return k1 == reflect.Invalid || k2 == reflect.Invalid
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
arg1 = indirectInterface(arg1)
if len(arg2) == 0 {
return false, errNoComparison
}
k1, _ := basicKind(arg1)
for _, arg := range arg2 {
arg = indirectInterface(arg)
k2, _ := basicKind(arg)
truth := false
if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
case k1 == uintKind && k2 == intKind:
truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
default:
if arg1 != zero && arg != zero {
return false, errBadComparison
}
}
} else {
switch k1 {
case boolKind:
truth = arg1.Bool() == arg.Bool()
case complexKind:
truth = arg1.Complex() == arg.Complex()
case floatKind:
truth = arg1.Float() == arg.Float()
case intKind:
truth = arg1.Int() == arg.Int()
case stringKind:
truth = arg1.String() == arg.String()
case uintKind:
truth = arg1.Uint() == arg.Uint()
default:
if !canCompare(arg1, arg) {
return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
}
if isNil(arg1) || isNil(arg) {
truth = isNil(arg) == isNil(arg1)
} else {
if !arg.Type().Comparable() {
return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
}
truth = arg1.Interface() == arg.Interface()
}
}
}
if truth {
return true, nil
}
}
return false, nil
}
// ne evaluates the comparison a != b.
func ne(arg1, arg2 reflect.Value) (bool, error) {
// != is the inverse of ==.
equal, err := eq(arg1, arg2)
return !equal, err
}
// lt evaluates the comparison a < b.
func lt(arg1, arg2 reflect.Value) (bool, error) {
arg1 = indirectInterface(arg1)
k1, err := basicKind(arg1)
if err != nil {
return false, err
}
arg2 = indirectInterface(arg2)
k2, err := basicKind(arg2)
if err != nil {
return false, err
}
truth := false
if k1 != k2 {
// Special case: Can compare integer values regardless of type's sign.
switch {
case k1 == intKind && k2 == uintKind:
truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
case k1 == uintKind && k2 == intKind:
truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
default:
return false, errBadComparison
}
} else {
switch k1 {
case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = arg1.Float() < arg2.Float()
case intKind:
truth = arg1.Int() < arg2.Int()
case stringKind:
truth = arg1.String() < arg2.String()
case uintKind:
truth = arg1.Uint() < arg2.Uint()
default:
panic("invalid kind")
}
}
return truth, nil
}
// le evaluates the comparison <= b.
func le(arg1, arg2 reflect.Value) (bool, error) {
// <= is < or ==.
lessThan, err := lt(arg1, arg2)
if lessThan || err != nil {
return lessThan, err
}
return eq(arg1, arg2)
}
// gt evaluates the comparison a > b.
func gt(arg1, arg2 reflect.Value) (bool, error) {
// > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2)
if err != nil {
return false, err
}
return !lessOrEqual, nil
}
// ge evaluates the comparison a >= b.
func ge(arg1, arg2 reflect.Value) (bool, error) {
// >= is the inverse of <.
lessThan, err := lt(arg1, arg2)
if err != nil {
return false, err
}
return !lessThan, nil
}
// HTML escaping.
var (
htmlQuot = []byte(""") // shorter than """
htmlApos = []byte("'") // shorter than "'" and apos was not in HTML until HTML5
htmlAmp = []byte("&")
htmlLt = []byte("<")
htmlGt = []byte(">")
htmlNull = []byte("\uFFFD")
)
// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
func HTMLEscape(w io.Writer, b []byte) {
last := 0
for i, c := range b {
var html []byte
switch c {
case '\000':
html = htmlNull
case '"':
html = htmlQuot
case '\'':
html = htmlApos
case '&':
html = htmlAmp
case '<':
html = htmlLt
case '>':
html = htmlGt
default:
continue
}
w.Write(b[last:i])
w.Write(html)
last = i + 1
}
w.Write(b[last:])
}
// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
func HTMLEscapeString(s string) string {
// Avoid allocation if we can.
if !strings.ContainsAny(s, "'\"&<>\000") {
return s
}
var b strings.Builder
HTMLEscape(&b, []byte(s))
return b.String()
}
// HTMLEscaper returns the escaped HTML equivalent of the textual
// representation of its arguments.
func HTMLEscaper(args ...any) string {
return HTMLEscapeString(evalArgs(args))
}
// JavaScript escaping.
var (
jsLowUni = []byte(`\u00`)
hex = []byte("0123456789ABCDEF")
jsBackslash = []byte(`\\`)
jsApos = []byte(`\'`)
jsQuot = []byte(`\"`)
jsLt = []byte(`\u003C`)
jsGt = []byte(`\u003E`)
jsAmp = []byte(`\u0026`)
jsEq = []byte(`\u003D`)
)
// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
func JSEscape(w io.Writer, b []byte) {
last := 0
for i := 0; i < len(b); i++ {
c := b[i]
if !jsIsSpecial(rune(c)) {
// fast path: nothing to do
continue
}
w.Write(b[last:i])
if c < utf8.RuneSelf {
// Quotes, slashes and angle brackets get quoted.
// Control characters get written as \u00XX.
switch c {
case '\\':
w.Write(jsBackslash)
case '\'':
w.Write(jsApos)
case '"':
w.Write(jsQuot)
case '<':
w.Write(jsLt)
case '>':
w.Write(jsGt)
case '&':
w.Write(jsAmp)
case '=':
w.Write(jsEq)
default:
w.Write(jsLowUni)
t, b := c>>4, c&0x0f
w.Write(hex[t : t+1])
w.Write(hex[b : b+1])
}
} else {
// Unicode rune.
r, size := utf8.DecodeRune(b[i:])
if unicode.IsPrint(r) {
w.Write(b[i : i+size])
} else {
fmt.Fprintf(w, "\\u%04X", r)
}
i += size - 1
}
last = i + 1
}
w.Write(b[last:])
}
// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
func JSEscapeString(s string) string {
// Avoid allocation if we can.
if strings.IndexFunc(s, jsIsSpecial) < 0 {
return s
}
var b strings.Builder
JSEscape(&b, []byte(s))
return b.String()
}
func jsIsSpecial(r rune) bool {
switch r {
case '\\', '\'', '"', '<', '>', '&', '=':
return true
}
return r < ' ' || utf8.RuneSelf <= r
}
// JSEscaper returns the escaped JavaScript equivalent of the textual
// representation of its arguments.
func JSEscaper(args ...any) string {
return JSEscapeString(evalArgs(args))
}
// URLQueryEscaper returns the escaped value of the textual representation of
// its arguments in a form suitable for embedding in a URL query.
func URLQueryEscaper(args ...any) string {
return url.QueryEscape(evalArgs(args))
}
// evalArgs formats the list of arguments into a string. It is therefore equivalent to
//
// fmt.Sprint(args...)
//
// except that each argument is indirected (if a pointer), as required,
// using the same rules as the default string evaluation during template
// execution.
func evalArgs(args []any) string {
ok := false
var s string
// Fast path for simple common case.
if len(args) == 1 {
s, ok = args[0].(string)
}
if !ok {
for i, arg := range args {
a, ok := printableValue(reflect.ValueOf(arg))
if ok {
args[i] = a
} // else let fmt do its thing
}
s = fmt.Sprint(args...)
}
return s
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Helper functions to make constructing templates easier.
package template
import (
"fmt"
"io/fs"
"os"
"path"
"path/filepath"
)
// Functions and methods to parse templates.
// Must is a helper that wraps a call to a function returning (*Template, error)
// and panics if the error is non-nil. It is intended for use in variable
// initializations such as
//
// var t = template.Must(template.New("name").Parse("text"))
func Must(t *Template, err error) *Template {
if err != nil {
panic(err)
}
return t
}
// ParseFiles creates a new Template and parses the template definitions from
// the named files. The returned template's name will have the base name and
// parsed contents of the first file. There must be at least one file.
// If an error occurs, parsing stops and the returned *Template is nil.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
// For instance, ParseFiles("a/foo", "b/foo") stores "b/foo" as the template
// named "foo", while "a/foo" is unavailable.
func ParseFiles(filenames ...string) (*Template, error) {
return parseFiles(nil, readFileOS, filenames...)
}
// ParseFiles parses the named files and associates the resulting templates with
// t. If an error occurs, parsing stops and the returned template is nil;
// otherwise it is t. There must be at least one file.
// Since the templates created by ParseFiles are named by the base
// names of the argument files, t should usually have the name of one
// of the (base) names of the files. If it does not, depending on t's
// contents before calling ParseFiles, t.Execute may fail. In that
// case use t.ExecuteTemplate to execute a valid template.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
func (t *Template) ParseFiles(filenames ...string) (*Template, error) {
t.init()
return parseFiles(t, readFileOS, filenames...)
}
// parseFiles is the helper for the method and function. If the argument
// template is nil, it is created from the first file.
func parseFiles(t *Template, readFile func(string) (string, []byte, error), filenames ...string) (*Template, error) {
if len(filenames) == 0 {
// Not really a problem, but be consistent.
return nil, fmt.Errorf("template: no files named in call to ParseFiles")
}
for _, filename := range filenames {
name, b, err := readFile(filename)
if err != nil {
return nil, err
}
s := string(b)
// First template becomes return value if not already defined,
// and we use that one for subsequent New calls to associate
// all the templates together. Also, if this file has the same name
// as t, this file becomes the contents of t, so
// t, err := New(name).Funcs(xxx).ParseFiles(name)
// works. Otherwise we create a new template associated with t.
var tmpl *Template
if t == nil {
t = New(name)
}
if name == t.Name() {
tmpl = t
} else {
tmpl = t.New(name)
}
_, err = tmpl.Parse(s)
if err != nil {
return nil, err
}
}
return t, nil
}
// ParseGlob creates a new Template and parses the template definitions from
// the files identified by the pattern. The files are matched according to the
// semantics of filepath.Match, and the pattern must match at least one file.
// The returned template will have the (base) name and (parsed) contents of the
// first file matched by the pattern. ParseGlob is equivalent to calling
// ParseFiles with the list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
func ParseGlob(pattern string) (*Template, error) {
return parseGlob(nil, pattern)
}
// ParseGlob parses the template definitions in the files identified by the
// pattern and associates the resulting templates with t. The files are matched
// according to the semantics of filepath.Match, and the pattern must match at
// least one file. ParseGlob is equivalent to calling t.ParseFiles with the
// list of files matched by the pattern.
//
// When parsing multiple files with the same name in different directories,
// the last one mentioned will be the one that results.
func (t *Template) ParseGlob(pattern string) (*Template, error) {
t.init()
return parseGlob(t, pattern)
}
// parseGlob is the implementation of the function and method ParseGlob.
func parseGlob(t *Template, pattern string) (*Template, error) {
filenames, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
if len(filenames) == 0 {
return nil, fmt.Errorf("template: pattern matches no files: %#q", pattern)
}
return parseFiles(t, readFileOS, filenames...)
}
// ParseFS is like ParseFiles or ParseGlob but reads from the file system fsys
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func ParseFS(fsys fs.FS, patterns ...string) (*Template, error) {
return parseFS(nil, fsys, patterns)
}
// ParseFS is like ParseFiles or ParseGlob but reads from the file system fsys
// instead of the host operating system's file system.
// It accepts a list of glob patterns.
// (Note that most file names serve as glob patterns matching only themselves.)
func (t *Template) ParseFS(fsys fs.FS, patterns ...string) (*Template, error) {
t.init()
return parseFS(t, fsys, patterns)
}
func parseFS(t *Template, fsys fs.FS, patterns []string) (*Template, error) {
var filenames []string
for _, pattern := range patterns {
list, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, fmt.Errorf("template: pattern matches no files: %#q", pattern)
}
filenames = append(filenames, list...)
}
return parseFiles(t, readFileFS(fsys), filenames...)
}
func readFileOS(file string) (name string, b []byte, err error) {
name = filepath.Base(file)
b, err = os.ReadFile(file)
return
}
func readFileFS(fsys fs.FS) func(string) (string, []byte, error) {
return func(file string) (name string, b []byte, err error) {
name = path.Base(file)
b, err = fs.ReadFile(fsys, file)
return
}
}
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains the code to handle template options.
package template
import "strings"
// missingKeyAction defines how to respond to indexing a map with a key that is not present.
type missingKeyAction int
const (
mapInvalid missingKeyAction = iota // Return an invalid reflect.Value.
mapZeroValue // Return the zero value for the map element.
mapError // Error out
)
type option struct {
missingKey missingKeyAction
}
// Option sets options for the template. Options are described by
// strings, either a simple string or "key=value". There can be at
// most one equals sign in an option string. If the option string
// is unrecognized or otherwise invalid, Option panics.
//
// Known options:
//
// missingkey: Control the behavior during execution if a map is
// indexed with a key that is not present in the map.
//
// "missingkey=default" or "missingkey=invalid"
// The default behavior: Do nothing and continue execution.
// If printed, the result of the index operation is the string
// "<no value>".
// "missingkey=zero"
// The operation returns the zero value for the map type's element.
// "missingkey=error"
// Execution stops immediately with an error.
func (t *Template) Option(opt ...string) *Template {
t.init()
for _, s := range opt {
t.setOption(s)
}
return t
}
func (t *Template) setOption(opt string) {
if opt == "" {
panic("empty option string")
}
// key=value
if key, value, ok := strings.Cut(opt, "="); ok {
switch key {
case "missingkey":
switch value {
case "invalid", "default":
t.option.missingKey = mapInvalid
return
case "zero":
t.option.missingKey = mapZeroValue
return
case "error":
t.option.missingKey = mapError
return
}
}
}
panic("unrecognized option: " + opt)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package parse
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
// item represents a token or text string returned from the scanner.
type item struct {
typ itemType // The type of this item.
pos Pos // The starting position, in bytes, of this item in the input string.
val string // The value of this item.
line int // The line number at the start of this item.
}
func (i item) String() string {
switch {
case i.typ == itemEOF:
return "EOF"
case i.typ == itemError:
return i.val
case i.typ > itemKeyword:
return fmt.Sprintf("<%s>", i.val)
case len(i.val) > 10:
return fmt.Sprintf("%.10q...", i.val)
}
return fmt.Sprintf("%q", i.val)
}
// itemType identifies the type of lex items.
type itemType int
const (
itemError itemType = iota // error occurred; value is text of error
itemBool // boolean constant
itemChar // printable ASCII character; grab bag for comma etc.
itemCharConstant // character constant
itemComment // comment text
itemComplex // complex constant (1+2i); imaginary is just a number
itemAssign // equals ('=') introducing an assignment
itemDeclare // colon-equals (':=') introducing a declaration
itemEOF
itemField // alphanumeric identifier starting with '.'
itemIdentifier // alphanumeric identifier not starting with '.'
itemLeftDelim // left action delimiter
itemLeftParen // '(' inside action
itemNumber // simple number, including imaginary
itemPipe // pipe symbol
itemRawString // raw quoted string (includes quotes)
itemRightDelim // right action delimiter
itemRightParen // ')' inside action
itemSpace // run of spaces separating arguments
itemString // quoted string (includes quotes)
itemText // plain text
itemVariable // variable starting with '$', such as '$' or '$1' or '$hello'
// Keywords appear after all the rest.
itemKeyword // used only to delimit the keywords
itemBlock // block keyword
itemBreak // break keyword
itemContinue // continue keyword
itemDot // the cursor, spelled '.'
itemDefine // define keyword
itemElse // else keyword
itemEnd // end keyword
itemIf // if keyword
itemNil // the untyped nil constant, easiest to treat as a keyword
itemRange // range keyword
itemTemplate // template keyword
itemWith // with keyword
)
var key = map[string]itemType{
".": itemDot,
"block": itemBlock,
"break": itemBreak,
"continue": itemContinue,
"define": itemDefine,
"else": itemElse,
"end": itemEnd,
"if": itemIf,
"range": itemRange,
"nil": itemNil,
"template": itemTemplate,
"with": itemWith,
}
const eof = -1
// Trimming spaces.
// If the action begins "{{- " rather than "{{", then all space/tab/newlines
// preceding the action are trimmed; conversely if it ends " -}}" the
// leading spaces are trimmed. This is done entirely in the lexer; the
// parser never sees it happen. We require an ASCII space (' ', \t, \r, \n)
// to be present to avoid ambiguity with things like "{{-3}}". It reads
// better with the space present anyway. For simplicity, only ASCII
// does the job.
const (
spaceChars = " \t\r\n" // These are the space characters defined by Go itself.
trimMarker = '-' // Attached to left/right delimiter, trims trailing spaces from preceding/following text.
trimMarkerLen = Pos(1 + 1) // marker plus space before or after
)
// stateFn represents the state of the scanner as a function that returns the next state.
type stateFn func(*lexer) stateFn
// lexer holds the state of the scanner.
type lexer struct {
name string // the name of the input; used only for error reports
input string // the string being scanned
leftDelim string // start of action marker
rightDelim string // end of action marker
pos Pos // current position in the input
start Pos // start position of this item
atEOF bool // we have hit the end of input and returned eof
parenDepth int // nesting depth of ( ) exprs
line int // 1+number of newlines seen
startLine int // start line of this item
item item // item to return to parser
insideAction bool // are we inside an action?
options lexOptions
}
// lexOptions control behavior of the lexer. All default to false.
type lexOptions struct {
emitComment bool // emit itemComment tokens.
breakOK bool // break keyword allowed
continueOK bool // continue keyword allowed
}
// next returns the next rune in the input.
func (l *lexer) next() rune {
if int(l.pos) >= len(l.input) {
l.atEOF = true
return eof
}
r, w := utf8.DecodeRuneInString(l.input[l.pos:])
l.pos += Pos(w)
if r == '\n' {
l.line++
}
return r
}
// peek returns but does not consume the next rune in the input.
func (l *lexer) peek() rune {
r := l.next()
l.backup()
return r
}
// backup steps back one rune.
func (l *lexer) backup() {
if !l.atEOF && l.pos > 0 {
r, w := utf8.DecodeLastRuneInString(l.input[:l.pos])
l.pos -= Pos(w)
// Correct newline count.
if r == '\n' {
l.line--
}
}
}
// thisItem returns the item at the current input point with the specified type
// and advances the input.
func (l *lexer) thisItem(t itemType) item {
i := item{t, l.start, l.input[l.start:l.pos], l.startLine}
l.start = l.pos
l.startLine = l.line
return i
}
// emit passes the trailing text as an item back to the parser.
func (l *lexer) emit(t itemType) stateFn {
return l.emitItem(l.thisItem(t))
}
// emitItem passes the specified item to the parser.
func (l *lexer) emitItem(i item) stateFn {
l.item = i
return nil
}
// ignore skips over the pending input before this point.
// It tracks newlines in the ignored text, so use it only
// for text that is skipped without calling l.next.
func (l *lexer) ignore() {
l.line += strings.Count(l.input[l.start:l.pos], "\n")
l.start = l.pos
l.startLine = l.line
}
// accept consumes the next rune if it's from the valid set.
func (l *lexer) accept(valid string) bool {
if strings.ContainsRune(valid, l.next()) {
return true
}
l.backup()
return false
}
// acceptRun consumes a run of runes from the valid set.
func (l *lexer) acceptRun(valid string) {
for strings.ContainsRune(valid, l.next()) {
}
l.backup()
}
// errorf returns an error token and terminates the scan by passing
// back a nil pointer that will be the next state, terminating l.nextItem.
func (l *lexer) errorf(format string, args ...any) stateFn {
l.item = item{itemError, l.start, fmt.Sprintf(format, args...), l.startLine}
l.start = 0
l.pos = 0
l.input = l.input[:0]
return nil
}
// nextItem returns the next item from the input.
// Called by the parser, not in the lexing goroutine.
func (l *lexer) nextItem() item {
l.item = item{itemEOF, l.pos, "EOF", l.startLine}
state := lexText
if l.insideAction {
state = lexInsideAction
}
for {
state = state(l)
if state == nil {
return l.item
}
}
}
// lex creates a new scanner for the input string.
func lex(name, input, left, right string) *lexer {
if left == "" {
left = leftDelim
}
if right == "" {
right = rightDelim
}
l := &lexer{
name: name,
input: input,
leftDelim: left,
rightDelim: right,
line: 1,
startLine: 1,
insideAction: false,
}
return l
}
// state functions
const (
leftDelim = "{{"
rightDelim = "}}"
leftComment = "/*"
rightComment = "*/"
)
// lexText scans until an opening action delimiter, "{{".
func lexText(l *lexer) stateFn {
if x := strings.Index(l.input[l.pos:], l.leftDelim); x >= 0 {
if x > 0 {
l.pos += Pos(x)
// Do we trim any trailing space?
trimLength := Pos(0)
delimEnd := l.pos + Pos(len(l.leftDelim))
if hasLeftTrimMarker(l.input[delimEnd:]) {
trimLength = rightTrimLength(l.input[l.start:l.pos])
}
l.pos -= trimLength
l.line += strings.Count(l.input[l.start:l.pos], "\n")
i := l.thisItem(itemText)
l.pos += trimLength
l.ignore()
if len(i.val) > 0 {
return l.emitItem(i)
}
}
return lexLeftDelim
}
l.pos = Pos(len(l.input))
// Correctly reached EOF.
if l.pos > l.start {
l.line += strings.Count(l.input[l.start:l.pos], "\n")
return l.emit(itemText)
}
return l.emit(itemEOF)
}
// rightTrimLength returns the length of the spaces at the end of the string.
func rightTrimLength(s string) Pos {
return Pos(len(s) - len(strings.TrimRight(s, spaceChars)))
}
// atRightDelim reports whether the lexer is at a right delimiter, possibly preceded by a trim marker.
func (l *lexer) atRightDelim() (delim, trimSpaces bool) {
if hasRightTrimMarker(l.input[l.pos:]) && strings.HasPrefix(l.input[l.pos+trimMarkerLen:], l.rightDelim) { // With trim marker.
return true, true
}
if strings.HasPrefix(l.input[l.pos:], l.rightDelim) { // Without trim marker.
return true, false
}
return false, false
}
// leftTrimLength returns the length of the spaces at the beginning of the string.
func leftTrimLength(s string) Pos {
return Pos(len(s) - len(strings.TrimLeft(s, spaceChars)))
}
// lexLeftDelim scans the left delimiter, which is known to be present, possibly with a trim marker.
// (The text to be trimmed has already been emitted.)
func lexLeftDelim(l *lexer) stateFn {
l.pos += Pos(len(l.leftDelim))
trimSpace := hasLeftTrimMarker(l.input[l.pos:])
afterMarker := Pos(0)
if trimSpace {
afterMarker = trimMarkerLen
}
if strings.HasPrefix(l.input[l.pos+afterMarker:], leftComment) {
l.pos += afterMarker
l.ignore()
return lexComment
}
i := l.thisItem(itemLeftDelim)
l.insideAction = true
l.pos += afterMarker
l.ignore()
l.parenDepth = 0
return l.emitItem(i)
}
// lexComment scans a comment. The left comment marker is known to be present.
func lexComment(l *lexer) stateFn {
l.pos += Pos(len(leftComment))
x := strings.Index(l.input[l.pos:], rightComment)
if x < 0 {
return l.errorf("unclosed comment")
}
l.pos += Pos(x + len(rightComment))
delim, trimSpace := l.atRightDelim()
if !delim {
return l.errorf("comment ends before closing delimiter")
}
i := l.thisItem(itemComment)
if trimSpace {
l.pos += trimMarkerLen
}
l.pos += Pos(len(l.rightDelim))
if trimSpace {
l.pos += leftTrimLength(l.input[l.pos:])
}
l.ignore()
if l.options.emitComment {
return l.emitItem(i)
}
return lexText
}
// lexRightDelim scans the right delimiter, which is known to be present, possibly with a trim marker.
func lexRightDelim(l *lexer) stateFn {
_, trimSpace := l.atRightDelim()
if trimSpace {
l.pos += trimMarkerLen
l.ignore()
}
l.pos += Pos(len(l.rightDelim))
i := l.thisItem(itemRightDelim)
if trimSpace {
l.pos += leftTrimLength(l.input[l.pos:])
l.ignore()
}
l.insideAction = false
return l.emitItem(i)
}
// lexInsideAction scans the elements inside action delimiters.
func lexInsideAction(l *lexer) stateFn {
// Either number, quoted string, or identifier.
// Spaces separate arguments; runs of spaces turn into itemSpace.
// Pipe symbols separate and are emitted.
delim, _ := l.atRightDelim()
if delim {
if l.parenDepth == 0 {
return lexRightDelim
}
return l.errorf("unclosed left paren")
}
switch r := l.next(); {
case r == eof:
return l.errorf("unclosed action")
case isSpace(r):
l.backup() // Put space back in case we have " -}}".
return lexSpace
case r == '=':
return l.emit(itemAssign)
case r == ':':
if l.next() != '=' {
return l.errorf("expected :=")
}
return l.emit(itemDeclare)
case r == '|':
return l.emit(itemPipe)
case r == '"':
return lexQuote
case r == '`':
return lexRawQuote
case r == '$':
return lexVariable
case r == '\'':
return lexChar
case r == '.':
// special look-ahead for ".field" so we don't break l.backup().
if l.pos < Pos(len(l.input)) {
r := l.input[l.pos]
if r < '0' || '9' < r {
return lexField
}
}
fallthrough // '.' can start a number.
case r == '+' || r == '-' || ('0' <= r && r <= '9'):
l.backup()
return lexNumber
case isAlphaNumeric(r):
l.backup()
return lexIdentifier
case r == '(':
l.parenDepth++
return l.emit(itemLeftParen)
case r == ')':
l.parenDepth--
if l.parenDepth < 0 {
return l.errorf("unexpected right paren")
}
return l.emit(itemRightParen)
case r <= unicode.MaxASCII && unicode.IsPrint(r):
return l.emit(itemChar)
default:
return l.errorf("unrecognized character in action: %#U", r)
}
}
// lexSpace scans a run of space characters.
// We have not consumed the first space, which is known to be present.
// Take care if there is a trim-marked right delimiter, which starts with a space.
func lexSpace(l *lexer) stateFn {
var r rune
var numSpaces int
for {
r = l.peek()
if !isSpace(r) {
break
}
l.next()
numSpaces++
}
// Be careful about a trim-marked closing delimiter, which has a minus
// after a space. We know there is a space, so check for the '-' that might follow.
if hasRightTrimMarker(l.input[l.pos-1:]) && strings.HasPrefix(l.input[l.pos-1+trimMarkerLen:], l.rightDelim) {
l.backup() // Before the space.
if numSpaces == 1 {
return lexRightDelim // On the delim, so go right to that.
}
}
return l.emit(itemSpace)
}
// lexIdentifier scans an alphanumeric.
func lexIdentifier(l *lexer) stateFn {
for {
switch r := l.next(); {
case isAlphaNumeric(r):
// absorb.
default:
l.backup()
word := l.input[l.start:l.pos]
if !l.atTerminator() {
return l.errorf("bad character %#U", r)
}
switch {
case key[word] > itemKeyword:
item := key[word]
if item == itemBreak && !l.options.breakOK || item == itemContinue && !l.options.continueOK {
return l.emit(itemIdentifier)
}
return l.emit(item)
case word[0] == '.':
return l.emit(itemField)
case word == "true", word == "false":
return l.emit(itemBool)
default:
return l.emit(itemIdentifier)
}
}
}
}
// lexField scans a field: .Alphanumeric.
// The . has been scanned.
func lexField(l *lexer) stateFn {
return lexFieldOrVariable(l, itemField)
}
// lexVariable scans a Variable: $Alphanumeric.
// The $ has been scanned.
func lexVariable(l *lexer) stateFn {
if l.atTerminator() { // Nothing interesting follows -> "$".
return l.emit(itemVariable)
}
return lexFieldOrVariable(l, itemVariable)
}
// lexFieldOrVariable scans a field or variable: [.$]Alphanumeric.
// The . or $ has been scanned.
func lexFieldOrVariable(l *lexer, typ itemType) stateFn {
if l.atTerminator() { // Nothing interesting follows -> "." or "$".
if typ == itemVariable {
return l.emit(itemVariable)
}
return l.emit(itemDot)
}
var r rune
for {
r = l.next()
if !isAlphaNumeric(r) {
l.backup()
break
}
}
if !l.atTerminator() {
return l.errorf("bad character %#U", r)
}
return l.emit(typ)
}
// atTerminator reports whether the input is at valid termination character to
// appear after an identifier. Breaks .X.Y into two pieces. Also catches cases
// like "$x+2" not being acceptable without a space, in case we decide one
// day to implement arithmetic.
func (l *lexer) atTerminator() bool {
r := l.peek()
if isSpace(r) {
return true
}
switch r {
case eof, '.', ',', '|', ':', ')', '(':
return true
}
return strings.HasPrefix(l.input[l.pos:], l.rightDelim)
}
// lexChar scans a character constant. The initial quote is already
// scanned. Syntax checking is done by the parser.
func lexChar(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case '\\':
if r := l.next(); r != eof && r != '\n' {
break
}
fallthrough
case eof, '\n':
return l.errorf("unterminated character constant")
case '\'':
break Loop
}
}
return l.emit(itemCharConstant)
}
// lexNumber scans a number: decimal, octal, hex, float, or imaginary. This
// isn't a perfect number scanner - for instance it accepts "." and "0x0.2"
// and "089" - but when it's wrong the input is invalid and the parser (via
// strconv) will notice.
func lexNumber(l *lexer) stateFn {
if !l.scanNumber() {
return l.errorf("bad number syntax: %q", l.input[l.start:l.pos])
}
if sign := l.peek(); sign == '+' || sign == '-' {
// Complex: 1+2i. No spaces, must end in 'i'.
if !l.scanNumber() || l.input[l.pos-1] != 'i' {
return l.errorf("bad number syntax: %q", l.input[l.start:l.pos])
}
return l.emit(itemComplex)
}
return l.emit(itemNumber)
}
func (l *lexer) scanNumber() bool {
// Optional leading sign.
l.accept("+-")
// Is it hex?
digits := "0123456789_"
if l.accept("0") {
// Note: Leading 0 does not mean octal in floats.
if l.accept("xX") {
digits = "0123456789abcdefABCDEF_"
} else if l.accept("oO") {
digits = "01234567_"
} else if l.accept("bB") {
digits = "01_"
}
}
l.acceptRun(digits)
if l.accept(".") {
l.acceptRun(digits)
}
if len(digits) == 10+1 && l.accept("eE") {
l.accept("+-")
l.acceptRun("0123456789_")
}
if len(digits) == 16+6+1 && l.accept("pP") {
l.accept("+-")
l.acceptRun("0123456789_")
}
// Is it imaginary?
l.accept("i")
// Next thing mustn't be alphanumeric.
if isAlphaNumeric(l.peek()) {
l.next()
return false
}
return true
}
// lexQuote scans a quoted string.
func lexQuote(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case '\\':
if r := l.next(); r != eof && r != '\n' {
break
}
fallthrough
case eof, '\n':
return l.errorf("unterminated quoted string")
case '"':
break Loop
}
}
return l.emit(itemString)
}
// lexRawQuote scans a raw quoted string.
func lexRawQuote(l *lexer) stateFn {
Loop:
for {
switch l.next() {
case eof:
return l.errorf("unterminated raw quoted string")
case '`':
break Loop
}
}
return l.emit(itemRawString)
}
// isSpace reports whether r is a space character.
func isSpace(r rune) bool {
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
}
// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore.
func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
}
func hasLeftTrimMarker(s string) bool {
return len(s) >= 2 && s[0] == trimMarker && isSpace(rune(s[1]))
}
func hasRightTrimMarker(s string) bool {
return len(s) >= 2 && isSpace(rune(s[0])) && s[1] == trimMarker
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Parse nodes.
package parse
import (
"fmt"
"strconv"
"strings"
)
var textFormat = "%s" // Changed to "%q" in tests for better error messages.
// A Node is an element in the parse tree. The interface is trivial.
// The interface contains an unexported method so that only
// types local to this package can satisfy it.
type Node interface {
Type() NodeType
String() string
// Copy does a deep copy of the Node and all its components.
// To avoid type assertions, some XxxNodes also have specialized
// CopyXxx methods that return *XxxNode.
Copy() Node
Position() Pos // byte position of start of node in full original input string
// tree returns the containing *Tree.
// It is unexported so all implementations of Node are in this package.
tree() *Tree
// writeTo writes the String output to the builder.
writeTo(*strings.Builder)
}
// NodeType identifies the type of a parse tree node.
type NodeType int
// Pos represents a byte position in the original input text from which
// this template was parsed.
type Pos int
func (p Pos) Position() Pos {
return p
}
// Type returns itself and provides an easy default implementation
// for embedding in a Node. Embedded in all non-trivial Nodes.
func (t NodeType) Type() NodeType {
return t
}
const (
NodeText NodeType = iota // Plain text.
NodeAction // A non-control action such as a field evaluation.
NodeBool // A boolean constant.
NodeChain // A sequence of field accesses.
NodeCommand // An element of a pipeline.
NodeDot // The cursor, dot.
nodeElse // An else action. Not added to tree.
nodeEnd // An end action. Not added to tree.
NodeField // A field or method name.
NodeIdentifier // An identifier; always a function name.
NodeIf // An if action.
NodeList // A list of Nodes.
NodeNil // An untyped nil constant.
NodeNumber // A numerical constant.
NodePipe // A pipeline of commands.
NodeRange // A range action.
NodeString // A string constant.
NodeTemplate // A template invocation action.
NodeVariable // A $ variable.
NodeWith // A with action.
NodeComment // A comment.
NodeBreak // A break action.
NodeContinue // A continue action.
)
// Nodes.
// ListNode holds a sequence of nodes.
type ListNode struct {
NodeType
Pos
tr *Tree
Nodes []Node // The element nodes in lexical order.
}
func (t *Tree) newList(pos Pos) *ListNode {
return &ListNode{tr: t, NodeType: NodeList, Pos: pos}
}
func (l *ListNode) append(n Node) {
l.Nodes = append(l.Nodes, n)
}
func (l *ListNode) tree() *Tree {
return l.tr
}
func (l *ListNode) String() string {
var sb strings.Builder
l.writeTo(&sb)
return sb.String()
}
func (l *ListNode) writeTo(sb *strings.Builder) {
for _, n := range l.Nodes {
n.writeTo(sb)
}
}
func (l *ListNode) CopyList() *ListNode {
if l == nil {
return l
}
n := l.tr.newList(l.Pos)
for _, elem := range l.Nodes {
n.append(elem.Copy())
}
return n
}
func (l *ListNode) Copy() Node {
return l.CopyList()
}
// TextNode holds plain text.
type TextNode struct {
NodeType
Pos
tr *Tree
Text []byte // The text; may span newlines.
}
func (t *Tree) newText(pos Pos, text string) *TextNode {
return &TextNode{tr: t, NodeType: NodeText, Pos: pos, Text: []byte(text)}
}
func (t *TextNode) String() string {
return fmt.Sprintf(textFormat, t.Text)
}
func (t *TextNode) writeTo(sb *strings.Builder) {
sb.WriteString(t.String())
}
func (t *TextNode) tree() *Tree {
return t.tr
}
func (t *TextNode) Copy() Node {
return &TextNode{tr: t.tr, NodeType: NodeText, Pos: t.Pos, Text: append([]byte{}, t.Text...)}
}
// CommentNode holds a comment.
type CommentNode struct {
NodeType
Pos
tr *Tree
Text string // Comment text.
}
func (t *Tree) newComment(pos Pos, text string) *CommentNode {
return &CommentNode{tr: t, NodeType: NodeComment, Pos: pos, Text: text}
}
func (c *CommentNode) String() string {
var sb strings.Builder
c.writeTo(&sb)
return sb.String()
}
func (c *CommentNode) writeTo(sb *strings.Builder) {
sb.WriteString("{{")
sb.WriteString(c.Text)
sb.WriteString("}}")
}
func (c *CommentNode) tree() *Tree {
return c.tr
}
func (c *CommentNode) Copy() Node {
return &CommentNode{tr: c.tr, NodeType: NodeComment, Pos: c.Pos, Text: c.Text}
}
// PipeNode holds a pipeline with optional declaration
type PipeNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input. Deprecated: Kept for compatibility.
IsAssign bool // The variables are being assigned, not declared.
Decl []*VariableNode // Variables in lexical order.
Cmds []*CommandNode // The commands in lexical order.
}
func (t *Tree) newPipeline(pos Pos, line int, vars []*VariableNode) *PipeNode {
return &PipeNode{tr: t, NodeType: NodePipe, Pos: pos, Line: line, Decl: vars}
}
func (p *PipeNode) append(command *CommandNode) {
p.Cmds = append(p.Cmds, command)
}
func (p *PipeNode) String() string {
var sb strings.Builder
p.writeTo(&sb)
return sb.String()
}
func (p *PipeNode) writeTo(sb *strings.Builder) {
if len(p.Decl) > 0 {
for i, v := range p.Decl {
if i > 0 {
sb.WriteString(", ")
}
v.writeTo(sb)
}
sb.WriteString(" := ")
}
for i, c := range p.Cmds {
if i > 0 {
sb.WriteString(" | ")
}
c.writeTo(sb)
}
}
func (p *PipeNode) tree() *Tree {
return p.tr
}
func (p *PipeNode) CopyPipe() *PipeNode {
if p == nil {
return p
}
vars := make([]*VariableNode, len(p.Decl))
for i, d := range p.Decl {
vars[i] = d.Copy().(*VariableNode)
}
n := p.tr.newPipeline(p.Pos, p.Line, vars)
n.IsAssign = p.IsAssign
for _, c := range p.Cmds {
n.append(c.Copy().(*CommandNode))
}
return n
}
func (p *PipeNode) Copy() Node {
return p.CopyPipe()
}
// ActionNode holds an action (something bounded by delimiters).
// Control actions have their own nodes; ActionNode represents simple
// ones such as field evaluations and parenthesized pipelines.
type ActionNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input. Deprecated: Kept for compatibility.
Pipe *PipeNode // The pipeline in the action.
}
func (t *Tree) newAction(pos Pos, line int, pipe *PipeNode) *ActionNode {
return &ActionNode{tr: t, NodeType: NodeAction, Pos: pos, Line: line, Pipe: pipe}
}
func (a *ActionNode) String() string {
var sb strings.Builder
a.writeTo(&sb)
return sb.String()
}
func (a *ActionNode) writeTo(sb *strings.Builder) {
sb.WriteString("{{")
a.Pipe.writeTo(sb)
sb.WriteString("}}")
}
func (a *ActionNode) tree() *Tree {
return a.tr
}
func (a *ActionNode) Copy() Node {
return a.tr.newAction(a.Pos, a.Line, a.Pipe.CopyPipe())
}
// CommandNode holds a command (a pipeline inside an evaluating action).
type CommandNode struct {
NodeType
Pos
tr *Tree
Args []Node // Arguments in lexical order: Identifier, field, or constant.
}
func (t *Tree) newCommand(pos Pos) *CommandNode {
return &CommandNode{tr: t, NodeType: NodeCommand, Pos: pos}
}
func (c *CommandNode) append(arg Node) {
c.Args = append(c.Args, arg)
}
func (c *CommandNode) String() string {
var sb strings.Builder
c.writeTo(&sb)
return sb.String()
}
func (c *CommandNode) writeTo(sb *strings.Builder) {
for i, arg := range c.Args {
if i > 0 {
sb.WriteByte(' ')
}
if arg, ok := arg.(*PipeNode); ok {
sb.WriteByte('(')
arg.writeTo(sb)
sb.WriteByte(')')
continue
}
arg.writeTo(sb)
}
}
func (c *CommandNode) tree() *Tree {
return c.tr
}
func (c *CommandNode) Copy() Node {
if c == nil {
return c
}
n := c.tr.newCommand(c.Pos)
for _, c := range c.Args {
n.append(c.Copy())
}
return n
}
// IdentifierNode holds an identifier.
type IdentifierNode struct {
NodeType
Pos
tr *Tree
Ident string // The identifier's name.
}
// NewIdentifier returns a new IdentifierNode with the given identifier name.
func NewIdentifier(ident string) *IdentifierNode {
return &IdentifierNode{NodeType: NodeIdentifier, Ident: ident}
}
// SetPos sets the position. NewIdentifier is a public method so we can't modify its signature.
// Chained for convenience.
// TODO: fix one day?
func (i *IdentifierNode) SetPos(pos Pos) *IdentifierNode {
i.Pos = pos
return i
}
// SetTree sets the parent tree for the node. NewIdentifier is a public method so we can't modify its signature.
// Chained for convenience.
// TODO: fix one day?
func (i *IdentifierNode) SetTree(t *Tree) *IdentifierNode {
i.tr = t
return i
}
func (i *IdentifierNode) String() string {
return i.Ident
}
func (i *IdentifierNode) writeTo(sb *strings.Builder) {
sb.WriteString(i.String())
}
func (i *IdentifierNode) tree() *Tree {
return i.tr
}
func (i *IdentifierNode) Copy() Node {
return NewIdentifier(i.Ident).SetTree(i.tr).SetPos(i.Pos)
}
// VariableNode holds a list of variable names, possibly with chained field
// accesses. The dollar sign is part of the (first) name.
type VariableNode struct {
NodeType
Pos
tr *Tree
Ident []string // Variable name and fields in lexical order.
}
func (t *Tree) newVariable(pos Pos, ident string) *VariableNode {
return &VariableNode{tr: t, NodeType: NodeVariable, Pos: pos, Ident: strings.Split(ident, ".")}
}
func (v *VariableNode) String() string {
var sb strings.Builder
v.writeTo(&sb)
return sb.String()
}
func (v *VariableNode) writeTo(sb *strings.Builder) {
for i, id := range v.Ident {
if i > 0 {
sb.WriteByte('.')
}
sb.WriteString(id)
}
}
func (v *VariableNode) tree() *Tree {
return v.tr
}
func (v *VariableNode) Copy() Node {
return &VariableNode{tr: v.tr, NodeType: NodeVariable, Pos: v.Pos, Ident: append([]string{}, v.Ident...)}
}
// DotNode holds the special identifier '.'.
type DotNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newDot(pos Pos) *DotNode {
return &DotNode{tr: t, NodeType: NodeDot, Pos: pos}
}
func (d *DotNode) Type() NodeType {
// Override method on embedded NodeType for API compatibility.
// TODO: Not really a problem; could change API without effect but
// api tool complains.
return NodeDot
}
func (d *DotNode) String() string {
return "."
}
func (d *DotNode) writeTo(sb *strings.Builder) {
sb.WriteString(d.String())
}
func (d *DotNode) tree() *Tree {
return d.tr
}
func (d *DotNode) Copy() Node {
return d.tr.newDot(d.Pos)
}
// NilNode holds the special identifier 'nil' representing an untyped nil constant.
type NilNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newNil(pos Pos) *NilNode {
return &NilNode{tr: t, NodeType: NodeNil, Pos: pos}
}
func (n *NilNode) Type() NodeType {
// Override method on embedded NodeType for API compatibility.
// TODO: Not really a problem; could change API without effect but
// api tool complains.
return NodeNil
}
func (n *NilNode) String() string {
return "nil"
}
func (n *NilNode) writeTo(sb *strings.Builder) {
sb.WriteString(n.String())
}
func (n *NilNode) tree() *Tree {
return n.tr
}
func (n *NilNode) Copy() Node {
return n.tr.newNil(n.Pos)
}
// FieldNode holds a field (identifier starting with '.').
// The names may be chained ('.x.y').
// The period is dropped from each ident.
type FieldNode struct {
NodeType
Pos
tr *Tree
Ident []string // The identifiers in lexical order.
}
func (t *Tree) newField(pos Pos, ident string) *FieldNode {
return &FieldNode{tr: t, NodeType: NodeField, Pos: pos, Ident: strings.Split(ident[1:], ".")} // [1:] to drop leading period
}
func (f *FieldNode) String() string {
var sb strings.Builder
f.writeTo(&sb)
return sb.String()
}
func (f *FieldNode) writeTo(sb *strings.Builder) {
for _, id := range f.Ident {
sb.WriteByte('.')
sb.WriteString(id)
}
}
func (f *FieldNode) tree() *Tree {
return f.tr
}
func (f *FieldNode) Copy() Node {
return &FieldNode{tr: f.tr, NodeType: NodeField, Pos: f.Pos, Ident: append([]string{}, f.Ident...)}
}
// ChainNode holds a term followed by a chain of field accesses (identifier starting with '.').
// The names may be chained ('.x.y').
// The periods are dropped from each ident.
type ChainNode struct {
NodeType
Pos
tr *Tree
Node Node
Field []string // The identifiers in lexical order.
}
func (t *Tree) newChain(pos Pos, node Node) *ChainNode {
return &ChainNode{tr: t, NodeType: NodeChain, Pos: pos, Node: node}
}
// Add adds the named field (which should start with a period) to the end of the chain.
func (c *ChainNode) Add(field string) {
if len(field) == 0 || field[0] != '.' {
panic("no dot in field")
}
field = field[1:] // Remove leading dot.
if field == "" {
panic("empty field")
}
c.Field = append(c.Field, field)
}
func (c *ChainNode) String() string {
var sb strings.Builder
c.writeTo(&sb)
return sb.String()
}
func (c *ChainNode) writeTo(sb *strings.Builder) {
if _, ok := c.Node.(*PipeNode); ok {
sb.WriteByte('(')
c.Node.writeTo(sb)
sb.WriteByte(')')
} else {
c.Node.writeTo(sb)
}
for _, field := range c.Field {
sb.WriteByte('.')
sb.WriteString(field)
}
}
func (c *ChainNode) tree() *Tree {
return c.tr
}
func (c *ChainNode) Copy() Node {
return &ChainNode{tr: c.tr, NodeType: NodeChain, Pos: c.Pos, Node: c.Node, Field: append([]string{}, c.Field...)}
}
// BoolNode holds a boolean constant.
type BoolNode struct {
NodeType
Pos
tr *Tree
True bool // The value of the boolean constant.
}
func (t *Tree) newBool(pos Pos, true bool) *BoolNode {
return &BoolNode{tr: t, NodeType: NodeBool, Pos: pos, True: true}
}
func (b *BoolNode) String() string {
if b.True {
return "true"
}
return "false"
}
func (b *BoolNode) writeTo(sb *strings.Builder) {
sb.WriteString(b.String())
}
func (b *BoolNode) tree() *Tree {
return b.tr
}
func (b *BoolNode) Copy() Node {
return b.tr.newBool(b.Pos, b.True)
}
// NumberNode holds a number: signed or unsigned integer, float, or complex.
// The value is parsed and stored under all the types that can represent the value.
// This simulates in a small amount of code the behavior of Go's ideal constants.
type NumberNode struct {
NodeType
Pos
tr *Tree
IsInt bool // Number has an integral value.
IsUint bool // Number has an unsigned integral value.
IsFloat bool // Number has a floating-point value.
IsComplex bool // Number is complex.
Int64 int64 // The signed integer value.
Uint64 uint64 // The unsigned integer value.
Float64 float64 // The floating-point value.
Complex128 complex128 // The complex value.
Text string // The original textual representation from the input.
}
func (t *Tree) newNumber(pos Pos, text string, typ itemType) (*NumberNode, error) {
n := &NumberNode{tr: t, NodeType: NodeNumber, Pos: pos, Text: text}
switch typ {
case itemCharConstant:
rune, _, tail, err := strconv.UnquoteChar(text[1:], text[0])
if err != nil {
return nil, err
}
if tail != "'" {
return nil, fmt.Errorf("malformed character constant: %s", text)
}
n.Int64 = int64(rune)
n.IsInt = true
n.Uint64 = uint64(rune)
n.IsUint = true
n.Float64 = float64(rune) // odd but those are the rules.
n.IsFloat = true
return n, nil
case itemComplex:
// fmt.Sscan can parse the pair, so let it do the work.
if _, err := fmt.Sscan(text, &n.Complex128); err != nil {
return nil, err
}
n.IsComplex = true
n.simplifyComplex()
return n, nil
}
// Imaginary constants can only be complex unless they are zero.
if len(text) > 0 && text[len(text)-1] == 'i' {
f, err := strconv.ParseFloat(text[:len(text)-1], 64)
if err == nil {
n.IsComplex = true
n.Complex128 = complex(0, f)
n.simplifyComplex()
return n, nil
}
}
// Do integer test first so we get 0x123 etc.
u, err := strconv.ParseUint(text, 0, 64) // will fail for -0; fixed below.
if err == nil {
n.IsUint = true
n.Uint64 = u
}
i, err := strconv.ParseInt(text, 0, 64)
if err == nil {
n.IsInt = true
n.Int64 = i
if i == 0 {
n.IsUint = true // in case of -0.
n.Uint64 = u
}
}
// If an integer extraction succeeded, promote the float.
if n.IsInt {
n.IsFloat = true
n.Float64 = float64(n.Int64)
} else if n.IsUint {
n.IsFloat = true
n.Float64 = float64(n.Uint64)
} else {
f, err := strconv.ParseFloat(text, 64)
if err == nil {
// If we parsed it as a float but it looks like an integer,
// it's a huge number too large to fit in an int. Reject it.
if !strings.ContainsAny(text, ".eEpP") {
return nil, fmt.Errorf("integer overflow: %q", text)
}
n.IsFloat = true
n.Float64 = f
// If a floating-point extraction succeeded, extract the int if needed.
if !n.IsInt && float64(int64(f)) == f {
n.IsInt = true
n.Int64 = int64(f)
}
if !n.IsUint && float64(uint64(f)) == f {
n.IsUint = true
n.Uint64 = uint64(f)
}
}
}
if !n.IsInt && !n.IsUint && !n.IsFloat {
return nil, fmt.Errorf("illegal number syntax: %q", text)
}
return n, nil
}
// simplifyComplex pulls out any other types that are represented by the complex number.
// These all require that the imaginary part be zero.
func (n *NumberNode) simplifyComplex() {
n.IsFloat = imag(n.Complex128) == 0
if n.IsFloat {
n.Float64 = real(n.Complex128)
n.IsInt = float64(int64(n.Float64)) == n.Float64
if n.IsInt {
n.Int64 = int64(n.Float64)
}
n.IsUint = float64(uint64(n.Float64)) == n.Float64
if n.IsUint {
n.Uint64 = uint64(n.Float64)
}
}
}
func (n *NumberNode) String() string {
return n.Text
}
func (n *NumberNode) writeTo(sb *strings.Builder) {
sb.WriteString(n.String())
}
func (n *NumberNode) tree() *Tree {
return n.tr
}
func (n *NumberNode) Copy() Node {
nn := new(NumberNode)
*nn = *n // Easy, fast, correct.
return nn
}
// StringNode holds a string constant. The value has been "unquoted".
type StringNode struct {
NodeType
Pos
tr *Tree
Quoted string // The original text of the string, with quotes.
Text string // The string, after quote processing.
}
func (t *Tree) newString(pos Pos, orig, text string) *StringNode {
return &StringNode{tr: t, NodeType: NodeString, Pos: pos, Quoted: orig, Text: text}
}
func (s *StringNode) String() string {
return s.Quoted
}
func (s *StringNode) writeTo(sb *strings.Builder) {
sb.WriteString(s.String())
}
func (s *StringNode) tree() *Tree {
return s.tr
}
func (s *StringNode) Copy() Node {
return s.tr.newString(s.Pos, s.Quoted, s.Text)
}
// endNode represents an {{end}} action.
// It does not appear in the final parse tree.
type endNode struct {
NodeType
Pos
tr *Tree
}
func (t *Tree) newEnd(pos Pos) *endNode {
return &endNode{tr: t, NodeType: nodeEnd, Pos: pos}
}
func (e *endNode) String() string {
return "{{end}}"
}
func (e *endNode) writeTo(sb *strings.Builder) {
sb.WriteString(e.String())
}
func (e *endNode) tree() *Tree {
return e.tr
}
func (e *endNode) Copy() Node {
return e.tr.newEnd(e.Pos)
}
// elseNode represents an {{else}} action. Does not appear in the final tree.
type elseNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input. Deprecated: Kept for compatibility.
}
func (t *Tree) newElse(pos Pos, line int) *elseNode {
return &elseNode{tr: t, NodeType: nodeElse, Pos: pos, Line: line}
}
func (e *elseNode) Type() NodeType {
return nodeElse
}
func (e *elseNode) String() string {
return "{{else}}"
}
func (e *elseNode) writeTo(sb *strings.Builder) {
sb.WriteString(e.String())
}
func (e *elseNode) tree() *Tree {
return e.tr
}
func (e *elseNode) Copy() Node {
return e.tr.newElse(e.Pos, e.Line)
}
// BranchNode is the common representation of if, range, and with.
type BranchNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input. Deprecated: Kept for compatibility.
Pipe *PipeNode // The pipeline to be evaluated.
List *ListNode // What to execute if the value is non-empty.
ElseList *ListNode // What to execute if the value is empty (nil if absent).
}
func (b *BranchNode) String() string {
var sb strings.Builder
b.writeTo(&sb)
return sb.String()
}
func (b *BranchNode) writeTo(sb *strings.Builder) {
name := ""
switch b.NodeType {
case NodeIf:
name = "if"
case NodeRange:
name = "range"
case NodeWith:
name = "with"
default:
panic("unknown branch type")
}
sb.WriteString("{{")
sb.WriteString(name)
sb.WriteByte(' ')
b.Pipe.writeTo(sb)
sb.WriteString("}}")
b.List.writeTo(sb)
if b.ElseList != nil {
sb.WriteString("{{else}}")
b.ElseList.writeTo(sb)
}
sb.WriteString("{{end}}")
}
func (b *BranchNode) tree() *Tree {
return b.tr
}
func (b *BranchNode) Copy() Node {
switch b.NodeType {
case NodeIf:
return b.tr.newIf(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
case NodeRange:
return b.tr.newRange(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
case NodeWith:
return b.tr.newWith(b.Pos, b.Line, b.Pipe, b.List, b.ElseList)
default:
panic("unknown branch type")
}
}
// IfNode represents an {{if}} action and its commands.
type IfNode struct {
BranchNode
}
func (t *Tree) newIf(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *IfNode {
return &IfNode{BranchNode{tr: t, NodeType: NodeIf, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (i *IfNode) Copy() Node {
return i.tr.newIf(i.Pos, i.Line, i.Pipe.CopyPipe(), i.List.CopyList(), i.ElseList.CopyList())
}
// BreakNode represents a {{break}} action.
type BreakNode struct {
tr *Tree
NodeType
Pos
Line int
}
func (t *Tree) newBreak(pos Pos, line int) *BreakNode {
return &BreakNode{tr: t, NodeType: NodeBreak, Pos: pos, Line: line}
}
func (b *BreakNode) Copy() Node { return b.tr.newBreak(b.Pos, b.Line) }
func (b *BreakNode) String() string { return "{{break}}" }
func (b *BreakNode) tree() *Tree { return b.tr }
func (b *BreakNode) writeTo(sb *strings.Builder) { sb.WriteString("{{break}}") }
// ContinueNode represents a {{continue}} action.
type ContinueNode struct {
tr *Tree
NodeType
Pos
Line int
}
func (t *Tree) newContinue(pos Pos, line int) *ContinueNode {
return &ContinueNode{tr: t, NodeType: NodeContinue, Pos: pos, Line: line}
}
func (c *ContinueNode) Copy() Node { return c.tr.newContinue(c.Pos, c.Line) }
func (c *ContinueNode) String() string { return "{{continue}}" }
func (c *ContinueNode) tree() *Tree { return c.tr }
func (c *ContinueNode) writeTo(sb *strings.Builder) { sb.WriteString("{{continue}}") }
// RangeNode represents a {{range}} action and its commands.
type RangeNode struct {
BranchNode
}
func (t *Tree) newRange(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *RangeNode {
return &RangeNode{BranchNode{tr: t, NodeType: NodeRange, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (r *RangeNode) Copy() Node {
return r.tr.newRange(r.Pos, r.Line, r.Pipe.CopyPipe(), r.List.CopyList(), r.ElseList.CopyList())
}
// WithNode represents a {{with}} action and its commands.
type WithNode struct {
BranchNode
}
func (t *Tree) newWith(pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) *WithNode {
return &WithNode{BranchNode{tr: t, NodeType: NodeWith, Pos: pos, Line: line, Pipe: pipe, List: list, ElseList: elseList}}
}
func (w *WithNode) Copy() Node {
return w.tr.newWith(w.Pos, w.Line, w.Pipe.CopyPipe(), w.List.CopyList(), w.ElseList.CopyList())
}
// TemplateNode represents a {{template}} action.
type TemplateNode struct {
NodeType
Pos
tr *Tree
Line int // The line number in the input. Deprecated: Kept for compatibility.
Name string // The name of the template (unquoted).
Pipe *PipeNode // The command to evaluate as dot for the template.
}
func (t *Tree) newTemplate(pos Pos, line int, name string, pipe *PipeNode) *TemplateNode {
return &TemplateNode{tr: t, NodeType: NodeTemplate, Pos: pos, Line: line, Name: name, Pipe: pipe}
}
func (t *TemplateNode) String() string {
var sb strings.Builder
t.writeTo(&sb)
return sb.String()
}
func (t *TemplateNode) writeTo(sb *strings.Builder) {
sb.WriteString("{{template ")
sb.WriteString(strconv.Quote(t.Name))
if t.Pipe != nil {
sb.WriteByte(' ')
t.Pipe.writeTo(sb)
}
sb.WriteString("}}")
}
func (t *TemplateNode) tree() *Tree {
return t.tr
}
func (t *TemplateNode) Copy() Node {
return t.tr.newTemplate(t.Pos, t.Line, t.Name, t.Pipe.CopyPipe())
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package parse builds parse trees for templates as defined by text/template
// and html/template. Clients should use those packages to construct templates
// rather than this one, which provides shared internal data structures not
// intended for general use.
package parse
import (
"bytes"
"fmt"
"runtime"
"strconv"
"strings"
)
// Tree is the representation of a single parsed template.
type Tree struct {
Name string // name of the template represented by the tree.
ParseName string // name of the top-level template during parsing, for error messages.
Root *ListNode // top-level root of the tree.
Mode Mode // parsing mode.
text string // text parsed to create the template (or its parent)
// Parsing only; cleared after parse.
funcs []map[string]any
lex *lexer
token [3]item // three-token lookahead for parser.
peekCount int
vars []string // variables defined at the moment.
treeSet map[string]*Tree
actionLine int // line of left delim starting action
rangeDepth int
}
// A mode value is a set of flags (or 0). Modes control parser behavior.
type Mode uint
const (
ParseComments Mode = 1 << iota // parse comments and add them to AST
SkipFuncCheck // do not check that functions are defined
)
// Copy returns a copy of the Tree. Any parsing state is discarded.
func (t *Tree) Copy() *Tree {
if t == nil {
return nil
}
return &Tree{
Name: t.Name,
ParseName: t.ParseName,
Root: t.Root.CopyList(),
text: t.text,
}
}
// Parse returns a map from template name to parse.Tree, created by parsing the
// templates described in the argument string. The top-level template will be
// given the specified name. If an error is encountered, parsing stops and an
// empty map is returned with the error.
func Parse(name, text, leftDelim, rightDelim string, funcs ...map[string]any) (map[string]*Tree, error) {
treeSet := make(map[string]*Tree)
t := New(name)
t.text = text
_, err := t.Parse(text, leftDelim, rightDelim, treeSet, funcs...)
return treeSet, err
}
// next returns the next token.
func (t *Tree) next() item {
if t.peekCount > 0 {
t.peekCount--
} else {
t.token[0] = t.lex.nextItem()
}
return t.token[t.peekCount]
}
// backup backs the input stream up one token.
func (t *Tree) backup() {
t.peekCount++
}
// backup2 backs the input stream up two tokens.
// The zeroth token is already there.
func (t *Tree) backup2(t1 item) {
t.token[1] = t1
t.peekCount = 2
}
// backup3 backs the input stream up three tokens
// The zeroth token is already there.
func (t *Tree) backup3(t2, t1 item) { // Reverse order: we're pushing back.
t.token[1] = t1
t.token[2] = t2
t.peekCount = 3
}
// peek returns but does not consume the next token.
func (t *Tree) peek() item {
if t.peekCount > 0 {
return t.token[t.peekCount-1]
}
t.peekCount = 1
t.token[0] = t.lex.nextItem()
return t.token[0]
}
// nextNonSpace returns the next non-space token.
func (t *Tree) nextNonSpace() (token item) {
for {
token = t.next()
if token.typ != itemSpace {
break
}
}
return token
}
// peekNonSpace returns but does not consume the next non-space token.
func (t *Tree) peekNonSpace() item {
token := t.nextNonSpace()
t.backup()
return token
}
// Parsing.
// New allocates a new parse tree with the given name.
func New(name string, funcs ...map[string]any) *Tree {
return &Tree{
Name: name,
funcs: funcs,
}
}
// ErrorContext returns a textual representation of the location of the node in the input text.
// The receiver is only used when the node does not have a pointer to the tree inside,
// which can occur in old code.
func (t *Tree) ErrorContext(n Node) (location, context string) {
pos := int(n.Position())
tree := n.tree()
if tree == nil {
tree = t
}
text := tree.text[:pos]
byteNum := strings.LastIndex(text, "\n")
if byteNum == -1 {
byteNum = pos // On first line.
} else {
byteNum++ // After the newline.
byteNum = pos - byteNum
}
lineNum := 1 + strings.Count(text, "\n")
context = n.String()
return fmt.Sprintf("%s:%d:%d", tree.ParseName, lineNum, byteNum), context
}
// errorf formats the error and terminates processing.
func (t *Tree) errorf(format string, args ...any) {
t.Root = nil
format = fmt.Sprintf("template: %s:%d: %s", t.ParseName, t.token[0].line, format)
panic(fmt.Errorf(format, args...))
}
// error terminates processing.
func (t *Tree) error(err error) {
t.errorf("%s", err)
}
// expect consumes the next token and guarantees it has the required type.
func (t *Tree) expect(expected itemType, context string) item {
token := t.nextNonSpace()
if token.typ != expected {
t.unexpected(token, context)
}
return token
}
// expectOneOf consumes the next token and guarantees it has one of the required types.
func (t *Tree) expectOneOf(expected1, expected2 itemType, context string) item {
token := t.nextNonSpace()
if token.typ != expected1 && token.typ != expected2 {
t.unexpected(token, context)
}
return token
}
// unexpected complains about the token and terminates processing.
func (t *Tree) unexpected(token item, context string) {
if token.typ == itemError {
extra := ""
if t.actionLine != 0 && t.actionLine != token.line {
extra = fmt.Sprintf(" in action started at %s:%d", t.ParseName, t.actionLine)
if strings.HasSuffix(token.val, " action") {
extra = extra[len(" in action"):] // avoid "action in action"
}
}
t.errorf("%s%s", token, extra)
}
t.errorf("unexpected %s in %s", token, context)
}
// recover is the handler that turns panics into returns from the top level of Parse.
func (t *Tree) recover(errp *error) {
e := recover()
if e != nil {
if _, ok := e.(runtime.Error); ok {
panic(e)
}
if t != nil {
t.stopParse()
}
*errp = e.(error)
}
}
// startParse initializes the parser, using the lexer.
func (t *Tree) startParse(funcs []map[string]any, lex *lexer, treeSet map[string]*Tree) {
t.Root = nil
t.lex = lex
t.vars = []string{"$"}
t.funcs = funcs
t.treeSet = treeSet
lex.options = lexOptions{
emitComment: t.Mode&ParseComments != 0,
breakOK: !t.hasFunction("break"),
continueOK: !t.hasFunction("continue"),
}
}
// stopParse terminates parsing.
func (t *Tree) stopParse() {
t.lex = nil
t.vars = nil
t.funcs = nil
t.treeSet = nil
}
// Parse parses the template definition string to construct a representation of
// the template for execution. If either action delimiter string is empty, the
// default ("{{" or "}}") is used. Embedded template definitions are added to
// the treeSet map.
func (t *Tree) Parse(text, leftDelim, rightDelim string, treeSet map[string]*Tree, funcs ...map[string]any) (tree *Tree, err error) {
defer t.recover(&err)
t.ParseName = t.Name
lexer := lex(t.Name, text, leftDelim, rightDelim)
t.startParse(funcs, lexer, treeSet)
t.text = text
t.parse()
t.add()
t.stopParse()
return t, nil
}
// add adds tree to t.treeSet.
func (t *Tree) add() {
tree := t.treeSet[t.Name]
if tree == nil || IsEmptyTree(tree.Root) {
t.treeSet[t.Name] = t
return
}
if !IsEmptyTree(t.Root) {
t.errorf("template: multiple definition of template %q", t.Name)
}
}
// IsEmptyTree reports whether this tree (node) is empty of everything but space or comments.
func IsEmptyTree(n Node) bool {
switch n := n.(type) {
case nil:
return true
case *ActionNode:
case *CommentNode:
return true
case *IfNode:
case *ListNode:
for _, node := range n.Nodes {
if !IsEmptyTree(node) {
return false
}
}
return true
case *RangeNode:
case *TemplateNode:
case *TextNode:
return len(bytes.TrimSpace(n.Text)) == 0
case *WithNode:
default:
panic("unknown node: " + n.String())
}
return false
}
// parse is the top-level parser for a template, essentially the same
// as itemList except it also parses {{define}} actions.
// It runs to EOF.
func (t *Tree) parse() {
t.Root = t.newList(t.peek().pos)
for t.peek().typ != itemEOF {
if t.peek().typ == itemLeftDelim {
delim := t.next()
if t.nextNonSpace().typ == itemDefine {
newT := New("definition") // name will be updated once we know it.
newT.text = t.text
newT.Mode = t.Mode
newT.ParseName = t.ParseName
newT.startParse(t.funcs, t.lex, t.treeSet)
newT.parseDefinition()
continue
}
t.backup2(delim)
}
switch n := t.textOrAction(); n.Type() {
case nodeEnd, nodeElse:
t.errorf("unexpected %s", n)
default:
t.Root.append(n)
}
}
}
// parseDefinition parses a {{define}} ... {{end}} template definition and
// installs the definition in t.treeSet. The "define" keyword has already
// been scanned.
func (t *Tree) parseDefinition() {
const context = "define clause"
name := t.expectOneOf(itemString, itemRawString, context)
var err error
t.Name, err = strconv.Unquote(name.val)
if err != nil {
t.error(err)
}
t.expect(itemRightDelim, context)
var end Node
t.Root, end = t.itemList()
if end.Type() != nodeEnd {
t.errorf("unexpected %s in %s", end, context)
}
t.add()
t.stopParse()
}
// itemList:
//
// textOrAction*
//
// Terminates at {{end}} or {{else}}, returned separately.
func (t *Tree) itemList() (list *ListNode, next Node) {
list = t.newList(t.peekNonSpace().pos)
for t.peekNonSpace().typ != itemEOF {
n := t.textOrAction()
switch n.Type() {
case nodeEnd, nodeElse:
return list, n
}
list.append(n)
}
t.errorf("unexpected EOF")
return
}
// textOrAction:
//
// text | comment | action
func (t *Tree) textOrAction() Node {
switch token := t.nextNonSpace(); token.typ {
case itemText:
return t.newText(token.pos, token.val)
case itemLeftDelim:
t.actionLine = token.line
defer t.clearActionLine()
return t.action()
case itemComment:
return t.newComment(token.pos, token.val)
default:
t.unexpected(token, "input")
}
return nil
}
func (t *Tree) clearActionLine() {
t.actionLine = 0
}
// Action:
//
// control
// command ("|" command)*
//
// Left delim is past. Now get actions.
// First word could be a keyword such as range.
func (t *Tree) action() (n Node) {
switch token := t.nextNonSpace(); token.typ {
case itemBlock:
return t.blockControl()
case itemBreak:
return t.breakControl(token.pos, token.line)
case itemContinue:
return t.continueControl(token.pos, token.line)
case itemElse:
return t.elseControl()
case itemEnd:
return t.endControl()
case itemIf:
return t.ifControl()
case itemRange:
return t.rangeControl()
case itemTemplate:
return t.templateControl()
case itemWith:
return t.withControl()
}
t.backup()
token := t.peek()
// Do not pop variables; they persist until "end".
return t.newAction(token.pos, token.line, t.pipeline("command", itemRightDelim))
}
// Break:
//
// {{break}}
//
// Break keyword is past.
func (t *Tree) breakControl(pos Pos, line int) Node {
if token := t.nextNonSpace(); token.typ != itemRightDelim {
t.unexpected(token, "{{break}}")
}
if t.rangeDepth == 0 {
t.errorf("{{break}} outside {{range}}")
}
return t.newBreak(pos, line)
}
// Continue:
//
// {{continue}}
//
// Continue keyword is past.
func (t *Tree) continueControl(pos Pos, line int) Node {
if token := t.nextNonSpace(); token.typ != itemRightDelim {
t.unexpected(token, "{{continue}}")
}
if t.rangeDepth == 0 {
t.errorf("{{continue}} outside {{range}}")
}
return t.newContinue(pos, line)
}
// Pipeline:
//
// declarations? command ('|' command)*
func (t *Tree) pipeline(context string, end itemType) (pipe *PipeNode) {
token := t.peekNonSpace()
pipe = t.newPipeline(token.pos, token.line, nil)
// Are there declarations or assignments?
decls:
if v := t.peekNonSpace(); v.typ == itemVariable {
t.next()
// Since space is a token, we need 3-token look-ahead here in the worst case:
// in "$x foo" we need to read "foo" (as opposed to ":=") to know that $x is an
// argument variable rather than a declaration. So remember the token
// adjacent to the variable so we can push it back if necessary.
tokenAfterVariable := t.peek()
next := t.peekNonSpace()
switch {
case next.typ == itemAssign, next.typ == itemDeclare:
pipe.IsAssign = next.typ == itemAssign
t.nextNonSpace()
pipe.Decl = append(pipe.Decl, t.newVariable(v.pos, v.val))
t.vars = append(t.vars, v.val)
case next.typ == itemChar && next.val == ",":
t.nextNonSpace()
pipe.Decl = append(pipe.Decl, t.newVariable(v.pos, v.val))
t.vars = append(t.vars, v.val)
if context == "range" && len(pipe.Decl) < 2 {
switch t.peekNonSpace().typ {
case itemVariable, itemRightDelim, itemRightParen:
// second initialized variable in a range pipeline
goto decls
default:
t.errorf("range can only initialize variables")
}
}
t.errorf("too many declarations in %s", context)
case tokenAfterVariable.typ == itemSpace:
t.backup3(v, tokenAfterVariable)
default:
t.backup2(v)
}
}
for {
switch token := t.nextNonSpace(); token.typ {
case end:
// At this point, the pipeline is complete
t.checkPipeline(pipe, context)
return
case itemBool, itemCharConstant, itemComplex, itemDot, itemField, itemIdentifier,
itemNumber, itemNil, itemRawString, itemString, itemVariable, itemLeftParen:
t.backup()
pipe.append(t.command())
default:
t.unexpected(token, context)
}
}
}
func (t *Tree) checkPipeline(pipe *PipeNode, context string) {
// Reject empty pipelines
if len(pipe.Cmds) == 0 {
t.errorf("missing value for %s", context)
}
// Only the first command of a pipeline can start with a non executable operand
for i, c := range pipe.Cmds[1:] {
switch c.Args[0].Type() {
case NodeBool, NodeDot, NodeNil, NodeNumber, NodeString:
// With A|B|C, pipeline stage 2 is B
t.errorf("non executable command in pipeline stage %d", i+2)
}
}
}
func (t *Tree) parseControl(allowElseIf bool, context string) (pos Pos, line int, pipe *PipeNode, list, elseList *ListNode) {
defer t.popVars(len(t.vars))
pipe = t.pipeline(context, itemRightDelim)
if context == "range" {
t.rangeDepth++
}
var next Node
list, next = t.itemList()
if context == "range" {
t.rangeDepth--
}
switch next.Type() {
case nodeEnd: //done
case nodeElse:
if allowElseIf {
// Special case for "else if". If the "else" is followed immediately by an "if",
// the elseControl will have left the "if" token pending. Treat
// {{if a}}_{{else if b}}_{{end}}
// as
// {{if a}}_{{else}}{{if b}}_{{end}}{{end}}.
// To do this, parse the if as usual and stop at it {{end}}; the subsequent{{end}}
// is assumed. This technique works even for long if-else-if chains.
// TODO: Should we allow else-if in with and range?
if t.peek().typ == itemIf {
t.next() // Consume the "if" token.
elseList = t.newList(next.Position())
elseList.append(t.ifControl())
// Do not consume the next item - only one {{end}} required.
break
}
}
elseList, next = t.itemList()
if next.Type() != nodeEnd {
t.errorf("expected end; found %s", next)
}
}
return pipe.Position(), pipe.Line, pipe, list, elseList
}
// If:
//
// {{if pipeline}} itemList {{end}}
// {{if pipeline}} itemList {{else}} itemList {{end}}
//
// If keyword is past.
func (t *Tree) ifControl() Node {
return t.newIf(t.parseControl(true, "if"))
}
// Range:
//
// {{range pipeline}} itemList {{end}}
// {{range pipeline}} itemList {{else}} itemList {{end}}
//
// Range keyword is past.
func (t *Tree) rangeControl() Node {
r := t.newRange(t.parseControl(false, "range"))
return r
}
// With:
//
// {{with pipeline}} itemList {{end}}
// {{with pipeline}} itemList {{else}} itemList {{end}}
//
// If keyword is past.
func (t *Tree) withControl() Node {
return t.newWith(t.parseControl(false, "with"))
}
// End:
//
// {{end}}
//
// End keyword is past.
func (t *Tree) endControl() Node {
return t.newEnd(t.expect(itemRightDelim, "end").pos)
}
// Else:
//
// {{else}}
//
// Else keyword is past.
func (t *Tree) elseControl() Node {
// Special case for "else if".
peek := t.peekNonSpace()
if peek.typ == itemIf {
// We see "{{else if ... " but in effect rewrite it to {{else}}{{if ... ".
return t.newElse(peek.pos, peek.line)
}
token := t.expect(itemRightDelim, "else")
return t.newElse(token.pos, token.line)
}
// Block:
//
// {{block stringValue pipeline}}
//
// Block keyword is past.
// The name must be something that can evaluate to a string.
// The pipeline is mandatory.
func (t *Tree) blockControl() Node {
const context = "block clause"
token := t.nextNonSpace()
name := t.parseTemplateName(token, context)
pipe := t.pipeline(context, itemRightDelim)
block := New(name) // name will be updated once we know it.
block.text = t.text
block.Mode = t.Mode
block.ParseName = t.ParseName
block.startParse(t.funcs, t.lex, t.treeSet)
var end Node
block.Root, end = block.itemList()
if end.Type() != nodeEnd {
t.errorf("unexpected %s in %s", end, context)
}
block.add()
block.stopParse()
return t.newTemplate(token.pos, token.line, name, pipe)
}
// Template:
//
// {{template stringValue pipeline}}
//
// Template keyword is past. The name must be something that can evaluate
// to a string.
func (t *Tree) templateControl() Node {
const context = "template clause"
token := t.nextNonSpace()
name := t.parseTemplateName(token, context)
var pipe *PipeNode
if t.nextNonSpace().typ != itemRightDelim {
t.backup()
// Do not pop variables; they persist until "end".
pipe = t.pipeline(context, itemRightDelim)
}
return t.newTemplate(token.pos, token.line, name, pipe)
}
func (t *Tree) parseTemplateName(token item, context string) (name string) {
switch token.typ {
case itemString, itemRawString:
s, err := strconv.Unquote(token.val)
if err != nil {
t.error(err)
}
name = s
default:
t.unexpected(token, context)
}
return
}
// command:
//
// operand (space operand)*
//
// space-separated arguments up to a pipeline character or right delimiter.
// we consume the pipe character but leave the right delim to terminate the action.
func (t *Tree) command() *CommandNode {
cmd := t.newCommand(t.peekNonSpace().pos)
for {
t.peekNonSpace() // skip leading spaces.
operand := t.operand()
if operand != nil {
cmd.append(operand)
}
switch token := t.next(); token.typ {
case itemSpace:
continue
case itemRightDelim, itemRightParen:
t.backup()
case itemPipe:
// nothing here; break loop below
default:
t.unexpected(token, "operand")
}
break
}
if len(cmd.Args) == 0 {
t.errorf("empty command")
}
return cmd
}
// operand:
//
// term .Field*
//
// An operand is a space-separated component of a command,
// a term possibly followed by field accesses.
// A nil return means the next item is not an operand.
func (t *Tree) operand() Node {
node := t.term()
if node == nil {
return nil
}
if t.peek().typ == itemField {
chain := t.newChain(t.peek().pos, node)
for t.peek().typ == itemField {
chain.Add(t.next().val)
}
// Compatibility with original API: If the term is of type NodeField
// or NodeVariable, just put more fields on the original.
// Otherwise, keep the Chain node.
// Obvious parsing errors involving literal values are detected here.
// More complex error cases will have to be handled at execution time.
switch node.Type() {
case NodeField:
node = t.newField(chain.Position(), chain.String())
case NodeVariable:
node = t.newVariable(chain.Position(), chain.String())
case NodeBool, NodeString, NodeNumber, NodeNil, NodeDot:
t.errorf("unexpected . after term %q", node.String())
default:
node = chain
}
}
return node
}
// term:
//
// literal (number, string, nil, boolean)
// function (identifier)
// .
// .Field
// $
// '(' pipeline ')'
//
// A term is a simple "expression".
// A nil return means the next item is not a term.
func (t *Tree) term() Node {
switch token := t.nextNonSpace(); token.typ {
case itemIdentifier:
checkFunc := t.Mode&SkipFuncCheck == 0
if checkFunc && !t.hasFunction(token.val) {
t.errorf("function %q not defined", token.val)
}
return NewIdentifier(token.val).SetTree(t).SetPos(token.pos)
case itemDot:
return t.newDot(token.pos)
case itemNil:
return t.newNil(token.pos)
case itemVariable:
return t.useVar(token.pos, token.val)
case itemField:
return t.newField(token.pos, token.val)
case itemBool:
return t.newBool(token.pos, token.val == "true")
case itemCharConstant, itemComplex, itemNumber:
number, err := t.newNumber(token.pos, token.val, token.typ)
if err != nil {
t.error(err)
}
return number
case itemLeftParen:
return t.pipeline("parenthesized pipeline", itemRightParen)
case itemString, itemRawString:
s, err := strconv.Unquote(token.val)
if err != nil {
t.error(err)
}
return t.newString(token.pos, token.val, s)
}
t.backup()
return nil
}
// hasFunction reports if a function name exists in the Tree's maps.
func (t *Tree) hasFunction(name string) bool {
for _, funcMap := range t.funcs {
if funcMap == nil {
continue
}
if funcMap[name] != nil {
return true
}
}
return false
}
// popVars trims the variable list to the specified length
func (t *Tree) popVars(n int) {
t.vars = t.vars[:n]
}
// useVar returns a node for a variable reference. It errors if the
// variable is not defined.
func (t *Tree) useVar(pos Pos, name string) Node {
v := t.newVariable(pos, name)
for _, varName := range t.vars {
if varName == v.Ident[0] {
return v
}
}
t.errorf("undefined variable %q", v.Ident[0])
return nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"reflect"
"sync"
"text/template/parse"
)
// common holds the information shared by related templates.
type common struct {
tmpl map[string]*Template // Map from name to defined templates.
muTmpl sync.RWMutex // protects tmpl
option option
// We use two maps, one for parsing and one for execution.
// This separation makes the API cleaner since it doesn't
// expose reflection to the client.
muFuncs sync.RWMutex // protects parseFuncs and execFuncs
parseFuncs FuncMap
execFuncs map[string]reflect.Value
}
// Template is the representation of a parsed template. The *parse.Tree
// field is exported only for use by html/template and should be treated
// as unexported by all other clients.
type Template struct {
name string
*parse.Tree
*common
leftDelim string
rightDelim string
}
// New allocates a new, undefined template with the given name.
func New(name string) *Template {
t := &Template{
name: name,
}
t.init()
return t
}
// Name returns the name of the template.
func (t *Template) Name() string {
return t.name
}
// New allocates a new, undefined template associated with the given one and with the same
// delimiters. The association, which is transitive, allows one template to
// invoke another with a {{template}} action.
//
// Because associated templates share underlying data, template construction
// cannot be done safely in parallel. Once the templates are constructed, they
// can be executed in parallel.
func (t *Template) New(name string) *Template {
t.init()
nt := &Template{
name: name,
common: t.common,
leftDelim: t.leftDelim,
rightDelim: t.rightDelim,
}
return nt
}
// init guarantees that t has a valid common structure.
func (t *Template) init() {
if t.common == nil {
c := new(common)
c.tmpl = make(map[string]*Template)
c.parseFuncs = make(FuncMap)
c.execFuncs = make(map[string]reflect.Value)
t.common = c
}
}
// Clone returns a duplicate of the template, including all associated
// templates. The actual representation is not copied, but the name space of
// associated templates is, so further calls to Parse in the copy will add
// templates to the copy but not to the original. Clone can be used to prepare
// common templates and use them with variant definitions for other templates
// by adding the variants after the clone is made.
func (t *Template) Clone() (*Template, error) {
nt := t.copy(nil)
nt.init()
if t.common == nil {
return nt, nil
}
t.muTmpl.RLock()
defer t.muTmpl.RUnlock()
for k, v := range t.tmpl {
if k == t.name {
nt.tmpl[t.name] = nt
continue
}
// The associated templates share nt's common structure.
tmpl := v.copy(nt.common)
nt.tmpl[k] = tmpl
}
t.muFuncs.RLock()
defer t.muFuncs.RUnlock()
for k, v := range t.parseFuncs {
nt.parseFuncs[k] = v
}
for k, v := range t.execFuncs {
nt.execFuncs[k] = v
}
return nt, nil
}
// copy returns a shallow copy of t, with common set to the argument.
func (t *Template) copy(c *common) *Template {
return &Template{
name: t.name,
Tree: t.Tree,
common: c,
leftDelim: t.leftDelim,
rightDelim: t.rightDelim,
}
}
// AddParseTree associates the argument parse tree with the template t, giving
// it the specified name. If the template has not been defined, this tree becomes
// its definition. If it has been defined and already has that name, the existing
// definition is replaced; otherwise a new template is created, defined, and returned.
func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error) {
t.init()
t.muTmpl.Lock()
defer t.muTmpl.Unlock()
nt := t
if name != t.name {
nt = t.New(name)
}
// Even if nt == t, we need to install it in the common.tmpl map.
if t.associate(nt, tree) || nt.Tree == nil {
nt.Tree = tree
}
return nt, nil
}
// Templates returns a slice of defined templates associated with t.
func (t *Template) Templates() []*Template {
if t.common == nil {
return nil
}
// Return a slice so we don't expose the map.
t.muTmpl.RLock()
defer t.muTmpl.RUnlock()
m := make([]*Template, 0, len(t.tmpl))
for _, v := range t.tmpl {
m = append(m, v)
}
return m
}
// Delims sets the action delimiters to the specified strings, to be used in
// subsequent calls to Parse, ParseFiles, or ParseGlob. Nested template
// definitions will inherit the settings. An empty delimiter stands for the
// corresponding default: {{ or }}.
// The return value is the template, so calls can be chained.
func (t *Template) Delims(left, right string) *Template {
t.init()
t.leftDelim = left
t.rightDelim = right
return t
}
// Funcs adds the elements of the argument map to the template's function map.
// It must be called before the template is parsed.
// It panics if a value in the map is not a function with appropriate return
// type or if the name cannot be used syntactically as a function in a template.
// It is legal to overwrite elements of the map. The return value is the template,
// so calls can be chained.
func (t *Template) Funcs(funcMap FuncMap) *Template {
t.init()
t.muFuncs.Lock()
defer t.muFuncs.Unlock()
addValueFuncs(t.execFuncs, funcMap)
addFuncs(t.parseFuncs, funcMap)
return t
}
// Lookup returns the template with the given name that is associated with t.
// It returns nil if there is no such template or the template has no definition.
func (t *Template) Lookup(name string) *Template {
if t.common == nil {
return nil
}
t.muTmpl.RLock()
defer t.muTmpl.RUnlock()
return t.tmpl[name]
}
// Parse parses text as a template body for t.
// Named template definitions ({{define ...}} or {{block ...}} statements) in text
// define additional templates associated with t and are removed from the
// definition of t itself.
//
// Templates can be redefined in successive calls to Parse.
// A template definition with a body containing only white space and comments
// is considered empty and will not replace an existing template's body.
// This allows using Parse to add new named template definitions without
// overwriting the main template body.
func (t *Template) Parse(text string) (*Template, error) {
t.init()
t.muFuncs.RLock()
trees, err := parse.Parse(t.name, text, t.leftDelim, t.rightDelim, t.parseFuncs, builtins())
t.muFuncs.RUnlock()
if err != nil {
return nil, err
}
// Add the newly parsed trees, including the one for t, into our common structure.
for name, tree := range trees {
if _, err := t.AddParseTree(name, tree); err != nil {
return nil, err
}
}
return t, nil
}
// associate installs the new template into the group of templates associated
// with t. The two are already known to share the common structure.
// The boolean return value reports whether to store this tree as t.Tree.
func (t *Template) associate(new *Template, tree *parse.Tree) bool {
if new.common != t.common {
panic("internal error: associate not common")
}
if old := t.tmpl[new.name]; old != nil && parse.IsEmptyTree(tree.Root) && old.Tree != nil {
// If a template by that name exists,
// don't replace it with an empty template.
return false
}
t.tmpl[new.name] = new
return true
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package time
import "errors"
// These are predefined layouts for use in Time.Format and time.Parse.
// The reference time used in these layouts is the specific time stamp:
//
// 01/02 03:04:05PM '06 -0700
//
// (January 2, 15:04:05, 2006, in time zone seven hours west of GMT).
// That value is recorded as the constant named Layout, listed below. As a Unix
// time, this is 1136239445. Since MST is GMT-0700, the reference would be
// printed by the Unix date command as:
//
// Mon Jan 2 15:04:05 MST 2006
//
// It is a regrettable historic error that the date uses the American convention
// of putting the numerical month before the day.
//
// The example for Time.Format demonstrates the working of the layout string
// in detail and is a good reference.
//
// Note that the RFC822, RFC850, and RFC1123 formats should be applied
// only to local times. Applying them to UTC times will use "UTC" as the
// time zone abbreviation, while strictly speaking those RFCs require the
// use of "GMT" in that case.
// In general RFC1123Z should be used instead of RFC1123 for servers
// that insist on that format, and RFC3339 should be preferred for new protocols.
// RFC3339, RFC822, RFC822Z, RFC1123, and RFC1123Z are useful for formatting;
// when used with time.Parse they do not accept all the time formats
// permitted by the RFCs and they do accept time formats not formally defined.
// The RFC3339Nano format removes trailing zeros from the seconds field
// and thus may not sort correctly once formatted.
//
// Most programs can use one of the defined constants as the layout passed to
// Format or Parse. The rest of this comment can be ignored unless you are
// creating a custom layout string.
//
// To define your own format, write down what the reference time would look like
// formatted your way; see the values of constants like ANSIC, StampMicro or
// Kitchen for examples. The model is to demonstrate what the reference time
// looks like so that the Format and Parse methods can apply the same
// transformation to a general time value.
//
// Here is a summary of the components of a layout string. Each element shows by
// example the formatting of an element of the reference time. Only these values
// are recognized. Text in the layout string that is not recognized as part of
// the reference time is echoed verbatim during Format and expected to appear
// verbatim in the input to Parse.
//
// Year: "2006" "06"
// Month: "Jan" "January" "01" "1"
// Day of the week: "Mon" "Monday"
// Day of the month: "2" "_2" "02"
// Day of the year: "__2" "002"
// Hour: "15" "3" "03" (PM or AM)
// Minute: "4" "04"
// Second: "5" "05"
// AM/PM mark: "PM"
//
// Numeric time zone offsets format as follows:
//
// "-0700" ±hhmm
// "-07:00" ±hh:mm
// "-07" ±hh
// "-070000" ±hhmmss
// "-07:00:00" ±hh:mm:ss
//
// Replacing the sign in the format with a Z triggers
// the ISO 8601 behavior of printing Z instead of an
// offset for the UTC zone. Thus:
//
// "Z0700" Z or ±hhmm
// "Z07:00" Z or ±hh:mm
// "Z07" Z or ±hh
// "Z070000" Z or ±hhmmss
// "Z07:00:00" Z or ±hh:mm:ss
//
// Within the format string, the underscores in "_2" and "__2" represent spaces
// that may be replaced by digits if the following number has multiple digits,
// for compatibility with fixed-width Unix time formats. A leading zero represents
// a zero-padded value.
//
// The formats __2 and 002 are space-padded and zero-padded
// three-character day of year; there is no unpadded day of year format.
//
// A comma or decimal point followed by one or more zeros represents
// a fractional second, printed to the given number of decimal places.
// A comma or decimal point followed by one or more nines represents
// a fractional second, printed to the given number of decimal places, with
// trailing zeros removed.
// For example "15:04:05,000" or "15:04:05.000" formats or parses with
// millisecond precision.
//
// Some valid layouts are invalid time values for time.Parse, due to formats
// such as _ for space padding and Z for zone information.
const (
Layout = "01/02 03:04:05PM '06 -0700" // The reference time, in numerical order.
ANSIC = "Mon Jan _2 15:04:05 2006"
UnixDate = "Mon Jan _2 15:04:05 MST 2006"
RubyDate = "Mon Jan 02 15:04:05 -0700 2006"
RFC822 = "02 Jan 06 15:04 MST"
RFC822Z = "02 Jan 06 15:04 -0700" // RFC822 with numeric zone
RFC850 = "Monday, 02-Jan-06 15:04:05 MST"
RFC1123 = "Mon, 02 Jan 2006 15:04:05 MST"
RFC1123Z = "Mon, 02 Jan 2006 15:04:05 -0700" // RFC1123 with numeric zone
RFC3339 = "2006-01-02T15:04:05Z07:00"
RFC3339Nano = "2006-01-02T15:04:05.999999999Z07:00"
Kitchen = "3:04PM"
// Handy time stamps.
Stamp = "Jan _2 15:04:05"
StampMilli = "Jan _2 15:04:05.000"
StampMicro = "Jan _2 15:04:05.000000"
StampNano = "Jan _2 15:04:05.000000000"
DateTime = "2006-01-02 15:04:05"
DateOnly = "2006-01-02"
TimeOnly = "15:04:05"
)
const (
_ = iota
stdLongMonth = iota + stdNeedDate // "January"
stdMonth // "Jan"
stdNumMonth // "1"
stdZeroMonth // "01"
stdLongWeekDay // "Monday"
stdWeekDay // "Mon"
stdDay // "2"
stdUnderDay // "_2"
stdZeroDay // "02"
stdUnderYearDay // "__2"
stdZeroYearDay // "002"
stdHour = iota + stdNeedClock // "15"
stdHour12 // "3"
stdZeroHour12 // "03"
stdMinute // "4"
stdZeroMinute // "04"
stdSecond // "5"
stdZeroSecond // "05"
stdLongYear = iota + stdNeedDate // "2006"
stdYear // "06"
stdPM = iota + stdNeedClock // "PM"
stdpm // "pm"
stdTZ = iota // "MST"
stdISO8601TZ // "Z0700" // prints Z for UTC
stdISO8601SecondsTZ // "Z070000"
stdISO8601ShortTZ // "Z07"
stdISO8601ColonTZ // "Z07:00" // prints Z for UTC
stdISO8601ColonSecondsTZ // "Z07:00:00"
stdNumTZ // "-0700" // always numeric
stdNumSecondsTz // "-070000"
stdNumShortTZ // "-07" // always numeric
stdNumColonTZ // "-07:00" // always numeric
stdNumColonSecondsTZ // "-07:00:00"
stdFracSecond0 // ".0", ".00", ... , trailing zeros included
stdFracSecond9 // ".9", ".99", ..., trailing zeros omitted
stdNeedDate = 1 << 8 // need month, day, year
stdNeedClock = 2 << 8 // need hour, minute, second
stdArgShift = 16 // extra argument in high bits, above low stdArgShift
stdSeparatorShift = 28 // extra argument in high 4 bits for fractional second separators
stdMask = 1<<stdArgShift - 1 // mask out argument
)
// std0x records the std values for "01", "02", ..., "06".
var std0x = [...]int{stdZeroMonth, stdZeroDay, stdZeroHour12, stdZeroMinute, stdZeroSecond, stdYear}
// startsWithLowerCase reports whether the string has a lower-case letter at the beginning.
// Its purpose is to prevent matching strings like "Month" when looking for "Mon".
func startsWithLowerCase(str string) bool {
if len(str) == 0 {
return false
}
c := str[0]
return 'a' <= c && c <= 'z'
}
// nextStdChunk finds the first occurrence of a std string in
// layout and returns the text before, the std string, and the text after.
func nextStdChunk(layout string) (prefix string, std int, suffix string) {
for i := 0; i < len(layout); i++ {
switch c := int(layout[i]); c {
case 'J': // January, Jan
if len(layout) >= i+3 && layout[i:i+3] == "Jan" {
if len(layout) >= i+7 && layout[i:i+7] == "January" {
return layout[0:i], stdLongMonth, layout[i+7:]
}
if !startsWithLowerCase(layout[i+3:]) {
return layout[0:i], stdMonth, layout[i+3:]
}
}
case 'M': // Monday, Mon, MST
if len(layout) >= i+3 {
if layout[i:i+3] == "Mon" {
if len(layout) >= i+6 && layout[i:i+6] == "Monday" {
return layout[0:i], stdLongWeekDay, layout[i+6:]
}
if !startsWithLowerCase(layout[i+3:]) {
return layout[0:i], stdWeekDay, layout[i+3:]
}
}
if layout[i:i+3] == "MST" {
return layout[0:i], stdTZ, layout[i+3:]
}
}
case '0': // 01, 02, 03, 04, 05, 06, 002
if len(layout) >= i+2 && '1' <= layout[i+1] && layout[i+1] <= '6' {
return layout[0:i], std0x[layout[i+1]-'1'], layout[i+2:]
}
if len(layout) >= i+3 && layout[i+1] == '0' && layout[i+2] == '2' {
return layout[0:i], stdZeroYearDay, layout[i+3:]
}
case '1': // 15, 1
if len(layout) >= i+2 && layout[i+1] == '5' {
return layout[0:i], stdHour, layout[i+2:]
}
return layout[0:i], stdNumMonth, layout[i+1:]
case '2': // 2006, 2
if len(layout) >= i+4 && layout[i:i+4] == "2006" {
return layout[0:i], stdLongYear, layout[i+4:]
}
return layout[0:i], stdDay, layout[i+1:]
case '_': // _2, _2006, __2
if len(layout) >= i+2 && layout[i+1] == '2' {
//_2006 is really a literal _, followed by stdLongYear
if len(layout) >= i+5 && layout[i+1:i+5] == "2006" {
return layout[0 : i+1], stdLongYear, layout[i+5:]
}
return layout[0:i], stdUnderDay, layout[i+2:]
}
if len(layout) >= i+3 && layout[i+1] == '_' && layout[i+2] == '2' {
return layout[0:i], stdUnderYearDay, layout[i+3:]
}
case '3':
return layout[0:i], stdHour12, layout[i+1:]
case '4':
return layout[0:i], stdMinute, layout[i+1:]
case '5':
return layout[0:i], stdSecond, layout[i+1:]
case 'P': // PM
if len(layout) >= i+2 && layout[i+1] == 'M' {
return layout[0:i], stdPM, layout[i+2:]
}
case 'p': // pm
if len(layout) >= i+2 && layout[i+1] == 'm' {
return layout[0:i], stdpm, layout[i+2:]
}
case '-': // -070000, -07:00:00, -0700, -07:00, -07
if len(layout) >= i+7 && layout[i:i+7] == "-070000" {
return layout[0:i], stdNumSecondsTz, layout[i+7:]
}
if len(layout) >= i+9 && layout[i:i+9] == "-07:00:00" {
return layout[0:i], stdNumColonSecondsTZ, layout[i+9:]
}
if len(layout) >= i+5 && layout[i:i+5] == "-0700" {
return layout[0:i], stdNumTZ, layout[i+5:]
}
if len(layout) >= i+6 && layout[i:i+6] == "-07:00" {
return layout[0:i], stdNumColonTZ, layout[i+6:]
}
if len(layout) >= i+3 && layout[i:i+3] == "-07" {
return layout[0:i], stdNumShortTZ, layout[i+3:]
}
case 'Z': // Z070000, Z07:00:00, Z0700, Z07:00,
if len(layout) >= i+7 && layout[i:i+7] == "Z070000" {
return layout[0:i], stdISO8601SecondsTZ, layout[i+7:]
}
if len(layout) >= i+9 && layout[i:i+9] == "Z07:00:00" {
return layout[0:i], stdISO8601ColonSecondsTZ, layout[i+9:]
}
if len(layout) >= i+5 && layout[i:i+5] == "Z0700" {
return layout[0:i], stdISO8601TZ, layout[i+5:]
}
if len(layout) >= i+6 && layout[i:i+6] == "Z07:00" {
return layout[0:i], stdISO8601ColonTZ, layout[i+6:]
}
if len(layout) >= i+3 && layout[i:i+3] == "Z07" {
return layout[0:i], stdISO8601ShortTZ, layout[i+3:]
}
case '.', ',': // ,000, or .000, or ,999, or .999 - repeated digits for fractional seconds.
if i+1 < len(layout) && (layout[i+1] == '0' || layout[i+1] == '9') {
ch := layout[i+1]
j := i + 1
for j < len(layout) && layout[j] == ch {
j++
}
// String of digits must end here - only fractional second is all digits.
if !isDigit(layout, j) {
code := stdFracSecond0
if layout[i+1] == '9' {
code = stdFracSecond9
}
std := stdFracSecond(code, j-(i+1), c)
return layout[0:i], std, layout[j:]
}
}
}
}
return layout, 0, ""
}
var longDayNames = []string{
"Sunday",
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
}
var shortDayNames = []string{
"Sun",
"Mon",
"Tue",
"Wed",
"Thu",
"Fri",
"Sat",
}
var shortMonthNames = []string{
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
}
var longMonthNames = []string{
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
}
// match reports whether s1 and s2 match ignoring case.
// It is assumed s1 and s2 are the same length.
func match(s1, s2 string) bool {
for i := 0; i < len(s1); i++ {
c1 := s1[i]
c2 := s2[i]
if c1 != c2 {
// Switch to lower-case; 'a'-'A' is known to be a single bit.
c1 |= 'a' - 'A'
c2 |= 'a' - 'A'
if c1 != c2 || c1 < 'a' || c1 > 'z' {
return false
}
}
}
return true
}
func lookup(tab []string, val string) (int, string, error) {
for i, v := range tab {
if len(val) >= len(v) && match(val[0:len(v)], v) {
return i, val[len(v):], nil
}
}
return -1, val, errBad
}
// appendInt appends the decimal form of x to b and returns the result.
// If the decimal form (excluding sign) is shorter than width, the result is padded with leading 0's.
// Duplicates functionality in strconv, but avoids dependency.
func appendInt(b []byte, x int, width int) []byte {
u := uint(x)
if x < 0 {
b = append(b, '-')
u = uint(-x)
}
// 2-digit and 4-digit fields are the most common in time formats.
utod := func(u uint) byte { return '0' + byte(u) }
switch {
case width == 2 && u < 1e2:
return append(b, utod(u/1e1), utod(u%1e1))
case width == 4 && u < 1e4:
return append(b, utod(u/1e3), utod(u/1e2%1e1), utod(u/1e1%1e1), utod(u%1e1))
}
// Compute the number of decimal digits.
var n int
if u == 0 {
n = 1
}
for u2 := u; u2 > 0; u2 /= 10 {
n++
}
// Add 0-padding.
for pad := width - n; pad > 0; pad-- {
b = append(b, '0')
}
// Ensure capacity.
if len(b)+n <= cap(b) {
b = b[:len(b)+n]
} else {
b = append(b, make([]byte, n)...)
}
// Assemble decimal in reverse order.
i := len(b) - 1
for u >= 10 && i > 0 {
q := u / 10
b[i] = utod(u - q*10)
u = q
i--
}
b[i] = utod(u)
return b
}
// Never printed, just needs to be non-nil for return by atoi.
var errAtoi = errors.New("time: invalid number")
// Duplicates functionality in strconv, but avoids dependency.
func atoi[bytes []byte | string](s bytes) (x int, err error) {
neg := false
if len(s) > 0 && (s[0] == '-' || s[0] == '+') {
neg = s[0] == '-'
s = s[1:]
}
q, rem, err := leadingInt(s)
x = int(q)
if err != nil || len(rem) > 0 {
return 0, errAtoi
}
if neg {
x = -x
}
return x, nil
}
// The "std" value passed to appendNano contains two packed fields: the number of
// digits after the decimal and the separator character (period or comma).
// These functions pack and unpack that variable.
func stdFracSecond(code, n, c int) int {
// Use 0xfff to make the failure case even more absurd.
if c == '.' {
return code | ((n & 0xfff) << stdArgShift)
}
return code | ((n & 0xfff) << stdArgShift) | 1<<stdSeparatorShift
}
func digitsLen(std int) int {
return (std >> stdArgShift) & 0xfff
}
func separator(std int) byte {
if (std >> stdSeparatorShift) == 0 {
return '.'
}
return ','
}
// appendNano appends a fractional second, as nanoseconds, to b
// and returns the result. The nanosec must be within [0, 999999999].
func appendNano(b []byte, nanosec int, std int) []byte {
trim := std&stdMask == stdFracSecond9
n := digitsLen(std)
if trim && (n == 0 || nanosec == 0) {
return b
}
dot := separator(std)
b = append(b, dot)
b = appendInt(b, nanosec, 9)
if n < 9 {
b = b[:len(b)-9+n]
}
if trim {
for len(b) > 0 && b[len(b)-1] == '0' {
b = b[:len(b)-1]
}
if len(b) > 0 && b[len(b)-1] == dot {
b = b[:len(b)-1]
}
}
return b
}
// String returns the time formatted using the format string
//
// "2006-01-02 15:04:05.999999999 -0700 MST"
//
// If the time has a monotonic clock reading, the returned string
// includes a final field "m=±<value>", where value is the monotonic
// clock reading formatted as a decimal number of seconds.
//
// The returned string is meant for debugging; for a stable serialized
// representation, use t.MarshalText, t.MarshalBinary, or t.Format
// with an explicit format string.
func (t Time) String() string {
s := t.Format("2006-01-02 15:04:05.999999999 -0700 MST")
// Format monotonic clock reading as m=±ddd.nnnnnnnnn.
if t.wall&hasMonotonic != 0 {
m2 := uint64(t.ext)
sign := byte('+')
if t.ext < 0 {
sign = '-'
m2 = -m2
}
m1, m2 := m2/1e9, m2%1e9
m0, m1 := m1/1e9, m1%1e9
buf := make([]byte, 0, 24)
buf = append(buf, " m="...)
buf = append(buf, sign)
wid := 0
if m0 != 0 {
buf = appendInt(buf, int(m0), 0)
wid = 9
}
buf = appendInt(buf, int(m1), wid)
buf = append(buf, '.')
buf = appendInt(buf, int(m2), 9)
s += string(buf)
}
return s
}
// GoString implements fmt.GoStringer and formats t to be printed in Go source
// code.
func (t Time) GoString() string {
abs := t.abs()
year, month, day, _ := absDate(abs, true)
hour, minute, second := absClock(abs)
buf := make([]byte, 0, len("time.Date(9999, time.September, 31, 23, 59, 59, 999999999, time.Local)"))
buf = append(buf, "time.Date("...)
buf = appendInt(buf, year, 0)
if January <= month && month <= December {
buf = append(buf, ", time."...)
buf = append(buf, longMonthNames[month-1]...)
} else {
// It's difficult to construct a time.Time with a date outside the
// standard range but we might as well try to handle the case.
buf = appendInt(buf, int(month), 0)
}
buf = append(buf, ", "...)
buf = appendInt(buf, day, 0)
buf = append(buf, ", "...)
buf = appendInt(buf, hour, 0)
buf = append(buf, ", "...)
buf = appendInt(buf, minute, 0)
buf = append(buf, ", "...)
buf = appendInt(buf, second, 0)
buf = append(buf, ", "...)
buf = appendInt(buf, t.Nanosecond(), 0)
buf = append(buf, ", "...)
switch loc := t.Location(); loc {
case UTC, nil:
buf = append(buf, "time.UTC"...)
case Local:
buf = append(buf, "time.Local"...)
default:
// there are several options for how we could display this, none of
// which are great:
//
// - use Location(loc.name), which is not technically valid syntax
// - use LoadLocation(loc.name), which will cause a syntax error when
// embedded and also would require us to escape the string without
// importing fmt or strconv
// - try to use FixedZone, which would also require escaping the name
// and would represent e.g. "America/Los_Angeles" daylight saving time
// shifts inaccurately
// - use the pointer format, which is no worse than you'd get with the
// old fmt.Sprintf("%#v", t) format.
//
// Of these, Location(loc.name) is the least disruptive. This is an edge
// case we hope not to hit too often.
buf = append(buf, `time.Location(`...)
buf = append(buf, quote(loc.name)...)
buf = append(buf, ')')
}
buf = append(buf, ')')
return string(buf)
}
// Format returns a textual representation of the time value formatted according
// to the layout defined by the argument. See the documentation for the
// constant called Layout to see how to represent the layout format.
//
// The executable example for Time.Format demonstrates the working
// of the layout string in detail and is a good reference.
func (t Time) Format(layout string) string {
const bufSize = 64
var b []byte
max := len(layout) + 10
if max < bufSize {
var buf [bufSize]byte
b = buf[:0]
} else {
b = make([]byte, 0, max)
}
b = t.AppendFormat(b, layout)
return string(b)
}
// AppendFormat is like Format but appends the textual
// representation to b and returns the extended buffer.
func (t Time) AppendFormat(b []byte, layout string) []byte {
// Optimize for RFC3339 as it accounts for over half of all representations.
switch layout {
case RFC3339:
return t.appendFormatRFC3339(b, false)
case RFC3339Nano:
return t.appendFormatRFC3339(b, true)
default:
return t.appendFormat(b, layout)
}
}
func (t Time) appendFormat(b []byte, layout string) []byte {
var (
name, offset, abs = t.locabs()
year int = -1
month Month
day int
yday int
hour int = -1
min int
sec int
)
// Each iteration generates one std value.
for layout != "" {
prefix, std, suffix := nextStdChunk(layout)
if prefix != "" {
b = append(b, prefix...)
}
if std == 0 {
break
}
layout = suffix
// Compute year, month, day if needed.
if year < 0 && std&stdNeedDate != 0 {
year, month, day, yday = absDate(abs, true)
yday++
}
// Compute hour, minute, second if needed.
if hour < 0 && std&stdNeedClock != 0 {
hour, min, sec = absClock(abs)
}
switch std & stdMask {
case stdYear:
y := year
if y < 0 {
y = -y
}
b = appendInt(b, y%100, 2)
case stdLongYear:
b = appendInt(b, year, 4)
case stdMonth:
b = append(b, month.String()[:3]...)
case stdLongMonth:
m := month.String()
b = append(b, m...)
case stdNumMonth:
b = appendInt(b, int(month), 0)
case stdZeroMonth:
b = appendInt(b, int(month), 2)
case stdWeekDay:
b = append(b, absWeekday(abs).String()[:3]...)
case stdLongWeekDay:
s := absWeekday(abs).String()
b = append(b, s...)
case stdDay:
b = appendInt(b, day, 0)
case stdUnderDay:
if day < 10 {
b = append(b, ' ')
}
b = appendInt(b, day, 0)
case stdZeroDay:
b = appendInt(b, day, 2)
case stdUnderYearDay:
if yday < 100 {
b = append(b, ' ')
if yday < 10 {
b = append(b, ' ')
}
}
b = appendInt(b, yday, 0)
case stdZeroYearDay:
b = appendInt(b, yday, 3)
case stdHour:
b = appendInt(b, hour, 2)
case stdHour12:
// Noon is 12PM, midnight is 12AM.
hr := hour % 12
if hr == 0 {
hr = 12
}
b = appendInt(b, hr, 0)
case stdZeroHour12:
// Noon is 12PM, midnight is 12AM.
hr := hour % 12
if hr == 0 {
hr = 12
}
b = appendInt(b, hr, 2)
case stdMinute:
b = appendInt(b, min, 0)
case stdZeroMinute:
b = appendInt(b, min, 2)
case stdSecond:
b = appendInt(b, sec, 0)
case stdZeroSecond:
b = appendInt(b, sec, 2)
case stdPM:
if hour >= 12 {
b = append(b, "PM"...)
} else {
b = append(b, "AM"...)
}
case stdpm:
if hour >= 12 {
b = append(b, "pm"...)
} else {
b = append(b, "am"...)
}
case stdISO8601TZ, stdISO8601ColonTZ, stdISO8601SecondsTZ, stdISO8601ShortTZ, stdISO8601ColonSecondsTZ, stdNumTZ, stdNumColonTZ, stdNumSecondsTz, stdNumShortTZ, stdNumColonSecondsTZ:
// Ugly special case. We cheat and take the "Z" variants
// to mean "the time zone as formatted for ISO 8601".
if offset == 0 && (std == stdISO8601TZ || std == stdISO8601ColonTZ || std == stdISO8601SecondsTZ || std == stdISO8601ShortTZ || std == stdISO8601ColonSecondsTZ) {
b = append(b, 'Z')
break
}
zone := offset / 60 // convert to minutes
absoffset := offset
if zone < 0 {
b = append(b, '-')
zone = -zone
absoffset = -absoffset
} else {
b = append(b, '+')
}
b = appendInt(b, zone/60, 2)
if std == stdISO8601ColonTZ || std == stdNumColonTZ || std == stdISO8601ColonSecondsTZ || std == stdNumColonSecondsTZ {
b = append(b, ':')
}
if std != stdNumShortTZ && std != stdISO8601ShortTZ {
b = appendInt(b, zone%60, 2)
}
// append seconds if appropriate
if std == stdISO8601SecondsTZ || std == stdNumSecondsTz || std == stdNumColonSecondsTZ || std == stdISO8601ColonSecondsTZ {
if std == stdNumColonSecondsTZ || std == stdISO8601ColonSecondsTZ {
b = append(b, ':')
}
b = appendInt(b, absoffset%60, 2)
}
case stdTZ:
if name != "" {
b = append(b, name...)
break
}
// No time zone known for this time, but we must print one.
// Use the -0700 format.
zone := offset / 60 // convert to minutes
if zone < 0 {
b = append(b, '-')
zone = -zone
} else {
b = append(b, '+')
}
b = appendInt(b, zone/60, 2)
b = appendInt(b, zone%60, 2)
case stdFracSecond0, stdFracSecond9:
b = appendNano(b, t.Nanosecond(), std)
}
}
return b
}
var errBad = errors.New("bad value for field") // placeholder not passed to user
// ParseError describes a problem parsing a time string.
type ParseError struct {
Layout string
Value string
LayoutElem string
ValueElem string
Message string
}
// newParseError creates a new ParseError.
// The provided value and valueElem are cloned to avoid escaping their values.
func newParseError(layout, value, layoutElem, valueElem, message string) *ParseError {
valueCopy := cloneString(value)
valueElemCopy := cloneString(valueElem)
return &ParseError{layout, valueCopy, layoutElem, valueElemCopy, message}
}
// cloneString returns a string copy of s.
// Do not use strings.Clone to avoid dependency on strings package.
func cloneString(s string) string {
return string([]byte(s))
}
// These are borrowed from unicode/utf8 and strconv and replicate behavior in
// that package, since we can't take a dependency on either.
const (
lowerhex = "0123456789abcdef"
runeSelf = 0x80
runeError = '\uFFFD'
)
func quote(s string) string {
buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes
buf[0] = '"'
for i, c := range s {
if c >= runeSelf || c < ' ' {
// This means you are asking us to parse a time.Duration or
// time.Location with unprintable or non-ASCII characters in it.
// We don't expect to hit this case very often. We could try to
// reproduce strconv.Quote's behavior with full fidelity but
// given how rarely we expect to hit these edge cases, speed and
// conciseness are better.
var width int
if c == runeError {
width = 1
if i+2 < len(s) && s[i:i+3] == string(runeError) {
width = 3
}
} else {
width = len(string(c))
}
for j := 0; j < width; j++ {
buf = append(buf, `\x`...)
buf = append(buf, lowerhex[s[i+j]>>4])
buf = append(buf, lowerhex[s[i+j]&0xF])
}
} else {
if c == '"' || c == '\\' {
buf = append(buf, '\\')
}
buf = append(buf, string(c)...)
}
}
buf = append(buf, '"')
return string(buf)
}
// Error returns the string representation of a ParseError.
func (e *ParseError) Error() string {
if e.Message == "" {
return "parsing time " +
quote(e.Value) + " as " +
quote(e.Layout) + ": cannot parse " +
quote(e.ValueElem) + " as " +
quote(e.LayoutElem)
}
return "parsing time " +
quote(e.Value) + e.Message
}
// isDigit reports whether s[i] is in range and is a decimal digit.
func isDigit[bytes []byte | string](s bytes, i int) bool {
if len(s) <= i {
return false
}
c := s[i]
return '0' <= c && c <= '9'
}
// getnum parses s[0:1] or s[0:2] (fixed forces s[0:2])
// as a decimal integer and returns the integer and the
// remainder of the string.
func getnum(s string, fixed bool) (int, string, error) {
if !isDigit(s, 0) {
return 0, s, errBad
}
if !isDigit(s, 1) {
if fixed {
return 0, s, errBad
}
return int(s[0] - '0'), s[1:], nil
}
return int(s[0]-'0')*10 + int(s[1]-'0'), s[2:], nil
}
// getnum3 parses s[0:1], s[0:2], or s[0:3] (fixed forces s[0:3])
// as a decimal integer and returns the integer and the remainder
// of the string.
func getnum3(s string, fixed bool) (int, string, error) {
var n, i int
for i = 0; i < 3 && isDigit(s, i); i++ {
n = n*10 + int(s[i]-'0')
}
if i == 0 || fixed && i != 3 {
return 0, s, errBad
}
return n, s[i:], nil
}
func cutspace(s string) string {
for len(s) > 0 && s[0] == ' ' {
s = s[1:]
}
return s
}
// skip removes the given prefix from value,
// treating runs of space characters as equivalent.
func skip(value, prefix string) (string, error) {
for len(prefix) > 0 {
if prefix[0] == ' ' {
if len(value) > 0 && value[0] != ' ' {
return value, errBad
}
prefix = cutspace(prefix)
value = cutspace(value)
continue
}
if len(value) == 0 || value[0] != prefix[0] {
return value, errBad
}
prefix = prefix[1:]
value = value[1:]
}
return value, nil
}
// Parse parses a formatted string and returns the time value it represents.
// See the documentation for the constant called Layout to see how to
// represent the format. The second argument must be parseable using
// the format string (layout) provided as the first argument.
//
// The example for Time.Format demonstrates the working of the layout string
// in detail and is a good reference.
//
// When parsing (only), the input may contain a fractional second
// field immediately after the seconds field, even if the layout does not
// signify its presence. In that case either a comma or a decimal point
// followed by a maximal series of digits is parsed as a fractional second.
// Fractional seconds are truncated to nanosecond precision.
//
// Elements omitted from the layout are assumed to be zero or, when
// zero is impossible, one, so parsing "3:04pm" returns the time
// corresponding to Jan 1, year 0, 15:04:00 UTC (note that because the year is
// 0, this time is before the zero Time).
// Years must be in the range 0000..9999. The day of the week is checked
// for syntax but it is otherwise ignored.
//
// For layouts specifying the two-digit year 06, a value NN >= 69 will be treated
// as 19NN and a value NN < 69 will be treated as 20NN.
//
// The remainder of this comment describes the handling of time zones.
//
// In the absence of a time zone indicator, Parse returns a time in UTC.
//
// When parsing a time with a zone offset like -0700, if the offset corresponds
// to a time zone used by the current location (Local), then Parse uses that
// location and zone in the returned time. Otherwise it records the time as
// being in a fabricated location with time fixed at the given zone offset.
//
// When parsing a time with a zone abbreviation like MST, if the zone abbreviation
// has a defined offset in the current location, then that offset is used.
// The zone abbreviation "UTC" is recognized as UTC regardless of location.
// If the zone abbreviation is unknown, Parse records the time as being
// in a fabricated location with the given zone abbreviation and a zero offset.
// This choice means that such a time can be parsed and reformatted with the
// same layout losslessly, but the exact instant used in the representation will
// differ by the actual zone offset. To avoid such problems, prefer time layouts
// that use a numeric zone offset, or use ParseInLocation.
func Parse(layout, value string) (Time, error) {
// Optimize for RFC3339 as it accounts for over half of all representations.
if layout == RFC3339 || layout == RFC3339Nano {
if t, ok := parseRFC3339(value, Local); ok {
return t, nil
}
}
return parse(layout, value, UTC, Local)
}
// ParseInLocation is like Parse but differs in two important ways.
// First, in the absence of time zone information, Parse interprets a time as UTC;
// ParseInLocation interprets the time as in the given location.
// Second, when given a zone offset or abbreviation, Parse tries to match it
// against the Local location; ParseInLocation uses the given location.
func ParseInLocation(layout, value string, loc *Location) (Time, error) {
// Optimize for RFC3339 as it accounts for over half of all representations.
if layout == RFC3339 || layout == RFC3339Nano {
if t, ok := parseRFC3339(value, loc); ok {
return t, nil
}
}
return parse(layout, value, loc, loc)
}
func parse(layout, value string, defaultLocation, local *Location) (Time, error) {
alayout, avalue := layout, value
rangeErrString := "" // set if a value is out of range
amSet := false // do we need to subtract 12 from the hour for midnight?
pmSet := false // do we need to add 12 to the hour?
// Time being constructed.
var (
year int
month int = -1
day int = -1
yday int = -1
hour int
min int
sec int
nsec int
z *Location
zoneOffset int = -1
zoneName string
)
// Each iteration processes one std value.
for {
var err error
prefix, std, suffix := nextStdChunk(layout)
stdstr := layout[len(prefix) : len(layout)-len(suffix)]
value, err = skip(value, prefix)
if err != nil {
return Time{}, newParseError(alayout, avalue, prefix, value, "")
}
if std == 0 {
if len(value) != 0 {
return Time{}, newParseError(alayout, avalue, "", value, ": extra text: "+quote(value))
}
break
}
layout = suffix
var p string
hold := value
switch std & stdMask {
case stdYear:
if len(value) < 2 {
err = errBad
break
}
p, value = value[0:2], value[2:]
year, err = atoi(p)
if err != nil {
break
}
if year >= 69 { // Unix time starts Dec 31 1969 in some time zones
year += 1900
} else {
year += 2000
}
case stdLongYear:
if len(value) < 4 || !isDigit(value, 0) {
err = errBad
break
}
p, value = value[0:4], value[4:]
year, err = atoi(p)
case stdMonth:
month, value, err = lookup(shortMonthNames, value)
month++
case stdLongMonth:
month, value, err = lookup(longMonthNames, value)
month++
case stdNumMonth, stdZeroMonth:
month, value, err = getnum(value, std == stdZeroMonth)
if err == nil && (month <= 0 || 12 < month) {
rangeErrString = "month"
}
case stdWeekDay:
// Ignore weekday except for error checking.
_, value, err = lookup(shortDayNames, value)
case stdLongWeekDay:
_, value, err = lookup(longDayNames, value)
case stdDay, stdUnderDay, stdZeroDay:
if std == stdUnderDay && len(value) > 0 && value[0] == ' ' {
value = value[1:]
}
day, value, err = getnum(value, std == stdZeroDay)
// Note that we allow any one- or two-digit day here.
// The month, day, year combination is validated after we've completed parsing.
case stdUnderYearDay, stdZeroYearDay:
for i := 0; i < 2; i++ {
if std == stdUnderYearDay && len(value) > 0 && value[0] == ' ' {
value = value[1:]
}
}
yday, value, err = getnum3(value, std == stdZeroYearDay)
// Note that we allow any one-, two-, or three-digit year-day here.
// The year-day, year combination is validated after we've completed parsing.
case stdHour:
hour, value, err = getnum(value, false)
if hour < 0 || 24 <= hour {
rangeErrString = "hour"
}
case stdHour12, stdZeroHour12:
hour, value, err = getnum(value, std == stdZeroHour12)
if hour < 0 || 12 < hour {
rangeErrString = "hour"
}
case stdMinute, stdZeroMinute:
min, value, err = getnum(value, std == stdZeroMinute)
if min < 0 || 60 <= min {
rangeErrString = "minute"
}
case stdSecond, stdZeroSecond:
sec, value, err = getnum(value, std == stdZeroSecond)
if err != nil {
break
}
if sec < 0 || 60 <= sec {
rangeErrString = "second"
break
}
// Special case: do we have a fractional second but no
// fractional second in the format?
if len(value) >= 2 && commaOrPeriod(value[0]) && isDigit(value, 1) {
_, std, _ = nextStdChunk(layout)
std &= stdMask
if std == stdFracSecond0 || std == stdFracSecond9 {
// Fractional second in the layout; proceed normally
break
}
// No fractional second in the layout but we have one in the input.
n := 2
for ; n < len(value) && isDigit(value, n); n++ {
}
nsec, rangeErrString, err = parseNanoseconds(value, n)
value = value[n:]
}
case stdPM:
if len(value) < 2 {
err = errBad
break
}
p, value = value[0:2], value[2:]
switch p {
case "PM":
pmSet = true
case "AM":
amSet = true
default:
err = errBad
}
case stdpm:
if len(value) < 2 {
err = errBad
break
}
p, value = value[0:2], value[2:]
switch p {
case "pm":
pmSet = true
case "am":
amSet = true
default:
err = errBad
}
case stdISO8601TZ, stdISO8601ColonTZ, stdISO8601SecondsTZ, stdISO8601ShortTZ, stdISO8601ColonSecondsTZ, stdNumTZ, stdNumShortTZ, stdNumColonTZ, stdNumSecondsTz, stdNumColonSecondsTZ:
if (std == stdISO8601TZ || std == stdISO8601ShortTZ || std == stdISO8601ColonTZ) && len(value) >= 1 && value[0] == 'Z' {
value = value[1:]
z = UTC
break
}
var sign, hour, min, seconds string
if std == stdISO8601ColonTZ || std == stdNumColonTZ {
if len(value) < 6 {
err = errBad
break
}
if value[3] != ':' {
err = errBad
break
}
sign, hour, min, seconds, value = value[0:1], value[1:3], value[4:6], "00", value[6:]
} else if std == stdNumShortTZ || std == stdISO8601ShortTZ {
if len(value) < 3 {
err = errBad
break
}
sign, hour, min, seconds, value = value[0:1], value[1:3], "00", "00", value[3:]
} else if std == stdISO8601ColonSecondsTZ || std == stdNumColonSecondsTZ {
if len(value) < 9 {
err = errBad
break
}
if value[3] != ':' || value[6] != ':' {
err = errBad
break
}
sign, hour, min, seconds, value = value[0:1], value[1:3], value[4:6], value[7:9], value[9:]
} else if std == stdISO8601SecondsTZ || std == stdNumSecondsTz {
if len(value) < 7 {
err = errBad
break
}
sign, hour, min, seconds, value = value[0:1], value[1:3], value[3:5], value[5:7], value[7:]
} else {
if len(value) < 5 {
err = errBad
break
}
sign, hour, min, seconds, value = value[0:1], value[1:3], value[3:5], "00", value[5:]
}
var hr, mm, ss int
hr, _, err = getnum(hour, true)
if err == nil {
mm, _, err = getnum(min, true)
}
if err == nil {
ss, _, err = getnum(seconds, true)
}
zoneOffset = (hr*60+mm)*60 + ss // offset is in seconds
switch sign[0] {
case '+':
case '-':
zoneOffset = -zoneOffset
default:
err = errBad
}
case stdTZ:
// Does it look like a time zone?
if len(value) >= 3 && value[0:3] == "UTC" {
z = UTC
value = value[3:]
break
}
n, ok := parseTimeZone(value)
if !ok {
err = errBad
break
}
zoneName, value = value[:n], value[n:]
case stdFracSecond0:
// stdFracSecond0 requires the exact number of digits as specified in
// the layout.
ndigit := 1 + digitsLen(std)
if len(value) < ndigit {
err = errBad
break
}
nsec, rangeErrString, err = parseNanoseconds(value, ndigit)
value = value[ndigit:]
case stdFracSecond9:
if len(value) < 2 || !commaOrPeriod(value[0]) || value[1] < '0' || '9' < value[1] {
// Fractional second omitted.
break
}
// Take any number of digits, even more than asked for,
// because it is what the stdSecond case would do.
i := 0
for i+1 < len(value) && '0' <= value[i+1] && value[i+1] <= '9' {
i++
}
nsec, rangeErrString, err = parseNanoseconds(value, 1+i)
value = value[1+i:]
}
if rangeErrString != "" {
return Time{}, newParseError(alayout, avalue, stdstr, value, ": "+rangeErrString+" out of range")
}
if err != nil {
return Time{}, newParseError(alayout, avalue, stdstr, hold, "")
}
}
if pmSet && hour < 12 {
hour += 12
} else if amSet && hour == 12 {
hour = 0
}
// Convert yday to day, month.
if yday >= 0 {
var d int
var m int
if isLeap(year) {
if yday == 31+29 {
m = int(February)
d = 29
} else if yday > 31+29 {
yday--
}
}
if yday < 1 || yday > 365 {
return Time{}, newParseError(alayout, avalue, "", value, ": day-of-year out of range")
}
if m == 0 {
m = (yday-1)/31 + 1
if int(daysBefore[m]) < yday {
m++
}
d = yday - int(daysBefore[m-1])
}
// If month, day already seen, yday's m, d must match.
// Otherwise, set them from m, d.
if month >= 0 && month != m {
return Time{}, newParseError(alayout, avalue, "", value, ": day-of-year does not match month")
}
month = m
if day >= 0 && day != d {
return Time{}, newParseError(alayout, avalue, "", value, ": day-of-year does not match day")
}
day = d
} else {
if month < 0 {
month = int(January)
}
if day < 0 {
day = 1
}
}
// Validate the day of the month.
if day < 1 || day > daysIn(Month(month), year) {
return Time{}, newParseError(alayout, avalue, "", value, ": day out of range")
}
if z != nil {
return Date(year, Month(month), day, hour, min, sec, nsec, z), nil
}
if zoneOffset != -1 {
t := Date(year, Month(month), day, hour, min, sec, nsec, UTC)
t.addSec(-int64(zoneOffset))
// Look for local zone with the given offset.
// If that zone was in effect at the given time, use it.
name, offset, _, _, _ := local.lookup(t.unixSec())
if offset == zoneOffset && (zoneName == "" || name == zoneName) {
t.setLoc(local)
return t, nil
}
// Otherwise create fake zone to record offset.
zoneNameCopy := cloneString(zoneName) // avoid leaking the input value
t.setLoc(FixedZone(zoneNameCopy, zoneOffset))
return t, nil
}
if zoneName != "" {
t := Date(year, Month(month), day, hour, min, sec, nsec, UTC)
// Look for local zone with the given offset.
// If that zone was in effect at the given time, use it.
offset, ok := local.lookupName(zoneName, t.unixSec())
if ok {
t.addSec(-int64(offset))
t.setLoc(local)
return t, nil
}
// Otherwise, create fake zone with unknown offset.
if len(zoneName) > 3 && zoneName[:3] == "GMT" {
offset, _ = atoi(zoneName[3:]) // Guaranteed OK by parseGMT.
offset *= 3600
}
zoneNameCopy := cloneString(zoneName) // avoid leaking the input value
t.setLoc(FixedZone(zoneNameCopy, offset))
return t, nil
}
// Otherwise, fall back to default.
return Date(year, Month(month), day, hour, min, sec, nsec, defaultLocation), nil
}
// parseTimeZone parses a time zone string and returns its length. Time zones
// are human-generated and unpredictable. We can't do precise error checking.
// On the other hand, for a correct parse there must be a time zone at the
// beginning of the string, so it's almost always true that there's one
// there. We look at the beginning of the string for a run of upper-case letters.
// If there are more than 5, it's an error.
// If there are 4 or 5 and the last is a T, it's a time zone.
// If there are 3, it's a time zone.
// Otherwise, other than special cases, it's not a time zone.
// GMT is special because it can have an hour offset.
func parseTimeZone(value string) (length int, ok bool) {
if len(value) < 3 {
return 0, false
}
// Special case 1: ChST and MeST are the only zones with a lower-case letter.
if len(value) >= 4 && (value[:4] == "ChST" || value[:4] == "MeST") {
return 4, true
}
// Special case 2: GMT may have an hour offset; treat it specially.
if value[:3] == "GMT" {
length = parseGMT(value)
return length, true
}
// Special Case 3: Some time zones are not named, but have +/-00 format
if value[0] == '+' || value[0] == '-' {
length = parseSignedOffset(value)
ok := length > 0 // parseSignedOffset returns 0 in case of bad input
return length, ok
}
// How many upper-case letters are there? Need at least three, at most five.
var nUpper int
for nUpper = 0; nUpper < 6; nUpper++ {
if nUpper >= len(value) {
break
}
if c := value[nUpper]; c < 'A' || 'Z' < c {
break
}
}
switch nUpper {
case 0, 1, 2, 6:
return 0, false
case 5: // Must end in T to match.
if value[4] == 'T' {
return 5, true
}
case 4:
// Must end in T, except one special case.
if value[3] == 'T' || value[:4] == "WITA" {
return 4, true
}
case 3:
return 3, true
}
return 0, false
}
// parseGMT parses a GMT time zone. The input string is known to start "GMT".
// The function checks whether that is followed by a sign and a number in the
// range -23 through +23 excluding zero.
func parseGMT(value string) int {
value = value[3:]
if len(value) == 0 {
return 3
}
return 3 + parseSignedOffset(value)
}
// parseSignedOffset parses a signed timezone offset (e.g. "+03" or "-04").
// The function checks for a signed number in the range -23 through +23 excluding zero.
// Returns length of the found offset string or 0 otherwise.
func parseSignedOffset(value string) int {
sign := value[0]
if sign != '-' && sign != '+' {
return 0
}
x, rem, err := leadingInt(value[1:])
// fail if nothing consumed by leadingInt
if err != nil || value[1:] == rem {
return 0
}
if x > 23 {
return 0
}
return len(value) - len(rem)
}
func commaOrPeriod(b byte) bool {
return b == '.' || b == ','
}
func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, rangeErrString string, err error) {
if !commaOrPeriod(value[0]) {
err = errBad
return
}
if nbytes > 10 {
value = value[:10]
nbytes = 10
}
if ns, err = atoi(value[1:nbytes]); err != nil {
return
}
if ns < 0 {
rangeErrString = "fractional second"
return
}
// We need nanoseconds, which means scaling by the number
// of missing digits in the format, maximum length 10.
scaleDigits := 10 - nbytes
for i := 0; i < scaleDigits; i++ {
ns *= 10
}
return
}
var errLeadingInt = errors.New("time: bad [0-9]*") // never printed
// leadingInt consumes the leading [0-9]* from s.
func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) {
i := 0
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if x > 1<<63/10 {
// overflow
return 0, rem, errLeadingInt
}
x = x*10 + uint64(c) - '0'
if x > 1<<63 {
// overflow
return 0, rem, errLeadingInt
}
}
return x, s[i:], nil
}
// leadingFraction consumes the leading [0-9]* from s.
// It is used only for fractions, so does not return an error on overflow,
// it just stops accumulating precision.
func leadingFraction(s string) (x uint64, scale float64, rem string) {
i := 0
scale = 1
overflow := false
for ; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
if overflow {
continue
}
if x > (1<<63-1)/10 {
// It's possible for overflow to give a positive number, so take care.
overflow = true
continue
}
y := x*10 + uint64(c) - '0'
if y > 1<<63 {
overflow = true
continue
}
x = y
scale *= 10
}
return x, scale, s[i:]
}
var unitMap = map[string]uint64{
"ns": uint64(Nanosecond),
"us": uint64(Microsecond),
"µs": uint64(Microsecond), // U+00B5 = micro symbol
"μs": uint64(Microsecond), // U+03BC = Greek letter mu
"ms": uint64(Millisecond),
"s": uint64(Second),
"m": uint64(Minute),
"h": uint64(Hour),
}
// ParseDuration parses a duration string.
// A duration string is a possibly signed sequence of
// decimal numbers, each with optional fraction and a unit suffix,
// such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func ParseDuration(s string) (Duration, error) {
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
orig := s
var d uint64
neg := false
// Consume [-+]?
if s != "" {
c := s[0]
if c == '-' || c == '+' {
neg = c == '-'
s = s[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if s == "0" {
return 0, nil
}
if s == "" {
return 0, errors.New("time: invalid duration " + quote(orig))
}
for s != "" {
var (
v, f uint64 // integers before, after decimal point
scale float64 = 1 // value = v + f/scale
)
var err error
// The next character must be [0-9.]
if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') {
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume [0-9]*
pl := len(s)
v, s, err = leadingInt(s)
if err != nil {
return 0, errors.New("time: invalid duration " + quote(orig))
}
pre := pl != len(s) // whether we consumed anything before a period
// Consume (\.[0-9]*)?
post := false
if s != "" && s[0] == '.' {
s = s[1:]
pl := len(s)
f, scale, s = leadingFraction(s)
post = pl != len(s)
}
if !pre && !post {
// no digits (e.g. ".s" or "-.s")
return 0, errors.New("time: invalid duration " + quote(orig))
}
// Consume unit.
i := 0
for ; i < len(s); i++ {
c := s[i]
if c == '.' || '0' <= c && c <= '9' {
break
}
}
if i == 0 {
return 0, errors.New("time: missing unit in duration " + quote(orig))
}
u := s[:i]
s = s[i:]
unit, ok := unitMap[u]
if !ok {
return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig))
}
if v > 1<<63/unit {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
v *= unit
if f > 0 {
// float64 is needed to be nanosecond accurate for fractions of hours.
// v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit)
v += uint64(float64(f) * (float64(unit) / scale))
if v > 1<<63 {
// overflow
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
d += v
if d > 1<<63 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
}
if neg {
return -Duration(d), nil
}
if d > 1<<63-1 {
return 0, errors.New("time: invalid duration " + quote(orig))
}
return Duration(d), nil
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package time
import "errors"
// RFC 3339 is the most commonly used format.
//
// It is implicitly used by the Time.(Marshal|Unmarshal)(Text|JSON) methods.
// Also, according to analysis on https://go.dev/issue/52746,
// RFC 3339 accounts for 57% of all explicitly specified time formats,
// with the second most popular format only being used 8% of the time.
// The overwhelming use of RFC 3339 compared to all other formats justifies
// the addition of logic to optimize formatting and parsing.
func (t Time) appendFormatRFC3339(b []byte, nanos bool) []byte {
_, offset, abs := t.locabs()
// Format date.
year, month, day, _ := absDate(abs, true)
b = appendInt(b, year, 4)
b = append(b, '-')
b = appendInt(b, int(month), 2)
b = append(b, '-')
b = appendInt(b, day, 2)
b = append(b, 'T')
// Format time.
hour, min, sec := absClock(abs)
b = appendInt(b, hour, 2)
b = append(b, ':')
b = appendInt(b, min, 2)
b = append(b, ':')
b = appendInt(b, sec, 2)
if nanos {
std := stdFracSecond(stdFracSecond9, 9, '.')
b = appendNano(b, t.Nanosecond(), std)
}
if offset == 0 {
return append(b, 'Z')
}
// Format zone.
zone := offset / 60 // convert to minutes
if zone < 0 {
b = append(b, '-')
zone = -zone
} else {
b = append(b, '+')
}
b = appendInt(b, zone/60, 2)
b = append(b, ':')
b = appendInt(b, zone%60, 2)
return b
}
func (t Time) appendStrictRFC3339(b []byte) ([]byte, error) {
n0 := len(b)
b = t.appendFormatRFC3339(b, true)
// Not all valid Go timestamps can be serialized as valid RFC 3339.
// Explicitly check for these edge cases.
// See https://go.dev/issue/4556 and https://go.dev/issue/54580.
num2 := func(b []byte) byte { return 10*(b[0]-'0') + (b[1] - '0') }
switch {
case b[n0+len("9999")] != '-': // year must be exactly 4 digits wide
return b, errors.New("year outside of range [0,9999]")
case b[len(b)-1] != 'Z':
c := b[len(b)-len("Z07:00")]
if ('0' <= c && c <= '9') || num2(b[len(b)-len("07:00"):]) >= 24 {
return b, errors.New("timezone hour outside of range [0,23]")
}
}
return b, nil
}
func parseRFC3339[bytes []byte | string](s bytes, local *Location) (Time, bool) {
// parseUint parses s as an unsigned decimal integer and
// verifies that it is within some range.
// If it is invalid or out-of-range,
// it sets ok to false and returns the min value.
ok := true
parseUint := func(s bytes, min, max int) (x int) {
for _, c := range []byte(s) {
if c < '0' || '9' < c {
ok = false
return min
}
x = x*10 + int(c) - '0'
}
if x < min || max < x {
ok = false
return min
}
return x
}
// Parse the date and time.
if len(s) < len("2006-01-02T15:04:05") {
return Time{}, false
}
year := parseUint(s[0:4], 0, 9999) // e.g., 2006
month := parseUint(s[5:7], 1, 12) // e.g., 01
day := parseUint(s[8:10], 1, daysIn(Month(month), year)) // e.g., 02
hour := parseUint(s[11:13], 0, 23) // e.g., 15
min := parseUint(s[14:16], 0, 59) // e.g., 04
sec := parseUint(s[17:19], 0, 59) // e.g., 05
if !ok || !(s[4] == '-' && s[7] == '-' && s[10] == 'T' && s[13] == ':' && s[16] == ':') {
return Time{}, false
}
s = s[19:]
// Parse the fractional second.
var nsec int
if len(s) >= 2 && s[0] == '.' && isDigit(s, 1) {
n := 2
for ; n < len(s) && isDigit(s, n); n++ {
}
nsec, _, _ = parseNanoseconds(s, n)
s = s[n:]
}
// Parse the time zone.
t := Date(year, Month(month), day, hour, min, sec, nsec, UTC)
if len(s) != 1 || s[0] != 'Z' {
if len(s) != len("-07:00") {
return Time{}, false
}
hr := parseUint(s[1:3], 0, 23) // e.g., 07
mm := parseUint(s[4:6], 0, 59) // e.g., 00
if !ok || !((s[0] == '-' || s[0] == '+') && s[3] == ':') {
return Time{}, false
}
zoneOffset := (hr*60 + mm) * 60
if s[0] == '-' {
zoneOffset *= -1
}
t.addSec(-int64(zoneOffset))
// Use local zone with the given offset if possible.
if _, offset, _, _, _ := local.lookup(t.unixSec()); offset == zoneOffset {
t.setLoc(local)
} else {
t.setLoc(FixedZone("", zoneOffset))
}
}
return t, true
}
func parseStrictRFC3339(b []byte) (Time, error) {
t, ok := parseRFC3339(b, Local)
if !ok {
t, err := Parse(RFC3339, string(b))
if err != nil {
return Time{}, err
}
// The parse template syntax cannot correctly validate RFC 3339.
// Explicitly check for cases that Parse is unable to validate for.
// See https://go.dev/issue/54580.
num2 := func(b []byte) byte { return 10*(b[0]-'0') + (b[1] - '0') }
switch {
// TODO(https://go.dev/issue/54580): Strict parsing is disabled for now.
// Enable this again with a GODEBUG opt-out.
case true:
return t, nil
case b[len("2006-01-02T")+1] == ':': // hour must be two digits
return Time{}, &ParseError{RFC3339, string(b), "15", string(b[len("2006-01-02T"):][:1]), ""}
case b[len("2006-01-02T15:04:05")] == ',': // sub-second separator must be a period
return Time{}, &ParseError{RFC3339, string(b), ".", ",", ""}
case b[len(b)-1] != 'Z':
switch {
case num2(b[len(b)-len("07:00"):]) >= 24: // timezone hour must be in range
return Time{}, &ParseError{RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone hour out of range"}
case num2(b[len(b)-len("00"):]) >= 60: // timezone minute must be in range
return Time{}, &ParseError{RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone minute out of range"}
}
default: // unknown error; should not occur
return Time{}, &ParseError{RFC3339, string(b), RFC3339, string(b), ""}
}
}
return t, nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package time
// Sleep pauses the current goroutine for at least the duration d.
// A negative or zero duration causes Sleep to return immediately.
func Sleep(d Duration)
// Interface to timers implemented in package runtime.
// Must be in sync with ../runtime/time.go:/^type timer
type runtimeTimer struct {
pp uintptr
when int64
period int64
f func(any, uintptr) // NOTE: must not be closure
arg any
seq uintptr
nextwhen int64
status uint32
}
// when is a helper function for setting the 'when' field of a runtimeTimer.
// It returns what the time will be, in nanoseconds, Duration d in the future.
// If d is negative, it is ignored. If the returned value would be less than
// zero because of an overflow, MaxInt64 is returned.
func when(d Duration) int64 {
if d <= 0 {
return runtimeNano()
}
t := runtimeNano() + int64(d)
if t < 0 {
// N.B. runtimeNano() and d are always positive, so addition
// (including overflow) will never result in t == 0.
t = 1<<63 - 1 // math.MaxInt64
}
return t
}
func startTimer(*runtimeTimer)
func stopTimer(*runtimeTimer) bool
func resetTimer(*runtimeTimer, int64) bool
func modTimer(t *runtimeTimer, when, period int64, f func(any, uintptr), arg any, seq uintptr)
// The Timer type represents a single event.
// When the Timer expires, the current time will be sent on C,
// unless the Timer was created by AfterFunc.
// A Timer must be created with NewTimer or AfterFunc.
type Timer struct {
C <-chan Time
r runtimeTimer
}
// Stop prevents the Timer from firing.
// It returns true if the call stops the timer, false if the timer has already
// expired or been stopped.
// Stop does not close the channel, to prevent a read from the channel succeeding
// incorrectly.
//
// To ensure the channel is empty after a call to Stop, check the
// return value and drain the channel.
// For example, assuming the program has not received from t.C already:
//
// if !t.Stop() {
// <-t.C
// }
//
// This cannot be done concurrent to other receives from the Timer's
// channel or other calls to the Timer's Stop method.
//
// For a timer created with AfterFunc(d, f), if t.Stop returns false, then the timer
// has already expired and the function f has been started in its own goroutine;
// Stop does not wait for f to complete before returning.
// If the caller needs to know whether f is completed, it must coordinate
// with f explicitly.
func (t *Timer) Stop() bool {
if t.r.f == nil {
panic("time: Stop called on uninitialized Timer")
}
return stopTimer(&t.r)
}
// NewTimer creates a new Timer that will send
// the current time on its channel after at least duration d.
func NewTimer(d Duration) *Timer {
c := make(chan Time, 1)
t := &Timer{
C: c,
r: runtimeTimer{
when: when(d),
f: sendTime,
arg: c,
},
}
startTimer(&t.r)
return t
}
// Reset changes the timer to expire after duration d.
// It returns true if the timer had been active, false if the timer had
// expired or been stopped.
//
// For a Timer created with NewTimer, Reset should be invoked only on
// stopped or expired timers with drained channels.
//
// If a program has already received a value from t.C, the timer is known
// to have expired and the channel drained, so t.Reset can be used directly.
// If a program has not yet received a value from t.C, however,
// the timer must be stopped and—if Stop reports that the timer expired
// before being stopped—the channel explicitly drained:
//
// if !t.Stop() {
// <-t.C
// }
// t.Reset(d)
//
// This should not be done concurrent to other receives from the Timer's
// channel.
//
// Note that it is not possible to use Reset's return value correctly, as there
// is a race condition between draining the channel and the new timer expiring.
// Reset should always be invoked on stopped or expired channels, as described above.
// The return value exists to preserve compatibility with existing programs.
//
// For a Timer created with AfterFunc(d, f), Reset either reschedules
// when f will run, in which case Reset returns true, or schedules f
// to run again, in which case it returns false.
// When Reset returns false, Reset neither waits for the prior f to
// complete before returning nor does it guarantee that the subsequent
// goroutine running f does not run concurrently with the prior
// one. If the caller needs to know whether the prior execution of
// f is completed, it must coordinate with f explicitly.
func (t *Timer) Reset(d Duration) bool {
if t.r.f == nil {
panic("time: Reset called on uninitialized Timer")
}
w := when(d)
return resetTimer(&t.r, w)
}
// sendTime does a non-blocking send of the current time on c.
func sendTime(c any, seq uintptr) {
select {
case c.(chan Time) <- Now():
default:
}
}
// After waits for the duration to elapse and then sends the current time
// on the returned channel.
// It is equivalent to NewTimer(d).C.
// The underlying Timer is not recovered by the garbage collector
// until the timer fires. If efficiency is a concern, use NewTimer
// instead and call Timer.Stop if the timer is no longer needed.
func After(d Duration) <-chan Time {
return NewTimer(d).C
}
// AfterFunc waits for the duration to elapse and then calls f
// in its own goroutine. It returns a Timer that can
// be used to cancel the call using its Stop method.
func AfterFunc(d Duration, f func()) *Timer {
t := &Timer{
r: runtimeTimer{
when: when(d),
f: goFunc,
arg: f,
},
}
startTimer(&t.r)
return t
}
func goFunc(arg any, seq uintptr) {
go arg.(func())()
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix || (js && wasm)
package time
import (
"errors"
"syscall"
)
// for testing: whatever interrupts a sleep
func interrupt() {
syscall.Kill(syscall.Getpid(), syscall.SIGCHLD)
}
func open(name string) (uintptr, error) {
fd, err := syscall.Open(name, syscall.O_RDONLY, 0)
if err != nil {
return 0, err
}
return uintptr(fd), nil
}
func read(fd uintptr, buf []byte) (int, error) {
return syscall.Read(int(fd), buf)
}
func closefd(fd uintptr) {
syscall.Close(int(fd))
}
func preadn(fd uintptr, buf []byte, off int) error {
whence := seekStart
if off < 0 {
whence = seekEnd
}
if _, err := syscall.Seek(int(fd), int64(off), whence); err != nil {
return err
}
for len(buf) > 0 {
m, err := syscall.Read(int(fd), buf)
if m <= 0 {
if err == nil {
return errors.New("short read")
}
return err
}
buf = buf[m:]
}
return nil
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package time
// A Ticker holds a channel that delivers “ticks” of a clock
// at intervals.
type Ticker struct {
C <-chan Time // The channel on which the ticks are delivered.
r runtimeTimer
}
// NewTicker returns a new Ticker containing a channel that will send
// the current time on the channel after each tick. The period of the
// ticks is specified by the duration argument. The ticker will adjust
// the time interval or drop ticks to make up for slow receivers.
// The duration d must be greater than zero; if not, NewTicker will
// panic. Stop the ticker to release associated resources.
func NewTicker(d Duration) *Ticker {
if d <= 0 {
panic("non-positive interval for NewTicker")
}
// Give the channel a 1-element time buffer.
// If the client falls behind while reading, we drop ticks
// on the floor until the client catches up.
c := make(chan Time, 1)
t := &Ticker{
C: c,
r: runtimeTimer{
when: when(d),
period: int64(d),
f: sendTime,
arg: c,
},
}
startTimer(&t.r)
return t
}
// Stop turns off a ticker. After Stop, no more ticks will be sent.
// Stop does not close the channel, to prevent a concurrent goroutine
// reading from the channel from seeing an erroneous "tick".
func (t *Ticker) Stop() {
stopTimer(&t.r)
}
// Reset stops a ticker and resets its period to the specified duration.
// The next tick will arrive after the new period elapses. The duration d
// must be greater than zero; if not, Reset will panic.
func (t *Ticker) Reset(d Duration) {
if d <= 0 {
panic("non-positive interval for Ticker.Reset")
}
if t.r.f == nil {
panic("time: Reset called on uninitialized Ticker")
}
modTimer(&t.r, when(d), int64(d), t.r.f, t.r.arg, t.r.seq)
}
// Tick is a convenience wrapper for NewTicker providing access to the ticking
// channel only. While Tick is useful for clients that have no need to shut down
// the Ticker, be aware that without a way to shut it down the underlying
// Ticker cannot be recovered by the garbage collector; it "leaks".
// Unlike NewTicker, Tick will return nil if d <= 0.
func Tick(d Duration) <-chan Time {
if d <= 0 {
return nil
}
return NewTicker(d).C
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package time provides functionality for measuring and displaying time.
//
// The calendrical calculations always assume a Gregorian calendar, with
// no leap seconds.
//
// # Monotonic Clocks
//
// Operating systems provide both a “wall clock,” which is subject to
// changes for clock synchronization, and a “monotonic clock,” which is
// not. The general rule is that the wall clock is for telling time and
// the monotonic clock is for measuring time. Rather than split the API,
// in this package the Time returned by time.Now contains both a wall
// clock reading and a monotonic clock reading; later time-telling
// operations use the wall clock reading, but later time-measuring
// operations, specifically comparisons and subtractions, use the
// monotonic clock reading.
//
// For example, this code always computes a positive elapsed time of
// approximately 20 milliseconds, even if the wall clock is changed during
// the operation being timed:
//
// start := time.Now()
// ... operation that takes 20 milliseconds ...
// t := time.Now()
// elapsed := t.Sub(start)
//
// Other idioms, such as time.Since(start), time.Until(deadline), and
// time.Now().Before(deadline), are similarly robust against wall clock
// resets.
//
// The rest of this section gives the precise details of how operations
// use monotonic clocks, but understanding those details is not required
// to use this package.
//
// The Time returned by time.Now contains a monotonic clock reading.
// If Time t has a monotonic clock reading, t.Add adds the same duration to
// both the wall clock and monotonic clock readings to compute the result.
// Because t.AddDate(y, m, d), t.Round(d), and t.Truncate(d) are wall time
// computations, they always strip any monotonic clock reading from their results.
// Because t.In, t.Local, and t.UTC are used for their effect on the interpretation
// of the wall time, they also strip any monotonic clock reading from their results.
// The canonical way to strip a monotonic clock reading is to use t = t.Round(0).
//
// If Times t and u both contain monotonic clock readings, the operations
// t.After(u), t.Before(u), t.Equal(u), t.Compare(u), and t.Sub(u) are carried out
// using the monotonic clock readings alone, ignoring the wall clock
// readings. If either t or u contains no monotonic clock reading, these
// operations fall back to using the wall clock readings.
//
// On some systems the monotonic clock will stop if the computer goes to sleep.
// On such a system, t.Sub(u) may not accurately reflect the actual
// time that passed between t and u.
//
// Because the monotonic clock reading has no meaning outside
// the current process, the serialized forms generated by t.GobEncode,
// t.MarshalBinary, t.MarshalJSON, and t.MarshalText omit the monotonic
// clock reading, and t.Format provides no format for it. Similarly, the
// constructors time.Date, time.Parse, time.ParseInLocation, and time.Unix,
// as well as the unmarshalers t.GobDecode, t.UnmarshalBinary.
// t.UnmarshalJSON, and t.UnmarshalText always create times with
// no monotonic clock reading.
//
// The monotonic clock reading exists only in Time values. It is not
// a part of Duration values or the Unix times returned by t.Unix and
// friends.
//
// Note that the Go == operator compares not just the time instant but
// also the Location and the monotonic clock reading. See the
// documentation for the Time type for a discussion of equality
// testing for Time values.
//
// For debugging, the result of t.String does include the monotonic
// clock reading if present. If t != u because of different monotonic clock readings,
// that difference will be visible when printing t.String() and u.String().
package time
import (
"errors"
_ "unsafe" // for go:linkname
)
// A Time represents an instant in time with nanosecond precision.
//
// Programs using times should typically store and pass them as values,
// not pointers. That is, time variables and struct fields should be of
// type time.Time, not *time.Time.
//
// A Time value can be used by multiple goroutines simultaneously except
// that the methods GobDecode, UnmarshalBinary, UnmarshalJSON and
// UnmarshalText are not concurrency-safe.
//
// Time instants can be compared using the Before, After, and Equal methods.
// The Sub method subtracts two instants, producing a Duration.
// The Add method adds a Time and a Duration, producing a Time.
//
// The zero value of type Time is January 1, year 1, 00:00:00.000000000 UTC.
// As this time is unlikely to come up in practice, the IsZero method gives
// a simple way of detecting a time that has not been initialized explicitly.
//
// Each Time has associated with it a Location, consulted when computing the
// presentation form of the time, such as in the Format, Hour, and Year methods.
// The methods Local, UTC, and In return a Time with a specific location.
// Changing the location in this way changes only the presentation; it does not
// change the instant in time being denoted and therefore does not affect the
// computations described in earlier paragraphs.
//
// Representations of a Time value saved by the GobEncode, MarshalBinary,
// MarshalJSON, and MarshalText methods store the Time.Location's offset, but not
// the location name. They therefore lose information about Daylight Saving Time.
//
// In addition to the required “wall clock” reading, a Time may contain an optional
// reading of the current process's monotonic clock, to provide additional precision
// for comparison or subtraction.
// See the “Monotonic Clocks” section in the package documentation for details.
//
// Note that the Go == operator compares not just the time instant but also the
// Location and the monotonic clock reading. Therefore, Time values should not
// be used as map or database keys without first guaranteeing that the
// identical Location has been set for all values, which can be achieved
// through use of the UTC or Local method, and that the monotonic clock reading
// has been stripped by setting t = t.Round(0). In general, prefer t.Equal(u)
// to t == u, since t.Equal uses the most accurate comparison available and
// correctly handles the case when only one of its arguments has a monotonic
// clock reading.
type Time struct {
// wall and ext encode the wall time seconds, wall time nanoseconds,
// and optional monotonic clock reading in nanoseconds.
//
// From high to low bit position, wall encodes a 1-bit flag (hasMonotonic),
// a 33-bit seconds field, and a 30-bit wall time nanoseconds field.
// The nanoseconds field is in the range [0, 999999999].
// If the hasMonotonic bit is 0, then the 33-bit field must be zero
// and the full signed 64-bit wall seconds since Jan 1 year 1 is stored in ext.
// If the hasMonotonic bit is 1, then the 33-bit field holds a 33-bit
// unsigned wall seconds since Jan 1 year 1885, and ext holds a
// signed 64-bit monotonic clock reading, nanoseconds since process start.
wall uint64
ext int64
// loc specifies the Location that should be used to
// determine the minute, hour, month, day, and year
// that correspond to this Time.
// The nil location means UTC.
// All UTC times are represented with loc==nil, never loc==&utcLoc.
loc *Location
}
const (
hasMonotonic = 1 << 63
maxWall = wallToInternal + (1<<33 - 1) // year 2157
minWall = wallToInternal // year 1885
nsecMask = 1<<30 - 1
nsecShift = 30
)
// These helpers for manipulating the wall and monotonic clock readings
// take pointer receivers, even when they don't modify the time,
// to make them cheaper to call.
// nsec returns the time's nanoseconds.
func (t *Time) nsec() int32 {
return int32(t.wall & nsecMask)
}
// sec returns the time's seconds since Jan 1 year 1.
func (t *Time) sec() int64 {
if t.wall&hasMonotonic != 0 {
return wallToInternal + int64(t.wall<<1>>(nsecShift+1))
}
return t.ext
}
// unixSec returns the time's seconds since Jan 1 1970 (Unix time).
func (t *Time) unixSec() int64 { return t.sec() + internalToUnix }
// addSec adds d seconds to the time.
func (t *Time) addSec(d int64) {
if t.wall&hasMonotonic != 0 {
sec := int64(t.wall << 1 >> (nsecShift + 1))
dsec := sec + d
if 0 <= dsec && dsec <= 1<<33-1 {
t.wall = t.wall&nsecMask | uint64(dsec)<<nsecShift | hasMonotonic
return
}
// Wall second now out of range for packed field.
// Move to ext.
t.stripMono()
}
// Check if the sum of t.ext and d overflows and handle it properly.
sum := t.ext + d
if (sum > t.ext) == (d > 0) {
t.ext = sum
} else if d > 0 {
t.ext = 1<<63 - 1
} else {
t.ext = -(1<<63 - 1)
}
}
// setLoc sets the location associated with the time.
func (t *Time) setLoc(loc *Location) {
if loc == &utcLoc {
loc = nil
}
t.stripMono()
t.loc = loc
}
// stripMono strips the monotonic clock reading in t.
func (t *Time) stripMono() {
if t.wall&hasMonotonic != 0 {
t.ext = t.sec()
t.wall &= nsecMask
}
}
// setMono sets the monotonic clock reading in t.
// If t cannot hold a monotonic clock reading,
// because its wall time is too large,
// setMono is a no-op.
func (t *Time) setMono(m int64) {
if t.wall&hasMonotonic == 0 {
sec := t.ext
if sec < minWall || maxWall < sec {
return
}
t.wall |= hasMonotonic | uint64(sec-minWall)<<nsecShift
}
t.ext = m
}
// mono returns t's monotonic clock reading.
// It returns 0 for a missing reading.
// This function is used only for testing,
// so it's OK that technically 0 is a valid
// monotonic clock reading as well.
func (t *Time) mono() int64 {
if t.wall&hasMonotonic == 0 {
return 0
}
return t.ext
}
// After reports whether the time instant t is after u.
func (t Time) After(u Time) bool {
if t.wall&u.wall&hasMonotonic != 0 {
return t.ext > u.ext
}
ts := t.sec()
us := u.sec()
return ts > us || ts == us && t.nsec() > u.nsec()
}
// Before reports whether the time instant t is before u.
func (t Time) Before(u Time) bool {
if t.wall&u.wall&hasMonotonic != 0 {
return t.ext < u.ext
}
ts := t.sec()
us := u.sec()
return ts < us || ts == us && t.nsec() < u.nsec()
}
// Compare compares the time instant t with u. If t is before u, it returns -1;
// if t is after u, it returns +1; if they're the same, it returns 0.
func (t Time) Compare(u Time) int {
var tc, uc int64
if t.wall&u.wall&hasMonotonic != 0 {
tc, uc = t.ext, u.ext
} else {
tc, uc = t.sec(), u.sec()
if tc == uc {
tc, uc = int64(t.nsec()), int64(u.nsec())
}
}
switch {
case tc < uc:
return -1
case tc > uc:
return +1
}
return 0
}
// Equal reports whether t and u represent the same time instant.
// Two times can be equal even if they are in different locations.
// For example, 6:00 +0200 and 4:00 UTC are Equal.
// See the documentation on the Time type for the pitfalls of using == with
// Time values; most code should use Equal instead.
func (t Time) Equal(u Time) bool {
if t.wall&u.wall&hasMonotonic != 0 {
return t.ext == u.ext
}
return t.sec() == u.sec() && t.nsec() == u.nsec()
}
// A Month specifies a month of the year (January = 1, ...).
type Month int
const (
January Month = 1 + iota
February
March
April
May
June
July
August
September
October
November
December
)
// String returns the English name of the month ("January", "February", ...).
func (m Month) String() string {
if January <= m && m <= December {
return longMonthNames[m-1]
}
buf := make([]byte, 20)
n := fmtInt(buf, uint64(m))
return "%!Month(" + string(buf[n:]) + ")"
}
// A Weekday specifies a day of the week (Sunday = 0, ...).
type Weekday int
const (
Sunday Weekday = iota
Monday
Tuesday
Wednesday
Thursday
Friday
Saturday
)
// String returns the English name of the day ("Sunday", "Monday", ...).
func (d Weekday) String() string {
if Sunday <= d && d <= Saturday {
return longDayNames[d]
}
buf := make([]byte, 20)
n := fmtInt(buf, uint64(d))
return "%!Weekday(" + string(buf[n:]) + ")"
}
// Computations on time.
//
// The zero value for a Time is defined to be
// January 1, year 1, 00:00:00.000000000 UTC
// which (1) looks like a zero, or as close as you can get in a date
// (1-1-1 00:00:00 UTC), (2) is unlikely enough to arise in practice to
// be a suitable "not set" sentinel, unlike Jan 1 1970, and (3) has a
// non-negative year even in time zones west of UTC, unlike 1-1-0
// 00:00:00 UTC, which would be 12-31-(-1) 19:00:00 in New York.
//
// The zero Time value does not force a specific epoch for the time
// representation. For example, to use the Unix epoch internally, we
// could define that to distinguish a zero value from Jan 1 1970, that
// time would be represented by sec=-1, nsec=1e9. However, it does
// suggest a representation, namely using 1-1-1 00:00:00 UTC as the
// epoch, and that's what we do.
//
// The Add and Sub computations are oblivious to the choice of epoch.
//
// The presentation computations - year, month, minute, and so on - all
// rely heavily on division and modulus by positive constants. For
// calendrical calculations we want these divisions to round down, even
// for negative values, so that the remainder is always positive, but
// Go's division (like most hardware division instructions) rounds to
// zero. We can still do those computations and then adjust the result
// for a negative numerator, but it's annoying to write the adjustment
// over and over. Instead, we can change to a different epoch so long
// ago that all the times we care about will be positive, and then round
// to zero and round down coincide. These presentation routines already
// have to add the zone offset, so adding the translation to the
// alternate epoch is cheap. For example, having a non-negative time t
// means that we can write
//
// sec = t % 60
//
// instead of
//
// sec = t % 60
// if sec < 0 {
// sec += 60
// }
//
// everywhere.
//
// The calendar runs on an exact 400 year cycle: a 400-year calendar
// printed for 1970-2369 will apply as well to 2370-2769. Even the days
// of the week match up. It simplifies the computations to choose the
// cycle boundaries so that the exceptional years are always delayed as
// long as possible. That means choosing a year equal to 1 mod 400, so
// that the first leap year is the 4th year, the first missed leap year
// is the 100th year, and the missed missed leap year is the 400th year.
// So we'd prefer instead to print a calendar for 2001-2400 and reuse it
// for 2401-2800.
//
// Finally, it's convenient if the delta between the Unix epoch and
// long-ago epoch is representable by an int64 constant.
//
// These three considerations—choose an epoch as early as possible, that
// uses a year equal to 1 mod 400, and that is no more than 2⁶³ seconds
// earlier than 1970—bring us to the year -292277022399. We refer to
// this year as the absolute zero year, and to times measured as a uint64
// seconds since this year as absolute times.
//
// Times measured as an int64 seconds since the year 1—the representation
// used for Time's sec field—are called internal times.
//
// Times measured as an int64 seconds since the year 1970 are called Unix
// times.
//
// It is tempting to just use the year 1 as the absolute epoch, defining
// that the routines are only valid for years >= 1. However, the
// routines would then be invalid when displaying the epoch in time zones
// west of UTC, since it is year 0. It doesn't seem tenable to say that
// printing the zero time correctly isn't supported in half the time
// zones. By comparison, it's reasonable to mishandle some times in
// the year -292277022399.
//
// All this is opaque to clients of the API and can be changed if a
// better implementation presents itself.
const (
// The unsigned zero year for internal calculations.
// Must be 1 mod 400, and times before it will not compute correctly,
// but otherwise can be changed at will.
absoluteZeroYear = -292277022399
// The year of the zero Time.
// Assumed by the unixToInternal computation below.
internalYear = 1
// Offsets to convert between internal and absolute or Unix times.
absoluteToInternal int64 = (absoluteZeroYear - internalYear) * 365.2425 * secondsPerDay
internalToAbsolute = -absoluteToInternal
unixToInternal int64 = (1969*365 + 1969/4 - 1969/100 + 1969/400) * secondsPerDay
internalToUnix int64 = -unixToInternal
wallToInternal int64 = (1884*365 + 1884/4 - 1884/100 + 1884/400) * secondsPerDay
)
// IsZero reports whether t represents the zero time instant,
// January 1, year 1, 00:00:00 UTC.
func (t Time) IsZero() bool {
return t.sec() == 0 && t.nsec() == 0
}
// abs returns the time t as an absolute time, adjusted by the zone offset.
// It is called when computing a presentation property like Month or Hour.
func (t Time) abs() uint64 {
l := t.loc
// Avoid function calls when possible.
if l == nil || l == &localLoc {
l = l.get()
}
sec := t.unixSec()
if l != &utcLoc {
if l.cacheZone != nil && l.cacheStart <= sec && sec < l.cacheEnd {
sec += int64(l.cacheZone.offset)
} else {
_, offset, _, _, _ := l.lookup(sec)
sec += int64(offset)
}
}
return uint64(sec + (unixToInternal + internalToAbsolute))
}
// locabs is a combination of the Zone and abs methods,
// extracting both return values from a single zone lookup.
func (t Time) locabs() (name string, offset int, abs uint64) {
l := t.loc
if l == nil || l == &localLoc {
l = l.get()
}
// Avoid function call if we hit the local time cache.
sec := t.unixSec()
if l != &utcLoc {
if l.cacheZone != nil && l.cacheStart <= sec && sec < l.cacheEnd {
name = l.cacheZone.name
offset = l.cacheZone.offset
} else {
name, offset, _, _, _ = l.lookup(sec)
}
sec += int64(offset)
} else {
name = "UTC"
}
abs = uint64(sec + (unixToInternal + internalToAbsolute))
return
}
// Date returns the year, month, and day in which t occurs.
func (t Time) Date() (year int, month Month, day int) {
year, month, day, _ = t.date(true)
return
}
// Year returns the year in which t occurs.
func (t Time) Year() int {
year, _, _, _ := t.date(false)
return year
}
// Month returns the month of the year specified by t.
func (t Time) Month() Month {
_, month, _, _ := t.date(true)
return month
}
// Day returns the day of the month specified by t.
func (t Time) Day() int {
_, _, day, _ := t.date(true)
return day
}
// Weekday returns the day of the week specified by t.
func (t Time) Weekday() Weekday {
return absWeekday(t.abs())
}
// absWeekday is like Weekday but operates on an absolute time.
func absWeekday(abs uint64) Weekday {
// January 1 of the absolute year, like January 1 of 2001, was a Monday.
sec := (abs + uint64(Monday)*secondsPerDay) % secondsPerWeek
return Weekday(int(sec) / secondsPerDay)
}
// ISOWeek returns the ISO 8601 year and week number in which t occurs.
// Week ranges from 1 to 53. Jan 01 to Jan 03 of year n might belong to
// week 52 or 53 of year n-1, and Dec 29 to Dec 31 might belong to week 1
// of year n+1.
func (t Time) ISOWeek() (year, week int) {
// According to the rule that the first calendar week of a calendar year is
// the week including the first Thursday of that year, and that the last one is
// the week immediately preceding the first calendar week of the next calendar year.
// See https://www.iso.org/obp/ui#iso:std:iso:8601:-1:ed-1:v1:en:term:3.1.1.23 for details.
// weeks start with Monday
// Monday Tuesday Wednesday Thursday Friday Saturday Sunday
// 1 2 3 4 5 6 7
// +3 +2 +1 0 -1 -2 -3
// the offset to Thursday
abs := t.abs()
d := Thursday - absWeekday(abs)
// handle Sunday
if d == 4 {
d = -3
}
// find the Thursday of the calendar week
abs += uint64(d) * secondsPerDay
year, _, _, yday := absDate(abs, false)
return year, yday/7 + 1
}
// Clock returns the hour, minute, and second within the day specified by t.
func (t Time) Clock() (hour, min, sec int) {
return absClock(t.abs())
}
// absClock is like clock but operates on an absolute time.
func absClock(abs uint64) (hour, min, sec int) {
sec = int(abs % secondsPerDay)
hour = sec / secondsPerHour
sec -= hour * secondsPerHour
min = sec / secondsPerMinute
sec -= min * secondsPerMinute
return
}
// Hour returns the hour within the day specified by t, in the range [0, 23].
func (t Time) Hour() int {
return int(t.abs()%secondsPerDay) / secondsPerHour
}
// Minute returns the minute offset within the hour specified by t, in the range [0, 59].
func (t Time) Minute() int {
return int(t.abs()%secondsPerHour) / secondsPerMinute
}
// Second returns the second offset within the minute specified by t, in the range [0, 59].
func (t Time) Second() int {
return int(t.abs() % secondsPerMinute)
}
// Nanosecond returns the nanosecond offset within the second specified by t,
// in the range [0, 999999999].
func (t Time) Nanosecond() int {
return int(t.nsec())
}
// YearDay returns the day of the year specified by t, in the range [1,365] for non-leap years,
// and [1,366] in leap years.
func (t Time) YearDay() int {
_, _, _, yday := t.date(false)
return yday + 1
}
// A Duration represents the elapsed time between two instants
// as an int64 nanosecond count. The representation limits the
// largest representable duration to approximately 290 years.
type Duration int64
const (
minDuration Duration = -1 << 63
maxDuration Duration = 1<<63 - 1
)
// Common durations. There is no definition for units of Day or larger
// to avoid confusion across daylight savings time zone transitions.
//
// To count the number of units in a Duration, divide:
//
// second := time.Second
// fmt.Print(int64(second/time.Millisecond)) // prints 1000
//
// To convert an integer number of units to a Duration, multiply:
//
// seconds := 10
// fmt.Print(time.Duration(seconds)*time.Second) // prints 10s
const (
Nanosecond Duration = 1
Microsecond = 1000 * Nanosecond
Millisecond = 1000 * Microsecond
Second = 1000 * Millisecond
Minute = 60 * Second
Hour = 60 * Minute
)
// String returns a string representing the duration in the form "72h3m0.5s".
// Leading zero units are omitted. As a special case, durations less than one
// second format use a smaller unit (milli-, micro-, or nanoseconds) to ensure
// that the leading digit is non-zero. The zero duration formats as 0s.
func (d Duration) String() string {
// Largest time is 2540400h10m10.000000000s
var buf [32]byte
w := len(buf)
u := uint64(d)
neg := d < 0
if neg {
u = -u
}
if u < uint64(Second) {
// Special case: if duration is smaller than a second,
// use smaller units, like 1.2ms
var prec int
w--
buf[w] = 's'
w--
switch {
case u == 0:
return "0s"
case u < uint64(Microsecond):
// print nanoseconds
prec = 0
buf[w] = 'n'
case u < uint64(Millisecond):
// print microseconds
prec = 3
// U+00B5 'µ' micro sign == 0xC2 0xB5
w-- // Need room for two bytes.
copy(buf[w:], "µ")
default:
// print milliseconds
prec = 6
buf[w] = 'm'
}
w, u = fmtFrac(buf[:w], u, prec)
w = fmtInt(buf[:w], u)
} else {
w--
buf[w] = 's'
w, u = fmtFrac(buf[:w], u, 9)
// u is now integer seconds
w = fmtInt(buf[:w], u%60)
u /= 60
// u is now integer minutes
if u > 0 {
w--
buf[w] = 'm'
w = fmtInt(buf[:w], u%60)
u /= 60
// u is now integer hours
// Stop at hours because days can be different lengths.
if u > 0 {
w--
buf[w] = 'h'
w = fmtInt(buf[:w], u)
}
}
}
if neg {
w--
buf[w] = '-'
}
return string(buf[w:])
}
// fmtFrac formats the fraction of v/10**prec (e.g., ".12345") into the
// tail of buf, omitting trailing zeros. It omits the decimal
// point too when the fraction is 0. It returns the index where the
// output bytes begin and the value v/10**prec.
func fmtFrac(buf []byte, v uint64, prec int) (nw int, nv uint64) {
// Omit trailing zeros up to and including decimal point.
w := len(buf)
print := false
for i := 0; i < prec; i++ {
digit := v % 10
print = print || digit != 0
if print {
w--
buf[w] = byte(digit) + '0'
}
v /= 10
}
if print {
w--
buf[w] = '.'
}
return w, v
}
// fmtInt formats v into the tail of buf.
// It returns the index where the output begins.
func fmtInt(buf []byte, v uint64) int {
w := len(buf)
if v == 0 {
w--
buf[w] = '0'
} else {
for v > 0 {
w--
buf[w] = byte(v%10) + '0'
v /= 10
}
}
return w
}
// Nanoseconds returns the duration as an integer nanosecond count.
func (d Duration) Nanoseconds() int64 { return int64(d) }
// Microseconds returns the duration as an integer microsecond count.
func (d Duration) Microseconds() int64 { return int64(d) / 1e3 }
// Milliseconds returns the duration as an integer millisecond count.
func (d Duration) Milliseconds() int64 { return int64(d) / 1e6 }
// These methods return float64 because the dominant
// use case is for printing a floating point number like 1.5s, and
// a truncation to integer would make them not useful in those cases.
// Splitting the integer and fraction ourselves guarantees that
// converting the returned float64 to an integer rounds the same
// way that a pure integer conversion would have, even in cases
// where, say, float64(d.Nanoseconds())/1e9 would have rounded
// differently.
// Seconds returns the duration as a floating point number of seconds.
func (d Duration) Seconds() float64 {
sec := d / Second
nsec := d % Second
return float64(sec) + float64(nsec)/1e9
}
// Minutes returns the duration as a floating point number of minutes.
func (d Duration) Minutes() float64 {
min := d / Minute
nsec := d % Minute
return float64(min) + float64(nsec)/(60*1e9)
}
// Hours returns the duration as a floating point number of hours.
func (d Duration) Hours() float64 {
hour := d / Hour
nsec := d % Hour
return float64(hour) + float64(nsec)/(60*60*1e9)
}
// Truncate returns the result of rounding d toward zero to a multiple of m.
// If m <= 0, Truncate returns d unchanged.
func (d Duration) Truncate(m Duration) Duration {
if m <= 0 {
return d
}
return d - d%m
}
// lessThanHalf reports whether x+x < y but avoids overflow,
// assuming x and y are both positive (Duration is signed).
func lessThanHalf(x, y Duration) bool {
return uint64(x)+uint64(x) < uint64(y)
}
// Round returns the result of rounding d to the nearest multiple of m.
// The rounding behavior for halfway values is to round away from zero.
// If the result exceeds the maximum (or minimum)
// value that can be stored in a Duration,
// Round returns the maximum (or minimum) duration.
// If m <= 0, Round returns d unchanged.
func (d Duration) Round(m Duration) Duration {
if m <= 0 {
return d
}
r := d % m
if d < 0 {
r = -r
if lessThanHalf(r, m) {
return d + r
}
if d1 := d - m + r; d1 < d {
return d1
}
return minDuration // overflow
}
if lessThanHalf(r, m) {
return d - r
}
if d1 := d + m - r; d1 > d {
return d1
}
return maxDuration // overflow
}
// Abs returns the absolute value of d.
// As a special case, math.MinInt64 is converted to math.MaxInt64.
func (d Duration) Abs() Duration {
switch {
case d >= 0:
return d
case d == minDuration:
return maxDuration
default:
return -d
}
}
// Add returns the time t+d.
func (t Time) Add(d Duration) Time {
dsec := int64(d / 1e9)
nsec := t.nsec() + int32(d%1e9)
if nsec >= 1e9 {
dsec++
nsec -= 1e9
} else if nsec < 0 {
dsec--
nsec += 1e9
}
t.wall = t.wall&^nsecMask | uint64(nsec) // update nsec
t.addSec(dsec)
if t.wall&hasMonotonic != 0 {
te := t.ext + int64(d)
if d < 0 && te > t.ext || d > 0 && te < t.ext {
// Monotonic clock reading now out of range; degrade to wall-only.
t.stripMono()
} else {
t.ext = te
}
}
return t
}
// Sub returns the duration t-u. If the result exceeds the maximum (or minimum)
// value that can be stored in a Duration, the maximum (or minimum) duration
// will be returned.
// To compute t-d for a duration d, use t.Add(-d).
func (t Time) Sub(u Time) Duration {
if t.wall&u.wall&hasMonotonic != 0 {
te := t.ext
ue := u.ext
d := Duration(te - ue)
if d < 0 && te > ue {
return maxDuration // t - u is positive out of range
}
if d > 0 && te < ue {
return minDuration // t - u is negative out of range
}
return d
}
d := Duration(t.sec()-u.sec())*Second + Duration(t.nsec()-u.nsec())
// Check for overflow or underflow.
switch {
case u.Add(d).Equal(t):
return d // d is correct
case t.Before(u):
return minDuration // t - u is negative out of range
default:
return maxDuration // t - u is positive out of range
}
}
// Since returns the time elapsed since t.
// It is shorthand for time.Now().Sub(t).
func Since(t Time) Duration {
var now Time
if t.wall&hasMonotonic != 0 {
// Common case optimization: if t has monotonic time, then Sub will use only it.
now = Time{hasMonotonic, runtimeNano() - startNano, nil}
} else {
now = Now()
}
return now.Sub(t)
}
// Until returns the duration until t.
// It is shorthand for t.Sub(time.Now()).
func Until(t Time) Duration {
var now Time
if t.wall&hasMonotonic != 0 {
// Common case optimization: if t has monotonic time, then Sub will use only it.
now = Time{hasMonotonic, runtimeNano() - startNano, nil}
} else {
now = Now()
}
return t.Sub(now)
}
// AddDate returns the time corresponding to adding the
// given number of years, months, and days to t.
// For example, AddDate(-1, 2, 3) applied to January 1, 2011
// returns March 4, 2010.
//
// AddDate normalizes its result in the same way that Date does,
// so, for example, adding one month to October 31 yields
// December 1, the normalized form for November 31.
func (t Time) AddDate(years int, months int, days int) Time {
year, month, day := t.Date()
hour, min, sec := t.Clock()
return Date(year+years, month+Month(months), day+days, hour, min, sec, int(t.nsec()), t.Location())
}
const (
secondsPerMinute = 60
secondsPerHour = 60 * secondsPerMinute
secondsPerDay = 24 * secondsPerHour
secondsPerWeek = 7 * secondsPerDay
daysPer400Years = 365*400 + 97
daysPer100Years = 365*100 + 24
daysPer4Years = 365*4 + 1
)
// date computes the year, day of year, and when full=true,
// the month and day in which t occurs.
func (t Time) date(full bool) (year int, month Month, day int, yday int) {
return absDate(t.abs(), full)
}
// absDate is like date but operates on an absolute time.
func absDate(abs uint64, full bool) (year int, month Month, day int, yday int) {
// Split into time and day.
d := abs / secondsPerDay
// Account for 400 year cycles.
n := d / daysPer400Years
y := 400 * n
d -= daysPer400Years * n
// Cut off 100-year cycles.
// The last cycle has one extra leap year, so on the last day
// of that year, day / daysPer100Years will be 4 instead of 3.
// Cut it back down to 3 by subtracting n>>2.
n = d / daysPer100Years
n -= n >> 2
y += 100 * n
d -= daysPer100Years * n
// Cut off 4-year cycles.
// The last cycle has a missing leap year, which does not
// affect the computation.
n = d / daysPer4Years
y += 4 * n
d -= daysPer4Years * n
// Cut off years within a 4-year cycle.
// The last year is a leap year, so on the last day of that year,
// day / 365 will be 4 instead of 3. Cut it back down to 3
// by subtracting n>>2.
n = d / 365
n -= n >> 2
y += n
d -= 365 * n
year = int(int64(y) + absoluteZeroYear)
yday = int(d)
if !full {
return
}
day = yday
if isLeap(year) {
// Leap year
switch {
case day > 31+29-1:
// After leap day; pretend it wasn't there.
day--
case day == 31+29-1:
// Leap day.
month = February
day = 29
return
}
}
// Estimate month on assumption that every month has 31 days.
// The estimate may be too low by at most one month, so adjust.
month = Month(day / 31)
end := int(daysBefore[month+1])
var begin int
if day >= end {
month++
begin = end
} else {
begin = int(daysBefore[month])
}
month++ // because January is 1
day = day - begin + 1
return
}
// daysBefore[m] counts the number of days in a non-leap year
// before month m begins. There is an entry for m=12, counting
// the number of days before January of next year (365).
var daysBefore = [...]int32{
0,
31,
31 + 28,
31 + 28 + 31,
31 + 28 + 31 + 30,
31 + 28 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30,
31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30 + 31,
}
func daysIn(m Month, year int) int {
if m == February && isLeap(year) {
return 29
}
return int(daysBefore[m] - daysBefore[m-1])
}
// daysSinceEpoch takes a year and returns the number of days from
// the absolute epoch to the start of that year.
// This is basically (year - zeroYear) * 365, but accounting for leap days.
func daysSinceEpoch(year int) uint64 {
y := uint64(int64(year) - absoluteZeroYear)
// Add in days from 400-year cycles.
n := y / 400
y -= 400 * n
d := daysPer400Years * n
// Add in 100-year cycles.
n = y / 100
y -= 100 * n
d += daysPer100Years * n
// Add in 4-year cycles.
n = y / 4
y -= 4 * n
d += daysPer4Years * n
// Add in non-leap years.
n = y
d += 365 * n
return d
}
// Provided by package runtime.
func now() (sec int64, nsec int32, mono int64)
// runtimeNano returns the current value of the runtime clock in nanoseconds.
//
//go:linkname runtimeNano runtime.nanotime
func runtimeNano() int64
// Monotonic times are reported as offsets from startNano.
// We initialize startNano to runtimeNano() - 1 so that on systems where
// monotonic time resolution is fairly low (e.g. Windows 2008
// which appears to have a default resolution of 15ms),
// we avoid ever reporting a monotonic time of 0.
// (Callers may want to use 0 as "time not set".)
var startNano int64 = runtimeNano() - 1
// Now returns the current local time.
func Now() Time {
sec, nsec, mono := now()
mono -= startNano
sec += unixToInternal - minWall
if uint64(sec)>>33 != 0 {
// Seconds field overflowed the 33 bits available when
// storing a monotonic time. This will be true after
// March 16, 2157.
return Time{uint64(nsec), sec + minWall, Local}
}
return Time{hasMonotonic | uint64(sec)<<nsecShift | uint64(nsec), mono, Local}
}
func unixTime(sec int64, nsec int32) Time {
return Time{uint64(nsec), sec + unixToInternal, Local}
}
// UTC returns t with the location set to UTC.
func (t Time) UTC() Time {
t.setLoc(&utcLoc)
return t
}
// Local returns t with the location set to local time.
func (t Time) Local() Time {
t.setLoc(Local)
return t
}
// In returns a copy of t representing the same time instant, but
// with the copy's location information set to loc for display
// purposes.
//
// In panics if loc is nil.
func (t Time) In(loc *Location) Time {
if loc == nil {
panic("time: missing Location in call to Time.In")
}
t.setLoc(loc)
return t
}
// Location returns the time zone information associated with t.
func (t Time) Location() *Location {
l := t.loc
if l == nil {
l = UTC
}
return l
}
// Zone computes the time zone in effect at time t, returning the abbreviated
// name of the zone (such as "CET") and its offset in seconds east of UTC.
func (t Time) Zone() (name string, offset int) {
name, offset, _, _, _ = t.loc.lookup(t.unixSec())
return
}
// ZoneBounds returns the bounds of the time zone in effect at time t.
// The zone begins at start and the next zone begins at end.
// If the zone begins at the beginning of time, start will be returned as a zero Time.
// If the zone goes on forever, end will be returned as a zero Time.
// The Location of the returned times will be the same as t.
func (t Time) ZoneBounds() (start, end Time) {
_, _, startSec, endSec, _ := t.loc.lookup(t.unixSec())
if startSec != alpha {
start = unixTime(startSec, 0)
start.setLoc(t.loc)
}
if endSec != omega {
end = unixTime(endSec, 0)
end.setLoc(t.loc)
}
return
}
// Unix returns t as a Unix time, the number of seconds elapsed
// since January 1, 1970 UTC. The result does not depend on the
// location associated with t.
// Unix-like operating systems often record time as a 32-bit
// count of seconds, but since the method here returns a 64-bit
// value it is valid for billions of years into the past or future.
func (t Time) Unix() int64 {
return t.unixSec()
}
// UnixMilli returns t as a Unix time, the number of milliseconds elapsed since
// January 1, 1970 UTC. The result is undefined if the Unix time in
// milliseconds cannot be represented by an int64 (a date more than 292 million
// years before or after 1970). The result does not depend on the
// location associated with t.
func (t Time) UnixMilli() int64 {
return t.unixSec()*1e3 + int64(t.nsec())/1e6
}
// UnixMicro returns t as a Unix time, the number of microseconds elapsed since
// January 1, 1970 UTC. The result is undefined if the Unix time in
// microseconds cannot be represented by an int64 (a date before year -290307 or
// after year 294246). The result does not depend on the location associated
// with t.
func (t Time) UnixMicro() int64 {
return t.unixSec()*1e6 + int64(t.nsec())/1e3
}
// UnixNano returns t as a Unix time, the number of nanoseconds elapsed
// since January 1, 1970 UTC. The result is undefined if the Unix time
// in nanoseconds cannot be represented by an int64 (a date before the year
// 1678 or after 2262). Note that this means the result of calling UnixNano
// on the zero Time is undefined. The result does not depend on the
// location associated with t.
func (t Time) UnixNano() int64 {
return (t.unixSec())*1e9 + int64(t.nsec())
}
const (
timeBinaryVersionV1 byte = iota + 1 // For general situation
timeBinaryVersionV2 // For LMT only
)
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (t Time) MarshalBinary() ([]byte, error) {
var offsetMin int16 // minutes east of UTC. -1 is UTC.
var offsetSec int8
version := timeBinaryVersionV1
if t.Location() == UTC {
offsetMin = -1
} else {
_, offset := t.Zone()
if offset%60 != 0 {
version = timeBinaryVersionV2
offsetSec = int8(offset % 60)
}
offset /= 60
if offset < -32768 || offset == -1 || offset > 32767 {
return nil, errors.New("Time.MarshalBinary: unexpected zone offset")
}
offsetMin = int16(offset)
}
sec := t.sec()
nsec := t.nsec()
enc := []byte{
version, // byte 0 : version
byte(sec >> 56), // bytes 1-8: seconds
byte(sec >> 48),
byte(sec >> 40),
byte(sec >> 32),
byte(sec >> 24),
byte(sec >> 16),
byte(sec >> 8),
byte(sec),
byte(nsec >> 24), // bytes 9-12: nanoseconds
byte(nsec >> 16),
byte(nsec >> 8),
byte(nsec),
byte(offsetMin >> 8), // bytes 13-14: zone offset in minutes
byte(offsetMin),
}
if version == timeBinaryVersionV2 {
enc = append(enc, byte(offsetSec))
}
return enc, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func (t *Time) UnmarshalBinary(data []byte) error {
buf := data
if len(buf) == 0 {
return errors.New("Time.UnmarshalBinary: no data")
}
version := buf[0]
if version != timeBinaryVersionV1 && version != timeBinaryVersionV2 {
return errors.New("Time.UnmarshalBinary: unsupported version")
}
wantLen := /*version*/ 1 + /*sec*/ 8 + /*nsec*/ 4 + /*zone offset*/ 2
if version == timeBinaryVersionV2 {
wantLen++
}
if len(buf) != wantLen {
return errors.New("Time.UnmarshalBinary: invalid length")
}
buf = buf[1:]
sec := int64(buf[7]) | int64(buf[6])<<8 | int64(buf[5])<<16 | int64(buf[4])<<24 |
int64(buf[3])<<32 | int64(buf[2])<<40 | int64(buf[1])<<48 | int64(buf[0])<<56
buf = buf[8:]
nsec := int32(buf[3]) | int32(buf[2])<<8 | int32(buf[1])<<16 | int32(buf[0])<<24
buf = buf[4:]
offset := int(int16(buf[1])|int16(buf[0])<<8) * 60
if version == timeBinaryVersionV2 {
offset += int(buf[2])
}
*t = Time{}
t.wall = uint64(nsec)
t.ext = sec
if offset == -1*60 {
t.setLoc(&utcLoc)
} else if _, localoff, _, _, _ := Local.lookup(t.unixSec()); offset == localoff {
t.setLoc(Local)
} else {
t.setLoc(FixedZone("", offset))
}
return nil
}
// TODO(rsc): Remove GobEncoder, GobDecoder, MarshalJSON, UnmarshalJSON in Go 2.
// The same semantics will be provided by the generic MarshalBinary, MarshalText,
// UnmarshalBinary, UnmarshalText.
// GobEncode implements the gob.GobEncoder interface.
func (t Time) GobEncode() ([]byte, error) {
return t.MarshalBinary()
}
// GobDecode implements the gob.GobDecoder interface.
func (t *Time) GobDecode(data []byte) error {
return t.UnmarshalBinary(data)
}
// MarshalJSON implements the json.Marshaler interface.
// The time is a quoted string in the RFC 3339 format with sub-second precision.
// If the timestamp cannot be represented as valid RFC 3339
// (e.g., the year is out of range), then an error is reported.
func (t Time) MarshalJSON() ([]byte, error) {
b := make([]byte, 0, len(RFC3339Nano)+len(`""`))
b = append(b, '"')
b, err := t.appendStrictRFC3339(b)
b = append(b, '"')
if err != nil {
return nil, errors.New("Time.MarshalJSON: " + err.Error())
}
return b, nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// The time must be a quoted string in the RFC 3339 format.
func (t *Time) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// TODO(https://go.dev/issue/47353): Properly unescape a JSON string.
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' {
return errors.New("Time.UnmarshalJSON: input is not a JSON string")
}
data = data[len(`"`) : len(data)-len(`"`)]
var err error
*t, err = parseStrictRFC3339(data)
return err
}
// MarshalText implements the encoding.TextMarshaler interface.
// The time is formatted in RFC 3339 format with sub-second precision.
// If the timestamp cannot be represented as valid RFC 3339
// (e.g., the year is out of range), then an error is reported.
func (t Time) MarshalText() ([]byte, error) {
b := make([]byte, 0, len(RFC3339Nano))
b, err := t.appendStrictRFC3339(b)
if err != nil {
return nil, errors.New("Time.MarshalText: " + err.Error())
}
return b, nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// The time must be in the RFC 3339 format.
func (t *Time) UnmarshalText(data []byte) error {
var err error
*t, err = parseStrictRFC3339(data)
return err
}
// Unix returns the local Time corresponding to the given Unix time,
// sec seconds and nsec nanoseconds since January 1, 1970 UTC.
// It is valid to pass nsec outside the range [0, 999999999].
// Not all sec values have a corresponding time value. One such
// value is 1<<63-1 (the largest int64 value).
func Unix(sec int64, nsec int64) Time {
if nsec < 0 || nsec >= 1e9 {
n := nsec / 1e9
sec += n
nsec -= n * 1e9
if nsec < 0 {
nsec += 1e9
sec--
}
}
return unixTime(sec, int32(nsec))
}
// UnixMilli returns the local Time corresponding to the given Unix time,
// msec milliseconds since January 1, 1970 UTC.
func UnixMilli(msec int64) Time {
return Unix(msec/1e3, (msec%1e3)*1e6)
}
// UnixMicro returns the local Time corresponding to the given Unix time,
// usec microseconds since January 1, 1970 UTC.
func UnixMicro(usec int64) Time {
return Unix(usec/1e6, (usec%1e6)*1e3)
}
// IsDST reports whether the time in the configured location is in Daylight Savings Time.
func (t Time) IsDST() bool {
_, _, _, _, isDST := t.loc.lookup(t.Unix())
return isDST
}
func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0)
}
// norm returns nhi, nlo such that
//
// hi * base + lo == nhi * base + nlo
// 0 <= nlo < base
func norm(hi, lo, base int) (nhi, nlo int) {
if lo < 0 {
n := (-lo-1)/base + 1
hi -= n
lo += n * base
}
if lo >= base {
n := lo / base
hi += n
lo -= n * base
}
return hi, lo
}
// Date returns the Time corresponding to
//
// yyyy-mm-dd hh:mm:ss + nsec nanoseconds
//
// in the appropriate zone for that time in the given location.
//
// The month, day, hour, min, sec, and nsec values may be outside
// their usual ranges and will be normalized during the conversion.
// For example, October 32 converts to November 1.
//
// A daylight savings time transition skips or repeats times.
// For example, in the United States, March 13, 2011 2:15am never occurred,
// while November 6, 2011 1:15am occurred twice. In such cases, the
// choice of time zone, and therefore the time, is not well-defined.
// Date returns a time that is correct in one of the two zones involved
// in the transition, but it does not guarantee which.
//
// Date panics if loc is nil.
func Date(year int, month Month, day, hour, min, sec, nsec int, loc *Location) Time {
if loc == nil {
panic("time: missing Location in call to Date")
}
// Normalize month, overflowing into year.
m := int(month) - 1
year, m = norm(year, m, 12)
month = Month(m) + 1
// Normalize nsec, sec, min, hour, overflowing into day.
sec, nsec = norm(sec, nsec, 1e9)
min, sec = norm(min, sec, 60)
hour, min = norm(hour, min, 60)
day, hour = norm(day, hour, 24)
// Compute days since the absolute epoch.
d := daysSinceEpoch(year)
// Add in days before this month.
d += uint64(daysBefore[month-1])
if isLeap(year) && month >= March {
d++ // February 29
}
// Add in days before today.
d += uint64(day - 1)
// Add in time elapsed today.
abs := d * secondsPerDay
abs += uint64(hour*secondsPerHour + min*secondsPerMinute + sec)
unix := int64(abs) + (absoluteToInternal + internalToUnix)
// Look for zone offset for expected time, so we can adjust to UTC.
// The lookup function expects UTC, so first we pass unix in the
// hope that it will not be too close to a zone transition,
// and then adjust if it is.
_, offset, start, end, _ := loc.lookup(unix)
if offset != 0 {
utc := unix - int64(offset)
// If utc is valid for the time zone we found, then we have the right offset.
// If not, we get the correct offset by looking up utc in the location.
if utc < start || utc >= end {
_, offset, _, _, _ = loc.lookup(utc)
}
unix -= int64(offset)
}
t := unixTime(unix, int32(nsec))
t.setLoc(loc)
return t
}
// Truncate returns the result of rounding t down to a multiple of d (since the zero time).
// If d <= 0, Truncate returns t stripped of any monotonic clock reading but otherwise unchanged.
//
// Truncate operates on the time as an absolute duration since the
// zero time; it does not operate on the presentation form of the
// time. Thus, Truncate(Hour) may return a time with a non-zero
// minute, depending on the time's Location.
func (t Time) Truncate(d Duration) Time {
t.stripMono()
if d <= 0 {
return t
}
_, r := div(t, d)
return t.Add(-r)
}
// Round returns the result of rounding t to the nearest multiple of d (since the zero time).
// The rounding behavior for halfway values is to round up.
// If d <= 0, Round returns t stripped of any monotonic clock reading but otherwise unchanged.
//
// Round operates on the time as an absolute duration since the
// zero time; it does not operate on the presentation form of the
// time. Thus, Round(Hour) may return a time with a non-zero
// minute, depending on the time's Location.
func (t Time) Round(d Duration) Time {
t.stripMono()
if d <= 0 {
return t
}
_, r := div(t, d)
if lessThanHalf(r, d) {
return t.Add(-r)
}
return t.Add(d - r)
}
// div divides t by d and returns the quotient parity and remainder.
// We don't use the quotient parity anymore (round half up instead of round to even)
// but it's still here in case we change our minds.
func div(t Time, d Duration) (qmod2 int, r Duration) {
neg := false
nsec := t.nsec()
sec := t.sec()
if sec < 0 {
// Operate on absolute value.
neg = true
sec = -sec
nsec = -nsec
if nsec < 0 {
nsec += 1e9
sec-- // sec >= 1 before the -- so safe
}
}
switch {
// Special case: 2d divides 1 second.
case d < Second && Second%(d+d) == 0:
qmod2 = int(nsec/int32(d)) & 1
r = Duration(nsec % int32(d))
// Special case: d is a multiple of 1 second.
case d%Second == 0:
d1 := int64(d / Second)
qmod2 = int(sec/d1) & 1
r = Duration(sec%d1)*Second + Duration(nsec)
// General case.
// This could be faster if more cleverness were applied,
// but it's really only here to avoid special case restrictions in the API.
// No one will care about these cases.
default:
// Compute nanoseconds as 128-bit number.
sec := uint64(sec)
tmp := (sec >> 32) * 1e9
u1 := tmp >> 32
u0 := tmp << 32
tmp = (sec & 0xFFFFFFFF) * 1e9
u0x, u0 := u0, u0+tmp
if u0 < u0x {
u1++
}
u0x, u0 = u0, u0+uint64(nsec)
if u0 < u0x {
u1++
}
// Compute remainder by subtracting r<<k for decreasing k.
// Quotient parity is whether we subtract on last round.
d1 := uint64(d)
for d1>>63 != 1 {
d1 <<= 1
}
d0 := uint64(0)
for {
qmod2 = 0
if u1 > d1 || u1 == d1 && u0 >= d0 {
// subtract
qmod2 = 1
u0x, u0 = u0, u0-d0
if u0 > u0x {
u1--
}
u1 -= d1
}
if d1 == 0 && d0 == uint64(d) {
break
}
d0 >>= 1
d0 |= (d1 & 1) << 63
d1 >>= 1
}
r = Duration(u0)
}
if neg && r != 0 {
// If input was negative and not an exact multiple of d, we computed q, r such that
// q*d + r = -t
// But the right answers are given by -(q-1), d-r:
// q*d + r = -t
// -q*d - r = t
// -(q-1)*d + (d - r) = t
qmod2 ^= 1
r = d - r
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package time
import (
"errors"
"sync"
"syscall"
)
//go:generate env ZONEINFO=$GOROOT/lib/time/zoneinfo.zip go run genzabbrs.go -output zoneinfo_abbrs_windows.go
// A Location maps time instants to the zone in use at that time.
// Typically, the Location represents the collection of time offsets
// in use in a geographical area. For many Locations the time offset varies
// depending on whether daylight savings time is in use at the time instant.
type Location struct {
name string
zone []zone
tx []zoneTrans
// The tzdata information can be followed by a string that describes
// how to handle DST transitions not recorded in zoneTrans.
// The format is the TZ environment variable without a colon; see
// https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap08.html.
// Example string, for America/Los_Angeles: PST8PDT,M3.2.0,M11.1.0
extend string
// Most lookups will be for the current time.
// To avoid the binary search through tx, keep a
// static one-element cache that gives the correct
// zone for the time when the Location was created.
// if cacheStart <= t < cacheEnd,
// lookup can return cacheZone.
// The units for cacheStart and cacheEnd are seconds
// since January 1, 1970 UTC, to match the argument
// to lookup.
cacheStart int64
cacheEnd int64
cacheZone *zone
}
// A zone represents a single time zone such as CET.
type zone struct {
name string // abbreviated name, "CET"
offset int // seconds east of UTC
isDST bool // is this zone Daylight Savings Time?
}
// A zoneTrans represents a single time zone transition.
type zoneTrans struct {
when int64 // transition time, in seconds since 1970 GMT
index uint8 // the index of the zone that goes into effect at that time
isstd, isutc bool // ignored - no idea what these mean
}
// alpha and omega are the beginning and end of time for zone
// transitions.
const (
alpha = -1 << 63 // math.MinInt64
omega = 1<<63 - 1 // math.MaxInt64
)
// UTC represents Universal Coordinated Time (UTC).
var UTC *Location = &utcLoc
// utcLoc is separate so that get can refer to &utcLoc
// and ensure that it never returns a nil *Location,
// even if a badly behaved client has changed UTC.
var utcLoc = Location{name: "UTC"}
// Local represents the system's local time zone.
// On Unix systems, Local consults the TZ environment
// variable to find the time zone to use. No TZ means
// use the system default /etc/localtime.
// TZ="" means use UTC.
// TZ="foo" means use file foo in the system timezone directory.
var Local *Location = &localLoc
// localLoc is separate so that initLocal can initialize
// it even if a client has changed Local.
var localLoc Location
var localOnce sync.Once
func (l *Location) get() *Location {
if l == nil {
return &utcLoc
}
if l == &localLoc {
localOnce.Do(initLocal)
}
return l
}
// String returns a descriptive name for the time zone information,
// corresponding to the name argument to LoadLocation or FixedZone.
func (l *Location) String() string {
return l.get().name
}
var unnamedFixedZones []*Location
var unnamedFixedZonesOnce sync.Once
// FixedZone returns a Location that always uses
// the given zone name and offset (seconds east of UTC).
func FixedZone(name string, offset int) *Location {
// Most calls to FixedZone have an unnamed zone with an offset by the hour.
// Optimize for that case by returning the same *Location for a given hour.
const hoursBeforeUTC = 12
const hoursAfterUTC = 14
hour := offset / 60 / 60
if name == "" && -hoursBeforeUTC <= hour && hour <= +hoursAfterUTC && hour*60*60 == offset {
unnamedFixedZonesOnce.Do(func() {
unnamedFixedZones = make([]*Location, hoursBeforeUTC+1+hoursAfterUTC)
for hr := -hoursBeforeUTC; hr <= +hoursAfterUTC; hr++ {
unnamedFixedZones[hr+hoursBeforeUTC] = fixedZone("", hr*60*60)
}
})
return unnamedFixedZones[hour+hoursBeforeUTC]
}
return fixedZone(name, offset)
}
func fixedZone(name string, offset int) *Location {
l := &Location{
name: name,
zone: []zone{{name, offset, false}},
tx: []zoneTrans{{alpha, 0, false, false}},
cacheStart: alpha,
cacheEnd: omega,
}
l.cacheZone = &l.zone[0]
return l
}
// lookup returns information about the time zone in use at an
// instant in time expressed as seconds since January 1, 1970 00:00:00 UTC.
//
// The returned information gives the name of the zone (such as "CET"),
// the start and end times bracketing sec when that zone is in effect,
// the offset in seconds east of UTC (such as -5*60*60), and whether
// the daylight savings is being observed at that time.
func (l *Location) lookup(sec int64) (name string, offset int, start, end int64, isDST bool) {
l = l.get()
if len(l.zone) == 0 {
name = "UTC"
offset = 0
start = alpha
end = omega
isDST = false
return
}
if zone := l.cacheZone; zone != nil && l.cacheStart <= sec && sec < l.cacheEnd {
name = zone.name
offset = zone.offset
start = l.cacheStart
end = l.cacheEnd
isDST = zone.isDST
return
}
if len(l.tx) == 0 || sec < l.tx[0].when {
zone := &l.zone[l.lookupFirstZone()]
name = zone.name
offset = zone.offset
start = alpha
if len(l.tx) > 0 {
end = l.tx[0].when
} else {
end = omega
}
isDST = zone.isDST
return
}
// Binary search for entry with largest time <= sec.
// Not using sort.Search to avoid dependencies.
tx := l.tx
end = omega
lo := 0
hi := len(tx)
for hi-lo > 1 {
m := lo + (hi-lo)/2
lim := tx[m].when
if sec < lim {
end = lim
hi = m
} else {
lo = m
}
}
zone := &l.zone[tx[lo].index]
name = zone.name
offset = zone.offset
start = tx[lo].when
// end = maintained during the search
isDST = zone.isDST
// If we're at the end of the known zone transitions,
// try the extend string.
if lo == len(tx)-1 && l.extend != "" {
if ename, eoffset, estart, eend, eisDST, ok := tzset(l.extend, end, sec); ok {
return ename, eoffset, estart, eend, eisDST
}
}
return
}
// lookupFirstZone returns the index of the time zone to use for times
// before the first transition time, or when there are no transition
// times.
//
// The reference implementation in localtime.c from
// https://www.iana.org/time-zones/repository/releases/tzcode2013g.tar.gz
// implements the following algorithm for these cases:
// 1. If the first zone is unused by the transitions, use it.
// 2. Otherwise, if there are transition times, and the first
// transition is to a zone in daylight time, find the first
// non-daylight-time zone before and closest to the first transition
// zone.
// 3. Otherwise, use the first zone that is not daylight time, if
// there is one.
// 4. Otherwise, use the first zone.
func (l *Location) lookupFirstZone() int {
// Case 1.
if !l.firstZoneUsed() {
return 0
}
// Case 2.
if len(l.tx) > 0 && l.zone[l.tx[0].index].isDST {
for zi := int(l.tx[0].index) - 1; zi >= 0; zi-- {
if !l.zone[zi].isDST {
return zi
}
}
}
// Case 3.
for zi := range l.zone {
if !l.zone[zi].isDST {
return zi
}
}
// Case 4.
return 0
}
// firstZoneUsed reports whether the first zone is used by some
// transition.
func (l *Location) firstZoneUsed() bool {
for _, tx := range l.tx {
if tx.index == 0 {
return true
}
}
return false
}
// tzset takes a timezone string like the one found in the TZ environment
// variable, the end of the last time zone transition expressed as seconds
// since January 1, 1970 00:00:00 UTC, and a time expressed the same way.
// We call this a tzset string since in C the function tzset reads TZ.
// The return values are as for lookup, plus ok which reports whether the
// parse succeeded.
func tzset(s string, initEnd, sec int64) (name string, offset int, start, end int64, isDST, ok bool) {
var (
stdName, dstName string
stdOffset, dstOffset int
)
stdName, s, ok = tzsetName(s)
if ok {
stdOffset, s, ok = tzsetOffset(s)
}
if !ok {
return "", 0, 0, 0, false, false
}
// The numbers in the tzset string are added to local time to get UTC,
// but our offsets are added to UTC to get local time,
// so we negate the number we see here.
stdOffset = -stdOffset
if len(s) == 0 || s[0] == ',' {
// No daylight savings time.
return stdName, stdOffset, initEnd, omega, false, true
}
dstName, s, ok = tzsetName(s)
if ok {
if len(s) == 0 || s[0] == ',' {
dstOffset = stdOffset + secondsPerHour
} else {
dstOffset, s, ok = tzsetOffset(s)
dstOffset = -dstOffset // as with stdOffset, above
}
}
if !ok {
return "", 0, 0, 0, false, false
}
if len(s) == 0 {
// Default DST rules per tzcode.
s = ",M3.2.0,M11.1.0"
}
// The TZ definition does not mention ';' here but tzcode accepts it.
if s[0] != ',' && s[0] != ';' {
return "", 0, 0, 0, false, false
}
s = s[1:]
var startRule, endRule rule
startRule, s, ok = tzsetRule(s)
if !ok || len(s) == 0 || s[0] != ',' {
return "", 0, 0, 0, false, false
}
s = s[1:]
endRule, s, ok = tzsetRule(s)
if !ok || len(s) > 0 {
return "", 0, 0, 0, false, false
}
year, _, _, yday := absDate(uint64(sec+unixToInternal+internalToAbsolute), false)
ysec := int64(yday*secondsPerDay) + sec%secondsPerDay
// Compute start of year in seconds since Unix epoch.
d := daysSinceEpoch(year)
abs := int64(d * secondsPerDay)
abs += absoluteToInternal + internalToUnix
startSec := int64(tzruleTime(year, startRule, stdOffset))
endSec := int64(tzruleTime(year, endRule, dstOffset))
dstIsDST, stdIsDST := true, false
// Note: this is a flipping of "DST" and "STD" while retaining the labels
// This happens in southern hemispheres. The labelling here thus is a little
// inconsistent with the goal.
if endSec < startSec {
startSec, endSec = endSec, startSec
stdName, dstName = dstName, stdName
stdOffset, dstOffset = dstOffset, stdOffset
stdIsDST, dstIsDST = dstIsDST, stdIsDST
}
// The start and end values that we return are accurate
// close to a daylight savings transition, but are otherwise
// just the start and end of the year. That suffices for
// the only caller that cares, which is Date.
if ysec < startSec {
return stdName, stdOffset, abs, startSec + abs, stdIsDST, true
} else if ysec >= endSec {
return stdName, stdOffset, endSec + abs, abs + 365*secondsPerDay, stdIsDST, true
} else {
return dstName, dstOffset, startSec + abs, endSec + abs, dstIsDST, true
}
}
// tzsetName returns the timezone name at the start of the tzset string s,
// and the remainder of s, and reports whether the parsing is OK.
func tzsetName(s string) (string, string, bool) {
if len(s) == 0 {
return "", "", false
}
if s[0] != '<' {
for i, r := range s {
switch r {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ',', '-', '+':
if i < 3 {
return "", "", false
}
return s[:i], s[i:], true
}
}
if len(s) < 3 {
return "", "", false
}
return s, "", true
} else {
for i, r := range s {
if r == '>' {
return s[1:i], s[i+1:], true
}
}
return "", "", false
}
}
// tzsetOffset returns the timezone offset at the start of the tzset string s,
// and the remainder of s, and reports whether the parsing is OK.
// The timezone offset is returned as a number of seconds.
func tzsetOffset(s string) (offset int, rest string, ok bool) {
if len(s) == 0 {
return 0, "", false
}
neg := false
if s[0] == '+' {
s = s[1:]
} else if s[0] == '-' {
s = s[1:]
neg = true
}
// The tzdata code permits values up to 24 * 7 here,
// although POSIX does not.
var hours int
hours, s, ok = tzsetNum(s, 0, 24*7)
if !ok {
return 0, "", false
}
off := hours * secondsPerHour
if len(s) == 0 || s[0] != ':' {
if neg {
off = -off
}
return off, s, true
}
var mins int
mins, s, ok = tzsetNum(s[1:], 0, 59)
if !ok {
return 0, "", false
}
off += mins * secondsPerMinute
if len(s) == 0 || s[0] != ':' {
if neg {
off = -off
}
return off, s, true
}
var secs int
secs, s, ok = tzsetNum(s[1:], 0, 59)
if !ok {
return 0, "", false
}
off += secs
if neg {
off = -off
}
return off, s, true
}
// ruleKind is the kinds of rules that can be seen in a tzset string.
type ruleKind int
const (
ruleJulian ruleKind = iota
ruleDOY
ruleMonthWeekDay
)
// rule is a rule read from a tzset string.
type rule struct {
kind ruleKind
day int
week int
mon int
time int // transition time
}
// tzsetRule parses a rule from a tzset string.
// It returns the rule, and the remainder of the string, and reports success.
func tzsetRule(s string) (rule, string, bool) {
var r rule
if len(s) == 0 {
return rule{}, "", false
}
ok := false
if s[0] == 'J' {
var jday int
jday, s, ok = tzsetNum(s[1:], 1, 365)
if !ok {
return rule{}, "", false
}
r.kind = ruleJulian
r.day = jday
} else if s[0] == 'M' {
var mon int
mon, s, ok = tzsetNum(s[1:], 1, 12)
if !ok || len(s) == 0 || s[0] != '.' {
return rule{}, "", false
}
var week int
week, s, ok = tzsetNum(s[1:], 1, 5)
if !ok || len(s) == 0 || s[0] != '.' {
return rule{}, "", false
}
var day int
day, s, ok = tzsetNum(s[1:], 0, 6)
if !ok {
return rule{}, "", false
}
r.kind = ruleMonthWeekDay
r.day = day
r.week = week
r.mon = mon
} else {
var day int
day, s, ok = tzsetNum(s, 0, 365)
if !ok {
return rule{}, "", false
}
r.kind = ruleDOY
r.day = day
}
if len(s) == 0 || s[0] != '/' {
r.time = 2 * secondsPerHour // 2am is the default
return r, s, true
}
offset, s, ok := tzsetOffset(s[1:])
if !ok {
return rule{}, "", false
}
r.time = offset
return r, s, true
}
// tzsetNum parses a number from a tzset string.
// It returns the number, and the remainder of the string, and reports success.
// The number must be between min and max.
func tzsetNum(s string, min, max int) (num int, rest string, ok bool) {
if len(s) == 0 {
return 0, "", false
}
num = 0
for i, r := range s {
if r < '0' || r > '9' {
if i == 0 || num < min {
return 0, "", false
}
return num, s[i:], true
}
num *= 10
num += int(r) - '0'
if num > max {
return 0, "", false
}
}
if num < min {
return 0, "", false
}
return num, "", true
}
// tzruleTime takes a year, a rule, and a timezone offset,
// and returns the number of seconds since the start of the year
// that the rule takes effect.
func tzruleTime(year int, r rule, off int) int {
var s int
switch r.kind {
case ruleJulian:
s = (r.day - 1) * secondsPerDay
if isLeap(year) && r.day >= 60 {
s += secondsPerDay
}
case ruleDOY:
s = r.day * secondsPerDay
case ruleMonthWeekDay:
// Zeller's Congruence.
m1 := (r.mon+9)%12 + 1
yy0 := year
if r.mon <= 2 {
yy0--
}
yy1 := yy0 / 100
yy2 := yy0 % 100
dow := ((26*m1-2)/10 + 1 + yy2 + yy2/4 + yy1/4 - 2*yy1) % 7
if dow < 0 {
dow += 7
}
// Now dow is the day-of-week of the first day of r.mon.
// Get the day-of-month of the first "dow" day.
d := r.day - dow
if d < 0 {
d += 7
}
for i := 1; i < r.week; i++ {
if d+7 >= daysIn(Month(r.mon), year) {
break
}
d += 7
}
d += int(daysBefore[r.mon-1])
if isLeap(year) && r.mon > 2 {
d++
}
s = d * secondsPerDay
}
return s + r.time - off
}
// lookupName returns information about the time zone with
// the given name (such as "EST") at the given pseudo-Unix time
// (what the given time of day would be in UTC).
func (l *Location) lookupName(name string, unix int64) (offset int, ok bool) {
l = l.get()
// First try for a zone with the right name that was actually
// in effect at the given time. (In Sydney, Australia, both standard
// and daylight-savings time are abbreviated "EST". Using the
// offset helps us pick the right one for the given time.
// It's not perfect: during the backward transition we might pick
// either one.)
for i := range l.zone {
zone := &l.zone[i]
if zone.name == name {
nam, offset, _, _, _ := l.lookup(unix - int64(zone.offset))
if nam == zone.name {
return offset, true
}
}
}
// Otherwise fall back to an ordinary name match.
for i := range l.zone {
zone := &l.zone[i]
if zone.name == name {
return zone.offset, true
}
}
// Otherwise, give up.
return
}
// NOTE(rsc): Eventually we will need to accept the POSIX TZ environment
// syntax too, but I don't feel like implementing it today.
var errLocation = errors.New("time: invalid location name")
var zoneinfo *string
var zoneinfoOnce sync.Once
// LoadLocation returns the Location with the given name.
//
// If the name is "" or "UTC", LoadLocation returns UTC.
// If the name is "Local", LoadLocation returns Local.
//
// Otherwise, the name is taken to be a location name corresponding to a file
// in the IANA Time Zone database, such as "America/New_York".
//
// LoadLocation looks for the IANA Time Zone database in the following
// locations in order:
//
// - the directory or uncompressed zip file named by the ZONEINFO environment variable
// - on a Unix system, the system standard installation location
// - $GOROOT/lib/time/zoneinfo.zip
// - the time/tzdata package, if it was imported
func LoadLocation(name string) (*Location, error) {
if name == "" || name == "UTC" {
return UTC, nil
}
if name == "Local" {
return Local, nil
}
if containsDotDot(name) || name[0] == '/' || name[0] == '\\' {
// No valid IANA Time Zone name contains a single dot,
// much less dot dot. Likewise, none begin with a slash.
return nil, errLocation
}
zoneinfoOnce.Do(func() {
env, _ := syscall.Getenv("ZONEINFO")
zoneinfo = &env
})
var firstErr error
if *zoneinfo != "" {
if zoneData, err := loadTzinfoFromDirOrZip(*zoneinfo, name); err == nil {
if z, err := LoadLocationFromTZData(name, zoneData); err == nil {
return z, nil
}
firstErr = err
} else if err != syscall.ENOENT {
firstErr = err
}
}
if z, err := loadLocation(name, platformZoneSources); err == nil {
return z, nil
} else if firstErr == nil {
firstErr = err
}
return nil, firstErr
}
// containsDotDot reports whether s contains "..".
func containsDotDot(s string) bool {
if len(s) < 2 {
return false
}
for i := 0; i < len(s)-1; i++ {
if s[i] == '.' && s[i+1] == '.' {
return true
}
}
return false
}
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !ios && !android
package time
func gorootZoneSource(goroot string) (string, bool) {
if goroot == "" {
return "", false
}
return goroot + "/lib/time/zoneinfo.zip", true
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Parse "zoneinfo" time zone file.
// This is a fairly standard file format used on OS X, Linux, BSD, Sun, and others.
// See tzfile(5), https://en.wikipedia.org/wiki/Zoneinfo,
// and ftp://munnari.oz.au/pub/oldtz/
package time
import (
"errors"
"runtime"
"syscall"
)
// registerLoadFromEmbeddedTZData is called by the time/tzdata package,
// if it is imported.
func registerLoadFromEmbeddedTZData(f func(string) (string, error)) {
loadFromEmbeddedTZData = f
}
// loadFromEmbeddedTZData is used to load a specific tzdata file
// from tzdata information embedded in the binary itself.
// This is set when the time/tzdata package is imported,
// via registerLoadFromEmbeddedTzdata.
var loadFromEmbeddedTZData func(zipname string) (string, error)
// maxFileSize is the max permitted size of files read by readFile.
// As reference, the zoneinfo.zip distributed by Go is ~350 KB,
// so 10MB is overkill.
const maxFileSize = 10 << 20
type fileSizeError string
func (f fileSizeError) Error() string {
return "time: file " + string(f) + " is too large"
}
// Copies of io.Seek* constants to avoid importing "io":
const (
seekStart = 0
seekCurrent = 1
seekEnd = 2
)
// Simple I/O interface to binary blob of data.
type dataIO struct {
p []byte
error bool
}
func (d *dataIO) read(n int) []byte {
if len(d.p) < n {
d.p = nil
d.error = true
return nil
}
p := d.p[0:n]
d.p = d.p[n:]
return p
}
func (d *dataIO) big4() (n uint32, ok bool) {
p := d.read(4)
if len(p) < 4 {
d.error = true
return 0, false
}
return uint32(p[3]) | uint32(p[2])<<8 | uint32(p[1])<<16 | uint32(p[0])<<24, true
}
func (d *dataIO) big8() (n uint64, ok bool) {
n1, ok1 := d.big4()
n2, ok2 := d.big4()
if !ok1 || !ok2 {
d.error = true
return 0, false
}
return (uint64(n1) << 32) | uint64(n2), true
}
func (d *dataIO) byte() (n byte, ok bool) {
p := d.read(1)
if len(p) < 1 {
d.error = true
return 0, false
}
return p[0], true
}
// read returns the read of the data in the buffer.
func (d *dataIO) rest() []byte {
r := d.p
d.p = nil
return r
}
// Make a string by stopping at the first NUL
func byteString(p []byte) string {
for i := 0; i < len(p); i++ {
if p[i] == 0 {
return string(p[0:i])
}
}
return string(p)
}
var errBadData = errors.New("malformed time zone information")
// LoadLocationFromTZData returns a Location with the given name
// initialized from the IANA Time Zone database-formatted data.
// The data should be in the format of a standard IANA time zone file
// (for example, the content of /etc/localtime on Unix systems).
func LoadLocationFromTZData(name string, data []byte) (*Location, error) {
d := dataIO{data, false}
// 4-byte magic "TZif"
if magic := d.read(4); string(magic) != "TZif" {
return nil, errBadData
}
// 1-byte version, then 15 bytes of padding
var version int
var p []byte
if p = d.read(16); len(p) != 16 {
return nil, errBadData
} else {
switch p[0] {
case 0:
version = 1
case '2':
version = 2
case '3':
version = 3
default:
return nil, errBadData
}
}
// six big-endian 32-bit integers:
// number of UTC/local indicators
// number of standard/wall indicators
// number of leap seconds
// number of transition times
// number of local time zones
// number of characters of time zone abbrev strings
const (
NUTCLocal = iota
NStdWall
NLeap
NTime
NZone
NChar
)
var n [6]int
for i := 0; i < 6; i++ {
nn, ok := d.big4()
if !ok {
return nil, errBadData
}
if uint32(int(nn)) != nn {
return nil, errBadData
}
n[i] = int(nn)
}
// If we have version 2 or 3, then the data is first written out
// in a 32-bit format, then written out again in a 64-bit format.
// Skip the 32-bit format and read the 64-bit one, as it can
// describe a broader range of dates.
is64 := false
if version > 1 {
// Skip the 32-bit data.
skip := n[NTime]*4 +
n[NTime] +
n[NZone]*6 +
n[NChar] +
n[NLeap]*8 +
n[NStdWall] +
n[NUTCLocal]
// Skip the version 2 header that we just read.
skip += 4 + 16
d.read(skip)
is64 = true
// Read the counts again, they can differ.
for i := 0; i < 6; i++ {
nn, ok := d.big4()
if !ok {
return nil, errBadData
}
if uint32(int(nn)) != nn {
return nil, errBadData
}
n[i] = int(nn)
}
}
size := 4
if is64 {
size = 8
}
// Transition times.
txtimes := dataIO{d.read(n[NTime] * size), false}
// Time zone indices for transition times.
txzones := d.read(n[NTime])
// Zone info structures
zonedata := dataIO{d.read(n[NZone] * 6), false}
// Time zone abbreviations.
abbrev := d.read(n[NChar])
// Leap-second time pairs
d.read(n[NLeap] * (size + 4))
// Whether tx times associated with local time types
// are specified as standard time or wall time.
isstd := d.read(n[NStdWall])
// Whether tx times associated with local time types
// are specified as UTC or local time.
isutc := d.read(n[NUTCLocal])
if d.error { // ran out of data
return nil, errBadData
}
var extend string
rest := d.rest()
if len(rest) > 2 && rest[0] == '\n' && rest[len(rest)-1] == '\n' {
extend = string(rest[1 : len(rest)-1])
}
// Now we can build up a useful data structure.
// First the zone information.
// utcoff[4] isdst[1] nameindex[1]
nzone := n[NZone]
if nzone == 0 {
// Reject tzdata files with no zones. There's nothing useful in them.
// This also avoids a panic later when we add and then use a fake transition (golang.org/issue/29437).
return nil, errBadData
}
zones := make([]zone, nzone)
for i := range zones {
var ok bool
var n uint32
if n, ok = zonedata.big4(); !ok {
return nil, errBadData
}
if uint32(int(n)) != n {
return nil, errBadData
}
zones[i].offset = int(int32(n))
var b byte
if b, ok = zonedata.byte(); !ok {
return nil, errBadData
}
zones[i].isDST = b != 0
if b, ok = zonedata.byte(); !ok || int(b) >= len(abbrev) {
return nil, errBadData
}
zones[i].name = byteString(abbrev[b:])
if runtime.GOOS == "aix" && len(name) > 8 && (name[:8] == "Etc/GMT+" || name[:8] == "Etc/GMT-") {
// There is a bug with AIX 7.2 TL 0 with files in Etc,
// GMT+1 will return GMT-1 instead of GMT+1 or -01.
if name != "Etc/GMT+0" {
// GMT+0 is OK
zones[i].name = name[4:]
}
}
}
// Now the transition time info.
tx := make([]zoneTrans, n[NTime])
for i := range tx {
var n int64
if !is64 {
if n4, ok := txtimes.big4(); !ok {
return nil, errBadData
} else {
n = int64(int32(n4))
}
} else {
if n8, ok := txtimes.big8(); !ok {
return nil, errBadData
} else {
n = int64(n8)
}
}
tx[i].when = n
if int(txzones[i]) >= len(zones) {
return nil, errBadData
}
tx[i].index = txzones[i]
if i < len(isstd) {
tx[i].isstd = isstd[i] != 0
}
if i < len(isutc) {
tx[i].isutc = isutc[i] != 0
}
}
if len(tx) == 0 {
// Build fake transition to cover all time.
// This happens in fixed locations like "Etc/GMT0".
tx = append(tx, zoneTrans{when: alpha, index: 0})
}
// Committed to succeed.
l := &Location{zone: zones, tx: tx, name: name, extend: extend}
// Fill in the cache with information about right now,
// since that will be the most common lookup.
sec, _, _ := now()
for i := range tx {
if tx[i].when <= sec && (i+1 == len(tx) || sec < tx[i+1].when) {
l.cacheStart = tx[i].when
l.cacheEnd = omega
l.cacheZone = &l.zone[tx[i].index]
if i+1 < len(tx) {
l.cacheEnd = tx[i+1].when
} else if l.extend != "" {
// If we're at the end of the known zone transitions,
// try the extend string.
if name, offset, estart, eend, isDST, ok := tzset(l.extend, l.cacheEnd, sec); ok {
l.cacheStart = estart
l.cacheEnd = eend
// Find the zone that is returned by tzset to avoid allocation if possible.
if zoneIdx := findZone(l.zone, name, offset, isDST); zoneIdx != -1 {
l.cacheZone = &l.zone[zoneIdx]
} else {
l.cacheZone = &zone{
name: name,
offset: offset,
isDST: isDST,
}
}
}
}
break
}
}
return l, nil
}
func findZone(zones []zone, name string, offset int, isDST bool) int {
for i, z := range zones {
if z.name == name && z.offset == offset && z.isDST == isDST {
return i
}
}
return -1
}
// loadTzinfoFromDirOrZip returns the contents of the file with the given name
// in dir. dir can either be an uncompressed zip file, or a directory.
func loadTzinfoFromDirOrZip(dir, name string) ([]byte, error) {
if len(dir) > 4 && dir[len(dir)-4:] == ".zip" {
return loadTzinfoFromZip(dir, name)
}
if dir != "" {
name = dir + "/" + name
}
return readFile(name)
}
// There are 500+ zoneinfo files. Rather than distribute them all
// individually, we ship them in an uncompressed zip file.
// Used this way, the zip file format serves as a commonly readable
// container for the individual small files. We choose zip over tar
// because zip files have a contiguous table of contents, making
// individual file lookups faster, and because the per-file overhead
// in a zip file is considerably less than tar's 512 bytes.
// get4 returns the little-endian 32-bit value in b.
func get4(b []byte) int {
if len(b) < 4 {
return 0
}
return int(b[0]) | int(b[1])<<8 | int(b[2])<<16 | int(b[3])<<24
}
// get2 returns the little-endian 16-bit value in b.
func get2(b []byte) int {
if len(b) < 2 {
return 0
}
return int(b[0]) | int(b[1])<<8
}
// loadTzinfoFromZip returns the contents of the file with the given name
// in the given uncompressed zip file.
func loadTzinfoFromZip(zipfile, name string) ([]byte, error) {
fd, err := open(zipfile)
if err != nil {
return nil, err
}
defer closefd(fd)
const (
zecheader = 0x06054b50
zcheader = 0x02014b50
ztailsize = 22
zheadersize = 30
zheader = 0x04034b50
)
buf := make([]byte, ztailsize)
if err := preadn(fd, buf, -ztailsize); err != nil || get4(buf) != zecheader {
return nil, errors.New("corrupt zip file " + zipfile)
}
n := get2(buf[10:])
size := get4(buf[12:])
off := get4(buf[16:])
buf = make([]byte, size)
if err := preadn(fd, buf, off); err != nil {
return nil, errors.New("corrupt zip file " + zipfile)
}
for i := 0; i < n; i++ {
// zip entry layout:
// 0 magic[4]
// 4 madevers[1]
// 5 madeos[1]
// 6 extvers[1]
// 7 extos[1]
// 8 flags[2]
// 10 meth[2]
// 12 modtime[2]
// 14 moddate[2]
// 16 crc[4]
// 20 csize[4]
// 24 uncsize[4]
// 28 namelen[2]
// 30 xlen[2]
// 32 fclen[2]
// 34 disknum[2]
// 36 iattr[2]
// 38 eattr[4]
// 42 off[4]
// 46 name[namelen]
// 46+namelen+xlen+fclen - next header
//
if get4(buf) != zcheader {
break
}
meth := get2(buf[10:])
size := get4(buf[24:])
namelen := get2(buf[28:])
xlen := get2(buf[30:])
fclen := get2(buf[32:])
off := get4(buf[42:])
zname := buf[46 : 46+namelen]
buf = buf[46+namelen+xlen+fclen:]
if string(zname) != name {
continue
}
if meth != 0 {
return nil, errors.New("unsupported compression for " + name + " in " + zipfile)
}
// zip per-file header layout:
// 0 magic[4]
// 4 extvers[1]
// 5 extos[1]
// 6 flags[2]
// 8 meth[2]
// 10 modtime[2]
// 12 moddate[2]
// 14 crc[4]
// 18 csize[4]
// 22 uncsize[4]
// 26 namelen[2]
// 28 xlen[2]
// 30 name[namelen]
// 30+namelen+xlen - file data
//
buf = make([]byte, zheadersize+namelen)
if err := preadn(fd, buf, off); err != nil ||
get4(buf) != zheader ||
get2(buf[8:]) != meth ||
get2(buf[26:]) != namelen ||
string(buf[30:30+namelen]) != name {
return nil, errors.New("corrupt zip file " + zipfile)
}
xlen = get2(buf[28:])
buf = make([]byte, size)
if err := preadn(fd, buf, off+30+namelen+xlen); err != nil {
return nil, errors.New("corrupt zip file " + zipfile)
}
return buf, nil
}
return nil, syscall.ENOENT
}
// loadTzinfoFromTzdata returns the time zone information of the time zone
// with the given name, from a tzdata database file as they are typically
// found on android.
var loadTzinfoFromTzdata func(file, name string) ([]byte, error)
// loadTzinfo returns the time zone information of the time zone
// with the given name, from a given source. A source may be a
// timezone database directory, tzdata database file or an uncompressed
// zip file, containing the contents of such a directory.
func loadTzinfo(name string, source string) ([]byte, error) {
if len(source) >= 6 && source[len(source)-6:] == "tzdata" {
return loadTzinfoFromTzdata(source, name)
}
return loadTzinfoFromDirOrZip(source, name)
}
// loadLocation returns the Location with the given name from one of
// the specified sources. See loadTzinfo for a list of supported sources.
// The first timezone data matching the given name that is successfully loaded
// and parsed is returned as a Location.
func loadLocation(name string, sources []string) (z *Location, firstErr error) {
for _, source := range sources {
zoneData, err := loadTzinfo(name, source)
if err == nil {
if z, err = LoadLocationFromTZData(name, zoneData); err == nil {
return z, nil
}
}
if firstErr == nil && err != syscall.ENOENT {
firstErr = err
}
}
if loadFromEmbeddedTZData != nil {
zoneData, err := loadFromEmbeddedTZData(name)
if err == nil {
if z, err = LoadLocationFromTZData(name, []byte(zoneData)); err == nil {
return z, nil
}
}
if firstErr == nil && err != syscall.ENOENT {
firstErr = err
}
}
if source, ok := gorootZoneSource(runtime.GOROOT()); ok {
zoneData, err := loadTzinfo(name, source)
if err == nil {
if z, err = LoadLocationFromTZData(name, zoneData); err == nil {
return z, nil
}
}
if firstErr == nil && err != syscall.ENOENT {
firstErr = err
}
}
if firstErr != nil {
return nil, firstErr
}
return nil, errors.New("unknown time zone " + name)
}
// readFile reads and returns the content of the named file.
// It is a trivial implementation of os.ReadFile, reimplemented
// here to avoid depending on io/ioutil or os.
// It returns an error if name exceeds maxFileSize bytes.
func readFile(name string) ([]byte, error) {
f, err := open(name)
if err != nil {
return nil, err
}
defer closefd(f)
var (
buf [4096]byte
ret []byte
n int
)
for {
n, err = read(f, buf[:])
if n > 0 {
ret = append(ret, buf[:n]...)
}
if n == 0 || err != nil {
break
}
if len(ret) > maxFileSize {
return nil, fileSizeError(name)
}
}
return ret, err
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix && !ios && !android
// Parse "zoneinfo" time zone file.
// This is a fairly standard file format used on OS X, Linux, BSD, Sun, and others.
// See tzfile(5), https://en.wikipedia.org/wiki/Zoneinfo,
// and ftp://munnari.oz.au/pub/oldtz/
package time
import (
"syscall"
)
// Many systems use /usr/share/zoneinfo, Solaris 2 has
// /usr/share/lib/zoneinfo, IRIX 6 has /usr/lib/locale/TZ,
// NixOS has /etc/zoneinfo.
var platformZoneSources = []string{
"/usr/share/zoneinfo/",
"/usr/share/lib/zoneinfo/",
"/usr/lib/locale/TZ/",
"/etc/zoneinfo",
}
func initLocal() {
// consult $TZ to find the time zone to use.
// no $TZ means use the system default /etc/localtime.
// $TZ="" means use UTC.
// $TZ="foo" or $TZ=":foo" if foo is an absolute path, then the file pointed
// by foo will be used to initialize timezone; otherwise, file
// /usr/share/zoneinfo/foo will be used.
tz, ok := syscall.Getenv("TZ")
switch {
case !ok:
z, err := loadLocation("localtime", []string{"/etc"})
if err == nil {
localLoc = *z
localLoc.name = "Local"
return
}
case tz != "":
if tz[0] == ':' {
tz = tz[1:]
}
if tz != "" && tz[0] == '/' {
if z, err := loadLocation(tz, []string{""}); err == nil {
localLoc = *z
if tz == "/etc/localtime" {
localLoc.name = "Local"
} else {
localLoc.name = tz
}
return
}
} else if tz != "" && tz != "UTC" {
if z, err := loadLocation(tz, platformZoneSources); err == nil {
localLoc = *z
return
}
}
}
// Fall back to UTC.
localLoc.name = "UTC"
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unicode
// IsDigit reports whether the rune is a decimal digit.
func IsDigit(r rune) bool {
if r <= MaxLatin1 {
return '0' <= r && r <= '9'
}
return isExcludingLatin(Digit, r)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package unicode
// Bit masks for each code point under U+0100, for fast lookup.
const (
pC = 1 << iota // a control character.
pP // a punctuation character.
pN // a numeral.
pS // a symbolic character.
pZ // a spacing character.
pLu // an upper-case letter.
pLl // a lower-case letter.
pp // a printable character according to Go's definition.
pg = pp | pZ // a graphical character according to the Unicode definition.
pLo = pLl | pLu // a letter that is neither upper nor lower case.
pLmask = pLo
)
// GraphicRanges defines the set of graphic characters according to Unicode.
var GraphicRanges = []*RangeTable{
L, M, N, P, S, Zs,
}
// PrintRanges defines the set of printable characters according to Go.
// ASCII space, U+0020, is handled separately.
var PrintRanges = []*RangeTable{
L, M, N, P, S,
}
// IsGraphic reports whether the rune is defined as a Graphic by Unicode.
// Such characters include letters, marks, numbers, punctuation, symbols, and
// spaces, from categories L, M, N, P, S, Zs.
func IsGraphic(r rune) bool {
// We convert to uint32 to avoid the extra test for negative,
// and in the index we convert to uint8 to avoid the range check.
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pg != 0
}
return In(r, GraphicRanges...)
}
// IsPrint reports whether the rune is defined as printable by Go. Such
// characters include letters, marks, numbers, punctuation, symbols, and the
// ASCII space character, from categories L, M, N, P, S and the ASCII space
// character. This categorization is the same as IsGraphic except that the
// only spacing character is ASCII space, U+0020.
func IsPrint(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pp != 0
}
return In(r, PrintRanges...)
}
// IsOneOf reports whether the rune is a member of one of the ranges.
// The function "In" provides a nicer signature and should be used in preference to IsOneOf.
func IsOneOf(ranges []*RangeTable, r rune) bool {
for _, inside := range ranges {
if Is(inside, r) {
return true
}
}
return false
}
// In reports whether the rune is a member of one of the ranges.
func In(r rune, ranges ...*RangeTable) bool {
for _, inside := range ranges {
if Is(inside, r) {
return true
}
}
return false
}
// IsControl reports whether the rune is a control character.
// The C (Other) Unicode category includes more code points
// such as surrogates; use Is(C, r) to test for them.
func IsControl(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pC != 0
}
// All control characters are < MaxLatin1.
return false
}
// IsLetter reports whether the rune is a letter (category L).
func IsLetter(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&(pLmask) != 0
}
return isExcludingLatin(Letter, r)
}
// IsMark reports whether the rune is a mark character (category M).
func IsMark(r rune) bool {
// There are no mark characters in Latin-1.
return isExcludingLatin(Mark, r)
}
// IsNumber reports whether the rune is a number (category N).
func IsNumber(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pN != 0
}
return isExcludingLatin(Number, r)
}
// IsPunct reports whether the rune is a Unicode punctuation character
// (category P).
func IsPunct(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pP != 0
}
return Is(Punct, r)
}
// IsSpace reports whether the rune is a space character as defined
// by Unicode's White Space property; in the Latin-1 space
// this is
//
// '\t', '\n', '\v', '\f', '\r', ' ', U+0085 (NEL), U+00A0 (NBSP).
//
// Other definitions of spacing characters are set by category
// Z and property Pattern_White_Space.
func IsSpace(r rune) bool {
// This property isn't the same as Z; special-case it.
if uint32(r) <= MaxLatin1 {
switch r {
case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0:
return true
}
return false
}
return isExcludingLatin(White_Space, r)
}
// IsSymbol reports whether the rune is a symbolic character.
func IsSymbol(r rune) bool {
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pS != 0
}
return isExcludingLatin(Symbol, r)
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unicode provides data and functions to test some properties of
// Unicode code points.
package unicode
const (
MaxRune = '\U0010FFFF' // Maximum valid Unicode code point.
ReplacementChar = '\uFFFD' // Represents invalid code points.
MaxASCII = '\u007F' // maximum ASCII value.
MaxLatin1 = '\u00FF' // maximum Latin-1 value.
)
// RangeTable defines a set of Unicode code points by listing the ranges of
// code points within the set. The ranges are listed in two slices
// to save space: a slice of 16-bit ranges and a slice of 32-bit ranges.
// The two slices must be in sorted order and non-overlapping.
// Also, R32 should contain only values >= 0x10000 (1<<16).
type RangeTable struct {
R16 []Range16
R32 []Range32
LatinOffset int // number of entries in R16 with Hi <= MaxLatin1
}
// Range16 represents of a range of 16-bit Unicode code points. The range runs from Lo to Hi
// inclusive and has the specified stride.
type Range16 struct {
Lo uint16
Hi uint16
Stride uint16
}
// Range32 represents of a range of Unicode code points and is used when one or
// more of the values will not fit in 16 bits. The range runs from Lo to Hi
// inclusive and has the specified stride. Lo and Hi must always be >= 1<<16.
type Range32 struct {
Lo uint32
Hi uint32
Stride uint32
}
// CaseRange represents a range of Unicode code points for simple (one
// code point to one code point) case conversion.
// The range runs from Lo to Hi inclusive, with a fixed stride of 1. Deltas
// are the number to add to the code point to reach the code point for a
// different case for that character. They may be negative. If zero, it
// means the character is in the corresponding case. There is a special
// case representing sequences of alternating corresponding Upper and Lower
// pairs. It appears with a fixed Delta of
//
// {UpperLower, UpperLower, UpperLower}
//
// The constant UpperLower has an otherwise impossible delta value.
type CaseRange struct {
Lo uint32
Hi uint32
Delta d
}
// SpecialCase represents language-specific case mappings such as Turkish.
// Methods of SpecialCase customize (by overriding) the standard mappings.
type SpecialCase []CaseRange
// BUG(r): There is no mechanism for full case folding, that is, for
// characters that involve multiple runes in the input or output.
// Indices into the Delta arrays inside CaseRanges for case mapping.
const (
UpperCase = iota
LowerCase
TitleCase
MaxCase
)
type d [MaxCase]rune // to make the CaseRanges text shorter
// If the Delta field of a CaseRange is UpperLower, it means
// this CaseRange represents a sequence of the form (say)
// Upper Lower Upper Lower.
const (
UpperLower = MaxRune + 1 // (Cannot be a valid delta.)
)
// linearMax is the maximum size table for linear search for non-Latin1 rune.
// Derived by running 'go test -calibrate'.
const linearMax = 18
// is16 reports whether r is in the sorted slice of 16-bit ranges.
func is16(ranges []Range16, r uint16) bool {
if len(ranges) <= linearMax || r <= MaxLatin1 {
for i := range ranges {
range_ := &ranges[i]
if r < range_.Lo {
return false
}
if r <= range_.Hi {
return range_.Stride == 1 || (r-range_.Lo)%range_.Stride == 0
}
}
return false
}
// binary search over ranges
lo := 0
hi := len(ranges)
for lo < hi {
m := lo + (hi-lo)/2
range_ := &ranges[m]
if range_.Lo <= r && r <= range_.Hi {
return range_.Stride == 1 || (r-range_.Lo)%range_.Stride == 0
}
if r < range_.Lo {
hi = m
} else {
lo = m + 1
}
}
return false
}
// is32 reports whether r is in the sorted slice of 32-bit ranges.
func is32(ranges []Range32, r uint32) bool {
if len(ranges) <= linearMax {
for i := range ranges {
range_ := &ranges[i]
if r < range_.Lo {
return false
}
if r <= range_.Hi {
return range_.Stride == 1 || (r-range_.Lo)%range_.Stride == 0
}
}
return false
}
// binary search over ranges
lo := 0
hi := len(ranges)
for lo < hi {
m := lo + (hi-lo)/2
range_ := ranges[m]
if range_.Lo <= r && r <= range_.Hi {
return range_.Stride == 1 || (r-range_.Lo)%range_.Stride == 0
}
if r < range_.Lo {
hi = m
} else {
lo = m + 1
}
}
return false
}
// Is reports whether the rune is in the specified table of ranges.
func Is(rangeTab *RangeTable, r rune) bool {
r16 := rangeTab.R16
// Compare as uint32 to correctly handle negative runes.
if len(r16) > 0 && uint32(r) <= uint32(r16[len(r16)-1].Hi) {
return is16(r16, uint16(r))
}
r32 := rangeTab.R32
if len(r32) > 0 && r >= rune(r32[0].Lo) {
return is32(r32, uint32(r))
}
return false
}
func isExcludingLatin(rangeTab *RangeTable, r rune) bool {
r16 := rangeTab.R16
// Compare as uint32 to correctly handle negative runes.
if off := rangeTab.LatinOffset; len(r16) > off && uint32(r) <= uint32(r16[len(r16)-1].Hi) {
return is16(r16[off:], uint16(r))
}
r32 := rangeTab.R32
if len(r32) > 0 && r >= rune(r32[0].Lo) {
return is32(r32, uint32(r))
}
return false
}
// IsUpper reports whether the rune is an upper case letter.
func IsUpper(r rune) bool {
// See comment in IsGraphic.
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pLmask == pLu
}
return isExcludingLatin(Upper, r)
}
// IsLower reports whether the rune is a lower case letter.
func IsLower(r rune) bool {
// See comment in IsGraphic.
if uint32(r) <= MaxLatin1 {
return properties[uint8(r)]&pLmask == pLl
}
return isExcludingLatin(Lower, r)
}
// IsTitle reports whether the rune is a title case letter.
func IsTitle(r rune) bool {
if r <= MaxLatin1 {
return false
}
return isExcludingLatin(Title, r)
}
// to maps the rune using the specified case mapping.
// It additionally reports whether caseRange contained a mapping for r.
func to(_case int, r rune, caseRange []CaseRange) (mappedRune rune, foundMapping bool) {
if _case < 0 || MaxCase <= _case {
return ReplacementChar, false // as reasonable an error as any
}
// binary search over ranges
lo := 0
hi := len(caseRange)
for lo < hi {
m := lo + (hi-lo)/2
cr := caseRange[m]
if rune(cr.Lo) <= r && r <= rune(cr.Hi) {
delta := cr.Delta[_case]
if delta > MaxRune {
// In an Upper-Lower sequence, which always starts with
// an UpperCase letter, the real deltas always look like:
// {0, 1, 0} UpperCase (Lower is next)
// {-1, 0, -1} LowerCase (Upper, Title are previous)
// The characters at even offsets from the beginning of the
// sequence are upper case; the ones at odd offsets are lower.
// The correct mapping can be done by clearing or setting the low
// bit in the sequence offset.
// The constants UpperCase and TitleCase are even while LowerCase
// is odd so we take the low bit from _case.
return rune(cr.Lo) + ((r-rune(cr.Lo))&^1 | rune(_case&1)), true
}
return r + delta, true
}
if r < rune(cr.Lo) {
hi = m
} else {
lo = m + 1
}
}
return r, false
}
// To maps the rune to the specified case: UpperCase, LowerCase, or TitleCase.
func To(_case int, r rune) rune {
r, _ = to(_case, r, CaseRanges)
return r
}
// ToUpper maps the rune to upper case.
func ToUpper(r rune) rune {
if r <= MaxASCII {
if 'a' <= r && r <= 'z' {
r -= 'a' - 'A'
}
return r
}
return To(UpperCase, r)
}
// ToLower maps the rune to lower case.
func ToLower(r rune) rune {
if r <= MaxASCII {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
}
return r
}
return To(LowerCase, r)
}
// ToTitle maps the rune to title case.
func ToTitle(r rune) rune {
if r <= MaxASCII {
if 'a' <= r && r <= 'z' { // title case is upper case for ASCII
r -= 'a' - 'A'
}
return r
}
return To(TitleCase, r)
}
// ToUpper maps the rune to upper case giving priority to the special mapping.
func (special SpecialCase) ToUpper(r rune) rune {
r1, hadMapping := to(UpperCase, r, []CaseRange(special))
if r1 == r && !hadMapping {
r1 = ToUpper(r)
}
return r1
}
// ToTitle maps the rune to title case giving priority to the special mapping.
func (special SpecialCase) ToTitle(r rune) rune {
r1, hadMapping := to(TitleCase, r, []CaseRange(special))
if r1 == r && !hadMapping {
r1 = ToTitle(r)
}
return r1
}
// ToLower maps the rune to lower case giving priority to the special mapping.
func (special SpecialCase) ToLower(r rune) rune {
r1, hadMapping := to(LowerCase, r, []CaseRange(special))
if r1 == r && !hadMapping {
r1 = ToLower(r)
}
return r1
}
// caseOrbit is defined in tables.go as []foldPair. Right now all the
// entries fit in uint16, so use uint16. If that changes, compilation
// will fail (the constants in the composite literal will not fit in uint16)
// and the types here can change to uint32.
type foldPair struct {
From uint16
To uint16
}
// SimpleFold iterates over Unicode code points equivalent under
// the Unicode-defined simple case folding. Among the code points
// equivalent to rune (including rune itself), SimpleFold returns the
// smallest rune > r if one exists, or else the smallest rune >= 0.
// If r is not a valid Unicode code point, SimpleFold(r) returns r.
//
// For example:
//
// SimpleFold('A') = 'a'
// SimpleFold('a') = 'A'
//
// SimpleFold('K') = 'k'
// SimpleFold('k') = '\u212A' (Kelvin symbol, K)
// SimpleFold('\u212A') = 'K'
//
// SimpleFold('1') = '1'
//
// SimpleFold(-2) = -2
func SimpleFold(r rune) rune {
if r < 0 || r > MaxRune {
return r
}
if int(r) < len(asciiFold) {
return rune(asciiFold[r])
}
// Consult caseOrbit table for special cases.
lo := 0
hi := len(caseOrbit)
for lo < hi {
m := lo + (hi-lo)/2
if rune(caseOrbit[m].From) < r {
lo = m + 1
} else {
hi = m
}
}
if lo < len(caseOrbit) && rune(caseOrbit[lo].From) == r {
return rune(caseOrbit[lo].To)
}
// No folding specified. This is a one- or two-element
// equivalence class containing rune and ToLower(rune)
// and ToUpper(rune) if they are different from rune.
if l := ToLower(r); l != r {
return l
}
return ToUpper(r)
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package utf16 implements encoding and decoding of UTF-16 sequences.
package utf16
// The conditions replacementChar==unicode.ReplacementChar and
// maxRune==unicode.MaxRune are verified in the tests.
// Defining them locally avoids this package depending on package unicode.
const (
replacementChar = '\uFFFD' // Unicode replacement character
maxRune = '\U0010FFFF' // Maximum valid Unicode code point.
)
const (
// 0xd800-0xdc00 encodes the high 10 bits of a pair.
// 0xdc00-0xe000 encodes the low 10 bits of a pair.
// the value is those 20 bits plus 0x10000.
surr1 = 0xd800
surr2 = 0xdc00
surr3 = 0xe000
surrSelf = 0x10000
)
// IsSurrogate reports whether the specified Unicode code point
// can appear in a surrogate pair.
func IsSurrogate(r rune) bool {
return surr1 <= r && r < surr3
}
// DecodeRune returns the UTF-16 decoding of a surrogate pair.
// If the pair is not a valid UTF-16 surrogate pair, DecodeRune returns
// the Unicode replacement code point U+FFFD.
func DecodeRune(r1, r2 rune) rune {
if surr1 <= r1 && r1 < surr2 && surr2 <= r2 && r2 < surr3 {
return (r1-surr1)<<10 | (r2 - surr2) + surrSelf
}
return replacementChar
}
// EncodeRune returns the UTF-16 surrogate pair r1, r2 for the given rune.
// If the rune is not a valid Unicode code point or does not need encoding,
// EncodeRune returns U+FFFD, U+FFFD.
func EncodeRune(r rune) (r1, r2 rune) {
if r < surrSelf || r > maxRune {
return replacementChar, replacementChar
}
r -= surrSelf
return surr1 + (r>>10)&0x3ff, surr2 + r&0x3ff
}
// Encode returns the UTF-16 encoding of the Unicode code point sequence s.
func Encode(s []rune) []uint16 {
n := len(s)
for _, v := range s {
if v >= surrSelf {
n++
}
}
a := make([]uint16, n)
n = 0
for _, v := range s {
switch {
case 0 <= v && v < surr1, surr3 <= v && v < surrSelf:
// normal rune
a[n] = uint16(v)
n++
case surrSelf <= v && v <= maxRune:
// needs surrogate sequence
r1, r2 := EncodeRune(v)
a[n] = uint16(r1)
a[n+1] = uint16(r2)
n += 2
default:
a[n] = uint16(replacementChar)
n++
}
}
return a[:n]
}
// AppendRune appends the UTF-16 encoding of the Unicode code point r
// to the end of p and returns the extended buffer. If the rune is not
// a valid Unicode code point, it appends the encoding of U+FFFD.
func AppendRune(a []uint16, r rune) []uint16 {
// This function is inlineable for fast handling of ASCII.
switch {
case 0 <= r && r < surr1, surr3 <= r && r < surrSelf:
// normal rune
return append(a, uint16(r))
case surrSelf <= r && r <= maxRune:
// needs surrogate sequence
r1, r2 := EncodeRune(r)
return append(a, uint16(r1), uint16(r2))
}
return append(a, replacementChar)
}
// Decode returns the Unicode code point sequence represented
// by the UTF-16 encoding s.
func Decode(s []uint16) []rune {
// Preallocate capacity to hold up to 64 runes.
// Decode inlines, so the allocation can live on the stack.
buf := make([]rune, 0, 64)
return decode(s, buf)
}
// decode appends to buf the Unicode code point sequence represented
// by the UTF-16 encoding s and return the extended buffer.
func decode(s []uint16, buf []rune) []rune {
for i := 0; i < len(s); i++ {
var ar rune
switch r := s[i]; {
case r < surr1, surr3 <= r:
// normal rune
ar = rune(r)
case surr1 <= r && r < surr2 && i+1 < len(s) &&
surr2 <= s[i+1] && s[i+1] < surr3:
// valid surrogate sequence
ar = DecodeRune(rune(r), rune(s[i+1]))
i++
default:
// invalid surrogate sequence
ar = replacementChar
}
buf = append(buf, ar)
}
return buf
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package utf8 implements functions and constants to support text encoded in
// UTF-8. It includes functions to translate between runes and UTF-8 byte sequences.
// See https://en.wikipedia.org/wiki/UTF-8
package utf8
// The conditions RuneError==unicode.ReplacementChar and
// MaxRune==unicode.MaxRune are verified in the tests.
// Defining them locally avoids this package depending on package unicode.
// Numbers fundamental to the encoding.
const (
RuneError = '\uFFFD' // the "error" Rune or "Unicode replacement character"
RuneSelf = 0x80 // characters below RuneSelf are represented as themselves in a single byte.
MaxRune = '\U0010FFFF' // Maximum valid Unicode code point.
UTFMax = 4 // maximum number of bytes of a UTF-8 encoded Unicode character.
)
// Code points in the surrogate range are not valid for UTF-8.
const (
surrogateMin = 0xD800
surrogateMax = 0xDFFF
)
const (
t1 = 0b00000000
tx = 0b10000000
t2 = 0b11000000
t3 = 0b11100000
t4 = 0b11110000
t5 = 0b11111000
maskx = 0b00111111
mask2 = 0b00011111
mask3 = 0b00001111
mask4 = 0b00000111
rune1Max = 1<<7 - 1
rune2Max = 1<<11 - 1
rune3Max = 1<<16 - 1
// The default lowest and highest continuation byte.
locb = 0b10000000
hicb = 0b10111111
// These names of these constants are chosen to give nice alignment in the
// table below. The first nibble is an index into acceptRanges or F for
// special one-byte cases. The second nibble is the Rune length or the
// Status for the special one-byte case.
xx = 0xF1 // invalid: size 1
as = 0xF0 // ASCII: size 1
s1 = 0x02 // accept 0, size 2
s2 = 0x13 // accept 1, size 3
s3 = 0x03 // accept 0, size 3
s4 = 0x23 // accept 2, size 3
s5 = 0x34 // accept 3, size 4
s6 = 0x04 // accept 0, size 4
s7 = 0x44 // accept 4, size 4
)
// first is information about the first byte in a UTF-8 sequence.
var first = [256]uint8{
// 1 2 3 4 5 6 7 8 9 A B C D E F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x00-0x0F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x10-0x1F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x20-0x2F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x30-0x3F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x40-0x4F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x50-0x5F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x60-0x6F
as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, as, // 0x70-0x7F
// 1 2 3 4 5 6 7 8 9 A B C D E F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0x80-0x8F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0x90-0x9F
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xA0-0xAF
xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xB0-0xBF
xx, xx, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, // 0xC0-0xCF
s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, // 0xD0-0xDF
s2, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s4, s3, s3, // 0xE0-0xEF
s5, s6, s6, s6, s7, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, // 0xF0-0xFF
}
// acceptRange gives the range of valid values for the second byte in a UTF-8
// sequence.
type acceptRange struct {
lo uint8 // lowest value for second byte.
hi uint8 // highest value for second byte.
}
// acceptRanges has size 16 to avoid bounds checks in the code that uses it.
var acceptRanges = [16]acceptRange{
0: {locb, hicb},
1: {0xA0, hicb},
2: {locb, 0x9F},
3: {0x90, hicb},
4: {locb, 0x8F},
}
// FullRune reports whether the bytes in p begin with a full UTF-8 encoding of a rune.
// An invalid encoding is considered a full Rune since it will convert as a width-1 error rune.
func FullRune(p []byte) bool {
n := len(p)
if n == 0 {
return false
}
x := first[p[0]]
if n >= int(x&7) {
return true // ASCII, invalid or valid.
}
// Must be short or invalid.
accept := acceptRanges[x>>4]
if n > 1 && (p[1] < accept.lo || accept.hi < p[1]) {
return true
} else if n > 2 && (p[2] < locb || hicb < p[2]) {
return true
}
return false
}
// FullRuneInString is like FullRune but its input is a string.
func FullRuneInString(s string) bool {
n := len(s)
if n == 0 {
return false
}
x := first[s[0]]
if n >= int(x&7) {
return true // ASCII, invalid, or valid.
}
// Must be short or invalid.
accept := acceptRanges[x>>4]
if n > 1 && (s[1] < accept.lo || accept.hi < s[1]) {
return true
} else if n > 2 && (s[2] < locb || hicb < s[2]) {
return true
}
return false
}
// DecodeRune unpacks the first UTF-8 encoding in p and returns the rune and
// its width in bytes. If p is empty it returns (RuneError, 0). Otherwise, if
// the encoding is invalid, it returns (RuneError, 1). Both are impossible
// results for correct, non-empty UTF-8.
//
// An encoding is invalid if it is incorrect UTF-8, encodes a rune that is
// out of range, or is not the shortest possible UTF-8 encoding for the
// value. No other validation is performed.
func DecodeRune(p []byte) (r rune, size int) {
n := len(p)
if n < 1 {
return RuneError, 0
}
p0 := p[0]
x := first[p0]
if x >= as {
// The following code simulates an additional check for x == xx and
// handling the ASCII and invalid cases accordingly. This mask-and-or
// approach prevents an additional branch.
mask := rune(x) << 31 >> 31 // Create 0x0000 or 0xFFFF.
return rune(p[0])&^mask | RuneError&mask, 1
}
sz := int(x & 7)
accept := acceptRanges[x>>4]
if n < sz {
return RuneError, 1
}
b1 := p[1]
if b1 < accept.lo || accept.hi < b1 {
return RuneError, 1
}
if sz <= 2 { // <= instead of == to help the compiler eliminate some bounds checks
return rune(p0&mask2)<<6 | rune(b1&maskx), 2
}
b2 := p[2]
if b2 < locb || hicb < b2 {
return RuneError, 1
}
if sz <= 3 {
return rune(p0&mask3)<<12 | rune(b1&maskx)<<6 | rune(b2&maskx), 3
}
b3 := p[3]
if b3 < locb || hicb < b3 {
return RuneError, 1
}
return rune(p0&mask4)<<18 | rune(b1&maskx)<<12 | rune(b2&maskx)<<6 | rune(b3&maskx), 4
}
// DecodeRuneInString is like DecodeRune but its input is a string. If s is
// empty it returns (RuneError, 0). Otherwise, if the encoding is invalid, it
// returns (RuneError, 1). Both are impossible results for correct, non-empty
// UTF-8.
//
// An encoding is invalid if it is incorrect UTF-8, encodes a rune that is
// out of range, or is not the shortest possible UTF-8 encoding for the
// value. No other validation is performed.
func DecodeRuneInString(s string) (r rune, size int) {
n := len(s)
if n < 1 {
return RuneError, 0
}
s0 := s[0]
x := first[s0]
if x >= as {
// The following code simulates an additional check for x == xx and
// handling the ASCII and invalid cases accordingly. This mask-and-or
// approach prevents an additional branch.
mask := rune(x) << 31 >> 31 // Create 0x0000 or 0xFFFF.
return rune(s[0])&^mask | RuneError&mask, 1
}
sz := int(x & 7)
accept := acceptRanges[x>>4]
if n < sz {
return RuneError, 1
}
s1 := s[1]
if s1 < accept.lo || accept.hi < s1 {
return RuneError, 1
}
if sz <= 2 { // <= instead of == to help the compiler eliminate some bounds checks
return rune(s0&mask2)<<6 | rune(s1&maskx), 2
}
s2 := s[2]
if s2 < locb || hicb < s2 {
return RuneError, 1
}
if sz <= 3 {
return rune(s0&mask3)<<12 | rune(s1&maskx)<<6 | rune(s2&maskx), 3
}
s3 := s[3]
if s3 < locb || hicb < s3 {
return RuneError, 1
}
return rune(s0&mask4)<<18 | rune(s1&maskx)<<12 | rune(s2&maskx)<<6 | rune(s3&maskx), 4
}
// DecodeLastRune unpacks the last UTF-8 encoding in p and returns the rune and
// its width in bytes. If p is empty it returns (RuneError, 0). Otherwise, if
// the encoding is invalid, it returns (RuneError, 1). Both are impossible
// results for correct, non-empty UTF-8.
//
// An encoding is invalid if it is incorrect UTF-8, encodes a rune that is
// out of range, or is not the shortest possible UTF-8 encoding for the
// value. No other validation is performed.
func DecodeLastRune(p []byte) (r rune, size int) {
end := len(p)
if end == 0 {
return RuneError, 0
}
start := end - 1
r = rune(p[start])
if r < RuneSelf {
return r, 1
}
// guard against O(n^2) behavior when traversing
// backwards through strings with long sequences of
// invalid UTF-8.
lim := end - UTFMax
if lim < 0 {
lim = 0
}
for start--; start >= lim; start-- {
if RuneStart(p[start]) {
break
}
}
if start < 0 {
start = 0
}
r, size = DecodeRune(p[start:end])
if start+size != end {
return RuneError, 1
}
return r, size
}
// DecodeLastRuneInString is like DecodeLastRune but its input is a string. If
// s is empty it returns (RuneError, 0). Otherwise, if the encoding is invalid,
// it returns (RuneError, 1). Both are impossible results for correct,
// non-empty UTF-8.
//
// An encoding is invalid if it is incorrect UTF-8, encodes a rune that is
// out of range, or is not the shortest possible UTF-8 encoding for the
// value. No other validation is performed.
func DecodeLastRuneInString(s string) (r rune, size int) {
end := len(s)
if end == 0 {
return RuneError, 0
}
start := end - 1
r = rune(s[start])
if r < RuneSelf {
return r, 1
}
// guard against O(n^2) behavior when traversing
// backwards through strings with long sequences of
// invalid UTF-8.
lim := end - UTFMax
if lim < 0 {
lim = 0
}
for start--; start >= lim; start-- {
if RuneStart(s[start]) {
break
}
}
if start < 0 {
start = 0
}
r, size = DecodeRuneInString(s[start:end])
if start+size != end {
return RuneError, 1
}
return r, size
}
// RuneLen returns the number of bytes required to encode the rune.
// It returns -1 if the rune is not a valid value to encode in UTF-8.
func RuneLen(r rune) int {
switch {
case r < 0:
return -1
case r <= rune1Max:
return 1
case r <= rune2Max:
return 2
case surrogateMin <= r && r <= surrogateMax:
return -1
case r <= rune3Max:
return 3
case r <= MaxRune:
return 4
}
return -1
}
// EncodeRune writes into p (which must be large enough) the UTF-8 encoding of the rune.
// If the rune is out of range, it writes the encoding of RuneError.
// It returns the number of bytes written.
func EncodeRune(p []byte, r rune) int {
// Negative values are erroneous. Making it unsigned addresses the problem.
switch i := uint32(r); {
case i <= rune1Max:
p[0] = byte(r)
return 1
case i <= rune2Max:
_ = p[1] // eliminate bounds checks
p[0] = t2 | byte(r>>6)
p[1] = tx | byte(r)&maskx
return 2
case i > MaxRune, surrogateMin <= i && i <= surrogateMax:
r = RuneError
fallthrough
case i <= rune3Max:
_ = p[2] // eliminate bounds checks
p[0] = t3 | byte(r>>12)
p[1] = tx | byte(r>>6)&maskx
p[2] = tx | byte(r)&maskx
return 3
default:
_ = p[3] // eliminate bounds checks
p[0] = t4 | byte(r>>18)
p[1] = tx | byte(r>>12)&maskx
p[2] = tx | byte(r>>6)&maskx
p[3] = tx | byte(r)&maskx
return 4
}
}
// AppendRune appends the UTF-8 encoding of r to the end of p and
// returns the extended buffer. If the rune is out of range,
// it appends the encoding of RuneError.
func AppendRune(p []byte, r rune) []byte {
// This function is inlineable for fast handling of ASCII.
if uint32(r) <= rune1Max {
return append(p, byte(r))
}
return appendRuneNonASCII(p, r)
}
func appendRuneNonASCII(p []byte, r rune) []byte {
// Negative values are erroneous. Making it unsigned addresses the problem.
switch i := uint32(r); {
case i <= rune2Max:
return append(p, t2|byte(r>>6), tx|byte(r)&maskx)
case i > MaxRune, surrogateMin <= i && i <= surrogateMax:
r = RuneError
fallthrough
case i <= rune3Max:
return append(p, t3|byte(r>>12), tx|byte(r>>6)&maskx, tx|byte(r)&maskx)
default:
return append(p, t4|byte(r>>18), tx|byte(r>>12)&maskx, tx|byte(r>>6)&maskx, tx|byte(r)&maskx)
}
}
// RuneCount returns the number of runes in p. Erroneous and short
// encodings are treated as single runes of width 1 byte.
func RuneCount(p []byte) int {
np := len(p)
var n int
for i := 0; i < np; {
n++
c := p[i]
if c < RuneSelf {
// ASCII fast path
i++
continue
}
x := first[c]
if x == xx {
i++ // invalid.
continue
}
size := int(x & 7)
if i+size > np {
i++ // Short or invalid.
continue
}
accept := acceptRanges[x>>4]
if c := p[i+1]; c < accept.lo || accept.hi < c {
size = 1
} else if size == 2 {
} else if c := p[i+2]; c < locb || hicb < c {
size = 1
} else if size == 3 {
} else if c := p[i+3]; c < locb || hicb < c {
size = 1
}
i += size
}
return n
}
// RuneCountInString is like RuneCount but its input is a string.
func RuneCountInString(s string) (n int) {
ns := len(s)
for i := 0; i < ns; n++ {
c := s[i]
if c < RuneSelf {
// ASCII fast path
i++
continue
}
x := first[c]
if x == xx {
i++ // invalid.
continue
}
size := int(x & 7)
if i+size > ns {
i++ // Short or invalid.
continue
}
accept := acceptRanges[x>>4]
if c := s[i+1]; c < accept.lo || accept.hi < c {
size = 1
} else if size == 2 {
} else if c := s[i+2]; c < locb || hicb < c {
size = 1
} else if size == 3 {
} else if c := s[i+3]; c < locb || hicb < c {
size = 1
}
i += size
}
return n
}
// RuneStart reports whether the byte could be the first byte of an encoded,
// possibly invalid rune. Second and subsequent bytes always have the top two
// bits set to 10.
func RuneStart(b byte) bool { return b&0xC0 != 0x80 }
// Valid reports whether p consists entirely of valid UTF-8-encoded runes.
func Valid(p []byte) bool {
// This optimization avoids the need to recompute the capacity
// when generating code for p[8:], bringing it to parity with
// ValidString, which was 20% faster on long ASCII strings.
p = p[:len(p):len(p)]
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
for len(p) >= 8 {
// Combining two 32 bit loads allows the same code to be used
// for 32 and 64 bit platforms.
// The compiler can generate a 32bit load for first32 and second32
// on many platforms. See test/codegen/memcombine.go.
first32 := uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
second32 := uint32(p[4]) | uint32(p[5])<<8 | uint32(p[6])<<16 | uint32(p[7])<<24
if (first32|second32)&0x80808080 != 0 {
// Found a non ASCII byte (>= RuneSelf).
break
}
p = p[8:]
}
n := len(p)
for i := 0; i < n; {
pi := p[i]
if pi < RuneSelf {
i++
continue
}
x := first[pi]
if x == xx {
return false // Illegal starter byte.
}
size := int(x & 7)
if i+size > n {
return false // Short or invalid.
}
accept := acceptRanges[x>>4]
if c := p[i+1]; c < accept.lo || accept.hi < c {
return false
} else if size == 2 {
} else if c := p[i+2]; c < locb || hicb < c {
return false
} else if size == 3 {
} else if c := p[i+3]; c < locb || hicb < c {
return false
}
i += size
}
return true
}
// ValidString reports whether s consists entirely of valid UTF-8-encoded runes.
func ValidString(s string) bool {
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
for len(s) >= 8 {
// Combining two 32 bit loads allows the same code to be used
// for 32 and 64 bit platforms.
// The compiler can generate a 32bit load for first32 and second32
// on many platforms. See test/codegen/memcombine.go.
first32 := uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 | uint32(s[3])<<24
second32 := uint32(s[4]) | uint32(s[5])<<8 | uint32(s[6])<<16 | uint32(s[7])<<24
if (first32|second32)&0x80808080 != 0 {
// Found a non ASCII byte (>= RuneSelf).
break
}
s = s[8:]
}
n := len(s)
for i := 0; i < n; {
si := s[i]
if si < RuneSelf {
i++
continue
}
x := first[si]
if x == xx {
return false // Illegal starter byte.
}
size := int(x & 7)
if i+size > n {
return false // Short or invalid.
}
accept := acceptRanges[x>>4]
if c := s[i+1]; c < accept.lo || accept.hi < c {
return false
} else if size == 2 {
} else if c := s[i+2]; c < locb || hicb < c {
return false
} else if size == 3 {
} else if c := s[i+3]; c < locb || hicb < c {
return false
}
i += size
}
return true
}
// ValidRune reports whether r can be legally encoded as UTF-8.
// Code points that are out of range or a surrogate half are illegal.
func ValidRune(r rune) bool {
switch {
case 0 <= r && r < surrogateMin:
return true
case surrogateMax < r && r <= MaxRune:
return true
}
return false
}